diff --git a/.devops/main-cuda.Dockerfile b/.devops/main-cuda.Dockerfile index c2bf0fbd1c6..7a21fc4e3db 100644 --- a/.devops/main-cuda.Dockerfile +++ b/.devops/main-cuda.Dockerfile @@ -25,7 +25,7 @@ ENV LD_LIBRARY_PATH /usr/local/cuda-${CUDA_MAIN_VERSION}/compat:$LD_LIBRARY_PATH COPY .. . # Enable cuBLAS -RUN make base.en CMAKE_ARGS="-DGGML_CUDA=1 -DCMAKE_CUDA_ARCHITECTURES='75;80;86;90'" +RUN --mount=type=secret,id=HF_TOKEN,required=false,env=HF_TOKEN make base.en CMAKE_ARGS="-DGGML_CUDA=1 -DCMAKE_CUDA_ARCHITECTURES='75;80;86;90'" RUN find /app/build -name "*.o" -delete && \ find /app/build -name "*.a" -delete && \ diff --git a/.devops/main-intel.Dockerfile b/.devops/main-intel.Dockerfile index 1b5859715d4..a0c04ad34ad 100644 --- a/.devops/main-intel.Dockerfile +++ b/.devops/main-intel.Dockerfile @@ -1,6 +1,6 @@ -ARG ONEAPI_VERSION=2025.1.1-0-devel-ubuntu24.04 +ARG ONEAPI_VERSION=2025.3.3-0-devel-ubuntu24.04 -FROM intel/oneapi-basekit:$ONEAPI_VERSION AS build +FROM intel/deep-learning-essentials:$ONEAPI_VERSION AS build WORKDIR /app RUN apt-get update && \ @@ -10,13 +10,14 @@ RUN apt-get update && \ COPY .. . # Enable SYCL ARG GGML_SYCL_F16=OFF -RUN if [ "${GGML_SYCL_F16}" = "ON" ]; then \ +RUN --mount=type=secret,id=HF_TOKEN,required=false,env=HF_TOKEN \ + if [ "${GGML_SYCL_F16}" = "ON" ]; then \ echo "GGML_SYCL_F16 is set" \ && export OPT_SYCL_F16="-DGGML_SYCL_F16=ON"; \ fi && \ make base.en CMAKE_ARGS="-DGGML_SYCL=1 -DCMAKE_C_COMPILER=icx -DCMAKE_CXX_COMPILER=icpx ${OPT_SYCL_F16}" -FROM intel/oneapi-basekit:$ONEAPI_VERSION AS runtime +FROM intel/deep-learning-essentials:$ONEAPI_VERSION AS runtime WORKDIR /app RUN apt-get update && \ diff --git a/.devops/main-musa.Dockerfile b/.devops/main-musa.Dockerfile index 026791e3f89..c68367830f1 100644 --- a/.devops/main-musa.Dockerfile +++ b/.devops/main-musa.Dockerfile @@ -16,7 +16,7 @@ RUN apt-get update && \ COPY .. . # Enable muBLAS -RUN make base.en CMAKE_ARGS="-DGGML_MUSA=1" +RUN --mount=type=secret,id=HF_TOKEN,required=false,env=HF_TOKEN make base.en CMAKE_ARGS="-DGGML_MUSA=1" RUN find /app/build -name "*.o" -delete && \ find /app/build -name "*.a" -delete && \ diff --git a/.devops/main-vulkan.Dockerfile b/.devops/main-vulkan.Dockerfile new file mode 100644 index 00000000000..16ee19dc689 --- /dev/null +++ b/.devops/main-vulkan.Dockerfile @@ -0,0 +1,20 @@ +FROM ubuntu:24.04 AS build +WORKDIR /app + +RUN apt-get update && \ + apt-get install -y build-essential wget cmake git libvulkan-dev spirv-headers glslc \ + && rm -rf /var/lib/apt/lists/* /var/cache/apt/archives/* + +COPY .. . +RUN --mount=type=secret,id=HF_TOKEN,required=false,env=HF_TOKEN make base.en CMAKE_ARGS="-DGGML_VULKAN=1" + +FROM ubuntu:24.04 AS runtime +WORKDIR /app + +RUN apt-get update && \ + apt-get install -y curl ffmpeg libsdl2-dev wget cmake git libvulkan1 mesa-vulkan-drivers \ + && rm -rf /var/lib/apt/lists/* /var/cache/apt/archives/* + +COPY --from=build /app /app +ENV PATH=/app/build/bin:$PATH +ENTRYPOINT [ "bash", "-c" ] diff --git a/.devops/main.Dockerfile b/.devops/main.Dockerfile index e1eb9b33700..d0e809f4e13 100644 --- a/.devops/main.Dockerfile +++ b/.devops/main.Dockerfile @@ -6,7 +6,7 @@ RUN apt-get update && \ && rm -rf /var/lib/apt/lists/* /var/cache/apt/archives/* COPY .. . -RUN make base.en +RUN --mount=type=secret,id=HF_TOKEN,required=false,env=HF_TOKEN make base.en FROM ubuntu:22.04 AS runtime WORKDIR /app diff --git a/.github/actions/ccache-clear/action.yml b/.github/actions/ccache-clear/action.yml new file mode 100644 index 00000000000..d38587efaf8 --- /dev/null +++ b/.github/actions/ccache-clear/action.yml @@ -0,0 +1,22 @@ +name: "ccache-clear" +description: "Delete all GitHub Actions caches matching a key prefix" +inputs: + key: + description: "Cache key prefix to match and delete" + required: true + +runs: + using: "composite" + steps: + - name: Clear caches + shell: bash + run: | + CACHES=$(gh cache list --key "ccache-${{ inputs.key }}" --json id,key --jq '.[] | "\(.id) \(.key)"' 2>/dev/null) + if [ -z "$CACHES" ]; then + echo "No caches found with key prefix: ${{ inputs.key }}" + exit 0 + fi + while read -r id key; do + echo "Deleting cache: $id ($key)" + gh cache delete "$id" + done <<< "$CACHES" diff --git a/.github/workflows/bindings-go.yml b/.github/workflows/bindings-go.yml index ff420f2b636..91f869e99cf 100644 --- a/.github/workflows/bindings-go.yml +++ b/.github/workflows/bindings-go.yml @@ -3,20 +3,20 @@ on: push: paths: - bindings/go/** - - whisper.h + - include/whisper.h pull_request: paths: - bindings/go/** - - whisper.h + - include/whisper.h jobs: ubuntu-22: runs-on: ubuntu-22.04 steps: - - uses: actions/setup-go@v5 + - uses: actions/setup-go@4a3601121dd01d1626a1e23e37211e3254c1c06c # v6 with: go-version: '^1.23' - - uses: actions/checkout@v4 + - uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6 - run: | cd bindings/go make test diff --git a/.github/workflows/bindings-ruby.yml b/.github/workflows/bindings-ruby.yml index 680862fb764..8cdb7a810f7 100644 --- a/.github/workflows/bindings-ruby.yml +++ b/.github/workflows/bindings-ruby.yml @@ -4,8 +4,19 @@ on: push: branches: - master + paths: + - bindings/ruby/** + - include/whisper.h + - examples/common-whisper.h + - ggml/include/ggml.h + pull_request: types: [opened, synchronize, reopened] + paths: + - bindings/ruby/** + - include/whisper.h + - examples/common-whisper.h + - ggml/include/ggml.h jobs: ubuntu-22: @@ -14,8 +25,8 @@ jobs: run: working-directory: bindings/ruby steps: - - uses: ruby/setup-ruby@v1 + - uses: ruby/setup-ruby@afeafc3d1ab54a631816aba4c914a0081c12ff2f # v1.310.0 with: - ruby-version: '3.2' - - uses: actions/checkout@v4 + ruby-version: '3.3' + - uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6 - run: rake test diff --git a/.github/workflows/build-android.yml b/.github/workflows/build-android.yml new file mode 100644 index 00000000000..571c35872c8 --- /dev/null +++ b/.github/workflows/build-android.yml @@ -0,0 +1,80 @@ +name: CI (android) + +on: + workflow_dispatch: # allows manual triggering + push: + branches: + - master + paths: ['.github/workflows/build-android.yml', + '**/CMakeLists.txt', + '**/*.h', + '**/*.hpp', + '**/*.c', + '**/*.cpp', + '**/*.java'] + + pull_request: + types: [opened, synchronize, reopened] + paths-ignore: + - 'bindings/ruby/**' # handled by bindings-ruby.yml + - 'bindings/go/**' # handled by bindings-go.yml + - 'examples/addon.node/**' # handled by examples.yml + +concurrency: + group: ${{ github.workflow }}-${{ github.head_ref && github.ref || github.run_id }} + cancel-in-progress: true + +jobs: + android: + runs-on: ubuntu-22.04 + + steps: + - name: Clone + uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6 + with: + path: whisper + + - name: Install Java + uses: actions/setup-java@be666c2fcd27ec809703dec50e508c2fdc7f6654 # v5 + with: + distribution: zulu + java-version: 21 + + - name: Setup Android SDK + uses: android-actions/setup-android@40fd30fb8d7440372e1316f5d1809ec01dcd3699 # v4.0.1 + + - name: Build + run: | + cd whisper/examples/whisper.android + ./gradlew assembleRelease --no-daemon + + - name: Build with external ggml + run: | + export PATH_TO_GGML=$PWD/ggml + cd whisper/examples/whisper.android + ./gradlew assembleRelease --no-daemon + + android_java: + runs-on: ubuntu-22.04 + + steps: + - name: Clone + uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6 + + - name: set up JDK 11 + uses: actions/setup-java@be666c2fcd27ec809703dec50e508c2fdc7f6654 # v5 + with: + java-version: '11' + distribution: 'temurin' + cache: gradle + + - name: Setup Android SDK + uses: android-actions/setup-android@40fd30fb8d7440372e1316f5d1809ec01dcd3699 # v4.0.1 + with: + cmdline-tools-version: 9.0 + + - name: Build + run: | + cd examples/whisper.android.java + chmod +x ./gradlew + ./gradlew assembleRelease diff --git a/.github/workflows/build-binaries.yml b/.github/workflows/build-binaries.yml index aec894d595e..7f9c29d5324 100644 --- a/.github/workflows/build-binaries.yml +++ b/.github/workflows/build-binaries.yml @@ -15,7 +15,8 @@ permissions: contents: write env: - CUDA_ARCHITECTURES: "75;80;86;89" + # RTX 20-50 (Turing through Blackwell). sm_120 requires CUDA Toolkit >= 12.8. + CUDA_ARCHITECTURES: "75;80;86;89;120" jobs: build-macos-arm64: @@ -190,21 +191,21 @@ jobs: - name: Install Ninja run: choco install ninja -y - - name: Install CUDA Toolkit 12.4.0 + - name: Install CUDA Toolkit 12.9.1 run: | - $CUDA_VERSION = "12.4.0" + $CUDA_VERSION = "12.9.1" $CUDA_TOOLKIT_DIR = "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v$CUDA_VERSION" $CUDA_DOWNLOAD = "https://developer.download.nvidia.com/compute/cuda/redist" - # Component versions for CUDA 12.4.0 - $CUDART_VER = "12.4.127" - $NVCC_VER = "12.4.131" - $NVRTC_VER = "12.4.127" - $CUBLAS_VER = "12.4.5.8" - $NVTX_VER = "12.4.127" - $PROFILER_VER = "12.4.127" - $VS_VER = "12.4.127" - $CCCL_VER = "12.4.127" + # Component versions for CUDA 12.9.1 + $CUDART_VER = "12.9.79" + $NVCC_VER = "12.9.86" + $NVRTC_VER = "12.9.86" + $CUBLAS_VER = "12.9.1.4" + $NVTX_VER = "12.9.79" + $PROFILER_VER = "12.9.79" + $VS_VER = "12.9.79" + $CCCL_VER = "12.9.27" # Create CUDA toolkit directory New-Item -ItemType Directory -Force -Path $CUDA_TOOLKIT_DIR @@ -400,7 +401,7 @@ jobs: sudo apt-get update sudo apt-get install -y build-essential cmake wget - - name: Install CUDA Toolkit 12.4 + - name: Install CUDA Toolkit 12.9 run: | # Download and install CUDA keyring wget https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2204/x86_64/cuda-keyring_1.1-1_all.deb @@ -408,16 +409,16 @@ jobs: sudo apt-get update # Install minimal CUDA toolkit (compiler and libraries only, no driver) - sudo apt-get install -y cuda-toolkit-12-4 + sudo apt-get install -y cuda-toolkit-12-9 # Set environment variables - echo "/usr/local/cuda-12.4/bin" >> $GITHUB_PATH - echo "CUDA_PATH=/usr/local/cuda-12.4" >> $GITHUB_ENV - echo "LD_LIBRARY_PATH=/usr/local/cuda-12.4/lib64:$LD_LIBRARY_PATH" >> $GITHUB_ENV + echo "/usr/local/cuda-12.9/bin" >> $GITHUB_PATH + echo "CUDA_PATH=/usr/local/cuda-12.9" >> $GITHUB_ENV + echo "LD_LIBRARY_PATH=/usr/local/cuda-12.9/lib64:$LD_LIBRARY_PATH" >> $GITHUB_ENV - name: Verify CUDA installation run: | - export PATH=/usr/local/cuda-12.4/bin:$PATH + export PATH=/usr/local/cuda-12.9/bin:$PATH nvcc --version - name: Setup ccache @@ -427,8 +428,8 @@ jobs: - name: Build whisper.cpp with CUDA run: | - export PATH=/usr/local/cuda-12.4/bin:$PATH - export CUDA_PATH=/usr/local/cuda-12.4 + export PATH=/usr/local/cuda-12.9/bin:$PATH + export CUDA_PATH=/usr/local/cuda-12.9 cmake -B build \ -DCMAKE_BUILD_TYPE=Release \ -DCMAKE_C_COMPILER_LAUNCHER=ccache \ @@ -436,7 +437,7 @@ jobs: -DBUILD_SHARED_LIBS=OFF \ -DGGML_NATIVE=OFF \ -DGGML_CUDA=ON \ - -DCMAKE_CUDA_COMPILER=/usr/local/cuda-12.4/bin/nvcc \ + -DCMAKE_CUDA_COMPILER=/usr/local/cuda-12.9/bin/nvcc \ -DCMAKE_CUDA_ARCHITECTURES="${{ env.CUDA_ARCHITECTURES }}" cmake --build build --config Release -j $(nproc) diff --git a/.github/workflows/build-clang.yml b/.github/workflows/build-clang.yml new file mode 100644 index 00000000000..20b7fec6494 --- /dev/null +++ b/.github/workflows/build-clang.yml @@ -0,0 +1,121 @@ +name: CI (clang) + +on: + workflow_dispatch: # allows manual triggering + push: + branches: + - master + paths: ['.github/workflows/build-clang.yml', + '**/CMakeLists.txt', + '**/Makefile', + '**/*.mk', + '**/*.cmake', + '**/*.in', + '**/*.h', + '**/*.hpp', + '**/*.c', + '**/*.cpp', + '**/*.cu', + '**/*.cuh', + '**/*.cl'] + + pull_request: + types: [opened, synchronize, reopened] + paths-ignore: + - 'bindings/ruby/**' # handled by bindings-ruby.yml + - 'bindings/go/**' # handled by bindings-go.yml + - 'examples/addon.node/**' # handled by examples.yml + +concurrency: + group: ${{ github.workflow }}-${{ github.head_ref && github.ref || github.run_id }} + cancel-in-progress: true + +env: + ubuntu_image: "ubuntu:22.04" + +jobs: + ubuntu-22-clang: + runs-on: ubuntu-22.04 + + strategy: + fail-fast: false + matrix: + build: [Debug, Release] + #arch: [linux/amd64, linux/arm64, linux/arm/v7, linux/ppc64le] + # TODO: arm/v7 disabled due to clang bug + # https://github.com/ggerganov/whisper.cpp/actions/runs/9657764109/job/26637633042?pr=2256#step:4:1990 + arch: [linux/amd64, linux/ppc64le] + + steps: + - name: Clone + uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6 + + - name: Set CCACHE_DIR + run: echo "CCACHE_DIR=${{ runner.temp }}/ccache" >> $GITHUB_ENV + + - name: ccache + uses: ggml-org/ccache-action@v1.2.21 + with: + key: clang-${{ matrix.arch }}-${{ matrix.build }} + evict-old-files: 1d + save: ${{ github.event_name == 'push' && github.ref == 'refs/heads/master' }} + + - name: Set up QEMU + uses: docker/setup-qemu-action@ce360397dd3f832beb865e1373c09c0e9f86d70a # v4 + + - name: Build ${{ matrix.arch }} + run: | + docker run --platform ${{ matrix.arch }} --rm \ + -v ${{ github.workspace }}:/workspace \ + -v ${CCACHE_DIR}:${CCACHE_DIR} \ + -e CCACHE_DIR=${CCACHE_DIR} \ + -w /workspace ${{ env.ubuntu_image }} /bin/sh -c ' + set -e + export DEBIAN_FRONTEND=noninteractive + sed -i "s|archive.ubuntu.com|mirrors.kernel.org|g" /etc/apt/sources.list + sed -i "s|security.ubuntu.com|mirrors.kernel.org|g" /etc/apt/sources.list + + apt update + apt install -y clang build-essential cmake libsdl2-dev git ccache + cmake . -DWHISPER_SDL2=ON -DCMAKE_BUILD_TYPE=${{ matrix.build }} \ + -DCMAKE_CXX_COMPILER=clang++ \ + -DCMAKE_C_COMPILER=clang \ + -DCMAKE_C_COMPILER_LAUNCHER=ccache \ + -DCMAKE_CXX_COMPILER_LAUNCHER=ccache + make + ctest -L gh --output-on-failure' + + ubuntu-22-clang-arm64: + runs-on: ubuntu-22.04-arm + + strategy: + fail-fast: false + matrix: + build: [Debug, Release] + + steps: + - name: Clone + uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6 + + - name: ccache + uses: ggml-org/ccache-action@v1.2.21 + with: + key: clang-arm64-${{ matrix.build }} + evict-old-files: 1d + save: ${{ github.event_name == 'push' && github.ref == 'refs/heads/master' }} + + - name: Install dependencies + run: | + sudo apt-get update + sudo apt-get install -y clang build-essential cmake libsdl2-dev git + + - name: Build and Test + run: | + cmake . -DWHISPER_SDL2=ON \ + -DCMAKE_BUILD_TYPE=${{ matrix.build }} \ + -DCMAKE_CXX_COMPILER=clang++ \ + -DCMAKE_C_COMPILER=clang \ + -DGGML_NATIVE=OFF \ + -DGGML_CPU_ARM_ARCH=armv8-a + make + ctest -L gh --output-on-failure diff --git a/.github/workflows/build-coreml.yml b/.github/workflows/build-coreml.yml new file mode 100644 index 00000000000..8dedd7819ed --- /dev/null +++ b/.github/workflows/build-coreml.yml @@ -0,0 +1,65 @@ +name: CI (coreml) + +on: + workflow_dispatch: # allows manual triggering + push: + branches: + - master + tags: + - 'v*' + paths: ['.github/workflows/build-coreml.yml', + '**/CMakeLists.txt', + '**/*.h', + '**/*.hpp', + '**/*.c', + '**/*.cpp', + '**/*.swift', + '**/*.m', + '**/*.mm', + '**/*.metal'] + +concurrency: + group: ${{ github.workflow }}-${{ github.head_ref && github.ref || github.run_id }} + cancel-in-progress: true + +env: + BRANCH_NAME: ${{ github.head_ref || github.ref_name }} + +jobs: + coreml-base-en: + runs-on: macos-latest + + steps: + - name: Checkout with full history + uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6 + with: + fetch-depth: 0 + + - name: Set environment variables + id: set_vars + run: | + BUILD_NUMBER=$(git rev-list --count HEAD) + SHORT_HASH=$(git rev-parse --short=7 HEAD) + if [[ "${{ github.ref_type }}" == "tag" ]]; then + TAG_NAME="${{ github.ref_name }}" + elif [[ "${{ env.BRANCH_NAME }}" == "master" ]]; then + TAG_NAME="b${BUILD_NUMBER}" + else + SAFE_NAME=$(echo "${{ env.BRANCH_NAME }}" | tr '/' '-') + TAG_NAME="${SAFE_NAME}-b${BUILD_NUMBER}-${SHORT_HASH}" + fi + echo "MODEL_NAME=base.en" >> $GITHUB_ENV + echo "GEN_MODEL_NAME=whisper-${TAG_NAME}-ggml-base.en-encoder.mlmodelc" >> $GITHUB_ENV + + - name: Download model + env: + HF_TOKEN: ${{ secrets.HF_TOKEN }} + run: | + ./models/download-ggml-model.sh ${{ env.MODEL_NAME }} + + - name: Generate CoreML model + run: | + python3.11 -m venv venv + source venv/bin/activate + pip install ane_transformers openai-whisper coremltools + ./models/generate-coreml-model.sh ${{ env.MODEL_NAME }} diff --git a/.github/workflows/build-cpu.yml b/.github/workflows/build-cpu.yml new file mode 100644 index 00000000000..e2b74881ea5 --- /dev/null +++ b/.github/workflows/build-cpu.yml @@ -0,0 +1,173 @@ +name: CI (cpu) + +on: + workflow_dispatch: # allows manual triggering + push: + branches: + - master + paths: ['.github/workflows/build-cpu.yml', + '**/CMakeLists.txt', + '**/Makefile', + '**/*.mk', + '**/*.cmake', + '**/*.in', + '**/*.h', + '**/*.hpp', + '**/*.c', + '**/*.cpp', + '**/*.cu', + '**/*.cuh', + '**/*.cl'] + + pull_request: + types: [opened, synchronize, reopened] + paths-ignore: + - 'bindings/ruby/**' # handled by bindings-ruby.yml + - 'bindings/go/**' # handled by bindings-go.yml + - 'examples/addon.node/**' # handled by examples.yml + +concurrency: + group: ${{ github.workflow }}-${{ github.head_ref && github.ref || github.run_id }} + cancel-in-progress: true + +# TODO: simplify the following jobs using a matrix +jobs: + ggml-ci-x64-cpu-low-perf: + runs-on: ubuntu-22.04 + + steps: + - name: Clone + id: checkout + uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6 + + - name: ccache + uses: ggml-org/ccache-action@v1.2.21 + with: + key: ggml-ci-x64-cpu-low-perf + evict-old-files: 1d + save: ${{ github.event_name == 'push' && github.ref == 'refs/heads/master' }} + + - name: Dependencies + id: depends + run: | + sudo apt-get update + sudo apt-get install build-essential libcurl4-openssl-dev + + - name: Test + id: ggml-ci + env: + HF_TOKEN: ${{ secrets.HF_TOKEN }} + run: | + LLAMA_ARG_THREADS=$(nproc) GG_BUILD_LOW_PERF=1 bash ./ci/run.sh ./tmp/results ./tmp/mnt + + ggml-ci-arm64-cpu-low-perf: + runs-on: ubuntu-22.04-arm + + steps: + - name: Clone + id: checkout + uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6 + + - name: ccache + uses: ggml-org/ccache-action@v1.2.21 + with: + key: ggml-ci-arm64-cpu-low-perf + evict-old-files: 1d + save: ${{ github.event_name == 'push' && github.ref == 'refs/heads/master' }} + + - name: Dependencies + id: depends + run: | + sudo apt-get update + sudo apt-get install build-essential libcurl4-openssl-dev + + - name: Test + id: ggml-ci + env: + HF_TOKEN: ${{ secrets.HF_TOKEN }} + run: | + LLAMA_ARG_THREADS=$(nproc) GG_BUILD_LOW_PERF=1 bash ./ci/run.sh ./tmp/results ./tmp/mnt + + ggml-ci-x64-cpu-high-perf: + runs-on: ubuntu-22.04 + + steps: + - name: Clone + id: checkout + uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6 + + - name: ccache + uses: ggml-org/ccache-action@v1.2.21 + with: + key: ggml-ci-x64-cpu-high-perf + evict-old-files: 1d + save: ${{ github.event_name == 'push' && github.ref == 'refs/heads/master' }} + + - name: Dependencies + id: depends + run: | + sudo apt-get update + sudo apt-get install build-essential libcurl4-openssl-dev + + - name: Test + id: ggml-ci + env: + HF_TOKEN: ${{ secrets.HF_TOKEN }} + run: | + LLAMA_ARG_THREADS=$(nproc) bash ./ci/run.sh ./tmp/results ./tmp/mnt + + ggml-ci-arm64-cpu-high-perf: + runs-on: ubuntu-22.04-arm + + steps: + - name: Clone + id: checkout + uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6 + + - name: ccache + uses: ggml-org/ccache-action@v1.2.21 + with: + key: ggml-ci-arm64-cpu-high-perf + evict-old-files: 1d + save: ${{ github.event_name == 'push' && github.ref == 'refs/heads/master' }} + + - name: Dependencies + id: depends + run: | + sudo apt-get update + sudo apt-get install build-essential libcurl4-openssl-dev + + - name: Test + id: ggml-ci + env: + HF_TOKEN: ${{ secrets.HF_TOKEN }} + run: | + LLAMA_ARG_THREADS=$(nproc) GG_BUILD_NO_SVE=1 GG_BUILD_NO_BF16=1 GG_BUILD_EXTRA_TESTS_0=1 bash ./ci/run.sh ./tmp/results ./tmp/mnt + + ggml-ci-arm64-cpu-high-perf-sve: + runs-on: ubuntu-22.04-arm + + steps: + - name: Clone + id: checkout + uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6 + + - name: ccache + uses: ggml-org/ccache-action@v1.2.21 + with: + key: ggml-ci-arm64-cpu-high-perf-sve + evict-old-files: 1d + save: ${{ github.event_name == 'push' && github.ref == 'refs/heads/master' }} + + - name: Dependencies + id: depends + run: | + sudo apt-get update + sudo apt-get install build-essential libcurl4-openssl-dev + + - name: Test + id: ggml-ci + env: + HF_TOKEN: ${{ secrets.HF_TOKEN }} + run: | + LLAMA_ARG_THREADS=$(nproc) GG_BUILD_NO_BF16=1 GG_BUILD_EXTRA_TESTS_0=1 bash ./ci/run.sh ./tmp/results ./tmp/mnt diff --git a/.github/workflows/build-freebsd.yml b/.github/workflows/build-freebsd.yml new file mode 100644 index 00000000000..64e78ad62f8 --- /dev/null +++ b/.github/workflows/build-freebsd.yml @@ -0,0 +1,47 @@ +name: CI (freebsd) + +on: + workflow_dispatch: # allows manual triggering + push: + branches: + - master + paths: ['.github/workflows/build-freebsd.yml', + '**/CMakeLists.txt', + '**/Makefile', + '**/*.mk', + '**/*.cmake', + '**/*.in', + '**/*.h', + '**/*.hpp', + '**/*.c', + '**/*.cpp'] + + pull_request: + types: [opened, synchronize, reopened] + paths-ignore: + - 'bindings/ruby/**' # handled by bindings-ruby.yml + - 'bindings/go/**' # handled by bindings-go.yml + - 'examples/addon.node/**' # handled by examples.yml + +concurrency: + group: ${{ github.workflow }}-${{ github.head_ref && github.ref || github.run_id }} + cancel-in-progress: true + +jobs: + freeBSD-latest: + runs-on: macos-13 + + steps: + - name: Clone + uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6 + + - name: Build + uses: cross-platform-actions/action@fe0167d8082ac584754ef3ffb567fded22642c7d # v0.27.0 + with: + operating_system: freebsd + version: '14.2' + run: | + sudo pkg update + sudo pkg install -y gmake sdl2 cmake git + cmake -B build + cmake --build build --config Release diff --git a/.github/workflows/build-gcc.yml b/.github/workflows/build-gcc.yml new file mode 100644 index 00000000000..53c1b2d783c --- /dev/null +++ b/.github/workflows/build-gcc.yml @@ -0,0 +1,167 @@ +name: CI (gcc) + +on: + workflow_dispatch: # allows manual triggering + push: + branches: + - master + paths: ['.github/workflows/build-gcc.yml', + '**/CMakeLists.txt', + '**/Makefile', + '**/*.mk', + '**/*.cmake', + '**/*.in', + '**/*.h', + '**/*.hpp', + '**/*.c', + '**/*.cpp', + '**/*.cu', + '**/*.cuh', + '**/*.cl'] + + pull_request: + types: [opened, synchronize, reopened] + paths-ignore: + - 'bindings/ruby/**' # handled by bindings-ruby.yml + - 'bindings/go/**' # handled by bindings-go.yml + - 'examples/addon.node/**' # handled by examples.yml + +concurrency: + group: ${{ github.workflow }}-${{ github.head_ref && github.ref || github.run_id }} + cancel-in-progress: true + +env: + ubuntu_image: "ubuntu:22.04" + +jobs: + ubuntu-22-gcc: + runs-on: ubuntu-22.04 + + strategy: + fail-fast: false + matrix: + build: [Debug, Release] + arch: [linux/amd64, linux/ppc64le] + + steps: + - name: Clone + uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6 + + - name: Set CCACHE_DIR + run: echo "CCACHE_DIR=${{ runner.temp }}/ccache" >> $GITHUB_ENV + + - name: ccache + uses: ggml-org/ccache-action@v1.2.21 + with: + key: gcc-${{ matrix.arch }}-${{ matrix.build }} + evict-old-files: 1d + save: ${{ github.event_name == 'push' && github.ref == 'refs/heads/master' }} + + - name: Set up QEMU + uses: docker/setup-qemu-action@ce360397dd3f832beb865e1373c09c0e9f86d70a # v4 + + - name: Build ${{ matrix.arch }} + run: | + docker run --platform ${{ matrix.arch }} --rm \ + -v ${{ github.workspace }}:/workspace \ + -v ${CCACHE_DIR}:${CCACHE_DIR} \ + -e CCACHE_DIR=${CCACHE_DIR} \ + -w /workspace ${{ env.ubuntu_image }} /bin/sh -c ' + set -e + export DEBIAN_FRONTEND=noninteractive + sed -i "s|archive.ubuntu.com|mirrors.kernel.org|g" /etc/apt/sources.list + sed -i "s|security.ubuntu.com|mirrors.kernel.org|g" /etc/apt/sources.list + + apt update + apt install -y build-essential cmake libsdl2-dev git ccache + cmake . -DWHISPER_SDL2=ON -DCMAKE_BUILD_TYPE=${{ matrix.build }} \ + -DGGML_NATIVE=OFF \ + -DCMAKE_C_COMPILER_LAUNCHER=ccache \ + -DCMAKE_CXX_COMPILER_LAUNCHER=ccache + make + ctest -L gh --output-on-failure' + + ubuntu-22-gcc-arm64: + runs-on: ubuntu-22.04-arm + + strategy: + fail-fast: false + matrix: + build: [Debug, Release] + + steps: + - name: Clone + uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6 + + - name: ccache + uses: ggml-org/ccache-action@v1.2.21 + with: + key: gcc-arm64-${{ matrix.build }} + evict-old-files: 1d + save: ${{ github.event_name == 'push' && github.ref == 'refs/heads/master' }} + + - name: Install dependencies + run: | + sudo apt-get update + sudo apt-get install -y build-essential cmake libsdl2-dev git + + - name: Configure CMake + run: | + cmake . \ + -DWHISPER_SDL2=ON \ + -DCMAKE_BUILD_TYPE=${{ matrix.build }} \ + -DGGML_NATIVE=OFF \ + -DGGML_CPU_ARM_ARCH=armv8-a + + - name: Build and Test + run: | + make + ctest -L gh --output-on-failure + + ubuntu-22-gcc-arm-v7: + runs-on: ubuntu-22.04 + + strategy: + fail-fast: false + matrix: + build: [Debug, Release] + arch: [linux/arm/v7] + + steps: + - name: Clone + uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6 + + - name: Set CCACHE_DIR + run: echo "CCACHE_DIR=${{ runner.temp }}/ccache" >> $GITHUB_ENV + + - name: ccache + uses: ggml-org/ccache-action@v1.2.21 + with: + key: gcc-${{ matrix.arch }}-${{ matrix.build }} + evict-old-files: 1d + save: ${{ github.event_name == 'push' && github.ref == 'refs/heads/master' }} + + - name: Set up QEMU + uses: docker/setup-qemu-action@ce360397dd3f832beb865e1373c09c0e9f86d70a # v4 + + - name: Build ${{ matrix.arch }} + run: | + docker run --platform ${{ matrix.arch }} --rm \ + -v ${{ github.workspace }}:/workspace \ + -v ${CCACHE_DIR}:${CCACHE_DIR} \ + -e CCACHE_DIR=${CCACHE_DIR} \ + -w /workspace ${{ env.ubuntu_image }} /bin/sh -c ' + set -e + export DEBIAN_FRONTEND=noninteractive + sed -i "s|archive.ubuntu.com|mirrors.kernel.org|g" /etc/apt/sources.list + sed -i "s|security.ubuntu.com|mirrors.kernel.org|g" /etc/apt/sources.list + + apt update + apt install -y build-essential cmake libsdl2-dev git ccache + cmake . -DWHISPER_SDL2=ON -DCMAKE_BUILD_TYPE=${{ matrix.build }} \ + -DGGML_NATIVE=OFF \ + -DGGML_CPU_ARM_ARCH=armv7-a+fp \ + -DCMAKE_C_COMPILER_LAUNCHER=ccache \ + -DCMAKE_CXX_COMPILER_LAUNCHER=ccache + make + ctest -L gh --output-on-failure' diff --git a/.github/workflows/build-macos.yml b/.github/workflows/build-macos.yml new file mode 100644 index 00000000000..8b209e4eec8 --- /dev/null +++ b/.github/workflows/build-macos.yml @@ -0,0 +1,72 @@ +name: CI (macOS) + +on: + workflow_dispatch: # allows manual triggering + push: + branches: + - master + paths: ['.github/workflows/build-macos.yml', + '**/CMakeLists.txt', + '**/Makefile', + '**/*.mk', + '**/*.cmake', + '**/*.in', + '**/*.h', + '**/*.hpp', + '**/*.c', + '**/*.cpp', + '**/*.cu', + '**/*.cuh', + '**/*.swift', + '**/*.m', + '**/*.mm', + '**/*.metal'] + + pull_request: + types: [opened, synchronize, reopened] + paths-ignore: + - 'bindings/ruby/**' # handled by bindings-ruby.yml + - 'bindings/go/**' # handled by bindings-go.yml + - 'examples/addon.node/**' # handled by examples.yml + +concurrency: + group: ${{ github.workflow }}-${{ github.head_ref && github.ref || github.run_id }} + cancel-in-progress: true + +jobs: + macOS-latest: + runs-on: macOS-latest + + strategy: + matrix: + destination: ['generic/platform=macOS', 'generic/platform=iOS', 'generic/platform=tvOS'] + + steps: + - name: Clone + id: checkout + uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6 + + - name: ccache + uses: ggml-org/ccache-action@v1.2.21 + with: + key: macos-${{ matrix.destination }} + evict-old-files: 1d + save: ${{ github.event_name == 'push' && github.ref == 'refs/heads/master' }} + + - name: Dependencies + run: | + brew update + cmake --version + brew install sdl2 + + - name: Build + run: | + sysctl -a + cmake -B build -G Xcode \ + -DGGML_METAL_USE_BF16=ON \ + -DGGML_METAL_EMBED_LIBRARY=ON \ + -DWHISPER_BUILD_EXAMPLES=OFF \ + -DWHISPER_BUILD_TESTS=OFF \ + -DWHISPER_BUILD_SERVER=OFF \ + -DCMAKE_OSX_ARCHITECTURES="arm64;x86_64" + cmake --build build --config Release -j $(sysctl -n hw.logicalcpu) diff --git a/.github/workflows/build-quantize.yml b/.github/workflows/build-quantize.yml new file mode 100644 index 00000000000..1c9576af7f1 --- /dev/null +++ b/.github/workflows/build-quantize.yml @@ -0,0 +1,48 @@ +name: CI (quantize) + +on: + workflow_dispatch: # allows manual triggering + push: + branches: + - master + paths: ['.github/workflows/build-quantize.yml', + '**/CMakeLists.txt', + '**/*.h', + '**/*.hpp', + '**/*.c', + '**/*.cpp'] + + pull_request: + types: [opened, synchronize, reopened] + paths-ignore: + - 'bindings/ruby/**' # handled by bindings-ruby.yml + - 'bindings/go/**' # handled by bindings-go.yml + - 'examples/addon.node/**' # handled by examples.yml + +concurrency: + group: ${{ github.workflow }}-${{ github.head_ref && github.ref || github.run_id }} + cancel-in-progress: true + +jobs: + quantize: + runs-on: ubuntu-22.04 + + steps: + - name: Clone + uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6 + + - name: ccache + uses: ggml-org/ccache-action@v1.2.21 + with: + key: quantize-ubuntu-22 + evict-old-files: 1d + save: ${{ github.event_name == 'push' && github.ref == 'refs/heads/master' }} + + - name: Test quantize + env: + HF_TOKEN: ${{ secrets.HF_TOKEN }} + run: | + ./models/download-ggml-model.sh tiny.en + cmake -B build -DCMAKE_C_COMPILER_LAUNCHER=ccache -DCMAKE_CXX_COMPILER_LAUNCHER=ccache + cmake --build build --config Release + ./build/bin/whisper-quantize models/ggml-tiny.en.bin models/ggml-tiny.en-q4_0.bin q4_0 diff --git a/.github/workflows/build-sanitize.yml b/.github/workflows/build-sanitize.yml new file mode 100644 index 00000000000..e517f7bade4 --- /dev/null +++ b/.github/workflows/build-sanitize.yml @@ -0,0 +1,82 @@ +name: CI (sanitize) + +on: + workflow_dispatch: # allows manual triggering + push: + branches: + - master + paths: ['.github/workflows/build-sanitize.yml', + '**/CMakeLists.txt', + '**/Makefile', + '**/*.mk', + '**/*.cmake', + '**/*.h', + '**/*.hpp', + '**/*.c', + '**/*.cpp'] + + pull_request: + types: [opened, synchronize, reopened] + paths-ignore: + - 'bindings/ruby/**' + - 'bindings/go/**' + - 'examples/addon.node/**' + +concurrency: + group: ${{ github.workflow }}-${{ github.head_ref && github.ref || github.run_id }} + cancel-in-progress: true + +jobs: + ubuntu-22-gcc-sanitized: + runs-on: ubuntu-22.04 + + continue-on-error: true + + strategy: + fail-fast: false + matrix: + sanitizer: [ADDRESS, THREAD, UNDEFINED] + + steps: + - name: Clone + uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6 + + - name: ccache + uses: ggml-org/ccache-action@v1.2.21 + with: + key: sanitize-${{ matrix.sanitizer }} + evict-old-files: 1d + save: ${{ github.event_name == 'push' && github.ref == 'refs/heads/master' }} + + - name: Install dependencies + run: | + sudo apt-get update + sudo apt-get install -y build-essential cmake git + + - name: Build (undefined) + if: ${{ matrix.sanitizer == 'UNDEFINED' }} + run: | + cmake . -DCMAKE_BUILD_TYPE=Debug \ + -DWHISPER_SANITIZE_${{ matrix.sanitizer }}=ON \ + -DGGML_OPENMP=OFF + make + + - name: Build + if: ${{ matrix.sanitizer == 'ADDRESS' }} + run: | + cmake . -DCMAKE_BUILD_TYPE=RelWithDebInfo \ + -DWHISPER_SANITIZE_${{ matrix.sanitizer }}=ON + make + + - name: Build (no OpenMP) + if: ${{ matrix.sanitizer == 'THREAD' }} + run: | + cmake . -DCMAKE_BUILD_TYPE=RelWithDebInfo \ + -DWHISPER_SANITIZE_${{ matrix.sanitizer }}=ON \ + -DGGML_OPENMP=OFF + make + + - name: Test + if: ${{ matrix.sanitizer != 'UNDEFINED' }} + run: | + ctest -L gh --output-on-failure diff --git a/.github/workflows/build-self-hosted.yml b/.github/workflows/build-self-hosted.yml new file mode 100644 index 00000000000..2286b63d6e7 --- /dev/null +++ b/.github/workflows/build-self-hosted.yml @@ -0,0 +1,116 @@ +name: CI (self-hosted) + +on: + workflow_dispatch: # allows manual triggering + push: + branches: + - master + paths: [ + '.github/workflows/build.yml', + '**/CMakeLists.txt', + '**/.cmake', + '**/*.h', + '**/*.hpp', + '**/*.c', + '**/*.cpp', + '**/*.cu', + '**/*.cuh', + '**/*.swift', + '**/*.m', + '**/*.mm', + '**/*.metal', + '**/*.comp' + ] + + pull_request: + types: [opened, synchronize, reopened] + paths: [ + '.github/workflows/build-self-hosted.yml', + '**/CMakeLists.txt', + '**/.cmake', + '**/*.h', + '**/*.hpp', + '**/*.c', + '**/*.cpp', + '**/*.cu', + '**/*.cuh', + '**/*.swift', + '**/*.m', + '**/*.mm', + '**/*.metal', + '**/*.comp' + ] + +concurrency: + group: ${{ github.workflow }}-${{ github.head_ref && github.ref || github.run_id }} + cancel-in-progress: true + +jobs: + gpu-cuda: + runs-on: [self-hosted, Linux, NVIDIA] + + steps: + - name: Clone + id: checkout + uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6 + + - name: Test + id: ggml-ci + run: | + nvidia-smi + GG_BUILD_CUDA=1 bash ./ci/run.sh ~/results/whisper.cpp ~/mnt/whisper.cpp + + gpu-vulkan-nvidia-cm: + runs-on: [self-hosted, Linux, NVIDIA] + + steps: + - name: Clone + id: checkout + uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6 + + - name: Test + id: ggml-ci + run: | + vulkaninfo --summary + GG_BUILD_VULKAN=1 GGML_VK_DISABLE_COOPMAT2=1 bash ./ci/run.sh ~/results/whisper.cpp ~/mnt/whisper.cpp + + gpu-vulkan-nvidia-cm2: + runs-on: [self-hosted, Linux, NVIDIA, COOPMAT2] + + steps: + - name: Clone + id: checkout + uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6 + + - name: Test + id: ggml-ci + run: | + vulkaninfo --summary + GG_BUILD_VULKAN=1 bash ./ci/run.sh ~/results/whisper.cpp ~/mnt/whisper.cpp + + gpu-metal: + runs-on: [self-hosted, macOS, ARM64] + + steps: + - name: Clone + id: checkout + uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6 + + - name: Test + id: ggml-ci + run: | + GG_BUILD_METAL=1 bash ./ci/run.sh ~/results/whisper.cpp ~/mnt/whisper.cpp + + gpu-vulkan: + runs-on: [self-hosted, macOS, ARM64] + + steps: + - name: Clone + id: checkout + uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6 + + - name: Test + id: ggml-ci + run: | + vulkaninfo --summary + GG_BUILD_VULKAN=1 bash ./ci/run.sh ~/results/whisper.cpp ~/mnt/whisper.cpp diff --git a/.github/workflows/build-sycl.yml b/.github/workflows/build-sycl.yml new file mode 100644 index 00000000000..e5361645f1e --- /dev/null +++ b/.github/workflows/build-sycl.yml @@ -0,0 +1,150 @@ +name: CI (sycl) + +on: + workflow_dispatch: # allows manual triggering + push: + branches: + - master + paths: ['.github/workflows/build-sycl.yml', + '**/CMakeLists.txt', + '**/Makefile', + '**/*.mk', + '**/*.cmake', + '**/*.in', + '**/*.h', + '**/*.hpp', + '**/*.c', + '**/*.cpp', + '**/*.cu', + '**/*.cuh', + '**/*.cl'] + + pull_request: + types: [opened, synchronize, reopened] + paths-ignore: + - 'bindings/ruby/**' # handled by bindings-ruby.yml + - 'bindings/go/**' # handled by bindings-go.yml + - 'examples/addon.node/**' # handled by examples.yml + +concurrency: + group: ${{ github.workflow }}-${{ github.head_ref && github.ref || github.run_id }} + cancel-in-progress: true + +jobs: + ubuntu-22-cmake-sycl: + runs-on: ubuntu-22.04 + + strategy: + fail-fast: false + matrix: + dwhisper_sycl: [ON] + dcmake_c_compiler: [icx] + dcmake_cxx_compiler: [icpx] + arch: [linux/amd64, linux/arm64, linux/arm/v7, linux/ppc64le] + + continue-on-error: true + + steps: + - name: Clone + uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6 + + - name: add oneAPI to apt + shell: bash + run: | + cd /tmp + wget https://apt.repos.intel.com/intel-gpg-keys/GPG-PUB-KEY-INTEL-SW-PRODUCTS.PUB + sudo apt-key add GPG-PUB-KEY-INTEL-SW-PRODUCTS.PUB + rm GPG-PUB-KEY-INTEL-SW-PRODUCTS.PUB + sudo add-apt-repository "deb https://apt.repos.intel.com/oneapi all main" + + - name: install oneAPI dpcpp compiler + shell: bash + run: | + sudo apt update + sudo apt install intel-oneapi-compiler-dpcpp-cpp + + - name: install oneAPI MKL library + shell: bash + run: | + sudo apt install intel-oneapi-mkl-devel + + - name: ccache + uses: ggml-org/ccache-action@v1.2.21 + with: + key: sycl-${{ matrix.arch }} + evict-old-files: 1d + save: ${{ github.event_name == 'push' && github.ref == 'refs/heads/master' }} + + - name: Build + id: cmake_build + env: + CCACHE_SLOPPINESS: time_macros + CCACHE_NODIRECT: 1 + run: | + source /opt/intel/oneapi/setvars.sh + export CCACHE_COMPILERCHECK="string:$(icpx --version 2>&1 | head -1)" + mkdir build + cd build + cmake -DGGML_SYCL=ON -DCMAKE_C_COMPILER=icx -DCMAKE_CXX_COMPILER=icpx \ + -DCMAKE_C_COMPILER_LAUNCHER=ccache \ + -DCMAKE_CXX_COMPILER_LAUNCHER=ccache .. + cmake --build . --config Release -j $(nproc) + + ubuntu-22-cmake-sycl-fp16: + runs-on: ubuntu-22.04 + + strategy: + fail-fast: false + matrix: + dwhisper_sycl: [ON] + dcmake_c_compiler: [icx] + dcmake_cxx_compiler: [icpx] + arch: [linux/amd64, linux/arm64, linux/arm/v7, linux/ppc64le] + + continue-on-error: true + + steps: + - name: Clone + uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6 + + - name: add oneAPI to apt + shell: bash + run: | + cd /tmp + wget https://apt.repos.intel.com/intel-gpg-keys/GPG-PUB-KEY-INTEL-SW-PRODUCTS.PUB + sudo apt-key add GPG-PUB-KEY-INTEL-SW-PRODUCTS.PUB + rm GPG-PUB-KEY-INTEL-SW-PRODUCTS.PUB + sudo add-apt-repository "deb https://apt.repos.intel.com/oneapi all main" + + - name: install oneAPI dpcpp compiler + shell: bash + run: | + sudo apt update + sudo apt install intel-oneapi-compiler-dpcpp-cpp + + - name: install oneAPI MKL library + shell: bash + run: | + sudo apt install intel-oneapi-mkl-devel + + - name: ccache + uses: ggml-org/ccache-action@v1.2.21 + with: + key: sycl-fp16-${{ matrix.arch }} + evict-old-files: 1d + save: ${{ github.event_name == 'push' && github.ref == 'refs/heads/master' }} + + - name: Build + id: cmake_build + env: + CCACHE_SLOPPINESS: time_macros + CCACHE_NODIRECT: 1 + run: | + source /opt/intel/oneapi/setvars.sh + export CCACHE_COMPILERCHECK="string:$(icpx --version 2>&1 | head -1)" + mkdir build + cd build + cmake -DGGML_SYCL_F16=ON -DCMAKE_C_COMPILER=icx -DCMAKE_CXX_COMPILER=icpx \ + -DCMAKE_C_COMPILER_LAUNCHER=ccache \ + -DCMAKE_CXX_COMPILER_LAUNCHER=ccache .. + cmake --build . --config Release -j $(nproc) diff --git a/.github/workflows/build-vad.yml b/.github/workflows/build-vad.yml new file mode 100644 index 00000000000..dd0efa33efe --- /dev/null +++ b/.github/workflows/build-vad.yml @@ -0,0 +1,50 @@ +name: CI (vad) + +on: + workflow_dispatch: # allows manual triggering + push: + branches: + - master + paths: ['.github/workflows/build-vad.yml', + '**/CMakeLists.txt', + '**/*.h', + '**/*.hpp', + '**/*.c', + '**/*.cpp'] + + pull_request: + types: [opened, synchronize, reopened] + paths-ignore: + - 'bindings/ruby/**' # handled by bindings-ruby.yml + - 'bindings/go/**' # handled by bindings-go.yml + - 'examples/addon.node/**' # handled by examples.yml + +concurrency: + group: ${{ github.workflow }}-${{ github.head_ref && github.ref || github.run_id }} + cancel-in-progress: true + +jobs: + vad: + runs-on: ubuntu-latest + + steps: + - name: Checkout + uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6 + + - name: ccache + uses: ggml-org/ccache-action@v1.2.21 + with: + key: vad-ubuntu-latest + evict-old-files: 1d + save: ${{ github.event_name == 'push' && github.ref == 'refs/heads/master' }} + + - name: Build + shell: bash + run: | + cmake -B build -DCMAKE_C_COMPILER_LAUNCHER=ccache -DCMAKE_CXX_COMPILER_LAUNCHER=ccache + cmake --build build --config Release + + - name: Test + shell: bash + run: | + ctest -R ^test-vad$ --test-dir build --output-on-failure -VV diff --git a/.github/workflows/build-wasm.yml b/.github/workflows/build-wasm.yml new file mode 100644 index 00000000000..c17a44ae455 --- /dev/null +++ b/.github/workflows/build-wasm.yml @@ -0,0 +1,65 @@ +name: CI (wasm) + +on: + workflow_dispatch: # allows manual triggering + push: + branches: + - master + paths: ['.github/workflows/build-wasm.yml', + '**/CMakeLists.txt', + '**/Makefile', + '**/*.mk', + '**/*.cmake', + '**/*.in', + '**/*.h', + '**/*.hpp', + '**/*.c', + '**/*.cpp'] + + pull_request: + types: [opened, synchronize, reopened] + paths-ignore: + - 'bindings/ruby/**' # handled by bindings-ruby.yml + - 'bindings/go/**' # handled by bindings-go.yml + - 'examples/addon.node/**' # handled by examples.yml + +concurrency: + group: ${{ github.workflow }}-${{ github.head_ref && github.ref || github.run_id }} + cancel-in-progress: true + +jobs: + emscripten: + runs-on: ubuntu-22.04 + + strategy: + matrix: + build: [Release] + + steps: + - name: Clone + uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6 + + - name: Setup emsdk + uses: emscripten-core/setup-emsdk@6ab9eb1bda2574c4ddb79809fc9247783eaf9021 # v14 + + - name: Verify + run: emcc -v + + - name: ccache + uses: ggml-org/ccache-action@v1.2.21 + with: + key: wasm-ubuntu-22 + evict-old-files: 1d + save: ${{ github.event_name == 'push' && github.ref == 'refs/heads/master' }} + + - name: Build + env: + CCACHE_SLOPPINESS: time_macros,include_file_mtime,include_file_ctime + CCACHE_COMPILERCHECK: content + run: | + emcmake cmake -B build -DCMAKE_BUILD_TYPE=${{ matrix.build }} \ + -DCMAKE_C_COMPILER_LAUNCHER=ccache \ + -DCMAKE_CXX_COMPILER_LAUNCHER=ccache \ + "-DCMAKE_C_FLAGS=-ffile-prefix-map=$EMSDK=/emsdk" \ + "-DCMAKE_CXX_FLAGS=-ffile-prefix-map=$EMSDK=/emsdk" + cmake --build build -j $(nproc) diff --git a/.github/workflows/build-windows.yml b/.github/workflows/build-windows.yml new file mode 100644 index 00000000000..76b7a7370ce --- /dev/null +++ b/.github/workflows/build-windows.yml @@ -0,0 +1,74 @@ +name: CI (windows) + +on: + workflow_dispatch: # allows manual triggering + push: + branches: + - master + paths: ['.github/workflows/build-windows.yml', + '**/CMakeLists.txt', + '**/Makefile', + '**/*.mk', + '**/*.cmake', + '**/*.in', + '**/*.h', + '**/*.hpp', + '**/*.c', + '**/*.cpp', + '**/*.cu', + '**/*.cuh', + '**/*.cl'] + + pull_request: + types: [opened, synchronize, reopened] + paths-ignore: + - 'bindings/ruby/**' # handled by bindings-ruby.yml + - 'bindings/go/**' # handled by bindings-go.yml + - 'examples/addon.node/**' # handled by examples.yml + +concurrency: + group: ${{ github.workflow }}-${{ github.head_ref && github.ref || github.run_id }} + cancel-in-progress: true + +jobs: + windows-msys2: + runs-on: windows-latest + + strategy: + fail-fast: false + matrix: + include: + - { sys: UCRT64, env: ucrt-x86_64, compiler: gcc, build: Release } + - { sys: CLANG64, env: clang-x86_64, compiler: clang, build: Release } + + steps: + - name: Clone + uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6 + + - name: Setup ${{ matrix.sys }} + uses: msys2/setup-msys2@cafece8e6baf9247cf9b1bf95097b0b983cc558d # v2 + with: + update: true + msystem: ${{matrix.sys}} + install: >- + mingw-w64-${{matrix.env}}-${{matrix.compiler}} + mingw-w64-${{matrix.env}}-cmake + mingw-w64-${{matrix.env}}-SDL2 + mingw-w64-${{matrix.env}}-openblas + + - name: Build using CMake + shell: msys2 {0} + run: | + cmake -B build -DWHISPER_SDL2=ON + cmake --build build --config ${{ matrix.build }} -j $(nproc) + + - name: Clean after building using CMake + shell: msys2 {0} + run: | + rm -rf build + + - name: Build using CMake w/ OpenBLAS + shell: msys2 {0} + run: | + cmake -B build -DGGML_BLAS=ON -DGGML_BLAS_VENDOR=OpenBLAS + cmake --build build --config ${{ matrix.build }} -j $(nproc) diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml deleted file mode 100644 index 5c1cf93ba2a..00000000000 --- a/.github/workflows/build.yml +++ /dev/null @@ -1,1580 +0,0 @@ -name: CI - -on: - push: - branches: - - master - tags: - - 'v*' - paths: ['.github/workflows/build.yml', - '**/CMakeLists.txt', - '**/Makefile', - '**/*.mk', - '**/*.cmake', - '**/*.in', - '**/*.h', - '**/*.hpp', - '**/*.c', - '**/*.cpp', - '**/*.cu', - '**/*.cuh', - '**/*.cl', - '**/*.swift', - '**/*.m', - '**/*.mm', - '**/*.metal', - '**/*.comp', - '**/*.java'] - - pull_request: - types: [opened, synchronize, reopened] - workflow_dispatch: - inputs: - create_release: - description: 'Create new release' - required: true - type: boolean - pre_release_tag: - description: 'Pre-release tag name' - required: false - type: string - run_type: - description: 'Workflow type to run' - required: true - type: choice - options: - - full-ci - - release-only - -concurrency: - group: ${{ github.workflow }}-${{ github.head_ref && github.ref || github.run_id }} - cancel-in-progress: true - -permissions: - contents: write # for creating release - -env: - BRANCH_NAME: ${{ github.head_ref || github.ref_name }} - ubuntu_image: "ubuntu:22.04" - VCPKG_BINARY_SOURCES: "clear;x-gha,readwrite" - -jobs: - determine-tag: - runs-on: ubuntu-latest - outputs: - tag_name: ${{ steps.tag.outputs.name }} - should_release: ${{ steps.tag.outputs.should_release }} - - steps: - - name: Checkout with full history - uses: actions/checkout@v4 - with: - fetch-depth: 0 - - - name: Determine tag name - id: tag - shell: bash - run: | - BUILD_NUMBER=$(git rev-list --count HEAD) - SHORT_HASH=$(git rev-parse --short=7 HEAD) - CUSTOM_TAG="${{ github.event.inputs.pre_release_tag }}" - SHOULD_RELEASE="false" - - echo "Raw values:" - echo "BUILD_NUMBER: $BUILD_NUMBER" - echo "SHORT_HASH: $SHORT_HASH" - echo "BRANCH_NAME: ${{ env.BRANCH_NAME }}" - echo "CUSTOM_TAG: $CUSTOM_TAG" - - if [[ "${{ github.ref_type }}" == "tag" ]]; then - echo "Using pushed tag name" - TAG_NAME="${{ github.ref_name }}" - SHOULD_RELEASE="true" - elif [[ -n "$CUSTOM_TAG" ]]; then - echo "Using custom tag" - TAG_NAME="${CUSTOM_TAG}" - SHOULD_RELEASE="true" - elif [[ "${{ github.event.inputs.create_release }}" == "true" ]]; then - echo "Manual release requested" - SHOULD_RELEASE="true" - TAG_NAME="b${BUILD_NUMBER}" - elif [[ "${{ env.BRANCH_NAME }}" == "master" ]]; then - echo "Using master branch format" - TAG_NAME="b${BUILD_NUMBER}" - SHOULD_RELEASE="false" - else - echo "Using non-master branch format" - SAFE_NAME=$(echo "${{ env.BRANCH_NAME }}" | tr '/' '-') - TAG_NAME="${SAFE_NAME}-b${BUILD_NUMBER}-${SHORT_HASH}" - SHOULD_RELEASE="false" - fi - - echo "Final tag name: $TAG_NAME" - echo "Should release: $SHOULD_RELEASE" - echo "name=$TAG_NAME" >> $GITHUB_OUTPUT - echo "should_release=$SHOULD_RELEASE" >> $GITHUB_OUTPUT - - - ubuntu-22: - if: ${{ github.event_name == 'push' || github.event_name == 'pull_request' || - github.event.inputs.run_type == 'full-ci' }} - runs-on: ubuntu-22.04 - - strategy: - fail-fast: false - matrix: - arch: [linux/amd64, linux/ppc64le] - - steps: - - name: Clone - uses: actions/checkout@v4 - - - name: Set up QEMU - uses: docker/setup-qemu-action@v3 - - - name: Build ${{ matrix.arch }} - run: | - docker run --platform ${{ matrix.arch }} --rm \ - -v ${{ github.workspace }}:/workspace \ - -w /workspace ${{ env.ubuntu_image }} /bin/sh -c ' - set -e - export DEBIAN_FRONTEND=noninteractive - sed -i "s|archive.ubuntu.com|mirrors.kernel.org|g" /etc/apt/sources.list - sed -i "s|security.ubuntu.com|mirrors.kernel.org|g" /etc/apt/sources.list - - apt update - apt install -y build-essential libsdl2-dev cmake git - cmake -B build - cmake --build build --config Release -j $(nproc)' - - ubuntu-22-arm64: - if: ${{ github.event_name == 'push' || github.event_name == 'pull_request' || - github.event.inputs.run_type == 'full-ci' }} - runs-on: ubuntu-22.04 - - strategy: - fail-fast: false - matrix: - arch: [linux/arm64] - - steps: - - name: Clone - uses: actions/checkout@v4 - - - name: Set up QEMU - uses: docker/setup-qemu-action@v3 - - - name: Build ${{ matrix.arch }} - run: | - docker run --platform ${{ matrix.arch }} --rm \ - -v ${{ github.workspace }}:/workspace \ - -w /workspace ${{ env.ubuntu_image }} /bin/sh -c ' - set -e - export DEBIAN_FRONTEND=noninteractive - sed -i "s|archive.ubuntu.com|mirrors.kernel.org|g" /etc/apt/sources.list - sed -i "s|security.ubuntu.com|mirrors.kernel.org|g" /etc/apt/sources.list - - apt-get update - apt-get install -y ca-certificates - sed -i "s|http://ports.ubuntu.com|https://mirror.kumi.systems|g" /etc/apt/sources.list - - apt update - apt install -y build-essential libsdl2-dev cmake git - cmake -B build -DGGML_NATIVE=OFF -DGGML_CPU_ARM_ARCH=armv8-a - cmake --build build --config Release -j $(nproc)' - - ubuntu-22-arm-v7: - if: ${{ github.event_name == 'push' || github.event_name == 'pull_request' || - github.event.inputs.run_type == 'full-ci' }} - runs-on: ubuntu-22.04 - - strategy: - fail-fast: false - matrix: - arch: [linux/arm/v7] - - steps: - - name: Clone - uses: actions/checkout@v4 - - - name: Set up QEMU - uses: docker/setup-qemu-action@v3 - - - name: Build ${{ matrix.arch }} - run: | - docker run --platform ${{ matrix.arch }} --rm \ - -v ${{ github.workspace }}:/workspace \ - -w /workspace ${{ env.ubuntu_image }} /bin/sh -c ' - set -e - export DEBIAN_FRONTEND=noninteractive - sed -i "s|archive.ubuntu.com|mirrors.kernel.org|g" /etc/apt/sources.list - sed -i "s|security.ubuntu.com|mirrors.kernel.org|g" /etc/apt/sources.list - - apt-get update - apt-get install -y ca-certificates - sed -i "s|http://ports.ubuntu.com|https://mirror.kumi.systems|g" /etc/apt/sources.list - - apt update - apt install -y build-essential libsdl2-dev cmake git - cmake -B build -DGGML_NATIVE=OFF -DGGML_CPU_ARM_ARCH=armv7-a+fp - cmake --build build --config Release -j $(nproc)' - - macOS-latest: - if: ${{ github.event_name == 'push' || github.event_name == 'pull_request' || - github.event.inputs.run_type == 'full-ci' }} - runs-on: macOS-latest - - strategy: - matrix: - destination: ['generic/platform=macOS', 'generic/platform=iOS', 'generic/platform=tvOS'] - - steps: - - name: Clone - id: checkout - uses: actions/checkout@v4 - - - name: ccache - uses: hendrikmuhs/ccache-action@v1.2.16 - with: - key: macOS-latest-swift - evict-old-files: 1d - - - name: Dependencies - run: | - brew update - cmake --version - brew install sdl2 - - - name: Build - run: | - sysctl -a - cmake -B build -G Xcode \ - -DGGML_METAL_USE_BF16=ON \ - -DGGML_METAL_EMBED_LIBRARY=ON \ - -DWHISPER_BUILD_EXAMPLES=OFF \ - -DWHISPER_BUILD_TESTS=OFF \ - -DWHISPER_BUILD_SERVER=OFF \ - -DCMAKE_OSX_ARCHITECTURES="arm64;x86_64" - cmake --build build --config Release -j $(sysctl -n hw.logicalcpu) - - -# freeBSD-latest: -# runs-on: macos-13 -# -# steps: -# - name: Clone -# uses: actions/checkout@v4 -# -# - name: Build -# uses: cross-platform-actions/action@v0.27.0 -# with: -# operating_system: freebsd -# version: '14.2' -# run: | -# sudo pkg update -# sudo pkg install -y gmake sdl2 cmake git -# cmake -B build -# cmake --build build --config Release - - ubuntu-22-gcc: - if: ${{ github.event_name == 'push' || github.event_name == 'pull_request' || - github.event.inputs.run_type == 'full-ci' }} - runs-on: ubuntu-22.04 - - strategy: - fail-fast: false - matrix: - build: [Debug, Release] - arch: [linux/amd64, linux/ppc64le] - - steps: - - name: Clone - uses: actions/checkout@v4 - - - name: Set up QEMU - uses: docker/setup-qemu-action@v3 - - - name: Build ${{ matrix.arch }} - run: | - docker run --platform ${{ matrix.arch }} --rm \ - -v ${{ github.workspace }}:/workspace \ - -w /workspace ${{ env.ubuntu_image }} /bin/sh -c ' - set -e - export DEBIAN_FRONTEND=noninteractive - sed -i "s|archive.ubuntu.com|mirrors.kernel.org|g" /etc/apt/sources.list - sed -i "s|security.ubuntu.com|mirrors.kernel.org|g" /etc/apt/sources.list - - apt update - apt install -y build-essential cmake libsdl2-dev git - cmake . -DWHISPER_SDL2=ON -DCMAKE_BUILD_TYPE=${{ matrix.build }} - make - ctest -L gh --output-on-failure' - - ubuntu-22-gcc-arm64: - if: ${{ github.event_name == 'push' || github.event_name == 'pull_request' || - github.event.inputs.run_type == 'full-ci' }} - runs-on: ubuntu-22.04 - - strategy: - fail-fast: false - matrix: - build: [Debug, Release] - arch: [linux/arm64] - - steps: - - name: Clone - uses: actions/checkout@v4 - - - name: Set up QEMU - uses: docker/setup-qemu-action@v3 - - - name: Build ${{ matrix.arch }} - run: | - docker run --platform ${{ matrix.arch }} --rm \ - -v ${{ github.workspace }}:/workspace \ - -w /workspace ${{ env.ubuntu_image }} /bin/sh -c ' - set -e - export DEBIAN_FRONTEND=noninteractive - sed -i "s|archive.ubuntu.com|mirrors.kernel.org|g" /etc/apt/sources.list - sed -i "s|security.ubuntu.com|mirrors.kernel.org|g" /etc/apt/sources.list - - apt-get update - apt-get install -y ca-certificates - sed -i "s|http://ports.ubuntu.com|https://mirror.kumi.systems|g" /etc/apt/sources.list - - apt update - apt install -y build-essential cmake libsdl2-dev git - cmake . -DWHISPER_SDL2=ON -DCMAKE_BUILD_TYPE=${{ matrix.build }} -DGGML_NATIVE=OFF -DGGML_CPU_ARM_ARCH=armv8-a - make - ctest -L gh --output-on-failure' - - ubuntu-22-gcc-arm-v7: - if: ${{ github.event_name == 'push' || github.event_name == 'pull_request' || - github.event.inputs.run_type == 'full-ci' }} - runs-on: ubuntu-22.04 - - strategy: - fail-fast: false - matrix: - build: [Debug, Release] - arch: [linux/arm/v7] - - steps: - - name: Clone - uses: actions/checkout@v4 - - - name: Set up QEMU - uses: docker/setup-qemu-action@v3 - - - name: Build ${{ matrix.arch }} - run: | - docker run --platform ${{ matrix.arch }} --rm \ - -v ${{ github.workspace }}:/workspace \ - -w /workspace ${{ env.ubuntu_image }} /bin/sh -c ' - set -e - export DEBIAN_FRONTEND=noninteractive - sed -i "s|archive.ubuntu.com|mirrors.kernel.org|g" /etc/apt/sources.list - sed -i "s|security.ubuntu.com|mirrors.kernel.org|g" /etc/apt/sources.list - - apt-get update - apt-get install -y ca-certificates - sed -i "s|http://ports.ubuntu.com|https://mirror.kumi.systems|g" /etc/apt/sources.list - - apt update - apt install -y build-essential cmake libsdl2-dev git - cmake . -DWHISPER_SDL2=ON -DCMAKE_BUILD_TYPE=${{ matrix.build }} -DGGML_NATIVE=OFF -DGGML_CPU_ARM_ARCH=armv7-a+fp - make - ctest -L gh --output-on-failure' - - ubuntu-22-clang: - if: ${{ github.event_name == 'push' || github.event_name == 'pull_request' || - github.event.inputs.run_type == 'full-ci' }} - runs-on: ubuntu-22.04 - - strategy: - fail-fast: false - matrix: - build: [Debug, Release] - #arch: [linux/amd64, linux/arm64, linux/arm/v7, linux/ppc64le] - # TODO: arm/v7 disabled due to clang bug - # https://github.com/ggerganov/whisper.cpp/actions/runs/9657764109/job/26637633042?pr=2256#step:4:1990 - arch: [linux/amd64, linux/arm64, linux/ppc64le] - - steps: - - name: Clone - uses: actions/checkout@v4 - - - name: Set up QEMU - uses: docker/setup-qemu-action@v3 - - - name: Build ${{ matrix.arch }} - run: | - docker run --platform ${{ matrix.arch }} --rm \ - -v ${{ github.workspace }}:/workspace \ - -w /workspace ${{ env.ubuntu_image }} /bin/sh -c ' - set -e - export DEBIAN_FRONTEND=noninteractive - sed -i "s|archive.ubuntu.com|mirrors.kernel.org|g" /etc/apt/sources.list - sed -i "s|security.ubuntu.com|mirrors.kernel.org|g" /etc/apt/sources.list - - apt-get update - apt-get install -y ca-certificates - sed -i "s|http://ports.ubuntu.com|https://mirror.kumi.systems|g" /etc/apt/sources.list - - apt update - apt install -y clang build-essential cmake libsdl2-dev git - cmake . -DWHISPER_SDL2=ON -DCMAKE_BUILD_TYPE=${{ matrix.build }} -DCMAKE_CXX_COMPILER=clang++ -DCMAKE_C_COMPILER=clang - make - ctest -L gh --output-on-failure' - - ubuntu-22-gcc-sanitized: - if: ${{ github.event_name == 'push' || github.event_name == 'pull_request' || - github.event.inputs.run_type == 'full-ci' }} - runs-on: ubuntu-22.04 - - strategy: - fail-fast: false - matrix: - sanitizer: [ADDRESS, THREAD, UNDEFINED] - arch: [linux/amd64] - - steps: - - name: Clone - uses: actions/checkout@v4 - - - name: Set up QEMU - uses: docker/setup-qemu-action@v3 - - - name: Build ${{ matrix.arch }} - run: | - docker run --platform ${{ matrix.arch }} --rm \ - -v ${{ github.workspace }}:/workspace \ - -w /workspace ${{ env.ubuntu_image }} /bin/sh -c ' - set -e - export DEBIAN_FRONTEND=noninteractive - sed -i "s|archive.ubuntu.com|mirrors.kernel.org|g" /etc/apt/sources.list - sed -i "s|security.ubuntu.com|mirrors.kernel.org|g" /etc/apt/sources.list - - apt update - apt install -y build-essential cmake git - cmake . -DCMAKE_BUILD_TYPE=Debug \ - -DWHISPER_SANITIZE_${{ matrix.sanitizer }}=ON \ - -DGGML_OPENMP=OFF - make - ctest -L gh --output-on-failure' - - ubuntu-22-cmake-sycl: - if: ${{ github.event_name == 'push' || github.event_name == 'pull_request' || - github.event.inputs.run_type == 'full-ci' }} - runs-on: ubuntu-22.04 - - strategy: - fail-fast: false - matrix: - dwhisper_sycl: [ON] - dcmake_c_compiler: [icx] - dcmake_cxx_compiler: [icpx] - arch: [linux/amd64, linux/arm64, linux/arm/v7, linux/ppc64le] - - continue-on-error: true - - steps: - - name: Clone - uses: actions/checkout@v4 - - - name: add oneAPI to apt - shell: bash - run: | - cd /tmp - wget https://apt.repos.intel.com/intel-gpg-keys/GPG-PUB-KEY-INTEL-SW-PRODUCTS.PUB - sudo apt-key add GPG-PUB-KEY-INTEL-SW-PRODUCTS.PUB - rm GPG-PUB-KEY-INTEL-SW-PRODUCTS.PUB - sudo add-apt-repository "deb https://apt.repos.intel.com/oneapi all main" - - - name: install oneAPI dpcpp compiler - shell: bash - run: | - sudo apt update - sudo apt install intel-oneapi-compiler-dpcpp-cpp git - - - name: install oneAPI MKL library - shell: bash - run: | - sudo apt install intel-oneapi-mkl-devel git - - - name: Clone - id: checkout - uses: actions/checkout@v4 - - - name: Build - id: cmake_build - run: | - source /opt/intel/oneapi/setvars.sh - mkdir build - cd build - cmake -DGGML_SYCL=ON -DCMAKE_C_COMPILER=icx -DCMAKE_CXX_COMPILER=icpx .. - cmake --build . --config Release -j $(nproc) - - ubuntu-22-cmake-sycl-fp16: - if: ${{ github.event_name == 'push' || github.event_name == 'pull_request' || - github.event.inputs.run_type == 'full-ci' }} - runs-on: ubuntu-22.04 - - strategy: - fail-fast: false - matrix: - dwhisper_sycl: [ON] - dcmake_c_compiler: [icx] - dcmake_cxx_compiler: [icpx] - arch: [linux/amd64, linux/arm64, linux/arm/v7, linux/ppc64le] - - continue-on-error: true - - steps: - - name: Clone - uses: actions/checkout@v4 - - - name: add oneAPI to apt - shell: bash - run: | - cd /tmp - wget https://apt.repos.intel.com/intel-gpg-keys/GPG-PUB-KEY-INTEL-SW-PRODUCTS.PUB - sudo apt-key add GPG-PUB-KEY-INTEL-SW-PRODUCTS.PUB - rm GPG-PUB-KEY-INTEL-SW-PRODUCTS.PUB - sudo add-apt-repository "deb https://apt.repos.intel.com/oneapi all main" - - - name: install oneAPI dpcpp compiler - shell: bash - run: | - sudo apt update - sudo apt install intel-oneapi-compiler-dpcpp-cpp git - - - name: install oneAPI MKL library - shell: bash - run: | - sudo apt install intel-oneapi-mkl-devel - - - name: Clone - id: checkout - uses: actions/checkout@v4 - - - name: Build - id: cmake_build - run: | - source /opt/intel/oneapi/setvars.sh - mkdir build - cd build - cmake -DGGML_SYCL_F16=ON -DCMAKE_C_COMPILER=icx -DCMAKE_CXX_COMPILER=icpx .. - cmake --build . --config Release -j $(nproc) - - windows-msys2: - if: ${{ github.event_name == 'push' || github.event_name == 'pull_request' || - github.event.inputs.run_type == 'full-ci' }} - runs-on: windows-latest - - strategy: - fail-fast: false - matrix: - include: - - { sys: UCRT64, env: ucrt-x86_64, build: Release } - - { sys: CLANG64, env: clang-x86_64, build: Release } - - steps: - - name: Clone - uses: actions/checkout@v4 - - - name: Setup ${{ matrix.sys }} - uses: msys2/setup-msys2@v2 - with: - update: true - msystem: ${{matrix.sys}} - install: >- - base-devel - git - mingw-w64-${{matrix.env}}-toolchain - mingw-w64-${{matrix.env}}-cmake - mingw-w64-${{matrix.env}}-SDL2 - mingw-w64-${{matrix.env}}-openblas - - - name: Build using CMake - shell: msys2 {0} - run: | - cmake -B build -DWHISPER_SDL2=ON - cmake --build build --config ${{ matrix.build }} -j $(nproc) - - - name: Clean after building using CMake - shell: msys2 {0} - run: | - rm -rf build - - - name: Build using CMake w/ OpenBLAS - shell: msys2 {0} - run: | - cmake -B build -DGGML_BLAS=ON -DGGML_BLAS_VENDOR=OpenBLAS - cmake --build build --config ${{ matrix.build }} -j $(nproc) - - windows: - if: ${{ github.event_name == 'push' || github.event_name == 'pull_request' || - github.event.inputs.run_type == 'full-ci' }} - runs-on: windows-latest - needs: determine-tag - - strategy: - matrix: - build: [Release] - arch: [Win32, x64] - sdl2: [ON] - include: - - arch: Win32 - s2arc: x86 - jnaPath: win32-x86 - - arch: x64 - s2arc: x64 - jnaPath: win32-x86-64 - - sdl2: ON - s2ver: 2.28.5 - - steps: - - name: Clone - uses: actions/checkout@v4 - - - name: Add msbuild to PATH - uses: microsoft/setup-msbuild@v2 - - - name: Fetch SDL2 and set SDL2_DIR - if: matrix.sdl2 == 'ON' - run: | - C:/msys64/usr/bin/wget.exe -qO sdl2.zip https://github.com/libsdl-org/SDL/releases/download/release-${{ matrix.s2ver }}/SDL2-devel-${{ matrix.s2ver }}-VC.zip - 7z x sdl2.zip - echo "SDL2_DIR=$env:GITHUB_WORKSPACE/SDL2-${{ matrix.s2ver }}/cmake" >> $env:GITHUB_ENV - - - name: Configure - run: > - cmake -S . -B ./build -A ${{ matrix.arch }} - -DCMAKE_BUILD_TYPE=${{ matrix.build }} - -DBUILD_SHARED_LIBS=ON - -DWHISPER_SDL2=${{ matrix.sdl2 }} - - - name: Build - run: | - cd ./build - msbuild ALL_BUILD.vcxproj -t:build -p:configuration=${{ matrix.build }} -p:platform=${{ matrix.arch }} - - - name: Copy SDL2.dll - if: matrix.sdl2 == 'ON' - run: copy "$env:SDL2_DIR/../lib/${{ matrix.s2arc }}/SDL2.dll" build/bin/${{ matrix.build }} - - - name: Upload SDL2.dll - if: matrix.sdl2 == 'ON' - uses: actions/upload-artifact@v4 - with: - name: ${{ matrix.s2arc }}_SDL2.dll - path: build/bin/${{ matrix.build }}/SDL2.dll - - - name: Upload whisper dll - uses: actions/upload-artifact@v4 - with: - name: whisper_${{ matrix.arch }}.dll - path: build/bin/${{ matrix.build }}/whisper.dll - - - name: Upload ggml dll - uses: actions/upload-artifact@v4 - with: - name: ggml_${{ matrix.arch }}.dll - path: build/bin/${{ matrix.build }}/ggml.dll - - - name: Upload ggml base dll - uses: actions/upload-artifact@v4 - with: - name: ggml_base_${{ matrix.arch }}.dll - path: build/bin/${{ matrix.build }}/ggml-base.dll - - - name: Upload ggml cpu dll - uses: actions/upload-artifact@v4 - with: - name: ggml_cpu_${{ matrix.arch }}.dll - path: build/bin/${{ matrix.build }}/ggml-cpu.dll - - - name: Pack bin artifacts - shell: pwsh - run: | - Compress-Archive -Path "build/bin/${{ matrix.build }}" -DestinationPath "whisper-bin-${{ matrix.arch }}.zip" - - - name: Upload binaries - if: matrix.sdl2 == 'ON' && ${{ needs.determine-tag.outputs.should_release }} - uses: actions/upload-artifact@v4 - with: - name: whisper-bin-${{ matrix.arch }}.zip - path: whisper-bin-${{ matrix.arch }}.zip - - windows-blas: - if: ${{ github.event_name == 'push' || github.event_name == 'pull_request' || - github.event.inputs.run_type == 'full-ci' }} - runs-on: windows-latest - - strategy: - matrix: - build: [Release] - arch: [Win32, x64] - blas: [ON] - sdl2: [ON] - blasver: [0.3.29] - include: - - arch: Win32 - s2arc: x86 - blasfile: x86 - - arch: x64 - s2arc: x64 - blasfile: x64_64 - - sdl2: ON - s2ver: 2.28.5 - - steps: - - name: Clone - uses: actions/checkout@v4 - - - name: Export GitHub Actions cache environment variables - uses: actions/github-script@v7 - with: - script: | - core.exportVariable('ACTIONS_CACHE_URL', process.env.ACTIONS_CACHE_URL || ''); - core.exportVariable('ACTIONS_RUNTIME_TOKEN', process.env.ACTIONS_RUNTIME_TOKEN || ''); - - - name: Add msbuild to PATH - uses: microsoft/setup-msbuild@v2 - - - name: Install OpenBLAS and pkgconfiglite - if: matrix.blas == 'ON' - run: | - Invoke-WebRequest "https://github.com/OpenMathLib/OpenBLAS/releases/download/v${{matrix.blasver}}/OpenBLAS-${{matrix.blasver}}_${{matrix.blasfile}}.zip" -OutFile "OpenBLAS-${{matrix.blasver}}.zip" - Expand-Archive "OpenBLAS-${{matrix.blasver}}.zip" -DestinationPath "OpenBLAS-${{matrix.blasver}}" - choco install pkgconfiglite - - - name: Fetch SDL2 and set SDL2_DIR - if: matrix.sdl2 == 'ON' - run: | - C:/msys64/usr/bin/wget.exe -qO sdl2.zip https://github.com/libsdl-org/SDL/releases/download/release-${{ matrix.s2ver }}/SDL2-devel-${{ matrix.s2ver }}-VC.zip - 7z x sdl2.zip - echo "SDL2_DIR=$env:GITHUB_WORKSPACE/SDL2-${{ matrix.s2ver }}/cmake" >> $env:GITHUB_ENV - - - name: Configure - run: > - cmake -S . -B ./build -A ${{ matrix.arch }} - -DCMAKE_TOOLCHAIN_FILE="$env:VCPKG_INSTALLATION_ROOT/scripts/buildsystems/vcpkg.cmake" - -DCMAKE_BUILD_TYPE=${{ matrix.build }} - -DGGML_BLAS=${{ matrix.blas }} - -DGGML_BLAS_VENDOR=OpenBLAS - -DBLAS_LIBRARIES="$env:GITHUB_WORKSPACE/OpenBLAS-${{matrix.blasver}}/lib/libopenblas.lib" - -DBLAS_INCLUDE_DIRS="$env:GITHUB_WORKSPACE/OpenBLAS-${{matrix.blasver}}/include" - -DWHISPER_SDL2=${{ matrix.sdl2 }} - - - name: Build - run: | - cd ./build - msbuild ALL_BUILD.vcxproj -t:build -p:configuration=${{ matrix.build }} -p:platform=${{ matrix.arch }} - - - name: Copy openblas.dll - if: matrix.blas == 'ON' - run: copy "$env:GITHUB_WORKSPACE/OpenBLAS-${{matrix.blasver}}/bin/libopenblas.dll" build/bin/${{ matrix.build }} - - - name: Copy SDL2.dll - if: matrix.sdl2 == 'ON' - run: copy "$env:SDL2_DIR/../lib/${{ matrix.s2arc }}/SDL2.dll" build/bin/${{ matrix.build }} - - - name: Pack bin artifacts - shell: pwsh - run: | - Compress-Archive -Path "build/bin/${{ matrix.build }}" -DestinationPath "whisper-blas-bin-${{ matrix.arch }}.zip" - - - name: Upload binaries - if: matrix.blas == 'ON' && matrix.sdl2 == 'ON' && ${{ needs.determine-tag.outputs.should_release }} - uses: actions/upload-artifact@v4 - with: - name: whisper-blas-bin-${{ matrix.arch }}.zip - path: whisper-blas-bin-${{ matrix.arch }}.zip - - windows-cublas: - if: ${{ github.event_name == 'push' || github.event_name == 'pull_request' || - github.event.inputs.run_type == 'full-ci' }} - runs-on: windows-2022 - needs: determine-tag - strategy: - fail-fast: false - matrix: - build: [Release] - arch: [x64] - cublas: [ON] - sdl2: [ON] - cuda-toolkit: [12.4.0, 11.8.0] - include: - - arch: x64 - sdl2: ON - sdl2_ver: 2.28.5 - steps: - - name: Clone repository - uses: actions/checkout@v4 - - - name: Install Ninja - id: install_ninja - run: | - choco install ninja - - - name: Install ccache - uses: hendrikmuhs/ccache-action@v1.2.16 - with: - key: ${{ github.job }}-${{ matrix.cuda-toolkit }}-${{ matrix.build }} - variant: sccache - evict-old-files: 5d - - - name: Install Cuda Toolkit 11.8.0 - if: ${{ matrix.cuda-toolkit == '11.8.0' }} - run: | - $CUDA_VERSION = ${{ matrix.cuda-toolkit }} - $CUDA_TOOLKIT_DIR = "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v$CUDA_VERSION" - $CUDA_DOWNLOAD = "https://developer.download.nvidia.com/compute/cuda/redist" - - # Components versions - $CUDART_VER = "11.8.89" - $NVCC_VER = "11.8.89" - $NVRTC_VER = "11.8.89" - $CUBLAS_VER = "11.8.1.74" - $NVTX_VER = "11.8.86" - $VS_VER = "11.8.86" - $NVPROF_VER = "11.8.87" - $CCCL_VER = "11.8.89" - - # Create the directory where the CUDA Toolkit will be installed - mkdir -p $CUDA_TOOLKIT_DIR - - # Install unzip to extract the downloaded files - choco install unzip -y - - # Download all the required components - curl -O "$CUDA_DOWNLOAD/cuda_cudart/windows-x86_64/cuda_cudart-windows-x86_64-${CUDART_VER}-archive.zip" - curl -O "$CUDA_DOWNLOAD/cuda_nvcc/windows-x86_64/cuda_nvcc-windows-x86_64-${NVCC_VER}-archive.zip" - curl -O "$CUDA_DOWNLOAD/cuda_nvrtc/windows-x86_64/cuda_nvrtc-windows-x86_64-${NVRTC_VER}-archive.zip" - curl -O "$CUDA_DOWNLOAD/libcublas/windows-x86_64/libcublas-windows-x86_64-${CUBLAS_VER}-archive.zip" - curl -O "$CUDA_DOWNLOAD/cuda_nvtx/windows-x86_64/cuda_nvtx-windows-x86_64-${NVTX_VER}-archive.zip" - curl -O "$CUDA_DOWNLOAD/visual_studio_integration/windows-x86_64/visual_studio_integration-windows-x86_64-${VS_VER}-archive.zip" - curl -O "$CUDA_DOWNLOAD/cuda_nvprof/windows-x86_64/cuda_nvprof-windows-x86_64-${NVPROF_VER}-archive.zip" - curl -O "$CUDA_DOWNLOAD/cuda_cccl/windows-x86_64/cuda_cccl-windows-x86_64-${CCCL_VER}-archive.zip" - - # Extract all the downloaded files to the CUDA Toolkit directory - unzip '*.zip' -d $CUDA_TOOLKIT_DIR - - # Copy all the extracted files to the main CUDA Toolkit directory - xcopy "$CUDA_TOOLKIT_DIR\cuda_cudart-windows-x86_64-${CUDART_VER}-archive\*" "$CUDA_TOOLKIT_DIR" /E /I /H /Y - xcopy "$CUDA_TOOLKIT_DIR\cuda_nvcc-windows-x86_64-${NVCC_VER}-archive\*" "$CUDA_TOOLKIT_DIR" /E /I /H /Y - xcopy "$CUDA_TOOLKIT_DIR\cuda_nvrtc-windows-x86_64-${NVRTC_VER}-archive\*" "$CUDA_TOOLKIT_DIR" /E /I /H /Y - xcopy "$CUDA_TOOLKIT_DIR\libcublas-windows-x86_64-${CUBLAS_VER}-archive\*" "$CUDA_TOOLKIT_DIR" /E /I /H /Y - xcopy "$CUDA_TOOLKIT_DIR\cuda_nvtx-windows-x86_64-${NVTX_VER}-archive\*" "$CUDA_TOOLKIT_DIR" /E /I /H /Y - xcopy "$CUDA_TOOLKIT_DIR\cuda_nvprof-windows-x86_64-${NVPROF_VER}-archive\*" "$CUDA_TOOLKIT_DIR" /E /I /H /Y - xcopy "$CUDA_TOOLKIT_DIR\cuda_cccl-windows-x86_64-${CCCL_VER}-archive\*" "$CUDA_TOOLKIT_DIR" /E /I /H /Y - xcopy "$CUDA_TOOLKIT_DIR\visual_studio_integration-windows-x86_64-${VS_VER}-archive\*" "$CUDA_TOOLKIT_DIR" /E /I /H /Y - - # Visual Studio integration - xcopy "$CUDA_TOOLKIT_DIR\visual_studio_integration-windows-x86_64-${VS_VER}-archive\visual_studio_integration\MSBuildExtensions\*" "C:\Program Files\Microsoft Visual Studio\2022\Enterprise\MSBuild\Microsoft\VC\v170\BuildCustomizations" /E /I /H /Y - - # Set environment variables - echo "$CUDA_TOOLKIT_DIR\bin" | Out-File -FilePath $env:GITHUB_PATH -Encoding utf8 -Append - echo "$CUDA_TOOLKIT_DIR\libnvvp" | Out-File -FilePath $env:GITHUB_PATH -Encoding utf8 -Append - echo "CUDA_PATH=$CUDA_TOOLKIT_DIR" | Out-File -FilePath $env:GITHUB_ENV -Append -Encoding utf8 - echo "CUDA_PATH_V11_8=$CUDA_TOOLKIT_DIR" | Out-File -FilePath $env:GITHUB_ENV -Append -Encoding utf8 - - - name: Install Cuda Toolkit 12.4.0 - if: ${{ matrix.cuda-toolkit == '12.4.0' }} - run: | - $CUDA_VERSION = ${{ matrix.cuda-toolkit }} - $CUDA_TOOLKIT_DIR = "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v$CUDA_VERSION" - $CUDA_DOWNLOAD = "https://developer.download.nvidia.com/compute/cuda/redist" - - # Components versions - $CUDART_VER = "12.4.127" - $NVCC_VER = "12.4.131" - $NVRTC_VER = "12.4.127" - $CUBLAS_VER = "12.4.5.8" - $NVTX_VER = "12.4.127" - $PROFILER_VER = "12.4.127" - $VS_VER = "12.4.127" - $NVPROF_VER = "12.4.128" - $CCCL_VER = "12.4.127" - - # Create the directory where the CUDA Toolkit will be installed - mkdir -p $CUDA_TOOLKIT_DIR - - # Install unzip to extract the downloaded files - choco install unzip -y - - # Download all the required components - curl -O "$CUDA_DOWNLOAD/cuda_cudart/windows-x86_64/cuda_cudart-windows-x86_64-${CUDART_VER}-archive.zip" - curl -O "$CUDA_DOWNLOAD/cuda_nvcc/windows-x86_64/cuda_nvcc-windows-x86_64-${NVCC_VER}-archive.zip" - curl -O "$CUDA_DOWNLOAD/cuda_nvrtc/windows-x86_64/cuda_nvrtc-windows-x86_64-${NVRTC_VER}-archive.zip" - curl -O "$CUDA_DOWNLOAD/libcublas/windows-x86_64/libcublas-windows-x86_64-${CUBLAS_VER}-archive.zip" - curl -O "$CUDA_DOWNLOAD/cuda_nvtx/windows-x86_64/cuda_nvtx-windows-x86_64-${NVTX_VER}-archive.zip" - curl -O "$CUDA_DOWNLOAD/cuda_profiler_api/windows-x86_64/cuda_profiler_api-windows-x86_64-${PROFILER_VER}-archive.zip" - curl -O "$CUDA_DOWNLOAD/visual_studio_integration/windows-x86_64/visual_studio_integration-windows-x86_64-${VS_VER}-archive.zip" - curl -O "$CUDA_DOWNLOAD/cuda_nvprof/windows-x86_64/cuda_nvprof-windows-x86_64-${NVPROF_VER}-archive.zip" - curl -O "$CUDA_DOWNLOAD/cuda_cccl/windows-x86_64/cuda_cccl-windows-x86_64-${CCCL_VER}-archive.zip" - - # Extract all the downloaded files to the CUDA Toolkit directory - unzip -q '*.zip' -d $CUDA_TOOLKIT_DIR - - # Copy all the extracted files to the main CUDA Toolkit directory - xcopy "$CUDA_TOOLKIT_DIR\cuda_cudart-windows-x86_64-${CUDART_VER}-archive\*" "$CUDA_TOOLKIT_DIR" /E /I /H /Y - xcopy "$CUDA_TOOLKIT_DIR\cuda_nvcc-windows-x86_64-${NVCC_VER}-archive\*" "$CUDA_TOOLKIT_DIR" /E /I /H /Y - xcopy "$CUDA_TOOLKIT_DIR\cuda_nvrtc-windows-x86_64-${NVRTC_VER}-archive\*" "$CUDA_TOOLKIT_DIR" /E /I /H /Y - xcopy "$CUDA_TOOLKIT_DIR\libcublas-windows-x86_64-${CUBLAS_VER}-archive\*" "$CUDA_TOOLKIT_DIR" /E /I /H /Y - xcopy "$CUDA_TOOLKIT_DIR\cuda_nvtx-windows-x86_64-${NVTX_VER}-archive\*" "$CUDA_TOOLKIT_DIR" /E /I /H /Y - xcopy "$CUDA_TOOLKIT_DIR\cuda_nvprof-windows-x86_64-${NVPROF_VER}-archive\*" "$CUDA_TOOLKIT_DIR" /E /I /H /Y - xcopy "$CUDA_TOOLKIT_DIR\cuda_cccl-windows-x86_64-${CCCL_VER}-archive\*" "$CUDA_TOOLKIT_DIR" /E /I /H /Y - xcopy "$CUDA_TOOLKIT_DIR\cuda_profiler_api-windows-x86_64-${PROFILER_VER}-archive\*" "$CUDA_TOOLKIT_DIR" /E /I /H /Y - xcopy "$CUDA_TOOLKIT_DIR\visual_studio_integration-windows-x86_64-${VS_VER}-archive\*" "$CUDA_TOOLKIT_DIR" /E /I /H /Y - - # Visual Studio integration - xcopy "$CUDA_TOOLKIT_DIR\visual_studio_integration-windows-x86_64-${VS_VER}-archive\visual_studio_integration\MSBuildExtensions\*" "C:\Program Files\Microsoft Visual Studio\2022\Enterprise\MSBuild\Microsoft\VC\v170\BuildCustomizations" /E /I /H /Y - - # Set environment variables - echo "$CUDA_TOOLKIT_DIR\bin" | Out-File -FilePath $env:GITHUB_PATH -Encoding utf8 -Append - echo "$CUDA_TOOLKIT_DIR\libnvvp" | Out-File -FilePath $env:GITHUB_PATH -Encoding utf8 -Append - echo "CUDA_PATH=$CUDA_TOOLKIT_DIR" | Out-File -FilePath $env:GITHUB_ENV -Append -Encoding utf8 - echo "CUDA_PATH_V12_2=$CUDA_TOOLKIT_DIR" | Out-File -FilePath $env:GITHUB_ENV -Append -Encoding utf8 - - - name: Add msbuild to PATH - uses: microsoft/setup-msbuild@v2 - - - name: Install 7-Zip - run: choco install 7zip -y - - - name: Fetch SDL2 and set SDL2_DIR - if: matrix.sdl2 == 'ON' - run: | - Invoke-WebRequest -Uri https://github.com/libsdl-org/SDL/releases/download/release-${{ matrix.sdl2_ver }}/SDL2-devel-${{ matrix.sdl2_ver }}-VC.zip -OutFile sdl2.zip - 7z x sdl2.zip - echo "SDL2_DIR=${{ github.workspace }}\SDL2-${{ matrix.sdl2_ver }}\cmake" | Out-File -FilePath $env:GITHUB_ENV -Append - echo "${{ github.workspace }}\SDL2-${{ matrix.sdl2_ver }}\cmake" > SDL2_PATH.txt - - - name: Install cmake - run: choco install cmake - - - name: Build Project - shell: cmd - run: | - call "C:\Program Files\Microsoft Visual Studio\2022\Enterprise\VC\Auxiliary\Build\vcvars64.bat" - cmake --version - where cmake - if "${{ matrix.cuda-toolkit }}" == "11.8.0" ( - set CUDA_FLAGS=-allow-unsupported-compiler -D_ALLOW_COMPILER_AND_STL_VERSION_MISMATCH -D_DISABLE_CONSTEXPR_MUTEX_CONSTRUCTOR - ) else ( - set CUDA_FLAGS= - ) - cmake -S . -B build -G "Ninja Multi-Config" ^ - -DCMAKE_BUILD_TYPE=${{ matrix.build }} ^ - -DGGML_CUDA=${{ matrix.cublas }} ^ - -DWHISPER_SDL2=${{ matrix.sdl2 }} ^ - -DSDL2_DIR="%SDL2_DIR%" ^ - -DCMAKE_POLICY_VERSION_MINIMUM=3.5 ^ - -DCMAKE_CUDA_FLAGS="%CUDA_FLAGS%" - set /A NINJA_JOBS=%NUMBER_OF_PROCESSORS%-1 - cmake --build build --config ${{ matrix.build }} -j %NUMBER_OF_PROCESSORS% - - - name: Check sccache status after build - run: | - sccache --show-stats - - - name: Copy CUDA DLLs - run: | - Get-ChildItem "$env:CUDA_PATH\bin\" -Filter "*.dll" | - Copy-Item -Destination "build/bin/${{ matrix.build }}" - - - name: Copy SDL2.dll - if: matrix.sdl2 == 'ON' - run: copy "$env:SDL2_DIR/../lib/${{ matrix.arch }}/SDL2.dll" build/bin/${{ matrix.build }} - - - name: Pack bin artifacts - shell: pwsh - run: | - Compress-Archive -Path "build/bin/${{ matrix.build }}" -DestinationPath "whisper-cublas-${{ matrix.cuda-toolkit }}-bin-${{ matrix.arch }}.zip" - - - name: Upload binaries - if: ${{ needs.determine-tag.outputs.should_release }} - uses: actions/upload-artifact@v4 - with: - name: whisper-cublas-${{ matrix.cuda-toolkit }}-bin-${{ matrix.arch }}.zip - path: whisper-cublas-${{ matrix.cuda-toolkit }}-bin-${{ matrix.arch }}.zip - - emscripten: - if: ${{ github.event_name == 'push' || github.event_name == 'pull_request' || - github.event.inputs.run_type == 'full-ci' }} - runs-on: ubuntu-22.04 - - strategy: - matrix: - build: [Release] - - steps: - - name: Clone - uses: actions/checkout@v4 - - - name: Setup emsdk - uses: mymindstorm/setup-emsdk@v14 - - - name: Verify - run: emcc -v - - - name: Build - run: | - emcmake cmake . -DCMAKE_BUILD_TYPE=${{ matrix.build }} - make - - ios-xcode-build: - runs-on: macos-latest - needs: determine-tag - - strategy: - matrix: - build: [Release] - - steps: - - name: Checkout code - uses: actions/checkout@v4 - - - name: Configure - run: | - cp models/for-tests-ggml-base.en.bin models/ggml-base.en.bin - mkdir models/ggml-base.en-encoder.mlmodelc - - - name: Build - id: cmake_build - run: | - sysctl -a - mkdir build - cd build - cmake -G Xcode .. \ - -DGGML_METAL_USE_BF16=ON \ - -DGGML_METAL_EMBED_LIBRARY=ON \ - -DWHISPER_BUILD_EXAMPLES=OFF \ - -DWHISPER_BUILD_TESTS=OFF \ - -DWHISPER_BUILD_SERVER=OFF \ - -DCMAKE_SYSTEM_NAME=iOS \ - -DCMAKE_OSX_DEPLOYMENT_TARGET=14.0 \ - -DCMAKE_XCODE_ATTRIBUTE_DEVELOPMENT_TEAM=ggml - cmake --build . --config Release -j $(sysctl -n hw.logicalcpu) -- CODE_SIGNING_ALLOWED=NO - - - name: xcodebuild for swift package - id: xcodebuild - run: | - ./build-xcframework.sh - - - name: Build objc example - run: xcodebuild -project examples/whisper.objc/whisper.objc.xcodeproj -scheme whisper.objc -configuration ${{ matrix.build }} -sdk iphoneos CODE_SIGN_IDENTITY="" CODE_SIGNING_REQUIRED=NO FRAMEWORK_FOLDER_PATH=./build-ios build - - - name: Build swiftui example - run: xcodebuild -project examples/whisper.swiftui/whisper.swiftui.xcodeproj -scheme WhisperCppDemo -configuration ${{ matrix.build }} -sdk iphoneos CODE_SIGNING_REQUIRED=NO CODE_SIGN_IDENTITY= -destination 'generic/platform=iOS' FRAMEWORK_FOLDER_PATH=./build-ios build - - - name: Pack artifacts - id: pack_artifacts - run: | - zip --symlinks -r whisper-${{ needs.determine-tag.outputs.tag_name }}-xcframework.zip build-apple/whisper.xcframework - - - name: Upload artifacts - if: ${{ needs.determine-tag.outputs.should_release }} - uses: actions/upload-artifact@v4 - with: - path: whisper-${{ needs.determine-tag.outputs.tag_name }}-xcframework.zip - name: whisper-${{ needs.determine-tag.outputs.tag_name }}-xcframework.zip - - android: - if: ${{ github.event_name == 'push' || github.event_name == 'pull_request' || - github.event.inputs.run_type == 'full-ci' }} - runs-on: ubuntu-22.04 - - steps: - - name: Clone - uses: actions/checkout@v4 - with: - path: whisper - - - name: Install Java - uses: actions/setup-java@v4 - with: - distribution: zulu - java-version: 21 - - - name: Setup Android SDK - uses: android-actions/setup-android@v3 - - - name: Build - run: | - cd whisper/examples/whisper.android - ./gradlew assembleRelease --no-daemon - - - name: Build with external ggml - run: | - export PATH_TO_GGML=$PWD/ggml - cd whisper/examples/whisper.android - ./gradlew assembleRelease --no-daemon - - android_java: - runs-on: ubuntu-22.04 - - steps: - - name: Clone - uses: actions/checkout@v4 - - - name: set up JDK 11 - uses: actions/setup-java@v4 - with: - java-version: '11' - distribution: 'temurin' - cache: gradle - - - name: Setup Android SDK - uses: android-actions/setup-android@v3 - with: - cmdline-tools-version: 9.0 - - - name: Build - run: | - cd examples/whisper.android.java - chmod +x ./gradlew - ./gradlew assembleRelease - - bindings-java: - if: ${{ github.event_name == 'push' || github.event_name == 'pull_request' || - github.event.inputs.run_type == 'full-ci' }} - needs: ['windows'] - runs-on: windows-latest - steps: - - uses: actions/checkout@v4 - - - name: Install Java - uses: actions/setup-java@v4 - with: - distribution: zulu - java-version: 20 - - - name: Download Whisper Windows lib - uses: actions/download-artifact@v4 - with: - name: whisper_x64.dll - - - name: Download GGML Windows lib - uses: actions/download-artifact@v4 - with: - name: ggml_x64.dll - - - name: Download GGML Base Windows lib - uses: actions/download-artifact@v4 - with: - name: ggml_base_x64.dll - - - name: Download GGML CPU Windows lib - uses: actions/download-artifact@v4 - with: - name: ggml_cpu_x64.dll - - - name: Download SDL2.dll - uses: actions/download-artifact@v4 - with: - name: x64_SDL2.dll - - - name: List downloaded files - shell: pwsh - run: | - Get-ChildItem -Path "." -Recurse -Filter "*.dll" - - - name: Move DLL to correct location - shell: pwsh - run: | - New-Item -Path "build\bin\Release" -ItemType Directory -Force - - Copy-Item -Path "whisper.dll" -Destination "build\bin\Release\whisper.dll" -Force - Write-Host "Copied whisper.dll to build\bin\Release\whisper.dll directory" - - Copy-Item -Path "ggml.dll" -Destination "build\bin\Release\ggml.dll" -Force - Write-Host "Copied ggml.dll to build\bin\Release\ggml.dll directory" - - Copy-Item -Path "ggml-base.dll" -Destination "build\bin\Release\ggml-base.dll" -Force - Write-Host "Copied ggml-base.dll to build\bin\Release\ggml-base.dll directory" - - Copy-Item -Path "ggml-cpu.dll" -Destination "build\bin\Release\ggml-cpu.dll" -Force - Write-Host "Copied ggml-cpu.dll to build\bin\Release\ggml-cpu.dll directory" - - Copy-Item -Path "SDL2.dll" -Destination "build\bin\Release\SDL2.dll" -Force - Write-Host "Copied SDL2.dll to build\bin\Release\SDL2.dll directory" - - - name: List build release files - shell: pwsh - run: | - Get-ChildItem -Path "build\Release" -Recurse -Filter "*.dll" - - - name: Build - run: | - models\download-ggml-model.cmd tiny.en models/ - cd bindings/java - chmod +x ./gradlew - ./gradlew build --info - - - name: Pack jar artifacts - shell: pwsh - run: | - Compress-Archive -Path "bindings/java/build/libs/whispercpp-*.jar" -DestinationPath "whispercpp.jar.zip" - - - name: Upload jar - uses: actions/upload-artifact@v4 - with: - name: whispercpp.jar.zip - path: whispercpp.jar.zip - -# - name: Publish package -# if: ${{ github.ref == 'refs/heads/master' }} -# uses: gradle/gradle-build-action@v2.4.2 -# with: -# arguments: publish -# build-root-directory: bindings/java -# env: -# MAVEN_USERNAME: ${{ secrets.JIRA_USER }} -# MAVEN_PASSWORD: ${{ secrets.JIRA_PASS }} -# PGP_SECRET: ${{ secrets.GPG_PRIVATE_KEY }} -# PGP_PASSPHRASE: ${{ secrets.GPG_PASSPHRASE }} - - quantize: - if: ${{ github.event_name == 'push' || github.event_name == 'pull_request' || - github.event.inputs.run_type == 'full-ci' }} - runs-on: ubuntu-22.04 - - steps: - - name: Clone - uses: actions/checkout@v4 - - - name: Test quantize - run: | - ./models/download-ggml-model.sh tiny.en - cmake -B build - cmake --build build --config Release - ./build/bin/whisper-quantize models/ggml-tiny.en.bin models/ggml-tiny.en-q4_0.bin q4_0 - - release: - if: ${{ github.event.inputs.create_release == 'true' || github.event.inputs.pre_release_tag != '' || startsWith(github.ref, 'refs/tags/v') }} - - runs-on: ubuntu-latest - - needs: - - determine-tag - - ios-xcode-build - - windows - - windows-blas - - windows-cublas - - steps: - - name: Clone - id: checkout - uses: actions/checkout@v4 - with: - fetch-depth: 0 - - - name: ccache - uses: hendrikmuhs/ccache-action@v1.2.16 - with: - key: release - evict-old-files: 1d - - # Downloads all the artifacts from the previous jobs - - name: Download artifacts - id: download-artifact - uses: actions/download-artifact@v4 - with: - path: ./artifact - - - name: Move artifacts - id: move_artifacts - run: mkdir -p ./artifact/release && mv ./artifact/*/*.zip ./artifact/release - - - name: Create release - id: create_release - uses: ggml-org/action-create-release@v1 - env: - GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} - with: - tag_name: ${{ needs.determine-tag.outputs.tag_name }} - prerelease: ${{ github.event.inputs.pre_release_tag != '' }} - draft: true - - - name: Upload release - id: upload_release - uses: actions/github-script@v3 - with: - github-token: ${{secrets.GITHUB_TOKEN}} - script: | - const path = require('path'); - const fs = require('fs'); - const release_id = '${{ steps.create_release.outputs.id }}'; - for (let file of await fs.readdirSync('./artifact/release')) { - if (path.extname(file) === '.zip') { - console.log('uploadReleaseAsset', file); - await github.repos.uploadReleaseAsset({ - owner: context.repo.owner, - repo: context.repo.repo, - release_id: release_id, - name: file, - data: await fs.readFileSync(`./artifact/release/${file}`) - }); - } - } - - coreml-base-en: - if: ${{ (github.event_name == 'push' && github.ref == 'refs/heads/master') || - github.event.inputs.create_release == 'true' || - github.event.inputs.pre_release_tag != '' || - startsWith(github.ref, 'refs/tags/v') }} - runs-on: macos-latest - needs: determine-tag - - steps: - - name: Checkout code - uses: actions/checkout@v4 - - - name: Set environment variables - id: set_vars - run: | - echo "MODEL_NAME=base.en" >> $GITHUB_ENV - echo "GEN_MODEL_NAME=whisper-${{ needs.determine-tag.outputs.tag_name }}-ggml-base.en-encoder.mlmodelc" >> $GITHUB_ENV - - - name: Download model - run: | - ./models/download-ggml-model.sh ${{ env.MODEL_NAME }} - - - name: Generate CoreML model - run: | - python3.11 -m venv venv - source venv/bin/activate - pip install ane_transformers openai-whisper coremltools - ./models/generate-coreml-model.sh ${{ env.MODEL_NAME }} - - vad: - if: ${{ github.event_name == 'push' || github.event_name == 'pull_request' || - github.event.inputs.run_type == 'full-ci' }} - runs-on: ubuntu-latest - - steps: - - name: Checkout - uses: actions/checkout@v4 - - - name: Build - shell: bash - run: | - cmake -B build - cmake --build build --config Release - - - name: Test - shell: bash - run: | - ctest -R ^test-vad$ --test-dir build --output-on-failure -VV - -# TODO: simplify the following workflows using a matrix - ggml-ci-x64-cpu-low-perf: - runs-on: ubuntu-22.04 - - steps: - - name: Clone - id: checkout - uses: actions/checkout@v4 - - - name: ccache - uses: ggml-org/ccache-action@v1.2.16 - with: - key: ggml-ci-x64-cpu-low-perf - evict-old-files: 1d - - - name: Dependencies - id: depends - run: | - sudo apt-get update - sudo apt-get install build-essential libcurl4-openssl-dev - - - name: Test - id: ggml-ci - run: | - LLAMA_ARG_THREADS=$(nproc) GG_BUILD_LOW_PERF=1 bash ./ci/run.sh ./tmp/results ./tmp/mnt - - ggml-ci-arm64-cpu-low-perf: - runs-on: ubuntu-22.04-arm - - steps: - - name: Clone - id: checkout - uses: actions/checkout@v4 - - - name: ccache - uses: ggml-org/ccache-action@v1.2.16 - with: - key: ggml-ci-arm64-cpu-low-perf - evict-old-files: 1d - - - name: Dependencies - id: depends - run: | - sudo apt-get update - sudo apt-get install build-essential libcurl4-openssl-dev - - - name: Test - id: ggml-ci - run: | - LLAMA_ARG_THREADS=$(nproc) GG_BUILD_LOW_PERF=1 bash ./ci/run.sh ./tmp/results ./tmp/mnt - - ggml-ci-x64-cpu-high-perf: - runs-on: ubuntu-22.04 - - steps: - - name: Clone - id: checkout - uses: actions/checkout@v4 - - - name: ccache - uses: ggml-org/ccache-action@v1.2.16 - with: - key: ggml-ci-x64-cpu-high-perf - evict-old-files: 1d - - - name: Dependencies - id: depends - run: | - sudo apt-get update - sudo apt-get install build-essential libcurl4-openssl-dev - - - name: Test - id: ggml-ci - run: | - LLAMA_ARG_THREADS=$(nproc) bash ./ci/run.sh ./tmp/results ./tmp/mnt - - ggml-ci-arm64-cpu-high-perf: - runs-on: ubuntu-22.04-arm - - steps: - - name: Clone - id: checkout - uses: actions/checkout@v4 - - - name: ccache - uses: ggml-org/ccache-action@v1.2.16 - with: - key: ggml-ci-arm64-cpu-high-perf - evict-old-files: 1d - - - name: Dependencies - id: depends - run: | - sudo apt-get update - sudo apt-get install build-essential libcurl4-openssl-dev - - - name: Test - id: ggml-ci - run: | - LLAMA_ARG_THREADS=$(nproc) GG_BUILD_NO_SVE=1 GG_BUILD_NO_BF16=1 GG_BUILD_EXTRA_TESTS_0=1 bash ./ci/run.sh ./tmp/results ./tmp/mnt - - ggml-ci-arm64-cpu-high-perf-sve: - runs-on: ubuntu-22.04-arm - - steps: - - name: Clone - id: checkout - uses: actions/checkout@v4 - - - name: ccache - uses: ggml-org/ccache-action@v1.2.16 - with: - key: ggml-ci-arm64-cpu-high-perf-sve - evict-old-files: 1d - - - name: Dependencies - id: depends - run: | - sudo apt-get update - sudo apt-get install build-essential libcurl4-openssl-dev - - - name: Test - id: ggml-ci - run: | - LLAMA_ARG_THREADS=$(nproc) GG_BUILD_NO_BF16=1 GG_BUILD_EXTRA_TESTS_0=1 bash ./ci/run.sh ./tmp/results ./tmp/mnt - - ggml-ci-x64-nvidia-cuda: - runs-on: [self-hosted, Linux, X64, NVIDIA] - - steps: - - name: Clone - id: checkout - uses: actions/checkout@v4 - - - name: Test - id: ggml-ci - run: | - nvidia-smi - GG_BUILD_CUDA=1 bash ./ci/run.sh ~/results/whisper.cpp /mnt/whisper.cpp - - ggml-ci-x64-nvidia-vulkan-cm: - runs-on: [self-hosted, Linux, X64, NVIDIA] - - steps: - - name: Clone - id: checkout - uses: actions/checkout@v4 - - - name: Test - id: ggml-ci - run: | - vulkaninfo --summary - GG_BUILD_VULKAN=1 GGML_VK_DISABLE_COOPMAT2=1 bash ./ci/run.sh ~/results/whisper.cpp /mnt/whisper.cpp - - ggml-ci-x64-nvidia-vulkan-cm2: - runs-on: [self-hosted, Linux, X64, NVIDIA, COOPMAT2] - - steps: - - name: Clone - id: checkout - uses: actions/checkout@v4 - - - name: Test - id: ggml-ci - run: | - vulkaninfo --summary - GG_BUILD_VULKAN=1 bash ./ci/run.sh ~/results/whisper.cpp /mnt/whisper.cpp - - ggml-ci-x64-cpu-amx: - runs-on: [self-hosted, Linux, X64, CPU, AMX] - - steps: - - name: Clone - id: checkout - uses: actions/checkout@v4 - - - name: Test - id: ggml-ci - run: | - bash ./ci/run.sh ~/results/whisper.cpp /mnt/whisper.cpp - - ggml-ci-mac-metal: - runs-on: [self-hosted, macOS, ARM64] - - steps: - - name: Clone - id: checkout - uses: actions/checkout@v4 - - - name: Test - id: ggml-ci - run: | - GG_BUILD_METAL=1 bash ./ci/run.sh ~/results/whisper.cpp ~/mnt/whisper.cpp - - ggml-ci-mac-vulkan: - runs-on: [self-hosted, macOS, ARM64] - - steps: - - name: Clone - id: checkout - uses: actions/checkout@v4 - - - name: Test - id: ggml-ci - run: | - vulkaninfo --summary - GG_BUILD_VULKAN=1 bash ./ci/run.sh ~/results/whisper.cpp ~/mnt/whisper.cpp diff --git a/.github/workflows/examples-wasm.yml b/.github/workflows/deploy-examples-wasm.yml similarity index 85% rename from .github/workflows/examples-wasm.yml rename to .github/workflows/deploy-examples-wasm.yml index ebbbdfe20ca..55df14720b1 100644 --- a/.github/workflows/examples-wasm.yml +++ b/.github/workflows/deploy-examples-wasm.yml @@ -22,13 +22,13 @@ jobs: runs-on: ubuntu-latest steps: - name: Checkout - uses: actions/checkout@v4 + uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6 - name: Setup Pages - uses: actions/configure-pages@v4 + uses: actions/configure-pages@983d7736d9b0ae728b81ab479565c72886d7745b # v5 - name: Setup emsdk - uses: mymindstorm/setup-emsdk@v14 + uses: emscripten-core/setup-emsdk@6ab9eb1bda2574c4ddb79809fc9247783eaf9021 # v14 - name: Build WASM Examples # Enable for real build later in whisper.cpp @@ -88,10 +88,10 @@ jobs: find staging -type f | sort - name: Upload artifact - uses: actions/upload-pages-artifact@v3 + uses: actions/upload-pages-artifact@7b1f4a764d45c48632c6b24a0339c27f5614fb0b # v4 with: path: ./staging - name: Deploy to GitHub Pages id: deployment - uses: actions/deploy-pages@v4 + uses: actions/deploy-pages@d6db90164ac5ed86f2b6aed7e0febac5b3c0c03e # v4 diff --git a/.github/workflows/docker.yml b/.github/workflows/docker.yml index 0e2fb1f2b9e..2d95e1a697f 100644 --- a/.github/workflows/docker.yml +++ b/.github/workflows/docker.yml @@ -1,42 +1,39 @@ name: Publish Docker image on: - pull_request: - push: - branches: - - master + workflow_dispatch: # allows manual triggering + schedule: + # Rebuild daily rather than on every push because it is expensive + - cron: '12 4 * * *' jobs: push_to_registry: name: Push Docker image to Docker Hub - if: github.event.pull_request.draft == false - runs-on: ubuntu-22.04 + runs-on: ${{ matrix.config.runs_on }} env: COMMIT_SHA: ${{ github.sha }} strategy: fail-fast: false matrix: config: - - { tag: "main", dockerfile: ".devops/main.Dockerfile", platform: "linux/amd64" } - - { tag: "main-musa", dockerfile: ".devops/main-musa.Dockerfile", platform: "linux/amd64" } - - { tag: "main-intel", dockerfile: ".devops/main-intel.Dockerfile", platform: "linux/amd64" } - - { tag: "main-cuda", dockerfile: ".devops/main-cuda.Dockerfile", platform: "linux/amd64" } + - { tag: "main", dockerfile: ".devops/main.Dockerfile", platform: "linux/amd64", runs_on: "ubuntu-24.04" } + - { tag: "main-arm64", dockerfile: ".devops/main.Dockerfile", platform: "linux/arm64", runs_on: "ubuntu-24.04-arm" } + - { tag: "main-musa", dockerfile: ".devops/main-musa.Dockerfile", platform: "linux/amd64", runs_on: "ubuntu-24.04" } + - { tag: "main-intel", dockerfile: ".devops/main-intel.Dockerfile", platform: "linux/amd64", runs_on: "ubuntu-24.04" } + - { tag: "main-cuda", dockerfile: ".devops/main-cuda.Dockerfile", platform: "linux/amd64", runs_on: "ubuntu-24.04" } + - { tag: "main-vulkan", dockerfile: ".devops/main-vulkan.Dockerfile", platform: "linux/amd64", runs_on: "ubuntu-24.04" } + - { tag: "main-vulkan-arm64", dockerfile: ".devops/main-vulkan.Dockerfile", platform: "linux/arm64", runs_on: "ubuntu-24.04-arm" } steps: - name: Check out the repo - uses: actions/checkout@v3 - - - name: Set up QEMU - uses: docker/setup-qemu-action@v3 - with: - image: tonistiigi/binfmt:qemu-v7.0.0-28 + uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6 - name: Set up Docker Buildx - uses: docker/setup-buildx-action@v3 + uses: docker/setup-buildx-action@4d04d5d9486b7bd6fa91e7baf45bbb4f8b9deedd # v4 - name: Log in to Docker Hub - uses: docker/login-action@v3 + uses: docker/login-action@b45d80f862d83dbcd57f89517bcf500b2ab88fb2 # v4 with: registry: ghcr.io username: ${{ github.repository_owner }} @@ -61,16 +58,16 @@ jobs: id: tags run: | TAGS="ghcr.io/${{ github.repository }}:${{ matrix.config.tag }}" - if [ "${{ github.event_name }}" == "push" ]; then - TAGS="$TAGS,ghcr.io/${{ github.repository }}:${{ matrix.config.tag }}-${{ env.COMMIT_SHA }}" - fi + TAGS="$TAGS,ghcr.io/${{ github.repository }}:${{ matrix.config.tag }}-${{ env.COMMIT_SHA }}" echo "tags=$TAGS" >> $GITHUB_OUTPUT - name: Build and push Docker image (tagged) - uses: docker/build-push-action@v5 + uses: docker/build-push-action@d08e5c354a6adb9ed34480a06d141179aa583294 # v7 with: context: . - push: ${{ github.event_name == 'push' }} + push: true platforms: ${{ matrix.config.platform }} tags: ${{ steps.tags.outputs.tags }} file: ${{ matrix.config.dockerfile }} + secrets: | + HF_TOKEN=${{ secrets.HF_TOKEN }} diff --git a/.github/workflows/examples.yml b/.github/workflows/examples.yml index 74ef8e0faae..ac811712e78 100644 --- a/.github/workflows/examples.yml +++ b/.github/workflows/examples.yml @@ -1,13 +1,15 @@ name: Examples Tests on: push: + branches: + - master paths: - examples/addon.node/** - - whisper.h + - include/whisper.h pull_request: paths: - examples/addon.node/** - - whisper.h + - include/whisper.h jobs: addon_node-ubuntu-22: @@ -17,7 +19,7 @@ jobs: node-version: [ 16.x, 18.x ] steps: - name: Clone - uses: actions/checkout@v1 + uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6 - name: Dependencies run: | @@ -27,7 +29,7 @@ jobs: sudo apt-get install libsdl2-dev - name: Use Node.js ${{ matrix.node-version }} - uses: actions/setup-node@v1 + uses: actions/setup-node@48b55a011bda9f5d6aeb4c2d9c7362e8dae4041e # v6 with: node-version: ${{ matrix.node-version }} cache: 'npm' @@ -40,6 +42,8 @@ jobs: run: npx cmake-js compile -T addon.node -B Release - name: Download test model + env: + HF_TOKEN: ${{ secrets.HF_TOKEN }} run: | bash ./models/download-ggml-model.sh base.en - name: Test diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml new file mode 100644 index 00000000000..8dcfeb9827c --- /dev/null +++ b/.github/workflows/release.yml @@ -0,0 +1,653 @@ +name: Release + +on: + workflow_dispatch: + inputs: + create_release: + description: 'Create new release' + required: true + type: boolean + pre_release_tag: + description: 'Pre-release tag name' + required: false + type: string + + push: + tags: + - 'v*' + +env: + BRANCH_NAME: ${{ github.head_ref || github.ref_name }} + VCPKG_BINARY_SOURCES: "clear;x-gha,readwrite" + +concurrency: + group: ${{ github.workflow }}-${{ github.head_ref && github.ref || github.run_id }} + cancel-in-progress: true + +permissions: + contents: write # for creating release + +jobs: + determine-tag: + runs-on: ubuntu-latest + outputs: + tag_name: ${{ steps.tag.outputs.name }} + should_release: ${{ steps.tag.outputs.should_release }} + + steps: + - name: Checkout with full history + uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6 + with: + fetch-depth: 0 + + - name: Determine tag name + id: tag + shell: bash + run: | + BUILD_NUMBER=$(git rev-list --count HEAD) + SHORT_HASH=$(git rev-parse --short=7 HEAD) + CUSTOM_TAG="${{ github.event.inputs.pre_release_tag }}" + SHOULD_RELEASE="false" + + echo "Raw values:" + echo "BUILD_NUMBER: $BUILD_NUMBER" + echo "SHORT_HASH: $SHORT_HASH" + echo "BRANCH_NAME: ${{ env.BRANCH_NAME }}" + echo "CUSTOM_TAG: $CUSTOM_TAG" + + if [[ "${{ github.ref_type }}" == "tag" ]]; then + echo "Using pushed tag name" + TAG_NAME="${{ github.ref_name }}" + SHOULD_RELEASE="true" + elif [[ -n "$CUSTOM_TAG" ]]; then + echo "Using custom tag" + TAG_NAME="${CUSTOM_TAG}" + SHOULD_RELEASE="true" + elif [[ "${{ github.event.inputs.create_release }}" == "true" ]]; then + echo "Manual release requested" + SHOULD_RELEASE="true" + TAG_NAME="b${BUILD_NUMBER}" + elif [[ "${{ env.BRANCH_NAME }}" == "master" ]]; then + echo "Using master branch format" + TAG_NAME="b${BUILD_NUMBER}" + SHOULD_RELEASE="false" + else + echo "Using non-master branch format" + SAFE_NAME=$(echo "${{ env.BRANCH_NAME }}" | tr '/' '-') + TAG_NAME="${SAFE_NAME}-b${BUILD_NUMBER}-${SHORT_HASH}" + SHOULD_RELEASE="false" + fi + + echo "Final tag name: $TAG_NAME" + echo "Should release: $SHOULD_RELEASE" + echo "name=$TAG_NAME" >> $GITHUB_OUTPUT + echo "should_release=$SHOULD_RELEASE" >> $GITHUB_OUTPUT + + ubuntu-cpu: + runs-on: ${{ matrix.os }} + needs: determine-tag + if: ${{ needs.determine-tag.outputs.should_release == 'true' }} + + strategy: + matrix: + include: + - build: x64 + os: ubuntu-22.04 + - build: arm64 + os: ubuntu-22.04-arm + + steps: + - name: Clone + uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6 + + - name: ccache + uses: ggml-org/ccache-action@v1.2.21 + with: + key: release-${{ matrix.os }}-cpu + evict-old-files: 1d + + - name: Dependencies + run: | + sudo apt-get update + sudo apt-get install -y build-essential cmake + + - name: Build + run: | + cmake -B build \ + -DCMAKE_BUILD_TYPE=Release \ + -DCMAKE_INSTALL_RPATH='$ORIGIN' \ + -DCMAKE_BUILD_WITH_INSTALL_RPATH=ON \ + -DGGML_BACKEND_DL=ON \ + -DGGML_NATIVE=OFF \ + ${{ matrix.build == 'x64' && '-DGGML_CPU_ALL_VARIANTS=ON' || '-DGGML_CPU_ARM_ARCH=armv8-a' }} + cmake --build build --config Release -j $(nproc) + + - name: Pack artifacts + run: | + cp LICENSE ./build/bin/ + tar -czvf whisper-bin-ubuntu-${{ matrix.build }}.tar.gz \ + --transform "s,^\.,whisper-bin-ubuntu-${{ matrix.build }}," \ + -C ./build/bin . + + - name: Upload artifacts + uses: actions/upload-artifact@b7c566a772e6b6bfb58ed0dc250532a479d7789f # v6 + with: + path: whisper-bin-ubuntu-${{ matrix.build }}.tar.gz + name: whisper-bin-ubuntu-${{ matrix.build }}.tar.gz + + windows: + runs-on: windows-latest + needs: determine-tag + + strategy: + matrix: + build: [Release] + arch: [Win32, x64] + sdl2: [ON] + include: + - arch: Win32 + s2arc: x86 + jnaPath: win32-x86 + - arch: x64 + s2arc: x64 + jnaPath: win32-x86-64 + - sdl2: ON + s2ver: 2.28.5 + + steps: + - name: Clone + uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6 + + - name: Add msbuild to PATH + uses: microsoft/setup-msbuild@6fb02220983dee41ce7ae257b6f4d8f9bf5ed4ce # v2 + + - name: Fetch SDL2 and set SDL2_DIR + if: matrix.sdl2 == 'ON' + run: | + C:/msys64/usr/bin/wget.exe -qO sdl2.zip https://github.com/libsdl-org/SDL/releases/download/release-${{ matrix.s2ver }}/SDL2-devel-${{ matrix.s2ver }}-VC.zip + 7z x sdl2.zip + echo "SDL2_DIR=$env:GITHUB_WORKSPACE/SDL2-${{ matrix.s2ver }}/cmake" >> $env:GITHUB_ENV + + - name: Configure + run: > + cmake -S . -B ./build -A ${{ matrix.arch }} + -DCMAKE_BUILD_TYPE=${{ matrix.build }} + -DBUILD_SHARED_LIBS=ON + -DWHISPER_SDL2=${{ matrix.sdl2 }} + -DGGML_NATIVE=OFF + ${{ matrix.arch == 'x64' && '-DGGML_BACKEND_DL=ON -DGGML_CPU_ALL_VARIANTS=ON' || '-DGGML_BMI2=OFF' }} + + - name: Build + run: | + cd ./build + msbuild ALL_BUILD.vcxproj -t:build -p:configuration=${{ matrix.build }} -p:platform=${{ matrix.arch }} + + - name: Copy SDL2.dll + if: matrix.sdl2 == 'ON' + run: copy "$env:SDL2_DIR/../lib/${{ matrix.s2arc }}/SDL2.dll" build/bin/${{ matrix.build }} + + - name: Upload SDL2.dll + if: matrix.sdl2 == 'ON' + uses: actions/upload-artifact@b7c566a772e6b6bfb58ed0dc250532a479d7789f # v6 + with: + name: ${{ matrix.s2arc }}_SDL2.dll + path: build/bin/${{ matrix.build }}/SDL2.dll + + - name: Upload whisper dll + uses: actions/upload-artifact@b7c566a772e6b6bfb58ed0dc250532a479d7789f # v6 + with: + name: whisper_${{ matrix.arch }}.dll + path: build/bin/${{ matrix.build }}/whisper.dll + + - name: Upload ggml dll + uses: actions/upload-artifact@b7c566a772e6b6bfb58ed0dc250532a479d7789f # v6 + with: + name: ggml_${{ matrix.arch }}.dll + path: build/bin/${{ matrix.build }}/ggml.dll + overwrite: true + + - name: Upload ggml base dll + uses: actions/upload-artifact@b7c566a772e6b6bfb58ed0dc250532a479d7789f # v6 + with: + name: ggml_base_${{ matrix.arch }}.dll + path: build/bin/${{ matrix.build }}/ggml-base.dll + + - name: Upload ggml cpu dll + uses: actions/upload-artifact@b7c566a772e6b6bfb58ed0dc250532a479d7789f # v6 + with: + name: ggml_cpu_${{ matrix.arch }}.dll + path: build/bin/${{ matrix.build }}/ggml-cpu.dll + + - name: Pack bin artifacts + shell: pwsh + run: | + Compress-Archive -Path "build/bin/${{ matrix.build }}" -DestinationPath "whisper-bin-${{ matrix.arch }}.zip" + + - name: Upload binaries + if: matrix.sdl2 == 'ON' && ${{ needs.determine-tag.outputs.should_release }} + uses: actions/upload-artifact@b7c566a772e6b6bfb58ed0dc250532a479d7789f # v6 + with: + name: whisper-bin-${{ matrix.arch }}.zip + path: whisper-bin-${{ matrix.arch }}.zip + + windows-blas: + runs-on: windows-latest + needs: determine-tag + + strategy: + matrix: + build: [Release] + arch: [Win32, x64] + blas: [ON] + sdl2: [ON] + blasver: [0.3.29] + include: + - arch: Win32 + s2arc: x86 + blasfile: x86 + - arch: x64 + s2arc: x64 + blasfile: x64_64 + - sdl2: ON + s2ver: 2.28.5 + + steps: + - name: Clone + uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6 + + - name: Export GitHub Actions cache environment variables + uses: actions/github-script@ed597411d8f924073f98dfc5c65a23a2325f34cd # v8 + with: + script: | + core.exportVariable('ACTIONS_CACHE_URL', process.env.ACTIONS_CACHE_URL || ''); + core.exportVariable('ACTIONS_RUNTIME_TOKEN', process.env.ACTIONS_RUNTIME_TOKEN || ''); + + - name: Add msbuild to PATH + uses: microsoft/setup-msbuild@6fb02220983dee41ce7ae257b6f4d8f9bf5ed4ce # v2 + + - name: Install OpenBLAS and pkgconfiglite + if: matrix.blas == 'ON' + run: | + Invoke-WebRequest "https://github.com/OpenMathLib/OpenBLAS/releases/download/v${{matrix.blasver}}/OpenBLAS-${{matrix.blasver}}_${{matrix.blasfile}}.zip" -OutFile "OpenBLAS-${{matrix.blasver}}.zip" + Expand-Archive "OpenBLAS-${{matrix.blasver}}.zip" -DestinationPath "OpenBLAS-${{matrix.blasver}}" + choco install pkgconfiglite + + - name: Fetch SDL2 and set SDL2_DIR + if: matrix.sdl2 == 'ON' + run: | + C:/msys64/usr/bin/wget.exe -qO sdl2.zip https://github.com/libsdl-org/SDL/releases/download/release-${{ matrix.s2ver }}/SDL2-devel-${{ matrix.s2ver }}-VC.zip + 7z x sdl2.zip + echo "SDL2_DIR=$env:GITHUB_WORKSPACE/SDL2-${{ matrix.s2ver }}/cmake" >> $env:GITHUB_ENV + + - name: Configure + run: > + cmake -S . -B ./build -A ${{ matrix.arch }} + -DCMAKE_TOOLCHAIN_FILE="$env:VCPKG_INSTALLATION_ROOT/scripts/buildsystems/vcpkg.cmake" + -DCMAKE_BUILD_TYPE=${{ matrix.build }} + -DGGML_BLAS=${{ matrix.blas }} + -DGGML_BLAS_VENDOR=OpenBLAS + -DBLAS_LIBRARIES="$env:GITHUB_WORKSPACE/OpenBLAS-${{matrix.blasver}}/lib/libopenblas.lib" + -DBLAS_INCLUDE_DIRS="$env:GITHUB_WORKSPACE/OpenBLAS-${{matrix.blasver}}/include" + -DWHISPER_SDL2=${{ matrix.sdl2 }} + -DGGML_NATIVE=OFF + ${{ matrix.arch == 'x64' && '-DGGML_BACKEND_DL=ON -DGGML_CPU_ALL_VARIANTS=ON' || '-DGGML_BMI2=OFF' }} + + - name: Build + run: | + cd ./build + msbuild ALL_BUILD.vcxproj -t:build -p:configuration=${{ matrix.build }} -p:platform=${{ matrix.arch }} + + - name: Copy openblas.dll + if: matrix.blas == 'ON' + run: copy "$env:GITHUB_WORKSPACE/OpenBLAS-${{matrix.blasver}}/bin/libopenblas.dll" build/bin/${{ matrix.build }} + + - name: Copy SDL2.dll + if: matrix.sdl2 == 'ON' + run: copy "$env:SDL2_DIR/../lib/${{ matrix.s2arc }}/SDL2.dll" build/bin/${{ matrix.build }} + + - name: Pack bin artifacts + shell: pwsh + run: | + Compress-Archive -Path "build/bin/${{ matrix.build }}" -DestinationPath "whisper-blas-bin-${{ matrix.arch }}.zip" + + - name: Upload binaries + if: matrix.blas == 'ON' && matrix.sdl2 == 'ON' && ${{ needs.determine-tag.outputs.should_release }} + uses: actions/upload-artifact@b7c566a772e6b6bfb58ed0dc250532a479d7789f # v6 + with: + name: whisper-blas-bin-${{ matrix.arch }}.zip + path: whisper-blas-bin-${{ matrix.arch }}.zip + + windows-cublas: + runs-on: windows-2022 + needs: determine-tag + strategy: + fail-fast: false + matrix: + build: [Release] + arch: [x64] + cublas: [ON] + sdl2: [ON] + cuda-toolkit: [12.4.0, 11.8.0] + include: + - arch: x64 + sdl2: ON + sdl2_ver: 2.28.5 + steps: + - name: Clone repository + uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6 + + - name: Install Ninja + id: install_ninja + run: | + choco install ninja + + - name: Install ccache + uses: ggml-org/ccache-action@v1.2.21 + with: + key: ${{ github.job }}-${{ matrix.cuda-toolkit }}-${{ matrix.build }} + evict-old-files: 5d + + - name: Install Cuda Toolkit 11.8.0 + if: ${{ matrix.cuda-toolkit == '11.8.0' }} + run: | + $CUDA_VERSION = ${{ matrix.cuda-toolkit }} + $CUDA_TOOLKIT_DIR = "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v$CUDA_VERSION" + $CUDA_DOWNLOAD = "https://developer.download.nvidia.com/compute/cuda/redist" + + # Components versions + $CUDART_VER = "11.8.89" + $NVCC_VER = "11.8.89" + $NVRTC_VER = "11.8.89" + $CUBLAS_VER = "11.8.1.74" + $NVTX_VER = "11.8.86" + $VS_VER = "11.8.86" + $NVPROF_VER = "11.8.87" + $CCCL_VER = "11.8.89" + + # Create the directory where the CUDA Toolkit will be installed + mkdir -p $CUDA_TOOLKIT_DIR + + # Install unzip to extract the downloaded files + choco install unzip -y + + # Download all the required components + curl -O "$CUDA_DOWNLOAD/cuda_cudart/windows-x86_64/cuda_cudart-windows-x86_64-${CUDART_VER}-archive.zip" + curl -O "$CUDA_DOWNLOAD/cuda_nvcc/windows-x86_64/cuda_nvcc-windows-x86_64-${NVCC_VER}-archive.zip" + curl -O "$CUDA_DOWNLOAD/cuda_nvrtc/windows-x86_64/cuda_nvrtc-windows-x86_64-${NVRTC_VER}-archive.zip" + curl -O "$CUDA_DOWNLOAD/libcublas/windows-x86_64/libcublas-windows-x86_64-${CUBLAS_VER}-archive.zip" + curl -O "$CUDA_DOWNLOAD/cuda_nvtx/windows-x86_64/cuda_nvtx-windows-x86_64-${NVTX_VER}-archive.zip" + curl -O "$CUDA_DOWNLOAD/visual_studio_integration/windows-x86_64/visual_studio_integration-windows-x86_64-${VS_VER}-archive.zip" + curl -O "$CUDA_DOWNLOAD/cuda_nvprof/windows-x86_64/cuda_nvprof-windows-x86_64-${NVPROF_VER}-archive.zip" + curl -O "$CUDA_DOWNLOAD/cuda_cccl/windows-x86_64/cuda_cccl-windows-x86_64-${CCCL_VER}-archive.zip" + + # Extract all the downloaded files to the CUDA Toolkit directory + unzip '*.zip' -d $CUDA_TOOLKIT_DIR + + # Copy all the extracted files to the main CUDA Toolkit directory + xcopy "$CUDA_TOOLKIT_DIR\cuda_cudart-windows-x86_64-${CUDART_VER}-archive\*" "$CUDA_TOOLKIT_DIR" /E /I /H /Y + xcopy "$CUDA_TOOLKIT_DIR\cuda_nvcc-windows-x86_64-${NVCC_VER}-archive\*" "$CUDA_TOOLKIT_DIR" /E /I /H /Y + xcopy "$CUDA_TOOLKIT_DIR\cuda_nvrtc-windows-x86_64-${NVRTC_VER}-archive\*" "$CUDA_TOOLKIT_DIR" /E /I /H /Y + xcopy "$CUDA_TOOLKIT_DIR\libcublas-windows-x86_64-${CUBLAS_VER}-archive\*" "$CUDA_TOOLKIT_DIR" /E /I /H /Y + xcopy "$CUDA_TOOLKIT_DIR\cuda_nvtx-windows-x86_64-${NVTX_VER}-archive\*" "$CUDA_TOOLKIT_DIR" /E /I /H /Y + xcopy "$CUDA_TOOLKIT_DIR\cuda_nvprof-windows-x86_64-${NVPROF_VER}-archive\*" "$CUDA_TOOLKIT_DIR" /E /I /H /Y + xcopy "$CUDA_TOOLKIT_DIR\cuda_cccl-windows-x86_64-${CCCL_VER}-archive\*" "$CUDA_TOOLKIT_DIR" /E /I /H /Y + xcopy "$CUDA_TOOLKIT_DIR\visual_studio_integration-windows-x86_64-${VS_VER}-archive\*" "$CUDA_TOOLKIT_DIR" /E /I /H /Y + + # Visual Studio integration + xcopy "$CUDA_TOOLKIT_DIR\visual_studio_integration-windows-x86_64-${VS_VER}-archive\visual_studio_integration\MSBuildExtensions\*" "C:\Program Files\Microsoft Visual Studio\2022\Enterprise\MSBuild\Microsoft\VC\v170\BuildCustomizations" /E /I /H /Y + + # Set environment variables + echo "$CUDA_TOOLKIT_DIR\bin" | Out-File -FilePath $env:GITHUB_PATH -Encoding utf8 -Append + echo "$CUDA_TOOLKIT_DIR\libnvvp" | Out-File -FilePath $env:GITHUB_PATH -Encoding utf8 -Append + echo "CUDA_PATH=$CUDA_TOOLKIT_DIR" | Out-File -FilePath $env:GITHUB_ENV -Append -Encoding utf8 + echo "CUDA_PATH_V11_8=$CUDA_TOOLKIT_DIR" | Out-File -FilePath $env:GITHUB_ENV -Append -Encoding utf8 + + - name: Install Cuda Toolkit 12.4.0 + if: ${{ matrix.cuda-toolkit == '12.4.0' }} + run: | + $CUDA_VERSION = ${{ matrix.cuda-toolkit }} + $CUDA_TOOLKIT_DIR = "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v$CUDA_VERSION" + $CUDA_DOWNLOAD = "https://developer.download.nvidia.com/compute/cuda/redist" + + # Components versions + $CUDART_VER = "12.4.127" + $NVCC_VER = "12.4.131" + $NVRTC_VER = "12.4.127" + $CUBLAS_VER = "12.4.5.8" + $NVTX_VER = "12.4.127" + $PROFILER_VER = "12.4.127" + $VS_VER = "12.4.127" + $NVPROF_VER = "12.4.128" + $CCCL_VER = "12.4.127" + + # Create the directory where the CUDA Toolkit will be installed + mkdir -p $CUDA_TOOLKIT_DIR + + # Install unzip to extract the downloaded files + choco install unzip -y + + # Download all the required components + curl -O "$CUDA_DOWNLOAD/cuda_cudart/windows-x86_64/cuda_cudart-windows-x86_64-${CUDART_VER}-archive.zip" + curl -O "$CUDA_DOWNLOAD/cuda_nvcc/windows-x86_64/cuda_nvcc-windows-x86_64-${NVCC_VER}-archive.zip" + curl -O "$CUDA_DOWNLOAD/cuda_nvrtc/windows-x86_64/cuda_nvrtc-windows-x86_64-${NVRTC_VER}-archive.zip" + curl -O "$CUDA_DOWNLOAD/libcublas/windows-x86_64/libcublas-windows-x86_64-${CUBLAS_VER}-archive.zip" + curl -O "$CUDA_DOWNLOAD/cuda_nvtx/windows-x86_64/cuda_nvtx-windows-x86_64-${NVTX_VER}-archive.zip" + curl -O "$CUDA_DOWNLOAD/cuda_profiler_api/windows-x86_64/cuda_profiler_api-windows-x86_64-${PROFILER_VER}-archive.zip" + curl -O "$CUDA_DOWNLOAD/visual_studio_integration/windows-x86_64/visual_studio_integration-windows-x86_64-${VS_VER}-archive.zip" + curl -O "$CUDA_DOWNLOAD/cuda_nvprof/windows-x86_64/cuda_nvprof-windows-x86_64-${NVPROF_VER}-archive.zip" + curl -O "$CUDA_DOWNLOAD/cuda_cccl/windows-x86_64/cuda_cccl-windows-x86_64-${CCCL_VER}-archive.zip" + + # Extract all the downloaded files to the CUDA Toolkit directory + unzip -q '*.zip' -d $CUDA_TOOLKIT_DIR + + # Copy all the extracted files to the main CUDA Toolkit directory + xcopy "$CUDA_TOOLKIT_DIR\cuda_cudart-windows-x86_64-${CUDART_VER}-archive\*" "$CUDA_TOOLKIT_DIR" /E /I /H /Y + xcopy "$CUDA_TOOLKIT_DIR\cuda_nvcc-windows-x86_64-${NVCC_VER}-archive\*" "$CUDA_TOOLKIT_DIR" /E /I /H /Y + xcopy "$CUDA_TOOLKIT_DIR\cuda_nvrtc-windows-x86_64-${NVRTC_VER}-archive\*" "$CUDA_TOOLKIT_DIR" /E /I /H /Y + xcopy "$CUDA_TOOLKIT_DIR\libcublas-windows-x86_64-${CUBLAS_VER}-archive\*" "$CUDA_TOOLKIT_DIR" /E /I /H /Y + xcopy "$CUDA_TOOLKIT_DIR\cuda_nvtx-windows-x86_64-${NVTX_VER}-archive\*" "$CUDA_TOOLKIT_DIR" /E /I /H /Y + xcopy "$CUDA_TOOLKIT_DIR\cuda_nvprof-windows-x86_64-${NVPROF_VER}-archive\*" "$CUDA_TOOLKIT_DIR" /E /I /H /Y + xcopy "$CUDA_TOOLKIT_DIR\cuda_cccl-windows-x86_64-${CCCL_VER}-archive\*" "$CUDA_TOOLKIT_DIR" /E /I /H /Y + xcopy "$CUDA_TOOLKIT_DIR\cuda_profiler_api-windows-x86_64-${PROFILER_VER}-archive\*" "$CUDA_TOOLKIT_DIR" /E /I /H /Y + xcopy "$CUDA_TOOLKIT_DIR\visual_studio_integration-windows-x86_64-${VS_VER}-archive\*" "$CUDA_TOOLKIT_DIR" /E /I /H /Y + + # Visual Studio integration + xcopy "$CUDA_TOOLKIT_DIR\visual_studio_integration-windows-x86_64-${VS_VER}-archive\visual_studio_integration\MSBuildExtensions\*" "C:\Program Files\Microsoft Visual Studio\2022\Enterprise\MSBuild\Microsoft\VC\v170\BuildCustomizations" /E /I /H /Y + + # Set environment variables + echo "$CUDA_TOOLKIT_DIR\bin" | Out-File -FilePath $env:GITHUB_PATH -Encoding utf8 -Append + echo "$CUDA_TOOLKIT_DIR\libnvvp" | Out-File -FilePath $env:GITHUB_PATH -Encoding utf8 -Append + echo "CUDA_PATH=$CUDA_TOOLKIT_DIR" | Out-File -FilePath $env:GITHUB_ENV -Append -Encoding utf8 + echo "CUDA_PATH_V12_2=$CUDA_TOOLKIT_DIR" | Out-File -FilePath $env:GITHUB_ENV -Append -Encoding utf8 + + - name: Add msbuild to PATH + uses: microsoft/setup-msbuild@6fb02220983dee41ce7ae257b6f4d8f9bf5ed4ce # v2 + + - name: Install 7-Zip + run: choco install 7zip -y + + - name: Fetch SDL2 and set SDL2_DIR + if: matrix.sdl2 == 'ON' + run: | + Invoke-WebRequest -Uri https://github.com/libsdl-org/SDL/releases/download/release-${{ matrix.sdl2_ver }}/SDL2-devel-${{ matrix.sdl2_ver }}-VC.zip -OutFile sdl2.zip + 7z x sdl2.zip + echo "SDL2_DIR=${{ github.workspace }}\SDL2-${{ matrix.sdl2_ver }}\cmake" | Out-File -FilePath $env:GITHUB_ENV -Append + echo "${{ github.workspace }}\SDL2-${{ matrix.sdl2_ver }}\cmake" > SDL2_PATH.txt + + - name: Install cmake + run: choco install cmake + + - name: Build Project + shell: cmd + run: | + call "C:\Program Files\Microsoft Visual Studio\2022\Enterprise\VC\Auxiliary\Build\vcvars64.bat" + cmake --version + where cmake + if "${{ matrix.cuda-toolkit }}" == "11.8.0" ( + set CUDA_FLAGS=-allow-unsupported-compiler -D_ALLOW_COMPILER_AND_STL_VERSION_MISMATCH -D_DISABLE_CONSTEXPR_MUTEX_CONSTRUCTOR + ) else ( + set CUDA_FLAGS= + ) + cmake -S . -B build -G "Ninja Multi-Config" ^ + -DCMAKE_BUILD_TYPE=${{ matrix.build }} ^ + -DGGML_CUDA=${{ matrix.cublas }} ^ + -DWHISPER_SDL2=${{ matrix.sdl2 }} ^ + -DSDL2_DIR="%SDL2_DIR%" ^ + -DCMAKE_POLICY_VERSION_MINIMUM=3.5 ^ + -DCMAKE_CUDA_FLAGS="%CUDA_FLAGS%" ^ + -DGGML_BACKEND_DL=ON ^ + -DGGML_NATIVE=OFF ^ + -DGGML_CPU_ALL_VARIANTS=ON + set /A NINJA_JOBS=%NUMBER_OF_PROCESSORS%-1 + cmake --build build --config ${{ matrix.build }} -j %NUMBER_OF_PROCESSORS% + + - name: Check ccache status after build + run: | + ccache --show-stats + + - name: Copy CUDA DLLs + run: | + Get-ChildItem "$env:CUDA_PATH\bin\" -Filter "*.dll" | + Copy-Item -Destination "build/bin/${{ matrix.build }}" + + - name: Copy SDL2.dll + if: matrix.sdl2 == 'ON' + run: copy "$env:SDL2_DIR/../lib/${{ matrix.arch }}/SDL2.dll" build/bin/${{ matrix.build }} + + - name: Pack bin artifacts + shell: pwsh + run: | + Compress-Archive -Path "build/bin/${{ matrix.build }}" -DestinationPath "whisper-cublas-${{ matrix.cuda-toolkit }}-bin-${{ matrix.arch }}.zip" + + - name: Upload binaries + if: ${{ needs.determine-tag.outputs.should_release }} + uses: actions/upload-artifact@b7c566a772e6b6bfb58ed0dc250532a479d7789f # v6 + with: + name: whisper-cublas-${{ matrix.cuda-toolkit }}-bin-${{ matrix.arch }}.zip + path: whisper-cublas-${{ matrix.cuda-toolkit }}-bin-${{ matrix.arch }}.zip + + ios-xcode-build: + runs-on: macos-latest + needs: determine-tag + + strategy: + matrix: + build: [Release] + + steps: + - name: Checkout code + uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6 + + - name: Configure + run: | + cp models/for-tests-ggml-base.en.bin models/ggml-base.en.bin + mkdir models/ggml-base.en-encoder.mlmodelc + + - name: Build + id: cmake_build + run: | + sysctl -a + mkdir build + cd build + cmake -G Xcode .. \ + -DGGML_METAL_USE_BF16=ON \ + -DGGML_METAL_EMBED_LIBRARY=ON \ + -DWHISPER_BUILD_EXAMPLES=OFF \ + -DWHISPER_BUILD_TESTS=OFF \ + -DWHISPER_BUILD_SERVER=OFF \ + -DCMAKE_SYSTEM_NAME=iOS \ + -DCMAKE_OSX_DEPLOYMENT_TARGET=14.0 \ + -DCMAKE_XCODE_ATTRIBUTE_DEVELOPMENT_TEAM=ggml + cmake --build . --config Release -j $(sysctl -n hw.logicalcpu) -- CODE_SIGNING_ALLOWED=NO + + - name: xcodebuild for swift package + id: xcodebuild + run: | + ./build-xcframework.sh + + - name: Build objc example + run: xcodebuild -project examples/whisper.objc/whisper.objc.xcodeproj -scheme whisper.objc -configuration ${{ matrix.build }} -sdk iphoneos CODE_SIGN_IDENTITY="" CODE_SIGNING_REQUIRED=NO FRAMEWORK_FOLDER_PATH=./build-ios build + + - name: Build swiftui example + run: xcodebuild -project examples/whisper.swiftui/whisper.swiftui.xcodeproj -scheme WhisperCppDemo -configuration ${{ matrix.build }} -sdk iphoneos CODE_SIGNING_REQUIRED=NO CODE_SIGN_IDENTITY= -destination 'generic/platform=iOS' FRAMEWORK_FOLDER_PATH=./build-ios build + + - name: Pack artifacts + id: pack_artifacts + run: | + zip --symlinks -r whisper-${{ needs.determine-tag.outputs.tag_name }}-xcframework.zip build-apple/whisper.xcframework + + - name: Upload artifacts + if: ${{ needs.determine-tag.outputs.should_release }} + uses: actions/upload-artifact@b7c566a772e6b6bfb58ed0dc250532a479d7789f # v6 + with: + path: whisper-${{ needs.determine-tag.outputs.tag_name }}-xcframework.zip + name: whisper-${{ needs.determine-tag.outputs.tag_name }}-xcframework.zip + + release: + if: ${{ github.event.inputs.create_release == 'true' || github.event.inputs.pre_release_tag != '' || startsWith(github.ref, 'refs/tags/v') }} + + runs-on: ubuntu-latest + + needs: + - determine-tag + - ubuntu-cpu + - ios-xcode-build + - windows + - windows-blas + - windows-cublas + + steps: + - name: Clone + id: checkout + uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6 + with: + fetch-depth: 0 + + - name: ccache + uses: ggml-org/ccache-action@v1.2.21 + with: + key: release + evict-old-files: 1d + + # Downloads all the artifacts from the previous jobs + - name: Download artifacts + id: download-artifact + uses: actions/download-artifact@37930b1c2abaa49bbe596cd826c3c89aef350131 # v7 + with: + path: ./artifact + + - name: Move artifacts + id: move_artifacts + run: mkdir -p ./artifact/release && mv ./artifact/*/*.zip ./artifact/release && mv ./artifact/*/*.tar.gz ./artifact/release 2>/dev/null || true + + - name: Create release + id: create_release + uses: ggml-org/action-create-release@v1 + env: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + with: + tag_name: ${{ needs.determine-tag.outputs.tag_name }} + prerelease: ${{ github.event.inputs.pre_release_tag != '' }} + draft: true + + - name: Upload release + id: upload_release + uses: actions/github-script@ffc2c79a5b2490bd33e0a41c1de74b877714d736 # v3 + with: + github-token: ${{secrets.GITHUB_TOKEN}} + script: | + const path = require('path'); + const fs = require('fs'); + const release_id = '${{ steps.create_release.outputs.id }}'; + for (let file of await fs.readdirSync('./artifact/release')) { + if (path.extname(file) === '.zip' || file.endsWith('.tar.gz')) { + console.log('uploadReleaseAsset', file); + await github.repos.uploadReleaseAsset({ + owner: context.repo.owner, + repo: context.repo.repo, + release_id: release_id, + name: file, + data: await fs.readFileSync(`./artifact/release/${file}`) + }); + } + } diff --git a/.gitignore b/.gitignore index 957eeb75456..7a98228af3c 100644 --- a/.gitignore +++ b/.gitignore @@ -10,6 +10,7 @@ .DS_Store .vimspector.json /CMakeSettings.json +/CMakeUserPresets.json /talk-llama.dSYM/ build/ @@ -64,3 +65,6 @@ cmake-build-debug/ local.properties .log .exe + +# AGENTS +.pi/SYSTEM.md diff --git a/.pi/gg/SYSTEM.md b/.pi/gg/SYSTEM.md new file mode 100644 index 00000000000..1ae0e40674e --- /dev/null +++ b/.pi/gg/SYSTEM.md @@ -0,0 +1,27 @@ +You are a coding agent. Here are some very important rules that you must follow: + +General: +- Be very precise and concise when writing code, comments, explanations, etc. +- PR and commit titles format: ` : `. Lookup recents for examples +- Don't try to build or run the code unless you are explicitly asked to do so +- Use the `gh` CLI tool when querying PRs, issues, or other GitHub resources + +Coding: +- When in doubt, always refer to the CONTRIBUTING.md file of the project +- When referencing issues or PRs in comments, use the format: + - C/C++ code: `// ref: <url>` + - Other (CMake, etc.): `# ref: <url>` + +Pull requests (PRs): +- New branch names are prefixed with "gg/" +- Before opening a pull request, ask the user to confirm the description +- When creating a pull request, look for the repository's PR template and follow it +- For the AI usage disclosure section, write "YES. llama.cpp + pi + [MODEL]" +- Ask the user to tell you what model was used and write it in place of [MODEL] +- Always create the pull requests in draft mode + +Commits: +- On every commit that you make, include a "Assisted-by: llama.cpp:local pi" tag +- Do not explicitly set the git author in commits - rely on the default git config +- Always use `--no-gpg-sign` when committing +- Never `git push` without explicit confirmation from the user diff --git a/AGENTS.md b/AGENTS.md new file mode 100644 index 00000000000..f34f3249977 --- /dev/null +++ b/AGENTS.md @@ -0,0 +1,102 @@ +# Instructions for whisper.cpp + +> [!IMPORTANT] +> This project does **not** accept pull requests that are fully or predominantly AI-generated. AI tools may be utilized solely in an assistive capacity. +> +> Read more: [CONTRIBUTING.md](CONTRIBUTING.md) + +AI assistance is permissible only when the majority of the code is authored by a human contributor, with AI employed exclusively for corrections or to expand on verbose modifications that the contributor has already conceptualized (see examples below). + +--- + +## Guidelines for Contributors Using AI + +whisper.cpp is built by humans, for humans. Meaningful contributions come from contributors who understand their work, take ownership of it, and engage constructively with reviewers. + +Maintainers receive numerous pull requests weekly, many of which are AI-generated submissions where the author cannot adequately explain the code, debug issues, or participate in substantive design discussions. Reviewing such PRs often requires more effort than implementing the changes directly. + +**A pull request represents a long-term commitment.** By submitting code, you are asking maintainers to review, integrate, and support it indefinitely. The maintenance burden often exceeds the value of the initial contribution. + +Most maintainers already have access to AI tools. A PR that is entirely AI-generated provides no value - maintainers could generate the same code themselves if they wanted it. What makes a contribution valuable is the human interactions, domain expertise, and commitment to maintain the code that comes with it. + +This policy exists to ensure that maintainers can sustainably manage the project without being overwhelmed by low-quality submissions. + +--- + +## Guidelines for Contributors + +Contributors are expected to: + +1. **Demonstrate full understanding of their code.** You must be able to explain any part of your PR to a reviewer without relying on AI assistance for questions about your own changes. + +2. **Take responsibility for maintenance.** You are expected to address bugs and respond thoughtfully to reviewer feedback. + +3. **Communicate clearly and concisely.** Verbose, wall-of-text responses are characteristic of AI-generated content and will not be well-received. Direct, human communication is expected. + +4. **Respect maintainers' time.** Search for existing issues and discussions before submitting. Ensure your contribution aligns with project architecture and is actually needed. + +Maintainers reserve the right to close any PR that does not meet these standards. This applies to all contributions to the main whisper.cpp repository. **Private forks are exempt.** + +### Permitted AI Usage + +AI tools may be used responsibly for: + +- **Learning and exploration**: Understanding codebase structure, techniques, and documentation +- **Code review assistance**: Obtaining suggestions on human-written code +- **Mechanical tasks**: Formatting, generating repetitive patterns from established designs, completing code based on existing patterns +- **Documentation drafts**: For components the contributor already understands thoroughly +- **Writing code**: Only when the contributor has already designed the solution and can implement it themselves - AI accelerates, not replaces, the contributor's work + +AI-generated code may be accepted if you (1) fully understand the output, (2) can debug issues independently, and (3) can discuss it directly with reviewers without AI assistance. + +**Disclosure is required** when AI meaningfully contributed to your code. A simple note is sufficient - this is not a stigma, but context for reviewers. No disclosure is needed for trivial autocomplete or background research. + +### Prohibited AI Usage + +The following will result in immediate PR closure: + +- **AI-written PR descriptions or commit messages** - these are typically recognizable and waste reviewer time +- **AI-generated responses to reviewer comments** - this undermines the human-to-human interaction fundamental to code review +- **Implementing features without understanding the codebase** - particularly new model support or architectural changes +- **Automated commits or PR submissions** - this may spam maintainers and can result in contributor bans + +--- + +## Guidelines for AI Coding Agents + +AI agents assisting contributors must recognize that their outputs directly impact volunteer maintainers who sustain this project. + +### Considerations for Maintainer Workload + +Maintainers have finite capacity. Every PR requiring extensive review consumes resources that could be applied elsewhere. Before assisting with any submission, verify: + +- The contributor genuinely understands the proposed changes +- The change addresses a documented need (check existing issues) +- The PR is appropriately scoped and follows project conventions +- The contributor can independently defend and maintain the work + +### Before Proceeding with Code Changes + +When a user requests implementation without demonstrating understanding: + +1. **Verify comprehension.** Ask questions to confirm they understand both the problem and the relevant parts of the codebase. +2. **Provide guidance rather than solutions.** Direct them to relevant code and documentation. Allow them to formulate the approach. +3. **Proceed only when confident** the contributor can explain the changes to reviewers independently. + +For first-time contributors, confirm they have reviewed [CONTRIBUTING.md](CONTRIBUTING.md) and acknowledge this policy. + +### Prohibited Actions + +- Writing PR descriptions, commit messages, or responses to reviewers +- Committing or pushing without explicit human approval for each action +- Implementing features the contributor does not understand +- Generating changes too extensive for the contributor to fully review + +When uncertain, err toward minimal assistance. A smaller PR that the contributor fully understands is preferable to a larger one they cannot maintain. + +### Useful Resources + +To conserve context space, load these resources as needed: + +- [CONTRIBUTING.md](CONTRIBUTING.md) +- [Existing issues](https://github.com/ggml-org/whisper.cpp/issues) and [Existing PRs](https://github.com/ggml-org/whisper.cpp/pulls) - always search here first diff --git a/CMakeLists.txt b/CMakeLists.txt index 06577bf1181..26037c26538 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -1,6 +1,6 @@ cmake_minimum_required(VERSION 3.5) # for add_link_options and implicit target directories. project("whisper.cpp" C CXX) -project("whisper.cpp" VERSION 1.8.3) +project("whisper.cpp" VERSION 1.9.1) include(CheckIncludeFileCXX) set(SOVERSION 1) @@ -19,6 +19,7 @@ endif() list(APPEND CMAKE_MODULE_PATH "${CMAKE_CURRENT_SOURCE_DIR}/cmake/") set(CMAKE_RUNTIME_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/bin) +set(CMAKE_LIBRARY_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/bin) if (CMAKE_SOURCE_DIR STREQUAL CMAKE_CURRENT_SOURCE_DIR) set(WHISPER_STANDALONE ON) @@ -85,7 +86,7 @@ option(WHISPER_CURL "whisper: use libcurl to download model from an URL" OFF) option(WHISPER_SDL2 "whisper: support for libSDL2" OFF) if (CMAKE_SYSTEM_NAME MATCHES "Linux") - option(WHISPER_FFMPEG "whisper: support building and linking with ffmpeg libs (avcodec, swresample, ...)" OFF) + option(WHISPER_COMMON_FFMPEG "whisper: examples link with ffmpeg libs in order to decode more audio formats" OFF) endif() option(WHISPER_COREML "whisper: enable Core ML framework" OFF) @@ -121,6 +122,7 @@ whisper_option_depr(WARNING WHISPER_RPC GGML_RPC) whisper_option_depr(WARNING WHISPER_SYCL GGML_SYCL) whisper_option_depr(WARNING WHISPER_SYCL_F16 GGML_SYCL_F16) whisper_option_depr(WARNING WHISPER_CCACHE GGML_CCACHE) +whisper_option_depr(WARNING WHISPER_FFMPEG WHISPER_COMMON_FFMPEG) if (GGML_CUDA AND NOT MSVC) #GGML_CUDA enabled, add the necessary compile options -Wno-deprecated-gpu-targets @@ -179,12 +181,20 @@ set(WHISPER_BIN_INSTALL_DIR ${CMAKE_INSTALL_BINDIR} CACHE PATH "Location get_directory_property(WHISPER_TRANSIENT_DEFINES COMPILE_DEFINITIONS) set_target_properties(whisper PROPERTIES PUBLIC_HEADER ${CMAKE_CURRENT_SOURCE_DIR}/include/whisper.h) + install(TARGETS whisper LIBRARY PUBLIC_HEADER) target_compile_definitions(whisper PRIVATE WHISPER_VERSION="${PROJECT_VERSION}" ) +set_target_properties(parakeet PROPERTIES PUBLIC_HEADER ${CMAKE_CURRENT_SOURCE_DIR}/include/parakeet.h) +install(TARGETS parakeet LIBRARY PUBLIC_HEADER) + +target_compile_definitions(parakeet PRIVATE + PARAKEET_VERSION="${PROJECT_VERSION}" +) + configure_package_config_file( ${CMAKE_CURRENT_SOURCE_DIR}/cmake/whisper-config.cmake.in ${CMAKE_CURRENT_BINARY_DIR}/whisper-config.cmake @@ -208,7 +218,36 @@ configure_file(cmake/whisper.pc.in @ONLY) install(FILES "${CMAKE_CURRENT_BINARY_DIR}/whisper.pc" - DESTINATION lib/pkgconfig) + DESTINATION ${CMAKE_INSTALL_LIBDIR}/pkgconfig) + +set(PARAKEET_INCLUDE_INSTALL_DIR ${CMAKE_INSTALL_INCLUDEDIR} CACHE PATH "Location of header files") +set(PARAKEET_LIB_INSTALL_DIR ${CMAKE_INSTALL_LIBDIR} CACHE PATH "Location of library files") +set(PARAKEET_BIN_INSTALL_DIR ${CMAKE_INSTALL_BINDIR} CACHE PATH "Location of binary files") + +configure_package_config_file( + ${CMAKE_CURRENT_SOURCE_DIR}/cmake/parakeet-config.cmake.in + ${CMAKE_CURRENT_BINARY_DIR}/parakeet-config.cmake + INSTALL_DESTINATION ${CMAKE_INSTALL_LIBDIR}/cmake/parakeet + PATH_VARS + PARAKEET_INCLUDE_INSTALL_DIR + PARAKEET_LIB_INSTALL_DIR + PARAKEET_BIN_INSTALL_DIR) + +write_basic_package_version_file( + ${CMAKE_CURRENT_BINARY_DIR}/parakeet-version.cmake + VERSION ${WHISPER_INSTALL_VERSION} + COMPATIBILITY SameMajorVersion) + +install(FILES ${CMAKE_CURRENT_BINARY_DIR}/parakeet-config.cmake + ${CMAKE_CURRENT_BINARY_DIR}/parakeet-version.cmake + DESTINATION ${CMAKE_INSTALL_LIBDIR}/cmake/parakeet) + +configure_file(cmake/parakeet.pc.in + "${CMAKE_CURRENT_BINARY_DIR}/parakeet.pc" + @ONLY) + +install(FILES "${CMAKE_CURRENT_BINARY_DIR}/parakeet.pc" + DESTINATION ${CMAKE_INSTALL_LIBDIR}/pkgconfig) # # programs, examples and tests diff --git a/CMakePresets.json b/CMakePresets.json new file mode 100644 index 00000000000..b5afeb3c0f2 --- /dev/null +++ b/CMakePresets.json @@ -0,0 +1,95 @@ +{ + "version": 4, + "configurePresets": [ + { + "name": "base", + "hidden": true, + "generator": "Ninja", + "binaryDir": "${sourceDir}/build-${presetName}", + "cacheVariables": { + "CMAKE_EXPORT_COMPILE_COMMANDS": "ON", + "CMAKE_INSTALL_RPATH": "$ORIGIN;$ORIGIN/.." + } + }, + { + "name": "sycl-base", + "hidden": true, + "generator": "Ninja", + "binaryDir": "${sourceDir}/build-${presetName}", + "cacheVariables": { + "CMAKE_EXPORT_COMPILE_COMMANDS": "ON", + "CMAKE_CXX_COMPILER": "icx", + "CMAKE_C_COMPILER": "cl", + "GGML_SYCL": "ON", + "CMAKE_INSTALL_RPATH": "$ORIGIN;$ORIGIN/.." + } + }, + { "name": "debug", "hidden": true, "cacheVariables": { "CMAKE_BUILD_TYPE": "Debug" } }, + { "name": "release", "hidden": true, "cacheVariables": { "CMAKE_BUILD_TYPE": "Release" } }, + { "name": "reldbg", "hidden": true, "cacheVariables": { "CMAKE_BUILD_TYPE": "RelWithDebInfo" } }, + { "name": "static", "hidden": true, "cacheVariables": { "GGML_STATIC": "ON" } }, + { "name": "sycl_f16", "hidden": true, "cacheVariables": { "GGML_SYCL_F16": "ON" } }, + { "name": "vulkan", "hidden": true, "cacheVariables": { "GGML_VULKAN": "ON" } }, + + { + "name": "x64-windows-llvm", "hidden": true, + "cacheVariables": { + "CMAKE_TOOLCHAIN_FILE": "${sourceDir}/cmake/x64-windows-llvm.cmake" + } + }, + + { + "name": "arm64-windows-llvm", "hidden": true, + "architecture": { "value": "arm64", "strategy": "external" }, + "toolset": { "value": "host=x64", "strategy": "external" }, + "cacheVariables": { + "CMAKE_TOOLCHAIN_FILE": "${sourceDir}/cmake/arm64-windows-llvm.cmake" + } + }, + + { + "name": "arm64-apple-clang", "hidden": true, + "architecture": { "value": "arm64", "strategy": "external" }, + "toolset": { "value": "host=x64", "strategy": "external" }, + "cacheVariables": { + "CMAKE_TOOLCHAIN_FILE": "${sourceDir}/cmake/arm64-apple-clang.cmake" + } + }, + { + "name": "x64-linux-gcc", "hidden": true, + "cacheVariables": { + "CMAKE_C_COMPILER": "gcc", + "CMAKE_CXX_COMPILER": "g++" + } + }, + { "name": "x64-linux-gcc-debug", "inherits": [ "base", "x64-linux-gcc", "debug" ] }, + { "name": "x64-linux-gcc-release", "inherits": [ "base", "x64-linux-gcc", "release" ] }, + { "name": "x64-linux-gcc-reldbg", "inherits": [ "base", "x64-linux-gcc", "reldbg" ] }, + { "name": "x64-linux-gcc+static-release", "inherits": [ "base", "x64-linux-gcc", "release", "static" ] }, + + { "name": "arm64-windows-llvm-debug", "inherits": [ "base", "arm64-windows-llvm", "debug" ] }, + { "name": "arm64-windows-llvm-release", "inherits": [ "base", "arm64-windows-llvm", "reldbg" ] }, + { "name": "arm64-windows-llvm+static-release", "inherits": [ "base", "arm64-windows-llvm", "reldbg", "static" ] }, + + { "name": "arm64-apple-clang-debug", "inherits": [ "base", "arm64-apple-clang", "debug" ] }, + { "name": "arm64-apple-clang-release", "inherits": [ "base", "arm64-apple-clang", "reldbg" ] }, + { "name": "arm64-apple-clang+static-release", "inherits": [ "base", "arm64-apple-clang", "reldbg", "static" ] }, + + { "name": "x64-windows-llvm-debug", "inherits": [ "base", "x64-windows-llvm", "debug" ] }, + { "name": "x64-windows-llvm-release", "inherits": [ "base", "x64-windows-llvm", "release" ] }, + { "name": "x64-windows-llvm-reldbg", "inherits": [ "base", "x64-windows-llvm", "reldbg" ] }, + { "name": "x64-windows-llvm+static-release", "inherits": [ "base", "x64-windows-llvm", "reldbg", "static" ] }, + + { "name": "x64-windows-msvc-debug", "inherits": [ "base", "debug" ] }, + { "name": "x64-windows-msvc-release", "inherits": [ "base", "reldbg" ] }, + { "name": "x64-windows-msvc+static-release", "inherits": [ "base", "reldbg", "static" ] }, + + { "name": "x64-windows-sycl-debug", "inherits": [ "sycl-base", "debug" ] }, + { "name": "x64-windows-sycl-debug-f16", "inherits": [ "sycl-base", "debug", "sycl_f16" ] }, + { "name": "x64-windows-sycl-release", "inherits": [ "sycl-base", "release" ] }, + { "name": "x64-windows-sycl-release-f16", "inherits": [ "sycl-base", "release", "sycl_f16" ] }, + + { "name": "x64-windows-vulkan-debug", "inherits": [ "base", "vulkan", "debug" ] }, + { "name": "x64-windows-vulkan-release", "inherits": [ "base", "vulkan", "release" ] } + ] +} diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md new file mode 100644 index 00000000000..c301604f1de --- /dev/null +++ b/CONTRIBUTING.md @@ -0,0 +1,176 @@ +# Contributors + +The project differentiates between 3 levels of contributors: + +- Contributors: people who have contributed before (no special privileges) +- Collaborators (Triage): people with significant contributions, who may be responsible for some parts of the code, and are expected to maintain and review contributions for the code they own +- Maintainers: responsible for reviewing and merging PRs, after approval from the code owners + +# AI Usage Policy + +> [!IMPORTANT] +> This project does **not** accept pull requests that are fully or predominantly AI-generated. AI tools may be utilized solely in an assistive capacity. +> +> Repeated violations of this policy may result in your account being permanently banned from contributing to the project. +> +> Detailed information regarding permissible and restricted uses of AI can be found in the [AGENTS.md](AGENTS.md) file. + +Code that is initially generated by AI and subsequently edited will still be considered AI-generated. AI assistance is permissible only when the majority of the code is authored by a human contributor, with AI employed exclusively for corrections or to expand on verbose modifications that the contributor has already conceptualized (e.g., generating repeated lines with minor variations). + +If AI is used to generate any portion of the code, contributors must adhere to the following requirements: + +1. Explicitly disclose the manner in which AI was employed. +2. Perform a comprehensive manual review prior to submitting the pull request. +3. Be prepared to explain every line of code they submitted when asked about it by a maintainer. +4. It is strictly prohibited to use AI to write your posts for you (bug reports, feature requests, pull request descriptions, Github discussions, responding to humans, ...). + +For more info, please refer to the [AGENTS.md](AGENTS.md) file. + +# Pull requests (for contributors & collaborators) + +Before submitting your PR: +- Search for existing PRs to prevent duplicating efforts +- whisper.cpp uses the ggml tensor library for model evaluation. If you are unfamiliar with ggml, consider taking a look at the [examples in the ggml repository](https://github.com/ggml-org/ggml/tree/master/examples/). [simple](https://github.com/ggml-org/ggml/tree/master/examples/simple) shows the bare minimum for using ggml. [gpt-2](https://github.com/ggml-org/ggml/tree/master/examples/gpt-2) has minimal implementations for language model inference using GPT-2. [mnist](https://github.com/ggml-org/ggml/tree/master/examples/mnist) demonstrates how to train and evaluate a simple image classifier +- Test your changes: + - Execute [the full CI locally on your machine](ci/README.md) before publishing +- Create separate PRs for each feature or fix: + - Avoid combining unrelated changes in a single PR + - For intricate features, consider opening a feature request first to discuss and align expectations +- If you are a new contributor + - Limit your open PRs to 1 + - Do not submit trivial fixes (e.g. typos, formatting changes) + +After submitting your PR: +- Expect requests for modifications to ensure the code meets whisper.cpp's standards for quality and long-term maintainability +- Maintainers will rely on your insights and approval when making a final decision to approve and merge a PR +- If your PR becomes stale, rebase it on top of latest `master` to get maintainers attention + +# Pull requests (for maintainers) + +- Squash-merge PRs +- Use the following format for the squashed commit title: `<module> : <commit title> (#<issue_number>)`. For example: `utils : fix typo in utils.py (#1234)` +- Optionally pick a `<module>` from here: https://github.com/ggml-org/llama.cpp/wiki/Modules +- Let other maintainers merge their own PRs +- When merging a PR, make sure you have a good understanding of the changes +- Be mindful of maintenance: most of the work going into a feature happens after the PR is merged. If the PR author is not committed to contribute long-term, someone else needs to take responsibility (you) + +Maintainers reserve the right to decline review or close pull requests for any reason, without any questions, particularly under any of the following conditions: +- The proposed change is already mentioned in the roadmap or an existing issue, and it has been assigned to someone. +- The pull request duplicates an existing one. +- The contributor fails to adhere to this contributing guide or the AI policy. + +# Coding guidelines + +- Avoid adding third-party dependencies, extra files, extra headers, etc. +- Always consider cross-compatibility with other operating systems and architectures +- Avoid fancy-looking modern STL constructs, use basic `for` loops, avoid templates, keep it simple +- Vertical alignment makes things more readable and easier to batch edit +- Clean-up any trailing whitespaces, use 4 spaces for indentation, brackets on the same line, `void * ptr`, `int & a` +- Use sized integer types such as `int32_t` in the public API, e.g. `size_t` may also be appropriate for allocation sizes or byte offsets +- Declare structs with `struct foo {}` instead of `typedef struct foo {} foo` + - In C++ code omit optional `struct` and `enum` keyword whenever they are not necessary + ```cpp + // OK + llama_context * ctx; + const llama_rope_type rope_type; + + // not OK + struct llama_context * ctx; + const enum llama_rope_type rope_type; + ``` + + _(NOTE: this guideline is yet to be applied to the `whisper.cpp` codebase. New code should follow this guideline.)_ + +- Try to follow the existing patterns in the code (indentation, spaces, etc.). In case of doubt use `clang-format` (from clang-tools v15+) to format the added code +- For anything not covered in the current guidelines, refer to the [C++ Core Guidelines](https://isocpp.github.io/CppCoreGuidelines/CppCoreGuidelines) +- Tensors store data in row-major order. We refer to dimension 0 as columns, 1 as rows, 2 as matrices +- Matrix multiplication is unconventional: [`C = ggml_mul_mat(ctx, A, B)`](https://github.com/ggml-org/llama.cpp/blob/880e352277fc017df4d5794f0c21c44e1eae2b84/ggml.h#L1058-L1064) means $C^T = A B^T \Leftrightarrow C = B A^T.$ + +![matmul](media/matmul.png) + +# Naming guidelines + +- Use `snake_case` for function, variable and type names +- Naming usually optimizes for longest common prefix (see https://github.com/ggml-org/ggml/pull/302#discussion_r1243240963) + + ```cpp + // not OK + int small_number; + int big_number; + + // OK + int number_small; + int number_big; + ``` + +- Enum values are always in upper case and prefixed with the enum name + + ```cpp + enum llama_vocab_type { + LLAMA_VOCAB_TYPE_NONE = 0, + LLAMA_VOCAB_TYPE_SPM = 1, + LLAMA_VOCAB_TYPE_BPE = 2, + LLAMA_VOCAB_TYPE_WPM = 3, + LLAMA_VOCAB_TYPE_UGM = 4, + LLAMA_VOCAB_TYPE_RWKV = 5, + }; + ``` + +- The general naming pattern is `<class>_<method>`, with `<method>` being `<action>_<noun>` + + ```cpp + llama_model_init(); // class: "llama_model", method: "init" + llama_sampler_chain_remove(); // class: "llama_sampler_chain", method: "remove" + llama_sampler_get_seed(); // class: "llama_sampler", method: "get_seed" + llama_set_embeddings(); // class: "llama_context", method: "set_embeddings" + llama_n_threads(); // class: "llama_context", method: "n_threads" + llama_adapter_lora_free(); // class: "llama_adapter_lora", method: "free" + ``` + + - The `get` `<action>` can be omitted + - The `<noun>` can be omitted if not necessary + - The `_context` suffix of the `<class>` is optional. Use it to disambiguate symbols when needed + - Use `init`/`free` for constructor/destructor `<action>` + +- Use the `_t` suffix when a type is supposed to be opaque to the user - it's not relevant to them if it is a struct or anything else + + ```cpp + typedef struct llama_context * llama_context_t; + + enum llama_pooling_type llama_pooling_type(const llama_context_t ctx); + ``` + + _(NOTE: this guideline is yet to be applied to the `whisper.cpp` codebase. New code should follow this guideline)_ + +- C/C++ filenames are all lowercase with dashes. Headers use the `.h` extension. Source files use the `.c` or `.cpp` extension +- Python filenames are all lowercase with underscores + +- _(TODO: abbreviations usage)_ + +# Preprocessor directives + +- _(TODO: add guidelines with examples and apply them to the codebase)_ + + ```cpp + #ifdef FOO + #endif // FOO + ``` + +# Code maintenance + +- New code should follow the guidelines (coding, naming, etc.) outlined in this document. Exceptions are allowed in isolated, backend-specific parts of the code that do not interface directly with the `ggml` interfaces. + _(NOTE: for legacy reasons, existing code is not required to follow this guideline)_ + +- For changes in server, please make sure to refer to the [server development documentation](./tools/server/README-dev.md) + +# Documentation + +- Documentation is a community effort +- When you need to look into the source code to figure out how to use an API consider adding a short summary to the header file for future reference +- When you notice incorrect or outdated documentation, please update it + +# Resources + +The Github issues, PRs and discussions contain a lot of information that can be useful to get familiar with the codebase. For convenience, some of the more important information is referenced from Github projects: + +https://github.com/ggml-org/whisper.cpp/projects diff --git a/LICENSE b/LICENSE index acb96ce78e0..e7dca554bcb 100644 --- a/LICENSE +++ b/LICENSE @@ -1,6 +1,6 @@ MIT License -Copyright (c) 2023-2024 The ggml authors +Copyright (c) 2023-2026 The ggml authors Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal diff --git a/README.md b/README.md index 6d4988e6fa5..0e2d5f100d5 100644 --- a/README.md +++ b/README.md @@ -7,7 +7,7 @@ [![Conan Center](https://shields.io/conan/v/whisper-cpp)](https://conan.io/center/whisper-cpp) [![npm](https://img.shields.io/npm/v/whisper.cpp.svg)](https://www.npmjs.com/package/whisper.cpp/) -Stable: [v1.8.1](https://github.com/ggml-org/whisper.cpp/releases/tag/v1.8.1) / [Roadmap](https://github.com/orgs/ggml-org/projects/4/) +Stable: [v1.9.1](https://github.com/ggml-org/whisper.cpp/releases/tag/v1.9.1) / [Roadmap](https://github.com/orgs/ggml-org/projects/4/) High-performance inference of [OpenAI's Whisper](https://github.com/openai/whisper) automatic speech recognition (ASR) model: @@ -21,6 +21,7 @@ High-performance inference of [OpenAI's Whisper](https://github.com/openai/whisp - [Vulkan support](#vulkan-gpu-support) - Support for CPU-only inference - [Efficient GPU support for NVIDIA](#nvidia-gpu-support) +- [AMD ROCm GPU support](#amd-rocm-gpu-support) - [OpenVINO Support](#openvino-support) - [Ascend NPU Support](#ascend-npu-support) - [Moore Threads GPU Support](#moore-threads-gpu-support) @@ -340,6 +341,27 @@ cmake -B build -DGGML_VULKAN=1 cmake --build build -j --config Release ``` +## AMD ROCm GPU support + +With AMD GPUs the processing can be accelerated via HIP/ROCm. +First, make sure you have installed [ROCm](https://rocm.docs.amd.com/en/latest/). + +Now build `whisper.cpp` with HIP support: + +``` +cmake -B build -DGGML_HIP=1 -DAMDGPU_TARGETS="gfx1201" +cmake --build build -j --config Release +``` + +Replace `gfx1201` with your GPU architecture. You can find it with: + +``` +rocminfo | grep "gfx" +``` + +Common architectures: `gfx1100` (RX 7900 XTX), `gfx1101` (RX 7800 XT), `gfx1201` (RX 9070 XT). +For multiple GPUs with different architectures: `-DAMDGPU_TARGETS="gfx1100;gfx1201"`. + ## BLAS CPU support via OpenBLAS Encoder processing can be accelerated on the CPU via OpenBLAS. @@ -403,9 +425,10 @@ cmake -B build -DGGML_MUSA=1 -DMUSA_ARCHITECTURES="21" cmake --build build -j --config Release ``` -## FFmpeg support (Linux only) +## FFmpeg support (examples only) -If you want to support more audio formats (such as Opus and AAC), you can turn on the `WHISPER_FFMPEG` build flag to enable FFmpeg integration. +By default, the examples in this repo use the [miniaudio](https://github.com/mackron/miniaudio) library to decode audio files. +Some of the examples also can use FFmpeg for decoding and broader format support. To enable that, build with `WHISPER_COMMON_FFMPEG`. First, you need to install required libraries: @@ -420,7 +443,7 @@ sudo dnf install libavcodec-free-devel libavformat-free-devel libavutil-free-dev Then you can build the project as follows: ```bash -cmake -B build -D WHISPER_FFMPEG=yes +cmake -B build -D WHISPER_COMMON_FFMPEG=yes cmake --build build ``` @@ -443,11 +466,12 @@ ffmpeg -i samples/jfk.wav jfk.opus ### Images -We have two Docker images available for this project: +We have multiple Docker images available for this project: 1. `ghcr.io/ggml-org/whisper.cpp:main`: This image includes the main executable file as well as `curl` and `ffmpeg`. (platforms: `linux/amd64`, `linux/arm64`) 2. `ghcr.io/ggml-org/whisper.cpp:main-cuda`: Same as `main` but compiled with CUDA support. (platforms: `linux/amd64`) 3. `ghcr.io/ggml-org/whisper.cpp:main-musa`: Same as `main` but compiled with MUSA support. (platforms: `linux/amd64`) +4. `ghcr.io/ggml-org/whisper.cpp:main-vulkan`: Same as `main` but compiled with Vulkan support. (platforms: `linux/amd64`) ### Usage @@ -456,15 +480,27 @@ We have two Docker images available for this project: docker run -it --rm \ -v path/to/models:/models \ whisper.cpp:main "./models/download-ggml-model.sh base /models" + # transcribe an audio file docker run -it --rm \ -v path/to/models:/models \ -v path/to/audios:/audios \ whisper.cpp:main "whisper-cli -m /models/ggml-base.bin -f /audios/jfk.wav" + # transcribe an audio file in samples folder docker run -it --rm \ -v path/to/models:/models \ whisper.cpp:main "whisper-cli -m /models/ggml-base.bin -f ./samples/jfk.wav" + +# run the web server +docker run -it --rm -p "8080:8080" \ + -v path/to/models:/models \ + whisper.cpp:main "whisper-server --host 127.0.0.1 -m /models/ggml-base.bin" + +# run the bench too on the small.en model using 4 threads +docker run -it --rm \ + -v path/to/models:/models \ + whisper.cpp:main "whisper-bench -m /models/ggml-small.en.bin -t 4" ``` ## Installing with Conan @@ -742,7 +778,7 @@ argument to `whisper-cli`. In addition to this option a VAD model is also required. The way this works is that first the audio samples are passed through -the VAD model which will detect speech segments. Using this information the +the VAD model which will detect speech segments. Using this information, only the speech segments that are detected are extracted from the original audio input and passed to whisper for processing. This reduces the amount of audio data that needs to be processed by whisper and can significantly speed up the diff --git a/bindings/go/examples/go-model-download/main.go b/bindings/go/examples/go-model-download/main.go index 728c6df53d4..e72262eb7cb 100644 --- a/bindings/go/examples/go-model-download/main.go +++ b/bindings/go/examples/go-model-download/main.go @@ -282,13 +282,20 @@ func Download(ctx context.Context, p io.Writer, model, out string) (string, erro default: // Read body n, err := resp.Body.Read(data) + if n > 0 { + if m, err := w.Write(data[:n]); err != nil { + return path, err + } else { + count += int64(m) + } + } + if err != nil { - DownloadReport(p, pct, count, resp.ContentLength) - return path, err - } else if m, err := w.Write(data[:n]); err != nil { + if err == io.EOF { + DownloadReport(p, pct, count, resp.ContentLength) + return path, nil + } return path, err - } else { - count += int64(m) } } } diff --git a/bindings/go/pkg/whisper/context_test.go b/bindings/go/pkg/whisper/context_test.go index e98a4c2b80b..79f6a593024 100644 --- a/bindings/go/pkg/whisper/context_test.go +++ b/bindings/go/pkg/whisper/context_test.go @@ -2,6 +2,7 @@ package whisper_test import ( "os" + "strings" "testing" "github.com/ggerganov/whisper.cpp/bindings/go/pkg/whisper" @@ -92,6 +93,53 @@ func TestProcess(t *testing.T) { assert.NoError(err) } +func TestProcessMaxTokensPerSegment(t *testing.T) { + assert := assert.New(t) + + if _, err := os.Stat(ModelPath); os.IsNotExist(err) { + t.Skip("Skipping test, model not found:", ModelPath) + } + + fh, err := os.Open(SamplePath) + assert.NoError(err) + defer fh.Close() + + // Decode the WAV file - load the full buffer + dec := wav.NewDecoder(fh) + buf, err := dec.FullPCMBuffer() + assert.NoError(err) + assert.Equal(uint16(1), dec.NumChans) + + data := buf.AsFloat32Buffer().Data + + model, err := whisper.New(ModelPath) + assert.NoError(err) + assert.NotNil(model) + defer model.Close() + + context, err := model.NewContext() + assert.NoError(err) + + context.SetMaxTokensPerSegment(5) + + err = context.Process(data, nil, nil, nil) + assert.NoError(err) + + var text strings.Builder + nSegments := 0 + for { + segment, err := context.NextSegment() + if err != nil { + break + } + nSegments++ + text.WriteString(segment.Text) + } + + assert.Greater(nSegments, 1) + assert.Contains(text.String(), "country") +} + func TestDetectedLanguage(t *testing.T) { assert := assert.New(t) diff --git a/bindings/javascript/package.json b/bindings/javascript/package.json index 84139804314..09829326605 100644 --- a/bindings/javascript/package.json +++ b/bindings/javascript/package.json @@ -1,6 +1,6 @@ { "name": "whisper.cpp", - "version": "1.8.3", + "version": "1.9.1", "description": "Whisper speech recognition", "main": "whisper.js", "scripts": { diff --git a/bindings/ruby/.document b/bindings/ruby/.document new file mode 100644 index 00000000000..a8e9788fc7c --- /dev/null +++ b/bindings/ruby/.document @@ -0,0 +1,3 @@ +README.md +LICENSE +sig diff --git a/bindings/ruby/.rdoc_options b/bindings/ruby/.rdoc_options new file mode 100644 index 00000000000..cf14aa5f5b4 --- /dev/null +++ b/bindings/ruby/.rdoc_options @@ -0,0 +1,2 @@ +title: whispercpp +main_page: README.md diff --git a/bindings/ruby/README.md b/bindings/ruby/README.md index ea202753b67..7f6b7d92c09 100644 --- a/bindings/ruby/README.md +++ b/bindings/ruby/README.md @@ -202,6 +202,8 @@ whisper.transcribe("path/to/audio.wav", params, n_processors: Etc.nprocessors) Note that transcription occasionally might be low accuracy when it works in parallel. +If n_processors is greater than 1, you cannot set any callbacks including new_segment_callback, progress_callback, encoder_begin_callback, abort_callback, and log_callback set by Whisper.log_set. + ### Segments ### Once `Whisper::Context#transcribe` called, you can retrieve segments by `#each_segment`: @@ -247,6 +249,58 @@ whisper.transcribe("path/to/audio.wav", params) ``` +### Tokens ### + +Each segment has tokens. + +To enable token timestamps, you need to set `Whisper::Params#token_timestamps = true`. Then, retrieve tokens from segments using `Whisper::Segment#each_token`. + +```ruby +whisper = Whisper::Context.new("base.en") +params = Whisper::Params.new(token_timestamps: true) +whisper + .transcribe("path/to/audio.wav", params) + .each_segment do |segment| + segment.each_token do |token| + token => {start_time:, end_time:, text:, probability:} + st = "%05.2fs" % (start_time / 1000.0) + et = "%05.2fs" % (end_time / 1000.0) + prob = "%.1f%%" % (probability * 100) + puts "[#{st} --> #{et}] #{text} (#{prob})" + end + end +``` + +``` +[00.00s --> 00.00s] [_BEG_] (84.2%) +[00.32s --> 00.37s] And (71.2%) +[00.37s --> 00.53s] so (98.5%) +[00.69s --> 00.85s] my (70.7%) +[00.85s --> 01.59s] fellow (99.5%) +[01.59s --> 02.10s] Americans (90.1%) +[02.85s --> 03.30s] , (28.4%) +[03.30s --> 04.14s] ask (79.8%) +[04.14s --> 04.28s] not (78.9%) +[05.03s --> 05.35s] what (93.3%) +[05.41s --> 05.74s] your (98.8%) +[05.74s --> 06.41s] country (99.6%) +[06.41s --> 06.74s] can (97.7%) +[06.74s --> 06.92s] do (99.0%) +[07.00s --> 07.00s] for (95.8%) +[07.01s --> 07.52s] you (98.5%) +[07.81s --> 08.05s] , (49.3%) +[08.19s --> 08.37s] ask (65.6%) +[08.37s --> 08.75s] what (98.8%) +[08.91s --> 09.04s] you (98.2%) +[09.04s --> 09.32s] can (96.9%) +[09.32s --> 09.38s] do (90.3%) +[09.44s --> 09.76s] for (91.8%) +[09.76s --> 09.99s] your (98.2%) +[10.02s --> 10.36s] country (99.6%) +[10.51s --> 10.99s] . (87.0%) +[11.00s --> 11.00s] [_TT_550] (7.6%) +``` + ### Models ### You can see model information: @@ -306,7 +360,7 @@ Whisper::Context.new("base") ### Low-level API to transcribe ### -You can also call `Whisper::Context#full` and `#full_parallel` with a Ruby array as samples. Although `#transcribe` with audio file path is recommended because it extracts PCM samples in C++ and is fast, `#full` and `#full_parallel` give you flexibility. +You can also call `Whisper::Context#full` and `#full_parallel` with a Ruby array as samples. Although `#transcribe` with audio file path is recommended because it extracts PCM samples in C++ and is fast, `#full` and `#full_parallel` give you flexibility. Unlike `#transcribe`, these methods requires 16,000 Hz, 32-bit float audio. ```ruby require "whisper" @@ -323,7 +377,69 @@ whisper end ``` -The second argument `samples` may be an array, an object with `length` and `each` method, or a MemoryView. If you can prepare audio data as C array and export it as a MemoryView, whispercpp accepts and works with it with zero copy. +The second argument `samples` may be an array, an object with `length` and `each` method, or a MemoryView. + +If you can prepare audio data as C array and export it as a MemoryView, whispercpp accepts and works with it with zero copy. + +```ruby +require "torchaudio" +require "ndav/torch/tensor" +require "whisper" + +waveform, sample_rate = TorchAudio.load("test/fixtures/jfk.wav") +# Convert Torch::Tensor to NDAV +samples = waveform.squeeze.to_ndav + +whisper = Whisper::Context.new("base") +whisper + # NDAV exports MemoryView + .full(Whisper::Params.new, samples) +``` + +### Parakeet ### + +whispercpp gem now supports NVIDIA's ASR model Parakeet. + +If you want to use Parakeet instead of Whisper, the API should feel familiar. +In most cases, replace `Whisper::Context` and `Whisper::Params` with `Whisper::Parakeet::Context` and `Whisper::Parakeet::Params`, then use `#transcribe`, `#full`, `#each_segment`, and `#each_token` in the same way. + +```ruby +require "whisper" + +# It's useful to assign Whisper::Parakeet to top-level Parakeet constant unless you use Parakeet gem. +Parakeet = Whisper::Parakeet + +parakeet = Parakeet::Context.new("path/to/model") + +params = Parakeet::Params.new( + no_context: true +) + +parakeet + .transcribe("path/to/audio.wav", params) + .each_segment do |segment| + puts "[#{segment.start_time} --> #{segment.end_time}] #{segment.text}" + end +``` + +The main differences are: + +* Namespace is `Whisper::Parakeet`. +* Parakeet also supports `on_new_token` / `new_token_callback` in addition to segment and progress callbacks. + +Custom context params +--------------------- + +You can use customize `Whisper::Context`'s behavior using `Whisper::Context::Params`. + +```ruby +context_params = Whisper::Context::Params.new( + use_gpu: false, + flash_attn: false, + # etc +) +whisper = Whisper::Context.new("base", context_params) +``` Using VAD separately from ASR ----------------------------- @@ -334,13 +450,27 @@ VAD feature itself is useful. You can use it separately from ASR: vad = Whisper::VAD::Context.new("silero-v6.2.0") vad .detect("path/to/audio.wav", Whisper::VAD::Params.new) - .each_with_index do |segment, index| + .each.with_index do |segment, index| segment => {start_time: st, end_time: ed} # `Segment` responds to `#deconstruct_keys` puts "[%{nth}: %{st} --> %{ed}]" % {nth: index + 1, st:, ed:} end ``` +You may also low level API `Whisper::VAD::Context#segments_from_samples` as such `Whisper::Context#full`: + +```ruby +# Ruby Array +reader = WaveFile::Reader.new("path/to/audio.wav", WaveFile::Format.new(:mono, :float, 16000)) +samples = reader.enum_for(:each_buffer).map(&:samples).flatten + +# Or, object which exports MemoryView +waveform, sample_rate = TorchAudio.load("test/fixtures/jfk.wav") +samples = waveform.squeeze.numo.to_arrow.to_arrow_array + +segments = vad.segments_from_samples(Whisper::VAD::Params.new, samples) +``` + Development ----------- diff --git a/bindings/ruby/Rakefile b/bindings/ruby/Rakefile index d9a66030de4..2327651a06a 100644 --- a/bindings/ruby/Rakefile +++ b/bindings/ruby/Rakefile @@ -16,7 +16,7 @@ EXTSOURCES.each do |src| file src directory dir file dest => [src, dir] do |t| - cp t.source, t.name + copy t.source, t.name end SOURCES.include dest end @@ -34,7 +34,7 @@ LIB_NAME = "whisper".ext(RbConfig::CONFIG["DLEXT"]) SO_FILE = File.join("ext", LIB_NAME) LIB_FILE = File.join("lib", LIB_NAME) -file "ext/Makefile" => SRC + ["ext/extconf.rb"] + SOURCES do |t| +file "ext/Makefile" => SRC + SOURCES + FileList["ext/*.rb"] do |t| chdir "ext" do ruby "extconf.rb" end @@ -84,6 +84,21 @@ else end end +TEST_PARAKEET_MODEL = "test/fixtures/for-tests-ggml-parakeet-tdt.bin" +TEST_PARAKEET_MODEL_SRC = File.expand_path(File.join(__dir__, "..", "..", "models", "for-tests-ggml-parakeet-tdt.bin")) +TEST_PARAKEET_MODEL_DIR = TEST_PARAKEET_MODEL.pathmap("%d") +directory TEST_PARAKEET_MODEL_DIR +if File.exist? TEST_PARAKEET_MODEL_SRC + file TEST_PARAKEET_MODEL => [TEST_PARAKEET_MODEL_SRC, TEST_PARAKEET_MODEL_DIR] do |t| + symlink t.source, t.name + end +else + require "open-uri" + file TEST_PARAKEET_MODEL => TEST_PARAKEET_MODEL_DIR do |t| + File.write t.name, URI("https://github.com/ggml-org/whisper.cpp/raw/refs/heads/master/models/for-tests-ggml-parakeet-tdt.bin").read + end +end + TEST_MEMORY_VIEW = "test/jfk_reader/jfk_reader.#{RbConfig::CONFIG['DLEXT']}" file TEST_MEMORY_VIEW => "test/jfk_reader/jfk_reader.c" do |t| chdir "test/jfk_reader" do @@ -93,4 +108,4 @@ file TEST_MEMORY_VIEW => "test/jfk_reader/jfk_reader.c" do |t| end CLEAN.include TEST_MEMORY_VIEW -task test: [LIB_FILE, TEST_MEMORY_VIEW, TEST_FIXTURE_AUDIO] +task test: [LIB_FILE, TEST_MEMORY_VIEW, TEST_FIXTURE_AUDIO, TEST_PARAKEET_MODEL] diff --git a/bindings/ruby/ext/dependencies.rb b/bindings/ruby/ext/dependencies.rb index 2ba4b94b62b..b2eb9beb84f 100644 --- a/bindings/ruby/ext/dependencies.rb +++ b/bindings/ruby/ext/dependencies.rb @@ -22,13 +22,17 @@ def libs else nil end - }.reverse.collect {|lib| "lib#{lib}.a"} + }.reverse.collect {|lib| "#{prefix(lib)}#{lib}.#{RbConfig::CONFIG['LIBEXT']}"} end def to_s libs.join(" ") end + def local_libs + to_s + end + private def dot_path @@ -36,9 +40,7 @@ def dot_path end def generate_dot - args = ["-S", "sources", "-B", "build", "--graphviz", dot_path, "-D", "BUILD_SHARED_LIBS=OFF"] - args << @options.to_s unless @options.to_s.empty? - system @cmake, *args, exception: true + system @cmake, "-S", "sources", "-B", "build", *@options.graphviz_cmake_args, "--graphviz", dot_path, *@options, exception: true end def parse_dot @@ -59,6 +61,10 @@ def parse_dot end end + def prefix(lib) + "lib" + end + def tsort_each_node @nodes.each_key do |node| yield node diff --git a/bindings/ruby/ext/dependencies_for_windows.rb b/bindings/ruby/ext/dependencies_for_windows.rb new file mode 100644 index 00000000000..5574107182d --- /dev/null +++ b/bindings/ruby/ext/dependencies_for_windows.rb @@ -0,0 +1,17 @@ +require_relative "dependencies" + +class DependenciesForWindows < Dependencies + def local_libs + libs.collect {|lib| %|"#{lib_path(lib)}"|}.join(" ") + end + + private + + def prefix(lib) + lib.start_with?("ggml") ? "" : "lib" + end + + def lib_path(lib) + File.join(__dir__, lib).tr("\\", "/") + end +end diff --git a/bindings/ruby/ext/extconf.rb b/bindings/ruby/ext/extconf.rb index 8a5ac67457b..99894f1234d 100644 --- a/bindings/ruby/ext/extconf.rb +++ b/bindings/ruby/ext/extconf.rb @@ -1,14 +1,27 @@ require "mkmf" -require_relative "options" -require_relative "dependencies" + +if RUBY_PLATFORM.match? /mswin|mingw|ucrt/ + require_relative "options_for_windows" + require_relative "dependencies_for_windows" + + Opts = OptionsForWindows + Deps = DependenciesForWindows +else + require_relative "options" + require_relative "dependencies" + + Opts = Options + Deps = Dependencies +end cmake = find_executable("cmake") || abort -options = Options.new(cmake).to_s +options = Opts.new(cmake) have_library("gomp") rescue nil -libs = Dependencies.new(cmake, options).to_s +libs = Deps.new(cmake, options) +append_cflags ["-O3", "-march=native"] $INCFLAGS << " -Isources/include -Isources/ggml/include -Isources/examples" -$LOCAL_LIBS << " #{libs}" +$LOCAL_LIBS << " #{libs.local_libs}" $cleanfiles << " build #{libs}" create_makefile "whisper" do |conf| @@ -16,7 +29,7 @@ $(TARGET_SO): #{libs} #{libs}: cmake-targets cmake-targets: - #{"\t"}#{cmake} -S sources -B build -D BUILD_SHARED_LIBS=OFF -D CMAKE_ARCHIVE_OUTPUT_DIRECTORY=#{__dir__} -D CMAKE_POSITION_INDEPENDENT_CODE=ON #{options} - #{"\t"}#{cmake} --build build --config Release --target common whisper + #{"\t"}"#{cmake}" -S sources -B build #{options} + #{"\t"}"#{cmake}" --build build --config Release --target common whisper parakeet EOF end diff --git a/bindings/ruby/ext/options.rb b/bindings/ruby/ext/options.rb index ede80c0656b..e723af9fd9a 100644 --- a/bindings/ruby/ext/options.rb +++ b/bindings/ruby/ext/options.rb @@ -1,26 +1,36 @@ +require "fileutils" + class Options def initialize(cmake="cmake") @cmake = cmake @options = {} configure + write_cache_file + end + + def to_a + [ + "-D", "BUILD_SHARED_LIBS=OFF", + "-D", "WHISPER_BUILD_TESTS=OFF", + "-D", "CMAKE_ARCHIVE_OUTPUT_DIRECTORY=#{__dir__}", + "-D", "CMAKE_POSITION_INDEPENDENT_CODE=ON", + "-C", cache_path + ] end def to_s - @options - .reject {|name, (type, value)| value.nil?} - .collect {|name, (type, value)| "-D #{name}=#{value == true ? "ON" : value == false ? "OFF" : value.shellescape}"} - .join(" ") + command_line(*to_a) end - def cmake_options - return @cmake_options if @cmake_options + def graphviz_cmake_args + [] + end - output = nil - Dir.chdir __dir__ do - output = `#{@cmake.shellescape} -S sources -B build -L` - end - @cmake_options = output.lines.drop_while {|line| line.chomp != "-- Cache values"}.drop(1) + private + + def cmake_options + @cmake_options ||= cmake_options_output.lines.drop_while {|line| line.chomp != "-- Cache values"}.drop(1) .filter_map {|line| option, value = line.chomp.split("=", 2) name, type = option.split(":", 2) @@ -34,7 +44,11 @@ def cmake_options }.to_h end - private + def cmake_options_output + Dir.chdir(__dir__) do + IO.popen([@cmake, "-S", "sources", "-B", "build", "-L"]) {|io| io.read} + end + end def configure cmake_options.each_pair do |name, (type, default_value)| @@ -74,12 +88,38 @@ def option_name(name) def enabled?(option) op = @options[option] - raise "Option not exist: #{option}" unless op - raise "Option not boolean: #{option}(#{op[0]})" unless op[0] == "BOOL" + return false unless op + return false unless op[0] == "BOOL" if op[1].nil? cmake_options[option][1] else op[1] end end + + def cache_path + File.join(__dir__, "sources", "Options.cmake") + end + + def write_cache_file + FileUtils.mkpath File.dirname(cache_path) + File.open cache_path, "w" do |file| + @options.reject {|name, (type, value)| value.nil?}.each do |name, (type, value)| + line = "set(CACHE{%<name>s} TYPE %<type>s FORCE VALUE %<value>s)" % { + name:, + type:, + value: value == true ? "ON" : value == false ? "OFF" : escape_cmake(value) + } + file.puts line + end + end + end + + def escape_cmake(str) + str.gsub(/[\\"]/, '\\\\\&') + end + + def command_line(*args) + args.collect {|arg| %|"#{arg.to_s.gsub(/[\\"]/, '\\\\\&')}"|}.join(" ") + end end diff --git a/bindings/ruby/ext/options_for_windows.rb b/bindings/ruby/ext/options_for_windows.rb new file mode 100644 index 00000000000..7db785d8a2d --- /dev/null +++ b/bindings/ruby/ext/options_for_windows.rb @@ -0,0 +1,51 @@ +require_relative "options" + +class OptionsForWindows < Options + def to_s + command_line(*generator_args, *to_a) + end + + def graphviz_cmake_args + generator_args + end + + private + + def arm? + RbConfig::CONFIG["host_cpu"].to_s.downcase.match?(/\A(?:arm64|aarch64)\z/) + end + + def cmake_options_output + Dir.chdir(__dir__) do + IO.popen([@cmake, "-S", "sources", "-B", "build", *generator_args, "-L"]) {|io| io.read} + end + end + + def generator_args + generator = cmake_generator + ["-G", generator] if generator && !generator.empty? + end + + def cmake_generator + return @cmake_generator if defined?(@cmake_generator) + + generator = ENV["CMAKE_GENERATOR"] + abort "CMAKE_GENERATOR=#{generator} is unsupported for mingw/ucrt Ruby" if visual_studio_generator_name?(generator) + return @cmake_generator = generator unless generator.nil? || generator.empty? + + ninja = find_executable("ninja") + return @cmake_generator = "Ninja" if ninja + + make = find_executable("make") + return @cmake_generator = "MSYS Makefiles" if make + + mingw32_make = find_executable("mingw32-make") + return @cmake_generator = "MinGW Makefiles" if mingw32_make + + @cmake_generator = nil + end + + def visual_studio_generator_name?(generator) + generator && generator.start_with?("Visual Studio") + end +end diff --git a/bindings/ruby/ext/ruby_whisper.c b/bindings/ruby/ext/ruby_whisper.c index ac677e9e3df..7941b1a99dd 100644 --- a/bindings/ruby/ext/ruby_whisper.c +++ b/bindings/ruby/ext/ruby_whisper.c @@ -1,21 +1,29 @@ -#include <ruby.h> -#include <ruby/memory_view.h> #include "ruby_whisper.h" VALUE mWhisper; +VALUE mLogSettable; VALUE mVAD; +VALUE mParakeet; VALUE cContext; VALUE cParams; VALUE cVADContext; VALUE cVADParams; VALUE cVADSegments; VALUE cVADSegment; +VALUE cParakeetContext; +VALUE cParakeetContextParams; +VALUE cParakeetParams; +VALUE cParakeetSegment; +VALUE cParakeetModel; VALUE eError; VALUE cSegment; VALUE cToken; VALUE cModel; +VALUE mOutputContext; +VALUE mOutputSegment; + ID id_to_s; ID id_call; ID id___method__; @@ -29,13 +37,17 @@ ID id_pre_converted_models; ID id_coreml_compiled_models; ID id_cache; ID id_n_processors; - -static bool is_log_callback_finalized = false; +ID id_extended; +ID id_start_log_callback_thread; +ID id_log_callback_thread; +ID id_alive_p; +ID id_join; // High level API extern VALUE ruby_whisper_segment_allocate(VALUE klass); -extern void init_ruby_whisper_context(VALUE *mWhisper); +extern VALUE init_ruby_whisper_context(VALUE *mWhisper); +extern void init_ruby_whisper_context_params(VALUE *cContext); extern void init_ruby_whisper_params(VALUE *mWhisper); extern void init_ruby_whisper_error(VALUE *mWhisper); extern void init_ruby_whisper_segment(VALUE *mWhisper); @@ -45,8 +57,13 @@ extern void init_ruby_whisper_vad_params(VALUE *mVAD); extern void init_ruby_whisper_vad_context(VALUE *mVAD); extern void init_ruby_whisper_vad_segment(VALUE *mVAD); extern void init_ruby_whisper_vad_segments(VALUE *mVAD); +extern void init_ruby_whisper_parakeet(VALUE *mWhisper); extern void register_callbacks(ruby_whisper_params *rwp, VALUE *context); +static ruby_whisper_log_queue whisper_log_queue; + +LOG_SETTABLE_SETUP(whisper_log_queue, mWhisper, whisper_log_set) + /* * call-seq: * lang_max_id -> Integer @@ -102,42 +119,6 @@ static VALUE ruby_whisper_s_system_info_str(VALUE self) { return rb_str_new2(whisper_print_system_info()); } -static VALUE ruby_whisper_s_finalize_log_callback(VALUE self, VALUE id) { - is_log_callback_finalized = true; - return Qnil; -} - -static void -ruby_whisper_log_callback(enum ggml_log_level level, const char * buffer, void * user_data) { - if (is_log_callback_finalized) { - return; - } - VALUE log_callback = rb_iv_get(mWhisper, "log_callback"); - VALUE udata = rb_iv_get(mWhisper, "user_data"); - rb_funcall(log_callback, id_call, 3, INT2NUM(level), rb_str_new2(buffer), udata); -} - -/* - * call-seq: - * log_set ->(level, buffer, user_data) { ... }, user_data -> nil - */ -static VALUE ruby_whisper_s_log_set(VALUE self, VALUE log_callback, VALUE user_data) { - VALUE old_callback = rb_iv_get(self, "log_callback"); - if (!NIL_P(old_callback)) { - rb_undefine_finalizer(old_callback); - } - - rb_iv_set(self, "log_callback", log_callback); - rb_iv_set(self, "user_data", user_data); - - VALUE finalize_log_callback = rb_funcall(mWhisper, rb_intern("method"), 1, rb_str_new2("finalize_log_callback")); - rb_define_finalizer(log_callback, finalize_log_callback); - - whisper_log_set(ruby_whisper_log_callback, NULL); - - return Qnil; -} - void Init_whisper() { id_to_s = rb_intern("to_s"); id_call = rb_intern("call"); @@ -152,9 +133,19 @@ void Init_whisper() { id_coreml_compiled_models = rb_intern("coreml_compiled_models"); id_cache = rb_intern("cache"); id_n_processors = rb_intern("n_processors"); + id_extended = rb_intern("extended"); + id_start_log_callback_thread = rb_intern("start_log_callback_thread"); + id_log_callback_thread = rb_intern("@log_callback_thread"); + id_alive_p = rb_intern("alive?"); + id_join = rb_intern("join"); mWhisper = rb_define_module("Whisper"); + rb_require("whisper/log_settable"); + mLogSettable = rb_path2class("Whisper::LogSettable"); mVAD = rb_define_module_under(mWhisper, "VAD"); + rb_require("whisper/output"); + mOutputContext = rb_path2class("Whisper::Output::Context"); + mOutputSegment = rb_path2class("Whisper::Output::Segment"); rb_define_const(mWhisper, "VERSION", rb_str_new2(whisper_version())); rb_define_const(mWhisper, "LOG_LEVEL_NONE", INT2NUM(GGML_LOG_LEVEL_NONE)); @@ -164,15 +155,32 @@ void Init_whisper() { rb_define_const(mWhisper, "LOG_LEVEL_DEBUG", INT2NUM(GGML_LOG_LEVEL_DEBUG)); rb_define_const(mWhisper, "LOG_LEVEL_CONT", INT2NUM(GGML_LOG_LEVEL_CONT)); + rb_define_const(mWhisper, "AHEADS_NONE", INT2NUM(WHISPER_AHEADS_NONE)); + rb_define_const(mWhisper, "AHEADS_N_TOP_MOST", INT2NUM(WHISPER_AHEADS_N_TOP_MOST)); + rb_define_const(mWhisper, "AHEADS_CUSTOM", INT2NUM(WHISPER_AHEADS_CUSTOM)); + rb_define_const(mWhisper, "AHEADS_TINY_EN", INT2NUM(WHISPER_AHEADS_TINY_EN)); + rb_define_const(mWhisper, "AHEADS_TINY", INT2NUM(WHISPER_AHEADS_TINY)); + rb_define_const(mWhisper, "AHEADS_BASE_EN", INT2NUM(WHISPER_AHEADS_BASE_EN)); + rb_define_const(mWhisper, "AHEADS_BASE", INT2NUM(WHISPER_AHEADS_BASE)); + rb_define_const(mWhisper, "AHEADS_SMALL_EN", INT2NUM(WHISPER_AHEADS_SMALL_EN)); + rb_define_const(mWhisper, "AHEADS_SMALL", INT2NUM(WHISPER_AHEADS_SMALL)); + rb_define_const(mWhisper, "AHEADS_MEDIUM_EN", INT2NUM(WHISPER_AHEADS_MEDIUM_EN)); + rb_define_const(mWhisper, "AHEADS_MEDIUM", INT2NUM(WHISPER_AHEADS_MEDIUM)); + rb_define_const(mWhisper, "AHEADS_LARGE_V1", INT2NUM(WHISPER_AHEADS_LARGE_V1)); + rb_define_const(mWhisper, "AHEADS_LARGE_V2", INT2NUM(WHISPER_AHEADS_LARGE_V2)); + rb_define_const(mWhisper, "AHEADS_LARGE_V3", INT2NUM(WHISPER_AHEADS_LARGE_V3)); + rb_define_const(mWhisper, "AHEADS_LARGE_V3_TURBO", INT2NUM(WHISPER_AHEADS_LARGE_V3_TURBO)); + rb_define_singleton_method(mWhisper, "lang_max_id", ruby_whisper_s_lang_max_id, 0); rb_define_singleton_method(mWhisper, "lang_id", ruby_whisper_s_lang_id, 1); rb_define_singleton_method(mWhisper, "lang_str", ruby_whisper_s_lang_str, 1); rb_define_singleton_method(mWhisper, "lang_str_full", ruby_whisper_s_lang_str_full, 1); rb_define_singleton_method(mWhisper, "system_info_str", ruby_whisper_s_system_info_str, 0); - rb_define_singleton_method(mWhisper, "log_set", ruby_whisper_s_log_set, 2); - rb_define_private_method(rb_singleton_class(mWhisper), "finalize_log_callback", ruby_whisper_s_finalize_log_callback, 1); - init_ruby_whisper_context(&mWhisper); + LOG_SETTABLE_INIT(whisper_log_queue, mWhisper) + + cContext = init_ruby_whisper_context(&mWhisper); + init_ruby_whisper_context_params(&cContext); init_ruby_whisper_params(&mWhisper); init_ruby_whisper_error(&mWhisper); init_ruby_whisper_segment(&mWhisper); @@ -182,8 +190,10 @@ void Init_whisper() { init_ruby_whisper_vad_segment(&mVAD); init_ruby_whisper_vad_segments(&mVAD); init_ruby_whisper_vad_context(&mVAD); + init_ruby_whisper_parakeet(&mWhisper); - rb_require("whisper/context"); - rb_require("whisper/segment"); rb_require("whisper/model/uri"); + + rb_include_module(cContext, mOutputContext); + rb_include_module(cSegment, mOutputSegment); } diff --git a/bindings/ruby/ext/ruby_whisper.h b/bindings/ruby/ext/ruby_whisper.h index 3f5660c374d..10e90674953 100644 --- a/bindings/ruby/ext/ruby_whisper.h +++ b/bindings/ruby/ext/ruby_whisper.h @@ -1,7 +1,21 @@ #ifndef RUBY_WHISPER_H #define RUBY_WHISPER_H +#include <ruby.h> +#include <ruby/version.h> +#include <ruby/util.h> +#include <ruby/thread.h> +#include <ruby/thread_native.h> +#include <ruby/atomic.h> +#include <ruby/memory_view.h> #include "whisper.h" +#include "parakeet.h" +#include "ruby_whisper_log_settable.h" + +#if RUBY_API_VERSION_MAJOR < 4 +// Exists but not declared as public API +int ruby_thread_has_gvl_p(void); +#endif typedef struct { VALUE *context; @@ -10,10 +24,37 @@ typedef struct { VALUE callbacks; } ruby_whisper_callback_container; +typedef struct ruby_whisper_abort_callback_user_data { + volatile rb_atomic_t is_interrupted; + ruby_whisper_callback_container *callback_container; +} ruby_whisper_abort_callback_user_data; + +typedef struct ruby_whisper_log { + enum ggml_log_level level; + char *text; + size_t length; + size_t capacity; +} ruby_whisper_log; + +typedef struct ruby_whisper_log_queue { + rb_nativethread_lock_t lock; + rb_nativethread_cond_t cond; + bool is_open; + + size_t head; + size_t tail; + size_t size; + ruby_whisper_log *logs; +} ruby_whisper_log_queue; + typedef struct { struct whisper_context *context; } ruby_whisper; +typedef struct ruby_whisper_context_params { + struct whisper_context_params params; +} ruby_whisper_context_params; + typedef struct { struct whisper_full_params params; bool diarize; @@ -35,7 +76,7 @@ typedef struct { typedef struct { whisper_token_data *token_data; - const char *text; + VALUE text; } ruby_whisper_token; typedef struct { @@ -55,6 +96,70 @@ typedef struct { struct whisper_vad_context *context; } ruby_whisper_vad_context; +typedef struct parsed_samples_t { + float *samples; + int n_samples; + rb_memory_view_t memview; + bool memview_exported; +} parsed_samples_t; + +typedef struct { + VALUE *context; + VALUE *params; + float *samples; + int n_samples; +} ruby_whisper_full_args; + +typedef struct ruby_whisper_full_parallel_args { + VALUE *context; + VALUE *params; + float *samples; + int n_samples; + int n_processors; +} ruby_whisper_full_parallel_args; + +typedef struct { + struct parakeet_full_params params; + ruby_whisper_callback_container *new_segment_callback_container; + ruby_whisper_callback_container *new_token_callback_container; + ruby_whisper_callback_container *progress_callback_container; + ruby_whisper_callback_container *encoder_begin_callback_container; + ruby_whisper_callback_container *abort_callback_container; +} ruby_whisper_parakeet_params; + +typedef struct { + struct parakeet_context_params params; +} ruby_whisper_parakeet_context_params; + +typedef struct { + struct parakeet_context *context; +} ruby_whisper_parakeet_context; + +typedef struct { + VALUE context; + int index; +} ruby_whisper_parakeet_segment; + +typedef struct { + parakeet_token_data *token_data; + VALUE text; +} ruby_whisper_parakeet_token; + +typedef struct { + VALUE context; +} ruby_whisper_parakeet_model; + +extern ID id_extended; +extern ID id_log_callback_thread; +extern ID id_start_log_callback_thread; +extern ID id_alive_p; +extern ID id_join; +extern void ruby_whisper_log_queue_initialize(ruby_whisper_log_queue *log_queue); +extern void ruby_whisper_log_queue_open(ruby_whisper_log_queue *log_queue); +extern void ruby_whisper_log_queue_close(ruby_whisper_log_queue *log_queue); +extern void ruby_whisper_log_queue_enqueue(ruby_whisper_log_queue *log_queue, enum ggml_log_level level, const char *text); +extern VALUE ruby_whisper_log_queue_drain(ruby_whisper_log_queue *log_queue); + #define GetContext(obj, rw) do { \ TypedData_Get_Struct((obj), ruby_whisper, &ruby_whisper_type, (rw)); \ if ((rw)->context == NULL) { \ @@ -62,13 +167,28 @@ typedef struct { } \ } while (0) -#define GetToken(obj, rwt) do { \ +#define GetContextParams(obj, rwcp) do { \ + TypedData_Get_Struct((obj), ruby_whisper_context_params, &ruby_whisper_context_params_type, (rwcp)); \ +} while (0) + +#define GetToken(obj, rwt) do { \ TypedData_Get_Struct((obj), ruby_whisper_token, &ruby_whisper_token_type, (rwt)); \ if ((rwt)->token_data == NULL) { \ rb_raise(rb_eRuntimeError, "Not initialized"); \ } \ } while (0) +#define GetVADContext(obj, rwvc) do { \ + TypedData_Get_Struct((obj), ruby_whisper_vad_context, &ruby_whisper_vad_context_type, (rwvc)); \ + if ((rwvc)->context == NULL) { \ + rb_raise(rb_eRuntimeError, "Not initialized"); \ + } \ +} while (0) + +#define GetVADParams(obj, rwvp) do { \ + TypedData_Get_Struct((obj), ruby_whisper_vad_params, &ruby_whisper_vad_params_type, (rwvp)); \ +} while (0) + #define GetVADSegments(obj, rwvss) do { \ TypedData_Get_Struct((obj), ruby_whisper_vad_segments, &ruby_whisper_vad_segments_type, (rwvss)); \ if ((rwvss)->segments == NULL) { \ @@ -76,4 +196,47 @@ typedef struct { } \ } while (0) +#define GetParakeetContextParams(obj, rwpcp) do { \ + TypedData_Get_Struct((obj), ruby_whisper_parakeet_context_params, &ruby_whisper_parakeet_context_params_type, (rwpcp)); \ +} while (0) + +#define GetParakeetContext(obj, rwpc) do { \ + TypedData_Get_Struct((obj), ruby_whisper_parakeet_context, &ruby_whisper_parakeet_context_type, (rwpc)); \ + if ((rwpc)->context == NULL) { \ + rb_raise(rb_eRuntimeError, "Not initialized"); \ + } \ +} while (0) + +#define GetParakeetParams(obj, rwpp) do { \ + TypedData_Get_Struct((obj), ruby_whisper_parakeet_params, &ruby_whisper_parakeet_params_type, (rwpp)); \ + if (!(rwpp)->new_segment_callback_container || \ + !(rwpp)->new_token_callback_container || \ + !(rwpp)->progress_callback_container || \ + !(rwpp)->encoder_begin_callback_container || \ + !(rwpp)->abort_callback_container) { \ + rb_raise(rb_eRuntimeError, "Not initialized"); \ + } \ +} while (0) + +#define GetParakeetSegment(obj, rwps) do { \ + TypedData_Get_Struct((obj), ruby_whisper_parakeet_segment, &ruby_whisper_parakeet_segment_type, (rwps)); \ + if (!(rwps)->context) { \ + rb_raise(rb_eRuntimeError, "Not initialized"); \ + } \ +} while (0) + +#define GetParakeetToken(obj, rwpt) do { \ + TypedData_Get_Struct((obj), ruby_whisper_parakeet_token, &ruby_whisper_parakeet_token_type, (rwpt)); \ + if (!(rwpt)->token_data) { \ + rb_raise(rb_eRuntimeError, "Not initialized"); \ + } \ +} while (0) + +#define GetParakeetModel(obj, rwpm) do { \ + TypedData_Get_Struct((obj), ruby_whisper_parakeet_model, &ruby_whisper_parakeet_model_type, (rwpm)); \ + if (NIL_P((rwpm)->context)) { \ + rb_raise(rb_eRuntimeError, "Not initialized"); \ + } \ +} while (0) + #endif diff --git a/bindings/ruby/ext/ruby_whisper_context.c b/bindings/ruby/ext/ruby_whisper_context.c index a7b5f8513db..9e5fc33e726 100644 --- a/bindings/ruby/ext/ruby_whisper_context.c +++ b/bindings/ruby/ext/ruby_whisper_context.c @@ -1,7 +1,11 @@ -#include <ruby.h> -#include <ruby/memory_view.h> #include "ruby_whisper.h" +#ifdef WORDS_BIGENDIAN + #define IS_BIGENDIAN true +#else + #define IS_BIGENDIAN false +#endif + extern ID id_to_s; extern ID id___method__; extern ID id_to_enum; @@ -20,13 +24,41 @@ extern VALUE eError; extern VALUE cModel; extern const rb_data_type_t ruby_whisper_params_type; +extern const rb_data_type_t ruby_whisper_context_params_type; extern VALUE ruby_whisper_transcribe(int argc, VALUE *argv, VALUE self); extern VALUE rb_whisper_model_s_new(VALUE context); extern VALUE rb_whisper_segment_s_new(VALUE context, int index); -extern void prepare_transcription(ruby_whisper_params *rwp, VALUE *context); +extern void prepare_transcription(ruby_whisper_params *rwp, VALUE *context, int n_processors, ruby_whisper_abort_callback_user_data *abort_callback_user_data); ID transcribe_option_names[1]; +typedef struct fill_samples_args { + float *dest; + VALUE *src; + int n_samples; +} fill_samples_args; + +typedef struct full_without_gvl_args { + struct whisper_context *context; + struct whisper_full_params *params; + float *samples; + int n_samples; + int result; +} full_without_gvl_args; + +typedef struct full_parallel_without_gvl_args { + struct whisper_context *context; + struct whisper_full_params *params; + float *samples; + int n_samples; + int n_processors; + int result; +} full_parallel_without_gvl_args; + +typedef struct full_ubf_args { + ruby_whisper_abort_callback_user_data *abort_callback_user_data; +} full_ubf_args; + static void ruby_whisper_free(ruby_whisper *rw) { @@ -54,7 +86,7 @@ static size_t ruby_whisper_memsize(const void *p) { const ruby_whisper *rw = (const ruby_whisper *)p; - size_t size = sizeof(rw); + size_t size = sizeof(*rw); if (!rw) { return 0; } @@ -124,16 +156,25 @@ ruby_whisper_initialize(int argc, VALUE *argv, VALUE self) { ruby_whisper *rw; VALUE whisper_model_file_path; + VALUE context_params; + struct whisper_context_params params; // TODO: we can support init from buffer here too maybe another ruby object to expose - rb_scan_args(argc, argv, "01", &whisper_model_file_path); + rb_scan_args(argc, argv, "11", &whisper_model_file_path, &context_params); TypedData_Get_Struct(self, ruby_whisper, &ruby_whisper_type, rw); whisper_model_file_path = ruby_whisper_normalize_model_path(whisper_model_file_path); if (!rb_respond_to(whisper_model_file_path, id_to_s)) { rb_raise(rb_eRuntimeError, "Expected file path to model to initialize Whisper::Context"); } - rw->context = whisper_init_from_file_with_params(StringValueCStr(whisper_model_file_path), whisper_context_default_params()); + if (NIL_P(context_params)) { + params = whisper_context_default_params(); + } else { + ruby_whisper_context_params *rwcp; + GetContextParams(context_params, rwcp); + params = rwcp->params; + } + rw->context = whisper_init_from_file_with_params(StringValueCStr(whisper_model_file_path), params); if (rw->context == NULL) { rb_raise(rb_eRuntimeError, "error: failed to initialize whisper context"); } @@ -272,82 +313,219 @@ VALUE ruby_whisper_model_type(VALUE self) return rb_str_new2(whisper_model_type_readable(rw->context)); } -/* - * Run the entire model: PCM -> log mel spectrogram -> encoder -> decoder -> text - * Not thread safe for same context - * Uses the specified decoding strategy to obtain the text. - * - * call-seq: - * full(params, samples, n_samples) -> nil - * full(params, samples) -> nil - * - * The second argument +samples+ must be an array of samples, respond to :length, or be a MemoryView of an array of float. It must be 32 bit float PCM audio data. - */ -VALUE ruby_whisper_full(int argc, VALUE *argv, VALUE self) +static bool +check_memory_view(rb_memory_view_t *memview) { - if (argc < 2 || argc > 3) { - rb_raise(rb_eArgError, "wrong number of arguments (given %d, expected 2..3)", argc); + if (!memview->format) { + rb_warn("currently format is required"); + return false; } - ruby_whisper *rw; - ruby_whisper_params *rwp; - GetContext(self, rw); - VALUE params = argv[0]; - TypedData_Get_Struct(params, ruby_whisper_params, &ruby_whisper_params_type, rwp); - VALUE samples = argv[1]; - int n_samples; - rb_memory_view_t view; - const bool memory_view_available_p = rb_memory_view_available_p(samples); - if (argc == 3) { - n_samples = NUM2INT(argv[2]); - if (TYPE(samples) == T_ARRAY) { - if (RARRAY_LEN(samples) < n_samples) { - rb_raise(rb_eArgError, "samples length %ld is less than n_samples %d", RARRAY_LEN(samples), n_samples); + if (strcmp(memview->format, "f") == 0) { + // accept + } else if (strcmp(memview->format, "e") == 0) { + if (IS_BIGENDIAN) { + rb_warn("currently format \"e\" is only supported on little-endian environment"); + return false; + } + } else { + rb_warn("currently only format \"f\" and \"e\" on little-endian environment is supported for MemoryView, but given: %s", memview->format); + return false; + } + + if (memview->ndim != 1 && !(memview->ndim == 2 && memview->shape[1] == 1)) { + // TODO: Accept ndim == 2 with shape [n_samples, channels] and channels > 1 by averaging the samples in different channels or just taking the first channel + rb_warn("currently only 1 dimensional MemoryView is supported, but given: %zd", memview->ndim); + return false; + } + + return true; +} + +static VALUE +fill_samples(VALUE rb_args) +{ + fill_samples_args *args = (fill_samples_args *)rb_args; + + if (RB_TYPE_P(*args->src, T_ARRAY)) { + for (int i = 0; i < args->n_samples; i++) { + args->dest[i] = RFLOAT_VALUE(rb_ary_entry(*args->src, i)); + } + } else { + // TODO: use rb_block_call + VALUE iter = rb_funcall(*args->src, id_to_enum, 1, rb_str_new2("each")); + for (int i = 0; i < args->n_samples; i++) { + // TODO: check if iter is exhausted and raise ArgumentError appropriately + VALUE sample = rb_funcall(iter, id_next, 0); + args->dest[i] = RFLOAT_VALUE(sample); + } + } + + return Qnil; +} + +parsed_samples_t +parse_samples(VALUE *samples, VALUE *n_samples) +{ + bool memview_available = rb_memory_view_available_p(*samples); + struct parsed_samples_t parsed = {0}; + parsed.memview_exported = false; + const bool is_array = RB_TYPE_P(*samples, T_ARRAY); + + if (!NIL_P(*n_samples)) { + parsed.n_samples = NUM2INT(*n_samples); + if (is_array) { + if (RARRAY_LEN(*samples) < parsed.n_samples) { + rb_raise(rb_eArgError, "samples length %ld is less than n_samples %d", RARRAY_LEN(*samples), parsed.n_samples); } } // Should check when samples.respond_to?(:length)? } else { - if (TYPE(samples) == T_ARRAY) { - if (RARRAY_LEN(samples) > INT_MAX) { + if (is_array) { + if (RARRAY_LEN(*samples) > INT_MAX) { rb_raise(rb_eArgError, "samples are too long"); } - n_samples = (int)RARRAY_LEN(samples); - } else if (memory_view_available_p) { - if (!rb_memory_view_get(samples, &view, RUBY_MEMORY_VIEW_SIMPLE)) { - view.obj = Qnil; - rb_raise(rb_eArgError, "unable to get a memory view"); + parsed.n_samples = (int)RARRAY_LEN(*samples); + } else if (memview_available) { + bool memview_got = rb_memory_view_get(*samples, &parsed.memview, RUBY_MEMORY_VIEW_SIMPLE); + if (memview_got) { + parsed.memview_exported = check_memory_view(&parsed.memview); + if (!parsed.memview_exported) { + rb_memory_view_release(&parsed.memview); + parsed.memview = (rb_memory_view_t){0}; + } } - ssize_t n_samples_size = view.byte_size / view.item_size; - if (n_samples_size > INT_MAX) { - rb_raise(rb_eArgError, "samples are too long"); + if (parsed.memview_exported) { + ssize_t n_samples_size = parsed.memview.byte_size / parsed.memview.item_size; + if (n_samples_size > INT_MAX) { + rb_memory_view_release(&parsed.memview); + rb_raise(rb_eArgError, "samples are too long: %zd", n_samples_size); + } + parsed.n_samples = (int)n_samples_size; + } else { + rb_warn("unable to get a memory view. falls back to Ruby object"); + if (rb_respond_to(*samples, id_length)) { + parsed.n_samples = NUM2INT(rb_funcall(*samples, id_length, 0)); + } else { + rb_raise(rb_eArgError, "samples must respond to :length"); + } } - n_samples = (int)n_samples_size; - } else if (rb_respond_to(samples, id_length)) { - n_samples = NUM2INT(rb_funcall(samples, id_length, 0)); + } else if (rb_respond_to(*samples, id_length)) { + parsed.n_samples = NUM2INT(rb_funcall(*samples, id_length, 0)); } else { - rb_raise(rb_eArgError, "samples must respond to :length or be a MemoryView of an array of flaot when n_samples is not given"); + rb_raise(rb_eArgError, "samples must respond to :length or be a MemoryView of an array of float when n_samples is not given"); } } - float * c_samples = (float *)malloc(n_samples * sizeof(float)); - if (memory_view_available_p) { - c_samples = (float *)view.data; + + if (parsed.memview_exported) { + parsed.samples = (float *)parsed.memview.data; } else { - if (TYPE(samples) == T_ARRAY) { - for (int i = 0; i < n_samples; i++) { - c_samples[i] = RFLOAT_VALUE(rb_ary_entry(samples, i)); - } - } else { - // TODO: use rb_block_call - VALUE iter = rb_funcall(samples, id_to_enum, 1, rb_str_new2("each")); - for (int i = 0; i < n_samples; i++) { - // TODO: check if iter is exhausted and raise ArgumentError appropriately - VALUE sample = rb_funcall(iter, id_next, 0); - c_samples[i] = RFLOAT_VALUE(sample); - } + parsed.samples = ALLOC_N(float, parsed.n_samples); + fill_samples_args args = { + parsed.samples, + samples, + parsed.n_samples, + }; + int state; + rb_protect(fill_samples, (VALUE)&args, &state); + if (state) { + xfree(parsed.samples); + rb_jump_tag(state); } } - prepare_transcription(rwp, &self); - const int result = whisper_full(rw->context, rwp->params, c_samples, n_samples); + + return parsed; +} + +VALUE +release_samples(VALUE rb_parsed_args) +{ + parsed_samples_t *parsed_args = (parsed_samples_t *)rb_parsed_args; + + if (parsed_args->memview_exported) { + rb_memory_view_release(&parsed_args->memview); + } else { + xfree(parsed_args->samples); + } + *parsed_args = (parsed_samples_t){0}; + + return Qnil; +} + +static void* +full_without_gvl(void *rb_args) +{ + full_without_gvl_args *args = (full_without_gvl_args *)rb_args; + args->result = whisper_full(args->context, *args->params, args->samples, args->n_samples); + return NULL; +} + +static void +full_ubf(void *rb_args) +{ + full_ubf_args *args = (full_ubf_args *)rb_args; + + RUBY_ATOMIC_SET(args->abort_callback_user_data->is_interrupted, 1); +} + +VALUE +full_body(VALUE rb_args) +{ + ruby_whisper_full_args *args = (ruby_whisper_full_args *)rb_args; + + ruby_whisper *rw; + ruby_whisper_params *rwp; + GetContext(*args->context, rw); + TypedData_Get_Struct(*args->params, ruby_whisper_params, &ruby_whisper_params_type, rwp); + + ruby_whisper_abort_callback_user_data abort_callback_user_data = { + 0, + NULL, + }; + prepare_transcription(rwp, args->context, 1, &abort_callback_user_data); + + struct full_without_gvl_args full_without_gvl_args = { + rw->context, + &rwp->params, + args->samples, + args->n_samples, + 0, + }; + full_ubf_args full_ubf_args = { + &abort_callback_user_data, + }; + rb_thread_call_without_gvl(full_without_gvl, (void *)&full_without_gvl_args, full_ubf, (void *)&full_ubf_args); + return INT2NUM(full_without_gvl_args.result); +} + +/* + * Run the entire model: PCM -> log mel spectrogram -> encoder -> decoder -> text + * Not thread safe for same context + * Uses the specified decoding strategy to obtain the text. + * + * call-seq: + * full(params, samples, n_samples) -> nil + * full(params, samples) -> nil + * + * The second argument +samples+ must be an array of samples, respond to :length, or be a MemoryView of an array of float. It must be 32 bit float PCM audio data. + */ +VALUE ruby_whisper_full(int argc, VALUE *argv, VALUE self) +{ + if (argc < 2 || argc > 3) { + rb_raise(rb_eArgError, "wrong number of arguments (given %d, expected 2..3)", argc); + } + + VALUE n_samples = argc == 2 ? Qnil : argv[2]; + + struct parsed_samples_t parsed = parse_samples(&argv[1], &n_samples); + ruby_whisper_full_args args = { + &self, + &argv[0], + parsed.samples, + parsed.n_samples, + }; + VALUE rb_result = rb_ensure(full_body, (VALUE)&args, release_samples, (VALUE)&parsed); + const int result = NUM2INT(rb_result); if (0 == result) { return self; } else { @@ -355,6 +533,45 @@ VALUE ruby_whisper_full(int argc, VALUE *argv, VALUE self) } } +static void* +full_parallel_without_gvl(void *rb_args) +{ + full_parallel_without_gvl_args *args = (full_parallel_without_gvl_args *)rb_args; + args->result = whisper_full_parallel(args->context, *args->params, args->samples, args->n_samples, args->n_processors); + return NULL; +} + +VALUE +full_parallel_body(VALUE rb_args) +{ + ruby_whisper_full_parallel_args *args = (ruby_whisper_full_parallel_args *)rb_args; + + ruby_whisper *rw; + ruby_whisper_params *rwp; + GetContext(*args->context, rw); + TypedData_Get_Struct(*args->params, ruby_whisper_params, &ruby_whisper_params_type, rwp); + + ruby_whisper_abort_callback_user_data abort_callback_user_data = { + 0, + NULL, + }; + prepare_transcription(rwp, args->context, args->n_processors, &abort_callback_user_data); + + struct full_parallel_without_gvl_args full_parallel_without_gvl_args = { + rw->context, + &rwp->params, + args->samples, + args->n_samples, + args->n_processors, + 0, + }; + full_ubf_args full_ubf_args = { + &abort_callback_user_data, + }; + rb_thread_call_without_gvl(full_parallel_without_gvl, (void *)&full_parallel_without_gvl_args, full_ubf, (void *)&full_ubf_args); + return INT2NUM(full_parallel_without_gvl_args.result); +} + /* * Split the input audio in chunks and process each chunk separately using whisper_full_with_state() * Result is stored in the default state of the context @@ -372,19 +589,11 @@ static VALUE ruby_whisper_full_parallel(int argc, VALUE *argv,VALUE self) { if (argc < 2 || argc > 4) { - rb_raise(rb_eArgError, "wrong number of arguments (given %d, expected 2..3)", argc); + rb_raise(rb_eArgError, "wrong number of arguments (given %d, expected 2..4)", argc); } - ruby_whisper *rw; - ruby_whisper_params *rwp; - GetContext(self, rw); - VALUE params = argv[0]; - TypedData_Get_Struct(params, ruby_whisper_params, &ruby_whisper_params_type, rwp); - VALUE samples = argv[1]; - int n_samples; + VALUE n_samples = argc == 2 ? Qnil : argv[2]; int n_processors; - rb_memory_view_t view; - const bool memory_view_available_p = rb_memory_view_available_p(samples); switch (argc) { case 2: n_processors = 1; @@ -396,56 +605,16 @@ ruby_whisper_full_parallel(int argc, VALUE *argv,VALUE self) n_processors = NUM2INT(argv[3]); break; } - if (argc >= 3 && !NIL_P(argv[2])) { - n_samples = NUM2INT(argv[2]); - if (TYPE(samples) == T_ARRAY) { - if (RARRAY_LEN(samples) < n_samples) { - rb_raise(rb_eArgError, "samples length %ld is less than n_samples %d", RARRAY_LEN(samples), n_samples); - } - } - // Should check when samples.respond_to?(:length)? - } else if (memory_view_available_p) { - if (!rb_memory_view_get(samples, &view, RUBY_MEMORY_VIEW_SIMPLE)) { - view.obj = Qnil; - rb_raise(rb_eArgError, "unable to get a memory view"); - } - ssize_t n_samples_size = view.byte_size / view.item_size; - if (n_samples_size > INT_MAX) { - rb_raise(rb_eArgError, "samples are too long"); - } - n_samples = (int)n_samples_size; - } else { - if (TYPE(samples) == T_ARRAY) { - if (RARRAY_LEN(samples) > INT_MAX) { - rb_raise(rb_eArgError, "samples are too long"); - } - n_samples = (int)RARRAY_LEN(samples); - } else if (rb_respond_to(samples, id_length)) { - n_samples = NUM2INT(rb_funcall(samples, id_length, 0)); - } else { - rb_raise(rb_eArgError, "samples must respond to :length or be a MemoryView of an array of flaot when n_samples is not given"); - } - } - float * c_samples = (float *)malloc(n_samples * sizeof(float)); - if (memory_view_available_p) { - c_samples = (float *)view.data; - } else { - if (TYPE(samples) == T_ARRAY) { - for (int i = 0; i < n_samples; i++) { - c_samples[i] = RFLOAT_VALUE(rb_ary_entry(samples, i)); - } - } else { - // FIXME: use rb_block_call - VALUE iter = rb_funcall(samples, id_to_enum, 1, rb_str_new2("each")); - for (int i = 0; i < n_samples; i++) { - // TODO: check if iter is exhausted and raise ArgumentError - VALUE sample = rb_funcall(iter, id_next, 0); - c_samples[i] = RFLOAT_VALUE(sample); - } - } - } - prepare_transcription(rwp, &self); - const int result = whisper_full_parallel(rw->context, rwp->params, c_samples, n_samples, n_processors); + struct parsed_samples_t parsed = parse_samples(&argv[1], &n_samples); + const ruby_whisper_full_parallel_args args = { + &self, + &argv[0], + parsed.samples, + parsed.n_samples, + n_processors, + }; + const VALUE rb_result = rb_ensure(full_parallel_body, (VALUE)&args, release_samples, (VALUE)&parsed); + const int result = NUM2INT(rb_result); if (0 == result) { return self; } else { @@ -631,7 +800,7 @@ ruby_whisper_get_model(VALUE self) return rb_whisper_model_s_new(self); } -void +VALUE init_ruby_whisper_context(VALUE *mWhisper) { cContext = rb_define_class_under(*mWhisper, "Context", rb_cObject); @@ -669,4 +838,6 @@ init_ruby_whisper_context(VALUE *mWhisper) rb_define_method(cContext, "each_segment", ruby_whisper_each_segment, 0); rb_define_method(cContext, "model", ruby_whisper_get_model, 0); + + return cContext; } diff --git a/bindings/ruby/ext/ruby_whisper_context_params.c b/bindings/ruby/ext/ruby_whisper_context_params.c new file mode 100644 index 00000000000..87df21d4b5e --- /dev/null +++ b/bindings/ruby/ext/ruby_whisper_context_params.c @@ -0,0 +1,163 @@ +#include "ruby_whisper.h" + +#define NUM_PARAMS 6 + +#define DEF_BOOLEAN_ATTR_METHOD(name) \ +static VALUE \ +ruby_whisper_context_params_get_ ## name(VALUE self) { \ + ruby_whisper_context_params *rwcp; \ + GetContextParams(self, rwcp); \ + return rwcp->params.name ? Qtrue : Qfalse; \ +} \ +static VALUE \ +ruby_whisper_context_params_set_ ## name(VALUE self, VALUE value) { \ + ruby_whisper_context_params *rwcp; \ + GetContextParams(self, rwcp); \ + rwcp->params.name = RTEST(value); \ + return value; \ +} + +#define DEF_INT_ATTR_METHOD(name) \ +static VALUE \ +ruby_whisper_context_params_get_ ## name(VALUE self) { \ + ruby_whisper_context_params *rwcp; \ + GetContextParams(self, rwcp); \ + return INT2NUM(rwcp->params.name); \ +} \ +static VALUE \ +ruby_whisper_context_params_set_ ## name(VALUE self, VALUE value) { \ + ruby_whisper_context_params *rwcp; \ + GetContextParams(self, rwcp); \ + rwcp->params.name = NUM2INT(value); \ + return value; \ +} + +#define DEFINE_PARAM(param_name, nth) \ + id_ ## param_name = rb_intern(#param_name); \ + param_names[nth] = id_ ## param_name; \ + rb_define_method(cContextParams, #param_name, ruby_whisper_context_params_get_ ## param_name, 0); \ + rb_define_method(cContextParams, #param_name "=", ruby_whisper_context_params_set_ ## param_name, 1); + +VALUE cContextParams; + +static ID param_names[NUM_PARAMS]; +static ID id_use_gpu; +static ID id_flash_attn; +static ID id_gpu_device; +static ID id_dtw_token_timestamps; +static ID id_dtw_aheads_preset; +static ID id_dtw_n_top; + +static size_t +ruby_whisper_context_params_memsize(const void *p) +{ + const ruby_whisper_context_params *rwcp = (ruby_whisper_context_params *)p; + if (!rwcp) { + return 0; + } + return sizeof(ruby_whisper_context_params); +} + +const rb_data_type_t ruby_whisper_context_params_type = { + "ruby_whisper_context_params", + {0, RUBY_DEFAULT_FREE, ruby_whisper_context_params_memsize,}, + 0, 0, + 0 +}; + +static VALUE +ruby_whisper_context_params_s_allocate(VALUE klass) +{ + ruby_whisper_context_params *rwcp; + return TypedData_Make_Struct(klass, ruby_whisper_context_params, &ruby_whisper_context_params_type, rwcp); +} + +DEF_BOOLEAN_ATTR_METHOD(use_gpu); +DEF_BOOLEAN_ATTR_METHOD(flash_attn); +DEF_INT_ATTR_METHOD(gpu_device); +DEF_BOOLEAN_ATTR_METHOD(dtw_token_timestamps); +DEF_INT_ATTR_METHOD(dtw_aheads_preset); + +static VALUE +ruby_whisper_context_params_get_dtw_n_top(VALUE self) { + ruby_whisper_context_params *rwcp; + GetContextParams(self, rwcp); + + int dtw_n_top = rwcp->params.dtw_n_top; + + return dtw_n_top == -1 ? Qnil : INT2NUM(dtw_n_top); +} + +static VALUE +ruby_whisper_context_params_set_dtw_n_top(VALUE self, VALUE value) { + ruby_whisper_context_params *rwcp; + GetContextParams(self, rwcp); + + rwcp->params.dtw_n_top = NIL_P(value) ? -1 : NUM2INT(value); + + return value; +} + +#define SET_PARAM_IF_SAME(param_name) \ + if (id == id_ ## param_name) { \ + ruby_whisper_context_params_set_ ## param_name(self, value); \ + continue; \ + } + +static VALUE +ruby_whisper_context_params_initialize(int argc, VALUE *argv, VALUE self) +{ + ruby_whisper_context_params *rwcp; + TypedData_Get_Struct(self, ruby_whisper_context_params, &ruby_whisper_context_params_type, rwcp); + rwcp->params = whisper_context_default_params(); + + VALUE kw_hash; + rb_scan_args_kw(RB_SCAN_ARGS_KEYWORDS, argc, argv, ":", &kw_hash); + if (NIL_P(kw_hash)) { + return Qnil; + } + + VALUE values[NUM_PARAMS] = {Qundef}; + rb_get_kwargs(kw_hash, param_names, 0, NUM_PARAMS, values); + + ID id; + VALUE value; + for (int i = 0; i < NUM_PARAMS; i++) { + id = param_names[i]; + value = values[i]; + if (value == Qundef) { + continue; + } + SET_PARAM_IF_SAME(use_gpu) + SET_PARAM_IF_SAME(flash_attn) + SET_PARAM_IF_SAME(gpu_device) + SET_PARAM_IF_SAME(dtw_token_timestamps) + SET_PARAM_IF_SAME(dtw_aheads_preset) + SET_PARAM_IF_SAME(dtw_n_top) + } + + return Qnil; +} + +#undef SET_PARAM_IF_SAME + +void +init_ruby_whisper_context_params(VALUE *cContext) +{ + cContextParams = rb_define_class_under(*cContext, "Params", rb_cObject); + + rb_define_alloc_func(cContextParams, ruby_whisper_context_params_s_allocate); + rb_define_method(cContextParams, "initialize", ruby_whisper_context_params_initialize, -1); + + DEFINE_PARAM(use_gpu, 0) + DEFINE_PARAM(flash_attn, 1) + DEFINE_PARAM(gpu_device, 2) + DEFINE_PARAM(dtw_token_timestamps, 3) + DEFINE_PARAM(dtw_aheads_preset, 4) + DEFINE_PARAM(dtw_n_top, 5) +} + +#undef DEFINE_PARAM +#undef DEF_INT_ATTR_METHOD +#undef DEF_BOOLEAN_ATTR_METHOD +#undef NUM_PARAMS diff --git a/bindings/ruby/ext/ruby_whisper_log_queue.c b/bindings/ruby/ext/ruby_whisper_log_queue.c new file mode 100644 index 00000000000..6558a339c6f --- /dev/null +++ b/bindings/ruby/ext/ruby_whisper_log_queue.c @@ -0,0 +1,180 @@ +#include "ruby_whisper.h" + +#define LOG_QUEUE_CAPACITY 256 +#define LOG_DEFAULT_CAPACITY 1024 + +void +ruby_whisper_log_queue_initialize(ruby_whisper_log_queue *log_queue) +{ + rb_nativethread_lock_initialize(&log_queue->lock); + rb_native_cond_initialize(&log_queue->cond); + log_queue->head = 0; + log_queue->tail = 0; + log_queue->size = 0; + log_queue->is_open = true; + log_queue->logs = ALLOC_N(ruby_whisper_log, LOG_QUEUE_CAPACITY); + for (size_t i = 0; i < LOG_QUEUE_CAPACITY; i++) { + // we cannot call Ruby API like ALLOC_N because this slot may be realloced without GVL + // this doesn't be freed because log queue lives until the end of process + char *slot = malloc(sizeof(char) * LOG_QUEUE_CAPACITY); + if (!slot) { + rb_raise(rb_eRuntimeError, "Could not allocate memory for log text"); + } + ruby_whisper_log log = { + 0, + slot, + 0, + LOG_QUEUE_CAPACITY, + }; + log_queue->logs[i] = log; + } +} + +void +ruby_whisper_log_queue_open(ruby_whisper_log_queue *log_queue) +{ + rb_nativethread_lock_lock(&log_queue->lock); + + log_queue->is_open = true; + + rb_native_cond_signal(&log_queue->cond); + + rb_nativethread_lock_unlock(&log_queue->lock); +} + +void +ruby_whisper_log_queue_close(ruby_whisper_log_queue *log_queue) +{ + rb_nativethread_lock_lock(&log_queue->lock); + + log_queue->is_open = false; + rb_native_cond_broadcast(&log_queue->cond); + + rb_nativethread_lock_unlock(&log_queue->lock); +} + +static size_t +calc_enough_cap(size_t len) +{ + size_t quot = len / LOG_DEFAULT_CAPACITY; + size_t rem = len % LOG_DEFAULT_CAPACITY; + + return sizeof(char) * (rem == 0 ? quot : quot + 1) * LOG_DEFAULT_CAPACITY; +} + +void +ruby_whisper_log_queue_enqueue(ruby_whisper_log_queue *log_queue, enum ggml_log_level level, const char *text) +{ + rb_nativethread_lock_lock(&log_queue->lock); + + if (!log_queue->is_open) { + rb_nativethread_lock_unlock(&log_queue->lock); + return; + } + + size_t len = strlen(text); + ruby_whisper_log *log = &log_queue->logs[log_queue->head]; + if (len > log->capacity) { + size_t new_cap = calc_enough_cap(len); + // we cannot call Ruby API like REALLOC_N because this function is called without GVL + char *slot = realloc(log->text, new_cap); + if (!slot) { + rb_nativethread_lock_unlock(&log_queue->lock); + return; + } + log->text = slot; + log->capacity = new_cap; + } + // we cannot call Ruby API like MEMCPY because this function is called without GVL + memcpy(log->text, text, sizeof(char) * len); + log->length = len; + log->level = level; + log_queue->head = (log_queue->head + 1) % LOG_QUEUE_CAPACITY; + bool is_full = log_queue->size >= LOG_QUEUE_CAPACITY; + log_queue->size = is_full ? LOG_QUEUE_CAPACITY : log_queue->size + 1; + if (is_full) { + log_queue->tail = log_queue->head; + } + + rb_native_cond_signal(&log_queue->cond); + rb_nativethread_lock_unlock(&log_queue->lock); +} + +static void* +ruby_whisper_log_queue_wait(void *args) +{ + ruby_whisper_log_queue *log_queue = (ruby_whisper_log_queue *)args; + + rb_native_cond_wait(&log_queue->cond, &log_queue->lock); + rb_nativethread_lock_unlock(&log_queue->lock); + + return NULL; +} + +static void +ruby_whisper_log_queue_wait_ubf(void *args) +{ + ruby_whisper_log_queue *log_queue = (ruby_whisper_log_queue *)args; + + rb_native_cond_broadcast(&log_queue->cond); +} + +typedef struct { + enum ggml_log_level level; + size_t length; + char *text; +} log_snapshot; + +VALUE +ruby_whisper_log_queue_drain(ruby_whisper_log_queue *log_queue) +{ + log_snapshot logs[LOG_QUEUE_CAPACITY]; + + rb_nativethread_lock_lock(&log_queue->lock); + + while (log_queue->size == 0 && log_queue->is_open) { + rb_thread_call_without_gvl(ruby_whisper_log_queue_wait, (void *)log_queue, ruby_whisper_log_queue_wait_ubf, (void *)log_queue); + rb_nativethread_lock_lock(&log_queue->lock); + } + + if (log_queue->size == 0 && !log_queue->is_open) { + rb_native_cond_broadcast(&log_queue->cond); + rb_nativethread_lock_unlock(&log_queue->lock); + return Qnil; + } + + size_t size = log_queue->size; + ruby_whisper_log *log; + size_t i; + for (i = 0; i < size; i++) { + log = &log_queue->logs[(log_queue->tail + i) % LOG_QUEUE_CAPACITY]; + logs[i].level = log->level; + logs[i].length = log->length; + char *text = malloc(log->length); + if (!text) { + logs[i].text = NULL; + continue; + } + logs[i].text = text; + memcpy(logs[i].text, log->text, log->length); + } + log_queue->size = 0; + log_queue->tail = log_queue->head; + + rb_native_cond_signal(&log_queue->cond); + + rb_nativethread_lock_unlock(&log_queue->lock); + + VALUE rb_logs = rb_ary_new2(size); + VALUE rb_text; + for (i = 0; i < size; i++) { + if (!logs[i].text) { + continue; + } + rb_text = rb_str_new(logs[i].text, logs[i].length); + free(logs[i].text); + rb_ary_push(rb_logs, rb_ary_new3(2, INT2NUM(logs[i].level), rb_text)); + } + + return rb_logs; +} diff --git a/bindings/ruby/ext/ruby_whisper_log_settable.h b/bindings/ruby/ext/ruby_whisper_log_settable.h new file mode 100644 index 00000000000..b98fbac826b --- /dev/null +++ b/bindings/ruby/ext/ruby_whisper_log_settable.h @@ -0,0 +1,47 @@ +#ifndef RUBY_WHISPER_LOG_SETTABLE_H +#define RUBY_WHISPER_LOG_SETTABLE_H + +#define LOG_SETTABLE_SETUP(log_queue, mod, log_set) \ + static VALUE \ + ruby_whisper_##log_queue##_s_drain_logs(VALUE self) \ + { \ + return ruby_whisper_log_queue_drain(&log_queue); \ + } \ + static void \ + ruby_whisper_##log_queue##_log_callback(enum ggml_log_level level, const char *text, void *user_data) \ + { \ + ruby_whisper_log_queue_enqueue(&log_queue, level, text); \ + } \ + static VALUE \ + ruby_whisper_##log_queue##_s_log_set(VALUE self, VALUE log_callback, VALUE user_data) \ + { \ + rb_iv_set(self, "@log_callback", log_callback); \ + rb_iv_set(self, "@log_callback_user_data", user_data); \ + if (NIL_P(log_callback)) { \ + log_set(NULL, NULL); \ + } else { \ + ruby_whisper_log_queue_open(&log_queue); \ + rb_funcall((mod), id_start_log_callback_thread, 0); \ + log_set(ruby_whisper_##log_queue##_log_callback, NULL); \ + } \ + return Qnil; \ + } \ + static void \ + ruby_whisper_##log_queue##_end_proc(VALUE args) \ + { \ + ruby_whisper_log_queue_close(&log_queue); \ + VALUE log_callback_thread = rb_ivar_get(mod, id_log_callback_thread); \ + if (!NIL_P(log_callback_thread) && RTEST(rb_funcall(log_callback_thread, id_alive_p, 0))) { \ + rb_funcall(log_callback_thread, id_join, 0); \ + } \ + } + +#define LOG_SETTABLE_INIT(log_queue, mod) \ + ruby_whisper_log_queue_initialize(&log_queue); \ + rb_define_singleton_method(mod, "drain_logs", ruby_whisper_##log_queue##_s_drain_logs, 0); \ + rb_define_singleton_method(mod, "log_set", ruby_whisper_##log_queue##_s_log_set, 2); \ + rb_set_end_proc(ruby_whisper_##log_queue##_end_proc, Qnil); \ + rb_extend_object(mod, mLogSettable); \ + rb_funcall(mLogSettable, id_extended, 1, mod); + +#endif diff --git a/bindings/ruby/ext/ruby_whisper_model.c b/bindings/ruby/ext/ruby_whisper_model.c index b196a8b5cb5..0e91fb3f87f 100644 --- a/bindings/ruby/ext/ruby_whisper_model.c +++ b/bindings/ruby/ext/ruby_whisper_model.c @@ -1,4 +1,3 @@ -#include <ruby.h> #include "ruby_whisper.h" extern const rb_data_type_t ruby_whisper_type; diff --git a/bindings/ruby/ext/ruby_whisper_parakeet.c b/bindings/ruby/ext/ruby_whisper_parakeet.c new file mode 100644 index 00000000000..d69369401d0 --- /dev/null +++ b/bindings/ruby/ext/ruby_whisper_parakeet.c @@ -0,0 +1,49 @@ +#include "ruby_whisper.h" +#include <stdio.h> +#include <unistd.h> + +extern VALUE mParakeet; +extern VALUE mLogSettable; +extern VALUE cParakeetContext; +extern VALUE cParakeetSegment; +extern VALUE mOutputContext; +extern VALUE mOutputSegment; + +extern void init_ruby_whisper_parakeet_params(VALUE *mParakeet); +extern void init_ruby_whisper_parakeet_token(VALUE *mParakeet); +extern void init_ruby_whisper_parakeet_segment(VALUE *mParakeet); +extern VALUE init_ruby_whisper_parakeet_context(VALUE *mParakeet); +extern void init_ruby_whisper_parakeet_context_params(VALUE *cParakeetContext); +extern void init_ruby_whisper_parakeet_model(VALUE *mParakeet); + +static ruby_whisper_log_queue parakeet_log_queue; + +LOG_SETTABLE_SETUP(parakeet_log_queue, mParakeet, parakeet_log_set) + +static VALUE +ruby_whisper_parakeet_s_system_info_str(VALUE self) +{ + return rb_str_new2(parakeet_print_system_info()); +} + +void +init_ruby_whisper_parakeet(VALUE *mWhisper) +{ + mParakeet = rb_define_module_under(*mWhisper, "Parakeet"); + + rb_define_const(mParakeet, "VERSION", rb_str_new2(parakeet_version())); + + LOG_SETTABLE_INIT(parakeet_log_queue, mParakeet) + + rb_define_singleton_method(mParakeet, "system_info_str", ruby_whisper_parakeet_s_system_info_str, 0); + + init_ruby_whisper_parakeet_params(&mParakeet); + init_ruby_whisper_parakeet_token(&mParakeet); + init_ruby_whisper_parakeet_segment(&mParakeet); + cParakeetContext = init_ruby_whisper_parakeet_context(&mParakeet); + init_ruby_whisper_parakeet_context_params(&cParakeetContext); + init_ruby_whisper_parakeet_model(&mParakeet); + + rb_include_module(cParakeetContext, mOutputContext); + rb_include_module(cParakeetSegment, mOutputSegment); +} diff --git a/bindings/ruby/ext/ruby_whisper_parakeet_context.c b/bindings/ruby/ext/ruby_whisper_parakeet_context.c new file mode 100644 index 00000000000..b4a2fc5c4b7 --- /dev/null +++ b/bindings/ruby/ext/ruby_whisper_parakeet_context.c @@ -0,0 +1,304 @@ +#include "ruby_whisper.h" + +#define ITERATE_SEGMENT_ATTRS(ITERATOR) \ + ITERATOR(get_segment_t0, LONG) \ + ITERATOR(get_segment_t1, LONG) \ + ITERATOR(get_segment_text, STRING) \ + ITERATOR(n_tokens, INT) + +#define ITERATE_TOKEN_ATTRS(ITERATOR) \ + ITERATOR(get_token_text, STRING) \ + ITERATOR(get_token_id, INT) \ + ITERATOR(get_token_p, FLOAT) + +#define VAL_FROM_LONG(v) LONG2NUM(v) +#define VAL_FROM_STRING(v) rb_utf8_str_new_cstr(v) +#define VAL_FROM_INT(v) INT2NUM(v) +#define VAL_FROM_FLOAT(v) DBL2NUM(v) +#define READER(type) VAL_FROM_##type + +extern ID id_to_s; +extern ID id___method__; +extern ID id_to_enum; +extern ID id_new; + +extern VALUE cParakeetContext; +extern VALUE eError; + +extern VALUE ruby_whisper_normalize_model_path(VALUE model_path); +extern VALUE ruby_whisper_parakeet_transcribe(VALUE self, VALUE audio_path, VALUE params); +extern VALUE ruby_whisper_parakeet_segment_init(VALUE context, int index); +extern parsed_samples_t parse_samples(VALUE *samples, VALUE *n_samples); +extern VALUE release_samples(VALUE rb_parsed_args); +extern void ruby_whisper_parakeet_prepare_transcription(ruby_whisper_parakeet_params *rwpp, VALUE *context, ruby_whisper_abort_callback_user_data *abort_callback_user_data); +extern rb_data_type_t ruby_whisper_parakeet_params_type; +extern rb_data_type_t ruby_whisper_parakeet_context_params_type; +extern VALUE ruby_whisper_parakeet_token_s_from_token_data(struct parakeet_context *context, const parakeet_token_data *token_data); +extern VALUE ruby_whisper_parakeet_model_s_new(VALUE context); + +static void +ruby_whisper_parakeet_context_free(void *p) +{ + ruby_whisper_parakeet_context *rwpc = (ruby_whisper_parakeet_context *)p; + if (rwpc->context) { + parakeet_free(rwpc->context); + rwpc->context = NULL; + } + xfree(rwpc); +} + +static size_t +ruby_whisper_parakeet_context_memsize(const void *p) +{ + ruby_whisper_parakeet_context *rwpc = (ruby_whisper_parakeet_context *)p; + if (!rwpc) { + return 0; + } + size_t size = sizeof(*rwpc); + return size; +} + +const rb_data_type_t ruby_whisper_parakeet_context_type = { + "ruby_whisper_parakeet_context", + {0, ruby_whisper_parakeet_context_free, ruby_whisper_parakeet_context_memsize,}, + 0, 0, + 0 +}; + +static VALUE +ruby_whisper_parakeet_context_allocate(VALUE klass) +{ + ruby_whisper_parakeet_context *rwpc; + + VALUE obj = TypedData_Make_Struct(klass, ruby_whisper_parakeet_context, &ruby_whisper_parakeet_context_type, rwpc); + rwpc->context = NULL; + + return obj; +} + +typedef struct { + struct parakeet_context **context; + char *model_path; + struct parakeet_context_params params; +} ruby_whisper_parakeet_context_init_args; + +static void* +ruby_whisper_parakeet_context_init_without_gvl(void *args) +{ + ruby_whisper_parakeet_context_init_args *init_args = (ruby_whisper_parakeet_context_init_args *)args; + *init_args->context = parakeet_init_from_file_with_params(init_args->model_path, init_args->params); + return NULL; +} + +static VALUE +ruby_whisper_parakeet_context_initialize(int argc, VALUE *argv, VALUE self) +{ + ruby_whisper_parakeet_context *rwpc; + VALUE model_path; + VALUE context_params; + struct parakeet_context_params params; + + rb_scan_args(argc, argv, "11", &model_path, &context_params); + TypedData_Get_Struct(self, ruby_whisper_parakeet_context, &ruby_whisper_parakeet_context_type, rwpc); + + model_path = ruby_whisper_normalize_model_path(model_path); + if (!rb_respond_to(model_path, id_to_s)) { + rb_raise(rb_eRuntimeError, "Expected file path to model to initialize Parakeet::Context"); + } + if (NIL_P(context_params)) { + params = parakeet_context_default_params(); + } else { + ruby_whisper_parakeet_context_params *rwpcp; + GetParakeetContextParams(context_params, rwpcp); + params = rwpcp->params; + } + ruby_whisper_parakeet_context_init_args init_args = { + &rwpc->context, + StringValueCStr(model_path), + params, + }; + rb_thread_call_without_gvl(ruby_whisper_parakeet_context_init_without_gvl, (void *)&init_args, NULL, NULL); + if (rwpc->context == NULL) { + rb_raise(rb_eRuntimeError, "Failed to load model"); + } + + return Qnil; +} + +static VALUE +ruby_whisper_parakeet_context_full_n_segments(VALUE self) +{ + ruby_whisper_parakeet_context *rwpc; + GetParakeetContext(self, rwpc); + + return INT2NUM(parakeet_full_n_segments(rwpc->context)); +} + +#define DEF_SEGMENT_ATTR(name, type) \ + static VALUE \ + ruby_whisper_parakeet_context_full_##name(VALUE self, VALUE i_segment) \ + { \ + ruby_whisper_parakeet_context *rwpc; \ + GetParakeetContext(self, rwpc); \ + return READER(type)(parakeet_full_##name(rwpc->context, NUM2INT(i_segment))); \ + } + +ITERATE_SEGMENT_ATTRS(DEF_SEGMENT_ATTR) + +#define DEF_TOKEN_ATTR(name, type) \ + static VALUE \ + ruby_whisper_parakeet_context_full_##name(VALUE self, VALUE i_segment, VALUE i_token) \ + { \ + ruby_whisper_parakeet_context *rwpc; \ + GetParakeetContext(self, rwpc); \ + return READER(type)(parakeet_full_##name(rwpc->context, NUM2INT(i_segment), NUM2INT(i_token))); \ + } + +ITERATE_TOKEN_ATTRS(DEF_TOKEN_ATTR) + +static VALUE +ruby_whisper_parakeet_context_full_get_token_data(VALUE self, VALUE i_segment, VALUE i_token) +{ + ruby_whisper_parakeet_context *rwpc; + GetParakeetContext(self, rwpc); + parakeet_token_data token_data = parakeet_full_get_token_data(rwpc->context, NUM2INT(i_segment), NUM2INT(i_token)); + + return ruby_whisper_parakeet_token_s_from_token_data(rwpc->context, &token_data); +} + +static VALUE +ruby_whisper_parakeet_context_each_segment(VALUE self) +{ + if (!rb_block_given_p()) { + const VALUE method_name = rb_funcall(self, id___method__, 0); + return rb_funcall(self, id_to_enum, 1, method_name); + } + + ruby_whisper_parakeet_context *rwpc; + GetParakeetContext(self, rwpc); + + const int n_segments = parakeet_full_n_segments(rwpc->context); + for (int i = 0; i < n_segments; ++i) { + rb_yield(ruby_whisper_parakeet_segment_init(self, i)); + } + + return self; +} + +typedef struct { + struct parakeet_context *context; + struct parakeet_full_params *params; + float *samples; + int n_samples; + int result; +} parakeet_full_without_gvl_args; + +static void* +parakeet_full_without_gvl(void *rb_args) +{ + parakeet_full_without_gvl_args *args = (parakeet_full_without_gvl_args *)rb_args; + args->result = parakeet_full(args->context, *args->params, args->samples, args->n_samples); + + return NULL; +} + +typedef struct { + ruby_whisper_abort_callback_user_data *abort_callback_user_data; +} parakeet_full_ubf_args; + +static void +parakeet_full_ubf(void *rb_args) +{ + parakeet_full_ubf_args *args = (parakeet_full_ubf_args *)rb_args; + + RUBY_ATOMIC_SET(args->abort_callback_user_data->is_interrupted, 1); +} + +VALUE +ruby_whisper_parakeet_context_full_body(VALUE rb_args) +{ + ruby_whisper_full_args *args = (ruby_whisper_full_args *)rb_args; + ruby_whisper_parakeet_context *rwpc; + GetParakeetContext(*args->context, rwpc); + ruby_whisper_parakeet_params *rwpp; + GetParakeetParams(*args->params, rwpp); + + ruby_whisper_abort_callback_user_data abort_callback_user_data = { + 0, + NULL, + }; + ruby_whisper_parakeet_prepare_transcription(rwpp, args->context, &abort_callback_user_data); + + parakeet_full_without_gvl_args full_without_gvl_args = { + rwpc->context, + &rwpp->params, + args->samples, + args->n_samples, + 0 + }; + parakeet_full_ubf_args full_ubf_args = { + &abort_callback_user_data, + }; + rb_thread_call_without_gvl(parakeet_full_without_gvl, (void *)&full_without_gvl_args, parakeet_full_ubf, (void *)&full_ubf_args); + + return INT2NUM(full_without_gvl_args.result); +} + +static VALUE +ruby_whisper_parakeet_context_full(int argc, VALUE *argv, VALUE self) +{ + if (argc < 2 || argc > 3) { + rb_raise(rb_eArgError, "wrong number of arguments (given %d, expected 2..3)", argc); + } + + VALUE n_samples = argc == 2 ? Qnil : argv[2]; + + struct parsed_samples_t parsed = parse_samples(&argv[1], &n_samples); + ruby_whisper_full_args args = { + &self, + &argv[0], + parsed.samples, + parsed.n_samples, + }; + VALUE rb_result = rb_ensure(ruby_whisper_parakeet_context_full_body, (VALUE)&args, release_samples, (VALUE)&parsed); + const int result = NUM2INT(rb_result); + if (result == 0) { + return self; + } else { + rb_exc_raise(rb_funcall(eError, id_new, 1, rb_result)); + } +} + +static VALUE +ruby_whisper_parakeet_context_get_model(VALUE self) +{ + return ruby_whisper_parakeet_model_s_new(self); +} + +VALUE +init_ruby_whisper_parakeet_context(VALUE *mParakeet) +{ + cParakeetContext = rb_define_class_under(*mParakeet, "Context", rb_cObject); + + rb_define_alloc_func(cParakeetContext, ruby_whisper_parakeet_context_allocate); + + rb_define_method(cParakeetContext, "initialize", ruby_whisper_parakeet_context_initialize, -1); + rb_define_method(cParakeetContext, "transcribe", ruby_whisper_parakeet_transcribe, 2); + rb_define_method(cParakeetContext, "full_n_segments", ruby_whisper_parakeet_context_full_n_segments, 0); + rb_define_method(cParakeetContext, "full_get_token_data", ruby_whisper_parakeet_context_full_get_token_data, 2); + rb_define_method(cParakeetContext, "model", ruby_whisper_parakeet_context_get_model, 0); + rb_define_method(cParakeetContext, "each_segment", ruby_whisper_parakeet_context_each_segment, 0); + rb_define_method(cParakeetContext, "full", ruby_whisper_parakeet_context_full, -1); + +#define REGISTER_SEGMENT_ATTR(name, type) \ + rb_define_method(cParakeetContext, "full_" #name, ruby_whisper_parakeet_context_full_##name, 1); + + ITERATE_SEGMENT_ATTRS(REGISTER_SEGMENT_ATTR) + +#define REGISTER_TOKEN_ATTR(name, type) \ + rb_define_method(cParakeetContext, "full_" #name, ruby_whisper_parakeet_context_full_##name, 2); + + ITERATE_TOKEN_ATTRS(REGISTER_TOKEN_ATTR) + + return cParakeetContext; +} diff --git a/bindings/ruby/ext/ruby_whisper_parakeet_context_params.c b/bindings/ruby/ext/ruby_whisper_parakeet_context_params.c new file mode 100644 index 00000000000..38bd6d57ce1 --- /dev/null +++ b/bindings/ruby/ext/ruby_whisper_parakeet_context_params.c @@ -0,0 +1,117 @@ +#include "ruby_whisper.h" + +#define ITERATE_ATTRS(ITERATOR) \ + ITERATOR(use_gpu, BOOL) \ + ITERATOR(gpu_device, INT) + +#define VAL_FROM_BOOL(v) ((v) ? Qtrue : Qfalse) +#define VAL_TO_BOOL(v) (RTEST(v)) +#define VAL_FROM_INT(v) (INT2NUM(v)) +#define VAL_TO_INT(v) (NUM2INT(v)) +#define READER(type) VAL_FROM_##type +#define WRITER(type) VAL_TO_##type + +#define DEF_ATTR(name, type) \ + static VALUE \ + ruby_whisper_parakeet_context_params_get_##name(VALUE self) \ + { \ + ruby_whisper_parakeet_context_params *rwpcp; \ + GetParakeetContextParams(self, rwpcp); \ + return READER(type)(rwpcp->params.name); \ + } \ + static VALUE \ + ruby_whisper_parakeet_context_params_set_##name(VALUE self, VALUE val) \ + { \ + ruby_whisper_parakeet_context_params *rwpcp; \ + GetParakeetContextParams(self, rwpcp); \ + rwpcp->params.name = WRITER(type)(val); \ + return val; \ + } + +enum { +#define DEF_IDX(name, type) RUBY_WHISPER_PARAKEET_CONTEXT_PARAMS_##name, + + ITERATE_ATTRS(DEF_IDX) + RUBY_WHISPER_PARAKEET_NUM_CONTEXT_PARAMS +}; + +extern VALUE cParakeetContextParams; + +typedef VALUE (*param_writer_t)(VALUE, VALUE); + +static ID param_names[RUBY_WHISPER_PARAKEET_NUM_CONTEXT_PARAMS]; +static param_writer_t param_writers[RUBY_WHISPER_PARAKEET_NUM_CONTEXT_PARAMS]; + +static size_t +ruby_whisper_parakeet_context_params_memsize(const void *p) +{ + if (!p) { + return 0; + } + return sizeof(ruby_whisper_parakeet_context_params); +} + +const rb_data_type_t ruby_whisper_parakeet_context_params_type = { + "ruby_whisper_parakeet_context_params", + {0, RUBY_DEFAULT_FREE, ruby_whisper_parakeet_context_params_memsize,}, + 0, 0, + 0, +}; + +static VALUE +ruby_whisper_parakeet_context_params_s_allocate(VALUE klass) +{ + ruby_whisper_parakeet_context_params *rwpcp; + return TypedData_Make_Struct(klass, ruby_whisper_parakeet_context_params, &ruby_whisper_parakeet_context_params_type, rwpcp); +} + +static VALUE +ruby_whisper_parakeet_context_params_initialize(int argc, VALUE *argv, VALUE self) +{ + VALUE kw_hash; + VALUE values[RUBY_WHISPER_PARAKEET_NUM_CONTEXT_PARAMS] = {Qundef}; + VALUE value; + ruby_whisper_parakeet_context_params *rwpcp; + int i; + + TypedData_Get_Struct(self, ruby_whisper_parakeet_context_params, &ruby_whisper_parakeet_context_params_type, rwpcp); + rwpcp->params = parakeet_context_default_params(); + + rb_scan_args_kw(RB_SCAN_ARGS_KEYWORDS, argc, argv, ":", &kw_hash); + if (NIL_P(kw_hash)) { + return Qnil; + } + + rb_get_kwargs(kw_hash, param_names, 0, RUBY_WHISPER_PARAKEET_NUM_CONTEXT_PARAMS, values); + for (i = 0; i < RUBY_WHISPER_PARAKEET_NUM_CONTEXT_PARAMS; i++) { + value = values[i]; + if (value == Qundef) { + continue; + } + param_writers[i](self, value); + } + + return Qnil; +} + +ITERATE_ATTRS(DEF_ATTR) + +void +init_ruby_whisper_parakeet_context_params(VALUE *cParakeetContext) +{ + cParakeetContextParams = rb_define_class_under(*cParakeetContext, "Params", rb_cObject); + + rb_define_alloc_func(cParakeetContextParams, ruby_whisper_parakeet_context_params_s_allocate); + + rb_define_method(cParakeetContextParams, "initialize", ruby_whisper_parakeet_context_params_initialize, -1); + + int i = 0; +#define REGISTER_ATTR(name, type) \ + param_names[i] = rb_intern(#name); \ + param_writers[i] = ruby_whisper_parakeet_context_params_set_##name; \ + rb_define_method(cParakeetContextParams, #name, ruby_whisper_parakeet_context_params_get_##name, 0); \ + rb_define_method(cParakeetContextParams, #name "=", ruby_whisper_parakeet_context_params_set_##name, 1); \ + i++; + + ITERATE_ATTRS(REGISTER_ATTR) +} diff --git a/bindings/ruby/ext/ruby_whisper_parakeet_model.c b/bindings/ruby/ext/ruby_whisper_parakeet_model.c new file mode 100644 index 00000000000..dce43c688e7 --- /dev/null +++ b/bindings/ruby/ext/ruby_whisper_parakeet_model.c @@ -0,0 +1,84 @@ +#include "ruby_whisper.h" + +#define ITERATE_ATTRS(ITERATOR) \ + ITERATOR(n_vocab) \ + ITERATOR(n_audio_ctx) \ + ITERATOR(n_audio_state) \ + ITERATOR(n_audio_head) \ + ITERATOR(n_audio_layer) \ + ITERATOR(n_mels) \ + ITERATOR(ftype) + +extern rb_data_type_t ruby_whisper_parakeet_context_type; +extern VALUE cParakeetModel; + +static void +ruby_whisper_parakeet_model_mark(void *p) +{ + ruby_whisper_parakeet_model *rwpm = (ruby_whisper_parakeet_model *)p; + if (!NIL_P(rwpm->context)) { + rb_gc_mark(rwpm->context); + } +} + +static size_t +ruby_whisper_parakeet_model_memsize(const void *p) +{ + if (!p) { + return 0; + } + return sizeof(ruby_whisper_parakeet_model); +} + +static const rb_data_type_t ruby_whisper_parakeet_model_type = { + "ruby_whisper_parakeet_model", + {ruby_whisper_parakeet_model_mark, RUBY_DEFAULT_FREE, ruby_whisper_parakeet_model_memsize}, + 0, 0, + 0 +}; + +static VALUE +ruby_whisper_parakeet_model_s_allocate(VALUE klass) +{ + ruby_whisper_parakeet_model *rwpm; + VALUE model = TypedData_Make_Struct(klass, ruby_whisper_parakeet_model, &ruby_whisper_parakeet_model_type, rwpm); + rwpm->context = Qnil; + + return model; +} + +VALUE +ruby_whisper_parakeet_model_s_new(VALUE context) +{ + const VALUE model = ruby_whisper_parakeet_model_s_allocate(cParakeetModel); + ruby_whisper_parakeet_model *rwpm; + TypedData_Get_Struct(model, ruby_whisper_parakeet_model, &ruby_whisper_parakeet_model_type, rwpm); + rwpm->context = context; + return model; +} + +#define DEF_ATTR(name) \ + static VALUE \ + ruby_whisper_parakeet_model_get_##name(VALUE self) \ + { \ + ruby_whisper_parakeet_model *rwpm; \ + ruby_whisper_parakeet_context *rwpc; \ + GetParakeetModel(self, rwpm); \ + GetParakeetContext(rwpm->context, rwpc); \ + return INT2NUM(parakeet_model_##name(rwpc->context)); \ + } + +ITERATE_ATTRS(DEF_ATTR) + +void +init_ruby_whisper_parakeet_model(VALUE *mParakeet) +{ + cParakeetModel = rb_define_class_under(*mParakeet, "Model", rb_cObject); + + rb_define_alloc_func(cParakeetModel, ruby_whisper_parakeet_model_s_allocate); + +#define REGISTER_ATTR(name) \ + rb_define_method(cParakeetModel, #name, ruby_whisper_parakeet_model_get_##name, 0); + + ITERATE_ATTRS(REGISTER_ATTR) +} diff --git a/bindings/ruby/ext/ruby_whisper_parakeet_params.c b/bindings/ruby/ext/ruby_whisper_parakeet_params.c new file mode 100644 index 00000000000..076e2a0cdfb --- /dev/null +++ b/bindings/ruby/ext/ruby_whisper_parakeet_params.c @@ -0,0 +1,548 @@ +#include "ruby_whisper.h" + +#define ITERATE_PARAMS(ITERATOR) \ + ITERATOR(n_threads, INT) \ + ITERATOR(offset_ms, INT) \ + ITERATOR(duration_ms, INT) \ + ITERATOR(no_context, BOOL) \ + ITERATOR(audio_ctx, INT) + +#define ITERATE_NORMAL_CALLBACK_NAMES(ITERATOR, DATA) \ + ITERATOR(new_segment, DATA) \ + ITERATOR(new_token, DATA) \ + ITERATOR(progress, DATA) \ + ITERATOR(encoder_begin, DATA) + +#define ITERATE_NORMAL_CALLBACK_PARAM(name, ITERATOR) ITERATOR(name##_callback) +#define ITERATE_NORMAL_CALLBACK_PARAMS(ITERATOR) \ + ITERATE_NORMAL_CALLBACK_NAMES(ITERATE_NORMAL_CALLBACK_PARAM, ITERATOR) + +#define ITERATE_CALLBACK_PARAMS(ITERATOR) \ + ITERATE_NORMAL_CALLBACK_PARAMS(ITERATOR) \ + ITERATOR(abort_callback) + +enum { +#define DEF_IDX(name, type) RUBY_WHISPER_PARAKEET_PARAM_##name, +#define DEF_IDX_CALLBACK(name) RUBY_WHISPER_PARAKEET_PARAM_##name, +#define DEF_IDX_USER_DATA(name) RUBY_WHISPER_PARAKEET_PARAM_##name##_user_data, + ITERATE_PARAMS(DEF_IDX) + ITERATE_CALLBACK_PARAMS(DEF_IDX_CALLBACK) + ITERATE_CALLBACK_PARAMS(DEF_IDX_USER_DATA) + + RUBY_WHISPER_PARAKEET_NUM_PARAMS +}; + +#define VAL_TO_INT(v) (NUM2INT(v)) +#define VAL_FROM_INT(v) (INT2NUM(v)) +#define VAL_TO_BOOL(v) (RTEST(v)) +#define VAL_FROM_BOOL(v) (v ? Qtrue : Qfalse) + +extern VALUE cParakeetParams; +extern ID id_call; + +extern void ruby_whisper_callback_container_mark(ruby_whisper_callback_container *rwc); +extern ruby_whisper_callback_container* ruby_whisper_callback_container_allocate(void); +extern bool ruby_whisper_callback_container_is_present(const ruby_whisper_callback_container *container); +extern VALUE ruby_whisper_parakeet_segment_init(VALUE context, int index); +extern VALUE ruby_whisper_parakeet_token_s_from_token_data(struct parakeet_context *context, const parakeet_token_data *token_data); + +static ID param_names[RUBY_WHISPER_PARAKEET_NUM_PARAMS]; +typedef VALUE (*param_writer_t)(VALUE, VALUE); +static param_writer_t param_writers[RUBY_WHISPER_PARAKEET_NUM_PARAMS]; + +typedef struct { + const ruby_whisper_callback_container *container; + struct parakeet_state *state; + int n_new; +} call_parakeet_new_segment_callbacks_args; + +static void* +call_parakeet_new_segment_callbacks(void *v_args) +{ + call_parakeet_new_segment_callbacks_args *args = (call_parakeet_new_segment_callbacks_args *)v_args; + const ruby_whisper_callback_container *container = args->container; + + if (!NIL_P(container->callback)) { + rb_funcall(container->callback, id_call, 4, *container->context, Qnil, INT2NUM(args->n_new), container->user_data); + } + if (NIL_P(container->callbacks)) { + return NULL; + } + const long n_callbacks = RARRAY_LEN(container->callbacks); + if (n_callbacks == 0) { + return NULL; + } + const int n_segments = parakeet_full_n_segments_from_state(args->state); + for (int i = args->n_new; i > 0; i--) { + int i_segment = n_segments - i; + VALUE segment = ruby_whisper_parakeet_segment_init(*container->context, i_segment); + for (int j = 0; j < n_callbacks; j++) { + VALUE cb = rb_ary_entry(container->callbacks, j); + rb_funcall(cb, id_call, 1, segment); + } + } + + return NULL; +} + +static void +ruby_whisper_parakeet_new_segment_callback(struct parakeet_context *context, struct parakeet_state *state, int n_new, void *user_data) +{ + const ruby_whisper_callback_container *container = (ruby_whisper_callback_container *)user_data; + if (!ruby_whisper_callback_container_is_present(container)) { + return; + } + + call_parakeet_new_segment_callbacks_args args = { + container, + state, + n_new, + }; + rb_thread_call_with_gvl(call_parakeet_new_segment_callbacks, (void *)&args); +} + +typedef struct { + const ruby_whisper_callback_container *container; + struct parakeet_context *context; + struct parakeet_state *state; + const parakeet_token_data *token_data; +} call_parakeet_new_token_callbacks_args; + +static void* +call_parakeet_new_token_callbacks(void *v_args) +{ + call_parakeet_new_token_callbacks_args *args = (call_parakeet_new_token_callbacks_args *)v_args; + VALUE token = Qnil; + const ruby_whisper_callback_container *container = args->container; + + if (!NIL_P(container->callback)) { + token = ruby_whisper_parakeet_token_s_from_token_data(args->context, args->token_data); + rb_funcall(container->callback, id_call, 4, *container->context, Qnil, token, container->user_data); + } + if (NIL_P(container->callbacks)) { + return NULL; + } + const long n_callbacks = RARRAY_LEN(container->callbacks); + if (n_callbacks == 0) { + return NULL; + } + if (NIL_P(token)) { + token = ruby_whisper_parakeet_token_s_from_token_data(args->context, args->token_data); + } + for (int i = 0; i < n_callbacks; i++) { + VALUE cb = rb_ary_entry(container->callbacks, i); + rb_funcall(cb, id_call, 1, token); + } + + return NULL; +} + +static void +ruby_whisper_parakeet_new_token_callback(struct parakeet_context *context, struct parakeet_state *state, const parakeet_token_data *token_data, void *user_data) +{ + const ruby_whisper_callback_container *container = (ruby_whisper_callback_container *)user_data; + if (!ruby_whisper_callback_container_is_present(container)) { + return; + } + + call_parakeet_new_token_callbacks_args args = { + container, + context, + state, + token_data, + }; + rb_thread_call_with_gvl(call_parakeet_new_token_callbacks, (void *)&args); +} + +typedef struct { + const ruby_whisper_callback_container *container; + struct parakeet_state *state; + int progress; +} call_parakeet_progress_callbacks_args; + +static void* +call_parakeet_progress_callback(void *v_args) +{ + call_parakeet_progress_callbacks_args *args = (call_parakeet_progress_callbacks_args *)v_args; + const ruby_whisper_callback_container *container = args->container; + + if (!NIL_P(container->callback)) { + rb_funcall(container->callback, id_call, 4, *container->context, Qnil, INT2NUM(args->progress), container->user_data); + } + if (NIL_P(container->callbacks)) { + return NULL; + } + const long n_callbacks = RARRAY_LEN(container->callbacks); + if (n_callbacks == 0) { + return NULL; + } + for (long i = 0; i < n_callbacks; i++) { + VALUE cb = rb_ary_entry(container->callbacks, i); + rb_funcall(cb, id_call, 1, INT2NUM(args->progress)); + } + + return NULL; +} + +static void +ruby_whisper_parakeet_progress_callback(struct parakeet_context *context, struct parakeet_state *state, int progress, void *user_data) +{ + const ruby_whisper_callback_container *container = (ruby_whisper_callback_container *)user_data; + if (!ruby_whisper_callback_container_is_present(container)) { + return; + } + + call_parakeet_progress_callbacks_args args = { + container, + state, + progress, + }; + rb_thread_call_with_gvl(call_parakeet_progress_callback, (void *)&args); +} + +typedef struct { + const ruby_whisper_callback_container *container; + struct parakeet_state *state; + bool is_continued; +} call_parakeet_encoder_begin_callbacks_args; + +static void* +call_parakeet_encoder_begin_callbacks(void *v_args) +{ + call_parakeet_encoder_begin_callbacks_args *args = (call_parakeet_encoder_begin_callbacks_args *)v_args; + const ruby_whisper_callback_container *container = args->container; + VALUE result = Qnil; + + if (!NIL_P(container->callback)) { + result = rb_funcall(container->callback, id_call, 3, *container->context, Qnil, container->user_data); + if (result == Qfalse) { + args->is_continued = false; + return NULL; + } + } + if (NIL_P(container->callbacks)) { + return NULL; + } + const long n_callbacks = RARRAY_LEN(container->callbacks); + if (n_callbacks == 0) { + return NULL; + } + for (long i = 0; i < n_callbacks; i++) { + VALUE cb = rb_ary_entry(container->callbacks, i); + result = rb_funcall(cb, id_call, 0); + if (result == Qfalse) { + args->is_continued = false; + return NULL; + } + } + + return NULL; +} + +static bool +ruby_whisper_parakeet_encoder_begin_callback(struct parakeet_context *context, struct parakeet_state *state, void *user_data) +{ + const ruby_whisper_callback_container *container = (ruby_whisper_callback_container *)user_data; + if (!ruby_whisper_callback_container_is_present(container)) { + return true; + } + + call_parakeet_encoder_begin_callbacks_args args = { + container, + state, + true, + }; + rb_thread_call_with_gvl(call_parakeet_encoder_begin_callbacks, (void *)&args); + + return args.is_continued; +} + +typedef struct { + const ruby_whisper_callback_container *container; + bool is_interrupted; +} call_parakeet_abort_callbacks_args; + +static void* +call_parakeet_abort_callbacks(void *v_args) +{ + call_parakeet_abort_callbacks_args *args = (call_parakeet_abort_callbacks_args *)v_args; + const ruby_whisper_callback_container *container = args->container; + VALUE result = Qnil; + + if (!NIL_P(container->callback)) { + result = rb_funcall(container->callback, id_call, 1, container->user_data); + if (RTEST(result)) { + args->is_interrupted = true; + return NULL; + } + } + if (NIL_P(container->callbacks)) { + return NULL; + } + const long n_callbacks = RARRAY_LEN(container->callbacks); + if (n_callbacks == 0) { + return NULL; + } + VALUE cb; + for (long i = 0; i < n_callbacks; i++) { + cb = rb_ary_entry(container->callbacks, i); + result = rb_funcall(cb, id_call, 0); + if (RTEST(result)) { + args->is_interrupted = true; + return NULL; + } + } + + return NULL; +} + +static bool +ruby_whisper_parakeet_abort_callback(void *user_data) +{ + ruby_whisper_abort_callback_user_data *data = (ruby_whisper_abort_callback_user_data *)user_data; + + int is_interrupted = RUBY_ATOMIC_LOAD(data->is_interrupted); + if (is_interrupted) { + return true; + } + + if (!(data->callback_container) || !ruby_whisper_callback_container_is_present(data->callback_container)) { + return false; + } + + call_parakeet_abort_callbacks_args args = { + data->callback_container, + false, + }; + rb_thread_call_with_gvl(call_parakeet_abort_callbacks, (void *)&args); + + return args.is_interrupted; +} + +#define CALLBACK_CONTAINER_NAME(name) name ## _container + +void +ruby_whisper_parakeet_prepare_transcription(ruby_whisper_parakeet_params *rwpp, VALUE *context, ruby_whisper_abort_callback_user_data *abort_callback_user_data) +{ +#define PARAM_NAME(name) name +#define USER_DATA_NAME(name) name##_user_data +#define REGISTER_CALLBACK(name) \ + if (ruby_whisper_callback_container_is_present(rwpp->CALLBACK_CONTAINER_NAME(name))) { \ + rwpp->CALLBACK_CONTAINER_NAME(name)->context = context; \ + rwpp->params.PARAM_NAME(name) = ruby_whisper_parakeet_##name; \ + rwpp->params.USER_DATA_NAME(name) = rwpp->CALLBACK_CONTAINER_NAME(name); \ + } + + ITERATE_NORMAL_CALLBACK_PARAMS(REGISTER_CALLBACK) + + if (ruby_whisper_callback_container_is_present(rwpp->abort_callback_container)) { + abort_callback_user_data->callback_container = rwpp->abort_callback_container; + } + rwpp->params.abort_callback = ruby_whisper_parakeet_abort_callback; + rwpp->params.abort_callback_user_data = (void *)abort_callback_user_data; +} + +static void +ruby_whisper_parakeet_params_mark(void *p) +{ + ruby_whisper_parakeet_params *rwpp = (ruby_whisper_parakeet_params *)p; + +#define MARK_CONTAINER(name) \ + if (rwpp->name##_container) { \ + ruby_whisper_callback_container_mark(rwpp->name##_container); \ + } + + ITERATE_CALLBACK_PARAMS(MARK_CONTAINER) +} + +static void +ruby_whisper_parakeet_params_free(void *p) +{ + ruby_whisper_parakeet_params *rwpp = (ruby_whisper_parakeet_params *)p; + +#define FREE_CONTAINER(name) \ + if (rwpp->name##_container) { \ + xfree(rwpp->name##_container); \ + } + + ITERATE_CALLBACK_PARAMS(FREE_CONTAINER) + + xfree(rwpp); +} + +static size_t +ruby_whisper_parakeet_params_memsize(const void *p) +{ + const struct ruby_whisper_parakeet_params *params = p; + if (!params) { + return 0; + } + return sizeof(ruby_whisper_parakeet_params); +} + +const rb_data_type_t ruby_whisper_parakeet_params_type = { + "ruby_whisper_parakeet_params", + {ruby_whisper_parakeet_params_mark, ruby_whisper_parakeet_params_free, ruby_whisper_parakeet_params_memsize,}, + 0, 0, + 0 +}; + +#define READER(type) VAL_FROM_##type +#define WRITER(type) VAL_TO_##type +#define DEF_PARAM_ATTR(name, type) \ + static VALUE \ + ruby_whisper_parakeet_params_get_##name(VALUE self) \ + { \ + ruby_whisper_parakeet_params *rwpp; \ + GetParakeetParams(self, rwpp); \ + return READER(type)(rwpp->params.name); \ + } \ + static VALUE \ + ruby_whisper_parakeet_params_set_##name(VALUE self, VALUE val) \ + { \ + ruby_whisper_parakeet_params *rwpp; \ + GetParakeetParams(self, rwpp); \ + rwpp->params.name = WRITER(type)(val); \ + return val; \ + } + +#define DEF_CALLBACK_PARAM_ATTR(name) \ + static VALUE \ + ruby_whisper_parakeet_params_get_##name(VALUE self) \ + { \ + ruby_whisper_parakeet_params *rwpp; \ + GetParakeetParams(self, rwpp); \ + return rwpp->CALLBACK_CONTAINER_NAME(name)->callback; \ + } \ + static VALUE \ + ruby_whisper_parakeet_params_set_##name(VALUE self, VALUE val) \ + { \ + ruby_whisper_parakeet_params *rwpp; \ + GetParakeetParams(self, rwpp); \ + rwpp->CALLBACK_CONTAINER_NAME(name)->callback = (val); \ + return val; \ + } + +#define DEF_USER_DATA_PARAM_ATTR(name) \ + static VALUE \ + ruby_whisper_parakeet_params_get_##name##_user_data(VALUE self) \ + { \ + ruby_whisper_parakeet_params *rwpp; \ + GetParakeetParams(self, rwpp); \ + return rwpp->CALLBACK_CONTAINER_NAME(name)->user_data; \ + } \ + static VALUE \ + ruby_whisper_parakeet_params_set_##name##_user_data(VALUE self, VALUE val) \ + { \ + ruby_whisper_parakeet_params *rwpp; \ + GetParakeetParams(self, rwpp); \ + rwpp->CALLBACK_CONTAINER_NAME(name)->user_data = val; \ + return val; \ + } + +#define DEF_HOOK(name, data) \ + static VALUE \ + ruby_whisper_parakeet_params_on_##name(VALUE self) \ + { \ + ruby_whisper_parakeet_params *rwpp; \ + GetParakeetParams(self, rwpp); \ + const VALUE blk = rb_block_proc(); \ + if (NIL_P(rwpp->name##_callback_container->callbacks)) { \ + rwpp->name##_callback_container->callbacks = rb_ary_new(); \ + } \ + rb_ary_push(rwpp->name##_callback_container->callbacks, blk); \ + return Qnil; \ + } + +ITERATE_PARAMS(DEF_PARAM_ATTR) +ITERATE_CALLBACK_PARAMS(DEF_CALLBACK_PARAM_ATTR) +ITERATE_CALLBACK_PARAMS(DEF_USER_DATA_PARAM_ATTR) +ITERATE_NORMAL_CALLBACK_NAMES(DEF_HOOK, _) + +static VALUE +ruby_whisper_parakeet_params_abort_on(VALUE self) +{ + ruby_whisper_parakeet_params *rwpp; + GetParakeetParams(self, rwpp); + const VALUE blk = rb_block_proc(); + if (NIL_P(rwpp->abort_callback_container->callbacks)) { + rwpp->abort_callback_container->callbacks = rb_ary_new(); + } + rb_ary_push(rwpp->abort_callback_container->callbacks, blk); + + return Qnil; +} + +static VALUE +ruby_whisper_parakeet_params_s_allocate(VALUE klass) +{ + ruby_whisper_parakeet_params *rwpp; + VALUE obj = TypedData_Make_Struct(klass, ruby_whisper_parakeet_params, &ruby_whisper_parakeet_params_type, rwpp); + rwpp->params = parakeet_full_default_params(PARAKEET_SAMPLING_GREEDY); + return obj; +} + +static VALUE +ruby_whisper_parakeet_params_initialize(int argc, VALUE *argv, VALUE self) +{ + VALUE kw_hash; + VALUE values[RUBY_WHISPER_PARAKEET_NUM_PARAMS] = {Qundef}; + VALUE value; + ruby_whisper_parakeet_params *rwpp; + int i; + + TypedData_Get_Struct(self, ruby_whisper_parakeet_params, &ruby_whisper_parakeet_params_type, rwpp); + +#define INIT_CONTAINER(name) rwpp->name##_container = ruby_whisper_callback_container_allocate(); + + ITERATE_CALLBACK_PARAMS(INIT_CONTAINER) + + rb_scan_args_kw(RB_SCAN_ARGS_KEYWORDS, argc, argv, ":", &kw_hash); + if (NIL_P(kw_hash)) { + return Qnil; + } + + rb_get_kwargs(kw_hash, param_names, 0, RUBY_WHISPER_PARAKEET_NUM_PARAMS, values); + + for (i = 0; i < RUBY_WHISPER_PARAKEET_NUM_PARAMS; i++) { + value = values[i]; + if (value == Qundef) { + continue; + } + param_writers[i](self, value); + } + + return Qnil; +} + +void +init_ruby_whisper_parakeet_params(VALUE *mParakeet) +{ + cParakeetParams = rb_define_class_under(*mParakeet, "Params", rb_cObject); + rb_define_alloc_func(cParakeetParams, ruby_whisper_parakeet_params_s_allocate); + + rb_define_method(cParakeetParams, "initialize", ruby_whisper_parakeet_params_initialize, -1); + + int i = 0; +#define REGISTER_PARAM(name) \ + param_names[i] = rb_intern(#name); \ + param_writers[i] = ruby_whisper_parakeet_params_set_##name; \ + rb_define_method(cParakeetParams, #name, ruby_whisper_parakeet_params_get_##name, 0); \ + rb_define_method(cParakeetParams, #name "=", ruby_whisper_parakeet_params_set_##name, 1); \ + i++; + +#define REGISTER_PARAM_ATTR(name, type) REGISTER_PARAM(name) +#define REGISTER_CALLBACK_PARAM_ATTR(name) REGISTER_PARAM(name) +#define REGISTER_USER_DATA_PARAM_ATTR(name) REGISTER_PARAM(name##_user_data) + + ITERATE_PARAMS(REGISTER_PARAM_ATTR) + ITERATE_CALLBACK_PARAMS(REGISTER_CALLBACK_PARAM_ATTR) + ITERATE_CALLBACK_PARAMS(REGISTER_USER_DATA_PARAM_ATTR) + +#define REGISTER_HOOK(name, data) \ + rb_define_method(cParakeetParams, "on_" #name, ruby_whisper_parakeet_params_on_##name, 0); + + ITERATE_NORMAL_CALLBACK_NAMES(REGISTER_HOOK, _) + + rb_define_method(cParakeetParams, "abort_on", ruby_whisper_parakeet_params_abort_on, 0); +} diff --git a/bindings/ruby/ext/ruby_whisper_parakeet_segment.c b/bindings/ruby/ext/ruby_whisper_parakeet_segment.c new file mode 100644 index 00000000000..b1e81ba930c --- /dev/null +++ b/bindings/ruby/ext/ruby_whisper_parakeet_segment.c @@ -0,0 +1,157 @@ +#include "ruby_whisper.h" + +#define ITERATE_ATTRS(ITERATOR) \ + ITERATOR(start_time, t0, TIME) \ + ITERATOR(end_time, t1, TIME) \ + ITERATOR(text, text, STRING) + +enum { +#define DEF_IDX(name, c_name, type) RUBY_WHISPER_PARAKEET_SEGMENT_##name, + + ITERATE_ATTRS(DEF_IDX) + RUBY_WHISPER_PARAKEET_SEGMENT_NUM_ATTRS, +}; + +#define VAL_FROM_TIME(v) (LONG2NUM((v) * 10)) +#define VAL_FROM_STRING(v) (rb_str_new2(v)) +#define READER(type) VAL_FROM_##type +#define DEF_ATTR(rb_name, c_name, type) \ + static VALUE \ + ruby_whisper_parakeet_get_##rb_name(VALUE self) \ + { \ + ruby_whisper_parakeet_segment *rwps; \ + GetParakeetSegment(self, rwps); \ + ruby_whisper_parakeet_context *rwpc; \ + GetParakeetContext(rwps->context, rwpc); \ + return READER(type)(parakeet_full_get_segment_##c_name(rwpc->context, rwps->index)); \ + } + +extern ID id___method__; +extern ID id_to_enum; +extern VALUE cParakeetSegment; +extern VALUE sym_start_time; +extern VALUE sym_end_time; +extern VALUE sym_text; +extern const rb_data_type_t ruby_whisper_parakeet_context_type; +extern VALUE ruby_whisper_parakeet_token_s_from_index(struct parakeet_context *context, int i_segment, int i_token); + +static void +rb_whisper_parakeet_segment_mark(void *p) +{ + ruby_whisper_parakeet_segment *rwps = (ruby_whisper_parakeet_segment *)p; + rb_gc_mark(rwps->context); +} + +static size_t +ruby_whisper_parakeet_segment_memsize(const void *p) +{ + const ruby_whisper_parakeet_segment *rwps = (const ruby_whisper_parakeet_segment *)p; + if (!rwps) { + return 0; + } + return sizeof(*rwps); +} + +static const rb_data_type_t ruby_whisper_parakeet_segment_type = { + "ruby_whisper_parakeet_segment", + {rb_whisper_parakeet_segment_mark, RUBY_DEFAULT_FREE, ruby_whisper_parakeet_segment_memsize,}, + 0, 0, + 0 +}; + +static VALUE +ruby_whisper_parakeet_segment_s_allocate(VALUE klass) +{ + ruby_whisper_parakeet_segment *rwps; + return TypedData_Make_Struct(klass, ruby_whisper_parakeet_segment, &ruby_whisper_parakeet_segment_type, rwps); +} + +VALUE +ruby_whisper_parakeet_segment_init(VALUE context, int index) +{ + ruby_whisper_parakeet_segment *rwps; + + const VALUE segment = ruby_whisper_parakeet_segment_s_allocate(cParakeetSegment); + TypedData_Get_Struct(segment, ruby_whisper_parakeet_segment, &ruby_whisper_parakeet_segment_type, rwps); + rwps->context = context; + rwps->index = index; + + return segment; +} + +ITERATE_ATTRS(DEF_ATTR) + +static VALUE +ruby_whisper_parakeet_segment_each_token(VALUE self) +{ + if (!rb_block_given_p()) { + const VALUE method_name = rb_funcall(self, id___method__, 0); + return rb_funcall(self, id_to_enum, 1, method_name); + } + + ruby_whisper_parakeet_segment *rwps; + GetParakeetSegment(self, rwps); + ruby_whisper_parakeet_context *rwpc; + GetParakeetContext(rwps->context, rwpc); + + const int n_tokens = parakeet_full_n_tokens(rwpc->context, rwps->index); + for (int i = 0; i < n_tokens; i++) { + rb_yield(ruby_whisper_parakeet_token_s_from_index(rwpc->context, rwps->index, i)); + } + + return self; +} + +static VALUE +ruby_whisper_parakeet_segment_deconstruct_keys(VALUE self, VALUE keys) +{ + ruby_whisper_parakeet_segment *rwps; + GetParakeetSegment(self, rwps); + ruby_whisper_parakeet_context *rwpc; + GetParakeetContext(rwps->context, rwpc); + + VALUE hash = rb_hash_new(); + long n_keys; + if (NIL_P(keys)) { + keys = rb_ary_new3( + RUBY_WHISPER_PARAKEET_SEGMENT_NUM_ATTRS, + sym_start_time, + sym_end_time, + sym_text + ); + n_keys = RUBY_WHISPER_PARAKEET_SEGMENT_NUM_ATTRS; + } else { + n_keys = RARRAY_LEN(keys); + if (n_keys > RUBY_WHISPER_PARAKEET_SEGMENT_NUM_ATTRS) { + return hash; + } + } + for (int i = 0; i < n_keys; i++) { + VALUE key = rb_ary_entry(keys, i); + +#define CHECK_AND_SET_KEY(rb_name, c_name, type) \ + if (key == sym_##rb_name) { \ + rb_hash_aset(hash, key, ruby_whisper_parakeet_get_##rb_name(self)); \ + } + + ITERATE_ATTRS(CHECK_AND_SET_KEY) + } + + return hash; +} + +void +init_ruby_whisper_parakeet_segment(VALUE *mParakeet) +{ + cParakeetSegment = rb_define_class_under(*mParakeet, "Segment", rb_cObject); + + rb_define_alloc_func(cParakeetSegment, ruby_whisper_parakeet_segment_s_allocate); + +#define REGISTER_ATTR(rb_name, c_name, type) \ + rb_define_method(cParakeetSegment, #rb_name, ruby_whisper_parakeet_get_##rb_name, 0); + + ITERATE_ATTRS(REGISTER_ATTR) + + rb_define_method(cParakeetSegment, "each_token", ruby_whisper_parakeet_segment_each_token, 0); + rb_define_method(cParakeetSegment, "deconstruct_keys", ruby_whisper_parakeet_segment_deconstruct_keys, 1); +} diff --git a/bindings/ruby/ext/ruby_whisper_parakeet_token.c b/bindings/ruby/ext/ruby_whisper_parakeet_token.c new file mode 100644 index 00000000000..a00b7ae1cbb --- /dev/null +++ b/bindings/ruby/ext/ruby_whisper_parakeet_token.c @@ -0,0 +1,188 @@ +#include "ruby_whisper.h" + +#define ITERATE_MEMBERS(ITERATOR) \ + ITERATOR(id, id, id, id, INT) \ + ITERATOR(duration_idx, duration_idx, duration_idx, duration_idx, INT) \ + ITERATOR(duration_value, duration_value, duration_value, duration_value, INT) \ + ITERATOR(frame_index, frame_index, frame_index, frame_index, INT) \ + ITERATOR(probability, probability, p, p, FLOAT) \ + ITERATOR(log_probability, log_probability, plog, plog, FLOAT) \ + ITERATOR(start_time, start_time, start_time, t0, TIME) \ + ITERATOR(end_time, end_time, end_time, t1, TIME) \ + ITERATOR(word_start?, word_start, word_start_p, is_word_start, BOOL) + +#define ITERATE_ATTRS(ITERATOR) \ + ITERATOR(text, text, text, text, STRING) + +enum { +#define DEF_IDX(rb_name, s_key, c_name, p_name, type) RUBY_WHISPER_PARAKEET_TOKEN_##c_name, + + ITERATE_MEMBERS(DEF_IDX) + ITERATE_ATTRS(DEF_IDX) + RUBY_WHISPER_PARAKEET_TOKEN_NUM_ATTRS, +}; + +#define VAL_FROM_INT(v) (INT2NUM(v)) +#define VAL_FROM_FLOAT(v) (DBL2NUM(v)) +#define VAL_FROM_TIME(v) (LONG2NUM(v * 10)) +#define VAL_FROM_BOOL(v) ((v) ? Qtrue : Qfalse) +#define VAL_FROM_STRING(v) (rb_str_new2(v)) + +#define READER(type) VAL_FROM_##type +#define MEMBER_NAME(name) name +#define DEF_MEMBER_ATTR(rb_name, s_key, c_name, p_name, type) \ + static VALUE \ + ruby_whisper_parakeet_token_get_##c_name(VALUE self) \ + { \ + ruby_whisper_parakeet_token *rwpt; \ + GetParakeetToken(self, rwpt); \ + return READER(type)(rwpt->token_data->MEMBER_NAME(p_name)); \ + } + +#define DEF_ATTR(rb_name, s_key, c_name, p_name, type) \ + static VALUE \ + ruby_whisper_parakeet_token_get_##c_name(VALUE self) \ + { \ + ruby_whisper_parakeet_token *rwpt; \ + GetParakeetToken(self, rwpt); \ + return rwpt->p_name; \ + } + +VALUE cParakeetToken; + +#define DEC_ATTR_SYMS(rb_name, s_key, c_name, p_name, type) static VALUE sym_##s_key; + +ITERATE_MEMBERS(DEC_ATTR_SYMS) +ITERATE_ATTRS(DEC_ATTR_SYMS) + +static void +ruby_whisper_parakeet_token_mark(void *p) +{ + ruby_whisper_parakeet_token *rwpt = (ruby_whisper_parakeet_token *)p; + rb_gc_mark(rwpt->text); +} + +static void +ruby_whisper_parakeet_token_free(void *p) +{ + ruby_whisper_parakeet_token *rwpt = (ruby_whisper_parakeet_token *)p; + if (rwpt->token_data) { + xfree(rwpt->token_data); + rwpt->token_data = NULL; + } + xfree(rwpt); +} + +static size_t +ruby_whisper_parakeet_token_memsize(const void *p) +{ + ruby_whisper_parakeet_token *rwpt = (ruby_whisper_parakeet_token *)p; + if (!rwpt) { + return 0; + } + size_t size = sizeof(*rwpt); + if (rwpt->token_data) { + size += sizeof(*rwpt->token_data); + } + + return size; +} + +static const rb_data_type_t ruby_whisper_parakeet_token_type = { + "ruby_whisper_parakeet_token", + {ruby_whisper_parakeet_token_mark, ruby_whisper_parakeet_token_free, ruby_whisper_parakeet_token_memsize}, + 0, 0, + 0, +}; + +static VALUE +ruby_whisper_parakeet_token_s_allocate(VALUE klass) +{ + ruby_whisper_parakeet_token *rwpt; + VALUE token = TypedData_Make_Struct(klass, ruby_whisper_parakeet_token, &ruby_whisper_parakeet_token_type, rwpt); + + rwpt->token_data = NULL; + rwpt->text = Qnil; + + return token; +} + +VALUE +ruby_whisper_parakeet_token_s_from_token_data(struct parakeet_context *context, const parakeet_token_data *token_data) +{ + const VALUE token = ruby_whisper_parakeet_token_s_allocate(cParakeetToken); + ruby_whisper_parakeet_token *rwpt; + TypedData_Get_Struct(token, ruby_whisper_parakeet_token, &ruby_whisper_parakeet_token_type, rwpt); + + rwpt->token_data = ALLOC(parakeet_token_data); + *rwpt->token_data = *token_data; + rwpt->text = rb_utf8_str_new_cstr(parakeet_token_to_str(context, token_data->id)); + + return token; +} + +VALUE +ruby_whisper_parakeet_token_s_from_index(struct parakeet_context *context, int i_segment, int i_token) +{ + parakeet_token_data token_data = parakeet_full_get_token_data(context, i_segment, i_token); + return ruby_whisper_parakeet_token_s_from_token_data(context, &token_data); +} + +ITERATE_MEMBERS(DEF_MEMBER_ATTR) +// Define #text using parakeet_token_to_str or parakeet_token_to_text +ITERATE_ATTRS(DEF_ATTR) + +static VALUE +ruby_whisper_parakeet_token_deconstruct_keys(VALUE self, VALUE keys) +{ + ruby_whisper_parakeet_token *rwpt; + GetParakeetToken(self, rwpt); + + VALUE hash = rb_hash_new(); + long n_keys = 0; + + if (NIL_P(keys)) { + VALUE attrs[] = { +#define LIST_SYMS(rb_name, s_key, c_name, p_name, type) sym_##s_key, + + ITERATE_MEMBERS(LIST_SYMS) + ITERATE_ATTRS(LIST_SYMS) + }; + keys = rb_ary_new_from_values(RUBY_WHISPER_PARAKEET_TOKEN_NUM_ATTRS, attrs); + n_keys = RUBY_WHISPER_PARAKEET_TOKEN_NUM_ATTRS; + } else { + n_keys = RARRAY_LEN(keys); + if (n_keys > RUBY_WHISPER_PARAKEET_TOKEN_NUM_ATTRS) { + return hash; + } + } + for (long i = 0; i < n_keys; i++) { + VALUE key = rb_ary_entry(keys, i); + +#define CHECK_AND_SET_KEY(rb_name, s_key, c_name, p_name, type) \ + if (key == sym_##s_key) { \ + rb_hash_aset(hash, key, ruby_whisper_parakeet_token_get_##c_name(self)); \ + } + + ITERATE_MEMBERS(CHECK_AND_SET_KEY) + ITERATE_ATTRS(CHECK_AND_SET_KEY) + } + + return hash; +} + +void +init_ruby_whisper_parakeet_token(VALUE *mParakeet) +{ + cParakeetToken = rb_define_class_under(*mParakeet, "Token", rb_cObject); + rb_define_alloc_func(cParakeetToken, ruby_whisper_parakeet_token_s_allocate); + +#define REGISTER_ATTR(rb_name, s_key, c_name, p_name, type) \ + sym_##s_key = ID2SYM(rb_intern(#s_key)); \ + rb_define_method(cParakeetToken, #rb_name, ruby_whisper_parakeet_token_get_##c_name, 0); + + ITERATE_MEMBERS(REGISTER_ATTR) + ITERATE_ATTRS(REGISTER_ATTR) + + rb_define_method(cParakeetToken, "deconstruct_keys", ruby_whisper_parakeet_token_deconstruct_keys, 1); +} diff --git a/bindings/ruby/ext/ruby_whisper_parakeet_transcribe.cpp b/bindings/ruby/ext/ruby_whisper_parakeet_transcribe.cpp new file mode 100644 index 00000000000..c4deccce84a --- /dev/null +++ b/bindings/ruby/ext/ruby_whisper_parakeet_transcribe.cpp @@ -0,0 +1,58 @@ +#include "ruby_whisper.h" +#include "common-whisper.h" +#include <string> +#include <vector> + +#ifdef __cplusplus +extern "C" { +#endif + +extern const rb_data_type_t ruby_whisper_parakeet_context_type; +extern const rb_data_type_t ruby_whisper_parakeet_params_type; + +extern VALUE ruby_whisper_parakeet_context_full_body(VALUE rb_args); + +extern ID id_to_path; +extern ID id_new; + +extern VALUE eError; + +VALUE +ruby_whisper_parakeet_transcribe(VALUE self, VALUE audio_path, VALUE params) +{ + if (rb_respond_to(audio_path, id_to_path)) { + audio_path = rb_funcall(audio_path, id_to_path, 0); + } + + std::string fname = StringValueCStr(audio_path); + std::vector<float> pcmf32; + std::vector<std::vector<float>> pcmf32s; + + if (!read_audio_data(fname, pcmf32, pcmf32s, false)) { + rb_raise(rb_eRuntimeError, "Failed to open %s", fname.c_str()); + return Qnil; + } + + ruby_whisper_parakeet_context *rwpc; + ruby_whisper_parakeet_params *rwpp; + GetParakeetContext(self, rwpc); + GetParakeetParams(params, rwpp); + + ruby_whisper_full_args args = { + &self, + ¶ms, + pcmf32.data(), + (int)pcmf32.size(), + }; + VALUE rb_result = ruby_whisper_parakeet_context_full_body((VALUE)&args); + const int result = NUM2INT(rb_result); + if (result == 0) { + return self; + } else { + rb_exc_raise(rb_funcall(eError, id_new, 1, rb_result)); + } +} + +#ifdef __cplusplus +} +#endif diff --git a/bindings/ruby/ext/ruby_whisper_params.c b/bindings/ruby/ext/ruby_whisper_params.c index 4dfe2575a39..f38e9bde3ea 100644 --- a/bindings/ruby/ext/ruby_whisper_params.c +++ b/bindings/ruby/ext/ruby_whisper_params.c @@ -1,4 +1,3 @@ -#include <ruby.h> #include "ruby_whisper.h" #define BOOL_PARAMS_SETTER(self, prop, value) \ @@ -30,6 +29,7 @@ extern VALUE cParams; extern VALUE cVADParams; +extern VALUE mWhisper; extern ID id_call; @@ -76,8 +76,8 @@ static ID id_vad; static ID id_vad_model_path; static ID id_vad_params; -static void -rb_whisper_callbcack_container_mark(ruby_whisper_callback_container *rwc) +void +ruby_whisper_callback_container_mark(ruby_whisper_callback_container *rwc) { if (rwc == NULL) return; @@ -86,28 +86,46 @@ rb_whisper_callbcack_container_mark(ruby_whisper_callback_container *rwc) rb_gc_mark(rwc->callbacks); } -static ruby_whisper_callback_container* -rb_whisper_callback_container_allocate() { +ruby_whisper_callback_container* +ruby_whisper_callback_container_allocate() { ruby_whisper_callback_container *container; container = ALLOC(ruby_whisper_callback_container); container->context = NULL; container->user_data = Qnil; container->callback = Qnil; - container->callbacks = rb_ary_new(); + container->callbacks = Qnil; return container; } -static void new_segment_callback(struct whisper_context *ctx, struct whisper_state *state, int n_new, void *user_data) { - const ruby_whisper_callback_container *container = (ruby_whisper_callback_container *)user_data; +bool +ruby_whisper_callback_container_is_present(const ruby_whisper_callback_container *container) { + return !NIL_P(container->callback) || !NIL_P(container->callbacks); +} + +typedef struct { + const ruby_whisper_callback_container *container; + struct whisper_state *state; + int n_new; +} call_new_segment_callbacks_args; + +static void* +call_new_segment_callbacks(void *v_args) { + call_new_segment_callbacks_args *args = (call_new_segment_callbacks_args *)v_args; + const ruby_whisper_callback_container *container = args->container; + struct whisper_state *state = args->state; + int n_new = args->n_new; // Currently, doesn't support state because // those require to resolve GC-related problems. if (!NIL_P(container->callback)) { rb_funcall(container->callback, id_call, 4, *container->context, Qnil, INT2NUM(n_new), container->user_data); } + if (NIL_P(container->callbacks)) { + return NULL; + } const long callbacks_len = RARRAY_LEN(container->callbacks); if (0 == callbacks_len) { - return; + return NULL; } const int n_segments = whisper_full_n_segments_from_state(state); for (int i = n_new; i > 0; i--) { @@ -118,99 +136,225 @@ static void new_segment_callback(struct whisper_context *ctx, struct whisper_sta rb_funcall(cb, id_call, 1, segment); } } + + return NULL; } -static void progress_callback(struct whisper_context *ctx, struct whisper_state *state, int progress_cur, void *user_data) { +static void new_segment_callback(struct whisper_context *ctx, struct whisper_state *state, int n_new, void *user_data) { const ruby_whisper_callback_container *container = (ruby_whisper_callback_container *)user_data; - const VALUE progress = INT2NUM(progress_cur); - // Currently, doesn't support state because + if (!ruby_whisper_callback_container_is_present(container)) { + return; + } + + call_new_segment_callbacks_args args = { + container, + state, + n_new + }; + rb_thread_call_with_gvl(call_new_segment_callbacks, (void *)&args); +} + +typedef struct { + const ruby_whisper_callback_container *container; + struct whisper_state *state; + int progress_cur; +} call_progress_callbacks_args; + +static void* +call_progress_callbacks(void *v_args) { + call_progress_callbacks_args *args = (call_progress_callbacks_args *)v_args; + const ruby_whisper_callback_container *container = args->container; + int progress_cur = args->progress_cur; + + // Currently, doesn't support state because // those require to resolve GC-related problems. - if (!NIL_P(container->callback)) { - rb_funcall(container->callback, id_call, 4, *container->context, Qnil, progress, container->user_data); + if (!NIL_P(args->container->callback)) { + rb_funcall(container->callback, id_call, 4, *container->context, Qnil, INT2NUM(progress_cur), container->user_data); + } + if (NIL_P(container->callbacks)) { + return NULL; } const long callbacks_len = RARRAY_LEN(container->callbacks); if (0 == callbacks_len) { - return; + return NULL; } for (int j = 0; j < callbacks_len; j++) { VALUE cb = rb_ary_entry(container->callbacks, j); - rb_funcall(cb, id_call, 1, progress); + rb_funcall(cb, id_call, 1, INT2NUM(progress_cur)); } + + return NULL; } -static bool encoder_begin_callback(struct whisper_context *ctx, struct whisper_state *state, void *user_data) { +static void progress_callback(struct whisper_context *ctx, struct whisper_state *state, int progress_cur, void *user_data) { const ruby_whisper_callback_container *container = (ruby_whisper_callback_container *)user_data; - bool is_aborted = false; - VALUE result; + if (!ruby_whisper_callback_container_is_present(container)) { + return; + } + + call_progress_callbacks_args args = { + container, + state, + progress_cur + }; + rb_thread_call_with_gvl(call_progress_callbacks, (void *)&args); +} + +typedef struct { + const ruby_whisper_callback_container *container; + struct whisper_state *state; + bool is_continued; +} call_encoder_begin_callbacks_args; + +static void* +call_encoder_begin_callbacks(void *v_args) { + call_encoder_begin_callbacks_args *args = (call_encoder_begin_callbacks_args *)v_args; + const ruby_whisper_callback_container *container = args->container; + VALUE result = Qnil; // Currently, doesn't support state because // those require to resolve GC-related problems. if (!NIL_P(container->callback)) { result = rb_funcall(container->callback, id_call, 3, *container->context, Qnil, container->user_data); if (result == Qfalse) { - is_aborted = true; + args->is_continued = false; + return NULL; } } - const long callbacks_len = RARRAY_LEN(container->callbacks); - if (0 == callbacks_len) { - return !is_aborted; - } - for (int j = 0; j < callbacks_len; j++) { - VALUE cb = rb_ary_entry(container->callbacks, j); - result = rb_funcall(cb, id_call, 0); - if (result == Qfalse) { - is_aborted = true; + if (!NIL_P(container->callbacks)) { + const long callbacks_len = RARRAY_LEN(container->callbacks); + if (0 == callbacks_len) { + return NULL; + } + for (int j = 0; j < callbacks_len; j++) { + VALUE cb = rb_ary_entry(container->callbacks, j); + result = rb_funcall(cb, id_call, 0); + if (result == Qfalse) { + args->is_continued = false; + return NULL; + } } } - return !is_aborted; + + return NULL; } -static bool abort_callback(void * user_data) { +static bool encoder_begin_callback(struct whisper_context *ctx, struct whisper_state *state, void *user_data) { const ruby_whisper_callback_container *container = (ruby_whisper_callback_container *)user_data; + if (!ruby_whisper_callback_container_is_present(container)) { + return true; + } + + call_encoder_begin_callbacks_args args = { + container, + state, + true + }; + rb_thread_call_with_gvl(call_encoder_begin_callbacks, (void *)&args); + + return args.is_continued; +} + +typedef struct { + const ruby_whisper_callback_container *container; + bool is_interrupted; +} call_abort_callbacks_args; + +static void* +call_abort_callbacks(void *v_args) { + call_abort_callbacks_args *args = (call_abort_callbacks_args *)v_args; + const ruby_whisper_callback_container *container = args->container; + VALUE result = Qnil; + if (!NIL_P(container->callback)) { - VALUE result = rb_funcall(container->callback, id_call, 1, container->user_data); - if (!NIL_P(result) && Qfalse != result) { - return true; + result = rb_funcall(container->callback, id_call, 1, container->user_data); + if (RTEST(result)) { + args->is_interrupted = true; + return NULL; } } - const long callbacks_len = RARRAY_LEN(container->callbacks); - if (0 == callbacks_len) { - return false; + if (NIL_P(container->callbacks)) { + return NULL; } - for (int j = 0; j < callbacks_len; j++) { + const long n_callbacks = RARRAY_LEN(container->callbacks); + if (0 == n_callbacks) { + return NULL; + } + for (int j = 0; j < n_callbacks; j++) { VALUE cb = rb_ary_entry(container->callbacks, j); - VALUE result = rb_funcall(cb, id_call, 1, container->user_data); - if (!NIL_P(result) && Qfalse != result) { - return true; + VALUE result = rb_funcall(cb, id_call, 0); + if (RTEST(result)) { + args->is_interrupted = true; + return NULL; } } - return false; + + return NULL; +} + +static bool abort_callback(void * user_data) { + ruby_whisper_abort_callback_user_data *data = (ruby_whisper_abort_callback_user_data *)user_data; + + int is_interrupted = RUBY_ATOMIC_LOAD(data->is_interrupted); + if (is_interrupted) { + return true; + } + + if (!(data->callback_container) || !ruby_whisper_callback_container_is_present(data->callback_container)) { + return false; + } + + call_abort_callbacks_args args = { + data->callback_container, + false + }; + rb_thread_call_with_gvl(call_abort_callbacks, (void *)&args); + + return args.is_interrupted; } -static void register_callbacks(ruby_whisper_params * rwp, VALUE * context) { - if (!NIL_P(rwp->new_segment_callback_container->callback) || 0 != RARRAY_LEN(rwp->new_segment_callback_container->callbacks)) { +static void +check_thread_safety(ruby_whisper_params *rwp, int n_processors) +{ + if (n_processors == 1) { + return; + } + + // new_segment_callback is called only after multiple threads are joined + // progress_callback is not called when parallel + + if (ruby_whisper_callback_container_is_present(rwp->encoder_begin_callback_container)) { + rb_raise(rb_eRuntimeError, "encoder begin callback not supported on parallel transcription"); + } + + if (ruby_whisper_callback_container_is_present(rwp->abort_callback_container)) { + rb_raise(rb_eRuntimeError, "abort callback not supported on parallel transcription"); + } +} + +static void register_callbacks(ruby_whisper_params * rwp, VALUE * context, ruby_whisper_abort_callback_user_data *abort_callback_user_data) { + if (ruby_whisper_callback_container_is_present(rwp->new_segment_callback_container)) { rwp->new_segment_callback_container->context = context; rwp->params.new_segment_callback = new_segment_callback; rwp->params.new_segment_callback_user_data = rwp->new_segment_callback_container; } - if (!NIL_P(rwp->progress_callback_container->callback) || 0 != RARRAY_LEN(rwp->progress_callback_container->callbacks)) { + if (ruby_whisper_callback_container_is_present(rwp->progress_callback_container)) { rwp->progress_callback_container->context = context; rwp->params.progress_callback = progress_callback; rwp->params.progress_callback_user_data = rwp->progress_callback_container; } - if (!NIL_P(rwp->encoder_begin_callback_container->callback) || 0 != RARRAY_LEN(rwp->encoder_begin_callback_container->callbacks)) { + if (ruby_whisper_callback_container_is_present(rwp->encoder_begin_callback_container)) { rwp->encoder_begin_callback_container->context = context; rwp->params.encoder_begin_callback = encoder_begin_callback; rwp->params.encoder_begin_callback_user_data = rwp->encoder_begin_callback_container; } - if (!NIL_P(rwp->abort_callback_container->callback) || 0 != RARRAY_LEN(rwp->abort_callback_container->callbacks)) { - rwp->abort_callback_container->context = context; - rwp->params.abort_callback = abort_callback; - rwp->params.abort_callback_user_data = rwp->abort_callback_container; - } + abort_callback_user_data->callback_container = rwp->abort_callback_container; + rwp->abort_callback_container->context = context; + rwp->params.abort_callback = abort_callback; + rwp->params.abort_callback_user_data = (void *)abort_callback_user_data; } static void set_vad_params(ruby_whisper_params *rwp) @@ -221,9 +365,10 @@ static void set_vad_params(ruby_whisper_params *rwp) } void -prepare_transcription(ruby_whisper_params *rwp, VALUE *context) +prepare_transcription(ruby_whisper_params *rwp, VALUE *context, int n_processors, ruby_whisper_abort_callback_user_data *abort_callback_user_data) { - register_callbacks(rwp, context); + check_thread_safety(rwp, n_processors); + register_callbacks(rwp, context, abort_callback_user_data); set_vad_params(rwp); } @@ -231,16 +376,30 @@ void rb_whisper_params_mark(void *p) { ruby_whisper_params *rwp = (ruby_whisper_params *)p; - rb_whisper_callbcack_container_mark(rwp->new_segment_callback_container); - rb_whisper_callbcack_container_mark(rwp->progress_callback_container); - rb_whisper_callbcack_container_mark(rwp->encoder_begin_callback_container); - rb_whisper_callbcack_container_mark(rwp->abort_callback_container); + ruby_whisper_callback_container_mark(rwp->new_segment_callback_container); + ruby_whisper_callback_container_mark(rwp->progress_callback_container); + ruby_whisper_callback_container_mark(rwp->encoder_begin_callback_container); + ruby_whisper_callback_container_mark(rwp->abort_callback_container); rb_gc_mark(rwp->vad_params); } void ruby_whisper_params_free(ruby_whisper_params *rwp) { + if (rwp->params.language) { + ruby_xfree((void *)rwp->params.language); + } + if (rwp->params.initial_prompt) { + ruby_xfree((void *)rwp->params.initial_prompt); + } + if (rwp->params.vad_model_path) { + ruby_xfree((void *)rwp->params.vad_model_path); + } + + xfree(rwp->new_segment_callback_container); + xfree(rwp->progress_callback_container); + xfree(rwp->encoder_begin_callback_container); + xfree(rwp->abort_callback_container); } void @@ -249,7 +408,7 @@ rb_whisper_params_free(void *p) ruby_whisper_params *rwp = (ruby_whisper_params *)p; // How to free user_data and callback only when not referred to by others? ruby_whisper_params_free(rwp); - free(rwp); + xfree(rwp); } static size_t @@ -277,12 +436,21 @@ ruby_whisper_params_allocate(VALUE klass) ruby_whisper_params *rwp; VALUE obj = TypedData_Make_Struct(klass, ruby_whisper_params, &ruby_whisper_params_type, rwp); rwp->params = whisper_full_default_params(WHISPER_SAMPLING_GREEDY); + if (rwp->params.language != NULL) { + rwp->params.language = ruby_strdup(rwp->params.language); + } + if (rwp->params.initial_prompt != NULL) { + rwp->params.initial_prompt = ruby_strdup(rwp->params.initial_prompt); + } + if (rwp->params.vad_model_path != NULL) { + rwp->params.vad_model_path = ruby_strdup(rwp->params.vad_model_path); + } rwp->diarize = false; rwp->vad_params = TypedData_Wrap_Struct(cVADParams, &ruby_whisper_vad_params_type, (void *)&rwp->params.vad_params); - rwp->new_segment_callback_container = rb_whisper_callback_container_allocate(); - rwp->progress_callback_container = rb_whisper_callback_container_allocate(); - rwp->encoder_begin_callback_container = rb_whisper_callback_container_allocate(); - rwp->abort_callback_container = rb_whisper_callback_container_allocate(); + rwp->new_segment_callback_container = ruby_whisper_callback_container_allocate(); + rwp->progress_callback_container = ruby_whisper_callback_container_allocate(); + rwp->encoder_begin_callback_container = ruby_whisper_callback_container_allocate(); + rwp->abort_callback_container = ruby_whisper_callback_container_allocate(); return obj; } @@ -297,10 +465,12 @@ ruby_whisper_params_set_language(VALUE self, VALUE value) { ruby_whisper_params *rwp; TypedData_Get_Struct(self, ruby_whisper_params, &ruby_whisper_params_type, rwp); + ruby_xfree((void *)rwp->params.language); + rwp->params.language = NULL; if (value == Qfalse || value == Qnil) { - rwp->params.language = "auto"; + rwp->params.language = ruby_strdup("auto"); } else { - rwp->params.language = StringValueCStr(value); + rwp->params.language = ruby_strdup(StringValueCStr(value)); } return value; } @@ -609,7 +779,13 @@ ruby_whisper_params_set_initial_prompt(VALUE self, VALUE value) { ruby_whisper_params *rwp; TypedData_Get_Struct(self, ruby_whisper_params, &ruby_whisper_params_type, rwp); - rwp->params.initial_prompt = StringValueCStr(value); + ruby_xfree((void *)rwp->params.initial_prompt); + rwp->params.initial_prompt = NULL; + if (NIL_P(value)) { + rwp->params.initial_prompt = NULL; + } else { + rwp->params.initial_prompt = ruby_strdup(StringValueCStr(value)); + } return value; } /* @@ -1104,12 +1280,14 @@ ruby_whisper_params_set_vad_model_path(VALUE self, VALUE value) { ruby_whisper_params *rwp; TypedData_Get_Struct(self, ruby_whisper_params, &ruby_whisper_params_type, rwp); + ruby_xfree((void *)rwp->params.vad_model_path); + rwp->params.vad_model_path = NULL; if (NIL_P(value)) { rwp->params.vad_model_path = NULL; return value; } VALUE path = ruby_whisper_normalize_model_path(value); - rwp->params.vad_model_path = StringValueCStr(path); + rwp->params.vad_model_path = ruby_strdup(StringValueCStr(path)); return value; } @@ -1236,6 +1414,9 @@ ruby_whisper_params_on_new_segment(VALUE self) ruby_whisper_params *rwp; TypedData_Get_Struct(self, ruby_whisper_params, &ruby_whisper_params_type, rwp); const VALUE blk = rb_block_proc(); + if (NIL_P(rwp->new_segment_callback_container->callbacks)) { + rwp->new_segment_callback_container->callbacks = rb_ary_new(); + } rb_ary_push(rwp->new_segment_callback_container->callbacks, blk); return Qnil; } @@ -1256,6 +1437,9 @@ ruby_whisper_params_on_progress(VALUE self) ruby_whisper_params *rwp; TypedData_Get_Struct(self, ruby_whisper_params, &ruby_whisper_params_type, rwp); const VALUE blk = rb_block_proc(); + if (NIL_P(rwp->progress_callback_container->callbacks)) { + rwp->progress_callback_container->callbacks = rb_ary_new(); + } rb_ary_push(rwp->progress_callback_container->callbacks, blk); return Qnil; } @@ -1276,6 +1460,9 @@ ruby_whisper_params_on_encoder_begin(VALUE self) ruby_whisper_params *rwp; TypedData_Get_Struct(self, ruby_whisper_params, &ruby_whisper_params_type, rwp); const VALUE blk = rb_block_proc(); + if (NIL_P(rwp->encoder_begin_callback_container->callbacks)) { + rwp->encoder_begin_callback_container->callbacks = rb_ary_new(); + } rb_ary_push(rwp->encoder_begin_callback_container->callbacks, blk); return Qnil; } @@ -1300,6 +1487,9 @@ ruby_whisper_params_abort_on(VALUE self) ruby_whisper_params *rwp; TypedData_Get_Struct(self, ruby_whisper_params, &ruby_whisper_params_type, rwp); const VALUE blk = rb_block_proc(); + if (NIL_P(rwp->abort_callback_container->callbacks)) { + rwp->abort_callback_container->callbacks = rb_ary_new(); + } rb_ary_push(rwp->abort_callback_container->callbacks, blk); return Qnil; } diff --git a/bindings/ruby/ext/ruby_whisper_segment.c b/bindings/ruby/ext/ruby_whisper_segment.c index 5229cb53900..cf0372797d3 100644 --- a/bindings/ruby/ext/ruby_whisper_segment.c +++ b/bindings/ruby/ext/ruby_whisper_segment.c @@ -1,16 +1,15 @@ -#include <ruby.h> #include "ruby_whisper.h" #define N_KEY_NAMES 6 extern ID id___method__; extern ID id_to_enum; -static VALUE sym_start_time; -static VALUE sym_end_time; -static VALUE sym_text; -static VALUE sym_no_speech_prob; -static VALUE sym_speaker_turn_next; -static VALUE sym_n_tokens; +VALUE sym_start_time; +VALUE sym_end_time; +VALUE sym_text; +VALUE sym_no_speech_prob; +VALUE sym_speaker_turn_next; +VALUE sym_n_tokens; extern const rb_data_type_t ruby_whisper_type; diff --git a/bindings/ruby/ext/ruby_whisper_token.c b/bindings/ruby/ext/ruby_whisper_token.c index ea4f4e635d2..73f5a547daf 100644 --- a/bindings/ruby/ext/ruby_whisper_token.c +++ b/bindings/ruby/ext/ruby_whisper_token.c @@ -1,4 +1,3 @@ -#include <ruby.h> #include "ruby_whisper.h" #define N_KEY_NAMES 11 @@ -25,12 +24,34 @@ ruby_whisper_token_memsize(const void *p) if (!rwt) { return 0; } - return sizeof(rwt); + size_t size = sizeof(*rwt); + if (rwt->token_data) { + size += sizeof(*rwt->token_data); + } + return size; +} + +static void +ruby_whisper_token_mark(void *p) +{ + ruby_whisper_token *rwt = (ruby_whisper_token *)p; + rb_gc_mark(rwt->text); +} + +static void +ruby_whisper_token_free(void *p) +{ + ruby_whisper_token *rwt = (ruby_whisper_token *)p; + if (rwt->token_data) { + xfree(rwt->token_data); + rwt->token_data = NULL; + } + xfree(rwt); } static const rb_data_type_t ruby_whisper_token_type = { "ruby_whisper_token", - {0, RUBY_DEFAULT_FREE, ruby_whisper_token_memsize,}, + {ruby_whisper_token_mark, ruby_whisper_token_free, ruby_whisper_token_memsize,}, 0, 0, 0 }; @@ -41,19 +62,19 @@ ruby_whisper_token_allocate(VALUE klass) ruby_whisper_token *rwt; VALUE token = TypedData_Make_Struct(klass, ruby_whisper_token, &ruby_whisper_token_type, rwt); rwt->token_data = NULL; - rwt->text = NULL; + rwt->text = Qnil; return token; } VALUE ruby_whisper_token_s_init(struct whisper_context *context, int i_segment, int i_token) { - whisper_token_data token_data = whisper_full_get_token_data(context, i_segment, i_token); const VALUE token = ruby_whisper_token_allocate(cToken); ruby_whisper_token *rwt; TypedData_Get_Struct(token, ruby_whisper_token, &ruby_whisper_token_type, rwt); - rwt->token_data = &token_data; - rwt->text = whisper_full_get_token_text(context, i_segment, i_token); + rwt->token_data = ALLOC(whisper_token_data); + *(rwt->token_data) = whisper_full_get_token_data(context, i_segment, i_token); + rwt->text = rb_str_new2(whisper_full_get_token_text(context, i_segment, i_token)); return token; } @@ -183,10 +204,9 @@ ruby_whisper_token_get_text(VALUE self) { ruby_whisper_token *rwt; GetToken(self, rwt); - return rb_str_new2(rwt->text); + return rwt->text; } - /* * Start time of the token. * diff --git a/bindings/ruby/ext/ruby_whisper_transcribe.cpp b/bindings/ruby/ext/ruby_whisper_transcribe.cpp index 594b2db90e3..73f606ca476 100644 --- a/bindings/ruby/ext/ruby_whisper_transcribe.cpp +++ b/bindings/ruby/ext/ruby_whisper_transcribe.cpp @@ -1,4 +1,3 @@ -#include <ruby.h> #include "ruby_whisper.h" #include "common-whisper.h" #include <string> @@ -13,10 +12,30 @@ extern const rb_data_type_t ruby_whisper_params_type; extern ID id_to_s; extern ID id_call; +extern ID id_to_path; extern ID transcribe_option_names[1]; -extern void -prepare_transcription(ruby_whisper_params * rwp, VALUE * self); +extern void prepare_transcription(ruby_whisper_params * rwp, VALUE * self, int n_processors); +extern VALUE full_body(VALUE rb_args); +extern VALUE full_parallel_body(VALUE rb_args); + +typedef struct{ + struct whisper_context *context; + struct whisper_full_params *params; + float *samples; + size_t n_samples; + int n_processors; + int result; +} transcribe_without_gvl_args; + +static void* +transcribe_without_gvl(void *rb_args) +{ + transcribe_without_gvl_args *args = (transcribe_without_gvl_args *)rb_args; + args->result = whisper_full_parallel(args->context, *args->params, args->samples, args->n_samples, args->n_processors); + + return NULL; +} /* * transcribe a single file @@ -50,6 +69,9 @@ ruby_whisper_transcribe(int argc, VALUE *argv, VALUE self) { rb_raise(rb_eRuntimeError, "Expected file path to wave file"); } + if (rb_respond_to(wave_file_path, id_to_path)) { + wave_file_path = rb_funcall(wave_file_path, id_to_path, 0); + } std::string fname_inp = StringValueCStr(wave_file_path); std::vector<float> pcmf32; // mono-channel F32 PCM @@ -59,20 +81,28 @@ ruby_whisper_transcribe(int argc, VALUE *argv, VALUE self) { fprintf(stderr, "error: failed to open '%s' as WAV file\n", fname_inp.c_str()); return self; } - // Commented out because it is work in progress - // { - // static bool is_aborted = false; // NOTE: this should be atomic to avoid data race - - // rwp->params.encoder_begin_callback = [](struct whisper_context * /*ctx*/, struct whisper_state * /*state*/, void * user_data) { - // bool is_aborted = *(bool*)user_data; - // return !is_aborted; - // }; - // rwp->params.encoder_begin_callback_user_data = &is_aborted; - // } - prepare_transcription(rwp, &self); - - if (whisper_full_parallel(rw->context, rwp->params, pcmf32.data(), pcmf32.size(), n_processors) != 0) { + VALUE rb_result; + if (n_processors == 1) { + ruby_whisper_full_args args = { + &self, + ¶ms, + pcmf32.data(), + (int)pcmf32.size(), + }; + rb_result = full_body((VALUE)&args); + } else { + ruby_whisper_full_parallel_args parallel_args = { + &self, + ¶ms, + pcmf32.data(), + (int)pcmf32.size(), + n_processors, + }; + rb_result = full_parallel_body((VALUE)¶llel_args); + } + const int result = NUM2INT(rb_result); + if (result != 0) { fprintf(stderr, "failed to process audio\n"); return self; } diff --git a/bindings/ruby/ext/ruby_whisper_vad_context.c b/bindings/ruby/ext/ruby_whisper_vad_context.c index bf2ed2ba465..97c9736b6f4 100644 --- a/bindings/ruby/ext/ruby_whisper_vad_context.c +++ b/bindings/ruby/ext/ruby_whisper_vad_context.c @@ -1,12 +1,23 @@ -#include <ruby.h> #include "ruby_whisper.h" extern ID id_to_s; extern VALUE cVADContext; +extern const rb_data_type_t ruby_whisper_vad_params_type; extern VALUE ruby_whisper_vad_detect(VALUE self, VALUE file_path, VALUE params); extern VALUE ruby_whisper_normalize_model_path(VALUE model_path); +extern parsed_samples_t parse_samples(VALUE *samples, VALUE *n_samples); +extern VALUE release_samples(VALUE parsed); + +extern VALUE ruby_whisper_vad_segments_s_init(struct whisper_vad_segments *segments); + +typedef struct segments_from_samples_args { + VALUE *context; + VALUE *params; + float *samples; + int n_samples; +} segments_from_samples_args; static size_t ruby_whisper_vad_context_memsize(const void *p) @@ -66,10 +77,46 @@ ruby_whisper_vad_context_initialize(VALUE self, VALUE model_path) return Qnil; } +static VALUE +segments_from_samples_body(VALUE rb_args) +{ + segments_from_samples_args *args = (segments_from_samples_args *)rb_args; + + ruby_whisper_vad_context *rwvc; + ruby_whisper_vad_params *rwvp; + GetVADContext(*args->context, rwvc); + GetVADParams(*args->params, rwvp); + + struct whisper_vad_segments *segments = whisper_vad_segments_from_samples(rwvc->context, rwvp->params, args->samples, args->n_samples); + + return ruby_whisper_vad_segments_s_init(segments); +} + +static VALUE +ruby_whisper_vad_segments_from_samples(int argc, VALUE *argv, VALUE self) +{ + if (argc < 2 || argc > 3) { + rb_raise(rb_eArgError, "wrong number of arguments (given %d, expected 2..3)", argc); + } + + VALUE n_samples = argc == 2 ? Qnil : argv[2]; + struct parsed_samples_t parsed = parse_samples(&argv[1], &n_samples); + segments_from_samples_args args = { + &self, + &argv[0], + parsed.samples, + parsed.n_samples, + }; + VALUE segments = rb_ensure(segments_from_samples_body, (VALUE)&args, release_samples, (VALUE)&parsed); + + return segments; +} + void init_ruby_whisper_vad_context(VALUE *mVAD) { cVADContext = rb_define_class_under(*mVAD, "Context", rb_cObject); rb_define_alloc_func(cVADContext, ruby_whisper_vad_context_s_allocate); rb_define_method(cVADContext, "initialize", ruby_whisper_vad_context_initialize, 1); + rb_define_method(cVADContext, "segments_from_samples", ruby_whisper_vad_segments_from_samples, -1); rb_define_method(cVADContext, "detect", ruby_whisper_vad_detect, 2); } diff --git a/bindings/ruby/ext/ruby_whisper_vad_context_detect.cpp b/bindings/ruby/ext/ruby_whisper_vad_context_detect.cpp index 58609f87742..802b0222dbd 100644 --- a/bindings/ruby/ext/ruby_whisper_vad_context_detect.cpp +++ b/bindings/ruby/ext/ruby_whisper_vad_context_detect.cpp @@ -1,4 +1,3 @@ -#include <ruby.h> #include "ruby_whisper.h" #include "common-whisper.h" #include <string> @@ -8,6 +7,8 @@ extern "C" { #endif +extern ID id_to_path; + extern VALUE cVADSegments; extern const rb_data_type_t ruby_whisper_vad_context_type; @@ -25,12 +26,12 @@ ruby_whisper_vad_detect(VALUE self, VALUE file_path, VALUE params) { std::vector<std::vector<float>> pcmf32s; whisper_vad_segments *segments; - TypedData_Get_Struct(self, ruby_whisper_vad_context, &ruby_whisper_vad_context_type, rwvc); - if (rwvc->context == NULL) { - rb_raise(rb_eRuntimeError, "Doesn't have referenxe to context internally"); - } + GetVADContext(self, rwvc); TypedData_Get_Struct(params, ruby_whisper_vad_params, &ruby_whisper_vad_params_type, rwvp); + if (rb_respond_to(file_path, id_to_path)) { + file_path = rb_funcall(file_path, id_to_path, 0); + } cpp_file_path = StringValueCStr(file_path); if (!read_audio_data(cpp_file_path, pcmf32, pcmf32s, false)) { diff --git a/bindings/ruby/ext/ruby_whisper_vad_params.c b/bindings/ruby/ext/ruby_whisper_vad_params.c index f254bfa2138..28256650e32 100644 --- a/bindings/ruby/ext/ruby_whisper_vad_params.c +++ b/bindings/ruby/ext/ruby_whisper_vad_params.c @@ -1,4 +1,3 @@ -#include <ruby.h> #include "ruby_whisper.h" #define DEFINE_PARAM(param_name, nth) \ diff --git a/bindings/ruby/ext/ruby_whisper_vad_segment.c b/bindings/ruby/ext/ruby_whisper_vad_segment.c index 49ff0aadcce..84a007bb725 100644 --- a/bindings/ruby/ext/ruby_whisper_vad_segment.c +++ b/bindings/ruby/ext/ruby_whisper_vad_segment.c @@ -1,4 +1,3 @@ -#include <ruby.h> #include "ruby_whisper.h" #define N_KEY_NAMES 2 diff --git a/bindings/ruby/ext/ruby_whisper_vad_segments.c b/bindings/ruby/ext/ruby_whisper_vad_segments.c index 1bb375937a4..db62fdb6222 100644 --- a/bindings/ruby/ext/ruby_whisper_vad_segments.c +++ b/bindings/ruby/ext/ruby_whisper_vad_segments.c @@ -1,4 +1,3 @@ -#include <ruby.h> #include "ruby_whisper.h" extern ID id___method__; diff --git a/bindings/ruby/extsources.rb b/bindings/ruby/extsources.rb index b24f1a7f13d..850ac9841b1 100644 --- a/bindings/ruby/extsources.rb +++ b/bindings/ruby/extsources.rb @@ -5,37 +5,53 @@ .devops .github ci - examples/wchess/wchess.wasm + examples/addon.node + examples/bench.wasm + examples/command + examples/command.wasm + examples/lsp + examples/main + examples/python + examples/stream + examples/stream.wasm + examples/sycl + examples/talk-llama + examples/wchess examples/whisper.android examples/whisper.android.java + examples/whisper.nvim examples/whisper.objc examples/whisper.swiftui + examples/whisper.wasm grammars models samples scripts + tests ].collect {|dir| root/dir} ignored_files = %w[ AUTHORS Makefile - README.md - README_sycl.md .gitignore .gitmodules .dockerignore - whisper.nvim - twitch.sh - yt-wsp.sh - close-issue.yml - build-xcframework.sh +] +ignored_exts = %w[ + .yml + .sh + .md + .py + .js + .nvim ] EXTSOURCES = `git ls-files -z #{root}`.split("\x0") .collect {|file| Pathname(file)} .reject {|file| - ignored_dirs.any? {|dir| file.descend.any? {|desc| desc == dir}} || + ignored_exts.include?(file.extname) || ignored_files.include?(file.basename.to_path) || - (file.descend.to_a[1] != root && file.descend.to_a[1] != Pathname("..")/"javascript") + ignored_dirs.any? {|dir| file.descend.any? {|desc| desc == dir}} || + (file.descend.to_a[1] != root && file != Pathname("..")/"javascript"/"package-tmpl.json") } .collect(&:to_path) diff --git a/bindings/ruby/lib/whisper/context.rb b/bindings/ruby/lib/whisper/context.rb deleted file mode 100644 index c3a134b773d..00000000000 --- a/bindings/ruby/lib/whisper/context.rb +++ /dev/null @@ -1,15 +0,0 @@ -module Whisper - class Context - def to_srt - each_segment.with_index.reduce("") {|srt, (segment, index)| - srt << "#{index + 1}\n#{segment.to_srt_cue}\n" - } - end - - def to_webvtt - each_segment.with_index.reduce("WEBVTT\n\n") {|webvtt, (segment, index)| - webvtt << "#{index + 1}\n#{segment.to_webvtt_cue}\n" - } - end - end -end diff --git a/bindings/ruby/lib/whisper/log_settable.rb b/bindings/ruby/lib/whisper/log_settable.rb new file mode 100644 index 00000000000..2f8218d26ee --- /dev/null +++ b/bindings/ruby/lib/whisper/log_settable.rb @@ -0,0 +1,36 @@ +require "mutex_m" + +module Whisper + module LogSettable + class << self + def extended(base) + base.extend Mutex_m + end + end + + private + + def start_log_callback_thread + return if @log_callback_thread&.alive? + + @log_callback_thread = Thread.new { + begin + while logs = drain_logs + begin + callback, user_data = synchronize {[@log_callback, @log_callback_user_data]} + next if callback.nil? + + logs.each do |(level, text)| + callback.call level, text, user_data + end + rescue => err + $stderr.puts err + end + end + rescue => err + $stderr.puts err + end + } + end + end +end diff --git a/bindings/ruby/lib/whisper/model/uri.rb b/bindings/ruby/lib/whisper/model/uri.rb index 8eb57e5e8cf..ef92eb901c4 100644 --- a/bindings/ruby/lib/whisper/model/uri.rb +++ b/bindings/ruby/lib/whisper/model/uri.rb @@ -41,6 +41,8 @@ def base_cache_dir def cache path = cache_path + return path if cache_path.exist? + headers = {} headers["if-modified-since"] = path.mtime.httpdate if path.exist? request @uri, headers @@ -216,8 +218,18 @@ def escaping(path) @pre_converted_models[name] = URI.new("https://huggingface.co/ggml-org/whisper-vad/resolve/main/ggml-#{name}.bin") end + %w[ + parakeet-tdt-0.6b-v3-f16 + parakeet-tdt-0.6b-v3-f32 + parakeet-tdt-0.6b-v3-q4_0 + parakeet-tdt-0.6b-v3-q4_k + parakeet-tdt-0.6b-v3-q8_0 + ].each do |name| + @pre_converted_models[name] = URI.new("https://huggingface.co/ggml-org/parakeet-GGUF/resolve/main/ggml-#{name}.bin") + end + @coreml_compiled_models = @pre_converted_models.each_with_object({}) {|(name, uri), models| - next if name.end_with?("-tdrz") || name.start_with?("silero-") + next if name.end_with?("-tdrz") || name.start_with?("silero-") || name.start_with?("parakeet-") if matched = name.match(/\A(?<name>.*)-q\d_\d\z/) name = matched[:name] diff --git a/bindings/ruby/lib/whisper/output.rb b/bindings/ruby/lib/whisper/output.rb new file mode 100644 index 00000000000..1781af17a33 --- /dev/null +++ b/bindings/ruby/lib/whisper/output.rb @@ -0,0 +1,74 @@ +module Whisper + module Output + module Context + def to_srt + each_segment.with_index.reduce("") {|srt, (segment, index)| + srt << "#{index + 1}\n#{segment.to_srt_cue}\n" + } + end + + def to_webvtt + each_segment.with_index.reduce("WEBVTT\n\n") {|webvtt, (segment, index)| + webvtt << "#{index + 1}\n#{segment.to_webvtt_cue}\n" + } + end + end + + module Segment + SRT_ESCAPES = { + "&" => "&", + "<" => "<", + ">" => ">", + } + SRT_ESCAPES_RE = Regexp.union(SRT_ESCAPES.keys) + private_constant :SRT_ESCAPES, :SRT_ESCAPES_RE + + def to_srt_cue + "#{srt_start_time} --> #{srt_end_time}\n#{srt_text}\n" + end + + def to_webvtt_cue + "#{webvtt_start_time} --> #{webvtt_end_time}\n#{webvtt_text}\n" + end + + private + + def time_to_a(time) + sec, decimal_part = time.divmod(1000) + min, sec = sec.divmod(60) + hour, min = min.divmod(60) + [hour, min, sec, decimal_part] + end + + def srt_time(time) + "%02d:%02d:%02d,%03d" % time_to_a(time) + end + + def srt_start_time + srt_time(start_time) + end + + def srt_end_time + srt_time(end_time) + end + + def srt_text + text.gsub(SRT_ESCAPES_RE, SRT_ESCAPES) + end + + def webvtt_time(time) + "%02d:%02d:%02d.%03d" % time_to_a(time) + end + + def webvtt_start_time + webvtt_time(start_time) + end + + def webvtt_end_time + webvtt_time(end_time) + end + + alias webvtt_text srt_text + end + end +end diff --git a/bindings/ruby/lib/whisper/segment.rb b/bindings/ruby/lib/whisper/segment.rb deleted file mode 100644 index dc187dcac36..00000000000 --- a/bindings/ruby/lib/whisper/segment.rb +++ /dev/null @@ -1,58 +0,0 @@ -module Whisper - class Segment - SRT_ESCAPES = { - "&" => "&", - "<" => "<", - ">" => ">", - } - SRT_ESCAPES_RE = Regexp.union(SRT_ESCAPES.keys) - private_constant :SRT_ESCAPES, :SRT_ESCAPES_RE - - def to_srt_cue - "#{srt_start_time} --> #{srt_end_time}\n#{srt_text}\n" - end - - def to_webvtt_cue - "#{webvtt_start_time} --> #{webvtt_end_time}\n#{webvtt_text}\n" - end - - private - - def time_to_a(time) - sec, decimal_part = time.divmod(1000) - min, sec = sec.divmod(60) - hour, min = min.divmod(60) - [hour, min, sec, decimal_part] - end - - def srt_time(time) - "%02d:%02d:%02d,%03d" % time_to_a(time) - end - - def srt_start_time - srt_time(start_time) - end - - def srt_end_time - srt_time(end_time) - end - - def srt_text - text.gsub(SRT_ESCAPES_RE, SRT_ESCAPES) - end - - def webvtt_time(time) - "%02d:%02d:%02d.%03d" % time_to_a(time) - end - - def webvtt_start_time - webvtt_time(start_time) - end - - def webvtt_end_time - webvtt_time(end_time) - end - - alias webvtt_text srt_text - end -end diff --git a/bindings/ruby/sig/whisper.rbs b/bindings/ruby/sig/whisper.rbs index 1137e3f36ab..c12e1fe55e5 100644 --- a/bindings/ruby/sig/whisper.rbs +++ b/bindings/ruby/sig/whisper.rbs @@ -5,10 +5,10 @@ module Whisper end type log_callback = ^(Integer level, String message, Object user_data) -> void - type new_segment_callback = ^(Whisper::Context, void, Integer n_new, Object user_data) -> void - type progress_callback = ^(Whisper::Context, void, Integer progress, Object user_data) -> void - type encoder_begin_callback = ^(Whisper::Context, void, Object user_data) -> void - type abort_callback = ^(Whisper::Context, void, Object user_data) -> boolish + type new_segment_callback = ^(Whisper::Context, untyped, Integer n_new, Object user_data) -> void + type progress_callback = ^(Whisper::Context, untyped, Integer progress, Object user_data) -> void + type encoder_begin_callback = ^(Whisper::Context, untyped, Object user_data) -> void + type abort_callback = ^(Whisper::Context, untyped, Object user_data) -> boolish VERSION: String LOG_LEVEL_NONE: Integer @@ -17,15 +17,44 @@ module Whisper LOG_LEVEL_ERROR: Integer LOG_LEVEL_DEBUG: Integer LOG_LEVEL_CONT: Integer + AHEADS_NONE: Integer + AHEADS_N_TOP_MOST: Integer + AHEADS_CUSTOM: Integer + AHEADS_TINY_EN: Integer + AHEADS_TINY: Integer + AHEADS_BASE_EN: Integer + AHEADS_BASE: Integer + AHEADS_SMALL_EN: Integer + AHEADS_SMALL: Integer + AHEADS_MEDIUM_EN: Integer + AHEADS_MEDIUM: Integer + AHEADS_LARGE_V1: Integer + AHEADS_LARGE_V2: Integer + AHEADS_LARGE_V3: Integer + AHEADS_LARGE_V3_TURBO: Integer def self.lang_max_id: () -> Integer def self.lang_id: (string name) -> Integer def self.lang_str: (Integer id) -> String def self.lang_str_full: (Integer id) -> String - def self.log_set: (log_callback, Object? user_data) -> log_callback + def self.log_set: (log_callback?, Object? user_data) -> log_callback def self.system_info_str: () -> String + module Output + module Context + def to_srt: () -> String + def to_webvtt: () -> String + end + + module Segment + def to_srt_cue: () -> String + def to_webvtt_cue: () -> String + end + end + class Context + include Output::Context + def self.new: (String | path | ::URI::HTTP) -> instance # transcribe a single file @@ -37,8 +66,11 @@ module Whisper # puts text # end # - def transcribe: (string, Params, ?n_processors: Integer) -> self - | (string, Params, ?n_processors: Integer) { (String) -> void } -> self + # If `n_processors` is greater than 1, you cannot set any callbacks including + # new_segment_callback, progress_callback, encoder_begin_callback, abort_callback, + # and log_callback set by Whisper.log_set + def transcribe: (path, Whisper::Params, ?n_processors: Integer) -> self + | (path, Whisper::Params, ?n_processors: Integer) { (String) -> void } -> self def model_n_vocab: () -> Integer def model_n_audio_ctx: () -> Integer @@ -56,7 +88,7 @@ module Whisper # puts segment.text # end # - # Returns an Enumerator if no block given: + # Returns an `Enumerator` if no block given: # # whisper.transcribe("path/to/audio.wav", params) # enum = whisper.each_segment @@ -73,25 +105,25 @@ module Whisper # def full_lang_id: () -> Integer - # Start time of a segment indexed by +segment_index+ in centiseconds (10 times milliseconds). + # Start time of a segment indexed by `segment_index` in centiseconds (10 times milliseconds). # # full_get_segment_t0(3) # => 1668 (16680 ms) # def full_get_segment_t0: (Integer) -> Integer - # End time of a segment indexed by +segment_index+ in centiseconds (10 times milliseconds). + # End time of a segment indexed by `segment_index` in centiseconds (10 times milliseconds). # # full_get_segment_t1(3) # => 1668 (16680 ms) # def full_get_segment_t1: (Integer) -> Integer - # Whether the next segment indexed by +segment_index+ is predicated as a speaker turn. + # Whether the next segment indexed by `segment_index` is predicated as a speaker turn. # # full_get_segment_speacker_turn_next(3) # => true # def full_get_segment_speaker_turn_next: (Integer) -> (true | false) - # Text of a segment indexed by +segment_index+. + # Text of a segment indexed by `segment_index`. # # full_get_segment_text(3) # => "ask not what your country can do for you, ..." # @@ -99,27 +131,51 @@ module Whisper def full_get_segment_no_speech_prob: (Integer) -> Float - # Run the entire model: PCM -> log mel spectrogram -> encoder -> decoder -> text - # Not thread safe for same context + # Run the entire model: PCM -> log mel spectrogram -> encoder -> decoder -> text + # Not thread safe for same context # Uses the specified decoding strategy to obtain the text. # - # The second argument +samples+ must be an array of samples, respond to :length, or be a MemoryView of an array of float. It must be 32 bit float PCM audio data. + # The second argument `samples` must be an array of samples, respond to `:length`, or be a MemoryView of an array of float. It must be 32 bit float PCM audio data. # - def full: (Params, Array[Float] samples, ?Integer n_samples) -> self - | (Params, _Samples, ?Integer n_samples) -> self + def full: (Whisper::Params, Array[Float] samples, ?Integer n_samples) -> self + | (Whisper::Params, _Samples, ?Integer n_samples) -> self - # Split the input audio in chunks and process each chunk separately using whisper_full_with_state() - # Result is stored in the default state of the context - # Not thread safe if executed in parallel on the same context. - # It seems this approach can offer some speedup in some cases. + # Split the input audio in chunks and process each chunk separately using `whisper_full_with_state()` + # Result is stored in the default state of the context + # Not thread safe if executed in parallel on the same context. + # It seems this approach can offer some speedup in some cases. # However, the transcription accuracy can be worse at the beginning and end of each chunk. # - def full_parallel: (Params, Array[Float], ?Integer n_samples) -> self - | (Params, _Samples, ?Integer n_samples) -> self - | (Params, _Samples, ?Integer? n_samples, Integer n_processors) -> self + # If `n_processors` is greater than 1, you cannot set any callbacks including + # new_segment_callback, progress_callback, encoder_begin_callback, abort_callback, + # and log_callback set by Whisper.log_set + def full_parallel: (Whisper::Params, Array[Float], ?Integer n_samples) -> self + | (Whisper::Params, _Samples, ?Integer n_samples) -> self + | (Whisper::Params, _Samples, ?Integer? n_samples, Integer n_processors) -> self + + class Params + def self.new: ( + ?use_gpu: boolish, + ?flash_attn: boolish, + ?gpu_device: Integer, + ?dtw_token_timestamps: boolish, + ?dtw_aheads_preset: Integer, + ?dtw_n_top: Integer | nil, + ) -> instance - def to_srt: () -> String - def to_webvtt: () -> String + def use_gpu=: (boolish) -> boolish + def use_gpu: () -> (true | false) + def flash_attn=: (boolish) -> boolish + def flash_attn: () -> (true | false) + def gpu_device=: (Integer) -> Integer + def gpu_device: () -> Integer + def dtw_token_timestamps=: (boolish) -> boolish + def dtw_token_timestamps: () -> (true | false) + def dtw_aheads_preset=: (Integer) -> Integer + def dtw_aheads_preset: () -> Integer + def dtw_n_top=: (Integer | nil) -> (Integer | nil) + def dtw_n_top: () -> (Integer | nil) + end end class Params @@ -172,35 +228,35 @@ module Whisper def translate: () -> (true | false) def no_context=: (boolish) -> boolish - # If true, does not use past transcription (if any) as initial prompt for the decoder. + # If `true`, does not use past transcription (if any) as initial prompt for the decoder. # def no_context: () -> (true | false) def single_segment=: (boolish) -> boolish - # If true, forces single segment output (useful for streaming). + # If `true`, forces single segment output (useful for streaming). # def single_segment: () -> (true | false) def print_special=: (boolish) -> boolish - # If true, prints special tokens (e.g. <SOT>, <EOT>, <BEG>, etc.). + # If `true`, prints special tokens (e.g. <SOT>, <EOT>, <BEG>, etc.). # def print_special: () -> (true | false) def print_progress=: (boolish) -> boolish - # If true, prints progress information. + # If `true`, prints progress information. # def print_progress: () -> (true | false) def print_realtime=: (boolish) -> boolish - # If true, prints results from within whisper.cpp. (avoid it, use callback instead) + # If `true`, prints results from within whisper.cpp. (avoid it, use callback instead) # def print_realtime: () -> (true | false) - # If true, prints timestamps for each text segment when printing realtime. + # If `true`, prints timestamps for each text segment when printing realtime. # def print_timestamps=: (boolish) -> boolish @@ -208,19 +264,19 @@ module Whisper def suppress_blank=: (boolish) -> boolish - # If true, suppresses blank outputs. + # If `true`, suppresses blank outputs. # def suppress_blank: () -> (true | false) def suppress_nst=: (boolish) -> boolish - # If true, suppresses non-speech-tokens. + # If `true`, suppresses non-speech-tokens. # def suppress_nst: () -> (true | false) def token_timestamps=: (boolish) -> boolish - # If true, enables token-level timestamps. + # If `true`, enables token-level timestamps. # def token_timestamps: () -> (true | false) @@ -232,16 +288,16 @@ module Whisper def split_on_word=: (boolish) -> boolish - # If true, split on word rather than on token (when used with max_len). + # If `true`, split on word rather than on token (when used with max_len). # def split_on_word: () -> (true | false) def initial_prompt=: (_ToS) -> _ToS def carry_initial_prompt=: (boolish) -> boolish - # Tokens to provide to the whisper decoder as initial prompt - # these are prepended to any existing text context from a previous call - # use whisper_tokenize() to convert text to tokens. + # Tokens to provide to the whisper decoder as initial prompt + # these are prepended to any existing text context from a previous call + # use whisper_tokenize() to convert text to tokens. # Maximum of whisper_n_text_ctx()/2 tokens are used (typically 224). # def initial_prompt: () -> (String | nil) @@ -249,7 +305,7 @@ module Whisper def diarize=: (boolish) -> boolish - # If true, enables diarization. + # If `true`, enables diarization. # def diarize: () -> (true | false) @@ -378,7 +434,7 @@ module Whisper # def on_new_segment: { (Segment) -> void } -> void - # Hook called on progress update. Yields each progress Integer between 0 and 100. + # Hook called on progress update. Yields each progress `Integer` between 0 and 100. # def on_progress: { (Integer progress) -> void } -> void @@ -386,7 +442,7 @@ module Whisper # def on_encoder_begin: { () -> void } -> void - # Call block to determine whether abort or not. Return +true+ when you want to abort. + # Call block to determine whether abort or not. Return `true` when you want to abort. # # params.abort_on do # if some_condition @@ -399,6 +455,9 @@ module Whisper def abort_on: { (Object user_data) -> boolish } -> void end + module LogSettable + end + class Model def self.pre_converted_models: () -> Hash[String, Model::URI] def self.coreml_compiled_models: () -> Hash[Model::URI, Model::ZipURI] @@ -429,6 +488,8 @@ module Whisper end class Segment + include Output::Segment + type deconstructed_keys = { start_time: (Integer | nil), end_time: (Integer | nil), @@ -459,21 +520,18 @@ module Whisper # Yields each Whisper::Token: # - # whisper.each_segment.first.each_token do |token| - # p token - # end + # whisper.each_segment.first.each_token do |token| + # p token + # end # - # Returns an Enumerator if no block is given: + # Returns an `Enumerator` if no block is given: # - # whisper.each_segment.first.each_token.to_a # => [#<Whisper::Token>, ...] + # whisper.each_segment.first.each_token.to_a # => [#<Whisper::Token>, ...] # def each_token: { (Token) -> void } -> void | () -> Enumerator[Token] - def to_srt_cue: () -> String - def to_webvtt_cue: () -> String - - # Possible keys: :start_time, :end_time, :text, :no_speech_prob, :speaker_turn_next + # Possible keys: `:start_time`, `:end_time`, `:text`, `:no_speech_prob`, `:speaker_turn_next` # # whisper.each_segment do |segment| # segment => {start_time:, end_time:, text:, no_speech_prob:, speaker_turn_next:} @@ -483,7 +541,7 @@ module Whisper def deconstruct_keys: (Array[:start_time | :end_time | :text | :no_speech_prob | :speaker_turn_next | :n_tokens] | nil) -> deconstructed_keys end - module Token + class Token type deconstructed_keys = { id: (Integer | nil), tid: (Integer | nil), @@ -524,7 +582,7 @@ module Whisper # [EXPERIMENTAL] Token-level timestamps with DTW # - # Do not use if you haven't computed token-level timestamps with dtw. + # Do not use if you haven't computed token-level timestamps with dtw. # Roughly corresponds to the moment in audio in which the token was output. # def t_dtw: () -> Integer @@ -535,14 +593,14 @@ module Whisper # Start time of the token. # - # Token-level timestamp data. + # Token-level timestamp data. # Do not use if you haven't computed token-level timestamps. # def start_time: () -> Integer # End time of the token. # - # Token-level timestamp data. + # Token-level timestamp data. # Do not use if you haven't computed token-level timestamps. # def end_time: () -> Integer @@ -553,6 +611,336 @@ module Whisper def deconstruct_keys: (Array[:id | :tid | :probability | :log_probability | :pt | :ptsum | :t_dtw | :voice_length | :start_time | :end_time | :text] | nil) -> deconstructed_keys end + module Parakeet + extend LogSettable + + VERSION: String + + # Control logging output. The default behavior is to print to stderr. + # + def self.log_set: (nil, Object? user_data) -> nil + | (^(Integer level, String message, Object user_data) -> void, Object? user_data) -> nil + def self.system_info_str: () -> String + + class Context + include Output::Context + + # Load a Parakeet model from the given file path. + # + def self.new: (String | path | ::URI::HTTP, ?Params) -> instance + + # Transcribe a single audio file. + # + def transcribe: (path audio_file_path, Whisper::Parakeet::Params) -> self + + # Run the entire model: PCM -> log mel spectrogram -> encoder -> decoder -> text. + # Not thread safe for the same context. + # + # The second argument `samples` must be an array of samples, respond to `:length`, + # or be a MemoryView of an array of float. It must be 32 bit float PCM audio data. + # + def full: (Whisper::Parakeet::Params, Array[Float] samples, ?Integer n_samples) -> self + | (Whisper::Parakeet::Params, _Samples, ?Integer n_samples) -> self + + # Number of generated text segments. + # + def full_n_segments: () -> Integer + + # Start time of a segment indexed by `segment_index` in centiseconds (10 times milliseconds). + # + # full_get_segment_t0(3) # => 1668 (16680 ms) + # + def full_get_segment_t0: (Integer segment_index) -> Integer + + # End time of a segment indexed by `segment_index` in centiseconds (10 times milliseconds). + # + # full_get_segment_t1(3) # => 1668 (16680 ms) + # + def full_get_segment_t1: (Integer segment_index) -> Integer + + # Text of a segment indexed by `segment_index`. + # + # full_get_segment_text(3) # => "ask not what your country can do for you, ..." + # + def full_get_segment_text: (Integer segment_index) -> String + + # Number of tokens in the segment indexed by `segment_index`. + # + def full_n_tokens: (Integer segment_index) -> Integer + + # Text of the token indexed by `token_index` in the segment indexed by `segment_index`. + # + def full_get_token_text: (Integer segment_index, Integer token_index) -> String + + # Token id of the token indexed by `token_index` in the segment indexed by `segment_index`. + # + def full_get_token_id: (Integer segment_index, Integer token_index) -> Integer + + # Probability of the token indexed by `token_index` in the segment indexed by `segment_index`. + # + def full_get_token_p: (Integer segment_index, Integer token_index) -> Float + + # Token data of the token indexed by `token_index` in the segment indexed by `segment_index`. + # + def full_get_token_data: (Integer segment_index, Integer token_index) -> Token + + def model: () -> Model + + # Yields each Whisper::Parakeet::Segment: + # + # parakeet.transcribe("path/to/audio.wav", params) + # parakeet.each_segment do |segment| + # puts segment.text + # end + # + # Returns an `Enumerator` if no block given: + # + # parakeet.transcribe("path/to/audio.wav", params) + # enum = parakeet.each_segment + # enum.to_a # => [#<Whisper::Parakeet::Segment>, ...] + # + def each_segment: { (Segment) -> void } -> void + | () -> Enumerator[Segment] + + class Params + def self.new: (?use_gpu: boolish, ?gpu_device: Integer) -> instance + def use_gpu: () -> boolish + def use_gpu=: (boolish) -> boolish + def gpu_device: () -> Integer + def gpu_device=: (Integer) -> Integer + end + end + + class Params + def self.new: ( + ?n_threads: Integer, + ?offset_ms: Integer, + ?duration_ms: Integer, + ?no_context: boolish, + ?audio_ctx: Integer, + ?new_segment_callback: ^(Whisper::Parakeet::Context, untyped, Integer n_new, Object user_data) -> void, + ?new_segment_callback_user_data: Object, + ?new_token_callback: ^(Whisper::Parakeet::Context, untyped, Whisper::Parakeet::Token, Object user_data) -> void, + ?new_token_callback_user_data: Object, + ?progress_callback: ^(Whisper::Parakeet::Context, untyped, Integer progress, Object user_data) -> void, + ?progress_callback_user_data: Object, + ?encoder_begin_callback: ^(Whisper::Parakeet::Context, untyped, Object user_data) -> boolish, + ?encoder_begin_callback_user_data: Object, + ?abort_callback: ^(Object user_data) -> boolish, + ?abort_callback_user_data: Object + ) -> instance + + # Number of threads to use. + # + def n_threads=: (Integer) -> Integer + def n_threads: () -> Integer + + # Start offset in ms. + # + def offset_ms=: (Integer) -> Integer + def offset_ms: () -> Integer + + # Audio duration to process in ms. + # + def duration_ms=: (Integer) -> Integer + def duration_ms: () -> Integer + + # If `true`, does not use past transcription (if any) as context. + # + def no_context=: (boolish) -> boolish + def no_context: () -> (true | false) + + # Overwrite the audio context size. `0` uses the default value. + # + def audio_ctx=: (Integer) -> Integer + def audio_ctx: () -> Integer + + # Sets new segment callback, called for every newly generated text segment. + # + # params.new_segment_callback = ->(context, _, n_new, user_data) { + # # ... + # } + # + def new_segment_callback=: (^(Whisper::Parakeet::Context, untyped, Integer n_new, Object user_data) -> void) -> (^(Whisper::Parakeet::Context, untyped, Integer n_new, Object user_data) -> void) + def new_segment_callback: () -> ((^(Whisper::Parakeet::Context, untyped, Integer n_new, Object user_data) -> void) | nil) + + # Sets user data passed to the last argument of new segment callback. + # + def new_segment_callback_user_data=: (Object?) -> Object? + def new_segment_callback_user_data: () -> Object? + + # Sets token callback, called for every newly predicted token. + # + def new_token_callback=: (^(Whisper::Parakeet::Context, untyped, Whisper::Parakeet::Token, Object user_data) -> void) -> (^(Whisper::Parakeet::Context, untyped, Whisper::Parakeet::Token, Object user_data) -> void) + def new_token_callback: () -> ((^(Whisper::Parakeet::Context, untyped, Whisper::Parakeet::Token, Object user_data) -> void) | nil) + + # Sets user data passed to the last argument of token callback. + # + def new_token_callback_user_data=: (Object?) -> Object? + def new_token_callback_user_data: () -> Object? + + # Sets progress callback, called on each progress update. + # + # +progress+ is an Integer between 0 and 100. + # + def progress_callback=: (^(Whisper::Parakeet::Context, untyped, Integer progress, Object user_data) -> void) -> (^(Whisper::Parakeet::Context, untyped, Integer progress, Object user_data) -> void) + def progress_callback: () -> ((^(Whisper::Parakeet::Context, untyped, Integer progress, Object user_data) -> void) | nil) + + # Sets user data passed to the last argument of progress callback. + # + def progress_callback_user_data=: (Object?) -> Object? + def progress_callback_user_data: () -> Object? + + # Sets encoder begin callback, called each time before the encoder starts. + # + # If it returns `false`, the computation is aborted. + # + def encoder_begin_callback=: (^(Whisper::Parakeet::Context, untyped, Object user_data) -> boolish) -> (^(Whisper::Parakeet::Context, untyped, Object user_data) -> boolish) + def encoder_begin_callback: () -> ((^(Whisper::Parakeet::Context, untyped, Object user_data) -> boolish) | nil) + + # Sets user data passed to the last argument of encoder begin callback. + # + def encoder_begin_callback_user_data=: (Object?) -> Object? + def encoder_begin_callback_user_data: () -> Object? + + # Sets abort callback, called each time before ggml computation starts. + # + def abort_callback=: (^(Object user_data) -> boolish) -> (^(Object user_data) -> boolish) + def abort_callback: () -> ((^(Object user_data) -> boolish) | nil) + + # Sets user data passed to the last argument of abort callback. + # + def abort_callback_user_data=: (Object?) -> Object? + def abort_callback_user_data: () -> Object? + + # Hook called on new segment. Yields each Whisper::Parakeet::Segment. + # + def on_new_segment: { (Segment) -> void } -> void + + # Hook called on new token. Yields each Whisper::Parakeet::Token. + # + def on_new_token: { (Token) -> void } -> void + + # Hook called on progress update. Yields each progress `Integer` between 0 and 100. + # + def on_progress: { (Integer progress) -> void } -> void + + # Hook called each time before the encoder starts. + # + def on_encoder_begin: { () -> boolish } -> void + + # Call block to determine whether abort or not. Return `true` when you want to abort. + # + def abort_on: { () -> boolish } -> void + end + + class Segment + include Output::Segment + + type deconstructed_keys = { + start_time: (Integer | nil), + end_time: (Integer | nil), + text: (String | nil) + } + + # Start time in milliseconds. + # + def start_time: () -> Integer + + # End time in milliseconds. + # + def end_time: () -> Integer + + # Text of the segment. + # + def text: () -> String + + # Yields each Whisper::Parakeet::Token: + # + # parakeet.each_segment.first.each_token do |token| + # p token + # end + # + # Returns an `Enumerator` if no block is given: + # + # parakeet.each_segment.first.each_token.to_a # => [#<Whisper::Parakeet::Token>, ...] + # + def each_token: { (Token) -> void } -> void + | () -> Enumerator[Token] + + # Possible keys: `:start_time`, `:end_time`, `:text` + # + def deconstruct_keys: (Array[:start_time | :end_time | :text] | nil) -> deconstructed_keys + end + + class Token + type deconstructed_keys = { + id: (Integer | nil), + duration_idx: (Integer | nil), + duration_value: (Integer | nil), + frame_index: (Integer | nil), + probability: (Float | nil), + log_probability: (Float | nil), + start_time: (Integer | nil), + end_time: (Integer | nil), + word_start: ((true | false) | nil), + text: (String | nil), + } + + # Token ID. + # + def id: () -> Integer + + # Index into the model's durations array. + # + def duration_idx: () -> Integer + + # Actual duration value. + # + def duration_value: () -> Integer + + # Frame index of the token. + # + def frame_index: () -> Integer + + # Probability of the token. + # + def probability: () -> Float + + # Log probability of the token. + # + def log_probability: () -> Float + + # Start time of the token in milliseconds. + # + def start_time: () -> Integer + + # End time of the token in milliseconds. + # + def end_time: () -> Integer + + # Whether this token is the start of a word. + # + def word_start?: () -> (true | false) + + # Get the token text of the token. + # + def text: () -> String + + def deconstruct_keys: (Array[:id | :duration_idx | :duration_value | :frame_index | :probability | :log_probability | :start_time | :end_time | :word_start | :text] | nil) -> deconstructed_keys + end + + class Model + def n_vocab: () -> Integer + def n_audio_ctx: () -> Integer + def n_audio_state: () -> Integer + def n_audio_head: () -> Integer + def n_audio_layer: () -> Integer + def n_mels: () -> Integer + def ftype: () -> Integer + end + end + module VAD class Params def self.new: ( @@ -603,6 +991,8 @@ module Whisper class Context def self.new: (String | path | ::URI::HTTP model_name_or_path) -> instance + def segments_from_samples: (Params, Array[Float] samples, ?Integer n_samples) -> Segments + | (Params, _Samples, ?Integer n_samples) -> Segments def detect: (path wav_file_path, Params) -> Segments end diff --git a/bindings/ruby/test/helper.rb b/bindings/ruby/test/helper.rb index 56cd3849fdd..5e37ad98596 100644 --- a/bindings/ruby/test/helper.rb +++ b/bindings/ruby/test/helper.rb @@ -5,6 +5,8 @@ class TestBase < Test::Unit::TestCase AUDIO = File.join(__dir__, "fixtures", "jfk.wav") + Parakeet = Whisper::Parakeet + class << self def whisper return @whisper if @whisper diff --git a/bindings/ruby/test/jfk_reader/jfk_reader.c b/bindings/ruby/test/jfk_reader/jfk_reader.c index 6657176e767..62207aaa411 100644 --- a/bindings/ruby/test/jfk_reader/jfk_reader.c +++ b/bindings/ruby/test/jfk_reader/jfk_reader.c @@ -2,6 +2,24 @@ #include <ruby/memory_view.h> #include <ruby/encoding.h> +typedef struct { + VALUE audio_path; + int n_samples; + const char *audio_path_str; + float *data; + short *samples; +} jfk_alloc_args; + +static VALUE +jfk_reader_alloc_resources(VALUE arg) +{ + jfk_alloc_args *a = (jfk_alloc_args *)arg; + a->audio_path_str = StringValueCStr(a->audio_path); + a->data = ALLOC_N(float, a->n_samples); + a->samples = ALLOC_N(short, a->n_samples); + return Qnil; +} + static VALUE jfk_reader_initialize(VALUE self, VALUE audio_path) { @@ -13,21 +31,42 @@ static bool jfk_reader_get_memory_view(const VALUE obj, rb_memory_view_t *view, int flags) { VALUE audio_path = rb_iv_get(obj, "audio_path"); - const char *audio_path_str = StringValueCStr(audio_path); + // n_samples is a fixed constant (not derived from user input). const int n_samples = 176000; - float *data = (float *)malloc(n_samples * sizeof(float)); - short *samples = (short *)malloc(n_samples * sizeof(short)); - FILE *file = fopen(audio_path_str, "rb"); + + jfk_alloc_args args = { + .audio_path = audio_path, + .n_samples = n_samples, + .audio_path_str = NULL, + .data = NULL, + .samples = NULL, + }; + + int state; + rb_protect(jfk_reader_alloc_resources, (VALUE)&args, &state); + if (state) { + if (args.samples) xfree(args.samples); + if (args.data) xfree(args.data); + return false; + } + + FILE *file = fopen(args.audio_path_str, "rb"); + if (file == NULL) { + xfree(args.samples); + xfree(args.data); + return false; + } fseek(file, 78, SEEK_SET); - fread(samples, sizeof(short), n_samples, file); + fread(args.samples, sizeof(short), n_samples, file); fclose(file); for (int i = 0; i < n_samples; i++) { - data[i] = samples[i]/32768.0; + args.data[i] = args.samples[i] / 32768.0; } + xfree(args.samples); view->obj = obj; - view->data = (void *)data; + view->data = (void *)args.data; view->byte_size = sizeof(float) * n_samples; view->readonly = true; view->format = "f"; @@ -45,6 +84,10 @@ jfk_reader_get_memory_view(const VALUE obj, rb_memory_view_t *view, int flags) static bool jfk_reader_release_memory_view(const VALUE obj, rb_memory_view_t *view) { + if (view->data) { + xfree(view->data); + view->data = NULL; + } return true; } diff --git a/bindings/ruby/test/test_callback.rb b/bindings/ruby/test/test_callback.rb index a7f49245ade..6490c8abb48 100644 --- a/bindings/ruby/test/test_callback.rb +++ b/bindings/ruby/test/test_callback.rb @@ -129,6 +129,7 @@ def test_encoder_begin_callback_abort return false } @whisper.transcribe(@audio, @params) + sleep 0.5 # wait for logs dequeued assert_match(/encoder_begin_callback returned false - aborting/, logs.join) Whisper.log_set ->(level, buffer, user_data) {}, nil end diff --git a/bindings/ruby/test/test_context_params.rb b/bindings/ruby/test/test_context_params.rb new file mode 100644 index 00000000000..8d19fdc94cb --- /dev/null +++ b/bindings/ruby/test/test_context_params.rb @@ -0,0 +1,82 @@ +require_relative "helper" + +class TestContextParams < TestBase + PARAM_NAMES = [ + :use_gpu, + :flash_attn, + :gpu_device, + :dtw_token_timestamps, + :dtw_aheads_preset, + :dtw_n_top + ] + + def test_new + params = Whisper::Context::Params.new + assert_instance_of Whisper::Context::Params, params + end + + def test_attributes + params = Whisper::Context::Params.new + + assert_true params.use_gpu + params.use_gpu = false + assert_false params.use_gpu + + assert_true params.flash_attn + params.flash_attn = false + assert_false params.flash_attn + + assert_equal 0, params.gpu_device + params.gpu_device = 1 + assert_equal 1, params.gpu_device + + assert_false params.dtw_token_timestamps + params.dtw_token_timestamps = true + assert_true params.dtw_token_timestamps + + assert_equal Whisper::AHEADS_NONE, params.dtw_aheads_preset + params.dtw_aheads_preset =Whisper::AHEADS_BASE + assert_equal Whisper::AHEADS_BASE, params.dtw_aheads_preset + + assert_nil params.dtw_n_top + params.dtw_n_top = 6 + assert_equal 6, params.dtw_n_top + params.dtw_n_top = nil + assert_nil params.dtw_n_top + end + + def test_new_with_kw_args + params = Whisper::Context::Params.new(use_gpu: false) + assert_false params.use_gpu + end + + def test_new_with_kw_wargs_non_existent + assert_raise ArgumentError do + Whisper::Context::Params.new(non_existent: "value") + end + end + + data(PARAM_NAMES.collect {|param| [param, param]}.to_h) + def test_new_with_kw_args_default_values(param) + default_params = Whisper::Context::Params.new + default_value = default_params.send(param) + value = if param == :dtw_n_top + 6 + else + case default_value + in true | false + !default_value + in Integer + default_value + 1 + end + end + params = Whisper::Context::Params.new(param => value) + assert_equal value, params.send(param) + + PARAM_NAMES.reject {|name| name == param}.each do |name| + expected = default_params.send(name) + actual = params.send(name) + assert_equal expected, actual + end + end +end diff --git a/bindings/ruby/test/test_package.rb b/bindings/ruby/test/test_package.rb index 108f34efbeb..f99012cce83 100644 --- a/bindings/ruby/test/test_package.rb +++ b/bindings/ruby/test/test_package.rb @@ -1,12 +1,12 @@ require_relative "helper" require 'tempfile' require 'tmpdir' -require 'shellwords' +require 'open3' class TestPackage < TestBase def test_build Tempfile.create do |file| - assert system("gem", "build", "whispercpp.gemspec", "--output", file.to_path.shellescape, exception: true) + assert system("gem", "build", "whispercpp.gemspec", "--output", file.to_path, exception: true) assert file.size > 0 assert_path_exist file.to_path end @@ -20,7 +20,7 @@ def setup def test_install gemspec = Gem::Specification.load("whispercpp.gemspec") Dir.mktmpdir do |dir| - system "gem", "install", "--install-dir", dir.shellescape, "--no-document", "pkg/#{gemspec.file_name.shellescape}", exception: true + system "gem", "install", "--install-dir", dir, "--no-document", File.join("pkg", gemspec.file_name), exception: true assert_installed dir, gemspec.version end end @@ -29,13 +29,14 @@ def test_install_with_coreml omit_unless RUBY_PLATFORM.match?(/darwin/) do gemspec = Gem::Specification.load("whispercpp.gemspec") Dir.mktmpdir do |dir| - system "gem", "install", "--install-dir", dir.shellescape, "--no-document", "pkg/#{gemspec.file_name.shellescape}", "--", "--enable-whisper-coreml", exception: true + system "gem", "install", "--install-dir", dir, "--no-document", File.join("pkg", gemspec.file_name), "--", "--enable-whisper-coreml", exception: true assert_installed dir, gemspec.version libdir = File.join(dir, "gems", "#{gemspec.name}-#{gemspec.version}", "lib") assert_nothing_raised do system "ruby", "-I", libdir, "-r", "whisper", "-e", "Whisper::Context.new('tiny')", exception: true end - assert_match(/COREML = 1/, `ruby -I #{libdir.shellescape} -r whisper -e 'puts Whisper.system_info_str'`) + output, status = Open3.capture2("ruby", "-I", libdir, "-r", "whisper", "-e", "puts Whisper.system_info_str") + assert_match /COREML = 1/, output end end end diff --git a/bindings/ruby/test/test_parakeet.rb b/bindings/ruby/test/test_parakeet.rb new file mode 100644 index 00000000000..bfd57076f56 --- /dev/null +++ b/bindings/ruby/test/test_parakeet.rb @@ -0,0 +1,28 @@ +require_relative "helper" +require "stringio" + +class TestParakeet < TestBase + def test_log_set + log_callback = Parakeet.instance_variable_get("@log_callback") + user_data = Parakeet.instance_variable_get("@log_callback_user_data") + + $stdout = StringIO.new + Parakeet.log_set proc {|level, message, _| puts [level, message].join(": ")}, nil + Parakeet::Context.new("test/fixtures/for-tests-ggml-parakeet-tdt.bin") + sleep 0.1 + $stdout.rewind + logs = $stdout.string + assert_match /loading model from/, logs + ensure + $stdout = STDOUT + Parakeet.log_set log_callback, user_data + end + + def test_system_info_str + assert_match /\APARAKEET : /, Parakeet.system_info_str + end + + def test_version + assert_instance_of String, Parakeet::VERSION + end +end diff --git a/bindings/ruby/test/test_parakeet_callback.rb b/bindings/ruby/test/test_parakeet_callback.rb new file mode 100644 index 00000000000..1209e960f09 --- /dev/null +++ b/bindings/ruby/test/test_parakeet_callback.rb @@ -0,0 +1,107 @@ +require_relative "helper" + +class TestParakeetCallback < TestBase + def setup + omit "Skip not to download large model" if ENV["CI"] + + Whisper.instance_variable_set "@whisper", nil + GC.start + @params = Parakeet::Params.new + @parakeet = Parakeet::Context.new("parakeet-tdt-0.6b-v3-q4_0") + end + + def test_new_segment_callback + @params.new_segment_callback = ->(context, state, n_new, user_data) { + assert_kind_of Integer, n_new + assert n_new > 0 + assert_same @parakeet, context + + n_segments = context.full_n_segments + n_new.times do |i| + i_segment = n_segments - 1 + i + start_time = context.full_get_segment_t0(i_segment) * 10 + end_time = context.full_get_segment_t1(i_segment) * 10 + text = context.full_get_segment_text(i_segment) + + assert_kind_of Integer, start_time + assert start_time >= 0 + assert_kind_of Integer, end_time + assert end_time > 0 + assert_match(/ask not what your country can do for you, ask what you can do for your/, text) if i_segment == 0 + end + } + + @parakeet.transcribe AUDIO, @params + end + + def test_on_new_segment + seg = nil + index = 0 + @params.on_new_segment do |segment| + assert_instance_of Parakeet::Segment, segment + if index == 0 + seg = segment + assert_equal 0, segment.start_time + assert_match(/ask not what your country can do for you, ask what you can do for your/, segment.text) + end + index += 1 + end + @parakeet.transcribe AUDIO, @params + assert_equal 0, seg.start_time + assert_match /ask not what your country can do for you, ask what you can do for your/, seg.text + end + + def test_on_new_token + index = 0 + @params.on_new_token do |token| + assert_instance_of Parakeet::Token, token + if index == 0 + assert_instance_of Integer, token.start_time + assert_match "▁And", token.text + end + index += 1 + end + + @parakeet.transcribe AUDIO, @params + end + + def test_on_progress + first = nil + @params.on_progress do |progress| + assert_kind_of Integer, progress + assert 0 <= progress && progress <= 100 + first = progress if first.nil? + end + + @parakeet.transcribe AUDIO, @params + + assert_equal 0, first + end + + def test_on_encoder_begin + i = 0 + @params.on_encoder_begin do + i += 1 + end + + @parakeet.transcribe AUDIO, @params + + assert i > 0 + end + + def test_abort_on + do_abort = false + @params.on_new_segment do |segment| + do_abort = true if segment.text.match?(/ask/) + end + i = 0 + @params.abort_on do + i += 1 + do_abort + end + + @parakeet.transcribe(AUDIO, @params) rescue nil + + assert i > 0 + end +end diff --git a/bindings/ruby/test/test_parakeet_context.rb b/bindings/ruby/test/test_parakeet_context.rb new file mode 100644 index 00000000000..2d039ce75f5 --- /dev/null +++ b/bindings/ruby/test/test_parakeet_context.rb @@ -0,0 +1,116 @@ +require_relative "helper" +require "stringio" + +class TestParakeetContext < TestBase + def setup + omit "Skip not to download large model" if ENV["CI"] + + Whisper.instance_variable_set "@whisper", nil + GC.start + + @parakeet = Parakeet::Context.new("parakeet-tdt-0.6b-v3-q4_0") + @params = Parakeet::Params.new + end + + def test_new + assert_instance_of Parakeet::Context, @parakeet + end + + def test_new_with_params + log_callback = Parakeet.instance_variable_get(:@log_callback) + user_data = Parakeet.instance_variable_get(:@log_callback_user_data) + begin + logs = "" + Parakeet.log_set proc {|level, message| logs << message}, nil + params = Parakeet::Context::Params.new(use_gpu: false) + parakeet = Parakeet::Context.new("parakeet-tdt-0.6b-v3-q4_0", params) + assert_instance_of Parakeet::Context, parakeet + assert_match /use gpu\s+=\s+0/, logs + ensure + Parakeet.log_set log_callback, user_data + end + end + + sub_test_case "full" do + def setup + super + @samples = File.read(AUDIO, nil, 78).unpack("s<*").collect {|i| i.to_f / 2**15} + end + + def test_full + @parakeet.full @params, @samples, @samples.length + + segments = @parakeet.each_segment.to_a + assert_equal 1, segments.length + assert_match /ask not what your country can do for you, ask what you can do for your/, segments.first.text + end + + def test_full_without_length + @parakeet.full(@params, @samples) + + segments = @parakeet.each_segment.to_a + assert_equal 1, segments.length + assert_match /ask not what your country can do for you, ask what you can do for your/, @parakeet.each_segment.first.text + end + + def test_full_enumerator + samples = @samples.each + @parakeet.full @params, samples, @samples.length + + segments = @parakeet.each_segment.to_a + assert_equal 1, segments.length + assert_match /ask not what your country can do for you, ask what you can do for your/, @parakeet.each_segment.first.text + end + + def test_full_enumerator_without_length + samples = @samples.each + assert_raise ArgumentError do + @parakeet.full @params, samples + end + end + + def test_full_enumerator_with_too_large_length + samples = @samples.each.take(10).to_enum + assert_raise StopIteration do + @parakeet.full @params, samples, 11 + end + end + + def test_full_with_memory_view + samples = JFKReader.new(AUDIO) + @parakeet.full @params, samples + + segments = @parakeet.each_segment.to_a + assert_equal 1, segments.length + assert_match /ask not what your country can do for you, ask what you can do for your/, @parakeet.each_segment.first.text + end + + def test_full_with_memroy_view_gc + samples = JFKReader.new(AUDIO) + @parakeet.full(@params, samples) + GC.start + require "fiddle" + Fiddle::MemoryView.export samples do |view| + assert_equal 176000, view.to_s.unpack("#{view.format}*").length + end + end + end + + def test_transcribe + assert_nothing_raised do + @parakeet.transcribe AUDIO, @params + end + end + + def test_transcribe_with_pathname + assert_nothing_raised do + @parakeet.transcribe Pathname(AUDIO), @params + end + end + + def test_transcribe_with_nothing + assert_raise_message(/open/) do + @parakeet.transcribe "nothing", @params + end + end +end diff --git a/bindings/ruby/test/test_parakeet_context_params.rb b/bindings/ruby/test/test_parakeet_context_params.rb new file mode 100644 index 00000000000..fcd0f2410f7 --- /dev/null +++ b/bindings/ruby/test/test_parakeet_context_params.rb @@ -0,0 +1,24 @@ +require_relative "helper" + +class TestParakeetContextParams < TestBase + def setup + @params = Parakeet::Context::Params.new + end + + def test_new + assert_instance_of Parakeet::Context::Params, @params + end + + def test_attributes + assert_true @params.use_gpu + assert_instance_of Integer, @params.gpu_device + end + + def test_attribute_writer + @params.use_gpu = false + assert_false @params.use_gpu + + @params.gpu_device = 2 + assert_equal 2, @params.gpu_device + end +end diff --git a/bindings/ruby/test/test_parakeet_model.rb b/bindings/ruby/test/test_parakeet_model.rb new file mode 100644 index 00000000000..5343b35ed8e --- /dev/null +++ b/bindings/ruby/test/test_parakeet_model.rb @@ -0,0 +1,21 @@ +require_relative "helper" + +class TestParakeetModel < TestBase + def test_model + parakeet = Parakeet::Context.new("test/fixtures/for-tests-ggml-parakeet-tdt.bin") + assert_instance_of Parakeet::Model, parakeet.model + end + + def test_attributes + parakeet = Parakeet::Context.new("test/fixtures/for-tests-ggml-parakeet-tdt.bin") + model = parakeet.model + + assert_equal 10, model.n_vocab + assert_equal 3200, model.n_audio_ctx + assert_equal 8, model.n_audio_state + assert_equal 2, model.n_audio_head + assert_equal 1, model.n_audio_layer + assert_equal 16, model.n_mels + assert_equal 0, model.ftype + end +end diff --git a/bindings/ruby/test/test_parakeet_params.rb b/bindings/ruby/test/test_parakeet_params.rb new file mode 100644 index 00000000000..dc651f7ab12 --- /dev/null +++ b/bindings/ruby/test/test_parakeet_params.rb @@ -0,0 +1,78 @@ +require_relative "helper" +require "etc" + +class TestParakeetParams < TestBase + PARAM_NAMES = [ + :n_threads, + :offset_ms, + :duration_ms, + :no_context, + :audio_ctx + ] + + def setup + @params = Parakeet::Params.new + end + + def test_new + assert_instance_of Parakeet::Params, @params + end + + def test_n_threads + assert_equal [4, Etc.nprocessors].min, @params.n_threads + + @params.n_threads = 1 + assert_equal 1, @params.n_threads + end + + def test_offset_ms + assert_equal 0, @params.offset_ms + + @params.offset_ms = 10_000 + assert_equal 10_000, @params.offset_ms + end + + def test_duration_ms + assert_equal 0, @params.duration_ms + + @params.duration_ms = 60_000 + assert_equal 60_000, @params.duration_ms + end + + def test_no_context + assert_equal true, @params.no_context + + @params.no_context = false + assert_equal false, @params.no_context + end + + def test_audio_ctx + assert_equal 0, @params.audio_ctx + + @params.audio_ctx = 1 + assert_equal 1, @params.audio_ctx + end + + def test_new_with_kw_args + params = Parakeet::Params.new(n_threads: 1) + assert_equal 1, params.n_threads + assert_equal 0, params.offset_ms + end + + data(PARAM_NAMES.collect {|param| [param, param]}.to_h) + def test_new_with_kw_args_default_values(param) + default_value = @params.send(param) + value = case [param, default_value] + in [*, true | false] + !default_value + in [*, Integer] + default_value + 1 + end + params = Parakeet::Params.new(param => value) + assert_equal value, params.send(param) + + PARAM_NAMES.reject {|name| name == param}.each do |name| + assert_equal @params.send(name), params.send(name) + end + end +end diff --git a/bindings/ruby/test/test_parakeet_segment.rb b/bindings/ruby/test/test_parakeet_segment.rb new file mode 100644 index 00000000000..d5b99bd5ee6 --- /dev/null +++ b/bindings/ruby/test/test_parakeet_segment.rb @@ -0,0 +1,42 @@ +require_relative "helper" + +class TestParakeetSegment < TestBase + def setup + omit "Skip not to download large model" if ENV["CI"] + + @parakeet = Parakeet::Context.new("parakeet-tdt-0.6b-v3-q4_0") + @parakeet.transcribe AUDIO, Parakeet::Params.new + end + + def test_segment + whole_text = "" + @parakeet.each_segment do |segment| + assert_instance_of Parakeet::Segment, segment + assert_kind_of Integer, segment.start_time + assert segment.end_time >= segment.start_time + assert_kind_of String, segment.text + whole_text << segment.text + end + assert_match(/ask not what your country can do for you, ask what you can do for your country/, whole_text) + end + + def test_deconstruct_keys + segment = @parakeet.each_segment.first + expected = { + start_time: segment.start_time, + end_time: segment.end_time, + text: segment.text + } + assert_equal expected, segment.deconstruct_keys([:start_time, :end_time, :text]) + end + + def test_deconstruct_keys_with_nil + segment = @parakeet.each_segment.first + expected = { + start_time: segment.start_time, + end_time: segment.end_time, + text: segment.text + } + assert_equal expected, segment.deconstruct_keys(nil) + end +end diff --git a/bindings/ruby/test/test_parakeet_token.rb b/bindings/ruby/test/test_parakeet_token.rb new file mode 100644 index 00000000000..6f0b8b5a37c --- /dev/null +++ b/bindings/ruby/test/test_parakeet_token.rb @@ -0,0 +1,73 @@ +require_relative "helper" + +class TestParakeetToken < TestBase + ATTRS = %i[ + id + duration_idx + duration_value + frame_index + probability + log_probability + start_time + end_time + word_start? + text + ] + + def setup + omit "Skip not to download large model" if ENV["CI"] + + Whisper.instance_variable_set "@whisper", nil + GC.start + + parakeet = Parakeet::Context.new("parakeet-tdt-0.6b-v3-q4_0") + params = Parakeet::Params.new + parakeet.transcribe AUDIO, params + @segment = parakeet.each_segment.first + end + + def test_each_token + i = 0 + @segment.each_token do |token| + i += 1 + assert_instance_of Parakeet::Token, token + end + assert_equal 38, i + end + + def test_each_token_without_block + assert_instance_of Enumerator, @segment.each_token + end + + def test_token + token = @segment.each_token.first + + assert_instance_of Parakeet::Token, token + assert_instance_of Integer, token.id + assert_instance_of Integer, token.duration_idx + assert_instance_of Integer, token.duration_value + assert_instance_of Integer, token.frame_index + assert_instance_of Float, token.probability + assert_instance_of Float, token.log_probability + assert_instance_of Integer, token.start_time + assert_instance_of Integer, token.end_time + assert_instance_of String, token.text + end + + def test_text + assert_equal ["▁And", "▁so", ",", "▁my", "▁f", "ell", "ow", "▁Amer", "ic", "ans", ",", "▁a", "sk", "▁not", "▁what", "▁your", "▁co", "un", "tr", "y", "▁can", "▁do", "▁for", "▁you", ",", "▁a", "sk", "▁what", "▁you", "▁can", "▁do", "▁for", "▁your", "▁co", "un", "tr", "y", "."], + @segment.each_token.collect(&:text) + end + + def test_deconstruct_keys_with_nil + token = @segment.each_token.first + expected = ATTRS.collect {|attr| [attr.to_s.sub(/\?\z/, "").intern, token.send(attr)]}.to_h + assert_equal expected, token.deconstruct_keys(nil) + end + + def test_deconstruct_keys_with_keys + token = @segment.each_token.first + expected = ATTRS.collect {|attr| [attr.to_s.sub(/\?\z/, "").intern, token.send(attr)]}.to_h + assert_equal expected, token.deconstruct_keys(expected.keys) + end +end diff --git a/bindings/ruby/test/test_params.rb b/bindings/ruby/test/test_params.rb index 094dba6f48e..ff5c28e9043 100644 --- a/bindings/ruby/test/test_params.rb +++ b/bindings/ruby/test/test_params.rb @@ -46,6 +46,8 @@ def setup def test_language @params.language = "en" assert_equal @params.language, "en" + GC.compact + assert_equal @params.language, "en" @params.language = "auto" assert_equal @params.language, "auto" end diff --git a/bindings/ruby/test/test_token.rb b/bindings/ruby/test/test_token.rb index e5834b1b480..a23f6813675 100644 --- a/bindings/ruby/test/test_token.rb +++ b/bindings/ruby/test/test_token.rb @@ -56,6 +56,17 @@ def test_text @segment.each_token.collect(&:text) end + def test_token_timestamps + params = Whisper::Params.new(token_timestamps: true) + whisper.transcribe(TestBase::AUDIO, params) + prev = -1 + whisper.each_segment.first.each_token do |token| + assert token.start_time >= prev + assert token.end_time >= token.start_time + prev = token.end_time + end + end + def test_deconstruct_keys_with_nil keys = %i[id tid probability log_probability pt ptsum t_dtw voice_length start_time end_time text] expected = keys.collect {|key| [key, @token.send(key)] }.to_h diff --git a/bindings/ruby/test/test_vad_context.rb b/bindings/ruby/test/test_vad_context.rb index 704916db6de..b4558d34faf 100644 --- a/bindings/ruby/test/test_vad_context.rb +++ b/bindings/ruby/test/test_vad_context.rb @@ -9,6 +9,25 @@ def test_initialize def test_detect context = Whisper::VAD::Context.new("silero-v6.2.0") segments = context.detect(AUDIO, Whisper::VAD::Params.new) + assert_segments segments + end + + def test_invalid_model_type + assert_raise TypeError do + Whisper::VAD::Context.new(Object.new) + end + end + + def test_allocate + vad = Whisper::VAD::Context.allocate + assert_raise do + vad.detect(AUDIO, Whisper::VAD::Params.new) + end + end + + private + + def assert_segments(segments) assert_instance_of Whisper::VAD::Segments, segments i = 0 @@ -35,16 +54,47 @@ def test_detect assert_equal 4, segments.length end - def test_invalid_model_type - assert_raise TypeError do - Whisper::VAD::Context.new(Object.new) + sub_test_case "from samples" do + def setup + super + @vad = Whisper::VAD::Context.new("silero-v6.2.0") + @samples = File.read(AUDIO, nil, 78).unpack("s<*").collect {|i| i.to_f / 2**15} end - end - def test_allocate - vad = Whisper::VAD::Context.allocate - assert_raise do - vad.detect(AUDIO, Whisper::VAD::Params.new) + def test_segments_from_samples + segments = @vad.segments_from_samples(Whisper::VAD::Params.new, @samples, @samples.length) + assert_segments segments + end + + def test_segments_from_samples_without_length + segments = @vad.segments_from_samples(Whisper::VAD::Params.new, @samples) + assert_segments segments + end + + def test_segments_from_samples_enumerator + samples = @samples.each + segments = @vad.segments_from_samples(Whisper::VAD::Params.new, samples, @samples.length) + assert_segments segments + end + + def test_segments_from_samples_enumerator_without_length + samples = @samples.each + assert_raise ArgumentError do + @vad.segments_from_samples(Whisper::VAD::Params.new, samples) + end + end + + def test_segments_from_samples_enumerator_with_too_large_length + samples = @samples.each.take(10).to_enum + assert_raise StopIteration do + @vad.segments_from_samples(Whisper::VAD::Params.new, samples, 11) + end + end + + def test_segments_from_samples_with_memory_view + samples = JFKReader.new(AUDIO) + segments = @vad.segments_from_samples(Whisper::VAD::Params.new, samples) + assert_segments segments end end end diff --git a/bindings/ruby/test/test_vad_segment.rb b/bindings/ruby/test/test_vad_segment.rb index 7348562cb15..6d66c27fd32 100644 --- a/bindings/ruby/test/test_vad_segment.rb +++ b/bindings/ruby/test/test_vad_segment.rb @@ -9,7 +9,7 @@ def test_initialize end assert_raise do - segments.end_time + segment.end_time end assert_raise do diff --git a/bindings/ruby/test/test_whisper.rb b/bindings/ruby/test/test_whisper.rb index 96e248aca3a..082547e7c08 100644 --- a/bindings/ruby/test/test_whisper.rb +++ b/bindings/ruby/test/test_whisper.rb @@ -1,6 +1,7 @@ require_relative "helper" require "stringio" require "etc" +require "pathname" # Exists to detect memory-related bug Whisper.log_set ->(level, buffer, user_data) {}, nil @@ -20,6 +21,15 @@ def test_whisper } end + def test_whisper_pathname + @whisper = Whisper::Context.new("base.en") + params = Whisper::Params.new + + @whisper.transcribe(Pathname(AUDIO), params) {|text| + assert_match(/ask not what your country can do for you, ask what you can do for your country/, text) + } + end + def test_transcribe_non_parallel @whisper = Whisper::Context.new("base.en") params = Whisper::Params.new @@ -33,9 +43,20 @@ def test_transcribe_n_processors @whisper = Whisper::Context.new("base.en") params = Whisper::Params.new - @whisper.transcribe(AUDIO, params, n_processors: 4) {|text| - assert_match(/what you can do for your country/i, text) - } + without_log_callback do + @whisper.transcribe(AUDIO, params, n_processors: 4) {|text| + assert_match(/what you can do for your country/i, text) + } + end + end + + private + + def without_log_callback + Whisper.log_set nil, nil + yield + ensure + Whisper.log_set ->(level, buffer, user_data) {}, nil end sub_test_case "After transcription" do @@ -128,6 +149,7 @@ def test_log_set } Whisper.log_set log_callback, user_data Whisper::Context.new("base.en") + sleep 0.1 # wait for logs dequeued assert logs.length > 30 logs.each do |log| @@ -207,9 +229,21 @@ def test_full_with_memory_view assert_match(/ask not what your country can do for you, ask what you can do for your country/, @whisper.each_segment.first.text) end + def test_full_with_memroy_view_gc + samples = JFKReader.new(AUDIO) + @whisper.full(@params, samples) + GC.start + require "fiddle" + Fiddle::MemoryView.export samples do |view| + assert_equal 176000, view.to_s.unpack("#{view.format}*").length + end + end + def test_full_parallel nprocessors = 2 - @whisper.full_parallel(@params, @samples, @samples.length, nprocessors) + without_log_callback do + @whisper.full_parallel(@params, @samples, @samples.length, nprocessors) + end assert_equal nprocessors, @whisper.full_n_segments text = @whisper.each_segment.collect(&:text).join @@ -220,7 +254,9 @@ def test_full_parallel def test_full_parallel_with_memory_view nprocessors = 2 samples = JFKReader.new(AUDIO) - @whisper.full_parallel(@params, samples, nil, nprocessors) + without_log_callback do + @whisper.full_parallel(@params, samples, nil, nprocessors) + end assert_equal nprocessors, @whisper.full_n_segments text = @whisper.each_segment.collect(&:text).join @@ -239,7 +275,9 @@ def test_full_parallel_without_length_and_n_processors def test_full_parallel_without_length nprocessors = 2 - @whisper.full_parallel(@params, @samples, nil, nprocessors) + without_log_callback do + @whisper.full_parallel(@params, @samples, nil, nprocessors) + end assert_equal nprocessors, @whisper.full_n_segments text = @whisper.each_segment.collect(&:text).join diff --git a/bindings/ruby/whispercpp.gemspec b/bindings/ruby/whispercpp.gemspec index 2e05769a22c..301ecfcc13d 100644 --- a/bindings/ruby/whispercpp.gemspec +++ b/bindings/ruby/whispercpp.gemspec @@ -3,7 +3,7 @@ require_relative "extsources" Gem::Specification.new do |s| s.name = "whispercpp" s.authors = ["Georgi Gerganov", "Todd A. Fisher"] - s.version = '1.3.5' + s.version = '1.3.7' s.description = %q{High-performance inference of OpenAI's Whisper automatic speech recognition (ASR) model via Ruby} s.email = 'todd.fisher@gmail.com' s.extra_rdoc_files = ['LICENSE', 'README.md'] @@ -23,7 +23,7 @@ Gem::Specification.new do |s| s.test_files = s.files.select {|file| file.start_with? "test/"} s.extensions << 'ext/extconf.rb' - s.required_ruby_version = '>= 3.1.0' + s.required_ruby_version = '>= 3.3.0' #### Documentation and testing. s.homepage = 'https://github.com/ggml-org/whisper.cpp' diff --git a/build-xcframework.sh b/build-xcframework.sh index bbf2764d729..4d462bbf4f3 100755 --- a/build-xcframework.sh +++ b/build-xcframework.sh @@ -559,7 +559,7 @@ xcodebuild -create-xcframework \ -framework $(pwd)/build-ios-device/framework/whisper.framework \ -debug-symbols $(pwd)/build-ios-device/dSYMs/whisper.dSYM \ -framework $(pwd)/build-macos/framework/whisper.framework \ - -debug-symbols $(pwd)/build-macos/dSYMS/whisper.dSYM \ + -debug-symbols $(pwd)/build-macos/dSYMs/whisper.dSYM \ -framework $(pwd)/build-visionos/framework/whisper.framework \ -debug-symbols $(pwd)/build-visionos/dSYMs/whisper.dSYM \ -framework $(pwd)/build-visionos-sim/framework/whisper.framework \ diff --git a/ci/run.sh b/ci/run.sh index cbe28442e16..dca4476a0fa 100644 --- a/ci/run.sh +++ b/ci/run.sh @@ -50,6 +50,10 @@ fi CMAKE_EXTRA="-DWHISPER_FATAL_WARNINGS=ON" +if [[ "$(uname -m)" == "x86_64" ]]; then + CMAKE_EXTRA="${CMAKE_EXTRA} -DGGML_NATIVE=OFF" +fi + if [ ! -z ${GG_BUILD_METAL} ]; then CMAKE_EXTRA="${CMAKE_EXTRA} -DGGML_METAL=ON" fi @@ -147,8 +151,15 @@ function gg_download_model { local cwd=`pwd` mkdir -p "$MNT/models" cd "$MNT/models" + set -x bash "$cwd/models/download-ggml-model.sh" ${model_name} . + local download_status=$? + set +x cd "$cwd" + if [ $download_status -ne 0 ]; then + echo "Error: failed to download model ${model_name}" + ret=1 + fi fi } diff --git a/close-issue.yml b/close-issue.yml index 276a217d450..f661de1cd45 100644 --- a/close-issue.yml +++ b/close-issue.yml @@ -15,7 +15,7 @@ jobs: issues: write pull-requests: write steps: - - uses: actions/stale@v5 + - uses: actions/stale@v10 with: exempt-issue-labels: "refactor,help wanted,good first issue,research,bug,roadmap" days-before-issue-stale: 30 diff --git a/cmake/parakeet-config.cmake.in b/cmake/parakeet-config.cmake.in new file mode 100644 index 00000000000..aadb55c2d19 --- /dev/null +++ b/cmake/parakeet-config.cmake.in @@ -0,0 +1,30 @@ +set(PARAKEET_VERSION @WHISPER_INSTALL_VERSION@) +set(PARAKEET_BUILD_COMMIT @WHISPER_BUILD_COMMIT@) +set(PARAKEET_BUILD_NUMBER @WHISPER_BUILD_NUMBER@) +set(PARAKEET_SHARED_LIB @BUILD_SHARED_LIBS@) + +@PACKAGE_INIT@ + +set_and_check(PARAKEET_INCLUDE_DIR "@PACKAGE_PARAKEET_INCLUDE_INSTALL_DIR@") +set_and_check(PARAKEET_LIB_DIR "@PACKAGE_PARAKEET_LIB_INSTALL_DIR@") +set_and_check(PARAKEET_BIN_DIR "@PACKAGE_PARAKEET_BIN_INSTALL_DIR@") + +find_package(ggml REQUIRED HINTS ${PARAKEET_LIB_DIR}/cmake) + +find_library(parakeet_LIBRARY parakeet + REQUIRED + HINTS ${PARAKEET_LIB_DIR} + NO_CMAKE_FIND_ROOT_PATH +) + +add_library(parakeet UNKNOWN IMPORTED) +set_target_properties(parakeet + PROPERTIES + INTERFACE_INCLUDE_DIRECTORIES "${PARAKEET_INCLUDE_DIR}" + INTERFACE_LINK_LIBRARIES "ggml::ggml;ggml::ggml-base;" + IMPORTED_LINK_INTERFACE_LANGUAGES "CXX" + IMPORTED_LOCATION "${parakeet_LIBRARY}" + INTERFACE_COMPILE_FEATURES cxx_std_11 + POSITION_INDEPENDENT_CODE ON) + +check_required_components(parakeet) diff --git a/cmake/parakeet.pc.in b/cmake/parakeet.pc.in new file mode 100644 index 00000000000..5a25fbb2e42 --- /dev/null +++ b/cmake/parakeet.pc.in @@ -0,0 +1,10 @@ +prefix=@CMAKE_INSTALL_PREFIX@ +exec_prefix=${prefix} +libdir=${prefix}/@CMAKE_INSTALL_LIBDIR@ +includedir=${prefix}/include + +Name: parakeet +Description: Port of NVIDIA's Parakeet model in C/C++ +Version: @PROJECT_VERSION@ +Libs: -L${libdir} -lggml -lggml-base -lparakeet +Cflags: -I${includedir} diff --git a/cmake/whisper-config.cmake.in b/cmake/whisper-config.cmake.in index 6a3fa22701f..b70c1e5af44 100644 --- a/cmake/whisper-config.cmake.in +++ b/cmake/whisper-config.cmake.in @@ -3,60 +3,25 @@ set(WHISPER_BUILD_COMMIT @WHISPER_BUILD_COMMIT@) set(WHISPER_BUILD_NUMBER @WHISPER_BUILD_NUMBER@) set(WHISPER_SHARED_LIB @BUILD_SHARED_LIBS@) -set(GGML_BLAS @GGML_BLAS@) -set(GGML_CUDA @GGML_CUDA@) -set(GGML_METAL @GGML_METAL@) -set(GGML_HIPBLAS @GGML_HIPBLAS@) -set(GGML_ACCELERATE @GGML_ACCELERATE@) - @PACKAGE_INIT@ set_and_check(WHISPER_INCLUDE_DIR "@PACKAGE_WHISPER_INCLUDE_INSTALL_DIR@") set_and_check(WHISPER_LIB_DIR "@PACKAGE_WHISPER_LIB_INSTALL_DIR@") set_and_check(WHISPER_BIN_DIR "@PACKAGE_WHISPER_BIN_INSTALL_DIR@") -# Ensure transient dependencies satisfied - -find_package(Threads REQUIRED) - -if (APPLE AND GGML_ACCELERATE) - find_library(ACCELERATE_FRAMEWORK Accelerate REQUIRED) -endif() - -if (GGML_BLAS) - find_package(BLAS REQUIRED) -endif() - -if (GGML_CUDA) - find_package(CUDAToolkit REQUIRED) -endif() - -if (GGML_METAL) - find_library(FOUNDATION_LIBRARY Foundation REQUIRED) - find_library(METAL_FRAMEWORK Metal REQUIRED) - find_library(METALKIT_FRAMEWORK MetalKit REQUIRED) -endif() - -if (GGML_HIPBLAS) - find_package(hip REQUIRED) - find_package(hipblas REQUIRED) - find_package(rocblas REQUIRED) -endif() +find_package(ggml REQUIRED HINTS ${LLAMA_LIB_DIR}/cmake) find_library(whisper_LIBRARY whisper REQUIRED - HINTS ${WHISPER_LIB_DIR}) - -set(_whisper_link_deps "Threads::Threads" "@WHISPER_EXTRA_LIBS@") -set(_whisper_transient_defines "@WHISPER_TRANSIENT_DEFINES@") + HINTS ${WHISPER_LIB_DIR} + NO_CMAKE_FIND_ROOT_PATH +) add_library(whisper UNKNOWN IMPORTED) - set_target_properties(whisper PROPERTIES INTERFACE_INCLUDE_DIRECTORIES "${WHISPER_INCLUDE_DIR}" - INTERFACE_LINK_LIBRARIES "${_whisper_link_deps}" - INTERFACE_COMPILE_DEFINITIONS "${_whisper_transient_defines}" + INTERFACE_LINK_LIBRARIES "ggml::ggml;ggml::ggml-base;" IMPORTED_LINK_INTERFACE_LANGUAGES "CXX" IMPORTED_LOCATION "${whisper_LIBRARY}" INTERFACE_COMPILE_FEATURES cxx_std_11 diff --git a/cmake/whisper.pc.in b/cmake/whisper.pc.in index 00ec7912014..200179d5d11 100644 --- a/cmake/whisper.pc.in +++ b/cmake/whisper.pc.in @@ -1,6 +1,6 @@ prefix=@CMAKE_INSTALL_PREFIX@ exec_prefix=${prefix} -libdir=${exec_prefix}/lib +libdir=${prefix}/@CMAKE_INSTALL_LIBDIR@ includedir=${prefix}/include Name: whisper diff --git a/examples/CMakeLists.txt b/examples/CMakeLists.txt index b202ca00b77..7aedb9df683 100644 --- a/examples/CMakeLists.txt +++ b/examples/CMakeLists.txt @@ -20,7 +20,7 @@ set(TARGET common) unset(COMMON_EXTRA_LIBS) -if (WHISPER_FFMPEG) +if (WHISPER_COMMON_FFMPEG) # As of cmake 3.27, there is no official cmake support for FindFFmpeg. # Consequnelty we added a FindFFmpeg.cmake script the cmake subfolder: # whisper.cpp does not need the full ffmpeg libs, just AVFORMAT AVCODEC AVUTIL SWRESAMPLE @@ -39,7 +39,7 @@ if (WHISPER_FFMPEG) message(STATUS "Found avformat ${AVFORMAT_VERSION}") include_directories(${FFMPEG_INCLUDE_DIRS}) - add_compile_definitions(WHISPER_FFMPEG) + add_compile_definitions(WHISPER_COMMON_FFMPEG) list(APPEND COMMON_EXTRA_LIBS ${FFMPEG_LIBRARIES}) @@ -107,6 +107,8 @@ else() add_subdirectory(server) add_subdirectory(quantize) add_subdirectory(vad-speech-segments) + add_subdirectory(parakeet-cli) + add_subdirectory(parakeet-quantize) if (WHISPER_SDL2) add_subdirectory(stream) add_subdirectory(command) diff --git a/examples/bench.wasm/emscripten.cpp b/examples/bench.wasm/emscripten.cpp index 083397db057..7e9f277f66e 100644 --- a/examples/bench.wasm/emscripten.cpp +++ b/examples/bench.wasm/emscripten.cpp @@ -45,7 +45,7 @@ void bench_main(size_t index) { fprintf(stderr, "\n"); fprintf(stderr, "If you wish, you can submit these results here:\n"); fprintf(stderr, "\n"); - fprintf(stderr, " https://github.com/ggerganov/whisper.cpp/issues/89\n"); + fprintf(stderr, " https://github.com/ggml-org/whisper.cpp/issues/89\n"); fprintf(stderr, "\n"); fprintf(stderr, "Please include the following information:\n"); fprintf(stderr, "\n"); diff --git a/examples/bench/bench.cpp b/examples/bench/bench.cpp index 2d967f2caf4..84915c56a8a 100644 --- a/examples/bench/bench.cpp +++ b/examples/bench/bench.cpp @@ -85,33 +85,38 @@ static int whisper_bench_full(const whisper_params & params) { fprintf(stderr, "error: failed to set mel: %d\n", ret); return 3; } - // heat encoder - if (int ret = whisper_encode(ctx, 0, params.n_threads) != 0) { - fprintf(stderr, "error: failed to encode: %d\n", ret); - return 4; - } whisper_token tokens[512]; memset(tokens, 0, sizeof(tokens)); - // prompt heat - if (int ret = whisper_decode(ctx, tokens, 256, 0, params.n_threads) != 0) { - fprintf(stderr, "error: failed to decode: %d\n", ret); - return 4; - } + // TODO: need 2 loops because of the current graph capture logic in the CUDA backend + // https://github.com/ggml-org/llama.cpp/pull/19754 + for (int h = 0; h < 2; ++h) { + // heat encoder + if (int ret = whisper_encode(ctx, 0, params.n_threads) != 0) { + fprintf(stderr, "error: failed to encode: %d\n", ret); + return 4; + } - // text-generation heat - for (int i = 0; i < 256; i++) { - if (int ret = whisper_decode(ctx, tokens, 1, i, params.n_threads) != 0) { + // prompt heat + if (int ret = whisper_decode(ctx, tokens, 256, 0, params.n_threads) != 0) { fprintf(stderr, "error: failed to decode: %d\n", ret); return 4; } - } - // batched heat - if (int ret = whisper_decode(ctx, tokens, 5, 0, params.n_threads) != 0) { - fprintf(stderr, "error: failed to decode: %d\n", ret); - return 4; + // text-generation heat + for (int i = 0; i < 256; i++) { + if (int ret = whisper_decode(ctx, tokens, 1, i, params.n_threads) != 0) { + fprintf(stderr, "error: failed to decode: %d\n", ret); + return 4; + } + } + + // batched heat + if (int ret = whisper_decode(ctx, tokens, 5, 0, params.n_threads) != 0) { + fprintf(stderr, "error: failed to decode: %d\n", ret); + return 4; + } } whisper_reset_timings(ctx); @@ -152,7 +157,7 @@ static int whisper_bench_full(const whisper_params & params) { fprintf(stderr, "\n"); fprintf(stderr, "If you wish, you can submit these results here:\n"); fprintf(stderr, "\n"); - fprintf(stderr, " https://github.com/ggerganov/whisper.cpp/issues/89\n"); + fprintf(stderr, " https://github.com/ggml-org/whisper.cpp/issues/89\n"); fprintf(stderr, "\n"); fprintf(stderr, "Please include the following information:\n"); fprintf(stderr, "\n"); diff --git a/examples/cli/cli.cpp b/examples/cli/cli.cpp index 9a54742fe1d..e505bf0e18d 100644 --- a/examples/cli/cli.cpp +++ b/examples/cli/cli.cpp @@ -77,6 +77,7 @@ struct whisper_params { bool log_score = false; bool use_gpu = true; bool flash_attn = true; + int32_t gpu_device = 0; bool suppress_nst = false; bool carry_initial_prompt = false; @@ -129,6 +130,10 @@ static char * requires_value_error(const std::string & arg) { } static bool whisper_params_parse(int argc, char ** argv, whisper_params & params) { + if (const char * env_device = std::getenv("WHISPER_ARG_DEVICE")) { + params.gpu_device = std::stoi(env_device); + } + for (int i = 1; i < argc; i++) { std::string arg = argv[i]; @@ -146,6 +151,10 @@ static bool whisper_params_parse(int argc, char ** argv, whisper_params & params whisper_print_usage(argc, argv, params); exit(0); } + if (arg == "--version") { + fprintf(stdout, "whisper.cpp version: %s\n", whisper_version()); + exit(0); + } #define ARGV_NEXT (((i + 1) < argc) ? argv[++i] : requires_value_error(arg)) else if (arg == "-t" || arg == "--threads") { params.n_threads = std::stoi(ARGV_NEXT); } else if (arg == "-p" || arg == "--processors") { params.n_processors = std::stoi(ARGV_NEXT); } @@ -195,6 +204,7 @@ static bool whisper_params_parse(int argc, char ** argv, whisper_params & params else if (arg == "-dtw" || arg == "--dtw") { params.dtw = ARGV_NEXT; } else if (arg == "-ls" || arg == "--log-score") { params.log_score = true; } else if (arg == "-ng" || arg == "--no-gpu") { params.use_gpu = false; } + else if (arg == "-dev" || arg == "--device") { params.gpu_device = std::stoi(ARGV_NEXT); } else if (arg == "-fa" || arg == "--flash-attn") { params.flash_attn = true; } else if (arg == "-nfa" || arg == "--no-flash-attn") { params.flash_attn = false; } else if (arg == "-sns" || arg == "--suppress-nst") { params.suppress_nst = true; } @@ -228,6 +238,7 @@ static void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params fprintf(stderr, "\n"); fprintf(stderr, "options:\n"); fprintf(stderr, " -h, --help [default] show this help message and exit\n"); + fprintf(stderr, " --version show version information and exit\n"); fprintf(stderr, " -t N, --threads N [%-7d] number of threads to use during computation\n", params.n_threads); fprintf(stderr, " -p N, --processors N [%-7d] number of processors to use during computation\n", params.n_processors); fprintf(stderr, " -ot N, --offset-t N [%-7d] time offset in milliseconds\n", params.offset_t_ms); @@ -276,6 +287,7 @@ static void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params fprintf(stderr, " -dtw MODEL --dtw MODEL [%-7s] compute token-level timestamps\n", params.dtw.c_str()); fprintf(stderr, " -ls, --log-score [%-7s] log best decoder scores of tokens\n", params.log_score?"true":"false"); fprintf(stderr, " -ng, --no-gpu [%-7s] disable GPU\n", params.use_gpu ? "false" : "true"); + fprintf(stderr, " -dev N, --device N [%-7d] GPU device ID (default: 0)\n", params.gpu_device); fprintf(stderr, " -fa, --flash-attn [%-7s] enable flash attention\n", params.flash_attn ? "true" : "false"); fprintf(stderr, " -nfa, --no-flash-attn [%-7s] disable flash attention\n", params.flash_attn ? "false" : "true"); fprintf(stderr, " -sns, --suppress-nst [%-7s] suppress non-speech tokens\n", params.suppress_nst ? "true" : "false"); @@ -731,18 +743,47 @@ static void output_json( if (full) { start_arr("tokens"); const int n = whisper_full_n_tokens(ctx, i); - for (int j = 0; j < n; ++j) { - auto token = whisper_full_get_token_data(ctx, i, j); + + // Merge adjacent tokens whose bytes together form a + // single UTF-8 codepoint. Multi-byte characters (CJK + // in particular) can end up split across whisper + // tokens, which used to produce invalid UTF-8 in the + // JSON string. Refs issue #1798. + struct merged_token { + std::string text; + whisper_token_data data; + int64_t t1; + }; + std::vector<merged_token> merged; + merged.reserve(n); + for (int j = 0; j < n; ) { + auto tok = whisper_full_get_token_data(ctx, i, j); + merged_token m{ whisper_token_to_str(ctx, tok.id), tok, tok.t1 }; + ++j; + while (j < n && utf8_trailing_bytes_needed(m.text) > 0) { + auto tok_next = whisper_full_get_token_data(ctx, i, j); + m.text += whisper_token_to_str(ctx, tok_next.id); + if (tok_next.t1 > -1) { + m.t1 = tok_next.t1; + } + ++j; + } + merged.push_back(std::move(m)); + } + + const int nm = (int) merged.size(); + for (int j = 0; j < nm; ++j) { + const auto & mt = merged[j]; start_obj(nullptr); - value_s("text", whisper_token_to_str(ctx, token.id), false); - if(token.t0 > -1 && token.t1 > -1) { + value_s("text", mt.text.c_str(), false); + if (mt.data.t0 > -1 && mt.t1 > -1) { // If we have per-token timestamps, write them out - times_o(token.t0, token.t1, false); + times_o(mt.data.t0, mt.t1, false); } - value_i("id", token.id, false); - value_f("p", token.p, false); - value_f("t_dtw", token.t_dtw, true); - end_obj(j == (n - 1)); + value_i("id", mt.data.id, false); + value_f("p", mt.data.p, false); + value_f("t_dtw", mt.data.t_dtw, true); + end_obj(j == (nm - 1)); } end_arr(!params.diarize && !params.tinydiarize); } @@ -1003,6 +1044,7 @@ int main(int argc, char ** argv) { struct whisper_context_params cparams = whisper_context_default_params(); cparams.use_gpu = params.use_gpu; + cparams.gpu_device = params.gpu_device; cparams.flash_attn = params.flash_attn; if (!params.dtw.empty()) { diff --git a/examples/common-ggml.cpp b/examples/common-ggml.cpp index c42b644fedd..3f2eded86f7 100644 --- a/examples/common-ggml.cpp +++ b/examples/common-ggml.cpp @@ -73,6 +73,8 @@ bool ggml_common_quantize_0( case GGML_FTYPE_MOSTLY_IQ1_M: case GGML_FTYPE_MOSTLY_BF16: case GGML_FTYPE_MOSTLY_MXFP4: + case GGML_FTYPE_MOSTLY_NVFP4: + case GGML_FTYPE_MOSTLY_Q1_0: { fprintf(stderr, "%s: invalid model type %d\n", __func__, ftype); return false; @@ -213,6 +215,8 @@ bool ggml_common_quantize_0( case GGML_TYPE_TQ1_0: case GGML_TYPE_TQ2_0: case GGML_TYPE_MXFP4: + case GGML_TYPE_NVFP4: + case GGML_TYPE_Q1_0: case GGML_TYPE_COUNT: { fprintf(stderr, "%s: unsupported quantization type %d (%s)\n", __func__, ttype, ggml_type_name((ggml_type) ttype)); diff --git a/examples/common-whisper.cpp b/examples/common-whisper.cpp index 6218a882eb5..b12481c013f 100644 --- a/examples/common-whisper.cpp +++ b/examples/common-whisper.cpp @@ -34,91 +34,35 @@ #include <cstring> #include <fstream> -#ifdef WHISPER_FFMPEG -// as implemented in ffmpeg_trancode.cpp only embedded in common lib if whisper built with ffmpeg support -extern bool ffmpeg_decode_audio(const std::string & ifname, std::vector<uint8_t> & wav_data); +#ifdef WHISPER_COMMON_FFMPEG +// as implemented in ffmpeg-trancode.cpp only embedded in common lib if whisper built with ffmpeg support +extern bool ffmpeg_decode_audio(const std::string & ifname, std::vector<uint8_t> & wav_data, int out_sample_rate = WHISPER_SAMPLE_RATE); #endif -bool read_audio_data(const std::string & fname, std::vector<float>& pcmf32, std::vector<std::vector<float>>& pcmf32s, bool stereo) { - std::vector<uint8_t> audio_data; // used for pipe input from stdin or ffmpeg decoding output - +// extract f32 PCM frames from an initialized decoder, downmix to mono and keep the stereo split +static bool read_audio_from_decoder(ma_decoder & decoder, std::vector<float> & pcmf32, std::vector<std::vector<float>> & pcmf32s, bool stereo) { ma_result result; - ma_decoder_config decoder_config; - ma_decoder decoder; - - decoder_config = ma_decoder_config_init(ma_format_f32, stereo ? 2 : 1, WHISPER_SAMPLE_RATE); - - if (fname == "-") { - #ifdef _WIN32 - _setmode(_fileno(stdin), _O_BINARY); - #endif - - uint8_t buf[1024]; - while (true) - { - const size_t n = fread(buf, 1, sizeof(buf), stdin); - if (n == 0) { - break; - } - audio_data.insert(audio_data.end(), buf, buf + n); - } - - if ((result = ma_decoder_init_memory(audio_data.data(), audio_data.size(), &decoder_config, &decoder)) != MA_SUCCESS) { - - fprintf(stderr, "Error: failed to open audio data from stdin (%s)\n", ma_result_description(result)); - - return false; - } - - fprintf(stderr, "%s: read %zu bytes from stdin\n", __func__, audio_data.size()); - } - else if (((result = ma_decoder_init_file(fname.c_str(), &decoder_config, &decoder)) != MA_SUCCESS)) { -#if defined(WHISPER_FFMPEG) - if (ffmpeg_decode_audio(fname, audio_data) != 0) { - fprintf(stderr, "error: failed to ffmpeg decode '%s'\n", fname.c_str()); - - return false; - } - - if ((result = ma_decoder_init_memory(audio_data.data(), audio_data.size(), &decoder_config, &decoder)) != MA_SUCCESS) { - fprintf(stderr, "error: failed to read audio data as wav (%s)\n", ma_result_description(result)); - - return false; - } -#else - if ((result = ma_decoder_init_memory(fname.c_str(), fname.size(), &decoder_config, &decoder)) != MA_SUCCESS) { - fprintf(stderr, "error: failed to read audio data as wav (%s)\n", ma_result_description(result)); - - return false; - } -#endif - } - ma_uint64 frame_count; ma_uint64 frames_read; if ((result = ma_decoder_get_length_in_pcm_frames(&decoder, &frame_count)) != MA_SUCCESS) { - fprintf(stderr, "error: failed to retrieve the length of the audio data (%s)\n", ma_result_description(result)); - - return false; + fprintf(stderr, "error: failed to retrieve the length of the audio data (%s)\n", ma_result_description(result)); + return false; } pcmf32.resize(stereo ? frame_count*2 : frame_count); if ((result = ma_decoder_read_pcm_frames(&decoder, pcmf32.data(), frame_count, &frames_read)) != MA_SUCCESS) { - fprintf(stderr, "error: failed to read the frames of the audio data (%s)\n", ma_result_description(result)); - - return false; + fprintf(stderr, "error: failed to read the frames of the audio data (%s)\n", ma_result_description(result)); + return false; } if (stereo) { std::vector<float> stereo_data = pcmf32; pcmf32.resize(frame_count); - for (uint64_t i = 0; i < frame_count; i++) { pcmf32[i] = (stereo_data[2*i] + stereo_data[2*i + 1]); } - pcmf32s.resize(2); pcmf32s[0].resize(frame_count); pcmf32s[1].resize(frame_count); @@ -128,11 +72,111 @@ bool read_audio_data(const std::string & fname, std::vector<float>& pcmf32, std: } } - ma_decoder_uninit(&decoder); - return true; } +bool read_audio_data(const std::string & fname, std::vector<float> & pcmf32, std::vector<std::vector<float>> & pcmf32s, bool stereo) { + std::vector<uint8_t> audio_data; // used for pipe input from stdin or ffmpeg decoding output + + ma_result result; + ma_decoder_config decoder_config; + + struct decoder_guard { + ma_decoder decoder; + bool initialized = false; + ma_decoder * operator&() { return &decoder; } + ~decoder_guard() { + if (initialized) { + ma_decoder_uninit(&decoder); + } + } + }; + decoder_guard decoder{}; + + decoder_config = ma_decoder_config_init(ma_format_f32, stereo ? 2 : 1, WHISPER_SAMPLE_RATE); + + if (fname == "-") { +#ifdef _WIN32 + _setmode(_fileno(stdin), _O_BINARY); +#endif + + uint8_t buf[1024]; + while (true) + { + const size_t n = fread(buf, 1, sizeof(buf), stdin); + if (n == 0) { + break; + } + audio_data.insert(audio_data.end(), buf, buf + n); + } + + result = ma_decoder_init_memory(audio_data.data(), audio_data.size(), &decoder_config, &decoder); + if (result != MA_SUCCESS) { + fprintf(stderr, "%s: failed to open audio data from stdin (%s)\n", __func__, ma_result_description(result)); + return false; + } + decoder.initialized = true; + + fprintf(stderr, "%s: read %zu bytes from stdin\n", __func__, audio_data.size()); + } else { + fprintf(stderr, "%s: reading audio data from '%s' ...\n", __func__, fname.c_str()); + + // first try miniaudio. if it fails (or skipped) - try ffmpeg + { + const char * skip = getenv("WHISPER_COMMON_MINIAUDIO_SKIP"); + if (!skip || strlen(skip) == 0 || strcmp(skip, "0") == 0) { + fprintf(stderr, "%s: trying to decode with miniaudio\n", __func__); + + result = ma_decoder_init_file(fname.c_str(), &decoder_config, &decoder); + if (result == MA_SUCCESS) { + decoder.initialized = true; + } + } else { + fprintf(stderr, "%s: skipping miniaudio\n", __func__); + } + } + +#if defined(WHISPER_COMMON_FFMPEG) + if (!decoder.initialized) { + fprintf(stderr, "%s: trying to decode with ffmpeg\n", __func__); + + if (ffmpeg_decode_audio(fname, audio_data) != 0) { + fprintf(stderr, "%s: failed to ffmpeg decode\n", __func__); + return false; + } + result = ma_decoder_init_memory(audio_data.data(), audio_data.size(), &decoder_config, &decoder); + if (result != MA_SUCCESS) { + fprintf(stderr, "%s: failed to read audio data as wav (%s)\n", __func__, ma_result_description(result)); + return false; + } + decoder.initialized = true; + } +#endif + + if (!decoder.initialized) { + fprintf(stderr, "%s: failed to read audio data\n", __func__); + return false; + } + } + + return read_audio_from_decoder(decoder.decoder, pcmf32, pcmf32s, stereo); +} + +// decode audio bytes already held in memory +bool read_audio_data(const char * buffer, size_t buffer_size, std::vector<float> & pcmf32, std::vector<std::vector<float>> & pcmf32s, bool stereo) { + ma_decoder_config decoder_config = ma_decoder_config_init(ma_format_f32, stereo ? 2 : 1, WHISPER_SAMPLE_RATE); + ma_decoder decoder; + + if (ma_decoder_init_memory(buffer, buffer_size, &decoder_config, &decoder) != MA_SUCCESS) { + fprintf(stderr, "error: failed to decode audio data from memory buffer\n"); + return false; + } + + bool ok = read_audio_from_decoder(decoder, pcmf32, pcmf32s, stereo); + ma_decoder_uninit(&decoder); + return ok; +} + // 500 -> 00:05.000 // 6000 -> 01:00.000 std::string to_timestamp(int64_t t, bool comma) { @@ -154,6 +198,34 @@ int timestamp_to_sample(int64_t t, int n_samples, int whisper_sample_rate) { return std::max(0, std::min((int) n_samples - 1, (int) ((t*whisper_sample_rate)/100))); } +int utf8_trailing_bytes_needed(const std::string & s) { + const int n = (int) s.size(); + int i = n - 1; + while (i >= 0 && ((unsigned char) s[i] & 0xC0) == 0x80) { + --i; + } + if (i < 0) { + return 0; + } + + const unsigned char c = (unsigned char) s[i]; + int expected; + if ((c & 0x80) == 0x00) { + expected = 1; + } else if ((c & 0xE0) == 0xC0) { + expected = 2; + } else if ((c & 0xF0) == 0xE0) { + expected = 3; + } else if ((c & 0xF8) == 0xF0) { + expected = 4; + } else { + return 0; + } + + const int have = n - i; + return have >= expected ? 0 : (expected - have); +} + bool speak_with_file(const std::string & command, const std::string & text, const std::string & path, int voice_id) { std::ofstream speak_file(path.c_str()); if (speak_file.fail()) { diff --git a/examples/common-whisper.h b/examples/common-whisper.h index 4134362150a..aec430d3635 100644 --- a/examples/common-whisper.h +++ b/examples/common-whisper.h @@ -14,11 +14,22 @@ bool read_audio_data( std::vector<std::vector<float>> & pcmf32s, bool stereo); +// decode audio bytes already held in memory (uploaded file, network buffer) +bool read_audio_data( + const char * buffer, + size_t buffer_size, + std::vector<float> & pcmf32, + std::vector<std::vector<float>> & pcmf32s, + bool stereo); + // convert timestamp to string, 6000 -> 01:00.000 std::string to_timestamp(int64_t t, bool comma = false); // given a timestamp get the sample int timestamp_to_sample(int64_t t, int n_samples, int whisper_sample_rate); +// Returns the number of trailing bytes still needed for s to end on a complete UTF-8 codepoint. +int utf8_trailing_bytes_needed(const std::string & s); + // write text to file, and call system("command voice_id file") bool speak_with_file(const std::string & command, const std::string & text, const std::string & path, int voice_id); diff --git a/examples/ffmpeg-transcode.cpp b/examples/ffmpeg-transcode.cpp index 1fae58a4ffa..7657af69823 100644 --- a/examples/ffmpeg-transcode.cpp +++ b/examples/ffmpeg-transcode.cpp @@ -1,368 +1,238 @@ -/* SPDX-License-Identifier: GPL-2.0 */ +#ifdef WHISPER_COMMON_FFMPEG -/* - * transcode.c - convert audio file to WAVE - * - * Copyright (C) 2019 Andrew Clayton <andrew@digital-domain.net> - * Copyright (C) 2024 William Tambellini <william.tambellini@gmail.com> - */ - -// Just for conveninent C++ API -#include <vector> #include <string> - -// C -#include <stdio.h> -#include <stdlib.h> -#include <string.h> -#include <stdbool.h> -#include <stdint.h> -#include <sys/types.h> -#include <sys/stat.h> -#include <fcntl.h> -#include <unistd.h> -#include <sys/mman.h> +#include <vector> +#include <cstdio> +#include <cstring> extern "C" { -#include <libavutil/opt.h> -#include <libavcodec/avcodec.h> #include <libavformat/avformat.h> +#include <libavcodec/avcodec.h> #include <libswresample/swresample.h> } -typedef uint64_t u64; -typedef int64_t s64; -typedef uint32_t u32; -typedef int32_t s32; -typedef uint16_t u16; -typedef int16_t s16; -typedef uint8_t u8; -typedef int8_t s8; - -#define WAVE_SAMPLE_RATE 16000 -#define AVIO_CTX_BUF_SZ 4096 - -static const char* ffmpegLog = getenv("FFMPEG_LOG"); -// Todo: add __FILE__ __LINE__ -#define LOG(...) \ - do { if (ffmpegLog) fprintf(stderr, __VA_ARGS__); } while(0) // C99 - -/* - * WAVE file header based on definition from - * https://gist.github.com/Jon-Schneider/8b7c53d27a7a13346a643dac9c19d34f - * - * We must ensure this structure doesn't have any holes or - * padding so we can just map it straight to the WAVE data. - */ -struct wave_hdr { - /* RIFF Header: "RIFF" */ - char riff_header[4]; - /* size of audio data + sizeof(struct wave_hdr) - 8 */ - int wav_size; - /* "WAVE" */ - char wav_header[4]; - - /* Format Header */ - /* "fmt " (includes trailing space) */ - char fmt_header[4]; - /* Should be 16 for PCM */ - int fmt_chunk_size; - /* Should be 1 for PCM. 3 for IEEE Float */ - s16 audio_format; - s16 num_channels; - int sample_rate; - /* - * Number of bytes per second - * sample_rate * num_channels * bit_depth/8 - */ - int byte_rate; - /* num_channels * bytes per sample */ - s16 sample_alignment; - /* bits per sample */ - s16 bit_depth; - - /* Data Header */ - /* "data" */ - char data_header[4]; - /* - * size of audio - * number of samples * num_channels * bit_depth/8 - */ - int data_bytes; -} __attribute__((__packed__)); - -struct audio_buffer { - u8 *ptr; - int size; /* size left in the buffer */ -}; - -static void set_wave_hdr(wave_hdr& wh, size_t size) { - memcpy(&wh.riff_header, "RIFF", 4); - wh.wav_size = size + sizeof(struct wave_hdr) - 8; - memcpy(&wh.wav_header, "WAVE", 4); - memcpy(&wh.fmt_header, "fmt ", 4); - wh.fmt_chunk_size = 16; - wh.audio_format = 1; - wh.num_channels = 1; - wh.sample_rate = WAVE_SAMPLE_RATE; - wh.sample_alignment = 2; - wh.bit_depth = 16; - wh.byte_rate = wh.sample_rate * wh.sample_alignment; - memcpy(&wh.data_header, "data", 4); - wh.data_bytes = size; +// Write a minimal WAV header into the output buffer. +// Returns the number of bytes written (44 for a standard PCM WAV header). +static size_t wav_header_write(uint8_t * buf, int num_channels, int sample_rate, int bits_per_sample, uint32_t data_size) { + // RIFF header + memcpy(buf, "RIFF", 4); + uint32_t chunk_size = 36 + data_size; + memcpy(buf + 4, &chunk_size, 4); + memcpy(buf + 8, "WAVE", 4); + + // fmt subchunk + memcpy(buf + 12, "fmt ", 4); + uint32_t subchunk1_size = 16; + memcpy(buf + 16, &subchunk1_size, 4); + uint16_t audio_format = 1; // PCM + memcpy(buf + 20, &audio_format, 2); + memcpy(buf + 22, &num_channels, 2); + memcpy(buf + 24, &sample_rate, 4); + + int bytes_per_sample = (bits_per_sample / 8) * num_channels; + int byte_rate = sample_rate * bytes_per_sample; + memcpy(buf + 28, &byte_rate, 4); + memcpy(buf + 32, &bytes_per_sample, 2); + memcpy(buf + 34, &bits_per_sample, 2); + + // data subchunk + memcpy(buf + 36, "data", 4); + memcpy(buf + 40, &data_size, 4); + + return 44; } -static void write_wave_hdr(int fd, size_t size) { - struct wave_hdr wh; - set_wave_hdr(wh, size); - write(fd, &wh, sizeof(struct wave_hdr)); -} +bool ffmpeg_decode_audio(const std::string & ifname, std::vector<uint8_t> & wav_data, int out_sample_rate) { + { + const char * verbose = getenv("WHISPER_COMMON_FFMPEG_VERBOSE"); + if (verbose && strcmp(verbose, "2") == 0) { + av_log_set_level(AV_LOG_DEBUG); + } else if (verbose && strcmp(verbose, "1") == 0) { + av_log_set_level(AV_LOG_VERBOSE); + } else { + av_log_set_level(AV_LOG_WARNING); + } + } -static int map_file(int fd, u8 **ptr, size_t *size) -{ - struct stat sb; + AVFormatContext * fmt_ctx = nullptr; + if (avformat_open_input(&fmt_ctx, ifname.c_str(), nullptr, nullptr) != 0) { + fprintf(stderr, "error: failed to open input file '%s'\n", ifname.c_str()); + return true; + } - fstat(fd, &sb); - *size = sb.st_size; + if (avformat_find_stream_info(fmt_ctx, nullptr) < 0) { + fprintf(stderr, "error: failed to find stream information\n"); + avformat_close_input(&fmt_ctx); + return true; + } - *ptr = (u8*)mmap(NULL, *size, PROT_READ|PROT_WRITE, MAP_PRIVATE, fd, 0); - if (*ptr == MAP_FAILED) { - perror("mmap"); - return -1; - } + // Find the first audio stream + int audio_stream_idx = -1; + for (unsigned int i = 0; i < fmt_ctx->nb_streams; i++) { + if (fmt_ctx->streams[i]->codecpar->codec_type == AVMEDIA_TYPE_AUDIO) { + audio_stream_idx = i; + break; + } + } - return 0; -} + if (audio_stream_idx == -1) { + fprintf(stderr, "error: failed to find an audio stream in '%s'\n", ifname.c_str()); + avformat_close_input(&fmt_ctx); + return true; + } -static int read_packet(void *opaque, u8 *buf, int buf_size) -{ - struct audio_buffer *audio_buf = (audio_buffer*)opaque; + AVStream * audio_stream = fmt_ctx->streams[audio_stream_idx]; - buf_size = FFMIN(buf_size, audio_buf->size); + // Open the decoder + const AVCodec * codec = avcodec_find_decoder(audio_stream->codecpar->codec_id); + if (!codec) { + fprintf(stderr, "error: failed to find decoder for codec id %d\n", audio_stream->codecpar->codec_id); + avformat_close_input(&fmt_ctx); + return true; + } - /* copy internal buffer data to buf */ - memcpy(buf, audio_buf->ptr, buf_size); - audio_buf->ptr += buf_size; - audio_buf->size -= buf_size; + AVCodecContext * codec_ctx = avcodec_alloc_context3(codec); + if (!codec_ctx) { + fprintf(stderr, "error: failed to allocate codec context\n"); + avformat_close_input(&fmt_ctx); + return true; + } - return buf_size; -} + if (avcodec_parameters_to_context(codec_ctx, audio_stream->codecpar) < 0) { + fprintf(stderr, "error: failed to copy codec parameters to context\n"); + avcodec_free_context(&codec_ctx); + avformat_close_input(&fmt_ctx); + return true; + } -static void convert_frame(struct SwrContext *swr, AVCodecContext *codec, - AVFrame *frame, s16 **data, int *size, bool flush) -{ - int nr_samples; - s64 delay; - u8 *buffer; - - delay = swr_get_delay(swr, codec->sample_rate); - nr_samples = av_rescale_rnd(delay + frame->nb_samples, - WAVE_SAMPLE_RATE, codec->sample_rate, - AV_ROUND_UP); - av_samples_alloc(&buffer, NULL, 1, nr_samples, AV_SAMPLE_FMT_S16, 0); - - /* - * !flush is used to check if we are flushing any remaining - * conversion buffers... - */ - nr_samples = swr_convert(swr, &buffer, nr_samples, - !flush ? (const u8 **)frame->data : NULL, - !flush ? frame->nb_samples : 0); - - *data = (s16*)realloc(*data, (*size + nr_samples) * sizeof(s16)); - memcpy(*data + *size, buffer, nr_samples * sizeof(s16)); - *size += nr_samples; - av_freep(&buffer); -} + if (avcodec_open2(codec_ctx, codec, nullptr) < 0) { + fprintf(stderr, "error: failed to open codec\n"); + avcodec_free_context(&codec_ctx); + avformat_close_input(&fmt_ctx); + return true; + } -static bool is_audio_stream(const AVStream *stream) -{ - if (stream->codecpar->codec_type == AVMEDIA_TYPE_AUDIO) - return true; + // Setup resampler: convert to 16-bit signed PCM, mono, 16000 Hz + const enum AVSampleFormat out_sample_fmt = AV_SAMPLE_FMT_S16; - return false; -} + AVChannelLayout out_ch_layout = AV_CHANNEL_LAYOUT_MONO; -// Return non zero on error, 0 on success -// audio_buffer: input memory -// data: decoded output audio data (wav file) -// size: size of output data -static int decode_audio(struct audio_buffer *audio_buf, s16 **data, int *size) -{ - LOG("decode_audio: input size: %d\n", audio_buf->size); - AVFormatContext *fmt_ctx; - AVIOContext *avio_ctx; - AVStream *stream; - AVCodecContext *codec; - AVPacket *packet; - AVFrame *frame; - struct SwrContext *swr; - u8 *avio_ctx_buffer; - unsigned int i; - int stream_index = -1; - int err; - const size_t errbuffsize = 1024; - char errbuff[errbuffsize]; - - fmt_ctx = avformat_alloc_context(); - avio_ctx_buffer = (u8*)av_malloc(AVIO_CTX_BUF_SZ); - LOG("Creating an avio context: AVIO_CTX_BUF_SZ=%d\n", AVIO_CTX_BUF_SZ); - avio_ctx = avio_alloc_context(avio_ctx_buffer, AVIO_CTX_BUF_SZ, 0, audio_buf, &read_packet, NULL, NULL); - fmt_ctx->pb = avio_ctx; - - // open the input stream and read header - err = avformat_open_input(&fmt_ctx, NULL, NULL, NULL); - if (err) { - LOG("Could not read audio buffer: %d: %s\n", err, av_make_error_string(errbuff, errbuffsize, err)); - return err; - } - - err = avformat_find_stream_info(fmt_ctx, NULL); - if (err < 0) { - LOG("Could not retrieve stream info from audio buffer: %d\n", err); - return err; - } - - for (i = 0; i < fmt_ctx->nb_streams; i++) { - if (is_audio_stream(fmt_ctx->streams[i])) { - stream_index = i; - break; - } - } - - if (stream_index == -1) { - LOG("Could not retrieve audio stream from buffer\n"); - return -1; - } - - stream = fmt_ctx->streams[stream_index]; - codec = avcodec_alloc_context3( - avcodec_find_decoder(stream->codecpar->codec_id)); - avcodec_parameters_to_context(codec, stream->codecpar); - err = avcodec_open2(codec, avcodec_find_decoder(codec->codec_id), - NULL); - if (err) { - LOG("Failed to open decoder for stream #%d in audio buffer\n", stream_index); - return err; - } - - /* prepare resampler */ - swr = swr_alloc(); - -#if LIBAVCODEC_VERSION_MAJOR > 60 - AVChannelLayout in_ch_layout = codec->ch_layout; - AVChannelLayout out_ch_layout = AV_CHANNEL_LAYOUT_MONO; - - /* Set the source audio layout as-is */ - av_opt_set_chlayout(swr, "in_chlayout", &in_ch_layout, 0); - av_opt_set_int(swr, "in_sample_rate", codec->sample_rate, 0); - av_opt_set_sample_fmt(swr, "in_sample_fmt", codec->sample_fmt, 0); - - /* Convert it into 16khz Mono */ - av_opt_set_chlayout(swr, "out_chlayout", &out_ch_layout, 0); - av_opt_set_int(swr, "out_sample_rate", WAVE_SAMPLE_RATE, 0); - av_opt_set_sample_fmt(swr, "out_sample_fmt", AV_SAMPLE_FMT_S16, 0); -#else - av_opt_set_int(swr, "in_channel_count", codec->channels, 0); - av_opt_set_int(swr, "out_channel_count", 1, 0); - av_opt_set_int(swr, "in_channel_layout", codec->channel_layout, 0); - av_opt_set_int(swr, "out_channel_layout", AV_CH_LAYOUT_MONO, 0); - av_opt_set_int(swr, "in_sample_rate", codec->sample_rate, 0); - av_opt_set_int(swr, "out_sample_rate", WAVE_SAMPLE_RATE, 0); - av_opt_set_sample_fmt(swr, "in_sample_fmt", codec->sample_fmt, 0); - av_opt_set_sample_fmt(swr, "out_sample_fmt", AV_SAMPLE_FMT_S16, 0); -#endif - - swr_init(swr); - if (!swr_is_initialized(swr)) { - LOG("Resampler has not been properly initialized\n"); - return -1; - } - - packet=av_packet_alloc(); - if (!packet) { - LOG("Error allocating the packet\n"); - return -1; - } - frame = av_frame_alloc(); - if (!frame) { - LOG("Error allocating the frame\n"); - return -1; - } - - /* iterate through frames */ - *data = NULL; - *size = 0; - while (av_read_frame(fmt_ctx, packet) >= 0) { - avcodec_send_packet(codec, packet); - - err = avcodec_receive_frame(codec, frame); - if (err == AVERROR(EAGAIN)) - continue; - - convert_frame(swr, codec, frame, data, size, false); - } - /* Flush any remaining conversion buffers... */ - convert_frame(swr, codec, frame, data, size, true); - - av_packet_free(&packet); - av_frame_free(&frame); - swr_free(&swr); - //avio_context_free(); // todo? - avcodec_free_context(&codec); - avformat_close_input(&fmt_ctx); - avformat_free_context(fmt_ctx); - - if (avio_ctx) { - av_freep(&avio_ctx->buffer); - av_freep(&avio_ctx); - } - - return 0; -} + SwrContext * swr_ctx = nullptr; + if (swr_alloc_set_opts2(&swr_ctx, &out_ch_layout, out_sample_fmt, out_sample_rate, + &codec_ctx->ch_layout, codec_ctx->sample_fmt, codec_ctx->sample_rate, + 0, nullptr) < 0) { + fprintf(stderr, "error: failed to allocate swr context\n"); + avcodec_free_context(&codec_ctx); + avformat_close_input(&fmt_ctx); + return true; + } -// in mem decoding/conversion/resampling: -// ifname: input file path -// owav_data: in mem wav file. Can be forwarded as it to whisper/drwav -// return 0 on success -int ffmpeg_decode_audio(const std::string &ifname, std::vector<uint8_t>& owav_data) { - LOG("ffmpeg_decode_audio: %s\n", ifname.c_str()); - int ifd = open(ifname.c_str(), O_RDONLY); - if (ifd == -1) { - fprintf(stderr, "Couldn't open input file %s\n", ifname.c_str()); - return -1; + if (swr_init(swr_ctx) < 0) { + fprintf(stderr, "error: failed to initialize swr context\n"); + swr_free(&swr_ctx); + avcodec_free_context(&codec_ctx); + avformat_close_input(&fmt_ctx); + return true; } - u8 *ibuf = NULL; - size_t ibuf_size; - int err = map_file(ifd, &ibuf, &ibuf_size); - if (err) { - LOG("Couldn't map input file %s\n", ifname.c_str()); - return err; + + // Decode and resample + AVPacket * packet = av_packet_alloc(); + AVFrame * frame = av_frame_alloc(); + + // Buffer to collect resampled output + std::vector<int16_t> pcm_data; + + // Max output samples per swr_convert call + const int max_out_samples = 16 * 1024; + std::vector<int16_t> out_buffer(max_out_samples); + + while (av_read_frame(fmt_ctx, packet) >= 0) { + if (packet->stream_index != audio_stream_idx) { + av_packet_unref(packet); + continue; + } + + int ret = avcodec_send_packet(codec_ctx, packet); + av_packet_unref(packet); + + if (ret < 0) { + continue; + } + + while (ret >= 0) { + ret = avcodec_receive_frame(codec_ctx, frame); + if (ret == AVERROR(EAGAIN) || ret == AVERROR_EOF) { + break; + } + if (ret < 0) { + break; + } + + // Resample + int out_samples = av_rescale_rnd(swr_get_delay(swr_ctx, out_sample_rate) + frame->nb_samples, + out_sample_rate, out_sample_rate, AV_ROUND_UP); + if (out_samples > (int)out_buffer.size()) { + out_buffer.resize(out_samples); + } + + const uint8_t * in_data[16] = {0}; + for (int p = 0; p < (int)codec_ctx->ch_layout.nb_channels && p < 16; p++) { + in_data[p] = frame->data[p]; + } + uint8_t * out_data[16] = {0}; + out_data[0] = (uint8_t *)out_buffer.data(); + + int got_samples = swr_convert(swr_ctx, out_data, out_samples, in_data, frame->nb_samples); + if (got_samples > 0) { + pcm_data.insert(pcm_data.end(), out_buffer.begin(), out_buffer.begin() + got_samples); + } + } } - LOG("Mapped input file: %s size: %d\n", ibuf, (int) ibuf_size); - struct audio_buffer inaudio_buf; - inaudio_buf.ptr = ibuf; - inaudio_buf.size = ibuf_size; - - s16 *odata=NULL; - int osize=0; - - err = decode_audio(&inaudio_buf, &odata, &osize); - LOG("decode_audio returned %d \n", err); - if (err != 0) { - LOG("decode_audio failed\n"); - return err; + + // Flush the decoder + avcodec_send_packet(codec_ctx, nullptr); + while (avcodec_receive_frame(codec_ctx, frame) >= 0) { + int out_samples = av_rescale_rnd(swr_get_delay(swr_ctx, out_sample_rate) + frame->nb_samples, + out_sample_rate, out_sample_rate, AV_ROUND_UP); + if (out_samples > (int)out_buffer.size()) { + out_buffer.resize(out_samples); + } + const uint8_t * in_data[16] = {0}; + for (int p = 0; p < (int)codec_ctx->ch_layout.nb_channels && p < 16; p++) { + in_data[p] = frame->data[p]; + } + uint8_t * out_data[16] = {0}; + out_data[0] = (uint8_t *)out_buffer.data(); + + int got_samples = swr_convert(swr_ctx, out_data, out_samples, in_data, frame->nb_samples); + if (got_samples > 0) { + pcm_data.insert(pcm_data.end(), out_buffer.begin(), out_buffer.begin() + got_samples); + } } - LOG("decode_audio output size: %d\n", osize); - - wave_hdr wh; - const size_t outdatasize = osize * sizeof(s16); - set_wave_hdr(wh, outdatasize); - owav_data.resize(sizeof(wave_hdr) + outdatasize); - // header: - memcpy(owav_data.data(), &wh, sizeof(wave_hdr)); - // the data: - memcpy(owav_data.data() + sizeof(wave_hdr), odata, osize* sizeof(s16)); - - return 0; + + // Flush the resampler + uint8_t * out_data[16] = {0}; + out_data[0] = (uint8_t *)out_buffer.data(); + int flush_samples = swr_convert(swr_ctx, out_data, max_out_samples, nullptr, 0); + if (flush_samples > 0) { + pcm_data.insert(pcm_data.end(), out_buffer.begin(), out_buffer.begin() + flush_samples); + } + + // Build WAV output + uint32_t data_size = pcm_data.size() * sizeof(int16_t); + wav_data.resize(44 + data_size); + + wav_header_write(wav_data.data(), 1, out_sample_rate, 16, data_size); + memcpy(wav_data.data() + 44, pcm_data.data(), data_size); + + // Cleanup + av_frame_free(&frame); + av_packet_free(&packet); + swr_free(&swr_ctx); + avcodec_free_context(&codec_ctx); + avformat_close_input(&fmt_ctx); + + return false; // success } + +#endif // WHISPER_COMMON_FFMPEG diff --git a/examples/miniaudio.h b/examples/miniaudio.h index c74bebeb3c7..24e676bb264 100644 --- a/examples/miniaudio.h +++ b/examples/miniaudio.h @@ -1,6 +1,6 @@ /* Audio playback and capture library. Choice of public domain or MIT-0. See license statements at the end of this file. -miniaudio - v0.11.22 - 2025-02-24 +miniaudio - v0.11.24 - 2026-01-17 David Reid - mackron@gmail.com @@ -12,18 +12,10 @@ GitHub: https://github.com/mackron/miniaudio /* 1. Introduction =============== -To use miniaudio, include "miniaudio.h": - - ```c - #include "miniaudio.h" - ``` - -The implementation is contained in "miniaudio.c". Just compile this like any other source file. You -can include miniaudio.c if you want to compile your project as a single translation unit: - - ```c - #include "miniaudio.c" - ``` +To use miniaudio, just include "miniaudio.h" like any other header and add "miniaudio.c" to your +source tree. If you don't want to add it to your source tree you can compile and link to it like +any other library. Note that ABI compatibility is not guaranteed between versions, even with bug +fix releases, so take care if compiling as a shared object. miniaudio includes both low level and high level APIs. The low level API is good for those who want to do all of their mixing themselves and only require a light weight interface to the underlying @@ -303,7 +295,7 @@ The engine encapsulates both the resource manager and the node graph to create a use high level API. The resource manager and node graph APIs are covered in more later sections of this manual. -The code below shows how you can initialize an engine using it's default configuration. +The code below shows how you can initialize an engine using its default configuration. ```c ma_result result; @@ -391,7 +383,7 @@ Sounds are not started by default. Start a sound with `ma_sound_start()` and sto `ma_sound_stop()`. When a sound is stopped, it is not rewound to the start. Use `ma_sound_seek_to_pcm_frame(&sound, 0)` to seek back to the start of a sound. By default, starting and stopping sounds happens immediately, but sometimes it might be convenient to schedule the sound -the be started and/or stopped at a specific time. This can be done with the following functions: +to be started and/or stopped at a specific time. This can be done with the following functions: ```c ma_sound_set_start_time_in_pcm_frames() @@ -463,6 +455,11 @@ is at the end, use `ma_sound_at_end()`. Looping of a sound can be controlled wit miniaudio should work cleanly out of the box without the need to download or install any dependencies. See below for platform-specific details. +This library has been designed to be added directly to your source tree which is the preferred way +of using it, but you can compile it as a normal library if that's your preference. Be careful if +compiling as a shared object because miniaudio is not ABI compatible between any release, including +bug fix releases. It's recommended you link statically. + Note that GCC and Clang require `-msse2`, `-mavx2`, etc. for SIMD optimizations. If you get errors about undefined references to `__sync_val_compare_and_swap_8`, `__atomic_load_8`, @@ -532,7 +529,7 @@ you'll need to disable run-time linking with `MA_NO_RUNTIME_LINKING` and link wi The Emscripten build emits Web Audio JavaScript directly and should compile cleanly out of the box. You cannot use `-std=c*` compiler flags, nor `-ansi`. -You can enable the use of AudioWorkets by defining `MA_ENABLE_AUDIO_WORKLETS` and then compiling +You can enable the use of AudioWorklets by defining `MA_ENABLE_AUDIO_WORKLETS` and then compiling with the following options: -sAUDIO_WORKLET=1 -sWASM_WORKERS=1 -sASYNCIFY @@ -881,7 +878,7 @@ read data within a certain range of the underlying data. To do this you can use This is useful if you have a sound bank where many sounds are stored in the same file and you want the data source to only play one of those sub-sounds. Note that once the range is set, everything -that takes a position, such as cursors and loop points, should always be relatvie to the start of +that takes a position, such as cursors and loop points, should always be relative to the start of the range. When the range is set, any previously defined loop point will be reset. Custom loop points can also be used with data sources. By default, data sources will loop after @@ -889,7 +886,7 @@ they reach the end of the data source, but if you need to loop at a specific loc the following: ```c - result = ma_data_set_loop_point_in_pcm_frames(pDataSource, loopBegInFrames, loopEndInFrames); + result = ma_data_source_set_loop_point_in_pcm_frames(pDataSource, loopBegInFrames, loopEndInFrames); if (result != MA_SUCCESS) { return result; // Failed to set the loop point. } @@ -3750,7 +3747,7 @@ extern "C" { #define MA_VERSION_MAJOR 0 #define MA_VERSION_MINOR 11 -#define MA_VERSION_REVISION 22 +#define MA_VERSION_REVISION 24 #define MA_VERSION_STRING MA_XSTRINGIFY(MA_VERSION_MAJOR) "." MA_XSTRINGIFY(MA_VERSION_MINOR) "." MA_XSTRINGIFY(MA_VERSION_REVISION) #if defined(_MSC_VER) && !defined(__clang__) @@ -3857,37 +3854,65 @@ typedef ma_uint16 wchar_t; #define MA_SIZE_MAX 0xFFFFFFFF /* When SIZE_MAX is not defined by the standard library just default to the maximum 32-bit unsigned integer. */ #endif +#define MA_UINT64_MAX (((ma_uint64)0xFFFFFFFF << 32) | (ma_uint64)0xFFFFFFFF) /* Weird shifting syntax is for VC6 compatibility. */ + /* Platform/backend detection. */ -#if defined(_WIN32) || defined(__COSMOPOLITAN__) +#if defined(_WIN32) #define MA_WIN32 #if defined(MA_FORCE_UWP) || (defined(WINAPI_FAMILY) && ((defined(WINAPI_FAMILY_PC_APP) && WINAPI_FAMILY == WINAPI_FAMILY_PC_APP) || (defined(WINAPI_FAMILY_PHONE_APP) && WINAPI_FAMILY == WINAPI_FAMILY_PHONE_APP))) #define MA_WIN32_UWP #elif defined(WINAPI_FAMILY) && (defined(WINAPI_FAMILY_GAMES) && WINAPI_FAMILY == WINAPI_FAMILY_GAMES) #define MA_WIN32_GDK + #elif defined(NXDK) + #define MA_WIN32_NXDK #else #define MA_WIN32_DESKTOP #endif + + /* The original Xbox. */ + #if defined(NXDK) /* <-- Add other Xbox compiler toolchains here, and then add a toolchain-specific define in case we need to discriminate between them later. */ + #define MA_XBOX + + #if defined(NXDK) + #define MA_XBOX_NXDK + #endif + #endif +#endif +#if defined(__MSDOS__) || defined(MSDOS) || defined(_MSDOS) || defined(__DOS__) + #define MA_DOS + + /* No threading allowed on DOS. */ + #ifndef MA_NO_THREADING + #define MA_NO_THREADING + #endif + + /* No runtime linking allowed on DOS. */ + #ifndef MA_NO_RUNTIME_LINKING + #define MA_NO_RUNTIME_LINKING + #endif #endif -#if !defined(_WIN32) /* If it's not Win32, assume POSIX. */ +#if !defined(MA_WIN32) && !defined(MA_DOS) /* If it's not Win32, assume POSIX. */ #define MA_POSIX - /* - Use the MA_NO_PTHREAD_IN_HEADER option at your own risk. This is intentionally undocumented. - You can use this to avoid including pthread.h in the header section. The downside is that it - results in some fixed sized structures being declared for the various types that are used in - miniaudio. The risk here is that these types might be too small for a given platform. This - risk is yours to take and no support will be offered if you enable this option. - */ - #ifndef MA_NO_PTHREAD_IN_HEADER - #include <pthread.h> /* Unfortunate #include, but needed for pthread_t, pthread_mutex_t and pthread_cond_t types. */ - typedef pthread_t ma_pthread_t; - typedef pthread_mutex_t ma_pthread_mutex_t; - typedef pthread_cond_t ma_pthread_cond_t; - #else - typedef ma_uintptr ma_pthread_t; - typedef union ma_pthread_mutex_t { char __data[40]; ma_uint64 __alignment; } ma_pthread_mutex_t; - typedef union ma_pthread_cond_t { char __data[48]; ma_uint64 __alignment; } ma_pthread_cond_t; + #if !defined(MA_NO_THREADING) + /* + Use the MA_NO_PTHREAD_IN_HEADER option at your own risk. This is intentionally undocumented. + You can use this to avoid including pthread.h in the header section. The downside is that it + results in some fixed sized structures being declared for the various types that are used in + miniaudio. The risk here is that these types might be too small for a given platform. This + risk is yours to take and no support will be offered if you enable this option. + */ + #ifndef MA_NO_PTHREAD_IN_HEADER + #include <pthread.h> /* Unfortunate #include, but needed for pthread_t, pthread_mutex_t and pthread_cond_t types. */ + typedef pthread_t ma_pthread_t; + typedef pthread_mutex_t ma_pthread_mutex_t; + typedef pthread_cond_t ma_pthread_cond_t; + #else + typedef ma_uintptr ma_pthread_t; + typedef union ma_pthread_mutex_t { char __data[40]; ma_uint64 __alignment; } ma_pthread_mutex_t; + typedef union ma_pthread_cond_t { char __data[48]; ma_uint64 __alignment; } ma_pthread_cond_t; + #endif #endif #if defined(__unix__) @@ -3914,8 +3939,11 @@ typedef ma_uint16 wchar_t; #if defined(__PROSPERO__) #define MA_PROSPERO #endif - #if defined(__NX__) - #define MA_NX + #if defined(__3DS__) + #define MA_3DS + #endif + #if defined(__SWITCH__) || defined(__NX__) + #define MA_SWITCH #endif #if defined(__BEOS__) || defined(__HAIKU__) #define MA_BEOS @@ -3925,12 +3953,13 @@ typedef ma_uint16 wchar_t; #endif #endif -#if defined(__has_c_attribute) - #if __has_c_attribute(fallthrough) - #define MA_FALLTHROUGH [[fallthrough]] - #endif +#if !defined(MA_FALLTHROUGH) && defined(__cplusplus) && __cplusplus >= 201703L + #define MA_FALLTHROUGH [[fallthrough]] #endif -#if !defined(MA_FALLTHROUGH) && defined(__has_attribute) && (defined(__clang__) || defined(__GNUC__)) +#if !defined(MA_FALLTHROUGH) && defined(__STDC_VERSION__) && __STDC_VERSION__ >= 202000L + #define MA_FALLTHROUGH [[fallthrough]] +#endif +#if !defined(MA_FALLTHROUGH) && defined(__has_attribute) #if __has_attribute(fallthrough) #define MA_FALLTHROUGH __attribute__((fallthrough)) #endif @@ -3967,7 +3996,7 @@ typedef ma_uint16 wchar_t; #define MA_NO_INLINE __attribute__((noinline)) #else #define MA_INLINE MA_GNUC_INLINE_HINT - #define MA_NO_INLINE __attribute__((noinline)) + #define MA_NO_INLINE #endif #elif defined(__WATCOMC__) #define MA_INLINE __inline @@ -4153,9 +4182,13 @@ typedef enum MA_CHANNEL_AUX_29 = 49, MA_CHANNEL_AUX_30 = 50, MA_CHANNEL_AUX_31 = 51, + + /* Count. */ + MA_CHANNEL_POSITION_COUNT, + + /* Aliases. */ MA_CHANNEL_LEFT = MA_CHANNEL_FRONT_LEFT, MA_CHANNEL_RIGHT = MA_CHANNEL_FRONT_RIGHT, - MA_CHANNEL_POSITION_COUNT = (MA_CHANNEL_AUX_31 + 1) } _ma_channel_position; /* Do not use `_ma_channel_position` directly. Use `ma_channel` instead. */ typedef enum @@ -4350,7 +4383,7 @@ typedef struct typedef struct { - ma_int32 state; + ma_uint32 state; } ma_lcg; @@ -6569,22 +6602,18 @@ This section contains the APIs for device playback and capture. Here is where yo ************************************************************************************************************************************************************/ #ifndef MA_NO_DEVICE_IO /* Some backends are only supported on certain platforms. */ -#if defined(MA_WIN32) +#if defined(MA_WIN32) && !defined(MA_XBOX) #define MA_SUPPORT_WASAPI #if defined(MA_WIN32_DESKTOP) /* DirectSound and WinMM backends are only supported on desktops. */ #define MA_SUPPORT_DSOUND #define MA_SUPPORT_WINMM - - /* Don't enable JACK here if compiling with Cosmopolitan. It'll be enabled in the Linux section below. */ - #if !defined(__COSMOPOLITAN__) - #define MA_SUPPORT_JACK /* JACK is technically supported on Windows, but I don't know how many people use it in practice... */ - #endif + #define MA_SUPPORT_JACK /* JACK is technically supported on Windows, but I don't know how many people use it in practice... */ #endif #endif #if defined(MA_UNIX) && !defined(MA_ORBIS) && !defined(MA_PROSPERO) #if defined(MA_LINUX) - #if !defined(MA_ANDROID) && !defined(__COSMOPOLITAN__) /* ALSA is not supported on Android. */ + #if !defined(MA_ANDROID) && !defined(MA_EMSCRIPTEN) /* ALSA is not supported on Android. */ #define MA_SUPPORT_ALSA #endif #endif @@ -7426,6 +7455,7 @@ struct ma_context ma_proc snd_pcm_hw_params_set_rate_resample; ma_proc snd_pcm_hw_params_set_rate; ma_proc snd_pcm_hw_params_set_rate_near; + ma_proc snd_pcm_hw_params_set_rate_minmax; ma_proc snd_pcm_hw_params_set_buffer_size_near; ma_proc snd_pcm_hw_params_set_periods_near; ma_proc snd_pcm_hw_params_set_access; @@ -7986,6 +8016,7 @@ struct ma_device /*AAudioStream**/ ma_ptr pStreamPlayback; /*AAudioStream**/ ma_ptr pStreamCapture; ma_mutex rerouteLock; + ma_atomic_bool32 isTearingDown; ma_aaudio_usage usage; ma_aaudio_content_type contentType; ma_aaudio_input_preset inputPreset; @@ -9644,7 +9675,7 @@ Parameters ---------- pBackends (out, optional) A pointer to the buffer that will receive the enabled backends. Set to NULL to retrieve the backend count. Setting - the capacity of the buffer to `MA_BUFFER_COUNT` will guarantee it's large enough for all backends. + the capacity of the buffer to `MA_BACKEND_COUNT` will guarantee it's large enough for all backends. backendCap (in) The capacity of the `pBackends` buffer. @@ -10489,6 +10520,7 @@ typedef struct ma_decoding_backend_vtable** ppCustomDecodingBackendVTables; ma_uint32 customDecodingBackendCount; void* pCustomDecodingBackendUserData; + ma_resampler_config resampling; } ma_resource_manager_config; MA_API ma_resource_manager_config ma_resource_manager_config_init(void); @@ -10816,6 +10848,7 @@ MA_API ma_result ma_node_graph_read_pcm_frames(ma_node_graph* pNodeGraph, void* MA_API ma_uint32 ma_node_graph_get_channels(const ma_node_graph* pNodeGraph); MA_API ma_uint64 ma_node_graph_get_time(const ma_node_graph* pNodeGraph); MA_API ma_result ma_node_graph_set_time(ma_node_graph* pNodeGraph, ma_uint64 globalTime); +MA_API ma_uint32 ma_node_graph_get_processing_size_in_frames(const ma_node_graph* pNodeGraph); @@ -11123,6 +11156,7 @@ typedef struct ma_bool8 isPitchDisabled; /* Pitching can be explicitly disabled with MA_SOUND_FLAG_NO_PITCH to optimize processing. */ ma_bool8 isSpatializationDisabled; /* Spatialization can be explicitly disabled with MA_SOUND_FLAG_NO_SPATIALIZATION. */ ma_uint8 pinnedListenerIndex; /* The index of the listener this node should always use for spatialization. If set to MA_LISTENER_INDEX_CLOSEST the engine will use the closest listener. */ + ma_resampler_config resampling; } ma_engine_node_config; MA_API ma_engine_node_config ma_engine_node_config_init(ma_engine* pEngine, ma_engine_node_type type, ma_uint32 flags); @@ -11137,7 +11171,7 @@ typedef struct ma_uint32 volumeSmoothTimeInPCMFrames; ma_mono_expansion_mode monoExpansionMode; ma_fader fader; - ma_linear_resampler resampler; /* For pitch shift. */ + ma_resampler resampler; /* For pitch shift. */ ma_spatializer spatializer; ma_panner panner; ma_gainer volumeGainer; /* This will only be used if volumeSmoothTimeInPCMFrames is > 0. */ @@ -11193,6 +11227,7 @@ typedef struct ma_uint64 loopPointEndInPCMFrames; ma_sound_end_proc endCallback; /* Fired when the sound reaches the end. Will be fired from the audio thread. Do not restart, uninitialize or otherwise change the state of the sound from here. Instead fire an event or set a variable to indicate to a different thread to change the start of the sound. Will not be fired in response to a scheduled stop with ma_sound_set_stop_time_*(). */ void* pEndCallbackUserData; + ma_resampler_config pitchResampling; #ifndef MA_NO_RESOURCE_MANAGER ma_resource_manager_pipeline_notifications initNotifications; #endif @@ -11211,7 +11246,10 @@ struct ma_sound MA_ATOMIC(4, ma_bool32) atEnd; ma_sound_end_proc endCallback; void* pEndCallbackUserData; - ma_bool8 ownsDataSource; + float* pProcessingCache; /* Will be null if pDataSource is null. */ + ma_uint32 processingCacheFramesRemaining; + ma_uint32 processingCacheCap; + ma_bool8 ownsDataSource; /* We're declaring a resource manager data source object here to save us a malloc when loading a @@ -11255,7 +11293,7 @@ typedef struct ma_log* pLog; /* When set to NULL, will use the context's log. */ ma_uint32 listenerCount; /* Must be between 1 and MA_ENGINE_MAX_LISTENERS. */ ma_uint32 channels; /* The number of channels to use when mixing and spatializing. When set to 0, will use the native channel count of the device. */ - ma_uint32 sampleRate; /* The sample rate. When set to 0 will use the native channel count of the device. */ + ma_uint32 sampleRate; /* The sample rate. When set to 0 will use the native sample rate of the device. */ ma_uint32 periodSizeInFrames; /* If set to something other than 0, updates will always be exactly this size. The underlying device may be a different size, but from the perspective of the mixer that won't matter.*/ ma_uint32 periodSizeInMilliseconds; /* Used if periodSizeInFrames is unset. */ ma_uint32 gainSmoothTimeInFrames; /* The number of frames to interpolate the gain of spatialized sounds across. If set to 0, will use gainSmoothTimeInMilliseconds. */ @@ -11269,6 +11307,8 @@ typedef struct ma_vfs* pResourceManagerVFS; /* A pointer to a pre-allocated VFS object to use with the resource manager. This is ignored if pResourceManager is not NULL. */ ma_engine_process_proc onProcess; /* Fired at the end of each call to ma_engine_read_pcm_frames(). For engine's that manage their own internal device (the default configuration), this will be fired from the audio thread, and you do not need to call ma_engine_read_pcm_frames() manually in order to trigger this. */ void* pProcessUserData; /* User data that's passed into onProcess. */ + ma_resampler_config resourceManagerResampling; /* The resampling config to use with the resource manager. */ + ma_resampler_config pitchResampling; /* The resampling config for the pitch and Doppler effects. You will typically want this to be a fast resampler. For high quality stuff, it's recommended that you pre-resample. */ } ma_engine_config; MA_API ma_engine_config ma_engine_config_init(void); @@ -11298,6 +11338,7 @@ struct ma_engine ma_mono_expansion_mode monoExpansionMode; ma_engine_process_proc onProcess; void* pProcessUserData; + ma_resampler_config pitchResamplingConfig; }; MA_API ma_result ma_engine_init(const ma_engine_config* pConfig, ma_engine* pEngine); @@ -11358,8 +11399,12 @@ MA_API ma_engine* ma_sound_get_engine(const ma_sound* pSound); MA_API ma_data_source* ma_sound_get_data_source(const ma_sound* pSound); MA_API ma_result ma_sound_start(ma_sound* pSound); MA_API ma_result ma_sound_stop(ma_sound* pSound); -MA_API ma_result ma_sound_stop_with_fade_in_pcm_frames(ma_sound* pSound, ma_uint64 fadeLengthInFrames); /* Will overwrite any scheduled stop and fade. */ -MA_API ma_result ma_sound_stop_with_fade_in_milliseconds(ma_sound* pSound, ma_uint64 fadeLengthInFrames); /* Will overwrite any scheduled stop and fade. */ +MA_API ma_result ma_sound_stop_with_fade_in_pcm_frames(ma_sound* pSound, ma_uint64 fadeLengthInFrames); /* Will overwrite any scheduled stop and fade. If you want to restart the sound, first reset it with `ma_sound_reset_stop_time_and_fade()`. There are plans to make this less awkward in the future. */ +MA_API ma_result ma_sound_stop_with_fade_in_milliseconds(ma_sound* pSound, ma_uint64 fadeLengthInFrames); /* Will overwrite any scheduled stop and fade. If you want to restart the sound, first reset it with `ma_sound_reset_stop_time_and_fade()`. There are plans to make this less awkward in the future. */ +MA_API void ma_sound_reset_start_time(ma_sound* pSound); +MA_API void ma_sound_reset_stop_time(ma_sound* pSound); +MA_API void ma_sound_reset_fade(ma_sound* pSound); +MA_API void ma_sound_reset_stop_time_and_fade(ma_sound* pSound); /* Resets fades and scheduled stop time. Does not seek back to the start. */ MA_API void ma_sound_set_volume(ma_sound* pSound, float volume); MA_API float ma_sound_get_volume(const ma_sound* pSound); MA_API void ma_sound_set_pan(ma_sound* pSound, float pan); @@ -11419,11 +11464,11 @@ MA_API ma_bool32 ma_sound_is_looping(const ma_sound* pSound); MA_API ma_bool32 ma_sound_at_end(const ma_sound* pSound); MA_API ma_result ma_sound_seek_to_pcm_frame(ma_sound* pSound, ma_uint64 frameIndex); /* Just a wrapper around ma_data_source_seek_to_pcm_frame(). */ MA_API ma_result ma_sound_seek_to_second(ma_sound* pSound, float seekPointInSeconds); /* Abstraction to ma_sound_seek_to_pcm_frame() */ -MA_API ma_result ma_sound_get_data_format(ma_sound* pSound, ma_format* pFormat, ma_uint32* pChannels, ma_uint32* pSampleRate, ma_channel* pChannelMap, size_t channelMapCap); -MA_API ma_result ma_sound_get_cursor_in_pcm_frames(ma_sound* pSound, ma_uint64* pCursor); -MA_API ma_result ma_sound_get_length_in_pcm_frames(ma_sound* pSound, ma_uint64* pLength); -MA_API ma_result ma_sound_get_cursor_in_seconds(ma_sound* pSound, float* pCursor); -MA_API ma_result ma_sound_get_length_in_seconds(ma_sound* pSound, float* pLength); +MA_API ma_result ma_sound_get_data_format(const ma_sound* pSound, ma_format* pFormat, ma_uint32* pChannels, ma_uint32* pSampleRate, ma_channel* pChannelMap, size_t channelMapCap); +MA_API ma_result ma_sound_get_cursor_in_pcm_frames(const ma_sound* pSound, ma_uint64* pCursor); +MA_API ma_result ma_sound_get_length_in_pcm_frames(const ma_sound* pSound, ma_uint64* pLength); +MA_API ma_result ma_sound_get_cursor_in_seconds(const ma_sound* pSound, float* pCursor); +MA_API ma_result ma_sound_get_length_in_seconds(const ma_sound* pSound, float* pLength); MA_API ma_result ma_sound_set_end_callback(ma_sound* pSound, ma_sound_end_proc callback, void* pUserData); MA_API ma_result ma_sound_group_init(ma_engine* pEngine, ma_uint32 flags, ma_sound_group* pParentGroup, ma_sound_group* pGroup); @@ -11544,16 +11589,22 @@ IMPLEMENTATION #endif #if !defined(MA_WIN32) -#include <sched.h> -#include <sys/time.h> /* select() (used for ma_sleep()). */ -#include <pthread.h> -#endif + #if !defined(MA_NO_THREADING) + #include <sched.h> + #include <pthread.h> /* For pthreads. */ + #endif -#ifdef MA_NX -#include <time.h> /* For nanosleep() */ + #include <sys/time.h> /* select() (used for ma_sleep()). */ + #include <time.h> /* For nanosleep() */ + #include <unistd.h> #endif -#include <sys/stat.h> /* For fstat(), etc. */ +/* For fstat(), etc. */ +#if defined(MA_XBOX_NXDK) + #include <stat.h> /* Suggestion for NXDK: Add a sys/stat.h wrapper for compatibility. */ +#else + #include <sys/stat.h> +#endif #ifdef MA_EMSCRIPTEN #include <emscripten/emscripten.h> @@ -11606,7 +11657,7 @@ IMPLEMENTATION #endif /* Intrinsics Support */ -#if (defined(MA_X64) || defined(MA_X86)) && !defined(__COSMOPOLITAN__) +#if defined(MA_X64) || defined(MA_X86) #if defined(_MSC_VER) && !defined(__clang__) /* MSVC. */ #if _MSC_VER >= 1400 && !defined(MA_NO_SSE2) /* 2005 */ @@ -11861,7 +11912,7 @@ static MA_INLINE ma_bool32 ma_has_neon(void) #endif #ifndef MA_RESTRICT - #if defined(__clang__) || defined(__GNUC__) || defined(_MSC_VER) + #if defined(__clang__) || defined(_MSC_VER) || (defined(__GNUC__) && (__GNUC__ > 2 || (__GNUC__ == 2 && __GNUC_MINOR__ >= 95))) #define MA_RESTRICT __restrict #else #define MA_RESTRICT @@ -11955,7 +12006,7 @@ static void ma_sleep__posix(ma_uint32 milliseconds) (void)milliseconds; MA_ASSERT(MA_FALSE); /* The Emscripten build should never sleep. */ #else - #if (defined(_POSIX_C_SOURCE) && _POSIX_C_SOURCE >= 199309L) || defined(MA_NX) + #if (defined(_POSIX_C_SOURCE) && _POSIX_C_SOURCE >= 199309L) || defined(MA_SWITCH) struct timespec ts; ts.tv_sec = milliseconds / 1000; ts.tv_nsec = milliseconds % 1000 * 1000000; @@ -11997,7 +12048,7 @@ static MA_INLINE void ma_yield(void) #endif #endif #else - __asm__ __volatile__ ("pause"); + __asm__ __volatile__ ("rep; nop"); #endif #elif (defined(__arm__) && defined(__ARM_ARCH) && __ARM_ARCH >= 7) || defined(_M_ARM64) || (defined(_M_ARM) && _M_ARM >= 7) || defined(__ARM_ARCH_6K__) || defined(__ARM_ARCH_6T2__) /* ARM */ @@ -12020,7 +12071,7 @@ static MA_INLINE unsigned int ma_disable_denormals(void) { unsigned int prevState; - #if defined(_MSC_VER) + #if defined(_MSC_VER) && !defined(MA_XBOX_NXDK) { /* Older versions of Visual Studio don't support the "safe" versions of _controlfp_s(). I don't @@ -12043,7 +12094,7 @@ static MA_INLINE unsigned int ma_disable_denormals(void) } #elif defined(MA_X86) || defined(MA_X64) { - #if defined(__SSE2__) && !(defined(__TINYC__) || defined(__WATCOMC__) || defined(__COSMOPOLITAN__)) /* <-- Add compilers that lack support for _mm_getcsr() and _mm_setcsr() to this list. */ + #if defined(MA_SUPPORT_SSE2) && defined(__SSE2__) && !(defined(__TINYC__) || defined(__WATCOMC__)) /* <-- Add compilers that lack support for _mm_getcsr() and _mm_setcsr() to this list. */ { prevState = _mm_getcsr(); _mm_setcsr(prevState | MA_MM_DENORMALS_ZERO_MASK | MA_MM_FLUSH_ZERO_MASK); @@ -12067,7 +12118,7 @@ static MA_INLINE unsigned int ma_disable_denormals(void) static MA_INLINE void ma_restore_denormals(unsigned int prevState) { - #if defined(_MSC_VER) + #if defined(_MSC_VER) && !defined(MA_XBOX_NXDK) { /* Older versions of Visual Studio do not support _controlfp_s(). See ma_disable_denormals(). */ #if _MSC_VER <= 1200 @@ -12083,7 +12134,7 @@ static MA_INLINE void ma_restore_denormals(unsigned int prevState) } #elif defined(MA_X86) || defined(MA_X64) { - #if defined(__SSE2__) && !(defined(__TINYC__) || defined(__WATCOMC__) || defined(__COSMOPOLITAN__)) /* <-- Add compilers that lack support for _mm_getcsr() and _mm_setcsr() to this list. */ + #if defined(MA_SUPPORT_SSE2) && defined(__SSE2__) && !(defined(__TINYC__) || defined(__WATCOMC__)) /* <-- Add compilers that lack support for _mm_getcsr() and _mm_setcsr() to this list. */ { _mm_setcsr(prevState); } @@ -12719,6 +12770,29 @@ MA_API MA_NO_INLINE int ma_strcmp(const char* str1, const char* str2) return ((unsigned char*)str1)[0] - ((unsigned char*)str2)[0]; } +MA_API MA_NO_INLINE int ma_wcscmp(const wchar_t* str1, const wchar_t* str2) +{ + if (str1 == str2) return 0; + + /* These checks differ from the standard implementation. It's not important, but I prefer it just for sanity. */ + if (str1 == NULL) return -1; + if (str2 == NULL) return 1; + + for (;;) { + if (str1[0] == L'\0') { + break; + } + if (str1[0] != str2[0]) { + break; + } + + str1 += 1; + str2 += 1; + } + + return ((unsigned short*)str1)[0] - ((unsigned short*)str2)[0]; +} + MA_API MA_NO_INLINE int ma_strappend(char* dst, size_t dstSize, const char* srcA, const char* srcB) { int result; @@ -12736,6 +12810,22 @@ MA_API MA_NO_INLINE int ma_strappend(char* dst, size_t dstSize, const char* srcA return result; } +MA_API MA_NO_INLINE size_t ma_wcslen(const wchar_t* str) +{ + const wchar_t* end; + + if (str == NULL) { + return 0; + } + + end = str; + while (end[0] != '\0') { + end += 1; + } + + return end - str; +} + MA_API MA_NO_INLINE char* ma_copy_string(const char* src, const ma_allocation_callbacks* pAllocationCallbacks) { size_t sz; @@ -12758,7 +12848,7 @@ MA_API MA_NO_INLINE char* ma_copy_string(const char* src, const ma_allocation_ca MA_API MA_NO_INLINE wchar_t* ma_copy_string_w(const wchar_t* src, const ma_allocation_callbacks* pAllocationCallbacks) { - size_t sz = wcslen(src)+1; + size_t sz = ma_wcslen(src)+1; wchar_t* dst = (wchar_t*)ma_malloc(sz * sizeof(*dst), pAllocationCallbacks); if (dst == NULL) { return NULL; @@ -13189,7 +13279,7 @@ MA_API ma_result ma_fopen(FILE** ppFile, const char* pFilePath, const char* pOpe return MA_INVALID_ARGS; } -#if defined(_MSC_VER) && _MSC_VER >= 1400 +#if (defined(_MSC_VER) && _MSC_VER >= 1400) && !defined(MA_XBOX_NXDK) err = fopen_s(ppFile, pFilePath, pOpenMode); if (err != 0) { return ma_result_from_errno(err); @@ -13231,7 +13321,7 @@ _wfopen() isn't always available in all compilation environments. This can be reviewed as compatibility issues arise. The preference is to use _wfopen_s() and _wfopen() as opposed to the wcsrtombs() fallback, so if you notice your compiler not detecting this properly I'm happy to look at adding support. */ -#if defined(_WIN32) +#if defined(_WIN32) && !defined(MA_XBOX_NXDK) #if defined(_MSC_VER) || defined(__MINGW64__) || (!defined(__STRICT_ANSI__) && !defined(_NO_EXT_KEYS)) #define MA_HAS_WFOPEN #endif @@ -13247,29 +13337,34 @@ MA_API ma_result ma_wfopen(FILE** ppFile, const wchar_t* pFilePath, const wchar_ return MA_INVALID_ARGS; } -#if defined(MA_HAS_WFOPEN) + #if defined(MA_HAS_WFOPEN) { /* Use _wfopen() on Windows. */ - #if defined(_MSC_VER) && _MSC_VER >= 1400 - errno_t err = _wfopen_s(ppFile, pFilePath, pOpenMode); - if (err != 0) { - return ma_result_from_errno(err); + #if defined(_MSC_VER) && _MSC_VER >= 1400 + { + errno_t err = _wfopen_s(ppFile, pFilePath, pOpenMode); + if (err != 0) { + return ma_result_from_errno(err); + } } - #else - *ppFile = _wfopen(pFilePath, pOpenMode); - if (*ppFile == NULL) { - return ma_result_from_errno(errno); + #else + { + *ppFile = _wfopen(pFilePath, pOpenMode); + if (*ppFile == NULL) { + return ma_result_from_errno(errno); + } } - #endif + #endif + (void)pAllocationCallbacks; } -#else - /* - Use fopen() on anything other than Windows. Requires a conversion. This is annoying because fopen() is locale specific. The only real way I can - think of to do this is with wcsrtombs(). Note that wcstombs() is apparently not thread-safe because it uses a static global mbstate_t object for - maintaining state. I've checked this with -std=c89 and it works, but if somebody get's a compiler error I'll look into improving compatibility. - */ + #elif !defined(MA_XBOX_NXDK) && !defined(MA_DOS) /* If your compiler does not support wcsrtombs(), add it here. */ { + /* + Use fopen() on anything other than Windows. Requires a conversion. This is annoying because fopen() is locale specific. The only real way I can + think of to do this is with wcsrtombs(). Note that wcstombs() is apparently not thread-safe because it uses a static global mbstate_t object for + maintaining state. I've checked this with -std=c89 and it works, but if somebody get's a compiler error I'll look into improving compatibility. + */ mbstate_t mbs; size_t lenMB; const wchar_t* pFilePathTemp = pFilePath; @@ -13310,11 +13405,16 @@ MA_API ma_result ma_wfopen(FILE** ppFile, const wchar_t* pFilePath, const wchar_ ma_free(pFilePathMB, pAllocationCallbacks); } + #else + { + /* Getting here means there is no way to open the file with a wide character string. */ + *ppFile = NULL; + } + #endif if (*ppFile == NULL) { return MA_ERROR; } -#endif return MA_SUCCESS; } @@ -13323,7 +13423,7 @@ MA_API ma_result ma_wfopen(FILE** ppFile, const wchar_t* pFilePath, const wchar_ static MA_INLINE void ma_copy_memory_64(void* dst, const void* src, ma_uint64 sizeInBytes) { -#if 0xFFFFFFFFFFFFFFFF <= MA_SIZE_MAX +#if MA_SIZE_MAX > 0xFFFFFFFF MA_COPY_MEMORY(dst, src, (size_t)sizeInBytes); #else while (sizeInBytes > 0) { @@ -13343,7 +13443,7 @@ static MA_INLINE void ma_copy_memory_64(void* dst, const void* src, ma_uint64 si static MA_INLINE void ma_zero_memory_64(void* dst, ma_uint64 sizeInBytes) { -#if 0xFFFFFFFFFFFFFFFF <= MA_SIZE_MAX +#if MA_SIZE_MAX > 0xFFFFFFFF MA_ZERO_MEMORY(dst, (size_t)sizeInBytes); #else while (sizeInBytes > 0) { @@ -13472,6 +13572,18 @@ static ma_result ma_allocation_callbacks_init_copy(ma_allocation_callbacks* pDst Logging **************************************************************************************************************************************************************/ +#ifndef ma_va_copy + #if !defined(_MSC_VER) || _MSC_VER >= 1800 + #if (defined(__GNUC__) && __GNUC__ < 3) + #define ma_va_copy(dst, src) ((dst) = (src)) /* This is untested. Not sure if this is correct for old GCC. */ + #else + #define ma_va_copy(dst, src) va_copy((dst), (src)) + #endif + #else + #define ma_va_copy(dst, src) ((dst) = (src)) + #endif +#endif + MA_API const char* ma_log_level_to_string(ma_uint32 logLevel) { switch (logLevel) @@ -13712,9 +13824,15 @@ MA_API ma_result ma_log_postv(ma_log* pLog, ma_uint32 level, const char* pFormat int length; char pFormattedMessageStack[1024]; char* pFormattedMessageHeap = NULL; + va_list args2; /* First try formatting into our fixed sized stack allocated buffer. If this is too small we'll fallback to a heap allocation. */ - length = vsnprintf(pFormattedMessageStack, sizeof(pFormattedMessageStack), pFormat, args); + ma_va_copy(args2, args); + { + length = vsnprintf(pFormattedMessageStack, sizeof(pFormattedMessageStack), pFormat, args2); + } + va_end(args2); + if (length < 0) { return MA_INVALID_OPERATION; /* An error occurred when trying to convert the buffer. */ } @@ -13755,17 +13873,10 @@ MA_API ma_result ma_log_postv(ma_log* pLog, ma_uint32 level, const char* pFormat char* pFormattedMessage = NULL; va_list args2; - #if _MSC_VER >= 1800 - { - va_copy(args2, args); - } - #else + ma_va_copy(args2, args); { - args2 = args; + formattedLen = ma_vscprintf(&pLog->allocationCallbacks, pFormat, args2); } - #endif - - formattedLen = ma_vscprintf(&pLog->allocationCallbacks, pFormat, args2); va_end(args2); if (formattedLen <= 0) { @@ -13964,7 +14075,7 @@ miniaudio's purposes. #define MA_LCG_A 48271 #define MA_LCG_C 0 -static ma_lcg g_maLCG = {MA_DEFAULT_LCG_SEED}; /* Non-zero initial seed. Use ma_seed() to use an explicit seed. */ +static ma_lcg g_maLCG = {MA_DEFAULT_LCG_SEED}; /* Non-zero initial seed. Use ma_lcg_seed() to use an explicit seed. */ static MA_INLINE void ma_lcg_seed(ma_lcg* pLCG, ma_int32 seed) { @@ -14013,7 +14124,7 @@ static MA_INLINE ma_int32 ma_lcg_rand_range_s32(ma_lcg* pLCG, ma_int32 lo, ma_in } - +#if 0 /* Currently unused. */ static MA_INLINE void ma_seed(ma_int32 seed) { ma_lcg_seed(&g_maLCG, seed); @@ -14038,6 +14149,7 @@ static MA_INLINE float ma_rand_f32(void) { return ma_lcg_rand_f32(&g_maLCG); } +#endif static MA_INLINE float ma_rand_range_f32(float lo, float hi) { @@ -14097,6 +14209,7 @@ Atomics **************************************************************************************************************************************************************/ /* c89atomic.h begin */ #ifndef ma_atomic_h +#define ma_atomic_h #if defined(__cplusplus) extern "C" { #endif @@ -14108,11 +14221,63 @@ extern "C" { #endif #endif typedef int ma_atomic_memory_order; -#define MA_ATOMIC_HAS_8 -#define MA_ATOMIC_HAS_16 -#define MA_ATOMIC_HAS_32 -#define MA_ATOMIC_HAS_64 -#if (defined(_MSC_VER) ) || defined(__WATCOMC__) || defined(__DMC__) +#if !defined(MA_ATOMIC_MODERN_MSVC) && \ + !defined(MA_ATOMIC_LEGACY_MSVC) && \ + !defined(MA_ATOMIC_LEGACY_MSVC_ASM) && \ + !defined(MA_ATOMIC_MODERN_GCC) && \ + !defined(MA_ATOMIC_LEGACY_GCC) && \ + !defined(MA_ATOMIC_LEGACY_GCC_ASM) + #if defined(_MSC_VER) || defined(__WATCOMC__) || defined(__DMC__) || defined(__BORLANDC__) + #if (defined(_MSC_VER) && _MSC_VER > 1600) + #define MA_ATOMIC_MODERN_MSVC + #else + #if defined(MA_X64) + #define MA_ATOMIC_LEGACY_MSVC + #else + #define MA_ATOMIC_LEGACY_MSVC_ASM + #endif + #endif + #elif (defined(__GNUC__) && (__GNUC__ > 4 || (__GNUC__ == 4 && __GNUC_MINOR__ >= 7))) || defined(__clang__) + #define MA_ATOMIC_MODERN_GCC + #else + #if defined(__GNUC__) && (__GNUC__ > 4 || (__GNUC__ == 4 && __GNUC_MINOR__ >= 1)) + #define MA_ATOMIC_LEGACY_GCC + #else + #define MA_ATOMIC_LEGACY_GCC_ASM + #endif + #endif +#endif +#if defined(MA_ATOMIC_MODERN_MSVC) || defined(MA_ATOMIC_LEGACY_MSVC) + #include <intrin.h> + #define ma_atomic_memory_order_relaxed 1 + #define ma_atomic_memory_order_consume 2 + #define ma_atomic_memory_order_acquire 3 + #define ma_atomic_memory_order_release 4 + #define ma_atomic_memory_order_acq_rel 5 + #define ma_atomic_memory_order_seq_cst 6 + #define MA_ATOMIC_MSVC_ARM_INTRINSIC_NORETURN(dst, src, order, intrin, ma_atomicType, msvcType) \ + switch (order) \ + { \ + case ma_atomic_memory_order_relaxed: \ + { \ + intrin##_nf((volatile msvcType*)dst, (msvcType)src); \ + } break; \ + case ma_atomic_memory_order_consume: \ + case ma_atomic_memory_order_acquire: \ + { \ + intrin##_acq((volatile msvcType*)dst, (msvcType)src); \ + } break; \ + case ma_atomic_memory_order_release: \ + { \ + intrin##_rel((volatile msvcType*)dst, (msvcType)src); \ + } break; \ + case ma_atomic_memory_order_acq_rel: \ + case ma_atomic_memory_order_seq_cst: \ + default: \ + { \ + intrin((volatile msvcType*)dst, (msvcType)src); \ + } break; \ + } #define MA_ATOMIC_MSVC_ARM_INTRINSIC(dst, src, order, intrin, ma_atomicType, msvcType) \ ma_atomicType result; \ switch (order) \ @@ -14138,720 +14303,1501 @@ typedef int ma_atomic_memory_order; } break; \ } \ return result; - #define MA_ATOMIC_MSVC_ARM_INTRINSIC_COMPARE_EXCHANGE(ptr, expected, desired, order, intrin, ma_atomicType, msvcType) \ + typedef ma_uint32 ma_atomic_flag; + static MA_INLINE ma_atomic_flag ma_atomic_flag_test_and_set_explicit(volatile ma_atomic_flag* dst, ma_atomic_memory_order order) + { + #if defined(MA_ARM) + { + MA_ATOMIC_MSVC_ARM_INTRINSIC(dst, 1, order, _InterlockedExchange, ma_atomic_flag, long); + } + #else + { + (void)order; + return (ma_atomic_flag)_InterlockedExchange((volatile long*)dst, (long)1); + } + #endif + } + static MA_INLINE void ma_atomic_flag_clear_explicit(volatile ma_atomic_flag* dst, ma_atomic_memory_order order) + { + #if defined(MA_ARM) + { + MA_ATOMIC_MSVC_ARM_INTRINSIC_NORETURN(dst, 0, order, _InterlockedExchange, ma_atomic_flag, long); + } + #else + { + (void)order; + _InterlockedExchange((volatile long*)dst, (long)0); + } + #endif + } + static MA_INLINE ma_atomic_flag ma_atomic_flag_load_explicit(volatile const ma_atomic_flag* dst, ma_atomic_memory_order order) + { + (void)order; + return (ma_uint32)_InterlockedCompareExchange((volatile long*)dst, 0, 0); + } +#endif +#if defined(MA_ATOMIC_LEGACY_MSVC_ASM) + #define ma_atomic_memory_order_relaxed 1 + #define ma_atomic_memory_order_consume 2 + #define ma_atomic_memory_order_acquire 3 + #define ma_atomic_memory_order_release 4 + #define ma_atomic_memory_order_acq_rel 5 + #define ma_atomic_memory_order_seq_cst 6 + typedef ma_uint32 ma_atomic_flag; + static MA_INLINE ma_atomic_flag ma_atomic_flag_test_and_set_explicit(volatile ma_atomic_flag* dst, ma_atomic_memory_order order) + { + ma_atomic_flag result = 0; + (void)order; + __asm { + mov ecx, dst + mov eax, 1 + xchg [ecx], eax + mov result, eax + } + return result; + } + static MA_INLINE void ma_atomic_flag_clear_explicit(volatile ma_atomic_flag* dst, ma_atomic_memory_order order) + { + if (order == ma_atomic_memory_order_relaxed) { + __asm { + mov esi, dst + mov dword ptr [esi], 0 + } + } else { + __asm { + mov esi, dst + mov eax, 0 + xchg [esi], eax + } + } + } + static MA_INLINE ma_atomic_flag ma_atomic_flag_load_explicit(volatile const ma_atomic_flag* dst, ma_atomic_memory_order order) + { + ma_atomic_flag result = 0; + if (order == ma_atomic_memory_order_relaxed) { + __asm { + mov esi, dst + mov eax, [esi] + mov result, eax + } + } else if (order <= ma_atomic_memory_order_release) { + __asm { + mov esi, dst + mov eax, [esi] + lock add dword ptr [esp], 0 + mov result, eax + } + } else { + __asm { + lock add dword ptr [esp], 0 + mov esi, dst + mov eax, [esi] + mov result, eax + lock add dword ptr [esp], 0 + } + } + return result; + } +#endif +#if defined(MA_ATOMIC_MODERN_GCC) + #define ma_atomic_memory_order_relaxed __ATOMIC_RELAXED + #define ma_atomic_memory_order_consume __ATOMIC_CONSUME + #define ma_atomic_memory_order_acquire __ATOMIC_ACQUIRE + #define ma_atomic_memory_order_release __ATOMIC_RELEASE + #define ma_atomic_memory_order_acq_rel __ATOMIC_ACQ_REL + #define ma_atomic_memory_order_seq_cst __ATOMIC_SEQ_CST + typedef ma_uint32 ma_atomic_flag; + #define ma_atomic_flag_test_and_set_explicit(dst, order) __atomic_exchange_n(dst, 1, order) + #define ma_atomic_flag_clear_explicit(dst, order) __atomic_store_n(dst, 0, order) + #define ma_atomic_flag_load_explicit(dst, order) __atomic_load_n(dst, order) +#endif +#if defined(MA_ATOMIC_LEGACY_GCC) + #define ma_atomic_memory_order_relaxed 1 + #define ma_atomic_memory_order_consume 2 + #define ma_atomic_memory_order_acquire 3 + #define ma_atomic_memory_order_release 4 + #define ma_atomic_memory_order_acq_rel 5 + #define ma_atomic_memory_order_seq_cst 6 + typedef ma_uint32 ma_atomic_flag; + static MA_INLINE ma_atomic_flag ma_atomic_flag_test_and_set_explicit(volatile ma_atomic_flag* dst, ma_atomic_memory_order order) + { + if (order > ma_atomic_memory_order_acquire) { + __sync_synchronize(); + } + return __sync_lock_test_and_set(dst, 1); + } + static MA_INLINE void ma_atomic_flag_clear_explicit(volatile ma_atomic_flag* dst, ma_atomic_memory_order order) + { + if (order > ma_atomic_memory_order_release) { + __sync_synchronize(); + } + __sync_lock_release(dst); + } + static MA_INLINE ma_atomic_flag ma_atomic_flag_load_explicit(volatile const ma_atomic_flag* dst, ma_atomic_memory_order order) + { + (void)order; + return __sync_val_compare_and_swap((ma_atomic_flag*)dst, 0, 0); + } +#endif +#if defined(MA_ATOMIC_LEGACY_GCC_ASM) + #define ma_atomic_memory_order_relaxed 1 + #define ma_atomic_memory_order_consume 2 + #define ma_atomic_memory_order_acquire 3 + #define ma_atomic_memory_order_release 4 + #define ma_atomic_memory_order_acq_rel 5 + #define ma_atomic_memory_order_seq_cst 6 + #if defined(MA_X86) + #define ma_atomic_thread_fence(order) __asm__ __volatile__("lock; addl $0, (%%esp)" ::: "memory") + #elif defined(MA_X64) + #define ma_atomic_thread_fence(order) __asm__ __volatile__("lock; addq $0, (%%rsp)" ::: "memory") + #else + #error Unsupported architecture. + #endif + #define MA_ATOMIC_XCHG_GCC_X86(instructionSizeSuffix, result, dst, src) \ + __asm__ __volatile__( \ + "xchg"instructionSizeSuffix" %0, %1" \ + : "=r"(result), \ + "=m"(*dst) \ + : "0"(src), \ + "m"(*dst) \ + : "memory" \ + ) + #define MA_ATOMIC_LOAD_RELAXED_GCC_X86(instructionSizeSuffix, result, dst) \ + __asm__ __volatile__( \ + "mov"instructionSizeSuffix" %1, %0" \ + : "=r"(result) \ + : "m"(*dst) \ + ) + #define MA_ATOMIC_LOAD_RELEASE_GCC_X86(instructionSizeSuffix, result, dst) \ + ma_atomic_thread_fence(ma_atomic_memory_order_release); \ + __asm__ __volatile__( \ + "mov"instructionSizeSuffix" %1, %0" \ + : "=r"(result) \ + : "m"(*dst) \ + : "memory" \ + ) + #define MA_ATOMIC_LOAD_SEQ_CST_GCC_X86(instructionSizeSuffix, result, dst) \ + ma_atomic_thread_fence(ma_atomic_memory_order_seq_cst); \ + __asm__ __volatile__( \ + "mov"instructionSizeSuffix" %1, %0" \ + : "=r"(result) \ + : "m"(*dst) \ + : "memory" \ + ); \ + ma_atomic_thread_fence(ma_atomic_memory_order_seq_cst) + typedef ma_uint32 ma_atomic_flag; + static MA_INLINE ma_atomic_flag ma_atomic_flag_test_and_set_explicit(volatile ma_atomic_flag* dst, ma_atomic_memory_order order) + { + ma_atomic_flag result; + #if defined(MA_X86) || defined(MA_X64) + { + (void)order; + MA_ATOMIC_XCHG_GCC_X86("l", result, dst, 1); + } + #else + { + #error Unsupported architecture. + } + #endif + return result; + } + static MA_INLINE void ma_atomic_flag_clear_explicit(volatile ma_atomic_flag* dst, ma_atomic_memory_order order) + { + #if defined(MA_X86) || defined(MA_X64) + { + if (order == ma_atomic_memory_order_relaxed) { + __asm__ __volatile__( + "movl $0, %0" + : "=m"(*dst) + ); + } else if (order == ma_atomic_memory_order_release) { + __asm__ __volatile__( + "movl $0, %0" + : "=m"(*dst) + : + : "memory" + ); + } else { + ma_atomic_flag tmp = 0; + __asm__ __volatile__( + "xchgl %0, %1" + : "=r"(tmp), + "=m"(*dst) + : "0"(tmp), + "m"(*dst) + : "memory" + ); + } + } + #else + { + #error Unsupported architecture. + } + #endif + } + static MA_INLINE ma_atomic_flag ma_atomic_flag_load_explicit(volatile const ma_atomic_flag* dst, ma_atomic_memory_order order) + { + #if defined(MA_X86) || defined(MA_X64) + { + ma_atomic_flag result; + if (order == ma_atomic_memory_order_relaxed) { + MA_ATOMIC_LOAD_RELAXED_GCC_X86("l", result, dst); + } else if (order <= ma_atomic_memory_order_release) { + MA_ATOMIC_LOAD_RELEASE_GCC_X86("l", result, dst); + } else { + MA_ATOMIC_LOAD_SEQ_CST_GCC_X86("l", result, dst); + } + return result; + } + #else + { + #error Unsupported architecture. + } + #endif + } +#endif +#define ma_atomic_flag_test_and_set(dst) ma_atomic_flag_test_and_set_explicit(dst, ma_atomic_memory_order_acquire) +#define ma_atomic_flag_clear(dst) ma_atomic_flag_clear_explicit(dst, ma_atomic_memory_order_release) +typedef ma_atomic_flag ma_atomic_spinlock; +static MA_INLINE void ma_atomic_spinlock_lock(volatile ma_atomic_spinlock* pSpinlock) +{ + for (;;) { + if (ma_atomic_flag_test_and_set_explicit(pSpinlock, ma_atomic_memory_order_acquire) == 0) { + break; + } + while (ma_atomic_flag_load_explicit(pSpinlock, ma_atomic_memory_order_relaxed) == 1) { + } + } +} +static MA_INLINE void ma_atomic_spinlock_unlock(volatile ma_atomic_spinlock* pSpinlock) +{ + ma_atomic_flag_clear_explicit(pSpinlock, ma_atomic_memory_order_release); +} +ma_atomic_spinlock ma_atomic_global_lock; +#if defined(MA_ATOMIC_MODERN_MSVC) || defined(MA_ATOMIC_LEGACY_MSVC) || defined(MA_ATOMIC_LEGACY_MSVC_ASM) || defined(MA_ATOMIC_LEGACY_GCC) || defined(MA_ATOMIC_LEGACY_GCC_ASM) + #if defined(MA_X64) || (defined(MA_X86) && ((defined(__GNUC__) && defined(__i486__)) || (defined(_M_IX86) && _M_IX86 >= 400))) + #if defined(MA_ATOMIC_LEGACY_MSVC) && defined(MA_X64) + #else + #define MA_ATOMIC_IS_LOCK_FREE_8 1 + #define MA_ATOMIC_IS_LOCK_FREE_16 1 + #endif + #define MA_ATOMIC_IS_LOCK_FREE_32 1 + #if defined(MA_X64) || (defined(MA_X86) && ((defined(__GNUC__) && defined(__i586__)) || (defined(_M_IX86) && _M_IX86 >= 500))) + #define MA_ATOMIC_IS_LOCK_FREE_64 1 + #else + #endif + #else + #endif + #if defined(MA_ARM32) || defined(MA_ARM64) + #define MA_ATOMIC_IS_LOCK_FREE_8 1 + #define MA_ATOMIC_IS_LOCK_FREE_16 1 + #define MA_ATOMIC_IS_LOCK_FREE_32 1 + #if defined(MA_ARM64) || defined(__ARM_ARCH_7A__) || defined(__ARM_ARCH_7R__) || defined(__ARM_ARCH_6K__) || defined(__ARM_ARCH_6Z__) || defined(__ARM_ARCH_6ZK__) + #define MA_ATOMIC_IS_LOCK_FREE_64 1 + #endif + #endif + #if defined(MA_ATOMIC_PPC32) || defined(MA_ATOMIC_PPC64) + #if (defined(__GNUC__) && (__GNUC__ < 4 || (__GNUC__ == 4 && __GNUC_MINOR__ < 7))) && !defined(__clang__) + #else + #define MA_ATOMIC_IS_LOCK_FREE_8 1 + #define MA_ATOMIC_IS_LOCK_FREE_16 1 + #endif + #define MA_ATOMIC_IS_LOCK_FREE_32 1 + #if defined(MA_ATOMIC_PPC64) + #define MA_ATOMIC_IS_LOCK_FREE_64 1 + #endif + #endif + static MA_INLINE ma_bool32 ma_atomic_is_lock_free_8(volatile void* ptr) + { + (void)ptr; + #if defined(MA_ATOMIC_IS_LOCK_FREE_8) + return 1; + #else + return 0; + #endif + } + static MA_INLINE ma_bool32 ma_atomic_is_lock_free_16(volatile void* ptr) + { + (void)ptr; + #if defined(MA_ATOMIC_IS_LOCK_FREE_16) + return 1; + #else + return 0; + #endif + } + static MA_INLINE ma_bool32 ma_atomic_is_lock_free_32(volatile void* ptr) + { + (void)ptr; + #if defined(MA_ATOMIC_IS_LOCK_FREE_32) + return 1; + #else + return 0; + #endif + } + static MA_INLINE ma_bool32 ma_atomic_is_lock_free_64(volatile void* ptr) + { + (void)ptr; + #if defined(MA_ATOMIC_IS_LOCK_FREE_64) + return 1; + #else + return 0; + #endif + } +#endif +#define MA_ATOMIC_COMPARE_AND_SWAP_LOCK(sizeInBits, dst, expected, replacement) \ + ma_uint##sizeInBits result; \ + ma_atomic_spinlock_lock(&ma_atomic_global_lock); \ + { \ + result = *dst; \ + if (result == expected) { \ + *dst = replacement; \ + } \ + } \ + ma_atomic_spinlock_unlock(&ma_atomic_global_lock); \ + return result +#define MA_ATOMIC_LOAD_EXPLICIT_LOCK(sizeInBits, ptr, order) \ + ma_uint##sizeInBits result; \ + ma_atomic_spinlock_lock(&ma_atomic_global_lock); \ + { \ + result = *ptr; \ + (void)order; \ + } \ + ma_atomic_spinlock_unlock(&ma_atomic_global_lock); \ + return result +#define MA_ATOMIC_STORE_EXPLICIT_LOCK(sizeInBits, dst, src, order) \ + ma_atomic_spinlock_lock(&ma_atomic_global_lock); \ + { \ + *dst = src; \ + (void)order; \ + } \ + ma_atomic_spinlock_unlock(&ma_atomic_global_lock) +#define MA_ATOMIC_STORE_EXPLICIT_CAS(sizeInBits, dst, src, order) \ + ma_uint##sizeInBits oldValue; \ + do { \ + oldValue = ma_atomic_load_explicit_##sizeInBits(dst, ma_atomic_memory_order_relaxed); \ + } while (ma_atomic_compare_and_swap_##sizeInBits(dst, oldValue, src) != oldValue); \ + (void)order +#define MA_ATOMIC_EXCHANGE_EXPLICIT_LOCK(sizeInBits, dst, src, order) \ + ma_uint##sizeInBits result; \ + ma_atomic_spinlock_lock(&ma_atomic_global_lock); \ + { \ + result = *dst; \ + *dst = src; \ + (void)order; \ + } \ + ma_atomic_spinlock_unlock(&ma_atomic_global_lock); \ + return result +#define MA_ATOMIC_EXCHANGE_EXPLICIT_CAS(sizeInBits, dst, src, order) \ + ma_uint##sizeInBits oldValue; \ + do { \ + oldValue = ma_atomic_load_explicit_##sizeInBits(dst, ma_atomic_memory_order_relaxed); \ + } while (ma_atomic_compare_and_swap_##sizeInBits(dst, oldValue, src) != oldValue); \ + (void)order; \ + return oldValue +#define MA_ATOMIC_FETCH_ADD_LOCK(sizeInBits, dst, src, order) \ + ma_uint##sizeInBits result; \ + ma_atomic_spinlock_lock(&ma_atomic_global_lock); \ + { \ + result = *dst; \ + *dst += src; \ + (void)order; \ + } \ + ma_atomic_spinlock_unlock(&ma_atomic_global_lock); \ + return result +#define MA_ATOMIC_FETCH_ADD_CAS(sizeInBits, dst, src, order) \ + ma_uint##sizeInBits oldValue; \ + ma_uint##sizeInBits newValue; \ + do { \ + oldValue = ma_atomic_load_explicit_##sizeInBits(dst, ma_atomic_memory_order_relaxed); \ + newValue = oldValue + src; \ + } while (ma_atomic_compare_and_swap_##sizeInBits(dst, oldValue, newValue) != oldValue); \ + (void)order; \ + return oldValue +#define MA_ATOMIC_FETCH_AND_CAS(sizeInBits, dst, src, order) \ + ma_uint##sizeInBits oldValue; \ + ma_uint##sizeInBits newValue; \ + do { \ + oldValue = ma_atomic_load_explicit_##sizeInBits(dst, ma_atomic_memory_order_relaxed); \ + newValue = (ma_uint##sizeInBits)(oldValue & src); \ + } while (ma_atomic_compare_and_swap_##sizeInBits(dst, oldValue, newValue) != oldValue); \ + (void)order; \ + return oldValue +#define MA_ATOMIC_FETCH_OR_CAS(sizeInBits, dst, src, order) \ + ma_uint##sizeInBits oldValue; \ + ma_uint##sizeInBits newValue; \ + do { \ + oldValue = ma_atomic_load_explicit_##sizeInBits(dst, ma_atomic_memory_order_relaxed); \ + newValue = (ma_uint##sizeInBits)(oldValue | src); \ + } while (ma_atomic_compare_and_swap_##sizeInBits(dst, oldValue, newValue) != oldValue); \ + (void)order; \ + return oldValue +#define MA_ATOMIC_FETCH_XOR_CAS(sizeInBits, dst, src, order) \ + ma_uint##sizeInBits oldValue; \ + ma_uint##sizeInBits newValue; \ + do { \ + oldValue = ma_atomic_load_explicit_##sizeInBits(dst, ma_atomic_memory_order_relaxed); \ + newValue = (ma_uint##sizeInBits)(oldValue ^ src); \ + } while (ma_atomic_compare_and_swap_##sizeInBits(dst, oldValue, newValue) != oldValue); \ + (void)order; \ + return oldValue +#if defined(MA_ATOMIC_MODERN_MSVC) || defined(MA_ATOMIC_LEGACY_MSVC) + #define MA_ATOMIC_MSVC_ARM_INTRINSIC_COMPARE_EXCHANGE(ptr, expected, replacement, order, intrin, ma_atomicType, msvcType) \ ma_atomicType result; \ switch (order) \ { \ case ma_atomic_memory_order_relaxed: \ { \ - result = (ma_atomicType)intrin##_nf((volatile msvcType*)ptr, (msvcType)expected, (msvcType)desired); \ + result = (ma_atomicType)intrin##_nf((volatile msvcType*)ptr, (msvcType)expected, (msvcType)replacement); \ } break; \ case ma_atomic_memory_order_consume: \ case ma_atomic_memory_order_acquire: \ { \ - result = (ma_atomicType)intrin##_acq((volatile msvcType*)ptr, (msvcType)expected, (msvcType)desired); \ + result = (ma_atomicType)intrin##_acq((volatile msvcType*)ptr, (msvcType)expected, (msvcType)replacement); \ } break; \ case ma_atomic_memory_order_release: \ { \ - result = (ma_atomicType)intrin##_rel((volatile msvcType*)ptr, (msvcType)expected, (msvcType)desired); \ + result = (ma_atomicType)intrin##_rel((volatile msvcType*)ptr, (msvcType)expected, (msvcType)replacement); \ } break; \ case ma_atomic_memory_order_acq_rel: \ case ma_atomic_memory_order_seq_cst: \ default: \ { \ - result = (ma_atomicType)intrin((volatile msvcType*)ptr, (msvcType)expected, (msvcType)desired); \ + result = (ma_atomicType)intrin((volatile msvcType*)ptr, (msvcType)expected, (msvcType)replacement); \ } break; \ } \ return result; - #define ma_atomic_memory_order_relaxed 0 - #define ma_atomic_memory_order_consume 1 - #define ma_atomic_memory_order_acquire 2 - #define ma_atomic_memory_order_release 3 - #define ma_atomic_memory_order_acq_rel 4 - #define ma_atomic_memory_order_seq_cst 5 - #if _MSC_VER < 1600 && defined(MA_X86) - #define MA_ATOMIC_MSVC_USE_INLINED_ASSEMBLY + #if defined(MA_ATOMIC_IS_LOCK_FREE_8) + #define ma_atomic_compare_and_swap_8( dst, expected, replacement) (ma_uint8 )_InterlockedCompareExchange8((volatile char*)dst, (char)replacement, (char)expected) + #else + static MA_INLINE ma_uint8 __stdcall ma_atomic_compare_and_swap_8(volatile ma_uint8* dst, ma_uint8 expected, ma_uint8 replacement) + { + MA_ATOMIC_COMPARE_AND_SWAP_LOCK(8, dst, expected, replacement); + } + #endif + #if defined(MA_ATOMIC_IS_LOCK_FREE_16) + #define ma_atomic_compare_and_swap_16(dst, expected, replacement) (ma_uint16)_InterlockedCompareExchange16((volatile short*)dst, (short)replacement, (short)expected) + #else + static MA_INLINE ma_uint16 __stdcall ma_atomic_compare_and_swap_16(volatile ma_uint16* dst, ma_uint16 expected, ma_uint16 replacement) + { + MA_ATOMIC_COMPARE_AND_SWAP_LOCK(16, dst, expected, replacement); + } #endif - #if _MSC_VER < 1600 - #undef MA_ATOMIC_HAS_8 - #undef MA_ATOMIC_HAS_16 + #if defined(MA_ATOMIC_IS_LOCK_FREE_32) + #define ma_atomic_compare_and_swap_32(dst, expected, replacement) (ma_uint32)_InterlockedCompareExchange((volatile long*)dst, (long)replacement, (long)expected) + #else + static MA_INLINE ma_uint32 __stdcall ma_atomic_compare_and_swap_32(volatile ma_uint32* dst, ma_uint32 expected, ma_uint32 replacement) + { + MA_ATOMIC_COMPARE_AND_SWAP_LOCK(32, dst, expected, replacement); + } #endif - #if !defined(MA_ATOMIC_MSVC_USE_INLINED_ASSEMBLY) - #include <intrin.h> + #if defined(MA_ATOMIC_IS_LOCK_FREE_32) + #define ma_atomic_compare_and_swap_64(dst, expected, replacement) (ma_uint64)_InterlockedCompareExchange64((volatile ma_int64*)dst, (ma_int64)replacement, (ma_int64)expected) + #else + static MA_INLINE ma_uint64 __stdcall ma_atomic_compare_and_swap_64(volatile ma_uint64* dst, ma_uint64 expected, ma_uint64 replacement) + { + MA_ATOMIC_COMPARE_AND_SWAP_LOCK(64, dst, expected, replacement); + } #endif - #if defined(MA_ATOMIC_MSVC_USE_INLINED_ASSEMBLY) - #if defined(MA_ATOMIC_HAS_8) - static MA_INLINE ma_uint8 __stdcall ma_atomic_compare_and_swap_8(volatile ma_uint8* dst, ma_uint8 expected, ma_uint8 desired) + static MA_INLINE ma_uint8 ma_atomic_load_explicit_8(volatile const ma_uint8* ptr, ma_atomic_memory_order order) + { + #if defined(MA_ATOMIC_IS_LOCK_FREE_8) + { + #if defined(MA_ARM) { - ma_uint8 result = 0; - __asm { - mov ecx, dst - mov al, expected - mov dl, desired - lock cmpxchg [ecx], dl - mov result, al - } - return result; + MA_ATOMIC_MSVC_ARM_INTRINSIC_COMPARE_EXCHANGE(ptr, 0, 0, order, _InterlockedCompareExchange8, ma_uint8, char); } - #endif - #if defined(MA_ATOMIC_HAS_16) - static MA_INLINE ma_uint16 __stdcall ma_atomic_compare_and_swap_16(volatile ma_uint16* dst, ma_uint16 expected, ma_uint16 desired) + #else { - ma_uint16 result = 0; - __asm { - mov ecx, dst - mov ax, expected - mov dx, desired - lock cmpxchg [ecx], dx - mov result, ax - } - return result; + (void)order; + return ma_atomic_compare_and_swap_8((volatile ma_uint8*)ptr, 0, 0); } + #endif + } + #else + { + MA_ATOMIC_LOAD_EXPLICIT_LOCK(8, ptr, order); + } #endif - #if defined(MA_ATOMIC_HAS_32) - static MA_INLINE ma_uint32 __stdcall ma_atomic_compare_and_swap_32(volatile ma_uint32* dst, ma_uint32 expected, ma_uint32 desired) + } + static MA_INLINE ma_uint16 ma_atomic_load_explicit_16(volatile const ma_uint16* ptr, ma_atomic_memory_order order) + { + #if defined(MA_ATOMIC_IS_LOCK_FREE_16) + { + #if defined(MA_ARM) { - ma_uint32 result = 0; - __asm { - mov ecx, dst - mov eax, expected - mov edx, desired - lock cmpxchg [ecx], edx - mov result, eax - } - return result; + MA_ATOMIC_MSVC_ARM_INTRINSIC_COMPARE_EXCHANGE(ptr, 0, 0, order, _InterlockedCompareExchange16, ma_uint16, short); } - #endif - #if defined(MA_ATOMIC_HAS_64) - static MA_INLINE ma_uint64 __stdcall ma_atomic_compare_and_swap_64(volatile ma_uint64* dst, ma_uint64 expected, ma_uint64 desired) + #else { - ma_uint32 resultEAX = 0; - ma_uint32 resultEDX = 0; - __asm { - mov esi, dst - mov eax, dword ptr expected - mov edx, dword ptr expected + 4 - mov ebx, dword ptr desired - mov ecx, dword ptr desired + 4 - lock cmpxchg8b qword ptr [esi] - mov resultEAX, eax - mov resultEDX, edx - } - return ((ma_uint64)resultEDX << 32) | resultEAX; + (void)order; + return ma_atomic_compare_and_swap_16((volatile ma_uint16*)ptr, 0, 0); } + #endif + } + #else + { + MA_ATOMIC_LOAD_EXPLICIT_LOCK(16, ptr, order); + } #endif - #else - #if defined(MA_ATOMIC_HAS_8) - #define ma_atomic_compare_and_swap_8( dst, expected, desired) (ma_uint8 )_InterlockedCompareExchange8((volatile char*)dst, (char)desired, (char)expected) - #endif - #if defined(MA_ATOMIC_HAS_16) - #define ma_atomic_compare_and_swap_16(dst, expected, desired) (ma_uint16)_InterlockedCompareExchange16((volatile short*)dst, (short)desired, (short)expected) - #endif - #if defined(MA_ATOMIC_HAS_32) - #define ma_atomic_compare_and_swap_32(dst, expected, desired) (ma_uint32)_InterlockedCompareExchange((volatile long*)dst, (long)desired, (long)expected) - #endif - #if defined(MA_ATOMIC_HAS_64) - #define ma_atomic_compare_and_swap_64(dst, expected, desired) (ma_uint64)_InterlockedCompareExchange64((volatile ma_int64*)dst, (ma_int64)desired, (ma_int64)expected) - #endif - #endif - #if defined(MA_ATOMIC_MSVC_USE_INLINED_ASSEMBLY) - #if defined(MA_ATOMIC_HAS_8) - static MA_INLINE ma_uint8 __stdcall ma_atomic_exchange_explicit_8(volatile ma_uint8* dst, ma_uint8 src, ma_atomic_memory_order order) + } + static MA_INLINE ma_uint32 ma_atomic_load_explicit_32(volatile const ma_uint32* ptr, ma_atomic_memory_order order) + { + #if defined(MA_ATOMIC_IS_LOCK_FREE_32) + { + #if defined(MA_ARM) { - ma_uint8 result = 0; - (void)order; - __asm { - mov ecx, dst - mov al, src - lock xchg [ecx], al - mov result, al - } - return result; + MA_ATOMIC_MSVC_ARM_INTRINSIC_COMPARE_EXCHANGE(ptr, 0, 0, order, _InterlockedCompareExchange, ma_uint32, long); } - #endif - #if defined(MA_ATOMIC_HAS_16) - static MA_INLINE ma_uint16 __stdcall ma_atomic_exchange_explicit_16(volatile ma_uint16* dst, ma_uint16 src, ma_atomic_memory_order order) + #else { - ma_uint16 result = 0; (void)order; - __asm { - mov ecx, dst - mov ax, src - lock xchg [ecx], ax - mov result, ax - } - return result; + return ma_atomic_compare_and_swap_32((volatile ma_uint32*)ptr, 0, 0); } + #endif + } + #else + { + MA_ATOMIC_LOAD_EXPLICIT_LOCK(32, ptr, order); + } #endif - #if defined(MA_ATOMIC_HAS_32) - static MA_INLINE ma_uint32 __stdcall ma_atomic_exchange_explicit_32(volatile ma_uint32* dst, ma_uint32 src, ma_atomic_memory_order order) + } + static MA_INLINE ma_uint64 ma_atomic_load_explicit_64(volatile const ma_uint64* ptr, ma_atomic_memory_order order) + { + #if defined(MA_ATOMIC_IS_LOCK_FREE_32) + { + #if defined(MA_ARM) + { + MA_ATOMIC_MSVC_ARM_INTRINSIC_COMPARE_EXCHANGE(ptr, 0, 0, order, _InterlockedCompareExchange64, ma_uint64, long long); + } + #else { - ma_uint32 result = 0; (void)order; - __asm { - mov ecx, dst - mov eax, src - lock xchg [ecx], eax - mov result, eax - } - return result; + return ma_atomic_compare_and_swap_64((volatile ma_uint64*)ptr, 0, 0); } + #endif + } + #else + { + MA_ATOMIC_LOAD_EXPLICIT_LOCK(64, ptr, order); + } #endif - #else - #if defined(MA_ATOMIC_HAS_8) - static MA_INLINE ma_uint8 __stdcall ma_atomic_exchange_explicit_8(volatile ma_uint8* dst, ma_uint8 src, ma_atomic_memory_order order) - { + } + static MA_INLINE ma_uint8 __stdcall ma_atomic_exchange_explicit_8(volatile ma_uint8* dst, ma_uint8 src, ma_atomic_memory_order order) + { + #if defined(MA_ATOMIC_IS_LOCK_FREE_8) + { #if defined(MA_ARM) + { MA_ATOMIC_MSVC_ARM_INTRINSIC(dst, src, order, _InterlockedExchange8, ma_uint8, char); + } #else + { (void)order; return (ma_uint8)_InterlockedExchange8((volatile char*)dst, (char)src); - #endif } + #endif + } + #else + { + MA_ATOMIC_EXCHANGE_EXPLICIT_LOCK(8, dst, src, order); + } #endif - #if defined(MA_ATOMIC_HAS_16) - static MA_INLINE ma_uint16 __stdcall ma_atomic_exchange_explicit_16(volatile ma_uint16* dst, ma_uint16 src, ma_atomic_memory_order order) - { + } + static MA_INLINE ma_uint16 __stdcall ma_atomic_exchange_explicit_16(volatile ma_uint16* dst, ma_uint16 src, ma_atomic_memory_order order) + { + #if defined(MA_ATOMIC_IS_LOCK_FREE_16) + { #if defined(MA_ARM) + { MA_ATOMIC_MSVC_ARM_INTRINSIC(dst, src, order, _InterlockedExchange16, ma_uint16, short); + } #else + { (void)order; return (ma_uint16)_InterlockedExchange16((volatile short*)dst, (short)src); - #endif } + #endif + } + #else + { + MA_ATOMIC_EXCHANGE_EXPLICIT_LOCK(16, dst, src, order); + } #endif - #if defined(MA_ATOMIC_HAS_32) - static MA_INLINE ma_uint32 __stdcall ma_atomic_exchange_explicit_32(volatile ma_uint32* dst, ma_uint32 src, ma_atomic_memory_order order) - { + } + static MA_INLINE ma_uint32 __stdcall ma_atomic_exchange_explicit_32(volatile ma_uint32* dst, ma_uint32 src, ma_atomic_memory_order order) + { + #if defined(MA_ATOMIC_IS_LOCK_FREE_32) + { #if defined(MA_ARM) + { MA_ATOMIC_MSVC_ARM_INTRINSIC(dst, src, order, _InterlockedExchange, ma_uint32, long); + } #else + { (void)order; return (ma_uint32)_InterlockedExchange((volatile long*)dst, (long)src); - #endif } - #endif - #if defined(MA_ATOMIC_HAS_64) && defined(MA_64BIT) - static MA_INLINE ma_uint64 __stdcall ma_atomic_exchange_explicit_64(volatile ma_uint64* dst, ma_uint64 src, ma_atomic_memory_order order) - { - #if defined(MA_ARM) - MA_ATOMIC_MSVC_ARM_INTRINSIC(dst, src, order, _InterlockedExchange64, ma_uint64, long long); - #else - (void)order; - return (ma_uint64)_InterlockedExchange64((volatile long long*)dst, (long long)src); #endif - } + } #else - #endif - #endif - #if defined(MA_ATOMIC_HAS_64) && !defined(MA_64BIT) - static MA_INLINE ma_uint64 __stdcall ma_atomic_exchange_explicit_64(volatile ma_uint64* dst, ma_uint64 src, ma_atomic_memory_order order) { - ma_uint64 oldValue; - do { - oldValue = *dst; - } while (ma_atomic_compare_and_swap_64(dst, oldValue, src) != oldValue); - (void)order; - return oldValue; + MA_ATOMIC_EXCHANGE_EXPLICIT_LOCK(32, dst, src, order); } - #endif - #if defined(MA_ATOMIC_MSVC_USE_INLINED_ASSEMBLY) - #if defined(MA_ATOMIC_HAS_8) - static MA_INLINE ma_uint8 __stdcall ma_atomic_fetch_add_explicit_8(volatile ma_uint8* dst, ma_uint8 src, ma_atomic_memory_order order) + #endif + } + static MA_INLINE ma_uint64 __stdcall ma_atomic_exchange_explicit_64(volatile ma_uint64* dst, ma_uint64 src, ma_atomic_memory_order order) + { + #if defined(MA_ATOMIC_IS_LOCK_FREE_64) + { + #if defined(MA_32BIT) { - ma_uint8 result = 0; - (void)order; - __asm { - mov ecx, dst - mov al, src - lock xadd [ecx], al - mov result, al - } - return result; + MA_ATOMIC_EXCHANGE_EXPLICIT_CAS(64, dst, src, order); } - #endif - #if defined(MA_ATOMIC_HAS_16) - static MA_INLINE ma_uint16 __stdcall ma_atomic_fetch_add_explicit_16(volatile ma_uint16* dst, ma_uint16 src, ma_atomic_memory_order order) + #else { - ma_uint16 result = 0; - (void)order; - __asm { - mov ecx, dst - mov ax, src - lock xadd [ecx], ax - mov result, ax + #if defined(MA_ARM) + { + MA_ATOMIC_MSVC_ARM_INTRINSIC(dst, src, order, _InterlockedExchange64, ma_uint64, long long); } - return result; + #else + { + (void)order; + return (ma_uint64)_InterlockedExchange64((volatile long long*)dst, (long long)src); + } + #endif } + #endif + } + #else + { + MA_ATOMIC_EXCHANGE_EXPLICIT_LOCK(64, dst, src, order); + } #endif - #if defined(MA_ATOMIC_HAS_32) - static MA_INLINE ma_uint32 __stdcall ma_atomic_fetch_add_explicit_32(volatile ma_uint32* dst, ma_uint32 src, ma_atomic_memory_order order) + } + static MA_INLINE ma_uint8 __stdcall ma_atomic_fetch_add_explicit_8(volatile ma_uint8* dst, ma_uint8 src, ma_atomic_memory_order order) + { + #if defined(MA_ATOMIC_IS_LOCK_FREE_8) + { + #if defined(MA_ARM) { - ma_uint32 result = 0; - (void)order; - __asm { - mov ecx, dst - mov eax, src - lock xadd [ecx], eax - mov result, eax - } - return result; - } - #endif - #else - #if defined(MA_ATOMIC_HAS_8) - static MA_INLINE ma_uint8 __stdcall ma_atomic_fetch_add_explicit_8(volatile ma_uint8* dst, ma_uint8 src, ma_atomic_memory_order order) - { - #if defined(MA_ARM) MA_ATOMIC_MSVC_ARM_INTRINSIC(dst, src, order, _InterlockedExchangeAdd8, ma_uint8, char); + } #else + { (void)order; return (ma_uint8)_InterlockedExchangeAdd8((volatile char*)dst, (char)src); - #endif } + #endif + } + #else + { + MA_ATOMIC_FETCH_ADD_LOCK(8, dst, src, order); + } #endif - #if defined(MA_ATOMIC_HAS_16) - static MA_INLINE ma_uint16 __stdcall ma_atomic_fetch_add_explicit_16(volatile ma_uint16* dst, ma_uint16 src, ma_atomic_memory_order order) - { + } + static MA_INLINE ma_uint16 __stdcall ma_atomic_fetch_add_explicit_16(volatile ma_uint16* dst, ma_uint16 src, ma_atomic_memory_order order) + { + #if defined(MA_ATOMIC_IS_LOCK_FREE_16) + { #if defined(MA_ARM) + { MA_ATOMIC_MSVC_ARM_INTRINSIC(dst, src, order, _InterlockedExchangeAdd16, ma_uint16, short); + } #else + { (void)order; return (ma_uint16)_InterlockedExchangeAdd16((volatile short*)dst, (short)src); - #endif } + #endif + } + #else + { + MA_ATOMIC_FETCH_ADD_LOCK(16, dst, src, order); + } #endif - #if defined(MA_ATOMIC_HAS_32) - static MA_INLINE ma_uint32 __stdcall ma_atomic_fetch_add_explicit_32(volatile ma_uint32* dst, ma_uint32 src, ma_atomic_memory_order order) - { + } + static MA_INLINE ma_uint32 __stdcall ma_atomic_fetch_add_explicit_32(volatile ma_uint32* dst, ma_uint32 src, ma_atomic_memory_order order) + { + #if defined(MA_ATOMIC_IS_LOCK_FREE_32) + { #if defined(MA_ARM) + { MA_ATOMIC_MSVC_ARM_INTRINSIC(dst, src, order, _InterlockedExchangeAdd, ma_uint32, long); + } #else + { (void)order; return (ma_uint32)_InterlockedExchangeAdd((volatile long*)dst, (long)src); - #endif } + #endif + } + #else + { + MA_ATOMIC_FETCH_ADD_LOCK(32, dst, src, order); + } #endif - #if defined(MA_ATOMIC_HAS_64) && defined(MA_64BIT) - static MA_INLINE ma_uint64 __stdcall ma_atomic_fetch_add_explicit_64(volatile ma_uint64* dst, ma_uint64 src, ma_atomic_memory_order order) + } + static MA_INLINE ma_uint64 __stdcall ma_atomic_fetch_add_explicit_64(volatile ma_uint64* dst, ma_uint64 src, ma_atomic_memory_order order) + { + #if defined(MA_ATOMIC_IS_LOCK_FREE_64) + { + #if defined(MA_32BIT) { - #if defined(MA_ARM) - MA_ATOMIC_MSVC_ARM_INTRINSIC(dst, src, order, _InterlockedExchangeAdd64, ma_uint64, long long); + MA_ATOMIC_FETCH_ADD_CAS(64, dst, src, order); + } #else - (void)order; - return (ma_uint64)_InterlockedExchangeAdd64((volatile long long*)dst, (long long)src); - #endif + { + #if defined(MA_ARM) + { + MA_ATOMIC_MSVC_ARM_INTRINSIC(dst, src, order, _InterlockedExchangeAdd64, ma_uint64, long long); + } + #else + { + (void)order; + return (ma_uint64)_InterlockedExchangeAdd64((volatile long long*)dst, (long long)src); + } + #endif } + #endif + } #else - #endif - #endif - #if defined(MA_ATOMIC_HAS_64) && !defined(MA_64BIT) - static MA_INLINE ma_uint64 __stdcall ma_atomic_fetch_add_explicit_64(volatile ma_uint64* dst, ma_uint64 src, ma_atomic_memory_order order) - { - ma_uint64 oldValue; - ma_uint64 newValue; - do { - oldValue = *dst; - newValue = oldValue + src; - } while (ma_atomic_compare_and_swap_64(dst, oldValue, newValue) != oldValue); - (void)order; - return oldValue; + { + MA_ATOMIC_FETCH_ADD_LOCK(64, dst, src, order); } - #endif - #if defined(MA_ATOMIC_MSVC_USE_INLINED_ASSEMBLY) - static MA_INLINE void __stdcall ma_atomic_thread_fence(ma_atomic_memory_order order) + #endif + } + static MA_INLINE ma_uint8 __stdcall ma_atomic_fetch_sub_explicit_8(volatile ma_uint8* dst, ma_uint8 src, ma_atomic_memory_order order) + { + return ma_atomic_fetch_add_explicit_8(dst, (ma_uint8)(-(ma_int8)src), order); + } + static MA_INLINE ma_uint16 __stdcall ma_atomic_fetch_sub_explicit_16(volatile ma_uint16* dst, ma_uint16 src, ma_atomic_memory_order order) + { + return ma_atomic_fetch_add_explicit_16(dst, (ma_uint16)(-(ma_int16)src), order); + } + static MA_INLINE ma_uint32 __stdcall ma_atomic_fetch_sub_explicit_32(volatile ma_uint32* dst, ma_uint32 src, ma_atomic_memory_order order) + { + return ma_atomic_fetch_add_explicit_32(dst, (ma_uint32)(-(ma_int32)src), order); + } + static MA_INLINE ma_uint64 __stdcall ma_atomic_fetch_sub_explicit_64(volatile ma_uint64* dst, ma_uint64 src, ma_atomic_memory_order order) + { + return ma_atomic_fetch_add_explicit_64(dst, (ma_uint64)(-(ma_int64)src), order); + } + static MA_INLINE ma_uint8 __stdcall ma_atomic_fetch_and_explicit_8(volatile ma_uint8* dst, ma_uint8 src, ma_atomic_memory_order order) + { + #if defined(MA_ARM) { - (void)order; - __asm { - lock add [esp], 0 - } + MA_ATOMIC_MSVC_ARM_INTRINSIC(dst, src, order, _InterlockedAnd8, ma_uint8, char); } - #else - #if defined(MA_X64) - #define ma_atomic_thread_fence(order) __faststorefence(), (void)order - #elif defined(MA_ARM64) - #define ma_atomic_thread_fence(order) __dmb(_ARM64_BARRIER_ISH), (void)order #else - static MA_INLINE void ma_atomic_thread_fence(ma_atomic_memory_order order) - { - volatile ma_uint32 barrier = 0; - ma_atomic_fetch_add_explicit_32(&barrier, 0, order); - } - #endif - #endif - #define ma_atomic_compiler_fence() ma_atomic_thread_fence(ma_atomic_memory_order_seq_cst) - #define ma_atomic_signal_fence(order) ma_atomic_thread_fence(order) - #if defined(MA_ATOMIC_HAS_8) - static MA_INLINE ma_uint8 ma_atomic_load_explicit_8(volatile const ma_uint8* ptr, ma_atomic_memory_order order) { + MA_ATOMIC_FETCH_AND_CAS(8, dst, src, order); + } + #endif + } + static MA_INLINE ma_uint16 __stdcall ma_atomic_fetch_and_explicit_16(volatile ma_uint16* dst, ma_uint16 src, ma_atomic_memory_order order) + { #if defined(MA_ARM) - MA_ATOMIC_MSVC_ARM_INTRINSIC_COMPARE_EXCHANGE(ptr, 0, 0, order, _InterlockedCompareExchange8, ma_uint8, char); + { + MA_ATOMIC_MSVC_ARM_INTRINSIC(dst, src, order, _InterlockedAnd16, ma_uint16, short); + } #else - (void)order; - return ma_atomic_compare_and_swap_8((volatile ma_uint8*)ptr, 0, 0); + { + MA_ATOMIC_FETCH_AND_CAS(16, dst, src, order); + } #endif + } + static MA_INLINE ma_uint32 __stdcall ma_atomic_fetch_and_explicit_32(volatile ma_uint32* dst, ma_uint32 src, ma_atomic_memory_order order) + { + #if defined(MA_ARM) + { + MA_ATOMIC_MSVC_ARM_INTRINSIC(dst, src, order, _InterlockedAnd, ma_uint32, long); } - #endif - #if defined(MA_ATOMIC_HAS_16) - static MA_INLINE ma_uint16 ma_atomic_load_explicit_16(volatile const ma_uint16* ptr, ma_atomic_memory_order order) + #else { + MA_ATOMIC_FETCH_AND_CAS(32, dst, src, order); + } + #endif + } + static MA_INLINE ma_uint64 __stdcall ma_atomic_fetch_and_explicit_64(volatile ma_uint64* dst, ma_uint64 src, ma_atomic_memory_order order) + { #if defined(MA_ARM) - MA_ATOMIC_MSVC_ARM_INTRINSIC_COMPARE_EXCHANGE(ptr, 0, 0, order, _InterlockedCompareExchange16, ma_uint16, short); + { + MA_ATOMIC_MSVC_ARM_INTRINSIC(dst, src, order, _InterlockedAnd64, ma_uint64, long long); + } #else - (void)order; - return ma_atomic_compare_and_swap_16((volatile ma_uint16*)ptr, 0, 0); + { + MA_ATOMIC_FETCH_AND_CAS(64, dst, src, order); + } #endif + } + static MA_INLINE ma_uint8 __stdcall ma_atomic_fetch_or_explicit_8(volatile ma_uint8* dst, ma_uint8 src, ma_atomic_memory_order order) + { + #if defined(MA_ARM) + { + MA_ATOMIC_MSVC_ARM_INTRINSIC(dst, src, order, _InterlockedOr8, ma_uint8, char); } - #endif - #if defined(MA_ATOMIC_HAS_32) - static MA_INLINE ma_uint32 ma_atomic_load_explicit_32(volatile const ma_uint32* ptr, ma_atomic_memory_order order) + #else { + MA_ATOMIC_FETCH_OR_CAS(8, dst, src, order); + } + #endif + } + static MA_INLINE ma_uint16 __stdcall ma_atomic_fetch_or_explicit_16(volatile ma_uint16* dst, ma_uint16 src, ma_atomic_memory_order order) + { #if defined(MA_ARM) - MA_ATOMIC_MSVC_ARM_INTRINSIC_COMPARE_EXCHANGE(ptr, 0, 0, order, _InterlockedCompareExchange, ma_uint32, long); + { + MA_ATOMIC_MSVC_ARM_INTRINSIC(dst, src, order, _InterlockedOr16, ma_uint16, short); + } #else - (void)order; - return ma_atomic_compare_and_swap_32((volatile ma_uint32*)ptr, 0, 0); + { + MA_ATOMIC_FETCH_OR_CAS(16, dst, src, order); + } #endif + } + static MA_INLINE ma_uint32 __stdcall ma_atomic_fetch_or_explicit_32(volatile ma_uint32* dst, ma_uint32 src, ma_atomic_memory_order order) + { + #if defined(MA_ARM) + { + MA_ATOMIC_MSVC_ARM_INTRINSIC(dst, src, order, _InterlockedOr, ma_uint32, long); } - #endif - #if defined(MA_ATOMIC_HAS_64) - static MA_INLINE ma_uint64 ma_atomic_load_explicit_64(volatile const ma_uint64* ptr, ma_atomic_memory_order order) + #else { + MA_ATOMIC_FETCH_OR_CAS(32, dst, src, order); + } + #endif + } + static MA_INLINE ma_uint64 __stdcall ma_atomic_fetch_or_explicit_64(volatile ma_uint64* dst, ma_uint64 src, ma_atomic_memory_order order) + { #if defined(MA_ARM) - MA_ATOMIC_MSVC_ARM_INTRINSIC_COMPARE_EXCHANGE(ptr, 0, 0, order, _InterlockedCompareExchange64, ma_uint64, long long); + { + MA_ATOMIC_MSVC_ARM_INTRINSIC(dst, src, order, _InterlockedOr64, ma_uint64, long long); + } #else - (void)order; - return ma_atomic_compare_and_swap_64((volatile ma_uint64*)ptr, 0, 0); + { + MA_ATOMIC_FETCH_OR_CAS(64, dst, src, order); + } #endif + } + static MA_INLINE ma_uint8 __stdcall ma_atomic_fetch_xor_explicit_8(volatile ma_uint8* dst, ma_uint8 src, ma_atomic_memory_order order) + { + #if defined(MA_ARM) + { + MA_ATOMIC_MSVC_ARM_INTRINSIC(dst, src, order, _InterlockedXor8, ma_uint8, char); } - #endif - #if defined(MA_ATOMIC_HAS_8) - #define ma_atomic_store_explicit_8( dst, src, order) (void)ma_atomic_exchange_explicit_8 (dst, src, order) - #endif - #if defined(MA_ATOMIC_HAS_16) - #define ma_atomic_store_explicit_16(dst, src, order) (void)ma_atomic_exchange_explicit_16(dst, src, order) - #endif - #if defined(MA_ATOMIC_HAS_32) - #define ma_atomic_store_explicit_32(dst, src, order) (void)ma_atomic_exchange_explicit_32(dst, src, order) - #endif - #if defined(MA_ATOMIC_HAS_64) - #define ma_atomic_store_explicit_64(dst, src, order) (void)ma_atomic_exchange_explicit_64(dst, src, order) - #endif - #if defined(MA_ATOMIC_HAS_8) - static MA_INLINE ma_uint8 __stdcall ma_atomic_fetch_sub_explicit_8(volatile ma_uint8* dst, ma_uint8 src, ma_atomic_memory_order order) - { - ma_uint8 oldValue; - ma_uint8 newValue; - do { - oldValue = *dst; - newValue = (ma_uint8)(oldValue - src); - } while (ma_atomic_compare_and_swap_8(dst, oldValue, newValue) != oldValue); - (void)order; - return oldValue; + #else + { + MA_ATOMIC_FETCH_XOR_CAS(8, dst, src, order); } - #endif - #if defined(MA_ATOMIC_HAS_16) - static MA_INLINE ma_uint16 __stdcall ma_atomic_fetch_sub_explicit_16(volatile ma_uint16* dst, ma_uint16 src, ma_atomic_memory_order order) - { - ma_uint16 oldValue; - ma_uint16 newValue; - do { - oldValue = *dst; - newValue = (ma_uint16)(oldValue - src); - } while (ma_atomic_compare_and_swap_16(dst, oldValue, newValue) != oldValue); - (void)order; - return oldValue; + #endif + } + static MA_INLINE ma_uint16 __stdcall ma_atomic_fetch_xor_explicit_16(volatile ma_uint16* dst, ma_uint16 src, ma_atomic_memory_order order) + { + #if defined(MA_ARM) + { + MA_ATOMIC_MSVC_ARM_INTRINSIC(dst, src, order, _InterlockedXor16, ma_uint16, short); } - #endif - #if defined(MA_ATOMIC_HAS_32) - static MA_INLINE ma_uint32 __stdcall ma_atomic_fetch_sub_explicit_32(volatile ma_uint32* dst, ma_uint32 src, ma_atomic_memory_order order) - { - ma_uint32 oldValue; - ma_uint32 newValue; - do { - oldValue = *dst; - newValue = oldValue - src; - } while (ma_atomic_compare_and_swap_32(dst, oldValue, newValue) != oldValue); - (void)order; - return oldValue; + #else + { + MA_ATOMIC_FETCH_XOR_CAS(16, dst, src, order); } - #endif - #if defined(MA_ATOMIC_HAS_64) - static MA_INLINE ma_uint64 __stdcall ma_atomic_fetch_sub_explicit_64(volatile ma_uint64* dst, ma_uint64 src, ma_atomic_memory_order order) - { - ma_uint64 oldValue; - ma_uint64 newValue; - do { - oldValue = *dst; - newValue = oldValue - src; - } while (ma_atomic_compare_and_swap_64(dst, oldValue, newValue) != oldValue); - (void)order; - return oldValue; + #endif + } + static MA_INLINE ma_uint32 __stdcall ma_atomic_fetch_xor_explicit_32(volatile ma_uint32* dst, ma_uint32 src, ma_atomic_memory_order order) + { + #if defined(MA_ARM) + { + MA_ATOMIC_MSVC_ARM_INTRINSIC(dst, src, order, _InterlockedXor, ma_uint32, long); } - #endif - #if defined(MA_ATOMIC_HAS_8) - static MA_INLINE ma_uint8 __stdcall ma_atomic_fetch_and_explicit_8(volatile ma_uint8* dst, ma_uint8 src, ma_atomic_memory_order order) + #else { + MA_ATOMIC_FETCH_XOR_CAS(32, dst, src, order); + } + #endif + } + static MA_INLINE ma_uint64 __stdcall ma_atomic_fetch_xor_explicit_64(volatile ma_uint64* dst, ma_uint64 src, ma_atomic_memory_order order) + { #if defined(MA_ARM) - MA_ATOMIC_MSVC_ARM_INTRINSIC(dst, src, order, _InterlockedAnd8, ma_uint8, char); + { + MA_ATOMIC_MSVC_ARM_INTRINSIC(dst, src, order, _InterlockedXor64, ma_uint64, long long); + } #else - ma_uint8 oldValue; - ma_uint8 newValue; - do { - oldValue = *dst; - newValue = (ma_uint8)(oldValue & src); - } while (ma_atomic_compare_and_swap_8(dst, oldValue, newValue) != oldValue); - (void)order; - return oldValue; + { + MA_ATOMIC_FETCH_XOR_CAS(64, dst, src, order); + } #endif + } + #define ma_atomic_store_explicit_8( dst, src, order) (void)ma_atomic_exchange_explicit_8 (dst, src, order) + #define ma_atomic_store_explicit_16(dst, src, order) (void)ma_atomic_exchange_explicit_16(dst, src, order) + #define ma_atomic_store_explicit_32(dst, src, order) (void)ma_atomic_exchange_explicit_32(dst, src, order) + #define ma_atomic_store_explicit_64(dst, src, order) (void)ma_atomic_exchange_explicit_64(dst, src, order) + #if defined(MA_X64) + #define ma_atomic_thread_fence(order) __faststorefence(), (void)order + #elif defined(MA_ARM64) + #define ma_atomic_thread_fence(order) __dmb(_ARM64_BARRIER_ISH), (void)order + #else + static MA_INLINE void ma_atomic_thread_fence(ma_atomic_memory_order order) + { + volatile ma_uint32 barrier = 0; + ma_atomic_fetch_add_explicit_32(&barrier, 0, order); } #endif - #if defined(MA_ATOMIC_HAS_16) - static MA_INLINE ma_uint16 __stdcall ma_atomic_fetch_and_explicit_16(volatile ma_uint16* dst, ma_uint16 src, ma_atomic_memory_order order) + #define ma_atomic_signal_fence(order) _ReadWriteBarrier(), (void)order +#endif +#if defined(MA_ATOMIC_LEGACY_MSVC_ASM) + static MA_INLINE ma_uint8 __stdcall ma_atomic_compare_and_swap_8(volatile ma_uint8* dst, ma_uint8 expected, ma_uint8 replacement) + { + #if defined(MA_ATOMIC_IS_LOCK_FREE_8) { - #if defined(MA_ARM) - MA_ATOMIC_MSVC_ARM_INTRINSIC(dst, src, order, _InterlockedAnd16, ma_uint16, short); + ma_uint8 result = 0; + __asm { + mov ecx, dst + mov al, expected + mov dl, replacement + lock cmpxchg [ecx], dl + mov result, al + } + return result; + } #else - ma_uint16 oldValue; - ma_uint16 newValue; - do { - oldValue = *dst; - newValue = (ma_uint16)(oldValue & src); - } while (ma_atomic_compare_and_swap_16(dst, oldValue, newValue) != oldValue); - (void)order; - return oldValue; - #endif + { + MA_ATOMIC_COMPARE_AND_SWAP_LOCK(8, dst, expected, replacement); } - #endif - #if defined(MA_ATOMIC_HAS_32) - static MA_INLINE ma_uint32 __stdcall ma_atomic_fetch_and_explicit_32(volatile ma_uint32* dst, ma_uint32 src, ma_atomic_memory_order order) + #endif + } + static MA_INLINE ma_uint16 __stdcall ma_atomic_compare_and_swap_16(volatile ma_uint16* dst, ma_uint16 expected, ma_uint16 replacement) + { + #if defined(MA_ATOMIC_IS_LOCK_FREE_16) { - #if defined(MA_ARM) - MA_ATOMIC_MSVC_ARM_INTRINSIC(dst, src, order, _InterlockedAnd, ma_uint32, long); + ma_uint16 result = 0; + __asm { + mov ecx, dst + mov ax, expected + mov dx, replacement + lock cmpxchg [ecx], dx + mov result, ax + } + return result; + } #else - ma_uint32 oldValue; - ma_uint32 newValue; - do { - oldValue = *dst; - newValue = oldValue & src; - } while (ma_atomic_compare_and_swap_32(dst, oldValue, newValue) != oldValue); - (void)order; - return oldValue; - #endif + { + MA_ATOMIC_COMPARE_AND_SWAP_LOCK(16, dst, expected, replacement); } - #endif - #if defined(MA_ATOMIC_HAS_64) - static MA_INLINE ma_uint64 __stdcall ma_atomic_fetch_and_explicit_64(volatile ma_uint64* dst, ma_uint64 src, ma_atomic_memory_order order) + #endif + } + static MA_INLINE ma_uint32 __stdcall ma_atomic_compare_and_swap_32(volatile ma_uint32* dst, ma_uint32 expected, ma_uint32 replacement) + { + #if defined(MA_ATOMIC_IS_LOCK_FREE_32) { - #if defined(MA_ARM) - MA_ATOMIC_MSVC_ARM_INTRINSIC(dst, src, order, _InterlockedAnd64, ma_uint64, long long); + ma_uint32 result = 0; + __asm { + mov ecx, dst + mov eax, expected + mov edx, replacement + lock cmpxchg [ecx], edx + mov result, eax + } + return result; + } #else - ma_uint64 oldValue; - ma_uint64 newValue; - do { - oldValue = *dst; - newValue = oldValue & src; - } while (ma_atomic_compare_and_swap_64(dst, oldValue, newValue) != oldValue); - (void)order; - return oldValue; - #endif + { + MA_ATOMIC_COMPARE_AND_SWAP_LOCK(32, dst, expected, replacement); } - #endif - #if defined(MA_ATOMIC_HAS_8) - static MA_INLINE ma_uint8 __stdcall ma_atomic_fetch_xor_explicit_8(volatile ma_uint8* dst, ma_uint8 src, ma_atomic_memory_order order) + #endif + } + static MA_INLINE ma_uint64 __stdcall ma_atomic_compare_and_swap_64(volatile ma_uint64* dst, ma_uint64 expected, ma_uint64 replacement) + { + #if defined(MA_ATOMIC_IS_LOCK_FREE_64) { - #if defined(MA_ARM) - MA_ATOMIC_MSVC_ARM_INTRINSIC(dst, src, order, _InterlockedXor8, ma_uint8, char); + ma_uint32 resultEAX = 0; + ma_uint32 resultEDX = 0; + __asm { + mov esi, dst + mov eax, dword ptr expected + mov edx, dword ptr expected + 4 + mov ebx, dword ptr replacement + mov ecx, dword ptr replacement + 4 + lock cmpxchg8b qword ptr [esi] + mov resultEAX, eax + mov resultEDX, edx + } + return ((ma_uint64)resultEDX << 32) | resultEAX; + } #else - ma_uint8 oldValue; - ma_uint8 newValue; - do { - oldValue = *dst; - newValue = (ma_uint8)(oldValue ^ src); - } while (ma_atomic_compare_and_swap_8(dst, oldValue, newValue) != oldValue); - (void)order; - return oldValue; - #endif + { + MA_ATOMIC_COMPARE_AND_SWAP_LOCK(64, dst, expected, replacement); } - #endif - #if defined(MA_ATOMIC_HAS_16) - static MA_INLINE ma_uint16 __stdcall ma_atomic_fetch_xor_explicit_16(volatile ma_uint16* dst, ma_uint16 src, ma_atomic_memory_order order) + #endif + } + static MA_INLINE ma_uint8 ma_atomic_load_explicit_8(volatile const ma_uint8* dst, ma_atomic_memory_order order) + { + #if defined(MA_ATOMIC_IS_LOCK_FREE_8) { - #if defined(MA_ARM) - MA_ATOMIC_MSVC_ARM_INTRINSIC(dst, src, order, _InterlockedXor16, ma_uint16, short); + ma_uint8 result = 0; + if (order == ma_atomic_memory_order_relaxed) { + __asm { + mov esi, dst + mov al, [esi] + mov result, al + } + } else if (order <= ma_atomic_memory_order_release) { + __asm { + mov esi, dst + mov al, [esi] + lock add dword ptr [esp], 0 + mov result, al + } + } else { + __asm { + lock add dword ptr [esp], 0 + mov esi, dst + mov al, [esi] + mov result, al + lock add dword ptr [esp], 0 + } + } + return result; + } #else - ma_uint16 oldValue; - ma_uint16 newValue; - do { - oldValue = *dst; - newValue = (ma_uint16)(oldValue ^ src); - } while (ma_atomic_compare_and_swap_16(dst, oldValue, newValue) != oldValue); - (void)order; - return oldValue; + { + MA_ATOMIC_LOAD_EXPLICIT_LOCK(8, dst, order); + } #endif + } + static MA_INLINE ma_uint16 ma_atomic_load_explicit_16(volatile const ma_uint16* dst, ma_atomic_memory_order order) + { + #if defined(MA_ATOMIC_IS_LOCK_FREE_16) + { + ma_uint16 result = 0; + if (order == ma_atomic_memory_order_relaxed) { + __asm { + mov esi, dst + mov ax, [esi] + mov result, ax + } + } else if (order <= ma_atomic_memory_order_release) { + __asm { + mov esi, dst + mov ax, [esi] + lock add dword ptr [esp], 0 + mov result, ax + } + } else { + __asm { + lock add dword ptr [esp], 0 + mov esi, dst + mov ax, [esi] + mov result, ax + lock add dword ptr [esp], 0 + } + } + return result; } - #endif - #if defined(MA_ATOMIC_HAS_32) - static MA_INLINE ma_uint32 __stdcall ma_atomic_fetch_xor_explicit_32(volatile ma_uint32* dst, ma_uint32 src, ma_atomic_memory_order order) + #else { - #if defined(MA_ARM) - MA_ATOMIC_MSVC_ARM_INTRINSIC(dst, src, order, _InterlockedXor, ma_uint32, long); + MA_ATOMIC_LOAD_EXPLICIT_LOCK(16, dst, order); + } + #endif + } + static MA_INLINE ma_uint32 ma_atomic_load_explicit_32(volatile const ma_uint32* dst, ma_atomic_memory_order order) + { + #if defined(MA_ATOMIC_IS_LOCK_FREE_32) + { + ma_uint32 result = 0; + if (order == ma_atomic_memory_order_relaxed) { + __asm { + mov esi, dst + mov eax, [esi] + mov result, eax + } + } else if (order <= ma_atomic_memory_order_release) { + __asm { + mov esi, dst + mov eax, [esi] + lock add dword ptr [esp], 0 + mov result, eax + } + } else { + __asm { + lock add dword ptr [esp], 0 + mov esi, dst + mov eax, [esi] + mov result, eax + lock add dword ptr [esp], 0 + } + } + return result; + } #else - ma_uint32 oldValue; - ma_uint32 newValue; - do { - oldValue = *dst; - newValue = oldValue ^ src; - } while (ma_atomic_compare_and_swap_32(dst, oldValue, newValue) != oldValue); - (void)order; - return oldValue; + { + MA_ATOMIC_LOAD_EXPLICIT_LOCK(32, dst, order); + } #endif + } + static MA_INLINE ma_uint64 ma_atomic_load_explicit_64(volatile const ma_uint64* dst, ma_atomic_memory_order order) + { + (void)order; + return ma_atomic_compare_and_swap_64((volatile ma_uint64*)dst, 0, 0); + } + static MA_INLINE void __stdcall ma_atomic_store_explicit_8(volatile ma_uint8* dst, ma_uint8 src, ma_atomic_memory_order order) + { + if (order == ma_atomic_memory_order_relaxed) { + __asm { + mov esi, dst + mov al, src + mov [esi], al + } + } else { + #if defined(MA_ATOMIC_IS_LOCK_FREE_8) + { + __asm { + mov esi, dst + mov al, src + xchg [esi], al + } + } + #else + { + MA_ATOMIC_STORE_EXPLICIT_LOCK(8, dst, src, order); + } + #endif } - #endif - #if defined(MA_ATOMIC_HAS_64) - static MA_INLINE ma_uint64 __stdcall ma_atomic_fetch_xor_explicit_64(volatile ma_uint64* dst, ma_uint64 src, ma_atomic_memory_order order) + } + static MA_INLINE void __stdcall ma_atomic_store_explicit_16(volatile ma_uint16* dst, ma_uint16 src, ma_atomic_memory_order order) + { + if (order == ma_atomic_memory_order_relaxed) { + __asm { + mov esi, dst + mov ax, src + mov [esi], ax + } + } else { + #if defined(MA_ATOMIC_IS_LOCK_FREE_16) + { + __asm { + mov esi, dst + mov ax, src + xchg [esi], ax + } + } + #else + { + MA_ATOMIC_STORE_EXPLICIT_LOCK(16, dst, src, order); + } + #endif + } + } + static MA_INLINE void __stdcall ma_atomic_store_explicit_32(volatile ma_uint32* dst, ma_uint32 src, ma_atomic_memory_order order) + { + if (order == ma_atomic_memory_order_relaxed) { + __asm { + mov esi, dst + mov eax, src + mov [esi], eax + } + } else { + #if defined(MA_ATOMIC_IS_LOCK_FREE_32) + { + __asm { + mov esi, dst + mov eax, src + xchg [esi], eax + } + } + #else + { + MA_ATOMIC_STORE_EXPLICIT_LOCK(32, dst, src, order); + } + #endif + } + } + static MA_INLINE void __stdcall ma_atomic_store_explicit_64(volatile ma_uint64* dst, ma_uint64 src, ma_atomic_memory_order order) + { + #if defined(MA_ATOMIC_IS_LOCK_FREE_64) { - #if defined(MA_ARM) - MA_ATOMIC_MSVC_ARM_INTRINSIC(dst, src, order, _InterlockedXor64, ma_uint64, long long); + MA_ATOMIC_STORE_EXPLICIT_CAS(64, dst, src, order); + } #else - ma_uint64 oldValue; - ma_uint64 newValue; - do { - oldValue = *dst; - newValue = oldValue ^ src; - } while (ma_atomic_compare_and_swap_64(dst, oldValue, newValue) != oldValue); - (void)order; - return oldValue; + { + MA_ATOMIC_STORE_EXPLICIT_LOCK(64, dst, src, order); + } #endif + } + static MA_INLINE ma_uint8 __stdcall ma_atomic_exchange_explicit_8(volatile ma_uint8* dst, ma_uint8 src, ma_atomic_memory_order order) + { + #if defined(MA_ATOMIC_IS_LOCK_FREE_8) + { + ma_uint8 result = 0; + (void)order; + __asm { + mov ecx, dst + mov al, src + lock xchg [ecx], al + mov result, al + } + return result; } - #endif - #if defined(MA_ATOMIC_HAS_8) - static MA_INLINE ma_uint8 __stdcall ma_atomic_fetch_or_explicit_8(volatile ma_uint8* dst, ma_uint8 src, ma_atomic_memory_order order) + #else { - #if defined(MA_ARM) - MA_ATOMIC_MSVC_ARM_INTRINSIC(dst, src, order, _InterlockedOr8, ma_uint8, char); + MA_ATOMIC_EXCHANGE_EXPLICIT_LOCK(8, dst, src, order); + } + #endif + } + static MA_INLINE ma_uint16 __stdcall ma_atomic_exchange_explicit_16(volatile ma_uint16* dst, ma_uint16 src, ma_atomic_memory_order order) + { + #if defined(MA_ATOMIC_IS_LOCK_FREE_16) + { + ma_uint16 result = 0; + (void)order; + __asm { + mov ecx, dst + mov ax, src + lock xchg [ecx], ax + mov result, ax + } + return result; + } #else - ma_uint8 oldValue; - ma_uint8 newValue; - do { - oldValue = *dst; - newValue = (ma_uint8)(oldValue | src); - } while (ma_atomic_compare_and_swap_8(dst, oldValue, newValue) != oldValue); + { + MA_ATOMIC_EXCHANGE_EXPLICIT_LOCK(16, dst, src, order); + } + #endif + } + static MA_INLINE ma_uint32 __stdcall ma_atomic_exchange_explicit_32(volatile ma_uint32* dst, ma_uint32 src, ma_atomic_memory_order order) + { + #if defined(MA_ATOMIC_IS_LOCK_FREE_32) + { + ma_uint32 result = 0; (void)order; - return oldValue; + __asm { + mov ecx, dst + mov eax, src + xchg [ecx], eax + mov result, eax + } + return result; + } + #else + { + MA_ATOMIC_EXCHANGE_EXPLICIT_LOCK(32, dst, src, order); + } #endif + } + static MA_INLINE ma_uint64 __stdcall ma_atomic_exchange_explicit_64(volatile ma_uint64* dst, ma_uint64 src, ma_atomic_memory_order order) + { + #if defined(MA_ATOMIC_IS_LOCK_FREE_64) + { + MA_ATOMIC_EXCHANGE_EXPLICIT_CAS(64, dst, src, order); } - #endif - #if defined(MA_ATOMIC_HAS_16) - static MA_INLINE ma_uint16 __stdcall ma_atomic_fetch_or_explicit_16(volatile ma_uint16* dst, ma_uint16 src, ma_atomic_memory_order order) + #else { - #if defined(MA_ARM) - MA_ATOMIC_MSVC_ARM_INTRINSIC(dst, src, order, _InterlockedOr16, ma_uint16, short); + MA_ATOMIC_EXCHANGE_EXPLICIT_LOCK(64, dst, src, order); + } + #endif + } + static MA_INLINE ma_uint8 __stdcall ma_atomic_fetch_add_explicit_8(volatile ma_uint8* dst, ma_uint8 src, ma_atomic_memory_order order) + { + #if defined(MA_ATOMIC_IS_LOCK_FREE_8) + { + ma_uint8 result = 0; + (void)order; + __asm { + mov ecx, dst + mov al, src + lock xadd [ecx], al + mov result, al + } + return result; + } #else - ma_uint16 oldValue; - ma_uint16 newValue; - do { - oldValue = *dst; - newValue = (ma_uint16)(oldValue | src); - } while (ma_atomic_compare_and_swap_16(dst, oldValue, newValue) != oldValue); + { + MA_ATOMIC_FETCH_ADD_LOCK(8, dst, src, order); + } + #endif + } + static MA_INLINE ma_uint16 __stdcall ma_atomic_fetch_add_explicit_16(volatile ma_uint16* dst, ma_uint16 src, ma_atomic_memory_order order) + { + #if defined(MA_ATOMIC_IS_LOCK_FREE_16) + { + ma_uint16 result = 0; (void)order; - return oldValue; + __asm { + mov ecx, dst + mov ax, src + lock xadd [ecx], ax + mov result, ax + } + return result; + } + #else + { + MA_ATOMIC_FETCH_ADD_LOCK(16, dst, src, order); + } #endif + } + static MA_INLINE ma_uint32 __stdcall ma_atomic_fetch_add_explicit_32(volatile ma_uint32* dst, ma_uint32 src, ma_atomic_memory_order order) + { + #if defined(MA_ATOMIC_IS_LOCK_FREE_32) + { + ma_uint32 result = 0; + (void)order; + __asm { + mov ecx, dst + mov eax, src + lock xadd [ecx], eax + mov result, eax + } + return result; } - #endif - #if defined(MA_ATOMIC_HAS_32) - static MA_INLINE ma_uint32 __stdcall ma_atomic_fetch_or_explicit_32(volatile ma_uint32* dst, ma_uint32 src, ma_atomic_memory_order order) + #else { - #if defined(MA_ARM) - MA_ATOMIC_MSVC_ARM_INTRINSIC(dst, src, order, _InterlockedOr, ma_uint32, long); + MA_ATOMIC_FETCH_ADD_LOCK(32, dst, src, order); + } + #endif + } + static MA_INLINE ma_uint64 __stdcall ma_atomic_fetch_add_explicit_64(volatile ma_uint64* dst, ma_uint64 src, ma_atomic_memory_order order) + { + #if defined(MA_ATOMIC_IS_LOCK_FREE_64) + { + MA_ATOMIC_FETCH_ADD_CAS(64, dst, src, order); + } #else - ma_uint32 oldValue; - ma_uint32 newValue; - do { - oldValue = *dst; - newValue = oldValue | src; - } while (ma_atomic_compare_and_swap_32(dst, oldValue, newValue) != oldValue); - (void)order; - return oldValue; + { + MA_ATOMIC_FETCH_ADD_LOCK(64, dst, src, order); + } #endif + } + static MA_INLINE ma_uint8 __stdcall ma_atomic_fetch_sub_explicit_8(volatile ma_uint8* dst, ma_uint8 src, ma_atomic_memory_order order) + { + #if defined(MA_ATOMIC_IS_LOCK_FREE_8) + { + ma_uint8 result = 0; + (void)order; + __asm { + mov ecx, dst + mov al, src + neg al + lock xadd [ecx], al + mov result, al + } + return result; } - #endif - #if defined(MA_ATOMIC_HAS_64) - static MA_INLINE ma_uint64 __stdcall ma_atomic_fetch_or_explicit_64(volatile ma_uint64* dst, ma_uint64 src, ma_atomic_memory_order order) + #else { - #if defined(MA_ARM) - MA_ATOMIC_MSVC_ARM_INTRINSIC(dst, src, order, _InterlockedOr64, ma_uint64, long long); + MA_ATOMIC_FETCH_ADD_LOCK(8, dst, (ma_uint8)(-(ma_int8)src), order); + } + #endif + } + static MA_INLINE ma_uint16 __stdcall ma_atomic_fetch_sub_explicit_16(volatile ma_uint16* dst, ma_uint16 src, ma_atomic_memory_order order) + { + #if defined(MA_ATOMIC_IS_LOCK_FREE_16) + { + ma_uint16 result = 0; + (void)order; + __asm { + mov ecx, dst + mov ax, src + neg ax + lock xadd [ecx], ax + mov result, ax + } + return result; + } #else - ma_uint64 oldValue; - ma_uint64 newValue; - do { - oldValue = *dst; - newValue = oldValue | src; - } while (ma_atomic_compare_and_swap_64(dst, oldValue, newValue) != oldValue); + { + MA_ATOMIC_FETCH_ADD_LOCK(16, dst, (ma_uint16)(-(ma_int16)src), order); + } + #endif + } + static MA_INLINE ma_uint32 __stdcall ma_atomic_fetch_sub_explicit_32(volatile ma_uint32* dst, ma_uint32 src, ma_atomic_memory_order order) + { + #if defined(MA_ATOMIC_IS_LOCK_FREE_32) + { + ma_uint32 result = 0; (void)order; - return oldValue; + __asm { + mov ecx, dst + mov eax, src + neg eax + lock xadd [ecx], eax + mov result, eax + } + return result; + } + #else + { + MA_ATOMIC_FETCH_ADD_LOCK(32, dst, (ma_uint32)(-(ma_int32)src), order); + } #endif + } + static MA_INLINE ma_uint64 __stdcall ma_atomic_fetch_sub_explicit_64(volatile ma_uint64* dst, ma_uint64 src, ma_atomic_memory_order order) + { + MA_ATOMIC_FETCH_ADD_CAS(64, dst, (ma_uint64)(-(ma_int64)src), order); + } + static MA_INLINE ma_uint8 __stdcall ma_atomic_fetch_and_explicit_8(volatile ma_uint8* dst, ma_uint8 src, ma_atomic_memory_order order) + { + MA_ATOMIC_FETCH_AND_CAS(8, dst, src, order); + } + static MA_INLINE ma_uint16 __stdcall ma_atomic_fetch_and_explicit_16(volatile ma_uint16* dst, ma_uint16 src, ma_atomic_memory_order order) + { + MA_ATOMIC_FETCH_AND_CAS(16, dst, src, order); + } + static MA_INLINE ma_uint32 __stdcall ma_atomic_fetch_and_explicit_32(volatile ma_uint32* dst, ma_uint32 src, ma_atomic_memory_order order) + { + MA_ATOMIC_FETCH_AND_CAS(32, dst, src, order); + } + static MA_INLINE ma_uint64 __stdcall ma_atomic_fetch_and_explicit_64(volatile ma_uint64* dst, ma_uint64 src, ma_atomic_memory_order order) + { + MA_ATOMIC_FETCH_AND_CAS(64, dst, src, order); + } + static MA_INLINE ma_uint8 __stdcall ma_atomic_fetch_or_explicit_8(volatile ma_uint8* dst, ma_uint8 src, ma_atomic_memory_order order) + { + MA_ATOMIC_FETCH_OR_CAS(8, dst, src, order); + } + static MA_INLINE ma_uint16 __stdcall ma_atomic_fetch_or_explicit_16(volatile ma_uint16* dst, ma_uint16 src, ma_atomic_memory_order order) + { + MA_ATOMIC_FETCH_OR_CAS(16, dst, src, order); + } + static MA_INLINE ma_uint32 __stdcall ma_atomic_fetch_or_explicit_32(volatile ma_uint32* dst, ma_uint32 src, ma_atomic_memory_order order) + { + MA_ATOMIC_FETCH_OR_CAS(32, dst, src, order); + } + static MA_INLINE ma_uint64 __stdcall ma_atomic_fetch_or_explicit_64(volatile ma_uint64* dst, ma_uint64 src, ma_atomic_memory_order order) + { + MA_ATOMIC_FETCH_OR_CAS(64, dst, src, order); + } + static MA_INLINE ma_uint8 __stdcall ma_atomic_fetch_xor_explicit_8(volatile ma_uint8* dst, ma_uint8 src, ma_atomic_memory_order order) + { + MA_ATOMIC_FETCH_XOR_CAS(8, dst, src, order); + } + static MA_INLINE ma_uint16 __stdcall ma_atomic_fetch_xor_explicit_16(volatile ma_uint16* dst, ma_uint16 src, ma_atomic_memory_order order) + { + MA_ATOMIC_FETCH_XOR_CAS(16, dst, src, order); + } + static MA_INLINE ma_uint32 __stdcall ma_atomic_fetch_xor_explicit_32(volatile ma_uint32* dst, ma_uint32 src, ma_atomic_memory_order order) + { + MA_ATOMIC_FETCH_XOR_CAS(32, dst, src, order); + } + static MA_INLINE ma_uint64 __stdcall ma_atomic_fetch_xor_explicit_64(volatile ma_uint64* dst, ma_uint64 src, ma_atomic_memory_order order) + { + MA_ATOMIC_FETCH_XOR_CAS(64, dst, src, order); + } + static MA_INLINE void __stdcall ma_atomic_thread_fence(ma_atomic_memory_order order) + { + (void)order; + __asm { + lock add dword ptr [esp], 0 } - #endif - #if defined(MA_ATOMIC_HAS_8) - #define ma_atomic_test_and_set_explicit_8( dst, order) ma_atomic_exchange_explicit_8 (dst, 1, order) - #endif - #if defined(MA_ATOMIC_HAS_16) - #define ma_atomic_test_and_set_explicit_16(dst, order) ma_atomic_exchange_explicit_16(dst, 1, order) - #endif - #if defined(MA_ATOMIC_HAS_32) - #define ma_atomic_test_and_set_explicit_32(dst, order) ma_atomic_exchange_explicit_32(dst, 1, order) - #endif - #if defined(MA_ATOMIC_HAS_64) - #define ma_atomic_test_and_set_explicit_64(dst, order) ma_atomic_exchange_explicit_64(dst, 1, order) - #endif - #if defined(MA_ATOMIC_HAS_8) - #define ma_atomic_clear_explicit_8( dst, order) ma_atomic_store_explicit_8 (dst, 0, order) - #endif - #if defined(MA_ATOMIC_HAS_16) - #define ma_atomic_clear_explicit_16(dst, order) ma_atomic_store_explicit_16(dst, 0, order) - #endif - #if defined(MA_ATOMIC_HAS_32) - #define ma_atomic_clear_explicit_32(dst, order) ma_atomic_store_explicit_32(dst, 0, order) - #endif - #if defined(MA_ATOMIC_HAS_64) - #define ma_atomic_clear_explicit_64(dst, order) ma_atomic_store_explicit_64(dst, 0, order) - #endif - #if defined(MA_ATOMIC_HAS_8) - typedef ma_uint8 ma_atomic_flag; - #define ma_atomic_flag_test_and_set_explicit(ptr, order) (ma_bool32)ma_atomic_test_and_set_explicit_8(ptr, order) - #define ma_atomic_flag_clear_explicit(ptr, order) ma_atomic_clear_explicit_8(ptr, order) - #define ma_atomic_flag_load_explicit(ptr, order) ma_atomic_load_explicit_8(ptr, order) - #else - typedef ma_uint32 ma_atomic_flag; - #define ma_atomic_flag_test_and_set_explicit(ptr, order) (ma_bool32)ma_atomic_test_and_set_explicit_32(ptr, order) - #define ma_atomic_flag_clear_explicit(ptr, order) ma_atomic_clear_explicit_32(ptr, order) - #define ma_atomic_flag_load_explicit(ptr, order) ma_atomic_load_explicit_32(ptr, order) - #endif -#elif defined(__clang__) || (defined(__GNUC__) && (__GNUC__ > 4 || (__GNUC__ == 4 && __GNUC_MINOR__ >= 7))) + } + #define ma_atomic_signal_fence(order) __asm {}; (void)order +#endif +#if defined(MA_ATOMIC_MODERN_GCC) #define MA_ATOMIC_HAS_NATIVE_COMPARE_EXCHANGE - #define MA_ATOMIC_HAS_NATIVE_IS_LOCK_FREE - #define ma_atomic_memory_order_relaxed __ATOMIC_RELAXED - #define ma_atomic_memory_order_consume __ATOMIC_CONSUME - #define ma_atomic_memory_order_acquire __ATOMIC_ACQUIRE - #define ma_atomic_memory_order_release __ATOMIC_RELEASE - #define ma_atomic_memory_order_acq_rel __ATOMIC_ACQ_REL - #define ma_atomic_memory_order_seq_cst __ATOMIC_SEQ_CST - #define ma_atomic_compiler_fence() __asm__ __volatile__("":::"memory") #define ma_atomic_thread_fence(order) __atomic_thread_fence(order) #define ma_atomic_signal_fence(order) __atomic_signal_fence(order) #define ma_atomic_is_lock_free_8(ptr) __atomic_is_lock_free(1, ptr) #define ma_atomic_is_lock_free_16(ptr) __atomic_is_lock_free(2, ptr) #define ma_atomic_is_lock_free_32(ptr) __atomic_is_lock_free(4, ptr) #define ma_atomic_is_lock_free_64(ptr) __atomic_is_lock_free(8, ptr) - #define ma_atomic_test_and_set_explicit_8( dst, order) __atomic_exchange_n(dst, 1, order) - #define ma_atomic_test_and_set_explicit_16(dst, order) __atomic_exchange_n(dst, 1, order) - #define ma_atomic_test_and_set_explicit_32(dst, order) __atomic_exchange_n(dst, 1, order) - #define ma_atomic_test_and_set_explicit_64(dst, order) __atomic_exchange_n(dst, 1, order) - #define ma_atomic_clear_explicit_8( dst, order) __atomic_store_n(dst, 0, order) - #define ma_atomic_clear_explicit_16(dst, order) __atomic_store_n(dst, 0, order) - #define ma_atomic_clear_explicit_32(dst, order) __atomic_store_n(dst, 0, order) - #define ma_atomic_clear_explicit_64(dst, order) __atomic_store_n(dst, 0, order) #define ma_atomic_store_explicit_8( dst, src, order) __atomic_store_n(dst, src, order) #define ma_atomic_store_explicit_16(dst, src, order) __atomic_store_n(dst, src, order) #define ma_atomic_store_explicit_32(dst, src, order) __atomic_store_n(dst, src, order) @@ -14864,14 +15810,14 @@ typedef int ma_atomic_memory_order; #define ma_atomic_exchange_explicit_16(dst, src, order) __atomic_exchange_n(dst, src, order) #define ma_atomic_exchange_explicit_32(dst, src, order) __atomic_exchange_n(dst, src, order) #define ma_atomic_exchange_explicit_64(dst, src, order) __atomic_exchange_n(dst, src, order) - #define ma_atomic_compare_exchange_strong_explicit_8( dst, expected, desired, successOrder, failureOrder) __atomic_compare_exchange_n(dst, expected, desired, 0, successOrder, failureOrder) - #define ma_atomic_compare_exchange_strong_explicit_16(dst, expected, desired, successOrder, failureOrder) __atomic_compare_exchange_n(dst, expected, desired, 0, successOrder, failureOrder) - #define ma_atomic_compare_exchange_strong_explicit_32(dst, expected, desired, successOrder, failureOrder) __atomic_compare_exchange_n(dst, expected, desired, 0, successOrder, failureOrder) - #define ma_atomic_compare_exchange_strong_explicit_64(dst, expected, desired, successOrder, failureOrder) __atomic_compare_exchange_n(dst, expected, desired, 0, successOrder, failureOrder) - #define ma_atomic_compare_exchange_weak_explicit_8( dst, expected, desired, successOrder, failureOrder) __atomic_compare_exchange_n(dst, expected, desired, 1, successOrder, failureOrder) - #define ma_atomic_compare_exchange_weak_explicit_16(dst, expected, desired, successOrder, failureOrder) __atomic_compare_exchange_n(dst, expected, desired, 1, successOrder, failureOrder) - #define ma_atomic_compare_exchange_weak_explicit_32(dst, expected, desired, successOrder, failureOrder) __atomic_compare_exchange_n(dst, expected, desired, 1, successOrder, failureOrder) - #define ma_atomic_compare_exchange_weak_explicit_64(dst, expected, desired, successOrder, failureOrder) __atomic_compare_exchange_n(dst, expected, desired, 1, successOrder, failureOrder) + #define ma_atomic_compare_exchange_strong_explicit_8( dst, expected, replacement, successOrder, failureOrder) __atomic_compare_exchange_n(dst, expected, replacement, 0, successOrder, failureOrder) + #define ma_atomic_compare_exchange_strong_explicit_16(dst, expected, replacement, successOrder, failureOrder) __atomic_compare_exchange_n(dst, expected, replacement, 0, successOrder, failureOrder) + #define ma_atomic_compare_exchange_strong_explicit_32(dst, expected, replacement, successOrder, failureOrder) __atomic_compare_exchange_n(dst, expected, replacement, 0, successOrder, failureOrder) + #define ma_atomic_compare_exchange_strong_explicit_64(dst, expected, replacement, successOrder, failureOrder) __atomic_compare_exchange_n(dst, expected, replacement, 0, successOrder, failureOrder) + #define ma_atomic_compare_exchange_weak_explicit_8( dst, expected, replacement, successOrder, failureOrder) __atomic_compare_exchange_n(dst, expected, replacement, 1, successOrder, failureOrder) + #define ma_atomic_compare_exchange_weak_explicit_16(dst, expected, replacement, successOrder, failureOrder) __atomic_compare_exchange_n(dst, expected, replacement, 1, successOrder, failureOrder) + #define ma_atomic_compare_exchange_weak_explicit_32(dst, expected, replacement, successOrder, failureOrder) __atomic_compare_exchange_n(dst, expected, replacement, 1, successOrder, failureOrder) + #define ma_atomic_compare_exchange_weak_explicit_64(dst, expected, replacement, successOrder, failureOrder) __atomic_compare_exchange_n(dst, expected, replacement, 1, successOrder, failureOrder) #define ma_atomic_fetch_add_explicit_8( dst, src, order) __atomic_fetch_add(dst, src, order) #define ma_atomic_fetch_add_explicit_16(dst, src, order) __atomic_fetch_add(dst, src, order) #define ma_atomic_fetch_add_explicit_32(dst, src, order) __atomic_fetch_add(dst, src, order) @@ -14892,19 +15838,19 @@ typedef int ma_atomic_memory_order; #define ma_atomic_fetch_and_explicit_16(dst, src, order) __atomic_fetch_and(dst, src, order) #define ma_atomic_fetch_and_explicit_32(dst, src, order) __atomic_fetch_and(dst, src, order) #define ma_atomic_fetch_and_explicit_64(dst, src, order) __atomic_fetch_and(dst, src, order) - static MA_INLINE ma_uint8 ma_atomic_compare_and_swap_8(volatile ma_uint8* dst, ma_uint8 expected, ma_uint8 desired) + static MA_INLINE ma_uint8 ma_atomic_compare_and_swap_8(volatile ma_uint8* dst, ma_uint8 expected, ma_uint8 replacement) { - __atomic_compare_exchange_n(dst, &expected, desired, 0, __ATOMIC_SEQ_CST, __ATOMIC_SEQ_CST); + __atomic_compare_exchange_n(dst, &expected, replacement, 0, __ATOMIC_SEQ_CST, __ATOMIC_SEQ_CST); return expected; } - static MA_INLINE ma_uint16 ma_atomic_compare_and_swap_16(volatile ma_uint16* dst, ma_uint16 expected, ma_uint16 desired) + static MA_INLINE ma_uint16 ma_atomic_compare_and_swap_16(volatile ma_uint16* dst, ma_uint16 expected, ma_uint16 replacement) { - __atomic_compare_exchange_n(dst, &expected, desired, 0, __ATOMIC_SEQ_CST, __ATOMIC_SEQ_CST); + __atomic_compare_exchange_n(dst, &expected, replacement, 0, __ATOMIC_SEQ_CST, __ATOMIC_SEQ_CST); return expected; } - static MA_INLINE ma_uint32 ma_atomic_compare_and_swap_32(volatile ma_uint32* dst, ma_uint32 expected, ma_uint32 desired) + static MA_INLINE ma_uint32 ma_atomic_compare_and_swap_32(volatile ma_uint32* dst, ma_uint32 expected, ma_uint32 replacement) { - __atomic_compare_exchange_n(dst, &expected, desired, 0, __ATOMIC_SEQ_CST, __ATOMIC_SEQ_CST); + __atomic_compare_exchange_n(dst, &expected, replacement, 0, __ATOMIC_SEQ_CST, __ATOMIC_SEQ_CST); return expected; } #if defined(__clang__) @@ -14913,636 +15859,1134 @@ typedef int ma_atomic_memory_order; #pragma clang diagnostic ignored "-Watomic-alignment" #endif #endif - static MA_INLINE ma_uint64 ma_atomic_compare_and_swap_64(volatile ma_uint64* dst, ma_uint64 expected, ma_uint64 desired) + static MA_INLINE ma_uint64 ma_atomic_compare_and_swap_64(volatile ma_uint64* dst, ma_uint64 expected, ma_uint64 replacement) { - __atomic_compare_exchange_n(dst, &expected, desired, 0, __ATOMIC_SEQ_CST, __ATOMIC_SEQ_CST); + __atomic_compare_exchange_n(dst, &expected, replacement, 0, __ATOMIC_SEQ_CST, __ATOMIC_SEQ_CST); return expected; } #if defined(__clang__) #pragma clang diagnostic pop #endif - typedef ma_uint8 ma_atomic_flag; - #define ma_atomic_flag_test_and_set_explicit(dst, order) (ma_bool32)__atomic_test_and_set(dst, order) - #define ma_atomic_flag_clear_explicit(dst, order) __atomic_clear(dst, order) - #define ma_atomic_flag_load_explicit(ptr, order) ma_atomic_load_explicit_8(ptr, order) -#else - #define ma_atomic_memory_order_relaxed 1 - #define ma_atomic_memory_order_consume 2 - #define ma_atomic_memory_order_acquire 3 - #define ma_atomic_memory_order_release 4 - #define ma_atomic_memory_order_acq_rel 5 - #define ma_atomic_memory_order_seq_cst 6 - #define ma_atomic_compiler_fence() __asm__ __volatile__("":::"memory") - #if defined(__GNUC__) +#endif +#if defined(MA_ATOMIC_LEGACY_GCC) || defined(MA_ATOMIC_LEGACY_GCC_ASM) + #define ma_atomic_signal_fence(order) __asm__ __volatile__("":::"memory") + #if defined(MA_ATOMIC_LEGACY_GCC) #define ma_atomic_thread_fence(order) __sync_synchronize(), (void)order + static MA_INLINE ma_uint8 ma_atomic_compare_and_swap_8(volatile ma_uint8* dst, ma_uint8 expected, ma_uint8 replacement) + { + #if defined(MA_ATOMIC_IS_LOCK_FREE_8) + { + return __sync_val_compare_and_swap(dst, expected, replacement); + } + #else + { + MA_ATOMIC_COMPARE_AND_SWAP_LOCK(8, dst, expected, replacement); + } + #endif + } + static MA_INLINE ma_uint16 ma_atomic_compare_and_swap_16(volatile ma_uint16* dst, ma_uint16 expected, ma_uint16 replacement) + { + #if defined(MA_ATOMIC_IS_LOCK_FREE_16) + { + return __sync_val_compare_and_swap(dst, expected, replacement); + } + #else + { + MA_ATOMIC_COMPARE_AND_SWAP_LOCK(16, dst, expected, replacement); + } + #endif + } + static MA_INLINE ma_uint32 ma_atomic_compare_and_swap_32(volatile ma_uint32* dst, ma_uint32 expected, ma_uint32 replacement) + { + #if defined(MA_ATOMIC_IS_LOCK_FREE_32) + { + return __sync_val_compare_and_swap(dst, expected, replacement); + } + #else + { + MA_ATOMIC_COMPARE_AND_SWAP_LOCK(32, dst, expected, replacement); + } + #endif + } + static MA_INLINE ma_uint64 ma_atomic_compare_and_swap_64(volatile ma_uint64* dst, ma_uint64 expected, ma_uint64 replacement) + { + #if defined(MA_ATOMIC_IS_LOCK_FREE_64) + { + return __sync_val_compare_and_swap(dst, expected, replacement); + } + #else + { + MA_ATOMIC_COMPARE_AND_SWAP_LOCK(64, dst, expected, replacement); + } + #endif + } + static MA_INLINE ma_uint8 ma_atomic_load_explicit_8(volatile const ma_uint8* ptr, ma_atomic_memory_order order) + { + #if defined(MA_ATOMIC_IS_LOCK_FREE_8) + { + (void)order; + return ma_atomic_compare_and_swap_8((ma_uint8*)ptr, 0, 0); + } + #else + { + MA_ATOMIC_LOAD_EXPLICIT_LOCK(8, ptr, order); + } + #endif + } + static MA_INLINE ma_uint16 ma_atomic_load_explicit_16(volatile const ma_uint16* ptr, ma_atomic_memory_order order) + { + #if defined(MA_ATOMIC_IS_LOCK_FREE_16) + { + (void)order; + return ma_atomic_compare_and_swap_16((ma_uint16*)ptr, 0, 0); + } + #else + { + MA_ATOMIC_LOAD_EXPLICIT_LOCK(16, ptr, order); + } + #endif + } + static MA_INLINE ma_uint32 ma_atomic_load_explicit_32(volatile const ma_uint32* ptr, ma_atomic_memory_order order) + { + #if defined(MA_ATOMIC_IS_LOCK_FREE_32) + { + (void)order; + return ma_atomic_compare_and_swap_32((ma_uint32*)ptr, 0, 0); + } + #else + { + MA_ATOMIC_LOAD_EXPLICIT_LOCK(32, ptr, order); + } + #endif + } + static MA_INLINE ma_uint64 ma_atomic_load_explicit_64(volatile const ma_uint64* ptr, ma_atomic_memory_order order) + { + #if defined(MA_ATOMIC_IS_LOCK_FREE_64) + { + (void)order; + return ma_atomic_compare_and_swap_64((ma_uint64*)ptr, 0, 0); + } + #else + { + MA_ATOMIC_LOAD_EXPLICIT_LOCK(64, ptr, order); + } + #endif + } static MA_INLINE ma_uint8 ma_atomic_exchange_explicit_8(volatile ma_uint8* dst, ma_uint8 src, ma_atomic_memory_order order) { - if (order > ma_atomic_memory_order_acquire) { - __sync_synchronize(); + #if defined(MA_ATOMIC_IS_LOCK_FREE_8) + { + if (order > ma_atomic_memory_order_acquire) { + __sync_synchronize(); + } + return __sync_lock_test_and_set(dst, src); } - return __sync_lock_test_and_set(dst, src); + #else + { + MA_ATOMIC_EXCHANGE_EXPLICIT_LOCK(8, dst, src, order); + } + #endif } static MA_INLINE ma_uint16 ma_atomic_exchange_explicit_16(volatile ma_uint16* dst, ma_uint16 src, ma_atomic_memory_order order) { - ma_uint16 oldValue; - do { - oldValue = *dst; - } while (__sync_val_compare_and_swap(dst, oldValue, src) != oldValue); - (void)order; - return oldValue; + #if defined(MA_ATOMIC_IS_LOCK_FREE_16) + { + if (order > ma_atomic_memory_order_acquire) { + __sync_synchronize(); + } + return __sync_lock_test_and_set(dst, src); + } + #else + { + MA_ATOMIC_EXCHANGE_EXPLICIT_LOCK(16, dst, src, order); + } + #endif } static MA_INLINE ma_uint32 ma_atomic_exchange_explicit_32(volatile ma_uint32* dst, ma_uint32 src, ma_atomic_memory_order order) { - ma_uint32 oldValue; - do { - oldValue = *dst; - } while (__sync_val_compare_and_swap(dst, oldValue, src) != oldValue); - (void)order; - return oldValue; + #if defined(MA_ATOMIC_IS_LOCK_FREE_32) + { + if (order > ma_atomic_memory_order_acquire) { + __sync_synchronize(); + } + return __sync_lock_test_and_set(dst, src); + } + #else + { + MA_ATOMIC_EXCHANGE_EXPLICIT_LOCK(32, dst, src, order); + } + #endif } static MA_INLINE ma_uint64 ma_atomic_exchange_explicit_64(volatile ma_uint64* dst, ma_uint64 src, ma_atomic_memory_order order) { - ma_uint64 oldValue; - do { - oldValue = *dst; - } while (__sync_val_compare_and_swap(dst, oldValue, src) != oldValue); - (void)order; - return oldValue; + #if defined(MA_ATOMIC_IS_LOCK_FREE_64) + { + if (order > ma_atomic_memory_order_acquire) { + __sync_synchronize(); + } + return __sync_lock_test_and_set(dst, src); + } + #else + { + MA_ATOMIC_EXCHANGE_EXPLICIT_LOCK(64, dst, src, order); + } + #endif } + #define ma_atomic_store_explicit_8( dst, src, order) (void)ma_atomic_exchange_explicit_8 (dst, src, order) + #define ma_atomic_store_explicit_16(dst, src, order) (void)ma_atomic_exchange_explicit_16(dst, src, order) + #define ma_atomic_store_explicit_32(dst, src, order) (void)ma_atomic_exchange_explicit_32(dst, src, order) + #define ma_atomic_store_explicit_64(dst, src, order) (void)ma_atomic_exchange_explicit_64(dst, src, order) static MA_INLINE ma_uint8 ma_atomic_fetch_add_explicit_8(volatile ma_uint8* dst, ma_uint8 src, ma_atomic_memory_order order) { - (void)order; - return __sync_fetch_and_add(dst, src); + #if defined(MA_ATOMIC_IS_LOCK_FREE_8) + { + (void)order; + return __sync_fetch_and_add(dst, src); + } + #else + { + MA_ATOMIC_FETCH_ADD_LOCK(8, dst, src, order); + } + #endif } static MA_INLINE ma_uint16 ma_atomic_fetch_add_explicit_16(volatile ma_uint16* dst, ma_uint16 src, ma_atomic_memory_order order) { - (void)order; - return __sync_fetch_and_add(dst, src); + #if defined(MA_ATOMIC_IS_LOCK_FREE_16) + { + (void)order; + return __sync_fetch_and_add(dst, src); + } + #else + { + MA_ATOMIC_FETCH_ADD_LOCK(16, dst, src, order); + } + #endif } static MA_INLINE ma_uint32 ma_atomic_fetch_add_explicit_32(volatile ma_uint32* dst, ma_uint32 src, ma_atomic_memory_order order) { - (void)order; - return __sync_fetch_and_add(dst, src); + #if defined(MA_ATOMIC_IS_LOCK_FREE_32) + { + (void)order; + return __sync_fetch_and_add(dst, src); + } + #else + { + MA_ATOMIC_FETCH_ADD_LOCK(32, dst, src, order); + } + #endif } static MA_INLINE ma_uint64 ma_atomic_fetch_add_explicit_64(volatile ma_uint64* dst, ma_uint64 src, ma_atomic_memory_order order) { - (void)order; - return __sync_fetch_and_add(dst, src); + #if defined(MA_ATOMIC_IS_LOCK_FREE_64) + { + (void)order; + return __sync_fetch_and_add(dst, src); + } + #else + { + MA_ATOMIC_FETCH_ADD_LOCK(64, dst, src, order); + } + #endif } static MA_INLINE ma_uint8 ma_atomic_fetch_sub_explicit_8(volatile ma_uint8* dst, ma_uint8 src, ma_atomic_memory_order order) { - (void)order; - return __sync_fetch_and_sub(dst, src); + #if defined(MA_ATOMIC_IS_LOCK_FREE_8) + { + (void)order; + return __sync_fetch_and_sub(dst, src); + } + #else + { + MA_ATOMIC_FETCH_ADD_LOCK(8, dst, (ma_uint8)(-(ma_int8)src), order); + } + #endif } static MA_INLINE ma_uint16 ma_atomic_fetch_sub_explicit_16(volatile ma_uint16* dst, ma_uint16 src, ma_atomic_memory_order order) { - (void)order; - return __sync_fetch_and_sub(dst, src); + #if defined(MA_ATOMIC_IS_LOCK_FREE_16) + { + (void)order; + return __sync_fetch_and_sub(dst, src); + } + #else + { + MA_ATOMIC_FETCH_ADD_LOCK(16, dst, (ma_uint16)(-(ma_int16)src), order); + } + #endif } static MA_INLINE ma_uint32 ma_atomic_fetch_sub_explicit_32(volatile ma_uint32* dst, ma_uint32 src, ma_atomic_memory_order order) { - (void)order; - return __sync_fetch_and_sub(dst, src); + #if defined(MA_ATOMIC_IS_LOCK_FREE_32) + { + (void)order; + return __sync_fetch_and_sub(dst, src); + } + #else + { + MA_ATOMIC_FETCH_ADD_LOCK(32, dst, (ma_uint32)(-(ma_int32)src), order); + } + #endif } static MA_INLINE ma_uint64 ma_atomic_fetch_sub_explicit_64(volatile ma_uint64* dst, ma_uint64 src, ma_atomic_memory_order order) { - (void)order; - return __sync_fetch_and_sub(dst, src); + #if defined(MA_ATOMIC_IS_LOCK_FREE_64) + { + (void)order; + return __sync_fetch_and_sub(dst, src); + } + #else + { + MA_ATOMIC_FETCH_ADD_LOCK(64, dst, (ma_uint64)(-(ma_int64)src), order); + } + #endif + } + static MA_INLINE ma_uint8 ma_atomic_fetch_and_explicit_8(volatile ma_uint8* dst, ma_uint8 src, ma_atomic_memory_order order) + { + #if defined(MA_ATOMIC_IS_LOCK_FREE_8) + { + (void)order; + return __sync_fetch_and_and(dst, src); + } + #else + { + MA_ATOMIC_FETCH_AND_CAS(8, dst, src, order); + } + #endif + } + static MA_INLINE ma_uint16 ma_atomic_fetch_and_explicit_16(volatile ma_uint16* dst, ma_uint16 src, ma_atomic_memory_order order) + { + #if defined(MA_ATOMIC_IS_LOCK_FREE_16) + { + (void)order; + return __sync_fetch_and_and(dst, src); + } + #else + { + MA_ATOMIC_FETCH_AND_CAS(16, dst, src, order); + } + #endif + } + static MA_INLINE ma_uint32 ma_atomic_fetch_and_explicit_32(volatile ma_uint32* dst, ma_uint32 src, ma_atomic_memory_order order) + { + #if defined(MA_ATOMIC_IS_LOCK_FREE_32) + { + (void)order; + return __sync_fetch_and_and(dst, src); + } + #else + { + MA_ATOMIC_FETCH_AND_CAS(32, dst, src, order); + } + #endif + } + static MA_INLINE ma_uint64 ma_atomic_fetch_and_explicit_64(volatile ma_uint64* dst, ma_uint64 src, ma_atomic_memory_order order) + { + #if defined(MA_ATOMIC_IS_LOCK_FREE_64) + { + (void)order; + return __sync_fetch_and_and(dst, src); + } + #else + { + MA_ATOMIC_FETCH_AND_CAS(64, dst, src, order); + } + #endif } static MA_INLINE ma_uint8 ma_atomic_fetch_or_explicit_8(volatile ma_uint8* dst, ma_uint8 src, ma_atomic_memory_order order) { - (void)order; - return __sync_fetch_and_or(dst, src); + #if defined(MA_ATOMIC_IS_LOCK_FREE_8) + { + (void)order; + return __sync_fetch_and_or(dst, src); + } + #else + { + MA_ATOMIC_FETCH_OR_CAS(8, dst, src, order); + } + #endif } static MA_INLINE ma_uint16 ma_atomic_fetch_or_explicit_16(volatile ma_uint16* dst, ma_uint16 src, ma_atomic_memory_order order) { - (void)order; - return __sync_fetch_and_or(dst, src); + #if defined(MA_ATOMIC_IS_LOCK_FREE_16) + { + (void)order; + return __sync_fetch_and_or(dst, src); + } + #else + { + MA_ATOMIC_FETCH_OR_CAS(16, dst, src, order); + } + #endif } static MA_INLINE ma_uint32 ma_atomic_fetch_or_explicit_32(volatile ma_uint32* dst, ma_uint32 src, ma_atomic_memory_order order) { - (void)order; - return __sync_fetch_and_or(dst, src); + #if defined(MA_ATOMIC_IS_LOCK_FREE_32) + { + (void)order; + return __sync_fetch_and_or(dst, src); + } + #else + { + MA_ATOMIC_FETCH_OR_CAS(32, dst, src, order); + } + #endif } static MA_INLINE ma_uint64 ma_atomic_fetch_or_explicit_64(volatile ma_uint64* dst, ma_uint64 src, ma_atomic_memory_order order) { - (void)order; - return __sync_fetch_and_or(dst, src); + #if defined(MA_ATOMIC_IS_LOCK_FREE_64) + { + (void)order; + return __sync_fetch_and_or(dst, src); + } + #else + { + MA_ATOMIC_FETCH_OR_CAS(64, dst, src, order); + } + #endif } static MA_INLINE ma_uint8 ma_atomic_fetch_xor_explicit_8(volatile ma_uint8* dst, ma_uint8 src, ma_atomic_memory_order order) { - (void)order; - return __sync_fetch_and_xor(dst, src); + #if defined(MA_ATOMIC_IS_LOCK_FREE_8) + { + (void)order; + return __sync_fetch_and_xor(dst, src); + } + #else + { + MA_ATOMIC_FETCH_XOR_CAS(8, dst, src, order); + } + #endif } static MA_INLINE ma_uint16 ma_atomic_fetch_xor_explicit_16(volatile ma_uint16* dst, ma_uint16 src, ma_atomic_memory_order order) { - (void)order; - return __sync_fetch_and_xor(dst, src); + #if defined(MA_ATOMIC_IS_LOCK_FREE_16) + { + (void)order; + return __sync_fetch_and_xor(dst, src); + } + #else + { + MA_ATOMIC_FETCH_XOR_CAS(16, dst, src, order); + } + #endif } static MA_INLINE ma_uint32 ma_atomic_fetch_xor_explicit_32(volatile ma_uint32* dst, ma_uint32 src, ma_atomic_memory_order order) { - (void)order; - return __sync_fetch_and_xor(dst, src); + #if defined(MA_ATOMIC_IS_LOCK_FREE_32) + { + (void)order; + return __sync_fetch_and_xor(dst, src); + } + #else + { + MA_ATOMIC_FETCH_XOR_CAS(32, dst, src, order); + } + #endif } static MA_INLINE ma_uint64 ma_atomic_fetch_xor_explicit_64(volatile ma_uint64* dst, ma_uint64 src, ma_atomic_memory_order order) { - (void)order; - return __sync_fetch_and_xor(dst, src); + #if defined(MA_ATOMIC_IS_LOCK_FREE_64) + { + (void)order; + return __sync_fetch_and_xor(dst, src); + } + #else + { + MA_ATOMIC_FETCH_XOR_CAS(64, dst, src, order); + } + #endif } - static MA_INLINE ma_uint8 ma_atomic_fetch_and_explicit_8(volatile ma_uint8* dst, ma_uint8 src, ma_atomic_memory_order order) + #elif defined(MA_ATOMIC_LEGACY_GCC_ASM) + #define MA_ATOMIC_CMPXCHG_GCC_X86(instructionSizeSuffix, result, dst, expected, replacement) \ + __asm__ __volatile__( \ + "lock; cmpxchg"instructionSizeSuffix" %2, %1" \ + : "=a"(result), \ + "=m"(*dst) \ + : "r"(replacement), \ + "0"(expected), \ + "m"(*dst) \ + : "cc", "memory") + #define MA_ATOMIC_XADD_GCC_X86(instructionSizeSuffix, result, dst, src) \ + __asm__ __volatile__( \ + "lock; xadd"instructionSizeSuffix" %0, %1" \ + : "=a"(result), \ + "=m"(*dst) \ + : "0"(src), \ + "m"(*dst) \ + : "cc", "memory") + static MA_INLINE ma_uint8 ma_atomic_compare_and_swap_8(volatile ma_uint8* dst, ma_uint8 expected, ma_uint8 replacement) { - (void)order; - return __sync_fetch_and_and(dst, src); + #if defined(MA_ATOMIC_IS_LOCK_FREE_8) && (defined(MA_X86) || defined(MA_X64)) + { + ma_uint8 result; + #if defined(MA_X86) || defined(MA_X64) + { + MA_ATOMIC_CMPXCHG_GCC_X86("b", result, dst, expected, replacement); + } + #else + { + #error Unsupported architecture. + } + #endif + return result; + } + #else + { + MA_ATOMIC_COMPARE_AND_SWAP_LOCK(8, dst, expected, replacement); + } + #endif } - static MA_INLINE ma_uint16 ma_atomic_fetch_and_explicit_16(volatile ma_uint16* dst, ma_uint16 src, ma_atomic_memory_order order) + static MA_INLINE ma_uint16 ma_atomic_compare_and_swap_16(volatile ma_uint16* dst, ma_uint16 expected, ma_uint16 replacement) { - (void)order; - return __sync_fetch_and_and(dst, src); + #if defined(MA_ATOMIC_IS_LOCK_FREE_16) && (defined(MA_X86) || defined(MA_X64)) + { + ma_uint16 result; + #if defined(MA_X86) || defined(MA_X64) + { + MA_ATOMIC_CMPXCHG_GCC_X86("w", result, dst, expected, replacement); + } + #else + { + #error Unsupported architecture. + } + #endif + return result; + } + #else + { + MA_ATOMIC_COMPARE_AND_SWAP_LOCK(16, dst, expected, replacement); + } + #endif } - static MA_INLINE ma_uint32 ma_atomic_fetch_and_explicit_32(volatile ma_uint32* dst, ma_uint32 src, ma_atomic_memory_order order) + static MA_INLINE ma_uint32 ma_atomic_compare_and_swap_32(volatile ma_uint32* dst, ma_uint32 expected, ma_uint32 replacement) { - (void)order; - return __sync_fetch_and_and(dst, src); + #if defined(MA_ATOMIC_IS_LOCK_FREE_32) && (defined(MA_X86) || defined(MA_X64)) + { + ma_uint32 result; + #if defined(MA_X86) || defined(MA_X64) + { + MA_ATOMIC_CMPXCHG_GCC_X86("l", result, dst, expected, replacement); + } + #else + { + #error Unsupported architecture. + } + #endif + return result; + } + #else + { + MA_ATOMIC_COMPARE_AND_SWAP_LOCK(32, dst, expected, replacement); + } + #endif } - static MA_INLINE ma_uint64 ma_atomic_fetch_and_explicit_64(volatile ma_uint64* dst, ma_uint64 src, ma_atomic_memory_order order) + static MA_INLINE ma_uint64 ma_atomic_compare_and_swap_64(volatile ma_uint64* dst, ma_uint64 expected, ma_uint64 replacement) { - (void)order; - return __sync_fetch_and_and(dst, src); + #if defined(MA_ATOMIC_IS_LOCK_FREE_64) && (defined(MA_X86) || defined(MA_X64)) + { + ma_uint64 result; + #if defined(MA_X86) + { + ma_uint32 resultEAX; + ma_uint32 resultEDX; + __asm__ __volatile__( + "pushl %%ebx\n" + "movl %4, %%ebx\n" + "lock cmpxchg8b (%%edi)\n" + "popl %%ebx\n" + : "=a"(resultEAX), + "=d"(resultEDX) + : "a"((ma_uint32)(expected & 0xFFFFFFFF)), + "d"((ma_uint32)(expected >> 32)), + "r"((ma_uint32)(replacement & 0xFFFFFFFF)), + "c"((ma_uint32)(replacement >> 32)), + "D"(dst) + : "memory", "cc"); + result = ((ma_uint64)resultEDX << 32) | resultEAX; + } + #elif defined(MA_X64) + { + MA_ATOMIC_CMPXCHG_GCC_X86("q", result, dst, expected, replacement); + } + #else + { + #error Unsupported architecture. + } + #endif + return result; + } + #else + { + MA_ATOMIC_COMPARE_AND_SWAP_LOCK(64, dst, expected, replacement); + } + #endif } - #define ma_atomic_compare_and_swap_8( dst, expected, desired) __sync_val_compare_and_swap(dst, expected, desired) - #define ma_atomic_compare_and_swap_16(dst, expected, desired) __sync_val_compare_and_swap(dst, expected, desired) - #define ma_atomic_compare_and_swap_32(dst, expected, desired) __sync_val_compare_and_swap(dst, expected, desired) - #define ma_atomic_compare_and_swap_64(dst, expected, desired) __sync_val_compare_and_swap(dst, expected, desired) - #else - #if defined(MA_X86) - #define ma_atomic_thread_fence(order) __asm__ __volatile__("lock; addl $0, (%%esp)" ::: "memory", "cc") - #elif defined(MA_X64) - #define ma_atomic_thread_fence(order) __asm__ __volatile__("lock; addq $0, (%%rsp)" ::: "memory", "cc") - #else - #error Unsupported architecture. Please submit a feature request. - #endif - static MA_INLINE ma_uint8 ma_atomic_compare_and_swap_8(volatile ma_uint8* dst, ma_uint8 expected, ma_uint8 desired) + static MA_INLINE ma_uint8 ma_atomic_load_explicit_8(volatile const ma_uint8* dst, ma_atomic_memory_order order) { - ma_uint8 result; - #if defined(MA_X86) || defined(MA_X64) - __asm__ __volatile__("lock; cmpxchg %3, %0" : "+m"(*dst), "=a"(result) : "a"(expected), "d"(desired) : "cc"); - #else - #error Unsupported architecture. Please submit a feature request. - #endif - return result; + #if defined(MA_ATOMIC_IS_LOCK_FREE_8) && (defined(MA_X86) || defined(MA_X64)) + { + ma_uint8 result; + #if defined(MA_X86) || defined(MA_X64) + { + if (order == ma_atomic_memory_order_relaxed) { + MA_ATOMIC_LOAD_RELAXED_GCC_X86("b", result, dst); + } else if (order <= ma_atomic_memory_order_release) { + MA_ATOMIC_LOAD_RELEASE_GCC_X86("b", result, dst); + } else { + MA_ATOMIC_LOAD_SEQ_CST_GCC_X86("b", result, dst); + } + } + #else + { + #error Unsupported architecture. + } + #endif + return result; + } + #else + { + MA_ATOMIC_LOAD_EXPLICIT_LOCK(8, dst, order); + } + #endif } - static MA_INLINE ma_uint16 ma_atomic_compare_and_swap_16(volatile ma_uint16* dst, ma_uint16 expected, ma_uint16 desired) + static MA_INLINE ma_uint16 ma_atomic_load_explicit_16(volatile const ma_uint16* dst, ma_atomic_memory_order order) { - ma_uint16 result; - #if defined(MA_X86) || defined(MA_X64) - __asm__ __volatile__("lock; cmpxchg %3, %0" : "+m"(*dst), "=a"(result) : "a"(expected), "d"(desired) : "cc"); - #else - #error Unsupported architecture. Please submit a feature request. - #endif - return result; + #if defined(MA_ATOMIC_IS_LOCK_FREE_16) && (defined(MA_X86) || defined(MA_X64)) + { + ma_uint16 result; + #if defined(MA_X86) || defined(MA_X64) + { + if (order == ma_atomic_memory_order_relaxed) { + MA_ATOMIC_LOAD_RELAXED_GCC_X86("w", result, dst); + } else if (order <= ma_atomic_memory_order_release) { + MA_ATOMIC_LOAD_RELEASE_GCC_X86("w", result, dst); + } else { + MA_ATOMIC_LOAD_SEQ_CST_GCC_X86("w", result, dst); + } + } + #else + { + #error Unsupported architecture. + } + #endif + return result; + } + #else + { + MA_ATOMIC_LOAD_EXPLICIT_LOCK(16, dst, order); + } + #endif + } + static MA_INLINE ma_uint32 ma_atomic_load_explicit_32(volatile const ma_uint32* dst, ma_atomic_memory_order order) + { + #if defined(MA_ATOMIC_IS_LOCK_FREE_32) && (defined(MA_X86) || defined(MA_X64)) + { + ma_uint32 result; + #if defined(MA_X86) || defined(MA_X64) + { + if (order == ma_atomic_memory_order_relaxed) { + MA_ATOMIC_LOAD_RELAXED_GCC_X86("l", result, dst); + } else if (order <= ma_atomic_memory_order_release) { + MA_ATOMIC_LOAD_RELEASE_GCC_X86("l", result, dst); + } else { + MA_ATOMIC_LOAD_SEQ_CST_GCC_X86("l", result, dst); + } + } + #else + { + #error Unsupported architecture. + } + #endif + return result; + } + #else + { + MA_ATOMIC_LOAD_EXPLICIT_LOCK(32, dst, order); + } + #endif + } + static MA_INLINE ma_uint64 ma_atomic_load_explicit_64(volatile const ma_uint64* dst, ma_atomic_memory_order order) + { + #if defined(MA_ATOMIC_IS_LOCK_FREE_64) && (defined(MA_X86) || defined(MA_X64)) + { + ma_uint64 result; + #if defined(MA_X64) + { + if (order == ma_atomic_memory_order_relaxed) { + MA_ATOMIC_LOAD_RELAXED_GCC_X86("q", result, dst); + } else if (order <= ma_atomic_memory_order_release) { + MA_ATOMIC_LOAD_RELEASE_GCC_X86("q", result, dst); + } else { + MA_ATOMIC_LOAD_SEQ_CST_GCC_X86("q", result, dst); + } + } + #elif defined(MA_X86) + { + (void)order; + return ma_atomic_compare_and_swap_64((volatile ma_uint64*)dst, 0, 0); + } + #else + { + #error Unsupported architecture. + } + #endif + return result; + } + #else + { + MA_ATOMIC_LOAD_EXPLICIT_LOCK(64, dst, order); + } + #endif + } + static MA_INLINE ma_uint8 ma_atomic_exchange_explicit_8(volatile ma_uint8* dst, ma_uint8 src, ma_atomic_memory_order order) + { + #if defined(MA_ATOMIC_IS_LOCK_FREE_8) && (defined(MA_X86) || defined(MA_X64)) + { + ma_uint8 result; + (void)order; + #if defined(MA_X86) || defined(MA_X64) + { + MA_ATOMIC_XCHG_GCC_X86("b", result, dst, src); + } + #else + { + #error Unsupported architecture. + } + #endif + return result; + } + #else + { + MA_ATOMIC_EXCHANGE_EXPLICIT_LOCK(8, dst, src, order); + } + #endif + } + static MA_INLINE ma_uint16 ma_atomic_exchange_explicit_16(volatile ma_uint16* dst, ma_uint16 src, ma_atomic_memory_order order) + { + #if defined(MA_ATOMIC_IS_LOCK_FREE_16) && (defined(MA_X86) || defined(MA_X64)) + { + ma_uint16 result; + (void)order; + #if defined(MA_X86) || defined(MA_X64) + { + MA_ATOMIC_XCHG_GCC_X86("w", result, dst, src); + } + #else + { + #error Unsupported architecture. + } + #endif + return result; + } + #else + { + MA_ATOMIC_EXCHANGE_EXPLICIT_LOCK(16, dst, src, order); + } + #endif } - static MA_INLINE ma_uint32 ma_atomic_compare_and_swap_32(volatile ma_uint32* dst, ma_uint32 expected, ma_uint32 desired) + static MA_INLINE ma_uint32 ma_atomic_exchange_explicit_32(volatile ma_uint32* dst, ma_uint32 src, ma_atomic_memory_order order) { - ma_uint32 result; - #if defined(MA_X86) || defined(MA_X64) - __asm__ __volatile__("lock; cmpxchg %3, %0" : "+m"(*dst), "=a"(result) : "a"(expected), "d"(desired) : "cc"); - #else - #error Unsupported architecture. Please submit a feature request. - #endif - return result; + #if defined(MA_ATOMIC_IS_LOCK_FREE_32) && (defined(MA_X86) || defined(MA_X64)) + { + ma_uint32 result; + (void)order; + #if defined(MA_X86) || defined(MA_X64) + { + MA_ATOMIC_XCHG_GCC_X86("l", result, dst, src); + } + #else + { + #error Unsupported architecture. + } + #endif + return result; + } + #else + { + MA_ATOMIC_EXCHANGE_EXPLICIT_LOCK(32, dst, src, order); + } + #endif } - static MA_INLINE ma_uint64 ma_atomic_compare_and_swap_64(volatile ma_uint64* dst, ma_uint64 expected, ma_uint64 desired) + static MA_INLINE ma_uint64 ma_atomic_exchange_explicit_64(volatile ma_uint64* dst, ma_uint64 src, ma_atomic_memory_order order) { - volatile ma_uint64 result; - #if defined(MA_X86) - ma_uint32 resultEAX; - ma_uint32 resultEDX; - __asm__ __volatile__("push %%ebx; xchg %5, %%ebx; lock; cmpxchg8b %0; pop %%ebx" : "+m"(*dst), "=a"(resultEAX), "=d"(resultEDX) : "a"(expected & 0xFFFFFFFF), "d"(expected >> 32), "r"(desired & 0xFFFFFFFF), "c"(desired >> 32) : "cc"); - result = ((ma_uint64)resultEDX << 32) | resultEAX; - #elif defined(MA_X64) - __asm__ __volatile__("lock; cmpxchg %3, %0" : "+m"(*dst), "=a"(result) : "a"(expected), "d"(desired) : "cc"); - #else - #error Unsupported architecture. Please submit a feature request. - #endif - return result; + #if defined(MA_ATOMIC_IS_LOCK_FREE_64) && (defined(MA_X86) || defined(MA_X64)) + { + ma_uint64 result; + (void)order; + #if defined(MA_X86) + { + MA_ATOMIC_EXCHANGE_EXPLICIT_CAS(64, dst, src, order); + } + #elif defined(MA_X64) + { + MA_ATOMIC_XCHG_GCC_X86("q", result, dst, src); + } + #else + { + #error Unsupported architecture. + } + #endif + return result; + } + #else + { + MA_ATOMIC_EXCHANGE_EXPLICIT_LOCK(64, dst, src, order); + } + #endif } - static MA_INLINE ma_uint8 ma_atomic_exchange_explicit_8(volatile ma_uint8* dst, ma_uint8 src, ma_atomic_memory_order order) + static MA_INLINE void ma_atomic_store_explicit_8(volatile ma_uint8* dst, ma_uint8 src, ma_atomic_memory_order order) { - ma_uint8 result = 0; - (void)order; - #if defined(MA_X86) || defined(MA_X64) - __asm__ __volatile__("lock; xchg %1, %0" : "+m"(*dst), "=a"(result) : "a"(src)); - #else - #error Unsupported architecture. Please submit a feature request. - #endif - return result; + #if defined(MA_ATOMIC_IS_LOCK_FREE_8) && (defined(MA_X86) || defined(MA_X64)) + { + #if defined(MA_X86) || defined(MA_X64) + { + if (order == ma_atomic_memory_order_relaxed) { + __asm__ __volatile__ ( + "movb %1, %0" + : "=m"(*dst) + : "r"(src) + ); + } else { + __asm__ __volatile__ ( + "xchgb %1, %0" + : "=m"(*dst) + : "r"(src) + : "memory" + ); + } + } + #else + { + #error Unsupported architecture. + } + #endif + } + #else + { + MA_ATOMIC_STORE_EXPLICIT_LOCK(8, dst, src, order); + } + #endif } - static MA_INLINE ma_uint16 ma_atomic_exchange_explicit_16(volatile ma_uint16* dst, ma_uint16 src, ma_atomic_memory_order order) + static MA_INLINE void ma_atomic_store_explicit_16(volatile ma_uint16* dst, ma_uint16 src, ma_atomic_memory_order order) { - ma_uint16 result = 0; - (void)order; - #if defined(MA_X86) || defined(MA_X64) - __asm__ __volatile__("lock; xchg %1, %0" : "+m"(*dst), "=a"(result) : "a"(src)); - #else - #error Unsupported architecture. Please submit a feature request. - #endif - return result; + #if defined(MA_ATOMIC_IS_LOCK_FREE_16) && (defined(MA_X86) || defined(MA_X64)) + { + #if defined(MA_X86) || defined(MA_X64) + { + if (order == ma_atomic_memory_order_relaxed) { + __asm__ __volatile__ ( + "movw %1, %0" + : "=m"(*dst) + : "r"(src) + ); + } else { + __asm__ __volatile__ ( + "xchgw %1, %0" + : "=m"(*dst) + : "r"(src) + : "memory" + ); + } + } + #else + { + #error Unsupported architecture. + } + #endif + } + #else + { + MA_ATOMIC_STORE_EXPLICIT_LOCK(16, dst, src, order); + } + #endif } - static MA_INLINE ma_uint32 ma_atomic_exchange_explicit_32(volatile ma_uint32* dst, ma_uint32 src, ma_atomic_memory_order order) + static MA_INLINE void ma_atomic_store_explicit_32(volatile ma_uint32* dst, ma_uint32 src, ma_atomic_memory_order order) { - ma_uint32 result; - (void)order; - #if defined(MA_X86) || defined(MA_X64) - __asm__ __volatile__("lock; xchg %1, %0" : "+m"(*dst), "=a"(result) : "a"(src)); - #else - #error Unsupported architecture. Please submit a feature request. - #endif - return result; + #if defined(MA_ATOMIC_IS_LOCK_FREE_32) && (defined(MA_X86) || defined(MA_X64)) + { + #if defined(MA_X86) || defined(MA_X64) + { + if (order == ma_atomic_memory_order_relaxed) { + __asm__ __volatile__ ( + "movl %1, %0" + : "=m"(*dst) + : "r"(src) + ); + } else { + __asm__ __volatile__ ( + "xchgl %1, %0" + : "=m"(*dst) + : "r"(src) + : "memory" + ); + } + } + #else + { + #error Unsupported architecture. + } + #endif + } + #else + { + MA_ATOMIC_STORE_EXPLICIT_LOCK(32, dst, src, order); + } + #endif } - static MA_INLINE ma_uint64 ma_atomic_exchange_explicit_64(volatile ma_uint64* dst, ma_uint64 src, ma_atomic_memory_order order) + static MA_INLINE void ma_atomic_store_explicit_64(volatile ma_uint64* dst, ma_uint64 src, ma_atomic_memory_order order) { - ma_uint64 result; - (void)order; - #if defined(MA_X86) - do { - result = *dst; - } while (ma_atomic_compare_and_swap_64(dst, result, src) != result); - #elif defined(MA_X64) - __asm__ __volatile__("lock; xchg %1, %0" : "+m"(*dst), "=a"(result) : "a"(src)); - #else - #error Unsupported architecture. Please submit a feature request. - #endif - return result; + #if defined(MA_ATOMIC_IS_LOCK_FREE_64) && (defined(MA_X86) || defined(MA_X64)) + { + #if defined(MA_X64) + { + if (order == ma_atomic_memory_order_relaxed) { + __asm__ __volatile__ ( + "movq %1, %0" + : "=m"(*dst) + : "r"(src) + ); + } else { + __asm__ __volatile__ ( + "xchgq %1, %0" + : "=m"(*dst) + : "r"(src) + : "memory" + ); + } + } + #else + { + MA_ATOMIC_STORE_EXPLICIT_CAS(64, dst, src, order); + } + #endif + } + #else + { + MA_ATOMIC_STORE_EXPLICIT_LOCK(64, dst, src, order); + } + #endif } static MA_INLINE ma_uint8 ma_atomic_fetch_add_explicit_8(volatile ma_uint8* dst, ma_uint8 src, ma_atomic_memory_order order) { - ma_uint8 result; - (void)order; - #if defined(MA_X86) || defined(MA_X64) - __asm__ __volatile__("lock; xadd %1, %0" : "+m"(*dst), "=a"(result) : "a"(src) : "cc"); - #else - #error Unsupported architecture. Please submit a feature request. - #endif - return result; + #if defined(MA_ATOMIC_IS_LOCK_FREE_8) && (defined(MA_X86) || defined(MA_X64)) + { + #if defined(MA_X86) || defined(MA_X64) + { + ma_uint8 result; + (void)order; + MA_ATOMIC_XADD_GCC_X86("b", result, dst, src); + return result; + } + #else + { + #error Unsupported architecture. + } + #endif + } + #else + { + MA_ATOMIC_FETCH_ADD_LOCK(8, dst, src, order); + } + #endif } static MA_INLINE ma_uint16 ma_atomic_fetch_add_explicit_16(volatile ma_uint16* dst, ma_uint16 src, ma_atomic_memory_order order) { - ma_uint16 result; - (void)order; - #if defined(MA_X86) || defined(MA_X64) - __asm__ __volatile__("lock; xadd %1, %0" : "+m"(*dst), "=a"(result) : "a"(src) : "cc"); - #else - #error Unsupported architecture. Please submit a feature request. - #endif - return result; + #if defined(MA_ATOMIC_IS_LOCK_FREE_16) && (defined(MA_X86) || defined(MA_X64)) + { + #if defined(MA_X86) || defined(MA_X64) + { + ma_uint16 result; + (void)order; + MA_ATOMIC_XADD_GCC_X86("w", result, dst, src); + return result; + } + #else + { + #error Unsupported architecture. + } + #endif + } + #else + { + MA_ATOMIC_FETCH_ADD_LOCK(16, dst, src, order); + } + #endif } static MA_INLINE ma_uint32 ma_atomic_fetch_add_explicit_32(volatile ma_uint32* dst, ma_uint32 src, ma_atomic_memory_order order) { - ma_uint32 result; - (void)order; - #if defined(MA_X86) || defined(MA_X64) - __asm__ __volatile__("lock; xadd %1, %0" : "+m"(*dst), "=a"(result) : "a"(src) : "cc"); - #else - #error Unsupported architecture. Please submit a feature request. - #endif - return result; + #if defined(MA_ATOMIC_IS_LOCK_FREE_32) && (defined(MA_X86) || defined(MA_X64)) + { + #if defined(MA_X86) || defined(MA_X64) + { + ma_uint32 result; + (void)order; + MA_ATOMIC_XADD_GCC_X86("l", result, dst, src); + return result; + } + #else + { + #error Unsupported architecture. + } + #endif + } + #else + { + MA_ATOMIC_FETCH_ADD_LOCK(32, dst, src, order); + } + #endif } static MA_INLINE ma_uint64 ma_atomic_fetch_add_explicit_64(volatile ma_uint64* dst, ma_uint64 src, ma_atomic_memory_order order) { - #if defined(MA_X86) - ma_uint64 oldValue; - ma_uint64 newValue; - (void)order; - do { - oldValue = *dst; - newValue = oldValue + src; - } while (ma_atomic_compare_and_swap_64(dst, oldValue, newValue) != oldValue); - return oldValue; - #elif defined(MA_X64) - ma_uint64 result; - (void)order; - __asm__ __volatile__("lock; xadd %1, %0" : "+m"(*dst), "=a"(result) : "a"(src) : "cc"); - return result; - #endif + #if defined(MA_ATOMIC_IS_LOCK_FREE_64) && (defined(MA_X86) || defined(MA_X64)) + { + #if defined(MA_X86) + { + MA_ATOMIC_FETCH_ADD_CAS(64, dst, src, order); + } + #elif defined(MA_X64) + { + ma_uint64 result; + MA_ATOMIC_XADD_GCC_X86("q", result, dst, src); + (void)order; + return result; + } + #else + { + #error Unsupported architecture. + } + #endif + } + #else + { + MA_ATOMIC_FETCH_ADD_LOCK(64, dst, src, order); + } + #endif } static MA_INLINE ma_uint8 ma_atomic_fetch_sub_explicit_8(volatile ma_uint8* dst, ma_uint8 src, ma_atomic_memory_order order) { - ma_uint8 oldValue; - ma_uint8 newValue; - do { - oldValue = *dst; - newValue = (ma_uint8)(oldValue - src); - } while (ma_atomic_compare_and_swap_8(dst, oldValue, newValue) != oldValue); - (void)order; - return oldValue; + return ma_atomic_fetch_add_explicit_8(dst, (ma_uint8)(-(ma_int8)src), order); } static MA_INLINE ma_uint16 ma_atomic_fetch_sub_explicit_16(volatile ma_uint16* dst, ma_uint16 src, ma_atomic_memory_order order) { - ma_uint16 oldValue; - ma_uint16 newValue; - do { - oldValue = *dst; - newValue = (ma_uint16)(oldValue - src); - } while (ma_atomic_compare_and_swap_16(dst, oldValue, newValue) != oldValue); - (void)order; - return oldValue; + return ma_atomic_fetch_add_explicit_16(dst, (ma_uint16)(-(ma_int16)src), order); } static MA_INLINE ma_uint32 ma_atomic_fetch_sub_explicit_32(volatile ma_uint32* dst, ma_uint32 src, ma_atomic_memory_order order) { - ma_uint32 oldValue; - ma_uint32 newValue; - do { - oldValue = *dst; - newValue = oldValue - src; - } while (ma_atomic_compare_and_swap_32(dst, oldValue, newValue) != oldValue); - (void)order; - return oldValue; + return ma_atomic_fetch_add_explicit_32(dst, (ma_uint32)(-(ma_int32)src), order); } static MA_INLINE ma_uint64 ma_atomic_fetch_sub_explicit_64(volatile ma_uint64* dst, ma_uint64 src, ma_atomic_memory_order order) { - ma_uint64 oldValue; - ma_uint64 newValue; - do { - oldValue = *dst; - newValue = oldValue - src; - } while (ma_atomic_compare_and_swap_64(dst, oldValue, newValue) != oldValue); - (void)order; - return oldValue; + return ma_atomic_fetch_add_explicit_64(dst, (ma_uint64)(-(ma_int64)src), order); } static MA_INLINE ma_uint8 ma_atomic_fetch_and_explicit_8(volatile ma_uint8* dst, ma_uint8 src, ma_atomic_memory_order order) { - ma_uint8 oldValue; - ma_uint8 newValue; - do { - oldValue = *dst; - newValue = (ma_uint8)(oldValue & src); - } while (ma_atomic_compare_and_swap_8(dst, oldValue, newValue) != oldValue); - (void)order; - return oldValue; + MA_ATOMIC_FETCH_AND_CAS(8, dst, src, order); } static MA_INLINE ma_uint16 ma_atomic_fetch_and_explicit_16(volatile ma_uint16* dst, ma_uint16 src, ma_atomic_memory_order order) { - ma_uint16 oldValue; - ma_uint16 newValue; - do { - oldValue = *dst; - newValue = (ma_uint16)(oldValue & src); - } while (ma_atomic_compare_and_swap_16(dst, oldValue, newValue) != oldValue); - (void)order; - return oldValue; + MA_ATOMIC_FETCH_AND_CAS(16, dst, src, order); } static MA_INLINE ma_uint32 ma_atomic_fetch_and_explicit_32(volatile ma_uint32* dst, ma_uint32 src, ma_atomic_memory_order order) { - ma_uint32 oldValue; - ma_uint32 newValue; - do { - oldValue = *dst; - newValue = oldValue & src; - } while (ma_atomic_compare_and_swap_32(dst, oldValue, newValue) != oldValue); - (void)order; - return oldValue; + MA_ATOMIC_FETCH_AND_CAS(32, dst, src, order); } static MA_INLINE ma_uint64 ma_atomic_fetch_and_explicit_64(volatile ma_uint64* dst, ma_uint64 src, ma_atomic_memory_order order) { - ma_uint64 oldValue; - ma_uint64 newValue; - do { - oldValue = *dst; - newValue = oldValue & src; - } while (ma_atomic_compare_and_swap_64(dst, oldValue, newValue) != oldValue); - (void)order; - return oldValue; + MA_ATOMIC_FETCH_AND_CAS(64, dst, src, order); } - static MA_INLINE ma_uint8 ma_atomic_fetch_xor_explicit_8(volatile ma_uint8* dst, ma_uint8 src, ma_atomic_memory_order order) + static MA_INLINE ma_uint8 ma_atomic_fetch_or_explicit_8(volatile ma_uint8* dst, ma_uint8 src, ma_atomic_memory_order order) { - ma_uint8 oldValue; - ma_uint8 newValue; - do { - oldValue = *dst; - newValue = (ma_uint8)(oldValue ^ src); - } while (ma_atomic_compare_and_swap_8(dst, oldValue, newValue) != oldValue); - (void)order; - return oldValue; + MA_ATOMIC_FETCH_OR_CAS(8, dst, src, order); } - static MA_INLINE ma_uint16 ma_atomic_fetch_xor_explicit_16(volatile ma_uint16* dst, ma_uint16 src, ma_atomic_memory_order order) + static MA_INLINE ma_uint16 ma_atomic_fetch_or_explicit_16(volatile ma_uint16* dst, ma_uint16 src, ma_atomic_memory_order order) { - ma_uint16 oldValue; - ma_uint16 newValue; - do { - oldValue = *dst; - newValue = (ma_uint16)(oldValue ^ src); - } while (ma_atomic_compare_and_swap_16(dst, oldValue, newValue) != oldValue); - (void)order; - return oldValue; + MA_ATOMIC_FETCH_OR_CAS(16, dst, src, order); } - static MA_INLINE ma_uint32 ma_atomic_fetch_xor_explicit_32(volatile ma_uint32* dst, ma_uint32 src, ma_atomic_memory_order order) + static MA_INLINE ma_uint32 ma_atomic_fetch_or_explicit_32(volatile ma_uint32* dst, ma_uint32 src, ma_atomic_memory_order order) { - ma_uint32 oldValue; - ma_uint32 newValue; - do { - oldValue = *dst; - newValue = oldValue ^ src; - } while (ma_atomic_compare_and_swap_32(dst, oldValue, newValue) != oldValue); - (void)order; - return oldValue; + MA_ATOMIC_FETCH_OR_CAS(32, dst, src, order); } - static MA_INLINE ma_uint64 ma_atomic_fetch_xor_explicit_64(volatile ma_uint64* dst, ma_uint64 src, ma_atomic_memory_order order) + static MA_INLINE ma_uint64 ma_atomic_fetch_or_explicit_64(volatile ma_uint64* dst, ma_uint64 src, ma_atomic_memory_order order) { - ma_uint64 oldValue; - ma_uint64 newValue; - do { - oldValue = *dst; - newValue = oldValue ^ src; - } while (ma_atomic_compare_and_swap_64(dst, oldValue, newValue) != oldValue); - (void)order; - return oldValue; + MA_ATOMIC_FETCH_OR_CAS(64, dst, src, order); } - static MA_INLINE ma_uint8 ma_atomic_fetch_or_explicit_8(volatile ma_uint8* dst, ma_uint8 src, ma_atomic_memory_order order) + static MA_INLINE ma_uint8 ma_atomic_fetch_xor_explicit_8(volatile ma_uint8* dst, ma_uint8 src, ma_atomic_memory_order order) { - ma_uint8 oldValue; - ma_uint8 newValue; - do { - oldValue = *dst; - newValue = (ma_uint8)(oldValue | src); - } while (ma_atomic_compare_and_swap_8(dst, oldValue, newValue) != oldValue); - (void)order; - return oldValue; + MA_ATOMIC_FETCH_XOR_CAS(8, dst, src, order); } - static MA_INLINE ma_uint16 ma_atomic_fetch_or_explicit_16(volatile ma_uint16* dst, ma_uint16 src, ma_atomic_memory_order order) + static MA_INLINE ma_uint16 ma_atomic_fetch_xor_explicit_16(volatile ma_uint16* dst, ma_uint16 src, ma_atomic_memory_order order) { - ma_uint16 oldValue; - ma_uint16 newValue; - do { - oldValue = *dst; - newValue = (ma_uint16)(oldValue | src); - } while (ma_atomic_compare_and_swap_16(dst, oldValue, newValue) != oldValue); - (void)order; - return oldValue; + MA_ATOMIC_FETCH_XOR_CAS(16, dst, src, order); } - static MA_INLINE ma_uint32 ma_atomic_fetch_or_explicit_32(volatile ma_uint32* dst, ma_uint32 src, ma_atomic_memory_order order) + static MA_INLINE ma_uint32 ma_atomic_fetch_xor_explicit_32(volatile ma_uint32* dst, ma_uint32 src, ma_atomic_memory_order order) { - ma_uint32 oldValue; - ma_uint32 newValue; - do { - oldValue = *dst; - newValue = oldValue | src; - } while (ma_atomic_compare_and_swap_32(dst, oldValue, newValue) != oldValue); - (void)order; - return oldValue; + MA_ATOMIC_FETCH_XOR_CAS(32, dst, src, order); } - static MA_INLINE ma_uint64 ma_atomic_fetch_or_explicit_64(volatile ma_uint64* dst, ma_uint64 src, ma_atomic_memory_order order) + static MA_INLINE ma_uint64 ma_atomic_fetch_xor_explicit_64(volatile ma_uint64* dst, ma_uint64 src, ma_atomic_memory_order order) { - ma_uint64 oldValue; - ma_uint64 newValue; - do { - oldValue = *dst; - newValue = oldValue | src; - } while (ma_atomic_compare_and_swap_64(dst, oldValue, newValue) != oldValue); - (void)order; - return oldValue; + MA_ATOMIC_FETCH_XOR_CAS(64, dst, src, order); } + #else + #error Unsupported compiler. #endif - #define ma_atomic_signal_fence(order) ma_atomic_thread_fence(order) - static MA_INLINE ma_uint8 ma_atomic_load_explicit_8(volatile const ma_uint8* ptr, ma_atomic_memory_order order) - { - (void)order; - return ma_atomic_compare_and_swap_8((ma_uint8*)ptr, 0, 0); - } - static MA_INLINE ma_uint16 ma_atomic_load_explicit_16(volatile const ma_uint16* ptr, ma_atomic_memory_order order) - { - (void)order; - return ma_atomic_compare_and_swap_16((ma_uint16*)ptr, 0, 0); - } - static MA_INLINE ma_uint32 ma_atomic_load_explicit_32(volatile const ma_uint32* ptr, ma_atomic_memory_order order) - { - (void)order; - return ma_atomic_compare_and_swap_32((ma_uint32*)ptr, 0, 0); - } - static MA_INLINE ma_uint64 ma_atomic_load_explicit_64(volatile const ma_uint64* ptr, ma_atomic_memory_order order) - { - (void)order; - return ma_atomic_compare_and_swap_64((ma_uint64*)ptr, 0, 0); - } - #define ma_atomic_store_explicit_8( dst, src, order) (void)ma_atomic_exchange_explicit_8 (dst, src, order) - #define ma_atomic_store_explicit_16(dst, src, order) (void)ma_atomic_exchange_explicit_16(dst, src, order) - #define ma_atomic_store_explicit_32(dst, src, order) (void)ma_atomic_exchange_explicit_32(dst, src, order) - #define ma_atomic_store_explicit_64(dst, src, order) (void)ma_atomic_exchange_explicit_64(dst, src, order) - #define ma_atomic_test_and_set_explicit_8( dst, order) ma_atomic_exchange_explicit_8 (dst, 1, order) - #define ma_atomic_test_and_set_explicit_16(dst, order) ma_atomic_exchange_explicit_16(dst, 1, order) - #define ma_atomic_test_and_set_explicit_32(dst, order) ma_atomic_exchange_explicit_32(dst, 1, order) - #define ma_atomic_test_and_set_explicit_64(dst, order) ma_atomic_exchange_explicit_64(dst, 1, order) - #define ma_atomic_clear_explicit_8( dst, order) ma_atomic_store_explicit_8 (dst, 0, order) - #define ma_atomic_clear_explicit_16(dst, order) ma_atomic_store_explicit_16(dst, 0, order) - #define ma_atomic_clear_explicit_32(dst, order) ma_atomic_store_explicit_32(dst, 0, order) - #define ma_atomic_clear_explicit_64(dst, order) ma_atomic_store_explicit_64(dst, 0, order) - typedef ma_uint8 ma_atomic_flag; - #define ma_atomic_flag_test_and_set_explicit(ptr, order) (ma_bool32)ma_atomic_test_and_set_explicit_8(ptr, order) - #define ma_atomic_flag_clear_explicit(ptr, order) ma_atomic_clear_explicit_8(ptr, order) - #define ma_atomic_flag_load_explicit(ptr, order) ma_atomic_load_explicit_8(ptr, order) #endif #if !defined(MA_ATOMIC_HAS_NATIVE_COMPARE_EXCHANGE) - #if defined(MA_ATOMIC_HAS_8) - static MA_INLINE ma_bool32 ma_atomic_compare_exchange_strong_explicit_8(volatile ma_uint8* dst, ma_uint8* expected, ma_uint8 desired, ma_atomic_memory_order successOrder, ma_atomic_memory_order failureOrder) - { - ma_uint8 expectedValue; - ma_uint8 result; - (void)successOrder; - (void)failureOrder; - expectedValue = ma_atomic_load_explicit_8(expected, ma_atomic_memory_order_seq_cst); - result = ma_atomic_compare_and_swap_8(dst, expectedValue, desired); - if (result == expectedValue) { - return 1; - } else { - ma_atomic_store_explicit_8(expected, result, failureOrder); - return 0; - } - } - #endif - #if defined(MA_ATOMIC_HAS_16) - static MA_INLINE ma_bool32 ma_atomic_compare_exchange_strong_explicit_16(volatile ma_uint16* dst, ma_uint16* expected, ma_uint16 desired, ma_atomic_memory_order successOrder, ma_atomic_memory_order failureOrder) - { - ma_uint16 expectedValue; - ma_uint16 result; - (void)successOrder; - (void)failureOrder; - expectedValue = ma_atomic_load_explicit_16(expected, ma_atomic_memory_order_seq_cst); - result = ma_atomic_compare_and_swap_16(dst, expectedValue, desired); - if (result == expectedValue) { - return 1; - } else { - ma_atomic_store_explicit_16(expected, result, failureOrder); - return 0; - } - } - #endif - #if defined(MA_ATOMIC_HAS_32) - static MA_INLINE ma_bool32 ma_atomic_compare_exchange_strong_explicit_32(volatile ma_uint32* dst, ma_uint32* expected, ma_uint32 desired, ma_atomic_memory_order successOrder, ma_atomic_memory_order failureOrder) - { - ma_uint32 expectedValue; - ma_uint32 result; - (void)successOrder; - (void)failureOrder; - expectedValue = ma_atomic_load_explicit_32(expected, ma_atomic_memory_order_seq_cst); - result = ma_atomic_compare_and_swap_32(dst, expectedValue, desired); - if (result == expectedValue) { - return 1; - } else { - ma_atomic_store_explicit_32(expected, result, failureOrder); - return 0; - } - } - #endif - #if defined(MA_ATOMIC_HAS_64) - static MA_INLINE ma_bool32 ma_atomic_compare_exchange_strong_explicit_64(volatile ma_uint64* dst, volatile ma_uint64* expected, ma_uint64 desired, ma_atomic_memory_order successOrder, ma_atomic_memory_order failureOrder) - { - ma_uint64 expectedValue; - ma_uint64 result; - (void)successOrder; - (void)failureOrder; - expectedValue = ma_atomic_load_explicit_64(expected, ma_atomic_memory_order_seq_cst); - result = ma_atomic_compare_and_swap_64(dst, expectedValue, desired); - if (result == expectedValue) { - return 1; - } else { - ma_atomic_store_explicit_64(expected, result, failureOrder); - return 0; - } - } - #endif - #define ma_atomic_compare_exchange_weak_explicit_8( dst, expected, desired, successOrder, failureOrder) ma_atomic_compare_exchange_strong_explicit_8 (dst, expected, desired, successOrder, failureOrder) - #define ma_atomic_compare_exchange_weak_explicit_16(dst, expected, desired, successOrder, failureOrder) ma_atomic_compare_exchange_strong_explicit_16(dst, expected, desired, successOrder, failureOrder) - #define ma_atomic_compare_exchange_weak_explicit_32(dst, expected, desired, successOrder, failureOrder) ma_atomic_compare_exchange_strong_explicit_32(dst, expected, desired, successOrder, failureOrder) - #define ma_atomic_compare_exchange_weak_explicit_64(dst, expected, desired, successOrder, failureOrder) ma_atomic_compare_exchange_strong_explicit_64(dst, expected, desired, successOrder, failureOrder) -#endif -#if !defined(MA_ATOMIC_HAS_NATIVE_IS_LOCK_FREE) - static MA_INLINE ma_bool32 ma_atomic_is_lock_free_8(volatile void* ptr) + static MA_INLINE ma_bool32 ma_atomic_compare_exchange_strong_explicit_8(volatile ma_uint8* dst, ma_uint8* expected, ma_uint8 replacement, ma_atomic_memory_order successOrder, ma_atomic_memory_order failureOrder) { - (void)ptr; - return 1; + ma_uint8 result; + (void)successOrder; + (void)failureOrder; + result = ma_atomic_compare_and_swap_8(dst, *expected, replacement); + if (result == *expected) { + return 1; + } else { + *expected = result; + return 0; + } } - static MA_INLINE ma_bool32 ma_atomic_is_lock_free_16(volatile void* ptr) + static MA_INLINE ma_bool32 ma_atomic_compare_exchange_strong_explicit_16(volatile ma_uint16* dst, ma_uint16* expected, ma_uint16 replacement, ma_atomic_memory_order successOrder, ma_atomic_memory_order failureOrder) { - (void)ptr; - return 1; + ma_uint16 result; + (void)successOrder; + (void)failureOrder; + result = ma_atomic_compare_and_swap_16(dst, *expected, replacement); + if (result == *expected) { + return 1; + } else { + *expected = result; + return 0; + } } - static MA_INLINE ma_bool32 ma_atomic_is_lock_free_32(volatile void* ptr) + static MA_INLINE ma_bool32 ma_atomic_compare_exchange_strong_explicit_32(volatile ma_uint32* dst, ma_uint32* expected, ma_uint32 replacement, ma_atomic_memory_order successOrder, ma_atomic_memory_order failureOrder) { - (void)ptr; - return 1; + ma_uint32 result; + (void)successOrder; + (void)failureOrder; + result = ma_atomic_compare_and_swap_32(dst, *expected, replacement); + if (result == *expected) { + return 1; + } else { + *expected = result; + return 0; + } } - static MA_INLINE ma_bool32 ma_atomic_is_lock_free_64(volatile void* ptr) + static MA_INLINE ma_bool32 ma_atomic_compare_exchange_strong_explicit_64(volatile ma_uint64* dst, volatile ma_uint64* expected, ma_uint64 replacement, ma_atomic_memory_order successOrder, ma_atomic_memory_order failureOrder) { - (void)ptr; - #if defined(MA_64BIT) - return 1; - #else - #if defined(MA_X86) || defined(MA_X64) + ma_uint64 result; + (void)successOrder; + (void)failureOrder; + result = ma_atomic_compare_and_swap_64(dst, *expected, replacement); + if (result == *expected) { return 1; - #else + } else { + *expected = result; return 0; - #endif - #endif + } } + #define ma_atomic_compare_exchange_weak_explicit_8( dst, expected, replacement, successOrder, failureOrder) ma_atomic_compare_exchange_strong_explicit_8 (dst, expected, replacement, successOrder, failureOrder) + #define ma_atomic_compare_exchange_weak_explicit_16(dst, expected, replacement, successOrder, failureOrder) ma_atomic_compare_exchange_strong_explicit_16(dst, expected, replacement, successOrder, failureOrder) + #define ma_atomic_compare_exchange_weak_explicit_32(dst, expected, replacement, successOrder, failureOrder) ma_atomic_compare_exchange_strong_explicit_32(dst, expected, replacement, successOrder, failureOrder) + #define ma_atomic_compare_exchange_weak_explicit_64(dst, expected, replacement, successOrder, failureOrder) ma_atomic_compare_exchange_strong_explicit_64(dst, expected, replacement, successOrder, failureOrder) #endif #if defined(MA_64BIT) static MA_INLINE ma_bool32 ma_atomic_is_lock_free_ptr(volatile void** ptr) @@ -15561,17 +17005,17 @@ typedef int ma_atomic_memory_order; { return (void*)ma_atomic_exchange_explicit_64((volatile ma_uint64*)dst, (ma_uint64)src, order); } - static MA_INLINE ma_bool32 ma_atomic_compare_exchange_strong_explicit_ptr(volatile void** dst, void** expected, void* desired, ma_atomic_memory_order successOrder, ma_atomic_memory_order failureOrder) + static MA_INLINE ma_bool32 ma_atomic_compare_exchange_strong_explicit_ptr(volatile void** dst, void** expected, void* replacement, ma_atomic_memory_order successOrder, ma_atomic_memory_order failureOrder) { - return ma_atomic_compare_exchange_strong_explicit_64((volatile ma_uint64*)dst, (ma_uint64*)expected, (ma_uint64)desired, successOrder, failureOrder); + return ma_atomic_compare_exchange_strong_explicit_64((volatile ma_uint64*)dst, (ma_uint64*)expected, (ma_uint64)replacement, successOrder, failureOrder); } - static MA_INLINE ma_bool32 ma_atomic_compare_exchange_weak_explicit_ptr(volatile void** dst, void** expected, void* desired, ma_atomic_memory_order successOrder, ma_atomic_memory_order failureOrder) + static MA_INLINE ma_bool32 ma_atomic_compare_exchange_weak_explicit_ptr(volatile void** dst, void** expected, void* replacement, ma_atomic_memory_order successOrder, ma_atomic_memory_order failureOrder) { - return ma_atomic_compare_exchange_weak_explicit_64((volatile ma_uint64*)dst, (ma_uint64*)expected, (ma_uint64)desired, successOrder, failureOrder); + return ma_atomic_compare_exchange_weak_explicit_64((volatile ma_uint64*)dst, (ma_uint64*)expected, (ma_uint64)replacement, successOrder, failureOrder); } - static MA_INLINE void* ma_atomic_compare_and_swap_ptr(volatile void** dst, void* expected, void* desired) + static MA_INLINE void* ma_atomic_compare_and_swap_ptr(volatile void** dst, void* expected, void* replacement) { - return (void*)ma_atomic_compare_and_swap_64((volatile ma_uint64*)dst, (ma_uint64)expected, (ma_uint64)desired); + return (void*)ma_atomic_compare_and_swap_64((volatile ma_uint64*)dst, (ma_uint64)expected, (ma_uint64)replacement); } #elif defined(MA_32BIT) static MA_INLINE ma_bool32 ma_atomic_is_lock_free_ptr(volatile void** ptr) @@ -15590,36 +17034,26 @@ typedef int ma_atomic_memory_order; { return (void*)ma_atomic_exchange_explicit_32((volatile ma_uint32*)dst, (ma_uint32)src, order); } - static MA_INLINE ma_bool32 ma_atomic_compare_exchange_strong_explicit_ptr(volatile void** dst, void** expected, void* desired, ma_atomic_memory_order successOrder, ma_atomic_memory_order failureOrder) + static MA_INLINE ma_bool32 ma_atomic_compare_exchange_strong_explicit_ptr(volatile void** dst, void** expected, void* replacement, ma_atomic_memory_order successOrder, ma_atomic_memory_order failureOrder) { - return ma_atomic_compare_exchange_strong_explicit_32((volatile ma_uint32*)dst, (ma_uint32*)expected, (ma_uint32)desired, successOrder, failureOrder); + return ma_atomic_compare_exchange_strong_explicit_32((volatile ma_uint32*)dst, (ma_uint32*)expected, (ma_uint32)replacement, successOrder, failureOrder); } - static MA_INLINE ma_bool32 ma_atomic_compare_exchange_weak_explicit_ptr(volatile void** dst, void** expected, void* desired, ma_atomic_memory_order successOrder, ma_atomic_memory_order failureOrder) + static MA_INLINE ma_bool32 ma_atomic_compare_exchange_weak_explicit_ptr(volatile void** dst, void** expected, void* replacement, ma_atomic_memory_order successOrder, ma_atomic_memory_order failureOrder) { - return ma_atomic_compare_exchange_weak_explicit_32((volatile ma_uint32*)dst, (ma_uint32*)expected, (ma_uint32)desired, successOrder, failureOrder); + return ma_atomic_compare_exchange_weak_explicit_32((volatile ma_uint32*)dst, (ma_uint32*)expected, (ma_uint32)replacement, successOrder, failureOrder); } - static MA_INLINE void* ma_atomic_compare_and_swap_ptr(volatile void** dst, void* expected, void* desired) + static MA_INLINE void* ma_atomic_compare_and_swap_ptr(volatile void** dst, void* expected, void* replacement) { - return (void*)ma_atomic_compare_and_swap_32((volatile ma_uint32*)dst, (ma_uint32)expected, (ma_uint32)desired); + return (void*)ma_atomic_compare_and_swap_32((volatile ma_uint32*)dst, (ma_uint32)expected, (ma_uint32)replacement); } #else #error Unsupported architecture. #endif -#define ma_atomic_flag_test_and_set(ptr) ma_atomic_flag_test_and_set_explicit(ptr, ma_atomic_memory_order_seq_cst) -#define ma_atomic_flag_clear(ptr) ma_atomic_flag_clear_explicit(ptr, ma_atomic_memory_order_seq_cst) -#define ma_atomic_store_ptr(dst, src) ma_atomic_store_explicit_ptr((volatile void**)dst, (void*)src, ma_atomic_memory_order_seq_cst) -#define ma_atomic_load_ptr(ptr) ma_atomic_load_explicit_ptr((volatile void**)ptr, ma_atomic_memory_order_seq_cst) -#define ma_atomic_exchange_ptr(dst, src) ma_atomic_exchange_explicit_ptr((volatile void**)dst, (void*)src, ma_atomic_memory_order_seq_cst) -#define ma_atomic_compare_exchange_strong_ptr(dst, expected, desired) ma_atomic_compare_exchange_strong_explicit_ptr((volatile void**)dst, (void**)expected, (void*)desired, ma_atomic_memory_order_seq_cst, ma_atomic_memory_order_seq_cst) -#define ma_atomic_compare_exchange_weak_ptr(dst, expected, desired) ma_atomic_compare_exchange_weak_explicit_ptr((volatile void**)dst, (void**)expected, (void*)desired, ma_atomic_memory_order_seq_cst, ma_atomic_memory_order_seq_cst) -#define ma_atomic_test_and_set_8( ptr) ma_atomic_test_and_set_explicit_8( ptr, ma_atomic_memory_order_seq_cst) -#define ma_atomic_test_and_set_16(ptr) ma_atomic_test_and_set_explicit_16(ptr, ma_atomic_memory_order_seq_cst) -#define ma_atomic_test_and_set_32(ptr) ma_atomic_test_and_set_explicit_32(ptr, ma_atomic_memory_order_seq_cst) -#define ma_atomic_test_and_set_64(ptr) ma_atomic_test_and_set_explicit_64(ptr, ma_atomic_memory_order_seq_cst) -#define ma_atomic_clear_8( ptr) ma_atomic_clear_explicit_8( ptr, ma_atomic_memory_order_seq_cst) -#define ma_atomic_clear_16(ptr) ma_atomic_clear_explicit_16(ptr, ma_atomic_memory_order_seq_cst) -#define ma_atomic_clear_32(ptr) ma_atomic_clear_explicit_32(ptr, ma_atomic_memory_order_seq_cst) -#define ma_atomic_clear_64(ptr) ma_atomic_clear_explicit_64(ptr, ma_atomic_memory_order_seq_cst) +#define ma_atomic_store_ptr(dst, src) ma_atomic_store_explicit_ptr((volatile void**)dst, (void*)src, ma_atomic_memory_order_seq_cst) +#define ma_atomic_load_ptr(ptr) ma_atomic_load_explicit_ptr((volatile void**)ptr, ma_atomic_memory_order_seq_cst) +#define ma_atomic_exchange_ptr(dst, src) ma_atomic_exchange_explicit_ptr((volatile void**)dst, (void*)src, ma_atomic_memory_order_seq_cst) +#define ma_atomic_compare_exchange_strong_ptr(dst, expected, replacement) ma_atomic_compare_exchange_strong_explicit_ptr((volatile void**)dst, (void**)expected, (void*)replacement, ma_atomic_memory_order_seq_cst, ma_atomic_memory_order_seq_cst) +#define ma_atomic_compare_exchange_weak_ptr(dst, expected, replacement) ma_atomic_compare_exchange_weak_explicit_ptr((volatile void**)dst, (void**)expected, (void*)replacement, ma_atomic_memory_order_seq_cst, ma_atomic_memory_order_seq_cst) #define ma_atomic_store_8( dst, src) ma_atomic_store_explicit_8( dst, src, ma_atomic_memory_order_seq_cst) #define ma_atomic_store_16(dst, src) ma_atomic_store_explicit_16(dst, src, ma_atomic_memory_order_seq_cst) #define ma_atomic_store_32(dst, src) ma_atomic_store_explicit_32(dst, src, ma_atomic_memory_order_seq_cst) @@ -15632,14 +17066,14 @@ typedef int ma_atomic_memory_order; #define ma_atomic_exchange_16(dst, src) ma_atomic_exchange_explicit_16(dst, src, ma_atomic_memory_order_seq_cst) #define ma_atomic_exchange_32(dst, src) ma_atomic_exchange_explicit_32(dst, src, ma_atomic_memory_order_seq_cst) #define ma_atomic_exchange_64(dst, src) ma_atomic_exchange_explicit_64(dst, src, ma_atomic_memory_order_seq_cst) -#define ma_atomic_compare_exchange_strong_8( dst, expected, desired) ma_atomic_compare_exchange_strong_explicit_8( dst, expected, desired, ma_atomic_memory_order_seq_cst, ma_atomic_memory_order_seq_cst) -#define ma_atomic_compare_exchange_strong_16(dst, expected, desired) ma_atomic_compare_exchange_strong_explicit_16(dst, expected, desired, ma_atomic_memory_order_seq_cst, ma_atomic_memory_order_seq_cst) -#define ma_atomic_compare_exchange_strong_32(dst, expected, desired) ma_atomic_compare_exchange_strong_explicit_32(dst, expected, desired, ma_atomic_memory_order_seq_cst, ma_atomic_memory_order_seq_cst) -#define ma_atomic_compare_exchange_strong_64(dst, expected, desired) ma_atomic_compare_exchange_strong_explicit_64(dst, expected, desired, ma_atomic_memory_order_seq_cst, ma_atomic_memory_order_seq_cst) -#define ma_atomic_compare_exchange_weak_8( dst, expected, desired) ma_atomic_compare_exchange_weak_explicit_8( dst, expected, desired, ma_atomic_memory_order_seq_cst, ma_atomic_memory_order_seq_cst) -#define ma_atomic_compare_exchange_weak_16( dst, expected, desired) ma_atomic_compare_exchange_weak_explicit_16(dst, expected, desired, ma_atomic_memory_order_seq_cst, ma_atomic_memory_order_seq_cst) -#define ma_atomic_compare_exchange_weak_32( dst, expected, desired) ma_atomic_compare_exchange_weak_explicit_32(dst, expected, desired, ma_atomic_memory_order_seq_cst, ma_atomic_memory_order_seq_cst) -#define ma_atomic_compare_exchange_weak_64( dst, expected, desired) ma_atomic_compare_exchange_weak_explicit_64(dst, expected, desired, ma_atomic_memory_order_seq_cst, ma_atomic_memory_order_seq_cst) +#define ma_atomic_compare_exchange_strong_8( dst, expected, replacement) ma_atomic_compare_exchange_strong_explicit_8( dst, expected, replacement, ma_atomic_memory_order_seq_cst, ma_atomic_memory_order_seq_cst) +#define ma_atomic_compare_exchange_strong_16(dst, expected, replacement) ma_atomic_compare_exchange_strong_explicit_16(dst, expected, replacement, ma_atomic_memory_order_seq_cst, ma_atomic_memory_order_seq_cst) +#define ma_atomic_compare_exchange_strong_32(dst, expected, replacement) ma_atomic_compare_exchange_strong_explicit_32(dst, expected, replacement, ma_atomic_memory_order_seq_cst, ma_atomic_memory_order_seq_cst) +#define ma_atomic_compare_exchange_strong_64(dst, expected, replacement) ma_atomic_compare_exchange_strong_explicit_64(dst, expected, replacement, ma_atomic_memory_order_seq_cst, ma_atomic_memory_order_seq_cst) +#define ma_atomic_compare_exchange_weak_8( dst, expected, replacement) ma_atomic_compare_exchange_weak_explicit_8( dst, expected, replacement, ma_atomic_memory_order_seq_cst, ma_atomic_memory_order_seq_cst) +#define ma_atomic_compare_exchange_weak_16( dst, expected, replacement) ma_atomic_compare_exchange_weak_explicit_16(dst, expected, replacement, ma_atomic_memory_order_seq_cst, ma_atomic_memory_order_seq_cst) +#define ma_atomic_compare_exchange_weak_32( dst, expected, replacement) ma_atomic_compare_exchange_weak_explicit_32(dst, expected, replacement, ma_atomic_memory_order_seq_cst, ma_atomic_memory_order_seq_cst) +#define ma_atomic_compare_exchange_weak_64( dst, expected, replacement) ma_atomic_compare_exchange_weak_explicit_64(dst, expected, replacement, ma_atomic_memory_order_seq_cst, ma_atomic_memory_order_seq_cst) #define ma_atomic_fetch_add_8( dst, src) ma_atomic_fetch_add_explicit_8( dst, src, ma_atomic_memory_order_seq_cst) #define ma_atomic_fetch_add_16(dst, src) ma_atomic_fetch_add_explicit_16(dst, src, ma_atomic_memory_order_seq_cst) #define ma_atomic_fetch_add_32(dst, src) ma_atomic_fetch_add_explicit_32(dst, src, ma_atomic_memory_order_seq_cst) @@ -15660,14 +17094,6 @@ typedef int ma_atomic_memory_order; #define ma_atomic_fetch_and_16(dst, src) ma_atomic_fetch_and_explicit_16(dst, src, ma_atomic_memory_order_seq_cst) #define ma_atomic_fetch_and_32(dst, src) ma_atomic_fetch_and_explicit_32(dst, src, ma_atomic_memory_order_seq_cst) #define ma_atomic_fetch_and_64(dst, src) ma_atomic_fetch_and_explicit_64(dst, src, ma_atomic_memory_order_seq_cst) -#define ma_atomic_test_and_set_explicit_i8( ptr, order) (ma_int8 )ma_atomic_test_and_set_explicit_8( (ma_uint8* )ptr, order) -#define ma_atomic_test_and_set_explicit_i16(ptr, order) (ma_int16)ma_atomic_test_and_set_explicit_16((ma_uint16*)ptr, order) -#define ma_atomic_test_and_set_explicit_i32(ptr, order) (ma_int32)ma_atomic_test_and_set_explicit_32((ma_uint32*)ptr, order) -#define ma_atomic_test_and_set_explicit_i64(ptr, order) (ma_int64)ma_atomic_test_and_set_explicit_64((ma_uint64*)ptr, order) -#define ma_atomic_clear_explicit_i8( ptr, order) ma_atomic_clear_explicit_8( (ma_uint8* )ptr, order) -#define ma_atomic_clear_explicit_i16(ptr, order) ma_atomic_clear_explicit_16((ma_uint16*)ptr, order) -#define ma_atomic_clear_explicit_i32(ptr, order) ma_atomic_clear_explicit_32((ma_uint32*)ptr, order) -#define ma_atomic_clear_explicit_i64(ptr, order) ma_atomic_clear_explicit_64((ma_uint64*)ptr, order) #define ma_atomic_store_explicit_i8( dst, src, order) ma_atomic_store_explicit_8( (ma_uint8* )dst, (ma_uint8 )src, order) #define ma_atomic_store_explicit_i16(dst, src, order) ma_atomic_store_explicit_16((ma_uint16*)dst, (ma_uint16)src, order) #define ma_atomic_store_explicit_i32(dst, src, order) ma_atomic_store_explicit_32((ma_uint32*)dst, (ma_uint32)src, order) @@ -15680,14 +17106,14 @@ typedef int ma_atomic_memory_order; #define ma_atomic_exchange_explicit_i16(dst, src, order) (ma_int16)ma_atomic_exchange_explicit_16((ma_uint16*)dst, (ma_uint16)src, order) #define ma_atomic_exchange_explicit_i32(dst, src, order) (ma_int32)ma_atomic_exchange_explicit_32((ma_uint32*)dst, (ma_uint32)src, order) #define ma_atomic_exchange_explicit_i64(dst, src, order) (ma_int64)ma_atomic_exchange_explicit_64((ma_uint64*)dst, (ma_uint64)src, order) -#define ma_atomic_compare_exchange_strong_explicit_i8( dst, expected, desired, successOrder, failureOrder) ma_atomic_compare_exchange_strong_explicit_8( (ma_uint8* )dst, (ma_uint8* )expected, (ma_uint8 )desired, successOrder, failureOrder) -#define ma_atomic_compare_exchange_strong_explicit_i16(dst, expected, desired, successOrder, failureOrder) ma_atomic_compare_exchange_strong_explicit_16((ma_uint16*)dst, (ma_uint16*)expected, (ma_uint16)desired, successOrder, failureOrder) -#define ma_atomic_compare_exchange_strong_explicit_i32(dst, expected, desired, successOrder, failureOrder) ma_atomic_compare_exchange_strong_explicit_32((ma_uint32*)dst, (ma_uint32*)expected, (ma_uint32)desired, successOrder, failureOrder) -#define ma_atomic_compare_exchange_strong_explicit_i64(dst, expected, desired, successOrder, failureOrder) ma_atomic_compare_exchange_strong_explicit_64((ma_uint64*)dst, (ma_uint64*)expected, (ma_uint64)desired, successOrder, failureOrder) -#define ma_atomic_compare_exchange_weak_explicit_i8( dst, expected, desired, successOrder, failureOrder) ma_atomic_compare_exchange_weak_explicit_8( (ma_uint8* )dst, (ma_uint8* )expected, (ma_uint8 )desired, successOrder, failureOrder) -#define ma_atomic_compare_exchange_weak_explicit_i16(dst, expected, desired, successOrder, failureOrder) ma_atomic_compare_exchange_weak_explicit_16((ma_uint16*)dst, (ma_uint16*)expected, (ma_uint16)desired, successOrder, failureOrder) -#define ma_atomic_compare_exchange_weak_explicit_i32(dst, expected, desired, successOrder, failureOrder) ma_atomic_compare_exchange_weak_explicit_32((ma_uint32*)dst, (ma_uint32*)expected, (ma_uint32)desired, successOrder, failureOrder) -#define ma_atomic_compare_exchange_weak_explicit_i64(dst, expected, desired, successOrder, failureOrder) ma_atomic_compare_exchange_weak_explicit_64((ma_uint64*)dst, (ma_uint64*)expected, (ma_uint64)desired, successOrder, failureOrder) +#define ma_atomic_compare_exchange_strong_explicit_i8( dst, expected, replacement, successOrder, failureOrder) ma_atomic_compare_exchange_strong_explicit_8( (ma_uint8* )dst, (ma_uint8* )expected, (ma_uint8 )replacement, successOrder, failureOrder) +#define ma_atomic_compare_exchange_strong_explicit_i16(dst, expected, replacement, successOrder, failureOrder) ma_atomic_compare_exchange_strong_explicit_16((ma_uint16*)dst, (ma_uint16*)expected, (ma_uint16)replacement, successOrder, failureOrder) +#define ma_atomic_compare_exchange_strong_explicit_i32(dst, expected, replacement, successOrder, failureOrder) ma_atomic_compare_exchange_strong_explicit_32((ma_uint32*)dst, (ma_uint32*)expected, (ma_uint32)replacement, successOrder, failureOrder) +#define ma_atomic_compare_exchange_strong_explicit_i64(dst, expected, replacement, successOrder, failureOrder) ma_atomic_compare_exchange_strong_explicit_64((ma_uint64*)dst, (ma_uint64*)expected, (ma_uint64)replacement, successOrder, failureOrder) +#define ma_atomic_compare_exchange_weak_explicit_i8( dst, expected, replacement, successOrder, failureOrder) ma_atomic_compare_exchange_weak_explicit_8( (ma_uint8* )dst, (ma_uint8* )expected, (ma_uint8 )replacement, successOrder, failureOrder) +#define ma_atomic_compare_exchange_weak_explicit_i16(dst, expected, replacement, successOrder, failureOrder) ma_atomic_compare_exchange_weak_explicit_16((ma_uint16*)dst, (ma_uint16*)expected, (ma_uint16)replacement, successOrder, failureOrder) +#define ma_atomic_compare_exchange_weak_explicit_i32(dst, expected, replacement, successOrder, failureOrder) ma_atomic_compare_exchange_weak_explicit_32((ma_uint32*)dst, (ma_uint32*)expected, (ma_uint32)replacement, successOrder, failureOrder) +#define ma_atomic_compare_exchange_weak_explicit_i64(dst, expected, replacement, successOrder, failureOrder) ma_atomic_compare_exchange_weak_explicit_64((ma_uint64*)dst, (ma_uint64*)expected, (ma_uint64)replacement, successOrder, failureOrder) #define ma_atomic_fetch_add_explicit_i8( dst, src, order) (ma_int8 )ma_atomic_fetch_add_explicit_8( (ma_uint8* )dst, (ma_uint8 )src, order) #define ma_atomic_fetch_add_explicit_i16(dst, src, order) (ma_int16)ma_atomic_fetch_add_explicit_16((ma_uint16*)dst, (ma_uint16)src, order) #define ma_atomic_fetch_add_explicit_i32(dst, src, order) (ma_int32)ma_atomic_fetch_add_explicit_32((ma_uint32*)dst, (ma_uint32)src, order) @@ -15708,14 +17134,6 @@ typedef int ma_atomic_memory_order; #define ma_atomic_fetch_and_explicit_i16(dst, src, order) (ma_int16)ma_atomic_fetch_and_explicit_16((ma_uint16*)dst, (ma_uint16)src, order) #define ma_atomic_fetch_and_explicit_i32(dst, src, order) (ma_int32)ma_atomic_fetch_and_explicit_32((ma_uint32*)dst, (ma_uint32)src, order) #define ma_atomic_fetch_and_explicit_i64(dst, src, order) (ma_int64)ma_atomic_fetch_and_explicit_64((ma_uint64*)dst, (ma_uint64)src, order) -#define ma_atomic_test_and_set_i8( ptr) ma_atomic_test_and_set_explicit_i8( ptr, ma_atomic_memory_order_seq_cst) -#define ma_atomic_test_and_set_i16(ptr) ma_atomic_test_and_set_explicit_i16(ptr, ma_atomic_memory_order_seq_cst) -#define ma_atomic_test_and_set_i32(ptr) ma_atomic_test_and_set_explicit_i32(ptr, ma_atomic_memory_order_seq_cst) -#define ma_atomic_test_and_set_i64(ptr) ma_atomic_test_and_set_explicit_i64(ptr, ma_atomic_memory_order_seq_cst) -#define ma_atomic_clear_i8( ptr) ma_atomic_clear_explicit_i8( ptr, ma_atomic_memory_order_seq_cst) -#define ma_atomic_clear_i16(ptr) ma_atomic_clear_explicit_i16(ptr, ma_atomic_memory_order_seq_cst) -#define ma_atomic_clear_i32(ptr) ma_atomic_clear_explicit_i32(ptr, ma_atomic_memory_order_seq_cst) -#define ma_atomic_clear_i64(ptr) ma_atomic_clear_explicit_i64(ptr, ma_atomic_memory_order_seq_cst) #define ma_atomic_store_i8( dst, src) ma_atomic_store_explicit_i8( dst, src, ma_atomic_memory_order_seq_cst) #define ma_atomic_store_i16(dst, src) ma_atomic_store_explicit_i16(dst, src, ma_atomic_memory_order_seq_cst) #define ma_atomic_store_i32(dst, src) ma_atomic_store_explicit_i32(dst, src, ma_atomic_memory_order_seq_cst) @@ -15728,14 +17146,14 @@ typedef int ma_atomic_memory_order; #define ma_atomic_exchange_i16(dst, src) ma_atomic_exchange_explicit_i16(dst, src, ma_atomic_memory_order_seq_cst) #define ma_atomic_exchange_i32(dst, src) ma_atomic_exchange_explicit_i32(dst, src, ma_atomic_memory_order_seq_cst) #define ma_atomic_exchange_i64(dst, src) ma_atomic_exchange_explicit_i64(dst, src, ma_atomic_memory_order_seq_cst) -#define ma_atomic_compare_exchange_strong_i8( dst, expected, desired) ma_atomic_compare_exchange_strong_explicit_i8( dst, expected, desired, ma_atomic_memory_order_seq_cst, ma_atomic_memory_order_seq_cst) -#define ma_atomic_compare_exchange_strong_i16(dst, expected, desired) ma_atomic_compare_exchange_strong_explicit_i16(dst, expected, desired, ma_atomic_memory_order_seq_cst, ma_atomic_memory_order_seq_cst) -#define ma_atomic_compare_exchange_strong_i32(dst, expected, desired) ma_atomic_compare_exchange_strong_explicit_i32(dst, expected, desired, ma_atomic_memory_order_seq_cst, ma_atomic_memory_order_seq_cst) -#define ma_atomic_compare_exchange_strong_i64(dst, expected, desired) ma_atomic_compare_exchange_strong_explicit_i64(dst, expected, desired, ma_atomic_memory_order_seq_cst, ma_atomic_memory_order_seq_cst) -#define ma_atomic_compare_exchange_weak_i8( dst, expected, desired) ma_atomic_compare_exchange_weak_explicit_i8( dst, expected, desired, ma_atomic_memory_order_seq_cst, ma_atomic_memory_order_seq_cst) -#define ma_atomic_compare_exchange_weak_i16(dst, expected, desired) ma_atomic_compare_exchange_weak_explicit_i16(dst, expected, desired, ma_atomic_memory_order_seq_cst, ma_atomic_memory_order_seq_cst) -#define ma_atomic_compare_exchange_weak_i32(dst, expected, desired) ma_atomic_compare_exchange_weak_explicit_i32(dst, expected, desired, ma_atomic_memory_order_seq_cst, ma_atomic_memory_order_seq_cst) -#define ma_atomic_compare_exchange_weak_i64(dst, expected, desired) ma_atomic_compare_exchange_weak_explicit_i64(dst, expected, desired, ma_atomic_memory_order_seq_cst, ma_atomic_memory_order_seq_cst) +#define ma_atomic_compare_exchange_strong_i8( dst, expected, replacement) ma_atomic_compare_exchange_strong_explicit_i8( dst, expected, replacement, ma_atomic_memory_order_seq_cst, ma_atomic_memory_order_seq_cst) +#define ma_atomic_compare_exchange_strong_i16(dst, expected, replacement) ma_atomic_compare_exchange_strong_explicit_i16(dst, expected, replacement, ma_atomic_memory_order_seq_cst, ma_atomic_memory_order_seq_cst) +#define ma_atomic_compare_exchange_strong_i32(dst, expected, replacement) ma_atomic_compare_exchange_strong_explicit_i32(dst, expected, replacement, ma_atomic_memory_order_seq_cst, ma_atomic_memory_order_seq_cst) +#define ma_atomic_compare_exchange_strong_i64(dst, expected, replacement) ma_atomic_compare_exchange_strong_explicit_i64(dst, expected, replacement, ma_atomic_memory_order_seq_cst, ma_atomic_memory_order_seq_cst) +#define ma_atomic_compare_exchange_weak_i8( dst, expected, replacement) ma_atomic_compare_exchange_weak_explicit_i8( dst, expected, replacement, ma_atomic_memory_order_seq_cst, ma_atomic_memory_order_seq_cst) +#define ma_atomic_compare_exchange_weak_i16(dst, expected, replacement) ma_atomic_compare_exchange_weak_explicit_i16(dst, expected, replacement, ma_atomic_memory_order_seq_cst, ma_atomic_memory_order_seq_cst) +#define ma_atomic_compare_exchange_weak_i32(dst, expected, replacement) ma_atomic_compare_exchange_weak_explicit_i32(dst, expected, replacement, ma_atomic_memory_order_seq_cst, ma_atomic_memory_order_seq_cst) +#define ma_atomic_compare_exchange_weak_i64(dst, expected, replacement) ma_atomic_compare_exchange_weak_explicit_i64(dst, expected, replacement, ma_atomic_memory_order_seq_cst, ma_atomic_memory_order_seq_cst) #define ma_atomic_fetch_add_i8( dst, src) ma_atomic_fetch_add_explicit_i8( dst, src, ma_atomic_memory_order_seq_cst) #define ma_atomic_fetch_add_i16(dst, src) ma_atomic_fetch_add_explicit_i16(dst, src, ma_atomic_memory_order_seq_cst) #define ma_atomic_fetch_add_i32(dst, src) ma_atomic_fetch_add_explicit_i32(dst, src, ma_atomic_memory_order_seq_cst) @@ -15812,28 +17230,28 @@ static MA_INLINE double ma_atomic_exchange_explicit_f64(volatile double* dst, do r.i = ma_atomic_exchange_explicit_64((volatile ma_uint64*)dst, x.i, order); return r.f; } -static MA_INLINE ma_bool32 ma_atomic_compare_exchange_strong_explicit_f32(volatile float* dst, float* expected, float desired, ma_atomic_memory_order successOrder, ma_atomic_memory_order failureOrder) +static MA_INLINE ma_bool32 ma_atomic_compare_exchange_strong_explicit_f32(volatile float* dst, float* expected, float replacement, ma_atomic_memory_order successOrder, ma_atomic_memory_order failureOrder) { ma_atomic_if32 d; - d.f = desired; + d.f = replacement; return ma_atomic_compare_exchange_strong_explicit_32((volatile ma_uint32*)dst, (ma_uint32*)expected, d.i, successOrder, failureOrder); } -static MA_INLINE ma_bool32 ma_atomic_compare_exchange_strong_explicit_f64(volatile double* dst, double* expected, double desired, ma_atomic_memory_order successOrder, ma_atomic_memory_order failureOrder) +static MA_INLINE ma_bool32 ma_atomic_compare_exchange_strong_explicit_f64(volatile double* dst, double* expected, double replacement, ma_atomic_memory_order successOrder, ma_atomic_memory_order failureOrder) { ma_atomic_if64 d; - d.f = desired; + d.f = replacement; return ma_atomic_compare_exchange_strong_explicit_64((volatile ma_uint64*)dst, (ma_uint64*)expected, d.i, successOrder, failureOrder); } -static MA_INLINE ma_bool32 ma_atomic_compare_exchange_weak_explicit_f32(volatile float* dst, float* expected, float desired, ma_atomic_memory_order successOrder, ma_atomic_memory_order failureOrder) +static MA_INLINE ma_bool32 ma_atomic_compare_exchange_weak_explicit_f32(volatile float* dst, float* expected, float replacement, ma_atomic_memory_order successOrder, ma_atomic_memory_order failureOrder) { ma_atomic_if32 d; - d.f = desired; + d.f = replacement; return ma_atomic_compare_exchange_weak_explicit_32((volatile ma_uint32*)dst, (ma_uint32*)expected, d.i, successOrder, failureOrder); } -static MA_INLINE ma_bool32 ma_atomic_compare_exchange_weak_explicit_f64(volatile double* dst, double* expected, double desired, ma_atomic_memory_order successOrder, ma_atomic_memory_order failureOrder) +static MA_INLINE ma_bool32 ma_atomic_compare_exchange_weak_explicit_f64(volatile double* dst, double* expected, double replacement, ma_atomic_memory_order successOrder, ma_atomic_memory_order failureOrder) { ma_atomic_if64 d; - d.f = desired; + d.f = replacement; return ma_atomic_compare_exchange_weak_explicit_64((volatile ma_uint64*)dst, (ma_uint64*)expected, d.i, successOrder, failureOrder); } static MA_INLINE float ma_atomic_fetch_add_explicit_f32(volatile float* dst, float src, ma_atomic_memory_order order) @@ -15924,10 +17342,10 @@ static MA_INLINE double ma_atomic_fetch_and_explicit_f64(volatile double* dst, d #define ma_atomic_load_f64(ptr) (double)ma_atomic_load_explicit_f64(ptr, ma_atomic_memory_order_seq_cst) #define ma_atomic_exchange_f32(dst, src) (float )ma_atomic_exchange_explicit_f32(dst, src, ma_atomic_memory_order_seq_cst) #define ma_atomic_exchange_f64(dst, src) (double)ma_atomic_exchange_explicit_f64(dst, src, ma_atomic_memory_order_seq_cst) -#define ma_atomic_compare_exchange_strong_f32(dst, expected, desired) ma_atomic_compare_exchange_strong_explicit_f32(dst, expected, desired, ma_atomic_memory_order_seq_cst, ma_atomic_memory_order_seq_cst) -#define ma_atomic_compare_exchange_strong_f64(dst, expected, desired) ma_atomic_compare_exchange_strong_explicit_f64(dst, expected, desired, ma_atomic_memory_order_seq_cst, ma_atomic_memory_order_seq_cst) -#define ma_atomic_compare_exchange_weak_f32(dst, expected, desired) ma_atomic_compare_exchange_weak_explicit_f32(dst, expected, desired, ma_atomic_memory_order_seq_cst, ma_atomic_memory_order_seq_cst) -#define ma_atomic_compare_exchange_weak_f64(dst, expected, desired) ma_atomic_compare_exchange_weak_explicit_f64(dst, expected, desired, ma_atomic_memory_order_seq_cst, ma_atomic_memory_order_seq_cst) +#define ma_atomic_compare_exchange_strong_f32(dst, expected, replacement) ma_atomic_compare_exchange_strong_explicit_f32(dst, expected, replacement, ma_atomic_memory_order_seq_cst, ma_atomic_memory_order_seq_cst) +#define ma_atomic_compare_exchange_strong_f64(dst, expected, replacement) ma_atomic_compare_exchange_strong_explicit_f64(dst, expected, replacement, ma_atomic_memory_order_seq_cst, ma_atomic_memory_order_seq_cst) +#define ma_atomic_compare_exchange_weak_f32(dst, expected, replacement) ma_atomic_compare_exchange_weak_explicit_f32(dst, expected, replacement, ma_atomic_memory_order_seq_cst, ma_atomic_memory_order_seq_cst) +#define ma_atomic_compare_exchange_weak_f64(dst, expected, replacement) ma_atomic_compare_exchange_weak_explicit_f64(dst, expected, replacement, ma_atomic_memory_order_seq_cst, ma_atomic_memory_order_seq_cst) #define ma_atomic_fetch_add_f32(dst, src) ma_atomic_fetch_add_explicit_f32(dst, src, ma_atomic_memory_order_seq_cst) #define ma_atomic_fetch_add_f64(dst, src) ma_atomic_fetch_add_explicit_f64(dst, src, ma_atomic_memory_order_seq_cst) #define ma_atomic_fetch_sub_f32(dst, src) ma_atomic_fetch_sub_explicit_f32(dst, src, ma_atomic_memory_order_seq_cst) @@ -15938,39 +17356,24 @@ static MA_INLINE double ma_atomic_fetch_and_explicit_f64(volatile double* dst, d #define ma_atomic_fetch_xor_f64(dst, src) ma_atomic_fetch_xor_explicit_f64(dst, src, ma_atomic_memory_order_seq_cst) #define ma_atomic_fetch_and_f32(dst, src) ma_atomic_fetch_and_explicit_f32(dst, src, ma_atomic_memory_order_seq_cst) #define ma_atomic_fetch_and_f64(dst, src) ma_atomic_fetch_and_explicit_f64(dst, src, ma_atomic_memory_order_seq_cst) -static MA_INLINE float ma_atomic_compare_and_swap_f32(volatile float* dst, float expected, float desired) +static MA_INLINE float ma_atomic_compare_and_swap_f32(volatile float* dst, float expected, float replacement) { ma_atomic_if32 r; ma_atomic_if32 e, d; e.f = expected; - d.f = desired; + d.f = replacement; r.i = ma_atomic_compare_and_swap_32((volatile ma_uint32*)dst, e.i, d.i); return r.f; } -static MA_INLINE double ma_atomic_compare_and_swap_f64(volatile double* dst, double expected, double desired) +static MA_INLINE double ma_atomic_compare_and_swap_f64(volatile double* dst, double expected, double replacement) { ma_atomic_if64 r; ma_atomic_if64 e, d; e.f = expected; - d.f = desired; + d.f = replacement; r.i = ma_atomic_compare_and_swap_64((volatile ma_uint64*)dst, e.i, d.i); return r.f; } -typedef ma_atomic_flag ma_atomic_spinlock; -static MA_INLINE void ma_atomic_spinlock_lock(volatile ma_atomic_spinlock* pSpinlock) -{ - for (;;) { - if (ma_atomic_flag_test_and_set_explicit(pSpinlock, ma_atomic_memory_order_acquire) == 0) { - break; - } - while (ma_atomic_flag_load_explicit(pSpinlock, ma_atomic_memory_order_relaxed) == 1) { - } - } -} -static MA_INLINE void ma_atomic_spinlock_unlock(volatile ma_atomic_spinlock* pSpinlock) -{ - ma_atomic_flag_clear_explicit(pSpinlock, ma_atomic_memory_order_release); -} #if defined(__clang__) || (defined(__GNUC__) && (__GNUC__ > 4 || (__GNUC__ == 4 && __GNUC_MINOR__ >= 6))) #pragma GCC diagnostic pop #endif @@ -16176,7 +17579,7 @@ static ma_result ma_thread_create__posix(ma_thread* pThread, ma_thread_priority int result; pthread_attr_t* pAttr = NULL; -#if !defined(__EMSCRIPTEN__) && !defined(__3DS__) +#if !defined(MA_EMSCRIPTEN) && !defined(MA_3DS) && !defined(MA_SWITCH) /* Try setting the thread priority. It's not critical if anything fails here. */ pthread_attr_t attr; if (pthread_attr_init(&attr) == 0) { @@ -16208,9 +17611,18 @@ static ma_result ma_thread_create__posix(ma_thread* pThread, ma_thread_priority } #endif - if (stackSize > 0) { - pthread_attr_setstacksize(&attr, stackSize); + #if defined(_POSIX_THREAD_ATTR_STACKSIZE) && _POSIX_THREAD_ATTR_STACKSIZE >= 0 + { + if (stackSize > 0) { + pthread_attr_setstacksize(&attr, stackSize); + } + } + #else + { + (void)stackSize; /* Suppress unused parameter warning. */ } + #endif + if (scheduler != -1) { int priorityMin = sched_get_priority_min(scheduler); @@ -16218,7 +17630,7 @@ static ma_result ma_thread_create__posix(ma_thread* pThread, ma_thread_priority int priorityStep = (priorityMax - priorityMin) / 7; /* 7 = number of priorities supported by miniaudio. */ struct sched_param sched; - if (pthread_attr_getschedparam(&attr, &sched) == 0) { + if (priorityMin != -1 && priorityMax != -1 && pthread_attr_getschedparam(&attr, &sched) == 0) { if (priority == ma_thread_priority_idle) { sched.sched_priority = priorityMin; } else if (priority == ma_thread_priority_realtime) { @@ -16267,6 +17679,21 @@ static ma_result ma_thread_create__posix(ma_thread* pThread, ma_thread_priority } if (result != 0) { + /* + There have been reports that attempting to create a realtime thread can sometimes fail. In this case, + fall back to a normal priority thread. + + I'm including a compile-time option here to disable this functionality for those who have a hard + requirement on realtime threads and would rather an explicit failure. + */ + #ifndef MA_NO_PTHREAD_REALTIME_PRIORITY_FALLBACK + { + if(result == EPERM && priority == ma_thread_priority_realtime) { + return ma_thread_create__posix(pThread, ma_thread_priority_normal, stackSize, entryProc, pData); + } + } + #endif + return ma_result_from_errno(result); } @@ -16538,7 +17965,7 @@ static ma_result ma_event_signal__win32(ma_event* pEvent) static ma_result ma_semaphore_init__win32(int initialValue, ma_semaphore* pSemaphore) { - *pSemaphore = CreateSemaphoreW(NULL, (LONG)initialValue, LONG_MAX, NULL); + *pSemaphore = CreateSemaphore(NULL, (LONG)initialValue, LONG_MAX, NULL); if (*pSemaphore == NULL) { return ma_result_from_GetLastError(GetLastError()); } @@ -17432,10 +18859,12 @@ static MA_INLINE ma_uint16 ma_job_extract_slot(ma_uint64 toc) return (ma_uint16)(toc & 0x0000FFFF); } +#if 0 /* Currently unused, but might make use of this later. */ static MA_INLINE ma_uint16 ma_job_extract_code(ma_uint64 toc) { return (ma_uint16)((toc & 0xFFFF0000) >> 16); } +#endif static MA_INLINE ma_uint64 ma_job_toc_to_allocation(ma_uint64 toc) { @@ -17900,6 +19329,13 @@ MA_API ma_result ma_job_queue_next(ma_job_queue* pQueue, ma_job* pJob) Dynamic Linking *******************************************************************************/ +/* Disable run-time linking on certain backends and platforms. */ +#ifndef MA_NO_RUNTIME_LINKING + #if defined(MA_EMSCRIPTEN) || defined(MA_ORBIS) || defined(MA_PROSPERO) || defined(MA_SWITCH) || defined(MA_DOS) + #define MA_NO_RUNTIME_LINKING + #endif +#endif + #ifdef MA_POSIX /* No need for dlfcn.h if we're not using runtime linking. */ #ifndef MA_NO_RUNTIME_LINKING @@ -17909,104 +19345,124 @@ Dynamic Linking MA_API ma_handle ma_dlopen(ma_log* pLog, const char* filename) { -#ifndef MA_NO_RUNTIME_LINKING - ma_handle handle; + #ifndef MA_NO_RUNTIME_LINKING + { + ma_handle handle; - ma_log_postf(pLog, MA_LOG_LEVEL_DEBUG, "Loading library: %s\n", filename); + ma_log_postf(pLog, MA_LOG_LEVEL_DEBUG, "Loading library: %s\n", filename); - #ifdef MA_WIN32 - /* From MSDN: Desktop applications cannot use LoadPackagedLibrary; if a desktop application calls this function it fails with APPMODEL_ERROR_NO_PACKAGE.*/ - #if !defined(MA_WIN32_UWP) || !(defined(WINAPI_FAMILY) && ((defined(WINAPI_FAMILY_PHONE_APP) && WINAPI_FAMILY == WINAPI_FAMILY_PHONE_APP))) - handle = (ma_handle)LoadLibraryA(filename); + #ifdef MA_WIN32 + /* From MSDN: Desktop applications cannot use LoadPackagedLibrary; if a desktop application calls this function it fails with APPMODEL_ERROR_NO_PACKAGE.*/ + #if !defined(MA_WIN32_UWP) || !(defined(WINAPI_FAMILY) && ((defined(WINAPI_FAMILY_PHONE_APP) && WINAPI_FAMILY == WINAPI_FAMILY_PHONE_APP))) + handle = (ma_handle)LoadLibraryA(filename); + #else + /* *sigh* It appears there is no ANSI version of LoadPackagedLibrary()... */ + WCHAR filenameW[4096]; + if (MultiByteToWideChar(CP_UTF8, 0, filename, -1, filenameW, sizeof(filenameW)) == 0) { + handle = NULL; + } else { + handle = (ma_handle)LoadPackagedLibrary(filenameW, 0); + } + #endif #else - /* *sigh* It appears there is no ANSI version of LoadPackagedLibrary()... */ - WCHAR filenameW[4096]; - if (MultiByteToWideChar(CP_UTF8, 0, filename, -1, filenameW, sizeof(filenameW)) == 0) { - handle = NULL; - } else { - handle = (ma_handle)LoadPackagedLibrary(filenameW, 0); - } + handle = (ma_handle)dlopen(filename, RTLD_NOW); #endif - #else - handle = (ma_handle)dlopen(filename, RTLD_NOW); - #endif - /* - I'm not considering failure to load a library an error nor a warning because seamlessly falling through to a lower-priority - backend is a deliberate design choice. Instead I'm logging it as an informational message. - */ - if (handle == NULL) { - ma_log_postf(pLog, MA_LOG_LEVEL_INFO, "Failed to load library: %s\n", filename); - } + /* + I'm not considering failure to load a library an error nor a warning because seamlessly falling through to a lower-priority + backend is a deliberate design choice. Instead I'm logging it as an informational message. + */ + if (handle == NULL) { + ma_log_postf(pLog, MA_LOG_LEVEL_INFO, "Failed to load library: %s\n", filename); + } - return handle; -#else - /* Runtime linking is disabled. */ - (void)pLog; - (void)filename; - return NULL; -#endif + return handle; + } + #else + { + /* Runtime linking is disabled. */ + (void)pLog; + (void)filename; + return NULL; + } + #endif } MA_API void ma_dlclose(ma_log* pLog, ma_handle handle) { -#ifndef MA_NO_RUNTIME_LINKING - #ifdef MA_WIN32 - FreeLibrary((HMODULE)handle); - #else - /* Hack for Android bug (see https://github.com/android/ndk/issues/360). Calling dlclose() pre-API 28 may segfault. */ - #if !defined(MA_ANDROID) || (defined(__ANDROID_API__) && __ANDROID_API__ >= 28) + #ifndef MA_NO_RUNTIME_LINKING + { + #ifdef MA_WIN32 { - dlclose((void*)handle); + FreeLibrary((HMODULE)handle); } #else { - (void)handle; + /* Hack for Android bug (see https://github.com/android/ndk/issues/360). Calling dlclose() pre-API 28 may segfault. */ + #if !defined(MA_ANDROID) || (defined(__ANDROID_API__) && __ANDROID_API__ >= 28) + { + dlclose((void*)handle); + } + #else + { + (void)handle; + } + #endif } #endif - #endif - (void)pLog; -#else - /* Runtime linking is disabled. */ - (void)pLog; - (void)handle; -#endif + (void)pLog; + } + #else + { + /* Runtime linking is disabled. */ + (void)pLog; + (void)handle; + } + #endif } MA_API ma_proc ma_dlsym(ma_log* pLog, ma_handle handle, const char* symbol) { -#ifndef MA_NO_RUNTIME_LINKING - ma_proc proc; + #ifndef MA_NO_RUNTIME_LINKING + { + ma_proc proc; - ma_log_postf(pLog, MA_LOG_LEVEL_DEBUG, "Loading symbol: %s\n", symbol); + ma_log_postf(pLog, MA_LOG_LEVEL_DEBUG, "Loading symbol: %s\n", symbol); -#ifdef _WIN32 - proc = (ma_proc)GetProcAddress((HMODULE)handle, symbol); -#else -#if (defined(__GNUC__) && (__GNUC__ > 4 || (__GNUC__ == 4 && __GNUC_MINOR__ >= 8))) || defined(__clang__) - #pragma GCC diagnostic push - #pragma GCC diagnostic ignored "-Wpedantic" -#endif - proc = (ma_proc)dlsym((void*)handle, symbol); -#if (defined(__GNUC__) && (__GNUC__ > 4 || (__GNUC__ == 4 && __GNUC_MINOR__ >= 8))) || defined(__clang__) - #pragma GCC diagnostic pop -#endif -#endif + #ifdef _WIN32 + { + proc = (ma_proc)GetProcAddress((HMODULE)handle, symbol); + } + #else + { + #if (defined(__GNUC__) && (__GNUC__ > 4 || (__GNUC__ == 4 && __GNUC_MINOR__ >= 8))) || defined(__clang__) + #pragma GCC diagnostic push + #pragma GCC diagnostic ignored "-Wpedantic" + #endif + proc = (ma_proc)dlsym((void*)handle, symbol); + #if (defined(__GNUC__) && (__GNUC__ > 4 || (__GNUC__ == 4 && __GNUC_MINOR__ >= 8))) || defined(__clang__) + #pragma GCC diagnostic pop + #endif + } + #endif - if (proc == NULL) { - ma_log_postf(pLog, MA_LOG_LEVEL_WARNING, "Failed to load symbol: %s\n", symbol); - } + if (proc == NULL) { + ma_log_postf(pLog, MA_LOG_LEVEL_WARNING, "Failed to load symbol: %s\n", symbol); + } - (void)pLog; /* It's possible for pContext to be unused. */ - return proc; -#else - /* Runtime linking is disabled. */ - (void)pLog; - (void)handle; - (void)symbol; - return NULL; -#endif + (void)pLog; /* It's possible for pContext to be unused. */ + return proc; + } + #else + { + /* Runtime linking is disabled. */ + (void)pLog; + (void)handle; + (void)symbol; + return NULL; + } + #endif } @@ -18020,13 +19476,6 @@ DEVICE I/O ************************************************************************************************************************************************************* ************************************************************************************************************************************************************/ -/* Disable run-time linking on certain backends and platforms. */ -#ifndef MA_NO_RUNTIME_LINKING - #if defined(MA_EMSCRIPTEN) || defined(MA_ORBIS) || defined(MA_PROSPERO) - #define MA_NO_RUNTIME_LINKING - #endif -#endif - #ifdef MA_APPLE #include <AvailabilityMacros.h> #endif @@ -18039,12 +19488,6 @@ DEVICE I/O #ifdef MA_POSIX #include <sys/types.h> - #include <unistd.h> - - /* No need for dlfcn.h if we're not using runtime linking. */ - #ifndef MA_NO_RUNTIME_LINKING - #include <dlfcn.h> - #endif #endif /* This must be set to at least 26. */ @@ -18299,7 +19742,7 @@ MA_API ma_bool32 ma_is_loopback_supported(ma_backend backend) -#if defined(MA_WIN32) +#if defined(MA_WIN32) && !defined(MA_XBOX) /* WASAPI error codes. */ #define MA_AUDCLNT_E_NOT_INITIALIZED ((HRESULT)0x88890001) #define MA_AUDCLNT_E_ALREADY_INITIALIZED ((HRESULT)0x88890002) @@ -18514,6 +19957,11 @@ typedef LONG (WINAPI * MA_PFN_RegCloseKey)(HKEY hKey); typedef LONG (WINAPI * MA_PFN_RegQueryValueExA)(HKEY hKey, const char* lpValueName, DWORD* lpReserved, DWORD* lpType, BYTE* lpData, DWORD* lpcbData); #endif /* MA_WIN32_DESKTOP */ +static GUID MA_GUID_KSDATAFORMAT_SUBTYPE_PCM = {0x00000001, 0x0000, 0x0010, {0x80, 0x00, 0x00, 0xaa, 0x00, 0x38, 0x9b, 0x71}}; +static GUID MA_GUID_KSDATAFORMAT_SUBTYPE_IEEE_FLOAT = {0x00000003, 0x0000, 0x0010, {0x80, 0x00, 0x00, 0xaa, 0x00, 0x38, 0x9b, 0x71}}; +/*static GUID MA_GUID_KSDATAFORMAT_SUBTYPE_ALAW = {0x00000006, 0x0000, 0x0010, {0x80, 0x00, 0x00, 0xaa, 0x00, 0x38, 0x9b, 0x71}};*/ +/*static GUID MA_GUID_KSDATAFORMAT_SUBTYPE_MULAW = {0x00000007, 0x0000, 0x0010, {0x80, 0x00, 0x00, 0xaa, 0x00, 0x38, 0x9b, 0x71}};*/ + MA_API size_t ma_strlen_WCHAR(const WCHAR* str) { size_t len = 0; @@ -18577,7 +20025,7 @@ Timing *******************************************************************************/ #if defined(MA_WIN32) && !defined(MA_POSIX) static LARGE_INTEGER g_ma_TimerFrequency; /* <-- Initialized to zero since it's static. */ - static void ma_timer_init(ma_timer* pTimer) + static MA_INLINE void ma_timer_init(ma_timer* pTimer) { LARGE_INTEGER counter; @@ -18589,7 +20037,7 @@ Timing pTimer->counter = counter.QuadPart; } - static double ma_timer_get_time_in_seconds(ma_timer* pTimer) + static MA_INLINE double ma_timer_get_time_in_seconds(ma_timer* pTimer) { LARGE_INTEGER counter; if (!QueryPerformanceCounter(&counter)) { @@ -18600,7 +20048,7 @@ Timing } #elif defined(MA_APPLE) && (MAC_OS_X_VERSION_MIN_REQUIRED < 101200) static ma_uint64 g_ma_TimerFrequency = 0; - static void ma_timer_init(ma_timer* pTimer) + static MA_INLINE void ma_timer_init(ma_timer* pTimer) { mach_timebase_info_data_t baseTime; mach_timebase_info(&baseTime); @@ -18609,7 +20057,7 @@ Timing pTimer->counter = mach_absolute_time(); } - static double ma_timer_get_time_in_seconds(ma_timer* pTimer) + static MA_INLINE double ma_timer_get_time_in_seconds(ma_timer* pTimer) { ma_uint64 newTimeCounter = mach_absolute_time(); ma_uint64 oldTimeCounter = pTimer->counter; @@ -18634,15 +20082,15 @@ Timing #define MA_CLOCK_ID CLOCK_REALTIME #endif - static void ma_timer_init(ma_timer* pTimer) + static MA_INLINE void ma_timer_init(ma_timer* pTimer) { struct timespec newTime; clock_gettime(MA_CLOCK_ID, &newTime); - pTimer->counter = (newTime.tv_sec * 1000000000) + newTime.tv_nsec; + pTimer->counter = ((ma_int64)newTime.tv_sec * 1000000000) + newTime.tv_nsec; } - static double ma_timer_get_time_in_seconds(ma_timer* pTimer) + static MA_INLINE double ma_timer_get_time_in_seconds(ma_timer* pTimer) { ma_uint64 newTimeCounter; ma_uint64 oldTimeCounter; @@ -18650,21 +20098,21 @@ Timing struct timespec newTime; clock_gettime(MA_CLOCK_ID, &newTime); - newTimeCounter = (newTime.tv_sec * 1000000000) + newTime.tv_nsec; + newTimeCounter = ((ma_uint64)newTime.tv_sec * 1000000000) + newTime.tv_nsec; oldTimeCounter = pTimer->counter; return (newTimeCounter - oldTimeCounter) / 1000000000.0; } #else - static void ma_timer_init(ma_timer* pTimer) + static MA_INLINE void ma_timer_init(ma_timer* pTimer) { struct timeval newTime; gettimeofday(&newTime, NULL); - pTimer->counter = (newTime.tv_sec * 1000000) + newTime.tv_usec; + pTimer->counter = ((ma_int64)newTime.tv_sec * 1000000) + newTime.tv_usec; } - static double ma_timer_get_time_in_seconds(ma_timer* pTimer) + static MA_INLINE double ma_timer_get_time_in_seconds(ma_timer* pTimer) { ma_uint64 newTimeCounter; ma_uint64 oldTimeCounter; @@ -18672,7 +20120,7 @@ Timing struct timeval newTime; gettimeofday(&newTime, NULL); - newTimeCounter = (newTime.tv_sec * 1000000) + newTime.tv_usec; + newTimeCounter = ((ma_uint64)newTime.tv_sec * 1000000) + newTime.tv_usec; oldTimeCounter = pTimer->counter; return (newTimeCounter - oldTimeCounter) / 1000000.0; @@ -19248,14 +20696,6 @@ static MA_INLINE void ma_device__set_state(ma_device* pDevice, ma_device_state n } -#if defined(MA_WIN32) - static GUID MA_GUID_KSDATAFORMAT_SUBTYPE_PCM = {0x00000001, 0x0000, 0x0010, {0x80, 0x00, 0x00, 0xaa, 0x00, 0x38, 0x9b, 0x71}}; - static GUID MA_GUID_KSDATAFORMAT_SUBTYPE_IEEE_FLOAT = {0x00000003, 0x0000, 0x0010, {0x80, 0x00, 0x00, 0xaa, 0x00, 0x38, 0x9b, 0x71}}; - /*static GUID MA_GUID_KSDATAFORMAT_SUBTYPE_ALAW = {0x00000006, 0x0000, 0x0010, {0x80, 0x00, 0x00, 0xaa, 0x00, 0x38, 0x9b, 0x71}};*/ - /*static GUID MA_GUID_KSDATAFORMAT_SUBTYPE_MULAW = {0x00000007, 0x0000, 0x0010, {0x80, 0x00, 0x00, 0xaa, 0x00, 0x38, 0x9b, 0x71}};*/ -#endif - - MA_API ma_uint32 ma_get_format_priority_index(ma_format format) /* Lower = better. */ { @@ -19967,7 +21407,7 @@ static ma_result ma_context_init__null(ma_context* pContext, const ma_context_co WIN32 COMMON *******************************************************************************/ -#if defined(MA_WIN32) +#if defined(MA_WIN32) && !defined(MA_XBOX) #if defined(MA_WIN32_DESKTOP) || defined(MA_WIN32_GDK) #define ma_CoInitializeEx(pContext, pvReserved, dwCoInit) ((pContext->win32.CoInitializeEx) ? ((MA_PFN_CoInitializeEx)pContext->win32.CoInitializeEx)(pvReserved, dwCoInit) : ((MA_PFN_CoInitialize)pContext->win32.CoInitialize)(pvReserved)) #define ma_CoUninitialize(pContext) ((MA_PFN_CoUninitialize)pContext->win32.CoUninitialize)() @@ -19982,7 +21422,7 @@ WIN32 COMMON #define ma_PropVariantClear(pContext, pvar) PropVariantClear(pvar) #endif -#if !defined(MAXULONG_PTR) && !defined(__WATCOMC__) +#if !defined(MAXULONG_PTR) && !defined(__WATCOMC__) && !defined(MA_XBOX_NXDK) typedef size_t DWORD_PTR; #endif @@ -20409,11 +21849,21 @@ typedef enum MA_AudioCategory_Other = 0 /* <-- miniaudio is only caring about Other. */ } MA_AUDIO_STREAM_CATEGORY; +typedef enum +{ + MA_AUDCLNT_STREAMOPTIONS_NONE, + MA_AUDCLNT_STREAMOPTIONS_RAW, + MA_AUDCLNT_STREAMOPTIONS_MATCH_FORMAT, + MA_AUDCLNT_STREAMOPTIONS_AMBISONICS, + MA_AUDCLNT_STREAMOPTIONS_POST_VOLUME_LOOPBACK +} MA_AUDCLNT_STREAMOPTIONS; + typedef struct { ma_uint32 cbSize; BOOL bIsOffload; MA_AUDIO_STREAM_CATEGORY eCategory; + MA_AUDCLNT_STREAMOPTIONS Options; } ma_AudioClientProperties; /* IUnknown */ @@ -21588,6 +23038,7 @@ static ma_result ma_context_get_MMDevice__wasapi(ma_context* pContext, ma_device { ma_IMMDeviceEnumerator* pDeviceEnumerator; HRESULT hr; + HRESULT CoInitializeResult; MA_ASSERT(pContext != NULL); MA_ASSERT(ppMMDevice != NULL); @@ -21601,12 +23052,17 @@ static ma_result ma_context_get_MMDevice__wasapi(ma_context* pContext, ma_device The community has reported that this seems to fix the crash. There are future plans to move all WASAPI operation over to a single thread to make everything safer, but in the meantime while we wait for that to come online I'm happy enough to use this hack instead. + + CoUninitialize should only be called if we successfully initialized. S_OK and S_FALSE both mean that we need to + call CoUninitialize since the internal ref count was increased. RPC_E_CHANGED_MODE means that CoInitializeEx was + called with a different COINIT value, and we don't call CoUninitialize in that case. Other errors are possible, + so we check for S_OK and S_FALSE specifically. */ - ma_CoInitializeEx(pContext, NULL, MA_COINIT_VALUE); + CoInitializeResult = ma_CoInitializeEx(pContext, NULL, MA_COINIT_VALUE); { hr = ma_CoCreateInstance(pContext, &MA_CLSID_MMDeviceEnumerator, NULL, CLSCTX_ALL, &MA_IID_IMMDeviceEnumerator, (void**)&pDeviceEnumerator); - } - ma_CoUninitialize(pContext); + } + if (CoInitializeResult == S_OK || CoInitializeResult == S_FALSE) { ma_CoUninitialize(pContext); } if (FAILED(hr)) { /* <-- This is checking the call above to ma_CoCreateInstance(). */ ma_log_postf(ma_context_get_log(pContext), MA_LOG_LEVEL_ERROR, "[WASAPI] Failed to create IMMDeviceEnumerator.\n"); @@ -21950,7 +23406,7 @@ static ma_result ma_context_get_IAudioClient__wasapi(ma_context* pContext, ma_de pActivationParams = &activationParams; /* When requesting a specific device ID we need to use a special device ID. */ - MA_COPY_MEMORY(virtualDeviceID.wasapi, MA_VIRTUAL_AUDIO_DEVICE_PROCESS_LOOPBACK, (wcslen(MA_VIRTUAL_AUDIO_DEVICE_PROCESS_LOOPBACK) + 1) * sizeof(wchar_t)); /* +1 for the null terminator. */ + MA_COPY_MEMORY(virtualDeviceID.wasapi, MA_VIRTUAL_AUDIO_DEVICE_PROCESS_LOOPBACK, (ma_wcslen(MA_VIRTUAL_AUDIO_DEVICE_PROCESS_LOOPBACK) + 1) * sizeof(wchar_t)); /* +1 for the null terminator. */ pDeviceID = &virtualDeviceID; } else { pActivationParams = NULL; /* No activation parameters required. */ @@ -26679,6 +28135,9 @@ typedef snd_pcm_channel_area_t ma_snd_pcm_channel_area_t; typedef snd_pcm_chmap_t ma_snd_pcm_chmap_t; typedef snd_pcm_state_t ma_snd_pcm_state_t; +/* snd_pcm_state_t */ +#define MA_SND_PCM_STATE_XRUN SND_PCM_STATE_XRUN + /* snd_pcm_stream_t */ #define MA_SND_PCM_STREAM_PLAYBACK SND_PCM_STREAM_PLAYBACK #define MA_SND_PCM_STREAM_CAPTURE SND_PCM_STREAM_CAPTURE @@ -26874,6 +28333,7 @@ typedef int (* ma_snd_pcm_hw_params_set_channels_minmax_proc) ( typedef int (* ma_snd_pcm_hw_params_set_rate_resample_proc) (ma_snd_pcm_t *pcm, ma_snd_pcm_hw_params_t *params, unsigned int val); typedef int (* ma_snd_pcm_hw_params_set_rate_proc) (ma_snd_pcm_t *pcm, ma_snd_pcm_hw_params_t *params, unsigned int val, int dir); typedef int (* ma_snd_pcm_hw_params_set_rate_near_proc) (ma_snd_pcm_t *pcm, ma_snd_pcm_hw_params_t *params, unsigned int *val, int *dir); +typedef int (* ma_snd_pcm_hw_params_set_rate_minmax_proc) (ma_snd_pcm_t *pcm, ma_snd_pcm_hw_params_t *params, unsigned int *min, int *mindir, unsigned int *max, int *maxdir); typedef int (* ma_snd_pcm_hw_params_set_buffer_size_near_proc)(ma_snd_pcm_t *pcm, ma_snd_pcm_hw_params_t *params, ma_snd_pcm_uframes_t *val); typedef int (* ma_snd_pcm_hw_params_set_periods_near_proc) (ma_snd_pcm_t *pcm, ma_snd_pcm_hw_params_t *params, unsigned int *val, int *dir); typedef int (* ma_snd_pcm_hw_params_set_access_proc) (ma_snd_pcm_t *pcm, ma_snd_pcm_hw_params_t *params, ma_snd_pcm_access_t _access); @@ -28640,8 +30100,9 @@ static ma_result ma_context_init__alsa(ma_context* pContext, const ma_context_co ma_snd_pcm_hw_params_get_format_mask_proc _snd_pcm_hw_params_get_format_mask = snd_pcm_hw_params_get_format_mask; ma_snd_pcm_hw_params_set_channels_proc _snd_pcm_hw_params_set_channels = snd_pcm_hw_params_set_channels; ma_snd_pcm_hw_params_set_channels_near_proc _snd_pcm_hw_params_set_channels_near = snd_pcm_hw_params_set_channels_near; + ma_snd_pcm_hw_params_set_channels_minmax_proc _snd_pcm_hw_params_set_channels_minmax = snd_pcm_hw_params_set_channels_minmax; ma_snd_pcm_hw_params_set_rate_resample_proc _snd_pcm_hw_params_set_rate_resample = snd_pcm_hw_params_set_rate_resample; - ma_snd_pcm_hw_params_set_rate_near _snd_pcm_hw_params_set_rate = snd_pcm_hw_params_set_rate; + ma_snd_pcm_hw_params_set_rate_proc _snd_pcm_hw_params_set_rate = snd_pcm_hw_params_set_rate; ma_snd_pcm_hw_params_set_rate_near_proc _snd_pcm_hw_params_set_rate_near = snd_pcm_hw_params_set_rate_near; ma_snd_pcm_hw_params_set_rate_minmax_proc _snd_pcm_hw_params_set_rate_minmax = snd_pcm_hw_params_set_rate_minmax; ma_snd_pcm_hw_params_set_buffer_size_near_proc _snd_pcm_hw_params_set_buffer_size_near = snd_pcm_hw_params_set_buffer_size_near; @@ -28693,9 +30154,9 @@ static ma_result ma_context_init__alsa(ma_context* pContext, const ma_context_co ma_snd_pcm_info_proc _snd_pcm_info = snd_pcm_info; ma_snd_pcm_info_sizeof_proc _snd_pcm_info_sizeof = snd_pcm_info_sizeof; ma_snd_pcm_info_get_name_proc _snd_pcm_info_get_name = snd_pcm_info_get_name; - ma_snd_pcm_poll_descriptors _snd_pcm_poll_descriptors = snd_pcm_poll_descriptors; - ma_snd_pcm_poll_descriptors_count _snd_pcm_poll_descriptors_count = snd_pcm_poll_descriptors_count; - ma_snd_pcm_poll_descriptors_revents _snd_pcm_poll_descriptors_revents = snd_pcm_poll_descriptors_revents; + ma_snd_pcm_poll_descriptors_proc _snd_pcm_poll_descriptors = snd_pcm_poll_descriptors; + ma_snd_pcm_poll_descriptors_count_proc _snd_pcm_poll_descriptors_count = snd_pcm_poll_descriptors_count; + ma_snd_pcm_poll_descriptors_revents_proc _snd_pcm_poll_descriptors_revents = snd_pcm_poll_descriptors_revents; ma_snd_config_update_free_global_proc _snd_config_update_free_global = snd_config_update_free_global; pContext->alsa.snd_pcm_open = (ma_proc)_snd_pcm_open; @@ -28711,6 +30172,7 @@ static ma_result ma_context_init__alsa(ma_context* pContext, const ma_context_co pContext->alsa.snd_pcm_hw_params_set_rate_resample = (ma_proc)_snd_pcm_hw_params_set_rate_resample; pContext->alsa.snd_pcm_hw_params_set_rate = (ma_proc)_snd_pcm_hw_params_set_rate; pContext->alsa.snd_pcm_hw_params_set_rate_near = (ma_proc)_snd_pcm_hw_params_set_rate_near; + pContext->alsa.snd_pcm_hw_params_set_rate_minmax = (ma_proc)_snd_pcm_hw_params_set_rate_minmax; pContext->alsa.snd_pcm_hw_params_set_buffer_size_near = (ma_proc)_snd_pcm_hw_params_set_buffer_size_near; pContext->alsa.snd_pcm_hw_params_set_periods_near = (ma_proc)_snd_pcm_hw_params_set_periods_near; pContext->alsa.snd_pcm_hw_params_set_access = (ma_proc)_snd_pcm_hw_params_set_access; @@ -29436,7 +30898,7 @@ typedef void (* ma_pa_threaded_mainloop_unlock_proc) ( typedef void (* ma_pa_threaded_mainloop_wait_proc) (ma_pa_threaded_mainloop* m); typedef void (* ma_pa_threaded_mainloop_signal_proc) (ma_pa_threaded_mainloop* m, int wait_for_accept); typedef void (* ma_pa_threaded_mainloop_accept_proc) (ma_pa_threaded_mainloop* m); -typedef int (* ma_pa_threaded_mainloop_get_retval_proc) (ma_pa_threaded_mainloop* m); +typedef int (* ma_pa_threaded_mainloop_get_retval_proc) (const ma_pa_threaded_mainloop* m); typedef ma_pa_mainloop_api* (* ma_pa_threaded_mainloop_get_api_proc) (ma_pa_threaded_mainloop* m); typedef int (* ma_pa_threaded_mainloop_in_thread_proc) (ma_pa_threaded_mainloop* m); typedef void (* ma_pa_threaded_mainloop_set_name_proc) (ma_pa_threaded_mainloop* m, const char* name); @@ -29445,13 +30907,13 @@ typedef void (* ma_pa_context_unref_proc) ( typedef int (* ma_pa_context_connect_proc) (ma_pa_context* c, const char* server, ma_pa_context_flags_t flags, const ma_pa_spawn_api* api); typedef void (* ma_pa_context_disconnect_proc) (ma_pa_context* c); typedef void (* ma_pa_context_set_state_callback_proc) (ma_pa_context* c, ma_pa_context_notify_cb_t cb, void* userdata); -typedef ma_pa_context_state_t (* ma_pa_context_get_state_proc) (ma_pa_context* c); +typedef ma_pa_context_state_t (* ma_pa_context_get_state_proc) (const ma_pa_context* c); typedef ma_pa_operation* (* ma_pa_context_get_sink_info_list_proc) (ma_pa_context* c, ma_pa_sink_info_cb_t cb, void* userdata); typedef ma_pa_operation* (* ma_pa_context_get_source_info_list_proc) (ma_pa_context* c, ma_pa_source_info_cb_t cb, void* userdata); typedef ma_pa_operation* (* ma_pa_context_get_sink_info_by_name_proc) (ma_pa_context* c, const char* name, ma_pa_sink_info_cb_t cb, void* userdata); typedef ma_pa_operation* (* ma_pa_context_get_source_info_by_name_proc)(ma_pa_context* c, const char* name, ma_pa_source_info_cb_t cb, void* userdata); typedef void (* ma_pa_operation_unref_proc) (ma_pa_operation* o); -typedef ma_pa_operation_state_t (* ma_pa_operation_get_state_proc) (ma_pa_operation* o); +typedef ma_pa_operation_state_t (* ma_pa_operation_get_state_proc) (const ma_pa_operation* o); typedef ma_pa_channel_map* (* ma_pa_channel_map_init_extend_proc) (ma_pa_channel_map* m, unsigned channels, ma_pa_channel_map_def_t def); typedef int (* ma_pa_channel_map_valid_proc) (const ma_pa_channel_map* m); typedef int (* ma_pa_channel_map_compatible_proc) (const ma_pa_channel_map* m, const ma_pa_sample_spec* ss); @@ -29460,12 +30922,12 @@ typedef void (* ma_pa_stream_unref_proc) ( typedef int (* ma_pa_stream_connect_playback_proc) (ma_pa_stream* s, const char* dev, const ma_pa_buffer_attr* attr, ma_pa_stream_flags_t flags, const ma_pa_cvolume* volume, ma_pa_stream* sync_stream); typedef int (* ma_pa_stream_connect_record_proc) (ma_pa_stream* s, const char* dev, const ma_pa_buffer_attr* attr, ma_pa_stream_flags_t flags); typedef int (* ma_pa_stream_disconnect_proc) (ma_pa_stream* s); -typedef ma_pa_stream_state_t (* ma_pa_stream_get_state_proc) (ma_pa_stream* s); +typedef ma_pa_stream_state_t (* ma_pa_stream_get_state_proc) (const ma_pa_stream* s); typedef const ma_pa_sample_spec* (* ma_pa_stream_get_sample_spec_proc) (ma_pa_stream* s); typedef const ma_pa_channel_map* (* ma_pa_stream_get_channel_map_proc) (ma_pa_stream* s); typedef const ma_pa_buffer_attr* (* ma_pa_stream_get_buffer_attr_proc) (ma_pa_stream* s); typedef ma_pa_operation* (* ma_pa_stream_set_buffer_attr_proc) (ma_pa_stream* s, const ma_pa_buffer_attr* attr, ma_pa_stream_success_cb_t cb, void* userdata); -typedef const char* (* ma_pa_stream_get_device_name_proc) (ma_pa_stream* s); +typedef const char* (* ma_pa_stream_get_device_name_proc) (const ma_pa_stream* s); typedef void (* ma_pa_stream_set_write_callback_proc) (ma_pa_stream* s, ma_pa_stream_request_cb_t cb, void* userdata); typedef void (* ma_pa_stream_set_read_callback_proc) (ma_pa_stream* s, ma_pa_stream_request_cb_t cb, void* userdata); typedef void (* ma_pa_stream_set_suspended_callback_proc) (ma_pa_stream* s, ma_pa_stream_notify_cb_t cb, void* userdata); @@ -29473,15 +30935,15 @@ typedef void (* ma_pa_stream_set_moved_callback_proc) ( typedef int (* ma_pa_stream_is_suspended_proc) (const ma_pa_stream* s); typedef ma_pa_operation* (* ma_pa_stream_flush_proc) (ma_pa_stream* s, ma_pa_stream_success_cb_t cb, void* userdata); typedef ma_pa_operation* (* ma_pa_stream_drain_proc) (ma_pa_stream* s, ma_pa_stream_success_cb_t cb, void* userdata); -typedef int (* ma_pa_stream_is_corked_proc) (ma_pa_stream* s); +typedef int (* ma_pa_stream_is_corked_proc) (const ma_pa_stream* s); typedef ma_pa_operation* (* ma_pa_stream_cork_proc) (ma_pa_stream* s, int b, ma_pa_stream_success_cb_t cb, void* userdata); typedef ma_pa_operation* (* ma_pa_stream_trigger_proc) (ma_pa_stream* s, ma_pa_stream_success_cb_t cb, void* userdata); typedef int (* ma_pa_stream_begin_write_proc) (ma_pa_stream* s, void** data, size_t* nbytes); typedef int (* ma_pa_stream_write_proc) (ma_pa_stream* s, const void* data, size_t nbytes, ma_pa_free_cb_t free_cb, int64_t offset, ma_pa_seek_mode_t seek); typedef int (* ma_pa_stream_peek_proc) (ma_pa_stream* s, const void** data, size_t* nbytes); typedef int (* ma_pa_stream_drop_proc) (ma_pa_stream* s); -typedef size_t (* ma_pa_stream_writable_size_proc) (ma_pa_stream* s); -typedef size_t (* ma_pa_stream_readable_size_proc) (ma_pa_stream* s); +typedef size_t (* ma_pa_stream_writable_size_proc) (const ma_pa_stream* s); +typedef size_t (* ma_pa_stream_readable_size_proc) (const ma_pa_stream* s); typedef struct { @@ -29777,9 +31239,10 @@ static ma_result ma_init_pa_mainloop_and_pa_context__pulse(ma_context* pContext, } /* Now we need to connect to the context. Everything is asynchronous so we need to wait for it to connect before returning. */ - result = ma_result_from_pulse(((ma_pa_context_connect_proc)pContext->pulse.pa_context_connect)((ma_pa_context*)pPulseContext, pServerName, (tryAutoSpawn) ? 0 : MA_PA_CONTEXT_NOAUTOSPAWN, NULL)); + result = ma_result_from_pulse(((ma_pa_context_connect_proc)pContext->pulse.pa_context_connect)((ma_pa_context*)pPulseContext, pServerName, (tryAutoSpawn) ? MA_PA_CONTEXT_NOFLAGS : MA_PA_CONTEXT_NOAUTOSPAWN, NULL)); if (result != MA_SUCCESS) { ma_log_postf(ma_context_get_log(pContext), MA_LOG_LEVEL_ERROR, "[PulseAudio] Failed to connect PulseAudio context."); + ((ma_pa_context_unref_proc)pContext->pulse.pa_context_unref)((ma_pa_context*)(pPulseContext)); ((ma_pa_mainloop_free_proc)pContext->pulse.pa_mainloop_free)((ma_pa_mainloop*)(pMainLoop)); return result; } @@ -29788,6 +31251,7 @@ static ma_result ma_init_pa_mainloop_and_pa_context__pulse(ma_context* pContext, result = ma_wait_for_pa_context_to_connect__pulse(pContext, pMainLoop, pPulseContext); if (result != MA_SUCCESS) { ma_log_postf(ma_context_get_log(pContext), MA_LOG_LEVEL_ERROR, "[PulseAudio] Waiting for connection failed."); + ((ma_pa_context_unref_proc)pContext->pulse.pa_context_unref)((ma_pa_context*)(pPulseContext)); ((ma_pa_mainloop_free_proc)pContext->pulse.pa_mainloop_free)((ma_pa_mainloop*)(pMainLoop)); return result; } @@ -30510,7 +31974,7 @@ static ma_result ma_device_init__pulse(ma_device* pDevice, const ma_device_confi const ma_pa_buffer_attr* pActualAttr = NULL; const ma_pa_channel_map* pActualChannelMap = NULL; ma_uint32 iChannel; - ma_pa_stream_flags_t streamFlags; + int streamFlags; MA_ASSERT(pDevice != NULL); MA_ZERO_OBJECT(&pDevice->pulse); @@ -30568,8 +32032,13 @@ static ma_result ma_device_init__pulse(ma_device* pDevice, const ma_device_confi ss.channels = pDescriptorCapture->channels; } + /* PulseAudio has a maximum channel count of 32. We'll get a crash if this is exceeded. */ + if (ss.channels > 32) { + ss.channels = 32; + } + /* Use a default channel map. */ - ((ma_pa_channel_map_init_extend_proc)pDevice->pContext->pulse.pa_channel_map_init_extend)(&cmap, ss.channels, pConfig->pulse.channelMap); + ((ma_pa_channel_map_init_extend_proc)pDevice->pContext->pulse.pa_channel_map_init_extend)(&cmap, ss.channels, (ma_pa_channel_map_def_t)pConfig->pulse.channelMap); /* Use the requested sample rate if one was specified. */ if (pDescriptorCapture->sampleRate != 0) { @@ -30626,7 +32095,7 @@ static ma_result ma_device_init__pulse(ma_device* pDevice, const ma_device_confi streamFlags |= MA_PA_STREAM_DONT_MOVE; } - error = ((ma_pa_stream_connect_record_proc)pDevice->pContext->pulse.pa_stream_connect_record)((ma_pa_stream*)pDevice->pulse.pStreamCapture, devCapture, &attr, streamFlags); + error = ((ma_pa_stream_connect_record_proc)pDevice->pContext->pulse.pa_stream_connect_record)((ma_pa_stream*)pDevice->pulse.pStreamCapture, devCapture, &attr, (ma_pa_stream_flags_t)streamFlags); if (error != MA_PA_OK) { ma_log_post(ma_device_get_log(pDevice), MA_LOG_LEVEL_ERROR, "[PulseAudio] Failed to connect PulseAudio capture stream."); result = ma_result_from_pulse(error); @@ -30720,8 +32189,13 @@ static ma_result ma_device_init__pulse(ma_device* pDevice, const ma_device_confi ss.channels = pDescriptorPlayback->channels; } + /* PulseAudio has a maximum channel count of 32. We'll get a crash if this is exceeded. */ + if (ss.channels > 32) { + ss.channels = 32; + } + /* Use a default channel map. */ - ((ma_pa_channel_map_init_extend_proc)pDevice->pContext->pulse.pa_channel_map_init_extend)(&cmap, ss.channels, pConfig->pulse.channelMap); + ((ma_pa_channel_map_init_extend_proc)pDevice->pContext->pulse.pa_channel_map_init_extend)(&cmap, ss.channels, (ma_pa_channel_map_def_t)pConfig->pulse.channelMap); /* Use the requested sample rate if one was specified. */ @@ -30783,7 +32257,7 @@ static ma_result ma_device_init__pulse(ma_device* pDevice, const ma_device_confi streamFlags |= MA_PA_STREAM_DONT_MOVE; } - error = ((ma_pa_stream_connect_playback_proc)pDevice->pContext->pulse.pa_stream_connect_playback)((ma_pa_stream*)pDevice->pulse.pStreamPlayback, devPlayback, &attr, streamFlags, NULL, NULL); + error = ((ma_pa_stream_connect_playback_proc)pDevice->pContext->pulse.pa_stream_connect_playback)((ma_pa_stream*)pDevice->pulse.pStreamPlayback, devPlayback, &attr, (ma_pa_stream_flags_t)streamFlags, NULL, NULL); if (error != MA_PA_OK) { ma_log_post(ma_device_get_log(pDevice), MA_LOG_LEVEL_ERROR, "[PulseAudio] Failed to connect PulseAudio playback stream."); result = ma_result_from_pulse(error); @@ -31338,6 +32812,7 @@ typedef JackProcessCallback ma_JackProcessCallback; typedef JackBufferSizeCallback ma_JackBufferSizeCallback; typedef JackShutdownCallback ma_JackShutdownCallback; #define MA_JACK_DEFAULT_AUDIO_TYPE JACK_DEFAULT_AUDIO_TYPE +#define ma_JackNullOption JackNullOption #define ma_JackNoStartServer JackNoStartServer #define ma_JackPortIsInput JackPortIsInput #define ma_JackPortIsOutput JackPortIsOutput @@ -31352,6 +32827,7 @@ typedef int (* ma_JackProcessCallback) (ma_jack_nframes_t nframes, void* arg) typedef int (* ma_JackBufferSizeCallback)(ma_jack_nframes_t nframes, void* arg); typedef void (* ma_JackShutdownCallback) (void* arg); #define MA_JACK_DEFAULT_AUDIO_TYPE "32 bit float mono audio" +#define ma_JackNullOption 0 #define ma_JackNoStartServer 1 #define ma_JackPortIsInput 1 #define ma_JackPortIsOutput 2 @@ -31392,7 +32868,7 @@ static ma_result ma_context_open_client__jack(ma_context* pContext, ma_jack_clie maxClientNameSize = ((ma_jack_client_name_size_proc)pContext->jack.jack_client_name_size)(); /* Includes null terminator. */ ma_strncpy_s(clientName, ma_min(sizeof(clientName), maxClientNameSize), (pContext->jack.pClientName != NULL) ? pContext->jack.pClientName : "miniaudio", (size_t)-1); - pClient = ((ma_jack_client_open_proc)pContext->jack.jack_client_open)(clientName, (pContext->jack.tryStartServer) ? 0 : ma_JackNoStartServer, &status, NULL); + pClient = ((ma_jack_client_open_proc)pContext->jack.jack_client_open)(clientName, (pContext->jack.tryStartServer) ? ma_JackNullOption : ma_JackNoStartServer, &status, NULL); if (pClient == NULL) { return MA_FAILED_TO_OPEN_BACKEND_DEVICE; } @@ -36994,7 +38470,7 @@ OSS Backend #define MA_OSS_DEFAULT_DEVICE_NAME "/dev/dsp" -static int ma_open_temp_device__oss() +static int ma_open_temp_device__oss(void) { /* The OSS sample code uses "/dev/mixer" as the device for getting system properties so I'm going to do the same. */ int fd = open("/dev/mixer", O_RDONLY, 0); @@ -37834,25 +39310,30 @@ static void ma_stream_error_callback__aaudio(ma_AAudioStream* pStream, void* pUs (void)error; ma_log_postf(ma_device_get_log(pDevice), MA_LOG_LEVEL_INFO, "[AAudio] ERROR CALLBACK: error=%d, AAudioStream_getState()=%d\n", error, ((MA_PFN_AAudioStream_getState)pDevice->pContext->aaudio.AAudioStream_getState)(pStream)); + /* When we get an error, we'll assume that the stream is in an erroneous state and needs to be restarted. From the documentation, we cannot do this from the error callback. Therefore we are going to use an event thread for the AAudio backend to do this cleanly and safely. */ - job = ma_job_init(MA_JOB_TYPE_DEVICE_AAUDIO_REROUTE); - job.data.device.aaudio.reroute.pDevice = pDevice; - - if (pStream == pDevice->aaudio.pStreamCapture) { - job.data.device.aaudio.reroute.deviceType = ma_device_type_capture; + if (ma_atomic_bool32_get(&pDevice->aaudio.isTearingDown)) { + ma_log_postf(ma_device_get_log(pDevice), MA_LOG_LEVEL_INFO, "[AAudio] Device Disconnected. Tearing down device.\n"); } else { - job.data.device.aaudio.reroute.deviceType = ma_device_type_playback; - } - - result = ma_device_job_thread_post(&pDevice->pContext->aaudio.jobThread, &job); - if (result != MA_SUCCESS) { - ma_log_postf(ma_device_get_log(pDevice), MA_LOG_LEVEL_INFO, "[AAudio] Device Disconnected. Failed to post job for rerouting.\n"); - return; + job = ma_job_init(MA_JOB_TYPE_DEVICE_AAUDIO_REROUTE); + job.data.device.aaudio.reroute.pDevice = pDevice; + + if (pStream == pDevice->aaudio.pStreamCapture) { + job.data.device.aaudio.reroute.deviceType = ma_device_type_capture; + } else { + job.data.device.aaudio.reroute.deviceType = ma_device_type_playback; + } + + result = ma_device_job_thread_post(&pDevice->pContext->aaudio.jobThread, &job); + if (result != MA_SUCCESS) { + ma_log_postf(ma_device_get_log(pDevice), MA_LOG_LEVEL_INFO, "[AAudio] Device Disconnected. Failed to post job for rerouting.\n"); + return; + } } } @@ -38169,7 +39650,7 @@ static ma_result ma_close_streams__aaudio(ma_device* pDevice) { MA_ASSERT(pDevice != NULL); - /* When re-routing, streams may have been closed and never re-opened. Hence the extra checks below. */ + /* When rerouting, streams may have been closed and never re-opened. Hence the extra checks below. */ if (pDevice->type == ma_device_type_capture || pDevice->type == ma_device_type_duplex) { ma_close_stream__aaudio(pDevice->pContext, (ma_AAudioStream*)pDevice->aaudio.pStreamCapture); pDevice->aaudio.pStreamCapture = NULL; @@ -38186,6 +39667,12 @@ static ma_result ma_device_uninit__aaudio(ma_device* pDevice) { MA_ASSERT(pDevice != NULL); + /* + Note: Closing the streams may cause a timeout error, which would then trigger rerouting in our error callback. + We must not schedule a reroute when device is getting destroyed. + */ + ma_atomic_bool32_set(&pDevice->aaudio.isTearingDown, MA_TRUE); + /* Wait for any rerouting to finish before attempting to close the streams. */ ma_mutex_lock(&pDevice->aaudio.rerouteLock); { @@ -38193,7 +39680,7 @@ static ma_result ma_device_uninit__aaudio(ma_device* pDevice) } ma_mutex_unlock(&pDevice->aaudio.rerouteLock); - /* Destroy re-routing lock. */ + /* Destroy rerouting lock. */ ma_mutex_uninit(&pDevice->aaudio.rerouteLock); return MA_SUCCESS; @@ -38429,17 +39916,22 @@ static ma_result ma_device_stop__aaudio(ma_device* pDevice) static ma_result ma_device_reinit__aaudio(ma_device* pDevice, ma_device_type deviceType) { + const ma_int32 maxAttempts = 4; /* Reasonable retry limit. */ + ma_result result; - int32_t retries = 0; + ma_int32 iAttempt; MA_ASSERT(pDevice != NULL); - /* - TODO: Stop retrying if main thread is about to uninit device. - */ - ma_mutex_lock(&pDevice->aaudio.rerouteLock); - { -error_disconnected: + /* We got disconnected! Retry a few times, until we find a connected device! */ + iAttempt = 0; + while (iAttempt++ < maxAttempts) { + /* Device tearing down? No need to reroute! */ + if (ma_atomic_bool32_get(&pDevice->aaudio.isTearingDown)) { + result = MA_SUCCESS; /* Caller should continue as normal. */ + break; + } + /* The first thing to do is close the streams. */ ma_close_streams__aaudio(pDevice); @@ -38495,14 +39987,16 @@ static ma_result ma_device_reinit__aaudio(ma_device* pDevice, ma_device_type dev result = ma_device_init_streams__aaudio(pDevice, &deviceConfig, &descriptorPlayback, &descriptorCapture); if (result != MA_SUCCESS) { ma_log_post(ma_device_get_log(pDevice), MA_LOG_LEVEL_WARNING, "[AAudio] Failed to create stream after route change."); - goto done; + /* Reroute failed! */ + break; } result = ma_device_post_init(pDevice, deviceType, &descriptorPlayback, &descriptorCapture); if (result != MA_SUCCESS) { ma_log_post(ma_device_get_log(pDevice), MA_LOG_LEVEL_WARNING, "[AAudio] Failed to initialize device after route change."); ma_close_streams__aaudio(pDevice); - goto done; + /* Reroute failed! */ + break; } /* We'll only ever do this in response to a reroute. */ @@ -38513,26 +40007,23 @@ static ma_result ma_device_reinit__aaudio(ma_device* pDevice, ma_device_type dev if (pDevice->aaudio.noAutoStartAfterReroute == MA_FALSE) { result = ma_device_start__aaudio(pDevice); if (result != MA_SUCCESS) { - /* We got disconnected! Retry a few times, until we find a connected device! */ - retries += 1; - if (retries <= 3) { - ma_log_postf(ma_device_get_log(pDevice), MA_LOG_LEVEL_INFO, "[AAudio] Failed to start stream after route change, retrying(%d)", retries); - goto error_disconnected; + if (iAttempt < maxAttempts) { + ma_log_postf(ma_device_get_log(pDevice), MA_LOG_LEVEL_INFO, "[AAudio] Failed to start stream after route change, retrying(%d)", iAttempt); + } else { + ma_log_post(ma_device_get_log(pDevice), MA_LOG_LEVEL_INFO, "[AAudio] Failed to start stream after route change, giving up."); } - ma_log_post(ma_device_get_log(pDevice), MA_LOG_LEVEL_INFO, "[AAudio] Failed to start stream after route change."); - goto done; } } else { - ma_device_stop(pDevice); /* Do a full device stop so we set internal state correctly. */ + ma_device_stop(pDevice); /* Do a full device stop so we set internal state correctly. */ } } - - result = MA_SUCCESS; - } -done: - /* Re-routing done */ - ma_mutex_unlock(&pDevice->aaudio.rerouteLock); + if (result == MA_SUCCESS) { + /* Reroute successful! */ + break; + } + } + return result; } @@ -38698,7 +40189,7 @@ static ma_result ma_context_init__aaudio(ma_context* pContext, const ma_context_ static ma_result ma_job_process__device__aaudio_reroute(ma_job* pJob) { - ma_result result; + ma_result result = MA_SUCCESS; ma_device* pDevice; MA_ASSERT(pJob != NULL); @@ -38706,19 +40197,22 @@ static ma_result ma_job_process__device__aaudio_reroute(ma_job* pJob) pDevice = (ma_device*)pJob->data.device.aaudio.reroute.pDevice; MA_ASSERT(pDevice != NULL); - /* Here is where we need to reroute the device. To do this we need to uninitialize the stream and reinitialize it. */ - result = ma_device_reinit__aaudio(pDevice, (ma_device_type)pJob->data.device.aaudio.reroute.deviceType); - if (result != MA_SUCCESS) { - /* - Getting here means we failed to reroute the device. The best thing I can think of here is to - just stop the device. - */ - ma_log_post(ma_device_get_log(pDevice), MA_LOG_LEVEL_ERROR, "[AAudio] Stopping device due to reroute failure."); - ma_device_stop(pDevice); - return result; + ma_mutex_lock(&pDevice->aaudio.rerouteLock); + { + /* Here is where we need to reroute the device. To do this we need to uninitialize the stream and reinitialize it. */ + result = ma_device_reinit__aaudio(pDevice, (ma_device_type)pJob->data.device.aaudio.reroute.deviceType); + if (result != MA_SUCCESS) { + /* + Getting here means we failed to reroute the device. The best thing I can think of here is to + just stop the device. + */ + ma_log_post(ma_device_get_log(pDevice), MA_LOG_LEVEL_ERROR, "[AAudio] Stopping device due to reroute failure."); + ma_device_stop(pDevice); + } } + ma_mutex_unlock(&pDevice->aaudio.rerouteLock); - return MA_SUCCESS; + return result; } #else /* Getting here means there is no AAudio backend so we need a no-op job implementation. */ @@ -40269,8 +41763,11 @@ static EM_BOOL ma_audio_worklet_process_callback__webaudio(int inputCount, const frameCount = pDevice->capture.internalPeriodSizeInFrames; } + /* + If this is called by the device has not yet been started we need to return early, making sure we output silence to + the output buffer. + */ if (ma_device_get_state(pDevice) != ma_device_state_started) { - /* Fill the output buffer with zero to avoid a noise sound */ for (int i = 0; i < outputCount; i += 1) { MA_ZERO_MEMORY(pOutputs[i].data, pOutputs[i].numberOfChannels * frameCount * sizeof(float)); } @@ -40292,7 +41789,9 @@ static EM_BOOL ma_audio_worklet_process_callback__webaudio(int inputCount, const if (outputCount > 0) { /* If it's a capture-only device, we'll need to output silence. */ if (pDevice->type == ma_device_type_capture) { - MA_ZERO_MEMORY(pOutputs[0].data, frameCount * pDevice->playback.internalChannels * sizeof(float)); + for (int i = 0; i < outputCount; i += 1) { + MA_ZERO_MEMORY(pOutputs[i].data, pOutputs[i].numberOfChannels * frameCount * sizeof(float)); + } } else { ma_device_process_pcm_frames_playback__webaudio(pDevice, frameCount, pDevice->webaudio.pIntermediaryBuffer); @@ -40302,6 +41801,14 @@ static EM_BOOL ma_audio_worklet_process_callback__webaudio(int inputCount, const pOutputs[0].data[frameCount*iChannel + iFrame] = pDevice->webaudio.pIntermediaryBuffer[iFrame*pDevice->playback.internalChannels + iChannel]; } } + + /* + Just above we output data to the first output buffer. Here we just make sure we're putting silence into any + remaining output buffers. + */ + for (int i = 1; i < outputCount; i += 1) { /* <-- Note that the counter starts at 1 instead of 0. */ + MA_ZERO_MEMORY(pOutputs[i].data, pOutputs[i].numberOfChannels * frameCount * sizeof(float)); + } } } @@ -40782,8 +42289,8 @@ static ma_result ma_context_uninit__webaudio(ma_context* pContext) /* Remove the global miniaudio object from window if there are no more references to it. */ EM_ASM({ if (typeof(window.miniaudio) !== 'undefined') { - miniaudio.unlock_event_types.map(function(event_type) { - document.removeEventListener(event_type, miniaudio.unlock, true); + window.miniaudio.unlock_event_types.map(function(event_type) { + document.removeEventListener(event_type, window.miniaudio.unlock, true); }); window.miniaudio.referenceCount -= 1; @@ -41236,13 +42743,13 @@ MA_API ma_result ma_device_post_init(ma_device* pDevice, ma_device_type deviceTy static ma_thread_result MA_THREADCALL ma_worker_thread(void* pData) { ma_device* pDevice = (ma_device*)pData; -#ifdef MA_WIN32 +#if defined(MA_WIN32) && !defined(MA_XBOX) HRESULT CoInitializeResult; #endif MA_ASSERT(pDevice != NULL); -#ifdef MA_WIN32 +#if defined(MA_WIN32) && !defined(MA_XBOX) CoInitializeResult = ma_CoInitializeEx(pDevice->pContext, NULL, MA_COINIT_VALUE); #endif @@ -41333,8 +42840,8 @@ static ma_thread_result MA_THREADCALL ma_worker_thread(void* pData) ma_event_signal(&pDevice->stopEvent); } -#ifdef MA_WIN32 - if (CoInitializeResult == S_OK) { +#if defined(MA_WIN32) && !defined(MA_XBOX) + if (CoInitializeResult == S_OK || CoInitializeResult == S_FALSE) { ma_CoUninitialize(pDevice->pContext); } #endif @@ -41358,67 +42865,92 @@ static ma_bool32 ma_device__is_initialized(ma_device* pDevice) static ma_result ma_context_uninit_backend_apis__win32(ma_context* pContext) { /* For some reason UWP complains when CoUninitialize() is called. I'm just not going to call it on UWP. */ -#if defined(MA_WIN32_DESKTOP) || defined(MA_WIN32_GDK) - if (pContext->win32.CoInitializeResult == S_OK) { - ma_CoUninitialize(pContext); - } + #if defined(MA_WIN32_DESKTOP) || defined(MA_WIN32_GDK) + { + /* TODO: Remove this once the new single threaded backend system is in place in 0.12. */ + #if !defined(MA_XBOX) + { + if (pContext->win32.CoInitializeResult == S_OK || pContext->win32.CoInitializeResult == S_FALSE) { + ma_CoUninitialize(pContext); /* TODO: Remove this once the new single threaded backend system is in place in 0.12. */ + } + } + #endif - #if defined(MA_WIN32_DESKTOP) - ma_dlclose(ma_context_get_log(pContext), pContext->win32.hUser32DLL); - ma_dlclose(ma_context_get_log(pContext), pContext->win32.hAdvapi32DLL); - #endif + #if defined(MA_WIN32_DESKTOP) + ma_dlclose(ma_context_get_log(pContext), pContext->win32.hUser32DLL); + ma_dlclose(ma_context_get_log(pContext), pContext->win32.hAdvapi32DLL); + #endif - ma_dlclose(ma_context_get_log(pContext), pContext->win32.hOle32DLL); -#else - (void)pContext; -#endif + ma_dlclose(ma_context_get_log(pContext), pContext->win32.hOle32DLL); + } + #else + { + (void)pContext; + } + #endif return MA_SUCCESS; } static ma_result ma_context_init_backend_apis__win32(ma_context* pContext) { -#if defined(MA_WIN32_DESKTOP) || defined(MA_WIN32_GDK) - #if defined(MA_WIN32_DESKTOP) - /* User32.dll */ - pContext->win32.hUser32DLL = ma_dlopen(ma_context_get_log(pContext), "user32.dll"); - if (pContext->win32.hUser32DLL == NULL) { - return MA_FAILED_TO_INIT_BACKEND; - } + /* + TODO: Reassess all of this stuff and move everything to the relevant backends. For example, I think + GetForegroundWindow() and GetDesktopWindow() are only used by the DirectSound backend. + */ + #if (defined(MA_WIN32_DESKTOP) || defined(MA_WIN32_GDK)) && !defined(MA_XBOX) + { + #if defined(MA_WIN32_DESKTOP) + { + /* User32.dll */ + pContext->win32.hUser32DLL = ma_dlopen(ma_context_get_log(pContext), "user32.dll"); + if (pContext->win32.hUser32DLL == NULL) { + return MA_FAILED_TO_INIT_BACKEND; + } + + pContext->win32.GetForegroundWindow = (ma_proc)ma_dlsym(ma_context_get_log(pContext), pContext->win32.hUser32DLL, "GetForegroundWindow"); + pContext->win32.GetDesktopWindow = (ma_proc)ma_dlsym(ma_context_get_log(pContext), pContext->win32.hUser32DLL, "GetDesktopWindow"); - pContext->win32.GetForegroundWindow = (ma_proc)ma_dlsym(ma_context_get_log(pContext), pContext->win32.hUser32DLL, "GetForegroundWindow"); - pContext->win32.GetDesktopWindow = (ma_proc)ma_dlsym(ma_context_get_log(pContext), pContext->win32.hUser32DLL, "GetDesktopWindow"); + /* Advapi32.dll */ + pContext->win32.hAdvapi32DLL = ma_dlopen(ma_context_get_log(pContext), "advapi32.dll"); + if (pContext->win32.hAdvapi32DLL == NULL) { + return MA_FAILED_TO_INIT_BACKEND; + } + + pContext->win32.RegOpenKeyExA = (ma_proc)ma_dlsym(ma_context_get_log(pContext), pContext->win32.hAdvapi32DLL, "RegOpenKeyExA"); + pContext->win32.RegCloseKey = (ma_proc)ma_dlsym(ma_context_get_log(pContext), pContext->win32.hAdvapi32DLL, "RegCloseKey"); + pContext->win32.RegQueryValueExA = (ma_proc)ma_dlsym(ma_context_get_log(pContext), pContext->win32.hAdvapi32DLL, "RegQueryValueExA"); + } + #endif - /* Advapi32.dll */ - pContext->win32.hAdvapi32DLL = ma_dlopen(ma_context_get_log(pContext), "advapi32.dll"); - if (pContext->win32.hAdvapi32DLL == NULL) { + /* Ole32.dll */ + pContext->win32.hOle32DLL = ma_dlopen(ma_context_get_log(pContext), "ole32.dll"); + if (pContext->win32.hOle32DLL == NULL) { return MA_FAILED_TO_INIT_BACKEND; } - pContext->win32.RegOpenKeyExA = (ma_proc)ma_dlsym(ma_context_get_log(pContext), pContext->win32.hAdvapi32DLL, "RegOpenKeyExA"); - pContext->win32.RegCloseKey = (ma_proc)ma_dlsym(ma_context_get_log(pContext), pContext->win32.hAdvapi32DLL, "RegCloseKey"); - pContext->win32.RegQueryValueExA = (ma_proc)ma_dlsym(ma_context_get_log(pContext), pContext->win32.hAdvapi32DLL, "RegQueryValueExA"); + pContext->win32.CoInitialize = (ma_proc)ma_dlsym(ma_context_get_log(pContext), pContext->win32.hOle32DLL, "CoInitialize"); + pContext->win32.CoInitializeEx = (ma_proc)ma_dlsym(ma_context_get_log(pContext), pContext->win32.hOle32DLL, "CoInitializeEx"); + pContext->win32.CoUninitialize = (ma_proc)ma_dlsym(ma_context_get_log(pContext), pContext->win32.hOle32DLL, "CoUninitialize"); + pContext->win32.CoCreateInstance = (ma_proc)ma_dlsym(ma_context_get_log(pContext), pContext->win32.hOle32DLL, "CoCreateInstance"); + pContext->win32.CoTaskMemFree = (ma_proc)ma_dlsym(ma_context_get_log(pContext), pContext->win32.hOle32DLL, "CoTaskMemFree"); + pContext->win32.PropVariantClear = (ma_proc)ma_dlsym(ma_context_get_log(pContext), pContext->win32.hOle32DLL, "PropVariantClear"); + pContext->win32.StringFromGUID2 = (ma_proc)ma_dlsym(ma_context_get_log(pContext), pContext->win32.hOle32DLL, "StringFromGUID2"); + } + #else + { + (void)pContext; /* Unused. */ + } #endif - /* Ole32.dll */ - pContext->win32.hOle32DLL = ma_dlopen(ma_context_get_log(pContext), "ole32.dll"); - if (pContext->win32.hOle32DLL == NULL) { - return MA_FAILED_TO_INIT_BACKEND; + /* TODO: Remove this once the new single threaded backend system is in place in 0.12. */ + #if !defined(MA_XBOX) + { + pContext->win32.CoInitializeResult = ma_CoInitializeEx(pContext, NULL, MA_COINIT_VALUE); } + #endif - pContext->win32.CoInitialize = (ma_proc)ma_dlsym(ma_context_get_log(pContext), pContext->win32.hOle32DLL, "CoInitialize"); - pContext->win32.CoInitializeEx = (ma_proc)ma_dlsym(ma_context_get_log(pContext), pContext->win32.hOle32DLL, "CoInitializeEx"); - pContext->win32.CoUninitialize = (ma_proc)ma_dlsym(ma_context_get_log(pContext), pContext->win32.hOle32DLL, "CoUninitialize"); - pContext->win32.CoCreateInstance = (ma_proc)ma_dlsym(ma_context_get_log(pContext), pContext->win32.hOle32DLL, "CoCreateInstance"); - pContext->win32.CoTaskMemFree = (ma_proc)ma_dlsym(ma_context_get_log(pContext), pContext->win32.hOle32DLL, "CoTaskMemFree"); - pContext->win32.PropVariantClear = (ma_proc)ma_dlsym(ma_context_get_log(pContext), pContext->win32.hOle32DLL, "PropVariantClear"); - pContext->win32.StringFromGUID2 = (ma_proc)ma_dlsym(ma_context_get_log(pContext), pContext->win32.hOle32DLL, "StringFromGUID2"); -#else - (void)pContext; /* Unused. */ -#endif - - pContext->win32.CoInitializeResult = ma_CoInitializeEx(pContext, NULL, MA_COINIT_VALUE); return MA_SUCCESS; } #else @@ -44016,7 +45548,7 @@ static MA_INLINE void ma_pcm_s16_to_s32__reference(void* dst, const void* src, m ma_uint64 i; for (i = 0; i < count; i += 1) { - dst_s32[i] = src_s16[i] << 16; + dst_s32[i] = (ma_int32)src_s16[i] << 16; } (void)ditherMode; @@ -49347,15 +50879,15 @@ static /*__attribute__((noinline))*/ ma_result ma_gainer_process_pcm_frames_inte a += d; } } + + pFramesOut = ma_offset_ptr(pFramesOut, interpolatedFrameCount * sizeof(float)); + pFramesIn = ma_offset_ptr(pFramesIn, interpolatedFrameCount * sizeof(float)); } + frameCount -= interpolatedFrameCount; + /* Make sure the timer is updated. */ pGainer->t = (ma_uint32)ma_min(pGainer->t + interpolatedFrameCount, pGainer->config.smoothTimeInFrames); - - /* Adjust our arguments so the next part can work normally. */ - frameCount -= interpolatedFrameCount; - pFramesOut = ma_offset_ptr(pFramesOut, interpolatedFrameCount * sizeof(float)); - pFramesIn = ma_offset_ptr(pFramesIn, interpolatedFrameCount * sizeof(float)); } /* All we need to do here is apply the new gains using an optimized path. */ @@ -50783,13 +52315,16 @@ static float ma_calculate_angular_gain(ma_vec3f dirA, ma_vec3f dirB, float coneI MA_API ma_result ma_spatializer_process_pcm_frames(ma_spatializer* pSpatializer, ma_spatializer_listener* pListener, void* pFramesOut, const void* pFramesIn, ma_uint64 frameCount) { - ma_channel* pChannelMapIn = pSpatializer->pChannelMapIn; - ma_channel* pChannelMapOut = pListener->config.pChannelMapOut; + ma_channel* pChannelMapIn; + ma_channel* pChannelMapOut; - if (pSpatializer == NULL) { + if (pSpatializer == NULL || pListener == NULL) { return MA_INVALID_ARGS; } + pChannelMapIn = pSpatializer->pChannelMapIn; + pChannelMapOut = pListener->config.pChannelMapOut; + /* If we're not spatializing we need to run an optimized path. */ if (ma_atomic_load_i32(&pSpatializer->attenuationModel) == ma_attenuation_model_none) { if (ma_spatializer_listener_is_enabled(pListener)) { @@ -50834,23 +52369,17 @@ MA_API ma_result ma_spatializer_process_pcm_frames(ma_spatializer* pSpatializer, We'll need the listener velocity for doppler pitch calculations. The speed of sound is defined by the listener, so we'll grab that here too. */ - if (pListener != NULL) { - listenerVel = ma_spatializer_listener_get_velocity(pListener); - speedOfSound = pListener->config.speedOfSound; - } else { - listenerVel = ma_vec3f_init_3f(0, 0, 0); - speedOfSound = MA_DEFAULT_SPEED_OF_SOUND; - } + listenerVel = ma_spatializer_listener_get_velocity(pListener); + speedOfSound = pListener->config.speedOfSound; - if (pListener == NULL || ma_spatializer_get_positioning(pSpatializer) == ma_positioning_relative) { - /* There's no listener or we're using relative positioning. */ + if (ma_spatializer_get_positioning(pSpatializer) == ma_positioning_relative) { relativePos = ma_spatializer_get_position(pSpatializer); relativeDir = ma_spatializer_get_direction(pSpatializer); } else { /* - We've found a listener and we're using absolute positioning. We need to transform the - sound's position and direction so that it's relative to listener. Later on we'll use - this for determining the factors to apply to each channel to apply the panning effect. + We're using absolute positioning. We need to transform the sound's position and + direction so that it's relative to listener. Later on we'll use this for determining + the factors to apply to each channel to apply the panning effect. */ ma_spatializer_get_relative_position_and_direction(pSpatializer, pListener, &relativePos, &relativeDir); } @@ -52885,7 +54414,7 @@ static ma_bool32 ma_is_spatial_channel_position(ma_channel channelPosition) return MA_FALSE; } - if (channelPosition >= MA_CHANNEL_AUX_0 && channelPosition <= MA_CHANNEL_AUX_31) { + if (channelPosition >= MA_CHANNEL_AUX_0) { return MA_FALSE; } @@ -56408,8 +57937,12 @@ MA_API size_t ma_channel_map_to_string(const ma_channel* pChannelMap, ma_uint32 } /* Null terminate. Don't increment the length here. */ - if (pBufferOut != NULL && bufferCap > len + 1) { - pBufferOut[len] = '\0'; + if (pBufferOut != NULL) { + if (bufferCap > len) { + pBufferOut[len] = '\0'; + } else if (bufferCap > 0) { + pBufferOut[bufferCap - 1] = '\0'; + } } return len; @@ -56620,7 +58153,7 @@ MA_API ma_result ma_rb_init_ex(size_t subbufferSizeInBytes, size_t subbufferCoun Here is where we allocate our own buffer. We always want to align this to MA_SIMD_ALIGNMENT for future SIMD optimization opportunity. To do this we need to make sure the stride is a multiple of MA_SIMD_ALIGNMENT. */ - pRB->subbufferStrideInBytes = (pRB->subbufferSizeInBytes + (MA_SIMD_ALIGNMENT-1)) & ~MA_SIMD_ALIGNMENT; + pRB->subbufferStrideInBytes = ma_align(pRB->subbufferSizeInBytes, MA_SIMD_ALIGNMENT); bufferSizeInBytes = (size_t)pRB->subbufferCount*pRB->subbufferStrideInBytes; pRB->pBuffer = ma_aligned_malloc(bufferSizeInBytes, MA_SIMD_ALIGNMENT, &pRB->allocationCallbacks); @@ -59515,7 +61048,7 @@ MA_API ma_result ma_vfs_info(ma_vfs* pVFS, ma_vfs_file file, ma_file_info* pInfo } -#if !defined(MA_USE_WIN32_FILEIO) && (defined(MA_WIN32) && defined(MA_WIN32_DESKTOP) && !defined(MA_NO_WIN32_FILEIO) && !defined(MA_POSIX)) +#if !defined(MA_USE_WIN32_FILEIO) && (defined(MA_WIN32) && (defined(MA_WIN32_DESKTOP) || defined(MA_WIN32_NXDK)) && !defined(MA_NO_WIN32_FILEIO) && !defined(MA_POSIX)) #define MA_USE_WIN32_FILEIO #endif @@ -59592,25 +61125,34 @@ static ma_result ma_default_vfs_open__win32(ma_vfs* pVFS, const char* pFilePath, static ma_result ma_default_vfs_open_w__win32(ma_vfs* pVFS, const wchar_t* pFilePath, ma_uint32 openMode, ma_vfs_file* pFile) { - HANDLE hFile; - DWORD dwDesiredAccess; - DWORD dwShareMode; - DWORD dwCreationDisposition; + #if !defined(MA_XBOX_NXDK) + { + HANDLE hFile; + DWORD dwDesiredAccess; + DWORD dwShareMode; + DWORD dwCreationDisposition; - (void)pVFS; + (void)pVFS; - /* Load some Win32 symbols dynamically so we can dynamically check for the existence of SetFilePointerEx. */ - ma_win32_fileio_init(); + /* Load some Win32 symbols dynamically so we can dynamically check for the existence of SetFilePointerEx. */ + ma_win32_fileio_init(); - ma_default_vfs__get_open_settings_win32(openMode, &dwDesiredAccess, &dwShareMode, &dwCreationDisposition); + ma_default_vfs__get_open_settings_win32(openMode, &dwDesiredAccess, &dwShareMode, &dwCreationDisposition); - hFile = CreateFileW(pFilePath, dwDesiredAccess, dwShareMode, NULL, dwCreationDisposition, FILE_ATTRIBUTE_NORMAL, NULL); - if (hFile == INVALID_HANDLE_VALUE) { - return ma_result_from_GetLastError(GetLastError()); - } + hFile = CreateFileW(pFilePath, dwDesiredAccess, dwShareMode, NULL, dwCreationDisposition, FILE_ATTRIBUTE_NORMAL, NULL); + if (hFile == INVALID_HANDLE_VALUE) { + return ma_result_from_GetLastError(GetLastError()); + } - *pFile = hFile; - return MA_SUCCESS; + *pFile = hFile; + return MA_SUCCESS; + } + #else + { + /* No CreateFileW() available. */ + return MA_NOT_IMPLEMENTED; + } + #endif } static ma_result ma_default_vfs_close__win32(ma_vfs* pVFS, ma_vfs_file file) @@ -59781,19 +61323,28 @@ static ma_result ma_default_vfs_tell__win32(ma_vfs* pVFS, ma_vfs_file file, ma_i static ma_result ma_default_vfs_info__win32(ma_vfs* pVFS, ma_vfs_file file, ma_file_info* pInfo) { - BY_HANDLE_FILE_INFORMATION fi; - BOOL result; - (void)pVFS; - result = GetFileInformationByHandle((HANDLE)file, &fi); - if (result == 0) { - return ma_result_from_GetLastError(GetLastError()); - } + #if !defined(MA_XBOX_NXDK) + { + BY_HANDLE_FILE_INFORMATION fi; + BOOL result; - pInfo->sizeInBytes = ((ma_uint64)fi.nFileSizeHigh << 32) | ((ma_uint64)fi.nFileSizeLow); + result = GetFileInformationByHandle((HANDLE)file, &fi); + if (result == 0) { + return ma_result_from_GetLastError(GetLastError()); + } - return MA_SUCCESS; + pInfo->sizeInBytes = ((ma_uint64)fi.nFileSizeHigh << 32) | ((ma_uint64)fi.nFileSizeLow); + + return MA_SUCCESS; + } + #else + { + /* GetFileInformationByHandle() is unavailable. */ + return MA_NOT_IMPLEMENTED; + } + #endif } #else static ma_result ma_default_vfs_open__stdio(ma_vfs* pVFS, const char* pFilePath, ma_uint32 openMode, ma_vfs_file* pFile) @@ -60131,6 +61682,8 @@ static ma_result ma_default_vfs_tell(ma_vfs* pVFS, ma_vfs_file file, ma_int64* p static ma_result ma_default_vfs_info(ma_vfs* pVFS, ma_vfs_file file, ma_file_info* pInfo) { + ma_result result; + if (pInfo == NULL) { return MA_INVALID_ARGS; } @@ -60142,10 +61695,42 @@ static ma_result ma_default_vfs_info(ma_vfs* pVFS, ma_vfs_file file, ma_file_inf } #if defined(MA_USE_WIN32_FILEIO) - return ma_default_vfs_info__win32(pVFS, file, pInfo); + result = ma_default_vfs_info__win32(pVFS, file, pInfo); #else - return ma_default_vfs_info__stdio(pVFS, file, pInfo); + result = ma_default_vfs_info__stdio(pVFS, file, pInfo); #endif + + if (result == MA_NOT_IMPLEMENTED) { + /* Not implemented. Fall back to seek/tell/seek. */ + ma_int64 cursor; + ma_int64 sizeInBytes; + + result = ma_default_vfs_tell(pVFS, file, &cursor); + if (result != MA_SUCCESS) { + return result; + } + + result = ma_default_vfs_seek(pVFS, file, 0, ma_seek_origin_end); + if (result != MA_SUCCESS) { + return result; + } + + result = ma_default_vfs_tell(pVFS, file, &sizeInBytes); + if (result != MA_SUCCESS) { + return result; + } + + pInfo->sizeInBytes = sizeInBytes; + + result = ma_default_vfs_seek(pVFS, file, cursor, ma_seek_origin_start); + if (result != MA_SUCCESS) { + return result; + } + + MA_ASSERT(result == MA_SUCCESS); + } + + return result; } @@ -60324,6 +61909,8 @@ Decoding and Encoding Headers. These are auto-generated from a tool. **************************************************************************************************************************************************************/ #if !defined(MA_NO_WAV) && (!defined(MA_NO_DECODING) || !defined(MA_NO_ENCODING)) +#define MA_HAS_WAV + /* dr_wav_h begin */ #ifndef ma_dr_wav_h #define ma_dr_wav_h @@ -60333,8 +61920,8 @@ extern "C" { #define MA_DR_WAV_STRINGIFY(x) #x #define MA_DR_WAV_XSTRINGIFY(x) MA_DR_WAV_STRINGIFY(x) #define MA_DR_WAV_VERSION_MAJOR 0 -#define MA_DR_WAV_VERSION_MINOR 13 -#define MA_DR_WAV_VERSION_REVISION 18 +#define MA_DR_WAV_VERSION_MINOR 14 +#define MA_DR_WAV_VERSION_REVISION 4 #define MA_DR_WAV_VERSION_STRING MA_DR_WAV_XSTRINGIFY(MA_DR_WAV_VERSION_MAJOR) "." MA_DR_WAV_XSTRINGIFY(MA_DR_WAV_VERSION_MINOR) "." MA_DR_WAV_XSTRINGIFY(MA_DR_WAV_VERSION_REVISION) #include <stddef.h> #define MA_DR_WAVE_FORMAT_PCM 0x1 @@ -60350,8 +61937,9 @@ MA_API void ma_dr_wav_version(ma_uint32* pMajor, ma_uint32* pMinor, ma_uint32* p MA_API const char* ma_dr_wav_version_string(void); typedef enum { - ma_dr_wav_seek_origin_start, - ma_dr_wav_seek_origin_current + MA_DR_WAV_SEEK_SET, + MA_DR_WAV_SEEK_CUR, + MA_DR_WAV_SEEK_END } ma_dr_wav_seek_origin; typedef enum { @@ -60388,6 +61976,7 @@ MA_API ma_uint16 ma_dr_wav_fmt_get_format(const ma_dr_wav_fmt* pFMT); typedef size_t (* ma_dr_wav_read_proc)(void* pUserData, void* pBufferOut, size_t bytesToRead); typedef size_t (* ma_dr_wav_write_proc)(void* pUserData, const void* pData, size_t bytesToWrite); typedef ma_bool32 (* ma_dr_wav_seek_proc)(void* pUserData, int offset, ma_dr_wav_seek_origin origin); +typedef ma_bool32 (* ma_dr_wav_tell_proc)(void* pUserData, ma_int64* pCursor); typedef ma_uint64 (* ma_dr_wav_chunk_proc)(void* pChunkUserData, ma_dr_wav_read_proc onRead, ma_dr_wav_seek_proc onSeek, void* pReadSeekUserData, const ma_dr_wav_chunk_header* pChunkHeader, ma_dr_wav_container container, const ma_dr_wav_fmt* pFMT); typedef struct { @@ -60432,6 +62021,11 @@ typedef enum ma_dr_wav_metadata_type_list_info_genre = 1 << 15, ma_dr_wav_metadata_type_list_info_album = 1 << 16, ma_dr_wav_metadata_type_list_info_tracknumber = 1 << 17, + ma_dr_wav_metadata_type_list_info_location = 1 << 18, + ma_dr_wav_metadata_type_list_info_organization = 1 << 19, + ma_dr_wav_metadata_type_list_info_keywords = 1 << 20, + ma_dr_wav_metadata_type_list_info_medium = 1 << 21, + ma_dr_wav_metadata_type_list_info_description = 1 << 22, ma_dr_wav_metadata_type_list_all_info_strings = ma_dr_wav_metadata_type_list_info_software | ma_dr_wav_metadata_type_list_info_copyright | ma_dr_wav_metadata_type_list_info_title @@ -60440,7 +62034,12 @@ typedef enum | ma_dr_wav_metadata_type_list_info_date | ma_dr_wav_metadata_type_list_info_genre | ma_dr_wav_metadata_type_list_info_album - | ma_dr_wav_metadata_type_list_info_tracknumber, + | ma_dr_wav_metadata_type_list_info_tracknumber + | ma_dr_wav_metadata_type_list_info_location + | ma_dr_wav_metadata_type_list_info_organization + | ma_dr_wav_metadata_type_list_info_keywords + | ma_dr_wav_metadata_type_list_info_medium + | ma_dr_wav_metadata_type_list_info_description, ma_dr_wav_metadata_type_list_all_adtl = ma_dr_wav_metadata_type_list_label | ma_dr_wav_metadata_type_list_note | ma_dr_wav_metadata_type_list_labelled_cue_region, @@ -60457,8 +62056,8 @@ typedef struct { ma_uint32 cuePointId; ma_uint32 type; - ma_uint32 firstSampleByteOffset; - ma_uint32 lastSampleByteOffset; + ma_uint32 firstSampleOffset; + ma_uint32 lastSampleOffset; ma_uint32 sampleFraction; ma_uint32 playCount; } ma_dr_wav_smpl_loop; @@ -60493,7 +62092,7 @@ typedef struct ma_uint8 dataChunkId[4]; ma_uint32 chunkStart; ma_uint32 blockStart; - ma_uint32 sampleByteOffset; + ma_uint32 sampleOffset; } ma_dr_wav_cue_point; typedef struct { @@ -60595,6 +62194,7 @@ typedef struct ma_dr_wav_read_proc onRead; ma_dr_wav_write_proc onWrite; ma_dr_wav_seek_proc onSeek; + ma_dr_wav_tell_proc onTell; void* pUserData; ma_allocation_callbacks allocationCallbacks; ma_dr_wav_container container; @@ -60637,9 +62237,9 @@ typedef struct ma_bool8 isUnsigned; } aiff; } ma_dr_wav; -MA_API ma_bool32 ma_dr_wav_init(ma_dr_wav* pWav, ma_dr_wav_read_proc onRead, ma_dr_wav_seek_proc onSeek, void* pUserData, const ma_allocation_callbacks* pAllocationCallbacks); -MA_API ma_bool32 ma_dr_wav_init_ex(ma_dr_wav* pWav, ma_dr_wav_read_proc onRead, ma_dr_wav_seek_proc onSeek, ma_dr_wav_chunk_proc onChunk, void* pReadSeekUserData, void* pChunkUserData, ma_uint32 flags, const ma_allocation_callbacks* pAllocationCallbacks); -MA_API ma_bool32 ma_dr_wav_init_with_metadata(ma_dr_wav* pWav, ma_dr_wav_read_proc onRead, ma_dr_wav_seek_proc onSeek, void* pUserData, ma_uint32 flags, const ma_allocation_callbacks* pAllocationCallbacks); +MA_API ma_bool32 ma_dr_wav_init(ma_dr_wav* pWav, ma_dr_wav_read_proc onRead, ma_dr_wav_seek_proc onSeek, ma_dr_wav_tell_proc onTell, void* pUserData, const ma_allocation_callbacks* pAllocationCallbacks); +MA_API ma_bool32 ma_dr_wav_init_ex(ma_dr_wav* pWav, ma_dr_wav_read_proc onRead, ma_dr_wav_seek_proc onSeek, ma_dr_wav_tell_proc onTell, ma_dr_wav_chunk_proc onChunk, void* pReadSeekTellUserData, void* pChunkUserData, ma_uint32 flags, const ma_allocation_callbacks* pAllocationCallbacks); +MA_API ma_bool32 ma_dr_wav_init_with_metadata(ma_dr_wav* pWav, ma_dr_wav_read_proc onRead, ma_dr_wav_seek_proc onSeek, ma_dr_wav_tell_proc onTell, void* pUserData, ma_uint32 flags, const ma_allocation_callbacks* pAllocationCallbacks); MA_API ma_bool32 ma_dr_wav_init_write(ma_dr_wav* pWav, const ma_dr_wav_data_format* pFormat, ma_dr_wav_write_proc onWrite, ma_dr_wav_seek_proc onSeek, void* pUserData, const ma_allocation_callbacks* pAllocationCallbacks); MA_API ma_bool32 ma_dr_wav_init_write_sequential(ma_dr_wav* pWav, const ma_dr_wav_data_format* pFormat, ma_uint64 totalSampleCount, ma_dr_wav_write_proc onWrite, void* pUserData, const ma_allocation_callbacks* pAllocationCallbacks); MA_API ma_bool32 ma_dr_wav_init_write_sequential_pcm_frames(ma_dr_wav* pWav, const ma_dr_wav_data_format* pFormat, ma_uint64 totalPCMFrameCount, ma_dr_wav_write_proc onWrite, void* pUserData, const ma_allocation_callbacks* pAllocationCallbacks); @@ -60711,9 +62311,9 @@ MA_API ma_bool32 ma_dr_wav_init_memory_write(ma_dr_wav* pWav, void** ppData, siz MA_API ma_bool32 ma_dr_wav_init_memory_write_sequential(ma_dr_wav* pWav, void** ppData, size_t* pDataSize, const ma_dr_wav_data_format* pFormat, ma_uint64 totalSampleCount, const ma_allocation_callbacks* pAllocationCallbacks); MA_API ma_bool32 ma_dr_wav_init_memory_write_sequential_pcm_frames(ma_dr_wav* pWav, void** ppData, size_t* pDataSize, const ma_dr_wav_data_format* pFormat, ma_uint64 totalPCMFrameCount, const ma_allocation_callbacks* pAllocationCallbacks); #ifndef MA_DR_WAV_NO_CONVERSION_API -MA_API ma_int16* ma_dr_wav_open_and_read_pcm_frames_s16(ma_dr_wav_read_proc onRead, ma_dr_wav_seek_proc onSeek, void* pUserData, unsigned int* channelsOut, unsigned int* sampleRateOut, ma_uint64* totalFrameCountOut, const ma_allocation_callbacks* pAllocationCallbacks); -MA_API float* ma_dr_wav_open_and_read_pcm_frames_f32(ma_dr_wav_read_proc onRead, ma_dr_wav_seek_proc onSeek, void* pUserData, unsigned int* channelsOut, unsigned int* sampleRateOut, ma_uint64* totalFrameCountOut, const ma_allocation_callbacks* pAllocationCallbacks); -MA_API ma_int32* ma_dr_wav_open_and_read_pcm_frames_s32(ma_dr_wav_read_proc onRead, ma_dr_wav_seek_proc onSeek, void* pUserData, unsigned int* channelsOut, unsigned int* sampleRateOut, ma_uint64* totalFrameCountOut, const ma_allocation_callbacks* pAllocationCallbacks); +MA_API ma_int16* ma_dr_wav_open_and_read_pcm_frames_s16(ma_dr_wav_read_proc onRead, ma_dr_wav_seek_proc onSeek, ma_dr_wav_tell_proc onTell, void* pUserData, unsigned int* channelsOut, unsigned int* sampleRateOut, ma_uint64* totalFrameCountOut, const ma_allocation_callbacks* pAllocationCallbacks); +MA_API float* ma_dr_wav_open_and_read_pcm_frames_f32(ma_dr_wav_read_proc onRead, ma_dr_wav_seek_proc onSeek, ma_dr_wav_tell_proc onTell, void* pUserData, unsigned int* channelsOut, unsigned int* sampleRateOut, ma_uint64* totalFrameCountOut, const ma_allocation_callbacks* pAllocationCallbacks); +MA_API ma_int32* ma_dr_wav_open_and_read_pcm_frames_s32(ma_dr_wav_read_proc onRead, ma_dr_wav_seek_proc onSeek, ma_dr_wav_tell_proc onTell, void* pUserData, unsigned int* channelsOut, unsigned int* sampleRateOut, ma_uint64* totalFrameCountOut, const ma_allocation_callbacks* pAllocationCallbacks); #ifndef MA_DR_WAV_NO_STDIO MA_API ma_int16* ma_dr_wav_open_file_and_read_pcm_frames_s16(const char* filename, unsigned int* channelsOut, unsigned int* sampleRateOut, ma_uint64* totalFrameCountOut, const ma_allocation_callbacks* pAllocationCallbacks); MA_API float* ma_dr_wav_open_file_and_read_pcm_frames_f32(const char* filename, unsigned int* channelsOut, unsigned int* sampleRateOut, ma_uint64* totalFrameCountOut, const ma_allocation_callbacks* pAllocationCallbacks); @@ -60744,6 +62344,8 @@ MA_API ma_bool32 ma_dr_wav_fourcc_equal(const ma_uint8* a, const char* b); #endif /* MA_NO_WAV */ #if !defined(MA_NO_FLAC) && !defined(MA_NO_DECODING) +#define MA_HAS_FLAC + /* dr_flac_h begin */ #ifndef ma_dr_flac_h #define ma_dr_flac_h @@ -60753,8 +62355,8 @@ extern "C" { #define MA_DR_FLAC_STRINGIFY(x) #x #define MA_DR_FLAC_XSTRINGIFY(x) MA_DR_FLAC_STRINGIFY(x) #define MA_DR_FLAC_VERSION_MAJOR 0 -#define MA_DR_FLAC_VERSION_MINOR 12 -#define MA_DR_FLAC_VERSION_REVISION 43 +#define MA_DR_FLAC_VERSION_MINOR 13 +#define MA_DR_FLAC_VERSION_REVISION 3 #define MA_DR_FLAC_VERSION_STRING MA_DR_FLAC_XSTRINGIFY(MA_DR_FLAC_VERSION_MAJOR) "." MA_DR_FLAC_XSTRINGIFY(MA_DR_FLAC_VERSION_MINOR) "." MA_DR_FLAC_XSTRINGIFY(MA_DR_FLAC_VERSION_REVISION) #include <stddef.h> #if defined(_MSC_VER) && _MSC_VER >= 1700 @@ -60817,8 +62419,9 @@ typedef enum } ma_dr_flac_container; typedef enum { - ma_dr_flac_seek_origin_start, - ma_dr_flac_seek_origin_current + MA_DR_FLAC_SEEK_SET, + MA_DR_FLAC_SEEK_CUR, + MA_DR_FLAC_SEEK_END } ma_dr_flac_seek_origin; typedef struct { @@ -60841,8 +62444,9 @@ typedef struct typedef struct { ma_uint32 type; - const void* pRawData; ma_uint32 rawDataSize; + ma_uint64 rawDataOffset; + const void* pRawData; union { ma_dr_flac_streaminfo streaminfo; @@ -60888,12 +62492,14 @@ typedef struct ma_uint32 colorDepth; ma_uint32 indexColorCount; ma_uint32 pictureDataSize; + ma_uint64 pictureDataOffset; const ma_uint8* pPictureData; } picture; } data; } ma_dr_flac_metadata; typedef size_t (* ma_dr_flac_read_proc)(void* pUserData, void* pBufferOut, size_t bytesToRead); typedef ma_bool32 (* ma_dr_flac_seek_proc)(void* pUserData, int offset, ma_dr_flac_seek_origin origin); +typedef ma_bool32 (* ma_dr_flac_tell_proc)(void* pUserData, ma_int64* pCursor); typedef void (* ma_dr_flac_meta_proc)(void* pUserData, ma_dr_flac_metadata* pMetadata); typedef struct { @@ -60905,6 +62511,7 @@ typedef struct { ma_dr_flac_read_proc onRead; ma_dr_flac_seek_proc onSeek; + ma_dr_flac_tell_proc onTell; void* pUserData; size_t unalignedByteCount; ma_dr_flac_cache_t unalignedCache; @@ -60964,10 +62571,10 @@ typedef struct ma_dr_flac_bs bs; ma_uint8 pExtraData[1]; } ma_dr_flac; -MA_API ma_dr_flac* ma_dr_flac_open(ma_dr_flac_read_proc onRead, ma_dr_flac_seek_proc onSeek, void* pUserData, const ma_allocation_callbacks* pAllocationCallbacks); -MA_API ma_dr_flac* ma_dr_flac_open_relaxed(ma_dr_flac_read_proc onRead, ma_dr_flac_seek_proc onSeek, ma_dr_flac_container container, void* pUserData, const ma_allocation_callbacks* pAllocationCallbacks); -MA_API ma_dr_flac* ma_dr_flac_open_with_metadata(ma_dr_flac_read_proc onRead, ma_dr_flac_seek_proc onSeek, ma_dr_flac_meta_proc onMeta, void* pUserData, const ma_allocation_callbacks* pAllocationCallbacks); -MA_API ma_dr_flac* ma_dr_flac_open_with_metadata_relaxed(ma_dr_flac_read_proc onRead, ma_dr_flac_seek_proc onSeek, ma_dr_flac_meta_proc onMeta, ma_dr_flac_container container, void* pUserData, const ma_allocation_callbacks* pAllocationCallbacks); +MA_API ma_dr_flac* ma_dr_flac_open(ma_dr_flac_read_proc onRead, ma_dr_flac_seek_proc onSeek, ma_dr_flac_tell_proc onTell, void* pUserData, const ma_allocation_callbacks* pAllocationCallbacks); +MA_API ma_dr_flac* ma_dr_flac_open_relaxed(ma_dr_flac_read_proc onRead, ma_dr_flac_seek_proc onSeek, ma_dr_flac_tell_proc onTell, ma_dr_flac_container container, void* pUserData, const ma_allocation_callbacks* pAllocationCallbacks); +MA_API ma_dr_flac* ma_dr_flac_open_with_metadata(ma_dr_flac_read_proc onRead, ma_dr_flac_seek_proc onSeek, ma_dr_flac_tell_proc onTell, ma_dr_flac_meta_proc onMeta, void* pUserData, const ma_allocation_callbacks* pAllocationCallbacks); +MA_API ma_dr_flac* ma_dr_flac_open_with_metadata_relaxed(ma_dr_flac_read_proc onRead, ma_dr_flac_seek_proc onSeek, ma_dr_flac_tell_proc onTell, ma_dr_flac_meta_proc onMeta, ma_dr_flac_container container, void* pUserData, const ma_allocation_callbacks* pAllocationCallbacks); MA_API void ma_dr_flac_close(ma_dr_flac* pFlac); MA_API ma_uint64 ma_dr_flac_read_pcm_frames_s32(ma_dr_flac* pFlac, ma_uint64 framesToRead, ma_int32* pBufferOut); MA_API ma_uint64 ma_dr_flac_read_pcm_frames_s16(ma_dr_flac* pFlac, ma_uint64 framesToRead, ma_int16* pBufferOut); @@ -60981,9 +62588,9 @@ MA_API ma_dr_flac* ma_dr_flac_open_file_with_metadata_w(const wchar_t* pFileName #endif MA_API ma_dr_flac* ma_dr_flac_open_memory(const void* pData, size_t dataSize, const ma_allocation_callbacks* pAllocationCallbacks); MA_API ma_dr_flac* ma_dr_flac_open_memory_with_metadata(const void* pData, size_t dataSize, ma_dr_flac_meta_proc onMeta, void* pUserData, const ma_allocation_callbacks* pAllocationCallbacks); -MA_API ma_int32* ma_dr_flac_open_and_read_pcm_frames_s32(ma_dr_flac_read_proc onRead, ma_dr_flac_seek_proc onSeek, void* pUserData, unsigned int* channels, unsigned int* sampleRate, ma_uint64* totalPCMFrameCount, const ma_allocation_callbacks* pAllocationCallbacks); -MA_API ma_int16* ma_dr_flac_open_and_read_pcm_frames_s16(ma_dr_flac_read_proc onRead, ma_dr_flac_seek_proc onSeek, void* pUserData, unsigned int* channels, unsigned int* sampleRate, ma_uint64* totalPCMFrameCount, const ma_allocation_callbacks* pAllocationCallbacks); -MA_API float* ma_dr_flac_open_and_read_pcm_frames_f32(ma_dr_flac_read_proc onRead, ma_dr_flac_seek_proc onSeek, void* pUserData, unsigned int* channels, unsigned int* sampleRate, ma_uint64* totalPCMFrameCount, const ma_allocation_callbacks* pAllocationCallbacks); +MA_API ma_int32* ma_dr_flac_open_and_read_pcm_frames_s32(ma_dr_flac_read_proc onRead, ma_dr_flac_seek_proc onSeek, ma_dr_flac_tell_proc onTell, void* pUserData, unsigned int* channels, unsigned int* sampleRate, ma_uint64* totalPCMFrameCount, const ma_allocation_callbacks* pAllocationCallbacks); +MA_API ma_int16* ma_dr_flac_open_and_read_pcm_frames_s16(ma_dr_flac_read_proc onRead, ma_dr_flac_seek_proc onSeek, ma_dr_flac_tell_proc onTell, void* pUserData, unsigned int* channels, unsigned int* sampleRate, ma_uint64* totalPCMFrameCount, const ma_allocation_callbacks* pAllocationCallbacks); +MA_API float* ma_dr_flac_open_and_read_pcm_frames_f32(ma_dr_flac_read_proc onRead, ma_dr_flac_seek_proc onSeek, ma_dr_flac_tell_proc onTell, void* pUserData, unsigned int* channels, unsigned int* sampleRate, ma_uint64* totalPCMFrameCount, const ma_allocation_callbacks* pAllocationCallbacks); #ifndef MA_DR_FLAC_NO_STDIO MA_API ma_int32* ma_dr_flac_open_file_and_read_pcm_frames_s32(const char* filename, unsigned int* channels, unsigned int* sampleRate, ma_uint64* totalPCMFrameCount, const ma_allocation_callbacks* pAllocationCallbacks); MA_API ma_int16* ma_dr_flac_open_file_and_read_pcm_frames_s16(const char* filename, unsigned int* channels, unsigned int* sampleRate, ma_uint64* totalPCMFrameCount, const ma_allocation_callbacks* pAllocationCallbacks); @@ -61031,6 +62638,14 @@ MA_API ma_bool32 ma_dr_flac_next_cuesheet_track(ma_dr_flac_cuesheet_track_iterat #endif /* MA_NO_FLAC */ #if !defined(MA_NO_MP3) && !defined(MA_NO_DECODING) +#define MA_HAS_MP3 + +#ifndef MA_DR_MP3_NO_SIMD + #if (defined(MA_NO_NEON) && defined(MA_ARM)) || (defined(MA_NO_SSE2) && (defined(MA_X86) || defined(MA_X64))) + #define MA_DR_MP3_NO_SIMD + #endif +#endif + /* dr_mp3_h begin */ #ifndef ma_dr_mp3_h #define ma_dr_mp3_h @@ -61040,31 +62655,57 @@ extern "C" { #define MA_DR_MP3_STRINGIFY(x) #x #define MA_DR_MP3_XSTRINGIFY(x) MA_DR_MP3_STRINGIFY(x) #define MA_DR_MP3_VERSION_MAJOR 0 -#define MA_DR_MP3_VERSION_MINOR 6 -#define MA_DR_MP3_VERSION_REVISION 40 +#define MA_DR_MP3_VERSION_MINOR 7 +#define MA_DR_MP3_VERSION_REVISION 3 #define MA_DR_MP3_VERSION_STRING MA_DR_MP3_XSTRINGIFY(MA_DR_MP3_VERSION_MAJOR) "." MA_DR_MP3_XSTRINGIFY(MA_DR_MP3_VERSION_MINOR) "." MA_DR_MP3_XSTRINGIFY(MA_DR_MP3_VERSION_REVISION) #include <stddef.h> #define MA_DR_MP3_MAX_PCM_FRAMES_PER_MP3_FRAME 1152 #define MA_DR_MP3_MAX_SAMPLES_PER_FRAME (MA_DR_MP3_MAX_PCM_FRAMES_PER_MP3_FRAME*2) MA_API void ma_dr_mp3_version(ma_uint32* pMajor, ma_uint32* pMinor, ma_uint32* pRevision); MA_API const char* ma_dr_mp3_version_string(void); +#define MA_DR_MP3_MAX_BITRESERVOIR_BYTES 511 +#define MA_DR_MP3_MAX_FREE_FORMAT_FRAME_SIZE 2304 +#define MA_DR_MP3_MAX_L3_FRAME_PAYLOAD_BYTES MA_DR_MP3_MAX_FREE_FORMAT_FRAME_SIZE typedef struct { - int frame_bytes, channels, hz, layer, bitrate_kbps; + int frame_bytes, channels, sample_rate, layer, bitrate_kbps; } ma_dr_mp3dec_frame_info; typedef struct +{ + const ma_uint8 *buf; + int pos, limit; +} ma_dr_mp3_bs; +typedef struct +{ + const ma_uint8 *sfbtab; + ma_uint16 part_23_length, big_values, scalefac_compress; + ma_uint8 global_gain, block_type, mixed_block_flag, n_long_sfb, n_short_sfb; + ma_uint8 table_select[3], region_count[3], subblock_gain[3]; + ma_uint8 preflag, scalefac_scale, count1_table, scfsi; +} ma_dr_mp3_L3_gr_info; +typedef struct +{ + ma_dr_mp3_bs bs; + ma_uint8 maindata[MA_DR_MP3_MAX_BITRESERVOIR_BYTES + MA_DR_MP3_MAX_L3_FRAME_PAYLOAD_BYTES]; + ma_dr_mp3_L3_gr_info gr_info[4]; + float grbuf[2][576], scf[40], syn[18 + 15][2*32]; + ma_uint8 ist_pos[2][39]; +} ma_dr_mp3dec_scratch; +typedef struct { float mdct_overlap[2][9*32], qmf_state[15*2*32]; int reserv, free_format_bytes; ma_uint8 header[4], reserv_buf[511]; + ma_dr_mp3dec_scratch scratch; } ma_dr_mp3dec; MA_API void ma_dr_mp3dec_init(ma_dr_mp3dec *dec); MA_API int ma_dr_mp3dec_decode_frame(ma_dr_mp3dec *dec, const ma_uint8 *mp3, int mp3_bytes, void *pcm, ma_dr_mp3dec_frame_info *info); MA_API void ma_dr_mp3dec_f32_to_s16(const float *in, ma_int16 *out, size_t num_samples); typedef enum { - ma_dr_mp3_seek_origin_start, - ma_dr_mp3_seek_origin_current + MA_DR_MP3_SEEK_SET, + MA_DR_MP3_SEEK_CUR, + MA_DR_MP3_SEEK_END } ma_dr_mp3_seek_origin; typedef struct { @@ -61073,8 +62714,24 @@ typedef struct ma_uint16 mp3FramesToDiscard; ma_uint16 pcmFramesToDiscard; } ma_dr_mp3_seek_point; +typedef enum +{ + MA_DR_MP3_METADATA_TYPE_ID3V1, + MA_DR_MP3_METADATA_TYPE_ID3V2, + MA_DR_MP3_METADATA_TYPE_APE, + MA_DR_MP3_METADATA_TYPE_XING, + MA_DR_MP3_METADATA_TYPE_VBRI +} ma_dr_mp3_metadata_type; +typedef struct +{ + ma_dr_mp3_metadata_type type; + const void* pRawData; + size_t rawDataSize; +} ma_dr_mp3_metadata; typedef size_t (* ma_dr_mp3_read_proc)(void* pUserData, void* pBufferOut, size_t bytesToRead); typedef ma_bool32 (* ma_dr_mp3_seek_proc)(void* pUserData, int offset, ma_dr_mp3_seek_origin origin); +typedef ma_bool32 (* ma_dr_mp3_tell_proc)(void* pUserData, ma_int64* pCursor); +typedef void (* ma_dr_mp3_meta_proc)(void* pUserData, const ma_dr_mp3_metadata* pMetadata); typedef struct { ma_uint32 channels; @@ -61087,7 +62744,9 @@ typedef struct ma_uint32 sampleRate; ma_dr_mp3_read_proc onRead; ma_dr_mp3_seek_proc onSeek; + ma_dr_mp3_meta_proc onMeta; void* pUserData; + void* pUserDataMeta; ma_allocation_callbacks allocationCallbacks; ma_uint32 mp3FrameChannels; ma_uint32 mp3FrameSampleRate; @@ -61096,13 +62755,20 @@ typedef struct ma_uint8 pcmFrames[sizeof(float)*MA_DR_MP3_MAX_SAMPLES_PER_FRAME]; ma_uint64 currentPCMFrame; ma_uint64 streamCursor; + ma_uint64 streamLength; + ma_uint64 streamStartOffset; ma_dr_mp3_seek_point* pSeekPoints; ma_uint32 seekPointCount; + ma_uint32 delayInPCMFrames; + ma_uint32 paddingInPCMFrames; + ma_uint64 totalPCMFrameCount; + ma_bool32 isVBR; + ma_bool32 isCBR; size_t dataSize; size_t dataCapacity; size_t dataConsumed; ma_uint8* pData; - ma_bool32 atEnd : 1; + ma_bool32 atEnd; struct { const ma_uint8* pData; @@ -61110,9 +62776,12 @@ typedef struct size_t currentReadPos; } memory; } ma_dr_mp3; -MA_API ma_bool32 ma_dr_mp3_init(ma_dr_mp3* pMP3, ma_dr_mp3_read_proc onRead, ma_dr_mp3_seek_proc onSeek, void* pUserData, const ma_allocation_callbacks* pAllocationCallbacks); +MA_API ma_bool32 ma_dr_mp3_init(ma_dr_mp3* pMP3, ma_dr_mp3_read_proc onRead, ma_dr_mp3_seek_proc onSeek, ma_dr_mp3_tell_proc onTell, ma_dr_mp3_meta_proc onMeta, void* pUserData, const ma_allocation_callbacks* pAllocationCallbacks); +MA_API ma_bool32 ma_dr_mp3_init_memory_with_metadata(ma_dr_mp3* pMP3, const void* pData, size_t dataSize, ma_dr_mp3_meta_proc onMeta, void* pUserDataMeta, const ma_allocation_callbacks* pAllocationCallbacks); MA_API ma_bool32 ma_dr_mp3_init_memory(ma_dr_mp3* pMP3, const void* pData, size_t dataSize, const ma_allocation_callbacks* pAllocationCallbacks); #ifndef MA_DR_MP3_NO_STDIO +MA_API ma_bool32 ma_dr_mp3_init_file_with_metadata(ma_dr_mp3* pMP3, const char* pFilePath, ma_dr_mp3_meta_proc onMeta, void* pUserDataMeta, const ma_allocation_callbacks* pAllocationCallbacks); +MA_API ma_bool32 ma_dr_mp3_init_file_with_metadata_w(ma_dr_mp3* pMP3, const wchar_t* pFilePath, ma_dr_mp3_meta_proc onMeta, void* pUserDataMeta, const ma_allocation_callbacks* pAllocationCallbacks); MA_API ma_bool32 ma_dr_mp3_init_file(ma_dr_mp3* pMP3, const char* pFilePath, const ma_allocation_callbacks* pAllocationCallbacks); MA_API ma_bool32 ma_dr_mp3_init_file_w(ma_dr_mp3* pMP3, const wchar_t* pFilePath, const ma_allocation_callbacks* pAllocationCallbacks); #endif @@ -61125,8 +62794,8 @@ MA_API ma_uint64 ma_dr_mp3_get_mp3_frame_count(ma_dr_mp3* pMP3); MA_API ma_bool32 ma_dr_mp3_get_mp3_and_pcm_frame_count(ma_dr_mp3* pMP3, ma_uint64* pMP3FrameCount, ma_uint64* pPCMFrameCount); MA_API ma_bool32 ma_dr_mp3_calculate_seek_points(ma_dr_mp3* pMP3, ma_uint32* pSeekPointCount, ma_dr_mp3_seek_point* pSeekPoints); MA_API ma_bool32 ma_dr_mp3_bind_seek_table(ma_dr_mp3* pMP3, ma_uint32 seekPointCount, ma_dr_mp3_seek_point* pSeekPoints); -MA_API float* ma_dr_mp3_open_and_read_pcm_frames_f32(ma_dr_mp3_read_proc onRead, ma_dr_mp3_seek_proc onSeek, void* pUserData, ma_dr_mp3_config* pConfig, ma_uint64* pTotalFrameCount, const ma_allocation_callbacks* pAllocationCallbacks); -MA_API ma_int16* ma_dr_mp3_open_and_read_pcm_frames_s16(ma_dr_mp3_read_proc onRead, ma_dr_mp3_seek_proc onSeek, void* pUserData, ma_dr_mp3_config* pConfig, ma_uint64* pTotalFrameCount, const ma_allocation_callbacks* pAllocationCallbacks); +MA_API float* ma_dr_mp3_open_and_read_pcm_frames_f32(ma_dr_mp3_read_proc onRead, ma_dr_mp3_seek_proc onSeek, ma_dr_mp3_tell_proc onTell, void* pUserData, ma_dr_mp3_config* pConfig, ma_uint64* pTotalFrameCount, const ma_allocation_callbacks* pAllocationCallbacks); +MA_API ma_int16* ma_dr_mp3_open_and_read_pcm_frames_s16(ma_dr_mp3_read_proc onRead, ma_dr_mp3_seek_proc onSeek, ma_dr_mp3_tell_proc onTell, void* pUserData, ma_dr_mp3_config* pConfig, ma_uint64* pTotalFrameCount, const ma_allocation_callbacks* pAllocationCallbacks); MA_API float* ma_dr_mp3_open_memory_and_read_pcm_frames_f32(const void* pData, size_t dataSize, ma_dr_mp3_config* pConfig, ma_uint64* pTotalFrameCount, const ma_allocation_callbacks* pAllocationCallbacks); MA_API ma_int16* ma_dr_mp3_open_memory_and_read_pcm_frames_s16(const void* pData, size_t dataSize, ma_dr_mp3_config* pConfig, ma_uint64* pTotalFrameCount, const ma_allocation_callbacks* pAllocationCallbacks); #ifndef MA_DR_MP3_NO_STDIO @@ -61591,7 +63260,6 @@ static ma_result ma_decoder_init_custom_from_memory__internal(const void* pData, /* WAV */ #ifdef ma_dr_wav_h -#define MA_HAS_WAV typedef struct { @@ -61679,8 +63347,10 @@ static ma_bool32 ma_wav_dr_callback__seek(void* pUserData, int offset, ma_dr_wav MA_ASSERT(pWav != NULL); maSeekOrigin = ma_seek_origin_start; - if (origin == ma_dr_wav_seek_origin_current) { - maSeekOrigin = ma_seek_origin_current; + if (origin == MA_DR_WAV_SEEK_CUR) { + maSeekOrigin = ma_seek_origin_current; + } else if (origin == MA_DR_WAV_SEEK_END) { + maSeekOrigin = ma_seek_origin_end; } result = pWav->onSeek(pWav->pReadSeekTellUserData, offset, maSeekOrigin); @@ -61690,6 +63360,26 @@ static ma_bool32 ma_wav_dr_callback__seek(void* pUserData, int offset, ma_dr_wav return MA_TRUE; } + +static ma_bool32 ma_wav_dr_callback__tell(void* pUserData, ma_int64* pCursor) +{ + ma_wav* pWav = (ma_wav*)pUserData; + ma_result result; + + MA_ASSERT(pWav != NULL); + MA_ASSERT(pCursor != NULL); + + if (pWav->onTell == NULL) { + return MA_FALSE; /* Not implemented. */ + } + + result = pWav->onTell(pWav->pReadSeekTellUserData, pCursor); + if (result != MA_SUCCESS) { + return MA_FALSE; /* Failed to tell. */ + } + + return MA_TRUE; +} #endif static ma_result ma_wav_init_internal(const ma_decoding_backend_config* pConfig, ma_wav* pWav) @@ -61784,7 +63474,7 @@ MA_API ma_result ma_wav_init(ma_read_proc onRead, ma_seek_proc onSeek, ma_tell_p { ma_bool32 wavResult; - wavResult = ma_dr_wav_init(&pWav->dr, ma_wav_dr_callback__read, ma_wav_dr_callback__seek, pWav, pAllocationCallbacks); + wavResult = ma_dr_wav_init(&pWav->dr, ma_wav_dr_callback__read, ma_wav_dr_callback__seek, ma_wav_dr_callback__tell, pWav, pAllocationCallbacks); if (wavResult != MA_TRUE) { return MA_INVALID_FILE; } @@ -62275,7 +63965,6 @@ static ma_result ma_decoder_init_wav_from_memory__internal(const void* pData, si /* FLAC */ #ifdef ma_dr_flac_h -#define MA_HAS_FLAC typedef struct { @@ -62363,8 +64052,10 @@ static ma_bool32 ma_flac_dr_callback__seek(void* pUserData, int offset, ma_dr_fl MA_ASSERT(pFlac != NULL); maSeekOrigin = ma_seek_origin_start; - if (origin == ma_dr_flac_seek_origin_current) { - maSeekOrigin = ma_seek_origin_current; + if (origin == MA_DR_FLAC_SEEK_CUR) { + maSeekOrigin = ma_seek_origin_current; + } else if (origin == MA_DR_FLAC_SEEK_END) { + maSeekOrigin = ma_seek_origin_end; } result = pFlac->onSeek(pFlac->pReadSeekTellUserData, offset, maSeekOrigin); @@ -62374,6 +64065,26 @@ static ma_bool32 ma_flac_dr_callback__seek(void* pUserData, int offset, ma_dr_fl return MA_TRUE; } + +static ma_bool32 ma_flac_dr_callback__tell(void* pUserData, ma_int64* pCursor) +{ + ma_flac* pFlac = (ma_flac*)pUserData; + ma_result result; + + MA_ASSERT(pFlac != NULL); + MA_ASSERT(pCursor != NULL); + + if (pFlac->onTell == NULL) { + return MA_FALSE; /* Not implemented. */ + } + + result = pFlac->onTell(pFlac->pReadSeekTellUserData, pCursor); + if (result != MA_SUCCESS) { + return MA_FALSE; /* Failed to tell. */ + } + + return MA_TRUE; +} #endif static ma_result ma_flac_init_internal(const ma_decoding_backend_config* pConfig, ma_flac* pFlac) @@ -62425,7 +64136,7 @@ MA_API ma_result ma_flac_init(ma_read_proc onRead, ma_seek_proc onSeek, ma_tell_ #if !defined(MA_NO_FLAC) { - pFlac->dr = ma_dr_flac_open(ma_flac_dr_callback__read, ma_flac_dr_callback__seek, pFlac, pAllocationCallbacks); + pFlac->dr = ma_dr_flac_open(ma_flac_dr_callback__read, ma_flac_dr_callback__seek, ma_flac_dr_callback__tell, pFlac, pAllocationCallbacks); if (pFlac->dr == NULL) { return MA_INVALID_FILE; } @@ -62897,7 +64608,6 @@ static ma_result ma_decoder_init_flac_from_memory__internal(const void* pData, s /* MP3 */ #ifdef ma_dr_mp3_h -#define MA_HAS_MP3 typedef struct { @@ -62986,9 +64696,12 @@ static ma_bool32 ma_mp3_dr_callback__seek(void* pUserData, int offset, ma_dr_mp3 MA_ASSERT(pMP3 != NULL); - maSeekOrigin = ma_seek_origin_start; - if (origin == ma_dr_mp3_seek_origin_current) { - maSeekOrigin = ma_seek_origin_current; + if (origin == MA_DR_MP3_SEEK_SET) { + maSeekOrigin = ma_seek_origin_start; + } else if (origin == MA_DR_MP3_SEEK_END) { + maSeekOrigin = ma_seek_origin_end; + } else { + maSeekOrigin = ma_seek_origin_current; } result = pMP3->onSeek(pMP3->pReadSeekTellUserData, offset, maSeekOrigin); @@ -62998,6 +64711,21 @@ static ma_bool32 ma_mp3_dr_callback__seek(void* pUserData, int offset, ma_dr_mp3 return MA_TRUE; } + +static ma_bool32 ma_mp3_dr_callback__tell(void* pUserData, ma_int64* pCursor) +{ + ma_mp3* pMP3 = (ma_mp3*)pUserData; + ma_result result; + + MA_ASSERT(pMP3 != NULL); + + result = pMP3->onTell(pMP3->pReadSeekTellUserData, pCursor); + if (result != MA_SUCCESS) { + return MA_FALSE; + } + + return MA_TRUE; +} #endif static ma_result ma_mp3_init_internal(const ma_decoding_backend_config* pConfig, ma_mp3* pMP3) @@ -63098,7 +64826,7 @@ MA_API ma_result ma_mp3_init(ma_read_proc onRead, ma_seek_proc onSeek, ma_tell_p { ma_bool32 mp3Result; - mp3Result = ma_dr_mp3_init(&pMP3->dr, ma_mp3_dr_callback__read, ma_mp3_dr_callback__seek, pMP3, pAllocationCallbacks); + mp3Result = ma_dr_mp3_init(&pMP3->dr, ma_mp3_dr_callback__read, ma_mp3_dr_callback__seek, ma_mp3_dr_callback__tell, NULL, pMP3, pAllocationCallbacks); if (mp3Result != MA_TRUE) { return MA_INVALID_FILE; } @@ -64557,11 +66285,9 @@ static ma_result ma_decoder_init__internal(ma_decoder_read_proc onRead, ma_decod We use trial and error to open a decoder. We prioritize custom decoders so that if they implement the same encoding format they take priority over the built-in decoders. */ + result = ma_decoder_init_custom__internal(pConfig, pDecoder); if (result != MA_SUCCESS) { - result = ma_decoder_init_custom__internal(pConfig, pDecoder); - if (result != MA_SUCCESS) { - onSeek(pDecoder, 0, ma_seek_origin_start); - } + onSeek(pDecoder, 0, ma_seek_origin_start); } /* @@ -64825,14 +66551,6 @@ MA_API ma_result ma_decoder_init_memory(const void* pData, size_t dataSize, cons /* Initialization was successful. Finish up. */ result = ma_decoder__postinit(&config, pDecoder); if (result != MA_SUCCESS) { - /* - The backend was initialized successfully, but for some reason post-initialization failed. This is most likely - due to an out of memory error. We're going to abort with an error here and not try to recover. - */ - if (pDecoder->pBackendVTable != NULL && pDecoder->pBackendVTable->onUninit != NULL) { - pDecoder->pBackendVTable->onUninit(pDecoder->pBackendUserData, &pDecoder->pBackend, &pDecoder->allocationCallbacks); - } - return result; } } else { @@ -64997,14 +66715,16 @@ static ma_bool32 ma_path_extension_equal_w(const wchar_t* path, const wchar_t* e ext1 = extension; ext2 = ma_path_extension_w(path); -#if defined(_MSC_VER) || defined(__WATCOMC__) || defined(__DMC__) - return _wcsicmp(ext1, ext2) == 0; -#else - /* - I'm not aware of a wide character version of strcasecmp(). I'm therefore converting the extensions to multibyte strings and comparing those. This - isn't the most efficient way to do it, but it should work OK. - */ + #if (defined(_MSC_VER) || defined(__WATCOMC__) || defined(__DMC__)) && !defined(MA_XBOX_NXDK) + { + return _wcsicmp(ext1, ext2) == 0; + } + #elif !defined(MA_XBOX_NXDK) && !defined(MA_DOS) { + /* + I'm not aware of a wide character version of strcasecmp(). I'm therefore converting the extensions to multibyte strings and comparing those. This + isn't the most efficient way to do it, but it should work OK. + */ char ext1MB[4096]; char ext2MB[4096]; const wchar_t* pext1 = ext1; @@ -65024,7 +66744,13 @@ static ma_bool32 ma_path_extension_equal_w(const wchar_t* path, const wchar_t* e return strcasecmp(ext1MB, ext2MB) == 0; } -#endif + #else + { + /* Getting here means we don't have a way to do a case-sensitive comparison for wide strings. Fall back to a simple case-sensitive comparison. */ + /* TODO: Implement our own wchar_t-to-char conversion routine and then use the char* version for comparing. */ + return ma_wcscmp(ext1, ext2) == 0; + } + #endif } #endif /* MA_HAS_PATH_API */ @@ -65125,11 +66851,9 @@ MA_API ma_result ma_decoder_init_vfs(ma_vfs* pVFS, const char* pFilePath, const We use trial and error to open a decoder. We prioritize custom decoders so that if they implement the same encoding format they take priority over the built-in decoders. */ + result = ma_decoder_init_custom__internal(&config, pDecoder); if (result != MA_SUCCESS) { - result = ma_decoder_init_custom__internal(&config, pDecoder); - if (result != MA_SUCCESS) { - ma_decoder__on_seek_vfs(pDecoder, 0, ma_seek_origin_start); - } + ma_decoder__on_seek_vfs(pDecoder, 0, ma_seek_origin_start); } /* @@ -65258,11 +66982,9 @@ MA_API ma_result ma_decoder_init_vfs_w(ma_vfs* pVFS, const wchar_t* pFilePath, c We use trial and error to open a decoder. We prioritize custom decoders so that if they implement the same encoding format they take priority over the built-in decoders. */ + result = ma_decoder_init_custom__internal(&config, pDecoder); if (result != MA_SUCCESS) { - result = ma_decoder_init_custom__internal(&config, pDecoder); - if (result != MA_SUCCESS) { - ma_decoder__on_seek_vfs(pDecoder, 0, ma_seek_origin_start); - } + ma_decoder__on_seek_vfs(pDecoder, 0, ma_seek_origin_start); } /* @@ -65444,14 +67166,6 @@ MA_API ma_result ma_decoder_init_file(const char* pFilePath, const ma_decoder_co /* Initialization was successful. Finish up. */ result = ma_decoder__postinit(&config, pDecoder); if (result != MA_SUCCESS) { - /* - The backend was initialized successfully, but for some reason post-initialization failed. This is most likely - due to an out of memory error. We're going to abort with an error here and not try to recover. - */ - if (pDecoder->pBackendVTable != NULL && pDecoder->pBackendVTable->onUninit != NULL) { - pDecoder->pBackendVTable->onUninit(pDecoder->pBackendUserData, &pDecoder->pBackend, &pDecoder->allocationCallbacks); - } - return result; } } else { @@ -65594,14 +67308,6 @@ MA_API ma_result ma_decoder_init_file_w(const wchar_t* pFilePath, const ma_decod /* Initialization was successful. Finish up. */ result = ma_decoder__postinit(&config, pDecoder); if (result != MA_SUCCESS) { - /* - The backend was initialized successfully, but for some reason post-initialization failed. This is most likely - due to an out of memory error. We're going to abort with an error here and not try to recover. - */ - if (pDecoder->pBackendVTable != NULL && pDecoder->pBackendVTable->onUninit != NULL) { - pDecoder->pBackendVTable->onUninit(pDecoder->pBackendUserData, &pDecoder->pBackend, &pDecoder->allocationCallbacks); - } - return result; } } else { @@ -66119,10 +67825,18 @@ static ma_bool32 ma_encoder__internal_on_seek_wav(void* pUserData, int offset, m { ma_encoder* pEncoder = (ma_encoder*)pUserData; ma_result result; + ma_seek_origin maSeekOrigin; MA_ASSERT(pEncoder != NULL); - result = pEncoder->onSeek(pEncoder, offset, (origin == ma_dr_wav_seek_origin_start) ? ma_seek_origin_start : ma_seek_origin_current); + maSeekOrigin = ma_seek_origin_start; + if (origin == MA_DR_WAV_SEEK_CUR) { + maSeekOrigin = ma_seek_origin_current; + } else if (origin == MA_DR_WAV_SEEK_END) { + maSeekOrigin = ma_seek_origin_end; + } + + result = pEncoder->onSeek(pEncoder, offset, maSeekOrigin); if (result != MA_SUCCESS) { return MA_FALSE; } else { @@ -67644,7 +69358,7 @@ static MA_INLINE ma_uint32 ma_hash_getblock(const ma_uint32* blocks, int i) ma_uint32 block; /* Try silencing a sanitization warning about unaligned access by doing a memcpy() instead of assignment. */ - MA_COPY_MEMORY(&block, ma_offset_ptr(blocks, i * sizeof(block)), sizeof(block)); + MA_COPY_MEMORY(&block, ma_offset_ptr(blocks, i * (int) sizeof(block)), sizeof(block)); if (ma_is_little_endian()) { return block; @@ -67720,7 +69434,7 @@ static ma_uint32 ma_hash_string_32(const char* str) static ma_uint32 ma_hash_string_w_32(const wchar_t* str) { - return ma_hash_32(str, (int)wcslen(str) * sizeof(*str), MA_DEFAULT_HASH_SEED); + return ma_hash_32(str, (int)ma_wcslen(str) * sizeof(*str), MA_DEFAULT_HASH_SEED); } @@ -67880,6 +69594,7 @@ static MA_INLINE ma_resource_manager_data_buffer_node* ma_resource_manager_data_ return ma_resource_manager_data_buffer_node_find_min(pDataBufferNode->pChildHi); } +#if 0 /* Currently unused, but might make use of this later. */ static MA_INLINE ma_resource_manager_data_buffer_node* ma_resource_manager_data_buffer_node_find_inorder_predecessor(ma_resource_manager_data_buffer_node* pDataBufferNode) { MA_ASSERT(pDataBufferNode != NULL); @@ -67887,6 +69602,7 @@ static MA_INLINE ma_resource_manager_data_buffer_node* ma_resource_manager_data_ return ma_resource_manager_data_buffer_node_find_max(pDataBufferNode->pChildLo); } +#endif static ma_result ma_resource_manager_data_buffer_node_remove(ma_resource_manager* pResourceManager, ma_resource_manager_data_buffer_node* pDataBufferNode) { @@ -68237,6 +69953,7 @@ MA_API ma_resource_manager_config ma_resource_manager_config_init(void) config.decodedSampleRate = 0; config.jobThreadCount = 1; /* A single miniaudio-managed job thread by default. */ config.jobQueueCapacity = MA_JOB_TYPE_RESOURCE_MANAGER_QUEUE_CAPACITY; + config.resampling = ma_resampler_config_init(ma_format_unknown, 0, 0, 0, ma_resample_algorithm_linear); /* Format/channels/rate doesn't matter here. */ /* Flags. */ config.flags = 0; @@ -68490,6 +70207,7 @@ static ma_decoder_config ma_resource_manager__init_decoder_config(ma_resource_ma config.ppCustomBackendVTables = pResourceManager->config.ppCustomDecodingBackendVTables; config.customBackendCount = pResourceManager->config.customDecodingBackendCount; config.pCustomBackendUserData = pResourceManager->config.pCustomDecodingBackendUserData; + config.resampling = pResourceManager->config.resampling; return config; } @@ -69009,16 +70727,19 @@ static ma_result ma_resource_manager_data_buffer_node_acquire_critical_section(m /* Failed to post job. Probably ran out of memory. */ ma_log_postf(ma_resource_manager_get_log(pResourceManager), MA_LOG_LEVEL_ERROR, "Failed to post MA_JOB_TYPE_RESOURCE_MANAGER_LOAD_DATA_BUFFER_NODE job. %s.\n", ma_result_description(result)); - /* - Fences were acquired before posting the job, but since the job was not able to - be posted, we need to make sure we release them so nothing gets stuck waiting. - */ - if (pInitFence != NULL) { ma_fence_release(pInitFence); } - if (pDoneFence != NULL) { ma_fence_release(pDoneFence); } - if ((flags & MA_RESOURCE_MANAGER_DATA_SOURCE_FLAG_WAIT_INIT) != 0) { ma_resource_manager_inline_notification_uninit(pInitNotification); } else { + /* + Fences were acquired before posting the job, but since the job was not able to + be posted, we need to make sure we release them so nothing gets stuck waiting. + + In the WAIT_INIT case, these will have already been released in ma_job_process() + so we should only release fences in this branch. + */ + if (pInitFence != NULL) { ma_fence_release(pInitFence); } + if (pDoneFence != NULL) { ma_fence_release(pDoneFence); } + /* These will have been freed by the job thread, but with WAIT_INIT they will already have happened since the job has already been handled. */ ma_free(pFilePathCopy, &pResourceManager->config.allocationCallbacks); ma_free(pFilePathWCopy, &pResourceManager->config.allocationCallbacks); @@ -69812,13 +71533,13 @@ MA_API ma_result ma_resource_manager_data_buffer_get_data_format(ma_resource_man MA_API ma_result ma_resource_manager_data_buffer_get_cursor_in_pcm_frames(ma_resource_manager_data_buffer* pDataBuffer, ma_uint64* pCursor) { - /* We cannot be using the data source after it's been uninitialized. */ - MA_ASSERT(ma_resource_manager_data_buffer_node_result(pDataBuffer->pNode) != MA_UNAVAILABLE); - if (pDataBuffer == NULL || pCursor == NULL) { return MA_INVALID_ARGS; } + /* We cannot be using the data source after it's been uninitialized. */ + MA_ASSERT(ma_resource_manager_data_buffer_node_result(pDataBuffer->pNode) != MA_UNAVAILABLE); + *pCursor = 0; switch (ma_resource_manager_data_buffer_node_get_data_supply_type(pDataBuffer->pNode)) @@ -69852,13 +71573,13 @@ MA_API ma_result ma_resource_manager_data_buffer_get_cursor_in_pcm_frames(ma_res MA_API ma_result ma_resource_manager_data_buffer_get_length_in_pcm_frames(ma_resource_manager_data_buffer* pDataBuffer, ma_uint64* pLength) { - /* We cannot be using the data source after it's been uninitialized. */ - MA_ASSERT(ma_resource_manager_data_buffer_node_result(pDataBuffer->pNode) != MA_UNAVAILABLE); - if (pDataBuffer == NULL || pLength == NULL) { return MA_INVALID_ARGS; } + /* We cannot be using the data source after it's been uninitialized. */ + MA_ASSERT(ma_resource_manager_data_buffer_node_result(pDataBuffer->pNode) != MA_UNAVAILABLE); + if (ma_resource_manager_data_buffer_node_get_data_supply_type(pDataBuffer->pNode) == ma_resource_manager_data_supply_type_unknown) { return MA_BUSY; /* Still loading. */ } @@ -71213,8 +72934,6 @@ static ma_result ma_job_process__resource_manager__free_data_buffer_node(ma_job* return ma_resource_manager_post_job(pResourceManager, pJob); /* Out of order. */ } - ma_resource_manager_data_buffer_node_free(pResourceManager, pDataBufferNode); - /* The event needs to be signalled last. */ if (pJob->data.resourceManager.freeDataBufferNode.pDoneNotification != NULL) { ma_async_notification_signal(pJob->data.resourceManager.freeDataBufferNode.pDoneNotification); @@ -71225,6 +72944,9 @@ static ma_result ma_job_process__resource_manager__free_data_buffer_node(ma_job* } ma_atomic_fetch_add_32(&pDataBufferNode->executionPointer, 1); + + ma_resource_manager_data_buffer_node_free(pResourceManager, pDataBufferNode); + return MA_SUCCESS; } @@ -72097,6 +73819,15 @@ MA_API ma_result ma_node_graph_set_time(ma_node_graph* pNodeGraph, ma_uint64 glo return ma_node_set_time(&pNodeGraph->endpoint, globalTime); /* Global time is just the local time of the endpoint. */ } +MA_API ma_uint32 ma_node_graph_get_processing_size_in_frames(const ma_node_graph* pNodeGraph) +{ + if (pNodeGraph == NULL) { + return 0; + } + + return pNodeGraph->processingSizeInFrames; +} + #define MA_NODE_OUTPUT_BUS_FLAG_HAS_READ 0x01 /* Whether or not this bus ready to read more data. Only used on nodes with multiple output buses. */ @@ -73256,12 +74987,12 @@ MA_API ma_node_state ma_node_get_state_by_time_range(const ma_node* pNode, ma_ui its start time not having been reached yet. Also, the stop time may have also been reached in which case it'll be considered stopped. */ - if (ma_node_get_state_time(pNode, ma_node_state_started) > globalTimeBeg) { - return ma_node_state_stopped; /* Start time has not yet been reached. */ + if (ma_node_get_state_time(pNode, ma_node_state_stopped) < globalTimeBeg) { + return ma_node_state_stopped; /* End time is before the start of the range. */ } - if (ma_node_get_state_time(pNode, ma_node_state_stopped) <= globalTimeEnd) { - return ma_node_state_stopped; /* Stop time has been reached. */ + if (ma_node_get_state_time(pNode, ma_node_state_started) > globalTimeEnd) { + return ma_node_state_stopped; /* Start time is after the end of the range. */ } /* Getting here means the node is marked as started and is within its start/stop times. */ @@ -73341,14 +75072,14 @@ static ma_result ma_node_read_pcm_frames(ma_node* pNode, ma_uint32 outputBusInde return MA_INVALID_ARGS; /* Invalid output bus index. */ } + globalTimeBeg = globalTime; + globalTimeEnd = globalTime + frameCount; + /* Don't do anything if we're in a stopped state. */ - if (ma_node_get_state_by_time_range(pNode, globalTime, globalTime + frameCount) != ma_node_state_started) { + if (ma_node_get_state_by_time_range(pNode, globalTimeBeg, globalTimeEnd) != ma_node_state_started) { return MA_SUCCESS; /* We're in a stopped state. This is not an error - we just need to not read anything. */ } - - globalTimeBeg = globalTime; - globalTimeEnd = globalTime + frameCount; startTime = ma_node_get_state_time(pNode, ma_node_state_started); stopTime = ma_node_get_state_time(pNode, ma_node_state_stopped); @@ -73361,11 +75092,16 @@ static ma_result ma_node_read_pcm_frames(ma_node* pNode, ma_uint32 outputBusInde therefore need to offset it by a number of frames to accommodate. The same thing applies for the stop time. */ - timeOffsetBeg = (globalTimeBeg < startTime) ? (ma_uint32)(globalTimeEnd - startTime) : 0; + timeOffsetBeg = (globalTimeBeg < startTime) ? (ma_uint32)(startTime - globalTimeBeg) : 0; timeOffsetEnd = (globalTimeEnd > stopTime) ? (ma_uint32)(globalTimeEnd - stopTime) : 0; /* Trim based on the start offset. We need to silence the start of the buffer. */ if (timeOffsetBeg > 0) { + MA_ASSERT(timeOffsetBeg <= frameCount); + if (timeOffsetBeg > frameCount) { + timeOffsetBeg = frameCount; + } + ma_silence_pcm_frames(pFramesOut, timeOffsetBeg, ma_format_f32, ma_node_get_output_channels(pNode, outputBusIndex)); pFramesOut += timeOffsetBeg * ma_node_get_output_channels(pNode, outputBusIndex); frameCount -= timeOffsetBeg; @@ -73373,6 +75109,11 @@ static ma_result ma_node_read_pcm_frames(ma_node* pNode, ma_uint32 outputBusInde /* Trim based on the end offset. We don't need to silence the tail section because we'll just have a reduced value written to pFramesRead. */ if (timeOffsetEnd > 0) { + MA_ASSERT(timeOffsetEnd <= frameCount); + if (timeOffsetEnd > frameCount) { + timeOffsetEnd = frameCount; + } + frameCount -= timeOffsetEnd; } @@ -74787,12 +76528,20 @@ static void ma_sound_set_at_end(ma_sound* pSound, ma_bool32 atEnd) MA_ASSERT(pSound != NULL); ma_atomic_exchange_32(&pSound->atEnd, atEnd); + /* + When this function is called the state of the sound will not yet be in a stopped state. This makes it confusing + because an end callback will intuitively expect ma_sound_is_playing() to return false from inside the callback. + I'm therefore no longer firing the callback here and will instead fire it manually in the *next* processing step + when the state should be set to stopped as expected. + */ + #if 0 /* Fire any callbacks or events. */ if (atEnd) { if (pSound->endCallback != NULL) { pSound->endCallback(pSound->pEndCallbackUserData, pSound); } } + #endif } static ma_bool32 ma_sound_get_at_end(const ma_sound* pSound) @@ -74812,6 +76561,7 @@ MA_API ma_engine_node_config ma_engine_node_config_init(ma_engine* pEngine, ma_e config.isPitchDisabled = (flags & MA_SOUND_FLAG_NO_PITCH) != 0; config.isSpatializationDisabled = (flags & MA_SOUND_FLAG_NO_SPATIALIZATION) != 0; config.monoExpansionMode = pEngine->monoExpansionMode; + config.resampling = pEngine->pitchResamplingConfig; return config; } @@ -74838,7 +76588,7 @@ static void ma_engine_node_update_pitch_if_required(ma_engine_node* pEngineNode) if (isUpdateRequired) { float basePitch = (float)pEngineNode->sampleRate / ma_engine_get_sample_rate(pEngineNode->pEngine); - ma_linear_resampler_set_rate_ratio(&pEngineNode->resampler, basePitch * pEngineNode->oldPitch * pEngineNode->oldDopplerPitch); + ma_resampler_set_rate_ratio(&pEngineNode->resampler, basePitch * pEngineNode->oldPitch * pEngineNode->oldDopplerPitch); } } @@ -74857,22 +76607,6 @@ static ma_bool32 ma_engine_node_is_spatialization_enabled(const ma_engine_node* return !ma_atomic_load_explicit_32(&pEngineNode->isSpatializationDisabled, ma_atomic_memory_order_acquire); } -static ma_uint64 ma_engine_node_get_required_input_frame_count(const ma_engine_node* pEngineNode, ma_uint64 outputFrameCount) -{ - ma_uint64 inputFrameCount = 0; - - if (ma_engine_node_is_pitching_enabled(pEngineNode)) { - ma_result result = ma_linear_resampler_get_required_input_frame_count(&pEngineNode->resampler, outputFrameCount, &inputFrameCount); - if (result != MA_SUCCESS) { - inputFrameCount = 0; - } - } else { - inputFrameCount = outputFrameCount; /* No resampling, so 1:1. */ - } - - return inputFrameCount; -} - static ma_result ma_engine_node_set_volume(ma_engine_node* pEngineNode, float volume) { if (pEngineNode == NULL) { @@ -75014,7 +76748,7 @@ static void ma_engine_node_process_pcm_frames__general(ma_engine_node* pEngineNo ma_uint64 resampleFrameCountIn = framesAvailableIn; ma_uint64 resampleFrameCountOut = framesAvailableOut; - ma_linear_resampler_process_pcm_frames(&pEngineNode->resampler, pRunningFramesIn, &resampleFrameCountIn, pWorkingBuffer, &resampleFrameCountOut); + ma_resampler_process_pcm_frames(&pEngineNode->resampler, pRunningFramesIn, &resampleFrameCountIn, pWorkingBuffer, &resampleFrameCountOut); isWorkingBufferValid = MA_TRUE; framesJustProcessedIn = (ma_uint32)resampleFrameCountIn; @@ -75138,6 +76872,11 @@ static void ma_engine_node_process_pcm_frames__sound(ma_node* pNode, const float /* If we're marked at the end we need to stop the sound and do nothing. */ if (ma_sound_at_end(pSound)) { ma_sound_stop(pSound); + + if (pSound->endCallback != NULL) { + pSound->endCallback(pSound->pEndCallbackUserData, pSound); + } + *pFrameCountOut = 0; return; } @@ -75175,55 +76914,74 @@ static void ma_engine_node_process_pcm_frames__sound(ma_node* pNode, const float /* Keep reading until we've read as much as was requested or we reach the end of the data source. */ while (totalFramesRead < frameCount) { ma_uint32 framesRemaining = frameCount - totalFramesRead; - ma_uint32 framesToRead; ma_uint64 framesJustRead; ma_uint32 frameCountIn; ma_uint32 frameCountOut; const float* pRunningFramesIn; float* pRunningFramesOut; - /* - The first thing we need to do is read into the temporary buffer. We can calculate exactly - how many input frames we'll need after resampling. - */ - framesToRead = (ma_uint32)ma_engine_node_get_required_input_frame_count(&pSound->engineNode, framesRemaining); - if (framesToRead > tempCapInFrames) { - framesToRead = tempCapInFrames; - } + /* If there's any input frames sitting in the cache get those processed first. */ + if (pSound->processingCacheFramesRemaining > 0) { + pRunningFramesIn = pSound->pProcessingCache; + frameCountIn = pSound->processingCacheFramesRemaining; - result = ma_data_source_read_pcm_frames(pSound->pDataSource, temp, framesToRead, &framesJustRead); + pRunningFramesOut = ma_offset_pcm_frames_ptr_f32(ppFramesOut[0], totalFramesRead, ma_node_get_output_channels(pNode, 0)); + frameCountOut = framesRemaining; - /* If we reached the end of the sound we'll want to mark it as at the end and stop it. This should never be returned for looping sounds. */ - if (result == MA_AT_END) { - ma_sound_set_at_end(pSound, MA_TRUE); /* This will be set to false in ma_sound_start(). */ - } + ma_engine_node_process_pcm_frames__general(&pSound->engineNode, &pRunningFramesIn, &frameCountIn, &pRunningFramesOut, &frameCountOut); - pRunningFramesOut = ma_offset_pcm_frames_ptr_f32(ppFramesOut[0], totalFramesRead, ma_node_get_output_channels(pNode, 0)); + MA_ASSERT(frameCountIn <= pSound->processingCacheFramesRemaining); + pSound->processingCacheFramesRemaining -= frameCountIn; - frameCountIn = (ma_uint32)framesJustRead; - frameCountOut = framesRemaining; + /* Move any remaining data in the cache down. */ + if (pSound->processingCacheFramesRemaining > 0) { + MA_MOVE_MEMORY(pSound->pProcessingCache, ma_offset_pcm_frames_ptr_f32(pSound->pProcessingCache, frameCountIn, dataSourceChannels), pSound->processingCacheFramesRemaining * ma_get_bytes_per_frame(ma_format_f32, dataSourceChannels)); + } + + totalFramesRead += (ma_uint32)frameCountOut; /* Safe cast. */ - /* Convert if necessary. */ - if (dataSourceFormat == ma_format_f32) { - /* Fast path. No data conversion necessary. */ - pRunningFramesIn = (float*)temp; - ma_engine_node_process_pcm_frames__general(&pSound->engineNode, &pRunningFramesIn, &frameCountIn, &pRunningFramesOut, &frameCountOut); + if (result != MA_SUCCESS || ma_sound_at_end(pSound)) { + break; /* Might have reached the end. */ + } } else { - /* Slow path. Need to do sample format conversion to f32. If we give the f32 buffer the same count as the first temp buffer, we're guaranteed it'll be large enough. */ - float tempf32[MA_DATA_CONVERTER_STACK_BUFFER_SIZE]; /* Do not do `MA_DATA_CONVERTER_STACK_BUFFER_SIZE/sizeof(float)` here like we've done in other places. */ - ma_convert_pcm_frames_format(tempf32, ma_format_f32, temp, dataSourceFormat, framesJustRead, dataSourceChannels, ma_dither_mode_none); + /* Getting here means there's nothing in the cache. Read more data from the data source. */ + if (dataSourceFormat == ma_format_f32) { + /* Fast path. No conversion to f32 necessary. */ + result = ma_data_source_read_pcm_frames(pSound->pDataSource, pSound->pProcessingCache, pSound->processingCacheCap, &framesJustRead); + } else { + /* Slow path. Need to convert to f32. */ + ma_uint64 totalFramesConverted = 0; + + while (totalFramesConverted < pSound->processingCacheCap) { + ma_uint64 framesConverted; + ma_uint32 framesToConvertThisIteration = pSound->processingCacheCap - (ma_uint32)totalFramesConverted; + if (framesToConvertThisIteration > tempCapInFrames) { + framesToConvertThisIteration = tempCapInFrames; + } - /* Now that we have our samples in f32 format we can process like normal. */ - pRunningFramesIn = tempf32; - ma_engine_node_process_pcm_frames__general(&pSound->engineNode, &pRunningFramesIn, &frameCountIn, &pRunningFramesOut, &frameCountOut); - } + result = ma_data_source_read_pcm_frames(pSound->pDataSource, temp, framesToConvertThisIteration, &framesConverted); + if (result != MA_SUCCESS) { + break; + } - /* We should have processed all of our input frames since we calculated the required number of input frames at the top. */ - MA_ASSERT(frameCountIn == framesJustRead); - totalFramesRead += (ma_uint32)frameCountOut; /* Safe cast. */ + ma_convert_pcm_frames_format(ma_offset_pcm_frames_ptr_f32(pSound->pProcessingCache, totalFramesConverted, dataSourceChannels), ma_format_f32, temp, dataSourceFormat, framesConverted, dataSourceChannels, ma_dither_mode_none); + totalFramesConverted += framesConverted; + } - if (result != MA_SUCCESS || ma_sound_at_end(pSound)) { - break; /* Might have reached the end. */ + framesJustRead = totalFramesConverted; + } + + MA_ASSERT(framesJustRead <= pSound->processingCacheCap); + pSound->processingCacheFramesRemaining = (ma_uint32)framesJustRead; + + /* If we reached the end of the sound we'll want to mark it as at the end and stop it. This should never be returned for looping sounds. */ + if (result == MA_AT_END) { + ma_sound_set_at_end(pSound, MA_TRUE); /* This will be set to false in ma_sound_start(). */ + } + + if (result != MA_SUCCESS || ma_sound_at_end(pSound)) { + break; + } } } } @@ -75246,25 +77004,6 @@ static void ma_engine_node_process_pcm_frames__group(ma_node* pNode, const float ma_engine_node_process_pcm_frames__general((ma_engine_node*)pNode, ppFramesIn, pFrameCountIn, ppFramesOut, pFrameCountOut); } -static ma_result ma_engine_node_get_required_input_frame_count__group(ma_node* pNode, ma_uint32 outputFrameCount, ma_uint32* pInputFrameCount) -{ - ma_uint64 inputFrameCount; - - MA_ASSERT(pInputFrameCount != NULL); - - /* Our pitch will affect this calculation. We need to update it. */ - ma_engine_node_update_pitch_if_required((ma_engine_node*)pNode); - - inputFrameCount = ma_engine_node_get_required_input_frame_count((ma_engine_node*)pNode, outputFrameCount); - if (inputFrameCount > 0xFFFFFFFF) { - inputFrameCount = 0xFFFFFFFF; /* Will never happen because miniaudio will only ever process in relatively small chunks. */ - } - - *pInputFrameCount = (ma_uint32)inputFrameCount; - - return MA_SUCCESS; -} - static ma_node_vtable g_ma_engine_node_vtable__sound = { @@ -75278,7 +77017,7 @@ static ma_node_vtable g_ma_engine_node_vtable__sound = static ma_node_vtable g_ma_engine_node_vtable__group = { ma_engine_node_process_pcm_frames__group, - ma_engine_node_get_required_input_frame_count__group, + NULL, /* onGetRequiredInputFrameCount */ 1, /* Groups have one input bus. */ 1, /* Groups have one output bus. */ MA_NODE_FLAG_DIFFERENT_PROCESSING_RATES /* The engine node does resampling so should let miniaudio know about it. */ @@ -75324,9 +77063,10 @@ static ma_result ma_engine_node_get_heap_layout(const ma_engine_node_config* pCo ma_result result; size_t tempHeapSize; ma_node_config baseNodeConfig; - ma_linear_resampler_config resamplerConfig; + ma_resampler_config resamplerConfig; ma_spatializer_config spatializerConfig; ma_gainer_config gainerConfig; + ma_uint32 sampleRate; ma_uint32 channelsIn; ma_uint32 channelsOut; ma_channel defaultStereoChannelMap[2] = {MA_CHANNEL_SIDE_LEFT, MA_CHANNEL_SIDE_RIGHT}; /* <-- Consistent with the default channel map of a stereo listener. Means channel conversion can run on a fast path. */ @@ -75345,6 +77085,7 @@ static ma_result ma_engine_node_get_heap_layout(const ma_engine_node_config* pCo pHeapLayout->sizeInBytes = 0; + sampleRate = (pConfig->sampleRate > 0) ? pConfig->sampleRate : ma_engine_get_sample_rate(pConfig->pEngine); channelsIn = (pConfig->channelsIn != 0) ? pConfig->channelsIn : ma_engine_get_channels(pConfig->pEngine); channelsOut = (pConfig->channelsOut != 0) ? pConfig->channelsOut : ma_engine_get_channels(pConfig->pEngine); @@ -75364,10 +77105,13 @@ static ma_result ma_engine_node_get_heap_layout(const ma_engine_node_config* pCo /* Resmapler. */ - resamplerConfig = ma_linear_resampler_config_init(ma_format_f32, channelsIn, 1, 1); /* Input and output sample rates don't affect the calculation of the heap size. */ - resamplerConfig.lpfOrder = 0; + resamplerConfig = pConfig->resampling; + resamplerConfig.format = ma_format_f32; + resamplerConfig.channels = channelsIn; + resamplerConfig.sampleRateIn = sampleRate; + resamplerConfig.sampleRateOut = ma_engine_get_sample_rate(pConfig->pEngine); - result = ma_linear_resampler_get_heap_size(&resamplerConfig, &tempHeapSize); + result = ma_resampler_get_heap_size(&resamplerConfig, &tempHeapSize); if (result != MA_SUCCESS) { return result; /* Failed to retrieve the size of the heap for the resampler. */ } @@ -75435,7 +77179,7 @@ MA_API ma_result ma_engine_node_init_preallocated(const ma_engine_node_config* p ma_result result; ma_engine_node_heap_layout heapLayout; ma_node_config baseNodeConfig; - ma_linear_resampler_config resamplerConfig; + ma_resampler_config resamplerConfig; ma_fader_config faderConfig; ma_spatializer_config spatializerConfig; ma_panner_config pannerConfig; @@ -75510,10 +77254,13 @@ MA_API ma_result ma_engine_node_init_preallocated(const ma_engine_node_config* p */ /* We'll always do resampling first. */ - resamplerConfig = ma_linear_resampler_config_init(ma_format_f32, baseNodeConfig.pInputChannels[0], pEngineNode->sampleRate, ma_engine_get_sample_rate(pEngineNode->pEngine)); - resamplerConfig.lpfOrder = 0; /* <-- Need to disable low-pass filtering for pitch shifting for now because there's cases where the biquads are becoming unstable. Need to figure out a better fix for this. */ + resamplerConfig = pConfig->resampling; + resamplerConfig.format = ma_format_f32; + resamplerConfig.channels = baseNodeConfig.pInputChannels[0]; + resamplerConfig.sampleRateIn = pEngineNode->sampleRate; + resamplerConfig.sampleRateOut = ma_engine_get_sample_rate(pEngineNode->pEngine); - result = ma_linear_resampler_init_preallocated(&resamplerConfig, ma_offset_ptr(pHeap, heapLayout.resamplerOffset), &pEngineNode->resampler); + result = ma_resampler_init_preallocated(&resamplerConfig, ma_offset_ptr(pHeap, heapLayout.resamplerOffset), &pEngineNode->resampler); if (result != MA_SUCCESS) { goto error1; } @@ -75572,7 +77319,7 @@ MA_API ma_result ma_engine_node_init_preallocated(const ma_engine_node_config* p /* No need for allocation callbacks here because we use a preallocated heap. */ error3: ma_spatializer_uninit(&pEngineNode->spatializer, NULL); -error2: ma_linear_resampler_uninit(&pEngineNode->resampler, NULL); +error2: ma_resampler_uninit(&pEngineNode->resampler, NULL); error1: ma_node_uninit(&pEngineNode->baseNode, NULL); error0: return result; } @@ -75621,7 +77368,7 @@ MA_API void ma_engine_node_uninit(ma_engine_node* pEngineNode, const ma_allocati } ma_spatializer_uninit(&pEngineNode->spatializer, pAllocationCallbacks); - ma_linear_resampler_uninit(&pEngineNode->resampler, pAllocationCallbacks); + ma_resampler_uninit(&pEngineNode->resampler, pAllocationCallbacks); /* Free the heap last. */ if (pEngineNode->_ownsHeap) { @@ -75643,8 +77390,12 @@ MA_API ma_sound_config ma_sound_config_init_2(ma_engine* pEngine) if (pEngine != NULL) { config.monoExpansionMode = pEngine->monoExpansionMode; + config.pitchResampling = pEngine->pitchResamplingConfig; } else { config.monoExpansionMode = ma_mono_expansion_mode_default; + + config.pitchResampling = ma_resampler_config_init(ma_format_f32, 0, 0, 0, ma_resample_algorithm_linear); + config.pitchResampling.linear.lpfOrder = 0; /* <-- Need to disable low-pass filtering for pitch shifting for now because there's cases where the biquads are becoming unstable. Need to figure out a better fix for this. */ } config.rangeEndInPCMFrames = ~((ma_uint64)0); @@ -75666,8 +77417,12 @@ MA_API ma_sound_group_config ma_sound_group_config_init_2(ma_engine* pEngine) if (pEngine != NULL) { config.monoExpansionMode = pEngine->monoExpansionMode; + config.pitchResampling = pEngine->pitchResamplingConfig; } else { config.monoExpansionMode = ma_mono_expansion_mode_default; + + config.pitchResampling = ma_resampler_config_init(ma_format_f32, 0, 0, 0, ma_resample_algorithm_linear); + config.pitchResampling.linear.lpfOrder = 0; /* <-- Need to disable low-pass filtering for pitch shifting for now because there's cases where the biquads are becoming unstable. Need to figure out a better fix for this. */ } return config; @@ -75679,8 +77434,12 @@ MA_API ma_engine_config ma_engine_config_init(void) ma_engine_config config; MA_ZERO_OBJECT(&config); - config.listenerCount = 1; /* Always want at least one listener. */ - config.monoExpansionMode = ma_mono_expansion_mode_default; + config.listenerCount = 1; /* Always want at least one listener. */ + config.monoExpansionMode = ma_mono_expansion_mode_default; + config.resourceManagerResampling = ma_resampler_config_init(ma_format_unknown, 0, 0, 0, ma_resample_algorithm_linear); + + config.pitchResampling = ma_resampler_config_init(ma_format_f32, 0, 0, 0, ma_resample_algorithm_linear); + config.pitchResampling.linear.lpfOrder = 0; /* <-- Need to disable low-pass filtering for pitch shifting for now because there's cases where the biquads are becoming unstable. Need to figure out a better fix for this. */ return config; } @@ -75761,6 +77520,7 @@ MA_API ma_result ma_engine_init(const ma_engine_config* pConfig, ma_engine* pEng pEngine->defaultVolumeSmoothTimeInPCMFrames = engineConfig.defaultVolumeSmoothTimeInPCMFrames; pEngine->onProcess = engineConfig.onProcess; pEngine->pProcessUserData = engineConfig.pProcessUserData; + pEngine->pitchResamplingConfig = engineConfig.pitchResampling; ma_allocation_callbacks_init_copy(&pEngine->allocationCallbacks, &engineConfig.allocationCallbacks); #if !defined(MA_NO_RESOURCE_MANAGER) @@ -75943,6 +77703,7 @@ MA_API ma_result ma_engine_init(const ma_engine_config* pConfig, ma_engine* pEng resourceManagerConfig.decodedSampleRate = ma_engine_get_sample_rate(pEngine); ma_allocation_callbacks_init_copy(&resourceManagerConfig.allocationCallbacks, &pEngine->allocationCallbacks); resourceManagerConfig.pVFS = engineConfig.pResourceManagerVFS; + resourceManagerConfig.resampling = engineConfig.resourceManagerResampling; /* The Emscripten build cannot use threads unless it's targeting pthreads. */ #if defined(MA_EMSCRIPTEN) && !defined(__EMSCRIPTEN_PTHREADS__) @@ -76668,13 +78429,32 @@ static ma_result ma_sound_init_from_data_source_internal(ma_engine* pEngine, con } + /* + When pulling data from a data source we need a processing cache to hold onto unprocessed input data from the data source + after doing resampling. + */ + if (pSound->pDataSource != NULL) { + pSound->processingCacheFramesRemaining = 0; + pSound->processingCacheCap = ma_node_graph_get_processing_size_in_frames(&pEngine->nodeGraph); + if (pSound->processingCacheCap == 0) { + pSound->processingCacheCap = 512; + } + + pSound->pProcessingCache = (float*)ma_calloc(pSound->processingCacheCap * ma_get_bytes_per_frame(ma_format_f32, engineNodeConfig.channelsIn), &pEngine->allocationCallbacks); + if (pSound->pProcessingCache == NULL) { + ma_engine_node_uninit(&pSound->engineNode, &pEngine->allocationCallbacks); + return MA_OUT_OF_MEMORY; + } + } + + /* Apply initial range and looping state to the data source if applicable. */ if (pConfig->rangeBegInPCMFrames != 0 || pConfig->rangeEndInPCMFrames != ~((ma_uint64)0)) { ma_data_source_set_range_in_pcm_frames(ma_sound_get_data_source(pSound), pConfig->rangeBegInPCMFrames, pConfig->rangeEndInPCMFrames); } if (pConfig->loopPointBegInPCMFrames != 0 || pConfig->loopPointEndInPCMFrames != ~((ma_uint64)0)) { - ma_data_source_set_range_in_pcm_frames(ma_sound_get_data_source(pSound), pConfig->loopPointBegInPCMFrames, pConfig->loopPointEndInPCMFrames); + ma_data_source_set_loop_point_in_pcm_frames(ma_sound_get_data_source(pSound), pConfig->loopPointBegInPCMFrames, pConfig->loopPointEndInPCMFrames); } ma_sound_set_looping(pSound, pConfig->isLooping || ((pConfig->flags & MA_SOUND_FLAG_LOOPING) != 0)); @@ -76736,6 +78516,7 @@ MA_API ma_result ma_sound_init_from_file_internal(ma_engine* pEngine, const ma_s result = ma_resource_manager_data_source_init_ex(pEngine->pResourceManager, &resourceManagerDataSourceConfig, pSound->pResourceManagerDataSource); if (result != MA_SUCCESS) { + ma_free(pSound->pResourceManagerDataSource, &pEngine->allocationCallbacks); goto done; } @@ -76904,6 +78685,11 @@ MA_API void ma_sound_uninit(ma_sound* pSound) */ ma_engine_node_uninit(&pSound->engineNode, &pSound->engineNode.pEngine->allocationCallbacks); + if (pSound->pProcessingCache != NULL) { + ma_free(pSound->pProcessingCache, &pSound->engineNode.pEngine->allocationCallbacks); + pSound->pProcessingCache = NULL; + } + /* Once the sound is detached from the group we can guarantee that it won't be referenced by the mixer thread which means it's safe for us to destroy the data source. */ #ifndef MA_NO_RESOURCE_MANAGER if (pSound->ownsDataSource) { @@ -76999,6 +78785,27 @@ MA_API ma_result ma_sound_stop_with_fade_in_milliseconds(ma_sound* pSound, ma_ui return ma_sound_stop_with_fade_in_pcm_frames(pSound, (fadeLengthInMilliseconds * sampleRate) / 1000); } +MA_API void ma_sound_reset_start_time(ma_sound* pSound) +{ + ma_sound_set_start_time_in_pcm_frames(pSound, 0); +} + +MA_API void ma_sound_reset_stop_time(ma_sound* pSound) +{ + ma_sound_set_stop_time_in_pcm_frames(pSound, ~(ma_uint64)0); +} + +MA_API void ma_sound_reset_fade(ma_sound* pSound) +{ + ma_sound_set_fade_in_pcm_frames(pSound, 0, 1, 0); +} + +MA_API void ma_sound_reset_stop_time_and_fade(ma_sound* pSound) +{ + ma_sound_reset_stop_time(pSound); + ma_sound_reset_fade(pSound); +} + MA_API void ma_sound_set_volume(ma_sound* pSound, float volume) { if (pSound == NULL) { @@ -77541,7 +79348,12 @@ MA_API ma_uint64 ma_sound_get_time_in_pcm_frames(const ma_sound* pSound) MA_API ma_uint64 ma_sound_get_time_in_milliseconds(const ma_sound* pSound) { - return ma_sound_get_time_in_pcm_frames(pSound) * 1000 / ma_engine_get_sample_rate(ma_sound_get_engine(pSound)); + ma_uint32 sampleRate = ma_engine_get_sample_rate(ma_sound_get_engine(pSound)); + if (sampleRate == 0) { + return 0; /* Prevent a division by zero. */ + } + + return ma_sound_get_time_in_pcm_frames(pSound) * 1000 / sampleRate; } MA_API void ma_sound_set_looping(ma_sound* pSound, ma_bool32 isLooping) @@ -77625,7 +79437,7 @@ MA_API ma_result ma_sound_seek_to_second(ma_sound* pSound, float seekPointInSeco return ma_sound_seek_to_pcm_frame(pSound, frameIndex); } -MA_API ma_result ma_sound_get_data_format(ma_sound* pSound, ma_format* pFormat, ma_uint32* pChannels, ma_uint32* pSampleRate, ma_channel* pChannelMap, size_t channelMapCap) +MA_API ma_result ma_sound_get_data_format(const ma_sound* pSound, ma_format* pFormat, ma_uint32* pChannels, ma_uint32* pSampleRate, ma_channel* pChannelMap, size_t channelMapCap) { if (pSound == NULL) { return MA_INVALID_ARGS; @@ -77645,7 +79457,7 @@ MA_API ma_result ma_sound_get_data_format(ma_sound* pSound, ma_format* pFormat, } if (pSampleRate != NULL) { - *pSampleRate = pSound->engineNode.resampler.config.sampleRateIn; + *pSampleRate = pSound->engineNode.resampler.sampleRateIn; } if (pChannelMap != NULL) { @@ -77658,7 +79470,7 @@ MA_API ma_result ma_sound_get_data_format(ma_sound* pSound, ma_format* pFormat, } } -MA_API ma_result ma_sound_get_cursor_in_pcm_frames(ma_sound* pSound, ma_uint64* pCursor) +MA_API ma_result ma_sound_get_cursor_in_pcm_frames(const ma_sound* pSound, ma_uint64* pCursor) { ma_uint64 seekTarget; @@ -77680,7 +79492,7 @@ MA_API ma_result ma_sound_get_cursor_in_pcm_frames(ma_sound* pSound, ma_uint64* } } -MA_API ma_result ma_sound_get_length_in_pcm_frames(ma_sound* pSound, ma_uint64* pLength) +MA_API ma_result ma_sound_get_length_in_pcm_frames(const ma_sound* pSound, ma_uint64* pLength) { if (pSound == NULL) { return MA_INVALID_ARGS; @@ -77694,7 +79506,7 @@ MA_API ma_result ma_sound_get_length_in_pcm_frames(ma_sound* pSound, ma_uint64* return ma_data_source_get_length_in_pcm_frames(pSound->pDataSource, pLength); } -MA_API ma_result ma_sound_get_cursor_in_seconds(ma_sound* pSound, float* pCursor) +MA_API ma_result ma_sound_get_cursor_in_seconds(const ma_sound* pSound, float* pCursor) { ma_result result; ma_uint64 cursorInPCMFrames; @@ -77720,7 +79532,7 @@ MA_API ma_result ma_sound_get_cursor_in_seconds(ma_sound* pSound, float* pCursor return MA_SUCCESS; } -MA_API ma_result ma_sound_get_length_in_seconds(ma_sound* pSound, float* pLength) +MA_API ma_result ma_sound_get_length_in_seconds(const ma_sound* pSound, float* pLength) { if (pSound == NULL) { return MA_INVALID_ARGS; @@ -78539,12 +80351,12 @@ MA_PRIVATE ma_bool32 ma_dr_wav__seek_forward(ma_dr_wav_seek_proc onSeek, ma_uint ma_uint64 bytesRemainingToSeek = offset; while (bytesRemainingToSeek > 0) { if (bytesRemainingToSeek > 0x7FFFFFFF) { - if (!onSeek(pUserData, 0x7FFFFFFF, ma_dr_wav_seek_origin_current)) { + if (!onSeek(pUserData, 0x7FFFFFFF, MA_DR_WAV_SEEK_CUR)) { return MA_FALSE; } bytesRemainingToSeek -= 0x7FFFFFFF; } else { - if (!onSeek(pUserData, (int)bytesRemainingToSeek, ma_dr_wav_seek_origin_current)) { + if (!onSeek(pUserData, (int)bytesRemainingToSeek, MA_DR_WAV_SEEK_CUR)) { return MA_FALSE; } bytesRemainingToSeek = 0; @@ -78555,17 +80367,17 @@ MA_PRIVATE ma_bool32 ma_dr_wav__seek_forward(ma_dr_wav_seek_proc onSeek, ma_uint MA_PRIVATE ma_bool32 ma_dr_wav__seek_from_start(ma_dr_wav_seek_proc onSeek, ma_uint64 offset, void* pUserData) { if (offset <= 0x7FFFFFFF) { - return onSeek(pUserData, (int)offset, ma_dr_wav_seek_origin_start); + return onSeek(pUserData, (int)offset, MA_DR_WAV_SEEK_SET); } - if (!onSeek(pUserData, 0x7FFFFFFF, ma_dr_wav_seek_origin_start)) { + if (!onSeek(pUserData, 0x7FFFFFFF, MA_DR_WAV_SEEK_SET)) { return MA_FALSE; } offset -= 0x7FFFFFFF; for (;;) { if (offset <= 0x7FFFFFFF) { - return onSeek(pUserData, (int)offset, ma_dr_wav_seek_origin_current); + return onSeek(pUserData, (int)offset, MA_DR_WAV_SEEK_CUR); } - if (!onSeek(pUserData, 0x7FFFFFFF, ma_dr_wav_seek_origin_current)) { + if (!onSeek(pUserData, 0x7FFFFFFF, MA_DR_WAV_SEEK_CUR)) { return MA_FALSE; } offset -= 0x7FFFFFFF; @@ -78588,7 +80400,7 @@ MA_PRIVATE ma_bool32 ma_dr_wav__on_seek(ma_dr_wav_seek_proc onSeek, void* pUserD if (!onSeek(pUserData, offset, origin)) { return MA_FALSE; } - if (origin == ma_dr_wav_seek_origin_start) { + if (origin == MA_DR_WAV_SEEK_SET) { *pCursor = offset; } else { *pCursor += offset; @@ -78707,12 +80519,12 @@ MA_PRIVATE ma_uint64 ma_dr_wav__read_smpl_to_metadata_obj(ma_dr_wav__metadata_pa ma_uint8 smplLoopData[MA_DR_WAV_SMPL_LOOP_BYTES]; bytesJustRead = ma_dr_wav__metadata_parser_read(pParser, smplLoopData, sizeof(smplLoopData), &totalBytesRead); if (bytesJustRead == sizeof(smplLoopData)) { - pMetadata->data.smpl.pLoops[iSampleLoop].cuePointId = ma_dr_wav_bytes_to_u32(smplLoopData + 0); - pMetadata->data.smpl.pLoops[iSampleLoop].type = ma_dr_wav_bytes_to_u32(smplLoopData + 4); - pMetadata->data.smpl.pLoops[iSampleLoop].firstSampleByteOffset = ma_dr_wav_bytes_to_u32(smplLoopData + 8); - pMetadata->data.smpl.pLoops[iSampleLoop].lastSampleByteOffset = ma_dr_wav_bytes_to_u32(smplLoopData + 12); - pMetadata->data.smpl.pLoops[iSampleLoop].sampleFraction = ma_dr_wav_bytes_to_u32(smplLoopData + 16); - pMetadata->data.smpl.pLoops[iSampleLoop].playCount = ma_dr_wav_bytes_to_u32(smplLoopData + 20); + pMetadata->data.smpl.pLoops[iSampleLoop].cuePointId = ma_dr_wav_bytes_to_u32(smplLoopData + 0); + pMetadata->data.smpl.pLoops[iSampleLoop].type = ma_dr_wav_bytes_to_u32(smplLoopData + 4); + pMetadata->data.smpl.pLoops[iSampleLoop].firstSampleOffset = ma_dr_wav_bytes_to_u32(smplLoopData + 8); + pMetadata->data.smpl.pLoops[iSampleLoop].lastSampleOffset = ma_dr_wav_bytes_to_u32(smplLoopData + 12); + pMetadata->data.smpl.pLoops[iSampleLoop].sampleFraction = ma_dr_wav_bytes_to_u32(smplLoopData + 16); + pMetadata->data.smpl.pLoops[iSampleLoop].playCount = ma_dr_wav_bytes_to_u32(smplLoopData + 20); } else { break; } @@ -78756,7 +80568,7 @@ MA_PRIVATE ma_uint64 ma_dr_wav__read_cue_to_metadata_obj(ma_dr_wav__metadata_par pMetadata->data.cue.pCuePoints[iCuePoint].dataChunkId[3] = cuePointData[11]; pMetadata->data.cue.pCuePoints[iCuePoint].chunkStart = ma_dr_wav_bytes_to_u32(cuePointData + 12); pMetadata->data.cue.pCuePoints[iCuePoint].blockStart = ma_dr_wav_bytes_to_u32(cuePointData + 16); - pMetadata->data.cue.pCuePoints[iCuePoint].sampleByteOffset = ma_dr_wav_bytes_to_u32(cuePointData + 20); + pMetadata->data.cue.pCuePoints[iCuePoint].sampleOffset = ma_dr_wav_bytes_to_u32(cuePointData + 20); } else { break; } @@ -79096,7 +80908,7 @@ MA_PRIVATE ma_uint64 ma_dr_wav__metadata_process_chunk(ma_dr_wav__metadata_parse if (pParser->stage == ma_dr_wav__metadata_parser_stage_count) { ma_uint8 buffer[4]; size_t bytesJustRead; - if (!pParser->onSeek(pParser->pReadSeekUserData, 28, ma_dr_wav_seek_origin_current)) { + if (!pParser->onSeek(pParser->pReadSeekUserData, 28, MA_DR_WAV_SEEK_CUR)) { return bytesRead; } bytesRead += 28; @@ -79191,7 +81003,7 @@ MA_PRIVATE ma_uint64 ma_dr_wav__metadata_process_chunk(ma_dr_wav__metadata_parse return bytesRead; } allocSizeNeeded += ma_dr_wav__strlen(buffer) + 1; - allocSizeNeeded += (size_t)pChunkHeader->sizeInBytes - MA_DR_WAV_BEXT_BYTES; + allocSizeNeeded += (size_t)pChunkHeader->sizeInBytes - MA_DR_WAV_BEXT_BYTES + 1; ma_dr_wav__metadata_request_extra_memory_for_stage_2(pParser, allocSizeNeeded, 1); pParser->metadataCount += 1; } else { @@ -79274,6 +81086,16 @@ MA_PRIVATE ma_uint64 ma_dr_wav__metadata_process_chunk(ma_dr_wav__metadata_parse subchunkBytesRead = ma_dr_wav__metadata_process_info_text_chunk(pParser, subchunkDataSize, ma_dr_wav_metadata_type_list_info_album); } else if (ma_dr_wav__chunk_matches(allowedMetadataTypes, subchunkId, ma_dr_wav_metadata_type_list_info_tracknumber, "ITRK")) { subchunkBytesRead = ma_dr_wav__metadata_process_info_text_chunk(pParser, subchunkDataSize, ma_dr_wav_metadata_type_list_info_tracknumber); + } else if (ma_dr_wav__chunk_matches(allowedMetadataTypes, subchunkId, ma_dr_wav_metadata_type_list_info_location, "IARL")) { + subchunkBytesRead = ma_dr_wav__metadata_process_info_text_chunk(pParser, subchunkDataSize, ma_dr_wav_metadata_type_list_info_location); + } else if (ma_dr_wav__chunk_matches(allowedMetadataTypes, subchunkId, ma_dr_wav_metadata_type_list_info_organization, "ICMS")) { + subchunkBytesRead = ma_dr_wav__metadata_process_info_text_chunk(pParser, subchunkDataSize, ma_dr_wav_metadata_type_list_info_organization); + } else if (ma_dr_wav__chunk_matches(allowedMetadataTypes, subchunkId, ma_dr_wav_metadata_type_list_info_keywords, "IKEY")) { + subchunkBytesRead = ma_dr_wav__metadata_process_info_text_chunk(pParser, subchunkDataSize, ma_dr_wav_metadata_type_list_info_keywords); + } else if (ma_dr_wav__chunk_matches(allowedMetadataTypes, subchunkId, ma_dr_wav_metadata_type_list_info_medium, "IMED")) { + subchunkBytesRead = ma_dr_wav__metadata_process_info_text_chunk(pParser, subchunkDataSize, ma_dr_wav_metadata_type_list_info_medium); + } else if (ma_dr_wav__chunk_matches(allowedMetadataTypes, subchunkId, ma_dr_wav_metadata_type_list_info_description, "ISBJ")) { + subchunkBytesRead = ma_dr_wav__metadata_process_info_text_chunk(pParser, subchunkDataSize, ma_dr_wav_metadata_type_list_info_description); } else if ((allowedMetadataTypes & ma_dr_wav_metadata_type_unknown) != 0) { subchunkBytesRead = ma_dr_wav__metadata_process_unknown_chunk(pParser, subchunkId, subchunkDataSize, listType); } @@ -79281,13 +81103,13 @@ MA_PRIVATE ma_uint64 ma_dr_wav__metadata_process_chunk(ma_dr_wav__metadata_parse MA_DR_WAV_ASSERT(subchunkBytesRead <= subchunkDataSize); if (subchunkBytesRead < subchunkDataSize) { ma_uint64 bytesToSeek = subchunkDataSize - subchunkBytesRead; - if (!pParser->onSeek(pParser->pReadSeekUserData, (int)bytesToSeek, ma_dr_wav_seek_origin_current)) { + if (!pParser->onSeek(pParser->pReadSeekUserData, (int)bytesToSeek, MA_DR_WAV_SEEK_CUR)) { break; } bytesRead += bytesToSeek; } if ((subchunkDataSize % 2) == 1) { - if (!pParser->onSeek(pParser->pReadSeekUserData, 1, ma_dr_wav_seek_origin_current)) { + if (!pParser->onSeek(pParser->pReadSeekUserData, 1, MA_DR_WAV_SEEK_CUR)) { break; } bytesRead += 1; @@ -79324,7 +81146,7 @@ MA_API ma_uint16 ma_dr_wav_fmt_get_format(const ma_dr_wav_fmt* pFMT) return ma_dr_wav_bytes_to_u16(pFMT->subFormat); } } -MA_PRIVATE ma_bool32 ma_dr_wav_preinit(ma_dr_wav* pWav, ma_dr_wav_read_proc onRead, ma_dr_wav_seek_proc onSeek, void* pReadSeekUserData, const ma_allocation_callbacks* pAllocationCallbacks) +MA_PRIVATE ma_bool32 ma_dr_wav_preinit(ma_dr_wav* pWav, ma_dr_wav_read_proc onRead, ma_dr_wav_seek_proc onSeek, ma_dr_wav_tell_proc onTell, void* pReadSeekTellUserData, const ma_allocation_callbacks* pAllocationCallbacks) { if (pWav == NULL || onRead == NULL || onSeek == NULL) { return MA_FALSE; @@ -79332,7 +81154,8 @@ MA_PRIVATE ma_bool32 ma_dr_wav_preinit(ma_dr_wav* pWav, ma_dr_wav_read_proc onRe MA_DR_WAV_ZERO_MEMORY(pWav, sizeof(*pWav)); pWav->onRead = onRead; pWav->onSeek = onSeek; - pWav->pUserData = pReadSeekUserData; + pWav->onTell = onTell; + pWav->pUserData = pReadSeekTellUserData; pWav->allocationCallbacks = ma_dr_wav_copy_allocation_callbacks_or_defaults(pAllocationCallbacks); if (pWav->allocationCallbacks.onFree == NULL || (pWav->allocationCallbacks.onMalloc == NULL && pWav->allocationCallbacks.onRealloc == NULL)) { return MA_FALSE; @@ -79546,14 +81369,14 @@ MA_PRIVATE ma_bool32 ma_dr_wav_init__internal(ma_dr_wav* pWav, ma_dr_wav_chunk_p fmt.channelMask = ma_dr_wav_bytes_to_u32_ex(fmtext + 2, pWav->container); ma_dr_wav_bytes_to_guid(fmtext + 6, fmt.subFormat); } else { - if (pWav->onSeek(pWav->pUserData, fmt.extendedSize, ma_dr_wav_seek_origin_current) == MA_FALSE) { + if (pWav->onSeek(pWav->pUserData, fmt.extendedSize, MA_DR_WAV_SEEK_CUR) == MA_FALSE) { return MA_FALSE; } } cursor += fmt.extendedSize; bytesReadSoFar += fmt.extendedSize; } - if (pWav->onSeek(pWav->pUserData, (int)(header.sizeInBytes - bytesReadSoFar), ma_dr_wav_seek_origin_current) == MA_FALSE) { + if (pWav->onSeek(pWav->pUserData, (int)(header.sizeInBytes - bytesReadSoFar), MA_DR_WAV_SEEK_CUR) == MA_FALSE) { return MA_FALSE; } cursor += (header.sizeInBytes - bytesReadSoFar); @@ -79704,15 +81527,26 @@ MA_PRIVATE ma_bool32 ma_dr_wav_init__internal(ma_dr_wav* pWav, ma_dr_wav_chunk_p return MA_FALSE; } offset = ma_dr_wav_bytes_to_u32_ex(offsetAndBlockSizeData + 0, pWav->container); - if (ma_dr_wav__seek_forward(pWav->onSeek, offset, pWav->pUserData) == MA_FALSE) { - return MA_FALSE; - } - cursor += offset; - pWav->dataChunkDataPos = cursor; + pWav->dataChunkDataPos = cursor + offset; dataChunkSize = chunkSize; - if (sequential || !isProcessingMetadata) { - break; + if (dataChunkSize > offset) { + dataChunkSize -= offset; + } else { + dataChunkSize = 0; + } + if (sequential) { + if (foundChunk_fmt) { + if (ma_dr_wav__seek_forward(pWav->onSeek, offset, pWav->pUserData) == MA_FALSE) { + return MA_FALSE; + } + cursor += offset; + break; + } else { + return MA_FALSE; + } } else { + chunkSize += header.paddingSize; + chunkSize -= sizeof(offsetAndBlockSizeData); if (ma_dr_wav__seek_forward(pWav->onSeek, chunkSize, pWav->pUserData) == MA_FALSE) { break; } @@ -79776,6 +81610,17 @@ MA_PRIVATE ma_bool32 ma_dr_wav_init__internal(ma_dr_wav* pWav, ma_dr_wav_chunk_p pWav->pMetadata = metadataParser.pMetadata; pWav->metadataCount = metadataParser.metadataCount; } + if (pWav->onTell != NULL && pWav->onSeek != NULL) { + if (pWav->onSeek(pWav->pUserData, 0, MA_DR_WAV_SEEK_END) == MA_TRUE) { + ma_int64 fileSize; + if (pWav->onTell(pWav->pUserData, &fileSize)) { + if (dataChunkSize + pWav->dataChunkDataPos > (ma_uint64)fileSize) { + dataChunkSize = (ma_uint64)fileSize - pWav->dataChunkDataPos; + } + } + } else { + } + } if (dataChunkSize == 0xFFFFFFFF && (pWav->container == ma_dr_wav_container_riff || pWav->container == ma_dr_wav_container_rifx) && pWav->isSequentialWrite == MA_FALSE) { dataChunkSize = 0; for (;;) { @@ -79795,8 +81640,14 @@ MA_PRIVATE ma_bool32 ma_dr_wav_init__internal(ma_dr_wav* pWav, ma_dr_wav_chunk_p pWav->sampleRate = fmt.sampleRate; pWav->channels = fmt.channels; pWav->bitsPerSample = fmt.bitsPerSample; - pWav->bytesRemaining = dataChunkSize; pWav->translatedFormatTag = translatedFormatTag; + if (!ma_dr_wav__is_compressed_format_tag(translatedFormatTag)) { + ma_uint32 bytesPerFrame = ma_dr_wav_get_bytes_per_pcm_frame(pWav); + if (bytesPerFrame > 0) { + dataChunkSize -= (dataChunkSize % bytesPerFrame); + } + } + pWav->bytesRemaining = dataChunkSize; pWav->dataChunkDataSize = dataChunkSize; if (sampleCountFromFactChunk != 0) { pWav->totalPCMFrameCount = sampleCountFromFactChunk; @@ -79851,20 +81702,20 @@ MA_PRIVATE ma_bool32 ma_dr_wav_init__internal(ma_dr_wav* pWav, ma_dr_wav_chunk_p #endif return MA_TRUE; } -MA_API ma_bool32 ma_dr_wav_init(ma_dr_wav* pWav, ma_dr_wav_read_proc onRead, ma_dr_wav_seek_proc onSeek, void* pUserData, const ma_allocation_callbacks* pAllocationCallbacks) +MA_API ma_bool32 ma_dr_wav_init(ma_dr_wav* pWav, ma_dr_wav_read_proc onRead, ma_dr_wav_seek_proc onSeek, ma_dr_wav_tell_proc onTell, void* pUserData, const ma_allocation_callbacks* pAllocationCallbacks) { - return ma_dr_wav_init_ex(pWav, onRead, onSeek, NULL, pUserData, NULL, 0, pAllocationCallbacks); + return ma_dr_wav_init_ex(pWav, onRead, onSeek, onTell, NULL, pUserData, NULL, 0, pAllocationCallbacks); } -MA_API ma_bool32 ma_dr_wav_init_ex(ma_dr_wav* pWav, ma_dr_wav_read_proc onRead, ma_dr_wav_seek_proc onSeek, ma_dr_wav_chunk_proc onChunk, void* pReadSeekUserData, void* pChunkUserData, ma_uint32 flags, const ma_allocation_callbacks* pAllocationCallbacks) +MA_API ma_bool32 ma_dr_wav_init_ex(ma_dr_wav* pWav, ma_dr_wav_read_proc onRead, ma_dr_wav_seek_proc onSeek, ma_dr_wav_tell_proc onTell, ma_dr_wav_chunk_proc onChunk, void* pReadSeekTellUserData, void* pChunkUserData, ma_uint32 flags, const ma_allocation_callbacks* pAllocationCallbacks) { - if (!ma_dr_wav_preinit(pWav, onRead, onSeek, pReadSeekUserData, pAllocationCallbacks)) { + if (!ma_dr_wav_preinit(pWav, onRead, onSeek, onTell, pReadSeekTellUserData, pAllocationCallbacks)) { return MA_FALSE; } return ma_dr_wav_init__internal(pWav, onChunk, pChunkUserData, flags); } -MA_API ma_bool32 ma_dr_wav_init_with_metadata(ma_dr_wav* pWav, ma_dr_wav_read_proc onRead, ma_dr_wav_seek_proc onSeek, void* pUserData, ma_uint32 flags, const ma_allocation_callbacks* pAllocationCallbacks) +MA_API ma_bool32 ma_dr_wav_init_with_metadata(ma_dr_wav* pWav, ma_dr_wav_read_proc onRead, ma_dr_wav_seek_proc onSeek, ma_dr_wav_tell_proc onTell, void* pUserData, ma_uint32 flags, const ma_allocation_callbacks* pAllocationCallbacks) { - if (!ma_dr_wav_preinit(pWav, onRead, onSeek, pUserData, pAllocationCallbacks)) { + if (!ma_dr_wav_preinit(pWav, onRead, onSeek, onTell, pUserData, pAllocationCallbacks)) { return MA_FALSE; } return ma_dr_wav_init__internal(pWav, NULL, NULL, flags | MA_DR_WAV_WITH_METADATA); @@ -80026,8 +81877,8 @@ MA_PRIVATE size_t ma_dr_wav__write_or_count_metadata(ma_dr_wav* pWav, ma_dr_wav_ for (iLoop = 0; iLoop < pMetadata->data.smpl.sampleLoopCount; ++iLoop) { bytesWritten += ma_dr_wav__write_or_count_u32ne_to_le(pWav, pMetadata->data.smpl.pLoops[iLoop].cuePointId); bytesWritten += ma_dr_wav__write_or_count_u32ne_to_le(pWav, pMetadata->data.smpl.pLoops[iLoop].type); - bytesWritten += ma_dr_wav__write_or_count_u32ne_to_le(pWav, pMetadata->data.smpl.pLoops[iLoop].firstSampleByteOffset); - bytesWritten += ma_dr_wav__write_or_count_u32ne_to_le(pWav, pMetadata->data.smpl.pLoops[iLoop].lastSampleByteOffset); + bytesWritten += ma_dr_wav__write_or_count_u32ne_to_le(pWav, pMetadata->data.smpl.pLoops[iLoop].firstSampleOffset); + bytesWritten += ma_dr_wav__write_or_count_u32ne_to_le(pWav, pMetadata->data.smpl.pLoops[iLoop].lastSampleOffset); bytesWritten += ma_dr_wav__write_or_count_u32ne_to_le(pWav, pMetadata->data.smpl.pLoops[iLoop].sampleFraction); bytesWritten += ma_dr_wav__write_or_count_u32ne_to_le(pWav, pMetadata->data.smpl.pLoops[iLoop].playCount); } @@ -80061,7 +81912,7 @@ MA_PRIVATE size_t ma_dr_wav__write_or_count_metadata(ma_dr_wav* pWav, ma_dr_wav_ bytesWritten += ma_dr_wav__write_or_count(pWav, pMetadata->data.cue.pCuePoints[iCuePoint].dataChunkId, 4); bytesWritten += ma_dr_wav__write_or_count_u32ne_to_le(pWav, pMetadata->data.cue.pCuePoints[iCuePoint].chunkStart); bytesWritten += ma_dr_wav__write_or_count_u32ne_to_le(pWav, pMetadata->data.cue.pCuePoints[iCuePoint].blockStart); - bytesWritten += ma_dr_wav__write_or_count_u32ne_to_le(pWav, pMetadata->data.cue.pCuePoints[iCuePoint].sampleByteOffset); + bytesWritten += ma_dr_wav__write_or_count_u32ne_to_le(pWav, pMetadata->data.cue.pCuePoints[iCuePoint].sampleOffset); } } break; case ma_dr_wav_metadata_type_acid: @@ -80147,15 +81998,20 @@ MA_PRIVATE size_t ma_dr_wav__write_or_count_metadata(ma_dr_wav* pWav, ma_dr_wav_ if (pMetadata->type & ma_dr_wav_metadata_type_list_all_info_strings) { const char* pID = NULL; switch (pMetadata->type) { - case ma_dr_wav_metadata_type_list_info_software: pID = "ISFT"; break; - case ma_dr_wav_metadata_type_list_info_copyright: pID = "ICOP"; break; - case ma_dr_wav_metadata_type_list_info_title: pID = "INAM"; break; - case ma_dr_wav_metadata_type_list_info_artist: pID = "IART"; break; - case ma_dr_wav_metadata_type_list_info_comment: pID = "ICMT"; break; - case ma_dr_wav_metadata_type_list_info_date: pID = "ICRD"; break; - case ma_dr_wav_metadata_type_list_info_genre: pID = "IGNR"; break; - case ma_dr_wav_metadata_type_list_info_album: pID = "IPRD"; break; - case ma_dr_wav_metadata_type_list_info_tracknumber: pID = "ITRK"; break; + case ma_dr_wav_metadata_type_list_info_software: pID = "ISFT"; break; + case ma_dr_wav_metadata_type_list_info_copyright: pID = "ICOP"; break; + case ma_dr_wav_metadata_type_list_info_title: pID = "INAM"; break; + case ma_dr_wav_metadata_type_list_info_artist: pID = "IART"; break; + case ma_dr_wav_metadata_type_list_info_comment: pID = "ICMT"; break; + case ma_dr_wav_metadata_type_list_info_date: pID = "ICRD"; break; + case ma_dr_wav_metadata_type_list_info_genre: pID = "IGNR"; break; + case ma_dr_wav_metadata_type_list_info_album: pID = "IPRD"; break; + case ma_dr_wav_metadata_type_list_info_tracknumber: pID = "ITRK"; break; + case ma_dr_wav_metadata_type_list_info_location: pID = "IARL"; break; + case ma_dr_wav_metadata_type_list_info_organization: pID = "ICMS"; break; + case ma_dr_wav_metadata_type_list_info_keywords: pID = "IKEY"; break; + case ma_dr_wav_metadata_type_list_info_medium: pID = "IMED"; break; + case ma_dr_wav_metadata_type_list_info_description: pID = "ISBJ"; break; default: break; } MA_DR_WAV_ASSERT(pID != NULL); @@ -80370,7 +82226,7 @@ MA_PRIVATE ma_bool32 ma_dr_wav_init_write__internal(ma_dr_wav* pWav, const ma_dr } pWav->dataChunkDataSizeTargetWrite = initialDataChunkSize; if (pFormat->container == ma_dr_wav_container_riff) { - ma_uint32 chunkSizeRIFF = 28 + (ma_uint32)initialDataChunkSize; + ma_uint32 chunkSizeRIFF = 36 + (ma_uint32)initialDataChunkSize; runningPos += ma_dr_wav__write(pWav, "RIFF", 4); runningPos += ma_dr_wav__write_u32ne_to_le(pWav, chunkSizeRIFF); runningPos += ma_dr_wav__write(pWav, "WAVE", 4); @@ -80493,7 +82349,31 @@ MA_PRIVATE size_t ma_dr_wav__on_write_stdio(void* pUserData, const void* pData, } MA_PRIVATE ma_bool32 ma_dr_wav__on_seek_stdio(void* pUserData, int offset, ma_dr_wav_seek_origin origin) { - return fseek((FILE*)pUserData, offset, (origin == ma_dr_wav_seek_origin_current) ? SEEK_CUR : SEEK_SET) == 0; + int whence = SEEK_SET; + if (origin == MA_DR_WAV_SEEK_CUR) { + whence = SEEK_CUR; + } else if (origin == MA_DR_WAV_SEEK_END) { + whence = SEEK_END; + } + return fseek((FILE*)pUserData, offset, whence) == 0; +} +MA_PRIVATE ma_bool32 ma_dr_wav__on_tell_stdio(void* pUserData, ma_int64* pCursor) +{ + FILE* pFileStdio = (FILE*)pUserData; + ma_int64 result; + MA_DR_WAV_ASSERT(pFileStdio != NULL); + MA_DR_WAV_ASSERT(pCursor != NULL); +#if defined(_WIN32) && !defined(NXDK) + #if defined(_MSC_VER) && _MSC_VER > 1200 + result = _ftelli64(pFileStdio); + #else + result = ftell(pFileStdio); + #endif +#else + result = ftell(pFileStdio); +#endif + *pCursor = result; + return MA_TRUE; } MA_API ma_bool32 ma_dr_wav_init_file(ma_dr_wav* pWav, const char* filename, const ma_allocation_callbacks* pAllocationCallbacks) { @@ -80502,7 +82382,7 @@ MA_API ma_bool32 ma_dr_wav_init_file(ma_dr_wav* pWav, const char* filename, cons MA_PRIVATE ma_bool32 ma_dr_wav_init_file__internal_FILE(ma_dr_wav* pWav, FILE* pFile, ma_dr_wav_chunk_proc onChunk, void* pChunkUserData, ma_uint32 flags, const ma_allocation_callbacks* pAllocationCallbacks) { ma_bool32 result; - result = ma_dr_wav_preinit(pWav, ma_dr_wav__on_read_stdio, ma_dr_wav__on_seek_stdio, (void*)pFile, pAllocationCallbacks); + result = ma_dr_wav_preinit(pWav, ma_dr_wav__on_read_stdio, ma_dr_wav__on_seek_stdio, ma_dr_wav__on_tell_stdio, (void*)pFile, pAllocationCallbacks); if (result != MA_TRUE) { fclose(pFile); return result; @@ -80639,25 +82519,26 @@ MA_PRIVATE size_t ma_dr_wav__on_read_memory(void* pUserData, void* pBufferOut, s MA_PRIVATE ma_bool32 ma_dr_wav__on_seek_memory(void* pUserData, int offset, ma_dr_wav_seek_origin origin) { ma_dr_wav* pWav = (ma_dr_wav*)pUserData; + ma_int64 newCursor; MA_DR_WAV_ASSERT(pWav != NULL); - if (origin == ma_dr_wav_seek_origin_current) { - if (offset > 0) { - if (pWav->memoryStream.currentReadPos + offset > pWav->memoryStream.dataSize) { - return MA_FALSE; - } - } else { - if (pWav->memoryStream.currentReadPos < (size_t)-offset) { - return MA_FALSE; - } - } - pWav->memoryStream.currentReadPos += offset; + if (origin == MA_DR_WAV_SEEK_SET) { + newCursor = 0; + } else if (origin == MA_DR_WAV_SEEK_CUR) { + newCursor = (ma_int64)pWav->memoryStream.currentReadPos; + } else if (origin == MA_DR_WAV_SEEK_END) { + newCursor = (ma_int64)pWav->memoryStream.dataSize; } else { - if ((ma_uint32)offset <= pWav->memoryStream.dataSize) { - pWav->memoryStream.currentReadPos = offset; - } else { - return MA_FALSE; - } + MA_DR_WAV_ASSERT(!"Invalid seek origin"); + return MA_FALSE; + } + newCursor += offset; + if (newCursor < 0) { + return MA_FALSE; + } + if ((size_t)newCursor > pWav->memoryStream.dataSize) { + return MA_FALSE; } + pWav->memoryStream.currentReadPos = (size_t)newCursor; return MA_TRUE; } MA_PRIVATE size_t ma_dr_wav__on_write_memory(void* pUserData, const void* pDataIn, size_t bytesToWrite) @@ -80691,25 +82572,34 @@ MA_PRIVATE size_t ma_dr_wav__on_write_memory(void* pUserData, const void* pDataI MA_PRIVATE ma_bool32 ma_dr_wav__on_seek_memory_write(void* pUserData, int offset, ma_dr_wav_seek_origin origin) { ma_dr_wav* pWav = (ma_dr_wav*)pUserData; + ma_int64 newCursor; MA_DR_WAV_ASSERT(pWav != NULL); - if (origin == ma_dr_wav_seek_origin_current) { - if (offset > 0) { - if (pWav->memoryStreamWrite.currentWritePos + offset > pWav->memoryStreamWrite.dataSize) { - offset = (int)(pWav->memoryStreamWrite.dataSize - pWav->memoryStreamWrite.currentWritePos); - } - } else { - if (pWav->memoryStreamWrite.currentWritePos < (size_t)-offset) { - offset = -(int)pWav->memoryStreamWrite.currentWritePos; - } - } - pWav->memoryStreamWrite.currentWritePos += offset; + if (origin == MA_DR_WAV_SEEK_SET) { + newCursor = 0; + } else if (origin == MA_DR_WAV_SEEK_CUR) { + newCursor = (ma_int64)pWav->memoryStreamWrite.currentWritePos; + } else if (origin == MA_DR_WAV_SEEK_END) { + newCursor = (ma_int64)pWav->memoryStreamWrite.dataSize; } else { - if ((ma_uint32)offset <= pWav->memoryStreamWrite.dataSize) { - pWav->memoryStreamWrite.currentWritePos = offset; - } else { - pWav->memoryStreamWrite.currentWritePos = pWav->memoryStreamWrite.dataSize; - } + MA_DR_WAV_ASSERT(!"Invalid seek origin"); + return MA_FALSE; + } + newCursor += offset; + if (newCursor < 0) { + return MA_FALSE; + } + if ((size_t)newCursor > pWav->memoryStreamWrite.dataSize) { + return MA_FALSE; } + pWav->memoryStreamWrite.currentWritePos = (size_t)newCursor; + return MA_TRUE; +} +MA_PRIVATE ma_bool32 ma_dr_wav__on_tell_memory(void* pUserData, ma_int64* pCursor) +{ + ma_dr_wav* pWav = (ma_dr_wav*)pUserData; + MA_DR_WAV_ASSERT(pWav != NULL); + MA_DR_WAV_ASSERT(pCursor != NULL); + *pCursor = (ma_int64)pWav->memoryStream.currentReadPos; return MA_TRUE; } MA_API ma_bool32 ma_dr_wav_init_memory(ma_dr_wav* pWav, const void* data, size_t dataSize, const ma_allocation_callbacks* pAllocationCallbacks) @@ -80721,7 +82611,7 @@ MA_API ma_bool32 ma_dr_wav_init_memory_ex(ma_dr_wav* pWav, const void* data, siz if (data == NULL || dataSize == 0) { return MA_FALSE; } - if (!ma_dr_wav_preinit(pWav, ma_dr_wav__on_read_memory, ma_dr_wav__on_seek_memory, pWav, pAllocationCallbacks)) { + if (!ma_dr_wav_preinit(pWav, ma_dr_wav__on_read_memory, ma_dr_wav__on_seek_memory, ma_dr_wav__on_tell_memory, pWav, pAllocationCallbacks)) { return MA_FALSE; } pWav->memoryStream.data = (const ma_uint8*)data; @@ -80734,7 +82624,7 @@ MA_API ma_bool32 ma_dr_wav_init_memory_with_metadata(ma_dr_wav* pWav, const void if (data == NULL || dataSize == 0) { return MA_FALSE; } - if (!ma_dr_wav_preinit(pWav, ma_dr_wav__on_read_memory, ma_dr_wav__on_seek_memory, pWav, pAllocationCallbacks)) { + if (!ma_dr_wav_preinit(pWav, ma_dr_wav__on_read_memory, ma_dr_wav__on_seek_memory, ma_dr_wav__on_tell_memory, pWav, pAllocationCallbacks)) { return MA_FALSE; } pWav->memoryStream.data = (const ma_uint8*)data; @@ -80793,30 +82683,30 @@ MA_API ma_result ma_dr_wav_uninit(ma_dr_wav* pWav) } if (pWav->onSeek && !pWav->isSequentialWrite) { if (pWav->container == ma_dr_wav_container_riff) { - if (pWav->onSeek(pWav->pUserData, 4, ma_dr_wav_seek_origin_start)) { + if (pWav->onSeek(pWav->pUserData, 4, MA_DR_WAV_SEEK_SET)) { ma_uint32 riffChunkSize = ma_dr_wav__riff_chunk_size_riff(pWav->dataChunkDataSize, pWav->pMetadata, pWav->metadataCount); ma_dr_wav__write_u32ne_to_le(pWav, riffChunkSize); } - if (pWav->onSeek(pWav->pUserData, (int)pWav->dataChunkDataPos - 4, ma_dr_wav_seek_origin_start)) { + if (pWav->onSeek(pWav->pUserData, (int)pWav->dataChunkDataPos - 4, MA_DR_WAV_SEEK_SET)) { ma_uint32 dataChunkSize = ma_dr_wav__data_chunk_size_riff(pWav->dataChunkDataSize); ma_dr_wav__write_u32ne_to_le(pWav, dataChunkSize); } } else if (pWav->container == ma_dr_wav_container_w64) { - if (pWav->onSeek(pWav->pUserData, 16, ma_dr_wav_seek_origin_start)) { + if (pWav->onSeek(pWav->pUserData, 16, MA_DR_WAV_SEEK_SET)) { ma_uint64 riffChunkSize = ma_dr_wav__riff_chunk_size_w64(pWav->dataChunkDataSize); ma_dr_wav__write_u64ne_to_le(pWav, riffChunkSize); } - if (pWav->onSeek(pWav->pUserData, (int)pWav->dataChunkDataPos - 8, ma_dr_wav_seek_origin_start)) { + if (pWav->onSeek(pWav->pUserData, (int)pWav->dataChunkDataPos - 8, MA_DR_WAV_SEEK_SET)) { ma_uint64 dataChunkSize = ma_dr_wav__data_chunk_size_w64(pWav->dataChunkDataSize); ma_dr_wav__write_u64ne_to_le(pWav, dataChunkSize); } } else if (pWav->container == ma_dr_wav_container_rf64) { int ds64BodyPos = 12 + 8; - if (pWav->onSeek(pWav->pUserData, ds64BodyPos + 0, ma_dr_wav_seek_origin_start)) { + if (pWav->onSeek(pWav->pUserData, ds64BodyPos + 0, MA_DR_WAV_SEEK_SET)) { ma_uint64 riffChunkSize = ma_dr_wav__riff_chunk_size_rf64(pWav->dataChunkDataSize, pWav->pMetadata, pWav->metadataCount); ma_dr_wav__write_u64ne_to_le(pWav, riffChunkSize); } - if (pWav->onSeek(pWav->pUserData, ds64BodyPos + 8, ma_dr_wav_seek_origin_start)) { + if (pWav->onSeek(pWav->pUserData, ds64BodyPos + 8, MA_DR_WAV_SEEK_SET)) { ma_uint64 dataChunkSize = ma_dr_wav__data_chunk_size_rf64(pWav->dataChunkDataSize); ma_dr_wav__write_u64ne_to_le(pWav, dataChunkSize); } @@ -80863,7 +82753,7 @@ MA_API size_t ma_dr_wav_read_raw(ma_dr_wav* pWav, size_t bytesToRead, void* pBuf if (bytesToSeek > 0x7FFFFFFF) { bytesToSeek = 0x7FFFFFFF; } - if (pWav->onSeek(pWav->pUserData, (int)bytesToSeek, ma_dr_wav_seek_origin_current) == MA_FALSE) { + if (pWav->onSeek(pWav->pUserData, (int)bytesToSeek, MA_DR_WAV_SEEK_CUR) == MA_FALSE) { break; } bytesRead += bytesToSeek; @@ -80962,7 +82852,7 @@ MA_PRIVATE ma_bool32 ma_dr_wav_seek_to_first_pcm_frame(ma_dr_wav* pWav) if (pWav->onWrite != NULL) { return MA_FALSE; } - if (!pWav->onSeek(pWav->pUserData, (int)pWav->dataChunkDataPos, ma_dr_wav_seek_origin_start)) { + if (!pWav->onSeek(pWav->pUserData, (int)pWav->dataChunkDataPos, MA_DR_WAV_SEEK_SET)) { return MA_FALSE; } if (ma_dr_wav__is_compressed_format_tag(pWav->translatedFormatTag)) { @@ -81043,7 +82933,7 @@ MA_API ma_bool32 ma_dr_wav_seek_to_pcm_frame(ma_dr_wav* pWav, ma_uint64 targetFr } while (offset > 0) { int offset32 = ((offset > INT_MAX) ? INT_MAX : (int)offset); - if (!pWav->onSeek(pWav->pUserData, offset32, ma_dr_wav_seek_origin_current)) { + if (!pWav->onSeek(pWav->pUserData, offset32, MA_DR_WAV_SEEK_CUR)) { return MA_FALSE; } pWav->readCursorInPCMFrames += offset32 / bytesPerFrame; @@ -81169,12 +83059,12 @@ MA_API ma_uint64 ma_dr_wav_write_pcm_frames(ma_dr_wav* pWav, ma_uint64 framesToW MA_PRIVATE ma_uint64 ma_dr_wav_read_pcm_frames_s16__msadpcm(ma_dr_wav* pWav, ma_uint64 framesToRead, ma_int16* pBufferOut) { ma_uint64 totalFramesRead = 0; - static ma_int32 adaptationTable[] = { + static const ma_int32 adaptationTable[] = { 230, 230, 230, 230, 307, 409, 512, 614, 768, 614, 512, 409, 307, 230, 230, 230 }; - static ma_int32 coeff1Table[] = { 256, 512, 0, 192, 240, 460, 392 }; - static ma_int32 coeff2Table[] = { 0, -256, 0, 64, 0, -208, -232 }; + static const ma_int32 coeff1Table[] = { 256, 512, 0, 192, 240, 460, 392 }; + static const ma_int32 coeff2Table[] = { 0, -256, 0, 64, 0, -208, -232 }; MA_DR_WAV_ASSERT(pWav != NULL); MA_DR_WAV_ASSERT(framesToRead > 0); while (pWav->readCursorInPCMFrames < pWav->totalPCMFrameCount) { @@ -81193,7 +83083,7 @@ MA_PRIVATE ma_uint64 ma_dr_wav_read_pcm_frames_s16__msadpcm(ma_dr_wav* pWav, ma_ pWav->msadpcm.cachedFrames[2] = pWav->msadpcm.prevFrames[0][0]; pWav->msadpcm.cachedFrames[3] = pWav->msadpcm.prevFrames[0][1]; pWav->msadpcm.cachedFrameCount = 2; - if (pWav->msadpcm.predictor[0] >= ma_dr_wav_countof(coeff1Table)) { + if (pWav->msadpcm.predictor[0] >= ma_dr_wav_countof(coeff1Table) || pWav->msadpcm.predictor[0] >= ma_dr_wav_countof(coeff2Table)) { return totalFramesRead; } } else { @@ -81215,7 +83105,8 @@ MA_PRIVATE ma_uint64 ma_dr_wav_read_pcm_frames_s16__msadpcm(ma_dr_wav* pWav, ma_ pWav->msadpcm.cachedFrames[2] = pWav->msadpcm.prevFrames[0][1]; pWav->msadpcm.cachedFrames[3] = pWav->msadpcm.prevFrames[1][1]; pWav->msadpcm.cachedFrameCount = 2; - if (pWav->msadpcm.predictor[0] >= ma_dr_wav_countof(coeff1Table) || pWav->msadpcm.predictor[1] >= ma_dr_wav_countof(coeff2Table)) { + if (pWav->msadpcm.predictor[0] >= ma_dr_wav_countof(coeff1Table) || pWav->msadpcm.predictor[0] >= ma_dr_wav_countof(coeff2Table) || + pWav->msadpcm.predictor[1] >= ma_dr_wav_countof(coeff1Table) || pWav->msadpcm.predictor[1] >= ma_dr_wav_countof(coeff2Table)) { return totalFramesRead; } } @@ -81252,6 +83143,9 @@ MA_PRIVATE ma_uint64 ma_dr_wav_read_pcm_frames_s16__msadpcm(ma_dr_wav* pWav, ma_ if (pWav->channels == 1) { ma_int32 newSample0; ma_int32 newSample1; + if (pWav->msadpcm.predictor[0] >= ma_dr_wav_countof(coeff1Table) || pWav->msadpcm.predictor[0] >= ma_dr_wav_countof(coeff2Table)) { + return totalFramesRead; + } newSample0 = ((pWav->msadpcm.prevFrames[0][1] * coeff1Table[pWav->msadpcm.predictor[0]]) + (pWav->msadpcm.prevFrames[0][0] * coeff2Table[pWav->msadpcm.predictor[0]])) >> 8; newSample0 += nibble0 * pWav->msadpcm.delta[0]; newSample0 = ma_dr_wav_clamp(newSample0, -32768, 32767); @@ -81276,6 +83170,9 @@ MA_PRIVATE ma_uint64 ma_dr_wav_read_pcm_frames_s16__msadpcm(ma_dr_wav* pWav, ma_ } else { ma_int32 newSample0; ma_int32 newSample1; + if (pWav->msadpcm.predictor[0] >= ma_dr_wav_countof(coeff1Table) || pWav->msadpcm.predictor[0] >= ma_dr_wav_countof(coeff2Table)) { + return totalFramesRead; + } newSample0 = ((pWav->msadpcm.prevFrames[0][1] * coeff1Table[pWav->msadpcm.predictor[0]]) + (pWav->msadpcm.prevFrames[0][0] * coeff2Table[pWav->msadpcm.predictor[0]])) >> 8; newSample0 += nibble0 * pWav->msadpcm.delta[0]; newSample0 = ma_dr_wav_clamp(newSample0, -32768, 32767); @@ -81285,6 +83182,9 @@ MA_PRIVATE ma_uint64 ma_dr_wav_read_pcm_frames_s16__msadpcm(ma_dr_wav* pWav, ma_ } pWav->msadpcm.prevFrames[0][0] = pWav->msadpcm.prevFrames[0][1]; pWav->msadpcm.prevFrames[0][1] = newSample0; + if (pWav->msadpcm.predictor[1] >= ma_dr_wav_countof(coeff1Table) || pWav->msadpcm.predictor[1] >= ma_dr_wav_countof(coeff2Table)) { + return totalFramesRead; + } newSample1 = ((pWav->msadpcm.prevFrames[1][1] * coeff1Table[pWav->msadpcm.predictor[1]]) + (pWav->msadpcm.prevFrames[1][0] * coeff2Table[pWav->msadpcm.predictor[1]])) >> 8; newSample1 += nibble1 * pWav->msadpcm.delta[1]; newSample1 = ma_dr_wav_clamp(newSample1, -32768, 32767); @@ -81307,11 +83207,11 @@ MA_PRIVATE ma_uint64 ma_dr_wav_read_pcm_frames_s16__ima(ma_dr_wav* pWav, ma_uint { ma_uint64 totalFramesRead = 0; ma_uint32 iChannel; - static ma_int32 indexTable[16] = { + static const ma_int32 indexTable[16] = { -1, -1, -1, -1, 2, 4, 6, 8, -1, -1, -1, -1, 2, 4, 6, 8 }; - static ma_int32 stepTable[89] = { + static const ma_int32 stepTable[89] = { 7, 8, 9, 10, 11, 12, 13, 14, 16, 17, 19, 21, 23, 25, 28, 31, 34, 37, 41, 45, 50, 55, 60, 66, 73, 80, 88, 97, 107, 118, @@ -81334,7 +83234,7 @@ MA_PRIVATE ma_uint64 ma_dr_wav_read_pcm_frames_s16__ima(ma_dr_wav* pWav, ma_uint } pWav->ima.bytesRemainingInBlock = pWav->fmt.blockAlign - sizeof(header); if (header[2] >= ma_dr_wav_countof(stepTable)) { - pWav->onSeek(pWav->pUserData, pWav->ima.bytesRemainingInBlock, ma_dr_wav_seek_origin_current); + pWav->onSeek(pWav->pUserData, pWav->ima.bytesRemainingInBlock, MA_DR_WAV_SEEK_CUR); pWav->ima.bytesRemainingInBlock = 0; return totalFramesRead; } @@ -81349,7 +83249,7 @@ MA_PRIVATE ma_uint64 ma_dr_wav_read_pcm_frames_s16__ima(ma_dr_wav* pWav, ma_uint } pWav->ima.bytesRemainingInBlock = pWav->fmt.blockAlign - sizeof(header); if (header[2] >= ma_dr_wav_countof(stepTable) || header[6] >= ma_dr_wav_countof(stepTable)) { - pWav->onSeek(pWav->pUserData, pWav->ima.bytesRemainingInBlock, ma_dr_wav_seek_origin_current); + pWav->onSeek(pWav->pUserData, pWav->ima.bytesRemainingInBlock, MA_DR_WAV_SEEK_CUR); pWav->ima.bytesRemainingInBlock = 0; return totalFramesRead; } @@ -81424,7 +83324,7 @@ MA_PRIVATE ma_uint64 ma_dr_wav_read_pcm_frames_s16__ima(ma_dr_wav* pWav, ma_uint return totalFramesRead; } #ifndef MA_DR_WAV_NO_CONVERSION_API -static unsigned short g_ma_dr_wavAlawTable[256] = { +static const unsigned short ma_dr_wav_gAlawTable[256] = { 0xEA80, 0xEB80, 0xE880, 0xE980, 0xEE80, 0xEF80, 0xEC80, 0xED80, 0xE280, 0xE380, 0xE080, 0xE180, 0xE680, 0xE780, 0xE480, 0xE580, 0xF540, 0xF5C0, 0xF440, 0xF4C0, 0xF740, 0xF7C0, 0xF640, 0xF6C0, 0xF140, 0xF1C0, 0xF040, 0xF0C0, 0xF340, 0xF3C0, 0xF240, 0xF2C0, 0xAA00, 0xAE00, 0xA200, 0xA600, 0xBA00, 0xBE00, 0xB200, 0xB600, 0x8A00, 0x8E00, 0x8200, 0x8600, 0x9A00, 0x9E00, 0x9200, 0x9600, @@ -81442,7 +83342,7 @@ static unsigned short g_ma_dr_wavAlawTable[256] = { 0x0560, 0x0520, 0x05E0, 0x05A0, 0x0460, 0x0420, 0x04E0, 0x04A0, 0x0760, 0x0720, 0x07E0, 0x07A0, 0x0660, 0x0620, 0x06E0, 0x06A0, 0x02B0, 0x0290, 0x02F0, 0x02D0, 0x0230, 0x0210, 0x0270, 0x0250, 0x03B0, 0x0390, 0x03F0, 0x03D0, 0x0330, 0x0310, 0x0370, 0x0350 }; -static unsigned short g_ma_dr_wavMulawTable[256] = { +static const unsigned short ma_dr_wav_gMulawTable[256] = { 0x8284, 0x8684, 0x8A84, 0x8E84, 0x9284, 0x9684, 0x9A84, 0x9E84, 0xA284, 0xA684, 0xAA84, 0xAE84, 0xB284, 0xB684, 0xBA84, 0xBE84, 0xC184, 0xC384, 0xC584, 0xC784, 0xC984, 0xCB84, 0xCD84, 0xCF84, 0xD184, 0xD384, 0xD584, 0xD784, 0xD984, 0xDB84, 0xDD84, 0xDF84, 0xE104, 0xE204, 0xE304, 0xE404, 0xE504, 0xE604, 0xE704, 0xE804, 0xE904, 0xEA04, 0xEB04, 0xEC04, 0xED04, 0xEE04, 0xEF04, 0xF004, @@ -81462,11 +83362,11 @@ static unsigned short g_ma_dr_wavMulawTable[256] = { }; static MA_INLINE ma_int16 ma_dr_wav__alaw_to_s16(ma_uint8 sampleIn) { - return (short)g_ma_dr_wavAlawTable[sampleIn]; + return (short)ma_dr_wav_gAlawTable[sampleIn]; } static MA_INLINE ma_int16 ma_dr_wav__mulaw_to_s16(ma_uint8 sampleIn) { - return (short)g_ma_dr_wavMulawTable[sampleIn]; + return (short)ma_dr_wav_gMulawTable[sampleIn]; } MA_PRIVATE void ma_dr_wav__pcm_to_s16(ma_int16* pOut, const ma_uint8* pIn, size_t totalSampleCount, unsigned int bytesPerSample) { @@ -82529,6 +84429,10 @@ MA_PRIVATE ma_int16* ma_dr_wav__read_pcm_frames_and_close_s16(ma_dr_wav* pWav, u ma_int16* pSampleData; ma_uint64 framesRead; MA_DR_WAV_ASSERT(pWav != NULL); + if (pWav->channels == 0 || pWav->totalPCMFrameCount > MA_SIZE_MAX / pWav->channels / sizeof(ma_int16)) { + ma_dr_wav_uninit(pWav); + return NULL; + } sampleDataSize = pWav->totalPCMFrameCount * pWav->channels * sizeof(ma_int16); if (sampleDataSize > MA_SIZE_MAX) { ma_dr_wav_uninit(pWav); @@ -82563,6 +84467,10 @@ MA_PRIVATE float* ma_dr_wav__read_pcm_frames_and_close_f32(ma_dr_wav* pWav, unsi float* pSampleData; ma_uint64 framesRead; MA_DR_WAV_ASSERT(pWav != NULL); + if (pWav->channels == 0 || pWav->totalPCMFrameCount > MA_SIZE_MAX / pWav->channels / sizeof(float)) { + ma_dr_wav_uninit(pWav); + return NULL; + } sampleDataSize = pWav->totalPCMFrameCount * pWav->channels * sizeof(float); if (sampleDataSize > MA_SIZE_MAX) { ma_dr_wav_uninit(pWav); @@ -82597,6 +84505,10 @@ MA_PRIVATE ma_int32* ma_dr_wav__read_pcm_frames_and_close_s32(ma_dr_wav* pWav, u ma_int32* pSampleData; ma_uint64 framesRead; MA_DR_WAV_ASSERT(pWav != NULL); + if (pWav->channels == 0 || pWav->totalPCMFrameCount > MA_SIZE_MAX / pWav->channels / sizeof(ma_int32)) { + ma_dr_wav_uninit(pWav); + return NULL; + } sampleDataSize = pWav->totalPCMFrameCount * pWav->channels * sizeof(ma_int32); if (sampleDataSize > MA_SIZE_MAX) { ma_dr_wav_uninit(pWav); @@ -82625,7 +84537,7 @@ MA_PRIVATE ma_int32* ma_dr_wav__read_pcm_frames_and_close_s32(ma_dr_wav* pWav, u } return pSampleData; } -MA_API ma_int16* ma_dr_wav_open_and_read_pcm_frames_s16(ma_dr_wav_read_proc onRead, ma_dr_wav_seek_proc onSeek, void* pUserData, unsigned int* channelsOut, unsigned int* sampleRateOut, ma_uint64* totalFrameCountOut, const ma_allocation_callbacks* pAllocationCallbacks) +MA_API ma_int16* ma_dr_wav_open_and_read_pcm_frames_s16(ma_dr_wav_read_proc onRead, ma_dr_wav_seek_proc onSeek, ma_dr_wav_tell_proc onTell, void* pUserData, unsigned int* channelsOut, unsigned int* sampleRateOut, ma_uint64* totalFrameCountOut, const ma_allocation_callbacks* pAllocationCallbacks) { ma_dr_wav wav; if (channelsOut) { @@ -82637,12 +84549,12 @@ MA_API ma_int16* ma_dr_wav_open_and_read_pcm_frames_s16(ma_dr_wav_read_proc onRe if (totalFrameCountOut) { *totalFrameCountOut = 0; } - if (!ma_dr_wav_init(&wav, onRead, onSeek, pUserData, pAllocationCallbacks)) { + if (!ma_dr_wav_init(&wav, onRead, onSeek, onTell, pUserData, pAllocationCallbacks)) { return NULL; } return ma_dr_wav__read_pcm_frames_and_close_s16(&wav, channelsOut, sampleRateOut, totalFrameCountOut); } -MA_API float* ma_dr_wav_open_and_read_pcm_frames_f32(ma_dr_wav_read_proc onRead, ma_dr_wav_seek_proc onSeek, void* pUserData, unsigned int* channelsOut, unsigned int* sampleRateOut, ma_uint64* totalFrameCountOut, const ma_allocation_callbacks* pAllocationCallbacks) +MA_API float* ma_dr_wav_open_and_read_pcm_frames_f32(ma_dr_wav_read_proc onRead, ma_dr_wav_seek_proc onSeek, ma_dr_wav_tell_proc onTell, void* pUserData, unsigned int* channelsOut, unsigned int* sampleRateOut, ma_uint64* totalFrameCountOut, const ma_allocation_callbacks* pAllocationCallbacks) { ma_dr_wav wav; if (channelsOut) { @@ -82654,12 +84566,12 @@ MA_API float* ma_dr_wav_open_and_read_pcm_frames_f32(ma_dr_wav_read_proc onRead, if (totalFrameCountOut) { *totalFrameCountOut = 0; } - if (!ma_dr_wav_init(&wav, onRead, onSeek, pUserData, pAllocationCallbacks)) { + if (!ma_dr_wav_init(&wav, onRead, onSeek, onTell, pUserData, pAllocationCallbacks)) { return NULL; } return ma_dr_wav__read_pcm_frames_and_close_f32(&wav, channelsOut, sampleRateOut, totalFrameCountOut); } -MA_API ma_int32* ma_dr_wav_open_and_read_pcm_frames_s32(ma_dr_wav_read_proc onRead, ma_dr_wav_seek_proc onSeek, void* pUserData, unsigned int* channelsOut, unsigned int* sampleRateOut, ma_uint64* totalFrameCountOut, const ma_allocation_callbacks* pAllocationCallbacks) +MA_API ma_int32* ma_dr_wav_open_and_read_pcm_frames_s32(ma_dr_wav_read_proc onRead, ma_dr_wav_seek_proc onSeek, ma_dr_wav_tell_proc onTell, void* pUserData, unsigned int* channelsOut, unsigned int* sampleRateOut, ma_uint64* totalFrameCountOut, const ma_allocation_callbacks* pAllocationCallbacks) { ma_dr_wav wav; if (channelsOut) { @@ -82671,7 +84583,7 @@ MA_API ma_int32* ma_dr_wav_open_and_read_pcm_frames_s32(ma_dr_wav_read_proc onRe if (totalFrameCountOut) { *totalFrameCountOut = 0; } - if (!ma_dr_wav_init(&wav, onRead, onSeek, pUserData, pAllocationCallbacks)) { + if (!ma_dr_wav_init(&wav, onRead, onSeek, onTell, pUserData, pAllocationCallbacks)) { return NULL; } return ma_dr_wav__read_pcm_frames_and_close_s32(&wav, channelsOut, sampleRateOut, totalFrameCountOut); @@ -83979,7 +85891,7 @@ static MA_INLINE ma_uint32 ma_dr_flac__clz_lzcnt(ma_dr_flac_cache_t x) { ma_uint64 r; __asm__ __volatile__ ( - "lzcnt{ %1, %0| %0, %1}" : "=r"(r) : "r"(x) : "cc" + "rep; bsr{q %1, %0| %0, %1}" : "=r"(r) : "r"(x) : "cc" ); return (ma_uint32)r; } @@ -83987,11 +85899,11 @@ static MA_INLINE ma_uint32 ma_dr_flac__clz_lzcnt(ma_dr_flac_cache_t x) { ma_uint32 r; __asm__ __volatile__ ( - "lzcnt{l %1, %0| %0, %1}" : "=r"(r) : "r"(x) : "cc" + "rep; bsr{l %1, %0| %0, %1}" : "=r"(r) : "r"(x) : "cc" ); return r; } - #elif defined(MA_ARM) && (defined(__ARM_ARCH) && __ARM_ARCH >= 5) && !defined(__ARM_ARCH_6M__) && !defined(MA_64BIT) + #elif defined(MA_ARM) && (defined(__ARM_ARCH) && __ARM_ARCH >= 5) && !defined(__ARM_ARCH_6M__) && !(defined(__thumb__) && !defined(__thumb2__)) && !defined(MA_64BIT) { unsigned int r; __asm__ __volatile__ ( @@ -84106,23 +86018,23 @@ static ma_bool32 ma_dr_flac__seek_to_byte(ma_dr_flac_bs* bs, ma_uint64 offsetFro MA_DR_FLAC_ASSERT(offsetFromStart > 0); if (offsetFromStart > 0x7FFFFFFF) { ma_uint64 bytesRemaining = offsetFromStart; - if (!bs->onSeek(bs->pUserData, 0x7FFFFFFF, ma_dr_flac_seek_origin_start)) { + if (!bs->onSeek(bs->pUserData, 0x7FFFFFFF, MA_DR_FLAC_SEEK_SET)) { return MA_FALSE; } bytesRemaining -= 0x7FFFFFFF; while (bytesRemaining > 0x7FFFFFFF) { - if (!bs->onSeek(bs->pUserData, 0x7FFFFFFF, ma_dr_flac_seek_origin_current)) { + if (!bs->onSeek(bs->pUserData, 0x7FFFFFFF, MA_DR_FLAC_SEEK_CUR)) { return MA_FALSE; } bytesRemaining -= 0x7FFFFFFF; } if (bytesRemaining > 0) { - if (!bs->onSeek(bs->pUserData, (int)bytesRemaining, ma_dr_flac_seek_origin_current)) { + if (!bs->onSeek(bs->pUserData, (int)bytesRemaining, MA_DR_FLAC_SEEK_CUR)) { return MA_FALSE; } } } else { - if (!bs->onSeek(bs->pUserData, (int)offsetFromStart, ma_dr_flac_seek_origin_start)) { + if (!bs->onSeek(bs->pUserData, (int)offsetFromStart, MA_DR_FLAC_SEEK_SET)) { return MA_FALSE; } } @@ -86600,6 +88512,7 @@ typedef struct { ma_dr_flac_read_proc onRead; ma_dr_flac_seek_proc onSeek; + ma_dr_flac_tell_proc onTell; ma_dr_flac_meta_proc onMeta; ma_dr_flac_container container; void* pUserData; @@ -86728,11 +88641,12 @@ static void ma_dr_flac__free_from_callbacks(void* p, const ma_allocation_callbac pAllocationCallbacks->onFree(p, pAllocationCallbacks->pUserData); } } -static ma_bool32 ma_dr_flac__read_and_decode_metadata(ma_dr_flac_read_proc onRead, ma_dr_flac_seek_proc onSeek, ma_dr_flac_meta_proc onMeta, void* pUserData, void* pUserDataMD, ma_uint64* pFirstFramePos, ma_uint64* pSeektablePos, ma_uint32* pSeekpointCount, ma_allocation_callbacks* pAllocationCallbacks) +static ma_bool32 ma_dr_flac__read_and_decode_metadata(ma_dr_flac_read_proc onRead, ma_dr_flac_seek_proc onSeek, ma_dr_flac_tell_proc onTell, ma_dr_flac_meta_proc onMeta, void* pUserData, void* pUserDataMD, ma_uint64* pFirstFramePos, ma_uint64* pSeektablePos, ma_uint32* pSeekpointCount, ma_allocation_callbacks* pAllocationCallbacks) { ma_uint64 runningFilePos = 42; ma_uint64 seektablePos = 0; ma_uint32 seektableSize = 0; + (void)onTell; for (;;) { ma_dr_flac_metadata metadata; ma_uint8 isLastBlock = 0; @@ -86743,8 +88657,9 @@ static ma_bool32 ma_dr_flac__read_and_decode_metadata(ma_dr_flac_read_proc onRea } runningFilePos += 4; metadata.type = blockType; - metadata.pRawData = NULL; metadata.rawDataSize = 0; + metadata.rawDataOffset = runningFilePos; + metadata.pRawData = NULL; switch (blockType) { case MA_DR_FLAC_METADATA_BLOCK_TYPE_APPLICATION: @@ -86944,53 +88859,124 @@ static ma_bool32 ma_dr_flac__read_and_decode_metadata(ma_dr_flac_read_proc onRea return MA_FALSE; } if (onMeta) { - void* pRawData; - const char* pRunningData; - const char* pRunningDataEnd; - pRawData = ma_dr_flac__malloc_from_callbacks(blockSize, pAllocationCallbacks); - if (pRawData == NULL) { - return MA_FALSE; + ma_bool32 result = MA_TRUE; + ma_uint32 blockSizeRemaining = blockSize; + char* pMime = NULL; + char* pDescription = NULL; + void* pPictureData = NULL; + if (blockSizeRemaining < 4 || onRead(pUserData, &metadata.data.picture.type, 4) != 4) { + result = MA_FALSE; + goto done_flac; } - if (onRead(pUserData, pRawData, blockSize) != blockSize) { - ma_dr_flac__free_from_callbacks(pRawData, pAllocationCallbacks); - return MA_FALSE; + blockSizeRemaining -= 4; + metadata.data.picture.type = ma_dr_flac__be2host_32(metadata.data.picture.type); + if (blockSizeRemaining < 4 || onRead(pUserData, &metadata.data.picture.mimeLength, 4) != 4) { + result = MA_FALSE; + goto done_flac; } - metadata.pRawData = pRawData; - metadata.rawDataSize = blockSize; - pRunningData = (const char*)pRawData; - pRunningDataEnd = (const char*)pRawData + blockSize; - metadata.data.picture.type = ma_dr_flac__be2host_32_ptr_unaligned(pRunningData); pRunningData += 4; - metadata.data.picture.mimeLength = ma_dr_flac__be2host_32_ptr_unaligned(pRunningData); pRunningData += 4; - if ((pRunningDataEnd - pRunningData) - 24 < (ma_int64)metadata.data.picture.mimeLength) { - ma_dr_flac__free_from_callbacks(pRawData, pAllocationCallbacks); - return MA_FALSE; + blockSizeRemaining -= 4; + metadata.data.picture.mimeLength = ma_dr_flac__be2host_32(metadata.data.picture.mimeLength); + pMime = (char*)ma_dr_flac__malloc_from_callbacks(metadata.data.picture.mimeLength + 1, pAllocationCallbacks); + if (pMime == NULL) { + result = MA_FALSE; + goto done_flac; } - metadata.data.picture.mime = pRunningData; pRunningData += metadata.data.picture.mimeLength; - metadata.data.picture.descriptionLength = ma_dr_flac__be2host_32_ptr_unaligned(pRunningData); pRunningData += 4; - if ((pRunningDataEnd - pRunningData) - 20 < (ma_int64)metadata.data.picture.descriptionLength) { - ma_dr_flac__free_from_callbacks(pRawData, pAllocationCallbacks); - return MA_FALSE; + if (blockSizeRemaining < metadata.data.picture.mimeLength || onRead(pUserData, pMime, metadata.data.picture.mimeLength) != metadata.data.picture.mimeLength) { + result = MA_FALSE; + goto done_flac; } - metadata.data.picture.description = pRunningData; pRunningData += metadata.data.picture.descriptionLength; - metadata.data.picture.width = ma_dr_flac__be2host_32_ptr_unaligned(pRunningData); pRunningData += 4; - metadata.data.picture.height = ma_dr_flac__be2host_32_ptr_unaligned(pRunningData); pRunningData += 4; - metadata.data.picture.colorDepth = ma_dr_flac__be2host_32_ptr_unaligned(pRunningData); pRunningData += 4; - metadata.data.picture.indexColorCount = ma_dr_flac__be2host_32_ptr_unaligned(pRunningData); pRunningData += 4; - metadata.data.picture.pictureDataSize = ma_dr_flac__be2host_32_ptr_unaligned(pRunningData); pRunningData += 4; - metadata.data.picture.pPictureData = (const ma_uint8*)pRunningData; - if (pRunningDataEnd - pRunningData < (ma_int64)metadata.data.picture.pictureDataSize) { - ma_dr_flac__free_from_callbacks(pRawData, pAllocationCallbacks); + blockSizeRemaining -= metadata.data.picture.mimeLength; + pMime[metadata.data.picture.mimeLength] = '\0'; + metadata.data.picture.mime = (const char*)pMime; + if (blockSizeRemaining < 4 || onRead(pUserData, &metadata.data.picture.descriptionLength, 4) != 4) { + result = MA_FALSE; + goto done_flac; + } + blockSizeRemaining -= 4; + metadata.data.picture.descriptionLength = ma_dr_flac__be2host_32(metadata.data.picture.descriptionLength); + pDescription = (char*)ma_dr_flac__malloc_from_callbacks(metadata.data.picture.descriptionLength + 1, pAllocationCallbacks); + if (pDescription == NULL) { + result = MA_FALSE; + goto done_flac; + } + if (blockSizeRemaining < metadata.data.picture.descriptionLength || onRead(pUserData, pDescription, metadata.data.picture.descriptionLength) != metadata.data.picture.descriptionLength) { + result = MA_FALSE; + goto done_flac; + } + blockSizeRemaining -= metadata.data.picture.descriptionLength; + pDescription[metadata.data.picture.descriptionLength] = '\0'; + metadata.data.picture.description = (const char*)pDescription; + if (blockSizeRemaining < 4 || onRead(pUserData, &metadata.data.picture.width, 4) != 4) { + result = MA_FALSE; + goto done_flac; + } + blockSizeRemaining -= 4; + metadata.data.picture.width = ma_dr_flac__be2host_32(metadata.data.picture.width); + if (blockSizeRemaining < 4 || onRead(pUserData, &metadata.data.picture.height, 4) != 4) { + result = MA_FALSE; + goto done_flac; + } + blockSizeRemaining -= 4; + metadata.data.picture.height = ma_dr_flac__be2host_32(metadata.data.picture.height); + if (blockSizeRemaining < 4 || onRead(pUserData, &metadata.data.picture.colorDepth, 4) != 4) { + result = MA_FALSE; + goto done_flac; + } + blockSizeRemaining -= 4; + metadata.data.picture.colorDepth = ma_dr_flac__be2host_32(metadata.data.picture.colorDepth); + if (blockSizeRemaining < 4 || onRead(pUserData, &metadata.data.picture.indexColorCount, 4) != 4) { + result = MA_FALSE; + goto done_flac; + } + blockSizeRemaining -= 4; + metadata.data.picture.indexColorCount = ma_dr_flac__be2host_32(metadata.data.picture.indexColorCount); + if (blockSizeRemaining < 4 || onRead(pUserData, &metadata.data.picture.pictureDataSize, 4) != 4) { + result = MA_FALSE; + goto done_flac; + } + blockSizeRemaining -= 4; + metadata.data.picture.pictureDataSize = ma_dr_flac__be2host_32(metadata.data.picture.pictureDataSize); + if (blockSizeRemaining < metadata.data.picture.pictureDataSize) { + result = MA_FALSE; + goto done_flac; + } + metadata.data.picture.pictureDataOffset = runningFilePos + (blockSize - blockSizeRemaining); + #ifndef MA_DR_FLAC_NO_PICTURE_METADATA_MALLOC + pPictureData = ma_dr_flac__malloc_from_callbacks(metadata.data.picture.pictureDataSize, pAllocationCallbacks); + if (pPictureData != NULL) { + if (onRead(pUserData, pPictureData, metadata.data.picture.pictureDataSize) != metadata.data.picture.pictureDataSize) { + result = MA_FALSE; + goto done_flac; + } + } else + #endif + { + if (!onSeek(pUserData, metadata.data.picture.pictureDataSize, MA_DR_FLAC_SEEK_CUR)) { + result = MA_FALSE; + goto done_flac; + } + } + blockSizeRemaining -= metadata.data.picture.pictureDataSize; + (void)blockSizeRemaining; + metadata.data.picture.pPictureData = (const ma_uint8*)pPictureData; + if (metadata.data.picture.pictureDataOffset != 0 || metadata.data.picture.pPictureData != NULL) { + onMeta(pUserDataMD, &metadata); + } else { + } + done_flac: + ma_dr_flac__free_from_callbacks(pMime, pAllocationCallbacks); + ma_dr_flac__free_from_callbacks(pDescription, pAllocationCallbacks); + ma_dr_flac__free_from_callbacks(pPictureData, pAllocationCallbacks); + if (result != MA_TRUE) { return MA_FALSE; } - onMeta(pUserDataMD, &metadata); - ma_dr_flac__free_from_callbacks(pRawData, pAllocationCallbacks); } } break; case MA_DR_FLAC_METADATA_BLOCK_TYPE_PADDING: { if (onMeta) { metadata.data.padding.unused = 0; - if (!onSeek(pUserData, blockSize, ma_dr_flac_seek_origin_current)) { + if (!onSeek(pUserData, blockSize, MA_DR_FLAC_SEEK_CUR)) { isLastBlock = MA_TRUE; } else { onMeta(pUserDataMD, &metadata); @@ -87000,7 +88986,7 @@ static ma_bool32 ma_dr_flac__read_and_decode_metadata(ma_dr_flac_read_proc onRea case MA_DR_FLAC_METADATA_BLOCK_TYPE_INVALID: { if (onMeta) { - if (!onSeek(pUserData, blockSize, ma_dr_flac_seek_origin_current)) { + if (!onSeek(pUserData, blockSize, MA_DR_FLAC_SEEK_CUR)) { isLastBlock = MA_TRUE; } } @@ -87009,12 +88995,15 @@ static ma_bool32 ma_dr_flac__read_and_decode_metadata(ma_dr_flac_read_proc onRea { if (onMeta) { void* pRawData = ma_dr_flac__malloc_from_callbacks(blockSize, pAllocationCallbacks); - if (pRawData == NULL) { - return MA_FALSE; - } - if (onRead(pUserData, pRawData, blockSize) != blockSize) { - ma_dr_flac__free_from_callbacks(pRawData, pAllocationCallbacks); - return MA_FALSE; + if (pRawData != NULL) { + if (onRead(pUserData, pRawData, blockSize) != blockSize) { + ma_dr_flac__free_from_callbacks(pRawData, pAllocationCallbacks); + return MA_FALSE; + } + } else { + if (!onSeek(pUserData, blockSize, MA_DR_FLAC_SEEK_CUR)) { + return MA_FALSE; + } } metadata.pRawData = pRawData; metadata.rawDataSize = blockSize; @@ -87024,7 +89013,7 @@ static ma_bool32 ma_dr_flac__read_and_decode_metadata(ma_dr_flac_read_proc onRea } break; } if (onMeta == NULL && blockSize > 0) { - if (!onSeek(pUserData, blockSize, ma_dr_flac_seek_origin_current)) { + if (!onSeek(pUserData, blockSize, MA_DR_FLAC_SEEK_CUR)) { isLastBlock = MA_TRUE; } } @@ -87288,6 +89277,7 @@ typedef struct { ma_dr_flac_read_proc onRead; ma_dr_flac_seek_proc onSeek; + ma_dr_flac_tell_proc onTell; void* pUserData; ma_uint64 currentBytePos; ma_uint64 firstBytePos; @@ -87306,29 +89296,29 @@ static size_t ma_dr_flac_oggbs__read_physical(ma_dr_flac_oggbs* oggbs, void* buf } static ma_bool32 ma_dr_flac_oggbs__seek_physical(ma_dr_flac_oggbs* oggbs, ma_uint64 offset, ma_dr_flac_seek_origin origin) { - if (origin == ma_dr_flac_seek_origin_start) { + if (origin == MA_DR_FLAC_SEEK_SET) { if (offset <= 0x7FFFFFFF) { - if (!oggbs->onSeek(oggbs->pUserData, (int)offset, ma_dr_flac_seek_origin_start)) { + if (!oggbs->onSeek(oggbs->pUserData, (int)offset, MA_DR_FLAC_SEEK_SET)) { return MA_FALSE; } oggbs->currentBytePos = offset; return MA_TRUE; } else { - if (!oggbs->onSeek(oggbs->pUserData, 0x7FFFFFFF, ma_dr_flac_seek_origin_start)) { + if (!oggbs->onSeek(oggbs->pUserData, 0x7FFFFFFF, MA_DR_FLAC_SEEK_SET)) { return MA_FALSE; } oggbs->currentBytePos = offset; - return ma_dr_flac_oggbs__seek_physical(oggbs, offset - 0x7FFFFFFF, ma_dr_flac_seek_origin_current); + return ma_dr_flac_oggbs__seek_physical(oggbs, offset - 0x7FFFFFFF, MA_DR_FLAC_SEEK_CUR); } } else { while (offset > 0x7FFFFFFF) { - if (!oggbs->onSeek(oggbs->pUserData, 0x7FFFFFFF, ma_dr_flac_seek_origin_current)) { + if (!oggbs->onSeek(oggbs->pUserData, 0x7FFFFFFF, MA_DR_FLAC_SEEK_CUR)) { return MA_FALSE; } oggbs->currentBytePos += 0x7FFFFFFF; offset -= 0x7FFFFFFF; } - if (!oggbs->onSeek(oggbs->pUserData, (int)offset, ma_dr_flac_seek_origin_current)) { + if (!oggbs->onSeek(oggbs->pUserData, (int)offset, MA_DR_FLAC_SEEK_CUR)) { return MA_FALSE; } oggbs->currentBytePos += offset; @@ -87354,7 +89344,7 @@ static ma_bool32 ma_dr_flac_oggbs__goto_next_page(ma_dr_flac_oggbs* oggbs, ma_dr continue; } if (header.serialNumber != oggbs->serialNumber) { - if (pageBodySize > 0 && !ma_dr_flac_oggbs__seek_physical(oggbs, pageBodySize, ma_dr_flac_seek_origin_current)) { + if (pageBodySize > 0 && !ma_dr_flac_oggbs__seek_physical(oggbs, pageBodySize, MA_DR_FLAC_SEEK_CUR)) { return MA_FALSE; } continue; @@ -87416,7 +89406,7 @@ static ma_bool32 ma_dr_flac_oggbs__seek_to_next_packet(ma_dr_flac_oggbs* oggbs) } bytesToEndOfPacketOrPage += segmentSize; } - ma_dr_flac_oggbs__seek_physical(oggbs, bytesToEndOfPacketOrPage, ma_dr_flac_seek_origin_current); + ma_dr_flac_oggbs__seek_physical(oggbs, bytesToEndOfPacketOrPage, MA_DR_FLAC_SEEK_CUR); oggbs->bytesRemainingInPage -= bytesToEndOfPacketOrPage; if (atEndOfPage) { if (!ma_dr_flac_oggbs__goto_next_page(oggbs)) { @@ -87469,36 +89459,44 @@ static ma_bool32 ma_dr_flac__on_seek_ogg(void* pUserData, int offset, ma_dr_flac int bytesSeeked = 0; MA_DR_FLAC_ASSERT(oggbs != NULL); MA_DR_FLAC_ASSERT(offset >= 0); - if (origin == ma_dr_flac_seek_origin_start) { - if (!ma_dr_flac_oggbs__seek_physical(oggbs, (int)oggbs->firstBytePos, ma_dr_flac_seek_origin_start)) { + if (origin == MA_DR_FLAC_SEEK_SET) { + if (!ma_dr_flac_oggbs__seek_physical(oggbs, (int)oggbs->firstBytePos, MA_DR_FLAC_SEEK_SET)) { return MA_FALSE; } if (!ma_dr_flac_oggbs__goto_next_page(oggbs, ma_dr_flac_ogg_fail_on_crc_mismatch)) { return MA_FALSE; } - return ma_dr_flac__on_seek_ogg(pUserData, offset, ma_dr_flac_seek_origin_current); - } - MA_DR_FLAC_ASSERT(origin == ma_dr_flac_seek_origin_current); - while (bytesSeeked < offset) { - int bytesRemainingToSeek = offset - bytesSeeked; - MA_DR_FLAC_ASSERT(bytesRemainingToSeek >= 0); - if (oggbs->bytesRemainingInPage >= (size_t)bytesRemainingToSeek) { - bytesSeeked += bytesRemainingToSeek; - (void)bytesSeeked; - oggbs->bytesRemainingInPage -= bytesRemainingToSeek; - break; - } - if (oggbs->bytesRemainingInPage > 0) { - bytesSeeked += (int)oggbs->bytesRemainingInPage; - oggbs->bytesRemainingInPage = 0; - } - MA_DR_FLAC_ASSERT(bytesRemainingToSeek > 0); - if (!ma_dr_flac_oggbs__goto_next_page(oggbs, ma_dr_flac_ogg_fail_on_crc_mismatch)) { - return MA_FALSE; + return ma_dr_flac__on_seek_ogg(pUserData, offset, MA_DR_FLAC_SEEK_CUR); + } else if (origin == MA_DR_FLAC_SEEK_CUR) { + while (bytesSeeked < offset) { + int bytesRemainingToSeek = offset - bytesSeeked; + MA_DR_FLAC_ASSERT(bytesRemainingToSeek >= 0); + if (oggbs->bytesRemainingInPage >= (size_t)bytesRemainingToSeek) { + bytesSeeked += bytesRemainingToSeek; + (void)bytesSeeked; + oggbs->bytesRemainingInPage -= bytesRemainingToSeek; + break; + } + if (oggbs->bytesRemainingInPage > 0) { + bytesSeeked += (int)oggbs->bytesRemainingInPage; + oggbs->bytesRemainingInPage = 0; + } + MA_DR_FLAC_ASSERT(bytesRemainingToSeek > 0); + if (!ma_dr_flac_oggbs__goto_next_page(oggbs, ma_dr_flac_ogg_fail_on_crc_mismatch)) { + return MA_FALSE; + } } + } else if (origin == MA_DR_FLAC_SEEK_END) { + return MA_FALSE; } return MA_TRUE; } +static ma_bool32 ma_dr_flac__on_tell_ogg(void* pUserData, ma_int64* pCursor) +{ + (void)pUserData; + (void)pCursor; + return MA_FALSE; +} static ma_bool32 ma_dr_flac_ogg__seek_to_pcm_frame(ma_dr_flac* pFlac, ma_uint64 pcmFrameIndex) { ma_dr_flac_oggbs* oggbs = (ma_dr_flac_oggbs*)pFlac->_oggbs; @@ -87515,7 +89513,7 @@ static ma_bool32 ma_dr_flac_ogg__seek_to_pcm_frame(ma_dr_flac* pFlac, ma_uint64 runningGranulePosition = 0; for (;;) { if (!ma_dr_flac_oggbs__goto_next_page(oggbs, ma_dr_flac_ogg_recover_on_crc_mismatch)) { - ma_dr_flac_oggbs__seek_physical(oggbs, originalBytePos, ma_dr_flac_seek_origin_start); + ma_dr_flac_oggbs__seek_physical(oggbs, originalBytePos, MA_DR_FLAC_SEEK_SET); return MA_FALSE; } runningFrameBytePos = oggbs->currentBytePos - ma_dr_flac_ogg__get_page_header_size(&oggbs->currentPageHeader) - oggbs->pageDataSize; @@ -87534,7 +89532,7 @@ static ma_bool32 ma_dr_flac_ogg__seek_to_pcm_frame(ma_dr_flac* pFlac, ma_uint64 } } } - if (!ma_dr_flac_oggbs__seek_physical(oggbs, runningFrameBytePos, ma_dr_flac_seek_origin_start)) { + if (!ma_dr_flac_oggbs__seek_physical(oggbs, runningFrameBytePos, MA_DR_FLAC_SEEK_SET)) { return MA_FALSE; } if (!ma_dr_flac_oggbs__goto_next_page(oggbs, ma_dr_flac_ogg_recover_on_crc_mismatch)) { @@ -87629,7 +89627,7 @@ static ma_bool32 ma_dr_flac__init_private__ogg(ma_dr_flac_init_info* pInit, ma_d if (mappingVersion[0] != 1) { return MA_FALSE; } - if (!onSeek(pUserData, 2, ma_dr_flac_seek_origin_current)) { + if (!onSeek(pUserData, 2, MA_DR_FLAC_SEEK_CUR)) { return MA_FALSE; } if (onRead(pUserData, sig, 4) != 4) { @@ -87674,17 +89672,17 @@ static ma_bool32 ma_dr_flac__init_private__ogg(ma_dr_flac_init_info* pInit, ma_d return MA_FALSE; } } else { - if (!onSeek(pUserData, bytesRemainingInPage, ma_dr_flac_seek_origin_current)) { + if (!onSeek(pUserData, bytesRemainingInPage, MA_DR_FLAC_SEEK_CUR)) { return MA_FALSE; } } } else { - if (!onSeek(pUserData, bytesRemainingInPage, ma_dr_flac_seek_origin_current)) { + if (!onSeek(pUserData, bytesRemainingInPage, MA_DR_FLAC_SEEK_CUR)) { return MA_FALSE; } } } else { - if (!onSeek(pUserData, pageBodySize, ma_dr_flac_seek_origin_current)) { + if (!onSeek(pUserData, pageBodySize, MA_DR_FLAC_SEEK_CUR)) { return MA_FALSE; } } @@ -87698,7 +89696,7 @@ static ma_bool32 ma_dr_flac__init_private__ogg(ma_dr_flac_init_info* pInit, ma_d return MA_TRUE; } #endif -static ma_bool32 ma_dr_flac__init_private(ma_dr_flac_init_info* pInit, ma_dr_flac_read_proc onRead, ma_dr_flac_seek_proc onSeek, ma_dr_flac_meta_proc onMeta, ma_dr_flac_container container, void* pUserData, void* pUserDataMD) +static ma_bool32 ma_dr_flac__init_private(ma_dr_flac_init_info* pInit, ma_dr_flac_read_proc onRead, ma_dr_flac_seek_proc onSeek, ma_dr_flac_tell_proc onTell, ma_dr_flac_meta_proc onMeta, ma_dr_flac_container container, void* pUserData, void* pUserDataMD) { ma_bool32 relaxed; ma_uint8 id[4]; @@ -87708,12 +89706,14 @@ static ma_bool32 ma_dr_flac__init_private(ma_dr_flac_init_info* pInit, ma_dr_fla MA_DR_FLAC_ZERO_MEMORY(pInit, sizeof(*pInit)); pInit->onRead = onRead; pInit->onSeek = onSeek; + pInit->onTell = onTell; pInit->onMeta = onMeta; pInit->container = container; pInit->pUserData = pUserData; pInit->pUserDataMD = pUserDataMD; pInit->bs.onRead = onRead; pInit->bs.onSeek = onSeek; + pInit->bs.onTell = onTell; pInit->bs.pUserData = pUserData; ma_dr_flac__reset_cache(&pInit->bs); relaxed = container != ma_dr_flac_container_unknown; @@ -87736,7 +89736,7 @@ static ma_bool32 ma_dr_flac__init_private(ma_dr_flac_init_info* pInit, ma_dr_fla if (flags & 0x10) { headerSize += 10; } - if (!onSeek(pUserData, headerSize, ma_dr_flac_seek_origin_current)) { + if (!onSeek(pUserData, headerSize, MA_DR_FLAC_SEEK_CUR)) { return MA_FALSE; } pInit->runningFilePos += headerSize; @@ -87779,7 +89779,7 @@ static void ma_dr_flac__init_from_info(ma_dr_flac* pFlac, const ma_dr_flac_init_ pFlac->totalPCMFrameCount = pInit->totalPCMFrameCount; pFlac->container = pInit->container; } -static ma_dr_flac* ma_dr_flac_open_with_metadata_private(ma_dr_flac_read_proc onRead, ma_dr_flac_seek_proc onSeek, ma_dr_flac_meta_proc onMeta, ma_dr_flac_container container, void* pUserData, void* pUserDataMD, const ma_allocation_callbacks* pAllocationCallbacks) +static ma_dr_flac* ma_dr_flac_open_with_metadata_private(ma_dr_flac_read_proc onRead, ma_dr_flac_seek_proc onSeek, ma_dr_flac_tell_proc onTell, ma_dr_flac_meta_proc onMeta, ma_dr_flac_container container, void* pUserData, void* pUserDataMD, const ma_allocation_callbacks* pAllocationCallbacks) { ma_dr_flac_init_info init; ma_uint32 allocationSize; @@ -87794,7 +89794,7 @@ static ma_dr_flac* ma_dr_flac_open_with_metadata_private(ma_dr_flac_read_proc on ma_allocation_callbacks allocationCallbacks; ma_dr_flac* pFlac; ma_dr_flac__init_cpu_caps(); - if (!ma_dr_flac__init_private(&init, onRead, onSeek, onMeta, container, pUserData, pUserDataMD)) { + if (!ma_dr_flac__init_private(&init, onRead, onSeek, onTell, onMeta, container, pUserData, pUserDataMD)) { return NULL; } if (pAllocationCallbacks != NULL) { @@ -87827,6 +89827,7 @@ static ma_dr_flac* ma_dr_flac_open_with_metadata_private(ma_dr_flac_read_proc on MA_DR_FLAC_ZERO_MEMORY(pOggbs, sizeof(*pOggbs)); pOggbs->onRead = onRead; pOggbs->onSeek = onSeek; + pOggbs->onTell = onTell; pOggbs->pUserData = pUserData; pOggbs->currentBytePos = init.oggFirstBytePos; pOggbs->firstBytePos = init.oggFirstBytePos; @@ -87841,15 +89842,17 @@ static ma_dr_flac* ma_dr_flac_open_with_metadata_private(ma_dr_flac_read_proc on if (init.hasMetadataBlocks) { ma_dr_flac_read_proc onReadOverride = onRead; ma_dr_flac_seek_proc onSeekOverride = onSeek; + ma_dr_flac_tell_proc onTellOverride = onTell; void* pUserDataOverride = pUserData; #ifndef MA_DR_FLAC_NO_OGG if (init.container == ma_dr_flac_container_ogg) { onReadOverride = ma_dr_flac__on_read_ogg; onSeekOverride = ma_dr_flac__on_seek_ogg; + onTellOverride = ma_dr_flac__on_tell_ogg; pUserDataOverride = (void*)pOggbs; } #endif - if (!ma_dr_flac__read_and_decode_metadata(onReadOverride, onSeekOverride, onMeta, pUserDataOverride, pUserDataMD, &firstFramePos, &seektablePos, &seekpointCount, &allocationCallbacks)) { + if (!ma_dr_flac__read_and_decode_metadata(onReadOverride, onSeekOverride, onTellOverride, onMeta, pUserDataOverride, pUserDataMD, &firstFramePos, &seektablePos, &seekpointCount, &allocationCallbacks)) { #ifndef MA_DR_FLAC_NO_OGG ma_dr_flac__free_from_callbacks(pOggbs, &allocationCallbacks); #endif @@ -87875,6 +89878,7 @@ static ma_dr_flac* ma_dr_flac_open_with_metadata_private(ma_dr_flac_read_proc on pOggbs = NULL; pFlac->bs.onRead = ma_dr_flac__on_read_ogg; pFlac->bs.onSeek = ma_dr_flac__on_seek_ogg; + pFlac->bs.onTell = ma_dr_flac__on_tell_ogg; pFlac->bs.pUserData = (void*)pInternalOggbs; pFlac->_oggbs = (void*)pInternalOggbs; } @@ -87894,7 +89898,7 @@ static ma_dr_flac* ma_dr_flac_open_with_metadata_private(ma_dr_flac_read_proc on pFlac->pSeekpoints = (ma_dr_flac_seekpoint*)((ma_uint8*)pFlac->pDecodedSamples + decodedSamplesAllocationSize); MA_DR_FLAC_ASSERT(pFlac->bs.onSeek != NULL); MA_DR_FLAC_ASSERT(pFlac->bs.onRead != NULL); - if (pFlac->bs.onSeek(pFlac->bs.pUserData, (int)seektablePos, ma_dr_flac_seek_origin_start)) { + if (pFlac->bs.onSeek(pFlac->bs.pUserData, (int)seektablePos, MA_DR_FLAC_SEEK_SET)) { ma_uint32 iSeekpoint; for (iSeekpoint = 0; iSeekpoint < seekpointCount; iSeekpoint += 1) { if (pFlac->bs.onRead(pFlac->bs.pUserData, pFlac->pSeekpoints + iSeekpoint, MA_DR_FLAC_SEEKPOINT_SIZE_IN_BYTES) == MA_DR_FLAC_SEEKPOINT_SIZE_IN_BYTES) { @@ -87907,7 +89911,7 @@ static ma_dr_flac* ma_dr_flac_open_with_metadata_private(ma_dr_flac_read_proc on break; } } - if (!pFlac->bs.onSeek(pFlac->bs.pUserData, (int)pFlac->firstFLACFramePosInBytes, ma_dr_flac_seek_origin_start)) { + if (!pFlac->bs.onSeek(pFlac->bs.pUserData, (int)pFlac->firstFLACFramePosInBytes, MA_DR_FLAC_SEEK_SET)) { ma_dr_flac__free_from_callbacks(pFlac, &allocationCallbacks); return NULL; } @@ -87950,8 +89954,31 @@ static size_t ma_dr_flac__on_read_stdio(void* pUserData, void* bufferOut, size_t } static ma_bool32 ma_dr_flac__on_seek_stdio(void* pUserData, int offset, ma_dr_flac_seek_origin origin) { - MA_DR_FLAC_ASSERT(offset >= 0); - return fseek((FILE*)pUserData, offset, (origin == ma_dr_flac_seek_origin_current) ? SEEK_CUR : SEEK_SET) == 0; + int whence = SEEK_SET; + if (origin == MA_DR_FLAC_SEEK_CUR) { + whence = SEEK_CUR; + } else if (origin == MA_DR_FLAC_SEEK_END) { + whence = SEEK_END; + } + return fseek((FILE*)pUserData, offset, whence) == 0; +} +static ma_bool32 ma_dr_flac__on_tell_stdio(void* pUserData, ma_int64* pCursor) +{ + FILE* pFileStdio = (FILE*)pUserData; + ma_int64 result; + MA_DR_FLAC_ASSERT(pFileStdio != NULL); + MA_DR_FLAC_ASSERT(pCursor != NULL); +#if defined(_WIN32) && !defined(NXDK) + #if defined(_MSC_VER) && _MSC_VER > 1200 + result = _ftelli64(pFileStdio); + #else + result = ftell(pFileStdio); + #endif +#else + result = ftell(pFileStdio); +#endif + *pCursor = result; + return MA_TRUE; } MA_API ma_dr_flac* ma_dr_flac_open_file(const char* pFileName, const ma_allocation_callbacks* pAllocationCallbacks) { @@ -87960,7 +89987,7 @@ MA_API ma_dr_flac* ma_dr_flac_open_file(const char* pFileName, const ma_allocati if (ma_fopen(&pFile, pFileName, "rb") != MA_SUCCESS) { return NULL; } - pFlac = ma_dr_flac_open(ma_dr_flac__on_read_stdio, ma_dr_flac__on_seek_stdio, (void*)pFile, pAllocationCallbacks); + pFlac = ma_dr_flac_open(ma_dr_flac__on_read_stdio, ma_dr_flac__on_seek_stdio, ma_dr_flac__on_tell_stdio, (void*)pFile, pAllocationCallbacks); if (pFlac == NULL) { fclose(pFile); return NULL; @@ -87975,7 +90002,7 @@ MA_API ma_dr_flac* ma_dr_flac_open_file_w(const wchar_t* pFileName, const ma_all if (ma_wfopen(&pFile, pFileName, L"rb", pAllocationCallbacks) != MA_SUCCESS) { return NULL; } - pFlac = ma_dr_flac_open(ma_dr_flac__on_read_stdio, ma_dr_flac__on_seek_stdio, (void*)pFile, pAllocationCallbacks); + pFlac = ma_dr_flac_open(ma_dr_flac__on_read_stdio, ma_dr_flac__on_seek_stdio, ma_dr_flac__on_tell_stdio, (void*)pFile, pAllocationCallbacks); if (pFlac == NULL) { fclose(pFile); return NULL; @@ -87990,7 +90017,7 @@ MA_API ma_dr_flac* ma_dr_flac_open_file_with_metadata(const char* pFileName, ma_ if (ma_fopen(&pFile, pFileName, "rb") != MA_SUCCESS) { return NULL; } - pFlac = ma_dr_flac_open_with_metadata_private(ma_dr_flac__on_read_stdio, ma_dr_flac__on_seek_stdio, onMeta, ma_dr_flac_container_unknown, (void*)pFile, pUserData, pAllocationCallbacks); + pFlac = ma_dr_flac_open_with_metadata_private(ma_dr_flac__on_read_stdio, ma_dr_flac__on_seek_stdio, ma_dr_flac__on_tell_stdio, onMeta, ma_dr_flac_container_unknown, (void*)pFile, pUserData, pAllocationCallbacks); if (pFlac == NULL) { fclose(pFile); return pFlac; @@ -88005,7 +90032,7 @@ MA_API ma_dr_flac* ma_dr_flac_open_file_with_metadata_w(const wchar_t* pFileName if (ma_wfopen(&pFile, pFileName, L"rb", pAllocationCallbacks) != MA_SUCCESS) { return NULL; } - pFlac = ma_dr_flac_open_with_metadata_private(ma_dr_flac__on_read_stdio, ma_dr_flac__on_seek_stdio, onMeta, ma_dr_flac_container_unknown, (void*)pFile, pUserData, pAllocationCallbacks); + pFlac = ma_dr_flac_open_with_metadata_private(ma_dr_flac__on_read_stdio, ma_dr_flac__on_seek_stdio, ma_dr_flac__on_tell_stdio, onMeta, ma_dr_flac_container_unknown, (void*)pFile, pUserData, pAllocationCallbacks); if (pFlac == NULL) { fclose(pFile); return pFlac; @@ -88033,24 +90060,34 @@ static size_t ma_dr_flac__on_read_memory(void* pUserData, void* bufferOut, size_ static ma_bool32 ma_dr_flac__on_seek_memory(void* pUserData, int offset, ma_dr_flac_seek_origin origin) { ma_dr_flac__memory_stream* memoryStream = (ma_dr_flac__memory_stream*)pUserData; + ma_int64 newCursor; MA_DR_FLAC_ASSERT(memoryStream != NULL); - MA_DR_FLAC_ASSERT(offset >= 0); - if (offset > (ma_int64)memoryStream->dataSize) { + if (origin == MA_DR_FLAC_SEEK_SET) { + newCursor = 0; + } else if (origin == MA_DR_FLAC_SEEK_CUR) { + newCursor = (ma_int64)memoryStream->currentReadPos; + } else if (origin == MA_DR_FLAC_SEEK_END) { + newCursor = (ma_int64)memoryStream->dataSize; + } else { + MA_DR_FLAC_ASSERT(!"Invalid seek origin"); return MA_FALSE; } - if (origin == ma_dr_flac_seek_origin_current) { - if (memoryStream->currentReadPos + offset <= memoryStream->dataSize) { - memoryStream->currentReadPos += offset; - } else { - return MA_FALSE; - } - } else { - if ((ma_uint32)offset <= memoryStream->dataSize) { - memoryStream->currentReadPos = offset; - } else { - return MA_FALSE; - } + newCursor += offset; + if (newCursor < 0) { + return MA_FALSE; } + if ((size_t)newCursor > memoryStream->dataSize) { + return MA_FALSE; + } + memoryStream->currentReadPos = (size_t)newCursor; + return MA_TRUE; +} +static ma_bool32 ma_dr_flac__on_tell_memory(void* pUserData, ma_int64* pCursor) +{ + ma_dr_flac__memory_stream* memoryStream = (ma_dr_flac__memory_stream*)pUserData; + MA_DR_FLAC_ASSERT(memoryStream != NULL); + MA_DR_FLAC_ASSERT(pCursor != NULL); + *pCursor = (ma_int64)memoryStream->currentReadPos; return MA_TRUE; } MA_API ma_dr_flac* ma_dr_flac_open_memory(const void* pData, size_t dataSize, const ma_allocation_callbacks* pAllocationCallbacks) @@ -88060,7 +90097,7 @@ MA_API ma_dr_flac* ma_dr_flac_open_memory(const void* pData, size_t dataSize, co memoryStream.data = (const ma_uint8*)pData; memoryStream.dataSize = dataSize; memoryStream.currentReadPos = 0; - pFlac = ma_dr_flac_open(ma_dr_flac__on_read_memory, ma_dr_flac__on_seek_memory, &memoryStream, pAllocationCallbacks); + pFlac = ma_dr_flac_open(ma_dr_flac__on_read_memory, ma_dr_flac__on_seek_memory, ma_dr_flac__on_tell_memory, &memoryStream, pAllocationCallbacks); if (pFlac == NULL) { return NULL; } @@ -88085,7 +90122,7 @@ MA_API ma_dr_flac* ma_dr_flac_open_memory_with_metadata(const void* pData, size_ memoryStream.data = (const ma_uint8*)pData; memoryStream.dataSize = dataSize; memoryStream.currentReadPos = 0; - pFlac = ma_dr_flac_open_with_metadata_private(ma_dr_flac__on_read_memory, ma_dr_flac__on_seek_memory, onMeta, ma_dr_flac_container_unknown, &memoryStream, pUserData, pAllocationCallbacks); + pFlac = ma_dr_flac_open_with_metadata_private(ma_dr_flac__on_read_memory, ma_dr_flac__on_seek_memory, ma_dr_flac__on_tell_memory, onMeta, ma_dr_flac_container_unknown, &memoryStream, pUserData, pAllocationCallbacks); if (pFlac == NULL) { return NULL; } @@ -88103,21 +90140,21 @@ MA_API ma_dr_flac* ma_dr_flac_open_memory_with_metadata(const void* pData, size_ } return pFlac; } -MA_API ma_dr_flac* ma_dr_flac_open(ma_dr_flac_read_proc onRead, ma_dr_flac_seek_proc onSeek, void* pUserData, const ma_allocation_callbacks* pAllocationCallbacks) +MA_API ma_dr_flac* ma_dr_flac_open(ma_dr_flac_read_proc onRead, ma_dr_flac_seek_proc onSeek, ma_dr_flac_tell_proc onTell, void* pUserData, const ma_allocation_callbacks* pAllocationCallbacks) { - return ma_dr_flac_open_with_metadata_private(onRead, onSeek, NULL, ma_dr_flac_container_unknown, pUserData, pUserData, pAllocationCallbacks); + return ma_dr_flac_open_with_metadata_private(onRead, onSeek, onTell, NULL, ma_dr_flac_container_unknown, pUserData, pUserData, pAllocationCallbacks); } -MA_API ma_dr_flac* ma_dr_flac_open_relaxed(ma_dr_flac_read_proc onRead, ma_dr_flac_seek_proc onSeek, ma_dr_flac_container container, void* pUserData, const ma_allocation_callbacks* pAllocationCallbacks) +MA_API ma_dr_flac* ma_dr_flac_open_relaxed(ma_dr_flac_read_proc onRead, ma_dr_flac_seek_proc onSeek, ma_dr_flac_tell_proc onTell, ma_dr_flac_container container, void* pUserData, const ma_allocation_callbacks* pAllocationCallbacks) { - return ma_dr_flac_open_with_metadata_private(onRead, onSeek, NULL, container, pUserData, pUserData, pAllocationCallbacks); + return ma_dr_flac_open_with_metadata_private(onRead, onSeek, onTell, NULL, container, pUserData, pUserData, pAllocationCallbacks); } -MA_API ma_dr_flac* ma_dr_flac_open_with_metadata(ma_dr_flac_read_proc onRead, ma_dr_flac_seek_proc onSeek, ma_dr_flac_meta_proc onMeta, void* pUserData, const ma_allocation_callbacks* pAllocationCallbacks) +MA_API ma_dr_flac* ma_dr_flac_open_with_metadata(ma_dr_flac_read_proc onRead, ma_dr_flac_seek_proc onSeek, ma_dr_flac_tell_proc onTell, ma_dr_flac_meta_proc onMeta, void* pUserData, const ma_allocation_callbacks* pAllocationCallbacks) { - return ma_dr_flac_open_with_metadata_private(onRead, onSeek, onMeta, ma_dr_flac_container_unknown, pUserData, pUserData, pAllocationCallbacks); + return ma_dr_flac_open_with_metadata_private(onRead, onSeek, onTell, onMeta, ma_dr_flac_container_unknown, pUserData, pUserData, pAllocationCallbacks); } -MA_API ma_dr_flac* ma_dr_flac_open_with_metadata_relaxed(ma_dr_flac_read_proc onRead, ma_dr_flac_seek_proc onSeek, ma_dr_flac_meta_proc onMeta, ma_dr_flac_container container, void* pUserData, const ma_allocation_callbacks* pAllocationCallbacks) +MA_API ma_dr_flac* ma_dr_flac_open_with_metadata_relaxed(ma_dr_flac_read_proc onRead, ma_dr_flac_seek_proc onSeek, ma_dr_flac_tell_proc onTell, ma_dr_flac_meta_proc onMeta, ma_dr_flac_container container, void* pUserData, const ma_allocation_callbacks* pAllocationCallbacks) { - return ma_dr_flac_open_with_metadata_private(onRead, onSeek, onMeta, container, pUserData, pUserData, pAllocationCallbacks); + return ma_dr_flac_open_with_metadata_private(onRead, onSeek, onTell, onMeta, container, pUserData, pUserData, pAllocationCallbacks); } MA_API void ma_dr_flac_close(ma_dr_flac* pFlac) { @@ -90345,57 +92382,42 @@ static type* ma_dr_flac__full_read_and_close_ ## extension (ma_dr_flac* pFlac, u { \ type* pSampleData = NULL; \ ma_uint64 totalPCMFrameCount; \ + type buffer[4096]; \ + ma_uint64 pcmFramesRead; \ + size_t sampleDataBufferSize = sizeof(buffer); \ \ MA_DR_FLAC_ASSERT(pFlac != NULL); \ \ - totalPCMFrameCount = pFlac->totalPCMFrameCount; \ - \ - if (totalPCMFrameCount == 0) { \ - type buffer[4096]; \ - ma_uint64 pcmFramesRead; \ - size_t sampleDataBufferSize = sizeof(buffer); \ + totalPCMFrameCount = 0; \ \ - pSampleData = (type*)ma_dr_flac__malloc_from_callbacks(sampleDataBufferSize, &pFlac->allocationCallbacks); \ - if (pSampleData == NULL) { \ - goto on_error; \ - } \ - \ - while ((pcmFramesRead = (ma_uint64)ma_dr_flac_read_pcm_frames_##extension(pFlac, sizeof(buffer)/sizeof(buffer[0])/pFlac->channels, buffer)) > 0) { \ - if (((totalPCMFrameCount + pcmFramesRead) * pFlac->channels * sizeof(type)) > sampleDataBufferSize) { \ - type* pNewSampleData; \ - size_t newSampleDataBufferSize; \ + pSampleData = (type*)ma_dr_flac__malloc_from_callbacks(sampleDataBufferSize, &pFlac->allocationCallbacks); \ + if (pSampleData == NULL) { \ + goto on_error; \ + } \ \ - newSampleDataBufferSize = sampleDataBufferSize * 2; \ - pNewSampleData = (type*)ma_dr_flac__realloc_from_callbacks(pSampleData, newSampleDataBufferSize, sampleDataBufferSize, &pFlac->allocationCallbacks); \ - if (pNewSampleData == NULL) { \ - ma_dr_flac__free_from_callbacks(pSampleData, &pFlac->allocationCallbacks); \ - goto on_error; \ - } \ + while ((pcmFramesRead = (ma_uint64)ma_dr_flac_read_pcm_frames_##extension(pFlac, sizeof(buffer)/sizeof(buffer[0])/pFlac->channels, buffer)) > 0) { \ + if (((totalPCMFrameCount + pcmFramesRead) * pFlac->channels * sizeof(type)) > sampleDataBufferSize) { \ + type* pNewSampleData; \ + size_t newSampleDataBufferSize; \ \ - sampleDataBufferSize = newSampleDataBufferSize; \ - pSampleData = pNewSampleData; \ + newSampleDataBufferSize = sampleDataBufferSize * 2; \ + pNewSampleData = (type*)ma_dr_flac__realloc_from_callbacks(pSampleData, newSampleDataBufferSize, sampleDataBufferSize, &pFlac->allocationCallbacks); \ + if (pNewSampleData == NULL) { \ + ma_dr_flac__free_from_callbacks(pSampleData, &pFlac->allocationCallbacks); \ + goto on_error; \ } \ \ - MA_DR_FLAC_COPY_MEMORY(pSampleData + (totalPCMFrameCount*pFlac->channels), buffer, (size_t)(pcmFramesRead*pFlac->channels*sizeof(type))); \ - totalPCMFrameCount += pcmFramesRead; \ - } \ - \ - \ - MA_DR_FLAC_ZERO_MEMORY(pSampleData + (totalPCMFrameCount*pFlac->channels), (size_t)(sampleDataBufferSize - totalPCMFrameCount*pFlac->channels*sizeof(type))); \ - } else { \ - ma_uint64 dataSize = totalPCMFrameCount*pFlac->channels*sizeof(type); \ - if (dataSize > (ma_uint64)MA_SIZE_MAX) { \ - goto on_error; \ - } \ - \ - pSampleData = (type*)ma_dr_flac__malloc_from_callbacks((size_t)dataSize, &pFlac->allocationCallbacks); \ - if (pSampleData == NULL) { \ - goto on_error; \ + sampleDataBufferSize = newSampleDataBufferSize; \ + pSampleData = pNewSampleData; \ } \ \ - totalPCMFrameCount = ma_dr_flac_read_pcm_frames_##extension(pFlac, pFlac->totalPCMFrameCount, pSampleData); \ + MA_DR_FLAC_COPY_MEMORY(pSampleData + (totalPCMFrameCount*pFlac->channels), buffer, (size_t)(pcmFramesRead*pFlac->channels*sizeof(type))); \ + totalPCMFrameCount += pcmFramesRead; \ } \ \ + \ + MA_DR_FLAC_ZERO_MEMORY(pSampleData + (totalPCMFrameCount*pFlac->channels), (size_t)(sampleDataBufferSize - totalPCMFrameCount*pFlac->channels*sizeof(type))); \ + \ if (sampleRateOut) *sampleRateOut = pFlac->sampleRate; \ if (channelsOut) *channelsOut = pFlac->channels; \ if (totalPCMFrameCountOut) *totalPCMFrameCountOut = totalPCMFrameCount; \ @@ -90410,7 +92432,7 @@ on_error: MA_DR_FLAC_DEFINE_FULL_READ_AND_CLOSE(s32, ma_int32) MA_DR_FLAC_DEFINE_FULL_READ_AND_CLOSE(s16, ma_int16) MA_DR_FLAC_DEFINE_FULL_READ_AND_CLOSE(f32, float) -MA_API ma_int32* ma_dr_flac_open_and_read_pcm_frames_s32(ma_dr_flac_read_proc onRead, ma_dr_flac_seek_proc onSeek, void* pUserData, unsigned int* channelsOut, unsigned int* sampleRateOut, ma_uint64* totalPCMFrameCountOut, const ma_allocation_callbacks* pAllocationCallbacks) +MA_API ma_int32* ma_dr_flac_open_and_read_pcm_frames_s32(ma_dr_flac_read_proc onRead, ma_dr_flac_seek_proc onSeek, ma_dr_flac_tell_proc onTell, void* pUserData, unsigned int* channelsOut, unsigned int* sampleRateOut, ma_uint64* totalPCMFrameCountOut, const ma_allocation_callbacks* pAllocationCallbacks) { ma_dr_flac* pFlac; if (channelsOut) { @@ -90422,13 +92444,13 @@ MA_API ma_int32* ma_dr_flac_open_and_read_pcm_frames_s32(ma_dr_flac_read_proc on if (totalPCMFrameCountOut) { *totalPCMFrameCountOut = 0; } - pFlac = ma_dr_flac_open(onRead, onSeek, pUserData, pAllocationCallbacks); + pFlac = ma_dr_flac_open(onRead, onSeek, onTell, pUserData, pAllocationCallbacks); if (pFlac == NULL) { return NULL; } return ma_dr_flac__full_read_and_close_s32(pFlac, channelsOut, sampleRateOut, totalPCMFrameCountOut); } -MA_API ma_int16* ma_dr_flac_open_and_read_pcm_frames_s16(ma_dr_flac_read_proc onRead, ma_dr_flac_seek_proc onSeek, void* pUserData, unsigned int* channelsOut, unsigned int* sampleRateOut, ma_uint64* totalPCMFrameCountOut, const ma_allocation_callbacks* pAllocationCallbacks) +MA_API ma_int16* ma_dr_flac_open_and_read_pcm_frames_s16(ma_dr_flac_read_proc onRead, ma_dr_flac_seek_proc onSeek, ma_dr_flac_tell_proc onTell, void* pUserData, unsigned int* channelsOut, unsigned int* sampleRateOut, ma_uint64* totalPCMFrameCountOut, const ma_allocation_callbacks* pAllocationCallbacks) { ma_dr_flac* pFlac; if (channelsOut) { @@ -90440,13 +92462,13 @@ MA_API ma_int16* ma_dr_flac_open_and_read_pcm_frames_s16(ma_dr_flac_read_proc on if (totalPCMFrameCountOut) { *totalPCMFrameCountOut = 0; } - pFlac = ma_dr_flac_open(onRead, onSeek, pUserData, pAllocationCallbacks); + pFlac = ma_dr_flac_open(onRead, onSeek, onTell, pUserData, pAllocationCallbacks); if (pFlac == NULL) { return NULL; } return ma_dr_flac__full_read_and_close_s16(pFlac, channelsOut, sampleRateOut, totalPCMFrameCountOut); } -MA_API float* ma_dr_flac_open_and_read_pcm_frames_f32(ma_dr_flac_read_proc onRead, ma_dr_flac_seek_proc onSeek, void* pUserData, unsigned int* channelsOut, unsigned int* sampleRateOut, ma_uint64* totalPCMFrameCountOut, const ma_allocation_callbacks* pAllocationCallbacks) +MA_API float* ma_dr_flac_open_and_read_pcm_frames_f32(ma_dr_flac_read_proc onRead, ma_dr_flac_seek_proc onSeek, ma_dr_flac_tell_proc onTell, void* pUserData, unsigned int* channelsOut, unsigned int* sampleRateOut, ma_uint64* totalPCMFrameCountOut, const ma_allocation_callbacks* pAllocationCallbacks) { ma_dr_flac* pFlac; if (channelsOut) { @@ -90458,7 +92480,7 @@ MA_API float* ma_dr_flac_open_and_read_pcm_frames_f32(ma_dr_flac_read_proc onRea if (totalPCMFrameCountOut) { *totalPCMFrameCountOut = 0; } - pFlac = ma_dr_flac_open(onRead, onSeek, pUserData, pAllocationCallbacks); + pFlac = ma_dr_flac_open(onRead, onSeek, onTell, pUserData, pAllocationCallbacks); if (pFlac == NULL) { return NULL; } @@ -90680,12 +92702,9 @@ MA_API const char* ma_dr_mp3_version_string(void) #define MA_DR_MP3_NO_SIMD #endif #define MA_DR_MP3_OFFSET_PTR(p, offset) ((void*)((ma_uint8*)(p) + (offset))) -#define MA_DR_MP3_MAX_FREE_FORMAT_FRAME_SIZE 2304 #ifndef MA_DR_MP3_MAX_FRAME_SYNC_MATCHES #define MA_DR_MP3_MAX_FRAME_SYNC_MATCHES 10 #endif -#define MA_DR_MP3_MAX_L3_FRAME_PAYLOAD_BYTES MA_DR_MP3_MAX_FREE_FORMAT_FRAME_SIZE -#define MA_DR_MP3_MAX_BITRESERVOIR_BYTES 511 #define MA_DR_MP3_SHORT_BLOCK_TYPE 2 #define MA_DR_MP3_STOP_BLOCK_TYPE 3 #define MA_DR_MP3_MODE_MONO 3 @@ -90735,7 +92754,7 @@ MA_API const char* ma_dr_mp3_version_string(void) #define MA_DR_MP3_VMUL_S(x, s) _mm_mul_ps(x, _mm_set1_ps(s)) #define MA_DR_MP3_VREV(x) _mm_shuffle_ps(x, x, _MM_SHUFFLE(0, 1, 2, 3)) typedef __m128 ma_dr_mp3_f4; -#if defined(_MSC_VER) || defined(MA_DR_MP3_ONLY_SIMD) +#if (defined(_MSC_VER) || defined(MA_DR_MP3_ONLY_SIMD)) && !defined(__clang__) #define ma_dr_mp3_cpuid __cpuid #else static __inline__ __attribute__((always_inline)) void ma_dr_mp3_cpuid(int CPUInfo[], const int InfoType) @@ -90851,11 +92870,6 @@ static __inline__ __attribute__((always_inline)) ma_int32 ma_dr_mp3_clip_int16_a #define MA_DR_MP3_FREE(p) free((p)) #endif typedef struct -{ - const ma_uint8 *buf; - int pos, limit; -} ma_dr_mp3_bs; -typedef struct { float scf[3*64]; ma_uint8 total_bands, stereo_bands, bitalloc[64], scfcod[64]; @@ -90864,22 +92878,6 @@ typedef struct { ma_uint8 tab_offset, code_tab_width, band_count; } ma_dr_mp3_L12_subband_alloc; -typedef struct -{ - const ma_uint8 *sfbtab; - ma_uint16 part_23_length, big_values, scalefac_compress; - ma_uint8 global_gain, block_type, mixed_block_flag, n_long_sfb, n_short_sfb; - ma_uint8 table_select[3], region_count[3], subblock_gain[3]; - ma_uint8 preflag, scalefac_scale, count1_table, scfsi; -} ma_dr_mp3_L3_gr_info; -typedef struct -{ - ma_dr_mp3_bs bs; - ma_uint8 maindata[MA_DR_MP3_MAX_BITRESERVOIR_BYTES + MA_DR_MP3_MAX_L3_FRAME_PAYLOAD_BYTES]; - ma_dr_mp3_L3_gr_info gr_info[4]; - float grbuf[2][576], scf[40], syn[18 + 15][2*32]; - ma_uint8 ist_pos[2][39]; -} ma_dr_mp3dec_scratch; static void ma_dr_mp3_bs_init(ma_dr_mp3_bs *bs, const ma_uint8 *data, int bytes) { bs->buf = data; @@ -91262,6 +93260,10 @@ static float ma_dr_mp3_L3_ldexp_q2(float y, int exp_q2) } while ((exp_q2 -= e) > 0); return y; } +#if (defined(__GNUC__) && (__GNUC__ >= 13)) && !defined(__clang__) + #pragma GCC diagnostic push + #pragma GCC diagnostic ignored "-Wstringop-overflow" +#endif static void ma_dr_mp3_L3_decode_scalefactors(const ma_uint8 *hdr, ma_uint8 *ist_pos, ma_dr_mp3_bs *bs, const ma_dr_mp3_L3_gr_info *gr, float *scf, int ch) { static const ma_uint8 g_scf_partitions[3][28] = { @@ -91320,7 +93322,10 @@ static void ma_dr_mp3_L3_decode_scalefactors(const ma_uint8 *hdr, ma_uint8 *ist_ scf[i] = ma_dr_mp3_L3_ldexp_q2(gain, iscf[i] << scf_shift); } } -static const float g_ma_dr_mp3_pow43[129 + 16] = { +#if (defined(__GNUC__) && (__GNUC__ >= 13)) && !defined(__clang__) + #pragma GCC diagnostic pop +#endif +static const float ma_dr_mp3_g_pow43[129 + 16] = { 0,-1,-2.519842f,-4.326749f,-6.349604f,-8.549880f,-10.902724f,-13.390518f,-16.000000f,-18.720754f,-21.544347f,-24.463781f,-27.473142f,-30.567351f,-33.741992f,-36.993181f, 0,1,2.519842f,4.326749f,6.349604f,8.549880f,10.902724f,13.390518f,16.000000f,18.720754f,21.544347f,24.463781f,27.473142f,30.567351f,33.741992f,36.993181f,40.317474f,43.711787f,47.173345f,50.699631f,54.288352f,57.937408f,61.644865f,65.408941f,69.227979f,73.100443f,77.024898f,81.000000f,85.024491f,89.097188f,93.216975f,97.382800f,101.593667f,105.848633f,110.146801f,114.487321f,118.869381f,123.292209f,127.755065f,132.257246f,136.798076f,141.376907f,145.993119f,150.646117f,155.335327f,160.060199f,164.820202f,169.614826f,174.443577f,179.305980f,184.201575f,189.129918f,194.090580f,199.083145f,204.107210f,209.162385f,214.248292f,219.364564f,224.510845f,229.686789f,234.892058f,240.126328f,245.389280f,250.680604f,256.000000f,261.347174f,266.721841f,272.123723f,277.552547f,283.008049f,288.489971f,293.998060f,299.532071f,305.091761f,310.676898f,316.287249f,321.922592f,327.582707f,333.267377f,338.976394f,344.709550f,350.466646f,356.247482f,362.051866f,367.879608f,373.730522f,379.604427f,385.501143f,391.420496f,397.362314f,403.326427f,409.312672f,415.320884f,421.350905f,427.402579f,433.475750f,439.570269f,445.685987f,451.822757f,457.980436f,464.158883f,470.357960f,476.577530f,482.817459f,489.077615f,495.357868f,501.658090f,507.978156f,514.317941f,520.677324f,527.056184f,533.454404f,539.871867f,546.308458f,552.764065f,559.238575f,565.731879f,572.243870f,578.774440f,585.323483f,591.890898f,598.476581f,605.080431f,611.702349f,618.342238f,625.000000f,631.675540f,638.368763f,645.079578f }; @@ -91330,7 +93335,7 @@ static float ma_dr_mp3_L3_pow_43(int x) int sign, mult = 256; if (x < 129) { - return g_ma_dr_mp3_pow43[16 + x]; + return ma_dr_mp3_g_pow43[16 + x]; } if (x < 1024) { @@ -91339,7 +93344,7 @@ static float ma_dr_mp3_L3_pow_43(int x) } sign = 2*x & 64; frac = (float)((x & 63) - sign) / ((x & ~63) + sign); - return g_ma_dr_mp3_pow43[16 + ((x + sign) >> 6)]*(1.f + frac*((4.f/3) + frac*(2.f/9)))*mult; + return ma_dr_mp3_g_pow43[16 + ((x + sign) >> 6)]*(1.f + frac*((4.f/3) + frac*(2.f/9)))*mult; } static void ma_dr_mp3_L3_huffman(float *dst, ma_dr_mp3_bs *bs, const ma_dr_mp3_L3_gr_info *gr_info, const float *scf, int layer3gr_limit) { @@ -91409,7 +93414,7 @@ static void ma_dr_mp3_L3_huffman(float *dst, ma_dr_mp3_bs *bs, const ma_dr_mp3_L *dst = one*ma_dr_mp3_L3_pow_43(lsb)*((ma_int32)bs_cache < 0 ? -1: 1); } else { - *dst = g_ma_dr_mp3_pow43[16 + lsb - 16*(bs_cache >> 31)]*one; + *dst = ma_dr_mp3_g_pow43[16 + lsb - 16*(bs_cache >> 31)]*one; } MA_DR_MP3_FLUSH_BITS(lsb ? 1 : 0); } @@ -91437,7 +93442,7 @@ static void ma_dr_mp3_L3_huffman(float *dst, ma_dr_mp3_bs *bs, const ma_dr_mp3_L for (j = 0; j < 2; j++, dst++, leaf >>= 4) { int lsb = leaf & 0x0F; - *dst = g_ma_dr_mp3_pow43[16 + lsb - 16*(bs_cache >> 31)]*one; + *dst = ma_dr_mp3_g_pow43[16 + lsb - 16*(bs_cache >> 31)]*one; MA_DR_MP3_FLUSH_BITS(lsb ? 1 : 0); } MA_DR_MP3_CHECK_BITS; @@ -92245,7 +94250,6 @@ MA_API int ma_dr_mp3dec_decode_frame(ma_dr_mp3dec *dec, const ma_uint8 *mp3, int int i = 0, igr, frame_size = 0, success = 1; const ma_uint8 *hdr; ma_dr_mp3_bs bs_frame[1]; - ma_dr_mp3dec_scratch scratch; if (mp3_bytes > 4 && dec->header[0] == 0xff && ma_dr_mp3_hdr_compare(dec->header, mp3)) { frame_size = ma_dr_mp3_hdr_frame_bytes(mp3, dec->free_format_bytes) + ma_dr_mp3_hdr_padding(mp3); @@ -92268,7 +94272,7 @@ MA_API int ma_dr_mp3dec_decode_frame(ma_dr_mp3dec *dec, const ma_uint8 *mp3, int MA_DR_MP3_COPY_MEMORY(dec->header, hdr, MA_DR_MP3_HDR_SIZE); info->frame_bytes = i + frame_size; info->channels = MA_DR_MP3_HDR_IS_MONO(hdr) ? 1 : 2; - info->hz = ma_dr_mp3_hdr_sample_rate_hz(hdr); + info->sample_rate = ma_dr_mp3_hdr_sample_rate_hz(hdr); info->layer = 4 - MA_DR_MP3_HDR_GET_LAYER(hdr); info->bitrate_kbps = ma_dr_mp3_hdr_bitrate_kbps(hdr); ma_dr_mp3_bs_init(bs_frame, hdr + MA_DR_MP3_HDR_SIZE, frame_size - MA_DR_MP3_HDR_SIZE); @@ -92278,23 +94282,23 @@ MA_API int ma_dr_mp3dec_decode_frame(ma_dr_mp3dec *dec, const ma_uint8 *mp3, int } if (info->layer == 3) { - int main_data_begin = ma_dr_mp3_L3_read_side_info(bs_frame, scratch.gr_info, hdr); + int main_data_begin = ma_dr_mp3_L3_read_side_info(bs_frame, dec->scratch.gr_info, hdr); if (main_data_begin < 0 || bs_frame->pos > bs_frame->limit) { ma_dr_mp3dec_init(dec); return 0; } - success = ma_dr_mp3_L3_restore_reservoir(dec, bs_frame, &scratch, main_data_begin); + success = ma_dr_mp3_L3_restore_reservoir(dec, bs_frame, &dec->scratch, main_data_begin); if (success && pcm != NULL) { for (igr = 0; igr < (MA_DR_MP3_HDR_TEST_MPEG1(hdr) ? 2 : 1); igr++, pcm = MA_DR_MP3_OFFSET_PTR(pcm, sizeof(ma_dr_mp3d_sample_t)*576*info->channels)) { - MA_DR_MP3_ZERO_MEMORY(scratch.grbuf[0], 576*2*sizeof(float)); - ma_dr_mp3_L3_decode(dec, &scratch, scratch.gr_info + igr*info->channels, info->channels); - ma_dr_mp3d_synth_granule(dec->qmf_state, scratch.grbuf[0], 18, info->channels, (ma_dr_mp3d_sample_t*)pcm, scratch.syn[0]); + MA_DR_MP3_ZERO_MEMORY(dec->scratch.grbuf[0], 576*2*sizeof(float)); + ma_dr_mp3_L3_decode(dec, &dec->scratch, dec->scratch.gr_info + igr*info->channels, info->channels); + ma_dr_mp3d_synth_granule(dec->qmf_state, dec->scratch.grbuf[0], 18, info->channels, (ma_dr_mp3d_sample_t*)pcm, dec->scratch.syn[0]); } } - ma_dr_mp3_L3_save_reservoir(dec, &scratch); + ma_dr_mp3_L3_save_reservoir(dec, &dec->scratch); } else { #ifdef MA_DR_MP3_ONLY_MP3 @@ -92305,15 +94309,15 @@ MA_API int ma_dr_mp3dec_decode_frame(ma_dr_mp3dec *dec, const ma_uint8 *mp3, int return ma_dr_mp3_hdr_frame_samples(hdr); } ma_dr_mp3_L12_read_scale_info(hdr, bs_frame, sci); - MA_DR_MP3_ZERO_MEMORY(scratch.grbuf[0], 576*2*sizeof(float)); + MA_DR_MP3_ZERO_MEMORY(dec->scratch.grbuf[0], 576*2*sizeof(float)); for (i = 0, igr = 0; igr < 3; igr++) { - if (12 == (i += ma_dr_mp3_L12_dequantize_granule(scratch.grbuf[0] + i, bs_frame, sci, info->layer | 1))) + if (12 == (i += ma_dr_mp3_L12_dequantize_granule(dec->scratch.grbuf[0] + i, bs_frame, sci, info->layer | 1))) { i = 0; - ma_dr_mp3_L12_apply_scf_384(sci, sci->scf + igr, scratch.grbuf[0]); - ma_dr_mp3d_synth_granule(dec->qmf_state, scratch.grbuf[0], 12, info->channels, (ma_dr_mp3d_sample_t*)pcm, scratch.syn[0]); - MA_DR_MP3_ZERO_MEMORY(scratch.grbuf[0], 576*2*sizeof(float)); + ma_dr_mp3_L12_apply_scf_384(sci, sci->scf + igr, dec->scratch.grbuf[0]); + ma_dr_mp3d_synth_granule(dec->qmf_state, dec->scratch.grbuf[0], 12, info->channels, (ma_dr_mp3d_sample_t*)pcm, dec->scratch.syn[0]); + MA_DR_MP3_ZERO_MEMORY(dec->scratch.grbuf[0], 576*2*sizeof(float)); pcm = MA_DR_MP3_OFFSET_PTR(pcm, sizeof(ma_dr_mp3d_sample_t)*384*info->channels); } if (bs_frame->pos > bs_frame->limit) @@ -92491,19 +94495,41 @@ static ma_allocation_callbacks ma_dr_mp3_copy_allocation_callbacks_or_defaults(c } static size_t ma_dr_mp3__on_read(ma_dr_mp3* pMP3, void* pBufferOut, size_t bytesToRead) { - size_t bytesRead = pMP3->onRead(pMP3->pUserData, pBufferOut, bytesToRead); + size_t bytesRead; + MA_DR_MP3_ASSERT(pMP3 != NULL); + MA_DR_MP3_ASSERT(pMP3->onRead != NULL); + if (bytesToRead == 0) { + return 0; + } + bytesRead = pMP3->onRead(pMP3->pUserData, pBufferOut, bytesToRead); pMP3->streamCursor += bytesRead; return bytesRead; } +static size_t ma_dr_mp3__on_read_clamped(ma_dr_mp3* pMP3, void* pBufferOut, size_t bytesToRead) +{ + MA_DR_MP3_ASSERT(pMP3 != NULL); + MA_DR_MP3_ASSERT(pMP3->onRead != NULL); + if (pMP3->streamLength == MA_UINT64_MAX) { + return ma_dr_mp3__on_read(pMP3, pBufferOut, bytesToRead); + } else { + ma_uint64 bytesRemaining; + bytesRemaining = (pMP3->streamLength - pMP3->streamCursor); + if (bytesToRead > bytesRemaining) { + bytesToRead = (size_t)bytesRemaining; + } + return ma_dr_mp3__on_read(pMP3, pBufferOut, bytesToRead); + } +} static ma_bool32 ma_dr_mp3__on_seek(ma_dr_mp3* pMP3, int offset, ma_dr_mp3_seek_origin origin) { MA_DR_MP3_ASSERT(offset >= 0); + MA_DR_MP3_ASSERT(origin == MA_DR_MP3_SEEK_SET || origin == MA_DR_MP3_SEEK_CUR); if (!pMP3->onSeek(pMP3->pUserData, offset, origin)) { return MA_FALSE; } - if (origin == ma_dr_mp3_seek_origin_start) { + if (origin == MA_DR_MP3_SEEK_SET) { pMP3->streamCursor = (ma_uint64)offset; - } else { + } else{ pMP3->streamCursor += offset; } return MA_TRUE; @@ -92513,18 +94539,18 @@ static ma_bool32 ma_dr_mp3__on_seek_64(ma_dr_mp3* pMP3, ma_uint64 offset, ma_dr_ if (offset <= 0x7FFFFFFF) { return ma_dr_mp3__on_seek(pMP3, (int)offset, origin); } - if (!ma_dr_mp3__on_seek(pMP3, 0x7FFFFFFF, ma_dr_mp3_seek_origin_start)) { + if (!ma_dr_mp3__on_seek(pMP3, 0x7FFFFFFF, MA_DR_MP3_SEEK_SET)) { return MA_FALSE; } offset -= 0x7FFFFFFF; while (offset > 0) { if (offset <= 0x7FFFFFFF) { - if (!ma_dr_mp3__on_seek(pMP3, (int)offset, ma_dr_mp3_seek_origin_current)) { + if (!ma_dr_mp3__on_seek(pMP3, (int)offset, MA_DR_MP3_SEEK_CUR)) { return MA_FALSE; } offset = 0; } else { - if (!ma_dr_mp3__on_seek(pMP3, 0x7FFFFFFF, ma_dr_mp3_seek_origin_current)) { + if (!ma_dr_mp3__on_seek(pMP3, 0x7FFFFFFF, MA_DR_MP3_SEEK_CUR)) { return MA_FALSE; } offset -= 0x7FFFFFFF; @@ -92532,7 +94558,18 @@ static ma_bool32 ma_dr_mp3__on_seek_64(ma_dr_mp3* pMP3, ma_uint64 offset, ma_dr_ } return MA_TRUE; } -static ma_uint32 ma_dr_mp3_decode_next_frame_ex__callbacks(ma_dr_mp3* pMP3, ma_dr_mp3d_sample_t* pPCMFrames) +static void ma_dr_mp3__on_meta(ma_dr_mp3* pMP3, ma_dr_mp3_metadata_type type, const void* pRawData, size_t rawDataSize) +{ + if (pMP3->onMeta) { + ma_dr_mp3_metadata metadata; + MA_DR_MP3_ZERO_OBJECT(&metadata); + metadata.type = type; + metadata.pRawData = pRawData; + metadata.rawDataSize = rawDataSize; + pMP3->onMeta(pMP3->pUserDataMeta, &metadata); + } +} +static ma_uint32 ma_dr_mp3_decode_next_frame_ex__callbacks(ma_dr_mp3* pMP3, ma_dr_mp3d_sample_t* pPCMFrames, ma_dr_mp3dec_frame_info* pMP3FrameInfo, const ma_uint8** ppMP3FrameData) { ma_uint32 pcmFramesRead = 0; MA_DR_MP3_ASSERT(pMP3 != NULL); @@ -92559,7 +94596,7 @@ static ma_uint32 ma_dr_mp3_decode_next_frame_ex__callbacks(ma_dr_mp3* pMP3, ma_d pMP3->pData = pNewData; pMP3->dataCapacity = newDataCap; } - bytesRead = ma_dr_mp3__on_read(pMP3, pMP3->pData + pMP3->dataSize, (pMP3->dataCapacity - pMP3->dataSize)); + bytesRead = ma_dr_mp3__on_read_clamped(pMP3, pMP3->pData + pMP3->dataSize, (pMP3->dataCapacity - pMP3->dataSize)); if (bytesRead == 0) { if (pMP3->dataSize == 0) { pMP3->atEnd = MA_TRUE; @@ -92578,16 +94615,20 @@ static ma_uint32 ma_dr_mp3_decode_next_frame_ex__callbacks(ma_dr_mp3* pMP3, ma_d return 0; } pcmFramesRead = ma_dr_mp3dec_decode_frame(&pMP3->decoder, pMP3->pData + pMP3->dataConsumed, (int)pMP3->dataSize, pPCMFrames, &info); - if (info.frame_bytes > 0) { - pMP3->dataConsumed += (size_t)info.frame_bytes; - pMP3->dataSize -= (size_t)info.frame_bytes; - } + pMP3->dataConsumed += (size_t)info.frame_bytes; + pMP3->dataSize -= (size_t)info.frame_bytes; if (pcmFramesRead > 0) { pcmFramesRead = ma_dr_mp3_hdr_frame_samples(pMP3->decoder.header); pMP3->pcmFramesConsumedInMP3Frame = 0; pMP3->pcmFramesRemainingInMP3Frame = pcmFramesRead; pMP3->mp3FrameChannels = info.channels; - pMP3->mp3FrameSampleRate = info.hz; + pMP3->mp3FrameSampleRate = info.sample_rate; + if (pMP3FrameInfo != NULL) { + *pMP3FrameInfo = info; + } + if (ppMP3FrameData != NULL) { + *ppMP3FrameData = pMP3->pData + pMP3->dataConsumed - (size_t)info.frame_bytes; + } break; } else if (info.frame_bytes == 0) { size_t bytesRead; @@ -92604,7 +94645,7 @@ static ma_uint32 ma_dr_mp3_decode_next_frame_ex__callbacks(ma_dr_mp3* pMP3, ma_d pMP3->pData = pNewData; pMP3->dataCapacity = newDataCap; } - bytesRead = ma_dr_mp3__on_read(pMP3, pMP3->pData + pMP3->dataSize, (pMP3->dataCapacity - pMP3->dataSize)); + bytesRead = ma_dr_mp3__on_read_clamped(pMP3, pMP3->pData + pMP3->dataSize, (pMP3->dataCapacity - pMP3->dataSize)); if (bytesRead == 0) { pMP3->atEnd = MA_TRUE; return 0; @@ -92614,7 +94655,7 @@ static ma_uint32 ma_dr_mp3_decode_next_frame_ex__callbacks(ma_dr_mp3* pMP3, ma_d }; return pcmFramesRead; } -static ma_uint32 ma_dr_mp3_decode_next_frame_ex__memory(ma_dr_mp3* pMP3, ma_dr_mp3d_sample_t* pPCMFrames) +static ma_uint32 ma_dr_mp3_decode_next_frame_ex__memory(ma_dr_mp3* pMP3, ma_dr_mp3d_sample_t* pPCMFrames, ma_dr_mp3dec_frame_info* pMP3FrameInfo, const ma_uint8** ppMP3FrameData) { ma_uint32 pcmFramesRead = 0; ma_dr_mp3dec_frame_info info; @@ -92630,36 +94671,44 @@ static ma_uint32 ma_dr_mp3_decode_next_frame_ex__memory(ma_dr_mp3* pMP3, ma_dr_m pMP3->pcmFramesConsumedInMP3Frame = 0; pMP3->pcmFramesRemainingInMP3Frame = pcmFramesRead; pMP3->mp3FrameChannels = info.channels; - pMP3->mp3FrameSampleRate = info.hz; + pMP3->mp3FrameSampleRate = info.sample_rate; + if (pMP3FrameInfo != NULL) { + *pMP3FrameInfo = info; + } + if (ppMP3FrameData != NULL) { + *ppMP3FrameData = pMP3->memory.pData + pMP3->memory.currentReadPos; + } break; } else if (info.frame_bytes > 0) { pMP3->memory.currentReadPos += (size_t)info.frame_bytes; + pMP3->streamCursor += (size_t)info.frame_bytes; } else { break; } } pMP3->memory.currentReadPos += (size_t)info.frame_bytes; + pMP3->streamCursor += (size_t)info.frame_bytes; return pcmFramesRead; } -static ma_uint32 ma_dr_mp3_decode_next_frame_ex(ma_dr_mp3* pMP3, ma_dr_mp3d_sample_t* pPCMFrames) +static ma_uint32 ma_dr_mp3_decode_next_frame_ex(ma_dr_mp3* pMP3, ma_dr_mp3d_sample_t* pPCMFrames, ma_dr_mp3dec_frame_info* pMP3FrameInfo, const ma_uint8** ppMP3FrameData) { if (pMP3->memory.pData != NULL && pMP3->memory.dataSize > 0) { - return ma_dr_mp3_decode_next_frame_ex__memory(pMP3, pPCMFrames); + return ma_dr_mp3_decode_next_frame_ex__memory(pMP3, pPCMFrames, pMP3FrameInfo, ppMP3FrameData); } else { - return ma_dr_mp3_decode_next_frame_ex__callbacks(pMP3, pPCMFrames); + return ma_dr_mp3_decode_next_frame_ex__callbacks(pMP3, pPCMFrames, pMP3FrameInfo, ppMP3FrameData); } } static ma_uint32 ma_dr_mp3_decode_next_frame(ma_dr_mp3* pMP3) { MA_DR_MP3_ASSERT(pMP3 != NULL); - return ma_dr_mp3_decode_next_frame_ex(pMP3, (ma_dr_mp3d_sample_t*)pMP3->pcmFrames); + return ma_dr_mp3_decode_next_frame_ex(pMP3, (ma_dr_mp3d_sample_t*)pMP3->pcmFrames, NULL, NULL); } #if 0 static ma_uint32 ma_dr_mp3_seek_next_frame(ma_dr_mp3* pMP3) { ma_uint32 pcmFrameCount; MA_DR_MP3_ASSERT(pMP3 != NULL); - pcmFrameCount = ma_dr_mp3_decode_next_frame_ex(pMP3, NULL); + pcmFrameCount = ma_dr_mp3_decode_next_frame_ex(pMP3, NULL, NULL, NULL); if (pcmFrameCount == 0) { return 0; } @@ -92669,33 +94718,252 @@ static ma_uint32 ma_dr_mp3_seek_next_frame(ma_dr_mp3* pMP3) return pcmFrameCount; } #endif -static ma_bool32 ma_dr_mp3_init_internal(ma_dr_mp3* pMP3, ma_dr_mp3_read_proc onRead, ma_dr_mp3_seek_proc onSeek, void* pUserData, const ma_allocation_callbacks* pAllocationCallbacks) +static ma_bool32 ma_dr_mp3_init_internal(ma_dr_mp3* pMP3, ma_dr_mp3_read_proc onRead, ma_dr_mp3_seek_proc onSeek, ma_dr_mp3_tell_proc onTell, ma_dr_mp3_meta_proc onMeta, void* pUserData, void* pUserDataMeta, const ma_allocation_callbacks* pAllocationCallbacks) { + ma_dr_mp3dec_frame_info firstFrameInfo; + const ma_uint8* pFirstFrameData; + ma_uint32 firstFramePCMFrameCount; + ma_uint32 detectedMP3FrameCount = 0xFFFFFFFF; MA_DR_MP3_ASSERT(pMP3 != NULL); MA_DR_MP3_ASSERT(onRead != NULL); ma_dr_mp3dec_init(&pMP3->decoder); pMP3->onRead = onRead; pMP3->onSeek = onSeek; + pMP3->onMeta = onMeta; pMP3->pUserData = pUserData; + pMP3->pUserDataMeta = pUserDataMeta; pMP3->allocationCallbacks = ma_dr_mp3_copy_allocation_callbacks_or_defaults(pAllocationCallbacks); if (pMP3->allocationCallbacks.onFree == NULL || (pMP3->allocationCallbacks.onMalloc == NULL && pMP3->allocationCallbacks.onRealloc == NULL)) { return MA_FALSE; } - if (ma_dr_mp3_decode_next_frame(pMP3) == 0) { + pMP3->streamCursor = 0; + pMP3->streamLength = MA_UINT64_MAX; + pMP3->streamStartOffset = 0; + pMP3->delayInPCMFrames = 0; + pMP3->paddingInPCMFrames = 0; + pMP3->totalPCMFrameCount = MA_UINT64_MAX; + #if 1 + if (onSeek != NULL && onTell != NULL) { + if (onSeek(pUserData, 0, MA_DR_MP3_SEEK_END)) { + ma_int64 streamLen; + int streamEndOffset = 0; + if (onTell(pUserData, &streamLen)) { + if (streamLen > 128) { + char id3[3]; + if (onSeek(pUserData, streamEndOffset - 128, MA_DR_MP3_SEEK_END)) { + if (onRead(pUserData, id3, 3) == 3 && id3[0] == 'T' && id3[1] == 'A' && id3[2] == 'G') { + streamEndOffset -= 128; + streamLen -= 128; + if (onMeta != NULL) { + ma_uint8 tag[128]; + tag[0] = 'T'; tag[1] = 'A'; tag[2] = 'G'; + if (onRead(pUserData, tag + 3, 125) == 125) { + ma_dr_mp3__on_meta(pMP3, MA_DR_MP3_METADATA_TYPE_ID3V1, tag, 128); + } + } + } else { + } + } else { + } + } else { + } + if (streamLen > 32) { + char ape[32]; + if (onSeek(pUserData, streamEndOffset - 32, MA_DR_MP3_SEEK_END)) { + if (onRead(pUserData, ape, 32) == 32 && ape[0] == 'A' && ape[1] == 'P' && ape[2] == 'E' && ape[3] == 'T' && ape[4] == 'A' && ape[5] == 'G' && ape[6] == 'E' && ape[7] == 'X') { + ma_uint32 tagSize = + ((ma_uint32)ape[24] << 0) | + ((ma_uint32)ape[25] << 8) | + ((ma_uint32)ape[26] << 16) | + ((ma_uint32)ape[27] << 24); + if (32 + tagSize < streamLen) { + streamEndOffset -= 32 + tagSize; + streamLen -= 32 + tagSize; + if (onMeta != NULL) { + if (onSeek(pUserData, streamEndOffset, MA_DR_MP3_SEEK_END)) { + size_t apeTagSize = (size_t)tagSize + 32; + ma_uint8* pTagData = (ma_uint8*)ma_dr_mp3_malloc(apeTagSize, pAllocationCallbacks); + if (pTagData != NULL) { + if (onRead(pUserData, pTagData, apeTagSize) == apeTagSize) { + ma_dr_mp3__on_meta(pMP3, MA_DR_MP3_METADATA_TYPE_APE, pTagData, apeTagSize); + } + ma_dr_mp3_free(pTagData, pAllocationCallbacks); + } + } + } + } else { + } + } + } + } else { + } + if (!onSeek(pUserData, 0, MA_DR_MP3_SEEK_SET)) { + return MA_FALSE; + } + pMP3->streamLength = (ma_uint64)streamLen; + if (pMP3->memory.pData != NULL) { + pMP3->memory.dataSize = (size_t)pMP3->streamLength; + } + } else { + if (!onSeek(pUserData, 0, MA_DR_MP3_SEEK_SET)) { + return MA_FALSE; + } + } + } else { + } + } else { + } + #endif + #if 1 + { + char header[10]; + if (onRead(pUserData, header, 10) == 10) { + if (header[0] == 'I' && header[1] == 'D' && header[2] == '3') { + ma_uint32 tagSize = + (((ma_uint32)header[6] & 0x7F) << 21) | + (((ma_uint32)header[7] & 0x7F) << 14) | + (((ma_uint32)header[8] & 0x7F) << 7) | + (((ma_uint32)header[9] & 0x7F) << 0); + if (header[5] & 0x10) { + tagSize += 10; + } + if (onMeta != NULL) { + size_t tagSizeWithHeader = 10 + tagSize; + ma_uint8* pTagData = (ma_uint8*)ma_dr_mp3_malloc(tagSizeWithHeader, pAllocationCallbacks); + if (pTagData != NULL) { + MA_DR_MP3_COPY_MEMORY(pTagData, header, 10); + if (onRead(pUserData, pTagData + 10, tagSize) == tagSize) { + ma_dr_mp3__on_meta(pMP3, MA_DR_MP3_METADATA_TYPE_ID3V2, pTagData, tagSizeWithHeader); + } + ma_dr_mp3_free(pTagData, pAllocationCallbacks); + } + } else { + if (onSeek != NULL) { + if (!onSeek(pUserData, tagSize, MA_DR_MP3_SEEK_CUR)) { + return MA_FALSE; + } + } else { + char discard[1024]; + while (tagSize > 0) { + size_t bytesToRead = tagSize; + if (bytesToRead > sizeof(discard)) { + bytesToRead = sizeof(discard); + } + if (onRead(pUserData, discard, bytesToRead) != bytesToRead) { + return MA_FALSE; + } + tagSize -= (ma_uint32)bytesToRead; + } + } + } + pMP3->streamStartOffset += 10 + tagSize; + pMP3->streamCursor = pMP3->streamStartOffset; + } else { + if (onSeek != NULL) { + if (!onSeek(pUserData, 0, MA_DR_MP3_SEEK_SET)) { + return MA_FALSE; + } + } else { + } + } + } else { + return MA_FALSE; + } + } + #endif + firstFramePCMFrameCount = ma_dr_mp3_decode_next_frame_ex(pMP3, (ma_dr_mp3d_sample_t*)pMP3->pcmFrames, &firstFrameInfo, &pFirstFrameData); + if (firstFramePCMFrameCount > 0) { + MA_DR_MP3_ASSERT(pFirstFrameData != NULL); + #if 1 + MA_DR_MP3_ASSERT(firstFrameInfo.frame_bytes > 0); + { + ma_dr_mp3_bs bs; + ma_dr_mp3_L3_gr_info grInfo[4]; + ma_dr_mp3_bs_init(&bs, pFirstFrameData + MA_DR_MP3_HDR_SIZE, firstFrameInfo.frame_bytes - MA_DR_MP3_HDR_SIZE); + if (MA_DR_MP3_HDR_IS_CRC(pFirstFrameData)) { + ma_dr_mp3_bs_get_bits(&bs, 16); + } + if (ma_dr_mp3_L3_read_side_info(&bs, grInfo, pFirstFrameData) >= 0) { + ma_bool32 isXing = MA_FALSE; + ma_bool32 isInfo = MA_FALSE; + const ma_uint8* pTagData; + const ma_uint8* pTagDataBeg; + pTagDataBeg = pFirstFrameData + MA_DR_MP3_HDR_SIZE + (bs.pos/8); + pTagData = pTagDataBeg; + isXing = (pTagData[0] == 'X' && pTagData[1] == 'i' && pTagData[2] == 'n' && pTagData[3] == 'g'); + isInfo = (pTagData[0] == 'I' && pTagData[1] == 'n' && pTagData[2] == 'f' && pTagData[3] == 'o'); + if (isXing || isInfo) { + ma_uint32 bytes = 0; + ma_uint32 flags = pTagData[7]; + pTagData += 8; + if (flags & 0x01) { + detectedMP3FrameCount = (ma_uint32)pTagData[0] << 24 | (ma_uint32)pTagData[1] << 16 | (ma_uint32)pTagData[2] << 8 | (ma_uint32)pTagData[3]; + pTagData += 4; + } + if (flags & 0x02) { + bytes = (ma_uint32)pTagData[0] << 24 | (ma_uint32)pTagData[1] << 16 | (ma_uint32)pTagData[2] << 8 | (ma_uint32)pTagData[3]; + (void)bytes; + pTagData += 4; + } + if (flags & 0x04) { + pTagData += 100; + } + if (flags & 0x08) { + pTagData += 4; + } + if (pTagData[0]) { + pTagData += 21; + if (pTagData - pFirstFrameData + 14 < firstFrameInfo.frame_bytes) { + int delayInPCMFrames; + int paddingInPCMFrames; + delayInPCMFrames = (( (ma_uint32)pTagData[0] << 4) | ((ma_uint32)pTagData[1] >> 4)) + (528 + 1); + paddingInPCMFrames = ((((ma_uint32)pTagData[1] & 0xF) << 8) | ((ma_uint32)pTagData[2] )) - (528 + 1); + if (paddingInPCMFrames < 0) { + paddingInPCMFrames = 0; + } + pMP3->delayInPCMFrames = (ma_uint32)delayInPCMFrames; + pMP3->paddingInPCMFrames = (ma_uint32)paddingInPCMFrames; + } + } + if (isXing) { + pMP3->isVBR = MA_TRUE; + } else if (isInfo) { + pMP3->isCBR = MA_TRUE; + } + if (onMeta != NULL) { + ma_dr_mp3_metadata_type metadataType = isXing ? MA_DR_MP3_METADATA_TYPE_XING : MA_DR_MP3_METADATA_TYPE_VBRI; + size_t tagDataSize; + tagDataSize = (size_t)firstFrameInfo.frame_bytes; + tagDataSize -= (size_t)(pTagDataBeg - pFirstFrameData); + ma_dr_mp3__on_meta(pMP3, metadataType, pTagDataBeg, tagDataSize); + } + pMP3->pcmFramesRemainingInMP3Frame = 0; + pMP3->streamStartOffset += (ma_uint32)(firstFrameInfo.frame_bytes); + pMP3->streamCursor = pMP3->streamStartOffset; + ma_dr_mp3dec_init(&pMP3->decoder); + } + } else { + } + } + #endif + } else { ma_dr_mp3__free_from_callbacks(pMP3->pData, &pMP3->allocationCallbacks); return MA_FALSE; } + if (detectedMP3FrameCount != 0xFFFFFFFF) { + pMP3->totalPCMFrameCount = detectedMP3FrameCount * firstFramePCMFrameCount; + } pMP3->channels = pMP3->mp3FrameChannels; pMP3->sampleRate = pMP3->mp3FrameSampleRate; return MA_TRUE; } -MA_API ma_bool32 ma_dr_mp3_init(ma_dr_mp3* pMP3, ma_dr_mp3_read_proc onRead, ma_dr_mp3_seek_proc onSeek, void* pUserData, const ma_allocation_callbacks* pAllocationCallbacks) +MA_API ma_bool32 ma_dr_mp3_init(ma_dr_mp3* pMP3, ma_dr_mp3_read_proc onRead, ma_dr_mp3_seek_proc onSeek, ma_dr_mp3_tell_proc onTell, ma_dr_mp3_meta_proc onMeta, void* pUserData, const ma_allocation_callbacks* pAllocationCallbacks) { if (pMP3 == NULL || onRead == NULL) { return MA_FALSE; } MA_DR_MP3_ZERO_OBJECT(pMP3); - return ma_dr_mp3_init_internal(pMP3, onRead, onSeek, pUserData, pAllocationCallbacks); + return ma_dr_mp3_init_internal(pMP3, onRead, onSeek, onTell, onMeta, pUserData, pUserData, pAllocationCallbacks); } static size_t ma_dr_mp3__on_read_memory(void* pUserData, void* pBufferOut, size_t bytesToRead) { @@ -92716,29 +94984,39 @@ static size_t ma_dr_mp3__on_read_memory(void* pUserData, void* pBufferOut, size_ static ma_bool32 ma_dr_mp3__on_seek_memory(void* pUserData, int byteOffset, ma_dr_mp3_seek_origin origin) { ma_dr_mp3* pMP3 = (ma_dr_mp3*)pUserData; + ma_int64 newCursor; MA_DR_MP3_ASSERT(pMP3 != NULL); - if (origin == ma_dr_mp3_seek_origin_current) { - if (byteOffset > 0) { - if (pMP3->memory.currentReadPos + byteOffset > pMP3->memory.dataSize) { - byteOffset = (int)(pMP3->memory.dataSize - pMP3->memory.currentReadPos); - } - } else { - if (pMP3->memory.currentReadPos < (size_t)-byteOffset) { - byteOffset = -(int)pMP3->memory.currentReadPos; - } - } - pMP3->memory.currentReadPos += byteOffset; + if (origin == MA_DR_MP3_SEEK_SET) { + newCursor = 0; + } else if (origin == MA_DR_MP3_SEEK_CUR) { + newCursor = (ma_int64)pMP3->memory.currentReadPos; + } else if (origin == MA_DR_MP3_SEEK_END) { + newCursor = (ma_int64)pMP3->memory.dataSize; } else { - if ((ma_uint32)byteOffset <= pMP3->memory.dataSize) { - pMP3->memory.currentReadPos = byteOffset; - } else { - pMP3->memory.currentReadPos = pMP3->memory.dataSize; - } + MA_DR_MP3_ASSERT(!"Invalid seek origin"); + return MA_FALSE; } + newCursor += byteOffset; + if (newCursor < 0) { + return MA_FALSE; + } + if ((size_t)newCursor > pMP3->memory.dataSize) { + return MA_FALSE; + } + pMP3->memory.currentReadPos = (size_t)newCursor; return MA_TRUE; } -MA_API ma_bool32 ma_dr_mp3_init_memory(ma_dr_mp3* pMP3, const void* pData, size_t dataSize, const ma_allocation_callbacks* pAllocationCallbacks) +static ma_bool32 ma_dr_mp3__on_tell_memory(void* pUserData, ma_int64* pCursor) { + ma_dr_mp3* pMP3 = (ma_dr_mp3*)pUserData; + MA_DR_MP3_ASSERT(pMP3 != NULL); + MA_DR_MP3_ASSERT(pCursor != NULL); + *pCursor = (ma_int64)pMP3->memory.currentReadPos; + return MA_TRUE; +} +MA_API ma_bool32 ma_dr_mp3_init_memory_with_metadata(ma_dr_mp3* pMP3, const void* pData, size_t dataSize, ma_dr_mp3_meta_proc onMeta, void* pUserDataMeta, const ma_allocation_callbacks* pAllocationCallbacks) +{ + ma_bool32 result; if (pMP3 == NULL) { return MA_FALSE; } @@ -92749,7 +95027,21 @@ MA_API ma_bool32 ma_dr_mp3_init_memory(ma_dr_mp3* pMP3, const void* pData, size_ pMP3->memory.pData = (const ma_uint8*)pData; pMP3->memory.dataSize = dataSize; pMP3->memory.currentReadPos = 0; - return ma_dr_mp3_init_internal(pMP3, ma_dr_mp3__on_read_memory, ma_dr_mp3__on_seek_memory, pMP3, pAllocationCallbacks); + result = ma_dr_mp3_init_internal(pMP3, ma_dr_mp3__on_read_memory, ma_dr_mp3__on_seek_memory, ma_dr_mp3__on_tell_memory, onMeta, pMP3, pUserDataMeta, pAllocationCallbacks); + if (result == MA_FALSE) { + return MA_FALSE; + } + if (pMP3->streamLength <= (ma_uint64)MA_SIZE_MAX) { + pMP3->memory.dataSize = (size_t)pMP3->streamLength; + } + if (pMP3->streamStartOffset > (ma_uint64)MA_SIZE_MAX) { + return MA_FALSE; + } + return MA_TRUE; +} +MA_API ma_bool32 ma_dr_mp3_init_memory(ma_dr_mp3* pMP3, const void* pData, size_t dataSize, const ma_allocation_callbacks* pAllocationCallbacks) +{ + return ma_dr_mp3_init_memory_with_metadata(pMP3, pData, dataSize, NULL, NULL, pAllocationCallbacks); } #ifndef MA_DR_MP3_NO_STDIO #include <stdio.h> @@ -92760,36 +95052,76 @@ static size_t ma_dr_mp3__on_read_stdio(void* pUserData, void* pBufferOut, size_t } static ma_bool32 ma_dr_mp3__on_seek_stdio(void* pUserData, int offset, ma_dr_mp3_seek_origin origin) { - return fseek((FILE*)pUserData, offset, (origin == ma_dr_mp3_seek_origin_current) ? SEEK_CUR : SEEK_SET) == 0; + int whence = SEEK_SET; + if (origin == MA_DR_MP3_SEEK_CUR) { + whence = SEEK_CUR; + } else if (origin == MA_DR_MP3_SEEK_END) { + whence = SEEK_END; + } + return fseek((FILE*)pUserData, offset, whence) == 0; } -MA_API ma_bool32 ma_dr_mp3_init_file(ma_dr_mp3* pMP3, const char* pFilePath, const ma_allocation_callbacks* pAllocationCallbacks) +static ma_bool32 ma_dr_mp3__on_tell_stdio(void* pUserData, ma_int64* pCursor) +{ + FILE* pFileStdio = (FILE*)pUserData; + ma_int64 result; + MA_DR_MP3_ASSERT(pFileStdio != NULL); + MA_DR_MP3_ASSERT(pCursor != NULL); +#if defined(_WIN32) && !defined(NXDK) + #if defined(_MSC_VER) && _MSC_VER > 1200 + result = _ftelli64(pFileStdio); + #else + result = ftell(pFileStdio); + #endif +#else + result = ftell(pFileStdio); +#endif + *pCursor = result; + return MA_TRUE; +} +MA_API ma_bool32 ma_dr_mp3_init_file_with_metadata(ma_dr_mp3* pMP3, const char* pFilePath, ma_dr_mp3_meta_proc onMeta, void* pUserDataMeta, const ma_allocation_callbacks* pAllocationCallbacks) { ma_bool32 result; FILE* pFile; + if (pMP3 == NULL) { + return MA_FALSE; + } + MA_DR_MP3_ZERO_OBJECT(pMP3); if (ma_fopen(&pFile, pFilePath, "rb") != MA_SUCCESS) { return MA_FALSE; } - result = ma_dr_mp3_init(pMP3, ma_dr_mp3__on_read_stdio, ma_dr_mp3__on_seek_stdio, (void*)pFile, pAllocationCallbacks); + result = ma_dr_mp3_init_internal(pMP3, ma_dr_mp3__on_read_stdio, ma_dr_mp3__on_seek_stdio, ma_dr_mp3__on_tell_stdio, onMeta, (void*)pFile, pUserDataMeta, pAllocationCallbacks); if (result != MA_TRUE) { fclose(pFile); return result; } return MA_TRUE; } -MA_API ma_bool32 ma_dr_mp3_init_file_w(ma_dr_mp3* pMP3, const wchar_t* pFilePath, const ma_allocation_callbacks* pAllocationCallbacks) +MA_API ma_bool32 ma_dr_mp3_init_file_with_metadata_w(ma_dr_mp3* pMP3, const wchar_t* pFilePath, ma_dr_mp3_meta_proc onMeta, void* pUserDataMeta, const ma_allocation_callbacks* pAllocationCallbacks) { ma_bool32 result; FILE* pFile; + if (pMP3 == NULL) { + return MA_FALSE; + } + MA_DR_MP3_ZERO_OBJECT(pMP3); if (ma_wfopen(&pFile, pFilePath, L"rb", pAllocationCallbacks) != MA_SUCCESS) { return MA_FALSE; } - result = ma_dr_mp3_init(pMP3, ma_dr_mp3__on_read_stdio, ma_dr_mp3__on_seek_stdio, (void*)pFile, pAllocationCallbacks); + result = ma_dr_mp3_init_internal(pMP3, ma_dr_mp3__on_read_stdio, ma_dr_mp3__on_seek_stdio, ma_dr_mp3__on_tell_stdio, onMeta, (void*)pFile, pUserDataMeta, pAllocationCallbacks); if (result != MA_TRUE) { fclose(pFile); return result; } return MA_TRUE; } +MA_API ma_bool32 ma_dr_mp3_init_file(ma_dr_mp3* pMP3, const char* pFilePath, const ma_allocation_callbacks* pAllocationCallbacks) +{ + return ma_dr_mp3_init_file_with_metadata(pMP3, pFilePath, NULL, NULL, pAllocationCallbacks); +} +MA_API ma_bool32 ma_dr_mp3_init_file_w(ma_dr_mp3* pMP3, const wchar_t* pFilePath, const ma_allocation_callbacks* pAllocationCallbacks) +{ + return ma_dr_mp3_init_file_with_metadata_w(pMP3, pFilePath, NULL, NULL, pAllocationCallbacks); +} #endif MA_API void ma_dr_mp3_uninit(ma_dr_mp3* pMP3) { @@ -92859,17 +95191,38 @@ static ma_uint64 ma_dr_mp3_read_pcm_frames_raw(ma_dr_mp3* pMP3, ma_uint64 frames MA_DR_MP3_ASSERT(pMP3 != NULL); MA_DR_MP3_ASSERT(pMP3->onRead != NULL); while (framesToRead > 0) { - ma_uint32 framesToConsume = (ma_uint32)MA_DR_MP3_MIN(pMP3->pcmFramesRemainingInMP3Frame, framesToRead); + ma_uint32 framesToConsume; + if (pMP3->currentPCMFrame < pMP3->delayInPCMFrames) { + ma_uint32 framesToSkip = (ma_uint32)MA_DR_MP3_MIN(pMP3->pcmFramesRemainingInMP3Frame, pMP3->delayInPCMFrames - pMP3->currentPCMFrame); + pMP3->currentPCMFrame += framesToSkip; + pMP3->pcmFramesConsumedInMP3Frame += framesToSkip; + pMP3->pcmFramesRemainingInMP3Frame -= framesToSkip; + } + framesToConsume = (ma_uint32)MA_DR_MP3_MIN(pMP3->pcmFramesRemainingInMP3Frame, framesToRead); + if (pMP3->totalPCMFrameCount != MA_UINT64_MAX && pMP3->totalPCMFrameCount > pMP3->paddingInPCMFrames) { + if (pMP3->currentPCMFrame < (pMP3->totalPCMFrameCount - pMP3->paddingInPCMFrames)) { + ma_uint64 framesRemainigToPadding = (pMP3->totalPCMFrameCount - pMP3->paddingInPCMFrames) - pMP3->currentPCMFrame; + if (framesToConsume > framesRemainigToPadding) { + framesToConsume = (ma_uint32)framesRemainigToPadding; + } + } else { + break; + } + } if (pBufferOut != NULL) { - #if defined(MA_DR_MP3_FLOAT_OUTPUT) - float* pFramesOutF32 = (float*)MA_DR_MP3_OFFSET_PTR(pBufferOut, sizeof(float) * totalFramesRead * pMP3->channels); - float* pFramesInF32 = (float*)MA_DR_MP3_OFFSET_PTR(&pMP3->pcmFrames[0], sizeof(float) * pMP3->pcmFramesConsumedInMP3Frame * pMP3->mp3FrameChannels); - MA_DR_MP3_COPY_MEMORY(pFramesOutF32, pFramesInF32, sizeof(float) * framesToConsume * pMP3->channels); - #else - ma_int16* pFramesOutS16 = (ma_int16*)MA_DR_MP3_OFFSET_PTR(pBufferOut, sizeof(ma_int16) * totalFramesRead * pMP3->channels); - ma_int16* pFramesInS16 = (ma_int16*)MA_DR_MP3_OFFSET_PTR(&pMP3->pcmFrames[0], sizeof(ma_int16) * pMP3->pcmFramesConsumedInMP3Frame * pMP3->mp3FrameChannels); - MA_DR_MP3_COPY_MEMORY(pFramesOutS16, pFramesInS16, sizeof(ma_int16) * framesToConsume * pMP3->channels); - #endif + #if defined(MA_DR_MP3_FLOAT_OUTPUT) + { + float* pFramesOutF32 = (float*)MA_DR_MP3_OFFSET_PTR(pBufferOut, sizeof(float) * totalFramesRead * pMP3->channels); + float* pFramesInF32 = (float*)MA_DR_MP3_OFFSET_PTR(&pMP3->pcmFrames[0], sizeof(float) * pMP3->pcmFramesConsumedInMP3Frame * pMP3->mp3FrameChannels); + MA_DR_MP3_COPY_MEMORY(pFramesOutF32, pFramesInF32, sizeof(float) * framesToConsume * pMP3->channels); + } + #else + { + ma_int16* pFramesOutS16 = (ma_int16*)MA_DR_MP3_OFFSET_PTR(pBufferOut, sizeof(ma_int16) * totalFramesRead * pMP3->channels); + ma_int16* pFramesInS16 = (ma_int16*)MA_DR_MP3_OFFSET_PTR(&pMP3->pcmFrames[0], sizeof(ma_int16) * pMP3->pcmFramesConsumedInMP3Frame * pMP3->mp3FrameChannels); + MA_DR_MP3_COPY_MEMORY(pFramesOutS16, pFramesInS16, sizeof(ma_int16) * framesToConsume * pMP3->channels); + } + #endif } pMP3->currentPCMFrame += framesToConsume; pMP3->pcmFramesConsumedInMP3Frame += framesToConsume; @@ -92879,6 +95232,9 @@ static ma_uint64 ma_dr_mp3_read_pcm_frames_raw(ma_dr_mp3* pMP3, ma_uint64 frames if (framesToRead == 0) { break; } + if (pMP3->totalPCMFrameCount != MA_UINT64_MAX && pMP3->totalPCMFrameCount > pMP3->paddingInPCMFrames && pMP3->currentPCMFrame >= (pMP3->totalPCMFrameCount - pMP3->paddingInPCMFrames)) { + break; + } MA_DR_MP3_ASSERT(pMP3->pcmFramesRemainingInMP3Frame == 0); if (ma_dr_mp3_decode_next_frame(pMP3) == 0) { break; @@ -92958,7 +95314,7 @@ static ma_bool32 ma_dr_mp3_seek_to_start_of_stream(ma_dr_mp3* pMP3) { MA_DR_MP3_ASSERT(pMP3 != NULL); MA_DR_MP3_ASSERT(pMP3->onSeek != NULL); - if (!ma_dr_mp3__on_seek(pMP3, 0, ma_dr_mp3_seek_origin_start)) { + if (!ma_dr_mp3__on_seek_64(pMP3, pMP3->streamStartOffset, MA_DR_MP3_SEEK_SET)) { return MA_FALSE; } ma_dr_mp3_reset(pMP3); @@ -93024,7 +95380,7 @@ static ma_bool32 ma_dr_mp3_seek_to_pcm_frame__seek_table(ma_dr_mp3* pMP3, ma_uin seekPoint.mp3FramesToDiscard = 0; seekPoint.pcmFramesToDiscard = 0; } - if (!ma_dr_mp3__on_seek_64(pMP3, seekPoint.seekPosInBytes, ma_dr_mp3_seek_origin_start)) { + if (!ma_dr_mp3__on_seek_64(pMP3, seekPoint.seekPosInBytes, MA_DR_MP3_SEEK_SET)) { return MA_FALSE; } ma_dr_mp3_reset(pMP3); @@ -93035,7 +95391,7 @@ static ma_bool32 ma_dr_mp3_seek_to_pcm_frame__seek_table(ma_dr_mp3* pMP3, ma_uin if (iMP3Frame == seekPoint.mp3FramesToDiscard-1) { pPCMFrames = (ma_dr_mp3d_sample_t*)pMP3->pcmFrames; } - pcmFramesRead = ma_dr_mp3_decode_next_frame_ex(pMP3, pPCMFrames); + pcmFramesRead = ma_dr_mp3_decode_next_frame_ex(pMP3, pPCMFrames, NULL, NULL); if (pcmFramesRead == 0) { return MA_FALSE; } @@ -93077,7 +95433,7 @@ MA_API ma_bool32 ma_dr_mp3_get_mp3_and_pcm_frame_count(ma_dr_mp3* pMP3, ma_uint6 totalMP3FrameCount = 0; for (;;) { ma_uint32 pcmFramesInCurrentMP3Frame; - pcmFramesInCurrentMP3Frame = ma_dr_mp3_decode_next_frame_ex(pMP3, NULL); + pcmFramesInCurrentMP3Frame = ma_dr_mp3_decode_next_frame_ex(pMP3, NULL, NULL, NULL); if (pcmFramesInCurrentMP3Frame == 0) { break; } @@ -93101,10 +95457,26 @@ MA_API ma_bool32 ma_dr_mp3_get_mp3_and_pcm_frame_count(ma_dr_mp3* pMP3, ma_uint6 MA_API ma_uint64 ma_dr_mp3_get_pcm_frame_count(ma_dr_mp3* pMP3) { ma_uint64 totalPCMFrameCount; - if (!ma_dr_mp3_get_mp3_and_pcm_frame_count(pMP3, NULL, &totalPCMFrameCount)) { + if (pMP3 == NULL) { return 0; } - return totalPCMFrameCount; + if (pMP3->totalPCMFrameCount != MA_UINT64_MAX) { + totalPCMFrameCount = pMP3->totalPCMFrameCount; + if (totalPCMFrameCount >= pMP3->delayInPCMFrames) { + totalPCMFrameCount -= pMP3->delayInPCMFrames; + } else { + } + if (totalPCMFrameCount >= pMP3->paddingInPCMFrames) { + totalPCMFrameCount -= pMP3->paddingInPCMFrames; + } else { + } + return totalPCMFrameCount; + } else { + if (!ma_dr_mp3_get_mp3_and_pcm_frame_count(pMP3, NULL, &totalPCMFrameCount)) { + return 0; + } + return totalPCMFrameCount; + } } MA_API ma_uint64 ma_dr_mp3_get_mp3_frame_count(ma_dr_mp3* pMP3) { @@ -93174,7 +95546,7 @@ MA_API ma_bool32 ma_dr_mp3_calculate_seek_points(ma_dr_mp3* pMP3, ma_uint32* pSe MA_DR_MP3_ASSERT(pMP3->streamCursor >= pMP3->dataSize); mp3FrameInfo[iMP3Frame].bytePos = pMP3->streamCursor - pMP3->dataSize; mp3FrameInfo[iMP3Frame].pcmFrameIndex = runningPCMFrameCount; - pcmFramesInCurrentMP3FrameIn = ma_dr_mp3_decode_next_frame_ex(pMP3, NULL); + pcmFramesInCurrentMP3FrameIn = ma_dr_mp3_decode_next_frame_ex(pMP3, NULL, NULL, NULL); if (pcmFramesInCurrentMP3FrameIn == 0) { return MA_FALSE; } @@ -93198,7 +95570,7 @@ MA_API ma_bool32 ma_dr_mp3_calculate_seek_points(ma_dr_mp3* pMP3, ma_uint32* pSe } mp3FrameInfo[MA_DR_MP3_COUNTOF(mp3FrameInfo)-1].bytePos = pMP3->streamCursor - pMP3->dataSize; mp3FrameInfo[MA_DR_MP3_COUNTOF(mp3FrameInfo)-1].pcmFrameIndex = runningPCMFrameCount; - pcmFramesInCurrentMP3FrameIn = ma_dr_mp3_decode_next_frame_ex(pMP3, NULL); + pcmFramesInCurrentMP3FrameIn = ma_dr_mp3_decode_next_frame_ex(pMP3, NULL, NULL, NULL); if (pcmFramesInCurrentMP3FrameIn == 0) { pSeekPoints[iSeekPoint].seekPosInBytes = mp3FrameInfo[0].bytePos; pSeekPoints[iSeekPoint].pcmFrameIndex = nextTargetPCMFrame; @@ -93264,6 +95636,8 @@ static float* ma_dr_mp3__full_read_and_close_f32(ma_dr_mp3* pMP3, ma_dr_mp3_conf pNewFrames = (float*)ma_dr_mp3__realloc_from_callbacks(pFrames, (size_t)newFramesBufferSize, (size_t)oldFramesBufferSize, &pMP3->allocationCallbacks); if (pNewFrames == NULL) { ma_dr_mp3__free_from_callbacks(pFrames, &pMP3->allocationCallbacks); + pFrames = NULL; + totalFramesRead = 0; break; } pFrames = pNewFrames; @@ -93315,6 +95689,8 @@ static ma_int16* ma_dr_mp3__full_read_and_close_s16(ma_dr_mp3* pMP3, ma_dr_mp3_c pNewFrames = (ma_int16*)ma_dr_mp3__realloc_from_callbacks(pFrames, (size_t)newFramesBufferSize, (size_t)oldFramesBufferSize, &pMP3->allocationCallbacks); if (pNewFrames == NULL) { ma_dr_mp3__free_from_callbacks(pFrames, &pMP3->allocationCallbacks); + pFrames = NULL; + totalFramesRead = 0; break; } pFrames = pNewFrames; @@ -93336,18 +95712,18 @@ static ma_int16* ma_dr_mp3__full_read_and_close_s16(ma_dr_mp3* pMP3, ma_dr_mp3_c } return pFrames; } -MA_API float* ma_dr_mp3_open_and_read_pcm_frames_f32(ma_dr_mp3_read_proc onRead, ma_dr_mp3_seek_proc onSeek, void* pUserData, ma_dr_mp3_config* pConfig, ma_uint64* pTotalFrameCount, const ma_allocation_callbacks* pAllocationCallbacks) +MA_API float* ma_dr_mp3_open_and_read_pcm_frames_f32(ma_dr_mp3_read_proc onRead, ma_dr_mp3_seek_proc onSeek, ma_dr_mp3_tell_proc onTell, void* pUserData, ma_dr_mp3_config* pConfig, ma_uint64* pTotalFrameCount, const ma_allocation_callbacks* pAllocationCallbacks) { ma_dr_mp3 mp3; - if (!ma_dr_mp3_init(&mp3, onRead, onSeek, pUserData, pAllocationCallbacks)) { + if (!ma_dr_mp3_init(&mp3, onRead, onSeek, onTell, NULL, pUserData, pAllocationCallbacks)) { return NULL; } return ma_dr_mp3__full_read_and_close_f32(&mp3, pConfig, pTotalFrameCount); } -MA_API ma_int16* ma_dr_mp3_open_and_read_pcm_frames_s16(ma_dr_mp3_read_proc onRead, ma_dr_mp3_seek_proc onSeek, void* pUserData, ma_dr_mp3_config* pConfig, ma_uint64* pTotalFrameCount, const ma_allocation_callbacks* pAllocationCallbacks) +MA_API ma_int16* ma_dr_mp3_open_and_read_pcm_frames_s16(ma_dr_mp3_read_proc onRead, ma_dr_mp3_seek_proc onSeek, ma_dr_mp3_tell_proc onTell, void* pUserData, ma_dr_mp3_config* pConfig, ma_uint64* pTotalFrameCount, const ma_allocation_callbacks* pAllocationCallbacks) { ma_dr_mp3 mp3; - if (!ma_dr_mp3_init(&mp3, onRead, onSeek, pUserData, pAllocationCallbacks)) { + if (!ma_dr_mp3_init(&mp3, onRead, onSeek, onTell, NULL, pUserData, pAllocationCallbacks)) { return NULL; } return ma_dr_mp3__full_read_and_close_s16(&mp3, pConfig, pTotalFrameCount); diff --git a/examples/parakeet-cli/CMakeLists.txt b/examples/parakeet-cli/CMakeLists.txt new file mode 100644 index 00000000000..adb9aba38ef --- /dev/null +++ b/examples/parakeet-cli/CMakeLists.txt @@ -0,0 +1,8 @@ +set(TARGET parakeet-cli) +add_executable(${TARGET} parakeet-cli.cpp) + +include(DefaultTargetOptions) + +target_link_libraries(${TARGET} PRIVATE common parakeet ${FFMPEG_LIBRARIES} ${CMAKE_THREAD_LIBS_INIT}) + +install(TARGETS ${TARGET} RUNTIME) diff --git a/examples/parakeet-cli/README.md b/examples/parakeet-cli/README.md new file mode 100644 index 00000000000..ccb8404f542 --- /dev/null +++ b/examples/parakeet-cli/README.md @@ -0,0 +1,106 @@ +# whisper.cpp/examples/parakeet-cli + +This is an example of using the [Parakeet] model in whisper.cpp. + +### Download converted model +```console +$ hf download ggml-org/parakeet-GGUF parakeet-tdt-0.6b-v3-f16.bin --local-dir models +``` + +### Building +```console +$ cmake -B build -S . +$ cmake --build build --target parakeet-cli -j 12 +``` + +### Usage +```console +$ ./build/bin/parakeet-cli --help + +usage: ./build/bin/parakeet-cli [options] file0 file1 ... +supported audio formats: flac, mp3, ogg, wav + +options: + -h, --help [default] show this help message and exit + -t N, --threads N [4 ] number of threads to use during computation + -m, --model FILE [models/ggml-parakeet-tdt-0.6b-v3.bin] model path + -f, --file FILE [ ] input audio file + -ng, --no-gpu [false ] disable GPU + -dev N, --device N [0 ] GPU device to use + -ps, --print-segments [false ] print segment information +``` + +### Example +```console +$ ./build/bin/parakeet-cli -m models/parakeet-tdt-0.6b-v3-f16.bin -f samples/jfk.wav +Processing audio (176000 samples, 11.00 seconds) +Processing audio: total_frames=1101, chunk_size=1101 +parakeet_decode: starting decode with n_frames=138 +And so, my fellow Americans, ask not what your country can do for you, ask what you can do for your country. +``` + +To print segment information: +```console +$ ./build/bin/parakeet-cli -m models/parakeet-tdt-0.6b-v3-f16.bin -f samples/jfk.wav --print-segments +Processing audio (176000 samples, 11.00 seconds) +Processing audio: total_frames=1101, chunk_size=1101 +parakeet_decode: starting decode with n_frames=138 +And so, my fellow Americans, ask not what your country can do for you, ask what you can do for your country. + +Segments (1): +Segment 0: [0 -> 1101] "And so, my fellow Americans, ask not what your country can do for you, ask what you can do for your country." +Tokens [38]: + [ 0] id= 1976 frame= 3 dur_idx= 4 dur_val= 4 p=0.9996 plog=-15.6206 t0= 24 t1= 56 word_start=true "▁And" + [ 1] id= 547 frame= 7 dur_idx= 4 dur_val= 4 p=0.9999 plog=-18.7922 t0= 56 t1= 88 word_start=true "▁so" + [ 2] id= 7877 frame= 11 dur_idx= 2 dur_val= 2 p=0.8451 plog=-14.5929 t0= 88 t1= 88 word_start=false "," + [ 3] id= 1103 frame= 13 dur_idx= 3 dur_val= 3 p=0.9996 plog=-15.6127 t0= 104 t1= 128 word_start=true "▁my" + [ 4] id= 309 frame= 16 dur_idx= 1 dur_val= 1 p=0.9912 plog=-11.9635 t0= 128 t1= 136 word_start=true "▁f" + [ 5] id= 530 frame= 17 dur_idx= 2 dur_val= 2 p=1.0000 plog=-13.5239 t0= 136 t1= 152 word_start=false "ell" + [ 6] id= 596 frame= 19 dur_idx= 3 dur_val= 3 p=1.0000 plog=-16.3120 t0= 152 t1= 176 word_start=false "ow" + [ 7] id= 3213 frame= 22 dur_idx= 4 dur_val= 4 p=0.9999 plog=-10.1462 t0= 176 t1= 208 word_start=true "▁Amer" + [ 8] id= 404 frame= 26 dur_idx= 4 dur_val= 4 p=1.0000 plog=-25.0910 t0= 208 t1= 240 word_start=false "ic" + [ 9] id= 667 frame= 30 dur_idx= 4 dur_val= 4 p=1.0000 plog=-27.1707 t0= 240 t1= 272 word_start=false "ans" + [10] id= 7877 frame= 37 dur_idx= 4 dur_val= 4 p=0.9094 plog=-16.3405 t0= 272 t1= 272 word_start=false "," + [11] id= 279 frame= 41 dur_idx= 4 dur_val= 4 p=0.9980 plog=-19.7244 t0= 328 t1= 360 word_start=true "▁a" + [12] id= 583 frame= 45 dur_idx= 4 dur_val= 4 p=1.0000 plog=-24.5312 t0= 360 t1= 392 word_start=false "sk" + [13] id= 1491 frame= 53 dur_idx= 4 dur_val= 4 p=1.0000 plog=-23.2991 t0= 424 t1= 456 word_start=true "▁not" + [14] id= 3470 frame= 65 dur_idx= 4 dur_val= 4 p=0.9995 plog=-16.7306 t0= 520 t1= 552 word_start=true "▁what" + [15] id= 3629 frame= 69 dur_idx= 2 dur_val= 2 p=0.8139 plog=-11.6486 t0= 552 t1= 568 word_start=true "▁your" + [16] id= 867 frame= 75 dur_idx= 1 dur_val= 1 p=0.9980 plog=-12.5265 t0= 600 t1= 608 word_start=true "▁co" + [17] id= 331 frame= 76 dur_idx= 2 dur_val= 2 p=1.0000 plog=-11.6697 t0= 608 t1= 624 word_start=false "un" + [18] id= 958 frame= 78 dur_idx= 2 dur_val= 2 p=1.0000 plog=-11.3621 t0= 624 t1= 640 word_start=false "tr" + [19] id= 7893 frame= 80 dur_idx= 2 dur_val= 2 p=1.0000 plog=-14.3245 t0= 640 t1= 656 word_start=false "y" + [20] id= 2059 frame= 82 dur_idx= 3 dur_val= 3 p=1.0000 plog=-17.7694 t0= 656 t1= 680 word_start=true "▁can" + [21] id= 458 frame= 85 dur_idx= 4 dur_val= 4 p=1.0000 plog=-23.2510 t0= 680 t1= 712 word_start=true "▁do" + [22] id= 509 frame= 89 dur_idx= 4 dur_val= 4 p=1.0000 plog=-23.0688 t0= 712 t1= 744 word_start=true "▁for" + [23] id= 1180 frame= 93 dur_idx= 4 dur_val= 4 p=0.9999 plog=-25.0567 t0= 744 t1= 776 word_start=true "▁you" + [24] id= 7877 frame= 98 dur_idx= 4 dur_val= 4 p=0.8820 plog=-14.2549 t0= 776 t1= 776 word_start=false "," + [25] id= 279 frame=102 dur_idx= 3 dur_val= 3 p=0.9992 plog=-16.8176 t0= 816 t1= 840 word_start=true "▁a" + [26] id= 583 frame=105 dur_idx= 4 dur_val= 4 p=1.0000 plog=-21.0352 t0= 840 t1= 872 word_start=false "sk" + [27] id= 3470 frame=109 dur_idx= 3 dur_val= 3 p=0.9999 plog=-15.4659 t0= 872 t1= 896 word_start=true "▁what" + [28] id= 1180 frame=112 dur_idx= 4 dur_val= 4 p=0.9997 plog=-17.6392 t0= 896 t1= 928 word_start=true "▁you" + [29] id= 2059 frame=116 dur_idx= 3 dur_val= 3 p=0.9999 plog=-15.5484 t0= 928 t1= 952 word_start=true "▁can" + [30] id= 458 frame=119 dur_idx= 2 dur_val= 2 p=1.0000 plog=-15.9953 t0= 952 t1= 968 word_start=true "▁do" + [31] id= 509 frame=121 dur_idx= 3 dur_val= 3 p=1.0000 plog=-15.9605 t0= 968 t1= 992 word_start=true "▁for" + [32] id= 3629 frame=124 dur_idx= 2 dur_val= 2 p=0.9994 plog=-12.2083 t0= 992 t1=1008 word_start=true "▁your" + [33] id= 867 frame=126 dur_idx= 2 dur_val= 2 p=0.9969 plog=-9.1252 t0=1008 t1=1024 word_start=true "▁co" + [34] id= 331 frame=128 dur_idx= 1 dur_val= 1 p=0.9999 plog=-12.6911 t0=1024 t1=1032 word_start=false "un" + [35] id= 958 frame=129 dur_idx= 1 dur_val= 1 p=1.0000 plog=-8.8885 t0=1032 t1=1040 word_start=false "tr" + [36] id= 7893 frame=130 dur_idx= 2 dur_val= 2 p=1.0000 plog=-14.1441 t0=1040 t1=1056 word_start=false "y" + [37] id= 7883 frame=132 dur_idx= 4 dur_val= 4 p=0.9567 plog=-11.5227 t0=1056 t1=1056 word_start=false "." +``` + +### Model conversion +Clone the original model from Hugging Face: +```console +$ git clone https://huggingface.co/nvidia/parakeet-tdt-0.6b-v3 +``` +Convert the model: +```console +(venv) $ python models/convert-parakeet-to-ggml.py \ + --model <path to cloned model> \ + --out-dir models \ + --out-name ggml-parakeet-tdt-0.6b-v3-f16.bin +``` + +[Parakeet]: https://huggingface.co/nvidia/parakeet-tdt-0.6b-v3 diff --git a/examples/parakeet-cli/parakeet-cli.cpp b/examples/parakeet-cli/parakeet-cli.cpp new file mode 100644 index 00000000000..03ddc7f8b8c --- /dev/null +++ b/examples/parakeet-cli/parakeet-cli.cpp @@ -0,0 +1,243 @@ +#include "parakeet.h" +#include "common-whisper.h" + +#include <cstdio> +#include <string> +#include <thread> +#include <vector> +#include <cstring> +#include <fstream> + +// command-line parameters +struct parakeet_params { + int32_t n_threads = std::min(4, (int32_t) std::thread::hardware_concurrency()); + + bool use_gpu = true; + int32_t gpu_device = 0; + + bool print_segments = false; + bool output_txt = false; + bool no_prints = false; + + std::string model = "models/ggml-parakeet-tdt-0.6b-v3.bin"; + std::string output_file = ""; + std::vector<std::string> fname_inp = {}; +}; + +static void parakeet_print_usage(int argc, char ** argv, const parakeet_params & params); + +static char * requires_value_error(const std::string & arg) { + fprintf(stderr, "error: argument %s requires value\n", arg.c_str()); + exit(1); +} + +static bool parakeet_params_parse(int argc, char ** argv, parakeet_params & params) { + if (const char * env_device = std::getenv("PARAKEET_ARG_DEVICE")) { + params.gpu_device = std::stoi(env_device); + } + + for (int i = 1; i < argc; i++) { + std::string arg = argv[i]; + + if (arg == "-"){ + params.fname_inp.push_back(arg); + continue; + } + + if (arg[0] != '-') { + params.fname_inp.push_back(arg); + continue; + } + + if (arg == "-h" || arg == "--help") { + parakeet_print_usage(argc, argv, params); + exit(0); + } + #define ARGV_NEXT (((i + 1) < argc) ? argv[++i] : requires_value_error(arg)) + else if (arg == "-t" || arg == "--threads") { params.n_threads = std::stoi(ARGV_NEXT); } + else if (arg == "-m" || arg == "--model") { params.model = ARGV_NEXT; } + else if (arg == "-f" || arg == "--file") { params.fname_inp.emplace_back(ARGV_NEXT); } + else if (arg == "-ng" || arg == "--no-gpu") { params.use_gpu = false; } + else if (arg == "-dev" || arg == "--device") { params.gpu_device = std::stoi(ARGV_NEXT); } + else if (arg == "-ps" || arg == "--print-segments") { params.print_segments = true; } + else if (arg == "-otxt" || arg == "--output-txt") { params.output_txt = true; } + else if (arg == "-of" || arg == "--output-file") { params.output_file = ARGV_NEXT; } + else if (arg == "-np" || arg == "--no-prints") { params.no_prints = true; } + else { + fprintf(stderr, "error: unknown argument: %s\n", arg.c_str()); + parakeet_print_usage(argc, argv, params); + exit(1); + } + } + + return true; +} + +static void parakeet_print_usage(int /*argc*/, char ** argv, const parakeet_params & params) { + fprintf(stderr, "\n"); + fprintf(stderr, "usage: %s [options] file0 file1 ...\n", argv[0]); + fprintf(stderr, "supported audio formats: flac, mp3, ogg, wav\n"); + fprintf(stderr, "\n"); + fprintf(stderr, "options:\n"); + fprintf(stderr, " -h, --help [default] show this help message and exit\n"); + fprintf(stderr, " -t N, --threads N [%-7d] number of threads to use during computation\n", params.n_threads); + fprintf(stderr, " -m, --model FILE [%-7s] model path\n", params.model.c_str()); + fprintf(stderr, " -f, --file FILE [%-7s] input audio file\n", ""); + fprintf(stderr, " -ng, --no-gpu [%-7s] disable GPU\n", params.use_gpu ? "false" : "true"); + fprintf(stderr, " -dev N, --device N [%-7d] GPU device to use\n", params.gpu_device); + fprintf(stderr, " -ps, --print-segments [%-7s] print segment information\n", params.print_segments ? "true" : "false"); + fprintf(stderr, " -otxt, --output-txt [%-7s] output result in a text file\n", params.output_txt ? "true" : "false"); + fprintf(stderr, " -of, --output-file FILE [%-7s] output file path (without file extension)\n", ""); + fprintf(stderr, " -np, --no-prints [%-7s] do not print anything other than the results\n", params.no_prints ? "true" : "false"); + fprintf(stderr, "\n"); +} + +void token_callback(parakeet_context * ctx, parakeet_state * state, const parakeet_token_data * token_data, void * user_data) { + bool * is_first = (bool *) user_data; + + const char * token_str = parakeet_token_to_str(ctx, token_data->id); + char text_buf[256]; + parakeet_token_to_text(token_str, *is_first, text_buf, sizeof(text_buf)); + printf("%s", text_buf); + fflush(stdout); + + *is_first = false; +} + +static void cb_log_disable(enum ggml_log_level , const char * , void * ) { } + +int main(int argc, char ** argv) { + ggml_backend_load_all(); + + parakeet_params params; + + if (parakeet_params_parse(argc, argv, params) == false) { + return 1; + } + + if (params.no_prints) { + parakeet_log_set(cb_log_disable, NULL); + } + + if (params.fname_inp.empty()) { + fprintf(stderr, "error: no input files specified\n"); + parakeet_print_usage(argc, argv, params); + return 1; + } + + struct parakeet_context_params ctx_params = parakeet_context_default_params(); + ctx_params.use_gpu = params.use_gpu; + ctx_params.gpu_device = params.gpu_device; + + if (!params.no_prints) { + fprintf(stderr, "Loading Parakeet model from: %s\n", params.model.c_str()); + } + + + struct parakeet_context * pctx = parakeet_init_from_file_with_params(params.model.c_str(), ctx_params); + if (pctx == nullptr) { + fprintf(stderr, "error: failed to load Parakeet model from '%s'\n", params.model.c_str()); + return 1; + } + + if (!params.no_prints) { + fprintf(stderr, "Successfully loaded Parakeet model\n"); + fprintf(stderr, "system_info: n_threads = %d / %d | %s\n", + params.n_threads, (int32_t) std::thread::hardware_concurrency(), parakeet_print_system_info()); + } + + // Process each input file + for (const auto & fname : params.fname_inp) { + if (!params.no_prints) { + fprintf(stderr, "\nProcessing file: %s\n", fname.c_str()); + } + + std::vector<float> pcmf32; + std::vector<std::vector<float>> pcmf32s; + if (!read_audio_data(fname.c_str(), pcmf32, pcmf32s, false)) { + fprintf(stderr, "error: failed to read audio file '%s'\n", fname.c_str()); + continue; + } + + if (pcmf32.empty()) { + fprintf(stderr, "error: no audio data in file '%s'\n", fname.c_str()); + continue; + } + + bool is_first = true; + struct parakeet_full_params full_params = parakeet_full_default_params(PARAKEET_SAMPLING_GREEDY); + full_params.n_threads = params.n_threads; + full_params.new_token_callback = token_callback; + full_params.new_token_callback_user_data = &is_first; + + const int mel_frames = (int)(pcmf32.size() / PARAKEET_HOP_LENGTH); + int ret = parakeet_full(pctx, full_params, pcmf32.data(), pcmf32.size()); + + if (ret != 0) { + fprintf(stderr, "error: failed to process audio file '%s'\n", fname.c_str()); + continue; + } + + printf("\n"); + + if (params.output_txt) { + const std::string fname_out = (!params.output_file.empty() ? params.output_file : fname) + ".txt"; + + std::ofstream fout(fname_out); + if (fout.is_open()) { + const int n_segments = parakeet_full_n_segments(pctx); + for (int i = 0; i < n_segments; ++i) { + const char * text = parakeet_full_get_segment_text(pctx, i); + fout << text << "\n"; + } + fout.close(); + if (!params.no_prints) { + fprintf(stderr, "Output written to: %s\n", fname_out.c_str()); + } + } else { + fprintf(stderr, "error: failed to open '%s' for writing\n", fname_out.c_str()); + } + } + + if (!params.no_prints) { + parakeet_print_timings(pctx); + } + + if (params.print_segments) { + const int n_segments = parakeet_full_n_segments(pctx); + fprintf(stderr, "\nSegments (%d):\n", n_segments); + + for (int i = 0; i < n_segments; i++) { + const char * text = parakeet_full_get_segment_text(pctx, i); + const int64_t t0 = parakeet_full_get_segment_t0(pctx, i); + const int64_t t1 = parakeet_full_get_segment_t1(pctx, i); + const int n_tokens = parakeet_full_n_tokens(pctx, i); + + fprintf(stderr, "Segment %d: [%lld -> %lld] \"%s\"\n", i, (long long)t0, (long long)t1, text); + fprintf(stderr, "Tokens [%d]:\n", n_tokens); + + for (int j = 0; j < n_tokens; j++) { + parakeet_token_data token_data = parakeet_full_get_token_data(pctx, i, j); + const char * token_str = parakeet_token_to_str(pctx, token_data.id); + + fprintf(stderr, " [%2d] id=%5d frame=%3d dur_idx=%2d dur_val=%2d p=%.4f plog=%.4f t0=%4lld t1=%4lld word_start=%s \"%s\"\n", + j, + token_data.id, + token_data.frame_index, + token_data.duration_idx, + token_data.duration_value, + token_data.p, + token_data.plog, + (long long)token_data.t0, + (long long)token_data.t1, + token_data.is_word_start ? "true": "false", + token_str); + } + } + } + } + + parakeet_free(pctx); + + return 0; +} diff --git a/examples/parakeet-quantize/CMakeLists.txt b/examples/parakeet-quantize/CMakeLists.txt new file mode 100644 index 00000000000..6b46da18d27 --- /dev/null +++ b/examples/parakeet-quantize/CMakeLists.txt @@ -0,0 +1,7 @@ +set(TARGET parakeet-quantize) +add_executable(${TARGET} parakeet-quantize.cpp) + +include(DefaultTargetOptions) + +target_link_libraries(${TARGET} PRIVATE common parakeet ${CMAKE_THREAD_LIBS_INIT}) +install(TARGETS ${TARGET} RUNTIME) diff --git a/examples/parakeet-quantize/parakeet-quantize.cpp b/examples/parakeet-quantize/parakeet-quantize.cpp new file mode 100644 index 00000000000..a5d9616420f --- /dev/null +++ b/examples/parakeet-quantize/parakeet-quantize.cpp @@ -0,0 +1,230 @@ +#include "ggml.h" +#include "ggml-backend.h" + +#include "common-ggml.h" + +#include <cassert> +#include <cstdio> +#include <cstring> +#include <fstream> +#include <string> +#include <vector> + +struct parakeet_hparams { + int32_t n_vocab = 0; + int32_t n_audio_ctx = 0; + int32_t n_audio_state = 0; + int32_t n_audio_head = 0; + int32_t n_audio_layer = 0; + int32_t n_mels = 0; + int32_t ftype = 0; + int32_t n_fft = 0; + int32_t subsampling_factor = 0; + int32_t n_subsampling_channels = 0; + int32_t n_conv_kernel = 0; + int32_t n_pred_dim = 0; + int32_t n_pred_layers = 0; + int32_t n_tdt_durations = 0; + int32_t n_max_tokens = 0; +}; + +static bool parakeet_model_quantize(const std::string & fname_inp, const std::string & fname_out, ggml_ftype ftype) { + printf("%s: loading model from '%s'\n", __func__, fname_inp.c_str()); + + auto finp = std::ifstream(fname_inp, std::ios::binary); + if (!finp) { + fprintf(stderr, "%s: failed to open '%s' for reading\n", __func__, fname_inp.c_str()); + return false; + } + + auto fout = std::ofstream(fname_out, std::ios::binary); + if (!fout) { + fprintf(stderr, "%s: failed to open '%s' for writing\n", __func__, fname_out.c_str()); + return false; + } + + // magic + { + uint32_t magic; + finp.read((char *) &magic, sizeof(magic)); + if (magic != GGML_FILE_MAGIC) { + fprintf(stderr, "%s: invalid model file (bad magic)\n", __func__); + return false; + } + fout.write((char *) &magic, sizeof(magic)); + } + + // hparams + parakeet_hparams hparams; + { + finp.read((char *) &hparams.n_vocab, sizeof(hparams.n_vocab)); + finp.read((char *) &hparams.n_audio_ctx, sizeof(hparams.n_audio_ctx)); + finp.read((char *) &hparams.n_audio_state, sizeof(hparams.n_audio_state)); + finp.read((char *) &hparams.n_audio_head, sizeof(hparams.n_audio_head)); + finp.read((char *) &hparams.n_audio_layer, sizeof(hparams.n_audio_layer)); + finp.read((char *) &hparams.n_mels, sizeof(hparams.n_mels)); + finp.read((char *) &hparams.ftype, sizeof(hparams.ftype)); + finp.read((char *) &hparams.n_fft, sizeof(hparams.n_fft)); + finp.read((char *) &hparams.subsampling_factor, sizeof(hparams.subsampling_factor)); + finp.read((char *) &hparams.n_subsampling_channels, sizeof(hparams.n_subsampling_channels)); + finp.read((char *) &hparams.n_conv_kernel, sizeof(hparams.n_conv_kernel)); + finp.read((char *) &hparams.n_pred_dim, sizeof(hparams.n_pred_dim)); + finp.read((char *) &hparams.n_pred_layers, sizeof(hparams.n_pred_layers)); + finp.read((char *) &hparams.n_tdt_durations, sizeof(hparams.n_tdt_durations)); + finp.read((char *) &hparams.n_max_tokens, sizeof(hparams.n_max_tokens)); + + const int32_t qntvr_src = hparams.ftype / GGML_QNT_VERSION_FACTOR; + const int32_t ftype_dst = GGML_QNT_VERSION * GGML_QNT_VERSION_FACTOR + ftype; + + fprintf(stderr, "%s: n_vocab = %d\n", __func__, hparams.n_vocab); + fprintf(stderr, "%s: n_audio_state = %d\n", __func__, hparams.n_audio_state); + fprintf(stderr, "%s: n_audio_layer = %d\n", __func__, hparams.n_audio_layer); + fprintf(stderr, "%s: n_mels = %d\n", __func__, hparams.n_mels); + fprintf(stderr, "%s: ftype (src) = %d\n", __func__, hparams.ftype); + fprintf(stderr, "%s: qntvr (src) = %d\n", __func__, qntvr_src); + fprintf(stderr, "%s: ftype (dst) = %d\n", __func__, ftype_dst); + fprintf(stderr, "%s: qntvr (dst) = %d\n", __func__, GGML_QNT_VERSION); + + fout.write((char *) &hparams.n_vocab, sizeof(hparams.n_vocab)); + fout.write((char *) &hparams.n_audio_ctx, sizeof(hparams.n_audio_ctx)); + fout.write((char *) &hparams.n_audio_state, sizeof(hparams.n_audio_state)); + fout.write((char *) &hparams.n_audio_head, sizeof(hparams.n_audio_head)); + fout.write((char *) &hparams.n_audio_layer, sizeof(hparams.n_audio_layer)); + fout.write((char *) &hparams.n_mels, sizeof(hparams.n_mels)); + fout.write((char *) &ftype_dst, sizeof(ftype_dst)); + fout.write((char *) &hparams.n_fft, sizeof(hparams.n_fft)); + fout.write((char *) &hparams.subsampling_factor, sizeof(hparams.subsampling_factor)); + fout.write((char *) &hparams.n_subsampling_channels, sizeof(hparams.n_subsampling_channels)); + fout.write((char *) &hparams.n_conv_kernel, sizeof(hparams.n_conv_kernel)); + fout.write((char *) &hparams.n_pred_dim, sizeof(hparams.n_pred_dim)); + fout.write((char *) &hparams.n_pred_layers, sizeof(hparams.n_pred_layers)); + fout.write((char *) &hparams.n_tdt_durations, sizeof(hparams.n_tdt_durations)); + fout.write((char *) &hparams.n_max_tokens, sizeof(hparams.n_max_tokens)); + } + + // mel filterbank + { + int32_t n_mel, n_fb; + finp.read((char *) &n_mel, sizeof(n_mel)); + fout.write((char *) &n_mel, sizeof(n_mel)); + finp.read((char *) &n_fb, sizeof(n_fb)); + fout.write((char *) &n_fb, sizeof(n_fb)); + + const size_t n = (size_t) n_mel * n_fb; + std::vector<float> buf(n); + finp.read((char *) buf.data(), n * sizeof(float)); + fout.write((char *) buf.data(), n * sizeof(float)); + } + + // window function + { + int32_t n_window; + finp.read((char *) &n_window, sizeof(n_window)); + fout.write((char *) &n_window, sizeof(n_window)); + + std::vector<float> buf(n_window); + finp.read((char *) buf.data(), n_window * sizeof(float)); + fout.write((char *) buf.data(), n_window * sizeof(float)); + } + + // TDT durations + { + std::vector<uint32_t> buf(hparams.n_tdt_durations); + finp.read((char *) buf.data(), hparams.n_tdt_durations * sizeof(uint32_t)); + fout.write((char *) buf.data(), hparams.n_tdt_durations * sizeof(uint32_t)); + } + + // vocab + { + int32_t n_tokens; + finp.read((char *) &n_tokens, sizeof(n_tokens)); + fout.write((char *) &n_tokens, sizeof(n_tokens)); + + for (int i = 0; i < n_tokens; ++i) { + int32_t len; + finp.read((char *) &len, sizeof(len)); + fout.write((char *) &len, sizeof(len)); + + std::string token(len, '\0'); + finp.read(&token[0], len); + fout.write(&token[0], len); + } + } + + // tensors — quantize 2D weights skipping tensors that must stay F32: + // ggml_ssm_conv / ggml_conv2d_dw CUDA kernels require F32 weights. + // pos_bias_u / pos_bias_v are declared F32 in the loader. + const std::vector<std::string> to_quant = { ".*" }; + std::vector<std::string> to_skip = { + // CUDA kernel constraints (ggml_ssm_conv / ggml_conv2d_dw require F32 weights) + "encoder\\.layers\\..+\\.conv\\.depthwise_conv\\.weight", + // Declared F32 in loader (pos_bias tensors) + "encoder\\.layers\\..+\\.self_attn\\.pos_bias_u", + "encoder\\.layers\\..+\\.self_attn\\.pos_bias_v", + }; + + // Prediction/joint tensors use n_pred_dim as their inner dimension. K-quant + // types (block size 256) cannot quantize 640 evenly, so keep them F32. For + // other types (Q8_0, Q4_0, block size 32) 640 is divisible and they can be + // quantized normally. The loader mirrors this logic at load time. + { + const ggml_type qtype = ggml_ftype_to_ggml_type(ftype); + const int32_t blck = ggml_blck_size(qtype); + if (blck > 1 && hparams.n_pred_dim % blck != 0) { + to_skip.push_back("decoder\\.prediction\\.embed\\.weight"); + to_skip.push_back("decoder\\.prediction\\.dec_rnn\\.lstm\\.weight_ih_l.*"); + to_skip.push_back("decoder\\.prediction\\.dec_rnn\\.lstm\\.weight_hh_l.*"); + to_skip.push_back("joint\\.pred\\.weight"); + to_skip.push_back("joint\\.joint_net\\.2\\.weight"); + } + } + + if (!ggml_common_quantize_0(finp, fout, ftype, to_quant, to_skip)) { + fprintf(stderr, "%s: failed to quantize tensors\n", __func__); + return false; + } + + finp.close(); + fout.close(); + + return true; +} + +int main(int argc, char ** argv) { + ggml_backend_load_all(); + + if (argc != 4) { + fprintf(stderr, "usage: %s model-f32.bin model-quant.bin type\n", argv[0]); + ggml_print_ftypes(stderr); + return 1; + } + + // initialise F16 lookup tables + { + struct ggml_init_params params = { 0, NULL, false }; + struct ggml_context * ctx = ggml_init(params); + ggml_free(ctx); + } + + const std::string fname_inp = argv[1]; + const std::string fname_out = argv[2]; + const ggml_ftype ftype = ggml_parse_ftype(argv[3]); + + if (ftype == GGML_FTYPE_UNKNOWN) { + fprintf(stderr, "%s: invalid quantization type\n", argv[0]); + ggml_print_ftypes(stderr); + return 1; + } + + const int64_t t_start_us = ggml_time_us(); + + if (!parakeet_model_quantize(fname_inp, fname_out, ftype)) { + fprintf(stderr, "%s: failed to quantize model from '%s'\n", argv[0], fname_inp.c_str()); + return 1; + } + + printf("\n%s: quantize time = %8.2f ms\n", argv[0], (ggml_time_us() - t_start_us) / 1000.0f); + printf("%s: output model = %s\n", argv[0], fname_out.c_str()); + + return 0; +} diff --git a/examples/server/README.md b/examples/server/README.md index ffba5f4edf5..8d4c802b8bf 100644 --- a/examples/server/README.md +++ b/examples/server/README.md @@ -40,6 +40,7 @@ options: -l LANG, --language LANG [en ] spoken language ('auto' for auto-detect) -dl, --detect-language [false ] exit after automatically detecting language --prompt PROMPT [ ] initial prompt + --carry-initial-prompt [false ] always prepend initial prompt -m FNAME, --model FNAME [models/ggml-base.en.bin] model path -oved D, --ov-e-device DNAME [CPU ] the OpenVINO device used for encode inference -dtw MODEL --dtw MODEL [ ] compute token-level timestamps @@ -78,6 +79,8 @@ curl 127.0.0.1:8080/inference \ -F file="@<file-path>" \ -F temperature="0.0" \ -F temperature_inc="0.2" \ +-F prompt="<prompt>" \ +-F carry_initial_prompt="true" \ -F response_format="json" ``` diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 866ac4eafaa..b87ef27375f 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -56,11 +56,11 @@ inline void signal_handler(int signal) { struct server_params { - std::string hostname = "127.0.0.1"; - std::string public_path = "examples/server/public"; - std::string request_path = ""; + std::string hostname = "127.0.0.1"; + std::string public_path = "examples/server/public"; + std::string request_path = ""; std::string inference_path = "/inference"; - std::string tmp_dir = "."; + std::string tmp_dir = "."; int32_t port = 8080; int32_t read_timeout = 600; @@ -87,49 +87,47 @@ struct whisper_params { float logprob_thold = -1.00f; float temperature = 0.00f; float temperature_inc = 0.20f; - float no_speech_thold = 0.6f; - - bool debug_mode = false; - bool translate = false; - bool detect_language = false; - bool diarize = false; - bool tinydiarize = false; - bool split_on_word = false; - bool no_fallback = false; - bool print_special = false; - bool print_colors = false; - bool print_realtime = false; - bool print_progress = false; - bool no_timestamps = false; - bool use_gpu = true; - bool flash_attn = true; - bool suppress_nst = false; - bool no_context = true; + float no_speech_thold = 0.6f; + + bool debug_mode = false; + bool translate = false; + bool detect_language = false; + bool diarize = false; + bool tinydiarize = false; + bool split_on_word = false; + bool no_fallback = false; + bool print_special = false; + bool print_colors = false; + bool print_realtime = false; + bool print_progress = false; + bool no_timestamps = false; + bool token_timestamps = true; + bool use_gpu = true; + bool flash_attn = true; + int32_t gpu_device = 0; + bool suppress_nst = false; + bool no_context = true; bool no_language_probabilities = false; - - std::string language = "en"; - std::string prompt = ""; - std::string font_path = "/System/Library/Fonts/Supplemental/Courier New Bold.ttf"; - std::string model = "models/ggml-base.en.bin"; - - std::string response_format = json_format; - - // [TDRZ] speaker turn string - std::string tdrz_speaker_turn = " [SPEAKER_TURN]"; // TODO: set from command line - + bool carry_initial_prompt = false; + + std::string language = "en"; + std::string prompt = ""; + std::string font_path = "/System/Library/Fonts/Supplemental/Courier New Bold.ttf"; + std::string model = "models/ggml-base.en.bin"; + std::string response_format = json_format; + std::string tdrz_speaker_turn = " [SPEAKER_TURN]"; // TODO: set from command line std::string openvino_encode_device = "CPU"; - - std::string dtw = ""; + std::string dtw = ""; // Voice Activity Detection (VAD) parameters - bool vad = false; - std::string vad_model = ""; - float vad_threshold = 0.5f; - int vad_min_speech_duration_ms = 250; + bool vad = false; + std::string vad_model = ""; + float vad_threshold = 0.5f; + int vad_min_speech_duration_ms = 250; int vad_min_silence_duration_ms = 100; - float vad_max_speech_duration_s = FLT_MAX; - int vad_speech_pad_ms = 30; - float vad_samples_overlap = 0.1f; + float vad_max_speech_duration_s = FLT_MAX; + int vad_speech_pad_ms = 30; + float vad_samples_overlap = 0.1f; }; void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & params, const server_params& sparams) { @@ -137,50 +135,52 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para fprintf(stderr, "usage: %s [options] \n", argv[0]); fprintf(stderr, "\n"); fprintf(stderr, "options:\n"); - fprintf(stderr, " -h, --help [default] show this help message and exit\n"); - fprintf(stderr, " -t N, --threads N [%-7d] number of threads to use during computation\n", params.n_threads); - fprintf(stderr, " -p N, --processors N [%-7d] number of processors to use during computation\n", params.n_processors); - fprintf(stderr, " -ot N, --offset-t N [%-7d] time offset in milliseconds\n", params.offset_t_ms); - fprintf(stderr, " -on N, --offset-n N [%-7d] segment index offset\n", params.offset_n); - fprintf(stderr, " -d N, --duration N [%-7d] duration of audio to process in milliseconds\n", params.duration_ms); - fprintf(stderr, " -mc N, --max-context N [%-7d] maximum number of text context tokens to store\n", params.max_context); - fprintf(stderr, " -ml N, --max-len N [%-7d] maximum segment length in characters\n", params.max_len); - fprintf(stderr, " -sow, --split-on-word [%-7s] split on word rather than on token\n", params.split_on_word ? "true" : "false"); - fprintf(stderr, " -bo N, --best-of N [%-7d] number of best candidates to keep\n", params.best_of); - fprintf(stderr, " -bs N, --beam-size N [%-7d] beam size for beam search\n", params.beam_size); - fprintf(stderr, " -ac N, --audio-ctx N [%-7d] audio context size (0 - all)\n", params.audio_ctx); - fprintf(stderr, " -wt N, --word-thold N [%-7.2f] word timestamp probability threshold\n", params.word_thold); - fprintf(stderr, " -et N, --entropy-thold N [%-7.2f] entropy threshold for decoder fail\n", params.entropy_thold); - fprintf(stderr, " -lpt N, --logprob-thold N [%-7.2f] log probability threshold for decoder fail\n", params.logprob_thold); - fprintf(stderr, " -debug, --debug-mode [%-7s] enable debug mode (eg. dump log_mel)\n", params.debug_mode ? "true" : "false"); - fprintf(stderr, " -tr, --translate [%-7s] translate from source language to english\n", params.translate ? "true" : "false"); - fprintf(stderr, " -di, --diarize [%-7s] stereo audio diarization\n", params.diarize ? "true" : "false"); - fprintf(stderr, " -tdrz, --tinydiarize [%-7s] enable tinydiarize (requires a tdrz model)\n", params.tinydiarize ? "true" : "false"); - fprintf(stderr, " -nf, --no-fallback [%-7s] do not use temperature fallback while decoding\n", params.no_fallback ? "true" : "false"); - fprintf(stderr, " -ps, --print-special [%-7s] print special tokens\n", params.print_special ? "true" : "false"); - fprintf(stderr, " -pc, --print-colors [%-7s] print colors\n", params.print_colors ? "true" : "false"); - fprintf(stderr, " -pr, --print-realtime [%-7s] print output in realtime\n", params.print_realtime ? "true" : "false"); - fprintf(stderr, " -pp, --print-progress [%-7s] print progress\n", params.print_progress ? "true" : "false"); - fprintf(stderr, " -nt, --no-timestamps [%-7s] do not print timestamps\n", params.no_timestamps ? "true" : "false"); - fprintf(stderr, " -l LANG, --language LANG [%-7s] spoken language ('auto' for auto-detect)\n", params.language.c_str()); - fprintf(stderr, " -dl, --detect-language [%-7s] exit after automatically detecting language\n", params.detect_language ? "true" : "false"); - fprintf(stderr, " --prompt PROMPT [%-7s] initial prompt\n", params.prompt.c_str()); - fprintf(stderr, " -m FNAME, --model FNAME [%-7s] model path\n", params.model.c_str()); - fprintf(stderr, " -oved D, --ov-e-device DNAME [%-7s] the OpenVINO device used for encode inference\n", params.openvino_encode_device.c_str()); + fprintf(stderr, " -h, --help [default] show this help message and exit\n"); + fprintf(stderr, " -t N, --threads N [%-7d] number of threads to use during computation\n", params.n_threads); + fprintf(stderr, " -p N, --processors N [%-7d] number of processors to use during computation\n", params.n_processors); + fprintf(stderr, " -ot N, --offset-t N [%-7d] time offset in milliseconds\n", params.offset_t_ms); + fprintf(stderr, " -on N, --offset-n N [%-7d] segment index offset\n", params.offset_n); + fprintf(stderr, " -d N, --duration N [%-7d] duration of audio to process in milliseconds\n", params.duration_ms); + fprintf(stderr, " -mc N, --max-context N [%-7d] maximum number of text context tokens to store\n", params.max_context); + fprintf(stderr, " -ml N, --max-len N [%-7d] maximum segment length in characters\n", params.max_len); + fprintf(stderr, " -sow, --split-on-word [%-7s] split on word rather than on token\n", params.split_on_word ? "true" : "false"); + fprintf(stderr, " -bo N, --best-of N [%-7d] number of best candidates to keep\n", params.best_of); + fprintf(stderr, " -bs N, --beam-size N [%-7d] beam size for beam search\n", params.beam_size); + fprintf(stderr, " -ac N, --audio-ctx N [%-7d] audio context size (0 - all)\n", params.audio_ctx); + fprintf(stderr, " -wt N, --word-thold N [%-7.2f] word timestamp probability threshold\n", params.word_thold); + fprintf(stderr, " -et N, --entropy-thold N [%-7.2f] entropy threshold for decoder fail\n", params.entropy_thold); + fprintf(stderr, " -lpt N, --logprob-thold N [%-7.2f] log probability threshold for decoder fail\n", params.logprob_thold); + fprintf(stderr, " -debug, --debug-mode [%-7s] enable debug mode (eg. dump log_mel)\n", params.debug_mode ? "true" : "false"); + fprintf(stderr, " -tr, --translate [%-7s] translate from source language to english\n", params.translate ? "true" : "false"); + fprintf(stderr, " -di, --diarize [%-7s] stereo audio diarization\n", params.diarize ? "true" : "false"); + fprintf(stderr, " -tdrz, --tinydiarize [%-7s] enable tinydiarize (requires a tdrz model)\n", params.tinydiarize ? "true" : "false"); + fprintf(stderr, " -nf, --no-fallback [%-7s] do not use temperature fallback while decoding\n", params.no_fallback ? "true" : "false"); + fprintf(stderr, " -ps, --print-special [%-7s] print special tokens\n", params.print_special ? "true" : "false"); + fprintf(stderr, " -pc, --print-colors [%-7s] print colors\n", params.print_colors ? "true" : "false"); + fprintf(stderr, " -pr, --print-realtime [%-7s] print output in realtime\n", params.print_realtime ? "true" : "false"); + fprintf(stderr, " -pp, --print-progress [%-7s] print progress\n", params.print_progress ? "true" : "false"); + fprintf(stderr, " -nt, --no-timestamps [%-7s] do not print timestamps\n", params.no_timestamps ? "true" : "false"); + fprintf(stderr, " -l LANG, --language LANG [%-7s] spoken language ('auto' for auto-detect)\n", params.language.c_str()); + fprintf(stderr, " -dl, --detect-language [%-7s] exit after automatically detecting language\n", params.detect_language ? "true" : "false"); + fprintf(stderr, " --prompt PROMPT [%-7s] initial prompt\n", params.prompt.c_str()); + fprintf(stderr, " --carry-initial-prompt [%-7s] always prepend initial prompt\n", params.carry_initial_prompt ? "true" : "false"); + fprintf(stderr, " -m FNAME, --model FNAME [%-7s] model path\n", params.model.c_str()); + fprintf(stderr, " -oved D, --ov-e-device DNAME [%-7s] the OpenVINO device used for encode inference\n", params.openvino_encode_device.c_str()); // server params - fprintf(stderr, " -dtw MODEL --dtw MODEL [%-7s] compute token-level timestamps\n", params.dtw.c_str()); - fprintf(stderr, " --host HOST, [%-7s] Hostname/ip-adress for the server\n", sparams.hostname.c_str()); - fprintf(stderr, " --port PORT, [%-7d] Port number for the server\n", sparams.port); - fprintf(stderr, " --public PATH, [%-7s] Path to the public folder\n", sparams.public_path.c_str()); - fprintf(stderr, " --request-path PATH, [%-7s] Request path for all requests\n", sparams.request_path.c_str()); - fprintf(stderr, " --inference-path PATH, [%-7s] Inference path for all requests\n", sparams.inference_path.c_str()); - fprintf(stderr, " --convert, [%-7s] Convert audio to WAV, requires ffmpeg on the server\n", sparams.ffmpeg_converter ? "true" : "false"); - fprintf(stderr, " --tmp-dir, [%-7s] Temporary directory for ffmpeg transcoded files\n", sparams.tmp_dir.c_str()); - fprintf(stderr, " -sns, --suppress-nst [%-7s] suppress non-speech tokens\n", params.suppress_nst ? "true" : "false"); - fprintf(stderr, " -nth N, --no-speech-thold N [%-7.2f] no speech threshold\n", params.no_speech_thold); - fprintf(stderr, " -ng, --no-gpu [%-7s] do not use gpu\n", params.use_gpu ? "false" : "true"); - fprintf(stderr, " -fa, --flash-attn [%-7s] enable flash attention\n", params.flash_attn ? "true" : "false"); - fprintf(stderr, " -nfa, --no-flash-attn [%-7s] disable flash attention\n", params.flash_attn ? "false" : "true"); + fprintf(stderr, " -dtw MODEL --dtw MODEL [%-7s] compute token-level timestamps\n", params.dtw.c_str()); + fprintf(stderr, " --host HOST, [%-7s] Hostname/ip-adress for the server\n", sparams.hostname.c_str()); + fprintf(stderr, " --port PORT, [%-7d] Port number for the server\n", sparams.port); + fprintf(stderr, " --public PATH, [%-7s] Path to the public folder\n", sparams.public_path.c_str()); + fprintf(stderr, " --request-path PATH, [%-7s] Request path for all requests\n", sparams.request_path.c_str()); + fprintf(stderr, " --inference-path PATH, [%-7s] Inference path for all requests\n", sparams.inference_path.c_str()); + fprintf(stderr, " --convert, [%-7s] Convert audio to WAV, requires ffmpeg on the server\n", sparams.ffmpeg_converter ? "true" : "false"); + fprintf(stderr, " --tmp-dir, [%-7s] Temporary directory for ffmpeg transcoded files\n", sparams.tmp_dir.c_str()); + fprintf(stderr, " -sns, --suppress-nst [%-7s] suppress non-speech tokens\n", params.suppress_nst ? "true" : "false"); + fprintf(stderr, " -nth N, --no-speech-thold N [%-7.2f] no speech threshold\n", params.no_speech_thold); + fprintf(stderr, " -ng, --no-gpu [%-7s] do not use gpu\n", params.use_gpu ? "false" : "true"); + fprintf(stderr, " -dev N, --device N [%-7d] GPU device ID (default: 0)\n", params.gpu_device); + fprintf(stderr, " -fa, --flash-attn [%-7s] enable flash attention\n", params.flash_attn ? "true" : "false"); + fprintf(stderr, " -nfa, --no-flash-attn [%-7s] disable flash attention\n", params.flash_attn ? "false" : "true"); fprintf(stderr, " -nlp, --no-language-probabilities [%-7s] exclude language probabilities from verbose_json output\n", params.no_language_probabilities ? "true" : "false"); // Voice Activity Detection (VAD) parameters fprintf(stderr, "\nVoice Activity Detection (VAD) options:\n"); @@ -188,16 +188,18 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para fprintf(stderr, " -vm FNAME, --vad-model FNAME [%-7s] VAD model path\n", params.vad_model.c_str()); fprintf(stderr, " -vt N, --vad-threshold N [%-7.2f] VAD threshold for speech recognition\n", params.vad_threshold); fprintf(stderr, " -vspd N, --vad-min-speech-duration-ms N [%-7d] VAD min speech duration (0.0-1.0)\n", params.vad_min_speech_duration_ms); - fprintf(stderr, " -vsd N, --vad-min-silence-duration-ms N [%-7d] VAD min silence duration (to split segments)\n", params.vad_min_silence_duration_ms); - fprintf(stderr, " -vmsd N, --vad-max-speech-duration-s N [%-7s] VAD max speech duration (auto-split longer)\n", params.vad_max_speech_duration_s == FLT_MAX ? - std::string("FLT_MAX").c_str() : - std::to_string(params.vad_max_speech_duration_s).c_str()); + fprintf(stderr, " -vsd N, --vad-min-silence-duration-ms N [%-7d] VAD min silence duration (to split segments)\n", params.vad_min_silence_duration_ms); + fprintf(stderr, " -vmsd N, --vad-max-speech-duration-s N [%-7s] VAD max speech duration (auto-split longer)\n", params.vad_max_speech_duration_s == FLT_MAX ? std::string("FLT_MAX").c_str() : std::to_string(params.vad_max_speech_duration_s).c_str()); fprintf(stderr, " -vp N, --vad-speech-pad-ms N [%-7d] VAD speech padding (extend segments)\n", params.vad_speech_pad_ms); fprintf(stderr, " -vo N, --vad-samples-overlap N [%-7.2f] VAD samples overlap (seconds between segments)\n", params.vad_samples_overlap); fprintf(stderr, "\n"); } bool whisper_params_parse(int argc, char ** argv, whisper_params & params, server_params & sparams) { + if (const char * env_device = std::getenv("WHISPER_ARG_DEVICE")) { + params.gpu_device = std::stoi(env_device); + } + for (int i = 1; i < argc; i++) { std::string arg = argv[i]; @@ -205,62 +207,64 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params, serve whisper_print_usage(argc, argv, params, sparams); exit(0); } - else if (arg == "-t" || arg == "--threads") { params.n_threads = std::stoi(argv[++i]); } - else if (arg == "-p" || arg == "--processors") { params.n_processors = std::stoi(argv[++i]); } - else if (arg == "-ot" || arg == "--offset-t") { params.offset_t_ms = std::stoi(argv[++i]); } - else if (arg == "-on" || arg == "--offset-n") { params.offset_n = std::stoi(argv[++i]); } - else if (arg == "-d" || arg == "--duration") { params.duration_ms = std::stoi(argv[++i]); } - else if (arg == "-mc" || arg == "--max-context") { params.max_context = std::stoi(argv[++i]); } - else if (arg == "-ml" || arg == "--max-len") { params.max_len = std::stoi(argv[++i]); } - else if (arg == "-bo" || arg == "--best-of") { params.best_of = std::stoi(argv[++i]); } - else if (arg == "-bs" || arg == "--beam-size") { params.beam_size = std::stoi(argv[++i]); } - else if (arg == "-ac" || arg == "--audio-ctx") { params.audio_ctx = std::stoi(argv[++i]); } - else if (arg == "-wt" || arg == "--word-thold") { params.word_thold = std::stof(argv[++i]); } - else if (arg == "-et" || arg == "--entropy-thold") { params.entropy_thold = std::stof(argv[++i]); } - else if (arg == "-lpt" || arg == "--logprob-thold") { params.logprob_thold = std::stof(argv[++i]); } - else if (arg == "-debug"|| arg == "--debug-mode") { params.debug_mode = true; } - else if (arg == "-tr" || arg == "--translate") { params.translate = true; } - else if (arg == "-di" || arg == "--diarize") { params.diarize = true; } - else if (arg == "-tdrz" || arg == "--tinydiarize") { params.tinydiarize = true; } - else if (arg == "-sow" || arg == "--split-on-word") { params.split_on_word = true; } - else if (arg == "-nf" || arg == "--no-fallback") { params.no_fallback = true; } - else if (arg == "-fp" || arg == "--font-path") { params.font_path = argv[++i]; } - else if (arg == "-ps" || arg == "--print-special") { params.print_special = true; } - else if (arg == "-pc" || arg == "--print-colors") { params.print_colors = true; } - else if (arg == "-pr" || arg == "--print-realtime") { params.print_realtime = true; } - else if (arg == "-pp" || arg == "--print-progress") { params.print_progress = true; } - else if (arg == "-nt" || arg == "--no-timestamps") { params.no_timestamps = true; } - else if (arg == "-l" || arg == "--language") { params.language = argv[++i]; } - else if (arg == "-dl" || arg == "--detect-language") { params.detect_language = true; } - else if ( arg == "--prompt") { params.prompt = argv[++i]; } - else if (arg == "-m" || arg == "--model") { params.model = argv[++i]; } - else if (arg == "-oved" || arg == "--ov-e-device") { params.openvino_encode_device = argv[++i]; } - else if (arg == "-dtw" || arg == "--dtw") { params.dtw = argv[++i]; } - else if (arg == "-ng" || arg == "--no-gpu") { params.use_gpu = false; } - else if (arg == "-fa" || arg == "--flash-attn") { params.flash_attn = true; } - else if (arg == "-nfa" || arg == "--no-flash-attn") { params.flash_attn = false; } - else if (arg == "-sns" || arg == "--suppress-nst") { params.suppress_nst = true; } - else if (arg == "-nth" || arg == "--no-speech-thold") { params.no_speech_thold = std::stof(argv[++i]); } - else if (arg == "-nlp" || arg == "--no-language-probabilities") { params.no_language_probabilities = true; } + else if (arg == "-t" || arg == "--threads") { params.n_threads = std::stoi(argv[++i]); } + else if (arg == "-p" || arg == "--processors") { params.n_processors = std::stoi(argv[++i]); } + else if (arg == "-ot" || arg == "--offset-t") { params.offset_t_ms = std::stoi(argv[++i]); } + else if (arg == "-on" || arg == "--offset-n") { params.offset_n = std::stoi(argv[++i]); } + else if (arg == "-d" || arg == "--duration") { params.duration_ms = std::stoi(argv[++i]); } + else if (arg == "-mc" || arg == "--max-context") { params.max_context = std::stoi(argv[++i]); } + else if (arg == "-ml" || arg == "--max-len") { params.max_len = std::stoi(argv[++i]); } + else if (arg == "-bo" || arg == "--best-of") { params.best_of = std::stoi(argv[++i]); } + else if (arg == "-bs" || arg == "--beam-size") { params.beam_size = std::stoi(argv[++i]); } + else if (arg == "-ac" || arg == "--audio-ctx") { params.audio_ctx = std::stoi(argv[++i]); } + else if (arg == "-wt" || arg == "--word-thold") { params.word_thold = std::stof(argv[++i]); } + else if (arg == "-et" || arg == "--entropy-thold") { params.entropy_thold = std::stof(argv[++i]); } + else if (arg == "-lpt" || arg == "--logprob-thold") { params.logprob_thold = std::stof(argv[++i]); } + else if (arg == "-debug" || arg == "--debug-mode") { params.debug_mode = true; } + else if (arg == "-tr" || arg == "--translate") { params.translate = true; } + else if (arg == "-di" || arg == "--diarize") { params.diarize = true; } + else if (arg == "-tdrz" || arg == "--tinydiarize") { params.tinydiarize = true; } + else if (arg == "-sow" || arg == "--split-on-word") { params.split_on_word = true; } + else if (arg == "-nf" || arg == "--no-fallback") { params.no_fallback = true; } + else if (arg == "-fp" || arg == "--font-path") { params.font_path = argv[++i]; } + else if (arg == "-ps" || arg == "--print-special") { params.print_special = true; } + else if (arg == "-pc" || arg == "--print-colors") { params.print_colors = true; } + else if (arg == "-pr" || arg == "--print-realtime") { params.print_realtime = true; } + else if (arg == "-pp" || arg == "--print-progress") { params.print_progress = true; } + else if (arg == "-nt" || arg == "--no-timestamps") { params.no_timestamps = true; } + else if (arg == "-l" || arg == "--language") { params.language = argv[++i]; } + else if (arg == "-dl" || arg == "--detect-language") { params.detect_language = true; } + else if ( arg == "--prompt") { params.prompt = argv[++i]; } + else if ( arg == "--carry-initial-prompt") { params.carry_initial_prompt = true; } + else if (arg == "-m" || arg == "--model") { params.model = argv[++i]; } + else if (arg == "-oved" || arg == "--ov-e-device") { params.openvino_encode_device = argv[++i]; } + else if (arg == "-dtw" || arg == "--dtw") { params.dtw = argv[++i]; } + else if (arg == "-ng" || arg == "--no-gpu") { params.use_gpu = false; } + else if (arg == "-dev" || arg == "--device") { params.gpu_device = std::stoi(argv[++i]); } + else if (arg == "-fa" || arg == "--flash-attn") { params.flash_attn = true; } + else if (arg == "-nfa" || arg == "--no-flash-attn") { params.flash_attn = false; } + else if (arg == "-sns" || arg == "--suppress-nst") { params.suppress_nst = true; } + else if (arg == "-nth" || arg == "--no-speech-thold") { params.no_speech_thold = std::stof(argv[++i]); } + else if (arg == "-nlp" || arg == "--no-language-probabilities") { params.no_language_probabilities = true; } // server params - else if ( arg == "--port") { sparams.port = std::stoi(argv[++i]); } - else if ( arg == "--host") { sparams.hostname = argv[++i]; } - else if ( arg == "--public") { sparams.public_path = argv[++i]; } - else if ( arg == "--request-path") { sparams.request_path = argv[++i]; } - else if ( arg == "--inference-path") { sparams.inference_path = argv[++i]; } - else if ( arg == "--convert") { sparams.ffmpeg_converter = true; } - else if ( arg == "--tmp-dir") { sparams.tmp_dir = argv[++i]; } + else if ( arg == "--port") { sparams.port = std::stoi(argv[++i]); } + else if ( arg == "--host") { sparams.hostname = argv[++i]; } + else if ( arg == "--public") { sparams.public_path = argv[++i]; } + else if ( arg == "--request-path") { sparams.request_path = argv[++i]; } + else if ( arg == "--inference-path") { sparams.inference_path = argv[++i]; } + else if ( arg == "--convert") { sparams.ffmpeg_converter = true; } + else if ( arg == "--tmp-dir") { sparams.tmp_dir = argv[++i]; } // Voice Activity Detection (VAD) - else if ( arg == "--vad") { params.vad = true; } - else if (arg == "-vm" || arg == "--vad-model") { params.vad_model = argv[++i]; } - else if (arg == "-vt" || arg == "--vad-threshold") { params.vad_threshold = std::stof(argv[++i]); } - else if (arg == "-vspd" || arg == "--vad-min-speech-duration-ms") { params.vad_min_speech_duration_ms = std::stoi(argv[++i]); } - else if (arg == "-vsd" || arg == "--vad-min-silence-duration-ms") { params.vad_min_silence_duration_ms = std::stoi(argv[++i]); } - else if (arg == "-vmsd" || arg == "--vad-max-speech-duration-s") { params.vad_max_speech_duration_s = std::stof(argv[++i]); } - else if (arg == "-vp" || arg == "--vad-speech-pad-ms") { params.vad_speech_pad_ms = std::stoi(argv[++i]); } - else if (arg == "-vo" || arg == "--vad-samples-overlap") { params.vad_samples_overlap = std::stof(argv[++i]); } + else if ( arg == "--vad") { params.vad = true; } + else if (arg == "-vm" || arg == "--vad-model") { params.vad_model = argv[++i]; } + else if (arg == "-vt" || arg == "--vad-threshold") { params.vad_threshold = std::stof(argv[++i]); } + else if (arg == "-vspd" || arg == "--vad-min-speech-duration-ms") { params.vad_min_speech_duration_ms = std::stoi(argv[++i]); } + else if (arg == "-vsd" || arg == "--vad-min-silence-duration-ms") { params.vad_min_silence_duration_ms = std::stoi(argv[++i]); } + else if (arg == "-vmsd" || arg == "--vad-max-speech-duration-s") { params.vad_max_speech_duration_s = std::stof(argv[++i]); } + else if (arg == "-vp" || arg == "--vad-speech-pad-ms") { params.vad_speech_pad_ms = std::stoi(argv[++i]); } + else if (arg == "-vo" || arg == "--vad-samples-overlap") { params.vad_samples_overlap = std::stof(argv[++i]); } else { fprintf(stderr, "error: unknown argument: %s\n", arg.c_str()); whisper_print_usage(argc, argv, params, sparams); @@ -311,10 +315,10 @@ std::string generate_temp_filename(const std::string &path, const std::string &p return ss.str(); } -bool convert_to_wav(const std::string & temp_filename, std::string & error_resp) { +bool convert_to_wav(const std::string & temp_filename, std::string & error_resp, bool stereo) { std::ostringstream cmd_stream; std::string converted_filename_temp = temp_filename + "_temp.wav"; - cmd_stream << "ffmpeg -i \"" << temp_filename << "\" -y -ar 16000 -ac 1 -c:a pcm_s16le \"" << converted_filename_temp << "\" 2>&1"; + cmd_stream << "ffmpeg -i \"" << temp_filename << "\" -y -ar 16000 -ac " << (stereo ? 2 : 1) << " -c:a pcm_s16le \"" << converted_filename_temp << "\" 2>&1"; std::string cmd = cmd_stream.str(); int status = std::system(cmd.c_str()); @@ -337,7 +341,7 @@ bool convert_to_wav(const std::string & temp_filename, std::string & error_resp) return true; } -std::string estimate_diarization_speaker(std::vector<std::vector<float>> pcmf32s, int64_t t0, int64_t t1, bool id_only = false) { +std::string estimate_diarization_speaker(const std::vector<std::vector<float>> & pcmf32s, int64_t t0, int64_t t1, bool id_only = false) { std::string speaker = ""; const int64_t n_samples = pcmf32s[0].size(); @@ -447,7 +451,7 @@ void whisper_print_segment_callback(struct whisper_context * ctx, struct whisper } } -std::string output_str(struct whisper_context * ctx, const whisper_params & params, std::vector<std::vector<float>> pcmf32s) { +std::string output_str(struct whisper_context * ctx, const whisper_params & params, const std::vector<std::vector<float>> & pcmf32s) { std::stringstream result; const int n_segments = whisper_full_n_segments(ctx); for (int i = 0; i < n_segments; ++i) { @@ -519,6 +523,10 @@ void get_req_parameters(const Request & req, whisper_params & params) { params.logprob_thold = std::stof(req.get_file_value("logprob_thold").content); } + if (req.has_file("no_speech_thold")) + { + params.no_speech_thold = std::stof(req.get_file_value("no_speech_thold").content); + } if (req.has_file("debug_mode")) { params.debug_mode = parse_str_to_bool(req.get_file_value("debug_mode").content); @@ -543,6 +551,12 @@ void get_req_parameters(const Request & req, whisper_params & params) { params.no_timestamps = parse_str_to_bool(req.get_file_value("no_timestamps").content); } + if (req.has_file("token_timestamps")) + { + params.token_timestamps = parse_str_to_bool(req.get_file_value("token_timestamps").content); + } else { + params.token_timestamps = !params.no_timestamps; + } if (req.has_file("language")) { params.language = req.get_file_value("language").content; @@ -555,6 +569,10 @@ void get_req_parameters(const Request & req, whisper_params & params) { params.prompt = req.get_file_value("prompt").content; } + if (req.has_file("carry_initial_prompt")) + { + params.carry_initial_prompt = parse_str_to_bool(req.get_file_value("carry_initial_prompt").content); + } if (req.has_file("response_format")) { params.response_format = req.get_file_value("response_format").content; @@ -643,6 +661,7 @@ int main(int argc, char ** argv) { struct whisper_context_params cparams = whisper_context_default_params(); cparams.use_gpu = params.use_gpu; + cparams.gpu_device = params.gpu_device; cparams.flash_attn = params.flash_attn; if (!params.dtw.empty()) { @@ -682,10 +701,10 @@ int main(int argc, char ** argv) { if (params.dtw == "large.v3") { cparams.dtw_aheads_preset = WHISPER_AHEADS_LARGE_V3; } - if (params.dtw == "large.v3.turbo") { + if (params.dtw == "large.v3.turbo") { cparams.dtw_aheads_preset = WHISPER_AHEADS_LARGE_V3_TURBO; } - + if (cparams.dtw_aheads_preset == WHISPER_AHEADS_NONE) { fprintf(stderr, "error: unknown DTW preset '%s'\n", params.dtw.c_str()); return 3; @@ -740,13 +759,14 @@ int main(int argc, char ** argv) { <body> <h1>Whisper.cpp Server</h1> - <h2>/inference</h2> + <h2>)" + sparams.request_path + sparams.inference_path + R"(</h2> <pre> - curl 127.0.0.1:)" + std::to_string(sparams.port) + R"(/inference \ + curl 127.0.0.1:)" + std::to_string(sparams.port) + sparams.request_path + sparams.inference_path + R"( \ -H "Content-Type: multipart/form-data" \ -F file="@<file-path>" \ -F temperature="0.0" \ -F temperature_inc="0.2" \ + -F no_speech_thold="0.6" \ -F response_format="json" </pre> @@ -759,7 +779,7 @@ int main(int argc, char ** argv) { <div> <h2>Try it out</h2> - <form action="/inference" method="POST" enctype="multipart/form-data"> + <form action=")" + sparams.request_path + sparams.inference_path + R"(" method="POST" enctype="multipart/form-data"> <label for="file">Choose an audio file:</label> <input type="file" id="file" name="file" accept="audio/*" required><br> @@ -803,12 +823,13 @@ int main(int argc, char ** argv) { { fprintf(stderr, "error: no 'file' field in the request\n"); const std::string error_resp = "{\"error\":\"no 'file' field in the request\"}"; + res.status = 400; res.set_content(error_resp, "application/json"); return; } auto audio_file = req.get_file_value("file"); - // check non-required fields + whisper_params params = default_params; get_req_parameters(req, params); std::string filename{audio_file.filename}; @@ -827,8 +848,9 @@ int main(int argc, char ** argv) { temp_file.close(); std::string error_resp = "{\"error\":\"Failed to execute ffmpeg command.\"}"; - const bool is_converted = convert_to_wav(temp_filename, error_resp); + const bool is_converted = convert_to_wav(temp_filename, error_resp, params.diarize); if (!is_converted) { + res.status = 500; res.set_content(error_resp, "application/json"); return; } @@ -838,6 +860,7 @@ int main(int argc, char ** argv) { { fprintf(stderr, "error: failed to read WAV file '%s'\n", temp_filename.c_str()); const std::string error_resp = "{\"error\":\"failed to read WAV file\"}"; + res.status = 400; res.set_content(error_resp, "application/json"); std::remove(temp_filename.c_str()); return; @@ -845,10 +868,10 @@ int main(int argc, char ** argv) { // remove temp file std::remove(temp_filename.c_str()); } else { - if (!::read_audio_data(audio_file.content, pcmf32, pcmf32s, params.diarize)) - { + if (!::read_audio_data(audio_file.content.data(), audio_file.content.size(), pcmf32, pcmf32s, params.diarize)) { fprintf(stderr, "error: failed to read audio data\n"); const std::string error_resp = "{\"error\":\"failed to read audio data\"}"; + res.status = 400; res.set_content(error_resp, "application/json"); return; } @@ -916,18 +939,19 @@ int main(int argc, char ** argv) { wparams.tdrz_enable = params.tinydiarize; // [TDRZ] wparams.initial_prompt = params.prompt.c_str(); + wparams.carry_initial_prompt = params.carry_initial_prompt; wparams.greedy.best_of = params.best_of; wparams.beam_search.beam_size = params.beam_size; wparams.temperature = params.temperature; - wparams.no_speech_thold = params.no_speech_thold; + wparams.no_speech_thold = params.no_speech_thold; wparams.temperature_inc = params.temperature_inc; wparams.entropy_thold = params.entropy_thold; wparams.logprob_thold = params.logprob_thold; wparams.no_timestamps = params.no_timestamps; - wparams.token_timestamps = !params.no_timestamps && params.response_format == vjson_format; + wparams.token_timestamps = params.token_timestamps; wparams.no_context = params.no_context; wparams.suppress_nst = params.suppress_nst; @@ -1031,7 +1055,7 @@ int main(int argc, char ** argv) { res.set_content(ss.str(), "text/vtt"); } else if (params.response_format == vjson_format) { /* try to match openai/whisper's Python format */ - std::string results = output_str(ctx, params, pcmf32s); + std::string results = output_str(ctx, params, pcmf32s); json jres = json{ {"task", params.translate ? "translate" : "transcribe"}, {"language", whisper_lang_str_full(whisper_full_lang_id(ctx))}, @@ -1066,6 +1090,14 @@ int main(int argc, char ** argv) { segment["end"] = whisper_full_get_segment_t1(ctx, i) * 0.01; } + if (params.diarize && pcmf32s.size() == 2) { + segment["speaker"] = estimate_diarization_speaker( + pcmf32s, + whisper_full_get_segment_t0(ctx, i), + whisper_full_get_segment_t1(ctx, i), + true); + } + float total_logprob = 0; const int n_tokens = whisper_full_n_tokens(ctx, i); for (int j = 0; j < n_tokens; ++j) { @@ -1075,10 +1107,29 @@ int main(int argc, char ** argv) { } segment["tokens"].push_back(token.id); - json word = json{{"word", whisper_full_get_token_text(ctx, i, j)}}; - if (!params.no_timestamps) { + std::string word_text = whisper_full_get_token_text(ctx, i, j); + int64_t word_t1 = token.t1; + + while (j + 1 < n_tokens && utf8_trailing_bytes_needed(word_text) > 0) { + const whisper_token_data next_token = whisper_full_get_token_data(ctx, i, j + 1); + // Keep verbose_json tokens free of EOT ids, matching the pre-merge server behavior. + if (next_token.id >= whisper_token_eot(ctx)) { + break; + } + + ++j; + segment["tokens"].push_back(next_token.id); + word_text += whisper_full_get_token_text(ctx, i, j); + if (next_token.t1 > -1) { + word_t1 = next_token.t1; + } + total_logprob += next_token.plog; + } + + json word = json{{"word", word_text}}; + if (!params.no_timestamps && params.token_timestamps) { word["start"] = token.t0 * 0.01; - word["end"] = token.t1 * 0.01; + word["end"] = word_t1 * 0.01; word["t_dtw"] = token.t_dtw; } word["probability"] = token.p; @@ -1108,9 +1159,6 @@ int main(int argc, char ** argv) { res.set_content(jres.dump(-1, ' ', false, json::error_handler_t::replace), "application/json"); } - - // reset params to their defaults - params = default_params; }); svr->Post(sparams.request_path + "/load", [&](const Request &req, Response &res){ std::lock_guard<std::mutex> lock(whisper_mutex); @@ -1119,6 +1167,7 @@ int main(int argc, char ** argv) { { fprintf(stderr, "error: no 'model' field in the request\n"); const std::string error_resp = "{\"error\":\"no 'model' field in the request\"}"; + res.status = 400; res.set_content(error_resp, "application/json"); return; } @@ -1127,6 +1176,7 @@ int main(int argc, char ** argv) { { fprintf(stderr, "error: 'model': %s not found!\n", model.c_str()); const std::string error_resp = "{\"error\":\"model not found!\"}"; + res.status = 400; res.set_content(error_resp, "application/json"); return; } diff --git a/examples/talk-llama/CMakeLists.txt b/examples/talk-llama/CMakeLists.txt index cac46705d6c..13b284ed0e9 100644 --- a/examples/talk-llama/CMakeLists.txt +++ b/examples/talk-llama/CMakeLists.txt @@ -20,20 +20,22 @@ if (WHISPER_SDL2) llama-io.cpp llama-kv-cache.cpp llama-kv-cache-iswa.cpp + llama-kv-cache-dsa.cpp llama-memory-recurrent.cpp llama-memory-hybrid.cpp + llama-memory-hybrid-iswa.cpp llama-memory.cpp llama-mmap.cpp llama-model-loader.cpp llama-model-saver.cpp llama-model.cpp llama-quant.cpp - llama-sampling.cpp + llama-sampler.cpp llama-vocab.cpp unicode.cpp unicode-data.cpp ${SRC_MODELS}) - target_include_directories(${TARGET} PRIVATE ${SDL2_INCLUDE_DIRS}) + target_include_directories(${TARGET} PRIVATE . ${SDL2_INCLUDE_DIRS}) target_link_libraries(${TARGET} PRIVATE common common-sdl whisper ${SDL2_LIBRARIES} ${CMAKE_THREAD_LIBS_INIT}) install(TARGETS ${TARGET} RUNTIME) diff --git a/examples/talk-llama/llama-adapter.cpp b/examples/talk-llama/llama-adapter.cpp index bdc24c2d6b1..3e0fe66afff 100644 --- a/examples/talk-llama/llama-adapter.cpp +++ b/examples/talk-llama/llama-adapter.cpp @@ -41,7 +41,7 @@ bool llama_adapter_cvec::init(const llama_model & model) { auto it = ctx_map.find(buft); if (it == ctx_map.end()) { ggml_init_params params = { - /*.mem_size =*/ hparams.n_layer*ggml_tensor_overhead(), + /*.mem_size =*/ hparams.n_layer()*ggml_tensor_overhead(), /*.mem_buffer =*/ NULL, /*.no_alloc =*/ true, }; @@ -61,9 +61,9 @@ bool llama_adapter_cvec::init(const llama_model & model) { }; // make tensors - tensors.reserve(hparams.n_layer); + tensors.reserve(hparams.n_layer()); tensors.push_back(nullptr); // there's never a tensor for layer 0 - for (size_t il = 1; il < hparams.n_layer; il++) { + for (size_t il = 1; il < hparams.n_layer(); il++) { ggml_backend_buffer_type_t buft = model.select_buft(il); ggml_context * ctx = ctx_for_buft(buft); if (!ctx) { @@ -121,7 +121,7 @@ bool llama_adapter_cvec::apply( layer_start = il_start; layer_end = il_end; - for (size_t il = 1; il < hparams.n_layer; il++) { + for (size_t il = 1; il < hparams.n_layer(); il++) { assert(tensors[il] != nullptr); const size_t off = n_embd * (il - 1); // buffer doesn't have data for layer 0, since it's never present @@ -146,11 +146,9 @@ llama_adapter_lora_weight * llama_adapter_lora::get_weight(ggml_tensor * w) { return nullptr; } -static void llama_adapter_lora_init_impl(const char * path_lora, llama_adapter_lora & adapter) { +static void llama_adapter_lora_init_impl(llama_model & model, const char * path_lora, llama_adapter_lora & adapter) { LLAMA_LOG_INFO("%s: loading lora adapter from '%s' ...\n", __func__, path_lora); - llama_model & model = adapter.model; - ggml_context * ctx_init; gguf_init_params meta_gguf_params = { /* .no_alloc = */ true, @@ -296,7 +294,7 @@ static void llama_adapter_lora_init_impl(const char * path_lora, llama_adapter_l } // get extra buffer types of the CPU - // TODO: a more general solution for non-CPU extra buft should be imlpemented in the future + // TODO: a more general solution for non-CPU extra buft should be implemented in the future // ref: https://github.com/ggml-org/llama.cpp/pull/12593#pullrequestreview-2718659948 std::vector<ggml_backend_buffer_type_t> buft_extra; { @@ -413,17 +411,17 @@ static void llama_adapter_lora_init_impl(const char * path_lora, llama_adapter_l } } - // update number of nodes used - model.n_lora_nodes += adapter.get_n_nodes(); + // register adapter with model + model.loras.insert(&adapter); LLAMA_LOG_INFO("%s: loaded %zu tensors from lora file\n", __func__, adapter.ab_map.size()*2); } llama_adapter_lora * llama_adapter_lora_init(llama_model * model, const char * path_lora) { - llama_adapter_lora * adapter = new llama_adapter_lora(*model); + llama_adapter_lora * adapter = new llama_adapter_lora(model); try { - llama_adapter_lora_init_impl(path_lora, *adapter); + llama_adapter_lora_init_impl(*model, path_lora, *adapter); return adapter; } catch (const std::exception & err) { LLAMA_LOG_ERROR("%s: failed to apply lora adapter: %s\n", __func__, err.what()); @@ -474,9 +472,14 @@ int32_t llama_adapter_meta_val_str_by_index(const llama_adapter_lora * adapter, } void llama_adapter_lora_free(llama_adapter_lora * adapter) { - // update number of nodes used - GGML_ASSERT(adapter->model.n_lora_nodes >= adapter->get_n_nodes()); - adapter->model.n_lora_nodes -= adapter->get_n_nodes(); + if (adapter == nullptr) { + return; + } + + if (adapter->model != nullptr) { + adapter->model->loras.erase(adapter); + adapter->model = nullptr; + } delete adapter; } diff --git a/examples/talk-llama/llama-adapter.h b/examples/talk-llama/llama-adapter.h index 42d64a6e0b5..f0b1e50f816 100644 --- a/examples/talk-llama/llama-adapter.h +++ b/examples/talk-llama/llama-adapter.h @@ -39,6 +39,8 @@ struct llama_adapter_cvec { std::vector<ggml_tensor *> tensors; // per layer }; +using llama_adapter_cvec_ptr = std::shared_ptr<llama_adapter_cvec>; + // // llama_adapter_lora // @@ -59,7 +61,7 @@ struct llama_adapter_lora_weight { }; struct llama_adapter_lora { - llama_model & model; + llama_model * model = nullptr; // map tensor name to lora_a_b std::unordered_map<std::string, llama_adapter_lora_weight> ab_map; @@ -75,7 +77,7 @@ struct llama_adapter_lora { // activated lora (aLoRA) std::vector<llama_token> alora_invocation_tokens; - llama_adapter_lora(llama_model & model) : model(model) {} + explicit llama_adapter_lora(llama_model * model) : model(model) {} ~llama_adapter_lora() = default; llama_adapter_lora_weight * get_weight(ggml_tensor * w); @@ -86,3 +88,4 @@ struct llama_adapter_lora { }; using llama_adapter_loras = std::unordered_map<llama_adapter_lora *, float>; +using llama_adapter_loras_ptr = std::unique_ptr<llama_adapter_loras>; diff --git a/examples/talk-llama/llama-arch.cpp b/examples/talk-llama/llama-arch.cpp index f736ee67050..9f93d5bc7ce 100644 --- a/examples/talk-llama/llama-arch.cpp +++ b/examples/talk-llama/llama-arch.cpp @@ -3,7 +3,7 @@ #include "llama-impl.h" #include <map> -#include <set> +#include <vector> static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = { { LLM_ARCH_CLIP, "clip" }, // dummy, only used by llama-quantize @@ -26,6 +26,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = { { LLM_ARCH_NEO_BERT, "neo-bert" }, { LLM_ARCH_JINA_BERT_V2, "jina-bert-v2" }, { LLM_ARCH_JINA_BERT_V3, "jina-bert-v3" }, + { LLM_ARCH_EUROBERT, "eurobert" }, { LLM_ARCH_BLOOM, "bloom" }, { LLM_ARCH_STABLELM, "stablelm" }, { LLM_ARCH_QWEN, "qwen" }, @@ -37,6 +38,8 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = { { LLM_ARCH_QWEN3NEXT, "qwen3next" }, { LLM_ARCH_QWEN3VL, "qwen3vl" }, { LLM_ARCH_QWEN3VLMOE, "qwen3vlmoe" }, + { LLM_ARCH_QWEN35, "qwen35" }, + { LLM_ARCH_QWEN35MOE, "qwen35moe" }, { LLM_ARCH_PHI2, "phi2" }, { LLM_ARCH_PHI3, "phi3" }, { LLM_ARCH_PHIMOE, "phimoe" }, @@ -52,6 +55,8 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = { { LLM_ARCH_GEMMA2, "gemma2" }, { LLM_ARCH_GEMMA3, "gemma3" }, { LLM_ARCH_GEMMA3N, "gemma3n" }, + { LLM_ARCH_GEMMA4, "gemma4" }, + { LLM_ARCH_GEMMA4_ASSISTANT, "gemma4-assistant" }, { LLM_ARCH_GEMMA_EMBEDDING, "gemma-embedding" }, { LLM_ARCH_STARCODER2, "starcoder2" }, { LLM_ARCH_MAMBA, "mamba" }, @@ -69,18 +74,23 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = { { LLM_ARCH_ARCTIC, "arctic" }, { LLM_ARCH_DEEPSEEK, "deepseek" }, { LLM_ARCH_DEEPSEEK2, "deepseek2" }, + { LLM_ARCH_DEEPSEEK2OCR, "deepseek2-ocr" }, + { LLM_ARCH_DEEPSEEK32, "deepseek32" }, { LLM_ARCH_CHATGLM, "chatglm" }, { LLM_ARCH_GLM4, "glm4" }, { LLM_ARCH_GLM4_MOE, "glm4moe" }, + { LLM_ARCH_GLM_DSA, "glm-dsa" }, { LLM_ARCH_BITNET, "bitnet" }, { LLM_ARCH_T5, "t5" }, { LLM_ARCH_T5ENCODER, "t5encoder" }, { LLM_ARCH_JAIS, "jais" }, + { LLM_ARCH_JAIS2, "jais2" }, { LLM_ARCH_NEMOTRON, "nemotron" }, { LLM_ARCH_NEMOTRON_H, "nemotron_h" }, { LLM_ARCH_NEMOTRON_H_MOE, "nemotron_h_moe" }, { LLM_ARCH_EXAONE, "exaone" }, { LLM_ARCH_EXAONE4, "exaone4" }, + { LLM_ARCH_EXAONE_MOE, "exaone-moe" }, { LLM_ARCH_RWKV6, "rwkv6" }, { LLM_ARCH_RWKV6QWEN2, "rwkv6qwen2" }, { LLM_ARCH_RWKV7, "rwkv7" }, @@ -100,6 +110,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = { { LLM_ARCH_ERNIE4_5_MOE, "ernie4_5-moe" }, { LLM_ARCH_HUNYUAN_MOE, "hunyuan-moe" }, { LLM_ARCH_HUNYUAN_DENSE, "hunyuan-dense" }, + { LLM_ARCH_HUNYUAN_VL, "hunyuan_vl" }, { LLM_ARCH_SMOLLM3, "smollm3" }, { LLM_ARCH_OPENAI_MOE, "gpt-oss" }, { LLM_ARCH_LFM2, "lfm2" }, @@ -116,9 +127,16 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = { { LLM_ARCH_RND1, "rnd1" }, { LLM_ARCH_PANGU_EMBED, "pangu-embedded" }, { LLM_ARCH_MISTRAL3, "mistral3" }, - { LLM_ARCH_MIMO2, "mimo2" }, + { LLM_ARCH_EAGLE3, "eagle3" }, + { LLM_ARCH_MISTRAL4, "mistral4" }, + { LLM_ARCH_PADDLEOCR, "paddleocr" }, + { LLM_ARCH_MIMO2, "mimo2" }, + { LLM_ARCH_STEP35, "step35" }, { LLM_ARCH_LLAMA_EMBED, "llama-embed" }, { LLM_ARCH_MAINCODER, "maincoder" }, + { LLM_ARCH_KIMI_LINEAR, "kimi-linear" }, + { LLM_ARCH_TALKIE, "talkie" }, + { LLM_ARCH_MELLUM, "mellum" }, { LLM_ARCH_UNKNOWN, "(unknown)" }, }; @@ -153,6 +171,7 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = { { LLM_KV_CONTEXT_LENGTH, "%s.context_length" }, { LLM_KV_EMBEDDING_LENGTH, "%s.embedding_length" }, { LLM_KV_EMBEDDING_LENGTH_OUT, "%s.embedding_length_out" }, + { LLM_KV_EMBEDDING_LENGTH_PER_LAYER, "%s.embedding_length_per_layer_input" }, { LLM_KV_FEATURES_LENGTH, "%s.features_length" }, { LLM_KV_BLOCK_COUNT, "%s.block_count" }, { LLM_KV_LEADING_DENSE_BLOCK_COUNT, "%s.leading_dense_block_count" }, @@ -160,6 +179,8 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = { { LLM_KV_EXPERT_FEED_FORWARD_LENGTH, "%s.expert_feed_forward_length" }, { LLM_KV_EXPERT_SHARED_FEED_FORWARD_LENGTH, "%s.expert_shared_feed_forward_length" }, { LLM_KV_EXPERT_CHUNK_FEED_FORWARD_LENGTH, "%s.expert_chunk_feed_forward_length" }, + { LLM_KV_SWIGLU_CLAMP_EXP, "%s.swiglu_clamp_exp" }, + { LLM_KV_SWIGLU_CLAMP_SHEXP, "%s.swiglu_clamp_shexp" }, { LLM_KV_USE_PARALLEL_RESIDUAL, "%s.use_parallel_residual" }, { LLM_KV_TENSOR_DATA_LAYOUT, "%s.tensor_data_layout" }, { LLM_KV_EXPERT_COUNT, "%s.expert_count" }, @@ -173,8 +194,11 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = { { LLM_KV_EXPERT_GROUP_SCALE, "%s.expert_group_scale" }, { LLM_KV_EXPERTS_PER_GROUP, "%s.experts_per_group" }, { LLM_KV_MOE_EVERY_N_LAYERS, "%s.moe_every_n_layers" }, + { LLM_KV_MOE_LATENT_SIZE, "%s.moe_latent_size" }, { LLM_KV_NEXTN_PREDICT_LAYERS, "%s.nextn_predict_layers" }, { LLM_KV_NUM_DEEPSTACK_LAYERS, "%s.n_deepstack_layers" }, + { LLM_KV_DEEPSTACK_MAPPING, "%s.deepstack_mapping" }, + { LLM_KV_HIDDEN_ACT, "%s.hidden_activation" }, { LLM_KV_POOLING_TYPE, "%s.pooling_type" }, { LLM_KV_LOGIT_SCALE, "%s.logit_scale" }, { LLM_KV_DECODER_START_TOKEN_ID, "%s.decoder_start_token_id" }, @@ -190,6 +214,7 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = { { LLM_KV_EMBEDDING_SCALE, "%s.embedding_scale" }, { LLM_KV_TOKEN_SHIFT_COUNT, "%s.token_shift_count" }, { LLM_KV_INTERLEAVE_MOE_LAYER_STEP, "%s.interleave_moe_layer_step" }, + { LLM_KV_FULL_ATTENTION_INTERVAL, "%s.full_attention_interval" }, { LLM_KV_ATTENTION_HEAD_COUNT, "%s.attention.head_count" }, { LLM_KV_ATTENTION_HEAD_COUNT_KV, "%s.attention.head_count_kv" }, @@ -213,26 +238,36 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = { { LLM_KV_ATTENTION_SLIDING_WINDOW_PATTERN, "%s.attention.sliding_window_pattern" }, { LLM_KV_ATTENTION_SCALE, "%s.attention.scale" }, { LLM_KV_ATTENTION_OUTPUT_SCALE, "%s.attention.output_scale" }, + { LLM_KV_ATTENTION_VALUE_SCALE, "%s.attention.value_scale" }, { LLM_KV_ATTENTION_TEMPERATURE_LENGTH, "%s.attention.temperature_length" }, { LLM_KV_ATTENTION_TEMPERATURE_SCALE, "%s.attention.temperature_scale" }, { LLM_KV_ATTENTION_KEY_LENGTH_MLA, "%s.attention.key_length_mla" }, { LLM_KV_ATTENTION_VALUE_LENGTH_MLA, "%s.attention.value_length_mla" }, - - { LLM_KV_ROPE_DIMENSION_COUNT, "%s.rope.dimension_count" }, - { LLM_KV_ROPE_DIMENSION_SECTIONS, "%s.rope.dimension_sections" }, - { LLM_KV_ROPE_FREQ_BASE, "%s.rope.freq_base" }, - { LLM_KV_ROPE_FREQ_BASE_SWA, "%s.rope.freq_base_swa" }, - { LLM_KV_ROPE_SCALE_LINEAR, "%s.rope.scale_linear" }, - { LLM_KV_ROPE_SCALING_TYPE, "%s.rope.scaling.type" }, - { LLM_KV_ROPE_SCALING_FACTOR, "%s.rope.scaling.factor" }, - { LLM_KV_ROPE_SCALING_ATTN_FACTOR, "%s.rope.scaling.attn_factor" }, - { LLM_KV_ROPE_SCALING_ORIG_CTX_LEN, "%s.rope.scaling.original_context_length" }, - { LLM_KV_ROPE_SCALING_FINETUNED, "%s.rope.scaling.finetuned" }, - { LLM_KV_ROPE_SCALING_YARN_LOG_MUL, "%s.rope.scaling.yarn_log_multiplier" }, - { LLM_KV_ROPE_SCALING_YARN_EXT_FACTOR, "%s.rope.scaling.yarn_ext_factor" }, - { LLM_KV_ROPE_SCALING_YARN_ATTN_FACTOR, "%s.rope.scaling.yarn_attn_factor" }, - { LLM_KV_ROPE_SCALING_YARN_BETA_FAST, "%s.rope.scaling.yarn_beta_fast" }, - { LLM_KV_ROPE_SCALING_YARN_BETA_SLOW, "%s.rope.scaling.yarn_beta_slow" }, + { LLM_KV_ATTENTION_KEY_LENGTH_SWA, "%s.attention.key_length_swa" }, + { LLM_KV_ATTENTION_VALUE_LENGTH_SWA, "%s.attention.value_length_swa" }, + { LLM_KV_ATTENTION_INDEXER_HEAD_COUNT, "%s.attention.indexer.head_count" }, + { LLM_KV_ATTENTION_INDEXER_KEY_LENGTH, "%s.attention.indexer.key_length" }, + { LLM_KV_ATTENTION_INDEXER_TOP_K, "%s.attention.indexer.top_k" }, + { LLM_KV_ATTENTION_SHARED_KV_LAYERS, "%s.attention.shared_kv_layers" }, + { LLM_KV_ATTENTION_RECURRENT_LAYERS, "%s.attention.recurrent_layers" }, + + { LLM_KV_ROPE_DIMENSION_COUNT, "%s.rope.dimension_count" }, + { LLM_KV_ROPE_DIMENSION_COUNT_SWA, "%s.rope.dimension_count_swa" }, + { LLM_KV_ROPE_DIMENSION_SECTIONS, "%s.rope.dimension_sections" }, + { LLM_KV_ROPE_FREQ_BASE, "%s.rope.freq_base" }, + { LLM_KV_ROPE_FREQ_BASE_SWA, "%s.rope.freq_base_swa" }, + { LLM_KV_ROPE_SCALE_LINEAR, "%s.rope.scale_linear" }, + { LLM_KV_ROPE_SCALING_TYPE, "%s.rope.scaling.type" }, + { LLM_KV_ROPE_SCALING_FACTOR, "%s.rope.scaling.factor" }, + { LLM_KV_ROPE_SCALING_ALPHA, "%s.rope.scaling.alpha" }, + { LLM_KV_ROPE_SCALING_ATTN_FACTOR, "%s.rope.scaling.attn_factor" }, + { LLM_KV_ROPE_SCALING_ORIG_CTX_LEN, "%s.rope.scaling.original_context_length" }, + { LLM_KV_ROPE_SCALING_FINETUNED, "%s.rope.scaling.finetuned" }, + { LLM_KV_ROPE_SCALING_YARN_LOG_MUL, "%s.rope.scaling.yarn_log_multiplier" }, + { LLM_KV_ROPE_SCALING_YARN_EXT_FACTOR, "%s.rope.scaling.yarn_ext_factor" }, + { LLM_KV_ROPE_SCALING_YARN_ATTN_FACTOR, "%s.rope.scaling.yarn_attn_factor" }, + { LLM_KV_ROPE_SCALING_YARN_BETA_FAST, "%s.rope.scaling.yarn_beta_fast" }, + { LLM_KV_ROPE_SCALING_YARN_BETA_SLOW, "%s.rope.scaling.yarn_beta_slow" }, { LLM_KV_SPLIT_NO, "split.no" }, { LLM_KV_SPLIT_COUNT, "split.count" }, @@ -245,6 +280,8 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = { { LLM_KV_SSM_GROUP_COUNT, "%s.ssm.group_count" }, { LLM_KV_SSM_DT_B_C_RMS, "%s.ssm.dt_b_c_rms" }, + { LLM_KV_KDA_HEAD_DIM, "%s.kda.head_dim" }, + { LLM_KV_WKV_HEAD_SIZE, "%s.wkv.head_size" }, { LLM_KV_POSNET_EMBEDDING_LENGTH, "%s.posnet.embedding_length" }, @@ -255,44 +292,51 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = { { LLM_KV_CLASSIFIER_OUTPUT_LABELS, "%s.classifier.output_labels" }, + { LLM_KV_TARGET_LAYERS, "%s.target_layers" }, + { LLM_KV_TARGET_HIDDEN_SIZE, "%s.target_hidden_size" }, + { LLM_KV_NORM_BEFORE_RESIDUAL, "%s.norm_before_residual" }, + { LLM_KV_SHORTCONV_L_CACHE, "%s.shortconv.l_cache" }, // sentence-transformers dense modules feature dims { LLM_KV_DENSE_2_FEAT_IN, "%s.dense_2_feat_in" }, - { LLM_KV_DENSE_2_FEAT_OUT, "%s.dense_2_feat_out" }, - { LLM_KV_DENSE_3_FEAT_IN, "%s.dense_3_feat_in" }, - { LLM_KV_DENSE_3_FEAT_OUT, "%s.dense_3_feat_out" }, - - { LLM_KV_TOKENIZER_MODEL, "tokenizer.ggml.model" }, - { LLM_KV_TOKENIZER_PRE, "tokenizer.ggml.pre" }, - { LLM_KV_TOKENIZER_LIST, "tokenizer.ggml.tokens" }, - { LLM_KV_TOKENIZER_TOKEN_TYPE, "tokenizer.ggml.token_type" }, - { LLM_KV_TOKENIZER_TOKEN_TYPE_COUNT, "tokenizer.ggml.token_type_count" }, - { LLM_KV_TOKENIZER_SCORES, "tokenizer.ggml.scores" }, - { LLM_KV_TOKENIZER_MERGES, "tokenizer.ggml.merges" }, - { LLM_KV_TOKENIZER_BOS_ID, "tokenizer.ggml.bos_token_id" }, - { LLM_KV_TOKENIZER_EOS_ID, "tokenizer.ggml.eos_token_id" }, - { LLM_KV_TOKENIZER_EOT_ID, "tokenizer.ggml.eot_token_id" }, - { LLM_KV_TOKENIZER_EOM_ID, "tokenizer.ggml.eom_token_id" }, - { LLM_KV_TOKENIZER_UNK_ID, "tokenizer.ggml.unknown_token_id" }, - { LLM_KV_TOKENIZER_SEP_ID, "tokenizer.ggml.seperator_token_id" }, - { LLM_KV_TOKENIZER_PAD_ID, "tokenizer.ggml.padding_token_id" }, - { LLM_KV_TOKENIZER_CLS_ID, "tokenizer.ggml.cls_token_id" }, - { LLM_KV_TOKENIZER_MASK_ID, "tokenizer.ggml.mask_token_id" }, - { LLM_KV_TOKENIZER_ADD_BOS, "tokenizer.ggml.add_bos_token" }, - { LLM_KV_TOKENIZER_ADD_EOS, "tokenizer.ggml.add_eos_token" }, - { LLM_KV_TOKENIZER_ADD_SEP, "tokenizer.ggml.add_sep_token" }, - { LLM_KV_TOKENIZER_ADD_PREFIX, "tokenizer.ggml.add_space_prefix" }, - { LLM_KV_TOKENIZER_REMOVE_EXTRA_WS, "tokenizer.ggml.remove_extra_whitespaces" }, - { LLM_KV_TOKENIZER_PRECOMPILED_CHARSMAP, "tokenizer.ggml.precompiled_charsmap" }, - { LLM_KV_TOKENIZER_HF_JSON, "tokenizer.huggingface.json" }, - { LLM_KV_TOKENIZER_RWKV, "tokenizer.rwkv.world" }, - { LLM_KV_TOKENIZER_CHAT_TEMPLATE, "tokenizer.chat_template" }, - { LLM_KV_TOKENIZER_FIM_PRE_ID, "tokenizer.ggml.fim_pre_token_id" }, - { LLM_KV_TOKENIZER_FIM_SUF_ID, "tokenizer.ggml.fim_suf_token_id" }, - { LLM_KV_TOKENIZER_FIM_MID_ID, "tokenizer.ggml.fim_mid_token_id" }, - { LLM_KV_TOKENIZER_FIM_PAD_ID, "tokenizer.ggml.fim_pad_token_id" }, - { LLM_KV_TOKENIZER_FIM_REP_ID, "tokenizer.ggml.fim_rep_token_id" }, - { LLM_KV_TOKENIZER_FIM_SEP_ID, "tokenizer.ggml.fim_sep_token_id" }, + { LLM_KV_DENSE_2_FEAT_OUT, "%s.dense_2_feat_out" }, + { LLM_KV_DENSE_3_FEAT_IN, "%s.dense_3_feat_in" }, + { LLM_KV_DENSE_3_FEAT_OUT, "%s.dense_3_feat_out" }, + + { LLM_KV_TOKENIZER_MODEL, "tokenizer.ggml.model" }, + { LLM_KV_TOKENIZER_PRE, "tokenizer.ggml.pre" }, + { LLM_KV_TOKENIZER_LIST, "tokenizer.ggml.tokens" }, + { LLM_KV_TOKENIZER_TOKEN_TYPE, "tokenizer.ggml.token_type" }, + { LLM_KV_TOKENIZER_TOKEN_TYPE_COUNT, "tokenizer.ggml.token_type_count" }, + { LLM_KV_TOKENIZER_SCORES, "tokenizer.ggml.scores" }, + { LLM_KV_TOKENIZER_MERGES, "tokenizer.ggml.merges" }, + { LLM_KV_TOKENIZER_BOS_ID, "tokenizer.ggml.bos_token_id" }, + { LLM_KV_TOKENIZER_EOS_ID, "tokenizer.ggml.eos_token_id" }, + { LLM_KV_TOKENIZER_EOT_ID, "tokenizer.ggml.eot_token_id" }, + { LLM_KV_TOKENIZER_EOM_ID, "tokenizer.ggml.eom_token_id" }, + { LLM_KV_TOKENIZER_UNK_ID, "tokenizer.ggml.unknown_token_id" }, + { LLM_KV_TOKENIZER_SEP_ID, "tokenizer.ggml.seperator_token_id" }, + { LLM_KV_TOKENIZER_PAD_ID, "tokenizer.ggml.padding_token_id" }, + { LLM_KV_TOKENIZER_CLS_ID, "tokenizer.ggml.cls_token_id" }, + { LLM_KV_TOKENIZER_MASK_ID, "tokenizer.ggml.mask_token_id" }, + { LLM_KV_TOKENIZER_ADD_BOS, "tokenizer.ggml.add_bos_token" }, + { LLM_KV_TOKENIZER_ADD_EOS, "tokenizer.ggml.add_eos_token" }, + { LLM_KV_TOKENIZER_ADD_SEP, "tokenizer.ggml.add_sep_token" }, + { LLM_KV_TOKENIZER_ADD_PREFIX, "tokenizer.ggml.add_space_prefix" }, + { LLM_KV_TOKENIZER_REMOVE_EXTRA_WS, "tokenizer.ggml.remove_extra_whitespaces" }, + { LLM_KV_TOKENIZER_PRECOMPILED_CHARSMAP, "tokenizer.ggml.precompiled_charsmap" }, + { LLM_KV_TOKENIZER_HF_JSON, "tokenizer.huggingface.json" }, + { LLM_KV_TOKENIZER_RWKV, "tokenizer.rwkv.world" }, + { LLM_KV_TOKENIZER_CHAT_TEMPLATE, "tokenizer.chat_template" }, + { LLM_KV_TOKENIZER_NORMALIZER_LOWERCASE, "tokenizer.ggml.normalizer.lowercase" }, + { LLM_KV_TOKENIZER_NORMALIZER_STRIP_ACCENTS, "tokenizer.ggml.normalizer.strip_accents" }, + { LLM_KV_TOKENIZER_FIM_PRE_ID, "tokenizer.ggml.fim_pre_token_id" }, + { LLM_KV_TOKENIZER_FIM_SUF_ID, "tokenizer.ggml.fim_suf_token_id" }, + { LLM_KV_TOKENIZER_FIM_MID_ID, "tokenizer.ggml.fim_mid_token_id" }, + { LLM_KV_TOKENIZER_FIM_PAD_ID, "tokenizer.ggml.fim_pad_token_id" }, + { LLM_KV_TOKENIZER_FIM_REP_ID, "tokenizer.ggml.fim_rep_token_id" }, + { LLM_KV_TOKENIZER_FIM_SEP_ID, "tokenizer.ggml.fim_sep_token_id" }, + { LLM_KV_TOKENIZER_SUPPRESS_TOKENS, "tokenizer.ggml.suppress_tokens" }, { LLM_KV_ADAPTER_TYPE, "adapter.type" }, { LLM_KV_ADAPTER_LORA_ALPHA, "adapter.lora.alpha" }, @@ -332,6 +376,7 @@ static const std::map<llm_tensor, const char *> LLM_TENSOR_NAMES = { { LLM_TENSOR_FFN_DOWN_EXP, "blk.%d.ffn_down.%d" }, { LLM_TENSOR_FFN_UP_EXP, "blk.%d.ffn_up.%d" }, { LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" }, + { LLM_TENSOR_FFN_GATE_UP_EXPS, "blk.%d.ffn_gate_up_exps" }, { LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" }, { LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" }, { LLM_TENSOR_ATTN_POST_NORM, "blk.%d.post_attention_norm" }, @@ -339,13 +384,19 @@ static const std::map<llm_tensor, const char *> LLM_TENSOR_NAMES = { { LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm" }, { LLM_TENSOR_ATTN_GATE, "blk.%d.attn_gate" }, { LLM_TENSOR_FFN_POST_NORM, "blk.%d.post_ffw_norm" }, + { LLM_TENSOR_FFN_POST_NORM_1, "blk.%d.post_ffw_norm_1" }, + { LLM_TENSOR_FFN_POST_NORM_2, "blk.%d.post_ffw_norm_2" }, + { LLM_TENSOR_FFN_PRE_NORM_2, "blk.%d.pre_ffw_norm_2" }, { LLM_TENSOR_FFN_GATE_SHEXP, "blk.%d.ffn_gate_shexp" }, { LLM_TENSOR_FFN_UP_SHEXP, "blk.%d.ffn_up_shexp" }, { LLM_TENSOR_FFN_DOWN_SHEXP, "blk.%d.ffn_down_shexp" }, { LLM_TENSOR_FFN_EXP_PROBS_B, "blk.%d.exp_probs_b" }, + { LLM_TENSOR_FFN_LATENT_DOWN, "blk.%d.ffn_latent_down" }, + { LLM_TENSOR_FFN_LATENT_UP, "blk.%d.ffn_latent_up" }, { LLM_TENSOR_ATTN_NORM_2, "blk.%d.attn_norm_2" }, { LLM_TENSOR_ATTN_QKV, "blk.%d.attn_qkv" }, { LLM_TENSOR_LAYER_OUT_NORM, "blk.%d.layer_output_norm" }, + { LLM_TENSOR_LAYER_OUT_SCALE, "blk.%d.layer_output_scale" }, { LLM_TENSOR_ATTN_OUT_NORM, "blk.%d.attn_output_norm" }, { LLM_TENSOR_POS_EMBD, "position_embd" }, { LLM_TENSOR_FFN_ACT, "blk.%d.ffn.act" }, @@ -353,12 +404,14 @@ static const std::map<llm_tensor, const char *> LLM_TENSOR_NAMES = { { LLM_TENSOR_TOKEN_TYPES, "token_types" }, { LLM_TENSOR_CLS, "cls" }, { LLM_TENSOR_CLS_OUT, "cls.output" }, + { LLM_TENSOR_CLS_NORM, "cls.norm" }, { LLM_TENSOR_ENC_OUTPUT_NORM, "enc.output_norm" }, { LLM_TENSOR_FFN_GATE_INP_SHEXP, "blk.%d.ffn_gate_inp_shexp" }, { LLM_TENSOR_SSM_A_NOSCAN, "blk.%d.ssm_a" }, { LLM_TENSOR_SSM_CONV1D, "blk.%d.ssm_conv1d" }, { LLM_TENSOR_SSM_DT, "blk.%d.ssm_dt" }, { LLM_TENSOR_SSM_BETA_ALPHA, "blk.%d.ssm_ba" }, + { LLM_TENSOR_SSM_ALPHA, "blk.%d.ssm_alpha" }, { LLM_TENSOR_SSM_IN, "blk.%d.ssm_in" }, { LLM_TENSOR_SSM_NORM, "blk.%d.ssm_norm" }, { LLM_TENSOR_SSM_OUT, "blk.%d.ssm_out" }, @@ -370,6 +423,15 @@ static const std::map<llm_tensor, const char *> LLM_TENSOR_NAMES = { { LLM_TENSOR_SSM_DT_NORM, "blk.%d.ssm_dt_norm" }, { LLM_TENSOR_SSM_B_NORM, "blk.%d.ssm_b_norm" }, { LLM_TENSOR_SSM_C_NORM, "blk.%d.ssm_c_norm" }, + { LLM_TENSOR_SSM_CONV1D_Q, "blk.%d.ssm_conv1d_q" }, + { LLM_TENSOR_SSM_CONV1D_K, "blk.%d.ssm_conv1d_k" }, + { LLM_TENSOR_SSM_CONV1D_V, "blk.%d.ssm_conv1d_v" }, + { LLM_TENSOR_SSM_F_A, "blk.%d.ssm_f_a" }, + { LLM_TENSOR_SSM_F_B, "blk.%d.ssm_f_b" }, + { LLM_TENSOR_SSM_BETA, "blk.%d.ssm_beta" }, + { LLM_TENSOR_SSM_G_A, "blk.%d.ssm_g_a" }, + { LLM_TENSOR_SSM_G_B, "blk.%d.ssm_g_b" }, + { LLM_TENSOR_SSM_NORM, "blk.%d.ssm_norm" }, { LLM_TENSOR_ATTN_Q_A_NORM, "blk.%d.attn_q_a_norm" }, { LLM_TENSOR_ATTN_KV_A_NORM, "blk.%d.attn_kv_a_norm" }, { LLM_TENSOR_ATTN_Q_A, "blk.%d.attn_q_a" }, @@ -397,6 +459,8 @@ static const std::map<llm_tensor, const char *> LLM_TENSOR_NAMES = { { LLM_TENSOR_FFN_NORM_EXPS, "blk.%d.ffn_norm_exps" }, { LLM_TENSOR_ATTN_K_B, "blk.%d.attn_k_b" }, { LLM_TENSOR_ATTN_V_B, "blk.%d.attn_v_b" }, + { LLM_TENSOR_NEXTN_PROJ_PRE, "nextn.pre_projection" }, + { LLM_TENSOR_NEXTN_PROJ_POST, "nextn.post_projection" }, { LLM_TENSOR_NEXTN_EH_PROJ, "blk.%d.nextn.eh_proj" }, { LLM_TENSOR_NEXTN_EMBED_TOKENS, "blk.%d.nextn.embed_tokens" }, { LLM_TENSOR_NEXTN_ENORM, "blk.%d.nextn.enorm" }, @@ -496,1771 +560,16 @@ static const std::map<llm_tensor, const char *> LLM_TENSOR_NAMES = { { LLM_TENSOR_VISEXP_FFN_GATE, "blk.%d.vis_gate" }, { LLM_TENSOR_VISEXP_FFN_DOWN, "blk.%d.vis_down" }, { LLM_TENSOR_VISEXP_FFN_UP, "blk.%d.vis_up" }, + { LLM_TENSOR_INDEXER_K_NORM, "blk.%d.indexer.k_norm" }, + { LLM_TENSOR_INDEXER_PROJ, "blk.%d.indexer.proj" }, + { LLM_TENSOR_INDEXER_ATTN_K, "blk.%d.indexer.attn_k" }, + { LLM_TENSOR_INDEXER_ATTN_Q_B, "blk.%d.indexer.attn_q_b" }, + { LLM_TENSOR_MASKED_EMBD_CENTROIDS, "masked_embd_centroids" }, + { LLM_TENSOR_MASKED_EMBD_ORDERING, "masked_embd_ordering" }, + { LLM_TENSOR_FC, "fc" }, + { LLM_TENSOR_D2T, "d2t" }, }; -static std::set<llm_tensor> llm_get_tensor_names(llm_arch arch) { - switch (arch) { - case LLM_ARCH_CLIP: - return {}; - case LLM_ARCH_LLAMA: - case LLM_ARCH_DECI: - case LLM_ARCH_MISTRAL3: - case LLM_ARCH_LLAMA_EMBED: - return { - LLM_TENSOR_TOKEN_EMBD, - LLM_TENSOR_OUTPUT_NORM, - LLM_TENSOR_OUTPUT, - LLM_TENSOR_ROPE_FREQS, - LLM_TENSOR_ATTN_NORM, - LLM_TENSOR_ATTN_Q, - LLM_TENSOR_ATTN_K, - LLM_TENSOR_ATTN_V, - LLM_TENSOR_ATTN_OUT, - LLM_TENSOR_ATTN_ROT_EMBD, - LLM_TENSOR_FFN_GATE_INP, - LLM_TENSOR_FFN_NORM, - LLM_TENSOR_FFN_GATE, - LLM_TENSOR_FFN_DOWN, - LLM_TENSOR_FFN_UP, - LLM_TENSOR_FFN_GATE_EXP, - LLM_TENSOR_FFN_DOWN_EXP, - LLM_TENSOR_FFN_UP_EXP, - LLM_TENSOR_FFN_GATE_EXPS, - LLM_TENSOR_FFN_DOWN_EXPS, - LLM_TENSOR_FFN_UP_EXPS, - }; - case LLM_ARCH_ARCEE: - case LLM_ARCH_STARCODER2: - case LLM_ARCH_NEMOTRON: - return { - LLM_TENSOR_TOKEN_EMBD, - LLM_TENSOR_OUTPUT_NORM, - LLM_TENSOR_OUTPUT, - LLM_TENSOR_ROPE_FREQS, - LLM_TENSOR_ATTN_NORM, - LLM_TENSOR_ATTN_Q, - LLM_TENSOR_ATTN_K, - LLM_TENSOR_ATTN_V, - LLM_TENSOR_ATTN_OUT, - LLM_TENSOR_ATTN_ROT_EMBD, - LLM_TENSOR_FFN_NORM, - LLM_TENSOR_FFN_DOWN, - LLM_TENSOR_FFN_UP, - }; - case LLM_ARCH_AFMOE: - return { - LLM_TENSOR_TOKEN_EMBD, - LLM_TENSOR_OUTPUT_NORM, - LLM_TENSOR_OUTPUT, - LLM_TENSOR_ATTN_NORM, - LLM_TENSOR_ATTN_POST_NORM, - LLM_TENSOR_ATTN_Q, - LLM_TENSOR_ATTN_K, - LLM_TENSOR_ATTN_V, - LLM_TENSOR_ATTN_OUT, - LLM_TENSOR_ATTN_Q_NORM, - LLM_TENSOR_ATTN_K_NORM, - LLM_TENSOR_ATTN_GATE, - LLM_TENSOR_FFN_NORM, - LLM_TENSOR_FFN_POST_NORM, - LLM_TENSOR_FFN_GATE_INP, - LLM_TENSOR_FFN_GATE, - LLM_TENSOR_FFN_DOWN, - LLM_TENSOR_FFN_UP, - LLM_TENSOR_FFN_GATE_EXPS, - LLM_TENSOR_FFN_DOWN_EXPS, - LLM_TENSOR_FFN_UP_EXPS, - LLM_TENSOR_FFN_GATE_SHEXP, - LLM_TENSOR_FFN_UP_SHEXP, - LLM_TENSOR_FFN_DOWN_SHEXP, - LLM_TENSOR_FFN_EXP_PROBS_B, - }; - case LLM_ARCH_LLAMA4: - return { - LLM_TENSOR_TOKEN_EMBD, - LLM_TENSOR_OUTPUT_NORM, - LLM_TENSOR_OUTPUT, - LLM_TENSOR_ROPE_FREQS, - LLM_TENSOR_ATTN_NORM, - LLM_TENSOR_ATTN_Q, - LLM_TENSOR_ATTN_K, - LLM_TENSOR_ATTN_V, - LLM_TENSOR_ATTN_OUT, - LLM_TENSOR_ATTN_ROT_EMBD, - LLM_TENSOR_FFN_GATE_INP, - LLM_TENSOR_FFN_NORM, - LLM_TENSOR_FFN_GATE, - LLM_TENSOR_FFN_DOWN, - LLM_TENSOR_FFN_UP, - LLM_TENSOR_FFN_GATE_EXP, - LLM_TENSOR_FFN_DOWN_EXP, - LLM_TENSOR_FFN_UP_EXP, - LLM_TENSOR_FFN_GATE_EXPS, - LLM_TENSOR_FFN_DOWN_EXPS, - LLM_TENSOR_FFN_UP_EXPS, - LLM_TENSOR_FFN_GATE_SHEXP, - LLM_TENSOR_FFN_DOWN_SHEXP, - LLM_TENSOR_FFN_UP_SHEXP, - }; - case LLM_ARCH_BAICHUAN: - case LLM_ARCH_ORION: - case LLM_ARCH_XVERSE: - case LLM_ARCH_EXAONE: - return { - LLM_TENSOR_TOKEN_EMBD, - LLM_TENSOR_OUTPUT_NORM, - LLM_TENSOR_OUTPUT, - LLM_TENSOR_ROPE_FREQS, - LLM_TENSOR_ATTN_NORM, - LLM_TENSOR_ATTN_Q, - LLM_TENSOR_ATTN_K, - LLM_TENSOR_ATTN_V, - LLM_TENSOR_ATTN_OUT, - LLM_TENSOR_ATTN_ROT_EMBD, - LLM_TENSOR_FFN_NORM, - LLM_TENSOR_FFN_GATE, - LLM_TENSOR_FFN_DOWN, - LLM_TENSOR_FFN_UP, - }; - case LLM_ARCH_FALCON: - return { - LLM_TENSOR_TOKEN_EMBD, - LLM_TENSOR_OUTPUT_NORM, - LLM_TENSOR_OUTPUT, - LLM_TENSOR_ATTN_NORM, - LLM_TENSOR_ATTN_NORM_2, - LLM_TENSOR_ATTN_QKV, - LLM_TENSOR_ATTN_OUT, - LLM_TENSOR_FFN_DOWN, - LLM_TENSOR_FFN_UP, - }; - case LLM_ARCH_GROK: - return { - LLM_TENSOR_TOKEN_EMBD, - LLM_TENSOR_OUTPUT_NORM, - LLM_TENSOR_OUTPUT, - LLM_TENSOR_ROPE_FREQS, - LLM_TENSOR_ATTN_NORM, - LLM_TENSOR_ATTN_Q, - LLM_TENSOR_ATTN_K, - LLM_TENSOR_ATTN_V, - LLM_TENSOR_ATTN_OUT, - LLM_TENSOR_ATTN_ROT_EMBD, - LLM_TENSOR_FFN_GATE_INP, - LLM_TENSOR_FFN_NORM, - LLM_TENSOR_FFN_GATE, - LLM_TENSOR_FFN_DOWN, - LLM_TENSOR_FFN_UP, - LLM_TENSOR_FFN_GATE_EXP, - LLM_TENSOR_FFN_DOWN_EXP, - LLM_TENSOR_FFN_UP_EXP, - LLM_TENSOR_FFN_GATE_EXPS, - LLM_TENSOR_FFN_DOWN_EXPS, - LLM_TENSOR_FFN_UP_EXPS, - LLM_TENSOR_FFN_POST_NORM, - LLM_TENSOR_LAYER_OUT_NORM, - LLM_TENSOR_ATTN_OUT_NORM, - }; - case LLM_ARCH_GPT2: - case LLM_ARCH_STARCODER: - return { - LLM_TENSOR_TOKEN_EMBD, - LLM_TENSOR_POS_EMBD, - LLM_TENSOR_OUTPUT_NORM, - LLM_TENSOR_OUTPUT, - LLM_TENSOR_ATTN_NORM, - LLM_TENSOR_ATTN_QKV, - LLM_TENSOR_ATTN_OUT, - LLM_TENSOR_FFN_NORM, - LLM_TENSOR_FFN_UP, - LLM_TENSOR_FFN_DOWN, - }; - case LLM_ARCH_GPTNEOX: - return { - LLM_TENSOR_TOKEN_EMBD, - LLM_TENSOR_OUTPUT_NORM, - LLM_TENSOR_OUTPUT, - LLM_TENSOR_ATTN_NORM, - LLM_TENSOR_ATTN_QKV, - LLM_TENSOR_ATTN_OUT, - LLM_TENSOR_FFN_NORM, - LLM_TENSOR_FFN_DOWN, - LLM_TENSOR_FFN_UP, - }; - case LLM_ARCH_MPT: - return { - LLM_TENSOR_TOKEN_EMBD, - LLM_TENSOR_OUTPUT_NORM, - LLM_TENSOR_OUTPUT, - LLM_TENSOR_ATTN_NORM, - LLM_TENSOR_FFN_NORM, - LLM_TENSOR_ATTN_QKV, - LLM_TENSOR_ATTN_OUT, - LLM_TENSOR_FFN_DOWN, - LLM_TENSOR_FFN_UP, - LLM_TENSOR_FFN_ACT, - LLM_TENSOR_POS_EMBD, - LLM_TENSOR_ATTN_Q_NORM, - LLM_TENSOR_ATTN_K_NORM, - }; - case LLM_ARCH_REFACT: - case LLM_ARCH_QWEN2: - case LLM_ARCH_QWEN2VL: - case LLM_ARCH_INTERNLM2: - case LLM_ARCH_GRANITE: - case LLM_ARCH_ERNIE4_5: - case LLM_ARCH_SMOLLM3: - case LLM_ARCH_DREAM: - case LLM_ARCH_LLADA: - case LLM_ARCH_PANGU_EMBED: - return { - LLM_TENSOR_TOKEN_EMBD, - LLM_TENSOR_OUTPUT_NORM, - LLM_TENSOR_OUTPUT, - LLM_TENSOR_ATTN_NORM, - LLM_TENSOR_ATTN_Q, - LLM_TENSOR_ATTN_K, - LLM_TENSOR_ATTN_V, - LLM_TENSOR_ATTN_OUT, - LLM_TENSOR_FFN_NORM, - LLM_TENSOR_FFN_GATE, - LLM_TENSOR_FFN_DOWN, - LLM_TENSOR_FFN_UP, - }; - case LLM_ARCH_BERT: - return { - LLM_TENSOR_TOKEN_EMBD, - LLM_TENSOR_TOKEN_EMBD_NORM, - LLM_TENSOR_TOKEN_TYPES, - LLM_TENSOR_POS_EMBD, - LLM_TENSOR_ATTN_OUT_NORM, - LLM_TENSOR_ATTN_QKV, - LLM_TENSOR_ATTN_Q, - LLM_TENSOR_ATTN_K, - LLM_TENSOR_ATTN_V, - LLM_TENSOR_ATTN_OUT, - LLM_TENSOR_LAYER_OUT_NORM, - LLM_TENSOR_FFN_DOWN, - LLM_TENSOR_FFN_UP, - LLM_TENSOR_CLS, - LLM_TENSOR_CLS_OUT, - }; - case LLM_ARCH_NOMIC_BERT: - return { - LLM_TENSOR_TOKEN_EMBD, - LLM_TENSOR_TOKEN_EMBD_NORM, - LLM_TENSOR_TOKEN_TYPES, - LLM_TENSOR_ATTN_OUT_NORM, - LLM_TENSOR_ATTN_QKV, - LLM_TENSOR_ATTN_OUT, - LLM_TENSOR_LAYER_OUT_NORM, - LLM_TENSOR_FFN_GATE, - LLM_TENSOR_FFN_DOWN, - LLM_TENSOR_FFN_UP, - }; - case LLM_ARCH_NOMIC_BERT_MOE: - return { - LLM_TENSOR_TOKEN_EMBD, - LLM_TENSOR_TOKEN_EMBD_NORM, - LLM_TENSOR_TOKEN_TYPES, - LLM_TENSOR_ATTN_OUT_NORM, - LLM_TENSOR_ATTN_QKV, - LLM_TENSOR_ATTN_OUT, - LLM_TENSOR_LAYER_OUT_NORM, - LLM_TENSOR_FFN_GATE, - LLM_TENSOR_FFN_DOWN, - LLM_TENSOR_FFN_UP, - LLM_TENSOR_FFN_GATE_INP, - LLM_TENSOR_FFN_DOWN_EXPS, - LLM_TENSOR_FFN_UP_EXPS, - }; - case LLM_ARCH_NEO_BERT: - return { - LLM_TENSOR_TOKEN_EMBD, - LLM_TENSOR_ATTN_NORM, - LLM_TENSOR_ATTN_QKV, - LLM_TENSOR_ATTN_OUT, - LLM_TENSOR_FFN_NORM, - LLM_TENSOR_FFN_DOWN, - LLM_TENSOR_FFN_UP, - LLM_TENSOR_ENC_OUTPUT_NORM, - LLM_TENSOR_CLS, - LLM_TENSOR_CLS_OUT, - }; - case LLM_ARCH_MODERN_BERT: - return { - LLM_TENSOR_TOKEN_EMBD, - LLM_TENSOR_TOKEN_EMBD_NORM, - LLM_TENSOR_OUTPUT_NORM, - LLM_TENSOR_ATTN_NORM, - LLM_TENSOR_ATTN_OUT, - LLM_TENSOR_ATTN_QKV, - LLM_TENSOR_FFN_DOWN, - LLM_TENSOR_FFN_UP, - LLM_TENSOR_FFN_NORM, - LLM_TENSOR_CLS, - LLM_TENSOR_CLS_OUT, - }; - case LLM_ARCH_JINA_BERT_V2: - return { - LLM_TENSOR_TOKEN_EMBD, - LLM_TENSOR_TOKEN_EMBD_NORM, - LLM_TENSOR_TOKEN_TYPES, - LLM_TENSOR_ATTN_NORM_2, - LLM_TENSOR_ATTN_OUT_NORM, - LLM_TENSOR_ATTN_Q, - LLM_TENSOR_ATTN_Q_NORM, - LLM_TENSOR_ATTN_K, - LLM_TENSOR_ATTN_K_NORM, - LLM_TENSOR_ATTN_V, - LLM_TENSOR_ATTN_OUT, - LLM_TENSOR_LAYER_OUT_NORM, - LLM_TENSOR_FFN_DOWN, - LLM_TENSOR_FFN_GATE, - LLM_TENSOR_FFN_UP, - LLM_TENSOR_CLS, - }; - case LLM_ARCH_JINA_BERT_V3: - return { - LLM_TENSOR_TOKEN_EMBD, - LLM_TENSOR_TOKEN_EMBD_NORM, - LLM_TENSOR_TOKEN_TYPES, - LLM_TENSOR_ATTN_OUT_NORM, - LLM_TENSOR_ATTN_QKV, - LLM_TENSOR_ATTN_OUT, - LLM_TENSOR_FFN_DOWN, - LLM_TENSOR_FFN_UP, - LLM_TENSOR_LAYER_OUT_NORM, - }; - case LLM_ARCH_BLOOM: - return { - LLM_TENSOR_TOKEN_EMBD, - LLM_TENSOR_TOKEN_EMBD_NORM, - LLM_TENSOR_OUTPUT_NORM, - LLM_TENSOR_OUTPUT, - LLM_TENSOR_ATTN_NORM, - LLM_TENSOR_ATTN_QKV, - LLM_TENSOR_ATTN_OUT, - LLM_TENSOR_FFN_NORM, - LLM_TENSOR_FFN_UP, - LLM_TENSOR_FFN_DOWN, - }; - case LLM_ARCH_STABLELM: - return { - LLM_TENSOR_TOKEN_EMBD, - LLM_TENSOR_OUTPUT_NORM, - LLM_TENSOR_OUTPUT, - LLM_TENSOR_ROPE_FREQS, - LLM_TENSOR_ATTN_NORM, - LLM_TENSOR_ATTN_Q, - LLM_TENSOR_ATTN_K, - LLM_TENSOR_ATTN_V, - LLM_TENSOR_ATTN_OUT, - LLM_TENSOR_FFN_NORM, - LLM_TENSOR_FFN_GATE, - LLM_TENSOR_FFN_DOWN, - LLM_TENSOR_FFN_UP, - LLM_TENSOR_ATTN_Q_NORM, - LLM_TENSOR_ATTN_K_NORM, - }; - case LLM_ARCH_QWEN: - return { - LLM_TENSOR_TOKEN_EMBD, - LLM_TENSOR_OUTPUT_NORM, - LLM_TENSOR_OUTPUT, - LLM_TENSOR_ROPE_FREQS, - LLM_TENSOR_ATTN_NORM, - LLM_TENSOR_ATTN_QKV, - LLM_TENSOR_ATTN_OUT, - LLM_TENSOR_FFN_NORM, - LLM_TENSOR_FFN_GATE, - LLM_TENSOR_FFN_DOWN, - LLM_TENSOR_FFN_UP, - }; - case LLM_ARCH_QWEN2MOE: - return { - LLM_TENSOR_TOKEN_EMBD, - LLM_TENSOR_OUTPUT_NORM, - LLM_TENSOR_OUTPUT, - LLM_TENSOR_ATTN_NORM, - LLM_TENSOR_ATTN_Q, - LLM_TENSOR_ATTN_K, - LLM_TENSOR_ATTN_V, - LLM_TENSOR_ATTN_OUT, - LLM_TENSOR_FFN_NORM, - LLM_TENSOR_FFN_GATE_INP, - LLM_TENSOR_FFN_GATE_EXPS, - LLM_TENSOR_FFN_DOWN_EXPS, - LLM_TENSOR_FFN_UP_EXPS, - LLM_TENSOR_FFN_GATE_INP_SHEXP, - LLM_TENSOR_FFN_GATE_SHEXP, - LLM_TENSOR_FFN_DOWN_SHEXP, - LLM_TENSOR_FFN_UP_SHEXP, - }; - case LLM_ARCH_QWEN3: - return { - LLM_TENSOR_TOKEN_EMBD, - LLM_TENSOR_OUTPUT_NORM, - LLM_TENSOR_OUTPUT, - LLM_TENSOR_CLS_OUT, - LLM_TENSOR_ATTN_NORM, - LLM_TENSOR_ATTN_Q, - LLM_TENSOR_ATTN_Q_NORM, - LLM_TENSOR_ATTN_K, - LLM_TENSOR_ATTN_K_NORM, - LLM_TENSOR_ATTN_V, - LLM_TENSOR_ATTN_OUT, - LLM_TENSOR_FFN_NORM, - LLM_TENSOR_FFN_GATE, - LLM_TENSOR_FFN_DOWN, - LLM_TENSOR_FFN_UP, - }; - case LLM_ARCH_QWEN3MOE: - case LLM_ARCH_QWEN3VLMOE: - case LLM_ARCH_OLMOE: - case LLM_ARCH_LLADA_MOE: - case LLM_ARCH_RND1: - return { - LLM_TENSOR_TOKEN_EMBD, - LLM_TENSOR_OUTPUT_NORM, - LLM_TENSOR_OUTPUT, - LLM_TENSOR_ATTN_NORM, - LLM_TENSOR_ATTN_Q, - LLM_TENSOR_ATTN_Q_NORM, - LLM_TENSOR_ATTN_K, - LLM_TENSOR_ATTN_K_NORM, - LLM_TENSOR_ATTN_V, - LLM_TENSOR_ATTN_OUT, - LLM_TENSOR_FFN_NORM, - LLM_TENSOR_FFN_GATE_INP, - LLM_TENSOR_FFN_GATE_EXPS, - LLM_TENSOR_FFN_DOWN_EXPS, - LLM_TENSOR_FFN_UP_EXPS, - }; - case LLM_ARCH_QWEN3NEXT: - return { - LLM_TENSOR_TOKEN_EMBD, - LLM_TENSOR_OUTPUT_NORM, - LLM_TENSOR_OUTPUT, - LLM_TENSOR_ATTN_NORM, - LLM_TENSOR_ATTN_POST_NORM, - LLM_TENSOR_ATTN_Q, - LLM_TENSOR_ATTN_Q_NORM, - LLM_TENSOR_ATTN_K, - LLM_TENSOR_ATTN_K_NORM, - LLM_TENSOR_ATTN_V, - LLM_TENSOR_ATTN_OUT, - LLM_TENSOR_ATTN_QKV, - LLM_TENSOR_ATTN_GATE, - LLM_TENSOR_FFN_NORM, - LLM_TENSOR_FFN_GATE_INP, - LLM_TENSOR_FFN_GATE_EXPS, - LLM_TENSOR_FFN_DOWN_EXPS, - LLM_TENSOR_FFN_UP_EXPS, - LLM_TENSOR_FFN_GATE_INP_SHEXP, - LLM_TENSOR_FFN_GATE_SHEXP, - LLM_TENSOR_FFN_DOWN_SHEXP, - LLM_TENSOR_FFN_UP_SHEXP, - LLM_TENSOR_SSM_A_NOSCAN, - LLM_TENSOR_SSM_CONV1D, - LLM_TENSOR_SSM_DT, - LLM_TENSOR_SSM_BETA_ALPHA, - LLM_TENSOR_SSM_IN, - LLM_TENSOR_SSM_NORM, - LLM_TENSOR_SSM_OUT, - }; - case LLM_ARCH_QWEN3VL: - case LLM_ARCH_CHAMELEON: - case LLM_ARCH_HUNYUAN_DENSE: - return { - LLM_TENSOR_TOKEN_EMBD, - LLM_TENSOR_OUTPUT_NORM, - LLM_TENSOR_OUTPUT, - LLM_TENSOR_ATTN_NORM, - LLM_TENSOR_ATTN_Q, - LLM_TENSOR_ATTN_Q_NORM, - LLM_TENSOR_ATTN_K, - LLM_TENSOR_ATTN_K_NORM, - LLM_TENSOR_ATTN_V, - LLM_TENSOR_ATTN_OUT, - LLM_TENSOR_FFN_NORM, - LLM_TENSOR_FFN_GATE, - LLM_TENSOR_FFN_DOWN, - LLM_TENSOR_FFN_UP, - }; - case LLM_ARCH_PHI2: - return { - LLM_TENSOR_TOKEN_EMBD, - LLM_TENSOR_OUTPUT_NORM, - LLM_TENSOR_OUTPUT, - LLM_TENSOR_ATTN_NORM, - LLM_TENSOR_ATTN_QKV, - LLM_TENSOR_ATTN_Q, - LLM_TENSOR_ATTN_K, - LLM_TENSOR_ATTN_V, - LLM_TENSOR_ATTN_OUT, - LLM_TENSOR_FFN_DOWN, - LLM_TENSOR_FFN_UP, - }; - case LLM_ARCH_PHI3: - return { - LLM_TENSOR_TOKEN_EMBD, - LLM_TENSOR_OUTPUT_NORM, - LLM_TENSOR_OUTPUT, - LLM_TENSOR_ROPE_FACTORS_LONG, - LLM_TENSOR_ROPE_FACTORS_SHORT, - LLM_TENSOR_ATTN_NORM, - LLM_TENSOR_ATTN_QKV, - LLM_TENSOR_ATTN_Q, - LLM_TENSOR_ATTN_K, - LLM_TENSOR_ATTN_V, - LLM_TENSOR_ATTN_OUT, - LLM_TENSOR_FFN_NORM, - LLM_TENSOR_FFN_DOWN, - LLM_TENSOR_FFN_UP, - }; - case LLM_ARCH_PHIMOE: - return { - LLM_TENSOR_TOKEN_EMBD, - LLM_TENSOR_OUTPUT_NORM, - LLM_TENSOR_OUTPUT, - LLM_TENSOR_ROPE_FACTORS_LONG, - LLM_TENSOR_ROPE_FACTORS_SHORT, - LLM_TENSOR_ATTN_NORM, - LLM_TENSOR_ATTN_QKV, - LLM_TENSOR_ATTN_Q, - LLM_TENSOR_ATTN_K, - LLM_TENSOR_ATTN_V, - LLM_TENSOR_ATTN_OUT, - LLM_TENSOR_FFN_NORM, - LLM_TENSOR_FFN_GATE_INP, - LLM_TENSOR_FFN_GATE_EXPS, - LLM_TENSOR_FFN_DOWN_EXPS, - LLM_TENSOR_FFN_UP_EXPS, - }; - case LLM_ARCH_PLAMO: - return { - LLM_TENSOR_TOKEN_EMBD, - LLM_TENSOR_OUTPUT_NORM, - LLM_TENSOR_OUTPUT, - LLM_TENSOR_ROPE_FREQS, - LLM_TENSOR_ATTN_NORM, - LLM_TENSOR_ATTN_Q, - LLM_TENSOR_ATTN_K, - LLM_TENSOR_ATTN_V, - LLM_TENSOR_ATTN_OUT, - LLM_TENSOR_ATTN_ROT_EMBD, - LLM_TENSOR_FFN_GATE, - LLM_TENSOR_FFN_DOWN, - LLM_TENSOR_FFN_UP, - }; - case LLM_ARCH_PLAMO2: - return { - LLM_TENSOR_TOKEN_EMBD, - LLM_TENSOR_OUTPUT_NORM, - LLM_TENSOR_OUTPUT, - LLM_TENSOR_ROPE_FREQS, - LLM_TENSOR_ATTN_NORM, - LLM_TENSOR_ATTN_QKV, - LLM_TENSOR_ATTN_Q_NORM, - LLM_TENSOR_ATTN_K_NORM, - LLM_TENSOR_ATTN_OUT, - LLM_TENSOR_ATTN_ROT_EMBD, - LLM_TENSOR_FFN_NORM, - LLM_TENSOR_FFN_DOWN, - LLM_TENSOR_FFN_UP, - LLM_TENSOR_SSM_IN, - LLM_TENSOR_SSM_CONV1D, - LLM_TENSOR_SSM_X, - LLM_TENSOR_SSM_DT, - LLM_TENSOR_SSM_A, - LLM_TENSOR_SSM_D, - LLM_TENSOR_SSM_OUT, - LLM_TENSOR_SSM_DT_NORM, - LLM_TENSOR_SSM_B_NORM, - LLM_TENSOR_SSM_C_NORM, - LLM_TENSOR_ATTN_POST_NORM, - LLM_TENSOR_FFN_POST_NORM, - }; - case LLM_ARCH_PLAMO3: - return { - LLM_TENSOR_TOKEN_EMBD, - LLM_TENSOR_OUTPUT_NORM, - LLM_TENSOR_OUTPUT, - LLM_TENSOR_ATTN_NORM, - LLM_TENSOR_ATTN_QKV, - LLM_TENSOR_ATTN_Q_NORM, - LLM_TENSOR_ATTN_K_NORM, - LLM_TENSOR_ATTN_OUT, - LLM_TENSOR_ATTN_POST_NORM, - LLM_TENSOR_FFN_NORM, - LLM_TENSOR_FFN_POST_NORM, - LLM_TENSOR_FFN_DOWN, - LLM_TENSOR_FFN_UP, - }; - case LLM_ARCH_CODESHELL: - return { - LLM_TENSOR_TOKEN_EMBD, - LLM_TENSOR_OUTPUT_NORM, - LLM_TENSOR_OUTPUT, - LLM_TENSOR_ROPE_FREQS, - LLM_TENSOR_ATTN_NORM, - LLM_TENSOR_ATTN_Q, - LLM_TENSOR_ATTN_K, - LLM_TENSOR_ATTN_V, - LLM_TENSOR_ATTN_QKV, - LLM_TENSOR_ATTN_OUT, - LLM_TENSOR_ATTN_ROT_EMBD, - LLM_TENSOR_FFN_NORM, - LLM_TENSOR_FFN_GATE, - LLM_TENSOR_FFN_DOWN, - LLM_TENSOR_FFN_UP, - }; - case LLM_ARCH_MINICPM: - return { - LLM_TENSOR_TOKEN_EMBD, - LLM_TENSOR_OUTPUT_NORM, - LLM_TENSOR_OUTPUT, - LLM_TENSOR_ROPE_FREQS, - LLM_TENSOR_ROPE_FACTORS_LONG, - LLM_TENSOR_ROPE_FACTORS_SHORT, - LLM_TENSOR_ATTN_NORM, - LLM_TENSOR_ATTN_Q, - LLM_TENSOR_ATTN_K, - LLM_TENSOR_ATTN_V, - LLM_TENSOR_ATTN_OUT, - LLM_TENSOR_ATTN_ROT_EMBD, - LLM_TENSOR_FFN_GATE_INP, - LLM_TENSOR_FFN_NORM, - LLM_TENSOR_FFN_GATE, - LLM_TENSOR_FFN_DOWN, - LLM_TENSOR_FFN_UP, - LLM_TENSOR_FFN_GATE_EXP, - LLM_TENSOR_FFN_DOWN_EXP, - LLM_TENSOR_FFN_UP_EXP, - }; - case LLM_ARCH_MINICPM3: - return { - LLM_TENSOR_TOKEN_EMBD, - LLM_TENSOR_OUTPUT_NORM, - LLM_TENSOR_OUTPUT, - LLM_TENSOR_ROPE_FACTORS_LONG, - LLM_TENSOR_ROPE_FACTORS_SHORT, - LLM_TENSOR_ATTN_NORM, - LLM_TENSOR_ATTN_Q_A_NORM, - LLM_TENSOR_ATTN_KV_A_NORM, - LLM_TENSOR_ATTN_Q, - LLM_TENSOR_ATTN_Q_A, - LLM_TENSOR_ATTN_Q_B, - LLM_TENSOR_ATTN_KV_A_MQA, - LLM_TENSOR_ATTN_KV_B, - LLM_TENSOR_ATTN_OUT, - LLM_TENSOR_FFN_NORM, - LLM_TENSOR_FFN_GATE, - LLM_TENSOR_FFN_UP, - LLM_TENSOR_FFN_DOWN, - }; - case LLM_ARCH_GEMMA: - return { - LLM_TENSOR_TOKEN_EMBD, - LLM_TENSOR_OUTPUT_NORM, - LLM_TENSOR_ATTN_NORM, - LLM_TENSOR_ATTN_Q, - LLM_TENSOR_ATTN_K, - LLM_TENSOR_ATTN_V, - LLM_TENSOR_ATTN_OUT, - LLM_TENSOR_FFN_NORM, - LLM_TENSOR_FFN_GATE, - LLM_TENSOR_FFN_DOWN, - LLM_TENSOR_FFN_UP, - }; - case LLM_ARCH_GEMMA2: - return { - LLM_TENSOR_TOKEN_EMBD, - LLM_TENSOR_OUTPUT_NORM, - LLM_TENSOR_ATTN_NORM, - LLM_TENSOR_ATTN_Q, - LLM_TENSOR_ATTN_K, - LLM_TENSOR_ATTN_V, - LLM_TENSOR_ATTN_OUT, - LLM_TENSOR_ATTN_POST_NORM, - LLM_TENSOR_FFN_NORM, - LLM_TENSOR_FFN_GATE, - LLM_TENSOR_FFN_DOWN, - LLM_TENSOR_FFN_UP, - LLM_TENSOR_FFN_POST_NORM, - }; - case LLM_ARCH_GEMMA3: - return { - LLM_TENSOR_TOKEN_EMBD, - LLM_TENSOR_OUTPUT_NORM, - LLM_TENSOR_OUTPUT, - LLM_TENSOR_ATTN_NORM, - LLM_TENSOR_ATTN_Q, - LLM_TENSOR_ATTN_Q_NORM, - LLM_TENSOR_ATTN_K, - LLM_TENSOR_ATTN_K_NORM, - LLM_TENSOR_ATTN_V, - LLM_TENSOR_ATTN_OUT, - LLM_TENSOR_ATTN_POST_NORM, - LLM_TENSOR_FFN_NORM, - LLM_TENSOR_FFN_GATE, - LLM_TENSOR_FFN_DOWN, - LLM_TENSOR_FFN_UP, - LLM_TENSOR_FFN_POST_NORM, - }; - case LLM_ARCH_GEMMA3N: - return { - LLM_TENSOR_TOKEN_EMBD, - LLM_TENSOR_OUTPUT_NORM, - LLM_TENSOR_ATTN_NORM, - LLM_TENSOR_ATTN_Q, - LLM_TENSOR_ATTN_Q_NORM, - LLM_TENSOR_ATTN_K, - LLM_TENSOR_ATTN_K_NORM, - LLM_TENSOR_ATTN_V, - LLM_TENSOR_ATTN_OUT, - LLM_TENSOR_ATTN_POST_NORM, - LLM_TENSOR_FFN_NORM, - LLM_TENSOR_FFN_GATE, - LLM_TENSOR_FFN_DOWN, - LLM_TENSOR_FFN_UP, - LLM_TENSOR_FFN_POST_NORM, - LLM_TENSOR_PER_LAYER_TOKEN_EMBD, - LLM_TENSOR_PER_LAYER_MODEL_PROJ, - LLM_TENSOR_PER_LAYER_PROJ_NORM, - LLM_TENSOR_ALTUP_UNEMBD_PROJ, - LLM_TENSOR_ALTUP_PROJ, - LLM_TENSOR_PER_LAYER_INP_GATE, - LLM_TENSOR_PER_LAYER_PROJ, - LLM_TENSOR_PER_LAYER_POST_NORM, - LLM_TENSOR_ALTUP_CORRECT_COEF, - LLM_TENSOR_ALTUP_CORRECT_SCALE, - LLM_TENSOR_ALTUP_PREDICT_COEF, - LLM_TENSOR_ALTUP_ROUTER, - LLM_TENSOR_ALTUP_ROUTER_NORM, - LLM_TENSOR_LAUREL_L, - LLM_TENSOR_LAUREL_R, - LLM_TENSOR_LAUREL_POST_NORM, - }; - case LLM_ARCH_GEMMA_EMBEDDING: - return { - LLM_TENSOR_TOKEN_EMBD, - LLM_TENSOR_OUTPUT_NORM, - LLM_TENSOR_OUTPUT, - LLM_TENSOR_DENSE_2_OUT, - LLM_TENSOR_DENSE_3_OUT, - LLM_TENSOR_ATTN_NORM, - LLM_TENSOR_ATTN_Q, - LLM_TENSOR_ATTN_Q_NORM, - LLM_TENSOR_ATTN_K, - LLM_TENSOR_ATTN_K_NORM, - LLM_TENSOR_ATTN_V, - LLM_TENSOR_ATTN_OUT, - LLM_TENSOR_ATTN_POST_NORM, - LLM_TENSOR_FFN_NORM, - LLM_TENSOR_FFN_GATE, - LLM_TENSOR_FFN_DOWN, - LLM_TENSOR_FFN_UP, - LLM_TENSOR_FFN_POST_NORM, - }; - case LLM_ARCH_MAMBA: - return { - LLM_TENSOR_TOKEN_EMBD, - LLM_TENSOR_OUTPUT_NORM, - LLM_TENSOR_OUTPUT, - LLM_TENSOR_ATTN_NORM, - LLM_TENSOR_SSM_IN, - LLM_TENSOR_SSM_CONV1D, - LLM_TENSOR_SSM_X, - LLM_TENSOR_SSM_DT, - LLM_TENSOR_SSM_A, - LLM_TENSOR_SSM_D, - LLM_TENSOR_SSM_OUT, - }; - case LLM_ARCH_MAMBA2: - return { - LLM_TENSOR_TOKEN_EMBD, - LLM_TENSOR_OUTPUT_NORM, - LLM_TENSOR_OUTPUT, - LLM_TENSOR_ATTN_NORM, - LLM_TENSOR_SSM_IN, - LLM_TENSOR_SSM_CONV1D, - LLM_TENSOR_SSM_DT, - LLM_TENSOR_SSM_A, - LLM_TENSOR_SSM_D, - LLM_TENSOR_SSM_NORM, - LLM_TENSOR_SSM_OUT, - }; - case LLM_ARCH_JAMBA: - return { - LLM_TENSOR_TOKEN_EMBD, - LLM_TENSOR_OUTPUT_NORM, - LLM_TENSOR_OUTPUT, - LLM_TENSOR_ATTN_NORM, - LLM_TENSOR_SSM_IN, - LLM_TENSOR_SSM_CONV1D, - LLM_TENSOR_SSM_X, - LLM_TENSOR_SSM_DT, - LLM_TENSOR_SSM_DT_NORM, - LLM_TENSOR_SSM_A, - LLM_TENSOR_SSM_B_NORM, - LLM_TENSOR_SSM_C_NORM, - LLM_TENSOR_SSM_D, - LLM_TENSOR_SSM_OUT, - LLM_TENSOR_ATTN_Q, - LLM_TENSOR_ATTN_K, - LLM_TENSOR_ATTN_V, - LLM_TENSOR_ATTN_OUT, - LLM_TENSOR_FFN_GATE_INP, - LLM_TENSOR_FFN_NORM, - LLM_TENSOR_FFN_GATE, - LLM_TENSOR_FFN_DOWN, - LLM_TENSOR_FFN_UP, - LLM_TENSOR_FFN_GATE_EXPS, - LLM_TENSOR_FFN_DOWN_EXPS, - LLM_TENSOR_FFN_UP_EXPS, - }; - case LLM_ARCH_FALCON_H1: - return { - LLM_TENSOR_TOKEN_EMBD, - LLM_TENSOR_OUTPUT, - LLM_TENSOR_OUTPUT_NORM, - LLM_TENSOR_ATTN_NORM, - LLM_TENSOR_ATTN_Q, - LLM_TENSOR_ATTN_K, - LLM_TENSOR_ATTN_V, - LLM_TENSOR_ATTN_OUT, - LLM_TENSOR_SSM_IN, - LLM_TENSOR_SSM_CONV1D, - LLM_TENSOR_SSM_DT, - LLM_TENSOR_SSM_A, - LLM_TENSOR_SSM_D, - LLM_TENSOR_SSM_NORM, - LLM_TENSOR_SSM_OUT, - LLM_TENSOR_FFN_NORM, - LLM_TENSOR_FFN_GATE, - LLM_TENSOR_FFN_DOWN, - LLM_TENSOR_FFN_UP, - }; - case LLM_ARCH_COMMAND_R: - return { - LLM_TENSOR_TOKEN_EMBD, - LLM_TENSOR_OUTPUT_NORM, - LLM_TENSOR_ATTN_NORM, - LLM_TENSOR_ATTN_Q, - LLM_TENSOR_ATTN_K, - LLM_TENSOR_ATTN_V, - LLM_TENSOR_ATTN_OUT, - LLM_TENSOR_FFN_GATE, - LLM_TENSOR_FFN_DOWN, - LLM_TENSOR_FFN_UP, - LLM_TENSOR_ATTN_Q_NORM, - LLM_TENSOR_ATTN_K_NORM, - }; - case LLM_ARCH_COHERE2: - return { - LLM_TENSOR_TOKEN_EMBD, - LLM_TENSOR_OUTPUT_NORM, - LLM_TENSOR_ATTN_NORM, - LLM_TENSOR_ATTN_Q, - LLM_TENSOR_ATTN_K, - LLM_TENSOR_ATTN_V, - LLM_TENSOR_ATTN_OUT, - LLM_TENSOR_FFN_GATE, - LLM_TENSOR_FFN_DOWN, - LLM_TENSOR_FFN_UP, - }; - case LLM_ARCH_DBRX: - return { - LLM_TENSOR_TOKEN_EMBD, - LLM_TENSOR_OUTPUT_NORM, - LLM_TENSOR_OUTPUT, - LLM_TENSOR_ATTN_QKV, - LLM_TENSOR_ATTN_NORM, - LLM_TENSOR_ATTN_OUT, - LLM_TENSOR_ATTN_OUT_NORM, - LLM_TENSOR_FFN_GATE_INP, - LLM_TENSOR_FFN_GATE_EXPS, - LLM_TENSOR_FFN_DOWN_EXPS, - LLM_TENSOR_FFN_UP_EXPS, - }; - case LLM_ARCH_OLMO: - return { - LLM_TENSOR_TOKEN_EMBD, - LLM_TENSOR_OUTPUT, - LLM_TENSOR_ATTN_Q, - LLM_TENSOR_ATTN_K, - LLM_TENSOR_ATTN_V, - LLM_TENSOR_ATTN_OUT, - LLM_TENSOR_FFN_GATE, - LLM_TENSOR_FFN_DOWN, - LLM_TENSOR_FFN_UP, - }; - case LLM_ARCH_OLMO2: - return { - LLM_TENSOR_TOKEN_EMBD, - LLM_TENSOR_OUTPUT_NORM, - LLM_TENSOR_OUTPUT, - LLM_TENSOR_ATTN_Q, - LLM_TENSOR_ATTN_K, - LLM_TENSOR_ATTN_V, - LLM_TENSOR_ATTN_OUT, - LLM_TENSOR_ATTN_POST_NORM, - LLM_TENSOR_ATTN_Q_NORM, - LLM_TENSOR_ATTN_K_NORM, - LLM_TENSOR_FFN_POST_NORM, - LLM_TENSOR_FFN_GATE, - LLM_TENSOR_FFN_DOWN, - LLM_TENSOR_FFN_UP, - }; - case LLM_ARCH_OPENELM: - return { - LLM_TENSOR_TOKEN_EMBD, - LLM_TENSOR_OUTPUT_NORM, - LLM_TENSOR_ATTN_NORM, - LLM_TENSOR_ATTN_QKV, - LLM_TENSOR_ATTN_Q_NORM, - LLM_TENSOR_ATTN_K_NORM, - LLM_TENSOR_ATTN_OUT, - LLM_TENSOR_FFN_NORM, - LLM_TENSOR_FFN_GATE, - LLM_TENSOR_FFN_DOWN, - LLM_TENSOR_FFN_UP, - }; - case LLM_ARCH_ARCTIC: - return { - LLM_TENSOR_TOKEN_EMBD, - LLM_TENSOR_OUTPUT_NORM, - LLM_TENSOR_OUTPUT, - LLM_TENSOR_ATTN_NORM, - LLM_TENSOR_ATTN_Q, - LLM_TENSOR_ATTN_K, - LLM_TENSOR_ATTN_V, - LLM_TENSOR_ATTN_OUT, - LLM_TENSOR_FFN_GATE_INP, - LLM_TENSOR_FFN_NORM, - LLM_TENSOR_FFN_GATE, - LLM_TENSOR_FFN_DOWN, - LLM_TENSOR_FFN_UP, - LLM_TENSOR_FFN_NORM_EXPS, - LLM_TENSOR_FFN_GATE_EXPS, - LLM_TENSOR_FFN_DOWN_EXPS, - LLM_TENSOR_FFN_UP_EXPS, - }; - case LLM_ARCH_DEEPSEEK: - return { - LLM_TENSOR_TOKEN_EMBD, - LLM_TENSOR_OUTPUT_NORM, - LLM_TENSOR_OUTPUT, - LLM_TENSOR_ROPE_FREQS, - LLM_TENSOR_ATTN_NORM, - LLM_TENSOR_ATTN_Q, - LLM_TENSOR_ATTN_K, - LLM_TENSOR_ATTN_V, - LLM_TENSOR_ATTN_OUT, - LLM_TENSOR_ATTN_ROT_EMBD, - LLM_TENSOR_FFN_GATE_INP, - LLM_TENSOR_FFN_NORM, - LLM_TENSOR_FFN_GATE, - LLM_TENSOR_FFN_DOWN, - LLM_TENSOR_FFN_UP, - LLM_TENSOR_FFN_GATE_EXPS, - LLM_TENSOR_FFN_DOWN_EXPS, - LLM_TENSOR_FFN_UP_EXPS, - LLM_TENSOR_FFN_GATE_INP_SHEXP, - LLM_TENSOR_FFN_GATE_SHEXP, - LLM_TENSOR_FFN_DOWN_SHEXP, - LLM_TENSOR_FFN_UP_SHEXP, - }; - case LLM_ARCH_DEEPSEEK2: - return { - LLM_TENSOR_TOKEN_EMBD, - LLM_TENSOR_OUTPUT_NORM, - LLM_TENSOR_OUTPUT, - LLM_TENSOR_ATTN_NORM, - LLM_TENSOR_ATTN_Q_A_NORM, - LLM_TENSOR_ATTN_KV_A_NORM, - LLM_TENSOR_ATTN_Q, - LLM_TENSOR_ATTN_Q_A, - LLM_TENSOR_ATTN_Q_B, - LLM_TENSOR_ATTN_KV_A_MQA, - LLM_TENSOR_ATTN_KV_B, - LLM_TENSOR_ATTN_K_B, - LLM_TENSOR_ATTN_V_B, - LLM_TENSOR_ATTN_OUT, - LLM_TENSOR_FFN_NORM, - LLM_TENSOR_FFN_GATE, - LLM_TENSOR_FFN_UP, - LLM_TENSOR_FFN_DOWN, - LLM_TENSOR_FFN_GATE_INP, - LLM_TENSOR_FFN_GATE_EXPS, - LLM_TENSOR_FFN_DOWN_EXPS, - LLM_TENSOR_FFN_UP_EXPS, - LLM_TENSOR_FFN_GATE_INP_SHEXP, - LLM_TENSOR_FFN_GATE_SHEXP, - LLM_TENSOR_FFN_DOWN_SHEXP, - LLM_TENSOR_FFN_UP_SHEXP, - LLM_TENSOR_FFN_EXP_PROBS_B, - }; - case LLM_ARCH_PLM: - return { - LLM_TENSOR_TOKEN_EMBD, - LLM_TENSOR_OUTPUT_NORM, - LLM_TENSOR_ATTN_NORM, - LLM_TENSOR_ATTN_Q, - LLM_TENSOR_ATTN_KV_A_MQA, - LLM_TENSOR_ATTN_KV_A_NORM, - LLM_TENSOR_ATTN_KV_B, - LLM_TENSOR_ATTN_OUT, - LLM_TENSOR_FFN_NORM, - LLM_TENSOR_FFN_DOWN, - LLM_TENSOR_FFN_UP, - }; - case LLM_ARCH_CHATGLM: - return { - LLM_TENSOR_TOKEN_EMBD, - LLM_TENSOR_ROPE_FREQS, - LLM_TENSOR_OUTPUT_NORM, - LLM_TENSOR_OUTPUT, - LLM_TENSOR_ATTN_NORM, - LLM_TENSOR_ATTN_QKV, - LLM_TENSOR_ATTN_Q, - LLM_TENSOR_ATTN_K, - LLM_TENSOR_ATTN_V, - LLM_TENSOR_ATTN_OUT, - LLM_TENSOR_FFN_NORM, - LLM_TENSOR_FFN_UP, - LLM_TENSOR_FFN_DOWN, - }; - case LLM_ARCH_GLM4: - return { - LLM_TENSOR_TOKEN_EMBD, - LLM_TENSOR_ROPE_FREQS, - LLM_TENSOR_OUTPUT_NORM, - LLM_TENSOR_OUTPUT, - LLM_TENSOR_ATTN_NORM, - LLM_TENSOR_ATTN_Q, - LLM_TENSOR_ATTN_K, - LLM_TENSOR_ATTN_V, - LLM_TENSOR_ATTN_OUT, - LLM_TENSOR_FFN_NORM, - LLM_TENSOR_FFN_UP, - LLM_TENSOR_FFN_DOWN, - LLM_TENSOR_ATTN_POST_NORM, - LLM_TENSOR_FFN_POST_NORM, - }; - case LLM_ARCH_GLM4_MOE: - return { - LLM_TENSOR_TOKEN_EMBD, - LLM_TENSOR_OUTPUT_NORM, - LLM_TENSOR_OUTPUT, - LLM_TENSOR_ATTN_NORM, - LLM_TENSOR_ATTN_POST_NORM, - LLM_TENSOR_ATTN_Q, - LLM_TENSOR_ATTN_K, - LLM_TENSOR_ATTN_V, - LLM_TENSOR_ATTN_OUT, - LLM_TENSOR_ATTN_Q_NORM, - LLM_TENSOR_ATTN_K_NORM, - LLM_TENSOR_FFN_GATE, - LLM_TENSOR_FFN_DOWN, - LLM_TENSOR_FFN_UP, - LLM_TENSOR_FFN_GATE_INP, - LLM_TENSOR_FFN_GATE_EXPS, - LLM_TENSOR_FFN_DOWN_EXPS, - LLM_TENSOR_FFN_UP_EXPS, - LLM_TENSOR_FFN_GATE_SHEXP, - LLM_TENSOR_FFN_DOWN_SHEXP, - LLM_TENSOR_FFN_UP_SHEXP, - LLM_TENSOR_FFN_EXP_PROBS_B, - LLM_TENSOR_NEXTN_EH_PROJ, - LLM_TENSOR_NEXTN_EMBED_TOKENS, - LLM_TENSOR_NEXTN_ENORM, - LLM_TENSOR_NEXTN_HNORM, - LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD, - LLM_TENSOR_NEXTN_SHARED_HEAD_NORM, - }; - case LLM_ARCH_BITNET: - return { - LLM_TENSOR_TOKEN_EMBD, - LLM_TENSOR_OUTPUT_NORM, - LLM_TENSOR_ATTN_Q, - LLM_TENSOR_ATTN_K, - LLM_TENSOR_ATTN_V, - LLM_TENSOR_ATTN_OUT, - LLM_TENSOR_ATTN_NORM, - LLM_TENSOR_ATTN_SUB_NORM, - LLM_TENSOR_FFN_GATE, - LLM_TENSOR_FFN_DOWN, - LLM_TENSOR_FFN_UP, - LLM_TENSOR_FFN_NORM, - LLM_TENSOR_FFN_SUB_NORM, - }; - case LLM_ARCH_T5: - return { - LLM_TENSOR_TOKEN_EMBD, - LLM_TENSOR_OUTPUT, - LLM_TENSOR_DEC_OUTPUT_NORM, - LLM_TENSOR_DEC_ATTN_NORM, - LLM_TENSOR_DEC_ATTN_Q, - LLM_TENSOR_DEC_ATTN_K, - LLM_TENSOR_DEC_ATTN_V, - LLM_TENSOR_DEC_ATTN_OUT, - LLM_TENSOR_DEC_ATTN_REL_B, - LLM_TENSOR_DEC_CROSS_ATTN_NORM, - LLM_TENSOR_DEC_CROSS_ATTN_Q, - LLM_TENSOR_DEC_CROSS_ATTN_K, - LLM_TENSOR_DEC_CROSS_ATTN_V, - LLM_TENSOR_DEC_CROSS_ATTN_OUT, - LLM_TENSOR_DEC_CROSS_ATTN_REL_B, - LLM_TENSOR_DEC_FFN_NORM, - LLM_TENSOR_DEC_FFN_GATE, - LLM_TENSOR_DEC_FFN_DOWN, - LLM_TENSOR_DEC_FFN_UP, - LLM_TENSOR_ENC_OUTPUT_NORM, - LLM_TENSOR_ENC_ATTN_NORM, - LLM_TENSOR_ENC_ATTN_Q, - LLM_TENSOR_ENC_ATTN_K, - LLM_TENSOR_ENC_ATTN_V, - LLM_TENSOR_ENC_ATTN_OUT, - LLM_TENSOR_ENC_ATTN_REL_B, - LLM_TENSOR_ENC_FFN_NORM, - LLM_TENSOR_ENC_FFN_GATE, - LLM_TENSOR_ENC_FFN_DOWN, - LLM_TENSOR_ENC_FFN_UP, - }; - case LLM_ARCH_T5ENCODER: - return { - LLM_TENSOR_TOKEN_EMBD, - LLM_TENSOR_OUTPUT, - LLM_TENSOR_ENC_OUTPUT_NORM, - LLM_TENSOR_ENC_ATTN_NORM, - LLM_TENSOR_ENC_ATTN_Q, - LLM_TENSOR_ENC_ATTN_K, - LLM_TENSOR_ENC_ATTN_V, - LLM_TENSOR_ENC_ATTN_OUT, - LLM_TENSOR_ENC_ATTN_REL_B, - LLM_TENSOR_ENC_FFN_NORM, - LLM_TENSOR_ENC_FFN_GATE, - LLM_TENSOR_ENC_FFN_DOWN, - LLM_TENSOR_ENC_FFN_UP, - }; - case LLM_ARCH_JAIS: - return { - LLM_TENSOR_TOKEN_EMBD, - LLM_TENSOR_OUTPUT_NORM, - LLM_TENSOR_OUTPUT, - LLM_TENSOR_ATTN_NORM, - LLM_TENSOR_ATTN_QKV, - LLM_TENSOR_ATTN_OUT, - LLM_TENSOR_FFN_NORM, - LLM_TENSOR_FFN_UP, - LLM_TENSOR_FFN_GATE, - LLM_TENSOR_FFN_DOWN, - }; - case LLM_ARCH_NEMOTRON_H: - return { - LLM_TENSOR_TOKEN_EMBD, - LLM_TENSOR_OUTPUT_NORM, - LLM_TENSOR_OUTPUT, - LLM_TENSOR_ATTN_NORM, - LLM_TENSOR_SSM_IN, - LLM_TENSOR_SSM_CONV1D, - LLM_TENSOR_SSM_DT, - LLM_TENSOR_SSM_A, - LLM_TENSOR_SSM_D, - LLM_TENSOR_SSM_NORM, - LLM_TENSOR_SSM_OUT, - LLM_TENSOR_ATTN_Q, - LLM_TENSOR_ATTN_K, - LLM_TENSOR_ATTN_V, - LLM_TENSOR_ATTN_OUT, - LLM_TENSOR_FFN_DOWN, - LLM_TENSOR_FFN_UP, - }; - case LLM_ARCH_NEMOTRON_H_MOE: - return { - LLM_TENSOR_TOKEN_EMBD, - LLM_TENSOR_OUTPUT_NORM, - LLM_TENSOR_OUTPUT, - LLM_TENSOR_ATTN_NORM, - // mamba(2) ssm layers - LLM_TENSOR_SSM_IN, - LLM_TENSOR_SSM_CONV1D, - LLM_TENSOR_SSM_DT, - LLM_TENSOR_SSM_A, - LLM_TENSOR_SSM_D, - LLM_TENSOR_SSM_NORM, - LLM_TENSOR_SSM_OUT, - // attention layers - LLM_TENSOR_ATTN_Q, - LLM_TENSOR_ATTN_K, - LLM_TENSOR_ATTN_V, - LLM_TENSOR_ATTN_OUT, - // dense FFN - LLM_TENSOR_FFN_DOWN, - LLM_TENSOR_FFN_UP, - // MoE FFN (for MoE layers) - LLM_TENSOR_FFN_GATE_INP, - LLM_TENSOR_FFN_UP_EXPS, - LLM_TENSOR_FFN_DOWN_EXPS, - LLM_TENSOR_FFN_EXP_PROBS_B, - // MoE shared expert layer - LLM_TENSOR_FFN_DOWN_SHEXP, - LLM_TENSOR_FFN_UP_SHEXP, - }; - case LLM_ARCH_EXAONE4: - return { - LLM_TENSOR_TOKEN_EMBD, - LLM_TENSOR_OUTPUT_NORM, - LLM_TENSOR_OUTPUT, - LLM_TENSOR_ROPE_FREQS, - LLM_TENSOR_ATTN_Q, - LLM_TENSOR_ATTN_Q_NORM, - LLM_TENSOR_ATTN_K, - LLM_TENSOR_ATTN_K_NORM, - LLM_TENSOR_ATTN_V, - LLM_TENSOR_ATTN_OUT, - LLM_TENSOR_ATTN_POST_NORM, - LLM_TENSOR_FFN_GATE, - LLM_TENSOR_FFN_DOWN, - LLM_TENSOR_FFN_UP, - LLM_TENSOR_FFN_POST_NORM, - }; - case LLM_ARCH_RWKV6: - return { - LLM_TENSOR_TOKEN_EMBD, - LLM_TENSOR_TOKEN_EMBD_NORM, - LLM_TENSOR_OUTPUT_NORM, - LLM_TENSOR_OUTPUT, - LLM_TENSOR_ATTN_NORM, - LLM_TENSOR_ATTN_NORM_2, - LLM_TENSOR_TIME_MIX_W1, - LLM_TENSOR_TIME_MIX_W2, - LLM_TENSOR_TIME_MIX_LERP_X, - LLM_TENSOR_TIME_MIX_LERP_W, - LLM_TENSOR_TIME_MIX_LERP_K, - LLM_TENSOR_TIME_MIX_LERP_V, - LLM_TENSOR_TIME_MIX_LERP_R, - LLM_TENSOR_TIME_MIX_LERP_G, - LLM_TENSOR_TIME_MIX_LERP_FUSED, - LLM_TENSOR_TIME_MIX_FIRST, - LLM_TENSOR_TIME_MIX_DECAY, - LLM_TENSOR_TIME_MIX_DECAY_W1, - LLM_TENSOR_TIME_MIX_DECAY_W2, - LLM_TENSOR_TIME_MIX_KEY, - LLM_TENSOR_TIME_MIX_VALUE, - LLM_TENSOR_TIME_MIX_RECEPTANCE, - LLM_TENSOR_TIME_MIX_GATE, - LLM_TENSOR_TIME_MIX_LN, - LLM_TENSOR_TIME_MIX_OUTPUT, - LLM_TENSOR_CHANNEL_MIX_LERP_K, - LLM_TENSOR_CHANNEL_MIX_LERP_R, - LLM_TENSOR_CHANNEL_MIX_KEY, - LLM_TENSOR_CHANNEL_MIX_VALUE, - LLM_TENSOR_CHANNEL_MIX_RECEPTANCE, - }; - case LLM_ARCH_RWKV6QWEN2: - return { - LLM_TENSOR_TOKEN_EMBD, - LLM_TENSOR_OUTPUT_NORM, - LLM_TENSOR_OUTPUT, - LLM_TENSOR_ATTN_NORM, - LLM_TENSOR_TIME_MIX_W1, - LLM_TENSOR_TIME_MIX_W2, - LLM_TENSOR_TIME_MIX_LERP_X, - LLM_TENSOR_TIME_MIX_LERP_FUSED, - LLM_TENSOR_TIME_MIX_FIRST, - LLM_TENSOR_TIME_MIX_DECAY, - LLM_TENSOR_TIME_MIX_DECAY_W1, - LLM_TENSOR_TIME_MIX_DECAY_W2, - LLM_TENSOR_TIME_MIX_KEY, - LLM_TENSOR_TIME_MIX_VALUE, - LLM_TENSOR_TIME_MIX_RECEPTANCE, - LLM_TENSOR_TIME_MIX_GATE, - LLM_TENSOR_TIME_MIX_OUTPUT, - LLM_TENSOR_FFN_NORM, - LLM_TENSOR_FFN_GATE, - LLM_TENSOR_FFN_DOWN, - LLM_TENSOR_FFN_UP, - }; - case LLM_ARCH_RWKV7: - return { - LLM_TENSOR_TOKEN_EMBD, - LLM_TENSOR_TOKEN_EMBD_NORM, - LLM_TENSOR_OUTPUT_NORM, - LLM_TENSOR_OUTPUT, - LLM_TENSOR_ATTN_NORM, - LLM_TENSOR_ATTN_NORM_2, - LLM_TENSOR_TIME_MIX_W0, - LLM_TENSOR_TIME_MIX_W1, - LLM_TENSOR_TIME_MIX_W2, - LLM_TENSOR_TIME_MIX_A0, - LLM_TENSOR_TIME_MIX_A1, - LLM_TENSOR_TIME_MIX_A2, - LLM_TENSOR_TIME_MIX_V0, - LLM_TENSOR_TIME_MIX_V1, - LLM_TENSOR_TIME_MIX_V2, - LLM_TENSOR_TIME_MIX_G1, - LLM_TENSOR_TIME_MIX_G2, - LLM_TENSOR_TIME_MIX_K_K, - LLM_TENSOR_TIME_MIX_K_A, - LLM_TENSOR_TIME_MIX_R_K, - LLM_TENSOR_TIME_MIX_LERP_FUSED, - LLM_TENSOR_TIME_MIX_KEY, - LLM_TENSOR_TIME_MIX_VALUE, - LLM_TENSOR_TIME_MIX_RECEPTANCE, - LLM_TENSOR_TIME_MIX_LN, - LLM_TENSOR_TIME_MIX_OUTPUT, - LLM_TENSOR_CHANNEL_MIX_LERP_K, - LLM_TENSOR_CHANNEL_MIX_KEY, - LLM_TENSOR_CHANNEL_MIX_VALUE, - }; - case LLM_ARCH_ARWKV7: - return { - LLM_TENSOR_TOKEN_EMBD, - LLM_TENSOR_TOKEN_EMBD_NORM, - LLM_TENSOR_OUTPUT_NORM, - LLM_TENSOR_OUTPUT, - LLM_TENSOR_ATTN_NORM, - LLM_TENSOR_TIME_MIX_W0, - LLM_TENSOR_TIME_MIX_W1, - LLM_TENSOR_TIME_MIX_W2, - LLM_TENSOR_TIME_MIX_A0, - LLM_TENSOR_TIME_MIX_A1, - LLM_TENSOR_TIME_MIX_A2, - LLM_TENSOR_TIME_MIX_V0, - LLM_TENSOR_TIME_MIX_V1, - LLM_TENSOR_TIME_MIX_V2, - LLM_TENSOR_TIME_MIX_G1, - LLM_TENSOR_TIME_MIX_G2, - LLM_TENSOR_TIME_MIX_K_K, - LLM_TENSOR_TIME_MIX_K_A, - LLM_TENSOR_TIME_MIX_R_K, - LLM_TENSOR_TIME_MIX_LERP_FUSED, - LLM_TENSOR_TIME_MIX_KEY, - LLM_TENSOR_TIME_MIX_VALUE, - LLM_TENSOR_TIME_MIX_RECEPTANCE, - LLM_TENSOR_TIME_MIX_LN, - LLM_TENSOR_TIME_MIX_OUTPUT, - LLM_TENSOR_FFN_NORM, - LLM_TENSOR_FFN_GATE, - LLM_TENSOR_FFN_DOWN, - LLM_TENSOR_FFN_UP, - }; - case LLM_ARCH_GRANITE_MOE: - return { - LLM_TENSOR_TOKEN_EMBD, - LLM_TENSOR_OUTPUT_NORM, - LLM_TENSOR_OUTPUT, - LLM_TENSOR_ATTN_NORM, - LLM_TENSOR_ATTN_Q, - LLM_TENSOR_ATTN_K, - LLM_TENSOR_ATTN_V, - LLM_TENSOR_ATTN_OUT, - LLM_TENSOR_FFN_NORM, - LLM_TENSOR_FFN_GATE_INP, - LLM_TENSOR_FFN_GATE_EXPS, - LLM_TENSOR_FFN_DOWN_EXPS, - LLM_TENSOR_FFN_UP_EXPS, - LLM_TENSOR_FFN_GATE_SHEXP, - LLM_TENSOR_FFN_DOWN_SHEXP, - LLM_TENSOR_FFN_UP_SHEXP, - }; - case LLM_ARCH_GRANITE_HYBRID: - return { - LLM_TENSOR_TOKEN_EMBD, - LLM_TENSOR_OUTPUT_NORM, - LLM_TENSOR_OUTPUT, - LLM_TENSOR_ATTN_NORM, - LLM_TENSOR_SSM_IN, - LLM_TENSOR_SSM_CONV1D, - LLM_TENSOR_SSM_DT, - LLM_TENSOR_SSM_A, - LLM_TENSOR_SSM_D, - LLM_TENSOR_SSM_NORM, - LLM_TENSOR_SSM_OUT, - LLM_TENSOR_ATTN_Q, - LLM_TENSOR_ATTN_K, - LLM_TENSOR_ATTN_V, - LLM_TENSOR_ATTN_OUT, - LLM_TENSOR_FFN_NORM, - LLM_TENSOR_FFN_GATE, - LLM_TENSOR_FFN_DOWN, - LLM_TENSOR_FFN_UP, - LLM_TENSOR_FFN_NORM, - LLM_TENSOR_FFN_GATE_INP, - LLM_TENSOR_FFN_GATE_EXPS, - LLM_TENSOR_FFN_DOWN_EXPS, - LLM_TENSOR_FFN_UP_EXPS, - LLM_TENSOR_FFN_GATE_SHEXP, - LLM_TENSOR_FFN_DOWN_SHEXP, - LLM_TENSOR_FFN_UP_SHEXP, - }; - case LLM_ARCH_WAVTOKENIZER_DEC: - return { - LLM_TENSOR_TOKEN_EMBD, - LLM_TENSOR_TOKEN_EMBD_NORM, - LLM_TENSOR_CONV1D, - LLM_TENSOR_CONVNEXT_DW, - LLM_TENSOR_CONVNEXT_NORM, - LLM_TENSOR_CONVNEXT_PW1, - LLM_TENSOR_CONVNEXT_PW2, - LLM_TENSOR_CONVNEXT_GAMMA, - LLM_TENSOR_OUTPUT_NORM, - LLM_TENSOR_OUTPUT, - LLM_TENSOR_POS_NET_CONV1, - LLM_TENSOR_POS_NET_CONV2, - LLM_TENSOR_POS_NET_NORM, - LLM_TENSOR_POS_NET_NORM1, - LLM_TENSOR_POS_NET_NORM2, - LLM_TENSOR_POS_NET_ATTN_NORM, - LLM_TENSOR_POS_NET_ATTN_Q, - LLM_TENSOR_POS_NET_ATTN_K, - LLM_TENSOR_POS_NET_ATTN_V, - LLM_TENSOR_POS_NET_ATTN_OUT, - }; - case LLM_ARCH_BAILINGMOE: - return { - LLM_TENSOR_TOKEN_EMBD, - LLM_TENSOR_OUTPUT_NORM, - LLM_TENSOR_OUTPUT, - LLM_TENSOR_ROPE_FREQS, - LLM_TENSOR_ATTN_NORM, - LLM_TENSOR_ATTN_Q, - LLM_TENSOR_ATTN_K, - LLM_TENSOR_ATTN_V, - LLM_TENSOR_ATTN_OUT, - LLM_TENSOR_FFN_GATE_INP, - LLM_TENSOR_FFN_NORM, - LLM_TENSOR_FFN_GATE_EXPS, - LLM_TENSOR_FFN_DOWN_EXPS, - LLM_TENSOR_FFN_UP_EXPS, - LLM_TENSOR_FFN_GATE_INP_SHEXP, - LLM_TENSOR_FFN_GATE_SHEXP, - LLM_TENSOR_FFN_DOWN_SHEXP, - LLM_TENSOR_FFN_UP_SHEXP, - }; - case LLM_ARCH_BAILINGMOE2: - return { - LLM_TENSOR_TOKEN_EMBD, - LLM_TENSOR_OUTPUT_NORM, - LLM_TENSOR_OUTPUT, - LLM_TENSOR_ATTN_NORM, - LLM_TENSOR_ATTN_Q_NORM, - LLM_TENSOR_ATTN_K_NORM, - LLM_TENSOR_ATTN_QKV, - LLM_TENSOR_ATTN_OUT, - LLM_TENSOR_FFN_GATE_INP, - LLM_TENSOR_FFN_EXP_PROBS_B, - LLM_TENSOR_FFN_NORM, - LLM_TENSOR_FFN_GATE, - LLM_TENSOR_FFN_DOWN, - LLM_TENSOR_FFN_UP, - LLM_TENSOR_FFN_GATE_EXPS, - LLM_TENSOR_FFN_DOWN_EXPS, - LLM_TENSOR_FFN_UP_EXPS, - LLM_TENSOR_FFN_GATE_SHEXP, - LLM_TENSOR_FFN_DOWN_SHEXP, - LLM_TENSOR_FFN_UP_SHEXP, - LLM_TENSOR_NEXTN_EH_PROJ, - LLM_TENSOR_NEXTN_EMBED_TOKENS, - LLM_TENSOR_NEXTN_ENORM, - LLM_TENSOR_NEXTN_HNORM, - LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD, - LLM_TENSOR_NEXTN_SHARED_HEAD_NORM, - LLM_TENSOR_LAYER_OUT_NORM, - }; - case LLM_ARCH_DOTS1: - return { - LLM_TENSOR_TOKEN_EMBD, - LLM_TENSOR_OUTPUT_NORM, - LLM_TENSOR_OUTPUT, - LLM_TENSOR_ATTN_NORM, - LLM_TENSOR_ATTN_Q, - LLM_TENSOR_ATTN_Q_NORM, - LLM_TENSOR_ATTN_K, - LLM_TENSOR_ATTN_K_NORM, - LLM_TENSOR_ATTN_V, - LLM_TENSOR_ATTN_OUT, - LLM_TENSOR_FFN_NORM, - LLM_TENSOR_FFN_GATE, - LLM_TENSOR_FFN_UP, - LLM_TENSOR_FFN_DOWN, - LLM_TENSOR_FFN_GATE_INP, - LLM_TENSOR_FFN_GATE_EXPS, - LLM_TENSOR_FFN_DOWN_EXPS, - LLM_TENSOR_FFN_UP_EXPS, - LLM_TENSOR_FFN_GATE_INP_SHEXP, - LLM_TENSOR_FFN_GATE_SHEXP, - LLM_TENSOR_FFN_DOWN_SHEXP, - LLM_TENSOR_FFN_UP_SHEXP, - LLM_TENSOR_FFN_EXP_PROBS_B, - }; - case LLM_ARCH_ERNIE4_5_MOE: - return { - LLM_TENSOR_TOKEN_EMBD, - LLM_TENSOR_OUTPUT_NORM, - LLM_TENSOR_OUTPUT, - LLM_TENSOR_ATTN_NORM, - LLM_TENSOR_ATTN_Q, - LLM_TENSOR_ATTN_K, - LLM_TENSOR_ATTN_V, - LLM_TENSOR_ATTN_OUT, - LLM_TENSOR_FFN_NORM, - LLM_TENSOR_FFN_GATE, - LLM_TENSOR_FFN_DOWN, - LLM_TENSOR_FFN_UP, - LLM_TENSOR_FFN_GATE_INP, - LLM_TENSOR_FFN_GATE_SHEXP, - LLM_TENSOR_FFN_DOWN_SHEXP, - LLM_TENSOR_FFN_UP_SHEXP, - LLM_TENSOR_FFN_GATE_EXPS, - LLM_TENSOR_FFN_DOWN_EXPS, - LLM_TENSOR_FFN_UP_EXPS, - LLM_TENSOR_FFN_EXP_PROBS_B, - }; - case LLM_ARCH_HUNYUAN_MOE: - return { - LLM_TENSOR_TOKEN_EMBD, - LLM_TENSOR_OUTPUT_NORM, - LLM_TENSOR_OUTPUT, - LLM_TENSOR_ATTN_NORM, - LLM_TENSOR_ATTN_Q, - LLM_TENSOR_ATTN_Q_NORM, - LLM_TENSOR_ATTN_K, - LLM_TENSOR_ATTN_K_NORM, - LLM_TENSOR_ATTN_V, - LLM_TENSOR_ATTN_OUT, - LLM_TENSOR_FFN_GATE_INP, - LLM_TENSOR_FFN_NORM, - LLM_TENSOR_FFN_GATE_SHEXP, - LLM_TENSOR_FFN_DOWN_SHEXP, - LLM_TENSOR_FFN_UP_SHEXP, - LLM_TENSOR_FFN_GATE_EXPS, - LLM_TENSOR_FFN_DOWN_EXPS, - LLM_TENSOR_FFN_UP_EXPS, - }; - case LLM_ARCH_OPENAI_MOE: - return { - LLM_TENSOR_TOKEN_EMBD, - LLM_TENSOR_OUTPUT_NORM, - LLM_TENSOR_OUTPUT, - LLM_TENSOR_ATTN_NORM, - LLM_TENSOR_ATTN_POST_NORM, - LLM_TENSOR_ATTN_Q, - LLM_TENSOR_ATTN_K, - LLM_TENSOR_ATTN_V, - LLM_TENSOR_ATTN_OUT, - LLM_TENSOR_ATTN_SINKS, - LLM_TENSOR_FFN_GATE_INP, - LLM_TENSOR_FFN_GATE_EXPS, - LLM_TENSOR_FFN_DOWN_EXPS, - LLM_TENSOR_FFN_UP_EXPS, - }; - case LLM_ARCH_LFM2: - return { - LLM_TENSOR_ATTN_NORM, - LLM_TENSOR_ATTN_Q, - LLM_TENSOR_ATTN_K, - LLM_TENSOR_ATTN_V, - LLM_TENSOR_ATTN_OUT, - LLM_TENSOR_ATTN_K_NORM, - LLM_TENSOR_ATTN_Q_NORM, - LLM_TENSOR_FFN_DOWN, - LLM_TENSOR_FFN_GATE, - LLM_TENSOR_FFN_NORM, - LLM_TENSOR_FFN_UP, - LLM_TENSOR_SHORTCONV_CONV, - LLM_TENSOR_SHORTCONV_INPROJ, - LLM_TENSOR_SHORTCONV_OUTPROJ, - LLM_TENSOR_TOKEN_EMBD, - LLM_TENSOR_OUTPUT_NORM_LFM2, - LLM_TENSOR_OUTPUT, - LLM_TENSOR_DENSE_2_OUT, - }; - case LLM_ARCH_LFM2MOE: - return { - LLM_TENSOR_ATTN_NORM, - LLM_TENSOR_ATTN_Q, - LLM_TENSOR_ATTN_K, - LLM_TENSOR_ATTN_V, - LLM_TENSOR_ATTN_OUT, - LLM_TENSOR_ATTN_K_NORM, - LLM_TENSOR_ATTN_Q_NORM, - LLM_TENSOR_FFN_DOWN, - LLM_TENSOR_FFN_GATE, - LLM_TENSOR_FFN_NORM, - LLM_TENSOR_FFN_UP, - LLM_TENSOR_SHORTCONV_CONV, - LLM_TENSOR_SHORTCONV_INPROJ, - LLM_TENSOR_SHORTCONV_OUTPROJ, - LLM_TENSOR_TOKEN_EMBD, - LLM_TENSOR_OUTPUT_NORM_LFM2, - LLM_TENSOR_FFN_GATE_INP, - LLM_TENSOR_FFN_GATE_EXPS, - LLM_TENSOR_FFN_DOWN_EXPS, - LLM_TENSOR_FFN_UP_EXPS, - LLM_TENSOR_FFN_EXP_PROBS_B, - }; - case LLM_ARCH_SMALLTHINKER: - return { - LLM_TENSOR_TOKEN_EMBD, - LLM_TENSOR_OUTPUT_NORM, - LLM_TENSOR_OUTPUT, - LLM_TENSOR_ATTN_NORM, - LLM_TENSOR_ATTN_Q, - LLM_TENSOR_ATTN_K, - LLM_TENSOR_ATTN_V, - LLM_TENSOR_ATTN_OUT, - LLM_TENSOR_FFN_NORM, - LLM_TENSOR_FFN_GATE, - LLM_TENSOR_FFN_DOWN, - LLM_TENSOR_FFN_UP, - LLM_TENSOR_FFN_GATE_INP, - LLM_TENSOR_FFN_GATE_EXPS, - LLM_TENSOR_FFN_DOWN_EXPS, - LLM_TENSOR_FFN_UP_EXPS, - }; - case LLM_ARCH_APERTUS: - return { - LLM_TENSOR_TOKEN_EMBD, - LLM_TENSOR_OUTPUT_NORM, - LLM_TENSOR_OUTPUT, - LLM_TENSOR_ROPE_FREQS, - LLM_TENSOR_ATTN_NORM, - LLM_TENSOR_ATTN_Q, - LLM_TENSOR_ATTN_K, - LLM_TENSOR_ATTN_V, - LLM_TENSOR_ATTN_OUT, - LLM_TENSOR_ATTN_Q_NORM, - LLM_TENSOR_ATTN_K_NORM, - LLM_TENSOR_FFN_NORM, - LLM_TENSOR_FFN_DOWN, - LLM_TENSOR_FFN_UP, - }; - case LLM_ARCH_SEED_OSS: - return { - LLM_TENSOR_TOKEN_EMBD, - LLM_TENSOR_OUTPUT_NORM, - LLM_TENSOR_OUTPUT, - LLM_TENSOR_ATTN_NORM, - LLM_TENSOR_ATTN_Q, - LLM_TENSOR_ATTN_K, - LLM_TENSOR_ATTN_V, - LLM_TENSOR_ATTN_OUT, - LLM_TENSOR_ATTN_POST_NORM, - LLM_TENSOR_FFN_GATE, - LLM_TENSOR_FFN_DOWN, - LLM_TENSOR_FFN_UP, - }; - case LLM_ARCH_GROVEMOE: - return { - LLM_TENSOR_TOKEN_EMBD, - LLM_TENSOR_OUTPUT_NORM, - LLM_TENSOR_OUTPUT, - LLM_TENSOR_ATTN_NORM, - LLM_TENSOR_ATTN_Q, - LLM_TENSOR_ATTN_Q_NORM, - LLM_TENSOR_ATTN_K, - LLM_TENSOR_ATTN_K_NORM, - LLM_TENSOR_ATTN_V, - LLM_TENSOR_ATTN_OUT, - LLM_TENSOR_FFN_NORM, - LLM_TENSOR_FFN_GATE_INP, - LLM_TENSOR_FFN_GATE_EXPS, - LLM_TENSOR_FFN_DOWN_EXPS, - LLM_TENSOR_FFN_UP_EXPS, - LLM_TENSOR_FFN_GATE_CHEXPS, - LLM_TENSOR_FFN_DOWN_CHEXPS, - LLM_TENSOR_FFN_UP_CHEXPS, - }; - case LLM_ARCH_MINIMAX_M2: - return { - LLM_TENSOR_TOKEN_EMBD, - LLM_TENSOR_OUTPUT_NORM, - LLM_TENSOR_OUTPUT, - LLM_TENSOR_ATTN_NORM, - LLM_TENSOR_ATTN_Q, - LLM_TENSOR_ATTN_K, - LLM_TENSOR_ATTN_V, - LLM_TENSOR_ATTN_OUT, - LLM_TENSOR_ATTN_Q_NORM, - LLM_TENSOR_ATTN_K_NORM, - LLM_TENSOR_FFN_NORM, - LLM_TENSOR_FFN_GATE_INP, - LLM_TENSOR_FFN_GATE_EXPS, - LLM_TENSOR_FFN_DOWN_EXPS, - LLM_TENSOR_FFN_UP_EXPS, - LLM_TENSOR_FFN_EXP_PROBS_B, - }; - case LLM_ARCH_COGVLM: - return { - LLM_TENSOR_TOKEN_EMBD, - LLM_TENSOR_OUTPUT_NORM, - LLM_TENSOR_OUTPUT, - LLM_TENSOR_ATTN_NORM, - LLM_TENSOR_ATTN_QKV, - LLM_TENSOR_ATTN_OUT, - LLM_TENSOR_FFN_NORM, - LLM_TENSOR_FFN_GATE, - LLM_TENSOR_FFN_DOWN, - LLM_TENSOR_FFN_UP, - LLM_TENSOR_VISEXP_ATTN_QKV, - LLM_TENSOR_VISEXP_ATTN_OUT, - LLM_TENSOR_VISEXP_FFN_GATE, - LLM_TENSOR_VISEXP_FFN_DOWN, - LLM_TENSOR_VISEXP_FFN_UP, - }; - case LLM_ARCH_MIMO2: - return { - LLM_TENSOR_TOKEN_EMBD, - LLM_TENSOR_OUTPUT_NORM, - LLM_TENSOR_OUTPUT, - LLM_TENSOR_ATTN_NORM, - LLM_TENSOR_ATTN_Q, - LLM_TENSOR_ATTN_K, - LLM_TENSOR_ATTN_V, - LLM_TENSOR_ATTN_SINKS, - LLM_TENSOR_ATTN_OUT, - LLM_TENSOR_FFN_NORM, - LLM_TENSOR_FFN_GATE, - LLM_TENSOR_FFN_DOWN, - LLM_TENSOR_FFN_UP, - LLM_TENSOR_FFN_GATE_INP, - LLM_TENSOR_FFN_GATE_EXPS, - LLM_TENSOR_FFN_DOWN_EXPS, - LLM_TENSOR_FFN_UP_EXPS, - LLM_TENSOR_FFN_EXP_PROBS_B, - }; - case LLM_ARCH_GPTJ: - case LLM_ARCH_UNKNOWN: - return { - LLM_TENSOR_TOKEN_EMBD, - }; - case LLM_ARCH_MAINCODER: - return { - LLM_TENSOR_TOKEN_EMBD, - LLM_TENSOR_OUTPUT_NORM, - LLM_TENSOR_OUTPUT, - LLM_TENSOR_ATTN_NORM, - LLM_TENSOR_ATTN_Q, - LLM_TENSOR_ATTN_Q_NORM, - LLM_TENSOR_ATTN_K, - LLM_TENSOR_ATTN_K_NORM, - LLM_TENSOR_ATTN_V, - LLM_TENSOR_ATTN_OUT, - LLM_TENSOR_FFN_NORM, - LLM_TENSOR_FFN_GATE, - LLM_TENSOR_FFN_DOWN, - LLM_TENSOR_FFN_UP, - }; - default: - GGML_ABORT("unknown architecture for tensor mapping"); - } -} - // declare information about the model weight tensors: // - the layer in which the tensor is going to be used. this is needed in order to assign the correct buffer type for the weight // - the operator which is going to use the weight. this is needed to determine if the respective backend supports the operator @@ -2272,19 +581,20 @@ static std::set<llm_tensor> llm_get_tensor_names(llm_arch arch) { // example: https://github.com/ggml-org/llama.cpp/pull/17548 // static const std::map<llm_tensor, llm_tensor_info> LLM_TENSOR_INFOS = { - {LLM_TENSOR_TOKEN_EMBD, {LLM_TENSOR_LAYER_INPUT, GGML_OP_GET_ROWS}}, - {LLM_TENSOR_POS_EMBD, {LLM_TENSOR_LAYER_INPUT, GGML_OP_GET_ROWS}}, - {LLM_TENSOR_TOKEN_TYPES, {LLM_TENSOR_LAYER_INPUT, GGML_OP_GET_ROWS}}, - {LLM_TENSOR_TOKEN_EMBD_NORM, {LLM_TENSOR_LAYER_INPUT, GGML_OP_MUL}}, - {LLM_TENSOR_OUTPUT, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}}, - {LLM_TENSOR_CLS, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}}, - {LLM_TENSOR_CLS_OUT, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}}, - {LLM_TENSOR_DENSE_2_OUT, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}}, // Dense layer output - {LLM_TENSOR_DENSE_3_OUT, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}}, // Dense layer output - {LLM_TENSOR_OUTPUT_NORM, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL}}, - {LLM_TENSOR_OUTPUT_NORM_LFM2, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL}}, - {LLM_TENSOR_DEC_OUTPUT_NORM, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL}}, - {LLM_TENSOR_ENC_OUTPUT_NORM, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL}}, + {LLM_TENSOR_TOKEN_EMBD, {LLM_TENSOR_LAYER_INPUT, GGML_OP_GET_ROWS}}, + {LLM_TENSOR_POS_EMBD, {LLM_TENSOR_LAYER_INPUT, GGML_OP_GET_ROWS}}, + {LLM_TENSOR_TOKEN_TYPES, {LLM_TENSOR_LAYER_INPUT, GGML_OP_GET_ROWS}}, + {LLM_TENSOR_TOKEN_EMBD_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, // do the norms on the first layer (not the input layer) + {LLM_TENSOR_OUTPUT, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_CLS, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_CLS_OUT, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_CLS_NORM, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL}}, + {LLM_TENSOR_DENSE_2_OUT, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}}, // Dense layer output + {LLM_TENSOR_DENSE_3_OUT, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}}, // Dense layer output + {LLM_TENSOR_OUTPUT_NORM, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL}}, + {LLM_TENSOR_OUTPUT_NORM_LFM2, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL}}, + {LLM_TENSOR_DEC_OUTPUT_NORM, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL}}, + {LLM_TENSOR_ENC_OUTPUT_NORM, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL}}, {LLM_TENSOR_ROPE_FREQS, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ROPE}}, {LLM_TENSOR_ROPE_FACTORS_LONG, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ROPE}}, {LLM_TENSOR_ROPE_FACTORS_SHORT, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ROPE}}, @@ -2331,6 +641,7 @@ static const std::map<llm_tensor, llm_tensor_info> LLM_TENSOR_INFOS = { {LLM_TENSOR_SSM_X, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, {LLM_TENSOR_SSM_DT, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, {LLM_TENSOR_SSM_OUT, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_SSM_ALPHA, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, {LLM_TENSOR_SSM_BETA_ALPHA, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, {LLM_TENSOR_TIME_MIX_W1, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, {LLM_TENSOR_TIME_MIX_W2, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, @@ -2359,6 +670,15 @@ static const std::map<llm_tensor, llm_tensor_info> LLM_TENSOR_INFOS = { {LLM_TENSOR_SSM_C_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, {LLM_TENSOR_SSM_D, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, {LLM_TENSOR_SSM_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, + // Kimi KDA - Conv tensors are 4D [d_conv, 1, d_inner, 1], reshaped to 2D at runtime + {LLM_TENSOR_SSM_CONV1D_Q, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, + {LLM_TENSOR_SSM_CONV1D_K, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, + {LLM_TENSOR_SSM_CONV1D_V, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, + {LLM_TENSOR_SSM_F_A, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_SSM_F_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_SSM_BETA, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_SSM_G_A, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_SSM_G_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, {LLM_TENSOR_TIME_MIX_LERP_X, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, {LLM_TENSOR_TIME_MIX_LN, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, {LLM_TENSOR_CHANNEL_MIX_LERP_K, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, @@ -2382,11 +702,15 @@ static const std::map<llm_tensor, llm_tensor_info> LLM_TENSOR_INFOS = { {LLM_TENSOR_ATTN_OUT_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, {LLM_TENSOR_ATTN_POST_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, {LLM_TENSOR_FFN_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, + {LLM_TENSOR_FFN_PRE_NORM_2, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, + {LLM_TENSOR_FFN_POST_NORM_1, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, + {LLM_TENSOR_FFN_POST_NORM_2, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, {LLM_TENSOR_FFN_POST_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, {LLM_TENSOR_FFN_NORM_EXPS, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, {LLM_TENSOR_ATTN_Q_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, {LLM_TENSOR_ATTN_K_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, {LLM_TENSOR_LAYER_OUT_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, + {LLM_TENSOR_LAYER_OUT_SCALE, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, {LLM_TENSOR_ATTN_Q_A_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, {LLM_TENSOR_ATTN_KV_A_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, {LLM_TENSOR_ATTN_SUB_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, @@ -2401,14 +725,15 @@ static const std::map<llm_tensor, llm_tensor_info> LLM_TENSOR_INFOS = { {LLM_TENSOR_FFN_DOWN_EXPS, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT_ID}}, {LLM_TENSOR_FFN_GATE_EXPS, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT_ID}}, {LLM_TENSOR_FFN_UP_EXPS, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT_ID}}, + {LLM_TENSOR_FFN_GATE_UP_EXPS, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT_ID}}, {LLM_TENSOR_FFN_DOWN_CHEXPS, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT_ID}}, {LLM_TENSOR_FFN_GATE_CHEXPS, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT_ID}}, {LLM_TENSOR_FFN_UP_CHEXPS, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT_ID}}, {LLM_TENSOR_FFN_EXP_PROBS_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ADD}}, // altup / laurel (gemma 3n) - {LLM_TENSOR_PER_LAYER_TOKEN_EMBD, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_GET_ROWS}}, - {LLM_TENSOR_PER_LAYER_MODEL_PROJ, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}}, - {LLM_TENSOR_PER_LAYER_PROJ_NORM, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL}}, + {LLM_TENSOR_PER_LAYER_TOKEN_EMBD, {LLM_TENSOR_LAYER_INPUT, GGML_OP_GET_ROWS}}, + {LLM_TENSOR_PER_LAYER_MODEL_PROJ, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_PER_LAYER_PROJ_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, {LLM_TENSOR_ALTUP_PROJ, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}}, {LLM_TENSOR_ALTUP_UNEMBD_PROJ, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}}, {LLM_TENSOR_PER_LAYER_INP_GATE, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, @@ -2424,7 +749,7 @@ static const std::map<llm_tensor, llm_tensor_info> LLM_TENSOR_INFOS = { {LLM_TENSOR_LAUREL_POST_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, // this tensor is loaded for T5, but never used {LLM_TENSOR_DEC_CROSS_ATTN_REL_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_NONE}}, - {LLM_TENSOR_CONV1D, {LLM_TENSOR_LAYER_INPUT, GGML_OP_IM2COL}}, + {LLM_TENSOR_CONV1D, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_IM2COL}}, {LLM_TENSOR_POS_NET_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, {LLM_TENSOR_POS_NET_NORM1, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, {LLM_TENSOR_POS_NET_NORM2, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, @@ -2448,14 +773,30 @@ static const std::map<llm_tensor, llm_tensor_info> LLM_TENSOR_INFOS = { {LLM_TENSOR_VISEXP_FFN_GATE, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, {LLM_TENSOR_VISEXP_FFN_DOWN, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, {LLM_TENSOR_VISEXP_FFN_UP, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, - // NextN/MTP tensors are currently ignored (reserved for future MTP support) - // These tensors only exist in the last layer(s) and are treated as output tensors - {LLM_TENSOR_NEXTN_EH_PROJ, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}}, - {LLM_TENSOR_NEXTN_EMBED_TOKENS, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_GET_ROWS}}, - {LLM_TENSOR_NEXTN_ENORM, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_GET_ROWS}}, - {LLM_TENSOR_NEXTN_HNORM, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL}}, - {LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}}, - {LLM_TENSOR_NEXTN_SHARED_HEAD_NORM, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL}}, + {LLM_TENSOR_INDEXER_K_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, + {LLM_TENSOR_INDEXER_PROJ, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_INDEXER_ATTN_K, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_INDEXER_ATTN_Q_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_NEXTN_PROJ_PRE, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_NEXTN_PROJ_POST, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}}, + // NextN/MTP tensors are stored per-block (blk.%d.nextn.*) even though only the + // last nextn_predict_layers blocks carry them. Classify as LAYER_REPEATING so + // the model loader doesn't fault on the block index. + {LLM_TENSOR_NEXTN_EH_PROJ, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_NEXTN_EMBED_TOKENS, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_GET_ROWS}}, + {LLM_TENSOR_NEXTN_ENORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, + {LLM_TENSOR_NEXTN_HNORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, + {LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_NEXTN_SHARED_HEAD_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, + // Nemotron 3 Super + // latent projections feed ggml_mul_mat, the buft probe must use MUL_MAT to keep them on GPU + {LLM_TENSOR_FFN_LATENT_DOWN, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_FFN_LATENT_UP, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_MASKED_EMBD_CENTROIDS, {LLM_TENSOR_LAYER_INPUT, GGML_OP_NONE}}, + {LLM_TENSOR_MASKED_EMBD_ORDERING, {LLM_TENSOR_LAYER_INPUT, GGML_OP_NONE}}, + // eagle3 + {LLM_TENSOR_FC, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_D2T, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_GET_ROWS}}, }; LLM_KV::LLM_KV(llm_arch arch, const char * suffix) : arch(arch), suffix(suffix) {} @@ -2472,18 +813,13 @@ std::string LLM_KV::operator()(llm_kv kv) const { } LLM_TN_IMPL::LLM_TN_IMPL(llm_arch arch, llm_tensor tensor, const char * suffix, int bid, int xid) - : arch(arch), tensor(tensor), suffix(suffix), bid(bid), xid(xid), - model_tensors(llm_get_tensor_names(arch)) {} + : arch(arch), tensor(tensor), suffix(suffix), bid(bid), xid(xid) {} std::string LLM_TN_IMPL::str() const { if (LLM_TENSOR_NAMES.find(tensor) == LLM_TENSOR_NAMES.end()) { GGML_ABORT("unknown tensor name for tensor id %d", static_cast<int>(tensor)); } - if (model_tensors.find(tensor) == model_tensors.end()) { - return LLM_TENSOR_NAMES.at(tensor); - } - std::string name = ::format(LLM_TENSOR_NAMES.at(tensor), bid, xid); if (suffix != nullptr) { name += "."; @@ -2493,6 +829,15 @@ std::string LLM_TN_IMPL::str() const { return name; } +std::vector<llm_arch> llm_arch_all() { + std::vector<llm_arch> ret; + ret.reserve(LLM_ARCH_NAMES.size()); + for (const auto & [arch, _] : LLM_ARCH_NAMES) { + ret.push_back(arch); + } + return ret; +} + const char * llm_arch_name(llm_arch arch) { auto it = LLM_ARCH_NAMES.find(arch); if (it == LLM_ARCH_NAMES.end()) { @@ -2540,6 +885,9 @@ bool llm_arch_is_hybrid(const llm_arch & arch) { case LLM_ARCH_NEMOTRON_H: case LLM_ARCH_NEMOTRON_H_MOE: case LLM_ARCH_QWEN3NEXT: + case LLM_ARCH_KIMI_LINEAR: + case LLM_ARCH_QWEN35: + case LLM_ARCH_QWEN35MOE: return true; default: return false; @@ -2557,3 +905,45 @@ bool llm_arch_is_diffusion(const llm_arch & arch) { return false; } } + +bool llm_arch_supports_rs_rollback(const llm_arch & arch) { + switch (arch) { + case LLM_ARCH_QWEN35: + case LLM_ARCH_QWEN35MOE: + return true; + default: + return false; + } +} + +bool llm_arch_supports_sm_tensor(const llm_arch & arch) { + switch (arch) { + case LLM_ARCH_GROK: + case LLM_ARCH_MPT: + case LLM_ARCH_PLAMO2: + case LLM_ARCH_MINICPM3: + case LLM_ARCH_GEMMA3N: + case LLM_ARCH_MAMBA: + case LLM_ARCH_MAMBA2: + case LLM_ARCH_JAMBA: + case LLM_ARCH_FALCON_H1: + case LLM_ARCH_OLMO2: + case LLM_ARCH_OLMOE: + case LLM_ARCH_DEEPSEEK2: + case LLM_ARCH_DEEPSEEK32: + case LLM_ARCH_GLM_DSA: + case LLM_ARCH_BITNET: + case LLM_ARCH_T5: + case LLM_ARCH_NEMOTRON_H: + case LLM_ARCH_NEMOTRON_H_MOE: + case LLM_ARCH_GRANITE_HYBRID: + case LLM_ARCH_LFM2: + case LLM_ARCH_LFM2MOE: + case LLM_ARCH_MINIMAX_M2: + case LLM_ARCH_MISTRAL4: + case LLM_ARCH_KIMI_LINEAR: + return false; + default: + return true; + } +} diff --git a/examples/talk-llama/llama-arch.h b/examples/talk-llama/llama-arch.h index 68ec6a18b18..c5245fb5891 100644 --- a/examples/talk-llama/llama-arch.h +++ b/examples/talk-llama/llama-arch.h @@ -4,6 +4,7 @@ #include <string> #include <set> +#include <vector> // // gguf constants (sync with gguf.py) @@ -30,6 +31,7 @@ enum llm_arch { LLM_ARCH_NEO_BERT, LLM_ARCH_JINA_BERT_V2, LLM_ARCH_JINA_BERT_V3, + LLM_ARCH_EUROBERT, LLM_ARCH_BLOOM, LLM_ARCH_STABLELM, LLM_ARCH_QWEN, @@ -41,6 +43,8 @@ enum llm_arch { LLM_ARCH_QWEN3NEXT, LLM_ARCH_QWEN3VL, LLM_ARCH_QWEN3VLMOE, + LLM_ARCH_QWEN35, + LLM_ARCH_QWEN35MOE, LLM_ARCH_PHI2, LLM_ARCH_PHI3, LLM_ARCH_PHIMOE, @@ -56,6 +60,8 @@ enum llm_arch { LLM_ARCH_GEMMA2, LLM_ARCH_GEMMA3, LLM_ARCH_GEMMA3N, + LLM_ARCH_GEMMA4, + LLM_ARCH_GEMMA4_ASSISTANT, LLM_ARCH_GEMMA_EMBEDDING, LLM_ARCH_STARCODER2, LLM_ARCH_MAMBA, @@ -73,18 +79,23 @@ enum llm_arch { LLM_ARCH_ARCTIC, LLM_ARCH_DEEPSEEK, LLM_ARCH_DEEPSEEK2, + LLM_ARCH_DEEPSEEK2OCR, + LLM_ARCH_DEEPSEEK32, LLM_ARCH_CHATGLM, LLM_ARCH_GLM4, LLM_ARCH_GLM4_MOE, + LLM_ARCH_GLM_DSA, LLM_ARCH_BITNET, LLM_ARCH_T5, LLM_ARCH_T5ENCODER, LLM_ARCH_JAIS, + LLM_ARCH_JAIS2, LLM_ARCH_NEMOTRON, LLM_ARCH_NEMOTRON_H, LLM_ARCH_NEMOTRON_H_MOE, LLM_ARCH_EXAONE, LLM_ARCH_EXAONE4, + LLM_ARCH_EXAONE_MOE, LLM_ARCH_RWKV6, LLM_ARCH_RWKV6QWEN2, LLM_ARCH_RWKV7, @@ -104,6 +115,7 @@ enum llm_arch { LLM_ARCH_ERNIE4_5_MOE, LLM_ARCH_HUNYUAN_MOE, LLM_ARCH_HUNYUAN_DENSE, + LLM_ARCH_HUNYUAN_VL, LLM_ARCH_SMOLLM3, LLM_ARCH_OPENAI_MOE, LLM_ARCH_LFM2, @@ -120,9 +132,16 @@ enum llm_arch { LLM_ARCH_RND1, LLM_ARCH_PANGU_EMBED, LLM_ARCH_MISTRAL3, + LLM_ARCH_MISTRAL4, + LLM_ARCH_PADDLEOCR, LLM_ARCH_MIMO2, + LLM_ARCH_STEP35, LLM_ARCH_LLAMA_EMBED, LLM_ARCH_MAINCODER, + LLM_ARCH_KIMI_LINEAR, + LLM_ARCH_TALKIE, + LLM_ARCH_MELLUM, + LLM_ARCH_EAGLE3, LLM_ARCH_UNKNOWN, }; @@ -157,6 +176,7 @@ enum llm_kv { LLM_KV_CONTEXT_LENGTH, LLM_KV_EMBEDDING_LENGTH, LLM_KV_EMBEDDING_LENGTH_OUT, + LLM_KV_EMBEDDING_LENGTH_PER_LAYER, LLM_KV_FEATURES_LENGTH, LLM_KV_BLOCK_COUNT, LLM_KV_LEADING_DENSE_BLOCK_COUNT, @@ -164,6 +184,8 @@ enum llm_kv { LLM_KV_EXPERT_FEED_FORWARD_LENGTH, LLM_KV_EXPERT_SHARED_FEED_FORWARD_LENGTH, LLM_KV_EXPERT_CHUNK_FEED_FORWARD_LENGTH, + LLM_KV_SWIGLU_CLAMP_EXP, + LLM_KV_SWIGLU_CLAMP_SHEXP, LLM_KV_USE_PARALLEL_RESIDUAL, LLM_KV_TENSOR_DATA_LAYOUT, LLM_KV_EXPERT_COUNT, @@ -177,8 +199,11 @@ enum llm_kv { LLM_KV_EXPERT_GROUP_SCALE, LLM_KV_EXPERTS_PER_GROUP, LLM_KV_MOE_EVERY_N_LAYERS, + LLM_KV_MOE_LATENT_SIZE, LLM_KV_NEXTN_PREDICT_LAYERS, LLM_KV_NUM_DEEPSTACK_LAYERS, + LLM_KV_DEEPSTACK_MAPPING, + LLM_KV_HIDDEN_ACT, LLM_KV_POOLING_TYPE, LLM_KV_LOGIT_SCALE, LLM_KV_DECODER_START_TOKEN_ID, @@ -194,6 +219,7 @@ enum llm_kv { LLM_KV_EMBEDDING_SCALE, LLM_KV_TOKEN_SHIFT_COUNT, LLM_KV_INTERLEAVE_MOE_LAYER_STEP, + LLM_KV_FULL_ATTENTION_INTERVAL, LLM_KV_ATTENTION_HEAD_COUNT, LLM_KV_ATTENTION_HEAD_COUNT_KV, @@ -217,18 +243,28 @@ enum llm_kv { LLM_KV_ATTENTION_SLIDING_WINDOW_PATTERN, LLM_KV_ATTENTION_SCALE, LLM_KV_ATTENTION_OUTPUT_SCALE, + LLM_KV_ATTENTION_VALUE_SCALE, LLM_KV_ATTENTION_TEMPERATURE_LENGTH, LLM_KV_ATTENTION_TEMPERATURE_SCALE, LLM_KV_ATTENTION_KEY_LENGTH_MLA, LLM_KV_ATTENTION_VALUE_LENGTH_MLA, + LLM_KV_ATTENTION_KEY_LENGTH_SWA, + LLM_KV_ATTENTION_VALUE_LENGTH_SWA, + LLM_KV_ATTENTION_INDEXER_HEAD_COUNT, + LLM_KV_ATTENTION_INDEXER_KEY_LENGTH, + LLM_KV_ATTENTION_INDEXER_TOP_K, + LLM_KV_ATTENTION_SHARED_KV_LAYERS, + LLM_KV_ATTENTION_RECURRENT_LAYERS, LLM_KV_ROPE_DIMENSION_COUNT, + LLM_KV_ROPE_DIMENSION_COUNT_SWA, LLM_KV_ROPE_DIMENSION_SECTIONS, LLM_KV_ROPE_FREQ_BASE, LLM_KV_ROPE_FREQ_BASE_SWA, LLM_KV_ROPE_SCALE_LINEAR, LLM_KV_ROPE_SCALING_TYPE, LLM_KV_ROPE_SCALING_FACTOR, + LLM_KV_ROPE_SCALING_ALPHA, LLM_KV_ROPE_SCALING_ATTN_FACTOR, LLM_KV_ROPE_SCALING_ORIG_CTX_LEN, LLM_KV_ROPE_SCALING_FINETUNED, @@ -249,6 +285,8 @@ enum llm_kv { LLM_KV_SSM_GROUP_COUNT, LLM_KV_SSM_DT_B_C_RMS, + LLM_KV_KDA_HEAD_DIM, + LLM_KV_WKV_HEAD_SIZE, LLM_KV_TOKENIZER_MODEL, @@ -276,12 +314,15 @@ enum llm_kv { LLM_KV_TOKENIZER_HF_JSON, LLM_KV_TOKENIZER_RWKV, LLM_KV_TOKENIZER_CHAT_TEMPLATE, + LLM_KV_TOKENIZER_NORMALIZER_LOWERCASE, + LLM_KV_TOKENIZER_NORMALIZER_STRIP_ACCENTS, LLM_KV_TOKENIZER_FIM_PRE_ID, LLM_KV_TOKENIZER_FIM_SUF_ID, LLM_KV_TOKENIZER_FIM_MID_ID, LLM_KV_TOKENIZER_FIM_PAD_ID, LLM_KV_TOKENIZER_FIM_REP_ID, LLM_KV_TOKENIZER_FIM_SEP_ID, + LLM_KV_TOKENIZER_SUPPRESS_TOKENS, LLM_KV_ADAPTER_TYPE, LLM_KV_ADAPTER_LORA_ALPHA, @@ -297,6 +338,10 @@ enum llm_kv { LLM_KV_CLASSIFIER_OUTPUT_LABELS, + LLM_KV_TARGET_LAYERS, + LLM_KV_TARGET_HIDDEN_SIZE, + LLM_KV_NORM_BEFORE_RESIDUAL, + LLM_KV_SHORTCONV_L_CACHE, LLM_KV_XIELU_ALPHA_N, @@ -345,6 +390,9 @@ enum llm_tensor { LLM_TENSOR_FFN_GATE_INP_SHEXP, LLM_TENSOR_FFN_NORM, LLM_TENSOR_FFN_POST_NORM, + LLM_TENSOR_FFN_POST_NORM_1, + LLM_TENSOR_FFN_POST_NORM_2, + LLM_TENSOR_FFN_PRE_NORM_2, LLM_TENSOR_FFN_GATE, LLM_TENSOR_FFN_DOWN, LLM_TENSOR_FFN_UP, @@ -356,6 +404,7 @@ enum llm_tensor { LLM_TENSOR_FFN_DOWN_EXPS, // merged experts LLM_TENSOR_FFN_GATE_EXPS, LLM_TENSOR_FFN_UP_EXPS, + LLM_TENSOR_FFN_GATE_UP_EXPS, LLM_TENSOR_FFN_DOWN_SHEXP, LLM_TENSOR_FFN_GATE_SHEXP, LLM_TENSOR_FFN_UP_SHEXP, @@ -363,9 +412,12 @@ enum llm_tensor { LLM_TENSOR_FFN_GATE_CHEXPS, LLM_TENSOR_FFN_UP_CHEXPS, LLM_TENSOR_FFN_EXP_PROBS_B, + LLM_TENSOR_FFN_LATENT_DOWN, + LLM_TENSOR_FFN_LATENT_UP, LLM_TENSOR_ATTN_Q_NORM, LLM_TENSOR_ATTN_K_NORM, LLM_TENSOR_LAYER_OUT_NORM, + LLM_TENSOR_LAYER_OUT_SCALE, LLM_TENSOR_POST_ATTN_NORM, LLM_TENSOR_POST_MLP_NORM, LLM_TENSOR_PER_LAYER_TOKEN_EMBD, // gemma3n @@ -397,6 +449,16 @@ enum llm_tensor { LLM_TENSOR_SSM_NORM, LLM_TENSOR_SSM_OUT, LLM_TENSOR_SSM_BETA_ALPHA, // qwen3next + LLM_TENSOR_SSM_ALPHA, // qwen3.5 + // Kimi Linear KDA (using SSM_ prefix for consistency) + LLM_TENSOR_SSM_CONV1D_Q, // kimi: Q conv1d weight + LLM_TENSOR_SSM_CONV1D_K, // kimi: K conv1d weight + LLM_TENSOR_SSM_CONV1D_V, // kimi: V conv1d weight + LLM_TENSOR_SSM_F_A, // kimi: forget gate projection A + LLM_TENSOR_SSM_F_B, // kimi: forget gate projection B + LLM_TENSOR_SSM_BETA, // kimi: beta mixing coefficient and qwen3.5 + LLM_TENSOR_SSM_G_A, // kimi: output gate projection A + LLM_TENSOR_SSM_G_B, // kimi: output gate projection B LLM_TENSOR_TIME_MIX_W0, LLM_TENSOR_TIME_MIX_W1, LLM_TENSOR_TIME_MIX_W2, @@ -473,6 +535,7 @@ enum llm_tensor { LLM_TENSOR_ENC_OUTPUT_NORM, LLM_TENSOR_CLS, LLM_TENSOR_CLS_OUT, + LLM_TENSOR_CLS_NORM, LLM_TENSOR_CONV1D, LLM_TENSOR_CONVNEXT_DW, LLM_TENSOR_CONVNEXT_NORM, @@ -497,14 +560,25 @@ enum llm_tensor { LLM_TENSOR_VISEXP_FFN_GATE, LLM_TENSOR_VISEXP_FFN_DOWN, LLM_TENSOR_VISEXP_FFN_UP, + LLM_TENSOR_INDEXER_K_NORM, + LLM_TENSOR_INDEXER_PROJ, + LLM_TENSOR_INDEXER_ATTN_K, + LLM_TENSOR_INDEXER_ATTN_Q_B, + LLM_TENSOR_NEXTN_PROJ_PRE, + LLM_TENSOR_NEXTN_PROJ_POST, LLM_TENSOR_NEXTN_EH_PROJ, LLM_TENSOR_NEXTN_EMBED_TOKENS, LLM_TENSOR_NEXTN_ENORM, LLM_TENSOR_NEXTN_HNORM, LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD, LLM_TENSOR_NEXTN_SHARED_HEAD_NORM, + LLM_TENSOR_MASKED_EMBD_CENTROIDS, + LLM_TENSOR_MASKED_EMBD_ORDERING, + LLM_TENSOR_FC, + LLM_TENSOR_D2T, }; + enum llm_tensor_layer { LLM_TENSOR_LAYER_INPUT, LLM_TENSOR_LAYER_REPEATING, @@ -536,8 +610,6 @@ struct LLM_TN_IMPL { const int bid; const int xid; - const std::set<llm_tensor> model_tensors; - LLM_TN_IMPL(llm_arch arch, llm_tensor tensor, const char * suffix, int bid, int xid); std::string str() const; @@ -575,12 +647,16 @@ struct llm_tensor_info { ggml_op op; }; +std::vector<llm_arch> llm_arch_all(); + const char * llm_arch_name(llm_arch arch); llm_arch llm_arch_from_string(const std::string & name); const llm_tensor_info & llm_tensor_info_for(llm_tensor tensor); -bool llm_arch_is_recurrent(const llm_arch & arch); -bool llm_arch_is_hybrid (const llm_arch & arch); -bool llm_arch_is_diffusion(const llm_arch & arch); +bool llm_arch_is_recurrent (const llm_arch & arch); +bool llm_arch_is_hybrid (const llm_arch & arch); +bool llm_arch_is_diffusion (const llm_arch & arch); +bool llm_arch_supports_sm_tensor(const llm_arch & arch); +bool llm_arch_supports_rs_rollback(const llm_arch & arch); diff --git a/examples/talk-llama/llama-batch.cpp b/examples/talk-llama/llama-batch.cpp index 386fab04ac9..6bf76939cdd 100644 --- a/examples/talk-llama/llama-batch.cpp +++ b/examples/talk-llama/llama-batch.cpp @@ -394,11 +394,13 @@ llama_ubatch llama_batch_allocr::ubatch_reserve(uint32_t n_seq_tokens, uint32_t clear(); split_reset(); + const int64_t n_pos_all = (int64_t) n_tokens*n_pos_per_embd; + auto udata = std::make_shared<llama_ubatch::data_t>(); udata->token .resize(n_tokens); udata->embd .clear(); - udata->pos .resize(n_tokens); + udata->pos .resize(n_pos_all); udata->n_seq_id .resize(n_tokens); udata->seq_id .resize(n_tokens); udata->seq_id_unq.resize(0); diff --git a/examples/talk-llama/llama-batch.h b/examples/talk-llama/llama-batch.h index 8e6fac0efab..f77520e86c3 100644 --- a/examples/talk-llama/llama-batch.h +++ b/examples/talk-llama/llama-batch.h @@ -18,7 +18,7 @@ struct llama_ubatch { } // typical for M-RoPE cases: - // 0 - sequantial position of the tokens/embeddings in the sequence + // 0 - sequential position of the tokens/embeddings in the sequence // 1 - y position in the image // 2 - x position in the image // 3 - other diff --git a/examples/talk-llama/llama-chat.cpp b/examples/talk-llama/llama-chat.cpp index b54ebbd155d..6d822ec62d6 100644 --- a/examples/talk-llama/llama-chat.cpp +++ b/examples/talk-llama/llama-chat.cpp @@ -49,6 +49,7 @@ static const std::map<std::string, llm_chat_template> LLM_CHAT_TEMPLATES = { { "deepseek", LLM_CHAT_TEMPLATE_DEEPSEEK }, { "deepseek2", LLM_CHAT_TEMPLATE_DEEPSEEK_2 }, { "deepseek3", LLM_CHAT_TEMPLATE_DEEPSEEK_3 }, + { "deepseek-ocr", LLM_CHAT_TEMPLATE_DEEPSEEK_OCR }, { "command-r", LLM_CHAT_TEMPLATE_COMMAND_R }, { "llama3", LLM_CHAT_TEMPLATE_LLAMA_3 }, { "chatglm3", LLM_CHAT_TEMPLATE_CHATGLM_3 }, @@ -57,8 +58,11 @@ static const std::map<std::string, llm_chat_template> LLM_CHAT_TEMPLATES = { { "minicpm", LLM_CHAT_TEMPLATE_MINICPM }, { "exaone3", LLM_CHAT_TEMPLATE_EXAONE_3 }, { "exaone4", LLM_CHAT_TEMPLATE_EXAONE_4 }, + { "exaone-moe", LLM_CHAT_TEMPLATE_EXAONE_MOE }, { "rwkv-world", LLM_CHAT_TEMPLATE_RWKV_WORLD }, - { "granite", LLM_CHAT_TEMPLATE_GRANITE }, + { "granite", LLM_CHAT_TEMPLATE_GRANITE_3_X }, + { "granite-4.0", LLM_CHAT_TEMPLATE_GRANITE_4_0 }, + { "granite-4.1", LLM_CHAT_TEMPLATE_GRANITE_4_1 }, { "gigachat", LLM_CHAT_TEMPLATE_GIGACHAT }, { "megrez", LLM_CHAT_TEMPLATE_MEGREZ }, { "yandex", LLM_CHAT_TEMPLATE_YANDEX }, @@ -70,6 +74,7 @@ static const std::map<std::string, llm_chat_template> LLM_CHAT_TEMPLATES = { { "hunyuan-moe", LLM_CHAT_TEMPLATE_HUNYUAN_MOE }, { "gpt-oss", LLM_CHAT_TEMPLATE_OPENAI_MOE }, { "hunyuan-dense", LLM_CHAT_TEMPLATE_HUNYUAN_DENSE }, + { "hunyuan-vl", LLM_CHAT_TEMPLATE_HUNYUAN_VL }, { "kimi-k2", LLM_CHAT_TEMPLATE_KIMI_K2 }, { "seed_oss", LLM_CHAT_TEMPLATE_SEED_OSS }, { "grok-2", LLM_CHAT_TEMPLATE_GROK_2 }, @@ -137,6 +142,9 @@ llm_chat_template llm_chat_detect_template(const std::string & tmpl) { } else if (tmpl_contains("[gMASK]<sop>")) { return LLM_CHAT_TEMPLATE_CHATGLM_4; } else if (tmpl_contains("<|assistant|>") && tmpl_contains("<|user|>")) { + if (tmpl_contains("<|tool_declare|>")) { + return LLM_CHAT_TEMPLATE_EXAONE_MOE; + } return tmpl_contains("</s>") ? LLM_CHAT_TEMPLATE_FALCON_3 : LLM_CHAT_TEMPLATE_GLMEDGE; } else if (tmpl_contains("<|{{ item['role'] }}|>") && tmpl_contains("<|begin_of_image|>")) { return LLM_CHAT_TEMPLATE_GLMEDGE; @@ -186,7 +194,13 @@ llm_chat_template llm_chat_detect_template(const std::string & tmpl) { } else if (tmpl_contains("rwkv-world") || tmpl_contains("{{- 'User: ' + message['content']|trim + '\\n\\n' -}}")) { return LLM_CHAT_TEMPLATE_RWKV_WORLD; } else if (tmpl_contains("<|start_of_role|>")) { - return LLM_CHAT_TEMPLATE_GRANITE; + if (tmpl_contains("<tool_call>") || tmpl_contains("<tools>")) { + if (tmpl_contains("g4_default_system_message")) { + return LLM_CHAT_TEMPLATE_GRANITE_4_0; + } + return LLM_CHAT_TEMPLATE_GRANITE_4_1; + } + return LLM_CHAT_TEMPLATE_GRANITE_3_X; } else if (tmpl_contains("message['role'] + additional_special_tokens[0] + message['content'] + additional_special_tokens[1]")) { return LLM_CHAT_TEMPLATE_GIGACHAT; } else if (tmpl_contains("<|role_start|>")) { @@ -207,6 +221,8 @@ llm_chat_template llm_chat_detect_template(const std::string & tmpl) { return LLM_CHAT_TEMPLATE_HUNYUAN_MOE; } else if (tmpl_contains("<|start|>") && tmpl_contains("<|channel|>")) { return LLM_CHAT_TEMPLATE_OPENAI_MOE; + } else if (tmpl_contains("<|hy_Assistant|>") && tmpl_contains("<|hy_begin▁of▁sentence|>")) { + return LLM_CHAT_TEMPLATE_HUNYUAN_VL; } else if (tmpl_contains("<|hy_Assistant|>") && tmpl_contains("<|hy_place▁holder▁no▁3|>")) { return LLM_CHAT_TEMPLATE_HUNYUAN_DENSE; } else if (tmpl_contains("<|im_assistant|>assistant<|im_middle|>")) { @@ -229,7 +245,7 @@ int32_t llm_chat_apply_template( llm_chat_template tmpl, const std::vector<const llama_chat_message *> & chat, std::string & dest, bool add_ass) { - // Taken from the research: https://github.com/ggerganov/llama.cpp/issues/5527 + // Taken from the research: https://github.com/ggml-org/llama.cpp/issues/5527 std::stringstream ss; if (tmpl == LLM_CHAT_TEMPLATE_CHATML) { // chatml template @@ -544,6 +560,11 @@ int32_t llm_chat_apply_template( if (add_ass) { ss << LU8("<|Assistant|>"); } + } else if (tmpl == LLM_CHAT_TEMPLATE_DEEPSEEK_OCR) { + for (auto message : chat) { + // no template + ss << message->content; + } } else if (tmpl == LLM_CHAT_TEMPLATE_EXAONE_3) { // ref: https://huggingface.co/LGAI-EXAONE/EXAONE-3.0-7.8B-Instruct/discussions/8#66bae61b1893d14ee8ed85bb // EXAONE-3.0-7.8B-Instruct @@ -576,6 +597,22 @@ int32_t llm_chat_apply_template( if (add_ass) { ss << "[|assistant|]"; } + } else if (tmpl == LLM_CHAT_TEMPLATE_EXAONE_MOE) { + for (auto message : chat) { + std::string role(message->role); + if (role == "system") { + ss << "<|system|>\n" << trim(message->content) << "<|endofturn|>\n"; + } else if (role == "user") { + ss << "<|user|>\n" << trim(message->content) << "<|endofturn|>\n"; + } else if (role == "assistant") { + ss << "<|assistant|>\n" << trim(message->content) << "<|endofturn|>\n"; + } else if (role == "tool") { + ss << "<|tool|>\n" << trim(message->content) << "<|endofturn|>\n"; + } + } + if (add_ass) { + ss << "<|assistant|>\n"; + } } else if (tmpl == LLM_CHAT_TEMPLATE_RWKV_WORLD) { // this template requires the model to have "\n\n" as EOT token for (size_t i = 0; i < chat.size(); i++) { @@ -591,8 +628,8 @@ int32_t llm_chat_apply_template( ss << "Assistant: " << trim(chat[i]->content) << "\n\n"; } } - } else if (tmpl == LLM_CHAT_TEMPLATE_GRANITE) { - // IBM Granite template + } else if (tmpl == LLM_CHAT_TEMPLATE_GRANITE_3_X) { + // IBM Granite 3.x template for (const auto & message : chat) { std::string role(message->role); ss << "<|start_of_role|>" << role << "<|end_of_role|>"; @@ -604,6 +641,34 @@ int32_t llm_chat_apply_template( if (add_ass) { ss << "<|start_of_role|>assistant<|end_of_role|>"; } + } else if (tmpl == LLM_CHAT_TEMPLATE_GRANITE_4_0) { + // IBM Granite 4.0 template + for (const auto & message : chat) { + std::string role(message->role); + if (role == "assistant_tool_call") { + ss << "<|start_of_role|>assistant<|end_of_role|><|tool_call|>"; + } else { + ss << "<|start_of_role|>" << role << "<|end_of_role|>"; + } + ss << message->content << "<|end_of_text|>\n"; + } + if (add_ass) { + ss << "<|start_of_role|>assistant<|end_of_role|>"; + } + } else if (tmpl == LLM_CHAT_TEMPLATE_GRANITE_4_1) { + // IBM Granite 4.1 template + for (const auto & message : chat) { + std::string role(message->role); + if (role == "assistant_tool_call") { + ss << "<|start_of_role|>assistant<|end_of_role|><|tool_call|>"; + } else { + ss << "<|start_of_role|>" << role << "<|end_of_role|>"; + } + ss << message->content << "<|end_of_text|>\n"; + } + if (add_ass) { + ss << "<|start_of_role|>assistant<|end_of_role|>"; + } } else if (tmpl == LLM_CHAT_TEMPLATE_GIGACHAT) { // GigaChat template bool has_system = !chat.empty() && std::string(chat[0]->role) == "system"; @@ -778,6 +843,22 @@ int32_t llm_chat_apply_template( ss << "<|hy_User|>" << chat[i]->content << "<|hy_Assistant|>"; } } + } else if (tmpl == LLM_CHAT_TEMPLATE_HUNYUAN_VL) { + // tencent/HunyuanOCR & tencent/HunyuanVL + ss << "<|hy_begin▁of▁sentence|>"; + for (size_t i = 0; i < chat.size(); i++) { + std::string role(chat[i]->role); + if (i == 0 && role == "system") { + ss << chat[i]->content << "<|hy_place▁holder▁no▁3|>"; + continue; + } + + if (role == "user") { + ss << chat[i]->content << "<|hy_User|>"; + } else if (role == "assistant") { + ss << chat[i]->content << "<|hy_Assistant|>"; + } + } } else if (tmpl == LLM_CHAT_TEMPLATE_KIMI_K2) { // moonshotai/Kimi-K2-Instruct for (auto message : chat) { diff --git a/examples/talk-llama/llama-chat.h b/examples/talk-llama/llama-chat.h index e1f795249c8..dc37f919a96 100644 --- a/examples/talk-llama/llama-chat.h +++ b/examples/talk-llama/llama-chat.h @@ -28,6 +28,7 @@ enum llm_chat_template { LLM_CHAT_TEMPLATE_DEEPSEEK, LLM_CHAT_TEMPLATE_DEEPSEEK_2, LLM_CHAT_TEMPLATE_DEEPSEEK_3, + LLM_CHAT_TEMPLATE_DEEPSEEK_OCR, LLM_CHAT_TEMPLATE_COMMAND_R, LLM_CHAT_TEMPLATE_LLAMA_3, LLM_CHAT_TEMPLATE_CHATGLM_3, @@ -36,8 +37,11 @@ enum llm_chat_template { LLM_CHAT_TEMPLATE_MINICPM, LLM_CHAT_TEMPLATE_EXAONE_3, LLM_CHAT_TEMPLATE_EXAONE_4, + LLM_CHAT_TEMPLATE_EXAONE_MOE, LLM_CHAT_TEMPLATE_RWKV_WORLD, - LLM_CHAT_TEMPLATE_GRANITE, + LLM_CHAT_TEMPLATE_GRANITE_3_X, + LLM_CHAT_TEMPLATE_GRANITE_4_0, + LLM_CHAT_TEMPLATE_GRANITE_4_1, LLM_CHAT_TEMPLATE_GIGACHAT, LLM_CHAT_TEMPLATE_MEGREZ, LLM_CHAT_TEMPLATE_YANDEX, @@ -50,6 +54,7 @@ enum llm_chat_template { LLM_CHAT_TEMPLATE_HUNYUAN_MOE, LLM_CHAT_TEMPLATE_OPENAI_MOE, LLM_CHAT_TEMPLATE_HUNYUAN_DENSE, + LLM_CHAT_TEMPLATE_HUNYUAN_VL, LLM_CHAT_TEMPLATE_KIMI_K2, LLM_CHAT_TEMPLATE_SEED_OSS, LLM_CHAT_TEMPLATE_GROK_2, diff --git a/examples/talk-llama/llama-context.cpp b/examples/talk-llama/llama-context.cpp index f220010a1b4..168dbabd766 100644 --- a/examples/talk-llama/llama-context.cpp +++ b/examples/talk-llama/llama-context.cpp @@ -1,12 +1,16 @@ #include "llama-context.h" +#include "ggml.h" #include "llama-arch.h" +#include "llama-graph.h" #include "llama-impl.h" #include "llama-batch.h" #include "llama-io.h" #include "llama-memory.h" #include "llama-mmap.h" #include "llama-model.h" +#include "llama-ext.h" +#include "llama.h" #include <cinttypes> #include <cmath> @@ -18,10 +22,20 @@ // llama_context // +static llm_graph_type ctx_type_to_graph_type(llama_context_type ctx_type) { + switch (ctx_type) { + case LLAMA_CONTEXT_TYPE_DEFAULT: return LLM_GRAPH_TYPE_DEFAULT; + case LLAMA_CONTEXT_TYPE_MTP : return LLM_GRAPH_TYPE_DECODER_MTP; + } + throw std::runtime_error("Unsupported ctx type"); +} + llama_context::llama_context( const llama_model & model, llama_context_params params) : model(model), + cvec(std::make_unique<llama_adapter_cvec>()), + loras(std::make_unique<llama_adapter_loras>()), balloc(std::make_unique<llama_batch_allocr>(model.hparams.n_pos_per_embd())) { // TODO warning when creating llama_context with awkward ctx size that is not a power of 2, // may need to be backend-dependent @@ -37,17 +51,31 @@ llama_context::llama_context( throw std::runtime_error("n_seq_max must be <= " + std::to_string(LLAMA_MAX_SEQ)); } - cparams.n_threads = params.n_threads; - cparams.n_threads_batch = params.n_threads_batch; - cparams.yarn_ext_factor = params.yarn_ext_factor >= 0.0f ? params.yarn_ext_factor : hparams.yarn_ext_factor; - cparams.yarn_attn_factor = params.yarn_attn_factor >= 0.0f ? params.yarn_attn_factor : hparams.yarn_attn_factor; - cparams.yarn_beta_fast = params.yarn_beta_fast >= 0.0f ? params.yarn_beta_fast : hparams.yarn_beta_fast; - cparams.yarn_beta_slow = params.yarn_beta_slow >= 0.0f ? params.yarn_beta_slow : hparams.yarn_beta_slow; - cparams.embeddings = params.embeddings; - cparams.offload_kqv = params.offload_kqv; - cparams.no_perf = params.no_perf; - cparams.pooling_type = params.pooling_type; - cparams.warmup = false; + cparams.n_rs_seq = params.n_rs_seq; + if (cparams.n_rs_seq > 0 && !llm_arch_supports_rs_rollback(model.arch)) { + LLAMA_LOG_DEBUG("%s: n_rs_seq=%u requested but model arch does not support recurrent partial rollback; clamping to 0\n", + __func__, cparams.n_rs_seq); + cparams.n_rs_seq = 0; + } + + cparams.n_threads = params.n_threads; + cparams.n_threads_batch = params.n_threads_batch; + cparams.yarn_ext_factor = params.yarn_ext_factor >= 0.0f ? params.yarn_ext_factor : hparams.yarn_ext_factor; + cparams.yarn_attn_factor = params.yarn_attn_factor >= 0.0f ? params.yarn_attn_factor : hparams.yarn_attn_factor; + cparams.yarn_beta_fast = params.yarn_beta_fast >= 0.0f ? params.yarn_beta_fast : hparams.yarn_beta_fast; + cparams.yarn_beta_slow = params.yarn_beta_slow >= 0.0f ? params.yarn_beta_slow : hparams.yarn_beta_slow; + cparams.embeddings = params.embeddings; + cparams.embeddings_nextn = false; + cparams.embeddings_nextn_masked = false; + cparams.offload_kqv = params.offload_kqv; + cparams.no_perf = params.no_perf; + cparams.warmup = false; + + cparams.embeddings_layer_inp.resize(hparams.n_layer(), false); + embd_layer_inp.resize(hparams.n_layer()); + + cparams.ctx_type = params.ctx_type; + cparams.pooling_type = params.pooling_type; cparams.n_ctx = params.n_ctx == 0 ? hparams.n_ctx_train : params.n_ctx; cparams.rope_freq_base = params.rope_freq_base == 0.0f ? hparams.rope_freq_base_train : params.rope_freq_base; @@ -60,6 +88,27 @@ llama_context::llama_context( cparams.cb_eval = params.cb_eval; cparams.cb_eval_user_data = params.cb_eval_user_data; + cparams.ctx_other = nullptr; + + // TODO: more generic + if (model.arch == LLM_ARCH_GEMMA4_ASSISTANT) { + if (params.ctx_other == nullptr) { + // TODO: change from runtime_error to llama_exception to avoid printing error message + throw std::runtime_error("Gemma4Assistant requires ctx_other to be set (this warning is normal during memory fitting)"); + } + + cparams.ctx_other = params.ctx_other; + } + + if (model.arch == LLM_ARCH_EAGLE3) { + if (model.tok_embd == nullptr || model.output == nullptr) { + if (params.ctx_other == nullptr) { + throw std::runtime_error("EAGLE3 requires ctx_other to be set (this warning is normal during memory fitting)"); + } + cparams.ctx_other = params.ctx_other; + } + } + // Initialize backend samplers here so they are part of the sampling graph // before the reserve passes run later in this function. This avoids a later // re-reserve when graph nodes change. @@ -146,15 +195,25 @@ llama_context::llama_context( } cparams.flash_attn = params.flash_attn_type != LLAMA_FLASH_ATTN_TYPE_DISABLED; + cparams.auto_fa = params.flash_attn_type == LLAMA_FLASH_ATTN_TYPE_AUTO; + + cparams.fused_gdn_ar = true; + cparams.fused_gdn_ch = true; + cparams.auto_fgdn = true; // with causal attention, the batch size is limited by the context size cparams.n_batch = cparams.causal_attn ? std::min(cparams.n_ctx, params.n_batch) : params.n_batch; cparams.n_ubatch = std::min(cparams.n_batch, params.n_ubatch == 0 ? params.n_batch : params.n_ubatch); + cparams.n_outputs_max = params.n_outputs_max == 0 || llama_model_has_encoder(&model) ? cparams.n_batch : params.n_outputs_max; + cparams.op_offload = params.op_offload; cparams.kv_unified = params.kv_unified; + // initialized later + cparams.pipeline_parallel = false; + { const char * LLAMA_GRAPH_REUSE_DISABLE = getenv("LLAMA_GRAPH_REUSE_DISABLE"); graph_reuse_disable = LLAMA_GRAPH_REUSE_DISABLE ? (atoi(LLAMA_GRAPH_REUSE_DISABLE) != 0) : graph_reuse_disable; @@ -193,6 +252,8 @@ llama_context::llama_context( LLAMA_LOG_INFO("%s: kv_unified = %s\n", __func__, cparams.kv_unified ? "true" : "false"); LLAMA_LOG_INFO("%s: freq_base = %.1f\n", __func__, cparams.rope_freq_base); LLAMA_LOG_INFO("%s: freq_scale = %g\n", __func__, cparams.rope_freq_scale); + LLAMA_LOG_INFO("%s: n_rs_seq = %u\n", __func__, cparams.n_rs_seq); + LLAMA_LOG_INFO("%s: n_outputs_max = %u\n", __func__, cparams.n_outputs_max); if (cparams.n_ctx_seq < hparams.n_ctx_train) { LLAMA_LOG_WARN("%s: n_ctx_seq (%u) < n_ctx_train (%u) -- the full capacity of the model will not be utilized\n", @@ -206,10 +267,10 @@ llama_context::llama_context( if (!hparams.vocab_only) { // GPU backends - for (auto * dev : model.devices) { - ggml_backend_t backend = ggml_backend_dev_init(dev, nullptr); + for (const auto & dev : model.devices) { + ggml_backend_t backend = ggml_backend_dev_init(dev.dev, nullptr); if (backend == nullptr) { - throw std::runtime_error(format("failed to initialize %s backend", ggml_backend_dev_name(dev))); + throw std::runtime_error(format("failed to initialize %s backend", ggml_backend_dev_name(dev.dev))); } backends.emplace_back(backend); } @@ -249,11 +310,7 @@ llama_context::llama_context( // graph outputs buffer { - // resized during inference when a batch uses more outputs - // Create a dummy batch for initialization. - llama_batch dummy_batch = {}; - dummy_batch.n_tokens = 0; - if (output_reserve(params.n_seq_max, dummy_batch) < params.n_seq_max) { + if (output_reserve(params.n_seq_max) < params.n_seq_max) { throw std::runtime_error("failed to reserve initial output buffer"); } @@ -266,9 +323,11 @@ llama_context::llama_context( // init the memory module if (!hparams.vocab_only) { llama_memory_params params_mem = { - /*.type_k =*/ params.type_k, - /*.type_v =*/ params.type_v, - /*.swa_full =*/ params.swa_full, + /*.type_k =*/ params.type_k, + /*.type_v =*/ params.type_v, + /*.swa_full =*/ params.swa_full, + /*.ctx_type =*/ cparams.ctx_type, + /*.mem_other =*/ llama_get_memory(cparams.ctx_other), }; memory.reset(model.create_memory(params_mem, cparams)); @@ -288,8 +347,8 @@ llama_context::llama_context( if (backend_type == GGML_BACKEND_DEVICE_TYPE_CPU && !model.devices.empty()) { // use the host buffer of the first device CPU for faster transfer of the intermediate state - auto * dev = model.devices[0]; - auto * host_buft = ggml_backend_dev_host_buffer_type(dev); + const auto & dev = model.devices[0]; + auto * host_buft = ggml_backend_dev_host_buffer_type(dev.dev); if (host_buft) { buft = host_buft; } @@ -302,21 +361,11 @@ llama_context::llama_context( LLAMA_LOG_DEBUG("%s: backend_ptrs.size() = %zu\n", __func__, backend_ptrs.size()); - const uint32_t n_seqs = cparams.n_seq_max; - const uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch); - - const size_t max_nodes = this->graph_max_nodes(n_tokens); - - LLAMA_LOG_DEBUG("%s: max_nodes = %zu\n", __func__, max_nodes); - - gf_res_prev.reset(new llm_graph_result(max_nodes)); - gf_res_reserve.reset(new llm_graph_result(max_nodes)); - // TODO: move these checks to ggml_backend_sched // enabling pipeline parallelism in the scheduler increases memory usage, so it is only done when necessary bool pipeline_parallel = model.n_devices() > 1 && - model.n_gpu_layers() > model.hparams.n_layer && + model.n_gpu_layers() > model.hparams.n_layer_all && model.split_mode() == LLAMA_SPLIT_MODE_LAYER && cparams.offload_kqv && !model.has_tensor_overrides(); @@ -327,6 +376,7 @@ llama_context::llama_context( auto dev_type = ggml_backend_dev_type(ggml_backend_get_device(backend.get())); if (dev_type == GGML_BACKEND_DEVICE_TYPE_CPU) { // ignore CPU backend + // TODO: should we ignore ACCEL types too? continue; } auto * dev = ggml_backend_get_device(backend.get()); @@ -340,177 +390,302 @@ llama_context::llama_context( } } - sched.reset(ggml_backend_sched_new(backend_ptrs.data(), backend_buft.data(), backend_ptrs.size(), max_nodes, pipeline_parallel, cparams.op_offload)); + cparams.pipeline_parallel = pipeline_parallel; - if (pipeline_parallel) { - LLAMA_LOG_INFO("%s: pipeline parallelism enabled (n_copies=%d)\n", __func__, ggml_backend_sched_get_n_copies(sched.get())); + if (cparams.pipeline_parallel) { + LLAMA_LOG_INFO("%s: pipeline parallelism enabled\n", __func__); + } + + sched_reserve(); + + if (!cparams.flash_attn) { + if (ggml_is_quantized(params.type_v)) { + throw std::runtime_error("quantized V cache was requested, but this requires Flash Attention"); + } + } + } + + // Initialize the full vocabulary token ids for backend samplers. + { + const int n_vocab = model.vocab.n_tokens(); + + sampling.token_ids_full_vocab.resize(n_vocab); + for (int i = 0; i < n_vocab; ++i) { + sampling.token_ids_full_vocab[i] = i; } + } +} - llama_memory_context_ptr mctx; - if (memory) { - LLAMA_LOG_DEBUG("%s: reserving full memory module\n", __func__); - mctx = memory->init_full(); - if (!mctx) { - throw std::runtime_error("failed to initialize memory module"); +llama_context::~llama_context() { + if (!model.hparams.no_alloc) { + for (size_t i = 0; i < backend_ptrs.size(); ++i) { + ggml_backend_t backend = backend_ptrs[i]; + ggml_backend_buffer_type_t buft = backend_buft[i]; + + const size_t size_exp = backend_buf_exp_size[i]; + const size_t size_act = ggml_backend_sched_get_buffer_size(sched.get(), backend); + if (size_exp == size_act) { + LLAMA_LOG_DEBUG("%s: %10s compute buffer size is %8.4f MiB, matches expectation of %8.4f MiB\n", + __func__, ggml_backend_buft_name(buft), size_act / (1024.0*1024.0), size_exp / (1024.0*1024.0)); + } else { + LLAMA_LOG_WARN("%s: %10s compute buffer size of %8.4f MiB, does not match expectation of %8.4f MiB\n", + __func__, ggml_backend_buft_name(buft), size_act / (1024.0*1024.0), size_exp / (1024.0*1024.0)); } } + } + ggml_opt_free(opt_ctx); +} - cross.v_embd.clear(); +void llama_context::sched_reserve() { + if (!sched_need_reserve) { + return; + } - // avoid reserving graphs with zero outputs - assume one output per sequence - n_outputs = n_seqs; + sched_need_reserve = false; - LLAMA_LOG_DEBUG("%s: worst-case: n_tokens = %d, n_seqs = %d, n_outputs = %d\n", __func__, n_tokens, n_seqs, n_outputs); + LLAMA_LOG_INFO("%s: reserving ...\n", __func__); - // resolve automatic Flash Attention use - if (params.flash_attn_type == LLAMA_FLASH_ATTN_TYPE_AUTO) { + synchronize(); + + const int64_t t_start_us = ggml_time_us(); + + const uint32_t n_seqs = cparams.n_seq_max; + const uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch); + + const size_t max_nodes = this->graph_max_nodes(n_tokens); + + LLAMA_LOG_DEBUG("%s: max_nodes = %zu\n", __func__, max_nodes); + + gf_res_prev.reset(new llm_graph_result(max_nodes)); + gf_res_reserve.reset(new llm_graph_result(max_nodes)); + + sched.reset(ggml_backend_sched_new(backend_ptrs.data(), backend_buft.data(), backend_ptrs.size(), max_nodes, cparams.pipeline_parallel, cparams.op_offload)); + + llama_memory_context_ptr mctx; + if (memory) { + LLAMA_LOG_DEBUG("%s: reserving full memory module\n", __func__); + mctx = memory->init_full(); + if (!mctx) { + throw std::runtime_error("failed to initialize memory module"); + } + } + + // avoid reserving graphs with zero outputs - assume one output per sequence + const int n_outputs = n_seqs; + + LLAMA_LOG_DEBUG("%s: worst-case: n_tokens = %d, n_seqs = %d, n_outputs = %d\n", __func__, n_tokens, n_seqs, n_outputs); + + // resolve automatic Flash Attention use + if (cparams.auto_fa) { + auto * gf = graph_reserve(1, n_seqs, n_outputs, mctx.get(), true); + if (!gf) { + throw std::runtime_error("failed to reserve graph for Flash Attention check"); + } + + const size_t prefix_len = strlen(LLAMA_TENSOR_NAME_FATTN) + 1; + bool fa_device_mismatch = false; + for (int i = 0; i < ggml_graph_n_nodes(gf); i++) { + ggml_tensor * n = ggml_graph_node(gf, i); + if (n->op != GGML_OP_FLASH_ATTN_EXT) { + continue; + } + ggml_backend_dev_t device_fa = ggml_backend_get_device(ggml_backend_sched_get_tensor_backend(sched.get(), n)); + + // TODO: instead of the tensor names, use a map to keep track of which (FA) tensors belong to which layer + GGML_ASSERT(strncmp(n->name, LLAMA_TENSOR_NAME_FATTN "-", prefix_len) == 0); + const int il = std::stoi(n->name + prefix_len); + ggml_backend_dev_t device_kv = model.dev_layer(il); + if (device_fa != device_kv) { + LLAMA_LOG_WARN("%s: layer %d is assigned to device %s but the Flash Attention tensor " + "is assigned to device %s (usually due to missing support)\n", + __func__, il, ggml_backend_dev_name(device_kv), ggml_backend_dev_name(device_fa)); + // FIXME: fa_device_mismatch logic is wrong for --no-kv-offload, but this is broken anyways + fa_device_mismatch = true; + break; + } + } + + if (fa_device_mismatch) { + cparams.flash_attn = false; + LLAMA_LOG_WARN("%s: Flash Attention was auto, set to disabled\n", __func__); + } else { + cparams.flash_attn = true; + LLAMA_LOG_INFO("%s: Flash Attention was auto, set to enabled\n", __func__); + } + + cparams.auto_fa = false; + } + + if (cparams.auto_fgdn) { + LLAMA_LOG_INFO("%s: resolving fused Gated Delta Net support:\n", __func__); + + if (cparams.fused_gdn_ar) { auto * gf = graph_reserve(1, n_seqs, n_outputs, mctx.get(), true); if (!gf) { - throw std::runtime_error("failed to split graph for Flash Attention check"); + throw std::runtime_error("failed to reserve graph for fused Gated Delta Net check (autoregressive)"); } - const size_t prefix_len = strlen(LLAMA_TENSOR_NAME_FATTN) + 1; - bool fa_device_mismatch = false; + const size_t prefix_len = strlen(LLAMA_TENSOR_NAME_FGDN_AR) + 1; + bool gdn_device_mismatch = false; for (int i = 0; i < ggml_graph_n_nodes(gf); i++) { ggml_tensor * n = ggml_graph_node(gf, i); - if (n->op != GGML_OP_FLASH_ATTN_EXT) { + if (n->op != GGML_OP_GATED_DELTA_NET) { continue; } - ggml_backend_dev_t device_fa = ggml_backend_get_device( - ggml_backend_sched_get_tensor_backend(sched.get(), n)); + ggml_backend_dev_t device_gdn = ggml_backend_get_device(ggml_backend_sched_get_tensor_backend(sched.get(), n)); - // TODO: instead of the tensor names, use a map to keep track of which (FA) tensors belong to which layer - GGML_ASSERT(strncmp(n->name, LLAMA_TENSOR_NAME_FATTN "-", prefix_len) == 0); + GGML_ASSERT(strncmp(n->name, LLAMA_TENSOR_NAME_FGDN_AR "-", prefix_len) == 0); const int il = std::stoi(n->name + prefix_len); ggml_backend_dev_t device_kv = model.dev_layer(il); - if (device_fa != device_kv) { - LLAMA_LOG_WARN("%s: layer %d is assigned to device %s but the Flash Attention tensor " - "is assigned to device %s (usually due to missing support)\n", - __func__, il, ggml_backend_dev_name(device_kv), ggml_backend_dev_name(device_fa)); - // FIXME: fa_device_mismatch logic is wrong for --no-kv-offload, but this is broken anyways - fa_device_mismatch = true; + if (device_gdn != device_kv) { + LLAMA_LOG_WARN("%s: layer %d is assigned to device %s but the fused Gated Delta Net tensor " + "is assigned to device %s (usually due to missing support)\n", + __func__, il, ggml_backend_dev_name(device_kv), ggml_backend_dev_name(device_gdn)); + gdn_device_mismatch = true; break; } } - if (fa_device_mismatch) { - cparams.flash_attn = false; - LLAMA_LOG_WARN("%s: Flash Attention was auto, set to disabled\n", __func__); - if (ggml_is_quantized(params.type_v)) { - throw std::runtime_error("quantized V cache was requested, but this requires Flash Attention"); - } + + if (gdn_device_mismatch) { + cparams.fused_gdn_ar = false; + LLAMA_LOG_WARN("%s: fused Gated Delta Net (autoregressive) not supported, set to disabled\n", __func__); } else { - cparams.flash_attn = true; - LLAMA_LOG_INFO("%s: Flash Attention was auto, set to enabled\n", __func__); + LLAMA_LOG_INFO("%s: fused Gated Delta Net (autoregressive) enabled\n", __func__); } } - // reserve worst-case graph - int n_splits_pp = -1; - int n_nodes_pp = -1; - - int n_splits_tg = -1; - int n_nodes_tg = -1; - - // reserve pp (prompt processing) graph first so that buffers are only allocated once - { - auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mctx.get(), - model.hparams.no_alloc, model.hparams.no_alloc ? backend_buf_exp_size.data() : nullptr); + if (cparams.fused_gdn_ch) { + // more than one token in the batch per sequence in order to take the chunked path + // note: n_outputs must match n_tokens for embedding models with mean/rank pooling, + // because build_pooling creates inp_mean with shape [n_tokens, n_seqs] and multiplies + // it with t_embd which is reduced to [n_outputs, ...] via out_ids. if n_outputs != n_tokens, + // the ggml_mul_mat assertion fails. + const uint32_t n_tokens_ch = 16*n_seqs; + auto * gf = graph_reserve(n_tokens_ch, n_seqs, n_tokens_ch, mctx.get(), true); if (!gf) { - if (pipeline_parallel) { - LLAMA_LOG_WARN("%s: compute buffer allocation failed, retrying without pipeline parallelism\n", __func__); - sched.reset(ggml_backend_sched_new(backend_ptrs.data(), backend_buft.data(), backend_ptrs.size(), max_nodes, false, cparams.op_offload)); - gf = graph_reserve(n_tokens, n_seqs, n_tokens, mctx.get()); + throw std::runtime_error("failed to reserve graph for fused Gated Delta Net check (chunked)"); + } + + const size_t prefix_len = strlen(LLAMA_TENSOR_NAME_FGDN_CH) + 1; + bool gdn_device_mismatch = false; + for (int i = 0; i < ggml_graph_n_nodes(gf); i++) { + ggml_tensor * n = ggml_graph_node(gf, i); + if (n->op != GGML_OP_GATED_DELTA_NET) { + continue; } - if (!gf) { - throw std::runtime_error("failed to allocate compute pp buffers"); + ggml_backend_dev_t device_gdn = ggml_backend_get_device(ggml_backend_sched_get_tensor_backend(sched.get(), n)); + + GGML_ASSERT(strncmp(n->name, LLAMA_TENSOR_NAME_FGDN_CH "-", prefix_len) == 0); + const int il = std::stoi(n->name + prefix_len); + ggml_backend_dev_t device_kv = model.dev_layer(il); + if (device_gdn != device_kv) { + LLAMA_LOG_WARN("%s: layer %d is assigned to device %s but the fused Gated Delta Net tensor " + "is assigned to device %s (usually due to missing support)\n", + __func__, il, ggml_backend_dev_name(device_kv), ggml_backend_dev_name(device_gdn)); + gdn_device_mismatch = true; + break; } } - n_splits_pp = ggml_backend_sched_get_n_splits(sched.get()); - n_nodes_pp = ggml_graph_n_nodes(gf); + if (gdn_device_mismatch) { + cparams.fused_gdn_ch = false; + LLAMA_LOG_WARN("%s: fused Gated Delta Net (chunked) not supported, set to disabled\n", __func__); + } else { + LLAMA_LOG_INFO("%s: fused Gated Delta Net (chunked) enabled\n", __func__); + } } - // reserve with tg (token generation) graph to get the number of splits and nodes - { - auto * gf = graph_reserve(n_seqs, n_seqs, n_seqs, mctx.get(), model.hparams.no_alloc); - if (!gf) { - throw std::runtime_error("failed to allocate compute tg buffers"); - } + cparams.auto_fgdn = false; + } - n_splits_tg = ggml_backend_sched_get_n_splits(sched.get()); - n_nodes_tg = ggml_graph_n_nodes(gf); - } + // reserve worst-case graph + int n_splits_pp = -1; + int n_nodes_pp = -1; - // reserve again with pp graph to avoid ggml-alloc reallocations during inference - { - // TODO: not sure if the following graph would be worster case for multi-stream KV caches: - // - // auto * gf = graph_reserve(n_tokens, 1, n_tokens, mctx.get()); - // - auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mctx.get(), model.hparams.no_alloc); + int n_splits_tg = -1; + int n_nodes_tg = -1; + + const uint32_t n_outputs_pp = std::min(n_tokens, cparams.n_outputs_max); + + // reserve pp (prompt processing) graph first so that buffers are only allocated once + { + auto * gf = graph_reserve(n_tokens, n_seqs, n_outputs_pp, mctx.get(), + model.hparams.no_alloc, model.hparams.no_alloc ? backend_buf_exp_size.data() : nullptr); + if (!gf) { + if (cparams.pipeline_parallel) { + LLAMA_LOG_WARN("%s: compute buffer allocation failed, retrying without pipeline parallelism\n", __func__); + cparams.pipeline_parallel = false; + sched.reset(ggml_backend_sched_new(backend_ptrs.data(), backend_buft.data(), backend_ptrs.size(), max_nodes, false, cparams.op_offload)); + gf = graph_reserve(n_tokens, n_seqs, n_outputs_pp, mctx.get()); + } if (!gf) { throw std::runtime_error("failed to allocate compute pp buffers"); } } - for (size_t i = 0; i < backend_ptrs.size(); ++i) { - ggml_backend_t backend = backend_ptrs[i]; - ggml_backend_buffer_type_t buft = backend_buft[i]; - if (!model.hparams.no_alloc) { - backend_buf_exp_size[i] = ggml_backend_sched_get_buffer_size(sched.get(), backend); - } - if (backend_buf_exp_size[i] > 1) { - LLAMA_LOG_INFO("%s: %10s compute buffer size = %8.2f MiB\n", __func__, - ggml_backend_buft_name(buft), - backend_buf_exp_size[i] / 1024.0 / 1024.0); - } - } + n_splits_pp = ggml_backend_sched_get_n_splits(sched.get()); + n_nodes_pp = ggml_graph_n_nodes(gf); + } - if (n_nodes_pp == n_nodes_tg) { - LLAMA_LOG_INFO("%s: graph nodes = %d\n", __func__, n_nodes_pp); - } else { - LLAMA_LOG_INFO("%s: graph nodes = %d (with bs=%d), %d (with bs=1)\n", __func__, n_nodes_pp, n_tokens, n_nodes_tg); + // reserve with tg (token generation) graph to get the number of splits and nodes + { + auto * gf = graph_reserve(n_seqs, n_seqs, n_seqs, mctx.get(), model.hparams.no_alloc); + if (!gf) { + throw std::runtime_error("failed to allocate compute tg buffers"); } - if (n_splits_pp == n_splits_tg) { - LLAMA_LOG_INFO("%s: graph splits = %d\n", __func__, n_splits_pp); - } else { - LLAMA_LOG_INFO("%s: graph splits = %d (with bs=%d), %d (with bs=1)\n", __func__, n_splits_pp, n_tokens, n_splits_tg); - } + n_splits_tg = ggml_backend_sched_get_n_splits(sched.get()); + n_nodes_tg = ggml_graph_n_nodes(gf); } - // Initialize the full vocabulary token ids for backend samplers. + // reserve again with pp graph to avoid ggml-alloc reallocations during inference { - const int n_vocab = model.vocab.n_tokens(); + // TODO: not sure if the following graph would be worst case for multi-stream KV caches: + // + // auto * gf = graph_reserve(n_tokens, 1, n_tokens, mctx.get()); + // + auto * gf = graph_reserve(n_tokens, n_seqs, n_outputs_pp, mctx.get(), model.hparams.no_alloc); + if (!gf) { + throw std::runtime_error("failed to allocate compute pp buffers"); + } + } - sampling.token_ids_full_vocab.resize(n_vocab); - for (int i = 0; i < n_vocab; ++i) { - sampling.token_ids_full_vocab[i] = i; + for (size_t i = 0; i < backend_ptrs.size(); ++i) { + ggml_backend_t backend = backend_ptrs[i]; + ggml_backend_buffer_type_t buft = backend_buft[i]; + if (!model.hparams.no_alloc) { + backend_buf_exp_size[i] = ggml_backend_sched_get_buffer_size(sched.get(), backend); + } + if (backend_buf_exp_size[i] > 1) { + LLAMA_LOG_INFO("%s: %10s compute buffer size = %8.2f MiB\n", __func__, + ggml_backend_buft_name(buft), + backend_buf_exp_size[i] / 1024.0 / 1024.0); } } -} -llama_context::~llama_context() { - if (!model.hparams.no_alloc) { - for (size_t i = 0; i < backend_ptrs.size(); ++i) { - ggml_backend_t backend = backend_ptrs[i]; - ggml_backend_buffer_type_t buft = backend_buft[i]; + if (n_nodes_pp == n_nodes_tg) { + LLAMA_LOG_INFO("%s: graph nodes = %d\n", __func__, n_nodes_pp); + } else { + LLAMA_LOG_INFO("%s: graph nodes = %d (with bs=%d), %d (with bs=1)\n", __func__, n_nodes_pp, n_tokens, n_nodes_tg); + } - const size_t size_exp = backend_buf_exp_size[i]; - const size_t size_act = ggml_backend_sched_get_buffer_size(sched.get(), backend); - if (size_exp == size_act) { - LLAMA_LOG_DEBUG("%s: %10s compute buffer size is %8.4f MiB, matches expectation of %8.4f MiB\n", - __func__, ggml_backend_buft_name(buft), size_act / (1024.0*1024.0), size_exp / (1024.0*1024.0)); - } else { - LLAMA_LOG_WARN("%s: %10s compute buffer size of %8.4f MiB, does not match expectation of %8.4f MiB\n", - __func__, ggml_backend_buft_name(buft), size_act / (1024.0*1024.0), size_exp / (1024.0*1024.0)); - } - } + if (n_splits_pp == n_splits_tg) { + LLAMA_LOG_INFO("%s: graph splits = %d\n", __func__, n_splits_pp); + } else { + LLAMA_LOG_INFO("%s: graph splits = %d (with bs=%d), %d (with bs=1)\n", __func__, n_splits_pp, n_tokens, n_splits_tg); } - ggml_opt_free(opt_ctx); + + const int64_t t_end_us = ggml_time_us(); + + LLAMA_LOG_INFO("%s: reserve took %.2f ms, sched copies = %d\n", + __func__, (t_end_us - t_start_us)/1000.0, ggml_backend_sched_get_n_copies(sched.get())); } void llama_context::synchronize() { + if (!sched) { + return; + } + ggml_backend_sched_synchronize(sched.get()); // FIXME: if multiple single tokens are evaluated without a synchronization, @@ -629,7 +804,9 @@ bool llama_context::memory_update(bool optimize) { const uint32_t n_seqs = cparams.n_seq_max; const uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch); - auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mctx.get()); + const uint32_t n_outputs_max = std::min(n_tokens, cparams.n_outputs_max); + + auto * gf = graph_reserve(n_tokens, n_seqs, n_outputs_max, mctx.get()); if (!gf) { LLAMA_LOG_ERROR("%s: failed to reserve graph after the memory update\n", __func__); } @@ -645,7 +822,7 @@ enum llama_pooling_type llama_context::pooling_type() const { float * llama_context::get_logits() { output_reorder(); - return logits; + return logits.data; } int64_t llama_context::output_resolve_row(int32_t i) const { @@ -678,36 +855,15 @@ int64_t llama_context::output_resolve_row(int32_t i) const { } float * llama_context::get_logits_ith(int32_t i) { - int64_t j = -1; - output_reorder(); try { - if (logits == nullptr) { + if (logits.data == nullptr) { throw std::runtime_error("no logits"); } - // TODO: use output_resolve_row() - if (i < 0) { - j = n_outputs + i; - if (j < 0) { - throw std::runtime_error(format("negative index out of range [0, %d)", n_outputs)); - } - } else if ((size_t) i >= output_ids.size()) { - throw std::runtime_error(format("out of range [0, %zu)", output_ids.size())); - } else { - j = output_ids[i]; - } - - if (j < 0) { - throw std::runtime_error(format("batch.logits[%d] != true", i)); - } - if (j >= n_outputs) { - // This should not happen - throw std::runtime_error(format("corrupt output buffer (j=%" PRId64 ", n_outputs=%d)", j, n_outputs)); - } - - return logits + j*model.vocab.n_tokens(); + const int64_t j = output_resolve_row(i); + return logits.data + j*model.vocab.n_tokens(); } catch (const std::exception & err) { LLAMA_LOG_ERROR("%s: invalid logits id %d, reason: %s\n", __func__, i, err.what()); #ifndef NDEBUG @@ -721,45 +877,24 @@ float * llama_context::get_logits_ith(int32_t i) { float * llama_context::get_embeddings() { output_reorder(); - return embd; + return embd.data; } llama_token * llama_context::get_sampled_tokens() const{ - return sampling.sampled; + return sampling.sampled.data; } float * llama_context::get_embeddings_ith(int32_t i) { - int64_t j = -1; - output_reorder(); try { - if (embd == nullptr) { + if (embd.data == nullptr) { throw std::runtime_error("no embeddings"); } - // TODO: use output_resolve_row() - if (i < 0) { - j = n_outputs + i; - if (j < 0) { - throw std::runtime_error(format("negative index out of range [0, %d)", n_outputs)); - } - } else if ((size_t) i >= output_ids.size()) { - throw std::runtime_error(format("out of range [0, %zu)", output_ids.size())); - } else { - j = output_ids[i]; - } - - if (j < 0) { - throw std::runtime_error(format("batch.logits[%d] != true", i)); - } - if (j >= n_outputs) { - // This should not happen - throw std::runtime_error(format("corrupt output buffer (j=%" PRId64 ", n_outputs=%d)", j, n_outputs)); - } - - const uint32_t n_embd_out = model.hparams.get_n_embd_out(); - return embd + j*n_embd_out; + const int64_t j = output_resolve_row(i); + const uint32_t n_embd_out = model.hparams.n_embd_out(); + return embd.data + j*n_embd_out; } catch (const std::exception & err) { LLAMA_LOG_ERROR("%s: invalid embeddings id %d, reason: %s\n", __func__, i, err.what()); #ifndef NDEBUG @@ -779,17 +914,61 @@ float * llama_context::get_embeddings_seq(llama_seq_id seq_id) { return it->second.data(); } +float * llama_context::get_embeddings_nextn() { + output_reorder(); + + return embd_nextn.data; +} + +float * llama_context::get_embeddings_nextn_ith(int32_t i) { + output_reorder(); + + try { + if (embd_nextn.data == nullptr) { + throw std::runtime_error("no nextn embeddings"); + } + + const uint32_t n_embd = model.hparams.n_embd_out(); + + if (!cparams.embeddings_nextn_masked) { + // unmasked: nextn rows are stored densely, indexed by raw token position. + if (i < 0 || (size_t)(i + 1) * n_embd > embd_nextn.size) { + throw std::runtime_error(format("out of range [0, %zu)", embd_nextn.size / n_embd)); + } + return embd_nextn.data + (size_t) i * n_embd; + } + + const int64_t j = output_resolve_row(i); + return embd_nextn.data + j*n_embd; + } catch (const std::exception & err) { + LLAMA_LOG_ERROR("%s: invalid nextn embeddings id %d, reason: %s\n", __func__, i, err.what()); +#ifndef NDEBUG + GGML_ABORT("fatal error"); +#else + return nullptr; +#endif + } +} + +float * llama_context::get_embeddings_layer_inp(uint32_t lid) { + output_reorder(); + + GGML_ASSERT(lid < embd_layer_inp.size() && embd_layer_inp[lid].has_data()); + + return embd_layer_inp[lid].data; +} + llama_token llama_context::get_sampled_token_ith(int32_t idx) { output_reorder(); - if (sampling.sampled == nullptr) { + if (!sampling.sampled.has_data()) { return LLAMA_TOKEN_NULL; } try { const int64_t row = output_resolve_row(idx); - GGML_ASSERT(row < (int64_t) sampling.sampled_size); - return sampling.sampled[row]; + GGML_ASSERT(row < (int64_t) sampling.sampled.size); + return sampling.sampled.data[row]; } catch (const std::exception & err) { LLAMA_LOG_ERROR("%s: invalid backend sampled token id %d, reason: %s\n", __func__, idx, err.what()); return LLAMA_TOKEN_NULL; @@ -799,7 +978,7 @@ llama_token llama_context::get_sampled_token_ith(int32_t idx) { float * llama_context::get_sampled_probs_ith(int32_t idx) { output_reorder(); - if (sampling.probs == nullptr) { + if (!sampling.probs.has_data()) { return nullptr; } @@ -808,7 +987,7 @@ float * llama_context::get_sampled_probs_ith(int32_t idx) { if ((size_t) row >= sampling.probs_count.size() || sampling.probs_count[row] == 0) { return nullptr; } - return sampling.probs + row*model.vocab.n_tokens(); + return sampling.probs.data + row*model.vocab.n_tokens(); } catch (const std::exception & err) { LLAMA_LOG_ERROR("%s: invalid backend sampled probs id %d, reason: %s\n", __func__, idx, err.what()); return nullptr; @@ -818,7 +997,7 @@ float * llama_context::get_sampled_probs_ith(int32_t idx) { float * llama_context::get_sampled_logits_ith(int32_t idx) { output_reorder(); - if (sampling.logits == nullptr) { + if (!sampling.logits.has_data()) { return nullptr; } @@ -827,7 +1006,7 @@ float * llama_context::get_sampled_logits_ith(int32_t idx) { if ((size_t) row >= sampling.logits_count.size() || sampling.logits_count[row] == 0) { return nullptr; } - return sampling.logits + row*model.vocab.n_tokens(); + return sampling.logits.data + row*model.vocab.n_tokens(); } catch (const std::exception & err) { LLAMA_LOG_ERROR("%s: invalid backend sampled logits id %d, reason: %s\n", __func__, idx, err.what()); return nullptr; @@ -839,13 +1018,14 @@ const llama_token * llama_context::get_sampled_candidates_ith(int32_t idx) { try { const int64_t row = output_resolve_row(idx); - if (sampling.candidates != nullptr && + if (sampling.candidates.has_data() && (size_t) row < sampling.candidates_count.size() && sampling.candidates_count[row] > 0) { - return sampling.candidates + row*model.vocab.n_tokens(); + return sampling.candidates.data + row*model.vocab.n_tokens(); } } catch (const std::exception & err) { // fallback to full vocab list + GGML_UNUSED(err); } return sampling.token_ids_full_vocab.data(); @@ -854,7 +1034,7 @@ const llama_token * llama_context::get_sampled_candidates_ith(int32_t idx) { size_t llama_context::get_sampled_candidates_count(int32_t idx) { output_reorder(); - if (sampling.candidates == nullptr) { + if (!sampling.candidates.has_data()) { return 0; } @@ -873,7 +1053,7 @@ size_t llama_context::get_sampled_candidates_count(int32_t idx) { size_t llama_context::get_sampled_logits_count(int32_t idx) { output_reorder(); - if (sampling.logits == nullptr) { + if (!sampling.logits.has_data()) { return model.vocab.n_tokens(); } @@ -892,7 +1072,7 @@ size_t llama_context::get_sampled_logits_count(int32_t idx) { size_t llama_context::get_sampled_probs_count(int32_t idx) { output_reorder(); - if (sampling.probs == nullptr) { + if (!sampling.probs.has_data()) { return 0; } @@ -940,9 +1120,11 @@ void llama_context::set_abort_callback(bool (*abort_callback)(void * data), void for (auto & backend : backends) { auto * reg = ggml_backend_dev_backend_reg(ggml_backend_get_device(backend.get())); - auto * set_abort_callback_fn = (ggml_backend_set_abort_callback_t) ggml_backend_reg_get_proc_address(reg, "ggml_backend_set_abort_callback"); - if (set_abort_callback_fn) { - set_abort_callback_fn(backend.get(), this->abort_callback, this->abort_callback_data); + if (reg) { + auto * set_abort_callback_fn = (ggml_backend_set_abort_callback_t) ggml_backend_reg_get_proc_address(reg, "ggml_backend_set_abort_callback"); + if (set_abort_callback_fn) { + set_abort_callback_fn(backend.get(), this->abort_callback, this->abort_callback_data); + } } } } @@ -951,23 +1133,74 @@ void llama_context::set_embeddings(bool value) { LLAMA_LOG_DEBUG("%s: value = %d\n", __func__, value); cparams.embeddings = value; + + // TODO: not sure yet if we want to reserve here + //sched_need_reserve = true; +} + +void llama_context::set_embeddings_nextn(bool value, bool masked) { + LLAMA_LOG_DEBUG("%s: value = %d, masked = %d\n", __func__, value, masked); + + cparams.embeddings_nextn = value; + cparams.embeddings_nextn_masked = masked; +} + +void llama_context::set_embeddings_layer_inp(uint32_t lid, bool enable) { + LLAMA_LOG_DEBUG("%s: lid = %d, enable = %d\n", __func__, lid, enable); + + GGML_ASSERT(lid < model.hparams.n_layer()); + + cparams.embeddings_layer_inp[lid] = enable; + + // note: without this reserve, the draft acceptance drops to zero. not sure why - this is unexpected + sched_need_reserve = true; } void llama_context::set_causal_attn(bool value) { LLAMA_LOG_DEBUG("%s: value = %d\n", __func__, value); + if (cparams.causal_attn == value) { + return; + } + cparams.causal_attn = value; + + sched_need_reserve = true; } void llama_context::set_warmup(bool value) { LLAMA_LOG_DEBUG("%s: value = %d\n", __func__, value); + if (cparams.warmup == value) { + return; + } + cparams.warmup = value; + + // warmups are usually with small batches, so no need to reserve + //sched_need_reserve = true; } bool llama_context::set_sampler(llama_seq_id seq_id, llama_sampler * sampler) { + if (!sampler && sampling.samplers.count(seq_id) == 0) { + return true; + } + LLAMA_LOG_DEBUG("%s: seq_id = %d, sampler = %p\n", __func__, (int) seq_id, (void *) sampler); + if (sampler && model.split_mode() == LLAMA_SPLIT_MODE_TENSOR) { + static bool warned = false; + if (!warned) { + LLAMA_LOG_WARN("%s: backend sampling not supported with SPLIT_MODE_TENSOR; using CPU\n", __func__); + warned = true; + } + if (sampling.samplers.count(seq_id) > 0) { + sched_need_reserve = true; + } + sampling.samplers.erase(seq_id); + return false; + } + const bool can_offload = sampler && sampler->iface->backend_init && @@ -975,22 +1208,24 @@ bool llama_context::set_sampler(llama_seq_id seq_id, llama_sampler * sampler) { llama_sampler_chain_n(sampler) > 0; if (sampler && can_offload) { - ggml_backend_buffer_type_t buft = ggml_backend_dev_buffer_type(model.dev_output()); - auto * host_buft = ggml_backend_dev_host_buffer_type(model.dev_output()); - if (host_buft) { - buft = host_buft; - } + auto * buft = ggml_backend_dev_buffer_type(model.dev_output()); sampler->iface->backend_init(sampler, buft); sampling.samplers[seq_id] = sampler; + sched_need_reserve = true; + return true; } if (sampler && !can_offload) { LLAMA_LOG_WARN("%s: sampler '%s' for seq_id = %d, cannot be offloaded to the backend\n", __func__, llama_sampler_name(sampler), seq_id); + if (sampling.samplers.count(seq_id) > 0) { + sched_need_reserve = true; + } + sampling.samplers.erase(seq_id); return false; @@ -998,37 +1233,56 @@ bool llama_context::set_sampler(llama_seq_id seq_id, llama_sampler * sampler) { sampling.samplers.erase(seq_id); + sched_need_reserve = true; + return true; } -void llama_context::set_adapter_lora( - llama_adapter_lora * adapter, - float scale) { - LLAMA_LOG_DEBUG("%s: adapter = %p, scale = %f\n", __func__, (void *) adapter, scale); +void llama_context::set_adapters_lora(llama_adapter_lora ** adapters, size_t n_adapters, float * scales) { + LLAMA_LOG_DEBUG("%s: adapters = %p\n", __func__, (void *) adapters); + + if (adapters_lora_are_same(adapters, n_adapters, scales)) { + return; + } + + loras.reset(new llama_adapter_loras()); + + for (size_t i = 0; i < n_adapters; i ++) { + if (scales[i] != 0.0f) { + loras->insert({adapters[i], scales[i]}); + } + } - loras[adapter] = scale; + sched_need_reserve = true; } -bool llama_context::rm_adapter_lora( - llama_adapter_lora * adapter) { - LLAMA_LOG_DEBUG("%s: adapter = %p\n", __func__, (void *) adapter); +bool llama_context::adapters_lora_are_same(llama_adapter_lora ** adapters, size_t n_adapters, float * scales) { + LLAMA_LOG_DEBUG("%s: adapters = %p\n", __func__, (void *) adapters); + + // Adapters with a zero scale are never added to `loras`, so also ignore them for the comparison. + size_t n_non_zero = 0; + + for (size_t i = 0; i < n_adapters; i ++) { + if (scales[i] == 0.0f) { + continue; + } + n_non_zero++; + + auto it = loras->find(adapters[i]); - auto pos = loras.find(adapter); - if (pos != loras.end()) { - loras.erase(pos); - return true; + if (it == loras->end() || it->second != scales[i]) { + return false; + } } - return false; -} - -void llama_context::clear_adapter_lora() { - LLAMA_LOG_DEBUG("%s: call\n", __func__); + if (n_non_zero != loras->size()) { + return false; + } - loras.clear(); + return true; } -bool llama_context::apply_adapter_cvec( +bool llama_context::set_adapter_cvec( const float * data, size_t len, int32_t n_embd, @@ -1036,7 +1290,11 @@ bool llama_context::apply_adapter_cvec( int32_t il_end) { LLAMA_LOG_DEBUG("%s: il_start = %d, il_end = %d\n", __func__, il_start, il_end); - return cvec.apply(model, data, len, n_embd, il_start, il_end); + bool res = cvec->apply(model, data, len, n_embd, il_start, il_end); + + sched_need_reserve = true; + + return res; } llm_graph_result * llama_context::process_ubatch(const llama_ubatch & ubatch, llm_graph_type gtype, llama_memory_context_i * mctx, ggml_status & ret) { @@ -1056,6 +1314,13 @@ llm_graph_result * llama_context::process_ubatch(const llama_ubatch & ubatch, ll if (!graph_reuse_disable && res->can_reuse(gparams)) { //LLAMA_LOG_DEBUG("%s: reusing previous graph\n", __func__); + // with pipeline parallelism, the previous graph_compute_async may still be running + // on the GPU. we must synchronize before set_inputs to avoid overwriting input tensors + // that the previous compute is still reading. + if (cparams.pipeline_parallel) { + ggml_backend_sched_synchronize(sched.get()); + } + n_reused++; } else { res->reset(); @@ -1086,6 +1351,7 @@ llm_graph_result * llama_context::process_ubatch(const llama_ubatch & ubatch, ll { //const auto t_start_us = ggml_time_us(); + // FIXME this call causes a crash if any model inputs were not used in the graph and were therefore not allocated res->set_inputs(&ubatch); //LLAMA_LOG_INFO("graph set inputs time: %.3f ms\n", (ggml_time_us() - t_start_us)/1000.0); @@ -1104,7 +1370,9 @@ llm_graph_result * llama_context::process_ubatch(const llama_ubatch & ubatch, ll } int llama_context::encode(const llama_batch & batch_inp) { - GGML_ASSERT((!batch_inp.token && batch_inp.embd) || (batch_inp.token && !batch_inp.embd)); // NOLINT + // MTP hook batches carry both token (next-token id) and embd (h_nextn row), + // so accept either present rather than requiring exactly one. + GGML_ASSERT(batch_inp.token || batch_inp.embd); if (batch_inp.n_tokens == 0) { LLAMA_LOG_ERROR("%s: n_tokens == 0\n", __func__); @@ -1113,7 +1381,8 @@ int llama_context::encode(const llama_batch & batch_inp) { const auto & hparams = model.hparams; - const int64_t n_embd = hparams.n_embd_inp(); + // eagle3/DFlash: features as encoder input, and non-draft paths fall back to model's input dim + const int64_t n_embd = hparams.n_embd_inp(); const int64_t n_vocab = model.vocab.n_tokens(); // note: during encode, we always pass the full sequence starting from pos = 0 @@ -1138,10 +1407,12 @@ int llama_context::encode(const llama_batch & batch_inp) { // TODO: this clear of the buffer can easily be forgotten - need something better embd_seq.clear(); + sched_reserve(); + n_queued_tokens += n_tokens; // reserve output buffer - if (output_reserve(n_tokens, batch_inp) < n_tokens) { + if (output_reserve(n_tokens) < n_tokens) { LLAMA_LOG_ERROR("%s: could not reserve space for batch with %u outputs\n", __func__, n_tokens); return -2; }; @@ -1173,20 +1444,21 @@ int llama_context::encode(const llama_batch & batch_inp) { } } - auto * t_logits = res->get_logits(); - auto * t_embd = res->get_embd_pooled() ? res->get_embd_pooled() : res->get_embd(); + auto * t_logits = res->get_logits(); + auto * t_embd = res->get_embd_pooled() ? res->get_embd_pooled() : res->get_embd(); + auto * t_h_nextn = cparams.embeddings_nextn ? res->get_h_nextn() : nullptr; // extract logits - if (logits && t_logits) { + if (logits.data && t_logits) { ggml_backend_t backend_res = ggml_backend_sched_get_tensor_backend(sched.get(), t_logits); GGML_ASSERT(backend_res != nullptr); - GGML_ASSERT(logits != nullptr); + GGML_ASSERT(logits.data != nullptr); - ggml_backend_tensor_get_async(backend_res, t_logits, logits, 0, n_tokens*n_vocab*sizeof(float)); + ggml_backend_tensor_get_async(backend_res, t_logits, logits.data, 0, n_tokens*n_vocab*sizeof(float)); } // extract embeddings - if (embd && t_embd) { + if (embd.data && t_embd) { ggml_backend_t backend_embd = ggml_backend_sched_get_tensor_backend(sched.get(), t_embd); GGML_ASSERT(backend_embd != nullptr); @@ -1194,11 +1466,11 @@ int llama_context::encode(const llama_batch & batch_inp) { case LLAMA_POOLING_TYPE_NONE: { // extract token embeddings - GGML_ASSERT(embd != nullptr); - const uint32_t n_embd_out = hparams.get_n_embd_out(); + GGML_ASSERT(embd.data != nullptr); + const uint32_t n_embd_out = hparams.n_embd_out(); - GGML_ASSERT(n_tokens*n_embd_out <= (int64_t) embd_size); - ggml_backend_tensor_get_async(backend_embd, t_embd, embd, 0, n_tokens*n_embd_out*sizeof(float)); + GGML_ASSERT(n_tokens*n_embd_out <= (int64_t) embd.size); + ggml_backend_tensor_get_async(backend_embd, t_embd, embd.data, 0, n_tokens*n_embd_out*sizeof(float)); } break; case LLAMA_POOLING_TYPE_MEAN: case LLAMA_POOLING_TYPE_CLS: @@ -1211,8 +1483,11 @@ int llama_context::encode(const llama_batch & batch_inp) { const llama_seq_id seq_id = ubatch.seq_id_unq[s]; const int32_t seq_idx = ubatch.seq_idx[seq_id]; - embd_seq_out[seq_id].resize(n_embd); - ggml_backend_tensor_get_async(backend_embd, t_embd, embd_seq_out[seq_id].data(), (n_embd*seq_idx)*sizeof(float), n_embd*sizeof(float)); + // use n_embd_out (not n_embd_inp) - the pooled embedding has the model's + // output dimension, which differs from input dimension for deepstack models (e.g. qwen3vl) + const uint32_t n_embd_out = hparams.n_embd_out(); + embd_seq_out[seq_id].resize(n_embd_out); + ggml_backend_tensor_get_async(backend_embd, t_embd, embd_seq_out[seq_id].data(), (n_embd_out*seq_idx)*sizeof(float), n_embd_out*sizeof(float)); } } break; case LLAMA_POOLING_TYPE_RANK: @@ -1237,6 +1512,16 @@ int llama_context::encode(const llama_batch & batch_inp) { } } + // extract nextn embeddings (hidden state before the final output norm) + if (embd_nextn.data && t_h_nextn && cparams.pooling_type == LLAMA_POOLING_TYPE_NONE) { + ggml_backend_t backend_h = ggml_backend_sched_get_tensor_backend(sched.get(), t_h_nextn); + GGML_ASSERT(backend_h != nullptr); + + const uint32_t n_embd = hparams.n_embd_out(); + GGML_ASSERT(n_tokens*n_embd <= (int64_t) embd_nextn.size); + ggml_backend_tensor_get_async(backend_h, t_h_nextn, embd_nextn.data, 0, n_tokens*n_embd*sizeof(float)); + } + // TODO: hacky solution if (model.arch == LLM_ARCH_T5 && t_embd) { //cross.t_embd = t_embd; @@ -1246,7 +1531,7 @@ int llama_context::encode(const llama_batch & batch_inp) { cross.n_embd = t_embd->ne[0]; cross.n_enc = t_embd->ne[1]; cross.v_embd.resize(cross.n_embd*cross.n_enc); - memcpy(cross.v_embd.data(), embd, ggml_nbytes(t_embd)); + memcpy(cross.v_embd.data(), embd.data, ggml_nbytes(t_embd)); const auto & batch = balloc->get_batch(); @@ -1286,11 +1571,10 @@ static std::map<llama_seq_id, uint32_t> build_seq_to_output_row(const llama_ubat static void copy_tensor_async_ints( const std::map<llama_seq_id, ggml_tensor*> & tensor_map, - llama_token * sampled, - size_t sampled_size, + const buffer_view<llama_token> & sampled, const std::map<llama_seq_id, uint32_t> & seq_to_row, ggml_backend_sched_t sched) { - if (sampled == nullptr) { + if (!sampled.has_data()) { return; } @@ -1301,23 +1585,23 @@ static void copy_tensor_async_ints( } const uint32_t row = it->second; - GGML_ASSERT(row < sampled_size); + GGML_ASSERT(row < sampled.size); GGML_ASSERT(ggml_is_contiguous(tensor) && "sampled tokens tensor must be contiguous for async copy"); ggml_backend_t backend = ggml_backend_sched_get_tensor_backend(sched, tensor); - ggml_backend_tensor_get_async(backend, tensor, sampled + row, 0, sizeof(sampled[row])); + ggml_backend_tensor_get_async(backend, tensor, sampled.data + row, 0, sizeof(sampled.data[row])); } } static void copy_tensor_async_floats( const std::map<llama_seq_id, ggml_tensor*> & tensor_map, - float * dst, + const buffer_view<float> & dst, size_t stride, std::vector<uint32_t> & counts, const std::map<llama_seq_id, uint32_t> & seq_to_row, ggml_backend_sched_t sched) { - if (dst == nullptr) { + if (!dst.has_data()) { return; } @@ -1333,7 +1617,7 @@ static void copy_tensor_async_floats( GGML_ASSERT(ggml_is_contiguous(tensor) && "logits/probs tensor must be contiguous for async copy"); ggml_backend_t backend = ggml_backend_sched_get_tensor_backend(sched, tensor); - float * row_ptr = dst + (size_t) row * stride; + float * row_ptr = dst.data + (size_t) row * stride; ggml_backend_tensor_get_async(backend, tensor, row_ptr, 0, ggml_nbytes(tensor)); // Update the actual number of logits/probabilities that were written for this row. @@ -1343,12 +1627,12 @@ static void copy_tensor_async_floats( static void copy_tensor_async_candidates( const std::map<llama_seq_id, ggml_tensor*> & tensor_map, - llama_token * dst, + const buffer_view<llama_token> & dst, size_t stride, std::vector<uint32_t> & counts, const std::map<llama_seq_id, uint32_t> & seq_to_row, ggml_backend_sched_t sched) { - if (dst == nullptr) { + if (!dst.has_data()) { return; } @@ -1364,7 +1648,7 @@ static void copy_tensor_async_candidates( GGML_ASSERT(ggml_is_contiguous(tensor) && "candidates tensor must be contiguous for async copy"); ggml_backend_t backend = ggml_backend_sched_get_tensor_backend(sched, tensor); - llama_token * row_ptr = dst + (size_t) row * stride; + llama_token * row_ptr = dst.data + (size_t) row * stride; ggml_backend_tensor_get_async(backend, tensor, row_ptr, 0, ggml_nbytes(tensor)); // Update the actual number of candidates that were written. @@ -1372,8 +1656,27 @@ static void copy_tensor_async_candidates( } } +static bool needs_raw_logits(const llama_ubatch & ubatch, const std::map<llama_seq_id, llama_sampler *> & samplers) { + for (uint32_t i = 0; i < ubatch.n_tokens; i++) { + if (!ubatch.output[i]) { + continue; + } + + // Check if the output token has at least one sequence without a backend sampler. + for (int32_t j = 0; j < ubatch.n_seq_id[i]; ++j) { + llama_seq_id seq_id = ubatch.seq_id[i][j]; + if (samplers.find(seq_id) == samplers.end()) { + return true; + } + } + } + return false; // all sequences use backend sampling +} + int llama_context::decode(const llama_batch & batch_inp) { - GGML_ASSERT((!batch_inp.token && batch_inp.embd) || (batch_inp.token && !batch_inp.embd)); // NOLINT + // MTP hook batches carry both token (next-token id) and embd (h_nextn row), + // so accept either present rather than requiring exactly one. + GGML_ASSERT(batch_inp.token || batch_inp.embd); if (!memory) { LLAMA_LOG_DEBUG("%s: cannot decode batches with this context (calling encode() instead)\n", __func__); @@ -1451,6 +1754,8 @@ int llama_context::decode(const llama_batch & batch_inp) { embd_seq.clear(); output_swaps.clear(); + sched_reserve(); + bool did_optimize = false; // handle any pending shifts/copies @@ -1502,12 +1807,13 @@ int llama_context::decode(const llama_batch & batch_inp) { } // reserve output buffer - if (output_reserve(n_outputs_all, balloc->get_batch()) < n_outputs_all) { + if (output_reserve(n_outputs_all) < n_outputs_all) { LLAMA_LOG_ERROR("%s: could not reserve space for batch with %d outputs\n", __func__, n_outputs_all); return -2; }; int64_t n_outputs_prev = 0; + int64_t n_tokens_prev = 0; do { const auto & ubatch = mctx->get_ubatch(); @@ -1529,7 +1835,8 @@ int llama_context::decode(const llama_batch & batch_inp) { } ggml_status status; - const auto * res = process_ubatch(ubatch, LLM_GRAPH_TYPE_DECODER, mctx.get(), status); + + const auto * res = process_ubatch(ubatch, ctx_type_to_graph_type(cparams.ctx_type), mctx.get(), status); if (!res) { // the last ubatch failed or was aborted -> remove all positions of that ubatch from the memory module @@ -1567,33 +1874,31 @@ int llama_context::decode(const llama_batch & batch_inp) { // ggml_graph_dump_dot(gf, NULL, "llama.dot"); //} - auto * t_logits = res->get_logits(); - auto * t_embd = cparams.embeddings ? res->get_embd() : nullptr; + auto * t_logits = res->get_logits(); + auto * t_embd = cparams.embeddings ? res->get_embd() : nullptr; + auto * t_h_nextn = cparams.embeddings_nextn ? res->get_h_nextn() : nullptr; if (t_embd && res->get_embd_pooled()) { t_embd = res->get_embd_pooled(); } // extract logits - // For multi-sequence batches that mix backend samplers and CPU sampler - // this is currently inefficient as we copy all logits even for the - // backend sampled tokens. - if (logits && t_logits && n_outputs > 0) { + if (logits.data && t_logits && n_outputs > 0 && needs_raw_logits(ubatch, sampling.samplers)) { ggml_backend_t backend_res = ggml_backend_sched_get_tensor_backend(sched.get(), t_logits); GGML_ASSERT(backend_res != nullptr); - GGML_ASSERT(logits != nullptr); + GGML_ASSERT(logits.data != nullptr); - float * logits_out = logits + n_outputs_prev*n_vocab; + float * logits_out = logits.data + n_outputs_prev*n_vocab; if (n_outputs) { GGML_ASSERT( n_outputs_prev + n_outputs <= n_outputs_all); - GGML_ASSERT((n_outputs_prev + n_outputs)*n_vocab <= (int64_t) logits_size); + GGML_ASSERT((n_outputs_prev + n_outputs)*n_vocab <= (int64_t) logits.size); ggml_backend_tensor_get_async(backend_res, t_logits, logits_out, 0, n_outputs*n_vocab*sizeof(float)); } } // extract embeddings - if (embd && t_embd && n_outputs > 0) { + if (embd.data && t_embd && n_outputs > 0) { ggml_backend_t backend_embd = ggml_backend_sched_get_tensor_backend(sched.get(), t_embd); GGML_ASSERT(backend_embd != nullptr); @@ -1601,13 +1906,13 @@ int llama_context::decode(const llama_batch & batch_inp) { case LLAMA_POOLING_TYPE_NONE: { // extract token embeddings - GGML_ASSERT(embd != nullptr); - const uint32_t n_embd_out = hparams.get_n_embd_out(); - float * embd_out = embd + n_outputs_prev*n_embd_out; + GGML_ASSERT(embd.data != nullptr); + const uint32_t n_embd_out = hparams.n_embd_out(); + float * embd_out = embd.data + n_outputs_prev*n_embd_out; if (n_outputs) { GGML_ASSERT( n_outputs_prev + n_outputs <= n_outputs_all); - GGML_ASSERT((n_outputs_prev + n_outputs)*n_embd_out <= (int64_t) embd_size); + GGML_ASSERT((n_outputs_prev + n_outputs)*n_embd_out <= (int64_t) embd.size); ggml_backend_tensor_get_async(backend_embd, t_embd, embd_out, 0, n_outputs*n_embd_out*sizeof(float)); } } break; @@ -1618,12 +1923,16 @@ int llama_context::decode(const llama_batch & batch_inp) { // extract sequence embeddings (cleared before processing each batch) auto & embd_seq_out = embd_seq; + // use n_embd_out (not n_embd_inp) - the pooled embedding has the model's + // output dimension, which differs from input dimension for deepstack models (e.g. qwen3vl) + const uint32_t n_embd_out = hparams.n_embd_out(); + for (uint32_t s = 0; s < ubatch.n_seqs_unq; ++s) { const llama_seq_id seq_id = ubatch.seq_id_unq[s]; const int32_t seq_idx = ubatch.seq_idx[seq_id]; - embd_seq_out[seq_id].resize(n_embd); - ggml_backend_tensor_get_async(backend_embd, t_embd, embd_seq_out[seq_id].data(), (n_embd*seq_idx)*sizeof(float), n_embd*sizeof(float)); + embd_seq_out[seq_id].resize(n_embd_out); + ggml_backend_tensor_get_async(backend_embd, t_embd, embd_seq_out[seq_id].data(), (n_embd_out*seq_idx)*sizeof(float), n_embd_out*sizeof(float)); } } break; case LLAMA_POOLING_TYPE_RANK: @@ -1648,16 +1957,34 @@ int llama_context::decode(const llama_batch & batch_inp) { } } - // This flag indicates whether a backend sampler has actually sampled a specific - // token, or if it has produced probabilites. If true, we can skip the normal copying of logits and embeddings. - const bool has_sampled = !res->t_sampled.empty() || !res->t_sampled_probs.empty() || !res->t_sampled_logits.empty(); + extract_layer_inputs(res, n_tokens_prev, ubatch.n_tokens); + + // extract nextn embeddings before + // only meaningful in LLAMA_POOLING_TYPE_NONE (per-token); other pooling modes are ignored. + { + const bool masked = cparams.embeddings_nextn_masked; + const int64_t n_rows = masked ? n_outputs : (int64_t) ubatch.n_tokens; + const int64_t offset = masked ? n_outputs_prev : n_tokens_prev; + + if (embd_nextn.data && t_h_nextn && n_rows > 0 && cparams.pooling_type == LLAMA_POOLING_TYPE_NONE) { + ggml_backend_t backend_h = ggml_backend_sched_get_tensor_backend(sched.get(), t_h_nextn); + GGML_ASSERT(backend_h != nullptr); + + const uint32_t n_embd = hparams.n_embd_out(); + float * embd_nextn_out = embd_nextn.data + offset*n_embd; + + GGML_ASSERT((offset + n_rows)*n_embd <= (int64_t) embd_nextn.size); + ggml_backend_tensor_get_async(backend_h, t_h_nextn, embd_nextn_out, 0, n_rows*n_embd*sizeof(float)); + } + } - if (has_samplers && has_sampled) { + // Copy backend sampling output if this ubatch produced any sampling tensors. + if (has_samplers && (!res->t_sampled.empty() || !res->t_sampled_probs.empty() || !res->t_sampled_logits.empty())) { const auto seq_to_output_row = build_seq_to_output_row(ubatch, n_outputs_prev); const auto stride = n_vocab; // async copy the sampling data from the backend to the host - copy_tensor_async_ints(res->t_sampled, sampling.sampled, sampling.sampled_size, seq_to_output_row, sched.get()); + copy_tensor_async_ints(res->t_sampled, sampling.sampled, seq_to_output_row, sched.get()); copy_tensor_async_floats (res->t_sampled_logits, sampling.logits, stride, sampling.logits_count, seq_to_output_row, sched.get()); copy_tensor_async_floats (res->t_sampled_probs, sampling.probs, stride, sampling.probs_count, seq_to_output_row, sched.get()); @@ -1665,6 +1992,7 @@ int llama_context::decode(const llama_batch & batch_inp) { } n_outputs_prev += n_outputs; + n_tokens_prev += ubatch.n_tokens; } while (mctx->next()); // set to total number of outputs in the batch, for use in llama_get_logits_ith @@ -1727,7 +2055,7 @@ int llama_context::decode(const llama_batch & batch_inp) { // output // -uint32_t llama_context::output_reserve(int32_t n_outputs, const llama_batch & batch) { +uint32_t llama_context::output_reserve(int32_t n_outputs) { const auto & hparams = model.hparams; const auto & vocab = model.vocab; @@ -1735,10 +2063,12 @@ uint32_t llama_context::output_reserve(int32_t n_outputs, const llama_batch & ba const auto n_batch = cparams.n_batch; const auto n_vocab = vocab.n_tokens(); - const auto n_embd_out = hparams.get_n_embd_out(); + const auto n_embd = hparams.n_embd; + const auto n_embd_out = hparams.n_embd_out(); - bool has_logits = true; - bool has_embd = cparams.embeddings; + bool has_logits = true; + bool has_embd = cparams.embeddings; + bool has_embd_nextn = cparams.embeddings_nextn; // TODO: hacky enc-dec support if (model.arch == LLM_ARCH_T5) { @@ -1746,52 +2076,31 @@ uint32_t llama_context::output_reserve(int32_t n_outputs, const llama_batch & ba has_embd = true; } - // Check which sampling modes are needed for the current batch. - // TODO: avoid this branching by working with the worst-case - bool has_sampling = false; - bool cpu_logits = false; - - if (batch.logits) { - for (int32_t i = 0; i < batch.n_tokens; i++) { - if (!batch.logits[i]) { - continue; - } - for (int32_t j = 0; j < batch.n_seq_id[i]; j++) { - llama_seq_id seq_id = batch.seq_id[i][j]; - if (sampling.samplers.find(seq_id) != sampling.samplers.end()) { - has_sampling = true; - } else { - cpu_logits = true; - } - } - } - } else { - // When batch.logits is nullptr (when loading state with a dummy batch), - // allocate CPU logits. - cpu_logits = true; - } - size_t backend_float_count = 0; size_t backend_token_count = 0; + size_t embd_layer_inp_float_count = 0; - // Allocate CPU logits buffer only if needed by sequences in this batch - logits_size = (has_logits && cpu_logits) ? n_vocab*n_outputs_max : 0; - embd_size = has_embd ? n_embd_out*n_outputs_max : 0; + logits.size = has_logits ? n_vocab*n_outputs_max : 0; + embd.size = has_embd ? n_embd_out*n_outputs_max : 0; + embd_nextn.size = has_embd_nextn ? n_embd_out*n_outputs_max : 0; - // TODO: avoid this branching by working with the worst-case - if (!has_sampling) { - sampling.logits_size = 0; - sampling.probs_size = 0; - sampling.sampled_size = 0; - sampling.candidates_size = 0; - } else { - sampling.logits_size = n_vocab*n_outputs_max; - sampling.probs_size = n_vocab*n_outputs_max; - sampling.sampled_size = n_outputs_max; - sampling.candidates_size = n_vocab*n_outputs_max; + if (has_embd_nextn && !cparams.embeddings_nextn_masked) { + // unmasked: nextn row exists for every token in the batch, not just + // those flagged via batch.logits[i] -> size by token count instead. + embd_nextn.size = (size_t) n_embd_out * n_batch; + } - backend_float_count = sampling.logits_size + sampling.probs_size; - backend_token_count = sampling.sampled_size + sampling.candidates_size; + for (bool enabled : cparams.embeddings_layer_inp) { + if (enabled) { + embd_layer_inp_float_count += (size_t) n_embd * n_batch; + } + } + + // Allocate backend sampling output buffers if there are backend samplers configured. + const bool has_sampling = !sampling.samplers.empty(); + if (has_sampling) { + backend_float_count = 2 * n_vocab * n_outputs_max; // logits + probs + backend_token_count = (1 + n_vocab) * n_outputs_max; // sampled + candidates } if (output_ids.empty()) { @@ -1801,8 +2110,8 @@ uint32_t llama_context::output_reserve(int32_t n_outputs, const llama_batch & ba const size_t prev_size = buf_output ? ggml_backend_buffer_get_size(buf_output.get()) : 0; const size_t new_size = - (logits_size + embd_size + backend_float_count) * sizeof(float) + - ( backend_token_count) * sizeof(llama_token); + (logits.size + embd.size + embd_nextn.size + embd_layer_inp_float_count + backend_float_count) * sizeof(float) + + ( backend_token_count) * sizeof(llama_token); // alloc only when more than the current capacity is required // TODO: also consider shrinking the buffer @@ -1816,8 +2125,12 @@ uint32_t llama_context::output_reserve(int32_t n_outputs, const llama_batch & ba // TODO: not needed? buf_output = nullptr; - logits = nullptr; - embd = nullptr; + logits.data = nullptr; + embd.data = nullptr; + embd_nextn.data = nullptr; + for (auto & layer_inp : embd_layer_inp) { + layer_inp = {nullptr, 0}; + } } auto * buft = ggml_backend_cpu_buffer_type(); @@ -1832,39 +2145,44 @@ uint32_t llama_context::output_reserve(int32_t n_outputs, const llama_batch & ba LLAMA_LOG_ERROR("%s: failed to allocate output buffer of size %.2f MiB\n", __func__, new_size / (1024.0 * 1024.0)); return 0; } + ggml_backend_buffer_clear(buf_output.get(), 0); } float * output_base = (float *) ggml_backend_buffer_get_base(buf_output.get()); - logits = nullptr; - embd = nullptr; - size_t offset = 0; uint8_t * base = (uint8_t *) output_base; - logits = (has_logits && cpu_logits) ? output_base : nullptr; - offset += logits_size * sizeof(float); + logits = has_logits ? buffer_view<float>{output_base, logits.size} : buffer_view<float>{nullptr, 0}; + offset += logits.size * sizeof(float); - embd = has_embd ? (float *) (base + offset) : nullptr; - offset += embd_size * sizeof(float); + embd = has_embd ? buffer_view<float>{(float *) (base + offset), embd.size} : buffer_view<float>{nullptr, 0}; + offset += embd.size * sizeof(float); - sampling.logits = nullptr; - sampling.probs = nullptr; - sampling.sampled = nullptr; - sampling.candidates = nullptr; + embd_nextn = has_embd_nextn ? buffer_view<float>{(float *) (base + offset), embd_nextn.size} : buffer_view<float>{nullptr, 0}; + offset += embd_nextn.size * sizeof(float); + + for (uint32_t il = 0; il < embd_layer_inp.size(); ++il) { + if (cparams.embeddings_layer_inp[il]) { + embd_layer_inp[il] = buffer_view<float>{(float *) (base + offset), (size_t) n_embd * n_batch}; + offset += embd_layer_inp[il].size * sizeof(float); + } else { + embd_layer_inp[il] = buffer_view<float>{nullptr, 0}; + } + } if (has_sampling) { - sampling.logits = (float *) (base + offset); - offset += sampling.logits_size * sizeof(float); + sampling.logits = {(float *) (base + offset), (size_t)(n_vocab*n_outputs_max)}; + offset += sampling.logits.size * sizeof(float); - sampling.probs = (float *) (base + offset); - offset += sampling.probs_size * sizeof(float); + sampling.probs = {(float *) (base + offset), (size_t)(n_vocab*n_outputs_max)}; + offset += sampling.probs.size * sizeof(float); - sampling.sampled = (llama_token *) (base + offset); - offset += sampling.sampled_size * sizeof(llama_token); + sampling.sampled = {(llama_token *) (base + offset), (size_t)n_outputs_max}; + offset += sampling.sampled.size * sizeof(llama_token); - sampling.candidates = (llama_token *) (base + offset); - offset += sampling.candidates_size * sizeof(llama_token); + sampling.candidates = {(llama_token *) (base + offset), (size_t)(n_vocab*n_outputs_max)}; + offset += sampling.candidates.size * sizeof(llama_token); // The count vectors keep track of the actual number of logits/probs/candidates // copied from the backend for each output row. @@ -1877,7 +2195,16 @@ uint32_t llama_context::output_reserve(int32_t n_outputs, const llama_batch & ba std::fill(sampling.probs_count.begin(), sampling.probs_count.end(), 0); std::fill(sampling.candidates_count.begin(), sampling.candidates_count.end(), 0); - std::fill_n(sampling.sampled, sampling.sampled_size, LLAMA_TOKEN_NULL); + std::fill_n(sampling.sampled.data, sampling.sampled.size, LLAMA_TOKEN_NULL); + } else { + sampling.logits = {nullptr, 0}; + sampling.probs = {nullptr, 0}; + sampling.sampled = {nullptr, 0}; + sampling.candidates = {nullptr, 0}; + + sampling.logits_count.clear(); + sampling.probs_count.clear(); + sampling.candidates_count.clear(); } // set all ids as invalid (negative) @@ -1885,9 +2212,39 @@ uint32_t llama_context::output_reserve(int32_t n_outputs, const llama_batch & ba this->n_outputs = 0; + GGML_ASSERT(n_outputs_max <= cparams.n_outputs_max); + return n_outputs_max; } +void llama_context::extract_layer_inputs(const llm_graph_result * res, size_t token_offset, size_t n_tokens) { + for (uint32_t il = 0; il < cparams.embeddings_layer_inp.size(); ++il) { + if (!cparams.embeddings_layer_inp[il]) { + continue; + } + if (!embd_layer_inp[il].has_data()) { + GGML_ABORT("output layer input buffer not allocated"); + } + ggml_tensor * t = res->get_layer_inp((int) il); + if (!t) { + GGML_ABORT("layer input tensor not found"); + } + + const size_t nbytes = ggml_nbytes(t); + const size_t nfloats = nbytes / sizeof(float); + GGML_ASSERT(n_tokens > 0); + GGML_ASSERT(nfloats % n_tokens == 0); + + const size_t row_floats = nfloats / n_tokens; + const size_t dst_offset = token_offset * row_floats; + GGML_ASSERT(dst_offset + nfloats <= embd_layer_inp[il].size); + + ggml_backend_t backend = ggml_backend_sched_get_tensor_backend(sched.get(), t); + GGML_ASSERT(backend != nullptr); + ggml_backend_tensor_get_async(backend, t, embd_layer_inp[il].data + dst_offset, 0, nbytes); + } +} + void llama_context::output_reorder() { const uint64_t n_vocab = model.vocab.n_tokens(); const uint64_t n_embd = model.hparams.n_embd; @@ -1896,49 +2253,58 @@ void llama_context::output_reorder() { const uint64_t i0 = output_swaps[s].i0; const uint64_t i1 = output_swaps[s].i1; - if (logits_size > 0) { + if (logits.size > 0) { for (uint64_t k = 0; k < n_vocab; k++) { - std::swap(logits[i0*n_vocab + k], logits[i1*n_vocab + k]); + std::swap(logits.data[i0*n_vocab + k], logits.data[i1*n_vocab + k]); } } - if (embd_size > 0) { + if (embd.size > 0) { for (uint64_t k = 0; k < n_embd; k++) { - std::swap(embd[i0*n_embd + k], embd[i1*n_embd + k]); + std::swap(embd.data[i0*n_embd + k], embd.data[i1*n_embd + k]); } } - if (sampling.logits && sampling.logits_size > 0) { - for (uint64_t k = 0; k < n_vocab; ++k) { - std::swap(sampling.logits[i0*n_vocab + k], sampling.logits[i1*n_vocab + k]); + if (embd_nextn.size > 0) { + for (uint64_t k = 0; k < n_embd; k++) { + std::swap(embd_nextn.data[i0*n_embd + k], embd_nextn.data[i1*n_embd + k]); } } - if (sampling.probs && sampling.probs_size > 0) { - for (uint64_t k = 0; k < n_vocab; ++k) { - std::swap(sampling.probs[i0*n_vocab + k], sampling.probs[i1*n_vocab + k]); + if (embd_layer_inp.size() > 0) { + for (int lid = 0; lid < (int) embd_layer_inp.size(); ++lid) { + if (embd_layer_inp[lid].size > 0) { + for (uint64_t k = 0; k < n_embd; ++k) { + std::swap(embd_layer_inp[lid].data[i0*n_embd + k], embd_layer_inp[lid].data[i1*n_embd + k]); + } + } } } - if (sampling.candidates && sampling.candidates_size > 0) { + if (!sampling.samplers.empty()) { + assert(sampling.logits.size > 0); + assert(sampling.probs.size > 0); + assert(sampling.candidates.size > 0); + assert(sampling.sampled.size > 0); + assert(sampling.logits_count.size() > 0); + assert(sampling.probs_count.size() > 0); + assert(sampling.candidates_count.size() > 0); + for (uint64_t k = 0; k < n_vocab; ++k) { - std::swap(sampling.candidates[i0*n_vocab + k], sampling.candidates[i1*n_vocab + k]); + std::swap(sampling.logits.data[i0*n_vocab + k], sampling.logits.data[i1*n_vocab + k]); } - } - - if (sampling.sampled && sampling.sampled_size > 0) { - std::swap(sampling.sampled[i0], sampling.sampled[i1]); - } - if (!sampling.logits_count.empty()) { - std::swap(sampling.logits_count[i0], sampling.logits_count[i1]); - } + for (uint64_t k = 0; k < n_vocab; ++k) { + std::swap(sampling.probs.data[i0*n_vocab + k], sampling.probs.data[i1*n_vocab + k]); + } - if (!sampling.probs_count.empty()) { - std::swap(sampling.probs_count[i0], sampling.probs_count[i1]); - } + for (uint64_t k = 0; k < n_vocab; ++k) { + std::swap(sampling.candidates.data[i0*n_vocab + k], sampling.candidates.data[i1*n_vocab + k]); + } - if (!sampling.candidates_count.empty()) { + std::swap(sampling.sampled.data[i0], sampling.sampled.data[i1]); + std::swap(sampling.logits_count[i0], sampling.logits_count[i1]); + std::swap(sampling.probs_count[i0], sampling.probs_count[i1]); std::swap(sampling.candidates_count[i0], sampling.candidates_count[i1]); } } @@ -1951,11 +2317,13 @@ void llama_context::output_reorder() { // uint32_t llama_context::graph_max_nodes(uint32_t n_tokens) const { - if (model.arch == LLM_ARCH_QWEN3NEXT) { + if (model.arch == LLM_ARCH_QWEN3NEXT || model.arch == LLM_ARCH_KIMI_LINEAR || model.arch == LLM_ARCH_QWEN35 || model.arch == LLM_ARCH_QWEN35MOE) { return std::max<uint32_t>(n_tokens * 40, 32u * model.n_tensors()); } uint32_t res = std::max<uint32_t>(1024u, 8u*model.n_tensors()); - res += model.n_lora_nodes; + for (const auto & lora : model.loras) { + res += lora->get_n_nodes(); + } return res; } @@ -1970,14 +2338,12 @@ ggml_cgraph * llama_context::graph_reserve( if (n_tokens % n_seqs != 0) { n_tokens = ((n_tokens + (n_seqs - 1)) / n_seqs) * n_seqs; // round to next multiple of n_seqs - n_outputs = std::max(n_outputs, n_tokens); - LLAMA_LOG_DEBUG("%s: making n_tokens a multiple of n_seqs - n_tokens = %u, n_seqs = %u, n_outputs = %u\n", __func__, n_tokens, n_seqs, n_outputs); } ggml_backend_sched_reset(sched.get()); - // when the scheduler is reset, we cannnot reuse the old graph, so we reset the previous graph result to prevent that + // when the scheduler is reset, we cannot reuse the old graph, so we reset the previous graph result to prevent that gf_res_prev->reset(); // store the n_outputs as it is, and restore it afterwards @@ -2000,7 +2366,7 @@ ggml_cgraph * llama_context::graph_reserve( auto * res = gf_res_reserve.get(); - const auto gparams = graph_params(res, ubatch, mctx, LLM_GRAPH_TYPE_DEFAULT); + const auto gparams = graph_params(res, ubatch, mctx, ctx_type_to_graph_type(cparams.ctx_type)); res->reset(); @@ -2037,8 +2403,8 @@ llm_graph_params llama_context::graph_params( /*.gtype =*/ gtype, /*.sched =*/ sched.get(), /*.backend_cpu =*/ backend_cpu, - /*.cvec =*/ &cvec, - /*.loras =*/ &loras, + /*.cvec =*/ cvec.get(), + /*.loras =*/ loras.get(), /*.mctx =*/ mctx, /*.cross =*/ &cross, /*.samplers =*/ sampling.samplers, @@ -2085,16 +2451,9 @@ llm_graph_cb llama_context::graph_get_cb() const { ggml_set_name(cur, name); } - if (!cparams.offload_kqv) { - if (strcmp(name, "kqv_merged_cont") == 0) { - // all nodes between the KV store and the attention output are run on the CPU - ggml_backend_sched_set_tensor_backend(sched.get(), cur, backend_cpu); - } - } - // norm may be automatically assigned to the backend of the previous layer, increasing data transfer between backends // FIXME: fix in ggml_backend_sched - const bool full_offload = model.n_gpu_layers() > model.hparams.n_layer; + const bool full_offload = model.n_gpu_layers() > model.hparams.n_layer_all; if (ubatch.n_tokens < 32 || full_offload) { if (il != -1 && strcmp(name, "norm") == 0) { const auto & dev_layer = model.dev_layer(il); @@ -2116,28 +2475,281 @@ llm_graph_cb llama_context::graph_get_cb() const { class llama_io_write_dummy : public llama_io_write_i { public: - llama_io_write_dummy() = default; + llama_io_write_dummy(bool skip_tensors) : skip_tensors(skip_tensors) {} void write(const void * /* src */, size_t size) override { size_written += size; } - void write_tensor(const ggml_tensor * /* tensor */, size_t /* offset */, size_t size) override { + void write_tensor(ggml_tensor * /* tensor */, size_t /* offset */, size_t size) override { + if (skip_tensors) { + return; + } + + size_written += size; + } + + size_t n_bytes() override { + return size_written; + } + +private: + const bool skip_tensors; + + size_t size_written = 0; +}; + +class llama_io_write_host : public llama_io_write_i { +public: + llama_io_write_host( + uint8_t * p, size_t len) : ptr(p), buf_size(len) {} + + ~llama_io_write_host() { + // TODO: add backend support to batch tensor_get? or some other way to speed this up + for (const auto & winfo : winfos) { + ggml_backend_tensor_get(winfo.tensor, winfo.ptr, winfo.offset, winfo.size); + } + } + + void write(const void * src, size_t size) override { + if (size > buf_size) { + throw std::runtime_error("unexpectedly reached end of buffer"); + } + memcpy(ptr, src, size); + ptr += size; + size_written += size; + buf_size -= size; + } + + void write_tensor(ggml_tensor * tensor, size_t offset, size_t size) override { + if (size > buf_size) { + throw std::runtime_error("unexpectedly reached end of buffer"); + } + + // save the write for later during destruction + winfos.push_back({tensor, ptr, size, offset}); + + ptr += size; + size_written += size; + buf_size -= size; + } + + size_t n_bytes() override { + return size_written; + } + +private: + uint8_t * ptr; + size_t buf_size = 0; + size_t size_written = 0; + + struct write_info { + ggml_tensor * tensor; + uint8_t * ptr; + size_t size; + size_t offset; + }; + std::vector<write_info> winfos; +}; + +class llama_io_read_host : public llama_io_read_i { +public: + llama_io_read_host(const uint8_t * p, size_t len) : ptr(p), buf_size(len) {} + + ~llama_io_read_host() { + // flush the reads + for (const auto & rinfo : rinfos) { + ggml_backend_tensor_set(rinfo.tensor, rinfo.ptr, rinfo.offset, rinfo.size); + } + } + + void read(void * dst, size_t size) override { + if (size > buf_size) { + throw std::runtime_error("unexpectedly reached end of buffer"); + } + memcpy(dst, ptr, size); + ptr += size; + size_read += size; + buf_size -= size; + } + + void read_tensor(ggml_tensor * tensor, size_t offset, size_t size) override { + if (size > buf_size) { + throw std::runtime_error("unexpectedly reached end of buffer"); + } + + // save for later during destruction + rinfos.push_back({tensor, ptr, size, offset}); + + ptr += size; + size_read += size; + buf_size -= size; + } + + size_t n_bytes() override { + return size_read; + } + +private: + const uint8_t * ptr; + size_t buf_size = 0; + size_t size_read = 0; + + struct read_info { + ggml_tensor * tensor; + const uint8_t * ptr; + size_t size; + size_t offset; + }; + std::vector<read_info> rinfos; +}; + +class llama_io_write_file : public llama_io_write_i { +public: + llama_io_write_file(llama_file * f) : file(f) {} + + void write(const void * src, size_t size) override { + file->write_raw(src, size); size_written += size; } + void write_tensor(ggml_tensor * tensor, size_t offset, size_t size) override { + temp_buffer.resize(size); + ggml_backend_tensor_get(tensor, temp_buffer.data(), offset, size); + write(temp_buffer.data(), temp_buffer.size()); + } + size_t n_bytes() override { return size_written; } -private: - size_t size_written = 0; -}; +private: + llama_file * file; + size_t size_written = 0; + std::vector<uint8_t> temp_buffer; +}; + +class llama_io_read_file : public llama_io_read_i { +public: + llama_io_read_file(llama_file * f) : file(f) {} + + void read(void * dst, size_t size) override { + file->read_raw(dst, size); + size_read += size; + } + + void read_tensor(ggml_tensor * tensor, size_t offset, size_t size) override { + temp_buffer.resize(size); + read(temp_buffer.data(), size); + ggml_backend_tensor_set(tensor, temp_buffer.data(), offset, size); + } + + size_t n_bytes() override { + return size_read; + } + +private: + llama_file * file; + size_t size_read = 0; + std::vector<uint8_t> temp_buffer; +}; + +class llama_io_write_device : public llama_io_write_i { +public: + llama_io_write_device(uint8_t * p, size_t len, llama_memory_buffers & mbufs) : ptr(p), buf_size(len), mbufs(mbufs) { + } + + ~llama_io_write_device() { + llama_memory_buffers mbufs_new; + + for (const auto & winfo : winfos) { + auto * buft = ggml_backend_buffer_get_type(winfo.tensor->buffer); + + mbufs_new[buft].n_tensors++; + mbufs_new[buft].total_size += winfo.size; + } + + for (auto & [buft, mbuf] : mbufs_new) { + ggml_init_params params = { + /*.mem_size =*/ 2*mbuf.n_tensors*ggml_tensor_overhead(), + /*.mem_buffer =*/ NULL, + /*.no_alloc =*/ true, + }; + + mbuf.ctx.reset(ggml_init(params)); + + mbuf.org.reserve(mbuf.n_tensors); + mbuf.cpy.reserve(mbuf.n_tensors); + } + + for (const auto & winfo : winfos) { + auto * buft = ggml_backend_buffer_get_type(winfo.tensor->buffer); + + const int64_t n = winfo.size/ggml_element_size(winfo.tensor); + + auto & mbuf = mbufs_new[buft]; + + mbuf.org.push_back(ggml_view_1d (mbuf.ctx.get(), winfo.tensor, n, winfo.offset)); + mbuf.cpy.push_back(ggml_new_tensor_1d(mbuf.ctx.get(), winfo.tensor->type, n)); + } + + for (auto & [buft, mbuf] : mbufs_new) { + auto & mbuf_cur = mbufs[buft]; + + bool need_alloc = false; + + need_alloc = need_alloc || (!mbuf_cur.buf); + need_alloc = need_alloc || (mbuf_cur.org.size() != mbuf.org.size()); + need_alloc = need_alloc || (mbuf_cur.total_size != mbuf.total_size); + + if (!need_alloc) { + for (size_t i = 0; i < mbuf_cur.org.size(); ++i) { + auto * org0 = mbuf_cur.org[i]; + auto * org1 = mbuf.org[i]; + + if (!ggml_are_same_shape(org0, org1)) { + need_alloc = true; + break; + } + + if (org0->view_src != org1->view_src || org0->view_offs != org1->view_offs) { + need_alloc = true; + break; + } + } + } + + if (need_alloc) { + if (!mbuf_cur.buf || mbuf_cur.total_size != mbuf.total_size) { + mbuf_cur = std::move(mbuf); + + mbuf_cur.buf.reset(ggml_backend_alloc_ctx_tensors_from_buft(mbuf_cur.ctx.get(), buft)); + + LLAMA_LOG_INFO("%s: allocated '%s' buffer %.3f MiB\n", __func__, ggml_backend_buft_name(buft), mbuf.total_size/1024.0/1024.0); + } else { + //LLAMA_LOG_INFO("%s: reallocating tensors in '%s' buffer %.3f MiB\n", __func__, ggml_backend_buft_name(buft), mbuf.total_size/1024.0/1024.0); + + // save the old buffer and allocate the new tensors in it + auto buf = std::move(mbuf_cur.buf); + + mbuf_cur = std::move(mbuf); + + ggml_tallocr talloc = ggml_tallocr_new(buf.get()); + + for (size_t i = 0; i < mbuf_cur.org.size(); ++i) { + ggml_backend_view_init(mbuf_cur.org[i]); + ggml_tallocr_alloc(&talloc, mbuf_cur.cpy[i]); + } + + mbuf_cur.buf = std::move(buf); + } + } -class llama_io_write_buffer : public llama_io_write_i { -public: - llama_io_write_buffer( - uint8_t * p, size_t len) : ptr(p), buf_size(len) {} + for (size_t i = 0; i < mbuf_cur.org.size(); ++i) { + ggml_backend_tensor_copy(mbuf_cur.org[i], mbuf_cur.cpy[i]); + } + } + } void write(const void * src, size_t size) override { if (size > buf_size) { @@ -2149,14 +2761,9 @@ class llama_io_write_buffer : public llama_io_write_i { buf_size -= size; } - void write_tensor(const ggml_tensor * tensor, size_t offset, size_t size) override { - if (size > buf_size) { - throw std::runtime_error("unexpectedly reached end of buffer"); - } - ggml_backend_tensor_get(tensor, ptr, offset, size); - ptr += size; - size_written += size; - buf_size -= size; + void write_tensor(ggml_tensor * tensor, size_t offset, size_t size) override { + // save the write for later during destruction + winfos.push_back({tensor, ptr, size, offset}); } size_t n_bytes() override { @@ -2167,75 +2774,85 @@ class llama_io_write_buffer : public llama_io_write_i { uint8_t * ptr; size_t buf_size = 0; size_t size_written = 0; + + struct write_info { + ggml_tensor * tensor; + uint8_t * ptr; + size_t size; + size_t offset; + }; + std::vector<write_info> winfos; + + llama_memory_buffers & mbufs; }; -class llama_io_read_buffer : public llama_io_read_i { +class llama_io_read_device : public llama_io_read_i { public: - llama_io_read_buffer(const uint8_t * p, size_t len) : ptr(p), buf_size(len) {} + llama_io_read_device(const uint8_t * p, size_t len, const llama_memory_buffers & mbufs) : ptr(p), buf_size(len), mbufs(mbufs) { + } - const uint8_t * read(size_t size) override { - const uint8_t * base_ptr = ptr; - if (size > buf_size) { - throw std::runtime_error("unexpectedly reached end of buffer"); + ~llama_io_read_device() { + llama_memory_buffers mbufs_new; + + for (const auto & rinfo : rinfos) { + auto * buft = ggml_backend_buffer_get_type(rinfo.tensor->buffer); + + mbufs_new[buft].n_tensors++; + mbufs_new[buft].total_size += rinfo.size; } - ptr += size; - size_read += size; - buf_size -= size; - return base_ptr; - } - void read_to(void * dst, size_t size) override { - memcpy(dst, read(size), size); - } + for (auto & [buft, mbuf] : mbufs_new) { + ggml_init_params params = { + /*.mem_size =*/ mbuf.n_tensors*ggml_tensor_overhead(), + /*.mem_buffer =*/ NULL, + /*.no_alloc =*/ true, + }; - size_t n_bytes() override { - return size_read; - } + mbuf.ctx.reset(ggml_init(params)); -private: - const uint8_t * ptr; - size_t buf_size = 0; - size_t size_read = 0; -}; + mbuf.org.reserve(mbuf.n_tensors); + } -class llama_io_write_file : public llama_io_write_i { -public: - llama_io_write_file(llama_file * f) : file(f) {} + for (const auto & rinfo : rinfos) { + auto * buft = ggml_backend_buffer_get_type(rinfo.tensor->buffer); - void write(const void * src, size_t size) override { - file->write_raw(src, size); - size_written += size; - } + const int64_t n = rinfo.size/ggml_element_size(rinfo.tensor); - void write_tensor(const ggml_tensor * tensor, size_t offset, size_t size) override { - temp_buffer.resize(size); - ggml_backend_tensor_get(tensor, temp_buffer.data(), offset, size); - write(temp_buffer.data(), temp_buffer.size()); - } + auto & mbuf = mbufs_new[buft]; - size_t n_bytes() override { - return size_written; - } + mbuf.org.push_back(ggml_view_1d(mbuf.ctx.get(), rinfo.tensor, n, rinfo.offset)); -private: - llama_file * file; - size_t size_written = 0; - std::vector<uint8_t> temp_buffer; -}; + ggml_backend_view_init(mbuf.org.back()); + } -class llama_io_read_file : public llama_io_read_i { -public: - llama_io_read_file(llama_file * f) : file(f) {} + for (auto & [buft, mbuf] : mbufs_new) { + const auto & mbuf_cur = mbufs.at(buft); - void read_to(void * dst, size_t size) override { - file->read_raw(dst, size); + if (!mbuf_cur.buf || mbuf_cur.n_tensors != mbuf.n_tensors || mbuf_cur.total_size != mbuf.total_size) { + GGML_ABORT("%s: memory buffer mismatch\n", __func__); + } + + for (size_t i = 0; i < mbuf_cur.org.size(); ++i) { + ggml_backend_tensor_copy(mbuf_cur.cpy[i], mbuf.org[i]); + } + } + + GGML_ASSERT(buf_size == 0); + } + + void read(void * dst, size_t size) override { + if (size > buf_size) { + throw std::runtime_error("unexpectedly reached end of buffer"); + } + memcpy(dst, ptr, size); + ptr += size; size_read += size; + buf_size -= size; } - const uint8_t * read(size_t size) override { - temp_buffer.resize(size); - read_to(temp_buffer.data(), size); - return temp_buffer.data(); + void read_tensor(ggml_tensor * tensor, size_t offset, size_t size) override { + // save for later during destruction + rinfos.push_back({tensor, ptr, size, offset}); } size_t n_bytes() override { @@ -2243,13 +2860,23 @@ class llama_io_read_file : public llama_io_read_i { } private: - llama_file * file; + const uint8_t * ptr; + size_t buf_size = 0; size_t size_read = 0; - std::vector<uint8_t> temp_buffer; + + struct read_info { + ggml_tensor * tensor; + const uint8_t * ptr; + size_t size; + size_t offset; + }; + std::vector<read_info> rinfos; + + const llama_memory_buffers & mbufs; }; size_t llama_context::state_get_size() { - llama_io_write_dummy io; + llama_io_write_dummy io(false); try { return state_write_data(io); } catch (const std::exception & err) { @@ -2259,7 +2886,7 @@ size_t llama_context::state_get_size() { } size_t llama_context::state_get_data(uint8_t * dst, size_t size) { - llama_io_write_buffer io(dst, size); + llama_io_write_host io(dst, size); try { return state_write_data(io); } catch (const std::exception & err) { @@ -2269,7 +2896,7 @@ size_t llama_context::state_get_data(uint8_t * dst, size_t size) { } size_t llama_context::state_set_data(const uint8_t * src, size_t size) { - llama_io_read_buffer io(src, size); + llama_io_read_host io(src, size); try { return state_read_data(io); } catch (const std::exception & err) { @@ -2278,9 +2905,14 @@ size_t llama_context::state_set_data(const uint8_t * src, size_t size) { } } +static constexpr uint32_t io_magic = 0xaf143cd8; + size_t llama_context::state_seq_get_size(llama_seq_id seq_id, llama_state_seq_flags flags) { - llama_io_write_dummy io; + llama_io_write_dummy io(flags & LLAMA_STATE_SEQ_FLAGS_ON_DEVICE); try { + io.write(&io_magic, sizeof(io_magic)); + io.write(&seq_id, sizeof(seq_id)); + return state_seq_write_data(io, seq_id, flags); } catch (const std::exception & err) { LLAMA_LOG_ERROR("%s: error getting state size: %s\n", __func__, err.what()); @@ -2289,9 +2921,18 @@ size_t llama_context::state_seq_get_size(llama_seq_id seq_id, llama_state_seq_fl } size_t llama_context::state_seq_get_data(llama_seq_id seq_id, uint8_t * dst, size_t size, llama_state_seq_flags flags) { - llama_io_write_buffer io(dst, size); + std::unique_ptr<llama_io_write_i> io; + if (flags & LLAMA_STATE_SEQ_FLAGS_ON_DEVICE) { + io = std::make_unique<llama_io_write_device>(dst, size, mem_storage[seq_id]); + } else { + io = std::make_unique<llama_io_write_host>(dst, size); + } + try { - return state_seq_write_data(io, seq_id, flags); + io->write(&io_magic, sizeof(io_magic)); + io->write(&seq_id, sizeof(seq_id)); + + return state_seq_write_data(*io, seq_id, flags); } catch (const std::exception & err) { LLAMA_LOG_ERROR("%s: error saving state: %s\n", __func__, err.what()); return 0; @@ -2299,9 +2940,38 @@ size_t llama_context::state_seq_get_data(llama_seq_id seq_id, uint8_t * dst, siz } size_t llama_context::state_seq_set_data(llama_seq_id seq_id, const uint8_t * src, size_t size, llama_state_seq_flags flags) { - llama_io_read_buffer io(src, size); + std::unique_ptr<llama_io_read_i> io; + if (flags & LLAMA_STATE_SEQ_FLAGS_ON_DEVICE) { + // create a temporary io to read the magic and the src seq_id + io = std::make_unique<llama_io_read_host>(src, size); + + uint32_t magic_read; + io->read(&magic_read, sizeof(magic_read)); + if (io_magic != magic_read) { + throw std::runtime_error("wrong sequence state magic"); + } + + llama_seq_id seq_id_read; + io->read(&seq_id_read, sizeof(seq_id_read)); + + GGML_ASSERT(mem_storage.find(seq_id_read) != mem_storage.end()); + + io = std::make_unique<llama_io_read_device>(src, size, mem_storage[seq_id_read]); + } else { + io = std::make_unique<llama_io_read_host>(src, size); + } + try { - return state_seq_read_data(io, seq_id, flags); + uint32_t magic_read; + io->read(&magic_read, sizeof(magic_read)); + if (io_magic != magic_read) { + throw std::runtime_error("wrong sequence state magic"); + } + + llama_seq_id seq_id_read; + io->read(&seq_id_read, sizeof(seq_id_read)); + + return state_seq_read_data(*io, seq_id, flags); } catch (const std::exception & err) { LLAMA_LOG_ERROR("%s: error loading state: %s\n", __func__, err.what()); return 0; @@ -2443,63 +3113,6 @@ size_t llama_context::state_write_data(llama_io_write_i & io) { // TODO: add more model-specific info which should prevent loading the session file if not identical } - // write output ids - { - LLAMA_LOG_DEBUG("%s: - writing output ids\n", __func__); - - const auto n_outputs = this->n_outputs; - const auto & output_ids = this->output_ids; - - std::vector<int32_t> w_output_pos; - - w_output_pos.resize(n_outputs); - - // build a more compact representation of the output ids - for (size_t i = 0; i < n_batch(); ++i) { - // map an output id to a position in the batch - int64_t pos = output_ids[i]; - if (pos >= 0) { - GGML_ASSERT(pos < n_outputs); - w_output_pos[pos] = i; - } - } - - io.write(&n_outputs, sizeof(n_outputs)); - - if (n_outputs) { - io.write(w_output_pos.data(), n_outputs * sizeof(int32_t)); - } - } - - // write logits - { - LLAMA_LOG_DEBUG("%s: - writing logits\n", __func__); - - const uint64_t logits_size = std::min((uint64_t) this->logits_size, (uint64_t) n_outputs * model.vocab.n_tokens()); - - io.write(&logits_size, sizeof(logits_size)); - - if (logits_size) { - io.write(logits, logits_size * sizeof(float)); - } - } - - // write embeddings - { - LLAMA_LOG_DEBUG("%s: - writing embeddings\n", __func__); - - const uint64_t embd_size = std::min((uint64_t) this->embd_size, (uint64_t) n_outputs * model.hparams.n_embd); - - io.write(&embd_size, sizeof(embd_size)); - - if (embd_size) { - io.write(embd, embd_size * sizeof(float)); - } - } - - // TODO: handle sampling buffers and samplers state ? - // https://github.com/ggml-org/llama.cpp/pull/17004 - if (memory != nullptr) { LLAMA_LOG_DEBUG("%s: - writing memory module\n", __func__); memory->state_write(io); @@ -2525,73 +3138,6 @@ size_t llama_context::state_read_data(llama_io_read_i & io) { // TODO: add more info which needs to be identical but which is not verified otherwise } - // read output ids - { - LLAMA_LOG_DEBUG("%s: - reading output ids\n", __func__); - - auto n_outputs = this->n_outputs; - io.read_to(&n_outputs, sizeof(n_outputs)); - - // Create a dummy batch for state loading. - llama_batch dummy_batch = {}; - dummy_batch.n_tokens = 0; - if (n_outputs > output_reserve(n_outputs, dummy_batch)) { - throw std::runtime_error("could not reserve outputs"); - } - - std::vector<int32_t> output_pos; - - if (n_outputs) { - output_pos.resize(n_outputs); - io.read_to(output_pos.data(), n_outputs * sizeof(int32_t)); - - for (int32_t i = 0; i < (int32_t) output_pos.size(); ++i) { - int32_t id = output_pos[i]; - if ((uint32_t) id >= n_batch()) { - throw std::runtime_error(format("invalid output id, %d does not fit in batch size of %u", id, n_batch())); - } - this->output_ids[id] = i; - } - - this->n_outputs = n_outputs; - } - } - - // read logits - { - LLAMA_LOG_DEBUG("%s: - reading logits\n", __func__); - - uint64_t logits_size; - io.read_to(&logits_size, sizeof(logits_size)); - - if (this->logits_size < logits_size) { - throw std::runtime_error("logits buffer too small"); - } - - if (logits_size) { - io.read_to(this->logits, logits_size * sizeof(float)); - } - } - - // read embeddings - { - LLAMA_LOG_DEBUG("%s: - reading embeddings\n", __func__); - - uint64_t embd_size; - io.read_to(&embd_size, sizeof(embd_size)); - - if (this->embd_size < embd_size) { - throw std::runtime_error("embeddings buffer too small"); - } - - if (embd_size) { - io.read_to(this->embd, embd_size * sizeof(float)); - } - } - - // TODO: handle sampling buffers and samplers state ? - // https://github.com/ggml-org/llama.cpp/pull/17004 - if (memory) { LLAMA_LOG_DEBUG("%s: - reading memory module\n", __func__); @@ -2646,7 +3192,7 @@ void llama_context::perf_reset() { n_reused = 0; } -std::map<ggml_backend_buffer_type_t, llama_memory_breakdown_data> llama_context::memory_breakdown() const { +llama_memory_breakdown llama_context::memory_breakdown() const { std::map<ggml_backend_buffer_type_t, llama_memory_breakdown_data> ret; for (const auto & [buft, size] : model.memory_breakdown()) { ret[buft].model += size; @@ -2724,6 +3270,7 @@ void llama_context::opt_init(struct llama_model * model, struct llama_opt_params llama_set_param(model->cls_b, param_filter, param_filter_ud); llama_set_param(model->cls_out, param_filter, param_filter_ud); llama_set_param(model->cls_out_b, param_filter, param_filter_ud); + llama_set_param(model->cls_norm, param_filter, param_filter_ud); for (struct llama_layer & layer : model->layers) { for (size_t i = 0; i < sizeof(layer)/sizeof(struct ggml_tensor *); ++i) { @@ -2780,7 +3327,7 @@ void llama_context::opt_epoch_iter( } // reserve output buffer - if (output_reserve(n_outputs_all, balloc->get_batch()) < n_outputs_all) { + if (output_reserve(n_outputs_all) < n_outputs_all) { LLAMA_LOG_ERROR("%s: could not reserve space for batch with %d outputs\n", __func__, n_outputs_all); GGML_ABORT("TODO: handle this error"); }; @@ -2798,7 +3345,7 @@ void llama_context::opt_epoch_iter( auto * res = gf_res_prev.get(); - const auto gparams = graph_params(res, ubatch, mctx.get(), LLM_GRAPH_TYPE_DEFAULT); + const auto gparams = graph_params(res, ubatch, mctx.get(), ctx_type_to_graph_type(cparams.ctx_type)); res->reset(); @@ -2815,7 +3362,7 @@ void llama_context::opt_epoch_iter( }; ctx_compute_opt = ggml_init(params); } - ggml_opt_prepare_alloc(opt_ctx, ctx_compute_opt, gf, res->get_tokens(), res->get_logits()); + ggml_opt_prepare_alloc(opt_ctx, ctx_compute_opt, gf, res->get_inp_tokens(), res->get_logits()); ggml_opt_alloc(opt_ctx, train); res->set_inputs(&ubatch); @@ -2899,8 +3446,11 @@ llama_context_params llama_context_default_params() { /*.n_batch =*/ 2048, /*.n_ubatch =*/ 512, /*.n_seq_max =*/ 1, + /*.n_rs_seq =*/ 0, + /*.n_outputs_max =*/ 0, /*.n_threads =*/ GGML_DEFAULT_N_THREADS, // TODO: better default /*.n_threads_batch =*/ GGML_DEFAULT_N_THREADS, + /*.ctx_type =*/ LLAMA_CONTEXT_TYPE_DEFAULT, /*.rope_scaling_type =*/ LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED, /*.pooling_type =*/ LLAMA_POOLING_TYPE_UNSPECIFIED, /*.attention_type =*/ LLAMA_ATTENTION_TYPE_UNSPECIFIED, @@ -2927,6 +3477,7 @@ llama_context_params llama_context_default_params() { /*.kv_unified =*/ false, /*.sampler =*/ nullptr, /*.n_sampler =*/ 0, + /*.ctx_other =*/ nullptr, }; return result; @@ -2955,21 +3506,36 @@ llama_context * llama_init_from_model( params.flash_attn_type = LLAMA_FLASH_ATTN_TYPE_DISABLED; } - if (params.flash_attn_type == LLAMA_FLASH_ATTN_TYPE_AUTO && ggml_is_quantized(params.type_k)) { - const uint32_t blck_size = ggml_blck_size(params.type_k); - if (model->hparams.n_embd_head_k % blck_size != 0) { - LLAMA_LOG_ERROR("%s: K cache type %s with block size %u does not divide n_embd_head_k=%u\n", - __func__, ggml_type_name(params.type_k), blck_size, model->hparams.n_embd_head_k); + if (model->split_mode() == LLAMA_SPLIT_MODE_TENSOR) { + if (params.flash_attn_type == LLAMA_FLASH_ATTN_TYPE_AUTO) { + LLAMA_LOG_INFO("%s: enabling flash_attn since it is required for SPLIT_MODE_TENSOR\n", __func__); + params.flash_attn_type = LLAMA_FLASH_ATTN_TYPE_ENABLED; + } + if (params.flash_attn_type != LLAMA_FLASH_ATTN_TYPE_ENABLED) { + LLAMA_LOG_ERROR("%s: SPLIT_MODE_TENSOR requires flash_attn to be enabled\n", __func__); return nullptr; } } - if (params.flash_attn_type == LLAMA_FLASH_ATTN_TYPE_AUTO && ggml_is_quantized(params.type_v)) { + if (params.flash_attn_type != LLAMA_FLASH_ATTN_TYPE_DISABLED && ggml_is_quantized(params.type_k)) { + const uint32_t blck_size = ggml_blck_size(params.type_k); + for (uint32_t il = 0; il < model->hparams.n_layer(); ++il) { + if (model->hparams.n_embd_head_k(il) % blck_size != 0) { + LLAMA_LOG_ERROR("%s: K cache type %s with block size %u does not divide n_embd_head_k=%u\n", + __func__, ggml_type_name(params.type_k), blck_size, model->hparams.n_embd_head_k(il)); + return nullptr; + } + } + } + + if (params.flash_attn_type != LLAMA_FLASH_ATTN_TYPE_DISABLED && ggml_is_quantized(params.type_v)) { const uint32_t blck_size = ggml_blck_size(params.type_v); - if (model->hparams.n_embd_head_v % blck_size != 0) { - LLAMA_LOG_ERROR("%s: V cache type %s with block size %u does not divide n_embd_head_k=%u\n", - __func__, ggml_type_name(params.type_v), blck_size, model->hparams.n_embd_head_v); - return nullptr; + for (uint32_t il = 0; il < model->hparams.n_layer(); ++il) { + if (model->hparams.n_embd_head_v(il) % blck_size != 0) { + LLAMA_LOG_ERROR("%s: V cache type %s with block size %u does not divide n_embd_head_v=%u\n", + __func__, ggml_type_name(params.type_v), blck_size, model->hparams.n_embd_head_v(il)); + return nullptr; + } } } @@ -2985,6 +3551,12 @@ llama_context * llama_init_from_model( model->hparams.pooling_type, params.pooling_type); } + if (params.ctx_type == LLAMA_CONTEXT_TYPE_MTP && + model->hparams.n_layer_nextn == 0) { + LLAMA_LOG_WARN("%s: context type MTP requested but model doesn't contain MTP layers\n", __func__); + return nullptr; + } + try { auto * ctx = new llama_context(*model, params); return ctx; @@ -3026,6 +3598,10 @@ uint32_t llama_n_seq_max(const llama_context * ctx) { return ctx->n_seq_max(); } +uint32_t llama_n_rs_seq(const llama_context * ctx) { + return ctx->get_cparams().n_rs_seq; +} + const llama_model * llama_get_model(const llama_context * ctx) { return &ctx->get_model(); } @@ -3115,6 +3691,40 @@ float * llama_get_embeddings_seq(llama_context * ctx, llama_seq_id seq_id) { return ctx->get_embeddings_seq(seq_id); } +void llama_set_embeddings_nextn(llama_context * ctx, bool value, bool masked) { + ctx->set_embeddings_nextn(value, masked); +} + +void llama_set_embeddings_layer_inp(llama_context * ctx, uint32_t lid, bool value) { + ctx->set_embeddings_layer_inp(lid, value); +} + +llama_memory_t llama_get_memory(const struct llama_context * ctx) { + if (!ctx) { + return nullptr; + } + + return ctx->get_memory(); +} + +float * llama_get_embeddings_nextn(llama_context * ctx) { + ctx->synchronize(); + + return ctx->get_embeddings_nextn(); +} + +float * llama_get_embeddings_nextn_ith(llama_context * ctx, int32_t i) { + ctx->synchronize(); + + return ctx->get_embeddings_nextn_ith(i); +} + +float * llama_get_embeddings_layer_inp(llama_context * ctx, uint32_t lid) { + ctx->synchronize(); + + return ctx->get_embeddings_layer_inp(lid); +} + bool llama_set_sampler(llama_context * ctx, llama_seq_id seq_id, llama_sampler * smpl) { return ctx->set_sampler(seq_id, smpl); } @@ -3161,37 +3771,43 @@ uint32_t llama_get_sampled_probs_count_ith(llama_context * ctx, int32_t i) { return static_cast<uint32_t>(ctx->get_sampled_probs_count(i)); } -// llama adapter API - -int32_t llama_set_adapter_lora( - llama_context * ctx, - llama_adapter_lora * adapter, - float scale) { - ctx->set_adapter_lora(adapter, scale); - - return 0; +struct ggml_cgraph * llama_graph_reserve( + struct llama_context * ctx, + uint32_t n_tokens, + uint32_t n_seqs, + uint32_t n_outputs) { + auto memory = ctx->get_memory(); + llama_memory_context_ptr mctx; + if (memory) { + mctx = memory->init_full(); + } + return ctx->graph_reserve(n_tokens, n_seqs, n_outputs, mctx.get()); } -int32_t llama_rm_adapter_lora( +// llama adapter API + +int32_t llama_set_adapters_lora( llama_context * ctx, - llama_adapter_lora * adapter) { - bool res = ctx->rm_adapter_lora(adapter); + llama_adapter_lora ** adapters, + size_t n_adapters, + float * scales) { + if (adapters == nullptr || scales == nullptr) { + GGML_ASSERT(n_adapters == 0 && "invalid llama_set_adapters_lora call"); + } - return res ? 0 : -1; -} + ctx->set_adapters_lora(adapters, n_adapters, scales); -void llama_clear_adapter_lora(llama_context * ctx) { - ctx->clear_adapter_lora(); + return 0; } -int32_t llama_apply_adapter_cvec( +int32_t llama_set_adapter_cvec( llama_context * ctx, - const float * data, - size_t len, - int32_t n_embd, - int32_t il_start, - int32_t il_end) { - bool res = ctx->apply_adapter_cvec(data, len, n_embd, il_start, il_end); + const float * data, + size_t len, + int32_t n_embd, + int32_t il_start, + int32_t il_end) { + bool res = ctx->set_adapter_cvec(data, len, n_embd, il_start, il_end); return res ? 0 : -1; } @@ -3200,10 +3816,6 @@ int32_t llama_apply_adapter_cvec( // memory // -llama_memory_t llama_get_memory(const struct llama_context * ctx) { - return ctx->get_memory(); -} - void llama_memory_clear(llama_memory_t mem, bool data) { if (!mem) { return; @@ -3390,7 +4002,6 @@ size_t llama_state_seq_get_data_ext(llama_context * ctx, uint8_t * dst, size_t s return ctx->state_seq_get_data(seq_id, dst, size, flags); } - size_t llama_state_seq_set_data_ext(llama_context * ctx, const uint8_t * src, size_t size, llama_seq_id seq_id, llama_state_seq_flags flags) { ctx->synchronize(); @@ -3477,142 +4088,6 @@ void llama_perf_context_reset(llama_context * ctx) { ctx->perf_reset(); } -void llama_memory_breakdown_print(const struct llama_context * ctx) { - const std::vector<ggml_backend_dev_t> & devices = ctx->get_model().devices; - - std::map<ggml_backend_buffer_type_t, llama_memory_breakdown_data> memory_breakdown = ctx->memory_breakdown(); - - std::vector<std::array<std::string, 9>> table_data; - table_data.reserve(devices.size()); - const std::string template_header = "%s: | %s | %s %s %s %s %s %s %s |\n"; - const std::string template_gpu = "%s: | %s | %s = %s + (%s = %s + %s + %s) + %s |\n"; - const std::string template_other = "%s: | %s | %s %s %s = %s + %s + %s %s |\n"; - - table_data.push_back({template_header, "memory breakdown [MiB]", "total", "free", "self", "model", "context", "compute", "unaccounted"}); - - constexpr size_t MiB = 1024 * 1024; - const std::vector<std::string> desc_prefixes_strip = {"NVIDIA ", "GeForce ", "Tesla ", "AMD ", "Radeon ", "Instinct "}; - - // track seen buffer types to avoid double counting: - std::set<ggml_backend_buffer_type_t> seen_buffer_types; - - // accumulative memory breakdown for each device and for host: - std::vector<llama_memory_breakdown_data> mb_dev(devices.size()); - llama_memory_breakdown_data mb_host; - - for (const auto & buft_mb : memory_breakdown) { - ggml_backend_buffer_type_t buft = buft_mb.first; - const llama_memory_breakdown_data & mb = buft_mb.second; - if (ggml_backend_buft_is_host(buft)) { - mb_host.model += mb.model; - mb_host.context += mb.context; - mb_host.compute += mb.compute; - seen_buffer_types.insert(buft); - continue; - } - ggml_backend_dev_t dev = ggml_backend_buft_get_device(buft); - if (dev) { - int i_dev = -1; - for (size_t i = 0; i < devices.size(); i++) { - if (devices[i] == dev) { - i_dev = i; - break; - } - } - if (i_dev != -1) { - mb_dev[i_dev].model += mb.model; - mb_dev[i_dev].context += mb.context; - mb_dev[i_dev].compute += mb.compute; - seen_buffer_types.insert(buft); - continue; - } - } - } - - // print memory breakdown for each device: - for (size_t i = 0; i < devices.size(); i++) { - ggml_backend_dev_t dev = devices[i]; - llama_memory_breakdown_data mb = mb_dev[i]; - - const std::string name = ggml_backend_dev_name(dev); - std::string desc = ggml_backend_dev_description(dev); - for (const std::string & prefix : desc_prefixes_strip) { - if (desc.length() >= prefix.length() && desc.substr(0, prefix.length()) == prefix) { - desc = desc.substr(prefix.length()); - } - } - - size_t free, total; - ggml_backend_dev_memory(dev, &free, &total); - - const size_t self = mb.model + mb.context + mb.compute; - const size_t unaccounted = total - self - free; - - table_data.push_back({ - template_gpu, - " - " + name + " (" + desc + ")", - std::to_string(total / MiB), - std::to_string(free / MiB), - std::to_string(self / MiB), - std::to_string(mb.model / MiB), - std::to_string(mb.context / MiB), - std::to_string(mb.compute / MiB), - std::to_string(unaccounted / MiB)}); - } - - // print memory breakdown for host: - { - const size_t self = mb_host.model + mb_host.context + mb_host.compute; - table_data.push_back({ - template_other, - " - Host", - "", // total - "", // free - std::to_string(self / MiB), - std::to_string(mb_host.model / MiB), - std::to_string(mb_host.context / MiB), - std::to_string(mb_host.compute / MiB), - ""}); // unaccounted - } - - // print memory breakdown for all remaining buffer types: - for (const auto & buft_mb : memory_breakdown) { - ggml_backend_buffer_type_t buft = buft_mb.first; - const llama_memory_breakdown_data & mb = buft_mb.second; - if (seen_buffer_types.count(buft) == 1) { - continue; - } - const std::string name = ggml_backend_buft_name(buft); - const size_t self = mb.model + mb.context + mb.compute; - table_data.push_back({ - template_other, - " - " + name, - "", // total - "", // free - std::to_string(self / MiB), - std::to_string(mb.model / MiB), - std::to_string(mb.context / MiB), - std::to_string(mb.compute / MiB), - ""}); // unaccounted - seen_buffer_types.insert(buft); - } - - for (size_t j = 1; j < table_data[0].size(); j++) { - size_t max_len = 0; - for (const auto & td : table_data) { - max_len = std::max(max_len, td[j].length()); - } - for (auto & td : table_data) { - td[j].insert(j == 1 ? td[j].length() : 0, max_len - td[j].length(), ' '); - } - } - for (const auto & td : table_data) { - LLAMA_LOG_INFO(td[0].c_str(), - __func__, td[1].c_str(), td[2].c_str(), td[3].c_str(), td[4].c_str(), td[5].c_str(), - td[6].c_str(), td[7].c_str(), td[8].c_str()); - } -} - // // training // @@ -3643,3 +4118,15 @@ void llama_opt_epoch( callback_train, callback_eval); } + +// +// ext +// + +llama_memory_breakdown llama_get_memory_breakdown(const struct llama_context * ctx) { + return ctx->memory_breakdown(); +} + +llama_context * llama_get_ctx_other(struct llama_context * ctx) { + return ctx->get_cparams().ctx_other; +} diff --git a/examples/talk-llama/llama-context.h b/examples/talk-llama/llama-context.h index b29edf4db21..853052be2ca 100644 --- a/examples/talk-llama/llama-context.h +++ b/examples/talk-llama/llama-context.h @@ -1,9 +1,12 @@ #pragma once #include "llama.h" +#include "llama-ext.h" #include "llama-cparams.h" #include "llama-graph.h" #include "llama-adapter.h" +#include "llama-impl.h" +#include "llama-memory.h" #include "ggml-cpp.h" #include "ggml-opt.h" @@ -21,17 +24,21 @@ class llama_io_write_i; struct llama_memory_i; struct llama_memory_context_i; -// "memory" as in physical memory for a buffer type, in bytes -struct llama_memory_breakdown_data { - size_t model = 0; // memory allocated for the model - size_t context = 0; // memory allocated for the context - size_t compute = 0; // memory allocated for temporary compute buffers +// stores copy of the memory in device buffer. used for fast state save/load +struct llama_memory_buffer { + int n_tensors = 0; + size_t total_size = 0; - size_t total() const { - return model + context + compute; - } + ggml_backend_buffer_ptr buf; + + ggml_context_ptr ctx; + + std::vector<ggml_tensor *> org; + std::vector<ggml_tensor *> cpy; }; +using llama_memory_buffers = std::map<ggml_backend_buffer_type_t, llama_memory_buffer>; + struct llama_context { // init scheduler and compute buffers, reserve worst-case graphs llama_context( @@ -40,6 +47,14 @@ struct llama_context { ~llama_context(); + // reserve a new backend scheduler (if needed) + // for example, when: + // - changing loras + // - changing samplers + // - changing attention type + // - etc. + void sched_reserve(); + void synchronize(); const llama_model & get_model() const; @@ -70,6 +85,11 @@ struct llama_context { float * get_embeddings_ith(int32_t i); float * get_embeddings_seq(llama_seq_id seq_id); + float * get_embeddings_nextn(); + float * get_embeddings_nextn_ith(int32_t i); + + float * get_embeddings_layer_inp(uint32_t lid); + llama_token * get_sampled_tokens() const; llama_token get_sampled_token_ith(int32_t idx); @@ -93,19 +113,16 @@ struct llama_context { void set_abort_callback(bool (*abort_callback)(void * data), void * abort_callback_data); void set_embeddings (bool value); + void set_embeddings_nextn(bool value, bool masked); + void set_embeddings_layer_inp(uint32_t lid, bool enable); void set_causal_attn(bool value); void set_warmup(bool value); - void set_adapter_lora( - llama_adapter_lora * adapter, - float scale); - - bool rm_adapter_lora( - llama_adapter_lora * adapter); + void set_adapters_lora(llama_adapter_lora ** adapters, size_t n_adapters, float * scales); - void clear_adapter_lora(); + bool adapters_lora_are_same(llama_adapter_lora ** adapters, size_t n_adapters, float * scales); - bool apply_adapter_cvec( + bool set_adapter_cvec( const float * data, size_t len, int32_t n_embd, @@ -134,6 +151,7 @@ struct llama_context { size_t state_set_data(const uint8_t * src, size_t size); size_t state_seq_get_size(llama_seq_id seq_id, llama_state_seq_flags flags); + size_t state_seq_get_data(llama_seq_id seq_id, uint8_t * dst, size_t size, llama_state_seq_flags flags); size_t state_seq_set_data(llama_seq_id seq_id, const uint8_t * src, size_t size, llama_state_seq_flags flags); @@ -168,7 +186,7 @@ struct llama_context { llama_perf_context_data perf_get_data() const; void perf_reset(); - std::map<ggml_backend_buffer_type_t, llama_memory_breakdown_data> memory_breakdown() const; + llama_memory_breakdown memory_breakdown() const; // // training @@ -204,13 +222,17 @@ struct llama_context { // Make sure enough space is available for outputs. // Returns max number of outputs for which space was reserved. - uint32_t output_reserve(int32_t n_outputs, const llama_batch & batch); + uint32_t output_reserve(int32_t n_outputs); void output_reorder(); // map the output row index `i` to batch index int64_t output_resolve_row(int32_t i) const; + // async-copy enabled layer-input tensors (per cparams.output_layer_inp) + // from backend into host-side embd_layer_inp buffers + void extract_layer_inputs(const llm_graph_result * res, size_t token_offset, size_t n_tokens); + // // graph // @@ -252,43 +274,45 @@ struct llama_context { const llama_model & model; - llama_cparams cparams; - llama_adapter_cvec cvec; - llama_adapter_loras loras; + llama_cparams cparams; + + llama_adapter_cvec_ptr cvec; + llama_adapter_loras_ptr loras; llama_cross cross; // TODO: tmp for handling cross-attention - need something better probably - std::unique_ptr<llama_memory_i> memory; + llama_memory_ptr memory; // decode output (2-dimensional array: [n_outputs][n_vocab]) - size_t logits_size = 0; // capacity (of floats) for logits - float * logits = nullptr; + buffer_view<float> logits = {nullptr, 0}; // embeddings output (2-dimensional array: [n_outputs][n_embd]) // populated only when pooling_type == LLAMA_POOLING_TYPE_NONE - size_t embd_size = 0; // capacity (of floats) for embeddings - float * embd = nullptr; - - // TODO: simplify - struct sampling_info { - std::map<llama_seq_id, llama_sampler *> samplers; + buffer_view<float> embd = {nullptr, 0}; - float * logits = nullptr; - size_t logits_size = 0; + // hidden state required by the nextn layers (2-dimensional array: [n_outputs][n_embd]) + // populated only when cparams.embeddings_nextn is enabled and the model graph + // sets llm_graph_result::t_h_nextn + buffer_view<float> embd_nextn = {nullptr, 0}; - llama_token * sampled = nullptr; - size_t sampled_size = 0; + // host buffers for output layer input embeddings, per layer + // populated when cparams.output_layer_inp[il] is true + std::vector<buffer_view<float>> embd_layer_inp; - float * probs = nullptr; - size_t probs_size = 0; + struct sampling_info { + // !samplers.empty() to check if any samplers are active + std::map<llama_seq_id, llama_sampler *> samplers; - llama_token * candidates = nullptr; - size_t candidates_size = 0; + buffer_view<float> logits = {nullptr, 0}; + buffer_view<llama_token> sampled = {nullptr, 0}; + buffer_view<float> probs = {nullptr, 0}; + buffer_view<llama_token> candidates = {nullptr, 0}; std::vector<uint32_t> logits_count; std::vector<uint32_t> probs_count; std::vector<uint32_t> candidates_count; + // optimization std::vector<llama_token> token_ids_full_vocab; }; @@ -314,6 +338,8 @@ struct llama_context { ggml_backend_sched_ptr sched; + bool sched_need_reserve = true; + ggml_backend_t backend_cpu = nullptr; std::vector<ggml_backend_ptr> backends; @@ -339,6 +365,9 @@ struct llama_context { // host buffer for the model output (logits and embeddings) ggml_backend_buffer_ptr buf_output; + // keep copies of the per-sequence memory on the device + std::map<llama_seq_id, llama_memory_buffers> mem_storage; + bool has_evaluated_once = false; // env: LLAMA_GRAPH_REUSE_DISABLE diff --git a/examples/talk-llama/llama-cparams.h b/examples/talk-llama/llama-cparams.h index fcef8fa9760..2b109f909c0 100644 --- a/examples/talk-llama/llama-cparams.h +++ b/examples/talk-llama/llama-cparams.h @@ -3,6 +3,7 @@ #include "llama.h" #include <cstdint> +#include <vector> #define LLAMA_MAX_SEQ 256 @@ -12,6 +13,8 @@ struct llama_cparams { uint32_t n_batch; uint32_t n_ubatch; uint32_t n_seq_max; + uint32_t n_rs_seq; // number of recurrent-state snapshots per seq for rollback + uint32_t n_outputs_max; // max outputs supported by the context int32_t n_threads; // number of threads to use for generation int32_t n_threads_batch; // number of threads to use for batch processing @@ -27,16 +30,28 @@ struct llama_cparams { float yarn_beta_slow; bool embeddings; + bool embeddings_nextn; // also extract the hidden state before the final output norm + bool embeddings_nextn_masked; // extract for only rows where batch.logits != 0 bool causal_attn; bool offload_kqv; bool flash_attn; + bool auto_fa; + bool fused_gdn_ar; // use fused gated delta net (autoregressive) + bool fused_gdn_ch; // use fused gated delta net (chunked) + bool auto_fgdn; bool no_perf; - bool warmup; + bool warmup; // TODO: remove [TAG_LLAMA_GRAPH_NO_WARMUP] bool op_offload; bool kv_unified; + bool pipeline_parallel; + std::vector<bool> embeddings_layer_inp; // [n_layer()] extract input embeddings for layer + + enum llama_context_type ctx_type; enum llama_pooling_type pooling_type; ggml_backend_sched_eval_callback cb_eval; void * cb_eval_user_data; + + llama_context * ctx_other; }; diff --git a/examples/talk-llama/llama-ext.h b/examples/talk-llama/llama-ext.h new file mode 100644 index 00000000000..b744af52864 --- /dev/null +++ b/examples/talk-llama/llama-ext.h @@ -0,0 +1,120 @@ +#pragma once + +// this is a staging header for new llama.cpp API +// breaking changes and C++ are allowed. everything here should be considered WIP + +#include "llama.h" + +#include <cstdint> +#include <map> + +// Reserve a new compute graph. It is valid until the next call to llama_graph_reserve. +LLAMA_API struct ggml_cgraph * llama_graph_reserve( + struct llama_context * ctx, + uint32_t n_tokens, + uint32_t n_seqs, + uint32_t n_outputs); + +// Get the default ggml_type for a given ftype. +LLAMA_API ggml_type llama_ftype_get_default_type(llama_ftype ftype); + +struct quantize_state_impl; + +LLAMA_API quantize_state_impl * llama_quant_init( + const llama_model * model, + const llama_model_quantize_params * params); + +LLAMA_API void llama_quant_free(quantize_state_impl * qs); + +// Descriptor for constructing a mock model for quantization testing. +struct llama_quant_model_desc { + const char * architecture; + uint32_t n_embd; + uint32_t n_ff; + uint32_t n_layer; + uint32_t n_head; + uint32_t n_head_kv; + uint32_t n_expert; + uint32_t n_embd_head_k; + uint32_t n_embd_head_v; +}; + +// Create a mock model from a metadata descriptor (for testing). +// The returned model must be freed with llama_model_free(). +LLAMA_API llama_model * llama_quant_model_from_metadata(const llama_quant_model_desc * desc); + +// Returns true if this tensor should be quantized (based on name, dims, params). +LLAMA_API bool llama_quant_tensor_allows_quantization( + const quantize_state_impl * qs, + const ggml_tensor * tensor); + +// Compute quantization type assignments for a list of tensors. +// All tensors should be quantizable (use llama_quant_tensor_allows_quantization to filter). +// result_types: caller-allocated array of n_tensors elements, filled with assigned types. +LLAMA_API void llama_quant_compute_types( + quantize_state_impl * qs, + llama_ftype ftype, + ggml_tensor ** tensors, + ggml_type * result_types, + size_t n_tensors); + +// +// device memory querying +// + +// "memory" as in physical memory for a buffer type, in bytes +struct llama_memory_breakdown_data { + size_t model = 0; // memory allocated for the model + size_t context = 0; // memory allocated for the context + size_t compute = 0; // memory allocated for temporary compute buffers + + size_t total() const { + return model + context + compute; + } +}; + +struct llama_device_memory_data { + int64_t total; + int64_t free; + llama_memory_breakdown_data mb; +}; + +// TODO: convert to C-style data structure +using llama_memory_breakdown = std::map<ggml_backend_buffer_type_t, llama_memory_breakdown_data>; + +LLAMA_API int32_t llama_model_n_expert (const struct llama_model * model); +LLAMA_API int32_t llama_model_n_devices(const struct llama_model * model); + +LLAMA_API ggml_backend_dev_t llama_model_get_device(const struct llama_model * model, int i); + +LLAMA_API llama_memory_breakdown llama_get_memory_breakdown(const struct llama_context * ctx); + +// Set whether the context outputs nextn embeddings or not +// If masked == true, output the embeddings only for the tokens with batch.logits != 0 +// If masked == false, output the embeddings for all tokens in the batch regardless of batch.logits +LLAMA_API void llama_set_embeddings_nextn(struct llama_context * ctx, bool value, bool masked); + +// mirrors: +// LLAMA_API float * llama_get_embeddings(struct llama_context * ctx); +LLAMA_API float * llama_get_embeddings_nextn(struct llama_context * ctx); + +// LLAMA_API float * llama_get_embeddings_ith(struct llama_context * ctx, int32_t i); +LLAMA_API float * llama_get_embeddings_nextn_ith(struct llama_context * ctx, int32_t i); + +// Set whether the context outputs the input embeddings of a specific layer +LLAMA_API void llama_set_embeddings_layer_inp(struct llama_context * ctx, uint32_t lid, bool value); + +// mirrors: +// LLAMA_API float * llama_get_embeddings(struct llama_context * ctx); +LLAMA_API float * llama_get_embeddings_layer_inp(struct llama_context * ctx, uint32_t lid); + +LLAMA_API llama_context * llama_get_ctx_other(struct llama_context * ctx); + +// +// model/context data extraction +// + +// returns pointer to the target-model layer indices +LLAMA_API const int32_t * llama_model_target_layer_ids (const struct llama_model * model); +// returns the number of extracted layers from target model +LLAMA_API uint32_t llama_model_target_layer_ids_n(const struct llama_model * model); diff --git a/examples/talk-llama/llama-grammar.cpp b/examples/talk-llama/llama-grammar.cpp index 64ea2fd00a9..badcbfd0fbb 100644 --- a/examples/talk-llama/llama-grammar.cpp +++ b/examples/talk-llama/llama-grammar.cpp @@ -2,11 +2,12 @@ #include "llama-impl.h" #include "llama-vocab.h" -#include "llama-sampling.h" +#include "llama-sampler.h" #include <cmath> #include <algorithm> #include <cstdint> +#include <set> #include <stdexcept> #define MAX_REPETITION_THRESHOLD 2000 @@ -454,6 +455,7 @@ const char * llama_grammar_parser::parse_sequence( bool is_nested) { size_t last_sym_start = rule.size(); const char * pos = src; + uint64_t n_prev_rules = 1; // use UINT64_MAX as the empty value because we aligned to the proper uint64_t type so -1 can't be used // (though it's technically the same as -1 now) @@ -481,6 +483,18 @@ const char * llama_grammar_parser::parse_sequence( // S' ::= S | llama_grammar_rule prev_rule(rule.begin() + last_sym_start, rule.end()); + // Calculate the total number of rules that will be generated by this repetition + uint64_t total_rules = 1; // Start with 1 for the original rule + if (!no_max && max_times > 0) { + total_rules = max_times; + } else if (min_times > 0) { + total_rules = min_times; + } + + if (n_prev_rules * total_rules >= MAX_REPETITION_THRESHOLD) { + throw std::runtime_error("number of rules that are going to be repeated multiplied by the new repetition exceeds sane defaults, please reduce the number of repetitions or rule complexity"); + } + if (min_times == 0) { rule.resize(last_sym_start); } else { @@ -508,12 +522,15 @@ const char * llama_grammar_parser::parse_sequence( if (n_opt > 0) { rule.push_back({LLAMA_GRETYPE_RULE_REF, last_rec_rule_id}); } + n_prev_rules *= total_rules; + GGML_ASSERT(n_prev_rules >= 1); }; while (*pos) { if (*pos == '"') { // literal string pos++; last_sym_start = rule.size(); + n_prev_rules = 1; while (*pos != '"') { if (!*pos) { throw std::runtime_error("unexpected end of input"); @@ -531,6 +548,7 @@ const char * llama_grammar_parser::parse_sequence( start_type = LLAMA_GRETYPE_CHAR_NOT; } last_sym_start = rule.size(); + n_prev_rules = 1; while (*pos != ']') { if (!*pos) { throw std::runtime_error("unexpected end of input"); @@ -561,6 +579,7 @@ const char * llama_grammar_parser::parse_sequence( auto token_pair = parse_token(vocab, pos); const char * token_end = token_pair.second; last_sym_start = rule.size(); + n_prev_rules = 1; rule.push_back({type, token_pair.first}); pos = parse_space(token_end, is_nested); } else if (is_word_char(*pos)) { // rule reference @@ -568,12 +587,15 @@ const char * llama_grammar_parser::parse_sequence( uint32_t ref_rule_id = get_symbol_id(pos, name_end - pos); pos = parse_space(name_end, is_nested); last_sym_start = rule.size(); + n_prev_rules = 1; rule.push_back({LLAMA_GRETYPE_RULE_REF, ref_rule_id}); } else if (*pos == '(') { // grouping // parse nested alternates into synthesized rule pos = parse_space(pos + 1, true); + uint32_t n_rules_before = symbol_ids.size(); uint32_t sub_rule_id = generate_symbol_id(rule_name); pos = parse_alternates(pos, rule_name, sub_rule_id, true); + n_prev_rules = std::max(1u, (uint32_t)symbol_ids.size() - n_rules_before); last_sym_start = rule.size(); // output reference to synthesized rule rule.push_back({LLAMA_GRETYPE_RULE_REF, sub_rule_id}); @@ -583,6 +605,7 @@ const char * llama_grammar_parser::parse_sequence( pos = parse_space(pos + 1, is_nested); } else if (*pos == '.') { // any char last_sym_start = rule.size(); + n_prev_rules = 1; rule.push_back({LLAMA_GRETYPE_CHAR_ANY, 0}); pos = parse_space(pos + 1, is_nested); } else if (*pos == '*') { @@ -601,7 +624,7 @@ const char * llama_grammar_parser::parse_sequence( throw std::runtime_error(std::string("expecting an int at ") + pos); } const char * int_end = parse_int(pos); - uint64_t min_times = std::stoul(std::string(pos, int_end - pos)); + uint64_t min_times = std::stoull(std::string(pos, int_end - pos)); pos = parse_space(int_end, is_nested); uint64_t max_times = UINT64_MAX; // default: no max limit @@ -614,7 +637,7 @@ const char * llama_grammar_parser::parse_sequence( if (is_digit_char(*pos)) { const char * int_end = parse_int(pos); - max_times = std::stoul(std::string(pos, int_end - pos)); + max_times = std::stoull(std::string(pos, int_end - pos)); pos = parse_space(int_end, is_nested); } @@ -830,32 +853,54 @@ static bool llama_grammar_match_token( static void llama_grammar_advance_stack( const llama_grammar_rules & rules, const llama_grammar_stack & stack, - llama_grammar_stacks & new_stacks) { - if (stack.empty()) { - if (std::find(new_stacks.begin(), new_stacks.end(), stack) == new_stacks.end()) { - new_stacks.emplace_back(stack); + llama_grammar_stacks & new_stacks) { + std::vector<llama_grammar_stack> todo; + todo.push_back(stack); + + auto stack_cmp = [](const llama_grammar_stack & a, const llama_grammar_stack & b) { + return std::lexicographical_compare(a.begin(), a.end(), b.begin(), b.end(), + [](const llama_grammar_element * pa, const llama_grammar_element * pb) { + return pa < pb; // Compare pointer addresses + } + ); + }; + + std::set<llama_grammar_stack, decltype(stack_cmp)> seen(stack_cmp); + + while (!todo.empty()) { + llama_grammar_stack curr_stack = std::move(todo.back()); + todo.pop_back(); + + if (seen.find( curr_stack) != seen.end()) { + continue; } - return; - } + seen.insert(curr_stack); - const llama_grammar_element * pos = stack.back(); + if (curr_stack.empty()) { + if (std::find(new_stacks.begin(), new_stacks.end(), curr_stack) == new_stacks.end()) { + new_stacks.emplace_back(std::move(curr_stack)); + } + continue; + } - switch (pos->type) { + const llama_grammar_element * pos = curr_stack.back(); + + switch (pos->type) { case LLAMA_GRETYPE_RULE_REF: { const size_t rule_id = static_cast<size_t>(pos->value); const llama_grammar_element * subpos = rules[rule_id].data(); do { // init new stack without the top (pos) - llama_grammar_stack new_stack(stack.begin(), stack.end() - 1); + llama_grammar_stack next_stack(curr_stack.begin(), curr_stack.end() - 1); if (!llama_grammar_is_end_of_sequence(pos + 1)) { // if this rule ref is followed by another element, add that to stack - new_stack.push_back(pos + 1); + next_stack.push_back(pos + 1); } if (!llama_grammar_is_end_of_sequence(subpos)) { // if alternate is nonempty, add to stack - new_stack.push_back(subpos); + next_stack.push_back(subpos); } - llama_grammar_advance_stack(rules, new_stack, new_stacks); + todo.push_back(std::move(next_stack)); while (!llama_grammar_is_end_of_sequence(subpos)) { // scan to end of alternate def subpos++; @@ -874,9 +919,9 @@ static void llama_grammar_advance_stack( case LLAMA_GRETYPE_CHAR_ANY: case LLAMA_GRETYPE_TOKEN: case LLAMA_GRETYPE_TOKEN_NOT: - if (std::find(new_stacks.begin(), new_stacks.end(), stack) == new_stacks.end()) { + if (std::find(new_stacks.begin(), new_stacks.end(), curr_stack) == new_stacks.end()) { // only add the stack if it's not a duplicate of one we already have - new_stacks.emplace_back(stack); + new_stacks.emplace_back(std::move(curr_stack)); } break; default: @@ -884,6 +929,7 @@ static void llama_grammar_advance_stack( // (LLAMA_GRETYPE_CHAR_ALT, LLAMA_GRETYPE_CHAR_RNG_UPPER); stack should never be left on // those GGML_ABORT("fatal error"); + } } } @@ -1160,13 +1206,13 @@ struct llama_grammar * llama_grammar_init_impl( // if there is a grammar, parse it // rules will be empty (default) if there are parse errors if (!parser.parse(grammar_str) || parser.rules.empty()) { - fprintf(stderr, "%s: failed to parse grammar\n", __func__); + LLAMA_LOG_ERROR("failed to parse grammar\n"); return nullptr; } - // Ensure that there is a "root" node. - if (parser.symbol_ids.find("root") == parser.symbol_ids.end()) { - fprintf(stderr, "%s: grammar does not contain a 'root' symbol\n", __func__); + // Ensure that the grammar contains the start symbol + if (parser.symbol_ids.find(grammar_root) == parser.symbol_ids.end()) { + LLAMA_LOG_ERROR("grammar does not contain a '%s' symbol\n", grammar_root); return nullptr; } @@ -1195,7 +1241,7 @@ struct llama_grammar * llama_grammar_init_impl( continue; } if (llama_grammar_detect_left_recursion(vec_rules, i, &rules_visited, &rules_in_progress, &rules_may_be_empty)) { - LLAMA_LOG_ERROR("unsupported grammar, left recursion detected for nonterminal at index %zu", i); + LLAMA_LOG_ERROR("unsupported grammar, left recursion detected for nonterminal at index %zu\n", i); return nullptr; } } diff --git a/examples/talk-llama/llama-graph.cpp b/examples/talk-llama/llama-graph.cpp index 374ff1ebf3a..7468bd9b79e 100644 --- a/examples/talk-llama/llama-graph.cpp +++ b/examples/talk-llama/llama-graph.cpp @@ -1,19 +1,86 @@ #include "llama-graph.h" #include "llama-impl.h" +#include "llama-model.h" #include "llama-batch.h" #include "llama-cparams.h" #include "llama-kv-cache.h" #include "llama-kv-cache-iswa.h" +#include "llama-kv-cache-dsa.h" #include "llama-memory-hybrid.h" +#include "llama-memory-hybrid-iswa.h" #include "llama-memory-recurrent.h" #include <cassert> #include <cmath> #include <cstring> +#include <numeric> +#include <sstream> #include <unordered_set> +// dedup helpers + +static ggml_tensor * build_attn_inp_kq_mask( + ggml_context * ctx, + const llama_kv_cache_context * mctx, + const llama_ubatch & ubatch, + const llama_cparams & cparams) { + const auto n_kv = mctx->get_n_kv(); + const auto n_tokens = ubatch.n_tokens; + const auto n_stream = cparams.kv_unified ? 1 : ubatch.n_seqs_unq; + + // flash attention requires an f16 mask + const auto type = cparams.flash_attn ? GGML_TYPE_F16 : GGML_TYPE_F32; + + ggml_tensor * res = ggml_new_tensor_4d(ctx, type, n_kv, n_tokens/n_stream, 1, n_stream); + ggml_set_input(res); + ggml_set_name(res, "attn_inp_kq_mask"); + + return res; +} + +static bool can_reuse_kq_mask( + ggml_tensor * kq_mask, + const llama_kv_cache_context * mctx, + const llama_ubatch & ubatch, + const llama_cparams & cparams) { + const auto n_kv = mctx->get_n_kv(); + const auto n_tokens = ubatch.n_tokens; + const auto n_stream = cparams.kv_unified ? 1 : ubatch.n_seqs_unq; + + bool res = true; + + res &= (kq_mask->ne[0] == n_kv); + res &= (kq_mask->ne[1] == n_tokens/n_stream); + res &= (kq_mask->ne[2] == 1); + res &= (kq_mask->ne[3] == n_stream); + + return res; +} + +// impl + +static ggml_tensor * ggml_mul_mat_aux( + ggml_context * ctx, + ggml_tensor * cur, + ggml_tensor * rot) { + const auto n = rot->ne[0]; + + ggml_tensor * res; + + if (!ggml_is_contiguous(cur)) { + res = ggml_cont_2d (ctx, cur, n, ggml_nelements(cur)/n); + } else { + res = ggml_reshape_2d(ctx, cur, n, ggml_nelements(cur)/n); + } + res = ggml_mul_mat (ctx, rot, res); + ggml_mul_mat_set_hint(res, GGML_HINT_SRC0_IS_HADAMARD); + res = ggml_reshape_4d(ctx, res, cur->ne[0], cur->ne[1], cur->ne[2], cur->ne[3]); + + return res; +} + void llm_graph_input_embd::set_input(const llama_ubatch * ubatch) { if (ubatch->token) { const int64_t n_tokens = ubatch->n_tokens; @@ -22,7 +89,8 @@ void llm_graph_input_embd::set_input(const llama_ubatch * ubatch) { } if (ubatch->embd) { - const int64_t n_embd = embd->ne[0]; + GGML_ASSERT(n_embd == embd->ne[0]); + const int64_t n_tokens = ubatch->n_tokens; ggml_backend_tensor_set(embd, ubatch->embd, 0, n_tokens*n_embd*ggml_element_size(embd)); @@ -32,8 +100,41 @@ void llm_graph_input_embd::set_input(const llama_ubatch * ubatch) { bool llm_graph_input_embd::can_reuse(const llm_graph_params & params) { bool res = true; - res &= (!tokens && !params.ubatch.token) || (tokens && tokens->ne[0] == params.ubatch.n_tokens); - res &= (!embd && !params.ubatch.embd) || (embd && embd->ne[1] == params.ubatch.n_tokens); + res &= (!params.ubatch.token) || (tokens && tokens->ne[0] == params.ubatch.n_tokens); + res &= (!params.ubatch.embd) || (embd && embd->ne[1] == params.ubatch.n_tokens); + + return res; +} + +void llm_graph_input_embd_h::set_input(const llama_ubatch * ubatch) { + const int64_t n_tokens = ubatch->n_tokens; + + if (ubatch->token) { + ggml_backend_tensor_set(tokens, ubatch->token, 0, n_tokens*ggml_element_size(tokens)); + } else { + // note: mtmd embedding input goes through here + GGML_ASSERT(ubatch->embd); + GGML_ASSERT(n_embd == embd->ne[0]); + + ggml_backend_tensor_set(embd, ubatch->embd, 0, n_tokens*n_embd*ggml_element_size(h)); + } + + // TODO: extend llama_ubatch to differentiate between token embeddings and hidden states + // for now, we assume that the hidden state is always provided as an embedding + // ref: https://github.com/ggml-org/llama.cpp/pull/23643 + if (ubatch->embd) { + GGML_ASSERT(n_embd == h->ne[0]); + + ggml_backend_tensor_set(h, ubatch->embd, 0, n_tokens*n_embd*ggml_element_size(h)); + } +} + +bool llm_graph_input_embd_h::can_reuse(const llm_graph_params & params) { + bool res = true; + + res &= (!params.ubatch.token) || (tokens && tokens->ne[0] == params.ubatch.n_tokens); + res &= (!params.ubatch.embd) || (embd && embd->ne[1] == params.ubatch.n_tokens); + res &= (!params.ubatch.embd) || (h && h->ne[1] == params.ubatch.n_tokens); return res; } @@ -96,11 +197,9 @@ void llm_graph_input_pos_bucket::set_input(const llama_ubatch * ubatch) { int32_t * data = (int32_t *) pos_bucket->data; - for (int h = 0; h < 1; ++h) { - for (int j = 0; j < n_tokens; ++j) { - for (int i = 0; i < n_tokens; ++i) { - data[h*(n_tokens*n_tokens) + j*n_tokens + i] = llama_relative_position_bucket(ubatch->pos[i], ubatch->pos[j], hparams.n_rel_attn_bkts, true); - } + for (int j = 0; j < n_tokens; ++j) { + for (int i = 0; i < n_tokens; ++i) { + data[j*n_tokens + i] = llama_relative_position_bucket(ubatch->pos[i], ubatch->pos[j], hparams.n_rel_attn_bkts, true); } } } @@ -148,7 +247,10 @@ bool llm_graph_input_out_ids::can_reuse(const llm_graph_params & params) { } void llm_graph_input_mean::set_input(const llama_ubatch * ubatch) { - if (cparams.embeddings && cparams.pooling_type == LLAMA_POOLING_TYPE_MEAN) { + if (cparams.embeddings && + (cparams.pooling_type == LLAMA_POOLING_TYPE_MEAN || + cparams.pooling_type == LLAMA_POOLING_TYPE_RANK )) { + const int64_t n_tokens = ubatch->n_tokens; const int64_t n_seq_tokens = ubatch->n_seq_tokens; const int64_t n_seqs_unq = ubatch->n_seqs_unq; @@ -210,7 +312,7 @@ void llm_graph_input_cls::set_input(const llama_ubatch * ubatch) { const bool last = ( cparams.pooling_type == LLAMA_POOLING_TYPE_LAST || - (cparams.pooling_type == LLAMA_POOLING_TYPE_RANK && arch == LLM_ARCH_QWEN3) // qwen3 reranking & embedding models use last token + (cparams.pooling_type == LLAMA_POOLING_TYPE_RANK && (arch == LLM_ARCH_QWEN3 || arch == LLM_ARCH_QWEN3VL)) // qwen3 reranking & embedding models use last token ); for (int i = 0; i < n_tokens; ++i) { @@ -283,7 +385,8 @@ void llm_graph_input_cross_embd::set_input(const llama_ubatch * ubatch) { } } -static void print_mask(const float * data, int64_t n_tokens, int64_t n_kv, int64_t n_swa, llama_swa_type swa_type) { +template <typename T> +static void print_mask(const T * data, int64_t n_tokens, int64_t n_kv, int64_t n_swa, llama_swa_type swa_type) { LLAMA_LOG_DEBUG("%s: === Attention mask ===\n", __func__); const char * swa_type_str = "unknown"; @@ -294,7 +397,7 @@ static void print_mask(const float * data, int64_t n_tokens, int64_t n_kv, int64 case LLAMA_SWA_TYPE_SYMMETRIC: swa_type_str = "LLAMA_SWA_TYPE_SYMMETRIC"; break; }; - LLAMA_LOG_DEBUG("%s: n_swa : %d, n_kv: %d, swq_type: %s\n", __func__, (int)n_swa, (int)n_kv, swa_type_str); + LLAMA_LOG_DEBUG("%s: n_swa : %d, n_kv: %d, swa_type: %s\n", __func__, (int)n_swa, (int)n_kv, swa_type_str); LLAMA_LOG_DEBUG("%s: '0' = can attend, '∞' = masked\n", __func__); LLAMA_LOG_DEBUG("%s: Rows = query tokens, Columns = key/value tokens\n\n", __func__); @@ -307,7 +410,7 @@ static void print_mask(const float * data, int64_t n_tokens, int64_t n_kv, int64 for (int i = 0; i < std::min((int64_t)20, n_tokens); ++i) { LLAMA_LOG_DEBUG(" %2d ", i); for (int j = 0; j < std::min((int64_t)20, n_kv); ++j) { - float val = data[i * n_kv + j]; + float val = llama_cast<float>(data[i * n_kv + j]); if (val == -INFINITY) { LLAMA_LOG_DEBUG(" ∞"); } else { @@ -322,66 +425,59 @@ void llm_graph_input_attn_no_cache::set_input(const llama_ubatch * ubatch) { const int64_t n_kv = ubatch->n_tokens; const int64_t n_tokens = ubatch->n_tokens; - const auto fill_mask = [&](float * data, int n_swa, llama_swa_type swa_type) { - for (int h = 0; h < 1; ++h) { - for (int i1 = 0; i1 < n_tokens; ++i1) { - const llama_seq_id s1 = ubatch->seq_id[i1][0]; - const llama_pos p1 = ubatch->pos[i1]; + const auto fill_mask = [&](auto * data, int64_t ne, int n_swa, llama_swa_type swa_type) { + using T = std::remove_reference_t<decltype(*data)>; + std::fill(data, data + ne, llama_cast<T>(-INFINITY)); - const uint64_t idst = h*(n_kv*n_tokens) + i1*n_kv; + for (int i1 = 0; i1 < n_tokens; ++i1) { + const llama_seq_id s1 = ubatch->seq_id[i1][0]; + const llama_pos p1 = ubatch->pos[i1]; - for (int i0 = 0; i0 < n_tokens; ++i0) { - const llama_seq_id s0 = ubatch->seq_id[i0][0]; - const llama_pos p0 = ubatch->pos[i0]; + const uint64_t idst = i1*n_kv; - // mask different sequences - if (s0 != s1) { - continue; - } + for (int i0 = 0; i0 < n_tokens; ++i0) { + const llama_seq_id s0 = ubatch->seq_id[i0][0]; + const llama_pos p0 = ubatch->pos[i0]; - // mask future tokens - if (cparams.causal_attn && p0 > p1) { - continue; - } + // mask different sequences + if (s0 != s1) { + continue; + } - // apply SWA if any - if (llama_hparams::is_masked_swa(n_swa, swa_type, p0, p1)) { - continue; - } + // mask future tokens + if (cparams.causal_attn && p0 > p1) { + continue; + } - data[idst + i0] = hparams.use_alibi ? -std::abs(p0 - p1) : 0.0f; + // apply SWA if any + if (llama_hparams::is_masked_swa(n_swa, swa_type, p0, p1)) { + continue; } + + data[idst + i0] = llama_cast<T>(hparams.use_alibi ? -std::abs(p0 - p1) : 0.0f); } } - }; - - { - GGML_ASSERT(self_kq_mask); - GGML_ASSERT(ggml_backend_buffer_is_host(self_kq_mask->buffer)); - - float * data = (float *) self_kq_mask->data; - - std::fill(data, data + ggml_nelements(self_kq_mask), -INFINITY); - - fill_mask(data, 0, LLAMA_SWA_TYPE_NONE); if (debug) { - print_mask(data, n_tokens, n_kv, 0, LLAMA_SWA_TYPE_NONE); + print_mask(data, n_tokens, n_kv, n_swa, swa_type); } + }; + + GGML_ASSERT(self_kq_mask); + GGML_ASSERT(ggml_backend_buffer_is_host(self_kq_mask->buffer)); + if (self_kq_mask->type == GGML_TYPE_F16) { + fill_mask((ggml_fp16_t *) self_kq_mask->data, ggml_nelements(self_kq_mask), 0, LLAMA_SWA_TYPE_NONE); + } else { + fill_mask((float *) self_kq_mask->data, ggml_nelements(self_kq_mask), 0, LLAMA_SWA_TYPE_NONE); } if (hparams.swa_type != LLAMA_SWA_TYPE_NONE) { GGML_ASSERT(self_kq_mask_swa); GGML_ASSERT(ggml_backend_buffer_is_host(self_kq_mask_swa->buffer)); - - float * data = (float *) self_kq_mask_swa->data; - - std::fill(data, data + ggml_nelements(self_kq_mask_swa), -INFINITY); - - fill_mask(data, hparams.n_swa, hparams.swa_type); - - if (debug) { - print_mask(data, n_tokens, n_kv, hparams.n_swa, hparams.swa_type); + if (self_kq_mask_swa->type == GGML_TYPE_F16) { + fill_mask((ggml_fp16_t *) self_kq_mask_swa->data, ggml_nelements(self_kq_mask_swa), hparams.n_swa, hparams.swa_type); + } else { + fill_mask((float *) self_kq_mask_swa->data, ggml_nelements(self_kq_mask_swa), hparams.n_swa, hparams.swa_type); } } } @@ -391,6 +487,14 @@ void llm_graph_input_attn_kv::set_input(const llama_ubatch * ubatch) { mctx->set_input_v_idxs(self_v_idxs, ubatch); mctx->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn); + + if (self_k_rot) { + mctx->set_input_k_rot(self_k_rot); + } + + if (self_v_rot) { + mctx->set_input_v_rot(self_v_rot); + } } bool llm_graph_input_attn_kv::can_reuse(const llm_graph_params & params) { @@ -403,22 +507,96 @@ bool llm_graph_input_attn_kv::can_reuse(const llm_graph_params & params) { res &= self_k_idxs->ne[0] == params.ubatch.n_tokens; //res &= self_v_idxs->ne[0] == params.ubatch.n_tokens; // TODO: need to move this to the unified cache and check there - res &= self_kq_mask->ne[0] == mctx->get_n_kv(); - res &= self_kq_mask->ne[1] == params.ubatch.n_tokens; + res &= can_reuse_kq_mask(self_kq_mask, mctx, params.ubatch, params.cparams); + + return res; +} + +void llm_graph_input_attn_k::set_input(const llama_ubatch * ubatch) { + mctx->set_input_k_idxs(self_k_idxs, ubatch); + + mctx->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn); +} + +bool llm_graph_input_attn_k::can_reuse(const llm_graph_params & params) { + const auto * mctx = static_cast<const llama_kv_cache_context *>(params.mctx); + + this->mctx = mctx; + + bool res = true; + + res &= self_k_idxs->ne[0] == params.ubatch.n_tokens; + + res &= can_reuse_kq_mask(self_kq_mask, mctx, params.ubatch, params.cparams); + + return res; +} + +void llm_graph_input_attn_k_dsa::set_input(const llama_ubatch * ubatch) { + mctx->get_mla()->set_input_k_idxs(self_k_idxs_mla, ubatch); + + mctx->get_mla()->set_input_kq_mask(self_kq_mask_mla, ubatch, cparams.causal_attn); + + mctx->get_lid()->set_input_k_idxs(self_k_idxs_lid, ubatch); + + mctx->get_lid()->set_input_kq_mask(self_kq_mask_lid, ubatch, cparams.causal_attn); + + mctx->get_lid()->set_input_k_rot(self_k_rot_lid); +} + +bool llm_graph_input_attn_k_dsa::can_reuse(const llm_graph_params & params) { + const auto * mctx = static_cast<const llama_kv_cache_dsa_context *>(params.mctx); + + this->mctx = mctx; + + bool res = true; + + res &= self_k_idxs_mla->ne[0] == params.ubatch.n_tokens; + res &= self_k_idxs_lid->ne[0] == params.ubatch.n_tokens; + + res &= can_reuse_kq_mask(self_kq_mask_mla, mctx->get_mla(), params.ubatch, params.cparams); + res &= can_reuse_kq_mask(self_kq_mask_lid, mctx->get_lid(), params.ubatch, params.cparams); return res; } void llm_graph_input_attn_kv_iswa::set_input(const llama_ubatch * ubatch) { - mctx->get_base()->set_input_k_idxs(self_k_idxs, ubatch); - mctx->get_base()->set_input_v_idxs(self_v_idxs, ubatch); + // base tensors may not be allocated if there are no non-SWA attention layers + if (self_k_idxs && self_k_idxs->buffer) { + mctx->get_base()->set_input_k_idxs(self_k_idxs, ubatch); + mctx->get_base()->set_input_v_idxs(self_v_idxs, ubatch); + } - mctx->get_base()->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn); + // the kq mask guards on its own buffer: shared cells leave idxs unbacked while the mask stays live + if (self_kq_mask && self_kq_mask->buffer) { + mctx->get_base()->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn); + } + + // swa tensors may not be allocated if there are no SWA attention layers + if (self_k_idxs_swa && self_k_idxs_swa->buffer) { + mctx->get_swa()->set_input_k_idxs(self_k_idxs_swa, ubatch); + mctx->get_swa()->set_input_v_idxs(self_v_idxs_swa, ubatch); + } - mctx->get_swa()->set_input_k_idxs(self_k_idxs_swa, ubatch); - mctx->get_swa()->set_input_v_idxs(self_v_idxs_swa, ubatch); + if (self_kq_mask_swa && self_kq_mask_swa->buffer) { + mctx->get_swa()->set_input_kq_mask(self_kq_mask_swa, ubatch, cparams.causal_attn); + } - mctx->get_swa()->set_input_kq_mask(self_kq_mask_swa, ubatch, cparams.causal_attn); + if (self_k_rot) { + mctx->get_base()->set_input_k_rot(self_k_rot); + } + + if (self_v_rot) { + mctx->get_base()->set_input_v_rot(self_v_rot); + } + + if (self_k_rot_swa) { + mctx->get_swa()->set_input_k_rot(self_k_rot_swa); + } + + if (self_v_rot_swa) { + mctx->get_swa()->set_input_v_rot(self_v_rot_swa); + } } bool llm_graph_input_attn_kv_iswa::can_reuse(const llm_graph_params & params) { @@ -428,17 +606,25 @@ bool llm_graph_input_attn_kv_iswa::can_reuse(const llm_graph_params & params) { bool res = true; - res &= self_k_idxs->ne[0] == params.ubatch.n_tokens; - //res &= self_v_idxs->ne[0] == params.ubatch.n_tokens; // TODO: need to move this to the unified cache and check there + // base tensors may not be allocated if there are no non-SWA attention layers + if (self_k_idxs && self_k_idxs->buffer) { + res &= self_k_idxs->ne[0] == params.ubatch.n_tokens; + //res &= self_v_idxs->ne[0] == params.ubatch.n_tokens; // TODO: need to move this to the unified cache and check there + } - res &= self_k_idxs_swa->ne[0] == params.ubatch.n_tokens; - //res &= self_v_idxs_swa->ne[0] == params.ubatch.n_tokens; // TODO: need to move this to the unified cache and check there + if (self_kq_mask && self_kq_mask->buffer) { + res &= can_reuse_kq_mask(self_kq_mask, mctx->get_base(), params.ubatch, params.cparams); + } - res &= self_kq_mask->ne[0] == mctx->get_base()->get_n_kv(); - res &= self_kq_mask->ne[1] == params.ubatch.n_tokens; + // swa tensors may not be allocated if there are no SWA attention layers + if (self_k_idxs_swa && self_k_idxs_swa->buffer) { + res &= self_k_idxs_swa->ne[0] == params.ubatch.n_tokens; + //res &= self_v_idxs_swa->ne[0] == params.ubatch.n_tokens; // TODO: need to move this to the unified cache and check there + } - res &= self_kq_mask_swa->ne[0] == mctx->get_swa()->get_n_kv(); - res &= self_kq_mask_swa->ne[1] == params.ubatch.n_tokens; + if (self_kq_mask_swa && self_kq_mask_swa->buffer) { + res &= can_reuse_kq_mask(self_kq_mask_swa, mctx->get_swa(), params.ubatch, params.cparams); + } return res; } @@ -452,10 +638,10 @@ void llm_graph_input_attn_cross::set_input(const llama_ubatch * ubatch) { GGML_ASSERT(ggml_backend_buffer_is_host(cross_kq_mask->buffer)); GGML_ASSERT(!ubatch->equal_seqs()); // TODO: use ubatch->n_seqs instead of failing - float * data = (float *) cross_kq_mask->data; - - for (int h = 0; h < 1; ++h) { + const auto fill_mask = [&](auto * data) { + using T = std::remove_reference_t<decltype(*data)>; for (int i = 0; i < n_tokens; ++i) { + GGML_ASSERT(!cross->seq_ids_enc.empty() && "llama_encode must be called first"); for (int j = 0; j < n_enc; ++j) { float f = -INFINITY; @@ -467,15 +653,15 @@ void llm_graph_input_attn_cross::set_input(const llama_ubatch * ubatch) { } } - data[h*(n_enc*n_tokens) + i*n_enc + j] = f; + data[i*n_enc + j] = llama_cast<T>(f); } } + }; - for (int i = n_tokens; i < n_tokens; ++i) { - for (int j = 0; j < n_enc; ++j) { - data[h*(n_enc*n_tokens) + i*n_enc + j] = -INFINITY; - } - } + if (cross_kq_mask->type == GGML_TYPE_F16) { + fill_mask((ggml_fp16_t *) cross_kq_mask->data); + } else { + fill_mask((float *) cross_kq_mask->data); } } @@ -485,6 +671,14 @@ void llm_graph_input_mem_hybrid::set_input(const llama_ubatch * ubatch) { mctx->get_attn()->set_input_kq_mask(inp_attn->self_kq_mask, ubatch, cparams.causal_attn); + if (inp_attn->self_k_rot) { + mctx->get_attn()->set_input_k_rot(inp_attn->self_k_rot); + } + + if (inp_attn->self_v_rot) { + mctx->get_attn()->set_input_v_rot(inp_attn->self_v_rot); + } + const int64_t n_rs = mctx->get_recr()->get_n_rs(); if (inp_rs->s_copy) { @@ -508,8 +702,138 @@ bool llm_graph_input_mem_hybrid::can_reuse(const llm_graph_params & params) { res &= inp_attn->self_k_idxs->ne[0] == params.ubatch.n_tokens; //res &= inp_attn->self_v_idxs->ne[0] == params.ubatch.n_tokens; // TODO: need to move this to the unified cache and check there - res &= inp_attn->self_kq_mask->ne[0] == mctx->get_attn()->get_n_kv(); - res &= inp_attn->self_kq_mask->ne[1] == params.ubatch.n_tokens; + res &= can_reuse_kq_mask(inp_attn->self_kq_mask, mctx->get_attn(), params.ubatch, params.cparams); + + res &= inp_rs->s_copy->ne[0] == mctx->get_recr()->get_n_rs(); + + res &= inp_rs->s_copy_main->ne[0] == params.ubatch.n_seqs; + res &= inp_rs->s_copy_extra->ne[0] == mctx->get_recr()->get_n_rs() - params.ubatch.n_seqs; + + res &= inp_rs->head == mctx->get_recr()->get_head(); + res &= inp_rs->rs_z == mctx->get_recr()->get_rs_z(); + + return res; +} + +// TODO: Hybrid input classes are a bit redundant. +// Instead of creating a hybrid input, the graph can simply create 2 separate inputs. +// Refactoring is required in the future. +void llm_graph_input_mem_hybrid_k::set_input(const llama_ubatch * ubatch) { + mctx->get_attn()->set_input_k_idxs(inp_attn->self_k_idxs, ubatch); + + mctx->get_attn()->set_input_kq_mask(inp_attn->self_kq_mask, ubatch, cparams.causal_attn); + + const int64_t n_rs = mctx->get_recr()->get_n_rs(); + + if (inp_rs->s_copy) { + GGML_ASSERT(ggml_backend_buffer_is_host(inp_rs->s_copy->buffer)); + int32_t * data = (int32_t *) inp_rs->s_copy->data; + + // assuming copy destinations ALWAYS happen ONLY on the cells between head and head+n + for (uint32_t i = 0; i < n_rs; ++i) { + data[i] = mctx->get_recr()->s_copy(i); + } + } +} + +bool llm_graph_input_mem_hybrid_k::can_reuse(const llm_graph_params & params) { + const auto * mctx = static_cast<const llama_memory_hybrid_context *>(params.mctx); + + this->mctx = mctx; + + bool res = true; + + res &= inp_attn->self_k_idxs->ne[0] == params.ubatch.n_tokens; + + res &= can_reuse_kq_mask(inp_attn->self_kq_mask, mctx->get_attn(), params.ubatch, params.cparams); + + res &= inp_rs->s_copy->ne[0] == mctx->get_recr()->get_n_rs(); + + res &= inp_rs->s_copy_main->ne[0] == params.ubatch.n_seqs; + res &= inp_rs->s_copy_extra->ne[0] == mctx->get_recr()->get_n_rs() - params.ubatch.n_seqs; + + res &= inp_rs->head == mctx->get_recr()->get_head(); + res &= inp_rs->rs_z == mctx->get_recr()->get_rs_z(); + + return res; +} + +void llm_graph_input_mem_hybrid_iswa::set_input(const llama_ubatch * ubatch) { + const auto * attn_ctx = mctx->get_attn(); + + // base tensors may not be allocated if there are no non-SWA attention layers + if (inp_attn->self_k_idxs && inp_attn->self_k_idxs->buffer) { + attn_ctx->get_base()->set_input_k_idxs(inp_attn->self_k_idxs, ubatch); + attn_ctx->get_base()->set_input_v_idxs(inp_attn->self_v_idxs, ubatch); + } + + if (inp_attn->self_kq_mask && inp_attn->self_kq_mask->buffer) { + attn_ctx->get_base()->set_input_kq_mask(inp_attn->self_kq_mask, ubatch, cparams.causal_attn); + } + + // swa tensors may not be allocated if there are no SWA attention layers + if (inp_attn->self_k_idxs_swa && inp_attn->self_k_idxs_swa->buffer) { + attn_ctx->get_swa()->set_input_k_idxs(inp_attn->self_k_idxs_swa, ubatch); + attn_ctx->get_swa()->set_input_v_idxs(inp_attn->self_v_idxs_swa, ubatch); + } + + if (inp_attn->self_kq_mask_swa && inp_attn->self_kq_mask_swa->buffer) { + attn_ctx->get_swa()->set_input_kq_mask(inp_attn->self_kq_mask_swa, ubatch, cparams.causal_attn); + } + + if (inp_attn->self_k_rot) { + attn_ctx->get_base()->set_input_k_rot(inp_attn->self_k_rot); + } + + if (inp_attn->self_v_rot) { + attn_ctx->get_base()->set_input_v_rot(inp_attn->self_v_rot); + } + + if (inp_attn->self_k_rot_swa) { + attn_ctx->get_swa()->set_input_k_rot(inp_attn->self_k_rot_swa); + } + + if (inp_attn->self_v_rot_swa) { + attn_ctx->get_swa()->set_input_v_rot(inp_attn->self_v_rot_swa); + } + + const int64_t n_rs = mctx->get_recr()->get_n_rs(); + + if (inp_rs->s_copy) { + GGML_ASSERT(ggml_backend_buffer_is_host(inp_rs->s_copy->buffer)); + int32_t * data = (int32_t *) inp_rs->s_copy->data; + + // assuming copy destinations ALWAYS happen ONLY on the cells between head and head+n + for (uint32_t i = 0; i < n_rs; ++i) { + data[i] = mctx->get_recr()->s_copy(i); + } + } +} + +bool llm_graph_input_mem_hybrid_iswa::can_reuse(const llm_graph_params & params) { + const auto * mctx = static_cast<const llama_memory_hybrid_iswa_context *>(params.mctx); + + this->mctx = mctx; + + bool res = true; + + const auto * attn_ctx = mctx->get_attn(); + + // base tensors may not be allocated if there are no non-SWA attention layers + if (inp_attn->self_k_idxs && inp_attn->self_k_idxs->buffer) { + res &= inp_attn->self_k_idxs->ne[0] == params.ubatch.n_tokens; + //res &= inp_attn->self_v_idxs->ne[0] == params.ubatch.n_tokens; // TODO: need to move this to the unified cache and check there + } + + res &= can_reuse_kq_mask(inp_attn->self_kq_mask, attn_ctx->get_base(), params.ubatch, params.cparams); + + // swa tensors may not be allocated if there are no SWA attention layers + if (inp_attn->self_k_idxs_swa && inp_attn->self_k_idxs_swa->buffer) { + res &= inp_attn->self_k_idxs_swa->ne[0] == params.ubatch.n_tokens; + //res &= inp_attn->self_v_idxs_swa->ne[0] == params.ubatch.n_tokens; // TODO: need to move this to the unified cache and check there + } + + res &= can_reuse_kq_mask(inp_attn->self_kq_mask_swa, attn_ctx->get_swa(), params.ubatch, params.cparams); res &= inp_rs->s_copy->ne[0] == mctx->get_recr()->get_n_rs(); @@ -575,10 +899,15 @@ int64_t llm_graph_result::get_max_nodes() const { } void llm_graph_result::reset() { - t_tokens = nullptr; + t_inp_tokens = nullptr; + t_inp_embd = nullptr; t_logits = nullptr; t_embd = nullptr; t_embd_pooled = nullptr; + + t_layer_inp.resize(LLAMA_MAX_LAYERS); + std::fill(t_layer_inp.begin(), t_layer_inp.end(), nullptr); + t_sampled.clear(); t_sampled_probs.clear(); t_sampled_logits.clear(); @@ -607,7 +936,7 @@ void llm_graph_result::set_inputs(const llama_ubatch * ubatch) { } } -void llm_graph_result::set_outputs() { +void llm_graph_result::set_outputs(const llm_graph_params & params) { if (t_logits != nullptr) { ggml_set_output(t_logits); } @@ -617,6 +946,18 @@ void llm_graph_result::set_outputs() { if (t_embd_pooled != nullptr) { ggml_set_output(t_embd_pooled); } + if (t_h_nextn != nullptr) { + ggml_set_output(t_h_nextn); + } + { + const auto & embeddings_layer_inp = params.cparams.embeddings_layer_inp; + for (size_t il = 0; il < embeddings_layer_inp.size(); ++il) { + if (embeddings_layer_inp[il]) { + GGML_ASSERT(t_layer_inp[il] != nullptr && "layer input tensor is null"); + ggml_set_output(t_layer_inp[il]); + } + } + } for (auto & [seq_id, t] : t_sampled) { if (t != nullptr) { ggml_set_output(t); @@ -690,14 +1031,15 @@ llm_graph_context::llm_graph_context(const llm_graph_params & params) : cparams (params.cparams), ubatch (params.ubatch), n_embd (hparams.n_embd), - n_layer (hparams.n_layer), - n_rot (hparams.n_rot), + n_layer (hparams.n_layer()), + n_layer_nextn (hparams.n_layer_nextn), + n_rot (hparams.n_rot()), n_ctx (cparams.n_ctx), n_head (hparams.n_head()), n_head_kv (hparams.n_head_kv()), - n_embd_head_k (hparams.n_embd_head_k), + n_embd_head_k (hparams.n_embd_head_k()), n_embd_k_gqa (hparams.n_embd_k_gqa()), - n_embd_head_v (hparams.n_embd_head_v), + n_embd_head_v (hparams.n_embd_head_v()), n_embd_v_gqa (hparams.n_embd_v_gqa()), n_expert (hparams.n_expert), n_expert_used (cparams.warmup ? hparams.n_expert : hparams.n_expert_used), @@ -742,7 +1084,8 @@ ggml_tensor * llm_graph_context::build_cvec( ggml_tensor * llm_graph_context::build_lora_mm( ggml_tensor * w, - ggml_tensor * cur) const { + ggml_tensor * cur, + ggml_tensor * w_s) const { ggml_tensor * res = ggml_mul_mat(ctx0, w, cur); for (const auto & lora : *loras) { @@ -763,6 +1106,10 @@ ggml_tensor * llm_graph_context::build_lora_mm( res = ggml_add(ctx0, res, ab_cur); } + if (w_s) { + res = ggml_mul(ctx0, res, w_s); + } + return res; } @@ -829,6 +1176,84 @@ ggml_tensor * llm_graph_context::build_norm( return cur; } + +llm_graph_qkv llm_graph_context::build_qkv( + const llama_layer & layer, + ggml_tensor * cur, + int64_t n_embd_head, + int64_t n_head, + int64_t n_head_kv, + int il) const { + const int64_t n_embd_q = n_embd_head * n_head; + const int64_t n_embd_kv = n_embd_head * n_head_kv; + + ggml_tensor * Qcur, * Kcur, * Vcur; + + if (layer.wqkv) { + // fused QKV path + ggml_tensor * qkv = build_lora_mm(layer.wqkv, cur, layer.wqkv_s); + cb(qkv, "wqkv", il); + if (layer.wqkv_b) { + qkv = ggml_add(ctx0, qkv, layer.wqkv_b); + cb(qkv, "wqkv_b", il); + } + if (hparams.f_clamp_kqv > 0.0f) { + qkv = ggml_clamp(ctx0, qkv, -hparams.f_clamp_kqv, hparams.f_clamp_kqv); + cb(qkv, "wqkv_clamped", il); + } + Qcur = ggml_view_3d(ctx0, qkv, n_embd_head, n_head, n_tokens, + ggml_row_size(qkv->type, n_embd_head), qkv->nb[1], 0); + Kcur = ggml_view_3d(ctx0, qkv, n_embd_head, n_head_kv, n_tokens, + ggml_row_size(qkv->type, n_embd_head), qkv->nb[1], + ggml_row_size(qkv->type, n_embd_q)); + Vcur = ggml_view_3d(ctx0, qkv, n_embd_head, n_head_kv, n_tokens, + ggml_row_size(qkv->type, n_embd_head), qkv->nb[1], + ggml_row_size(qkv->type, n_embd_q + n_embd_kv)); + } else { + // separate Q/K/V path + Qcur = build_lora_mm(layer.wq, cur, layer.wq_s); + cb(Qcur, "Qcur", il); + if (layer.wq_b) { + Qcur = ggml_add(ctx0, Qcur, layer.wq_b); + cb(Qcur, "Qcur", il); + } + if (hparams.f_clamp_kqv > 0.0f) { + Qcur = ggml_clamp(ctx0, Qcur, -hparams.f_clamp_kqv, hparams.f_clamp_kqv); + cb(Qcur, "Qcur_clamped", il); + } + Kcur = build_lora_mm(layer.wk, cur, layer.wk_s); + cb(Kcur, "Kcur", il); + if (layer.wk_b) { + Kcur = ggml_add(ctx0, Kcur, layer.wk_b); + cb(Kcur, "Kcur", il); + } + if (hparams.f_clamp_kqv > 0.0f) { + Kcur = ggml_clamp(ctx0, Kcur, -hparams.f_clamp_kqv, hparams.f_clamp_kqv); + cb(Kcur, "Kcur_clamped", il); + } + Vcur = build_lora_mm(layer.wv, cur, layer.wv_s); + cb(Vcur, "Vcur", il); + if (layer.wv_b) { + Vcur = ggml_add(ctx0, Vcur, layer.wv_b); + cb(Vcur, "Vcur", il); + } + if (hparams.f_clamp_kqv > 0.0f) { + Vcur = ggml_clamp(ctx0, Vcur, -hparams.f_clamp_kqv, hparams.f_clamp_kqv); + cb(Vcur, "Vcur_clamped", il); + } + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); + Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + } + + cb(Qcur, "Qcur", il); + cb(Kcur, "Kcur", il); + cb(Vcur, "Vcur", il); + + return { Qcur, Kcur, Vcur }; +} + + ggml_tensor * llm_graph_context::build_ffn( ggml_tensor * cur, ggml_tensor * up, @@ -888,6 +1313,26 @@ ggml_tensor * llm_graph_context::build_ffn( switch (type_op) { case LLM_FFN_SILU: if (gate && type_gate == LLM_FFN_PAR) { + // Step35: HF clamps gate (after SiLU) and up before multiplication + if (arch == LLM_ARCH_STEP35 && il >= 0) { + const float limit = hparams.swiglu_clamp_shexp[il]; + constexpr float eps = 1e-6f; + if (limit > eps) { + ggml_tensor * gate_act = ggml_silu(ctx0, cur); + cb(gate_act, "ffn_silu", il); + gate_act = ggml_clamp(ctx0, gate_act, -INFINITY, limit); + cb(gate_act, "ffn_silu_clamped", il); + + tmp = ggml_clamp(ctx0, tmp, -limit, limit); + cb(tmp, "ffn_up_clamped", il); + + cur = ggml_mul(ctx0, gate_act, tmp); + cb(cur, "ffn_swiglu_limited", il); + type_gate = LLM_FFN_SEQ; + break; + } + } + cur = ggml_swiglu_split(ctx0, cur, tmp); cb(cur, "ffn_swiglu", il); type_gate = LLM_FFN_SEQ; @@ -951,8 +1396,8 @@ ggml_tensor * llm_graph_context::build_ffn( if (down) { cur = build_lora_mm(down, cur); - if (arch == LLM_ARCH_GLM4 || arch == LLM_ARCH_GLM4_MOE) { - // GLM4 and GLM4_MOE seem to have numerical issues with half-precision accumulators + if (arch == LLM_ARCH_GLM4 || arch == LLM_ARCH_GLM4_MOE || arch == LLM_ARCH_JAIS2) { + // GLM4, GLM4_MOE, and JAIS2 seem to have numerical issues with half-precision accumulators ggml_mul_mat_set_prec(cur, GGML_PREC_F32); } } @@ -984,11 +1429,14 @@ ggml_tensor * llm_graph_context::build_moe_ffn( int64_t n_expert_used, llm_ffn_op_type type_op, bool norm_w, - bool scale_w, float w_scale, llama_expert_gating_func_type gating_op, int il, - ggml_tensor * probs_in) const { + ggml_tensor * probs_in, + ggml_tensor * gate_up_exps, + ggml_tensor * up_exps_s, + ggml_tensor * gate_exps_s, + ggml_tensor * down_exps_s) const { return build_moe_ffn( cur, gate_inp, /* gate_inp_b */ nullptr, @@ -1000,11 +1448,15 @@ ggml_tensor * llm_graph_context::build_moe_ffn( n_expert_used, type_op, norm_w, - scale_w, w_scale, gating_op, il, - probs_in + probs_in, + gate_up_exps, + /* gate_up_exps_b */ nullptr, + up_exps_s, + gate_exps_s, + down_exps_s ); } @@ -1023,11 +1475,15 @@ ggml_tensor * llm_graph_context::build_moe_ffn( int64_t n_expert_used, llm_ffn_op_type type_op, bool norm_w, - bool scale_w, float w_scale, llama_expert_gating_func_type gating_op, int il, - ggml_tensor * probs_in) const { + ggml_tensor * probs_in, + ggml_tensor * gate_up_exps, + ggml_tensor * gate_up_exps_b, + ggml_tensor * up_exps_s, + ggml_tensor * gate_exps_s, + ggml_tensor * down_exps_s) const { const int64_t n_embd = cur->ne[0]; const int64_t n_tokens = cur->ne[1]; const bool weight_before_ffn = arch == LLM_ARCH_LLAMA4; // for llama4, we apply the sigmoid-ed weights before the FFN @@ -1149,7 +1605,7 @@ ggml_tensor * llm_graph_context::build_moe_ffn( weights = ggml_reshape_3d(ctx0, weights, 1, n_expert_used, n_tokens); } - if (scale_w) { + if (w_scale != 0.0f && w_scale != 1.0f) { weights = ggml_scale(ctx0, weights, w_scale); cb(weights, "ffn_moe_weights_scaled", il); } @@ -1166,30 +1622,100 @@ ggml_tensor * llm_graph_context::build_moe_ffn( cb(cur, "ffn_moe_weighted", il); } - ggml_tensor * up = build_lora_mm_id(up_exps, cur, selected_experts); // [n_ff, n_expert_used, n_tokens] - cb(up, "ffn_moe_up", il); + ggml_tensor * up = nullptr; + ggml_tensor * experts = nullptr; - if (up_exps_b) { - up = ggml_add_id(ctx0, up, up_exps_b, selected_experts); - cb(up, "ffn_moe_up_biased", il); - } + if (gate_up_exps) { + // merged gate_up path: one mul_mat_id, then split into gate and up views + ggml_tensor * gate_up = build_lora_mm_id(gate_up_exps, cur, selected_experts); // [n_ff*2, n_expert_used, n_tokens] + cb(gate_up, "ffn_moe_gate_up", il); - ggml_tensor * experts = nullptr; - if (gate_exps) { - cur = build_lora_mm_id(gate_exps, cur, selected_experts); // [n_ff, n_expert_used, n_tokens] + if (gate_up_exps_b) { + gate_up = ggml_add_id(ctx0, gate_up, gate_up_exps_b, selected_experts); + cb(gate_up, "ffn_moe_gate_up_biased", il); + } + + // apply per-expert scale2 to merged gate_up (use up_exps_s since gate and up are fused) + if (up_exps_s) { + ggml_tensor * s = ggml_reshape_3d(ctx0, up_exps_s, 1, n_expert, 1); + s = ggml_repeat_4d(ctx0, s, 1, n_expert, n_tokens, 1); + s = ggml_get_rows(ctx0, s, selected_experts); // [1, n_expert_used, n_tokens] + gate_up = ggml_mul(ctx0, gate_up, s); + cb(gate_up, "ffn_moe_gate_up_scaled", il); + } + + const int64_t n_ff = gate_up->ne[0] / 2; + cur = ggml_view_3d(ctx0, gate_up, n_ff, gate_up->ne[1], gate_up->ne[2], gate_up->nb[1], gate_up->nb[2], 0); cb(cur, "ffn_moe_gate", il); + up = ggml_view_3d(ctx0, gate_up, n_ff, gate_up->ne[1], gate_up->ne[2], gate_up->nb[1], gate_up->nb[2], n_ff * gate_up->nb[0]); + cb(up, "ffn_moe_up", il); } else { - cur = up; - } + // separate gate and up path + up = build_lora_mm_id(up_exps, cur, selected_experts); // [n_ff, n_expert_used, n_tokens] + cb(up, "ffn_moe_up", il); + + if (up_exps_b) { + up = ggml_add_id(ctx0, up, up_exps_b, selected_experts); + cb(up, "ffn_moe_up_biased", il); + } + + // apply per-expert scale2 to up + if (up_exps_s) { + ggml_tensor * s = ggml_reshape_3d(ctx0, up_exps_s, 1, n_expert, 1); + s = ggml_repeat_4d(ctx0, s, 1, n_expert, n_tokens, 1); + s = ggml_get_rows(ctx0, s, selected_experts); // [1, n_expert_used, n_tokens] + up = ggml_mul(ctx0, up, s); + cb(up, "ffn_moe_up_scaled", il); + } + + if (gate_exps) { + cur = build_lora_mm_id(gate_exps, cur, selected_experts); // [n_ff, n_expert_used, n_tokens] + cb(cur, "ffn_moe_gate", il); + } else { + cur = up; + } - if (gate_exps_b) { - cur = ggml_add_id(ctx0, cur, gate_exps_b, selected_experts); - cb(cur, "ffn_moe_gate_biased", il); + if (gate_exps_b) { + cur = ggml_add_id(ctx0, cur, gate_exps_b, selected_experts); + cb(cur, "ffn_moe_gate_biased", il); + } + + // apply per-expert scale2 to gate + if (gate_exps_s) { + ggml_tensor * s = ggml_reshape_3d(ctx0, gate_exps_s, 1, n_expert, 1); + s = ggml_repeat_4d(ctx0, s, 1, n_expert, n_tokens, 1); + s = ggml_get_rows(ctx0, s, selected_experts); // [1, n_expert_used, n_tokens] + cur = ggml_mul(ctx0, cur, s); + cb(cur, "ffn_moe_gate_scaled", il); + } } + const bool has_gate = gate_exps || gate_up_exps; + switch (type_op) { case LLM_FFN_SILU: if (gate_exps) { + // Step35: per-layer clamp for routed experts + if (arch == LLM_ARCH_STEP35 && il >= 0) { + const float limit = hparams.swiglu_clamp_exp[il]; + constexpr float eps = 1e-6f; + if (limit > eps) { + ggml_tensor * gate_act = ggml_silu(ctx0, cur); + cb(gate_act, "ffn_moe_silu", il); + gate_act = ggml_clamp(ctx0, gate_act, -INFINITY, limit); + cb(gate_act, "ffn_moe_silu_clamped", il); + + up = ggml_clamp(ctx0, up, -limit, limit); + cb(up, "ffn_moe_up_clamped", il); + + cur = ggml_mul(ctx0, gate_act, up); + cb(cur, "ffn_moe_swiglu_limited", il); + break; + } + } + } + + if (has_gate) { cur = ggml_swiglu_split(ctx0, cur, up); cb(cur, "ffn_moe_swiglu", il); } else { @@ -1197,7 +1723,7 @@ ggml_tensor * llm_graph_context::build_moe_ffn( cb(cur, "ffn_moe_silu", il); } break; case LLM_FFN_GELU: - if (gate_exps) { + if (has_gate) { cur = ggml_geglu_split(ctx0, cur, up); cb(cur, "ffn_moe_geglu", il); } else { @@ -1213,7 +1739,7 @@ ggml_tensor * llm_graph_context::build_moe_ffn( cb(cur, "ffn_moe_swiglu_oai", il); } break; case LLM_FFN_RELU: - if (gate_exps) { + if (has_gate) { cur = ggml_reglu_split(ctx0, cur, up); cb(cur, "ffn_moe_reglu", il); } else { @@ -1221,7 +1747,7 @@ ggml_tensor * llm_graph_context::build_moe_ffn( cb(cur, "ffn_moe_relu", il); } break; case LLM_FFN_RELU_SQR: - if (gate_exps) { + if (has_gate) { // TODO: add support for gated squared relu GGML_ABORT("fatal error: gated squared relu not implemented"); } else { @@ -1241,11 +1767,22 @@ ggml_tensor * llm_graph_context::build_moe_ffn( cb(experts, "ffn_moe_down_biased", il); } + // apply per-expert scale2 to down + if (down_exps_s) { + ggml_tensor * s = ggml_reshape_3d(ctx0, down_exps_s, 1, n_expert, 1); + s = ggml_repeat_4d(ctx0, s, 1, n_expert, n_tokens, 1); + s = ggml_get_rows(ctx0, s, selected_experts); // [1, n_expert_used, n_tokens] + experts = ggml_mul(ctx0, experts, s); + cb(experts, "ffn_moe_down_scaled", il); + } + if (!weight_before_ffn) { experts = ggml_mul(ctx0, experts, weights); - cb(cur, "ffn_moe_weighted", il); + cb(experts, "ffn_moe_weighted", il); } + ggml_build_forward_expand(gf, experts); + ggml_tensor * cur_experts[LLAMA_MAX_EXPERTS] = { nullptr }; assert(n_expert_used > 0); @@ -1265,6 +1802,8 @@ ggml_tensor * llm_graph_context::build_moe_ffn( for (uint32_t i = 1; i < hparams.n_expert_used; ++i) { moe_out = ggml_add(ctx0, moe_out, cur_experts[i]); + + ggml_build_forward_expand(gf, moe_out); } if (hparams.n_expert_used == 1) { @@ -1279,17 +1818,29 @@ ggml_tensor * llm_graph_context::build_moe_ffn( // input embeddings with optional lora ggml_tensor * llm_graph_context::build_inp_embd(ggml_tensor * tok_embd) const { - const int64_t n_embd = hparams.n_embd_inp(); + const int64_t n_embd_inp = hparams.n_embd_inp(); + const int64_t n_embd = hparams.n_embd; - auto inp = std::make_unique<llm_graph_input_embd>(); + assert(n_embd_inp >= n_embd); - ggml_tensor * cur = nullptr; + auto inp = std::make_unique<llm_graph_input_embd>(n_embd_inp); - if (ubatch.token) { - inp->tokens = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, ubatch.n_tokens); - //cb(inp->tokens, "inp_tokens", -1); - ggml_set_input(inp->tokens); - res->t_tokens = inp->tokens; + inp->tokens = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, ubatch.n_tokens); + cb(inp->tokens, "inp_tokens", -1); + ggml_set_input(inp->tokens); + res->t_inp_tokens = inp->tokens; + + inp->embd = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd_inp, ubatch.n_tokens); + cb(inp->embd, "inp_embd", -1); + ggml_set_input(inp->embd); + + // select one of the 2 inputs, based on the batch contents + // ref: https://github.com/ggml-org/llama.cpp/pull/18550 + std::array<ggml_tensor *, 2> inps; + + // token embeddings path (ubatch.token != nullptr) + { + auto & cur = inps[0]; cur = ggml_get_rows(ctx0, tok_embd, inp->tokens); @@ -1310,19 +1861,41 @@ ggml_tensor * llm_graph_context::build_inp_embd(ggml_tensor * tok_embd) const { cur = ggml_add(ctx0, cur, inpL_delta); } - } else { - inp->embd = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, ubatch.n_tokens); - ggml_set_input(inp->embd); + + if (n_embd_inp != n_embd) { + cur = ggml_pad(ctx0, cur, hparams.n_embd_inp() - n_embd, 0, 0, 0); + } + } + + // vector embeddings path (ubatch.embd != nullptr) + { + auto & cur = inps[1]; cur = inp->embd; } + assert(ggml_are_same_shape (inps[0], inps[1])); + assert(ggml_are_same_stride(inps[0], inps[1])); + + ggml_tensor * cur = ggml_build_forward_select(gf, inps.data(), inps.size(), ubatch.token ? 0 : 1); + + if (n_embd_inp != n_embd) { + cur = ggml_view_2d(ctx0, cur, n_embd, n_tokens, cur->nb[1], 0); + } + + res->t_inp_embd = cur; + // For Granite architecture - if (hparams.f_embedding_scale != 0.0f) { + // NOTE: For deepstack models, only apply scale to token inputs (ie text-only input). + // Raw embeddings are assumed to be multimodal inputs that should not be scaled. + if (hparams.f_embedding_scale != 0.0f && (ubatch.token || hparams.n_deepstack_layers == 0)) { + if (!ggml_is_contiguous(cur)) { + cur = ggml_cont(ctx0, cur); + } cur = ggml_scale(ctx0, cur, hparams.f_embedding_scale); } - cb(cur, "inp_embd", -1); + cb(cur, "embd", -1); res->add_input(std::move(inp)); @@ -1354,6 +1927,7 @@ ggml_tensor * llm_graph_context::build_inp_attn_scale() const { // this need to be 1x1xN for broadcasting cur = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, 1, 1, n_tokens); ggml_set_input(cur); + ggml_set_name(cur, "attn_scale"); res->add_input(std::move(inp)); @@ -1362,8 +1936,8 @@ ggml_tensor * llm_graph_context::build_inp_attn_scale() const { ggml_tensor * llm_graph_context::build_inp_out_ids() const { // note: when all tokens are output, we could skip this optimization to spare the ggml_get_rows() calls, - // but this would make the graph topology depend on the number of output tokens, which can interere with - // features that require constant topology such as pipline parallelism + // but this would make the graph topology depend on the number of output tokens, which can interfere with + // features that require constant topology such as pipeline parallelism // ref: https://github.com/ggml-org/llama.cpp/pull/14275#issuecomment-2987424471 //if (n_outputs < n_tokens) { // return nullptr; @@ -1421,7 +1995,7 @@ ggml_tensor * llm_graph_context::build_inp_cross_embd() const { //} const auto n_embd = !cross->v_embd.empty() ? cross->n_embd : hparams.n_embd_inp(); - const auto n_enc = !cross->v_embd.empty() ? cross->n_enc : hparams.n_ctx_train; + const auto n_enc = !cross->v_embd.empty() ? cross->n_enc : hparams.n_ctx_train; cur = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, n_enc); ggml_set_input(cur); @@ -1499,7 +2073,8 @@ ggml_tensor * llm_graph_context::build_attn_mha( ggml_tensor * cur; - if (cparams.flash_attn && kq_b == nullptr) { + const bool use_flash_attn = cparams.flash_attn && kq_b == nullptr; + if (use_flash_attn) { GGML_ASSERT(kq_b == nullptr && "Flash attention does not support KQ bias yet"); if (v_trans) { @@ -1525,7 +2100,7 @@ ggml_tensor * llm_graph_context::build_attn_mha( if (v_mla) { #if 0 // v_mla can be applied as a matrix-vector multiplication with broadcasting across dimension 3 == n_tokens. - // However, the code is optimized for dimensions 0 and 1 being large, so this is ineffient. + // However, the code is optimized for dimensions 0 and 1 being large, so this is inefficient. cur = ggml_reshape_4d(ctx0, cur, v_mla->ne[0], 1, n_head, n_tokens); cur = ggml_mul_mat(ctx0, v_mla, cur); #else @@ -1613,17 +2188,20 @@ ggml_tensor * llm_graph_context::build_attn_mha( llm_graph_input_attn_no_cache * llm_graph_context::build_attn_inp_no_cache() const { auto inp = std::make_unique<llm_graph_input_attn_no_cache>(hparams, cparams); + // flash attention requires an f16 mask + const auto type_mask = cparams.flash_attn ? GGML_TYPE_F16 : GGML_TYPE_F32; + // note: there is no KV cache, so the number of KV values is equal to the number of tokens in the batch - inp->self_kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_tokens, n_tokens, 1, 1); + inp->self_kq_mask = ggml_new_tensor_4d(ctx0, type_mask, n_tokens, n_tokens, 1, 1); ggml_set_input(inp->self_kq_mask); - inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask; + inp->self_kq_mask_cnv = inp->self_kq_mask; if (hparams.swa_type != LLAMA_SWA_TYPE_NONE) { - inp->self_kq_mask_swa = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_tokens, n_tokens, 1, 1); + inp->self_kq_mask_swa = ggml_new_tensor_4d(ctx0, type_mask, n_tokens, n_tokens, 1, 1); ggml_set_input(inp->self_kq_mask_swa); - inp->self_kq_mask_swa_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask_swa, GGML_TYPE_F16) : inp->self_kq_mask_swa; + inp->self_kq_mask_swa_cnv = inp->self_kq_mask_swa; } else { inp->self_kq_mask_swa = nullptr; inp->self_kq_mask_swa_cnv = nullptr; @@ -1636,6 +2214,7 @@ ggml_tensor * llm_graph_context::build_attn( llm_graph_input_attn_no_cache * inp, ggml_tensor * wo, ggml_tensor * wo_b, + ggml_tensor * wo_s, ggml_tensor * q_cur, ggml_tensor * k_cur, ggml_tensor * v_cur, @@ -1669,7 +2248,7 @@ ggml_tensor * llm_graph_context::build_attn( cb(cur, "kqv_out", il); if (wo) { - cur = build_lora_mm(wo, cur); + cur = build_lora_mm(wo, cur, wo_s); } if (wo_b) { @@ -1695,19 +2274,16 @@ static std::unique_ptr<llm_graph_input_attn_kv> build_attn_inp_kv_impl( { GGML_ASSERT(hparams.swa_type == LLAMA_SWA_TYPE_NONE && "Use llama_kv_cache_iswa for SWA"); - const auto n_kv = mctx_cur->get_n_kv(); - const auto n_tokens = ubatch.n_tokens; - const auto n_stream = cparams.kv_unified ? 1 : ubatch.n_seqs_unq; - inp->self_k_idxs = mctx_cur->build_input_k_idxs(ctx0, ubatch); inp->self_v_idxs = mctx_cur->build_input_v_idxs(ctx0, ubatch); - inp->self_kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, n_tokens/n_stream, 1, n_stream); - ggml_set_input(inp->self_kq_mask); - - inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask; + inp->self_kq_mask = build_attn_inp_kq_mask(ctx0, mctx_cur, ubatch, cparams); + inp->self_kq_mask_cnv = inp->self_kq_mask; } + inp->self_k_rot = mctx_cur->build_input_k_rot(ctx0); + inp->self_v_rot = mctx_cur->build_input_v_rot(ctx0); + return inp; } @@ -1723,14 +2299,26 @@ ggml_tensor * llm_graph_context::build_attn( llm_graph_input_attn_kv * inp, ggml_tensor * wo, ggml_tensor * wo_b, + ggml_tensor * wo_s, ggml_tensor * q_cur, ggml_tensor * k_cur, ggml_tensor * v_cur, ggml_tensor * kq_b, ggml_tensor * sinks, - ggml_tensor * v_mla, + ggml_tensor * v_mla, // TODO: remove float kq_scale, int il) const { + GGML_ASSERT(v_mla == nullptr); + + if (inp->self_k_rot) { + q_cur = ggml_mul_mat_aux(ctx0, q_cur, inp->self_k_rot); + k_cur = ggml_mul_mat_aux(ctx0, k_cur, inp->self_k_rot); + } + + if (inp->self_v_rot) { + v_cur = ggml_mul_mat_aux(ctx0, v_cur, inp->self_v_rot); + } + // these nodes are added to the graph together so that they are not reordered // by doing so, the number of splits in the graph is reduced // expand k later to enable rope fusion which directly writes into k-v cache @@ -1758,11 +2346,107 @@ ggml_tensor * llm_graph_context::build_attn( ggml_tensor * cur = build_attn_mha(q, k, v, kq_b, kq_mask, sinks, v_mla, kq_scale, il); cb(cur, "kqv_out", il); + if (inp->self_v_rot) { + cur = ggml_mul_mat_aux(ctx0, cur, inp->self_v_rot); + } + + if (wo) { + if (arch == LLM_ARCH_GLM4 || arch == LLM_ARCH_GLM4_MOE || arch == LLM_ARCH_JAIS2) { + // GLM4, GLM4_MOE, and JAIS2 seem to have numerical issues with half-precision accumulators + cur = build_lora_mm(wo, cur); + ggml_mul_mat_set_prec(cur, GGML_PREC_F32); + if (wo_s) { + cur = ggml_mul(ctx0, cur, wo_s); + } + } else { + cur = build_lora_mm(wo, cur, wo_s); + } + } + + if (wo_b) { + cur = ggml_add(ctx0, cur, wo_b); + } + + return cur; +} + +static std::unique_ptr<llm_graph_input_attn_k> build_attn_inp_k_impl( + ggml_context * ctx0, + const llama_ubatch & ubatch, + const llama_hparams & hparams, + const llama_cparams & cparams, + const llama_kv_cache_context * mctx_cur) { + + auto inp = std::make_unique<llm_graph_input_attn_k>(hparams, cparams, mctx_cur); + + { + GGML_ASSERT(hparams.swa_type == LLAMA_SWA_TYPE_NONE && "Use llama_kv_cache_iswa for SWA"); + + inp->self_k_idxs = mctx_cur->build_input_k_idxs(ctx0, ubatch); + + inp->self_kq_mask = build_attn_inp_kq_mask(ctx0, mctx_cur, ubatch, cparams); + inp->self_kq_mask_cnv = inp->self_kq_mask; + } + + return inp; +} + +llm_graph_input_attn_k * llm_graph_context::build_attn_inp_k() const { + const auto * mctx_cur = static_cast<const llama_kv_cache_context *>(mctx); + + auto inp = build_attn_inp_k_impl(ctx0, ubatch, hparams, cparams, mctx_cur); + + return (llm_graph_input_attn_k *) res->add_input(std::move(inp)); +} + +ggml_tensor * llm_graph_context::build_attn( + llm_graph_input_attn_k * inp, + ggml_tensor * wo, + ggml_tensor * wo_b, + ggml_tensor * wo_s, + ggml_tensor * q_cur, + ggml_tensor * k_cur, + ggml_tensor * v_cur, + ggml_tensor * kq_b, + ggml_tensor * sinks, + ggml_tensor * v_mla, + float kq_scale, + int il) const { + // these nodes are added to the graph together so that they are not reordered + // by doing so, the number of splits in the graph is reduced + // expand k later to enable rope fusion which directly writes into k-v cache + ggml_build_forward_expand(gf, q_cur); + ggml_build_forward_expand(gf, v_cur); + ggml_build_forward_expand(gf, k_cur); + + const auto * mctx_cur = inp->mctx; + + // store to KV cache + { + const auto & k_idxs = inp->get_k_idxs(); + + ggml_build_forward_expand(gf, mctx_cur->cpy_k(ctx0, k_cur, k_idxs, il)); + } + + const auto & kq_mask = inp->get_kq_mask(); + + ggml_tensor * q = q_cur; + ggml_tensor * k = mctx_cur->get_k(ctx0, il); + ggml_tensor * v = ggml_view_4d(ctx0, k, v_cur->ne[0], k->ne[1], k->ne[2], k->ne[3], k->nb[1], k->nb[2], k->nb[3], 0); + + ggml_tensor * cur = build_attn_mha(q, k, v, kq_b, kq_mask, sinks, v_mla, kq_scale, il); + cb(cur, "kqv_out", il); + if (wo) { - cur = build_lora_mm(wo, cur); if (arch == LLM_ARCH_GLM4 || arch == LLM_ARCH_GLM4_MOE) { // GLM4 and GLM4_MOE seem to have numerical issues with half-precision accumulators + cur = build_lora_mm(wo, cur); ggml_mul_mat_set_prec(cur, GGML_PREC_F32); + if (wo_s) { + cur = ggml_mul(ctx0, cur, wo_s); + } + } else { + cur = build_lora_mm(wo, cur, wo_s); } } @@ -1773,10 +2457,87 @@ ggml_tensor * llm_graph_context::build_attn( return cur; } +ggml_tensor * llm_graph_context::build_attn( + llm_graph_input_attn_k_dsa * inp, + ggml_tensor * wo, + ggml_tensor * wo_b, + ggml_tensor * wo_s, + ggml_tensor * q_cur, + ggml_tensor * k_cur, + ggml_tensor * v_cur, + ggml_tensor * kq_b, + ggml_tensor * sinks, + ggml_tensor * v_mla, + ggml_tensor * top_k, + float kq_scale, + int il) const { + // these nodes are added to the graph together so that they are not reordered + // by doing so, the number of splits in the graph is reduced + // expand k later to enable rope fusion which directly writes into k-v cache + ggml_build_forward_expand(gf, q_cur); + ggml_build_forward_expand(gf, v_cur); + ggml_build_forward_expand(gf, k_cur); + + const auto * mctx_cur = inp->mctx->get_mla(); + + // store to KV cache + { + const auto & k_idxs = inp->get_k_idxs_mla(); + + ggml_build_forward_expand(gf, mctx_cur->cpy_k(ctx0, k_cur, k_idxs, il)); + } + + const auto & kq_mask = inp->get_kq_mask_mla(); + + // prepare new kq mask - starts filled with -INFINITY + ggml_tensor * kq_mask_all = ggml_fill(ctx0, kq_mask, -INFINITY); + + // reshape KQ mask into tensor with rows of size 1: + // [n_kv, n_batch, 1, n_stream] -> [1, n_kv, n_batch, n_stream] + kq_mask_all = ggml_view_4d(ctx0, kq_mask_all, 1, kq_mask_all->ne[0], kq_mask_all->ne[1], kq_mask_all->ne[3], kq_mask_all->nb[0], kq_mask_all->nb[1], kq_mask_all->nb[2], 0); + + // reshape top_k indices: [n_top_k, n_batch, 1, n_stream] -> [n_top_k, n_batch, n_stream, 1] + ggml_tensor * top_k_3d = ggml_view_4d(ctx0, top_k, top_k->ne[0], top_k->ne[1], top_k->ne[3], 1, top_k->nb[1], top_k->nb[2], top_k->ne[3]*top_k->nb[3], 0); + + // prepare zero-filled tensor with rows of size 1: [1, n_top_k, n_batch, n_stream] + // this will be our source of zero values for unmasking top k mask elements + ggml_tensor * zeros = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, 1, top_k_3d->ne[0], top_k_3d->ne[1], top_k_3d->ne[2]); + zeros = ggml_fill(ctx0, zeros, 0.0f); + + // modify KQ mask by unmasking elements that are in top_k indices + // ggml_set_rows([1, n_kv, n_batch, n_stream], [1, n_top_k, n_batch, n_stream], [n_top_k, n_batch, n_stream, 1]) + ggml_tensor * kq_mask_top_k = ggml_set_rows(ctx0, kq_mask_all, zeros, top_k_3d); + + // reshape to restore the original shape of KQ mask: + // [1, n_kv, n_batch, n_stream] -> [n_kv, n_batch, 1, n_stream] + kq_mask_top_k = ggml_view_4d(ctx0, kq_mask_top_k, kq_mask_top_k->ne[1], kq_mask_top_k->ne[2], 1, kq_mask_top_k->ne[3], kq_mask_top_k->nb[2], kq_mask_top_k->nb[3], kq_mask_top_k->nb[3], 0); + + // combine with the original kq mask + kq_mask_top_k = ggml_add(ctx0, kq_mask_top_k, kq_mask); + + ggml_tensor * q = q_cur; + ggml_tensor * k = mctx_cur->get_k(ctx0, il); + ggml_tensor * v = ggml_view_4d(ctx0, k, v_cur->ne[0], k->ne[1], k->ne[2], k->ne[3], k->nb[1], k->nb[2], k->nb[3], 0); + + ggml_tensor * cur = build_attn_mha(q, k, v, kq_b, kq_mask_top_k, sinks, v_mla, kq_scale, il); + cb(cur, "kqv_out", il); + + if (wo) { + cur = build_lora_mm(wo, cur, wo_s); + } + + if (wo_b) { + cur = ggml_add(ctx0, cur, wo_b); + } + + return cur; +} + ggml_tensor * llm_graph_context::build_attn( llm_graph_input_attn_kv_iswa * inp, ggml_tensor * wo, ggml_tensor * wo_b, + ggml_tensor * wo_s, ggml_tensor * q_cur, ggml_tensor * k_cur, ggml_tensor * v_cur, @@ -1785,6 +2546,23 @@ ggml_tensor * llm_graph_context::build_attn( ggml_tensor * v_mla, float kq_scale, int il) const { + const bool is_swa = hparams.is_swa(il); + + auto * k_rot = is_swa ? inp->self_k_rot_swa : inp->self_k_rot; + auto * v_rot = is_swa ? inp->self_v_rot_swa : inp->self_v_rot; + + if (k_rot) { + q_cur = ggml_mul_mat_aux(ctx0, q_cur, k_rot); + if (k_cur) { + k_cur = ggml_mul_mat_aux(ctx0, k_cur, k_rot); + } + } + if (v_rot) { + if (v_cur) { + v_cur = ggml_mul_mat_aux(ctx0, v_cur, v_rot); + } + } + // these nodes are added to the graph together so that they are not reordered // by doing so, the number of splits in the graph is reduced ggml_build_forward_expand(gf, q_cur); @@ -1799,8 +2577,6 @@ ggml_tensor * llm_graph_context::build_attn( const auto * mctx_iswa = inp->mctx; - const bool is_swa = hparams.is_swa(il); - const auto * mctx_cur = is_swa ? mctx_iswa->get_swa() : mctx_iswa->get_base(); // optionally store to KV cache @@ -1825,8 +2601,12 @@ ggml_tensor * llm_graph_context::build_attn( ggml_tensor * cur = build_attn_mha(q, k, v, kq_b, kq_mask, sinks, v_mla, kq_scale, il); cb(cur, "kqv_out", il); + if (v_rot) { + cur = ggml_mul_mat_aux(ctx0, cur, v_rot); + } + if (wo) { - cur = build_lora_mm(wo, cur); + cur = build_lora_mm(wo, cur, wo_s); } if (wo_b) { @@ -1845,10 +2625,13 @@ llm_graph_input_attn_cross * llm_graph_context::build_attn_inp_cross() const { const int32_t n_enc = !cross->v_embd.empty() ? cross->n_enc : hparams.n_ctx_train; - inp->cross_kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_enc, n_tokens, 1, 1); + // flash attention requires an f16 mask + const auto type_mask = cparams.flash_attn ? GGML_TYPE_F16 : GGML_TYPE_F32; + + inp->cross_kq_mask = ggml_new_tensor_4d(ctx0, type_mask, n_enc, n_tokens, 1, 1); ggml_set_input(inp->cross_kq_mask); - inp->cross_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->cross_kq_mask, GGML_TYPE_F16) : inp->cross_kq_mask; + inp->cross_kq_mask_cnv = inp->cross_kq_mask; return (llm_graph_input_attn_cross *) res->add_input(std::move(inp)); } @@ -1857,6 +2640,7 @@ ggml_tensor * llm_graph_context::build_attn( llm_graph_input_attn_cross * inp, ggml_tensor * wo, ggml_tensor * wo_b, + ggml_tensor * wo_s, ggml_tensor * q_cur, ggml_tensor * k_cur, ggml_tensor * v_cur, @@ -1881,7 +2665,7 @@ ggml_tensor * llm_graph_context::build_attn( cb(cur, "kqv_out", il); if (wo) { - cur = build_lora_mm(wo, cur); + cur = build_lora_mm(wo, cur, wo_s); } if (wo_b) { @@ -1895,6 +2679,34 @@ ggml_tensor * llm_graph_context::build_attn( return cur; } +llm_graph_input_attn_k_dsa * llm_graph_context::build_attn_inp_k_dsa() const { + const auto * mctx_cur = static_cast<const llama_kv_cache_dsa_context *>(mctx); + + auto inp = std::make_unique<llm_graph_input_attn_k_dsa>(hparams, cparams, mctx_cur); + + { + inp->self_k_idxs_mla = mctx_cur->get_mla()->build_input_k_idxs(ctx0, ubatch); + + inp->self_kq_mask_mla = build_attn_inp_kq_mask(ctx0, mctx_cur->get_mla(), ubatch, cparams); + inp->self_kq_mask_mla_cnv = inp->self_kq_mask_mla; + } + + { + inp->self_k_idxs_lid = mctx_cur->get_lid()->build_input_k_idxs(ctx0, ubatch); + + // ensure F32 mask + auto cparams_copy = cparams; + cparams_copy.flash_attn = false; + + inp->self_kq_mask_lid = build_attn_inp_kq_mask(ctx0, mctx_cur->get_lid(), ubatch, cparams_copy); + inp->self_kq_mask_lid_cnv = inp->self_kq_mask_lid; + + inp->self_k_rot_lid = mctx_cur->get_lid()->build_input_k_rot(ctx0); + } + + return (llm_graph_input_attn_k_dsa *) res->add_input(std::move(inp)); +} + // TODO: maybe separate the inner implementation into a separate function // like with the non-sliding window equivalent // once sliding-window hybrid caches are a thing. @@ -1903,38 +2715,30 @@ llm_graph_input_attn_kv_iswa * llm_graph_context::build_attn_inp_kv_iswa() const auto inp = std::make_unique<llm_graph_input_attn_kv_iswa>(hparams, cparams, mctx_cur); - const auto n_stream = cparams.kv_unified ? 1 : ubatch.n_seqs_unq; - { - const auto n_kv = mctx_cur->get_base()->get_n_kv(); - inp->self_k_idxs = mctx_cur->get_base()->build_input_k_idxs(ctx0, ubatch); inp->self_v_idxs = mctx_cur->get_base()->build_input_v_idxs(ctx0, ubatch); - inp->self_kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, n_tokens/n_stream, 1, n_stream); - ggml_set_input(inp->self_kq_mask); - ggml_set_name(inp->self_kq_mask, "self_kq_mask"); - - inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask; - ggml_set_name(inp->self_kq_mask_cnv, "self_kq_mask_cnv"); + inp->self_kq_mask = build_attn_inp_kq_mask(ctx0, mctx_cur->get_base(), ubatch, cparams); + inp->self_kq_mask_cnv = inp->self_kq_mask; } { GGML_ASSERT(hparams.swa_type != LLAMA_SWA_TYPE_NONE && "Use llama_kv_cache for non-SWA"); - const auto n_kv = mctx_cur->get_swa()->get_n_kv(); - inp->self_k_idxs_swa = mctx_cur->get_swa()->build_input_k_idxs(ctx0, ubatch); inp->self_v_idxs_swa = mctx_cur->get_swa()->build_input_v_idxs(ctx0, ubatch); - inp->self_kq_mask_swa = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, n_tokens/n_stream, 1, n_stream); - ggml_set_input(inp->self_kq_mask_swa); - ggml_set_name(inp->self_kq_mask_swa, "self_kq_mask_swa"); - - inp->self_kq_mask_swa_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask_swa, GGML_TYPE_F16) : inp->self_kq_mask_swa; - ggml_set_name(inp->self_kq_mask_swa_cnv, "self_kq_mask_swa_cnv"); + inp->self_kq_mask_swa = build_attn_inp_kq_mask(ctx0, mctx_cur->get_swa(), ubatch, cparams); + inp->self_kq_mask_swa_cnv = inp->self_kq_mask_swa; } + inp->self_k_rot = mctx_cur->get_base()->build_input_k_rot(ctx0); + inp->self_v_rot = mctx_cur->get_base()->build_input_v_rot(ctx0); + + inp->self_k_rot_swa = mctx_cur->get_swa()->build_input_k_rot(ctx0); + inp->self_v_rot_swa = mctx_cur->get_swa()->build_input_v_rot(ctx0); + return (llm_graph_input_attn_kv_iswa *) res->add_input(std::move(inp)); } @@ -1950,7 +2754,8 @@ ggml_tensor * llm_graph_context::build_rs( int32_t rs_zero, const llm_graph_get_rows_fn & get_state_rows) const { - ggml_tensor * states = ggml_reshape_2d(ctx0, s, state_size, rs_size); + GGML_UNUSED(rs_size); + ggml_tensor * states = ggml_reshape_2d(ctx0, s, state_size, s->ne[1]); // Clear a single state which will then be copied to the other cleared states. // Note that this is a no-op when the view is zero-sized. @@ -1968,7 +2773,7 @@ ggml_tensor * llm_graph_context::build_rs( ggml_build_forward_expand(gf, ggml_cpy(ctx0, states_extra, - ggml_view_1d(ctx0, s, state_size*(n_rs - n_seqs), (rs_head + n_seqs)*state_size*ggml_element_size(s)))); + ggml_view_2d(ctx0, s, state_size, (n_rs - n_seqs), s->nb[1], (rs_head + n_seqs)*s->nb[1]))); return output_states; } @@ -2068,10 +2873,53 @@ llm_graph_input_mem_hybrid * llm_graph_context::build_inp_mem_hybrid() const { return (llm_graph_input_mem_hybrid *) res->add_input(std::move(inp)); } +llm_graph_input_mem_hybrid_k * llm_graph_context::build_inp_mem_hybrid_k() const { + const auto * mctx_cur = static_cast<const llama_memory_hybrid_context *>(mctx); + + auto inp_rs = build_rs_inp_impl (ctx0, ubatch, mctx_cur->get_recr()); + auto inp_attn = build_attn_inp_k_impl(ctx0, ubatch, hparams, cparams, mctx_cur->get_attn()); + + auto inp = std::make_unique<llm_graph_input_mem_hybrid_k>(cparams, std::move(inp_attn), std::move(inp_rs), mctx_cur); + + return (llm_graph_input_mem_hybrid_k *) res->add_input(std::move(inp)); +} + +llm_graph_input_mem_hybrid_iswa * llm_graph_context::build_inp_mem_hybrid_iswa() const { + const auto * mctx_cur = static_cast<const llama_memory_hybrid_iswa_context *>(mctx); + + auto inp_rs = build_rs_inp_impl(ctx0, ubatch, mctx_cur->get_recr()); + + // build iswa attention input + const auto * attn_ctx = mctx_cur->get_attn(); + + auto inp_attn = std::make_unique<llm_graph_input_attn_kv_iswa>(hparams, cparams, attn_ctx); + + { + inp_attn->self_k_idxs = attn_ctx->get_base()->build_input_k_idxs(ctx0, ubatch); + inp_attn->self_v_idxs = attn_ctx->get_base()->build_input_v_idxs(ctx0, ubatch); + + inp_attn->self_kq_mask = build_attn_inp_kq_mask(ctx0, attn_ctx->get_base(), ubatch, cparams); + inp_attn->self_kq_mask_cnv = inp_attn->self_kq_mask; + } + + { + inp_attn->self_k_idxs_swa = attn_ctx->get_swa()->build_input_k_idxs(ctx0, ubatch); + inp_attn->self_v_idxs_swa = attn_ctx->get_swa()->build_input_v_idxs(ctx0, ubatch); + + inp_attn->self_kq_mask_swa = build_attn_inp_kq_mask(ctx0, attn_ctx->get_swa(), ubatch, cparams); + inp_attn->self_kq_mask_swa_cnv = inp_attn->self_kq_mask_swa; + } + + auto inp = std::make_unique<llm_graph_input_mem_hybrid_iswa>(cparams, std::move(inp_attn), std::move(inp_rs), mctx_cur); + + return (llm_graph_input_mem_hybrid_iswa *) res->add_input(std::move(inp)); +} + void llm_graph_context::build_dense_out( ggml_tensor * dense_2, + ggml_tensor * dense_2_b, ggml_tensor * dense_3) const { - if (!cparams.embeddings || !(dense_2 || dense_3)) { + if (!cparams.embeddings || !(dense_2 || dense_2_b || dense_3)) { return; } ggml_tensor * cur = res->t_embd_pooled != nullptr ? res->t_embd_pooled : res->t_embd; @@ -2080,6 +2928,9 @@ void llm_graph_context::build_dense_out( if (dense_2) { cur = ggml_mul_mat(ctx0, dense_2, cur); } + if (dense_2_b) { + cur = ggml_add(ctx0, cur, dense_2_b); + } if (dense_3) { cur = ggml_mul_mat(ctx0, dense_3, cur); } @@ -2093,7 +2944,8 @@ void llm_graph_context::build_pooling( ggml_tensor * cls, ggml_tensor * cls_b, ggml_tensor * cls_out, - ggml_tensor * cls_out_b) const { + ggml_tensor * cls_out_b, + ggml_tensor * cls_norm) const { if (!cparams.embeddings) { return; } @@ -2132,8 +2984,15 @@ void llm_graph_context::build_pooling( } break; case LLAMA_POOLING_TYPE_RANK: { - ggml_tensor * inp_cls = build_inp_cls(); - cur = ggml_get_rows(ctx0, inp, inp_cls); + if (arch == LLM_ARCH_MODERN_BERT) { + // modern bert gte reranker builds mean first then applies prediction head and classifier + // https://github.com/huggingface/transformers/blob/main/src/transformers/models/modernbert/modular_modernbert.py#L1404-1411 + ggml_tensor * inp_mean = build_inp_mean(); + cur = ggml_mul_mat(ctx0, ggml_cont(ctx0, ggml_transpose(ctx0, inp)), inp_mean); + } else { + ggml_tensor * inp_cls = build_inp_cls(); + cur = ggml_get_rows(ctx0, inp, inp_cls); + } // classification head // https://github.com/huggingface/transformers/blob/5af7d41e49bbfc8319f462eb45253dcb3863dfb7/src/transformers/models/roberta/modeling_roberta.py#L1566 @@ -2142,7 +3001,15 @@ void llm_graph_context::build_pooling( if (cls_b) { cur = ggml_add(ctx0, cur, cls_b); } - cur = ggml_tanh(ctx0, cur); + if (arch == LLM_ARCH_MODERN_BERT) { + cur = ggml_gelu(ctx0, cur); + } else { + cur = ggml_tanh(ctx0, cur); + } + if (cls_norm) { + // head norm + cur = build_norm(cur, cls_norm, NULL, LLM_NORM, -1); + } } // some models don't have `cls_out`, for example: https://huggingface.co/jinaai/jina-reranker-v1-tiny-en @@ -2157,7 +3024,7 @@ void llm_graph_context::build_pooling( } // softmax for qwen3 reranker - if (arch == LLM_ARCH_QWEN3) { + if (arch == LLM_ARCH_QWEN3 || arch == LLM_ARCH_QWEN3VL) { cur = ggml_soft_max(ctx0, cur); } } break; @@ -2178,6 +3045,9 @@ void llm_graph_context::build_sampling() const { return; } + std::array<ggml_tensor *, 2> outs; + outs[0] = res->t_logits; + auto inp_sampling = std::make_unique<llm_graph_input_sampling>(samplers); res->add_input(std::move(inp_sampling)); @@ -2198,14 +3068,14 @@ void llm_graph_context::build_sampling() const { // add a dummy row of logits // this trick makes the graph static, regardless of which samplers are activated // this is important in order to minimize graph reallocations - // TODO: use `ggml_build_forward_select()` when available (https://github.com/ggml-org/llama.cpp/pull/18550) ggml_tensor * logits_t = ggml_pad(ctx0, res->t_logits, 0, 1, 0, 0); for (const auto & [seq_id, sampler] : samplers) { const auto it = seq_to_logit_row.find(seq_id); // inactive samplers always work on the first row - const auto row_idx = seq_to_logit_row.find(seq_id) != seq_to_logit_row.end() ? it->second : 0; + const auto row_idx = it != seq_to_logit_row.end() ? it->second : 0; + const int i_out = it != seq_to_logit_row.end() ? 1 : 0; ggml_tensor * logits_seq = ggml_view_1d(ctx0, logits_t, logits_t->ne[0], row_idx * logits_t->nb[1]); ggml_format_name(logits_seq, "logits_seq_%d", seq_id); @@ -2222,22 +3092,26 @@ void llm_graph_context::build_sampling() const { if (data.sampled != nullptr) { res->t_sampled[seq_id] = data.sampled; - ggml_build_forward_expand(gf, data.sampled); + outs[1] = data.sampled; + ggml_build_forward_select(gf, outs.data(), outs.size(), i_out); } if (data.probs != nullptr) { res->t_sampled_probs[seq_id] = data.probs; - ggml_build_forward_expand(gf, data.probs); + outs[1] = data.probs; + ggml_build_forward_select(gf, outs.data(), outs.size(), i_out); } if (data.logits != nullptr) { res->t_sampled_logits[seq_id] = data.logits; - ggml_build_forward_expand(gf, data.logits); + outs[1] = data.logits; + ggml_build_forward_select(gf, outs.data(), outs.size(), i_out); } if (data.candidates != nullptr) { res->t_candidates[seq_id] = data.candidates; - ggml_build_forward_expand(gf, data.candidates); + outs[1] = data.candidates; + ggml_build_forward_select(gf, outs.data(), outs.size(), i_out); } } diff --git a/examples/talk-llama/llama-graph.h b/examples/talk-llama/llama-graph.h index 503ffd695aa..cc5cfe51dcd 100644 --- a/examples/talk-llama/llama-graph.h +++ b/examples/talk-llama/llama-graph.h @@ -17,22 +17,27 @@ struct ggml_context; struct ggml_tensor; struct llama_cparams; +struct llama_layer; struct llama_memory_context_i; class llama_kv_cache_context; +class llama_kv_cache_dsa_context; class llama_kv_cache_iswa_context; class llama_memory_recurrent_context; class llama_memory_hybrid_context; +class llama_memory_hybrid_iswa_context; // certain models (typically multi-modal) can produce different types of graphs enum llm_graph_type { LLM_GRAPH_TYPE_DEFAULT, LLM_GRAPH_TYPE_ENCODER, LLM_GRAPH_TYPE_DECODER, + LLM_GRAPH_TYPE_DECODER_MTP, }; -enum llm_ffn_op_type { +enum llm_ffn_op_type : int { + LLM_FFN_NONE = 0, // sentinel: unset; archs must assign before use LLM_FFN_SILU, LLM_FFN_GELU, LLM_FFN_RELU, @@ -105,7 +110,7 @@ using llm_graph_input_ptr = std::unique_ptr<llm_graph_input_i>; class llm_graph_input_embd : public llm_graph_input_i { public: - llm_graph_input_embd() = default; + llm_graph_input_embd(int64_t n_embd) : n_embd(n_embd) {} virtual ~llm_graph_input_embd() = default; void set_input(const llama_ubatch * ubatch) override; @@ -114,6 +119,25 @@ class llm_graph_input_embd : public llm_graph_input_i { ggml_tensor * tokens = nullptr; // I32 [n_batch] ggml_tensor * embd = nullptr; // F32 [n_embd, n_batch] + + const int64_t n_embd = 0; +}; + +// similar to llm_graph_input_embd but with an additional hidden state input +class llm_graph_input_embd_h : public llm_graph_input_i { +public: + llm_graph_input_embd_h(int64_t n_embd) : n_embd(n_embd) {} + virtual ~llm_graph_input_embd_h() = default; + + void set_input(const llama_ubatch * ubatch) override; + + bool can_reuse(const llm_graph_params & params) override; + + ggml_tensor * tokens = nullptr; // I32 [n_batch] + ggml_tensor * embd = nullptr; // F32 [n_embd, n_batch] + ggml_tensor * h = nullptr; // F32 [n_embd, n_batch] + + const int64_t n_embd = 0; }; class llm_graph_input_pos : public llm_graph_input_i { @@ -269,10 +293,10 @@ class llm_graph_input_attn_no_cache : public llm_graph_input_i { ggml_tensor * get_kq_mask_swa() const { return self_kq_mask_swa_cnv; } // n_tokens == n_batch - ggml_tensor * self_kq_mask = nullptr; // F32 [n_tokens, n_batch/n_stream, 1, n_stream] - ggml_tensor * self_kq_mask_cnv = nullptr; // [n_tokens, n_batch/n_stream, 1, n_stream] - ggml_tensor * self_kq_mask_swa = nullptr; // F32 [n_tokens, n_batch/n_stream, 1, n_stream] - ggml_tensor * self_kq_mask_swa_cnv = nullptr; // [n_tokens, n_batch/n_stream, 1, n_stream] + ggml_tensor * self_kq_mask = nullptr; // F32/F16 [n_tokens, n_batch/n_stream, 1, n_stream] + ggml_tensor * self_kq_mask_cnv = nullptr; // [n_tokens, n_batch/n_stream, 1, n_stream] + ggml_tensor * self_kq_mask_swa = nullptr; // F32/F16 [n_tokens, n_batch/n_stream, 1, n_stream] + ggml_tensor * self_kq_mask_swa_cnv = nullptr; // [n_tokens, n_batch/n_stream, 1, n_stream] const llama_hparams hparams; const llama_cparams cparams; @@ -302,8 +326,12 @@ class llm_graph_input_attn_kv : public llm_graph_input_i { ggml_tensor * self_k_idxs = nullptr; // I64 [n_batch] ggml_tensor * self_v_idxs = nullptr; // I64 [n_batch] or [n_batch*n_embd_v_gqa] - ggml_tensor * self_kq_mask = nullptr; // F32 [n_kv, n_batch/n_stream, 1, n_stream] - ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch/n_stream, 1, n_stream] + ggml_tensor * self_kq_mask = nullptr; // F32/F16 [n_kv, n_batch/n_stream, 1, n_stream] + ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch/n_stream, 1, n_stream] + + // note: assumes v_rot^2 == I + ggml_tensor * self_k_rot = nullptr; + ggml_tensor * self_v_rot = nullptr; // note: these have to be copies because in order to be able to reuse a graph, its inputs // need to carry these parameters with them. otherwise, they can point to freed @@ -314,6 +342,77 @@ class llm_graph_input_attn_kv : public llm_graph_input_i { const llama_kv_cache_context * mctx; }; +// V-less input for the KV cache +// ref: https://github.com/ggml-org/llama.cpp/pull/19067 +class llm_graph_input_attn_k : public llm_graph_input_i { +public: + llm_graph_input_attn_k( + const llama_hparams & hparams, + const llama_cparams & cparams, + const llama_kv_cache_context * mctx) : + hparams(hparams), + cparams(cparams), + mctx(mctx) { + } + ~llm_graph_input_attn_k() = default; + + void set_input(const llama_ubatch * ubatch) override; + + bool can_reuse(const llm_graph_params & params) override; + + ggml_tensor * get_k_idxs() const { return self_k_idxs; } + + ggml_tensor * get_kq_mask() const { return self_kq_mask_cnv; } + + ggml_tensor * self_k_idxs = nullptr; // I64 [n_batch] + + ggml_tensor * self_kq_mask = nullptr; // F32/F16 [n_kv, n_batch/n_stream, 1, n_stream] + ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch/n_stream, 1, n_stream] + + const llama_hparams hparams; + const llama_cparams cparams; + + const llama_kv_cache_context * mctx; +}; + +class llm_graph_input_attn_k_dsa : public llm_graph_input_i { +public: + llm_graph_input_attn_k_dsa( + const llama_hparams & hparams, + const llama_cparams & cparams, + const llama_kv_cache_dsa_context * mctx) : + hparams(hparams), + cparams(cparams), + mctx(mctx) { + } + ~llm_graph_input_attn_k_dsa() = default; + + void set_input(const llama_ubatch * ubatch) override; + + bool can_reuse(const llm_graph_params & params) override; + + ggml_tensor * get_k_idxs_mla() const { return self_k_idxs_mla; } + ggml_tensor * get_k_idxs_lid() const { return self_k_idxs_lid; } + + ggml_tensor * get_kq_mask_mla() const { return self_kq_mask_mla_cnv; } + ggml_tensor * get_kq_mask_lid() const { return self_kq_mask_lid; } + + ggml_tensor * self_k_idxs_mla = nullptr; // I64 [n_batch] + ggml_tensor * self_k_idxs_lid = nullptr; // I64 [n_batch] + + ggml_tensor * self_kq_mask_mla = nullptr; // F32/F16 [n_kv, n_batch/n_stream, 1, n_stream] + ggml_tensor * self_kq_mask_mla_cnv = nullptr; // [n_kv, n_batch/n_stream, 1, n_stream] + ggml_tensor * self_kq_mask_lid = nullptr; // F32 [n_kv, n_batch/n_stream, 1, n_stream] + ggml_tensor * self_kq_mask_lid_cnv = nullptr; // [n_kv, n_batch/n_stream, 1, n_stream] + + ggml_tensor * self_k_rot_lid = nullptr; + + const llama_hparams hparams; + const llama_cparams cparams; + + const llama_kv_cache_dsa_context * mctx; +}; + class llm_graph_input_attn_kv_iswa : public llm_graph_input_i { public: llm_graph_input_attn_kv_iswa( @@ -343,10 +442,16 @@ class llm_graph_input_attn_kv_iswa : public llm_graph_input_i { ggml_tensor * self_k_idxs_swa = nullptr; // I64 [n_batch] ggml_tensor * self_v_idxs_swa = nullptr; // I64 [n_batch] or [n_batch*n_embd_v_gqa] - ggml_tensor * self_kq_mask = nullptr; // F32 [n_kv, n_batch/n_stream, 1, n_stream] - ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch/n_stream, 1, n_stream] - ggml_tensor * self_kq_mask_swa = nullptr; // F32 [n_kv, n_batch/n_stream, 1, n_stream] - ggml_tensor * self_kq_mask_swa_cnv = nullptr; // [n_kv, n_batch/n_stream, 1, n_stream] + ggml_tensor * self_kq_mask = nullptr; // F32/F16 [n_kv, n_batch/n_stream, 1, n_stream] + ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch/n_stream, 1, n_stream] + ggml_tensor * self_kq_mask_swa = nullptr; // F32/F16 [n_kv, n_batch/n_stream, 1, n_stream] + ggml_tensor * self_kq_mask_swa_cnv = nullptr; // [n_kv, n_batch/n_stream, 1, n_stream] + + ggml_tensor * self_k_rot = nullptr; + ggml_tensor * self_v_rot = nullptr; + + ggml_tensor * self_k_rot_swa = nullptr; + ggml_tensor * self_v_rot_swa = nullptr; const llama_hparams hparams; const llama_cparams cparams; @@ -363,8 +468,8 @@ class llm_graph_input_attn_cross : public llm_graph_input_i { ggml_tensor * get_kq_mask_cross() const { return cross_kq_mask_cnv; } - ggml_tensor * cross_kq_mask = nullptr; // F32 [n_outputs_enc, n_batch, 1, 1] - ggml_tensor * cross_kq_mask_cnv = nullptr; // F32 [n_outputs_enc, n_batch, 1, 1] + ggml_tensor * cross_kq_mask = nullptr; // F32/F16 [n_outputs_enc, n_batch, 1, 1] + ggml_tensor * cross_kq_mask_cnv = nullptr; // F32/F16 [n_outputs_enc, n_batch, 1, 1] const llama_cross * cross = nullptr; }; @@ -397,6 +502,62 @@ class llm_graph_input_mem_hybrid : public llm_graph_input_i { const llama_memory_hybrid_context * mctx; }; +class llm_graph_input_mem_hybrid_k : public llm_graph_input_i { +public: + llm_graph_input_mem_hybrid_k( + const llama_cparams & cparams, + std::unique_ptr<llm_graph_input_attn_k> inp_attn, + std::unique_ptr<llm_graph_input_rs> inp_rs, + const llama_memory_hybrid_context * mctx) : + inp_attn(std::move(inp_attn)), + inp_rs(std::move(inp_rs)), + cparams(cparams), + mctx(mctx) { } + virtual ~llm_graph_input_mem_hybrid_k() = default; + + void set_input(const llama_ubatch * ubatch) override; + + bool can_reuse(const llm_graph_params & params) override; + + std::unique_ptr<llm_graph_input_attn_k> inp_attn; + std::unique_ptr<llm_graph_input_rs> inp_rs; + + llm_graph_input_attn_k * get_attn() const { return inp_attn.get(); } + llm_graph_input_rs * get_recr() const { return inp_rs.get(); } + + const llama_cparams cparams; + + const llama_memory_hybrid_context * mctx; +}; + +class llm_graph_input_mem_hybrid_iswa : public llm_graph_input_i { +public: + llm_graph_input_mem_hybrid_iswa( + const llama_cparams & cparams, + std::unique_ptr<llm_graph_input_attn_kv_iswa> inp_attn, + std::unique_ptr<llm_graph_input_rs> inp_rs, + const llama_memory_hybrid_iswa_context * mctx) : + inp_attn(std::move(inp_attn)), + inp_rs(std::move(inp_rs)), + cparams(cparams), + mctx(mctx) { } + virtual ~llm_graph_input_mem_hybrid_iswa() = default; + + void set_input(const llama_ubatch * ubatch) override; + + bool can_reuse(const llm_graph_params & params) override; + + std::unique_ptr<llm_graph_input_attn_kv_iswa> inp_attn; + std::unique_ptr<llm_graph_input_rs> inp_rs; + + llm_graph_input_attn_kv_iswa * get_attn() const { return inp_attn.get(); } + llm_graph_input_rs * get_recr() const { return inp_rs.get(); } + + const llama_cparams cparams; + + const llama_memory_hybrid_iswa_context * mctx; +}; + class llm_graph_input_sampling : public llm_graph_input_i { public: llm_graph_input_sampling(std::map<llama_seq_id, llama_sampler *> samplers) : @@ -477,7 +638,8 @@ struct llm_graph_params { ubatch.n_seqs_unq == other.ubatch.n_seqs_unq && ( (!ubatch.token && !other.ubatch.token) || - (!ubatch.embd && !other.ubatch.embd) + (!ubatch.embd && !other.ubatch.embd) || + (ubatch.token && other.ubatch.token && ubatch.embd && other.ubatch.embd) ); // when we split the batch using "equal_seqs" we have to verify that the participating sequences are the same @@ -537,10 +699,13 @@ class llm_graph_result { virtual ~llm_graph_result() = default; - ggml_tensor * get_tokens() const { return t_tokens; } + ggml_tensor * get_inp_tokens() const { return t_inp_tokens; } ggml_tensor * get_logits() const { return t_logits; } ggml_tensor * get_embd() const { return t_embd; } ggml_tensor * get_embd_pooled() const { return t_embd_pooled; } + ggml_tensor * get_h_nextn() const { return t_h_nextn; } + + ggml_tensor * get_layer_inp(int il) const { return t_layer_inp[il]; } ggml_cgraph * get_gf() const { return gf; } ggml_context * get_ctx() const { return ctx_compute.get(); } @@ -550,7 +715,7 @@ class llm_graph_result { void reset(); void set_inputs(const llama_ubatch * ubatch); - void set_outputs(); + void set_outputs(const llm_graph_params & params); // try to update the existing graph result using the new graph parameters in order to reuse it // this can only be done if we determine that the resulting graph using the new graph parameters @@ -564,15 +729,19 @@ class llm_graph_result { void set_params(const llm_graph_params & params); // important graph nodes - ggml_tensor * t_tokens = nullptr; + ggml_tensor * t_inp_tokens = nullptr; + ggml_tensor * t_inp_embd = nullptr; // [n_embd_inp, n_tokens] ggml_tensor * t_logits = nullptr; ggml_tensor * t_embd = nullptr; ggml_tensor * t_embd_pooled = nullptr; + ggml_tensor * t_h_nextn = nullptr; // [n_embd, n_outputs] hidden state before final output norm + + std::vector<ggml_tensor *> t_layer_inp; - std::map<llama_seq_id, ggml_tensor*> t_sampled_logits; - std::map<llama_seq_id, ggml_tensor*> t_candidates; - std::map<llama_seq_id, ggml_tensor*> t_sampled; - std::map<llama_seq_id, ggml_tensor*> t_sampled_probs; + std::map<llama_seq_id, ggml_tensor *> t_sampled_logits; + std::map<llama_seq_id, ggml_tensor *> t_candidates; + std::map<llama_seq_id, ggml_tensor *> t_sampled; + std::map<llama_seq_id, ggml_tensor *> t_sampled_probs; std::vector<llm_graph_input_ptr> inputs; @@ -604,6 +773,12 @@ using llm_graph_result_ptr = std::unique_ptr<llm_graph_result>; // used in build_rs to properly order writes and avoid unnecessary copies using llm_graph_get_rows_fn = std::function<ggml_tensor * (ggml_context *, ggml_tensor * states, ggml_tensor * ids)>; +struct llm_graph_qkv { + ggml_tensor * q; // [n_embd_head, n_head, n_tokens] + ggml_tensor * k; // [n_embd_head, n_head_kv, n_tokens] + ggml_tensor * v; // [n_embd_head, n_head_kv, n_tokens] +}; + struct llm_graph_context { const llm_arch arch; @@ -613,6 +788,7 @@ struct llm_graph_context { const int64_t n_embd; const int64_t n_layer; + const int64_t n_layer_nextn; const int64_t n_rot; const int64_t n_ctx; // user-specified context size (can be different from n_ctx_train) const int64_t n_head; @@ -671,10 +847,11 @@ struct llm_graph_context { ggml_tensor * cur, int il) const; - // do mat_mul, while optionally apply lora + // do mat_mul, while optionally apply lora and per-tensor scale ggml_tensor * build_lora_mm( ggml_tensor * w, - ggml_tensor * cur) const; + ggml_tensor * cur, + ggml_tensor * w_s = nullptr) const; // do mat_mul_id, while optionally apply lora ggml_tensor * build_lora_mm_id( @@ -689,6 +866,17 @@ struct llm_graph_context { llm_norm_type type, int il) const; + + // compute Q, K, V projections with optional bias and reshape + // supports both fused wqkv and separate wq/wk/wv paths + llm_graph_qkv build_qkv( + const llama_layer & layer, + ggml_tensor * cur, + int64_t n_embd_head, + int64_t n_head, + int64_t n_head_kv, + int il) const; + ggml_tensor * build_ffn( ggml_tensor * cur, ggml_tensor * up, @@ -717,11 +905,14 @@ struct llm_graph_context { int64_t n_expert_used, llm_ffn_op_type type_op, bool norm_w, - bool scale_w, float w_scale, llama_expert_gating_func_type gating_op, int il, - ggml_tensor * probs_in = nullptr) const; + ggml_tensor * probs_in = nullptr, + ggml_tensor * gate_up_exps = nullptr, + ggml_tensor * up_exps_s = nullptr, + ggml_tensor * gate_exps_s = nullptr, + ggml_tensor * down_exps_s = nullptr) const; ggml_tensor * build_moe_ffn( ggml_tensor * cur, @@ -738,11 +929,15 @@ struct llm_graph_context { int64_t n_expert_used, llm_ffn_op_type type_op, bool norm_w, - bool scale_w, float w_scale, llama_expert_gating_func_type gating_op, int il, - ggml_tensor * probs_in = nullptr) const; + ggml_tensor * probs_in = nullptr, + ggml_tensor * gate_up_exps = nullptr, + ggml_tensor * gate_up_exps_b = nullptr, + ggml_tensor * up_exps_s = nullptr, + ggml_tensor * gate_exps_s = nullptr, + ggml_tensor * down_exps_s = nullptr) const; // // inputs @@ -781,6 +976,7 @@ struct llm_graph_context { llm_graph_input_attn_no_cache * inp, ggml_tensor * wo, ggml_tensor * wo_b, + ggml_tensor * wo_s, ggml_tensor * q_cur, // [n_embd_head_q, n_head_q, n_tokens] ggml_tensor * k_cur, // [n_embd_head_k, n_head_k, n_tokens] ggml_tensor * v_cur, // [n_embd_head_v, n_head_v, n_tokens] @@ -796,12 +992,46 @@ struct llm_graph_context { llm_graph_input_attn_kv * inp, ggml_tensor * wo, ggml_tensor * wo_b, + ggml_tensor * wo_s, + ggml_tensor * q_cur, // [n_embd_head_q, n_head_q, n_tokens] + ggml_tensor * k_cur, // [n_embd_head_k, n_head_k, n_tokens] + ggml_tensor * v_cur, // [n_embd_head_v, n_head_v, n_tokens] + ggml_tensor * kq_b, + ggml_tensor * sinks, // [n_head_q] + ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v] // TODO: remove + float kq_scale, + int il) const; + + llm_graph_input_attn_k * build_attn_inp_k() const; + + ggml_tensor * build_attn( + llm_graph_input_attn_k * inp, + ggml_tensor * wo, + ggml_tensor * wo_b, + ggml_tensor * wo_s, + ggml_tensor * q_cur, // [n_embd_head_q, n_head_q, n_tokens] + ggml_tensor * k_cur, // [n_embd_head_k, n_head_k, n_tokens] + ggml_tensor * v_cur, // [n_embd_head_v, n_head_v, n_tokens] + ggml_tensor * kq_b, + ggml_tensor * sinks, // [n_head_q] + ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v] + float kq_scale, + int il) const; + + llm_graph_input_attn_k_dsa * build_attn_inp_k_dsa() const; + + ggml_tensor * build_attn( + llm_graph_input_attn_k_dsa * inp, + ggml_tensor * wo, + ggml_tensor * wo_b, + ggml_tensor * wo_s, ggml_tensor * q_cur, // [n_embd_head_q, n_head_q, n_tokens] ggml_tensor * k_cur, // [n_embd_head_k, n_head_k, n_tokens] ggml_tensor * v_cur, // [n_embd_head_v, n_head_v, n_tokens] ggml_tensor * kq_b, ggml_tensor * sinks, // [n_head_q] ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v] + ggml_tensor * top_k, // [n_indexer_top_k, n_tokens] float kq_scale, int il) const; @@ -812,6 +1042,7 @@ struct llm_graph_context { llm_graph_input_attn_kv_iswa * inp, ggml_tensor * wo, ggml_tensor * wo_b, + ggml_tensor * wo_s, ggml_tensor * q_cur, // [n_embd_head_q, n_head_q, n_tokens] ggml_tensor * k_cur, // [n_embd_head_k, n_head_k, n_tokens] optional ggml_tensor * v_cur, // [n_embd_head_v, n_head_v, n_tokens] optional @@ -827,6 +1058,7 @@ struct llm_graph_context { llm_graph_input_attn_cross * inp, ggml_tensor * wo, ggml_tensor * wo_b, + ggml_tensor * wo_s, ggml_tensor * q_cur, // [n_embd_head_q, n_head_q, n_tokens] ggml_tensor * k_cur, // [n_embd_head_k, n_head_k, n_tokens] ggml_tensor * v_cur, // [n_embd_head_v, n_head_v, n_tokens] @@ -880,6 +1112,9 @@ struct llm_graph_context { // llm_graph_input_mem_hybrid * build_inp_mem_hybrid() const; + llm_graph_input_mem_hybrid_k * build_inp_mem_hybrid_k() const; + + llm_graph_input_mem_hybrid_iswa * build_inp_mem_hybrid_iswa() const; // // pooling @@ -889,7 +1124,8 @@ struct llm_graph_context { ggml_tensor * cls, ggml_tensor * cls_b, ggml_tensor * cls_out, - ggml_tensor * cls_out_b) const; + ggml_tensor * cls_out_b, + ggml_tensor * cls_norm) const; // // sampling (backend sampling) @@ -903,6 +1139,7 @@ struct llm_graph_context { void build_dense_out( ggml_tensor * dense_2, + ggml_tensor * dense_2_b, ggml_tensor * dense_3) const; }; diff --git a/examples/talk-llama/llama-hparams.cpp b/examples/talk-llama/llama-hparams.cpp index c847ef91b7a..2bf57687382 100644 --- a/examples/talk-llama/llama-hparams.cpp +++ b/examples/talk-llama/llama-hparams.cpp @@ -7,19 +7,39 @@ void llama_hparams::set_swa_pattern(uint32_t n_pattern, bool dense_first) { if (dense_first) { - for (uint32_t il = 0; il < n_layer; ++il) { - swa_layers[il] = n_pattern == 0 || (il % n_pattern != 0); + for (uint32_t il = 0; il < n_layer(); ++il) { + is_swa_impl[il] = n_pattern == 0 || (il % n_pattern != 0); } } else { - for (uint32_t il = 0; il < n_layer; ++il) { - swa_layers[il] = n_pattern == 0 || (il % n_pattern < (n_pattern - 1)); + for (uint32_t il = 0; il < n_layer(); ++il) { + is_swa_impl[il] = n_pattern == 0 || (il % n_pattern < (n_pattern - 1)); } } + + for (uint32_t il = n_layer(); il < n_layer_all; ++il) { + is_swa_impl[il] = false; + } +} + +void llama_hparams::set_recr_pattern(uint32_t n_pattern, bool dense_first) { + if (dense_first) { + for (uint32_t il = 0; il < n_layer(); ++il) { + is_recr_impl[il] = n_pattern == 0 || (il % n_pattern != 0); + } + } else { + for (uint32_t il = 0; il < n_layer(); ++il) { + is_recr_impl[il] = n_pattern == 0 || (il % n_pattern < (n_pattern - 1)); + } + } + + for (uint32_t il = n_layer(); il < n_layer_all; ++il) { + is_recr_impl[il] = false; + } } bool llama_hparams::is_swa_any() const { - for (uint32_t il = 0; il < n_layer; ++il) { - if (swa_layers[il]) { + for (uint32_t il = 0; il < n_layer_all; ++il) { + if (is_swa_impl[il]) { return true; } } @@ -28,7 +48,7 @@ bool llama_hparams::is_swa_any() const { } uint32_t llama_hparams::n_head(uint32_t il) const { - if (il < n_layer) { + if (il < n_layer_all) { return n_head_arr[il]; } @@ -36,7 +56,7 @@ uint32_t llama_hparams::n_head(uint32_t il) const { } uint32_t llama_hparams::n_head_kv(uint32_t il) const { - if (il < n_layer) { + if (il < n_layer_all) { return n_head_kv_arr[il]; } @@ -44,7 +64,7 @@ uint32_t llama_hparams::n_head_kv(uint32_t il) const { } uint32_t llama_hparams::n_ff(uint32_t il) const { - if (il < n_layer) { + if (il < n_layer_all) { return n_ff_arr[il]; } @@ -62,7 +82,19 @@ uint32_t llama_hparams::n_gqa(uint32_t il) const { return n_head/n_head_kv; } +uint32_t llama_hparams::n_rot(uint32_t il) const { + if (il < n_layer_all) { + return is_swa(il) ? n_rot_swa : n_rot_full; + } + + GGML_ABORT("fatal error"); +} + uint32_t llama_hparams::n_embd_inp() const { + if (n_embd_inp_impl > 0) { + return n_embd_inp_impl; + } + uint32_t n_embd_inp = n_embd; if (n_deepstack_layers > 0) { @@ -72,25 +104,41 @@ uint32_t llama_hparams::n_embd_inp() const { return n_embd_inp; } -uint32_t llama_hparams::get_n_embd_out() const { - return n_embd_out > 0 ? n_embd_out : n_embd; +uint32_t llama_hparams::n_embd_out() const { + return n_embd_out_impl > 0 ? n_embd_out_impl : n_embd; +} + +uint32_t llama_hparams::n_embd_head_k(uint32_t il) const { + if (il < n_layer_all) { + return is_swa(il) ? n_embd_head_k_swa : n_embd_head_k_full; + } + + GGML_ABORT("fatal error"); +} + +uint32_t llama_hparams::n_embd_head_v(uint32_t il) const { + if (il < n_layer_all) { + return is_swa(il) ? n_embd_head_v_swa : n_embd_head_v_full; + } + + GGML_ABORT("fatal error"); } uint32_t llama_hparams::n_embd_k_gqa(uint32_t il) const { const uint32_t n_head_kv = this->n_head_kv(il); - return n_embd_head_k * n_head_kv; + return n_embd_head_k(il) * n_head_kv; } uint32_t llama_hparams::n_embd_v_gqa(uint32_t il) const { const uint32_t n_head_kv = this->n_head_kv(il); - return n_embd_head_v * n_head_kv; + return n_embd_head_v(il) * n_head_kv; } bool llama_hparams::is_n_embd_k_gqa_variable() const { const uint32_t val = n_embd_k_gqa(); - for (uint32_t il = 0; il < n_layer; ++il) { + for (uint32_t il = 0; il < n_layer_all; ++il) { if (val != n_embd_k_gqa(il)) { return true; } @@ -101,7 +149,7 @@ bool llama_hparams::is_n_embd_k_gqa_variable() const { bool llama_hparams::is_n_embd_v_gqa_variable() const { const uint32_t val = n_embd_v_gqa(); - for (uint32_t il = 0; il < n_layer; ++il) { + for (uint32_t il = 0; il < n_layer_all; ++il) { if (val != n_embd_v_gqa(il)) { return true; } @@ -112,7 +160,7 @@ bool llama_hparams::is_n_embd_v_gqa_variable() const { uint32_t llama_hparams::n_embd_k_gqa_max() const { uint32_t val = n_embd_k_gqa(); - for (uint32_t il = 0; il < n_layer; ++il) { + for (uint32_t il = 0; il < n_layer_all; ++il) { val = std::max(val, n_embd_k_gqa(il)); } @@ -121,7 +169,7 @@ uint32_t llama_hparams::n_embd_k_gqa_max() const { uint32_t llama_hparams::n_embd_v_gqa_max() const { uint32_t val = n_embd_v_gqa(); - for (uint32_t il = 0; il < n_layer; ++il) { + for (uint32_t il = 0; il < n_layer_all; ++il) { val = std::max(val, n_embd_v_gqa(il)); } @@ -139,6 +187,13 @@ uint32_t llama_hparams::n_embd_r() const { return n_embd * (n_shortconv_l_cache - 1); } + if (n_embd_head_kda != 0) { + // for Kimi KDA layers + // Conv state for Q, K, V: 3 * (d_conv - 1) * n_head * head_dim + const uint32_t d_inner = n_head() * n_embd_head_kda; // 32 * 128 = 4096 + return 3 * (ssm_d_conv > 0 ? ssm_d_conv - 1 : 3) * d_inner; + } + // TODO: maybe support other convolution strides than 1 // NOTE: since the first column of the conv_state is shifted out each time, it's not actually needed // Corresponds to Mamba's conv_states size @@ -151,16 +206,23 @@ uint32_t llama_hparams::n_embd_s() const { return n_embd * wkv_head_size; } + if (n_embd_head_kda != 0) { + // for Kimi KDA layers + // Full recurrent state: head_dim * head_dim * n_head + // h tensor shape for delta attention: [head_dim, head_dim, n_head] + return n_embd_head_kda * n_embd_head_kda * n_head(); // 128 * 128 * 32 = 524288 + } + // corresponds to Mamba's ssm_states size return ssm_d_state * ssm_d_inner; } -bool llama_hparams::is_recurrent(uint32_t il) const { - if (il < n_layer) { - return recurrent_layer_arr[il]; +bool llama_hparams::is_recr(uint32_t il) const { + if (il < n_layer_all) { + return is_recr_impl[il]; } - GGML_ABORT("%s: il (%u) out of bounds (n_layer: %u)\n", __func__, il, n_layer); + GGML_ABORT("%s: il (%u) out of bounds (n_layer_all: %u)\n", __func__, il, n_layer_all); } uint32_t llama_hparams::n_pos_per_embd() const { @@ -168,11 +230,26 @@ uint32_t llama_hparams::n_pos_per_embd() const { } bool llama_hparams::is_swa(uint32_t il) const { - if (il < n_layer) { - return swa_layers[il]; + if (il < n_layer_all) { + return is_swa_impl[il]; } - GGML_ABORT("fatal error"); + GGML_ABORT("%s: il (%u) out of bounds (n_layer_all: %u)\n", __func__, il, n_layer_all); +} + +bool llama_hparams::is_mla() const { + assert((n_embd_head_k_mla_impl == 0 && n_embd_head_v_mla_impl == 0) || + (n_embd_head_k_mla_impl != 0 && n_embd_head_v_mla_impl != 0)); + + return n_embd_head_k_mla_impl != 0 && n_embd_head_v_mla_impl != 0; +} + +uint32_t llama_hparams::n_embd_head_k_mla() const { + return is_mla() ? n_embd_head_k_mla_impl : n_embd_head_k(); +} + +uint32_t llama_hparams::n_embd_head_v_mla() const { + return is_mla() ? n_embd_head_v_mla_impl : n_embd_head_v(); } bool llama_hparams::has_kv(uint32_t il) const { @@ -188,52 +265,8 @@ bool llama_hparams::has_kv(uint32_t il) const { return true; } -uint32_t llama_hparams::n_layer_kv() const { - uint32_t res = 0; - - for (uint32_t il = 0; il < n_layer; ++il) { - if (has_kv(il)) { - res++; - } - } - - return res; -} - -bool llama_hparams::is_masked_swa(uint32_t n_swa, llama_swa_type swa_type, llama_pos p0, llama_pos p1) { - assert(p0 >= 0 && p1 >= 0); - - switch (swa_type) { - case LLAMA_SWA_TYPE_NONE: - { - } break; - case LLAMA_SWA_TYPE_STANDARD: - { - if (p1 - p0 >= (int32_t) n_swa) { - return true; - } - } break; - case LLAMA_SWA_TYPE_CHUNKED: - { - const llama_pos pos_chunk_start = (p1 / n_swa) * n_swa; - - if (p0 < pos_chunk_start) { - return true; - } - } break; - case LLAMA_SWA_TYPE_SYMMETRIC: - { - const int32_t half_n_swa = (int32_t) n_swa / 2; - const int32_t pos_diff = p1 - p0; - - // Mask if outside the symmetric window - if (pos_diff < -half_n_swa || pos_diff > half_n_swa) { - return true; - } - } break; - } - - return false; +uint32_t llama_hparams::n_layer() const { + return n_layer_all - n_layer_nextn; } bool llama_hparams::use_mrope() const { diff --git a/examples/talk-llama/llama-hparams.h b/examples/talk-llama/llama-hparams.h index 7ae3ec292ef..d045059a63e 100644 --- a/examples/talk-llama/llama-hparams.h +++ b/examples/talk-llama/llama-hparams.h @@ -3,6 +3,7 @@ #include "llama.h" #include <array> +#include <cassert> // bump if necessary #define LLAMA_MAX_LAYERS 512 @@ -22,6 +23,9 @@ enum llama_swa_type { LLAMA_SWA_TYPE_SYMMETRIC = 3, }; +// forward declaration; full definition in llama-graph.h +enum llm_ffn_op_type : int; + struct llama_hparams_posnet { uint32_t n_embd; uint32_t n_layer; @@ -33,27 +37,40 @@ struct llama_hparams_convnext { }; struct llama_hparams { + // note: use the `_impl` suffix to avoid name conflict between members and getters + // for example: n_embd_out() vs n_embd_out_impl + bool vocab_only; bool no_alloc; bool rope_finetuned; bool use_par_res; bool swin_norm; + bool norm_before_residual = false; uint32_t n_ctx_train; // context size the model was trained on uint32_t n_embd; - uint32_t n_embd_features = 0; - uint32_t n_layer; - int32_t n_layer_kv_from_start = -1; // if non-negative, the first n_layer_kv_from_start layers have KV cache - uint32_t n_rot; - uint32_t n_embd_head_k; // dimension of keys (d_k). d_q is assumed to be the same, but there are n_head q heads, and only n_head_kv k-v heads - uint32_t n_embd_head_v; // dimension of values (d_v) aka n_embd_head + uint32_t n_layer_all; + uint32_t n_layer_nextn = 0; uint32_t n_expert = 0; uint32_t n_expert_used = 0; uint32_t n_rel_attn_bkts = 0; + // TODO: this needs to be reworked + int32_t n_layer_kv_from_start = -1; // if non-negative, the first n_layer_kv_from_start layers have KV cache + + // different head size for full_attention and SWA layers + uint32_t n_embd_head_k_full; // dimension of keys (d_k). d_q is assumed to be the same, but there are n_head q heads, and only n_head_kv k-v heads + uint32_t n_embd_head_v_full; // dimension of values (d_v) aka n_embd_head + uint32_t n_embd_head_k_swa; + uint32_t n_embd_head_v_swa; + + // different RoPE dimensions for full_attention and SWA layers + uint32_t n_rot_full; + uint32_t n_rot_swa; + // note: deepseek2 using MLA converts into MQA with larger heads, then decompresses to MHA - uint32_t n_embd_head_k_mla = 0; - uint32_t n_embd_head_v_mla = 0; + uint32_t n_embd_head_k_mla_impl = 0; + uint32_t n_embd_head_v_mla_impl = 0; // for WavTokenizer struct llama_hparams_posnet posnet; @@ -82,7 +99,7 @@ struct llama_hparams { bool expert_weights_norm = false; uint32_t expert_gating_func = LLAMA_EXPERT_GATING_FUNC_TYPE_NONE; uint32_t moe_every_n_layers = 0; - uint32_t nextn_predict_layers = 0; + uint32_t moe_latent_size = 0; float f_norm_eps; float f_norm_rms_eps; @@ -108,6 +125,7 @@ struct llama_hparams { float rope_freq_base_train_swa = 10000.0f; float rope_freq_scale_train; float rope_freq_scale_train_swa = 1.0f; + float rope_scaling_alpha = 0.0f; // NTK-aware alpha for XDRoPE uint32_t n_ctx_orig_yarn; float rope_yarn_log_mul = 0.0f; @@ -123,11 +141,15 @@ struct llama_hparams { llama_swa_type swa_type = LLAMA_SWA_TYPE_NONE; // the size of the sliding window (0 - no SWA) uint32_t n_swa = 0; - // if swa_layers[il] == 1, then layer il is SWA - // if swa_layers[il] == 0, then layer il is dense (i.e. non-SWA) + + // if is_swa_impl[il] == 1, then layer il is SWA + // if is_swa_impl[il] == 0, then layer il is dense (i.e. non-SWA) // by default, all layers are dense // note: using uint32_t type for compatibility reason - std::array<uint32_t, LLAMA_MAX_LAYERS> swa_layers; + std::array<uint32_t, LLAMA_MAX_LAYERS> is_swa_impl; + + // for hybrid state space models + std::array<uint32_t, LLAMA_MAX_LAYERS> is_recr_impl; // for State Space Models uint32_t ssm_d_conv = 0; @@ -136,8 +158,8 @@ struct llama_hparams { uint32_t ssm_dt_rank = 0; uint32_t ssm_n_group = 0; - // for hybrid state space models - std::array<bool, LLAMA_MAX_LAYERS> recurrent_layer_arr; + // for Kimi Linear KDA + uint32_t n_embd_head_kda = 0; bool ssm_dt_b_c_rms = false; @@ -154,6 +176,8 @@ struct llama_hparams { float f_attn_out_scale = 0.0f; uint32_t attn_temp_length = 0; + float f_attn_value_scale = 0.0f; + bool causal_attn = true; bool use_alibi = false; bool attn_soft_cap = false; @@ -162,8 +186,11 @@ struct llama_hparams { // for Classifiers uint32_t n_cls_out = 1; + // input embedding dimension (0 = use n_embd) + uint32_t n_embd_inp_impl = 0; + // output embedding dimension (0 = use n_embd) - uint32_t n_embd_out = 0; + uint32_t n_embd_out_impl = 0; // llama4 smallthinker uint32_t n_moe_layer_step = 0; @@ -190,11 +217,30 @@ struct llama_hparams { std::array<float, LLAMA_MAX_LAYERS> xielu_beta; std::array<float, LLAMA_MAX_LAYERS> xielu_eps; + // DSA (deepseek sparse attention) + uint32_t indexer_n_head = 0; + uint32_t indexer_head_size = 0; + uint32_t indexer_top_k = 0; + // qwen3vl deepstack + // When parsed from GGUF, this implies the first N layers consume the first + // N deepstack embeddings. Use deepstack_mapping_arr if you need a more + // complex mapping. If using deepstack_mapping_arr, also make sure to set + // n_deepstack_layers to the number of unique deepstack layers so that + // n_embd_imp is accurate (see granite.cpp). + // TODO: can be expressed via the `new n_embd_inp_impl` and remove this param uint32_t n_deepstack_layers = 0; + // deepstack layer array (Granite4 Vision) + // -1 => no deepstack + // >=0 => input embedding index for deepstack injection + std::array<int32_t, LLAMA_MAX_LAYERS> deepstack_mapping_arr; + + // gemma4 per-layer embedding + uint32_t n_embd_per_layer = 0; + // needed by encoder-decoder models (e.g. T5, FLAN-T5) - // ref: https://github.com/ggerganov/llama.cpp/pull/8141 + // ref: https://github.com/ggml-org/llama.cpp/pull/8141 llama_token dec_start_token_id = LLAMA_TOKEN_NULL; uint32_t dec_n_layer = 0; @@ -202,6 +248,19 @@ struct llama_hparams { enum llama_rope_type rope_type = LLAMA_ROPE_TYPE_NONE; enum llama_rope_scaling_type rope_scaling_type_train = LLAMA_ROPE_SCALING_TYPE_NONE; + + // Resolved FFN gated activation flavor for archs that read + // `<arch>.hidden_activation` from the GGUF (e.g. ModernBert derivatives). + // Defaults to LLM_FFN_NONE (sentinel = 0); the mapping from the GGUF + // string to a real op is done at hparam-load time via + // llm_ffn_op_type_from_string() in llama-model.cpp, mirroring how + // rope_scaling_type_train is handled. + enum llm_ffn_op_type llm_ffn_op; + + // Step35: optional per-layer clamps for (Swi)GLU + std::array<float, LLAMA_MAX_LAYERS> swiglu_clamp_exp; // clamping for expert FFN + std::array<float, LLAMA_MAX_LAYERS> swiglu_clamp_shexp; // shared expert + // this value n_pattern means that every nth layer is dense (i.e. non-SWA) // dense_first means whether the pattern is start with a dense layer // note that if n_pattern == 0, all layers are SWA @@ -226,6 +285,13 @@ struct llama_hparams { // return true if one of the layers is SWA bool is_swa_any() const; + bool is_swa(uint32_t il) const; + + void set_recr_pattern(uint32_t n_pattern, bool dense_first = false); + + // whether or not the given layer is recurrent (for hybrid models) + bool is_recr(uint32_t il) const; + uint32_t n_head(uint32_t il = 0) const; uint32_t n_head_kv(uint32_t il = 0) const; @@ -234,11 +300,17 @@ struct llama_hparams { uint32_t n_gqa(uint32_t il = 0) const; + uint32_t n_rot(uint32_t il = 0) const; + // dimension of main + auxiliary input embeddings uint32_t n_embd_inp() const; // dimension of output embeddings - uint32_t get_n_embd_out() const; + uint32_t n_embd_out() const; + + // dimension of key/value embeddings for each head (per layer) + uint32_t n_embd_head_k(uint32_t il = 0) const; + uint32_t n_embd_head_v(uint32_t il = 0) const; // dimension of key embeddings across all k-v heads uint32_t n_embd_k_gqa(uint32_t il = 0) const; @@ -261,22 +333,59 @@ struct llama_hparams { // dimension of the recurrent state embeddings uint32_t n_embd_s() const; - // whether or not the given layer is recurrent (for hybrid models) - bool is_recurrent(uint32_t il) const; - uint32_t n_pos_per_embd() const; - bool is_swa(uint32_t il) const; + // note: currently only support if either all or none of the layers are MLA + bool is_mla() const; + + uint32_t n_embd_head_k_mla() const; + uint32_t n_embd_head_v_mla() const; bool has_kv(uint32_t il) const; - // number of layers for which has_kv() returns true - uint32_t n_layer_kv() const; + // number of effective layers (excludes nextn layers) + uint32_t n_layer() const; // note that this function uses different SWA parameters from those in the hparams + // note: inlined on purpose for performance reasons // TODO: think of a better place for this function // TODO: pack the SWA params in a struct? - static bool is_masked_swa(uint32_t n_swa, llama_swa_type swa_type, llama_pos p0, llama_pos p1); + static bool is_masked_swa(uint32_t n_swa, llama_swa_type swa_type, llama_pos p0, llama_pos p1) { + assert(p0 >= 0 && p1 >= 0); + + switch (swa_type) { + case LLAMA_SWA_TYPE_NONE: + { + } break; + case LLAMA_SWA_TYPE_STANDARD: + { + if (p1 - p0 >= (int32_t) n_swa) { + return true; + } + } break; + case LLAMA_SWA_TYPE_CHUNKED: + { + const llama_pos pos_chunk_start = (p1 / n_swa) * n_swa; + + if (p0 < pos_chunk_start) { + return true; + } + } break; + case LLAMA_SWA_TYPE_SYMMETRIC: + { + const int32_t half_n_swa = (int32_t) n_swa / 2; + const int32_t pos_diff = p1 - p0; + + // Mask if outside the symmetric window + if (pos_diff < -half_n_swa || pos_diff > half_n_swa) { + return true; + } + } break; + } + + return false; + } + bool use_mrope() const; }; diff --git a/examples/talk-llama/llama-impl.cpp b/examples/talk-llama/llama-impl.cpp index 8e3e7b223a6..b3a94b946d2 100644 --- a/examples/talk-llama/llama-impl.cpp +++ b/examples/talk-llama/llama-impl.cpp @@ -100,18 +100,18 @@ std::string format(const char * fmt, ...) { std::string llama_format_tensor_shape(const std::vector<int64_t> & ne) { char buf[256]; - snprintf(buf, sizeof(buf), "%5" PRId64, ne.at(0)); + snprintf(buf, sizeof(buf), "%6" PRId64, ne.at(0)); for (size_t i = 1; i < ne.size(); i++) { - snprintf(buf + strlen(buf), sizeof(buf) - strlen(buf), ", %5" PRId64, ne.at(i)); + snprintf(buf + strlen(buf), sizeof(buf) - strlen(buf), ", %6" PRId64, ne.at(i)); } return buf; } std::string llama_format_tensor_shape(const struct ggml_tensor * t) { char buf[256]; - snprintf(buf, sizeof(buf), "%5" PRId64, t->ne[0]); + snprintf(buf, sizeof(buf), "%6" PRId64, t->ne[0]); for (int i = 1; i < GGML_MAX_DIMS; i++) { - snprintf(buf + strlen(buf), sizeof(buf) - strlen(buf), ", %5" PRId64, t->ne[i]); + snprintf(buf + strlen(buf), sizeof(buf) - strlen(buf), ", %6" PRId64, t->ne[i]); } return buf; } @@ -128,7 +128,7 @@ static std::string gguf_data_to_str(enum gguf_type type, const void * data, int case GGUF_TYPE_INT64: return std::to_string(((const int64_t *)data)[i]); case GGUF_TYPE_FLOAT32: return std::to_string(((const float *)data)[i]); case GGUF_TYPE_FLOAT64: return std::to_string(((const double *)data)[i]); - case GGUF_TYPE_BOOL: return ((const bool *)data)[i] ? "true" : "false"; + case GGUF_TYPE_BOOL: return ((const int8_t *)data)[i] != 0 ? "true" : "false"; default: return format("unknown type %d", type); } } diff --git a/examples/talk-llama/llama-impl.h b/examples/talk-llama/llama-impl.h index c3391e79f51..7923c3f7ed5 100644 --- a/examples/talk-llama/llama-impl.h +++ b/examples/talk-llama/llama-impl.h @@ -3,6 +3,7 @@ #include "ggml.h" // for ggml_log_level #include <string> +#include <type_traits> #include <vector> #ifdef __GNUC__ @@ -40,6 +41,19 @@ struct no_init { no_init() = default; }; +template <typename dst_t, typename src_t> +static inline dst_t llama_cast(src_t v) { + if constexpr (std::is_same_v<src_t, dst_t>) { + return v; + } else if constexpr (std::is_same_v<src_t, ggml_fp16_t> && std::is_same_v<dst_t, float>) { + return ggml_fp16_to_fp32(v); + } else if constexpr (std::is_same_v<src_t, float> && std::is_same_v<dst_t, ggml_fp16_t>) { + return ggml_fp32_to_fp16(v); + } else { + static_assert(std::is_same_v<dst_t, void>, "unsupported type combination"); + } +} + struct time_meas { time_meas(int64_t & t_acc, bool disable = false); ~time_meas(); @@ -49,6 +63,16 @@ struct time_meas { int64_t & t_acc; }; +template <typename T> +struct buffer_view { + T * data; + size_t size = 0; + + bool has_data() const { + return data && size > 0; + } +}; + void replace_all(std::string & s, const std::string & search, const std::string & replace); // TODO: rename to llama_format ? @@ -60,4 +84,6 @@ std::string llama_format_tensor_shape(const struct ggml_tensor * t); std::string gguf_kv_to_str(const struct gguf_context * ctx_gguf, int i); -#define LLAMA_TENSOR_NAME_FATTN "__fattn__" +#define LLAMA_TENSOR_NAME_FATTN "__fattn__" +#define LLAMA_TENSOR_NAME_FGDN_AR "__fgdn_ar__" +#define LLAMA_TENSOR_NAME_FGDN_CH "__fgdn_ch__" diff --git a/examples/talk-llama/llama-io.cpp b/examples/talk-llama/llama-io.cpp index 7ad70d16334..5ec4634943f 100644 --- a/examples/talk-llama/llama-io.cpp +++ b/examples/talk-llama/llama-io.cpp @@ -1,5 +1,7 @@ #include "llama-io.h" +#include <vector> + void llama_io_write_i::write_string(const std::string & str) { uint32_t str_size = str.size(); @@ -9,7 +11,10 @@ void llama_io_write_i::write_string(const std::string & str) { void llama_io_read_i::read_string(std::string & str) { uint32_t str_size; - read_to(&str_size, sizeof(str_size)); + read(&str_size, sizeof(str_size)); + + std::vector<char> buf(str_size); + read(buf.data(), str_size); - str.assign((const char *) read(str_size), str_size); + str.assign(buf.data(), str_size); } diff --git a/examples/talk-llama/llama-io.h b/examples/talk-llama/llama-io.h index ce9216b83b1..f276af4fb96 100644 --- a/examples/talk-llama/llama-io.h +++ b/examples/talk-llama/llama-io.h @@ -12,7 +12,7 @@ class llama_io_write_i { virtual ~llama_io_write_i() = default; virtual void write(const void * src, size_t size) = 0; - virtual void write_tensor(const ggml_tensor * tensor, size_t offset, size_t size) = 0; + virtual void write_tensor(ggml_tensor * tensor, size_t offset, size_t size) = 0; // bytes written so far virtual size_t n_bytes() = 0; @@ -25,8 +25,8 @@ class llama_io_read_i { llama_io_read_i() = default; virtual ~llama_io_read_i() = default; - virtual const uint8_t * read(size_t size) = 0; - virtual void read_to(void * dst, size_t size) = 0; + virtual void read(void * dst, size_t size) = 0; + virtual void read_tensor(ggml_tensor * tensor, size_t offset, size_t size) = 0; // bytes read so far virtual size_t n_bytes() = 0; diff --git a/examples/talk-llama/llama-kv-cache-dsa.cpp b/examples/talk-llama/llama-kv-cache-dsa.cpp new file mode 100644 index 00000000000..916ab653756 --- /dev/null +++ b/examples/talk-llama/llama-kv-cache-dsa.cpp @@ -0,0 +1,261 @@ +#include "llama-kv-cache-dsa.h" + +#include "llama-impl.h" +#include "llama-batch.h" +#include "llama-model.h" + +#include <algorithm> +#include <cassert> + +// +// llama_kv_cache_dsa +// + +llama_kv_cache_dsa::llama_kv_cache_dsa( + const llama_model & model, + ggml_type type_k, + ggml_type type_v, + bool v_trans, + bool offload, + bool unified, + uint32_t kv_size, + uint32_t n_seq_max, + uint32_t n_pad, + uint32_t n_swa, + llama_swa_type swa_type, + const layer_filter_cb & filter, + const layer_reuse_cb & reuse) : + hparams_lid(model.hparams), n_stream(unified ? 1 : n_seq_max) { + + LLAMA_LOG_INFO("%s: creating main KV cache, size = %u cells\n", __func__, kv_size); + + kv_mla = std::make_unique<llama_kv_cache>( + model, model.hparams, type_k, type_v, + v_trans, offload, unified, kv_size, n_seq_max, n_pad, + n_swa, swa_type, nullptr, filter, reuse, nullptr); + + // we use llama_kv_cache for caching indexer keys + // by hand-tweaking some hparams we fool it to create + // indexer key cache tensors with correct dimensions + // https://github.com/ggml-org/llama.cpp/pull/21149#discussion_r3015940823 + + // DSA lightning indexer uses MQA with single key head + std::fill(hparams_lid.n_head_kv_arr.begin(), hparams_lid.n_head_kv_arr.end(), 1); + hparams_lid.n_embd_head_k_full = model.hparams.indexer_head_size; + hparams_lid.rope_type = LLAMA_ROPE_TYPE_NEOX; + + LLAMA_LOG_INFO("%s: creating indexer KV cache, size = %u cells\n", __func__, kv_size); + + kv_lid = std::make_unique<llama_kv_cache>( + model, hparams_lid, type_k, type_v, + v_trans, offload, unified, kv_size, n_seq_max, n_pad, + n_swa, swa_type, nullptr, filter, reuse, nullptr); +} + +void llama_kv_cache_dsa::clear(bool data) { + kv_mla->clear(data); + kv_lid->clear(data); +} + +bool llama_kv_cache_dsa::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) { + bool res = true; + + res = res & kv_mla->seq_rm(seq_id, p0, p1); + res = res & kv_lid->seq_rm(seq_id, p0, p1); + + return res; +} + +void llama_kv_cache_dsa::seq_cp(llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) { + kv_mla->seq_cp(seq_id_src, seq_id_dst, p0, p1); + kv_lid->seq_cp(seq_id_src, seq_id_dst, p0, p1); +} + +void llama_kv_cache_dsa::seq_keep(llama_seq_id seq_id) { + kv_mla->seq_keep(seq_id); + kv_lid->seq_keep(seq_id); +} + +void llama_kv_cache_dsa::seq_add(llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) { + kv_mla->seq_add(seq_id, p0, p1, shift); + kv_lid->seq_add(seq_id, p0, p1, shift); +} + +void llama_kv_cache_dsa::seq_div(llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) { + kv_mla->seq_div(seq_id, p0, p1, d); + kv_lid->seq_div(seq_id, p0, p1, d); +} + +llama_pos llama_kv_cache_dsa::seq_pos_min(llama_seq_id seq_id) const { + return kv_mla->seq_pos_min(seq_id); +} + +llama_pos llama_kv_cache_dsa::seq_pos_max(llama_seq_id seq_id) const { + return kv_mla->seq_pos_max(seq_id); +} + +std::map<ggml_backend_buffer_type_t, size_t> llama_kv_cache_dsa::memory_breakdown() const { + std::map<ggml_backend_buffer_type_t, size_t> mb = kv_mla->memory_breakdown(); + for (const auto & buft_size : kv_lid->memory_breakdown()) { + mb[buft_size.first] += buft_size.second; + } + return mb; +} + +llama_memory_context_ptr llama_kv_cache_dsa::init_batch( + llama_batch_allocr & balloc, + uint32_t n_ubatch, + bool embd_all) { + GGML_UNUSED(embd_all); + + do { + balloc.split_reset(); + + std::vector<llama_ubatch> ubatches; + while (true) { + auto ubatch = n_stream == 1 ? balloc.split_simple(n_ubatch) : balloc.split_equal(n_ubatch, true); + + if (ubatch.n_tokens == 0) { + break; + } + + ubatches.push_back(std::move(ubatch)); // NOLINT + } + + if (balloc.get_n_used() < balloc.get_n_tokens()) { + // failed to find a suitable split + break; + } + + auto sinfos_mla = kv_mla->prepare(ubatches); + if (sinfos_mla.empty()) { + break; + } + + auto sinfos_lid = kv_lid->prepare(ubatches); + if (sinfos_lid.empty()) { + break; + } + + assert(sinfos_mla.size() == sinfos_lid.size()); + + return std::make_unique<llama_kv_cache_dsa_context>( + this, std::move(sinfos_mla), std::move(sinfos_lid), std::move(ubatches)); + } while (false); + + return std::make_unique<llama_kv_cache_dsa_context>(LLAMA_MEMORY_STATUS_FAILED_PREPARE); +} + +llama_memory_context_ptr llama_kv_cache_dsa::init_full() { + return std::make_unique<llama_kv_cache_dsa_context>(this); +} + +llama_memory_context_ptr llama_kv_cache_dsa::init_update(llama_context * lctx, bool optimize) { + return std::make_unique<llama_kv_cache_dsa_context>(this, lctx, optimize); +} + +bool llama_kv_cache_dsa::get_can_shift() const { + return kv_mla->get_can_shift() && + kv_lid->get_can_shift() && + kv_mla->get_size() == kv_lid->get_size(); +} + +void llama_kv_cache_dsa::state_write(llama_io_write_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) const { + kv_mla->state_write(io, seq_id, flags); + kv_lid->state_write(io, seq_id, flags); +} + +void llama_kv_cache_dsa::state_read(llama_io_read_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) { + kv_mla->state_read(io, seq_id, flags); + kv_lid->state_read(io, seq_id, flags); +} + +llama_kv_cache * llama_kv_cache_dsa::get_mla() const { + return kv_mla.get(); +} + +llama_kv_cache * llama_kv_cache_dsa::get_lid() const { + return kv_lid.get(); +} + +// +// llama_kv_cache_dsa_context +// + +llama_kv_cache_dsa_context::llama_kv_cache_dsa_context(llama_memory_status status) : status(status) {} + +llama_kv_cache_dsa_context::llama_kv_cache_dsa_context( + llama_kv_cache_dsa * kv) : + ctx_mla(kv->get_mla()->init_full()), + ctx_lid(kv->get_lid()->init_full()), + status(llama_memory_status_combine(ctx_mla->get_status(), ctx_lid->get_status())) { +} + +llama_kv_cache_dsa_context::llama_kv_cache_dsa_context( + llama_kv_cache_dsa * kv, + llama_context * lctx, + bool optimize) : + ctx_mla(kv->get_mla()->init_update(lctx, optimize)), + ctx_lid(kv->get_lid()->init_update(lctx, optimize)), + status(llama_memory_status_combine(ctx_mla->get_status(), ctx_lid->get_status())) { +} + +llama_kv_cache_dsa_context::llama_kv_cache_dsa_context( + llama_kv_cache_dsa * kv, + slot_info_vec_t sinfos_mla, + slot_info_vec_t sinfos_lid, + std::vector<llama_ubatch> ubatches) : + ubatches(std::move(ubatches)), + // note: here we copy the ubatches. not sure if this is ideal + ctx_mla(new llama_kv_cache_context(kv->get_mla(), std::move(sinfos_mla), this->ubatches)), + ctx_lid(new llama_kv_cache_context(kv->get_lid(), std::move(sinfos_lid), this->ubatches)), + status(llama_memory_status_combine(ctx_mla->get_status(), ctx_lid->get_status())) { +} + +llama_kv_cache_dsa_context:: ~llama_kv_cache_dsa_context() = default; + +bool llama_kv_cache_dsa_context::next() { + assert(status == LLAMA_MEMORY_STATUS_SUCCESS); + + ctx_mla->next(); + ctx_lid->next(); + + if (++i_next >= ubatches.size()) { + return false; + } + + return true; +} + +bool llama_kv_cache_dsa_context::apply() { + assert(!llama_memory_status_is_fail(status)); + + bool res = true; + + res = res & ctx_mla->apply(); + res = res & ctx_lid->apply(); + + return res; +} + +llama_memory_status llama_kv_cache_dsa_context::get_status() const { + return status; +} + +const llama_ubatch & llama_kv_cache_dsa_context::get_ubatch() const { + assert(status == LLAMA_MEMORY_STATUS_SUCCESS); + + return ubatches[i_next]; +} + +const llama_kv_cache_context * llama_kv_cache_dsa_context::get_mla() const { + assert(status == LLAMA_MEMORY_STATUS_SUCCESS); + + return static_cast<const llama_kv_cache_context *>(ctx_mla.get()); +} + +const llama_kv_cache_context * llama_kv_cache_dsa_context::get_lid() const { + assert(status == LLAMA_MEMORY_STATUS_SUCCESS); + + return static_cast<const llama_kv_cache_context *>(ctx_lid.get()); +} diff --git a/examples/talk-llama/llama-kv-cache-dsa.h b/examples/talk-llama/llama-kv-cache-dsa.h new file mode 100644 index 00000000000..e2b330993b8 --- /dev/null +++ b/examples/talk-llama/llama-kv-cache-dsa.h @@ -0,0 +1,138 @@ +#pragma once + +#include "llama-kv-cache.h" + +#include <vector> + +// +// llama_kv_cache_dsa +// + +// utilizes two instances of llama_kv_cache: +// - the first instance is for caching key tensors of the model, +// - the second instance is for caching lightning indexer key tensors + +class llama_kv_cache_dsa : public llama_memory_i { +public: + llama_kv_cache_dsa( + const llama_model & model, + ggml_type type_k, + ggml_type type_v, + bool v_trans, + bool offload, + bool unified, + uint32_t kv_size, + uint32_t n_seq_max, + uint32_t n_pad, + uint32_t n_swa, + llama_swa_type swa_type, + const layer_filter_cb & filter, + const layer_reuse_cb & reuse); + + ~llama_kv_cache_dsa() = default; + + // + // llama_memory_i + // + + llama_memory_context_ptr init_batch( + llama_batch_allocr & balloc, + uint32_t n_ubatch, + bool embd_all) override; + + llama_memory_context_ptr init_full() override; + + llama_memory_context_ptr init_update(llama_context * lctx, bool optimize) override; + + bool get_can_shift() const override; + + void clear(bool data) override; + + bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) override; + void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) override; + void seq_keep(llama_seq_id seq_id) override; + void seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) override; + void seq_div (llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) override; + + llama_pos seq_pos_min(llama_seq_id seq_id) const override; + llama_pos seq_pos_max(llama_seq_id seq_id) const override; + + std::map<ggml_backend_buffer_type_t, size_t> memory_breakdown() const override; + + // state write/load + + void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1, llama_state_seq_flags flags = 0) const override; + void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1, llama_state_seq_flags flags = 0) override; + + // + // llama_kv_cache_dsa specific API + // + + llama_kv_cache * get_mla() const; + llama_kv_cache * get_lid() const; + +private: + // we keep indexer KV cache hparams instance here as llama_kv_cache stores only reference to it + llama_hparams hparams_lid; + const uint32_t n_stream = 1; + + std::unique_ptr<llama_kv_cache> kv_mla; + std::unique_ptr<llama_kv_cache> kv_lid; +}; + +class llama_kv_cache_dsa_context : public llama_memory_context_i { +public: + using slot_info_vec_t = llama_kv_cache::slot_info_vec_t; + + // used for errors + llama_kv_cache_dsa_context(llama_memory_status status); + + // used to create a full-cache context + llama_kv_cache_dsa_context( + llama_kv_cache_dsa * kv); + + // used to create an update context + llama_kv_cache_dsa_context( + llama_kv_cache_dsa * kv, + llama_context * lctx, + bool optimize); + + // used to create a batch processing context from a batch + llama_kv_cache_dsa_context( + llama_kv_cache_dsa * kv, + slot_info_vec_t sinfos_base, + slot_info_vec_t sinfos_ik, + std::vector<llama_ubatch> ubatches); + + virtual ~llama_kv_cache_dsa_context(); + + // + // llama_memory_context_i + // + + bool next() override; + bool apply() override; + + llama_memory_status get_status() const override; + const llama_ubatch & get_ubatch() const override; + + // + // llama_kv_cache_dsa_context specific API + // + + const llama_kv_cache_context * get_mla() const; + const llama_kv_cache_context * get_lid() const; + +private: + //llama_kv_cache_dsa * kv; + + // the index of the next ubatch to process + size_t i_next = 0; + + std::vector<llama_ubatch> ubatches; + + const llama_memory_context_ptr ctx_mla; + const llama_memory_context_ptr ctx_lid; + + const llama_memory_status status; +}; diff --git a/examples/talk-llama/llama-kv-cache-iswa.cpp b/examples/talk-llama/llama-kv-cache-iswa.cpp index 3a34102a23d..aa1b1b72ebe 100644 --- a/examples/talk-llama/llama-kv-cache-iswa.cpp +++ b/examples/talk-llama/llama-kv-cache-iswa.cpp @@ -23,8 +23,10 @@ llama_kv_cache_iswa::llama_kv_cache_iswa( uint32_t n_seq_max, uint32_t n_ubatch, uint32_t n_pad, + llama_memory_t mem_other, const layer_filter_cb & filter, - const layer_reuse_cb & reuse) : hparams(model.hparams), unified(unified) { + const layer_reuse_cb & reuse, + const layer_share_cb & share) : hparams(model.hparams), unified(unified) { // chain filters const layer_filter_cb filter_base = [&](int32_t il) { @@ -59,17 +61,27 @@ llama_kv_cache_iswa::llama_kv_cache_iswa( LLAMA_LOG_INFO("%s: creating non-SWA KV cache, size = %u cells\n", __func__, size_base); + llama_memory_t mem_other_base = nullptr; + if (mem_other) { + mem_other_base = static_cast<llama_kv_cache_iswa *>(mem_other)->get_base(); + } + + llama_memory_t mem_other_swa = nullptr; + if (mem_other) { + mem_other_swa = static_cast<llama_kv_cache_iswa *>(mem_other)->get_swa(); + } + kv_base = std::make_unique<llama_kv_cache>( - model, type_k, type_v, + model, hparams, type_k, type_v, v_trans, offload, unified, size_base, n_seq_max, n_pad, - 0, LLAMA_SWA_TYPE_NONE, filter_base, reuse); + 0, LLAMA_SWA_TYPE_NONE, mem_other_base, filter_base, reuse, share); LLAMA_LOG_INFO("%s: creating SWA KV cache, size = %u cells\n", __func__, size_swa); kv_swa = std::make_unique<llama_kv_cache>( - model, type_k, type_v, + model, hparams, type_k, type_v, v_trans, offload, unified, size_swa, n_seq_max, n_pad, - hparams.n_swa, hparams.swa_type, filter_swa, reuse); + hparams.n_swa, hparams.swa_type, mem_other_swa, filter_swa, reuse, share); } void llama_kv_cache_iswa::clear(bool data) { @@ -218,7 +230,9 @@ llama_memory_context_ptr llama_kv_cache_iswa::init_update(llama_context * lctx, } bool llama_kv_cache_iswa::get_can_shift() const { - return kv_base->get_size() == kv_swa->get_size(); + return kv_base->get_can_shift() && + kv_swa->get_can_shift() && + kv_base->get_size() == kv_swa->get_size(); } void llama_kv_cache_iswa::state_write(llama_io_write_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) const { diff --git a/examples/talk-llama/llama-kv-cache-iswa.h b/examples/talk-llama/llama-kv-cache-iswa.h index 70ab22f0d60..dfafc1ef510 100644 --- a/examples/talk-llama/llama-kv-cache-iswa.h +++ b/examples/talk-llama/llama-kv-cache-iswa.h @@ -25,8 +25,10 @@ class llama_kv_cache_iswa : public llama_memory_i { uint32_t n_seq_max, uint32_t n_ubatch, uint32_t n_pad, + llama_memory_t mem_other, const layer_filter_cb & filter, - const layer_reuse_cb & reuse); + const layer_reuse_cb & reuse, + const layer_share_cb & share); ~llama_kv_cache_iswa() = default; diff --git a/examples/talk-llama/llama-kv-cache.cpp b/examples/talk-llama/llama-kv-cache.cpp index 3186242d60f..2802103bdd8 100644 --- a/examples/talk-llama/llama-kv-cache.cpp +++ b/examples/talk-llama/llama-kv-cache.cpp @@ -13,12 +13,73 @@ #include <map> #include <stdexcept> +static bool ggml_is_power_of_2(int n) { + return (n & (n - 1)) == 0; +} + +// orthonormal Walsh-Hadamard rotation matrix +// note: res^2 == I +static void ggml_gen_hadamard(ggml_tensor * tensor) { + assert(tensor->type == GGML_TYPE_F32); + + const int n = tensor->ne[0]; + + assert(ggml_is_power_of_2(n)); + assert(tensor->ne[1] == n); + assert(tensor->ne[2] == 1); + assert(tensor->ne[3] == 1); + + std::vector<float> data_f32; + + float * data = (float *) tensor->data; + + if (tensor->type != GGML_TYPE_F32) { + data_f32.resize(n*n); + data = data_f32.data(); + } + + data[0*n + 0] = 1.0 / sqrtf(n); + + for (int s = 1; s < n; s *= 2) { + for (int i = 0; i < s; i++) { + for (int j = 0; j < s; j++) { + const float val = data[i*n + j]; + + data[(i + s)*n + (j )] = val; + data[(i )*n + (j + s)] = val; + data[(i + s)*n + (j + s)] = -val; + } + } + } + + if (tensor->type != GGML_TYPE_F32) { + ggml_quantize_chunk(tensor->type, data, tensor->data, 0, 1, n*n, nullptr); + } +} + +static ggml_tensor * ggml_mul_mat_aux( + ggml_context * ctx, + ggml_tensor * cur, + ggml_tensor * rot) { + const auto n = rot->ne[0]; + + ggml_tensor * res; + + res = ggml_reshape_2d(ctx, cur, n, ggml_nelements(cur)/n); + res = ggml_mul_mat (ctx, rot, res); + ggml_mul_mat_set_hint(res, GGML_HINT_SRC0_IS_HADAMARD); + res = ggml_reshape_4d(ctx, res, cur->ne[0], cur->ne[1], cur->ne[2], cur->ne[3]); + + return res; +} + // // llama_kv_cache // llama_kv_cache::llama_kv_cache( const llama_model & model, + const llama_hparams & hparams, ggml_type type_k, ggml_type type_v, bool v_trans, @@ -29,14 +90,30 @@ llama_kv_cache::llama_kv_cache( uint32_t n_pad, uint32_t n_swa, llama_swa_type swa_type, + llama_memory_t mem_other, const layer_filter_cb & filter, - const layer_reuse_cb & reuse) : - model(model), hparams(model.hparams), v_trans(v_trans), - n_seq_max(n_seq_max), n_stream(unified ? 1 : n_seq_max), n_pad(n_pad), n_swa(n_swa), swa_type(swa_type) { + const layer_reuse_cb & reuse, + const layer_share_cb & share) : + model(model), hparams(hparams), v_trans(v_trans), + n_seq_max(n_seq_max), n_stream(unified ? 1 : n_seq_max), n_pad(n_pad), n_swa(n_swa), swa_type(swa_type), + other(static_cast<llama_kv_cache *>(mem_other)), + v_cells_impl(other ? other->v_cells_impl : std::make_shared<llama_kv_cells_vec>()), + v_cells(*v_cells_impl) { + + // shared cells view the source cache's K/V tensors, so the cell count + // follows the source allocation: a fitted target can be smaller than the + // draft default and oversized views would overflow the source tensors + if (other) { + const uint32_t size_other = other->get_size(); + if (kv_size != size_other) { + LLAMA_LOG_WARN("%s: kv_size = %u overridden to %u to match the shared source cache\n", __func__, kv_size, size_other); + kv_size = size_other; + } + } GGML_ASSERT(kv_size % n_pad == 0); - const uint32_t n_layer_kv = hparams.n_layer_kv(); + const uint32_t n_layer = hparams.n_layer_all; // define a comparator for the buft -> ctx map to ensure that the order is well-defined: struct ggml_backend_buft_comparator { @@ -51,7 +128,7 @@ llama_kv_cache::llama_kv_cache( auto it = ctx_map.find(buft); if (it == ctx_map.end()) { ggml_init_params params = { - /*.mem_size =*/ size_t(2u*(1 + n_stream)*n_layer_kv*ggml_tensor_overhead()), + /*.mem_size =*/ size_t(2u*(1 + n_stream)*n_layer*ggml_tensor_overhead()), /*.mem_buffer =*/ NULL, /*.no_alloc =*/ true, }; @@ -97,7 +174,9 @@ llama_kv_cache::llama_kv_cache( __func__, hparams.n_embd_v_gqa_max()); } - for (uint32_t il = 0; il < hparams.n_layer; il++) { + const bool is_mla = hparams.is_mla(); + + for (uint32_t il = 0; il < n_layer; il++) { if (!hparams.has_kv(il)) { LLAMA_LOG_DEBUG("%s: layer %3d: does not have KV cache\n", __func__, il); continue; @@ -108,6 +187,36 @@ llama_kv_cache::llama_kv_cache( continue; } + if (share && other) { + const int32_t il_share = share(il); + + if (il_share >= 0) { + const auto & layer_share = other->layers[other->map_layer_ids[il_share]]; + + LLAMA_LOG_WARN("%s: layer %3d: sharing with layer %d. k = %p, v = %p\n", __func__, il, il_share, + layer_share.k->data, layer_share.v->data); + + map_layer_ids[il] = layers.size(); + + layers.push_back(layer_share); + layers.back().il = il; + + continue; + } + } + + if (n_embd_head_k_all == 0) { + n_embd_head_k_all = (int32_t) hparams.n_embd_head_k(il); + } else if (n_embd_head_k_all > 0 && n_embd_head_k_all != (int32_t) hparams.n_embd_head_k(il)) { + n_embd_head_k_all = -1; + } + + if (n_embd_head_v_all == 0) { + n_embd_head_v_all = (int32_t) hparams.n_embd_head_v(il); + } else if (n_embd_head_v_all > 0 && n_embd_head_v_all != (int32_t) hparams.n_embd_head_v(il)) { + n_embd_head_v_all = -1; + } + // [TAG_V_CACHE_VARIABLE] const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il); const uint32_t n_embd_v_gqa = !v_trans ? hparams.n_embd_v_gqa(il) : hparams.n_embd_v_gqa_max(); @@ -130,18 +239,21 @@ llama_kv_cache::llama_kv_cache( throw std::runtime_error("failed to create ggml context for kv cache"); } - ggml_tensor * k = ggml_new_tensor_3d(ctx, type_k, n_embd_k_gqa, kv_size, n_stream); - ggml_tensor * v = ggml_new_tensor_3d(ctx, type_v, n_embd_v_gqa, kv_size, n_stream); + const bool has_k = true; + const bool has_v = !is_mla; - ggml_format_name(k, "cache_k_l%d", il); - ggml_format_name(v, "cache_v_l%d", il); + ggml_tensor * k = has_k ? ggml_new_tensor_3d(ctx, type_k, n_embd_k_gqa, kv_size, n_stream) : nullptr; + ggml_tensor * v = has_v ? ggml_new_tensor_3d(ctx, type_v, n_embd_v_gqa, kv_size, n_stream) : nullptr; + + has_k && ggml_format_name(k, "cache_k_l%d", il); + has_v && ggml_format_name(v, "cache_v_l%d", il); std::vector<ggml_tensor *> k_stream; std::vector<ggml_tensor *> v_stream; for (uint32_t s = 0; s < n_stream; ++s) { - k_stream.push_back(ggml_view_2d(ctx, k, n_embd_k_gqa, kv_size, k->nb[1], s*k->nb[2])); - v_stream.push_back(ggml_view_2d(ctx, v, n_embd_v_gqa, kv_size, v->nb[1], s*v->nb[2])); + k_stream.push_back(has_k ? ggml_view_2d(ctx, k, n_embd_k_gqa, kv_size, k->nb[1], s*k->nb[2]) : nullptr); + v_stream.push_back(has_v ? ggml_view_2d(ctx, v, n_embd_v_gqa, kv_size, v->nb[1], s*v->nb[2]) : nullptr); } map_layer_ids[il] = layers.size(); @@ -152,7 +264,7 @@ llama_kv_cache::llama_kv_cache( if (reuse) { LLAMA_LOG_DEBUG("%s: reusing layers:\n", __func__); - for (uint32_t il = 0; il < hparams.n_layer; il++) { + for (uint32_t il = 0; il < n_layer; il++) { const int32_t il_reuse = reuse(il); if (il_reuse < 0) { @@ -176,7 +288,7 @@ llama_kv_cache::llama_kv_cache( // allocate tensors and initialize the buffers to avoid NaNs in the padding for (auto & [buft, ctx] : ctx_map) { ggml_backend_buffer_t buf; - if (model.hparams.no_alloc) { + if (hparams.no_alloc) { buf = ggml_backend_buft_alloc_buffer(buft, /*size =*/ 0); // dummy buffer for (ggml_tensor * t = ggml_get_first_tensor(ctx.get()); t != nullptr; t = ggml_get_next_tensor(ctx.get(), t)) { t->buffer = buf; // set dummy buffer for KV cache so that the backend scheduler won't try to allocate it @@ -204,6 +316,62 @@ llama_kv_cache::llama_kv_cache( ggml_type_name(type_v), (float)memory_size_v / (1024.0f * 1024.0f)); } + // TODO: refactor [TAG_KV_CACHE_SHARE_CELLS] + if (other) { + n_embd_head_k_all = other->n_embd_head_k_all; + n_embd_head_v_all = other->n_embd_head_v_all; + + attn_rot_k = other->attn_rot_k; + attn_rot_v = other->attn_rot_v; + } else { + const char * LLAMA_ATTN_ROT_DISABLE = getenv("LLAMA_ATTN_ROT_DISABLE"); + const bool attn_rot_disable = LLAMA_ATTN_ROT_DISABLE ? atoi(LLAMA_ATTN_ROT_DISABLE) : false; + if (attn_rot_disable) { + LLAMA_LOG_WARN("%s: attention rotation force disabled (LLAMA_ATTN_ROT_DISABLE)\n", __func__); + } + + attn_rot_k = + !attn_rot_disable && + n_embd_head_k_all > 0 && + ggml_is_quantized(type_k) && + hparams.n_embd_head_k() % 64 == 0; + + // always create Hadamard rotation tensors for DeepSeek V3.2 DSA lightning indexer + if (model.arch == LLM_ARCH_DEEPSEEK32 && hparams.n_embd_head_k_full == hparams.indexer_head_size) { + attn_rot_k = true; + } + + attn_rot_v = + !attn_rot_disable && + n_embd_head_v_all > 0 && + ggml_is_quantized(type_v) && + hparams.n_embd_head_v() % 64 == 0; + } + + LLAMA_LOG_INFO("%s: attn_rot_k = %d, n_embd_head_k_all = %d\n", __func__, attn_rot_k, n_embd_head_k_all); + LLAMA_LOG_INFO("%s: attn_rot_v = %d, n_embd_head_k_all = %d\n", __func__, attn_rot_v, n_embd_head_v_all); + + // pre-compute the haramard matrices and keep them in host memory + // TODO: in the future, we can make copies in the backend buffers to avoid host -> device transfers + if (attn_rot_k || attn_rot_v) { + for (int64_t n = 64; n <= std::max(n_embd_head_k_all, n_embd_head_v_all); n *= 2) { + attn_rot_hadamard[n] = std::vector<float>(n*n); + + ggml_init_params params = { + /* .mem_size = */ 1*ggml_tensor_overhead(), + /* .mem_buffer = */ nullptr, + /* .no_alloc = */ true, + }; + + ggml_context_ptr ctx { ggml_init(params) }; + + ggml_tensor * tmp = ggml_new_tensor_2d(ctx.get(), GGML_TYPE_F32, n, n); + tmp->data = attn_rot_hadamard[n].data(); + + ggml_gen_hadamard(tmp); + } + } + const char * LLAMA_KV_CACHE_DEBUG = getenv("LLAMA_KV_CACHE_DEBUG"); debug = LLAMA_KV_CACHE_DEBUG ? atoi(LLAMA_KV_CACHE_DEBUG) : 0; } @@ -222,6 +390,11 @@ void llama_kv_cache::clear(bool data) { } bool llama_kv_cache::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) { + // TODO: refactor [TAG_KV_CACHE_SHARE_CELLS] + if (other) { + return true; + } + GGML_ASSERT(seq_id == -1 || (seq_id >= 0 && (size_t) seq_id < seq_to_stream.size())); if (p0 < 0) { @@ -285,6 +458,11 @@ bool llama_kv_cache::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) { } void llama_kv_cache::seq_cp(llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) { + // TODO: refactor [TAG_KV_CACHE_SHARE_CELLS] + if (other) { + return; + } + GGML_ASSERT(seq_id_src >= 0 && (size_t) seq_id_src < seq_to_stream.size()); GGML_ASSERT(seq_id_dst >= 0 && (size_t) seq_id_dst < seq_to_stream.size()); @@ -372,6 +550,11 @@ void llama_kv_cache::seq_cp(llama_seq_id seq_id_src, llama_seq_id seq_id_dst, ll } void llama_kv_cache::seq_keep(llama_seq_id seq_id) { + // TODO: refactor [TAG_KV_CACHE_SHARE_CELLS] + if (other) { + return; + } + GGML_ASSERT(seq_id >= 0 && (size_t) seq_id < seq_to_stream.size()); auto & cells = v_cells[seq_to_stream[seq_id]]; @@ -394,6 +577,11 @@ void llama_kv_cache::seq_keep(llama_seq_id seq_id) { } void llama_kv_cache::seq_add(llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) { + // TODO: refactor [TAG_KV_CACHE_SHARE_CELLS] + if (other) { + return; + } + GGML_ASSERT(seq_id >= 0 && (size_t) seq_id < seq_to_stream.size()); GGML_ASSERT(hparams.n_pos_per_embd() == 1 && "seq_add() is only supported for n_pos_per_embd() == 1"); @@ -439,6 +627,11 @@ void llama_kv_cache::seq_add(llama_seq_id seq_id, llama_pos p0, llama_pos p1, ll } void llama_kv_cache::seq_div(llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) { + // TODO: refactor [TAG_KV_CACHE_SHARE_CELLS] + if (other) { + return; + } + GGML_ASSERT(seq_id >= 0 && (size_t) seq_id < seq_to_stream.size()); GGML_ASSERT(hparams.n_pos_per_embd() == 1 && "seq_div() is only supported for n_pos_per_embd() == 1"); @@ -473,6 +666,11 @@ void llama_kv_cache::seq_div(llama_seq_id seq_id, llama_pos p0, llama_pos p1, in } llama_pos llama_kv_cache::seq_pos_min(llama_seq_id seq_id) const { + // TODO: refactor [TAG_KV_CACHE_SHARE_CELLS] + if (other) { + return other->seq_pos_min(seq_id); + } + GGML_ASSERT(seq_id >= 0 && (size_t) seq_id < seq_to_stream.size()); const auto & cells = v_cells[seq_to_stream[seq_id]]; @@ -481,6 +679,11 @@ llama_pos llama_kv_cache::seq_pos_min(llama_seq_id seq_id) const { } llama_pos llama_kv_cache::seq_pos_max(llama_seq_id seq_id) const { + // TODO: refactor [TAG_KV_CACHE_SHARE_CELLS] + if (other) { + return other->seq_pos_max(seq_id); + } + GGML_ASSERT(seq_id >= 0 && (size_t) seq_id < seq_to_stream.size()); const auto & cells = v_cells[seq_to_stream[seq_id]]; @@ -578,7 +781,7 @@ llama_kv_cache::slot_info_vec_t llama_kv_cache::prepare(const std::vector<llama_ break; } - // remeber the position that we found + // remember the position that we found res.push_back(sinfo_new); // store the old state of the cells in the recovery stack @@ -621,6 +824,11 @@ llama_kv_cache::slot_info_vec_t llama_kv_cache::prepare(const std::vector<llama_ } bool llama_kv_cache::update(llama_context * lctx, bool do_shift, const stream_copy_info & sc_info) { + // TODO: refactor [TAG_KV_CACHE_SHARE_CELLS] + if (other) { + return true; + } + bool updated = false; auto * sched = lctx->get_sched(); @@ -647,7 +855,10 @@ bool llama_kv_cache::update(llama_context * lctx, bool do_shift, const stream_co const auto & layer = layers[il]; ggml_backend_tensor_copy(layer.k_stream[ssrc], layer.k_stream[sdst]); - ggml_backend_tensor_copy(layer.v_stream[ssrc], layer.v_stream[sdst]); + + if (layer.v_stream[ssrc]) { + ggml_backend_tensor_copy(layer.v_stream[ssrc], layer.v_stream[sdst]); + } } } } @@ -852,7 +1063,7 @@ llama_kv_cache::slot_info llama_kv_cache::find_slot(const llama_ubatch & ubatch, const llama_seq_id seq_id_cell = cells.seq_get(idx); // SWA mask - if (is_masked_swa(pos_cell, cells.seq_pos_max(seq_id_cell) + 1)) { + if (llama_hparams::is_masked_swa(n_swa, swa_type, pos_cell, cells.seq_pos_max(seq_id_cell) + 1)) { can_use = true; } } @@ -893,6 +1104,11 @@ llama_kv_cache::slot_info llama_kv_cache::find_slot(const llama_ubatch & ubatch, } void llama_kv_cache::apply_ubatch(const slot_info & sinfo, const llama_ubatch & ubatch) { + // TODO: refactor [TAG_KV_CACHE_SHARE_CELLS] + if (other) { + return; + } + // keep track of the max sequence position that we would overwrite with this ubatch // for non-SWA cache, this would be always empty llama_seq_id seq_pos_max_rm[LLAMA_MAX_SEQ]; @@ -966,6 +1182,13 @@ void llama_kv_cache::apply_ubatch(const slot_info & sinfo, const llama_ubatch & } bool llama_kv_cache::get_can_shift() const { + // Step35 uses per-layer RoPE dims; K-shift assumes a single global n_rot. + if (model.arch == LLM_ARCH_STEP35) { + return false; + } + if (hparams.n_pos_per_embd() > 1) { + return false; + } return true; } @@ -989,6 +1212,14 @@ bool llama_kv_cache::get_has_shift() const { return result; } +ggml_type llama_kv_cache::type_k() const { + return layers[0].k->type; +} + +ggml_type llama_kv_cache::type_v() const { + return layers[0].v->type; +} + uint32_t llama_kv_cache::get_n_kv(const slot_info & sinfo) const { uint32_t result = 0; @@ -1018,8 +1249,8 @@ ggml_tensor * llama_kv_cache::get_k(ggml_context * ctx, int32_t il, uint32_t n_k const uint32_t ns = sinfo.s1 - sinfo.s0 + 1; return ggml_view_4d(ctx, k, - hparams.n_embd_head_k, hparams.n_head_kv(il), n_kv, ns, - ggml_row_size(k->type, hparams.n_embd_head_k), + hparams.n_embd_head_k(il), hparams.n_head_kv(il), n_kv, ns, + ggml_row_size(k->type, hparams.n_embd_head_k(il)), ggml_row_size(k->type, n_embd_k_gqa), ggml_row_size(k->type, n_embd_k_gqa*kv_size), ggml_row_size(k->type, n_embd_k_gqa*kv_size)*sinfo.s0); @@ -1041,8 +1272,8 @@ ggml_tensor * llama_kv_cache::get_v(ggml_context * ctx, int32_t il, uint32_t n_k if (!v_trans) { // note: v->nb[1] <= v->nb[2] return ggml_view_4d(ctx, v, - hparams.n_embd_head_v, hparams.n_head_kv(il), n_kv, ns, - ggml_row_size(v->type, hparams.n_embd_head_v), // v->nb[1] + hparams.n_embd_head_v(il), hparams.n_head_kv(il), n_kv, ns, + ggml_row_size(v->type, hparams.n_embd_head_v(il)), // v->nb[1] ggml_row_size(v->type, n_embd_v_gqa), // v->nb[2] ggml_row_size(v->type, n_embd_v_gqa*kv_size), // v->nb[3] ggml_row_size(v->type, n_embd_v_gqa*kv_size)*sinfo.s0); @@ -1050,8 +1281,8 @@ ggml_tensor * llama_kv_cache::get_v(ggml_context * ctx, int32_t il, uint32_t n_k // note: v->nb[1] > v->nb[2] return ggml_view_4d(ctx, v, - n_kv, hparams.n_head_kv(il), hparams.n_embd_head_v, ns, - ggml_row_size(v->type, kv_size*hparams.n_embd_head_v), // v->nb[1] + n_kv, hparams.n_head_kv(il), hparams.n_embd_head_v(il), ns, + ggml_row_size(v->type, kv_size*hparams.n_embd_head_v(il)), // v->nb[1] ggml_row_size(v->type, kv_size), // v->nb[2] ggml_row_size(v->type, kv_size*n_embd_v_gqa), // v->nb[3] ggml_row_size(v->type, kv_size*n_embd_v_gqa)*sinfo.s0); @@ -1174,6 +1405,47 @@ ggml_tensor * llama_kv_cache::build_input_v_idxs(ggml_context * ctx, const llama return v_idxs; } +ggml_tensor * llama_kv_cache::build_input_k_rot(ggml_context * ctx) const { + ggml_tensor * res = nullptr; + + if (attn_rot_k) { + int nrot = 64; + + // TODO: investigate if using the smallest rotation matrix is beneficial also for K (similar as for V) + // ref: https://github.com/ggml-org/llama.cpp/pull/21038#issuecomment-4141323088 + do { + nrot *= 2; + } while (n_embd_head_k_all % nrot == 0); + nrot /= 2; + + res = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, nrot, nrot); + ggml_set_input(res); + ggml_set_name(res, "attn_inp_k_rot"); + } + + return res; +} + +ggml_tensor * llama_kv_cache::build_input_v_rot(ggml_context * ctx) const { + ggml_tensor * res = nullptr; + + if (attn_rot_v) { + int nrot = 64; + // using smaller rotation matrices for V seems beneficial + // ref: https://github.com/ggml-org/llama.cpp/pull/21038#issuecomment-4146397570 + //do { + // nrot *= 2; + //} while (hparams.n_embd_head_v() % nrot == 0); + //nrot /= 2; + + res = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, nrot, nrot); + ggml_set_input(res); + ggml_set_name(res, "attn_inp_v_rot"); + } + + return res; +} + void llama_kv_cache::set_input_k_idxs(ggml_tensor * dst, const llama_ubatch * ubatch, const slot_info & sinfo) const { const uint32_t n_tokens = ubatch->n_tokens; GGML_ASSERT(n_tokens == (int64_t) sinfo.size()*sinfo.n_stream()); @@ -1237,90 +1509,247 @@ void llama_kv_cache::set_input_k_shift(ggml_tensor * dst) const { } } -void llama_kv_cache::set_input_kq_mask(ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const { - const uint32_t n_tokens = ubatch->n_tokens; +struct args_set_input_kq_mask { + const llama_hparams & hparams; + const llama_ubatch * ubatch; - GGML_ASSERT(ggml_backend_buffer_is_host(dst->buffer)); - float * data = (float *) dst->data; + const std::vector<llama_kv_cells> & v_cells; + const std::vector<uint32_t> & seq_to_stream; - const int64_t n_kv = dst->ne[0]; - const int64_t n_stream = dst->ne[3]; // num streams in the current ubatch + uint32_t n_swa; + llama_swa_type swa_type; - GGML_ASSERT(n_tokens%n_stream == 0); + int64_t n_kv; + int64_t n_stream; + int64_t n_tps; +}; - // n_tps == n_tokens_per_stream - const int64_t n_tps = n_tokens/n_stream; +template<typename T, bool causal, bool swa, bool is_2d, bool alibi> +static void set_input_kq_mask_impl(const args_set_input_kq_mask & args, T * data) { + //const auto & hparams = args.hparams; + const auto & ubatch = args.ubatch; - std::fill(data, data + ggml_nelements(dst), -INFINITY); - - // Use only the previous KV cells of the correct sequence for each token of the ubatch. - // It's assumed that if a token in the batch has multiple sequences, they are equivalent. - // Example with a cache of 10 tokens, 2 tokens populated in cache and 3 tokens in batch: - // Causal mask: - // xxx------- - // xxxx------ - // xxxxx----- - // Non-causal mask: - // xxxxx----- - // xxxxx----- - // xxxxx----- - // To visualize the mask, see https://github.com/ggml-org/llama.cpp/pull/12615 - // TODO: optimize this section - for (uint32_t h = 0; h < 1; ++h) { - for (uint32_t s = 0; s < n_stream; ++s) { - for (uint32_t ii = 0; ii < n_tps; ++ii) { - const uint32_t i = s*n_tps + ii; + const auto & v_cells = args.v_cells; + const auto & seq_to_stream = args.seq_to_stream; - const llama_seq_id seq_id = ubatch->seq_id[i][0]; + const uint32_t n_swa = args.n_swa; + const llama_swa_type swa_type = args.swa_type; - const auto & cells = v_cells[seq_to_stream[seq_id]]; + const int64_t n_kv = args.n_kv; + const int64_t n_stream = args.n_stream; + const int64_t n_tps = args.n_tps; - const llama_pos p1 = ubatch->pos[i]; + const T mask_keep = llama_cast<T>(0.0f); + const T mask_drop = llama_cast<T>(-INFINITY); - // for M-RoPE - const bool is_2d = ubatch->is_pos_2d(); - const llama_pos p1_x = is_2d ? ubatch->pos[i + ubatch->n_tokens*2] : 0; - const llama_pos p1_y = is_2d ? ubatch->pos[i + ubatch->n_tokens] : 0; + // the min position in the batch for each sequence + llama_pos seq_pos_min[LLAMA_MAX_SEQ]; + std::fill(seq_pos_min, seq_pos_min + LLAMA_MAX_SEQ, INT32_MAX); - const uint64_t idst = n_kv*(h*n_stream*n_tps + s*n_tps + ii); + for (uint32_t i = 0; i < ubatch->n_tokens; ++i) { + const llama_seq_id seq_id = ubatch->seq_id[i][0]; - for (uint32_t j = 0; j < n_kv; ++j) { - if (cells.is_empty(j)) { - continue; - } + seq_pos_min[seq_id] = std::min(seq_pos_min[seq_id], ubatch->pos[i]); + } + + for (uint32_t s = 0; s < n_stream; ++s) { + // bookkeeping of the KQ mask cells that could change for other tokens of the same sequence + std::unordered_map<llama_seq_id, uint32_t> seq_srct; + std::unordered_map<llama_seq_id, std::vector<uint32_t>> seq_idxs; + + for (uint32_t ii = 0; ii < n_tps; ++ii) { + const uint32_t i = s*n_tps + ii; + + const llama_seq_id seq_id = ubatch->seq_id[i][0]; + + const auto & cells = v_cells.at(seq_to_stream[seq_id]); - // mask the token if not the same sequence - if (!cells.seq_has(j, seq_id)) { - continue; + llama_pos p0 = -1; + const llama_pos p1 = ubatch->pos[i]; + + // for M-RoPE + const llama_pos p1_x = is_2d ? ubatch->pos[i + ubatch->n_tokens*2] : 0; + const llama_pos p1_y = is_2d ? ubatch->pos[i + ubatch->n_tokens] : 0; + + const uint64_t idst = n_kv*i; + + // for tokens of the same sequence, the mask is mostly the same, so we can reuse it + // the only cells that could change are the ones that are with similar positions as the + // ones in the batch (i.e. due to causal masking, SWA, etc.) + // keep track of those cells and shortcut the loop to save time + // note: this optimization is not compatible with Alibi position encoding + // ref: https://github.com/ggml-org/llama.cpp/pull/18842 + bool prev = false; + + auto & idxs = seq_idxs[seq_id]; + + if (!alibi) { + if (seq_srct.find(seq_id) != seq_srct.end()) { + const uint32_t srct = seq_srct[seq_id]; + + const uint64_t idst_prev = n_kv*srct; + + std::copy(data + idst_prev, data + idst_prev + n_kv, data + idst); + + prev = true; + } else { + idxs.clear(); + idxs.reserve(ubatch->n_tokens + n_swa + 32); + + seq_srct[seq_id] = i; + } + } + + for (uint32_t jj = 0; jj < n_kv; ++jj) { + uint32_t j = jj; + + // we have an exiting mask for this sequence -> update just seq_idxs + if (!alibi) { + if (prev) { + if (jj >= idxs.size()) { + break; + } + + j = idxs[jj]; } + } + + if (cells.is_empty(j)) { + goto skip; + } + + // mask the token if not the same sequence + if (!cells.seq_has(j, seq_id)) { + goto skip; + } + + p0 = cells.pos_get(j); - const llama_pos p0 = cells.pos_get(j); + if (!alibi) { + if (!prev) { + // record all cells for which: p0 >= seq_pos_min[seq_id] - n_swa - 32 + if (p0 + (int32_t) (n_swa + 32) >= seq_pos_min[seq_id]) { + idxs.push_back(j); + } + } + } + if (causal) { // mask future tokens - if (causal_attn && p0 > p1) { - continue; + if (p0 > p1) { + goto skip; } // M-RoPE causal mask - if (causal_attn && is_2d && p0 == p1) { - const auto & p0_ext = cells.ext_get(j); - if (p0_ext.is_2d_gt(p1_x, p1_y)) { - continue; + if (is_2d) { + if (p0 == p1) { + const auto & p0_ext = cells.ext_get(j); + + if (p0_ext.is_2d_gt(p1_x, p1_y)) { + goto skip; + } } } + } - // apply SWA if any - if (is_masked_swa(p0, p1)) { - continue; + // apply SWA if any + if (swa) { + if (llama_hparams::is_masked_swa(n_swa, swa_type, p0, p1)) { + goto skip; } + } - data[idst + j] = hparams.use_alibi ? -std::abs(p0 - p1) : 0.0f; + if (alibi) { + data[idst + j] = llama_cast<T>(static_cast<float>(-std::abs(p0 - p1))); + } else { + data[idst + j] = mask_keep; } + + continue; +skip: + data[idst + j] = mask_drop; } } } } +template<typename T, bool causal, bool swa, bool is_2d> +static void set_input_kq_mask_impl(const args_set_input_kq_mask & args, T * data) { + const bool alibi = args.hparams.use_alibi; + if (alibi) { + set_input_kq_mask_impl<T, causal, swa, is_2d, true> (args, data); + } else { + set_input_kq_mask_impl<T, causal, swa, is_2d, false>(args, data); + } +} + +template<typename T, bool causal, bool swa> +static void set_input_kq_mask_impl(const args_set_input_kq_mask & args, T * data) { + const bool is_2d = args.ubatch->is_pos_2d(); + if (is_2d) { + set_input_kq_mask_impl<T, causal, swa, true> (args, data); + } else { + set_input_kq_mask_impl<T, causal, swa, false>(args, data); + } +} + +template<typename T, bool causal> +static void set_input_kq_mask_impl(const args_set_input_kq_mask & args, T * data) { + const bool swa = args.swa_type != LLAMA_SWA_TYPE_NONE; + if (swa) { + set_input_kq_mask_impl<T, causal, true> (args, data); + } else { + set_input_kq_mask_impl<T, causal, false>(args, data); + } +} + +template<typename T> +static void set_input_kq_mask_impl(const args_set_input_kq_mask & args, T * data, bool causal_attn) { + if (causal_attn) { + set_input_kq_mask_impl<T, true> (args, data); + } else { + set_input_kq_mask_impl<T, false>(args, data); + } +} + +void llama_kv_cache::set_input_kq_mask(ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const { + const uint32_t n_tokens = ubatch->n_tokens; + + GGML_ASSERT(ggml_backend_buffer_is_host(dst->buffer)); + + const int64_t n_kv = dst->ne[0]; + const int64_t n_stream = dst->ne[3]; // num streams in the current ubatch + + GGML_ASSERT(n_tokens%n_stream == 0); + + // n_tps == n_tokens_per_stream + const int64_t n_tps = n_tokens/n_stream; + + //const int64_t t_start = ggml_time_us(); + + const args_set_input_kq_mask args = { + /*.hparams =*/ hparams, + /*.ubatch =*/ ubatch, + /*.v_cells =*/ v_cells, + /*.seq_to_stream =*/ seq_to_stream, + /*.n_swa =*/ n_swa, + /*.swa_type =*/ swa_type, + /*.n_kv =*/ n_kv, + /*.n_stream =*/ n_stream, + /*.n_tps =*/ n_tps, + }; + + if (dst->type == GGML_TYPE_F16) { + set_input_kq_mask_impl<ggml_fp16_t>(args, (ggml_fp16_t *) dst->data, causal_attn); + } else { + set_input_kq_mask_impl<float>(args, (float *) dst->data, causal_attn); + } + + //const int64_t t_end = ggml_time_us(); + + //LLAMA_LOG_ERROR("%s: kq mask time: %0.3f ms\n", __func__, (t_end - t_start)/1000.0); +} + void llama_kv_cache::set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const { const int64_t n_tokens = ubatch->n_tokens; @@ -1346,6 +1775,24 @@ void llama_kv_cache::set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch } } +void llama_kv_cache::set_input_k_rot(ggml_tensor * dst) const { + GGML_ASSERT(ggml_backend_buffer_is_host(dst->buffer)); + + const auto n_rot = dst->ne[0]; + GGML_ASSERT(attn_rot_hadamard.count(dst->ne[0])); + + memcpy(dst->data, attn_rot_hadamard.at(n_rot).data(), ggml_nbytes(dst)); +} + +void llama_kv_cache::set_input_v_rot(ggml_tensor * dst) const { + GGML_ASSERT(ggml_backend_buffer_is_host(dst->buffer)); + + const auto n_rot = dst->ne[0]; + GGML_ASSERT(attn_rot_hadamard.count(dst->ne[0])); + + memcpy(dst->data, attn_rot_hadamard.at(n_rot).data(), ggml_nbytes(dst)); +} + size_t llama_kv_cache::total_size() const { size_t size = 0; @@ -1370,7 +1817,7 @@ size_t llama_kv_cache::size_v_bytes() const { size_t size_v_bytes = 0; for (const auto & layer : layers) { - size_v_bytes += ggml_nbytes(layer.v); + size_v_bytes += layer.v ? ggml_nbytes(layer.v) : 0; } return size_v_bytes; @@ -1381,9 +1828,11 @@ ggml_tensor * llama_kv_cache::build_rope_shift( ggml_context * ctx, ggml_tensor * cur, ggml_tensor * shift, + ggml_tensor * rot, ggml_tensor * factors, float freq_base, - float freq_scale) const { + float freq_scale, + uint32_t il) const { const auto & n_ctx_orig = cparams.n_ctx_orig_yarn; const auto & yarn_ext_factor = cparams.yarn_ext_factor; @@ -1391,7 +1840,7 @@ ggml_tensor * llama_kv_cache::build_rope_shift( const auto & yarn_beta_slow = cparams.yarn_beta_slow; const auto & yarn_attn_factor = cparams.yarn_attn_factor; - const auto & n_rot = hparams.n_rot; + const auto & n_rot = hparams.n_rot(il); const auto & rope_type = hparams.rope_type == LLAMA_ROPE_TYPE_MROPE || hparams.rope_type == LLAMA_ROPE_TYPE_IMROPE // @ngxson : this is a workaround // for M-RoPE, we want to rotate the whole vector when doing KV shift @@ -1399,17 +1848,22 @@ ggml_tensor * llama_kv_cache::build_rope_shift( // ref: https://github.com/ggml-org/llama.cpp/pull/13870 ? LLAMA_ROPE_TYPE_NEOX : hparams.rope_type; - ggml_tensor * tmp; if (ggml_is_quantized(cur->type)) { // dequantize to f32 -> RoPE -> quantize back tmp = ggml_cast(ctx, cur, GGML_TYPE_F32); + // rotate back + tmp = ggml_mul_mat_aux(ctx, tmp, rot); + tmp = ggml_rope_ext(ctx, tmp, shift, factors, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, yarn_ext_factor, yarn_attn_factor, yarn_beta_fast, yarn_beta_slow); + // rotate fwd + tmp = ggml_mul_mat_aux(ctx, tmp, rot); + tmp = ggml_cpy(ctx, tmp, cur); } else { // we rotate only the first n_rot dimensions @@ -1430,6 +1884,9 @@ class llm_graph_input_k_shift : public llm_graph_input_i { ggml_tensor * k_shift; // I32 [kv_size*n_stream] + // note: assumes k_rot^2 == I + ggml_tensor * k_rot = nullptr; + const llama_kv_cache * kv_self; }; @@ -1439,20 +1896,26 @@ void llm_graph_input_k_shift::set_input(const llama_ubatch * ubatch) { if (k_shift) { kv_self->set_input_k_shift(k_shift); } + + if (k_rot) { + kv_self->set_input_k_rot(k_rot); + } } ggml_cgraph * llama_kv_cache::build_graph_shift(llm_graph_result * res, llama_context * lctx) const { + // TODO: refactor [TAG_KV_CACHE_SHARE_CELLS] + GGML_ASSERT(!other); + auto * ctx = res->get_ctx(); auto * gf = res->get_gf(); - const auto & n_embd_head_k = hparams.n_embd_head_k; - //const auto & n_embd_head_v = hparams.n_embd_head_v; - auto inp = std::make_unique<llm_graph_input_k_shift>(this); inp->k_shift = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, (int64_t) get_size()*n_stream); ggml_set_input(inp->k_shift); + inp->k_rot = build_input_k_rot(ctx); + const auto & cparams = lctx->get_cparams(); for (const auto & layer : layers) { @@ -1461,6 +1924,10 @@ ggml_cgraph * llama_kv_cache::build_graph_shift(llm_graph_result * res, llama_co const int64_t n_head_kv = hparams.n_head_kv(il); const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(il); + const auto n_rot = hparams.n_rot(il); + const auto n_embd_head_k = hparams.n_embd_head_k(il); + const auto n_embd_nope = hparams.n_lora_kv > 0 ? n_embd_head_k - n_rot : 0; + const float freq_base_l = model.get_rope_freq_base (cparams, il); const float freq_scale_l = model.get_rope_freq_scale(cparams, il); @@ -1468,12 +1935,12 @@ ggml_cgraph * llama_kv_cache::build_graph_shift(llm_graph_result * res, llama_co ggml_tensor * k = ggml_view_3d(ctx, layer.k, - n_embd_head_k, n_head_kv, get_size()*n_stream, + n_rot, n_head_kv, get_size()*n_stream, ggml_row_size(layer.k->type, n_embd_head_k), ggml_row_size(layer.k->type, n_embd_k_gqa), - 0); + ggml_row_size(layer.k->type, n_embd_nope)); - ggml_tensor * cur = build_rope_shift(cparams, ctx, k, inp->k_shift, rope_factors, freq_base_l, freq_scale_l); + ggml_tensor * cur = build_rope_shift(cparams, ctx, k, inp->k_shift, inp->k_rot, rope_factors, freq_base_l, freq_scale_l, il); ggml_build_forward_expand(gf, cur); } @@ -1483,11 +1950,12 @@ ggml_cgraph * llama_kv_cache::build_graph_shift(llm_graph_result * res, llama_co return gf; } -bool llama_kv_cache::is_masked_swa(llama_pos p0, llama_pos p1) const { - return llama_hparams::is_masked_swa(n_swa, swa_type, p0, p1); -} - void llama_kv_cache::state_write(llama_io_write_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) const { + // TODO: refactor [TAG_KV_CACHE_SHARE_CELLS] + if (other) { + return; + } + GGML_UNUSED(flags); io.write(&n_stream, sizeof(n_stream)); @@ -1504,7 +1972,19 @@ void llama_kv_cache::state_write(llama_io_write_i & io, llama_seq_id seq_id, lla uint32_t cell_range_begin = cells.size(); for (uint32_t i = 0; i < cells.size(); ++i) { - if (!cells.is_empty(i) && (seq_id == -1 || cells.seq_has(i, seq_id))) { + bool add_cell = true; + + add_cell = add_cell && !cells.is_empty(i); + add_cell = add_cell && (seq_id == -1 || cells.seq_has(i, seq_id)); + + // check the cell is not SWA-masked + if (add_cell && seq_id != -1) { + const bool is_masked = llama_hparams::is_masked_swa(n_swa, swa_type, cells.pos_get(i), cells.seq_pos_max(seq_id)); + + add_cell = !is_masked; + } + + if (add_cell) { ++cell_count; if (cell_range_begin == cells.size()) { cell_range_begin = i; @@ -1541,19 +2021,24 @@ void llama_kv_cache::state_write(llama_io_write_i & io, llama_seq_id seq_id, lla } void llama_kv_cache::state_read(llama_io_read_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) { + // TODO: refactor [TAG_KV_CACHE_SHARE_CELLS] + if (other) { + return; + } + GGML_UNUSED(flags); GGML_ASSERT(seq_id == -1 || (seq_id >= 0 && (size_t) seq_id < seq_to_stream.size())); uint32_t n_stream_cur; - io.read_to(&n_stream_cur, sizeof(n_stream_cur)); + io.read(&n_stream_cur, sizeof(n_stream_cur)); if (n_stream_cur != n_stream) { throw std::runtime_error("n_stream mismatch"); } for (uint32_t s = 0; s < n_stream; ++s) { uint32_t cell_count; - io.read_to(&cell_count, sizeof(cell_count)); + io.read(&cell_count, sizeof(cell_count)); if (cell_count == 0) { continue; @@ -1599,8 +2084,10 @@ void llama_kv_cache::state_write_meta(llama_io_write_i & io, const cell_ranges_t io.write(&pos, sizeof(pos)); io.write(&n_seq_id, sizeof(n_seq_id)); - // TODO: we also need to save llama_kv_cell_ext when apply_ubatch() support loading it - // see: https://github.com/ggml-org/llama.cpp/pull/16825#issuecomment-3460868350 + if (hparams.n_pos_per_embd() > 1) { + const llama_kv_cell_ext ext = cells.ext_get(i); + io.write(&ext, sizeof(ext)); + } for (const auto & seq_id : seq_ids) { io.write(&seq_id, sizeof(seq_id)); @@ -1618,8 +2105,6 @@ void llama_kv_cache::state_write_data(llama_io_write_i & io, const cell_ranges_t io.write(&v_trans, sizeof(v_trans)); io.write(&n_layer, sizeof(n_layer)); - std::vector<uint8_t> tmp_buf; - // Iterate and write all the keys first, each row is a cell // Get whole range at a time for (const auto & layer : layers) { @@ -1637,7 +2122,7 @@ void llama_kv_cache::state_write_data(llama_io_write_i & io, const cell_ranges_t const uint64_t k_size_row = ggml_row_size(k->type, n_embd_k_gqa); io.write(&k_size_row, sizeof(k_size_row)); - // Read each range of cells of k_size length each into tmp_buf and write out + // Read each range of cells of k_size length and write out for (const auto & range : cr.data) { const size_t range_size = range.second - range.first; const size_t buf_size = range_size * k_size_row; @@ -1652,6 +2137,9 @@ void llama_kv_cache::state_write_data(llama_io_write_i & io, const cell_ranges_t const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il); auto * v = layer.v_stream[cr.strm]; + if (!v) { + continue; + } // Write value type const int32_t v_type_i = (int32_t) v->type; @@ -1661,7 +2149,7 @@ void llama_kv_cache::state_write_data(llama_io_write_i & io, const cell_ranges_t const uint64_t v_size_row = ggml_row_size(v->type, n_embd_v_gqa); io.write(&v_size_row, sizeof(v_size_row)); - // Read each range of cells of v_size length each into tmp_buf and write out + // Read each range of cells of v_size length and write out for (const auto & range : cr.data) { const size_t range_size = range.second - range.first; const size_t buf_size = range_size * v_size_row; @@ -1678,6 +2166,9 @@ void llama_kv_cache::state_write_data(llama_io_write_i & io, const cell_ranges_t const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il); auto * v = layer.v_stream[cr.strm]; + if (!v) { + continue; + } // Write value type const int32_t v_type_i = (int32_t) v->type; @@ -1692,7 +2183,7 @@ void llama_kv_cache::state_write_data(llama_io_write_i & io, const cell_ranges_t // For each row, we get the element values of each cell for (uint32_t j = 0; j < n_embd_v_gqa; ++j) { - // Read each range of cells of v_size_el length each into tmp_buf and write out + // Read each range of cells of v_size_el length and write out for (const auto & range : cr.data) { const size_t range_size = range.second - range.first; const size_t src_offset = (range.first + j * kv_size) * v_size_el; @@ -1722,18 +2213,26 @@ bool llama_kv_cache::state_read_meta(llama_io_read_i & io, uint32_t strm, uint32 llama_pos pos; uint32_t n_seq_id; - io.read_to(&pos, sizeof(pos)); - io.read_to(&n_seq_id, sizeof(n_seq_id)); + io.read(&pos, sizeof(pos)); + io.read(&n_seq_id, sizeof(n_seq_id)); if (n_seq_id != 1) { LLAMA_LOG_ERROR("%s: invalid seq_id-agnostic kv cell\n", __func__); return false; } + if (hparams.n_pos_per_embd() > 1) { + llama_kv_cell_ext ext; + io.read(&ext, sizeof(ext)); + + ubatch.pos[i + ubatch.n_tokens] = ext.y; + ubatch.pos[i + ubatch.n_tokens*2] = ext.x; + } + // read the sequence id, but directly discard it - we will use dest_seq_id instead { llama_seq_id seq_id; - io.read_to(&seq_id, sizeof(seq_id)); + io.read(&seq_id, sizeof(seq_id)); } ubatch.pos[i] = pos; @@ -1743,7 +2242,7 @@ bool llama_kv_cache::state_read_meta(llama_io_read_i & io, uint32_t strm, uint32 sinfo = find_slot(ubatch, false); if (sinfo.empty()) { - LLAMA_LOG_ERROR("%s: failed to find available cells in kv cache\n", __func__); + LLAMA_LOG_ERROR("%s: failed to find %d available cells in kv cache\n", __func__, cell_count); return false; } @@ -1775,14 +2274,20 @@ bool llama_kv_cache::state_read_meta(llama_io_read_i & io, uint32_t strm, uint32 llama_pos pos; uint32_t n_seq_id; - io.read_to(&pos, sizeof(pos)); - io.read_to(&n_seq_id, sizeof(n_seq_id)); + io.read(&pos, sizeof(pos)); + io.read(&n_seq_id, sizeof(n_seq_id)); cells.pos_set(i, pos); + if (hparams.n_pos_per_embd() > 1) { + llama_kv_cell_ext ext; + io.read(&ext, sizeof(ext)); + cells.ext_set(i, ext); + } + for (uint32_t j = 0; j < n_seq_id; ++j) { llama_seq_id seq_id; - io.read_to(&seq_id, sizeof(seq_id)); + io.read(&seq_id, sizeof(seq_id)); if (seq_id < 0 || (uint32_t) seq_id >= n_seq_max) { LLAMA_LOG_ERROR("%s: invalid seq_id, %d is out of range [0, %u)\n", __func__, seq_id, n_seq_max); @@ -1815,8 +2320,8 @@ bool llama_kv_cache::state_read_data(llama_io_read_i & io, uint32_t strm, uint32 uint32_t v_trans; uint32_t n_layer; - io.read_to(&v_trans, sizeof(v_trans)); - io.read_to(&n_layer, sizeof(n_layer)); + io.read(&v_trans, sizeof(v_trans)); + io.read(&n_layer, sizeof(n_layer)); if (n_layer != layers.size()) { LLAMA_LOG_ERROR("%s: mismatched layer count (%u instead of %u)\n", __func__, n_layer, (uint32_t) layers.size()); @@ -1843,7 +2348,7 @@ bool llama_kv_cache::state_read_data(llama_io_read_i & io, uint32_t strm, uint32 // Read type of key int32_t k_type_i_ref; - io.read_to(&k_type_i_ref, sizeof(k_type_i_ref)); + io.read(&k_type_i_ref, sizeof(k_type_i_ref)); const int32_t k_type_i = (int32_t) k->type; if (k_type_i != k_type_i_ref) { LLAMA_LOG_ERROR("%s: mismatched key type (%d != %d, layer %d)\n", __func__, k_type_i, k_type_i_ref, il); @@ -1852,7 +2357,7 @@ bool llama_kv_cache::state_read_data(llama_io_read_i & io, uint32_t strm, uint32 // Read row size of key uint64_t k_size_row_ref; - io.read_to(&k_size_row_ref, sizeof(k_size_row_ref)); + io.read(&k_size_row_ref, sizeof(k_size_row_ref)); const size_t k_size_row = ggml_row_size(k->type, n_embd_k_gqa); if (k_size_row != k_size_row_ref) { LLAMA_LOG_ERROR("%s: mismatched key row size (%zu != %zu, layer %d)\n", __func__, k_size_row, (size_t) k_size_row_ref, il); @@ -1862,13 +2367,12 @@ bool llama_kv_cache::state_read_data(llama_io_read_i & io, uint32_t strm, uint32 if (cell_count) { if (sinfo.is_contiguous()) { // Fast path: contiguous cells, single memcpy - ggml_backend_tensor_set(k, io.read(cell_count * k_size_row), sinfo.head() * k_size_row, cell_count * k_size_row); + io.read_tensor(k, sinfo.head() * k_size_row, cell_count * k_size_row); } else { // Slow path: scatter to non-contiguous positions - const void * src = io.read(cell_count * k_size_row); for (uint32_t i = 0; i < cell_count; ++i) { const size_t dst_offset = sinfo.idxs[0][i] * k_size_row; - ggml_backend_tensor_set(k, (const char*)src + i * k_size_row, dst_offset, k_size_row); + io.read_tensor(k, dst_offset, k_size_row); } } } @@ -1881,10 +2385,13 @@ bool llama_kv_cache::state_read_data(llama_io_read_i & io, uint32_t strm, uint32 const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il); auto * v = layer.v_stream[strm]; + if (!v) { + continue; + } // Read type of value int32_t v_type_i_ref; - io.read_to(&v_type_i_ref, sizeof(v_type_i_ref)); + io.read(&v_type_i_ref, sizeof(v_type_i_ref)); const int32_t v_type_i = (int32_t) v->type; if (v_type_i != v_type_i_ref) { LLAMA_LOG_ERROR("%s: mismatched value type (%d != %d, layer %d)\n", __func__, v_type_i, v_type_i_ref, il); @@ -1893,7 +2400,7 @@ bool llama_kv_cache::state_read_data(llama_io_read_i & io, uint32_t strm, uint32 // Read row size of value uint64_t v_size_row_ref; - io.read_to(&v_size_row_ref, sizeof(v_size_row_ref)); + io.read(&v_size_row_ref, sizeof(v_size_row_ref)); const size_t v_size_row = ggml_row_size(v->type, n_embd_v_gqa); if (v_size_row != v_size_row_ref) { LLAMA_LOG_ERROR("%s: mismatched value row size (%zu != %zu, layer %d)\n", __func__, v_size_row, (size_t) v_size_row_ref, il); @@ -1903,13 +2410,12 @@ bool llama_kv_cache::state_read_data(llama_io_read_i & io, uint32_t strm, uint32 if (cell_count) { if (sinfo.is_contiguous()) { // Fast path: contiguous cells, single memcpy - ggml_backend_tensor_set(v, io.read(cell_count * v_size_row), sinfo.head() * v_size_row, cell_count * v_size_row); + io.read_tensor(v, sinfo.head() * v_size_row, cell_count * v_size_row); } else { // Slow path: scatter to non-contiguous positions - const void * src = io.read(cell_count * v_size_row); for (uint32_t i = 0; i < cell_count; ++i) { const size_t dst_offset = sinfo.idxs[0][i] * v_size_row; - ggml_backend_tensor_set(v, (const char*)src + i * v_size_row, dst_offset, v_size_row); + io.read_tensor(v, dst_offset, v_size_row); } } } @@ -1922,10 +2428,13 @@ bool llama_kv_cache::state_read_data(llama_io_read_i & io, uint32_t strm, uint32 const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il); auto * v = layer.v_stream[strm]; + if (!v) { + continue; + } // Read type of value int32_t v_type_i_ref; - io.read_to(&v_type_i_ref, sizeof(v_type_i_ref)); + io.read(&v_type_i_ref, sizeof(v_type_i_ref)); const int32_t v_type_i = (int32_t) v->type; if (v_type_i != v_type_i_ref) { LLAMA_LOG_ERROR("%s: mismatched value type (%d != %d, layer %d)\n", __func__, v_type_i, v_type_i_ref, il); @@ -1934,7 +2443,7 @@ bool llama_kv_cache::state_read_data(llama_io_read_i & io, uint32_t strm, uint32 // Read element size of value uint32_t v_size_el_ref; - io.read_to(&v_size_el_ref, sizeof(v_size_el_ref)); + io.read(&v_size_el_ref, sizeof(v_size_el_ref)); const size_t v_size_el = ggml_type_size(v->type); if (v_size_el != v_size_el_ref) { LLAMA_LOG_ERROR("%s: mismatched value element size (%zu != %zu, layer %d)\n", __func__, v_size_el, (size_t) v_size_el_ref, il); @@ -1943,7 +2452,7 @@ bool llama_kv_cache::state_read_data(llama_io_read_i & io, uint32_t strm, uint32 // Read GQA embedding size uint32_t n_embd_v_gqa_ref; - io.read_to(&n_embd_v_gqa_ref, sizeof(n_embd_v_gqa_ref)); + io.read(&n_embd_v_gqa_ref, sizeof(n_embd_v_gqa_ref)); if (n_embd_v_gqa != n_embd_v_gqa_ref) { LLAMA_LOG_ERROR("%s: mismatched GQA embedding size (%u != %u, layer %d)\n", __func__, n_embd_v_gqa, n_embd_v_gqa_ref, il); return false; @@ -1955,15 +2464,14 @@ bool llama_kv_cache::state_read_data(llama_io_read_i & io, uint32_t strm, uint32 const uint32_t h = sinfo.head(); for (uint32_t j = 0; j < n_embd_v_gqa; ++j) { const size_t dst_offset = (h + j * cells.size()) * v_size_el; - ggml_backend_tensor_set(v, io.read(cell_count * v_size_el), dst_offset, cell_count * v_size_el); + io.read_tensor(v, dst_offset, cell_count * v_size_el); } } else { // Slow path: scatter to non-contiguous positions for (uint32_t j = 0; j < n_embd_v_gqa; ++j) { - const void * src = io.read(cell_count * v_size_el); for (uint32_t i = 0; i < cell_count; ++i) { const size_t dst_offset = (sinfo.idxs[0][i] + j * cells.size()) * v_size_el; - ggml_backend_tensor_set(v, (const char*)src + i * v_size_el, dst_offset, v_size_el); + io.read_tensor(v, dst_offset, v_size_el); } } } @@ -2055,6 +2563,14 @@ uint32_t llama_kv_cache_context::get_n_kv() const { return n_kv; } +ggml_type llama_kv_cache_context::type_k() const { + return kv->type_k(); +} + +ggml_type llama_kv_cache_context::type_v() const { + return kv->type_v(); +} + ggml_tensor * llama_kv_cache_context::get_k(ggml_context * ctx, int32_t il) const { return kv->get_k(ctx, il, n_kv, sinfos[i_cur]); } @@ -2079,6 +2595,14 @@ ggml_tensor * llama_kv_cache_context::build_input_v_idxs(ggml_context * ctx, con return kv->build_input_v_idxs(ctx, ubatch); } +ggml_tensor * llama_kv_cache_context::build_input_k_rot(ggml_context * ctx) const { + return kv->build_input_k_rot(ctx); +} + +ggml_tensor * llama_kv_cache_context::build_input_v_rot(ggml_context * ctx) const { + return kv->build_input_v_rot(ctx); +} + void llama_kv_cache_context::set_input_k_shift(ggml_tensor * dst) const { kv->set_input_k_shift(dst); } @@ -2098,3 +2622,11 @@ void llama_kv_cache_context::set_input_kq_mask(ggml_tensor * dst, const llama_ub void llama_kv_cache_context::set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const { kv->set_input_pos_bucket(dst, ubatch); } + +void llama_kv_cache_context::set_input_k_rot(ggml_tensor * dst) const { + kv->set_input_k_rot(dst); +} + +void llama_kv_cache_context::set_input_v_rot(ggml_tensor * dst) const { + kv->set_input_v_rot(dst); +} diff --git a/examples/talk-llama/llama-kv-cache.h b/examples/talk-llama/llama-kv-cache.h index 0c4ed648456..3d68f98c142 100644 --- a/examples/talk-llama/llama-kv-cache.h +++ b/examples/talk-llama/llama-kv-cache.h @@ -93,8 +93,12 @@ class llama_kv_cache : public llama_memory_i { using slot_info_vec_t = std::vector<slot_info>; + // TODO: refactor the memory instances to not depend on `llama_model` + // instead pass all necessary info (e.g. hparams, dev layers, arch, etc.) directly + // likely through `struct llama_memory_params` llama_kv_cache( const llama_model & model, + const llama_hparams & hparams, ggml_type type_k, ggml_type type_v, bool v_trans, @@ -105,8 +109,10 @@ class llama_kv_cache : public llama_memory_i { uint32_t n_pad, uint32_t n_swa, llama_swa_type swa_type, + llama_memory_t mem_other, const layer_filter_cb & filter, - const layer_reuse_cb & reuse); + const layer_reuse_cb & reuse, + const layer_share_cb & share); ~llama_kv_cache() = default; @@ -152,6 +158,9 @@ class llama_kv_cache : public llama_memory_i { bool get_has_shift() const; + ggml_type type_k() const; + ggml_type type_v() const; + // // graph_build API // @@ -191,6 +200,9 @@ class llama_kv_cache : public llama_memory_i { ggml_tensor * build_input_k_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const; ggml_tensor * build_input_v_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const; + ggml_tensor * build_input_k_rot(ggml_context * ctx) const; + ggml_tensor * build_input_v_rot(ggml_context * ctx) const; + void set_input_k_idxs(ggml_tensor * dst, const llama_ubatch * ubatch, const slot_info & sinfo) const; void set_input_v_idxs(ggml_tensor * dst, const llama_ubatch * ubatch, const slot_info & sinfo) const; @@ -199,6 +211,9 @@ class llama_kv_cache : public llama_memory_i { void set_input_kq_mask (ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const; void set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const; + void set_input_k_rot(ggml_tensor * dst) const; + void set_input_v_rot(ggml_tensor * dst) const; + private: const llama_model & model; const llama_hparams & hparams; @@ -226,6 +241,18 @@ class llama_kv_cache : public llama_memory_i { // SWA const uint32_t n_swa = 0; + // env: LLAMA_ATTN_ROT_DISABLE + bool attn_rot_k = false; + bool attn_rot_v = false; + + // if all layers participating in the cache have constant head size, the value is stored here + // otherwise the value is -1 + int32_t n_embd_head_k_all = 0; + int32_t n_embd_head_v_all = 0; + + // pre-computed hadamard martrices + std::unordered_map<int64_t, std::vector<float>> attn_rot_hadamard; + // env: LLAMA_KV_CACHE_DEBUG int debug = 0; @@ -239,7 +266,12 @@ class llama_kv_cache : public llama_memory_i { // note: this is not part of the KV state and it's only used to speed-up the find_slot() method std::vector<uint32_t> v_heads; - std::vector<llama_kv_cells> v_cells; + // TODO: temporary until we refactor to be able to share the same cells between 2 kv caches [TAG_KV_CACHE_SHARE_CELLS] + llama_kv_cache * other; + + std::shared_ptr<llama_kv_cells_vec> v_cells_impl; + + llama_kv_cells_vec & v_cells; // maps from a sequence id to a stream id std::vector<uint32_t> seq_to_stream; @@ -257,16 +289,16 @@ class llama_kv_cache : public llama_memory_i { size_t size_k_bytes() const; size_t size_v_bytes() const; - bool is_masked_swa(llama_pos p0, llama_pos p1) const; - ggml_tensor * build_rope_shift( const llama_cparams & cparams, ggml_context * ctx, ggml_tensor * cur, ggml_tensor * shift, + ggml_tensor * rot, ggml_tensor * factors, float freq_base, - float freq_scale) const; + float freq_scale, + uint32_t il) const; ggml_cgraph * build_graph_shift( llm_graph_result * res, @@ -329,12 +361,15 @@ class llama_kv_cache_context : public llama_memory_context_i { uint32_t get_n_kv() const; + ggml_type type_k() const; + ggml_type type_v() const; + // get views of the current state of the cache ggml_tensor * get_k(ggml_context * ctx, int32_t il) const; ggml_tensor * get_v(ggml_context * ctx, int32_t il) const; // store k_cur and v_cur in the cache based on the provided head location - // note: the heads in k_cur and v_cur should be layed out contiguously in memory + // note: the heads in k_cur and v_cur should be laid out contiguously in memory // - k_cur [n_embd_head_k, n_head_k, n_tokens] // - k_idxs [n_tokens] // - v_cur [n_embd_head_v, n_head_v, n_tokens] @@ -348,6 +383,9 @@ class llama_kv_cache_context : public llama_memory_context_i { ggml_tensor * build_input_k_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const; ggml_tensor * build_input_v_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const; + ggml_tensor * build_input_k_rot(ggml_context * ctx) const; + ggml_tensor * build_input_v_rot(ggml_context * ctx) const; + void set_input_k_idxs(ggml_tensor * dst, const llama_ubatch * ubatch) const; void set_input_v_idxs(ggml_tensor * dst, const llama_ubatch * ubatch) const; @@ -355,6 +393,9 @@ class llama_kv_cache_context : public llama_memory_context_i { void set_input_kq_mask (ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const; void set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const; + void set_input_k_rot(ggml_tensor * dst) const; + void set_input_v_rot(ggml_tensor * dst) const; + private: llama_memory_status status; diff --git a/examples/talk-llama/llama-kv-cells.h b/examples/talk-llama/llama-kv-cells.h index 10063bf4272..fddd31a0b21 100644 --- a/examples/talk-llama/llama-kv-cells.h +++ b/examples/talk-llama/llama-kv-cells.h @@ -531,3 +531,5 @@ class llama_kv_cells { } } }; + +using llama_kv_cells_vec = std::vector<llama_kv_cells>; diff --git a/examples/talk-llama/llama-memory-hybrid-iswa.cpp b/examples/talk-llama/llama-memory-hybrid-iswa.cpp new file mode 100644 index 00000000000..c7d4bcd413e --- /dev/null +++ b/examples/talk-llama/llama-memory-hybrid-iswa.cpp @@ -0,0 +1,285 @@ +#include "llama-memory-hybrid-iswa.h" + +#include "llama-impl.h" +#include "llama-model.h" +#include "llama-context.h" + +// +// llama_memory_hybrid_iswa +// + +llama_memory_hybrid_iswa::llama_memory_hybrid_iswa( + const llama_model & model, + /* attn */ + ggml_type type_k, + ggml_type type_v, + bool v_trans, + bool swa_full, + uint32_t kv_size, + uint32_t n_ubatch, + uint32_t n_pad, + /* recurrent */ + ggml_type type_r, + ggml_type type_s, + uint32_t rs_size, + /* common */ + uint32_t n_seq_max, + uint32_t n_rs_seq, + bool offload, + bool unified, + /* layer filters */ + const layer_filter_cb & filter_attn, + const layer_filter_cb & filter_recr) : + hparams(model.hparams), + mem_attn(new llama_kv_cache_iswa( + model, + type_k, + type_v, + v_trans, + offload, + swa_full, + unified, + kv_size, + n_seq_max, + n_ubatch, + n_pad, + nullptr, + filter_attn == nullptr ? + [&](int32_t il) { return !hparams.is_recr(il); } + : filter_attn, + nullptr, + nullptr + )), + mem_recr(new llama_memory_recurrent( + model, + type_r, + type_s, + offload, + rs_size, + n_seq_max, + n_rs_seq, + filter_recr == nullptr ? + [&](int32_t il) { return hparams.is_recr(il); } + : filter_recr + )) {} + +llama_memory_context_ptr llama_memory_hybrid_iswa::init_batch(llama_batch_allocr & balloc, uint32_t n_ubatch, bool embd_all) { + do { + balloc.split_reset(); + + // follow the recurrent pattern for creating the ubatch splits + std::vector<llama_ubatch> ubatches; + + while (true) { + llama_ubatch ubatch; + + if (embd_all) { + // if all tokens are output, split by sequence + ubatch = balloc.split_seq(n_ubatch); + } else { + if (mem_recr->n_rs_seq > 0) { + // [TAG_RECURRENT_ROLLBACK_SPLITS] + // TODO: recurrent state rollback does not support equal splits + ubatch = balloc.split_seq(n_ubatch); + } else { + // Use non-sequential split when KV cache is unified (needed for hellaswag/winogrande/multiple-choice) + const bool unified = (mem_attn->get_base()->get_n_stream() == 1); + ubatch = balloc.split_equal(n_ubatch, !unified); + } + } + + if (ubatch.n_tokens == 0) { + break; + } + + ubatches.push_back(std::move(ubatch)); // NOLINT + } + + if (balloc.get_n_used() < balloc.get_n_tokens()) { + // failed to find a suitable split + break; + } + + // prepare the recurrent batches first + if (!mem_recr->prepare(ubatches)) { + // TODO: will the recurrent cache be in an undefined context at this point? + LLAMA_LOG_ERROR("%s: failed to prepare recurrent ubatches\n", __func__); + return std::make_unique<llama_memory_hybrid_iswa_context>(LLAMA_MEMORY_STATUS_FAILED_PREPARE); + } + + // prepare the attention cache (iswa version returns both base and swa slot infos) + auto sinfos_base = mem_attn->get_base()->prepare(ubatches); + if (sinfos_base.empty()) { + LLAMA_LOG_ERROR("%s: failed to prepare attention base ubatches\n", __func__); + return std::make_unique<llama_memory_hybrid_iswa_context>(LLAMA_MEMORY_STATUS_FAILED_PREPARE); + } + + auto sinfos_swa = mem_attn->get_swa()->prepare(ubatches); + if (sinfos_swa.empty()) { + LLAMA_LOG_ERROR("%s: failed to prepare attention swa ubatches\n", __func__); + return std::make_unique<llama_memory_hybrid_iswa_context>(LLAMA_MEMORY_STATUS_FAILED_PREPARE); + } + + return std::make_unique<llama_memory_hybrid_iswa_context>( + this, std::move(sinfos_base), std::move(sinfos_swa), std::move(ubatches)); + } while(false); + + return std::make_unique<llama_memory_hybrid_iswa_context>(LLAMA_MEMORY_STATUS_FAILED_PREPARE); +} + +llama_memory_context_ptr llama_memory_hybrid_iswa::init_full() { + return std::make_unique<llama_memory_hybrid_iswa_context>(this); +} + +llama_memory_context_ptr llama_memory_hybrid_iswa::init_update(llama_context * lctx, bool optimize) { + return std::make_unique<llama_memory_hybrid_iswa_context>(this, lctx, optimize); +} + +bool llama_memory_hybrid_iswa::get_can_shift() const { + // Shifting is trivially supported for recurrent + return mem_attn->get_can_shift(); +} + +void llama_memory_hybrid_iswa::clear(bool data) { + mem_attn->clear(data); + mem_recr->clear(data); +} + +bool llama_memory_hybrid_iswa::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) { + // Try removing from the recurrent cache first since it may fail. If it does + // fail, the cache will not have been mutated. + if (!mem_recr->seq_rm(seq_id, p0, p1)) { + return false; + } + return mem_attn->seq_rm(seq_id, p0, p1); +} + +void llama_memory_hybrid_iswa::seq_cp(llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) { + mem_attn->seq_cp(seq_id_src, seq_id_dst, p0, p1); + mem_recr->seq_cp(seq_id_src, seq_id_dst, p0, p1); +} + +void llama_memory_hybrid_iswa::seq_keep(llama_seq_id seq_id) { + mem_attn->seq_keep(seq_id); + mem_recr->seq_keep(seq_id); +} + +void llama_memory_hybrid_iswa::seq_add(llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) { + mem_attn->seq_add(seq_id, p0, p1, shift); + mem_recr->seq_add(seq_id, p0, p1, shift); +} + +void llama_memory_hybrid_iswa::seq_div(llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) { + mem_attn->seq_div(seq_id, p0, p1, d); + mem_recr->seq_div(seq_id, p0, p1, d); +} + +llama_pos llama_memory_hybrid_iswa::seq_pos_min(llama_seq_id seq_id) const { + // the min of the total cache is the max of the two caches' min values + return std::max(mem_attn->seq_pos_min(seq_id), mem_recr->seq_pos_min(seq_id)); +} + +llama_pos llama_memory_hybrid_iswa::seq_pos_max(llama_seq_id seq_id) const { + // the max of the total cache is the min of the two caches' max values + return std::min(mem_attn->seq_pos_max(seq_id), mem_recr->seq_pos_max(seq_id)); +} + +std::map<ggml_backend_buffer_type_t, size_t> llama_memory_hybrid_iswa::memory_breakdown() const { + std::map<ggml_backend_buffer_type_t, size_t> mb = mem_attn->memory_breakdown(); + for (const auto & buft_size : mem_recr->memory_breakdown()) { + mb[buft_size.first] += buft_size.second; + } + return mb; +} + +void llama_memory_hybrid_iswa::state_write(llama_io_write_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) const { + mem_attn->state_write(io, seq_id, flags); + mem_recr->state_write(io, seq_id, flags); +} + +void llama_memory_hybrid_iswa::state_read(llama_io_read_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) { + mem_attn->state_read(io, seq_id, flags); + mem_recr->state_read(io, seq_id, flags); +} + +llama_kv_cache_iswa * llama_memory_hybrid_iswa::get_mem_attn() const { + return mem_attn.get(); +} + +llama_memory_recurrent * llama_memory_hybrid_iswa::get_mem_recr() const { + return mem_recr.get(); +} + +// +// llama_memory_hybrid_iswa_context +// + +llama_memory_hybrid_iswa_context::llama_memory_hybrid_iswa_context(llama_memory_status status) : status(status) {} + +llama_memory_hybrid_iswa_context::llama_memory_hybrid_iswa_context(llama_memory_hybrid_iswa * mem) : + ctx_attn(mem->get_mem_attn()->init_full()), + ctx_recr(mem->get_mem_recr()->init_full()), + status(llama_memory_status_combine(ctx_attn->get_status(), ctx_recr->get_status())) { +} + +llama_memory_hybrid_iswa_context::llama_memory_hybrid_iswa_context( + llama_memory_hybrid_iswa * mem, + llama_context * lctx, + bool optimize) : + ctx_attn(mem->get_mem_attn()->init_update(lctx, optimize)), + ctx_recr(mem->get_mem_recr()->init_update(lctx, optimize)), + status(llama_memory_status_combine(ctx_attn->get_status(), ctx_recr->get_status())) { +} + +llama_memory_hybrid_iswa_context::llama_memory_hybrid_iswa_context( + llama_memory_hybrid_iswa * mem, + slot_info_vec_t sinfos_base, + slot_info_vec_t sinfos_swa, + std::vector<llama_ubatch> ubatches) : + ubatches(std::move(ubatches)), + // note: here we copy the ubatches. not sure if this is ideal + ctx_attn(new llama_kv_cache_iswa_context(mem->get_mem_attn(), std::move(sinfos_base), std::move(sinfos_swa), this->ubatches)), + ctx_recr(new llama_memory_recurrent_context(mem->get_mem_recr(), this->ubatches)), + status(llama_memory_status_combine(ctx_attn->get_status(), ctx_recr->get_status())) { +} + +bool llama_memory_hybrid_iswa_context::next() { + assert(status == LLAMA_MEMORY_STATUS_SUCCESS); + + ctx_attn->next(); + ctx_recr->next(); + + if (++i_next >= ubatches.size()) { + return false; + } + + return true; +} + +bool llama_memory_hybrid_iswa_context::apply() { + assert(!llama_memory_status_is_fail(status)); + + bool res = true; + + res = res & ctx_attn->apply(); + res = res & ctx_recr->apply(); + + return res; +} + +llama_memory_status llama_memory_hybrid_iswa_context::get_status() const { + return status; +} + +const llama_ubatch & llama_memory_hybrid_iswa_context::get_ubatch() const { + assert(status == LLAMA_MEMORY_STATUS_SUCCESS); + return ubatches[i_next]; +} + +const llama_kv_cache_iswa_context * llama_memory_hybrid_iswa_context::get_attn() const { + return static_cast<const llama_kv_cache_iswa_context *>(ctx_attn.get()); +} + +const llama_memory_recurrent_context * llama_memory_hybrid_iswa_context::get_recr() const { + return static_cast<const llama_memory_recurrent_context *>(ctx_recr.get()); +} diff --git a/examples/talk-llama/llama-memory-hybrid-iswa.h b/examples/talk-llama/llama-memory-hybrid-iswa.h new file mode 100644 index 00000000000..c9d3f9f57c5 --- /dev/null +++ b/examples/talk-llama/llama-memory-hybrid-iswa.h @@ -0,0 +1,141 @@ +#pragma once + +#include "llama-batch.h" +#include "llama-graph.h" +#include "llama-kv-cache-iswa.h" +#include "llama-memory.h" +#include "llama-memory-recurrent.h" + +#include <memory> +#include <vector> + +// +// llama_memory_hybrid_iswa +// + +// utilizes instances of llama_memory_recurrent and llama_kv_cache_iswa to +// support models where each layer may be either attention-based (with SWA support) or recurrent + +class llama_memory_hybrid_iswa : public llama_memory_i { +public: + llama_memory_hybrid_iswa( + const llama_model & model, + /* attn */ + ggml_type type_k, + ggml_type type_v, + bool v_trans, + bool swa_full, + uint32_t kv_size, + uint32_t n_ubatch, + uint32_t n_pad, + /* recurrent */ + ggml_type type_r, + ggml_type type_s, + uint32_t rs_size, + /* common */ + uint32_t n_seq_max, + uint32_t n_rs_seq, + bool offload, + bool unified, + /* layer filters */ + const layer_filter_cb & filter_attn = nullptr, + const layer_filter_cb & filter_recr = nullptr); + + ~llama_memory_hybrid_iswa() = default; + + // + // llama_memory_i + // + + llama_memory_context_ptr init_batch( + llama_batch_allocr & balloc, + uint32_t n_ubatch, + bool embd_all) override; + + llama_memory_context_ptr init_full() override; + + llama_memory_context_ptr init_update(llama_context * lctx, bool optimize) override; + + bool get_can_shift() const override; + + void clear(bool data) override; + + bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) override; + void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) override; + void seq_keep(llama_seq_id seq_id) override; + void seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) override; + void seq_div (llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) override; + + llama_pos seq_pos_min(llama_seq_id seq_id) const override; + llama_pos seq_pos_max(llama_seq_id seq_id) const override; + + std::map<ggml_backend_buffer_type_t, size_t> memory_breakdown() const override; + + // state write/load + + void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1, llama_state_seq_flags flags = 0) const override; + void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1, llama_state_seq_flags flags = 0) override; + + // + // llama_memory_hybrid_iswa specific API + // + + llama_kv_cache_iswa * get_mem_attn() const; + llama_memory_recurrent * get_mem_recr() const; + +private: + const llama_hparams & hparams; + + const std::unique_ptr<llama_kv_cache_iswa> mem_attn; + const std::unique_ptr<llama_memory_recurrent> mem_recr; +}; + +class llama_memory_hybrid_iswa_context : public llama_memory_context_i { +public: + using slot_info_vec_t = llama_kv_cache::slot_info_vec_t; + + // init failure + explicit llama_memory_hybrid_iswa_context(llama_memory_status status); + + // init full + explicit llama_memory_hybrid_iswa_context(llama_memory_hybrid_iswa * mem); + + // init update + explicit llama_memory_hybrid_iswa_context( + llama_memory_hybrid_iswa * mem, + llama_context * lctx, + bool optimize); + + // init success + llama_memory_hybrid_iswa_context( + llama_memory_hybrid_iswa * mem, + slot_info_vec_t sinfos_base, + slot_info_vec_t sinfos_swa, + std::vector<llama_ubatch> ubatches); + + ~llama_memory_hybrid_iswa_context() = default; + + bool next() override; + bool apply() override; + + llama_memory_status get_status() const override; + const llama_ubatch & get_ubatch() const override; + + // + // llama_memory_hybrid_iswa_context + // + + const llama_kv_cache_iswa_context * get_attn() const; + const llama_memory_recurrent_context * get_recr() const; + +private: + // the index of the next ubatch to process + size_t i_next = 0; + + std::vector<llama_ubatch> ubatches; + + const llama_memory_context_ptr ctx_attn; + const llama_memory_context_ptr ctx_recr; + + const llama_memory_status status; +}; diff --git a/examples/talk-llama/llama-memory-hybrid.cpp b/examples/talk-llama/llama-memory-hybrid.cpp index a1b45e4a3cc..f2d49cbce54 100644 --- a/examples/talk-llama/llama-memory-hybrid.cpp +++ b/examples/talk-llama/llama-memory-hybrid.cpp @@ -24,6 +24,7 @@ llama_memory_hybrid::llama_memory_hybrid( uint32_t rs_size, /* common */ uint32_t n_seq_max, + uint32_t n_rs_seq, bool offload, bool unified, /* layer filters */ @@ -32,6 +33,7 @@ llama_memory_hybrid::llama_memory_hybrid( hparams(model.hparams), mem_attn(new llama_kv_cache( model, + model.hparams, type_k, type_v, v_trans, @@ -42,9 +44,11 @@ llama_memory_hybrid::llama_memory_hybrid( n_pad, n_swa, swa_type, + nullptr, filter_attn == nullptr ? - [&](int32_t il) { return !hparams.is_recurrent(il); } + [&](int32_t il) { return !hparams.is_recr(il); } : filter_attn, + nullptr, nullptr )), mem_recr(new llama_memory_recurrent( @@ -54,8 +58,9 @@ llama_memory_hybrid::llama_memory_hybrid( offload, rs_size, n_seq_max, + n_rs_seq, filter_recr == nullptr ? - [&](int32_t il) { return hparams.is_recurrent(il); } + [&](int32_t il) { return hparams.is_recr(il); } : filter_recr )) {} @@ -73,9 +78,15 @@ llama_memory_context_ptr llama_memory_hybrid::init_batch(llama_batch_allocr & ba // if all tokens are output, split by sequence ubatch = balloc.split_seq(n_ubatch); } else { - // TODO: non-sequential equal split can be done if using unified KV cache - // for simplicity, we always use sequential equal split for now - ubatch = balloc.split_equal(n_ubatch, true); + if (mem_recr->n_rs_seq > 0) { + // [TAG_RECURRENT_ROLLBACK_SPLITS] + // TODO: recurrent state rollback does not support equal splits + ubatch = balloc.split_seq(n_ubatch); + } else { + // Use non-sequential split when KV cache is unified (needed for hellaswag/winogrande/multiple-choice) + const bool unified = (mem_attn->get_n_stream() == 1); + ubatch = balloc.split_equal(n_ubatch, !unified); + } } if (ubatch.n_tokens == 0) { diff --git a/examples/talk-llama/llama-memory-hybrid.h b/examples/talk-llama/llama-memory-hybrid.h index 558cafdf984..484eafb7499 100644 --- a/examples/talk-llama/llama-memory-hybrid.h +++ b/examples/talk-llama/llama-memory-hybrid.h @@ -34,6 +34,7 @@ class llama_memory_hybrid : public llama_memory_i { uint32_t rs_size, /* common */ uint32_t n_seq_max, + uint32_t n_rs_seq, bool offload, bool unified, /* layer filters */ diff --git a/examples/talk-llama/llama-memory-recurrent.cpp b/examples/talk-llama/llama-memory-recurrent.cpp index 812bf253049..6a4892fb471 100644 --- a/examples/talk-llama/llama-memory-recurrent.cpp +++ b/examples/talk-llama/llama-memory-recurrent.cpp @@ -1,5 +1,6 @@ #include "llama-memory-recurrent.h" +#include "ggml-backend.h" #include "llama-impl.h" #include "llama-io.h" #include "llama-batch.h" @@ -23,13 +24,17 @@ llama_memory_recurrent::llama_memory_recurrent( bool offload, uint32_t mem_size, uint32_t n_seq_max, + uint32_t n_rs_seq, const layer_filter_cb & filter) : hparams(model.hparams), n_seq_max(n_seq_max) { - const int32_t n_layer = hparams.n_layer; + const int32_t n_layer = hparams.n_layer(); head = 0; size = mem_size; used = 0; + this->n_rs_seq = n_rs_seq; + rs_idx.assign(n_seq_max, 0); + cells.clear(); cells.resize(mem_size); @@ -91,8 +96,9 @@ llama_memory_recurrent::llama_memory_recurrent( throw std::runtime_error("failed to create ggml context for rs cache"); } - ggml_tensor * r = ggml_new_tensor_1d(ctx, type_r, hparams.n_embd_r()*mem_size); - ggml_tensor * s = ggml_new_tensor_1d(ctx, type_s, hparams.n_embd_s()*mem_size); + const uint32_t n_rows = mem_size * (1 + n_rs_seq); + ggml_tensor * r = ggml_new_tensor_2d(ctx, type_r, hparams.n_embd_r(), n_rows); + ggml_tensor * s = ggml_new_tensor_2d(ctx, type_s, hparams.n_embd_s(), n_rows); ggml_format_name(r, "cache_r_l%d", i); ggml_format_name(s, "cache_s_l%d", i); r_l[i] = r; @@ -114,8 +120,8 @@ llama_memory_recurrent::llama_memory_recurrent( const size_t memory_size_r = size_r_bytes(); const size_t memory_size_s = size_s_bytes(); - LLAMA_LOG_INFO("%s: size = %7.2f MiB (%6u cells, %3d layers, %2u seqs), R (%s): %7.2f MiB, S (%s): %7.2f MiB\n", __func__, - (float)(memory_size_r + memory_size_s) / (1024.0f * 1024.0f), mem_size, n_layer, n_seq_max, + LLAMA_LOG_INFO("%s: size = %7.2f MiB (%6u cells, %3d layers, %2u seqs %2u rs_seq), R (%s): %7.2f MiB, S (%s): %7.2f MiB\n", __func__, + (float)(memory_size_r + memory_size_s) / (1024.0f * 1024.0f), mem_size, n_layer, n_seq_max, n_rs_seq, ggml_type_name(type_r), (float)memory_size_r / (1024.0f * 1024.0f), ggml_type_name(type_s), (float)memory_size_s / (1024.0f * 1024.0f)); } @@ -137,10 +143,11 @@ void llama_memory_recurrent::clear(bool data) { ggml_backend_buffer_clear(buf.get(), 0); } } + + std::fill(rs_idx.begin(), rs_idx.end(), 0); } bool llama_memory_recurrent::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) { - //printf("[DEBUG] calling llama_memory_recurrent::seq_rm` with `seq_id=%d, p0=%d, p1=%d`\n", seq_id, p0, p1); uint32_t new_head = size; if (p0 < 0) { @@ -151,6 +158,15 @@ bool llama_memory_recurrent::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1 = std::numeric_limits<llama_pos>::max(); } + const bool rm_all = p0 == 0 && p1 == std::numeric_limits<llama_pos>::max(); + if (rm_all) { + if (seq_id >= 0) { + set_rs_idx(seq_id, 0); + } else { + std::fill(rs_idx.begin(), rs_idx.end(), 0); + } + } + // models like Mamba or RWKV can't have a state partially erased at the end // of the sequence because their state isn't preserved for previous tokens if (seq_id >= (int64_t) size) { @@ -160,10 +176,16 @@ bool llama_memory_recurrent::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos if (0 <= seq_id) { int32_t & tail_id = cells[seq_id].tail; if (tail_id >= 0) { - const auto & cell = cells[tail_id]; - // partial intersection is invalid if it includes the final pos + auto & cell = cells[tail_id]; + + // partial rollback via per-token snapshot index (bounded by n_rs_seq) if (0 < p0 && p0 <= cell.pos && p1 > cell.pos) { - //printf("[DEBUG] inside `llama_memory_recurrent::seq_rm`: partial intersection is invalid, so returning false\n"); + const llama_pos rollback = cell.pos - (p0 - 1); + if (rollback >= 1 && rollback <= (llama_pos) n_rs_seq) { + set_rs_idx(seq_id, (uint32_t) rollback); + cell.pos = p0 - 1; + return true; + } return false; } // invalidate tails which will be cleared @@ -367,6 +389,13 @@ llama_pos llama_memory_recurrent::seq_pos_max(llama_seq_id seq_id) const { return result; } +void llama_memory_recurrent::set_rs_idx(llama_seq_id seq_id, uint32_t idx) { + if (seq_id < 0 || (size_t) seq_id >= rs_idx.size()) { + return; + } + rs_idx[seq_id] = (idx > n_rs_seq) ? n_rs_seq : idx; +} + std::map<ggml_backend_buffer_type_t, size_t> llama_memory_recurrent::memory_breakdown() const { std::map<ggml_backend_buffer_type_t, size_t> ret; for (const auto & [_, buf] : ctxs_bufs) { @@ -387,9 +416,15 @@ llama_memory_context_ptr llama_memory_recurrent::init_batch(llama_batch_allocr & // if all tokens are output, split by sequence ubatch = balloc.split_seq(n_ubatch); } else { - // TODO: non-sequential equal split can be done if using unified KV cache - // for simplicity, we always use sequential equal split for now - ubatch = balloc.split_equal(n_ubatch, true); + if (n_rs_seq > 0) { + // [TAG_RECURRENT_ROLLBACK_SPLITS] + // TODO: recurrent state rollback does not support equal splits + ubatch = balloc.split_seq(n_ubatch); + } else { + // TODO: non-sequential equal split can be done if using unified KV cache + // for simplicity, we always use sequential equal split for now + ubatch = balloc.split_equal(n_ubatch, true); + } } if (ubatch.n_tokens == 0) { @@ -702,6 +737,7 @@ void llama_memory_recurrent::state_write(llama_io_write_i & io, llama_seq_id seq GGML_UNUSED(flags); std::vector<std::pair<uint32_t, uint32_t>> cell_ranges; // ranges, from inclusive, to exclusive + std::vector<std::pair<uint32_t, uint32_t>> cell_ranges_data; // logical source row ranges uint32_t cell_count = 0; // Count the number of cells with the specified seq_id @@ -711,6 +747,35 @@ void llama_memory_recurrent::state_write(llama_io_write_i & io, llama_seq_id seq const auto & cell = cells[i]; if ((seq_id == -1 && !cell.is_empty()) || cell.has_seq_id(seq_id)) { ++cell_count; + uint32_t rs_idx_cur = 0; + + if (n_rs_seq != 0) { + if (seq_id != -1) { + GGML_ASSERT(seq_id >= 0 && (size_t) seq_id < rs_idx.size()); + rs_idx_cur = rs_idx[seq_id]; + } else { + bool has_rs_idx = false; + for (const llama_seq_id cell_seq_id : cell.seq_id) { + GGML_ASSERT(cell_seq_id >= 0 && (size_t) cell_seq_id < rs_idx.size()); + + const uint32_t seq_rs_idx = rs_idx[cell_seq_id]; + if (!has_rs_idx) { + rs_idx_cur = seq_rs_idx; + has_rs_idx = true; + } else if (rs_idx_cur != seq_rs_idx) { + GGML_ABORT("cannot write shared recurrent state with different rollback indices"); + } + } + } + } + + const uint32_t cell_id = rs_idx_cur * size + (cell.src >= 0 ? cell.src : (int32_t) i); + if (cell_ranges_data.empty() || cell_ranges_data.back().second != cell_id) { + cell_ranges_data.emplace_back(cell_id, cell_id + 1); + } else { + cell_ranges_data.back().second++; + } + if (cell_range_begin == size) { cell_range_begin = i; } @@ -725,6 +790,10 @@ void llama_memory_recurrent::state_write(llama_io_write_i & io, llama_seq_id seq cell_ranges.emplace_back(cell_range_begin, size); } + if ((flags & LLAMA_STATE_SEQ_FLAGS_ON_DEVICE) && cell_ranges.size() > 1) { + GGML_ABORT("cannot save/load multiple ranges of cells to/from device memory\n"); + } + // DEBUG CHECK: Sum of cell counts in ranges should equal the total cell count uint32_t cell_count_check = 0; for (const auto & range : cell_ranges) { @@ -732,17 +801,23 @@ void llama_memory_recurrent::state_write(llama_io_write_i & io, llama_seq_id seq } GGML_ASSERT(cell_count == cell_count_check); + cell_count_check = 0; + for (const auto & range : cell_ranges_data) { + cell_count_check += range.second - range.first; + } + GGML_ASSERT(cell_count == cell_count_check); + io.write(&cell_count, sizeof(cell_count)); state_write_meta(io, cell_ranges, seq_id); - state_write_data(io, cell_ranges); + state_write_data(io, cell_ranges_data); } void llama_memory_recurrent::state_read(llama_io_read_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) { GGML_UNUSED(flags); uint32_t cell_count; - io.read_to(&cell_count, sizeof(cell_count)); + io.read(&cell_count, sizeof(cell_count)); bool res = true; @@ -757,6 +832,14 @@ void llama_memory_recurrent::state_read(llama_io_read_i & io, llama_seq_id seq_i } throw std::runtime_error("failed to restore kv cache"); } + + if (n_rs_seq != 0) { + if (seq_id == -1) { + std::fill(rs_idx.begin(), rs_idx.end(), 0); + } else { + set_rs_idx(seq_id, 0); + } + } } void llama_memory_recurrent::state_write_meta(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges, llama_seq_id seq_id) const { @@ -780,28 +863,27 @@ void llama_memory_recurrent::state_write_meta(llama_io_write_i & io, const std:: void llama_memory_recurrent::state_write_data(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges) const { const uint32_t s_trans = 0; - const uint32_t n_layer = hparams.n_layer; + const uint32_t n_layer = hparams.n_layer(); io.write(&s_trans, sizeof(s_trans)); - io.write(&n_layer, sizeof(n_layer)); + io.write(&n_layer, sizeof(n_layer)); - std::vector<uint8_t> tmp_buf; - - // Iterate and write all the keys first, each row is a cell + // Iterate and write all the R tensors first, each row is a cell // Get whole range at a time for (uint32_t il = 0; il < n_layer; ++il) { // skip null layers (read_data will handle this by checking "r_l" and "s_l" for null) if (r_l[il] == nullptr) continue; - // Write key type + // Write R tensor type const int32_t r_type_i = (int32_t)r_l[il]->type; io.write(&r_type_i, sizeof(r_type_i)); - // Write row size of key + // Write row size of R tensor const uint64_t r_size_row = ggml_row_size(r_l[il]->type, hparams.n_embd_r()); io.write(&r_size_row, sizeof(r_size_row)); - // Read each range of cells of k_size length each into tmp_buf and write out + // Write each logical cell row range. With pending recurrent rollback, + // the logical current state may live in a rollback snapshot plane. for (const auto & range : cell_ranges) { const size_t range_size = range.second - range.first; const size_t buf_size = range_size * r_size_row; @@ -814,15 +896,16 @@ void llama_memory_recurrent::state_write_data(llama_io_write_i & io, const std:: // skip null layers (read_data will handle this by checking "r_l" and "s_l" for null) if (s_l[il] == nullptr) continue; - // Write value type + // Write S tensor type const int32_t s_type_i = (int32_t)s_l[il]->type; io.write(&s_type_i, sizeof(s_type_i)); - // Write row size of value + // Write row size of S tensor const uint64_t s_size_row = ggml_row_size(s_l[il]->type, hparams.n_embd_s()); io.write(&s_size_row, sizeof(s_size_row)); - // Read each range of cells of s_size length each into tmp_buf and write out + // Write each logical cell row range. With pending recurrent rollback, + // the logical current state may live in a rollback snapshot plane. for (const auto & range : cell_ranges) { const size_t range_size = range.second - range.first; const size_t buf_size = range_size * s_size_row; @@ -830,7 +913,7 @@ void llama_memory_recurrent::state_write_data(llama_io_write_i & io, const std:: } } } else { - // When v is transposed, we also need the element size and get the element ranges from each row + // When S tensor is transposed, we also need the element size and get the element ranges from each row const uint32_t mem_size = size; for (uint32_t il = 0; il < n_layer; ++il) { // skip null layers (read_data will handle this by checking "r_l" and "s_l" for null) @@ -838,7 +921,7 @@ void llama_memory_recurrent::state_write_data(llama_io_write_i & io, const std:: const uint32_t n_embd_s = hparams.n_embd_s(); - // Write value type + // Write S tensor type const int32_t s_type_i = (int32_t)s_l[il]->type; io.write(&s_type_i, sizeof(s_type_i)); @@ -849,9 +932,8 @@ void llama_memory_recurrent::state_write_data(llama_io_write_i & io, const std:: // Write GQA embedding size io.write(&n_embd_s, sizeof(n_embd_s)); - // For each row, we get the element values of each cell + // For each row, we get the element values of each logical cell for (uint32_t j = 0; j < n_embd_s; ++j) { - // Read each range of cells of v_size_el length each into tmp_buf and write out for (const auto & range : cell_ranges) { const size_t range_size = range.second - range.first; const size_t src_offset = (range.first + j * mem_size) * s_size_el; @@ -880,8 +962,8 @@ bool llama_memory_recurrent::state_read_meta(llama_io_read_i & io, uint32_t cell llama_pos pos; uint32_t n_seq_id; - io.read_to(&pos, sizeof(pos)); - io.read_to(&n_seq_id, sizeof(n_seq_id)); + io.read(&pos, sizeof(pos)); + io.read(&n_seq_id, sizeof(n_seq_id)); if (n_seq_id != 0) { LLAMA_LOG_ERROR("%s: invalid seq_id-agnostic kv cell\n", __func__); @@ -921,20 +1003,17 @@ bool llama_memory_recurrent::state_read_meta(llama_io_read_i & io, uint32_t cell llama_pos pos; uint32_t n_seq_id; - io.read_to(&pos, sizeof(pos)); - io.read_to(&n_seq_id, sizeof(n_seq_id)); + io.read(&pos, sizeof(pos)); + io.read(&n_seq_id, sizeof(n_seq_id)); cell.pos = pos; for (uint32_t j = 0; j < n_seq_id; ++j) { llama_seq_id seq_id; - io.read_to(&seq_id, sizeof(seq_id)); + io.read(&seq_id, sizeof(seq_id)); - // TODO: llama_memory_recurrent should have a notion of max sequences - //if (seq_id < 0 || (uint32_t) seq_id >= llama_n_seq_max(ctx)) { - if (seq_id < 0) { - //LLAMA_LOG_ERROR("%s: invalid seq_id, %d is out of range [0, %u)\n", __func__, seq_id, llama_n_seq_max(ctx)); - LLAMA_LOG_ERROR("%s: invalid seq_id, %d is out of range [0, inf)\n", __func__, seq_id); + if (seq_id < 0 || (uint32_t) seq_id >= this->n_seq_max) { + LLAMA_LOG_ERROR("%s: invalid seq_id, %d is out of range [0, %u)\n", __func__, seq_id, this->n_seq_max); return false; } @@ -965,11 +1044,11 @@ bool llama_memory_recurrent::state_read_meta(llama_io_read_i & io, uint32_t cell bool llama_memory_recurrent::state_read_data(llama_io_read_i & io, uint32_t cell_count) { uint32_t s_trans; uint32_t n_layer; - io.read_to(&s_trans, sizeof(s_trans)); - io.read_to(&n_layer, sizeof(n_layer)); + io.read(&s_trans, sizeof(s_trans)); + io.read(&n_layer, sizeof(n_layer)); - if (n_layer != hparams.n_layer) { - LLAMA_LOG_ERROR("%s: mismatched layer count (%u instead of %u)\n", __func__, n_layer, hparams.n_layer); + if (n_layer != hparams.n_layer()) { + LLAMA_LOG_ERROR("%s: mismatched layer count (%u instead of %u)\n", __func__, n_layer, hparams.n_layer()); return false; } if (cell_count > size) { @@ -988,7 +1067,7 @@ bool llama_memory_recurrent::state_read_data(llama_io_read_i & io, uint32_t cell // Read type of key int32_t r_type_i_ref; - io.read_to(&r_type_i_ref, sizeof(r_type_i_ref)); + io.read(&r_type_i_ref, sizeof(r_type_i_ref)); const int32_t r_type_i = (int32_t) r_l[il]->type; if (r_type_i != r_type_i_ref) { LLAMA_LOG_ERROR("%s: mismatched r type (%d != %d, layer %d)\n", __func__, r_type_i, r_type_i_ref, il); @@ -997,7 +1076,7 @@ bool llama_memory_recurrent::state_read_data(llama_io_read_i & io, uint32_t cell // Read row size of key uint64_t r_size_row_ref; - io.read_to(&r_size_row_ref, sizeof(r_size_row_ref)); + io.read(&r_size_row_ref, sizeof(r_size_row_ref)); const size_t r_size_row = ggml_row_size(r_l[il]->type, hparams.n_embd_r()); if (r_size_row != r_size_row_ref) { LLAMA_LOG_ERROR("%s: mismatched r row size (%zu != %zu, layer %d)\n", __func__, r_size_row, (size_t) r_size_row_ref, il); @@ -1006,7 +1085,7 @@ bool llama_memory_recurrent::state_read_data(llama_io_read_i & io, uint32_t cell if (cell_count) { // Read and set the keys for the whole cell range - ggml_backend_tensor_set(r_l[il], io.read(cell_count * r_size_row), head * r_size_row, cell_count * r_size_row); + io.read_tensor(r_l[il], head * r_size_row, cell_count * r_size_row); } } @@ -1017,7 +1096,7 @@ bool llama_memory_recurrent::state_read_data(llama_io_read_i & io, uint32_t cell // Read type of value int32_t s_type_i_ref; - io.read_to(&s_type_i_ref, sizeof(s_type_i_ref)); + io.read(&s_type_i_ref, sizeof(s_type_i_ref)); const int32_t s_type_i = (int32_t)s_l[il]->type; if (s_type_i != s_type_i_ref) { @@ -1027,7 +1106,7 @@ bool llama_memory_recurrent::state_read_data(llama_io_read_i & io, uint32_t cell // Read row size of value uint64_t s_size_row_ref; - io.read_to(&s_size_row_ref, sizeof(s_size_row_ref)); + io.read(&s_size_row_ref, sizeof(s_size_row_ref)); const size_t s_size_row = ggml_row_size(s_l[il]->type, hparams.n_embd_s()); if (s_size_row != s_size_row_ref) { LLAMA_LOG_ERROR("%s: mismatched s row size (%zu != %zu, layer %d)\n", __func__, s_size_row, (size_t) s_size_row_ref, il); @@ -1036,7 +1115,7 @@ bool llama_memory_recurrent::state_read_data(llama_io_read_i & io, uint32_t cell if (cell_count) { // Read and set the values for the whole cell range - ggml_backend_tensor_set(s_l[il], io.read(cell_count * s_size_row), head * s_size_row, cell_count * s_size_row); + io.read_tensor(s_l[il], head * s_size_row, cell_count * s_size_row); } } } else { @@ -1049,7 +1128,7 @@ bool llama_memory_recurrent::state_read_data(llama_io_read_i & io, uint32_t cell // Read type of value int32_t s_type_i_ref; - io.read_to(&s_type_i_ref, sizeof(s_type_i_ref)); + io.read(&s_type_i_ref, sizeof(s_type_i_ref)); const int32_t s_type_i = (int32_t)s_l[il]->type; if (s_type_i != s_type_i_ref) { LLAMA_LOG_ERROR("%s: mismatched s type (%d != %d, layer %d)\n", __func__, s_type_i, s_type_i_ref, il); @@ -1058,7 +1137,7 @@ bool llama_memory_recurrent::state_read_data(llama_io_read_i & io, uint32_t cell // Read element size of value uint32_t s_size_el_ref; - io.read_to(&s_size_el_ref, sizeof(s_size_el_ref)); + io.read(&s_size_el_ref, sizeof(s_size_el_ref)); const size_t s_size_el = ggml_type_size(s_l[il]->type); if (s_size_el != s_size_el_ref) { LLAMA_LOG_ERROR("%s: mismatched s element size (%zu != %zu, layer %d)\n", __func__, s_size_el, (size_t) s_size_el_ref, il); @@ -1067,7 +1146,7 @@ bool llama_memory_recurrent::state_read_data(llama_io_read_i & io, uint32_t cell // Read state embedding size uint32_t n_embd_s_ref; - io.read_to(&n_embd_s_ref, sizeof(n_embd_s_ref)); + io.read(&n_embd_s_ref, sizeof(n_embd_s_ref)); if (n_embd_s != n_embd_s_ref) { LLAMA_LOG_ERROR("%s: mismatched s embedding size (%u != %u, layer %d)\n", __func__, n_embd_s, n_embd_s_ref, il); return false; @@ -1077,7 +1156,7 @@ bool llama_memory_recurrent::state_read_data(llama_io_read_i & io, uint32_t cell // For each row in the transposed matrix, read the values for the whole cell range for (uint32_t j = 0; j < n_embd_s; ++j) { const size_t dst_offset = (head + j * size) * s_size_el; - ggml_backend_tensor_set(s_l[il], io.read(cell_count * s_size_el), dst_offset, cell_count * s_size_el); + io.read_tensor(s_l[il], dst_offset, cell_count * s_size_el); } } } @@ -1163,5 +1242,21 @@ ggml_tensor * llama_memory_recurrent_context::get_s_l(int32_t il) const { } int32_t llama_memory_recurrent_context::s_copy(int i) const { - return mem->cells[i + mem->head].src0; + const uint32_t cell_idx = i + mem->head; + const int32_t src0 = mem->cells[cell_idx].src0; + + if (mem->n_rs_seq == 0) { + return src0; + } + + uint32_t idx = 0; + if (!mem->cells[cell_idx].seq_id.empty()) { + const llama_seq_id seq = *mem->cells[cell_idx].seq_id.begin(); + if (seq >= 0 && (size_t) seq < mem->rs_idx.size()) { + idx = mem->rs_idx[seq]; + // reset rollback idx + mem->rs_idx[seq] = 0; + } + } + return (int32_t)(idx * mem->size) + src0; } diff --git a/examples/talk-llama/llama-memory-recurrent.h b/examples/talk-llama/llama-memory-recurrent.h index 47f01d73912..b13b7b748f5 100644 --- a/examples/talk-llama/llama-memory-recurrent.h +++ b/examples/talk-llama/llama-memory-recurrent.h @@ -23,6 +23,7 @@ class llama_memory_recurrent : public llama_memory_i { bool offload, uint32_t mem_size, uint32_t n_seq_max, + uint32_t n_rs_seq, const layer_filter_cb & filter); ~llama_memory_recurrent() = default; @@ -69,6 +70,14 @@ class llama_memory_recurrent : public llama_memory_i { uint32_t size = 0; // total number of cells, shared across all sequences uint32_t used = 0; // used cells (i.e. at least one seq_id) + // number of recurrent-state snapshots per seq for rollback; tensors are widened to (1 + n_rs_seq) groups + uint32_t n_rs_seq = 0; + + // per-seq rollback index + std::vector<uint32_t> rs_idx; + + void set_rs_idx(llama_seq_id seq_id, uint32_t idx); + // computed before each graph build uint32_t n = 0; diff --git a/examples/talk-llama/llama-memory.h b/examples/talk-llama/llama-memory.h index 4a157b91fdb..db825396645 100644 --- a/examples/talk-llama/llama-memory.h +++ b/examples/talk-llama/llama-memory.h @@ -1,6 +1,7 @@ #pragma once #include "llama.h" +#include "llama-graph.h" #include <map> #include <memory> @@ -20,6 +21,10 @@ struct llama_memory_params { // use full-size SWA cache bool swa_full; + + llama_context_type ctx_type; + + llama_memory_t mem_other; }; enum llama_memory_status { @@ -73,6 +78,8 @@ struct llama_memory_i { // return negative value to indicate that the layer il should not reuse memory using layer_reuse_cb = std::function<int32_t(int32_t il)>; + using layer_share_cb = std::function<int32_t(int32_t il)>; + virtual ~llama_memory_i() = default; // split the input batch into a set of ubatches and verify that they can fit into the cache diff --git a/examples/talk-llama/llama-mmap.cpp b/examples/talk-llama/llama-mmap.cpp index 2da857b3aae..ed572da7fb5 100644 --- a/examples/talk-llama/llama-mmap.cpp +++ b/examples/talk-llama/llama-mmap.cpp @@ -40,6 +40,14 @@ #include <TargetConditionals.h> #endif +#ifdef _WIN32 +# define llama_mmap_ftell _ftelli64 +# define llama_mmap_fseek _fseeki64 +#else +# define llama_mmap_ftell ftello +# define llama_mmap_fseek fseeko +#endif + // TODO: consider moving to llama-impl.h if needed in more places #if defined(_WIN32) static std::string llama_format_win_err(DWORD err) { @@ -86,6 +94,14 @@ struct llama_file::impl { seek(0, SEEK_SET); } + impl(FILE * file) : owns_fp(false) { + fp = file; + fp_win32 = (HANDLE) _get_osfhandle(_fileno(fp)); + seek(0, SEEK_END); + size = tell(); + seek(0, SEEK_SET); + } + size_t tell() const { LARGE_INTEGER li; li.QuadPart = 0; @@ -159,7 +175,7 @@ struct llama_file::impl { } ~impl() { - if (fp) { + if (fp && owns_fp) { std::fclose(fp); } } @@ -209,9 +225,16 @@ struct llama_file::impl { seek(0, SEEK_SET); } + impl(FILE * file) : fname("(file*)"), owns_fp(false) { + fp = file; + seek(0, SEEK_END); + size = tell(); + seek(0, SEEK_SET); + } + size_t tell() const { if (fd == -1) { - long ret = std::ftell(fp); + off_t ret = llama_mmap_ftell(fp); if (ret == -1) { throw std::runtime_error(format("ftell error: %s", strerror(errno))); } @@ -229,7 +252,7 @@ struct llama_file::impl { void seek(size_t offset, int whence) const { off_t ret = 0; if (fd == -1) { - ret = std::fseek(fp, (long) offset, whence); + ret = llama_mmap_fseek(fp, offset, whence); } else { ret = lseek(fd, offset, whence); } @@ -244,11 +267,14 @@ struct llama_file::impl { } errno = 0; if (fd == -1) { - std::size_t ret = std::fread(ptr, len, 1, fp); + const size_t curr_off = tell(); + const size_t to_read = std::min(len, size - curr_off); + + std::size_t ret = std::fread(ptr, to_read, 1, fp); if (ferror(fp)) { throw std::runtime_error(format("read error: %s", strerror(errno))); } - if (ret != 1) { + if (to_read > 0 && ret != 1) { throw std::runtime_error("unexpectedly reached end of file"); } } else { @@ -262,7 +288,8 @@ struct llama_file::impl { continue; // Interrupted by signal, retry } // Fallback to std::fread in case the DMA controller cannot access the buffer - if (errno == EFAULT) { + if (errno == EFAULT || errno == EINVAL) { + LLAMA_LOG_WARN("%s: Falling back to buffered IO due to %s\n", __func__, strerror(errno)); auto curr_off = tell(); close(fd); fd = -1; @@ -349,7 +376,7 @@ struct llama_file::impl { ~impl() { if (fd != -1) { close(fd); - } else { + } else if (owns_fp) { std::fclose(fp); } } @@ -365,10 +392,14 @@ struct llama_file::impl { FILE * fp{}; size_t size{}; + bool owns_fp = true; }; llama_file::llama_file(const char * fname, const char * mode, const bool use_direct_io) : pimpl(std::make_unique<impl>(fname, mode, use_direct_io)) {} + +llama_file::llama_file(FILE * file) : pimpl(std::make_unique<impl>(file)) {} + llama_file::~llama_file() = default; size_t llama_file::tell() const { return pimpl->tell(); } @@ -381,6 +412,9 @@ int llama_file::file_id() const { #ifdef _WIN32 return _fileno(pimpl->fp); #else + if (pimpl->fd != -1) { + return pimpl->fd; + } #if defined(fileno) return fileno(pimpl->fp); #else @@ -497,6 +531,8 @@ struct llama_mmap::impl { } } #elif defined(_WIN32) + HANDLE hMapping = nullptr; + impl(struct llama_file * file, size_t prefetch, bool numa) { GGML_UNUSED(numa); @@ -504,7 +540,7 @@ struct llama_mmap::impl { HANDLE hFile = (HANDLE) _get_osfhandle(file->file_id()); - HANDLE hMapping = CreateFileMappingA(hFile, NULL, PAGE_READONLY, 0, 0, NULL); + hMapping = CreateFileMappingA(hFile, NULL, PAGE_READONLY, 0, 0, NULL); if (hMapping == NULL) { DWORD error = GetLastError(); @@ -513,9 +549,9 @@ struct llama_mmap::impl { addr = MapViewOfFile(hMapping, FILE_MAP_READ, 0, 0, 0); DWORD error = GetLastError(); - CloseHandle(hMapping); if (addr == NULL) { + CloseHandle(hMapping); throw std::runtime_error(format("MapViewOfFile failed: %s", llama_format_win_err(error).c_str())); } @@ -547,9 +583,17 @@ struct llama_mmap::impl { } ~impl() { - if (!UnmapViewOfFile(addr)) { - LLAMA_LOG_WARN("warning: UnmapViewOfFile failed: %s\n", - llama_format_win_err(GetLastError()).c_str()); + if (hMapping) { + if (addr) { + if (!UnmapViewOfFile(addr)) { + LLAMA_LOG_WARN("warning: UnmapViewOfFile failed: %s\n", + llama_format_win_err(GetLastError()).c_str()); + } + } + if (!CloseHandle(hMapping)) { + LLAMA_LOG_WARN("warning: CloseHandle failed: %s\n", + llama_format_win_err(GetLastError()).c_str()); + } } } #else @@ -611,9 +655,9 @@ struct llama_mlock::impl { char* errmsg = std::strerror(errno); bool suggest = (errno == ENOMEM); -#if defined(TARGET_OS_VISION) || defined(TARGET_OS_TV) || defined(_AIX) - // visionOS/tvOS dont't support RLIMIT_MEMLOCK - // Skip resource limit checks on visionOS/tvOS +#if defined(TARGET_OS_VISION) || defined(TARGET_OS_TV) || defined(_AIX) || defined(__HAIKU__) + // visionOS/tvOS/Haiku don't support RLIMIT_MEMLOCK + // Skip resource limit checks on these platforms suggest = false; #else struct rlimit lock_limit; diff --git a/examples/talk-llama/llama-mmap.h b/examples/talk-llama/llama-mmap.h index 29ce4d24685..b7d5c61e95f 100644 --- a/examples/talk-llama/llama-mmap.h +++ b/examples/talk-llama/llama-mmap.h @@ -15,6 +15,7 @@ using llama_mlocks = std::vector<std::unique_ptr<llama_mlock>>; struct llama_file { llama_file(const char * fname, const char * mode, bool use_direct_io = false); + llama_file(FILE * file); ~llama_file(); size_t tell() const; diff --git a/examples/talk-llama/llama-model-loader.cpp b/examples/talk-llama/llama-model-loader.cpp index e66febaa021..474cabdfc09 100644 --- a/examples/talk-llama/llama-model-loader.cpp +++ b/examples/talk-llama/llama-model-loader.cpp @@ -1,11 +1,17 @@ #include "llama-model-loader.h" +#include "ggml-alloc.h" #include "ggml.h" +#include "gguf.h" +#include "llama-hparams.h" +#include <algorithm> #include <array> #include <cinttypes> +#include <cstdint> #include <cstring> #include <future> +#include <regex> static const size_t kiB = 1024; static const size_t MiB = 1024*kiB; @@ -30,12 +36,14 @@ static std::string llama_model_ftype_name(llama_ftype ftype) { case LLAMA_FTYPE_ALL_F32: return "all F32"; case LLAMA_FTYPE_MOSTLY_F16: return "F16"; case LLAMA_FTYPE_MOSTLY_BF16: return "BF16"; + case LLAMA_FTYPE_MOSTLY_Q1_0: return "Q1_0"; case LLAMA_FTYPE_MOSTLY_Q4_0: return "Q4_0"; case LLAMA_FTYPE_MOSTLY_Q4_1: return "Q4_1"; case LLAMA_FTYPE_MOSTLY_Q5_0: return "Q5_0"; case LLAMA_FTYPE_MOSTLY_Q5_1: return "Q5_1"; case LLAMA_FTYPE_MOSTLY_Q8_0: return "Q8_0"; case LLAMA_FTYPE_MOSTLY_MXFP4_MOE: return "MXFP4 MoE"; + case LLAMA_FTYPE_MOSTLY_NVFP4: return "NVFP4"; case LLAMA_FTYPE_MOSTLY_Q2_K: return "Q2_K - Medium"; case LLAMA_FTYPE_MOSTLY_Q2_K_S: return "Q2_K - Small"; case LLAMA_FTYPE_MOSTLY_Q3_K_S: return "Q3_K - Small"; @@ -138,7 +146,7 @@ namespace GGUFMeta { const enum gguf_type arr_type = gguf_get_arr_type(ctx, k); return ArrayInfo { arr_type, - size_t(gguf_get_arr_n(ctx, k)), + gguf_get_arr_n(ctx, k), arr_type == GGUF_TYPE_STRING ? nullptr : gguf_get_arr_data(ctx, k), }; } @@ -262,7 +270,7 @@ namespace GGUFMeta { template<typename T> typename std::enable_if<std::is_integral<T>::value, bool>::type llama_model_loader::get_arr_n(const std::string & key, T & result, bool required) { - const int kid = gguf_find_key(meta.get(), key.c_str()); + const int kid = gguf_find_key(metadata, key.c_str()); if (kid < 0) { if (required) { @@ -272,7 +280,7 @@ namespace GGUFMeta { } struct GGUFMeta::ArrayInfo arr_info = - GGUFMeta::GKV<GGUFMeta::ArrayInfo>::get_kv(meta.get(), kid); + GGUFMeta::GKV<GGUFMeta::ArrayInfo>::get_kv(metadata, kid); result = arr_info.length; @@ -289,7 +297,7 @@ namespace GGUFMeta { template<typename T> bool llama_model_loader::get_arr(const std::string & key, std::vector<T> & result, bool required) { - const gguf_context * ctx = meta.get(); + const gguf_context * ctx = metadata; const int kid = gguf_find_key(ctx, key.c_str()); if (kid < 0 || gguf_get_kv_type(ctx, kid) != GGUF_TYPE_ARRAY) { @@ -330,7 +338,7 @@ namespace GGUFMeta { template<typename T, size_t N_MAX> bool llama_model_loader::get_arr(const std::string & key, std::array<T, N_MAX> & result, bool required) { - const gguf_context * ctx = meta.get(); + const gguf_context * ctx = metadata; const int kid = gguf_find_key(ctx, key.c_str()); if (kid < 0 || gguf_get_kv_type(ctx, kid) != GGUF_TYPE_ARRAY) { @@ -344,6 +352,7 @@ namespace GGUFMeta { GGUFMeta::GKV<GGUFMeta::ArrayInfo>::get_kv(ctx, kid); switch (arr_info.gt) { + case GGUF_TYPE_BOOL: case GGUF_TYPE_UINT32: case GGUF_TYPE_INT32: GGML_ASSERT((std::is_same<T, int32_t>::value) || (std::is_same<T, uint32_t>::value)); break; @@ -365,7 +374,14 @@ namespace GGUFMeta { result[i] = value; } } else { - std::copy((const T*)arr_info.data, (const T *)arr_info.data + arr_info.length, result.begin()); + if (arr_info.gt == GGUF_TYPE_BOOL) { + const int8_t * values = (const int8_t *) arr_info.data; + std::transform(values, values + arr_info.length, result.begin(), [](int8_t x) { + return static_cast<T>(x != 0); + }); + } else { + std::copy((const T*)arr_info.data, (const T *)arr_info.data + arr_info.length, result.begin()); + } } return true; @@ -377,6 +393,8 @@ namespace GGUFMeta { } template bool llama_model_loader::get_arr<std::vector<std::string>>(enum llm_kv kid, std::vector<std::string> & result, bool required); + template bool llama_model_loader::get_arr<std::array<int32_t, 512>>(enum llm_kv kid, std::array<int32_t, 512> & result, bool required); + template bool llama_model_loader::get_arr<std::vector<int32_t>>(enum llm_kv kid, std::vector<int32_t> & result, bool required); template<typename T> bool llama_model_loader::get_key(const std::string & key, T & result, bool required) { @@ -385,7 +403,7 @@ namespace GGUFMeta { const struct llama_model_kv_override * override = it != kv_overrides.end() ? &it->second : nullptr; - const bool found = GGUFMeta::GKV<T>::set(meta.get(), key, result, override); + const bool found = GGUFMeta::GKV<T>::set(metadata, key, result, override); if (required && !found) { throw std::runtime_error(format("key not found in model: %s", key.c_str())); @@ -419,7 +437,7 @@ namespace GGUFMeta { // get array of n <= N_MAX elements, or a single element repeated n times template<typename T, size_t N_MAX> bool llama_model_loader::get_key_or_arr(const std::string & key, std::array<T, N_MAX> & result, uint32_t n, bool required) { - const int kid = gguf_find_key(meta.get(), key.c_str()); + const int kid = gguf_find_key(metadata, key.c_str()); if (kid < 0) { if (required) { @@ -429,12 +447,12 @@ namespace GGUFMeta { } if (n > N_MAX) { - throw std::runtime_error(format("n > N_MAX: %u > %u for key %s", (uint32_t) n, (uint32_t) N_MAX, key.c_str())); + throw std::runtime_error(format("n > N_MAX: %u > %u for key %s", n, (uint32_t) N_MAX, key.c_str())); } - if (gguf_get_kv_type(meta.get(), kid) == GGUF_TYPE_ARRAY) { + if (gguf_get_kv_type(metadata, kid) == GGUF_TYPE_ARRAY) { struct GGUFMeta::ArrayInfo arr_info = - GGUFMeta::GKV<GGUFMeta::ArrayInfo>::get_kv(meta.get(), kid); + GGUFMeta::GKV<GGUFMeta::ArrayInfo>::get_kv(metadata, kid); if (n != arr_info.length) { throw std::runtime_error(format("key %s has wrong array length; expected %u, got %u", key.c_str(), n, (uint32_t) arr_info.length)); @@ -465,7 +483,7 @@ namespace GGUFMeta { bool llama_model_loader::get_key_or_arr(enum llm_kv kid, uint32_t & result, bool required) { const std::string key = llm_kv(kid); - const int id = gguf_find_key(meta.get(), key.c_str()); + const int id = gguf_find_key(metadata, key.c_str()); if (id < 0) { if (required) { @@ -475,7 +493,7 @@ namespace GGUFMeta { } // throw and error if type is an array - if (gguf_get_kv_type(meta.get(), id) == GGUF_TYPE_ARRAY) { + if (gguf_get_kv_type(metadata, id) == GGUF_TYPE_ARRAY) { if (required) { throw std::runtime_error(format("expected scalar, found array for key: %s", key.c_str())); } @@ -486,20 +504,25 @@ namespace GGUFMeta { } // TODO: this is not very clever - figure out something better - template bool llama_model_loader::get_key_or_arr<std::array<int, 4>>(enum llm_kv kid, std::array<int, 4> & result, uint32_t n, bool required); + template bool llama_model_loader::get_key_or_arr<std::array<int, 4>> (enum llm_kv kid, std::array<int, 4> & result, uint32_t n, bool required); template bool llama_model_loader::get_key_or_arr<std::array<uint32_t, 512>>(enum llm_kv kid, std::array<uint32_t, 512> & result, uint32_t n, bool required); - template bool llama_model_loader::get_key_or_arr<std::array<float, 512>>(enum llm_kv kid, std::array<float, 512> & result, uint32_t n, bool required); + template bool llama_model_loader::get_key_or_arr<std::array<float, 512>>(enum llm_kv kid, std::array<float, 512> & result, uint32_t n, bool required); llama_model_loader::llama_model_loader( + struct gguf_context * meta, + llama_model_set_tensor_data_t set_tensor_data, + void * set_tensor_data_ud, const std::string & fname, std::vector<std::string> & splits, + FILE * file, bool use_mmap, bool use_direct_io, bool check_tensors, bool no_alloc, const llama_model_kv_override * param_overrides_p, - const llama_model_tensor_buft_override * param_tensor_buft_overrides_p) { + const llama_model_tensor_buft_override * param_tensor_buft_overrides_p) + : metadata(meta), set_tensor_data(set_tensor_data), set_tensor_data_ud(set_tensor_data_ud) { int trace = 0; if (getenv("LLAMA_TRACE")) { trace = atoi(getenv("LLAMA_TRACE")); @@ -513,133 +536,175 @@ llama_model_loader::llama_model_loader( tensor_buft_overrides = param_tensor_buft_overrides_p; - // Load the main GGUF - struct ggml_context * ctx = NULL; - struct gguf_init_params params = { - /*.no_alloc = */ true, - /*.ctx = */ &ctx, - }; - - meta.reset(gguf_init_from_file(fname.c_str(), params)); - if (!meta) { - throw std::runtime_error(format("%s: failed to load model from %s", __func__, fname.c_str())); - } - - get_key(llm_kv(LLM_KV_GENERAL_ARCHITECTURE), arch_name, false); - llm_kv = LLM_KV(llm_arch_from_string(arch_name)); + if (!fname.empty()) { + // Load the main GGUF + struct ggml_context * ctx = NULL; + struct gguf_init_params params = { + /*.no_alloc = */ true, + /*.ctx = */ &ctx, + }; - files.emplace_back(new llama_file(fname.c_str(), "rb", use_direct_io)); - contexts.emplace_back(ctx); + metadata_ptr.reset(gguf_init_from_file(fname.c_str(), params)); + metadata = metadata_ptr.get(); + if (metadata == nullptr) { + throw std::runtime_error(format("%s: failed to load model from %s", __func__, fname.c_str())); + } - use_direct_io = use_direct_io && files.back()->has_direct_io(); + get_key(llm_kv(LLM_KV_GENERAL_ARCHITECTURE), arch_name, false); + llm_kv = LLM_KV(llm_arch_from_string(arch_name)); - // Disable mmap in case Direct I/O is enabled and available - if (use_direct_io && use_mmap) { - use_mmap = false; - LLAMA_LOG_WARN("%s: direct I/O is enabled, disabling mmap\n", __func__); - } + files.emplace_back(new llama_file(fname.c_str(), "rb", use_direct_io)); + contexts.emplace_back(ctx); - // Save tensors data offset of the main file. - // For subsidiary files, `meta` tensor data offset must not be used, - // so we build a unified tensors index for weights. - for (ggml_tensor * cur = ggml_get_first_tensor(ctx); cur; cur = ggml_get_next_tensor(ctx, cur)) { - std::string tensor_name = std::string(cur->name); - // make sure there is no duplicated tensor names - if (weights_map.find(tensor_name) != weights_map.end()) { - throw std::runtime_error(format("invalid model: tensor '%s' is duplicated", ggml_get_name(cur))); - } - n_elements += ggml_nelements(cur); - n_bytes += ggml_nbytes(cur); - weights_map.emplace(tensor_name, llama_tensor_weight(files.back().get(), 0, meta.get(), cur)); - } - uint16_t n_split = 0; - get_key(llm_kv(LLM_KV_SPLIT_COUNT), n_split, false); - - // Load additional GGML contexts - if (n_split > 1) { - // make sure the main file is loaded first - uint16_t idx = 0; - const std::string kv_split_no = llm_kv(LLM_KV_SPLIT_NO); - get_key(kv_split_no, idx); - if (idx != 0) { - throw std::runtime_error(format("illegal split file idx: %d (file: %s), model must be loaded with the first split", idx, fname.c_str())); - } + if (use_mmap && use_direct_io) { + if (files.back()->has_direct_io()) { + LLAMA_LOG_WARN("%s: direct I/O is enabled, disabling mmap\n", __func__); + use_mmap = false; + } else { + LLAMA_LOG_WARN("%s: direct I/O is not available, using mmap\n", __func__); + use_direct_io = false; - // generate list of splits if needed - if (splits.empty()) { - splits = llama_get_list_splits(fname, idx, n_split); + // reopen file using std::fopen for mmap + files.pop_back(); + files.emplace_back(new llama_file(fname.c_str(), "rb", false)); + } } - // in case user give a custom list of splits, check if it matches the expected number - if (n_split != (uint16_t)splits.size()) { - throw std::runtime_error(format("invalid split count, given: %zu splits, but expected %d", splits.size(), n_split)); - } + // Save tensors data offset of the main file. + // For subsidiary files, `meta` tensor data offset must not be used, + // so we build a unified tensors index for weights. + for (ggml_tensor * cur = ggml_get_first_tensor(ctx); cur; cur = ggml_get_next_tensor(ctx, cur)) { + std::string tensor_name = std::string(cur->name); + // make sure there is no duplicated tensor names + if (weights_map.find(tensor_name) != weights_map.end()) { + throw std::runtime_error(format("invalid model: tensor '%s' is duplicated", ggml_get_name(cur))); + } + n_elements += ggml_nelements(cur); + n_bytes += ggml_nbytes(cur); + weights_map.emplace(tensor_name, llama_tensor_weight(files.back().get(), 0, metadata, cur)); + } + uint16_t n_split = 0; + get_key(llm_kv(LLM_KV_SPLIT_COUNT), n_split, false); + + // Load additional GGML contexts + if (n_split > 1) { + // make sure the main file is loaded first + uint16_t idx = 0; + const std::string kv_split_no = llm_kv(LLM_KV_SPLIT_NO); + get_key(kv_split_no, idx); + if (idx != 0) { + throw std::runtime_error(format("illegal split file idx: %d (file: %s), model must be loaded with the first split", idx, fname.c_str())); + } - if (trace > 0) { - LLAMA_LOG_INFO("%s: loading additional %d GGUFs\n", __func__, n_split); - } + // generate list of splits if needed + if (splits.empty()) { + splits = llama_get_list_splits(fname, idx, n_split); + } - // load other splits - for (idx = 1; idx < n_split; idx++) { - const char * fname_split = splits[idx].c_str(); + // in case user give a custom list of splits, check if it matches the expected number + if (n_split != (uint16_t)splits.size()) { + throw std::runtime_error(format("invalid split count, given: %zu splits, but expected %d", splits.size(), n_split)); + } - struct gguf_init_params split_params = { - /*.no_alloc = */ true, - /*.ctx = */ &ctx, - }; - gguf_context_ptr ctx_gguf { gguf_init_from_file(fname_split, split_params) }; - if (!ctx_gguf) { - throw std::runtime_error(format("%s: failed to load GGUF split from %s", __func__, fname_split)); + if (trace > 0) { + LLAMA_LOG_INFO("%s: loading additional %d GGUFs\n", __func__, n_split); } - // check idx - { - const int kid = gguf_find_key(ctx_gguf.get(), kv_split_no.c_str()); - if (kid < 0) { - throw std::runtime_error(format("missing key %s in GGUF split %s", kv_split_no.c_str(), fname_split)); + // load other splits + for (idx = 1; idx < n_split; idx++) { + const char * fname_split = splits[idx].c_str(); + + struct gguf_init_params split_params = { + /*.no_alloc = */ true, + /*.ctx = */ &ctx, + }; + gguf_context_ptr ctx_gguf { gguf_init_from_file(fname_split, split_params) }; + if (!ctx_gguf) { + throw std::runtime_error(format("%s: failed to load GGUF split from %s", __func__, fname_split)); } - int idx_gguf = gguf_get_val_u16(ctx_gguf.get(), kid); - if (idx_gguf != idx) { - throw std::runtime_error(format("invalid split file idx: %d (file: %s), expected %d", idx_gguf, fname_split, idx)); + + // check idx + { + const int kid = gguf_find_key(ctx_gguf.get(), kv_split_no.c_str()); + if (kid < 0) { + throw std::runtime_error(format("missing key %s in GGUF split %s", kv_split_no.c_str(), fname_split)); + } + int idx_gguf = gguf_get_val_u16(ctx_gguf.get(), kid); + if (idx_gguf != idx) { + throw std::runtime_error(format("invalid split file idx: %d (file: %s), expected %d", idx_gguf, fname_split, idx)); + } + } + + files.emplace_back(new llama_file(fname_split, "rb", use_direct_io)); + contexts.emplace_back(ctx); + + // Save tensors data offset info of the shard. + for (ggml_tensor * cur = ggml_get_first_tensor(ctx); cur; cur = ggml_get_next_tensor(ctx, cur)) { + std::string tensor_name = std::string(cur->name); + // make sure there is no duplicated tensor names + if (weights_map.find(tensor_name) != weights_map.end()) { + throw std::runtime_error(format("invalid model: tensor '%s' is duplicated", ggml_get_name(cur))); + } + n_elements += ggml_nelements(cur); + n_bytes += ggml_nbytes(cur); + weights_map.emplace(tensor_name, llama_tensor_weight(files.back().get(), idx, ctx_gguf.get(), cur)); } } - files.emplace_back(new llama_file(fname_split, "rb", use_direct_io)); - contexts.emplace_back(ctx); + get_key(llm_kv(LLM_KV_SPLIT_TENSORS_COUNT), n_tensors); - // Save tensors data offset info of the shard. - for (ggml_tensor * cur = ggml_get_first_tensor(ctx); cur; cur = ggml_get_next_tensor(ctx, cur)) { - std::string tensor_name = std::string(cur->name); - // make sure there is no duplicated tensor names - if (weights_map.find(tensor_name) != weights_map.end()) { - throw std::runtime_error(format("invalid model: tensor '%s' is duplicated", ggml_get_name(cur))); + // sanity check + { + const int n_tensors_loaded = (int) weights_map.size(); + if (n_tensors != n_tensors_loaded) { + throw std::runtime_error(format("corrupted model: %d tensors expected but %d found", n_tensors, n_tensors_loaded)); } - n_elements += ggml_nelements(cur); - n_bytes += ggml_nbytes(cur); - weights_map.emplace(tensor_name, llama_tensor_weight(files.back().get(), idx, ctx_gguf.get(), cur)); } + + LLAMA_LOG_INFO("%s: additional %d GGUFs metadata loaded.\n", __func__, n_split - 1); } + } else if (file != nullptr) { + struct ggml_context * ctx = NULL; + struct gguf_init_params params = { + /*.no_alloc = */ true, + /*.ctx = */ &ctx, + }; - get_key(llm_kv(LLM_KV_SPLIT_TENSORS_COUNT), n_tensors); + metadata_ptr.reset(gguf_init_from_file_ptr(file, params)); + metadata = metadata_ptr.get(); + if (metadata == nullptr) { + throw std::runtime_error(format("%s: failed to load model from file pointer", __func__)); + } - // sanity check - { - const int n_tensors_loaded = (int) weights_map.size(); - if (n_tensors != n_tensors_loaded) { - throw std::runtime_error(format("corrupted model: %d tensors expected but %d found", n_tensors, n_tensors_loaded)); + get_key(llm_kv(LLM_KV_GENERAL_ARCHITECTURE), arch_name, false); + llm_kv = LLM_KV(llm_arch_from_string(arch_name)); + + files.emplace_back(new llama_file(file)); + contexts.emplace_back(ctx); + + // Save tensors data offset info of the main file. + for (ggml_tensor * cur = ggml_get_first_tensor(ctx); cur; cur = ggml_get_next_tensor(ctx, cur)) { + std::string tensor_name = std::string(cur->name); + // make sure there is no duplicated tensor names + if (weights_map.find(tensor_name) != weights_map.end()) { + throw std::runtime_error(format("invalid model: tensor '%s' is duplicated", ggml_get_name(cur))); } + n_elements += ggml_nelements(cur); + n_bytes += ggml_nbytes(cur); + weights_map.emplace(tensor_name, llama_tensor_weight(files.back().get(), 0, metadata, cur)); } - - LLAMA_LOG_INFO("%s: additional %d GGUFs metadata loaded.\n", __func__, n_split - 1); + } else { + get_key(llm_kv(LLM_KV_GENERAL_ARCHITECTURE), arch_name, false); + llm_kv = LLM_KV(llm_arch_from_string(arch_name)); } - n_kv = gguf_get_n_kv(meta.get()); + n_kv = gguf_get_n_kv(metadata); n_tensors = weights_map.size(); - fver = (enum llama_fver) gguf_get_version(meta.get()); + fver = (enum llama_fver) gguf_get_version(metadata); LLAMA_LOG_INFO("%s: loaded meta data with %d key-value pairs and %d tensors from %s (version %s)\n", - __func__, n_kv, n_tensors, fname.c_str(), llama_file_version_name(fver)); + __func__, n_kv, n_tensors, fname.empty() ? "(file*)" : fname.c_str(), llama_file_version_name(fver)); // determine file type based on the number of tensors for each quantization and print meta data // TODO: make optional @@ -695,6 +760,8 @@ llama_model_loader::llama_model_loader( case GGML_TYPE_IQ4_NL: ftype = LLAMA_FTYPE_MOSTLY_IQ4_NL; break; case GGML_TYPE_IQ4_XS: ftype = LLAMA_FTYPE_MOSTLY_IQ4_XS; break; case GGML_TYPE_IQ3_S: ftype = LLAMA_FTYPE_MOSTLY_IQ3_S; break; + case GGML_TYPE_NVFP4: ftype = LLAMA_FTYPE_MOSTLY_NVFP4; break; + case GGML_TYPE_Q1_0: ftype = LLAMA_FTYPE_MOSTLY_Q1_0; break; default: { LLAMA_LOG_WARN("%s: unknown type %s\n", __func__, ggml_type_name(type_max)); @@ -715,14 +782,14 @@ llama_model_loader::llama_model_loader( LLAMA_LOG_INFO("%s: Dumping metadata keys/values. Note: KV overrides do not apply in this output.\n", __func__); for (int i = 0; i < n_kv; i++) { - const char * name = gguf_get_key(meta.get(), i); - const enum gguf_type type = gguf_get_kv_type(meta.get(), i); + const char * name = gguf_get_key(metadata, i); + const enum gguf_type type = gguf_get_kv_type(metadata, i); const std::string type_name = type == GGUF_TYPE_ARRAY - ? format("%s[%s,%zu]", gguf_type_name(type), gguf_type_name(gguf_get_arr_type(meta.get(), i)), gguf_get_arr_n(meta.get(), i)) + ? format("%s[%s,%zu]", gguf_type_name(type), gguf_type_name(gguf_get_arr_type(metadata, i)), gguf_get_arr_n(metadata, i)) : gguf_type_name(type); - std::string value = gguf_kv_to_str(meta.get(), i); + std::string value = gguf_kv_to_str(metadata, i); const size_t MAX_VALUE_LEN = 40; if (value.size() > MAX_VALUE_LEN) { value = format("%s...", value.substr(0, MAX_VALUE_LEN - 3).c_str()); @@ -824,15 +891,388 @@ const struct ggml_tensor * llama_model_loader::check_tensor_dims(const std::stri return cur; } -struct ggml_tensor * llama_model_loader::create_tensor(struct ggml_context * ctx, const std::string & name, const std::initializer_list<int64_t> & ne, int flags) { - LLAMA_LOG_DEBUG("%s: loading tensor %s\n", __func__, name.c_str()); - const struct ggml_tensor * cur = check_tensor_dims(name, ne, !(flags & TENSOR_NOT_REQUIRED)); +// checks if the weight tensor can be used with the specified buffer type and device +static bool weight_buft_supported(const llama_hparams & hparams, ggml_tensor * w, ggml_op op, ggml_backend_buffer_type_t buft, ggml_backend_dev_t dev) { + GGML_ASSERT(w != nullptr); + + if (op == GGML_OP_NONE) { + return true; + } + + ggml_init_params params = { + /*.mem_size =*/ ggml_tensor_overhead()*8, + /*.mem_buffer =*/ NULL, + /*.no_alloc =*/ true, + }; + ggml_context_ptr ctx_ptr { ggml_init(params) }; + if (!ctx_ptr) { + throw std::runtime_error(format("failed to create ggml context")); + } + ggml_context * ctx = ctx_ptr.get(); + + ggml_tensor * op_tensor = nullptr; + + switch (op) { + case GGML_OP_GET_ROWS: + { + ggml_tensor * b = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 512); + op_tensor = ggml_get_rows(ctx, w, b); + } break; + case GGML_OP_MUL_MAT: + { + ggml_tensor * b = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, w->ne[0], 512, w->ne[2], w->ne[3]); + op_tensor = ggml_mul_mat(ctx, w, b); + } break; + case GGML_OP_MUL_MAT_ID: + { + const int n_expert_used = hparams.n_expert_used; + GGML_ASSERT(n_expert_used > 0); + ggml_tensor * b = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, w->ne[0], n_expert_used, 512); + ggml_tensor * ids = ggml_new_tensor_2d(ctx, GGML_TYPE_I32, n_expert_used, 512); + op_tensor = ggml_mul_mat_id(ctx, w, b, ids); + } break; + case GGML_OP_ADD: + { + ggml_tensor * a = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, w->ne[0], w->ne[1], w->ne[2], w->ne[3]); + op_tensor = ggml_add(ctx, a, w); + } break; + case GGML_OP_ADD_ID: + { + const int n_expert_used = hparams.n_expert_used; + GGML_ASSERT(n_expert_used > 0); + ggml_tensor * a = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, w->ne[0], n_expert_used, 512); + ggml_tensor * c = ggml_new_tensor_2d(ctx, GGML_TYPE_I32, n_expert_used, 512); + op_tensor = ggml_add_id(ctx, a, w, c); + } break; + case GGML_OP_MUL: + { + ggml_tensor * a = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, w->ne[0], w->ne[1], w->ne[2], w->ne[3]); + op_tensor = ggml_mul(ctx, a, w); + } break; + case GGML_OP_DIV: + { + ggml_tensor * a = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, w->ne[0]); + op_tensor = ggml_div(ctx, a, w); + } break; + case GGML_OP_ROPE: + { + const int n_embd_head = hparams.n_embd_head_v(); + const int n_head = hparams.n_head(); + ggml_tensor * a = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, n_embd_head, n_head, 512); + ggml_tensor * b = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 512); + op_tensor = ggml_rope_ext( + ctx, a, b, w, + 0, 0, 0, 0, 0, + 0, 0, 0, 0 + ); + + } break; + case GGML_OP_SSM_CONV: + { + const int64_t n_seq_tokens = 512; + const int64_t n_seqs = 3; + ggml_tensor * conv_x = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, w->ne[0] - 1 + n_seq_tokens, w->ne[1], n_seqs); + op_tensor = ggml_ssm_conv(ctx, conv_x, w); + } break; + case GGML_OP_SSM_SCAN: + { + // w is ssm_a, which is used to distinguish Mamba-1 and Mamba-2 + const int64_t d_state = w->ne[0] == 1 ? hparams.ssm_d_state : w->ne[0]; + const int64_t n_head = w->ne[1]; + const int64_t head_dim = hparams.ssm_d_inner / n_head; + const int64_t n_group = hparams.ssm_n_group ? hparams.ssm_n_group : 1; + const int64_t n_seq_tokens = 512; + const int64_t n_seqs = 3; + ggml_tensor * s = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, d_state, head_dim, n_head, n_seqs); + ggml_tensor * x = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, head_dim, n_head, n_seq_tokens, n_seqs); + ggml_tensor * dt = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, n_head, n_seq_tokens, n_seqs); + ggml_tensor * B = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, d_state, n_group, n_seq_tokens, n_seqs); + ggml_tensor * C = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, d_state, n_group, n_seq_tokens, n_seqs); + ggml_tensor * ids = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, n_seqs); + op_tensor = ggml_ssm_scan(ctx, s, x, dt, w, B, C, ids); + } break; + case GGML_OP_RWKV_WKV6: + { + // FIXME + const int64_t S = 123; + const int64_t H = 123; + const int64_t n_tokens = 123; + const int64_t n_seqs = 123; + ggml_tensor * k = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, S, H, n_tokens); + ggml_tensor * v = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, S, H, n_tokens); + ggml_tensor * r = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, S, H, n_tokens); + ggml_tensor * tf = w; + ggml_tensor * td = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, S, H, n_tokens); + ggml_tensor * state = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, S, n_seqs, S, H); + op_tensor = ggml_rwkv_wkv6(ctx, k, v, r, tf, td, state); + } break; + case GGML_OP_IM2COL: + { + const int n_embd_inp = hparams.n_embd_inp(); + ggml_tensor * b = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, n_embd_inp, w->ne[1], 1, 1); + op_tensor = ggml_im2col(ctx, w, b, 1, 0, 0, 0, 1, 0, false, GGML_TYPE_F16); + } break; + case GGML_OP_SCALE: + { + op_tensor = ggml_scale(ctx, w, 1.0f); + } break; + default: + GGML_ABORT("%s: missing test for op %s for tensor %s", __func__, ggml_op_name(op), w->name); + } + + // create a temporary dummy buffer for the weight so that supports_op can check the buffer type + GGML_ASSERT(w->buffer == nullptr); + w->buffer = ggml_backend_buft_alloc_buffer(buft, 0); + bool op_supported = ggml_backend_dev_supports_op(dev, op_tensor); + ggml_backend_buffer_free(w->buffer); + w->buffer = nullptr; + + return op_supported; +} + +// find the first buffer type in the list that can use the tensor +static ggml_backend_buffer_type_t select_weight_buft(const llama_hparams & hparams, ggml_tensor * tensor, ggml_op op, const buft_list_t * buft_list) { + GGML_ASSERT(!buft_list->empty()); + for (const auto & cur : *buft_list) { + ggml_backend_dev_t cur_dev = cur.first; + ggml_backend_buffer_type_t cur_buft = cur.second; + if (weight_buft_supported(hparams, tensor, op, cur_buft, cur_dev)) { + return cur_buft; + } + } + + return nullptr; +} + +struct ggml_tensor * llama_model_loader::create_tensor( + const llama_hparams & hparams, const buft_list_t * buft_list_cpu, const buft_list_t * buft_list_input, const buft_list_t * buft_list_output, + const buft_list_t * buft_list_layer, const LLM_TN_IMPL & tn, const std::initializer_list<int64_t> & ne, int flags) { + auto ctx_for_buft = [&](ggml_backend_buffer_type_t buft) -> ggml_context * { + auto it = ctx_map.find(buft); + if (it == ctx_map.end()) { + // one ggml context per buffer type + int max_n_tensors = n_tensors; + max_n_tensors += 1; // duplicated output tensor + max_n_tensors += hparams.n_layer()*2; // duplicated rope freq tensors + if (files.empty()) { + max_n_tensors += hparams.n_layer()*256; // this should be well above what any model actually uses + } + const size_t ctx_size = ggml_tensor_overhead()*max_n_tensors; + + ggml_init_params params = { + /*.mem_size =*/ ctx_size, + /*.mem_buffer =*/ NULL, + /*.no_alloc =*/ true, + }; + + ggml_context * ctx = ggml_init(params); + if (!ctx) { + throw std::runtime_error(format("failed to create ggml context")); + } + + ctx_map.emplace(buft, ctx); + + return ctx; + } + return it->second.get(); + }; + + auto buft_for_tensor = [&](ggml_tensor * t_meta) -> ggml_backend_buffer_type_t { + if (!t_meta) { + if (flags & TENSOR_NOT_REQUIRED) { + return nullptr; + } + throw std::runtime_error(format("missing tensor '%s'", tn.str().c_str())); + } + + // some models use the token embedding tensor as the output, but since these are used in different layers and with different ops + // the tensor is duplicated + // to handle this, we check if the tensor is duplicated, and if so, we assume that it is being loaded as the output tensor + llm_tensor tn_tensor = tn.tensor; + if (tn.tensor == LLM_TENSOR_TOKEN_EMBD && (flags & TENSOR_DUPLICATED)) { + tn_tensor = LLM_TENSOR_OUTPUT; + } + + llm_tensor_info info; + try { + info = llm_tensor_info_for(tn_tensor); + } catch (const std::out_of_range & e) { + throw std::runtime_error(format("missing tensor info mapping for %s", tn.str().c_str())); + } + + // skip unused tensors + if (info.op == GGML_OP_NONE || (flags & TENSOR_SKIP)) { + const size_t nbytes = ggml_nbytes(t_meta); + LLAMA_LOG_WARN("model has unused tensor %s (size = %zu bytes) -- ignoring\n", tn.str().c_str(), nbytes); + + size_data -= nbytes; + n_created++; + + return nullptr; + } + + // tensors with "bias" suffix are always used with GGML_OP_ADD or GGML_OP_ADD_ID + ggml_op op; + bool bias = tn.suffix != nullptr && strcmp(tn.suffix, "bias") == 0; + if (bias) { + if (info.op == GGML_OP_MUL_MAT_ID) { + op = GGML_OP_ADD_ID; + } else { + op = GGML_OP_ADD; + } + } else { + op = info.op; + } + + // sanity checks + if (info.layer == LLM_TENSOR_LAYER_INPUT || info.layer == LLM_TENSOR_LAYER_OUTPUT) { + if (tn.bid != -1) { + GGML_ABORT("input/output layer tensor %s used with a layer number", tn.str().c_str()); + } + } else { + if (tn.bid == -1) { + GGML_ABORT("repeating layer tensor %s used without a layer number", tn.str().c_str()); + } + } + + // select the buffer type for this tensor + const buft_list_t * buft_list; + switch (info.layer) { + case LLM_TENSOR_LAYER_INPUT: + buft_list = buft_list_input; + break; + case LLM_TENSOR_LAYER_OUTPUT: + buft_list = buft_list_output; + break; + case LLM_TENSOR_LAYER_REPEATING: + GGML_ASSERT(buft_list_layer != nullptr); + buft_list = buft_list_layer; + break; + default: + GGML_ABORT("invalid layer %d for tensor %s", info.layer, tn.str().c_str()); + } + + ggml_backend_buffer_type_t buft = nullptr; + + // check overrides + if (tensor_buft_overrides) { + std::string tensor_name = tn.str(); + for (const auto * overrides = tensor_buft_overrides; overrides->pattern != nullptr; ++overrides) { + std::regex pattern(overrides->pattern); + if (std::regex_search(tensor_name, pattern)) { + if (overrides->buft == ggml_backend_cpu_buffer_type()) { + // when overriding to a CPU buffer, consider the extra buffer types + buft = select_weight_buft(hparams, t_meta, op, buft_list_cpu); + if (use_mmap) { + static std::once_flag once; + std::call_once(once, [] { + LLAMA_LOG_WARN("llama_model_loader: tensor overrides to CPU are used with mmap enabled - consider using --no-mmap for better performance\n"); + }); + } + } else { + buft = overrides->buft; + } + + LLAMA_LOG_DEBUG("tensor %s (%zu MiB %s) buffer type overridden to %s\n", + tensor_name.c_str(), + ggml_nbytes(t_meta) / 1024 / 1024, ggml_type_name(t_meta->type), + ggml_backend_buft_name(buft)); + break; + } + } + } + + if (!buft) { + buft = select_weight_buft(hparams, t_meta, op, buft_list); + if (!buft) { + throw std::runtime_error(format("failed to find a compatible buffer type for tensor %s", tn.str().c_str())); + } + } + + // avoid using a host buffer when using mmap + auto * buft_dev = ggml_backend_buft_get_device(buft); + if (use_mmap && buft_dev && buft == ggml_backend_dev_host_buffer_type(buft_dev)) { + auto * cpu_dev = ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_CPU); + if (!cpu_dev) { + throw std::runtime_error("no CPU backend found"); + } + buft = ggml_backend_dev_buffer_type(cpu_dev); + } + + if (buft != buft_list->front().second) { + if (n_tensors_moved == 0) { + first_tensor_moved_name = t_meta->name; + first_tensor_moved_type_name = ggml_type_name(t_meta->type); + first_moved_from_buft = buft_list->front().second; + first_moved_to_buft = buft; + } + n_tensors_moved++; + } + + return buft; + }; + + if (files.empty()) { + if (flags & TENSOR_SKIP_IF_VIRTUAL) { + return nullptr; + } + ggml_type type = GGML_TYPE_F32; + const int64_t tid = gguf_find_tensor(metadata, tn.str().c_str()); + if (tid != -1) { + type = gguf_get_tensor_type(metadata, tid); + } + + // for tensors that are not required some of the dimensions can be invalid: + if (flags & TENSOR_NOT_REQUIRED) { + for (size_t dim = 0; dim < ne.size(); dim++) { + if (ne.begin()[dim] <= 0) { + return nullptr; + } + } + } + + ggml_tensor t_meta; + memset(&t_meta, 0, sizeof(ggml_tensor)); + t_meta.type = type; + for (size_t dim = 0; dim < GGML_MAX_DIMS; dim++) { + t_meta.ne[dim] = dim < ne.size() ? ne.begin()[dim] : 1; + GGML_ASSERT(t_meta.ne[dim] >= 1); + t_meta.nb[dim] = dim == 0 ? ggml_type_size(type) : t_meta.ne[dim-1]*t_meta.nb[dim-1]; + GGML_ASSERT(t_meta.nb[dim] >= 1); + } + ggml_set_name(&t_meta, tn.str().c_str()); + + ggml_backend_buffer_type_t buft = buft_for_tensor(&t_meta); + GGML_ASSERT(buft != nullptr); + ggml_context * ctx = ctx_for_buft(buft); + ggml_tensor * ret = ggml_dup_tensor(ctx, &t_meta); + ggml_set_name(ret, tn.str().c_str()); + return ret; + } + + ggml_tensor * t_meta = get_tensor_meta(tn.str().c_str()); + ggml_backend_buffer_type_t buft = buft_for_tensor(t_meta); + if (buft == nullptr) { + return nullptr; // return type is ggml_tensor * + } + ggml_context * ctx = ctx_for_buft(buft); + + // if duplicated, check if the original tensor was allocated in the same buffer type context and avoid creating a new one + if (flags & TENSOR_DUPLICATED) { + ggml_tensor * t = ggml_get_tensor(ctx, tn.str().c_str()); + if (t) { + return t; + } + } + + LLAMA_LOG_DEBUG("%s: loading tensor %s\n", __func__, tn.str().c_str()); + const struct ggml_tensor * cur = check_tensor_dims(tn.str(), ne, !(flags & TENSOR_NOT_REQUIRED)); if (cur == NULL) { return NULL; } - bool duplicated = flags & TENSOR_DUPLICATED; + const bool duplicated = flags & TENSOR_DUPLICATED; struct ggml_tensor * tensor = ggml_dup_tensor(ctx, cur); ggml_set_name(tensor, ggml_get_name(cur)); @@ -844,7 +1284,6 @@ struct ggml_tensor * llama_model_loader::create_tensor(struct ggml_context * ctx } return tensor; - } struct ggml_tensor * llama_model_loader::create_tensor_as_view(struct ggml_context * ctx, struct ggml_tensor * base, const std::string & name, const std::initializer_list<int64_t> & ne, size_t offset, bool required) { @@ -875,9 +1314,21 @@ struct ggml_tensor * llama_model_loader::create_tensor_as_view(struct ggml_conte return tensor; } -void llama_model_loader::done_getting_tensors() const { - if (n_created != n_tensors) { - throw std::runtime_error(format("%s: wrong number of tensors; expected %d, got %d", __func__, n_tensors, n_created)); +void llama_model_loader::done_getting_tensors(bool partial) const { + if (n_created > n_tensors) { + throw std::runtime_error(format("%s: too many tensors created; expected %d, got %d", __func__, n_tensors, n_created)); + } + if (n_created < n_tensors) { + if (!partial) { + throw std::runtime_error(format("%s: wrong number of tensors; expected %d, got %d", __func__, n_tensors, n_created)); + } + LLAMA_LOG_INFO("%s: partial load — used %d of %d tensors in the file (rest belong to a sibling model on the same .gguf)\n", + __func__, n_created, n_tensors); + } + if (n_tensors_moved > 0) { + LLAMA_LOG_DEBUG("%s: tensor '%s' (%s) (and %zu others) cannot be used with preferred buffer type %s, using %s instead\n", + __func__, first_tensor_moved_name.c_str(), first_tensor_moved_type_name.c_str(), n_tensors_moved - 1, + ggml_backend_buft_name(first_moved_from_buft), ggml_backend_buft_name(first_moved_to_buft)); } } @@ -960,6 +1411,12 @@ bool llama_model_loader::load_all_data( llama_mlocks * lmlocks, llama_progress_callback progress_callback, void * progress_callback_user_data) { + if (files.empty()) { + for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != nullptr; t = ggml_get_next_tensor(ctx, t)) { + set_tensor_data(t, set_tensor_data_ud); + } + return true; + } GGML_ASSERT(size_data != 0 && "call init_mappings() first"); std::vector<no_init<uint8_t>> read_buf; diff --git a/examples/talk-llama/llama-model-loader.h b/examples/talk-llama/llama-model-loader.h index 65953dd3d5a..c476026d3e5 100644 --- a/examples/talk-llama/llama-model-loader.h +++ b/examples/talk-llama/llama-model-loader.h @@ -4,17 +4,22 @@ #include "llama-impl.h" #include "llama-arch.h" +#include "llama-hparams.h" #include "llama-mmap.h" #include "ggml-cpp.h" #include <cstddef> +#include <cstring> #include <map> #include <stdexcept> #include <unordered_map> using llama_buf_map = std::unordered_map<uint32_t, ggml_backend_buffer_t>; +// lists of buffer types used for each layer +using buft_list_t = std::vector<std::pair<ggml_backend_dev_t, ggml_backend_buffer_type_t>>; + enum llama_fver { GGUF_FILE_VERSION_V1 = 1, GGUF_FILE_VERSION_V2 = 2, @@ -58,9 +63,10 @@ struct llama_model_loader { } }; - static const int TENSOR_NOT_REQUIRED = 1 << 0; - static const int TENSOR_DUPLICATED = 1 << 1; - static const int TENSOR_SKIP = 1 << 2; + static const int TENSOR_NOT_REQUIRED = 1 << 0; + static const int TENSOR_DUPLICATED = 1 << 1; + static const int TENSOR_SKIP = 1 << 2; + static const int TENSOR_SKIP_IF_VIRTUAL = 1 << 3; int n_kv = 0; int n_tensors = 0; @@ -84,7 +90,10 @@ struct llama_model_loader { std::unordered_map<std::string, llama_model_kv_override> kv_overrides; const llama_model_tensor_buft_override * tensor_buft_overrides; - gguf_context_ptr meta; + gguf_context_ptr metadata_ptr; + struct gguf_context * metadata; // either metadata_ptr.get() or externally set + llama_model_set_tensor_data_t set_tensor_data; + void * set_tensor_data_ud; std::vector<ggml_context_ptr> contexts; std::string arch_name; @@ -94,9 +103,29 @@ struct llama_model_loader { size_t size_data = 0; std::vector<std::pair<size_t, size_t>> mmaps_used; + // define a comparator for the buft -> ctx map to ensure that the order is well-defined: + struct ggml_backend_buft_comparator { + bool operator()(const ggml_backend_buffer_type_t & lhs, const ggml_backend_buffer_type_t & rhs) const { + return strcmp(ggml_backend_buft_name(lhs), ggml_backend_buft_name(rhs)) < 0; + } + }; + + std::map<ggml_backend_buffer_type_t, ggml_context_ptr, ggml_backend_buft_comparator> ctx_map; + + // track tensors that had to be moved for debugging: + size_t n_tensors_moved = 0; + std::string first_tensor_moved_name; + std::string first_tensor_moved_type_name; + ggml_backend_buffer_type_t first_moved_from_buft = nullptr; + ggml_backend_buffer_type_t first_moved_to_buft = nullptr; + llama_model_loader( + struct gguf_context * metadata, + llama_model_set_tensor_data_t set_tensor_data, + void * set_tensor_data_ud, const std::string & fname, std::vector<std::string> & splits, // optional, only need if the split does not follow naming scheme + FILE * file, bool use_mmap, bool use_direct_io, bool check_tensors, @@ -149,11 +178,13 @@ struct llama_model_loader { const struct ggml_tensor * check_tensor_dims(const std::string & name, const std::vector<int64_t> & ne, bool required) const; - struct ggml_tensor * create_tensor(struct ggml_context * ctx, const std::string & name, const std::initializer_list<int64_t> & ne, int flags = 0); + struct ggml_tensor * create_tensor( + const llama_hparams & hparams, const buft_list_t * buft_list_cpu, const buft_list_t * buft_list_input, const buft_list_t * buft_list_output, + const buft_list_t * buft_list_layer, const LLM_TN_IMPL & tn, const std::initializer_list<int64_t> & ne, int flags); struct ggml_tensor * create_tensor_as_view(struct ggml_context * ctx, struct ggml_tensor * base, const std::string & name, const std::initializer_list<int64_t> & ne, size_t offset, bool required = true); - void done_getting_tensors() const; + void done_getting_tensors(bool partial = false) const; void init_mappings(bool prefetch = true, llama_mlocks * mlock_mmaps = nullptr); diff --git a/examples/talk-llama/llama-model-saver.cpp b/examples/talk-llama/llama-model-saver.cpp index ae27c71ce23..67d4a9df0f0 100644 --- a/examples/talk-llama/llama-model-saver.cpp +++ b/examples/talk-llama/llama-model-saver.cpp @@ -1,20 +1,50 @@ #include "llama-model-saver.h" +#include "ggml.h" #include "gguf.h" +#include "llama-arch.h" #include "llama.h" #include "llama-hparams.h" #include "llama-model.h" #include "llama-vocab.h" +#include <cstdint> #include <string> -llama_model_saver::llama_model_saver(const struct llama_model & model) : model(model), llm_kv(model.arch) { - gguf_ctx = gguf_init_empty(); +bool llama_model_saver_supports_arch(llm_arch arch) { + switch (arch) { + case LLM_ARCH_PLAMO3: + case LLM_ARCH_GEMMA3: + case LLM_ARCH_GEMMA3N: + case LLM_ARCH_COHERE2: + case LLM_ARCH_OLMO2: + case LLM_ARCH_BITNET: + case LLM_ARCH_T5: + case LLM_ARCH_EXAONE_MOE: + case LLM_ARCH_AFMOE: + case LLM_ARCH_APERTUS: + case LLM_ARCH_MIMO2: + case LLM_ARCH_STEP35: + case LLM_ARCH_MELLUM: + return false; + default: + return true; + } +} + +llama_model_saver::llama_model_saver(const struct llama_model * model) : + gguf_ctx(gguf_init_empty()), gguf_ctx_owned(true), model(model), llm_kv(model->arch) { + GGML_ASSERT(llama_model_saver_supports_arch(model->arch)); } +llama_model_saver::llama_model_saver(enum llm_arch arch, struct gguf_context * gguf_ctx) : + gguf_ctx(gguf_ctx == nullptr ? gguf_init_empty() : gguf_ctx), gguf_ctx_owned(gguf_ctx == nullptr), model(nullptr), llm_kv(arch) {} + llama_model_saver::~llama_model_saver() { - gguf_free(gguf_ctx); + if (gguf_ctx_owned) { + gguf_free(gguf_ctx); + } } void llama_model_saver::add_kv(const enum llm_kv key, const uint32_t value) { @@ -46,7 +76,8 @@ void llama_model_saver::add_kv(const enum llm_kv key, const char value) { template <typename Container> void llama_model_saver::add_kv(const enum llm_kv key, const Container & value, const bool per_layer) { - const size_t n_values = per_layer ? size_t(model.hparams.n_layer) : value.size(); + GGML_ASSERT(model != nullptr || !per_layer); + const size_t n_values = per_layer ? size_t(model->hparams.n_layer()) : value.size(); GGML_ASSERT(n_values <= value.size()); if (n_values == 0) { @@ -73,6 +104,8 @@ void llama_model_saver::add_kv(const enum llm_kv key, const Container & value, c gguf_set_arr_data(gguf_ctx, llm_kv(key).c_str(), GGUF_TYPE_INT8, value.data(), n_values); } else if (std::is_same<typename Container::value_type, uint32_t>::value) { gguf_set_arr_data(gguf_ctx, llm_kv(key).c_str(), GGUF_TYPE_UINT32, value.data(), n_values); + } else if (std::is_same<typename Container::value_type, bool>::value) { + gguf_set_arr_data(gguf_ctx, llm_kv(key).c_str(), GGUF_TYPE_BOOL, value.data(), n_values); } else if (std::is_same<typename Container::value_type, int32_t>::value) { gguf_set_arr_data(gguf_ctx, llm_kv(key).c_str(), GGUF_TYPE_INT32, value.data(), n_values); } else if (std::is_same<typename Container::value_type, float>::value) { @@ -83,6 +116,8 @@ void llama_model_saver::add_kv(const enum llm_kv key, const Container & value, c GGML_ABORT("fatal error"); } } +// instantiate for external usage: +template void llama_model_saver::add_kv<std::vector<uint32_t>>(const enum llm_kv, const std::vector<uint32_t> &, const bool); void llama_model_saver::add_kv(const enum llm_kv key, const std::vector<std::string> & value) { std::vector<const char *> tmp(value.size()); @@ -97,44 +132,66 @@ void llama_model_saver::add_tensor(const struct ggml_tensor * tensor) { return; } if (gguf_find_tensor(gguf_ctx, tensor->name) >= 0) { - GGML_ASSERT(std::string(tensor->name) == "rope_freqs.weight"); // FIXME + const std::string tensor_name = tensor->name; + GGML_ASSERT( + tensor_name == "rope_freqs.weight" || tensor_name == "rope_factors_long.weight" || + tensor_name == "rope_factors_short.weight"); // FIXME return; } gguf_add_tensor(gguf_ctx, tensor); } void llama_model_saver::add_kv_from_model() { - const llama_hparams & hparams = model.hparams; - const llama_vocab & vocab = model.vocab; + const llama_hparams & hparams = model->hparams; + const llama_vocab & vocab = model->vocab; const int32_t n_vocab = vocab.n_tokens(); std::vector<std::string> tokens(n_vocab); std::vector<float> scores(n_vocab); std::vector<int32_t> token_types(n_vocab); - for (int32_t id = 0; id < n_vocab; ++id) { - const llama_vocab::token_data & token_data = vocab.get_token_data(id); - - tokens[id] = token_data.text; - scores[id] = token_data.score; - - switch(token_data.attr) { - case LLAMA_TOKEN_ATTR_UNKNOWN: token_types[id] = LLAMA_TOKEN_TYPE_UNKNOWN; break; - case LLAMA_TOKEN_ATTR_UNUSED: token_types[id] = LLAMA_TOKEN_TYPE_UNUSED; break; - case LLAMA_TOKEN_ATTR_NORMAL: token_types[id] = LLAMA_TOKEN_TYPE_NORMAL; break; - case LLAMA_TOKEN_ATTR_CONTROL: token_types[id] = LLAMA_TOKEN_TYPE_CONTROL; break; - case LLAMA_TOKEN_ATTR_USER_DEFINED: token_types[id] = LLAMA_TOKEN_TYPE_USER_DEFINED; break; - case LLAMA_TOKEN_ATTR_BYTE: token_types[id] = LLAMA_TOKEN_TYPE_BYTE; break; - case LLAMA_TOKEN_ATTR_UNDEFINED: - default: token_types[id] = LLAMA_TOKEN_TYPE_UNDEFINED; break; + if (vocab.get_type() != LLAMA_VOCAB_TYPE_NONE) { + for (int32_t id = 0; id < n_vocab; ++id) { + const llama_vocab::token_data & token_data = vocab.get_token_data(id); + + tokens[id] = token_data.text; + scores[id] = token_data.score; + + // FIXME should this be treated as flags? + switch(token_data.attr) { + case LLAMA_TOKEN_ATTR_UNKNOWN: token_types[id] = LLAMA_TOKEN_TYPE_UNKNOWN; break; + case LLAMA_TOKEN_ATTR_UNUSED: token_types[id] = LLAMA_TOKEN_TYPE_UNUSED; break; + case LLAMA_TOKEN_ATTR_NORMAL: token_types[id] = LLAMA_TOKEN_TYPE_NORMAL; break; + case LLAMA_TOKEN_ATTR_CONTROL: token_types[id] = LLAMA_TOKEN_TYPE_CONTROL; break; + case LLAMA_TOKEN_ATTR_USER_DEFINED: token_types[id] = LLAMA_TOKEN_TYPE_USER_DEFINED; break; + case LLAMA_TOKEN_ATTR_BYTE: token_types[id] = LLAMA_TOKEN_TYPE_BYTE; break; + // case LLAMA_TOKEN_ATTR_NORMALIZED: ??? + // case LLAMA_TOKEN_ATTR_LSTRIP: ??? + // case LLAMA_TOKEN_ATTR_RSTRIP: ??? + case LLAMA_TOKEN_ATTR_UNDEFINED: + default: token_types[id] = LLAMA_TOKEN_TYPE_UNDEFINED; break; + } } } // add_kv(LLM_KV_GENERAL_TYPE, ???); - add_kv(LLM_KV_GENERAL_ARCHITECTURE, model.arch_name()); + add_kv(LLM_KV_GENERAL_ARCHITECTURE, model->arch_name()); // add_kv(LLM_KV_GENERAL_QUANTIZATION_VERSION, ???); // add_kv(LLM_KV_GENERAL_ALIGNMENT, ???); - add_kv(LLM_KV_GENERAL_NAME, model.name); + // add_kv(LLM_KV_GENERAL_FILE_TYPE, ???); + // add_kv(LLM_KV_GENERAL_SAMPLING_SEQUENCE, ???); + // add_kv(LLM_KV_GENERAL_SAMPLING_TOP_K, ???); + // add_kv(LLM_KV_GENERAL_SAMPLING_TOP_P, ???); + // add_kv(LLM_KV_GENERAL_SAMPLING_MIN_P, ???); + // add_kv(LLM_KV_GENERAL_SAMPLING_XTC_PROBABILITY, ???); + // add_kv(LLM_KV_GENERAL_SAMPLING_XTC_THRESHOLD, ???); + // add_kv(LLM_KV_GENERAL_SAMPLING_TEMP, ???); + // add_kv(LLM_KV_GENERAL_SAMPLING_PENALTY_LAST_N, ???); + // add_kv(LLM_KV_GENERAL_SAMPLING_PENALTY_REPEAT, ???); + // add_kv(LLM_KV_GENERAL_SAMPLING_MIROSTAT, ???); + // add_kv(LLM_KV_GENERAL_SAMPLING_MIROSTAT_TAU, ???); + // add_kv(LLM_KV_GENERAL_SAMPLING_MIROSTAT_ETA, ???); + add_kv(LLM_KV_GENERAL_NAME, model->name); // add_kv(LLM_KV_GENERAL_AUTHOR, ???); // add_kv(LLM_KV_GENERAL_VERSION, ???); // add_kv(LLM_KV_GENERAL_URL, ???); @@ -146,24 +203,39 @@ void llama_model_saver::add_kv_from_model() { add_kv(LLM_KV_VOCAB_SIZE, vocab.n_tokens()); add_kv(LLM_KV_CONTEXT_LENGTH, hparams.n_ctx_train); add_kv(LLM_KV_EMBEDDING_LENGTH, hparams.n_embd); - if (hparams.n_embd_out > 0) { - add_kv(LLM_KV_EMBEDDING_LENGTH_OUT, hparams.n_embd_out); + if (hparams.n_embd_out_impl > 0) { + add_kv(LLM_KV_EMBEDDING_LENGTH_OUT, hparams.n_embd_out_impl); } - add_kv(LLM_KV_BLOCK_COUNT, hparams.n_layer); + add_kv(LLM_KV_BLOCK_COUNT, hparams.n_layer_all); add_kv(LLM_KV_LEADING_DENSE_BLOCK_COUNT, hparams.n_layer_dense_lead); add_kv(LLM_KV_FEED_FORWARD_LENGTH, hparams.n_ff_arr, true); add_kv(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp); - add_kv(LLM_KV_EXPERT_SHARED_FEED_FORWARD_LENGTH, hparams.n_ff_exp); + add_kv(LLM_KV_EXPERT_SHARED_FEED_FORWARD_LENGTH, hparams.n_ff_shexp); + add_kv(LLM_KV_EXPERT_SHARED_FEED_FORWARD_LENGTH, hparams.n_ff_chexp); + add_kv(LLM_KV_SWIGLU_CLAMP_EXP, hparams.swiglu_clamp_exp); + add_kv(LLM_KV_SWIGLU_CLAMP_SHEXP, hparams.swiglu_clamp_shexp); add_kv(LLM_KV_USE_PARALLEL_RESIDUAL, hparams.use_par_res); // add_kv(LLM_KV_TENSOR_DATA_LAYOUT, ???); add_kv(LLM_KV_EXPERT_COUNT, hparams.n_expert); add_kv(LLM_KV_EXPERT_USED_COUNT, hparams.n_expert_used); add_kv(LLM_KV_EXPERT_SHARED_COUNT, hparams.n_expert_shared); + add_kv(LLM_KV_EXPERT_GROUP_COUNT, hparams.n_expert_groups); + add_kv(LLM_KV_EXPERT_GROUP_USED_COUNT, hparams.n_group_used); add_kv(LLM_KV_EXPERT_WEIGHTS_SCALE, hparams.expert_weights_scale); + add_kv(LLM_KV_EXPERT_WEIGHTS_NORM, hparams.expert_weights_norm); + add_kv(LLM_KV_EXPERT_GATING_FUNC, hparams.expert_gating_func); + add_kv(LLM_KV_EXPERT_GROUP_SCALE, hparams.expert_group_scale); + add_kv(LLM_KV_EXPERTS_PER_GROUP, hparams.n_group_experts); + add_kv(LLM_KV_MOE_EVERY_N_LAYERS, hparams.moe_every_n_layers); + add_kv(LLM_KV_NEXTN_PREDICT_LAYERS, hparams.n_layer_nextn); + add_kv(LLM_KV_NUM_DEEPSTACK_LAYERS, hparams.n_deepstack_layers); + add_kv(LLM_KV_DEEPSTACK_MAPPING, hparams.deepstack_mapping_arr); add_kv(LLM_KV_POOLING_TYPE, uint32_t(hparams.pooling_type)); add_kv(LLM_KV_LOGIT_SCALE, hparams.f_logit_scale); add_kv(LLM_KV_DECODER_START_TOKEN_ID, hparams.dec_start_token_id); + add_kv(LLM_KV_DECODER_BLOCK_COUNT, hparams.dec_n_layer); add_kv(LLM_KV_ATTN_LOGIT_SOFTCAPPING, hparams.f_attn_logit_softcapping); + add_kv(LLM_KV_ROUTER_LOGIT_SOFTCAPPING, hparams.f_router_logit_softcapping); add_kv(LLM_KV_FINAL_LOGIT_SOFTCAPPING, hparams.f_final_logit_softcapping); add_kv(LLM_KV_SWIN_NORM, hparams.swin_norm); add_kv(LLM_KV_RESCALE_EVERY_N_LAYERS, hparams.rescale_every_n_layers); @@ -171,26 +243,51 @@ void llama_model_saver::add_kv_from_model() { add_kv(LLM_KV_TIME_DECAY_EXTRA_DIM, hparams.time_decay_extra_dim); add_kv(LLM_KV_RESIDUAL_SCALE, hparams.f_residual_scale); add_kv(LLM_KV_EMBEDDING_SCALE, hparams.f_embedding_scale); + add_kv(LLM_KV_TOKEN_SHIFT_COUNT, hparams.token_shift_count); + add_kv(LLM_KV_INTERLEAVE_MOE_LAYER_STEP, hparams.n_moe_layer_step); + // add_kv(LLM_KV_FULL_ATTENTION_INTERVAL, ???); // saved as LLM_KV_ATTENTION_RECURRENT_LAYERS instead add_kv(LLM_KV_ATTENTION_HEAD_COUNT, hparams.n_head_arr, true); add_kv(LLM_KV_ATTENTION_HEAD_COUNT_KV, hparams.n_head_kv_arr, true); add_kv(LLM_KV_ATTENTION_MAX_ALIBI_BIAS, hparams.f_max_alibi_bias); add_kv(LLM_KV_ATTENTION_CLAMP_KQV, hparams.f_clamp_kqv); - add_kv(LLM_KV_ATTENTION_KEY_LENGTH, hparams.n_embd_head_k); - add_kv(LLM_KV_ATTENTION_VALUE_LENGTH, hparams.n_embd_head_v); + add_kv(LLM_KV_ATTENTION_KEY_LENGTH, hparams.n_embd_head_k_full); + add_kv(LLM_KV_ATTENTION_VALUE_LENGTH, hparams.n_embd_head_v_full); add_kv(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); add_kv(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + add_kv(LLM_KV_ATTENTION_GROUPNORM_EPS, hparams.f_norm_group_eps); + add_kv(LLM_KV_ATTENTION_GROUPNORM_GROUPS, hparams.n_norm_groups); add_kv(LLM_KV_ATTENTION_CAUSAL, hparams.causal_attn); add_kv(LLM_KV_ATTENTION_Q_LORA_RANK, hparams.n_lora_q); add_kv(LLM_KV_ATTENTION_KV_LORA_RANK, hparams.n_lora_kv); + add_kv(LLM_KV_ATTENTION_DECAY_LORA_RANK, hparams.n_lora_decay); + add_kv(LLM_KV_ATTENTION_ICLR_LORA_RANK, hparams.n_lora_iclr); + add_kv(LLM_KV_ATTENTION_VALUE_RESIDUAL_MIX_LORA_RANK, hparams.n_lora_value_res_mix); + add_kv(LLM_KV_ATTENTION_GATE_LORA_RANK, hparams.n_lora_gate); add_kv(LLM_KV_ATTENTION_RELATIVE_BUCKETS_COUNT, hparams.n_rel_attn_bkts); add_kv(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa); + // add_kv(LLM_KV_ATTENTION_SLIDING_WINDOW_PATTERN, ???); add_kv(LLM_KV_ATTENTION_SCALE, hparams.f_attention_scale); + add_kv(LLM_KV_ATTENTION_OUTPUT_SCALE, hparams.f_attn_out_scale); + add_kv(LLM_KV_ATTENTION_VALUE_SCALE, hparams.f_attn_value_scale); + add_kv(LLM_KV_ATTENTION_TEMPERATURE_LENGTH, hparams.attn_temp_length); + add_kv(LLM_KV_ATTENTION_TEMPERATURE_SCALE, hparams.f_attn_temp_scale); + add_kv(LLM_KV_ATTENTION_KEY_LENGTH_MLA, hparams.n_embd_head_k_mla_impl); + add_kv(LLM_KV_ATTENTION_VALUE_LENGTH_MLA, hparams.n_embd_head_v_mla_impl); + add_kv(LLM_KV_ATTENTION_KEY_LENGTH_SWA, hparams.n_embd_head_k_swa); + add_kv(LLM_KV_ATTENTION_VALUE_LENGTH_SWA, hparams.n_embd_head_v_swa); + add_kv(LLM_KV_ATTENTION_INDEXER_HEAD_COUNT, hparams.indexer_n_head); + add_kv(LLM_KV_ATTENTION_INDEXER_KEY_LENGTH, hparams.indexer_head_size); + add_kv(LLM_KV_ATTENTION_INDEXER_TOP_K, hparams.indexer_top_k); + add_kv(LLM_KV_ATTENTION_RECURRENT_LAYERS, hparams.is_recr_impl, true); const float rope_scaling_factor = hparams.rope_freq_scale_train == 1.0f ? 0.0f : 1.0f/hparams.rope_freq_scale_train; - add_kv(LLM_KV_ROPE_DIMENSION_COUNT, hparams.n_rot); + add_kv(LLM_KV_ROPE_DIMENSION_COUNT, hparams.n_rot_full); + add_kv(LLM_KV_ROPE_DIMENSION_COUNT_SWA, hparams.n_rot_swa); + add_kv(LLM_KV_ROPE_DIMENSION_SECTIONS, hparams.rope_sections); add_kv(LLM_KV_ROPE_FREQ_BASE, hparams.rope_freq_base_train); + add_kv(LLM_KV_ROPE_FREQ_BASE_SWA, hparams.rope_freq_base_train_swa); // add_kv(LLM_KV_ROPE_SCALE_LINEAR, rope_scaling_factor); // old name add_kv(LLM_KV_ROPE_SCALING_TYPE, llama_rope_scaling_type_name(hparams.rope_scaling_type_train)); add_kv(LLM_KV_ROPE_SCALING_FACTOR, rope_scaling_factor); @@ -198,6 +295,10 @@ void llama_model_saver::add_kv_from_model() { add_kv(LLM_KV_ROPE_SCALING_ORIG_CTX_LEN, hparams.n_ctx_orig_yarn); add_kv(LLM_KV_ROPE_SCALING_FINETUNED, hparams.rope_finetuned); add_kv(LLM_KV_ROPE_SCALING_YARN_LOG_MUL, hparams.rope_yarn_log_mul); + add_kv(LLM_KV_ROPE_SCALING_YARN_EXT_FACTOR, hparams.yarn_ext_factor); + add_kv(LLM_KV_ROPE_SCALING_YARN_ATTN_FACTOR, hparams.yarn_attn_factor); + add_kv(LLM_KV_ROPE_SCALING_YARN_BETA_FAST, hparams.yarn_beta_fast); + add_kv(LLM_KV_ROPE_SCALING_YARN_BETA_SLOW, hparams.yarn_beta_slow); // TODO: implement split file support // add_kv(LLM_KV_SPLIT_NO, ???); @@ -208,8 +309,11 @@ void llama_model_saver::add_kv_from_model() { add_kv(LLM_KV_SSM_CONV_KERNEL, hparams.ssm_d_conv); add_kv(LLM_KV_SSM_STATE_SIZE, hparams.ssm_d_state); add_kv(LLM_KV_SSM_TIME_STEP_RANK, hparams.ssm_dt_rank); + add_kv(LLM_KV_SSM_GROUP_COUNT, hparams.ssm_n_group); add_kv(LLM_KV_SSM_DT_B_C_RMS, hparams.ssm_dt_b_c_rms); + add_kv(LLM_KV_KDA_HEAD_DIM, hparams.n_embd_head_kda); + add_kv(LLM_KV_WKV_HEAD_SIZE, hparams.wkv_head_size); add_kv(LLM_KV_TOKENIZER_MODEL, vocab.get_tokenizer_model()); @@ -247,32 +351,59 @@ void llama_model_saver::add_kv_from_model() { // TODO: implement LoRA support // add_kv(LLM_KV_ADAPTER_TYPE, ???); // add_kv(LLM_KV_ADAPTER_LORA_ALPHA, ???); + // add_kv(LLM_KV_ADAPTER_LORA_TASK_NAME, ???); + // add_kv(LLM_KV_ADAPTER_LORA_PROMPT_PREFIX, ???); + // add_kv(LLM_KV_ADAPTER_ALORA_INVOCATION_TOKENS, ???); + + add_kv(LLM_KV_POSNET_EMBEDDING_LENGTH, hparams.posnet.n_embd); + add_kv(LLM_KV_POSNET_BLOCK_COUNT, hparams.posnet.n_layer); + + add_kv(LLM_KV_CONVNEXT_EMBEDDING_LENGTH, hparams.convnext.n_embd); + add_kv(LLM_KV_CONVNEXT_BLOCK_COUNT, hparams.convnext.n_layer); + + add_kv(LLM_KV_CLASSIFIER_OUTPUT_LABELS, model->classifier_labels); + + add_kv(LLM_KV_SHORTCONV_L_CACHE, hparams.n_shortconv_l_cache); + + add_kv(LLM_KV_XIELU_ALPHA_N, hparams.xielu_alpha_n); + add_kv(LLM_KV_XIELU_ALPHA_P, hparams.xielu_alpha_p); + add_kv(LLM_KV_XIELU_BETA, hparams.xielu_beta); + add_kv(LLM_KV_XIELU_EPS, hparams.xielu_eps); // deprecated // add_kv(LLM_KV_TOKENIZER_PREFIX_ID, ???); // add_kv(LLM_KV_TOKENIZER_SUFFIX_ID, ???); // add_kv(LLM_KV_TOKENIZER_MIDDLE_ID, ???); + + add_kv(LLM_KV_DENSE_2_FEAT_IN, hparams.dense_2_feat_in); + add_kv(LLM_KV_DENSE_2_FEAT_OUT, hparams.dense_2_feat_out); + add_kv(LLM_KV_DENSE_3_FEAT_IN, hparams.dense_3_feat_in); + add_kv(LLM_KV_DENSE_3_FEAT_OUT, hparams.dense_3_feat_out); } void llama_model_saver::add_tensors_from_model() { - if (std::string(model.output->name) != std::string(model.tok_embd->name)) { - add_tensor(model.tok_embd); // some models use the same tensor for tok_embd and output + if (model->output != nullptr && + std::string(model->output->name) != std::string(model->tok_embd->name)) { + add_tensor(model->tok_embd); // some models use the same tensor for tok_embd and output } - add_tensor(model.type_embd); - add_tensor(model.pos_embd); - add_tensor(model.tok_norm); - add_tensor(model.tok_norm_b); - add_tensor(model.output_norm); - add_tensor(model.output_norm_b); - add_tensor(model.output); - add_tensor(model.output_b); - add_tensor(model.output_norm_enc); - add_tensor(model.cls); - add_tensor(model.cls_b); - add_tensor(model.cls_out); - add_tensor(model.cls_out_b); - - for (const struct llama_layer & layer : model.layers) { + add_tensor(model->type_embd); + add_tensor(model->pos_embd); + add_tensor(model->tok_norm); + add_tensor(model->tok_norm_b); + add_tensor(model->output_norm); + add_tensor(model->output_norm_b); + add_tensor(model->output); + add_tensor(model->output_b); + add_tensor(model->output_norm_enc); + add_tensor(model->output_s); + add_tensor(model->output_in_s); + add_tensor(model->cls); + add_tensor(model->cls_b); + add_tensor(model->cls_out); + add_tensor(model->cls_out_b); + add_tensor(model->cls_norm); + + for (const struct llama_layer & layer : model->layers) { for (size_t i = 0; i < sizeof(layer)/sizeof(struct ggml_tensor *); ++i) { add_tensor(reinterpret_cast<const struct ggml_tensor * const *>(&layer)[i]); } @@ -283,3 +414,6 @@ void llama_model_saver::save(const std::string & path_model) { gguf_write_to_file(gguf_ctx, path_model.c_str(), false); } +void llama_model_saver::save(FILE * file) { + gguf_write_to_file_ptr(gguf_ctx, file, false); +} diff --git a/examples/talk-llama/llama-model-saver.h b/examples/talk-llama/llama-model-saver.h index a5a434c3069..36a715e2b6b 100644 --- a/examples/talk-llama/llama-model-saver.h +++ b/examples/talk-llama/llama-model-saver.h @@ -1,16 +1,22 @@ #pragma once +#include "gguf.h" #include "llama.h" #include "llama-arch.h" #include <vector> +// FIXME temporary function for better error messages +bool llama_model_saver_supports_arch(llm_arch arch); + struct llama_model_saver { struct gguf_context * gguf_ctx = nullptr; - const struct llama_model & model; + const bool gguf_ctx_owned; + const struct llama_model * model; const struct LLM_KV llm_kv; - llama_model_saver(const struct llama_model & model); + llama_model_saver(const struct llama_model * model); + llama_model_saver(enum llm_arch arch, struct gguf_context * gguf_ctx); ~llama_model_saver(); void add_kv(enum llm_kv key, uint32_t value); @@ -34,4 +40,5 @@ struct llama_model_saver { void add_tensors_from_model(); void save(const std::string & path_model); + void save(FILE * file); }; diff --git a/examples/talk-llama/llama-model.cpp b/examples/talk-llama/llama-model.cpp index f6cea8f8db4..7281ed79f10 100644 --- a/examples/talk-llama/llama-model.cpp +++ b/examples/talk-llama/llama-model.cpp @@ -1,5 +1,8 @@ #include "llama-model.h" +#include "llama-arch.h" +#include "llama-ext.h" +#include "llama-hparams.h" #include "llama-impl.h" #include "llama-mmap.h" #include "llama-cparams.h" @@ -7,23 +10,676 @@ #include "llama-kv-cache.h" #include "llama-kv-cache-iswa.h" +#include "llama-kv-cache-dsa.h" #include "llama-memory-hybrid.h" +#include "llama-memory-hybrid-iswa.h" #include "llama-memory-recurrent.h" -#include "ggml-cpp.h" - #include "models/models.h" +#include "ggml.h" +#include "ggml-cpp.h" + #include <algorithm> #include <cassert> #include <cfloat> +#include <cstdint> #include <cstring> #include <cmath> #include <functional> #include <map> +#include <numeric> #include <regex> #include <sstream> #include <stdexcept> +#include <string> +#include <vector> + +static llama_model * llama_model_mapping(llm_arch arch, const llama_model_params & params) { + switch (arch) { + case LLM_ARCH_LLAMA: + return new llama_model_llama(params); + case LLM_ARCH_LLAMA4: + return new llama_model_llama4(params); + case LLM_ARCH_LLAMA_EMBED: + return new llama_model_llama_embed(params); + case LLM_ARCH_MAINCODER: + return new llama_model_maincoder(params); + case LLM_ARCH_TALKIE: + return new llama_model_talkie(params); + case LLM_ARCH_DECI: + return new llama_model_deci(params); + case LLM_ARCH_BAICHUAN: + return new llama_model_baichuan(params); + case LLM_ARCH_FALCON: + return new llama_model_falcon(params); + case LLM_ARCH_GROK: + return new llama_model_grok(params); + case LLM_ARCH_STARCODER: + return new llama_model_starcoder(params); + case LLM_ARCH_REFACT: + return new llama_model_refact(params); + case LLM_ARCH_BERT: + return new llama_model_bert(params); + case LLM_ARCH_JINA_BERT_V2: + return new llama_model_jina_bert_v2(params); + case LLM_ARCH_JINA_BERT_V3: + return new llama_model_jina_bert_v3(params); + case LLM_ARCH_NOMIC_BERT: + return new llama_model_nomic_bert(params); + case LLM_ARCH_NOMIC_BERT_MOE: + return new llama_model_nomic_bert_moe(params); + case LLM_ARCH_MODERN_BERT: + return new llama_model_modern_bert(params); + case LLM_ARCH_NEO_BERT: + return new llama_model_neo_bert(params); + case LLM_ARCH_EUROBERT: + return new llama_model_eurobert(params); + case LLM_ARCH_BLOOM: + return new llama_model_bloom(params); + case LLM_ARCH_MPT: + return new llama_model_mpt(params); + case LLM_ARCH_STABLELM: + return new llama_model_stablelm(params); + case LLM_ARCH_MELLUM: + return new llama_model_mellum(params); + case LLM_ARCH_QWEN: + return new llama_model_qwen(params); + case LLM_ARCH_QWEN2: + return new llama_model_qwen2(params); + case LLM_ARCH_DREAM: + return new llama_model_dream(params); + case LLM_ARCH_LLADA: + return new llama_model_llada(params); + case LLM_ARCH_LLADA_MOE: + return new llama_model_llada_moe(params); + case LLM_ARCH_RND1: + return new llama_model_rnd1(params); + case LLM_ARCH_QWEN2VL: + return new llama_model_qwen2vl(params); + case LLM_ARCH_QWEN2MOE: + return new llama_model_qwen2moe(params); + case LLM_ARCH_QWEN3: + return new llama_model_qwen3(params); + case LLM_ARCH_QWEN3MOE: + return new llama_model_qwen3moe(params); + case LLM_ARCH_QWEN3VL: + return new llama_model_qwen3vl(params); + case LLM_ARCH_QWEN3VLMOE: + return new llama_model_qwen3vlmoe(params); + case LLM_ARCH_PHI2: + return new llama_model_phi2(params); + case LLM_ARCH_PHI3: + return new llama_model_phi3(params); + case LLM_ARCH_PHIMOE: + return new llama_model_phimoe(params); + case LLM_ARCH_PLAMO: + return new llama_model_plamo(params); + case LLM_ARCH_PLAMO2: + return new llama_model_plamo2(params); + case LLM_ARCH_PLAMO3: + return new llama_model_plamo3(params); + case LLM_ARCH_GPT2: + return new llama_model_gpt2(params); + case LLM_ARCH_CODESHELL: + return new llama_model_codeshell(params); + case LLM_ARCH_ORION: + return new llama_model_orion(params); + case LLM_ARCH_INTERNLM2: + return new llama_model_internlm2(params); + case LLM_ARCH_MINICPM3: + return new llama_model_minicpm3(params); + case LLM_ARCH_GEMMA: + return new llama_model_gemma(params); + case LLM_ARCH_GEMMA2: + return new llama_model_gemma2(params); + case LLM_ARCH_GEMMA3: + return new llama_model_gemma3(params); + case LLM_ARCH_GEMMA3N: + return new llama_model_gemma3n(params); + case LLM_ARCH_GEMMA4: + return new llama_model_gemma4(params); + case LLM_ARCH_GEMMA4_ASSISTANT: + return new llama_model_gemma4_assistant(params); + case LLM_ARCH_GEMMA_EMBEDDING: + return new llama_model_gemma_embedding(params); + case LLM_ARCH_STARCODER2: + return new llama_model_starcoder2(params); + case LLM_ARCH_MAMBA: + return new llama_model_mamba(params); + case LLM_ARCH_MAMBA2: + return new llama_model_mamba2(params); + case LLM_ARCH_JAMBA: + return new llama_model_jamba(params); + case LLM_ARCH_XVERSE: + return new llama_model_xverse(params); + case LLM_ARCH_COMMAND_R: + return new llama_model_command_r(params); + case LLM_ARCH_COHERE2: + return new llama_model_cohere2(params); + case LLM_ARCH_DBRX: + return new llama_model_dbrx(params); + case LLM_ARCH_OLMO: + return new llama_model_olmo(params); + case LLM_ARCH_OLMO2: + return new llama_model_olmo2(params); + case LLM_ARCH_OLMOE: + return new llama_model_olmoe(params); + case LLM_ARCH_OPENELM: + return new llama_model_openelm(params); + case LLM_ARCH_GPTNEOX: + return new llama_model_gptneox(params); + case LLM_ARCH_ARCTIC: + return new llama_model_arctic(params); + case LLM_ARCH_DEEPSEEK: + return new llama_model_deepseek(params); + case LLM_ARCH_DEEPSEEK2: + return new llama_model_deepseek2(params); + case LLM_ARCH_DEEPSEEK2OCR: + return new llama_model_deepseek2ocr(params); + case LLM_ARCH_DEEPSEEK32: + return new llama_model_deepseek32(params); + case LLM_ARCH_GLM_DSA: + return new llama_model_glm_dsa(params); + case LLM_ARCH_MISTRAL4: + return new llama_model_mistral4(params); + case LLM_ARCH_CHATGLM: + return new llama_model_chatglm(params); + case LLM_ARCH_GLM4: + return new llama_model_glm4(params); + case LLM_ARCH_GLM4_MOE: + return new llama_model_glm4_moe(params); + case LLM_ARCH_BITNET: + return new llama_model_bitnet(params); + case LLM_ARCH_T5: + return new llama_model_t5(params); + case LLM_ARCH_T5ENCODER: + return new llama_model_t5encoder(params); + case LLM_ARCH_JAIS: + return new llama_model_jais(params); + case LLM_ARCH_JAIS2: + return new llama_model_jais2(params); + case LLM_ARCH_NEMOTRON: + return new llama_model_nemotron(params); + case LLM_ARCH_NEMOTRON_H: + return new llama_model_nemotron_h(params); + case LLM_ARCH_NEMOTRON_H_MOE: + return new llama_model_nemotron_h_moe(params); + case LLM_ARCH_EXAONE: + return new llama_model_exaone(params); + case LLM_ARCH_EXAONE4: + return new llama_model_exaone4(params); + case LLM_ARCH_EXAONE_MOE: + return new llama_model_exaone_moe(params); + case LLM_ARCH_RWKV6: + return new llama_model_rwkv6(params); + case LLM_ARCH_RWKV6QWEN2: + return new llama_model_rwkv6qwen2(params); + case LLM_ARCH_RWKV7: + return new llama_model_rwkv7(params); + case LLM_ARCH_ARWKV7: + return new llama_model_arwkv7(params); + case LLM_ARCH_GRANITE: + return new llama_model_granite(params); + case LLM_ARCH_GRANITE_MOE: + return new llama_model_granite_moe(params); + case LLM_ARCH_MINICPM: + return new llama_model_minicpm(params); + case LLM_ARCH_GRANITE_HYBRID: + return new llama_model_granite_hybrid(params); + case LLM_ARCH_CHAMELEON: + return new llama_model_chameleon(params); + case LLM_ARCH_WAVTOKENIZER_DEC: + return new llama_model_wavtokenizer_dec(params); + case LLM_ARCH_PLM: + return new llama_model_plm(params); + case LLM_ARCH_BAILINGMOE: + return new llama_model_bailingmoe(params); + case LLM_ARCH_BAILINGMOE2: + return new llama_model_bailingmoe2(params); + case LLM_ARCH_SEED_OSS: + return new llama_model_seed_oss(params); + case LLM_ARCH_DOTS1: + return new llama_model_dots1(params); + case LLM_ARCH_ARCEE: + return new llama_model_arcee(params); + case LLM_ARCH_AFMOE: + return new llama_model_afmoe(params); + case LLM_ARCH_ERNIE4_5: + return new llama_model_ernie4_5(params); + case LLM_ARCH_ERNIE4_5_MOE: + return new llama_model_ernie4_5_moe(params); + case LLM_ARCH_PADDLEOCR: + return new llama_model_paddleocr(params); + case LLM_ARCH_HUNYUAN_MOE: + return new llama_model_hunyuan_moe(params); + case LLM_ARCH_HUNYUAN_VL: + return new llama_model_hunyuan_vl(params); + case LLM_ARCH_HUNYUAN_DENSE: + return new llama_model_hunyuan_dense(params); + case LLM_ARCH_SMOLLM3: + return new llama_model_smollm3(params); + case LLM_ARCH_OPENAI_MOE: + return new llama_model_openai_moe(params); + case LLM_ARCH_FALCON_H1: + return new llama_model_falcon_h1(params); + case LLM_ARCH_LFM2: + return new llama_model_lfm2(params); + case LLM_ARCH_LFM2MOE: + return new llama_model_lfm2moe(params); + case LLM_ARCH_SMALLTHINKER: + return new llama_model_smallthinker(params); + case LLM_ARCH_GROVEMOE: + return new llama_model_grovemoe(params); + case LLM_ARCH_APERTUS: + return new llama_model_apertus(params); + case LLM_ARCH_MINIMAX_M2: + return new llama_model_minimax_m2(params); + case LLM_ARCH_COGVLM: + return new llama_model_cogvlm(params); + case LLM_ARCH_PANGU_EMBED: + return new llama_model_pangu_embed(params); + case LLM_ARCH_QWEN3NEXT: + return new llama_model_qwen3next(params); + case LLM_ARCH_QWEN35: + return new llama_model_qwen35(params); + case LLM_ARCH_QWEN35MOE: + return new llama_model_qwen35moe(params); + case LLM_ARCH_MISTRAL3: + return new llama_model_mistral3(params); + case LLM_ARCH_EAGLE3: + return new llama_model_eagle3(params); + case LLM_ARCH_MIMO2: + return new llama_model_mimo2(params); + case LLM_ARCH_KIMI_LINEAR: + return new llama_model_kimi_linear(params); + case LLM_ARCH_STEP35: + return new llama_model_step35(params); + default: + throw std::runtime_error(std::string("unsupported model architecture: '") + llm_arch_name(arch) + "'"); + } + +} + +llama_model * llama_model_create(llm_arch arch, const llama_model_params & params) { + llama_model * model = llama_model_mapping(arch, params); + + if (model != nullptr) { + model->arch = arch; + auto & devices = model->devices; + if (!devices.empty() && devices[0].is_meta && !llm_arch_supports_sm_tensor(arch)) { + throw std::runtime_error(std::string("LLAMA_SPLIT_MODE_TENSOR not implemented for architecture '") + llm_arch_name(arch) + "'"); + } + } + + return model; +} + +llama_model * llama_model_create(llama_model_loader & ml, const llama_model_params & params) { + llm_arch arch = ml.get_arch(); + if (arch == LLM_ARCH_UNKNOWN) { + throw std::runtime_error("unknown model architecture: '" + ml.get_arch_name() + "'"); + } + + return llama_model_create(arch, params); +} + +struct ggml_backend_meta_split_state llama_meta_device_get_split_state(const struct ggml_tensor * tensor, void * userdata) { + const llama_meta_device_get_split_state_userdata * ud = (const llama_meta_device_get_split_state_userdata *) userdata; + const llama_hparams & hparams = ud->model->hparams; + const std::string tensor_name = tensor->name; + + const std::regex pattern_q_weight ("blk\\.\\d*\\.attn_q.weight"); + const std::regex pattern_kv_weight ("blk\\.\\d*\\.attn_(k|v).weight"); + const std::regex pattern_qkv_weight ("blk\\.\\d*\\.attn_qkv.weight"); + const std::regex pattern_q_bias ("blk\\.\\d*\\.attn_q\\.bias"); + const std::regex pattern_kv_bias ("blk\\.\\d*\\.attn_(k|v)\\.bias"); + const std::regex pattern_qkv_bias ("blk\\.\\d*\\.attn_qkv.bias"); + const std::regex pattern_qk_norm ("blk\\.\\d*\\.attn_(q|k)_norm\\.weight"); + const std::regex pattern_kv_cache ("cache_(k|v)_l\\d*"); + const std::regex pattern_attn_sinks ("blk\\.\\d*\\.attn_sinks.weight"); + const std::regex pattern_attn_out_weight ("blk\\.\\d*\\.attn_output.weight"); + const std::regex pattern_attn_out_bias ("blk\\.\\d*\\.attn_output.bias"); + const std::regex pattern_attn_gate_weight("blk\\.\\d*\\.attn_gate.weight"); + + const std::regex pattern_ssm_dt ("blk\\.\\d*\\.ssm_dt.bias"); + const std::regex pattern_ssm_a ("blk\\.\\d*\\.ssm_a"); + const std::regex pattern_ssm_alpha ("blk\\.\\d*\\.ssm_alpha.weight"); + const std::regex pattern_ssm_beta ("blk\\.\\d*\\.ssm_beta.weight"); + const std::regex pattern_ssm_beta_alpha ("blk\\.\\d*\\.ssm_ba.weight"); + const std::regex pattern_r_cache ("cache_r_l\\d*"); + const std::regex pattern_s_cache ("cache_s_l\\d*"); + const std::regex pattern_ssm_conv1d ("blk\\.\\d*\\.ssm_conv1d.weight"); + const std::regex pattern_ssm_out_weight ("blk\\.\\d*\\.ssm_out.weight"); + + const std::regex pattern_ffn_up_gate_weight("blk\\.\\d*\\.ffn_(up|gate)(_exps)?.weight"); + const std::regex pattern_ffn_up_gate_bias ("blk\\.\\d*\\.ffn_(up|gate)(_exps)?.bias"); + const std::regex pattern_ffn_gate_up_weight("blk\\.\\d*\\.ffn_gate_up(_exps)?.weight"); + const std::regex pattern_ffn_down_weight ("blk\\.\\d*\\.ffn_down(_exps)?.weight"); + const std::regex pattern_ffn_down_bias ("blk\\.\\d*\\.ffn_down.bias"); + const std::regex pattern_ffn_down_exps_bias("blk\\.\\d*\\.ffn_down_exps.bias"); + + const std::regex pattern_output_weight("output\\.weight"); + const std::regex pattern_output_bias ("output\\.bias"); + + struct tensor_config { + ggml_backend_meta_split_axis axis; + + const ggml_tensor * tensor_axis_0; + + uint32_t il; + size_t rotation; // when assigning tensor slices, rotate how the rounding is done for more even allocation + }; + + auto get_tensor_config_impl = [&]( + const ggml_backend_meta_split_axis axis, const std::string & suffix = "", const std::string & suffix_fallback = "") -> tensor_config { + // the layers in a tensor can be inhomogeneous, if the pattern is cleanly divided by the number of GPUs there can be aliasing effects, + // count only the same type of previous layers to avoid this + auto get_il_eff = [&](const size_t il){ + size_t ret = 0; + const bool il_is_recr = hparams.is_recr(il); + const bool il_is_swa = hparams.is_swa(il); + for (size_t il_prev = 0; il_prev < il; il_prev++) { + ret += hparams.is_recr(il_prev) == il_is_recr && hparams.is_swa(il_prev) == il_is_swa; + } + return ret; + }; + + uint32_t il; + std::string prefix; + size_t rotation; + if (tensor_name.substr(0, 4) == "blk.") { + const size_t length_prefix = tensor_name.find('.', 4); + GGML_ASSERT(length_prefix != std::string::npos); + prefix = tensor_name.substr(0, length_prefix + 1); + il = std::stoull(tensor_name.substr(4, length_prefix)); + rotation = get_il_eff(il) % ud->n_devices; + } else if (tensor_name.substr(0, 6) == "cache_") { + const size_t layer_index_start = tensor_name.find("_l", 6); + GGML_ASSERT(layer_index_start != std::string::npos); + il = std::stoull(tensor_name.substr(layer_index_start + 2)); + prefix = "blk." + std::to_string(il) + "."; + rotation = get_il_eff(il) % ud->n_devices; + } else { + il = 0; + rotation = hparams.n_layer() % ud->n_devices; + } + const ggml_tensor * tensor_axis_0 = suffix.empty() ? tensor : ud->model->get_tensor((prefix + suffix).c_str()); + if (tensor_axis_0 == nullptr) { + GGML_ASSERT(!suffix_fallback.empty()); + tensor_axis_0 = ud->model->get_tensor((prefix + suffix_fallback).c_str()); + } + GGML_ASSERT(tensor_axis_0 != nullptr); + return {axis, tensor_axis_0, il, rotation}; + }; + + auto get_tensor_config = [&]() -> tensor_config { + // standard attention + if (std::regex_match(tensor_name, pattern_q_weight) || std::regex_match(tensor_name, pattern_kv_weight)) { + return get_tensor_config_impl(GGML_BACKEND_SPLIT_AXIS_1, "attn_output.weight", "ssm_out.weight"); + } + if (std::regex_match(tensor_name, pattern_q_bias) || std::regex_match(tensor_name, pattern_kv_bias)) { + return get_tensor_config_impl(GGML_BACKEND_SPLIT_AXIS_0, "attn_output.weight", "ssm_out.weight"); + } + if (std::regex_match(tensor_name, pattern_qkv_weight)) { + return get_tensor_config_impl(GGML_BACKEND_SPLIT_AXIS_1, "attn_output.weight", "ssm_out.weight"); + } + if ( std::regex_match(tensor_name, pattern_qkv_bias)) { + return get_tensor_config_impl(GGML_BACKEND_SPLIT_AXIS_0, "attn_output.weight", "ssm_out.weight"); + } + if (std::regex_match(tensor_name, pattern_qk_norm)) { + return get_tensor_config_impl(tensor->ne[1] == 1 ? GGML_BACKEND_SPLIT_AXIS_MIRRORED : GGML_BACKEND_SPLIT_AXIS_1, "attn_output.weight"); + } + if (std::regex_match(tensor_name, pattern_kv_cache) || std::regex_match(tensor_name, pattern_attn_sinks)) { + return get_tensor_config_impl(GGML_BACKEND_SPLIT_AXIS_0, "attn_output.weight"); + } + if (std::regex_match(tensor_name, pattern_attn_out_weight)) { + return get_tensor_config_impl(GGML_BACKEND_SPLIT_AXIS_0); + } + if (std::regex_match(tensor_name, pattern_attn_out_bias)) { + return get_tensor_config_impl(GGML_BACKEND_SPLIT_AXIS_MIRRORED); + } + + if (std::regex_match(tensor_name, pattern_attn_gate_weight)) { + return get_tensor_config_impl(GGML_BACKEND_SPLIT_AXIS_1, "attn_output.weight", "ssm_out.weight"); + } + if (std::regex_match(tensor_name, pattern_ssm_dt) || std::regex_match(tensor_name, pattern_ssm_a)) { + return get_tensor_config_impl(GGML_BACKEND_SPLIT_AXIS_0, "ssm_out.weight"); + } + if (std::regex_match(tensor_name, pattern_ssm_alpha) || std::regex_match(tensor_name, pattern_ssm_beta) || + std::regex_match(tensor_name, pattern_ssm_beta_alpha)) { + return get_tensor_config_impl(GGML_BACKEND_SPLIT_AXIS_1, "ssm_out.weight"); + } + if (std::regex_match(tensor_name, pattern_r_cache) || std::regex_match(tensor_name, pattern_s_cache)) { + return get_tensor_config_impl(GGML_BACKEND_SPLIT_AXIS_0, "ssm_out.weight"); + } + if (std::regex_match(tensor_name, pattern_ssm_conv1d)) { + return get_tensor_config_impl(GGML_BACKEND_SPLIT_AXIS_1, "ssm_out.weight"); + } + if (std::regex_match(tensor_name, pattern_ssm_out_weight)) { + return get_tensor_config_impl(GGML_BACKEND_SPLIT_AXIS_0); + } + + // FFN + if (std::regex_match(tensor_name, pattern_ffn_up_gate_weight)) { + return get_tensor_config_impl(GGML_BACKEND_SPLIT_AXIS_1, "ffn_down.weight", "ffn_down_exps.weight"); + } + if (std::regex_match(tensor_name, pattern_ffn_up_gate_bias)) { + return get_tensor_config_impl(GGML_BACKEND_SPLIT_AXIS_0, "ffn_down.weight", "ffn_down_exps.weight"); + } + if (std::regex_match(tensor_name, pattern_ffn_gate_up_weight)) { + return get_tensor_config_impl(GGML_BACKEND_SPLIT_AXIS_1, "ffn_down.weight", "ffn_down_exps.weight"); + } + if (std::regex_match(tensor_name, pattern_ffn_down_weight)) { + return get_tensor_config_impl(GGML_BACKEND_SPLIT_AXIS_0, "ffn_down.weight", "ffn_down_exps.weight"); + } + if (std::regex_match(tensor_name, pattern_ffn_down_bias)) { + return get_tensor_config_impl(GGML_BACKEND_SPLIT_AXIS_MIRRORED); + } + if (std::regex_match(tensor_name, pattern_ffn_down_exps_bias)) { + return get_tensor_config_impl(GGML_BACKEND_SPLIT_AXIS_PARTIAL); + } + + // output + if (std::regex_match(tensor_name, pattern_output_weight)) { + return get_tensor_config_impl(GGML_BACKEND_SPLIT_AXIS_1); + } + if (std::regex_match(tensor_name, pattern_output_bias)) { + const ggml_tensor * output_weight = ud->model->get_tensor("output.weight"); + GGML_ASSERT(output_weight != nullptr); + return get_tensor_config_impl(GGML_BACKEND_SPLIT_AXIS_0); + } + + // everything else + return get_tensor_config_impl(GGML_BACKEND_SPLIT_AXIS_MIRRORED); + }; + + auto get_split_segments = [&](int axis, uint32_t il) -> std::vector<std::pair<int64_t, uint32_t>> { + if (ud->model->arch == LLM_ARCH_QWEN3NEXT || ud->model->arch == LLM_ARCH_QWEN35 || ud->model->arch == LLM_ARCH_QWEN35MOE) { + const int64_t head_k_dim = hparams.ssm_d_state; + const int64_t head_v_dim = hparams.ssm_d_state; + const int64_t n_k_heads = hparams.ssm_n_group; + const int64_t n_v_heads = hparams.ssm_dt_rank; + const int64_t key_dim = head_k_dim * n_k_heads; + const int64_t value_dim = head_v_dim * n_v_heads; + + // both Qwen 3 Next and Qwen 3.5 support n_v_heads > n_k_heads but the broadcasting pattern is different: + // - Qwen 3 Next: [k0_v0, k0_v1, k1_v2, k1_v3] (this is the default split pattern) + // - Qwen 3.5: [k0_v0, k1_v1, k0_v2, k1_v3] (needs segmenting of V on the scale of K to get the correct pattern) + if (ud->model->arch == LLM_ARCH_QWEN3NEXT) { + if (std::regex_match(tensor_name, pattern_qkv_weight) || std::regex_match(tensor_name, pattern_ssm_conv1d)) { + GGML_ASSERT(tensor->ne[axis] == 2*key_dim + value_dim); + return {{key_dim, 2}, {value_dim, 1}}; + } + } else { + const int64_t head_ratio = n_v_heads / n_k_heads; + if (std::regex_match(tensor_name, pattern_qkv_weight) || std::regex_match(tensor_name, pattern_ssm_conv1d)) { + GGML_ASSERT(tensor->ne[axis] == 2*key_dim + value_dim); + return {{key_dim, 2 + head_ratio}}; + } + if (std::regex_match(tensor_name, pattern_attn_gate_weight) || std::regex_match(tensor_name, pattern_ssm_out_weight)) { + return {{key_dim, head_ratio}}; + } + if (std::regex_match(tensor_name, pattern_ssm_dt) || std::regex_match(tensor_name, pattern_ssm_a) || + std::regex_match(tensor_name, pattern_ssm_alpha) || std::regex_match(tensor_name, pattern_ssm_beta)) { + return {{n_k_heads, head_ratio}}; + } + if (std::regex_match(tensor_name, pattern_r_cache)) { + return {{key_dim * (hparams.ssm_d_conv - 1), 2 + head_ratio}}; + } + if (std::regex_match(tensor_name, pattern_s_cache)) { + return {{n_k_heads * head_v_dim * head_v_dim, head_ratio}}; + } + } + + // the FFN is the same for Qwen 3 Next and Qwen 3.5: + if (std::regex_match(tensor_name, pattern_ffn_gate_up_weight)) { + const int64_t n_ff_exp = hparams.n_ff_exp; + GGML_ASSERT(tensor->ne[axis] == 2*n_ff_exp); + return {{n_ff_exp, 2}}; + } + return {{tensor->ne[axis], 1}}; + } + + if (std::regex_match(tensor_name, pattern_qkv_weight) || std::regex_match(tensor_name, pattern_qkv_bias)) { + const int64_t n_embd = hparams.n_embd; + const int64_t n_embd_gqa = hparams.n_embd_v_gqa(il); + GGML_ASSERT(hparams.n_embd_k_gqa() == n_embd_gqa); + GGML_ASSERT(tensor->ne[axis] == n_embd + 2*n_embd_gqa); + return {{n_embd, 1}, {n_embd_gqa, 2}}; + } + if (std::regex_match(tensor_name, pattern_ffn_gate_up_weight)) { + const int64_t n_ff_exp = hparams.n_ff_exp; + GGML_ASSERT(tensor->ne[axis] == 2*n_ff_exp); + return {{n_ff_exp, 2}}; + } + return {{tensor->ne[axis], 1}}; + }; + + auto get_split_granularity = [&](int64_t blck_size, uint32_t il, const std::vector<std::pair<int64_t, uint32_t>> & segments) -> std::vector<int64_t> { + // for better performance it may make sense to round up blck_size to a higher power of 2 so that more efficient kernels can be used + if (hparams.is_recr(il)) { + // linear attention + const int64_t head_dim = hparams.ssm_d_state; + const int64_t blck_size_perf = std::lcm(blck_size, 128); + const int64_t granularity_qkv = std::lcm(blck_size_perf, head_dim); + if (std::regex_match(tensor_name, pattern_qkv_weight) || std::regex_match(tensor_name, pattern_attn_gate_weight) || + std::regex_match(tensor_name, pattern_ssm_conv1d) || std::regex_match(tensor_name, pattern_ssm_out_weight)) { + return std::vector<int64_t>(segments.size(), granularity_qkv); + } + if (std::regex_match(tensor_name, pattern_ssm_dt) || std::regex_match(tensor_name, pattern_ssm_a) || + std::regex_match(tensor_name, pattern_ssm_alpha) || std::regex_match(tensor_name, pattern_ssm_beta)) { + return std::vector<int64_t>(segments.size(), granularity_qkv / head_dim); + } + if (std::regex_match(tensor_name, pattern_ssm_beta_alpha)) { + return std::vector<int64_t>(segments.size(), 2 * (granularity_qkv / head_dim)); + } + if (std::regex_match(tensor_name, pattern_r_cache)) { + return std::vector<int64_t>(segments.size(), granularity_qkv * (hparams.ssm_d_conv - 1)); + } + if (std::regex_match(tensor_name, pattern_s_cache)) { + return std::vector<int64_t>(segments.size(), granularity_qkv * head_dim); + } + } else { + // regular attention + const uint32_t n_gqa = hparams.n_gqa(il); + const uint32_t n_embd_q = n_gqa * hparams.n_embd_head_k(il); + + // to handle head sizes like 80, only increase granularity while it doesn't cause underutilization + int64_t blck_size_perf = blck_size; + while (blck_size_perf < 128 && blck_size_perf*ud->n_devices < n_embd_q) { + blck_size_perf *= 2; + } + + if (std::regex_match(tensor_name, pattern_attn_sinks)) { + GGML_ASSERT(segments.size() == 1); + return {std::lcm(n_embd_q, blck_size_perf)/n_embd_q * n_gqa}; + } + + const int64_t granularity_q = std::lcm(n_embd_q, blck_size_perf); + if (std::regex_match(tensor_name, pattern_q_weight) || std::regex_match(tensor_name, pattern_q_bias)) { + GGML_ASSERT(segments.size() == 1); + // some models have Q gate tensors, for those cases the granularity needs to be doubled: + if (ud->model->arch == LLM_ARCH_QWEN3NEXT || ud->model->arch == LLM_ARCH_QWEN35 || ud->model->arch == LLM_ARCH_QWEN35MOE) { + return {std::lcm(2*n_embd_q, blck_size_perf)}; + } + return {granularity_q}; + } + if (std::regex_match(tensor_name, pattern_attn_out_weight)) { + GGML_ASSERT(segments.size() == 1); + return {granularity_q}; + } + + const int64_t granularity_kv = granularity_q / n_gqa; + if (std::regex_match(tensor_name, pattern_kv_weight) || + std::regex_match(tensor_name, pattern_kv_bias) || + std::regex_match(tensor_name, pattern_kv_cache)) { + GGML_ASSERT(segments.size() == 1); + return {granularity_kv}; + } + if (std::regex_match(tensor_name, pattern_qkv_weight) || std::regex_match(tensor_name, pattern_qkv_bias)) { + GGML_ASSERT(segments.size() == 2); + return {granularity_q, granularity_kv}; + } + } + + // FFN + if (std::regex_match(tensor_name, pattern_ffn_up_gate_weight) || std::regex_match(tensor_name, pattern_ffn_up_gate_bias) || + std::regex_match(tensor_name, pattern_ffn_gate_up_weight) || std::regex_match(tensor_name, pattern_ffn_down_weight)) { + const int64_t blck_size_perf = std::lcm(blck_size, 128); + GGML_ASSERT(segments.size() == 1); + return {blck_size_perf}; + } + + // everything else + GGML_ASSERT(segments.size() == 1); + return {1}; + }; + + ggml_backend_meta_split_state split_state; + memset(&split_state, 0, sizeof(split_state)); + tensor_config tc = get_tensor_config(); + split_state.axis = tc.axis; + if (split_state.axis >= 0 && split_state.axis < GGML_MAX_DIMS) { + const int64_t blck_size = ggml_blck_size(tc.tensor_axis_0->type); + const float * tensor_split = ud->model->tensor_split(); + std::vector<float> tensor_split_scan; + tensor_split_scan.reserve(ud->n_devices); + for (size_t j = 0; j < ud->n_devices; j++) { + tensor_split_scan.push_back(tensor_split == nullptr ? 0.0f : tensor_split[(j + tc.rotation) % ud->n_devices]); + if (j > 0) { + tensor_split_scan[j] += tensor_split_scan[j - 1]; + } + } + const std::vector<std::pair<int64_t, uint32_t>> segments = get_split_segments(split_state.axis, tc.il); + const std::vector<int64_t> granularity = get_split_granularity(blck_size, tc.il, segments); + for (size_t is = 0; is < segments.size(); is++) { + const int64_t ne_s = segments[is].first; + const uint32_t nr_s = segments[is].second; + const int64_t g_s = granularity[is]; + int64_t low = 0; + size_t j = 0; + for (; j < ud->n_devices - 1; j++) { + int64_t high = tensor_split_scan.back() == 0.0f ? + ne_s * (j+1)/ud->n_devices : ne_s * tensor_split_scan[j]/tensor_split_scan.back(); + if (high % g_s != 0) { + high -= high % g_s; + } + split_state.ne[is*ud->n_devices + (j + tc.rotation) % ud->n_devices] = high - low; + low = high; + } + split_state.ne[is*ud->n_devices + (j + tc.rotation) % ud->n_devices] = ne_s - low; + split_state.nr[is] = nr_s; + } + split_state.n_segments = segments.size(); + } else { + memset(split_state.ne, 0, sizeof(split_state.ne)); + split_state.nr[0] = 1; + split_state.n_segments = 1; + } + return split_state; + GGML_UNUSED(userdata); +} const char * llm_type_name(llm_type type) { switch (type) { @@ -60,6 +716,7 @@ const char * llm_type_name(llm_type type) { case LLM_TYPE_0_3B: return "0.3B"; case LLM_TYPE_0_5B: return "0.5B"; case LLM_TYPE_0_6B: return "0.6B"; + case LLM_TYPE_0_8B: return "0.8B"; case LLM_TYPE_1B: return "1B"; case LLM_TYPE_1_2B: return "1.2B"; case LLM_TYPE_1_3B: return "1.3B"; @@ -89,6 +746,7 @@ const char * llm_type_name(llm_type type) { case LLM_TYPE_26B: return "26B"; case LLM_TYPE_27B: return "27B"; case LLM_TYPE_30B: return "30B"; + case LLM_TYPE_31B: return "31B"; case LLM_TYPE_32B: return "32B"; case LLM_TYPE_34B: return "34B"; case LLM_TYPE_35B: return "35B"; @@ -120,19 +778,30 @@ const char * llm_type_name(llm_type type) { case LLM_TYPE_A13B: return "A13B"; case LLM_TYPE_7B_A1B: return "7B.A1B"; case LLM_TYPE_8B_A1B: return "8B.A1B"; + case LLM_TYPE_12B_A2_5B: return "12B.A2.5B"; case LLM_TYPE_16B_A1B: return "16B.A1B"; case LLM_TYPE_21B_A3B: return "21B.A3B"; + case LLM_TYPE_24B_A2B: return "24B.A2B"; + case LLM_TYPE_26B_A4B: return "26B.A4B"; case LLM_TYPE_30B_A3B: return "30B.A3B"; case LLM_TYPE_31B_A3_5B: return "31B.A3.5B"; + case LLM_TYPE_35B_A3B: return "35B.A3B"; + case LLM_TYPE_48B_A3B: return "48B.A3B"; case LLM_TYPE_80B_A3B: return "80B.A3B"; case LLM_TYPE_100B_A6B: return "100B.A6B"; case LLM_TYPE_102B_A12B: return "102B.A12B"; case LLM_TYPE_106B_A12B: return "106B.A12B"; + case LLM_TYPE_120B_A12B: return "120B.A12B"; + case LLM_TYPE_122B_A10B: return "122B.A10B"; + case LLM_TYPE_196B_A11B: return "196B.A11B"; case LLM_TYPE_230B_A10B: return "230B.A10B"; case LLM_TYPE_235B_A22B: return "235B.A22B"; case LLM_TYPE_300B_A47B: return "300B.A47B"; case LLM_TYPE_310B_A15B: return "310B.A15B"; case LLM_TYPE_355B_A32B: return "355B.A32B"; + case LLM_TYPE_397B_A17B: return "397B.A17B"; + case LLM_TYPE_685B_A37B: return "685B.A37B"; + case LLM_TYPE_744B_A40B: return "744B.A40B"; case LLM_TYPE_E2B: return "E2B"; case LLM_TYPE_E4B: return "E4B"; default: return "?B"; @@ -168,191 +837,59 @@ static llama_rope_scaling_type llama_rope_scaling_type_from_string(const std::st return LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED; } -// checks if the weight tensor can be used with the specified buffer type and device -static bool weight_buft_supported(const llama_hparams & hparams, ggml_tensor * w, ggml_op op, ggml_backend_buffer_type_t buft, ggml_backend_dev_t dev) { - GGML_ASSERT(w != nullptr); +// Maps the GGUF `<arch>.hidden_activation` string to the FFN op type used by the +// graph builders. Only gated activations that map cleanly to llm_ffn_op_type are +// listed; unrecognized values fall back to GeGLU, which matches the historical +// default for ModernBert-style architectures. +static const std::map<std::string, llm_ffn_op_type> LLM_FFN_OP_TYPES_FROM_STRING = { + { "gelu", LLM_FFN_GEGLU }, + { "geglu", LLM_FFN_GEGLU }, + { "silu", LLM_FFN_SWIGLU }, + { "swish", LLM_FFN_SWIGLU }, + { "swiglu", LLM_FFN_SWIGLU }, + { "relu", LLM_FFN_RELU }, + { "reglu", LLM_FFN_REGLU }, +}; - if (op == GGML_OP_NONE) { - return true; +llm_ffn_op_type llm_ffn_op_type_from_string(const std::string & name, llm_ffn_op_type fallback) { + const auto it = LLM_FFN_OP_TYPES_FROM_STRING.find(name); + if (it != LLM_FFN_OP_TYPES_FROM_STRING.end()) { + return it->second; } + return fallback; +} - ggml_init_params params = { - /*.mem_size =*/ ggml_tensor_overhead()*8, - /*.mem_buffer =*/ NULL, - /*.no_alloc =*/ true, - }; - ggml_context_ptr ctx_ptr { ggml_init(params) }; - if (!ctx_ptr) { - throw std::runtime_error(format("failed to create ggml context")); +// CPU: ACCEL -> GPU host -> CPU extra -> CPU +static buft_list_t make_cpu_buft_list(const std::vector<llama_device> & devices, bool use_extra_bufts, bool no_host) { + buft_list_t buft_list; + + // add ACCEL buffer types + for (size_t i = 0; i < ggml_backend_dev_count(); ++i) { + ggml_backend_dev_t dev = ggml_backend_dev_get(i); + if (ggml_backend_dev_type(dev) == GGML_BACKEND_DEVICE_TYPE_ACCEL) { + auto * buft = ggml_backend_dev_buffer_type(dev); + // skip + if (buft != ggml_backend_cpu_buffer_type()) { + buft_list.emplace_back(dev, buft); + } + } } - ggml_context * ctx = ctx_ptr.get(); - ggml_tensor * op_tensor = nullptr; - - switch (op) { - case GGML_OP_GET_ROWS: - { - ggml_tensor * b = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 512); - op_tensor = ggml_get_rows(ctx, w, b); - } break; - case GGML_OP_MUL_MAT: - { - ggml_tensor * b = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, w->ne[0], 512, w->ne[2], w->ne[3]); - op_tensor = ggml_mul_mat(ctx, w, b); - } break; - case GGML_OP_MUL_MAT_ID: - { - int n_expert_used = hparams.n_expert_used; - ggml_tensor * b = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, w->ne[0], n_expert_used, 512); - ggml_tensor * ids = ggml_new_tensor_2d(ctx, GGML_TYPE_I32, n_expert_used, 512); - op_tensor = ggml_mul_mat_id(ctx, w, b, ids); - } break; - case GGML_OP_ADD: - { - ggml_tensor * a = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, w->ne[0], w->ne[1], w->ne[2], w->ne[3]); - op_tensor = ggml_add(ctx, a, w); - } break; - case GGML_OP_ADD_ID: - { - int n_expert_used = hparams.n_expert_used; - ggml_tensor * a = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, w->ne[0], n_expert_used, 512); - ggml_tensor * c = ggml_new_tensor_2d(ctx, GGML_TYPE_I32, n_expert_used, 512); - op_tensor = ggml_add_id(ctx, a, w, c); - } break; - case GGML_OP_MUL: - { - ggml_tensor * a = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, w->ne[0], w->ne[1], w->ne[2], w->ne[3]); - op_tensor = ggml_mul(ctx, a, w); - } break; - case GGML_OP_DIV: - { - ggml_tensor * a = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, w->ne[0]); - op_tensor = ggml_div(ctx, a, w); - } break; - case GGML_OP_ROPE: - { - int n_embd_head = hparams.n_embd_head_v; - int n_head = hparams.n_head(); - ggml_tensor * a = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, n_embd_head, n_head, 512); - ggml_tensor * b = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 512); - op_tensor = ggml_rope_ext( - ctx, a, b, w, - 0, 0, 0, 0, 0, - 0, 0, 0, 0 - ); - - } break; - case GGML_OP_SSM_CONV: - { - const int64_t n_seq_tokens = 512; - const int64_t n_seqs = 3; - ggml_tensor * conv_x = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, w->ne[0] - 1 + n_seq_tokens, w->ne[1], n_seqs); - op_tensor = ggml_ssm_conv(ctx, conv_x, w); - } break; - case GGML_OP_SSM_SCAN: - { - // w is ssm_a, which is used to distinguish Mamba-1 and Mamba-2 - const int64_t d_state = w->ne[0] == 1 ? hparams.ssm_d_state : w->ne[0]; - const int64_t n_head = w->ne[1]; - const int64_t head_dim = hparams.ssm_d_inner / n_head; - const int64_t n_group = hparams.ssm_n_group ? hparams.ssm_n_group : 1; - const int64_t n_seq_tokens = 512; - const int64_t n_seqs = 3; - ggml_tensor * s = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, d_state, head_dim, n_head, n_seqs); - ggml_tensor * x = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, head_dim, n_head, n_seq_tokens, n_seqs); - ggml_tensor * dt = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, n_head, n_seq_tokens, n_seqs); - ggml_tensor * B = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, d_state, n_group, n_seq_tokens, n_seqs); - ggml_tensor * C = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, d_state, n_group, n_seq_tokens, n_seqs); - ggml_tensor * ids = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, n_seqs); - op_tensor = ggml_ssm_scan(ctx, s, x, dt, w, B, C, ids); - } break; - case GGML_OP_RWKV_WKV6: - { - // FIXME - const int64_t S = 123; - const int64_t H = 123; - const int64_t n_tokens = 123; - const int64_t n_seqs = 123; - ggml_tensor * k = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, S, H, n_tokens); - ggml_tensor * v = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, S, H, n_tokens); - ggml_tensor * r = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, S, H, n_tokens); - ggml_tensor * tf = w; - ggml_tensor * td = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, S, H, n_tokens); - ggml_tensor * state = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, S, n_seqs, S, H); - op_tensor = ggml_rwkv_wkv6(ctx, k, v, r, tf, td, state); - } break; - case GGML_OP_IM2COL: - { - const int n_embd_inp = hparams.n_embd_inp(); - ggml_tensor * b = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, n_embd_inp, w->ne[1], 1, 1); - op_tensor = ggml_im2col(ctx, w, b, 1, 0, 0, 0, 1, 0, false, GGML_TYPE_F16); - } break; - case GGML_OP_SCALE: - { - op_tensor = ggml_scale(ctx, w, 1.0f); - } break; - default: - GGML_ABORT("%s: missing test for op %s for tensor %s", __func__, ggml_op_name(op), w->name); - } - - // create a temporary dummy buffer for the weight so that supports_op can check the buffer type - GGML_ASSERT(w->buffer == nullptr); - w->buffer = ggml_backend_buft_alloc_buffer(buft, 0); - bool op_supported = ggml_backend_dev_supports_op(dev, op_tensor); - ggml_backend_buffer_free(w->buffer); - w->buffer = nullptr; - - return op_supported; -} - -// lists of buffer types used for each layer -using buft_list_t = std::vector<std::pair<ggml_backend_dev_t, ggml_backend_buffer_type_t>>; - -// find the first buffer type in the list that can use the tensor -static ggml_backend_buffer_type_t select_weight_buft(const llama_hparams & hparams, ggml_tensor * tensor, ggml_op op, const buft_list_t & buft_list) { - GGML_ASSERT(!buft_list.empty()); - for (const auto & cur : buft_list) { - ggml_backend_dev_t cur_dev = cur.first; - ggml_backend_buffer_type_t cur_buft = cur.second; - if (weight_buft_supported(hparams, tensor, op, cur_buft, cur_dev)) { - return cur_buft; - } - } - - return nullptr; -} - -// CPU: ACCEL -> GPU host -> CPU extra -> CPU -static buft_list_t make_cpu_buft_list(const std::vector<ggml_backend_dev_t> & devices, bool use_extra_bufts, bool no_host) { - buft_list_t buft_list; - - // add ACCEL buffer types - for (size_t i = 0; i < ggml_backend_dev_count(); ++i) { - ggml_backend_dev_t dev = ggml_backend_dev_get(i); - if (ggml_backend_dev_type(dev) == GGML_BACKEND_DEVICE_TYPE_ACCEL) { - auto * buft = ggml_backend_dev_buffer_type(dev); - // skip - if (buft != ggml_backend_cpu_buffer_type()) { - buft_list.emplace_back(dev, buft); - } - } - } - - // add a host buffer type - // storing the tensors in a host buffer is useful when the processing of large batches - // is offloaded to a GPU device, since it reduces the time spent on data transfers - // generally, this will be done using the first device in the list - // a better approach would be to handle this on a weight-by-weight basis using the offload_op - // function of the device to determine if it would benefit from being stored in a host buffer - if (!no_host) { - for (auto * dev : devices) { - ggml_backend_buffer_type_t buft = ggml_backend_dev_host_buffer_type(dev); - if (buft) { - buft_list.emplace_back(dev, buft); - break; - } - } - } + // add a host buffer type + // storing the tensors in a host buffer is useful when the processing of large batches + // is offloaded to a GPU device, since it reduces the time spent on data transfers + // generally, this will be done using the first device in the list + // a better approach would be to handle this on a weight-by-weight basis using the offload_op + // function of the device to determine if it would benefit from being stored in a host buffer + if (!no_host) { + for (const auto & dev : devices) { + ggml_backend_buffer_type_t buft = ggml_backend_dev_host_buffer_type(dev.dev); + if (buft) { + buft_list.emplace_back(dev.dev, buft); + break; + } + } + } // add extra buffer types if (use_extra_bufts) { @@ -415,14 +952,16 @@ static buft_list_t make_gpu_buft_list(ggml_backend_dev_t dev, llama_split_mode s // add the device extra buffer type (if any) ggml_backend_reg_t reg = ggml_backend_dev_backend_reg(dev); - auto ggml_backend_dev_get_extra_bufts_fn = (ggml_backend_dev_get_extra_bufts_t) - ggml_backend_reg_get_proc_address(reg, "ggml_backend_dev_get_extra_bufts"); - - if (ggml_backend_dev_get_extra_bufts_fn) { - ggml_backend_buffer_type_t * extra_bufts = ggml_backend_dev_get_extra_bufts_fn(dev); - while (extra_bufts && *extra_bufts) { - buft_list.emplace_back(dev, *extra_bufts); - ++extra_bufts; + if (reg) { + auto ggml_backend_dev_get_extra_bufts_fn = (ggml_backend_dev_get_extra_bufts_t) + ggml_backend_reg_get_proc_address(reg, "ggml_backend_dev_get_extra_bufts"); + + if (ggml_backend_dev_get_extra_bufts_fn) { + ggml_backend_buffer_type_t * extra_bufts = ggml_backend_dev_get_extra_bufts_fn(dev); + while (extra_bufts && *extra_bufts) { + buft_list.emplace_back(dev, *extra_bufts); + ++extra_bufts; + } } } @@ -446,7 +985,7 @@ struct llama_model::impl { llama_mlocks mlock_bufs; llama_mlocks mlock_mmaps; - // contexts where the model tensors metadata is stored as well ass the corresponding buffers: + // contexts where the model tensors metadata is stored as well as the corresponding buffers: std::vector<std::pair<ggml_context_ptr, std::vector<ggml_backend_buffer_ptr>>> ctxs_bufs; buft_list_t cpu_buft_list; @@ -468,22 +1007,19 @@ llama_model::llama_model(const llama_model_params & params) : params(params), pi pimpl->has_tensor_overrides = params.tensor_buft_overrides && params.tensor_buft_overrides[0].pattern; } -llama_model::~llama_model() = default; +llama_model::~llama_model() { + for (auto * lora : loras) { + delete lora; + } +} -void llama_model::load_stats(llama_model_loader & ml) { +void llama_model_base::load_stats(llama_model_loader & ml) { pimpl->n_elements = ml.n_elements; pimpl->n_bytes = ml.n_bytes; } -void llama_model::load_arch(llama_model_loader & ml) { - arch = ml.get_arch(); - if (arch == LLM_ARCH_UNKNOWN) { - throw std::runtime_error("unknown model architecture: '" + ml.get_arch_name() + "'"); - } -} - -void llama_model::load_hparams(llama_model_loader & ml) { - const gguf_context * ctx = ml.meta.get(); +void llama_model_base::load_hparams(llama_model_loader & ml) { + const gguf_context * ctx = ml.metadata; // get metadata as string for (int i = 0; i < gguf_get_n_kv(ctx); i++) { @@ -507,15 +1043,25 @@ void llama_model::load_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_CONTEXT_LENGTH, hparams.n_ctx_train); ml.get_key(LLM_KV_EMBEDDING_LENGTH, hparams.n_embd); - ml.get_key(LLM_KV_EMBEDDING_LENGTH_OUT, hparams.n_embd_out, false); - ml.get_key(LLM_KV_BLOCK_COUNT, hparams.n_layer); + ml.get_key(LLM_KV_EMBEDDING_LENGTH_OUT, hparams.n_embd_out_impl, false); + ml.get_key(LLM_KV_ATTENTION_CAUSAL, hparams.causal_attn, false); + ml.get_key(LLM_KV_POOLING_TYPE, hparams.pooling_type, false); + ml.get_key(LLM_KV_BLOCK_COUNT, hparams.n_layer_all); ml.get_key(LLM_KV_EXPERT_COUNT, hparams.n_expert, false); ml.get_key(LLM_KV_EXPERT_USED_COUNT, hparams.n_expert_used, false); ml.get_key(LLM_KV_EXPERT_GROUP_COUNT, hparams.n_expert_groups, false); ml.get_key(LLM_KV_EXPERT_GROUP_USED_COUNT, hparams.n_group_used, false); + if (arch == LLM_ARCH_HUNYUAN_VL || arch == LLM_ARCH_HUNYUAN_DENSE) { + if (hparams.n_expert <= 1) { + hparams.n_expert = 0; + hparams.n_expert_used = 0; + } + } + if (arch == LLM_ARCH_WAVTOKENIZER_DEC) { - ml.get_key(LLM_KV_FEATURES_LENGTH, hparams.n_embd_features); + ml.get_key(LLM_KV_FEATURES_LENGTH, hparams.n_embd); + ml.get_key(LLM_KV_EMBEDDING_LENGTH, hparams.n_embd_out_impl); ml.get_key(LLM_KV_POSNET_EMBEDDING_LENGTH, hparams.posnet.n_embd); ml.get_key(LLM_KV_POSNET_BLOCK_COUNT, hparams.posnet.n_layer); @@ -542,26 +1088,29 @@ void llama_model::load_hparams(llama_model_loader & ml) { std::fill(hparams.n_head_arr.begin(), hparams.n_head_arr.end(), 0); std::fill(hparams.n_head_kv_arr.begin(), hparams.n_head_kv_arr.end(), 0); std::fill(hparams.n_ff_arr.begin(), hparams.n_ff_arr.end(), 0); - std::fill( - hparams.recurrent_layer_arr.begin(), - hparams.recurrent_layer_arr.end(), - llm_arch_is_recurrent(ml.get_arch())); std::fill(hparams.rope_sections.begin(), hparams.rope_sections.end(), 0); - std::fill(hparams.swa_layers.begin(), hparams.swa_layers.end(), 0); + std::fill(hparams.is_swa_impl.begin(), hparams.is_swa_impl.end(), 0); + std::fill(hparams.is_recr_impl.begin(), hparams.is_recr_impl.end(), llm_arch_is_recurrent(ml.get_arch()) ? 1 : 0); std::fill(hparams.xielu_alpha_n.begin(), hparams.xielu_alpha_n.end(), 0.0f); std::fill(hparams.xielu_alpha_p.begin(), hparams.xielu_alpha_p.end(), 0.0f); - std::fill(hparams.xielu_beta.begin(), hparams.xielu_beta.end(), 0.0f); - std::fill(hparams.xielu_eps.begin(), hparams.xielu_eps.end(), 0.0f); + std::fill(hparams.xielu_beta.begin(), hparams.xielu_beta.end(), 0.0f); + std::fill(hparams.xielu_eps.begin(), hparams.xielu_eps.end(), 0.0f); - ml.get_key_or_arr(LLM_KV_FEED_FORWARD_LENGTH, hparams.n_ff_arr, hparams.n_layer, false); - ml.get_key_or_arr(LLM_KV_ATTENTION_HEAD_COUNT, hparams.n_head_arr, hparams.n_layer, false); + std::fill(hparams.swiglu_clamp_exp.begin(), hparams.swiglu_clamp_exp.end(), 0.0f); + std::fill(hparams.swiglu_clamp_shexp.begin(), hparams.swiglu_clamp_shexp.end(), 0.0f); + + ml.get_key_or_arr(LLM_KV_FEED_FORWARD_LENGTH, hparams.n_ff_arr, hparams.n_layer(), false); + ml.get_key_or_arr(LLM_KV_ATTENTION_HEAD_COUNT, hparams.n_head_arr, hparams.n_layer(), false); + + // Populate deepstack_mapping_arr - initialized to -1 (no deepstack) + std::fill(hparams.deepstack_mapping_arr.begin(), hparams.deepstack_mapping_arr.end(), -1); // n_head_kv is optional, default to n_head hparams.n_head_kv_arr = hparams.n_head_arr; - ml.get_key_or_arr(LLM_KV_ATTENTION_HEAD_COUNT_KV, hparams.n_head_kv_arr, hparams.n_layer, false); + ml.get_key_or_arr(LLM_KV_ATTENTION_HEAD_COUNT_KV, hparams.n_head_kv_arr, hparams.n_layer(), false); bool rope_finetuned = false; ml.get_key(LLM_KV_ROPE_SCALING_FINETUNED, rope_finetuned, false); @@ -589,37 +1138,45 @@ void llama_model::load_hparams(llama_model_loader & ml) { hparams.rope_freq_scale_train = ropescale == 0.0f ? 1.0f : 1.0f/ropescale; ml.get_key(LLM_KV_ROPE_SCALING_ATTN_FACTOR, hparams.rope_attn_factor, false); + ml.get_key(LLM_KV_ROPE_SCALING_ALPHA, hparams.rope_scaling_alpha, false); // non-transformer models do not have attention heads if (hparams.n_head() > 0) { // gpt-neox n_rot = rotary_pct * (n_embd / n_head) // gpt-j n_rot = rotary_dim - hparams.n_embd_head_k = hparams.n_embd / hparams.n_head(); - ml.get_key(LLM_KV_ATTENTION_KEY_LENGTH, hparams.n_embd_head_k, false); + hparams.n_embd_head_k_full = hparams.n_embd / hparams.n_head(); + ml.get_key(LLM_KV_ATTENTION_KEY_LENGTH, hparams.n_embd_head_k_full, false); - hparams.n_embd_head_v = hparams.n_embd / hparams.n_head(); - ml.get_key(LLM_KV_ATTENTION_VALUE_LENGTH, hparams.n_embd_head_v, false); + hparams.n_embd_head_v_full = hparams.n_embd / hparams.n_head(); + ml.get_key(LLM_KV_ATTENTION_VALUE_LENGTH, hparams.n_embd_head_v_full, false); // sanity check for n_rot (optional) - hparams.n_rot = hparams.n_embd_head_k; + hparams.n_rot_full = hparams.n_embd_head_k_full; - ml.get_key(LLM_KV_ROPE_DIMENSION_COUNT, hparams.n_rot, false); + ml.get_key(LLM_KV_ROPE_DIMENSION_COUNT, hparams.n_rot_full, false); if (arch == LLM_ARCH_LLAMA || arch == LLM_ARCH_DECI || arch == LLM_ARCH_FALCON || arch == LLM_ARCH_LLAMA_EMBED) { - if (hparams.n_rot != hparams.n_embd_head_k) { - throw std::runtime_error(format("invalid n_rot: %u, expected %u", hparams.n_rot, hparams.n_embd_head_k)); + if (hparams.n_rot_full != hparams.n_embd_head_k_full) { + throw std::runtime_error(format("invalid n_rot: %u, expected %u", hparams.n_rot_full, hparams.n_embd_head_k_full)); } } } else { - hparams.n_rot = 0; - hparams.n_embd_head_k = 0; - hparams.n_embd_head_v = 0; + hparams.n_rot_full = 0; + hparams.n_embd_head_k_full = 0; + hparams.n_embd_head_v_full = 0; } - // for differentiating model types - uint32_t n_vocab = 0; - ml.get_key(LLM_KV_VOCAB_SIZE, n_vocab, false) || ml.get_arr_n(LLM_KV_TOKENIZER_LIST, n_vocab, false); + // head size and n_rot for SWA layers + { + hparams.n_embd_head_k_swa = hparams.n_embd_head_k_full; + hparams.n_embd_head_v_swa = hparams.n_embd_head_v_full; + ml.get_key(LLM_KV_ATTENTION_KEY_LENGTH_SWA, hparams.n_embd_head_k_swa, false); + ml.get_key(LLM_KV_ATTENTION_VALUE_LENGTH_SWA, hparams.n_embd_head_v_swa, false); + + hparams.n_rot_swa = hparams.n_rot_full; + ml.get_key(LLM_KV_ROPE_DIMENSION_COUNT_SWA, hparams.n_rot_swa, false); + } // for classifier models ml.get_arr(LLM_KV_CLASSIFIER_OUTPUT_LABELS, classifier_labels, false); @@ -627,7331 +1184,1052 @@ void llama_model::load_hparams(llama_model_loader & ml) { hparams.n_cls_out = classifier_labels.size(); } - // arch-specific KVs - switch (arch) { - case LLM_ARCH_LLAMA: - case LLM_ARCH_LLAMA_EMBED: - { - ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + // per-arch hparams + load_arch_hparams(ml); - if (hparams.n_expert == 8) { - switch (hparams.n_layer) { - case 32: type = LLM_TYPE_8x7B; break; - case 56: type = LLM_TYPE_8x22B; break; - default: type = LLM_TYPE_UNKNOWN; - } - } else { - switch (hparams.n_layer) { - case 16: type = LLM_TYPE_1B; break; // Llama 3.2 1B - case 22: type = LLM_TYPE_1B; break; - case 26: type = LLM_TYPE_3B; break; - case 28: type = LLM_TYPE_3B; break; // Llama 3.2 3B - case 30: type = LLM_TYPE_256M; break; // smoldocling 256M - // granite uses a vocab with len 49152 - case 32: type = n_vocab == 49152 ? LLM_TYPE_3B : (n_vocab < 40000 ? LLM_TYPE_7B : LLM_TYPE_8B); break; - case 36: type = LLM_TYPE_8B; break; // granite - case 40: type = LLM_TYPE_13B; break; - case 48: type = LLM_TYPE_34B; break; - case 60: type = LLM_TYPE_30B; break; - case 80: type = hparams.n_head() == hparams.n_head_kv() ? LLM_TYPE_65B : LLM_TYPE_70B; break; - default: type = LLM_TYPE_UNKNOWN; - } - } - } break; - case LLM_ARCH_LLAMA4: - { - ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); - ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp); - ml.get_key(LLM_KV_INTERLEAVE_MOE_LAYER_STEP, hparams.n_moe_layer_step); - - const bool found_swa = ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa, false); - if (found_swa && hparams.n_swa == 0) { - hparams.swa_type = LLAMA_SWA_TYPE_NONE; - hparams.n_no_rope_layer_step = hparams.n_layer; // always use rope - } else { - hparams.swa_type = LLAMA_SWA_TYPE_CHUNKED; - hparams.n_swa = 8192; - hparams.n_attn_temp_floor_scale = 8192; - hparams.f_attn_temp_scale = 0.1f; - hparams.f_attn_temp_offset = 1.0f; - hparams.set_swa_pattern(4); // pattern: 3 chunked - 1 full - - hparams.rope_freq_base_train_swa = hparams.rope_freq_base_train; - hparams.rope_freq_scale_train_swa = hparams.rope_freq_scale_train; - ml.get_key(LLM_KV_ROPE_FREQ_BASE_SWA, hparams.rope_freq_base_train_swa, false); - } + pimpl->n_bytes = ml.n_bytes; - switch (hparams.n_expert) { - case 0: { - // MobileLLM (no MoE) - switch (hparams.n_embd) { - case 2048: type = LLM_TYPE_140M; break; - case 4096: type = LLM_TYPE_360M; break; - case 6144: type = LLM_TYPE_950M; break; - default: type = LLM_TYPE_UNKNOWN; - } - } break; - case 16: type = LLM_TYPE_17B_16E; break; - case 128: type = LLM_TYPE_17B_128E; break; - default: type = LLM_TYPE_UNKNOWN; - } + pimpl->desc_str = arch_name() + " " + type_name() + " " + ml.ftype_name(); - hparams.use_kq_norm = type != LLM_TYPE_17B_128E; - } break; - case LLM_ARCH_ARCEE: - { - ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + if (hparams.f_max_alibi_bias > 0.0f) { + hparams.use_alibi = true; + } - // Arcee uses the same structure as Llama - switch (hparams.n_layer) { - case 36: type = LLM_TYPE_4B; break; - default: type = LLM_TYPE_UNKNOWN; - } - } break; - case LLM_ARCH_AFMOE: - { - ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); - ml.get_key(LLM_KV_LEADING_DENSE_BLOCK_COUNT, hparams.n_layer_dense_lead); - ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp); - ml.get_key(LLM_KV_EXPERT_SHARED_COUNT, hparams.n_expert_shared); - ml.get_key(LLM_KV_EXPERT_GATING_FUNC, hparams.expert_gating_func, false); - ml.get_key(LLM_KV_EXPERT_WEIGHTS_SCALE, hparams.expert_weights_scale, false); - ml.get_key(LLM_KV_EXPERT_WEIGHTS_NORM, hparams.expert_weights_norm, false); - ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa, false); - - // Set up interleaved sliding window attention (ISWA) - // Pattern: 3 sliding - 1 full (global_attn_every_n_layers = 4) - if (hparams.n_swa > 0) { - hparams.swa_type = LLAMA_SWA_TYPE_STANDARD; - hparams.set_swa_pattern(4); - - hparams.rope_freq_base_train_swa = hparams.rope_freq_base_train; - hparams.rope_freq_scale_train_swa = hparams.rope_freq_scale_train; - ml.get_key(LLM_KV_ROPE_FREQ_BASE_SWA, hparams.rope_freq_base_train_swa, false); - } else { - hparams.swa_type = LLAMA_SWA_TYPE_NONE; - } + hparams.rope_type = llama_model_rope_type(this); +} - // Default to sigmoid if not set - if (hparams.expert_gating_func == LLAMA_EXPERT_GATING_FUNC_TYPE_NONE) { - hparams.expert_gating_func = LLAMA_EXPERT_GATING_FUNC_TYPE_SIGMOID; - } +void llama_model_base::load_vocab(llama_model_loader & ml) { + const auto kv = LLM_KV(arch); - switch (hparams.n_layer) { - case 56: type = LLM_TYPE_6B; break; - case 32: type = LLM_TYPE_26B; break; - default: type = LLM_TYPE_UNKNOWN; - } - } break; - case LLM_ARCH_DECI: - { - ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); - switch (hparams.n_layer) { - case 32: type = LLM_TYPE_7B; break; - case 80: type = LLM_TYPE_70B; break; - case 162: type = LLM_TYPE_405B; break; - default: type = LLM_TYPE_UNKNOWN; - } - } break; - case LLM_ARCH_MINICPM: - { - // Backward-compatible defaults for older MiniCPM GGUFs - hparams.f_embedding_scale = 12.0f; - hparams.f_residual_scale = 1.4f / sqrtf(float(hparams.n_layer)); - hparams.f_logit_scale = hparams.n_embd ? (256.0f / float(hparams.n_embd)) : 1.0f; + vocab.load(ml, kv); +} - ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); +bool llama_model_base::load_tensors(llama_model_loader & ml) { + const auto & split_mode = params.split_mode; + const auto & use_mlock = params.use_mlock; + const auto & tensor_split = params.tensor_split; - // Optional KV reads, override defaults if present in newer GGUF exports - ml.get_key(LLM_KV_EMBEDDING_SCALE, hparams.f_embedding_scale, /*required=*/false); - ml.get_key(LLM_KV_RESIDUAL_SCALE, hparams.f_residual_scale, /*required=*/false); - ml.get_key(LLM_KV_LOGIT_SCALE, hparams.f_logit_scale, /*required=*/false); + const int n_layer_all = hparams.n_layer_all; + const int n_gpu_layers = this->n_gpu_layers(); - // MiniCPM uses rope by default, unlike Granite which uses it as a switch - hparams.rope_finetuned = true; + const bool use_mmap_buffer = true; - switch (hparams.n_layer) { - case 52: type = LLM_TYPE_1B; break; - case 40: type = LLM_TYPE_2B; break; - default: type = LLM_TYPE_UNKNOWN; - } - } break; - case LLM_ARCH_MINICPM3: - { - ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); - ml.get_key(LLM_KV_ATTENTION_Q_LORA_RANK, hparams.n_lora_q); - ml.get_key(LLM_KV_ATTENTION_KV_LORA_RANK, hparams.n_lora_kv); + this->ml = &ml; // to be used by create_tensor() and load_arch_tensors() - switch (hparams.n_layer) { - case 62: type = LLM_TYPE_4B; break; - default: type = LLM_TYPE_UNKNOWN; - } - } break; - case LLM_ARCH_GROK: - { - // defaults for old GGUFs - hparams.yarn_beta_fast = 8.0f; - hparams.f_logit_scale = 0.5773502691896257f; - hparams.f_embedding_scale = 78.38367176906169f; - hparams.f_attn_out_scale = 0.08838834764831845f; - hparams.f_attn_logit_softcapping = 30.0f; - hparams.f_router_logit_softcapping = 30.0f; - // no final_logit_softcapping in grok-1 - hparams.f_final_logit_softcapping = 0.0f; - - ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); - ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp, false); - ml.get_key(LLM_KV_LOGIT_SCALE, hparams.f_logit_scale, false); - ml.get_key(LLM_KV_EMBEDDING_SCALE, hparams.f_embedding_scale, false); - ml.get_key(LLM_KV_ATTENTION_OUTPUT_SCALE, hparams.f_attn_out_scale, false); - ml.get_key(LLM_KV_ATTN_LOGIT_SOFTCAPPING, hparams.f_attn_logit_softcapping, false); - ml.get_key(LLM_KV_ROUTER_LOGIT_SOFTCAPPING, hparams.f_router_logit_softcapping, false); - ml.get_key(LLM_KV_FINAL_LOGIT_SOFTCAPPING, hparams.f_final_logit_softcapping, false); - - ml.get_key(LLM_KV_ATTENTION_TEMPERATURE_LENGTH, hparams.attn_temp_length, false); - ml.get_key(LLM_KV_ROPE_SCALING_YARN_EXT_FACTOR, hparams.yarn_ext_factor, false); - ml.get_key(LLM_KV_ROPE_SCALING_YARN_ATTN_FACTOR, hparams.yarn_attn_factor, false); - ml.get_key(LLM_KV_ROPE_SCALING_YARN_BETA_FAST, hparams.yarn_beta_fast, false); - ml.get_key(LLM_KV_ROPE_SCALING_YARN_BETA_SLOW, hparams.yarn_beta_slow, false); - - switch (hparams.n_layer) { - case 64: type = LLM_TYPE_314B; break; - default: type = LLM_TYPE_UNKNOWN; - } - } break; - case LLM_ARCH_FALCON: - { - ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); + LLAMA_LOG_INFO("%s: loading model tensors, this can take a while... (mmap = %s, direct_io = %s)\n", + __func__, ml.use_mmap ? "true" : "false", ml.use_direct_io ? "true" : "false"); - switch (hparams.n_layer) { - case 32: type = LLM_TYPE_7B; break; - case 60: type = LLM_TYPE_40B; break; - default: type = LLM_TYPE_UNKNOWN; - } - } break; - case LLM_ARCH_BAICHUAN: - { - ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); - switch (hparams.n_layer) { - case 32: type = LLM_TYPE_7B; break; - case 40: type = LLM_TYPE_13B; break; - default: type = LLM_TYPE_UNKNOWN; - } + // build a list of buffer types for the CPU and GPU devices + pimpl->cpu_buft_list = make_cpu_buft_list(devices, params.use_extra_bufts, params.no_host); + for (const auto & dev : devices) { + buft_list_t buft_list = make_gpu_buft_list(dev.dev, split_mode, tensor_split); + // add CPU buffer types as a fallback + buft_list.insert(buft_list.end(), pimpl->cpu_buft_list.begin(), pimpl->cpu_buft_list.end()); + pimpl->gpu_buft_list.emplace(dev.dev, std::move(buft_list)); + } - if (type == LLM_TYPE_13B) { - // TODO: become GGUF KV parameter - hparams.f_max_alibi_bias = 8.0f; - } - } break; - case LLM_ARCH_STARCODER: - { - ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); - switch (hparams.n_layer) { - case 24: type = LLM_TYPE_1B; break; - case 36: type = LLM_TYPE_3B; break; - case 42: type = LLM_TYPE_7B; break; - case 40: type = LLM_TYPE_15B; break; - default: type = LLM_TYPE_UNKNOWN; - } - } break; - case LLM_ARCH_REFACT: - { - ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); - switch (hparams.n_layer) { - case 32: type = LLM_TYPE_1B; break; - default: type = LLM_TYPE_UNKNOWN; - } + ggml_backend_dev_t cpu_dev = ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_CPU); + if (cpu_dev == nullptr) { + throw std::runtime_error(format("%s: no CPU backend found", __func__)); + } - // TODO: become GGUF KV parameter - hparams.f_max_alibi_bias = 8.0f; - } break; - case LLM_ARCH_BERT: - { - ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); - ml.get_key(LLM_KV_ATTENTION_CAUSAL, hparams.causal_attn); - ml.get_key(LLM_KV_POOLING_TYPE, hparams.pooling_type, false); - - switch (hparams.n_layer) { - case 3: - type = LLM_TYPE_17M; break; // bge-micro - case 6: - type = LLM_TYPE_22M; break; // MiniLM-L6 - case 12: - switch (hparams.n_embd) { - case 384: type = LLM_TYPE_33M; break; // MiniLM-L12, bge-small - case 768: type = LLM_TYPE_109M; break; // bge-base - default: type = LLM_TYPE_UNKNOWN; - } break; - case 24: - type = LLM_TYPE_335M; break; // bge-large - default: type = LLM_TYPE_UNKNOWN; - } - } break; - case LLM_ARCH_MODERN_BERT: - { - const bool found_swa = ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa, false); - if (found_swa && hparams.n_swa > 0) { - uint32_t swa_period = 3; - hparams.swa_type = LLAMA_SWA_TYPE_SYMMETRIC; - - ml.get_key(LLM_KV_ROPE_FREQ_BASE_SWA, hparams.rope_freq_base_train_swa); - ml.get_key_or_arr(LLM_KV_ATTENTION_SLIDING_WINDOW_PATTERN, swa_period, false); - hparams.set_swa_pattern(swa_period); - } else { - hparams.swa_type = LLAMA_SWA_TYPE_NONE; - } + // calculate the split points + bool all_zero = tensor_split == nullptr || std::all_of(tensor_split, tensor_split + n_devices(), [](float x) { return x == 0.0f; }); + std::vector<float> splits(n_devices()); + if (all_zero) { + // default split, by free memory + for (size_t i = 0; i < n_devices(); ++i) { + ggml_backend_dev_t dev = devices[i].dev; + size_t total; + size_t free; + ggml_backend_dev_memory(dev, &free, &total); - ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); - ml.get_key(LLM_KV_ATTENTION_CAUSAL, hparams.causal_attn); - ml.get_key(LLM_KV_POOLING_TYPE, hparams.pooling_type, false); - - switch (hparams.n_layer) { - case 12: - type = LLM_TYPE_47M; break; // granite-embedding-small - case 22: - type = LLM_TYPE_149M; break; // modern-bert-base - case 28: - type = LLM_TYPE_395M; break; // modern-bert-large - default: type = LLM_TYPE_UNKNOWN; - } - } break; - case LLM_ARCH_JINA_BERT_V2: - { - ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); - ml.get_key(LLM_KV_ATTENTION_CAUSAL, hparams.causal_attn); - ml.get_key(LLM_KV_POOLING_TYPE, hparams.pooling_type, false); - hparams.f_max_alibi_bias = 8.0f; - - switch (hparams.n_layer) { - case 4: type = LLM_TYPE_33M; break; // jina-embeddings-small - case 12: type = LLM_TYPE_137M; break; // jina-embeddings-base - default: type = LLM_TYPE_UNKNOWN; - } - } break; - case LLM_ARCH_JINA_BERT_V3: - { - ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); - ml.get_key(LLM_KV_ATTENTION_CAUSAL, hparams.causal_attn); - ml.get_key(LLM_KV_POOLING_TYPE, hparams.pooling_type, false); - - switch (hparams.n_layer) { - case 24: - type = LLM_TYPE_558M; break; - default: type = LLM_TYPE_UNKNOWN; - } - } break; - case LLM_ARCH_NOMIC_BERT: - case LLM_ARCH_NOMIC_BERT_MOE: - { - ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); - ml.get_key(LLM_KV_ATTENTION_CAUSAL, hparams.causal_attn); - ml.get_key(LLM_KV_POOLING_TYPE, hparams.pooling_type); - ml.get_key(LLM_KV_MOE_EVERY_N_LAYERS, hparams.moe_every_n_layers, 0); - - if (hparams.n_layer == 12 && hparams.n_embd == 768) { - if (arch == LLM_ARCH_NOMIC_BERT) { - type = LLM_TYPE_137M; - } else if (arch == LLM_ARCH_NOMIC_BERT_MOE && hparams.moe_every_n_layers == 2) { - type = LLM_TYPE_475M; - } - } - } break; - case LLM_ARCH_NEO_BERT: - { - ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); - ml.get_key(LLM_KV_ATTENTION_CAUSAL, hparams.causal_attn); - ml.get_key(LLM_KV_POOLING_TYPE, hparams.pooling_type); + // devices can return 0 bytes for free and total memory if they do not + // have any to report. in this case, we will use the host memory as a fallback + // fixes: https://github.com/ggml-org/llama.cpp/issues/18577 + if (free == 0 && total == 0) { + ggml_backend_dev_memory(cpu_dev, &free, &total); + } + splits[i] = free; + } + } else { + std::copy(tensor_split, tensor_split + n_devices(), splits.begin()); + } - if (hparams.n_layer == 28) { - type = LLM_TYPE_250M; - } - } break; - case LLM_ARCH_BLOOM: - { - ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); - - switch (hparams.n_layer) { - case 24: type = LLM_TYPE_1B; break; - case 30: - switch (hparams.n_embd) { - case 2560: type = LLM_TYPE_3B; break; - case 4096: type = LLM_TYPE_7B; break; - default: type = LLM_TYPE_UNKNOWN; - } break; - default: type = LLM_TYPE_UNKNOWN; - } + // sum and normalize the splits to get the split points + float split_sum = 0.0f; + for (size_t i = 0; i < n_devices(); ++i) { + split_sum += splits[i]; + splits[i] = split_sum; + } + for (size_t i = 0; i < n_devices(); ++i) { + splits[i] /= split_sum; + } - // TODO: become GGUF KV parameter - hparams.f_max_alibi_bias = 8.0f; - } break; - case LLM_ARCH_MPT: - { - ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); - ml.get_key(LLM_KV_ATTENTION_CLAMP_KQV, hparams.f_clamp_kqv, false); - ml.get_key(LLM_KV_ATTENTION_MAX_ALIBI_BIAS, hparams.f_max_alibi_bias); - - switch (hparams.n_layer) { - case 32: type = LLM_TYPE_7B; break; - case 48: type = LLM_TYPE_30B; break; - default: type = LLM_TYPE_UNKNOWN; - } - } break; - case LLM_ARCH_STABLELM: - { - ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); - - switch (hparams.n_layer) { - case 24: type = LLM_TYPE_1B; break; - case 32: type = LLM_TYPE_3B; break; - case 40: type = LLM_TYPE_12B; break; - default: type = LLM_TYPE_UNKNOWN; - } - } break; - case LLM_ARCH_QWEN: - { - ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + const int i_gpu_start = std::max(n_layer_all + 1 - n_gpu_layers, 0); + const int act_gpu_layers = devices.empty() ? 0 : std::min(n_gpu_layers, n_layer_all + 1); + auto get_layer_buft_list = [&](int il) -> llama_model::impl::layer_dev { + const bool is_swa = il < n_layer_all && hparams.is_swa(il); + if (il < i_gpu_start || (il - i_gpu_start) >= act_gpu_layers) { + LLAMA_LOG_DEBUG("load_tensors: layer %3d assigned to device %s, is_swa = %d\n", il, ggml_backend_dev_name(cpu_dev), is_swa); + return {cpu_dev, &pimpl->cpu_buft_list}; + } + const int layer_gpu = std::upper_bound(splits.begin(), splits.begin() + n_devices(), float(il - i_gpu_start)/act_gpu_layers) - splits.begin(); + auto * dev = devices.at(layer_gpu).dev; + LLAMA_LOG_DEBUG("load_tensors: layer %3d assigned to device %s, is_swa = %d\n", il, ggml_backend_dev_name(dev), is_swa); + return {dev, &pimpl->gpu_buft_list.at(dev)}; + }; - switch (hparams.n_layer) { - case 32: type = LLM_TYPE_7B; break; - case 40: type = LLM_TYPE_13B; break; - default: type = LLM_TYPE_UNKNOWN; - } - } break; - case LLM_ARCH_QWEN2VL: - { - ml.get_key_or_arr(LLM_KV_ROPE_DIMENSION_SECTIONS, hparams.rope_sections, 4, true); - } - // fall through - case LLM_ARCH_QWEN2: - { - ml.get_key(LLM_KV_POOLING_TYPE, hparams.pooling_type, false); - ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); - switch (hparams.n_layer) { - case 24: type = hparams.n_embd == 1024 ? LLM_TYPE_0_5B : LLM_TYPE_1B; break; - case 28: type = hparams.n_embd == 1536 ? LLM_TYPE_1_5B : LLM_TYPE_7B; break; - case 32: type = LLM_TYPE_7B; break; - case 36: type = LLM_TYPE_3B; break; - case 40: type = hparams.n_head() == 20 ? LLM_TYPE_4B : LLM_TYPE_13B; break; - case 48: type = LLM_TYPE_14B; break; - case 64: type = LLM_TYPE_32B; break; - case 80: type = LLM_TYPE_70B; break; - default: type = LLM_TYPE_UNKNOWN; - } - } break; - case LLM_ARCH_DREAM: - { - ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); - // Dream models are primarily 7B with 28 layers - switch (hparams.n_layer) { - case 28: - type = LLM_TYPE_7B; - break; - default: - type = LLM_TYPE_UNKNOWN; - } - // Set non-causal attention for diffusion models - hparams.causal_attn = false; - } - break; - case LLM_ARCH_LLADA: - { - ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); - // LLaDA-8B has 32 layers, similar to LLaMA but for diffusion - switch (hparams.n_layer) { - case 32: - type = LLM_TYPE_8B; - break; - default: - type = LLM_TYPE_UNKNOWN; - } - // Set non-causal attention for diffusion models - hparams.causal_attn = false; - } - break; - case LLM_ARCH_LLADA_MOE: - { - ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp, false); - - ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); - // diffusion language model uses non-causal attention - hparams.causal_attn = false; - switch (hparams.n_layer) { - case 16: type = LLM_TYPE_A1_7B; break; - default: type = LLM_TYPE_UNKNOWN; - } - } break; - case LLM_ARCH_RND1: - { - ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp, false); + // assign the input layer + // there is very little benefit to offloading the input layer, so always keep it on the CPU + pimpl->dev_input = { cpu_dev, &pimpl->cpu_buft_list }; - ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); - switch (hparams.n_layer) { - case 48: type = LLM_TYPE_30B_A3B; break; - default: type = LLM_TYPE_UNKNOWN; - } - // Set non-causal attention for diffusion models - hparams.causal_attn = false; - } break; - case LLM_ARCH_QWEN2MOE: - { - ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp, false); - ml.get_key(LLM_KV_EXPERT_SHARED_FEED_FORWARD_LENGTH, hparams.n_ff_shexp, false); - - ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); - switch (hparams.n_layer) { - case 24: type = LLM_TYPE_A2_7B; break; - case 28: type = LLM_TYPE_57B_A14B; break; - default: type = LLM_TYPE_UNKNOWN; - } - } break; - case LLM_ARCH_QWEN3: - { - ml.get_key(LLM_KV_POOLING_TYPE, hparams.pooling_type, false); - ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); - switch (hparams.n_layer) { - case 28: type = hparams.n_embd == 1024 ? LLM_TYPE_0_6B : LLM_TYPE_1_7B; break; - case 36: type = hparams.n_embd == 2560 ? LLM_TYPE_4B : LLM_TYPE_8B; break; - case 40: type = LLM_TYPE_14B; break; - case 64: type = LLM_TYPE_32B; break; - default: type = LLM_TYPE_UNKNOWN; - } - } break; - case LLM_ARCH_MAINCODER: - { - ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); - switch (hparams.n_layer) { - case 32: type = LLM_TYPE_1B; break; - default: type = LLM_TYPE_UNKNOWN; - } - } break; - case LLM_ARCH_QWEN3VL: - { - ml.get_key(LLM_KV_NUM_DEEPSTACK_LAYERS, hparams.n_deepstack_layers, false); - ml.get_key_or_arr(LLM_KV_ROPE_DIMENSION_SECTIONS, hparams.rope_sections, 4, true); - ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); - switch (hparams.n_layer) { - case 28: type = LLM_TYPE_1_7B; break; - case 36: type = hparams.n_embd == 2560 ? LLM_TYPE_4B : LLM_TYPE_8B; break; - case 64: type = LLM_TYPE_32B; break; - default: type = LLM_TYPE_UNKNOWN; - } - } break; - case LLM_ARCH_QWEN3MOE: - { - ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp, false); + // assign the repeating layers to the devices according to the splits + pimpl->dev_layer.resize(n_layer_all); + for (int il = 0; il < n_layer_all; ++il) { + pimpl->dev_layer[il] = get_layer_buft_list(il); + } - ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); - switch (hparams.n_layer) { - case 48: type = LLM_TYPE_30B_A3B; break; - case 94: type = LLM_TYPE_235B_A22B; break; - default: type = LLM_TYPE_UNKNOWN; - } - } break; - case LLM_ARCH_QWEN3VLMOE: - { - ml.get_key(LLM_KV_NUM_DEEPSTACK_LAYERS, hparams.n_deepstack_layers, false); - ml.get_key_or_arr(LLM_KV_ROPE_DIMENSION_SECTIONS, hparams.rope_sections, 4, true); - ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp, false); - ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); - switch (hparams.n_layer) { - case 48: type = LLM_TYPE_30B_A3B; break; - case 94: type = LLM_TYPE_235B_A22B; break; - default: type = LLM_TYPE_UNKNOWN; - } - } break; - case LLM_ARCH_PHI2: - { - ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); + // assign the output layer + pimpl->dev_output = get_layer_buft_list(n_layer_all); - switch (hparams.n_layer) { - case 24: type = LLM_TYPE_1B; break; - case 32: type = LLM_TYPE_3B; break; - default: type = LLM_TYPE_UNKNOWN; - } - } break; - case LLM_ARCH_PHI3: - { - ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + const auto TENSOR_NOT_REQUIRED = llama_model_loader::TENSOR_NOT_REQUIRED; - switch (hparams.n_layer) { - case 24: type = LLM_TYPE_1B; break; - case 32: type = LLM_TYPE_3B; break; - case 40: type = LLM_TYPE_14B; break; - default: type = LLM_TYPE_UNKNOWN; - } + // create tensors for the weights + { + // TODO: move to a separate function + const auto tn = LLM_TN(arch); - const bool found_swa = ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa, false); + const int64_t n_expert = hparams.n_expert; + const int64_t n_expert_used = hparams.n_expert_used; - if (found_swa && hparams.n_swa > 0) { - LLAMA_LOG_WARN("%s: Phi SWA is currently disabled - results might be suboptimal for some models (see %s)\n", - __func__, "https://github.com/ggml-org/llama.cpp/pull/13676"); + if (n_expert > 0 && n_expert_used == 0) { + throw std::runtime_error("model has expert layers but no expert layers are used"); + } - // TODO: fix conversion scripts to correctly populate `n_swa` and `n_swa_pattern` - hparams.swa_type = LLAMA_SWA_TYPE_NONE; + layers.resize(n_layer_all); - hparams.n_swa = 0; - hparams.set_swa_pattern(1); - } - } break; - case LLM_ARCH_PHIMOE: - { - ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + // call the per-model loading function + load_arch_tensors(ml); - switch (hparams.n_layer) { - case 32: type = LLM_TYPE_16x3_8B; break; - default: type = LLM_TYPE_UNKNOWN; - } - } break; - case LLM_ARCH_PLAMO: - { - ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + // generic pass: load optional per-tensor/per-expert ".scale" tensors (e.g. NVFP4 scale2) + // this avoids having to add scale loading to every architecture + for (int i = 0; i < n_layer_all; ++i) { + auto & layer = layers[i]; - switch (hparams.n_layer) { - case 40: type = LLM_TYPE_13B; break; - default: type = LLM_TYPE_UNKNOWN; - } - } break; - case LLM_ARCH_PLAMO2: - { - ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + // attention weight scales (per-tensor, shape {1}) + if (!layer.wq_s && layer.wq) { + layer.wq_s = create_tensor(tn(LLM_TENSOR_ATTN_Q, "scale", i), {1}, TENSOR_NOT_REQUIRED); + } + if (!layer.wk_s && layer.wk) { + layer.wk_s = create_tensor(tn(LLM_TENSOR_ATTN_K, "scale", i), {1}, TENSOR_NOT_REQUIRED); + } + if (!layer.wv_s && layer.wv) { + layer.wv_s = create_tensor(tn(LLM_TENSOR_ATTN_V, "scale", i), {1}, TENSOR_NOT_REQUIRED); + } + if (!layer.wo_s && layer.wo) { + layer.wo_s = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "scale", i), {1}, TENSOR_NOT_REQUIRED); + } + if (!layer.wqkv_s && layer.wqkv) { + layer.wqkv_s = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "scale", i), {1}, TENSOR_NOT_REQUIRED); + } + if (!layer.wqkv_gate_s && layer.wqkv_gate) { + layer.wqkv_gate_s = create_tensor(tn(LLM_TENSOR_ATTN_GATE, "scale", i), {1}, TENSOR_NOT_REQUIRED); + } - // Load Mamba SSM parameters - ml.get_key(LLM_KV_SSM_CONV_KERNEL, hparams.ssm_d_conv); - ml.get_key(LLM_KV_SSM_INNER_SIZE, hparams.ssm_d_inner); - ml.get_key(LLM_KV_SSM_STATE_SIZE, hparams.ssm_d_state); - ml.get_key(LLM_KV_SSM_TIME_STEP_RANK, hparams.ssm_dt_rank); - ml.get_key(LLM_KV_SSM_GROUP_COUNT, hparams.ssm_n_group); + // dense FFN weight scales (per-tensor, shape {1}) + if (!layer.ffn_gate_s && layer.ffn_gate) { + layer.ffn_gate_s = create_tensor(tn(LLM_TENSOR_FFN_GATE, "scale", i), {1}, TENSOR_NOT_REQUIRED); + } + if (!layer.ffn_down_s && layer.ffn_down) { + layer.ffn_down_s = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "scale", i), {1}, TENSOR_NOT_REQUIRED); + } + if (!layer.ffn_up_s && layer.ffn_up) { + layer.ffn_up_s = create_tensor(tn(LLM_TENSOR_FFN_UP, "scale", i), {1}, TENSOR_NOT_REQUIRED); + } + if (!layer.ffn_gate_shexp_s && layer.ffn_gate_shexp) { + layer.ffn_gate_shexp_s = create_tensor(tn(LLM_TENSOR_FFN_GATE_SHEXP, "scale", i), {1}, TENSOR_NOT_REQUIRED); + } + if (!layer.ffn_down_shexp_s && layer.ffn_down_shexp) { + layer.ffn_down_shexp_s = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "scale", i), {1}, TENSOR_NOT_REQUIRED); + } + if (!layer.ffn_up_shexp_s && layer.ffn_up_shexp) { + layer.ffn_up_shexp_s = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "scale", i), {1}, TENSOR_NOT_REQUIRED); + } - for (uint32_t i = 0; i < hparams.n_layer; ++i) { - hparams.recurrent_layer_arr[i] = hparams.n_head_kv(i) == 0; - } + // MoE expert weight scales (per-expert, shape {n_expert}) + if (!layer.ffn_gate_exps_s && layer.ffn_gate_exps) { + layer.ffn_gate_exps_s = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "scale", i), {n_expert}, TENSOR_NOT_REQUIRED); + } + if (!layer.ffn_down_exps_s && layer.ffn_down_exps) { + layer.ffn_down_exps_s = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "scale", i), {n_expert}, TENSOR_NOT_REQUIRED); + } + if (!layer.ffn_up_exps_s && layer.ffn_up_exps) { + layer.ffn_up_exps_s = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "scale", i), {n_expert}, TENSOR_NOT_REQUIRED); + } - switch (hparams.n_layer) { - case 16: type = LLM_TYPE_1B; break; - case 32: - if (hparams.n_embd == 2048) { - type = LLM_TYPE_2B; - } else if (hparams.n_embd == 4096) { - type = LLM_TYPE_8B; - } - break; - default: type = LLM_TYPE_UNKNOWN; - } + // recurrent / linear-attention weight scales (per-tensor, shape {1}) + if (!layer.ssm_in_s && layer.ssm_in) { + layer.ssm_in_s = create_tensor(tn(LLM_TENSOR_SSM_IN, "scale", i), {1}, TENSOR_NOT_REQUIRED); + } + if (!layer.ssm_out_s && layer.ssm_out) { + layer.ssm_out_s = create_tensor(tn(LLM_TENSOR_SSM_OUT, "scale", i), {1}, TENSOR_NOT_REQUIRED); + } + if (!layer.ssm_alpha_s && layer.ssm_alpha) { + layer.ssm_alpha_s = create_tensor(tn(LLM_TENSOR_SSM_ALPHA, "scale", i), {1}, TENSOR_NOT_REQUIRED); + } + if (!layer.ssm_beta_s && layer.ssm_beta) { + layer.ssm_beta_s = create_tensor(tn(LLM_TENSOR_SSM_BETA, "scale", i), {1}, TENSOR_NOT_REQUIRED); + } + if (!layer.nextn.eh_proj_s && layer.nextn.eh_proj) { + layer.nextn.eh_proj_s = create_tensor(tn(LLM_TENSOR_NEXTN_EH_PROJ, "scale", i), {1}, TENSOR_NOT_REQUIRED); + } + if (!layer.nextn.shared_head_head_s && layer.nextn.shared_head_head) { + layer.nextn.shared_head_head_s = create_tensor(tn(LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD, "scale", i), {1}, TENSOR_NOT_REQUIRED); + } - // Load attention parameters - ml.get_key(LLM_KV_ATTENTION_KEY_LENGTH, hparams.n_embd_head_k, false); - ml.get_key(LLM_KV_ATTENTION_VALUE_LENGTH, hparams.n_embd_head_v, false); - } break; - case LLM_ARCH_PLAMO3: - { - ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); - const bool found_swa = ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa, false); - if (found_swa && hparams.n_swa > 0) { - uint32_t swa_period = 8; - hparams.swa_type = LLAMA_SWA_TYPE_STANDARD; - ml.get_key(LLM_KV_ROPE_FREQ_BASE_SWA, hparams.rope_freq_base_train_swa); - ml.get_key_or_arr(LLM_KV_ATTENTION_SLIDING_WINDOW_PATTERN, swa_period, false); - hparams.set_swa_pattern(swa_period); - } else { - hparams.swa_type = LLAMA_SWA_TYPE_NONE; - } + // input scales + if (!layer.wq_in_s && layer.wq) { + layer.wq_in_s = create_tensor(tn(LLM_TENSOR_ATTN_Q, "input_scale", i), {1}, TENSOR_NOT_REQUIRED); + } + if (!layer.wk_in_s && layer.wk) { + layer.wk_in_s = create_tensor(tn(LLM_TENSOR_ATTN_K, "input_scale", i), {1}, TENSOR_NOT_REQUIRED); + } + if (!layer.wv_in_s && layer.wv) { + layer.wv_in_s = create_tensor(tn(LLM_TENSOR_ATTN_V, "input_scale", i), {1}, TENSOR_NOT_REQUIRED); + } + if (!layer.wo_in_s && layer.wo) { + layer.wo_in_s = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "input_scale", i), {1}, TENSOR_NOT_REQUIRED); + } + if (!layer.wqkv_in_s && layer.wqkv) { + layer.wqkv_in_s = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "input_scale", i), {1}, TENSOR_NOT_REQUIRED); + } + if (!layer.wqkv_gate_in_s && layer.wqkv_gate) { + layer.wqkv_gate_in_s = create_tensor(tn(LLM_TENSOR_ATTN_GATE, "input_scale", i), {1}, TENSOR_NOT_REQUIRED); + } + if (!layer.ffn_gate_in_s && layer.ffn_gate) { + layer.ffn_gate_in_s = create_tensor(tn(LLM_TENSOR_FFN_GATE, "input_scale", i), {1}, TENSOR_NOT_REQUIRED); + } + if (!layer.ffn_down_in_s && layer.ffn_down) { + layer.ffn_down_in_s = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "input_scale", i), {1}, TENSOR_NOT_REQUIRED); + } + if (!layer.ffn_up_in_s && layer.ffn_up) { + layer.ffn_up_in_s = create_tensor(tn(LLM_TENSOR_FFN_UP, "input_scale", i), {1}, TENSOR_NOT_REQUIRED); + } + if (!layer.ffn_gate_exps_in_s && layer.ffn_gate_exps) { + layer.ffn_gate_exps_in_s = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "input_scale", i), {n_expert}, TENSOR_NOT_REQUIRED); + } + if (!layer.ffn_down_exps_in_s && layer.ffn_down_exps) { + layer.ffn_down_exps_in_s = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "input_scale", i), {n_expert}, TENSOR_NOT_REQUIRED); + } + if (!layer.ffn_up_exps_in_s && layer.ffn_up_exps) { + layer.ffn_up_exps_in_s = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "input_scale", i), {n_expert}, TENSOR_NOT_REQUIRED); + } + if (!layer.ffn_gate_shexp_in_s && layer.ffn_gate_shexp) { + layer.ffn_gate_shexp_in_s = create_tensor(tn(LLM_TENSOR_FFN_GATE_SHEXP, "input_scale", i), {1}, TENSOR_NOT_REQUIRED); + } + if (!layer.ffn_down_shexp_in_s && layer.ffn_down_shexp) { + layer.ffn_down_shexp_in_s = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "input_scale", i), {1}, TENSOR_NOT_REQUIRED); + } + if (!layer.ffn_up_shexp_in_s && layer.ffn_up_shexp) { + layer.ffn_up_shexp_in_s = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "input_scale", i), {1}, TENSOR_NOT_REQUIRED); + } + if (!layer.ssm_in_in_s && layer.ssm_in) { + layer.ssm_in_in_s = create_tensor(tn(LLM_TENSOR_SSM_IN, "input_scale", i), {1}, TENSOR_NOT_REQUIRED); + } + if (!layer.ssm_out_in_s && layer.ssm_out) { + layer.ssm_out_in_s = create_tensor(tn(LLM_TENSOR_SSM_OUT, "input_scale", i), {1}, TENSOR_NOT_REQUIRED); + } + if (!layer.ssm_alpha_in_s && layer.ssm_alpha) { + layer.ssm_alpha_in_s = create_tensor(tn(LLM_TENSOR_SSM_ALPHA, "input_scale", i), {1}, TENSOR_NOT_REQUIRED); + } + if (!layer.ssm_beta_in_s && layer.ssm_beta) { + layer.ssm_beta_in_s = create_tensor(tn(LLM_TENSOR_SSM_BETA, "input_scale", i), {1}, TENSOR_NOT_REQUIRED); + } + if (!layer.nextn.eh_proj_in_s && layer.nextn.eh_proj) { + layer.nextn.eh_proj_in_s = create_tensor(tn(LLM_TENSOR_NEXTN_EH_PROJ, "input_scale", i), {1}, TENSOR_NOT_REQUIRED); + } + if (!layer.nextn.shared_head_head_in_s && layer.nextn.shared_head_head) { + layer.nextn.shared_head_head_in_s = create_tensor(tn(LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD, "input_scale", i), {1}, TENSOR_NOT_REQUIRED); + } + } + // output scales + if (output && output->type == GGML_TYPE_NVFP4) { + // weight scale + if (!output_s) { + output_s = create_tensor(tn(LLM_TENSOR_OUTPUT, "scale"), {1}, TENSOR_NOT_REQUIRED); + } + // input scale + if (!output_in_s) { + output_in_s = create_tensor(tn(LLM_TENSOR_OUTPUT, "input_scale"), {1}, TENSOR_NOT_REQUIRED); + } + } + } + ml.done_getting_tensors(); - switch (hparams.n_layer) { - case 24: type = LLM_TYPE_2B; break; - default: type = LLM_TYPE_UNKNOWN; - } - } break; - case LLM_ARCH_GPT2: - { - ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); - switch (hparams.n_layer) { - case 12: type = LLM_TYPE_SMALL; break; - case 24: type = LLM_TYPE_MEDIUM; break; - case 36: type = LLM_TYPE_LARGE; break; - case 48: type = LLM_TYPE_XL; break; - default: type = LLM_TYPE_UNKNOWN; - } - } break; - case LLM_ARCH_CODESHELL: - { - ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); - switch (hparams.n_layer) { - case 42: type = LLM_TYPE_7B; break; - default: type = LLM_TYPE_UNKNOWN; - } - } break; - case LLM_ARCH_ORION: - { - ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); + GGML_ASSERT(!(output && tok_embd && + strcmp(output->name, tok_embd->name) == 0 && + output->type == GGML_TYPE_NVFP4)); + // populate tensors_by_name + for (auto & [_, ctx_ptr] : ml.ctx_map) { + for (auto * cur = ggml_get_first_tensor(ctx_ptr.get()); cur != NULL; cur = ggml_get_next_tensor(ctx_ptr.get(), cur)) { + tensors_by_name.emplace_back(ggml_get_name(cur), cur); + } + } - switch (hparams.n_layer) { - case 40: type = LLM_TYPE_14B; break; - default: type = LLM_TYPE_UNKNOWN; - } - } break; - case LLM_ARCH_INTERNLM2: - { - ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); - switch (hparams.n_layer) { - case 32: type = LLM_TYPE_7B; break; - case 48: type = LLM_TYPE_20B; break; - default: type = LLM_TYPE_UNKNOWN; - } - } break; - case LLM_ARCH_GEMMA: - { - ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + ml.init_mappings(true, use_mlock ? &pimpl->mlock_mmaps : nullptr); + pimpl->mappings.reserve(ml.mappings.size()); - switch (hparams.n_layer) { - case 18: type = LLM_TYPE_2B; break; - case 28: type = LLM_TYPE_7B; break; - default: type = LLM_TYPE_UNKNOWN; - } - } break; - case LLM_ARCH_GEMMA2: - { - hparams.swa_type = LLAMA_SWA_TYPE_STANDARD; - hparams.n_swa = 4096; // default value of gemma 2 - hparams.set_swa_pattern(2); - hparams.attn_soft_cap = true; - hparams.rope_freq_base_train_swa = hparams.rope_freq_base_train; - hparams.rope_freq_scale_train_swa = hparams.rope_freq_scale_train; - - ml.get_key(LLM_KV_ROPE_FREQ_BASE_SWA, hparams.rope_freq_base_train_swa, false); - ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa, false); - ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); - ml.get_key(LLM_KV_ATTN_LOGIT_SOFTCAPPING, hparams.f_attn_logit_softcapping, false); - ml.get_key(LLM_KV_FINAL_LOGIT_SOFTCAPPING, hparams.f_final_logit_softcapping, false); - - switch (hparams.n_layer) { - case 26: type = LLM_TYPE_2B; break; - case 42: type = LLM_TYPE_9B; break; - case 46: type = LLM_TYPE_27B; break; - default: type = LLM_TYPE_UNKNOWN; - } - - // ref: https://github.com/google/gemma_pytorch/blob/014acb7ac4563a5f77c76d7ff98f31b568c16508/gemma/config.py#L173 - hparams.f_attention_scale = type == LLM_TYPE_27B - ? 1.0f / std::sqrt(float(hparams.n_embd / hparams.n_head(0))) - : 1.0f / std::sqrt(float(hparams.n_embd_head_k)); - } break; - case LLM_ARCH_GEMMA3: - { - const bool found_swa = ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa, false); - if (found_swa && hparams.n_swa > 0) { - hparams.swa_type = LLAMA_SWA_TYPE_STANDARD; - hparams.set_swa_pattern(6); + // create the backend buffers + std::vector<std::pair<ggml_context *, llama_buf_map>> ctx_buf_maps; + ctx_buf_maps.reserve(ml.ctx_map.size()); - ml.get_key(LLM_KV_ROPE_FREQ_BASE_SWA, hparams.rope_freq_base_train_swa, false); - } else { - hparams.swa_type = LLAMA_SWA_TYPE_NONE; - } + // Ensure we have enough capacity for the maximum backend buffer we will potentially create + const size_t n_max_backend_buffer = ml.ctx_map.size() * ml.files.size(); + pimpl->ctxs_bufs.reserve(n_max_backend_buffer); - hparams.f_final_logit_softcapping = 0.0f; - ml.get_key(LLM_KV_FINAL_LOGIT_SOFTCAPPING, hparams.f_final_logit_softcapping, false); - ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); - - switch (hparams.n_layer) { - case 18: type = LLM_TYPE_270M; break; - case 26: type = LLM_TYPE_1B; break; - case 32: type = LLM_TYPE_8B; break; // Rnj-1 - case 34: type = LLM_TYPE_4B; break; - case 48: type = LLM_TYPE_12B; break; - case 62: type = LLM_TYPE_27B; break; - default: type = LLM_TYPE_UNKNOWN; - } + for (auto & [buft, ctx_ptr] : ml.ctx_map) { + ggml_context * ctx = ctx_ptr.get(); - // ref: https://github.com/google/gemma_pytorch/blob/014acb7ac4563a5f77c76d7ff98f31b568c16508/gemma/config.py#L289 - hparams.f_attention_scale = type == LLM_TYPE_27B - ? 1.0f / std::sqrt(float(hparams.n_embd / hparams.n_head(0))) - : 1.0f / std::sqrt(float(hparams.n_embd_head_k)); - } break; - case LLM_ARCH_GEMMA3N: - { - hparams.swa_type = LLAMA_SWA_TYPE_STANDARD; - hparams.set_swa_pattern(5); + // skip contexts without tensors + if (ggml_get_first_tensor(ctx) == nullptr) { + continue; + } - hparams.n_layer_kv_from_start = 20; - hparams.f_attention_scale = 1.0f; + llama_buf_map buf_map; + buf_map.reserve(n_max_backend_buffer); - ml.get_key(LLM_KV_ROPE_FREQ_BASE_SWA, hparams.rope_freq_base_train_swa, false); - ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa); - ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + // check if it is possible to use buffer_from_host_ptr with this buffer type + ggml_backend_dev_t dev = ggml_backend_buft_get_device(buft); + if (!dev) { + // FIXME: workaround for CPU backend buft having a NULL device + dev = ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_CPU); + if (!dev) { + throw std::runtime_error(format("%s: no CPU backend found", __func__)); + } + } + ggml_backend_dev_props props; + ggml_backend_dev_get_props(dev, &props); + bool buffer_from_host_ptr_supported = props.caps.buffer_from_host_ptr; + bool is_default_buft = buft == ggml_backend_dev_buffer_type(dev); - switch (hparams.n_layer) { - case 30: type = LLM_TYPE_E2B; break; - case 35: type = LLM_TYPE_E4B; break; - default: type = LLM_TYPE_UNKNOWN; + std::vector<ggml_backend_buffer_ptr> bufs; + if (ml.use_mmap && use_mmap_buffer && buffer_from_host_ptr_supported && is_default_buft) { + GGML_ASSERT(!ml.no_alloc); + for (uint32_t idx = 0; idx < ml.files.size(); idx++) { + // only the mmap region containing the tensors in the model is mapped to the backend buffer + // this is important for metal with apple silicon: if the entire model could be mapped to a metal buffer, + // then we could just use metal for all layers + // this allows using partial offloading when the model size exceeds the metal buffer size, but not the RAM size + void * addr = nullptr; + size_t first, last; // NOLINT + ml.get_mapping_range(&first, &last, &addr, idx, ctx); + if (first >= last) { + continue; } - } break; - case LLM_ARCH_GEMMA_EMBEDDING: - { - hparams.swa_type = LLAMA_SWA_TYPE_SYMMETRIC; - hparams.set_swa_pattern(6); - - hparams.causal_attn = false; // embeddings do not use causal attention + const size_t max_size = ggml_get_max_tensor_size(ctx); + ggml_backend_buffer_t buf = ggml_backend_dev_buffer_from_host_ptr(dev, (char *) addr + first, last - first, max_size); + if (buf == nullptr) { + throw std::runtime_error(format("unable to allocate %s buffer", ggml_backend_buft_name(buft))); + } + bufs.emplace_back(buf); + buf_map.emplace(idx, buf); + } + } else { + ggml_backend_buffer_t buf; + if (ml.no_alloc) { + buf = ggml_backend_buft_alloc_buffer(buft, /*size =*/ 0); // dummy buffer + for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != nullptr; t = ggml_get_next_tensor(ctx, t)) { + t->buffer = buf; // set dummy buffer for weights so that the backend scheduler won't try to allocate them + } + } else { + buf = ggml_backend_alloc_ctx_tensors_from_buft(ctx, buft); // real buffer + } + if (buf == nullptr) { + throw std::runtime_error(format("unable to allocate %s buffer", ggml_backend_buft_name(buft))); + } + if (use_mlock && ggml_backend_buffer_is_host(buf)) { + pimpl->mlock_bufs.emplace_back(new llama_mlock); + auto & mlock_buf = pimpl->mlock_bufs.back(); + mlock_buf->init (ggml_backend_buffer_get_base(buf)); + mlock_buf->grow_to(ggml_backend_buffer_get_size(buf)); + } + bufs.emplace_back(buf); + for (uint32_t idx = 0; idx < ml.files.size(); idx++) { + buf_map.emplace(idx, buf); + } + } - ml.get_key(LLM_KV_ROPE_FREQ_BASE_SWA, hparams.rope_freq_base_train_swa, false); - ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa); - ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); - ml.get_key(LLM_KV_POOLING_TYPE, hparams.pooling_type); + for (auto & buf : bufs) { + // indicate that this buffer contains weights + // this is used by ggml_backend_sched to improve op scheduling: ops that use a weight are preferably scheduled to the backend that contains the weight + ggml_backend_buffer_set_usage(buf.get(), GGML_BACKEND_BUFFER_USAGE_WEIGHTS); + } - //applied only if model converted with --sentence-transformers-dense-modules - ml.get_key(LLM_KV_DENSE_2_FEAT_IN, hparams.dense_2_feat_in, false); - ml.get_key(LLM_KV_DENSE_2_FEAT_OUT, hparams.dense_2_feat_out, false); - ml.get_key(LLM_KV_DENSE_3_FEAT_IN, hparams.dense_3_feat_in, false); - ml.get_key(LLM_KV_DENSE_3_FEAT_OUT, hparams.dense_3_feat_out, false); + pimpl->ctxs_bufs.emplace_back(std::move(ctx_ptr), std::move(bufs)); - GGML_ASSERT((hparams.dense_2_feat_in == 0 || hparams.dense_2_feat_in == hparams.n_embd) && "dense_2_feat_in must be equal to n_embd"); - GGML_ASSERT((hparams.dense_3_feat_out == 0 || hparams.dense_3_feat_out == hparams.n_embd) && "dense_3_feat_out must be equal to n_embd"); + ctx_buf_maps.emplace_back(ctx, buf_map); + } - switch (hparams.n_layer) { - case 24: type = LLM_TYPE_0_3B; break; - default: type = LLM_TYPE_UNKNOWN; - } - hparams.f_attention_scale = 1.0f / std::sqrt(float(hparams.n_embd_head_k)); + if (llama_supports_gpu_offload()) { + const int n_gpu = std::min(n_gpu_layers, n_layer_all); - } break; - case LLM_ARCH_STARCODER2: - { - ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); - switch (hparams.n_layer) { - case 30: type = LLM_TYPE_3B; break; - case 32: type = LLM_TYPE_7B; break; - case 40: type = LLM_TYPE_15B; break; - case 52: type = LLM_TYPE_20B; break; // granite - case 88: type = LLM_TYPE_34B; break; // granite - default: type = LLM_TYPE_UNKNOWN; - } - } break; - case LLM_ARCH_MAMBA: - { - ml.get_key(LLM_KV_SSM_CONV_KERNEL, hparams.ssm_d_conv); - ml.get_key(LLM_KV_SSM_INNER_SIZE, hparams.ssm_d_inner); - ml.get_key(LLM_KV_SSM_STATE_SIZE, hparams.ssm_d_state); - ml.get_key(LLM_KV_SSM_TIME_STEP_RANK, hparams.ssm_dt_rank); - ml.get_key(LLM_KV_SSM_DT_B_C_RMS, hparams.ssm_dt_b_c_rms, false); - - ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); - - switch (hparams.n_layer) { - case 24: - switch (hparams.n_embd) { - case 768: type = LLM_TYPE_SMALL; break; - default: type = LLM_TYPE_UNKNOWN; - } break; - case 48: - switch (hparams.n_embd) { - case 1024: type = LLM_TYPE_MEDIUM; break; - case 1536: type = LLM_TYPE_LARGE; break; - case 2048: type = LLM_TYPE_XL; break; - default: type = LLM_TYPE_UNKNOWN; - } break; - case 64: - switch (hparams.n_embd) { - case 2560: type = LLM_TYPE_3B; break; - default: type = LLM_TYPE_UNKNOWN; - } break; - default: type = LLM_TYPE_UNKNOWN; - } - } break; - case LLM_ARCH_MAMBA2: - { - ml.get_key(LLM_KV_SSM_CONV_KERNEL, hparams.ssm_d_conv); - ml.get_key(LLM_KV_SSM_INNER_SIZE, hparams.ssm_d_inner); - ml.get_key(LLM_KV_SSM_STATE_SIZE, hparams.ssm_d_state); - ml.get_key(LLM_KV_SSM_TIME_STEP_RANK, hparams.ssm_dt_rank); - ml.get_key(LLM_KV_SSM_GROUP_COUNT, hparams.ssm_n_group); - - ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); - - switch (hparams.n_layer) { - case 24: - switch (hparams.n_embd) { - case 768: type = LLM_TYPE_SMALL; break; - default: type = LLM_TYPE_UNKNOWN; - } break; - case 48: - switch (hparams.n_embd) { - case 1024: type = LLM_TYPE_MEDIUM; break; - case 1536: type = LLM_TYPE_LARGE; break; - case 2048: type = LLM_TYPE_XL; break; - default: type = LLM_TYPE_UNKNOWN; - } break; - case 64: - switch (hparams.n_embd) { - case 2560: type = LLM_TYPE_3B; break; - case 4096: type = LLM_TYPE_7B; break; - default: type = LLM_TYPE_UNKNOWN; - } break; - default: type = LLM_TYPE_UNKNOWN; - } - } break; - case LLM_ARCH_JAMBA: - { - ml.get_key(LLM_KV_SSM_CONV_KERNEL, hparams.ssm_d_conv); - ml.get_key(LLM_KV_SSM_INNER_SIZE, hparams.ssm_d_inner); - ml.get_key(LLM_KV_SSM_STATE_SIZE, hparams.ssm_d_state); - ml.get_key(LLM_KV_SSM_TIME_STEP_RANK, hparams.ssm_dt_rank); - - ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + int n_repeating = n_gpu; + if (n_repeating > 0) { + LLAMA_LOG_INFO("%s: offloading output layer to GPU\n", __func__); + n_repeating--; + } + LLAMA_LOG_INFO("%s: offloading %d repeating layers to GPU\n", __func__, n_repeating); - for (uint32_t i = 0; i < hparams.n_layer; ++i) { - hparams.recurrent_layer_arr[i] = hparams.n_head_kv(i) == 0; - } + const int max_backend_supported_layers = n_layer_all + 1; + const int max_offloadable_layers = n_layer_all + 1; - switch (hparams.n_layer) { - // TODO: Jamba layers are a bit heterogenous, so naming this is hard. - case 12: // 900M 8x???M - case 32: // 51B 16x?B - default: type = LLM_TYPE_UNKNOWN; - } - } break; - case LLM_ARCH_XVERSE: - { - ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); - switch (hparams.n_layer) { - case 32: type = LLM_TYPE_7B; break; - case 40: type = LLM_TYPE_13B; break; - case 80: type = LLM_TYPE_65B; break; - default: type = LLM_TYPE_UNKNOWN; - } - } break; - case LLM_ARCH_COMMAND_R: - { - ml.get_key(LLM_KV_LOGIT_SCALE, hparams.f_logit_scale); - ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); - switch (hparams.n_layer) { - case 40: type = LLM_TYPE_35B; break; - default: type = LLM_TYPE_UNKNOWN; - } - } break; - case LLM_ARCH_COHERE2: - { - hparams.swa_type = LLAMA_SWA_TYPE_STANDARD; - hparams.set_swa_pattern(4); - hparams.rope_freq_base_train_swa = hparams.rope_freq_base_train; - hparams.rope_freq_scale_train_swa = hparams.rope_freq_scale_train; - - ml.get_key(LLM_KV_ROPE_FREQ_BASE_SWA, hparams.rope_freq_base_train_swa, false); - ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa); - ml.get_key(LLM_KV_LOGIT_SCALE, hparams.f_logit_scale); - ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); - switch (hparams.n_layer) { - case 32: type = LLM_TYPE_8B; break; - default: type = LLM_TYPE_UNKNOWN; - } - } break; - case LLM_ARCH_DBRX: - { - ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); - ml.get_key(LLM_KV_ATTENTION_CLAMP_KQV, hparams.f_clamp_kqv); + LLAMA_LOG_INFO("%s: offloaded %d/%d layers to GPU\n", __func__, std::min(n_gpu_layers, max_offloadable_layers), max_backend_supported_layers); + } - switch (hparams.n_layer) { - case 40: type = LLM_TYPE_16x12B; break; - default: type = LLM_TYPE_UNKNOWN; - } - } break; - case LLM_ARCH_OLMO: - { - ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); - ml.get_key(LLM_KV_ATTENTION_CLAMP_KQV, hparams.f_clamp_kqv, false); - - switch (hparams.n_layer) { - case 22: type = LLM_TYPE_1B; break; - case 32: type = LLM_TYPE_7B; break; - case 80: type = LLM_TYPE_70B; break; - default: type = LLM_TYPE_UNKNOWN; - } - } break; - case LLM_ARCH_OLMO2: - { - ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + // print memory requirements per buffer type + for (auto & [_, bufs] : pimpl->ctxs_bufs) { + for (auto & buf: bufs) { + LLAMA_LOG_INFO("%s: %12s model buffer size = %8.2f MiB\n", + __func__, ggml_backend_buffer_name(buf.get()), ggml_backend_buffer_get_size(buf.get()) / 1024.0 / 1024.0); + } + } - const bool found_swa = ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa, false); - if (found_swa && hparams.n_swa > 0) { - hparams.swa_type = LLAMA_SWA_TYPE_STANDARD; - hparams.set_swa_pattern(4); + if (ml.no_alloc) { + return true; + } - hparams.rope_freq_base_train_swa = hparams.rope_freq_base_train; - hparams.rope_freq_scale_train_swa = 1.0; // See olmo2.cpp - ml.get_key(LLM_KV_ROPE_FREQ_BASE_SWA, hparams.rope_freq_base_train_swa, false); - } else { - hparams.swa_type = LLAMA_SWA_TYPE_NONE; - } + // load tensor data + for (auto & [ctx, buf_map] : ctx_buf_maps) { + if (!ml.load_all_data(ctx, buf_map, use_mlock ? &pimpl->mlock_mmaps : NULL, params.progress_callback, params.progress_callback_user_data)) { + return false; + } + } - switch (hparams.n_layer) { - case 16: type = LLM_TYPE_1B; break; - case 32: type = LLM_TYPE_7B; break; - case 40: type = LLM_TYPE_13B; break; - case 64: type = LLM_TYPE_32B; break; - default: type = LLM_TYPE_UNKNOWN; - } - } break; - case LLM_ARCH_SEED_OSS: - { - ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); - switch (hparams.n_layer) { - case 64: type = LLM_TYPE_36B; break; - default: type = LLM_TYPE_UNKNOWN; - } - } break; - case LLM_ARCH_OLMOE: - { - ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); - switch (hparams.n_layer) { - case 16: type = LLM_TYPE_A1_7B; break; - default: type = LLM_TYPE_UNKNOWN; - } - } break; - case LLM_ARCH_OPENELM: - { - ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); - - switch (hparams.n_layer) { - case 16: type = LLM_TYPE_270M; break; - case 20: type = LLM_TYPE_450M; break; - case 28: type = LLM_TYPE_1B; break; - case 36: type = LLM_TYPE_3B; break; - default: type = LLM_TYPE_UNKNOWN; - } - } break; - case LLM_ARCH_GPTNEOX: - { - ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); - ml.get_key(LLM_KV_USE_PARALLEL_RESIDUAL, hparams.use_par_res); - switch (hparams.n_layer) { - case 6: - switch (hparams.n_ff()) { - case 512: type = LLM_TYPE_14M; break; - case 2048: type = LLM_TYPE_70M; break; - default: type = LLM_TYPE_UNKNOWN; - } break; - case 12: - switch (hparams.n_ff()) { - case 3072: type = LLM_TYPE_160M; break; - default: type = LLM_TYPE_UNKNOWN; - } break; - case 16: - switch (hparams.n_ff()) { - case 8192: type = LLM_TYPE_1B; break; - default: type = LLM_TYPE_UNKNOWN; - } break; - case 24: - switch (hparams.n_ff()) { - case 4096: type = LLM_TYPE_410M; break; - case 8192: type = LLM_TYPE_1_4B; break; - default: type = LLM_TYPE_UNKNOWN; - } break; - case 32: - switch (hparams.n_ff()) { - case 10240: type = LLM_TYPE_2_8B; break; - case 16384: type = LLM_TYPE_6_9B; break; - default: type = LLM_TYPE_UNKNOWN; - } break; - case 36: - switch (hparams.n_ff()) { - case 20480: type = LLM_TYPE_12B; break; - default: type = LLM_TYPE_UNKNOWN; - } break; - case 44: - switch (hparams.n_ff()) { - case 24576: type = LLM_TYPE_20B; break; - default: type = LLM_TYPE_UNKNOWN; - } break; - default: type = LLM_TYPE_UNKNOWN; - } - } break; - case LLM_ARCH_ARCTIC: - { - ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + if (use_mmap_buffer) { + for (auto & mapping : ml.mappings) { + pimpl->mappings.emplace_back(std::move(mapping)); + } + } - if (hparams.n_expert == 128) { - switch (hparams.n_layer) { - case 35: type = LLM_TYPE_10B_128x3_66B; break; - default: type = LLM_TYPE_UNKNOWN; - } - } else { - type = LLM_TYPE_UNKNOWN; - } - } break; - case LLM_ARCH_DEEPSEEK: - { - ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); - ml.get_key(LLM_KV_LEADING_DENSE_BLOCK_COUNT, hparams.n_layer_dense_lead); - ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp); - ml.get_key(LLM_KV_EXPERT_SHARED_COUNT, hparams.n_expert_shared); - ml.get_key(LLM_KV_EXPERT_WEIGHTS_SCALE, hparams.expert_weights_scale); - - switch (hparams.n_ff_exp) { - case 1408: type = LLM_TYPE_16B; break; - case 1792: type = LLM_TYPE_20B; break; - default: type = LLM_TYPE_UNKNOWN; - } - } break; - case LLM_ARCH_DEEPSEEK2: - { - // lite variants include DeepSeek-V2-Lite, GigaChat3-10B-A1.8B - bool is_lite = (hparams.n_layer == 27 || hparams.n_layer == 26); - ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); - ml.get_key(LLM_KV_LEADING_DENSE_BLOCK_COUNT, hparams.n_layer_dense_lead); - if (!is_lite) { - ml.get_key(LLM_KV_ATTENTION_Q_LORA_RANK, hparams.n_lora_q); - } - ml.get_key(LLM_KV_ATTENTION_KV_LORA_RANK, hparams.n_lora_kv); - ml.get_key(LLM_KV_ATTENTION_KEY_LENGTH_MLA, hparams.n_embd_head_k_mla, false); - ml.get_key(LLM_KV_ATTENTION_VALUE_LENGTH_MLA, hparams.n_embd_head_v_mla, false); - ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp); - ml.get_key(LLM_KV_EXPERT_SHARED_COUNT, hparams.n_expert_shared); - ml.get_key(LLM_KV_EXPERT_WEIGHTS_SCALE, hparams.expert_weights_scale, false); - ml.get_key(LLM_KV_EXPERT_WEIGHTS_NORM, hparams.expert_weights_norm, false); - ml.get_key(LLM_KV_EXPERT_GATING_FUNC, hparams.expert_gating_func, false); - if (hparams.expert_gating_func == LLAMA_EXPERT_GATING_FUNC_TYPE_NONE) { - // for compatibility with existing DeepSeek V2 and V2.5 GGUFs - // that have no expert_gating_func model parameter set - hparams.expert_gating_func = LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX; - } + return true; +} - if (ml.get_key(LLM_KV_ROPE_SCALING_YARN_LOG_MUL, hparams.rope_yarn_log_mul, 0.0f)) { - // [TAG_DEEPSEEK2_YARN_LOG_MUL_FIX] - // cancel the factor from the convert script - hparams.rope_yarn_log_mul /= 0.1f; - } +ggml_tensor * llama_model_base::create_tensor(llama_model_loader & ml, const LLM_TN_IMPL & tn, const std::initializer_list<int64_t> & ne, int flags) { + const buft_list_t * buft_list_layer = tn.bid == -1 ? nullptr : pimpl->dev_layer.at(tn.bid).buft_list; + return ml.create_tensor( + hparams, &pimpl->cpu_buft_list, pimpl->dev_input.buft_list, pimpl->dev_output.buft_list, buft_list_layer, + tn, ne, flags); +} - // (optional) temperature tuning - used by mistral-large - ml.get_key(LLM_KV_ATTENTION_TEMPERATURE_SCALE, hparams.f_attn_temp_scale, false); - ml.get_key(LLM_KV_ATTENTION_TEMPERATURE_LENGTH, hparams.n_attn_temp_floor_scale, false); +std::string llama_model::arch_name() const { + return llm_arch_name(arch); +} - hparams.f_attn_temp_offset = 0.0f; +std::string llama_model::type_name() const { + return llm_type_name(type); +} - switch (hparams.n_layer) { - case 27: type = LLM_TYPE_16B; break; - case 60: type = LLM_TYPE_236B; break; - case 61: type = LLM_TYPE_671B; break; - default: type = LLM_TYPE_UNKNOWN; - } - } break; - case LLM_ARCH_PLM: - { - ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); - ml.get_key(LLM_KV_ATTENTION_KV_LORA_RANK, hparams.n_lora_kv); - switch (hparams.n_layer) { - case 32: type = LLM_TYPE_1_8B; break; - default: type = LLM_TYPE_UNKNOWN; - } - } break; - case LLM_ARCH_CHATGLM: - { - ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); - switch (hparams.n_layer) { - case 28: { - if (hparams.n_head(0) == 16) { - type = LLM_TYPE_1_5B; - } else { - type = LLM_TYPE_6B; - } - } break; - case 40: { - if (hparams.n_head(0) == 24) { - type = LLM_TYPE_4B; - } else { - type = LLM_TYPE_9B; - } - } break; - default: type = LLM_TYPE_UNKNOWN; - } - } break; - case LLM_ARCH_GLM4: - { - ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); - ml.get_key_or_arr(LLM_KV_ROPE_DIMENSION_SECTIONS, hparams.rope_sections, 4, false); - switch (hparams.n_layer) { - case 40: type = LLM_TYPE_9B; break; - case 61: type = LLM_TYPE_32B; break; - default: type = LLM_TYPE_UNKNOWN; - } - } break; - case LLM_ARCH_GLM4_MOE: - { - ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp); - ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); - ml.get_key_or_arr(LLM_KV_ROPE_DIMENSION_SECTIONS, hparams.rope_sections, 4, false); - - // MoE parameters - ml.get_key(LLM_KV_EXPERT_COUNT, hparams.n_expert); - ml.get_key(LLM_KV_EXPERT_USED_COUNT, hparams.n_expert_used); - ml.get_key(LLM_KV_EXPERT_SHARED_COUNT, hparams.n_expert_shared); - ml.get_key(LLM_KV_LEADING_DENSE_BLOCK_COUNT, hparams.n_layer_dense_lead, false); - ml.get_key(LLM_KV_EXPERT_WEIGHTS_SCALE, hparams.expert_weights_scale); - ml.get_key(LLM_KV_EXPERT_WEIGHTS_NORM, hparams.expert_weights_norm, false); - - // Expert gating function (GLM-4.5 uses sigmoid) - ml.get_key(LLM_KV_EXPERT_GATING_FUNC, hparams.expert_gating_func, false); - if (hparams.expert_gating_func == LLAMA_EXPERT_GATING_FUNC_TYPE_NONE) { - hparams.expert_gating_func = LLAMA_EXPERT_GATING_FUNC_TYPE_SIGMOID; - } +std::string llama_model::desc() const { + return pimpl->desc_str; +} - // NextN/MTP parameters - ml.get_key(LLM_KV_NEXTN_PREDICT_LAYERS, hparams.nextn_predict_layers, false); +size_t llama_model::size() const { + return pimpl->n_bytes; +} - // TODO: when MTP is implemented, this should probably be updated if needed - hparams.n_layer_kv_from_start = hparams.n_layer - hparams.nextn_predict_layers; +size_t llama_model::n_tensors() const { + return tensors_by_name.size(); +} - switch (hparams.n_layer) { - case 47: type = LLM_TYPE_106B_A12B; break; // GLM-4.5-Air (46 layers + 1 NextN layer) - case 48: type = LLM_TYPE_102B_A12B; break; // Solar Open - case 93: type = LLM_TYPE_355B_A32B; break; // GLM-4.5 (92 layers + 1 NextN layer) - default: type = LLM_TYPE_UNKNOWN; - } - } break; - case LLM_ARCH_BITNET: - { - ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); +size_t llama_model::n_devices() const { + return devices.size(); +} - switch (hparams.n_layer) { - case 26: type = LLM_TYPE_3B; break; - default: type = LLM_TYPE_UNKNOWN; - } - } break; - case LLM_ARCH_T5: - { - ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); - ml.get_key(LLM_KV_ATTENTION_RELATIVE_BUCKETS_COUNT, hparams.n_rel_attn_bkts); +const float * llama_model::tensor_split() const { + return params.tensor_split; +} - uint32_t dec_start_token_id; - if (ml.get_key(LLM_KV_DECODER_START_TOKEN_ID, dec_start_token_id, false)) { - hparams.dec_start_token_id = dec_start_token_id; - } +uint32_t llama_model::n_gpu_layers() const { + // note: plus 1 for the "output" layer + return params.n_gpu_layers >= 0 ? params.n_gpu_layers : hparams.n_layer_all + 1; +} - hparams.dec_n_layer = hparams.n_layer; - ml.get_key(LLM_KV_DECODER_BLOCK_COUNT, hparams.dec_n_layer, false); - - switch (hparams.n_layer) { - case 6: type = LLM_TYPE_60M; break; // t5-small - case 8: type = LLM_TYPE_80M; break; // flan-t5-small - case 12: - switch (hparams.n_ff()) { - case 3072: type = LLM_TYPE_220M; break; // t5-base - case 2048: type = LLM_TYPE_250M; break; // flan-t5-base - default: type = LLM_TYPE_UNKNOWN; - } break; - case 24: - switch (hparams.n_ff()) { - case 4096: type = LLM_TYPE_770M; break; // t5-large - case 2816: type = LLM_TYPE_780M; break; // flan-t5-large - case 16384: type = LLM_TYPE_3B; break; // t5-3b - case 5120: type = LLM_TYPE_3B; break; // flan-t5-xl - case 65536: type = LLM_TYPE_11B; break; // t5-11b - case 10240: type = LLM_TYPE_11B; break; // flan-t5-xxl - default: type = LLM_TYPE_UNKNOWN; - } break; - default: type = LLM_TYPE_UNKNOWN; - } - } break; - case LLM_ARCH_T5ENCODER: - { - ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); - ml.get_key(LLM_KV_ATTENTION_RELATIVE_BUCKETS_COUNT, hparams.n_rel_attn_bkts); - type = LLM_TYPE_UNKNOWN; - } break; - case LLM_ARCH_JAIS: - { - ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); - ml.get_key(LLM_KV_ATTENTION_MAX_ALIBI_BIAS, hparams.f_max_alibi_bias); - - switch (hparams.n_layer) { - case 24: type = LLM_TYPE_1_3B; break; - case 40: type = LLM_TYPE_13B; break; - /* TODO: add variants */ - default: type = LLM_TYPE_UNKNOWN; - } - } break; - case LLM_ARCH_NEMOTRON: - { - ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); - switch (hparams.n_layer) { - case 32: type = LLM_TYPE_4B; break; - default: type = LLM_TYPE_UNKNOWN; - } - } break; - case LLM_ARCH_NEMOTRON_H: - case LLM_ARCH_NEMOTRON_H_MOE: - { - ml.get_key(LLM_KV_SSM_CONV_KERNEL, hparams.ssm_d_conv); - ml.get_key(LLM_KV_SSM_INNER_SIZE, hparams.ssm_d_inner); - ml.get_key(LLM_KV_SSM_STATE_SIZE, hparams.ssm_d_state); - ml.get_key(LLM_KV_SSM_TIME_STEP_RANK, hparams.ssm_dt_rank); - ml.get_key(LLM_KV_SSM_GROUP_COUNT, hparams.ssm_n_group); - - // A layer is recurrent IFF the n_head_kv value is set to 0 and - // the n_ff value is set to 0 - for (uint32_t i = 0; i < hparams.n_layer; ++i) { - hparams.recurrent_layer_arr[i] = (hparams.n_head_kv(i) == 0 && hparams.n_ff(i) == 0); - } +llama_split_mode llama_model::split_mode() const { + return params.split_mode; +} - ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); +std::map<ggml_backend_buffer_type_t, size_t> llama_model::memory_breakdown() const { + std::map<ggml_backend_buffer_type_t, size_t> ret; + for (const auto & [ctx, bufs] : pimpl->ctxs_bufs) { + if (hparams.no_alloc) { + GGML_ASSERT(bufs.size() == 1); + ggml_backend_buffer_t buf = bufs[0].get(); + GGML_ASSERT(ggml_backend_buffer_get_base(buf) == nullptr); + ggml_backend_buffer_type_t buft = ggml_backend_buffer_get_type(buf); + ret[buft] += ggml_backend_alloc_ctx_tensors_from_buft_size(ctx.get(), buft); + } else { + for (const auto & buf : bufs) { + // GGML_ASSERT(ggml_backend_buffer_get_base(buf.get()) != nullptr); // multi_buffer does not have a defined base + ret[ggml_backend_buffer_get_type(buf.get())] += ggml_backend_buffer_get_size(buf.get()); + } + } + } + return ret; +} - ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp, false); - ml.get_key(LLM_KV_EXPERT_SHARED_FEED_FORWARD_LENGTH, hparams.n_ff_shexp, false); - ml.get_key(LLM_KV_EXPERT_SHARED_COUNT, hparams.n_expert_shared, false); - ml.get_key(LLM_KV_EXPERT_WEIGHTS_NORM, hparams.expert_weights_norm, false); - ml.get_key(LLM_KV_EXPERT_WEIGHTS_SCALE, hparams.expert_weights_scale, false); +uint64_t llama_model::n_elements() const { + return pimpl->n_elements; +} - switch (hparams.n_layer) { - case 52: type = LLM_TYPE_31B_A3_5B; break; // Nemotron-H_MOE 31B - case 56: type = LLM_TYPE_9B; break; - default: type = LLM_TYPE_UNKNOWN; - } - } break; - case LLM_ARCH_EXAONE: - { - ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); +void llama_model::print_info() const { + const std::string rope_scaling_type = llama_rope_scaling_type_name(hparams.rope_scaling_type_train); - switch (hparams.n_layer) { - case 32: type = LLM_TYPE_8B; break; - default: type = LLM_TYPE_UNKNOWN; - } - } break; - case LLM_ARCH_EXAONE4: - { - if (hparams.n_layer == 64) { // 32B - hparams.swa_type = LLAMA_SWA_TYPE_STANDARD; - hparams.n_swa = 4096; - hparams.set_swa_pattern(4); - - hparams.rope_freq_base_train_swa = hparams.rope_freq_base_train; - hparams.rope_freq_scale_train_swa = hparams.rope_freq_scale_train; - ml.get_key(LLM_KV_ROPE_FREQ_BASE_SWA, hparams.rope_freq_base_train_swa, false); - } + auto print_f = [](const std::function<int32_t(uint32_t)> & f, uint32_t n) { + bool is_var = false; - ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa, false); - ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + std::vector<int32_t> v; + for (uint32_t i = 0; i < n; ++i) { + v.push_back(f(i)); + if (v[i] != v[0]) { + is_var = true; + } + } - switch (hparams.n_layer) { - case 30: type = LLM_TYPE_1_2B; break; - case 64: type = LLM_TYPE_32B; break; - default: type = LLM_TYPE_UNKNOWN; - } - } break; - case LLM_ARCH_RWKV6: - case LLM_ARCH_RWKV6QWEN2: - { - ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps, false); - ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps, false); - ml.get_key(LLM_KV_WKV_HEAD_SIZE, hparams.wkv_head_size); - ml.get_key(LLM_KV_TIME_MIX_EXTRA_DIM, hparams.time_mix_extra_dim); - ml.get_key(LLM_KV_TIME_DECAY_EXTRA_DIM, hparams.time_decay_extra_dim); - ml.get_key(LLM_KV_RESCALE_EVERY_N_LAYERS, hparams.rescale_every_n_layers, false); - ml.get_key(LLM_KV_TOKEN_SHIFT_COUNT, hparams.token_shift_count, false); - - switch (hparams.n_layer) { - case 24: type = LLM_TYPE_1_6B; break; - case 32: - switch (hparams.n_embd) { - case 2560: type = LLM_TYPE_3B; break; - case 4096: type = LLM_TYPE_7B; break; - default: type = LLM_TYPE_UNKNOWN; - } break; - case 61: type = LLM_TYPE_14B; break; - case 64: type = LLM_TYPE_32B; break; - default: type = LLM_TYPE_UNKNOWN; - } - } break; - case LLM_ARCH_RWKV7: - case LLM_ARCH_ARWKV7: - { - ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps, false); - ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps, false); - ml.get_key(LLM_KV_WKV_HEAD_SIZE, hparams.wkv_head_size); - ml.get_key(LLM_KV_ATTENTION_DECAY_LORA_RANK, hparams.n_lora_decay); - ml.get_key(LLM_KV_ATTENTION_ICLR_LORA_RANK, hparams.n_lora_iclr); - ml.get_key(LLM_KV_ATTENTION_VALUE_RESIDUAL_MIX_LORA_RANK, hparams.n_lora_value_res_mix); - ml.get_key(LLM_KV_ATTENTION_GATE_LORA_RANK, hparams.n_lora_gate, false); - ml.get_key(LLM_KV_TOKEN_SHIFT_COUNT, hparams.token_shift_count, false); - - switch (hparams.n_layer) { - case 12: - switch (hparams.n_embd) { - case 768: type = LLM_TYPE_190M; break; - default: type = LLM_TYPE_UNKNOWN; - } break; - case 24: - switch (hparams.n_embd) { - case 1024: type = LLM_TYPE_450M; break; - case 2048: type = LLM_TYPE_1_5B; break; - default: type = LLM_TYPE_UNKNOWN; - } break; - case 28: - switch (hparams.n_embd) { - case 1536: type = LLM_TYPE_1_5B; break; - case 3584: type = LLM_TYPE_7B; break; - default: type = LLM_TYPE_UNKNOWN; - } break; - case 32: - switch (hparams.n_embd) { - case 2560: type = LLM_TYPE_2_9B; break; - case 4096: type = LLM_TYPE_7B; break; - default: type = LLM_TYPE_UNKNOWN; - } break; - case 61: - switch (hparams.n_embd) { - case 4096: type = LLM_TYPE_14B; break; - default: type = LLM_TYPE_UNKNOWN; - } break; - default: type = LLM_TYPE_UNKNOWN; - } - } break; - case LLM_ARCH_GRANITE: - case LLM_ARCH_GRANITE_MOE: - { - ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); - ml.get_key(LLM_KV_LOGIT_SCALE, hparams.f_logit_scale); - ml.get_key(LLM_KV_RESIDUAL_SCALE, hparams.f_residual_scale); - ml.get_key(LLM_KV_EMBEDDING_SCALE, hparams.f_embedding_scale); - ml.get_key(LLM_KV_ATTENTION_SCALE, hparams.f_attention_scale); - - // Granite uses rope_finetuned as a switch for rope, so default to true - bool rope_finetuned = true; - ml.get_key(LLM_KV_ROPE_SCALING_FINETUNED, rope_finetuned, false); - hparams.rope_finetuned = rope_finetuned; - - switch (hparams.n_layer) { - case 32: type = LLM_TYPE_3B; break; - case 40: type = LLM_TYPE_3B; break; - // Add additional layer/vocab/etc checks here for other model sizes - default: type = LLM_TYPE_UNKNOWN; - } + std::stringstream ss; - // For Granite MoE Shared - ml.get_key(LLM_KV_EXPERT_SHARED_FEED_FORWARD_LENGTH, hparams.n_ff_shexp, /* required */ false); - } break; - case LLM_ARCH_GRANITE_HYBRID: - { - ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); - ml.get_key(LLM_KV_LOGIT_SCALE, hparams.f_logit_scale, /* required */ false); - ml.get_key(LLM_KV_RESIDUAL_SCALE, hparams.f_residual_scale, /* required */ false); - ml.get_key(LLM_KV_EMBEDDING_SCALE, hparams.f_embedding_scale, /* required */ false); - ml.get_key(LLM_KV_ATTENTION_SCALE, hparams.f_attention_scale, /* required */ false); - - ml.get_key(LLM_KV_SSM_CONV_KERNEL, hparams.ssm_d_conv); - ml.get_key(LLM_KV_SSM_INNER_SIZE, hparams.ssm_d_inner); - ml.get_key(LLM_KV_SSM_STATE_SIZE, hparams.ssm_d_state); - ml.get_key(LLM_KV_SSM_TIME_STEP_RANK, hparams.ssm_dt_rank); - ml.get_key(LLM_KV_SSM_GROUP_COUNT, hparams.ssm_n_group); - - // Granite uses rope_finetuned as a switch for rope, so default to true - bool rope_finetuned = true; - ml.get_key(LLM_KV_ROPE_SCALING_FINETUNED, rope_finetuned, false); - hparams.rope_finetuned = rope_finetuned; - - // A layer is recurrent IFF the n_head_kv value is set to 0 - for (uint32_t i = 0; i < hparams.n_layer; ++i) { - hparams.recurrent_layer_arr[i] = hparams.n_head_kv(i) == 0; + if (is_var) { + ss << "["; + for (uint32_t i = 0; i < n; ++i) { + ss << v[i]; + if (i < n - 1) { + ss << ", "; } + } + ss << "]"; + } else { + ss << v[0]; + } - ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + return ss.str(); + }; - switch (hparams.n_embd) { - case 768: type = LLM_TYPE_350M; break; - case 1536: type = (hparams.n_embd == 2048 ? LLM_TYPE_7B_A1B : LLM_TYPE_1B); break; - case 2048: case 2560: type = LLM_TYPE_3B; break; - case 4096: type = LLM_TYPE_32B; break; - default: type = LLM_TYPE_UNKNOWN; - } + // hparams + LLAMA_LOG_INFO("%s: arch = %s\n", __func__, arch_name().c_str()); + LLAMA_LOG_INFO("%s: vocab_only = %d\n", __func__, hparams.vocab_only); + LLAMA_LOG_INFO("%s: no_alloc = %d\n", __func__, hparams.no_alloc); - // For Granite MoE Shared - ml.get_key(LLM_KV_EXPERT_SHARED_FEED_FORWARD_LENGTH, hparams.n_ff_shexp, /* required */ false); - } break; - case LLM_ARCH_CHAMELEON: - { - ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); - hparams.f_norm_eps = 1e-5; // eps for qk-norm, torch default - ml.get_key(LLM_KV_SWIN_NORM, hparams.swin_norm); - - switch (hparams.n_layer) { - case 32: type = LLM_TYPE_7B; break; - case 48: type = LLM_TYPE_34B; break; - default: type = LLM_TYPE_UNKNOWN; - } - } break; - case LLM_ARCH_WAVTOKENIZER_DEC: - { - ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); - ml.get_key(LLM_KV_ATTENTION_GROUPNORM_EPS, hparams.f_norm_group_eps); - ml.get_key(LLM_KV_ATTENTION_GROUPNORM_GROUPS, hparams.n_norm_groups); - ml.get_key(LLM_KV_ATTENTION_CAUSAL, hparams.causal_attn); - } break; - case LLM_ARCH_BAILINGMOE: - { - ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); - ml.get_key(LLM_KV_LEADING_DENSE_BLOCK_COUNT, hparams.n_layer_dense_lead); - ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp); - ml.get_key(LLM_KV_EXPERT_SHARED_COUNT, hparams.n_expert_shared); - ml.get_key(LLM_KV_EXPERT_WEIGHTS_SCALE, hparams.expert_weights_scale); - ml.get_key(LLM_KV_EXPERT_WEIGHTS_NORM, hparams.expert_weights_norm, false); - - switch (hparams.n_layer) { - case 28: type = LLM_TYPE_16B; break; - case 88: type = LLM_TYPE_290B; break; - default: type = LLM_TYPE_UNKNOWN; - } - } break; - case LLM_ARCH_BAILINGMOE2: - { - ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); - ml.get_key(LLM_KV_LEADING_DENSE_BLOCK_COUNT, hparams.n_layer_dense_lead); - ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp); - ml.get_key(LLM_KV_EXPERT_SHARED_FEED_FORWARD_LENGTH, hparams.n_ff_shexp); - ml.get_key(LLM_KV_EXPERT_SHARED_COUNT, hparams.n_expert_shared); - ml.get_key(LLM_KV_EXPERT_WEIGHTS_SCALE, hparams.expert_weights_scale); - ml.get_key(LLM_KV_EXPERT_WEIGHTS_NORM, hparams.expert_weights_norm, false); - ml.get_key(LLM_KV_EXPERT_GATING_FUNC, hparams.expert_gating_func); - ml.get_key(LLM_KV_NEXTN_PREDICT_LAYERS, hparams.nextn_predict_layers, false); - - // TODO: when MTP is implemented, this should probably be updated if needed - hparams.n_layer_kv_from_start = hparams.n_layer - hparams.nextn_predict_layers; - - switch (hparams.n_layer) { - case 20: type = LLM_TYPE_16B_A1B; break; - case 21: type = LLM_TYPE_16B_A1B; break; - case 32: type = LLM_TYPE_100B_A6B; break; - case 33: type = LLM_TYPE_100B_A6B; break; - default: type = LLM_TYPE_UNKNOWN; - } - } break; - case LLM_ARCH_DOTS1: - { - ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); - ml.get_key(LLM_KV_LEADING_DENSE_BLOCK_COUNT, hparams.n_layer_dense_lead); - ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp); - ml.get_key(LLM_KV_EXPERT_SHARED_COUNT, hparams.n_expert_shared); - ml.get_key(LLM_KV_EXPERT_WEIGHTS_SCALE, hparams.expert_weights_scale); - ml.get_key(LLM_KV_EXPERT_WEIGHTS_NORM, hparams.expert_weights_norm, false); - ml.get_key(LLM_KV_EXPERT_GATING_FUNC, hparams.expert_gating_func, false); - switch (hparams.n_layer) { - case 62: type = LLM_TYPE_142B; break; - default: type = LLM_TYPE_UNKNOWN; - } - } break; - case LLM_ARCH_ERNIE4_5: - case LLM_ARCH_ERNIE4_5_MOE: - { - ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); - if (arch == LLM_ARCH_ERNIE4_5_MOE) { - ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp); - ml.get_key(LLM_KV_EXPERT_SHARED_FEED_FORWARD_LENGTH, hparams.n_ff_shexp, false); - ml.get_key(LLM_KV_INTERLEAVE_MOE_LAYER_STEP, hparams.n_moe_layer_step); - ml.get_key(LLM_KV_LEADING_DENSE_BLOCK_COUNT, hparams.n_layer_dense_lead); - } - - switch (hparams.n_layer) { - case 18: type = LLM_TYPE_0_3B; break; - case 28: type = LLM_TYPE_21B_A3B; break; - case 54: type = LLM_TYPE_300B_A47B; break; - default: type = LLM_TYPE_UNKNOWN; - } - } break; - case LLM_ARCH_FALCON_H1: - { - // Common parameters - ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); - - // SSM parameters - ml.get_key(LLM_KV_SSM_CONV_KERNEL, hparams.ssm_d_conv); - ml.get_key(LLM_KV_SSM_INNER_SIZE, hparams.ssm_d_inner); - ml.get_key(LLM_KV_SSM_STATE_SIZE, hparams.ssm_d_state); - ml.get_key(LLM_KV_SSM_TIME_STEP_RANK, hparams.ssm_dt_rank); - ml.get_key(LLM_KV_SSM_GROUP_COUNT, hparams.ssm_n_group); - - std::fill(hparams.recurrent_layer_arr.begin(), hparams.recurrent_layer_arr.end(), true); - - switch (hparams.n_layer) { - case 36: - type = LLM_TYPE_0_5B; break; - case 24: - type = LLM_TYPE_1_5B; break; - case 66: - type = LLM_TYPE_1B; break; - case 32: - type = LLM_TYPE_3B; break; - case 44: - type = LLM_TYPE_7B; break; - case 72: - type = LLM_TYPE_34B; break; - default: - type = LLM_TYPE_UNKNOWN; - } - } break; - case LLM_ARCH_HUNYUAN_MOE: - { - ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); - ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp); - ml.get_key(LLM_KV_EXPERT_SHARED_FEED_FORWARD_LENGTH, hparams.n_ff_shexp); - - switch (hparams.n_layer) { - case 32: type = LLM_TYPE_A13B; break; - default: type = LLM_TYPE_UNKNOWN; - } - } break; - case LLM_ARCH_HUNYUAN_DENSE: - { - ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); - - switch (hparams.n_embd) { - case 1024: type = LLM_TYPE_0_5B; break; - case 2048: type = LLM_TYPE_1_8B; break; - case 3072: type = LLM_TYPE_4B; break; - case 4096: type = LLM_TYPE_7B; break; - default: type = LLM_TYPE_UNKNOWN; - } - } break; - case LLM_ARCH_SMOLLM3: - { - ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); - hparams.n_no_rope_layer_step = 4; - - switch (hparams.n_layer) { - case 36: type = LLM_TYPE_3B; break; - default: type = LLM_TYPE_UNKNOWN; - } - } break; - case LLM_ARCH_OPENAI_MOE: - { - ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); - ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp); - ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa); - - hparams.swa_type = LLAMA_SWA_TYPE_STANDARD; - hparams.set_swa_pattern(2); - - hparams.rope_freq_base_train_swa = hparams.rope_freq_base_train; - hparams.rope_freq_scale_train_swa = hparams.rope_freq_scale_train; - ml.get_key(LLM_KV_ROPE_FREQ_BASE_SWA, hparams.rope_freq_base_train_swa, false); - - switch (hparams.n_layer) { - case 24: type = LLM_TYPE_20B; break; - case 36: type = LLM_TYPE_120B; break; - default: type = LLM_TYPE_UNKNOWN; - } - } break; - case LLM_ARCH_LFM2: - { - ml.get_key(LLM_KV_SHORTCONV_L_CACHE, hparams.n_shortconv_l_cache); - ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); - for (uint32_t il = 0; il < hparams.n_layer; ++il) { - hparams.recurrent_layer_arr[il] = hparams.n_head_kv(il) == 0; - } - hparams.n_layer_dense_lead = hparams.n_layer; - switch (hparams.n_ff()) { - case 4608: type = LLM_TYPE_350M; break; - case 6912: type = LLM_TYPE_700M; break; - case 8192: type = LLM_TYPE_1_2B; break; - case 10752: type = LLM_TYPE_2_6B; break; - default: type = LLM_TYPE_UNKNOWN; - } - } break; - case LLM_ARCH_LFM2MOE: - { - ml.get_key(LLM_KV_SHORTCONV_L_CACHE, hparams.n_shortconv_l_cache); - ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); - ml.get_key(LLM_KV_LEADING_DENSE_BLOCK_COUNT, hparams.n_layer_dense_lead); - ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp); - ml.get_key(LLM_KV_EXPERT_GATING_FUNC, hparams.expert_gating_func); - - for (uint32_t il = 0; il < hparams.n_layer; ++il) { - hparams.recurrent_layer_arr[il] = hparams.n_head_kv(il) == 0; - } - - type = LLM_TYPE_8B_A1B; - } break; - case LLM_ARCH_SMALLTHINKER: - { - const bool found_swa = ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa, false); - - if (found_swa && hparams.n_swa > 0) { - hparams.swa_type = LLAMA_SWA_TYPE_STANDARD; - hparams.n_swa = 4096; - hparams.set_swa_pattern(4, true); - - hparams.rope_freq_base_train_swa = hparams.rope_freq_base_train; - hparams.rope_freq_scale_train_swa = hparams.rope_freq_scale_train; - ml.get_key(LLM_KV_ROPE_FREQ_BASE_SWA, hparams.rope_freq_base_train_swa, false); - } else { - hparams.swa_type = LLAMA_SWA_TYPE_NONE; - hparams.n_no_rope_layer_step = hparams.n_layer; - } - - ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp, false); - ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); - ml.get_key(LLM_KV_EXPERT_GATING_FUNC, hparams.expert_gating_func, false); + if (!hparams.vocab_only) { + LLAMA_LOG_INFO("%s: n_ctx_train = %u\n", __func__, hparams.n_ctx_train); + LLAMA_LOG_INFO("%s: n_embd_inp = %u\n", __func__, hparams.n_embd_inp()); + LLAMA_LOG_INFO("%s: n_embd = %u\n", __func__, hparams.n_embd); + LLAMA_LOG_INFO("%s: n_embd_out = %u\n", __func__, hparams.n_embd_out()); + LLAMA_LOG_INFO("%s: n_layer = %u\n", __func__, hparams.n_layer()); + LLAMA_LOG_INFO("%s: n_layer_all = %u\n", __func__, hparams.n_layer_all); + LLAMA_LOG_INFO("%s: n_head = %s\n", __func__, print_f([&](uint32_t il) { return hparams.n_head(il); }, hparams.n_layer_all).c_str()); + LLAMA_LOG_INFO("%s: n_head_kv = %s\n", __func__, print_f([&](uint32_t il) { return hparams.n_head_kv(il); }, hparams.n_layer_all).c_str()); + LLAMA_LOG_INFO("%s: n_rot = %u\n", __func__, hparams.n_rot_full); + LLAMA_LOG_INFO("%s: n_swa = %u\n", __func__, hparams.n_swa); + LLAMA_LOG_INFO("%s: is_swa_any = %u\n", __func__, hparams.is_swa_any()); + LLAMA_LOG_INFO("%s: n_embd_head_k = %u\n", __func__, hparams.n_embd_head_k_full); + LLAMA_LOG_INFO("%s: n_embd_head_v = %u\n", __func__, hparams.n_embd_head_v_full); + LLAMA_LOG_INFO("%s: n_gqa = %s\n", __func__, print_f([&](uint32_t il) { return hparams.n_gqa(il); }, hparams.n_layer_all).c_str()); + LLAMA_LOG_INFO("%s: n_embd_k_gqa = %s\n", __func__, print_f([&](uint32_t il) { return hparams.n_embd_k_gqa(il); }, hparams.n_layer_all).c_str()); + LLAMA_LOG_INFO("%s: n_embd_v_gqa = %s\n", __func__, print_f([&](uint32_t il) { return hparams.n_embd_v_gqa(il); }, hparams.n_layer_all).c_str()); + LLAMA_LOG_INFO("%s: f_norm_eps = %.1e\n", __func__, hparams.f_norm_eps); + LLAMA_LOG_INFO("%s: f_norm_rms_eps = %.1e\n", __func__, hparams.f_norm_rms_eps); + LLAMA_LOG_INFO("%s: f_clamp_kqv = %.1e\n", __func__, hparams.f_clamp_kqv); + LLAMA_LOG_INFO("%s: f_max_alibi_bias = %.1e\n", __func__, hparams.f_max_alibi_bias); + LLAMA_LOG_INFO("%s: f_logit_scale = %.1e\n", __func__, hparams.f_logit_scale); + LLAMA_LOG_INFO("%s: f_attn_scale = %.1e\n", __func__, hparams.f_attention_scale); + LLAMA_LOG_INFO("%s: f_attn_value_scale = %.4f\n", __func__, hparams.f_attn_value_scale); + LLAMA_LOG_INFO("%s: n_ff = %s\n", __func__, print_f([&](uint32_t il) { return hparams.n_ff(il); }, hparams.n_layer_all).c_str()); + LLAMA_LOG_INFO("%s: n_expert = %u\n", __func__, hparams.n_expert); + LLAMA_LOG_INFO("%s: n_expert_used = %u\n", __func__, hparams.n_expert_used); + LLAMA_LOG_INFO("%s: n_expert_groups = %d\n", __func__, hparams.n_expert_groups); + LLAMA_LOG_INFO("%s: n_group_used = %d\n", __func__, hparams.n_group_used); + LLAMA_LOG_INFO("%s: causal attn = %d\n", __func__, hparams.causal_attn); + LLAMA_LOG_INFO("%s: pooling type = %d\n", __func__, hparams.pooling_type); + LLAMA_LOG_INFO("%s: rope type = %d\n", __func__, hparams.rope_type); + LLAMA_LOG_INFO("%s: rope scaling = %s\n", __func__, rope_scaling_type.c_str()); + LLAMA_LOG_INFO("%s: freq_base_train = %.1f\n", __func__, hparams.rope_freq_base_train); + LLAMA_LOG_INFO("%s: freq_scale_train = %g\n", __func__, hparams.rope_freq_scale_train); + if (hparams.swa_type != LLAMA_SWA_TYPE_NONE) { + LLAMA_LOG_INFO("%s: freq_base_swa = %.1f\n", __func__, hparams.rope_freq_base_train_swa); + LLAMA_LOG_INFO("%s: freq_scale_swa = %g\n", __func__, hparams.rope_freq_scale_train_swa); + LLAMA_LOG_INFO("%s: n_embd_head_k_swa = %u\n", __func__, hparams.n_embd_head_k_swa); + LLAMA_LOG_INFO("%s: n_embd_head_v_swa = %u\n", __func__, hparams.n_embd_head_v_swa); + LLAMA_LOG_INFO("%s: n_rot_swa = %u\n", __func__, hparams.n_rot_swa); + } + LLAMA_LOG_INFO("%s: n_ctx_orig_yarn = %u\n", __func__, hparams.n_ctx_orig_yarn); + LLAMA_LOG_INFO("%s: rope_yarn_log_mul = %.4f\n", __func__, hparams.rope_yarn_log_mul); + LLAMA_LOG_INFO("%s: rope_finetuned = %s\n", __func__, hparams.rope_finetuned ? "yes" : "unknown"); + if (arch == LLM_ARCH_GRANITE && + std::any_of(hparams.deepstack_mapping_arr.begin(), + hparams.deepstack_mapping_arr.end(), + [](const auto & entry) { return entry >= 0; })) { + LLAMA_LOG_INFO("%s: deepstack_mapping_arr = %s\n", __func__, + print_f([&](uint32_t il) { return hparams.deepstack_mapping_arr[il]; }, + hparams.n_layer_all).c_str()); + } + // MRoPE (Multi-axis Rotary Position Embedding) sections + if (const auto & s = hparams.rope_sections; s[0] || s[1] || s[2] || s[3]) { + LLAMA_LOG_INFO("%s: mrope sections = [%d, %d, %d, %d]\n", __func__, s[0], s[1], s[2], s[3]); + } + if (!classifier_labels.empty()) { + LLAMA_LOG_INFO("%s: n_cls_out = %u\n", __func__, hparams.n_cls_out); - switch (hparams.n_layer) { - case 32: type = LLM_TYPE_4B; break; - case 52: type = LLM_TYPE_20B; break; - default: type = LLM_TYPE_UNKNOWN; - } - } break; - case LLM_ARCH_GROVEMOE: - { - ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp); - ml.get_key(LLM_KV_EXPERT_CHUNK_FEED_FORWARD_LENGTH, hparams.n_ff_chexp); - ml.get_key(LLM_KV_EXPERT_GROUP_SCALE, hparams.expert_group_scale); - ml.get_key(LLM_KV_EXPERTS_PER_GROUP, hparams.n_group_experts); - ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); - - switch (hparams.n_layer) { - case 48: type = LLM_TYPE_30B_A3B; break; - default: type = LLM_TYPE_UNKNOWN; - } - } break; - case LLM_ARCH_APERTUS: - { - ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); - ml.get_key_or_arr(LLM_KV_XIELU_ALPHA_N, hparams.xielu_alpha_n, hparams.n_layer); - ml.get_key_or_arr(LLM_KV_XIELU_ALPHA_P, hparams.xielu_alpha_p, hparams.n_layer); - ml.get_key_or_arr(LLM_KV_XIELU_BETA, hparams.xielu_beta, hparams.n_layer); - ml.get_key_or_arr(LLM_KV_XIELU_EPS, hparams.xielu_eps, hparams.n_layer); - - switch (hparams.n_layer) { - case 32: type = LLM_TYPE_8B; break; - default: type = LLM_TYPE_UNKNOWN; - } - } break; - case LLM_ARCH_MINIMAX_M2: - { - ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); - ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp); - ml.get_key(LLM_KV_EXPERT_GATING_FUNC, hparams.expert_gating_func, false); + size_t i = 0; + for (const auto & label : classifier_labels) { + LLAMA_LOG_INFO("%s: cls_label[%2zu] = %s\n", __func__, i++, label.c_str()); + } + } - switch (hparams.n_layer) { - case 62: type = LLM_TYPE_230B_A10B; break; - default: type = LLM_TYPE_UNKNOWN; - } - } break; - case LLM_ARCH_COGVLM: - { - ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); - switch (hparams.n_layer) { - case 32: type = LLM_TYPE_13B; break; - default: type = LLM_TYPE_UNKNOWN; - } - } break; - case LLM_ARCH_PANGU_EMBED: - { - ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); - switch (hparams.n_layer) { - case 26: type = LLM_TYPE_1B; break; // openPangu-Embedded-1B-V1.1 - case 34: type = LLM_TYPE_7B; break; // openPangu-Embedded-7B-V1.1 - default: type = LLM_TYPE_UNKNOWN; - } - } break; - case LLM_ARCH_QWEN3NEXT: - { - ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp, false); - ml.get_key(LLM_KV_EXPERT_SHARED_FEED_FORWARD_LENGTH, hparams.n_ff_shexp, false); - ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); - - // Load linear attention (gated delta net) parameters - ml.get_key(LLM_KV_SSM_CONV_KERNEL, hparams.ssm_d_conv); - ml.get_key(LLM_KV_SSM_INNER_SIZE, hparams.ssm_d_inner); - ml.get_key(LLM_KV_SSM_STATE_SIZE, hparams.ssm_d_state); - ml.get_key(LLM_KV_SSM_TIME_STEP_RANK, hparams.ssm_dt_rank); - ml.get_key(LLM_KV_SSM_GROUP_COUNT, hparams.ssm_n_group); - - // Mark recurrent layers (linear attention layers) - for (uint32_t i = 0; i < hparams.n_layer; ++i) { - hparams.recurrent_layer_arr[i] = ((i + 1) % 4 != 0); // TODO: extract the magic 4 from "full_attention_interval" - } + if (arch == LLM_ARCH_MAMBA || + arch == LLM_ARCH_MAMBA2 || + arch == LLM_ARCH_JAMBA || + arch == LLM_ARCH_FALCON_H1 || + arch == LLM_ARCH_PLAMO2 || + arch == LLM_ARCH_GRANITE_HYBRID || + arch == LLM_ARCH_QWEN3NEXT || + arch == LLM_ARCH_QWEN35 || + arch == LLM_ARCH_QWEN35MOE || + arch == LLM_ARCH_NEMOTRON_H || + arch == LLM_ARCH_NEMOTRON_H_MOE) { + LLAMA_LOG_INFO("%s: ssm_d_conv = %u\n", __func__, hparams.ssm_d_conv); + LLAMA_LOG_INFO("%s: ssm_d_inner = %u\n", __func__, hparams.ssm_d_inner); + LLAMA_LOG_INFO("%s: ssm_d_state = %u\n", __func__, hparams.ssm_d_state); + LLAMA_LOG_INFO("%s: ssm_dt_rank = %u\n", __func__, hparams.ssm_dt_rank); + LLAMA_LOG_INFO("%s: ssm_n_group = %u\n", __func__, hparams.ssm_n_group); + LLAMA_LOG_INFO("%s: ssm_dt_b_c_rms = %d\n", __func__, hparams.ssm_dt_b_c_rms); + } - switch (hparams.n_layer) { - case 48: type = LLM_TYPE_80B_A3B; break; - default: type = LLM_TYPE_UNKNOWN; - } - } break; - case LLM_ARCH_MISTRAL3: - { - ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); - ml.get_key(LLM_KV_ATTENTION_TEMPERATURE_SCALE, hparams.f_attn_temp_scale, false); + LLAMA_LOG_INFO("%s: model type = %s\n", __func__, type_name().c_str()); + if (pimpl->n_elements >= 1e12) { + LLAMA_LOG_INFO("%s: model params = %.2f T\n", __func__, pimpl->n_elements*1e-12); + } else if (pimpl->n_elements >= 1e9) { + LLAMA_LOG_INFO("%s: model params = %.2f B\n", __func__, pimpl->n_elements*1e-9); + } else if (pimpl->n_elements >= 1e6) { + LLAMA_LOG_INFO("%s: model params = %.2f M\n", __func__, pimpl->n_elements*1e-6); + } else { + LLAMA_LOG_INFO("%s: model params = %.2f K\n", __func__, pimpl->n_elements*1e-3); + } - ml.get_key(LLM_KV_ROPE_SCALING_YARN_BETA_FAST, hparams.yarn_beta_fast, false); - ml.get_key(LLM_KV_ROPE_SCALING_YARN_BETA_SLOW, hparams.yarn_beta_slow, false); - ml.get_key(LLM_KV_ROPE_SCALING_YARN_LOG_MUL, hparams.rope_yarn_log_mul, 0.0f); + // general kv + LLAMA_LOG_INFO("%s: general.name = %s\n", __func__, name.c_str()); - hparams.f_attn_temp_offset = 0.0f; + if (arch == LLM_ARCH_DEEPSEEK) { + LLAMA_LOG_INFO("%s: n_layer_dense_lead = %d\n", __func__, hparams.n_layer_dense_lead); + LLAMA_LOG_INFO("%s: n_ff_exp = %d\n", __func__, hparams.n_ff_exp); + LLAMA_LOG_INFO("%s: n_expert_shared = %d\n", __func__, hparams.n_expert_shared); + LLAMA_LOG_INFO("%s: expert_weights_scale = %.1f\n", __func__, hparams.expert_weights_scale); + } - // TODO: maybe add n_attn_temp_floor_scale as a separate KV? - if (hparams.f_attn_temp_scale != 0.0f) { - hparams.n_attn_temp_floor_scale = hparams.n_ctx_orig_yarn; - if (hparams.n_attn_temp_floor_scale == 0) { - throw std::runtime_error("invalid n_ctx_orig_yarn for attention temperature scaling"); - } - } + if (arch == LLM_ARCH_DEEPSEEK2 || arch == LLM_ARCH_DEEPSEEK2OCR || arch == LLM_ARCH_DEEPSEEK32 || arch == LLM_ARCH_GLM_DSA || arch == LLM_ARCH_MISTRAL4) { + LLAMA_LOG_INFO("%s: n_layer_dense_lead = %d\n", __func__, hparams.n_layer_dense_lead); + LLAMA_LOG_INFO("%s: n_lora_q = %d\n", __func__, hparams.n_lora_q); + LLAMA_LOG_INFO("%s: n_lora_kv = %d\n", __func__, hparams.n_lora_kv); + LLAMA_LOG_INFO("%s: n_embd_head_k_mla = %d\n", __func__, hparams.n_embd_head_k_mla()); + LLAMA_LOG_INFO("%s: n_embd_head_v_mla = %d\n", __func__, hparams.n_embd_head_v_mla()); + LLAMA_LOG_INFO("%s: n_ff_exp = %d\n", __func__, hparams.n_ff_exp); + LLAMA_LOG_INFO("%s: n_expert_shared = %d\n", __func__, hparams.n_expert_shared); + LLAMA_LOG_INFO("%s: expert_weights_scale = %.1f\n", __func__, hparams.expert_weights_scale); + LLAMA_LOG_INFO("%s: expert_weights_norm = %d\n", __func__, hparams.expert_weights_norm); + LLAMA_LOG_INFO("%s: expert_gating_func = %s\n", __func__, llama_expert_gating_func_name((llama_expert_gating_func_type) hparams.expert_gating_func)); + } - switch (hparams.n_layer) { - case 26: type = LLM_TYPE_3B; break; - case 34: type = LLM_TYPE_8B; break; - case 40: type = LLM_TYPE_14B; break; - default: type = LLM_TYPE_UNKNOWN; - } - } break; - case LLM_ARCH_MIMO2: - { - ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + if (arch == LLM_ARCH_QWEN2MOE) { + LLAMA_LOG_INFO("%s: n_ff_exp = %d\n", __func__, hparams.n_ff_exp); + LLAMA_LOG_INFO("%s: n_ff_shexp = %d\n", __func__, hparams.n_ff_shexp); + } - hparams.swa_type = LLAMA_SWA_TYPE_STANDARD; + if (arch == LLM_ARCH_MELLUM || + arch == LLM_ARCH_QWEN3MOE || + arch == LLM_ARCH_OPENAI_MOE || + arch == LLM_ARCH_QWEN3VLMOE || + arch == LLM_ARCH_RND1) { + LLAMA_LOG_INFO("%s: n_ff_exp = %d\n", __func__, hparams.n_ff_exp); + } - ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp); - ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa); - ml.get_key(LLM_KV_ROPE_FREQ_BASE_SWA, hparams.rope_freq_base_train_swa); - ml.get_key_or_arr(LLM_KV_ATTENTION_SLIDING_WINDOW_PATTERN, hparams.swa_layers, hparams.n_layer); + if (arch == LLM_ARCH_MINICPM || + arch == LLM_ARCH_GRANITE || + arch == LLM_ARCH_GRANITE_MOE || + arch == LLM_ARCH_GRANITE_HYBRID || + arch == LLM_ARCH_NEMOTRON_H_MOE) { + LLAMA_LOG_INFO("%s: f_embedding_scale = %f\n", __func__, hparams.f_embedding_scale); + LLAMA_LOG_INFO("%s: f_residual_scale = %f\n", __func__, hparams.f_residual_scale); + LLAMA_LOG_INFO("%s: f_attention_scale = %f\n", __func__, hparams.f_attention_scale); + LLAMA_LOG_INFO("%s: n_ff_shexp = %d\n", __func__, hparams.n_ff_shexp); + } - switch (hparams.n_layer) { - case 48: type = LLM_TYPE_310B_A15B; break; - default: type = LLM_TYPE_UNKNOWN; - } - } break; - default: throw std::runtime_error("unsupported model architecture"); - } + if (arch == LLM_ARCH_BAILINGMOE) { + LLAMA_LOG_INFO("%s: n_layer_dense_lead = %d\n", __func__, hparams.n_layer_dense_lead); + LLAMA_LOG_INFO("%s: n_ff_exp = %d\n", __func__, hparams.n_ff_exp); + LLAMA_LOG_INFO("%s: n_expert_shared = %d\n", __func__, hparams.n_expert_shared); + LLAMA_LOG_INFO("%s: expert_weights_scale = %.1f\n", __func__, hparams.expert_weights_scale); + LLAMA_LOG_INFO("%s: expert_weights_norm = %d\n", __func__, hparams.expert_weights_norm); + } - pimpl->n_bytes = ml.n_bytes; + if (arch == LLM_ARCH_BAILINGMOE2) { + LLAMA_LOG_INFO("%s: n_layer_dense_lead = %d\n", __func__, hparams.n_layer_dense_lead); + LLAMA_LOG_INFO("%s: n_ff_exp = %d\n", __func__, hparams.n_ff_exp); + LLAMA_LOG_INFO("%s: n_ff_shexp = %d\n", __func__, hparams.n_ff_shexp); + LLAMA_LOG_INFO("%s: n_expert_shared = %d\n", __func__, hparams.n_expert_shared); + LLAMA_LOG_INFO("%s: expert_weights_scale = %.1f\n", __func__, hparams.expert_weights_scale); + LLAMA_LOG_INFO("%s: expert_weights_norm = %d\n", __func__, hparams.expert_weights_norm); + LLAMA_LOG_INFO("%s: expert_gating_func = %s\n", __func__, llama_expert_gating_func_name((llama_expert_gating_func_type) hparams.expert_gating_func)); + LLAMA_LOG_INFO("%s: n_layer_nextn = %d\n", __func__, hparams.n_layer_nextn); + } - pimpl->desc_str = arch_name() + " " + type_name() + " " + ml.ftype_name(); + if (arch == LLM_ARCH_SMALLTHINKER || arch == LLM_ARCH_LFM2MOE) { + LLAMA_LOG_INFO("%s: n_ff_exp = %d\n", __func__, hparams.n_ff_exp); + LLAMA_LOG_INFO("%s: expert_gating_func = %s\n", __func__, llama_expert_gating_func_name((llama_expert_gating_func_type) hparams.expert_gating_func)); + } - if (hparams.f_max_alibi_bias > 0.0f) { - hparams.use_alibi = true; + if (arch == LLM_ARCH_GROVEMOE) { + LLAMA_LOG_INFO("%s: n_ff_exp = %d\n", __func__, hparams.n_ff_exp); + LLAMA_LOG_INFO("%s: n_ff_chexp = %d\n", __func__, hparams.n_ff_chexp); + LLAMA_LOG_INFO("%s: n_group_experts = %d\n", __func__, hparams.n_group_experts); + LLAMA_LOG_INFO("%s: expert_group_scale = %.2f\n", __func__, hparams.expert_group_scale); + } } - hparams.rope_type = llama_model_rope_type(this); + vocab.print_info(); } -void llama_model::load_vocab(llama_model_loader & ml) { - const auto kv = LLM_KV(arch); - - vocab.load(ml, kv); +ggml_backend_dev_t llama_model::dev_layer(int il) const { + return pimpl->dev_layer.at(il).dev; } -bool llama_model::load_tensors(llama_model_loader & ml) { - const auto & split_mode = params.split_mode; - const auto & use_mlock = params.use_mlock; - const auto & tensor_split = params.tensor_split; - - const int n_layer = hparams.n_layer; - const int n_gpu_layers = this->n_gpu_layers(); - - const bool use_mmap_buffer = true; - - LLAMA_LOG_INFO("%s: loading model tensors, this can take a while... (mmap = %s, direct_io = %s)\n", - __func__, ml.use_mmap ? "true" : "false", ml.use_direct_io ? "true" : "false"); - - // build a list of buffer types for the CPU and GPU devices - pimpl->cpu_buft_list = make_cpu_buft_list(devices, params.use_extra_bufts, params.no_host); - for (auto * dev : devices) { - buft_list_t buft_list = make_gpu_buft_list(dev, split_mode, tensor_split); - // add CPU buffer types as a fallback - buft_list.insert(buft_list.end(), pimpl->cpu_buft_list.begin(), pimpl->cpu_buft_list.end()); - pimpl->gpu_buft_list.emplace(dev, std::move(buft_list)); - } - - ggml_backend_dev_t cpu_dev = ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_CPU); - if (cpu_dev == nullptr) { - throw std::runtime_error(format("%s: no CPU backend found", __func__)); - } - - // calculate the split points - bool all_zero = tensor_split == nullptr || std::all_of(tensor_split, tensor_split + n_devices(), [](float x) { return x == 0.0f; }); - std::vector<float> splits(n_devices()); - if (all_zero) { - // default split, by free memory - for (size_t i = 0; i < n_devices(); ++i) { - ggml_backend_dev_t dev = devices[i]; - size_t total; - size_t free; - ggml_backend_dev_memory(dev, &free, &total); - - // devices can return 0 bytes for free and total memory if they do not - // have any to report. in this case, we will use the host memory as a fallback - // fixes: https://github.com/ggml-org/llama.cpp/issues/18577 - if (free == 0 && total == 0) { - ggml_backend_dev_memory(cpu_dev, &free, &total); - } - splits[i] = free; - } - } else { - std::copy(tensor_split, tensor_split + n_devices(), splits.begin()); - } - - // sum and normalize the splits to get the split points - float split_sum = 0.0f; - for (size_t i = 0; i < n_devices(); ++i) { - split_sum += splits[i]; - splits[i] = split_sum; - } - for (size_t i = 0; i < n_devices(); ++i) { - splits[i] /= split_sum; - } +ggml_backend_dev_t llama_model::dev_output() const { + return pimpl->dev_output.dev; +} - const int i_gpu_start = std::max(int(hparams.n_layer) + 1 - n_gpu_layers, 0); - const int act_gpu_layers = devices.empty() ? 0 : std::min(n_gpu_layers, int(n_layer) + 1); - auto get_layer_buft_list = [&](int il) -> llama_model::impl::layer_dev { - const bool is_swa = il < int(hparams.n_layer) && hparams.is_swa(il); - if (il < i_gpu_start || (il - i_gpu_start) >= act_gpu_layers) { - LLAMA_LOG_DEBUG("load_tensors: layer %3d assigned to device %s, is_swa = %d\n", il, ggml_backend_dev_name(cpu_dev), is_swa); - return {cpu_dev, &pimpl->cpu_buft_list}; - } - const int layer_gpu = std::upper_bound(splits.begin(), splits.begin() + n_devices(), float(il - i_gpu_start)/act_gpu_layers) - splits.begin(); - auto * dev = devices.at(layer_gpu); - LLAMA_LOG_DEBUG("load_tensors: layer %3d assigned to device %s, is_swa = %d\n", il, ggml_backend_dev_name(dev), is_swa); - return {dev, &pimpl->gpu_buft_list.at(dev)}; +template<typename F> +static bool buft_supported(ggml_backend_buffer_type_t buft, ggml_backend_dev_t dev, F & fn) { + ggml_init_params params = { + /*.mem_size =*/ ggml_tensor_overhead()*8, + /*.mem_buffer =*/ NULL, + /*.no_alloc =*/ true, }; - // assign the input layer - // there is very little benefit to offloading the input layer, so always keep it on the CPU - pimpl->dev_input = { cpu_dev, &pimpl->cpu_buft_list }; - - // assign the repeating layers to the devices according to the splits - pimpl->dev_layer.resize(n_layer); - for (int il = 0; il < n_layer; ++il) { - pimpl->dev_layer[il] = get_layer_buft_list(il); + ggml_context_ptr ctx { ggml_init(params) }; + if (!ctx) { + throw std::runtime_error(format("failed to create ggml context")); } - // assign the output layer - pimpl->dev_output = get_layer_buft_list(n_layer); - - // one ggml context per buffer type - int max_n_tensors = ml.n_tensors; - max_n_tensors += 1; // duplicated output tensor - max_n_tensors += n_layer*2; // duplicated rope freq tensors - const size_t ctx_size = ggml_tensor_overhead()*max_n_tensors; - - // define a comparator for the buft -> ctx map to ensure that the order is well-defined: - struct ggml_backend_buft_comparator { - bool operator()(const ggml_backend_buffer_type_t & lhs, const ggml_backend_buffer_type_t & rhs) const { - return strcmp(ggml_backend_buft_name(lhs), ggml_backend_buft_name(rhs)) < 0; - } - }; - std::map<ggml_backend_buffer_type_t, ggml_context_ptr, ggml_backend_buft_comparator> ctx_map; - - auto ctx_for_buft = [&](ggml_backend_buffer_type_t buft) -> ggml_context * { - auto it = ctx_map.find(buft); - if (it == ctx_map.end()) { - ggml_init_params params = { - /*.mem_size =*/ ctx_size, - /*.mem_buffer =*/ NULL, - /*.no_alloc =*/ true, - }; - - ggml_context * ctx = ggml_init(params); - if (!ctx) { - throw std::runtime_error(format("failed to create ggml context")); - } - - ctx_map.emplace(buft, ctx); - - return ctx; - } - return it->second.get(); - }; - - const auto TENSOR_DUPLICATED = llama_model_loader::TENSOR_DUPLICATED; - const auto TENSOR_NOT_REQUIRED = llama_model_loader::TENSOR_NOT_REQUIRED; - const auto TENSOR_SKIP = llama_model_loader::TENSOR_SKIP; - - // create tensors for the weights - { - // note: cast to int64_t since we will use these for the tensor dimensions - const int64_t n_head = hparams.n_head(); - const int64_t n_head_kv = hparams.n_head_kv(); - const int64_t n_embd = hparams.n_embd; - const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(); - const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa(); - const int64_t n_embd_head_k = hparams.n_embd_head_k; - const int64_t n_embd_head_v = hparams.n_embd_head_v; - const int64_t n_ff = hparams.n_ff(); - const int64_t n_embd_gqa = n_embd_v_gqa; - const int64_t n_vocab = vocab.n_tokens(); - const int64_t n_token_types = vocab.n_token_types(); - const int64_t n_rot = hparams.n_rot; - const int64_t n_expert = hparams.n_expert; - const int64_t n_expert_used = hparams.n_expert_used; - const int64_t n_ctx_train = hparams.n_ctx_train; - - if (n_expert > 0 && hparams.n_expert_used == 0) { - throw std::runtime_error("model has expert layers but no expert layers are used"); + ggml_backend_buffer_ptr buf { ggml_backend_buft_alloc_buffer(buft, 0) }; + ggml_tensor * op_tensor = fn(ctx.get()); + for (int i = 0; i < GGML_MAX_SRC; i++) { + if (op_tensor->src[i] != nullptr) { + assert(op_tensor->src[i]->buffer == nullptr); + op_tensor->src[i]->buffer = buf.get(); } + } - int n_moved_tensors = 0; - ggml_tensor * first_moved_tensor = nullptr; - ggml_backend_buffer_type_t first_moved_from_buft = nullptr; - ggml_backend_buffer_type_t first_moved_to_buft = nullptr; - - auto create_tensor = [&](const LLM_TN_IMPL & tn, const std::initializer_list<int64_t> & ne, int flags) -> ggml_tensor * { - ggml_tensor * t_meta = ml.get_tensor_meta(tn.str().c_str()); - - if (!t_meta) { - if (flags & TENSOR_NOT_REQUIRED) { - return nullptr; - } - throw std::runtime_error(format("missing tensor '%s'", tn.str().c_str())); - } - - // some models use the token embedding tensor as the output, but since these are used in different layers and with different ops - // the tensor is duplicated - // to handle this, we check if the tensor is duplicated, and if so, we assume that it is being loaded as the output tensor - llm_tensor tn_tensor = tn.tensor; - if (tn.tensor == LLM_TENSOR_TOKEN_EMBD && flags & TENSOR_DUPLICATED) { - tn_tensor = LLM_TENSOR_OUTPUT; - } - - llm_tensor_info info; - try { - info = llm_tensor_info_for(tn_tensor); - } catch (const std::out_of_range & e) { - throw std::runtime_error(format("missing tensor info mapping for %s", tn.str().c_str())); - } - - // skip unused tensors - if (info.op == GGML_OP_NONE || flags & TENSOR_SKIP) { - const size_t nbytes = ggml_nbytes(t_meta); - LLAMA_LOG_WARN("model has unused tensor %s (size = %zu bytes) -- ignoring\n", tn.str().c_str(), nbytes); - - ml.size_data -= nbytes; - ml.n_created++; - - return nullptr; - } + bool op_supported = ggml_backend_dev_supports_op(dev, op_tensor); - // tensors with "bias" suffix are always used with GGML_OP_ADD or GGML_OP_ADD_ID - ggml_op op; - bool bias = tn.suffix != nullptr && strcmp(tn.suffix, "bias") == 0; - if (bias) { - if (info.op == GGML_OP_MUL_MAT_ID) { - op = GGML_OP_ADD_ID; - } else { - op = GGML_OP_ADD; - } - } else { - op = info.op; - } + return op_supported; +} - // sanity checks - if (info.layer == LLM_TENSOR_LAYER_INPUT || info.layer == LLM_TENSOR_LAYER_OUTPUT) { - if (tn.bid != -1) { - GGML_ABORT("input/output layer tensor %s used with a layer number", tn.str().c_str()); - } - } else { - if (tn.bid == -1) { - GGML_ABORT("repeating layer tensor %s used without a layer number", tn.str().c_str()); - } - } - - // select the buffer type for this tensor - buft_list_t * buft_list; - switch (info.layer) { - case LLM_TENSOR_LAYER_INPUT: - buft_list = pimpl->dev_input.buft_list; - break; - case LLM_TENSOR_LAYER_OUTPUT: - buft_list = pimpl->dev_output.buft_list; - break; - case LLM_TENSOR_LAYER_REPEATING: - buft_list = pimpl->dev_layer.at(tn.bid).buft_list; - break; - default: - GGML_ABORT("invalid layer %d for tensor %s", info.layer, tn.str().c_str()); - } - - ggml_backend_buffer_type_t buft = nullptr; - - // check overrides - if (ml.tensor_buft_overrides) { - std::string tensor_name = tn.str(); - for (const auto * overrides = ml.tensor_buft_overrides; overrides->pattern != nullptr; ++overrides) { - std::regex pattern(overrides->pattern); - if (std::regex_search(tensor_name, pattern)) { - if (overrides->buft == ggml_backend_cpu_buffer_type()) { - // when overriding to a CPU buffer, consider the extra buffer types - buft = select_weight_buft(hparams, t_meta, op, pimpl->cpu_buft_list); - } else { - buft = overrides->buft; - } - - LLAMA_LOG_DEBUG("tensor %s (%zu MiB %s) buffer type overridden to %s\n", - tensor_name.c_str(), - ggml_nbytes(t_meta) / 1024 / 1024, ggml_type_name(t_meta->type), - ggml_backend_buft_name(buft)); - break; - } - } - } - - if (!buft) { - buft = select_weight_buft(hparams, t_meta, op, *buft_list); - if (!buft) { - throw std::runtime_error(format("failed to find a compatible buffer type for tensor %s", tn.str().c_str())); - } - } - - // avoid using a host buffer when using mmap - auto * buft_dev = ggml_backend_buft_get_device(buft); - if (ml.use_mmap && buft_dev && buft == ggml_backend_dev_host_buffer_type(buft_dev)) { - auto * cpu_dev = ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_CPU); - if (!cpu_dev) { - throw std::runtime_error("no CPU backend found"); - } - buft = ggml_backend_dev_buffer_type(cpu_dev); - } - - if (buft != buft_list->front().second) { - n_moved_tensors++; - if (!first_moved_tensor) { - first_moved_tensor = t_meta; - first_moved_from_buft = buft_list->front().second; - first_moved_to_buft = buft; - } - } - - ggml_context * ctx = ctx_for_buft(buft); - - // if duplicated, check if the original tensor was allocated in the same buffer type context and avoid creating a new one - if (flags & TENSOR_DUPLICATED) { - ggml_tensor * t = ggml_get_tensor(ctx, tn.str().c_str()); - if (t) { - return t; - } - } - return ml.create_tensor(ctx, tn, ne, flags); - }; - - layers.resize(n_layer); - - // TODO: move to a separate function - const auto tn = LLM_TN(arch); - switch (arch) { - case LLM_ARCH_LLAMA: - case LLM_ARCH_REFACT: - case LLM_ARCH_MINICPM: - case LLM_ARCH_GRANITE: - case LLM_ARCH_GRANITE_MOE: - case LLM_ARCH_MISTRAL3: - case LLM_ARCH_LLAMA_EMBED: - { - tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); - - // output - output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); - output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); - - // if output is NULL, init from the input tok embed - if (output == NULL) { - output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); - } - - for (int i = 0; i < n_layer; ++i) { - auto & layer = layers[i]; - - layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); - - layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head}, 0); - layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_k_gqa}, 0); - layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_v_gqa}, 0); - layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0); - - // optional bias tensors - layer.bq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); - layer.bk = create_tensor(tn(LLM_TENSOR_ATTN_K, "bias", i), {n_embd_gqa}, TENSOR_NOT_REQUIRED); - layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V, "bias", i), {n_embd_gqa}, TENSOR_NOT_REQUIRED); - layer.bo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); - - layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); - - if (hparams.rope_scaling_type_train == LLAMA_ROPE_SCALING_TYPE_LONGROPE) { - layer.rope_long = create_tensor(tn(LLM_TENSOR_ROPE_FACTORS_LONG, "weight", i), {n_rot/2}, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0)); - layer.rope_short = create_tensor(tn(LLM_TENSOR_ROPE_FACTORS_SHORT, "weight", i), {n_rot/2}, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0)); - } - else { - layer.rope_freqs = create_tensor(tn(LLM_TENSOR_ROPE_FREQS, "weight", i), {n_rot/2}, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0)); - } - - if (n_expert == 0) { - layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); - layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); - layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); - - // optional MLP bias - layer.ffn_gate_b = create_tensor(tn(LLM_TENSOR_FFN_GATE, "bias", i), {n_ff}, TENSOR_NOT_REQUIRED); - layer.ffn_down_b = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); - layer.ffn_up_b = create_tensor(tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff}, TENSOR_NOT_REQUIRED); - } else { - layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0); - layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {n_embd, n_ff, n_expert}, TENSOR_NOT_REQUIRED); - layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), { n_ff, n_embd, n_expert}, 0); - layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), {n_embd, n_ff, n_expert}, 0); - - // For Granite MoE Shared - if (hparams.n_ff_shexp > 0) { - layer.ffn_gate_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), {n_embd, hparams.n_ff_shexp}, 0); - layer.ffn_up_shexp = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), {n_embd, hparams.n_ff_shexp}, 0); - layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), {hparams.n_ff_shexp, n_embd}, 0); - } - } - } - } break; - case LLM_ARCH_LLADA: - { - tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), { n_embd, n_vocab }, 0); - - // output - output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), { n_embd }, 0); - output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), { n_embd, n_vocab }, TENSOR_NOT_REQUIRED); - - // if output is NULL, init from the input tok embed - if (output == NULL) { - output = - create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), { n_embd, n_vocab }, TENSOR_DUPLICATED); - } - - for (int i = 0; i < n_layer; ++i) { - auto & layer = layers[i]; - - layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), { n_embd }, 0); - - // Use separate Q, K, V projections without bias, matching LLaDALlamaBlock - layer.wq = - create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), { n_embd, n_embd_head_k * n_head }, 0); - layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), { n_embd, n_embd_k_gqa }, 0); - layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), { n_embd, n_embd_v_gqa }, 0); - // No bias for QKV projections as per config: include_bias=false, include_qkv_bias=false - layer.wo = - create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), { n_embd_head_k * n_head, n_embd }, 0); - layer.bo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), { n_embd }, TENSOR_NOT_REQUIRED); - - layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), { n_embd }, 0); - - layer.rope_freqs = create_tensor(tn(LLM_TENSOR_ROPE_FREQS, "weight", i), { n_rot / 2 }, - TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0)); - - layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), { n_embd, n_ff }, 0); - layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd }, 0); - layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), { n_embd, n_ff }, 0); - - // optional MLP bias - layer.ffn_gate_b = - create_tensor(tn(LLM_TENSOR_FFN_GATE, "bias", i), { n_ff }, TENSOR_NOT_REQUIRED); - layer.ffn_down_b = - create_tensor(tn(LLM_TENSOR_FFN_DOWN, "bias", i), { n_embd }, TENSOR_NOT_REQUIRED); - layer.ffn_up_b = create_tensor(tn(LLM_TENSOR_FFN_UP, "bias", i), { n_ff }, TENSOR_NOT_REQUIRED); - } - } - break; - case LLM_ARCH_LLADA_MOE: - { - tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); - - // output - output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); - output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0); - - GGML_ASSERT(n_expert > 0 && "n_expert must be > 0 for llada-moe"); - GGML_ASSERT(n_expert_used > 0 && "n_expert_used must be > 0 for llada-moe"); - - for (int i = 0; i < n_layer; ++i) { - auto & layer = layers[i]; - - layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); - - layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}, 0); - layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}, 0); - layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}, 0); - layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); - layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head_k}, 0); - layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k}, 0); - - layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); - - layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0); - - const int64_t n_ff_exp = hparams.n_ff_exp ? hparams.n_ff_exp : n_ff / n_expert_used; - - layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}, 0); - layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff_exp, n_embd, n_expert}, 0); - layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}, 0); - } - } break; - case LLM_ARCH_LLAMA4: - { - tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); - - // output - output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); - output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); - - // if output is NULL, init from the input tok embed - if (output == NULL) { - output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); - } - - for (int i = 0; i < n_layer; ++i) { - bool is_moe_layer = hparams.n_moe_layer_step > 0 && (i + 1) % hparams.n_moe_layer_step == 0; - - auto & layer = layers[i]; - - layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); - - layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head}, 0); - layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_k_gqa}, 0); - layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_v_gqa}, 0); - layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0); - - layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); - - layer.rope_freqs = create_tensor(tn(LLM_TENSOR_ROPE_FREQS, "weight", i), {n_rot/2}, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0)); - - if (is_moe_layer) { - int n_ff_exp = hparams.n_ff_exp; - - layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0); - layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {n_embd, n_ff_exp, n_expert}, 0); - layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), { n_ff_exp, n_embd, n_expert}, 0); - layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), {n_embd, n_ff_exp, n_expert}, 0); - - // Shared expert - const int64_t n_ff_shexp = n_ff_exp; - layer.ffn_gate_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), { n_embd, n_ff_shexp}, 0); - layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), {n_ff_shexp, n_embd }, 0); - layer.ffn_up_shexp = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), { n_embd, n_ff_shexp}, 0); - } else { - layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); - layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); - layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); - } - } - } break; - case LLM_ARCH_DECI: - { - tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); - - // output - output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); - output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); - - // if output is NULL, init from the input tok embed - if (output == NULL) { - output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); - } - - for (int i = 0; i < n_layer; ++i) { - auto & layer = layers[i]; - const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(i); - const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa(i); - const int64_t n_embd_gqa = hparams.n_embd_v_gqa(i); - const int64_t n_ff = hparams.n_ff(i); - const int64_t n_head = hparams.n_head(i); - const int64_t n_head_kv = hparams.n_head_kv(i); - - if (n_head_kv == 0 && n_head > 0) { - // linear attention for DeciLMCausalModel - layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); - layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); - } - else if (n_head_kv > 0) { - layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); - - layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head}, 0); - layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_k_gqa}, 0); - layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_v_gqa}, 0); - layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0); - } - - // optional bias tensors - layer.bq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); - layer.bk = create_tensor(tn(LLM_TENSOR_ATTN_K, "bias", i), {n_embd_gqa}, TENSOR_NOT_REQUIRED); - layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V, "bias", i), {n_embd_gqa}, TENSOR_NOT_REQUIRED); - layer.bo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); - - if (n_ff > 0) { - layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); - } - - if (hparams.rope_scaling_type_train == LLAMA_ROPE_SCALING_TYPE_LONGROPE) { - layer.rope_long = create_tensor(tn(LLM_TENSOR_ROPE_FACTORS_LONG, "weight", i), {n_rot/2}, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0)); - layer.rope_short = create_tensor(tn(LLM_TENSOR_ROPE_FACTORS_SHORT, "weight", i), {n_rot/2}, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0)); - } - else { - layer.rope_freqs = create_tensor(tn(LLM_TENSOR_ROPE_FREQS, "weight", i), {n_rot/2}, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0)); - } - - if (n_ff > 0) { - layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); - layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); - layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); - } - - // optional MLP bias - layer.ffn_gate_b = create_tensor(tn(LLM_TENSOR_FFN_GATE, "bias", i), {n_ff}, TENSOR_NOT_REQUIRED); - layer.ffn_down_b = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); - layer.ffn_up_b = create_tensor(tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff}, TENSOR_NOT_REQUIRED); - } - } break; - case LLM_ARCH_MINICPM3: - { - const int64_t n_embd_head_qk_rope = hparams.n_rot; - const int64_t n_embd_head_qk_nope = hparams.n_embd_head_k - hparams.n_rot; - - const int64_t q_lora_rank = hparams.n_lora_q; - const int64_t kv_lora_rank = hparams.n_lora_kv; - tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); - - // output - output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); - output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); - - // if output is NULL, init from the input tok embed - if (output == NULL) { - output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); - } - - for (int i = 0; i < n_layer; ++i) { - auto & layer = layers[i]; - - layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); - layer.attn_q_a_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_A_NORM, "weight", i), {q_lora_rank}, 0); - - layer.attn_kv_a_norm = create_tensor(tn(LLM_TENSOR_ATTN_KV_A_NORM, "weight", i), {kv_lora_rank}, 0); - - layer.wq_a = create_tensor(tn(LLM_TENSOR_ATTN_Q_A, "weight", i), {n_embd, q_lora_rank}, 0); - layer.wq_b = create_tensor(tn(LLM_TENSOR_ATTN_Q_B, "weight", i), {q_lora_rank, n_head * n_embd_head_k}, 0); - - layer.wkv_a_mqa = create_tensor(tn(LLM_TENSOR_ATTN_KV_A_MQA, "weight", i), {n_embd, kv_lora_rank + (n_embd_head_qk_rope)}, 0); - layer.wkv_b = create_tensor(tn(LLM_TENSOR_ATTN_KV_B, "weight", i), {kv_lora_rank, n_head * (n_embd_head_qk_nope + n_embd_head_v)}, 0); - layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), { n_head * ( n_embd_head_v), n_embd}, 0); - - layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); - - layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); - layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); - layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); - - layer.rope_long = create_tensor(tn(LLM_TENSOR_ROPE_FACTORS_LONG, "weight", i), { n_embd_head_qk_rope/2 }, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0)); - layer.rope_short = create_tensor(tn(LLM_TENSOR_ROPE_FACTORS_SHORT, "weight", i), { n_embd_head_qk_rope/2 }, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0)); - } - } break; - case LLM_ARCH_GROK: - { - if (n_expert == 0) { - throw std::runtime_error("Grok model cannot have zero experts"); - } - - tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); - - // output - output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); - output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); - - // if output is NULL, init from the input tok embed - if (output == NULL) { - output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); - } - - const int64_t n_ff_exp = hparams.n_ff_exp ? hparams.n_ff_exp : n_ff/* / n_expert_used*/; // grok-1 n_ff_exp == n_ff - for (int i = 0; i < n_layer; ++i) { - auto & layer = layers[i]; - - layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); - - layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}, 0); - layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}, 0); - layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}, 0); - layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); - - layer.attn_out_norm = create_tensor(tn(LLM_TENSOR_ATTN_OUT_NORM, "weight", i), {n_embd}, 0); - - layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); - - layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, TENSOR_NOT_REQUIRED); - layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, TENSOR_NOT_REQUIRED); - layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, TENSOR_NOT_REQUIRED); - - layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0); - layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {n_embd, n_ff_exp, n_expert}, TENSOR_NOT_REQUIRED); - layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff_exp, n_embd, n_expert}, 0); - layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), {n_embd, n_ff_exp, n_expert}, 0); - - layer.ffn_post_norm = create_tensor(tn(LLM_TENSOR_LAYER_OUT_NORM, "weight", i), {n_embd}, TENSOR_NOT_REQUIRED); - if (!layer.ffn_post_norm) { - layer.ffn_post_norm = create_tensor(tn(LLM_TENSOR_FFN_POST_NORM, "weight", i), {n_embd}, 0); - } - } - } break; - case LLM_ARCH_DBRX: - { - if (n_expert == 0) { - throw std::runtime_error("DBRX model cannot have zero experts"); - } - - tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); - - // output - output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); - output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0); - - for (int i = 0; i < n_layer; ++i) { - auto & layer = layers[i]; - - layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); - - layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, 0); - layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); - - layer.attn_out_norm = create_tensor(tn(LLM_TENSOR_ATTN_OUT_NORM, "weight", i), {n_embd}, 0); - - layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0); - layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {n_embd, n_ff, n_expert}, 0); - layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff, n_embd, n_expert}, 0); - layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), {n_embd, n_ff, n_expert}, 0); - } - } break; - case LLM_ARCH_BAICHUAN: - { - tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); - { - output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); - output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0); - } - - for (int i = 0; i < n_layer; ++i) { - auto & layer = layers[i]; - - layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); - - layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}, 0); - layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}, 0); - layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}, 0); - layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); - - layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); - - layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); - layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); - layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); - } - } break; - case LLM_ARCH_FALCON: - { - tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); - - // output - { - output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); - output_norm_b = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd}, 0); - - output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); - if (!output) { - output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); // needs to be on GPU - } - } - - for (int i = 0; i < n_layer; ++i) { - auto & layer = layers[i]; - - layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); - layer.attn_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "bias", i), {n_embd}, 0); - - layer.attn_norm_2 = create_tensor(tn(LLM_TENSOR_ATTN_NORM_2, "weight", i), {n_embd}, TENSOR_NOT_REQUIRED); - layer.attn_norm_2_b = create_tensor(tn(LLM_TENSOR_ATTN_NORM_2, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); - - layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, 0); - layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); - - layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); - layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); - } - } break; - case LLM_ARCH_STARCODER: - { - tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); - pos_embd = create_tensor(tn(LLM_TENSOR_POS_EMBD, "weight"), {n_embd, n_ctx_train}, 0); - - // output - { - output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); - output_norm_b = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd}, 0); - output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); - if (!output) { - // needs to be on GPU - output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); - } - - } - - for (int i = 0; i < n_layer; ++i) { - auto & layer = layers[i]; - - layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); - layer.attn_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "bias", i), {n_embd}, 0); - - layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, 0); - layer.bqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "bias", i), {n_embd + 2*n_embd_gqa}, 0); - - layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); - layer.bo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, 0); - - layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); - layer.ffn_norm_b = create_tensor(tn(LLM_TENSOR_FFN_NORM, "bias", i), {n_embd}, 0); - - layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0); - layer.ffn_down_b = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, 0); - - layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); - layer.ffn_up_b = create_tensor(tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff}, 0); - } - } break; - case LLM_ARCH_BERT: - case LLM_ARCH_NOMIC_BERT: - case LLM_ARCH_NOMIC_BERT_MOE: - case LLM_ARCH_JINA_BERT_V3: - { - tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); - type_embd = create_tensor(tn(LLM_TENSOR_TOKEN_TYPES, "weight"), {n_embd, n_token_types}, TENSOR_NOT_REQUIRED); - - if (arch == LLM_ARCH_BERT) { - pos_embd = create_tensor(tn(LLM_TENSOR_POS_EMBD, "weight"), {n_embd, n_ctx_train}, 0); - - cls = create_tensor(tn(LLM_TENSOR_CLS, "weight"), {n_embd, n_embd}, TENSOR_NOT_REQUIRED); - cls_b = create_tensor(tn(LLM_TENSOR_CLS, "bias"), {n_embd}, TENSOR_NOT_REQUIRED); - - cls_out = create_tensor(tn(LLM_TENSOR_CLS_OUT, "weight"), {n_embd, hparams.n_cls_out}, TENSOR_NOT_REQUIRED); - cls_out_b = create_tensor(tn(LLM_TENSOR_CLS_OUT, "bias"), {hparams.n_cls_out}, TENSOR_NOT_REQUIRED); - } - - tok_norm = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "weight"), {n_embd}, 0); - tok_norm_b = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "bias"), {n_embd}, 0); - - for (int i = 0; i < n_layer; ++i) { - auto & layer = layers[i]; - - layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, TENSOR_NOT_REQUIRED); - layer.bqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "bias", i), {n_embd + 2*n_embd_gqa}, TENSOR_NOT_REQUIRED); - - if (!layer.wqkv) { - layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}, 0); - layer.bq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_embd}, 0); - - layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}, 0); - layer.bk = create_tensor(tn(LLM_TENSOR_ATTN_K, "bias", i), {n_embd_gqa}, 0); - - layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}, 0); - layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V, "bias", i), {n_embd_gqa}, 0); - } - - layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); - layer.bo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); - - layer.attn_out_norm = create_tensor(tn(LLM_TENSOR_ATTN_OUT_NORM, "weight", i), {n_embd}, 0); - layer.attn_out_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_OUT_NORM, "bias", i), {n_embd}, 0); - - if (hparams.moe_every_n_layers > 0 && i % hparams.moe_every_n_layers == 1) { - layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), { n_embd, n_ff, n_expert}, 0); - layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), { n_ff, n_embd, n_expert}, 0); - layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0); - } else { - layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); - layer.ffn_up_b = create_tensor(tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff}, TENSOR_NOT_REQUIRED); - layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0); - layer.ffn_down_b = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); - - if (arch == LLM_ARCH_NOMIC_BERT) { - layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); - } - } - - layer.layer_out_norm = create_tensor(tn(LLM_TENSOR_LAYER_OUT_NORM, "weight", i), {n_embd}, 0); - layer.layer_out_norm_b = create_tensor(tn(LLM_TENSOR_LAYER_OUT_NORM, "bias", i), {n_embd}, 0); - } - } break; - case LLM_ARCH_MODERN_BERT: - { - tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); - tok_norm = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "weight"), {n_embd}, 0); - - output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); - - for(int i = 0; i < n_layer; ++i) { - auto& layer = layers[i]; - - if ( i != 0 ) { - layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); - } else{ - // layer 0 uses identity - layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, TENSOR_NOT_REQUIRED); - } - - - layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, 3 * n_embd }, 0); - layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); - - layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, 2 * n_ff}, 0); - layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0); - layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); - } - - cls = create_tensor(tn(LLM_TENSOR_CLS, "weight"), {n_embd, n_embd}, TENSOR_NOT_REQUIRED); - cls_out = create_tensor(tn(LLM_TENSOR_CLS_OUT, "weight"), {n_embd, hparams.n_cls_out}, TENSOR_NOT_REQUIRED); - cls_out_b = create_tensor(tn(LLM_TENSOR_CLS_OUT, "bias"), {hparams.n_cls_out}, TENSOR_NOT_REQUIRED); - - } break; - case LLM_ARCH_NEO_BERT: - { - tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); - - cls = create_tensor(tn(LLM_TENSOR_CLS, "weight"), {n_embd, n_embd}, TENSOR_NOT_REQUIRED); - cls_b = create_tensor(tn(LLM_TENSOR_CLS, "bias"), {n_embd}, TENSOR_NOT_REQUIRED); - - cls_out = create_tensor(tn(LLM_TENSOR_CLS_OUT, "weight"), {n_embd, hparams.n_cls_out}, TENSOR_NOT_REQUIRED); - cls_out_b = create_tensor(tn(LLM_TENSOR_CLS_OUT, "bias"), {hparams.n_cls_out}, TENSOR_NOT_REQUIRED); - - output_norm_enc = create_tensor(tn(LLM_TENSOR_ENC_OUTPUT_NORM, "weight"), {n_embd}, 0); - - for (int i = 0; i < n_layer; ++i) { - auto & layer = layers[i]; - - layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); - - layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, 0); - layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); - - layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); - - layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff*2}, 0); - layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0); - } - } break; - case LLM_ARCH_JINA_BERT_V2: - { - tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); // word_embeddings - type_embd = create_tensor(tn(LLM_TENSOR_TOKEN_TYPES, "weight"), {n_embd, n_token_types}, 0); // token_type_embeddings - - tok_norm = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "weight"), {n_embd}, 0); // LayerNorm - tok_norm_b = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "bias"), {n_embd}, 0); //LayerNorm bias - - cls = create_tensor(tn(LLM_TENSOR_CLS, "weight"), {n_embd, 1}, TENSOR_NOT_REQUIRED); - cls_b = create_tensor(tn(LLM_TENSOR_CLS, "bias"), {1}, TENSOR_NOT_REQUIRED); - for (int i = 0; i < n_layer; ++i) { - auto & layer = layers[i]; // JinaBertLayer - - layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}, 0); - layer.bq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_embd}, 0); - - layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd}, TENSOR_NOT_REQUIRED); - layer.attn_q_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); - - layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}, 0); - layer.bk = create_tensor(tn(LLM_TENSOR_ATTN_K, "bias", i), {n_embd_gqa}, 0); - - layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd}, TENSOR_NOT_REQUIRED); - layer.attn_k_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); - - layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}, 0); - layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V, "bias", i), {n_embd_gqa}, 0); - - layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); //output_dens - layer.bo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, 0); //output_dens - - layer.attn_out_norm = create_tensor(tn(LLM_TENSOR_ATTN_OUT_NORM, "weight", i), {n_embd}, 0); //output_norm - layer.attn_out_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_OUT_NORM, "bias", i), {n_embd}, 0); - - layer.attn_norm_2 = create_tensor(tn(LLM_TENSOR_ATTN_NORM_2, "weight", i), {n_embd}, TENSOR_NOT_REQUIRED); - layer.attn_norm_2_b = create_tensor(tn(LLM_TENSOR_ATTN_NORM_2, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); - - layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, TENSOR_NOT_REQUIRED); - - const auto tn_ffn_up_weight = tn(LLM_TENSOR_FFN_UP, "weight", i); - ggml_tensor * t_ffn_up = ml.get_tensor_meta(tn_ffn_up_weight.str().c_str()); - const int64_t n_ffn_up = t_ffn_up ? t_ffn_up->ne[1] : n_ff; - - GGML_ASSERT(n_ffn_up == n_ff || n_ffn_up == n_ff * 2); - layer.ffn_up = create_tensor(tn_ffn_up_weight, {n_embd, n_ffn_up}, 0); - layer.ffn_up_b = create_tensor(tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ffn_up}, TENSOR_NOT_REQUIRED); - - layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0); - layer.ffn_down_b = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, 0); - - layer.layer_out_norm = create_tensor(tn(LLM_TENSOR_LAYER_OUT_NORM, "weight", i), {n_embd}, 0); - layer.layer_out_norm_b = create_tensor(tn(LLM_TENSOR_LAYER_OUT_NORM, "bias", i), {n_embd}, 0); - } - } break; - case LLM_ARCH_BLOOM: - { - tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); - tok_norm = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "weight"), {n_embd}, 0); - tok_norm_b = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "bias"), {n_embd}, 0); - - // output - output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); - output_norm_b = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd}, 0); - output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); - - // if output is NULL, init from the input tok embed - if (output == NULL) { - output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); - } - - for (int i = 0; i < n_layer; ++i) { - auto & layer = layers[i]; - - layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); - layer.attn_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "bias", i), {n_embd}, 0); - - layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, 0); - layer.bqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "bias", i), {n_embd + 2*n_embd_gqa}, 0); - - layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); - layer.bo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, 0); - - layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); - layer.ffn_norm_b = create_tensor(tn(LLM_TENSOR_FFN_NORM, "bias", i), {n_embd}, 0); - - layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0); - layer.ffn_down_b = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, 0); - - layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); - layer.ffn_up_b = create_tensor(tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff}, 0); - } - } break; - case LLM_ARCH_MPT: - { - tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); - pos_embd = create_tensor(tn(LLM_TENSOR_POS_EMBD, "weight"), {n_embd, n_ctx_train}, TENSOR_NOT_REQUIRED); - - // output - output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); - output_norm_b = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd}, TENSOR_NOT_REQUIRED); - - output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); - if (!output) { - output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); // needs to be on GPU - } - - for (int i = 0; i < n_layer; ++i) { - auto & layer = layers[i]; - - layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); - layer.attn_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); - - layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, 0); - layer.bqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "bias", i), {n_embd + 2*n_embd_gqa}, TENSOR_NOT_REQUIRED); - - layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); - layer.bo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); - - layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); - layer.ffn_norm_b = create_tensor(tn(LLM_TENSOR_FFN_NORM, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); - - layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0); - layer.ffn_down_b = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); - - layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); - layer.ffn_up_b = create_tensor(tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff}, TENSOR_NOT_REQUIRED); - - layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd}, TENSOR_NOT_REQUIRED); - layer.attn_q_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); - - layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd}, TENSOR_NOT_REQUIRED); - layer.attn_k_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); - - // AWQ ScaleActivation layer - layer.ffn_act = create_tensor(tn(LLM_TENSOR_FFN_ACT, "scales", i), {n_ff}, TENSOR_NOT_REQUIRED); - } - } break; - case LLM_ARCH_STABLELM: - { - tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); - - // output - output_norm_b = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd}, 0); - output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); - output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0); - - for (int i = 0; i < n_layer; ++i) { - auto & layer = layers[i]; - - layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); - layer.attn_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "bias", i), {n_embd}, 0); - - layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}, 0); - layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}, 0); - layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}, 0); - layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); - - // optional bias tensors, present in Stable LM 2 1.6B - layer.bq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); - layer.bk = create_tensor(tn(LLM_TENSOR_ATTN_K, "bias", i), {n_embd_gqa}, TENSOR_NOT_REQUIRED); - layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V, "bias", i), {n_embd_gqa}, TENSOR_NOT_REQUIRED); - - // optional q and k layernorms, present in StableLM 2 12B - layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head_k, n_head}, TENSOR_NOT_REQUIRED); - layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k, n_head_kv}, TENSOR_NOT_REQUIRED); - - // optional FFN norm, not present in StableLM 2 12B which uses parallel residual - layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, TENSOR_NOT_REQUIRED); - layer.ffn_norm_b = create_tensor(tn(LLM_TENSOR_FFN_NORM, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); - - layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); - layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); - layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); - } - } break; - case LLM_ARCH_QWEN: - { - tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); - - // output - output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); - output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0); - - for (int i = 0; i < n_layer; ++i) { - auto & layer = layers[i]; - - layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); - - layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd*3}, 0); - layer.bqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "bias", i), {n_embd*3}, 0); - layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); - - layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); - - layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff/2}, 0); - layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff/2, n_embd}, 0); - layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff/2}, 0); - } - } break; - case LLM_ARCH_QWEN2: - case LLM_ARCH_QWEN2VL: - case LLM_ARCH_DREAM: - { - tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); - - // output - output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); - output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); - output_b = create_tensor(tn(LLM_TENSOR_OUTPUT, "bias"), {n_vocab}, TENSOR_NOT_REQUIRED); - // if output is NULL, init from the input tok embed - if (output == NULL) { - output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); - } - - for (int i = 0; i < n_layer; ++i) { - auto & layer = layers[i]; - - layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); - - layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}, 0); - layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}, 0); - layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}, 0); - layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); - - // optional bias tensors - layer.bq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); - layer.bk = create_tensor(tn(LLM_TENSOR_ATTN_K, "bias", i), {n_embd_gqa}, TENSOR_NOT_REQUIRED); - layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V, "bias", i), {n_embd_gqa}, TENSOR_NOT_REQUIRED); - - layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); - - layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); - layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); - layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); - } - } break; - case LLM_ARCH_QWEN2MOE: - { - tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); - - // output - output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); - output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0); - - for (int i = 0; i < n_layer; ++i) { - auto & layer = layers[i]; - - layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); - - layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}, 0); - layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}, 0); - layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}, 0); - layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); - - // optional bias tensors - layer.bq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); - layer.bk = create_tensor(tn(LLM_TENSOR_ATTN_K, "bias", i), {n_embd_gqa}, TENSOR_NOT_REQUIRED); - layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V, "bias", i), {n_embd_gqa}, TENSOR_NOT_REQUIRED); - - layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); - - layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0); - - if (n_expert == 0) { - throw std::runtime_error("n_expert must be > 0 for QWEN2MOE"); - } - if (n_expert_used == 0) { - throw std::runtime_error("n_expert_used must be > 0 for QWEN2MOE"); - } - - // MoE branch - const int64_t n_ff_exp = hparams.n_ff_exp ? hparams.n_ff_exp : n_ff / n_expert_used; - - layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}, 0); - layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff_exp, n_embd, n_expert}, 0); - layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}, 0); - - // Shared expert branch - const int64_t n_ff_shexp = hparams.n_ff_shexp ? hparams.n_ff_shexp : n_ff; - - layer.ffn_gate_inp_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP_SHEXP, "weight", i), {n_embd}, 0); - layer.ffn_gate_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), { n_embd, n_ff_shexp}, 0); - layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), {n_ff_shexp, n_embd}, 0); - layer.ffn_up_shexp = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), { n_embd, n_ff_shexp}, 0); - } - } break; - case LLM_ARCH_QWEN3: - case LLM_ARCH_QWEN3VL: - { - tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); - - // output - output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); - output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); - // if output is NULL, init from the input tok embed - if (output == NULL) { - output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); - } - - // output rerank head - cls_out = create_tensor(tn(LLM_TENSOR_CLS_OUT, "weight"), {n_embd, hparams.n_cls_out}, TENSOR_NOT_REQUIRED); - - for (int i = 0; i < n_layer; ++i) { - auto & layer = layers[i]; - - layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); - - layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head}, 0); - layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}, 0); - layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}, 0); - layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0); - - layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k}, 0); - layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head_k}, 0); - - layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); - layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); - layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); - layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); - } - } break; - case LLM_ARCH_QWEN3MOE: - case LLM_ARCH_QWEN3VLMOE: - case LLM_ARCH_RND1: - { - tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); - - // output - output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); - output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); - // if output is NULL, init from the input tok embed - if (output == NULL) { - output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); - } - - for (int i = 0; i < n_layer; ++i) { - auto & layer = layers[i]; - - layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); - - layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head}, 0); - layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}, 0); - layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}, 0); - layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0); - - layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k}, 0); - layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head_k}, 0); - - layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); - - layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0); - - if (n_expert == 0) { - throw std::runtime_error("n_expert must be > 0 for QWEN3MOE"); - } - if (n_expert_used == 0) { - throw std::runtime_error("n_expert_used must be > 0 for QWEN3MOE"); - } - - // MoE branch - const int64_t n_ff_exp = hparams.n_ff_exp ? hparams.n_ff_exp : n_ff / n_expert_used; - - layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}, 0); - layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff_exp, n_embd, n_expert}, 0); - layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}, 0); - } - } break; - case LLM_ARCH_PHI2: - { - tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); - - // output - output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); - output_norm_b = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd}, 0); - output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0); - output_b = create_tensor(tn(LLM_TENSOR_OUTPUT, "bias"), {n_vocab}, 0); - - for (int i = 0; i < n_layer; ++i) { - auto & layer = layers[i]; - - layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); - layer.attn_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "bias", i), {n_embd}, 0); - - layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, TENSOR_NOT_REQUIRED); - layer.bqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "bias", i), {n_embd + 2*n_embd_gqa}, TENSOR_NOT_REQUIRED); - - if (layer.wqkv == nullptr) { - layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}, 0); - layer.bq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_embd}, 0); - - layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}, 0); - layer.bk = create_tensor(tn(LLM_TENSOR_ATTN_K, "bias", i), {n_embd_gqa}, 0); - - layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}, 0); - layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V, "bias", i), {n_embd_gqa}, 0); - } - - layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); - layer.bo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, 0); - - layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0); - layer.ffn_down_b = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, 0); - - layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); - layer.ffn_up_b = create_tensor(tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff}, 0); - } - } break; - case LLM_ARCH_PHI3: - { - tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), { n_embd, n_vocab }, 0); - - // output - output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), { n_embd }, 0); - output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); - - // if output is NULL, init from the input tok embed - if (output == NULL) { - output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); - } - - for (int i = 0; i < n_layer; ++i) { - auto & layer = layers[i]; - - layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), { n_embd }, 0); - - layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), { n_embd, n_embd + 2 * n_embd_gqa }, TENSOR_NOT_REQUIRED); - layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), { n_embd, n_embd }, 0); - - layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), { n_embd }, 0); - - layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd }, 0); - layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), { n_embd, 2 * n_ff }, 0); - - layer.rope_long = create_tensor(tn(LLM_TENSOR_ROPE_FACTORS_LONG, "weight", i), { n_rot/2 }, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0)); - layer.rope_short = create_tensor(tn(LLM_TENSOR_ROPE_FACTORS_SHORT, "weight", i), { n_rot/2 }, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0)); - } - } break; - case LLM_ARCH_PHIMOE: - { - const int64_t n_embd_head = n_embd / n_head; - - tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), { n_embd, n_vocab }, 0); - - // output - output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), { n_embd }, 0); - output_norm_b = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd}, 0); - output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), { n_embd, n_vocab }, 0); - output_b = create_tensor(tn(LLM_TENSOR_OUTPUT, "bias"), { n_vocab }, 0); - - for (int i = 0; i < n_layer; ++i) { - auto & layer = layers[i]; - - layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), { n_embd }, 0); - layer.attn_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "bias", i), { n_embd }, 0); - - layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), { n_embd, n_embd + 2 * n_embd_gqa }, TENSOR_NOT_REQUIRED); - if (layer.wqkv == nullptr) { - layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}, 0); - layer.bq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_embd}, 0); - - layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}, 0); - layer.bk = create_tensor(tn(LLM_TENSOR_ATTN_K, "bias", i), {n_embd_gqa}, 0); - - layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}, 0); - layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V, "bias", i), {n_embd_gqa}, 0); - } - layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), { n_embd, n_embd }, 0); - layer.bo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), { n_embd }, 0); - - layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), { n_embd }, 0); - layer.ffn_norm_b = create_tensor(tn(LLM_TENSOR_FFN_NORM, "bias", i), { n_embd }, 0); - - layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0); - layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {n_embd, n_ff, n_expert}, 0); - layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff, n_embd, n_expert}, 0); - layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), {n_embd, n_ff, n_expert}, 0); - - layer.rope_long = create_tensor(tn(LLM_TENSOR_ROPE_FACTORS_LONG, "weight", i), { n_embd_head/2 }, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0)); - layer.rope_short = create_tensor(tn(LLM_TENSOR_ROPE_FACTORS_SHORT, "weight", i), { n_embd_head/2 }, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0)); - } - } break; - case LLM_ARCH_PLAMO: - { - tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); - - // output - output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); - output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0); - - for (int i = 0; i < n_layer; ++i) { - auto & layer = layers[i]; - - layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); - - layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}, 0); - layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}, 0); - layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}, 0); - layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); - - layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); - layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); - layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); - } - } break; - case LLM_ARCH_PLAMO2: - { - // mamba parameters - const uint32_t d_conv = hparams.ssm_d_conv; - const uint32_t d_state = hparams.ssm_d_state; - const uint32_t num_heads = hparams.ssm_dt_rank; - const uint32_t intermediate_size = hparams.ssm_d_inner; - const int64_t dt_dim = std::max(64, int(hparams.n_embd / 16)); - - // attention parameters - const uint32_t qk_dim = hparams.n_embd_head_k; - const uint32_t v_dim = hparams.n_embd_head_v; - - tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); - - // output - output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); - output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); - // if output is NULL, init from the input tok embed - if (output == NULL) { - output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); - } - - for (int i = 0; i < n_layer; ++i) { - auto & layer = layers[i]; - bool is_mamba_layer = hparams.is_recurrent(i); - - layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); - - if (is_mamba_layer) { - layer.ssm_in = create_tensor(tn(LLM_TENSOR_SSM_IN, "weight", i), {n_embd, 2 * intermediate_size}, 0); - layer.ssm_conv1d = create_tensor(tn(LLM_TENSOR_SSM_CONV1D, "weight", i), {d_conv, intermediate_size}, 0); - - layer.ssm_x = create_tensor(tn(LLM_TENSOR_SSM_X, "weight", i), {intermediate_size, dt_dim + 2*d_state}, 0); - layer.ssm_dt = create_tensor(tn(LLM_TENSOR_SSM_DT, "weight", i), {dt_dim, num_heads}, 0); - layer.ssm_dt_b = create_tensor(tn(LLM_TENSOR_SSM_DT, "bias", i), {num_heads}, 0); - - layer.ssm_a = create_tensor(tn(LLM_TENSOR_SSM_A, i), {num_heads}, 0); - layer.ssm_d = create_tensor(tn(LLM_TENSOR_SSM_D, i), {num_heads}, 0); - - layer.ssm_out = create_tensor(tn(LLM_TENSOR_SSM_OUT, "weight", i), {intermediate_size, n_embd}, 0); - - layer.ssm_dt_norm = create_tensor(tn(LLM_TENSOR_SSM_DT_NORM, i), {dt_dim}, 0); - layer.ssm_b_norm = create_tensor(tn(LLM_TENSOR_SSM_B_NORM, i), {d_state}, 0); - layer.ssm_c_norm = create_tensor(tn(LLM_TENSOR_SSM_C_NORM, i), {d_state}, 0); - } else { - const int64_t num_attention_heads = hparams.n_head(i); - const int64_t q_num_heads = num_attention_heads; - const int64_t num_key_value_heads = hparams.n_head_kv(i); - const int64_t k_num_heads = num_key_value_heads; - const int64_t v_num_heads = num_key_value_heads; - const int64_t q_proj_dim = q_num_heads * qk_dim; - const int64_t k_proj_dim = k_num_heads * qk_dim; - const int64_t v_proj_dim = v_num_heads * v_dim; - - layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, q_proj_dim + k_proj_dim + v_proj_dim}, 0); - layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {qk_dim, num_attention_heads}, 0); - layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {qk_dim, k_num_heads}, 0); - layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {q_num_heads * v_dim, n_embd}, 0); - } - - // All layers have post-attention norm, FFN norm, and FFN tensors - layer.attn_post_norm = create_tensor(tn(LLM_TENSOR_ATTN_POST_NORM, i), {n_embd}, 0); - layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); - layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0); - layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff * 2}, 0); - layer.ffn_post_norm = create_tensor(tn(LLM_TENSOR_FFN_POST_NORM, i), {n_embd}, 0); - } - } break; - case LLM_ARCH_PLAMO3: - { - const int64_t head_dim_q = hparams.n_embd_head_k; - const int64_t head_dim_v = hparams.n_embd_head_v; - - tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); - - output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); - output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); - if (output == NULL) { - output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); - } - - for (int i = 0; i < n_layer; ++i) { - auto & layer = layers[i]; - - const int64_t num_attention_heads = hparams.n_head(i); - const int64_t num_key_value_heads = hparams.n_head_kv(i); - const int64_t q_proj_dim = num_attention_heads * head_dim_q; - const int64_t k_proj_dim = num_key_value_heads * head_dim_q; - const int64_t v_proj_dim = num_key_value_heads * head_dim_v; - const int64_t n_ff_cur = hparams.n_ff(i); - - layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); - layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), - {n_embd,q_proj_dim + k_proj_dim + v_proj_dim}, 0); - layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {head_dim_q}, 0); - layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {head_dim_q}, 0); - layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {num_attention_heads * head_dim_v, n_embd}, 0); - layer.attn_post_norm = create_tensor(tn(LLM_TENSOR_ATTN_POST_NORM, i), {n_embd}, 0); - - layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); - layer.ffn_post_norm = create_tensor(tn(LLM_TENSOR_FFN_POST_NORM, i), {n_embd}, 0); - - layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff_cur * 2}, 0); - layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff_cur, n_embd}, 0); - } - } break; - case LLM_ARCH_GPT2: - { - tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); - pos_embd = create_tensor(tn(LLM_TENSOR_POS_EMBD, "weight"), {n_embd, n_ctx_train}, 0); - - // output - output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); - output_norm_b = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd}, 0); - output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); - - // if output is NULL, init from the input tok embed - if (output == NULL) { - output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); - } - - for (int i = 0; i < n_layer; ++i) { - auto & layer = layers[i]; - - layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); - layer.attn_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "bias", i), {n_embd}, 0); - - layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, 0); - layer.bqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "bias", i), {n_embd + 2*n_embd_gqa}, 0); - - layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); - layer.bo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, 0); - - layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); - layer.ffn_norm_b = create_tensor(tn(LLM_TENSOR_FFN_NORM, "bias", i), {n_embd}, 0); - - layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0); - layer.ffn_down_b = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, 0); - - layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); - layer.ffn_up_b = create_tensor(tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff}, 0); - } - } break; - case LLM_ARCH_CODESHELL: - { - tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); - - // if tok embd is NULL, init from output - if (tok_embd == NULL) { - tok_embd = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); - } - - // output - output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); - output_norm_b = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd}, 0); - output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0); - - for (int i = 0; i < n_layer; ++i) { - auto & layer = layers[i]; - - layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); - layer.attn_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "bias", i), {n_embd}, 0); - - layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, 0); - layer.bqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "bias", i), {n_embd + 2*n_embd_gqa}, 0); - - layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); - layer.bo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, 0); - - layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); - layer.ffn_norm_b = create_tensor(tn(LLM_TENSOR_FFN_NORM, "bias", i), {n_embd}, 0); - - layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0); - layer.ffn_down_b = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, 0); - - layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); - layer.ffn_up_b = create_tensor(tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff}, 0); - } - } break; - case LLM_ARCH_ORION: - { - tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); - - output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); - output_norm_b = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd}, 0); - output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0); - - for (int i = 0; i < n_layer; ++i) { - auto & layer = layers[i]; - - layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); - layer.attn_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "bias", i), {n_embd}, 0); - - layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}, 0); - layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}, 0); - layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}, 0); - layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); - - layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); - layer.ffn_norm_b = create_tensor(tn(LLM_TENSOR_FFN_NORM, "bias", i), {n_embd}, 0); - - layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); - layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); - layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); - } - } break; - case LLM_ARCH_INTERNLM2: - { - tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); - - // output - output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); - output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0); - - for (int i = 0; i < n_layer; ++i) { - auto & layer = layers[i]; - - layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); - // layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, 0); - layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}, 0); - layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}, 0); - layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}, 0); - - layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); - layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); - layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); - layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); - layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); - } - } break; - case LLM_ARCH_GEMMA: - { - tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); - - // output - output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); - output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); // same as tok_embd, duplicated to allow offloading - - for (int i = 0; i < n_layer; ++i) { - auto & layer = layers[i]; - - layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); - - layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head}, 0); - layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_k_gqa}, 0); - layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_v_gqa}, 0); - layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0); - - layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); - layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); - layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); - layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); - } - } break; - case LLM_ARCH_GEMMA2: - { - tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); - - // output - output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); - output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); // same as tok_embd, duplicated to allow offloading - - for (int i = 0; i < n_layer; ++i) { - auto & layer = layers[i]; - - layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); - - layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head}, 0); - layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_k_gqa}, 0); - layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_v_gqa}, 0); - layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0); - layer.attn_post_norm = create_tensor(tn(LLM_TENSOR_ATTN_POST_NORM, "weight", i), {n_embd}, 0); - - layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); - layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); - layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); - layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); - layer.ffn_post_norm = create_tensor(tn(LLM_TENSOR_FFN_POST_NORM, "weight", i), {n_embd}, 0); - } - } break; - case LLM_ARCH_GEMMA3: - case LLM_ARCH_GEMMA_EMBEDDING: - { - tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); - - // output - output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); - output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); - - // if output is NULL, init from the input tok embed - if (output == NULL) { - output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); - } - - // Dense linear weights - dense_2_out_layers = create_tensor(tn(LLM_TENSOR_DENSE_2_OUT, "weight"), {n_embd, hparams.dense_2_feat_out}, TENSOR_NOT_REQUIRED); - dense_3_out_layers = create_tensor(tn(LLM_TENSOR_DENSE_3_OUT, "weight"), {hparams.dense_3_feat_in, n_embd}, TENSOR_NOT_REQUIRED); - - - for (int i = 0; i < n_layer; ++i) { - auto & layer = layers[i]; - - layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); - - layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head}, 0); - layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_k_gqa}, 0); - layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_v_gqa}, 0); - layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0); - - layer.attn_post_norm = create_tensor(tn(LLM_TENSOR_ATTN_POST_NORM, "weight", i), {n_embd}, 0); - layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k}, 0); - layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head_k}, 0); - - layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); - layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); - layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); - layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); - layer.ffn_post_norm = create_tensor(tn(LLM_TENSOR_FFN_POST_NORM, "weight", i), {n_embd}, 0); - } - } break; - case LLM_ARCH_GEMMA3N: - { - const int64_t n_altup = hparams.n_altup; - const int64_t laurel_rank = hparams.laurel_rank; - const int64_t n_embd_altup = hparams.n_embd_altup; - - output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); - // if output is NULL, init from the input tok embed - if (output == NULL) { - output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); - } - - tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); - tok_embd_per_layer = create_tensor(tn(LLM_TENSOR_PER_LAYER_TOKEN_EMBD, "weight"), {n_embd_altup * n_layer, n_vocab}, 0); - - altup_proj = create_tensor(tn(LLM_TENSOR_ALTUP_PROJ, "weight"), {n_embd, n_embd, n_altup - 1}, 0); - altup_unembd_proj = create_tensor(tn(LLM_TENSOR_ALTUP_UNEMBD_PROJ, "weight"), {n_embd, n_embd, n_altup - 1}, 0); - per_layer_model_proj = create_tensor(tn(LLM_TENSOR_PER_LAYER_MODEL_PROJ, "weight"), {n_embd, n_embd_altup * n_layer}, 0); - per_layer_proj_norm = create_tensor(tn(LLM_TENSOR_PER_LAYER_PROJ_NORM, "weight"), {n_embd_altup}, 0); - - output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); - - for (int i = 0; i < n_layer; ++i) { - auto & layer = layers[i]; - - layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); - - layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head}, 0); - layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_k_gqa}, 0); - layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_v_gqa}, 0); - layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0); - - layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head_k}, 0); - layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k}, 0); - layer.attn_post_norm = create_tensor(tn(LLM_TENSOR_ATTN_POST_NORM, "weight", i), {n_embd}, 0); - - layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); - layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); - layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); - layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); - layer.ffn_post_norm = create_tensor(tn(LLM_TENSOR_FFN_POST_NORM, "weight", i), {n_embd}, 0); - - // altup & laurel - layer.per_layer_inp_gate = create_tensor(tn(LLM_TENSOR_PER_LAYER_INP_GATE, "weight", i), {n_embd, n_embd_altup}, 0); - layer.per_layer_proj = create_tensor(tn(LLM_TENSOR_PER_LAYER_PROJ, "weight", i), {n_embd_altup, n_embd}, 0); - layer.per_layer_post_norm = create_tensor(tn(LLM_TENSOR_PER_LAYER_POST_NORM, "weight", i), {n_embd}, 0); - layer.altup_correct_coef = create_tensor(tn(LLM_TENSOR_ALTUP_CORRECT_COEF, "weight", i), {n_altup, n_altup}, 0); - layer.altup_correct_scale = create_tensor(tn(LLM_TENSOR_ALTUP_CORRECT_SCALE, "weight", i), {n_embd}, 0); - layer.altup_predict_coef = create_tensor(tn(LLM_TENSOR_ALTUP_PREDICT_COEF, "weight", i), {n_altup, n_altup * n_altup}, 0); - layer.altup_router = create_tensor(tn(LLM_TENSOR_ALTUP_ROUTER, "weight", i), {n_embd, n_altup}, 0); - layer.altup_router_norm = create_tensor(tn(LLM_TENSOR_ALTUP_ROUTER_NORM, "weight", i), {n_embd}, 0); - layer.laurel_l = create_tensor(tn(LLM_TENSOR_LAUREL_L, "weight", i), {n_embd, laurel_rank}, 0); - layer.laurel_r = create_tensor(tn(LLM_TENSOR_LAUREL_R, "weight", i), {laurel_rank, n_embd}, 0); - layer.laurel_post_norm = create_tensor(tn(LLM_TENSOR_LAUREL_POST_NORM, "weight", i), {n_embd}, 0); - } - } break; - case LLM_ARCH_STARCODER2: - { - tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); - - // output - output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); - output_norm_b = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd}, 0); - - output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); - // if output is NULL, init from the input tok embed - if (output == NULL) { - output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); - } - - for (int i = 0; i < n_layer; ++i) { - auto & layer = layers[i]; - - layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); - layer.attn_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "bias", i), {n_embd}, 0); - - layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}, 0); - layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}, 0); - layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}, 0); - layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); - - // optional bias tensors - layer.bq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_embd}, 0); - layer.bk = create_tensor(tn(LLM_TENSOR_ATTN_K, "bias", i), {n_embd_gqa}, 0); - layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V, "bias", i), {n_embd_gqa}, 0); - layer.bo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, 0); - - layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); - layer.ffn_norm_b = create_tensor(tn(LLM_TENSOR_FFN_NORM, "bias", i), {n_embd}, 0); - - layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); - layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); - - // optional bias tensors - layer.ffn_down_b = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, 0); - layer.ffn_up_b = create_tensor(tn(LLM_TENSOR_FFN_UP , "bias", i), { n_ff}, 0); - } - } break; - case LLM_ARCH_MAMBA: - { - const int64_t d_conv = hparams.ssm_d_conv; - const int64_t d_inner = hparams.ssm_d_inner; - const int64_t d_state = hparams.ssm_d_state; - const int64_t dt_rank = hparams.ssm_dt_rank; - - // only an expansion factor of 2 is supported for now - if (2 * n_embd != d_inner) { - throw std::runtime_error("only an expansion factor of 2 is supported for now"); - } - - tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); - - // output - output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); - - output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); - // if output is NULL, init from the input tok embed, duplicated to allow offloading - if (output == NULL) { - output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); - } - - for (int i = 0; i < n_layer; ++i) { - auto & layer = layers[i]; - - // norm - layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); - - layer.ssm_in = create_tensor(tn(LLM_TENSOR_SSM_IN, "weight", i), {n_embd, 2*d_inner}, 0); - - layer.ssm_conv1d = create_tensor(tn(LLM_TENSOR_SSM_CONV1D, "weight", i), {d_conv, d_inner}, 0); - layer.ssm_conv1d_b = create_tensor(tn(LLM_TENSOR_SSM_CONV1D, "bias", i), {d_inner}, 0); - - layer.ssm_x = create_tensor(tn(LLM_TENSOR_SSM_X, "weight", i), {d_inner, dt_rank + 2*d_state}, 0); - - layer.ssm_dt = create_tensor(tn(LLM_TENSOR_SSM_DT, "weight", i), {dt_rank, d_inner}, 0); - layer.ssm_dt_b = create_tensor(tn(LLM_TENSOR_SSM_DT, "bias", i), {d_inner}, 0); - - // no "weight" suffix for these - layer.ssm_a = create_tensor(tn(LLM_TENSOR_SSM_A, i), {d_state, d_inner}, 0); - layer.ssm_d = create_tensor(tn(LLM_TENSOR_SSM_D, i), {d_inner}, 0); - - // out_proj - layer.ssm_out = create_tensor(tn(LLM_TENSOR_SSM_OUT, "weight", i), {d_inner, n_embd}, 0); - } - } break; - case LLM_ARCH_MAMBA2: - { - const int64_t d_conv = hparams.ssm_d_conv; - const int64_t d_inner = hparams.ssm_d_inner; - const int64_t d_state = hparams.ssm_d_state; - const int64_t n_head = hparams.ssm_dt_rank; - const int64_t n_group = hparams.ssm_n_group; - const int64_t d_in_proj = 2*d_inner + 2*n_group*d_state + n_head; - - // only an expansion factor of 2 is supported for now - GGML_ASSERT(2 * n_embd == d_inner); - - tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); - - // output - { - output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); - - output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); - // if output is NULL, init from the input tok embed, duplicated to allow offloading - if (output == NULL) { - output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); - } - } - - for (int i = 0; i < n_layer; ++i) { - auto & layer = layers[i]; - - // norm - layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); - - layer.ssm_in = create_tensor(tn(LLM_TENSOR_SSM_IN, "weight", i), {n_embd, d_in_proj}, 0); - - layer.ssm_conv1d = create_tensor(tn(LLM_TENSOR_SSM_CONV1D, "weight", i), {d_conv, d_inner + 2*n_group*d_state}, 0); - layer.ssm_conv1d_b = create_tensor(tn(LLM_TENSOR_SSM_CONV1D, "bias", i), {d_inner + 2*n_group*d_state}, 0); - - layer.ssm_dt_b = create_tensor(tn(LLM_TENSOR_SSM_DT, "bias", i), {n_head}, 0); - - // no "weight" suffix for these - layer.ssm_a = create_tensor(tn(LLM_TENSOR_SSM_A, i), {1, n_head}, 0); - layer.ssm_d = create_tensor(tn(LLM_TENSOR_SSM_D, i), {1, n_head}, 0); - - layer.ssm_norm = create_tensor(tn(LLM_TENSOR_SSM_NORM, "weight", i), {d_inner / n_group, n_group}, 0); - - // out_proj - layer.ssm_out = create_tensor(tn(LLM_TENSOR_SSM_OUT, "weight", i), {d_inner, n_embd}, 0); - } - } break; - case LLM_ARCH_JAMBA: - { - const int64_t d_conv = hparams.ssm_d_conv; - const int64_t d_inner = hparams.ssm_d_inner; - const int64_t d_state = hparams.ssm_d_state; - const int64_t dt_rank = hparams.ssm_dt_rank; - - // only an expansion factor of 2 is supported for now - GGML_ASSERT(2 * n_embd == d_inner); - - tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); - - // output - { - output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); - - output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); - // if output is NULL, init from the input tok embed, duplicated to allow offloading - if (output == NULL) { - output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); - } - } - - for (int i = 0; i < n_layer; ++i) { - const int64_t n_head_kv = hparams.n_head_kv(i); - const int64_t n_embd_gqa = hparams.n_embd_v_gqa(i); - - auto & layer = layers[i]; - - // norm - layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); - - if (n_head_kv == 0) { - // Mamba layer - layer.ssm_in = create_tensor(tn(LLM_TENSOR_SSM_IN, "weight", i), {n_embd, 2*d_inner}, 0); - - layer.ssm_conv1d = create_tensor(tn(LLM_TENSOR_SSM_CONV1D, "weight", i), {d_conv, d_inner}, 0); - layer.ssm_conv1d_b = create_tensor(tn(LLM_TENSOR_SSM_CONV1D, "bias", i), {d_inner}, 0); - - layer.ssm_x = create_tensor(tn(LLM_TENSOR_SSM_X, "weight", i), {d_inner, dt_rank + 2*d_state}, 0); - - layer.ssm_dt_norm = create_tensor(tn(LLM_TENSOR_SSM_DT_NORM, "weight", i), {dt_rank}, 0); - - layer.ssm_dt = create_tensor(tn(LLM_TENSOR_SSM_DT, "weight", i), {dt_rank, d_inner}, 0); - layer.ssm_dt_b = create_tensor(tn(LLM_TENSOR_SSM_DT, "bias", i), {d_inner}, 0); - - layer.ssm_b_norm = create_tensor(tn(LLM_TENSOR_SSM_B_NORM, "weight", i), {d_state}, 0); - layer.ssm_c_norm = create_tensor(tn(LLM_TENSOR_SSM_C_NORM, "weight", i), {d_state}, 0); - - // no "weight" suffix for these - layer.ssm_a = create_tensor(tn(LLM_TENSOR_SSM_A, i), {d_state, d_inner}, 0); - layer.ssm_d = create_tensor(tn(LLM_TENSOR_SSM_D, i), {d_inner}, 0); - - // out_proj - layer.ssm_out = create_tensor(tn(LLM_TENSOR_SSM_OUT, "weight", i), {d_inner, n_embd}, 0); - } else { - // Attention layers - - layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}, 0); - layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}, 0); - layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}, 0); - layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); - } - - layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); - - layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, TENSOR_NOT_REQUIRED); - - if (layer.ffn_gate_inp) { - // MoE - layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {n_embd, n_ff, n_expert}, 0); - layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff, n_embd, n_expert}, 0); - layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), {n_embd, n_ff, n_expert}, 0); - } else { - // FFN (no MoE) - layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); - layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0); - layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); - } - } - } break; - case LLM_ARCH_GRANITE_HYBRID: - { - // mamba2 Mixer SSM params - // NOTE: int64_t for tensor dimensions - const int64_t d_conv = hparams.ssm_d_conv; - const int64_t d_inner = hparams.ssm_d_inner; - const int64_t d_state = hparams.ssm_d_state; - const int64_t n_ssm_head = hparams.ssm_dt_rank; - const int64_t n_group = hparams.ssm_n_group; - const int64_t d_in_proj = 2*d_inner + 2*n_group*d_state + n_ssm_head; - - // only an expansion factor of 2 is supported for now - GGML_ASSERT(2 * n_embd == d_inner); - - // embeddings - tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); - - // output - { - output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); - output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); - // if output is NULL, init from the input tok embed, duplicated to allow offloading - if (output == NULL) { - output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); - } - } - - for (int i = 0; i < n_layer; ++i) { - auto & layer = layers[i]; - - // norm - layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); - - if (hparams.is_recurrent(i)) { - // ssm layers - layer.ssm_in = create_tensor(tn(LLM_TENSOR_SSM_IN, "weight", i), {n_embd, d_in_proj}, 0); - - layer.ssm_conv1d = create_tensor(tn(LLM_TENSOR_SSM_CONV1D, "weight", i), {d_conv, d_inner + 2*n_group*d_state}, 0); - layer.ssm_conv1d_b = create_tensor(tn(LLM_TENSOR_SSM_CONV1D, "bias", i), {d_inner + 2*n_group*d_state}, TENSOR_NOT_REQUIRED); - - layer.ssm_dt_b = create_tensor(tn(LLM_TENSOR_SSM_DT, "bias", i), {n_ssm_head}, 0); - - // no "weight" suffix for these - layer.ssm_a = create_tensor(tn(LLM_TENSOR_SSM_A, i), {1, n_ssm_head}, 0); - layer.ssm_d = create_tensor(tn(LLM_TENSOR_SSM_D, i), {1, n_ssm_head}, 0); - - layer.ssm_norm = create_tensor(tn(LLM_TENSOR_SSM_NORM, "weight", i), {d_inner / n_group, n_group}, 0); - - // out_proj - layer.ssm_out = create_tensor(tn(LLM_TENSOR_SSM_OUT, "weight", i), {d_inner, n_embd}, 0); - } else { - // attention layers (with optional bias) - const int64_t n_head_i = hparams.n_head(i); - const int64_t n_embd_k_gqa_i = hparams.n_embd_k_gqa(i); - const int64_t n_embd_v_gqa_i = hparams.n_embd_v_gqa(i); - layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head_i}, 0); - layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_k_gqa_i}, 0); - layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_v_gqa_i}, 0); - layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head_i, n_embd}, 0); - layer.bq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); - layer.bk = create_tensor(tn(LLM_TENSOR_ATTN_K, "bias", i), {n_embd_k_gqa_i}, TENSOR_NOT_REQUIRED); - layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V, "bias", i), {n_embd_v_gqa_i}, TENSOR_NOT_REQUIRED); - layer.bo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); - } - - // feed forward (w/ optional biases) - if (n_expert > 0) { - // MoE FFN - layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); - layer.rope_freqs = create_tensor(tn(LLM_TENSOR_ROPE_FREQS, "weight", i), {n_rot/2}, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0)); - layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0); - layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {n_embd, n_ff, n_expert}, TENSOR_NOT_REQUIRED); - layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), { n_ff, n_embd, n_expert}, 0); - layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), {n_embd, n_ff, n_expert}, 0); - - // For Granite MoE Shared - if (hparams.n_ff_shexp > 0) { - layer.ffn_gate_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), {n_embd, hparams.n_ff_shexp}, 0); - layer.ffn_up_shexp = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), {n_embd, hparams.n_ff_shexp}, 0); - layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), {hparams.n_ff_shexp, n_embd}, 0); - } - } else { - layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); - layer.rope_freqs = create_tensor(tn(LLM_TENSOR_ROPE_FREQS, "weight", i), {n_rot/2}, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0)); - layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); - layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); - layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); - layer.ffn_gate_b = create_tensor(tn(LLM_TENSOR_FFN_GATE, "bias", i), {n_ff}, TENSOR_NOT_REQUIRED); - layer.ffn_down_b = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); - layer.ffn_up_b = create_tensor(tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff}, TENSOR_NOT_REQUIRED); - } - } - } break; - case LLM_ARCH_XVERSE: - { - tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); - - output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); - output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0); - - for (int i = 0; i < n_layer; ++i) { - auto & layer = layers[i]; - - layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); - - layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}, 0); - layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}, 0); - layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}, 0); - layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); - - layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); - layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); - layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); - layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); - } - } break; - case LLM_ARCH_COMMAND_R: - { - tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); - - // output - output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); - // init output from the input tok embed - output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); - - for (int i = 0; i < n_layer; ++i) { - auto & layer = layers[i]; - - layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); - - if (n_layer >= 64){ - layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head_k, n_head}, 0); - layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k, n_head_kv}, 0); - } - - layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}, 0); - layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}, 0); - layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}, 0); - layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); - - layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); - layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); - layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); - } - } break; - case LLM_ARCH_COHERE2: - { - tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), { n_embd, n_vocab }, 0); - - // output - output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), { n_embd }, 0); - // init output from the input tok embed - output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), { n_embd, n_vocab }, - TENSOR_DUPLICATED); - - for (int i = 0; i < n_layer; ++i) { - auto & layer = layers[i]; - - layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), { n_embd }, 0); - - layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), { n_embd, n_embd }, 0); - layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), { n_embd, n_embd_gqa }, 0); - layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), { n_embd, n_embd_gqa }, 0); - layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), { n_embd, n_embd }, 0); - - layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), { n_embd, n_ff }, 0); - layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd }, 0); - layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), { n_embd, n_ff }, 0); - } - } - break; - case LLM_ARCH_OLMO: // adapted from LLM_ARCH_LLAMA with norm params removed - { - tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); - - // output - output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); - // if output is NULL, init from the input tok embed - if (output == NULL) { - output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); - } - - for (int i = 0; i < n_layer; ++i) { - auto & layer = layers[i]; - - layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}, 0); - layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}, 0); - layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}, 0); - layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); - - layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); - layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); - layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); - } - } break; - case LLM_ARCH_OLMO2: - { - const int64_t n_embd_head = n_embd / n_head; - - tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); - - // output - output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); - output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0); - - for (int i = 0; i < n_layer; ++i) { - auto & layer = layers[i]; - - layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}, 0); - layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}, 0); - layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}, 0); - layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); - layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd}, 0); - layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_head_kv * n_embd_head}, 0); - layer.attn_post_norm = create_tensor(tn(LLM_TENSOR_ATTN_POST_NORM, "weight", i), {n_embd}, 0); - - layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); - layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); - layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); - layer.ffn_post_norm = create_tensor(tn(LLM_TENSOR_FFN_POST_NORM, "weight", i), {n_embd}, 0); - } - } break; - case LLM_ARCH_SEED_OSS: - { - const uint32_t head_dim = hparams.n_embd_head_k; - const int64_t n_qo_dim = n_head * head_dim; - const int64_t n_kv_dim = n_head_kv * head_dim; - - tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); - - // output - output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); - output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); - // if output is NULL, init from the input tok embed - if (output == NULL) { - output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); - } - - for (int i = 0; i < n_layer; ++i) { - auto & layer = layers[i]; - - layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_qo_dim}, 0); - layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_kv_dim}, 0); - layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_kv_dim}, 0); - layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_qo_dim, n_embd}, 0); - - layer.bq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_qo_dim}, TENSOR_NOT_REQUIRED); - layer.bk = create_tensor(tn(LLM_TENSOR_ATTN_K, "bias", i), {n_kv_dim}, TENSOR_NOT_REQUIRED); - layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V, "bias", i), {n_kv_dim}, TENSOR_NOT_REQUIRED); - - layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); - layer.attn_post_norm = create_tensor(tn(LLM_TENSOR_ATTN_POST_NORM, "weight", i), {n_embd}, 0); - - layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); - layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); - layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); - } - } break; - - case LLM_ARCH_OLMOE: - { - tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); - - // output - output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); - output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0); - - for (int i = 0; i < n_layer; ++i) { - auto & layer = layers[i]; - - layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); - - layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}, 0); - layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}, 0); - layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}, 0); - layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); - layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd}, 0); - layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd}, 0); - - layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); - - layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0); - - if (n_expert == 0) { - throw std::runtime_error("n_expert must be > 0"); - } - if (n_expert_used == 0) { - throw std::runtime_error("n_expert_used must be > 0"); - } - - // MoE branch - layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {n_embd, n_ff, n_expert}, 0); - layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff, n_embd, n_expert}, 0); - layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), {n_embd, n_ff, n_expert}, 0); - } - } break; - case LLM_ARCH_OPENELM: - { - tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); - - // output - output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); - // init output from the input tok embed - output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); - - for (int i = 0; i < n_layer; ++i) { - const int64_t n_head = hparams.n_head(i); - const int64_t n_head_qkv = 2*hparams.n_head_kv(i) + n_head; - const int64_t n_ff = hparams.n_ff(i); - - auto & layer = layers[i]; - - layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); - - layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_head_qkv*n_embd_head_k}, 0); - layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head_k}, 0); - layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k}, 0); - layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_head*n_embd_head_k, n_embd}, 0); - - layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); - layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); - layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0); - layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); - } - } break; - case LLM_ARCH_GPTNEOX: - { - tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); - - // output - output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); - output_norm_b = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd}, 0); - output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0); - - for (int i = 0; i < n_layer; ++i) { - auto & layer = layers[i]; - - layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); - layer.attn_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "bias", i), {n_embd}, 0); - - layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, 0); - layer.bqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "bias", i), {n_embd + 2*n_embd_gqa}, 0); - - layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); - layer.bo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, 0); - - layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); - layer.ffn_norm_b = create_tensor(tn(LLM_TENSOR_FFN_NORM, "bias", i), {n_embd}, 0); - - layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0); - layer.ffn_down_b = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, 0); - - layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); - layer.ffn_up_b = create_tensor(tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff}, 0); - } - } break; - case LLM_ARCH_ARCTIC: - { - tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); - - // output - output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); - output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); - - // if output is NULL, init from the input tok embed - if (output == NULL) { - output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); - } - - for (int i = 0; i < n_layer; ++i) { - auto & layer = layers[i]; - - layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); - - layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}, 0); - layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}, 0); - layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}, 0); - layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); - - layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); - - layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_embd}, 0); - layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_embd, n_embd}, 0); - layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_embd}, 0); - - layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0); - layer.ffn_norm_exps = create_tensor(tn(LLM_TENSOR_FFN_NORM_EXPS, "weight", i), {n_embd}, 0); - layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {n_embd, n_ff, n_expert}, false); - layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), { n_ff, n_embd, n_expert}, 0); - layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), {n_embd, n_ff, n_expert}, 0); - } - } break; - case LLM_ARCH_DEEPSEEK: - { - - const int64_t n_ff_exp = hparams.n_ff_exp; - const int64_t n_expert_shared = hparams.n_expert_shared; - - tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); - - // output - output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); - // try to load output.weight, if not found, use token_embd (tied embeddings) - output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); - if (!output) { - output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); - } - - for (int i = 0; i < n_layer; ++i) { - auto & layer = layers[i]; - - layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); - - layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}, 0); - layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}, 0); - layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}, 0); - layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); - layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); - - if (i < (int) hparams.n_layer_dense_lead) { - layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); - layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); - layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); - } else { - layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0); - - if (n_expert == 0) { - throw std::runtime_error("n_expert must be > 0"); - } - if (n_expert_used == 0) { - throw std::runtime_error("n_expert_used must be > 0"); - } - - // MoE branch - layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}, 0); - layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff_exp, n_embd, n_expert}, 0); - layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}, 0); - - // Shared expert branch - layer.ffn_gate_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), {n_embd, n_ff_exp * n_expert_shared}, 0); - layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), { n_ff_exp * n_expert_shared, n_embd}, 0); - layer.ffn_up_shexp = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), {n_embd, n_ff_exp * n_expert_shared}, 0); - } - } - } break; - case LLM_ARCH_DEEPSEEK2: - { - // lite variants include DeepSeek-V2-Lite, GigaChat3-10B-A1.8B - const bool is_lite = (hparams.n_layer == 27 || hparams.n_layer == 26); - - const bool is_mla = (hparams.n_embd_head_k_mla != 0 && hparams.n_embd_head_v_mla != 0); - - // note: these are the actual head sizes you get when treating as MHA or after "decompression" using wv_b for MLA - const int64_t n_embd_head_k_mla = is_mla ? hparams.n_embd_head_k_mla : hparams.n_embd_head_k; - const int64_t n_embd_head_v_mla = is_mla ? hparams.n_embd_head_v_mla : hparams.n_embd_head_v; - - const int64_t n_embd_head_qk_rope = hparams.n_rot; - const int64_t n_embd_head_qk_nope = n_embd_head_k_mla - n_embd_head_qk_rope; - - const int64_t q_lora_rank = hparams.n_lora_q; - const int64_t kv_lora_rank = hparams.n_lora_kv; - - const int64_t n_ff_exp = hparams.n_ff_exp; - const int64_t n_expert_shared = hparams.n_expert_shared; - - tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); - - // output - output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); - // try to load output.weight, if not found, use token_embd (tied embeddings) - output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); - if (!output) { - output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); - } - - for (int i = 0; i < n_layer; ++i) { - auto & layer = layers[i]; - - layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); - if (!is_lite) { - layer.attn_q_a_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_A_NORM, "weight", i), {q_lora_rank}, 0); - } - - layer.attn_kv_a_norm = create_tensor(tn(LLM_TENSOR_ATTN_KV_A_NORM, "weight", i), {kv_lora_rank}, 0); - - if (!is_lite) { - layer.wq_a = create_tensor(tn(LLM_TENSOR_ATTN_Q_A, "weight", i), {n_embd, q_lora_rank}, 0); - layer.wq_b = create_tensor(tn(LLM_TENSOR_ATTN_Q_B, "weight", i), {q_lora_rank, n_head * n_embd_head_k_mla}, 0); - } else { - layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_head * n_embd_head_k_mla}, 0); - } - - layer.wkv_a_mqa = create_tensor(tn(LLM_TENSOR_ATTN_KV_A_MQA, "weight", i), {n_embd, kv_lora_rank + n_embd_head_qk_rope}, 0); - - // note: only old legacy GGUF files will have the unsplit wkv_b tensor in - if (is_mla) { - layer.wk_b = create_tensor(tn(LLM_TENSOR_ATTN_K_B, "weight", i), {n_embd_head_qk_nope, kv_lora_rank, n_head}, 0); - layer.wv_b = create_tensor(tn(LLM_TENSOR_ATTN_V_B, "weight", i), {kv_lora_rank, n_embd_head_v_mla, n_head}, 0); - } else { - layer.wkv_b = create_tensor(tn(LLM_TENSOR_ATTN_KV_B, "weight", i), {kv_lora_rank, n_head * (n_embd_head_qk_nope + n_embd_head_v_mla)}, 0); - } - - layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_head * n_embd_head_v_mla, n_embd}, 0); - - layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); - - if (i < (int) hparams.n_layer_dense_lead) { - layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); - layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); - layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); - } else { - layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0); - layer.ffn_exp_probs_b = create_tensor(tn(LLM_TENSOR_FFN_EXP_PROBS_B, "bias", i), {n_expert}, TENSOR_NOT_REQUIRED); - - if (n_expert == 0) { - throw std::runtime_error("n_expert must be > 0"); - } - if (n_expert_used == 0) { - throw std::runtime_error("n_expert_used must be > 0"); - } - - // MoE branch - layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}, 0); - layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff_exp, n_embd, n_expert}, 0); - layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}, 0); - - // Shared expert branch - layer.ffn_gate_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), {n_embd, n_ff_exp * n_expert_shared}, 0); - layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), { n_ff_exp * n_expert_shared, n_embd}, 0); - layer.ffn_up_shexp = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), {n_embd, n_ff_exp * n_expert_shared}, 0); - } - } - } break; - case LLM_ARCH_PLM: - { - const int64_t n_embd_head_qk_rope = hparams.n_rot; - const int64_t n_embd_head_qk_nope = hparams.n_embd_head_k - hparams.n_rot; - const int64_t kv_lora_rank = hparams.n_lora_kv; - - tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); - - // output - output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); - // output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0); - output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); - - for (int i = 0; i < n_layer; ++i) { - auto & layer = layers[i]; - - layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); - - layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head}, 0); - layer.wkv_a_mqa = create_tensor(tn(LLM_TENSOR_ATTN_KV_A_MQA, "weight", i), {n_embd, kv_lora_rank + (n_embd_head_qk_rope)}, 0); - layer.attn_kv_a_norm = create_tensor(tn(LLM_TENSOR_ATTN_KV_A_NORM, "weight", i), {kv_lora_rank}, 0); - layer.wkv_b = create_tensor(tn(LLM_TENSOR_ATTN_KV_B, "weight", i), {kv_lora_rank, n_head * (n_embd_head_qk_nope + n_embd_head_v)}, 0); - layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), { n_head * ( n_embd_head_v), n_embd}, 0); - - layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); - layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); - layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); - } - } break; - case LLM_ARCH_BITNET: - { - tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); - - // output - output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); - - for (int i = 0; i < n_layer; ++i) { - auto & layer = layers[i]; - - layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); - layer.attn_sub_norm = create_tensor(tn(LLM_TENSOR_ATTN_SUB_NORM, "weight", i), {n_embd}, 0); - - layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}, 0); - layer.wq_scale = create_tensor(tn(LLM_TENSOR_ATTN_Q, "scale", i), {1}, TENSOR_NOT_REQUIRED); - layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}, 0); - layer.wk_scale = create_tensor(tn(LLM_TENSOR_ATTN_K, "scale", i), {1}, TENSOR_NOT_REQUIRED); - layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}, 0); - layer.wv_scale = create_tensor(tn(LLM_TENSOR_ATTN_V, "scale", i), {1}, TENSOR_NOT_REQUIRED); - layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); - layer.wo_scale = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "scale", i), {1}, TENSOR_NOT_REQUIRED); - - layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); - layer.ffn_sub_norm = create_tensor(tn(LLM_TENSOR_FFN_SUB_NORM, "weight", i), {n_ff}, 0); - - layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); - layer.ffn_gate_scale = create_tensor(tn(LLM_TENSOR_FFN_GATE, "scale", i), {1}, TENSOR_NOT_REQUIRED); - layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0); - layer.ffn_down_scale = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "scale", i), {1}, TENSOR_NOT_REQUIRED); - layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); - layer.ffn_up_scale = create_tensor(tn(LLM_TENSOR_FFN_UP, "scale", i), {1}, TENSOR_NOT_REQUIRED); - } - } break; - case LLM_ARCH_T5: - { - const auto n_rel_attn_bkts = hparams.n_rel_attn_bkts; - - tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); - - // output - output_norm_enc = create_tensor(tn(LLM_TENSOR_ENC_OUTPUT_NORM, "weight"), {n_embd}, 0); - output_norm = create_tensor(tn(LLM_TENSOR_DEC_OUTPUT_NORM, "weight"), {n_embd}, 0); - - output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); - // if output is NULL, init from the input tok embed - if (output == NULL) { - output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); - } - - // n_layer: number of encoder_layers - // dec_n_layer: number of decoder_layers - const int dec_n_layer = hparams.dec_n_layer; - if (dec_n_layer > n_layer) { - layers.resize(dec_n_layer); - } - - // load encoder layers - for (int i = 0; i < n_layer; ++i) { - auto & layer = layers[i]; - - layer.attn_norm_enc = create_tensor(tn(LLM_TENSOR_ENC_ATTN_NORM, "weight", i), {n_embd}, 0); - layer.attn_rel_b_enc = create_tensor(tn(LLM_TENSOR_ENC_ATTN_REL_B, "weight", i), {n_head, n_rel_attn_bkts}, TENSOR_NOT_REQUIRED); - - layer.wq_enc = create_tensor(tn(LLM_TENSOR_ENC_ATTN_Q, "weight", i), {n_embd, n_embd_k_gqa}, 0); - layer.wk_enc = create_tensor(tn(LLM_TENSOR_ENC_ATTN_K, "weight", i), {n_embd, n_embd_k_gqa}, 0); - layer.wv_enc = create_tensor(tn(LLM_TENSOR_ENC_ATTN_V, "weight", i), {n_embd, n_embd_v_gqa}, 0); - layer.wo_enc = create_tensor(tn(LLM_TENSOR_ENC_ATTN_OUT, "weight", i), {n_embd_v_gqa, n_embd}, 0); - - layer.ffn_norm_enc = create_tensor(tn(LLM_TENSOR_ENC_FFN_NORM, "weight", i), {n_embd}, 0); - layer.ffn_gate_enc = create_tensor(tn(LLM_TENSOR_ENC_FFN_GATE, "weight", i), {n_embd, n_ff}, TENSOR_NOT_REQUIRED); - layer.ffn_down_enc = create_tensor(tn(LLM_TENSOR_ENC_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); - layer.ffn_up_enc = create_tensor(tn(LLM_TENSOR_ENC_FFN_UP, "weight", i), {n_embd, n_ff}, 0); - } - - // load decoder layers - for (int i = 0; i < dec_n_layer; ++i) { - auto & layer = layers[i]; - - layer.attn_norm = create_tensor(tn(LLM_TENSOR_DEC_ATTN_NORM, "weight", i), {n_embd}, 0); - layer.attn_rel_b = create_tensor(tn(LLM_TENSOR_DEC_ATTN_REL_B, "weight", i), {n_head, n_rel_attn_bkts}, TENSOR_NOT_REQUIRED); - - layer.wq = create_tensor(tn(LLM_TENSOR_DEC_ATTN_Q, "weight", i), {n_embd, n_embd_k_gqa}, 0); - layer.wk = create_tensor(tn(LLM_TENSOR_DEC_ATTN_K, "weight", i), {n_embd, n_embd_k_gqa}, 0); - layer.wv = create_tensor(tn(LLM_TENSOR_DEC_ATTN_V, "weight", i), {n_embd, n_embd_v_gqa}, 0); - layer.wo = create_tensor(tn(LLM_TENSOR_DEC_ATTN_OUT, "weight", i), {n_embd_v_gqa, n_embd}, 0); - - layer.attn_norm_cross = create_tensor(tn(LLM_TENSOR_DEC_CROSS_ATTN_NORM, "weight", i), {n_embd}, 0); - // this tensor seems to be unused in HF transformers implementation - layer.attn_rel_b_cross = create_tensor(tn(LLM_TENSOR_DEC_CROSS_ATTN_REL_B, "weight", i), {n_head, n_rel_attn_bkts}, TENSOR_NOT_REQUIRED); - - layer.wq_cross = create_tensor(tn(LLM_TENSOR_DEC_CROSS_ATTN_Q, "weight", i), {n_embd, n_embd_k_gqa}, 0); - layer.wk_cross = create_tensor(tn(LLM_TENSOR_DEC_CROSS_ATTN_K, "weight", i), {n_embd, n_embd_k_gqa}, 0); - layer.wv_cross = create_tensor(tn(LLM_TENSOR_DEC_CROSS_ATTN_V, "weight", i), {n_embd, n_embd_v_gqa}, 0); - layer.wo_cross = create_tensor(tn(LLM_TENSOR_DEC_CROSS_ATTN_OUT, "weight", i), {n_embd_v_gqa, n_embd}, 0); - - layer.ffn_norm = create_tensor(tn(LLM_TENSOR_DEC_FFN_NORM, "weight", i), {n_embd}, 0); - layer.ffn_gate = create_tensor(tn(LLM_TENSOR_DEC_FFN_GATE, "weight", i), {n_embd, n_ff}, TENSOR_NOT_REQUIRED); - layer.ffn_down = create_tensor(tn(LLM_TENSOR_DEC_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); - layer.ffn_up = create_tensor(tn(LLM_TENSOR_DEC_FFN_UP, "weight", i), {n_embd, n_ff}, 0); - } - } break; - case LLM_ARCH_T5ENCODER: - { - const auto n_rel_attn_bkts = hparams.n_rel_attn_bkts; - - tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); - - // output - output_norm_enc = create_tensor(tn(LLM_TENSOR_ENC_OUTPUT_NORM, "weight"), {n_embd}, 0); - output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); - // if output is NULL, init from the input tok embed - if (output == NULL) { - output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); - } - - for (int i = 0; i < n_layer; ++i) { - auto & layer = layers[i]; - - layer.attn_norm_enc = create_tensor(tn(LLM_TENSOR_ENC_ATTN_NORM, "weight", i), {n_embd}, 0); - layer.attn_rel_b_enc = create_tensor(tn(LLM_TENSOR_ENC_ATTN_REL_B, "weight", i), {n_head, n_rel_attn_bkts}, TENSOR_NOT_REQUIRED); - - layer.wq_enc = create_tensor(tn(LLM_TENSOR_ENC_ATTN_Q, "weight", i), {n_embd, n_embd_k_gqa}, 0); - layer.wk_enc = create_tensor(tn(LLM_TENSOR_ENC_ATTN_K, "weight", i), {n_embd, n_embd_k_gqa}, 0); - layer.wv_enc = create_tensor(tn(LLM_TENSOR_ENC_ATTN_V, "weight", i), {n_embd, n_embd_v_gqa}, 0); - layer.wo_enc = create_tensor(tn(LLM_TENSOR_ENC_ATTN_OUT, "weight", i), {n_embd_v_gqa, n_embd}, 0); - - layer.ffn_norm_enc = create_tensor(tn(LLM_TENSOR_ENC_FFN_NORM, "weight", i), {n_embd}, 0); - layer.ffn_gate_enc = create_tensor(tn(LLM_TENSOR_ENC_FFN_GATE, "weight", i), {n_embd, n_ff}, TENSOR_NOT_REQUIRED); - layer.ffn_down_enc = create_tensor(tn(LLM_TENSOR_ENC_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); - layer.ffn_up_enc = create_tensor(tn(LLM_TENSOR_ENC_FFN_UP, "weight", i), {n_embd, n_ff}, 0); - } - } break; - case LLM_ARCH_JAIS: - { - tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); - - // output - output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); - output_norm_b = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd}, 0); - output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0); - - for (int i = 0; i < n_layer; ++i) { - auto & layer = layers[i]; - - layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); - layer.attn_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "bias", i), {n_embd}, 0); - - layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, 0); - layer.bqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "bias", i), {n_embd + 2*n_embd_gqa}, 0); - - layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); - layer.bo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, 0); - - layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); - layer.ffn_norm_b = create_tensor(tn(LLM_TENSOR_FFN_NORM, "bias", i), {n_embd}, 0); - - layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0); - layer.ffn_down_b = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, 0); - - layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); - layer.ffn_gate_b = create_tensor(tn(LLM_TENSOR_FFN_GATE, "bias", i), {n_ff}, 0); - - layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); - layer.ffn_up_b = create_tensor(tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff}, 0); - } - } break; - case LLM_ARCH_CHATGLM: - { - tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); - - // output - output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); - output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); - // if output is NULL, init from the input tok embed - if (output == NULL) { - output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); - } - - for (int i = 0; i < n_layer; ++i) { - auto & layer = layers[i]; - - layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); - layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, TENSOR_NOT_REQUIRED); - layer.bqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "bias", i), {n_embd + 2*n_embd_gqa}, TENSOR_NOT_REQUIRED); - - if (layer.wqkv == nullptr) { - layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head}, 0); - layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_k_gqa}, 0); - layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_v_gqa}, 0); - layer.bq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); - layer.bk = create_tensor(tn(LLM_TENSOR_ATTN_K, "bias", i), {n_embd_gqa}, TENSOR_NOT_REQUIRED); - layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V, "bias", i), {n_embd_gqa}, TENSOR_NOT_REQUIRED); - } - - layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); - - layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); - - layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff * 2}, 0); - - layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0); - } - } break; - case LLM_ARCH_GLM4: - { - tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); - - // output - output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); - output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); - // if output is NULL, init from the input tok embed - if (output == NULL) { - output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); - } - - for (int i = 0; i < n_layer; ++i) { - auto & layer = layers[i]; - - layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); - layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, TENSOR_NOT_REQUIRED); - layer.bqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "bias", i), {n_embd + 2*n_embd_gqa}, TENSOR_NOT_REQUIRED); - - if (layer.wqkv == nullptr) { - layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head}, 0); - layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_k_gqa}, 0); - layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_v_gqa}, 0); - layer.bq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); - layer.bk = create_tensor(tn(LLM_TENSOR_ATTN_K, "bias", i), {n_embd_gqa}, TENSOR_NOT_REQUIRED); - layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V, "bias", i), {n_embd_gqa}, TENSOR_NOT_REQUIRED); - } - - layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); - - layer.attn_post_norm = create_tensor(tn(LLM_TENSOR_ATTN_POST_NORM, "weight", i), {n_embd}, 0); - - layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); - layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); - layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff * 2}, 0); - - layer.ffn_post_norm = create_tensor(tn(LLM_TENSOR_FFN_POST_NORM, "weight", i), {n_embd}, 0); - } - } break; - case LLM_ARCH_GLM4_MOE: - { - const int64_t n_expert = hparams.n_expert; - const int64_t n_expert_used = hparams.n_expert_used; - const int64_t n_expert_shared = hparams.n_expert_shared; - - GGML_ASSERT(hparams.n_expert > 0 && "n_expert must be > 0 for GLM4_MOE MoE layers"); - GGML_ASSERT(hparams.n_expert_used > 0 && "n_expert_used must be > 0 for GLM4_MOE MoE layers"); - - tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), { n_embd, n_vocab }, 0); - - // output - output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), { n_embd }, 0); - output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), { n_embd, n_vocab }, TENSOR_NOT_REQUIRED); - // if output is NULL, init from the input tok embed - if (output == NULL) { - output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), { n_embd, n_vocab }, TENSOR_DUPLICATED); - } - - // Load ALL tensors including NextN layer to satisfy total tensor count - // but only PROCESS up to last layer (skipping final NextN layer) in forward pass - for (int i = 0; i < n_layer; ++i) { - int flags = 0; - if (hparams.nextn_predict_layers > 0 && static_cast<uint32_t>(i) >= n_layer - hparams.nextn_predict_layers) { - // skip all tensors in the NextN layers - flags |= TENSOR_SKIP; - } - - auto & layer = layers[i]; - - layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), { n_embd }, flags); - - // GLM-style attention with bias terms - layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), { n_embd, n_embd_head_k * n_head }, flags); - layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), { n_embd, n_embd_k_gqa }, flags); - layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), { n_embd, n_embd_v_gqa }, flags); - layer.bq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "bias", i), { n_embd_head_k * n_head }, TENSOR_NOT_REQUIRED | flags); - layer.bk = create_tensor(tn(LLM_TENSOR_ATTN_K, "bias", i), { n_embd_k_gqa }, TENSOR_NOT_REQUIRED | flags); - layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V, "bias", i), { n_embd_v_gqa }, TENSOR_NOT_REQUIRED | flags); - - layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), { n_embd_head_k * n_head, n_embd }, flags); - - // K/Q norm tensors (optional for GLM-4.5 355B variant) - layer.attn_q_norm = create_tensor( - tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), { n_embd_head_k }, TENSOR_NOT_REQUIRED | flags); - layer.attn_k_norm = create_tensor( - tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), { n_embd_head_k }, TENSOR_NOT_REQUIRED | flags); - - layer.attn_post_norm = create_tensor(tn(LLM_TENSOR_ATTN_POST_NORM, "weight", i), { n_embd }, flags); - - // Check if this layer uses MoE or dense FFN based on n_layer_dense_lead - // GLM 4.5 uses hybrid architecture: layer 0 is dense, layers 1+ are MoE - const bool use_moe = (static_cast<uint32_t>(i) >= hparams.n_layer_dense_lead); - - if (use_moe) { - // MoE layers - layer.ffn_gate_inp = - create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), { n_embd, n_expert }, flags); - layer.ffn_exp_probs_b = create_tensor(tn(LLM_TENSOR_FFN_EXP_PROBS_B, "bias", i), { n_expert }, flags); - - // MoE branch - const int64_t n_ff_exp = hparams.n_ff_exp ? hparams.n_ff_exp : n_ff / n_expert_used; - - layer.ffn_gate_exps = create_tensor( - tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert }, flags); - layer.ffn_down_exps = create_tensor( - tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), { n_ff_exp, n_embd, n_expert }, flags); - layer.ffn_up_exps = create_tensor( - tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert }, flags); - - // Shared expert - if (n_expert_shared > 0) { - const int64_t n_ff_shexp = n_ff_exp * n_expert_shared; - layer.ffn_gate_shexp = create_tensor( - tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), { n_embd, n_ff_shexp }, flags); - layer.ffn_down_shexp = create_tensor( - tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), { n_ff_shexp, n_embd }, flags); - layer.ffn_up_shexp = create_tensor( - tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), { n_embd, n_ff_shexp }, flags); - } - } else { - // Dense layers (first k layers) - GLM uses separate gate/up projections - layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), { n_embd, n_ff }, flags); - layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd }, flags); - layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), { n_embd, n_ff }, flags); - } - - // NextN/MTP tensors (preserved but unused) - conditionally load for last nextn_predict_layers - if (hparams.nextn_predict_layers > 0 && static_cast<uint32_t>(i) >= n_layer - hparams.nextn_predict_layers) { - layer.nextn.eh_proj = create_tensor(tn(LLM_TENSOR_NEXTN_EH_PROJ, "weight", i), { 2 * n_embd, n_embd }, flags); - layer.nextn.enorm = create_tensor(tn(LLM_TENSOR_NEXTN_ENORM, "weight", i), { n_embd }, flags); - layer.nextn.hnorm = create_tensor(tn(LLM_TENSOR_NEXTN_HNORM, "weight", i), { n_embd }, flags); - - // Optional tensors - layer.nextn.embed_tokens = create_tensor(tn(LLM_TENSOR_NEXTN_EMBED_TOKENS, "weight", i), { n_embd, n_vocab }, flags | TENSOR_NOT_REQUIRED); - layer.nextn.shared_head_head = create_tensor(tn(LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD, "weight", i), { n_embd, n_vocab }, flags | TENSOR_NOT_REQUIRED); - layer.nextn.shared_head_norm = create_tensor(tn(LLM_TENSOR_NEXTN_SHARED_HEAD_NORM, "weight", i), { n_embd }, flags | TENSOR_NOT_REQUIRED); - } - } - } - break; - case LLM_ARCH_NEMOTRON: - { - tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); - - // output - output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); - output_norm_b = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd}, 0); - output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0); - - for (int i = 0; i < n_layer; ++i) { - auto & layer = layers[i]; - - layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); - layer.attn_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "bias", i), {n_embd}, 0); - - layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}, 0); - layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}, 0); - layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}, 0); - layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); - - // optional bias tensors - layer.bq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); - layer.bk = create_tensor(tn(LLM_TENSOR_ATTN_K, "bias", i), {n_embd_gqa}, TENSOR_NOT_REQUIRED); - layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V, "bias", i), {n_embd_gqa}, TENSOR_NOT_REQUIRED); - layer.bo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); - - layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); - layer.ffn_norm_b = create_tensor(tn(LLM_TENSOR_FFN_NORM, "bias", i), {n_embd}, 0); - - layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); - layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); - - // optional MLP bias - layer.ffn_down_b = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); - layer.ffn_up_b = create_tensor(tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff}, TENSOR_NOT_REQUIRED); - } - } break; - case LLM_ARCH_NEMOTRON_H: - case LLM_ARCH_NEMOTRON_H_MOE: - { - // mamba2 Mixer SSM params - // NOTE: int64_t for tensor dimensions - const int64_t d_conv = hparams.ssm_d_conv; - const int64_t d_inner = hparams.ssm_d_inner; - const int64_t d_state = hparams.ssm_d_state; - const int64_t n_ssm_head = hparams.ssm_dt_rank; - const int64_t n_group = hparams.ssm_n_group; - const int64_t d_in_proj = 2*d_inner + 2*n_group*d_state + n_ssm_head; - - // embeddings - tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); - - // output - { - output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); - output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); - // if output is NULL, init from the input tok embed, duplicated to allow offloading - if (output == NULL) { - output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); - } - } - - for (int i = 0; i < n_layer; ++i) { - auto & layer = layers[i]; - - // all blocks use the attn norm - layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); - - if (hparams.is_recurrent(i)) { - // ssm layers - layer.ssm_in = create_tensor(tn(LLM_TENSOR_SSM_IN, "weight", i), {n_embd, d_in_proj}, 0); - - layer.ssm_conv1d = create_tensor(tn(LLM_TENSOR_SSM_CONV1D, "weight", i), {d_conv, d_inner + 2*n_group*d_state}, 0); - layer.ssm_conv1d_b = create_tensor(tn(LLM_TENSOR_SSM_CONV1D, "bias", i), {d_inner + 2*n_group*d_state}, TENSOR_NOT_REQUIRED); - - layer.ssm_dt_b = create_tensor(tn(LLM_TENSOR_SSM_DT, "bias", i), {n_ssm_head}, 0); - - // no "weight" suffix for these - layer.ssm_a = create_tensor(tn(LLM_TENSOR_SSM_A, i), {1, n_ssm_head}, 0); - layer.ssm_d = create_tensor(tn(LLM_TENSOR_SSM_D, i), {1, n_ssm_head}, 0); - - layer.ssm_norm = create_tensor(tn(LLM_TENSOR_SSM_NORM, "weight", i), {d_inner / n_group, n_group}, 0); - - // out_proj - layer.ssm_out = create_tensor(tn(LLM_TENSOR_SSM_OUT, "weight", i), {d_inner, n_embd}, 0); - } else if (hparams.n_ff(i) == 0) { - // attention layers (with optional bias) - const int64_t n_head_i = hparams.n_head(i); - const int64_t n_embd_k_gqa_i = hparams.n_embd_k_gqa(i); - const int64_t n_embd_v_gqa_i = hparams.n_embd_v_gqa(i); - layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head_i}, 0); - layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_k_gqa_i}, 0); - layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_v_gqa_i}, 0); - layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head_i, n_embd}, 0); - layer.bq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); - layer.bk = create_tensor(tn(LLM_TENSOR_ATTN_K, "bias", i), {n_embd_k_gqa_i}, TENSOR_NOT_REQUIRED); - layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V, "bias", i), {n_embd_v_gqa_i}, TENSOR_NOT_REQUIRED); - layer.bo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); - } else { - if (n_expert != 0) { - const int64_t n_ff_exp = hparams.n_ff_exp ? hparams.n_ff_exp : n_ff / n_expert_used; - const int64_t n_ff_shexp = hparams.n_ff_shexp; - - layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), { n_embd, n_expert}, 0); - layer.ffn_exp_probs_b = create_tensor(tn(LLM_TENSOR_FFN_EXP_PROBS_B, "bias", i), {n_expert }, 0); - - // MoE branch - layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff_exp, n_embd, n_expert}, 0); - layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}, 0); - - // Shared expert branch - layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), {n_ff_shexp, n_embd}, 0); - layer.ffn_up_shexp = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), {n_embd, n_ff_shexp}, 0); - - } else { - // mlp layers - layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { hparams.n_ff(i), n_embd}, 0); - layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, hparams.n_ff(i)}, 0); - layer.ffn_down_b = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); - layer.ffn_up_b = create_tensor(tn(LLM_TENSOR_FFN_UP, "bias", i), {hparams.n_ff(i)}, TENSOR_NOT_REQUIRED); - } - } - } - } break; - case LLM_ARCH_EXAONE: - { - tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); - - // output - output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); - output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); - - // if output is NULL, init from the input tok embed - if (output == NULL) { - output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); - } - - for (int i = 0; i < n_layer; ++i) { - auto & layer = layers[i]; - - layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); - - layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head}, 0); - layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_k_gqa}, 0); - layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_v_gqa}, 0); - layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0); - - layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); - layer.rope_freqs = create_tensor(tn(LLM_TENSOR_ROPE_FREQS, "weight", i), {n_rot/2}, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0)); - layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); - layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); - layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); - } - } break; - case LLM_ARCH_EXAONE4: - { - tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); - - // output - output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); - output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); - - // if output is NULL, init from the input tok embed - if (output == NULL) { - output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); - } - - for (int i = 0; i < n_layer; ++i) { - auto & layer = layers[i]; - - layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head}, 0); - layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_k_gqa}, 0); - layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_v_gqa}, 0); - layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); - - layer.rope_freqs = create_tensor(tn(LLM_TENSOR_ROPE_FREQS, "weight", i), {n_rot/2}, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0)); - - layer.attn_post_norm = create_tensor(tn(LLM_TENSOR_ATTN_POST_NORM, "weight", i), {n_embd}, 0); - layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head_k}, 0); - layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k}, 0); - - layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); - layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); - layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); - layer.ffn_post_norm = create_tensor(tn(LLM_TENSOR_FFN_POST_NORM, "weight", i), {n_embd}, 0); - } - } break; - case LLM_ARCH_RWKV6: - { - tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); - - // Block 0, LN0 - tok_norm = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "weight"), {n_embd}, 0); - tok_norm_b = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "bias"), {n_embd}, 0); - - // output - output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); - output_norm_b = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd}, 0); - output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0); - - const int time_mix_extra_dim = hparams.time_mix_extra_dim; - const int time_decay_extra_dim = hparams.time_decay_extra_dim; - const int head_size = hparams.wkv_head_size; - const int attn_hidden_size = n_embd; - const int ffn_size = hparams.n_ff_arr[0]; - - for (int i = 0; i < n_layer; ++i) { - auto & layer = layers[i]; - - layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); - layer.attn_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "bias", i), {n_embd}, 0); - - layer.attn_norm_2 = create_tensor(tn(LLM_TENSOR_ATTN_NORM_2, "weight", i), {n_embd}, 0); - layer.attn_norm_2_b = create_tensor(tn(LLM_TENSOR_ATTN_NORM_2, "bias", i), {n_embd}, 0); - - layer.time_mix_w1 = create_tensor(tn(LLM_TENSOR_TIME_MIX_W1, "weight", i), {n_embd, time_mix_extra_dim * 5}, 0); - layer.time_mix_w2 = create_tensor(tn(LLM_TENSOR_TIME_MIX_W2, "weight", i), {time_mix_extra_dim, n_embd, 5}, 0); - - layer.time_mix_lerp_x = create_tensor(tn(LLM_TENSOR_TIME_MIX_LERP_X, "weight", i), {n_embd, 1, 1}, 0); - layer.time_mix_lerp_w = create_tensor(tn(LLM_TENSOR_TIME_MIX_LERP_W, "weight", i), {n_embd, 1, 1}, TENSOR_NOT_REQUIRED); - layer.time_mix_lerp_k = create_tensor(tn(LLM_TENSOR_TIME_MIX_LERP_K, "weight", i), {n_embd, 1, 1}, TENSOR_NOT_REQUIRED); - layer.time_mix_lerp_v = create_tensor(tn(LLM_TENSOR_TIME_MIX_LERP_V, "weight", i), {n_embd, 1, 1}, TENSOR_NOT_REQUIRED); - layer.time_mix_lerp_r = create_tensor(tn(LLM_TENSOR_TIME_MIX_LERP_R, "weight", i), {n_embd, 1, 1}, TENSOR_NOT_REQUIRED); - layer.time_mix_lerp_g = create_tensor(tn(LLM_TENSOR_TIME_MIX_LERP_G, "weight", i), {n_embd, 1, 1}, TENSOR_NOT_REQUIRED); - layer.time_mix_lerp_fused = create_tensor(tn(LLM_TENSOR_TIME_MIX_LERP_FUSED, "weight", i), {n_embd, 1, 1, 5}, TENSOR_NOT_REQUIRED); - GGML_ASSERT(!(layer.time_mix_lerp_fused == NULL && layer.time_mix_lerp_w == NULL)); - - layer.time_mix_first = create_tensor(tn(LLM_TENSOR_TIME_MIX_FIRST, "weight", i), {head_size, n_embd / head_size}, 0); - layer.time_mix_decay = create_tensor(tn(LLM_TENSOR_TIME_MIX_DECAY, "weight", i), {n_embd}, 0); - layer.time_mix_decay_w1 = create_tensor(tn(LLM_TENSOR_TIME_MIX_DECAY_W1, "weight", i), {n_embd, time_decay_extra_dim}, 0); - layer.time_mix_decay_w2 = create_tensor(tn(LLM_TENSOR_TIME_MIX_DECAY_W2, "weight", i), {time_decay_extra_dim, attn_hidden_size}, 0); - layer.time_mix_key = create_tensor(tn(LLM_TENSOR_TIME_MIX_KEY, "weight", i), {attn_hidden_size, n_embd}, 0); - layer.time_mix_value = create_tensor(tn(LLM_TENSOR_TIME_MIX_VALUE, "weight", i), {attn_hidden_size, n_embd}, 0); - layer.time_mix_receptance = create_tensor(tn(LLM_TENSOR_TIME_MIX_RECEPTANCE, "weight", i), {attn_hidden_size, n_embd}, 0); - layer.time_mix_gate = create_tensor(tn(LLM_TENSOR_TIME_MIX_GATE, "weight", i), {attn_hidden_size, n_embd}, 0); - - layer.time_mix_ln = create_tensor(tn(LLM_TENSOR_TIME_MIX_LN, "weight", i), {n_embd}, 0); - layer.time_mix_ln_b = create_tensor(tn(LLM_TENSOR_TIME_MIX_LN, "bias", i), {n_embd}, 0); - layer.time_mix_output = create_tensor(tn(LLM_TENSOR_TIME_MIX_OUTPUT, "weight", i), {n_embd, attn_hidden_size}, 0); - - layer.channel_mix_lerp_k = create_tensor(tn(LLM_TENSOR_CHANNEL_MIX_LERP_K, "weight", i), {n_embd, 1, 1}, 0); - layer.channel_mix_lerp_r = create_tensor(tn(LLM_TENSOR_CHANNEL_MIX_LERP_R, "weight", i), {n_embd, 1, 1}, 0); - - layer.channel_mix_key = create_tensor(tn(LLM_TENSOR_CHANNEL_MIX_KEY, "weight", i), {n_embd, ffn_size}, 0); - layer.channel_mix_value = create_tensor(tn(LLM_TENSOR_CHANNEL_MIX_VALUE, "weight", i), {ffn_size, n_embd}, 0); - layer.channel_mix_receptance = create_tensor(tn(LLM_TENSOR_CHANNEL_MIX_RECEPTANCE, "weight", i), {n_embd, n_embd}, 0); - } - - } break; - case LLM_ARCH_RWKV6QWEN2: - { - tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); - - output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); - output_norm_b = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd}, TENSOR_NOT_REQUIRED); - output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0); - - const int time_mix_extra_dim = hparams.time_mix_extra_dim; - const int time_decay_extra_dim = hparams.time_decay_extra_dim; - const int head_size = hparams.wkv_head_size; - const int attn_hidden_size = n_embd; - const int n_head_kv = hparams.n_head_kv(); - int attn_key_value_size; - if (n_head_kv == 0 || attn_hidden_size / head_size == n_head_kv) { - attn_key_value_size = attn_hidden_size; - } else { - attn_key_value_size = n_head_kv * head_size; - } - - for (int i = 0; i < n_layer; ++i) { - auto & layer = layers[i]; - - layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); - - layer.time_mix_w1 = create_tensor(tn(LLM_TENSOR_TIME_MIX_W1, "weight", i), {n_embd, time_mix_extra_dim * 5}, 0); - layer.time_mix_w2 = create_tensor(tn(LLM_TENSOR_TIME_MIX_W2, "weight", i), {time_mix_extra_dim, n_embd, 5}, 0); - - layer.time_mix_lerp_x = create_tensor(tn(LLM_TENSOR_TIME_MIX_LERP_X, "weight", i), {n_embd, 1, 1}, 0); - layer.time_mix_lerp_fused = create_tensor(tn(LLM_TENSOR_TIME_MIX_LERP_FUSED, "weight", i), {n_embd, 1, 1, 5}, 0); - - layer.time_mix_first = create_tensor(tn(LLM_TENSOR_TIME_MIX_FIRST, "weight", i), {head_size, n_embd / head_size}, TENSOR_NOT_REQUIRED); - layer.time_mix_decay = create_tensor(tn(LLM_TENSOR_TIME_MIX_DECAY, "weight", i), {n_embd}, 0); - layer.time_mix_decay_w1 = create_tensor(tn(LLM_TENSOR_TIME_MIX_DECAY_W1, "weight", i), {n_embd, time_decay_extra_dim}, 0); - layer.time_mix_decay_w2 = create_tensor(tn(LLM_TENSOR_TIME_MIX_DECAY_W2, "weight", i), {time_decay_extra_dim, attn_hidden_size}, 0); - layer.time_mix_key = create_tensor(tn(LLM_TENSOR_TIME_MIX_KEY, "weight", i), {n_embd, attn_key_value_size}, 0); - layer.time_mix_value = create_tensor(tn(LLM_TENSOR_TIME_MIX_VALUE, "weight", i), {n_embd, attn_key_value_size}, 0); - layer.time_mix_receptance = create_tensor(tn(LLM_TENSOR_TIME_MIX_RECEPTANCE, "weight", i), {attn_hidden_size, n_embd}, 0); - layer.time_mix_gate = create_tensor(tn(LLM_TENSOR_TIME_MIX_GATE, "weight", i), {attn_hidden_size, n_embd}, 0); - // optional bias tensors - layer.time_mix_key_b = create_tensor(tn(LLM_TENSOR_TIME_MIX_KEY, "bias", i), {attn_key_value_size}, TENSOR_NOT_REQUIRED); - layer.time_mix_value_b = create_tensor(tn(LLM_TENSOR_TIME_MIX_VALUE, "bias", i), {attn_key_value_size}, TENSOR_NOT_REQUIRED); - layer.time_mix_receptance_b = create_tensor(tn(LLM_TENSOR_TIME_MIX_RECEPTANCE, "bias", i), {attn_hidden_size}, TENSOR_NOT_REQUIRED); - - layer.time_mix_output = create_tensor(tn(LLM_TENSOR_TIME_MIX_OUTPUT, "weight", i), {n_embd, attn_hidden_size}, 0); - - layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); - - layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); - layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); - layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); - } - } break; - case LLM_ARCH_RWKV7: - { - tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); - - // Block 0, LN0 - tok_norm = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "weight"), {n_embd}, 0); - tok_norm_b = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "bias"), {n_embd}, 0); - - // output - output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); - output_norm_b = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd}, 0); - output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0); - - const int n_lora_decay = hparams.n_lora_decay; - const int n_lora_iclr = hparams.n_lora_iclr; - const int n_lora_value_res_mix = hparams.n_lora_value_res_mix; - const int n_lora_gate = hparams.n_lora_gate; - const int attn_hidden_size = n_embd; - const int ffn_size = hparams.n_ff_arr[0]; - - for (int i = 0; i < n_layer; ++i) { - auto & layer = layers[i]; - - layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); - layer.attn_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "bias", i), {n_embd}, 0); - - layer.attn_norm_2 = create_tensor(tn(LLM_TENSOR_ATTN_NORM_2, "weight", i), {n_embd}, 0); - layer.attn_norm_2_b = create_tensor(tn(LLM_TENSOR_ATTN_NORM_2, "bias", i), {n_embd}, 0); - - layer.time_mix_w0 = create_tensor(tn(LLM_TENSOR_TIME_MIX_W0, "weight", i), {n_embd}, 0); - layer.time_mix_w1 = create_tensor(tn(LLM_TENSOR_TIME_MIX_W1, "weight", i), {n_embd, n_lora_decay}, 0); - layer.time_mix_w2 = create_tensor(tn(LLM_TENSOR_TIME_MIX_W2, "weight", i), {n_lora_decay, n_embd}, 0); - - layer.time_mix_a0 = create_tensor(tn(LLM_TENSOR_TIME_MIX_A0, "weight", i), {n_embd}, 0); - layer.time_mix_a1 = create_tensor(tn(LLM_TENSOR_TIME_MIX_A1, "weight", i), {n_embd, n_lora_iclr}, 0); - layer.time_mix_a2 = create_tensor(tn(LLM_TENSOR_TIME_MIX_A2, "weight", i), {n_lora_iclr, n_embd}, 0); - - if (i == 0) { - // actually not used - layer.time_mix_v0 = create_tensor(tn(LLM_TENSOR_TIME_MIX_V0, "weight", i), {n_embd}, 0); - layer.time_mix_v1 = create_tensor(tn(LLM_TENSOR_TIME_MIX_V1, "weight", i), {n_embd, n_lora_iclr}, 0); - layer.time_mix_v2 = create_tensor(tn(LLM_TENSOR_TIME_MIX_V2, "weight", i), {n_lora_iclr, n_embd}, 0); - } else { - layer.time_mix_v0 = create_tensor(tn(LLM_TENSOR_TIME_MIX_V0, "weight", i), {n_embd}, 0); - layer.time_mix_v1 = create_tensor(tn(LLM_TENSOR_TIME_MIX_V1, "weight", i), {n_embd, n_lora_value_res_mix}, 0); - layer.time_mix_v2 = create_tensor(tn(LLM_TENSOR_TIME_MIX_V2, "weight", i), {n_lora_value_res_mix, n_embd}, 0); - } - - layer.time_mix_g1 = create_tensor(tn(LLM_TENSOR_TIME_MIX_G1, "weight", i), {n_embd, n_lora_gate}, 0); - layer.time_mix_g2 = create_tensor(tn(LLM_TENSOR_TIME_MIX_G2, "weight", i), {n_lora_gate, n_embd}, 0); - - layer.time_mix_lerp_fused = create_tensor(tn(LLM_TENSOR_TIME_MIX_LERP_FUSED, "weight", i), {n_embd, 1, 1, 6}, 0); - - layer.time_mix_k_k = create_tensor(tn(LLM_TENSOR_TIME_MIX_K_K, "weight", i), {attn_hidden_size}, 0); - layer.time_mix_k_a = create_tensor(tn(LLM_TENSOR_TIME_MIX_K_A, "weight", i), {attn_hidden_size}, 0); - layer.time_mix_r_k = create_tensor(tn(LLM_TENSOR_TIME_MIX_R_K, "weight", i), {attn_hidden_size}, 0); - - layer.time_mix_key = create_tensor(tn(LLM_TENSOR_TIME_MIX_KEY, "weight", i), {attn_hidden_size, n_embd}, 0); - layer.time_mix_value = create_tensor(tn(LLM_TENSOR_TIME_MIX_VALUE, "weight", i), {attn_hidden_size, n_embd}, 0); - layer.time_mix_receptance = create_tensor(tn(LLM_TENSOR_TIME_MIX_RECEPTANCE, "weight", i), {attn_hidden_size, n_embd}, 0); - - layer.time_mix_ln = create_tensor(tn(LLM_TENSOR_TIME_MIX_LN, "weight", i), {n_embd}, 0); - layer.time_mix_ln_b = create_tensor(tn(LLM_TENSOR_TIME_MIX_LN, "bias", i), {n_embd}, 0); - layer.time_mix_output = create_tensor(tn(LLM_TENSOR_TIME_MIX_OUTPUT, "weight", i), {n_embd, attn_hidden_size}, 0); - - layer.channel_mix_lerp_k = create_tensor(tn(LLM_TENSOR_CHANNEL_MIX_LERP_K, "weight", i), {n_embd, 1, 1}, 0); - - layer.channel_mix_key = create_tensor(tn(LLM_TENSOR_CHANNEL_MIX_KEY, "weight", i), {n_embd, ffn_size}, 0); - layer.channel_mix_value = create_tensor(tn(LLM_TENSOR_CHANNEL_MIX_VALUE, "weight", i), {ffn_size, n_embd}, 0); - } - - } break; - case LLM_ARCH_ARWKV7: - { - tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); - - // output - output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); - output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0); - - const int n_lora_decay = hparams.n_lora_decay; - const int n_lora_iclr = hparams.n_lora_iclr; - const int n_lora_value_res_mix = hparams.n_lora_value_res_mix; - const int n_lora_gate = hparams.n_lora_gate; - const int attn_hidden_size = n_embd; - - for (int i = 0; i < n_layer; ++i) { - auto & layer = layers[i]; - - layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); - - layer.time_mix_w0 = create_tensor(tn(LLM_TENSOR_TIME_MIX_W0, "weight", i), {n_embd}, 0); - layer.time_mix_w1 = create_tensor(tn(LLM_TENSOR_TIME_MIX_W1, "weight", i), {n_embd, n_lora_decay}, 0); - layer.time_mix_w2 = create_tensor(tn(LLM_TENSOR_TIME_MIX_W2, "weight", i), {n_lora_decay, n_embd}, 0); - - layer.time_mix_a0 = create_tensor(tn(LLM_TENSOR_TIME_MIX_A0, "weight", i), {n_embd}, 0); - layer.time_mix_a1 = create_tensor(tn(LLM_TENSOR_TIME_MIX_A1, "weight", i), {n_embd, n_lora_iclr}, 0); - layer.time_mix_a2 = create_tensor(tn(LLM_TENSOR_TIME_MIX_A2, "weight", i), {n_lora_iclr, n_embd}, 0); - - if (i == 0) { - // actually not used - layer.time_mix_v0 = create_tensor(tn(LLM_TENSOR_TIME_MIX_V0, "weight", i), {n_embd}, 0); - layer.time_mix_v1 = create_tensor(tn(LLM_TENSOR_TIME_MIX_V1, "weight", i), {n_embd, n_lora_iclr}, 0); - layer.time_mix_v2 = create_tensor(tn(LLM_TENSOR_TIME_MIX_V2, "weight", i), {n_lora_iclr, n_embd}, 0); - } else { - layer.time_mix_v0 = create_tensor(tn(LLM_TENSOR_TIME_MIX_V0, "weight", i), {n_embd}, 0); - layer.time_mix_v1 = create_tensor(tn(LLM_TENSOR_TIME_MIX_V1, "weight", i), {n_embd, n_lora_value_res_mix}, 0); - layer.time_mix_v2 = create_tensor(tn(LLM_TENSOR_TIME_MIX_V2, "weight", i), {n_lora_value_res_mix, n_embd}, 0); - } - - layer.time_mix_g1 = create_tensor(tn(LLM_TENSOR_TIME_MIX_G1, "weight", i), {n_embd, n_lora_gate}, TENSOR_NOT_REQUIRED); - layer.time_mix_g2 = create_tensor(tn(LLM_TENSOR_TIME_MIX_G2, "weight", i), {n_lora_gate, n_embd}, TENSOR_NOT_REQUIRED); - - try { - layer.time_mix_lerp_fused = create_tensor(tn(LLM_TENSOR_TIME_MIX_LERP_FUSED, "weight", i), {n_embd, 1, 1, 6}, 0); - } catch(std::runtime_error & e) { - // ARWKV models may not have gate tensors - layer.time_mix_lerp_fused = create_tensor(tn(LLM_TENSOR_TIME_MIX_LERP_FUSED, "weight", i), {n_embd, 1, 1, 5}, 0); - } - - layer.time_mix_k_k = create_tensor(tn(LLM_TENSOR_TIME_MIX_K_K, "weight", i), {attn_hidden_size}, 0); - layer.time_mix_k_a = create_tensor(tn(LLM_TENSOR_TIME_MIX_K_A, "weight", i), {attn_hidden_size}, 0); - layer.time_mix_r_k = create_tensor(tn(LLM_TENSOR_TIME_MIX_R_K, "weight", i), {attn_hidden_size}, 0); - - layer.time_mix_key = create_tensor(tn(LLM_TENSOR_TIME_MIX_KEY, "weight", i), {attn_hidden_size, n_embd}, 0); - layer.time_mix_value = create_tensor(tn(LLM_TENSOR_TIME_MIX_VALUE, "weight", i), {attn_hidden_size, n_embd}, 0); - layer.time_mix_receptance = create_tensor(tn(LLM_TENSOR_TIME_MIX_RECEPTANCE, "weight", i), {attn_hidden_size, n_embd}, 0); - - layer.time_mix_ln = create_tensor(tn(LLM_TENSOR_TIME_MIX_LN, "weight", i), {n_embd}, TENSOR_NOT_REQUIRED); - layer.time_mix_ln_b = create_tensor(tn(LLM_TENSOR_TIME_MIX_LN, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); - layer.time_mix_output = create_tensor(tn(LLM_TENSOR_TIME_MIX_OUTPUT, "weight", i), {n_embd, attn_hidden_size}, 0); - - layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); - - layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); - layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); - layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); - } - - } break; - case LLM_ARCH_CHAMELEON: - { - tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); - - // output - output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); - output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); - // if output is NULL, init from the input tok embed - if (output == NULL) { - output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); - } - - for (int i = 0; i < n_layer; ++i) { - auto & layer = layers[i]; - - layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); - layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head_k, n_head}, 0); - layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k, n_head_kv}, 0); - layer.attn_q_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "bias", i), {n_embd_head_k, n_head}, TENSOR_NOT_REQUIRED); - layer.attn_k_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "bias", i), {n_embd_head_k, n_head_kv}, TENSOR_NOT_REQUIRED); - - layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}, 0); - layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}, 0); - layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}, 0); - layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); - - layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); - - layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); - layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); - layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); - } - } break; - case LLM_ARCH_WAVTOKENIZER_DEC: - { - tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {hparams.n_embd_features, n_vocab}, 0); - - conv1d = create_tensor(tn(LLM_TENSOR_CONV1D, "weight"), {7, hparams.n_embd_features, hparams.posnet.n_embd}, 0); - conv1d_b = create_tensor(tn(LLM_TENSOR_CONV1D, "bias"), {1, hparams.posnet.n_embd}, 0); - - // posnet - { - const int64_t n_embd = hparams.posnet.n_embd; - - for (uint32_t i = 0; i < hparams.posnet.n_layer; ++i) { - auto & layer = layers[i].posnet; - - // posnet: - // - // - resnet - // - resnet - // - attn - // - resnet - // - resnet - // - norm - // - switch (i) { - case 0: - case 1: - case 3: - case 4: - { - layer.norm1 = create_tensor(tn(LLM_TENSOR_POS_NET_NORM1, "weight", i), {1, n_embd}, 0); - layer.norm1_b = create_tensor(tn(LLM_TENSOR_POS_NET_NORM1, "bias", i), {1, n_embd}, 0); - - layer.conv1 = create_tensor(tn(LLM_TENSOR_POS_NET_CONV1, "weight", i), {3, n_embd, n_embd}, 0); - layer.conv1_b = create_tensor(tn(LLM_TENSOR_POS_NET_CONV1, "bias", i), {1, n_embd}, 0); - - layer.norm2 = create_tensor(tn(LLM_TENSOR_POS_NET_NORM2, "weight", i), {1, n_embd}, 0); - layer.norm2_b = create_tensor(tn(LLM_TENSOR_POS_NET_NORM2, "bias", i), {1, n_embd}, 0); - - layer.conv2 = create_tensor(tn(LLM_TENSOR_POS_NET_CONV2, "weight", i), {3, n_embd, n_embd}, 0); - layer.conv2_b = create_tensor(tn(LLM_TENSOR_POS_NET_CONV2, "bias", i), {1, n_embd}, 0); - } break; - case 2: - { - layer.attn_norm = create_tensor(tn(LLM_TENSOR_POS_NET_ATTN_NORM, "weight", i), {1, n_embd}, 0); - layer.attn_norm_b = create_tensor(tn(LLM_TENSOR_POS_NET_ATTN_NORM, "bias", i), {1, n_embd}, 0); - - layer.attn_q = create_tensor(tn(LLM_TENSOR_POS_NET_ATTN_Q, "weight", i), {1, n_embd, n_embd}, 0); - layer.attn_q_b = create_tensor(tn(LLM_TENSOR_POS_NET_ATTN_Q, "bias", i), {1, n_embd}, 0); - - layer.attn_k = create_tensor(tn(LLM_TENSOR_POS_NET_ATTN_K, "weight", i), {1, n_embd, n_embd}, 0); - layer.attn_k_b = create_tensor(tn(LLM_TENSOR_POS_NET_ATTN_K, "bias", i), {1, n_embd}, 0); - - layer.attn_v = create_tensor(tn(LLM_TENSOR_POS_NET_ATTN_V, "weight", i), {1, n_embd, n_embd}, 0); - layer.attn_v_b = create_tensor(tn(LLM_TENSOR_POS_NET_ATTN_V, "bias", i), {1, n_embd}, 0); - - layer.attn_o = create_tensor(tn(LLM_TENSOR_POS_NET_ATTN_OUT, "weight", i), {1, n_embd, n_embd}, 0); - layer.attn_o_b = create_tensor(tn(LLM_TENSOR_POS_NET_ATTN_OUT, "bias", i), {1, n_embd}, 0); - } break; - case 5: - { - layer.norm = create_tensor(tn(LLM_TENSOR_POS_NET_ATTN_NORM, "weight", i), {1, n_embd}, 0); - layer.norm_b = create_tensor(tn(LLM_TENSOR_POS_NET_ATTN_NORM, "bias", i), {1, n_embd}, 0); - } break; - default: GGML_ABORT("unknown posnet layer"); - }; - } - } - - GGML_ASSERT(hparams.posnet.n_embd == hparams.convnext.n_embd); - - tok_norm = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "weight"), {hparams.posnet.n_embd}, 0); - tok_norm_b = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "bias"), {hparams.posnet.n_embd}, 0); - - // convnext - { - const int64_t n_embd = hparams.convnext.n_embd; - - for (uint32_t i = 0; i < hparams.convnext.n_layer; ++i) { - auto & layer = layers[i].convnext; - - layer.dw = create_tensor(tn(LLM_TENSOR_CONVNEXT_DW, "weight", i), {7, 1, n_embd}, 0); - layer.dw_b = create_tensor(tn(LLM_TENSOR_CONVNEXT_DW, "bias", i), {1, n_embd}, 0); - - layer.norm = create_tensor(tn(LLM_TENSOR_CONVNEXT_NORM, "weight", i), {n_embd}, 0); - layer.norm_b = create_tensor(tn(LLM_TENSOR_CONVNEXT_NORM, "bias", i), {n_embd}, 0); - - layer.pw1 = create_tensor(tn(LLM_TENSOR_CONVNEXT_PW1, "weight", i), {n_embd, n_ff}, 0); - layer.pw1_b = create_tensor(tn(LLM_TENSOR_CONVNEXT_PW1, "bias", i), {n_ff}, 0); - - layer.pw2 = create_tensor(tn(LLM_TENSOR_CONVNEXT_PW2, "weight", i), {n_ff, n_embd}, 0); - layer.pw2_b = create_tensor(tn(LLM_TENSOR_CONVNEXT_PW2, "bias", i), {n_embd}, 0); - - layer.gamma = create_tensor(tn(LLM_TENSOR_CONVNEXT_GAMMA, "weight", i), {n_embd}, 0); - } - - // output - output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); - output_norm_b = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd}, 0); - } - - output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {hparams.convnext.n_embd, n_embd}, 0); - output_b = create_tensor(tn(LLM_TENSOR_OUTPUT, "bias"), {n_embd}, 0); - } break; - case LLM_ARCH_BAILINGMOE: - { - const int64_t n_ff_exp = hparams.n_ff_exp; - const int64_t n_expert_shared = hparams.n_expert_shared; - - tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); - - // output - output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); - output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0); - - for (int i = 0; i < n_layer; ++i) { - auto & layer = layers[i]; - - layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); - - layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_head * n_rot}, 0); - layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_head_kv * n_rot}, 0); - layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_head_kv * n_rot}, 0); - layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_head * n_rot, n_embd}, 0); - layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); - - layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0); - - if (n_expert == 0) { - throw std::runtime_error("n_expert must be > 0"); - } - if (n_expert_used == 0) { - throw std::runtime_error("n_expert_used must be > 0"); - } - - layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}, 0); - layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff_exp, n_embd, n_expert}, 0); - layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}, 0); - - layer.ffn_gate_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), {n_embd, n_ff_exp * n_expert_shared}, 0); - layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), { n_ff_exp * n_expert_shared, n_embd}, 0); - layer.ffn_up_shexp = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), {n_embd, n_ff_exp * n_expert_shared}, 0); - } - } break; - case LLM_ARCH_BAILINGMOE2: - { - const int64_t n_ff_exp = hparams.n_ff_exp; - const int64_t n_expert_shared = hparams.n_expert_shared; - - tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); - - // output - output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); - output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0); - - GGML_ASSERT(n_expert > 0 && "n_expert must be > 0 for bailingmoe2"); - GGML_ASSERT(n_expert_used > 0 && "n_expert_used must be > 0 for bailingmoe2"); - - for (int i = 0; i < n_layer; ++i) { - int flags = 0; - if (hparams.nextn_predict_layers > 0 && static_cast<uint32_t>(i) >= n_layer - hparams.nextn_predict_layers) { - // skip all tensors in the NextN layers - flags |= TENSOR_SKIP; - } - - auto & layer = layers[i]; - - layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, flags); - - layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, flags); - layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, flags); - - layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head_k}, flags); - layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k}, flags); - - layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, flags); - - if (static_cast<uint32_t>(i) >= hparams.n_layer_dense_lead) { // MoE layers - const int64_t n_ff_shexp = (hparams.n_ff_shexp ? hparams.n_ff_shexp : n_ff_exp) * n_expert_shared; - - layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, flags); - layer.ffn_exp_probs_b = create_tensor(tn(LLM_TENSOR_FFN_EXP_PROBS_B, "bias", i), {n_expert}, TENSOR_NOT_REQUIRED | flags); - - layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}, flags); - layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff_exp, n_embd, n_expert}, flags); - layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}, flags); - - layer.ffn_gate_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), {n_embd, n_ff_shexp}, flags); - layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), {n_ff_shexp, n_embd}, flags); - layer.ffn_up_shexp = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), {n_embd, n_ff_shexp}, flags); - } else { // Dense layers - layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, flags); - layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, flags); - layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, flags); - } - - // NextN/MTP tensors (preserved but unused) - conditionally load for last nextn_predict_layers - if (hparams.nextn_predict_layers > 0 && static_cast<uint32_t>(i) >= n_layer - hparams.nextn_predict_layers) { - layer.nextn.eh_proj = create_tensor(tn(LLM_TENSOR_NEXTN_EH_PROJ, "weight", i), { 2 * n_embd, n_embd }, flags); - layer.nextn.embed_tokens = create_tensor(tn(LLM_TENSOR_NEXTN_EMBED_TOKENS, "weight", i), { n_embd, n_vocab }, TENSOR_NOT_REQUIRED | flags); - layer.nextn.enorm = create_tensor(tn(LLM_TENSOR_NEXTN_ENORM, "weight", i), { n_embd }, flags); - layer.nextn.hnorm = create_tensor(tn(LLM_TENSOR_NEXTN_HNORM, "weight", i), { n_embd }, flags); - layer.nextn.shared_head_head = create_tensor(tn(LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD, "weight", i), { n_embd, n_vocab }, TENSOR_NOT_REQUIRED | flags); - layer.nextn.shared_head_norm = create_tensor(tn(LLM_TENSOR_NEXTN_SHARED_HEAD_NORM, "weight", i), { n_embd }, TENSOR_NOT_REQUIRED | flags); - layer.layer_out_norm = create_tensor(tn(LLM_TENSOR_LAYER_OUT_NORM, "weight", i), {n_embd}, flags); - } - } - } break; - case LLM_ARCH_DOTS1: - { - const int64_t n_ff_exp = hparams.n_ff_exp; - const int64_t n_expert_shared = hparams.n_expert_shared; - - tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); - - output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); - output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0); - - for (int i = 0; i < n_layer; ++i) { - auto & layer = layers[i]; - - layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); - - layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head}, 0); - layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_head_k * n_head}, 0); - layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_head_k * n_head}, 0); - layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0); - - layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k}, 0); - layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head_k}, 0); - - layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); - - if (i < (int) hparams.n_layer_dense_lead) { - layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); - layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); - layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); - } else { - layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0); - layer.ffn_exp_probs_b = create_tensor(tn(LLM_TENSOR_FFN_EXP_PROBS_B, "bias", i), {n_expert}, TENSOR_NOT_REQUIRED); - - if (n_expert == 0) { - throw std::runtime_error("n_expert must be > 0"); - } - if (n_expert_used == 0) { - throw std::runtime_error("n_expert_used must be > 0"); - } - - // MoE branch - layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}, 0); - layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff_exp, n_embd, n_expert}, 0); - layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}, 0); - - // Shared expert branch - layer.ffn_gate_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), {n_embd, n_ff_exp * n_expert_shared}, 0); - layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), { n_ff_exp * n_expert_shared, n_embd}, 0); - layer.ffn_up_shexp = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), {n_embd, n_ff_exp * n_expert_shared}, 0); - } - } - } break; - case LLM_ARCH_ARCEE: - { - tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); - - // output - output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); - output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); - - // if output is NULL, init from the input tok embed - if (output == NULL) { - output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); - } - - for (int i = 0; i < n_layer; ++i) { - auto & layer = layers[i]; - - layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); - - layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head}, 0); - layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_k_gqa}, 0); - layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_v_gqa}, 0); - layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0); - - layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); - - layer.rope_freqs = create_tensor(tn(LLM_TENSOR_ROPE_FREQS, "weight", i), {n_rot/2}, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0)); - - layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); - layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); - } - } break; - case LLM_ARCH_AFMOE: - { - tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); - - // output - output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); - output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); - - // if output is NULL, init from the input tok embed - if (output == NULL) { - output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); - } - - const int64_t n_ff_exp = hparams.n_ff_exp; - const int64_t n_expert_shared = hparams.n_expert_shared; - - for (int i = 0; i < n_layer; ++i) { - auto & layer = layers[i]; - - // dual attention normalization - layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); - layer.attn_post_norm = create_tensor(tn(LLM_TENSOR_ATTN_POST_NORM, "weight", i), {n_embd}, 0); - - // attention projections - layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head}, 0); - layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_k_gqa}, 0); - layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_v_gqa}, 0); - layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0); - - // Q/K normalization - layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head_k}, 0); - layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k}, 0); - - // attention gating - layer.wqkv_gate = create_tensor(tn(LLM_TENSOR_ATTN_GATE, "weight", i), {n_embd, n_embd_head_k * n_head}, 0); - - // dual ffn normalization - layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); - layer.ffn_post_norm = create_tensor(tn(LLM_TENSOR_FFN_POST_NORM, "weight", i), {n_embd}, 0); - - if (static_cast<uint32_t>(i) >= hparams.n_layer_dense_lead) { - // MoE layers - layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0); - layer.ffn_exp_probs_b = create_tensor(tn(LLM_TENSOR_FFN_EXP_PROBS_B, "bias", i), {n_expert}, 0); - - // grouped expert weights - layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {n_embd, n_ff_exp, n_expert}, 0); - layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff_exp, n_embd, n_expert}, 0); - layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), {n_embd, n_ff_exp, n_expert}, 0); - - // shared expert - if (n_expert_shared > 0) { - const int64_t n_ff_shexp = n_ff_exp * n_expert_shared; - layer.ffn_gate_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), {n_embd, n_ff_shexp}, 0); - layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), {n_ff_shexp, n_embd}, 0); - layer.ffn_up_shexp = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), {n_embd, n_ff_shexp}, 0); - } - } else { - // Dense layers - layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); - layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0); - layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); - } - } - } break; - case LLM_ARCH_ERNIE4_5: - case LLM_ARCH_ERNIE4_5_MOE: - { - tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); - - // output - output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); - output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); - // if output is NULL, init from the input tok embed - if (output == NULL) { - output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); - } - - for (int i = 0; i < n_layer; ++i) { - auto & layer = layers[i]; - - layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); - - layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head}, 0); - layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}, 0); - layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}, 0); - layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0); - - // optional bias tensors - layer.bq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); - layer.bk = create_tensor(tn(LLM_TENSOR_ATTN_K, "bias", i), {n_embd_gqa}, TENSOR_NOT_REQUIRED); - layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V, "bias", i), {n_embd_gqa}, TENSOR_NOT_REQUIRED); - layer.bo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); - - layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); - - if (arch == LLM_ARCH_ERNIE4_5_MOE && static_cast<uint32_t>(i) >= hparams.n_layer_dense_lead) { // MoE layers - int n_ff_exp = hparams.n_ff_exp; - - layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0); - layer.ffn_exp_probs_b = create_tensor(tn(LLM_TENSOR_FFN_EXP_PROBS_B, "bias", i), {n_expert}, TENSOR_NOT_REQUIRED); - layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {n_embd, n_ff_exp, n_expert}, TENSOR_NOT_REQUIRED); - layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), { n_ff_exp, n_embd, n_expert}, 0); - layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), {n_embd, n_ff_exp, n_expert}, 0); - - // Shared expert (if present) - if (hparams.n_ff_shexp > 0) { - layer.ffn_gate_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), { n_embd, hparams.n_ff_shexp}, 0); - layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), {hparams.n_ff_shexp, n_embd }, 0); - layer.ffn_up_shexp = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), { n_embd, hparams.n_ff_shexp}, 0); - } - } else { // Dense layers - layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); - layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); - layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); - } - } - } break; - case LLM_ARCH_FALCON_H1: - { - // Common - const int64_t hidden_size = hparams.n_embd; // hidden_size - - // mamba2 Mixer SSM params - const int64_t ssm_conv_kernel_size = hparams.ssm_d_conv; // ssm_conv_kernel_size - const int64_t ssm_n_groups = hparams.ssm_n_group; // ssm_n_groups - const int64_t ssm_state_size = hparams.ssm_d_state; // ssm_state_size - const int64_t ssm_intermediate_size = hparams.ssm_d_inner; // TODO expand - const int64_t ssm_num_heads = hparams.ssm_dt_rank; // ssm_num_heads - const int64_t ssm_conv_dim = ssm_intermediate_size + 2 * ssm_n_groups * ssm_state_size; - const int64_t ssm_projection_size = ssm_intermediate_size + ssm_conv_dim + ssm_num_heads; - - // attn params - const int64_t attn_num_attention_head = hparams.n_head(0); // rename to: attn_num_attention_head - const int64_t attn_num_key_value_head = hparams.n_head_kv(0); - - // ffn params - const int64_t ffn_intermediate_size = hparams.n_ff(0); - - // embeddings - tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {hidden_size, n_vocab}, 0); - - // output - output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {hidden_size, n_vocab}, TENSOR_NOT_REQUIRED); - output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {hidden_size}, 0); - - // if output is NULL, init from the input tok embed - if (output == NULL) { - output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {hidden_size, n_vocab}, TENSOR_DUPLICATED); - } - - for (int i = 0; i < n_layer; ++i) { - auto & layer = layers[i]; - - /*SSM LAYERS*/ - // ssm in - layer.ssm_in = create_tensor(tn(LLM_TENSOR_SSM_IN, "weight", i), {hidden_size, ssm_projection_size}, 0); - // ssm 1d conv - layer.ssm_conv1d = create_tensor(tn(LLM_TENSOR_SSM_CONV1D, "weight", i), {ssm_conv_kernel_size, ssm_conv_dim}, 0); - layer.ssm_conv1d_b = create_tensor(tn(LLM_TENSOR_SSM_CONV1D, "bias", i), {ssm_conv_dim}, TENSOR_NOT_REQUIRED); - // ssm_dt - layer.ssm_dt_b = create_tensor(tn(LLM_TENSOR_SSM_DT, "bias", i), {ssm_num_heads}, 0); - // no "weight" suffix for these - layer.ssm_a = create_tensor(tn(LLM_TENSOR_SSM_A, i), {1, ssm_num_heads}, 0); - layer.ssm_d = create_tensor(tn(LLM_TENSOR_SSM_D, i), {1, ssm_num_heads}, 0); - // ssm_norm - layer.ssm_norm = create_tensor(tn(LLM_TENSOR_SSM_NORM, "weight", i), {ssm_intermediate_size / ssm_n_groups, ssm_n_groups}, TENSOR_NOT_REQUIRED); - // out_proj - layer.ssm_out = create_tensor(tn(LLM_TENSOR_SSM_OUT, "weight", i), {ssm_intermediate_size, hidden_size}, 0); - - /*ATTENTION LAYERS*/ - // attention layers (with optional bias) - layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {hidden_size, n_embd_head_k * attn_num_attention_head}, 0); - layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {hidden_size, attn_num_key_value_head * n_embd_head_k}, 0); - layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {hidden_size, attn_num_key_value_head * n_embd_head_v}, 0); - layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * attn_num_attention_head, hidden_size}, 0); - layer.bq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "bias", i), {hidden_size}, TENSOR_NOT_REQUIRED); - layer.bk = create_tensor(tn(LLM_TENSOR_ATTN_K, "bias", i), {attn_num_key_value_head * n_embd_head_k}, TENSOR_NOT_REQUIRED); - layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V, "bias", i), {attn_num_key_value_head * n_embd_head_v}, TENSOR_NOT_REQUIRED); - layer.bo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {hidden_size}, TENSOR_NOT_REQUIRED); - layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {hidden_size}, 0); - - - // feed forward (w/ optional biases) - layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, i), {hidden_size}, 0); - layer.rope_freqs = create_tensor(tn(LLM_TENSOR_ROPE_FREQS, "weight", i), {n_rot/2}, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0)); - layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {hidden_size, ffn_intermediate_size}, 0); - layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { ffn_intermediate_size, hidden_size}, 0); - layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {hidden_size, ffn_intermediate_size}, 0); - - layer.ffn_gate_b = create_tensor(tn(LLM_TENSOR_FFN_GATE, "bias", i), {ffn_intermediate_size}, TENSOR_NOT_REQUIRED); - layer.ffn_down_b = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "bias", i), {hidden_size}, TENSOR_NOT_REQUIRED); - layer.ffn_up_b = create_tensor(tn(LLM_TENSOR_FFN_UP, "bias", i), {ffn_intermediate_size}, TENSOR_NOT_REQUIRED); - } - } break; - case LLM_ARCH_HUNYUAN_MOE: - { - tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); - - // output - output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); - output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); - // if output is NULL, init from the input tok embed - if (output == NULL) { - output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); - } - - for (int i = 0; i < n_layer; ++i) { - auto & layer = layers[i]; - - layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); - - layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head}, 0); - layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_k_gqa}, 0); - layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_v_gqa}, 0); - layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0); - - layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k}, 0); - layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head_k}, 0); - - layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); - - layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0); - layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {n_embd, n_ff, n_expert}, 0); - layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), { n_ff, n_embd, n_expert}, 0); - layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), {n_embd, n_ff, n_expert}, 0); - - layer.ffn_gate_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), {n_embd, hparams.n_ff_shexp}, 0); - layer.ffn_up_shexp = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), {n_embd, hparams.n_ff_shexp}, 0); - layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), {hparams.n_ff_shexp, n_embd}, 0); - } - } break; - case LLM_ARCH_HUNYUAN_DENSE: - { - tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); - - // output - output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); - output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); - // if output is NULL, init from the input tok embed - if (output == NULL) { - output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); - } - - for (int i = 0; i < n_layer; ++i) { - auto & layer = layers[i]; - - layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); - - layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head}, 0); - layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_k_gqa}, 0); - layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_v_gqa}, 0); - layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0); - - layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k}, 0); - layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head_k}, 0); - - layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); - - layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); - layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); - layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); - - } - } break; - case LLM_ARCH_SMOLLM3: - { - tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); - - // output - output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); - output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); - - // if output is NULL, init from the input tok embed - if (output == NULL) { - output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); - } - - for (int i = 0; i < n_layer; ++i) { - auto & layer = layers[i]; - - layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); - - layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head}, 0); - layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_k_gqa}, 0); - layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_v_gqa}, 0); - layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0); - - layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); - layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); - layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); - layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); - } - } break; - case LLM_ARCH_OPENAI_MOE: - { - const int64_t n_ff_exp = hparams.n_ff_exp; - - tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); - - // output - output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); - output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0); - - for (int i = 0; i < n_layer; ++i) { - auto & layer = layers[i]; - - layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); - layer.attn_post_norm = create_tensor(tn(LLM_TENSOR_ATTN_POST_NORM, "weight", i), {n_embd}, 0); - - layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_head * n_rot}, 0); - layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_head_kv * n_rot}, 0); - layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_head_kv * n_rot}, 0); - layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_head * n_rot, n_embd}, 0); - - layer.attn_sinks = create_tensor(tn(LLM_TENSOR_ATTN_SINKS, "weight", i), {n_head}, 0); - - layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), { n_embd, n_expert}, 0); - layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}, 0); - layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff_exp, n_embd, n_expert}, 0); - layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}, 0); - - // bias - layer.bq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_head * n_rot}, 0); - layer.bk = create_tensor(tn(LLM_TENSOR_ATTN_K, "bias", i), {n_head_kv * n_rot}, 0); - layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V, "bias", i), {n_head_kv * n_rot}, 0); - layer.bo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, 0); - - layer.ffn_gate_inp_b = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "bias", i), {n_expert}, 0); - layer.ffn_gate_exps_b = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "bias", i), {n_ff_exp, n_expert}, 0); - layer.ffn_down_exps_b = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "bias", i), { n_embd, n_expert}, 0); - layer.ffn_up_exps_b = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "bias", i), {n_ff_exp, n_expert}, 0); - } - } break; - case LLM_ARCH_LFM2: - case LLM_ARCH_LFM2MOE: - { - tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); - - output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM_LFM2, "weight"), {n_embd}, 0); - output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); - - if (output == NULL) { - output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); - } - - for (int i = 0; i < n_layer; ++i) { - auto & layer = layers[i]; - - const bool is_moe_layer = i >= static_cast<int>(hparams.n_layer_dense_lead); - - // ffn/moe is same for transformer and conv layers - layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); - if (is_moe_layer) { - GGML_ASSERT(n_expert && n_expert_used); - layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0); - layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {n_embd, hparams.n_ff_exp, n_expert}, 0); - layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {hparams.n_ff_exp, n_embd, n_expert}, 0); - layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), {n_embd, hparams.n_ff_exp, n_expert}, 0); - layer.ffn_exp_probs_b = create_tensor(tn(LLM_TENSOR_FFN_EXP_PROBS_B, "bias", i), {n_expert}, 0); - } else { // dense - layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); - layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); - layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); - } - - // for operator_norm - layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); - - if (!hparams.is_recurrent(i)) { - layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head_k}, 0); - layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k}, 0); - GGML_ASSERT(n_embd_v_gqa == n_embd_k_gqa); - - layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}, 0); - layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, hparams.n_embd_k_gqa(i)}, 0); - layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, hparams.n_embd_v_gqa(i)}, 0); - - layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); - } else { - layer.shortconv.conv = create_tensor(tn(LLM_TENSOR_SHORTCONV_CONV, "weight", i), {hparams.n_shortconv_l_cache, n_embd}, 0); - layer.shortconv.in_proj = create_tensor(tn(LLM_TENSOR_SHORTCONV_INPROJ, "weight", i), {n_embd, 3 * n_embd}, 0); - layer.shortconv.out_proj = create_tensor(tn(LLM_TENSOR_SHORTCONV_OUTPROJ, "weight", i), {n_embd, n_embd}, 0); - } - } - - // for LFM2-ColBert-350M - dense_2_out_layers = create_tensor(tn(LLM_TENSOR_DENSE_2_OUT, "weight"), {n_embd, hparams.get_n_embd_out()}, TENSOR_NOT_REQUIRED); - } break; - case LLM_ARCH_SMALLTHINKER: - { - tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), { n_embd, n_vocab }, 0); - - // output - output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), { n_embd }, 0); - output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); - - // if output is NULL, init from the input tok embed - if (output == NULL) { - output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); - } - - for (int i = 0; i < n_layer; ++i) { - auto & layer = layers[i]; - - layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), { n_embd }, 0); - - layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), { n_embd, n_embd_head_k * n_head }, 0); - layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), { n_embd, n_embd_gqa }, 0); - layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), { n_embd, n_embd_gqa }, 0); - layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), { n_embd_head_k * n_head, n_embd }, 0); - - layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), { n_embd }, 0); - - GGML_ASSERT(n_expert > 0 && "n_expert must be > 0 for SMALLTHINKER"); - GGML_ASSERT(n_expert_used > 0 && "n_expert_used must be > 0 for SMALLTHINKER"); - - // MoE branch - const int64_t n_ff_exp = hparams.n_ff_exp; - layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), { n_embd, n_expert }, 0); - layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert }, 0); - layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), { n_ff_exp, n_embd, n_expert }, 0); - layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert }, 0); - } - } break; - case LLM_ARCH_GROVEMOE: - { - tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); - - // output - output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); - output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); - // if output is NULL, init from the input tok embed - if (output == NULL) { - output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); - } - - GGML_ASSERT(n_expert > 0 && "n_expert must be > 0 for GROVEMOE"); - GGML_ASSERT(n_expert_used > 0 && "n_expert_used must be > 0 for GROVEMOE"); - GGML_ASSERT(hparams.n_group_experts > 0 && "n_group_experts must be > 0 for GROVEMOE"); - - for (int i = 0; i < n_layer; ++i) { - auto & layer = layers[i]; - - layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); - - layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head}, 0); - layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}, 0); - layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}, 0); - layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0); - - layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k}, 0); - layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head_k}, 0); - - layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); - - layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0); - - // MoE branch - const int64_t n_ff_exp = hparams.n_ff_exp ? hparams.n_ff_exp : n_ff / n_expert_used; - const int64_t n_ff_chexp = hparams.n_ff_chexp ? hparams.n_ff_chexp : n_embd_head_k; - const int64_t n_chunk_expert = n_expert / hparams.n_group_experts; - - layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}, 0); - layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff_exp, n_embd, n_expert}, 0); - layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}, 0); - - layer.ffn_gate_chexps = create_tensor(tn(LLM_TENSOR_FFN_GATE_CHEXPS, "weight", i), { n_embd, n_ff_chexp, n_chunk_expert}, 0); - layer.ffn_down_chexps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_CHEXPS, "weight", i), {n_ff_chexp, n_embd, n_chunk_expert}, 0); - layer.ffn_up_chexps = create_tensor(tn(LLM_TENSOR_FFN_UP_CHEXPS, "weight", i), { n_embd, n_ff_chexp, n_chunk_expert}, 0); - } - } break; - case LLM_ARCH_APERTUS: - { - tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), { n_embd, n_vocab }, 0); - - // output - output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), { n_embd }, 0); - output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), { n_embd, n_vocab }, 0); - - for (int i = 0; i < n_layer; ++i) { - auto & layer = layers[i]; - - layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), { n_embd }, 0); - - if (hparams.rope_scaling_type_train == LLAMA_ROPE_SCALING_TYPE_LONGROPE) { - layer.rope_long = create_tensor(tn(LLM_TENSOR_ROPE_FACTORS_LONG, "weight", i), { n_rot/2 }, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0)); - layer.rope_short = create_tensor(tn(LLM_TENSOR_ROPE_FACTORS_SHORT, "weight", i), { n_rot/2 }, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0)); - } else { - layer.rope_freqs = create_tensor(tn(LLM_TENSOR_ROPE_FREQS, "weight", i), { n_rot/2 }, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0)); - } - - layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), { n_embd, n_embd_head_k * n_head }, 0); - layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), { n_embd, n_embd_gqa }, 0); - layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), { n_embd, n_embd_gqa }, 0); - layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), { n_embd_head_k * n_head, n_embd }, 0); - - // optional bias tensors - layer.bq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "bias", i), { n_embd }, TENSOR_NOT_REQUIRED); - layer.bk = create_tensor(tn(LLM_TENSOR_ATTN_K, "bias", i), { n_embd_gqa }, TENSOR_NOT_REQUIRED); - layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V, "bias", i), { n_embd_gqa }, TENSOR_NOT_REQUIRED); - layer.bo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), { n_embd }, TENSOR_NOT_REQUIRED); - - layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), { n_embd }, 0); - layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd }, 0); - layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), { n_embd, n_ff }, 0); - - // Q and K layernorms for Apertus - layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), { n_embd_head_k }, 0); - layer.attn_q_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "bias", i), { n_embd_head_k }, TENSOR_NOT_REQUIRED); - layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), { n_embd_head_k }, 0); - layer.attn_k_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "bias", i), { n_embd_head_k }, TENSOR_NOT_REQUIRED); - } - } break; - case LLM_ARCH_MINIMAX_M2: - { - tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); - - // output - output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); - output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0); - - for (int i = 0; i < n_layer; ++i) { - auto & layer = layers[i]; - - layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), { n_embd, n_embd_head_k * n_head }, 0); - layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), { n_embd, n_embd_gqa }, 0); - layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), { n_embd, n_embd_gqa }, 0); - layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), { n_embd_head_k * n_head, n_embd }, 0); - - layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); - layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head_k * n_head}, 0); - layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_k_gqa}, 0); - - layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); - - layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0); - layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {n_embd, n_ff, n_expert}, 0); - layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff, n_embd, n_expert}, 0); - layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), {n_embd, n_ff, n_expert}, 0); - layer.ffn_exp_probs_b = create_tensor(tn(LLM_TENSOR_FFN_EXP_PROBS_B, "bias", i), {n_expert}, 0); - } - } break; - case LLM_ARCH_COGVLM: - { - tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); - - // output - output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); - output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); - - // if output is NULL, init from the input tok embed - if (output == NULL) { - output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); - } - - for (int i = 0; i < n_layer; ++i) { - auto & layer = layers[i]; - - layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); - layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd_head_k * n_head * 3}, 0); - layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0); - - layer.visexp_attn_wqkv = create_tensor(tn(LLM_TENSOR_VISEXP_ATTN_QKV, "weight", i), {n_embd, n_embd_head_k * n_head * 3}, 0); - layer.visexp_attn_wo = create_tensor(tn(LLM_TENSOR_VISEXP_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0); - - layer.rope_freqs = create_tensor(tn(LLM_TENSOR_ROPE_FREQS, "weight", i), {n_rot/2}, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0)); - - layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); - layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); - layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); - layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); - - layer.visexp_ffn_gate = create_tensor(tn(LLM_TENSOR_VISEXP_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); - layer.visexp_ffn_down = create_tensor(tn(LLM_TENSOR_VISEXP_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); - layer.visexp_ffn_up = create_tensor(tn(LLM_TENSOR_VISEXP_FFN_UP, "weight", i), {n_embd, n_ff}, 0); - } - } break; - case LLM_ARCH_PANGU_EMBED: - { - tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); - - // output - output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); - output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); - - // if output is NULL, init from the input tok embed - if (output == NULL) { - output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); - } - - for (int i = 0; i < n_layer; ++i) { - auto & layer = layers[i]; - - layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); - - // weight tensors - layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head}, 0); - layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_k_gqa}, 0); - layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_v_gqa}, 0); - layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0); - - // bias tensors - layer.bq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_embd_head_k * n_head}, 0); - layer.bk = create_tensor(tn(LLM_TENSOR_ATTN_K, "bias", i), {n_embd_gqa}, 0); - layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V, "bias", i), {n_embd_gqa}, 0); - layer.bo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, 0); - - layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); - - if (hparams.rope_scaling_type_train == LLAMA_ROPE_SCALING_TYPE_LONGROPE) { - layer.rope_long = create_tensor(tn(LLM_TENSOR_ROPE_FACTORS_LONG, "weight", i), {n_rot/2}, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0)); - layer.rope_short = create_tensor(tn(LLM_TENSOR_ROPE_FACTORS_SHORT, "weight", i), {n_rot/2}, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0)); - } else { - layer.rope_freqs = create_tensor(tn(LLM_TENSOR_ROPE_FREQS, "weight", i), {n_rot/2}, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0)); - } - - layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); - layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); - layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); - } - } break; - case LLM_ARCH_QWEN3NEXT: - { - tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), { n_embd, n_vocab }, 0); - - // output - output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), { n_embd }, 0); - output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), { n_embd, n_vocab }, TENSOR_NOT_REQUIRED); - - // if output is NULL, init from the input tok embed - if (output == NULL) { - output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), { n_embd, n_vocab }, TENSOR_DUPLICATED); - } - - const int64_t n_ff_exp = hparams.n_ff_exp ? hparams.n_ff_exp : n_ff / n_expert_used; - - // Calculate dimensions from hyperparameters - const int64_t head_k_dim = hparams.ssm_d_state; - const int64_t head_v_dim = hparams.ssm_d_state; - const int64_t n_k_heads = hparams.ssm_n_group; - const int64_t n_v_heads = hparams.ssm_dt_rank; - const int64_t key_dim = head_k_dim * n_k_heads; - const int64_t value_dim = head_v_dim * n_v_heads; - const int64_t conv_dim = key_dim * 2 + value_dim; - - // Calculate projection sizes - const int64_t qkvz_dim = key_dim * 2 + value_dim * 2; - const int64_t ba_dim = n_v_heads * 2; - - for (int i = 0; i < n_layer; ++i) { - auto & layer = layers[i]; - - layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), { n_embd }, 0); - layer.attn_post_norm = create_tensor(tn(LLM_TENSOR_ATTN_POST_NORM, "weight", i), { n_embd }, 0); - - if (!hparams.is_recurrent(i)) { - // Attention layers - layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), { n_embd, n_embd_head_k * n_head * 2 }, 0); - layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), { n_embd, n_embd_k_gqa }, 0); - layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), { n_embd, n_embd_v_gqa }, 0); - layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), { n_embd_head_k * n_head, n_embd }, 0); - - // Q/K normalization for attention layers - layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), { n_embd_head_k }, 0); - layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), { n_embd_head_k }, 0); - } else { - // Linear attention (gated delta net) specific tensors - // Create tensors with calculated dimensions - // note: ssm_in is used by legacy GGUF - layer.ssm_in = create_tensor(tn(LLM_TENSOR_SSM_IN, "weight", i), { n_embd, qkvz_dim }, TENSOR_NOT_REQUIRED); - layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), { n_embd, key_dim * 2 + value_dim }, TENSOR_NOT_REQUIRED); - layer.wqkv_gate = create_tensor(tn(LLM_TENSOR_ATTN_GATE, "weight", i), { n_embd, value_dim }, TENSOR_NOT_REQUIRED); - layer.ssm_conv1d = create_tensor(tn(LLM_TENSOR_SSM_CONV1D, "weight", i), { hparams.ssm_d_conv, conv_dim }, 0); - layer.ssm_dt = create_tensor(tn(LLM_TENSOR_SSM_DT, "bias", i), { hparams.ssm_dt_rank }, 0); - layer.ssm_a = create_tensor(tn(LLM_TENSOR_SSM_A_NOSCAN, i), { hparams.ssm_dt_rank }, 0); - layer.ssm_beta_alpha = create_tensor(tn(LLM_TENSOR_SSM_BETA_ALPHA, "weight", i), { n_embd, ba_dim }, 0); - layer.ssm_norm = create_tensor(tn(LLM_TENSOR_SSM_NORM, "weight", i), { head_v_dim }, 0); - layer.ssm_out = create_tensor(tn(LLM_TENSOR_SSM_OUT, "weight", i), { value_dim, n_embd }, 0); - } - - layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), { n_embd, n_expert }, 0); - layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert }, 0); - layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), { n_ff_exp, n_embd, n_expert }, 0); - layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert }, 0); - - // Shared experts - layer.ffn_gate_inp_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP_SHEXP, "weight", i), { n_embd }, 0); - layer.ffn_gate_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), { n_embd, hparams.n_ff_shexp }, 0); - layer.ffn_up_shexp = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), { n_embd, hparams.n_ff_shexp }, 0); - layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), { hparams.n_ff_shexp, n_embd }, 0); - } - } break; - case LLM_ARCH_MIMO2: - { - tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); - - // output - output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); - output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0); - - for (int i = 0; i < n_layer; ++i) { - auto & layer = layers[i]; - uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(i); - uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(i); - uint32_t n_head = hparams.n_head(i); - - layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), { n_embd, n_embd_head_k * n_head }, 0); - layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), { n_embd, n_embd_k_gqa }, 0); - layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), { n_embd, n_embd_v_gqa }, 0); - layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), { n_embd_head_v * n_head, n_embd }, 0); - - layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); - layer.attn_sinks = create_tensor(tn(LLM_TENSOR_ATTN_SINKS, "weight", i), {n_head}, TENSOR_NOT_REQUIRED); - - layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); - - // non-MoE branch - layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, TENSOR_NOT_REQUIRED); - layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, TENSOR_NOT_REQUIRED); - layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, TENSOR_NOT_REQUIRED); - - // MoE branch - int64_t n_ff_exp = hparams.n_ff_exp; - layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, TENSOR_NOT_REQUIRED); - layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {n_embd, n_ff_exp, n_expert}, TENSOR_NOT_REQUIRED); - layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff_exp, n_embd, n_expert}, TENSOR_NOT_REQUIRED); - layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), {n_embd, n_ff_exp, n_expert}, TENSOR_NOT_REQUIRED); - layer.ffn_exp_probs_b = create_tensor(tn(LLM_TENSOR_FFN_EXP_PROBS_B, "bias", i), {n_expert}, TENSOR_NOT_REQUIRED); - } - } break; - case LLM_ARCH_MAINCODER: - { - tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); - - // output - output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); - output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); - // if output is NULL, init from the input tok embed - if (output == NULL) { - output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); - } - - for (int i = 0; i < n_layer; ++i) { - auto & layer = layers[i]; - - layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); - - layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head}, 0); - layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}, 0); - layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}, 0); - layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0); - - layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k}, 0); - layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head_k}, 0); - - layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); - layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); - layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); - layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); - } - } break; - default: - throw std::runtime_error("unknown architecture"); - } - - if (n_moved_tensors > 0) { - LLAMA_LOG_DEBUG("%s: tensor '%s' (%s) (and %d others) cannot be used with preferred buffer type %s, using %s instead\n", - __func__, first_moved_tensor->name, ggml_type_name(first_moved_tensor->type), n_moved_tensors - 1, - ggml_backend_buft_name(first_moved_from_buft), ggml_backend_buft_name(first_moved_to_buft)); - } - } - - ml.done_getting_tensors(); - - ml.init_mappings(true, use_mlock ? &pimpl->mlock_mmaps : nullptr); - pimpl->mappings.reserve(ml.mappings.size()); - - // create the backend buffers - std::vector<std::pair<ggml_context *, llama_buf_map>> ctx_buf_maps; - ctx_buf_maps.reserve(ctx_map.size()); - - // Ensure we have enough capacity for the maximum backend buffer we will potentially create - const size_t n_max_backend_buffer = ctx_map.size() * ml.files.size(); - pimpl->ctxs_bufs.reserve(n_max_backend_buffer); - - for (auto & [buft, ctx_ptr] : ctx_map) { - ggml_context * ctx = ctx_ptr.get(); - - // skip contexts without tensors - if (ggml_get_first_tensor(ctx) == nullptr) { - continue; - } - - llama_buf_map buf_map; - buf_map.reserve(n_max_backend_buffer); - - // check if it is possible to use buffer_from_host_ptr with this buffer type - ggml_backend_dev_t dev = ggml_backend_buft_get_device(buft); - if (!dev) { - // FIXME: workaround for CPU backend buft having a NULL device - dev = ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_CPU); - if (!dev) { - throw std::runtime_error(format("%s: no CPU backend found", __func__)); - } - } - ggml_backend_dev_props props; - ggml_backend_dev_get_props(dev, &props); - bool buffer_from_host_ptr_supported = props.caps.buffer_from_host_ptr; - bool is_default_buft = buft == ggml_backend_dev_buffer_type(dev); - - std::vector<ggml_backend_buffer_ptr> bufs; - if (ml.use_mmap && use_mmap_buffer && buffer_from_host_ptr_supported && is_default_buft) { - GGML_ASSERT(!ml.no_alloc); - for (uint32_t idx = 0; idx < ml.files.size(); idx++) { - // only the mmap region containing the tensors in the model is mapped to the backend buffer - // this is important for metal with apple silicon: if the entire model could be mapped to a metal buffer, - // then we could just use metal for all layers - // this allows using partial offloading when the model size exceeds the metal buffer size, but not the RAM size - void * addr = nullptr; - size_t first, last; // NOLINT - ml.get_mapping_range(&first, &last, &addr, idx, ctx); - if (first >= last) { - continue; - } - const size_t max_size = ggml_get_max_tensor_size(ctx); - ggml_backend_buffer_t buf = ggml_backend_dev_buffer_from_host_ptr(dev, (char *) addr + first, last - first, max_size); - if (buf == nullptr) { - throw std::runtime_error(format("unable to allocate %s buffer", ggml_backend_buft_name(buft))); - } - bufs.emplace_back(buf); - buf_map.emplace(idx, buf); - } - } else { - ggml_backend_buffer_t buf; - if (ml.no_alloc) { - buf = ggml_backend_buft_alloc_buffer(buft, /*size =*/ 0); // dummy buffer - for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != nullptr; t = ggml_get_next_tensor(ctx, t)) { - t->buffer = buf; // set dummy buffer for weights so that the backend scheduler won't try to allocate them - } - } else { - buf = ggml_backend_alloc_ctx_tensors_from_buft(ctx, buft); // real buffer - } - if (buf == nullptr) { - throw std::runtime_error(format("unable to allocate %s buffer", ggml_backend_buft_name(buft))); - } - if (use_mlock && ggml_backend_buffer_is_host(buf)) { - pimpl->mlock_bufs.emplace_back(new llama_mlock); - auto & mlock_buf = pimpl->mlock_bufs.back(); - mlock_buf->init (ggml_backend_buffer_get_base(buf)); - mlock_buf->grow_to(ggml_backend_buffer_get_size(buf)); - } - bufs.emplace_back(buf); - for (uint32_t idx = 0; idx < ml.files.size(); idx++) { - buf_map.emplace(idx, buf); - } - } - pimpl->ctxs_bufs.emplace_back(std::move(ctx_ptr), std::move(bufs)); - - for (auto & buf : buf_map) { - // indicate that this buffer contains weights - // this is used by ggml_backend_sched to improve op scheduling: ops that use a weight are preferably scheduled to the backend that contains the weight - ggml_backend_buffer_set_usage(buf.second, GGML_BACKEND_BUFFER_USAGE_WEIGHTS); - } - - ctx_buf_maps.emplace_back(ctx, buf_map); - } - - if (llama_supports_gpu_offload()) { - const int n_gpu = std::min(n_gpu_layers, int(hparams.n_layer)); - - int n_repeating = n_gpu; - if (n_repeating > 0) { - LLAMA_LOG_INFO("%s: offloading output layer to GPU\n", __func__); - n_repeating--; - } - LLAMA_LOG_INFO("%s: offloading %d repeating layers to GPU\n", __func__, n_repeating); - - const int max_backend_supported_layers = hparams.n_layer + 1; - const int max_offloadable_layers = hparams.n_layer + 1; - - LLAMA_LOG_INFO("%s: offloaded %d/%d layers to GPU\n", __func__, std::min(n_gpu_layers, max_offloadable_layers), max_backend_supported_layers); - } - - // print memory requirements per buffer type - for (auto & [_, bufs] : pimpl->ctxs_bufs) { - for (auto & buf: bufs) { - LLAMA_LOG_INFO("%s: %12s model buffer size = %8.2f MiB\n", - __func__, ggml_backend_buffer_name(buf.get()), ggml_backend_buffer_get_size(buf.get()) / 1024.0 / 1024.0); - } - } - - // populate tensors_by_name - for (auto & [ctx, _] : pimpl->ctxs_bufs) { - for (auto * cur = ggml_get_first_tensor(ctx.get()); cur != NULL; cur = ggml_get_next_tensor(ctx.get(), cur)) { - tensors_by_name.emplace_back(ggml_get_name(cur), cur); - } - } - - if (ml.no_alloc) { - return true; - } - - // load tensor data - for (auto & [ctx, buf_map] : ctx_buf_maps) { - if (!ml.load_all_data(ctx, buf_map, use_mlock ? &pimpl->mlock_mmaps : NULL, params.progress_callback, params.progress_callback_user_data)) { - return false; - } - } - - if (use_mmap_buffer) { - for (auto & mapping : ml.mappings) { - pimpl->mappings.emplace_back(std::move(mapping)); - } - } - - return true; -} - -std::string llama_model::arch_name() const { - return llm_arch_name(arch); -} - -std::string llama_model::type_name() const { - return llm_type_name(type); -} - -std::string llama_model::desc() const { - return pimpl->desc_str; -} - -size_t llama_model::size() const { - return pimpl->n_bytes; -} - -size_t llama_model::n_tensors() const { - return tensors_by_name.size(); -} - -size_t llama_model::n_devices() const { - return devices.size(); -} - -uint32_t llama_model::n_gpu_layers() const { - return params.n_gpu_layers >= 0 ? params.n_gpu_layers : hparams.n_layer + 1; -} - -llama_split_mode llama_model::split_mode() const { - return params.split_mode; -} - -std::map<ggml_backend_buffer_type_t, size_t> llama_model::memory_breakdown() const { - std::map<ggml_backend_buffer_type_t, size_t> ret; - for (const auto & [ctx, bufs] : pimpl->ctxs_bufs) { - if (hparams.no_alloc) { - GGML_ASSERT(bufs.size() == 1); - ggml_backend_buffer_t buf = bufs[0].get(); - GGML_ASSERT(ggml_backend_buffer_get_base(buf) == nullptr); - ggml_backend_buffer_type_t buft = ggml_backend_buffer_get_type(buf); - ret[buft] += ggml_backend_alloc_ctx_tensors_from_buft_size(ctx.get(), buft); - } else { - for (const auto & buf : bufs) { - // GGML_ASSERT(ggml_backend_buffer_get_base(buf.get()) != nullptr); // multi_buffer does not have a defined base - ret[ggml_backend_buffer_get_type(buf.get())] += ggml_backend_buffer_get_size(buf.get()); - } - } - } - return ret; -} - -uint64_t llama_model::n_elements() const { - return pimpl->n_elements; -} - -void llama_model::print_info() const { - const std::string rope_scaling_type = llama_rope_scaling_type_name(hparams.rope_scaling_type_train); - - auto print_f = [](const std::function<uint32_t(uint32_t)> & f, uint32_t n) { - bool is_var = false; - - std::vector<uint32_t> v; - for (uint32_t i = 0; i < n; ++i) { - v.push_back(f(i)); - if (v[i] != v[0]) { - is_var = true; - } - } - - std::stringstream ss; - - if (is_var) { - ss << "["; - for (uint32_t i = 0; i < n; ++i) { - ss << v[i]; - if (i < n - 1) { - ss << ", "; - } - } - ss << "]"; - } else { - ss << v[0]; - } - - return ss.str(); - }; - - // hparams - LLAMA_LOG_INFO("%s: arch = %s\n", __func__, arch_name().c_str()); - LLAMA_LOG_INFO("%s: vocab_only = %d\n", __func__, hparams.vocab_only); - LLAMA_LOG_INFO("%s: no_alloc = %d\n", __func__, hparams.no_alloc); - - if (!hparams.vocab_only) { - LLAMA_LOG_INFO("%s: n_ctx_train = %u\n", __func__, hparams.n_ctx_train); - LLAMA_LOG_INFO("%s: n_embd = %u\n", __func__, hparams.n_embd); - LLAMA_LOG_INFO("%s: n_embd_inp = %u\n", __func__, hparams.n_embd_inp()); - LLAMA_LOG_INFO("%s: n_layer = %u\n", __func__, hparams.n_layer); - LLAMA_LOG_INFO("%s: n_head = %s\n", __func__, print_f([&](uint32_t il) { return hparams.n_head(il); }, hparams.n_layer).c_str()); - LLAMA_LOG_INFO("%s: n_head_kv = %s\n", __func__, print_f([&](uint32_t il) { return hparams.n_head_kv(il); }, hparams.n_layer).c_str()); - LLAMA_LOG_INFO("%s: n_rot = %u\n", __func__, hparams.n_rot); - LLAMA_LOG_INFO("%s: n_swa = %u\n", __func__, hparams.n_swa); - LLAMA_LOG_INFO("%s: is_swa_any = %u\n", __func__, hparams.is_swa_any()); - LLAMA_LOG_INFO("%s: n_embd_head_k = %u\n", __func__, hparams.n_embd_head_k); - LLAMA_LOG_INFO("%s: n_embd_head_v = %u\n", __func__, hparams.n_embd_head_v); - LLAMA_LOG_INFO("%s: n_gqa = %s\n", __func__, print_f([&](uint32_t il) { return hparams.n_gqa(il); }, hparams.n_layer).c_str()); - LLAMA_LOG_INFO("%s: n_embd_k_gqa = %s\n", __func__, print_f([&](uint32_t il) { return hparams.n_embd_k_gqa(il); }, hparams.n_layer).c_str()); - LLAMA_LOG_INFO("%s: n_embd_v_gqa = %s\n", __func__, print_f([&](uint32_t il) { return hparams.n_embd_v_gqa(il); }, hparams.n_layer).c_str()); - LLAMA_LOG_INFO("%s: f_norm_eps = %.1e\n", __func__, hparams.f_norm_eps); - LLAMA_LOG_INFO("%s: f_norm_rms_eps = %.1e\n", __func__, hparams.f_norm_rms_eps); - LLAMA_LOG_INFO("%s: f_clamp_kqv = %.1e\n", __func__, hparams.f_clamp_kqv); - LLAMA_LOG_INFO("%s: f_max_alibi_bias = %.1e\n", __func__, hparams.f_max_alibi_bias); - LLAMA_LOG_INFO("%s: f_logit_scale = %.1e\n", __func__, hparams.f_logit_scale); - LLAMA_LOG_INFO("%s: f_attn_scale = %.1e\n", __func__, hparams.f_attention_scale); - LLAMA_LOG_INFO("%s: n_ff = %s\n", __func__, print_f([&](uint32_t il) { return hparams.n_ff(il); }, hparams.n_layer).c_str()); - LLAMA_LOG_INFO("%s: n_expert = %u\n", __func__, hparams.n_expert); - LLAMA_LOG_INFO("%s: n_expert_used = %u\n", __func__, hparams.n_expert_used); - LLAMA_LOG_INFO("%s: n_expert_groups = %d\n", __func__, hparams.n_expert_groups); - LLAMA_LOG_INFO("%s: n_group_used = %d\n", __func__, hparams.n_group_used); - LLAMA_LOG_INFO("%s: causal attn = %d\n", __func__, hparams.causal_attn); - LLAMA_LOG_INFO("%s: pooling type = %d\n", __func__, hparams.pooling_type); - LLAMA_LOG_INFO("%s: rope type = %d\n", __func__, hparams.rope_type); - LLAMA_LOG_INFO("%s: rope scaling = %s\n", __func__, rope_scaling_type.c_str()); - LLAMA_LOG_INFO("%s: freq_base_train = %.1f\n", __func__, hparams.rope_freq_base_train); - LLAMA_LOG_INFO("%s: freq_scale_train = %g\n", __func__, hparams.rope_freq_scale_train); - if (hparams.swa_type != LLAMA_SWA_TYPE_NONE) { - LLAMA_LOG_INFO("%s: freq_base_swa = %.1f\n", __func__, hparams.rope_freq_base_train_swa); - LLAMA_LOG_INFO("%s: freq_scale_swa = %g\n", __func__, hparams.rope_freq_scale_train_swa); - } - LLAMA_LOG_INFO("%s: n_ctx_orig_yarn = %u\n", __func__, hparams.n_ctx_orig_yarn); - LLAMA_LOG_INFO("%s: rope_yarn_log_mul= %.4f\n", __func__, hparams.rope_yarn_log_mul); - LLAMA_LOG_INFO("%s: rope_finetuned = %s\n", __func__, hparams.rope_finetuned ? "yes" : "unknown"); - // MRoPE (Multi-axis Rotary Position Embedding) sections - if (const auto & s = hparams.rope_sections; s[0] || s[1] || s[2] || s[3]) { - LLAMA_LOG_INFO("%s: mrope sections = [%d, %d, %d, %d]\n", __func__, s[0], s[1], s[2], s[3]); - } - if (!classifier_labels.empty()) { - LLAMA_LOG_INFO("%s: n_cls_out = %u\n", __func__, hparams.n_cls_out); - - size_t i = 0; - for (auto label : classifier_labels) { - LLAMA_LOG_INFO("%s: cls_label[%2zu] = %s\n", __func__, i++, label.c_str()); - } - } - } - - if (arch == LLM_ARCH_MAMBA || - arch == LLM_ARCH_MAMBA2 || - arch == LLM_ARCH_JAMBA || - arch == LLM_ARCH_FALCON_H1 || - arch == LLM_ARCH_PLAMO2 || - arch == LLM_ARCH_GRANITE_HYBRID || - arch == LLM_ARCH_QWEN3NEXT || - arch == LLM_ARCH_NEMOTRON_H || - arch == LLM_ARCH_NEMOTRON_H_MOE) { - LLAMA_LOG_INFO("%s: ssm_d_conv = %u\n", __func__, hparams.ssm_d_conv); - LLAMA_LOG_INFO("%s: ssm_d_inner = %u\n", __func__, hparams.ssm_d_inner); - LLAMA_LOG_INFO("%s: ssm_d_state = %u\n", __func__, hparams.ssm_d_state); - LLAMA_LOG_INFO("%s: ssm_dt_rank = %u\n", __func__, hparams.ssm_dt_rank); - LLAMA_LOG_INFO("%s: ssm_n_group = %u\n", __func__, hparams.ssm_n_group); - LLAMA_LOG_INFO("%s: ssm_dt_b_c_rms = %d\n", __func__, hparams.ssm_dt_b_c_rms); - } - - LLAMA_LOG_INFO("%s: model type = %s\n", __func__, type_name().c_str()); - if (pimpl->n_elements >= 1e12) { - LLAMA_LOG_INFO("%s: model params = %.2f T\n", __func__, pimpl->n_elements*1e-12); - } else if (pimpl->n_elements >= 1e9) { - LLAMA_LOG_INFO("%s: model params = %.2f B\n", __func__, pimpl->n_elements*1e-9); - } else if (pimpl->n_elements >= 1e6) { - LLAMA_LOG_INFO("%s: model params = %.2f M\n", __func__, pimpl->n_elements*1e-6); - } else { - LLAMA_LOG_INFO("%s: model params = %.2f K\n", __func__, pimpl->n_elements*1e-3); - } - - // general kv - LLAMA_LOG_INFO("%s: general.name = %s\n", __func__, name.c_str()); - - if (arch == LLM_ARCH_DEEPSEEK) { - LLAMA_LOG_INFO("%s: n_layer_dense_lead = %d\n", __func__, hparams.n_layer_dense_lead); - LLAMA_LOG_INFO("%s: n_ff_exp = %d\n", __func__, hparams.n_ff_exp); - LLAMA_LOG_INFO("%s: n_expert_shared = %d\n", __func__, hparams.n_expert_shared); - LLAMA_LOG_INFO("%s: expert_weights_scale = %.1f\n", __func__, hparams.expert_weights_scale); - } - - if (arch == LLM_ARCH_DEEPSEEK2) { - LLAMA_LOG_INFO("%s: n_layer_dense_lead = %d\n", __func__, hparams.n_layer_dense_lead); - LLAMA_LOG_INFO("%s: n_lora_q = %d\n", __func__, hparams.n_lora_q); - LLAMA_LOG_INFO("%s: n_lora_kv = %d\n", __func__, hparams.n_lora_kv); - LLAMA_LOG_INFO("%s: n_embd_head_k_mla = %d\n", __func__, hparams.n_embd_head_k_mla); - LLAMA_LOG_INFO("%s: n_embd_head_v_mla = %d\n", __func__, hparams.n_embd_head_v_mla); - LLAMA_LOG_INFO("%s: n_ff_exp = %d\n", __func__, hparams.n_ff_exp); - LLAMA_LOG_INFO("%s: n_expert_shared = %d\n", __func__, hparams.n_expert_shared); - LLAMA_LOG_INFO("%s: expert_weights_scale = %.1f\n", __func__, hparams.expert_weights_scale); - LLAMA_LOG_INFO("%s: expert_weights_norm = %d\n", __func__, hparams.expert_weights_norm); - LLAMA_LOG_INFO("%s: expert_gating_func = %s\n", __func__, llama_expert_gating_func_name((llama_expert_gating_func_type) hparams.expert_gating_func)); - } - - if (arch == LLM_ARCH_QWEN2MOE) { - LLAMA_LOG_INFO("%s: n_ff_exp = %d\n", __func__, hparams.n_ff_exp); - LLAMA_LOG_INFO("%s: n_ff_shexp = %d\n", __func__, hparams.n_ff_shexp); - } - - if (arch == LLM_ARCH_QWEN3MOE || arch == LLM_ARCH_OPENAI_MOE || arch == LLM_ARCH_QWEN3VLMOE || arch == LLM_ARCH_RND1) { - LLAMA_LOG_INFO("%s: n_ff_exp = %d\n", __func__, hparams.n_ff_exp); - } - - if (arch == LLM_ARCH_MINICPM || - arch == LLM_ARCH_GRANITE || - arch == LLM_ARCH_GRANITE_MOE || - arch == LLM_ARCH_GRANITE_HYBRID || - arch == LLM_ARCH_NEMOTRON_H_MOE) { - LLAMA_LOG_INFO("%s: f_embedding_scale = %f\n", __func__, hparams.f_embedding_scale); - LLAMA_LOG_INFO("%s: f_residual_scale = %f\n", __func__, hparams.f_residual_scale); - LLAMA_LOG_INFO("%s: f_attention_scale = %f\n", __func__, hparams.f_attention_scale); - LLAMA_LOG_INFO("%s: n_ff_shexp = %d\n", __func__, hparams.n_ff_shexp); - } - - if (arch == LLM_ARCH_BAILINGMOE) { - LLAMA_LOG_INFO("%s: n_layer_dense_lead = %d\n", __func__, hparams.n_layer_dense_lead); - LLAMA_LOG_INFO("%s: n_ff_exp = %d\n", __func__, hparams.n_ff_exp); - LLAMA_LOG_INFO("%s: n_expert_shared = %d\n", __func__, hparams.n_expert_shared); - LLAMA_LOG_INFO("%s: expert_weights_scale = %.1f\n", __func__, hparams.expert_weights_scale); - LLAMA_LOG_INFO("%s: expert_weights_norm = %d\n", __func__, hparams.expert_weights_norm); - } - - if (arch == LLM_ARCH_BAILINGMOE2) { - LLAMA_LOG_INFO("%s: n_layer_dense_lead = %d\n", __func__, hparams.n_layer_dense_lead); - LLAMA_LOG_INFO("%s: n_ff_exp = %d\n", __func__, hparams.n_ff_exp); - LLAMA_LOG_INFO("%s: n_ff_shexp = %d\n", __func__, hparams.n_ff_shexp); - LLAMA_LOG_INFO("%s: n_expert_shared = %d\n", __func__, hparams.n_expert_shared); - LLAMA_LOG_INFO("%s: expert_weights_scale = %.1f\n", __func__, hparams.expert_weights_scale); - LLAMA_LOG_INFO("%s: expert_weights_norm = %d\n", __func__, hparams.expert_weights_norm); - LLAMA_LOG_INFO("%s: expert_gating_func = %s\n", __func__, llama_expert_gating_func_name((llama_expert_gating_func_type) hparams.expert_gating_func)); - LLAMA_LOG_INFO("%s: nextn_predict_layers = %d\n", __func__, hparams.nextn_predict_layers); - } - - if (arch == LLM_ARCH_SMALLTHINKER || arch == LLM_ARCH_LFM2MOE) { - LLAMA_LOG_INFO("%s: n_ff_exp = %d\n", __func__, hparams.n_ff_exp); - LLAMA_LOG_INFO("%s: expert_gating_func = %s\n", __func__, llama_expert_gating_func_name((llama_expert_gating_func_type) hparams.expert_gating_func)); - } - - if (arch == LLM_ARCH_GROVEMOE) { - LLAMA_LOG_INFO("%s: n_ff_exp = %d\n", __func__, hparams.n_ff_exp); - LLAMA_LOG_INFO("%s: n_ff_chexp = %d\n", __func__, hparams.n_ff_chexp); - LLAMA_LOG_INFO("%s: n_group_experts = %d\n", __func__, hparams.n_group_experts); - LLAMA_LOG_INFO("%s: expert_group_scale = %.2f\n", __func__, hparams.expert_group_scale); - } - - vocab.print_info(); -} - -ggml_backend_dev_t llama_model::dev_layer(int il) const { - return pimpl->dev_layer.at(il).dev; -} - -ggml_backend_dev_t llama_model::dev_output() const { - return pimpl->dev_output.dev; -} - -template<typename F> -static bool buft_supported(ggml_backend_buffer_type_t buft, ggml_backend_dev_t dev, F & fn) { - ggml_init_params params = { - /*.mem_size =*/ ggml_tensor_overhead()*8, - /*.mem_buffer =*/ NULL, - /*.no_alloc =*/ true, - }; - - ggml_context_ptr ctx { ggml_init(params) }; - if (!ctx) { - throw std::runtime_error(format("failed to create ggml context")); - } - - ggml_backend_buffer_ptr buf { ggml_backend_buft_alloc_buffer(buft, 0) }; - ggml_tensor * op_tensor = fn(ctx.get()); - for (int i = 0; i < GGML_MAX_SRC; i++) { - if (op_tensor->src[i] != nullptr) { - assert(op_tensor->src[i]->buffer == nullptr); - op_tensor->src[i]->buffer = buf.get(); - } - } - - bool op_supported = ggml_backend_dev_supports_op(dev, op_tensor); - - return op_supported; -} - -template<typename F> -static ggml_backend_buffer_type_t select_buft(const buft_list_t & buft_list, const F & fn) { - for (const auto & cur : buft_list) { - ggml_backend_dev_t cur_dev = cur.first; - ggml_backend_buffer_type_t cur_buft = cur.second; - if (buft_supported(cur_buft, cur_dev, fn)) { - return cur_buft; - } - } - - throw std::runtime_error(format("no suitable buffer type found")); -} - -ggml_backend_buffer_type_t llama_model::select_buft(int il) const { - return ::select_buft( - *pimpl->dev_layer.at(il).buft_list, - [&](ggml_context * ctx) { - ggml_tensor * cur = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, hparams.n_embd); - ggml_tensor * layer_dir = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, hparams.n_embd); - return ggml_add(ctx, cur, layer_dir); - }); -} - -bool llama_model::has_tensor_overrides() const { - return pimpl->has_tensor_overrides; -} - -const ggml_tensor * llama_model::get_tensor(const char * name) const { - auto it = std::find_if(tensors_by_name.begin(), tensors_by_name.end(), - [name](const std::pair<std::string, ggml_tensor *> & it) { - return it.first == name; - }); - if (it == tensors_by_name.end()) { - return nullptr; - } - - return it->second; -} - -float llama_model::get_rope_freq_base (const llama_cparams & cparams, int il) const { - return hparams.is_swa(il) ? hparams.rope_freq_base_train_swa : cparams.rope_freq_base; -} - -float llama_model::get_rope_freq_scale(const llama_cparams & cparams, int il) const { - return hparams.is_swa(il) ? hparams.rope_freq_scale_train_swa : cparams.rope_freq_scale; -} - -ggml_tensor * llama_model::get_rope_factors(const llama_cparams & cparams, int il) const { - const uint32_t n_ctx_seq = cparams.n_ctx_seq; - - // choose long/short freq factors based on the context size - if (layers[il].rope_freqs != nullptr) { - return layers[il].rope_freqs; - } - - if (n_ctx_seq > hparams.n_ctx_orig_yarn) { - return layers[il].rope_long; - } - - return layers[il].rope_short; -} - -llama_memory_i * llama_model::create_memory(const llama_memory_params & params, const llama_cparams & cparams) const { - llama_memory_i * res; - - switch (arch) { - // Models that need specific instantiation should be handled in the - // switch statement - case LLM_ARCH_BERT: - case LLM_ARCH_JINA_BERT_V2: - case LLM_ARCH_JINA_BERT_V3: - case LLM_ARCH_NOMIC_BERT: - case LLM_ARCH_NOMIC_BERT_MOE: - case LLM_ARCH_NEO_BERT: - case LLM_ARCH_WAVTOKENIZER_DEC: - case LLM_ARCH_MODERN_BERT: - case LLM_ARCH_GEMMA_EMBEDDING: - case LLM_ARCH_DREAM: - case LLM_ARCH_LLADA: - case LLM_ARCH_LLADA_MOE: - case LLM_ARCH_RND1: - { - res = nullptr; - } break; - // Models that need standard caching should rely on recurrent/hybrid - // checks - default: - { - if (llm_arch_is_recurrent(arch)) { - res = new llama_memory_recurrent( - *this, - GGML_TYPE_F32, - GGML_TYPE_F32, - cparams.offload_kqv, - std::max((uint32_t) 1, cparams.n_seq_max), - cparams.n_seq_max, - nullptr); - } else if (llm_arch_is_hybrid(arch)) { - - // The main difference between hybrid architectures is the - // layer filters, so pick the right one here - llama_memory_hybrid::layer_filter_cb filter_attn = nullptr; - llama_memory_hybrid::layer_filter_cb filter_recr = nullptr; - if (arch == LLM_ARCH_FALCON_H1) { - filter_attn = [&](int32_t) { return true; }; - filter_recr = [&](int32_t) { return true; }; - } else if (arch == LLM_ARCH_NEMOTRON_H || arch == LLM_ARCH_NEMOTRON_H_MOE) { - filter_attn = [&](int32_t il) { - return !hparams.is_recurrent(il) && hparams.n_ff(il) == 0; - }; - filter_recr = [&](int32_t il) { - return hparams.is_recurrent(il) && hparams.n_ff(il) == 0; - }; - } - - res = new llama_memory_hybrid( - /* model */ *this, - /* attn_type_k */ params.type_k, - /* attn_type_v */ params.type_v, - /* attn_v_trans */ !cparams.flash_attn, - /* attn_kv_size */ cparams.n_ctx, - /* attn_n_pad */ 1, - /* attn_n_swa */ hparams.n_swa, - /* attn_swa_type */ hparams.swa_type, - /* recurrent_type_k */ GGML_TYPE_F32, - /* recurrent_type_v */ GGML_TYPE_F32, - /* recurrent_kv_size */ std::max((uint32_t) 1, cparams.n_seq_max), - /* n_seq_max */ cparams.n_seq_max, - /* offload */ cparams.offload_kqv, - /* unified */ cparams.kv_unified, - /* filter_attn */ std::move(filter_attn), - /* filter_recr */ std::move(filter_recr)); - } else { - llama_memory_i::layer_reuse_cb reuse = nullptr; - - if (arch == LLM_ARCH_GEMMA3N) { - reuse = [&](int32_t il) { - if (il >= (int32_t) hparams.n_layer_kv_from_start) { - return (int32_t) hparams.n_layer_kv_from_start - (hparams.is_swa(il) ? 2 : 1); - } - - return -1; - }; - } - - if (hparams.swa_type != LLAMA_SWA_TYPE_NONE) { - GGML_ASSERT(hparams.is_swa_any()); - - res = new llama_kv_cache_iswa( - *this, - params.type_k, - params.type_v, - !cparams.flash_attn, - cparams.offload_kqv, - params.swa_full, - cparams.kv_unified, - cparams.n_ctx_seq, - cparams.n_seq_max, - cparams.n_ubatch, - 1, - nullptr, - reuse); - } else { - GGML_ASSERT(!hparams.is_swa_any()); - - res = new llama_kv_cache( - *this, - params.type_k, - params.type_v, - !cparams.flash_attn, - cparams.offload_kqv, - cparams.kv_unified, - cparams.n_ctx_seq, - cparams.n_seq_max, - 1, - hparams.n_swa, - hparams.swa_type, - nullptr, - nullptr); - } - } - } - } - - return res; -} - -ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const { - std::unique_ptr<llm_graph_context> llm; - - switch (arch) { - case LLM_ARCH_LLAMA: - { - llm = std::make_unique<llm_build_llama<false>>(*this, params); - } break; - case LLM_ARCH_LLAMA4: - { - if (hparams.swa_type == LLAMA_SWA_TYPE_NONE) { - llm = std::make_unique<llm_build_llama<false>>(*this, params); - } else { - llm = std::make_unique<llm_build_llama_iswa>(*this, params); - } - } break; - case LLM_ARCH_LLAMA_EMBED: - { - llm = std::make_unique<llm_build_llama<true>>(*this, params); - } break; - case LLM_ARCH_MAINCODER: - { - llm = std::make_unique<llm_build_maincoder>(*this, params); - } break; - case LLM_ARCH_DECI: - { - llm = std::make_unique<llm_build_deci>(*this, params); - } break; - case LLM_ARCH_BAICHUAN: - { - llm = std::make_unique<llm_build_baichuan>(*this, params); - } break; - case LLM_ARCH_FALCON: - { - llm = std::make_unique<llm_build_falcon>(*this, params); - } break; - case LLM_ARCH_GROK: - { - llm = std::make_unique<llm_build_grok>(*this, params); - } break; - case LLM_ARCH_STARCODER: - { - llm = std::make_unique<llm_build_starcoder>(*this, params); - } break; - case LLM_ARCH_REFACT: - { - llm = std::make_unique<llm_build_refact>(*this, params); - } break; - case LLM_ARCH_BERT: - case LLM_ARCH_JINA_BERT_V2: - case LLM_ARCH_JINA_BERT_V3: - case LLM_ARCH_NOMIC_BERT: - case LLM_ARCH_NOMIC_BERT_MOE: - { - llm = std::make_unique<llm_build_bert>(*this, params); - } break; - case LLM_ARCH_MODERN_BERT: - { - llm = std::make_unique<llm_build_modern_bert>(*this, params); - } break; - case LLM_ARCH_NEO_BERT: - { - llm = std::make_unique<llm_build_neo_bert>(*this, params); - } break; - case LLM_ARCH_BLOOM: - { - llm = std::make_unique<llm_build_bloom>(*this, params); - } break; - case LLM_ARCH_MPT: - { - llm = std::make_unique<llm_build_mpt>(*this, params); - } break; - case LLM_ARCH_STABLELM: - { - llm = std::make_unique<llm_build_stablelm>(*this, params); - } break; - case LLM_ARCH_QWEN: - { - llm = std::make_unique<llm_build_qwen>(*this, params); - } break; - case LLM_ARCH_QWEN2: - { - llm = std::make_unique<llm_build_qwen2>(*this, params); - } break; - case LLM_ARCH_DREAM: - { - llm = std::make_unique<llm_build_dream>(*this, params); - } - break; - case LLM_ARCH_LLADA: - { - llm = std::make_unique<llm_build_llada>(*this, params); - } - break; - case LLM_ARCH_LLADA_MOE: - { - llm = std::make_unique<llm_build_llada_moe>(*this, params); - } - break; - case LLM_ARCH_RND1: - { - llm = std::make_unique<llm_build_rnd1>(*this, params); - } - break; - case LLM_ARCH_QWEN2VL: - { - llm = std::make_unique<llm_build_qwen2vl>(*this, params); - } break; - case LLM_ARCH_QWEN2MOE: - { - llm = std::make_unique<llm_build_qwen2moe>(*this, params); - } break; - case LLM_ARCH_QWEN3: - { - llm = std::make_unique<llm_build_qwen3>(*this, params); - } break; - case LLM_ARCH_QWEN3MOE: - { - llm = std::make_unique<llm_build_qwen3moe>(*this, params); - } break; - case LLM_ARCH_QWEN3VL: - { - llm = std::make_unique<llm_build_qwen3vl>(*this, params); - } break; - case LLM_ARCH_QWEN3VLMOE: - { - llm = std::make_unique<llm_build_qwen3vlmoe>(*this, params); - } break; - case LLM_ARCH_PHI2: - { - llm = std::make_unique<llm_build_phi2>(*this, params); - } break; - case LLM_ARCH_PHI3: - case LLM_ARCH_PHIMOE: - { - if (hparams.swa_type != LLAMA_SWA_TYPE_NONE) { - llm = std::make_unique<llm_build_phi3<true>> (*this, params); - } else { - llm = std::make_unique<llm_build_phi3<false>>(*this, params); - } - } break; - case LLM_ARCH_PLAMO: - { - llm = std::make_unique<llm_build_plamo>(*this, params); - } break; - case LLM_ARCH_PLAMO2: - { - llm = std::make_unique<llm_build_plamo2>(*this, params); - } break; - case LLM_ARCH_PLAMO3: - { - if (hparams.swa_type != LLAMA_SWA_TYPE_NONE) { - llm = std::make_unique<llm_build_plamo3<true>> (*this, params); - } else { - llm = std::make_unique<llm_build_plamo3<false>>(*this, params); - } - } break; - case LLM_ARCH_GPT2: - { - llm = std::make_unique<llm_build_gpt2>(*this, params); - } break; - case LLM_ARCH_CODESHELL: - { - llm = std::make_unique<llm_build_codeshell>(*this, params); - } break; - case LLM_ARCH_ORION: - { - llm = std::make_unique<llm_build_orion>(*this, params); - } break; - case LLM_ARCH_INTERNLM2: - { - llm = std::make_unique<llm_build_internlm2>(*this, params); - } break; - case LLM_ARCH_MINICPM3: - { - llm = std::make_unique<llm_build_minicpm3>(*this, params); - } break; - case LLM_ARCH_GEMMA: - { - llm = std::make_unique<llm_build_gemma>(*this, params); - } break; - case LLM_ARCH_GEMMA2: - { - llm = std::make_unique<llm_build_gemma2_iswa>(*this, params); - } break; - case LLM_ARCH_GEMMA3: - { - if (hparams.swa_type == LLAMA_SWA_TYPE_STANDARD) { - llm = std::make_unique<llm_build_gemma3<true>>(*this, params); - } else { - llm = std::make_unique<llm_build_gemma3<false>>(*this, params); - } - } break; - case LLM_ARCH_GEMMA3N: - { - llm = std::make_unique<llm_build_gemma3n_iswa>(*this, params); - } break; - case LLM_ARCH_GEMMA_EMBEDDING: - { - llm = std::make_unique<llm_build_gemma_embedding>(*this, params); - } break; - case LLM_ARCH_STARCODER2: - { - llm = std::make_unique<llm_build_starcoder2>(*this, params); - } break; - case LLM_ARCH_MAMBA: - case LLM_ARCH_MAMBA2: - { - llm = std::make_unique<llm_build_mamba>(*this, params); - } break; - case LLM_ARCH_JAMBA: - { - llm = std::make_unique<llm_build_jamba>(*this, params); - } break; - case LLM_ARCH_XVERSE: - { - llm = std::make_unique<llm_build_xverse>(*this, params); - } break; - case LLM_ARCH_COMMAND_R: - { - llm = std::make_unique<llm_build_command_r>(*this, params); - } break; - case LLM_ARCH_COHERE2: - { - llm = std::make_unique<llm_build_cohere2_iswa>(*this, params); - } break; - case LLM_ARCH_DBRX: - { - llm = std::make_unique<llm_build_dbrx>(*this, params); - } break; - case LLM_ARCH_OLMO: - { - llm = std::make_unique<llm_build_olmo>(*this, params); - } break; - case LLM_ARCH_OLMO2: - { - if (hparams.swa_type == LLAMA_SWA_TYPE_STANDARD) { - llm = std::make_unique<llm_build_olmo2<true>>(*this, params); - } else { - llm = std::make_unique<llm_build_olmo2<false>>(*this, params); - } - } break; - case LLM_ARCH_OLMOE: - { - llm = std::make_unique<llm_build_olmoe>(*this, params); - } break; - case LLM_ARCH_OPENELM: - { - llm = std::make_unique<llm_build_openelm>(*this, params); - } break; - case LLM_ARCH_GPTNEOX: - { - llm = std::make_unique<llm_build_gptneox>(*this, params); - } break; - case LLM_ARCH_ARCTIC: - { - llm = std::make_unique<llm_build_arctic>(*this, params); - } break; - case LLM_ARCH_DEEPSEEK: - { - llm = std::make_unique<llm_build_deepseek>(*this, params); - } break; - case LLM_ARCH_DEEPSEEK2: - { - llm = std::make_unique<llm_build_deepseek2>(*this, params); - } break; - case LLM_ARCH_CHATGLM: - { - llm = std::make_unique<llm_build_chatglm>(*this, params); - } break; - case LLM_ARCH_GLM4: - { - llm = std::make_unique<llm_build_glm4>(*this, params); - } break; - case LLM_ARCH_GLM4_MOE: - { - llm = std::make_unique<llm_build_glm4_moe>(*this, params); - } break; - case LLM_ARCH_BITNET: - { - llm = std::make_unique<llm_build_bitnet>(*this, params); - } break; - case LLM_ARCH_T5: - { - switch (params.gtype) { - case LLM_GRAPH_TYPE_ENCODER: - llm = std::make_unique<llm_build_t5_enc>(*this, params); - break; - case LLM_GRAPH_TYPE_DEFAULT: - case LLM_GRAPH_TYPE_DECODER: - llm = std::make_unique<llm_build_t5_dec>(*this, params); - break; - default: - GGML_ABORT("invalid graph type"); - }; - } break; - case LLM_ARCH_T5ENCODER: - { - llm = std::make_unique<llm_build_t5_enc>(*this, params); - } - break; - case LLM_ARCH_JAIS: - { - llm = std::make_unique<llm_build_jais>(*this, params); - } break; - case LLM_ARCH_NEMOTRON: - { - llm = std::make_unique<llm_build_nemotron>(*this, params); - } break; - case LLM_ARCH_NEMOTRON_H: - case LLM_ARCH_NEMOTRON_H_MOE: - { - llm = std::make_unique<llm_build_nemotron_h>(*this, params); - } break; - case LLM_ARCH_EXAONE: - { - llm = std::make_unique<llm_build_exaone>(*this, params); - } break; - case LLM_ARCH_EXAONE4: - { - if (hparams.swa_type == LLAMA_SWA_TYPE_STANDARD) { - llm = std::make_unique<llm_build_exaone4<true>>(*this, params); - } else { - llm = std::make_unique<llm_build_exaone4<false>>(*this, params); - } - } break; - case LLM_ARCH_RWKV6: - { - llm = std::make_unique<llm_build_rwkv6>(*this, params); - } break; - case LLM_ARCH_RWKV6QWEN2: - { - llm = std::make_unique<llm_build_rwkv6qwen2>(*this, params); - } break; - case LLM_ARCH_RWKV7: - { - llm = std::make_unique<llm_build_rwkv7>(*this, params); - } break; - case LLM_ARCH_ARWKV7: - { - llm = std::make_unique<llm_build_arwkv7>(*this, params); - } break; - case LLM_ARCH_GRANITE: - case LLM_ARCH_GRANITE_MOE: - case LLM_ARCH_MINICPM: - { - llm = std::make_unique<llm_build_granite>(*this, params); - } break; - case LLM_ARCH_GRANITE_HYBRID: - { - llm = std::make_unique<llm_build_granite_hybrid>(*this, params); - } break; - case LLM_ARCH_CHAMELEON: - { - llm = std::make_unique<llm_build_chameleon>(*this, params); - } break; +template<typename F> +static ggml_backend_buffer_type_t select_buft(const buft_list_t & buft_list, const F & fn) { + for (const auto & cur : buft_list) { + ggml_backend_dev_t cur_dev = cur.first; + ggml_backend_buffer_type_t cur_buft = cur.second; + if (buft_supported(cur_buft, cur_dev, fn)) { + return cur_buft; + } + } + + throw std::runtime_error(format("no suitable buffer type found")); +} + +ggml_backend_buffer_type_t llama_model::select_buft(int il) const { + return ::select_buft( + *pimpl->dev_layer.at(il).buft_list, + [&](ggml_context * ctx) { + ggml_tensor * cur = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, hparams.n_embd); + ggml_tensor * layer_dir = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, hparams.n_embd); + return ggml_add(ctx, cur, layer_dir); + }); +} + +bool llama_model::has_tensor_overrides() const { + return pimpl->has_tensor_overrides; +} + +const ggml_tensor * llama_model::get_tensor(const char * name) const { + auto it = std::find_if(tensors_by_name.begin(), tensors_by_name.end(), + [name](const std::pair<std::string, ggml_tensor *> & it) { + return it.first == name; + }); + if (it == tensors_by_name.end()) { + return nullptr; + } + + return it->second; +} + +float llama_model::get_rope_freq_base (const llama_cparams & cparams, int il) const { + return hparams.is_swa(il) ? hparams.rope_freq_base_train_swa : cparams.rope_freq_base; +} + +float llama_model::get_rope_freq_scale(const llama_cparams & cparams, int il) const { + return hparams.is_swa(il) ? hparams.rope_freq_scale_train_swa : cparams.rope_freq_scale; +} + +ggml_tensor * llama_model::get_rope_factors(const llama_cparams & cparams, int il) const { + const uint32_t n_ctx_seq = cparams.n_ctx_seq; + + // choose long/short freq factors based on the context size + if (layers[il].rope_freqs != nullptr) { + return layers[il].rope_freqs; + } + + if (n_ctx_seq > hparams.n_ctx_orig_yarn) { + return layers[il].rope_long; + } + + return layers[il].rope_short; +} + +llama_memory_i * llama_model::create_memory(const llama_memory_params & params, const llama_cparams & cparams) const { + llama_memory_i * res; + + switch (arch) { + // Models that need specific instantiation should be handled in the + // switch statement + case LLM_ARCH_BERT: + case LLM_ARCH_JINA_BERT_V2: + case LLM_ARCH_JINA_BERT_V3: + case LLM_ARCH_NOMIC_BERT: + case LLM_ARCH_NOMIC_BERT_MOE: + case LLM_ARCH_NEO_BERT: + case LLM_ARCH_EUROBERT: case LLM_ARCH_WAVTOKENIZER_DEC: + case LLM_ARCH_MODERN_BERT: + case LLM_ARCH_GEMMA_EMBEDDING: + case LLM_ARCH_DREAM: + case LLM_ARCH_LLADA: + case LLM_ARCH_LLADA_MOE: + case LLM_ARCH_RND1: { - llm = std::make_unique<llm_build_wavtokenizer_dec>(*this, params); - } break; - case LLM_ARCH_PLM: - { - llm = std::make_unique<llm_build_plm>(*this, params); - } break; - case LLM_ARCH_BAILINGMOE: - { - llm = std::make_unique<llm_build_bailingmoe>(*this, params); - } break; - case LLM_ARCH_BAILINGMOE2: - { - llm = std::make_unique<llm_build_bailingmoe2>(*this, params); - } break; - case LLM_ARCH_SEED_OSS: - { - llm = std::make_unique<llm_build_seed_oss>(*this, params); - } break; - case LLM_ARCH_DOTS1: - { - llm = std::make_unique<llm_build_dots1>(*this, params); - } break; - case LLM_ARCH_ARCEE: - { - llm = std::make_unique<llm_build_arcee>(*this, params); - } break; - case LLM_ARCH_AFMOE: - { - llm = std::make_unique<llm_build_afmoe>(*this, params); - } break; - case LLM_ARCH_ERNIE4_5: - { - llm = std::make_unique<llm_build_ernie4_5>(*this, params); - } break; - case LLM_ARCH_ERNIE4_5_MOE: - { - llm = std::make_unique<llm_build_ernie4_5_moe>(*this, params); - } break; - case LLM_ARCH_HUNYUAN_MOE: - { - llm = std::make_unique<llm_build_hunyuan_moe>(*this, params); - } break; - case LLM_ARCH_HUNYUAN_DENSE: - { - llm = std::make_unique<llm_build_hunyuan_dense>(*this, params); - } break; - case LLM_ARCH_SMOLLM3: - { - llm = std::make_unique<llm_build_smollm3>(*this, params); - } break; - case LLM_ARCH_OPENAI_MOE: - { - llm = std::make_unique<llm_build_openai_moe_iswa>(*this, params); - } break; - case LLM_ARCH_FALCON_H1: - { - llm = std::make_unique<llm_build_falcon_h1>(*this, params); + res = nullptr; } break; - case LLM_ARCH_LFM2: - case LLM_ARCH_LFM2MOE: - { - llm = std::make_unique<llm_build_lfm2>(*this, params); + case LLM_ARCH_DEEPSEEK32: + { + res = new llama_kv_cache_dsa( + *this, + params.type_k, + params.type_v, + !cparams.flash_attn, + cparams.offload_kqv, + cparams.kv_unified, + cparams.n_ctx_seq, + cparams.n_seq_max, + 1, + hparams.n_swa, + hparams.swa_type, + nullptr, + nullptr); } break; - case LLM_ARCH_SMALLTHINKER: + // Models that need standard caching should rely on recurrent/hybrid + // checks + default: { - if (hparams.swa_type == LLAMA_SWA_TYPE_STANDARD) { - llm = std::make_unique<llm_build_smallthinker<true>> (*this, params); + // The MTP head is dense-attention only on hybrid Qwen3.5/3.6, so use a plain + // attention KV cache for the MTP context instead of the hybrid wrapper. + const bool mtp_on_hybrid_qwen35 = + params.ctx_type == LLAMA_CONTEXT_TYPE_MTP && + (arch == LLM_ARCH_QWEN35 || arch == LLM_ARCH_QWEN35MOE); + + if (llm_arch_is_recurrent(arch)) { + res = new llama_memory_recurrent( + *this, + GGML_TYPE_F32, + GGML_TYPE_F32, + cparams.offload_kqv, + std::max((uint32_t) 1, cparams.n_seq_max), + cparams.n_seq_max, + cparams.n_rs_seq, + nullptr); + } else if (llm_arch_is_hybrid(arch) && !mtp_on_hybrid_qwen35) { + // The main difference between hybrid architectures is the + // layer filters, so pick the right one here + llama_memory_hybrid::layer_filter_cb filter_attn = nullptr; + llama_memory_hybrid::layer_filter_cb filter_recr = nullptr; + if (arch == LLM_ARCH_FALCON_H1) { + filter_attn = [&](uint32_t) { return true; }; + filter_recr = [&](uint32_t) { return true; }; + } else if (arch == LLM_ARCH_NEMOTRON_H || arch == LLM_ARCH_NEMOTRON_H_MOE) { + filter_attn = [&](uint32_t il) { + return !hparams.is_recr(il) && hparams.n_ff(il) == 0; + }; + filter_recr = [&](uint32_t il) { + return hparams.is_recr(il) && hparams.n_ff(il) == 0; + }; + } else if (arch == LLM_ARCH_QWEN35 || arch == LLM_ARCH_QWEN35MOE) { + filter_attn = [&](uint32_t il) { + return il < hparams.n_layer() && !hparams.is_recr(il); + }; + filter_recr = [&](uint32_t il) { + return il < hparams.n_layer() && hparams.is_recr(il); + }; + } + + if (hparams.swa_type != LLAMA_SWA_TYPE_NONE) { + // Use hybrid-iswa for hybrid models with SWA + res = new llama_memory_hybrid_iswa( + /* model */ *this, + /* attn_type_k */ params.type_k, + /* attn_type_v */ params.type_v, + /* attn_v_trans */ !cparams.flash_attn, + /* attn_swa_full */ params.swa_full, + /* attn_kv_size */ cparams.n_ctx_seq, + /* attn_n_ubatch */ cparams.n_ubatch, + /* attn_n_pad */ 1, + /* recurrent_type_r */ GGML_TYPE_F32, + /* recurrent_type_s */ GGML_TYPE_F32, + /* recurrent_rs_size */ std::max((uint32_t) 1, cparams.n_seq_max), + /* n_seq_max */ cparams.n_seq_max, + /* n_rs_seq */ cparams.n_rs_seq, + /* offload */ cparams.offload_kqv, + /* unified */ cparams.kv_unified, + /* filter_attn */ std::move(filter_attn), + /* filter_recr */ std::move(filter_recr)); + } else { + res = new llama_memory_hybrid( + /* model */ *this, + /* attn_type_k */ params.type_k, + /* attn_type_v */ params.type_v, + /* attn_v_trans */ !cparams.flash_attn, + /* attn_kv_size */ cparams.n_ctx_seq, + /* attn_n_pad */ 1, + /* attn_n_swa */ hparams.n_swa, + /* attn_swa_type */ hparams.swa_type, + /* recurrent_type_k */ GGML_TYPE_F32, + /* recurrent_type_v */ GGML_TYPE_F32, + /* recurrent_kv_size */ std::max((uint32_t) 1, cparams.n_seq_max), + /* n_seq_max */ cparams.n_seq_max, + /* n_rs_seq */ cparams.n_rs_seq, + /* offload */ cparams.offload_kqv, + /* unified */ cparams.kv_unified, + /* filter_attn */ std::move(filter_attn), + /* filter_recr */ std::move(filter_recr)); + } } else { - llm = std::make_unique<llm_build_smallthinker<false>>(*this, params); + llama_kv_cache::layer_filter_cb filter = nullptr; + llama_memory_i::layer_reuse_cb reuse = nullptr; + llama_kv_cache::layer_share_cb share = nullptr; + + if (arch == LLM_ARCH_GEMMA3N || arch == LLM_ARCH_GEMMA4) { + reuse = [&](uint32_t il) { + GGML_ASSERT(hparams.n_layer_kv_from_start >= 2); + + if (il >= (uint32_t)hparams.n_layer_kv_from_start) { + return hparams.n_layer_kv_from_start - (hparams.is_swa(il) ? 2 : 1); + } + + return -1; + }; + } + + if (mtp_on_hybrid_qwen35) { + filter = [&](uint32_t il) { return il >= hparams.n_layer(); }; + } + + if (arch == LLM_ARCH_STEP35 && hparams.n_layer_nextn > 0) { + if (params.ctx_type == LLAMA_CONTEXT_TYPE_MTP) { + filter = [&](uint32_t il) { return il >= hparams.n_layer(); }; + } else { + filter = [&](uint32_t il) { return il < hparams.n_layer(); }; + } + } + + if (hparams.swa_type != LLAMA_SWA_TYPE_NONE) { + GGML_ASSERT(hparams.is_swa_any()); + + if (arch == LLM_ARCH_GEMMA4_ASSISTANT) { + llama_memory_t mem_other = llama_get_memory(cparams.ctx_other); + + share = [&](int32_t il) { + const llama_model * model_other = llama_get_model(cparams.ctx_other); + + if (hparams.is_swa(il)) { + return llama_model_n_layer(model_other) - 2; + } + + return llama_model_n_layer(model_other) - 1; + }; + + res = new llama_kv_cache_iswa( + *this, + params.type_k, + params.type_v, + !cparams.flash_attn, + cparams.offload_kqv, + params.swa_full, + cparams.kv_unified, + cparams.n_ctx_seq, + cparams.n_seq_max, + cparams.n_ubatch, + 1, + mem_other, + filter, + reuse, + share); + } else { + res = new llama_kv_cache_iswa( + *this, + params.type_k, + params.type_v, + !cparams.flash_attn, + cparams.offload_kqv, + params.swa_full, + cparams.kv_unified, + cparams.n_ctx_seq, + cparams.n_seq_max, + cparams.n_ubatch, + 1, + nullptr, + filter, + reuse, + share); + } + } else { + GGML_ASSERT(!hparams.is_swa_any()); + + res = new llama_kv_cache( + *this, + hparams, + params.type_k, + params.type_v, + !cparams.flash_attn, + cparams.offload_kqv, + cparams.kv_unified, + cparams.n_ctx_seq, + cparams.n_seq_max, + 1, + hparams.n_swa, + hparams.swa_type, + nullptr, + filter, + nullptr, + nullptr); + } } - } break; - case LLM_ARCH_GROVEMOE: - { - llm = std::make_unique<llm_build_grovemoe>(*this, params); - } break; - case LLM_ARCH_APERTUS: - { - llm = std::make_unique<llm_build_apertus>(*this, params); - } break; - case LLM_ARCH_MINIMAX_M2: - { - llm = std::make_unique<llm_build_minimax_m2>(*this, params); - } break; - case LLM_ARCH_COGVLM: - { - llm = std::make_unique<llm_build_cogvlm>(*this, params); - } break; - case LLM_ARCH_PANGU_EMBED: - { - llm = std::make_unique<llm_build_pangu_embedded>(*this, params); - } break; - case LLM_ARCH_QWEN3NEXT: - { - llm = std::make_unique<llm_build_qwen3next>(*this, params); - } break; - case LLM_ARCH_MISTRAL3: - { - llm = std::make_unique<llm_build_mistral3>(*this, params); - } break; - case LLM_ARCH_MIMO2: - { - llm = std::make_unique<llm_build_mimo2_iswa>(*this, params); - } break; - default: - GGML_ABORT("fatal error"); + } } + return res; +} + +ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const { + std::unique_ptr<llm_graph_context> llm = build_arch_graph(params); + // add on pooling layer - llm->build_pooling(cls, cls_b, cls_out, cls_out_b); + llm->build_pooling(cls, cls_b, cls_out, cls_out_b, cls_norm); // add backend sampling layers (if any) llm->build_sampling(); @@ -7960,9 +2238,9 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const { // there will be two additional dense projection layers // dense linear projections are applied after pooling // TODO: move reranking logic here and generalize - llm->build_dense_out(dense_2_out_layers, dense_3_out_layers); + llm->build_dense_out(dense_2_out_layers, dense_2_out_layers_b, dense_3_out_layers); - llm->res->set_outputs(); + llm->res->set_outputs(params); return llm->res->get_gf(); } @@ -7985,7 +2263,7 @@ llama_model_params llama_model_default_params() { /*.kv_overrides =*/ nullptr, /*.vocab_only =*/ false, /*.use_mmap =*/ true, - /*.use_direct_io =*/ true, + /*.use_direct_io =*/ false, /*.use_mlock =*/ false, /*.check_tensors =*/ false, /*.use_extra_bufts =*/ true, @@ -8021,11 +2299,11 @@ int32_t llama_model_n_embd_inp(const llama_model * model) { } int32_t llama_model_n_embd_out(const llama_model * model) { - return model->hparams.get_n_embd_out(); + return model->hparams.n_embd_out(); } int32_t llama_model_n_layer(const llama_model * model) { - return model->hparams.n_layer; + return model->hparams.n_layer(); } int32_t llama_model_n_head(const llama_model * model) { @@ -8040,6 +2318,7 @@ int32_t llama_model_n_swa(const llama_model * model) { return model->hparams.n_swa; } + uint32_t llama_model_n_cls_out(const struct llama_model * model) { return model->hparams.n_cls_out; } @@ -8095,6 +2374,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) { case LLM_ARCH_WAVTOKENIZER_DEC: case LLM_ARCH_NEMOTRON_H: case LLM_ARCH_NEMOTRON_H_MOE: + case LLM_ARCH_KIMI_LINEAR: return LLAMA_ROPE_TYPE_NONE; // use what we call a normal RoPE, operating on pairs of consecutive head values @@ -8113,6 +2393,8 @@ llama_rope_type llama_model_rope_type(const llama_model * model) { case LLM_ARCH_ARCTIC: case LLM_ARCH_DEEPSEEK: case LLM_ARCH_DEEPSEEK2: + case LLM_ARCH_DEEPSEEK2OCR: + case LLM_ARCH_DEEPSEEK32: case LLM_ARCH_PLM: case LLM_ARCH_CHATGLM: case LLM_ARCH_GRANITE: @@ -8126,8 +2408,11 @@ llama_rope_type llama_model_rope_type(const llama_model * model) { case LLM_ARCH_ERNIE4_5: case LLM_ARCH_ERNIE4_5_MOE: case LLM_ARCH_MISTRAL3: + case LLM_ARCH_EAGLE3: + case LLM_ARCH_MISTRAL4: case LLM_ARCH_LLAMA_EMBED: case LLM_ARCH_MAINCODER: + case LLM_ARCH_GLM_DSA: return LLAMA_ROPE_TYPE_NORM; // the pairs of head values are offset by n_rot/2 @@ -8140,6 +2425,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) { case LLM_ARCH_MODERN_BERT: case LLM_ARCH_NOMIC_BERT: case LLM_ARCH_NOMIC_BERT_MOE: + case LLM_ARCH_EUROBERT: case LLM_ARCH_STABLELM: case LLM_ARCH_BITNET: case LLM_ARCH_QWEN: @@ -8162,6 +2448,8 @@ llama_rope_type llama_model_rope_type(const llama_model * model) { case LLM_ARCH_GEMMA2: case LLM_ARCH_GEMMA3: case LLM_ARCH_GEMMA3N: + case LLM_ARCH_GEMMA4: + case LLM_ARCH_GEMMA4_ASSISTANT: case LLM_ARCH_GEMMA_EMBEDDING: case LLM_ARCH_STARCODER2: case LLM_ARCH_OPENELM: @@ -8171,10 +2459,12 @@ llama_rope_type llama_model_rope_type(const llama_model * model) { case LLM_ARCH_NEMOTRON: case LLM_ARCH_EXAONE: case LLM_ARCH_EXAONE4: + case LLM_ARCH_EXAONE_MOE: case LLM_ARCH_MINICPM3: case LLM_ARCH_BAILINGMOE2: case LLM_ARCH_DOTS1: case LLM_ARCH_HUNYUAN_MOE: + case LLM_ARCH_JAIS2: case LLM_ARCH_OPENAI_MOE: case LLM_ARCH_HUNYUAN_DENSE: case LLM_ARCH_LFM2: @@ -8189,12 +2479,18 @@ llama_rope_type llama_model_rope_type(const llama_model * model) { case LLM_ARCH_AFMOE: case LLM_ARCH_QWEN3NEXT: case LLM_ARCH_MIMO2: + case LLM_ARCH_STEP35: + case LLM_ARCH_TALKIE: + case LLM_ARCH_MELLUM: return LLAMA_ROPE_TYPE_NEOX; case LLM_ARCH_QWEN2VL: + case LLM_ARCH_PADDLEOCR: return LLAMA_ROPE_TYPE_MROPE; case LLM_ARCH_QWEN3VL: case LLM_ARCH_QWEN3VLMOE: + case LLM_ARCH_QWEN35: + case LLM_ARCH_QWEN35MOE: return LLAMA_ROPE_TYPE_IMROPE; case LLM_ARCH_GLM4: @@ -8202,6 +2498,9 @@ llama_rope_type llama_model_rope_type(const llama_model * model) { case LLM_ARCH_GLM4_MOE: return model->hparams.use_mrope() ? LLAMA_ROPE_TYPE_MROPE : LLAMA_ROPE_TYPE_NEOX; + case LLM_ARCH_HUNYUAN_VL: + return model->hparams.use_mrope() ? LLAMA_ROPE_TYPE_MROPE : LLAMA_ROPE_TYPE_NEOX; + // all model arches should be listed explicitly here case LLM_ARCH_UNKNOWN: GGML_ABORT("unknown architecture"); @@ -8304,8 +2603,9 @@ uint64_t llama_model_n_params(const llama_model * model) { bool llama_model_has_encoder(const llama_model * model) { switch (model->arch) { - case LLM_ARCH_T5: return true; - case LLM_ARCH_T5ENCODER: return true; + case LLM_ARCH_T5: + case LLM_ARCH_T5ENCODER: + case LLM_ARCH_EAGLE3: return true; default: return false; } } @@ -8336,3 +2636,67 @@ bool llama_model_is_diffusion(const llama_model * model) { const std::vector<std::pair<std::string, ggml_tensor *>> & llama_internal_get_tensor_map(const llama_model * model) { return model->tensors_by_name; } + +int32_t llama_model_n_expert(const struct llama_model * model) { + return model->hparams.n_expert; +} + +int32_t llama_model_n_devices(const struct llama_model * model) { + return (int32_t)model->devices.size(); +} + +ggml_backend_dev_t llama_model_get_device(const struct llama_model * model, int i) { + if (i < 0 || i >= (int)model->devices.size()) { + return nullptr; + } + return model->devices[i].dev; +} + +// +// llama_model_base +// + +llama_model_base::llama_model_base(const struct llama_model_params & params) : llama_model(params), model(this), tn(model->arch), + TENSOR_DUPLICATED (llama_model_loader::TENSOR_DUPLICATED), + TENSOR_NOT_REQUIRED (llama_model_loader::TENSOR_NOT_REQUIRED), + TENSOR_SKIP (llama_model_loader::TENSOR_SKIP), + TENSOR_SKIP_IF_VIRTUAL(llama_model_loader::TENSOR_SKIP_IF_VIRTUAL) {} + +ggml_tensor * llama_model_base::create_tensor(const LLM_TN_IMPL & tn, const std::initializer_list<int64_t> & ne, int flags) { + GGML_ASSERT(ml != nullptr); + return create_tensor(*ml, tn, ne, flags); +} + +void llama_model_base::create_tensor_gate_up_exps(llama_layer & layer, int bid, int64_t n_embd_, int64_t n_ff_, int64_t n_expert_, int flags) { + layer.ffn_gate_up_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_UP_EXPS, "weight", bid), {n_embd_, n_ff_ * 2, n_expert_}, TENSOR_NOT_REQUIRED); + if (layer.ffn_gate_up_exps == nullptr) { + layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", bid), {n_embd_, n_ff_, n_expert_}, flags); + layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", bid), {n_embd_, n_ff_, n_expert_}, flags); + } +} + +void llama_model_base::create_tensor_qkv(llama_layer & layer, int bid, + int64_t n_embd_, int64_t n_embd_q_, int64_t n_embd_k_, int64_t n_embd_v_, + int flags) { + const int64_t n_embd_qkv = n_embd_q_ + n_embd_k_ + n_embd_v_; + layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", bid), {n_embd_, n_embd_qkv}, TENSOR_NOT_REQUIRED | TENSOR_SKIP_IF_VIRTUAL); + if (layer.wqkv) { + layer.wqkv_b = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "bias", bid), {n_embd_qkv}, TENSOR_NOT_REQUIRED | TENSOR_SKIP_IF_VIRTUAL); + } else { + layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", bid), {n_embd_, n_embd_q_}, flags); + layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", bid), {n_embd_, n_embd_k_}, flags); + layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", bid), {n_embd_, n_embd_v_}, flags); + layer.wq_b = create_tensor(tn(LLM_TENSOR_ATTN_Q, "bias", bid), {n_embd_q_}, TENSOR_NOT_REQUIRED); + layer.wk_b = create_tensor(tn(LLM_TENSOR_ATTN_K, "bias", bid), {n_embd_k_}, TENSOR_NOT_REQUIRED); + layer.wv_b = create_tensor(tn(LLM_TENSOR_ATTN_V, "bias", bid), {n_embd_v_}, TENSOR_NOT_REQUIRED); + } +} + +const int32_t * llama_model_target_layer_ids(const struct llama_model * model) { + const auto & v = model->target_layer_ids; + return v.empty() ? nullptr : v.data(); +} + +uint32_t llama_model_target_layer_ids_n(const struct llama_model * model) { + return (uint32_t) model->target_layer_ids.size(); +} diff --git a/examples/talk-llama/llama-model.h b/examples/talk-llama/llama-model.h index 79200a0d97a..f4718f6d584 100644 --- a/examples/talk-llama/llama-model.h +++ b/examples/talk-llama/llama-model.h @@ -11,6 +11,7 @@ #include <memory> #include <string> #include <unordered_map> +#include <unordered_set> #include <vector> struct llama_cparams; @@ -53,6 +54,7 @@ enum llm_type { LLM_TYPE_0_3B, LLM_TYPE_0_5B, LLM_TYPE_0_6B, + LLM_TYPE_0_8B, LLM_TYPE_1B, LLM_TYPE_1_2B, LLM_TYPE_1_3B, @@ -82,6 +84,7 @@ enum llm_type { LLM_TYPE_26B, LLM_TYPE_27B, LLM_TYPE_30B, + LLM_TYPE_31B, LLM_TYPE_32B, LLM_TYPE_34B, LLM_TYPE_35B, @@ -113,25 +116,40 @@ enum llm_type { LLM_TYPE_A13B, LLM_TYPE_7B_A1B, LLM_TYPE_8B_A1B, // lfm2moe + LLM_TYPE_12B_A2_5B, LLM_TYPE_16B_A1B, LLM_TYPE_21B_A3B, // Ernie MoE small + LLM_TYPE_24B_A2B, // lfm2moe + LLM_TYPE_26B_A4B, // Gemma4 LLM_TYPE_30B_A3B, LLM_TYPE_31B_A3_5B, + LLM_TYPE_35B_A3B, // Qwen3.5 + LLM_TYPE_48B_A3B, // Kimi Linear LLM_TYPE_80B_A3B, // Qwen3 Next LLM_TYPE_100B_A6B, LLM_TYPE_102B_A12B, // Solar-Open LLM_TYPE_106B_A12B, // GLM-4.5-Air + LLM_TYPE_120B_A12B, // Nemotron 3 Super + LLM_TYPE_122B_A10B, // Qwen3.5 + LLM_TYPE_196B_A11B, // Step3.5-Flash LLM_TYPE_230B_A10B, // Minimax M2 LLM_TYPE_235B_A22B, LLM_TYPE_300B_A47B, // Ernie MoE big LLM_TYPE_310B_A15B, // /MiMo-V2-Flash LLM_TYPE_355B_A32B, // GLM-4.5 + LLM_TYPE_397B_A17B, // Qwen3.5 + LLM_TYPE_685B_A37B, // DeepSeek V3.2 + LLM_TYPE_744B_A40B, // GLM-5 LLM_TYPE_E2B, LLM_TYPE_E4B, }; std::string llama_rope_scaling_type_name(llama_rope_scaling_type rope_scaling_type); +// Map a GGUF activation-name string to llm_ffn_op_type. Returns `fallback` if +// the string is empty or not recognized. +llm_ffn_op_type llm_ffn_op_type_from_string(const std::string & name, llm_ffn_op_type fallback); + struct llama_layer_posnet { // resnet struct ggml_tensor * norm1 = nullptr; @@ -190,12 +208,16 @@ struct llama_layer_shortconv { }; struct llama_layer_nextn { - struct ggml_tensor * eh_proj = nullptr; - struct ggml_tensor * embed_tokens = nullptr; - struct ggml_tensor * enorm = nullptr; - struct ggml_tensor * hnorm = nullptr; - struct ggml_tensor * shared_head_head = nullptr; - struct ggml_tensor * shared_head_norm = nullptr; + struct ggml_tensor * eh_proj = nullptr; + struct ggml_tensor * eh_proj_s = nullptr; + struct ggml_tensor * eh_proj_in_s = nullptr; + struct ggml_tensor * embed_tokens = nullptr; + struct ggml_tensor * enorm = nullptr; + struct ggml_tensor * hnorm = nullptr; + struct ggml_tensor * shared_head_head = nullptr; + struct ggml_tensor * shared_head_head_s = nullptr; + struct ggml_tensor * shared_head_head_in_s = nullptr; + struct ggml_tensor * shared_head_norm = nullptr; }; struct llama_layer { @@ -234,6 +256,8 @@ struct llama_layer { struct ggml_tensor * wkv_b = nullptr; struct ggml_tensor * wk_b = nullptr; struct ggml_tensor * wv_b = nullptr; + struct ggml_tensor * wqkv_b = nullptr; + struct ggml_tensor * wo_b = nullptr; struct ggml_tensor * wq_cross = nullptr; struct ggml_tensor * wk_cross = nullptr; struct ggml_tensor * wv_cross = nullptr; @@ -244,13 +268,6 @@ struct llama_layer { struct ggml_tensor * wo_enc = nullptr; struct ggml_tensor * wqkv_gate = nullptr; - // attention bias - struct ggml_tensor * bq = nullptr; - struct ggml_tensor * bk = nullptr; - struct ggml_tensor * bv = nullptr; - struct ggml_tensor * bo = nullptr; - struct ggml_tensor * bqkv = nullptr; - // relative position bias struct ggml_tensor * attn_rel_b = nullptr; struct ggml_tensor * attn_rel_b_enc = nullptr; @@ -260,6 +277,9 @@ struct llama_layer { struct ggml_tensor * ffn_norm = nullptr; struct ggml_tensor * ffn_norm_b = nullptr; struct ggml_tensor * ffn_post_norm = nullptr; + struct ggml_tensor * ffn_post_norm_1 = nullptr; // gemma4 + struct ggml_tensor * ffn_post_norm_2 = nullptr; // gemma4 + struct ggml_tensor * ffn_pre_norm_2 = nullptr; // gemma4 struct ggml_tensor * layer_out_norm = nullptr; struct ggml_tensor * layer_out_norm_b = nullptr; struct ggml_tensor * ffn_norm_exps = nullptr; @@ -274,14 +294,26 @@ struct llama_layer { struct ggml_tensor * ffn_up_enc = nullptr; // ff MoE - struct ggml_tensor * ffn_gate_inp = nullptr; - struct ggml_tensor * ffn_gate_exps = nullptr; - struct ggml_tensor * ffn_down_exps = nullptr; - struct ggml_tensor * ffn_up_exps = nullptr; - struct ggml_tensor * ffn_gate_inp_b = nullptr; - struct ggml_tensor * ffn_gate_exps_b = nullptr; - struct ggml_tensor * ffn_down_exps_b = nullptr; - struct ggml_tensor * ffn_up_exps_b = nullptr; + struct ggml_tensor * ffn_gate_inp = nullptr; + struct ggml_tensor * ffn_gate_inp_s = nullptr; // gemma4 + struct ggml_tensor * ffn_gate_exps = nullptr; + struct ggml_tensor * ffn_down_exps = nullptr; + struct ggml_tensor * ffn_up_exps = nullptr; + struct ggml_tensor * ffn_gate_up_exps = nullptr; + struct ggml_tensor * ffn_gate_inp_b = nullptr; + struct ggml_tensor * ffn_gate_exps_b = nullptr; + struct ggml_tensor * ffn_down_exps_b = nullptr; + struct ggml_tensor * ffn_up_exps_b = nullptr; + struct ggml_tensor * ffn_gate_up_exps_b = nullptr; + + // ff MoE per-expert scales (NVFP4 per-tensor scale2) + struct ggml_tensor * ffn_gate_exps_s = nullptr; + struct ggml_tensor * ffn_down_exps_s = nullptr; + struct ggml_tensor * ffn_up_exps_s = nullptr; + + // ff MoE latent proj + struct ggml_tensor * ffn_latent_down = nullptr; + struct ggml_tensor * ffn_latent_up = nullptr; // ff shared expert (shexp) struct ggml_tensor * ffn_gate_inp_shexp = nullptr; @@ -319,6 +351,9 @@ struct llama_layer { // qwen3next struct ggml_tensor * ssm_beta_alpha = nullptr; + // qwen3.5 + struct ggml_tensor * ssm_alpha = nullptr; + // rwkv struct ggml_tensor * time_mix_w1 = nullptr; struct ggml_tensor * time_mix_w2 = nullptr; @@ -373,13 +408,43 @@ struct llama_layer { struct ggml_tensor * rope_freqs = nullptr; // bitnet scale - struct ggml_tensor * wq_scale = nullptr; - struct ggml_tensor * wk_scale = nullptr; - struct ggml_tensor * wv_scale = nullptr; - struct ggml_tensor * wo_scale = nullptr; - struct ggml_tensor * ffn_gate_scale = nullptr; - struct ggml_tensor * ffn_up_scale = nullptr; - struct ggml_tensor * ffn_down_scale = nullptr; + struct ggml_tensor * wq_s = nullptr; + struct ggml_tensor * wk_s = nullptr; + struct ggml_tensor * wv_s = nullptr; + struct ggml_tensor * wo_s = nullptr; + struct ggml_tensor * wqkv_s = nullptr; + struct ggml_tensor * wqkv_gate_s = nullptr; + struct ggml_tensor * ffn_gate_s = nullptr; + struct ggml_tensor * ffn_up_s = nullptr; + struct ggml_tensor * ffn_down_s = nullptr; + struct ggml_tensor * ffn_gate_shexp_s = nullptr; + struct ggml_tensor * ffn_up_shexp_s = nullptr; + struct ggml_tensor * ffn_down_shexp_s = nullptr; + struct ggml_tensor * ssm_in_s = nullptr; + struct ggml_tensor * ssm_out_s = nullptr; + struct ggml_tensor * ssm_alpha_s = nullptr; + struct ggml_tensor * ssm_beta_s = nullptr; + + // input scales + struct ggml_tensor * wq_in_s = nullptr; + struct ggml_tensor * wk_in_s = nullptr; + struct ggml_tensor * wv_in_s = nullptr; + struct ggml_tensor * wo_in_s = nullptr; + struct ggml_tensor * wqkv_in_s = nullptr; + struct ggml_tensor * wqkv_gate_in_s = nullptr; + struct ggml_tensor * ffn_gate_in_s = nullptr; + struct ggml_tensor * ffn_up_in_s = nullptr; + struct ggml_tensor * ffn_down_in_s = nullptr; + struct ggml_tensor * ffn_gate_exps_in_s = nullptr; + struct ggml_tensor * ffn_down_exps_in_s = nullptr; + struct ggml_tensor * ffn_up_exps_in_s = nullptr; + struct ggml_tensor * ffn_gate_shexp_in_s= nullptr; + struct ggml_tensor * ffn_up_shexp_in_s = nullptr; + struct ggml_tensor * ffn_down_shexp_in_s= nullptr; + struct ggml_tensor * ssm_in_in_s = nullptr; + struct ggml_tensor * ssm_out_in_s = nullptr; + struct ggml_tensor * ssm_alpha_in_s = nullptr; + struct ggml_tensor * ssm_beta_in_s = nullptr; // altup & laurel struct ggml_tensor * per_layer_inp_gate = nullptr; @@ -410,6 +475,28 @@ struct llama_layer { struct ggml_tensor * ffn_act_beta = nullptr; struct ggml_tensor * ffn_act_eps = nullptr; + // Kimi Linear KDA (using ssm_ prefix for consistency) + // Note: ssm_dt_b already exists above (mamba bias), reused for Kimi dt_bias + struct ggml_tensor * ssm_q_conv = nullptr; + struct ggml_tensor * ssm_k_conv = nullptr; + struct ggml_tensor * ssm_v_conv = nullptr; + struct ggml_tensor * ssm_f_a = nullptr; + struct ggml_tensor * ssm_f_b = nullptr; + struct ggml_tensor * ssm_beta = nullptr; + struct ggml_tensor * ssm_g_a = nullptr; + struct ggml_tensor * ssm_g_b = nullptr; + struct ggml_tensor * ssm_o_norm = nullptr; + + // DSA (deepseek sparse attention) + struct ggml_tensor * indexer_k_norm = nullptr; + struct ggml_tensor * indexer_k_norm_b = nullptr; + struct ggml_tensor * indexer_proj = nullptr; + struct ggml_tensor * indexer_attn_k = nullptr; + struct ggml_tensor * indexer_attn_q_b = nullptr; // note: for lora a/b, not bias + + // gemma4 layer output scale, reused for talkie embedding skip scale + struct ggml_tensor * out_scale = nullptr; + struct llama_layer_posnet posnet; struct llama_layer_convnext convnext; @@ -419,6 +506,19 @@ struct llama_layer { struct llama_layer_nextn nextn; }; +struct llama_device { + bool is_meta; + + ggml_backend_dev_t dev; +}; + +struct llama_meta_device_get_split_state_userdata { + size_t n_devices; + const struct llama_model * model; +}; + +struct ggml_backend_meta_split_state llama_meta_device_get_split_state(const struct ggml_tensor * tensor, void * userdata); + struct llama_model { llm_type type = LLM_TYPE_UNKNOWN; llm_arch arch = LLM_ARCH_UNKNOWN; @@ -443,53 +543,68 @@ struct llama_model { struct ggml_tensor * output_b = nullptr; struct ggml_tensor * output_norm_enc = nullptr; + + // NVFP4 per-tensor scale2, input_scale for LM head + struct ggml_tensor * output_s = nullptr; + struct ggml_tensor * output_in_s = nullptr; + + // NextN/MTP model-level projections + struct ggml_tensor * nextn_proj_pre = nullptr; + struct ggml_tensor * nextn_proj_post = nullptr; + // classifier struct ggml_tensor * cls = nullptr; struct ggml_tensor * cls_b = nullptr; struct ggml_tensor * cls_out = nullptr; struct ggml_tensor * cls_out_b = nullptr; + struct ggml_tensor * cls_norm = nullptr; struct ggml_tensor * conv1d = nullptr; struct ggml_tensor * conv1d_b = nullptr; // gemma3n altup - struct ggml_tensor * tok_embd_per_layer = nullptr; struct ggml_tensor * altup_proj = nullptr; struct ggml_tensor * altup_unembd_proj = nullptr; + struct ggml_tensor * per_layer_tok_embd = nullptr; struct ggml_tensor * per_layer_model_proj = nullptr; struct ggml_tensor * per_layer_proj_norm = nullptr; + // eagle3 + struct ggml_tensor * fc = nullptr; // feature fusion layer + struct ggml_tensor * d2t = nullptr; // draft to target vocabulary mapping + + // unified vector to store target-model extracted layer ids in eagle3, dflash, etc. + std::vector<int32_t> target_layer_ids; + std::vector<llama_layer> layers; //Dense linear projections for SentenceTransformers models like embeddinggemma // For Sentence Transformers models structure see // https://sbert.net/docs/sentence_transformer/usage/custom_models.html#structure-of-sentence-transformer-models - struct ggml_tensor * dense_2_out_layers = nullptr; - struct ggml_tensor * dense_3_out_layers = nullptr; + struct ggml_tensor * dense_2_out_layers = nullptr; + struct ggml_tensor * dense_2_out_layers_b = nullptr; + struct ggml_tensor * dense_3_out_layers = nullptr; // gguf metadata std::unordered_map<std::string, std::string> gguf_kv; // list of devices used in this model - std::vector<ggml_backend_dev_t> devices; + std::vector<llama_device> devices; // for quantize-stats only std::vector<std::pair<std::string, struct ggml_tensor *>> tensors_by_name; - // for keeping track of extra nodes used by lora adapters - uint32_t n_lora_nodes = 0; + // for keeping track of associated LoRA adapters + std::unordered_set<llama_adapter_lora *> loras; + + // statically allocated context for assigning + struct llama_meta_device_get_split_state_userdata get_split_state_ud; int64_t t_load_us = 0; int64_t t_start_us = 0; - explicit llama_model(const struct llama_model_params & params); - ~llama_model(); - - void load_stats (llama_model_loader & ml); - void load_arch (llama_model_loader & ml); - void load_hparams(llama_model_loader & ml); - void load_vocab (llama_model_loader & ml); - bool load_tensors(llama_model_loader & ml); // returns false if cancelled by progress_callback + explicit llama_model(const llama_model_params & params); + virtual ~llama_model(); std::string arch_name() const; std::string type_name() const; @@ -499,6 +614,7 @@ struct llama_model { size_t size() const; // file size size_t n_tensors() const; size_t n_devices() const; + const float * tensor_split() const; uint32_t n_gpu_layers() const; llama_split_mode split_mode() const; @@ -524,21 +640,96 @@ struct llama_model { ggml_tensor * get_rope_factors(const llama_cparams & cparams, int il) const; - // TODO: move this to new llm_arch_model_i interface llama_memory_i * create_memory(const llama_memory_params & params, const llama_cparams & cparams) const; - // TODO: move this to new llm_arch_model_i interface ggml_cgraph * build_graph(const llm_graph_params & params) const; -private: + virtual void load_stats (llama_model_loader & ml) = 0; + virtual void load_hparams(llama_model_loader & ml) = 0; + virtual void load_vocab (llama_model_loader & ml) = 0; + virtual bool load_tensors(llama_model_loader & ml) = 0; // returns false if cancelled by progress_callback + + // model must define these + virtual void load_arch_hparams(llama_model_loader & ml) = 0; + virtual void load_arch_tensors(llama_model_loader & ml) = 0; + virtual std::unique_ptr<llm_graph_context> build_arch_graph(const llm_graph_params & params) const = 0; + +protected: llama_model_params params; struct impl; std::unique_ptr<impl> pimpl; }; +llama_model * llama_model_create(llm_arch arch, const llama_model_params & params); +llama_model * llama_model_create(llama_model_loader & ml, const llama_model_params & params); + +// model must inherit from this +struct llama_model_base : public llama_model { + friend struct llama_model; + + llama_model * model; + llama_model_loader * ml = nullptr; + const LLM_TN tn; + + // llama_model_loader is not yet defined at this point, so we will set it after construction + const int TENSOR_DUPLICATED; + const int TENSOR_NOT_REQUIRED; + const int TENSOR_SKIP; + const int TENSOR_SKIP_IF_VIRTUAL; + + explicit llama_model_base(const llama_model_params & params); + virtual ~llama_model_base() = default; + + ggml_tensor * create_tensor(llama_model_loader & ml, const LLM_TN_IMPL & tn, const std::initializer_list<int64_t> & ne, int flags); + + // convenience overload of create_tensor that doesn't require llama_model_loader + ggml_tensor * create_tensor(const LLM_TN_IMPL & tn, const std::initializer_list<int64_t> & ne, int flags); + + // helper: try merged gate_up_exps first, fall back to separate gate and up + void create_tensor_gate_up_exps(llama_layer & layer, int bid, int64_t n_embd_, + int64_t n_ff_, int64_t n_expert_, int flags); + + // helper: try to load merged qkv first, fall back to separate q, k, v + void create_tensor_qkv(llama_layer & layer, int bid, + int64_t n_embd_, int64_t n_embd_q_, int64_t n_embd_k_, int64_t n_embd_v_, + int flags); + + void load_stats (llama_model_loader & ml) override; + void load_hparams(llama_model_loader & ml) override; + void load_vocab (llama_model_loader & ml) override; + bool load_tensors(llama_model_loader & ml) override; + + // model must define these + void load_arch_hparams(llama_model_loader & ml) override = 0; + void load_arch_tensors(llama_model_loader & ml) override = 0; + std::unique_ptr<llm_graph_context> build_arch_graph(const llm_graph_params & params) const override = 0; +}; + const char * llm_type_name(llm_type type); +// convenience macro for loading local variables for load_tensors() in llama_model_base +// note: cast to int64_t since we will use these for the tensor dimensions +#define LLAMA_LOAD_LOCALS \ + const int n_layer = hparams.n_layer(); GGML_UNUSED(n_layer); \ + const int n_layer_all = hparams.n_layer_all; GGML_UNUSED(n_layer_all); \ + const int n_layer_nextn = hparams.n_layer_nextn; GGML_UNUSED(n_layer_nextn); \ + const int64_t n_head = hparams.n_head(); GGML_UNUSED(n_head); \ + const int64_t n_head_kv = hparams.n_head_kv(); GGML_UNUSED(n_head_kv); \ + const int64_t n_embd = hparams.n_embd; GGML_UNUSED(n_embd); \ + const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(); GGML_UNUSED(n_embd_k_gqa); \ + const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa(); GGML_UNUSED(n_embd_v_gqa); \ + const int64_t n_embd_head_k = hparams.n_embd_head_k(); GGML_UNUSED(n_embd_head_k); \ + const int64_t n_embd_head_v = hparams.n_embd_head_v(); GGML_UNUSED(n_embd_head_v); \ + const int64_t n_ff = hparams.n_ff(); GGML_UNUSED(n_ff); \ + const int64_t n_embd_gqa = n_embd_v_gqa; GGML_UNUSED(n_embd_gqa); \ + const int64_t n_vocab = vocab.n_tokens(); GGML_UNUSED(n_vocab); \ + const int64_t n_token_types = vocab.n_token_types(); GGML_UNUSED(n_token_types); \ + const int64_t n_rot = hparams.n_rot(); GGML_UNUSED(n_rot); \ + const int64_t n_expert = hparams.n_expert; GGML_UNUSED(n_expert); \ + const int64_t n_expert_used = hparams.n_expert_used; GGML_UNUSED(n_expert_used); \ + const int64_t n_ctx_train = hparams.n_ctx_train; GGML_UNUSED(n_ctx_train); + // For internal test use // TODO: remove const std::vector<std::pair<std::string, ggml_tensor *>> & llama_internal_get_tensor_map(const llama_model * model); diff --git a/examples/talk-llama/llama-quant.cpp b/examples/talk-llama/llama-quant.cpp index 048d65a75c2..cf92ce4bb8b 100644 --- a/examples/talk-llama/llama-quant.cpp +++ b/examples/talk-llama/llama-quant.cpp @@ -1,7 +1,7 @@ -#include "llama-quant.h" #include "llama-impl.h" #include "llama-model.h" #include "llama-model-loader.h" +#include "llama-ext.h" #include <algorithm> #include <cmath> @@ -13,10 +13,28 @@ #include <thread> #include <unordered_map> -// Quantization types. Changes to this struct must be replicated in quantize.cpp -struct tensor_quantization { +// result of parsing --tensor-type option +// (changes to this struct must be reflected in tools/quantize/quantize.cpp) +struct tensor_type_option { std::string name; - ggml_type quant = GGML_TYPE_COUNT; + ggml_type type = GGML_TYPE_COUNT; +}; + +// tensor categorization - used to avoid repeated string matching in quantization logic. +// this is different from LLM_TN - we want broad categories, not specific tensor names per arch. +enum class tensor_category { + TOKEN_EMBD, + ATTENTION_Q, + ATTENTION_V, + ATTENTION_K, + ATTENTION_QKV, + ATTENTION_KV_B, + ATTENTION_OUTPUT, + FFN_UP, + FFN_GATE, + FFN_DOWN, + OUTPUT, + OTHER }; static void zeros(std::ofstream & file, size_t n) { @@ -54,7 +72,7 @@ static std::string remap_layer(const std::string & orig_name, const std::vector< return orig_name; } -static std::string remap_imatrix (const std::string & orig_name, const std::map<int, std::string> & mapped) { +static std::string remap_imatrix(const std::string & orig_name, const std::map<int, std::string> & mapped) { if (mapped.empty()) { return orig_name; } @@ -66,7 +84,6 @@ static std::string remap_imatrix (const std::string & orig_name, const std::map< for (const auto & p : mapped) { if (p.second == blk) { - LLAMA_LOG_DEBUG("(blk.%d imatrix) ", p.first); return new_name.replace(match.position(1), match.length(1), std::to_string(p.first)); } } @@ -76,6 +93,73 @@ static std::string remap_imatrix (const std::string & orig_name, const std::map< return orig_name; } +// +// helper functions for tensor name matching +// + +static bool tensor_name_match_token_embd(const char * tensor_name) { + return std::strcmp(tensor_name, "token_embd.weight") == 0 || + std::strcmp(tensor_name, "per_layer_token_embd.weight") == 0; +} + +static bool tensor_name_match_output_weight(const char * tensor_name) { + return std::strcmp(tensor_name, "output.weight") == 0; +} + +// +// tensor categorization for quantization +// +// (this is different from LLM_TN - we want broad categories, not specific tensor names per arch) +// + +static tensor_category tensor_get_category(const std::string & tensor_name) { + if (tensor_name_match_output_weight(tensor_name.c_str())) { + return tensor_category::OUTPUT; + } + if (tensor_name_match_token_embd(tensor_name.c_str())) { + return tensor_category::TOKEN_EMBD; + } + if (tensor_name.find("attn_qkv.weight") != std::string::npos) { + return tensor_category::ATTENTION_QKV; + } + if (tensor_name.find("attn_kv_b.weight") != std::string::npos) { + return tensor_category::ATTENTION_KV_B; + } + if (tensor_name.find("attn_v.weight") != std::string::npos) { + return tensor_category::ATTENTION_V; + } + if (tensor_name.find("attn_k.weight") != std::string::npos) { + return tensor_category::ATTENTION_K; + } + if (tensor_name.find("attn_q.weight") != std::string::npos) { + return tensor_category::ATTENTION_Q; + } + if (tensor_name.find("attn_output.weight") != std::string::npos) { + return tensor_category::ATTENTION_OUTPUT; + } + if (tensor_name.find("ffn_up") != std::string::npos) { + return tensor_category::FFN_UP; + } + if (tensor_name.find("ffn_gate") != std::string::npos) { + return tensor_category::FFN_GATE; + } + if (tensor_name.find("ffn_down") != std::string::npos) { + return tensor_category::FFN_DOWN; + } + return tensor_category::OTHER; +} + +// check if category is for attention-v-like tensors (more sensitive to quantization) +static bool category_is_attn_v(tensor_category cat) { + return cat == tensor_category::ATTENTION_V || + cat == tensor_category::ATTENTION_QKV || + cat == tensor_category::ATTENTION_KV_B; +} + +// +// quantization state +// + struct quantize_state_impl { const llama_model & model; const llama_model_quantize_params * params; @@ -89,20 +173,42 @@ struct quantize_state_impl { int i_ffn_gate = 0; int i_ffn_up = 0; - int n_k_quantized = 0; int n_fallback = 0; bool has_imatrix = false; - // used to figure out if a model shares tok_embd with the output weight - bool has_output = false; + // used to figure out if a model has tied embeddings (tok_embd shares weights with output) + bool has_tied_embeddings = true; // assume tied until we see output.weight + + // tensor type override patterns (compiled once, used twice) + std::vector<std::pair<std::regex, ggml_type>> tensor_type_patterns; + + quantize_state_impl(const llama_model & model, const llama_model_quantize_params * params): + model(model), params(params) + { + // compile regex patterns once - they are expensive + if (params->tt_overrides) { + for (const auto * p = params->tt_overrides; p->pattern != nullptr; p++) { + tensor_type_patterns.emplace_back(std::regex(p->pattern), p->type); + } + } + } +}; - quantize_state_impl(const llama_model & model, const llama_model_quantize_params * params) - : model(model) - , params(params) - {} +// per-tensor metadata, computed in the preliminary loop and used in the main loop +struct tensor_metadata { + std::string name; + ggml_type target_type; + tensor_category category; + std::string remapped_imatrix_name; + bool allows_quantization; + bool requires_imatrix; }; +// +// dequantization +// + static void llama_tensor_dequantize_impl( ggml_tensor * tensor, std::vector<no_init<float>> & output, std::vector<std::thread> & workers, const size_t nelements, const int nthread @@ -175,12 +281,138 @@ static void llama_tensor_dequantize_impl( workers.clear(); } -static ggml_type llama_tensor_get_type(quantize_state_impl & qs, ggml_type new_type, const ggml_tensor * tensor, llama_ftype ftype) { +// +// do we allow this tensor to be quantized? +// + +static bool tensor_allows_quantization(const llama_model_quantize_params * params, llm_arch arch, const ggml_tensor * tensor) { + // trivial checks first -- no string ops needed + if (params->only_copy) return false; + + // quantize only 2D and 3D tensors (experts) + if (ggml_n_dims(tensor) < 2) return false; + + const std::string name = ggml_get_name(tensor); + + // This used to be a regex, but <regex> has an extreme cost to compile times. + bool quantize = name.rfind("weight") == name.size() - 6; // ends with 'weight'? + + // do not quantize norm tensors + quantize &= name.find("_norm.weight") == std::string::npos; + + quantize &= params->quantize_output_tensor || name != "output.weight"; + + // do not quantize expert gating tensors + // NOTE: can't use LLM_TN here because the layer number is not known + quantize &= name.find("ffn_gate_inp.weight") == std::string::npos; + + // these are very small (e.g. 4x4) + quantize &= name.find("altup") == std::string::npos; + quantize &= name.find("laurel") == std::string::npos; + + // these are not too big so keep them as it is + quantize &= name.find("per_layer_model_proj") == std::string::npos; + + // do not quantize positional embeddings and token types (BERT) + quantize &= name != LLM_TN(arch)(LLM_TENSOR_POS_EMBD, "weight"); + quantize &= name != LLM_TN(arch)(LLM_TENSOR_TOKEN_TYPES, "weight"); + + // do not quantize Mamba/Kimi's small conv1d weights + // NOTE: can't use LLM_TN here because the layer number is not known + quantize &= name.find("ssm_conv1d") == std::string::npos; + quantize &= name.find("shortconv.conv.weight") == std::string::npos; + + // do not quantize RWKV's small yet 2D weights + quantize &= name.find("time_mix_first.weight") == std::string::npos; + quantize &= name.find("time_mix_w0.weight") == std::string::npos; + quantize &= name.find("time_mix_w1.weight") == std::string::npos; + quantize &= name.find("time_mix_w2.weight") == std::string::npos; + quantize &= name.find("time_mix_v0.weight") == std::string::npos; + quantize &= name.find("time_mix_v1.weight") == std::string::npos; + quantize &= name.find("time_mix_v2.weight") == std::string::npos; + quantize &= name.find("time_mix_a0.weight") == std::string::npos; + quantize &= name.find("time_mix_a1.weight") == std::string::npos; + quantize &= name.find("time_mix_a2.weight") == std::string::npos; + quantize &= name.find("time_mix_g1.weight") == std::string::npos; + quantize &= name.find("time_mix_g2.weight") == std::string::npos; + quantize &= name.find("time_mix_decay_w1.weight") == std::string::npos; + quantize &= name.find("time_mix_decay_w2.weight") == std::string::npos; + quantize &= name.find("time_mix_lerp_fused.weight") == std::string::npos; + + // do not quantize relative position bias (T5) + quantize &= name.find("attn_rel_b.weight") == std::string::npos; + + // do not quantize specific multimodal tensors + quantize &= name.find(".position_embd") == std::string::npos; + quantize &= name.find("sam.pos_embd") == std::string::npos; + quantize &= name.find("sam.neck.") == std::string::npos; + quantize &= name.find("sam.net_") == std::string::npos; + quantize &= name.find(".rel_pos") == std::string::npos; + quantize &= name.find(".patch_embd") == std::string::npos; + quantize &= name.find(".patch_merger") == std::string::npos; + + return quantize; +} + +// +// tensor type selection +// + +// incompatible tensor shapes are handled here - fallback to a compatible type +static ggml_type tensor_type_fallback(quantize_state_impl & qs, const ggml_tensor * t, const ggml_type target_type) { + ggml_type return_type = target_type; + + const int64_t ncols = t->ne[0]; + const int64_t qk_k = ggml_blck_size(target_type); + + if (ncols % qk_k != 0) { // this tensor's shape is incompatible with this quant + LLAMA_LOG_WARN("warning: %-36s - ncols %6" PRId64 " not divisible by %3" PRId64 " (required for type %7s) ", + t->name, ncols, qk_k, ggml_type_name(target_type)); + ++qs.n_fallback; + + switch (target_type) { + // types on the left: block size 256 + case GGML_TYPE_IQ1_S: + case GGML_TYPE_IQ1_M: + case GGML_TYPE_IQ2_XXS: + case GGML_TYPE_IQ2_XS: + case GGML_TYPE_IQ2_S: + case GGML_TYPE_IQ3_XXS: + case GGML_TYPE_IQ3_S: // types on the right: block size 32 + case GGML_TYPE_IQ4_XS: return_type = GGML_TYPE_IQ4_NL; break; + case GGML_TYPE_Q2_K: + case GGML_TYPE_Q3_K: + case GGML_TYPE_TQ1_0: + case GGML_TYPE_TQ2_0: return_type = GGML_TYPE_Q4_0; break; + case GGML_TYPE_Q4_K: return_type = GGML_TYPE_Q5_0; break; + case GGML_TYPE_Q5_K: return_type = GGML_TYPE_Q5_1; break; + case GGML_TYPE_Q6_K: return_type = GGML_TYPE_Q8_0; break; + default: + throw std::runtime_error(format("no tensor type fallback is defined for type %s", + ggml_type_name(target_type))); + } + if (ncols % ggml_blck_size(return_type) != 0) { + // + // the fallback return type is still not compatible for this tensor! + // + // most likely, this tensor's first dimension is not divisible by 32. + // this is very rare. we can either abort the quantization, or + // fallback to F16 / F32. + // + LLAMA_LOG_WARN("(WARNING: must use F16 due to unusual shape) "); + return_type = GGML_TYPE_F16; + } + LLAMA_LOG_WARN("-> falling back to %7s\n", ggml_type_name(return_type)); + } + return return_type; +} + +// internal standard logic for selecting the target tensor type based on tensor category, ftype, and model arch +static ggml_type llama_tensor_get_type_impl(quantize_state_impl & qs, ggml_type new_type, const ggml_tensor * tensor, llama_ftype ftype, tensor_category category) { const std::string name = ggml_get_name(tensor); // TODO: avoid hardcoded tensor names - use the TN_* constants const llm_arch arch = qs.model.arch; - const auto tn = LLM_TN(arch); auto use_more_bits = [](int i_layer, int n_layers) -> bool { return i_layer < n_layers/8 || i_layer >= 7*n_layers/8 || (i_layer - n_layers/8)%3 == 2; @@ -204,7 +436,7 @@ static ggml_type llama_tensor_get_type(quantize_state_impl & qs, ggml_type new_t // for arches that share the same tensor between the token embeddings and the output, we quantize the token embeddings // with the quantization of the output tensor - if (name == tn(LLM_TENSOR_OUTPUT, "weight") || (!qs.has_output && name == tn(LLM_TENSOR_TOKEN_EMBD, "weight"))) { + if (category == tensor_category::OUTPUT || (qs.has_tied_embeddings && category == tensor_category::TOKEN_EMBD)) { if (qs.params->output_tensor_type < GGML_TYPE_COUNT) { new_type = qs.params->output_tensor_type; } else { @@ -234,7 +466,7 @@ static ggml_type llama_tensor_get_type(quantize_state_impl & qs, ggml_type new_t } else { new_type = GGML_TYPE_Q8_0; } - } else if (name == "token_embd.weight" || name == "per_layer_token_embd.weight") { + } else if (category == tensor_category::TOKEN_EMBD) { if (qs.params->token_embedding_type < GGML_TYPE_COUNT) { new_type = qs.params->token_embedding_type; } else { @@ -254,21 +486,21 @@ static ggml_type llama_tensor_get_type(quantize_state_impl & qs, ggml_type new_t } } else if (ftype == LLAMA_FTYPE_MOSTLY_IQ2_XXS || ftype == LLAMA_FTYPE_MOSTLY_IQ2_XS || ftype == LLAMA_FTYPE_MOSTLY_IQ1_S || ftype == LLAMA_FTYPE_MOSTLY_IQ2_S || ftype == LLAMA_FTYPE_MOSTLY_IQ2_M || ftype == LLAMA_FTYPE_MOSTLY_IQ1_M) { - if (name.find("attn_v.weight") != std::string::npos) { + if (category_is_attn_v(category)) { if (qs.model.hparams.n_gqa() >= 4 || qs.model.hparams.n_expert >= 4) new_type = GGML_TYPE_Q4_K; else new_type = ftype == LLAMA_FTYPE_MOSTLY_IQ2_S || ftype == LLAMA_FTYPE_MOSTLY_IQ2_M ? GGML_TYPE_IQ3_S : GGML_TYPE_Q2_K; ++qs.i_attention_wv; } - else if (qs.model.hparams.n_expert == 8 && name.find("attn_k.weight") != std::string::npos) { + else if (qs.model.hparams.n_expert == 8 && category == tensor_category::ATTENTION_K) { new_type = GGML_TYPE_Q4_K; } - else if (name.find("ffn_down") != std::string::npos) { + else if (category == tensor_category::FFN_DOWN) { if (qs.i_ffn_down < qs.n_ffn_down/8) { new_type = ftype == LLAMA_FTYPE_MOSTLY_IQ2_S || ftype == LLAMA_FTYPE_MOSTLY_IQ2_M ? GGML_TYPE_IQ3_S : GGML_TYPE_Q2_K; } ++qs.i_ffn_down; } - else if (name.find("attn_output.weight") != std::string::npos) { + else if (category == tensor_category::ATTENTION_OUTPUT) { if (qs.model.hparams.n_expert == 8) { new_type = GGML_TYPE_Q5_K; } else { @@ -276,7 +508,7 @@ static ggml_type llama_tensor_get_type(quantize_state_impl & qs, ggml_type new_t else if (ftype == LLAMA_FTYPE_MOSTLY_IQ2_S || ftype == LLAMA_FTYPE_MOSTLY_IQ2_M) new_type = GGML_TYPE_IQ3_S; } } - } else if (name.find("attn_v.weight") != std::string::npos) { + } else if (category_is_attn_v(category)) { if (ftype == LLAMA_FTYPE_MOSTLY_Q2_K) { new_type = qs.model.hparams.n_gqa() >= 4 ? GGML_TYPE_Q4_K : GGML_TYPE_Q3_K; } @@ -314,7 +546,7 @@ static ggml_type llama_tensor_get_type(quantize_state_impl & qs, ggml_type new_t new_type = GGML_TYPE_Q8_0; } ++qs.i_attention_wv; - } else if (name.find("attn_k.weight") != std::string::npos) { + } else if (category == tensor_category::ATTENTION_K) { if (qs.model.hparams.n_expert == 8) { // for the 8-expert model, bumping this to Q8_0 trades just ~128MB // TODO: explore better strategies @@ -326,14 +558,14 @@ static ggml_type llama_tensor_get_type(quantize_state_impl & qs, ggml_type new_t else if (ftype == LLAMA_FTYPE_MOSTLY_IQ3_XXS) { new_type = GGML_TYPE_IQ2_S; } - } else if (name.find("attn_q.weight") != std::string::npos) { + } else if (category == tensor_category::ATTENTION_Q) { if (ftype == LLAMA_FTYPE_MOSTLY_IQ3_XS) { new_type = GGML_TYPE_IQ3_XXS; } else if (ftype == LLAMA_FTYPE_MOSTLY_IQ3_XXS) { new_type = GGML_TYPE_IQ2_S; } - } else if (name.find("ffn_down") != std::string::npos) { + } else if (category == tensor_category::FFN_DOWN) { auto info = layer_info(qs.i_ffn_down, qs.n_ffn_down, name.c_str()); int i_layer = info.first, n_layer = info.second; if (ftype == LLAMA_FTYPE_MOSTLY_Q2_K) new_type = GGML_TYPE_Q3_K; @@ -378,7 +610,7 @@ static ggml_type llama_tensor_get_type(quantize_state_impl & qs, ggml_type new_t new_type = ftype == LLAMA_FTYPE_MOSTLY_Q4_0 ? GGML_TYPE_Q4_1 : GGML_TYPE_Q5_1; } ++qs.i_ffn_down; - } else if (name.find("attn_output.weight") != std::string::npos) { + } else if (category == tensor_category::ATTENTION_OUTPUT) { if (arch != LLM_ARCH_FALCON) { if (qs.model.hparams.n_expert == 8) { if (ftype == LLAMA_FTYPE_MOSTLY_Q2_K || ftype == LLAMA_FTYPE_MOSTLY_IQ3_XS || ftype == LLAMA_FTYPE_MOSTLY_IQ3_XXS || @@ -398,14 +630,14 @@ static ggml_type llama_tensor_get_type(quantize_state_impl & qs, ggml_type new_t if (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_L) new_type = GGML_TYPE_Q4_K; } } - else if (name.find("attn_qkv.weight") != std::string::npos) { + else if (category == tensor_category::ATTENTION_QKV) { if (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_M || ftype == LLAMA_FTYPE_MOSTLY_Q3_K_L || ftype == LLAMA_FTYPE_MOSTLY_IQ3_M) { new_type = GGML_TYPE_Q4_K; } else if (ftype == LLAMA_FTYPE_MOSTLY_Q4_K_M) new_type = GGML_TYPE_Q5_K; else if (ftype == LLAMA_FTYPE_MOSTLY_Q5_K_M) new_type = GGML_TYPE_Q6_K; } - else if (name.find("ffn_gate") != std::string::npos) { + else if (category == tensor_category::FFN_GATE) { auto info = layer_info(qs.i_ffn_gate, qs.n_ffn_gate, name.c_str()); int i_layer = info.first, n_layer = info.second; if (ftype == LLAMA_FTYPE_MOSTLY_IQ3_XS && (i_layer >= n_layer/8 && i_layer < 7*n_layer/8)) { @@ -413,7 +645,7 @@ static ggml_type llama_tensor_get_type(quantize_state_impl & qs, ggml_type new_t } ++qs.i_ffn_gate; } - else if (name.find("ffn_up") != std::string::npos) { + else if (category == tensor_category::FFN_UP) { auto info = layer_info(qs.i_ffn_up, qs.n_ffn_up, name.c_str()); int i_layer = info.first, n_layer = info.second; if (ftype == LLAMA_FTYPE_MOSTLY_IQ3_XS && (i_layer >= n_layer/8 && i_layer < 7*n_layer/8)) { @@ -422,60 +654,58 @@ static ggml_type llama_tensor_get_type(quantize_state_impl & qs, ggml_type new_t ++qs.i_ffn_up; } - // if (ftype == LLAMA_FTYPE_MOSTLY_Q2_K) new_type = GGML_TYPE_Q3_K; - //} - // IK: let's remove this, else Q2_K is almost the same as Q3_K_S - //else if (name.find("ffn_gate") != std::string::npos || name.find("ffn_up") != std::string::npos) { - // if (ftype == LLAMA_FTYPE_MOSTLY_Q2_K) new_type = GGML_TYPE_Q3_K; - //} - // This can be used to reduce the size of the Q5_K_S model. - // The associated PPL increase is fully in line with the size reduction - //else { - // if (ftype == LLAMA_FTYPE_MOSTLY_Q5_K_S) new_type = GGML_TYPE_Q4_K; - //} - bool convert_incompatible_tensor = false; - { - const int64_t nx = tensor->ne[0]; - const int64_t ny = tensor->ne[1]; - const int64_t qk_k = ggml_blck_size(new_type); + return new_type; +} - if (nx % qk_k != 0) { - LLAMA_LOG_WARN("\n\n%s : tensor cols %" PRId64 " x %" PRId64 " are not divisible by %" PRId64 ", required for %s", __func__, nx, ny, qk_k, ggml_type_name(new_type)); - convert_incompatible_tensor = true; - } else { - ++qs.n_k_quantized; - } +// outer wrapper: determine the ggml_type that this tensor should be quantized to +static ggml_type llama_tensor_get_type(quantize_state_impl & qs, const llama_model_quantize_params * params, const ggml_tensor * tensor, ggml_type default_type, const tensor_metadata & tm) { + if (!tensor_allows_quantization(params, qs.model.arch, tensor)) { + return tensor->type; + } + if (params->token_embedding_type < GGML_TYPE_COUNT && tm.category == tensor_category::TOKEN_EMBD) { + return params->token_embedding_type; + } + if (params->output_tensor_type < GGML_TYPE_COUNT && tm.category == tensor_category::OUTPUT) { + return params->output_tensor_type; } - if (convert_incompatible_tensor) { - switch (new_type) { - case GGML_TYPE_TQ1_0: - case GGML_TYPE_TQ2_0: new_type = GGML_TYPE_Q4_0; break; // TODO: use a symmetric type instead - case GGML_TYPE_IQ2_XXS: - case GGML_TYPE_IQ2_XS: - case GGML_TYPE_IQ2_S: - case GGML_TYPE_IQ3_XXS: - case GGML_TYPE_IQ3_S: - case GGML_TYPE_IQ1_S: - case GGML_TYPE_IQ1_M: - case GGML_TYPE_Q2_K: - case GGML_TYPE_Q3_K: - case GGML_TYPE_IQ4_XS: new_type = GGML_TYPE_IQ4_NL; break; - case GGML_TYPE_Q4_K: new_type = GGML_TYPE_Q5_0; break; - case GGML_TYPE_Q5_K: new_type = GGML_TYPE_Q5_1; break; - case GGML_TYPE_Q6_K: new_type = GGML_TYPE_Q8_0; break; - default: throw std::runtime_error("\nUnsupported tensor size encountered\n"); + ggml_type new_type = default_type; + + // get more optimal quantization type based on the tensor shape, layer, etc. + if (!params->pure && ggml_is_quantized(default_type)) { + // if the user provided tensor types - use those + bool manual = false; + if (!qs.tensor_type_patterns.empty()) { + const std::string tensor_name(tensor->name); + for (const auto & [pattern, qtype] : qs.tensor_type_patterns) { + if (std::regex_search(tensor_name, pattern)) { + if (qtype != new_type) { + LLAMA_LOG_WARN("%s: %-36s - applying manual override: %s -> %s\n", + __func__, tensor_name.c_str(), ggml_type_name(new_type), ggml_type_name(qtype)); + new_type = qtype; + } + manual = true; + break; + } + } } - if (tensor->ne[0] % ggml_blck_size(new_type) != 0) { - new_type = GGML_TYPE_F16; + + // if not manual - use the standard logic for choosing the quantization type based on the selected mixture + if (!manual) { + new_type = llama_tensor_get_type_impl(qs, new_type, tensor, params->ftype, tm.category); } - LLAMA_LOG_WARN(" - using fallback quantization %s\n", ggml_type_name(new_type)); - ++qs.n_fallback; + + // incompatible tensor shapes are handled here - fallback to a compatible type + new_type = tensor_type_fallback(qs, tensor, new_type); } return new_type; } +// +// quantization implementation +// + static size_t llama_tensor_quantize_impl(enum ggml_type new_type, const float * f32_data, void * new_data, const int64_t chunk_size, int64_t nrows, int64_t n_per_row, const float * imatrix, std::vector<std::thread> & workers, const int nthread) { if (nthread < 2) { // single-thread @@ -530,50 +760,102 @@ static size_t llama_tensor_quantize_impl(enum ggml_type new_type, const float * return new_size; } -static void llama_model_quantize_impl(const std::string & fname_inp, const std::string & fname_out, const llama_model_quantize_params * params) { - ggml_type default_type; - llama_ftype ftype = params->ftype; +// +// imatrix requirement check +// - switch (params->ftype) { - case LLAMA_FTYPE_MOSTLY_Q4_0: default_type = GGML_TYPE_Q4_0; break; - case LLAMA_FTYPE_MOSTLY_Q4_1: default_type = GGML_TYPE_Q4_1; break; - case LLAMA_FTYPE_MOSTLY_Q5_0: default_type = GGML_TYPE_Q5_0; break; - case LLAMA_FTYPE_MOSTLY_Q5_1: default_type = GGML_TYPE_Q5_1; break; - case LLAMA_FTYPE_MOSTLY_Q8_0: default_type = GGML_TYPE_Q8_0; break; - case LLAMA_FTYPE_MOSTLY_F16: default_type = GGML_TYPE_F16; break; - case LLAMA_FTYPE_MOSTLY_BF16: default_type = GGML_TYPE_BF16; break; - case LLAMA_FTYPE_ALL_F32: default_type = GGML_TYPE_F32; break; +static bool tensor_requires_imatrix(const char * tensor_name, const ggml_type dst_type, const llama_ftype ftype) { + if (tensor_name_match_token_embd(tensor_name) || tensor_name_match_output_weight(tensor_name)) { + return false; + } + switch (dst_type) { + case GGML_TYPE_IQ3_XXS: + case GGML_TYPE_IQ2_XXS: + case GGML_TYPE_IQ2_XS: + case GGML_TYPE_IQ2_S: + case GGML_TYPE_IQ1_M: + case GGML_TYPE_IQ1_S: + return true; + case GGML_TYPE_Q2_K: + // as a general rule, the k-type quantizations don't require imatrix data. + // the only exception is Q2_K tensors that are part of a Q2_K_S file. + return ftype == LLAMA_FTYPE_MOSTLY_Q2_K_S; + default: + return false; + } +} - case LLAMA_FTYPE_MOSTLY_MXFP4_MOE: default_type = GGML_TYPE_MXFP4; break; +// +// given a file type, get the default tensor type +// + +ggml_type llama_ftype_get_default_type(llama_ftype ftype) { + switch (ftype) { + case LLAMA_FTYPE_MOSTLY_Q4_0: return GGML_TYPE_Q4_0; + case LLAMA_FTYPE_MOSTLY_Q4_1: return GGML_TYPE_Q4_1; + case LLAMA_FTYPE_MOSTLY_Q5_0: return GGML_TYPE_Q5_0; + case LLAMA_FTYPE_MOSTLY_Q5_1: return GGML_TYPE_Q5_1; + case LLAMA_FTYPE_MOSTLY_Q8_0: return GGML_TYPE_Q8_0; + case LLAMA_FTYPE_MOSTLY_F16: return GGML_TYPE_F16; + case LLAMA_FTYPE_MOSTLY_BF16: return GGML_TYPE_BF16; + case LLAMA_FTYPE_ALL_F32: return GGML_TYPE_F32; + case LLAMA_FTYPE_MOSTLY_Q1_0: return GGML_TYPE_Q1_0; + + case LLAMA_FTYPE_MOSTLY_MXFP4_MOE: return GGML_TYPE_MXFP4; // K-quants case LLAMA_FTYPE_MOSTLY_Q2_K_S: - case LLAMA_FTYPE_MOSTLY_Q2_K: default_type = GGML_TYPE_Q2_K; break; - case LLAMA_FTYPE_MOSTLY_IQ3_XS: default_type = GGML_TYPE_IQ3_S; break; + case LLAMA_FTYPE_MOSTLY_Q2_K: return GGML_TYPE_Q2_K; + case LLAMA_FTYPE_MOSTLY_IQ3_XS: return GGML_TYPE_IQ3_S; case LLAMA_FTYPE_MOSTLY_Q3_K_S: case LLAMA_FTYPE_MOSTLY_Q3_K_M: - case LLAMA_FTYPE_MOSTLY_Q3_K_L: default_type = GGML_TYPE_Q3_K; break; + case LLAMA_FTYPE_MOSTLY_Q3_K_L: return GGML_TYPE_Q3_K; case LLAMA_FTYPE_MOSTLY_Q4_K_S: - case LLAMA_FTYPE_MOSTLY_Q4_K_M: default_type = GGML_TYPE_Q4_K; break; + case LLAMA_FTYPE_MOSTLY_Q4_K_M: return GGML_TYPE_Q4_K; case LLAMA_FTYPE_MOSTLY_Q5_K_S: - case LLAMA_FTYPE_MOSTLY_Q5_K_M: default_type = GGML_TYPE_Q5_K; break; - case LLAMA_FTYPE_MOSTLY_Q6_K: default_type = GGML_TYPE_Q6_K; break; - case LLAMA_FTYPE_MOSTLY_TQ1_0: default_type = GGML_TYPE_TQ1_0; break; - case LLAMA_FTYPE_MOSTLY_TQ2_0: default_type = GGML_TYPE_TQ2_0; break; - case LLAMA_FTYPE_MOSTLY_IQ2_XXS: default_type = GGML_TYPE_IQ2_XXS; break; - case LLAMA_FTYPE_MOSTLY_IQ2_XS: default_type = GGML_TYPE_IQ2_XS; break; - case LLAMA_FTYPE_MOSTLY_IQ2_S: default_type = GGML_TYPE_IQ2_XS; break; - case LLAMA_FTYPE_MOSTLY_IQ2_M: default_type = GGML_TYPE_IQ2_S; break; - case LLAMA_FTYPE_MOSTLY_IQ3_XXS: default_type = GGML_TYPE_IQ3_XXS; break; - case LLAMA_FTYPE_MOSTLY_IQ1_S: default_type = GGML_TYPE_IQ1_S; break; - case LLAMA_FTYPE_MOSTLY_IQ1_M: default_type = GGML_TYPE_IQ1_M; break; - case LLAMA_FTYPE_MOSTLY_IQ4_NL: default_type = GGML_TYPE_IQ4_NL; break; - case LLAMA_FTYPE_MOSTLY_IQ4_XS: default_type = GGML_TYPE_IQ4_XS; break; - case LLAMA_FTYPE_MOSTLY_IQ3_S: default_type = GGML_TYPE_IQ3_S; break; - case LLAMA_FTYPE_MOSTLY_IQ3_M: default_type = GGML_TYPE_IQ3_S; break; - - default: throw std::runtime_error(format("invalid output file type %d\n", ftype)); + case LLAMA_FTYPE_MOSTLY_Q5_K_M: return GGML_TYPE_Q5_K; + case LLAMA_FTYPE_MOSTLY_Q6_K: return GGML_TYPE_Q6_K; + case LLAMA_FTYPE_MOSTLY_TQ1_0: return GGML_TYPE_TQ1_0; + case LLAMA_FTYPE_MOSTLY_TQ2_0: return GGML_TYPE_TQ2_0; + case LLAMA_FTYPE_MOSTLY_IQ2_XXS: return GGML_TYPE_IQ2_XXS; + case LLAMA_FTYPE_MOSTLY_IQ2_XS: return GGML_TYPE_IQ2_XS; + case LLAMA_FTYPE_MOSTLY_IQ2_S: return GGML_TYPE_IQ2_XS; + case LLAMA_FTYPE_MOSTLY_IQ2_M: return GGML_TYPE_IQ2_S; + case LLAMA_FTYPE_MOSTLY_IQ3_XXS: return GGML_TYPE_IQ3_XXS; + case LLAMA_FTYPE_MOSTLY_IQ1_S: return GGML_TYPE_IQ1_S; + case LLAMA_FTYPE_MOSTLY_IQ1_M: return GGML_TYPE_IQ1_M; + case LLAMA_FTYPE_MOSTLY_IQ4_NL: return GGML_TYPE_IQ4_NL; + case LLAMA_FTYPE_MOSTLY_IQ4_XS: return GGML_TYPE_IQ4_XS; + case LLAMA_FTYPE_MOSTLY_IQ3_S: + case LLAMA_FTYPE_MOSTLY_IQ3_M: return GGML_TYPE_IQ3_S; + + default: return GGML_TYPE_COUNT; + } +} + + +static void init_quantize_state_counters(quantize_state_impl & qs, std::vector<tensor_metadata> & metadata) { + for (auto & tm : metadata) { + tensor_category cat = tensor_get_category(tm.name); + tm.category = cat; + + if (category_is_attn_v(cat)) { + ++qs.n_attention_wv; + } + + if (cat == tensor_category::OUTPUT) { + qs.has_tied_embeddings = false; + } } + qs.n_ffn_down = qs.n_ffn_gate = qs.n_ffn_up = (int)qs.model.hparams.n_layer(); +} + +// +// main quantization driver +// + +static void llama_model_quantize_impl(const std::string & fname_inp, const std::string & fname_out, const llama_model_quantize_params * params) { + llama_ftype ftype = params->ftype; int nthread = params->nthread; @@ -581,6 +863,11 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std:: nthread = std::thread::hardware_concurrency(); } + ggml_type default_type = llama_ftype_get_default_type(ftype); + if (default_type == GGML_TYPE_COUNT) { + throw std::runtime_error(format("invalid output file type %d\n", ftype)); + } + // mmap consistently increases speed on Linux, and also increases speed on Windows with // hot cache. It may cause a slowdown on macOS, possibly related to free memory. #if defined(__linux__) || defined(_WIN32) @@ -589,32 +876,38 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std:: constexpr bool use_mmap = false; #endif - llama_model_kv_override * kv_overrides = nullptr; - if (params->kv_overrides) { - auto * v = (std::vector<llama_model_kv_override>*)params->kv_overrides; - kv_overrides = v->data(); - } - + const llama_model_kv_override * kv_overrides = params->kv_overrides; std::vector<std::string> splits = {}; - llama_model_loader ml(fname_inp, splits, use_mmap, /*use_direct_io*/ true, /*check_tensors*/ true, /*no_alloc*/ false, kv_overrides, nullptr); + llama_model_loader ml(/*metadata*/ nullptr, /*set_tensor_data*/ nullptr, /*set_tensor_data_ud*/ nullptr, + fname_inp, splits, /*file*/ nullptr, use_mmap, /*use_direct_io*/ false, /*check_tensors*/ true, /*no_alloc*/ false, kv_overrides, nullptr); ml.init_mappings(false); // no prefetching - llama_model model(llama_model_default_params()); + auto mparams = llama_model_default_params(); + std::unique_ptr<llama_model> model_ptr(llama_model_create(ml, mparams)); + + auto * model = dynamic_cast<llama_model_base *>(model_ptr.get()); + if (model == nullptr) { + GGML_ABORT("fatal error: model does not implement llama_model_base"); + } - model.load_arch (ml); - model.load_hparams(ml); - model.load_stats (ml); + model->load_hparams(ml); + model->load_stats (ml); - quantize_state_impl qs(model, params); + quantize_state_impl qs(*model, params); if (params->only_copy) { ftype = ml.ftype; } + std::unordered_map<std::string, std::vector<float>> i_data; const std::unordered_map<std::string, std::vector<float>> * imatrix_data = nullptr; if (params->imatrix) { - imatrix_data = static_cast<const std::unordered_map<std::string, std::vector<float>>*>(params->imatrix); + for (const llama_model_imatrix_data * p = params->imatrix; p->name != nullptr; p++) { + i_data.emplace(p->name, std::vector<float>(p->data, p->data + p->size)); + } + imatrix_data = & i_data; if (imatrix_data) { - LLAMA_LOG_INFO("================================ Have weights data with %d entries\n",int(imatrix_data->size())); + LLAMA_LOG_INFO("\n%s: have importance matrix data with %d entries\n", + __func__, (int)imatrix_data->size()); qs.has_imatrix = true; // check imatrix for nans or infs for (const auto & kv : *imatrix_data) { @@ -632,11 +925,13 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std:: std::vector<int> prune_list = {}; if (params->prune_layers) { - prune_list = *static_cast<const std::vector<int> *>(params->prune_layers); + for (const int32_t * p = params->prune_layers; * p != -1; p++) { + prune_list.push_back(* p); + } } // copy the KV pairs from the input file - gguf_set_kv (ctx_out.get(), ml.meta.get()); + gguf_set_kv (ctx_out.get(), ml.metadata); gguf_set_val_u32(ctx_out.get(), "general.quantization_version", GGML_QNT_VERSION); // TODO: use LLM_KV gguf_set_val_u32(ctx_out.get(), "general.file_type", ftype); // TODO: use LLM_KV @@ -646,20 +941,18 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std:: gguf_remove_key(ctx_out.get(), ml.llm_kv(LLM_KV_SPLIT_TENSORS_COUNT).c_str()); if (params->kv_overrides) { - const std::vector<llama_model_kv_override> & overrides = *(const std::vector<llama_model_kv_override> *)params->kv_overrides; - for (const auto & o : overrides) { - if (o.key[0] == 0) break; - if (o.tag == LLAMA_KV_OVERRIDE_TYPE_FLOAT) { - gguf_set_val_f32(ctx_out.get(), o.key, o.val_f64); - } else if (o.tag == LLAMA_KV_OVERRIDE_TYPE_INT) { + for (const llama_model_kv_override * o = params->kv_overrides; o->key[0] != 0; ++o) { + if (o->tag == LLAMA_KV_OVERRIDE_TYPE_FLOAT) { + gguf_set_val_f32(ctx_out.get(), o->key, o->val_f64); + } else if (o->tag == LLAMA_KV_OVERRIDE_TYPE_INT) { // Setting type to UINT32. See https://github.com/ggml-org/llama.cpp/pull/14182 for context - gguf_set_val_u32(ctx_out.get(), o.key, (uint32_t)std::abs(o.val_i64)); - } else if (o.tag == LLAMA_KV_OVERRIDE_TYPE_BOOL) { - gguf_set_val_bool(ctx_out.get(), o.key, o.val_bool); - } else if (o.tag == LLAMA_KV_OVERRIDE_TYPE_STR) { - gguf_set_val_str(ctx_out.get(), o.key, o.val_str); + gguf_set_val_u32(ctx_out.get(), o->key, (uint32_t)std::abs(o->val_i64)); + } else if (o->tag == LLAMA_KV_OVERRIDE_TYPE_BOOL) { + gguf_set_val_bool(ctx_out.get(), o->key, o->val_bool); + } else if (o->tag == LLAMA_KV_OVERRIDE_TYPE_STR) { + gguf_set_val_str(ctx_out.get(), o->key, o->val_str); } else { - LLAMA_LOG_WARN("%s: unknown KV override type for key %s\n", __func__, o.key); + LLAMA_LOG_WARN("%s: unknown KV override type for key %s\n", __func__, o->key); } } } @@ -697,35 +990,16 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std:: }); } - for (const auto * it : tensors) { - const struct ggml_tensor * tensor = it->tensor; - - const std::string name = ggml_get_name(tensor); - - // TODO: avoid hardcoded tensor names - use the TN_* constants - if (name.find("attn_v.weight") != std::string::npos || - name.find("attn_qkv.weight") != std::string::npos || - name.find("attn_kv_b.weight")!= std::string::npos) { - ++qs.n_attention_wv; - } else if (name == LLM_TN(model.arch)(LLM_TENSOR_OUTPUT, "weight")) { - qs.has_output = true; - } + // compute tensor metadata once and cache it + std::vector<tensor_metadata> metadata(tensors.size()); + for (size_t i = 0; i < tensors.size(); ++i) { + metadata[i].name = ggml_get_name(tensors[i]->tensor); } - qs.n_ffn_down = qs.n_ffn_gate = qs.n_ffn_up = (int)model.hparams.n_layer; - - size_t total_size_org = 0; - size_t total_size_new = 0; - - std::vector<std::thread> workers; - workers.reserve(nthread); + // initialize quantization state counters and metadata categories + init_quantize_state_counters(qs, metadata); int idx = 0; - - std::vector<no_init<uint8_t>> read_data; - std::vector<no_init<uint8_t>> work; - std::vector<no_init<float>> f32_conv_buf; - uint16_t n_split = 1; // Assume split index is continuous @@ -737,14 +1011,48 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std:: std::vector<gguf_context_ptr> ctx_outs(n_split); ctx_outs[0] = std::move(ctx_out); - // populate the original tensors so we get an initial meta data - for (const auto * it : tensors) { + // flag for --dry-run + bool will_require_imatrix = false; + + // + // preliminary iteration over all weights + // + + for (size_t i = 0; i < tensors.size(); ++i) { + const auto * it = tensors[i]; + const struct ggml_tensor * tensor = it->tensor; + uint16_t i_split = params->keep_split ? it->idx : 0; - ggml_tensor * tensor = it->tensor; if (!ctx_outs[i_split]) { ctx_outs[i_split].reset(gguf_init_empty()); } gguf_add_tensor(ctx_outs[i_split].get(), tensor); + + metadata[i].allows_quantization = tensor_allows_quantization(params, model->arch, tensor); + + if (metadata[i].allows_quantization) { + metadata[i].target_type = llama_tensor_get_type(qs, params, tensor, default_type, metadata[i]); + } else { + metadata[i].target_type = tensor->type; + } + + metadata[i].requires_imatrix = tensor_requires_imatrix(tensor->name, metadata[i].target_type, ftype); + + if (params->imatrix) { + metadata[i].remapped_imatrix_name = remap_imatrix(tensor->name, mapped); + } else if (metadata[i].allows_quantization && metadata[i].requires_imatrix) { + if (params->dry_run) { + will_require_imatrix = true; + } else { + LLAMA_LOG_ERROR("\n============================================================================\n" + " ERROR: this quantization requires an importance matrix!\n" + " - offending tensor: %s\n" + " - target type: %s\n" + "============================================================================\n\n", + metadata[i].name.c_str(), ggml_type_name(metadata[i].target_type)); + throw std::runtime_error("this quantization requires an imatrix!"); + } + } } // Set split info if needed @@ -756,6 +1064,16 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std:: } } + size_t total_size_org = 0; + size_t total_size_new = 0; + + std::vector<std::thread> workers; + workers.reserve(nthread); + + std::vector<no_init<uint8_t>> read_data; + std::vector<no_init<uint8_t>> work; + std::vector<no_init<float>> f32_conv_buf; + int cur_split = -1; std::ofstream fout; auto close_ofstream = [&]() { @@ -785,251 +1103,181 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std:: ::zeros(fout, meta_size); }; - const auto tn = LLM_TN(model.arch); - new_ofstream(0); - for (const auto * it : tensors) { - const auto & weight = *it; + // no output file for --dry-run + if (!params->dry_run) { + new_ofstream(0); + } + + // + // main loop: iterate over all weights + // + + for (size_t i = 0; i < tensors.size(); ++i) { + const auto & weight = *tensors[i]; + const auto & tm = metadata[i]; ggml_tensor * tensor = weight.tensor; - if (weight.idx != cur_split && params->keep_split) { + + if (!params->dry_run && (weight.idx != cur_split && params->keep_split)) { close_ofstream(); new_ofstream(weight.idx); } - const std::string name = ggml_get_name(tensor); + const size_t tensor_size = ggml_nbytes(tensor); - if (!ml.use_mmap) { - if (read_data.size() < ggml_nbytes(tensor)) { - read_data.resize(ggml_nbytes(tensor)); + if (!params->dry_run) { + if (!ml.use_mmap) { + if (read_data.size() < tensor_size) { + read_data.resize(tensor_size); + } + tensor->data = read_data.data(); } - tensor->data = read_data.data(); + ml.load_data_for(tensor); } - ml.load_data_for(tensor); - LLAMA_LOG_INFO("[%4d/%4d] %36s - [%s], type = %6s, ", + LLAMA_LOG_INFO("[%4d/%4d] %-36s - [%s], type = %6s, ", ++idx, ml.n_tensors, ggml_get_name(tensor), llama_format_tensor_shape(tensor).c_str(), ggml_type_name(tensor->type)); - // This used to be a regex, but <regex> has an extreme cost to compile times. - bool quantize = name.rfind("weight") == name.size() - 6; // ends with 'weight'? - - // quantize only 2D and 3D tensors (experts) - quantize &= (ggml_n_dims(tensor) >= 2); - - // do not quantize norm tensors - quantize &= name.find("_norm.weight") == std::string::npos; - - quantize &= params->quantize_output_tensor || name != "output.weight"; - quantize &= !params->only_copy; - - // do not quantize expert gating tensors - // NOTE: can't use LLM_TN here because the layer number is not known - quantize &= name.find("ffn_gate_inp.weight") == std::string::npos; - - // these are very small (e.g. 4x4) - quantize &= name.find("altup") == std::string::npos; - quantize &= name.find("laurel") == std::string::npos; - - // these are not too big so keep them as it is - quantize &= name.find("per_layer_model_proj") == std::string::npos; - - // do not quantize positional embeddings and token types (BERT) - quantize &= name != LLM_TN(model.arch)(LLM_TENSOR_POS_EMBD, "weight"); - quantize &= name != LLM_TN(model.arch)(LLM_TENSOR_TOKEN_TYPES, "weight"); - - // do not quantize Mamba's small yet 2D weights - // NOTE: can't use LLM_TN here because the layer number is not known - quantize &= name.find("ssm_conv1d.weight") == std::string::npos; - quantize &= name.find("shortconv.conv.weight") == std::string::npos; - - // do not quantize RWKV's small yet 2D weights - quantize &= name.find("time_mix_first.weight") == std::string::npos; - quantize &= name.find("time_mix_w0.weight") == std::string::npos; - quantize &= name.find("time_mix_w1.weight") == std::string::npos; - quantize &= name.find("time_mix_w2.weight") == std::string::npos; - quantize &= name.find("time_mix_v0.weight") == std::string::npos; - quantize &= name.find("time_mix_v1.weight") == std::string::npos; - quantize &= name.find("time_mix_v2.weight") == std::string::npos; - quantize &= name.find("time_mix_a0.weight") == std::string::npos; - quantize &= name.find("time_mix_a1.weight") == std::string::npos; - quantize &= name.find("time_mix_a2.weight") == std::string::npos; - quantize &= name.find("time_mix_g1.weight") == std::string::npos; - quantize &= name.find("time_mix_g2.weight") == std::string::npos; - quantize &= name.find("time_mix_decay_w1.weight") == std::string::npos; - quantize &= name.find("time_mix_decay_w2.weight") == std::string::npos; - quantize &= name.find("time_mix_lerp_fused.weight") == std::string::npos; - - // do not quantize relative position bias (T5) - quantize &= name.find("attn_rel_b.weight") == std::string::npos; - - // do not quantize specific multimodal tensors - quantize &= name.find(".position_embd.") == std::string::npos; - - ggml_type new_type; + const ggml_type cur_type = tensor->type; + const ggml_type new_type = tm.target_type; + + // If we've decided to quantize to the same type the tensor is already + // in then there's nothing to do. + bool quantize = cur_type != new_type; + void * new_data; size_t new_size; - if (quantize) { - new_type = default_type; - - // get more optimal quantization type based on the tensor shape, layer, etc. - if (!params->pure && ggml_is_quantized(default_type)) { - int fallback = qs.n_fallback; - new_type = llama_tensor_get_type(qs, new_type, tensor, ftype); - // unless the user specifies a type, and the tensor geometry will not require fallback quantisation - if (params->tensor_types && qs.n_fallback - fallback == 0) { - const std::vector<tensor_quantization> & tensor_types = *static_cast<const std::vector<tensor_quantization> *>(params->tensor_types); - const std::string tensor_name(tensor->name); - for (const auto & [tname, qtype] : tensor_types) { - if (std::regex pattern(tname); std::regex_search(tensor_name, pattern)) { - if (qtype != new_type) { - LLAMA_LOG_DEBUG("(overriding %s) ", ggml_type_name(new_type)); - new_type = qtype; // if two or more types are specified for the same tensor, the last match wins - } - } - } + if (params->dry_run) { + // the --dry-run option calculates the final quantization size without quantizing + if (quantize) { + new_size = ggml_nrows(tensor) * ggml_row_size(new_type, tensor->ne[0]); + LLAMA_LOG_INFO("size = %8.2f MiB -> %8.2f MiB (%s)\n", + tensor_size/1024.0/1024.0, + new_size/1024.0/1024.0, + ggml_type_name(new_type)); + if (!will_require_imatrix && tm.requires_imatrix) { + will_require_imatrix = true; } + } else { + new_size = tensor_size; + LLAMA_LOG_INFO("size = %8.3f MiB\n", new_size/1024.0/1024.0); } - if (params->token_embedding_type < GGML_TYPE_COUNT && strcmp(tensor->name, "token_embd.weight") == 0) { - new_type = params->token_embedding_type; - } - if (params->output_tensor_type < GGML_TYPE_COUNT && strcmp(tensor->name, "output.weight") == 0) { - new_type = params->output_tensor_type; - } - - // If we've decided to quantize to the same type the tensor is already - // in then there's nothing to do. - quantize = tensor->type != new_type; - } - - if (!quantize) { - new_type = tensor->type; - new_data = tensor->data; - new_size = ggml_nbytes(tensor); - LLAMA_LOG_INFO("size = %8.3f MiB\n", ggml_nbytes(tensor)/1024.0/1024.0); + total_size_org += tensor_size; + total_size_new += new_size; + continue; } else { - const int64_t nelements = ggml_nelements(tensor); + // no --dry-run, perform quantization + if (!quantize) { + new_data = tensor->data; + new_size = tensor_size; + LLAMA_LOG_INFO("size = %8.3f MiB\n", tensor_size/1024.0/1024.0); + } else { + const int64_t nelements = ggml_nelements(tensor); - const float * imatrix = nullptr; - if (imatrix_data) { - auto it = imatrix_data->find(remap_imatrix(tensor->name, mapped)); - if (it == imatrix_data->end()) { - LLAMA_LOG_INFO("\n====== %s: did not find weights for %s\n", __func__, tensor->name); - } else { - if (it->second.size() == (size_t)tensor->ne[0]*tensor->ne[2]) { - imatrix = it->second.data(); + const float * imatrix = nullptr; + if (imatrix_data) { + auto it = imatrix_data->find(tm.remapped_imatrix_name); + if (it == imatrix_data->end()) { + LLAMA_LOG_INFO("\n====== %s: did not find weights for %s\n", __func__, tensor->name); } else { - LLAMA_LOG_INFO("\n====== %s: imatrix size %d is different from tensor size %d for %s\n", __func__, - int(it->second.size()), int(tensor->ne[0]*tensor->ne[2]), tensor->name); - - // this can happen when quantizing an old mixtral model with split tensors with a new incompatible imatrix - // this is a significant error and it may be good idea to abort the process if this happens, - // since many people will miss the error and not realize that most of the model is being quantized without an imatrix - // tok_embd should be ignored in this case, since it always causes this warning - if (name != tn(LLM_TENSOR_TOKEN_EMBD, "weight")) { - throw std::runtime_error(format("imatrix size %d is different from tensor size %d for %s", - int(it->second.size()), int(tensor->ne[0]*tensor->ne[2]), tensor->name)); + if (it->second.size() == (size_t)tensor->ne[0]*tensor->ne[2]) { + imatrix = it->second.data(); + } else { + LLAMA_LOG_INFO("\n====== %s: imatrix size %d is different from tensor size %d for %s\n", __func__, + int(it->second.size()), int(tensor->ne[0]*tensor->ne[2]), tensor->name); + + // this can happen when quantizing an old mixtral model with split tensors with a new incompatible imatrix + // this is a significant error and it may be good idea to abort the process if this happens, + // since many people will miss the error and not realize that most of the model is being quantized without an imatrix + // tok_embd should be ignored in this case, since it always causes this warning + if (!tensor_name_match_token_embd(tensor->name)) { + throw std::runtime_error(format("imatrix size %d is different from tensor size %d for %s", + int(it->second.size()), int(tensor->ne[0]*tensor->ne[2]), tensor->name)); + } } } } - } - if ((new_type == GGML_TYPE_IQ2_XXS || - new_type == GGML_TYPE_IQ2_XS || - new_type == GGML_TYPE_IQ2_S || - new_type == GGML_TYPE_IQ1_S || - (new_type == GGML_TYPE_IQ1_M && strcmp(tensor->name, "token_embd.weight") && strcmp(tensor->name, "output.weight")) || - (new_type == GGML_TYPE_Q2_K && params->ftype == LLAMA_FTYPE_MOSTLY_Q2_K_S && strcmp(tensor->name, "token_embd.weight") != 0)) && !imatrix) { - LLAMA_LOG_ERROR("\n\n============================================================\n"); - LLAMA_LOG_ERROR("Missing importance matrix for tensor %s in a very low-bit quantization\n", tensor->name); - LLAMA_LOG_ERROR("The result will be garbage, so bailing out\n"); - LLAMA_LOG_ERROR("============================================================\n\n"); - throw std::runtime_error(format("Missing importance matrix for tensor %s in a very low-bit quantization", tensor->name)); - } + if (!imatrix && tm.requires_imatrix) { + LLAMA_LOG_ERROR("\n\n============================================================\n"); + LLAMA_LOG_ERROR("Missing importance matrix for tensor %s in a very low-bit quantization\n", tensor->name); + LLAMA_LOG_ERROR("The result will be garbage, so bailing out\n"); + LLAMA_LOG_ERROR("============================================================\n\n"); + throw std::runtime_error(format("Missing importance matrix for tensor %s in a very low-bit quantization", tensor->name)); + } - float * f32_data; + float * f32_data; - if (tensor->type == GGML_TYPE_F32) { - f32_data = (float *) tensor->data; - } else if (ggml_is_quantized(tensor->type) && !params->allow_requantize) { - throw std::runtime_error(format("requantizing from type %s is disabled", ggml_type_name(tensor->type))); - } else { - llama_tensor_dequantize_impl(tensor, f32_conv_buf, workers, nelements, nthread); - f32_data = (float *) f32_conv_buf.data(); - } + if (tensor->type == GGML_TYPE_F32) { + f32_data = (float *) tensor->data; + } else if (ggml_is_quantized(tensor->type) && !params->allow_requantize) { + throw std::runtime_error(format("requantizing from type %s is disabled", ggml_type_name(tensor->type))); + } else { + llama_tensor_dequantize_impl(tensor, f32_conv_buf, workers, nelements, nthread); + f32_data = (float *) f32_conv_buf.data(); + } - LLAMA_LOG_INFO("converting to %s .. ", ggml_type_name(new_type)); - fflush(stdout); + LLAMA_LOG_INFO("converting to %s .. ", ggml_type_name(new_type)); + fflush(stdout); - if (work.size() < (size_t)nelements * 4) { - work.resize(nelements * 4); // upper bound on size - } - new_data = work.data(); - - const int64_t n_per_row = tensor->ne[0]; - const int64_t nrows = tensor->ne[1]; - - static const int64_t min_chunk_size = 32 * 512; - const int64_t chunk_size = (n_per_row >= min_chunk_size ? n_per_row : n_per_row * ((min_chunk_size + n_per_row - 1)/n_per_row)); - - const int64_t nelements_matrix = tensor->ne[0] * tensor->ne[1]; - const int64_t nchunk = (nelements_matrix + chunk_size - 1)/chunk_size; - const int64_t nthread_use = nthread > 1 ? std::max((int64_t)1, std::min((int64_t)nthread, nchunk)) : 1; - - // quantize each expert separately since they have different importance matrices - new_size = 0; - for (int64_t i03 = 0; i03 < tensor->ne[2]; ++i03) { - const float * f32_data_03 = f32_data + i03 * nelements_matrix; - void * new_data_03 = (char *)new_data + ggml_row_size(new_type, n_per_row) * i03 * nrows; - const float * imatrix_03 = imatrix ? imatrix + i03 * n_per_row : nullptr; - - new_size += llama_tensor_quantize_impl(new_type, f32_data_03, new_data_03, chunk_size, nrows, n_per_row, imatrix_03, workers, nthread_use); - - // TODO: temporary sanity check that the F16 -> MXFP4 is lossless -#if 0 - if (new_type == GGML_TYPE_MXFP4) { - auto * x = f32_data_03; - - //LLAMA_LOG_INFO("nrows = %d, n_per_row = %d\n", nrows, n_per_row); - std::vector<float> deq(nrows*n_per_row); - const ggml_type_traits * qtype = ggml_get_type_traits(new_type); - qtype->to_float(new_data_03, deq.data(), deq.size()); - - double err = 0.0f; - for (int i = 0; i < (int) deq.size(); ++i) { - err += fabsf(deq[i] - x[i]); - //if (fabsf(deq[i] - x[i]) > 0.00001 && i < 256) { - if (deq[i] != x[i]) { - LLAMA_LOG_INFO("deq[%d] = %f, x[%d] = %f\n", i, deq[i], i, x[i]); - } - } - //LLAMA_LOG_INFO("err = %f\n", err); - GGML_ASSERT(err == 0.00000); + if (work.size() < (size_t)nelements * 4) { + work.resize(nelements * 4); // upper bound on size } -#endif - } - LLAMA_LOG_INFO("size = %8.2f MiB -> %8.2f MiB\n", ggml_nbytes(tensor)/1024.0/1024.0, new_size/1024.0/1024.0); - } - total_size_org += ggml_nbytes(tensor); - total_size_new += new_size; + new_data = work.data(); + + const int64_t n_per_row = tensor->ne[0]; + const int64_t nrows = tensor->ne[1]; - // update the gguf meta data as we go - gguf_set_tensor_type(ctx_outs[cur_split].get(), name.c_str(), new_type); - GGML_ASSERT(gguf_get_tensor_size(ctx_outs[cur_split].get(), gguf_find_tensor(ctx_outs[cur_split].get(), name.c_str())) == new_size); - gguf_set_tensor_data(ctx_outs[cur_split].get(), name.c_str(), new_data); + static const int64_t min_chunk_size = 32 * 512; + const int64_t chunk_size = (n_per_row >= min_chunk_size ? n_per_row : n_per_row * ((min_chunk_size + n_per_row - 1)/n_per_row)); - // write tensor data + padding - fout.write((const char *) new_data, new_size); - zeros(fout, GGML_PAD(new_size, align) - new_size); + const int64_t nelements_matrix = tensor->ne[0] * tensor->ne[1]; + const int64_t nchunk = (nelements_matrix + chunk_size - 1)/chunk_size; + const int64_t nthread_use = nthread > 1 ? std::max((int64_t)1, std::min((int64_t)nthread, nchunk)) : 1; + + // quantize each expert separately since they have different importance matrices + new_size = 0; + for (int64_t i03 = 0; i03 < tensor->ne[2]; ++i03) { + const float * f32_data_03 = f32_data + i03 * nelements_matrix; + void * new_data_03 = (char *)new_data + ggml_row_size(new_type, n_per_row) * i03 * nrows; + const float * imatrix_03 = imatrix ? imatrix + i03 * n_per_row : nullptr; + + new_size += llama_tensor_quantize_impl(new_type, f32_data_03, new_data_03, chunk_size, nrows, n_per_row, imatrix_03, workers, nthread_use); + } + LLAMA_LOG_INFO("size = %8.2f MiB -> %8.2f MiB\n", tensor_size/1024.0/1024.0, new_size/1024.0/1024.0); + } + total_size_org += tensor_size; + total_size_new += new_size; + + // update the gguf meta data as we go + gguf_set_tensor_type(ctx_outs[cur_split].get(), metadata[i].name.c_str(), new_type); + GGML_ASSERT(gguf_get_tensor_size(ctx_outs[cur_split].get(), gguf_find_tensor(ctx_outs[cur_split].get(), metadata[i].name.c_str())) == new_size); + gguf_set_tensor_data(ctx_outs[cur_split].get(), metadata[i].name.c_str(), new_data); + + // write tensor data + padding + fout.write((const char *) new_data, new_size); + zeros(fout, GGML_PAD(new_size, align) - new_size); + } // no --dry-run + } // main loop + + if (!params->dry_run) { + close_ofstream(); } - close_ofstream(); - LLAMA_LOG_INFO("%s: model size = %8.2f MiB\n", __func__, total_size_org/1024.0/1024.0); - LLAMA_LOG_INFO("%s: quant size = %8.2f MiB\n", __func__, total_size_new/1024.0/1024.0); + LLAMA_LOG_INFO("%s: model size = %8.2f MiB (%.2f BPW)\n", __func__, total_size_org/1024.0/1024.0, total_size_org*8.0/ml.n_elements); + LLAMA_LOG_INFO("%s: quant size = %8.2f MiB (%.2f BPW)\n", __func__, total_size_new/1024.0/1024.0, total_size_new*8.0/ml.n_elements); + + if (!params->imatrix && params->dry_run && will_require_imatrix) { + LLAMA_LOG_WARN("%s: WARNING: dry run completed successfully, but actually completing this quantization will require an imatrix!\n", + __func__ + ); + } if (qs.n_fallback > 0) { LLAMA_LOG_WARN("%s: WARNING: %d of %d tensor(s) required fallback quantization\n", - __func__, qs.n_fallback, qs.n_k_quantized + qs.n_fallback); + __func__, qs.n_fallback, ml.n_tensors); } } @@ -1040,7 +1288,7 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std:: llama_model_quantize_params llama_model_quantize_default_params() { llama_model_quantize_params result = { /*.nthread =*/ 0, - /*.ftype =*/ LLAMA_FTYPE_MOSTLY_Q5_1, + /*.ftype =*/ LLAMA_FTYPE_MOSTLY_Q8_0, /*.output_tensor_type =*/ GGML_TYPE_COUNT, /*.token_embedding_type =*/ GGML_TYPE_COUNT, /*.allow_requantize =*/ false, @@ -1048,6 +1296,7 @@ llama_model_quantize_params llama_model_quantize_default_params() { /*.only_copy =*/ false, /*.pure =*/ false, /*.keep_split =*/ false, + /*.dry_run =*/ false, /*.imatrix =*/ nullptr, /*.kv_overrides =*/ nullptr, /*.tensor_type =*/ nullptr, @@ -1070,3 +1319,89 @@ uint32_t llama_model_quantize( return 0; } + +// +// Helper functions for external tools exposed in llama-ext.h +// + +quantize_state_impl * llama_quant_init( + const llama_model * model, + const llama_model_quantize_params * params) { + return new quantize_state_impl(*model, params); +} + +void llama_quant_free(quantize_state_impl * qs) { + delete qs; +} + +llama_model * llama_quant_model_from_metadata(const llama_quant_model_desc * desc) { + struct llama_model_params mparams = llama_model_default_params(); + auto arch = llm_arch_from_string(desc->architecture); + auto * model = llama_model_create(arch, mparams); + model->arch = arch; + + // infer llm_type: only LLM_TYPE_70B matters for quantization logic + if (model->arch == LLM_ARCH_LLAMA && desc->n_layer == 80 && desc->n_head != desc->n_head_kv) { + model->type = LLM_TYPE_70B; + } + + model->hparams.n_embd = desc->n_embd; + model->hparams.n_embd_head_k_full = desc->n_embd_head_k; + model->hparams.n_embd_head_v_full = desc->n_embd_head_v; + model->hparams.n_layer_all = desc->n_layer; + model->hparams.n_expert = desc->n_expert; + + for (uint32_t i = 0; i < desc->n_layer; i++) { + model->hparams.n_head_arr[i] = desc->n_head; + model->hparams.n_head_kv_arr[i] = desc->n_head_kv; + model->hparams.n_ff_arr[i] = desc->n_ff; + } + + return model; +} + +bool llama_quant_tensor_allows_quantization( + const quantize_state_impl * qs, + const ggml_tensor * tensor) { + return tensor_allows_quantization(qs->params, qs->model.arch, tensor); +} + +void llama_quant_compute_types( + quantize_state_impl * qs, + llama_ftype ftype, + ggml_tensor ** tensors, + ggml_type * result_types, + size_t n_tensors) { + // reset per-computation state + qs->n_attention_wv = 0; + qs->n_ffn_down = 0; + qs->n_ffn_gate = 0; + qs->n_ffn_up = 0; + qs->i_attention_wv = 0; + qs->i_ffn_down = 0; + qs->i_ffn_gate = 0; + qs->i_ffn_up = 0; + qs->n_fallback = 0; + qs->has_imatrix = false; + qs->has_tied_embeddings = true; + + // build metadata from tensor names + std::vector<tensor_metadata> metadata(n_tensors); + for (size_t i = 0; i < n_tensors; i++) { + metadata[i].name = ggml_get_name(tensors[i]); + } + + // initialize counters and categories + init_quantize_state_counters(*qs, metadata); + + // use a local copy of params with the requested ftype + llama_model_quantize_params local_params = *qs->params; + local_params.ftype = ftype; + + ggml_type default_type = llama_ftype_get_default_type(ftype); + + // compute types + for (size_t i = 0; i < n_tensors; i++) { + result_types[i] = llama_tensor_get_type(*qs, &local_params, tensors[i], default_type, metadata[i]); + } +} diff --git a/examples/talk-llama/llama-sampling.cpp b/examples/talk-llama/llama-sampler.cpp similarity index 94% rename from examples/talk-llama/llama-sampling.cpp rename to examples/talk-llama/llama-sampler.cpp index 11f0394c4ce..9bbc5dbde24 100644 --- a/examples/talk-llama/llama-sampling.cpp +++ b/examples/talk-llama/llama-sampler.cpp @@ -1,4 +1,4 @@ -#include "llama-sampling.h" +#include "llama-sampler.h" #include "llama-impl.h" #include "llama-vocab.h" @@ -1025,11 +1025,7 @@ struct llama_sampler_dist : public llama_sampler_backend { std::mt19937 rng; - // backend input - struct ggml_tensor * inp_uniform; - - ggml_context_ptr inp_ctx; - ggml_backend_buffer_ptr inp_buf; + ggml_tensor * inp_uniform; }; static const char * llama_sampler_dist_name(const struct llama_sampler * smpl) { @@ -1138,37 +1134,10 @@ static bool llama_sampler_dist_backend_init( ggml_backend_buffer_type_t buft) { auto * sctx = (llama_sampler_dist *) smpl->ctx; - // allocate inputs - { - ggml_init_params params = { - /*.mem_size =*/ ggml_tensor_overhead(), - /*.mem_buffer =*/ nullptr, - /*.no_alloc =*/ true, - }; - - sctx->inp_ctx.reset(ggml_init(params)); - - // Create the uniform random scalar input tensor. This will be set by - // llama_sampler_dist_backend_set_input after this graph is built. - sctx->inp_uniform = ggml_new_tensor_1d(sctx->inp_ctx.get(), GGML_TYPE_F32, 1); - ggml_set_name (sctx->inp_uniform, "uniform"); - ggml_set_input(sctx->inp_uniform); - - // Allocate all tensors from our context to the backend - sctx->inp_buf.reset(ggml_backend_alloc_ctx_tensors_from_buft(sctx->inp_ctx.get(), buft)); - - ggml_backend_buffer_clear(sctx->inp_buf.get(), 0); - } - const bool res = llama_sampler_backend_support(smpl, buft); sctx->init(res); - if (!res) { - sctx->inp_ctx.reset(nullptr); - sctx->inp_buf.reset(nullptr); - } - return res; } @@ -1178,8 +1147,13 @@ static void llama_sampler_dist_backend_apply( struct ggml_cgraph * gf, struct llama_sampler_data * data) { GGML_UNUSED(gf); + auto * sctx = (llama_sampler_dist *) smpl->ctx; + sctx->inp_uniform = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 1); + ggml_set_name (sctx->inp_uniform, "uniform"); + ggml_set_input(sctx->inp_uniform); + struct ggml_tensor * probs = ggml_soft_max(ctx, data->logits); ggml_set_name(probs, "dist_probs"); @@ -1226,6 +1200,7 @@ static void llama_sampler_dist_backend_apply( static void llama_sampler_dist_backend_set_input(struct llama_sampler * smpl) { auto * sctx = (llama_sampler_dist *) smpl->ctx; + GGML_ASSERT(sctx->inp_uniform != nullptr); // We sample in double precision and cast to float to match rnd numbers of @@ -1262,8 +1237,6 @@ struct llama_sampler * llama_sampler_init_dist(uint32_t seed) { /* .seed_cur = */ seed_cur, /* .rng = */ std::mt19937(seed_cur), /* .inp_uniform = */ nullptr, - /* .inp_ctx = */ nullptr, - /* .inp_buf = */ nullptr, } ); } @@ -1513,12 +1486,9 @@ static void llama_sampler_top_p_backend_apply( mask_reshaped = ggml_set_rows(ctx, mask_reshaped, ones, ggml_cast(ctx, idxf, GGML_TYPE_I32)); mask = ggml_reshape_1d(ctx, mask_reshaped, mask->ne[0]); - // Use ggml_scale_bias (output = (a * s) + b) which in this case becomes: - // top_p_bias = (mask * 1e9f) - 1e9f. - // So entries in the mask that we want to discard will become -1e9f, and - // others will be 0 (meaning that will not effect the logits). - const float large_val = 1e9f; - struct ggml_tensor * top_p_bias = ggml_scale_bias(ctx, mask, large_val, -large_val); + // Apply -INFINITY bias for masked-out tokens + // log(1) = 0 (keep), log(0) = -INF (discard) + struct ggml_tensor * top_p_bias = ggml_log(ctx, mask); ggml_set_name(top_p_bias, "top_p_bias"); data->logits = ggml_add(ctx, sorted_logits, top_p_bias); @@ -1673,15 +1643,11 @@ static void llama_sampler_min_p_backend_apply( struct ggml_tensor * mask = ggml_step(ctx, sub); ggml_set_name(mask, "min_p_mask"); - // Use ggml_scale_bias (output = (a * s) + b) which in this case becomes: - // min_p_bias = (mask * 1e9f) - 1e9f. - // So entries in the mask that we want to discard will become -1e9f, and - // others will be 0 (meaning that will not effect the logits). - const float large_val = 1e9f; - struct ggml_tensor * min_p_bias = ggml_scale_bias(ctx, mask, large_val, -large_val); + // Apply -INFINITY bias for masked-out tokens + // log(1) = 0 (keep), log(0) = -INF (discard) + struct ggml_tensor * min_p_bias = ggml_log(ctx, mask); ggml_set_name(min_p_bias, "min_p_bias"); - // Add the min_p bias to the logits. data->logits = ggml_add(ctx, data->logits, min_p_bias); ggml_set_name(data->logits, "min_p_logits"); @@ -3293,6 +3259,170 @@ struct llama_sampler * llama_sampler_init_dry_testing(int32_t context_size, floa return result; } +// adaptive-p sampler state +// +// maintains an exponential moving average of the *ORIGINAL* probabilities +// of selected tokens, used to compute an adapted target at each sampling step. +// +// see llama.h for a full description of the sampler +// +// ref: https://github.com/ggml-org/llama.cpp/pull/17927 +// +struct llama_sampler_adaptive_p { + const float target; // target probability (0.0 - 1.0; negative = disabled) + const float decay; // EMA decay; history ~= 1/(1-decay) tokens (0.0 - 0.99) + const uint32_t seed; // original RNG seed + uint32_t seed_cur; // actual RNG seed + std::mt19937 rng; // RNG state + float weighted_sum; // sum(p_i * decay^i) + float total_weight; // sum(decay^i), converges to 1/(1-decay) + std::vector<float> original_probs; // pre-transform probs, cached for EMA update + llama_token pending_token_id; // token ID of selected token + int32_t pending_token_idx; // index of orig. prob. of selected token in original_probs +}; + +// adaptive probability transformation constants +static constexpr float DISTRIBUTION_WIDTH = 0.3f; +static constexpr float PEAK_LOGIT_VALUE = 5.0f; +static constexpr float SHARPNESS = 10.0f; +static constexpr float INV_WIDTH = 1.0f / DISTRIBUTION_WIDTH; + +static const char * llama_sampler_adaptive_p_name(const struct llama_sampler * /*smpl*/) { + return "adaptive-p"; +} + +static void llama_sampler_adaptive_p_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) { + auto * ctx = (llama_sampler_adaptive_p *) smpl->ctx; + + llama_sampler_softmax_impl(cur_p, false); + + if (ctx->target < 0.0f) { + // at negative target values, adaptive-p is no-op + // we simply sample from the existing distribution + cur_p->selected = llama_sample_dist(cur_p, ctx->rng); + return; + } + + // store the original probabilities + ctx->original_probs.resize(cur_p->size); + for (size_t i = 0; i < cur_p->size; ++i) { + ctx->original_probs[i] = cur_p->data[i].p; + } + + // using the EMA, compute the adapted target probability for the current sampling step + auto target = std::clamp(ctx->target, 0.0f, 1.0f); + float adapted_target = std::clamp( + ctx->total_weight == 0.0f ? target : 2.0f * target - (ctx->weighted_sum / ctx->total_weight), + 0.0f, 1.0f + ); + + // adaptive probability transform + // + // quadratic near target for fine differentiation, transitioning to linear decay in the + // tails. unbounded negative logits ensure proper suppression of far-from-target tokens + // after the softmax. + // + for (size_t i = 0; i < cur_p->size; ++i) { + if (cur_p->data[i].logit == -INFINITY) { + // don't transform logits that are -INFINITY + // (as masked out by e.g. min-p and top-p when using backend sampling) + continue; + } + float dist = std::abs((cur_p->data[i].p - adapted_target) * INV_WIDTH); + cur_p->data[i].logit = PEAK_LOGIT_VALUE - SHARPNESS * dist * dist / (1.0f + dist); + } + + // softmax and sample from the transformed distribution + llama_sampler_softmax_impl(cur_p, false); + const int idx = llama_sample_dist(cur_p, ctx->rng); + cur_p->selected = idx; + + // store the selected token ID for acceptance later + ctx->pending_token_id = cur_p->data[idx].id; + ctx->pending_token_idx = idx; +} + +static void llama_sampler_adaptive_p_accept(struct llama_sampler * smpl, llama_token token) { + auto * ctx = (llama_sampler_adaptive_p *) smpl->ctx; + if (ctx->pending_token_id == token) { + GGML_ASSERT(ctx->pending_token_id != LLAMA_TOKEN_NULL); + GGML_ASSERT(ctx->pending_token_idx != -1); + // update EMA with the original probability of the selected token + ctx->weighted_sum = ctx->original_probs[ctx->pending_token_idx] + ctx->decay * ctx->weighted_sum; + ctx->total_weight = 1.0f + ctx->decay * ctx->total_weight; + } + ctx->pending_token_id = LLAMA_TOKEN_NULL; + ctx->pending_token_idx = -1; +} + +static void llama_sampler_adaptive_p_reset(struct llama_sampler * smpl) { + auto * ctx = (llama_sampler_adaptive_p *) smpl->ctx; + // ctx->target and ctx->decay never change after init, so it's safe to keep them as is. + // original_probs is completely overwritten on every call to _apply. + // so we only need to reset the EMA state and pending token. + ctx->weighted_sum = ctx->target / (1.0f - ctx->decay); + ctx->total_weight = 1.0f / (1.0f - ctx->decay); + ctx->pending_token_id = LLAMA_TOKEN_NULL; + ctx->pending_token_idx = -1; + ctx->seed_cur = get_rng_seed(ctx->seed); + ctx->rng.seed(ctx->seed_cur); +} + +static struct llama_sampler * llama_sampler_adaptive_p_clone(const struct llama_sampler * smpl) { + const auto * ctx = (const llama_sampler_adaptive_p *) smpl->ctx; + auto * result = llama_sampler_init_adaptive_p(ctx->target, ctx->decay, ctx->seed); + auto * result_ctx = (llama_sampler_adaptive_p *) result->ctx; + + // copy everything (target, decay, seed, and RNG are already set) + result_ctx->weighted_sum = ctx->weighted_sum; + result_ctx->total_weight = ctx->total_weight; + result_ctx->pending_token_id = ctx->pending_token_id; + result_ctx->pending_token_idx = ctx->pending_token_idx; + + return result; +} + +static void llama_sampler_adaptive_p_free(struct llama_sampler * smpl) { + delete (llama_sampler_adaptive_p *) smpl->ctx; +} + +static struct llama_sampler_i llama_sampler_adaptive_p_i = { + /* .name = */ llama_sampler_adaptive_p_name, + /* .accept = */ llama_sampler_adaptive_p_accept, + /* .apply = */ llama_sampler_adaptive_p_apply, + /* .reset = */ llama_sampler_adaptive_p_reset, + /* .clone = */ llama_sampler_adaptive_p_clone, + /* .free = */ llama_sampler_adaptive_p_free, + /* .backend_init = */ nullptr, + /* .backend_accept = */ nullptr, + /* .backend_apply = */ nullptr, + /* .backend_set_input = */ nullptr, +}; + +struct llama_sampler * llama_sampler_init_adaptive_p( + float target, + float decay, + uint32_t seed +) { + auto seed_cur = get_rng_seed(seed); + float clamped_decay = std::clamp(decay, 0.0f, 0.99f); + return llama_sampler_init( + /* .iface = */ &llama_sampler_adaptive_p_i, + /* .ctx = */ new llama_sampler_adaptive_p { + /* .target = */ target, + /* .decay = */ clamped_decay, + /* .seed = */ seed, + /* .seed_cur = */ seed_cur, + /* .rng = */ std::mt19937(seed_cur), + /* .weighted_sum = */ target / (1.0f - clamped_decay), + /* .total_weight = */ 1.0f / (1.0f - clamped_decay), + /* .original_probs = */ {}, + /* .pending_token_id = */ LLAMA_TOKEN_NULL, + /* .pending_token_idx = */ -1 + } + ); +} + // logit-bias struct llama_sampler_logit_bias : public llama_sampler_backend { @@ -3304,9 +3434,6 @@ struct llama_sampler_logit_bias : public llama_sampler_backend { struct ggml_tensor * inp_logit_bias; struct ggml_tensor * inp_logit_idxs; - - ggml_context_ptr inp_ctx; - ggml_backend_buffer_ptr inp_buf; }; static const char * llama_sampler_logit_bias_name(const struct llama_sampler * smpl) { @@ -3369,6 +3496,16 @@ static void llama_sampler_logit_bias_backend_apply( return; } + const size_t n = sctx->logit_bias.size(); + + sctx->inp_logit_bias = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, 1, n); + ggml_set_name(sctx->inp_logit_bias, "logit_bias"); + ggml_set_input(sctx->inp_logit_bias); + + sctx->inp_logit_idxs = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, n); + ggml_set_name(sctx->inp_logit_idxs, "logit_idxs"); + ggml_set_input(sctx->inp_logit_idxs); + ggml_tensor * cur = ggml_fill(ctx, data->logits, 0.0f); cur = ggml_reshape_2d(ctx, cur, 1, ggml_nelements(cur)); @@ -3405,6 +3542,8 @@ static void llama_sampler_logit_bias_backend_set_input(struct llama_sampler * sm static bool llama_sampler_logit_bias_backend_init( struct llama_sampler * smpl, ggml_backend_buffer_type_t buft) { + GGML_UNUSED(buft); + auto * sctx = (llama_sampler_logit_bias *) smpl->ctx; sctx->init(true); @@ -3413,29 +3552,6 @@ static bool llama_sampler_logit_bias_backend_init( return true; } - ggml_init_params params = { - /*.mem_size =*/ 2*ggml_tensor_overhead(), - /*.mem_buffer =*/ nullptr, - /*.no_alloc =*/ true, - }; - - sctx->inp_ctx.reset(ggml_init(params)); - - const size_t n = sctx->logit_bias.size(); - - sctx->inp_logit_bias = ggml_new_tensor_2d(sctx->inp_ctx.get(), GGML_TYPE_F32, 1, n); - ggml_set_name(sctx->inp_logit_bias, "logit_bias"); - ggml_set_input(sctx->inp_logit_bias); - - sctx->inp_logit_idxs = ggml_new_tensor_1d(sctx->inp_ctx.get(), GGML_TYPE_I32, n); - ggml_set_name(sctx->inp_logit_idxs, "logit_idxs"); - ggml_set_input(sctx->inp_logit_idxs); - - // Allocate all tensors from our context to the backend - sctx->inp_buf.reset(ggml_backend_alloc_ctx_tensors_from_buft(sctx->inp_ctx.get(), buft)); - - ggml_backend_buffer_clear(sctx->inp_buf.get(), 0); - return true; } @@ -3471,8 +3587,6 @@ struct llama_sampler * llama_sampler_init_logit_bias( /* .to_search = */ {}, /* .inp_logit_bias = */ nullptr, /* .inp_logit_idxs = */ nullptr, - /* .inp_ctx = */ nullptr, - /* .inp_buf = */ nullptr, } ); } diff --git a/examples/talk-llama/llama-sampling.h b/examples/talk-llama/llama-sampler.h similarity index 92% rename from examples/talk-llama/llama-sampling.h rename to examples/talk-llama/llama-sampler.h index 6a963c0bb73..b9bfc20d251 100644 --- a/examples/talk-llama/llama-sampling.h +++ b/examples/talk-llama/llama-sampler.h @@ -1,7 +1,5 @@ #pragma once -// TODO: rename llama-sampling.h/.cpp to llama-sampler.h/.cpp ? - #include "llama.h" #include <vector> diff --git a/examples/talk-llama/llama-vocab.cpp b/examples/talk-llama/llama-vocab.cpp index a20c6525e46..8543e178dba 100644 --- a/examples/talk-llama/llama-vocab.cpp +++ b/examples/talk-llama/llama-vocab.cpp @@ -90,7 +90,7 @@ static_assert(std::is_trivially_copyable<llm_symbol>::value, "llm_symbol is not // // SPM tokenizer // original implementation: -// https://github.com/ggerganov/llama.cpp/commit/074bea2eb1f1349a0118239c4152914aecaa1be4 +// https://github.com/ggml-org/llama.cpp/commit/074bea2eb1f1349a0118239c4152914aecaa1be4 // struct llm_bigram_spm { @@ -285,10 +285,19 @@ struct llm_tokenizer_bpe : llm_tokenizer { // original regex from tokenizer.json //"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+", - // adapted: https://github.com/ggerganov/llama.cpp/pull/6920#issuecomment-2080233989 + // adapted: https://github.com/ggml-org/llama.cpp/pull/6920#issuecomment-2080233989 "(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+", }; break; + case LLAMA_VOCAB_PRE_TYPE_JAIS2: + regex_exprs = { + // original regex from tokenizer.json + //"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s{512}(?!\\S)|\\s{256}(?!\\S)|\\s{128}(?!\\S)|\\s{64}(?!\\S)|\\s{32}(?!\\S)|\\s{16}(?!\\S)|\\s{8}(?!\\S)|\\s{4}(?!\\S)|\\s{1,2}(?!\\S)|\\s{1}", + + // adapted: same as llama3 but with cascading whitespace pattern + "(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s{512}(?!\\S)|\\s{256}(?!\\S)|\\s{128}(?!\\S)|\\s{64}(?!\\S)|\\s{32}(?!\\S)|\\s{16}(?!\\S)|\\s{8}(?!\\S)|\\s{4}(?!\\S)|\\s{1,2}(?!\\S)|\\s{1}", + }; + break; case LLAMA_VOCAB_PRE_TYPE_DBRX: case LLAMA_VOCAB_PRE_TYPE_SMAUG: regex_exprs = { @@ -308,6 +317,7 @@ struct llm_tokenizer_bpe : llm_tokenizer { break; case LLAMA_VOCAB_PRE_TYPE_DEEPSEEK3_LLM: case LLAMA_VOCAB_PRE_TYPE_HUNYUAN_DENSE: + case LLAMA_VOCAB_PRE_TYPE_JOYAI_LLM: regex_exprs = { "\\p{N}{1,3}", "[一-龥぀-ゟ゠-ヿ]+", @@ -343,6 +353,7 @@ struct llm_tokenizer_bpe : llm_tokenizer { case LLAMA_VOCAB_PRE_TYPE_CODESHELL: case LLAMA_VOCAB_PRE_TYPE_EXAONE: case LLAMA_VOCAB_PRE_TYPE_MINERVA: + case LLAMA_VOCAB_PRE_TYPE_MELLUM2: regex_exprs = { "\\p{N}", "'s|'t|'re|'ve|'m|'ll|'d| ?\\p{L}+| ?\\p{N}+| ?[^\\s\\p{L}\\p{N}]+|\\s+(?!\\S)", @@ -368,6 +379,13 @@ struct llm_tokenizer_bpe : llm_tokenizer { "(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+", }; break; + case LLAMA_VOCAB_PRE_TYPE_QWEN35: + regex_exprs = { + // original regex from tokenizer.json + // "(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?[\\p{L}\\p{M}]+|\\p{N}| ?[^\\s\\p{L}\\p{M}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+" + "(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])|[^\\r\\n\\p{L}\\p{N}]?[\\p{L}\\p{M}]+|\\p{N}| ?[^\\s\\p{L}\\p{M}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+", + }; + break; case LLAMA_VOCAB_PRE_TYPE_PORO: case LLAMA_VOCAB_PRE_TYPE_BLOOM: case LLAMA_VOCAB_PRE_TYPE_GPT3_FINNISH: @@ -415,6 +433,23 @@ struct llm_tokenizer_bpe : llm_tokenizer { "[^\\r\\n\\p{L}\\p{N}]?((?=[\\p{L}])([^a-z]))*((?=[\\p{L}])([^A-Z]))+(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])?|[^\\r\\n\\p{L}\\p{N}]?((?=[\\p{L}])([^a-z]))+((?=[\\p{L}])([^A-Z]))*(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])?|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n/]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+", }; break; + case LLAMA_VOCAB_PRE_TYPE_GRANITE_EMB_MULTI: + // Same lookaheads as GPT4O but with \p{M} added so combining marks + // (diacritics) attach to their base letters. Avoids excessive + // backtracking on scripts that use them heavily (Bengali, Hindi, + // Telugu, Thai, ...). See PR #22716 for benchmarks. + regex_exprs = { + "[^\\r\\n\\p{L}\\p{N}]?((?=[\\p{L}\\p{M}])([^a-z]))*((?=[\\p{L}\\p{M}])([^A-Z]))+(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])?|[^\\r\\n\\p{L}\\p{N}]?((?=[\\p{L}\\p{M}])([^a-z]))+((?=[\\p{L}\\p{M}])([^A-Z]))*(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])?|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n/]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+", + }; + break; + case LLAMA_VOCAB_PRE_TYPE_TINY_AYA: + regex_exprs = { + // original regex from tokenizer.json: "\\d{1,3}(?=(?:\\d{3})*\\b)" + "\\d{1,3}(?=(?:\\d{3})*\\b)", + // original regex from tokenizer.json: "[^\\r\\n\\p{L}\\p{N}]?[\\p{Lu}\\p{Lt}\\p{Lm}\\p{Lo}\\p{M}]*[\\p{Ll}\\p{Lm}\\p{Lo}\\p{M}]+(?i:'s|'t|'re|'ve|'m|'ll|'d)?|[^\\r\\n\\p{L}\\p{N}]?[\\p{Lu}\\p{Lt}\\p{Lm}\\p{Lo}\\p{M}]+[\\p{Ll}\\p{Lm}\\p{Lo}\\p{M}]*(?i:'s|'t|'re|'ve|'m|'ll|'d)?|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n/]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+" + "[^\\r\\n\\p{L}\\p{N}]?[\\p{Lu}\\p{Lt}\\p{Lm}\\p{Lo}\\p{M}]*[\\p{Ll}\\p{Lm}\\p{Lo}\\p{M}]+(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])?|[^\\r\\n\\p{L}\\p{N}]?[\\p{Lu}\\p{Lt}\\p{Lm}\\p{Lo}\\p{M}]+[\\p{Ll}\\p{Lm}\\p{Lo}\\p{M}]*(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])?|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n/]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+", + }; + break; case LLAMA_VOCAB_PRE_TYPE_KIMI_K2: regex_exprs = { // K2 trigger pattern - this will activate the custom K2 handler in unicode.cpp @@ -461,6 +496,46 @@ struct llm_tokenizer_bpe : llm_tokenizer { "[!\"#$%&'()*+,\\-./:;<=>?@\\[\\\\\\]^_`{|}~][A-Za-z]+|[^\\r\\n\\p{L}\\p{P}\\p{S}]?[\\p{L}\\p{M}]+| ?[\\p{P}\\p{S}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+", }; break; + case LLAMA_VOCAB_PRE_TYPE_EXAONE_MOE: + regex_exprs = { + // original regex from tokenizer.json + // "(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?(?:\\p{L}\\p{M}*(?: \\p{L}\\p{M}*)*)+|\\p{N}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n/]?|\\s*[\\r\\n]|\\s+(?!\\S)|\\s+" + "(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])|[^\\r\\n\\p{L}\\p{N}]?(?:\\p{L}\\p{M}*(?: \\p{L}\\p{M}*)*)+|\\p{N}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n/]?|\\s*[\\r\\n]|\\s+(?!\\S)|\\s+", + }; + break; + case LLAMA_VOCAB_PRE_TYPE_GEMMA4: + // Gemma4 uses SPM-style BPE: spaces are replaced with ▁ by the + // normalizer, then BPE merges run on the whole text without + // word-level pre-splitting. We only need to split on newlines + // since BPE merge lookup asserts no newlines in tokens. + regex_exprs = { + "[^\\n]+|[\\n]+", + }; + byte_encode = false; // uses raw UTF-8, not GPT-2 byte encoding + break; + case LLAMA_VOCAB_PRE_TYPE_SARVAM_MOE: + // Sarvam uses SPM-style BPE (same shape as Gemma4): spaces replaced with U+2581 + // by the normalizer, BPE merges over the whole text on raw UTF-8. + regex_exprs = { + "[^\\n]+|[\\n]+", + }; + byte_encode = false; + break; + case LLAMA_VOCAB_PRE_TYPE_MINICPM5: + regex_exprs = { + // original regex from tokenizer.json (openbmb/MiniCPM5-1B) + "\\p{N}{1,3}", + // "(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}+| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+" + "(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}+| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+", + }; + break; + case LLAMA_VOCAB_PRE_TYPE_WHITESPACE: + // whitespace pre-tokenizer (jinaai/jina-embeddings-v2-base-zh) + regex_exprs = { + "\\S+", + }; + byte_encode = false; + break; default: // default regex for BPE tokenization pre-processing regex_exprs = { @@ -474,11 +549,14 @@ struct llm_tokenizer_bpe : llm_tokenizer { } std::vector<std::string> regex_exprs; + bool byte_encode = true; // GPT-2 byte encoding; false for SPM-style BPE (raw UTF-8) }; struct llm_tokenizer_bpe_session { llm_tokenizer_bpe_session(const llama_vocab & vocab, const llm_tokenizer_bpe & tokenizer) : vocab(vocab), tokenizer(tokenizer) {} + virtual ~llm_tokenizer_bpe_session() = default; + static void append(const llama_token token_id, std::vector<llama_token> & output) { output.push_back(token_id); } @@ -516,11 +594,12 @@ struct llm_tokenizer_bpe_session { } } - void tokenize(const std::string & text, std::vector<llama_token> & output) { + virtual void tokenize(const std::string & text, std::vector<llama_token> & output) { int final_prev_index = -1; - const auto word_collection = unicode_regex_split(text, tokenizer.regex_exprs); + const auto word_collection = unicode_regex_split(text, tokenizer.regex_exprs, tokenizer.byte_encode); symbols_final.clear(); + auto tok_pre = vocab.get_pre_type(); for (const auto & word : word_collection) { work_queue = llm_bigram_bpe::queue(); @@ -533,6 +612,13 @@ struct llm_tokenizer_bpe_session { if (vocab.get_ignore_merges() && vocab.text_to_token(word) != LLAMA_TOKEN_NULL) { symbols.emplace_back(llm_symbol{-1, -1, word.c_str(), word.size()}); offset = word.size(); + } else if (tok_pre == LLAMA_VOCAB_PRE_TYPE_GEMMA4 && word.find_first_not_of('\n') == std::string::npos) { + // fix for gemma 4, ref: https://github.com/ggml-org/llama.cpp/pull/21343 + auto tok = vocab.text_to_token(word); + if (tok != LLAMA_TOKEN_NULL) { + symbols.emplace_back(llm_symbol{-1, -1, word.c_str(), word.size()}); + offset = word.size(); + } } while (offset < word.size()) { @@ -608,8 +694,17 @@ struct llm_tokenizer_bpe_session { if (token == LLAMA_TOKEN_NULL) { for (auto j = str.begin(); j != str.end(); ++j) { - std::string byte_str(1, *j); - auto token_multibyte = vocab.text_to_token(byte_str); + llama_token token_multibyte = LLAMA_TOKEN_NULL; + if (tokenizer.byte_encode) { + std::string byte_str(1, *j); + token_multibyte = vocab.text_to_token(byte_str); + } else { + // For non-byte-encoded BPE (e.g. gemma-4), byte tokens use <0xXX> format + static const char * hex = "0123456789ABCDEF"; + const uint8_t ch = (uint8_t)*j; + const char buf[7] = { '<', '0', 'x', hex[ch >> 4], hex[ch & 15], '>', 0 }; + token_multibyte = vocab.text_to_token(buf); + } if (token_multibyte != LLAMA_TOKEN_NULL) { output.push_back(token_multibyte); } @@ -669,7 +764,7 @@ struct llm_tokenizer_wpm_session { void tokenize(const std::string & text, std::vector<llama_token> & output) { // normalize and split by whitespace - std::vector<std::string> words = preprocess(text); + std::vector<std::string> words = preprocess(text, vocab.get_normalizer_opts()); // bos token prepended already // find the longest tokens that form the words @@ -714,11 +809,14 @@ struct llm_tokenizer_wpm_session { } // TODO: reduce string copies by using cpts_offs array - static std::vector<std::string> preprocess(const std::string & text) { - const std::vector<uint32_t> cpts_nfd = unicode_cpts_normalize_nfd(unicode_cpts_from_utf8(text)); + static std::vector<std::string> preprocess(const std::string & text, const llama_vocab::normalizer_options & normalizer_opts) { + std::vector<uint32_t> cpts = unicode_cpts_from_utf8(text); + if (normalizer_opts.strip_accents) { + cpts = unicode_cpts_normalize_nfd(cpts); + } std::vector<std::string> words(1, ""); - for (const uint32_t cpt : cpts_nfd) { + for (const uint32_t cpt : cpts) { const auto flags = unicode_cpt_flags_from_cpt(cpt); if (flags.is_whitespace) { @@ -733,7 +831,11 @@ struct llm_tokenizer_wpm_session { continue; } - const std::string s = unicode_cpt_to_utf8(unicode_tolower(cpt)); + if (normalizer_opts.strip_accents && flags.is_accent_mark) { + continue; + } + + const std::string s = unicode_cpt_to_utf8(normalizer_opts.lowercase ? unicode_tolower(cpt) : cpt); if (flags.is_punctuation || ( cpt < 0x7F && flags.is_symbol ) || is_chinese_char(cpt)) { if (words.back().size()) { // finish previous word if any words.emplace_back(); @@ -1511,6 +1613,117 @@ struct llm_tokenizer_plamo2_session { const llm_tokenizer_plamo2 & tokenizer; }; +// reserved suffix (U+E000) that keeps DNA k-mers distinct from identical +// base-vocab BPE tokens (e.g. CCCCCC) in token_to_id; erased from id_to_token +// text at load +static const std::string dna_kmer_marker = "\xee\x80\x80"; + +struct llm_tokenizer_hybriddna_session : llm_tokenizer_bpe_session { + llm_tokenizer_hybriddna_session(const llama_vocab & vocab, const llm_tokenizer_bpe & tokenizer) : llm_tokenizer_bpe_session{vocab, tokenizer}, vocab{vocab} {} + + void tokenize(const std::string & text, std::vector<llama_token> & output) override { + static const std::string open_tag = "<dna>"; + static const std::string close_tag = "</dna>"; + + const auto dna_begin_id = vocab.text_to_token(open_tag); + const auto dna_end_id = vocab.text_to_token(close_tag); + const auto dna_oov_id = vocab.text_to_token("<oov>"); + + // Fall back to plain BPE if the DNA pieces aren't in the vocab. + if (dna_begin_id == LLAMA_TOKEN_NULL || dna_end_id == LLAMA_TOKEN_NULL || dna_oov_id == LLAMA_TOKEN_NULL) { + llm_tokenizer_bpe_session::tokenize(text, output); + return; + } + + const size_t k = 6; + size_t pos = 0; + + while (pos < text.size()) { + const size_t start = text.find(open_tag, pos); + if (start == std::string::npos) { + if (pos < text.size()) { + llm_tokenizer_bpe_session::tokenize(text.substr(pos), output); + } + break; + } + if (start > pos) { + llm_tokenizer_bpe_session::tokenize(text.substr(pos, start - pos), output); + } + output.push_back(dna_begin_id); + + const size_t content_start = start + open_tag.size(); + const size_t end = text.find(close_tag, content_start); + const size_t content_end = (end == std::string::npos) ? text.size() : end; + + emit_dna_kmers(text.substr(content_start, content_end - content_start), k, dna_oov_id, output); + + if (end == std::string::npos) { + break; + } + output.push_back(dna_end_id); + pos = end + close_tag.size(); + } + } + +private: + void emit_dna_kmers(const std::string & raw, size_t k, llama_token oov_id, std::vector<llama_token> & output) { + std::string seq = raw; + for (char & c : seq) { + if (c >= 'a' && c <= 'z') { + c = char(c - 32); + } + } + + // k-mers carry the reserved marker suffix; a non-ACGT k-mer simply + // isn't in the vocab and falls back to <oov> + auto kmer_token = [&](const std::string & kmer) { + const auto tok = vocab.text_to_token(kmer + dna_kmer_marker); + return tok != LLAMA_TOKEN_NULL ? tok : oov_id; + }; + + size_t i = 0; + for (; i + k <= seq.size(); i += k) { + output.push_back(kmer_token(seq.substr(i, k))); + } + if (i < seq.size()) { + std::string kmer = seq.substr(i); + kmer.append(k - kmer.size(), 'A'); + output.push_back(kmer_token(kmer)); + } + } + + const llama_vocab & vocab; +}; + +struct llm_tokenizer_whitespace_session : llm_tokenizer_bpe_session { + llm_tokenizer_whitespace_session(const llama_vocab & vocab, const llm_tokenizer_bpe & tokenizer) : llm_tokenizer_bpe_session{vocab, tokenizer}, vocab{vocab} {} + + void tokenize(const std::string & text, std::vector<llama_token> & output) override { + const bool lowercase = vocab.get_normalizer_opts().lowercase; + + std::string segment; + auto flush = [&]() { + if (!segment.empty()) { + llm_tokenizer_bpe_session::tokenize(segment, output); + segment.clear(); + } + }; + + for (uint32_t cpt : unicode_cpts_from_utf8(text)) { + // drop whitespace + if (unicode_cpt_flags_from_cpt(cpt).is_whitespace) { + flush(); + } else { + segment += unicode_cpt_to_utf8(lowercase ? unicode_tolower(cpt) : cpt); + } + } + flush(); + } + +private: + const llama_vocab & vocab; +}; + // // impl // @@ -1592,6 +1805,9 @@ struct llama_vocab::impl { bool escape_whitespaces = true; bool treat_whitespace_as_suffix = false; + // BertNormalizer options + llama_vocab::normalizer_options normalizer_opts; + std::unordered_map<std::string, llama_token> token_to_id; std::vector<token_data> id_to_token; @@ -1608,6 +1824,8 @@ struct llama_vocab::impl { // set of all tokens that cause "end of generation" std::set<llama_token> special_eog_ids; + std::vector<llama_token> suppress_tokens; + std::unique_ptr<llm_tokenizer> tokenizer; std::vector<char> precompiled_charsmap; @@ -1687,7 +1905,7 @@ struct llama_vocab::impl { }; void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) { - struct gguf_context * ctx = ml.meta.get(); + struct gguf_context * ctx = ml.metadata; // determine vocab type { @@ -1740,31 +1958,38 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) { special_mask_id = 103; add_sep = true; - } else if (tokenizer_model == "gpt2") { + } else if (tokenizer_model == "gpt2" || tokenizer_model == "hybriddna" || tokenizer_model == "whitespace") { type = LLAMA_VOCAB_TYPE_BPE; // read bpe merges and populate bpe ranks const int merges_keyidx = gguf_find_key(ctx, kv(LLM_KV_TOKENIZER_MERGES).c_str()); + // Kimi-K2 uses custom tokenization without traditional BPE merges + const bool is_kimi_k2 = (tokenizer_pre == "kimi-k2"); + if (merges_keyidx == -1) { - throw std::runtime_error("cannot find tokenizer merges in model file\n"); - } + if (!is_kimi_k2) { + throw std::runtime_error("cannot find tokenizer merges in model file\n"); + } + // Kimi-K2 doesn't need merges, skip + LLAMA_LOG_INFO("%s: Kimi-K2 tokenizer detected, skipping BPE merges\n", __func__); + } else { + const int n_merges = gguf_get_arr_n(ctx, merges_keyidx); + for (int i = 0; i < n_merges; i++) { + const std::string word = gguf_get_arr_str(ctx, merges_keyidx, i); + //GGML_ASSERT(unicode_cpts_from_utf8(word).size() > 0); - const int n_merges = gguf_get_arr_n(ctx, merges_keyidx); - for (int i = 0; i < n_merges; i++) { - const std::string word = gguf_get_arr_str(ctx, merges_keyidx, i); - //GGML_ASSERT(unicode_cpts_from_utf8(word).size() > 0); + std::string first; + std::string second; - std::string first; - std::string second; + const size_t pos = word.find(' ', 1); - const size_t pos = word.find(' ', 1); + if (pos != std::string::npos) { + first = word.substr(0, pos); + second = word.substr(pos + 1); + } - if (pos != std::string::npos) { - first = word.substr(0, pos); - second = word.substr(pos + 1); + bpe_ranks.emplace(std::make_pair(first, second), i); } - - bpe_ranks.emplace(std::make_pair(first, second), i); } // default special tokens @@ -1794,7 +2019,7 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) { const char * pc = (const char *) gguf_get_arr_data(ctx, precompiled_charsmap_keyidx); precompiled_charsmap.assign(pc, pc + n_precompiled_charsmap); #if defined(__BYTE_ORDER__) && defined(__ORDER_BIG_ENDIAN__) && __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__ - // correct endiannes of data in precompiled_charsmap binary blob + // correct endianness of data in precompiled_charsmap binary blob uint32_t * xcda_blob_size = (uint32_t *) &precompiled_charsmap[0]; *xcda_blob_size = __builtin_bswap32(*xcda_blob_size); assert(*xcda_blob_size + sizeof(uint32_t) < n_precompiled_charsmap); @@ -1824,6 +2049,42 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) { special_sep_id = LLAMA_TOKEN_NULL; special_pad_id = 3; // <|plamo:pad|> special_mask_id = LLAMA_TOKEN_NULL; + } else if (tokenizer_model == "gemma4") { + type = LLAMA_VOCAB_TYPE_BPE; + + // read bpe merges and populate bpe ranks + const int merges_keyidx = gguf_find_key(ctx, kv(LLM_KV_TOKENIZER_MERGES).c_str()); + if (merges_keyidx == -1) { + throw std::runtime_error("cannot find tokenizer merges in model file\n"); + } + { + const int n_merges = gguf_get_arr_n(ctx, merges_keyidx); + for (int i = 0; i < n_merges; i++) { + const std::string word = gguf_get_arr_str(ctx, merges_keyidx, i); + + std::string first; + std::string second; + + const size_t pos = word.find(' ', 1); + + if (pos != std::string::npos) { + first = word.substr(0, pos); + second = word.substr(pos + 1); + } + + bpe_ranks.emplace(std::make_pair(first, second), i); + } + } + + // default special tokens (to be read from GGUF) + special_bos_id = LLAMA_TOKEN_NULL; + special_eos_id = LLAMA_TOKEN_NULL; + special_unk_id = LLAMA_TOKEN_NULL; + special_sep_id = LLAMA_TOKEN_NULL; + special_pad_id = LLAMA_TOKEN_NULL; + special_mask_id = LLAMA_TOKEN_NULL; + + tokenizer_pre = "gemma4"; } else { throw std::runtime_error(format("unknown tokenizer: '%s'", tokenizer_model.c_str())); } @@ -1831,6 +2092,7 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) { // for now, only BPE models have pre-tokenizers if (type == LLAMA_VOCAB_TYPE_BPE) { add_space_prefix = false; + escape_whitespaces = false; clean_spaces = true; if (tokenizer_pre.empty()) { LLAMA_LOG_WARN("%s: missing pre-tokenizer type, using: 'default'\n", __func__); @@ -1843,6 +2105,9 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) { pre_type = LLAMA_VOCAB_PRE_TYPE_DEFAULT; } else if (tokenizer_pre == "default") { pre_type = LLAMA_VOCAB_PRE_TYPE_DEFAULT; + } else if (tokenizer_pre == "minicpm5") { + pre_type = LLAMA_VOCAB_PRE_TYPE_MINICPM5; + ignore_merges = true; } else if ( tokenizer_pre == "llama3" || tokenizer_pre == "llama-v3" || @@ -1851,7 +2116,8 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) { tokenizer_pre == "falcon-h1" || tokenizer_pre == "pixtral" || tokenizer_pre == "midm-2.0" || - tokenizer_pre == "lfm2") { + tokenizer_pre == "lfm2" || + tokenizer_pre == "jina-v5-nano") { pre_type = LLAMA_VOCAB_PRE_TYPE_LLAMA3; ignore_merges = true; add_bos = true; @@ -1891,14 +2157,31 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) { tokenizer_pre == "jina-v2-de" || tokenizer_pre == "a.x-4.0" || tokenizer_pre == "mellum" || - tokenizer_pre == "modern-bert" ) { + tokenizer_pre == "modern-bert") { pre_type = LLAMA_VOCAB_PRE_TYPE_GPT2; + } else if ( + tokenizer_pre == "jais-2") { + pre_type = LLAMA_VOCAB_PRE_TYPE_JAIS2; + } else if ( + tokenizer_pre == "gemma4" || + tokenizer_pre == "granite-embed-multi-311m") { + pre_type = LLAMA_VOCAB_PRE_TYPE_GEMMA4; + escape_whitespaces = true; + } else if ( + tokenizer_pre == "sarvam-moe") { + pre_type = LLAMA_VOCAB_PRE_TYPE_SARVAM_MOE; + escape_whitespaces = true; + clean_spaces = false; } else if ( tokenizer_pre == "jina-v1-en" || tokenizer_pre == "jina-v2-code" || tokenizer_pre == "roberta-bpe") { pre_type = LLAMA_VOCAB_PRE_TYPE_GPT2; add_sep = true; + } else if ( + tokenizer_pre == "whitespace") { + pre_type = LLAMA_VOCAB_PRE_TYPE_WHITESPACE; + normalizer_opts.lowercase = false; } else if ( tokenizer_pre == "refact") { pre_type = LLAMA_VOCAB_PRE_TYPE_REFACT; @@ -1909,9 +2192,14 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) { } else if ( tokenizer_pre == "qwen2" || tokenizer_pre == "deepseek-r1-qwen" || - tokenizer_pre == "kormo") { + tokenizer_pre == "kormo" || + tokenizer_pre == "f2llmv2") { pre_type = LLAMA_VOCAB_PRE_TYPE_QWEN2; clean_spaces = false; + } else if ( + tokenizer_pre == "qwen35") { + pre_type = LLAMA_VOCAB_PRE_TYPE_QWEN35; + clean_spaces = false; } else if ( tokenizer_pre == "stablelm2") { pre_type = LLAMA_VOCAB_PRE_TYPE_STABLELM2; @@ -1965,6 +2253,9 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) { } else if ( tokenizer_pre == "exaone4") { pre_type = LLAMA_VOCAB_PRE_TYPE_GPT2; + } else if ( + tokenizer_pre == "exaone-moe") { + pre_type = LLAMA_VOCAB_PRE_TYPE_EXAONE_MOE; } else if ( tokenizer_pre == "chameleon") { pre_type = LLAMA_VOCAB_PRE_TYPE_CHAMELEON; @@ -1977,10 +2268,21 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) { tokenizer_pre == "megrez") { pre_type = LLAMA_VOCAB_PRE_TYPE_QWEN2; } else if ( - tokenizer_pre == "gpt-4o" || - tokenizer_pre == "llama4") { + tokenizer_pre == "gpt-4o" || + tokenizer_pre == "llama4" || + tokenizer_pre == "kanana2" || + tokenizer_pre == "talkie") { pre_type = LLAMA_VOCAB_PRE_TYPE_GPT4O; clean_spaces = false; + } else if ( + tokenizer_pre == "granite-embed-multi-97m") { + pre_type = LLAMA_VOCAB_PRE_TYPE_GRANITE_EMB_MULTI; + clean_spaces = false; + ignore_merges = true; + } else if ( + tokenizer_pre == "tiny_aya") { + pre_type = LLAMA_VOCAB_PRE_TYPE_TINY_AYA; + clean_spaces = false; } else if ( tokenizer_pre == "superbpe") { pre_type = LLAMA_VOCAB_PRE_TYPE_SUPERBPE; @@ -2011,6 +2313,10 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) { tokenizer_pre == "hunyuan-dense") { pre_type = LLAMA_VOCAB_PRE_TYPE_HUNYUAN_DENSE; clean_spaces = false; + } else if ( + tokenizer_pre == "joyai-llm") { + pre_type = LLAMA_VOCAB_PRE_TYPE_JOYAI_LLM; + clean_spaces = false; } else if ( tokenizer_pre == "kimi-k2") { pre_type = LLAMA_VOCAB_PRE_TYPE_KIMI_K2; @@ -2031,6 +2337,9 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) { tokenizer_pre == "solar-open") { pre_type = LLAMA_VOCAB_PRE_TYPE_SOLAR_OPEN; clean_spaces = false; + } else if ( + tokenizer_pre == "mellum2") { + pre_type = LLAMA_VOCAB_PRE_TYPE_MELLUM2; } else { throw std::runtime_error(format("unknown pre-tokenizer type: '%s'", tokenizer_pre.c_str())); } @@ -2070,19 +2379,28 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) { throw std::runtime_error("cannot find tokenizer vocab in model file\n"); } + const uint32_t n_tokens = gguf_get_arr_n(ctx, token_idx); + const float * scores = nullptr; const int score_idx = gguf_find_key(ctx, kv(LLM_KV_TOKENIZER_SCORES).c_str()); if (score_idx != -1) { + const uint32_t n_scores = gguf_get_arr_n(ctx, score_idx); + if (n_scores < n_tokens) { + throw std::runtime_error("Index out of array bounds for scores (" + std::to_string(n_scores) + " < " + std::to_string(n_tokens) + ")\n"); + } scores = (const float * ) gguf_get_arr_data(ctx, score_idx); } const int * toktypes = nullptr; const int toktype_idx = gguf_find_key(ctx, kv(LLM_KV_TOKENIZER_TOKEN_TYPE).c_str()); if (toktype_idx != -1) { + const uint32_t n_toktypes = gguf_get_arr_n(ctx, toktype_idx); + if (n_toktypes < n_tokens) { + throw std::runtime_error("Index out of array bounds for toktypes (" + std::to_string(n_toktypes) + " < " + std::to_string(n_tokens) + ")\n"); + } toktypes = (const int * ) gguf_get_arr_data(ctx, toktype_idx); } - uint32_t n_tokens = gguf_get_arr_n(ctx, token_idx); id_to_token.resize(n_tokens); for (uint32_t i = 0; i < n_tokens; i++) { @@ -2115,6 +2433,23 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) { } GGML_ASSERT(id_to_token.size() == token_to_id.size()); + // hybriddna: the marker suffix kept k-mer ids distinct in token_to_id; erase + // it from id_to_token so the k-mers detokenize to the bare DNA sequence. The + // k-mers are the block right after <oov>, so only scan from there. + if (tokenizer_model == "hybriddna") { + const auto idx = token_to_id.find("<oov>"); + if (idx != token_to_id.end()) { + auto it = id_to_token.begin() + idx->second + 1; + for (; it != id_to_token.end(); ++it) { + std::string & text = it->text; + if (text.size() > dna_kmer_marker.size() + && text.compare(text.size() - dna_kmer_marker.size(), dna_kmer_marker.size(), dna_kmer_marker) == 0) { + text.erase(text.size() - dna_kmer_marker.size()); + } + } + } + } + init_tokenizer(type); // determine the newline token: LLaMA "<0x0A>" == 10 == '\n', Falcon 193 == '\n' @@ -2196,6 +2531,29 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) { if (ml.get_key(LLM_KV_TOKENIZER_ADD_SEP, temp, false)) { add_sep = temp; } + + // workaround for Gemma 4 + // ref: https://github.com/ggml-org/llama.cpp/pull/21500 + if (pre_type == LLAMA_VOCAB_PRE_TYPE_GEMMA4 && !add_bos) { + add_bos = true; + + LLAMA_LOG_WARN("%s: override '%s' to 'true' for Gemma4\n", __func__, kv(LLM_KV_TOKENIZER_ADD_BOS).c_str()); + } + } + + // BertNormalizer options + ml.get_key(LLM_KV_TOKENIZER_NORMALIZER_LOWERCASE, normalizer_opts.lowercase, false); + normalizer_opts.strip_accents = normalizer_opts.lowercase; + ml.get_key(LLM_KV_TOKENIZER_NORMALIZER_STRIP_ACCENTS, normalizer_opts.strip_accents, false); + + // suppress tokens + { + const int suppress_idx = gguf_find_key(ctx, kv(LLM_KV_TOKENIZER_SUPPRESS_TOKENS).c_str()); + if (suppress_idx != -1) { + const int n = gguf_get_arr_n(ctx, suppress_idx); + const int32_t * data = (const int32_t *) gguf_get_arr_data(ctx, suppress_idx); + suppress_tokens.assign(data, data + n); + } } // auto-detect special tokens by text @@ -2216,6 +2574,7 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) { || t.first == "<|end_of_text|>" // granite || t.first == "<EOT>" || t.first == "_<EOT>" + || t.first == "[EOT]" // Kimi-K2 || t.first == "<|end▁of▁sentence|>" // DeepSeek || t.first == "<end_of_utterance>" // smoldocling ) { @@ -2252,6 +2611,7 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) { || t.first == "<PRE>" || t.first == "▁<PRE>" // CodeLlama || t.first == "<|code_prefix|>" // GLM-4.5 + || t.first == "<|prefix|>" // Falcon-H1-Tiny-Coder ) { special_fim_pre_id = t.second; if ((attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) { @@ -2272,6 +2632,7 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) { || t.first == "<SUF>" || t.first == "▁<SUF>" // CodeLlama || t.first == "<|code_suffix|>" // GLM-4.5 + || t.first == "<|suffix|>" // Falcon-H1-Tiny-Coder ) { special_fim_suf_id = t.second; if ((attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) { @@ -2292,6 +2653,7 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) { || t.first == "<MID>" || t.first == "▁<MID>" // CodeLlama || t.first == "<|code_middle|>" // GLM-4.5 + || t.first == "<|middle|>" // Falcon-H1-Tiny-Coder ) { special_fim_mid_id = t.second; if ((attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) { @@ -2309,6 +2671,7 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) { || t.first == "<fim-pad>" || t.first == "<fim_pad>" // Granite || t.first == "<PAD>" + || t.first == "[PAD]" // Kimi-K2 ) { special_fim_pad_id = t.second; if ((attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) { @@ -2380,7 +2743,7 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) { // maintain a list of tokens that cause end-of-generation // this is currently determined based on the token text, which is obviously not ideal - // ref: https://github.com/ggerganov/llama.cpp/issues/9606 + // ref: https://github.com/ggml-org/llama.cpp/issues/9606 special_eog_ids.clear(); if (special_fim_pad_id != LLAMA_TOKEN_NULL && special_eog_ids.count(special_fim_pad_id) == 0) { @@ -2408,11 +2771,18 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) { || t.first == "<|calls|>" // solar-open || t.first == "<end_of_turn>" || t.first == "<|endoftext|>" + || t.first == "</s>" // paddleocr || t.first == "<|eom_id|>" || t.first == "<EOT>" || t.first == "_<EOT>" + || t.first == "[EOT]" // Kimi-K2 + || t.first == "[EOS]" // Kimi-K2 || t.first == "<|end_of_text|>" || t.first == "<end_of_utterance>" // smoldocling + || t.first == "<eos>" // gemma4 + || t.first == "<turn|>" // gemma4 + || t.first == "<|tool_response>" // gemma4 + || t.first == "<|end▁of▁sentence|>" // deepseek-ocr ) { special_eog_ids.insert(t.second); if ((attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) { @@ -2436,7 +2806,10 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) { auto & attr = id_to_token[t.second].attr; if (t.first == "<|channel|>" || t.first == "<|message|>" || t.first == "<|start|>" || t.first == "<|constrain|>") { - attr = (llama_token_attr) (attr | LLAMA_TOKEN_ATTR_USER_DEFINED); + LLAMA_LOG_WARN("%s: setting token '%s' (%d) attribute to USER_DEFINED (%u), old attributes: %u\n", + __func__, t.first.c_str(), t.second, LLAMA_TOKEN_ATTR_USER_DEFINED, attr); + + attr = LLAMA_TOKEN_ATTR_USER_DEFINED; } } @@ -2489,11 +2862,38 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) { special_eog_ids.erase(end_id); auto & attr = id_to_token[end_id].attr; - attr = (llama_token_attr) (attr | LLAMA_TOKEN_ATTR_USER_DEFINED); + attr = LLAMA_TOKEN_ATTR_USER_DEFINED; LLAMA_LOG_WARN("%s: special_eog_ids contains both '<|return|>' and '<|call|>', or '<|calls|>' and '<|flush|>' tokens, removing '<|end|>' token from EOG list\n", __func__); } } + + // workaround for gemma4 and paddleocr: do not include </s> as an eog token + { + bool has_tool_response = false; + bool has_s = false; + + llama_token s_id = LLAMA_TOKEN_NULL; + + for (auto tid : special_eog_ids) { + const auto & text = id_to_token[tid].text; + if (text == "<|tool_response>") { + has_tool_response = true; + } else if (text == "</s>") { + has_s = true; + s_id = tid; + } + } + + if (has_tool_response && has_s) { + special_eog_ids.erase(s_id); + + auto & attr = id_to_token[s_id].attr; + attr = LLAMA_TOKEN_ATTR_NORMAL; + + LLAMA_LOG_WARN("%s: special_eog_ids contains '<|tool_response>', removing '</s>' token from EOG list\n", __func__); + } + } } // build special tokens cache @@ -2662,7 +3062,9 @@ uint8_t llama_vocab::impl::token_to_byte(llama_token id) const { return strtol(buf.c_str(), NULL, 16); } case LLAMA_VOCAB_TYPE_BPE: { - GGML_ABORT("fatal error"); + // Gemma4 uses BPE with SPM-style byte fallback tokens (<0xXX>) + auto buf = token_data.text.substr(3, 2); + return strtol(buf.c_str(), NULL, 16); } case LLAMA_VOCAB_TYPE_WPM: { GGML_ABORT("fatal error"); @@ -2941,28 +3343,42 @@ std::vector<llama_token> llama_vocab::impl::tokenize( } break; case LLAMA_VOCAB_TYPE_BPE: { - llm_tokenizer_bpe_session session(vocab, *static_cast<const llm_tokenizer_bpe *>(tokenizer.get())); // it calls some other methods that are not exist in llm_tokenizer, // here just cast it to bpe tokenizer object + const llm_tokenizer_bpe * tok_bpe = static_cast<const llm_tokenizer_bpe *>(tokenizer.get()); + + std::unique_ptr<llm_tokenizer_bpe_session> session; + if (vocab.get_tokenizer_model() == "hybriddna") { + session = std::make_unique<llm_tokenizer_hybriddna_session>(vocab, *tok_bpe); + } else if (vocab.get_tokenizer_model() == "whitespace") { + session = std::make_unique<llm_tokenizer_whitespace_session>(vocab, *tok_bpe); + } else { + session = std::make_unique<llm_tokenizer_bpe_session>(vocab, *tok_bpe); + } + if (add_special) { - session.append_bos(output); + session->append_bos(output); } for (const auto & fragment : fragment_buffer) { if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT) { std::string text = fragment.raw_text.substr(fragment.offset, fragment.length); + if (escape_whitespaces) { + llama_escape_whitespace(text); + } + #ifdef PRETOKENIZERDEBUG LLAMA_LOG_WARN("TT: (%ld %ld %ld) '%s'\n", text.length(), fragment.offset, fragment.length, text.c_str()); #endif - session.tokenize(text, output); + session->tokenize(text, output); } else { // if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_TOKEN) - session.append(fragment.token, output); + session->append(fragment.token, output); } } if (add_special) { - session.append_eos(output); - session.check_double_bos_eos(output); + session->append_eos(output); + session->check_double_bos_eos(output); } } break; case LLAMA_VOCAB_TYPE_WPM: @@ -3066,7 +3482,7 @@ std::vector<llama_token> llama_vocab::impl::tokenize( } int32_t llama_vocab::impl::token_to_piece(llama_token token, char * buf, int32_t length, int32_t lstrip, bool special) const { - // ref: https://github.com/ggerganov/llama.cpp/pull/7587#discussion_r1620983843 + // ref: https://github.com/ggml-org/llama.cpp/pull/7587#discussion_r1620983843 static const int attr_special = LLAMA_TOKEN_ATTR_UNKNOWN | LLAMA_TOKEN_ATTR_CONTROL; const llama_token_attr attr = token_get_attr(token); if (!special && (attr & attr_special)) { @@ -3130,9 +3546,19 @@ int32_t llama_vocab::impl::token_to_piece(llama_token token, char * buf, int32_t return _try_copy(token_text.data(), token_text.size()); } if (attr & LLAMA_TOKEN_ATTR_NORMAL) { + if (escape_whitespaces) { + // SPM-style BPE: tokens contain ▁ for spaces + std::string result = token_text; + llama_unescape_whitespace(result); + return _try_copy(result.data(), result.size()); + } std::string result = llama_decode_text(token_text); return _try_copy(result.data(), result.size()); } + if (attr & LLAMA_TOKEN_ATTR_BYTE) { + char byte = (char) token_to_byte(token); + return _try_copy((char*) &byte, 1); + } break; } case LLAMA_VOCAB_TYPE_RWKV: { @@ -3289,34 +3715,34 @@ int32_t llama_vocab::impl::detokenize( } void llama_vocab::impl::print_info() const { - LLAMA_LOG_INFO("%s: vocab type = %s\n", __func__, type_name().c_str()); - LLAMA_LOG_INFO("%s: n_vocab = %u\n", __func__, vocab.n_tokens()); - LLAMA_LOG_INFO("%s: n_merges = %u\n", __func__, (uint32_t) bpe_ranks.size()); + LLAMA_LOG_INFO("%s: vocab type = %s\n", __func__, type_name().c_str()); + LLAMA_LOG_INFO("%s: n_vocab = %u\n", __func__, vocab.n_tokens()); + LLAMA_LOG_INFO("%s: n_merges = %u\n", __func__, (uint32_t) bpe_ranks.size()); // special tokens - if (special_bos_id != LLAMA_TOKEN_NULL) { LLAMA_LOG_INFO( "%s: BOS token = %d '%s'\n", __func__, special_bos_id, id_to_token.at(special_bos_id).text.c_str() ); } - if (special_eos_id != LLAMA_TOKEN_NULL) { LLAMA_LOG_INFO( "%s: EOS token = %d '%s'\n", __func__, special_eos_id, id_to_token.at(special_eos_id).text.c_str() ); } - if (special_eot_id != LLAMA_TOKEN_NULL) { LLAMA_LOG_INFO( "%s: EOT token = %d '%s'\n", __func__, special_eot_id, id_to_token.at(special_eot_id).text.c_str() ); } - if (special_eom_id != LLAMA_TOKEN_NULL) { LLAMA_LOG_INFO( "%s: EOM token = %d '%s'\n", __func__, special_eom_id, id_to_token.at(special_eom_id).text.c_str() ); } - if (special_unk_id != LLAMA_TOKEN_NULL) { LLAMA_LOG_INFO( "%s: UNK token = %d '%s'\n", __func__, special_unk_id, id_to_token.at(special_unk_id).text.c_str() ); } - if (special_sep_id != LLAMA_TOKEN_NULL) { LLAMA_LOG_INFO( "%s: SEP token = %d '%s'\n", __func__, special_sep_id, id_to_token.at(special_sep_id).text.c_str() ); } - if (special_pad_id != LLAMA_TOKEN_NULL) { LLAMA_LOG_INFO( "%s: PAD token = %d '%s'\n", __func__, special_pad_id, id_to_token.at(special_pad_id).text.c_str() ); } - if (special_mask_id != LLAMA_TOKEN_NULL) { LLAMA_LOG_INFO( "%s: MASK token = %d '%s'\n", __func__, special_mask_id, id_to_token.at(special_mask_id).text.c_str() ); } - - if (linefeed_id != LLAMA_TOKEN_NULL) { LLAMA_LOG_INFO( "%s: LF token = %d '%s'\n", __func__, linefeed_id, id_to_token.at(linefeed_id).text.c_str() ); } - - if (special_fim_pre_id != LLAMA_TOKEN_NULL) { LLAMA_LOG_INFO( "%s: FIM PRE token = %d '%s'\n", __func__, special_fim_pre_id, id_to_token.at(special_fim_pre_id).text.c_str() ); } - if (special_fim_suf_id != LLAMA_TOKEN_NULL) { LLAMA_LOG_INFO( "%s: FIM SUF token = %d '%s'\n", __func__, special_fim_suf_id, id_to_token.at(special_fim_suf_id).text.c_str() ); } - if (special_fim_mid_id != LLAMA_TOKEN_NULL) { LLAMA_LOG_INFO( "%s: FIM MID token = %d '%s'\n", __func__, special_fim_mid_id, id_to_token.at(special_fim_mid_id).text.c_str() ); } - if (special_fim_pad_id != LLAMA_TOKEN_NULL) { LLAMA_LOG_INFO( "%s: FIM PAD token = %d '%s'\n", __func__, special_fim_pad_id, id_to_token.at(special_fim_pad_id).text.c_str() ); } - if (special_fim_rep_id != LLAMA_TOKEN_NULL) { LLAMA_LOG_INFO( "%s: FIM REP token = %d '%s'\n", __func__, special_fim_rep_id, id_to_token.at(special_fim_rep_id).text.c_str() ); } - if (special_fim_sep_id != LLAMA_TOKEN_NULL) { LLAMA_LOG_INFO( "%s: FIM SEP token = %d '%s'\n", __func__, special_fim_sep_id, id_to_token.at(special_fim_sep_id).text.c_str() ); } + if (special_bos_id != LLAMA_TOKEN_NULL) { LLAMA_LOG_INFO( "%s: BOS token = %d '%s'\n", __func__, special_bos_id, id_to_token.at(special_bos_id).text.c_str() ); } + if (special_eos_id != LLAMA_TOKEN_NULL) { LLAMA_LOG_INFO( "%s: EOS token = %d '%s'\n", __func__, special_eos_id, id_to_token.at(special_eos_id).text.c_str() ); } + if (special_eot_id != LLAMA_TOKEN_NULL) { LLAMA_LOG_INFO( "%s: EOT token = %d '%s'\n", __func__, special_eot_id, id_to_token.at(special_eot_id).text.c_str() ); } + if (special_eom_id != LLAMA_TOKEN_NULL) { LLAMA_LOG_INFO( "%s: EOM token = %d '%s'\n", __func__, special_eom_id, id_to_token.at(special_eom_id).text.c_str() ); } + if (special_unk_id != LLAMA_TOKEN_NULL) { LLAMA_LOG_INFO( "%s: UNK token = %d '%s'\n", __func__, special_unk_id, id_to_token.at(special_unk_id).text.c_str() ); } + if (special_sep_id != LLAMA_TOKEN_NULL) { LLAMA_LOG_INFO( "%s: SEP token = %d '%s'\n", __func__, special_sep_id, id_to_token.at(special_sep_id).text.c_str() ); } + if (special_pad_id != LLAMA_TOKEN_NULL) { LLAMA_LOG_INFO( "%s: PAD token = %d '%s'\n", __func__, special_pad_id, id_to_token.at(special_pad_id).text.c_str() ); } + if (special_mask_id != LLAMA_TOKEN_NULL) { LLAMA_LOG_INFO( "%s: MASK token = %d '%s'\n", __func__, special_mask_id, id_to_token.at(special_mask_id).text.c_str() ); } + + if (linefeed_id != LLAMA_TOKEN_NULL) { LLAMA_LOG_INFO( "%s: LF token = %d '%s'\n", __func__, linefeed_id, id_to_token.at(linefeed_id).text.c_str() ); } + + if (special_fim_pre_id != LLAMA_TOKEN_NULL) { LLAMA_LOG_INFO( "%s: FIM PRE token = %d '%s'\n", __func__, special_fim_pre_id, id_to_token.at(special_fim_pre_id).text.c_str() ); } + if (special_fim_suf_id != LLAMA_TOKEN_NULL) { LLAMA_LOG_INFO( "%s: FIM SUF token = %d '%s'\n", __func__, special_fim_suf_id, id_to_token.at(special_fim_suf_id).text.c_str() ); } + if (special_fim_mid_id != LLAMA_TOKEN_NULL) { LLAMA_LOG_INFO( "%s: FIM MID token = %d '%s'\n", __func__, special_fim_mid_id, id_to_token.at(special_fim_mid_id).text.c_str() ); } + if (special_fim_pad_id != LLAMA_TOKEN_NULL) { LLAMA_LOG_INFO( "%s: FIM PAD token = %d '%s'\n", __func__, special_fim_pad_id, id_to_token.at(special_fim_pad_id).text.c_str() ); } + if (special_fim_rep_id != LLAMA_TOKEN_NULL) { LLAMA_LOG_INFO( "%s: FIM REP token = %d '%s'\n", __func__, special_fim_rep_id, id_to_token.at(special_fim_rep_id).text.c_str() ); } + if (special_fim_sep_id != LLAMA_TOKEN_NULL) { LLAMA_LOG_INFO( "%s: FIM SEP token = %d '%s'\n", __func__, special_fim_sep_id, id_to_token.at(special_fim_sep_id).text.c_str() ); } for (const auto & id : special_eog_ids) { - LLAMA_LOG_INFO( "%s: EOG token = %d '%s'\n", __func__, id, id_to_token.at(id).text.c_str() ); + LLAMA_LOG_INFO( "%s: EOG token = %d '%s'\n", __func__, id, id_to_token.at(id).text.c_str() ); } - LLAMA_LOG_INFO("%s: max token length = %d\n", __func__, max_token_len); + LLAMA_LOG_INFO("%s: max token length = %d\n", __func__, max_token_len); } llama_vocab::llama_vocab() : pimpl(new impl(*this)) { @@ -3554,15 +3980,21 @@ bool llama_vocab::get_treat_whitespace_as_suffix() const { return pimpl->treat_whitespace_as_suffix; } +const llama_vocab::normalizer_options & llama_vocab::get_normalizer_opts() const { + return pimpl->normalizer_opts; +} + +const std::vector<llama_token> & llama_vocab::get_suppress_tokens() const { + return pimpl->suppress_tokens; +} + int llama_vocab::max_token_len() const { return pimpl->max_token_len; } int llama_vocab::find_bpe_rank(const std::string & token_left, const std::string & token_right) const { GGML_ASSERT(token_left.find(' ') == std::string::npos); - GGML_ASSERT(token_left.find('\n') == std::string::npos); GGML_ASSERT(token_right.find(' ') == std::string::npos); - GGML_ASSERT(token_right.find('\n') == std::string::npos); auto it = pimpl->bpe_ranks.find(std::make_pair(token_left, token_right)); if (it == pimpl->bpe_ranks.end()) { diff --git a/examples/talk-llama/llama-vocab.h b/examples/talk-llama/llama-vocab.h index 2b240a5491b..707cd4bac4b 100644 --- a/examples/talk-llama/llama-vocab.h +++ b/examples/talk-llama/llama-vocab.h @@ -8,51 +8,62 @@ // pre-tokenization types enum llama_vocab_pre_type { - LLAMA_VOCAB_PRE_TYPE_DEFAULT = 0, - LLAMA_VOCAB_PRE_TYPE_LLAMA3 = 1, - LLAMA_VOCAB_PRE_TYPE_DEEPSEEK_LLM = 2, - LLAMA_VOCAB_PRE_TYPE_DEEPSEEK_CODER = 3, - LLAMA_VOCAB_PRE_TYPE_FALCON = 4, - LLAMA_VOCAB_PRE_TYPE_MPT = 5, - LLAMA_VOCAB_PRE_TYPE_STARCODER = 6, - LLAMA_VOCAB_PRE_TYPE_GPT2 = 7, - LLAMA_VOCAB_PRE_TYPE_REFACT = 8, - LLAMA_VOCAB_PRE_TYPE_COMMAND_R = 9, - LLAMA_VOCAB_PRE_TYPE_STABLELM2 = 10, - LLAMA_VOCAB_PRE_TYPE_QWEN2 = 11, - LLAMA_VOCAB_PRE_TYPE_OLMO = 12, - LLAMA_VOCAB_PRE_TYPE_DBRX = 13, - LLAMA_VOCAB_PRE_TYPE_SMAUG = 14, - LLAMA_VOCAB_PRE_TYPE_PORO = 15, - LLAMA_VOCAB_PRE_TYPE_CHATGLM3 = 16, - LLAMA_VOCAB_PRE_TYPE_CHATGLM4 = 17, - LLAMA_VOCAB_PRE_TYPE_VIKING = 18, - LLAMA_VOCAB_PRE_TYPE_JAIS = 19, - LLAMA_VOCAB_PRE_TYPE_TEKKEN = 20, - LLAMA_VOCAB_PRE_TYPE_SMOLLM = 21, - LLAMA_VOCAB_PRE_TYPE_CODESHELL = 22, - LLAMA_VOCAB_PRE_TYPE_BLOOM = 23, - LLAMA_VOCAB_PRE_TYPE_GPT3_FINNISH = 24, - LLAMA_VOCAB_PRE_TYPE_EXAONE = 25, - LLAMA_VOCAB_PRE_TYPE_CHAMELEON = 26, - LLAMA_VOCAB_PRE_TYPE_MINERVA = 27, - LLAMA_VOCAB_PRE_TYPE_DEEPSEEK3_LLM = 28, - LLAMA_VOCAB_PRE_TYPE_GPT4O = 29, - LLAMA_VOCAB_PRE_TYPE_SUPERBPE = 30, - LLAMA_VOCAB_PRE_TYPE_TRILLION = 31, - LLAMA_VOCAB_PRE_TYPE_BAILINGMOE = 32, - LLAMA_VOCAB_PRE_TYPE_LLAMA4 = 33, - LLAMA_VOCAB_PRE_TYPE_PIXTRAL = 34, - LLAMA_VOCAB_PRE_TYPE_SEED_CODER = 35, - LLAMA_VOCAB_PRE_TYPE_HUNYUAN = 36, - LLAMA_VOCAB_PRE_TYPE_KIMI_K2 = 37, - LLAMA_VOCAB_PRE_TYPE_HUNYUAN_DENSE = 38, - LLAMA_VOCAB_PRE_TYPE_GROK_2 = 39, - LLAMA_VOCAB_PRE_TYPE_GRANITE_DOCLING = 40, - LLAMA_VOCAB_PRE_TYPE_MINIMAX_M2 = 41, - LLAMA_VOCAB_PRE_TYPE_AFMOE = 42, - LLAMA_VOCAB_PRE_TYPE_SOLAR_OPEN = 43, - LLAMA_VOCAB_PRE_TYPE_YOUTU = 44, + LLAMA_VOCAB_PRE_TYPE_DEFAULT = 0, + LLAMA_VOCAB_PRE_TYPE_LLAMA3 = 1, + LLAMA_VOCAB_PRE_TYPE_DEEPSEEK_LLM = 2, + LLAMA_VOCAB_PRE_TYPE_DEEPSEEK_CODER = 3, + LLAMA_VOCAB_PRE_TYPE_FALCON = 4, + LLAMA_VOCAB_PRE_TYPE_MPT = 5, + LLAMA_VOCAB_PRE_TYPE_STARCODER = 6, + LLAMA_VOCAB_PRE_TYPE_GPT2 = 7, + LLAMA_VOCAB_PRE_TYPE_REFACT = 8, + LLAMA_VOCAB_PRE_TYPE_COMMAND_R = 9, + LLAMA_VOCAB_PRE_TYPE_STABLELM2 = 10, + LLAMA_VOCAB_PRE_TYPE_QWEN2 = 11, + LLAMA_VOCAB_PRE_TYPE_OLMO = 12, + LLAMA_VOCAB_PRE_TYPE_DBRX = 13, + LLAMA_VOCAB_PRE_TYPE_SMAUG = 14, + LLAMA_VOCAB_PRE_TYPE_PORO = 15, + LLAMA_VOCAB_PRE_TYPE_CHATGLM3 = 16, + LLAMA_VOCAB_PRE_TYPE_CHATGLM4 = 17, + LLAMA_VOCAB_PRE_TYPE_VIKING = 18, + LLAMA_VOCAB_PRE_TYPE_JAIS = 19, + LLAMA_VOCAB_PRE_TYPE_TEKKEN = 20, + LLAMA_VOCAB_PRE_TYPE_SMOLLM = 21, + LLAMA_VOCAB_PRE_TYPE_CODESHELL = 22, + LLAMA_VOCAB_PRE_TYPE_BLOOM = 23, + LLAMA_VOCAB_PRE_TYPE_GPT3_FINNISH = 24, + LLAMA_VOCAB_PRE_TYPE_EXAONE = 25, + LLAMA_VOCAB_PRE_TYPE_CHAMELEON = 26, + LLAMA_VOCAB_PRE_TYPE_MINERVA = 27, + LLAMA_VOCAB_PRE_TYPE_DEEPSEEK3_LLM = 28, + LLAMA_VOCAB_PRE_TYPE_GPT4O = 29, + LLAMA_VOCAB_PRE_TYPE_SUPERBPE = 30, + LLAMA_VOCAB_PRE_TYPE_TRILLION = 31, + LLAMA_VOCAB_PRE_TYPE_BAILINGMOE = 32, + LLAMA_VOCAB_PRE_TYPE_LLAMA4 = 33, + LLAMA_VOCAB_PRE_TYPE_PIXTRAL = 34, + LLAMA_VOCAB_PRE_TYPE_SEED_CODER = 35, + LLAMA_VOCAB_PRE_TYPE_HUNYUAN = 36, + LLAMA_VOCAB_PRE_TYPE_KIMI_K2 = 37, + LLAMA_VOCAB_PRE_TYPE_HUNYUAN_DENSE = 38, + LLAMA_VOCAB_PRE_TYPE_GROK_2 = 39, + LLAMA_VOCAB_PRE_TYPE_GRANITE_DOCLING = 40, + LLAMA_VOCAB_PRE_TYPE_MINIMAX_M2 = 41, + LLAMA_VOCAB_PRE_TYPE_AFMOE = 42, + LLAMA_VOCAB_PRE_TYPE_SOLAR_OPEN = 43, + LLAMA_VOCAB_PRE_TYPE_YOUTU = 44, + LLAMA_VOCAB_PRE_TYPE_EXAONE_MOE = 45, + LLAMA_VOCAB_PRE_TYPE_QWEN35 = 46, + LLAMA_VOCAB_PRE_TYPE_TINY_AYA = 47, + LLAMA_VOCAB_PRE_TYPE_JOYAI_LLM = 48, + LLAMA_VOCAB_PRE_TYPE_JAIS2 = 49, + LLAMA_VOCAB_PRE_TYPE_GEMMA4 = 50, + LLAMA_VOCAB_PRE_TYPE_SARVAM_MOE = 51, + LLAMA_VOCAB_PRE_TYPE_MINICPM5 = 52, + LLAMA_VOCAB_PRE_TYPE_WHITESPACE = 53, + LLAMA_VOCAB_PRE_TYPE_GRANITE_EMB_MULTI = 54, + LLAMA_VOCAB_PRE_TYPE_MELLUM2 = 55, }; struct LLM_KV; @@ -65,6 +76,12 @@ struct llama_vocab { llama_token_attr attr; }; + struct normalizer_options { + bool lowercase = true; + bool strip_accents = true; + // TODO: clean_text, handle_chinese_chars + }; + llama_vocab(); ~llama_vocab(); @@ -130,6 +147,9 @@ struct llama_vocab { bool get_remove_extra_whitespaces () const; bool get_escape_whitespaces () const; bool get_treat_whitespace_as_suffix() const; + const normalizer_options & get_normalizer_opts() const; + + const std::vector<llama_token> & get_suppress_tokens() const; int max_token_len() const; diff --git a/examples/talk-llama/llama.cpp b/examples/talk-llama/llama.cpp index f1096d960e1..a67fa8039a4 100644 --- a/examples/talk-llama/llama.cpp +++ b/examples/talk-llama/llama.cpp @@ -11,7 +11,9 @@ #include "llama-model.h" #include "ggml.h" +#include "ggml-cpp.h" #include "ggml-backend.h" +#include "gguf.h" #include <algorithm> #include <cassert> @@ -22,6 +24,7 @@ #include <cstring> #include <ctime> #include <stdexcept> +#include <vector> #if defined(_MSC_VER) #pragma warning(disable: 4244 4267) // possible loss of data @@ -43,718 +46,6 @@ const char * llama_flash_attn_type_name(enum llama_flash_attn_type flash_attn_ty GGML_ABORT("fatal error"); } -struct llama_device_memory_data { - int64_t total; - int64_t free; - llama_memory_breakdown_data mb; -}; - -static std::vector<llama_device_memory_data> llama_get_device_memory_data( - const char * path_model, const llama_model_params * mparams, const llama_context_params * cparams, - std::vector<ggml_backend_dev_t> & devs, uint32_t & hp_ngl, uint32_t & hp_n_ctx_train, uint32_t & hp_n_expert, - const ggml_log_level log_level) { - struct user_data_t { - struct { - ggml_log_callback callback; - void * user_data; - } original_logger; - ggml_log_level min_level; // prints below this log level go to debug log - }; - user_data_t ud; - llama_log_get(&ud.original_logger.callback, &ud.original_logger.user_data); - ud.min_level = log_level; - - llama_log_set([](ggml_log_level level, const char * text, void * user_data) { - const user_data_t * ud = (const user_data_t *) user_data; - const ggml_log_level level_eff = level >= ud->min_level ? level : GGML_LOG_LEVEL_DEBUG; - ud->original_logger.callback(level_eff, text, ud->original_logger.user_data); - }, &ud); - - llama_model_params mparams_copy = *mparams; - mparams_copy.no_alloc = true; - mparams_copy.use_mmap = false; - mparams_copy.use_mlock = false; - - llama_model * model = llama_model_load_from_file(path_model, mparams_copy); - if (model == nullptr) { - llama_log_set(ud.original_logger.callback, ud.original_logger.user_data); - throw std::runtime_error("failed to load model"); - } - - llama_context * ctx = llama_init_from_model(model, *cparams); - if (ctx == nullptr) { - llama_model_free(model); - llama_log_set(ud.original_logger.callback, ud.original_logger.user_data); - throw std::runtime_error("failed to create llama_context from model"); - } - - std::vector<llama_device_memory_data> ret(model->devices.size()); - - std::map<ggml_backend_buffer_type_t, llama_memory_breakdown_data> memory_breakdown = ctx->memory_breakdown(); - - for (const auto & [buft, mb] : memory_breakdown) { - if (ggml_backend_buft_is_host(buft)) { - continue; - } - - ggml_backend_dev_t dev = ggml_backend_buft_get_device(buft); - if (!dev) { - continue; - } - for (size_t i = 0; i < ret.size(); i++) { - if (model->devices[i] == dev) { - ret[i].mb.model += mb.model; - ret[i].mb.context += mb.context; - ret[i].mb.compute += mb.compute; - break; - } - } - } - for (size_t i = 0; i < ret.size(); i++) { - size_t free; - size_t total; - ggml_backend_dev_memory(model->devices[i], &free, &total); - - // devices can return 0 bytes for free and total memory if they do not - // have any to report. in this case, we will use the host memory as a fallback - // fixes: https://github.com/ggml-org/llama.cpp/issues/18577 - if (free == 0 && total == 0) { - ggml_backend_dev_t cpu_dev = ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_CPU); - if (cpu_dev == nullptr) { - throw std::runtime_error(format("%s: no CPU backend found", __func__)); - } - ggml_backend_dev_memory(cpu_dev, &free, &total); - } - ret[i].free = free; - ret[i].total = total; - } - - devs = model->devices; - hp_ngl = model->hparams.n_layer; - hp_n_ctx_train = model->hparams.n_ctx_train; - hp_n_expert = model->hparams.n_expert; - - llama_memory_breakdown_print(ctx); // goes to debug log - - llama_free(ctx); - llama_model_free(model); - llama_log_set(ud.original_logger.callback, ud.original_logger.user_data); - return ret; -} - -// enum to identify part of a layer for distributing its tensors: -enum layer_fraction_t { - LAYER_FRACTION_NONE = 0, // nothing - LAYER_FRACTION_ATTN = 1, // attention - LAYER_FRACTION_UP = 2, // attention + up - LAYER_FRACTION_GATE = 3, // attention + up + gate - LAYER_FRACTION_MOE = 4, // everything but sparse MoE weights -}; -// this enum is only used in llama_params_fit_impl but needs to be defined outside of it to fix a Windows compilation issue - -class llama_params_fit_exception : public std::runtime_error { - using std::runtime_error::runtime_error; -}; - -static void llama_params_fit_impl( - const char * path_model, struct llama_model_params * mparams, struct llama_context_params * cparams, - float * tensor_split, struct llama_model_tensor_buft_override * tensor_buft_overrides, - size_t * margins_s, uint32_t n_ctx_min, enum ggml_log_level log_level) { - constexpr int64_t MiB = 1024*1024; - typedef std::vector<llama_device_memory_data> dmds_t; - const llama_model_params default_mparams = llama_model_default_params(); - - std::vector<ggml_backend_dev_t> devs; - uint32_t hp_ngl = 0; // hparams.n_gpu_layers - uint32_t hp_nct = 0; // hparams.n_ctx_train - uint32_t hp_nex = 0; // hparams.n_expert - - // step 1: get data for default parameters and check whether any changes are necessary in the first place - - LLAMA_LOG_DEBUG("%s: getting device memory data for initial parameters:\n", __func__); - const dmds_t dmds_full = llama_get_device_memory_data(path_model, mparams, cparams, devs, hp_ngl, hp_nct, hp_nex, log_level); - const size_t nd = devs.size(); // number of devices - if (nd == 0) { - LLAMA_LOG_INFO("%s: no devices with dedicated memory found\n", __func__); - return; - } - - std::vector<int64_t> margins; // this function uses int64_t rather than size_t for memory sizes to more conveniently handle deficits - margins.reserve(nd); - for (size_t id = 0; id < nd; id++) { - margins.push_back(margins_s[id]); - } - - std::vector<std::string> dev_names; - { - dev_names.reserve(nd); - size_t max_length = 0; - for (ggml_backend_dev_t dev : devs) { - std::string name = ggml_backend_dev_name(dev); - name += " ("; - name += ggml_backend_dev_description(dev); - name += ")"; - dev_names.push_back(name); - max_length = std::max(max_length, name.length()); - } - for (std::string & dn : dev_names) { - dn.insert(dn.end(), max_length - dn.length(), ' '); - } - } - - int64_t sum_free = 0; - int64_t sum_projected_free = 0; - int64_t sum_projected_used = 0; - int64_t sum_projected_model = 0; - std::vector<int64_t> projected_free_per_device; - projected_free_per_device.reserve(nd); - - if (nd > 1) { - LLAMA_LOG_INFO("%s: projected memory use with initial parameters [MiB]:\n", __func__); - } - for (size_t id = 0; id < nd; id++) { - const llama_device_memory_data & dmd = dmds_full[id]; - - const int64_t projected_used = dmd.mb.total(); - const int64_t projected_free = dmd.free - projected_used; - projected_free_per_device.push_back(projected_free); - - sum_free += dmd.free; - sum_projected_used += projected_used; - sum_projected_free += projected_free; - sum_projected_model += dmd.mb.model; - - if (nd > 1) { - LLAMA_LOG_INFO("%s: - %s: %6" PRId64 " total, %6" PRId64 " used, %6" PRId64 " free vs. target of %6" PRId64 "\n", - __func__, dev_names[id].c_str(), dmd.total/MiB, projected_used/MiB, projected_free/MiB, margins[id]/MiB); - } - } - assert(sum_free >= 0 && sum_projected_used >= 0); - LLAMA_LOG_INFO("%s: projected to use %" PRId64 " MiB of device memory vs. %" PRId64 " MiB of free device memory\n", - __func__, sum_projected_used/MiB, sum_free/MiB); - if (nd == 1) { - if (projected_free_per_device[0] >= margins[0]) { - LLAMA_LOG_INFO("%s: will leave %" PRId64 " >= %" PRId64 " MiB of free device memory, no changes needed\n", - __func__, projected_free_per_device[0]/MiB, margins[0]/MiB); - return; - } - } else { - bool changes_needed = false; - for (size_t id = 0; id < nd; id++) { - if (projected_free_per_device[id] < margins[id]) { - changes_needed = true; - break; - } - } - if (!changes_needed) { - LLAMA_LOG_INFO("%s: targets for free memory can be met on all devices, no changes needed\n", __func__); - return; - } - } - - // step 2: try reducing memory use by reducing the context size - - { - int64_t global_surplus = sum_projected_free; - for (size_t id = 0; id < nd; id++) { - global_surplus -= margins[id]; - } - if (global_surplus < 0) { - if (nd == 1) { - LLAMA_LOG_INFO("%s: cannot meet free memory target of %" PRId64 " MiB, need to reduce device memory by %" PRId64 " MiB\n", - __func__, margins[0]/MiB, -global_surplus/MiB); - } else { - LLAMA_LOG_INFO( - "%s: cannot meet free memory targets on all devices, need to use %" PRId64 " MiB less in total\n", - __func__, -global_surplus/MiB); - } - if (cparams->n_ctx == 0) { - if (hp_nct > n_ctx_min) { - int64_t sum_used_target = sum_free; - for (size_t id = 0; id < nd; id++) { - sum_used_target -= margins[id]; - } - if (nd > 1) { - // for multiple devices we need to be more conservative in terms of how much context we think can fit: - // - for dense models only whole layers can be assigned to devices - // - for MoE models only whole tensors can be assigned to devices, which we estimate to be <= 1/3 of a layer - // - on average we expect a waste of 0.5 layers/tensors per device - // - use slightly more than the expected average for nd devices to be safe - const int64_t model_per_layer = sum_projected_model / std::min(uint32_t(mparams->n_gpu_layers), hp_ngl); - sum_used_target -= (nd + 1) * model_per_layer / (hp_nex == 0 ? 2 : 6); - } - - int64_t sum_projected_used_min_ctx = 0; - cparams->n_ctx = n_ctx_min; - const dmds_t dmds_min_ctx = llama_get_device_memory_data(path_model, mparams, cparams, devs, hp_ngl, hp_nct, hp_nex, log_level); - for (const auto & dmd : dmds_min_ctx) { - sum_projected_used_min_ctx += dmd.mb.total(); - } - if (sum_used_target > sum_projected_used_min_ctx) { - // linear interpolation between minimum and maximum context size: - cparams->n_ctx += (hp_nct - n_ctx_min) * (sum_used_target - sum_projected_used_min_ctx) - / (sum_projected_used - sum_projected_used_min_ctx); - cparams->n_ctx = std::max(cparams->n_ctx - cparams->n_ctx % 256, n_ctx_min); // round down context for CUDA backend - - const int64_t bytes_per_ctx = (sum_projected_used - sum_projected_used_min_ctx) / (hp_nct - n_ctx_min); - const int64_t memory_reduction = (hp_nct - cparams->n_ctx) * bytes_per_ctx; - LLAMA_LOG_INFO("%s: context size reduced from %" PRIu32 " to %" PRIu32 " -> need %" PRId64 " MiB less memory in total\n", - __func__, hp_nct, cparams->n_ctx, memory_reduction/MiB); - if (nd == 1) { - LLAMA_LOG_INFO("%s: entire model can be fit by reducing context\n", __func__); - return; - } - LLAMA_LOG_INFO("%s: entire model should be fit across devices by reducing context\n", __func__); - } else { - const int64_t memory_reduction = sum_projected_used - sum_projected_used_min_ctx; - LLAMA_LOG_INFO("%s: context size reduced from %" PRIu32 " to %" PRIu32 " -> need %" PRId64 " MiB less memory in total\n", - __func__, hp_nct, cparams->n_ctx, memory_reduction/MiB); - } - } else { - LLAMA_LOG_INFO("%s: default model context size is %" PRIu32 " which is <= the min. context size of %" PRIu32 " -> no change\n", - __func__, hp_nct, n_ctx_min); - } - } else { - LLAMA_LOG_INFO("%s: context size set by user to %" PRIu32 " -> no change\n", __func__, cparams->n_ctx); - } - } - } - - if (mparams->n_gpu_layers != default_mparams.n_gpu_layers) { - throw llama_params_fit_exception("n_gpu_layers already set by user to " + std::to_string(mparams->n_gpu_layers) + ", abort"); - } - if (nd > 1) { - if (!tensor_split) { - throw llama_params_fit_exception("did not provide a buffer to write the tensor_split to, abort"); - } - if (mparams->tensor_split) { - for (size_t id = 0; id < nd; id++) { - if (mparams->tensor_split[id] != 0.0f) { - throw llama_params_fit_exception("model_params::tensor_split already set by user, abort"); - } - } - } - if (mparams->split_mode == LLAMA_SPLIT_MODE_ROW) { - throw llama_params_fit_exception("changing weight allocation for LLAMA_SPLIT_MODE_ROW not implemented, abort"); - } - } - if (!tensor_buft_overrides) { - throw llama_params_fit_exception("did not provide buffer to set tensor_buft_overrides, abort"); - } - if (mparams->tensor_buft_overrides && (mparams->tensor_buft_overrides->pattern || mparams->tensor_buft_overrides->buft)) { - throw llama_params_fit_exception("model_params::tensor_buft_overrides already set by user, abort"); - } - - // step 3: iteratively fill the back to front with "dense" layers - // - for a dense model simply fill full layers, giving each device a contiguous slice of the model - // - for a MoE model, same as dense model but with all MoE tensors in system memory - - // utility function that returns a static C string matching the tensors for a specific layer index and layer fraction: - auto get_overflow_pattern = [&](const size_t il, const layer_fraction_t lf) -> const char * { - constexpr size_t n_strings = 1000; - if (il >= n_strings) { - throw std::runtime_error("at most " + std::to_string(n_strings) + " model layers are supported"); - } - switch (lf) { - case LAYER_FRACTION_ATTN: { - static std::array<std::string, n_strings> patterns; - if (patterns[il].empty()) { - patterns[il] = "blk\\." + std::to_string(il) + "\\.ffn_(up|gate|down).*"; - } - return patterns[il].c_str(); - } - case LAYER_FRACTION_UP: { - static std::array<std::string, n_strings> patterns; - if (patterns[il].empty()) { - patterns[il] = "blk\\." + std::to_string(il) + "\\.ffn_(gate|down).*"; - } - return patterns[il].c_str(); - } - case LAYER_FRACTION_GATE: { - static std::array<std::string, n_strings> patterns; - if (patterns[il].empty()) { - patterns[il] = "blk\\." + std::to_string(il) + "\\.ffn_down.*"; - } - return patterns[il].c_str(); - } - case LAYER_FRACTION_MOE: { - static std::array<std::string, n_strings> patterns; - if (patterns[il].empty()) { - patterns[il] = "blk\\." + std::to_string(il) + "\\.ffn_(up|down|gate)_(ch|)exps"; - } - return patterns[il].c_str(); - } - default: - GGML_ABORT("fatal error"); - } - }; - - struct ngl_t { - uint32_t n_layer = 0; // number of total layers - uint32_t n_part = 0; // number of partial layers, <= n_layer - - // for the first partial layer varying parts can overflow, all further layers use LAYER_FRACTION_MOE: - layer_fraction_t overflow_type = LAYER_FRACTION_MOE; - - uint32_t n_full() const { - assert(n_layer >= n_part); - return n_layer - n_part; - } - }; - - const size_t ntbo = llama_max_tensor_buft_overrides(); - - // utility function to set n_gpu_layers and tensor_split - auto set_ngl_tensor_split_tbo = [&]( - const std::vector<ngl_t> & ngl_per_device, - const std::vector<ggml_backend_buffer_type_t> & overflow_bufts, - llama_model_params & mparams) { - mparams.n_gpu_layers = 0; - for (size_t id = 0; id < nd; id++) { - mparams.n_gpu_layers += ngl_per_device[id].n_layer; - if (nd > 1) { - tensor_split[id] = ngl_per_device[id].n_layer; - } - } - assert(uint32_t(mparams.n_gpu_layers) <= hp_ngl + 1); - uint32_t il0 = hp_ngl + 1 - mparams.n_gpu_layers; // start index for tensor buft overrides - - mparams.tensor_split = tensor_split; - - size_t itbo = 0; - for (size_t id = 0; id < nd; id++) { - il0 += ngl_per_device[id].n_full(); - for (uint32_t il = il0; il < il0 + ngl_per_device[id].n_part; il++) { - if (itbo + 1 >= ntbo) { - tensor_buft_overrides[itbo].pattern = nullptr; - tensor_buft_overrides[itbo].buft = nullptr; - itbo++; - mparams.tensor_buft_overrides = tensor_buft_overrides; - throw llama_params_fit_exception("llama_max_tensor_buft_overrides() == " - + std::to_string(ntbo) + " is insufficient for model"); - } - tensor_buft_overrides[itbo].pattern = get_overflow_pattern(il, il == il0 ? ngl_per_device[id].overflow_type : LAYER_FRACTION_MOE); - tensor_buft_overrides[itbo].buft = il == il0 ? overflow_bufts[id] : ggml_backend_cpu_buffer_type(); - itbo++; - } - il0 += ngl_per_device[id].n_part; - } - tensor_buft_overrides[itbo].pattern = nullptr; - tensor_buft_overrides[itbo].buft = nullptr; - itbo++; - mparams.tensor_buft_overrides = tensor_buft_overrides; - }; - - // utility function that returns the memory use per device for given numbers of layers per device - auto get_memory_for_layers = [&]( - const char * func_name, - const std::vector<ngl_t> & ngl_per_device, - const std::vector<ggml_backend_buffer_type_t> & overflow_bufts) -> std::vector<int64_t> { - llama_model_params mparams_copy = *mparams; - set_ngl_tensor_split_tbo(ngl_per_device, overflow_bufts, mparams_copy); - - const dmds_t dmd_nl = llama_get_device_memory_data( - path_model, &mparams_copy, cparams, devs, hp_ngl, hp_nct, hp_nex, log_level); - - LLAMA_LOG_DEBUG("%s: memory for test allocation by device:\n", func_name); - for (size_t id = 0; id < nd; id++) { - const ngl_t & n = ngl_per_device[id]; - LLAMA_LOG_DEBUG( - "%s: id=%zu, n_layer=%2" PRIu32 ", n_part=%2" PRIu32 ", overflow_type=%d, mem=%6" PRId64 " MiB\n", - func_name, id, n.n_layer, n.n_part, int(n.overflow_type), dmd_nl[id].mb.total()/MiB); - } - - std::vector<int64_t> ret; - ret.reserve(nd); - for (const llama_device_memory_data & dmd : dmd_nl) { - ret.push_back(dmd.mb.total()); - } - return ret; - }; - - int64_t global_surplus_cpu_moe = 0; - if (hp_nex > 0) { - const static std::string pattern_moe_all = "blk\\.\\d+\\.ffn_(up|down|gate)_(ch|)exps"; // matches all MoE tensors - ggml_backend_buffer_type_t cpu_buft = ggml_backend_cpu_buffer_type(); - tensor_buft_overrides[0] = {pattern_moe_all.c_str(), cpu_buft}; - tensor_buft_overrides[1] = {nullptr, nullptr}; - mparams->tensor_buft_overrides = tensor_buft_overrides; - - LLAMA_LOG_DEBUG("%s: getting device memory data with all MoE tensors moved to system memory:\n", __func__); - const dmds_t dmds_cpu_moe = llama_get_device_memory_data( - path_model, mparams, cparams, devs, hp_ngl, hp_nct, hp_nex, log_level); - - for (size_t id = 0; id < nd; id++) { - global_surplus_cpu_moe += dmds_cpu_moe[id].free; - global_surplus_cpu_moe -= int64_t(dmds_cpu_moe[id].mb.total()) + margins[id]; - } - - if (global_surplus_cpu_moe > 0) { - LLAMA_LOG_INFO("%s: with only dense weights in device memory there is a total surplus of %" PRId64 " MiB\n", - __func__, global_surplus_cpu_moe/MiB); - } else { - LLAMA_LOG_INFO("%s: with only dense weights in device memory there is still a total deficit of %" PRId64 " MiB\n", - __func__, -global_surplus_cpu_moe/MiB); - } - - // reset - tensor_buft_overrides[0] = {nullptr, nullptr}; - mparams->tensor_buft_overrides = tensor_buft_overrides; - } - - std::vector<int64_t> targets; // maximum acceptable memory use per device - targets.reserve(nd); - for (size_t id = 0; id < nd; id++) { - targets.push_back(dmds_full[id].free - margins[id]); - LLAMA_LOG_DEBUG("%s: id=%zu, target=%" PRId64 " MiB\n", __func__, id, targets[id]/MiB); - } - - std::vector<ggml_backend_buffer_type_t> overflow_bufts; // which bufts the first partial layer of a device overflows to: - overflow_bufts.reserve(nd); - for (size_t id = 0; id < nd; id++) { - overflow_bufts.push_back(ggml_backend_cpu_buffer_type()); - } - - std::vector<ngl_t> ngl_per_device(nd); - std::vector<int64_t> mem = get_memory_for_layers(__func__, ngl_per_device, overflow_bufts); - - // optimize the number of layers per device using the method of false position: - // - ngl_per_device has 0 layers for each device, lower bound - // - try a "high" configuration where a device is given all unassigned layers - // - interpolate the memory use / layer between low and high linearly to get a guess where it meets our target - // - check memory use of our guess, replace either the low or high bound - // - once we only have a difference of a single layer, stop and return the lower bound that just barely still fits - // - the last device has the output layer, which cannot be a partial layer - if (hp_nex == 0) { - LLAMA_LOG_INFO("%s: filling dense layers back-to-front:\n", __func__); - } else { - LLAMA_LOG_INFO("%s: filling dense-only layers back-to-front:\n", __func__); - } - for (int id = nd - 1; id >= 0; id--) { - uint32_t n_unassigned = hp_ngl + 1; - for (size_t jd = id + 1; jd < nd; ++jd) { - assert(n_unassigned >= ngl_per_device[jd].n_layer); - n_unassigned -= ngl_per_device[jd].n_layer; - } - - std::vector<ngl_t> ngl_per_device_high = ngl_per_device; - ngl_per_device_high[id].n_layer = n_unassigned; - if (hp_nex > 0) { - ngl_per_device_high[id].n_part = size_t(id) < nd - 1 ? ngl_per_device_high[id].n_layer : ngl_per_device_high[id].n_layer - 1; - } - if (ngl_per_device_high[id].n_layer > 0) { - std::vector<int64_t> mem_high = get_memory_for_layers(__func__, ngl_per_device_high, overflow_bufts); - if (mem_high[id] > targets[id]) { - assert(ngl_per_device_high[id].n_layer > ngl_per_device[id].n_layer); - uint32_t delta = ngl_per_device_high[id].n_layer - ngl_per_device[id].n_layer; - LLAMA_LOG_DEBUG("%s: start filling device %" PRIu32 ", delta=%" PRIu32 "\n", __func__, id, delta); - while (delta > 1) { - uint32_t step_size = int64_t(delta) * (targets[id] - mem[id]) / (mem_high[id] - mem[id]); - step_size = std::max(step_size, uint32_t(1)); - step_size = std::min(step_size, delta - 1); - - std::vector<ngl_t> ngl_per_device_test = ngl_per_device; - ngl_per_device_test[id].n_layer += step_size; - if (hp_nex) { - ngl_per_device_test[id].n_part += size_t(id) == nd - 1 && ngl_per_device_test[id].n_part == 0 ? - step_size - 1 : step_size; // the first layer is the output layer which must always be full - } - const std::vector<int64_t> mem_test = get_memory_for_layers(__func__, ngl_per_device_test, overflow_bufts); - - if (mem_test[id] <= targets[id]) { - ngl_per_device = ngl_per_device_test; - mem = mem_test; - LLAMA_LOG_DEBUG("%s: set ngl_per_device[%d].n_layer=%" PRIu32 "\n", __func__, id, ngl_per_device[id].n_layer); - } else { - ngl_per_device_high = ngl_per_device_test; - mem_high = mem_test; - LLAMA_LOG_DEBUG("%s: set ngl_per_device_high[%d].n_layer=%" PRIu32 "\n", __func__, id, ngl_per_device_high[id].n_layer); - } - delta = ngl_per_device_high[id].n_layer - ngl_per_device[id].n_layer; - } - } else { - assert(ngl_per_device_high[id].n_layer == n_unassigned); - ngl_per_device = ngl_per_device_high; - mem = mem_high; - LLAMA_LOG_DEBUG("%s: set ngl_per_device[%d].n_layer=%" PRIu32 "\n", __func__, id, ngl_per_device[id].n_layer); - } - } - - const int64_t projected_margin = dmds_full[id].free - mem[id]; - LLAMA_LOG_INFO( - "%s: - %s: %2" PRIu32 " layers, %6" PRId64 " MiB used, %6" PRId64 " MiB free\n", - __func__, dev_names[id].c_str(), ngl_per_device[id].n_layer, mem[id]/MiB, projected_margin/MiB); - } - if (hp_nex == 0 || global_surplus_cpu_moe <= 0) { - set_ngl_tensor_split_tbo(ngl_per_device, overflow_bufts, *mparams); - return; - } - - // step 4: for a MoE model where all dense tensors fit, - // convert the dense-only layers in the back to full layers in the front until all devices are full - // essentially the same procedure as for the dense-only layers except front-to-back - // also, try fitting at least part of one more layer to reduce waste for "small" GPUs with e.g. 24 GiB VRAM - - size_t id_dense_start = nd; - for (int id = nd - 1; id >= 0; id--) { - if (ngl_per_device[id].n_layer > 0) { - id_dense_start = id; - continue; - } - break; - } - assert(id_dense_start < nd); - - LLAMA_LOG_INFO("%s: converting dense-only layers to full layers and filling them front-to-back with overflow to next device/system memory:\n", __func__); - for (size_t id = 0; id <= id_dense_start && id_dense_start < nd; id++) { - std::vector<ngl_t> ngl_per_device_high = ngl_per_device; - for (size_t jd = id_dense_start; jd < nd; jd++) { - const uint32_t n_layer_move = jd < nd - 1 ? ngl_per_device_high[jd].n_layer : ngl_per_device_high[jd].n_layer - 1; - ngl_per_device_high[id].n_layer += n_layer_move; - ngl_per_device_high[jd].n_layer -= n_layer_move; - ngl_per_device_high[jd].n_part = 0; - } - size_t id_dense_start_high = nd - 1; - std::vector<int64_t> mem_high = get_memory_for_layers(__func__, ngl_per_device_high, overflow_bufts); - - if (mem_high[id] > targets[id]) { - assert(ngl_per_device_high[id].n_full() >= ngl_per_device[id].n_full()); - uint32_t delta = ngl_per_device_high[id].n_full() - ngl_per_device[id].n_full(); - while (delta > 1) { - uint32_t step_size = int64_t(delta) * (targets[id] - mem[id]) / (mem_high[id] - mem[id]); - step_size = std::max(step_size, uint32_t(1)); - step_size = std::min(step_size, delta - 1); - - std::vector<ngl_t> ngl_per_device_test = ngl_per_device; - size_t id_dense_start_test = id_dense_start; - uint32_t n_converted_test = 0; - for (;id_dense_start_test < nd; id_dense_start_test++) { - const uint32_t n_convert_jd = std::min(step_size - n_converted_test, ngl_per_device_test[id_dense_start_test].n_part); - ngl_per_device_test[id_dense_start_test].n_layer -= n_convert_jd; - ngl_per_device_test[id_dense_start_test].n_part -= n_convert_jd; - ngl_per_device_test[id].n_layer += n_convert_jd; - n_converted_test += n_convert_jd; - - if (ngl_per_device_test[id_dense_start_test].n_part > 0) { - break; - } - } - const std::vector<int64_t> mem_test = get_memory_for_layers(__func__, ngl_per_device_test, overflow_bufts); - - if (mem_test[id] <= targets[id]) { - ngl_per_device = ngl_per_device_test; - mem = mem_test; - id_dense_start = id_dense_start_test; - LLAMA_LOG_DEBUG("%s: set ngl_per_device[%zu].(n_layer, n_part)=(%" PRIu32 ", %" PRIu32 "), id_dense_start=%zu\n", - __func__, id, ngl_per_device[id].n_layer, ngl_per_device[id].n_part, id_dense_start); - } else { - ngl_per_device_high = ngl_per_device_test; - mem_high = mem_test; - id_dense_start_high = id_dense_start_test; - LLAMA_LOG_DEBUG("%s: set ngl_per_device_high[%zu].(n_layer, n_part)=(%" PRIu32 ", %" PRIu32 "), id_dense_start_high=%zu\n", - __func__, id, ngl_per_device_high[id].n_layer, ngl_per_device_high[id].n_part, id_dense_start_high); - } - assert(ngl_per_device_high[id].n_full() >= ngl_per_device[id].n_full()); - delta = ngl_per_device_high[id].n_full() - ngl_per_device[id].n_full(); - } - } else { - ngl_per_device = ngl_per_device_high; - mem = mem_high; - id_dense_start = id_dense_start_high; - LLAMA_LOG_DEBUG("%s: set ngl_per_device[%zu].(n_layer, n_part)=(%" PRIu32 ", %" PRIu32 "), id_dense_start=%zu\n", - __func__, id, ngl_per_device[id].n_layer, ngl_per_device[id].n_part, id_dense_start); - } - - // try to fit at least part of one more layer - if (ngl_per_device[id_dense_start].n_layer > (id < nd - 1 ? 0 : 1)) { - std::vector<ngl_t> ngl_per_device_test = ngl_per_device; - size_t id_dense_start_test = id_dense_start; - ngl_per_device_test[id_dense_start_test].n_layer--; - ngl_per_device_test[id_dense_start_test].n_part--; - ngl_per_device_test[id].n_layer++; - ngl_per_device_test[id].n_part++; - if (ngl_per_device_test[id_dense_start_test].n_part == 0) { - id_dense_start_test++; - } - ngl_per_device_test[id].overflow_type = LAYER_FRACTION_UP; - std::vector<ggml_backend_buffer_type_t> overflow_bufts_test = overflow_bufts; - if (id < nd - 1) { - overflow_bufts_test[id] = ggml_backend_dev_buffer_type(devs[id + 1]); - } - LLAMA_LOG_DEBUG("%s: trying to fit one extra layer with overflow_type=LAYER_FRACTION_UP\n", __func__); - std::vector<int64_t> mem_test = get_memory_for_layers(__func__, ngl_per_device_test, overflow_bufts_test); - if (mem_test[id] < targets[id] && (id + 1 == nd || mem_test[id + 1] < targets[id + 1])) { - ngl_per_device = ngl_per_device_test; - overflow_bufts = overflow_bufts_test; - mem = mem_test; - id_dense_start = id_dense_start_test; - LLAMA_LOG_DEBUG("%s: set ngl_per_device[%zu].(n_layer, n_part, overflow_type)=(%" PRIu32 ", %" PRIu32 ", UP), id_dense_start=%zu\n", - __func__, id, ngl_per_device[id].n_layer, ngl_per_device[id].n_part, id_dense_start); - - ngl_per_device_test[id].overflow_type = LAYER_FRACTION_GATE; - LLAMA_LOG_DEBUG("%s: trying to fit one extra layer with overflow_type=LAYER_FRACTION_GATE\n", __func__); - mem_test = get_memory_for_layers(__func__, ngl_per_device_test, overflow_bufts_test); - if (mem_test[id] < targets[id] && (id + 1 == nd || mem_test[id + 1] < targets[id + 1])) { - ngl_per_device = ngl_per_device_test; - overflow_bufts = overflow_bufts_test; - mem = mem_test; - id_dense_start = id_dense_start_test; - LLAMA_LOG_DEBUG("%s: set ngl_per_device[%zu].(n_layer, n_part, overflow_type)=(%" PRIu32 ", %" PRIu32 ", GATE), id_dense_start=%zu\n", - __func__, id, ngl_per_device[id].n_layer, ngl_per_device[id].n_part, id_dense_start); - } - } else { - ngl_per_device_test[id].overflow_type = LAYER_FRACTION_ATTN; - LLAMA_LOG_DEBUG("%s: trying to fit one extra layer with overflow_type=LAYER_FRACTION_ATTN\n", __func__); - mem_test = get_memory_for_layers(__func__, ngl_per_device_test, overflow_bufts_test); - if (mem_test[id] < targets[id] && (id + 1 == nd || mem_test[id + 1] < targets[id + 1])) { - ngl_per_device = ngl_per_device_test; - overflow_bufts = overflow_bufts_test; - mem = mem_test; - id_dense_start = id_dense_start_test; - LLAMA_LOG_DEBUG("%s: set ngl_per_device[%zu].(n_layer, n_part, overflow_type)=(%" PRIu32 ", %" PRIu32 ", ATTN), id_dense_start=%zu\n", - __func__, id, ngl_per_device[id].n_layer, ngl_per_device[id].n_part, id_dense_start); - } - } - } - - const int64_t projected_margin = dmds_full[id].free - mem[id]; - LLAMA_LOG_INFO( - "%s: - %s: %2" PRIu32 " layers (%2" PRIu32 " overflowing), %6" PRId64 " MiB used, %6" PRId64 " MiB free\n", - __func__, dev_names[id].c_str(), ngl_per_device[id].n_layer, ngl_per_device[id].n_part, mem[id]/MiB, projected_margin/MiB); - } - - // print info for devices that were not changed during the conversion from dense only to full layers: - for (size_t id = id_dense_start + 1; id < nd; id++) { - const int64_t projected_margin = dmds_full[id].free - mem[id]; - LLAMA_LOG_INFO( - "%s: - %s: %2" PRIu32 " layers (%2" PRIu32 " overflowing), %6" PRId64 " MiB used, %6" PRId64 " MiB free\n", - __func__, dev_names[id].c_str(), ngl_per_device[id].n_layer, ngl_per_device[id].n_part, mem[id]/MiB, projected_margin/MiB); - } - - set_ngl_tensor_split_tbo(ngl_per_device, overflow_bufts, *mparams); -} - -enum llama_params_fit_status llama_params_fit( - const char * path_model, struct llama_model_params * mparams, struct llama_context_params * cparams, - float * tensor_split, struct llama_model_tensor_buft_override * tensor_buft_overrides, - size_t * margins, uint32_t n_ctx_min, enum ggml_log_level log_level) { - const int64_t t0_us = llama_time_us(); - llama_params_fit_status status = LLAMA_PARAMS_FIT_STATUS_SUCCESS; - try { - llama_params_fit_impl(path_model, mparams, cparams, tensor_split, tensor_buft_overrides, margins, n_ctx_min, log_level); - LLAMA_LOG_INFO("%s: successfully fit params to free device memory\n", __func__); - } catch (const llama_params_fit_exception & e) { - LLAMA_LOG_WARN("%s: failed to fit params to free device memory: %s\n", __func__, e.what()); - status = LLAMA_PARAMS_FIT_STATUS_FAILURE; - } catch (const std::runtime_error & e) { - LLAMA_LOG_ERROR("%s: encountered an error while trying to fit params to free device memory: %s\n", __func__, e.what()); - status = LLAMA_PARAMS_FIT_STATUS_ERROR; - } - const int64_t t1_us = llama_time_us(); - LLAMA_LOG_INFO("%s: fitting params to free memory took %.2f seconds\n", __func__, (t1_us - t0_us) * 1e-6); - return status; -} - struct llama_sampler_chain_params llama_sampler_chain_default_params() { struct llama_sampler_chain_params result = { /*.no_perf =*/ true, @@ -780,12 +71,18 @@ bool llama_supports_mlock(void) { } bool llama_supports_gpu_offload(void) { + if (!ggml_backend_reg_count()) { + ggml_backend_load_all(); + } return ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_GPU) != nullptr || ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_IGPU) != nullptr || llama_supports_rpc(); } bool llama_supports_rpc(void) { + if (!ggml_backend_reg_count()) { + ggml_backend_load_all(); + } return ggml_backend_reg_by_name("RPC") != nullptr; } @@ -798,6 +95,10 @@ void llama_backend_init(void) { struct ggml_context * ctx = ggml_init(params); ggml_free(ctx); } + + if (!ggml_backend_reg_count()) { + ggml_backend_load_all(); + } } void llama_numa_init(enum ggml_numa_strategy numa) { @@ -820,65 +121,247 @@ int64_t llama_time_us(void) { return ggml_time_us(); } -// Returns 0 on success, -1 on error, and -2 on cancellation via llama_progress_callback -static int llama_model_load(const std::string & fname, std::vector<std::string> & splits, llama_model & model, llama_model_params & params) { - // loading time will be recalculated after the first eval, so - // we take page faults deferred by mmap() into consideration - model.t_load_us = 0; - time_meas tm(model.t_load_us); +// returns true on success +static bool llama_prepare_model_devices(const llama_model_params & params, llama_model * model) { + // create list of devices to use with this model + if (params.devices) { + if (params.split_mode == LLAMA_SPLIT_MODE_TENSOR) { + size_t n_devs = 0; + while (params.devices[n_devs]) { + n_devs++; + } + if (n_devs == 0) { + LLAMA_LOG_ERROR("%s: LLAMA_SPLIT_MODE_TENSOR needs >= 1 devices\n", __func__); + return false; + } + LLAMA_LOG_INFO("%s: creating a Meta device with %zu devices\n", __func__, n_devs); + for (size_t i = 0; i < n_devs; ++i) { + LLAMA_LOG_INFO("%s: - device %zu: %s\n", __func__, i, ggml_backend_dev_name(params.devices[i])); + } + model->get_split_state_ud.n_devices = n_devs; + model->get_split_state_ud.model = model; + model->devices.push_back({ + true, ggml_backend_meta_device( + params.devices, n_devs, llama_meta_device_get_split_state, &model->get_split_state_ud) + }); + } else { + for (ggml_backend_dev_t * dev = params.devices; *dev; ++dev) { + model->devices.push_back({false, *dev}); + } + } + } else { + // default device selection + + // build list of available devices + std::vector<llama_device> gpus; + std::vector<llama_device> igpus; + std::vector<llama_device> rpc_servers; + + if (params.split_mode == LLAMA_SPLIT_MODE_TENSOR) { + std::vector<ggml_backend_dev_t> devs; + devs.reserve(ggml_backend_dev_count()); + for (size_t i = 0; i < ggml_backend_dev_count(); ++i) { + auto * dev = ggml_backend_dev_get(i); + if (ggml_backend_dev_buffer_type(dev) == ggml_backend_cpu_buffer_type()) { + LLAMA_LOG_INFO("%s: skipping %s (%s) for tensor parallelism\n", __func__, ggml_backend_dev_name(dev), ggml_backend_dev_description(dev)); + continue; + } + devs.push_back(dev); + } + if (devs.empty()) { + LLAMA_LOG_ERROR("%s: LLAMA_SPLIT_MODE_TENSOR needs >= 1 devices\n", __func__); + return false; + } + + LLAMA_LOG_INFO("%s: creating a Meta device for tensor parallelism from %zu devices:\n", __func__, devs.size()); + for (size_t i = 0; i < devs.size(); ++i) { + LLAMA_LOG_INFO("%s: - device %zu: %s (%s)\n", __func__, i, ggml_backend_dev_name(devs[i]), ggml_backend_dev_description(devs[i])); + } + + GGML_ASSERT(!devs.empty()); + model->get_split_state_ud.n_devices = devs.size(); + model->get_split_state_ud.model = model; + gpus.push_back({ + true, ggml_backend_meta_device( + devs.data(), devs.size(), llama_meta_device_get_split_state, &model->get_split_state_ud) + }); + } else { + for (size_t i = 0; i < ggml_backend_dev_count(); ++i) { + ggml_backend_dev_t dev = ggml_backend_dev_get(i); + switch (ggml_backend_dev_type(dev)) { + case GGML_BACKEND_DEVICE_TYPE_CPU: + case GGML_BACKEND_DEVICE_TYPE_ACCEL: + // skip CPU backends since they are handled separately + break; + + case GGML_BACKEND_DEVICE_TYPE_GPU: { + ggml_backend_reg_t reg = ggml_backend_dev_backend_reg(dev); + if (ggml_backend_reg_name(reg) == std::string("RPC")) { + rpc_servers.push_back({false, dev}); + } else { + // check if there is already a GPU with the same device id + ggml_backend_dev_props props; + ggml_backend_dev_get_props(dev, &props); + auto it = std::find_if(gpus.begin(), gpus.end(), [&props](const llama_device & d) { + ggml_backend_dev_props d_props; + ggml_backend_dev_get_props(d.dev, &d_props); + if (props.device_id && d_props.device_id) { + return strcmp(props.device_id, d_props.device_id) == 0; + } + return false; + }); + + if (it != gpus.end()) { + LLAMA_LOG_INFO("%s: skipping device %s (%s) with id %s - already using device %s (%s) with the same id\n", + __func__, + ggml_backend_dev_name(dev), ggml_backend_dev_description(dev), + props.device_id ? props.device_id : "unknown id", + ggml_backend_dev_name(it->dev), ggml_backend_dev_description(it->dev)); + } else { + gpus.push_back({false, dev}); + } + } + break; + } + + case GGML_BACKEND_DEVICE_TYPE_IGPU: + if (igpus.empty()) { + igpus.push_back({false, dev}); + } + break; + case GGML_BACKEND_DEVICE_TYPE_META: + GGML_ABORT("fatal error"); + } + } + } + + // add RPC servers at the front of the list to minimize network transfers + model->devices.insert(model->devices.begin(), rpc_servers.begin(), rpc_servers.end()); + + // add GPUs + model->devices.insert(model->devices.end(), gpus.begin(), gpus.end()); + + // add integrated GPUs only if no discrete GPUs were found + // (RPC servers do not count, otherwise the local iGPU would be dropped on iGPU+RPC setups) + if (gpus.empty()) { + model->devices.insert(model->devices.end(), igpus.begin(), igpus.end()); + } + } + + // if using single GPU mode, remove all except the main GPU + if (params.split_mode == LLAMA_SPLIT_MODE_NONE) { + if (params.main_gpu < 0) { + model->devices.clear(); + } else { + if (params.main_gpu >= (int)model->devices.size()) { + LLAMA_LOG_ERROR("%s: invalid value for main_gpu: %d (available devices: %zu)\n", __func__, params.main_gpu, model->devices.size()); + return false; + } + llama_device main_gpu = model->devices[params.main_gpu]; + model->devices.clear(); + model->devices.push_back(main_gpu); + } + } + + for (const auto & dev : model->devices) { + ggml_backend_dev_props props; + ggml_backend_dev_get_props(dev.dev, &props); + LLAMA_LOG_INFO("%s: using device %s (%s) (%s) - %zu MiB free\n", __func__, + ggml_backend_dev_name(dev.dev), ggml_backend_dev_description(dev.dev), + props.device_id ? props.device_id : "unknown id", + props.memory_free/1024/1024); + } - model.t_start_us = tm.t_start_us; + return true; +} +// Returns 0 on success, -1 on error, and -2 on cancellation via llama_progress_callback +static std::pair<int, llama_model *> llama_model_load(struct gguf_context * metadata, llama_model_set_tensor_data_t set_tensor_data, void * set_tensor_data_ud, + const std::string & fname, std::vector<std::string> & splits, FILE * file, llama_model_params & params) { try { - llama_model_loader ml(fname, splits, params.use_mmap, params.use_direct_io, params.check_tensors, params.no_alloc, params.kv_overrides, params.tensor_buft_overrides); + llama_model_loader ml(metadata, set_tensor_data, set_tensor_data_ud, fname, splits, file, params.use_mmap, params.use_direct_io, + params.check_tensors, params.no_alloc, params.kv_overrides, params.tensor_buft_overrides); ml.print_info(); + std::unique_ptr<llama_model> model_ptr(llama_model_create(ml, params)); - model.hparams.vocab_only = params.vocab_only; - model.hparams.no_alloc = params.no_alloc; + bool ok = llama_prepare_model_devices(params, model_ptr.get()); + if (!ok) { + return {-1, nullptr}; + } - try { - model.load_arch(ml); - } catch(const std::exception & e) { - throw std::runtime_error("error loading model architecture: " + std::string(e.what())); + auto * model = dynamic_cast<llama_model_base *>(model_ptr.get()); + if (model == nullptr) { + GGML_ABORT("fatal error: model does not implement llama_model_base"); } + + // loading time will be recalculated after the first eval, so + // we take page faults deferred by mmap() into consideration + model->t_load_us = 0; + time_meas tm(model->t_load_us); + + model->t_start_us = tm.t_start_us; + + model->hparams.vocab_only = params.vocab_only; + model->hparams.no_alloc = params.no_alloc; + try { - model.load_hparams(ml); + model->load_hparams(ml); } catch(const std::exception & e) { throw std::runtime_error("error loading model hyperparameters: " + std::string(e.what())); } - if (model.arch == LLM_ARCH_CLIP) { + if (model->arch == LLM_ARCH_CLIP) { throw std::runtime_error("CLIP cannot be used as main model, use it with --mmproj instead"); } try { - model.load_vocab(ml); + model->load_vocab(ml); } catch(const std::exception & e) { throw std::runtime_error("error loading model vocabulary: " + std::string(e.what())); } - model.load_stats(ml); - model.print_info(); + model->load_stats(ml); + model->print_info(); if (params.vocab_only) { LLAMA_LOG_INFO("%s: vocab only - skipping tensors\n", __func__); - return 0; + return {0, model_ptr.release()}; } - if (!model.load_tensors(ml)) { - return -2; + if (!model->load_tensors(ml)) { + return {-2, nullptr}; } + + return {0, model_ptr.release()}; } catch (const std::exception & err) { LLAMA_LOG_ERROR("%s: error loading model: %s\n", __func__, err.what()); - return -1; + return {-1, nullptr}; } - - return 0; } static struct llama_model * llama_model_load_from_file_impl( + struct gguf_context * metadata, + llama_model_set_tensor_data_t set_tensor_data, + void * set_tensor_data_ud, const std::string & path_model, std::vector<std::string> & splits, + FILE * file, struct llama_model_params params) { + { + int n_sources_defined = 0; + if (metadata != nullptr) { + n_sources_defined++; + } + if (!path_model.empty()) { + n_sources_defined++; + } + if (file != nullptr) { + n_sources_defined++; + } + if (n_sources_defined != 1) { + LLAMA_LOG_ERROR("%s: exactly one out metadata, path_model, and file must be defined\n", __func__); + return nullptr; + } + } ggml_time_init(); if (!params.vocab_only && ggml_backend_reg_count() == 0) { @@ -903,103 +386,7 @@ static struct llama_model * llama_model_load_from_file_impl( }; } - llama_model * model = new llama_model(params); - - // create list of devices to use with this model - if (params.devices) { - for (ggml_backend_dev_t * dev = params.devices; *dev; ++dev) { - model->devices.push_back(*dev); - } - } else { - // default device selection - - // build list of available devices - std::vector<ggml_backend_dev_t> gpus; - std::vector<ggml_backend_dev_t> igpus; - std::vector<ggml_backend_dev_t> rpc_servers; - - for (size_t i = 0; i < ggml_backend_dev_count(); ++i) { - ggml_backend_dev_t dev = ggml_backend_dev_get(i); - switch (ggml_backend_dev_type(dev)) { - case GGML_BACKEND_DEVICE_TYPE_CPU: - case GGML_BACKEND_DEVICE_TYPE_ACCEL: - // skip CPU backends since they are handled separately - break; - - case GGML_BACKEND_DEVICE_TYPE_GPU: { - ggml_backend_reg_t reg = ggml_backend_dev_backend_reg(dev); - if (ggml_backend_reg_name(reg) == std::string("RPC")) { - rpc_servers.push_back(dev); - } else { - // check if there is already a GPU with the same device id - ggml_backend_dev_props props; - ggml_backend_dev_get_props(dev, &props); - auto it = std::find_if(gpus.begin(), gpus.end(), [&props](ggml_backend_dev_t d) { - ggml_backend_dev_props d_props; - ggml_backend_dev_get_props(d, &d_props); - if (props.device_id && d_props.device_id) { - return strcmp(props.device_id, d_props.device_id) == 0; - } - return false; - }); - - if (it != gpus.end()) { - LLAMA_LOG_INFO("%s: skipping device %s (%s) with id %s - already using device %s (%s) with the same id\n", - __func__, - ggml_backend_dev_name(dev), ggml_backend_dev_description(dev), - props.device_id ? props.device_id : "unknown id", - ggml_backend_dev_name(*it), ggml_backend_dev_description(*it)); - } else { - gpus.push_back(dev); - } - } - break; - } - - case GGML_BACKEND_DEVICE_TYPE_IGPU: - igpus.push_back(dev); - break; - } - } - - // add RPC servers at the front of the list to minimize network transfers - model->devices.insert(model->devices.begin(), rpc_servers.begin(), rpc_servers.end()); - - // add GPUs - model->devices.insert(model->devices.end(), gpus.begin(), gpus.end()); - - // add integrated GPUs only if no other devices were found - if (model->devices.empty()) { - model->devices.insert(model->devices.end(), igpus.begin(), igpus.end()); - } - } - - // if using single GPU mode, remove all except the main GPU - if (params.split_mode == LLAMA_SPLIT_MODE_NONE) { - if (params.main_gpu < 0) { - model->devices.clear(); - } else { - if (params.main_gpu >= (int)model->devices.size()) { - LLAMA_LOG_ERROR("%s: invalid value for main_gpu: %d (available devices: %zu)\n", __func__, params.main_gpu, model->devices.size()); - llama_model_free(model); - return nullptr; - } - ggml_backend_dev_t main_gpu = model->devices[params.main_gpu]; - model->devices.clear(); - model->devices.push_back(main_gpu); - } - } - - for (auto * dev : model->devices) { - ggml_backend_dev_props props; - ggml_backend_dev_get_props(dev, &props); - LLAMA_LOG_INFO("%s: using device %s (%s) (%s) - %zu MiB free\n", __func__, - ggml_backend_dev_name(dev), ggml_backend_dev_description(dev), - props.device_id ? props.device_id : "unknown id", - props.memory_free/1024/1024); - } - - const int status = llama_model_load(path_model, splits, *model, params); + const auto [status, model] = llama_model_load(metadata, set_tensor_data, set_tensor_data_ud, path_model, splits, file, params); GGML_ASSERT(status <= 0); if (status < 0) { if (status == -1) { @@ -1008,13 +395,27 @@ static struct llama_model * llama_model_load_from_file_impl( LLAMA_LOG_INFO("%s: cancelled model load\n", __func__); } - llama_model_free(model); + if (model) { + llama_model_free(model); + } return nullptr; } return model; } +struct llama_model * llama_model_init_from_user( + struct gguf_context * metadata, + llama_model_set_tensor_data_t set_tensor_data, + void * set_tensor_data_ud, + struct llama_model_params params) { + GGML_ASSERT(metadata != nullptr); + std::string path_model; + std::vector<std::string> splits = {}; + params.use_mmap = false; + params.use_extra_bufts = false; + return llama_model_load_from_file_impl(metadata, set_tensor_data, set_tensor_data_ud, path_model, splits, /*file*/ nullptr, params); +} // deprecated struct llama_model * llama_load_model_from_file( const char * path_model, @@ -1026,7 +427,7 @@ struct llama_model * llama_model_load_from_file( const char * path_model, struct llama_model_params params) { std::vector<std::string> splits = {}; - return llama_model_load_from_file_impl(path_model, splits, params); + return llama_model_load_from_file_impl(nullptr, nullptr, nullptr, path_model, splits, /*file*/ nullptr, params); } struct llama_model * llama_model_load_from_splits( @@ -1042,11 +443,21 @@ struct llama_model * llama_model_load_from_splits( for (size_t i = 0; i < n_paths; ++i) { splits.push_back(paths[i]); } - return llama_model_load_from_file_impl(splits.front(), splits, params); + return llama_model_load_from_file_impl(nullptr, nullptr, nullptr, splits.front(), splits, /*file*/ nullptr, params); +} + +struct llama_model * llama_model_load_from_file_ptr(FILE * file, struct llama_model_params params) { + if (!file) { + LLAMA_LOG_ERROR("%s: file is NULL\n", __func__); + return nullptr; + } + std::string path_model; + std::vector<std::string> splits = {}; + return llama_model_load_from_file_impl(nullptr, nullptr, nullptr, path_model, splits, file, params); } void llama_model_save_to_file(const struct llama_model * model, const char * path_model) { - llama_model_saver ms(*model); + llama_model_saver ms(model); ms.add_kv_from_model(); ms.add_tensors_from_model(); ms.save(path_model); @@ -1091,25 +502,55 @@ int32_t llama_chat_apply_template( // model split // -int llama_split_path(char * split_path, size_t maxlen, const char * path_prefix, int split_no, int split_count) { +int32_t llama_split_path( + char * split_path, + size_t maxlen, + const char * path_prefix, + int32_t split_no, + int32_t split_count) { + static const char * const SPLIT_PATH_FORMAT = "%s-%05d-of-%05d.gguf"; - if (snprintf(split_path, maxlen, SPLIT_PATH_FORMAT, path_prefix, split_no + 1, split_count)) { - return strlen(split_path); + + const int written = snprintf( + split_path, + maxlen, + SPLIT_PATH_FORMAT, + path_prefix, + split_no + 1, + split_count + ); + + if (written < 0 || (size_t) written >= maxlen) { + return 0; } - return 0; + + return (int32_t) written; } -int llama_split_prefix(char * split_prefix, size_t maxlen, const char * split_path, int split_no, int split_count) { - std::string str_split_path(split_path); +int32_t llama_split_prefix( + char * split_prefix, + size_t maxlen, + const char * split_path, + int32_t split_no, + int32_t split_count) { + + const std::string str_split_path(split_path); + char postfix[32]; - snprintf(postfix, 32, "-%05d-of-%05d.gguf", split_no + 1, split_count); - std::string str_postfix(postfix); - - // check if split_prefix ends with postfix - int size_prefix = str_split_path.size() - str_postfix.size(); - if (size_prefix > 0 && str_split_path.find(str_postfix, size_prefix) != std::string::npos) { - snprintf(split_prefix, std::min((size_t) size_prefix + 1, maxlen), "%s", split_path); - return size_prefix; + snprintf(postfix, sizeof(postfix), "-%05d-of-%05d.gguf", split_no + 1, split_count); + + const std::string str_postfix(postfix); + if (str_split_path.size() <= str_postfix.size()) { + return 0; + } + + const size_t size_prefix = str_split_path.size() - str_postfix.size(); + + if (str_split_path.compare(size_prefix, std::string::npos, str_postfix) == 0) { + const size_t copy_len = std::min(size_prefix + 1, maxlen); + snprintf(split_prefix, copy_len, "%s", split_path); + + return (int32_t) size_prefix; } return 0; diff --git a/examples/talk-llama/llama.h b/examples/talk-llama/llama.h index 1c17efb9fa1..27e48067428 100644 --- a/examples/talk-llama/llama.h +++ b/examples/talk-llama/llama.h @@ -5,6 +5,7 @@ #include "ggml-cpu.h" #include "ggml-backend.h" #include "ggml-opt.h" +#include "gguf.h" #include <stddef.h> #include <stdint.h> @@ -152,6 +153,8 @@ extern "C" { LLAMA_FTYPE_MOSTLY_TQ1_0 = 36, // except 1d tensors LLAMA_FTYPE_MOSTLY_TQ2_0 = 37, // except 1d tensors LLAMA_FTYPE_MOSTLY_MXFP4_MOE = 38, // except 1d tensors + LLAMA_FTYPE_MOSTLY_NVFP4 = 39, // except 1d tensors + LLAMA_FTYPE_MOSTLY_Q1_0 = 40, // except 1d tensors LLAMA_FTYPE_GUESSED = 1024, // not specified in the model file }; @@ -189,9 +192,15 @@ extern "C" { LLAMA_API const char * llama_flash_attn_type_name(enum llama_flash_attn_type flash_attn_type); enum llama_split_mode { - LLAMA_SPLIT_MODE_NONE = 0, // single GPU - LLAMA_SPLIT_MODE_LAYER = 1, // split layers and KV across GPUs - LLAMA_SPLIT_MODE_ROW = 2, // split layers and KV across GPUs, use tensor parallelism if supported + LLAMA_SPLIT_MODE_NONE = 0, // single GPU + LLAMA_SPLIT_MODE_LAYER = 1, // split layers and KV across GPUs + LLAMA_SPLIT_MODE_ROW = 2, // split layers and KV across GPUs, use tensor parallelism if supported + LLAMA_SPLIT_MODE_TENSOR = 3, + }; + + enum llama_context_type { + LLAMA_CONTEXT_TYPE_DEFAULT = 0, + LLAMA_CONTEXT_TYPE_MTP = 1, }; // TODO: simplify (https://github.com/ggml-org/llama.cpp/pull/9294#pullrequestreview-2286561979) @@ -309,7 +318,7 @@ extern "C" { // Keep the booleans together to avoid misalignment during copy-by-value. bool vocab_only; // only load the vocabulary, no weights bool use_mmap; // use mmap if possible - bool use_direct_io; // use direct io, takes precedence over use_mmap + bool use_direct_io; // use direct io, takes precedence over use_mmap when supported bool use_mlock; // force system to keep model in RAM bool check_tensors; // validate model tensor data bool use_extra_bufts; // use extra buffer types (used for weight repacking) @@ -329,9 +338,12 @@ extern "C" { uint32_t n_batch; // logical maximum batch size that can be submitted to llama_decode uint32_t n_ubatch; // physical maximum batch size uint32_t n_seq_max; // max number of sequences (i.e. distinct states for recurrent models) + uint32_t n_rs_seq; // number of recurrent-state snapshots per seq for rollback (0 = no rollback) [EXPERIMENTAL] + uint32_t n_outputs_max; // max outputs in a ubatch (0 = n_batch) int32_t n_threads; // number of threads to use for generation int32_t n_threads_batch; // number of threads to use for batch processing + enum llama_context_type ctx_type; // set the context type (e.g. MTP) enum llama_rope_scaling_type rope_scaling_type; // RoPE scaling type, from `enum llama_rope_scaling_type` enum llama_pooling_type pooling_type; // whether to pool (sum) embedding results by sequence id enum llama_attention_type attention_type; // attention type to use for embeddings @@ -376,23 +388,39 @@ extern "C" { // note: the samplers must be sampler chains (i.e. use llama_sampler_chain_init) struct llama_sampler_seq_config * samplers; size_t n_samplers; + + // a source/target/parent context + // can be utilized in various ways, for example by sharing results or llama_memory between 2 contexts + struct llama_context * ctx_other; + }; + + struct llama_model_tensor_override { + const char * pattern; + enum ggml_type type; + }; + + struct llama_model_imatrix_data { + const char * name; + const float * data; + size_t size; }; // model quantization parameters typedef struct llama_model_quantize_params { - int32_t nthread; // number of threads to use for quantizing, if <=0 will use std::thread::hardware_concurrency() - enum llama_ftype ftype; // quantize to this llama_ftype - enum ggml_type output_tensor_type; // output tensor type - enum ggml_type token_embedding_type; // token embeddings tensor type - bool allow_requantize; // allow quantizing non-f32/f16 tensors - bool quantize_output_tensor; // quantize output.weight - bool only_copy; // only copy tensors - ftype, allow_requantize and quantize_output_tensor are ignored - bool pure; // quantize all tensors to the default type - bool keep_split; // quantize to the same number of shards - void * imatrix; // pointer to importance matrix data - void * kv_overrides; // pointer to vector containing overrides - void * tensor_types; // pointer to vector containing tensor types - void * prune_layers; // pointer to vector containing layer indices to prune + int32_t nthread; // number of threads to use for quantizing, if <=0 will use std::thread::hardware_concurrency() + enum llama_ftype ftype; // quantize to this llama_ftype + enum ggml_type output_tensor_type; // output tensor type + enum ggml_type token_embedding_type; // token embeddings tensor type + bool allow_requantize; // allow quantizing non-f32/f16 tensors + bool quantize_output_tensor; // quantize output.weight + bool only_copy; // only copy tensors - ftype, allow_requantize and quantize_output_tensor are ignored + bool pure; // quantize all tensors to the default type + bool keep_split; // quantize to the same number of shards + bool dry_run; // calculate and show the final quantization size without performing quantization + const struct llama_model_imatrix_data * imatrix; // pointer to importance matrix data + const struct llama_model_kv_override * kv_overrides; // pointer to kv overrides + const struct llama_model_tensor_override * tt_overrides; // pointer to tensor overrides + const int32_t * prune_layers; // pointer to layer indices to prune } llama_model_quantize_params; typedef struct llama_logit_bias { @@ -439,19 +467,35 @@ extern "C" { LLAMA_API void llama_detach_threadpool(struct llama_context * ctx); + typedef void (*llama_model_set_tensor_data_t)(struct ggml_tensor * tensor, void * userdata); + + // Create a new model from GGUF metadata as well as a function to set the tensor data + // - tensors are created as GGML_TYPE_F32 by default, + // override by adding a tensor with the same name but a different name to the context + LLAMA_API struct llama_model * llama_model_init_from_user( + struct gguf_context * metadata, + llama_model_set_tensor_data_t set_tensor_data, // function to initialize tensor data with + void * set_tensor_data_ud, // userdata for function + struct llama_model_params params); + DEPRECATED(LLAMA_API struct llama_model * llama_load_model_from_file( const char * path_model, struct llama_model_params params), "use llama_model_load_from_file instead"); - // Load the model from a file + // Load a model from a file // If the file is split into multiple parts, the file name must follow this pattern: <name>-%05d-of-%05d.gguf // If the split file name does not follow this pattern, use llama_model_load_from_splits LLAMA_API struct llama_model * llama_model_load_from_file( const char * path_model, struct llama_model_params params); - // Load the model from multiple splits (support custom naming scheme) + // Load a model from an open FILE pointer + LLAMA_API struct llama_model * llama_model_load_from_file_ptr( + FILE * file, + struct llama_model_params params); + + // Load a model from multiple splits (support custom naming scheme) // The paths must be in the correct order LLAMA_API struct llama_model * llama_model_load_from_splits( const char ** paths, @@ -479,26 +523,6 @@ extern "C" { // Frees all allocated memory LLAMA_API void llama_free(struct llama_context * ctx); - enum llama_params_fit_status { - LLAMA_PARAMS_FIT_STATUS_SUCCESS = 0, // found allocations that are projected to fit - LLAMA_PARAMS_FIT_STATUS_FAILURE = 1, // could not find allocations that are projected to fit - LLAMA_PARAMS_FIT_STATUS_ERROR = 2, // a hard error occured, e.g. because no model could be found at the specified path - }; - - // fits mparams and cparams to free device memory (assumes system memory is unlimited) - // - returns true if the parameters could be successfully modified to fit device memory - // - this function is NOT thread safe because it modifies the global llama logger state - // - only parameters that have the same value as in llama_default_model_params are modified - LLAMA_API enum llama_params_fit_status llama_params_fit( - const char * path_model, - struct llama_model_params * mparams, - struct llama_context_params * cparams, - float * tensor_split, // writable buffer for tensor split, needs at least llama_max_devices elements - struct llama_model_tensor_buft_override * tensor_buft_overrides, // writable buffer for overrides, needs at least llama_max_tensor_buft_overrides elements - size_t * margins, // margins of memory to leave per device in bytes - uint32_t n_ctx_min, // minimum context size to set when trying to reduce memory use - enum ggml_log_level log_level); // minimum log level to print during fitting, lower levels go to debug log - LLAMA_API int64_t llama_time_us(void); LLAMA_API size_t llama_max_devices(void); @@ -518,6 +542,7 @@ extern "C" { LLAMA_API uint32_t llama_n_batch (const struct llama_context * ctx); LLAMA_API uint32_t llama_n_ubatch (const struct llama_context * ctx); LLAMA_API uint32_t llama_n_seq_max (const struct llama_context * ctx); + LLAMA_API uint32_t llama_n_rs_seq (const struct llama_context * ctx); DEPRECATED(LLAMA_API int32_t llama_n_ctx_train(const struct llama_model * model), "use llama_model_n_ctx_train instead"); DEPRECATED(LLAMA_API int32_t llama_n_embd (const struct llama_model * model), "use llama_model_n_embd instead"); @@ -621,7 +646,6 @@ extern "C" { // Load a LoRA adapter from file // The adapter is valid as long as the associated model is not freed - // All adapters must be loaded before context creation LLAMA_API struct llama_adapter_lora * llama_adapter_lora_init( struct llama_model * model, const char * path_lora); @@ -645,7 +669,7 @@ extern "C" { LLAMA_API int32_t llama_adapter_meta_val_str_by_index(const struct llama_adapter_lora * adapter, int32_t i, char * buf, size_t buf_size); // Manually free a LoRA adapter - // NOTE: loaded adapters will be free when the associated model is deleted + // NOTE: loaded adapters that are not manually freed will be freed when the associated model is deleted LLAMA_API void llama_adapter_lora_free(struct llama_adapter_lora * adapter); // Get the invocation tokens if the current lora is an alora @@ -654,21 +678,12 @@ extern "C" { // The following functions operate on a llama_context, hence the naming: llama_verb_... - // Add a loaded LoRA adapter to given context - // This will not modify model's weight - LLAMA_API int32_t llama_set_adapter_lora( - struct llama_context * ctx, - struct llama_adapter_lora * adapter, - float scale); - - // Remove a specific LoRA adapter from given context - // Return -1 if the adapter is not present in the context - LLAMA_API int32_t llama_rm_adapter_lora( + // Set LoRa adapters on the context. Will only modify if the adapters currently in context are different. + LLAMA_API int32_t llama_set_adapters_lora( struct llama_context * ctx, - struct llama_adapter_lora * adapter); - - // Remove all LoRA adapters from given context - LLAMA_API void llama_clear_adapter_lora(struct llama_context * ctx); + struct llama_adapter_lora ** adapters, + size_t n_adapters, + float * scales); // Apply a loaded control vector to a llama_context, or if data is NULL, clear // the currently loaded vector. @@ -676,7 +691,7 @@ extern "C" { // to an n_embd x n_layers buffer starting from layer 1. // il_start and il_end are the layer range the vector should apply to (both inclusive) // See llama_control_vector_load in common to load a control vector. - LLAMA_API int32_t llama_apply_adapter_cvec( + LLAMA_API int32_t llama_set_adapter_cvec( struct llama_context * ctx, const float * data, size_t len, @@ -856,12 +871,18 @@ extern "C" { size_t n_token_capacity, size_t * n_token_count_out); +#define LLAMA_STATE_SEQ_FLAGS_NONE 0 + // for backwards-compat #define LLAMA_STATE_SEQ_FLAGS_SWA_ONLY 1 // work only with partial states, such as SWA KV cache or recurrent cache (e.g. Mamba) #define LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY 1 +// Keeps the tensor data on device buffers (i.e. not accessible in host memory, but faster save/load). +// Getting the state for a seq_id with this flag invalidates all prior states gotten for that seq_id with this flag. +#define LLAMA_STATE_SEQ_FLAGS_ON_DEVICE 2 + typedef uint32_t llama_state_seq_flags; LLAMA_API size_t llama_state_seq_get_size_ext( @@ -959,7 +980,11 @@ extern "C" { // Set whether the model is in warmup mode or not // If true, all model tensors are activated during llama_decode() to load and cache their weights. - LLAMA_API void llama_set_warmup(struct llama_context * ctx, bool warmup); + // + // note: using this can cause extra graph reallocations because it changes the graph topology with MoE models, + // so it is generally not recommended to use in practice. will be removed in the future + DEPRECATED(LLAMA_API void llama_set_warmup(struct llama_context * ctx, bool warmup), + "user code should do warmup runs manually [TAG_LLAMA_GRAPH_NO_WARMUP]"); // Set abort callback LLAMA_API void llama_set_abort_callback(struct llama_context * ctx, ggml_abort_callback abort_callback, void * abort_callback_data); @@ -979,7 +1004,7 @@ extern "C" { // Logits for the ith token. For positive indices, Equivalent to: // llama_get_logits(ctx) + ctx->output_ids[i]*n_vocab - // Negative indicies can be used to access logits in reverse order, -1 is the last logit. + // Negative indices can be used to access logits in reverse order, -1 is the last logit. // returns NULL for invalid ids. LLAMA_API float * llama_get_logits_ith(struct llama_context * ctx, int32_t i); @@ -994,7 +1019,7 @@ extern "C" { // Get the embeddings for the ith token. For positive indices, Equivalent to: // llama_get_embeddings(ctx) + ctx->output_ids[i]*n_embd - // Negative indicies can be used to access embeddings in reverse order, -1 is the last embedding. + // Negative indices can be used to access embeddings in reverse order, -1 is the last embedding. // shape: [n_embd] (1-dimensional) // returns NULL for invalid ids. LLAMA_API float * llama_get_embeddings_ith(struct llama_context * ctx, int32_t i); @@ -1014,9 +1039,9 @@ extern "C" { // Returns LLAMA_TOKEN_NULL if no token was sampled. LLAMA_API llama_token llama_get_sampled_token_ith(struct llama_context * ctx, int32_t i); - // Get the backend sampled probabilites for the ith token + // Get the backend sampled probabilities for the ith token // The index matches llama_get_sampled_token_ith(). - // Returns NULL if no probabilites were generated. + // Returns NULL if no probabilities were generated. LLAMA_API float * llama_get_sampled_probs_ith (struct llama_context * ctx, int32_t i); LLAMA_API uint32_t llama_get_sampled_probs_count_ith(struct llama_context * ctx, int32_t i); @@ -1148,9 +1173,9 @@ extern "C" { // /// Apply chat template. Inspired by hf apply_chat_template() on python. - /// Both "model" and "custom_template" are optional, but at least one is required. "custom_template" has higher precedence than "model" + /// /// NOTE: This function does not use a jinja parser. It only support a pre-defined list of template. See more: https://github.com/ggml-org/llama.cpp/wiki/Templates-supported-by-llama_chat_apply_template - /// @param tmpl A Jinja template to use for this chat. If this is nullptr, the model’s default chat template will be used instead. + /// @param tmpl A Jinja template to use for this chat. /// @param chat Pointer to a list of multiple llama_chat_message /// @param n_msg Number of llama_chat_message in this chat /// @param add_ass Whether to end the prompt with the token(s) that indicate the start of an assistant message. @@ -1255,7 +1280,6 @@ extern "C" { // [EXPERIMENTAL] // attach a sampler to the context // note: prefer initializing the context with llama_context_params.samplers when possible - // note: changing the samplers of a context can cause graph reallocations and degraded performance LLAMA_API bool llama_set_sampler(struct llama_context * ctx, llama_seq_id seq_id, struct llama_sampler * smpl); // mirror of llama_sampler_i: @@ -1344,7 +1368,7 @@ extern "C" { float tau, float eta); - /// @details Intializes a GBNF grammar, see grammars/README.md for details. + /// @details Initializes a GBNF grammar, see grammars/README.md for details. /// @param vocab The vocabulary that this grammar will be used with. /// @param grammar_str The production rules for the grammar, encoded as a string. Returns an empty grammar if empty. Returns NULL if parsing of grammar_str fails. /// @param grammar_root The name of the start symbol for the grammar. @@ -1395,6 +1419,33 @@ extern "C" { const char ** seq_breakers, size_t num_breakers); + /// adaptive-p: select tokens near a configurable target probability over time. + /// + /// the adaptive-p sampler transforms the token probability distribution to favor tokens + /// that fall near a user-configurable probability target. + /// + /// internally, the sampler maintains an exponential moving average of the *ORIGINAL* + /// probabilities of selected tokens at each sampling step. it uses this EMA to compute an + /// adapted target probability at each sampling step, thus maintaining the desired target + /// probability over time. + /// + /// adaptive-p selects a token ID rather than just mutating candidates, so it must be last + /// in the sampler chain (like mirostat, dist, greedy). + /// + /// only mild truncation before this sampler is recommended. we suggest applying min-p + /// before adaptive-p as the only other active sampler in the chain. + /// + /// @param target select tokens near this probability (valid range 0.0 to 1.0; negative = disabled) + /// @param decay EMA decay for adaptation; history ≈ 1/(1-decay) tokens (valid range 0.0 - 0.99) + /// @param seed RNG seed + /// + /// ref: https://github.com/ggml-org/llama.cpp/pull/17927 + /// + LLAMA_API struct llama_sampler * llama_sampler_init_adaptive_p( + float target, + float decay, + uint32_t seed); + LLAMA_API struct llama_sampler * llama_sampler_init_logit_bias( int32_t n_vocab, int32_t n_logit_bias, @@ -1448,12 +1499,12 @@ extern "C" { /// @details Build a split GGUF final path for this chunk. /// llama_split_path(split_path, sizeof(split_path), "/models/ggml-model-q4_0", 2, 4) => split_path = "/models/ggml-model-q4_0-00002-of-00004.gguf" // Returns the split_path length. - LLAMA_API int llama_split_path(char * split_path, size_t maxlen, const char * path_prefix, int split_no, int split_count); + LLAMA_API int32_t llama_split_path(char * split_path, size_t maxlen, const char * path_prefix, int32_t split_no, int32_t split_count); /// @details Extract the path prefix from the split_path if and only if the split_no and split_count match. /// llama_split_prefix(split_prefix, 64, "/models/ggml-model-q4_0-00002-of-00004.gguf", 2, 4) => split_prefix = "/models/ggml-model-q4_0" // Returns the split_prefix length. - LLAMA_API int llama_split_prefix(char * split_prefix, size_t maxlen, const char * split_path, int split_no, int split_count); + LLAMA_API int32_t llama_split_prefix(char * split_prefix, size_t maxlen, const char * split_path, int32_t split_no, int32_t split_count); // Print system information LLAMA_API const char * llama_print_system_info(void); @@ -1497,9 +1548,6 @@ extern "C" { LLAMA_API void llama_perf_sampler_print(const struct llama_sampler * chain); LLAMA_API void llama_perf_sampler_reset( struct llama_sampler * chain); - // print a breakdown of per-device memory use via LLAMA_LOG: - LLAMA_API void llama_memory_breakdown_print(const struct llama_context * ctx); - // // training // diff --git a/examples/talk-llama/models/afmoe.cpp b/examples/talk-llama/models/afmoe.cpp index 6a752a403f6..063b214256e 100644 --- a/examples/talk-llama/models/afmoe.cpp +++ b/examples/talk-llama/models/afmoe.cpp @@ -1,8 +1,114 @@ #include "models.h" -llm_build_afmoe::llm_build_afmoe(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { - const int64_t n_embd_head = hparams.n_embd_head_v; - GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); +void llama_model_afmoe::load_arch_hparams(llama_model_loader & ml) { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + ml.get_key(LLM_KV_LEADING_DENSE_BLOCK_COUNT, hparams.n_layer_dense_lead, false); + ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp); + ml.get_key(LLM_KV_EXPERT_SHARED_COUNT, hparams.n_expert_shared); + ml.get_key(LLM_KV_EXPERT_GATING_FUNC, hparams.expert_gating_func, false); + ml.get_key(LLM_KV_EXPERT_WEIGHTS_SCALE, hparams.expert_weights_scale, false); + ml.get_key(LLM_KV_EXPERT_WEIGHTS_NORM, hparams.expert_weights_norm, false); + ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa, false); + + // Set up interleaved sliding window attention (ISWA) + // Pattern: 3 sliding - 1 full (global_attn_every_n_layers = 4) + if (hparams.n_swa > 0) { + hparams.swa_type = LLAMA_SWA_TYPE_STANDARD; + uint32_t swa_period = 4; + ml.get_key_or_arr(LLM_KV_ATTENTION_SLIDING_WINDOW_PATTERN, swa_period, false); + hparams.set_swa_pattern(swa_period); + + hparams.rope_freq_base_train_swa = hparams.rope_freq_base_train; + hparams.rope_freq_scale_train_swa = hparams.rope_freq_scale_train; + ml.get_key(LLM_KV_ROPE_FREQ_BASE_SWA, hparams.rope_freq_base_train_swa, false); + } else { + hparams.swa_type = LLAMA_SWA_TYPE_NONE; + } + + // Default to sigmoid if not set + if (hparams.expert_gating_func == LLAMA_EXPERT_GATING_FUNC_TYPE_NONE) { + hparams.expert_gating_func = LLAMA_EXPERT_GATING_FUNC_TYPE_SIGMOID; + } + + switch (hparams.n_layer()) { + case 56: type = LLM_TYPE_6B; break; + case 32: type = LLM_TYPE_26B; break; + default: type = LLM_TYPE_UNKNOWN; + } +} + +void llama_model_afmoe::load_arch_tensors(llama_model_loader &) { + LLAMA_LOAD_LOCALS; + const int64_t n_expert_shared = hparams.n_expert_shared; + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); + + // if output is NULL, init from the input tok embed + if (output == NULL) { + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); + } + + const int64_t n_ff_exp = hparams.n_ff_exp; + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + // dual attention normalization + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + layer.attn_post_norm = create_tensor(tn(LLM_TENSOR_ATTN_POST_NORM, "weight", i), {n_embd}, 0); + + // attention projections + create_tensor_qkv(layer, i, n_embd, n_embd_head_k * n_head, n_embd_k_gqa, n_embd_v_gqa, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0); + + // Q/K normalization + layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head_k}, 0); + layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k}, 0); + + // attention gating + layer.wqkv_gate = create_tensor(tn(LLM_TENSOR_ATTN_GATE, "weight", i), {n_embd, n_embd_head_k * n_head}, 0); + + // dual ffn normalization + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + layer.ffn_post_norm = create_tensor(tn(LLM_TENSOR_FFN_POST_NORM, "weight", i), {n_embd}, 0); + + if (static_cast<uint32_t>(i) >= hparams.n_layer_dense_lead) { + // MoE layers + layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0); + layer.ffn_exp_probs_b = create_tensor(tn(LLM_TENSOR_FFN_EXP_PROBS_B, "bias", i), {n_expert}, 0); + + // grouped expert weights + layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {n_embd, n_ff_exp, n_expert}, 0); + layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff_exp, n_embd, n_expert}, 0); + layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), {n_embd, n_ff_exp, n_expert}, 0); + + // shared expert + if (n_expert_shared > 0) { + const int64_t n_ff_shexp = n_ff_exp * n_expert_shared; + layer.ffn_gate_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), {n_embd, n_ff_shexp}, 0); + layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), {n_ff_shexp, n_embd}, 0); + layer.ffn_up_shexp = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), {n_embd, n_ff_shexp}, 0); + } + } else { + // Dense layers + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + } + } +} + +std::unique_ptr<llm_graph_context> llama_model_afmoe::build_arch_graph(const llm_graph_params & params) const { + return std::make_unique<graph>(*this, params); +} + +llama_model_afmoe::graph::graph(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { + const int64_t n_embd_head = hparams.n_embd_head_v(); + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k()); ggml_tensor * cur; ggml_tensor * inpL; @@ -41,22 +147,13 @@ llm_build_afmoe::llm_build_afmoe(const llama_model & model, const llm_graph_para { ggml_tensor * attn_inp = cur; // save input for gate computation - ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); - cb(Qcur, "Qcur", il); - - ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); - cb(Kcur, "Kcur", il); - - ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); - cb(Vcur, "Vcur", il); + auto [Qcur, Kcur, Vcur] = build_qkv(model.layers[il], cur, + n_embd_head, n_head, n_head_kv, il); // compute gate from input ggml_tensor * gate = build_lora_mm(model.layers[il].wqkv_gate, attn_inp); cb(gate, "attn_gate_proj", il); - Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); - // Q/K normalization Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, NULL, LLM_NORM_RMS, il); Kcur = build_norm(Kcur, model.layers[il].attn_k_norm, NULL, LLM_NORM_RMS, il); @@ -77,10 +174,8 @@ llm_build_afmoe::llm_build_afmoe(const llama_model & model, const llm_graph_para cb(Kcur, "Kcur_rope", il); } - Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); - cur = build_attn(inp_attn, - NULL, NULL, // wo will be applied after gating + NULL, NULL, NULL, // wo will be applied after gating Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale, il); cb(cur, "attn_out", il); @@ -91,7 +186,7 @@ llm_build_afmoe::llm_build_afmoe(const llama_model & model, const llm_graph_para cb(cur, "attn_gated", il); // now apply output projection - cur = build_lora_mm(model.layers[il].wo, cur); + cur = build_lora_mm(model.layers[il].wo, cur, model.layers[il].wo_s); cb(cur, "attn_o_proj", il); } @@ -127,7 +222,6 @@ llm_build_afmoe::llm_build_afmoe(const llama_model & model, const llm_graph_para n_expert, n_expert_used, LLM_FFN_SILU, hparams.expert_weights_norm, // norm_w (route_norm=True) - hparams.expert_weights_scale, // scale_w hparams.expert_weights_scale, // w_scale (route_scale=2.826) (llama_expert_gating_func_type) hparams.expert_gating_func, il); @@ -183,7 +277,7 @@ llm_build_afmoe::llm_build_afmoe(const llama_model & model, const llm_graph_para res->t_embd = cur; // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/apertus.cpp b/examples/talk-llama/models/apertus.cpp index 9af19c1bfe8..6dfb8905fbe 100644 --- a/examples/talk-llama/models/apertus.cpp +++ b/examples/talk-llama/models/apertus.cpp @@ -1,12 +1,67 @@ #include "models.h" +void llama_model_apertus::load_arch_hparams(llama_model_loader & ml) { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + ml.get_key_or_arr(LLM_KV_XIELU_ALPHA_N, hparams.xielu_alpha_n, hparams.n_layer()); + ml.get_key_or_arr(LLM_KV_XIELU_ALPHA_P, hparams.xielu_alpha_p, hparams.n_layer()); + ml.get_key_or_arr(LLM_KV_XIELU_BETA, hparams.xielu_beta, hparams.n_layer()); + ml.get_key_or_arr(LLM_KV_XIELU_EPS, hparams.xielu_eps, hparams.n_layer()); -llm_build_apertus::llm_build_apertus(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { - const int64_t n_embd_head = hparams.n_embd_head_v; + switch (hparams.n_layer()) { + case 32: type = LLM_TYPE_8B; break; + default: type = LLM_TYPE_UNKNOWN; + } +} + +void llama_model_apertus::load_arch_tensors(llama_model_loader &) { + LLAMA_LOAD_LOCALS; + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), { n_embd, n_vocab }, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), { n_embd }, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), { n_embd, n_vocab }, 0); + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), { n_embd }, 0); + + if (hparams.rope_scaling_type_train == LLAMA_ROPE_SCALING_TYPE_LONGROPE) { + layer.rope_long = create_tensor(tn(LLM_TENSOR_ROPE_FACTORS_LONG, "weight", i), { n_rot/2 }, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0)); + layer.rope_short = create_tensor(tn(LLM_TENSOR_ROPE_FACTORS_SHORT, "weight", i), { n_rot/2 }, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0)); + } else { + layer.rope_freqs = create_tensor(tn(LLM_TENSOR_ROPE_FREQS, "weight", i), { n_rot/2 }, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0)); + } + + create_tensor_qkv(layer, i, n_embd, n_embd_head_k * n_head, n_embd_gqa, n_embd_gqa, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), { n_embd_head_k * n_head, n_embd }, 0); + + // optional bias tensors + layer.wo_b = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), { n_embd }, TENSOR_NOT_REQUIRED); - GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); - GGML_ASSERT(n_embd_head == hparams.n_rot); + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), { n_embd }, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd }, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), { n_embd, n_ff }, 0); + + // Q and K layernorms for Apertus + layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), { n_embd_head_k }, 0); + layer.attn_q_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "bias", i), { n_embd_head_k }, TENSOR_NOT_REQUIRED); + layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), { n_embd_head_k }, 0); + layer.attn_k_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "bias", i), { n_embd_head_k }, TENSOR_NOT_REQUIRED); + } +} + +std::unique_ptr<llm_graph_context> llama_model_apertus::build_arch_graph(const llm_graph_params & params) const { + return std::make_unique<graph>(*this, params); +} + +llama_model_apertus::graph::graph(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { + const int64_t n_embd_head = hparams.n_embd_head_v(); + + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k()); + GGML_ASSERT(n_embd_head == n_rot); ggml_tensor * cur; ggml_tensor * inpL; @@ -32,25 +87,15 @@ llm_build_apertus::llm_build_apertus(const llama_model & model, const llm_graph_ ggml_tensor * rope_factors = model.get_rope_factors(cparams, il); // compute Q and K and RoPE them - ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); - cb(Qcur, "Qcur", il); - - ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); - cb(Kcur, "Kcur", il); + auto [Qcur, Kcur, Vcur] = build_qkv(model.layers[il], cur, + n_embd_head, n_head, n_head_kv, il); - ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); - cb(Vcur, "Vcur", il); - - Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, NULL, LLM_NORM_RMS, il); cb(Qcur, "Qcur_normed", il); - Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); Kcur = build_norm(Kcur, model.layers[il].attn_k_norm, NULL, LLM_NORM_RMS, il); cb(Kcur, "Kcur_normed", il); - Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); - Qcur = ggml_rope_ext(ctx0, Qcur, inp_pos, rope_factors, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow); @@ -62,7 +107,7 @@ llm_build_apertus::llm_build_apertus(const llama_model & model, const llm_graph_ cb(Vcur, "Vcur_pos", il); cur = build_attn(inp_attn, - model.layers[il].wo, model.layers[il].bo, + model.layers[il].wo, model.layers[il].wo_b, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale, il); cb(cur, "attn_out", il); } @@ -116,7 +161,7 @@ llm_build_apertus::llm_build_apertus(const llama_model & model, const llm_graph_ res->t_embd = cur; // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/arcee.cpp b/examples/talk-llama/models/arcee.cpp index aa6167dba1e..9536e7c5d42 100644 --- a/examples/talk-llama/models/arcee.cpp +++ b/examples/talk-llama/models/arcee.cpp @@ -1,11 +1,55 @@ #include "models.h" +void llama_model_arcee::load_arch_hparams(llama_model_loader & ml) { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); -llm_build_arcee::llm_build_arcee(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { - const int64_t n_embd_head = hparams.n_embd_head_v; + // Arcee uses the same structure as Llama + switch (hparams.n_layer()) { + case 36: type = LLM_TYPE_4B; break; + default: type = LLM_TYPE_UNKNOWN; + } +} + +void llama_model_arcee::load_arch_tensors(llama_model_loader &) { + LLAMA_LOAD_LOCALS; + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); + + // if output is NULL, init from the input tok embed + if (output == NULL) { + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); + } + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + + create_tensor_qkv(layer, i, n_embd, n_embd_head_k * n_head, n_embd_k_gqa, n_embd_v_gqa, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + + layer.rope_freqs = create_tensor(tn(LLM_TENSOR_ROPE_FREQS, "weight", i), {n_rot/2}, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0)); - GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); - GGML_ASSERT(n_embd_head == hparams.n_rot); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + } +} + +std::unique_ptr<llm_graph_context> llama_model_arcee::build_arch_graph(const llm_graph_params & params) const { + return std::make_unique<graph>(*this, params); +} + +llama_model_arcee::graph::graph(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { + const int64_t n_embd_head = hparams.n_embd_head_v(); + + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k()); + GGML_ASSERT(n_embd_head == n_rot); ggml_tensor * cur; ggml_tensor * inpL; @@ -36,30 +80,8 @@ llm_build_arcee::llm_build_arcee(const llama_model & model, const llm_graph_para ggml_tensor * rope_factors = model.get_rope_factors(cparams, il); // compute Q and K and RoPE them - ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); - cb(Qcur, "Qcur", il); - if (model.layers[il].bq) { - Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq); - cb(Qcur, "Qcur", il); - } - - ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); - cb(Kcur, "Kcur", il); - if (model.layers[il].bk) { - Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk); - cb(Kcur, "Kcur", il); - } - - ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); - cb(Vcur, "Vcur", il); - if (model.layers[il].bv) { - Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv); - cb(Vcur, "Vcur", il); - } - - Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); - Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + auto [Qcur, Kcur, Vcur] = build_qkv(model.layers[il], cur, + n_embd_head, n_head, n_head_kv, il); Qcur = ggml_rope_ext( ctx0, Qcur, inp_pos, rope_factors, @@ -78,7 +100,7 @@ llm_build_arcee::llm_build_arcee(const llama_model & model, const llm_graph_para cb(Vcur, "Vcur", il); cur = build_attn(inp_attn, - model.layers[il].wo, model.layers[il].bo, + model.layers[il].wo, model.layers[il].wo_b, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale, il); cb(cur, "attn_out", il); } @@ -126,7 +148,7 @@ llm_build_arcee::llm_build_arcee(const llama_model & model, const llm_graph_para res->t_embd = cur; // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/arctic.cpp b/examples/talk-llama/models/arctic.cpp index e8f028a723e..09ee0f752f0 100644 --- a/examples/talk-llama/models/arctic.cpp +++ b/examples/talk-llama/models/arctic.cpp @@ -1,11 +1,63 @@ #include "models.h" +void llama_model_arctic::load_arch_hparams(llama_model_loader & ml) { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); -llm_build_arctic::llm_build_arctic(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { - const int64_t n_embd_head = hparams.n_embd_head_v; + if (hparams.n_expert == 128) { + switch (hparams.n_layer()) { + case 35: type = LLM_TYPE_10B_128x3_66B; break; + default: type = LLM_TYPE_UNKNOWN; + } + } else { + type = LLM_TYPE_UNKNOWN; + } +} + +void llama_model_arctic::load_arch_tensors(llama_model_loader &) { + LLAMA_LOAD_LOCALS; - GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); - GGML_ASSERT(n_embd_head == hparams.n_rot); + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); + + // if output is NULL, init from the input tok embed + if (output == NULL) { + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); + } + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + + create_tensor_qkv(layer, i, n_embd, n_embd, n_embd_gqa, n_embd_gqa, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_embd}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_embd, n_embd}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_embd}, 0); + + layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0); + layer.ffn_norm_exps = create_tensor(tn(LLM_TENSOR_FFN_NORM_EXPS, "weight", i), {n_embd}, 0); + layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {n_embd, n_ff, n_expert}, false); + layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), { n_ff, n_embd, n_expert}, 0); + layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), {n_embd, n_ff, n_expert}, 0); + } +} + +std::unique_ptr<llm_graph_context> llama_model_arctic::build_arch_graph(const llm_graph_params & params) const { + return std::make_unique<graph>(*this, params); +} + +llama_model_arctic::graph::graph(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { + const int64_t n_embd_head = hparams.n_embd_head_v(); + + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k()); + GGML_ASSERT(n_embd_head == n_rot); ggml_tensor * cur; ggml_tensor * inpL; @@ -31,18 +83,8 @@ llm_build_arctic::llm_build_arctic(const llama_model & model, const llm_graph_pa // self-attention { // compute Q and K and RoPE them - ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); - cb(Qcur, "Qcur", il); - - ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); - cb(Kcur, "Kcur", il); - - ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); - cb(Vcur, "Vcur", il); - - Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); - Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + auto [Qcur, Kcur, Vcur] = build_qkv(model.layers[il], cur, + n_embd_head, n_head, n_head_kv, il); Qcur = ggml_rope_ext( ctx0, Qcur, inp_pos, nullptr, @@ -61,7 +103,7 @@ llm_build_arctic::llm_build_arctic(const llama_model & model, const llm_graph_pa cb(Vcur, "Vcur", il); cur = build_attn(inp_attn, - model.layers[il].wo, NULL, + model.layers[il].wo, NULL, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } @@ -104,7 +146,7 @@ llm_build_arctic::llm_build_arctic(const llama_model & model, const llm_graph_pa nullptr, n_expert, n_expert_used, LLM_FFN_SILU, true, - false, 0.0, + hparams.expert_weights_scale, LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX, il); cb(cur, "ffn_moe_out", il); @@ -129,7 +171,7 @@ llm_build_arctic::llm_build_arctic(const llama_model & model, const llm_graph_pa res->t_embd = cur; // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/arwkv7.cpp b/examples/talk-llama/models/arwkv7.cpp index 107a3bef8da..b38b2064785 100644 --- a/examples/talk-llama/models/arwkv7.cpp +++ b/examples/talk-llama/models/arwkv7.cpp @@ -1,7 +1,123 @@ #include "models.h" +void llama_model_arwkv7::load_arch_hparams(llama_model_loader & ml) { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps, false); + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps, false); + ml.get_key(LLM_KV_WKV_HEAD_SIZE, hparams.wkv_head_size); + ml.get_key(LLM_KV_ATTENTION_DECAY_LORA_RANK, hparams.n_lora_decay); + ml.get_key(LLM_KV_ATTENTION_ICLR_LORA_RANK, hparams.n_lora_iclr); + ml.get_key(LLM_KV_ATTENTION_VALUE_RESIDUAL_MIX_LORA_RANK, hparams.n_lora_value_res_mix); + ml.get_key(LLM_KV_ATTENTION_GATE_LORA_RANK, hparams.n_lora_gate, false); + ml.get_key(LLM_KV_TOKEN_SHIFT_COUNT, hparams.token_shift_count, false); + + switch (hparams.n_layer()) { + case 12: + switch (hparams.n_embd) { + case 768: type = LLM_TYPE_190M; break; + default: type = LLM_TYPE_UNKNOWN; + } break; + case 24: + switch (hparams.n_embd) { + case 1024: type = LLM_TYPE_450M; break; + case 2048: type = LLM_TYPE_1_5B; break; + default: type = LLM_TYPE_UNKNOWN; + } break; + case 28: + switch (hparams.n_embd) { + case 1536: type = LLM_TYPE_1_5B; break; + case 3584: type = LLM_TYPE_7B; break; + default: type = LLM_TYPE_UNKNOWN; + } break; + case 32: + switch (hparams.n_embd) { + case 2560: type = LLM_TYPE_2_9B; break; + case 4096: type = LLM_TYPE_7B; break; + default: type = LLM_TYPE_UNKNOWN; + } break; + case 61: + switch (hparams.n_embd) { + case 4096: type = LLM_TYPE_14B; break; + default: type = LLM_TYPE_UNKNOWN; + } break; + default: type = LLM_TYPE_UNKNOWN; + } +} + +void llama_model_arwkv7::load_arch_tensors(llama_model_loader &) { + LLAMA_LOAD_LOCALS; + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0); + + const int n_lora_decay = hparams.n_lora_decay; + const int n_lora_iclr = hparams.n_lora_iclr; + const int n_lora_value_res_mix = hparams.n_lora_value_res_mix; + const int n_lora_gate = hparams.n_lora_gate; + const int attn_hidden_size = n_embd; + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + + layer.time_mix_w0 = create_tensor(tn(LLM_TENSOR_TIME_MIX_W0, "weight", i), {n_embd}, 0); + layer.time_mix_w1 = create_tensor(tn(LLM_TENSOR_TIME_MIX_W1, "weight", i), {n_embd, n_lora_decay}, 0); + layer.time_mix_w2 = create_tensor(tn(LLM_TENSOR_TIME_MIX_W2, "weight", i), {n_lora_decay, n_embd}, 0); + + layer.time_mix_a0 = create_tensor(tn(LLM_TENSOR_TIME_MIX_A0, "weight", i), {n_embd}, 0); + layer.time_mix_a1 = create_tensor(tn(LLM_TENSOR_TIME_MIX_A1, "weight", i), {n_embd, n_lora_iclr}, 0); + layer.time_mix_a2 = create_tensor(tn(LLM_TENSOR_TIME_MIX_A2, "weight", i), {n_lora_iclr, n_embd}, 0); + + if (i == 0) { + // actually not used + layer.time_mix_v0 = create_tensor(tn(LLM_TENSOR_TIME_MIX_V0, "weight", i), {n_embd}, 0); + layer.time_mix_v1 = create_tensor(tn(LLM_TENSOR_TIME_MIX_V1, "weight", i), {n_embd, n_lora_iclr}, 0); + layer.time_mix_v2 = create_tensor(tn(LLM_TENSOR_TIME_MIX_V2, "weight", i), {n_lora_iclr, n_embd}, 0); + } else { + layer.time_mix_v0 = create_tensor(tn(LLM_TENSOR_TIME_MIX_V0, "weight", i), {n_embd}, 0); + layer.time_mix_v1 = create_tensor(tn(LLM_TENSOR_TIME_MIX_V1, "weight", i), {n_embd, n_lora_value_res_mix}, 0); + layer.time_mix_v2 = create_tensor(tn(LLM_TENSOR_TIME_MIX_V2, "weight", i), {n_lora_value_res_mix, n_embd}, 0); + } + + layer.time_mix_g1 = create_tensor(tn(LLM_TENSOR_TIME_MIX_G1, "weight", i), {n_embd, n_lora_gate}, TENSOR_NOT_REQUIRED); + layer.time_mix_g2 = create_tensor(tn(LLM_TENSOR_TIME_MIX_G2, "weight", i), {n_lora_gate, n_embd}, TENSOR_NOT_REQUIRED); + + try { + layer.time_mix_lerp_fused = create_tensor(tn(LLM_TENSOR_TIME_MIX_LERP_FUSED, "weight", i), {n_embd, 1, 1, 6}, 0); + } catch(std::runtime_error & e) { + // ARWKV models may not have gate tensors + layer.time_mix_lerp_fused = create_tensor(tn(LLM_TENSOR_TIME_MIX_LERP_FUSED, "weight", i), {n_embd, 1, 1, 5}, 0); + } + + layer.time_mix_k_k = create_tensor(tn(LLM_TENSOR_TIME_MIX_K_K, "weight", i), {attn_hidden_size}, 0); + layer.time_mix_k_a = create_tensor(tn(LLM_TENSOR_TIME_MIX_K_A, "weight", i), {attn_hidden_size}, 0); + layer.time_mix_r_k = create_tensor(tn(LLM_TENSOR_TIME_MIX_R_K, "weight", i), {attn_hidden_size}, 0); + + layer.time_mix_key = create_tensor(tn(LLM_TENSOR_TIME_MIX_KEY, "weight", i), {attn_hidden_size, n_embd}, 0); + layer.time_mix_value = create_tensor(tn(LLM_TENSOR_TIME_MIX_VALUE, "weight", i), {attn_hidden_size, n_embd}, 0); + layer.time_mix_receptance = create_tensor(tn(LLM_TENSOR_TIME_MIX_RECEPTANCE, "weight", i), {attn_hidden_size, n_embd}, 0); + + layer.time_mix_ln = create_tensor(tn(LLM_TENSOR_TIME_MIX_LN, "weight", i), {n_embd}, TENSOR_NOT_REQUIRED); + layer.time_mix_ln_b = create_tensor(tn(LLM_TENSOR_TIME_MIX_LN, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); + layer.time_mix_output = create_tensor(tn(LLM_TENSOR_TIME_MIX_OUTPUT, "weight", i), {n_embd, attn_hidden_size}, 0); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + } + +} + +std::unique_ptr<llm_graph_context> llama_model_arwkv7::build_arch_graph(const llm_graph_params & params) const { + return std::make_unique<graph>(*this, params); +} -llm_build_arwkv7::llm_build_arwkv7(const llama_model & model, const llm_graph_params & params) : llm_build_rwkv7_base(model, params) { +llama_model_arwkv7::graph::graph(const llama_model & model, const llm_graph_params & params) : llm_build_rwkv7_base(model, params) { GGML_ASSERT(n_embd == hparams.n_embd_r()); ggml_tensor * cur; @@ -77,7 +193,7 @@ llm_build_arwkv7::llm_build_arwkv7(const llama_model & model, const llm_graph_pa cb(cur, "result_norm", -1); res->t_embd = cur; - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/baichuan.cpp b/examples/talk-llama/models/baichuan.cpp index c04b0c98b0b..585f3614174 100644 --- a/examples/talk-llama/models/baichuan.cpp +++ b/examples/talk-llama/models/baichuan.cpp @@ -1,11 +1,53 @@ #include "models.h" +void llama_model_baichuan::load_arch_hparams(llama_model_loader & ml) { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + switch (hparams.n_layer()) { + case 32: type = LLM_TYPE_7B; break; + case 40: type = LLM_TYPE_13B; break; + default: type = LLM_TYPE_UNKNOWN; + } + + if (type == LLM_TYPE_13B) { + // TODO: become GGUF KV parameter + hparams.f_max_alibi_bias = 8.0f; + } +} + +void llama_model_baichuan::load_arch_tensors(llama_model_loader &) { + LLAMA_LOAD_LOCALS; + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + { + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0); + } + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + + create_tensor_qkv(layer, i, n_embd, n_embd, n_embd_gqa, n_embd_gqa, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); -llm_build_baichuan::llm_build_baichuan(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { - const int64_t n_embd_head = hparams.n_embd_head_v; + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + } +} + +std::unique_ptr<llm_graph_context> llama_model_baichuan::build_arch_graph(const llm_graph_params & params) const { + return std::make_unique<graph>(*this, params); +} + +llama_model_baichuan::graph::graph(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { + const int64_t n_embd_head = hparams.n_embd_head_v(); - GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); - GGML_ASSERT(n_embd_head == hparams.n_rot); + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k()); + GGML_ASSERT(n_embd_head == n_rot); ggml_tensor * cur; ggml_tensor * inpL; @@ -29,18 +71,8 @@ llm_build_baichuan::llm_build_baichuan(const llama_model & model, const llm_grap // self-attention { - ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); - cb(Qcur, "Qcur", il); - - ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); - cb(Kcur, "Kcur", il); - - ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); - cb(Vcur, "Vcur", il); - - Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); - Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + auto [Qcur, Kcur, Vcur] = build_qkv(model.layers[il], cur, + n_embd_head, n_head, n_head_kv, il); switch (model.type) { case LLM_TYPE_7B: @@ -56,6 +88,7 @@ llm_build_baichuan::llm_build_baichuan(const llama_model & model, const llm_grap ); break; case LLM_TYPE_13B: + case LLM_TYPE_UNKNOWN: break; default: GGML_ABORT("fatal error"); @@ -66,7 +99,7 @@ llm_build_baichuan::llm_build_baichuan(const llama_model & model, const llm_grap cb(Vcur, "Vcur", il); cur = build_attn(inp_attn, - model.layers[il].wo, NULL, + model.layers[il].wo, NULL, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } @@ -113,7 +146,7 @@ llm_build_baichuan::llm_build_baichuan(const llama_model & model, const llm_grap res->t_embd = cur; // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/bailingmoe.cpp b/examples/talk-llama/models/bailingmoe.cpp index ed56b9c4713..7faf73c835b 100644 --- a/examples/talk-llama/models/bailingmoe.cpp +++ b/examples/talk-llama/models/bailingmoe.cpp @@ -1,7 +1,65 @@ #include "models.h" +void llama_model_bailingmoe::load_arch_hparams(llama_model_loader & ml) { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + ml.get_key(LLM_KV_LEADING_DENSE_BLOCK_COUNT, hparams.n_layer_dense_lead, false); + ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp); + ml.get_key(LLM_KV_EXPERT_SHARED_COUNT, hparams.n_expert_shared); + ml.get_key(LLM_KV_EXPERT_WEIGHTS_SCALE, hparams.expert_weights_scale, false); + ml.get_key(LLM_KV_EXPERT_WEIGHTS_NORM, hparams.expert_weights_norm, false); + + switch (hparams.n_layer()) { + case 28: type = LLM_TYPE_16B; break; + case 88: type = LLM_TYPE_290B; break; + default: type = LLM_TYPE_UNKNOWN; + } +} + +void llama_model_bailingmoe::load_arch_tensors(llama_model_loader &) { + LLAMA_LOAD_LOCALS; + const int64_t n_expert_shared = hparams.n_expert_shared; + + const int64_t n_ff_exp = hparams.n_ff_exp; + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0); + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + + create_tensor_qkv(layer, i, n_embd, n_head * n_rot, n_head_kv * n_rot, n_head_kv * n_rot, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_head * n_rot, n_embd}, 0); + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + + layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0); + + if (n_expert == 0) { + throw std::runtime_error("n_expert must be > 0"); + } + if (n_expert_used == 0) { + throw std::runtime_error("n_expert_used must be > 0"); + } -llm_build_bailingmoe::llm_build_bailingmoe(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { + layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}, 0); + layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff_exp, n_embd, n_expert}, 0); + layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}, 0); + + layer.ffn_gate_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), {n_embd, n_ff_exp * n_expert_shared}, 0); + layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), { n_ff_exp * n_expert_shared, n_embd}, 0); + layer.ffn_up_shexp = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), {n_embd, n_ff_exp * n_expert_shared}, 0); + } +} + +std::unique_ptr<llm_graph_context> llama_model_bailingmoe::build_arch_graph(const llm_graph_params & params) const { + return std::make_unique<graph>(*this, params); +} + +llama_model_bailingmoe::graph::graph(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { ggml_tensor * cur; ggml_tensor * inpL; @@ -29,30 +87,8 @@ llm_build_bailingmoe::llm_build_bailingmoe(const llama_model & model, const llm_ ggml_tensor * rope_factors = model.get_rope_factors(cparams, il); // compute Q and K and RoPE them - ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); - cb(Qcur, "Qcur", il); - if (model.layers[il].bq) { - Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq); - cb(Qcur, "Qcur", il); - } - - ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); - cb(Kcur, "Kcur", il); - if (model.layers[il].bk) { - Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk); - cb(Kcur, "Kcur", il); - } - - ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); - cb(Vcur, "Vcur", il); - if (model.layers[il].bv) { - Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv); - cb(Vcur, "Vcur", il); - } - - Qcur = ggml_reshape_3d(ctx0, Qcur, n_rot, n_head, n_tokens); - Kcur = ggml_reshape_3d(ctx0, Kcur, n_rot, n_head_kv, n_tokens); - Vcur = ggml_reshape_3d(ctx0, Vcur, n_rot, n_head_kv, n_tokens); + auto [Qcur, Kcur, Vcur] = build_qkv(model.layers[il], cur, + n_embd_head_k, n_head, n_head_kv, il); Qcur = ggml_rope_ext( ctx0, Qcur, inp_pos, rope_factors, @@ -71,7 +107,7 @@ llm_build_bailingmoe::llm_build_bailingmoe(const llama_model & model, const llm_ cb(Vcur, "Vcur", il); cur = build_attn(inp_attn, - model.layers[il].wo, model.layers[il].bo, + model.layers[il].wo, model.layers[il].wo_b, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_rot)), il); } @@ -97,7 +133,7 @@ llm_build_bailingmoe::llm_build_bailingmoe(const llama_model & model, const llm_ nullptr, n_expert, n_expert_used, LLM_FFN_SILU, hparams.expert_weights_norm, - false, hparams.expert_weights_scale, + hparams.expert_weights_scale, LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX, il); cb(moe_out, "ffn_moe_out", il); @@ -135,7 +171,7 @@ llm_build_bailingmoe::llm_build_bailingmoe(const llama_model & model, const llm_ res->t_embd = cur; // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/bailingmoe2.cpp b/examples/talk-llama/models/bailingmoe2.cpp index fbf7b210c42..5000e9c6db8 100644 --- a/examples/talk-llama/models/bailingmoe2.cpp +++ b/examples/talk-llama/models/bailingmoe2.cpp @@ -1,13 +1,100 @@ #include "models.h" +void llama_model_bailingmoe2::load_arch_hparams(llama_model_loader & ml) { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + ml.get_key(LLM_KV_LEADING_DENSE_BLOCK_COUNT, hparams.n_layer_dense_lead, false); + ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp); + ml.get_key(LLM_KV_EXPERT_SHARED_FEED_FORWARD_LENGTH, hparams.n_ff_shexp, false); + ml.get_key(LLM_KV_EXPERT_SHARED_COUNT, hparams.n_expert_shared); + ml.get_key(LLM_KV_EXPERT_WEIGHTS_SCALE, hparams.expert_weights_scale, false); + ml.get_key(LLM_KV_EXPERT_WEIGHTS_NORM, hparams.expert_weights_norm, false); + ml.get_key(LLM_KV_EXPERT_GATING_FUNC, hparams.expert_gating_func); + ml.get_key(LLM_KV_NEXTN_PREDICT_LAYERS, hparams.n_layer_nextn, false); + + GGML_ASSERT(hparams.n_layer_nextn < hparams.n_layer_all && "n_layer_nextn must be < n_layer_impl"); + + switch (hparams.n_layer()) { + case 20: type = LLM_TYPE_16B_A1B; break; + case 32: type = LLM_TYPE_100B_A6B; break; + default: type = LLM_TYPE_UNKNOWN; + } +} + +void llama_model_bailingmoe2::load_arch_tensors(llama_model_loader &) { + LLAMA_LOAD_LOCALS; + const int64_t n_expert_shared = hparams.n_expert_shared; + + const int64_t n_ff_exp = hparams.n_ff_exp; + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0); + + GGML_ASSERT(n_expert > 0 && "n_expert must be > 0 for bailingmoe2"); + GGML_ASSERT(n_expert_used > 0 && "n_expert_used must be > 0 for bailingmoe2"); + + for (int i = 0; i < n_layer_all; ++i) { + int flags = 0; + if (i >= n_layer) { + // skip all tensors in the NextN layers + flags |= TENSOR_SKIP; + } + + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, flags); + + layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, flags); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, flags); + + layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head_k}, flags); + layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k}, flags); + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, flags); -llm_build_bailingmoe2::llm_build_bailingmoe2(const llama_model & model, const llm_graph_params & params) : + if (static_cast<uint32_t>(i) >= hparams.n_layer_dense_lead) { // MoE layers + const int64_t n_ff_shexp = (hparams.n_ff_shexp ? hparams.n_ff_shexp : n_ff_exp) * n_expert_shared; + + layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, flags); + layer.ffn_exp_probs_b = create_tensor(tn(LLM_TENSOR_FFN_EXP_PROBS_B, "bias", i), {n_expert}, TENSOR_NOT_REQUIRED | flags); + + layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}, flags); + layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff_exp, n_embd, n_expert}, flags); + layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}, flags); + + layer.ffn_gate_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), {n_embd, n_ff_shexp}, flags); + layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), {n_ff_shexp, n_embd}, flags); + layer.ffn_up_shexp = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), {n_embd, n_ff_shexp}, flags); + } else { // Dense layers + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, flags); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, flags); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, flags); + } + + // NextN/MTP tensors (preserved but unused) - conditionally load for last nextn_predict_layers + if (i >= n_layer) { + layer.nextn.eh_proj = create_tensor(tn(LLM_TENSOR_NEXTN_EH_PROJ, "weight", i), { 2 * n_embd, n_embd }, flags); + layer.nextn.embed_tokens = create_tensor(tn(LLM_TENSOR_NEXTN_EMBED_TOKENS, "weight", i), { n_embd, n_vocab }, TENSOR_NOT_REQUIRED | flags); + layer.nextn.enorm = create_tensor(tn(LLM_TENSOR_NEXTN_ENORM, "weight", i), { n_embd }, flags); + layer.nextn.hnorm = create_tensor(tn(LLM_TENSOR_NEXTN_HNORM, "weight", i), { n_embd }, flags); + layer.nextn.shared_head_head = create_tensor(tn(LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD, "weight", i), { n_embd, n_vocab }, TENSOR_NOT_REQUIRED | flags); + layer.nextn.shared_head_norm = create_tensor(tn(LLM_TENSOR_NEXTN_SHARED_HEAD_NORM, "weight", i), { n_embd }, TENSOR_NOT_REQUIRED | flags); + layer.layer_out_norm = create_tensor(tn(LLM_TENSOR_LAYER_OUT_NORM, "weight", i), {n_embd}, flags); + } + } +} + +std::unique_ptr<llm_graph_context> llama_model_bailingmoe2::build_arch_graph(const llm_graph_params & params) const { + return std::make_unique<graph>(*this, params); +} + +llama_model_bailingmoe2::graph::graph(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { - const int64_t n_embd_head = hparams.n_embd_head_v; - const int64_t n_embd_gqa = hparams.n_embd_v_gqa(); + const int64_t n_embd_head = hparams.n_embd_head_v(); - GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k()); ggml_tensor * cur; ggml_tensor * inpL; @@ -21,8 +108,7 @@ llm_build_bailingmoe2::llm_build_bailingmoe2(const llama_model & model, const ll ggml_tensor * inp_out_ids = build_inp_out_ids(); - const int n_transformer_layers = n_layer - hparams.nextn_predict_layers; - for (int il = 0; il < n_transformer_layers; ++il) { + for (int il = 0; il < n_layer; ++il) { ggml_tensor * inpSA = inpL; // norm @@ -31,15 +117,8 @@ llm_build_bailingmoe2::llm_build_bailingmoe2(const llama_model & model, const ll // self_attention { - cur = build_lora_mm(model.layers[il].wqkv, cur); - cb(cur, "wqkv", il); - - ggml_tensor * Qcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head, n_tokens, n_embd_head * sizeof(float), - cur->nb[1], 0 * sizeof(float) * (n_embd)); - ggml_tensor * Kcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head * sizeof(float), - cur->nb[1], 1 * sizeof(float) * (n_embd)); - ggml_tensor * Vcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head * sizeof(float), - cur->nb[1], 1 * sizeof(float) * (n_embd + n_embd_gqa)); + auto [Qcur, Kcur, Vcur] = build_qkv(model.layers[il], cur, + n_embd_head, n_head, n_head_kv, il); Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, NULL, LLM_NORM_RMS, il); cb(Qcur, "Qcur_normed", il); @@ -58,11 +137,11 @@ llm_build_bailingmoe2::llm_build_bailingmoe2(const llama_model & model, const ll cb(Vcur, "Vcur", il); cur = build_attn(inp_attn, - model.layers[il].wo, model.layers[il].bo, + model.layers[il].wo, model.layers[il].wo_b, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f / sqrtf(float(n_embd_head)), il); } - if (il == n_transformer_layers - 1 && inp_out_ids) { + if (il == n_layer - 1 && inp_out_ids) { cur = ggml_get_rows(ctx0, cur, inp_out_ids); inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); } @@ -90,7 +169,7 @@ llm_build_bailingmoe2::llm_build_bailingmoe2(const llama_model & model, const ll model.layers[il].ffn_exp_probs_b, n_expert, n_expert_used, LLM_FFN_SILU, hparams.expert_weights_norm, - true, hparams.expert_weights_scale, + hparams.expert_weights_scale, (llama_expert_gating_func_type) hparams.expert_gating_func, il); cb(moe_out, "ffn_moe_out", il); @@ -126,7 +205,7 @@ llm_build_bailingmoe2::llm_build_bailingmoe2(const llama_model & model, const ll res->t_embd = cur; // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/bert.cpp b/examples/talk-llama/models/bert.cpp index bca0e254fc5..53ce29f23ca 100644 --- a/examples/talk-llama/models/bert.cpp +++ b/examples/talk-llama/models/bert.cpp @@ -1,12 +1,86 @@ #include "models.h" +void llama_model_bert::load_arch_hparams(llama_model_loader & ml) { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); + + switch (hparams.n_layer()) { + case 3: + type = LLM_TYPE_17M; break; // bge-micro + case 6: + type = LLM_TYPE_22M; break; // MiniLM-L6 + case 12: + switch (hparams.n_embd) { + case 384: type = LLM_TYPE_33M; break; // MiniLM-L12, bge-small + case 768: type = LLM_TYPE_109M; break; // bge-base + default: type = LLM_TYPE_UNKNOWN; + } break; + case 24: + type = LLM_TYPE_335M; break; // bge-large + default: type = LLM_TYPE_UNKNOWN; + } +} + +void llama_model_bert::load_arch_tensors(llama_model_loader &) { + LLAMA_LOAD_LOCALS; + + if (n_token_types == 0) { + throw std::runtime_error(arch_name() + " model needs to define token type count"); + } + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + type_embd = create_tensor(tn(LLM_TENSOR_TOKEN_TYPES, "weight"), {n_embd, n_token_types}, TENSOR_NOT_REQUIRED); + + if (arch == LLM_ARCH_BERT) { + pos_embd = create_tensor(tn(LLM_TENSOR_POS_EMBD, "weight"), {n_embd, n_ctx_train}, 0); + + cls = create_tensor(tn(LLM_TENSOR_CLS, "weight"), {n_embd, n_embd}, TENSOR_NOT_REQUIRED); + cls_b = create_tensor(tn(LLM_TENSOR_CLS, "bias"), {n_embd}, TENSOR_NOT_REQUIRED); + cls_out = create_tensor(tn(LLM_TENSOR_CLS_OUT, "weight"), {n_embd, hparams.n_cls_out}, TENSOR_NOT_REQUIRED); + cls_out_b = create_tensor(tn(LLM_TENSOR_CLS_OUT, "bias"), {hparams.n_cls_out}, TENSOR_NOT_REQUIRED); + } + + tok_norm = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "weight", 0), {n_embd}, 0); + tok_norm_b = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "bias", 0), {n_embd}, 0); + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + create_tensor_qkv(layer, i, n_embd, n_embd, n_embd_gqa, n_embd_gqa, 0); -llm_build_bert::llm_build_bert(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { - const int64_t n_embd_head = hparams.n_embd_head_v; - const int64_t n_embd_gqa = hparams.n_embd_v_gqa(); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); + layer.wo_b = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); - GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); + layer.attn_out_norm = create_tensor(tn(LLM_TENSOR_ATTN_OUT_NORM, "weight", i), {n_embd}, 0); + layer.attn_out_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_OUT_NORM, "bias", i), {n_embd}, 0); + + if (hparams.moe_every_n_layers > 0 && i % hparams.moe_every_n_layers == 1) { + layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), { n_embd, n_ff, n_expert}, 0); + layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), { n_ff, n_embd, n_expert}, 0); + layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0); + } else { + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_up_b = create_tensor(tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff}, TENSOR_NOT_REQUIRED); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0); + layer.ffn_down_b = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); + + if (arch == LLM_ARCH_NOMIC_BERT) { + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); + } + } + + layer.layer_out_norm = create_tensor(tn(LLM_TENSOR_LAYER_OUT_NORM, "weight", i), {n_embd}, 0); + layer.layer_out_norm_b = create_tensor(tn(LLM_TENSOR_LAYER_OUT_NORM, "bias", i), {n_embd}, 0); + } +} + +std::unique_ptr<llm_graph_context> llama_model_bert::build_arch_graph(const llm_graph_params & params) const { + return std::make_unique<graph>(*this, params); +} + +llama_model_bert::graph::graph(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { + const int64_t n_embd_head = hparams.n_embd_head_v(); + + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k()); ggml_tensor * cur; ggml_tensor * inpL; @@ -30,8 +104,8 @@ llm_build_bert::llm_build_bert(const llama_model & model, const llm_graph_params cb(inpL, "inp_embd", -1); // embed layer norm - inpL = build_norm(inpL, model.tok_norm, model.tok_norm_b, LLM_NORM, -1); - cb(inpL, "inp_norm", -1); + inpL = build_norm(inpL, model.tok_norm, model.tok_norm_b, LLM_NORM, 0); + cb(inpL, "inp_norm", 0); auto * inp_attn = build_attn_inp_no_cache(); @@ -41,35 +115,8 @@ llm_build_bert::llm_build_bert(const llama_model & model, const llm_graph_params ggml_tensor * cur = inpL; { - ggml_tensor * Qcur; - ggml_tensor * Kcur; - ggml_tensor * Vcur; - - // self-attention - if (model.layers[il].wqkv) { - cur = build_lora_mm(model.layers[il].wqkv, cur); - cb(cur, "wqkv", il); - - if (model.layers[il].bqkv) { - cur = ggml_add(ctx0, cur, model.layers[il].bqkv); - cb(cur, "bqkv", il); - } - - Qcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head, n_tokens, n_embd_head * sizeof(float), cur->nb[1], - 0 * sizeof(float) * (n_embd)); - Kcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head * sizeof(float), - cur->nb[1], 1 * sizeof(float) * (n_embd)); - Vcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head * sizeof(float), - cur->nb[1], 1 * sizeof(float) * (n_embd + n_embd_gqa)); - } else { - Qcur = ggml_add(ctx0, build_lora_mm(model.layers[il].wq, cur), model.layers[il].bq); - Kcur = ggml_add(ctx0, build_lora_mm(model.layers[il].wk, cur), model.layers[il].bk); - Vcur = ggml_add(ctx0, build_lora_mm(model.layers[il].wv, cur), model.layers[il].bv); - - Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); - Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); - } + auto [Qcur, Kcur, Vcur] = build_qkv(model.layers[il], cur, + n_embd_head, n_head, n_head_kv, il); if (model.layers[il].attn_q_norm) { Qcur = ggml_reshape_2d(ctx0, Qcur, n_embd_head * n_head, n_tokens); @@ -102,7 +149,7 @@ llm_build_bert::llm_build_bert(const llama_model & model, const llm_graph_params cb(Vcur, "Vcur", il); cur = build_attn(inp_attn, - model.layers[il].wo, model.layers[il].bo, + model.layers[il].wo, model.layers[il].wo_b, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f / sqrtf(float(n_embd_head)), il); cb(cur, "kqv_out", il); } @@ -129,9 +176,17 @@ llm_build_bert::llm_build_bert(const llama_model & model, const llm_graph_params // feed-forward network if (hparams.moe_every_n_layers > 0 && il % hparams.moe_every_n_layers == 1) { // MoE branch - cur = build_moe_ffn(cur, model.layers[il].ffn_gate_inp, model.layers[il].ffn_up_exps, nullptr, - model.layers[il].ffn_down_exps, nullptr, hparams.n_expert, hparams.n_expert_used, - LLM_FFN_GELU, false, false, 0.0f, LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX, il); + cur = build_moe_ffn(cur, + model.layers[il].ffn_gate_inp, + model.layers[il].ffn_up_exps, + nullptr, + model.layers[il].ffn_down_exps, + nullptr, + hparams.n_expert, hparams.n_expert_used, + LLM_FFN_GELU, false, + hparams.expert_weights_scale, + LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX, + il); cb(cur, "ffn_moe_out", il); } else if (model.arch == LLM_ARCH_BERT || model.arch == LLM_ARCH_NOMIC_BERT_MOE || model.arch == LLM_ARCH_JINA_BERT_V3) { diff --git a/examples/talk-llama/models/bitnet.cpp b/examples/talk-llama/models/bitnet.cpp index 331a3f11197..c8330274580 100644 --- a/examples/talk-llama/models/bitnet.cpp +++ b/examples/talk-llama/models/bitnet.cpp @@ -1,10 +1,57 @@ #include "models.h" +void llama_model_bitnet::load_arch_hparams(llama_model_loader & ml) { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); -llm_build_bitnet::llm_build_bitnet(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { - const int64_t n_embd_head = hparams.n_embd_head_v; + switch (hparams.n_layer()) { + case 26: type = LLM_TYPE_3B; break; + default: type = LLM_TYPE_UNKNOWN; + } +} + +void llama_model_bitnet::load_arch_tensors(llama_model_loader &) { + LLAMA_LOAD_LOCALS; + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + layer.attn_sub_norm = create_tensor(tn(LLM_TENSOR_ATTN_SUB_NORM, "weight", i), {n_embd}, 0); + + layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}, 0); + layer.wq_s = create_tensor(tn(LLM_TENSOR_ATTN_Q, "scale", i), {1}, TENSOR_NOT_REQUIRED); + layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}, 0); + layer.wk_s = create_tensor(tn(LLM_TENSOR_ATTN_K, "scale", i), {1}, TENSOR_NOT_REQUIRED); + layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}, 0); + layer.wv_s = create_tensor(tn(LLM_TENSOR_ATTN_V, "scale", i), {1}, TENSOR_NOT_REQUIRED); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); + layer.wo_s = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "scale", i), {1}, TENSOR_NOT_REQUIRED); - GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + layer.ffn_sub_norm = create_tensor(tn(LLM_TENSOR_FFN_SUB_NORM, "weight", i), {n_ff}, 0); + + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_gate_s = create_tensor(tn(LLM_TENSOR_FFN_GATE, "scale", i), {1}, TENSOR_NOT_REQUIRED); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0); + layer.ffn_down_s = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "scale", i), {1}, TENSOR_NOT_REQUIRED); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_up_s = create_tensor(tn(LLM_TENSOR_FFN_UP, "scale", i), {1}, TENSOR_NOT_REQUIRED); + } +} + +std::unique_ptr<llm_graph_context> llama_model_bitnet::build_arch_graph(const llm_graph_params & params) const { + return std::make_unique<graph>(*this, params); +} + +llama_model_bitnet::graph::graph(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { + const int64_t n_embd_head = hparams.n_embd_head_v(); + + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k()); ggml_tensor * cur; ggml_tensor * inpL; @@ -28,42 +75,8 @@ llm_build_bitnet::llm_build_bitnet(const llama_model & model, const llm_graph_pa // self-attention { - // compute Q and K and RoPE them - ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); - if (model.layers[il].wq_scale) { - Qcur = ggml_mul(ctx0, Qcur, model.layers[il].wq_scale); - } - cb(Qcur, "Qcur", il); - if (model.layers[il].bq) { - Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq); - cb(Qcur, "Qcur", il); - } - - // B1.K - ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); - if (model.layers[il].wk_scale) { - Kcur = ggml_mul(ctx0, Kcur, model.layers[il].wk_scale); - } - cb(Kcur, "Kcur", il); - if (model.layers[il].bk) { - Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk); - cb(Kcur, "Kcur", il); - } - - // B1.V - ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); - if (model.layers[il].wv_scale) { - Vcur = ggml_mul(ctx0, Vcur, model.layers[il].wv_scale); - } - cb(Vcur, "Vcur", il); - if (model.layers[il].bv) { - Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv); - cb(Vcur, "Vcur", il); - } - - Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); - Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + auto [Qcur, Kcur, Vcur] = build_qkv(model.layers[il], cur, + n_embd_head, n_head, n_head_kv, il); Qcur = ggml_rope_ext( ctx0, Qcur, inp_pos, nullptr, @@ -82,7 +95,7 @@ llm_build_bitnet::llm_build_bitnet(const llama_model & model, const llm_graph_pa cb(Vcur, "Vcur", il); cur = build_attn(inp_attn, - NULL, NULL, + NULL, NULL, NULL, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); cur = build_norm(cur, @@ -90,12 +103,9 @@ llm_build_bitnet::llm_build_bitnet(const llama_model & model, const llm_graph_pa LLM_NORM_RMS, il); cb(cur, "attn_sub_norm", il); - cur = build_lora_mm(model.layers[il].wo, cur); - if (model.layers[il].wo_scale) { - cur = ggml_mul(ctx0, cur, model.layers[il].wo_scale); - } - if (model.layers[il].bo) { - cur = ggml_add(ctx0, cur, model.layers[il].bo); + cur = build_lora_mm(model.layers[il].wo, cur, model.layers[il].wo_s); + if (model.layers[il].wo_b) { + cur = ggml_add(ctx0, cur, model.layers[il].wo_b); } cb(cur, "attn_out", il); } @@ -115,8 +125,8 @@ llm_build_bitnet::llm_build_bitnet(const llama_model & model, const llm_graph_pa cb(cur, "ffn_norm", il); cur = build_ffn(cur, - model.layers[il].ffn_up, NULL, model.layers[il].ffn_up_scale, - model.layers[il].ffn_gate, NULL, model.layers[il].ffn_gate_scale, + model.layers[il].ffn_up, NULL, model.layers[il].ffn_up_s, + model.layers[il].ffn_gate, NULL, model.layers[il].ffn_gate_s, NULL, NULL, NULL, NULL, LLM_FFN_SILU, LLM_FFN_PAR, il); @@ -127,15 +137,15 @@ llm_build_bitnet::llm_build_bitnet(const llama_model & model, const llm_graph_pa LLM_NORM_RMS, il); cb(cur, "ffn_sub_norm", il); - cur = build_lora_mm(model.layers[il].ffn_down, cur); - if (model.layers[il].ffn_down_scale) { - cur = ggml_mul(ctx0, cur, model.layers[il].ffn_down_scale); - } + cur = build_lora_mm(model.layers[il].ffn_down, cur, model.layers[il].ffn_down_s); cb(cur, "ffn_down", il); cur = ggml_add(ctx0, cur, ffn_inp); cb(cur, "l_out", il); + cur = build_cvec(cur, il); + cb(cur, "l_out", il); + // input for next layer inpL = cur; } diff --git a/examples/talk-llama/models/bloom.cpp b/examples/talk-llama/models/bloom.cpp index 2c552d1d15e..609d2ddf998 100644 --- a/examples/talk-llama/models/bloom.cpp +++ b/examples/talk-llama/models/bloom.cpp @@ -1,10 +1,71 @@ #include "models.h" -llm_build_bloom::llm_build_bloom(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { - const int64_t n_embd_head = hparams.n_embd_head_v; - const int64_t n_embd_gqa = hparams.n_embd_v_gqa(); +void llama_model_bloom::load_arch_hparams(llama_model_loader & ml) { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); + + switch (hparams.n_layer()) { + case 24: type = LLM_TYPE_1B; break; + case 30: + switch (hparams.n_embd) { + case 2560: type = LLM_TYPE_3B; break; + case 4096: type = LLM_TYPE_7B; break; + default: type = LLM_TYPE_UNKNOWN; + } break; + default: type = LLM_TYPE_UNKNOWN; + } + + // TODO: become GGUF KV parameter + hparams.f_max_alibi_bias = 8.0f; +} + +void llama_model_bloom::load_arch_tensors(llama_model_loader &) { + LLAMA_LOAD_LOCALS; + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + tok_norm = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "weight", 0), {n_embd}, 0); + tok_norm_b = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "bias", 0), {n_embd}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output_norm_b = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); + + // if output is NULL, init from the input tok embed + if (output == NULL) { + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); + } + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + layer.attn_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "bias", i), {n_embd}, 0); + + layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, 0); + layer.wqkv_b = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "bias", i), {n_embd + 2*n_embd_gqa}, 0); + + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); + layer.wo_b = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, 0); - GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + layer.ffn_norm_b = create_tensor(tn(LLM_TENSOR_FFN_NORM, "bias", i), {n_embd}, 0); + + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0); + layer.ffn_down_b = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, 0); + + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_up_b = create_tensor(tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff}, 0); + } +} + +std::unique_ptr<llm_graph_context> llama_model_bloom::build_arch_graph(const llm_graph_params & params) const { + return std::make_unique<graph>(*this, params); +} + +llama_model_bloom::graph::graph(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { + const int64_t n_embd_head = hparams.n_embd_head_v(); + + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k()); ggml_tensor * cur; ggml_tensor * inpL; @@ -16,8 +77,8 @@ llm_build_bloom::llm_build_bloom(const llama_model & model, const llm_graph_para inpL = build_norm(inpL, model.tok_norm, model.tok_norm_b, - LLM_NORM, -1); - cb(inpL, "inp_norm", -1); + LLM_NORM, 0); + cb(inpL, "inp_norm", 0); ggml_tensor * inp_out_ids = build_inp_out_ids(); @@ -30,22 +91,11 @@ llm_build_bloom::llm_build_bloom(const llama_model & model, const llm_graph_para // self-attention { - cur = build_lora_mm(model.layers[il].wqkv, cur); - cb(cur, "wqkv", il); - - cur = ggml_add(ctx0, cur, model.layers[il].bqkv); - cb(cur, "bqkv", il); - - ggml_tensor * Qcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 0*sizeof(float)*(n_embd)); - ggml_tensor * Kcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 1*sizeof(float)*(n_embd)); - ggml_tensor * Vcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa)); - - cb(Qcur, "Qcur", il); - cb(Kcur, "Kcur", il); - cb(Vcur, "Vcur", il); + auto [Qcur, Kcur, Vcur] = build_qkv(model.layers[il], cur, + n_embd_head, n_head, n_head_kv, il); cur = build_attn(inp_attn, - model.layers[il].wo, model.layers[il].bo, + model.layers[il].wo, model.layers[il].wo_b, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } @@ -92,7 +142,7 @@ llm_build_bloom::llm_build_bloom(const llama_model & model, const llm_graph_para cb(cur, "result_norm", -1); res->t_embd = cur; - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/chameleon.cpp b/examples/talk-llama/models/chameleon.cpp index 184511aed4c..4f45acecf84 100644 --- a/examples/talk-llama/models/chameleon.cpp +++ b/examples/talk-llama/models/chameleon.cpp @@ -1,12 +1,60 @@ #include "models.h" - #include <float.h> -llm_build_chameleon::llm_build_chameleon(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { - const int64_t n_embd_head = hparams.n_embd_head_v; +void llama_model_chameleon::load_arch_hparams(llama_model_loader & ml) { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + hparams.f_norm_eps = 1e-5; // eps for qk-norm, torch default + ml.get_key(LLM_KV_SWIN_NORM, hparams.swin_norm, false); + + switch (hparams.n_layer()) { + case 32: type = LLM_TYPE_7B; break; + case 48: type = LLM_TYPE_34B; break; + default: type = LLM_TYPE_UNKNOWN; + } +} + +void llama_model_chameleon::load_arch_tensors(llama_model_loader &) { + LLAMA_LOAD_LOCALS; + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); + // if output is NULL, init from the input tok embed + if (output == NULL) { + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); + } + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head_k, n_head}, 0); + layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k, n_head_kv}, 0); + layer.attn_q_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "bias", i), {n_embd_head_k, n_head}, TENSOR_NOT_REQUIRED); + layer.attn_k_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "bias", i), {n_embd_head_k, n_head_kv}, TENSOR_NOT_REQUIRED); - GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); - GGML_ASSERT(n_embd_head == hparams.n_rot); + create_tensor_qkv(layer, i, n_embd, n_embd, n_embd_gqa, n_embd_gqa, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + } +} + +std::unique_ptr<llm_graph_context> llama_model_chameleon::build_arch_graph(const llm_graph_params & params) const { + return std::make_unique<graph>(*this, params); +} + +llama_model_chameleon::graph::graph(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { + const int64_t n_embd_head = hparams.n_embd_head_v(); + + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k()); + GGML_ASSERT(n_embd_head == n_rot); ggml_tensor * cur; ggml_tensor * inpL; @@ -36,22 +84,10 @@ llm_build_chameleon::llm_build_chameleon(const llama_model & model, const llm_gr // self-attention { // compute Q and K and RoPE them - ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); - cb(Qcur, "Qcur", il); - - ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); - cb(Kcur, "Kcur", il); - - ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); - cb(Vcur, "Vcur", il); + auto [Qcur, Kcur, Vcur] = build_qkv(model.layers[il], cur, + n_embd_head, n_head, n_head_kv, il); if (model.layers[il].attn_q_norm) { - Qcur = ggml_view_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens, - ggml_element_size(Qcur) * n_embd_head, - ggml_element_size(Qcur) * n_embd_head * n_head, - 0); - cb(Qcur, "Qcur", il); - Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, model.layers[il].attn_q_norm_b, @@ -60,12 +96,6 @@ llm_build_chameleon::llm_build_chameleon(const llama_model & model, const llm_gr } if (model.layers[il].attn_k_norm) { - Kcur = ggml_view_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens, - ggml_element_size(Kcur) * n_embd_head, - ggml_element_size(Kcur) * n_embd_head * n_head_kv, - 0); - cb(Kcur, "Kcur", il); - Kcur = build_norm(Kcur, model.layers[il].attn_k_norm, model.layers[il].attn_k_norm_b, @@ -73,10 +103,6 @@ llm_build_chameleon::llm_build_chameleon(const llama_model & model, const llm_gr cb(Kcur, "Kcur", il); } - Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); - Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); - Qcur = ggml_rope_ext( ctx0, Qcur, inp_pos, nullptr, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, @@ -94,7 +120,7 @@ llm_build_chameleon::llm_build_chameleon(const llama_model & model, const llm_gr cb(Vcur, "Vcur", il); cur = build_attn(inp_attn, - model.layers[il].wo, nullptr, + model.layers[il].wo, nullptr, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } @@ -155,7 +181,7 @@ llm_build_chameleon::llm_build_chameleon(const llama_model & model, const llm_gr res->t_embd = cur; // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output_with_img_logits", -1); // TODO: this suppresses the output of image tokens, which is required to enable text-only outputs. diff --git a/examples/talk-llama/models/chatglm.cpp b/examples/talk-llama/models/chatglm.cpp index 2685d4fbcbe..7ae5b938fde 100644 --- a/examples/talk-llama/models/chatglm.cpp +++ b/examples/talk-llama/models/chatglm.cpp @@ -1,11 +1,64 @@ #include "models.h" +void llama_model_chatglm::load_arch_hparams(llama_model_loader & ml) { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); -llm_build_chatglm::llm_build_chatglm(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { - const int64_t n_embd_head = hparams.n_embd_head_v; - const int64_t n_embd_gqa = hparams.n_embd_v_gqa(); + switch (hparams.n_layer()) { + case 28: { + if (hparams.n_head(0) == 16) { + type = LLM_TYPE_1_5B; + } else { + type = LLM_TYPE_6B; + } + } break; + case 40: { + if (hparams.n_head(0) == 24) { + type = LLM_TYPE_4B; + } else { + type = LLM_TYPE_9B; + } + } break; + default: type = LLM_TYPE_UNKNOWN; + } +} + +void llama_model_chatglm::load_arch_tensors(llama_model_loader &) { + LLAMA_LOAD_LOCALS; + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); + // if output is NULL, init from the input tok embed + if (output == NULL) { + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); + } + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + create_tensor_qkv(layer, i, n_embd, n_embd_head_k * n_head, n_embd_k_gqa, n_embd_v_gqa, 0); + + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); - GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff * 2}, 0); + + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0); + } +} + +std::unique_ptr<llm_graph_context> llama_model_chatglm::build_arch_graph(const llm_graph_params & params) const { + return std::make_unique<graph>(*this, params); +} + +llama_model_chatglm::graph::graph(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { + const int64_t n_embd_head = hparams.n_embd_head_v(); + + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k()); ggml_tensor * cur; ggml_tensor * inpL; @@ -30,37 +83,8 @@ llm_build_chatglm::llm_build_chatglm(const llama_model & model, const llm_graph_ // self-attention { - ggml_tensor * Qcur = nullptr; - ggml_tensor * Kcur = nullptr; - ggml_tensor * Vcur = nullptr; - - if (model.layers[il].wqkv == nullptr) { - Qcur = build_lora_mm(model.layers[il].wq, cur); - if (model.layers[il].bq) { - Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq); - } - Kcur = build_lora_mm(model.layers[il].wk, cur); - if (model.layers[il].bk) { - Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk); - } - Vcur = build_lora_mm(model.layers[il].wv, cur); - if (model.layers[il].bv) { - Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv); - } - Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); - Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); - } else { - cur = build_lora_mm(model.layers[il].wqkv, cur); - cb(cur, "wqkv", il); - if (model.layers[il].bqkv) { - cur = ggml_add(ctx0, cur, model.layers[il].bqkv); - cb(cur, "bqkv", il); - } - Qcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 0*sizeof(float)*(n_embd)); - Kcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 1*sizeof(float)*(n_embd)); - Vcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa)); - } + auto [Qcur, Kcur, Vcur] = build_qkv(model.layers[il], cur, + n_embd_head, n_head, n_head_kv, il); //printf("freq_base: %f freq_scale: %f ext_factor: %f attn_factor: %f\n", freq_base, freq_scale, ext_factor, attn_factor); Qcur = ggml_rope_ext( @@ -80,7 +104,7 @@ llm_build_chatglm::llm_build_chatglm(const llama_model & model, const llm_graph_ cb(Vcur, "Vcur", il); cur = build_attn(inp_attn, - model.layers[il].wo, NULL, + model.layers[il].wo, NULL, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } @@ -111,8 +135,13 @@ llm_build_chatglm::llm_build_chatglm(const llama_model & model, const llm_graph_ } - inpL = ggml_add(ctx0, cur, ffn_inp); - cb(inpL, "l_out", il); + cur = ggml_add(ctx0, cur, ffn_inp); + + cur = build_cvec(cur, il); + cb(cur, "l_out", il); + + // input for next layer + inpL = cur; } cur = build_norm(inpL, @@ -123,7 +152,7 @@ llm_build_chatglm::llm_build_chatglm(const llama_model & model, const llm_graph_ cb(cur, "result_norm", -1); res->t_embd = cur; - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/codeshell.cpp b/examples/talk-llama/models/codeshell.cpp index 0b3bdbff529..de53bb98184 100644 --- a/examples/talk-llama/models/codeshell.cpp +++ b/examples/talk-llama/models/codeshell.cpp @@ -1,11 +1,60 @@ #include "models.h" -llm_build_codeshell::llm_build_codeshell(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { - const int64_t n_embd_head = hparams.n_embd_head_v; - const int64_t n_embd_gqa = hparams.n_embd_v_gqa(); +void llama_model_codeshell::load_arch_hparams(llama_model_loader & ml) { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); - GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); - GGML_ASSERT(n_embd_head == hparams.n_rot); + switch (hparams.n_layer()) { + case 42: type = LLM_TYPE_7B; break; + default: type = LLM_TYPE_UNKNOWN; + } +} + +void llama_model_codeshell::load_arch_tensors(llama_model_loader &) { + LLAMA_LOAD_LOCALS; + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); + + // if tok embd is NULL, init from output + if (tok_embd == NULL) { + tok_embd = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); + } + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output_norm_b = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0); + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + layer.attn_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "bias", i), {n_embd}, 0); + + create_tensor_qkv(layer, i, n_embd, n_embd, n_embd_gqa, n_embd_gqa, 0); + + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); + layer.wo_b = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, 0); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + layer.ffn_norm_b = create_tensor(tn(LLM_TENSOR_FFN_NORM, "bias", i), {n_embd}, 0); + + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0); + layer.ffn_down_b = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, 0); + + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_up_b = create_tensor(tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff}, 0); + } +} + +std::unique_ptr<llm_graph_context> llama_model_codeshell::build_arch_graph(const llm_graph_params & params) const { + return std::make_unique<graph>(*this, params); +} + +llama_model_codeshell::graph::graph(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { + const int64_t n_embd_head = hparams.n_embd_head_v(); + + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k()); + GGML_ASSERT(n_embd_head == n_rot); ggml_tensor * cur; ggml_tensor * inpL; @@ -28,15 +77,8 @@ llm_build_codeshell::llm_build_codeshell(const llama_model & model, const llm_gr // self-attention { - cur = build_lora_mm(model.layers[il].wqkv, cur); - cb(cur, "wqkv", il); - - cur = ggml_add(ctx0, cur, model.layers[il].bqkv); - cb(cur, "bqkv", il); - - ggml_tensor * Qcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 0*sizeof(float)*(n_embd)); - ggml_tensor * Kcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 1*sizeof(float)*(n_embd)); - ggml_tensor * Vcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa)); + auto [Qcur, Kcur, Vcur] = build_qkv(model.layers[il], cur, + n_embd_head, n_head, n_head_kv, il); Qcur = ggml_rope_ext( ctx0, Qcur, inp_pos, nullptr, @@ -55,7 +97,7 @@ llm_build_codeshell::llm_build_codeshell(const llama_model & model, const llm_gr cb(Vcur, "Vcur", il); cur = build_attn(inp_attn, - model.layers[il].wo, model.layers[il].bo, + model.layers[il].wo, model.layers[il].wo_b, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } @@ -102,7 +144,7 @@ llm_build_codeshell::llm_build_codeshell(const llama_model & model, const llm_gr cb(cur, "result_norm", -1); res->t_embd = cur; - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/cogvlm.cpp b/examples/talk-llama/models/cogvlm.cpp index 0ceae3aaeb5..750f57a394e 100644 --- a/examples/talk-llama/models/cogvlm.cpp +++ b/examples/talk-llama/models/cogvlm.cpp @@ -1,12 +1,62 @@ #include "models.h" -llm_build_cogvlm::llm_build_cogvlm(const llama_model & model, const llm_graph_params & params) : +void llama_model_cogvlm::load_arch_hparams(llama_model_loader & ml) { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + + switch (hparams.n_layer()) { + case 32: type = LLM_TYPE_13B; break; + default: type = LLM_TYPE_UNKNOWN; + } +} + +void llama_model_cogvlm::load_arch_tensors(llama_model_loader &) { + LLAMA_LOAD_LOCALS; + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); + + // if output is NULL, init from the input tok embed + if (output == NULL) { + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); + } + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd_head_k * n_head * 3}, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0); + + layer.visexp_attn_wqkv = create_tensor(tn(LLM_TENSOR_VISEXP_ATTN_QKV, "weight", i), {n_embd, n_embd_head_k * n_head * 3}, 0); + layer.visexp_attn_wo = create_tensor(tn(LLM_TENSOR_VISEXP_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0); + + layer.rope_freqs = create_tensor(tn(LLM_TENSOR_ROPE_FREQS, "weight", i), {n_rot/2}, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0)); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + + layer.visexp_ffn_gate = create_tensor(tn(LLM_TENSOR_VISEXP_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); + layer.visexp_ffn_down = create_tensor(tn(LLM_TENSOR_VISEXP_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); + layer.visexp_ffn_up = create_tensor(tn(LLM_TENSOR_VISEXP_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + } +} + +std::unique_ptr<llm_graph_context> llama_model_cogvlm::build_arch_graph(const llm_graph_params & params) const { + return std::make_unique<graph>(*this, params); +} + +llama_model_cogvlm::graph::graph(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { - const int64_t n_embd_head = hparams.n_embd_head_v; + const int64_t n_embd_head = hparams.n_embd_head_v(); const float kq_scale = 1.0f / sqrtf(float(n_embd_head)); - GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); - GGML_ASSERT(n_embd_head == hparams.n_rot); + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k()); + GGML_ASSERT(n_embd_head == n_rot); ggml_tensor * inpL; ggml_tensor * cur; @@ -28,18 +78,20 @@ llm_build_cogvlm::llm_build_cogvlm(const llama_model & model, const llm_graph_pa for (int il = 0; il < n_layer; ++il) { // get either the text or image weight tensors - ggml_tensor *wqkv, *wo; + ggml_tensor *wqkv, *wo, *wo_s; ggml_tensor *ffn_gate, *ffn_down, *ffn_up; if (is_text) { wqkv = model.layers[il].wqkv; wo = model.layers[il].wo; + wo_s = model.layers[il].wo_s; ffn_gate = model.layers[il].ffn_gate; ffn_down = model.layers[il].ffn_down; ffn_up = model.layers[il].ffn_up; } else { wqkv = model.layers[il].visexp_attn_wqkv; wo = model.layers[il].visexp_attn_wo; + wo_s = nullptr; ffn_gate = model.layers[il].visexp_ffn_gate; ffn_down = model.layers[il].visexp_ffn_down; ffn_up = model.layers[il].visexp_ffn_up; @@ -64,7 +116,7 @@ llm_build_cogvlm::llm_build_cogvlm(const llama_model & model, const llm_graph_pa Kcur = ggml_rope(ctx0, Kcur, inp_pos, n_embd_head, rope_type); cur = build_attn(inp_attn, - wo, nullptr, + wo, nullptr, wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale, il); @@ -86,6 +138,10 @@ llm_build_cogvlm::llm_build_cogvlm(const llama_model & model, const llm_graph_pa cur = ggml_add(ctx0, cur, ffn_inp); cb(cur, "ffn_out", il); + cur = build_cvec(cur, il); + cb(cur, "l_out", il); + + // input for next layer inpL = cur; } @@ -95,7 +151,7 @@ llm_build_cogvlm::llm_build_cogvlm(const llama_model & model, const llm_graph_pa cb(cur, "result_norm", -1); res->t_embd = cur; - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; ggml_build_forward_expand(gf, cur); diff --git a/examples/talk-llama/models/cohere2-iswa.cpp b/examples/talk-llama/models/cohere2.cpp similarity index 53% rename from examples/talk-llama/models/cohere2-iswa.cpp rename to examples/talk-llama/models/cohere2.cpp index 9334b5e4263..61a5945a194 100644 --- a/examples/talk-llama/models/cohere2-iswa.cpp +++ b/examples/talk-llama/models/cohere2.cpp @@ -1,9 +1,58 @@ #include "models.h" -llm_build_cohere2_iswa::llm_build_cohere2_iswa(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { - const int64_t n_embd_head = hparams.n_embd_head_v; +void llama_model_cohere2::load_arch_hparams(llama_model_loader & ml) { + hparams.swa_type = LLAMA_SWA_TYPE_STANDARD; + uint32_t swa_period = 4; + ml.get_key_or_arr(LLM_KV_ATTENTION_SLIDING_WINDOW_PATTERN, swa_period, false); + hparams.set_swa_pattern(swa_period); + + hparams.rope_freq_base_train_swa = hparams.rope_freq_base_train; + hparams.rope_freq_scale_train_swa = hparams.rope_freq_scale_train; + + ml.get_key(LLM_KV_ROPE_FREQ_BASE_SWA, hparams.rope_freq_base_train_swa, false); + ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa); + ml.get_key(LLM_KV_LOGIT_SCALE, hparams.f_logit_scale); + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); + + switch (hparams.n_layer()) { + case 32: type = LLM_TYPE_8B; break; + default: type = LLM_TYPE_UNKNOWN; + } +} + +void llama_model_cohere2::load_arch_tensors(llama_model_loader &) { + LLAMA_LOAD_LOCALS; + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), { n_embd, n_vocab }, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), { n_embd }, 0); + // init output from the input tok embed + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), { n_embd, n_vocab }, + TENSOR_DUPLICATED); + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; - GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), { n_embd }, 0); + + create_tensor_qkv(layer, i, n_embd, n_embd, n_embd_gqa, n_embd_gqa, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), { n_embd, n_embd }, 0); + + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), { n_embd, n_ff }, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd }, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), { n_embd, n_ff }, 0); + } +} + +std::unique_ptr<llm_graph_context> llama_model_cohere2::build_arch_graph(const llm_graph_params & params) const { + return std::make_unique<graph>(*this, params); +} + +llama_model_cohere2::graph::graph(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { + const int64_t n_embd_head = hparams.n_embd_head_v(); + + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k()); const float f_logit_scale = hparams.f_logit_scale; @@ -36,30 +85,8 @@ llm_build_cohere2_iswa::llm_build_cohere2_iswa(const llama_model & model, const ggml_tensor * rope_factors = model.get_rope_factors(cparams, il); // compute Q and K and RoPE them - ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); - cb(Qcur, "Qcur", il); - if (model.layers[il].bq) { - Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq); - cb(Qcur, "Qcur", il); - } - - ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); - cb(Kcur, "Kcur", il); - if (model.layers[il].bk) { - Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk); - cb(Kcur, "Kcur", il); - } - - ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); - cb(Vcur, "Vcur", il); - if (model.layers[il].bv) { - Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv); - cb(Vcur, "Vcur", il); - } - - Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); - Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + auto [Qcur, Kcur, Vcur] = build_qkv(model.layers[il], cur, + n_embd_head, n_head, n_head_kv, il); if (is_swa) { Qcur = ggml_rope_ext( @@ -80,7 +107,7 @@ llm_build_cohere2_iswa::llm_build_cohere2_iswa(const llama_model & model, const cb(Vcur, "Vcur", il); cur = build_attn(inp_attn, - model.layers[il].wo, model.layers[il].bo, + model.layers[il].wo, model.layers[il].wo_b, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } @@ -121,7 +148,7 @@ llm_build_cohere2_iswa::llm_build_cohere2_iswa(const llama_model & model, const res->t_embd = cur; // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); if (f_logit_scale) { cur = ggml_scale(ctx0, cur, f_logit_scale); diff --git a/examples/talk-llama/models/command-r.cpp b/examples/talk-llama/models/command-r.cpp index 4d3b643b444..94a46188bb8 100644 --- a/examples/talk-llama/models/command-r.cpp +++ b/examples/talk-llama/models/command-r.cpp @@ -1,12 +1,53 @@ #include "models.h" +void llama_model_command_r::load_arch_hparams(llama_model_loader & ml) { + ml.get_key(LLM_KV_LOGIT_SCALE, hparams.f_logit_scale, false); + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); + switch (hparams.n_layer()) { + case 40: type = LLM_TYPE_35B; break; + default: type = LLM_TYPE_UNKNOWN; + } +} + +void llama_model_command_r::load_arch_tensors(llama_model_loader &) { + LLAMA_LOAD_LOCALS; + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + // init output from the input tok embed + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; -llm_build_command_r::llm_build_command_r(const llama_model & model, const llm_graph_params & params) : + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + + if (n_layer >= 64){ + layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head_k, n_head}, 0); + layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k, n_head_kv}, 0); + } + + create_tensor_qkv(layer, i, n_embd, n_embd, n_embd_gqa, n_embd_gqa, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); + + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + } +} + +std::unique_ptr<llm_graph_context> llama_model_command_r::build_arch_graph(const llm_graph_params & params) const { + return std::make_unique<graph>(*this, params); +} + +llama_model_command_r::graph::graph(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { - const int64_t n_embd_head = hparams.n_embd_head_v; + const int64_t n_embd_head = hparams.n_embd_head_v(); - GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k()); const float f_logit_scale = hparams.f_logit_scale; @@ -32,27 +73,8 @@ llm_build_command_r::llm_build_command_r(const llama_model & model, const llm_gr // self-attention { // compute Q and K and RoPE them - ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); - cb(Qcur, "Qcur", il); - if (model.layers[il].bq) { - Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq); - cb(Qcur, "Qcur", il); - } - ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); - cb(Kcur, "Kcur", il); - if (model.layers[il].bk) { - Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk); - cb(Kcur, "Kcur", il); - } - ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); - cb(Vcur, "Vcur", il); - if (model.layers[il].bv) { - Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv); - cb(Vcur, "Vcur", il); - } - Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); - Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + auto [Qcur, Kcur, Vcur] = build_qkv(model.layers[il], cur, + n_embd_head, n_head, n_head_kv, il); if (model.layers[il].attn_q_norm) { Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, NULL, LLM_NORM, il); @@ -73,7 +95,7 @@ llm_build_command_r::llm_build_command_r(const llama_model & model, const llm_gr cb(Vcur, "Vcur", il); cur = build_attn(inp_attn, - model.layers[il].wo, model.layers[il].bo, + model.layers[il].wo, model.layers[il].wo_b, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f / sqrtf(float(n_embd_head)), il); } if (il == n_layer - 1 && inp_out_ids) { @@ -110,7 +132,7 @@ llm_build_command_r::llm_build_command_r(const llama_model & model, const llm_gr res->t_embd = cur; // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); if (f_logit_scale) { cur = ggml_scale(ctx0, cur, f_logit_scale); diff --git a/examples/talk-llama/models/dbrx.cpp b/examples/talk-llama/models/dbrx.cpp index 6d2a0ebf1b7..4f5ac4d06a4 100644 --- a/examples/talk-llama/models/dbrx.cpp +++ b/examples/talk-llama/models/dbrx.cpp @@ -1,12 +1,54 @@ #include "models.h" +void llama_model_dbrx::load_arch_hparams(llama_model_loader & ml) { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); + ml.get_key(LLM_KV_ATTENTION_CLAMP_KQV, hparams.f_clamp_kqv); -llm_build_dbrx::llm_build_dbrx(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { - const int64_t n_embd_head = hparams.n_embd_head_v; - const int64_t n_embd_gqa = hparams.n_embd_v_gqa(); + switch (hparams.n_layer()) { + case 40: type = LLM_TYPE_16x12B; break; + default: type = LLM_TYPE_UNKNOWN; + } +} + +void llama_model_dbrx::load_arch_tensors(llama_model_loader &) { + LLAMA_LOAD_LOCALS; + + if (n_expert == 0) { + throw std::runtime_error("DBRX model cannot have zero experts"); + } + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0); + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + + layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); - GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); - GGML_ASSERT(n_embd_head == hparams.n_rot); + layer.attn_out_norm = create_tensor(tn(LLM_TENSOR_ATTN_OUT_NORM, "weight", i), {n_embd}, 0); + + layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0); + layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {n_embd, n_ff, n_expert}, 0); + layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff, n_embd, n_expert}, 0); + layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), {n_embd, n_ff, n_expert}, 0); + } +} + +std::unique_ptr<llm_graph_context> llama_model_dbrx::build_arch_graph(const llm_graph_params & params) const { + return std::make_unique<graph>(*this, params); +} + +llama_model_dbrx::graph::graph(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { + const int64_t n_embd_head = hparams.n_embd_head_v(); + + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k()); + GGML_ASSERT(n_embd_head == n_rot); ggml_tensor * cur; ggml_tensor * inpL; @@ -31,19 +73,8 @@ llm_build_dbrx::llm_build_dbrx(const llama_model & model, const llm_graph_params // self-attention { - ggml_tensor * Qcur = nullptr; - ggml_tensor * Kcur = nullptr; - ggml_tensor * Vcur = nullptr; - - cur = build_lora_mm(model.layers[il].wqkv, cur); - cb(cur, "wqkv", il); - - cur = ggml_clamp(ctx0, cur, -hparams.f_clamp_kqv, hparams.f_clamp_kqv); - cb(cur, "wqkv_clamped", il); - - Qcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 0*sizeof(float)*(n_embd)); - Kcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 1*sizeof(float)*(n_embd)); - Vcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa)); + auto [Qcur, Kcur, Vcur] = build_qkv(model.layers[il], cur, + n_embd_head, n_head, n_head_kv, il); Qcur = ggml_rope_ext( ctx0, Qcur, inp_pos, nullptr, @@ -62,7 +93,7 @@ llm_build_dbrx::llm_build_dbrx(const llama_model & model, const llm_graph_params cb(Vcur, "Vcur", il); cur = build_attn(inp_attn, - model.layers[il].wo, NULL, + model.layers[il].wo, NULL, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } @@ -89,7 +120,7 @@ llm_build_dbrx::llm_build_dbrx(const llama_model & model, const llm_graph_params nullptr, n_expert, n_expert_used, LLM_FFN_SILU, true, - false, 0.0, + hparams.expert_weights_scale, LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX, il); cb(cur, "ffn_moe_out", il); @@ -114,7 +145,7 @@ llm_build_dbrx::llm_build_dbrx(const llama_model & model, const llm_graph_params res->t_embd = cur; // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/deci.cpp b/examples/talk-llama/models/deci.cpp index 7410a3a46d9..cdfcf29e02f 100644 --- a/examples/talk-llama/models/deci.cpp +++ b/examples/talk-llama/models/deci.cpp @@ -1,12 +1,87 @@ #include "models.h" +void llama_model_deci::load_arch_hparams(llama_model_loader & ml) { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + + switch (hparams.n_layer()) { + case 32: type = LLM_TYPE_7B; break; + case 80: type = LLM_TYPE_70B; break; + case 162: type = LLM_TYPE_405B; break; + default: type = LLM_TYPE_UNKNOWN; + } +} + +void llama_model_deci::load_arch_tensors(llama_model_loader &) { + LLAMA_LOAD_LOCALS; + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); + + // if output is NULL, init from the input tok embed + if (output == NULL) { + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); + } + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(i); + const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa(i); + const int64_t n_ff = hparams.n_ff(i); + const int64_t n_head = hparams.n_head(i); + const int64_t n_head_kv = hparams.n_head_kv(i); + + if (n_head_kv == 0 && n_head > 0) { + // linear attention for DeciLMCausalModel + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); + } + else if (n_head_kv > 0) { + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + + create_tensor_qkv(layer, i, n_embd, n_embd_head_k * n_head, n_embd_k_gqa, n_embd_v_gqa, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0); + } + + // optional bias tensors + layer.wo_b = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); + + if (n_ff > 0) { + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + } + if (hparams.rope_scaling_type_train == LLAMA_ROPE_SCALING_TYPE_LONGROPE) { + layer.rope_long = create_tensor(tn(LLM_TENSOR_ROPE_FACTORS_LONG, "weight", i), {n_rot/2}, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0)); + layer.rope_short = create_tensor(tn(LLM_TENSOR_ROPE_FACTORS_SHORT, "weight", i), {n_rot/2}, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0)); + } + else { + layer.rope_freqs = create_tensor(tn(LLM_TENSOR_ROPE_FREQS, "weight", i), {n_rot/2}, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0)); + } -llm_build_deci::llm_build_deci(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { - const int64_t n_embd_head = hparams.n_embd_head_v; + if (n_ff > 0) { + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + } - GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); - GGML_ASSERT(n_embd_head == hparams.n_rot); + // optional MLP bias + layer.ffn_gate_b = create_tensor(tn(LLM_TENSOR_FFN_GATE, "bias", i), {n_ff}, TENSOR_NOT_REQUIRED); + layer.ffn_down_b = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); + layer.ffn_up_b = create_tensor(tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff}, TENSOR_NOT_REQUIRED); + } +} + +std::unique_ptr<llm_graph_context> llama_model_deci::build_arch_graph(const llm_graph_params & params) const { + return std::make_unique<graph>(*this, params); +} + +llama_model_deci::graph::graph(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { + const int64_t n_embd_head = hparams.n_embd_head_v(); + + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k()); + GGML_ASSERT(n_embd_head == n_rot); ggml_tensor * cur; ggml_tensor * inpL; @@ -47,27 +122,8 @@ llm_build_deci::llm_build_deci(const llama_model & model, const llm_graph_params ggml_tensor * rope_factors = model.get_rope_factors(cparams, il); // compute Q and K and RoPE them - ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); - cb(Qcur, "Qcur", il); - if (model.layers[il].bq) { - Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq); - cb(Qcur, "Qcur", il); - } - ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); - cb(Kcur, "Kcur", il); - if (model.layers[il].bk) { - Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk); - cb(Kcur, "Kcur", il); - } - ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); - cb(Vcur, "Vcur", il); - if (model.layers[il].bv) { - Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv); - cb(Vcur, "Vcur", il); - } - Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); - Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + auto [Qcur, Kcur, Vcur] = build_qkv(model.layers[il], cur, + n_embd_head, n_head, n_head_kv, il); Qcur = ggml_rope_ext(ctx0, Qcur, inp_pos, rope_factors, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow); @@ -80,7 +136,7 @@ llm_build_deci::llm_build_deci(const llama_model & model, const llm_graph_params cb(Vcur, "Vcur", il); cur = build_attn(inp_attn, - model.layers[il].wo, model.layers[il].bo, + model.layers[il].wo, model.layers[il].wo_b, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale, il); } if (il == n_layer - 1 && inp_out_ids) { @@ -126,7 +182,7 @@ llm_build_deci::llm_build_deci(const llama_model & model, const llm_graph_params res->t_embd = cur; // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/deepseek.cpp b/examples/talk-llama/models/deepseek.cpp index 17866c0d88e..f52ec9518b6 100644 --- a/examples/talk-llama/models/deepseek.cpp +++ b/examples/talk-llama/models/deepseek.cpp @@ -1,13 +1,82 @@ #include "models.h" +void llama_model_deepseek::load_arch_hparams(llama_model_loader & ml) { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + ml.get_key(LLM_KV_LEADING_DENSE_BLOCK_COUNT, hparams.n_layer_dense_lead, false); + ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp); + ml.get_key(LLM_KV_EXPERT_SHARED_COUNT, hparams.n_expert_shared); + ml.get_key(LLM_KV_EXPERT_WEIGHTS_SCALE, hparams.expert_weights_scale, false); + + switch (hparams.n_ff_exp) { + case 1408: type = LLM_TYPE_16B; break; + case 1792: type = LLM_TYPE_20B; break; + default: type = LLM_TYPE_UNKNOWN; + } +} + +void llama_model_deepseek::load_arch_tensors(llama_model_loader &) { + LLAMA_LOAD_LOCALS; + const int64_t n_expert_shared = hparams.n_expert_shared; + + + const int64_t n_ff_exp = hparams.n_ff_exp; + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + // try to load output.weight, if not found, use token_embd (tied embeddings) + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); + if (!output) { + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); + } + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + + create_tensor_qkv(layer, i, n_embd, n_embd, n_embd_gqa, n_embd_gqa, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + + if (i < (int) hparams.n_layer_dense_lead) { + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + } else { + layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0); + + if (n_expert == 0) { + throw std::runtime_error("n_expert must be > 0"); + } + if (n_expert_used == 0) { + throw std::runtime_error("n_expert_used must be > 0"); + } + + // MoE branch + layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}, 0); + layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff_exp, n_embd, n_expert}, 0); + layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}, 0); + + // Shared expert branch + layer.ffn_gate_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), {n_embd, n_ff_exp * n_expert_shared}, 0); + layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), { n_ff_exp * n_expert_shared, n_embd}, 0); + layer.ffn_up_shexp = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), {n_embd, n_ff_exp * n_expert_shared}, 0); + } + } +} +std::unique_ptr<llm_graph_context> llama_model_deepseek::build_arch_graph(const llm_graph_params & params) const { + return std::make_unique<graph>(*this, params); +} -llm_build_deepseek::llm_build_deepseek(const llama_model & model, const llm_graph_params & params) : +llama_model_deepseek::graph::graph(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { - const int64_t n_embd_head = hparams.n_embd_head_v; + const int64_t n_embd_head = hparams.n_embd_head_v(); - GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); - GGML_ASSERT(n_embd_head == hparams.n_rot); + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k()); + GGML_ASSERT(n_embd_head == n_rot); ggml_tensor * cur; ggml_tensor * inpL; @@ -37,27 +106,8 @@ llm_build_deepseek::llm_build_deepseek(const llama_model & model, const llm_grap ggml_tensor * rope_factors = model.get_rope_factors(cparams, il); // compute Q and K and RoPE them - ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); - cb(Qcur, "Qcur", il); - if (model.layers[il].bq) { - Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq); - cb(Qcur, "Qcur", il); - } - ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); - cb(Kcur, "Kcur", il); - if (model.layers[il].bk) { - Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk); - cb(Kcur, "Kcur", il); - } - ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); - cb(Vcur, "Vcur", il); - if (model.layers[il].bv) { - Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv); - cb(Vcur, "Vcur", il); - } - Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); - Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + auto [Qcur, Kcur, Vcur] = build_qkv(model.layers[il], cur, + n_embd_head, n_head, n_head_kv, il); Qcur = ggml_rope_ext(ctx0, Qcur, inp_pos, rope_factors, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow); @@ -70,7 +120,7 @@ llm_build_deepseek::llm_build_deepseek(const llama_model & model, const llm_grap cb(Vcur, "Vcur", il); cur = build_attn(inp_attn, - model.layers[il].wo, model.layers[il].bo, + model.layers[il].wo, model.layers[il].wo_b, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale, il); } if (il == n_layer - 1 && inp_out_ids) { @@ -100,7 +150,7 @@ llm_build_deepseek::llm_build_deepseek(const llama_model & model, const llm_grap nullptr, n_expert, n_expert_used, LLM_FFN_SILU, false, - false, hparams.expert_weights_scale, + hparams.expert_weights_scale, LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX, il); cb(moe_out, "ffn_moe_out", il); @@ -135,7 +185,7 @@ llm_build_deepseek::llm_build_deepseek(const llama_model & model, const llm_grap res->t_embd = cur; // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/deepseek2.cpp b/examples/talk-llama/models/deepseek2.cpp index ca63a62ad1b..a9e8bc51403 100644 --- a/examples/talk-llama/models/deepseek2.cpp +++ b/examples/talk-llama/models/deepseek2.cpp @@ -1,23 +1,166 @@ #include "models.h" -llm_build_deepseek2::llm_build_deepseek2(const llama_model & model, const llm_graph_params & params) : +void llama_model_deepseek2::load_arch_hparams(llama_model_loader & ml) { + uint32_t n_vocab = 0; + ml.get_key(LLM_KV_VOCAB_SIZE, n_vocab, false) || ml.get_arr_n(LLM_KV_TOKENIZER_LIST, n_vocab, false); + + // lite variants include DeepSeek-V2-Lite, GigaChat3-10B-A1.8B, Kanana-2-30B-A3B + const bool is_lite = (hparams.n_layer() == 27 || hparams.n_layer() == 26 || (hparams.n_layer() == 48 && n_vocab == 128256)); + + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + ml.get_key(LLM_KV_LEADING_DENSE_BLOCK_COUNT, hparams.n_layer_dense_lead, false); + if (!is_lite) { + ml.get_key(LLM_KV_ATTENTION_Q_LORA_RANK, hparams.n_lora_q); + } + ml.get_key(LLM_KV_ATTENTION_KV_LORA_RANK, hparams.n_lora_kv); + ml.get_key(LLM_KV_ATTENTION_KEY_LENGTH_MLA, hparams.n_embd_head_k_mla_impl, false); + ml.get_key(LLM_KV_ATTENTION_VALUE_LENGTH_MLA, hparams.n_embd_head_v_mla_impl, false); + ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp); + ml.get_key(LLM_KV_EXPERT_SHARED_COUNT, hparams.n_expert_shared); + ml.get_key(LLM_KV_EXPERT_WEIGHTS_SCALE, hparams.expert_weights_scale, false); + ml.get_key(LLM_KV_EXPERT_WEIGHTS_NORM, hparams.expert_weights_norm, false); + ml.get_key(LLM_KV_EXPERT_GATING_FUNC, hparams.expert_gating_func, false); + if (hparams.expert_gating_func == LLAMA_EXPERT_GATING_FUNC_TYPE_NONE) { + // for compatibility with existing DeepSeek V2 and V2.5 GGUFs + // that have no expert_gating_func model parameter set + if ((hparams.n_layer() == 47 || hparams.n_layer() == 48) && n_vocab == 154880) { + // GLM 4.7 Lite + hparams.expert_gating_func = LLAMA_EXPERT_GATING_FUNC_TYPE_SIGMOID; + } else { + hparams.expert_gating_func = LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX; + } + } + + if (ml.get_key(LLM_KV_ROPE_SCALING_YARN_LOG_MUL, hparams.rope_yarn_log_mul, false)) { + // [TAG_DEEPSEEK2_YARN_LOG_MUL_FIX] + // cancel the factor from the convert script + hparams.rope_yarn_log_mul /= 0.1f; + } + + // (optional) temperature tuning - used by mistral-large + ml.get_key(LLM_KV_ATTENTION_TEMPERATURE_SCALE, hparams.f_attn_temp_scale, false); + ml.get_key(LLM_KV_ATTENTION_TEMPERATURE_LENGTH, hparams.n_attn_temp_floor_scale, false); // FIXME why not use temperature_length? + + hparams.f_attn_temp_offset = 0.0f; + + switch (hparams.n_layer()) { + case 27: type = LLM_TYPE_16B; break; + case 47: type = LLM_TYPE_30B_A3B; break; + case 60: type = LLM_TYPE_236B; break; + case 61: type = LLM_TYPE_671B; break; + default: type = LLM_TYPE_UNKNOWN; + } +} + +void llama_model_deepseek2::load_arch_tensors(llama_model_loader &) { + LLAMA_LOAD_LOCALS; + const int64_t n_expert_shared = hparams.n_expert_shared; + + const bool is_mla = hparams.is_mla(); + + // note: these are the actual head sizes you get when treating as MHA or after "decompression" using wv_b for MLA + const int64_t n_embd_head_k_mla = hparams.n_embd_head_k_mla(); + const int64_t n_embd_head_v_mla = hparams.n_embd_head_v_mla(); + + const int64_t n_embd_head_qk_rope = hparams.n_rot(); + const int64_t n_embd_head_qk_nope = n_embd_head_k_mla - n_embd_head_qk_rope; + GGML_ASSERT(n_embd_head_qk_nope >= 1); + + const int64_t q_lora_rank = hparams.n_lora_q; + const int64_t kv_lora_rank = hparams.n_lora_kv; + + const int64_t n_ff_exp = hparams.n_ff_exp; + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + // try to load output.weight, if not found, use token_embd (tied embeddings) + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); + if (!output) { + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); + } + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + if (q_lora_rank > 0) { + layer.attn_q_a_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_A_NORM, "weight", i), {q_lora_rank}, 0); + } + + layer.attn_kv_a_norm = create_tensor(tn(LLM_TENSOR_ATTN_KV_A_NORM, "weight", i), {kv_lora_rank}, 0); + + if (q_lora_rank > 0) { + layer.wq_a = create_tensor(tn(LLM_TENSOR_ATTN_Q_A, "weight", i), {n_embd, q_lora_rank}, 0); + layer.wq_b = create_tensor(tn(LLM_TENSOR_ATTN_Q_B, "weight", i), {q_lora_rank, n_head * n_embd_head_k_mla}, 0); + } else { + layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_head * n_embd_head_k_mla}, 0); + } + + layer.wkv_a_mqa = create_tensor(tn(LLM_TENSOR_ATTN_KV_A_MQA, "weight", i), {n_embd, kv_lora_rank + n_embd_head_qk_rope}, 0); + + // note: only old legacy GGUF files will have the unsplit wkv_b tensor in + if (is_mla) { + layer.wk_b = create_tensor(tn(LLM_TENSOR_ATTN_K_B, "weight", i), {n_embd_head_qk_nope, kv_lora_rank, n_head}, 0); + layer.wv_b = create_tensor(tn(LLM_TENSOR_ATTN_V_B, "weight", i), {kv_lora_rank, n_embd_head_v_mla, n_head}, 0); + } else { + layer.wkv_b = create_tensor(tn(LLM_TENSOR_ATTN_KV_B, "weight", i), {kv_lora_rank, n_head * (n_embd_head_qk_nope + n_embd_head_v_mla)}, 0); + } + + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_head * n_embd_head_v_mla, n_embd}, 0); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + + if (i < (int) hparams.n_layer_dense_lead) { + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + } else { + layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0); + layer.ffn_exp_probs_b = create_tensor(tn(LLM_TENSOR_FFN_EXP_PROBS_B, "bias", i), {n_expert}, TENSOR_NOT_REQUIRED); + + if (n_expert == 0) { + throw std::runtime_error("n_expert must be > 0"); + } + if (n_expert_used == 0) { + throw std::runtime_error("n_expert_used must be > 0"); + } + + // MoE branch + layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff_exp, n_embd, n_expert}, 0); + create_tensor_gate_up_exps(layer, i, n_embd, n_ff_exp, n_expert, 0); + + // Shared expert branch + layer.ffn_gate_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), {n_embd, n_ff_exp * n_expert_shared}, 0); + layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), { n_ff_exp * n_expert_shared, n_embd}, 0); + layer.ffn_up_shexp = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), {n_embd, n_ff_exp * n_expert_shared}, 0); + } + } +} + +std::unique_ptr<llm_graph_context> llama_model_deepseek2::build_arch_graph(const llm_graph_params & params) const { + return std::make_unique<graph>(*this, params); +} + +llama_model_deepseek2::graph::graph(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { // lite variants include DeepSeek-V2-Lite, GigaChat3-10B-A1.8B - bool is_lite = (hparams.n_layer == 27 || hparams.n_layer == 26); + bool is_ocr = model.arch == LLM_ARCH_DEEPSEEK2OCR; - const bool is_mla = (hparams.n_embd_head_k_mla != 0 && hparams.n_embd_head_v_mla != 0); + const bool is_mla = hparams.is_mla(); // note: these are the actual head sizes you get when treating as MHA or after "decompression" using wv_b for MLA - const int64_t n_embd_head_k = is_mla ? hparams.n_embd_head_k_mla : hparams.n_embd_head_k; - const int64_t n_embd_head_v = is_mla ? hparams.n_embd_head_v_mla : hparams.n_embd_head_v; + const int64_t n_embd_head_k = hparams.n_embd_head_k_mla(); + const int64_t n_embd_head_v = hparams.n_embd_head_v_mla(); - const int64_t n_embd_head_qk_rope = hparams.n_rot; + const int64_t n_embd_head_qk_rope = hparams.n_rot(); const int64_t n_embd_head_qk_nope = n_embd_head_k - n_embd_head_qk_rope; const uint32_t kv_lora_rank = hparams.n_lora_kv; // We have to pre-scale kq_scale and attn_factor to make the YaRN RoPE work correctly. - // See https://github.com/ggerganov/llama.cpp/discussions/7416 for detailed explanation. + // See https://github.com/ggml-org/llama.cpp/discussions/7416 for detailed explanation. // And also: https://github.com/ggml-org/llama.cpp/pull/17945 [TAG_DEEPSEEK2_YARN_LOG_MUL_FIX] // first cancel the adjustment from llama_hparams::yarn_attn_factor_adjust to get the original attn_factor @@ -43,7 +186,8 @@ llm_build_deepseek2::llm_build_deepseek2(const llama_model & model, const llm_gr // inp_pos - contains the positions ggml_tensor * inp_pos = build_inp_pos(); - auto * inp_attn = build_attn_inp_kv(); + auto * inp_attn_kv = !is_mla ? build_attn_inp_kv() : nullptr; + auto * inp_attn_k = is_mla ? build_attn_inp_k() : nullptr; ggml_tensor * inp_out_ids = build_inp_out_ids(); @@ -55,8 +199,42 @@ llm_build_deepseek2::llm_build_deepseek2(const llama_model & model, const llm_gr cb(cur, "attn_norm", il); // self_attention - { + if (is_ocr) { + const int n_embed_head = hparams.n_embd / hparams.n_head(); + const int ocr_rope_type = GGML_ROPE_TYPE_NEOX; + GGML_ASSERT(n_embed_head == n_embd_head_k && n_embed_head == n_embd_head_v); + + ggml_tensor * Qcur = NULL; + ggml_tensor * Kcur = NULL; + ggml_tensor * Vcur = NULL; + + Qcur = ggml_mul_mat(ctx0, model.layers[il].wq, cur); + Kcur = ggml_mul_mat(ctx0, model.layers[il].wk, cur); + Vcur = ggml_mul_mat(ctx0, model.layers[il].wv, cur); + cb(Qcur, "q", il); + cb(Kcur, "k", il); + cb(Vcur, "v", il); + + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embed_head, n_head, n_tokens); + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embed_head, n_head, n_tokens); + Vcur = ggml_reshape_3d(ctx0, Vcur, n_embed_head, n_head, n_tokens); + + GGML_ASSERT(fabs(freq_base - 10000.0) < 1e-4); + Qcur = ggml_rope_ext(ctx0, Qcur, inp_pos, nullptr, n_embed_head, ocr_rope_type, 0, freq_base, 1, 0, 1, 0, 0); + Kcur = ggml_rope_ext(ctx0, Kcur, inp_pos, nullptr, n_embed_head, ocr_rope_type, 0, freq_base, 1, 0, 1, 0, 0); + cb(Qcur, "q_pe", il); + cb(Kcur, "k_pe", il); + + cur = build_attn(inp_attn_kv, + model.layers[il].wo, NULL, model.layers[il].wo_s, + Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale, il); + cb(cur, "attn_out", il); + } + else { ggml_tensor * q = NULL; + + const bool is_lite = model.layers[il].wq; + if (!is_lite) { q = ggml_mul_mat(ctx0, model.layers[il].wq_a, cur); cb(q, "q", il); @@ -124,14 +302,14 @@ llm_build_deepseek2::llm_build_deepseek2(const llama_model & model, const llm_gr // {n_embd_head_qk_rope + kv_lora_rank, n_head, n_tokens} // note: rope must go first for in-place context shifting in build_rope_shift() - ggml_tensor * Qcur = ggml_concat(ctx0, q_pe, q_nope_absorbed, 0); + ggml_tensor * Qcur = ggml_concat(ctx0, q_nope_absorbed, q_pe, 0); cb(Qcur, "Qcur", il); kv_cmpr = ggml_reshape_3d(ctx0, kv_cmpr, kv_lora_rank, 1, n_tokens); cb(kv_cmpr, "kv_cmpr_reshape", il); // {n_embd_head_qk_rope + kv_lora_rank, 1, n_tokens} - ggml_tensor * Kcur = ggml_concat(ctx0, k_pe, kv_cmpr, 0); + ggml_tensor * Kcur = ggml_concat(ctx0, kv_cmpr, k_pe, 0); cb(Kcur, "Kcur", il); // {kv_lora_rank, 1, n_tokens} @@ -144,9 +322,9 @@ llm_build_deepseek2::llm_build_deepseek2(const llama_model & model, const llm_gr cb(Qcur, "Qcur_attn_temp_scaled", il); } - // note: MLA with the absorption optimzation converts into MQA (ie: GQA with 1 group) - cur = build_attn(inp_attn, - model.layers[il].wo, NULL, + // note: MLA with the absorption optimization converts into MQA (ie: GQA with 1 group) + cur = build_attn(inp_attn_k, + model.layers[il].wo, NULL, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, model.layers[il].wv_b, kq_scale, il); } else { ggml_tensor * kv = ggml_mul_mat(ctx0, model.layers[il].wkv_b, kv_cmpr); @@ -169,11 +347,10 @@ llm_build_deepseek2::llm_build_deepseek2(const llama_model & model, const llm_gr Vcur = ggml_cont(ctx0, Vcur); cb(Vcur, "Vcur_cont", il); - // note: rope must go first for in-place context shifting in build_rope_shift() - ggml_tensor * Qcur = ggml_concat(ctx0, q_pe, q_nope, 0); + ggml_tensor * Qcur = ggml_concat(ctx0, q_nope, q_pe, 0); cb(Qcur, "Qcur", il); - ggml_tensor * Kcur = ggml_concat(ctx0, ggml_repeat(ctx0, k_pe, q_pe), k_nope, 0); + ggml_tensor * Kcur = ggml_concat(ctx0, k_nope, ggml_repeat(ctx0, k_pe, q_pe), 0); cb(Kcur, "Kcur", il); if (inp_attn_scale) { @@ -183,8 +360,8 @@ llm_build_deepseek2::llm_build_deepseek2(const llama_model & model, const llm_gr } // note: MLA without the absorption optimization converts into MHA (ie: GQA with full n_head groups) - cur = build_attn(inp_attn, - model.layers[il].wo, NULL, + cur = build_attn(inp_attn_kv, + model.layers[il].wo, NULL, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale, il); } } @@ -215,9 +392,11 @@ llm_build_deepseek2::llm_build_deepseek2(const llama_model & model, const llm_gr model.layers[il].ffn_exp_probs_b, n_expert, n_expert_used, LLM_FFN_SILU, hparams.expert_weights_norm, - hparams.expert_weights_scale, hparams.expert_weights_scale, + hparams.expert_weights_scale, (llama_expert_gating_func_type) hparams.expert_gating_func, - il); + il, + nullptr, + model.layers[il].ffn_gate_up_exps); cb(moe_out, "ffn_moe_out", il); // FFN shared expert diff --git a/examples/talk-llama/models/deepseek2ocr.cpp b/examples/talk-llama/models/deepseek2ocr.cpp new file mode 100644 index 00000000000..65d31c31b93 --- /dev/null +++ b/examples/talk-llama/models/deepseek2ocr.cpp @@ -0,0 +1,82 @@ +#include "models.h" + +void llama_model_deepseek2ocr::load_arch_hparams(llama_model_loader & ml) { + // similar to deepseek2, but without MLA + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + ml.get_key(LLM_KV_LEADING_DENSE_BLOCK_COUNT, hparams.n_layer_dense_lead, false); + ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp); + ml.get_key(LLM_KV_EXPERT_SHARED_COUNT, hparams.n_expert_shared); + ml.get_key(LLM_KV_EXPERT_WEIGHTS_SCALE, hparams.expert_weights_scale, false); + ml.get_key(LLM_KV_EXPERT_WEIGHTS_NORM, hparams.expert_weights_norm, false); + ml.get_key(LLM_KV_EXPERT_GATING_FUNC, hparams.expert_gating_func, false); + + if (hparams.expert_gating_func == LLAMA_EXPERT_GATING_FUNC_TYPE_NONE) { + hparams.expert_gating_func = LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX; + } + + switch (hparams.n_layer()) { + case 12: type = LLM_TYPE_3B; break; + default: type = LLM_TYPE_UNKNOWN; + } +} + +void llama_model_deepseek2ocr::load_arch_tensors(llama_model_loader &) { + LLAMA_LOAD_LOCALS; + const int64_t n_expert_shared = hparams.n_expert_shared; + + // similar to deepseek2, but without MLA + const int64_t n_ff_exp = hparams.n_ff_exp; + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + // try to load output.weight, if not found, use token_embd (tied embeddings) + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); + if (!output) { + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); + } + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}, 0); + layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd}, 0); + layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd}, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); + + // norm + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + + if (i < (int) hparams.n_layer_dense_lead) { + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); + } else { + layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0); + layer.ffn_exp_probs_b = create_tensor(tn(LLM_TENSOR_FFN_EXP_PROBS_B, "bias", i), {n_expert}, TENSOR_NOT_REQUIRED); + + if (n_expert == 0) { + throw std::runtime_error("n_expert must be > 0"); + } + if (n_expert_used == 0) { + throw std::runtime_error("n_expert_used must be > 0"); + } + + // MoE branch + layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff_exp, n_embd, n_expert}, 0); + create_tensor_gate_up_exps(layer, i, n_embd, n_ff_exp, n_expert, 0); + + // Shared expert branch + layer.ffn_gate_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), {n_embd, n_ff_exp * n_expert_shared}, 0); + layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), { n_ff_exp * n_expert_shared, n_embd}, 0); + layer.ffn_up_shexp = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), {n_embd, n_ff_exp * n_expert_shared}, 0); + } + } +} + +std::unique_ptr<llm_graph_context> llama_model_deepseek2ocr::build_arch_graph(const llm_graph_params & params) const { + return std::make_unique<graph>(*this, params); +} + diff --git a/examples/talk-llama/models/deepseek32.cpp b/examples/talk-llama/models/deepseek32.cpp new file mode 100644 index 00000000000..9a20e2ce907 --- /dev/null +++ b/examples/talk-llama/models/deepseek32.cpp @@ -0,0 +1,499 @@ +#include "models.h" + +#include "llama-kv-cache.h" +#include "llama-kv-cache-dsa.h" + +void llama_model_deepseek32::load_arch_hparams(llama_model_loader & ml) { + ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp); + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + hparams.f_norm_eps = 1e-6; // eps for layer norm + ml.get_key_or_arr(LLM_KV_ROPE_DIMENSION_SECTIONS, hparams.rope_sections, 4, false); + + // MoE parameters + ml.get_key(LLM_KV_EXPERT_COUNT, hparams.n_expert); + ml.get_key(LLM_KV_EXPERT_USED_COUNT, hparams.n_expert_used); + ml.get_key(LLM_KV_EXPERT_SHARED_COUNT, hparams.n_expert_shared); + ml.get_key(LLM_KV_LEADING_DENSE_BLOCK_COUNT, hparams.n_layer_dense_lead, false); + ml.get_key(LLM_KV_EXPERT_WEIGHTS_SCALE, hparams.expert_weights_scale, false); + ml.get_key(LLM_KV_EXPERT_WEIGHTS_NORM, hparams.expert_weights_norm, false); + + // deepseek MLA parameters + ml.get_key(LLM_KV_ATTENTION_Q_LORA_RANK, hparams.n_lora_q); + ml.get_key(LLM_KV_ATTENTION_KV_LORA_RANK, hparams.n_lora_kv); + ml.get_key(LLM_KV_ATTENTION_KEY_LENGTH_MLA, hparams.n_embd_head_k_mla_impl, false); + ml.get_key(LLM_KV_ATTENTION_VALUE_LENGTH_MLA, hparams.n_embd_head_v_mla_impl, false); + ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp); + ml.get_key(LLM_KV_EXPERT_SHARED_COUNT, hparams.n_expert_shared); + + // DSA parameters + ml.get_key(LLM_KV_ATTENTION_INDEXER_HEAD_COUNT, hparams.indexer_n_head); + ml.get_key(LLM_KV_ATTENTION_INDEXER_KEY_LENGTH, hparams.indexer_head_size); + ml.get_key(LLM_KV_ATTENTION_INDEXER_TOP_K, hparams.indexer_top_k); + + // Expert gating function + ml.get_key(LLM_KV_EXPERT_GATING_FUNC, hparams.expert_gating_func); + + if (ml.get_key(LLM_KV_ROPE_SCALING_YARN_LOG_MUL, hparams.rope_yarn_log_mul, 0.0f)) { + // [TAG_DEEPSEEK2_YARN_LOG_MUL_FIX] + // cancel the factor from the convert script + hparams.rope_yarn_log_mul /= 0.1f; + } + + // NextN/MTP parameters + ml.get_key(LLM_KV_NEXTN_PREDICT_LAYERS, hparams.n_layer_nextn, false); + GGML_ASSERT(hparams.n_layer_nextn < hparams.n_layer_all && "n_layer_nextn must be < n_layer"); + + switch (hparams.n_layer()) { + case 62: type = LLM_TYPE_685B_A37B; break; + default: type = LLM_TYPE_UNKNOWN; + } +} + +void llama_model_deepseek32::load_arch_tensors(llama_model_loader &) { + LLAMA_LOAD_LOCALS; + const bool is_mla = hparams.is_mla(); + if (!is_mla) { + throw std::runtime_error("DEEPSEEK32 architecture requires MLA"); + } + + // note: these are the actual head sizes you get when treating as MHA or after "decompression" using wv_b for MLA + const int64_t n_embd_head_k_mla = hparams.n_embd_head_k_mla(); + const int64_t n_embd_head_v_mla = hparams.n_embd_head_v_mla(); + + const int64_t n_embd_head_qk_rope = hparams.n_rot(); + const int64_t n_embd_head_qk_nope = n_embd_head_k_mla - n_embd_head_qk_rope; + + const int64_t q_lora_rank = hparams.n_lora_q; + const int64_t kv_lora_rank = hparams.n_lora_kv; + + const int64_t n_ff_exp = hparams.n_ff_exp; + const int64_t n_expert_shared = hparams.n_expert_shared; + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + // try to load output.weight, if not found, use token_embd (tied embeddings) + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); + if (!output) { + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); + } + + for (int i = 0; i < n_layer_all; ++i) { + int flags = 0; + if (i >= n_layer) { + // skip all tensors in the NextN layers + // TODO @ngxson : TENSOR_NOT_REQUIRED was a hack, need to remove it later + flags |= TENSOR_SKIP | TENSOR_NOT_REQUIRED; + } + + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, flags); + layer.attn_q_a_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_A_NORM, "weight", i), {q_lora_rank}, flags); + layer.attn_kv_a_norm = create_tensor(tn(LLM_TENSOR_ATTN_KV_A_NORM, "weight", i), {kv_lora_rank}, flags); + + layer.wq_a = create_tensor(tn(LLM_TENSOR_ATTN_Q_A, "weight", i), {n_embd, q_lora_rank}, flags); + layer.wq_b = create_tensor(tn(LLM_TENSOR_ATTN_Q_B, "weight", i), {q_lora_rank, n_head * n_embd_head_k_mla}, flags); + + layer.wkv_a_mqa = create_tensor(tn(LLM_TENSOR_ATTN_KV_A_MQA, "weight", i), {n_embd, kv_lora_rank + n_embd_head_qk_rope}, flags); + + // note: only old legacy GGUF files will have the unsplit wkv_b tensor in + layer.wk_b = create_tensor(tn(LLM_TENSOR_ATTN_K_B, "weight", i), {n_embd_head_qk_nope, kv_lora_rank, n_head}, flags); + layer.wv_b = create_tensor(tn(LLM_TENSOR_ATTN_V_B, "weight", i), {kv_lora_rank, n_embd_head_v_mla, n_head}, flags); + + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_head * n_embd_head_v_mla, n_embd}, flags); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, flags); + + // DSA indexer + layer.indexer_k_norm = create_tensor(tn(LLM_TENSOR_INDEXER_K_NORM, "weight", i), {hparams.indexer_head_size}, flags); + layer.indexer_k_norm_b = create_tensor(tn(LLM_TENSOR_INDEXER_K_NORM, "bias", i), {hparams.indexer_head_size}, flags); + layer.indexer_proj = create_tensor(tn(LLM_TENSOR_INDEXER_PROJ, "weight", i), {n_embd, hparams.indexer_n_head}, flags); + layer.indexer_attn_k = create_tensor(tn(LLM_TENSOR_INDEXER_ATTN_K, "weight", i), {n_embd, hparams.indexer_head_size}, flags); + layer.indexer_attn_q_b = create_tensor(tn(LLM_TENSOR_INDEXER_ATTN_Q_B, "weight", i), {q_lora_rank, hparams.indexer_n_head * hparams.indexer_head_size}, flags); + if (i < (int) hparams.n_layer_dense_lead) { + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, flags); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, flags); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, flags); + } else { + layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, flags); + layer.ffn_exp_probs_b = create_tensor(tn(LLM_TENSOR_FFN_EXP_PROBS_B, "bias", i), {n_expert}, TENSOR_NOT_REQUIRED); + + if (n_expert == 0) { + throw std::runtime_error("n_expert must be > 0"); + } + if (n_expert_used == 0) { + throw std::runtime_error("n_expert_used must be > 0"); + } + + // MoE branch + layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}, flags); + layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff_exp, n_embd, n_expert}, flags); + layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}, flags); + + // Shared expert branch + layer.ffn_gate_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), {n_embd, n_ff_exp * n_expert_shared}, flags); + layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), { n_ff_exp * n_expert_shared, n_embd}, flags); + layer.ffn_up_shexp = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), {n_embd, n_ff_exp * n_expert_shared}, flags); + } + + // NextN/MTP tensors (preserved but unused) - conditionally load for last nextn_predict_layers + if (i >= n_layer) { + layer.nextn.eh_proj = create_tensor(tn(LLM_TENSOR_NEXTN_EH_PROJ, "weight", i), { 2 * n_embd, n_embd }, flags); + layer.nextn.enorm = create_tensor(tn(LLM_TENSOR_NEXTN_ENORM, "weight", i), { n_embd }, flags); + layer.nextn.hnorm = create_tensor(tn(LLM_TENSOR_NEXTN_HNORM, "weight", i), { n_embd }, flags); + + // Optional tensors + layer.nextn.embed_tokens = create_tensor(tn(LLM_TENSOR_NEXTN_EMBED_TOKENS, "weight", i), { n_embd, n_vocab }, flags | TENSOR_NOT_REQUIRED); + layer.nextn.shared_head_head = create_tensor(tn(LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD, "weight", i), { n_embd, n_vocab }, flags | TENSOR_NOT_REQUIRED); + layer.nextn.shared_head_norm = create_tensor(tn(LLM_TENSOR_NEXTN_SHARED_HEAD_NORM, "weight", i), { n_embd }, flags | TENSOR_NOT_REQUIRED); + } + } +} + +std::unique_ptr<llm_graph_context> llama_model_deepseek32::build_arch_graph(const llm_graph_params & params) const { + return std::make_unique<graph>(*this, params); +} + +llama_model_deepseek32::graph::graph(const llama_model & model, const llm_graph_params & params) : + llm_graph_context(params) { + const bool is_mla = hparams.is_mla(); + GGML_ASSERT(is_mla); + + // note: these are the actual head sizes you get when treating as MHA or after "decompression" using wv_b for MLA + const int64_t n_embd_head_k = hparams.n_embd_head_k_mla(); + const int64_t n_embd_head_v = hparams.n_embd_head_v_mla(); + GGML_UNUSED(n_embd_head_v); + + const int64_t n_embd_head_qk_rope = hparams.n_rot(); + const int64_t n_embd_head_qk_nope = n_embd_head_k - n_embd_head_qk_rope; + + const int64_t n_indexer_head = hparams.indexer_n_head; + const int64_t n_embd_indexer_head = hparams.indexer_head_size; + const int64_t n_embd_indexer_head_rope = hparams.n_rot(); + const int64_t n_embd_indexer_head_nope = n_embd_indexer_head - n_embd_indexer_head_rope; + const uint32_t n_indexer_top_k = hparams.indexer_top_k; + + const uint32_t kv_lora_rank = hparams.n_lora_kv; + + // We have to pre-scale kq_scale and attn_factor to make the YaRN RoPE work correctly. + // See https://github.com/ggml-org/llama.cpp/discussions/7416 for detailed explanation. + // And also: https://github.com/ggml-org/llama.cpp/pull/17945 [TAG_DEEPSEEK2_YARN_LOG_MUL_FIX] + + // first cancel the adjustment from llama_hparams::yarn_attn_factor_adjust to get the original attn_factor + GGML_ASSERT(ext_factor >= 0.0f); + const float attn_factor_org = attn_factor * (1.0f + 0.1f * logf(1.0f / freq_scale)); + + // use the original attn_factor to pre-scale the kq_scale + const float mscale = attn_factor_org * (1.0f + 0.1f * hparams.rope_yarn_log_mul * logf(1.0f / freq_scale)); + const float kq_scale = 1.0f * mscale * mscale / sqrtf(float(n_embd_head_k)); + + ggml_tensor * cur; + ggml_tensor * inpL; + + // {n_embd, n_tokens} + inpL = build_inp_embd(model.tok_embd); + + // inp_pos - contains the positions + ggml_tensor * inp_pos = build_inp_pos(); + + llm_graph_input_attn_k_dsa * inp_attn_dsa = build_attn_inp_k_dsa(); + + ggml_tensor * inp_out_ids = build_inp_out_ids(); + + for (int il = 0; il < n_layer; ++il) { + ggml_tensor * inpSA = inpL; + + // norm + cur = build_norm(inpL, model.layers[il].attn_norm, NULL, LLM_NORM_RMS, il); + cb(cur, "attn_norm", il); + + // self_attention + { + ggml_tensor * qr = ggml_mul_mat(ctx0, model.layers[il].wq_a, cur); + cb(qr, "qr", il); + + qr = build_norm(qr, model.layers[il].attn_q_a_norm, nullptr, LLM_NORM_RMS, il); + cb(qr, "qr", il); + + ggml_tensor * top_k = nullptr; + + // lightning indexer + { + ggml_tensor * indexer_q = ggml_mul_mat(ctx0, model.layers[il].indexer_attn_q_b, qr); + cb(indexer_q, "indexer_q", il); + + // split into {n_embd_indexer_head_rope, n_indexer_head, n_tokens} + ggml_tensor * indexer_q_pe = + ggml_view_3d(ctx0, indexer_q, n_embd_indexer_head_rope, n_indexer_head, n_tokens, + ggml_row_size(indexer_q->type, n_embd_indexer_head), + ggml_row_size(indexer_q->type, n_embd_indexer_head) * n_indexer_head, 0); + cb(indexer_q_pe, "indexer_q_pe", il); + + // and {n_embd_indexer_head_nope, n_indexer_head, n_tokens} + ggml_tensor * indexer_q_nope = + ggml_view_3d(ctx0, indexer_q, n_embd_indexer_head_nope, n_indexer_head, n_tokens, + ggml_row_size(indexer_q->type, n_embd_indexer_head), + ggml_row_size(indexer_q->type, n_embd_indexer_head) * n_indexer_head, + ggml_row_size(indexer_q->type, n_embd_indexer_head_nope)); + cb(indexer_q_nope, "indexer_q_nope", il); + + indexer_q_pe = ggml_rope_ext(ctx0, indexer_q_pe, inp_pos, nullptr, n_rot, + LLAMA_ROPE_TYPE_NEOX, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow); + cb(indexer_q_pe, "indexer_q_pe", il); + + // {n_embd_indexer_head_rope + n_embd_indexer_head_nope, n_head, n_tokens} + indexer_q = ggml_concat(ctx0, indexer_q_pe, indexer_q_nope, 0); + cb(indexer_q, "indexer_q", il); + + ggml_tensor * indexer_k = ggml_mul_mat(ctx0, model.layers[il].indexer_attn_k, cur); + cb(indexer_k, "indexer_k", il); + + indexer_k = build_norm(indexer_k, model.layers[il].indexer_k_norm, model.layers[il].indexer_k_norm_b, LLM_NORM, il); + cb(indexer_k, "indexer_k", il); + + // split into {n_embd_indexer_head_rope, 1, n_tokens} + ggml_tensor * indexer_k_pe = + ggml_view_3d(ctx0, indexer_k, n_embd_indexer_head_rope, 1, n_tokens, + ggml_row_size(indexer_k->type, n_embd_indexer_head), + ggml_row_size(indexer_k->type, n_embd_indexer_head) * 1, 0); + cb(indexer_k_pe, "indexer_k_pe", il); + + // and {n_embd_indexer_head_nope, 1, n_tokens} + ggml_tensor * indexer_k_nope = + ggml_view_3d(ctx0, indexer_k, n_embd_indexer_head_nope, 1, n_tokens, + ggml_row_size(indexer_k->type, n_embd_indexer_head), + ggml_row_size(indexer_k->type, n_embd_indexer_head) * 1, + ggml_row_size(indexer_k->type, n_embd_indexer_head_nope)); + cb(indexer_k_nope, "indexer_k_nope", il); + + indexer_k_pe = ggml_rope_ext(ctx0, indexer_k_pe, inp_pos, nullptr, n_rot, + LLAMA_ROPE_TYPE_NEOX, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow); + cb(indexer_k_pe, "indexer_k_pe", il); + + // {n_embd_indexer_head_rope + n_embd_indexer_head_nope, 1, n_tokens} + indexer_k = ggml_concat(ctx0, indexer_k_pe, indexer_k_nope, 0); + cb(indexer_k, "indexer_k", il); + + // perform Hadamard transform on indexer q and k + indexer_q = ggml_mul_mat(ctx0, inp_attn_dsa->self_k_rot_lid, indexer_q); + cb(indexer_q, "indexer_q", il); + indexer_k = ggml_mul_mat(ctx0, inp_attn_dsa->self_k_rot_lid, indexer_k); + cb(indexer_k, "indexer_k", il); + + // store indexer keys to KV cache + const auto * mctx_lid = inp_attn_dsa->mctx->get_lid(); + const auto & k_idxs_lid = inp_attn_dsa->get_k_idxs_lid(); + ggml_build_forward_expand(gf, mctx_lid->cpy_k(ctx0, indexer_k, k_idxs_lid, il)); + + // prepare indexer weights + ggml_tensor * indexer_weights = ggml_mul_mat(ctx0, model.layers[il].indexer_proj, cur); + cb(indexer_weights, "indexer_weights", il); + + // get cached indexer keys + indexer_k = mctx_lid->get_k(ctx0, il); + + // split the batch into streams if needed + const auto n_stream = indexer_k->ne[3]; + indexer_q = ggml_view_4d(ctx0, indexer_q, indexer_q->ne[0], indexer_q->ne[1], indexer_q->ne[2]/n_stream, n_stream, indexer_q->nb[1], indexer_q->nb[2], indexer_q->nb[3]/n_stream, 0); + indexer_weights = ggml_view_4d(ctx0, indexer_weights, indexer_weights->ne[0], indexer_weights->ne[1]/n_stream, indexer_weights->ne[2], n_stream, indexer_weights->nb[1], indexer_weights->nb[2]/n_stream, indexer_weights->nb[3]/n_stream, 0); + + // calculate indexer kq + indexer_q = ggml_permute(ctx0, indexer_q, 0, 2, 1, 3); + cb(indexer_q, "indexer_q", il); + indexer_k = ggml_permute(ctx0, indexer_k, 0, 2, 1, 3); + cb(indexer_k, "indexer_k", il); + + ggml_tensor * indexer_kq = ggml_mul_mat(ctx0, indexer_k, indexer_q); + cb(indexer_kq, "indexer_kq", il); + + // ReLU requires contiguous tensors + indexer_kq = ggml_cont(ctx0, ggml_permute(ctx0, indexer_kq, 2, 1, 0, 3)); + cb(indexer_kq, "indexer_kq", il); + + // apply ReLU + ggml_tensor * indexer_score = ggml_relu(ctx0, indexer_kq); + cb(indexer_score, "indexer_score", il); + + // pre-scale weights to avoid scaling operations on huge indexer_score tensor + indexer_weights = ggml_scale(ctx0, indexer_weights, 1.0f / sqrtf(float(n_embd_indexer_head * n_indexer_head))); + cb(indexer_weights, "indexer_weights", il); + + // multiply scores by indexer weights + indexer_score = ggml_mul(ctx0, indexer_score, indexer_weights); + cb(indexer_score, "indexer_score", il); + + // sum by q n_indexer_head dimension + indexer_score = ggml_sum_rows(ctx0, indexer_score); + cb(indexer_score, "indexer_score", il); + + // permute result to match KQ mask + indexer_score = ggml_cont(ctx0, ggml_permute(ctx0, indexer_score, 2, 1, 0, 3)); + cb(indexer_score, "indexer_score", il); + + // mask indexer scores + ggml_tensor * indexer_kq_mask = inp_attn_dsa->get_kq_mask_lid(); + indexer_score = ggml_add(ctx0, indexer_score, indexer_kq_mask); + cb(indexer_score, "indexer_score", il); + + // get indices of top k indexer scores + uint32_t n_top_k = indexer_score->ne[0] < n_indexer_top_k ? indexer_score->ne[0] : n_indexer_top_k; + top_k = ggml_cont(ctx0, ggml_top_k(ctx0, indexer_score, n_top_k)); + cb(top_k, "top_k", il); + } + + ggml_tensor * q = ggml_mul_mat(ctx0, model.layers[il].wq_b, qr); + cb(q, "q", il); + + // split into {n_embd_head_qk_nope, n_head, n_tokens} + ggml_tensor * q_nope = + ggml_view_3d(ctx0, q, n_embd_head_qk_nope, n_head, n_tokens, ggml_row_size(q->type, n_embd_head_k), + ggml_row_size(q->type, n_embd_head_k) * n_head, 0); + cb(q_nope, "q_nope", il); + + // and {n_embd_head_qk_rope, n_head, n_tokens} + ggml_tensor * q_pe = ggml_view_3d( + ctx0, q, n_embd_head_qk_rope, n_head, n_tokens, ggml_row_size(q->type, n_embd_head_k), + ggml_row_size(q->type, n_embd_head_k) * n_head, ggml_row_size(q->type, n_embd_head_qk_nope)); + cb(q_pe, "q_pe", il); + + ggml_tensor * kv_cmpr_pe = ggml_mul_mat(ctx0, model.layers[il].wkv_a_mqa, cur); + cb(kv_cmpr_pe, "kv_cmpr_pe", il); + + // split into {kv_lora_rank, n_tokens} + ggml_tensor * kv_cmpr = + ggml_view_2d(ctx0, kv_cmpr_pe, kv_lora_rank, n_tokens, + ggml_row_size(kv_cmpr_pe->type, kv_lora_rank + n_embd_head_qk_rope), 0); + cb(kv_cmpr, "kv_cmpr", il); + + // and {n_embd_head_qk_rope, 1, n_tokens} + ggml_tensor * k_pe = ggml_view_3d(ctx0, kv_cmpr_pe, n_embd_head_qk_rope, 1, n_tokens, + ggml_row_size(kv_cmpr_pe->type, kv_lora_rank + n_embd_head_qk_rope), + ggml_row_size(kv_cmpr_pe->type, kv_lora_rank + n_embd_head_qk_rope), + ggml_row_size(kv_cmpr_pe->type, kv_lora_rank)); + cb(k_pe, "k_pe", il); + + q_pe = ggml_rope_ext(ctx0, q_pe, inp_pos, nullptr, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow); + cb(q_pe, "q_pe", il); + + k_pe = ggml_rope_ext(ctx0, k_pe, inp_pos, nullptr, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow); + cb(k_pe, "k_pe", il); + + kv_cmpr = build_norm(kv_cmpr, model.layers[il].attn_kv_a_norm, nullptr, LLM_NORM_RMS, il); + cb(kv_cmpr, "kv_cmpr", il); + + // MLA attention + { + // {n_embd_head_qk_nope, n_tokens, n_head} + q_nope = ggml_permute(ctx0, q_nope, 0, 2, 1, 3); + cb(q_nope, "q_nope_perm", il); + + // {n_embd_head_qk_nope, kv_lora_rank, n_head} x {n_embd_head_qk_nope, n_tokens, n_head} + ggml_tensor * q_nope_absorbed = ggml_mul_mat(ctx0, model.layers[il].wk_b, q_nope); + cb(q_nope_absorbed, "q_nope_absorbed", il); + + // {kv_lora_rank, n_head, n_tokens} + q_nope_absorbed = ggml_permute(ctx0, q_nope_absorbed, 0, 2, 1, 3); + cb(q_nope_absorbed, "q_nope_absorbed_perm", il); + + // {n_embd_head_qk_rope + kv_lora_rank, n_head, n_tokens} + // note: rope must go first for in-place context shifting in build_rope_shift() + ggml_tensor * Qcur = ggml_concat(ctx0, q_nope_absorbed, q_pe, 0); + cb(Qcur, "Qcur", il); + + kv_cmpr = ggml_reshape_3d(ctx0, kv_cmpr, kv_lora_rank, 1, n_tokens); + cb(kv_cmpr, "kv_cmpr_reshape", il); + + // {n_embd_head_qk_rope + kv_lora_rank, 1, n_tokens} + ggml_tensor * Kcur = ggml_concat(ctx0, kv_cmpr, k_pe, 0); + cb(Kcur, "Kcur", il); + + // {kv_lora_rank, 1, n_tokens} + ggml_tensor * Vcur = kv_cmpr; + cb(Vcur, "Vcur", il); + + // note: MLA with the absorption optimization converts into MQA (ie: GQA with 1 group) + cur = build_attn(inp_attn_dsa, + model.layers[il].wo, NULL, model.layers[il].wo_s, + Qcur, Kcur, Vcur, nullptr, nullptr, model.layers[il].wv_b, top_k, kq_scale, il); + } + } + if (il == n_layer - 1 && inp_out_ids) { + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); + } + ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA); + cb(ffn_inp, "ffn_inp", il); + + cur = build_norm(ffn_inp, model.layers[il].ffn_norm, NULL, LLM_NORM_RMS, il); + cb(cur, "ffn_norm", il); + + if ((uint32_t) il < hparams.n_layer_dense_lead) { + cur = build_ffn(cur, + model.layers[il].ffn_up, NULL, model.layers[il].ffn_up_s, + model.layers[il].ffn_gate, NULL, model.layers[il].ffn_gate_s, + model.layers[il].ffn_down, NULL, model.layers[il].ffn_down_s, + NULL, LLM_FFN_SILU, LLM_FFN_PAR, il); + cb(cur, "ffn_out", il); + } else { + // MoE branch + ggml_tensor * moe_out = build_moe_ffn(cur, + model.layers[il].ffn_gate_inp, + model.layers[il].ffn_up_exps, + model.layers[il].ffn_gate_exps, + model.layers[il].ffn_down_exps, + model.layers[il].ffn_exp_probs_b, + n_expert, n_expert_used, + LLM_FFN_SILU, hparams.expert_weights_norm, + hparams.expert_weights_scale, + (llama_expert_gating_func_type) hparams.expert_gating_func, + il, + nullptr, + model.layers[il].ffn_gate_up_exps, + model.layers[il].ffn_up_exps_s, + model.layers[il].ffn_gate_exps_s, + model.layers[il].ffn_down_exps_s); + cb(moe_out, "ffn_moe_out", il); + + // FFN shared expert + { + ggml_tensor * ffn_shexp = + build_ffn(cur, + model.layers[il].ffn_up_shexp, NULL, model.layers[il].ffn_up_shexp_s, + model.layers[il].ffn_gate_shexp, NULL, model.layers[il].ffn_gate_shexp_s, + model.layers[il].ffn_down_shexp, NULL, model.layers[il].ffn_down_shexp_s, + NULL, LLM_FFN_SILU, LLM_FFN_PAR, il); + cb(ffn_shexp, "ffn_shexp", il); + + cur = ggml_add(ctx0, moe_out, ffn_shexp); + cb(cur, "ffn_out", il); + } + } + cur = ggml_add(ctx0, cur, ffn_inp); + + cur = build_cvec(cur, il); + cb(cur, "l_out", il); + + // input for next layer + inpL = cur; + } + cur = inpL; + + cur = build_norm(cur, model.output_norm, NULL, LLM_NORM_RMS, -1); + + cb(cur, "result_norm", -1); + res->t_embd = cur; + + // lm_head + cur = ggml_mul_mat(ctx0, model.output, cur); + + cb(cur, "result_output", -1); + res->t_logits = cur; + + ggml_build_forward_expand(gf, cur); +} diff --git a/examples/talk-llama/models/delta-net-base.cpp b/examples/talk-llama/models/delta-net-base.cpp new file mode 100644 index 00000000000..ad9ce771408 --- /dev/null +++ b/examples/talk-llama/models/delta-net-base.cpp @@ -0,0 +1,606 @@ +#include "models.h" + +#include "llama-impl.h" +#include "llama-memory-recurrent.h" + +// utility to get one slice from the third dimension +// input dim: [x, y, c, b] +// output dim: [x, y, 1, b] +static ggml_tensor * get_slice_2d(ggml_context * ctx0, ggml_tensor * t, int64_t c) { + return ggml_view_4d(ctx0, t, t->ne[0], t->ne[1], 1, t->ne[3], + t->nb[1], t->nb[2], t->nb[3], t->nb[2] * c); +} + +llm_build_delta_net_base::llm_build_delta_net_base(const llm_graph_params & params) : llm_graph_context(params) {} + +std::pair<ggml_tensor *, ggml_tensor *> llm_build_delta_net_base::build_delta_net_chunking( + ggml_tensor * q, + ggml_tensor * k, + ggml_tensor * v, + ggml_tensor * g, + ggml_tensor * b, + ggml_tensor * s, + int il) { + const int64_t S_k = q->ne[0]; + const int64_t H_k = q->ne[1]; + const int64_t n_tokens = q->ne[2]; + const int64_t n_seqs = q->ne[3]; + + const int64_t S_v = v->ne[0]; + const int64_t H_v = v->ne[1]; + const bool kda = (g->ne[0] == S_k && g->ne[1] == H_k); + + GGML_ASSERT(S_k == S_v); + GGML_ASSERT(H_v % H_k == 0); + + GGML_ASSERT(q->ne[0] == S_k && q->ne[1] == H_k && q->ne[2] == n_tokens && q->ne[3] == n_seqs); + GGML_ASSERT(k->ne[0] == S_k && k->ne[1] == H_k && k->ne[2] == n_tokens && k->ne[3] == n_seqs); + GGML_ASSERT(v->ne[0] == S_v && v->ne[1] == H_v && v->ne[2] == n_tokens && v->ne[3] == n_seqs); + + GGML_ASSERT(g->ne[0] == 1 || g->ne[0] == S_v); + GGML_ASSERT( g->ne[1] == H_v && g->ne[2] == n_tokens && g->ne[3] == n_seqs); + GGML_ASSERT(b->ne[0] == 1 && b->ne[1] == H_v && b->ne[2] == n_tokens && b->ne[3] == n_seqs); + GGML_ASSERT(s->ne[0] == S_v && s->ne[1] == S_v && s->ne[2] == H_v && s->ne[3] == n_seqs); + + const float scale = 1.0f / sqrtf(S_k); + + q = ggml_scale(ctx0, q, scale); + + cb(q, "q_in", il); + cb(k, "k_in", il); + cb(v, "v_in", il); + cb(b, "b_in", il); + cb(g, "g_in", il); + + q = ggml_permute(ctx0, q, 0, 2, 1, 3); // [S_k, n_tokens, H_k, n_seqs] + k = ggml_permute(ctx0, k, 0, 2, 1, 3); // [S_k, n_tokens, H_k, n_seqs] + v = ggml_permute(ctx0, v, 0, 2, 1, 3); // [S_v, n_tokens, H_v, n_seqs] + g = ggml_permute(ctx0, g, 0, 2, 1, 3); // [g_0, n_tokens, H_v, n_seqs] + b = ggml_permute(ctx0, b, 0, 2, 1, 3); // [ 1, n_tokens, H_v, n_seqs] + + const int CS = kda ? 16 : 64; // chunk size + + const int pad = (CS - n_tokens % CS) % CS; + const int n_chunks = (n_tokens + pad) / CS; + + q = ggml_pad(ctx0, q, 0, pad, 0, 0); + k = ggml_pad(ctx0, k, 0, pad, 0, 0); + v = ggml_pad(ctx0, v, 0, pad, 0, 0); + g = ggml_pad(ctx0, g, 0, pad, 0, 0); + b = ggml_pad(ctx0, b, 0, pad, 0, 0); + + ggml_tensor * v_b = ggml_mul(ctx0, v, b); + ggml_tensor * k_b = ggml_mul(ctx0, k, b); + + cb(v_b, "v_b", il); + cb(k_b, "k_b", il); + + q = ggml_reshape_4d(ctx0, q, S_k, CS, n_chunks, H_k * n_seqs); + k = ggml_reshape_4d(ctx0, k, S_k, CS, n_chunks, H_k * n_seqs); + k_b = ggml_reshape_4d(ctx0, k_b, S_k, CS, n_chunks, H_v * n_seqs); + v = ggml_reshape_4d(ctx0, v, S_v, CS, n_chunks, H_v * n_seqs); + v_b = ggml_reshape_4d(ctx0, v_b, S_v, CS, n_chunks, H_v * n_seqs); + + g = ggml_reshape_4d(ctx0, g, g->ne[0], CS, n_chunks, H_v * n_seqs); + b = ggml_reshape_4d(ctx0, b, 1, CS, n_chunks, H_v * n_seqs); + + // [CS, g_0, n_chunks, H_v * n_seqs] + // TODO: extend ggml_cumsum with axis parameter to avoid transpose + ggml_tensor * g_cs = ggml_cumsum(ctx0, ggml_cont(ctx0, ggml_transpose(ctx0, g))); + cb(g_cs, "g_cs", il); + + ggml_tensor * kb = nullptr; + ggml_tensor * kq = nullptr; + if (kda) { + const int64_t CHB = n_chunks * H_k * n_seqs; + + ggml_tensor * g_cs_i = ggml_reshape_4d(ctx0, g_cs, CS, 1, S_k, CHB); // [chunk_size, 1, S_k, CHB] + ggml_tensor * g_cs_j = ggml_reshape_4d(ctx0, g_cs, 1, CS, S_k, CHB); // [1, chunk_size, S_k, CHB] + + g_cs_j = ggml_repeat_4d(ctx0, g_cs_j, CS, CS, S_k, CHB); // [1, chunk_size, S_k, CHB] -> [chunk_size, chunk_size, S_k, CHB] + + // decay_mask [chunk_size,chunk_size,S_k,CHB] + ggml_tensor * decay_mask; + decay_mask = ggml_sub(ctx0, g_cs_j, g_cs_i); + decay_mask = ggml_tri(ctx0, decay_mask, GGML_TRI_TYPE_LOWER_DIAG); + decay_mask = ggml_exp(ctx0, decay_mask); + cb(decay_mask, "decay_mask", il); + + // decay_mask [S_k,BT_j,BT_i,CHB] *Note* second and third chunk_sizes are switched + decay_mask = ggml_cont_4d(ctx0, ggml_permute(ctx0, decay_mask, 2, 1, 0, 3), S_k, CS, CS, CHB); + + ggml_tensor * k_b_i = ggml_reshape_4d(ctx0, k_b, S_k, CS, 1, CHB); + ggml_tensor * k_j = ggml_reshape_4d(ctx0, k, S_k, 1, CS, CHB); + ggml_tensor * q_i = ggml_reshape_4d(ctx0, q, S_k, CS, 1, CHB); + + ggml_tensor * decay_k_b_i = ggml_mul(ctx0, decay_mask, k_b_i); + ggml_tensor * decay_q_i = ggml_mul(ctx0, decay_mask, q_i); + + // decay_k_b_i [S,BT,BT,CHB] @ k_j [S,1,BT,CHB] = Akk [BT,1,BT,CHB] + kb = ggml_mul_mat(ctx0, decay_k_b_i, k_j); + kq = ggml_mul_mat(ctx0, decay_q_i, k_j); + + kb = ggml_cont(ctx0, ggml_transpose(ctx0, ggml_reshape_4d(ctx0, kb, CS, CS, n_chunks, H_v * n_seqs))); + kq = ggml_cont(ctx0, ggml_transpose(ctx0, ggml_reshape_4d(ctx0, kq, CS, CS, n_chunks, H_v * n_seqs))); + } else { + ggml_tensor * g_cs_i = g_cs; + ggml_tensor * g_cs_j = ggml_reshape_4d(ctx0, g_cs, 1, CS, n_chunks, H_v * n_seqs); + + g_cs_j = ggml_repeat_4d(ctx0, g_cs_j, CS, CS, n_chunks, H_v * n_seqs); + + // [CS, CS, n_chunks, H_v * n_seqs] + ggml_tensor * decay_mask; + decay_mask = ggml_sub(ctx0, g_cs_j, g_cs_i); + decay_mask = ggml_tri(ctx0, decay_mask, GGML_TRI_TYPE_LOWER_DIAG); + decay_mask = ggml_exp(ctx0, decay_mask); + cb(decay_mask, "decay_mask", il); + + // [CS, CS, n_chunks, H_k * n_seqs] + kb = ggml_mul_mat(ctx0, k, k_b); + kb = ggml_mul (ctx0, kb, decay_mask); + + // [CS, CS, n_chunks, H_k * n_seqs] + kq = ggml_mul_mat(ctx0, k, q); + kq = ggml_mul(ctx0, kq, decay_mask); + } + + kq = ggml_tri(ctx0, kq, GGML_TRI_TYPE_LOWER_DIAG); + cb(kq, "kq", il); + + // [CS, CS, n_chunks, H_k * n_seqs] + ggml_tensor * attn; + attn = ggml_tri(ctx0, kb, GGML_TRI_TYPE_LOWER); + cb(attn, "attn", il); + + ggml_tensor * identity; + identity = ggml_view_1d(ctx0, attn, CS, 0); + identity = ggml_fill (ctx0, identity, 1.0f); + identity = ggml_diag (ctx0, identity); + + ggml_tensor * lhs = ggml_add(ctx0, attn, identity); + cb(lhs, "dnet_add_ch_lhs", il); + + attn = ggml_neg(ctx0, attn); + cb(attn, "attn_pre_solve", il); + + ggml_tensor * lin_solve = ggml_solve_tri(ctx0, lhs, attn, true, true, false); + attn = ggml_add(ctx0, lin_solve, identity); + cb(attn, "dnet_add_ch_attn_solved", il); // [CS, CS, n_chunks, H_k * n_seqs] + + // [S_v, CS, n_chunks, H_v * n_seqs] + v = ggml_mul_mat(ctx0, ggml_cont(ctx0, ggml_transpose(ctx0, v_b)), attn); + + // [CS, 1, n_chunks, H_v * n_seqs] KDA: [CS, S_k, n_chunks, H_v * n_seqs] + ggml_tensor * g_exp = ggml_exp(ctx0, g_cs); + + k_b = ggml_cont(ctx0, ggml_transpose(ctx0, k_b)); + + // [CS, S_k, n_chunks, H_k * n_seqs] + ggml_tensor * kbg = ggml_mul(ctx0, k_b, g_exp); + cb(kbg, "k_beta_g_exp", il); + + // [S_k, CS, n_chunks, H_k * n_seqs] + ggml_tensor * k_cd = ggml_mul_mat(ctx0, kbg, attn); + cb(k_cd, "k_cumdecay", il); + + // [1, CS, n_chunks, H_k * n_seqs] KDA: [S_k, CS, n_chunks, H_k * n_seqs] + ggml_tensor * g_exp_t = ggml_cont(ctx0, ggml_transpose(ctx0, g_exp)); + ggml_tensor * q_g_exp = ggml_mul(ctx0, q, g_exp_t); + + // vectorized calculation of key_gdiff + // improved from the chunked version: + // g_last = torch.clamp(g_cum[:, :, -1], max=50.0).exp().unsqueeze(-1).unsqueeze(-1) + // g_diff = torch.clamp(g_cum[:, :, -1:] - g_cum, max=50.0).exp() + // key_gdiff = key * g_diff.unsqueeze(-1) + // kgdmulvnew = (key_gdiff).transpose(-1, -2) @ v_new + // last_recurrent_state = last_recurrent_state * g_last + kgdmulvnew + + // get last element in g_cumsum along CS dimension (ne0) + // example: [[x, y, z, ..., last], ...] -> [[last], ...] + // [1, 1, n_chunks, H_v * n_seqs] KDA: [1, S_k, n_chunks, H_v * n_seqs] + ggml_tensor * g_last = ggml_view_4d(ctx0, g_cs, 1, g_cs->ne[1], g_cs->ne[2], g_cs->ne[3], + g_cs->nb[1], + g_cs->nb[2], + g_cs->nb[3], + ggml_row_size(g_cs->type, g_cs->ne[0] - 1)); + cb(g_last, "g_last", il); + + // TODO: remove this cont when CUDA supports non-cont unary ops + g_last = ggml_cont(ctx0, g_last); + + // [1, 1, n_chunks, H_v * n_seqs] KDA: [S_k, 1, n_chunks, H_v * n_seqs] + ggml_tensor * g_last_exp_t = ggml_transpose(ctx0, ggml_exp(ctx0, g_last)); + cb(g_last_exp_t, "g_last_exp_t", il); + + // [CS, 1, n_chunks, H_v * n_seqs] KDA: [CS, S_k, n_chunks, H_v * n_seqs] + ggml_tensor * g_diff = ggml_neg(ctx0, ggml_sub(ctx0, g_cs, g_last)); + cb(g_diff, "g_diff", il); + + ggml_tensor * g_diff_exp_t = ggml_cont(ctx0, ggml_transpose(ctx0, ggml_exp(ctx0, g_diff))); + + // [S_k, CS, n_chunks, H_v * n_seqs] + ggml_tensor * kg = ggml_mul(ctx0, k, g_diff_exp_t); + cb(kg, "key_gdiff", il); + + // [CS, S_k, n_chunks, H_v * n_seqs] + ggml_tensor * kg_t = ggml_cont(ctx0, ggml_transpose(ctx0, kg)); + cb(kg_t, "key_gdiff_t", il); + + s = ggml_reshape_4d(ctx0, s, S_v, S_v, 1, H_v * n_seqs); + cb(s, "dnet_add_ch_state", il); + + // [CS, S_v, n_chunks, H_v * n_seqs] + ggml_tensor * v_t = ggml_cont(ctx0, ggml_transpose(ctx0, v)); + + for (int64_t chunk = 0; chunk < n_chunks; chunk++) { + ggml_tensor * ch_k_cd = get_slice_2d(ctx0, k_cd, chunk); // [S_k, CS, 1, H_k * n_seqs] + ggml_tensor * ch_v_t = get_slice_2d(ctx0, v_t, chunk); // [ CS, S_v, 1, H_v * n_seqs] + ggml_tensor * ch_kq = get_slice_2d(ctx0, kq, chunk); // [ CS, CS, 1, H_k * n_seqs] + ggml_tensor * ch_q_g_exp = get_slice_2d(ctx0, q_g_exp, chunk); // [S_k, CS, 1, H_k * n_seqs] + ggml_tensor * ch_kg_t = get_slice_2d(ctx0, kg_t, chunk); // [ CS, S_k, 1, H_v * n_seqs] + + // [CS, S_v, 1, H_v * n_seqs] + ggml_tensor * v_t_p = ggml_mul_mat(ctx0, ch_k_cd, s); + cb(v_t_p, "v_prime", il); + + // [CS, S_v, 1, H_v * n_seqs] + ggml_tensor * v_t_new = ggml_sub(ctx0, ch_v_t, v_t_p); + cb(v_t_new, "v_t_new", il); + + // [S_v, CS, 1, H_v * n_seqs] + ggml_tensor * v_attn = ggml_mul_mat(ctx0, v_t_new, ch_kq); + cb(v_attn, "v_attn", il); + + // [S_v, CS, 1, H_v * n_seqs] + ggml_tensor * attn_inter = ggml_mul_mat(ctx0, s, ch_q_g_exp); + cb(attn_inter, "attn_inter", il); + + // [S_v, CS, 1, H_v * n_seqs] + ggml_tensor * o_ch = ggml_add(ctx0, attn_inter, v_attn); + cb(o_ch, "dnet_add_ch_attn_out", il); + + v = ggml_set_inplace(ctx0, v, o_ch, v->nb[1], v->nb[2], v->nb[3], chunk * v->nb[2]); + + // kgdmulvnew = (key_gdiff).transpose(-1, -2) @ v_new + // TODO: head broadcast might not work here - probably will need a transpose + ggml_tensor * kgv = ggml_mul_mat(ctx0, ch_kg_t, v_t_new); // [S_k, S_v, 1, H_k * n_seqs] + + // last_recurrent_state = last_recurrent_state * g_last + kgdmulvnew + ggml_tensor * ch_g_last_exp_t = get_slice_2d(ctx0, g_last_exp_t, chunk); + + s = ggml_mul(ctx0, s, ch_g_last_exp_t); + s = ggml_add(ctx0, s, kgv); + cb(s, "dnet_add_ch_state", il); + } + + // truncate padded tokens + ggml_tensor * o = ggml_view_4d(ctx0, v, + S_v, n_tokens, H_v, n_seqs, + ggml_row_size(v->type, S_v), + ggml_row_size(v->type, S_v * CS * n_chunks), + ggml_row_size(v->type, S_v * CS * n_chunks * H_v), 0); + o = ggml_permute (ctx0, o, 0, 2, 1, 3); // [S_v, H_v, n_tokens, n_seqs] + s = ggml_reshape_4d(ctx0, s, S_v, S_v, H_v, n_seqs); + cb(s, "output_state", il); + + return {o, s}; +} + +std::pair<ggml_tensor *, ggml_tensor *> llm_build_delta_net_base::build_delta_net_autoregressive( + ggml_tensor * q, + ggml_tensor * k, + ggml_tensor * v, + ggml_tensor * g, + ggml_tensor * b, // beta + ggml_tensor * s, // state + int il) { + const int64_t S_k = q->ne[0]; + const int64_t H_k = q->ne[1]; + const int64_t n_tokens = q->ne[2]; + const int64_t n_seqs = q->ne[3]; + + const int64_t S_v = v->ne[0]; + const int64_t H_v = v->ne[1]; + + GGML_ASSERT(n_tokens == 1); + + GGML_ASSERT(S_k == S_v); + GGML_ASSERT(H_v % H_k == 0); + + GGML_ASSERT(q->ne[0] == S_k && q->ne[1] == H_k && q->ne[2] == n_tokens && q->ne[3] == n_seqs); + GGML_ASSERT(k->ne[0] == S_k && k->ne[1] == H_k && k->ne[2] == n_tokens && k->ne[3] == n_seqs); + GGML_ASSERT(v->ne[0] == S_v && v->ne[1] == H_v && v->ne[2] == n_tokens && v->ne[3] == n_seqs); + + GGML_ASSERT(g->ne[0] == 1 || g->ne[0] == S_v); + GGML_ASSERT( g->ne[1] == H_v && g->ne[2] == n_tokens && g->ne[3] == n_seqs); + GGML_ASSERT(b->ne[0] == 1 && b->ne[1] == H_v && b->ne[2] == n_tokens && b->ne[3] == n_seqs); + GGML_ASSERT(s->ne[0] == S_v && s->ne[1] == S_v && s->ne[2] == H_v && s->ne[3] == n_seqs); + + const float scale = 1.0f / sqrtf(S_k); + + q = ggml_scale(ctx0, q, scale); + + q = ggml_permute(ctx0, q, 0, 2, 1, 3); // [S_k, n_tokens, H_k, n_seqs] + k = ggml_permute(ctx0, k, 0, 2, 1, 3); // [S_k, n_tokens, H_k, n_seqs] + v = ggml_permute(ctx0, v, 0, 2, 1, 3); // [S_v, n_tokens, H_v, n_seqs] + + cb(q, "q_in", il); + cb(k, "k_in", il); + cb(v, "v_in", il); + cb(b, "b_in", il); + cb(g, "g_in", il); + + // GDA: [1, 1, H_v, n_seqs] + // KDA: [1, S_k, H_v, n_seqs] + g = ggml_reshape_4d(ctx0, g, 1, g->ne[0], H_v, n_seqs); + b = ggml_reshape_4d(ctx0, b, 1, 1, H_v, n_seqs); + + // [S_v, S_v, H_v, n_seqs] + g = ggml_exp(ctx0, g); + s = ggml_mul(ctx0, s, g); + + // [1, S_v, H_v, n_seqs] + ggml_tensor * sk; + sk = ggml_mul (ctx0, s, k); + sk = ggml_sum_rows(ctx0, sk); + + // [S_v, 1, H_v, n_seqs] + ggml_tensor * d; + d = ggml_sub(ctx0, v, ggml_transpose(ctx0, sk)); + d = ggml_mul(ctx0, d, b); + + // [1, S_v, H_v, n_seqs] + ggml_tensor * d_t; + d_t = ggml_transpose(ctx0, d); + + // [S_v, S_v, H_v, n_seqs] + ggml_tensor * kd; + k = ggml_repeat(ctx0, k, s); + kd = ggml_mul (ctx0, k, d_t); + + s = ggml_add(ctx0, s, kd); + + cb(s, "dnet_add_ar_state", il); + + ggml_tensor * s_q = ggml_mul (ctx0, s, q); + ggml_tensor * o = ggml_sum_rows(ctx0, s_q); + + o = ggml_permute (ctx0, o, 2, 0, 1, 3); // [S_v, H_v, n_tokens, n_seqs] + + return {o, s}; +} + +std::pair<ggml_tensor *, ggml_tensor *> llm_build_delta_net_base::build_delta_net_fused( + ggml_tensor * q, + ggml_tensor * k, + ggml_tensor * v, + ggml_tensor * g, + ggml_tensor * b, + ggml_tensor * s, + int il) { + const int64_t S_k = q->ne[0]; + const int64_t H_k = q->ne[1]; + const int64_t n_tokens = q->ne[2]; + const int64_t n_seqs = q->ne[3]; + + const int64_t S_v = v->ne[0]; + const int64_t H_v = v->ne[1]; + + GGML_ASSERT(S_k == S_v); + GGML_ASSERT(H_v % H_k == 0); + + GGML_ASSERT(q->ne[0] == S_k && q->ne[1] == H_k && q->ne[2] == n_tokens && q->ne[3] == n_seqs); + GGML_ASSERT(k->ne[0] == S_k && k->ne[1] == H_k && k->ne[2] == n_tokens && k->ne[3] == n_seqs); + GGML_ASSERT(v->ne[0] == S_v && v->ne[1] == H_v && v->ne[2] == n_tokens && v->ne[3] == n_seqs); + + GGML_ASSERT(g->ne[0] == 1 || g->ne[0] == S_v); + GGML_ASSERT( g->ne[1] == H_v && g->ne[2] == n_tokens && g->ne[3] == n_seqs); + GGML_ASSERT(b->ne[0] == 1 && b->ne[1] == H_v && b->ne[2] == n_tokens && b->ne[3] == n_seqs); + GGML_ASSERT(s->ne[0] == S_v && s->ne[1] == S_v && s->ne[2] == H_v && s->ne[3] == n_seqs); + + // K=1: output carries the final state only. state s is 4D [S_v, S_v, H_v, n_seqs]. + ggml_tensor * result = ggml_gated_delta_net(ctx0, q, k, v, g, b, s, /*K=*/1); + if (n_tokens == 1) { + cb(result, LLAMA_TENSOR_NAME_FGDN_AR, il); + } else { + cb(result, LLAMA_TENSOR_NAME_FGDN_CH, il); + } + + ggml_tensor * output = ggml_view_4d(ctx0, result, + S_v, H_v, n_tokens, n_seqs, + ggml_row_size(result->type, S_v), + ggml_row_size(result->type, S_v * H_v), + ggml_row_size(result->type, S_v * H_v * n_tokens), 0); + + ggml_tensor * new_state = ggml_view_4d(ctx0, result, + S_v, S_v, H_v, n_seqs, + ggml_row_size(result->type, S_v), + ggml_row_size(result->type, S_v * S_v), + ggml_row_size(result->type, S_v * S_v * H_v), + ggml_row_size(result->type, S_v * H_v * n_tokens * n_seqs)); + + return {output, new_state}; +} + +std::pair<ggml_tensor *, ggml_tensor *> llm_build_delta_net_base::build_delta_net( + ggml_tensor * q, + ggml_tensor * k, + ggml_tensor * v, + ggml_tensor * g, + ggml_tensor * b, + ggml_tensor * s, + int il) { + const int64_t n_seq_tokens = q->ne[2]; + + if (n_seq_tokens == 1) { + if (cparams.fused_gdn_ar) { + return build_delta_net_fused(q, k, v, g, b, s, il); + } + return build_delta_net_autoregressive(q, k, v, g, b, s, il); + } + + if (cparams.fused_gdn_ch) { + return build_delta_net_fused(q, k, v, g, b, s, il); + } + + return build_delta_net_chunking(q, k, v, g, b, s, il); +} + +ggml_tensor * llm_build_delta_net_base::build_conv_state( + llm_graph_input_rs * inp, + ggml_tensor * conv_states_all, + ggml_tensor * qkv_mixed, + int64_t conv_kernel_size, + int64_t conv_channels, + int il) { + const auto * mctx_cur = inp->mctx; + + const auto kv_head = mctx_cur->get_head(); + const auto mem_size = mctx_cur->get_size(); + + const int64_t n_seqs = ubatch.n_seqs; + + ggml_tensor * conv_states = build_rs(inp, conv_states_all, hparams.n_embd_r(), n_seqs); + cb(conv_states, "conv_states", il); + + conv_states = ggml_reshape_3d(ctx0, conv_states, conv_kernel_size - 1, conv_channels, n_seqs); + cb(conv_states, "conv_states_reshaped", il); + + qkv_mixed = ggml_transpose(ctx0, qkv_mixed); + cb(qkv_mixed, "qkv_mixed_transposed", il); + + ggml_tensor * conv_input = ggml_concat(ctx0, conv_states, qkv_mixed, 0); + cb(conv_input, "conv_input", il); + + const int64_t row_count = (conv_kernel_size - 1) * conv_channels; + + const size_t row_size = ggml_row_size(conv_states_all->type, row_count); + + if (cparams.n_rs_seq == 0) { + const int64_t s_idx = conv_input->ne[0] - conv_states->ne[0]; + const int64_t s_slot = 0; + + ggml_tensor * conv_state_last = + ggml_view_3d(ctx0, conv_input, + conv_kernel_size - 1, conv_channels, n_seqs, + conv_input->nb[1], conv_input->nb[2], + ggml_row_size(conv_input->type, s_idx)); + cb(conv_state_last, "conv_state_last", il); + + ggml_tensor * conv_state_update = + ggml_view_2d(ctx0, conv_states_all, + row_count, n_seqs, conv_states_all->nb[1], + (s_slot * mem_size + kv_head) * row_size); + cb(conv_state_update, "conv_state_update", il); + + ggml_build_forward_expand(gf, ggml_cpy(ctx0, conv_state_last, conv_state_update)); + } else { + // [TAG_RECURRENT_ROLLBACK_SPLITS] + // TODO: this logic incorrectly assumes that the last (n_rs_seq + 1) tokens of a sequence in a batch are + // inside the same ubatch. currently with `split_equal()` this is not correct + + const int64_t K = (int64_t) cparams.n_rs_seq + 1; + + for (int64_t t = 1; t <= K; ++t) { + const int64_t s_idx = std::max<int64_t>(0, conv_input->ne[0] - conv_states->ne[0] - K + t); + const int64_t s_slot = K - t; + + ggml_tensor * conv_state_last = + ggml_view_3d(ctx0, conv_input, + conv_kernel_size - 1, conv_channels, n_seqs, + conv_input->nb[1], conv_input->nb[2], + ggml_row_size(conv_input->type, s_idx)); + + ggml_tensor * conv_state_update = + ggml_view_2d(ctx0, + conv_states_all, row_count, n_seqs, + conv_states_all->nb[1], + (s_slot * mem_size + kv_head) * row_size); + + ggml_build_forward_expand(gf, ggml_cpy(ctx0, conv_state_last, conv_state_update)); + } + } + + return conv_input; +} + +ggml_tensor * llm_build_delta_net_base::build_recurrent_attn( + llm_graph_input_rs * inp, + ggml_tensor * ssm_states_all, + ggml_tensor * q, + ggml_tensor * k, + ggml_tensor * v, + ggml_tensor * g, + ggml_tensor * b, + ggml_tensor * s, + int il) { + const auto * mctx_cur = inp->mctx; + const auto kv_head = mctx_cur->get_head(); + const uint32_t mem_size = mctx_cur->get_size(); + + const int64_t S_v = s->ne[0]; + const int64_t H_v = s->ne[2]; + const int64_t n_seqs = s->ne[3]; + const int64_t n_seq_tokens = q->ne[2]; + + const bool keep = cparams.n_rs_seq > 0; + + if (!keep) { + auto attn_out = build_delta_net(q, k, v, g, b, s, il); + ggml_tensor * output = attn_out.first; + ggml_tensor * new_state = attn_out.second; + cb(output, "attn_output", il); + cb(new_state, "new_state", il); + + ggml_build_forward_expand(gf, + ggml_cpy(ctx0, new_state, + ggml_view_2d(ctx0, ssm_states_all, hparams.n_embd_s(), n_seqs, ssm_states_all->nb[1], + kv_head * hparams.n_embd_s() * ggml_element_size(ssm_states_all)))); + + return output; + } + + const int64_t D = S_v * S_v * H_v; + const int64_t K = cparams.n_rs_seq + 1; + + // state s is 4D [S_v, S_v, H_v, n_seqs]; K snapshot slots are written into the output. + ggml_tensor * gdn_out = ggml_gated_delta_net(ctx0, q, k, v, g, b, s, K); + if (n_seq_tokens > 1) { + cb(gdn_out, LLAMA_TENSOR_NAME_FGDN_CH, il); + } else { + cb(gdn_out, LLAMA_TENSOR_NAME_FGDN_AR, il); + } + + const int64_t attn_score_elems = S_v * H_v * n_seq_tokens * n_seqs; + const int64_t state_size_per_snap = S_v * S_v * H_v * n_seqs; + + ggml_tensor * output = ggml_view_4d(ctx0, gdn_out, + S_v, H_v, n_seq_tokens, n_seqs, + ggml_row_size(gdn_out->type, S_v), + ggml_row_size(gdn_out->type, S_v * H_v), + ggml_row_size(gdn_out->type, S_v * H_v * n_seq_tokens), + 0); + cb(output, "attn_output", il); + + const size_t row_size = hparams.n_embd_s() * ggml_element_size(ssm_states_all); + + // op writes the last min(n_seq_tokens, K) snapshots; trailing slots are left unwritten + const int64_t n_written = std::min<int64_t>(n_seq_tokens, K); + + // write the produced snapshots into the recurrent cache (snapshot slot i -> rollback group i) + ggml_tensor * src = ggml_view_3d(ctx0, gdn_out, + D, n_seqs, n_written, + ggml_row_size(gdn_out->type, D), + ggml_row_size(gdn_out->type, state_size_per_snap), + ggml_row_size(gdn_out->type, attn_score_elems)); + + ggml_tensor * dst = ggml_view_3d(ctx0, ssm_states_all, + D, n_seqs, n_written, + ssm_states_all->nb[1], + (size_t) mem_size * row_size, + (size_t) kv_head * row_size); + + ggml_build_forward_expand(gf, ggml_cpy(ctx0, src, dst)); + + return output; +} diff --git a/examples/talk-llama/models/dots1.cpp b/examples/talk-llama/models/dots1.cpp index 09c36f82fe2..07d6ab1b7cd 100644 --- a/examples/talk-llama/models/dots1.cpp +++ b/examples/talk-llama/models/dots1.cpp @@ -1,13 +1,82 @@ #include "models.h" +void llama_model_dots1::load_arch_hparams(llama_model_loader & ml) { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + ml.get_key(LLM_KV_LEADING_DENSE_BLOCK_COUNT, hparams.n_layer_dense_lead, false); + ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp); + ml.get_key(LLM_KV_EXPERT_SHARED_COUNT, hparams.n_expert_shared); + ml.get_key(LLM_KV_EXPERT_WEIGHTS_SCALE, hparams.expert_weights_scale, false); + ml.get_key(LLM_KV_EXPERT_WEIGHTS_NORM, hparams.expert_weights_norm, false); + ml.get_key(LLM_KV_EXPERT_GATING_FUNC, hparams.expert_gating_func, false); + + switch (hparams.n_layer()) { + case 62: type = LLM_TYPE_142B; break; + default: type = LLM_TYPE_UNKNOWN; + } +} + +void llama_model_dots1::load_arch_tensors(llama_model_loader &) { + LLAMA_LOAD_LOCALS; + const int64_t n_expert_shared = hparams.n_expert_shared; + + const int64_t n_ff_exp = hparams.n_ff_exp; + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0); + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + + create_tensor_qkv(layer, i, n_embd, n_embd_head_k * n_head, n_embd_head_k * n_head, n_embd_head_k * n_head, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0); + + layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k}, 0); + layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head_k}, 0); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + + if (i < (int) hparams.n_layer_dense_lead) { + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + } else { + layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0); + layer.ffn_exp_probs_b = create_tensor(tn(LLM_TENSOR_FFN_EXP_PROBS_B, "bias", i), {n_expert}, TENSOR_NOT_REQUIRED); + + if (n_expert == 0) { + throw std::runtime_error("n_expert must be > 0"); + } + if (n_expert_used == 0) { + throw std::runtime_error("n_expert_used must be > 0"); + } + + // MoE branch + layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}, 0); + layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff_exp, n_embd, n_expert}, 0); + layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}, 0); + + // Shared expert branch + layer.ffn_gate_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), {n_embd, n_ff_exp * n_expert_shared}, 0); + layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), { n_ff_exp * n_expert_shared, n_embd}, 0); + layer.ffn_up_shexp = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), {n_embd, n_ff_exp * n_expert_shared}, 0); + } + } +} +std::unique_ptr<llm_graph_context> llama_model_dots1::build_arch_graph(const llm_graph_params & params) const { + return std::make_unique<graph>(*this, params); +} -llm_build_dots1::llm_build_dots1(const llama_model & model, const llm_graph_params & params) : +llama_model_dots1::graph::graph(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { - const int64_t n_embd_head = hparams.n_embd_head_v; + const int64_t n_embd_head = hparams.n_embd_head_v(); - GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); - GGML_ASSERT(n_embd_head == hparams.n_rot); + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k()); + GGML_ASSERT(n_embd_head == n_rot); ggml_tensor * cur; ggml_tensor * inpL; @@ -31,18 +100,8 @@ llm_build_dots1::llm_build_dots1(const llama_model & model, const llm_graph_para // self_attention { // compute Q and K and RoPE them - ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); - cb(Qcur, "Qcur", il); - - ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); - cb(Kcur, "Kcur", il); - - ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); - cb(Vcur, "Vcur", il); - - Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); - Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + auto [Qcur, Kcur, Vcur] = build_qkv(model.layers[il], cur, + n_embd_head, n_head, n_head_kv, il); Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, NULL, LLM_NORM_RMS, il); cb(Qcur, "Qcur_normed", il); @@ -61,7 +120,7 @@ llm_build_dots1::llm_build_dots1(const llama_model & model, const llm_graph_para cb(Vcur, "Vcur", il); cur = build_attn(inp_attn, - model.layers[il].wo, model.layers[il].bo, + model.layers[il].wo, model.layers[il].wo_b, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f / sqrtf(float(n_embd_head)), il); } if (il == n_layer - 1 && inp_out_ids) { @@ -91,7 +150,7 @@ llm_build_dots1::llm_build_dots1(const llama_model & model, const llm_graph_para model.layers[il].ffn_exp_probs_b, n_expert, n_expert_used, LLM_FFN_SILU, hparams.expert_weights_norm, - true, hparams.expert_weights_scale, + hparams.expert_weights_scale, (llama_expert_gating_func_type) hparams.expert_gating_func, il); cb(moe_out, "ffn_moe_out", il); @@ -125,7 +184,7 @@ llm_build_dots1::llm_build_dots1(const llama_model & model, const llm_graph_para res->t_embd = cur; // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/dream.cpp b/examples/talk-llama/models/dream.cpp index 2aafbae1397..abe737c335a 100644 --- a/examples/talk-llama/models/dream.cpp +++ b/examples/talk-llama/models/dream.cpp @@ -1,14 +1,61 @@ #include "models.h" +void llama_model_dream::load_arch_hparams(llama_model_loader & ml) { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + + // Dream models are primarily 7B with 28 layers + switch (hparams.n_layer()) { + case 28: + type = LLM_TYPE_7B; + break; + default: + type = LLM_TYPE_UNKNOWN; + } + // Set non-causal attention for diffusion models + hparams.causal_attn = false; +} + +void llama_model_dream::load_arch_tensors(llama_model_loader &) { + LLAMA_LOAD_LOCALS; + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); + output_b = create_tensor(tn(LLM_TENSOR_OUTPUT, "bias"), {n_vocab}, TENSOR_NOT_REQUIRED); + // if output is NULL, init from the input tok embed + if (output == NULL) { + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); + } + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + + create_tensor_qkv(layer, i, n_embd, n_embd, n_embd_gqa, n_embd_gqa, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + } +} -llm_build_dream::llm_build_dream(const llama_model & model, const llm_graph_params & params) : +std::unique_ptr<llm_graph_context> llama_model_dream::build_arch_graph(const llm_graph_params & params) const { + return std::make_unique<graph>(*this, params); +} + +llama_model_dream::graph::graph(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { //copied from qwen2 - const int64_t n_embd_head = hparams.n_embd_head_v; + const int64_t n_embd_head = hparams.n_embd_head_v(); - GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); - GGML_ASSERT(n_embd_head == hparams.n_rot); + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k()); + GGML_ASSERT(n_embd_head == n_rot); ggml_tensor * cur; ggml_tensor * inpL; @@ -31,22 +78,8 @@ llm_build_dream::llm_build_dream(const llama_model & model, const llm_graph_para // self-attention { - // compute Q and K and RoPE them - ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); - Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq); - cb(Qcur, "Qcur", il); - - ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); - Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk); - cb(Kcur, "Kcur", il); - - ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); - Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv); - cb(Vcur, "Vcur", il); - - Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); - Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + auto [Qcur, Kcur, Vcur] = build_qkv(model.layers[il], cur, + n_embd_head, n_head, n_head_kv, il); Qcur = ggml_rope_ext(ctx0, Qcur, inp_pos, nullptr, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow); @@ -59,7 +92,7 @@ llm_build_dream::llm_build_dream(const llama_model & model, const llm_graph_para cb(Vcur, "Vcur", il); cur = build_attn(inp_attn, - model.layers[il].wo, model.layers[il].bo, + model.layers[il].wo, model.layers[il].wo_b, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f / sqrtf(float(n_embd_head)), il); } if (il == n_layer - 1 && inp_out_ids) { @@ -96,7 +129,7 @@ llm_build_dream::llm_build_dream(const llama_model & model, const llm_graph_para res->t_embd = cur; // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/eagle3.cpp b/examples/talk-llama/models/eagle3.cpp new file mode 100644 index 00000000000..3321b390515 --- /dev/null +++ b/examples/talk-llama/models/eagle3.cpp @@ -0,0 +1,323 @@ +#include "models.h" + +void llama_model_eagle3::load_arch_hparams(llama_model_loader & ml) { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + + if (!ml.get_arr(LLM_KV_TARGET_LAYERS, target_layer_ids, false)) { + throw std::runtime_error("EAGLE3 model requires 'extract_layers' in GGUF metadata"); + } + if (target_layer_ids.size() != 3) { + throw std::runtime_error("EAGLE3 requires exactly 3 entries in 'extract_layers'"); + } + LLAMA_LOG_INFO("%s: EAGLE3 extract_layers = [%d, %d, %d]\n", __func__, + target_layer_ids[0], + target_layer_ids[1], + target_layer_ids[2]); + + uint32_t n_embd_tgt = 0; + + ml.get_key(LLM_KV_TARGET_HIDDEN_SIZE, n_embd_tgt); + LLAMA_LOG_INFO("%s: EAGLE3 n_embd_tgt = %u (draft n_embd = %u)\n", __func__, n_embd_tgt, hparams.n_embd); + + hparams.n_embd_inp_impl = (uint32_t) target_layer_ids.size() * n_embd_tgt; + + // eagle3 norm_before_residual (optional, default false) + // compatible with Readhat eagle3 speculator model + ml.get_key(LLM_KV_NORM_BEFORE_RESIDUAL, hparams.norm_before_residual, false); + if (hparams.norm_before_residual) { + LLAMA_LOG_INFO("%s: EAGLE3gnorm_before_residual = true\n", __func__); + } + + type = LLM_TYPE_UNKNOWN; +} + +void llama_model_eagle3::load_arch_tensors(llama_model_loader &) { + LLAMA_LOAD_LOCALS; + + const int64_t n_embd_inp = hparams.n_embd_inp(); + const int64_t n_embd_attn_input = 2 * n_embd; + + // Get vocab size from the d2t tensor in the GGUF file (optional - only needed if eagle3 has different vocab_size than target) + // d2t: draft to target vocabulary mapping + int64_t n_draft_vocab = n_vocab; // Default: same as target vocab + const struct ggml_tensor * d2t_meta = ml->get_tensor_meta("d2t"); + if (d2t_meta) { + n_draft_vocab = d2t_meta->ne[0]; // update draft vocab size + d2t = create_tensor(tn(LLM_TENSOR_D2T), {n_draft_vocab}, 0); + LLAMA_LOG_INFO("%s: EAGLE3 using d2t mapping (draft_vocab_size = %lld)\n", __func__, (long long)n_draft_vocab); + } else { + d2t = nullptr; // no d2t, use default vocab size + LLAMA_LOG_INFO("%s: EAGLE3 without d2t - sharing same vocab_size with target (vocab_size = %lld)\n", __func__, (long long)n_draft_vocab); + } + + // Feature fusion layer: projects 3 target layers to draft hidden size + fc = create_tensor(tn(LLM_TENSOR_FC, "weight"), {n_embd_inp, n_embd}, 0); + + // Output layer (uses draft vocab size) + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_draft_vocab}, TENSOR_NOT_REQUIRED); + + // Token embeddings (optional - Llama 3.3 70B EAGLE3 has its own) + const struct ggml_tensor * tok_embd_meta = ml->get_tensor_meta(tn(LLM_TENSOR_TOKEN_EMBD, "weight").str().c_str()); + if (tok_embd_meta) { + const int64_t n_target_vocab = tok_embd_meta->ne[1]; + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_target_vocab}, 0); + LLAMA_LOG_INFO("%s: EAGLE3 using its own token_embd (vocab = %lld)\n", __func__, (long long)n_target_vocab); + } + + // Single decoder layer + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + // input_layernorm: applied to token embeddings + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + + // eagle3 specific: hidden_norm applied to fused target features + layer.attn_norm_2 = create_tensor(tn(LLM_TENSOR_ATTN_NORM_2, "weight", i), {n_embd}, 0); + + // Attention takes input_embeds_normed + fused_target_normed as input + layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd_attn_input, n_embd_head_k * n_head}, 0); + layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd_attn_input, n_embd_k_gqa}, 0); + layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd_attn_input, n_embd_v_gqa}, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + + // rope_freqs for llama3 rope scaling (optional - only if eagle3 config has rope_scaling) + layer.rope_freqs = create_tensor(tn(LLM_TENSOR_ROPE_FREQS, "weight", i), {n_rot/2}, TENSOR_NOT_REQUIRED); + } +} + +std::unique_ptr<llm_graph_context> llama_model_eagle3::build_arch_graph(const llm_graph_params & params) const { + switch (params.gtype) { + case LLM_GRAPH_TYPE_ENCODER: + return std::make_unique<graph<true>>(*this, params); + case LLM_GRAPH_TYPE_DEFAULT: + case LLM_GRAPH_TYPE_DECODER: + return std::make_unique<graph<false>>(*this, params); + default: + GGML_ABORT("invalid graph type"); + }; +} + +template <> +ggml_tensor * llama_model_eagle3::graph<true>::build_inp_embd_enc() const { + ggml_tensor * cur = nullptr; + + // Input: Target model features (3 layers concatenated: low, mid, high) + // Data will be provided via ubatch->embd in encode_eagle3_features() + auto inp_target = std::make_unique<llm_graph_input_embd>(hparams.n_embd_inp()); + inp_target->embd = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32,hparams.n_embd_inp(), n_tokens); + ggml_set_input(inp_target->embd); + + cur = inp_target->embd; + cb(cur, "inp_embd", -1); + + res->add_input(std::move(inp_target)); + + return cur; +} + +// eagle3 Encoder: processes target model features through feature fusion layer +// Input: target_features e.g. [12288, n_tokens] from target model layers low, middle, high +// Output: g_embeddings e.g. [4096, n_tokens] stored in context +template <> +llama_model_eagle3::graph<true>::graph(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { + ggml_tensor * cur = nullptr; + + cur = build_inp_embd_enc(); + + // Feature fusion layer + cur = build_lora_mm(model.fc, cur); + cb(cur, "fc_out", -1); + + // Output: g_embeddings e.g. [4096, n_tokens] + // store in t_h_nextn (same as MTP) so can be read via llama_get_embeddings_nextn(ctx_dft) + ggml_set_output(cur); + res->t_h_nextn = cur; + + ggml_build_forward_expand(gf, cur); +} + +// eagle3 Decoder: processes draft tokens using g_embeddings from encoder +// Input: draft tokens + g_embeddings from encoder +// Output: draft logits +template <> +llama_model_eagle3::graph<false>::graph(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { + const int64_t n_embd_head = hparams.n_embd_head_v(); + + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k()); + GGML_ASSERT(n_layer == 1); // eagle3 has only one decoder layer + + ggml_tensor * cur; + ggml_tensor * inpL; + + // eagle3 Decoder receives: + // 1. Token embeddings (e.g.from eagle3's own tok_embd for Llama 3.3 70B, or target model for Llama 3.1 8B) + // 2. g_embeddings from encoder + auto * tok_embd = model.tok_embd; + if (model.tok_embd == nullptr) { + GGML_ASSERT(cparams.ctx_other != nullptr); + const auto * model_other = llama_get_model(cparams.ctx_other); + + GGML_ASSERT(model_other->tok_embd != nullptr && "EAGLE3 decoder requires token embeddings (own or from target model)"); + tok_embd = model_other->tok_embd; + } + + auto inp = std::make_unique<llm_graph_input_embd>(n_embd); + + inp->tokens = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens); + ggml_set_input(inp->tokens); + + inp->embd = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, n_tokens); + ggml_set_input(inp->embd); + + ggml_tensor * inp_embd = ggml_get_rows(ctx0, tok_embd, inp->tokens); + cb(inp_embd, "inp_embd", -1); + + ggml_tensor * inp_g = inp->embd; + cb(inp_g, "inp_g_embeddings", -1); + + res->add_input(std::move(inp)); + + inpL = inp_g; + + // inp_pos - contains the positions + ggml_tensor * inp_pos = build_inp_pos(); + + auto * inp_attn = build_attn_inp_kv(); + + const float kq_scale = 1.0f/sqrtf(float(n_embd_head)); + + // Single decoder layer (il = 0) + const int il = 0; + { + // Apply input_layernorm to the token embeddings + ggml_tensor * embd_norm = build_norm(inp_embd, + model.layers[il].attn_norm, NULL, + LLM_NORM_RMS, il); + cb(embd_norm, "embd_norm", il); + + // Apply hidden_norm to inp_g + ggml_tensor * g_norm = build_norm(inp_g, + model.layers[il].attn_norm_2, NULL, + LLM_NORM_RMS, -1); + cb(g_norm, "g_norm", il); + + // norm_before_residual: determines what goes into the residual connection (compatible with Readhat eagle3 speculator model) + // - false (default): use raw inp_g for residual + // - true: use normalized g_norm for residual + // inpL is the concatenated input (normalized inp_embd + normalized inp_g) + ggml_tensor * inpSA = hparams.norm_before_residual ? g_norm : inpL; + + // Concatenate normalized inp_embd and normalized inp_g + cur = ggml_concat(ctx0, embd_norm, g_norm, il); + cb(cur, "concat_embd", il); + + // Self-attention with concatenated input + ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); + cb(Qcur, "Qcur", il); + + ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); + cb(Kcur, "Kcur", il); + + ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); + cb(Vcur, "Vcur", il); + + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); + Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + + // rope freq factors, returns nullptr if not available + ggml_tensor * rope_factors = model.get_rope_factors(cparams, il); + + // RoPE + Qcur = ggml_rope_ext( + ctx0, Qcur, inp_pos, rope_factors, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + Kcur = ggml_rope_ext( + ctx0, Kcur, inp_pos, rope_factors, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + cb(Qcur, "Qcur_rope", il); + cb(Kcur, "Kcur_rope", il); + + cur = build_attn(inp_attn, + model.layers[il].wo, NULL, nullptr, + Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale, il); + + // Add residual and update it + ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA); + cb(ffn_inp, "ffn_inp", il); + + // Apply FFN norm to the sum + cur = build_norm(ffn_inp, + model.layers[il].ffn_norm, NULL, + LLM_NORM_RMS, il); + cb(cur, "post_attn_norm", il); + + cur = build_ffn(cur, + model.layers[il].ffn_up, NULL, NULL, + model.layers[il].ffn_gate, NULL, NULL, + model.layers[il].ffn_down, NULL, NULL, + NULL, + LLM_FFN_SILU, LLM_FFN_PAR, il); + cb(cur, "ffn_out", il); + + // Output norm with residual + cur = ggml_add(ctx0, cur, ffn_inp); + cb(cur, "eagle3_prenorm", il); + + inpL = cur; + } + + cur = inpL; + + // Output prenorm state (for next token's g_embeddings in autoregressive generation) + ggml_set_output(cur); + res->t_h_nextn = cur; + + cur = build_norm(cur, + model.output_norm, NULL, + LLM_NORM_RMS, -1); + cb(cur, "result_norm", -1); + + // lm_head - projects to draft vocabulary + // if the draft has no own output projection, inherit the target model's lm_head + auto * output = model.output; + if (output == nullptr) { + GGML_ASSERT(cparams.ctx_other != nullptr); + const auto * model_other = llama_get_model(cparams.ctx_other); + + GGML_ASSERT(model_other->output != nullptr && "EAGLE3 decoder requires an output projection (own or from target model)"); + output = model_other->output; + } + cur = build_lora_mm(output, cur); + + if (model.d2t) { + const int64_t n_draft_vocab = cur->ne[0]; + const int64_t n_outputs = cur->ne[1]; + const int64_t n_vocab = (int64_t) model.vocab.n_tokens(); + + GGML_ASSERT(model.d2t->type == GGML_TYPE_I64); + GGML_ASSERT(model.d2t->ne[0] == n_draft_vocab); + + ggml_tensor * logits = ggml_fill(ctx0, ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, 1, n_vocab, n_outputs), -INFINITY); + cur = ggml_set_rows(ctx0, logits, + ggml_reshape_3d(ctx0, cur, 1, n_draft_vocab, n_outputs), + ggml_reshape_3d(ctx0, model.d2t, n_draft_vocab, 1, 1)); + cur = ggml_reshape_2d(ctx0, cur, n_vocab, n_outputs); + } + + cb(cur, "result_output", -1); + res->t_logits = cur; + + ggml_build_forward_expand(gf, cur); +} diff --git a/examples/talk-llama/models/ernie4-5-moe.cpp b/examples/talk-llama/models/ernie4-5-moe.cpp index 0d96d14e6fd..8d9ff138676 100644 --- a/examples/talk-llama/models/ernie4-5-moe.cpp +++ b/examples/talk-llama/models/ernie4-5-moe.cpp @@ -1,13 +1,15 @@ #include "models.h" +std::unique_ptr<llm_graph_context> llama_model_ernie4_5_moe::build_arch_graph(const llm_graph_params & params) const { + return std::make_unique<graph>(*this, params); +} - -llm_build_ernie4_5_moe::llm_build_ernie4_5_moe(const llama_model & model, const llm_graph_params & params) : +llama_model_ernie4_5_moe::graph::graph(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { - const int64_t n_embd_head = hparams.n_embd_head_v; + const int64_t n_embd_head = hparams.n_embd_head_v(); - GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); - GGML_ASSERT(n_embd_head == hparams.n_rot); + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k()); + GGML_ASSERT(n_embd_head == n_rot); ggml_tensor * cur; ggml_tensor * inpL; @@ -32,27 +34,8 @@ llm_build_ernie4_5_moe::llm_build_ernie4_5_moe(const llama_model & model, const // self-attention { // compute Q and K and RoPE them - ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); - cb(Qcur, "Qcur", il); - if (model.layers[il].bq) { - Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq); - cb(Qcur, "Qcur", il); - } - ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); - cb(Kcur, "Kcur", il); - if (model.layers[il].bk) { - Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk); - cb(Kcur, "Kcur", il); - } - ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); - cb(Vcur, "Vcur", il); - if (model.layers[il].bv) { - Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv); - cb(Vcur, "Vcur", il); - } - Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); - Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + auto [Qcur, Kcur, Vcur] = build_qkv(model.layers[il], cur, + n_embd_head, n_head, n_head_kv, il); Qcur = ggml_rope_ext(ctx0, Qcur, inp_pos, nullptr, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow); @@ -65,7 +48,7 @@ llm_build_ernie4_5_moe::llm_build_ernie4_5_moe(const llama_model & model, const cb(Vcur, "Vcur", il); cur = build_attn(inp_attn, - model.layers[il].wo, NULL, + model.layers[il].wo, NULL, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f / sqrtf(float(n_embd_head)), il); cb(cur, "attn_out", il); } @@ -103,7 +86,7 @@ llm_build_ernie4_5_moe::llm_build_ernie4_5_moe(const llama_model & model, const model.layers[il].ffn_exp_probs_b, n_expert, n_expert_used, LLM_FFN_SILU, true, - false, 0.0, + hparams.expert_weights_scale, LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX, il); cb(moe_out, "ffn_moe_out", il); @@ -141,7 +124,7 @@ llm_build_ernie4_5_moe::llm_build_ernie4_5_moe(const llama_model & model, const res->t_embd = cur; // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/ernie4-5.cpp b/examples/talk-llama/models/ernie4-5.cpp index 99aead53283..895cf690bd2 100644 --- a/examples/talk-llama/models/ernie4-5.cpp +++ b/examples/talk-llama/models/ernie4-5.cpp @@ -1,11 +1,84 @@ #include "models.h" -llm_build_ernie4_5::llm_build_ernie4_5(const llama_model & model, const llm_graph_params & params) : +void llama_model_ernie4_5::load_arch_hparams(llama_model_loader & ml) { + // paddleocr need mrope_section + ml.get_key_or_arr(LLM_KV_ROPE_DIMENSION_SECTIONS, hparams.rope_sections, 4, false); + + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + if (arch == LLM_ARCH_ERNIE4_5_MOE) { + ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp); + ml.get_key(LLM_KV_EXPERT_SHARED_FEED_FORWARD_LENGTH, hparams.n_ff_shexp, false); + ml.get_key(LLM_KV_INTERLEAVE_MOE_LAYER_STEP, hparams.n_moe_layer_step); + ml.get_key(LLM_KV_LEADING_DENSE_BLOCK_COUNT, hparams.n_layer_dense_lead, false); + } + + switch (hparams.n_layer()) { + case 18: type = LLM_TYPE_0_3B; break; + case 28: type = LLM_TYPE_21B_A3B; break; + case 54: type = LLM_TYPE_300B_A47B; break; + default: type = LLM_TYPE_UNKNOWN; + } +} + +void llama_model_ernie4_5::load_arch_tensors(llama_model_loader &) { + LLAMA_LOAD_LOCALS; + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); + // if output is NULL, init from the input tok embed + if (output == NULL) { + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); + } + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + + create_tensor_qkv(layer, i, n_embd, n_embd_head_k * n_head, n_embd_gqa, n_embd_gqa, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0); + + // optional bias tensors + layer.wo_b = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + + if (arch == LLM_ARCH_ERNIE4_5_MOE && static_cast<uint32_t>(i) >= hparams.n_layer_dense_lead) { // MoE layers + int n_ff_exp = hparams.n_ff_exp; + + layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0); + layer.ffn_exp_probs_b = create_tensor(tn(LLM_TENSOR_FFN_EXP_PROBS_B, "bias", i), {n_expert}, TENSOR_NOT_REQUIRED); + layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {n_embd, n_ff_exp, n_expert}, TENSOR_NOT_REQUIRED); + layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), { n_ff_exp, n_embd, n_expert}, 0); + layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), {n_embd, n_ff_exp, n_expert}, 0); + + // Shared expert (if present) + if (hparams.n_ff_shexp > 0) { + layer.ffn_gate_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), { n_embd, hparams.n_ff_shexp}, 0); + layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), {hparams.n_ff_shexp, n_embd }, 0); + layer.ffn_up_shexp = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), { n_embd, hparams.n_ff_shexp}, 0); + } + } else { // Dense layers + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + } + } +} + +std::unique_ptr<llm_graph_context> llama_model_ernie4_5::build_arch_graph(const llm_graph_params & params) const { + return std::make_unique<graph>(*this, params); +} + +llama_model_ernie4_5::graph::graph(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { - const int64_t n_embd_head = hparams.n_embd_head_v; + const int64_t n_embd_head = hparams.n_embd_head_v(); - GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); - GGML_ASSERT(n_embd_head == hparams.n_rot); + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k()); + GGML_ASSERT(n_embd_head == n_rot); ggml_tensor * cur; ggml_tensor * inpL; @@ -29,27 +102,8 @@ llm_build_ernie4_5::llm_build_ernie4_5(const llama_model & model, const llm_grap } // self-attention { - ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); - cb(Qcur, "Qcur", il); - if (model.layers[il].bq) { - Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq); - cb(Qcur, "Qcur", il); - } - ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); - cb(Kcur, "Kcur", il); - if (model.layers[il].bk) { - Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk); - cb(Kcur, "Kcur", il); - } - ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); - cb(Vcur, "Vcur", il); - if (model.layers[il].bv) { - Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv); - cb(Vcur, "Vcur", il); - } - Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); - Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + auto [Qcur, Kcur, Vcur] = build_qkv(model.layers[il], cur, + n_embd_head, n_head, n_head_kv, il); Qcur = ggml_rope_ext(ctx0, Qcur, inp_pos, nullptr, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow); @@ -62,7 +116,7 @@ llm_build_ernie4_5::llm_build_ernie4_5(const llama_model & model, const llm_grap cb(Vcur, "Vcur", il); cur = build_attn(inp_attn, - model.layers[il].wo, NULL, + model.layers[il].wo, NULL, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f / sqrtf(float(n_embd_head)), il); } if (il == n_layer - 1) { @@ -101,7 +155,7 @@ llm_build_ernie4_5::llm_build_ernie4_5(const llama_model & model, const llm_grap res->t_embd = cur; // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/eurobert.cpp b/examples/talk-llama/models/eurobert.cpp new file mode 100644 index 00000000000..0948d7de656 --- /dev/null +++ b/examples/talk-llama/models/eurobert.cpp @@ -0,0 +1,124 @@ +#include "models.h" + +void llama_model_eurobert::load_arch_hparams(llama_model_loader & ml) { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + + if (hparams.n_layer() == 12) { + type = LLM_TYPE_SMALL; // 0.2B + } +} + +void llama_model_eurobert::load_arch_tensors(llama_model_loader &) { + LLAMA_LOAD_LOCALS; + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + + create_tensor_qkv(layer, i, n_embd, n_embd, n_embd_gqa, n_embd_gqa, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0); + } +} + +std::unique_ptr<llm_graph_context> llama_model_eurobert::build_arch_graph(const llm_graph_params & params) const { + return std::make_unique<graph>(*this, params); +} + +llama_model_eurobert::graph::graph(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { + const int64_t n_embd_head = hparams.n_embd_head_v(); + + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k()); + + ggml_tensor * cur; + ggml_tensor * inpL; + ggml_tensor * inp_pos = build_inp_pos(); + + inpL = build_inp_embd(model.tok_embd); + cb(inpL, "inp_embd", -1); + + auto * inp_attn = build_attn_inp_no_cache(); + + ggml_tensor * inp_out_ids = build_inp_out_ids(); + + for (int il = 0; il < n_layer; ++il) { + ggml_tensor * cur = inpL; + + cur = build_norm(inpL, + model.layers[il].attn_norm, NULL, + LLM_NORM_RMS, il); + + { + auto [Qcur, Kcur, Vcur] = build_qkv(model.layers[il], cur, + n_embd_head, n_head, n_head_kv, il); + + Qcur = ggml_rope_ext( + ctx0, Qcur, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + Kcur = ggml_rope_ext( + ctx0, Kcur, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + cb(Qcur, "Qcur", il); + cb(Kcur, "Kcur", il); + cb(Vcur, "Vcur", il); + + cur = build_attn(inp_attn, + model.layers[il].wo, nullptr, model.layers[il].wo_s, + Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); + cb(cur, "kqv_out", il); + } + + if (il == n_layer - 1 && inp_out_ids) { + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpL = ggml_get_rows(ctx0, inpL, inp_out_ids); + } + + cur = ggml_add(ctx0, cur, inpL); + + ggml_tensor * ffn_inp = cur; + cb(ffn_inp, "ffn_inp", il); + + cur = build_norm(ffn_inp, + model.layers[il].ffn_norm, NULL, + LLM_NORM_RMS, il); + cb(cur, "ffn_norm", il); + + cur = build_ffn(cur, + model.layers[il].ffn_up, NULL, NULL, + model.layers[il].ffn_gate, NULL, NULL, + model.layers[il].ffn_down, NULL, NULL, + NULL, LLM_FFN_SILU, LLM_FFN_PAR, il); + cb(cur, "ffn_out", il); + + cur = ggml_add(ctx0, cur, ffn_inp); + + // input for next layer + inpL = cur; + } + cur = inpL; + + cur = build_norm(cur, + model.output_norm, NULL, + LLM_NORM_RMS, -1); + + cb(cur, "result_embd", -1); + res->t_embd = cur; + + ggml_build_forward_expand(gf, cur); +} diff --git a/examples/talk-llama/models/exaone-moe.cpp b/examples/talk-llama/models/exaone-moe.cpp new file mode 100644 index 00000000000..5aed9379400 --- /dev/null +++ b/examples/talk-llama/models/exaone-moe.cpp @@ -0,0 +1,244 @@ +#include "models.h" + +void llama_model_exaone_moe::load_arch_hparams(llama_model_loader & ml) { + hparams.swa_type = LLAMA_SWA_TYPE_STANDARD; + hparams.n_swa = 128; + uint32_t swa_period = 4; + ml.get_key_or_arr(LLM_KV_ATTENTION_SLIDING_WINDOW_PATTERN, swa_period, false); + hparams.set_swa_pattern(swa_period); + hparams.rope_freq_base_train_swa = hparams.rope_freq_base_train; + hparams.rope_freq_scale_train_swa = hparams.rope_freq_scale_train; + + ml.get_key(LLM_KV_ROPE_FREQ_BASE_SWA, hparams.rope_freq_base_train_swa, false); + ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa); + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + ml.get_key(LLM_KV_EXPERT_SHARED_COUNT, hparams.n_expert_shared, false); + ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp); + ml.get_key(LLM_KV_EXPERT_SHARED_FEED_FORWARD_LENGTH, hparams.n_ff_shexp, false); + ml.get_key(LLM_KV_EXPERT_GATING_FUNC, hparams.expert_gating_func); + ml.get_key(LLM_KV_EXPERT_WEIGHTS_SCALE, hparams.expert_weights_scale, false); + ml.get_key(LLM_KV_EXPERT_WEIGHTS_NORM, hparams.expert_weights_norm, false); + ml.get_key(LLM_KV_LEADING_DENSE_BLOCK_COUNT, hparams.n_layer_dense_lead, false); + + ml.get_key(LLM_KV_NEXTN_PREDICT_LAYERS, hparams.n_layer_nextn, false); + GGML_ASSERT(hparams.n_layer_nextn < hparams.n_layer_all && "n_layer_nextn must be < n_layer_impl"); + + switch (hparams.n_layer()) { + case 32: type = LLM_TYPE_30B_A3B; break; + case 48: type = LLM_TYPE_235B_A22B; break; + default: type = LLM_TYPE_UNKNOWN; + } +} + +void llama_model_exaone_moe::load_arch_tensors(llama_model_loader &) { + LLAMA_LOAD_LOCALS; + + const int64_t n_ff_exp = hparams.n_ff_exp; + const int64_t n_ff_shexp = hparams.n_ff_shexp > 0 ? hparams.n_ff_shexp : n_ff_exp; + const int64_t head_dim = hparams.n_embd_head_k(); + const int64_t n_qo_dim = n_head * head_dim; + const int64_t n_kv_dim = n_head_kv * head_dim; + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0); + + if (output == NULL) { + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); + } + + for (int i = 0; i < n_layer_all; ++i) { + int flags = 0; + if (i >= n_layer) { + // skip all tensors in the NextN layers + flags |= TENSOR_SKIP; + } + + auto & layer = layers[i]; + create_tensor_qkv(layer, i, n_embd, n_qo_dim, n_kv_dim, n_kv_dim, flags); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_qo_dim, n_embd}, flags); + + layer.rope_freqs = create_tensor(tn(LLM_TENSOR_ROPE_FREQS, "weight", i), {n_rot/2}, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0) | flags); + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, flags); + layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head_k}, flags); + layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k}, flags); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, flags); + + // dense layers for first n_layer_dense_lead layers or nextn_predict_layers layers at the end + if (i < (int) hparams.n_layer_dense_lead || (i >= n_layer)) { + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, flags); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, flags); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, flags); + } else { + layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, flags); + layer.ffn_exp_probs_b = create_tensor(tn(LLM_TENSOR_FFN_EXP_PROBS_B, "bias", i), {n_expert}, TENSOR_NOT_REQUIRED | flags); + + if (n_expert == 0) { + throw std::runtime_error("n_expert must be > 0"); + } + if (n_expert_used == 0) { + throw std::runtime_error("n_expert_used must be > 0"); + } + + layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {n_embd, n_ff_exp, n_expert}, flags); + layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff_exp, n_embd, n_expert}, flags); + layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), {n_embd, n_ff_exp, n_expert}, flags); + + layer.ffn_gate_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), {n_embd, n_ff_shexp}, flags); + layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), {n_ff_shexp, n_embd}, flags); + layer.ffn_up_shexp = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), {n_embd, n_ff_shexp}, flags); + } + + // NextN/MTP tensors (preserved but unused) - conditionally load for last nextn_predict_layers + if (i >= n_layer) { + layer.nextn.eh_proj = create_tensor(tn(LLM_TENSOR_NEXTN_EH_PROJ, "weight", i), {2 * n_embd, n_embd}, flags); + layer.nextn.enorm = create_tensor(tn(LLM_TENSOR_NEXTN_ENORM, "weight", i), {n_embd}, flags); + layer.nextn.hnorm = create_tensor(tn(LLM_TENSOR_NEXTN_HNORM, "weight", i), {n_embd}, flags); + + layer.nextn.shared_head_norm = create_tensor(tn(LLM_TENSOR_NEXTN_SHARED_HEAD_NORM, "weight", i), {n_embd}, flags | TENSOR_NOT_REQUIRED); + layer.nextn.embed_tokens = create_tensor(tn(LLM_TENSOR_NEXTN_EMBED_TOKENS, "weight", i), {n_embd, n_vocab}, flags | TENSOR_NOT_REQUIRED); + layer.nextn.shared_head_head = create_tensor(tn(LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD, "weight", i), {n_embd, n_vocab}, flags | TENSOR_NOT_REQUIRED); + } + } +} + +std::unique_ptr<llm_graph_context> llama_model_exaone_moe::build_arch_graph(const llm_graph_params & params) const { + return std::make_unique<graph>(*this, params); +} + +llama_model_exaone_moe::graph::graph(const llama_model & model, const llm_graph_params & params) : + llm_graph_context(params) { + const int64_t n_embd_head = hparams.n_embd_head_k(); + + GGML_ASSERT(n_embd_head == hparams.n_embd_head_v()); + GGML_ASSERT(n_embd_head == n_rot); + + ggml_tensor * cur; + ggml_tensor * inpL; + + inpL = build_inp_embd(model.tok_embd); + + // inp_pos - contains the positions + ggml_tensor * inp_pos = build_inp_pos(); + + auto * inp_attn_iswa = build_attn_inp_kv_iswa(); + + ggml_tensor * inp_out_ids = build_inp_out_ids(); + + for (int il = 0; il < n_layer; ++il) { + ggml_tensor * inpSA = inpL; + + // use RoPE for SWA layers + const bool is_local_layer = hparams.is_swa(il); + + // norm + cur = build_norm(inpL, model.layers[il].attn_norm, NULL, LLM_NORM_RMS, il); + cb(cur, "attn_norm", il); + + // self-attention + { + ggml_tensor * rope_factors = model.get_rope_factors(cparams, il); + + // compute Q and K and RoPE them + auto [Qcur, Kcur, Vcur] = build_qkv(model.layers[il], cur, + n_embd_head, n_head, n_head_kv, il); + + Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, NULL, LLM_NORM_RMS, il); + Kcur = build_norm(Kcur, model.layers[il].attn_k_norm, NULL, LLM_NORM_RMS, il); + cb(Qcur, "Qcur_normed", il); + cb(Kcur, "Kcur_normed", il); + + if (is_local_layer) { + Qcur = ggml_rope_ext(ctx0, Qcur, inp_pos, rope_factors, n_rot, rope_type, n_ctx_orig, freq_base, + freq_scale, ext_factor, attn_factor, beta_fast, beta_slow); + + Kcur = ggml_rope_ext(ctx0, Kcur, inp_pos, rope_factors, n_rot, rope_type, n_ctx_orig, freq_base, + freq_scale, ext_factor, attn_factor, beta_fast, beta_slow); + } + cb(Qcur, "Qcur", il); + cb(Kcur, "Kcur", il); + cb(Vcur, "Vcur", il); + + cur = build_attn(inp_attn_iswa, + model.layers[il].wo, NULL, model.layers[il].wo_s, + Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f / sqrtf(float(n_embd_head)), il); + cb(cur, "attn_out", il); + } + if (il == n_layer - 1 && inp_out_ids) { + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); + } + ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA); + cb(ffn_inp, "ffn_inp", il); + + // norm + cur = build_norm(ffn_inp, model.layers[il].ffn_norm, NULL, LLM_NORM_RMS, il); + cb(cur, "ffn_norm", il); + + // feed-forward network + if (model.layers[il].ffn_gate_inp == nullptr) { + // dense branch + cur = build_ffn(cur, + model.layers[il].ffn_up, NULL, NULL, + model.layers[il].ffn_gate, NULL, NULL, + model.layers[il].ffn_down, NULL, NULL, NULL, + LLM_FFN_SILU, LLM_FFN_PAR, il); + cb(cur, "ffn_out", il); + } else { + // MoE branch + ggml_tensor * moe_out = build_moe_ffn(cur, + model.layers[il].ffn_gate_inp, + model.layers[il].ffn_up_exps, + model.layers[il].ffn_gate_exps, + model.layers[il].ffn_down_exps, + model.layers[il].ffn_exp_probs_b, + n_expert, n_expert_used, + LLM_FFN_SILU, hparams.expert_weights_norm, + hparams.expert_weights_scale, + (llama_expert_gating_func_type) hparams.expert_gating_func, + il); + cb(moe_out, "ffn_moe_out", il); + + // FFN shared expert + { + ggml_tensor * ffn_shexp = + build_ffn(cur, + model.layers[il].ffn_up_shexp, NULL, NULL, + model.layers[il].ffn_gate_shexp, NULL, NULL, + model.layers[il].ffn_down_shexp, NULL, NULL, + NULL, LLM_FFN_SILU, LLM_FFN_PAR, il); + cb(ffn_shexp, "ffn_shexp", il); + + cur = ggml_add(ctx0, moe_out, ffn_shexp); + cb(cur, "ffn_out", il); + } + } + + cur = ggml_add(ctx0, cur, ffn_inp); + + cur = build_cvec(cur, il); + cb(cur, "l_out", il); + + // input for next layer + inpL = cur; + } + cur = inpL; + + // final norm + cur = build_norm(cur, model.output_norm, NULL, LLM_NORM_RMS, -1); + + cb(cur, "result_norm", -1); + res->t_embd = cur; + + // lm_head + cur = build_lora_mm(model.output, cur, model.output_s); + + cb(cur, "result_output", -1); + res->t_logits = cur; + + ggml_build_forward_expand(gf, cur); +} diff --git a/examples/talk-llama/models/exaone.cpp b/examples/talk-llama/models/exaone.cpp index 62602b284de..676fb37b5a6 100644 --- a/examples/talk-llama/models/exaone.cpp +++ b/examples/talk-llama/models/exaone.cpp @@ -1,13 +1,54 @@ #include "models.h" +void llama_model_exaone::load_arch_hparams(llama_model_loader & ml) { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + switch (hparams.n_layer()) { + case 32: type = LLM_TYPE_8B; break; + default: type = LLM_TYPE_UNKNOWN; + } +} + +void llama_model_exaone::load_arch_tensors(llama_model_loader &) { + LLAMA_LOAD_LOCALS; + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); + + // if output is NULL, init from the input tok embed + if (output == NULL) { + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); + } + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); -llm_build_exaone::llm_build_exaone(const llama_model & model, const llm_graph_params & params) : + create_tensor_qkv(layer, i, n_embd, n_embd_head_k * n_head, n_embd_k_gqa, n_embd_v_gqa, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + layer.rope_freqs = create_tensor(tn(LLM_TENSOR_ROPE_FREQS, "weight", i), {n_rot/2}, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0)); + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + } +} + +std::unique_ptr<llm_graph_context> llama_model_exaone::build_arch_graph(const llm_graph_params & params) const { + return std::make_unique<graph>(*this, params); +} + +llama_model_exaone::graph::graph(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { - const int64_t n_embd_head = hparams.n_embd_head_v; + const int64_t n_embd_head = hparams.n_embd_head_v(); - GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); - GGML_ASSERT(n_embd_head == hparams.n_rot); + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k()); + GGML_ASSERT(n_embd_head == n_rot); ggml_tensor * cur; ggml_tensor * inpL; @@ -34,27 +75,8 @@ llm_build_exaone::llm_build_exaone(const llama_model & model, const llm_graph_pa ggml_tensor * rope_factors = model.get_rope_factors(cparams, il); // compute Q and K and RoPE them - ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); - cb(Qcur, "Qcur", il); - if (model.layers[il].bq) { - Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq); - cb(Qcur, "Qcur", il); - } - ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); - cb(Kcur, "Kcur", il); - if (model.layers[il].bk) { - Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk); - cb(Kcur, "Kcur", il); - } - ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); - cb(Vcur, "Vcur", il); - if (model.layers[il].bv) { - Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv); - cb(Vcur, "Vcur", il); - } - Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); - Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + auto [Qcur, Kcur, Vcur] = build_qkv(model.layers[il], cur, + n_embd_head, n_head, n_head_kv, il); Qcur = ggml_rope_ext(ctx0, Qcur, inp_pos, rope_factors, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow); @@ -67,7 +89,7 @@ llm_build_exaone::llm_build_exaone(const llama_model & model, const llm_graph_pa cb(Vcur, "Vcur", il); cur = build_attn(inp_attn, - model.layers[il].wo, model.layers[il].bo, + model.layers[il].wo, model.layers[il].wo_b, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f / sqrtf(float(n_embd_head)), il); } if (il == n_layer - 1 && inp_out_ids) { @@ -105,7 +127,7 @@ llm_build_exaone::llm_build_exaone(const llama_model & model, const llm_graph_pa res->t_embd = cur; // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/exaone4.cpp b/examples/talk-llama/models/exaone4.cpp index 8b7e3dc06e5..863268abcef 100644 --- a/examples/talk-llama/models/exaone4.cpp +++ b/examples/talk-llama/models/exaone4.cpp @@ -1,13 +1,95 @@ #include "models.h" +void llama_model_exaone4::load_arch_hparams(llama_model_loader & ml) { + if (hparams.n_layer() == 64) { // 32B + hparams.swa_type = LLAMA_SWA_TYPE_STANDARD; + hparams.n_swa = 4096; + uint32_t swa_period = 4; + ml.get_key_or_arr(LLM_KV_ATTENTION_SLIDING_WINDOW_PATTERN, swa_period, false); + hparams.set_swa_pattern(swa_period); + + hparams.rope_freq_base_train_swa = hparams.rope_freq_base_train; + hparams.rope_freq_scale_train_swa = hparams.rope_freq_scale_train; + ml.get_key(LLM_KV_ROPE_FREQ_BASE_SWA, hparams.rope_freq_base_train_swa, false); + } + + ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa, false); + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + ml.get_key(LLM_KV_NEXTN_PREDICT_LAYERS, hparams.n_layer_nextn, false); + + GGML_ASSERT(hparams.n_layer_nextn < hparams.n_layer_all && "n_layer_nextn must be < n_layer"); + + switch (hparams.n_layer()) { + case 30: type = LLM_TYPE_1_2B; break; + case 64: type = LLM_TYPE_32B; break; + default: type = LLM_TYPE_UNKNOWN; + } +} + +void llama_model_exaone4::load_arch_tensors(llama_model_loader &) { + LLAMA_LOAD_LOCALS; + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); + + // if output is NULL, init from the input tok embed + if (output == NULL) { + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); + } + + for (int i = 0; i < n_layer_all; ++i) { + const bool is_nextn = i >= n_layer; + int flags = 0; + if (is_nextn) { + // NextN/MTP layers are preserved in GGUF but are not executed yet. + flags |= TENSOR_SKIP; + } + + auto & layer = layers[i]; + + create_tensor_qkv(layer, i, n_embd, n_embd_head_k * n_head, n_embd_k_gqa, n_embd_v_gqa, flags); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, flags); + + if (!is_nextn) { + layer.rope_freqs = create_tensor(tn(LLM_TENSOR_ROPE_FREQS, "weight", i), {n_rot/2}, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0)); + } + + layer.attn_post_norm = create_tensor(tn(LLM_TENSOR_ATTN_POST_NORM, "weight", i), {n_embd}, flags); + layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head_k}, flags); + layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k}, flags); + + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, flags); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, flags); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, flags); + layer.ffn_post_norm = create_tensor(tn(LLM_TENSOR_FFN_POST_NORM, "weight", i), {n_embd}, flags); + + if (is_nextn) { + layer.nextn.eh_proj = create_tensor(tn(LLM_TENSOR_NEXTN_EH_PROJ, "weight", i), {2 * n_embd, n_embd}, flags); + layer.nextn.enorm = create_tensor(tn(LLM_TENSOR_NEXTN_ENORM, "weight", i), {n_embd}, flags); + layer.nextn.hnorm = create_tensor(tn(LLM_TENSOR_NEXTN_HNORM, "weight", i), {n_embd}, flags); + layer.nextn.shared_head_norm = create_tensor(tn(LLM_TENSOR_NEXTN_SHARED_HEAD_NORM, "weight", i), {n_embd}, flags | TENSOR_NOT_REQUIRED); + } + } +} + +std::unique_ptr<llm_graph_context> llama_model_exaone4::build_arch_graph(const llm_graph_params & params) const { + if (hparams.swa_type == LLAMA_SWA_TYPE_STANDARD) { + return std::make_unique<graph<true>>(*this, params); + } else { + return std::make_unique<graph<false>>(*this, params); + } +} template <bool iswa> -llm_build_exaone4<iswa>::llm_build_exaone4(const llama_model & model, const llm_graph_params & params) : +llama_model_exaone4::graph<iswa>::graph(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { - const int64_t n_embd_head = hparams.n_embd_head_k; + const int64_t n_embd_head = hparams.n_embd_head_k(); - GGML_ASSERT(n_embd_head == hparams.n_embd_head_v); - GGML_ASSERT(n_embd_head == hparams.n_rot); + GGML_ASSERT(n_embd_head == hparams.n_embd_head_v()); + GGML_ASSERT(n_embd_head == n_rot); ggml_tensor * cur; ggml_tensor * inpL; @@ -39,18 +121,8 @@ llm_build_exaone4<iswa>::llm_build_exaone4(const llama_model & model, const llm_ { ggml_tensor * rope_factors = model.get_rope_factors(cparams, il); - ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); - cb(Qcur, "Qcur", il); - - ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); - cb(Kcur, "Kcur", il); - - ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); - cb(Vcur, "Vcur", il); - - Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); - Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + auto [Qcur, Kcur, Vcur] = build_qkv(model.layers[il], cur, + n_embd_head, n_head, n_head_kv, il); Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, NULL, LLM_NORM_RMS, il); Kcur = build_norm(Kcur, model.layers[il].attn_k_norm, NULL, LLM_NORM_RMS, il); @@ -69,7 +141,7 @@ llm_build_exaone4<iswa>::llm_build_exaone4(const llama_model & model, const llm_ cb(Vcur, "Vcur", il); cur = build_attn(inp_attn, - model.layers[il].wo, NULL, + model.layers[il].wo, NULL, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f / sqrtf(float(n_embd_head)), il); cb(cur, "attn_out", il); } @@ -110,7 +182,7 @@ llm_build_exaone4<iswa>::llm_build_exaone4(const llama_model & model, const llm_ res->t_embd = cur; // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; @@ -119,5 +191,5 @@ llm_build_exaone4<iswa>::llm_build_exaone4(const llama_model & model, const llm_ } // Explicit template instantiations -template struct llm_build_exaone4<false>; -template struct llm_build_exaone4<true>; +template struct llama_model_exaone4::graph<false>; +template struct llama_model_exaone4::graph<true>; diff --git a/examples/talk-llama/models/falcon-h1.cpp b/examples/talk-llama/models/falcon-h1.cpp index b641a094079..d6ef2d51986 100644 --- a/examples/talk-llama/models/falcon-h1.cpp +++ b/examples/talk-llama/models/falcon-h1.cpp @@ -1,10 +1,117 @@ #include "models.h" +void llama_model_falcon_h1::load_arch_hparams(llama_model_loader & ml) { + // Common parameters + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + + // SSM parameters + ml.get_key(LLM_KV_SSM_CONV_KERNEL, hparams.ssm_d_conv); + ml.get_key(LLM_KV_SSM_INNER_SIZE, hparams.ssm_d_inner); + ml.get_key(LLM_KV_SSM_STATE_SIZE, hparams.ssm_d_state); + ml.get_key(LLM_KV_SSM_TIME_STEP_RANK, hparams.ssm_dt_rank); + ml.get_key(LLM_KV_SSM_GROUP_COUNT, hparams.ssm_n_group); + + std::fill(hparams.is_recr_impl.begin(), hparams.is_recr_impl.end(), true); + + switch (hparams.n_layer()) { + case 36: + type = LLM_TYPE_0_5B; break; + case 24: + type = LLM_TYPE_1_5B; break; + case 66: + type = LLM_TYPE_1B; break; + case 32: + type = LLM_TYPE_3B; break; + case 44: + type = LLM_TYPE_7B; break; + case 72: + type = LLM_TYPE_34B; break; + default: + type = LLM_TYPE_UNKNOWN; + } +} + +void llama_model_falcon_h1::load_arch_tensors(llama_model_loader &) { + LLAMA_LOAD_LOCALS; + + // Common + const int64_t hidden_size = hparams.n_embd; // hidden_size + + // mamba2 Mixer SSM params + const int64_t ssm_conv_kernel_size = hparams.ssm_d_conv; // ssm_conv_kernel_size + const int64_t ssm_n_groups = hparams.ssm_n_group; // ssm_n_groups + const int64_t ssm_state_size = hparams.ssm_d_state; // ssm_state_size + const int64_t ssm_intermediate_size = hparams.ssm_d_inner; // TODO expand + const int64_t ssm_num_heads = hparams.ssm_dt_rank; // ssm_num_heads + const int64_t ssm_conv_dim = ssm_intermediate_size + 2 * ssm_n_groups * ssm_state_size; + const int64_t ssm_projection_size = ssm_intermediate_size + ssm_conv_dim + ssm_num_heads; + // attn params + const int64_t attn_num_attention_head = hparams.n_head(0); // rename to: attn_num_attention_head + const int64_t attn_num_key_value_head = hparams.n_head_kv(0); -llm_build_falcon_h1::llm_build_falcon_h1(const llama_model & model, const llm_graph_params & params) : - llm_graph_context_mamba(params) { - const int64_t n_embd_head = hparams.n_embd_head_v; + // ffn params + const int64_t ffn_intermediate_size = hparams.n_ff(0); + + // embeddings + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {hidden_size, n_vocab}, 0); + + // output + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {hidden_size, n_vocab}, TENSOR_NOT_REQUIRED); + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {hidden_size}, 0); + + // if output is NULL, init from the input tok embed + if (output == NULL) { + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {hidden_size, n_vocab}, TENSOR_DUPLICATED); + } + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + /*SSM LAYERS*/ + // ssm in + layer.ssm_in = create_tensor(tn(LLM_TENSOR_SSM_IN, "weight", i), {hidden_size, ssm_projection_size}, 0); + // ssm 1d conv + layer.ssm_conv1d = create_tensor(tn(LLM_TENSOR_SSM_CONV1D, "weight", i), {ssm_conv_kernel_size, ssm_conv_dim}, 0); + layer.ssm_conv1d_b = create_tensor(tn(LLM_TENSOR_SSM_CONV1D, "bias", i), {ssm_conv_dim}, TENSOR_NOT_REQUIRED); + // ssm_dt + layer.ssm_dt_b = create_tensor(tn(LLM_TENSOR_SSM_DT, "bias", i), {ssm_num_heads}, 0); + // no "weight" suffix for these + layer.ssm_a = create_tensor(tn(LLM_TENSOR_SSM_A, i), {1, ssm_num_heads}, 0); + layer.ssm_d = create_tensor(tn(LLM_TENSOR_SSM_D, i), {1, ssm_num_heads}, 0); + // ssm_norm + layer.ssm_norm = create_tensor(tn(LLM_TENSOR_SSM_NORM, "weight", i), {ssm_intermediate_size / ssm_n_groups, ssm_n_groups}, TENSOR_NOT_REQUIRED); + // out_proj + layer.ssm_out = create_tensor(tn(LLM_TENSOR_SSM_OUT, "weight", i), {ssm_intermediate_size, hidden_size}, 0); + + /*ATTENTION LAYERS*/ + // attention layers (with optional bias) + create_tensor_qkv(layer, i, hidden_size, n_embd_head_k * attn_num_attention_head, attn_num_key_value_head * n_embd_head_k, attn_num_key_value_head * n_embd_head_v, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * attn_num_attention_head, hidden_size}, 0); + layer.wo_b = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {hidden_size}, TENSOR_NOT_REQUIRED); + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {hidden_size}, 0); + + + // feed forward (w/ optional biases) + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, i), {hidden_size}, 0); + layer.rope_freqs = create_tensor(tn(LLM_TENSOR_ROPE_FREQS, "weight", i), {n_rot/2}, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0)); + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {hidden_size, ffn_intermediate_size}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { ffn_intermediate_size, hidden_size}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {hidden_size, ffn_intermediate_size}, 0); + + layer.ffn_gate_b = create_tensor(tn(LLM_TENSOR_FFN_GATE, "bias", i), {ffn_intermediate_size}, TENSOR_NOT_REQUIRED); + layer.ffn_down_b = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "bias", i), {hidden_size}, TENSOR_NOT_REQUIRED); + layer.ffn_up_b = create_tensor(tn(LLM_TENSOR_FFN_UP, "bias", i), {ffn_intermediate_size}, TENSOR_NOT_REQUIRED); + } +} + +std::unique_ptr<llm_graph_context> llama_model_falcon_h1::build_arch_graph(const llm_graph_params & params) const { + return std::make_unique<graph>(*this, params); +} + +llama_model_falcon_h1::graph::graph(const llama_model & model, const llm_graph_params & params) : + llm_build_mamba_base(params) { + const int64_t n_embd_head = hparams.n_embd_head_v(); ggml_tensor * cur; ggml_tensor * inpL; @@ -29,19 +136,8 @@ llm_build_falcon_h1::llm_build_falcon_h1(const llama_model & model, const llm_gr cb(cur, "attn_norm", il); // self-attention - ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); - cb(Qcur, "Qcur", il); - - ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); - cb(Kcur, "Kcur", il); - - ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); - cb(Vcur, "Vcur", il); - - Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); - - Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + auto [Qcur, Kcur, Vcur] = build_qkv(model.layers[il], cur, + n_embd_head, n_head, n_head_kv, il); Qcur = ggml_rope_ext(ctx0, Qcur, inp_pos, nullptr, n_rot, hparams.rope_type, n_ctx_orig, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow); @@ -54,7 +150,7 @@ llm_build_falcon_h1::llm_build_falcon_h1(const llama_model & model, const llm_gr cb(Vcur, "Vcur-post-rope", il); ggml_tensor * attn_out = build_attn(inp->get_attn(), - model.layers[il].wo, NULL, + model.layers[il].wo, NULL, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale, il); cb(attn_out, "attn_out", il); @@ -104,7 +200,7 @@ llm_build_falcon_h1::llm_build_falcon_h1(const llama_model & model, const llm_gr res->t_embd = cur; // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/falcon.cpp b/examples/talk-llama/models/falcon.cpp index db1ccdb5008..b2ad90b3272 100644 --- a/examples/talk-llama/models/falcon.cpp +++ b/examples/talk-llama/models/falcon.cpp @@ -1,12 +1,57 @@ #include "models.h" +void llama_model_falcon::load_arch_hparams(llama_model_loader & ml) { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); -llm_build_falcon::llm_build_falcon(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { - const int64_t n_embd_head = hparams.n_embd_head_v; - const int64_t n_embd_gqa = hparams.n_embd_v_gqa(); + switch (hparams.n_layer()) { + case 32: type = LLM_TYPE_7B; break; + case 60: type = LLM_TYPE_40B; break; + default: type = LLM_TYPE_UNKNOWN; + } +} + +void llama_model_falcon::load_arch_tensors(llama_model_loader &) { + LLAMA_LOAD_LOCALS; + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + { + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output_norm_b = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd}, 0); + + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); + if (!output) { + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); // needs to be on GPU + } + } + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + layer.attn_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "bias", i), {n_embd}, 0); - GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); - GGML_ASSERT(n_embd_head == hparams.n_rot); + layer.attn_norm_2 = create_tensor(tn(LLM_TENSOR_ATTN_NORM_2, "weight", i), {n_embd}, TENSOR_NOT_REQUIRED); + layer.attn_norm_2_b = create_tensor(tn(LLM_TENSOR_ATTN_NORM_2, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); + + layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); + + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + } +} + +std::unique_ptr<llm_graph_context> llama_model_falcon::build_arch_graph(const llm_graph_params & params) const { + return std::make_unique<graph>(*this, params); +} + +llama_model_falcon::graph::graph(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { + const int64_t n_embd_head = hparams.n_embd_head_v(); + + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k()); + GGML_ASSERT(n_embd_head == n_rot); ggml_tensor * cur; ggml_tensor * inpL; @@ -42,12 +87,8 @@ llm_build_falcon::llm_build_falcon(const llama_model & model, const llm_graph_pa cur = attn_norm; } - cur = build_lora_mm(model.layers[il].wqkv, cur); - cb(cur, "wqkv", il); - - ggml_tensor * Qcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 0*sizeof(float)*(n_embd)); - ggml_tensor * Kcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 1*sizeof(float)*(n_embd)); - ggml_tensor * Vcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa)); + auto [Qcur, Kcur, Vcur] = build_qkv(model.layers[il], cur, + n_embd_head, n_head, n_head_kv, il); // using mode = 2 for neox mode Qcur = ggml_rope_ext( @@ -67,7 +108,7 @@ llm_build_falcon::llm_build_falcon(const llama_model & model, const llm_graph_pa cb(Vcur, "Vcur", il); cur = build_attn(inp_attn, - model.layers[il].wo, NULL, + model.layers[il].wo, NULL, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } @@ -111,7 +152,7 @@ llm_build_falcon::llm_build_falcon(const llama_model & model, const llm_graph_pa cb(cur, "result_norm", -1); res->t_embd = cur; - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/gemma-embedding.cpp b/examples/talk-llama/models/gemma-embedding.cpp index 944c198bf95..80ed3b1a460 100644 --- a/examples/talk-llama/models/gemma-embedding.cpp +++ b/examples/talk-llama/models/gemma-embedding.cpp @@ -1,15 +1,87 @@ #include "models.h" -llm_build_gemma_embedding::llm_build_gemma_embedding(const llama_model & model, const llm_graph_params & params) : +void llama_model_gemma_embedding::load_arch_hparams(llama_model_loader & ml) { + hparams.swa_type = LLAMA_SWA_TYPE_SYMMETRIC; + uint32_t swa_period = 6; + ml.get_key_or_arr(LLM_KV_ATTENTION_SLIDING_WINDOW_PATTERN, swa_period, false); + hparams.set_swa_pattern(swa_period); + + hparams.causal_attn = false; // embeddings do not use causal attention + + ml.get_key(LLM_KV_ROPE_FREQ_BASE_SWA, hparams.rope_freq_base_train_swa, false); + ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa); + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + + //applied only if model converted with --sentence-transformers-dense-modules + ml.get_key(LLM_KV_DENSE_2_FEAT_IN, hparams.dense_2_feat_in, false); + ml.get_key(LLM_KV_DENSE_2_FEAT_OUT, hparams.dense_2_feat_out, false); + ml.get_key(LLM_KV_DENSE_3_FEAT_IN, hparams.dense_3_feat_in, false); + ml.get_key(LLM_KV_DENSE_3_FEAT_OUT, hparams.dense_3_feat_out, false); + + GGML_ASSERT((hparams.dense_2_feat_in == 0 || hparams.dense_2_feat_in == hparams.n_embd) && "dense_2_feat_in must be equal to n_embd"); + GGML_ASSERT((hparams.dense_3_feat_out == 0 || hparams.dense_3_feat_out == hparams.n_embd) && "dense_3_feat_out must be equal to n_embd"); + + switch (hparams.n_layer()) { + case 24: type = LLM_TYPE_0_3B; break; + default: type = LLM_TYPE_UNKNOWN; + } + hparams.f_attention_scale = 1.0f / std::sqrt(float(hparams.n_embd_head_k())); + +} + +void llama_model_gemma_embedding::load_arch_tensors(llama_model_loader &) { + LLAMA_LOAD_LOCALS; + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); + + // if output is NULL, init from the input tok embed + if (output == NULL) { + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); + } + + // Dense linear weights + dense_2_out_layers = create_tensor(tn(LLM_TENSOR_DENSE_2_OUT, "weight"), {n_embd, hparams.dense_2_feat_out}, TENSOR_NOT_REQUIRED); + dense_3_out_layers = create_tensor(tn(LLM_TENSOR_DENSE_3_OUT, "weight"), {hparams.dense_3_feat_in, n_embd}, TENSOR_NOT_REQUIRED); + + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + + create_tensor_qkv(layer, i, n_embd, n_embd_head_k * n_head, n_embd_k_gqa, n_embd_v_gqa, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0); + + layer.attn_post_norm = create_tensor(tn(LLM_TENSOR_ATTN_POST_NORM, "weight", i), {n_embd}, 0); + layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k}, 0); + layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head_k}, 0); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); + layer.ffn_post_norm = create_tensor(tn(LLM_TENSOR_FFN_POST_NORM, "weight", i), {n_embd}, 0); + } +} + +std::unique_ptr<llm_graph_context> llama_model_gemma_embedding::build_arch_graph(const llm_graph_params & params) const { + return std::make_unique<graph>(*this, params); +} + +llama_model_gemma_embedding::graph::graph(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { - const int64_t n_embd_head = hparams.n_embd_head_k; + const int64_t n_embd_head = hparams.n_embd_head_k(); ggml_tensor * cur; ggml_tensor * inpL; inpL = build_inp_embd(model.tok_embd); - // important: do not normalize weights for raw embeddings input (i.e. encoded image emdeddings) + // important: do not normalize weights for raw embeddings input (i.e. encoded image embeddings) inpL = ggml_scale(ctx0, inpL, ubatch.token ? sqrtf(n_embd) : 1.0f); cb(inpL, "inp_scaled", -1); @@ -31,18 +103,8 @@ llm_build_gemma_embedding::llm_build_gemma_embedding(const llama_model & model, // self-attention { // compute Q and K and RoPE them - ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); - cb(Qcur, "Qcur", il); - - ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); - cb(Kcur, "Kcur", il); - - ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); - cb(Vcur, "Vcur", il); - - Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); - Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + auto [Qcur, Kcur, Vcur] = build_qkv(model.layers[il], cur, + n_embd_head, n_head, n_head_kv, il); Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, NULL, LLM_NORM_RMS, il); cb(Qcur, "Qcur_normed", il); @@ -65,7 +127,7 @@ llm_build_gemma_embedding::llm_build_gemma_embedding(const llama_model & model, cur = build_attn(inp_attn, - model.layers[il].wo, NULL, + model.layers[il].wo, NULL, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f, il); } diff --git a/examples/talk-llama/models/gemma.cpp b/examples/talk-llama/models/gemma.cpp index 4893d9af4b8..651cd7e64de 100644 --- a/examples/talk-llama/models/gemma.cpp +++ b/examples/talk-llama/models/gemma.cpp @@ -1,8 +1,45 @@ #include "models.h" +void llama_model_gemma::load_arch_hparams(llama_model_loader & ml) { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + + switch (hparams.n_layer()) { + case 18: type = LLM_TYPE_2B; break; + case 28: type = LLM_TYPE_7B; break; + default: type = LLM_TYPE_UNKNOWN; + } +} + +void llama_model_gemma::load_arch_tensors(llama_model_loader &) { + LLAMA_LOAD_LOCALS; + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); // same as tok_embd, duplicated to allow offloading + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + + create_tensor_qkv(layer, i, n_embd, n_embd_head_k * n_head, n_embd_k_gqa, n_embd_v_gqa, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); + } +} + +std::unique_ptr<llm_graph_context> llama_model_gemma::build_arch_graph(const llm_graph_params & params) const { + return std::make_unique<graph>(*this, params); +} -llm_build_gemma::llm_build_gemma(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { - const int64_t n_embd_head = hparams.n_embd_head_v; +llama_model_gemma::graph::graph(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { + const int64_t n_embd_head = hparams.n_embd_head_v(); ggml_tensor * cur; ggml_tensor * inpL; @@ -29,18 +66,8 @@ llm_build_gemma::llm_build_gemma(const llama_model & model, const llm_graph_para // self-attention { // compute Q and K and RoPE them - ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); - cb(Qcur, "Qcur", il); - - ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); - cb(Kcur, "Kcur", il); - - ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); - cb(Vcur, "Vcur", il); - - Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); - Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + auto [Qcur, Kcur, Vcur] = build_qkv(model.layers[il], cur, + n_embd_head, n_head, n_head_kv, il); Qcur = ggml_rope_ext( ctx0, Qcur, inp_pos, nullptr, @@ -60,7 +87,7 @@ llm_build_gemma::llm_build_gemma(const llama_model & model, const llm_graph_para cb(Qcur, "Qcur_scaled", il); cur = build_attn(inp_attn, - model.layers[il].wo, NULL, + model.layers[il].wo, NULL, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f, il); } if (il == n_layer - 1 && inp_out_ids) { @@ -103,7 +130,7 @@ llm_build_gemma::llm_build_gemma(const llama_model & model, const llm_graph_para res->t_embd = cur; // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/gemma2-iswa.cpp b/examples/talk-llama/models/gemma2-iswa.cpp deleted file mode 100644 index 7a9198193ac..00000000000 --- a/examples/talk-llama/models/gemma2-iswa.cpp +++ /dev/null @@ -1,128 +0,0 @@ -#include "models.h" - -llm_build_gemma2_iswa::llm_build_gemma2_iswa(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { - const int64_t n_embd_head = hparams.n_embd_head_k; - - ggml_tensor * cur; - ggml_tensor * inpL; - - inpL = build_inp_embd(model.tok_embd); - - inpL = ggml_scale(ctx0, inpL, sqrtf(n_embd)); - cb(inpL, "inp_scaled", -1); - - // inp_pos - contains the positions - ggml_tensor * inp_pos = build_inp_pos(); - - auto * inp_attn = build_attn_inp_kv_iswa(); - - ggml_tensor * inp_out_ids = build_inp_out_ids(); - - for (int il = 0; il < n_layer; ++il) { - const float freq_base_l = model.get_rope_freq_base (cparams, il); - const float freq_scale_l = model.get_rope_freq_scale(cparams, il); - - // norm - cur = build_norm(inpL, - model.layers[il].attn_norm, NULL, - LLM_NORM_RMS, il); - cb(cur, "attn_norm", il); - - // self-attention - { - // compute Q and K and RoPE them - ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); - cb(Qcur, "Qcur", il); - - ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); - cb(Kcur, "Kcur", il); - - ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); - cb(Vcur, "Vcur", il); - - Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); - Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); - - Qcur = ggml_rope_ext( - ctx0, Qcur, inp_pos, nullptr, - n_rot, rope_type, n_ctx_orig, freq_base_l, freq_scale_l, - ext_factor, attn_factor, beta_fast, beta_slow); - - Kcur = ggml_rope_ext( - ctx0, Kcur, inp_pos, nullptr, - n_rot, rope_type, n_ctx_orig, freq_base_l, freq_scale_l, - ext_factor, attn_factor, beta_fast, beta_slow); - - cb(Qcur, "Qcur", il); - cb(Kcur, "Kcur", il); - cb(Vcur, "Vcur", il); - - Qcur = ggml_scale(ctx0, Qcur, hparams.f_attention_scale); - - cur = build_attn(inp_attn, - model.layers[il].wo, NULL, - Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f, il); - } - if (il == n_layer - 1 && inp_out_ids) { - cur = ggml_get_rows(ctx0, cur, inp_out_ids); - inpL = ggml_get_rows(ctx0, inpL, inp_out_ids); - } - cur = build_norm(cur, - model.layers[il].attn_post_norm, NULL, - LLM_NORM_RMS, il); - cb(cur, "attn_post_norm", il); - - ggml_tensor * sa_out = ggml_add(ctx0, cur, inpL); - cb(sa_out, "sa_out", il); - - cur = build_norm(sa_out, - model.layers[il].ffn_norm, NULL, - LLM_NORM_RMS, il); - cb(cur, "ffn_norm", il); - - // feed-forward network - { - cur = build_ffn(cur, - model.layers[il].ffn_up, NULL, NULL, - model.layers[il].ffn_gate, NULL, NULL, - model.layers[il].ffn_down, NULL, NULL, - NULL, - LLM_FFN_GELU, LLM_FFN_PAR, il); - cb(cur, "ffn_out", il); - } - cur = build_norm(cur, - model.layers[il].ffn_post_norm, NULL, - LLM_NORM_RMS, -1); - cb(cur, "ffn_post_norm", -1); - - cur = ggml_add(ctx0, cur, sa_out); - - cur = build_cvec(cur, il); - cb(cur, "l_out", il); - - // input for next layer - inpL = cur; - } - cur = inpL; - - cur = build_norm(cur, - model.output_norm, NULL, - LLM_NORM_RMS, -1); - - cb(cur, "result_norm", -1); - res->t_embd = cur; - - // lm_head - cur = build_lora_mm(model.output, cur); - - // final logit soft-capping - cur = ggml_scale(ctx0, cur, 1.0f / hparams.f_final_logit_softcapping); - cur = ggml_tanh(ctx0, cur); - cur = ggml_scale(ctx0, cur, hparams.f_final_logit_softcapping); - - cb(cur, "result_output", -1); - res->t_logits = cur; - - ggml_build_forward_expand(gf, cur); -} diff --git a/examples/talk-llama/models/gemma2.cpp b/examples/talk-llama/models/gemma2.cpp new file mode 100644 index 00000000000..2fbfb15a94a --- /dev/null +++ b/examples/talk-llama/models/gemma2.cpp @@ -0,0 +1,177 @@ +#include "models.h" + +void llama_model_gemma2::load_arch_hparams(llama_model_loader & ml) { + hparams.swa_type = LLAMA_SWA_TYPE_STANDARD; + hparams.n_swa = 4096; // default value of gemma 2 + uint32_t swa_period = 2; + ml.get_key_or_arr(LLM_KV_ATTENTION_SLIDING_WINDOW_PATTERN, swa_period, false); + hparams.set_swa_pattern(swa_period); + hparams.attn_soft_cap = true; + hparams.rope_freq_base_train_swa = hparams.rope_freq_base_train; + hparams.rope_freq_scale_train_swa = hparams.rope_freq_scale_train; + + ml.get_key(LLM_KV_ROPE_FREQ_BASE_SWA, hparams.rope_freq_base_train_swa, false); + ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa, false); + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + ml.get_key(LLM_KV_ATTN_LOGIT_SOFTCAPPING, hparams.f_attn_logit_softcapping, false); + ml.get_key(LLM_KV_FINAL_LOGIT_SOFTCAPPING, hparams.f_final_logit_softcapping, false); + + switch (hparams.n_layer()) { + case 26: type = LLM_TYPE_2B; break; + case 42: type = LLM_TYPE_9B; break; + case 46: type = LLM_TYPE_27B; break; + default: type = LLM_TYPE_UNKNOWN; + } + + // ref: https://github.com/google/gemma_pytorch/blob/014acb7ac4563a5f77c76d7ff98f31b568c16508/gemma/config.py#L173 + hparams.f_attention_scale = type == LLM_TYPE_27B + ? 1.0f / std::sqrt(float(hparams.n_embd / hparams.n_head(0))) + : 1.0f / std::sqrt(float(hparams.n_embd_head_k())); +} + +void llama_model_gemma2::load_arch_tensors(llama_model_loader &) { + LLAMA_LOAD_LOCALS; + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); // same as tok_embd, duplicated to allow offloading + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + + create_tensor_qkv(layer, i, n_embd, n_embd_head_k * n_head, n_embd_k_gqa, n_embd_v_gqa, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0); + layer.attn_post_norm = create_tensor(tn(LLM_TENSOR_ATTN_POST_NORM, "weight", i), {n_embd}, 0); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); + layer.ffn_post_norm = create_tensor(tn(LLM_TENSOR_FFN_POST_NORM, "weight", i), {n_embd}, 0); + } +} + +std::unique_ptr<llm_graph_context> llama_model_gemma2::build_arch_graph(const llm_graph_params & params) const { + return std::make_unique<graph>(*this, params); +} + +llama_model_gemma2::graph::graph(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { + const int64_t n_embd_head = hparams.n_embd_head_k(); + + ggml_tensor * cur; + ggml_tensor * inpL; + + inpL = build_inp_embd(model.tok_embd); + + inpL = ggml_scale(ctx0, inpL, sqrtf(n_embd)); + cb(inpL, "inp_scaled", -1); + + // inp_pos - contains the positions + ggml_tensor * inp_pos = build_inp_pos(); + + auto * inp_attn = build_attn_inp_kv_iswa(); + + ggml_tensor * inp_out_ids = build_inp_out_ids(); + + for (int il = 0; il < n_layer; ++il) { + const float freq_base_l = model.get_rope_freq_base (cparams, il); + const float freq_scale_l = model.get_rope_freq_scale(cparams, il); + + // norm + cur = build_norm(inpL, + model.layers[il].attn_norm, NULL, + LLM_NORM_RMS, il); + cb(cur, "attn_norm", il); + + // self-attention + { + // compute Q and K and RoPE them + auto [Qcur, Kcur, Vcur] = build_qkv(model.layers[il], cur, + n_embd_head, n_head, n_head_kv, il); + + Qcur = ggml_rope_ext( + ctx0, Qcur, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base_l, freq_scale_l, + ext_factor, attn_factor, beta_fast, beta_slow); + + Kcur = ggml_rope_ext( + ctx0, Kcur, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base_l, freq_scale_l, + ext_factor, attn_factor, beta_fast, beta_slow); + + cb(Qcur, "Qcur", il); + cb(Kcur, "Kcur", il); + cb(Vcur, "Vcur", il); + + Qcur = ggml_scale(ctx0, Qcur, hparams.f_attention_scale); + + cur = build_attn(inp_attn, + model.layers[il].wo, NULL, model.layers[il].wo_s, + Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f, il); + } + if (il == n_layer - 1 && inp_out_ids) { + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpL = ggml_get_rows(ctx0, inpL, inp_out_ids); + } + cur = build_norm(cur, + model.layers[il].attn_post_norm, NULL, + LLM_NORM_RMS, il); + cb(cur, "attn_post_norm", il); + + ggml_tensor * sa_out = ggml_add(ctx0, cur, inpL); + cb(sa_out, "sa_out", il); + + cur = build_norm(sa_out, + model.layers[il].ffn_norm, NULL, + LLM_NORM_RMS, il); + cb(cur, "ffn_norm", il); + + // feed-forward network + { + cur = build_ffn(cur, + model.layers[il].ffn_up, NULL, NULL, + model.layers[il].ffn_gate, NULL, NULL, + model.layers[il].ffn_down, NULL, NULL, + NULL, + LLM_FFN_GELU, LLM_FFN_PAR, il); + cb(cur, "ffn_out", il); + } + cur = build_norm(cur, + model.layers[il].ffn_post_norm, NULL, + LLM_NORM_RMS, -1); + cb(cur, "ffn_post_norm", -1); + + cur = ggml_add(ctx0, cur, sa_out); + + cur = build_cvec(cur, il); + cb(cur, "l_out", il); + + // input for next layer + inpL = cur; + } + cur = inpL; + + cur = build_norm(cur, + model.output_norm, NULL, + LLM_NORM_RMS, -1); + + cb(cur, "result_norm", -1); + res->t_embd = cur; + + // lm_head + cur = build_lora_mm(model.output, cur, model.output_s); + + // final logit soft-capping + cur = ggml_scale(ctx0, cur, 1.0f / hparams.f_final_logit_softcapping); + cur = ggml_tanh(ctx0, cur); + cur = ggml_scale(ctx0, cur, hparams.f_final_logit_softcapping); + + cb(cur, "result_output", -1); + res->t_logits = cur; + + ggml_build_forward_expand(gf, cur); +} diff --git a/examples/talk-llama/models/gemma3.cpp b/examples/talk-llama/models/gemma3.cpp index dec3fc4b8bc..690194529e3 100644 --- a/examples/talk-llama/models/gemma3.cpp +++ b/examples/talk-llama/models/gemma3.cpp @@ -1,15 +1,95 @@ #include "models.h" +void llama_model_gemma3::load_arch_hparams(llama_model_loader & ml) { + const bool found_swa = ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa, false); + if (found_swa && hparams.n_swa > 0) { + hparams.swa_type = LLAMA_SWA_TYPE_STANDARD; + uint32_t swa_period = 6; + ml.get_key_or_arr(LLM_KV_ATTENTION_SLIDING_WINDOW_PATTERN, swa_period, false); + hparams.set_swa_pattern(swa_period); + + ml.get_key(LLM_KV_ROPE_FREQ_BASE_SWA, hparams.rope_freq_base_train_swa, false); + } else { + hparams.swa_type = LLAMA_SWA_TYPE_NONE; + } + + hparams.f_final_logit_softcapping = 0.0f; + ml.get_key(LLM_KV_FINAL_LOGIT_SOFTCAPPING, hparams.f_final_logit_softcapping, false); + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + + switch (hparams.n_layer()) { + case 18: type = LLM_TYPE_270M; break; + case 26: type = LLM_TYPE_1B; break; + case 32: type = LLM_TYPE_8B; break; // Rnj-1 + case 34: type = LLM_TYPE_4B; break; + case 48: type = LLM_TYPE_12B; break; + case 62: type = LLM_TYPE_27B; break; + default: type = LLM_TYPE_UNKNOWN; + } + + // ref: https://github.com/google/gemma_pytorch/blob/014acb7ac4563a5f77c76d7ff98f31b568c16508/gemma/config.py#L289 + hparams.f_attention_scale = type == LLM_TYPE_27B + ? 1.0f / std::sqrt(float(hparams.n_embd / hparams.n_head(0))) + : 1.0f / std::sqrt(float(hparams.n_embd_head_k())); +} + +void llama_model_gemma3::load_arch_tensors(llama_model_loader &) { + LLAMA_LOAD_LOCALS; + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); + + // if output is NULL, init from the input tok embed + if (output == NULL) { + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); + } + + // Dense linear weights + dense_2_out_layers = create_tensor(tn(LLM_TENSOR_DENSE_2_OUT, "weight"), {n_embd, hparams.dense_2_feat_out}, TENSOR_NOT_REQUIRED); + dense_3_out_layers = create_tensor(tn(LLM_TENSOR_DENSE_3_OUT, "weight"), {hparams.dense_3_feat_in, n_embd}, TENSOR_NOT_REQUIRED); + + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + + create_tensor_qkv(layer, i, n_embd, n_embd_head_k * n_head, n_embd_k_gqa, n_embd_v_gqa, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0); + + layer.attn_post_norm = create_tensor(tn(LLM_TENSOR_ATTN_POST_NORM, "weight", i), {n_embd}, 0); + layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k}, 0); + layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head_k}, 0); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); + layer.ffn_post_norm = create_tensor(tn(LLM_TENSOR_FFN_POST_NORM, "weight", i), {n_embd}, 0); + } +} + +std::unique_ptr<llm_graph_context> llama_model_gemma3::build_arch_graph(const llm_graph_params & params) const { + if (hparams.swa_type == LLAMA_SWA_TYPE_STANDARD) { + return std::make_unique<graph<true>>(*this, params); + } else { + return std::make_unique<graph<false>>(*this, params); + } +} + template <bool iswa> -llm_build_gemma3<iswa>::llm_build_gemma3(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { - const int64_t n_embd_head = hparams.n_embd_head_k; +llama_model_gemma3::graph<iswa>::graph(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { + const int64_t n_embd_head = hparams.n_embd_head_k(); ggml_tensor * cur; ggml_tensor * inpL; inpL = build_inp_embd(model.tok_embd); - // important: do not normalize weights for raw embeddings input (i.e. encoded image emdeddings) + // important: do not normalize weights for raw embeddings input (i.e. encoded image embeddings) inpL = ggml_scale(ctx0, inpL, ubatch.token ? sqrtf(n_embd) : 1.0f); cb(inpL, "inp_scaled", -1); @@ -47,18 +127,8 @@ llm_build_gemma3<iswa>::llm_build_gemma3(const llama_model & model, const llm_gr // self-attention { // compute Q and K and RoPE them - ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); - cb(Qcur, "Qcur", il); - - ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); - cb(Kcur, "Kcur", il); - - ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); - cb(Vcur, "Vcur", il); - - Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); - Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + auto [Qcur, Kcur, Vcur] = build_qkv(model.layers[il], cur, + n_embd_head, n_head, n_head_kv, il); Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, NULL, LLM_NORM_RMS, il); cb(Qcur, "Qcur_normed", il); @@ -84,7 +154,7 @@ llm_build_gemma3<iswa>::llm_build_gemma3(const llama_model & model, const llm_gr Qcur = ggml_scale(ctx0, Qcur, hparams.f_attention_scale); cur = build_attn(inp_attn, - model.layers[il].wo, NULL, + model.layers[il].wo, NULL, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f, il); } if (il == n_layer - 1 && inp_out_ids) { @@ -137,7 +207,7 @@ llm_build_gemma3<iswa>::llm_build_gemma3(const llama_model & model, const llm_gr res->t_embd = cur; // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); if (hparams.f_final_logit_softcapping) { cur = ggml_scale(ctx0, cur, 1.0f / hparams.f_final_logit_softcapping); @@ -151,5 +221,5 @@ llm_build_gemma3<iswa>::llm_build_gemma3(const llama_model & model, const llm_gr ggml_build_forward_expand(gf, cur); } -template struct llm_build_gemma3<false>; -template struct llm_build_gemma3<true>; +template struct llama_model_gemma3::graph<false>; +template struct llama_model_gemma3::graph<true>; diff --git a/examples/talk-llama/models/gemma3n-iswa.cpp b/examples/talk-llama/models/gemma3n.cpp similarity index 63% rename from examples/talk-llama/models/gemma3n-iswa.cpp rename to examples/talk-llama/models/gemma3n.cpp index 93defbeef9c..83eb8250aa9 100644 --- a/examples/talk-llama/models/gemma3n-iswa.cpp +++ b/examples/talk-llama/models/gemma3n.cpp @@ -1,9 +1,97 @@ #include "models.h" -llm_build_gemma3n_iswa::llm_build_gemma3n_iswa(const llama_model & model, const llm_graph_params & params) : +void llama_model_gemma3n::load_arch_hparams(llama_model_loader & ml) { + uint32_t swa_period = 5; + ml.get_key_or_arr(LLM_KV_ATTENTION_SLIDING_WINDOW_PATTERN, swa_period, false); + hparams.swa_type = LLAMA_SWA_TYPE_STANDARD; + hparams.set_swa_pattern(swa_period); + + hparams.n_layer_kv_from_start = 20; + hparams.f_attention_scale = 1.0f; + + ml.get_key(LLM_KV_ROPE_FREQ_BASE_SWA, hparams.rope_freq_base_train_swa, false); + ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa); + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + + switch (hparams.n_layer()) { + case 30: type = LLM_TYPE_E2B; break; + case 35: type = LLM_TYPE_E4B; break; + default: type = LLM_TYPE_UNKNOWN; + } +} + +void llama_model_gemma3n::load_arch_tensors(llama_model_loader &) { + LLAMA_LOAD_LOCALS; + + const int64_t n_altup = hparams.n_altup; + const int64_t laurel_rank = hparams.laurel_rank; + const int64_t n_embd_altup = hparams.n_embd_altup; + + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); + // if output is NULL, init from the input tok embed + if (output == NULL) { + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); + } + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + altup_proj = create_tensor(tn(LLM_TENSOR_ALTUP_PROJ, "weight"), {n_embd, n_embd, n_altup - 1}, 0); + altup_unembd_proj = create_tensor(tn(LLM_TENSOR_ALTUP_UNEMBD_PROJ, "weight"), {n_embd, n_embd, n_altup - 1}, 0); + + per_layer_tok_embd = create_tensor(tn(LLM_TENSOR_PER_LAYER_TOKEN_EMBD, "weight"), {n_embd_altup * n_layer, n_vocab}, 0); + per_layer_model_proj = create_tensor(tn(LLM_TENSOR_PER_LAYER_MODEL_PROJ, "weight", 0), {n_embd, n_embd_altup * n_layer}, 0); + per_layer_proj_norm = create_tensor(tn(LLM_TENSOR_PER_LAYER_PROJ_NORM, "weight", 0), {n_embd_altup}, 0); + + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + + create_tensor_qkv(layer, i, n_embd, n_embd_head_k * n_head, n_embd_k_gqa, n_embd_v_gqa, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0); + + layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head_k}, 0); + layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k}, 0); + layer.attn_post_norm = create_tensor(tn(LLM_TENSOR_ATTN_POST_NORM, "weight", i), {n_embd}, 0); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); + layer.ffn_post_norm = create_tensor(tn(LLM_TENSOR_FFN_POST_NORM, "weight", i), {n_embd}, 0); + + // altup & laurel + layer.per_layer_inp_gate = create_tensor(tn(LLM_TENSOR_PER_LAYER_INP_GATE, "weight", i), {n_embd, n_embd_altup}, 0); + layer.per_layer_proj = create_tensor(tn(LLM_TENSOR_PER_LAYER_PROJ, "weight", i), {n_embd_altup, n_embd}, 0); + layer.per_layer_post_norm = create_tensor(tn(LLM_TENSOR_PER_LAYER_POST_NORM, "weight", i), {n_embd}, 0); + layer.altup_correct_coef = create_tensor(tn(LLM_TENSOR_ALTUP_CORRECT_COEF, "weight", i), {n_altup, n_altup}, 0); + layer.altup_correct_scale = create_tensor(tn(LLM_TENSOR_ALTUP_CORRECT_SCALE, "weight", i), {n_embd}, 0); + layer.altup_predict_coef = create_tensor(tn(LLM_TENSOR_ALTUP_PREDICT_COEF, "weight", i), {n_altup, n_altup * n_altup}, 0); + layer.altup_router = create_tensor(tn(LLM_TENSOR_ALTUP_ROUTER, "weight", i), {n_embd, n_altup}, 0); + layer.altup_router_norm = create_tensor(tn(LLM_TENSOR_ALTUP_ROUTER_NORM, "weight", i), {n_embd}, 0); + layer.laurel_l = create_tensor(tn(LLM_TENSOR_LAUREL_L, "weight", i), {n_embd, laurel_rank}, 0); + layer.laurel_r = create_tensor(tn(LLM_TENSOR_LAUREL_R, "weight", i), {laurel_rank, n_embd}, 0); + layer.laurel_post_norm = create_tensor(tn(LLM_TENSOR_LAUREL_POST_NORM, "weight", i), {n_embd}, 0); + } +} + +std::unique_ptr<llm_graph_context> llama_model_gemma3n::build_arch_graph(const llm_graph_params & params) const { + return std::make_unique<graph>(*this, params); +} + +// get 2D slice view from a 3D tensor, the idx corresponds to the 3rd dim +static ggml_tensor * ggml_view_2d_slice(ggml_context * ctx0, ggml_tensor * x, int idx) { + GGML_ASSERT(idx < (int) x->ne[2]); + return ggml_view_2d(ctx0, x, x->ne[0], x->ne[1], ggml_row_size(x->type, x->ne[0]), + idx * x->ne[0] * x->ne[1] * ggml_element_size(x)); +} + +llama_model_gemma3n::graph::graph(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params), model(model), - n_embd_head(model.hparams.n_embd_head_k), + n_embd_head(model.hparams.n_embd_head_k()), n_embd_altup(model.hparams.n_embd_altup), n_altup(model.hparams.n_altup), i_altup_act(model.hparams.i_altup_act) { @@ -12,7 +100,7 @@ llm_build_gemma3n_iswa::llm_build_gemma3n_iswa(const llama_model & model, const inpL = build_inp_embd(model.tok_embd); - // important: do not normalize weights for raw embeddings input (i.e. encoded image emdeddings) + // important: do not normalize weights for raw embeddings input (i.e. encoded image embeddings) inpL = ggml_scale(ctx0, inpL, ubatch.token ? sqrtf(n_embd) : 1.0f); cb(inpL, "inp_scaled", -1); @@ -22,8 +110,11 @@ llm_build_gemma3n_iswa::llm_build_gemma3n_iswa(const llama_model & model, const // TODO: is causal == true correct? might need some changes auto * inp_attn = build_attn_inp_kv_iswa(); - // inp_per_layer shape: [n_embd_altup, n_tokens, n_layer] - ggml_tensor * inp_per_layer = project_per_layer_inputs(inpL, get_per_layer_inputs()); + ggml_tensor * inp_per_layer = build_inp_per_layer(); + ggml_build_forward_expand(gf, inp_per_layer); + + // inp_per_layer now has shape: [n_embd_altup, n_tokens, n_layer] + inp_per_layer = project_per_layer_inputs(inpL, inp_per_layer); // inpL now has only 1 altup, project it to the rest of the altups // these "added" altups will be concat to the last dim of inpL @@ -37,8 +128,7 @@ llm_build_gemma3n_iswa::llm_build_gemma3n_iswa(const llama_model & model, const inpL = ggml_concat(ctx0, inpL, altup_added, 2); // shape: [n_embd, n_tokens, n_altup] cb(inpL, "inp_stacked", -1); } - // inpL now has shape: [n_embd, n_tokens, n_altup] - // inp_per_layer now has shape: [n_embd_altup, n_tokens, n_layer] + // inpL now has shape: [n_embd, n_tokens, n_altup] for (int il = 0; il < n_layer; ++il) { // this block is made to be closely resemble Gemma3p5DecoderLayer on python code @@ -49,8 +139,8 @@ llm_build_gemma3n_iswa::llm_build_gemma3n_iswa(const llama_model & model, const ggml_tensor * predictions = altup_predict(cur, il); // [n_embd, n_tokens, n_altup] // predicted value will go through self-attention and laurel - ggml_tensor * active_prediction = view_2d_slice(predictions, i_altup_act); // [n_embd, n_tokens] - cur = active_prediction; + ggml_tensor * active_prediction = ggml_view_2d_slice(ctx0, predictions, i_altup_act); // [n_embd, n_tokens] + cur = active_prediction; cb(cur, "active_prediction", il); // norm @@ -62,19 +152,7 @@ llm_build_gemma3n_iswa::llm_build_gemma3n_iswa(const llama_model & model, const // self-attention if (hparams.has_kv(il)) { - // compute Q and K and RoPE them - ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); - cb(Qcur, "Qcur", il); - - ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); - cb(Kcur, "Kcur", il); - - ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); - cb(Vcur, "Vcur", il); - - Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); - Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + auto [Qcur, Kcur, Vcur] = build_qkv(model.layers[il], cur, n_embd_head, n_head, n_head_kv, il); Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, NULL, LLM_NORM_RMS, il); Kcur = build_norm(Kcur, model.layers[il].attn_k_norm, NULL, LLM_NORM_RMS, il); @@ -94,7 +172,7 @@ llm_build_gemma3n_iswa::llm_build_gemma3n_iswa(const llama_model & model, const cb(Kcur, "Kcur_pos", il); cur = build_attn(inp_attn, model.layers[il].wo, - NULL, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, + NULL, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, hparams.f_attention_scale, il); } else { // reuse KV cache of earlier layers @@ -110,7 +188,7 @@ llm_build_gemma3n_iswa::llm_build_gemma3n_iswa(const llama_model & model, const cb(Qcur, "Qcur_pos", il); cur = build_attn(inp_attn, - model.layers[il].wo, NULL, + model.layers[il].wo, NULL, model.layers[il].wo_s, Qcur, nullptr, nullptr, nullptr, nullptr, nullptr, hparams.f_attention_scale, il); } cur = build_norm(cur, model.layers[il].attn_post_norm, NULL, LLM_NORM_RMS, il); @@ -151,12 +229,13 @@ llm_build_gemma3n_iswa::llm_build_gemma3n_iswa(const llama_model & model, const ggml_tensor * first_prediction; // [n_embd, n_tokens] { - first_prediction = view_2d_slice(corrected, i_altup_act); // [n_embd, n_tokens] + first_prediction = ggml_view_2d_slice(ctx0, corrected, i_altup_act); // [n_embd, n_tokens] first_prediction = ggml_mul(ctx0, first_prediction, model.layers[il].altup_correct_scale); first_prediction = build_lora_mm(model.layers[il].per_layer_inp_gate, first_prediction); first_prediction = ggml_gelu(ctx0, first_prediction); // [n_embd_altup, n_tokens] cb(first_prediction, "first_prediction_gated", il); - ggml_tensor * inp_this_layer = view_2d_slice(inp_per_layer, il); // [n_embd_altup, n_tokens] + + ggml_tensor * inp_this_layer = ggml_view_2d_slice(ctx0, inp_per_layer, il); // [n_embd_altup, n_tokens] first_prediction = ggml_mul(ctx0, first_prediction, inp_this_layer); // [n_embd_altup, n_tokens] cb(first_prediction, "first_prediction_scaled", il); @@ -167,7 +246,7 @@ llm_build_gemma3n_iswa::llm_build_gemma3n_iswa(const llama_model & model, const } // equivalent to python code: corrected_predictions[1:] += first_prediction { - ggml_tensor * slice_first = view_2d_slice(corrected, 0); + ggml_tensor * slice_first = ggml_view_2d_slice(ctx0, corrected, 0); ggml_tensor * slice_rest = ggml_view_3d( ctx0, corrected, n_embd, n_tokens, n_altup - 1, ggml_row_size(corrected->type, n_embd), ggml_row_size(corrected->type, n_embd * n_tokens), n_embd * n_tokens * ggml_element_size(corrected)); @@ -185,7 +264,7 @@ llm_build_gemma3n_iswa::llm_build_gemma3n_iswa(const llama_model & model, const // cur now has multiple altup(s), we want to merge them back to 1 altup { - ggml_tensor * target_magnitude = calc_magnitude(view_2d_slice(cur, i_altup_act)); // [n_embd, n_tokens] + ggml_tensor * target_magnitude = calc_magnitude(ggml_view_2d_slice(ctx0, cur, i_altup_act)); // [n_embd, n_tokens] // do a view to skip the first slice (active altup) ggml_tensor * alt_slice = ggml_view_3d(ctx0, cur, n_embd, n_tokens, n_altup - 1, ggml_row_size(cur->type, n_embd), @@ -197,9 +276,9 @@ llm_build_gemma3n_iswa::llm_build_gemma3n_iswa(const llama_model & model, const cb(altup_unembd, "altup_unembd", -1); // equivalent to torch.mean(hidden_states, dim=0) - cur = view_2d_slice(cur, 0); // [n_embd, n_tokens] + cur = ggml_view_2d_slice(ctx0, cur, 0); // [n_embd, n_tokens] for (int i = 0; i < n_altup - 1; ++i) { - cur = ggml_add(ctx0, cur, view_2d_slice(altup_unembd, i)); + cur = ggml_add(ctx0, cur, ggml_view_2d_slice(ctx0, altup_unembd, i)); } cur = ggml_scale(ctx0, cur, 1.0f / float(n_altup)); // [n_embd, n_tokens] cb(cur, "unembd_merged", -1); @@ -217,7 +296,7 @@ llm_build_gemma3n_iswa::llm_build_gemma3n_iswa(const llama_model & model, const cb(cur, "result_norm", -1); res->t_embd = cur; - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); { // final logit soft-capping @@ -231,43 +310,38 @@ llm_build_gemma3n_iswa::llm_build_gemma3n_iswa(const llama_model & model, const ggml_build_forward_expand(gf, cur); } -ggml_tensor * llm_build_gemma3n_iswa::calc_magnitude(ggml_tensor * x) { +ggml_tensor * llama_model_gemma3n::graph::calc_magnitude(ggml_tensor * x) { return ggml_sqrt(ctx0, ggml_sum_rows(ctx0, ggml_sqr(ctx0, x))); } -// get 2D slice view from a 3D tensor, the idx corresponds to the 3rd dim -ggml_tensor * llm_build_gemma3n_iswa::view_2d_slice(ggml_tensor * x, int idx) { - GGML_ASSERT(idx < (int) x->ne[2]); - return ggml_view_2d(ctx0, x, x->ne[0], x->ne[1], ggml_row_size(x->type, x->ne[0]), - idx * x->ne[0] * x->ne[1] * ggml_element_size(x)); -} - // equivalent to get_per_layer_inputs() in python code // output shape: [n_embd_altup, n_layer, n_tokens] -ggml_tensor * llm_build_gemma3n_iswa::get_per_layer_inputs() { - auto inp = std::make_unique<llm_graph_input_embd>(); +ggml_tensor * llama_model_gemma3n::graph::build_inp_per_layer() { + auto inp = std::make_unique<llm_graph_input_embd>(n_embd); ggml_tensor * inp_per_layer; + float tok_embd_scale = sqrtf((float) n_embd_altup); if (ubatch.token) { inp->tokens = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, ubatch.n_tokens); ggml_set_input(inp->tokens); - res->t_tokens = inp->tokens; - inp_per_layer = ggml_get_rows(ctx0, model.tok_embd_per_layer, inp->tokens); + res->t_inp_tokens = inp->tokens; + inp_per_layer = ggml_get_rows (ctx0, model.per_layer_tok_embd, inp->tokens); inp_per_layer = ggml_reshape_3d(ctx0, inp_per_layer, n_embd_altup, n_layer, n_tokens); - inp_per_layer = ggml_scale(ctx0, inp_per_layer, sqrtf((float) n_embd_altup)); + inp_per_layer = ggml_scale (ctx0, inp_per_layer, tok_embd_scale); cb(inp_per_layer, "inp_per_layer_selected", -1); res->add_input(std::move(inp)); } else { - // Vision embedding path: use padding token (ID=0) embedding - const int64_t embd_size = model.tok_embd_per_layer->ne[0]; // n_embd_altup * n_layer + // Multimodal embedding path: use padding token (ID=0) embedding + // TODO: verify if this is the correct behavior in transformers implementation + const int64_t embd_size = model.per_layer_tok_embd->ne[0]; // n_embd_altup * n_layer - // Extract and dequantize padding token embedding (column 0) - ggml_tensor * padding_q = ggml_view_1d(ctx0, model.tok_embd_per_layer, embd_size, 0); - ggml_tensor * padding_f32 = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, embd_size); - inp_per_layer = ggml_cpy(ctx0, padding_q, padding_f32); + // Extract and dequantize padding token embedding (row 0) + ggml_tensor * padding = ggml_view_1d(ctx0, model.per_layer_tok_embd, embd_size, 0); + inp_per_layer = ggml_cast (ctx0, padding, GGML_TYPE_F32); + inp_per_layer = ggml_scale(ctx0, inp_per_layer, tok_embd_scale); // Reshape to [n_embd_altup, n_layer, 1] inp_per_layer = ggml_reshape_3d(ctx0, inp_per_layer, n_embd_altup, n_layer, 1); - cb(inp_per_layer, "inp_per_layer_vision", -1); + cb(inp_per_layer, "inp_per_layer_multimodal", -1); } return inp_per_layer; } @@ -275,18 +349,19 @@ ggml_tensor * llm_build_gemma3n_iswa::get_per_layer_inputs() { // equivalent to project_per_layer_inputs() in python code // this calculates the per-layer inputs, so the final tensor shape will have n_layer as the last dim // output shape: [n_embd_altup, n_tokens, n_layer] -ggml_tensor * llm_build_gemma3n_iswa::project_per_layer_inputs(ggml_tensor * inputs_embeds, ggml_tensor * inp_per_layer) { +ggml_tensor * llama_model_gemma3n::graph::project_per_layer_inputs(ggml_tensor * inp_batch, ggml_tensor * inp_per_layer) { const float per_layer_projection_scale = 1.0f / sqrtf((float) n_embd); const float per_layer_input_scale = 1.0f / sqrtf(2.0f); - ggml_tensor * per_layer_proj = ggml_mul_mat(ctx0, model.per_layer_model_proj, inputs_embeds); - per_layer_proj = ggml_scale(ctx0, per_layer_proj, per_layer_projection_scale); - per_layer_proj = ggml_reshape_3d(ctx0, per_layer_proj, n_embd_altup, n_layer, n_tokens); - per_layer_proj = build_norm(per_layer_proj, model.per_layer_proj_norm, NULL, LLM_NORM_RMS, - -1); // [n_embd_altup, n_layer, n_tokens] + ggml_tensor * per_layer_proj; + per_layer_proj = ggml_mul_mat (ctx0, model.per_layer_model_proj, inp_batch); + per_layer_proj = ggml_scale (ctx0, per_layer_proj, per_layer_projection_scale); + per_layer_proj = ggml_reshape_3d(ctx0, per_layer_proj, n_embd_altup, n_layer, n_tokens); + + per_layer_proj = build_norm(per_layer_proj, model.per_layer_proj_norm, NULL, LLM_NORM_RMS, -1); cb(per_layer_proj, "per_layer_proj", -1); - inp_per_layer = ggml_add(ctx0, per_layer_proj, inp_per_layer); + inp_per_layer = ggml_add (ctx0, per_layer_proj, inp_per_layer); inp_per_layer = ggml_scale(ctx0, inp_per_layer, per_layer_input_scale); cb(inp_per_layer, "inp_per_layer", -1); @@ -297,7 +372,7 @@ ggml_tensor * llm_build_gemma3n_iswa::project_per_layer_inputs(ggml_tensor * inp // input cur shape: [n_altup, n_tokens] // output shape: [n_altup, n_tokens] -ggml_tensor * llm_build_gemma3n_iswa::laurel(ggml_tensor * cur, int il) { +ggml_tensor * llama_model_gemma3n::graph::laurel(ggml_tensor * cur, int il) { ggml_tensor * tmp = cur; tmp = build_lora_mm(model.layers[il].laurel_l, tmp); tmp = build_lora_mm(model.layers[il].laurel_r, tmp); @@ -309,7 +384,7 @@ ggml_tensor * llm_build_gemma3n_iswa::laurel(ggml_tensor * cur, int il) { // input x shape: [n_embd, n_tokens] // output shape: [n_embd, n_tokens] -ggml_tensor * llm_build_gemma3n_iswa::gaussian_topk(ggml_tensor * x) { +ggml_tensor * llama_model_gemma3n::graph::gaussian_topk(ggml_tensor * x) { ggml_tensor * mean = ggml_mean(ctx0, x); ggml_tensor * std = ggml_sqrt(ctx0, ggml_scale(ctx0, ggml_sum_rows(ctx0, ggml_sqr(ctx0, ggml_sub(ctx0, x, mean))), 1.0f / (float) (x->ne[0] - 1))); @@ -324,7 +399,7 @@ ggml_tensor * llm_build_gemma3n_iswa::gaussian_topk(ggml_tensor * x) { // equivalent to compute_router_modalities() in python code // input x shape: [n_embd, n_tokens] // output shape: [n_altup, n_tokens] -ggml_tensor * llm_build_gemma3n_iswa::altup_compute_router_modalities(ggml_tensor * x, int il) { +ggml_tensor * llama_model_gemma3n::graph::altup_compute_router_modalities(ggml_tensor * x, int il) { ggml_tensor * router_inputs = build_norm(x, model.layers[il].altup_router_norm, NULL, LLM_NORM_RMS, il); // router_input_scale @@ -336,8 +411,8 @@ ggml_tensor * llm_build_gemma3n_iswa::altup_compute_router_modalities(ggml_tenso // input cur shape: [n_embd, n_tokens, n_altup] // output shape: [n_embd, n_tokens, n_altup] -ggml_tensor * llm_build_gemma3n_iswa::altup_predict(ggml_tensor * cur, int il) { - ggml_tensor * activated = view_2d_slice(cur, i_altup_act); // [n_embd, n_tokens] +ggml_tensor * llama_model_gemma3n::graph::altup_predict(ggml_tensor * cur, int il) { + ggml_tensor * activated = ggml_view_2d_slice(ctx0, cur, i_altup_act); // [n_embd, n_tokens] ggml_tensor * modalities = altup_compute_router_modalities(activated, il); // [n_altup, n_tokens] cb(modalities, "modalities", il); @@ -361,11 +436,11 @@ ggml_tensor * llm_build_gemma3n_iswa::altup_predict(ggml_tensor * cur, int il) { // input predictions shape: [n_embd, n_tokens, n_altup] // input activated shape: [n_embd, n_tokens] // output shape: [n_embd, n_tokens, n_altup] -ggml_tensor * llm_build_gemma3n_iswa::altup_correct(ggml_tensor * predictions, ggml_tensor * activated, int il) { +ggml_tensor * llama_model_gemma3n::graph::altup_correct(ggml_tensor * predictions, ggml_tensor * activated, int il) { ggml_tensor * modalities = altup_compute_router_modalities(activated, il); // [n_altup, n_tokens] cb(modalities, "modalities", il); - ggml_tensor * active_prediction = view_2d_slice(predictions, i_altup_act); + ggml_tensor * active_prediction = ggml_view_2d_slice(ctx0, predictions, i_altup_act); ggml_tensor * innovation = ggml_sub(ctx0, activated, active_prediction); // [n_embd, n_tokens] cb(innovation, "innovation", il); diff --git a/examples/talk-llama/models/gemma4-assistant.cpp b/examples/talk-llama/models/gemma4-assistant.cpp new file mode 100644 index 00000000000..6378130e79e --- /dev/null +++ b/examples/talk-llama/models/gemma4-assistant.cpp @@ -0,0 +1,203 @@ +#include "models.h" + +void llama_model_gemma4_assistant::load_arch_hparams(llama_model_loader & ml) { + hparams.n_embd_inp_impl = hparams.n_embd_out(); + + hparams.swa_type = LLAMA_SWA_TYPE_STANDARD; + ml.get_key_or_arr(LLM_KV_ATTENTION_SLIDING_WINDOW_PATTERN, hparams.is_swa_impl, hparams.n_layer()); + + uint32_t n_kv_shared_layers = 0; + ml.get_key(LLM_KV_ATTENTION_SHARED_KV_LAYERS, n_kv_shared_layers, false); + + hparams.f_attention_scale = 1.0f; + + ml.get_key(LLM_KV_NEXTN_PREDICT_LAYERS, hparams.n_layer_nextn, false); + GGML_ASSERT(hparams.n_layer_nextn == hparams.n_layer_all && "n_layer_nextn must be == n_layer_impl"); + + ml.get_key(LLM_KV_ROPE_FREQ_BASE_SWA, hparams.rope_freq_base_train_swa, false); + ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa); + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + ml.get_key(LLM_KV_ATTENTION_KEY_LENGTH_SWA, hparams.n_embd_head_k_swa); + ml.get_key(LLM_KV_ATTENTION_VALUE_LENGTH_SWA, hparams.n_embd_head_v_swa); +} + +void llama_model_gemma4_assistant::load_arch_tensors(llama_model_loader &) { + LLAMA_LOAD_LOCALS; + + if (n_embd_head_k != n_embd_head_v) { + throw std::runtime_error("Gemma 4 assistant requires n_embd_head_k == n_embd_head_v"); + } + if (hparams.n_embd_head_k_swa != hparams.n_embd_head_v_swa) { + throw std::runtime_error("Gemma 4 assistant requires n_embd_head_k_swa == n_embd_head_v_swa"); + } + if (hparams.n_embd_out() == n_embd) { + throw std::runtime_error("Gemma 4 assistant requires embedding_length_out to carry the target hidden size"); + } + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), { n_embd, n_vocab }, 0); + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), { n_embd, n_vocab }, TENSOR_DUPLICATED); + + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), { n_embd }, 0); + + create_tensor(tn(LLM_TENSOR_MASKED_EMBD_CENTROIDS, "weight"), {}, TENSOR_NOT_REQUIRED); + create_tensor(tn(LLM_TENSOR_MASKED_EMBD_ORDERING), {}, TENSOR_NOT_REQUIRED); + + const int64_t n_embd_backbone = hparams.n_embd_inp(); + nextn_proj_post = create_tensor(tn(LLM_TENSOR_NEXTN_PROJ_POST, "weight"), { n_embd, n_embd_backbone }, 0); + + int rope_freqs_flag = 0; + + for (int i = 0; i < n_layer_nextn; ++i) { + auto & layer = layers[i]; + + const int64_t n_head = hparams.n_head(i); + const int64_t n_embd_head = hparams.n_embd_head_k(i); + const int64_t n_ff = hparams.n_ff(i); + + if (i == 0) { + nextn_proj_pre = create_tensor(tn(LLM_TENSOR_NEXTN_PROJ_PRE, "weight", i), { 2*n_embd_backbone, n_embd }, 0); + } + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), { n_embd }, 0); + layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), { n_embd, n_embd_head*n_head }, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), { n_embd_head*n_head, n_embd }, 0); + + layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), { n_embd_head }, 0); + layer.attn_post_norm = create_tensor(tn(LLM_TENSOR_ATTN_POST_NORM, "weight", i), { n_embd }, 0); + + layer.out_scale = create_tensor(tn(LLM_TENSOR_LAYER_OUT_SCALE, "weight", i), { 1u }, 0); + + if (!hparams.is_swa(i)) { + layer.rope_freqs = create_tensor(tn(LLM_TENSOR_ROPE_FREQS, "weight", i), { n_embd_head/2 }, rope_freqs_flag); + rope_freqs_flag = TENSOR_DUPLICATED; + } + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), { n_embd }, 0); + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), { n_embd, n_ff }, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), { n_embd, n_ff }, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd }, 0); + layer.ffn_post_norm = create_tensor(tn(LLM_TENSOR_FFN_POST_NORM, "weight", i), { n_embd }, 0); + } +} + +std::unique_ptr<llm_graph_context> llama_model_gemma4_assistant::build_arch_graph(const llm_graph_params & params) const { + return std::make_unique<graph>(*this, params); +} + +llama_model_gemma4_assistant::graph::graph(const llama_model & model, const llm_graph_params & params) : + llm_graph_context(params) { + const int64_t n_embd_backbone = hparams.n_embd_inp(); + + ggml_tensor * inp_tokens; + ggml_tensor * inp_h; + { + auto inp = std::make_unique<llm_graph_input_embd>(n_embd_backbone); + + inp->tokens = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, ubatch.n_tokens); + cb(inp->tokens, "inp_tokens", -1); + ggml_set_input(inp->tokens); + inp_tokens = inp->tokens; + res->t_inp_tokens = inp->tokens; + + inp->embd = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd_backbone, ubatch.n_tokens); + cb(inp->embd, "inp_h", -1); + ggml_set_input(inp->embd); + inp_h = inp->embd; + res->t_inp_embd = inp->embd; + + res->add_input(std::move(inp)); + } + + GGML_ASSERT(cparams.ctx_other != nullptr); + const auto * model_other = llama_get_model(cparams.ctx_other); + + ggml_tensor * x = ggml_get_rows(ctx0, model_other->tok_embd, inp_tokens); + x = ggml_scale(ctx0, x, sqrtf((float) n_embd_backbone)); + cb(x, "inp_embd_target", -1); + + ggml_tensor * xh = ggml_concat(ctx0, x, inp_h, 0); + cb(xh, "inp_xh", -1); + + ggml_tensor * cur = ggml_mul_mat(ctx0, model.nextn_proj_pre, xh); + cb(cur, "pre_proj", -1); + + auto * inp_attn = build_attn_inp_kv_iswa(); + ggml_tensor * inp_pos = build_inp_pos(); + ggml_tensor * inp_out_ids = build_inp_out_ids(); + + ggml_tensor * inpL = cur; + + for (int il = 0; il < n_layer_nextn; ++il) { + const bool is_swa = hparams.is_swa(il); + + const int64_t n_embd_head = hparams.n_embd_head_k(il); + const int64_t n_head = hparams.n_head(il); + + const float freq_base_l = model.get_rope_freq_base(cparams, il); + const float freq_scale_l = model.get_rope_freq_scale(cparams, il); + const int n_rot_l = hparams.n_rot(il); + + ggml_tensor * cur_norm = build_norm(inpL, model.layers[il].attn_norm, nullptr, LLM_NORM_RMS, il); + cb(cur_norm, "attn_norm", il); + + ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur_norm); + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); + Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, nullptr, LLM_NORM_RMS, il); + cb(Qcur, "Qcur_normed", il); + + ggml_tensor * freq_factors = is_swa ? nullptr : model.layers[il].rope_freqs; + Qcur = ggml_rope_ext(ctx0, Qcur, inp_pos, freq_factors, n_rot_l, rope_type, n_ctx_orig, + freq_base_l, freq_scale_l, ext_factor, attn_factor, beta_fast, beta_slow); + cb(Qcur, "Qcur_pos", il); + + cur = build_attn(inp_attn, model.layers[il].wo, nullptr, nullptr, + Qcur, nullptr, nullptr, nullptr, nullptr, nullptr, hparams.f_attention_scale, il); + + if (il == n_layer_nextn - 1 && inp_out_ids) { + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpL = ggml_get_rows(ctx0, inpL, inp_out_ids); + } + + cur = build_norm(cur, model.layers[il].attn_post_norm, nullptr, LLM_NORM_RMS, il); + cb(cur, "attn_post_norm", il); + + ggml_tensor * attn_out = ggml_add(ctx0, cur, inpL); + cb(attn_out, "attn_out", il); + + cur = build_norm(attn_out, model.layers[il].ffn_norm, nullptr, LLM_NORM_RMS, il); + cb(cur, "ffn_norm", il); + + cur = build_ffn(cur, + model.layers[il].ffn_up, nullptr, nullptr, + model.layers[il].ffn_gate, nullptr, nullptr, + model.layers[il].ffn_down, nullptr, nullptr, + nullptr, + LLM_FFN_GELU, LLM_FFN_PAR, il); + cb(cur, "ffn_out", il); + + cur = build_norm(cur, model.layers[il].ffn_post_norm, nullptr, LLM_NORM_RMS, -1); + cb(cur, "ffn_post_norm", il); + + cur = ggml_add(ctx0, cur, attn_out); + + cur = ggml_mul(ctx0, cur, model.layers[il].out_scale); + cb(cur, "out_scaled", il); + + inpL = cur; + } + cur = inpL; + + cur = build_norm(cur, model.output_norm, nullptr, LLM_NORM_RMS, -1); + cb(cur, "result_norm", -1); + + ggml_tensor * logits = build_lora_mm(model.output, cur); + cb(logits, "result_output", -1); + res->t_logits = logits; + + ggml_tensor * h_next = ggml_mul_mat(ctx0, model.nextn_proj_post, cur); + cb(h_next, "h_nextn", -1); + res->t_h_nextn = h_next; + + ggml_build_forward_expand(gf, logits); + ggml_build_forward_expand(gf, h_next); +} diff --git a/examples/talk-llama/models/gemma4.cpp b/examples/talk-llama/models/gemma4.cpp new file mode 100644 index 00000000000..6a96979cebd --- /dev/null +++ b/examples/talk-llama/models/gemma4.cpp @@ -0,0 +1,508 @@ +#include "models.h" + +void llama_model_gemma4::load_arch_hparams(llama_model_loader & ml) { + hparams.swa_type = LLAMA_SWA_TYPE_STANDARD; + ml.get_key_or_arr(LLM_KV_ATTENTION_SLIDING_WINDOW_PATTERN, hparams.is_swa_impl, hparams.n_layer()); + + uint32_t n_kv_shared_layers = 0; + ml.get_key(LLM_KV_ATTENTION_SHARED_KV_LAYERS, n_kv_shared_layers, false); + + hparams.n_layer_kv_from_start = hparams.n_layer_all - (int32_t)n_kv_shared_layers; + hparams.f_attention_scale = 1.0f; // Gemma4 uses self.scaling = 1.0 (no pre-attn scaling) + + ml.get_key(LLM_KV_ROPE_FREQ_BASE_SWA, hparams.rope_freq_base_train_swa, false); + ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp, false); + ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa); + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + ml.get_key(LLM_KV_EMBEDDING_LENGTH_PER_LAYER, hparams.n_embd_per_layer); + ml.get_key(LLM_KV_ATTENTION_KEY_LENGTH_SWA, hparams.n_embd_head_k_swa); + ml.get_key(LLM_KV_ATTENTION_VALUE_LENGTH_SWA, hparams.n_embd_head_v_swa); + ml.get_key(LLM_KV_FINAL_LOGIT_SOFTCAPPING, hparams.f_final_logit_softcapping, false); + + switch (hparams.n_layer()) { + case 30: type = LLM_TYPE_26B_A4B; break; + case 35: type = LLM_TYPE_E2B; break; + case 42: type = LLM_TYPE_E4B; break; + case 60: type = LLM_TYPE_31B; break; + default: type = LLM_TYPE_UNKNOWN; + } +} + +void llama_model_gemma4::load_arch_tensors(llama_model_loader &) { + LLAMA_LOAD_LOCALS; + + const uint32_t n_embd_per_layer = hparams.n_embd_per_layer; + const int64_t n_ff_exp = hparams.n_ff_exp; + + if (n_embd_head_k != n_embd_head_v) { + throw std::runtime_error("Gemma 4 requires n_embd_head_k == n_embd_head_v"); + } + if (hparams.n_embd_head_k_swa != hparams.n_embd_head_v_swa) { + throw std::runtime_error("Gemma 4 requires n_embd_head_k_swa == n_embd_head_v_swa"); + } + + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); + // if output is NULL, init from the input tok embed + if (output == NULL) { + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); + } + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + if (n_embd_per_layer > 0) { + per_layer_tok_embd = create_tensor(tn(LLM_TENSOR_PER_LAYER_TOKEN_EMBD, "weight"), {n_embd_per_layer * n_layer, n_vocab}, 0); + per_layer_model_proj = create_tensor(tn(LLM_TENSOR_PER_LAYER_MODEL_PROJ, "weight", 0), {n_embd, n_embd_per_layer * n_layer}, 0); + per_layer_proj_norm = create_tensor(tn(LLM_TENSOR_PER_LAYER_PROJ_NORM, "weight", 0), {n_embd_per_layer}, 0); + } + + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + + int rope_freqs_flag = 0; + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + const int64_t n_head = hparams.n_head(i); + const int64_t n_embd_head = hparams.n_embd_head_k(i); + const int64_t n_embd_k = hparams.n_embd_k_gqa(i); + const int64_t n_embd_v = hparams.n_embd_v_gqa(i); + const int kv_flags = hparams.has_kv(i) ? 0 : TENSOR_NOT_REQUIRED; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + + // note: use_alternative_attention (v_proj is optional, if it's not present, use k_proj) + layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head * n_head}, 0); + layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_k}, kv_flags); + layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_v}, TENSOR_NOT_REQUIRED); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head * n_head, n_embd}, 0); + + layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head}, 0); + layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head}, kv_flags); + layer.attn_post_norm = create_tensor(tn(LLM_TENSOR_ATTN_POST_NORM, "weight", i), {n_embd}, 0); + + layer.out_scale = create_tensor(tn(LLM_TENSOR_LAYER_OUT_SCALE, "weight", i), {1u}, TENSOR_NOT_REQUIRED); + + if (!hparams.is_swa(i)) { + // full_attention layers use rope_freqs for proportional rope + layer.rope_freqs = create_tensor(tn(LLM_TENSOR_ROPE_FREQS, "weight", i), {n_embd_head/2}, rope_freqs_flag); + rope_freqs_flag = TENSOR_DUPLICATED; + } + + // handle use_double_wide_mlp + int64_t n_ff_cur = hparams.n_ff(i); + + // for expert layers, we use normal FFN as shared expert (same as python code) + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff_cur}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff_cur}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff_cur, n_embd}, 0); + layer.ffn_post_norm = create_tensor(tn(LLM_TENSOR_FFN_POST_NORM, "weight", i), {n_embd}, 0); + + // MoE router + layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, TENSOR_NOT_REQUIRED); + bool has_expert = layer.ffn_gate_inp != nullptr; + + // norm + if (has_expert) { + layer.ffn_gate_inp_s = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "scale", i), {n_embd}, 0); + + layer.ffn_pre_norm_2 = create_tensor(tn(LLM_TENSOR_FFN_PRE_NORM_2, "weight", i), {n_embd}, 0); + layer.ffn_post_norm_1 = create_tensor(tn(LLM_TENSOR_FFN_POST_NORM_1, "weight", i), {n_embd}, 0); + layer.ffn_post_norm_2 = create_tensor(tn(LLM_TENSOR_FFN_POST_NORM_2, "weight", i), {n_embd}, 0); + + // MoE FFN + layer.ffn_gate_up_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_UP_EXPS, "weight", i), {n_embd, n_ff_exp * 2, n_expert}, TENSOR_NOT_REQUIRED); + + if (layer.ffn_gate_up_exps == nullptr) { + layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {n_embd, n_ff_exp, n_expert}, 0); + layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), {n_embd, n_ff_exp, n_expert}, 0); + } + + layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff_exp, n_embd, n_expert}, 0); + + // per-expert scale will be loaded as down_exps_s at the end of the current switch case + } + + // per-layer embeddings + if (n_embd_per_layer > 0) { + layer.per_layer_inp_gate = create_tensor(tn(LLM_TENSOR_PER_LAYER_INP_GATE, "weight", i), {n_embd, n_embd_per_layer}, 0); + layer.per_layer_proj = create_tensor(tn(LLM_TENSOR_PER_LAYER_PROJ, "weight", i), {n_embd_per_layer, n_embd}, 0); + layer.per_layer_post_norm = create_tensor(tn(LLM_TENSOR_PER_LAYER_POST_NORM, "weight", i), {n_embd}, 0); + } + } +} + +std::unique_ptr<llm_graph_context> llama_model_gemma4::build_arch_graph(const llm_graph_params & params) const { + return std::make_unique<graph>(*this, params); +} + +// get 2D slice view from a 3D tensor, the idx corresponds to the 3rd dim +static ggml_tensor * ggml_view_2d_slice(ggml_context * ctx0, ggml_tensor * x, int idx) { + GGML_ASSERT(idx < (int) x->ne[2]); + return ggml_view_2d(ctx0, x, x->ne[0], x->ne[1], ggml_row_size(x->type, x->ne[0]), + idx * x->ne[0] * x->ne[1] * ggml_element_size(x)); +} + +// TODO @ngxson : maybe improve this in the future +class llm_graph_input_logits_bias : public llm_graph_input_i { +public: + llm_graph_input_logits_bias(const llama_vocab & vocab) { + arr.resize(vocab.n_tokens(), 0.0f); + for (llama_token id : vocab.get_suppress_tokens()) { + if (0 <= id && id < (int32_t)vocab.n_tokens()) { + arr[id] = -INFINITY; + } + } + } + virtual ~llm_graph_input_logits_bias() = default; + + void set_input(const llama_ubatch * /*ubatch*/) override { + const int64_t n_vocab = arr.size(); + ggml_backend_tensor_set(logits_bias, arr.data(), 0, n_vocab*ggml_element_size(logits_bias)); + } + + bool can_reuse(const llm_graph_params & /*params*/) override { + return true; + } + + ggml_tensor * logits_bias = nullptr; // F32 [n_vocab] + + std::vector<float> arr; +}; + +llama_model_gemma4::graph::graph(const llama_model & model, const llm_graph_params & params) : + llm_graph_context(params), + model(model), + n_embd_per_layer(model.hparams.n_embd_per_layer) { + ggml_tensor * cur; + ggml_tensor * inpL; + + inpL = build_inp_embd(model.tok_embd); + + // important: do not normalize weights for raw embeddings input (i.e. encoded image emdeddings) + inpL = ggml_scale(ctx0, inpL, ubatch.token ? sqrtf(n_embd) : 1.0f); + cb(inpL, "inp_scaled", -1); + + // inp_pos - contains the positions + ggml_tensor * inp_pos = build_inp_pos(); + + // TODO: is causal == true correct? might need some changes + auto * inp_attn = build_attn_inp_kv_iswa(); + + ggml_tensor * inp_out_ids = build_inp_out_ids(); + + ggml_tensor * inp_per_layer = nullptr; + if (model.per_layer_tok_embd) { + inp_per_layer = build_inp_per_layer(); + ggml_build_forward_expand(gf, inp_per_layer); + + // inp_per_layer shape: [n_embd_per_layer, n_tokens, n_layer] + inp_per_layer = project_per_layer_inputs(inpL, inp_per_layer); + } + + for (int il = 0; il < n_layer; ++il) { + const int64_t n_embd_head = hparams.n_embd_head_k(il); + GGML_ASSERT(n_embd_head == hparams.n_embd_head_v(il)); + + const int64_t n_head = hparams.n_head(il); + const int64_t n_head_kv = hparams.n_head_kv(il); + + const float freq_base_l = model.get_rope_freq_base(cparams, il); + const float freq_scale_l = model.get_rope_freq_scale(cparams, il); + const int n_rot_l = hparams.n_rot(il); + + res->t_layer_inp[il] = inpL; + + // norm + cur = build_norm(inpL, model.layers[il].attn_norm, nullptr, LLM_NORM_RMS, il); + cb(cur, "attn_norm", il); + + ggml_tensor * freq_factors = nullptr; + if (!hparams.is_swa(il)) { + // full_attention layers use rope_freqs for proportional rope + freq_factors = model.layers[il].rope_freqs; + } + + // Q projection (shared for both non-KV and KV layers) + // this is to mirror Gemma4Attention in pytorch code + ggml_tensor * Qcur; + { + Qcur = build_lora_mm(model.layers[il].wq, cur, model.layers[il].wq_s); + cb(Qcur, "Qcur", il); + + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); + + Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, nullptr, LLM_NORM_RMS, il); + cb(Qcur, "Qcur_normed", il); + + Qcur = ggml_rope_ext(ctx0, Qcur, inp_pos, freq_factors, n_rot_l, rope_type, n_ctx_orig, freq_base_l, freq_scale_l, + ext_factor, attn_factor, beta_fast, beta_slow); + cb(Qcur, "Qcur_pos", il); + } + + // self-attention + if (hparams.has_kv(il)) { + ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur, model.layers[il].wk_s); + cb(Kcur, "Kcur", il); + + ggml_tensor * Vcur = model.layers[il].wv + ? build_lora_mm(model.layers[il].wv, cur, model.layers[il].wv_s) + : Kcur; // if v_proj is not present, use Kcur as Vcur + cb(Vcur, "Vcur", il); + + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); + Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + + Kcur = build_norm(Kcur, model.layers[il].attn_k_norm, nullptr, LLM_NORM_RMS, il); + Vcur = ggml_rms_norm(ctx0, Vcur, hparams.f_norm_rms_eps); + + cb(Kcur, "Kcur_normed", il); + cb(Vcur, "Vcur_normed", il); + + Kcur = ggml_rope_ext(ctx0, Kcur, inp_pos, freq_factors, n_rot_l, rope_type, n_ctx_orig, freq_base_l, freq_scale_l, + ext_factor, attn_factor, beta_fast, beta_slow); + + cb(Kcur, "Kcur_pos", il); + + cur = build_attn(inp_attn, model.layers[il].wo, + nullptr, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, + hparams.f_attention_scale, il); + } else { + // reuse KV cache of earlier layers + cur = build_attn(inp_attn, + model.layers[il].wo, nullptr, model.layers[il].wo_s, + Qcur, nullptr, nullptr, nullptr, nullptr, nullptr, hparams.f_attention_scale, il); + } + + // TODO @ngxson : strip unused token right after the last KV layer to speed up prompt processing + // keep all rows when extracting unmasked nextn embeddings (MTP target needs the hidden state for every token) + if (il == n_layer - 1 && inp_out_ids && cparams.embeddings_nextn_masked) { + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpL = ggml_get_rows(ctx0, inpL, inp_out_ids); + } + cur = build_norm(cur, + model.layers[il].attn_post_norm, nullptr, + LLM_NORM_RMS, il); + cb(cur, "attn_post_norm", il); + + ggml_tensor * attn_out = ggml_add(ctx0, cur, inpL); + cb(attn_out, "attn_out", il); + + // feed-forward network + const bool is_moe_layer = model.layers[il].ffn_gate_inp != nullptr; + if (is_moe_layer) { + // MLP (shared exp) + ggml_tensor * cur_mlp = build_norm(attn_out, + model.layers[il].ffn_norm, nullptr, + LLM_NORM_RMS, il); + cb(cur_mlp, "ffn_norm_1", il); + + cur_mlp = build_ffn(cur_mlp, + model.layers[il].ffn_up, nullptr, model.layers[il].ffn_up_s, + model.layers[il].ffn_gate, nullptr, model.layers[il].ffn_gate_s, + model.layers[il].ffn_down, nullptr, model.layers[il].ffn_down_s, + nullptr, + LLM_FFN_GELU, LLM_FFN_PAR, il); + cur_mlp = build_norm(cur_mlp, + model.layers[il].ffn_post_norm_1, nullptr, + LLM_NORM_RMS, il); + cb(cur_mlp, "ffn_mlp", il); + + // Expert FFN + ggml_tensor * cur_moe = build_norm(attn_out, + model.layers[il].ffn_pre_norm_2, nullptr, + LLM_NORM_RMS, il); + cb(cur_moe, "ffn_norm_2", il); + + // custom MoE logits calculation (router operates on attn_out, not cur) + ggml_tensor * tmp = ggml_rms_norm(ctx0, attn_out, hparams.f_norm_rms_eps); + tmp = ggml_scale(ctx0, tmp, 1.0f / sqrtf((float) n_embd)); + tmp = ggml_mul(ctx0, tmp, model.layers[il].ffn_gate_inp_s); + ggml_tensor * logits = build_lora_mm(model.layers[il].ffn_gate_inp, tmp); // [n_expert, n_tokens] + cb(logits, "ffn_moe_logits", il); + + cur_moe = build_moe_ffn(cur_moe, + nullptr, // gate_inp + model.layers[il].ffn_up_exps, + model.layers[il].ffn_gate_exps, + model.layers[il].ffn_down_exps, + nullptr, // exp_probs_b (not used for gemma4) + n_expert, n_expert_used, + LLM_FFN_GELU, true, + 1.0f, + LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX, + il, logits, + model.layers[il].ffn_gate_up_exps, + model.layers[il].ffn_up_exps_s, + model.layers[il].ffn_gate_exps_s, + model.layers[il].ffn_down_exps_s); + cur_moe = build_norm(cur_moe, + model.layers[il].ffn_post_norm_2, nullptr, + LLM_NORM_RMS, il); + cb(cur_moe, "ffn_moe", il); + + cur = ggml_add(ctx0, cur_mlp, cur_moe); + cb(cur, "ffn_moe_combined", il); + } else { + cur = build_norm(attn_out, + model.layers[il].ffn_norm, nullptr, + LLM_NORM_RMS, il); + cb(cur, "ffn_norm", il); + + cur = build_ffn(cur, + model.layers[il].ffn_up, nullptr, model.layers[il].ffn_up_s, + model.layers[il].ffn_gate, nullptr, model.layers[il].ffn_gate_s, + model.layers[il].ffn_down, nullptr, model.layers[il].ffn_down_s, + nullptr, + LLM_FFN_GELU, LLM_FFN_PAR, il); + cb(cur, "ffn_out", il); + } + cur = build_norm(cur, + model.layers[il].ffn_post_norm, nullptr, + LLM_NORM_RMS, -1); + cb(cur, "ffn_post_norm", il); + + // residual connection + cur = ggml_add(ctx0, cur, attn_out); + + // per-layer embedding + if (inp_per_layer) { + ggml_tensor * pe_in = cur; + cb(cur, "pe_in", il); + + cur = build_lora_mm(model.layers[il].per_layer_inp_gate, cur); // [n_embd_per_layer, n_tokens] + cur = ggml_gelu(ctx0, cur); + + ggml_tensor * inp_this_layer = ggml_view_2d_slice(ctx0, inp_per_layer, il); // [n_embd_per_layer, n_tokens] + + // TODO @ngxson : improve this + if (il == n_layer - 1 && inp_out_ids && cparams.embeddings_nextn_masked) { + inp_this_layer = ggml_get_rows(ctx0, inp_this_layer, inp_out_ids); + } + + cur = ggml_mul(ctx0, cur, inp_this_layer); + cur = build_lora_mm(model.layers[il].per_layer_proj, cur); // [n_embd, n_tokens] + cur = build_norm(cur, model.layers[il].per_layer_post_norm, nullptr, LLM_NORM_RMS, il); + cb(cur, "per_layer_embd_out", il); + + // residual connection + cur = ggml_add(ctx0, pe_in, cur); + } + + // layer_scalar + if (model.layers[il].out_scale) { + cur = ggml_mul(ctx0, cur, model.layers[il].out_scale); + cb(cur, "out_scaled", il); + } + + cur = build_cvec(cur, il); + cb(cur, "l_out", il); + + // input for next layer + inpL = cur; + } + cur = inpL; + + cur = build_norm(cur, + model.output_norm, nullptr, + LLM_NORM_RMS, -1); + + // Expose the post-output-norm hidden state (the LM-head input feature) so that + // MTP draft contexts can read it via llama_get_embeddings_nextn_ith() as the + // recurrent h input. This matches the reference (transformers/vLLM/SGLang), + // which feeds the drafter the target's post-final-norm hidden state. + cb(cur, "h_nextn", -1); + res->t_h_nextn = cur; + + if (!cparams.embeddings_nextn_masked && inp_out_ids) { + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + } + + cb(cur, "result_norm", -1); + res->t_embd = cur; + + // lm_head + cur = build_lora_mm(model.output, cur, model.output_s); + + if (hparams.f_final_logit_softcapping) { + cur = ggml_scale(ctx0, cur, 1.0f / hparams.f_final_logit_softcapping); + cur = ggml_tanh(ctx0, cur); + cur = ggml_scale(ctx0, cur, hparams.f_final_logit_softcapping); + } + + // apply logits bias if needed (e.g. for gemma4_unified patch) + // this is to mirror the suppress_tokens patch on transformers, to avoid model from outputing <image|> and <audio|> tokens (which is a known issue related to the checkpoint) + // TODO: maybe handle this inside the sampling system in the future + if (!model.vocab.get_suppress_tokens().empty()) { + auto inp_bias = std::make_unique<llm_graph_input_logits_bias>(model.vocab); + inp_bias->logits_bias = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, inp_bias->arr.size()); + cur = ggml_add(ctx0, cur, inp_bias->logits_bias); + res->add_input(std::move(inp_bias)); + } + + cb(cur, "result_output", -1); + res->t_logits = cur; + + ggml_build_forward_expand(gf, cur); +} + +// equivalent to get_per_layer_inputs() in python code +// output shape: [n_embd_per_layer, n_layer, n_tokens] +ggml_tensor * llama_model_gemma4::graph::build_inp_per_layer() { + auto inp = std::make_unique<llm_graph_input_embd>(n_embd); + + ggml_tensor * inp_per_layer; + float tok_embd_scale = sqrtf((float) n_embd_per_layer); + if (ubatch.token) { + inp->tokens = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, ubatch.n_tokens); + ggml_set_input(inp->tokens); + res->t_inp_tokens = inp->tokens; + + inp_per_layer = ggml_get_rows (ctx0, model.per_layer_tok_embd, inp->tokens); + inp_per_layer = ggml_reshape_3d(ctx0, inp_per_layer, n_embd_per_layer, n_layer, n_tokens); + inp_per_layer = ggml_scale (ctx0, inp_per_layer, tok_embd_scale); + cb(inp_per_layer, "inp_per_layer_selected", -1); + + res->add_input(std::move(inp)); + } else { + // Multimodal embedding path: use padding token (ID=0) embedding + // TODO: verify if this is the correct behavior in transformers implementation + const int64_t embd_size = model.per_layer_tok_embd->ne[0]; // n_embd_per_layer * n_layer + + // Extract and dequantize padding token embedding (row 0) + ggml_tensor * padding = ggml_view_1d(ctx0, model.per_layer_tok_embd, embd_size, 0); + inp_per_layer = ggml_cast (ctx0, padding, GGML_TYPE_F32); + inp_per_layer = ggml_scale(ctx0, inp_per_layer, tok_embd_scale); + + // Reshape to [n_embd_per_layer, n_layer, 1] + inp_per_layer = ggml_reshape_3d(ctx0, inp_per_layer, n_embd_per_layer, n_layer, 1); + cb(inp_per_layer, "inp_per_layer_multimodal", -1); + } + return inp_per_layer; +} + +// equivalent to project_per_layer_inputs() in python code +// this calculates the per-layer inputs, so the final tensor shape will have n_layer as the last dim +// inp_batch shape: [n_embd, n_tokens] +// inp_per_layer shape: [n_embd_per_layer, n_layer, n_tokens] (from build_inp_per_layer) +// output shape: [n_embd_per_layer, n_tokens, n_layer] +ggml_tensor * llama_model_gemma4::graph::project_per_layer_inputs(ggml_tensor * inp_batch, ggml_tensor * inp_per_layer) { + const float per_layer_projection_scale = 1.0f / sqrtf((float) n_embd); + const float per_layer_input_scale = 1.0f / sqrtf(2.0f); + + // note: this matrix multiplication will be performed in the input layer (i.e. on the CPU) + ggml_tensor * per_layer_proj; + per_layer_proj = ggml_mul_mat (ctx0, model.per_layer_model_proj, inp_batch); + per_layer_proj = ggml_scale (ctx0, per_layer_proj, per_layer_projection_scale); + per_layer_proj = ggml_reshape_3d(ctx0, per_layer_proj, n_embd_per_layer, n_layer, n_tokens); + + per_layer_proj = build_norm(per_layer_proj, model.per_layer_proj_norm, nullptr, LLM_NORM_RMS, -1); + cb(per_layer_proj, "per_layer_proj", -1); + + inp_per_layer = ggml_add (ctx0, per_layer_proj, inp_per_layer); + inp_per_layer = ggml_scale(ctx0, inp_per_layer, per_layer_input_scale); + cb(inp_per_layer, "inp_per_layer", -1); + + // permute to shape: [n_embd_per_layer, n_tokens, n_layer] + inp_per_layer = ggml_cont(ctx0, ggml_permute(ctx0, inp_per_layer, 0, 2, 1, 3)); + return inp_per_layer; +} diff --git a/examples/talk-llama/models/glm-dsa.cpp b/examples/talk-llama/models/glm-dsa.cpp new file mode 100644 index 00000000000..11d91312def --- /dev/null +++ b/examples/talk-llama/models/glm-dsa.cpp @@ -0,0 +1,152 @@ +#include "models.h" + +void llama_model_glm_dsa::load_arch_hparams(llama_model_loader & ml) { + ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp); + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + ml.get_key_or_arr(LLM_KV_ROPE_DIMENSION_SECTIONS, hparams.rope_sections, 4, false); + + // MoE parameters + ml.get_key(LLM_KV_EXPERT_COUNT, hparams.n_expert); + ml.get_key(LLM_KV_EXPERT_USED_COUNT, hparams.n_expert_used); + ml.get_key(LLM_KV_EXPERT_SHARED_COUNT, hparams.n_expert_shared); + ml.get_key(LLM_KV_LEADING_DENSE_BLOCK_COUNT, hparams.n_layer_dense_lead, false); + ml.get_key(LLM_KV_EXPERT_WEIGHTS_SCALE, hparams.expert_weights_scale, false); + ml.get_key(LLM_KV_EXPERT_WEIGHTS_NORM, hparams.expert_weights_norm, false); + + // deepseek MLA parameters + ml.get_key(LLM_KV_ATTENTION_Q_LORA_RANK, hparams.n_lora_q); + ml.get_key(LLM_KV_ATTENTION_KV_LORA_RANK, hparams.n_lora_kv); + ml.get_key(LLM_KV_ATTENTION_KEY_LENGTH_MLA, hparams.n_embd_head_k_mla_impl, false); + ml.get_key(LLM_KV_ATTENTION_VALUE_LENGTH_MLA, hparams.n_embd_head_v_mla_impl, false); + ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp); + ml.get_key(LLM_KV_EXPERT_SHARED_COUNT, hparams.n_expert_shared); + + // DSA parameters + ml.get_key(LLM_KV_ATTENTION_INDEXER_HEAD_COUNT, hparams.indexer_n_head); + ml.get_key(LLM_KV_ATTENTION_INDEXER_KEY_LENGTH, hparams.indexer_head_size); + ml.get_key(LLM_KV_ATTENTION_INDEXER_TOP_K, hparams.indexer_top_k); + + // Expert gating function (GLM-4.5 uses sigmoid) + ml.get_key(LLM_KV_EXPERT_GATING_FUNC, hparams.expert_gating_func, false); + if (hparams.expert_gating_func == LLAMA_EXPERT_GATING_FUNC_TYPE_NONE) { + hparams.expert_gating_func = LLAMA_EXPERT_GATING_FUNC_TYPE_SIGMOID; + } + + // NextN/MTP parameters + ml.get_key(LLM_KV_NEXTN_PREDICT_LAYERS, hparams.n_layer_nextn, false); + GGML_ASSERT(hparams.n_layer_nextn < hparams.n_layer_all && "n_layer_nextn must be < n_layer_impl"); + + switch (hparams.n_layer()) { + case 79: type = LLM_TYPE_744B_A40B; break; + default: type = LLM_TYPE_UNKNOWN; + } +} + +void llama_model_glm_dsa::load_arch_tensors(llama_model_loader &) { + LLAMA_LOAD_LOCALS; + const int64_t n_expert_shared = hparams.n_expert_shared; + + const bool is_mla = hparams.is_mla(); + if (!is_mla) { + throw std::runtime_error("GLM_DSA architecture requires MLA"); + } + + // note: these are the actual head sizes you get when treating as MHA or after "decompression" using wv_b for MLA + const int64_t n_embd_head_k_mla = hparams.n_embd_head_k_mla(); + const int64_t n_embd_head_v_mla = hparams.n_embd_head_v_mla(); + + const int64_t n_embd_head_qk_rope = hparams.n_rot(); + const int64_t n_embd_head_qk_nope = n_embd_head_k_mla - n_embd_head_qk_rope; + + const int64_t q_lora_rank = hparams.n_lora_q; + const int64_t kv_lora_rank = hparams.n_lora_kv; + + const int64_t n_ff_exp = hparams.n_ff_exp; + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + // try to load output.weight, if not found, use token_embd (tied embeddings) + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); + if (!output) { + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); + } + + for (int i = 0; i < n_layer_all; ++i) { + int flags = 0; + if (i >= n_layer) { + // skip all tensors in the NextN layers + // TODO @ngxson : TENSOR_NOT_REQUIRED was a hack, need to remove it later + flags |= TENSOR_SKIP | TENSOR_NOT_REQUIRED; + } + + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, flags); + layer.attn_q_a_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_A_NORM, "weight", i), {q_lora_rank}, flags); + layer.attn_kv_a_norm = create_tensor(tn(LLM_TENSOR_ATTN_KV_A_NORM, "weight", i), {kv_lora_rank}, flags); + + layer.wq_a = create_tensor(tn(LLM_TENSOR_ATTN_Q_A, "weight", i), {n_embd, q_lora_rank}, flags); + layer.wq_b = create_tensor(tn(LLM_TENSOR_ATTN_Q_B, "weight", i), {q_lora_rank, n_head * n_embd_head_k_mla}, flags); + + layer.wkv_a_mqa = create_tensor(tn(LLM_TENSOR_ATTN_KV_A_MQA, "weight", i), {n_embd, kv_lora_rank + n_embd_head_qk_rope}, flags); + + // note: only old legacy GGUF files will have the unsplit wkv_b tensor in + layer.wk_b = create_tensor(tn(LLM_TENSOR_ATTN_K_B, "weight", i), {n_embd_head_qk_nope, kv_lora_rank, n_head}, flags); + layer.wv_b = create_tensor(tn(LLM_TENSOR_ATTN_V_B, "weight", i), {kv_lora_rank, n_embd_head_v_mla, n_head}, flags); + + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_head * n_embd_head_v_mla, n_embd}, flags); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, flags); + + // DSA indexer + layer.indexer_k_norm = create_tensor(tn(LLM_TENSOR_INDEXER_K_NORM, "weight", i), {hparams.indexer_head_size}, flags); + layer.indexer_k_norm_b = create_tensor(tn(LLM_TENSOR_INDEXER_K_NORM, "bias", i), {hparams.indexer_head_size}, flags); + layer.indexer_proj = create_tensor(tn(LLM_TENSOR_INDEXER_PROJ, "weight", i), {n_embd, hparams.indexer_n_head}, flags); + layer.indexer_attn_k = create_tensor(tn(LLM_TENSOR_INDEXER_ATTN_K, "weight", i), {n_embd, hparams.indexer_head_size}, flags); + layer.indexer_attn_q_b = create_tensor(tn(LLM_TENSOR_INDEXER_ATTN_Q_B, "weight", i), {q_lora_rank, hparams.indexer_n_head * hparams.indexer_head_size}, flags); + if (i < (int) hparams.n_layer_dense_lead) { + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, flags); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, flags); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, flags); + } else { + layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, flags); + layer.ffn_exp_probs_b = create_tensor(tn(LLM_TENSOR_FFN_EXP_PROBS_B, "bias", i), {n_expert}, TENSOR_NOT_REQUIRED); + + if (n_expert == 0) { + throw std::runtime_error("n_expert must be > 0"); + } + if (n_expert_used == 0) { + throw std::runtime_error("n_expert_used must be > 0"); + } + + // MoE branch + layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}, flags); + layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff_exp, n_embd, n_expert}, flags); + layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}, flags); + + // Shared expert branch + layer.ffn_gate_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), {n_embd, n_ff_exp * n_expert_shared}, flags); + layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), { n_ff_exp * n_expert_shared, n_embd}, flags); + layer.ffn_up_shexp = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), {n_embd, n_ff_exp * n_expert_shared}, flags); + } + + // NextN/MTP tensors (preserved but unused) - conditionally load for last n_layer_nextn + if (i >= n_layer) { + layer.nextn.eh_proj = create_tensor(tn(LLM_TENSOR_NEXTN_EH_PROJ, "weight", i), { 2 * n_embd, n_embd }, flags); + layer.nextn.enorm = create_tensor(tn(LLM_TENSOR_NEXTN_ENORM, "weight", i), { n_embd }, flags); + layer.nextn.hnorm = create_tensor(tn(LLM_TENSOR_NEXTN_HNORM, "weight", i), { n_embd }, flags); + + // Optional tensors + layer.nextn.embed_tokens = create_tensor(tn(LLM_TENSOR_NEXTN_EMBED_TOKENS, "weight", i), { n_embd, n_vocab }, flags | TENSOR_NOT_REQUIRED); + layer.nextn.shared_head_head = create_tensor(tn(LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD, "weight", i), { n_embd, n_vocab }, flags | TENSOR_NOT_REQUIRED); + layer.nextn.shared_head_norm = create_tensor(tn(LLM_TENSOR_NEXTN_SHARED_HEAD_NORM, "weight", i), { n_embd }, flags | TENSOR_NOT_REQUIRED); + } + } +} + +std::unique_ptr<llm_graph_context> llama_model_glm_dsa::build_arch_graph(const llm_graph_params & params) const { + return std::make_unique<graph>(*this, params); +} + diff --git a/examples/talk-llama/models/glm4-moe.cpp b/examples/talk-llama/models/glm4-moe.cpp index 003f70f7396..d60e47ddf0c 100644 --- a/examples/talk-llama/models/glm4-moe.cpp +++ b/examples/talk-llama/models/glm4-moe.cpp @@ -1,9 +1,139 @@ #include "models.h" -llm_build_glm4_moe::llm_build_glm4_moe(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { - const int64_t n_embd_head = hparams.n_embd_head_v; +void llama_model_glm4_moe::load_arch_hparams(llama_model_loader & ml) { + ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp); + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + ml.get_key_or_arr(LLM_KV_ROPE_DIMENSION_SECTIONS, hparams.rope_sections, 4, false); - GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); + // MoE parameters + ml.get_key(LLM_KV_EXPERT_COUNT, hparams.n_expert); + ml.get_key(LLM_KV_EXPERT_USED_COUNT, hparams.n_expert_used); + ml.get_key(LLM_KV_EXPERT_SHARED_COUNT, hparams.n_expert_shared); + ml.get_key(LLM_KV_LEADING_DENSE_BLOCK_COUNT, hparams.n_layer_dense_lead, false); + ml.get_key(LLM_KV_EXPERT_WEIGHTS_SCALE, hparams.expert_weights_scale, false); + ml.get_key(LLM_KV_EXPERT_WEIGHTS_NORM, hparams.expert_weights_norm, false); + + // Expert gating function (GLM-4.5 uses sigmoid) + ml.get_key(LLM_KV_EXPERT_GATING_FUNC, hparams.expert_gating_func, false); + if (hparams.expert_gating_func == LLAMA_EXPERT_GATING_FUNC_TYPE_NONE) { + hparams.expert_gating_func = LLAMA_EXPERT_GATING_FUNC_TYPE_SIGMOID; + } + + // NextN/MTP parameters + ml.get_key(LLM_KV_NEXTN_PREDICT_LAYERS, hparams.n_layer_nextn, false); + GGML_ASSERT(hparams.n_layer_nextn < hparams.n_layer_all && "n_layer_nextn must be < n_layer_impl"); + + switch (hparams.n_layer()) { + case 46: type = LLM_TYPE_106B_A12B; break; // GLM-4.5-Air + case 48: type = LLM_TYPE_102B_A12B; break; // Solar Open + case 92: type = LLM_TYPE_355B_A32B; break; // GLM-4.5 + default: type = LLM_TYPE_UNKNOWN; + } +} + +void llama_model_glm4_moe::load_arch_tensors(llama_model_loader &) { + LLAMA_LOAD_LOCALS; + const int64_t n_expert_shared = hparams.n_expert_shared; + + + GGML_ASSERT(hparams.n_expert > 0 && "n_expert must be > 0 for GLM4_MOE MoE layers"); + GGML_ASSERT(hparams.n_expert_used > 0 && "n_expert_used must be > 0 for GLM4_MOE MoE layers"); + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), { n_embd, n_vocab }, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), { n_embd }, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), { n_embd, n_vocab }, TENSOR_NOT_REQUIRED); + // if output is NULL, init from the input tok embed + if (output == NULL) { + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), { n_embd, n_vocab }, TENSOR_DUPLICATED); + } + + // Load ALL tensors including NextN layer to satisfy total tensor count + // but only PROCESS up to last layer (skipping final NextN layer) in forward pass + for (int i = 0; i < n_layer_all; ++i) { + int flags = 0; + if (i >= n_layer) { + // skip all tensors in the NextN layers + flags |= TENSOR_SKIP; + } + + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), { n_embd }, flags); + + // GLM-style attention with bias terms + create_tensor_qkv(layer, i, n_embd, n_embd_head_k * n_head, n_embd_k_gqa, n_embd_v_gqa, flags); + + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), { n_embd_head_k * n_head, n_embd }, flags); + + // K/Q norm tensors (optional for GLM-4.5 355B variant) + layer.attn_q_norm = create_tensor( + tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), { n_embd_head_k }, TENSOR_NOT_REQUIRED | flags); + layer.attn_k_norm = create_tensor( + tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), { n_embd_head_k }, TENSOR_NOT_REQUIRED | flags); + + layer.attn_post_norm = create_tensor(tn(LLM_TENSOR_ATTN_POST_NORM, "weight", i), { n_embd }, flags); + + // Check if this layer uses MoE or dense FFN based on n_layer_dense_lead + // GLM 4.5 uses hybrid architecture: layer 0 is dense, layers 1+ are MoE + const bool use_moe = (static_cast<uint32_t>(i) >= hparams.n_layer_dense_lead); + + if (use_moe) { + // MoE layers + layer.ffn_gate_inp = + create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), { n_embd, n_expert }, flags); + layer.ffn_exp_probs_b = create_tensor(tn(LLM_TENSOR_FFN_EXP_PROBS_B, "bias", i), { n_expert }, flags); + + // MoE branch + const int64_t n_ff_exp = hparams.n_ff_exp ? hparams.n_ff_exp : n_ff / n_expert_used; + + layer.ffn_gate_exps = create_tensor( + tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert }, flags); + layer.ffn_down_exps = create_tensor( + tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), { n_ff_exp, n_embd, n_expert }, flags); + layer.ffn_up_exps = create_tensor( + tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert }, flags); + + // Shared expert + if (n_expert_shared > 0) { + const int64_t n_ff_shexp = n_ff_exp * n_expert_shared; + layer.ffn_gate_shexp = create_tensor( + tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), { n_embd, n_ff_shexp }, flags); + layer.ffn_down_shexp = create_tensor( + tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), { n_ff_shexp, n_embd }, flags); + layer.ffn_up_shexp = create_tensor( + tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), { n_embd, n_ff_shexp }, flags); + } + } else { + // Dense layers (first k layers) - GLM uses separate gate/up projections + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), { n_embd, n_ff }, flags); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd }, flags); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), { n_embd, n_ff }, flags); + } + + // NextN/MTP tensors (preserved but unused) - conditionally load for last nextn_predict_layers + if (i >= n_layer) { + layer.nextn.eh_proj = create_tensor(tn(LLM_TENSOR_NEXTN_EH_PROJ, "weight", i), { 2 * n_embd, n_embd }, flags); + layer.nextn.enorm = create_tensor(tn(LLM_TENSOR_NEXTN_ENORM, "weight", i), { n_embd }, flags); + layer.nextn.hnorm = create_tensor(tn(LLM_TENSOR_NEXTN_HNORM, "weight", i), { n_embd }, flags); + + // Optional tensors + layer.nextn.embed_tokens = create_tensor(tn(LLM_TENSOR_NEXTN_EMBED_TOKENS, "weight", i), { n_embd, n_vocab }, flags | TENSOR_NOT_REQUIRED); + layer.nextn.shared_head_head = create_tensor(tn(LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD, "weight", i), { n_embd, n_vocab }, flags | TENSOR_NOT_REQUIRED); + layer.nextn.shared_head_norm = create_tensor(tn(LLM_TENSOR_NEXTN_SHARED_HEAD_NORM, "weight", i), { n_embd }, flags | TENSOR_NOT_REQUIRED); + } + } +} + +std::unique_ptr<llm_graph_context> llama_model_glm4_moe::build_arch_graph(const llm_graph_params & params) const { + return std::make_unique<graph>(*this, params); +} + +llama_model_glm4_moe::graph::graph(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { + const int64_t n_embd_head = hparams.n_embd_head_v(); + + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k()); int sections[4]; std::copy(std::begin(hparams.rope_sections), std::begin(hparams.rope_sections) + 4, sections); @@ -28,8 +158,7 @@ llm_build_glm4_moe::llm_build_glm4_moe(const llama_model & model, const llm_grap // Only process up to last layer (skip final NextN layer) // Final layer tensors are loaded but not processed in forward pass - const int n_transformer_layers = n_layer - hparams.nextn_predict_layers; - for (int il = 0; il < n_transformer_layers; ++il) { + for (int il = 0; il < n_layer; ++il) { ggml_tensor * inpSA = inpL; // Pre-attention norm @@ -38,27 +167,8 @@ llm_build_glm4_moe::llm_build_glm4_moe(const llama_model & model, const llm_grap // self-attention { - ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); - if (model.layers[il].bq) { - Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq); - } - cb(Qcur, "Qcur", il); - - ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); - if (model.layers[il].bk) { - Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk); - } - cb(Kcur, "Kcur", il); - - ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); - if (model.layers[il].bv) { - Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv); - } - cb(Vcur, "Vcur", il); - - Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); - Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + auto [Qcur, Kcur, Vcur] = build_qkv(model.layers[il], cur, + n_embd_head, n_head, n_head_kv, il); // Apply Q/K norm if available (GLM-4.5 355B variant) if (model.layers[il].attn_q_norm) { @@ -94,10 +204,10 @@ llm_build_glm4_moe::llm_build_glm4_moe(const llama_model & model, const llm_grap cb(Vcur, "Vcur", il); cur = build_attn(inp_attn, - model.layers[il].wo, NULL, + model.layers[il].wo, NULL, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } - if (il == n_transformer_layers - 1 && inp_out_ids) { + if (il == n_layer - 1 && inp_out_ids) { cur = ggml_get_rows(ctx0, cur, inp_out_ids); inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); } @@ -128,7 +238,7 @@ llm_build_glm4_moe::llm_build_glm4_moe(const llama_model & model, const llm_grap model.layers[il].ffn_exp_probs_b, n_expert, n_expert_used, LLM_FFN_SILU, hparams.expert_weights_norm, - true, hparams.expert_weights_scale, + hparams.expert_weights_scale, (llama_expert_gating_func_type) hparams.expert_gating_func, il); cb(routed_out, "ffn_moe_out", il); @@ -161,7 +271,7 @@ llm_build_glm4_moe::llm_build_glm4_moe(const llama_model & model, const llm_grap res->t_embd = cur; // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/glm4.cpp b/examples/talk-llama/models/glm4.cpp index 204aa3932af..b4326c5f210 100644 --- a/examples/talk-llama/models/glm4.cpp +++ b/examples/talk-llama/models/glm4.cpp @@ -1,12 +1,78 @@ #include "models.h" +void llama_model_glm4::load_arch_hparams(llama_model_loader & ml) { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + ml.get_key_or_arr(LLM_KV_ROPE_DIMENSION_SECTIONS, hparams.rope_sections, 4, false); + + // NextN/MTP parameters (GLM-OCR) + ml.get_key(LLM_KV_NEXTN_PREDICT_LAYERS, hparams.n_layer_nextn, false); + GGML_ASSERT(hparams.n_layer_nextn < hparams.n_layer_all && "n_layer_nextn must be < n_layer_impl"); + + switch (hparams.n_layer()) { + case 17: type = LLM_TYPE_1B; break; // GLM-OCR + case 40: type = LLM_TYPE_9B; break; + case 61: type = LLM_TYPE_32B; break; + default: type = LLM_TYPE_UNKNOWN; + } +} + +void llama_model_glm4::load_arch_tensors(llama_model_loader &) { + LLAMA_LOAD_LOCALS; + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); + // if output is NULL, init from the input tok embed + if (output == NULL) { + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); + } + + for (int i = 0; i < n_layer_all; ++i) { + int flags = 0; + if (i >= n_layer) { + // skip all tensors in the NextN layers + flags |= TENSOR_SKIP; + } + + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, flags); + create_tensor_qkv(layer, i, n_embd, n_embd_head_k * n_head, n_embd_k_gqa, n_embd_v_gqa, flags); + + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, flags); + + layer.attn_post_norm = create_tensor(tn(LLM_TENSOR_ATTN_POST_NORM, "weight", i), {n_embd}, flags); + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, flags); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, flags); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff * 2}, flags); -llm_build_glm4::llm_build_glm4(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { - const int64_t n_embd_head = hparams.n_embd_head_v; - const int64_t n_embd_gqa = hparams.n_embd_v_gqa(); + layer.ffn_post_norm = create_tensor(tn(LLM_TENSOR_FFN_POST_NORM, "weight", i), {n_embd}, flags); - GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); + // NextN/MTP tensors (preserved but unused) - conditionally load for last nextn_predict_layers + if (i >= n_layer) { + layer.nextn.eh_proj = create_tensor(tn(LLM_TENSOR_NEXTN_EH_PROJ, "weight", i), { 2 * n_embd, n_embd }, flags); + layer.nextn.enorm = create_tensor(tn(LLM_TENSOR_NEXTN_ENORM, "weight", i), { n_embd }, flags); + layer.nextn.hnorm = create_tensor(tn(LLM_TENSOR_NEXTN_HNORM, "weight", i), { n_embd }, flags); + + // Optional tensors + layer.nextn.embed_tokens = create_tensor(tn(LLM_TENSOR_NEXTN_EMBED_TOKENS, "weight", i), { n_embd, n_vocab }, flags | TENSOR_NOT_REQUIRED); + layer.nextn.shared_head_head = create_tensor(tn(LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD, "weight", i), { n_embd, n_vocab }, flags | TENSOR_NOT_REQUIRED); + layer.nextn.shared_head_norm = create_tensor(tn(LLM_TENSOR_NEXTN_SHARED_HEAD_NORM, "weight", i), { n_embd }, flags | TENSOR_NOT_REQUIRED); + } + } +} + +std::unique_ptr<llm_graph_context> llama_model_glm4::build_arch_graph(const llm_graph_params & params) const { + return std::make_unique<graph>(*this, params); +} + +llama_model_glm4::graph::graph(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { + const int64_t n_embd_head = hparams.n_embd_head_v(); + + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k()); int sections[4]; std::copy(std::begin(hparams.rope_sections), std::begin(hparams.rope_sections) + 4, sections); @@ -29,6 +95,8 @@ llm_build_glm4::llm_build_glm4(const llama_model & model, const llm_graph_params ggml_tensor * inp_out_ids = build_inp_out_ids(); + // Only process up to last layer (skip final NextN layer) + // Final layer tensors are loaded but not processed in forward pass for (int il = 0; il < n_layer; ++il) { ggml_tensor * inpSA = inpL; @@ -38,40 +106,8 @@ llm_build_glm4::llm_build_glm4(const llama_model & model, const llm_graph_params // self-attention { - ggml_tensor * Qcur = nullptr; - ggml_tensor * Kcur = nullptr; - ggml_tensor * Vcur = nullptr; - - if (model.layers[il].wqkv == nullptr) { - Qcur = build_lora_mm(model.layers[il].wq, cur); - if (model.layers[il].bq) { - Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq); - } - Kcur = build_lora_mm(model.layers[il].wk, cur); - if (model.layers[il].bk) { - Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk); - } - Vcur = build_lora_mm(model.layers[il].wv, cur); - if (model.layers[il].bv) { - Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv); - } - Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); - Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); - } else { - cur = build_lora_mm(model.layers[il].wqkv, cur); - cb(cur, "wqkv", il); - if (model.layers[il].bqkv) { - cur = ggml_add(ctx0, cur, model.layers[il].bqkv); - cb(cur, "bqkv", il); - } - Qcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head, n_tokens, n_embd_head * sizeof(float), cur->nb[1], - 0 * sizeof(float) * (n_embd)); - Kcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head * sizeof(float), - cur->nb[1], 1 * sizeof(float) * (n_embd)); - Vcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head * sizeof(float), - cur->nb[1], 1 * sizeof(float) * (n_embd + n_embd_gqa)); - } + auto [Qcur, Kcur, Vcur] = build_qkv(model.layers[il], cur, + n_embd_head, n_head, n_head_kv, il); if (use_mrope) { Qcur = ggml_rope_multi(ctx0, Qcur, inp_pos, nullptr, @@ -97,7 +133,7 @@ llm_build_glm4::llm_build_glm4(const llama_model & model, const llm_graph_params cb(Vcur, "Vcur", il); cur = build_attn(inp_attn, - model.layers[il].wo, NULL, + model.layers[il].wo, NULL, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f / sqrtf(float(n_embd_head)), il); } if (il == n_layer - 1 && inp_out_ids) { @@ -130,9 +166,13 @@ llm_build_glm4::llm_build_glm4(const llama_model & model, const llm_graph_params cur = build_norm(cur, model.layers[il].ffn_post_norm, NULL, LLM_NORM_RMS, il); cb(cur, "post_mlp_norm", il); } - // Add residual connection after post-MLP norm - inpL = ggml_add(ctx0, cur, ffn_inp); - cb(inpL, "l_out", il); + cur = ggml_add(ctx0, cur, ffn_inp); + + cur = build_cvec(cur, il); + cb(cur, "l_out", il); + + // input for next layer + inpL = cur; } // Final norm cur = build_norm(inpL, model.output_norm, NULL, LLM_NORM_RMS, -1); @@ -141,7 +181,7 @@ llm_build_glm4::llm_build_glm4(const llama_model & model, const llm_graph_params res->t_embd = cur; // Output projection - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/gpt2.cpp b/examples/talk-llama/models/gpt2.cpp index 60761c8e765..45afbccc121 100644 --- a/examples/talk-llama/models/gpt2.cpp +++ b/examples/talk-llama/models/gpt2.cpp @@ -1,10 +1,64 @@ #include "models.h" -llm_build_gpt2::llm_build_gpt2(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { - const int64_t n_embd_head = hparams.n_embd_head_v; - const int64_t n_embd_gqa = hparams.n_embd_v_gqa(); +void llama_model_gpt2::load_arch_hparams(llama_model_loader & ml) { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); + + switch (hparams.n_layer()) { + case 12: type = LLM_TYPE_SMALL; break; + case 24: type = LLM_TYPE_MEDIUM; break; + case 36: type = LLM_TYPE_LARGE; break; + case 48: type = LLM_TYPE_XL; break; + default: type = LLM_TYPE_UNKNOWN; + } +} + +void llama_model_gpt2::load_arch_tensors(llama_model_loader &) { + LLAMA_LOAD_LOCALS; + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + pos_embd = create_tensor(tn(LLM_TENSOR_POS_EMBD, "weight"), {n_embd, n_ctx_train}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output_norm_b = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); + + // if output is NULL, init from the input tok embed + if (output == NULL) { + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); + } + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + layer.attn_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "bias", i), {n_embd}, 0); + + layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, 0); + layer.wqkv_b = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "bias", i), {n_embd + 2*n_embd_gqa}, 0); + + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); + layer.wo_b = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, 0); - GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + layer.ffn_norm_b = create_tensor(tn(LLM_TENSOR_FFN_NORM, "bias", i), {n_embd}, 0); + + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0); + layer.ffn_down_b = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, 0); + + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_up_b = create_tensor(tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff}, 0); + } +} + +std::unique_ptr<llm_graph_context> llama_model_gpt2::build_arch_graph(const llm_graph_params & params) const { + return std::make_unique<graph>(*this, params); +} + +llama_model_gpt2::graph::graph(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { + const int64_t n_embd_head = hparams.n_embd_head_v(); + + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k()); ggml_tensor * cur; ggml_tensor * pos; @@ -34,22 +88,11 @@ llm_build_gpt2::llm_build_gpt2(const llama_model & model, const llm_graph_params // self-attention { - cur = build_lora_mm(model.layers[il].wqkv, cur); - cb(cur, "wqkv", il); - - cur = ggml_add(ctx0, cur, model.layers[il].bqkv); - cb(cur, "bqkv", il); - - ggml_tensor * Qcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 0*sizeof(float)*(n_embd)); - ggml_tensor * Kcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 1*sizeof(float)*(n_embd)); - ggml_tensor * Vcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa)); - - cb(Qcur, "Qcur", il); - cb(Kcur, "Kcur", il); - cb(Vcur, "Vcur", il); + auto [Qcur, Kcur, Vcur] = build_qkv(model.layers[il], cur, + n_embd_head, n_head, n_head_kv, il); cur = build_attn(inp_attn, - model.layers[il].wo, model.layers[il].bo, + model.layers[il].wo, model.layers[il].wo_b, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } @@ -96,7 +139,7 @@ llm_build_gpt2::llm_build_gpt2(const llama_model & model, const llm_graph_params cb(cur, "result_norm", -1); res->t_embd = cur; - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/gptneox.cpp b/examples/talk-llama/models/gptneox.cpp index 2151b14e939..ed5e8c50da2 100644 --- a/examples/talk-llama/models/gptneox.cpp +++ b/examples/talk-llama/models/gptneox.cpp @@ -1,11 +1,93 @@ #include "models.h" +void llama_model_gptneox::load_arch_hparams(llama_model_loader & ml) { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); + ml.get_key(LLM_KV_USE_PARALLEL_RESIDUAL, hparams.use_par_res); + + switch (hparams.n_layer()) { + case 6: + switch (hparams.n_ff()) { + case 512: type = LLM_TYPE_14M; break; + case 2048: type = LLM_TYPE_70M; break; + default: type = LLM_TYPE_UNKNOWN; + } break; + case 12: + switch (hparams.n_ff()) { + case 3072: type = LLM_TYPE_160M; break; + default: type = LLM_TYPE_UNKNOWN; + } break; + case 16: + switch (hparams.n_ff()) { + case 8192: type = LLM_TYPE_1B; break; + default: type = LLM_TYPE_UNKNOWN; + } break; + case 24: + switch (hparams.n_ff()) { + case 4096: type = LLM_TYPE_410M; break; + case 8192: type = LLM_TYPE_1_4B; break; + default: type = LLM_TYPE_UNKNOWN; + } break; + case 32: + switch (hparams.n_ff()) { + case 10240: type = LLM_TYPE_2_8B; break; + case 16384: type = LLM_TYPE_6_9B; break; + default: type = LLM_TYPE_UNKNOWN; + } break; + case 36: + switch (hparams.n_ff()) { + case 20480: type = LLM_TYPE_12B; break; + default: type = LLM_TYPE_UNKNOWN; + } break; + case 44: + switch (hparams.n_ff()) { + case 24576: type = LLM_TYPE_20B; break; + default: type = LLM_TYPE_UNKNOWN; + } break; + default: type = LLM_TYPE_UNKNOWN; + } +} + +void llama_model_gptneox::load_arch_tensors(llama_model_loader &) { + LLAMA_LOAD_LOCALS; + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output_norm_b = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0); + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + layer.attn_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "bias", i), {n_embd}, 0); + + layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, 0); + layer.wqkv_b = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "bias", i), {n_embd + 2*n_embd_gqa}, 0); + + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); + layer.wo_b = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, 0); -llm_build_gptneox::llm_build_gptneox(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { - const int64_t n_embd_head = hparams.n_embd_head_v; - const int64_t n_embd_gqa = hparams.n_embd_v_gqa(); + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + layer.ffn_norm_b = create_tensor(tn(LLM_TENSOR_FFN_NORM, "bias", i), {n_embd}, 0); - GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0); + layer.ffn_down_b = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, 0); + + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_up_b = create_tensor(tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff}, 0); + } +} + +std::unique_ptr<llm_graph_context> llama_model_gptneox::build_arch_graph(const llm_graph_params & params) const { + return std::make_unique<graph>(*this, params); +} + +llama_model_gptneox::graph::graph(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { + const int64_t n_embd_head = hparams.n_embd_head_v(); + + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k()); ggml_tensor * cur; ggml_tensor * inpL; @@ -28,15 +110,8 @@ llm_build_gptneox::llm_build_gptneox(const llama_model & model, const llm_graph_ // self-attention { - cur = build_lora_mm(model.layers[il].wqkv, cur); - cb(cur, "wqkv", il); - - cur = ggml_add(ctx0, cur, model.layers[il].bqkv); - cb(cur, "bqkv", il); - - ggml_tensor * Qcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 0*sizeof(float)*(n_embd)); - ggml_tensor * Kcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 1*sizeof(float)*(n_embd)); - ggml_tensor * Vcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa)); + auto [Qcur, Kcur, Vcur] = build_qkv(model.layers[il], cur, + n_embd_head, n_head, n_head_kv, il); Qcur = ggml_rope_ext( ctx0, Qcur, inp_pos, nullptr, @@ -55,7 +130,7 @@ llm_build_gptneox::llm_build_gptneox(const llama_model & model, const llm_graph_ cb(Vcur, "Vcur", il); cur = build_attn(inp_attn, - model.layers[il].wo, model.layers[il].bo, + model.layers[il].wo, model.layers[il].wo_b, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } @@ -135,7 +210,7 @@ llm_build_gptneox::llm_build_gptneox(const llama_model & model, const llm_graph_ cb(cur, "result_norm", -1); res->t_embd = cur; - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/granite-hybrid.cpp b/examples/talk-llama/models/granite-hybrid.cpp index f6ca4c17a21..eb23095aece 100644 --- a/examples/talk-llama/models/granite-hybrid.cpp +++ b/examples/talk-llama/models/granite-hybrid.cpp @@ -1,10 +1,140 @@ #include "models.h" +void llama_model_granite_hybrid::load_arch_hparams(llama_model_loader & ml) { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + ml.get_key(LLM_KV_LOGIT_SCALE, hparams.f_logit_scale, /* required */ false); + ml.get_key(LLM_KV_RESIDUAL_SCALE, hparams.f_residual_scale, /* required */ false); + ml.get_key(LLM_KV_EMBEDDING_SCALE, hparams.f_embedding_scale, /* required */ false); + ml.get_key(LLM_KV_ATTENTION_SCALE, hparams.f_attention_scale, /* required */ false); + + ml.get_key(LLM_KV_SSM_CONV_KERNEL, hparams.ssm_d_conv); + ml.get_key(LLM_KV_SSM_INNER_SIZE, hparams.ssm_d_inner); + ml.get_key(LLM_KV_SSM_STATE_SIZE, hparams.ssm_d_state); + ml.get_key(LLM_KV_SSM_TIME_STEP_RANK, hparams.ssm_dt_rank); + ml.get_key(LLM_KV_SSM_GROUP_COUNT, hparams.ssm_n_group); + + // Granite uses rope_finetuned as a switch for rope, so default to true + bool rope_finetuned = true; + ml.get_key(LLM_KV_ROPE_SCALING_FINETUNED, rope_finetuned, false); + hparams.rope_finetuned = rope_finetuned; + + // A layer is recurrent IFF the n_head_kv value is set to 0 + for (uint32_t i = 0; i < hparams.n_layer(); ++i) { + hparams.is_recr_impl[i] = hparams.n_head_kv(i) == 0; + } + + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + + switch (hparams.n_embd) { + case 768: type = LLM_TYPE_350M; break; + case 1536: type = (hparams.n_ff() == 512 ? LLM_TYPE_7B_A1B : LLM_TYPE_1B); break; + case 2048: case 2560: type = LLM_TYPE_3B; break; + case 4096: type = LLM_TYPE_32B; break; + default: type = LLM_TYPE_UNKNOWN; + } + + // For Granite MoE Shared + ml.get_key(LLM_KV_EXPERT_SHARED_FEED_FORWARD_LENGTH, hparams.n_ff_shexp, /* required */ false); +} + +void llama_model_granite_hybrid::load_arch_tensors(llama_model_loader &) { + LLAMA_LOAD_LOCALS; + + // mamba2 Mixer SSM params + // NOTE: int64_t for tensor dimensions + const int64_t d_conv = hparams.ssm_d_conv; + const int64_t d_inner = hparams.ssm_d_inner; + const int64_t d_state = hparams.ssm_d_state; + const int64_t n_ssm_head = hparams.ssm_dt_rank; + const int64_t n_group = hparams.ssm_n_group; + const int64_t d_in_proj = 2*d_inner + 2*n_group*d_state + n_ssm_head; + + // only an expansion factor of 2 is supported for now + GGML_ASSERT(2 * n_embd == d_inner); + + // embeddings + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + { + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); + // if output is NULL, init from the input tok embed, duplicated to allow offloading + if (output == NULL) { + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); + } + } + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + // norm + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + + if (hparams.is_recr(i)) { + // ssm layers + layer.ssm_in = create_tensor(tn(LLM_TENSOR_SSM_IN, "weight", i), {n_embd, d_in_proj}, 0); + + layer.ssm_conv1d = create_tensor(tn(LLM_TENSOR_SSM_CONV1D, "weight", i), {d_conv, d_inner + 2*n_group*d_state}, 0); + layer.ssm_conv1d_b = create_tensor(tn(LLM_TENSOR_SSM_CONV1D, "bias", i), {d_inner + 2*n_group*d_state}, TENSOR_NOT_REQUIRED); + + layer.ssm_dt_b = create_tensor(tn(LLM_TENSOR_SSM_DT, "bias", i), {n_ssm_head}, 0); + + // no "weight" suffix for these + layer.ssm_a = create_tensor(tn(LLM_TENSOR_SSM_A, i), {1, n_ssm_head}, 0); + layer.ssm_d = create_tensor(tn(LLM_TENSOR_SSM_D, i), {1, n_ssm_head}, 0); + + layer.ssm_norm = create_tensor(tn(LLM_TENSOR_SSM_NORM, "weight", i), {d_inner / n_group, n_group}, 0); + + // out_proj + layer.ssm_out = create_tensor(tn(LLM_TENSOR_SSM_OUT, "weight", i), {d_inner, n_embd}, 0); + } else { + // attention layers (with optional bias) + const int64_t n_head_i = hparams.n_head(i); + const int64_t n_embd_k_gqa_i = hparams.n_embd_k_gqa(i); + const int64_t n_embd_v_gqa_i = hparams.n_embd_v_gqa(i); + create_tensor_qkv(layer, i, n_embd, n_embd_head_k * n_head_i, n_embd_k_gqa_i, n_embd_v_gqa_i, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head_i, n_embd}, 0); + layer.wo_b = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); + } + + // feed forward (w/ optional biases) + if (n_expert > 0) { + // MoE FFN + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + layer.rope_freqs = create_tensor(tn(LLM_TENSOR_ROPE_FREQS, "weight", i), {n_rot/2}, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0)); + layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0); + layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {n_embd, n_ff, n_expert}, TENSOR_NOT_REQUIRED); + layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), { n_ff, n_embd, n_expert}, 0); + layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), {n_embd, n_ff, n_expert}, 0); + + // For Granite MoE Shared + if (hparams.n_ff_shexp > 0) { + layer.ffn_gate_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), {n_embd, hparams.n_ff_shexp}, 0); + layer.ffn_up_shexp = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), {n_embd, hparams.n_ff_shexp}, 0); + layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), {hparams.n_ff_shexp, n_embd}, 0); + } + } else { + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + layer.rope_freqs = create_tensor(tn(LLM_TENSOR_ROPE_FREQS, "weight", i), {n_rot/2}, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0)); + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_gate_b = create_tensor(tn(LLM_TENSOR_FFN_GATE, "bias", i), {n_ff}, TENSOR_NOT_REQUIRED); + layer.ffn_down_b = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); + layer.ffn_up_b = create_tensor(tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff}, TENSOR_NOT_REQUIRED); + } + } +} + +std::unique_ptr<llm_graph_context> llama_model_granite_hybrid::build_arch_graph(const llm_graph_params & params) const { + return std::make_unique<graph>(*this, params); +} -llm_build_granite_hybrid::llm_build_granite_hybrid(const llama_model & model, const llm_graph_params & params) : - llm_graph_context_mamba(params) { - const int64_t n_embd_head = hparams.n_embd_head_v; - GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); +llama_model_granite_hybrid::graph::graph(const llama_model & model, const llm_graph_params & params) : + llm_build_mamba_base(params) { + const int64_t n_embd_head = hparams.n_embd_head_v(); + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k()); ggml_tensor * cur; ggml_tensor * inpL; @@ -28,7 +158,7 @@ llm_build_granite_hybrid::llm_build_granite_hybrid(const llama_model & model, co cur = build_norm(inpL, model.layers[il].attn_norm, NULL, LLM_NORM_RMS, il); cb(cur, "attn_norm", il); - if (hparams.is_recurrent(il)) { + if (hparams.is_recr(il)) { // ssm layer // cur = build_mamba2_layer(inp->get_recr(), cur, model, ubatch, il); } else { @@ -56,7 +186,7 @@ llm_build_granite_hybrid::llm_build_granite_hybrid(const llama_model & model, co res->t_embd = cur; // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); // For Granite architectures - scale logits if (hparams.f_logit_scale) { @@ -68,37 +198,13 @@ llm_build_granite_hybrid::llm_build_granite_hybrid(const llama_model & model, co ggml_build_forward_expand(gf, cur); } -ggml_tensor * llm_build_granite_hybrid::build_attention_layer(ggml_tensor * cur, +ggml_tensor * llama_model_granite_hybrid::graph::build_attention_layer(ggml_tensor * cur, ggml_tensor * inp_pos, llm_graph_input_attn_kv * inp_attn, const llama_model & model, const int64_t n_embd_head, const int il) { - // compute Q and K and (optionally) RoPE them - ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); - cb(Qcur, "Qcur", il); - if (model.layers[il].bq) { - Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq); - cb(Qcur, "Qcur", il); - } - - ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); - cb(Kcur, "Kcur", il); - if (model.layers[il].bk) { - Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk); - cb(Kcur, "Kcur", il); - } - - ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); - cb(Vcur, "Vcur", il); - if (model.layers[il].bv) { - Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv); - cb(Vcur, "Vcur", il); - } - - Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, hparams.n_head(il), n_tokens); - Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, hparams.n_head_kv(il), n_tokens); - Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, hparams.n_head_kv(il), n_tokens); + auto [Qcur, Kcur, Vcur] = build_qkv(model.layers[il], cur, n_embd_head, hparams.n_head(il), hparams.n_head_kv(il), il); const bool use_rope = hparams.rope_finetuned; if (use_rope) { @@ -117,13 +223,13 @@ ggml_tensor * llm_build_granite_hybrid::build_attention_layer(ggml_tensor * const float kq_scale = hparams.f_attention_scale == 0.0f ? 1.0f / sqrtf(float(n_embd_head)) : hparams.f_attention_scale; cur = build_attn(inp_attn, - model.layers[il].wo, model.layers[il].bo, + model.layers[il].wo, model.layers[il].wo_b, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale, il); cb(cur, "attn_out", il); return cur; } -ggml_tensor * llm_build_granite_hybrid::build_layer_ffn(ggml_tensor * cur, +ggml_tensor * llama_model_granite_hybrid::graph::build_layer_ffn(ggml_tensor * cur, ggml_tensor * inpSA, const llama_model & model, const int il) { @@ -160,7 +266,7 @@ ggml_tensor * llm_build_granite_hybrid::build_layer_ffn(ggml_tensor * cur, nullptr, n_expert, n_expert_used, LLM_FFN_SILU, true, - false, 0.0, + hparams.expert_weights_scale, LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX, il); cb(moe_out, "ffn_moe_out", il); diff --git a/examples/talk-llama/models/granite-moe.cpp b/examples/talk-llama/models/granite-moe.cpp new file mode 100644 index 00000000000..115263c418f --- /dev/null +++ b/examples/talk-llama/models/granite-moe.cpp @@ -0,0 +1,89 @@ +#include "models.h" + +void llama_model_granite_moe::load_arch_hparams(llama_model_loader & ml) { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + ml.get_key(LLM_KV_LOGIT_SCALE, hparams.f_logit_scale); + ml.get_key(LLM_KV_RESIDUAL_SCALE, hparams.f_residual_scale, false); + ml.get_key(LLM_KV_EMBEDDING_SCALE, hparams.f_embedding_scale, false); + ml.get_key(LLM_KV_ATTENTION_SCALE, hparams.f_attention_scale, false); + + // Granite uses rope_finetuned as a switch for rope, so default to true + bool rope_finetuned = true; + ml.get_key(LLM_KV_ROPE_SCALING_FINETUNED, rope_finetuned, false); + hparams.rope_finetuned = rope_finetuned; + + switch (hparams.n_layer()) { + case 32: type = LLM_TYPE_3B; break; + case 40: type = LLM_TYPE_3B; break; + // Add additional layer/vocab/etc checks here for other model sizes + default: type = LLM_TYPE_UNKNOWN; + } + + // For Granite MoE Shared + ml.get_key(LLM_KV_EXPERT_SHARED_FEED_FORWARD_LENGTH, hparams.n_ff_shexp, /* required */ false); +} + +void llama_model_granite_moe::load_arch_tensors(llama_model_loader &) { + LLAMA_LOAD_LOCALS; + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); + + // if output is NULL, init from the input tok embed + if (output == NULL) { + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); + } + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + + create_tensor_qkv(layer, i, n_embd, n_embd_head_k * n_head, n_embd_k_gqa, n_embd_v_gqa, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0); + + // optional bias tensors + layer.wo_b = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + + if (hparams.rope_scaling_type_train == LLAMA_ROPE_SCALING_TYPE_LONGROPE) { + layer.rope_long = create_tensor(tn(LLM_TENSOR_ROPE_FACTORS_LONG, "weight", i), {n_rot/2}, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0)); + layer.rope_short = create_tensor(tn(LLM_TENSOR_ROPE_FACTORS_SHORT, "weight", i), {n_rot/2}, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0)); + } + else { + layer.rope_freqs = create_tensor(tn(LLM_TENSOR_ROPE_FREQS, "weight", i), {n_rot/2}, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0)); + } + + if (n_expert == 0) { + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + + // optional MLP bias + layer.ffn_gate_b = create_tensor(tn(LLM_TENSOR_FFN_GATE, "bias", i), {n_ff}, TENSOR_NOT_REQUIRED); + layer.ffn_down_b = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); + layer.ffn_up_b = create_tensor(tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff}, TENSOR_NOT_REQUIRED); + } else { + layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0); + layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {n_embd, n_ff, n_expert}, TENSOR_NOT_REQUIRED); + layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), { n_ff, n_embd, n_expert}, 0); + layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), {n_embd, n_ff, n_expert}, 0); + + // For Granite MoE Shared + if (hparams.n_ff_shexp > 0) { + layer.ffn_gate_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), {n_embd, hparams.n_ff_shexp}, 0); + layer.ffn_up_shexp = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), {n_embd, hparams.n_ff_shexp}, 0); + layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), {hparams.n_ff_shexp, n_embd}, 0); + } + } + } +} + +std::unique_ptr<llm_graph_context> llama_model_granite_moe::build_arch_graph(const llm_graph_params & params) const { + return std::make_unique<graph>(*this, params); +} + diff --git a/examples/talk-llama/models/granite.cpp b/examples/talk-llama/models/granite.cpp index 18748e9c26c..4a75c5ff3cc 100644 --- a/examples/talk-llama/models/granite.cpp +++ b/examples/talk-llama/models/granite.cpp @@ -1,15 +1,124 @@ #include "models.h" +#include <sstream> + +void llama_model_granite::load_arch_hparams(llama_model_loader & ml) { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + ml.get_key(LLM_KV_LOGIT_SCALE, hparams.f_logit_scale); + ml.get_key(LLM_KV_RESIDUAL_SCALE, hparams.f_residual_scale, false); + ml.get_key(LLM_KV_EMBEDDING_SCALE, hparams.f_embedding_scale, false); + ml.get_key(LLM_KV_ATTENTION_SCALE, hparams.f_attention_scale, false); + + // Granite4 Vision uses array deepstack_mapping + ml.get_arr(LLM_KV_DEEPSTACK_MAPPING, hparams.deepstack_mapping_arr, false); + + // Count the unique deepstack input indices + std::unordered_set<uint32_t> unique_deepstack_idxs; + for (const auto val : hparams.deepstack_mapping_arr) { + if (val >= 0) { + unique_deepstack_idxs.insert(val); + } + } + hparams.n_deepstack_layers = unique_deepstack_idxs.size(); + + // Ensure all values are valid (avoid overflow attacks) + for (const auto val : unique_deepstack_idxs) { + if (val > hparams.n_deepstack_layers) { + std::stringstream ss; + ss << "Invalid deepstack index: " << val << " > " << hparams.n_deepstack_layers; + throw std::runtime_error(ss.str()); + } + } + + // Granite uses rope_finetuned as a switch for rope, so default to true + bool rope_finetuned = true; + ml.get_key(LLM_KV_ROPE_SCALING_FINETUNED, rope_finetuned, false); + hparams.rope_finetuned = rope_finetuned; + + switch (hparams.n_layer()) { + case 32: type = LLM_TYPE_3B; break; + case 40: type = LLM_TYPE_3B; break; + // Add additional layer/vocab/etc checks here for other model sizes + default: type = LLM_TYPE_UNKNOWN; + } + + // For Granite MoE Shared + ml.get_key(LLM_KV_EXPERT_SHARED_FEED_FORWARD_LENGTH, hparams.n_ff_shexp, /* required */ false); +} + +void llama_model_granite::load_arch_tensors(llama_model_loader &) { + LLAMA_LOAD_LOCALS; + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); + + // if output is NULL, init from the input tok embed + if (output == NULL) { + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); + } + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + + create_tensor_qkv(layer, i, n_embd, n_embd_head_k * n_head, n_embd_k_gqa, n_embd_v_gqa, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0); + + // optional bias tensors + layer.wo_b = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + + if (hparams.rope_scaling_type_train == LLAMA_ROPE_SCALING_TYPE_LONGROPE) { + layer.rope_long = create_tensor(tn(LLM_TENSOR_ROPE_FACTORS_LONG, "weight", i), {n_rot/2}, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0)); + layer.rope_short = create_tensor(tn(LLM_TENSOR_ROPE_FACTORS_SHORT, "weight", i), {n_rot/2}, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0)); + } + else { + layer.rope_freqs = create_tensor(tn(LLM_TENSOR_ROPE_FREQS, "weight", i), {n_rot/2}, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0)); + } + + if (n_expert == 0) { + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); -llm_build_granite::llm_build_granite( + // optional MLP bias + layer.ffn_gate_b = create_tensor(tn(LLM_TENSOR_FFN_GATE, "bias", i), {n_ff}, TENSOR_NOT_REQUIRED); + layer.ffn_down_b = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); + layer.ffn_up_b = create_tensor(tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff}, TENSOR_NOT_REQUIRED); + } else { + layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0); + layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {n_embd, n_ff, n_expert}, TENSOR_NOT_REQUIRED); + layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), { n_ff, n_embd, n_expert}, 0); + layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), {n_embd, n_ff, n_expert}, 0); + + // For Granite MoE Shared + if (hparams.n_ff_shexp > 0) { + layer.ffn_gate_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), {n_embd, hparams.n_ff_shexp}, 0); + layer.ffn_up_shexp = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), {n_embd, hparams.n_ff_shexp}, 0); + layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), {hparams.n_ff_shexp, n_embd}, 0); + } + } + } +} + +std::unique_ptr<llm_graph_context> llama_model_granite::build_arch_graph(const llm_graph_params & params) const { + return std::make_unique<graph>(*this, params); +} + +llama_model_granite::graph::graph( const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { - const int64_t n_embd_head = hparams.n_embd_head_v; + const int64_t n_embd_head = hparams.n_embd_head_v(); - GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); - GGML_ASSERT(n_embd_head == hparams.n_rot); + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k()); + GGML_ASSERT(n_embd_head == n_rot); ggml_tensor * cur; ggml_tensor * inpL; @@ -26,6 +135,20 @@ llm_build_granite::llm_build_granite( ggml_tensor * inp_out_ids = build_inp_out_ids(); for (int il = 0; il < n_layer; ++il) { + + // Granite Vision 4.1 deepstack: inject the projector stream that + // targets decoder layer `il` before the decoder runs. + // NOTE: skip the first deepstack layer since that's inpL + const auto & deepstack_emb_idx = hparams.deepstack_mapping_arr[il]; + if (il > 0 && deepstack_emb_idx >= 0) { + ggml_tensor * ds = ggml_view_2d(ctx0, + res->t_inp_embd, n_embd, n_tokens, + res->t_inp_embd->nb[1], + deepstack_emb_idx * n_embd * sizeof(float)); + inpL = ggml_add(ctx0, inpL, ds); + cb(inpL, "deepstack_in", il); + } + ggml_tensor * inpSA = inpL; // norm @@ -59,7 +182,7 @@ llm_build_granite::llm_build_granite( res->t_embd = cur; // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); // For Granite architectures - scale logits cur = ggml_scale(ctx0, cur, 1.0f / hparams.f_logit_scale); @@ -69,7 +192,7 @@ llm_build_granite::llm_build_granite( ggml_build_forward_expand(gf, cur); } -ggml_tensor * llm_build_granite::build_attention_layer( +ggml_tensor * llama_model_granite::graph::build_attention_layer( ggml_tensor * cur, ggml_tensor * inp_pos, llm_graph_input_attn_kv * inp_attn, @@ -77,31 +200,8 @@ ggml_tensor * llm_build_granite::build_attention_layer( const int64_t n_embd_head, const int il) { - // compute Q and K and (optionally) RoPE them - ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); - cb(Qcur, "Qcur", il); - if (model.layers[il].bq) { - Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq); - cb(Qcur, "Qcur", il); - } - - ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); - cb(Kcur, "Kcur", il); - if (model.layers[il].bk) { - Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk); - cb(Kcur, "Kcur", il); - } - - ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); - cb(Vcur, "Vcur", il); - if (model.layers[il].bv) { - Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv); - cb(Vcur, "Vcur", il); - } - - Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, hparams.n_head(il), n_tokens); - Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, hparams.n_head_kv(il), n_tokens); - Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, hparams.n_head_kv(il), n_tokens); + auto [Qcur, Kcur, Vcur] = build_qkv(model.layers[il], cur, + n_embd_head, hparams.n_head(il), hparams.n_head_kv(il), il); const bool use_rope = hparams.rope_finetuned; if (use_rope) { @@ -125,13 +225,13 @@ ggml_tensor * llm_build_granite::build_attention_layer( const float kq_scale = hparams.f_attention_scale == 0.0f ? 1.0f/sqrtf(float(n_embd_head)) : hparams.f_attention_scale; cur = build_attn(inp_attn, - model.layers[il].wo, model.layers[il].bo, + model.layers[il].wo, model.layers[il].wo_b, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale, il); cb(cur, "attn_out", il); return cur; } -ggml_tensor * llm_build_granite::build_layer_ffn( +ggml_tensor * llama_model_granite::graph::build_layer_ffn( ggml_tensor * cur, ggml_tensor * inpSA, const llama_model & model, @@ -175,7 +275,7 @@ ggml_tensor * llm_build_granite::build_layer_ffn( nullptr, n_expert, n_expert_used, LLM_FFN_SILU, true, - false, 0.0, + hparams.expert_weights_scale, LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX, il); cb(moe_out, "ffn_moe_out", il); diff --git a/examples/talk-llama/models/grok.cpp b/examples/talk-llama/models/grok.cpp index 3c54dfee636..42f38af6724 100644 --- a/examples/talk-llama/models/grok.cpp +++ b/examples/talk-llama/models/grok.cpp @@ -1,10 +1,93 @@ #include "models.h" -llm_build_grok::llm_build_grok(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { - const int64_t n_embd_head = hparams.n_embd_head_v; +void llama_model_grok::load_arch_hparams(llama_model_loader & ml) { + // defaults for old GGUFs + hparams.yarn_beta_fast = 8.0f; + hparams.f_logit_scale = 0.5773502691896257f; + hparams.f_embedding_scale = 78.38367176906169f; + hparams.f_attn_out_scale = 0.08838834764831845f; + hparams.f_attn_logit_softcapping = 30.0f; + hparams.f_router_logit_softcapping = 30.0f; + // no final_logit_softcapping in grok-1 + hparams.f_final_logit_softcapping = 0.0f; + + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp, false); + ml.get_key(LLM_KV_LOGIT_SCALE, hparams.f_logit_scale, false); + ml.get_key(LLM_KV_EMBEDDING_SCALE, hparams.f_embedding_scale, false); + ml.get_key(LLM_KV_ATTENTION_OUTPUT_SCALE, hparams.f_attn_out_scale, false); + ml.get_key(LLM_KV_ATTN_LOGIT_SOFTCAPPING, hparams.f_attn_logit_softcapping, false); + ml.get_key(LLM_KV_ROUTER_LOGIT_SOFTCAPPING, hparams.f_router_logit_softcapping, false); + ml.get_key(LLM_KV_FINAL_LOGIT_SOFTCAPPING, hparams.f_final_logit_softcapping, false); + + ml.get_key(LLM_KV_ATTENTION_TEMPERATURE_LENGTH, hparams.attn_temp_length, false); + ml.get_key(LLM_KV_ROPE_SCALING_YARN_EXT_FACTOR, hparams.yarn_ext_factor, false); + ml.get_key(LLM_KV_ROPE_SCALING_YARN_ATTN_FACTOR, hparams.yarn_attn_factor, false); + ml.get_key(LLM_KV_ROPE_SCALING_YARN_BETA_FAST, hparams.yarn_beta_fast, false); + ml.get_key(LLM_KV_ROPE_SCALING_YARN_BETA_SLOW, hparams.yarn_beta_slow, false); + + switch (hparams.n_layer()) { + case 64: type = LLM_TYPE_314B; break; + default: type = LLM_TYPE_UNKNOWN; + } +} + +void llama_model_grok::load_arch_tensors(llama_model_loader &) { + LLAMA_LOAD_LOCALS; + + if (n_expert == 0) { + throw std::runtime_error(arch_name() + " model cannot have zero experts"); + } + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); + + // if output is NULL, init from the input tok embed + if (output == NULL) { + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); + } + + const int64_t n_ff_exp = hparams.n_ff_exp ? hparams.n_ff_exp : n_ff/* / n_expert_used*/; // grok-1 n_ff_exp == n_ff + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + + create_tensor_qkv(layer, i, n_embd, n_embd, n_embd_gqa, n_embd_gqa, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); + + layer.attn_out_norm = create_tensor(tn(LLM_TENSOR_ATTN_OUT_NORM, "weight", i), {n_embd}, 0); - GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); - GGML_ASSERT(n_embd_head == hparams.n_rot); + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, TENSOR_NOT_REQUIRED); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, TENSOR_NOT_REQUIRED); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, TENSOR_NOT_REQUIRED); + + layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0); + layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {n_embd, n_ff_exp, n_expert}, TENSOR_NOT_REQUIRED); + layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff_exp, n_embd, n_expert}, 0); + layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), {n_embd, n_ff_exp, n_expert}, 0); + + layer.ffn_post_norm = create_tensor(tn(LLM_TENSOR_LAYER_OUT_NORM, "weight", i), {n_embd}, TENSOR_NOT_REQUIRED); + if (!layer.ffn_post_norm) { + layer.ffn_post_norm = create_tensor(tn(LLM_TENSOR_FFN_POST_NORM, "weight", i), {n_embd}, 0); + } + } +} + +std::unique_ptr<llm_graph_context> llama_model_grok::build_arch_graph(const llm_graph_params & params) const { + return std::make_unique<graph>(*this, params); +} + +llama_model_grok::graph::graph(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { + const int64_t n_embd_head = hparams.n_embd_head_v(); + + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k()); + GGML_ASSERT(n_embd_head == n_rot); ggml_tensor * cur; ggml_tensor * inpL; @@ -30,27 +113,8 @@ llm_build_grok::llm_build_grok(const llama_model & model, const llm_graph_params // self-attention { // compute Q and K and RoPE them - ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); - cb(Qcur, "Qcur", il); - if (model.layers[il].bq) { - Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq); - cb(Qcur, "Qcur", il); - } - ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); - cb(Kcur, "Kcur", il); - if (model.layers[il].bk) { - Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk); - cb(Kcur, "Kcur", il); - } - ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); - cb(Vcur, "Vcur", il); - if (model.layers[il].bv) { - Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv); - cb(Vcur, "Vcur", il); - } - Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); - Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + auto [Qcur, Kcur, Vcur] = build_qkv(model.layers[il], cur, + n_embd_head, n_head, n_head_kv, il); Qcur = ggml_rope_ext( ctx0, Qcur, inp_pos, nullptr, @@ -69,7 +133,7 @@ llm_build_grok::llm_build_grok(const llama_model & model, const llm_graph_params cb(Vcur, "Vcur", il); cur = build_attn(inp_attn, - model.layers[il].wo, model.layers[il].bo, + model.layers[il].wo, model.layers[il].wo_b, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f, il); } if (il == n_layer - 1 && inp_out_ids) { @@ -99,7 +163,7 @@ llm_build_grok::llm_build_grok(const llama_model & model, const llm_graph_params nullptr, n_expert, n_expert_used, LLM_FFN_GELU, true, - false, 0.0, + hparams.expert_weights_scale, LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX, il); cb(moe_out, "ffn_moe_out", il); @@ -142,7 +206,7 @@ llm_build_grok::llm_build_grok(const llama_model & model, const llm_graph_params res->t_embd = cur; // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cur = ggml_scale(ctx0, cur, hparams.f_logit_scale); diff --git a/examples/talk-llama/models/grovemoe.cpp b/examples/talk-llama/models/grovemoe.cpp index 56b6db9a3d0..643a448e59a 100644 --- a/examples/talk-llama/models/grovemoe.cpp +++ b/examples/talk-llama/models/grovemoe.cpp @@ -1,14 +1,76 @@ #include "models.h" +void llama_model_grovemoe::load_arch_hparams(llama_model_loader & ml) { + ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp); + ml.get_key(LLM_KV_EXPERT_CHUNK_FEED_FORWARD_LENGTH, hparams.n_ff_chexp, false); + ml.get_key(LLM_KV_EXPERT_GROUP_SCALE, hparams.expert_group_scale); + ml.get_key(LLM_KV_EXPERTS_PER_GROUP, hparams.n_group_experts); + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + + switch (hparams.n_layer()) { + case 48: type = LLM_TYPE_30B_A3B; break; + default: type = LLM_TYPE_UNKNOWN; + } +} + +void llama_model_grovemoe::load_arch_tensors(llama_model_loader &) { + LLAMA_LOAD_LOCALS; + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); + // if output is NULL, init from the input tok embed + if (output == NULL) { + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); + } + + GGML_ASSERT(n_expert > 0 && "n_expert must be > 0 for GROVEMOE"); + GGML_ASSERT(n_expert_used > 0 && "n_expert_used must be > 0 for GROVEMOE"); + GGML_ASSERT(hparams.n_group_experts > 0 && "n_group_experts must be > 0 for GROVEMOE"); + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + + create_tensor_qkv(layer, i, n_embd, n_embd_head_k * n_head, n_embd_gqa, n_embd_gqa, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0); + + layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k}, 0); + layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head_k}, 0); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + + layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0); + // MoE branch + const int64_t n_ff_exp = hparams.n_ff_exp ? hparams.n_ff_exp : n_ff / n_expert_used; + const int64_t n_ff_chexp = hparams.n_ff_chexp ? hparams.n_ff_chexp : n_embd_head_k; + const int64_t n_chunk_expert = n_expert / hparams.n_group_experts; + + layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}, 0); + layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff_exp, n_embd, n_expert}, 0); + layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}, 0); + + layer.ffn_gate_chexps = create_tensor(tn(LLM_TENSOR_FFN_GATE_CHEXPS, "weight", i), { n_embd, n_ff_chexp, n_chunk_expert}, 0); + layer.ffn_down_chexps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_CHEXPS, "weight", i), {n_ff_chexp, n_embd, n_chunk_expert}, 0); + layer.ffn_up_chexps = create_tensor(tn(LLM_TENSOR_FFN_UP_CHEXPS, "weight", i), { n_embd, n_ff_chexp, n_chunk_expert}, 0); + } +} + +std::unique_ptr<llm_graph_context> llama_model_grovemoe::build_arch_graph(const llm_graph_params & params) const { + return std::make_unique<graph>(*this, params); +} -llm_build_grovemoe::llm_build_grovemoe(const llama_model & model, const llm_graph_params & params) : +llama_model_grovemoe::graph::graph(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { - const int64_t n_embd_head = hparams.n_embd_head_v; + const int64_t n_embd_head = hparams.n_embd_head_v(); const int64_t n_chunk_expert = n_expert / hparams.n_group_experts; - GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); - GGML_ASSERT(n_embd_head == hparams.n_rot); + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k()); + GGML_ASSERT(n_embd_head == n_rot); ggml_tensor * cur; ggml_tensor * inpL; @@ -32,18 +94,8 @@ llm_build_grovemoe::llm_build_grovemoe(const llama_model & model, const llm_grap // self_attention { // compute Q and K and RoPE them - ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); - cb(Qcur, "Qcur", il); - - ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); - cb(Kcur, "Kcur", il); - - ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); - cb(Vcur, "Vcur", il); - - Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); - Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + auto [Qcur, Kcur, Vcur] = build_qkv(model.layers[il], cur, + n_embd_head, n_head, n_head_kv, il); Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, NULL, LLM_NORM_RMS, il); cb(Qcur, "Qcur_normed", il); @@ -62,7 +114,7 @@ llm_build_grovemoe::llm_build_grovemoe(const llama_model & model, const llm_grap cb(Vcur, "Vcur", il); cur = build_attn(inp_attn, - model.layers[il].wo, model.layers[il].bo, + model.layers[il].wo, model.layers[il].wo_b, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f / sqrtf(float(n_embd_head)), il); } @@ -90,7 +142,7 @@ llm_build_grovemoe::llm_build_grovemoe(const llama_model & model, const llm_grap nullptr, n_expert, n_expert_used, LLM_FFN_SILU, true, - false, 0.0, + hparams.expert_weights_scale, LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX, il, probs); @@ -106,7 +158,7 @@ llm_build_grovemoe::llm_build_grovemoe(const llama_model & model, const llm_grap nullptr, n_chunk_expert, n_expert_used > n_chunk_expert ? n_chunk_expert : n_expert_used, LLM_FFN_SILU, true, - false, 0.0, + hparams.expert_weights_scale, LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX, il, probs); @@ -132,7 +184,7 @@ llm_build_grovemoe::llm_build_grovemoe(const llama_model & model, const llm_grap res->t_embd = cur; // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/hunyuan-dense.cpp b/examples/talk-llama/models/hunyuan-dense.cpp index 7d5dcc7828b..c137bd37c02 100644 --- a/examples/talk-llama/models/hunyuan-dense.cpp +++ b/examples/talk-llama/models/hunyuan-dense.cpp @@ -1,132 +1,6 @@ #include "models.h" -llm_build_hunyuan_dense::llm_build_hunyuan_dense(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { - const int64_t n_embd_head = hparams.n_embd_head_v; - - GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); - GGML_ASSERT(n_embd_head == hparams.n_rot); - - ggml_tensor * cur; - ggml_tensor * inpL; - - inpL = build_inp_embd(model.tok_embd); - - // inp_pos - contains the positions - ggml_tensor * inp_pos = build_inp_pos(); - - auto * inp_attn = build_attn_inp_kv(); - - const float kq_scale = 1.0f / sqrtf(float(n_embd_head)); - - ggml_tensor * inp_out_ids = build_inp_out_ids(); - - for (int il = 0; il < n_layer; ++il) { - ggml_tensor * inpSA = inpL; - - // norm - cur = build_norm(inpL, - model.layers[il].attn_norm, NULL, - LLM_NORM_RMS, il); - cb(cur, "attn_norm", il); - // self-attention - { - // rope freq factors for llama3; may return nullptr for llama2 and other models - ggml_tensor * rope_factors = model.get_rope_factors(cparams, il); - - // compute Q and K and RoPE them - ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); - cb(Qcur, "Qcur", il); - if (model.layers[il].bq) { - Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq); - cb(Qcur, "Qcur", il); - } - ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); - cb(Kcur, "Kcur", il); - if (model.layers[il].bk) { - Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk); - cb(Kcur, "Kcur", il); - } - ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); - cb(Vcur, "Vcur", il); - if (model.layers[il].bv) { - Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv); - cb(Vcur, "Vcur", il); - } - Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); - Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); - - Qcur = ggml_rope_ext( - ctx0, Qcur, inp_pos, rope_factors, - n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, - ext_factor, attn_factor, beta_fast, beta_slow - ); - - cb(Qcur, "Qcur", il); - cb(Kcur, "Kcur", il); - cb(Vcur, "Vcur", il); - - Kcur = ggml_rope_ext( - ctx0, Kcur, inp_pos, rope_factors, - n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, - ext_factor, attn_factor, beta_fast, beta_slow - ); - - Kcur = build_norm(Kcur, - model.layers[il].attn_k_norm, nullptr, - LLM_NORM_RMS, il); - cb(Kcur, "Kcur_norm", il); - - Qcur = build_norm(Qcur, - model.layers[il].attn_q_norm, nullptr, - LLM_NORM_RMS, il); - cb(Qcur, "Qcur_norm", il); - - cur = build_attn(inp_attn, - model.layers[il].wo, model.layers[il].bo, - Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale, il); - cb(cur, "attn_out", il); - } - if (il == n_layer - 1 && inp_out_ids) { - cur = ggml_get_rows(ctx0, cur, inp_out_ids); - inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); - } - ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA); - cb(ffn_inp, "ffn_inp", il); - - cur = build_norm(ffn_inp, - model.layers[il].ffn_norm, NULL, - LLM_NORM_RMS, il); - cb(cur, "ffn_norm", il); - // feed-forward network (non-MoE) - ggml_tensor * cur_mlp = build_ffn(cur, - model.layers[il].ffn_up, NULL, NULL, - model.layers[il].ffn_gate, NULL, NULL, - model.layers[il].ffn_down, NULL, NULL, - NULL, - LLM_FFN_SILU, LLM_FFN_PAR, il); - cb(cur_mlp, "ffn_out", il); - - cur = ggml_add(ctx0, cur_mlp, ffn_inp); - - cur = build_cvec(cur, il); - cb(cur, "l_out", il); - - // input for next layer - inpL = cur; - } - cur = inpL; - - cur = build_norm(cur, - model.output_norm, NULL, - LLM_NORM_RMS, -1); - - cb(cur, "result_norm", -1); - res->t_embd = cur; - // lm_head - cur = build_lora_mm(model.output, cur); - cb(cur, "result_output", -1); - res->t_logits = cur; - - ggml_build_forward_expand(gf, cur); +std::unique_ptr<llm_graph_context> llama_model_hunyuan_dense::build_arch_graph(const llm_graph_params & params) const { + return std::make_unique<graph>(*this, params); } + diff --git a/examples/talk-llama/models/hunyuan-moe.cpp b/examples/talk-llama/models/hunyuan-moe.cpp index 77e39de5b8b..4d55f5e7f31 100644 --- a/examples/talk-llama/models/hunyuan-moe.cpp +++ b/examples/talk-llama/models/hunyuan-moe.cpp @@ -1,10 +1,63 @@ #include "models.h" -llm_build_hunyuan_moe::llm_build_hunyuan_moe(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { - const int64_t n_embd_head = hparams.n_embd_head_v; +void llama_model_hunyuan_moe::load_arch_hparams(llama_model_loader & ml) { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp); + ml.get_key(LLM_KV_EXPERT_SHARED_FEED_FORWARD_LENGTH, hparams.n_ff_shexp, false); + + switch (hparams.n_layer()) { + case 32: type = LLM_TYPE_A13B; break; + default: type = LLM_TYPE_UNKNOWN; + } +} + +void llama_model_hunyuan_moe::load_arch_tensors(llama_model_loader &) { + LLAMA_LOAD_LOCALS; + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); + // if output is NULL, init from the input tok embed + if (output == NULL) { + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); + } + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + const uint32_t n_ff_shexp = hparams.n_ff_shexp > 0 ? hparams.n_ff_shexp : hparams.n_ff(i); + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + + create_tensor_qkv(layer, i, n_embd, n_embd_head_k * n_head, n_embd_k_gqa, n_embd_v_gqa, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0); + + layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k}, 0); + layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head_k}, 0); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); - GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); - GGML_ASSERT(n_embd_head == hparams.n_rot); + layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0); + layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {n_embd, n_ff, n_expert}, 0); + layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), { n_ff, n_embd, n_expert}, 0); + layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), {n_embd, n_ff, n_expert}, 0); + + layer.ffn_gate_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), {n_embd, n_ff_shexp}, 0); + layer.ffn_up_shexp = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), {n_embd, n_ff_shexp}, 0); + layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), {n_ff_shexp, n_embd}, 0); + } +} + +std::unique_ptr<llm_graph_context> llama_model_hunyuan_moe::build_arch_graph(const llm_graph_params & params) const { + return std::make_unique<graph>(*this, params); +} + +llama_model_hunyuan_moe::graph::graph(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { + const int64_t n_embd_head = hparams.n_embd_head_v(); + + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k()); + GGML_ASSERT(n_embd_head == n_rot); ggml_tensor * cur; ggml_tensor * inpL; @@ -35,27 +88,8 @@ llm_build_hunyuan_moe::llm_build_hunyuan_moe(const llama_model & model, const ll ggml_tensor * rope_factors = model.get_rope_factors(cparams, il); // compute Q and K and RoPE them - ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); - cb(Qcur, "Qcur", il); - if (model.layers[il].bq) { - Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq); - cb(Qcur, "Qcur", il); - } - ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); - cb(Kcur, "Kcur", il); - if (model.layers[il].bk) { - Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk); - cb(Kcur, "Kcur", il); - } - ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); - cb(Vcur, "Vcur", il); - if (model.layers[il].bv) { - Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv); - cb(Vcur, "Vcur", il); - } - Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); - Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + auto [Qcur, Kcur, Vcur] = build_qkv(model.layers[il], cur, + n_embd_head, n_head, n_head_kv, il); Qcur = ggml_rope_ext( ctx0, Qcur, inp_pos, rope_factors, @@ -84,7 +118,7 @@ llm_build_hunyuan_moe::llm_build_hunyuan_moe(const llama_model & model, const ll cb(Qcur, "Qcur_norm", il); cur = build_attn(inp_attn, - model.layers[il].wo, model.layers[il].bo, + model.layers[il].wo, model.layers[il].wo_b, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale, il); cb(cur, "attn_out", il); } @@ -119,8 +153,7 @@ llm_build_hunyuan_moe::llm_build_hunyuan_moe(const llama_model & model, const ll n_expert, n_expert_used, LLM_FFN_SILU, true, // norm_topk_prob - false, - 0.0, + hparams.expert_weights_scale, LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX, il); cb(cur_moe, "ffn_moe_out", il); @@ -146,7 +179,7 @@ llm_build_hunyuan_moe::llm_build_hunyuan_moe(const llama_model & model, const ll res->t_embd = cur; // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/hunyuan-vl.cpp b/examples/talk-llama/models/hunyuan-vl.cpp new file mode 100644 index 00000000000..da9bb74de7e --- /dev/null +++ b/examples/talk-llama/models/hunyuan-vl.cpp @@ -0,0 +1,189 @@ +#include "models.h" + +void llama_model_hunyuan_vl::load_arch_hparams(llama_model_loader & ml) { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + ml.get_key_or_arr(LLM_KV_ROPE_DIMENSION_SECTIONS, hparams.rope_sections, 4, false); + + // XDRoPE / NTK-aware scaling: base = rope_theta * alpha^(dim / (dim - 2)) + if (hparams.rope_scaling_alpha > 0.0f) { + const int dim = hparams.n_embd_head_k(); + hparams.rope_freq_base_train = hparams.rope_freq_base_train + * powf(hparams.rope_scaling_alpha, (float)dim / (float)(dim - 2)); + } + + switch (hparams.n_embd) { + case 1024: type = LLM_TYPE_0_5B; break; + case 2048: type = LLM_TYPE_1_8B; break; + case 3072: type = LLM_TYPE_4B; break; + case 4096: type = LLM_TYPE_7B; break; + default: type = LLM_TYPE_UNKNOWN; + } +} + +void llama_model_hunyuan_vl::load_arch_tensors(llama_model_loader &) { + LLAMA_LOAD_LOCALS; + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); + // if output is NULL, init from the input tok embed + if (output == NULL) { + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); + } + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + + create_tensor_qkv(layer, i, n_embd, n_embd_head_k * n_head, n_embd_k_gqa, n_embd_v_gqa, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0); + + layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k}, 0); + layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head_k}, 0); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + + } +} + +std::unique_ptr<llm_graph_context> llama_model_hunyuan_vl::build_arch_graph(const llm_graph_params & params) const { + return std::make_unique<graph>(*this, params); +} + +llama_model_hunyuan_vl::graph::graph(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { + const int64_t n_embd_head = hparams.n_embd_head_v(); + + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k()); + GGML_ASSERT(n_embd_head == n_rot); + + const bool use_mrope = hparams.use_mrope(); + + int sections[4]; + std::copy(std::begin(hparams.rope_sections), std::begin(hparams.rope_sections) + 4, sections); + + ggml_tensor * cur; + ggml_tensor * inpL; + + inpL = build_inp_embd(model.tok_embd); + + // inp_pos - contains the positions + ggml_tensor * inp_pos = build_inp_pos(); + + auto * inp_attn = build_attn_inp_kv(); + + const float kq_scale = 1.0f / sqrtf(float(n_embd_head)); + + ggml_tensor * inp_out_ids = build_inp_out_ids(); + + for (int il = 0; il < n_layer; ++il) { + ggml_tensor * inpSA = inpL; + + // norm + cur = build_norm(inpL, + model.layers[il].attn_norm, NULL, + LLM_NORM_RMS, il); + cb(cur, "attn_norm", il); + // self-attention + { + // rope freq factors for llama3; may return nullptr for llama2 and other models + ggml_tensor * rope_factors = model.get_rope_factors(cparams, il); + + // compute Q and K and RoPE them + auto [Qcur, Kcur, Vcur] = build_qkv(model.layers[il], cur, + n_embd_head, n_head, n_head_kv, il); + + if (use_mrope) { + Qcur = ggml_rope_multi( + ctx0, Qcur, inp_pos, rope_factors, + n_rot, sections, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + Kcur = ggml_rope_multi( + ctx0, Kcur, inp_pos, rope_factors, + n_rot, sections, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + } else { + Qcur = ggml_rope_ext( + ctx0, Qcur, inp_pos, rope_factors, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + Kcur = ggml_rope_ext( + ctx0, Kcur, inp_pos, rope_factors, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + } + + cb(Qcur, "Qcur", il); + cb(Kcur, "Kcur", il); + cb(Vcur, "Vcur", il); + + Kcur = build_norm(Kcur, + model.layers[il].attn_k_norm, nullptr, + LLM_NORM_RMS, il); + cb(Kcur, "Kcur_norm", il); + + Qcur = build_norm(Qcur, + model.layers[il].attn_q_norm, nullptr, + LLM_NORM_RMS, il); + cb(Qcur, "Qcur_norm", il); + + cur = build_attn(inp_attn, + model.layers[il].wo, model.layers[il].wo_b, model.layers[il].wo_s, + Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale, il); + cb(cur, "attn_out", il); + } + if (il == n_layer - 1 && inp_out_ids) { + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); + } + ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA); + cb(ffn_inp, "ffn_inp", il); + + cur = build_norm(ffn_inp, + model.layers[il].ffn_norm, NULL, + LLM_NORM_RMS, il); + cb(cur, "ffn_norm", il); + // feed-forward network (non-MoE) + ggml_tensor * cur_mlp = build_ffn(cur, + model.layers[il].ffn_up, NULL, NULL, + model.layers[il].ffn_gate, NULL, NULL, + model.layers[il].ffn_down, NULL, NULL, + NULL, + LLM_FFN_SILU, LLM_FFN_PAR, il); + cb(cur_mlp, "ffn_out", il); + + cur = ggml_add(ctx0, cur_mlp, ffn_inp); + + cur = build_cvec(cur, il); + cb(cur, "l_out", il); + + // input for next layer + inpL = cur; + } + cur = inpL; + + cur = build_norm(cur, + model.output_norm, NULL, + LLM_NORM_RMS, -1); + + cb(cur, "result_norm", -1); + res->t_embd = cur; + // lm_head + cur = build_lora_mm(model.output, cur, model.output_s); + cb(cur, "result_output", -1); + res->t_logits = cur; + + ggml_build_forward_expand(gf, cur); +} diff --git a/examples/talk-llama/models/internlm2.cpp b/examples/talk-llama/models/internlm2.cpp index 387e8211270..f6cfdfb9458 100644 --- a/examples/talk-llama/models/internlm2.cpp +++ b/examples/talk-llama/models/internlm2.cpp @@ -1,10 +1,48 @@ #include "models.h" -llm_build_internlm2::llm_build_internlm2(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { - const int64_t n_embd_head = hparams.n_embd_head_v; +void llama_model_internlm2::load_arch_hparams(llama_model_loader & ml) { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); - GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); - GGML_ASSERT(n_embd_head == hparams.n_rot); + switch (hparams.n_layer()) { + case 32: type = LLM_TYPE_7B; break; + case 48: type = LLM_TYPE_20B; break; + default: type = LLM_TYPE_UNKNOWN; + } +} + +void llama_model_internlm2::load_arch_tensors(llama_model_loader &) { + LLAMA_LOAD_LOCALS; + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0); + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + // layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, 0); + create_tensor_qkv(layer, i, n_embd, n_embd, n_embd_gqa, n_embd_gqa, 0); + + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + } +} + +std::unique_ptr<llm_graph_context> llama_model_internlm2::build_arch_graph(const llm_graph_params & params) const { + return std::make_unique<graph>(*this, params); +} + +llama_model_internlm2::graph::graph(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { + const int64_t n_embd_head = hparams.n_embd_head_v(); + + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k()); + GGML_ASSERT(n_embd_head == n_rot); ggml_tensor * cur; ggml_tensor * inpL; @@ -30,27 +68,8 @@ llm_build_internlm2::llm_build_internlm2(const llama_model & model, const llm_gr // self-attention { // compute Q and K and RoPE them - ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); - cb(Qcur, "Qcur", il); - if (model.layers[il].bq) { - Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq); - cb(Qcur, "Qcur", il); - } - ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); - cb(Kcur, "Kcur", il); - if (model.layers[il].bk) { - Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk); - cb(Kcur, "Kcur", il); - } - ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); - cb(Vcur, "Vcur", il); - if (model.layers[il].bv) { - Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv); - cb(Vcur, "Vcur", il); - } - Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); - Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + auto [Qcur, Kcur, Vcur] = build_qkv(model.layers[il], cur, + n_embd_head, n_head, n_head_kv, il); Qcur = ggml_rope_ext( ctx0, Qcur, inp_pos, nullptr, @@ -69,7 +88,7 @@ llm_build_internlm2::llm_build_internlm2(const llama_model & model, const llm_gr cb(Vcur, "Vcur", il); cur = build_attn(inp_attn, - model.layers[il].wo, model.layers[il].bo, + model.layers[il].wo, model.layers[il].wo_b, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } if (il == n_layer - 1 && inp_out_ids) { @@ -111,7 +130,7 @@ llm_build_internlm2::llm_build_internlm2(const llama_model & model, const llm_gr res->t_embd = cur; // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/jais.cpp b/examples/talk-llama/models/jais.cpp index 3e3376e6a62..415103ce23a 100644 --- a/examples/talk-llama/models/jais.cpp +++ b/examples/talk-llama/models/jais.cpp @@ -1,10 +1,61 @@ #include "models.h" -llm_build_jais::llm_build_jais(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { - const int64_t n_embd_head = hparams.n_embd_head_v; - const int64_t n_embd_gqa = hparams.n_embd_v_gqa(); +void llama_model_jais::load_arch_hparams(llama_model_loader & ml) { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); + ml.get_key(LLM_KV_ATTENTION_MAX_ALIBI_BIAS, hparams.f_max_alibi_bias, false); + + switch (hparams.n_layer()) { + case 24: type = LLM_TYPE_1_3B; break; + case 40: type = LLM_TYPE_13B; break; + /* TODO: add variants */ + default: type = LLM_TYPE_UNKNOWN; + } +} + +void llama_model_jais::load_arch_tensors(llama_model_loader &) { + LLAMA_LOAD_LOCALS; + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output_norm_b = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0); + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + layer.attn_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "bias", i), {n_embd}, 0); + + layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, 0); + layer.wqkv_b = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "bias", i), {n_embd + 2*n_embd_gqa}, 0); + + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); + layer.wo_b = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, 0); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + layer.ffn_norm_b = create_tensor(tn(LLM_TENSOR_FFN_NORM, "bias", i), {n_embd}, 0); + + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0); + layer.ffn_down_b = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, 0); + + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_gate_b = create_tensor(tn(LLM_TENSOR_FFN_GATE, "bias", i), {n_ff}, 0); - GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_up_b = create_tensor(tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff}, 0); + } +} + +std::unique_ptr<llm_graph_context> llama_model_jais::build_arch_graph(const llm_graph_params & params) const { + return std::make_unique<graph>(*this, params); +} + +llama_model_jais::graph::graph(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { + const int64_t n_embd_head = hparams.n_embd_head_v(); + + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k()); ggml_tensor * cur; ggml_tensor * inpL; @@ -24,22 +75,11 @@ llm_build_jais::llm_build_jais(const llama_model & model, const llm_graph_params // self-attention { - cur = build_lora_mm(model.layers[il].wqkv, cur); - cb(cur, "wqkv", il); - - cur = ggml_add(ctx0, cur, model.layers[il].bqkv); - cb(cur, "bqkv", il); - - ggml_tensor * Qcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 0*cur->nb[0]*(n_embd)); - ggml_tensor * Kcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 1*cur->nb[0]*(n_embd)); - ggml_tensor * Vcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 1*cur->nb[0]*(n_embd + n_embd_gqa)); - - cb(Qcur, "Qcur", il); - cb(Kcur, "Kcur", il); - cb(Vcur, "Vcur", il); + auto [Qcur, Kcur, Vcur] = build_qkv(model.layers[il], cur, + n_embd_head, n_head, n_head_kv, il); cur = build_attn(inp_attn, - model.layers[il].wo, model.layers[il].bo, + model.layers[il].wo, model.layers[il].wo_b, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/float(n_embd_head), il); } if (il == n_layer - 1 && inp_out_ids) { @@ -66,8 +106,14 @@ llm_build_jais::llm_build_jais(const llama_model & model, const llm_graph_params LLM_FFN_SILU, LLM_FFN_PAR, il); cb(cur, "ffn_out", il); } - inpL = ggml_add(ctx0, cur, ffn_inp); - cb(inpL, "l_out", il); + + cur = ggml_add(ctx0, cur, ffn_inp); + + cur = build_cvec(cur, il); + cb(cur, "l_out", il); + + // input for next layer + inpL = cur; } cur = build_norm(inpL, model.output_norm, @@ -77,7 +123,7 @@ llm_build_jais::llm_build_jais(const llama_model & model, const llm_graph_params cb(cur, "result_norm", -1); res->t_embd = cur; - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/jais2.cpp b/examples/talk-llama/models/jais2.cpp new file mode 100644 index 00000000000..8610fcc9f82 --- /dev/null +++ b/examples/talk-llama/models/jais2.cpp @@ -0,0 +1,161 @@ +#include "models.h" + +void llama_model_jais2::load_arch_hparams(llama_model_loader & ml) { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); + + switch (hparams.n_layer()) { + case 32: type = LLM_TYPE_8B; break; + case 68: type = LLM_TYPE_70B; break; + default: type = LLM_TYPE_UNKNOWN; + } +} + +void llama_model_jais2::load_arch_tensors(llama_model_loader &) { + LLAMA_LOAD_LOCALS; + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output_norm_b = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); + if (!output) { + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); + } + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + layer.attn_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "bias", i), {n_embd}, 0); + + layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head}, 0); + layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_k_gqa}, 0); + layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_v_gqa}, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0); + + // attention biases - all have shape n_embd (output dimension of projections) + layer.wq_b = create_tensor(tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_embd}, 0); + layer.wk_b = create_tensor(tn(LLM_TENSOR_ATTN_K, "bias", i), {n_embd}, 0); + layer.wv_b = create_tensor(tn(LLM_TENSOR_ATTN_V, "bias", i), {n_embd}, 0); + layer.wo_b = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, 0); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + layer.ffn_norm_b = create_tensor(tn(LLM_TENSOR_FFN_NORM, "bias", i), {n_embd}, 0); + + // Jais-2 uses simple MLP (no gate) with biases + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_up_b = create_tensor(tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0); + layer.ffn_down_b = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, 0); + } +} + +std::unique_ptr<llm_graph_context> llama_model_jais2::build_arch_graph(const llm_graph_params & params) const { + return std::make_unique<graph>(*this, params); +} + +// JAIS-2 model graph builder +// Uses: LayerNorm (not RMSNorm), relu2 activation, separate Q/K/V, RoPE embeddings +llama_model_jais2::graph::graph(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { + const int64_t n_embd_head = hparams.n_embd_head_v(); + + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k()); + GGML_ASSERT(n_embd_head == n_rot); + + ggml_tensor * cur; + ggml_tensor * inpL; + + inpL = build_inp_embd(model.tok_embd); + + // inp_pos - contains the positions + ggml_tensor * inp_pos = build_inp_pos(); + + // KV input for attention + auto * inp_attn = build_attn_inp_kv(); + + ggml_tensor * inp_out_ids = build_inp_out_ids(); + + for (int il = 0; il < n_layer; ++il) { + // Pre-attention LayerNorm + cur = build_norm(inpL, + model.layers[il].attn_norm, + model.layers[il].attn_norm_b, + LLM_NORM, il); + cb(cur, "attn_norm", il); + + // Self-attention with separate Q, K, V projections + { + auto [Qcur, Kcur, Vcur] = build_qkv(model.layers[il], cur, + n_embd_head, n_head, n_head_kv, il); + + // Apply RoPE + Qcur = ggml_rope_ext( + ctx0, Qcur, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + Kcur = ggml_rope_ext( + ctx0, Kcur, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + cb(Qcur, "Qcur_rope", il); + cb(Kcur, "Kcur_rope", il); + + cur = build_attn(inp_attn, + model.layers[il].wo, model.layers[il].wo_b, model.layers[il].wo_s, + Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); + } + + if (il == n_layer - 1 && inp_out_ids) { + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpL = ggml_get_rows(ctx0, inpL, inp_out_ids); + } + + // Residual connection + ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpL); + cb(ffn_inp, "ffn_inp", il); + + // Pre-FFN LayerNorm + cur = build_norm(ffn_inp, + model.layers[il].ffn_norm, + model.layers[il].ffn_norm_b, + LLM_NORM, il); + cb(cur, "ffn_norm", il); + + // FFN with relu2 activation (ReLU squared) - no gate projection + // up -> relu2 -> down + cur = build_ffn(cur, + model.layers[il].ffn_up, model.layers[il].ffn_up_b, NULL, + NULL, NULL, NULL, // no gate + model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL, + NULL, + LLM_FFN_RELU_SQR, LLM_FFN_SEQ, il); + cb(cur, "ffn_out", il); + + // Residual connection + inpL = ggml_add(ctx0, cur, ffn_inp); + inpL = build_cvec(inpL, il); + cb(inpL, "l_out", il); + } + + // Final LayerNorm + cur = build_norm(inpL, + model.output_norm, + model.output_norm_b, + LLM_NORM, -1); + cb(cur, "result_norm", -1); + + res->t_embd = cur; + + // Output projection + cur = build_lora_mm(model.output, cur, model.output_s); + cb(cur, "result_output", -1); + + res->t_logits = cur; + + ggml_build_forward_expand(gf, cur); +} diff --git a/examples/talk-llama/models/jamba.cpp b/examples/talk-llama/models/jamba.cpp index a0187772ccb..dba160b014f 100644 --- a/examples/talk-llama/models/jamba.cpp +++ b/examples/talk-llama/models/jamba.cpp @@ -1,7 +1,112 @@ #include "models.h" -llm_build_jamba::llm_build_jamba(const llama_model & model, const llm_graph_params & params) : llm_graph_context_mamba(params) { - const int64_t n_embd_head = hparams.n_embd_head_v; +void llama_model_jamba::load_arch_hparams(llama_model_loader & ml) { + ml.get_key(LLM_KV_SSM_CONV_KERNEL, hparams.ssm_d_conv); + ml.get_key(LLM_KV_SSM_INNER_SIZE, hparams.ssm_d_inner); + ml.get_key(LLM_KV_SSM_STATE_SIZE, hparams.ssm_d_state); + ml.get_key(LLM_KV_SSM_TIME_STEP_RANK, hparams.ssm_dt_rank); + + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + + for (uint32_t i = 0; i < hparams.n_layer(); ++i) { + hparams.is_recr_impl[i] = hparams.n_head_kv(i) == 0; + } + + switch (hparams.n_layer()) { + // TODO: Jamba layers are a bit heterogeneous, so naming this is hard. + case 12: // 900M 8x???M + case 32: // 51B 16x?B + default: type = LLM_TYPE_UNKNOWN; + } +} + +void llama_model_jamba::load_arch_tensors(llama_model_loader &) { + LLAMA_LOAD_LOCALS; + + const int64_t d_conv = hparams.ssm_d_conv; + const int64_t d_inner = hparams.ssm_d_inner; + const int64_t d_state = hparams.ssm_d_state; + const int64_t dt_rank = hparams.ssm_dt_rank; + + // only an expansion factor of 2 is supported for now + GGML_ASSERT(2 * n_embd == d_inner); + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + { + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); + // if output is NULL, init from the input tok embed, duplicated to allow offloading + if (output == NULL) { + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); + } + } + + for (int i = 0; i < n_layer; ++i) { + const int64_t n_head_kv = hparams.n_head_kv(i); + const int64_t n_embd_gqa = hparams.n_embd_v_gqa(i); + + auto & layer = layers[i]; + + // norm + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + + if (n_head_kv == 0) { + // Mamba layer + layer.ssm_in = create_tensor(tn(LLM_TENSOR_SSM_IN, "weight", i), {n_embd, 2*d_inner}, 0); + + layer.ssm_conv1d = create_tensor(tn(LLM_TENSOR_SSM_CONV1D, "weight", i), {d_conv, d_inner}, 0); + layer.ssm_conv1d_b = create_tensor(tn(LLM_TENSOR_SSM_CONV1D, "bias", i), {d_inner}, 0); + + layer.ssm_x = create_tensor(tn(LLM_TENSOR_SSM_X, "weight", i), {d_inner, dt_rank + 2*d_state}, 0); + + layer.ssm_dt_norm = create_tensor(tn(LLM_TENSOR_SSM_DT_NORM, "weight", i), {dt_rank}, 0); + + layer.ssm_dt = create_tensor(tn(LLM_TENSOR_SSM_DT, "weight", i), {dt_rank, d_inner}, 0); + layer.ssm_dt_b = create_tensor(tn(LLM_TENSOR_SSM_DT, "bias", i), {d_inner}, 0); + + layer.ssm_b_norm = create_tensor(tn(LLM_TENSOR_SSM_B_NORM, "weight", i), {d_state}, 0); + layer.ssm_c_norm = create_tensor(tn(LLM_TENSOR_SSM_C_NORM, "weight", i), {d_state}, 0); + + // no "weight" suffix for these + layer.ssm_a = create_tensor(tn(LLM_TENSOR_SSM_A, i), {d_state, d_inner}, 0); + layer.ssm_d = create_tensor(tn(LLM_TENSOR_SSM_D, i), {d_inner}, 0); + + // out_proj + layer.ssm_out = create_tensor(tn(LLM_TENSOR_SSM_OUT, "weight", i), {d_inner, n_embd}, 0); + } else { + // Attention layers + + create_tensor_qkv(layer, i, n_embd, n_embd, n_embd_gqa, n_embd_gqa, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); + } + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + + layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, TENSOR_NOT_REQUIRED); + + if (layer.ffn_gate_inp) { + // MoE + layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {n_embd, n_ff, n_expert}, 0); + layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff, n_embd, n_expert}, 0); + layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), {n_embd, n_ff, n_expert}, 0); + } else { + // FFN (no MoE) + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + } + } +} + +std::unique_ptr<llm_graph_context> llama_model_jamba::build_arch_graph(const llm_graph_params & params) const { + return std::make_unique<graph>(*this, params); +} + +llama_model_jamba::graph::graph(const llama_model & model, const llm_graph_params & params) : llm_build_mamba_base(params) { + const int64_t n_embd_head = hparams.n_embd_head_v(); ggml_tensor * cur; ggml_tensor * inpL; @@ -24,25 +129,12 @@ llm_build_jamba::llm_build_jamba(const llama_model & model, const llm_graph_para } else { // Attention - struct ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); - struct ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); - struct ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); - - cb(Qcur, "Qcur", il); - cb(Kcur, "Kcur", il); - cb(Vcur, "Vcur", il); - - Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); - Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); - - cb(Qcur, "Qcur", il); - cb(Kcur, "Kcur", il); - cb(Vcur, "Vcur", il); + auto [Qcur, Kcur, Vcur] = build_qkv(model.layers[il], cur, + n_embd_head, n_head, n_head_kv, il); // No RoPE :) cur = build_attn(inp_hybrid->get_attn(), - model.layers[il].wo, NULL, + model.layers[il].wo, NULL, model.layers[il].wo_s, Qcur, Kcur, Vcur, NULL, NULL, NULL, 1.0f/sqrtf(float(n_embd_head)), il); } if (il == n_layer - 1 && inp_out_ids) { @@ -76,7 +168,7 @@ llm_build_jamba::llm_build_jamba(const llama_model & model, const llm_graph_para nullptr, n_expert, n_expert_used, LLM_FFN_SILU, false, - false, 0.0, + hparams.expert_weights_scale, LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX, il); cb(cur, "ffn_moe_out", il); @@ -97,7 +189,7 @@ llm_build_jamba::llm_build_jamba(const llama_model & model, const llm_graph_para res->t_embd = cur; // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/jina-bert-v2.cpp b/examples/talk-llama/models/jina-bert-v2.cpp new file mode 100644 index 00000000000..86ff1c84d1a --- /dev/null +++ b/examples/talk-llama/models/jina-bert-v2.cpp @@ -0,0 +1,66 @@ +#include "models.h" + +void llama_model_jina_bert_v2::load_arch_hparams(llama_model_loader & ml) { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); + hparams.f_max_alibi_bias = 8.0f; + + switch (hparams.n_layer()) { + case 4: type = LLM_TYPE_33M; break; // jina-embeddings-small + case 12: type = LLM_TYPE_137M; break; // jina-embeddings-base + default: type = LLM_TYPE_UNKNOWN; + } +} + +void llama_model_jina_bert_v2::load_arch_tensors(llama_model_loader & ml) { + LLAMA_LOAD_LOCALS; + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); // word_embeddings + type_embd = create_tensor(tn(LLM_TENSOR_TOKEN_TYPES, "weight"), {n_embd, n_token_types}, 0); // token_type_embeddings + + tok_norm = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "weight", 0), {n_embd}, 0); // LayerNorm + tok_norm_b = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "bias", 0), {n_embd}, 0); // LayerNorm bias + + cls = create_tensor(tn(LLM_TENSOR_CLS, "weight"), {n_embd, 1}, TENSOR_NOT_REQUIRED); + cls_b = create_tensor(tn(LLM_TENSOR_CLS, "bias"), {1}, TENSOR_NOT_REQUIRED); + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; // JinaBertLayer + + create_tensor_qkv(layer, i, n_embd, n_embd, n_embd_gqa, n_embd_gqa, 0); + + layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd}, TENSOR_NOT_REQUIRED); + layer.attn_q_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); + + layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd}, TENSOR_NOT_REQUIRED); + layer.attn_k_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); + + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); //output_dens + layer.wo_b = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, 0); //output_dens + + layer.attn_out_norm = create_tensor(tn(LLM_TENSOR_ATTN_OUT_NORM, "weight", i), {n_embd}, 0); //output_norm + layer.attn_out_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_OUT_NORM, "bias", i), {n_embd}, 0); + + layer.attn_norm_2 = create_tensor(tn(LLM_TENSOR_ATTN_NORM_2, "weight", i), {n_embd}, TENSOR_NOT_REQUIRED); + layer.attn_norm_2_b = create_tensor(tn(LLM_TENSOR_ATTN_NORM_2, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); + + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, TENSOR_NOT_REQUIRED); + + const auto tn_ffn_up_weight = tn(LLM_TENSOR_FFN_UP, "weight", i); + ggml_tensor * t_ffn_up = ml.get_tensor_meta(tn_ffn_up_weight.str().c_str()); + const int64_t n_ffn_up = t_ffn_up ? t_ffn_up->ne[1] : n_ff; + + GGML_ASSERT(n_ffn_up == n_ff || n_ffn_up == n_ff * 2); + layer.ffn_up = create_tensor(tn_ffn_up_weight, {n_embd, n_ffn_up}, 0); + layer.ffn_up_b = create_tensor(tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ffn_up}, TENSOR_NOT_REQUIRED); + + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0); + layer.ffn_down_b = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, 0); + + layer.layer_out_norm = create_tensor(tn(LLM_TENSOR_LAYER_OUT_NORM, "weight", i), {n_embd}, 0); + layer.layer_out_norm_b = create_tensor(tn(LLM_TENSOR_LAYER_OUT_NORM, "bias", i), {n_embd}, 0); + } +} + +std::unique_ptr<llm_graph_context> llama_model_jina_bert_v2::build_arch_graph(const llm_graph_params & params) const { + return std::make_unique<graph>(*this, params); +} + diff --git a/examples/talk-llama/models/jina-bert-v3.cpp b/examples/talk-llama/models/jina-bert-v3.cpp new file mode 100644 index 00000000000..1c974a6f16c --- /dev/null +++ b/examples/talk-llama/models/jina-bert-v3.cpp @@ -0,0 +1,69 @@ +#include "models.h" + +void llama_model_jina_bert_v3::load_arch_hparams(llama_model_loader & ml) { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); + + switch (hparams.n_layer()) { + case 24: + type = LLM_TYPE_558M; break; + default: type = LLM_TYPE_UNKNOWN; + } +} + +void llama_model_jina_bert_v3::load_arch_tensors(llama_model_loader &) { + LLAMA_LOAD_LOCALS; + + if (n_token_types == 0) { + throw std::runtime_error(arch_name() + " model needs to define token type count"); + } + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + type_embd = create_tensor(tn(LLM_TENSOR_TOKEN_TYPES, "weight"), {n_embd, n_token_types}, TENSOR_NOT_REQUIRED); + + if (arch == LLM_ARCH_BERT) { + pos_embd = create_tensor(tn(LLM_TENSOR_POS_EMBD, "weight"), {n_embd, n_ctx_train}, 0); + + cls = create_tensor(tn(LLM_TENSOR_CLS, "weight"), {n_embd, n_embd}, TENSOR_NOT_REQUIRED); + cls_b = create_tensor(tn(LLM_TENSOR_CLS, "bias"), {n_embd}, TENSOR_NOT_REQUIRED); + + cls_out = create_tensor(tn(LLM_TENSOR_CLS_OUT, "weight"), {n_embd, hparams.n_cls_out}, TENSOR_NOT_REQUIRED); + cls_out_b = create_tensor(tn(LLM_TENSOR_CLS_OUT, "bias"), {hparams.n_cls_out}, TENSOR_NOT_REQUIRED); + } + + tok_norm = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "weight", 0), {n_embd}, 0); + tok_norm_b = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "bias", 0), {n_embd}, 0); + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + create_tensor_qkv(layer, i, n_embd, n_embd, n_embd_gqa, n_embd_gqa, 0); + + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); + layer.wo_b = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); + + layer.attn_out_norm = create_tensor(tn(LLM_TENSOR_ATTN_OUT_NORM, "weight", i), {n_embd}, 0); + layer.attn_out_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_OUT_NORM, "bias", i), {n_embd}, 0); + + if (hparams.moe_every_n_layers > 0 && i % hparams.moe_every_n_layers == 1) { + layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), { n_embd, n_ff, n_expert}, 0); + layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), { n_ff, n_embd, n_expert}, 0); + layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0); + } else { + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_up_b = create_tensor(tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff}, TENSOR_NOT_REQUIRED); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0); + layer.ffn_down_b = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); + + if (arch == LLM_ARCH_NOMIC_BERT) { + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); + } + } + + layer.layer_out_norm = create_tensor(tn(LLM_TENSOR_LAYER_OUT_NORM, "weight", i), {n_embd}, 0); + layer.layer_out_norm_b = create_tensor(tn(LLM_TENSOR_LAYER_OUT_NORM, "bias", i), {n_embd}, 0); + } +} + +std::unique_ptr<llm_graph_context> llama_model_jina_bert_v3::build_arch_graph(const llm_graph_params & params) const { + return std::make_unique<graph>(*this, params); +} + diff --git a/examples/talk-llama/models/kimi-linear.cpp b/examples/talk-llama/models/kimi-linear.cpp new file mode 100644 index 00000000000..367f6990d1f --- /dev/null +++ b/examples/talk-llama/models/kimi-linear.cpp @@ -0,0 +1,550 @@ +#include "models.h" +#include "llama-memory-recurrent.h" + +void llama_model_kimi_linear::load_arch_hparams(llama_model_loader & ml) { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + ml.get_key(LLM_KV_ATTENTION_KEY_LENGTH_MLA, hparams.n_embd_head_k_mla_impl); + ml.get_key(LLM_KV_ATTENTION_VALUE_LENGTH_MLA, hparams.n_embd_head_v_mla_impl); + ml.get_key(LLM_KV_ATTENTION_KV_LORA_RANK, hparams.n_lora_kv); + ml.get_key(LLM_KV_SSM_CONV_KERNEL, hparams.ssm_d_conv); + ml.get_key(LLM_KV_KDA_HEAD_DIM, hparams.n_embd_head_kda); + + // MLA qk_rope_head_dim (for reference) + // qk_rope_head_dim = 64, qk_nope_head_dim = 128, qk_head_dim = 192 + + // Mark KDA layers as recurrent using n_head_kv pattern (like Jamba) + // Set n_head_kv = 0 for KDA layers (recurrent), n_head_kv = n_head for MLA layers (attention) + for (uint32_t i = 0; i < hparams.n_layer(); ++i) { + hparams.is_recr_impl[i] = hparams.n_head_kv(i) == 0; // KDA layers are recurrent + } + + // MoE parameters - Kimi uses moe_intermediate_size = 1024 + ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp); + ml.get_key(LLM_KV_EXPERT_SHARED_COUNT, hparams.n_expert_shared); + ml.get_key(LLM_KV_LEADING_DENSE_BLOCK_COUNT, hparams.n_layer_dense_lead, false); + ml.get_key(LLM_KV_EXPERT_WEIGHTS_SCALE, hparams.expert_weights_scale, false); + ml.get_key(LLM_KV_EXPERT_GATING_FUNC, hparams.expert_gating_func); + + switch (hparams.n_layer()) { + case 27: type = LLM_TYPE_48B_A3B; break; // Kimi-Linear-48B-A3B + default: type = LLM_TYPE_UNKNOWN; + } +} + +void llama_model_kimi_linear::load_arch_tensors(llama_model_loader &) { + LLAMA_LOAD_LOCALS; + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0); + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + + // Check for KDA specific tensors to determine layer type or if it's a mixed model + // Assuming KDA layer if KDA tensors are present + + // KDA uses head_dim = 128 (from linear_attn_config.head_dim) + const int64_t n_embd_head_k_kda = hparams.n_embd_head_kda; + const int64_t n_embd_head_v_kda = hparams.n_embd_head_kda; + const int64_t ssm_d_conv = hparams.ssm_d_conv; + + if (hparams.is_recr(i)) { + // Conv1d weights: try 4D first, then 3D (quantization may remove trailing 1) + // 4D: [d_conv, 1, d_inner, 1], 3D: [d_conv, 1, d_inner] + layer.ssm_q_conv = create_tensor(tn(LLM_TENSOR_SSM_CONV1D_Q, "weight", i), {ssm_d_conv, 1, n_embd_head_k_kda * n_head, 1}, TENSOR_NOT_REQUIRED); + if (!layer.ssm_q_conv) { + layer.ssm_q_conv = create_tensor(tn(LLM_TENSOR_SSM_CONV1D_Q, "weight", i), {ssm_d_conv, 1, n_embd_head_k_kda * n_head}, 0); + } + + // KDA Layer - Conv1d weights may be 3D or 4D + layer.ssm_k_conv = create_tensor(tn(LLM_TENSOR_SSM_CONV1D_K, "weight", i), {ssm_d_conv, 1, n_embd_head_k_kda * n_head, 1}, TENSOR_NOT_REQUIRED); + if (!layer.ssm_k_conv) { + layer.ssm_k_conv = create_tensor(tn(LLM_TENSOR_SSM_CONV1D_K, "weight", i), {ssm_d_conv, 1, n_embd_head_k_kda * n_head}, 0); + } + layer.ssm_v_conv = create_tensor(tn(LLM_TENSOR_SSM_CONV1D_V, "weight", i), {ssm_d_conv, 1, n_embd_head_v_kda * n_head, 1}, TENSOR_NOT_REQUIRED); + if (!layer.ssm_v_conv) { + layer.ssm_v_conv = create_tensor(tn(LLM_TENSOR_SSM_CONV1D_V, "weight", i), {ssm_d_conv, 1, n_embd_head_v_kda * n_head}, 0); + } + + // q, k, v projections + // Python: q_proj, k_proj, v_proj + create_tensor_qkv(layer, i, n_embd, n_embd_head_k_kda * n_head, n_embd_head_k_kda * n_head, n_embd_head_v_kda * n_head, 0); + + // KDA specific projections + // f_a_proj, f_b_proj + layer.ssm_f_a = create_tensor(tn(LLM_TENSOR_SSM_F_A, "weight", i), {n_embd, n_embd_head_k_kda}, 0); // head_dim + layer.ssm_f_b = create_tensor(tn(LLM_TENSOR_SSM_F_B, "weight", i), {n_embd_head_k_kda, n_embd_head_k_kda * n_head}, 0); // projection_size + + // b_proj (beta mixing coefficient) + layer.ssm_beta = create_tensor(tn(LLM_TENSOR_SSM_BETA, "weight", i), {n_embd, n_head}, 0); + + // A_log - Shape in GGUF: [1, num_heads, 1, 1] (4D) or [1, num_heads] (2D after quantization) Note: -exp(A_log) is applied in convert_hf_to_gguf.py + layer.ssm_a = create_tensor(tn(LLM_TENSOR_SSM_A, i), {1, n_head, 1, 1}, TENSOR_NOT_REQUIRED); + if (!layer.ssm_a) { + layer.ssm_a = create_tensor(tn(LLM_TENSOR_SSM_A, i), {1, n_head}, 0); + } + + // dt_bias - shape [n_embd_head_k_kda * n_head] = [4096] + layer.ssm_dt_b = create_tensor(tn(LLM_TENSOR_SSM_DT, "bias", i), {n_embd_head_k_kda * n_head}, 0); + + // g_a_proj, g_b_proj (output gate) + layer.ssm_g_a = create_tensor(tn(LLM_TENSOR_SSM_G_A, "weight", i), {n_embd, n_embd_head_k_kda}, 0); + layer.ssm_g_b = create_tensor(tn(LLM_TENSOR_SSM_G_B, "weight", i), {n_embd_head_k_kda, n_embd_head_k_kda * n_head}, 0); + + // o_norm (reusing SSM_NORM) + layer.ssm_o_norm = create_tensor(tn(LLM_TENSOR_SSM_NORM, "weight", i), {n_embd_head_k_kda}, 0); // FusedRMSNormGated + + // o_proj + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_v_kda * n_head, n_embd}, 0); + + } else { + // MLA Layer - use MLA-specific head dimensions + const int64_t q_lora_rank = hparams.n_lora_q; + const int64_t kv_lora_rank = hparams.n_lora_kv; + const int64_t n_embd_head_k_mla = hparams.n_embd_head_k_mla(); + const int64_t n_embd_head_v_mla = hparams.n_embd_head_v_mla(); + + layer.attn_q_a_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_A_NORM, "weight", i), {q_lora_rank}, TENSOR_NOT_REQUIRED); + layer.attn_kv_a_norm = create_tensor(tn(LLM_TENSOR_ATTN_KV_A_NORM, "weight", i), {kv_lora_rank}, 0); + + if (layer.attn_q_a_norm) { + layer.wq_a = create_tensor(tn(LLM_TENSOR_ATTN_Q_A, "weight", i), {n_embd, q_lora_rank}, 0); + layer.wq_b = create_tensor(tn(LLM_TENSOR_ATTN_Q_B, "weight", i), {q_lora_rank, n_head * n_embd_head_k_mla}, 0); + } else { + // Kimi MLA without Q compression: wq = [n_embd, n_head * n_embd_head_k_mla] + layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_head * n_embd_head_k_mla}, 0); + } + + // Kimi: qk_rope_head_dim = 64 (actual RoPE dimension for MLA) + // Note: hparams.n_rot may be 72 (from conversion) but actual is 64 + const int64_t qk_rope_head_dim = hparams.n_rot(); // From config: qk_rope_head_dim + layer.wkv_a_mqa = create_tensor(tn(LLM_TENSOR_ATTN_KV_A_MQA, "weight", i), {n_embd, kv_lora_rank + qk_rope_head_dim}, 0); + // Support Legacy GGUFs that don't split wkv_b (MLA KV cache disabled) + layer.wkv_b = create_tensor(tn(LLM_TENSOR_ATTN_KV_B, "weight", i), + {kv_lora_rank, n_head * (n_embd_head_k_mla - qk_rope_head_dim + n_embd_head_v_mla)}, TENSOR_NOT_REQUIRED | TENSOR_SKIP_IF_VIRTUAL); + if (!layer.wkv_b) { // MLA KV cache enabled + layer.wk_b = create_tensor(tn(LLM_TENSOR_ATTN_K_B, "weight", i), {n_embd_head_k_mla - qk_rope_head_dim, kv_lora_rank, n_head}, 0); + layer.wv_b = create_tensor(tn(LLM_TENSOR_ATTN_V_B, "weight", i), {kv_lora_rank, n_embd_head_v_mla, n_head}, 0); + } + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_head * n_embd_head_v_mla, n_embd}, 0); + } + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + + // MoE intermediate size (different from dense FFN) + const int64_t n_ff_exp = hparams.n_ff_exp; + + // Kimi uses n_layer_dense_lead to determine which layers use dense FFN vs MoE + // first_k_dense_replace = 1 means layer 0 uses dense FFN, layers 1+ use MoE + if (i < (int) hparams.n_layer_dense_lead) { + // Dense FFN layer - use normal n_ff + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + } else { + // MoE layer - use n_ff_exp (1024) instead of n_ff (9216) + layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0); + layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {n_embd, n_ff_exp, n_expert}, 0); + layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff_exp, n_embd, n_expert}, 0); + layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), {n_embd, n_ff_exp, n_expert}, 0); + + // Shared experts use moe_intermediate_size * num_shared_experts + // Kimi: shared_expert_intermediate_size = 1024 * 1 = 1024 + // Tensors are 2D: [n_embd, n_ff_shexp] or [n_ff_shexp, n_embd] + const int64_t n_ff_shexp_actual = n_ff_exp * (hparams.n_expert_shared > 0 ? hparams.n_expert_shared : 1); + layer.ffn_gate_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), {n_embd, n_ff_shexp_actual}, TENSOR_NOT_REQUIRED); + layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), {n_ff_shexp_actual, n_embd}, TENSOR_NOT_REQUIRED); + layer.ffn_up_shexp = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), {n_embd, n_ff_shexp_actual}, TENSOR_NOT_REQUIRED); + + layer.ffn_exp_probs_b = create_tensor(tn(LLM_TENSOR_FFN_EXP_PROBS_B, "bias", i), {n_expert}, 0); + } + } +} + +std::unique_ptr<llm_graph_context> llama_model_kimi_linear::build_arch_graph(const llm_graph_params & params) const { + return std::make_unique<graph>(*this, params); +} + +// Causal Conv1d function for Q,K,V +// When qkv is 0, it is Q, 1 is K, 2 is V +static ggml_tensor * causal_conv1d(ggml_cgraph * gf, ggml_context * ctx0, ggml_tensor * conv_states_all, ggml_tensor * conv_state_all, int64_t qkv, ggml_tensor * x, ggml_tensor * proj_w, ggml_tensor * conv_w, int64_t d_conv, int64_t head_dim, int64_t n_head, int64_t n_seq_tokens, int64_t n_seqs, int64_t n_tokens, int64_t kv_head) { + const int64_t d_inner = head_dim * n_head; + const int64_t conv_state_size = (d_conv - 1) * d_inner; + const int64_t n_embd_r_total = 3 * conv_state_size; // Q + K + V + + // conv_state_all is [n_embd_r_total, n_seqs], split into Q, K, V + // Each conv state is [(d_conv-1) * d_inner] per sequence, need to reshape to [d_conv-1, d_inner, n_seqs] + // Memory layout: for each seq, Q state is first conv_state_size elements, then K, then V + // conv_state_all has stride: nb[0] = element_size, nb[1] = n_embd_r_total * element_size + // View Q conv state: offset 0, size conv_state_size per seq + // conv_state_all is [n_embd_r_total, n_seqs] with memory layout: + // state[i + seq * n_embd_r_total] where i = conv_step + channel * (d_conv-1) + {0, conv_state_size, 2*conv_state_size} for Q/K/V + // We want [d_conv-1, d_inner, n_seqs] view: + // nb1 = (d_conv-1) * element_size (stride between channels) + // nb2 = n_embd_r_total * element_size (stride between seqs) + ggml_tensor * conv_state_x = ggml_view_3d(ctx0, conv_state_all, d_conv - 1, d_inner, n_seqs, + (d_conv - 1) * ggml_element_size(conv_state_all), // nb1: stride between channels + n_embd_r_total * ggml_element_size(conv_state_all), // nb2: stride between seqs + qkv * conv_state_size * ggml_element_size(conv_state_all)); + +// Causal Conv1d function for Q,K,V +// When qkv is 0, it is Q, 1 is K, 2 is V + // Step 1: Q, K, V projections -> [d_inner, n_tokens] + ggml_tensor * x_proj = ggml_mul_mat(ctx0, proj_w, x); + + // Reshape input: {d_inner, n_tokens} -> {d_inner, n_seq_tokens, n_seqs} + ggml_tensor * x_3d = ggml_reshape_3d(ctx0, x_proj, d_inner, n_seq_tokens, n_seqs); + + // Concat Q conv state and current input: {d_conv-1 + n_seq_tokens, d_inner, n_seqs} + ggml_tensor * conv_x = ggml_concat(ctx0, conv_state_x, ggml_transpose(ctx0, x_3d), 0); + + // Save last (d_conv-1) columns back to Q conv state + ggml_tensor * last_conv_x = ggml_view_3d(ctx0, conv_x, d_conv - 1, d_inner, n_seqs, + conv_x->nb[1], conv_x->nb[2], n_seq_tokens * conv_x->nb[0]); + ggml_build_forward_expand(gf, + ggml_cpy(ctx0, last_conv_x, + ggml_view_3d(ctx0, conv_states_all, + d_conv - 1, d_inner, n_seqs, + (d_conv - 1) * ggml_element_size(conv_states_all), // nb1: contiguous within one channel's conv taps + n_embd_r_total * ggml_element_size(conv_states_all), // nb2: stride between sequences (skip over K,V states) + (kv_head * n_embd_r_total + qkv * conv_state_size) * ggml_element_size(conv_states_all)))); // offset to first seq's Q/K/V state + // Reshape conv weight: GGUF [d_conv, 1, d_inner, 1] -> ggml_ssm_conv expects [d_conv, d_inner] + // GGUF stores as [d_conv, 1, d_inner, 1] with memory layout w[conv_step + channel * d_conv] + // vLLM stores as [d_inner, d_conv] with memory layout w[channel * d_conv + conv_step] + // ggml_ssm_conv computes: c[conv_step + channel * d_conv] + // GGUF layout: [d_conv, 1, d_inner] or [d_conv, 1, d_inner, 1] -> reshape to [d_conv, d_inner] + // Reshape conv weight from [d_conv, 1, d_inner, 1] to [d_conv, d_inner] for ggml_ssm_conv + ggml_tensor * conv_weight = ggml_reshape_2d(ctx0, conv_w, d_conv, d_inner); + + // Apply conv1d + // ggml_ssm_conv output: {d_inner, n_seq_tokens, n_seqs} + ggml_tensor * Xcur = ggml_ssm_conv(ctx0, conv_x, conv_weight); + // Reshape to 2D for bias add: {d_inner, n_tokens} + Xcur = ggml_reshape_2d(ctx0, Xcur, d_inner, n_tokens); + Xcur = ggml_silu(ctx0, Xcur); + + return ggml_reshape_4d(ctx0, Xcur, head_dim, n_head, n_seq_tokens, n_seqs); +} + +llama_model_kimi_linear::graph::graph(const llama_model & model, const llm_graph_params & params) : + llm_build_delta_net_base(params), model(model) { + ggml_tensor * cur; + ggml_tensor * inpL; + + inpL = build_inp_embd(model.tok_embd); + cb(inpL, "model.embed_tokens", -1); + + // Note: Kimi MLA does NOT use RoPE (rotary_emb=None in vLLM) + // So we don't need inp_pos + + auto * inp_kv = !hparams.is_mla() ? build_inp_mem_hybrid() : nullptr; + auto * inp_k = hparams.is_mla() ? build_inp_mem_hybrid_k() : nullptr; + auto * inp_rs = hparams.is_mla() ? inp_k->get_recr() : inp_kv->get_recr(); + auto * inp_attn_kv = !hparams.is_mla() ? inp_kv->get_attn() : nullptr; + auto * inp_attn_k = hparams.is_mla() ? inp_k->get_attn() : nullptr; + + // Output ids for selecting which tokens to output + ggml_tensor * inp_out_ids = build_inp_out_ids(); + + // Kimi dimension constants + const int64_t n_head = hparams.n_head(); + const int64_t head_dim = hparams.n_embd_head_kda; + const int64_t d_conv = hparams.ssm_d_conv; + const int64_t d_inner = n_head * head_dim; // 32 * 128 = 4096 + const int64_t n_seqs = ubatch.n_seqs; + const int64_t n_seq_tokens = ubatch.n_seq_tokens; + + // Verify batch consistency for recurrent layers + GGML_ASSERT(n_seqs != 0); + GGML_ASSERT(ubatch.equal_seqs()); + GGML_ASSERT(ubatch.n_tokens == n_seq_tokens * n_seqs); + + // MLA params + const int64_t n_embd_head_k_mla = hparams.n_embd_head_k_mla(); + const int64_t n_embd_head_v_mla = hparams.n_embd_head_v_mla(); + const int64_t kv_lora_rank = hparams.n_lora_kv; + // qk_rope_head_dim = 64 (from Kimi config) which is hparams.n_rot + // Confirmed from tensor shape: wkv_a_mqa [2304, 576] = [n_embd, kv_lora_rank + qk_rope_head_dim] + const int64_t n_embd_head_qk_rope = hparams.n_rot(); // config.qk_rope_head_dim + const int64_t n_embd_head_qk_nope = n_embd_head_k_mla - n_embd_head_qk_rope; // 192 - 64 = 128 + // Attention scale for MLA + const float kq_scale_mla = 1.0f / sqrtf((float)n_embd_head_k_mla); + + for (int il = 0; il < n_layer; ++il) { + const auto & layer = model.layers[il]; + ggml_tensor * inpSA = inpL; + + // Attention Norm + cur = build_norm(inpL, layer.attn_norm, NULL, LLM_NORM_RMS, il); + cb(cur, "attn_norm", il); + + ggml_build_forward_expand(gf, cur); + + if (hparams.is_recr(il)) { + // === KDA Layer (Kimi Delta Attention) with Recurrent State === + // Reference: vLLM kda.py + const auto * mctx_cur = inp_rs->mctx; + const auto kv_head = mctx_cur->get_head(); + + // Get conv states from r_l tensor (Q, K, V each have separate state) + ggml_tensor * conv_states_all = mctx_cur->get_r_l(il); + cb(conv_states_all, "conv_states_all", il); + ggml_tensor * conv_state_all = build_rs(inp_rs, conv_states_all, hparams.n_embd_r(), n_seqs); + ggml_tensor * Qcur = causal_conv1d(gf, ctx0, conv_states_all, conv_state_all, 0, cur, layer.wq, layer.ssm_q_conv, d_conv, head_dim, n_head, n_seq_tokens, n_seqs, n_tokens, kv_head); + ggml_tensor * Kcur = causal_conv1d(gf, ctx0, conv_states_all, conv_state_all, 1, cur, layer.wk, layer.ssm_k_conv, d_conv, head_dim, n_head, n_seq_tokens, n_seqs, n_tokens, kv_head); + ggml_tensor * Vcur = causal_conv1d(gf, ctx0, conv_states_all, conv_state_all, 2, cur, layer.wv, layer.ssm_v_conv, d_conv, head_dim, n_head, n_seq_tokens, n_seqs, n_tokens, kv_head); + + // g1 = -exp(A_log) * softplus(f_b(f_a(x)) + dt_bias) + ggml_tensor * f_a = ggml_mul_mat(ctx0, layer.ssm_f_a, cur); + ggml_tensor * g1 = ggml_mul_mat(ctx0, layer.ssm_f_b, f_a); + cb(g1, "g1 f_b(f_a(cur))", il); + g1 = ggml_add(ctx0, g1, layer.ssm_dt_b); + g1 = ggml_softplus(ctx0, g1); + g1 = ggml_reshape_3d(ctx0, g1, head_dim, n_head, n_tokens); + + // A_log shape is [1, n_head] or [1, n_head, 1, 1], need to broadcast to [head_dim, n_head, n_tokens]. No need to -exp(a_log) because it was done in convert_hf_to_gguf.py + // Reshape to [1, n_head, 1] for broadcasting with g1 [head_dim, n_head, n_tokens] + ggml_tensor * A = ggml_reshape_3d(ctx0, layer.ssm_a, 1, n_head, 1); + g1 = ggml_mul(ctx0, g1, A); + cb(g1, "kda_g1", il); + + g1 = ggml_reshape_4d(ctx0, g1, head_dim, n_head, n_seq_tokens, n_seqs); + + // Compute beta (mixing coefficient) + ggml_tensor * beta = ggml_mul_mat(ctx0, layer.ssm_beta, cur); + beta = ggml_reshape_4d(ctx0, beta, 1, n_head, n_seq_tokens, n_seqs); + cb(beta, "kda_beta", il); + + beta = ggml_sigmoid(ctx0, beta); + + // Reshape for KDA recurrence + // {n_embd, n_tokens} -> {n_embd, n_seq_tokens, n_seqs} + cur = ggml_reshape_3d(ctx0, cur, cur->ne[0], n_seq_tokens, n_seqs); + + // Get SSM state and compute KDA recurrence using ggml_kda_scan + ggml_tensor * ssm_states_all = mctx_cur->get_s_l(il); + ggml_tensor * state = build_rs(inp_rs, ssm_states_all, hparams.n_embd_s(), n_seqs); + state = ggml_reshape_4d(ctx0, state, head_dim, head_dim, n_head, n_seqs); + + const float eps_norm = hparams.f_norm_rms_eps; + + Qcur = ggml_l2_norm(ctx0, Qcur, eps_norm); + Kcur = ggml_l2_norm(ctx0, Kcur, eps_norm); + + // Choose between build_delta_net_chunking and build_delta_net_recurrent based on n_tokens + auto attn_out = build_delta_net(Qcur, Kcur, Vcur, g1, beta, state, il); + + ggml_tensor * output = ggml_cont(ctx0, attn_out.first); + ggml_tensor * new_state = attn_out.second; + cb(output, "attn_output", il); + cb(new_state, "new_state", il); + + // Update the recurrent states + ggml_build_forward_expand(gf, + ggml_cpy(ctx0, new_state, + ggml_view_1d(ctx0, ssm_states_all, hparams.n_embd_s() * n_seqs, + kv_head * hparams.n_embd_s() * ggml_element_size(ssm_states_all)))); + + // Output gating g2 = g_b(g_a(x)) + ggml_tensor * cur_2d = ggml_reshape_2d(ctx0, cur, cur->ne[0], n_seq_tokens * n_seqs); + ggml_tensor * g_a = ggml_mul_mat(ctx0, layer.ssm_g_a, cur_2d); + ggml_tensor * g2 = ggml_mul_mat(ctx0, layer.ssm_g_b, g_a); + cb(g2, "g2 g_b(g_a(cur_2d))", il); + g2 = ggml_reshape_3d(ctx0, g2, head_dim, n_head, n_seq_tokens * n_seqs); + + // Apply o_norm with sigmoid gating + // Note: Kimi model uses sigmoid gating, not SiLU (despite FusedRMSNormGated default being swish) + // Formula: output = RMSNorm(x) * sigmoid(g) + ggml_tensor * attn_out_final = ggml_reshape_3d(ctx0, output, head_dim, n_head, n_seq_tokens * n_seqs); + ggml_tensor * normed = build_norm(attn_out_final, layer.ssm_o_norm, nullptr, LLM_NORM_RMS, il); + cb(normed, "kda_normed", il); + ggml_tensor * gate = ggml_sigmoid(ctx0, g2); + ggml_tensor * gated = ggml_mul(ctx0, normed, gate); + + // Output projection + gated = ggml_cont_2d(ctx0, gated, d_inner, n_tokens); + cur = ggml_mul_mat(ctx0, layer.wo, gated); + cb(cur, "kda_out", il); + + } else { + // === MLA Layer (Multi-head Latent Attention) without KV Cache === + // Reference: vLLM mla.py + // Step 1: Q projection and reshape + // vLLM Kimi: q = q_proj(hidden_states), then view as [n_tokens, n_head, qk_head_dim] + // Note: Kimi MLA does NOT use RoPE (rotary_emb=None in vLLM) + ggml_tensor * Qcur = ggml_mul_mat(ctx0, layer.wq, cur); + + // Step 2: KV compression + // kv_cmpr_pe = kv_a_proj_with_mqa(hidden_states) -> [kv_lora_rank + qk_rope_head_dim, n_tokens] + ggml_tensor * kv_cmpr_pe = ggml_mul_mat(ctx0, layer.wkv_a_mqa, cur); + + // Split: kv_cmpr = kv_lora[:kv_lora_rank], k_pe = kv_lora[kv_lora_rank:] + ggml_tensor * kv_cmpr = ggml_view_2d(ctx0, kv_cmpr_pe, kv_lora_rank, n_tokens, + ggml_row_size(kv_cmpr_pe->type, kv_lora_rank + n_embd_head_qk_rope), 0); + ggml_tensor * k_pe = ggml_view_3d(ctx0, kv_cmpr_pe, n_embd_head_qk_rope, 1, n_tokens, + ggml_row_size(kv_cmpr_pe->type, kv_lora_rank + n_embd_head_qk_rope), + ggml_row_size(kv_cmpr_pe->type, kv_lora_rank + n_embd_head_qk_rope), + ggml_row_size(kv_cmpr_pe->type, kv_lora_rank)); + // Note: Kimi MLA does NOT apply RoPE (rotary_emb=None in vLLM) + // k_pe is used directly without RoPE + // Normalize kv_c + kv_cmpr = build_norm(kv_cmpr, layer.attn_kv_a_norm, nullptr, LLM_NORM_RMS, il); + + if (layer.wk_b && layer.wv_b) { // MLA KV cache enabled + // extract q_nope + ggml_tensor * q_nope = + ggml_view_3d(ctx0, Qcur, n_embd_head_qk_nope, n_head, n_tokens, ggml_row_size(Qcur->type, n_embd_head_k_mla), + ggml_row_size(Qcur->type, n_embd_head_k_mla) * n_head, 0); + cb(q_nope, "q_nope", il); + + // and {n_embd_head_qk_rope, n_head, n_tokens} + ggml_tensor * q_pe = ggml_view_3d( + ctx0, Qcur, n_embd_head_qk_rope, n_head, n_tokens, ggml_row_size(Qcur->type, n_embd_head_k_mla), + ggml_row_size(Qcur->type, n_embd_head_k_mla) * n_head, ggml_row_size(Qcur->type, n_embd_head_qk_nope)); + cb(q_pe, "q_pe", il); + + // {n_embd_head_qk_nope, n_tokens, n_head} + q_nope = ggml_permute(ctx0, q_nope, 0, 2, 1, 3); + cb(q_nope, "q_nope_perm", il); + + // {n_embd_head_qk_nope, kv_lora_rank, n_head} x {n_embd_head_qk_nope, n_tokens, n_head} + ggml_tensor * q_nope_absorbed = ggml_mul_mat(ctx0, layer.wk_b, q_nope); + cb(q_nope_absorbed, "q_nope_absorbed", il); + + // {kv_lora_rank, n_head, n_tokens} + q_nope_absorbed = ggml_permute(ctx0, q_nope_absorbed, 0, 2, 1, 3); + cb(q_nope_absorbed, "q_nope_absorbed_perm", il); + + // {n_embd_head_qk_rope + kv_lora_rank, n_head, n_tokens} + // note: rope must go first for in-place context shifting in build_rope_shift() + Qcur = ggml_concat(ctx0, q_nope_absorbed, q_pe, 0); + cb(Qcur, "Qcur", il); + + kv_cmpr = ggml_reshape_3d(ctx0, kv_cmpr, kv_lora_rank, 1, n_tokens); + cb(kv_cmpr, "kv_cmpr_reshape", il); + + // {n_embd_head_qk_rope + kv_lora_rank, 1, n_tokens} + ggml_tensor * Kcur = ggml_concat(ctx0, kv_cmpr, k_pe, 0); + cb(Kcur, "Kcur", il); + + // {kv_lora_rank, 1, n_tokens} + ggml_tensor * Vcur = kv_cmpr; + cb(Vcur, "Vcur", il); + + cur = build_attn(inp_attn_k, layer.wo, NULL, layer.wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, layer.wv_b, kq_scale_mla, il); + cb(cur, "mla_out", il); + } else { // MLA KV cache disabled. Fall back to MHA KV cache. + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head_k_mla, n_head, n_tokens); + cb(Qcur, "mla_Q", il); + // KV decompression: kv = kv_b_proj(kv_c_normed) + ggml_tensor * kv = ggml_mul_mat(ctx0, layer.wkv_b, kv_cmpr); + const int64_t kv_per_head = n_embd_head_qk_nope + n_embd_head_v_mla; + + // Split kv into k_nope and v + ggml_tensor * k_nope = ggml_view_3d(ctx0, kv, n_embd_head_qk_nope, n_head, n_tokens, + ggml_row_size(kv->type, kv_per_head), + ggml_row_size(kv->type, kv_per_head * n_head), 0); + ggml_tensor * Vcur = ggml_view_3d(ctx0, kv, n_embd_head_v_mla, n_head, n_tokens, + ggml_row_size(kv->type, kv_per_head), + ggml_row_size(kv->type, kv_per_head * n_head), + ggml_row_size(kv->type, n_embd_head_qk_nope)); + Vcur = ggml_cont(ctx0, Vcur); + cb(Vcur, "mla_V", il); + + // Concatenate k_nope + k_pe (broadcast k_pe to all heads) + // K = [k_nope, k_pe] where k_nope is [qk_nope_head_dim, n_head, n_tokens] + // and k_pe is [qk_rope_head_dim, 1, n_tokens] broadcast to all heads + // Need to broadcast k_pe from [qk_rope, 1, n_tokens] to [qk_rope, n_head, n_tokens] + ggml_tensor * k_pe_target = ggml_new_tensor_3d(ctx0, k_pe->type, n_embd_head_qk_rope, n_head, n_tokens); + ggml_tensor * k_pe_repeated = ggml_repeat(ctx0, k_pe, k_pe_target); + ggml_tensor * Kcur = ggml_concat(ctx0, k_pe_repeated, k_nope, 0); + cb(Kcur, "mla_K", il); + + // Direct softmax attention (with MHA KV cache) + // Use build_attn with inp_attn for proper mask handling + cur = build_attn(inp_attn_kv, layer.wo, NULL, layer.wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale_mla, il); + cb(cur, "mla_out", il); + } + } + + // On last layer, select only the output tokens + if (il == n_layer - 1 && inp_out_ids) { + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); + } + + // Residual + ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA); + cb(ffn_inp, "ffn_inp", il); + + // FFN Norm + cur = build_norm(ffn_inp, layer.ffn_norm, NULL, LLM_NORM_RMS, il); + cb(cur, "ffn_norm", il); + + if ((uint32_t) il < hparams.n_layer_dense_lead) { + // Dense FFN layer + cur = build_ffn(cur, + layer.ffn_up, NULL, NULL, + layer.ffn_gate, NULL, NULL, + layer.ffn_down, NULL, NULL, + NULL, LLM_FFN_SILU, LLM_FFN_PAR, il); + cb(cur, "ffn_out", il); + } else { + // MoE layer + // Kimi uses moe_renormalize=True and routed_scaling_factor (stored as expert_weights_scale) = 2.446 + ggml_tensor * moe_out = build_moe_ffn(cur, + layer.ffn_gate_inp, + layer.ffn_up_exps, + layer.ffn_gate_exps, + layer.ffn_down_exps, + layer.ffn_exp_probs_b, + hparams.n_expert, + hparams.n_expert_used, + LLM_FFN_SILU, true, + hparams.expert_weights_scale, + (llama_expert_gating_func_type) hparams.expert_gating_func, + il); + cb(moe_out, "ffn_moe_out", il); + + // Shared expert + { + ggml_tensor * ffn_shexp = build_ffn(cur, + layer.ffn_up_shexp, NULL, NULL, + layer.ffn_gate_shexp, NULL, NULL, + layer.ffn_down_shexp, NULL, NULL, + NULL, LLM_FFN_SILU, LLM_FFN_PAR, il); + cb(ffn_shexp, "ffn_shexp", il); + + cur = ggml_add(ctx0, moe_out, ffn_shexp); + cb(cur, "ffn_out", il); + } + } + // Residual + cur = ggml_add(ctx0, cur, ffn_inp); + + cur = build_cvec(cur, il); + cb(cur, "l_out", il); + + // input for next layer + inpL = cur; + } + cur = inpL; + + // Final Norm + cur = build_norm(cur, model.output_norm, NULL, LLM_NORM_RMS, -1); + + cb(cur, "result_norm", -1); + res->t_embd = cur; + + // Output + cur = ggml_mul_mat(ctx0, model.output, cur); + cb(cur, "result_output", -1); + res->t_logits = cur; + + ggml_build_forward_expand(gf, cur); +} diff --git a/examples/talk-llama/models/lfm2.cpp b/examples/talk-llama/models/lfm2.cpp index 7f805d78795..97da8a6abb8 100644 --- a/examples/talk-llama/models/lfm2.cpp +++ b/examples/talk-llama/models/lfm2.cpp @@ -1,18 +1,235 @@ #include "models.h" - +#include "../llama-memory-hybrid-iswa.h" #include "../llama-memory-hybrid.h" +void llama_model_lfm2::load_arch_hparams(llama_model_loader & ml) { + ml.get_key(LLM_KV_SHORTCONV_L_CACHE, hparams.n_shortconv_l_cache); + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + + for (uint32_t il = 0; il < hparams.n_layer(); ++il) { + hparams.is_recr_impl[il] = hparams.n_head_kv(il) == 0; + } + + hparams.n_layer_dense_lead = hparams.n_layer(); + + switch (hparams.n_ff()) { + case 4608: type = LLM_TYPE_350M; break; + case 6912: type = LLM_TYPE_700M; break; + case 8192: type = LLM_TYPE_1_2B; break; + case 10752: type = LLM_TYPE_2_6B; break; + default: type = LLM_TYPE_UNKNOWN; + } + + if (const auto is_swa = ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa, false); is_swa && hparams.n_swa > 0) { + hparams.swa_type = LLAMA_SWA_TYPE_STANDARD; + for (uint32_t il = 0; il < hparams.n_layer(); ++il) { + hparams.is_swa_impl[il] = !hparams.is_recr_impl[il]; + } + } +} + +void llama_model_lfm2::load_arch_tensors(llama_model_loader &) { + LLAMA_LOAD_LOCALS; + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM_LFM2, "weight"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); + + if (output == NULL) { + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); + } + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + const bool is_moe_layer = i >= static_cast<int>(hparams.n_layer_dense_lead); + + // ffn/moe is same for transformer and conv layers + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + if (is_moe_layer) { + GGML_ASSERT(n_expert && n_expert_used); + layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0); + layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {n_embd, hparams.n_ff_exp, n_expert}, 0); + layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {hparams.n_ff_exp, n_embd, n_expert}, 0); + layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), {n_embd, hparams.n_ff_exp, n_expert}, 0); + layer.ffn_exp_probs_b = create_tensor(tn(LLM_TENSOR_FFN_EXP_PROBS_B, "bias", i), {n_expert}, 0); + } else { // dense + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + } + + // for operator_norm + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + + if (!hparams.is_recr(i)) { + layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head_k}, 0); + layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k}, 0); + GGML_ASSERT(n_embd_v_gqa == n_embd_k_gqa); -llm_build_lfm2::llm_build_lfm2(const llama_model & model, const llm_graph_params & params) : - llm_graph_context(params), - model(model) { + create_tensor_qkv(layer, i, n_embd, n_embd, hparams.n_embd_k_gqa(i), hparams.n_embd_v_gqa(i), 0); + + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); + } else { + layer.shortconv.conv = create_tensor(tn(LLM_TENSOR_SHORTCONV_CONV, "weight", i), {hparams.n_shortconv_l_cache, n_embd}, 0); + layer.shortconv.in_proj = create_tensor(tn(LLM_TENSOR_SHORTCONV_INPROJ, "weight", i), {n_embd, 3 * n_embd}, 0); + layer.shortconv.out_proj = create_tensor(tn(LLM_TENSOR_SHORTCONV_OUTPROJ, "weight", i), {n_embd, n_embd}, 0); + } + } + + // for LFM2-ColBert-350M + dense_2_out_layers = create_tensor(tn(LLM_TENSOR_DENSE_2_OUT, "weight"), {n_embd, hparams.n_embd_out()}, TENSOR_NOT_REQUIRED); + dense_2_out_layers_b = create_tensor(tn(LLM_TENSOR_DENSE_2_OUT, "bias"), {hparams.n_embd_out() }, TENSOR_NOT_REQUIRED); +} + +std::unique_ptr<llm_graph_context> llama_model_lfm2::build_arch_graph(const llm_graph_params & params) const { + if (hparams.swa_type == LLAMA_SWA_TYPE_STANDARD) { + return std::make_unique<graph<true>>(*this, params); + } else { + return std::make_unique<graph<false>>(*this, params); + } +} + +template <bool iswa> +llama_model_lfm2::graph<iswa>::graph(const llama_model & model, const llm_graph_params & params) : + llm_graph_context(params) { + using inp_hybrid_type = std::conditional_t<iswa, llm_graph_input_mem_hybrid_iswa, llm_graph_input_mem_hybrid>; + using inp_attn_type = std::conditional_t<iswa, llm_graph_input_attn_kv_iswa, llm_graph_input_attn_kv>; + using mem_hybrid_ctx = std::conditional_t<iswa, llama_memory_hybrid_iswa_context, llama_memory_hybrid_context>; + + // lambda helpers for readability + auto build_dense_feed_forward = [&model, this](ggml_tensor * cur, int il) -> ggml_tensor * { + GGML_ASSERT(!model.layers[il].ffn_up_b); + GGML_ASSERT(!model.layers[il].ffn_gate_b); + GGML_ASSERT(!model.layers[il].ffn_down_b); + return build_ffn(cur, + model.layers[il].ffn_up, NULL, NULL, + model.layers[il].ffn_gate, NULL, NULL, + model.layers[il].ffn_down, NULL, NULL, + NULL, LLM_FFN_SILU, LLM_FFN_PAR, il); + }; + auto build_moe_feed_forward = [&model, this](ggml_tensor * cur, int il) -> ggml_tensor * { + return build_moe_ffn(cur, + model.layers[il].ffn_gate_inp, + model.layers[il].ffn_up_exps, + model.layers[il].ffn_gate_exps, + model.layers[il].ffn_down_exps, + model.layers[il].ffn_exp_probs_b, + n_expert, n_expert_used, + LLM_FFN_SILU, true, + hparams.expert_weights_scale, + static_cast<llama_expert_gating_func_type>(hparams.expert_gating_func), + il); + }; + auto build_attn_block = [&model, this](ggml_tensor * cur, + ggml_tensor * inp_pos, + inp_attn_type * inp_attn, + int il) -> ggml_tensor * { + GGML_ASSERT(hparams.n_embd_v_gqa(il) == hparams.n_embd_k_gqa(il)); + const auto n_embd_head = hparams.n_embd_head_v(); + const auto n_head_kv = hparams.n_head_kv(il); + + auto [q, k, v] = build_qkv(model.layers[il], cur, + n_embd_head, n_head, n_head_kv, il); + + // qk norm + q = build_norm(q, model.layers[il].attn_q_norm, NULL, LLM_NORM_RMS, il); + cb(q, "model.layers.{}.self_attn.q_layernorm", il); + k = build_norm(k, model.layers[il].attn_k_norm, NULL, LLM_NORM_RMS, il); + cb(k, "model.layers.{}.self_attn.k_layernorm", il); + + // RoPE + q = ggml_rope_ext(ctx0, q, inp_pos, nullptr, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, ext_factor, + attn_factor, beta_fast, beta_slow); + k = ggml_rope_ext(ctx0, k, inp_pos, nullptr, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, ext_factor, + attn_factor, beta_fast, beta_slow); + + cur = build_attn(inp_attn, + model.layers[il].wo, NULL, model.layers[il].wo_s, + q, k, v, nullptr, nullptr, nullptr, 1.0f / sqrtf(float(n_embd_head)), il); + + cb(cur, "model.layers.{}.self_attn.out_proj", il); + + return cur; + }; + auto build_shortconv_block = [&model, this](ggml_tensor * cur, + llm_graph_input_rs * inp_recr, + int il) -> ggml_tensor * { + const auto * mctx_cur = static_cast<const mem_hybrid_ctx *>(mctx)->get_recr(); + const uint32_t kv_head = mctx_cur->get_head(); + const int64_t n_seq_tokens = ubatch.n_seq_tokens; + const int64_t n_seqs = ubatch.n_seqs; + GGML_ASSERT(n_seqs != 0); + GGML_ASSERT(ubatch.equal_seqs()); + GGML_ASSERT(ubatch.n_tokens == n_seq_tokens * n_seqs); + + GGML_ASSERT(hparams.n_shortconv_l_cache > 1); + const uint32_t d_conv = hparams.n_shortconv_l_cache - 1; + + // {n_embd, n_tokens} => {n_embd, n_seq_tokens, n_seqs} + cur = ggml_reshape_3d(ctx0, cur, cur->ne[0], n_seq_tokens, n_seqs); + + auto * bcx = build_lora_mm(model.layers[il].shortconv.in_proj, cur); + cb(bcx, "model.layers.{}.conv.in_proj", il); + + constexpr auto n_chunks = 3; + GGML_ASSERT(bcx->ne[0] % n_chunks == 0); + const auto chunk_size = bcx->ne[0] / n_chunks; + auto * b = ggml_view_3d(ctx0, bcx, chunk_size, bcx->ne[1], bcx->ne[2], bcx->nb[1], bcx->nb[2], + 0 * chunk_size * ggml_element_size(bcx)); + auto * c = ggml_view_3d(ctx0, bcx, chunk_size, bcx->ne[1], bcx->ne[2], bcx->nb[1], bcx->nb[2], + 1 * chunk_size * ggml_element_size(bcx)); + auto * x = ggml_view_3d(ctx0, bcx, chunk_size, bcx->ne[1], bcx->ne[2], bcx->nb[1], bcx->nb[2], + 2 * chunk_size * ggml_element_size(bcx)); + + auto * bx = ggml_transpose(ctx0, ggml_mul(ctx0, b, x)); + + // read conv state + auto * conv_state = mctx_cur->get_r_l(il); + auto * conv_rs = build_rs(inp_recr, conv_state, hparams.n_embd_r(), n_seqs); + auto * conv = ggml_reshape_3d(ctx0, conv_rs, d_conv, hparams.n_embd, n_seqs); + + bx = ggml_concat(ctx0, conv, bx, 0); + GGML_ASSERT(bx->ne[0] > conv->ne[0]); + + // last d_conv columns is a new conv state + auto * new_conv = ggml_view_3d(ctx0, bx, conv->ne[0], bx->ne[1], bx->ne[2], bx->nb[1], bx->nb[2], + (bx->ne[0] - conv->ne[0]) * ggml_element_size(bx)); + GGML_ASSERT(ggml_are_same_shape(conv, new_conv)); + + // write new conv conv state + ggml_build_forward_expand(gf, ggml_cpy(ctx0, new_conv, + ggml_view_1d(ctx0, conv_state, ggml_nelements(new_conv), + kv_head * d_conv * n_embd * ggml_element_size(new_conv)))); + + auto * conv_kernel = model.layers[il].shortconv.conv; + auto * conv_out = ggml_ssm_conv(ctx0, bx, conv_kernel); + cb(conv_out, "model.layers.{}.conv.conv", il); + + auto * y = ggml_mul(ctx0, c, conv_out); + y = build_lora_mm(model.layers[il].shortconv.out_proj, y); + cb(y, "model.layers.{}.conv.out_proj", il); + // {n_embd, n_seq_tokens, n_seqs} => {n_embd, n_tokens} + y = ggml_reshape_2d(ctx0, y, y->ne[0], n_seq_tokens * n_seqs); + + return y; + }; + + // actual graph construction starts here ggml_tensor * cur = build_inp_embd(model.tok_embd); cb(cur, "model.embed_tokens", -1); ggml_build_forward_expand(gf, cur); + inp_hybrid_type * inp_hybrid = nullptr; + if constexpr (iswa) { + inp_hybrid = build_inp_mem_hybrid_iswa(); + } else { + inp_hybrid = build_inp_mem_hybrid(); + } + ggml_tensor * inp_pos = build_inp_pos(); - auto * inp_hybrid = build_inp_mem_hybrid(); ggml_tensor * inp_out_ids = build_inp_out_ids(); for (int il = 0; il < n_layer; ++il) { @@ -22,8 +239,8 @@ llm_build_lfm2::llm_build_lfm2(const llama_model & model, const llm_graph_params cur = build_norm(cur, model.layers[il].attn_norm, NULL, LLM_NORM_RMS, il); cb(cur, "model.layers.{}.operator_norm", il); - cur = hparams.is_recurrent(il) ? build_shortconv_block(cur, inp_hybrid->get_recr(), il) : - build_attn_block(cur, inp_pos, inp_hybrid->get_attn(), il); + cur = hparams.is_recr(il) ? build_shortconv_block(cur, inp_hybrid->get_recr(), il) : + build_attn_block(cur, inp_pos, inp_hybrid->get_attn(), il); if (il == n_layer - 1 && inp_out_ids) { cur = ggml_get_rows(ctx0, cur, inp_out_ids); @@ -40,13 +257,16 @@ llm_build_lfm2::llm_build_lfm2(const llama_model & model, const llm_graph_params cb(ffn_norm_out, "model.layers.{}.ffn_out", il); cur = ggml_add(ctx0, cur, ffn_out); + + cur = build_cvec(cur, il); + cb(cur, "l_out", il); } cur = build_norm(cur, model.output_norm, NULL, LLM_NORM_RMS, -1); cb(cur, "result_norm", -1); res->t_embd = cur; - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; @@ -54,122 +274,6 @@ llm_build_lfm2::llm_build_lfm2(const llama_model & model, const llm_graph_params ggml_build_forward_expand(gf, cur); } -ggml_tensor * llm_build_lfm2::build_moe_feed_forward(ggml_tensor * cur, int il) const { - return build_moe_ffn(cur, - model.layers[il].ffn_gate_inp, model.layers[il].ffn_up_exps, - model.layers[il].ffn_gate_exps, model.layers[il].ffn_down_exps, - model.layers[il].ffn_exp_probs_b, n_expert, n_expert_used, LLM_FFN_SILU, true, false, 0.0, - static_cast<llama_expert_gating_func_type>(hparams.expert_gating_func), il); -} - -ggml_tensor * llm_build_lfm2::build_dense_feed_forward(ggml_tensor * cur, int il) const { - GGML_ASSERT(!model.layers[il].ffn_up_b); - GGML_ASSERT(!model.layers[il].ffn_gate_b); - GGML_ASSERT(!model.layers[il].ffn_down_b); - return build_ffn(cur, - model.layers[il].ffn_up, NULL, NULL, - model.layers[il].ffn_gate, NULL, NULL, - model.layers[il].ffn_down, NULL, NULL, - NULL, LLM_FFN_SILU, LLM_FFN_PAR, il); -} - -ggml_tensor * llm_build_lfm2::build_attn_block(ggml_tensor * cur, - ggml_tensor * inp_pos, - llm_graph_input_attn_kv * inp_attn, - int il) const { - GGML_ASSERT(hparams.n_embd_v_gqa(il) == hparams.n_embd_k_gqa(il)); - const auto n_embd_head = hparams.n_embd_head_v; - const auto n_head_kv = hparams.n_head_kv(il); - - auto * q = build_lora_mm(model.layers[il].wq, cur); - cb(q, "model.layers.{}.self_attn.q_proj", il); - auto * k = build_lora_mm(model.layers[il].wk, cur); - cb(k, "model.layers.{}.self_attn.k_proj", il); - auto * v = build_lora_mm(model.layers[il].wv, cur); - cb(v, "model.layers.{}.self_attn.v_proj", il); - - q = ggml_reshape_3d(ctx0, q, n_embd_head, n_head, n_tokens); - k = ggml_reshape_3d(ctx0, k, n_embd_head, n_head_kv, n_tokens); - v = ggml_reshape_3d(ctx0, v, n_embd_head, n_head_kv, n_tokens); - - // qk norm - q = build_norm(q, model.layers[il].attn_q_norm, NULL, LLM_NORM_RMS, il); - cb(q, "model.layers.{}.self_attn.q_layernorm", il); - k = build_norm(k, model.layers[il].attn_k_norm, NULL, LLM_NORM_RMS, il); - cb(k, "model.layers.{}.self_attn.k_layernorm", il); - - // RoPE - q = ggml_rope_ext(ctx0, q, inp_pos, nullptr, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, ext_factor, - attn_factor, beta_fast, beta_slow); - k = ggml_rope_ext(ctx0, k, inp_pos, nullptr, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, ext_factor, - attn_factor, beta_fast, beta_slow); - - cur = build_attn(inp_attn, - model.layers[il].wo, NULL, - q, k, v, nullptr, nullptr, nullptr, 1.0f / sqrtf(float(n_embd_head)), il); - - cb(cur, "model.layers.{}.self_attn.out_proj", il); - - return cur; -} - -ggml_tensor * llm_build_lfm2::build_shortconv_block(ggml_tensor * cur, llm_graph_input_rs * inp_recr, int il) { - const auto * mctx_cur = static_cast<const llama_memory_hybrid_context *>(mctx)->get_recr(); - const uint32_t kv_head = mctx_cur->get_head(); - const int64_t n_seq_tokens = ubatch.n_seq_tokens; - const int64_t n_seqs = ubatch.n_seqs; - GGML_ASSERT(n_seqs != 0); - GGML_ASSERT(ubatch.equal_seqs()); - GGML_ASSERT(ubatch.n_tokens == n_seq_tokens * n_seqs); - - GGML_ASSERT(hparams.n_shortconv_l_cache > 1); - const uint32_t d_conv = hparams.n_shortconv_l_cache - 1; - - // {n_embd, n_tokens} => {n_embd, n_seq_tokens, n_seqs} - cur = ggml_reshape_3d(ctx0, cur, cur->ne[0], n_seq_tokens, n_seqs); - - auto * bcx = build_lora_mm(model.layers[il].shortconv.in_proj, cur); - cb(bcx, "model.layers.{}.conv.in_proj", il); - - constexpr auto n_chunks = 3; - GGML_ASSERT(bcx->ne[0] % n_chunks == 0); - const auto chunk_size = bcx->ne[0] / n_chunks; - auto * b = ggml_view_3d(ctx0, bcx, chunk_size, bcx->ne[1], bcx->ne[2], bcx->nb[1], bcx->nb[2], - 0 * chunk_size * ggml_element_size(bcx)); - auto * c = ggml_view_3d(ctx0, bcx, chunk_size, bcx->ne[1], bcx->ne[2], bcx->nb[1], bcx->nb[2], - 1 * chunk_size * ggml_element_size(bcx)); - auto * x = ggml_view_3d(ctx0, bcx, chunk_size, bcx->ne[1], bcx->ne[2], bcx->nb[1], bcx->nb[2], - 2 * chunk_size * ggml_element_size(bcx)); - - auto * bx = ggml_transpose(ctx0, ggml_mul(ctx0, b, x)); - - // read conv state - auto * conv_state = mctx_cur->get_r_l(il); - auto * conv_rs = build_rs(inp_recr, conv_state, hparams.n_embd_r(), n_seqs); - auto * conv = ggml_reshape_3d(ctx0, conv_rs, d_conv, hparams.n_embd, n_seqs); - - bx = ggml_concat(ctx0, conv, bx, 0); - GGML_ASSERT(bx->ne[0] > conv->ne[0]); - - // last d_conv columns is a new conv state - auto * new_conv = ggml_view_3d(ctx0, bx, conv->ne[0], bx->ne[1], bx->ne[2], bx->nb[1], bx->nb[2], - (bx->ne[0] - conv->ne[0]) * ggml_element_size(bx)); - GGML_ASSERT(ggml_are_same_shape(conv, new_conv)); - - // write new conv conv state - ggml_build_forward_expand(gf, ggml_cpy(ctx0, new_conv, - ggml_view_1d(ctx0, conv_state, ggml_nelements(new_conv), - kv_head * d_conv * n_embd * ggml_element_size(new_conv)))); - - auto * conv_kernel = model.layers[il].shortconv.conv; - auto * conv_out = ggml_ssm_conv(ctx0, bx, conv_kernel); - cb(conv_out, "model.layers.{}.conv.conv", il); - - auto * y = ggml_mul(ctx0, c, conv_out); - y = build_lora_mm(model.layers[il].shortconv.out_proj, y); - cb(y, "model.layers.{}.conv.out_proj", il); - // {n_embd, n_seq_tokens, n_seqs} => {n_embd, n_tokens} - y = ggml_reshape_2d(ctx0, y, y->ne[0], n_seq_tokens * n_seqs); - - return y; -} +// Explicit template instantiations +template struct llama_model_lfm2::graph<true>; +template struct llama_model_lfm2::graph<false>; diff --git a/examples/talk-llama/models/lfm2moe.cpp b/examples/talk-llama/models/lfm2moe.cpp new file mode 100644 index 00000000000..490f5c223eb --- /dev/null +++ b/examples/talk-llama/models/lfm2moe.cpp @@ -0,0 +1,85 @@ +#include "models.h" +#include "../llama-memory-hybrid-iswa.h" +#include "../llama-memory-hybrid.h" + +void llama_model_lfm2moe::load_arch_hparams(llama_model_loader & ml) { + ml.get_key(LLM_KV_SHORTCONV_L_CACHE, hparams.n_shortconv_l_cache); + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + ml.get_key(LLM_KV_LEADING_DENSE_BLOCK_COUNT, hparams.n_layer_dense_lead, false); + ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp); + ml.get_key(LLM_KV_EXPERT_GATING_FUNC, hparams.expert_gating_func); + + for (uint32_t il = 0; il < hparams.n_layer(); ++il) { + hparams.is_recr_impl[il] = hparams.n_head_kv(il) == 0; + } + + switch (hparams.n_layer()) { + case 24: type = LLM_TYPE_8B_A1B; break; + case 40: type = LLM_TYPE_24B_A2B; break; + default: type = LLM_TYPE_UNKNOWN; + } +} + +void llama_model_lfm2moe::load_arch_tensors(llama_model_loader &) { + LLAMA_LOAD_LOCALS; + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM_LFM2, "weight"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); + + if (output == NULL) { + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); + } + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + const bool is_moe_layer = i >= static_cast<int>(hparams.n_layer_dense_lead); + + // ffn/moe is same for transformer and conv layers + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + if (is_moe_layer) { + GGML_ASSERT(n_expert && n_expert_used); + layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0); + layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {n_embd, hparams.n_ff_exp, n_expert}, 0); + layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {hparams.n_ff_exp, n_embd, n_expert}, 0); + layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), {n_embd, hparams.n_ff_exp, n_expert}, 0); + layer.ffn_exp_probs_b = create_tensor(tn(LLM_TENSOR_FFN_EXP_PROBS_B, "bias", i), {n_expert}, 0); + } else { // dense + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + } + + // for operator_norm + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + + if (!hparams.is_recr(i)) { + layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head_k}, 0); + layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k}, 0); + GGML_ASSERT(n_embd_v_gqa == n_embd_k_gqa); + + create_tensor_qkv(layer, i, n_embd, n_embd, hparams.n_embd_k_gqa(i), hparams.n_embd_v_gqa(i), 0); + + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); + } else { + layer.shortconv.conv = create_tensor(tn(LLM_TENSOR_SHORTCONV_CONV, "weight", i), {hparams.n_shortconv_l_cache, n_embd}, 0); + layer.shortconv.in_proj = create_tensor(tn(LLM_TENSOR_SHORTCONV_INPROJ, "weight", i), {n_embd, 3 * n_embd}, 0); + layer.shortconv.out_proj = create_tensor(tn(LLM_TENSOR_SHORTCONV_OUTPROJ, "weight", i), {n_embd, n_embd}, 0); + } + } + + // for LFM2-ColBert-350M + dense_2_out_layers = create_tensor(tn(LLM_TENSOR_DENSE_2_OUT, "weight"), {n_embd, hparams.n_embd_out()}, TENSOR_NOT_REQUIRED); + dense_2_out_layers_b = create_tensor(tn(LLM_TENSOR_DENSE_2_OUT, "bias"), {hparams.n_embd_out() }, TENSOR_NOT_REQUIRED); +} + +std::unique_ptr<llm_graph_context> llama_model_lfm2moe::build_arch_graph(const llm_graph_params & params) const { + if (hparams.swa_type == LLAMA_SWA_TYPE_STANDARD) { + return std::make_unique<graph<true>>(*this, params); + } else { + return std::make_unique<graph<false>>(*this, params); + } +} + diff --git a/examples/talk-llama/models/llada-moe.cpp b/examples/talk-llama/models/llada-moe.cpp index 5f64686f5fb..2ae89386447 100644 --- a/examples/talk-llama/models/llada-moe.cpp +++ b/examples/talk-llama/models/llada-moe.cpp @@ -1,10 +1,61 @@ #include "models.h" -llm_build_llada_moe::llm_build_llada_moe(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { - const int64_t n_embd_head = hparams.n_embd_head_v; +void llama_model_llada_moe::load_arch_hparams(llama_model_loader & ml) { + ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp, false); + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); - GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); - GGML_ASSERT(n_embd_head == hparams.n_rot); + // diffusion language model uses non-causal attention + hparams.causal_attn = false; + + switch (hparams.n_layer()) { + case 16: type = LLM_TYPE_A1_7B; break; + default: type = LLM_TYPE_UNKNOWN; + } +} + +void llama_model_llada_moe::load_arch_tensors(llama_model_loader &) { + LLAMA_LOAD_LOCALS; + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0); + + GGML_ASSERT(n_expert > 0 && "n_expert must be > 0 for llada-moe"); + GGML_ASSERT(n_expert_used > 0 && "n_expert_used must be > 0 for llada-moe"); + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + + create_tensor_qkv(layer, i, n_embd, n_embd, n_embd_gqa, n_embd_gqa, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); + layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head_k}, 0); + layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k}, 0); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + + layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0); + + const int64_t n_ff_exp = hparams.n_ff_exp ? hparams.n_ff_exp : n_ff / n_expert_used; + + layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}, 0); + layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff_exp, n_embd, n_expert}, 0); + layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}, 0); + } +} + +std::unique_ptr<llm_graph_context> llama_model_llada_moe::build_arch_graph(const llm_graph_params & params) const { + return std::make_unique<graph>(*this, params); +} + +llama_model_llada_moe::graph::graph(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { + const int64_t n_embd_head = hparams.n_embd_head_v(); + + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k()); + GGML_ASSERT(n_embd_head == n_rot); ggml_tensor * cur; ggml_tensor * inpL; @@ -30,18 +81,8 @@ llm_build_llada_moe::llm_build_llada_moe(const llama_model & model, const llm_gr // self_attention { // compute Q and K and RoPE them - ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); - cb(Qcur, "Qcur", il); - - ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); - cb(Kcur, "Kcur", il); - - ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); - cb(Vcur, "Vcur", il); - - Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); - Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + auto [Qcur, Kcur, Vcur] = build_qkv(model.layers[il], cur, + n_embd_head, n_head, n_head_kv, il); Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, NULL, LLM_NORM_RMS, il); cb(Qcur, "Qcur_normed", il); @@ -66,7 +107,7 @@ llm_build_llada_moe::llm_build_llada_moe(const llama_model & model, const llm_gr cb(Vcur, "Vcur", il); cur = build_attn(inp_attn, - model.layers[il].wo, NULL, + model.layers[il].wo, NULL, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } if (il == n_layer - 1 && inp_out_ids) { @@ -90,7 +131,7 @@ llm_build_llada_moe::llm_build_llada_moe(const llama_model & model, const llm_gr nullptr, n_expert, n_expert_used, LLM_FFN_SILU, false, - false, 0.0, + hparams.expert_weights_scale, LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX, il); cb(cur, "ffn_moe_out", il); @@ -113,7 +154,7 @@ llm_build_llada_moe::llm_build_llada_moe(const llama_model & model, const llm_gr res->t_embd = cur; // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/llada.cpp b/examples/talk-llama/models/llada.cpp index 857033660a0..87d4259f9a7 100644 --- a/examples/talk-llama/models/llada.cpp +++ b/examples/talk-llama/models/llada.cpp @@ -1,11 +1,79 @@ #include "models.h" -llm_build_llada::llm_build_llada(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { +void llama_model_llada::load_arch_hparams(llama_model_loader & ml) { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + + // LLaDA-8B has 32 layers, similar to LLaMA but for diffusion + switch (hparams.n_layer()) { + case 32: + type = LLM_TYPE_8B; + break; + default: + type = LLM_TYPE_UNKNOWN; + } + + // Set non-causal attention for diffusion models + hparams.causal_attn = false; +} + +void llama_model_llada::load_arch_tensors(llama_model_loader &) { + LLAMA_LOAD_LOCALS; + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), { n_embd, n_vocab }, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), { n_embd }, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), { n_embd, n_vocab }, TENSOR_NOT_REQUIRED); + + // if output is NULL, init from the input tok embed + if (output == NULL) { + output = + create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), { n_embd, n_vocab }, TENSOR_DUPLICATED); + } + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), { n_embd }, 0); + + // Use separate Q, K, V projections without bias, matching LLaDALlamaBlock + layer.wq = + create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), { n_embd, n_embd_head_k * n_head }, 0); + layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), { n_embd, n_embd_k_gqa }, 0); + layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), { n_embd, n_embd_v_gqa }, 0); + // No bias for QKV projections as per config: include_bias=false, include_qkv_bias=false + layer.wo = + create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), { n_embd_head_k * n_head, n_embd }, 0); + layer.wo_b = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), { n_embd }, TENSOR_NOT_REQUIRED); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), { n_embd }, 0); + + layer.rope_freqs = create_tensor(tn(LLM_TENSOR_ROPE_FREQS, "weight", i), { n_rot / 2 }, + TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0)); + + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), { n_embd, n_ff }, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd }, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), { n_embd, n_ff }, 0); + + // optional MLP bias + layer.ffn_gate_b = + create_tensor(tn(LLM_TENSOR_FFN_GATE, "bias", i), { n_ff }, TENSOR_NOT_REQUIRED); + layer.ffn_down_b = + create_tensor(tn(LLM_TENSOR_FFN_DOWN, "bias", i), { n_embd }, TENSOR_NOT_REQUIRED); + layer.ffn_up_b = create_tensor(tn(LLM_TENSOR_FFN_UP, "bias", i), { n_ff }, TENSOR_NOT_REQUIRED); + } +} + +std::unique_ptr<llm_graph_context> llama_model_llada::build_arch_graph(const llm_graph_params & params) const { + return std::make_unique<graph>(*this, params); +} + +llama_model_llada::graph::graph(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { // LLaDA is similar to LLaMA but uses non-causal attention for diffusion - const int64_t n_embd_head = hparams.n_embd_head_v; + const int64_t n_embd_head = hparams.n_embd_head_v(); - GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); - GGML_ASSERT(n_embd_head == hparams.n_rot); + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k()); + GGML_ASSERT(n_embd_head == n_rot); ggml_tensor * cur; ggml_tensor * inpL; @@ -30,17 +98,8 @@ llm_build_llada::llm_build_llada(const llama_model & model, const llm_graph_para // self-attention { // compute separate Q, K, V projections without bias, matching LLaDALlamaBlock - ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); - ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); - ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); - - cb(Qcur, "Qcur", il); - cb(Kcur, "Kcur", il); - cb(Vcur, "Vcur", il); - - Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); - Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + auto [Qcur, Kcur, Vcur] = build_qkv(model.layers[il], cur, + n_embd_head, n_head, n_head_kv, il); Qcur = ggml_rope_ext(ctx0, Qcur, inp_pos, nullptr, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow); @@ -53,7 +112,7 @@ llm_build_llada::llm_build_llada(const llama_model & model, const llm_graph_para cb(Vcur, "Vcur", il); cur = build_attn(inp_attn, - model.layers[il].wo, NULL, + model.layers[il].wo, NULL, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f / sqrtf(float(n_embd_head)), il); } if (il == n_layer - 1 && inp_out_ids) { @@ -90,7 +149,7 @@ llm_build_llada::llm_build_llada(const llama_model & model, const llm_graph_para res->t_embd = cur; // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/llama-embed.cpp b/examples/talk-llama/models/llama-embed.cpp new file mode 100644 index 00000000000..0699e744461 --- /dev/null +++ b/examples/talk-llama/models/llama-embed.cpp @@ -0,0 +1,6 @@ +#include "models.h" + +std::unique_ptr<llm_graph_context> llama_model_llama_embed::build_arch_graph(const llm_graph_params & params) const { + return std::make_unique<graph<true>>(*this, params); +} + diff --git a/examples/talk-llama/models/llama-iswa.cpp b/examples/talk-llama/models/llama-iswa.cpp deleted file mode 100644 index 61dd2c179f1..00000000000 --- a/examples/talk-llama/models/llama-iswa.cpp +++ /dev/null @@ -1,178 +0,0 @@ -#include "models.h" - -llm_build_llama_iswa::llm_build_llama_iswa(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { - const int64_t n_embd_head = hparams.n_embd_head_v; - - GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); - GGML_ASSERT(n_embd_head == hparams.n_rot); - - ggml_tensor * cur; - ggml_tensor * inpL; - - inpL = build_inp_embd(model.tok_embd); - - // inp_pos - contains the positions - ggml_tensor * inp_pos = build_inp_pos(); - - // temperature tuning - ggml_tensor * inp_attn_scale = nullptr; - inp_attn_scale = build_inp_attn_scale(); - - auto * inp_attn = build_attn_inp_kv_iswa(); - - const float kq_scale = hparams.f_attention_scale == 0.0f ? 1.0f/sqrtf(float(n_embd_head)) : hparams.f_attention_scale; - - ggml_tensor * inp_out_ids = build_inp_out_ids(); - - for (int il = 0; il < n_layer; ++il) { - const float freq_base_l = model.get_rope_freq_base (cparams, il); - const float freq_scale_l = model.get_rope_freq_scale(cparams, il); - - ggml_tensor * inpSA = inpL; - - // This overlaps with SWA layers in current models, so get_rope_freq_base/scale may be superfluous - const bool use_rope = hparams.n_no_rope_layer_step > 0 && - (il + 1) % hparams.n_no_rope_layer_step != 0; - - // norm - cur = build_norm(inpL, - model.layers[il].attn_norm, NULL, - LLM_NORM_RMS, il); - cb(cur, "attn_norm", il); - - // self-attention - { - // rope freq factors for llama3; may return nullptr for llama2 and other models - ggml_tensor * rope_factors = model.get_rope_factors(cparams, il); - - // compute Q and K and RoPE them - ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); - cb(Qcur, "Qcur", il); - if (model.layers[il].bq) { - Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq); - cb(Qcur, "Qcur", il); - } - ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); - cb(Kcur, "Kcur", il); - if (model.layers[il].bk) { - Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk); - cb(Kcur, "Kcur", il); - } - ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); - cb(Vcur, "Vcur", il); - if (model.layers[il].bv) { - Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv); - cb(Vcur, "Vcur", il); - } - Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); - Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); - - if (use_rope) { - Qcur = ggml_rope_ext( - ctx0, Qcur, inp_pos, rope_factors, - n_rot, rope_type, n_ctx_orig, freq_base_l, freq_scale_l, - ext_factor, attn_factor, beta_fast, beta_slow - ); - - Kcur = ggml_rope_ext( - ctx0, Kcur, inp_pos, rope_factors, - n_rot, rope_type, n_ctx_orig, freq_base_l, freq_scale_l, - ext_factor, attn_factor, beta_fast, beta_slow - ); - } else if (inp_attn_scale) { - Qcur = ggml_mul(ctx0, Qcur, inp_attn_scale); - } - cb(Qcur, "Qcur", il); - cb(Kcur, "Kcur", il); - cb(Vcur, "Vcur", il); - - if (use_rope && hparams.use_kq_norm) { - // Llama4TextL2Norm - Qcur = ggml_rms_norm(ctx0, Qcur, hparams.f_norm_rms_eps); - Kcur = ggml_rms_norm(ctx0, Kcur, hparams.f_norm_rms_eps); - cb(Qcur, "Qcur_normed", il); - cb(Kcur, "Kcur_normed", il); - } - cur = build_attn(inp_attn, - model.layers[il].wo, model.layers[il].bo, - Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale, il); - cb(cur, "attn_out", il); - } - if (il == n_layer - 1 && inp_out_ids) { - cur = ggml_get_rows(ctx0, cur, inp_out_ids); - inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); - } - ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA); - cb(ffn_inp, "ffn_inp", il); - - // feed-forward network (non-MoE) - if (model.layers[il].ffn_gate_inp == nullptr) { - cur = build_norm(ffn_inp, - model.layers[il].ffn_norm, NULL, - LLM_NORM_RMS, il); - cb(cur, "ffn_norm", il); - - cur = build_ffn(cur, - model.layers[il].ffn_up, model.layers[il].ffn_up_b, NULL, - model.layers[il].ffn_gate, model.layers[il].ffn_gate_b, NULL, - model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL, - NULL, - LLM_FFN_SILU, LLM_FFN_PAR, il); - cb(cur, "ffn_out", il); - } else { - ggml_tensor * ffn_inp_normed = build_norm(ffn_inp, - model.layers[il].ffn_norm, NULL, - LLM_NORM_RMS, il); - cb(cur, "ffn_norm", il); - - ggml_tensor * moe_out = build_moe_ffn(ffn_inp_normed, - model.layers[il].ffn_gate_inp, - model.layers[il].ffn_up_exps, - model.layers[il].ffn_gate_exps, - model.layers[il].ffn_down_exps, - nullptr, - n_expert, n_expert_used, - LLM_FFN_SILU, false, - false, 0.0, - LLAMA_EXPERT_GATING_FUNC_TYPE_SIGMOID, - il); - - // Shared experts - ggml_tensor * shexp_out = build_ffn(ffn_inp_normed, - model.layers[il].ffn_up_shexp, NULL, NULL, - model.layers[il].ffn_gate_shexp, NULL, NULL, - model.layers[il].ffn_down_shexp, NULL, NULL, - NULL, - LLM_FFN_SILU, LLM_FFN_PAR, il); - cb(shexp_out, "ffn_moe_shexp", il); - - cur = ggml_add(ctx0, moe_out, shexp_out); - cb(cur, "ffn_moe_out_merged", il); - } - cur = ggml_add(ctx0, cur, ffn_inp); - cb(cur, "ffn_out", il); - - cur = build_cvec(cur, il); - cb(cur, "l_out", il); - - // input for next layer - inpL = cur; - } - cur = inpL; - - cur = build_norm(cur, - model.output_norm, NULL, - LLM_NORM_RMS, -1); - - cb(cur, "result_norm", -1); - res->t_embd = cur; - - // lm_head - cur = build_lora_mm(model.output, cur); - - cb(cur, "result_output", -1); - res->t_logits = cur; - - ggml_build_forward_expand(gf, cur); -} diff --git a/examples/talk-llama/models/llama.cpp b/examples/talk-llama/models/llama.cpp index 42b5fcdf42e..4bfebc8843c 100644 --- a/examples/talk-llama/models/llama.cpp +++ b/examples/talk-llama/models/llama.cpp @@ -1,11 +1,106 @@ #include "models.h" +void llama_model_llama::load_arch_hparams(llama_model_loader & ml) { + uint32_t n_vocab = 0; + ml.get_key(LLM_KV_VOCAB_SIZE, n_vocab, false) || ml.get_arr_n(LLM_KV_TOKENIZER_LIST, n_vocab, false); + + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + + if (hparams.n_expert == 8) { + switch (hparams.n_layer()) { + case 32: type = LLM_TYPE_8x7B; break; + case 56: type = LLM_TYPE_8x22B; break; + default: type = LLM_TYPE_UNKNOWN; + } + } else { + switch (hparams.n_layer()) { + case 16: type = LLM_TYPE_1B; break; // Llama 3.2 1B + case 22: type = LLM_TYPE_1B; break; + case 26: type = LLM_TYPE_3B; break; + case 28: type = LLM_TYPE_3B; break; // Llama 3.2 3B + case 30: type = LLM_TYPE_256M; break; // smoldocling 256M + // granite uses a vocab with len 49152 + case 32: type = n_vocab == 49152 ? LLM_TYPE_3B : (n_vocab < 40000 ? LLM_TYPE_7B : LLM_TYPE_8B); break; + case 36: type = LLM_TYPE_8B; break; // granite + case 40: type = LLM_TYPE_13B; break; + case 48: type = LLM_TYPE_34B; break; + case 60: type = LLM_TYPE_30B; break; + case 80: type = hparams.n_head() == hparams.n_head_kv() ? LLM_TYPE_65B : LLM_TYPE_70B; break; + default: type = LLM_TYPE_UNKNOWN; + } + } +} + +void llama_model_llama::load_arch_tensors(llama_model_loader &) { + LLAMA_LOAD_LOCALS; + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); + + // if output is NULL, init from the input tok embed + if (output == NULL) { + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); + } + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + + create_tensor_qkv(layer, i, n_embd, n_embd_head_k * n_head, n_embd_k_gqa, n_embd_v_gqa, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0); + + // optional bias tensors + layer.wo_b = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + + if (hparams.rope_scaling_type_train == LLAMA_ROPE_SCALING_TYPE_LONGROPE) { + layer.rope_long = create_tensor(tn(LLM_TENSOR_ROPE_FACTORS_LONG, "weight", i), {n_rot/2}, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0)); + layer.rope_short = create_tensor(tn(LLM_TENSOR_ROPE_FACTORS_SHORT, "weight", i), {n_rot/2}, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0)); + } + else { + layer.rope_freqs = create_tensor(tn(LLM_TENSOR_ROPE_FREQS, "weight", i), {n_rot/2}, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0)); + } + + if (n_expert == 0) { + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + + // optional MLP bias + layer.ffn_gate_b = create_tensor(tn(LLM_TENSOR_FFN_GATE, "bias", i), {n_ff}, TENSOR_NOT_REQUIRED); + layer.ffn_down_b = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); + layer.ffn_up_b = create_tensor(tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff}, TENSOR_NOT_REQUIRED); + } else { + layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0); + layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {n_embd, n_ff, n_expert}, TENSOR_NOT_REQUIRED); + layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), { n_ff, n_embd, n_expert}, 0); + layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), {n_embd, n_ff, n_expert}, 0); + + // For Granite MoE Shared + if (hparams.n_ff_shexp > 0) { + layer.ffn_gate_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), {n_embd, hparams.n_ff_shexp}, 0); + layer.ffn_up_shexp = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), {n_embd, hparams.n_ff_shexp}, 0); + layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), {hparams.n_ff_shexp, n_embd}, 0); + } + } + } +} + +std::unique_ptr<llm_graph_context> llama_model_llama::build_arch_graph(const llm_graph_params & params) const { + return std::make_unique<graph<false>>(*this, params); +} + template <bool embed> -llm_build_llama<embed>::llm_build_llama(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { - const int64_t n_embd_head = hparams.n_embd_head_v; +llama_model_llama::graph<embed>::graph(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { + const int64_t n_embd_head = hparams.n_embd_head_v(); - GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); - GGML_ASSERT(n_embd_head == hparams.n_rot); + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k()); + GGML_ASSERT(n_embd_head == n_rot); ggml_tensor * cur; ggml_tensor * inpL; @@ -29,6 +124,8 @@ llm_build_llama<embed>::llm_build_llama(const llama_model & model, const llm_gra ggml_tensor * inp_out_ids = build_inp_out_ids(); for (int il = 0; il < n_layer; ++il) { + res->t_layer_inp[il] = inpL; + ggml_tensor * inpSA = inpL; // norm @@ -43,27 +140,8 @@ llm_build_llama<embed>::llm_build_llama(const llama_model & model, const llm_gra ggml_tensor * rope_factors = model.get_rope_factors(cparams, il); // compute Q and K and RoPE them - ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); - cb(Qcur, "Qcur", il); - if (model.layers[il].bq) { - Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq); - cb(Qcur, "Qcur", il); - } - ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); - cb(Kcur, "Kcur", il); - if (model.layers[il].bk) { - Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk); - cb(Kcur, "Kcur", il); - } - ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); - cb(Vcur, "Vcur", il); - if (model.layers[il].bv) { - Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv); - cb(Vcur, "Vcur", il); - } - Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); - Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + auto [Qcur, Kcur, Vcur] = build_qkv(model.layers[il], cur, + n_embd_head, n_head, n_head_kv, il); Qcur = ggml_rope_ext( ctx0, Qcur, inp_pos, rope_factors, @@ -89,7 +167,7 @@ llm_build_llama<embed>::llm_build_llama(const llama_model & model, const llm_gra cb(Kcur, "Kcur_normed", il); } cur = build_attn(inp_attn, - model.layers[il].wo, model.layers[il].bo, + model.layers[il].wo, model.layers[il].wo_b, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale, il); cb(cur, "attn_out", il); } @@ -109,9 +187,9 @@ llm_build_llama<embed>::llm_build_llama(const llama_model & model, const llm_gra cb(cur, "ffn_norm", il); cur = build_ffn(cur, - model.layers[il].ffn_up, model.layers[il].ffn_up_b, NULL, - model.layers[il].ffn_gate, model.layers[il].ffn_gate_b, NULL, - model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL, + model.layers[il].ffn_up, model.layers[il].ffn_up_b, model.layers[il].ffn_up_s, + model.layers[il].ffn_gate, model.layers[il].ffn_gate_b, model.layers[il].ffn_gate_s, + model.layers[il].ffn_down, model.layers[il].ffn_down_b, model.layers[il].ffn_down_s, NULL, LLM_FFN_SILU, LLM_FFN_PAR, il); cb(cur, "ffn_out", il); @@ -130,9 +208,13 @@ llm_build_llama<embed>::llm_build_llama(const llama_model & model, const llm_gra nullptr, n_expert, n_expert_used, LLM_FFN_SILU, true, - false, 0.0, + hparams.expert_weights_scale, LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX, - il); + il, + nullptr, nullptr, + model.layers[il].ffn_up_exps_s, + model.layers[il].ffn_gate_exps_s, + model.layers[il].ffn_down_exps_s); cb(cur, "ffn_moe_out", il); } cur = ggml_add(ctx0, cur, ffn_inp); @@ -155,7 +237,7 @@ llm_build_llama<embed>::llm_build_llama(const llama_model & model, const llm_gra if constexpr (!embed) { // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; @@ -164,5 +246,5 @@ llm_build_llama<embed>::llm_build_llama(const llama_model & model, const llm_gra ggml_build_forward_expand(gf, cur); } -template struct llm_build_llama<false>; -template struct llm_build_llama<true>; +template struct llama_model_llama::graph<false>; +template struct llama_model_llama::graph<true>; diff --git a/examples/talk-llama/models/llama4.cpp b/examples/talk-llama/models/llama4.cpp new file mode 100644 index 00000000000..7194c72a585 --- /dev/null +++ b/examples/talk-llama/models/llama4.cpp @@ -0,0 +1,274 @@ +#include "models.h" + +void llama_model_llama4::load_arch_hparams(llama_model_loader & ml) { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp); + ml.get_key(LLM_KV_INTERLEAVE_MOE_LAYER_STEP, hparams.n_moe_layer_step); + + const bool found_swa = ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa, false); + if (found_swa && hparams.n_swa == 0) { + hparams.swa_type = LLAMA_SWA_TYPE_NONE; + hparams.n_no_rope_layer_step = hparams.n_layer(); // always use rope + } else { + hparams.swa_type = LLAMA_SWA_TYPE_CHUNKED; + hparams.n_swa = 8192; + hparams.n_attn_temp_floor_scale = 8192; + hparams.f_attn_temp_scale = 0.1f; + hparams.f_attn_temp_offset = 1.0f; + + uint32_t swa_period = 4; // pattern: 3 chunked - 1 full + ml.get_key_or_arr(LLM_KV_ATTENTION_SLIDING_WINDOW_PATTERN, swa_period, false); + hparams.set_swa_pattern(swa_period); + + hparams.rope_freq_base_train_swa = hparams.rope_freq_base_train; + hparams.rope_freq_scale_train_swa = hparams.rope_freq_scale_train; + ml.get_key(LLM_KV_ROPE_FREQ_BASE_SWA, hparams.rope_freq_base_train_swa, false); + } + + switch (hparams.n_expert) { + case 0: { + // MobileLLM (no MoE) + switch (hparams.n_embd) { + case 2048: type = LLM_TYPE_140M; break; + case 4096: type = LLM_TYPE_360M; break; + case 6144: type = LLM_TYPE_950M; break; + default: type = LLM_TYPE_UNKNOWN; + } + } break; + case 16: type = LLM_TYPE_17B_16E; break; + case 128: type = LLM_TYPE_17B_128E; break; + default: type = LLM_TYPE_UNKNOWN; + } + + hparams.use_kq_norm = type != LLM_TYPE_17B_128E; +} + +void llama_model_llama4::load_arch_tensors(llama_model_loader &) { + LLAMA_LOAD_LOCALS; + + if (n_expert == 0) { + throw std::runtime_error(arch_name() + " model cannot have zero experts"); + } + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); + + // if output is NULL, init from the input tok embed + if (output == NULL) { + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); + } + + for (int i = 0; i < n_layer; ++i) { + const bool is_moe_layer = hparams.n_moe_layer_step > 0 && (i + 1) % hparams.n_moe_layer_step == 0; + + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + + create_tensor_qkv(layer, i, n_embd, n_embd_head_k * n_head, n_embd_k_gqa, n_embd_v_gqa, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + + layer.rope_freqs = create_tensor(tn(LLM_TENSOR_ROPE_FREQS, "weight", i), {n_rot/2}, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0)); + + if (is_moe_layer) { + const int64_t n_ff_exp = hparams.n_ff_exp; + + layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0); + layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {n_embd, n_ff_exp, n_expert}, 0); + layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), { n_ff_exp, n_embd, n_expert}, 0); + layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), {n_embd, n_ff_exp, n_expert}, 0); + + // Shared expert + const int64_t n_ff_shexp = n_ff_exp; + layer.ffn_gate_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), { n_embd, n_ff_shexp}, 0); + layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), {n_ff_shexp, n_embd }, 0); + layer.ffn_up_shexp = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), { n_embd, n_ff_shexp}, 0); + } else { + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + } + } +} + +std::unique_ptr<llm_graph_context> llama_model_llama4::build_arch_graph(const llm_graph_params & params) const { + if (hparams.swa_type == LLAMA_SWA_TYPE_NONE) { + return std::make_unique<graph<false>>(*this, params); + } else { + return std::make_unique<graph<true>>(*this, params); + } +} + +template <bool iswa> +llama_model_llama4::graph<iswa>::graph(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { + const int64_t n_embd_head = hparams.n_embd_head_v(); + + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k()); + GGML_ASSERT(n_embd_head == n_rot); + + ggml_tensor * cur; + ggml_tensor * inpL; + + inpL = build_inp_embd(model.tok_embd); + + // inp_pos - contains the positions + ggml_tensor * inp_pos = build_inp_pos(); + + // temperature tuning + ggml_tensor * inp_attn_scale = nullptr; + inp_attn_scale = build_inp_attn_scale(); + + using inp_attn_type = std::conditional_t<iswa, llm_graph_input_attn_kv_iswa, llm_graph_input_attn_kv>; + inp_attn_type * inp_attn = nullptr; + + if constexpr (iswa) { + inp_attn = build_attn_inp_kv_iswa(); + } else { + inp_attn = build_attn_inp_kv(); + } + + const float kq_scale = hparams.f_attention_scale == 0.0f ? 1.0f/sqrtf(float(n_embd_head)) : hparams.f_attention_scale; + + ggml_tensor * inp_out_ids = build_inp_out_ids(); + + for (int il = 0; il < n_layer; ++il) { + const float freq_base_l = model.get_rope_freq_base (cparams, il); + const float freq_scale_l = model.get_rope_freq_scale(cparams, il); + + ggml_tensor * inpSA = inpL; + + // This overlaps with SWA layers in current models, so get_rope_freq_base/scale may be superfluous + const bool use_rope = hparams.n_no_rope_layer_step > 0 && + (il + 1) % hparams.n_no_rope_layer_step != 0; + + // norm + cur = build_norm(inpL, + model.layers[il].attn_norm, NULL, + LLM_NORM_RMS, il); + cb(cur, "attn_norm", il); + + // self-attention + { + // rope freq factors for llama3; may return nullptr for llama2 and other models + ggml_tensor * rope_factors = model.get_rope_factors(cparams, il); + + // compute Q and K and RoPE them + auto [Qcur, Kcur, Vcur] = build_qkv(model.layers[il], cur, + n_embd_head, n_head, n_head_kv, il); + + if (use_rope) { + Qcur = ggml_rope_ext( + ctx0, Qcur, inp_pos, rope_factors, + n_rot, rope_type, n_ctx_orig, freq_base_l, freq_scale_l, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + Kcur = ggml_rope_ext( + ctx0, Kcur, inp_pos, rope_factors, + n_rot, rope_type, n_ctx_orig, freq_base_l, freq_scale_l, + ext_factor, attn_factor, beta_fast, beta_slow + ); + } else if (inp_attn_scale) { + Qcur = ggml_mul(ctx0, Qcur, inp_attn_scale); + } + cb(Qcur, "Qcur", il); + cb(Kcur, "Kcur", il); + cb(Vcur, "Vcur", il); + + if (use_rope && hparams.use_kq_norm) { + // Llama4TextL2Norm + Qcur = ggml_rms_norm(ctx0, Qcur, hparams.f_norm_rms_eps); + Kcur = ggml_rms_norm(ctx0, Kcur, hparams.f_norm_rms_eps); + cb(Qcur, "Qcur_normed", il); + cb(Kcur, "Kcur_normed", il); + } + cur = build_attn(inp_attn, + model.layers[il].wo, model.layers[il].wo_b, model.layers[il].wo_s, + Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale, il); + cb(cur, "attn_out", il); + } + if (il == n_layer - 1 && inp_out_ids) { + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); + } + ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA); + cb(ffn_inp, "ffn_inp", il); + + // feed-forward network (non-MoE) + if (model.layers[il].ffn_gate_inp == nullptr) { + cur = build_norm(ffn_inp, + model.layers[il].ffn_norm, NULL, + LLM_NORM_RMS, il); + cb(cur, "ffn_norm", il); + + cur = build_ffn(cur, + model.layers[il].ffn_up, model.layers[il].ffn_up_b, NULL, + model.layers[il].ffn_gate, model.layers[il].ffn_gate_b, NULL, + model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL, + NULL, + LLM_FFN_SILU, LLM_FFN_PAR, il); + cb(cur, "ffn_out", il); + } else { + ggml_tensor * ffn_inp_normed = build_norm(ffn_inp, + model.layers[il].ffn_norm, NULL, + LLM_NORM_RMS, il); + cb(cur, "ffn_norm", il); + + ggml_tensor * moe_out = build_moe_ffn(ffn_inp_normed, + model.layers[il].ffn_gate_inp, + model.layers[il].ffn_up_exps, + model.layers[il].ffn_gate_exps, + model.layers[il].ffn_down_exps, + nullptr, + n_expert, n_expert_used, + LLM_FFN_SILU, false, + hparams.expert_weights_scale, + LLAMA_EXPERT_GATING_FUNC_TYPE_SIGMOID, + il); + + // Shared experts + ggml_tensor * shexp_out = build_ffn(ffn_inp_normed, + model.layers[il].ffn_up_shexp, NULL, NULL, + model.layers[il].ffn_gate_shexp, NULL, NULL, + model.layers[il].ffn_down_shexp, NULL, NULL, + NULL, + LLM_FFN_SILU, LLM_FFN_PAR, il); + cb(shexp_out, "ffn_moe_shexp", il); + + cur = ggml_add(ctx0, moe_out, shexp_out); + cb(cur, "ffn_moe_out_merged", il); + } + cur = ggml_add(ctx0, cur, ffn_inp); + cb(cur, "ffn_out", il); + + cur = build_cvec(cur, il); + cb(cur, "l_out", il); + + // input for next layer + inpL = cur; + } + cur = inpL; + + cur = build_norm(cur, + model.output_norm, NULL, + LLM_NORM_RMS, -1); + + cb(cur, "result_norm", -1); + res->t_embd = cur; + + // lm_head + cur = build_lora_mm(model.output, cur, model.output_s); + + cb(cur, "result_output", -1); + res->t_logits = cur; + + ggml_build_forward_expand(gf, cur); +} + +// Explicit template instantiations +template struct llama_model_llama4::graph<false>; +template struct llama_model_llama4::graph<true>; diff --git a/examples/talk-llama/models/maincoder.cpp b/examples/talk-llama/models/maincoder.cpp index da57308167e..ae56a26a1f6 100644 --- a/examples/talk-llama/models/maincoder.cpp +++ b/examples/talk-llama/models/maincoder.cpp @@ -1,10 +1,54 @@ #include "models.h" -llm_build_maincoder::llm_build_maincoder(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { - const int64_t n_embd_head = hparams.n_embd_head_v; +void llama_model_maincoder::load_arch_hparams(llama_model_loader & ml) { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); - GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); - GGML_ASSERT(n_embd_head == hparams.n_rot); + switch (hparams.n_layer()) { + case 32: type = LLM_TYPE_1B; break; + default: type = LLM_TYPE_UNKNOWN; + } +} + +void llama_model_maincoder::load_arch_tensors(llama_model_loader &) { + LLAMA_LOAD_LOCALS; + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); + // if output is NULL, init from the input tok embed + if (output == NULL) { + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); + } + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + + create_tensor_qkv(layer, i, n_embd, n_embd_head_k * n_head, n_embd_gqa, n_embd_gqa, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0); + + layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k}, 0); + layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head_k}, 0); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + } +} + +std::unique_ptr<llm_graph_context> llama_model_maincoder::build_arch_graph(const llm_graph_params & params) const { + return std::make_unique<graph>(*this, params); +} + +llama_model_maincoder::graph::graph(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { + const int64_t n_embd_head = hparams.n_embd_head_v(); + + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k()); + GGML_ASSERT(n_embd_head == n_rot); ggml_tensor * cur; ggml_tensor * inpL; @@ -30,18 +74,8 @@ llm_build_maincoder::llm_build_maincoder(const llama_model & model, const llm_gr // self-attention { // compute Q and K and RoPE them - ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); - cb(Qcur, "Qcur", il); - - ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); - cb(Kcur, "Kcur", il); - - ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); - cb(Vcur, "Vcur", il); - - Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); - Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + auto [Qcur, Kcur, Vcur] = build_qkv(model.layers[il], cur, + n_embd_head, n_head, n_head_kv, il); Qcur = ggml_rope_ext( ctx0, Qcur, inp_pos, nullptr, @@ -66,7 +100,7 @@ llm_build_maincoder::llm_build_maincoder(const llama_model & model, const llm_gr cb(Vcur, "Vcur", il); cur = build_attn(inp_attn, - model.layers[il].wo, model.layers[il].bo, + model.layers[il].wo, model.layers[il].wo_b, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } if (il == n_layer - 1 && inp_out_ids) { @@ -108,7 +142,7 @@ llm_build_maincoder::llm_build_maincoder(const llama_model & model, const llm_gr res->t_embd = cur; // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/graph-context-mamba.cpp b/examples/talk-llama/models/mamba-base.cpp similarity index 95% rename from examples/talk-llama/models/graph-context-mamba.cpp rename to examples/talk-llama/models/mamba-base.cpp index b9a363b32b6..c37f29c487e 100644 --- a/examples/talk-llama/models/graph-context-mamba.cpp +++ b/examples/talk-llama/models/mamba-base.cpp @@ -1,8 +1,10 @@ #include "models.h" -llm_graph_context_mamba::llm_graph_context_mamba(const llm_graph_params & params) : llm_graph_context(params) {} +#include "llama-memory-recurrent.h" -ggml_tensor * llm_graph_context_mamba::build_mamba_layer(llm_graph_input_rs * inp, +llm_build_mamba_base::llm_build_mamba_base(const llm_graph_params & params) : llm_graph_context(params) {} + +ggml_tensor * llm_build_mamba_base::build_mamba_layer(llm_graph_input_rs * inp, ggml_tensor * cur, const llama_model & model, const llama_ubatch & ubatch, @@ -28,6 +30,7 @@ ggml_tensor * llm_graph_context_mamba::build_mamba_layer(llm_graph_input_rs * in GGML_ASSERT(n_seqs != 0); GGML_ASSERT(ubatch.equal_seqs()); GGML_ASSERT(ubatch.n_tokens == n_seq_tokens * n_seqs); + GGML_ASSERT(d_inner % n_head == 0); ggml_tensor * conv_states_all = mctx_cur->get_r_l(il); ggml_tensor * ssm_states_all = mctx_cur->get_s_l(il); @@ -39,7 +42,7 @@ ggml_tensor * llm_graph_context_mamba::build_mamba_layer(llm_graph_input_rs * in cur = ggml_reshape_3d(ctx0, cur, cur->ne[0], n_seq_tokens, n_seqs); // {n_embd, 2*d_inner} @ {n_embd, n_seq_tokens, n_seqs} => {2*d_inner, n_seq_tokens, n_seqs} - ggml_tensor * xz = build_lora_mm(layer.ssm_in, cur); + ggml_tensor * xz = build_lora_mm(layer.ssm_in, cur, layer.ssm_in_s); // split the above in two // => {d_inner, n_seq_tokens, n_seqs} ggml_tensor * x = ggml_view_3d(ctx0, xz, d_inner, xz->ne[1], xz->ne[2], xz->nb[1], xz->nb[2], 0); @@ -134,7 +137,7 @@ ggml_tensor * llm_graph_context_mamba::build_mamba_layer(llm_graph_input_rs * in y = ggml_swiglu_split(ctx0, ggml_cont(ctx0, z), y); // {d_inner, n_embd} @ {d_inner, n_seq_tokens, n_seqs} => {n_embd, n_seq_tokens, n_seqs} - cur = build_lora_mm(layer.ssm_out, y); + cur = build_lora_mm(layer.ssm_out, y, layer.ssm_out_s); } // {n_embd, n_seq_tokens, n_seqs} => {n_embd, n_tokens} @@ -143,7 +146,7 @@ ggml_tensor * llm_graph_context_mamba::build_mamba_layer(llm_graph_input_rs * in return cur; } -ggml_tensor * llm_graph_context_mamba::build_mamba2_layer(llm_graph_input_rs * inp, +ggml_tensor * llm_build_mamba_base::build_mamba2_layer(llm_graph_input_rs * inp, ggml_tensor * cur, const llama_model & model, const llama_ubatch & ubatch, @@ -165,6 +168,9 @@ ggml_tensor * llm_graph_context_mamba::build_mamba2_layer(llm_graph_input_rs * i GGML_ASSERT(n_seqs != 0); GGML_ASSERT(ubatch.equal_seqs()); GGML_ASSERT(ubatch.n_tokens == n_seq_tokens * n_seqs); + GGML_ASSERT(d_inner % n_head == 0); + GGML_ASSERT(d_inner % d_state == 0); + GGML_ASSERT(d_inner % n_group == 0); ggml_tensor * conv_states_all = mctx_cur->get_r_l(il); ggml_tensor * ssm_states_all = mctx_cur->get_s_l(il); @@ -178,7 +184,7 @@ ggml_tensor * llm_graph_context_mamba::build_mamba2_layer(llm_graph_input_rs * i // d_in_proj = 2 * self.d_inner + 2 * self.ngroups * self.d_state + self.nheads // {n_embd, d_in_proj} @ {n_embd, n_seq_tokens, n_seqs} => {d_in_proj, n_seq_tokens, n_seqs} - ggml_tensor * zxBCdt = build_lora_mm(model.layers[il].ssm_in, cur); + ggml_tensor * zxBCdt = build_lora_mm(model.layers[il].ssm_in, cur, model.layers[il].ssm_in_s); // split the above in three ggml_tensor * z = ggml_view_4d(ctx0, zxBCdt, head_dim, n_head, n_seq_tokens, n_seqs, head_dim * zxBCdt->nb[0], @@ -272,7 +278,7 @@ ggml_tensor * llm_graph_context_mamba::build_mamba2_layer(llm_graph_input_rs * i y = ggml_reshape_3d(ctx0, y, d_inner, n_seq_tokens, n_seqs); // {d_inner, n_embd} @ {d_inner, n_seq_tokens, n_seqs} => {n_embd, n_seq_tokens, n_seqs} - cur = build_lora_mm(model.layers[il].ssm_out, y); + cur = build_lora_mm(model.layers[il].ssm_out, y, model.layers[il].ssm_out_s); } // {n_embd, n_seq_tokens, n_seqs} => {n_embd, n_tokens} diff --git a/examples/talk-llama/models/mamba.cpp b/examples/talk-llama/models/mamba.cpp index 46819613c2d..0d94e98281c 100644 --- a/examples/talk-llama/models/mamba.cpp +++ b/examples/talk-llama/models/mamba.cpp @@ -1,7 +1,90 @@ #include "models.h" +void llama_model_mamba::load_arch_hparams(llama_model_loader & ml) { + ml.get_key(LLM_KV_SSM_CONV_KERNEL, hparams.ssm_d_conv); + ml.get_key(LLM_KV_SSM_INNER_SIZE, hparams.ssm_d_inner); + ml.get_key(LLM_KV_SSM_STATE_SIZE, hparams.ssm_d_state); + ml.get_key(LLM_KV_SSM_TIME_STEP_RANK, hparams.ssm_dt_rank); + ml.get_key(LLM_KV_SSM_DT_B_C_RMS, hparams.ssm_dt_b_c_rms, false); + + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + + switch (hparams.n_layer()) { + case 24: + switch (hparams.n_embd) { + case 768: type = LLM_TYPE_SMALL; break; + default: type = LLM_TYPE_UNKNOWN; + } break; + case 48: + switch (hparams.n_embd) { + case 1024: type = LLM_TYPE_MEDIUM; break; + case 1536: type = LLM_TYPE_LARGE; break; + case 2048: type = LLM_TYPE_XL; break; + default: type = LLM_TYPE_UNKNOWN; + } break; + case 64: + switch (hparams.n_embd) { + case 2560: type = LLM_TYPE_3B; break; + default: type = LLM_TYPE_UNKNOWN; + } break; + default: type = LLM_TYPE_UNKNOWN; + } +} + +void llama_model_mamba::load_arch_tensors(llama_model_loader &) { + LLAMA_LOAD_LOCALS; + + const int64_t d_conv = hparams.ssm_d_conv; + const int64_t d_inner = hparams.ssm_d_inner; + const int64_t d_state = hparams.ssm_d_state; + const int64_t dt_rank = hparams.ssm_dt_rank; + + // only an expansion factor of 2 is supported for now + if (2 * n_embd != d_inner) { + throw std::runtime_error("only an expansion factor of 2 is supported for now"); + } + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); + // if output is NULL, init from the input tok embed, duplicated to allow offloading + if (output == NULL) { + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); + } + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + // norm + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + + layer.ssm_in = create_tensor(tn(LLM_TENSOR_SSM_IN, "weight", i), {n_embd, 2*d_inner}, 0); + + layer.ssm_conv1d = create_tensor(tn(LLM_TENSOR_SSM_CONV1D, "weight", i), {d_conv, d_inner}, 0); + layer.ssm_conv1d_b = create_tensor(tn(LLM_TENSOR_SSM_CONV1D, "bias", i), {d_inner}, 0); + + layer.ssm_x = create_tensor(tn(LLM_TENSOR_SSM_X, "weight", i), {d_inner, dt_rank + 2*d_state}, 0); + + layer.ssm_dt = create_tensor(tn(LLM_TENSOR_SSM_DT, "weight", i), {dt_rank, d_inner}, 0); + layer.ssm_dt_b = create_tensor(tn(LLM_TENSOR_SSM_DT, "bias", i), {d_inner}, 0); + + // no "weight" suffix for these + layer.ssm_a = create_tensor(tn(LLM_TENSOR_SSM_A, i), {d_state, d_inner}, 0); + layer.ssm_d = create_tensor(tn(LLM_TENSOR_SSM_D, i), {d_inner}, 0); -llm_build_mamba::llm_build_mamba(const llama_model & model, const llm_graph_params & params) : llm_graph_context_mamba(params) { + // out_proj + layer.ssm_out = create_tensor(tn(LLM_TENSOR_SSM_OUT, "weight", i), {d_inner, n_embd}, 0); + } +} + +std::unique_ptr<llm_graph_context> llama_model_mamba::build_arch_graph(const llm_graph_params & params) const { + return std::make_unique<graph>(*this, params); +} + +llama_model_mamba::graph::graph(const llama_model & model, const llm_graph_params & params) : llm_build_mamba_base(params) { ggml_tensor * cur; ggml_tensor * inpL; @@ -45,11 +128,10 @@ llm_build_mamba::llm_build_mamba(const llama_model & model, const llm_graph_para res->t_embd = cur; // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; ggml_build_forward_expand(gf, cur); } - diff --git a/examples/talk-llama/models/mamba2.cpp b/examples/talk-llama/models/mamba2.cpp new file mode 100644 index 00000000000..c5951cf0f7f --- /dev/null +++ b/examples/talk-llama/models/mamba2.cpp @@ -0,0 +1,87 @@ +#include "models.h" + +void llama_model_mamba2::load_arch_hparams(llama_model_loader & ml) { + ml.get_key(LLM_KV_SSM_CONV_KERNEL, hparams.ssm_d_conv); + ml.get_key(LLM_KV_SSM_INNER_SIZE, hparams.ssm_d_inner); + ml.get_key(LLM_KV_SSM_STATE_SIZE, hparams.ssm_d_state); + ml.get_key(LLM_KV_SSM_TIME_STEP_RANK, hparams.ssm_dt_rank); + ml.get_key(LLM_KV_SSM_GROUP_COUNT, hparams.ssm_n_group); + + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + + switch (hparams.n_layer()) { + case 24: + switch (hparams.n_embd) { + case 768: type = LLM_TYPE_SMALL; break; + default: type = LLM_TYPE_UNKNOWN; + } break; + case 48: + switch (hparams.n_embd) { + case 1024: type = LLM_TYPE_MEDIUM; break; + case 1536: type = LLM_TYPE_LARGE; break; + case 2048: type = LLM_TYPE_XL; break; + default: type = LLM_TYPE_UNKNOWN; + } break; + case 64: + switch (hparams.n_embd) { + case 2560: type = LLM_TYPE_3B; break; + case 4096: type = LLM_TYPE_7B; break; + default: type = LLM_TYPE_UNKNOWN; + } break; + default: type = LLM_TYPE_UNKNOWN; + } +} + +void llama_model_mamba2::load_arch_tensors(llama_model_loader &) { + LLAMA_LOAD_LOCALS; + + const int64_t d_conv = hparams.ssm_d_conv; + const int64_t d_inner = hparams.ssm_d_inner; + const int64_t d_state = hparams.ssm_d_state; + const int64_t n_group = hparams.ssm_n_group; + const int64_t d_in_proj = 2*d_inner + 2*n_group*d_state + n_head; + + // only an expansion factor of 2 is supported for now + GGML_ASSERT(2 * n_embd == d_inner); + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + { + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); + // if output is NULL, init from the input tok embed, duplicated to allow offloading + if (output == NULL) { + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); + } + } + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + // norm + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + + layer.ssm_in = create_tensor(tn(LLM_TENSOR_SSM_IN, "weight", i), {n_embd, d_in_proj}, 0); + + layer.ssm_conv1d = create_tensor(tn(LLM_TENSOR_SSM_CONV1D, "weight", i), {d_conv, d_inner + 2*n_group*d_state}, 0); + layer.ssm_conv1d_b = create_tensor(tn(LLM_TENSOR_SSM_CONV1D, "bias", i), {d_inner + 2*n_group*d_state}, 0); + + layer.ssm_dt_b = create_tensor(tn(LLM_TENSOR_SSM_DT, "bias", i), {n_head}, 0); + + // no "weight" suffix for these + layer.ssm_a = create_tensor(tn(LLM_TENSOR_SSM_A, i), {1, n_head}, 0); + layer.ssm_d = create_tensor(tn(LLM_TENSOR_SSM_D, i), {1, n_head}, 0); + + layer.ssm_norm = create_tensor(tn(LLM_TENSOR_SSM_NORM, "weight", i), {d_inner / n_group, n_group}, 0); + + // out_proj + layer.ssm_out = create_tensor(tn(LLM_TENSOR_SSM_OUT, "weight", i), {d_inner, n_embd}, 0); + } +} + +std::unique_ptr<llm_graph_context> llama_model_mamba2::build_arch_graph(const llm_graph_params & params) const { + return std::make_unique<graph>(*this, params); +} + diff --git a/examples/talk-llama/models/mellum.cpp b/examples/talk-llama/models/mellum.cpp new file mode 100644 index 00000000000..28823018bc0 --- /dev/null +++ b/examples/talk-llama/models/mellum.cpp @@ -0,0 +1,225 @@ +#include "models.h" + +void llama_model_mellum::load_arch_hparams(llama_model_loader & ml) { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp); + ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa, false); + + if (hparams.n_swa > 0) { + hparams.swa_type = LLAMA_SWA_TYPE_STANDARD; + + uint32_t swa_period = 4; + const auto res = ml.get_key_or_arr(LLM_KV_ATTENTION_SLIDING_WINDOW_PATTERN, swa_period, false); + if (res) { + hparams.set_swa_pattern(swa_period); + } else { + ml.get_key_or_arr(LLM_KV_ATTENTION_SLIDING_WINDOW_PATTERN, hparams.is_swa_impl, hparams.n_layer()); + } + + hparams.rope_freq_base_train_swa = hparams.rope_freq_base_train; + hparams.rope_freq_scale_train_swa = hparams.rope_freq_scale_train; + + ml.get_key(LLM_KV_ROPE_FREQ_BASE_SWA, hparams.rope_freq_base_train_swa, false); + } else { + hparams.swa_type = LLAMA_SWA_TYPE_NONE; + } + + switch (hparams.n_layer()) { + case 28: type = LLM_TYPE_12B_A2_5B; break; + default: type = LLM_TYPE_UNKNOWN; + } +} + +void llama_model_mellum::load_arch_tensors(llama_model_loader &) { + LLAMA_LOAD_LOCALS; + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0); + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + + create_tensor_qkv(layer, i, n_embd, n_embd_head_k * n_head, n_embd_gqa, n_embd_gqa, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0); + + layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k}, 0); + layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head_k}, 0); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + + layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0); + + if (n_expert == 0) { + throw std::runtime_error("n_expert must be > 0 for Mellum"); + } + if (n_expert_used == 0) { + throw std::runtime_error("n_expert_used must be > 0 for Mellum"); + } + + const int64_t n_ff_exp = hparams.n_ff_exp ? hparams.n_ff_exp : n_ff / n_expert_used; + + layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}, 0); + layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff_exp, n_embd, n_expert}, 0); + layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}, 0); + } +} + +std::unique_ptr<llm_graph_context> llama_model_mellum::build_arch_graph(const llm_graph_params & params) const { + if (hparams.swa_type == LLAMA_SWA_TYPE_STANDARD) { + return std::make_unique<graph<true>>(*this, params); + } + return std::make_unique<graph<false>>(*this, params); +} + +template <bool iswa> +llama_model_mellum::graph<iswa>::graph(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { + const int64_t n_embd_head = hparams.n_embd_head_v(); + + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k()); + GGML_ASSERT(n_embd_head == n_rot); + + ggml_tensor * cur; + ggml_tensor * inpL; + + inpL = build_inp_embd(model.tok_embd); + + // inp_pos - contains the positions + ggml_tensor * inp_pos = build_inp_pos(); + + using inp_attn_type = std::conditional_t<iswa, llm_graph_input_attn_kv_iswa, llm_graph_input_attn_kv>; + inp_attn_type * inp_attn = nullptr; + + if constexpr (iswa) { + inp_attn = build_attn_inp_kv_iswa(); + } else { + inp_attn = build_attn_inp_kv(); + } + + ggml_tensor * inp_out_ids = build_inp_out_ids(); + + for (int il = 0; il < n_layer; ++il) { + ggml_tensor * inpSA = inpL; + + // norm + cur = build_norm(inpL, + model.layers[il].attn_norm, nullptr, + LLM_NORM_RMS, il); + cb(cur, "attn_norm", il); + + // self_attention + { + // compute Q and K and RoPE them + auto [Qcur, Kcur, Vcur] = build_qkv(model.layers[il], cur, + n_embd_head, n_head, n_head_kv, il); + + Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, nullptr, LLM_NORM_RMS, il); + cb(Qcur, "Qcur_normed", il); + + Kcur = build_norm(Kcur, model.layers[il].attn_k_norm, nullptr, LLM_NORM_RMS, il); + cb(Kcur, "Kcur_normed", il); + + const bool is_swa = hparams.is_swa(il); + + if (is_swa) { + // For sliding window layers, use regular rope with no yarn rope scaling. + // This is achieved here by setting freq_scale and attn_factor to 1. + // We also set ext_factor to 0 to avoid a few unnecessary computations. + Qcur = ggml_rope_ext( + ctx0, Qcur, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, 1.0, + 0.0, 1.0, beta_fast, beta_slow + ); + + Kcur = ggml_rope_ext( + ctx0, Kcur, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, 1.0, + 0.0, 1.0, beta_fast, beta_slow + ); + } else { + Qcur = ggml_rope_ext( + ctx0, Qcur, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + Kcur = ggml_rope_ext( + ctx0, Kcur, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + } + + cb(Qcur, "Qcur", il); + cb(Kcur, "Kcur", il); + cb(Vcur, "Vcur", il); + + cur = build_attn(inp_attn, + model.layers[il].wo, model.layers[il].wo_b, model.layers[il].wo_s, + Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); + } + if (il == n_layer - 1 && inp_out_ids) { + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); + } + ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA); + cb(ffn_inp, "ffn_inp", il); + + // MoE + cur = build_norm(ffn_inp, + model.layers[il].ffn_norm, nullptr, + LLM_NORM_RMS, il); + cb(cur, "ffn_norm", il); + + ggml_tensor * moe_out = + build_moe_ffn(cur, + model.layers[il].ffn_gate_inp, + model.layers[il].ffn_up_exps, + model.layers[il].ffn_gate_exps, + model.layers[il].ffn_down_exps, + nullptr, + n_expert, n_expert_used, + LLM_FFN_SILU, true, + hparams.expert_weights_scale, + LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX, + il, + nullptr, nullptr, + model.layers[il].ffn_up_exps_s, + model.layers[il].ffn_gate_exps_s, + model.layers[il].ffn_down_exps_s); + cb(moe_out, "ffn_moe_out", il); + cur = moe_out; + + cur = ggml_add(ctx0, cur, ffn_inp); + cb(cur, "ffn_out", il); + + cur = build_cvec(cur, il); + cb(cur, "l_out", il); + + // input for next layer + inpL = cur; + } + cur = inpL; + + cur = build_norm(cur, + model.output_norm, nullptr, + LLM_NORM_RMS, -1); + + cb(cur, "result_norm", -1); + res->t_embd = cur; + + // lm_head + cur = build_lora_mm(model.output, cur, model.output_s); + + cb(cur, "result_output", -1); + res->t_logits = cur; + + ggml_build_forward_expand(gf, cur); +} + +template struct llama_model_mellum::graph<false>; +template struct llama_model_mellum::graph<true>; diff --git a/examples/talk-llama/models/mimo2-iswa.cpp b/examples/talk-llama/models/mimo2-iswa.cpp deleted file mode 100644 index edc87cc9f0d..00000000000 --- a/examples/talk-llama/models/mimo2-iswa.cpp +++ /dev/null @@ -1,123 +0,0 @@ - -#include "models.h" - -llm_build_mimo2_iswa::llm_build_mimo2_iswa(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { - ggml_tensor * cur; - ggml_tensor * inpL; - - inpL = build_inp_embd(model.tok_embd); - - ggml_tensor * inp_pos = build_inp_pos(); - auto * inp_attn = build_attn_inp_kv_iswa(); - ggml_tensor * inp_out_ids = build_inp_out_ids(); - - for (int il = 0; il < n_layer; ++il) { - ggml_tensor * inpSA = inpL; - - uint32_t n_head_l = hparams.n_head(il); - uint32_t n_head_kv_l = hparams.n_head_kv(il); - const float freq_base_l = model.get_rope_freq_base(cparams, il); - const float freq_scale_l = model.get_rope_freq_scale(cparams, il); - - cur = inpL; - - // self_attention - { - cur = build_norm(inpL, model.layers[il].attn_norm, NULL, LLM_NORM_RMS, il); - cb(cur, "attn_norm", il); - - // compute Q and K and RoPE them - ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); - cb(Qcur, "Qcur", il); - - ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); - cb(Kcur, "Kcur", il); - - ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); - cb(Vcur, "Vcur", il); - - Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head_k, n_head_l, n_tokens); - Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head_k, n_head_kv_l, n_tokens); - Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head_v, n_head_kv_l, n_tokens); - - Qcur = ggml_rope_ext( - ctx0, Qcur, inp_pos, nullptr, - n_rot, rope_type, n_ctx_orig, freq_base_l, freq_scale_l, - ext_factor, attn_factor, beta_fast, beta_slow - ); - - Kcur = ggml_rope_ext( - ctx0, Kcur, inp_pos, nullptr, - n_rot, rope_type, n_ctx_orig, freq_base_l, freq_scale_l, - ext_factor, attn_factor, beta_fast, beta_slow - ); - - cb(Qcur, "Qcur", il); - cb(Kcur, "Kcur", il); - cb(Vcur, "Vcur", il); - - ggml_tensor * sinks = model.layers[il].attn_sinks; - - cur = build_attn(inp_attn, - model.layers[il].wo, NULL, - Qcur, Kcur, Vcur, nullptr, sinks, nullptr, 1.0f/sqrtf(float(n_embd_head_k)), il); - } - - if (il == n_layer - 1 && inp_out_ids) { - cur = ggml_get_rows(ctx0, cur, inp_out_ids); - inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); - } - - ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA); - cb(ffn_inp, "ffn_inp", il); - - cur = build_norm(ffn_inp, - model.layers[il].ffn_norm, NULL, - LLM_NORM_RMS, il); - cb(cur, "ffn_norm", il); - - // feed-forward network - if (model.layers[il].ffn_gate_inp == nullptr) { - // dense branch - cur = build_ffn(cur, - model.layers[il].ffn_up, model.layers[il].ffn_up_b, NULL, - model.layers[il].ffn_gate, model.layers[il].ffn_gate_b, NULL, - model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL, - NULL, - LLM_FFN_SILU, LLM_FFN_PAR, il); - cb(cur, "ffn_out", il); - } else { - // MoE branch - cur = build_moe_ffn(cur, model.layers[il].ffn_gate_inp, model.layers[il].ffn_up_exps, - model.layers[il].ffn_gate_exps, model.layers[il].ffn_down_exps, - model.layers[il].ffn_exp_probs_b, n_expert, n_expert_used, LLM_FFN_SILU, true, false, - 0.0, LLAMA_EXPERT_GATING_FUNC_TYPE_SIGMOID, il); - cb(cur, "ffn_moe_out", il); - } - - cur = ggml_add(ctx0, cur, ffn_inp); - - cur = build_cvec(cur, il); - cb(cur, "l_out", il); - - // input for next layer - inpL = cur; - } - - cur = inpL; - - cur = build_norm(cur, - model.output_norm, NULL, - LLM_NORM_RMS, -1); - - cb(cur, "result_norm", -1); - res->t_embd = cur; - - // lm_head - cur = build_lora_mm(model.output, cur); - - cb(cur, "result_output", -1); - res->t_logits = cur; - - ggml_build_forward_expand(gf, cur); -} diff --git a/examples/talk-llama/models/mimo2.cpp b/examples/talk-llama/models/mimo2.cpp new file mode 100644 index 00000000000..88989160570 --- /dev/null +++ b/examples/talk-llama/models/mimo2.cpp @@ -0,0 +1,235 @@ +#include "models.h" + +void llama_model_mimo2::load_arch_hparams(llama_model_loader & ml) { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + + hparams.swa_type = LLAMA_SWA_TYPE_STANDARD; + + ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp); + ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa); + ml.get_key(LLM_KV_ROPE_FREQ_BASE_SWA, hparams.rope_freq_base_train_swa, false); + + ml.get_key_or_arr(LLM_KV_ATTENTION_SLIDING_WINDOW_PATTERN, hparams.is_swa_impl, hparams.n_layer()); + + float value_scale = 0.0f; + if (ml.get_key(LLM_KV_ATTENTION_VALUE_SCALE, value_scale, false) && value_scale != 1.0f) { + hparams.f_attn_value_scale = value_scale; + } + + ml.get_key(LLM_KV_NEXTN_PREDICT_LAYERS, hparams.n_layer_nextn, false); + GGML_ASSERT(hparams.n_layer_nextn < hparams.n_layer_all && "n_layer_nextn must be < n_layer_impl"); + + switch (hparams.n_layer()) { + case 48: type = LLM_TYPE_310B_A15B; break; + default: type = LLM_TYPE_UNKNOWN; + } +} + +void llama_model_mimo2::load_arch_tensors(llama_model_loader &) { + LLAMA_LOAD_LOCALS; + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0); + + for (int i = 0; i < n_layer_all; ++i) { + auto & layer = layers[i]; + uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(i); + uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(i); + uint32_t n_head = hparams.n_head(i); + + // NextN/MTP layers (the last n_nextn blocks) are preserved but disabled pending support + const bool is_nextn = i >= n_layer; + const int skip = is_nextn ? TENSOR_SKIP : 0; + + create_tensor_qkv(layer, i, n_embd, n_embd_head_k * n_head, n_embd_k_gqa, n_embd_v_gqa, skip); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), { n_embd_head_v * n_head, n_embd }, skip); + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, skip); + layer.attn_sinks = create_tensor(tn(LLM_TENSOR_ATTN_SINKS, "weight", i), {n_head}, TENSOR_NOT_REQUIRED | skip); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, skip); + + // non-MoE branch + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, TENSOR_NOT_REQUIRED | skip); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, TENSOR_NOT_REQUIRED | skip); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, TENSOR_NOT_REQUIRED | skip); + + // MoE branch + int64_t n_ff_exp = hparams.n_ff_exp; + layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, TENSOR_NOT_REQUIRED | skip); + layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {n_embd, n_ff_exp, n_expert}, TENSOR_NOT_REQUIRED | skip); + layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff_exp, n_embd, n_expert}, TENSOR_NOT_REQUIRED | skip); + layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), {n_embd, n_ff_exp, n_expert}, TENSOR_NOT_REQUIRED | skip); + layer.ffn_exp_probs_b = create_tensor(tn(LLM_TENSOR_FFN_EXP_PROBS_B, "bias", i), {n_expert}, TENSOR_NOT_REQUIRED | skip); + + if (is_nextn) { + layer.nextn.eh_proj = create_tensor(tn(LLM_TENSOR_NEXTN_EH_PROJ, "weight", i), {2 * n_embd, n_embd}, skip); + layer.nextn.enorm = create_tensor(tn(LLM_TENSOR_NEXTN_ENORM, "weight", i), {n_embd}, skip); + layer.nextn.hnorm = create_tensor(tn(LLM_TENSOR_NEXTN_HNORM, "weight", i), {n_embd}, skip); + layer.layer_out_norm = create_tensor(tn(LLM_TENSOR_LAYER_OUT_NORM, "weight", i), {n_embd}, skip); + } + } +} + +std::unique_ptr<llm_graph_context> llama_model_mimo2::build_arch_graph(const llm_graph_params & params) const { + return std::make_unique<graph>(*this, params); +} + +llama_model_mimo2::graph::graph(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { + ggml_tensor * cur; + ggml_tensor * inpL; + + inpL = build_inp_embd(model.tok_embd); + + ggml_tensor * inp_pos = build_inp_pos(); + auto * inp_attn = build_attn_inp_kv_iswa(); + ggml_tensor * inp_out_ids = build_inp_out_ids(); + + const float v_scale = hparams.f_attn_value_scale; + + for (int il = 0; il < n_layer; ++il) { + ggml_tensor * inpSA = inpL; + + uint32_t n_head_l = hparams.n_head(il); + uint32_t n_head_kv_l = hparams.n_head_kv(il); + const float freq_base_l = model.get_rope_freq_base(cparams, il); + const float freq_scale_l = model.get_rope_freq_scale(cparams, il); + + cur = inpL; + + // self_attention + { + cur = build_norm(inpL, model.layers[il].attn_norm, NULL, LLM_NORM_RMS, il); + cb(cur, "attn_norm", il); + + ggml_tensor * Qcur; + ggml_tensor * Kcur; + ggml_tensor * Vcur; + + if (model.layers[il].wqkv) { + // Fused qkv_proj - Q/K share head_dim_k, V uses head_dim_v + ggml_tensor * qkv = build_lora_mm(model.layers[il].wqkv, cur); + cb(qkv, "wqkv", il); + + const size_t row_k = ggml_row_size(qkv->type, n_embd_head_k); + const size_t row_v = ggml_row_size(qkv->type, n_embd_head_v); + const size_t row_full = qkv->nb[1]; + const size_t k_off = row_k * n_head_l; + const size_t v_off = k_off + row_k * n_head_kv_l; + + Qcur = ggml_view_3d(ctx0, qkv, n_embd_head_k, n_head_l, n_tokens, row_k, row_full, 0); + Kcur = ggml_view_3d(ctx0, qkv, n_embd_head_k, n_head_kv_l, n_tokens, row_k, row_full, k_off); + Vcur = ggml_view_3d(ctx0, qkv, n_embd_head_v, n_head_kv_l, n_tokens, row_v, row_full, v_off); + } else { + // Split path + Qcur = build_lora_mm(model.layers[il].wq, cur); + cb(Qcur, "Qcur", il); + + Kcur = build_lora_mm(model.layers[il].wk, cur); + cb(Kcur, "Kcur", il); + + Vcur = build_lora_mm(model.layers[il].wv, cur); + cb(Vcur, "Vcur", il); + + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head_k, n_head_l, n_tokens); + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head_k, n_head_kv_l, n_tokens); + Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head_v, n_head_kv_l, n_tokens); + } + + Qcur = ggml_rope_ext( + ctx0, Qcur, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base_l, freq_scale_l, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + Kcur = ggml_rope_ext( + ctx0, Kcur, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base_l, freq_scale_l, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + cb(Qcur, "Qcur", il); + cb(Kcur, "Kcur", il); + cb(Vcur, "Vcur", il); + + ggml_tensor * sinks = model.layers[il].attn_sinks; + + cur = build_attn(inp_attn, + model.layers[il].wo, NULL, model.layers[il].wo_s, + Qcur, Kcur, Vcur, nullptr, sinks, nullptr, 1.0f/sqrtf(float(n_embd_head_k)), il); + cb(cur, "attn_out", il); + + if (v_scale) { + cur = ggml_scale(ctx0, cur, v_scale); + cb(cur, "attn_out_scaled", il); + } + } + + if (il == n_layer - 1 && inp_out_ids) { + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); + } + + ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA); + cb(ffn_inp, "ffn_inp", il); + + cur = build_norm(ffn_inp, + model.layers[il].ffn_norm, NULL, + LLM_NORM_RMS, il); + cb(cur, "ffn_norm", il); + + // feed-forward network + if (model.layers[il].ffn_gate_inp == nullptr) { + // dense branch + cur = build_ffn(cur, + model.layers[il].ffn_up, model.layers[il].ffn_up_b, NULL, + model.layers[il].ffn_gate, model.layers[il].ffn_gate_b, NULL, + model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL, + NULL, + LLM_FFN_SILU, LLM_FFN_PAR, il); + cb(cur, "ffn_out", il); + } else { + // MoE branch + cur = build_moe_ffn(cur, + model.layers[il].ffn_gate_inp, + model.layers[il].ffn_up_exps, + model.layers[il].ffn_gate_exps, + model.layers[il].ffn_down_exps, + model.layers[il].ffn_exp_probs_b, + n_expert, n_expert_used, + LLM_FFN_SILU, true, + hparams.expert_weights_scale, + LLAMA_EXPERT_GATING_FUNC_TYPE_SIGMOID, + il); + cb(cur, "ffn_moe_out", il); + } + + cur = ggml_add(ctx0, cur, ffn_inp); + + cur = build_cvec(cur, il); + cb(cur, "l_out", il); + + // input for next layer + inpL = cur; + } + + cur = inpL; + + cur = build_norm(cur, + model.output_norm, NULL, + LLM_NORM_RMS, -1); + + cb(cur, "result_norm", -1); + res->t_embd = cur; + + // lm_head + cur = build_lora_mm(model.output, cur, model.output_s); + + cb(cur, "result_output", -1); + res->t_logits = cur; + + ggml_build_forward_expand(gf, cur); +} diff --git a/examples/talk-llama/models/minicpm.cpp b/examples/talk-llama/models/minicpm.cpp new file mode 100644 index 00000000000..fc3e5b171d5 --- /dev/null +++ b/examples/talk-llama/models/minicpm.cpp @@ -0,0 +1,89 @@ +#include "models.h" + +void llama_model_minicpm::load_arch_hparams(llama_model_loader & ml) { + // Backward-compatible defaults for older MiniCPM GGUFs + hparams.f_embedding_scale = 12.0f; + hparams.f_residual_scale = 1.4f / sqrtf(float(hparams.n_layer())); + hparams.f_logit_scale = hparams.n_embd ? (256.0f / float(hparams.n_embd)) : 1.0f; + + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + + // Optional KV reads, override defaults if present in newer GGUF exports + ml.get_key(LLM_KV_EMBEDDING_SCALE, hparams.f_embedding_scale, /*required=*/false); + ml.get_key(LLM_KV_RESIDUAL_SCALE, hparams.f_residual_scale, /*required=*/false); + ml.get_key(LLM_KV_LOGIT_SCALE, hparams.f_logit_scale, /*required=*/false); + + // MiniCPM uses rope by default, unlike Granite which uses it as a switch + hparams.rope_finetuned = true; + + switch (hparams.n_layer()) { + case 52: type = LLM_TYPE_1B; break; + case 40: type = LLM_TYPE_2B; break; + default: type = LLM_TYPE_UNKNOWN; + } +} + +void llama_model_minicpm::load_arch_tensors(llama_model_loader &) { + LLAMA_LOAD_LOCALS; + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); + + // if output is NULL, init from the input tok embed + if (output == NULL) { + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); + } + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + + create_tensor_qkv(layer, i, n_embd, n_embd_head_k * n_head, n_embd_k_gqa, n_embd_v_gqa, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0); + + // optional bias tensors + layer.wo_b = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + + if (hparams.rope_scaling_type_train == LLAMA_ROPE_SCALING_TYPE_LONGROPE) { + layer.rope_long = create_tensor(tn(LLM_TENSOR_ROPE_FACTORS_LONG, "weight", i), {n_rot/2}, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0)); + layer.rope_short = create_tensor(tn(LLM_TENSOR_ROPE_FACTORS_SHORT, "weight", i), {n_rot/2}, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0)); + } + else { + layer.rope_freqs = create_tensor(tn(LLM_TENSOR_ROPE_FREQS, "weight", i), {n_rot/2}, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0)); + } + + if (n_expert == 0) { + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + + // optional MLP bias + layer.ffn_gate_b = create_tensor(tn(LLM_TENSOR_FFN_GATE, "bias", i), {n_ff}, TENSOR_NOT_REQUIRED); + layer.ffn_down_b = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); + layer.ffn_up_b = create_tensor(tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff}, TENSOR_NOT_REQUIRED); + } else { + layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0); + layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {n_embd, n_ff, n_expert}, TENSOR_NOT_REQUIRED); + layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), { n_ff, n_embd, n_expert}, 0); + layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), {n_embd, n_ff, n_expert}, 0); + + // For Granite MoE Shared + if (hparams.n_ff_shexp > 0) { + layer.ffn_gate_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), {n_embd, hparams.n_ff_shexp}, 0); + layer.ffn_up_shexp = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), {n_embd, hparams.n_ff_shexp}, 0); + layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), {hparams.n_ff_shexp, n_embd}, 0); + } + } + } +} + +std::unique_ptr<llm_graph_context> llama_model_minicpm::build_arch_graph(const llm_graph_params & params) const { + return std::make_unique<graph>(*this, params); +} + diff --git a/examples/talk-llama/models/minicpm3.cpp b/examples/talk-llama/models/minicpm3.cpp index f374a9fd030..e011b1ff0a8 100644 --- a/examples/talk-llama/models/minicpm3.cpp +++ b/examples/talk-llama/models/minicpm3.cpp @@ -1,14 +1,75 @@ #include "models.h" -llm_build_minicpm3::llm_build_minicpm3(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { +void llama_model_minicpm3::load_arch_hparams(llama_model_loader & ml) { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + ml.get_key(LLM_KV_ATTENTION_Q_LORA_RANK, hparams.n_lora_q); + ml.get_key(LLM_KV_ATTENTION_KV_LORA_RANK, hparams.n_lora_kv); + + switch (hparams.n_layer()) { + case 62: type = LLM_TYPE_4B; break; + default: type = LLM_TYPE_UNKNOWN; + } +} + +void llama_model_minicpm3::load_arch_tensors(llama_model_loader &) { + LLAMA_LOAD_LOCALS; + + const int64_t n_embd_head_qk_rope = hparams.n_rot(); + const int64_t n_embd_head_qk_nope = hparams.n_embd_head_k() - hparams.n_rot(); + + const int64_t q_lora_rank = hparams.n_lora_q; + const int64_t kv_lora_rank = hparams.n_lora_kv; + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); + + // if output is NULL, init from the input tok embed + if (output == NULL) { + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); + } + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + layer.attn_q_a_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_A_NORM, "weight", i), {q_lora_rank}, 0); + + layer.attn_kv_a_norm = create_tensor(tn(LLM_TENSOR_ATTN_KV_A_NORM, "weight", i), {kv_lora_rank}, 0); + + layer.wq_a = create_tensor(tn(LLM_TENSOR_ATTN_Q_A, "weight", i), {n_embd, q_lora_rank}, 0); + layer.wq_b = create_tensor(tn(LLM_TENSOR_ATTN_Q_B, "weight", i), {q_lora_rank, n_head * n_embd_head_k}, 0); + + layer.wkv_a_mqa = create_tensor(tn(LLM_TENSOR_ATTN_KV_A_MQA, "weight", i), {n_embd, kv_lora_rank + (n_embd_head_qk_rope)}, 0); + layer.wkv_b = create_tensor(tn(LLM_TENSOR_ATTN_KV_B, "weight", i), {kv_lora_rank, n_head * (n_embd_head_qk_nope + n_embd_head_v)}, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), { n_head * ( n_embd_head_v), n_embd}, 0); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + + layer.rope_long = create_tensor(tn(LLM_TENSOR_ROPE_FACTORS_LONG, "weight", i), { n_embd_head_qk_rope/2 }, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0)); + layer.rope_short = create_tensor(tn(LLM_TENSOR_ROPE_FACTORS_SHORT, "weight", i), { n_embd_head_qk_rope/2 }, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0)); + } +} + +std::unique_ptr<llm_graph_context> llama_model_minicpm3::build_arch_graph(const llm_graph_params & params) const { + return std::make_unique<graph>(*this, params); +} + +llama_model_minicpm3::graph::graph(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { //TODO: if the model varies, these parameters need to be read from the model const int64_t n_embd_base = 256; const float scale_embd = 12.0f; const float scale_depth = 1.4f; - const float kq_scale = 1.0f / sqrtf(float(hparams.n_embd_head_k)); + const float kq_scale = 1.0f / sqrtf(float(hparams.n_embd_head_k())); + + const uint32_t n_embd_head_qk_rope = hparams.n_rot(); + const uint32_t n_embd_head_qk_nope = hparams.n_embd_head_k() - hparams.n_rot(); - const uint32_t n_embd_head_qk_rope = hparams.n_rot; - const uint32_t n_embd_head_qk_nope = hparams.n_embd_head_k - hparams.n_rot; const uint32_t kv_lora_rank = hparams.n_lora_kv; ggml_tensor * cur; @@ -50,21 +111,21 @@ llm_build_minicpm3::llm_build_minicpm3(const llama_model & model, const llm_grap LLM_NORM_RMS, il); cb(q, "q", il); - // {q_lora_rank, n_head * hparams.n_embd_head_k} * {q_lora_rank, n_tokens} -> {n_head * hparams.n_embd_head_k, n_tokens} + // {q_lora_rank, n_head * hparams.n_embd_head_k()} * {q_lora_rank, n_tokens} -> {n_head * hparams.n_embd_head_k(), n_tokens} q = ggml_mul_mat(ctx0, model.layers[il].wq_b, q); cb(q, "q", il); // split into {n_head * n_embd_head_qk_nope, n_tokens} ggml_tensor * q_nope = ggml_view_3d(ctx0, q, n_embd_head_qk_nope, n_head, n_tokens, - ggml_row_size(q->type, hparams.n_embd_head_k), - ggml_row_size(q->type, hparams.n_embd_head_k * n_head), + ggml_row_size(q->type, hparams.n_embd_head_k()), + ggml_row_size(q->type, hparams.n_embd_head_k() * n_head), 0); cb(q_nope, "q_nope", il); // and {n_head * n_embd_head_qk_rope, n_tokens} ggml_tensor * q_pe = ggml_view_3d(ctx0, q, n_embd_head_qk_rope, n_head, n_tokens, - ggml_row_size(q->type, hparams.n_embd_head_k), - ggml_row_size(q->type, hparams.n_embd_head_k * n_head), + ggml_row_size(q->type, hparams.n_embd_head_k()), + ggml_row_size(q->type, hparams.n_embd_head_k() * n_head), ggml_row_size(q->type, n_embd_head_qk_nope)); cb(q_pe, "q_pe", il); @@ -96,15 +157,15 @@ llm_build_minicpm3::llm_build_minicpm3(const llama_model & model, const llm_grap // split into {n_head * n_embd_head_qk_nope, n_tokens} ggml_tensor * k_nope = ggml_view_3d(ctx0, kv, n_embd_head_qk_nope, n_head, n_tokens, - ggml_row_size(kv->type, n_embd_head_qk_nope + hparams.n_embd_head_v), - ggml_row_size(kv->type, n_head * (n_embd_head_qk_nope + hparams.n_embd_head_v)), + ggml_row_size(kv->type, n_embd_head_qk_nope + hparams.n_embd_head_v()), + ggml_row_size(kv->type, n_head * (n_embd_head_qk_nope + hparams.n_embd_head_v())), 0); cb(k_nope, "k_nope", il); // and {n_head * n_embd_head_v, n_tokens} - ggml_tensor * v_states = ggml_view_3d(ctx0, kv, hparams.n_embd_head_v, n_head, n_tokens, - ggml_row_size(kv->type, (n_embd_head_qk_nope + hparams.n_embd_head_v)), - ggml_row_size(kv->type, (n_embd_head_qk_nope + hparams.n_embd_head_v)*n_head), + ggml_tensor * v_states = ggml_view_3d(ctx0, kv, hparams.n_embd_head_v(), n_head, n_tokens, + ggml_row_size(kv->type, (n_embd_head_qk_nope + hparams.n_embd_head_v())), + ggml_row_size(kv->type, (n_embd_head_qk_nope + hparams.n_embd_head_v())*n_head), ggml_row_size(kv->type, (n_embd_head_qk_nope))); cb(v_states, "v_states", il); @@ -133,7 +194,7 @@ llm_build_minicpm3::llm_build_minicpm3(const llama_model & model, const llm_grap cb(k_states, "k_states", il); cur = build_attn(inp_attn, - model.layers[il].wo, NULL, + model.layers[il].wo, NULL, model.layers[il].wo_s, q_states, k_states, v_states, nullptr, nullptr, nullptr, kq_scale, il); } if (il == n_layer - 1 && inp_out_ids) { @@ -190,7 +251,7 @@ llm_build_minicpm3::llm_build_minicpm3(const llama_model & model, const llm_grap cb(cur, "lmhead_scaling", -1); // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/minimax-m2.cpp b/examples/talk-llama/models/minimax-m2.cpp index f7001badf75..b25435e4d97 100644 --- a/examples/talk-llama/models/minimax-m2.cpp +++ b/examples/talk-llama/models/minimax-m2.cpp @@ -1,11 +1,54 @@ - #include "models.h" -llm_build_minimax_m2::llm_build_minimax_m2(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { - const int64_t n_embd_head = hparams.n_embd_head_v; +void llama_model_minimax_m2::load_arch_hparams(llama_model_loader & ml) { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp); + ml.get_key(LLM_KV_EXPERT_GATING_FUNC, hparams.expert_gating_func, false); + + switch (hparams.n_layer()) { + case 62: type = LLM_TYPE_230B_A10B; break; + default: type = LLM_TYPE_UNKNOWN; + } +} + +void llama_model_minimax_m2::load_arch_tensors(llama_model_loader &) { + LLAMA_LOAD_LOCALS; + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0); + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + create_tensor_qkv(layer, i, n_embd, n_embd_head_k * n_head, n_embd_gqa, n_embd_gqa, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), { n_embd_head_k * n_head, n_embd }, 0); + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head_k * n_head}, 0); + layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_k_gqa}, 0); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + + layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0); + layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {n_embd, n_ff, n_expert}, 0); + layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff, n_embd, n_expert}, 0); + layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), {n_embd, n_ff, n_expert}, 0); + layer.ffn_exp_probs_b = create_tensor(tn(LLM_TENSOR_FFN_EXP_PROBS_B, "bias", i), {n_expert}, 0); + } +} + +std::unique_ptr<llm_graph_context> llama_model_minimax_m2::build_arch_graph(const llm_graph_params & params) const { + return std::make_unique<graph>(*this, params); +} + +llama_model_minimax_m2::graph::graph(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { + const int64_t n_embd_head = hparams.n_embd_head_v(); - GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); - // GGML_ASSERT(n_embd_head == hparams.n_rot); this is wrong in case of minimax, head_dim = 128, n_rot = 64 + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k()); + // GGML_ASSERT(n_embd_head == n_rot); this is wrong in case of minimax, head_dim = 128, n_rot = 64 ggml_tensor * cur; ggml_tensor * inpL; @@ -65,7 +108,7 @@ llm_build_minimax_m2::llm_build_minimax_m2(const llama_model & model, const llm_ cb(Vcur, "Vcur", il); cur = build_attn(inp_attn, - model.layers[il].wo, NULL, + model.layers[il].wo, NULL, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } @@ -91,7 +134,7 @@ llm_build_minimax_m2::llm_build_minimax_m2(const llama_model & model, const llm_ model.layers[il].ffn_exp_probs_b, n_expert, n_expert_used, LLM_FFN_SILU, true, - false, 0.0, + hparams.expert_weights_scale, (llama_expert_gating_func_type) hparams.expert_gating_func, il); cb(cur, "ffn_moe_out", il); @@ -115,7 +158,7 @@ llm_build_minimax_m2::llm_build_minimax_m2(const llama_model & model, const llm_ res->t_embd = cur; // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/mistral3.cpp b/examples/talk-llama/models/mistral3.cpp index 0b672235911..9a8e3f9a50b 100644 --- a/examples/talk-llama/models/mistral3.cpp +++ b/examples/talk-llama/models/mistral3.cpp @@ -1,10 +1,100 @@ #include "models.h" -llm_build_mistral3::llm_build_mistral3(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { - const int64_t n_embd_head = hparams.n_embd_head_v; +void llama_model_mistral3::load_arch_hparams(llama_model_loader & ml) { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + ml.get_key(LLM_KV_ATTENTION_TEMPERATURE_SCALE, hparams.f_attn_temp_scale, false); - GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); - GGML_ASSERT(n_embd_head == hparams.n_rot); + ml.get_key(LLM_KV_ROPE_SCALING_YARN_BETA_FAST, hparams.yarn_beta_fast, false); + ml.get_key(LLM_KV_ROPE_SCALING_YARN_BETA_SLOW, hparams.yarn_beta_slow, false); + ml.get_key(LLM_KV_ROPE_SCALING_YARN_LOG_MUL, hparams.rope_yarn_log_mul, false); + + hparams.f_attn_temp_offset = 0.0f; + + // TODO: maybe add n_attn_temp_floor_scale as a separate KV? + if (hparams.f_attn_temp_scale != 0.0f) { + hparams.n_attn_temp_floor_scale = hparams.n_ctx_orig_yarn; + if (hparams.n_attn_temp_floor_scale == 0) { + throw std::runtime_error("invalid n_ctx_orig_yarn for attention temperature scaling"); + } + } + + switch (hparams.n_layer()) { + case 26: type = LLM_TYPE_3B; break; + case 34: type = LLM_TYPE_8B; break; + case 40: type = LLM_TYPE_14B; break; + default: type = LLM_TYPE_UNKNOWN; + } +} + +void llama_model_mistral3::load_arch_tensors(llama_model_loader &) { + LLAMA_LOAD_LOCALS; + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); + + // if output is NULL, init from the input tok embed + if (output == NULL) { + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); + } + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + + create_tensor_qkv(layer, i, n_embd, n_embd_head_k * n_head, n_embd_k_gqa, n_embd_v_gqa, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0); + + // optional bias tensors + layer.wo_b = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + + if (hparams.rope_scaling_type_train == LLAMA_ROPE_SCALING_TYPE_LONGROPE) { + layer.rope_long = create_tensor(tn(LLM_TENSOR_ROPE_FACTORS_LONG, "weight", i), {n_rot/2}, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0)); + layer.rope_short = create_tensor(tn(LLM_TENSOR_ROPE_FACTORS_SHORT, "weight", i), {n_rot/2}, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0)); + } + else { + layer.rope_freqs = create_tensor(tn(LLM_TENSOR_ROPE_FREQS, "weight", i), {n_rot/2}, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0)); + } + + if (n_expert == 0) { + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + + // optional MLP bias + layer.ffn_gate_b = create_tensor(tn(LLM_TENSOR_FFN_GATE, "bias", i), {n_ff}, TENSOR_NOT_REQUIRED); + layer.ffn_down_b = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); + layer.ffn_up_b = create_tensor(tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff}, TENSOR_NOT_REQUIRED); + } else { + layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0); + layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {n_embd, n_ff, n_expert}, TENSOR_NOT_REQUIRED); + layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), { n_ff, n_embd, n_expert}, 0); + layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), {n_embd, n_ff, n_expert}, 0); + + // For Granite MoE Shared + if (hparams.n_ff_shexp > 0) { + layer.ffn_gate_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), {n_embd, hparams.n_ff_shexp}, 0); + layer.ffn_up_shexp = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), {n_embd, hparams.n_ff_shexp}, 0); + layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), {hparams.n_ff_shexp, n_embd}, 0); + } + } + } +} + +std::unique_ptr<llm_graph_context> llama_model_mistral3::build_arch_graph(const llm_graph_params & params) const { + return std::make_unique<graph>(*this, params); +} + +llama_model_mistral3::graph::graph(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { + const int64_t n_embd_head = hparams.n_embd_head_v(); + + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k()); + GGML_ASSERT(n_embd_head == n_rot); ggml_tensor * cur; ggml_tensor * inpL; @@ -41,27 +131,8 @@ llm_build_mistral3::llm_build_mistral3(const llama_model & model, const llm_grap ggml_tensor * rope_factors = model.get_rope_factors(cparams, il); // compute Q and K and RoPE them - ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); - cb(Qcur, "Qcur", il); - if (model.layers[il].bq) { - Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq); - cb(Qcur, "Qcur", il); - } - ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); - cb(Kcur, "Kcur", il); - if (model.layers[il].bk) { - Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk); - cb(Kcur, "Kcur", il); - } - ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); - cb(Vcur, "Vcur", il); - if (model.layers[il].bv) { - Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv); - cb(Vcur, "Vcur", il); - } - Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); - Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + auto [Qcur, Kcur, Vcur] = build_qkv(model.layers[il], cur, + n_embd_head, n_head, n_head_kv, il); Qcur = ggml_rope_ext( ctx0, Qcur, inp_pos, rope_factors, @@ -86,7 +157,7 @@ llm_build_mistral3::llm_build_mistral3(const llama_model & model, const llm_grap } cur = build_attn(inp_attn, - model.layers[il].wo, model.layers[il].bo, + model.layers[il].wo, model.layers[il].wo_b, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale, il); cb(cur, "attn_out", il); } @@ -106,9 +177,9 @@ llm_build_mistral3::llm_build_mistral3(const llama_model & model, const llm_grap cb(cur, "ffn_norm", il); cur = build_ffn(cur, - model.layers[il].ffn_up, model.layers[il].ffn_up_b, NULL, - model.layers[il].ffn_gate, model.layers[il].ffn_gate_b, NULL, - model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL, + model.layers[il].ffn_up, model.layers[il].ffn_up_b, model.layers[il].ffn_up_s, + model.layers[il].ffn_gate, model.layers[il].ffn_gate_b, model.layers[il].ffn_gate_s, + model.layers[il].ffn_down, model.layers[il].ffn_down_b, model.layers[il].ffn_down_s, NULL, LLM_FFN_SILU, LLM_FFN_PAR, il); cb(cur, "ffn_out", il); @@ -127,9 +198,13 @@ llm_build_mistral3::llm_build_mistral3(const llama_model & model, const llm_grap nullptr, n_expert, n_expert_used, LLM_FFN_SILU, true, - false, 0.0, + hparams.expert_weights_scale, LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX, - il); + il, + nullptr, nullptr, + model.layers[il].ffn_up_exps_s, + model.layers[il].ffn_gate_exps_s, + model.layers[il].ffn_down_exps_s); cb(cur, "ffn_moe_out", il); } cur = ggml_add(ctx0, cur, ffn_inp); @@ -151,7 +226,7 @@ llm_build_mistral3::llm_build_mistral3(const llama_model & model, const llm_grap res->t_embd = cur; // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/mistral4.cpp b/examples/talk-llama/models/mistral4.cpp new file mode 100644 index 00000000000..3d9190650e3 --- /dev/null +++ b/examples/talk-llama/models/mistral4.cpp @@ -0,0 +1,6 @@ +#include "models.h" + +std::unique_ptr<llm_graph_context> llama_model_mistral4::build_arch_graph(const llm_graph_params & params) const { + return std::make_unique<graph>(*this, params); +} + diff --git a/examples/talk-llama/models/models.h b/examples/talk-llama/models/models.h index 6c40f48042b..ee3aff07b9a 100644 --- a/examples/talk-llama/models/models.h +++ b/examples/talk-llama/models/models.h @@ -1,23 +1,95 @@ #pragma once -#include "../llama-model.h" -#include "../llama-graph.h" +#include "llama-model.h" +#include "llama-graph.h" +#include "llama-model-loader.h" -// TODO: remove in follow-up PR - move to .cpp files -#include "../llama-memory-recurrent.h" +// note: almost all graphs require at least sqrtf, so include cmath globally #include <cmath> -struct llm_graph_context_mamba : public llm_graph_context { - llm_graph_context_mamba(const llm_graph_params & params); +// +// base classes +// - virtual ~llm_graph_context_mamba() = default; +struct llm_build_mamba_base : public llm_graph_context { + llm_build_mamba_base(const llm_graph_params & params); + + virtual ~llm_build_mamba_base() = default; ggml_tensor * build_mamba_layer(llm_graph_input_rs * inp, ggml_tensor * cur, const llama_model & model, const llama_ubatch & ubatch, int il); ggml_tensor * build_mamba2_layer(llm_graph_input_rs * inp, ggml_tensor * cur, const llama_model & model, const llama_ubatch & ubatch, int il) const; }; -// Base class for RWKV-related models +struct llm_build_delta_net_base : public llm_graph_context { + llm_build_delta_net_base(const llm_graph_params & params); + + virtual ~llm_build_delta_net_base() = default; + + // returns pair of output and new state + std::pair<ggml_tensor *, ggml_tensor *> build_delta_net_chunking( + ggml_tensor * q, + ggml_tensor * k, + ggml_tensor * v, + ggml_tensor * g, + ggml_tensor * b, + ggml_tensor * s, + int il); + + // returns pair of output and new state + std::pair<ggml_tensor *, ggml_tensor *> build_delta_net_autoregressive( + ggml_tensor * q, + ggml_tensor * k, + ggml_tensor * v, + ggml_tensor * g, + ggml_tensor * b, + ggml_tensor * s, + int il); + + // use the ggml_gated_delta_net fused operator (K=1; state has shape [S_v, S_v, H_v, n_seqs]) + std::pair<ggml_tensor *, ggml_tensor *> build_delta_net_fused( + ggml_tensor * q, + ggml_tensor * k, + ggml_tensor * v, + ggml_tensor * g, + ggml_tensor * b, + ggml_tensor * s, + int il); + + // choose one of two implementations above based on the number of tokens + std::pair<ggml_tensor *, ggml_tensor *> build_delta_net( + ggml_tensor * q, + ggml_tensor * k, + ggml_tensor * v, + ggml_tensor * g, + ggml_tensor * b, + ggml_tensor * s, + int il); + + // read conv state from cache, concat with qkv_mixed, write back (single slot or per-token) + // qkv_mixed: (qkv_dim, n_seq_tokens, n_seqs); returns conv_input: (kernel_size + n_seq_tokens - 1, channels, n_seqs) + ggml_tensor * build_conv_state( + llm_graph_input_rs * inp, + ggml_tensor * conv_states_all, + ggml_tensor * qkv_mixed, + int64_t conv_kernel_size, + int64_t conv_channels, + int il); + + // run delta-net attention and write the new recurrent state(s) back to ssm_states_all + // s: (head_v_dim, head_v_dim, num_v_heads, n_seqs); returns output: (head_v_dim, num_v_heads, n_seq_tokens, n_seqs) + ggml_tensor * build_recurrent_attn( + llm_graph_input_rs * inp, + ggml_tensor * ssm_states_all, + ggml_tensor * q, + ggml_tensor * k, + ggml_tensor * v, + ggml_tensor * g, + ggml_tensor * b, + ggml_tensor * s, + int il); +}; + struct llm_build_rwkv6_base : public llm_graph_context { const llama_model & model; @@ -58,512 +130,1832 @@ struct llm_build_rwkv7_base : public llm_graph_context { int il) const; }; -struct llm_build_afmoe : public llm_graph_context { - llm_build_afmoe(const llama_model & model, const llm_graph_params & params); -}; +// +// models +// -struct llm_build_apertus : public llm_graph_context { - llm_build_apertus(const llama_model & model, const llm_graph_params & params); -}; +struct llama_model_llama : public llama_model_base { + llama_model_llama(const struct llama_model_params & params) : llama_model_base(params) {} + void load_arch_hparams(llama_model_loader & ml) override; + void load_arch_tensors(llama_model_loader & ml) override; -struct llm_build_arcee : public llm_graph_context { - llm_build_arcee(const llama_model & model, const llm_graph_params & params); -}; + template <bool embed> + struct graph : public llm_graph_context { + graph(const llama_model & model, const llm_graph_params & params); + }; -struct llm_build_arctic : public llm_graph_context { - llm_build_arctic(const llama_model & model, const llm_graph_params & params); + std::unique_ptr<llm_graph_context> build_arch_graph(const llm_graph_params & params) const override; }; -struct llm_build_arwkv7 : public llm_build_rwkv7_base { - llm_build_arwkv7(const llama_model & model, const llm_graph_params & params); -}; -struct llm_build_baichuan : public llm_graph_context { - llm_build_baichuan(const llama_model & model, const llm_graph_params & params); -}; +struct llama_model_llama4 : public llama_model_base { + llama_model_llama4(const struct llama_model_params & params) : llama_model_base(params) {} + void load_arch_hparams(llama_model_loader & ml) override; + void load_arch_tensors(llama_model_loader & ml) override; -struct llm_build_bailingmoe2 : public llm_graph_context { - llm_build_bailingmoe2(const llama_model & model, const llm_graph_params & params); -}; + template <bool iswa> + struct graph : public llm_graph_context { + graph(const llama_model & model, const llm_graph_params & params); + }; -struct llm_build_bailingmoe : public llm_graph_context { - llm_build_bailingmoe(const llama_model & model, const llm_graph_params & params); + std::unique_ptr<llm_graph_context> build_arch_graph(const llm_graph_params & params) const override; }; -struct llm_build_bert : public llm_graph_context { - llm_build_bert(const llama_model & model, const llm_graph_params & params); -}; -struct llm_build_bitnet : public llm_graph_context { - llm_build_bitnet(const llama_model & model, const llm_graph_params & params); -}; +struct llama_model_llama_embed : public llama_model_llama { + llama_model_llama_embed(const struct llama_model_params & params) : llama_model_llama(params) {} + // reuse load_arch_hparams and load_arch_tensors from llama_model_llama -struct llm_build_bloom : public llm_graph_context { - llm_build_bloom(const llama_model & model, const llm_graph_params & params); -}; + template <bool embed> + using graph = llama_model_llama::graph<embed>; -struct llm_build_chameleon : public llm_graph_context { - llm_build_chameleon(const llama_model & model, const llm_graph_params & params); + std::unique_ptr<llm_graph_context> build_arch_graph(const llm_graph_params & params) const override; }; -struct llm_build_chatglm : public llm_graph_context { - llm_build_chatglm(const llama_model & model, const llm_graph_params & params); -}; -struct llm_build_codeshell : public llm_graph_context { - llm_build_codeshell(const llama_model & model, const llm_graph_params & params); -}; +struct llama_model_maincoder : public llama_model_base { + llama_model_maincoder(const struct llama_model_params & params) : llama_model_base(params) {} + void load_arch_hparams(llama_model_loader & ml) override; + void load_arch_tensors(llama_model_loader & ml) override; -struct llm_build_cogvlm : public llm_graph_context { - llm_build_cogvlm(const llama_model & model, const llm_graph_params & params); -}; + struct graph : public llm_graph_context { + graph(const llama_model & model, const llm_graph_params & params); + }; -struct llm_build_cohere2_iswa : public llm_graph_context { - llm_build_cohere2_iswa(const llama_model & model, const llm_graph_params & params); + std::unique_ptr<llm_graph_context> build_arch_graph(const llm_graph_params & params) const override; }; -struct llm_build_command_r : public llm_graph_context { - llm_build_command_r(const llama_model & model, const llm_graph_params & params); -}; -struct llm_build_dbrx : public llm_graph_context { - llm_build_dbrx(const llama_model & model, const llm_graph_params & params); -}; +struct llama_model_talkie : public llama_model_base { + llama_model_talkie(const struct llama_model_params & params) : llama_model_base(params) {} + void load_arch_hparams(llama_model_loader & ml) override; + void load_arch_tensors(llama_model_loader & ml) override; -struct llm_build_deci : public llm_graph_context { - llm_build_deci(const llama_model & model, const llm_graph_params & params); -}; + struct graph : public llm_graph_context { + graph(const llama_model & model, const llm_graph_params & params); + }; -struct llm_build_deepseek2 : public llm_graph_context { - llm_build_deepseek2(const llama_model & model, const llm_graph_params & params); + std::unique_ptr<llm_graph_context> build_arch_graph(const llm_graph_params & params) const override; }; -struct llm_build_deepseek : public llm_graph_context { - llm_build_deepseek(const llama_model & model, const llm_graph_params & params); -}; -struct llm_build_dots1 : public llm_graph_context { - llm_build_dots1(const llama_model & model, const llm_graph_params & params); -}; +struct llama_model_deci : public llama_model_base { + llama_model_deci(const struct llama_model_params & params) : llama_model_base(params) {} + void load_arch_hparams(llama_model_loader & ml) override; + void load_arch_tensors(llama_model_loader & ml) override; -struct llm_build_dream : public llm_graph_context { - llm_build_dream(const llama_model & model, const llm_graph_params & params); -}; + struct graph : public llm_graph_context { + graph(const llama_model & model, const llm_graph_params & params); + }; -struct llm_build_ernie4_5 : public llm_graph_context { - llm_build_ernie4_5(const llama_model & model, const llm_graph_params & params); + std::unique_ptr<llm_graph_context> build_arch_graph(const llm_graph_params & params) const override; }; -struct llm_build_ernie4_5_moe : public llm_graph_context { - llm_build_ernie4_5_moe(const llama_model & model, const llm_graph_params & params); -}; -template <bool iswa> -struct llm_build_exaone4 : public llm_graph_context { - llm_build_exaone4(const llama_model & model, const llm_graph_params & params); -}; +struct llama_model_baichuan : public llama_model_base { + llama_model_baichuan(const struct llama_model_params & params) : llama_model_base(params) {} + void load_arch_hparams(llama_model_loader & ml) override; + void load_arch_tensors(llama_model_loader & ml) override; -struct llm_build_exaone : public llm_graph_context { - llm_build_exaone(const llama_model & model, const llm_graph_params & params); -}; + struct graph : public llm_graph_context { + graph(const llama_model & model, const llm_graph_params & params); + }; -struct llm_build_falcon : public llm_graph_context { - llm_build_falcon(const llama_model & model, const llm_graph_params & params); + std::unique_ptr<llm_graph_context> build_arch_graph(const llm_graph_params & params) const override; }; -struct llm_build_falcon_h1 : public llm_graph_context_mamba { - llm_build_falcon_h1(const llama_model & model, const llm_graph_params & params); -}; -struct llm_build_gemma2_iswa : public llm_graph_context { - llm_build_gemma2_iswa(const llama_model & model, const llm_graph_params & params); -}; +struct llama_model_falcon : public llama_model_base { + llama_model_falcon(const struct llama_model_params & params) : llama_model_base(params) {} + void load_arch_hparams(llama_model_loader & ml) override; + void load_arch_tensors(llama_model_loader & ml) override; -template <bool iswa> -struct llm_build_gemma3 : public llm_graph_context { - llm_build_gemma3(const llama_model & model, const llm_graph_params & params); + struct graph : public llm_graph_context { + graph(const llama_model & model, const llm_graph_params & params); + }; + + std::unique_ptr<llm_graph_context> build_arch_graph(const llm_graph_params & params) const override; }; -struct llm_build_gemma3n_iswa : public llm_graph_context { - const llama_model & model; - const int64_t n_embd_head; - const int64_t n_embd_altup; - const int64_t n_altup; - const int i_altup_act; - const int n_layer_sparsity = 10; // number of layers using activation sparsity - const float f_sparsity_std_mul = 1.6448533535003662f; // std_multiplier = normal_dist.icdf(0.95) +struct llama_model_grok : public llama_model_base { + llama_model_grok(const struct llama_model_params & params) : llama_model_base(params) {} + void load_arch_hparams(llama_model_loader & ml) override; + void load_arch_tensors(llama_model_loader & ml) override; - llm_build_gemma3n_iswa(const llama_model & model, const llm_graph_params & params); - ggml_tensor * calc_magnitude(ggml_tensor * x); - ggml_tensor * view_2d_slice(ggml_tensor * x, int idx); - ggml_tensor * get_per_layer_inputs(); - ggml_tensor * project_per_layer_inputs(ggml_tensor * inputs_embeds, ggml_tensor * inp_per_layer); - ggml_tensor * gaussian_topk(ggml_tensor * x); - ggml_tensor * altup_compute_router_modalities(ggml_tensor * x, int il); - ggml_tensor * altup_predict(ggml_tensor * cur, int il); - ggml_tensor * laurel(ggml_tensor * cur, int il); - ggml_tensor * altup_correct(ggml_tensor * predictions, ggml_tensor * activated, int il); -}; + struct graph : public llm_graph_context { + graph(const llama_model & model, const llm_graph_params & params); + }; -struct llm_build_gemma_embedding : public llm_graph_context { - llm_build_gemma_embedding(const llama_model & model, const llm_graph_params & params); + std::unique_ptr<llm_graph_context> build_arch_graph(const llm_graph_params & params) const override; }; -struct llm_build_gemma : public llm_graph_context { - llm_build_gemma(const llama_model & model, const llm_graph_params & params); -}; -struct llm_build_glm4 : public llm_graph_context { - llm_build_glm4(const llama_model & model, const llm_graph_params & params); -}; +struct llama_model_starcoder : public llama_model_base { + llama_model_starcoder(const struct llama_model_params & params) : llama_model_base(params) {} + void load_arch_hparams(llama_model_loader & ml) override; + void load_arch_tensors(llama_model_loader & ml) override; -struct llm_build_glm4_moe : public llm_graph_context { - llm_build_glm4_moe(const llama_model & model, const llm_graph_params & params); -}; + struct graph : public llm_graph_context { + graph(const llama_model & model, const llm_graph_params & params); + }; -struct llm_build_gpt2 : public llm_graph_context { - llm_build_gpt2(const llama_model & model, const llm_graph_params & params); + std::unique_ptr<llm_graph_context> build_arch_graph(const llm_graph_params & params) const override; }; -struct llm_build_gptneox : public llm_graph_context { - llm_build_gptneox(const llama_model & model, const llm_graph_params & params); -}; -struct llm_build_granite : public llm_graph_context { - llm_build_granite(const llama_model & model, const llm_graph_params & params); +struct llama_model_refact : public llama_model_base { + llama_model_refact(const struct llama_model_params & params) : llama_model_base(params) {} + void load_arch_hparams(llama_model_loader & ml) override; + void load_arch_tensors(llama_model_loader & ml) override; -private: - ggml_tensor * build_attention_layer( - ggml_tensor * cur, - ggml_tensor * inp_pos, - llm_graph_input_attn_kv * inp_attn, - const llama_model & model, - const int64_t n_embd_head, - const int il); + struct graph : public llm_graph_context { + graph(const llama_model & model, const llm_graph_params & params); + }; - ggml_tensor * build_layer_ffn( - ggml_tensor * cur, - ggml_tensor * inpSA, - const llama_model & model, - const int il); + std::unique_ptr<llm_graph_context> build_arch_graph(const llm_graph_params & params) const override; }; -struct llm_build_granite_hybrid : public llm_graph_context_mamba { - llm_build_granite_hybrid(const llama_model & model, const llm_graph_params & params); - ggml_tensor * build_layer_ffn(ggml_tensor * cur, ggml_tensor * inpSA, const llama_model & model, const int il); - ggml_tensor * build_attention_layer(ggml_tensor * cur, ggml_tensor * inp_pos, llm_graph_input_attn_kv * inp_attn, - const llama_model & model,const int64_t n_embd_head, const int il); -}; -struct llm_build_grok : public llm_graph_context { - llm_build_grok(const llama_model & model, const llm_graph_params & params); -}; +struct llama_model_bert : public llama_model_base { + llama_model_bert(const struct llama_model_params & params) : llama_model_base(params) {} + void load_arch_hparams(llama_model_loader & ml) override; + void load_arch_tensors(llama_model_loader & ml) override; -struct llm_build_grovemoe : public llm_graph_context { - llm_build_grovemoe(const llama_model & model, const llm_graph_params & params); -}; + struct graph : public llm_graph_context { + graph(const llama_model & model, const llm_graph_params & params); + }; -struct llm_build_hunyuan_dense : public llm_graph_context { - llm_build_hunyuan_dense(const llama_model & model, const llm_graph_params & params); + std::unique_ptr<llm_graph_context> build_arch_graph(const llm_graph_params & params) const override; }; -struct llm_build_hunyuan_moe : public llm_graph_context { - llm_build_hunyuan_moe(const llama_model & model, const llm_graph_params & params); -}; -struct llm_build_internlm2 : public llm_graph_context { - llm_build_internlm2(const llama_model & model, const llm_graph_params & params); +struct llama_model_jina_bert_v2 : public llama_model_base { + llama_model_jina_bert_v2(const struct llama_model_params & params) : llama_model_base(params) {} + void load_arch_hparams(llama_model_loader & ml) override; + void load_arch_tensors(llama_model_loader & ml) override; + + using graph = llama_model_bert::graph; + + std::unique_ptr<llm_graph_context> build_arch_graph(const llm_graph_params & params) const override; }; -struct llm_build_jais : public llm_graph_context { - llm_build_jais(const llama_model & model, const llm_graph_params & params); + +struct llama_model_jina_bert_v3 : public llama_model_base { + llama_model_jina_bert_v3(const struct llama_model_params & params) : llama_model_base(params) {} + void load_arch_hparams(llama_model_loader & ml) override; + void load_arch_tensors(llama_model_loader & ml) override; + + using graph = llama_model_bert::graph; + + std::unique_ptr<llm_graph_context> build_arch_graph(const llm_graph_params & params) const override; }; -struct llm_build_jamba : public llm_graph_context_mamba { - llm_build_jamba(const llama_model & model, const llm_graph_params & params); + +struct llama_model_nomic_bert : public llama_model_base { + llama_model_nomic_bert(const struct llama_model_params & params) : llama_model_base(params) {} + void load_arch_hparams(llama_model_loader & ml) override; + void load_arch_tensors(llama_model_loader & ml) override; + + using graph = llama_model_bert::graph; + + std::unique_ptr<llm_graph_context> build_arch_graph(const llm_graph_params & params) const override; }; -struct llm_build_lfm2 : public llm_graph_context { - const llama_model & model; - llm_build_lfm2(const llama_model & model, const llm_graph_params & params); - ggml_tensor * build_moe_feed_forward(ggml_tensor * cur, int il) const; - ggml_tensor * build_dense_feed_forward(ggml_tensor * cur, int il) const; - ggml_tensor * build_attn_block(ggml_tensor * cur, ggml_tensor * inp_pos, llm_graph_input_attn_kv * inp_attn, int il) const; - ggml_tensor * build_shortconv_block(ggml_tensor * cur, llm_graph_input_rs * inp_recr, int il); +struct llama_model_nomic_bert_moe : public llama_model_base { + llama_model_nomic_bert_moe(const struct llama_model_params & params) : llama_model_base(params) {} + void load_arch_hparams(llama_model_loader & ml) override; + void load_arch_tensors(llama_model_loader & ml) override; + using graph = llama_model_bert::graph; + + std::unique_ptr<llm_graph_context> build_arch_graph(const llm_graph_params & params) const override; }; -struct llm_build_llada : public llm_graph_context { - llm_build_llada(const llama_model & model, const llm_graph_params & params); + +struct llama_model_modern_bert : public llama_model_base { + llama_model_modern_bert(const struct llama_model_params & params) : llama_model_base(params) {} + void load_arch_hparams(llama_model_loader & ml) override; + void load_arch_tensors(llama_model_loader & ml) override; + + struct graph : public llm_graph_context { + graph(const llama_model & model, const llm_graph_params & params); + }; + + std::unique_ptr<llm_graph_context> build_arch_graph(const llm_graph_params & params) const override; }; -struct llm_build_llada_moe : public llm_graph_context { - llm_build_llada_moe(const llama_model & model, const llm_graph_params & params); + +struct llama_model_neo_bert : public llama_model_base { + llama_model_neo_bert(const struct llama_model_params & params) : llama_model_base(params) {} + void load_arch_hparams(llama_model_loader & ml) override; + void load_arch_tensors(llama_model_loader & ml) override; + + struct graph : public llm_graph_context { + graph(const llama_model & model, const llm_graph_params & params); + }; + + std::unique_ptr<llm_graph_context> build_arch_graph(const llm_graph_params & params) const override; }; -template <bool embed> -struct llm_build_llama : public llm_graph_context { - llm_build_llama(const llama_model & model, const llm_graph_params & params); + +struct llama_model_eurobert : public llama_model_base { + llama_model_eurobert(const struct llama_model_params & params) : llama_model_base(params) {} + void load_arch_hparams(llama_model_loader & ml) override; + void load_arch_tensors(llama_model_loader & ml) override; + + struct graph : public llm_graph_context { + graph(const llama_model & model, const llm_graph_params & params); + }; + + std::unique_ptr<llm_graph_context> build_arch_graph(const llm_graph_params & params) const override; }; -struct llm_build_llama_iswa : public llm_graph_context { - llm_build_llama_iswa(const llama_model & model, const llm_graph_params & params); + +struct llama_model_bloom : public llama_model_base { + llama_model_bloom(const struct llama_model_params & params) : llama_model_base(params) {} + void load_arch_hparams(llama_model_loader & ml) override; + void load_arch_tensors(llama_model_loader & ml) override; + + struct graph : public llm_graph_context { + graph(const llama_model & model, const llm_graph_params & params); + }; + + std::unique_ptr<llm_graph_context> build_arch_graph(const llm_graph_params & params) const override; }; -struct llm_build_maincoder : public llm_graph_context { - llm_build_maincoder(const llama_model & model, const llm_graph_params & params); + +struct llama_model_mpt : public llama_model_base { + llama_model_mpt(const struct llama_model_params & params) : llama_model_base(params) {} + void load_arch_hparams(llama_model_loader & ml) override; + void load_arch_tensors(llama_model_loader & ml) override; + + struct graph : public llm_graph_context { + graph(const llama_model & model, const llm_graph_params & params); + }; + + std::unique_ptr<llm_graph_context> build_arch_graph(const llm_graph_params & params) const override; }; -struct llm_build_mamba : public llm_graph_context_mamba { - llm_build_mamba(const llama_model & model, const llm_graph_params & params); + +struct llama_model_stablelm : public llama_model_base { + llama_model_stablelm(const struct llama_model_params & params) : llama_model_base(params) {} + void load_arch_hparams(llama_model_loader & ml) override; + void load_arch_tensors(llama_model_loader & ml) override; + + struct graph : public llm_graph_context { + graph(const llama_model & model, const llm_graph_params & params); + }; + + std::unique_ptr<llm_graph_context> build_arch_graph(const llm_graph_params & params) const override; }; -struct llm_build_mimo2_iswa : public llm_graph_context { - llm_build_mimo2_iswa(const llama_model & model, const llm_graph_params & params); +struct llama_model_mellum : public llama_model_base { + llama_model_mellum(const struct llama_model_params & params) : llama_model_base(params) {} + void load_arch_hparams(llama_model_loader & ml) override; + void load_arch_tensors(llama_model_loader & ml) override; + + template <bool iswa> + struct graph : public llm_graph_context { + graph(const llama_model & model, const llm_graph_params & params); + }; + + std::unique_ptr<llm_graph_context> build_arch_graph(const llm_graph_params & params) const override; }; -struct llm_build_minicpm3 : public llm_graph_context { - llm_build_minicpm3(const llama_model & model, const llm_graph_params & params); +struct llama_model_qwen : public llama_model_base { + llama_model_qwen(const struct llama_model_params & params) : llama_model_base(params) {} + void load_arch_hparams(llama_model_loader & ml) override; + void load_arch_tensors(llama_model_loader & ml) override; + + struct graph : public llm_graph_context { + graph(const llama_model & model, const llm_graph_params & params); + }; + + std::unique_ptr<llm_graph_context> build_arch_graph(const llm_graph_params & params) const override; }; -struct llm_build_minimax_m2 : public llm_graph_context { - llm_build_minimax_m2(const llama_model & model, const llm_graph_params & params); + +struct llama_model_qwen2 : public llama_model_base { + llama_model_qwen2(const struct llama_model_params & params) : llama_model_base(params) {} + void load_arch_hparams(llama_model_loader & ml) override; + void load_arch_tensors(llama_model_loader & ml) override; + + struct graph : public llm_graph_context { + graph(const llama_model & model, const llm_graph_params & params); + }; + + std::unique_ptr<llm_graph_context> build_arch_graph(const llm_graph_params & params) const override; }; -struct llm_build_mistral3 : public llm_graph_context { - llm_build_mistral3(const llama_model & model, const llm_graph_params & params); + +struct llama_model_dream : public llama_model_base { + llama_model_dream(const struct llama_model_params & params) : llama_model_base(params) {} + void load_arch_hparams(llama_model_loader & ml) override; + void load_arch_tensors(llama_model_loader & ml) override; + + struct graph : public llm_graph_context { + graph(const llama_model & model, const llm_graph_params & params); + }; + + std::unique_ptr<llm_graph_context> build_arch_graph(const llm_graph_params & params) const override; }; -struct llm_build_modern_bert : public llm_graph_context { - llm_build_modern_bert(const llama_model & model, const llm_graph_params & params); + +struct llama_model_llada : public llama_model_base { + llama_model_llada(const struct llama_model_params & params) : llama_model_base(params) {} + void load_arch_hparams(llama_model_loader & ml) override; + void load_arch_tensors(llama_model_loader & ml) override; + + struct graph : public llm_graph_context { + graph(const llama_model & model, const llm_graph_params & params); + }; + + std::unique_ptr<llm_graph_context> build_arch_graph(const llm_graph_params & params) const override; }; -struct llm_build_mpt : public llm_graph_context { - llm_build_mpt(const llama_model & model, const llm_graph_params & params); + +struct llama_model_llada_moe : public llama_model_base { + llama_model_llada_moe(const struct llama_model_params & params) : llama_model_base(params) {} + void load_arch_hparams(llama_model_loader & ml) override; + void load_arch_tensors(llama_model_loader & ml) override; + + struct graph : public llm_graph_context { + graph(const llama_model & model, const llm_graph_params & params); + }; + + std::unique_ptr<llm_graph_context> build_arch_graph(const llm_graph_params & params) const override; }; -struct llm_build_nemotron : public llm_graph_context { - llm_build_nemotron(const llama_model & model, const llm_graph_params & params); + +struct llama_model_rnd1 : public llama_model_base { + llama_model_rnd1(const struct llama_model_params & params) : llama_model_base(params) {} + void load_arch_hparams(llama_model_loader & ml) override; + void load_arch_tensors(llama_model_loader & ml) override; + + struct graph : public llm_graph_context { + graph(const llama_model & model, const llm_graph_params & params); + }; + + std::unique_ptr<llm_graph_context> build_arch_graph(const llm_graph_params & params) const override; }; -struct llm_build_nemotron_h : public llm_graph_context_mamba { - llm_build_nemotron_h(const llama_model & model, const llm_graph_params & params); - ggml_tensor * build_ffn_layer(ggml_tensor * cur, const llama_model & model, const int il); - ggml_tensor * build_attention_layer(ggml_tensor * cur, llm_graph_input_attn_kv * inp_attn, - const llama_model & model, const int64_t n_embd_head, const int il); + +struct llama_model_qwen2vl : public llama_model_base { + llama_model_qwen2vl(const struct llama_model_params & params) : llama_model_base(params) {} + void load_arch_hparams(llama_model_loader & ml) override; + void load_arch_tensors(llama_model_loader & ml) override; + + struct graph : public llm_graph_context { + graph(const llama_model & model, const llm_graph_params & params); + }; + + std::unique_ptr<llm_graph_context> build_arch_graph(const llm_graph_params & params) const override; }; -struct llm_build_neo_bert : public llm_graph_context { - llm_build_neo_bert(const llama_model & model, const llm_graph_params & params); + +struct llama_model_qwen2moe : public llama_model_base { + llama_model_qwen2moe(const struct llama_model_params & params) : llama_model_base(params) {} + void load_arch_hparams(llama_model_loader & ml) override; + void load_arch_tensors(llama_model_loader & ml) override; + + struct graph : public llm_graph_context { + graph(const llama_model & model, const llm_graph_params & params); + }; + + std::unique_ptr<llm_graph_context> build_arch_graph(const llm_graph_params & params) const override; }; -template <bool iswa> -struct llm_build_olmo2 : public llm_graph_context { - llm_build_olmo2(const llama_model & model, const llm_graph_params & params); + +struct llama_model_qwen3 : public llama_model_base { + llama_model_qwen3(const struct llama_model_params & params) : llama_model_base(params) {} + void load_arch_hparams(llama_model_loader & ml) override; + void load_arch_tensors(llama_model_loader & ml) override; + + struct graph : public llm_graph_context { + graph(const llama_model & model, const llm_graph_params & params); + }; + + std::unique_ptr<llm_graph_context> build_arch_graph(const llm_graph_params & params) const override; }; -struct llm_build_olmoe : public llm_graph_context { - llm_build_olmoe(const llama_model & model, const llm_graph_params & params); + +struct llama_model_qwen3moe : public llama_model_base { + llama_model_qwen3moe(const struct llama_model_params & params) : llama_model_base(params) {} + void load_arch_hparams(llama_model_loader & ml) override; + void load_arch_tensors(llama_model_loader & ml) override; + + struct graph : public llm_graph_context { + graph(const llama_model & model, const llm_graph_params & params); + }; + + std::unique_ptr<llm_graph_context> build_arch_graph(const llm_graph_params & params) const override; }; -struct llm_build_olmo : public llm_graph_context { - llm_build_olmo(const llama_model & model, const llm_graph_params & params); + +struct llama_model_qwen3vl : public llama_model_base { + llama_model_qwen3vl(const struct llama_model_params & params) : llama_model_base(params) {} + void load_arch_hparams(llama_model_loader & ml) override; + void load_arch_tensors(llama_model_loader & ml) override; + + struct graph : public llm_graph_context { + graph(const llama_model & model, const llm_graph_params & params); + }; + + std::unique_ptr<llm_graph_context> build_arch_graph(const llm_graph_params & params) const override; }; -struct llm_build_openai_moe_iswa : public llm_graph_context { - llm_build_openai_moe_iswa(const llama_model & model, const llm_graph_params & params); + +struct llama_model_qwen3vlmoe : public llama_model_base { + llama_model_qwen3vlmoe(const struct llama_model_params & params) : llama_model_base(params) {} + void load_arch_hparams(llama_model_loader & ml) override; + void load_arch_tensors(llama_model_loader & ml) override; + + struct graph : public llm_graph_context { + graph(const llama_model & model, const llm_graph_params & params); + }; + + std::unique_ptr<llm_graph_context> build_arch_graph(const llm_graph_params & params) const override; }; -struct llm_build_openelm : public llm_graph_context { - llm_build_openelm(const llama_model & model, const llm_graph_params & params); + +struct llama_model_phi2 : public llama_model_base { + llama_model_phi2(const struct llama_model_params & params) : llama_model_base(params) {} + void load_arch_hparams(llama_model_loader & ml) override; + void load_arch_tensors(llama_model_loader & ml) override; + + struct graph : public llm_graph_context { + graph(const llama_model & model, const llm_graph_params & params); + }; + + std::unique_ptr<llm_graph_context> build_arch_graph(const llm_graph_params & params) const override; }; -struct llm_build_orion : public llm_graph_context { - llm_build_orion(const llama_model & model, const llm_graph_params & params); + +struct llama_model_phi3 : public llama_model_base { + llama_model_phi3(const struct llama_model_params & params) : llama_model_base(params) {} + void load_arch_hparams(llama_model_loader & ml) override; + void load_arch_tensors(llama_model_loader & ml) override; + + template <bool iswa> + struct graph : public llm_graph_context { + graph(const llama_model & model, const llm_graph_params & params); + }; + + std::unique_ptr<llm_graph_context> build_arch_graph(const llm_graph_params & params) const override; }; -struct llm_build_pangu_embedded : public llm_graph_context { - llm_build_pangu_embedded(const llama_model & model, const llm_graph_params & params); + +struct llama_model_phimoe : public llama_model_base { + llama_model_phimoe(const struct llama_model_params & params) : llama_model_base(params) {} + void load_arch_hparams(llama_model_loader & ml) override; + void load_arch_tensors(llama_model_loader & ml) override; + + template <bool iswa> + using graph = llama_model_phi3::graph<iswa>; + + std::unique_ptr<llm_graph_context> build_arch_graph(const llm_graph_params & params) const override; }; -struct llm_build_phi2 : public llm_graph_context { - llm_build_phi2(const llama_model & model, const llm_graph_params & params); + +struct llama_model_plamo : public llama_model_base { + llama_model_plamo(const struct llama_model_params & params) : llama_model_base(params) {} + void load_arch_hparams(llama_model_loader & ml) override; + void load_arch_tensors(llama_model_loader & ml) override; + + struct graph : public llm_graph_context { + graph(const llama_model & model, const llm_graph_params & params); + }; + + std::unique_ptr<llm_graph_context> build_arch_graph(const llm_graph_params & params) const override; }; -template<bool iswa> -struct llm_build_phi3 : public llm_graph_context { - llm_build_phi3(const llama_model & model, const llm_graph_params & params); + +struct llama_model_plamo2 : public llama_model_base { + llama_model_plamo2(const struct llama_model_params & params) : llama_model_base(params) {} + void load_arch_hparams(llama_model_loader & ml) override; + void load_arch_tensors(llama_model_loader & ml) override; + + struct graph : public llm_build_mamba_base { + graph(const llama_model & model, const llm_graph_params & params); + private: + ggml_tensor * build_plamo2_mamba_layer(llm_graph_input_rs * inp, ggml_tensor * cur, const llama_model & model, const llama_ubatch & ubatch, int il); + ggml_tensor * build_plamo2_attn_layer(llm_graph_input_attn_kv * inp, ggml_tensor * inp_pos, ggml_tensor * cur, + const llama_model & model, int il); + }; + + std::unique_ptr<llm_graph_context> build_arch_graph(const llm_graph_params & params) const override; }; -struct llm_build_plamo2 : public llm_graph_context_mamba { - llm_build_plamo2(const llama_model & model, const llm_graph_params & params); - private: - ggml_tensor * build_plamo2_mamba_layer(llm_graph_input_rs * inp, ggml_tensor * cur, const llama_model & model, const llama_ubatch & ubatch, int il); - ggml_tensor * build_plamo2_attn_layer(llm_graph_input_attn_kv * inp, ggml_tensor * inp_pos, ggml_tensor * cur, - const llama_model & model, int il); + +struct llama_model_plamo3 : public llama_model_base { + llama_model_plamo3(const struct llama_model_params & params) : llama_model_base(params) {} + void load_arch_hparams(llama_model_loader & ml) override; + void load_arch_tensors(llama_model_loader & ml) override; + + template <bool iswa> + struct graph : public llm_graph_context { + graph(const llama_model & model, const llm_graph_params & params); + }; + + std::unique_ptr<llm_graph_context> build_arch_graph(const llm_graph_params & params) const override; }; -struct llm_build_plamo : public llm_graph_context { - llm_build_plamo(const llama_model & model, const llm_graph_params & params); + +struct llama_model_gpt2 : public llama_model_base { + llama_model_gpt2(const struct llama_model_params & params) : llama_model_base(params) {} + void load_arch_hparams(llama_model_loader & ml) override; + void load_arch_tensors(llama_model_loader & ml) override; + + struct graph : public llm_graph_context { + graph(const llama_model & model, const llm_graph_params & params); + }; + + std::unique_ptr<llm_graph_context> build_arch_graph(const llm_graph_params & params) const override; }; -template <bool iswa> -struct llm_build_plamo3 : public llm_graph_context { - llm_build_plamo3(const llama_model & model, const llm_graph_params & params); + +struct llama_model_codeshell : public llama_model_base { + llama_model_codeshell(const struct llama_model_params & params) : llama_model_base(params) {} + void load_arch_hparams(llama_model_loader & ml) override; + void load_arch_tensors(llama_model_loader & ml) override; + + struct graph : public llm_graph_context { + graph(const llama_model & model, const llm_graph_params & params); + }; + + std::unique_ptr<llm_graph_context> build_arch_graph(const llm_graph_params & params) const override; }; -struct llm_build_plm : public llm_graph_context { - llm_build_plm(const llama_model & model, const llm_graph_params & params); + +struct llama_model_orion : public llama_model_base { + llama_model_orion(const struct llama_model_params & params) : llama_model_base(params) {} + void load_arch_hparams(llama_model_loader & ml) override; + void load_arch_tensors(llama_model_loader & ml) override; + + struct graph : public llm_graph_context { + graph(const llama_model & model, const llm_graph_params & params); + }; + + std::unique_ptr<llm_graph_context> build_arch_graph(const llm_graph_params & params) const override; }; -struct llm_build_qwen2 : public llm_graph_context { - llm_build_qwen2(const llama_model & model, const llm_graph_params & params); + +struct llama_model_internlm2 : public llama_model_base { + llama_model_internlm2(const struct llama_model_params & params) : llama_model_base(params) {} + void load_arch_hparams(llama_model_loader & ml) override; + void load_arch_tensors(llama_model_loader & ml) override; + + struct graph : public llm_graph_context { + graph(const llama_model & model, const llm_graph_params & params); + }; + + std::unique_ptr<llm_graph_context> build_arch_graph(const llm_graph_params & params) const override; }; -struct llm_build_qwen2moe : public llm_graph_context { - llm_build_qwen2moe(const llama_model & model, const llm_graph_params & params); + +struct llama_model_minicpm3 : public llama_model_base { + llama_model_minicpm3(const struct llama_model_params & params) : llama_model_base(params) {} + void load_arch_hparams(llama_model_loader & ml) override; + void load_arch_tensors(llama_model_loader & ml) override; + + struct graph : public llm_graph_context { + graph(const llama_model & model, const llm_graph_params & params); + }; + + std::unique_ptr<llm_graph_context> build_arch_graph(const llm_graph_params & params) const override; }; -struct llm_build_qwen2vl : public llm_graph_context { - llm_build_qwen2vl(const llama_model & model, const llm_graph_params & params); + +struct llama_model_gemma : public llama_model_base { + llama_model_gemma(const struct llama_model_params & params) : llama_model_base(params) {} + void load_arch_hparams(llama_model_loader & ml) override; + void load_arch_tensors(llama_model_loader & ml) override; + + struct graph : public llm_graph_context { + graph(const llama_model & model, const llm_graph_params & params); + }; + + std::unique_ptr<llm_graph_context> build_arch_graph(const llm_graph_params & params) const override; }; -struct llm_build_qwen3 : public llm_graph_context { - llm_build_qwen3(const llama_model & model, const llm_graph_params & params); + +struct llama_model_gemma2 : public llama_model_base { + llama_model_gemma2(const struct llama_model_params & params) : llama_model_base(params) {} + void load_arch_hparams(llama_model_loader & ml) override; + void load_arch_tensors(llama_model_loader & ml) override; + + struct graph : public llm_graph_context { + graph(const llama_model & model, const llm_graph_params & params); + }; + + std::unique_ptr<llm_graph_context> build_arch_graph(const llm_graph_params & params) const override; }; -struct llm_build_qwen3moe : public llm_graph_context { - llm_build_qwen3moe(const llama_model & model, const llm_graph_params & params); + +struct llama_model_gemma3 : public llama_model_base { + llama_model_gemma3(const struct llama_model_params & params) : llama_model_base(params) {} + void load_arch_hparams(llama_model_loader & ml) override; + void load_arch_tensors(llama_model_loader & ml) override; + + template <bool iswa> + struct graph : public llm_graph_context { + graph(const llama_model & model, const llm_graph_params & params); + }; + + std::unique_ptr<llm_graph_context> build_arch_graph(const llm_graph_params & params) const override; }; -struct llm_build_qwen3vl : public llm_graph_context { - llm_build_qwen3vl(const llama_model & model, const llm_graph_params & params); + +struct llama_model_gemma3n : public llama_model_base { + llama_model_gemma3n(const struct llama_model_params & params) : llama_model_base(params) {} + void load_arch_hparams(llama_model_loader & ml) override; + void load_arch_tensors(llama_model_loader & ml) override; + + struct graph : public llm_graph_context { + const llama_model & model; + + const int64_t n_embd_head; + const int64_t n_embd_altup; + const int64_t n_altup; + const int i_altup_act; + const int n_layer_sparsity = 10; // number of layers using activation sparsity + const float f_sparsity_std_mul = 1.6448533535003662f; // std_multiplier = normal_dist.icdf(0.95) + + graph(const llama_model & model, const llm_graph_params & params); + ggml_tensor * calc_magnitude(ggml_tensor * x); + + // TODO: refactor in common "per-layer" functionality [TAG_PER_LAYER] + ggml_tensor * build_inp_per_layer(); + ggml_tensor * project_per_layer_inputs(ggml_tensor * inp_batch, ggml_tensor * inp_per_layer); + + ggml_tensor * gaussian_topk(ggml_tensor * x); + ggml_tensor * altup_compute_router_modalities(ggml_tensor * x, int il); + ggml_tensor * altup_predict(ggml_tensor * cur, int il); + ggml_tensor * laurel(ggml_tensor * cur, int il); + ggml_tensor * altup_correct(ggml_tensor * predictions, ggml_tensor * activated, int il); + }; + + std::unique_ptr<llm_graph_context> build_arch_graph(const llm_graph_params & params) const override; }; -struct llm_build_qwen3vlmoe : public llm_graph_context { - llm_build_qwen3vlmoe(const llama_model & model, const llm_graph_params & params); + +struct llama_model_gemma4 : public llama_model_base { + llama_model_gemma4(const struct llama_model_params & params) : llama_model_base(params) {} + void load_arch_hparams(llama_model_loader & ml) override; + void load_arch_tensors(llama_model_loader & ml) override; + + struct graph : public llm_graph_context { + const llama_model & model; + + const int64_t n_embd_per_layer; + + graph(const llama_model & model, const llm_graph_params & params); + + // TODO: refactor in common "per-layer" functionality [TAG_PER_LAYER] + ggml_tensor * build_inp_per_layer(); + ggml_tensor * project_per_layer_inputs(ggml_tensor * inp_batch, ggml_tensor * inp_per_layer); + }; + + std::unique_ptr<llm_graph_context> build_arch_graph(const llm_graph_params & params) const override; }; -struct llm_build_qwen3next : public llm_graph_context_mamba { - llm_build_qwen3next(const llama_model & model, const llm_graph_params & params); -private: - ggml_tensor * build_layer_attn( - llm_graph_input_attn_kv * inp_attn, - ggml_tensor * cur, - ggml_tensor * inp_pos, - int il); - ggml_tensor * build_layer_attn_linear( - llm_graph_input_rs * inp, - ggml_tensor * cur, - ggml_tensor * causal_mask, - ggml_tensor * identity, - ggml_tensor * diag_mask, - int il); - ggml_tensor * build_layer_ffn( - ggml_tensor * cur, - int il); +struct llama_model_gemma4_assistant : public llama_model_base { + llama_model_gemma4_assistant(const struct llama_model_params & params) : llama_model_base(params) {} + void load_arch_hparams(llama_model_loader & ml) override; + void load_arch_tensors(llama_model_loader & ml) override; - // returns pair of output and new state - std::pair<ggml_tensor *, ggml_tensor *> build_delta_net_chunking( - ggml_tensor * q, - ggml_tensor * k, - ggml_tensor * v, - ggml_tensor * g, - ggml_tensor * beta, - ggml_tensor * state, - ggml_tensor * causal_mask, - ggml_tensor * identity, - ggml_tensor * diag_mask, - int il); + struct graph : public llm_graph_context { + graph(const llama_model & model, const llm_graph_params & params); + }; - // returns pair of output and new state - std::pair<ggml_tensor *, ggml_tensor *> build_delta_net_autoregressive( - ggml_tensor * q, - ggml_tensor * k, - ggml_tensor * v, - ggml_tensor * g, - ggml_tensor * beta, - ggml_tensor * state, - int il); + std::unique_ptr<llm_graph_context> build_arch_graph(const llm_graph_params & params) const override; +}; - ggml_tensor * build_norm_gated( - ggml_tensor * input, - ggml_tensor * weights, - ggml_tensor * gate, - int layer); - // returns pair of qkv, z - std::pair<ggml_tensor *, ggml_tensor *> build_qkvz( - ggml_tensor * input, - int il); +struct llama_model_gemma_embedding : public llama_model_base { + llama_model_gemma_embedding(const struct llama_model_params & params) : llama_model_base(params) {} + void load_arch_hparams(llama_model_loader & ml) override; + void load_arch_tensors(llama_model_loader & ml) override; - const llama_model & model; + struct graph : public llm_graph_context { + graph(const llama_model & model, const llm_graph_params & params); + }; + + std::unique_ptr<llm_graph_context> build_arch_graph(const llm_graph_params & params) const override; }; -struct llm_build_qwen : public llm_graph_context { - llm_build_qwen(const llama_model & model, const llm_graph_params & params); + +struct llama_model_starcoder2 : public llama_model_base { + llama_model_starcoder2(const struct llama_model_params & params) : llama_model_base(params) {} + void load_arch_hparams(llama_model_loader & ml) override; + void load_arch_tensors(llama_model_loader & ml) override; + + struct graph : public llm_graph_context { + graph(const llama_model & model, const llm_graph_params & params); + }; + + std::unique_ptr<llm_graph_context> build_arch_graph(const llm_graph_params & params) const override; }; -struct llm_build_refact : public llm_graph_context { - llm_build_refact(const llama_model & model, const llm_graph_params & params); + +struct llama_model_mamba : public llama_model_base { + llama_model_mamba(const struct llama_model_params & params) : llama_model_base(params) {} + void load_arch_hparams(llama_model_loader & ml) override; + void load_arch_tensors(llama_model_loader & ml) override; + + struct graph : public llm_build_mamba_base { + graph(const llama_model & model, const llm_graph_params & params); + }; + + std::unique_ptr<llm_graph_context> build_arch_graph(const llm_graph_params & params) const override; }; -struct llm_build_rnd1 : public llm_graph_context { - llm_build_rnd1(const llama_model & model, const llm_graph_params & params); + +struct llama_model_mamba2 : public llama_model_base { + llama_model_mamba2(const struct llama_model_params & params) : llama_model_base(params) {} + void load_arch_hparams(llama_model_loader & ml) override; + void load_arch_tensors(llama_model_loader & ml) override; + + using graph = llama_model_mamba::graph; + + std::unique_ptr<llm_graph_context> build_arch_graph(const llm_graph_params & params) const override; }; -struct llm_build_rwkv6 : public llm_build_rwkv6_base { - llm_build_rwkv6(const llama_model & model, const llm_graph_params & params); + +struct llama_model_jamba : public llama_model_base { + llama_model_jamba(const struct llama_model_params & params) : llama_model_base(params) {} + void load_arch_hparams(llama_model_loader & ml) override; + void load_arch_tensors(llama_model_loader & ml) override; + + struct graph : public llm_build_mamba_base { + graph(const llama_model & model, const llm_graph_params & params); + }; + + std::unique_ptr<llm_graph_context> build_arch_graph(const llm_graph_params & params) const override; }; -struct llm_build_rwkv6qwen2 : public llm_build_rwkv6_base { - llm_build_rwkv6qwen2(const llama_model & model, const llm_graph_params & params); + +struct llama_model_xverse : public llama_model_base { + llama_model_xverse(const struct llama_model_params & params) : llama_model_base(params) {} + void load_arch_hparams(llama_model_loader & ml) override; + void load_arch_tensors(llama_model_loader & ml) override; + + struct graph : public llm_graph_context { + graph(const llama_model & model, const llm_graph_params & params); + }; + + std::unique_ptr<llm_graph_context> build_arch_graph(const llm_graph_params & params) const override; }; -struct llm_build_rwkv7 : public llm_build_rwkv7_base { - llm_build_rwkv7(const llama_model & model, const llm_graph_params & params); + +struct llama_model_command_r : public llama_model_base { + llama_model_command_r(const struct llama_model_params & params) : llama_model_base(params) {} + void load_arch_hparams(llama_model_loader & ml) override; + void load_arch_tensors(llama_model_loader & ml) override; + + struct graph : public llm_graph_context { + graph(const llama_model & model, const llm_graph_params & params); + }; + + std::unique_ptr<llm_graph_context> build_arch_graph(const llm_graph_params & params) const override; }; -struct llm_build_seed_oss : public llm_graph_context { - llm_build_seed_oss(const llama_model & model, const llm_graph_params & params); + +struct llama_model_cohere2 : public llama_model_base { + llama_model_cohere2(const struct llama_model_params & params) : llama_model_base(params) {} + void load_arch_hparams(llama_model_loader & ml) override; + void load_arch_tensors(llama_model_loader & ml) override; + + struct graph : public llm_graph_context { + graph(const llama_model & model, const llm_graph_params & params); + }; + + std::unique_ptr<llm_graph_context> build_arch_graph(const llm_graph_params & params) const override; }; -template <bool iswa> -struct llm_build_smallthinker : public llm_graph_context { - llm_build_smallthinker(const llama_model & model, const llm_graph_params & params); + +struct llama_model_dbrx : public llama_model_base { + llama_model_dbrx(const struct llama_model_params & params) : llama_model_base(params) {} + void load_arch_hparams(llama_model_loader & ml) override; + void load_arch_tensors(llama_model_loader & ml) override; + + struct graph : public llm_graph_context { + graph(const llama_model & model, const llm_graph_params & params); + }; + + std::unique_ptr<llm_graph_context> build_arch_graph(const llm_graph_params & params) const override; }; -struct llm_build_smollm3 : public llm_graph_context { - llm_build_smollm3(const llama_model & model, const llm_graph_params & params); + +struct llama_model_olmo : public llama_model_base { + llama_model_olmo(const struct llama_model_params & params) : llama_model_base(params) {} + void load_arch_hparams(llama_model_loader & ml) override; + void load_arch_tensors(llama_model_loader & ml) override; + + struct graph : public llm_graph_context { + graph(const llama_model & model, const llm_graph_params & params); + }; + + std::unique_ptr<llm_graph_context> build_arch_graph(const llm_graph_params & params) const override; }; -struct llm_build_stablelm : public llm_graph_context { - llm_build_stablelm(const llama_model & model, const llm_graph_params & params); + +struct llama_model_olmo2 : public llama_model_base { + llama_model_olmo2(const struct llama_model_params & params) : llama_model_base(params) {} + void load_arch_hparams(llama_model_loader & ml) override; + void load_arch_tensors(llama_model_loader & ml) override; + + template <bool iswa> + struct graph : public llm_graph_context { + graph(const llama_model & model, const llm_graph_params & params); + }; + + std::unique_ptr<llm_graph_context> build_arch_graph(const llm_graph_params & params) const override; }; -struct llm_build_starcoder2 : public llm_graph_context { - llm_build_starcoder2(const llama_model & model, const llm_graph_params & params); + +struct llama_model_olmoe : public llama_model_base { + llama_model_olmoe(const struct llama_model_params & params) : llama_model_base(params) {} + void load_arch_hparams(llama_model_loader & ml) override; + void load_arch_tensors(llama_model_loader & ml) override; + + struct graph : public llm_graph_context { + graph(const llama_model & model, const llm_graph_params & params); + }; + + std::unique_ptr<llm_graph_context> build_arch_graph(const llm_graph_params & params) const override; }; -struct llm_build_starcoder : public llm_graph_context { - llm_build_starcoder(const llama_model & model, const llm_graph_params & params); + +struct llama_model_openelm : public llama_model_base { + llama_model_openelm(const struct llama_model_params & params) : llama_model_base(params) {} + void load_arch_hparams(llama_model_loader & ml) override; + void load_arch_tensors(llama_model_loader & ml) override; + + struct graph : public llm_graph_context { + graph(const llama_model & model, const llm_graph_params & params); + }; + + std::unique_ptr<llm_graph_context> build_arch_graph(const llm_graph_params & params) const override; }; -struct llm_build_t5_dec : public llm_graph_context { - llm_build_t5_dec(const llama_model & model, const llm_graph_params & params); + +struct llama_model_gptneox : public llama_model_base { + llama_model_gptneox(const struct llama_model_params & params) : llama_model_base(params) {} + void load_arch_hparams(llama_model_loader & ml) override; + void load_arch_tensors(llama_model_loader & ml) override; + + struct graph : public llm_graph_context { + graph(const llama_model & model, const llm_graph_params & params); + }; + + std::unique_ptr<llm_graph_context> build_arch_graph(const llm_graph_params & params) const override; }; -struct llm_build_t5_enc : public llm_graph_context { - llm_build_t5_enc(const llama_model & model, const llm_graph_params & params); + +struct llama_model_arctic : public llama_model_base { + llama_model_arctic(const struct llama_model_params & params) : llama_model_base(params) {} + void load_arch_hparams(llama_model_loader & ml) override; + void load_arch_tensors(llama_model_loader & ml) override; + + struct graph : public llm_graph_context { + graph(const llama_model & model, const llm_graph_params & params); + }; + + std::unique_ptr<llm_graph_context> build_arch_graph(const llm_graph_params & params) const override; }; -struct llm_build_wavtokenizer_dec : public llm_graph_context { - llm_build_wavtokenizer_dec(const llama_model & model, const llm_graph_params & params); + +struct llama_model_deepseek : public llama_model_base { + llama_model_deepseek(const struct llama_model_params & params) : llama_model_base(params) {} + void load_arch_hparams(llama_model_loader & ml) override; + void load_arch_tensors(llama_model_loader & ml) override; + + struct graph : public llm_graph_context { + graph(const llama_model & model, const llm_graph_params & params); + }; + + std::unique_ptr<llm_graph_context> build_arch_graph(const llm_graph_params & params) const override; }; -struct llm_build_xverse : public llm_graph_context { - llm_build_xverse(const llama_model & model, const llm_graph_params & params); + +struct llama_model_deepseek2 : public llama_model_base { + llama_model_deepseek2(const struct llama_model_params & params) : llama_model_base(params) {} + void load_arch_hparams(llama_model_loader & ml) override; + void load_arch_tensors(llama_model_loader & ml) override; + + struct graph : public llm_graph_context { + graph(const llama_model & model, const llm_graph_params & params); + }; + + std::unique_ptr<llm_graph_context> build_arch_graph(const llm_graph_params & params) const override; +}; + + +struct llama_model_deepseek32 : public llama_model_base { + llama_model_deepseek32(const struct llama_model_params & params) : llama_model_base(params) {} + void load_arch_hparams(llama_model_loader & ml) override; + void load_arch_tensors(llama_model_loader & ml) override; + + struct graph : public llm_graph_context { + graph(const llama_model & model, const llm_graph_params & params); + }; + + std::unique_ptr<llm_graph_context> build_arch_graph(const llm_graph_params & params) const override; +}; + + +struct llama_model_deepseek2ocr : public llama_model_base { + llama_model_deepseek2ocr(const struct llama_model_params & params) : llama_model_base(params) {} + void load_arch_hparams(llama_model_loader & ml) override; + void load_arch_tensors(llama_model_loader & ml) override; + + using graph = llama_model_deepseek2::graph; + + std::unique_ptr<llm_graph_context> build_arch_graph(const llm_graph_params & params) const override; +}; + + +struct llama_model_glm_dsa : public llama_model_base { + llama_model_glm_dsa(const struct llama_model_params & params) : llama_model_base(params) {} + void load_arch_hparams(llama_model_loader & ml) override; + void load_arch_tensors(llama_model_loader & ml) override; + + using graph = llama_model_deepseek2::graph; + + std::unique_ptr<llm_graph_context> build_arch_graph(const llm_graph_params & params) const override; +}; + +struct llama_model_eagle3 : public llama_model_base { + llama_model_eagle3(const struct llama_model_params & params) : llama_model_base(params) {} + void load_arch_hparams(llama_model_loader & ml) override; + void load_arch_tensors(llama_model_loader & ml) override; + + template <bool is_enc> + struct graph : public llm_graph_context { + graph(const llama_model & model, const llm_graph_params & params); + + ggml_tensor * build_inp_embd_enc() const; + }; + + std::unique_ptr<llm_graph_context> build_arch_graph(const llm_graph_params & params) const override; +}; + + +struct llama_model_mistral4 : public llama_model_deepseek2 { + llama_model_mistral4(const struct llama_model_params & params) : llama_model_deepseek2(params) {} + // reuse load_arch_hparams and load_arch_tensors from llama_model_deepseek2 + + using graph = llama_model_deepseek2::graph; + + std::unique_ptr<llm_graph_context> build_arch_graph(const llm_graph_params & params) const override; +}; + + +struct llama_model_chatglm : public llama_model_base { + llama_model_chatglm(const struct llama_model_params & params) : llama_model_base(params) {} + void load_arch_hparams(llama_model_loader & ml) override; + void load_arch_tensors(llama_model_loader & ml) override; + + struct graph : public llm_graph_context { + graph(const llama_model & model, const llm_graph_params & params); + }; + + std::unique_ptr<llm_graph_context> build_arch_graph(const llm_graph_params & params) const override; +}; + + +struct llama_model_glm4 : public llama_model_base { + llama_model_glm4(const struct llama_model_params & params) : llama_model_base(params) {} + void load_arch_hparams(llama_model_loader & ml) override; + void load_arch_tensors(llama_model_loader & ml) override; + + struct graph : public llm_graph_context { + graph(const llama_model & model, const llm_graph_params & params); + }; + + std::unique_ptr<llm_graph_context> build_arch_graph(const llm_graph_params & params) const override; +}; + + +struct llama_model_glm4_moe : public llama_model_base { + llama_model_glm4_moe(const struct llama_model_params & params) : llama_model_base(params) {} + void load_arch_hparams(llama_model_loader & ml) override; + void load_arch_tensors(llama_model_loader & ml) override; + + struct graph : public llm_graph_context { + graph(const llama_model & model, const llm_graph_params & params); + }; + + std::unique_ptr<llm_graph_context> build_arch_graph(const llm_graph_params & params) const override; +}; + + +struct llama_model_bitnet : public llama_model_base { + llama_model_bitnet(const struct llama_model_params & params) : llama_model_base(params) {} + void load_arch_hparams(llama_model_loader & ml) override; + void load_arch_tensors(llama_model_loader & ml) override; + + struct graph : public llm_graph_context { + graph(const llama_model & model, const llm_graph_params & params); + }; + + std::unique_ptr<llm_graph_context> build_arch_graph(const llm_graph_params & params) const override; +}; + + +struct llama_model_t5 : public llama_model_base { + llama_model_t5(const struct llama_model_params & params) : llama_model_base(params) {} + void load_arch_hparams(llama_model_loader & ml) override; + void load_arch_tensors(llama_model_loader & ml) override; + + template <bool is_enc> + struct graph : public llm_graph_context { + graph(const llama_model & model, const llm_graph_params & params); + }; + + std::unique_ptr<llm_graph_context> build_arch_graph(const llm_graph_params & params) const override; +}; + + +struct llama_model_t5encoder : public llama_model_base { + llama_model_t5encoder(const struct llama_model_params & params) : llama_model_base(params) {} + void load_arch_hparams(llama_model_loader & ml) override; + void load_arch_tensors(llama_model_loader & ml) override; + + using graph = llama_model_t5::graph<true>; + + std::unique_ptr<llm_graph_context> build_arch_graph(const llm_graph_params & params) const override; +}; + + +struct llama_model_jais : public llama_model_base { + llama_model_jais(const struct llama_model_params & params) : llama_model_base(params) {} + void load_arch_hparams(llama_model_loader & ml) override; + void load_arch_tensors(llama_model_loader & ml) override; + + struct graph : public llm_graph_context { + graph(const llama_model & model, const llm_graph_params & params); + }; + + std::unique_ptr<llm_graph_context> build_arch_graph(const llm_graph_params & params) const override; +}; + + +struct llama_model_jais2 : public llama_model_base { + llama_model_jais2(const struct llama_model_params & params) : llama_model_base(params) {} + void load_arch_hparams(llama_model_loader & ml) override; + void load_arch_tensors(llama_model_loader & ml) override; + + struct graph : public llm_graph_context { + graph(const llama_model & model, const llm_graph_params & params); + }; + + std::unique_ptr<llm_graph_context> build_arch_graph(const llm_graph_params & params) const override; +}; + + +struct llama_model_nemotron : public llama_model_base { + llama_model_nemotron(const struct llama_model_params & params) : llama_model_base(params) {} + void load_arch_hparams(llama_model_loader & ml) override; + void load_arch_tensors(llama_model_loader & ml) override; + + struct graph : public llm_graph_context { + graph(const llama_model & model, const llm_graph_params & params); + }; + + std::unique_ptr<llm_graph_context> build_arch_graph(const llm_graph_params & params) const override; +}; + + +struct llama_model_nemotron_h : public llama_model_base { + llama_model_nemotron_h(const struct llama_model_params & params) : llama_model_base(params) {} + void load_arch_hparams(llama_model_loader & ml) override; + void load_arch_tensors(llama_model_loader & ml) override; + + struct graph : public llm_build_mamba_base { + graph(const llama_model & model, const llm_graph_params & params); + ggml_tensor * build_ffn_layer(ggml_tensor * cur, const llama_model & model, int il); + ggml_tensor * build_attention_layer(ggml_tensor * cur, llm_graph_input_attn_kv * inp_attn, + const llama_model & model, int64_t n_embd_head, int il); + }; + + std::unique_ptr<llm_graph_context> build_arch_graph(const llm_graph_params & params) const override; +}; + + +struct llama_model_nemotron_h_moe : public llama_model_nemotron_h { + llama_model_nemotron_h_moe(const struct llama_model_params & params) : llama_model_nemotron_h(params) {} + // reuse load_arch_hparams and load_arch_tensors from llama_model_nemotron_h + + using graph = llama_model_nemotron_h::graph; + + std::unique_ptr<llm_graph_context> build_arch_graph(const llm_graph_params & params) const override; +}; + + +struct llama_model_exaone : public llama_model_base { + llama_model_exaone(const struct llama_model_params & params) : llama_model_base(params) {} + void load_arch_hparams(llama_model_loader & ml) override; + void load_arch_tensors(llama_model_loader & ml) override; + + struct graph : public llm_graph_context { + graph(const llama_model & model, const llm_graph_params & params); + }; + + std::unique_ptr<llm_graph_context> build_arch_graph(const llm_graph_params & params) const override; +}; + + +struct llama_model_exaone4 : public llama_model_base { + llama_model_exaone4(const struct llama_model_params & params) : llama_model_base(params) {} + void load_arch_hparams(llama_model_loader & ml) override; + void load_arch_tensors(llama_model_loader & ml) override; + + template <bool iswa> + struct graph : public llm_graph_context { + graph(const llama_model & model, const llm_graph_params & params); + }; + + std::unique_ptr<llm_graph_context> build_arch_graph(const llm_graph_params & params) const override; +}; + + +struct llama_model_exaone_moe : public llama_model_base { + llama_model_exaone_moe(const struct llama_model_params & params) : llama_model_base(params) {} + void load_arch_hparams(llama_model_loader & ml) override; + void load_arch_tensors(llama_model_loader & ml) override; + + struct graph : public llm_graph_context { + graph(const llama_model & model, const llm_graph_params & params); + }; + + std::unique_ptr<llm_graph_context> build_arch_graph(const llm_graph_params & params) const override; +}; + + +struct llama_model_rwkv6 : public llama_model_base { + llama_model_rwkv6(const struct llama_model_params & params) : llama_model_base(params) {} + void load_arch_hparams(llama_model_loader & ml) override; + void load_arch_tensors(llama_model_loader & ml) override; + + struct graph : public llm_build_rwkv6_base { + graph(const llama_model & model, const llm_graph_params & params); + }; + + std::unique_ptr<llm_graph_context> build_arch_graph(const llm_graph_params & params) const override; +}; + + +struct llama_model_rwkv6qwen2 : public llama_model_base { + llama_model_rwkv6qwen2(const struct llama_model_params & params) : llama_model_base(params) {} + void load_arch_hparams(llama_model_loader & ml) override; + void load_arch_tensors(llama_model_loader & ml) override; + + struct graph : public llm_build_rwkv6_base { + graph(const llama_model & model, const llm_graph_params & params); + }; + + std::unique_ptr<llm_graph_context> build_arch_graph(const llm_graph_params & params) const override; +}; + + +struct llama_model_rwkv7 : public llama_model_base { + llama_model_rwkv7(const struct llama_model_params & params) : llama_model_base(params) {} + void load_arch_hparams(llama_model_loader & ml) override; + void load_arch_tensors(llama_model_loader & ml) override; + + struct graph : public llm_build_rwkv7_base { + graph(const llama_model & model, const llm_graph_params & params); + }; + + std::unique_ptr<llm_graph_context> build_arch_graph(const llm_graph_params & params) const override; +}; + + +struct llama_model_arwkv7 : public llama_model_base { + llama_model_arwkv7(const struct llama_model_params & params) : llama_model_base(params) {} + void load_arch_hparams(llama_model_loader & ml) override; + void load_arch_tensors(llama_model_loader & ml) override; + + struct graph : public llm_build_rwkv7_base { + graph(const llama_model & model, const llm_graph_params & params); + }; + + std::unique_ptr<llm_graph_context> build_arch_graph(const llm_graph_params & params) const override; +}; + + +struct llama_model_granite : public llama_model_base { + llama_model_granite(const struct llama_model_params & params) : llama_model_base(params) {} + void load_arch_hparams(llama_model_loader & ml) override; + void load_arch_tensors(llama_model_loader & ml) override; + + struct graph : public llm_graph_context { + graph(const llama_model & model, const llm_graph_params & params); + + private: + ggml_tensor * build_attention_layer( + ggml_tensor * cur, + ggml_tensor * inp_pos, + llm_graph_input_attn_kv * inp_attn, + const llama_model & model, + const int64_t n_embd_head, + const int il); + + ggml_tensor * build_layer_ffn( + ggml_tensor * cur, + ggml_tensor * inpSA, + const llama_model & model, + const int il); + }; + + std::unique_ptr<llm_graph_context> build_arch_graph(const llm_graph_params & params) const override; +}; + + +struct llama_model_granite_moe : public llama_model_base { + llama_model_granite_moe(const struct llama_model_params & params) : llama_model_base(params) {} + void load_arch_hparams(llama_model_loader & ml) override; + void load_arch_tensors(llama_model_loader & ml) override; + + using graph = llama_model_granite::graph; + + std::unique_ptr<llm_graph_context> build_arch_graph(const llm_graph_params & params) const override; +}; + + +struct llama_model_minicpm : public llama_model_base { + llama_model_minicpm(const struct llama_model_params & params) : llama_model_base(params) {} + void load_arch_hparams(llama_model_loader & ml) override; + void load_arch_tensors(llama_model_loader & ml) override; + + using graph = llama_model_granite::graph; + + std::unique_ptr<llm_graph_context> build_arch_graph(const llm_graph_params & params) const override; +}; + + +struct llama_model_granite_hybrid : public llama_model_base { + llama_model_granite_hybrid(const struct llama_model_params & params) : llama_model_base(params) {} + void load_arch_hparams(llama_model_loader & ml) override; + void load_arch_tensors(llama_model_loader & ml) override; + + struct graph : public llm_build_mamba_base { + graph(const llama_model & model, const llm_graph_params & params); + ggml_tensor * build_layer_ffn(ggml_tensor * cur, ggml_tensor * inpSA, const llama_model & model, const int il); + ggml_tensor * build_attention_layer(ggml_tensor * cur, ggml_tensor * inp_pos, llm_graph_input_attn_kv * inp_attn, + const llama_model & model,const int64_t n_embd_head, const int il); + }; + + std::unique_ptr<llm_graph_context> build_arch_graph(const llm_graph_params & params) const override; +}; + + +struct llama_model_chameleon : public llama_model_base { + llama_model_chameleon(const struct llama_model_params & params) : llama_model_base(params) {} + void load_arch_hparams(llama_model_loader & ml) override; + void load_arch_tensors(llama_model_loader & ml) override; + + struct graph : public llm_graph_context { + graph(const llama_model & model, const llm_graph_params & params); + }; + + std::unique_ptr<llm_graph_context> build_arch_graph(const llm_graph_params & params) const override; +}; + + +struct llama_model_wavtokenizer_dec : public llama_model_base { + llama_model_wavtokenizer_dec(const struct llama_model_params & params) : llama_model_base(params) {} + void load_arch_hparams(llama_model_loader & ml) override; + void load_arch_tensors(llama_model_loader & ml) override; + + struct graph : public llm_graph_context { + graph(const llama_model & model, const llm_graph_params & params); + }; + + std::unique_ptr<llm_graph_context> build_arch_graph(const llm_graph_params & params) const override; +}; + + +struct llama_model_plm : public llama_model_base { + llama_model_plm(const struct llama_model_params & params) : llama_model_base(params) {} + void load_arch_hparams(llama_model_loader & ml) override; + void load_arch_tensors(llama_model_loader & ml) override; + + struct graph : public llm_graph_context { + graph(const llama_model & model, const llm_graph_params & params); + }; + + std::unique_ptr<llm_graph_context> build_arch_graph(const llm_graph_params & params) const override; +}; + + +struct llama_model_bailingmoe : public llama_model_base { + llama_model_bailingmoe(const struct llama_model_params & params) : llama_model_base(params) {} + void load_arch_hparams(llama_model_loader & ml) override; + void load_arch_tensors(llama_model_loader & ml) override; + + struct graph : public llm_graph_context { + graph(const llama_model & model, const llm_graph_params & params); + }; + + std::unique_ptr<llm_graph_context> build_arch_graph(const llm_graph_params & params) const override; +}; + + +struct llama_model_bailingmoe2 : public llama_model_base { + llama_model_bailingmoe2(const struct llama_model_params & params) : llama_model_base(params) {} + void load_arch_hparams(llama_model_loader & ml) override; + void load_arch_tensors(llama_model_loader & ml) override; + + struct graph : public llm_graph_context { + graph(const llama_model & model, const llm_graph_params & params); + }; + + std::unique_ptr<llm_graph_context> build_arch_graph(const llm_graph_params & params) const override; +}; + + +struct llama_model_seed_oss : public llama_model_base { + llama_model_seed_oss(const struct llama_model_params & params) : llama_model_base(params) {} + void load_arch_hparams(llama_model_loader & ml) override; + void load_arch_tensors(llama_model_loader & ml) override; + + struct graph : public llm_graph_context { + graph(const llama_model & model, const llm_graph_params & params); + }; + + std::unique_ptr<llm_graph_context> build_arch_graph(const llm_graph_params & params) const override; +}; + + +struct llama_model_dots1 : public llama_model_base { + llama_model_dots1(const struct llama_model_params & params) : llama_model_base(params) {} + void load_arch_hparams(llama_model_loader & ml) override; + void load_arch_tensors(llama_model_loader & ml) override; + + struct graph : public llm_graph_context { + graph(const llama_model & model, const llm_graph_params & params); + }; + + std::unique_ptr<llm_graph_context> build_arch_graph(const llm_graph_params & params) const override; +}; + + +struct llama_model_arcee : public llama_model_base { + llama_model_arcee(const struct llama_model_params & params) : llama_model_base(params) {} + void load_arch_hparams(llama_model_loader & ml) override; + void load_arch_tensors(llama_model_loader & ml) override; + + struct graph : public llm_graph_context { + graph(const llama_model & model, const llm_graph_params & params); + }; + + std::unique_ptr<llm_graph_context> build_arch_graph(const llm_graph_params & params) const override; +}; + + +struct llama_model_afmoe : public llama_model_base { + llama_model_afmoe(const struct llama_model_params & params) : llama_model_base(params) {} + void load_arch_hparams(llama_model_loader & ml) override; + void load_arch_tensors(llama_model_loader & ml) override; + + struct graph : public llm_graph_context { + graph(const llama_model & model, const llm_graph_params & params); + }; + + std::unique_ptr<llm_graph_context> build_arch_graph(const llm_graph_params & params) const override; +}; + + +struct llama_model_ernie4_5 : public llama_model_base { + llama_model_ernie4_5(const struct llama_model_params & params) : llama_model_base(params) {} + void load_arch_hparams(llama_model_loader & ml) override; + void load_arch_tensors(llama_model_loader & ml) override; + + struct graph : public llm_graph_context { + graph(const llama_model & model, const llm_graph_params & params); + }; + + std::unique_ptr<llm_graph_context> build_arch_graph(const llm_graph_params & params) const override; +}; + + +struct llama_model_ernie4_5_moe : public llama_model_ernie4_5 { + llama_model_ernie4_5_moe(const struct llama_model_params & params) : llama_model_ernie4_5(params) {} + // reuse load_arch_hparams and load_arch_tensors from llama_model_ernie4_5 + + struct graph : public llm_graph_context { + graph(const llama_model & model, const llm_graph_params & params); + }; + + std::unique_ptr<llm_graph_context> build_arch_graph(const llm_graph_params & params) const override; +}; + + +struct llama_model_paddleocr : public llama_model_ernie4_5 { + llama_model_paddleocr(const struct llama_model_params & params) : llama_model_ernie4_5(params) {} + // reuse load_arch_hparams and load_arch_tensors from llama_model_ernie4_5 + + struct graph : public llm_graph_context { + graph(const llama_model & model, const llm_graph_params & params); + }; + + std::unique_ptr<llm_graph_context> build_arch_graph(const llm_graph_params & params) const override; +}; + + +struct llama_model_hunyuan_moe : public llama_model_base { + llama_model_hunyuan_moe(const struct llama_model_params & params) : llama_model_base(params) {} + void load_arch_hparams(llama_model_loader & ml) override; + void load_arch_tensors(llama_model_loader & ml) override; + + struct graph : public llm_graph_context { + graph(const llama_model & model, const llm_graph_params & params); + }; + + std::unique_ptr<llm_graph_context> build_arch_graph(const llm_graph_params & params) const override; +}; + + +struct llama_model_hunyuan_vl : public llama_model_base { + llama_model_hunyuan_vl(const struct llama_model_params & params) : llama_model_base(params) {} + void load_arch_hparams(llama_model_loader & ml) override; + void load_arch_tensors(llama_model_loader & ml) override; + + struct graph : public llm_graph_context { + graph(const llama_model & model, const llm_graph_params & params); + }; + + std::unique_ptr<llm_graph_context> build_arch_graph(const llm_graph_params & params) const override; +}; + + +struct llama_model_hunyuan_dense : public llama_model_hunyuan_vl { + llama_model_hunyuan_dense(const struct llama_model_params & params) : llama_model_hunyuan_vl(params) {} + // reuse load_arch_hparams and load_arch_tensors from llama_model_hunyuan_vl + + using graph = llama_model_hunyuan_vl::graph; + + std::unique_ptr<llm_graph_context> build_arch_graph(const llm_graph_params & params) const override; +}; + + +struct llama_model_smollm3 : public llama_model_base { + llama_model_smollm3(const struct llama_model_params & params) : llama_model_base(params) {} + void load_arch_hparams(llama_model_loader & ml) override; + void load_arch_tensors(llama_model_loader & ml) override; + + struct graph : public llm_graph_context { + graph(const llama_model & model, const llm_graph_params & params); + }; + + std::unique_ptr<llm_graph_context> build_arch_graph(const llm_graph_params & params) const override; +}; + + +struct llama_model_openai_moe : public llama_model_base { + llama_model_openai_moe(const struct llama_model_params & params) : llama_model_base(params) {} + void load_arch_hparams(llama_model_loader & ml) override; + void load_arch_tensors(llama_model_loader & ml) override; + + struct graph : public llm_graph_context { + graph(const llama_model & model, const llm_graph_params & params); + }; + + std::unique_ptr<llm_graph_context> build_arch_graph(const llm_graph_params & params) const override; +}; + + +struct llama_model_falcon_h1 : public llama_model_base { + llama_model_falcon_h1(const struct llama_model_params & params) : llama_model_base(params) {} + void load_arch_hparams(llama_model_loader & ml) override; + void load_arch_tensors(llama_model_loader & ml) override; + + struct graph : public llm_build_mamba_base { + graph(const llama_model & model, const llm_graph_params & params); + }; + + std::unique_ptr<llm_graph_context> build_arch_graph(const llm_graph_params & params) const override; +}; + + +struct llama_model_lfm2 : public llama_model_base { + llama_model_lfm2(const struct llama_model_params & params) : llama_model_base(params) {} + void load_arch_hparams(llama_model_loader & ml) override; + void load_arch_tensors(llama_model_loader & ml) override; + + template <bool iswa> + struct graph : public llm_graph_context { + graph(const llama_model & model, const llm_graph_params & params); + }; + + std::unique_ptr<llm_graph_context> build_arch_graph(const llm_graph_params & params) const override; +}; + + +struct llama_model_lfm2moe : public llama_model_base { + llama_model_lfm2moe(const struct llama_model_params & params) : llama_model_base(params) {} + void load_arch_hparams(llama_model_loader & ml) override; + void load_arch_tensors(llama_model_loader & ml) override; + + template <bool iswa> + using graph = llama_model_lfm2::graph<iswa>; + + std::unique_ptr<llm_graph_context> build_arch_graph(const llm_graph_params & params) const override; +}; + + +struct llama_model_smallthinker : public llama_model_base { + llama_model_smallthinker(const struct llama_model_params & params) : llama_model_base(params) {} + void load_arch_hparams(llama_model_loader & ml) override; + void load_arch_tensors(llama_model_loader & ml) override; + + template <bool iswa> + struct graph : public llm_graph_context { + graph(const llama_model & model, const llm_graph_params & params); + }; + + std::unique_ptr<llm_graph_context> build_arch_graph(const llm_graph_params & params) const override; +}; + + +struct llama_model_grovemoe : public llama_model_base { + llama_model_grovemoe(const struct llama_model_params & params) : llama_model_base(params) {} + void load_arch_hparams(llama_model_loader & ml) override; + void load_arch_tensors(llama_model_loader & ml) override; + + struct graph : public llm_graph_context { + graph(const llama_model & model, const llm_graph_params & params); + }; + + std::unique_ptr<llm_graph_context> build_arch_graph(const llm_graph_params & params) const override; +}; + + +struct llama_model_apertus : public llama_model_base { + llama_model_apertus(const struct llama_model_params & params) : llama_model_base(params) {} + void load_arch_hparams(llama_model_loader & ml) override; + void load_arch_tensors(llama_model_loader & ml) override; + + struct graph : public llm_graph_context { + graph(const llama_model & model, const llm_graph_params & params); + }; + + std::unique_ptr<llm_graph_context> build_arch_graph(const llm_graph_params & params) const override; +}; + + +struct llama_model_minimax_m2 : public llama_model_base { + llama_model_minimax_m2(const struct llama_model_params & params) : llama_model_base(params) {} + void load_arch_hparams(llama_model_loader & ml) override; + void load_arch_tensors(llama_model_loader & ml) override; + + struct graph : public llm_graph_context { + graph(const llama_model & model, const llm_graph_params & params); + }; + + std::unique_ptr<llm_graph_context> build_arch_graph(const llm_graph_params & params) const override; +}; + + +struct llama_model_cogvlm : public llama_model_base { + llama_model_cogvlm(const struct llama_model_params & params) : llama_model_base(params) {} + void load_arch_hparams(llama_model_loader & ml) override; + void load_arch_tensors(llama_model_loader & ml) override; + + struct graph : public llm_graph_context { + graph(const llama_model & model, const llm_graph_params & params); + }; + + std::unique_ptr<llm_graph_context> build_arch_graph(const llm_graph_params & params) const override; +}; + + +struct llama_model_pangu_embed : public llama_model_base { + llama_model_pangu_embed(const struct llama_model_params & params) : llama_model_base(params) {} + void load_arch_hparams(llama_model_loader & ml) override; + void load_arch_tensors(llama_model_loader & ml) override; + + struct graph : public llm_graph_context { + graph(const llama_model & model, const llm_graph_params & params); + }; + + std::unique_ptr<llm_graph_context> build_arch_graph(const llm_graph_params & params) const override; +}; + + +struct llama_model_qwen3next : public llama_model_base { + llama_model_qwen3next(const struct llama_model_params & params) : llama_model_base(params) {} + void load_arch_hparams(llama_model_loader & ml) override; + void load_arch_tensors(llama_model_loader & ml) override; + + struct graph : public llm_build_delta_net_base { + graph(const llama_model & model, const llm_graph_params & params); + private: + ggml_tensor * build_layer_attn( + llm_graph_input_attn_kv * inp_attn, + ggml_tensor * cur, + ggml_tensor * inp_pos, + int il); + + ggml_tensor * build_layer_attn_linear( + llm_graph_input_rs * inp, + ggml_tensor * cur, + int il); + + ggml_tensor * build_layer_ffn( + ggml_tensor * cur, + int il); + + ggml_tensor * build_norm_gated( + ggml_tensor * input, + ggml_tensor * weights, + ggml_tensor * gate, + int layer); + + // returns pair of qkv, z + std::pair<ggml_tensor *, ggml_tensor *> build_qkvz( + ggml_tensor * input, + int il); + + const llama_model & model; + }; + + std::unique_ptr<llm_graph_context> build_arch_graph(const llm_graph_params & params) const override; +}; + + +struct llama_model_qwen35 : public llama_model_base { + llama_model_qwen35(const struct llama_model_params & params) : llama_model_base(params) {} + void load_arch_hparams(llama_model_loader & ml) override; + void load_arch_tensors(llama_model_loader & ml) override; + + struct graph : public llm_build_delta_net_base { + graph(const llama_model & model, const llm_graph_params & params); + private: + ggml_tensor * build_layer_attn( + llm_graph_input_attn_kv * inp_attn, + ggml_tensor * cur, + ggml_tensor * inp_pos, + int * sections, + int il); + + ggml_tensor * build_layer_attn_linear( + llm_graph_input_rs * inp, + ggml_tensor * cur, + int il); + + ggml_tensor * build_layer_ffn( + ggml_tensor * cur, + int il); + + ggml_tensor * build_norm_gated( + ggml_tensor * input, + ggml_tensor * weights, + ggml_tensor * gate, + int layer); + + // returns pair of qkv, z + std::pair<ggml_tensor *, ggml_tensor *> build_qkvz( + ggml_tensor * input, + int il); + + const llama_model & model; + }; + + struct graph_mtp : public llm_graph_context { + graph_mtp(const llama_model & model, const llm_graph_params & params); + }; + + std::unique_ptr<llm_graph_context> build_arch_graph(const llm_graph_params & params) const override; +}; + + +struct llama_model_qwen35moe : public llama_model_base { + llama_model_qwen35moe(const struct llama_model_params & params) : llama_model_base(params) {} + void load_arch_hparams(llama_model_loader & ml) override; + void load_arch_tensors(llama_model_loader & ml) override; + + struct graph : public llm_build_delta_net_base { + graph(const llama_model & model, const llm_graph_params & params); + private: + ggml_tensor * build_layer_attn( + llm_graph_input_attn_kv * inp_attn, + ggml_tensor * cur, + ggml_tensor * inp_pos, + int * sections, + int il); + + ggml_tensor * build_layer_attn_linear( + llm_graph_input_rs * inp, + ggml_tensor * cur, + int il); + + ggml_tensor * build_layer_ffn( + ggml_tensor * cur, + int il); + + ggml_tensor * build_norm_gated( + ggml_tensor * input, + ggml_tensor * weights, + ggml_tensor * gate, + int layer); + + // returns pair of qkv, z + std::pair<ggml_tensor *, ggml_tensor *> build_qkvz( + ggml_tensor * input, + int il); + + const llama_model & model; + }; + + struct graph_mtp : public llm_graph_context { + graph_mtp(const llama_model & model, const llm_graph_params & params); + }; + + std::unique_ptr<llm_graph_context> build_arch_graph(const llm_graph_params & params) const override; +}; + + +struct llama_model_mistral3 : public llama_model_base { + llama_model_mistral3(const struct llama_model_params & params) : llama_model_base(params) {} + void load_arch_hparams(llama_model_loader & ml) override; + void load_arch_tensors(llama_model_loader & ml) override; + + struct graph : public llm_graph_context { + graph(const llama_model & model, const llm_graph_params & params); + }; + + std::unique_ptr<llm_graph_context> build_arch_graph(const llm_graph_params & params) const override; +}; + + +struct llama_model_mimo2 : public llama_model_base { + llama_model_mimo2(const struct llama_model_params & params) : llama_model_base(params) {} + void load_arch_hparams(llama_model_loader & ml) override; + void load_arch_tensors(llama_model_loader & ml) override; + + struct graph : public llm_graph_context { + graph(const llama_model & model, const llm_graph_params & params); + }; + + std::unique_ptr<llm_graph_context> build_arch_graph(const llm_graph_params & params) const override; +}; + + +struct llama_model_kimi_linear : public llama_model_base { + llama_model_kimi_linear(const struct llama_model_params & params) : llama_model_base(params) {} + void load_arch_hparams(llama_model_loader & ml) override; + void load_arch_tensors(llama_model_loader & ml) override; + + struct graph : public llm_build_delta_net_base { + graph(const llama_model & model, const llm_graph_params & params); + + std::pair<ggml_tensor *, ggml_tensor *> build_kda_autoregressive( + ggml_tensor * q, + ggml_tensor * k, + ggml_tensor * v, + ggml_tensor * gk, + ggml_tensor * beta, + ggml_tensor * state, + int il); + + std::pair<ggml_tensor *, ggml_tensor *> build_kda_chunking( + ggml_tensor * q, + ggml_tensor * k, + ggml_tensor * v, + ggml_tensor * gk, + ggml_tensor * beta, + ggml_tensor * state, + ggml_tensor * causal_mask, + ggml_tensor * identity, + ggml_tensor * diag_mask, + int il); + + const llama_model & model; + }; + + std::unique_ptr<llm_graph_context> build_arch_graph(const llm_graph_params & params) const override; +}; + + +struct llama_model_step35 : public llama_model_base { + llama_model_step35(const struct llama_model_params & params) : llama_model_base(params) {} + void load_arch_hparams(llama_model_loader & ml) override; + void load_arch_tensors(llama_model_loader & ml) override; + + struct graph : public llm_graph_context { + graph(const llama_model & model, const llm_graph_params & params); + }; + + struct graph_mtp : public llm_graph_context { + graph_mtp(const llama_model & model, const llm_graph_params & params); + }; + + std::unique_ptr<llm_graph_context> build_arch_graph(const llm_graph_params & params) const override; }; diff --git a/examples/talk-llama/models/modern-bert.cpp b/examples/talk-llama/models/modern-bert.cpp index bb12ed819f7..f3e9407e012 100644 --- a/examples/talk-llama/models/modern-bert.cpp +++ b/examples/talk-llama/models/modern-bert.cpp @@ -1,10 +1,80 @@ #include "models.h" -llm_build_modern_bert::llm_build_modern_bert(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { - const int64_t n_embd_head = hparams.n_embd_head_v; - const int64_t n_embd_gqa = hparams.n_embd_v_gqa(); +void llama_model_modern_bert::load_arch_hparams(llama_model_loader & ml) { + const bool found_swa = ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa, false); + if (found_swa && hparams.n_swa > 0) { + hparams.swa_type = LLAMA_SWA_TYPE_SYMMETRIC; + ml.get_key(LLM_KV_ROPE_FREQ_BASE_SWA, hparams.rope_freq_base_train_swa, false); + uint32_t swa_period = 3; + ml.get_key_or_arr(LLM_KV_ATTENTION_SLIDING_WINDOW_PATTERN, swa_period, false); + hparams.set_swa_pattern(swa_period, true); + } else { + hparams.swa_type = LLAMA_SWA_TYPE_NONE; + } + + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); + + // Some ModernBert derivatives (e.g. IBM Granite Embedding 97m R2) use + // SiLU/SwiGLU in the FFN instead of the default GELU/GeGLU. + hparams.llm_ffn_op = LLM_FFN_GEGLU; + std::string hidden_act; + if (ml.get_key(LLM_KV_HIDDEN_ACT, hidden_act, false)) { + hparams.llm_ffn_op = llm_ffn_op_type_from_string(hidden_act, LLM_FFN_GEGLU); + } + + switch (hparams.n_layer()) { + case 12: + type = LLM_TYPE_47M; break; // granite-embedding-small + case 22: + type = LLM_TYPE_149M; break; // modern-bert-base + case 28: + type = LLM_TYPE_395M; break; // modern-bert-large + default: type = LLM_TYPE_UNKNOWN; + } +} + +void llama_model_modern_bert::load_arch_tensors(llama_model_loader &) { + LLAMA_LOAD_LOCALS; + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + tok_norm = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "weight", 0), {n_embd}, 0); + + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + + for(int i = 0; i < n_layer; ++i) { + auto& layer = layers[i]; + + if ( i != 0 ) { + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + } else{ + // layer 0 uses identity + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, TENSOR_NOT_REQUIRED); + } + - GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); + layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, 3 * n_embd }, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); + + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, 2 * n_ff}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0); + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + } + + cls_out = create_tensor(tn(LLM_TENSOR_CLS_OUT, "weight"), {n_embd, hparams.n_cls_out}, TENSOR_NOT_REQUIRED); + cls_out_b = create_tensor(tn(LLM_TENSOR_CLS_OUT, "bias"), {hparams.n_cls_out}, TENSOR_NOT_REQUIRED); + cls = create_tensor(tn(LLM_TENSOR_CLS, "weight"), {n_embd, n_embd}, TENSOR_NOT_REQUIRED); + cls_norm = create_tensor(tn(LLM_TENSOR_CLS_NORM, "weight"), {n_embd}, TENSOR_NOT_REQUIRED); + +} + +std::unique_ptr<llm_graph_context> llama_model_modern_bert::build_arch_graph(const llm_graph_params & params) const { + return std::make_unique<graph>(*this, params); +} + +llama_model_modern_bert::graph::graph(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { + const int64_t n_embd_head = hparams.n_embd_head_v(); + + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k()); ggml_tensor * cur; ggml_tensor * inpL; @@ -15,8 +85,8 @@ llm_build_modern_bert::llm_build_modern_bert(const llama_model & model, const ll cb(inpL, "inp_embd", -1); // embed layer norm - inpL = build_norm(inpL, model.tok_norm, nullptr, LLM_NORM, -1); - cb(inpL, "inp_norm", -1); + inpL = build_norm(inpL, model.tok_norm, nullptr, LLM_NORM, 0); + cb(inpL, "inp_norm", 0); ggml_tensor * inp_out_ids = build_inp_out_ids(); @@ -37,14 +107,8 @@ llm_build_modern_bert::llm_build_modern_bert(const llama_model & model, const ll } // self attention - cur = build_lora_mm(model.layers[il].wqkv, cur); - cb(cur, "wqkv", il); - - const size_t type_size = ggml_type_size(cur->type); - - ggml_tensor * Qcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head, n_tokens, n_embd_head*type_size, cur->nb[1], 0*type_size*(n_embd)); - ggml_tensor * Kcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head*type_size, cur->nb[1], 1*type_size*(n_embd)); - ggml_tensor * Vcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head*type_size, cur->nb[1], 1*type_size*(n_embd + n_embd_gqa)); + auto [Qcur, Kcur, Vcur] = build_qkv(model.layers[il], cur, + n_embd_head, n_head, n_head_kv, il); // RoPE Qcur = ggml_rope_ext( @@ -64,7 +128,7 @@ llm_build_modern_bert::llm_build_modern_bert(const llama_model & model, const ll cb(Vcur, "Vcur", il); cur = build_attn(inp_attn, - model.layers[il].wo, nullptr, + model.layers[il].wo, nullptr, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); cb(cur, "kqv_out", il); @@ -88,7 +152,8 @@ llm_build_modern_bert::llm_build_modern_bert(const llama_model & model, const ll NULL, NULL, NULL, model.layers[il].ffn_down, NULL, NULL, NULL, - LLM_FFN_GEGLU, LLM_FFN_SEQ, il); + hparams.llm_ffn_op, + LLM_FFN_SEQ, il); // attentions bypass the intermediate layer cur = ggml_add(ctx0, cur, ffn_inp); @@ -104,13 +169,6 @@ llm_build_modern_bert::llm_build_modern_bert(const llama_model & model, const ll LLM_NORM, -1); cb(cur, "final_norm_out", -1); - if (hparams.pooling_type == LLAMA_POOLING_TYPE_CLS) { - // extracting cls token - cur = ggml_view_1d(ctx0, cur, hparams.n_embd, 0); - cb(cur, "cls_pooled_embd", -1); - } - - cb(cur, "res_embd", -1); res->t_embd = cur; ggml_build_forward_expand(gf, cur); } diff --git a/examples/talk-llama/models/mpt.cpp b/examples/talk-llama/models/mpt.cpp index 2328e027a74..d094fd9f80b 100644 --- a/examples/talk-llama/models/mpt.cpp +++ b/examples/talk-llama/models/mpt.cpp @@ -1,12 +1,73 @@ #include "models.h" +void llama_model_mpt::load_arch_hparams(llama_model_loader & ml) { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); + ml.get_key(LLM_KV_ATTENTION_CLAMP_KQV, hparams.f_clamp_kqv, false); + ml.get_key(LLM_KV_ATTENTION_MAX_ALIBI_BIAS, hparams.f_max_alibi_bias, false); + + switch (hparams.n_layer()) { + case 32: type = LLM_TYPE_7B; break; + case 48: type = LLM_TYPE_30B; break; + default: type = LLM_TYPE_UNKNOWN; + } +} + +void llama_model_mpt::load_arch_tensors(llama_model_loader &) { + LLAMA_LOAD_LOCALS; + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + pos_embd = create_tensor(tn(LLM_TENSOR_POS_EMBD, "weight"), {n_embd, n_ctx_train}, TENSOR_NOT_REQUIRED); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output_norm_b = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd}, TENSOR_NOT_REQUIRED); + + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); + if (!output) { + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); // needs to be on GPU + } + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + layer.attn_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); + + layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, 0); + layer.wqkv_b = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "bias", i), {n_embd + 2*n_embd_gqa}, TENSOR_NOT_REQUIRED); + + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); + layer.wo_b = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + layer.ffn_norm_b = create_tensor(tn(LLM_TENSOR_FFN_NORM, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); + + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0); + layer.ffn_down_b = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_up_b = create_tensor(tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff}, TENSOR_NOT_REQUIRED); -llm_build_mpt::llm_build_mpt(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { - const int64_t n_embd_head = hparams.n_embd_head_v; - const int64_t n_embd_gqa = hparams.n_embd_v_gqa(); + // FIXME test-llama-archs crashes if q_norm is created + layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd}, TENSOR_NOT_REQUIRED | TENSOR_SKIP_IF_VIRTUAL); + layer.attn_q_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED | TENSOR_SKIP_IF_VIRTUAL); - GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); + layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd}, TENSOR_NOT_REQUIRED); + layer.attn_k_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); + + // AWQ ScaleActivation layer + layer.ffn_act = create_tensor(tn(LLM_TENSOR_FFN_ACT, "scales", i), {n_ff}, TENSOR_NOT_REQUIRED); + } +} + +std::unique_ptr<llm_graph_context> llama_model_mpt::build_arch_graph(const llm_graph_params & params) const { + return std::make_unique<graph>(*this, params); +} + +llama_model_mpt::graph::graph(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { + const int64_t n_embd_head = hparams.n_embd_head_v(); + + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k()); ggml_tensor * cur; ggml_tensor * pos; @@ -38,25 +99,8 @@ llm_build_mpt::llm_build_mpt(const llama_model & model, const llm_graph_params & { cur = attn_norm; - cur = build_lora_mm(model.layers[il].wqkv, cur); - cb(cur, "wqkv", il); - - if (model.layers[il].bqkv) { - cur = ggml_add(ctx0, cur, model.layers[il].bqkv); - cb(cur, "bqkv", il); - } - - if (hparams.f_clamp_kqv > 0.0f) { - cur = ggml_clamp(ctx0, cur, -hparams.f_clamp_kqv, hparams.f_clamp_kqv); - cb(cur, "wqkv_clamped", il); - } - - ggml_tensor * Qcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head, n_tokens, n_embd_head * sizeof(float), - cur->nb[1], 0 * sizeof(float) * (n_embd)); - ggml_tensor * Kcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head * sizeof(float), - cur->nb[1], 1 * sizeof(float) * (n_embd)); - ggml_tensor * Vcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head * sizeof(float), - cur->nb[1], 1 * sizeof(float) * (n_embd + n_embd_gqa)); + auto [Qcur, Kcur, Vcur] = build_qkv(model.layers[il], cur, + n_embd_head, n_head, n_head_kv, il); // Q/K Layernorm if (model.layers[il].attn_q_norm) { @@ -76,7 +120,7 @@ llm_build_mpt::llm_build_mpt(const llama_model & model, const llm_graph_params & cb(Vcur, "Vcur", il); cur = build_attn(inp_attn, - model.layers[il].wo, model.layers[il].bo, + model.layers[il].wo, model.layers[il].wo_b, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f / sqrtf(float(n_embd_head)), il); } @@ -117,7 +161,7 @@ llm_build_mpt::llm_build_mpt(const llama_model & model, const llm_graph_params & cb(cur, "result_norm", -1); res->t_embd = cur; - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/nemotron-h-moe.cpp b/examples/talk-llama/models/nemotron-h-moe.cpp new file mode 100644 index 00000000000..a59cc6c9fbd --- /dev/null +++ b/examples/talk-llama/models/nemotron-h-moe.cpp @@ -0,0 +1,6 @@ +#include "models.h" + +std::unique_ptr<llm_graph_context> llama_model_nemotron_h_moe::build_arch_graph(const llm_graph_params & params) const { + return std::make_unique<graph>(*this, params); +} + diff --git a/examples/talk-llama/models/nemotron-h.cpp b/examples/talk-llama/models/nemotron-h.cpp index eb135e63f18..a456269347b 100644 --- a/examples/talk-llama/models/nemotron-h.cpp +++ b/examples/talk-llama/models/nemotron-h.cpp @@ -1,11 +1,130 @@ #include "models.h" +void llama_model_nemotron_h::load_arch_hparams(llama_model_loader & ml) { + ml.get_key(LLM_KV_SSM_CONV_KERNEL, hparams.ssm_d_conv); + ml.get_key(LLM_KV_SSM_INNER_SIZE, hparams.ssm_d_inner); + ml.get_key(LLM_KV_SSM_STATE_SIZE, hparams.ssm_d_state); + ml.get_key(LLM_KV_SSM_TIME_STEP_RANK, hparams.ssm_dt_rank); + ml.get_key(LLM_KV_SSM_GROUP_COUNT, hparams.ssm_n_group); + // A layer is recurrent IFF the n_head_kv value is set to 0 and + // the n_ff value is set to 0 + for (uint32_t i = 0; i < hparams.n_layer(); ++i) { + hparams.is_recr_impl[i] = (hparams.n_head_kv(i) == 0 && hparams.n_ff(i) == 0); + } + + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + + ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp, false); + ml.get_key(LLM_KV_EXPERT_SHARED_FEED_FORWARD_LENGTH, hparams.n_ff_shexp, false); + ml.get_key(LLM_KV_EXPERT_SHARED_COUNT, hparams.n_expert_shared, false); + ml.get_key(LLM_KV_EXPERT_WEIGHTS_NORM, hparams.expert_weights_norm, false); + ml.get_key(LLM_KV_EXPERT_WEIGHTS_SCALE, hparams.expert_weights_scale, false); + ml.get_key(LLM_KV_MOE_LATENT_SIZE, hparams.moe_latent_size, false); + + switch (hparams.n_layer()) { + case 52: type = LLM_TYPE_31B_A3_5B; break; // Nemotron-H_MOE 31B + case 56: type = LLM_TYPE_9B; break; + case 88: type = LLM_TYPE_120B_A12B; break; + default: type = LLM_TYPE_UNKNOWN; + } +} + +void llama_model_nemotron_h::load_arch_tensors(llama_model_loader &) { + LLAMA_LOAD_LOCALS; + + // mamba2 Mixer SSM params + // NOTE: int64_t for tensor dimensions + const int64_t d_conv = hparams.ssm_d_conv; + const int64_t d_inner = hparams.ssm_d_inner; + const int64_t d_state = hparams.ssm_d_state; + const int64_t n_ssm_head = hparams.ssm_dt_rank; + const int64_t n_group = hparams.ssm_n_group; + const int64_t d_in_proj = 2*d_inner + 2*n_group*d_state + n_ssm_head; + const int64_t moe_n_embd = hparams.moe_latent_size > 0 ? hparams.moe_latent_size : n_embd; + + // embeddings + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + { + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); + // if output is NULL, init from the input tok embed, duplicated to allow offloading + if (output == NULL) { + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); + } + } + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + // all blocks use the attn norm + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + + if (hparams.is_recr(i)) { + // ssm layers + layer.ssm_in = create_tensor(tn(LLM_TENSOR_SSM_IN, "weight", i), {n_embd, d_in_proj}, 0); + + layer.ssm_conv1d = create_tensor(tn(LLM_TENSOR_SSM_CONV1D, "weight", i), {d_conv, d_inner + 2*n_group*d_state}, 0); + layer.ssm_conv1d_b = create_tensor(tn(LLM_TENSOR_SSM_CONV1D, "bias", i), {d_inner + 2*n_group*d_state}, TENSOR_NOT_REQUIRED); + + layer.ssm_dt_b = create_tensor(tn(LLM_TENSOR_SSM_DT, "bias", i), {n_ssm_head}, 0); + + // no "weight" suffix for these + layer.ssm_a = create_tensor(tn(LLM_TENSOR_SSM_A, i), {1, n_ssm_head}, 0); + layer.ssm_d = create_tensor(tn(LLM_TENSOR_SSM_D, i), {1, n_ssm_head}, 0); + + layer.ssm_norm = create_tensor(tn(LLM_TENSOR_SSM_NORM, "weight", i), {d_inner / n_group, n_group}, 0); + + // out_proj + layer.ssm_out = create_tensor(tn(LLM_TENSOR_SSM_OUT, "weight", i), {d_inner, n_embd}, 0); + } else if (hparams.n_ff(i) == 0) { + // attention layers (with optional bias) + const int64_t n_head_i = hparams.n_head(i); + const int64_t n_embd_k_gqa_i = hparams.n_embd_k_gqa(i); + const int64_t n_embd_v_gqa_i = hparams.n_embd_v_gqa(i); + create_tensor_qkv(layer, i, n_embd, n_embd_head_k * n_head_i, n_embd_k_gqa_i, n_embd_v_gqa_i, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head_i, n_embd}, 0); + layer.wo_b = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); + } else { + if (n_expert != 0) { + const int64_t n_ff_exp = hparams.n_ff_exp ? hparams.n_ff_exp : n_ff / n_expert_used; + const int64_t n_ff_shexp = hparams.n_ff_shexp; + + layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), { n_embd, n_expert}, 0); + layer.ffn_exp_probs_b = create_tensor(tn(LLM_TENSOR_FFN_EXP_PROBS_B, "bias", i), {n_expert }, 0); + + // MoE branch + layer.ffn_latent_down = create_tensor(tn(LLM_TENSOR_FFN_LATENT_DOWN, "weight", i), {n_embd, moe_n_embd}, TENSOR_NOT_REQUIRED); + layer.ffn_latent_up = create_tensor(tn(LLM_TENSOR_FFN_LATENT_UP, "weight", i), {moe_n_embd, n_embd}, TENSOR_NOT_REQUIRED); -llm_build_nemotron_h::llm_build_nemotron_h(const llama_model & model, const llm_graph_params & params) : - llm_graph_context_mamba(params) { - const int64_t n_embd_head = hparams.n_embd_head_v; - GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); + layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff_exp, moe_n_embd, n_expert}, 0); + layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), {moe_n_embd, n_ff_exp, n_expert}, 0); + + // Shared expert branch + layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), {n_ff_shexp, n_embd}, 0); + layer.ffn_up_shexp = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), {n_embd, n_ff_shexp}, 0); + + } else { + // mlp layers + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { hparams.n_ff(i), n_embd}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, hparams.n_ff(i)}, 0); + layer.ffn_down_b = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); + layer.ffn_up_b = create_tensor(tn(LLM_TENSOR_FFN_UP, "bias", i), {hparams.n_ff(i)}, TENSOR_NOT_REQUIRED); + } + } + } +} + +std::unique_ptr<llm_graph_context> llama_model_nemotron_h::build_arch_graph(const llm_graph_params & params) const { + return std::make_unique<graph>(*this, params); +} + +llama_model_nemotron_h::graph::graph(const llama_model & model, const llm_graph_params & params) : + llm_build_mamba_base(params) { + const int64_t n_embd_head = hparams.n_embd_head_v(); + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k()); ggml_tensor * cur; ggml_tensor * inpL; @@ -24,7 +143,7 @@ llm_build_nemotron_h::llm_build_nemotron_h(const llama_model & model, const llm_ cur = build_norm(inpL, model.layers[il].attn_norm, NULL, LLM_NORM_RMS, il); cb(cur, "attn_norm", il); - if (hparams.is_recurrent(il)) { + if (hparams.is_recr(il)) { // ssm layer // cur = build_mamba2_layer(inp->get_recr(), cur, model, ubatch, il); } else if (hparams.n_ff(il) == 0) { @@ -55,70 +174,51 @@ llm_build_nemotron_h::llm_build_nemotron_h(const llama_model & model, const llm_ res->t_embd = cur; // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; ggml_build_forward_expand(gf, cur); } -ggml_tensor * llm_build_nemotron_h::build_attention_layer(ggml_tensor * cur, +ggml_tensor * llama_model_nemotron_h::graph::build_attention_layer(ggml_tensor * cur, llm_graph_input_attn_kv * inp_attn, const llama_model & model, - const int64_t n_embd_head, - const int il) { - // compute Q and K and (optionally) RoPE them - ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); - cb(Qcur, "Qcur", il); - if (model.layers[il].bq) { - Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq); - cb(Qcur, "Qcur", il); - } - - ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); - cb(Kcur, "Kcur", il); - if (model.layers[il].bk) { - Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk); - cb(Kcur, "Kcur", il); - } - - ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); - cb(Vcur, "Vcur", il); - if (model.layers[il].bv) { - Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv); - cb(Vcur, "Vcur", il); - } - - Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, hparams.n_head(il), n_tokens); - Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, hparams.n_head_kv(il), n_tokens); - Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, hparams.n_head_kv(il), n_tokens); - - cb(Qcur, "Qcur", il); - cb(Kcur, "Kcur", il); - cb(Vcur, "Vcur", il); + int64_t n_embd_head, + int il) { + auto [Qcur, Kcur, Vcur] = build_qkv(model.layers[il], cur, n_embd_head, hparams.n_head(il), hparams.n_head_kv(il), il); const float kq_scale = hparams.f_attention_scale == 0.0f ? 1.0f / sqrtf(float(n_embd_head)) : hparams.f_attention_scale; cur = build_attn(inp_attn, - model.layers[il].wo, model.layers[il].bo, + model.layers[il].wo, model.layers[il].wo_b, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale, il); cb(cur, "attn_out", il); return cur; } -ggml_tensor * llm_build_nemotron_h::build_ffn_layer(ggml_tensor * cur, const llama_model & model, const int il) { +ggml_tensor * llama_model_nemotron_h::graph::build_ffn_layer(ggml_tensor * cur, const llama_model & model, int il) { if (model.layers[il].ffn_gate_inp == nullptr) { cur = build_ffn(cur, - model.layers[il].ffn_up, model.layers[il].ffn_up_b, NULL, + model.layers[il].ffn_up, model.layers[il].ffn_up_b, model.layers[il].ffn_up_s, NULL, NULL, NULL, - model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL, + model.layers[il].ffn_down, model.layers[il].ffn_down_b, model.layers[il].ffn_down_s, NULL, LLM_FFN_RELU_SQR, LLM_FFN_PAR, il); cb(cur, "ffn_out", il); } else { - ggml_tensor * ffn_inp = cur; + ggml_tensor * inp_emb = cur; + ggml_tensor * inp_latent = cur; + + if (model.layers[il].ffn_latent_down) { + inp_latent = ggml_mul_mat(ctx0, model.layers[il].ffn_latent_down, cur); + } + + ggml_tensor * router_logits = build_lora_mm(model.layers[il].ffn_gate_inp, cur); + cb(router_logits, "ffn_moe_logits", il); + ggml_tensor * moe_out = - build_moe_ffn(ffn_inp, + build_moe_ffn(inp_latent, model.layers[il].ffn_gate_inp, model.layers[il].ffn_up_exps, nullptr, // no gate @@ -126,15 +226,23 @@ ggml_tensor * llm_build_nemotron_h::build_ffn_layer(ggml_tensor * cur, const lla model.layers[il].ffn_exp_probs_b, n_expert, n_expert_used, LLM_FFN_RELU_SQR, hparams.expert_weights_norm, - true, hparams.expert_weights_scale, + hparams.expert_weights_scale, LLAMA_EXPERT_GATING_FUNC_TYPE_SIGMOID, - il); + il, + router_logits, nullptr, + model.layers[il].ffn_up_exps_s, + nullptr, // no gate + model.layers[il].ffn_down_exps_s); cb(moe_out, "ffn_moe_out", il); - ggml_tensor * ffn_shexp = build_ffn(ffn_inp, - model.layers[il].ffn_up_shexp, NULL, NULL, - NULL /* no gate */ , NULL, NULL, - model.layers[il].ffn_down_shexp, NULL, NULL, + if (model.layers[il].ffn_latent_up) { + moe_out = ggml_mul_mat(ctx0, model.layers[il].ffn_latent_up, moe_out); + } + + ggml_tensor * ffn_shexp = build_ffn(inp_emb, + model.layers[il].ffn_up_shexp, NULL, model.layers[il].ffn_up_shexp_s, + NULL /* no gate */ , NULL, NULL, + model.layers[il].ffn_down_shexp, NULL, model.layers[il].ffn_down_shexp_s, NULL, LLM_FFN_RELU_SQR, LLM_FFN_PAR, il); cb(ffn_shexp, "ffn_shexp", il); diff --git a/examples/talk-llama/models/nemotron.cpp b/examples/talk-llama/models/nemotron.cpp index fcead041f0a..6e2bd9a33ca 100644 --- a/examples/talk-llama/models/nemotron.cpp +++ b/examples/talk-llama/models/nemotron.cpp @@ -1,10 +1,57 @@ #include "models.h" -llm_build_nemotron::llm_build_nemotron(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { - const int64_t n_embd_head = hparams.n_embd_head_v; +void llama_model_nemotron::load_arch_hparams(llama_model_loader & ml) { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); - GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); - //GGML_ASSERT(n_embd_head == hparams.n_rot); + switch (hparams.n_layer()) { + case 32: type = LLM_TYPE_4B; break; + default: type = LLM_TYPE_UNKNOWN; + } +} + +void llama_model_nemotron::load_arch_tensors(llama_model_loader &) { + LLAMA_LOAD_LOCALS; + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output_norm_b = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0); + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + layer.attn_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "bias", i), {n_embd}, 0); + + create_tensor_qkv(layer, i, n_embd, n_embd, n_embd_gqa, n_embd_gqa, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); + + // optional bias tensors + layer.wo_b = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + layer.ffn_norm_b = create_tensor(tn(LLM_TENSOR_FFN_NORM, "bias", i), {n_embd}, 0); + + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + + // optional MLP bias + layer.ffn_down_b = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); + layer.ffn_up_b = create_tensor(tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff}, TENSOR_NOT_REQUIRED); + } +} + +std::unique_ptr<llm_graph_context> llama_model_nemotron::build_arch_graph(const llm_graph_params & params) const { + return std::make_unique<graph>(*this, params); +} + +llama_model_nemotron::graph::graph(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { + const int64_t n_embd_head = hparams.n_embd_head_v(); + + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k()); + //GGML_ASSERT(n_embd_head == n_rot); ggml_tensor * cur; ggml_tensor * inpL; @@ -31,27 +78,8 @@ llm_build_nemotron::llm_build_nemotron(const llama_model & model, const llm_grap // self-attention { // compute Q and K and RoPE them - ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); - cb(Qcur, "Qcur", il); - if (model.layers[il].bq) { - Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq); - cb(Qcur, "Qcur", il); - } - ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); - cb(Kcur, "Kcur", il); - if (model.layers[il].bk) { - Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk); - cb(Kcur, "Kcur", il); - } - ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); - cb(Vcur, "Vcur", il); - if (model.layers[il].bv) { - Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv); - cb(Vcur, "Vcur", il); - } - Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); - Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + auto [Qcur, Kcur, Vcur] = build_qkv(model.layers[il], cur, + n_embd_head, n_head, n_head_kv, il); Qcur = ggml_rope_ext( ctx0, Qcur, inp_pos, nullptr, @@ -70,7 +98,7 @@ llm_build_nemotron::llm_build_nemotron(const llama_model & model, const llm_grap cb(Vcur, "Vcur", il); cur = build_attn(inp_attn, - model.layers[il].wo, model.layers[il].bo, + model.layers[il].wo, model.layers[il].wo_b, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } if (il == n_layer - 1 && inp_out_ids) { @@ -113,7 +141,7 @@ llm_build_nemotron::llm_build_nemotron(const llama_model & model, const llm_grap res->t_embd = cur; // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/neo-bert.cpp b/examples/talk-llama/models/neo-bert.cpp index 7c32bfca5f5..4a08d7abd40 100644 --- a/examples/talk-llama/models/neo-bert.cpp +++ b/examples/talk-llama/models/neo-bert.cpp @@ -1,10 +1,49 @@ #include "models.h" -llm_build_neo_bert::llm_build_neo_bert(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { - const int64_t n_embd_head = hparams.n_embd_head_v; - const int64_t n_embd_gqa = hparams.n_embd_v_gqa(); +void llama_model_neo_bert::load_arch_hparams(llama_model_loader & ml) { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); - GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); + if (hparams.n_layer() == 28) { + type = LLM_TYPE_250M; + } +} + +void llama_model_neo_bert::load_arch_tensors(llama_model_loader &) { + LLAMA_LOAD_LOCALS; + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + cls = create_tensor(tn(LLM_TENSOR_CLS, "weight"), {n_embd, n_embd}, TENSOR_NOT_REQUIRED); + cls_b = create_tensor(tn(LLM_TENSOR_CLS, "bias"), {n_embd}, TENSOR_NOT_REQUIRED); + + cls_out = create_tensor(tn(LLM_TENSOR_CLS_OUT, "weight"), {n_embd, hparams.n_cls_out}, TENSOR_NOT_REQUIRED); + cls_out_b = create_tensor(tn(LLM_TENSOR_CLS_OUT, "bias"), {hparams.n_cls_out}, TENSOR_NOT_REQUIRED); + + output_norm_enc = create_tensor(tn(LLM_TENSOR_ENC_OUTPUT_NORM, "weight"), {n_embd}, 0); + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + + layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff*2}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0); + } +} + +std::unique_ptr<llm_graph_context> llama_model_neo_bert::build_arch_graph(const llm_graph_params & params) const { + return std::make_unique<graph>(*this, params); +} + +llama_model_neo_bert::graph::graph(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { + const int64_t n_embd_head = hparams.n_embd_head_v(); + + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k()); ggml_tensor * cur; ggml_tensor * inpL; @@ -27,17 +66,8 @@ llm_build_neo_bert::llm_build_neo_bert(const llama_model & model, const llm_grap LLM_NORM_RMS, il); { - ggml_tensor * Qcur; - ggml_tensor * Kcur; - ggml_tensor * Vcur; - - // self-attention - cur = build_lora_mm(model.layers[il].wqkv, cur); - cb(cur, "wqkv", il); - - Qcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 0*sizeof(float)*(n_embd)); - Kcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 1*sizeof(float)*(n_embd)); - Vcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa)); + auto [Qcur, Kcur, Vcur] = build_qkv(model.layers[il], cur, + n_embd_head, n_head, n_head_kv, il); // RoPE Qcur = ggml_rope_ext( @@ -57,7 +87,7 @@ llm_build_neo_bert::llm_build_neo_bert(const llama_model & model, const llm_grap cb(Vcur, "Vcur", il); cur = build_attn(inp_attn, - model.layers[il].wo, nullptr, + model.layers[il].wo, nullptr, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); cb(cur, "kqv_out", il); } diff --git a/examples/talk-llama/models/nomic-bert-moe.cpp b/examples/talk-llama/models/nomic-bert-moe.cpp new file mode 100644 index 00000000000..da4b62919bb --- /dev/null +++ b/examples/talk-llama/models/nomic-bert-moe.cpp @@ -0,0 +1,72 @@ +#include "models.h" + +void llama_model_nomic_bert_moe::load_arch_hparams(llama_model_loader & ml) { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); + ml.get_key(LLM_KV_MOE_EVERY_N_LAYERS, hparams.moe_every_n_layers, 0); + + if (hparams.n_layer() == 12 && hparams.n_embd == 768) { + if (arch == LLM_ARCH_NOMIC_BERT) { + type = LLM_TYPE_137M; + } else if (arch == LLM_ARCH_NOMIC_BERT_MOE && hparams.moe_every_n_layers == 2) { + type = LLM_TYPE_475M; + } + } +} + +void llama_model_nomic_bert_moe::load_arch_tensors(llama_model_loader &) { + LLAMA_LOAD_LOCALS; + + if (n_token_types == 0) { + throw std::runtime_error(arch_name() + " model needs to define token type count"); + } + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + type_embd = create_tensor(tn(LLM_TENSOR_TOKEN_TYPES, "weight"), {n_embd, n_token_types}, TENSOR_NOT_REQUIRED); + + if (arch == LLM_ARCH_BERT) { + pos_embd = create_tensor(tn(LLM_TENSOR_POS_EMBD, "weight"), {n_embd, n_ctx_train}, 0); + + cls = create_tensor(tn(LLM_TENSOR_CLS, "weight"), {n_embd, n_embd}, TENSOR_NOT_REQUIRED); + cls_b = create_tensor(tn(LLM_TENSOR_CLS, "bias"), {n_embd}, TENSOR_NOT_REQUIRED); + + cls_out = create_tensor(tn(LLM_TENSOR_CLS_OUT, "weight"), {n_embd, hparams.n_cls_out}, TENSOR_NOT_REQUIRED); + cls_out_b = create_tensor(tn(LLM_TENSOR_CLS_OUT, "bias"), {hparams.n_cls_out}, TENSOR_NOT_REQUIRED); + } + + tok_norm = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "weight", 0), {n_embd}, 0); + tok_norm_b = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "bias", 0), {n_embd}, 0); + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + create_tensor_qkv(layer, i, n_embd, n_embd, n_embd_gqa, n_embd_gqa, 0); + + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); + layer.wo_b = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); + + layer.attn_out_norm = create_tensor(tn(LLM_TENSOR_ATTN_OUT_NORM, "weight", i), {n_embd}, 0); + layer.attn_out_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_OUT_NORM, "bias", i), {n_embd}, 0); + + if (hparams.moe_every_n_layers > 0 && i % hparams.moe_every_n_layers == 1) { + layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), { n_embd, n_ff, n_expert}, 0); + layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), { n_ff, n_embd, n_expert}, 0); + layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0); + } else { + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_up_b = create_tensor(tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff}, TENSOR_NOT_REQUIRED); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0); + layer.ffn_down_b = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); + + if (arch == LLM_ARCH_NOMIC_BERT) { + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); + } + } + + layer.layer_out_norm = create_tensor(tn(LLM_TENSOR_LAYER_OUT_NORM, "weight", i), {n_embd}, 0); + layer.layer_out_norm_b = create_tensor(tn(LLM_TENSOR_LAYER_OUT_NORM, "bias", i), {n_embd}, 0); + } +} + +std::unique_ptr<llm_graph_context> llama_model_nomic_bert_moe::build_arch_graph(const llm_graph_params & params) const { + return std::make_unique<graph>(*this, params); +} + diff --git a/examples/talk-llama/models/nomic-bert.cpp b/examples/talk-llama/models/nomic-bert.cpp new file mode 100644 index 00000000000..e7fc72286a6 --- /dev/null +++ b/examples/talk-llama/models/nomic-bert.cpp @@ -0,0 +1,72 @@ +#include "models.h" + +void llama_model_nomic_bert::load_arch_hparams(llama_model_loader & ml) { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); + ml.get_key(LLM_KV_MOE_EVERY_N_LAYERS, hparams.moe_every_n_layers, 0); + + if (hparams.n_layer() == 12 && hparams.n_embd == 768) { + if (arch == LLM_ARCH_NOMIC_BERT) { + type = LLM_TYPE_137M; + } else if (arch == LLM_ARCH_NOMIC_BERT_MOE && hparams.moe_every_n_layers == 2) { + type = LLM_TYPE_475M; + } + } +} + +void llama_model_nomic_bert::load_arch_tensors(llama_model_loader &) { + LLAMA_LOAD_LOCALS; + + if (n_token_types == 0) { + throw std::runtime_error(arch_name() + " model needs to define token type count"); + } + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + type_embd = create_tensor(tn(LLM_TENSOR_TOKEN_TYPES, "weight"), {n_embd, n_token_types}, TENSOR_NOT_REQUIRED); + + if (arch == LLM_ARCH_BERT) { + pos_embd = create_tensor(tn(LLM_TENSOR_POS_EMBD, "weight"), {n_embd, n_ctx_train}, 0); + + cls = create_tensor(tn(LLM_TENSOR_CLS, "weight"), {n_embd, n_embd}, TENSOR_NOT_REQUIRED); + cls_b = create_tensor(tn(LLM_TENSOR_CLS, "bias"), {n_embd}, TENSOR_NOT_REQUIRED); + + cls_out = create_tensor(tn(LLM_TENSOR_CLS_OUT, "weight"), {n_embd, hparams.n_cls_out}, TENSOR_NOT_REQUIRED); + cls_out_b = create_tensor(tn(LLM_TENSOR_CLS_OUT, "bias"), {hparams.n_cls_out}, TENSOR_NOT_REQUIRED); + } + + tok_norm = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "weight", 0), {n_embd}, 0); + tok_norm_b = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "bias", 0), {n_embd}, 0); + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + create_tensor_qkv(layer, i, n_embd, n_embd, n_embd_gqa, n_embd_gqa, 0); + + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); + layer.wo_b = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); + + layer.attn_out_norm = create_tensor(tn(LLM_TENSOR_ATTN_OUT_NORM, "weight", i), {n_embd}, 0); + layer.attn_out_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_OUT_NORM, "bias", i), {n_embd}, 0); + + if (hparams.moe_every_n_layers > 0 && i % hparams.moe_every_n_layers == 1) { + layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), { n_embd, n_ff, n_expert}, 0); + layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), { n_ff, n_embd, n_expert}, 0); + layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0); + } else { + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_up_b = create_tensor(tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff}, TENSOR_NOT_REQUIRED); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0); + layer.ffn_down_b = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); + + if (arch == LLM_ARCH_NOMIC_BERT) { + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); + } + } + + layer.layer_out_norm = create_tensor(tn(LLM_TENSOR_LAYER_OUT_NORM, "weight", i), {n_embd}, 0); + layer.layer_out_norm_b = create_tensor(tn(LLM_TENSOR_LAYER_OUT_NORM, "bias", i), {n_embd}, 0); + } +} + +std::unique_ptr<llm_graph_context> llama_model_nomic_bert::build_arch_graph(const llm_graph_params & params) const { + return std::make_unique<graph>(*this, params); +} + diff --git a/examples/talk-llama/models/olmo.cpp b/examples/talk-llama/models/olmo.cpp index bbd623f1112..9f7a2ba60ef 100644 --- a/examples/talk-llama/models/olmo.cpp +++ b/examples/talk-llama/models/olmo.cpp @@ -1,10 +1,50 @@ #include "models.h" -llm_build_olmo::llm_build_olmo(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { - const int64_t n_embd_head = hparams.n_embd_head_v; +void llama_model_olmo::load_arch_hparams(llama_model_loader & ml) { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); + ml.get_key(LLM_KV_ATTENTION_CLAMP_KQV, hparams.f_clamp_kqv, false); + + switch (hparams.n_layer()) { + case 22: type = LLM_TYPE_1B; break; + case 32: type = LLM_TYPE_7B; break; + case 80: type = LLM_TYPE_70B; break; + default: type = LLM_TYPE_UNKNOWN; + } +} + +void llama_model_olmo::load_arch_tensors(llama_model_loader &) { + LLAMA_LOAD_LOCALS; - GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); - GGML_ASSERT(n_embd_head == hparams.n_rot); + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); + // if output is NULL, init from the input tok embed + if (output == NULL) { + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); + } + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + create_tensor_qkv(layer, i, n_embd, n_embd, n_embd_gqa, n_embd_gqa, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); + + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + } +} + +std::unique_ptr<llm_graph_context> llama_model_olmo::build_arch_graph(const llm_graph_params & params) const { + return std::make_unique<graph>(*this, params); +} + +llama_model_olmo::graph::graph(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { + const int64_t n_embd_head = hparams.n_embd_head_v(); + + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k()); + GGML_ASSERT(n_embd_head == n_rot); ggml_tensor * cur; ggml_tensor * inpL; @@ -30,27 +70,8 @@ llm_build_olmo::llm_build_olmo(const llama_model & model, const llm_graph_params // self-attention { // compute Q and K and RoPE them - ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); - cb(Qcur, "Qcur", il); - if (hparams.f_clamp_kqv > 0.0f) { - Qcur = ggml_clamp(ctx0, Qcur, -hparams.f_clamp_kqv, hparams.f_clamp_kqv); - cb(Qcur, "Qcur", il); - } - ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); - cb(Kcur, "Kcur", il); - if (hparams.f_clamp_kqv > 0.0f) { - Kcur = ggml_clamp(ctx0, Kcur, -hparams.f_clamp_kqv, hparams.f_clamp_kqv); - cb(Kcur, "Kcur", il); - } - ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); - cb(Vcur, "Vcur", il); - if (hparams.f_clamp_kqv > 0.0f) { - Vcur = ggml_clamp(ctx0, Vcur, -hparams.f_clamp_kqv, hparams.f_clamp_kqv); - cb(Vcur, "Vcur", il); - } - Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); - Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + auto [Qcur, Kcur, Vcur] = build_qkv(model.layers[il], cur, + n_embd_head, n_head, n_head_kv, il); Qcur = ggml_rope_ext( ctx0, Qcur, inp_pos, nullptr, @@ -69,7 +90,7 @@ llm_build_olmo::llm_build_olmo(const llama_model & model, const llm_graph_params cb(Vcur, "Vcur", il); cur = build_attn(inp_attn, - model.layers[il].wo, nullptr, + model.layers[il].wo, nullptr, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } if (il == n_layer - 1 && inp_out_ids) { @@ -112,7 +133,7 @@ llm_build_olmo::llm_build_olmo(const llama_model & model, const llm_graph_params res->t_embd = cur; // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/olmo2.cpp b/examples/talk-llama/models/olmo2.cpp index 713552dab89..cb52cdef720 100644 --- a/examples/talk-llama/models/olmo2.cpp +++ b/examples/talk-llama/models/olmo2.cpp @@ -1,11 +1,72 @@ #include "models.h" +void llama_model_olmo2::load_arch_hparams(llama_model_loader & ml) { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + + const bool found_swa = ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa, false); + if (found_swa && hparams.n_swa > 0) { + hparams.swa_type = LLAMA_SWA_TYPE_STANDARD; + uint32_t swa_period = 4; + ml.get_key_or_arr(LLM_KV_ATTENTION_SLIDING_WINDOW_PATTERN, swa_period, false); + hparams.set_swa_pattern(swa_period); + + hparams.rope_freq_base_train_swa = hparams.rope_freq_base_train; + hparams.rope_freq_scale_train_swa = 1.0; // See olmo2.cpp + ml.get_key(LLM_KV_ROPE_FREQ_BASE_SWA, hparams.rope_freq_base_train_swa, false); + } else { + hparams.swa_type = LLAMA_SWA_TYPE_NONE; + } + + switch (hparams.n_layer()) { + case 16: type = LLM_TYPE_1B; break; + case 32: type = LLM_TYPE_7B; break; + case 40: type = LLM_TYPE_13B; break; + case 64: type = LLM_TYPE_32B; break; + default: type = LLM_TYPE_UNKNOWN; + } +} + +void llama_model_olmo2::load_arch_tensors(llama_model_loader &) { + LLAMA_LOAD_LOCALS; + + const int64_t n_embd_head = n_embd / n_head; + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0); + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + create_tensor_qkv(layer, i, n_embd, n_embd, n_embd_gqa, n_embd_gqa, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); + layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd}, 0); + layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_head_kv * n_embd_head}, 0); + layer.attn_post_norm = create_tensor(tn(LLM_TENSOR_ATTN_POST_NORM, "weight", i), {n_embd}, 0); + + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); + layer.ffn_post_norm = create_tensor(tn(LLM_TENSOR_FFN_POST_NORM, "weight", i), {n_embd}, 0); + } +} + +std::unique_ptr<llm_graph_context> llama_model_olmo2::build_arch_graph(const llm_graph_params & params) const { + if (hparams.swa_type == LLAMA_SWA_TYPE_STANDARD) { + return std::make_unique<graph<true>>(*this, params); + } else { + return std::make_unique<graph<false>>(*this, params); + } +} + template <bool iswa> -llm_build_olmo2<iswa>::llm_build_olmo2(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { - const int64_t n_embd_head = hparams.n_embd_head_v; +llama_model_olmo2::graph<iswa>::graph(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { + const int64_t n_embd_head = hparams.n_embd_head_v(); - GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); - GGML_ASSERT(n_embd_head == hparams.n_rot); + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k()); + GGML_ASSERT(n_embd_head == n_rot); ggml_tensor * cur; ggml_tensor * inpL; @@ -89,7 +150,7 @@ llm_build_olmo2<iswa>::llm_build_olmo2(const llama_model & model, const llm_grap cb(Vcur, "Vcur", il); cur = build_attn(inp_attn, - model.layers[il].wo, NULL, + model.layers[il].wo, NULL, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } if (il == n_layer - 1 && inp_out_ids) { @@ -137,7 +198,7 @@ llm_build_olmo2<iswa>::llm_build_olmo2(const llama_model & model, const llm_grap res->t_embd = cur; // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; @@ -146,5 +207,5 @@ llm_build_olmo2<iswa>::llm_build_olmo2(const llama_model & model, const llm_grap } // Explicit template instantiations -template struct llm_build_olmo2<false>; -template struct llm_build_olmo2<true>; +template struct llama_model_olmo2::graph<false>; +template struct llama_model_olmo2::graph<true>; diff --git a/examples/talk-llama/models/olmoe.cpp b/examples/talk-llama/models/olmoe.cpp index b8b6988f897..1e2baeb207f 100644 --- a/examples/talk-llama/models/olmoe.cpp +++ b/examples/talk-llama/models/olmoe.cpp @@ -1,10 +1,60 @@ #include "models.h" -llm_build_olmoe::llm_build_olmoe(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { - const int64_t n_embd_head = hparams.n_embd_head_v; +void llama_model_olmoe::load_arch_hparams(llama_model_loader & ml) { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); - GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); - GGML_ASSERT(n_embd_head == hparams.n_rot); + switch (hparams.n_layer()) { + case 16: type = LLM_TYPE_A1_7B; break; + default: type = LLM_TYPE_UNKNOWN; + } +} + +void llama_model_olmoe::load_arch_tensors(llama_model_loader &) { + LLAMA_LOAD_LOCALS; + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0); + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + + create_tensor_qkv(layer, i, n_embd, n_embd, n_embd_gqa, n_embd_gqa, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); + layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd}, 0); + layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd}, 0); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + + layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0); + + if (n_expert == 0) { + throw std::runtime_error("n_expert must be > 0"); + } + if (n_expert_used == 0) { + throw std::runtime_error("n_expert_used must be > 0"); + } + + // MoE branch + layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {n_embd, n_ff, n_expert}, 0); + layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff, n_embd, n_expert}, 0); + layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), {n_embd, n_ff, n_expert}, 0); + } +} + +std::unique_ptr<llm_graph_context> llama_model_olmoe::build_arch_graph(const llm_graph_params & params) const { + return std::make_unique<graph>(*this, params); +} + +llama_model_olmoe::graph::graph(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { + const int64_t n_embd_head = hparams.n_embd_head_v(); + + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k()); + GGML_ASSERT(n_embd_head == n_rot); ggml_tensor * cur; ggml_tensor * inpL; @@ -68,7 +118,7 @@ llm_build_olmoe::llm_build_olmoe(const llama_model & model, const llm_graph_para cb(Vcur, "Vcur", il); cur = build_attn(inp_attn, - model.layers[il].wo, NULL, + model.layers[il].wo, NULL, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } if (il == n_layer - 1 && inp_out_ids) { @@ -92,7 +142,7 @@ llm_build_olmoe::llm_build_olmoe(const llama_model & model, const llm_graph_para nullptr, n_expert, n_expert_used, LLM_FFN_SILU, false, - false, 0.0, + hparams.expert_weights_scale, LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX, il); cb(cur, "ffn_moe_out", il); @@ -115,7 +165,7 @@ llm_build_olmoe::llm_build_olmoe(const llama_model & model, const llm_graph_para res->t_embd = cur; // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/openai-moe-iswa.cpp b/examples/talk-llama/models/openai-moe-iswa.cpp deleted file mode 100644 index dbe3ca1851f..00000000000 --- a/examples/talk-llama/models/openai-moe-iswa.cpp +++ /dev/null @@ -1,127 +0,0 @@ -#include "models.h" - -llm_build_openai_moe_iswa::llm_build_openai_moe_iswa(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { - ggml_tensor * cur; - ggml_tensor * inpL; - - inpL = build_inp_embd(model.tok_embd); - - // inp_pos - contains the positions - ggml_tensor * inp_pos = build_inp_pos(); - - auto * inp_attn = build_attn_inp_kv_iswa(); - - ggml_tensor * inp_out_ids = build_inp_out_ids(); - - for (int il = 0; il < n_layer; ++il) { - const float freq_base_l = model.get_rope_freq_base (cparams, il); - const float freq_scale_l = model.get_rope_freq_scale(cparams, il); - - ggml_tensor * inpSA = inpL; - - // norm - cur = build_norm(inpL, - model.layers[il].attn_norm, nullptr, - LLM_NORM_RMS, il); - cb(cur, "attn_norm", il); - - // self-attention - { - // compute Q and K and RoPE them - ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); - cb(Qcur, "Qcur", il); - if (model.layers[il].bq) { - Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq); - cb(Qcur, "Qcur", il); - } - ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); - cb(Kcur, "Kcur", il); - if (model.layers[il].bk) { - Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk); - cb(Kcur, "Kcur", il); - } - ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); - cb(Vcur, "Vcur", il); - if (model.layers[il].bv) { - Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv); - cb(Vcur, "Vcur", il); - } - Qcur = ggml_reshape_3d(ctx0, Qcur, n_rot, n_head, n_tokens); - Kcur = ggml_reshape_3d(ctx0, Kcur, n_rot, n_head_kv, n_tokens); - Vcur = ggml_reshape_3d(ctx0, Vcur, n_rot, n_head_kv, n_tokens); - - Qcur = ggml_rope_ext( - ctx0, Qcur, inp_pos, nullptr, - n_rot, rope_type, n_ctx_orig, freq_base_l, freq_scale_l, - ext_factor, attn_factor, beta_fast, beta_slow - ); - - Kcur = ggml_rope_ext( - ctx0, Kcur, inp_pos, nullptr, - n_rot, rope_type, n_ctx_orig, freq_base_l, freq_scale_l, - ext_factor, attn_factor, beta_fast, beta_slow - ); - - cb(Qcur, "Qcur", il); - cb(Kcur, "Kcur", il); - cb(Vcur, "Vcur", il); - - cur = build_attn(inp_attn, - model.layers[il].wo, model.layers[il].bo, - Qcur, Kcur, Vcur, nullptr, model.layers[il].attn_sinks, nullptr, 1.0f/sqrtf(float(n_rot)), il); - - cb(cur, "attn_out", il); - } - if (il == n_layer - 1) { - // skip computing output for unused tokens - cur = ggml_get_rows(ctx0, cur, inp_out_ids); - inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); - } - ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA); - cb(ffn_inp, "ffn_inp", il); - - cur = ffn_inp; - cur = build_norm(cur, - model.layers[il].attn_post_norm, nullptr, - LLM_NORM_RMS, il); - cb(cur, "attn_post_norm", il); - - // MoE branch - cur = build_moe_ffn(cur, - model.layers[il].ffn_gate_inp, model.layers[il].ffn_gate_inp_b, - model.layers[il].ffn_up_exps, model.layers[il].ffn_up_exps_b, - model.layers[il].ffn_gate_exps, model.layers[il].ffn_gate_exps_b, - model.layers[il].ffn_down_exps, model.layers[il].ffn_down_exps_b, - nullptr, - n_expert, n_expert_used, - LLM_FFN_SWIGLU_OAI_MOE, false, - false, 0.0, - LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX_WEIGHT, - il); - cb(cur, "ffn_moe_out", il); - - cur = ggml_add(ctx0, cur, ffn_inp); - - cur = build_cvec(cur, il); - cb(cur, "l_out", il); - - // input for next layer - inpL = cur; - } - cur = inpL; - - cur = build_norm(cur, - model.output_norm, NULL, - LLM_NORM_RMS, -1); - - cb(cur, "result_norm", -1); - res->t_embd = cur; - - // lm_head - cur = build_lora_mm(model.output, cur); - - cb(cur, "result_output", -1); - res->t_logits = cur; - - ggml_build_forward_expand(gf, cur); -} diff --git a/examples/talk-llama/models/openai-moe.cpp b/examples/talk-llama/models/openai-moe.cpp new file mode 100644 index 00000000000..6d74f9c7e6e --- /dev/null +++ b/examples/talk-llama/models/openai-moe.cpp @@ -0,0 +1,171 @@ +#include "models.h" + +void llama_model_openai_moe::load_arch_hparams(llama_model_loader & ml) { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp); + ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa); + + hparams.swa_type = LLAMA_SWA_TYPE_STANDARD; + uint32_t swa_period = 2; + ml.get_key_or_arr(LLM_KV_ATTENTION_SLIDING_WINDOW_PATTERN, swa_period, false); + hparams.set_swa_pattern(swa_period); + + hparams.rope_freq_base_train_swa = hparams.rope_freq_base_train; + hparams.rope_freq_scale_train_swa = hparams.rope_freq_scale_train; + ml.get_key(LLM_KV_ROPE_FREQ_BASE_SWA, hparams.rope_freq_base_train_swa, false); + + switch (hparams.n_layer()) { + case 24: type = LLM_TYPE_20B; break; + case 36: type = LLM_TYPE_120B; break; + default: type = LLM_TYPE_UNKNOWN; + } +} + +void llama_model_openai_moe::load_arch_tensors(llama_model_loader &) { + LLAMA_LOAD_LOCALS; + + const int64_t n_ff_exp = hparams.n_ff_exp; + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0); + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + layer.attn_post_norm = create_tensor(tn(LLM_TENSOR_ATTN_POST_NORM, "weight", i), {n_embd}, 0); + + create_tensor_qkv(layer, i, n_embd, n_head * n_rot, n_head_kv * n_rot, n_head_kv * n_rot, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_head * n_rot, n_embd}, 0); + + layer.attn_sinks = create_tensor(tn(LLM_TENSOR_ATTN_SINKS, "weight", i), {n_head}, 0); + + layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), { n_embd, n_expert}, 0); + layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}, 0); + layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff_exp, n_embd, n_expert}, 0); + layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}, 0); + + layer.wo_b = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, 0); + + layer.ffn_gate_inp_b = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "bias", i), {n_expert}, 0); + layer.ffn_gate_exps_b = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "bias", i), {n_ff_exp, n_expert}, 0); + layer.ffn_down_exps_b = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "bias", i), { n_embd, n_expert}, 0); + layer.ffn_up_exps_b = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "bias", i), {n_ff_exp, n_expert}, 0); + } +} + +std::unique_ptr<llm_graph_context> llama_model_openai_moe::build_arch_graph(const llm_graph_params & params) const { + return std::make_unique<graph>(*this, params); +} + +llama_model_openai_moe::graph::graph(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { + ggml_tensor * cur; + ggml_tensor * inpL; + + inpL = build_inp_embd(model.tok_embd); + + // inp_pos - contains the positions + ggml_tensor * inp_pos = build_inp_pos(); + + auto * inp_attn = build_attn_inp_kv_iswa(); + + ggml_tensor * inp_out_ids = build_inp_out_ids(); + + for (int il = 0; il < n_layer; ++il) { + res->t_layer_inp[il] = inpL; + + const float freq_base_l = model.get_rope_freq_base (cparams, il); + const float freq_scale_l = model.get_rope_freq_scale(cparams, il); + + ggml_tensor * inpSA = inpL; + + // norm + cur = build_norm(inpL, + model.layers[il].attn_norm, nullptr, + LLM_NORM_RMS, il); + cb(cur, "attn_norm", il); + + // self-attention + { + // compute Q and K and RoPE them + auto [Qcur, Kcur, Vcur] = build_qkv(model.layers[il], cur, + n_rot, n_head, n_head_kv, il); + + Qcur = ggml_rope_ext( + ctx0, Qcur, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base_l, freq_scale_l, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + Kcur = ggml_rope_ext( + ctx0, Kcur, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base_l, freq_scale_l, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + cb(Qcur, "Qcur", il); + cb(Kcur, "Kcur", il); + cb(Vcur, "Vcur", il); + + cur = build_attn(inp_attn, + model.layers[il].wo, model.layers[il].wo_b, model.layers[il].wo_s, + Qcur, Kcur, Vcur, nullptr, model.layers[il].attn_sinks, nullptr, 1.0f/sqrtf(float(n_rot)), il); + + cb(cur, "attn_out", il); + } + if (il == n_layer - 1) { + // skip computing output for unused tokens + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); + } + ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA); + cb(ffn_inp, "ffn_inp", il); + + cur = ffn_inp; + cur = build_norm(cur, + model.layers[il].attn_post_norm, nullptr, + LLM_NORM_RMS, il); + cb(cur, "attn_post_norm", il); + + // MoE branch + cur = build_moe_ffn(cur, + model.layers[il].ffn_gate_inp, model.layers[il].ffn_gate_inp_b, + model.layers[il].ffn_up_exps, model.layers[il].ffn_up_exps_b, + model.layers[il].ffn_gate_exps, model.layers[il].ffn_gate_exps_b, + model.layers[il].ffn_down_exps, model.layers[il].ffn_down_exps_b, + nullptr, + n_expert, n_expert_used, + LLM_FFN_SWIGLU_OAI_MOE, false, + hparams.expert_weights_scale, + LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX_WEIGHT, + il); + cb(cur, "ffn_moe_out", il); + + cur = ggml_add(ctx0, cur, ffn_inp); + + cur = build_cvec(cur, il); + cb(cur, "l_out", il); + + // input for next layer + inpL = cur; + } + cur = inpL; + + cur = build_norm(cur, + model.output_norm, NULL, + LLM_NORM_RMS, -1); + + cb(cur, "result_norm", -1); + res->t_embd = cur; + + // lm_head + cur = build_lora_mm(model.output, cur, model.output_s); + + cb(cur, "result_output", -1); + res->t_logits = cur; + + ggml_build_forward_expand(gf, cur); +} diff --git a/examples/talk-llama/models/openelm.cpp b/examples/talk-llama/models/openelm.cpp index ee46a3375e8..13120bd3236 100644 --- a/examples/talk-llama/models/openelm.cpp +++ b/examples/talk-llama/models/openelm.cpp @@ -1,9 +1,56 @@ #include "models.h" -llm_build_openelm::llm_build_openelm(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { - const int64_t n_embd_head = hparams.n_embd_head_v; +void llama_model_openelm::load_arch_hparams(llama_model_loader & ml) { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + + switch (hparams.n_layer()) { + case 16: type = LLM_TYPE_270M; break; + case 20: type = LLM_TYPE_450M; break; + case 28: type = LLM_TYPE_1B; break; + case 36: type = LLM_TYPE_3B; break; + default: type = LLM_TYPE_UNKNOWN; + } +} + +void llama_model_openelm::load_arch_tensors(llama_model_loader &) { + LLAMA_LOAD_LOCALS; + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + // init output from the input tok embed + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); + + for (int i = 0; i < n_layer; ++i) { + const int64_t n_head = hparams.n_head(i); + const int64_t n_head_qkv = 2*hparams.n_head_kv(i) + n_head; + const int64_t n_ff = hparams.n_ff(i); + + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + + layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_head_qkv*n_embd_head_k}, 0); + layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head_k}, 0); + layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k}, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_head*n_embd_head_k, n_embd}, 0); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + } +} + +std::unique_ptr<llm_graph_context> llama_model_openelm::build_arch_graph(const llm_graph_params & params) const { + return std::make_unique<graph>(*this, params); +} + +llama_model_openelm::graph::graph(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { + const int64_t n_embd_head = hparams.n_embd_head_v(); - GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k()); ggml_tensor * cur; ggml_tensor * inpL; @@ -43,7 +90,7 @@ llm_build_openelm::llm_build_openelm(const llama_model & model, const llm_graph_ ggml_tensor * Kcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, cur->nb[1], cur->nb[2], cur->nb[1]*n_head); cb(Kcur, "Kcur", il); - ggml_tensor * Vcur = ggml_cont(ctx0, ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, cur->nb[1], cur->nb[2], cur->nb[1]*(n_head+n_head_kv))); + ggml_tensor * Vcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, cur->nb[1], cur->nb[2], cur->nb[1]*(n_head+n_head_kv)); cb(Vcur, "Vcur", il); Qcur = build_norm(Qcur, @@ -73,7 +120,7 @@ llm_build_openelm::llm_build_openelm(const llama_model & model, const llm_graph_ cb(Qcur, "Vcur", il); cur = build_attn(inp_attn, - model.layers[il].wo, NULL, + model.layers[il].wo, NULL, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } if (il == n_layer - 1 && inp_out_ids) { @@ -115,7 +162,7 @@ llm_build_openelm::llm_build_openelm(const llama_model & model, const llm_graph_ cb(cur, "result_norm", -1); res->t_embd = cur; - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/orion.cpp b/examples/talk-llama/models/orion.cpp index bb02273bfe7..863a2822269 100644 --- a/examples/talk-llama/models/orion.cpp +++ b/examples/talk-llama/models/orion.cpp @@ -1,10 +1,50 @@ #include "models.h" -llm_build_orion::llm_build_orion(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { - const int64_t n_embd_head = hparams.n_embd_head_v; +void llama_model_orion::load_arch_hparams(llama_model_loader & ml) { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); - GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); - GGML_ASSERT(n_embd_head == hparams.n_rot); + switch (hparams.n_layer()) { + case 40: type = LLM_TYPE_14B; break; + default: type = LLM_TYPE_UNKNOWN; + } +} + +void llama_model_orion::load_arch_tensors(llama_model_loader &) { + LLAMA_LOAD_LOCALS; + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output_norm_b = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0); + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + layer.attn_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "bias", i), {n_embd}, 0); + + create_tensor_qkv(layer, i, n_embd, n_embd, n_embd_gqa, n_embd_gqa, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + layer.ffn_norm_b = create_tensor(tn(LLM_TENSOR_FFN_NORM, "bias", i), {n_embd}, 0); + + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + } +} + +std::unique_ptr<llm_graph_context> llama_model_orion::build_arch_graph(const llm_graph_params & params) const { + return std::make_unique<graph>(*this, params); +} + +llama_model_orion::graph::graph(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { + const int64_t n_embd_head = hparams.n_embd_head_v(); + + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k()); + GGML_ASSERT(n_embd_head == n_rot); ggml_tensor * cur; ggml_tensor * inpL; @@ -30,30 +70,8 @@ llm_build_orion::llm_build_orion(const llama_model & model, const llm_graph_para // self-attention { // compute Q and K and RoPE them - ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); - cb(Qcur, "Qcur", il); - // if (model.layers[il].bq) { - // Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq); - // cb(Qcur, "Qcur", il); - // } - - ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); - cb(Kcur, "Kcur", il); - // if (model.layers[il].bk) { - // Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk); - // cb(Kcur, "Kcur", il); - // } - - ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); - cb(Vcur, "Vcur", il); - // if (model.layers[il].bv) { - // Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv); - // cb(Vcur, "Vcur", il); - // } - - Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); - Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + auto [Qcur, Kcur, Vcur] = build_qkv(model.layers[il], cur, + n_embd_head, n_head, n_head_kv, il); Qcur = ggml_rope_ext( ctx0, Qcur, inp_pos, nullptr, @@ -72,7 +90,7 @@ llm_build_orion::llm_build_orion(const llama_model & model, const llm_graph_para cb(Vcur, "Vcur", il); cur = build_attn(inp_attn, - model.layers[il].wo, NULL, + model.layers[il].wo, NULL, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } if (il == n_layer - 1 && inp_out_ids) { @@ -114,7 +132,7 @@ llm_build_orion::llm_build_orion(const llama_model & model, const llm_graph_para res->t_embd = cur; // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/paddleocr.cpp b/examples/talk-llama/models/paddleocr.cpp new file mode 100644 index 00000000000..d39220bd778 --- /dev/null +++ b/examples/talk-llama/models/paddleocr.cpp @@ -0,0 +1,107 @@ +#include "models.h" + +std::unique_ptr<llm_graph_context> llama_model_paddleocr::build_arch_graph(const llm_graph_params & params) const { + return std::make_unique<graph>(*this, params); +} + +llama_model_paddleocr::graph::graph(const llama_model & model, const llm_graph_params & params) : + llm_graph_context(params) { + + // NOTE: same with qwen2vl.cpp, but bias tensors are optional + + const int64_t n_embd_head = hparams.n_embd_head_v(); + + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k()); + GGML_ASSERT(n_embd_head == n_rot); + + ggml_tensor * cur; + ggml_tensor * inpL; + + inpL = build_inp_embd(model.tok_embd); + + int sections[4]; + std::copy(std::begin(hparams.rope_sections), std::begin(hparams.rope_sections) + 4, sections); + + // inp_pos - contains the positions + ggml_tensor * inp_pos = build_inp_pos(); + + auto * inp_attn = build_attn_inp_kv(); + + ggml_tensor * inp_out_ids = build_inp_out_ids(); + + for (int il = 0; il < n_layer; ++il) { + ggml_tensor * inpSA = inpL; + + // norm + { + cur = build_norm(inpL, model.layers[il].attn_norm, NULL, LLM_NORM_RMS, il); + cb(cur, "attn_norm", il); + } + // self-attention + { + auto [Qcur, Kcur, Vcur] = build_qkv(model.layers[il], cur, + n_embd_head, n_head, n_head_kv, il); + + Qcur = ggml_rope_multi( + ctx0, Qcur, inp_pos, nullptr, + n_rot, sections, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + Kcur = ggml_rope_multi( + ctx0, Kcur, inp_pos, nullptr, + n_rot, sections, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + cb(Qcur, "Qcur", il); + cb(Kcur, "Kcur", il); + cb(Vcur, "Vcur", il); + + cur = build_attn(inp_attn, + model.layers[il].wo, model.layers[il].wo_b, model.layers[il].wo_s, + Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); + } + if (il == n_layer - 1) { + // skip computing output for unused tokens + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); + } + ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA); + cb(ffn_inp, "ffn_inp", il); + + // feed-forward network + { + cur = build_norm(ffn_inp, model.layers[il].ffn_norm, NULL, LLM_NORM_RMS, il); + cb(cur, "ffn_norm", il); + + cur = build_ffn(cur, + model.layers[il].ffn_up, NULL, NULL, + model.layers[il].ffn_gate, NULL, NULL, + model.layers[il].ffn_down, NULL, NULL, + NULL, LLM_FFN_SILU, LLM_FFN_PAR, il); + cb(cur, "ffn_out", il); + } + cur = ggml_add(ctx0, cur, ffn_inp); + + cur = build_cvec(cur, il); + cb(cur, "l_out", il); + + // input for next layer + inpL = cur; + } + cur = inpL; + + cur = build_norm(cur, model.output_norm, NULL, LLM_NORM_RMS, -1); + + cb(cur, "result_norm", -1); + res->t_embd = cur; + + // lm_head + cur = build_lora_mm(model.output, cur, model.output_s); + + cb(cur, "result_output", -1); + res->t_logits = cur; + + ggml_build_forward_expand(gf, cur); +} diff --git a/examples/talk-llama/models/pangu-embed.cpp b/examples/talk-llama/models/pangu-embed.cpp new file mode 100644 index 00000000000..90f05c088c1 --- /dev/null +++ b/examples/talk-llama/models/pangu-embed.cpp @@ -0,0 +1,162 @@ +#include "models.h" + +void llama_model_pangu_embed::load_arch_hparams(llama_model_loader & ml) { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + + switch (hparams.n_layer()) { + case 26: type = LLM_TYPE_1B; break; // openPangu-Embedded-1B-V1.1 + case 34: type = LLM_TYPE_7B; break; // openPangu-Embedded-7B-V1.1 + default: type = LLM_TYPE_UNKNOWN; + } +} + +void llama_model_pangu_embed::load_arch_tensors(llama_model_loader &) { + LLAMA_LOAD_LOCALS; + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); + + // if output is NULL, init from the input tok embed + if (output == NULL) { + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); + } + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + + // weight tensors + create_tensor_qkv(layer, i, n_embd, n_embd_head_k * n_head, n_embd_k_gqa, n_embd_v_gqa, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0); + + // bias tensors + layer.wo_b = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, 0); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + + if (hparams.rope_scaling_type_train == LLAMA_ROPE_SCALING_TYPE_LONGROPE) { + layer.rope_long = create_tensor(tn(LLM_TENSOR_ROPE_FACTORS_LONG, "weight", i), {n_rot/2}, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0)); + layer.rope_short = create_tensor(tn(LLM_TENSOR_ROPE_FACTORS_SHORT, "weight", i), {n_rot/2}, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0)); + } else { + layer.rope_freqs = create_tensor(tn(LLM_TENSOR_ROPE_FREQS, "weight", i), {n_rot/2}, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0)); + } + + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + } +} + +std::unique_ptr<llm_graph_context> llama_model_pangu_embed::build_arch_graph(const llm_graph_params & params) const { + return std::make_unique<graph>(*this, params); +} + +llama_model_pangu_embed::graph::graph(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { + const int64_t n_embd_head = hparams.n_embd_head_v(); + + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k()); + GGML_ASSERT(n_embd_head == n_rot); + + ggml_tensor * cur; + ggml_tensor * inpL; + + inpL = build_inp_embd(model.tok_embd); + + // inp_pos - contains the positions + ggml_tensor * inp_pos = build_inp_pos(); + + auto * inp_attn = build_attn_inp_kv(); + + ggml_tensor * inp_out_ids = build_inp_out_ids(); + + for (int il = 0; il < n_layer; ++il) { + ggml_tensor * inpSA = inpL; + + // norm + cur = build_norm(inpL, + model.layers[il].attn_norm, NULL, + LLM_NORM_RMS, il); + cb(cur, "attn_norm", il); + + // self attention + { + // compute Q and K and RoPE them + auto [Qcur, Kcur, Vcur] = build_qkv(model.layers[il], cur, + n_embd_head, n_head, n_head_kv, il); + + Qcur = ggml_rope_ext( + ctx0, Qcur, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + Kcur = ggml_rope_ext(ctx0, Kcur, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + cb(Qcur, "Qcur", il); + cb(Kcur, "Kcur", il); + cb(Vcur, "Vcur", il); + + cur = build_attn(inp_attn, + model.layers[il].wo, model.layers[il].wo_b, model.layers[il].wo_s, + Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); + } + + if (il == n_layer - 1 && inp_out_ids) { + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); + } + + ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA); + cb(ffn_inp, "ffn_inp", il); + + // feed-forward network + cur = build_norm(ffn_inp, + model.layers[il].ffn_norm, NULL, + LLM_NORM_RMS, il); + cb(cur, "ffn_norm", il); + + cur = build_ffn(cur, + model.layers[il].ffn_up, model.layers[il].ffn_up_b, NULL, + model.layers[il].ffn_gate, model.layers[il].ffn_gate_b, NULL, + model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL, + NULL, + LLM_FFN_SILU, LLM_FFN_PAR, il); + + cur = ggml_add(ctx0, cur, ffn_inp); + cb(cur, "ffn_out", il); + + cur = build_cvec(cur, il); + cb(cur, "l_out", il); + + // input for next layer + inpL = cur; + } + + cur = inpL; + + cur = build_norm(cur, + model.output_norm, NULL, + LLM_NORM_RMS, -1); + + cb(cur, "result_norm", -1); + res->t_embd = cur; + + // lm_head + cur = build_lora_mm(model.output, cur, model.output_s); + + if (model.output_b != nullptr) { + cur = ggml_add(ctx0, cur, model.output_b); + } + + cb(cur, "result_output", -1); + res->t_logits = cur; + + ggml_build_forward_expand(gf, cur); +} diff --git a/examples/talk-llama/models/pangu-embedded.cpp b/examples/talk-llama/models/pangu-embedded.cpp deleted file mode 100644 index 664572a5001..00000000000 --- a/examples/talk-llama/models/pangu-embedded.cpp +++ /dev/null @@ -1,121 +0,0 @@ -#include "models.h" - - -llm_build_pangu_embedded::llm_build_pangu_embedded(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { - const int64_t n_embd_head = hparams.n_embd_head_v; - - GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); - GGML_ASSERT(n_embd_head == hparams.n_rot); - - ggml_tensor * cur; - ggml_tensor * inpL; - - inpL = build_inp_embd(model.tok_embd); - - // inp_pos - contains the positions - ggml_tensor * inp_pos = build_inp_pos(); - - auto * inp_attn = build_attn_inp_kv(); - - ggml_tensor * inp_out_ids = build_inp_out_ids(); - - for (int il = 0; il < n_layer; ++il) { - ggml_tensor * inpSA = inpL; - - // norm - cur = build_norm(inpL, - model.layers[il].attn_norm, NULL, - LLM_NORM_RMS, il); - cb(cur, "attn_norm", il); - - // self attention - { - // compute Q and K and RoPE them - ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); - Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq); - cb(Qcur, "Qcur", il); - - ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); - Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk); - cb(Kcur, "Kcur", il); - - ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); - Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv); - cb(Vcur, "Vcur", il); - - Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); - Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); - - Qcur = ggml_rope_ext( - ctx0, Qcur, inp_pos, nullptr, - n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, - ext_factor, attn_factor, beta_fast, beta_slow - ); - - Kcur = ggml_rope_ext(ctx0, Kcur, inp_pos, nullptr, - n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, - ext_factor, attn_factor, beta_fast, beta_slow - ); - - cb(Qcur, "Qcur", il); - cb(Kcur, "Kcur", il); - cb(Vcur, "Vcur", il); - - cur = build_attn(inp_attn, - model.layers[il].wo, model.layers[il].bo, - Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); - } - - if (il == n_layer - 1 && inp_out_ids) { - cur = ggml_get_rows(ctx0, cur, inp_out_ids); - inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); - } - - ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA); - cb(ffn_inp, "ffn_inp", il); - - // feed-forward network - cur = build_norm(ffn_inp, - model.layers[il].ffn_norm, NULL, - LLM_NORM_RMS, il); - cb(cur, "ffn_norm", il); - - cur = build_ffn(cur, - model.layers[il].ffn_up, model.layers[il].ffn_up_b, NULL, - model.layers[il].ffn_gate, model.layers[il].ffn_gate_b, NULL, - model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL, - NULL, - LLM_FFN_SILU, LLM_FFN_PAR, il); - - cur = ggml_add(ctx0, cur, ffn_inp); - cb(cur, "ffn_out", il); - - cur = build_cvec(cur, il); - cb(cur, "l_out", il); - - // input for next layer - inpL = cur; - } - - cur = inpL; - - cur = build_norm(cur, - model.output_norm, NULL, - LLM_NORM_RMS, -1); - - cb(cur, "result_norm", -1); - res->t_embd = cur; - - // lm_head - cur = build_lora_mm(model.output, cur); - - if (model.output_b != nullptr) { - cur = ggml_add(ctx0, cur, model.output_b); - } - - cb(cur, "result_output", -1); - res->t_logits = cur; - - ggml_build_forward_expand(gf, cur); -} diff --git a/examples/talk-llama/models/phi2.cpp b/examples/talk-llama/models/phi2.cpp index 22dbf610767..81b1ad12cc0 100644 --- a/examples/talk-llama/models/phi2.cpp +++ b/examples/talk-llama/models/phi2.cpp @@ -1,11 +1,53 @@ #include "models.h" +void llama_model_phi2::load_arch_hparams(llama_model_loader & ml) { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); -llm_build_phi2::llm_build_phi2(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { - const int64_t n_embd_head = hparams.n_embd_head_v; - const int64_t n_embd_gqa = hparams.n_embd_v_gqa(); + switch (hparams.n_layer()) { + case 24: type = LLM_TYPE_1B; break; + case 32: type = LLM_TYPE_3B; break; + default: type = LLM_TYPE_UNKNOWN; + } +} + +void llama_model_phi2::load_arch_tensors(llama_model_loader &) { + LLAMA_LOAD_LOCALS; + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output_norm_b = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0); + output_b = create_tensor(tn(LLM_TENSOR_OUTPUT, "bias"), {n_vocab}, 0); + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + layer.attn_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "bias", i), {n_embd}, 0); + + create_tensor_qkv(layer, i, n_embd, n_embd, n_embd_gqa, n_embd_gqa, 0); + + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); + layer.wo_b = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, 0); + + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0); + layer.ffn_down_b = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, 0); + + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_up_b = create_tensor(tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff}, 0); + } +} + +std::unique_ptr<llm_graph_context> llama_model_phi2::build_arch_graph(const llm_graph_params & params) const { + return std::make_unique<graph>(*this, params); +} + +llama_model_phi2::graph::graph(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { + const int64_t n_embd_head = hparams.n_embd_head_v(); - GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k()); ggml_tensor * cur; ggml_tensor * attn_norm_output; @@ -30,29 +72,8 @@ llm_build_phi2::llm_build_phi2(const llama_model & model, const llm_graph_params // self-attention { - ggml_tensor * Qcur = nullptr; - ggml_tensor * Kcur = nullptr; - ggml_tensor * Vcur = nullptr; - - if (model.layers[il].wqkv) { - cur = build_lora_mm(model.layers[il].wqkv, attn_norm_output); - cb(cur, "wqkv", il); - - cur = ggml_add(ctx0, cur, model.layers[il].bqkv); - cb(cur, "bqkv", il); - - Qcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 0*sizeof(float)*(n_embd)); - Kcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 1*sizeof(float)*(n_embd)); - Vcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa)); - } else { - Qcur = ggml_add(ctx0, build_lora_mm(model.layers[il].wq, attn_norm_output), model.layers[il].bq); - Kcur = ggml_add(ctx0, build_lora_mm(model.layers[il].wk, attn_norm_output), model.layers[il].bk); - Vcur = ggml_add(ctx0, build_lora_mm(model.layers[il].wv, attn_norm_output), model.layers[il].bv); - - Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); - Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); - } + auto [Qcur, Kcur, Vcur] = build_qkv(model.layers[il], attn_norm_output, + n_embd_head, n_head, n_head_kv, il); Qcur = ggml_rope_ext( ctx0, Qcur, inp_pos, nullptr, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, @@ -74,7 +95,7 @@ llm_build_phi2::llm_build_phi2(const llama_model & model, const llm_graph_params Qcur = ggml_scale(ctx0, Qcur, 1.0f/sqrtf(float(n_embd_head))); cur = build_attn(inp_attn, - model.layers[il].wo, model.layers[il].bo, + model.layers[il].wo, model.layers[il].wo_b, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f, il); } if (il == n_layer - 1 && inp_out_ids) { @@ -109,7 +130,7 @@ llm_build_phi2::llm_build_phi2(const llama_model & model, const llm_graph_params cb(cur, "result_norm", -1); res->t_embd = cur; - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output_no_bias", -1); cur = ggml_add(ctx0, cur, model.output_b); diff --git a/examples/talk-llama/models/phi3.cpp b/examples/talk-llama/models/phi3.cpp index c8e5da33db7..716ff814cc1 100644 --- a/examples/talk-llama/models/phi3.cpp +++ b/examples/talk-llama/models/phi3.cpp @@ -1,11 +1,74 @@ #include "models.h" +void llama_model_phi3::load_arch_hparams(llama_model_loader & ml) { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + + switch (hparams.n_layer()) { + case 24: type = LLM_TYPE_1B; break; + case 32: type = LLM_TYPE_3B; break; + case 40: type = LLM_TYPE_14B; break; + default: type = LLM_TYPE_UNKNOWN; + } + + const bool found_swa = ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa, false); + + if (found_swa && hparams.n_swa > 0) { + LLAMA_LOG_WARN("%s: Phi SWA is currently disabled - results might be suboptimal for some models (see %s)\n", + __func__, "https://github.com/ggml-org/llama.cpp/pull/13676"); + + // TODO: fix conversion scripts to correctly populate `n_swa` and `n_swa_pattern` + hparams.swa_type = LLAMA_SWA_TYPE_NONE; + + hparams.n_swa = 0; + hparams.set_swa_pattern(1); + } +} + +void llama_model_phi3::load_arch_tensors(llama_model_loader &) { + LLAMA_LOAD_LOCALS; + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), { n_embd, n_vocab }, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), { n_embd }, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); + + // if output is NULL, init from the input tok embed + if (output == NULL) { + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); + } + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), { n_embd }, 0); + + create_tensor_qkv(layer, i, n_embd, n_embd, n_embd_gqa, n_embd_gqa, TENSOR_NOT_REQUIRED); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), { n_embd, n_embd }, 0); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), { n_embd }, 0); + + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd }, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), { n_embd, 2 * n_ff }, 0); + + layer.rope_long = create_tensor(tn(LLM_TENSOR_ROPE_FACTORS_LONG, "weight", i), { n_rot/2 }, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0)); + layer.rope_short = create_tensor(tn(LLM_TENSOR_ROPE_FACTORS_SHORT, "weight", i), { n_rot/2 }, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0)); + } +} + +std::unique_ptr<llm_graph_context> llama_model_phi3::build_arch_graph(const llm_graph_params & params) const { + if (hparams.swa_type != LLAMA_SWA_TYPE_NONE) { + return std::make_unique<graph<true>> (*this, params); + } else { + return std::make_unique<graph<false>>(*this, params); + } +} + template<bool iswa> -llm_build_phi3<iswa>::llm_build_phi3(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { - const int64_t n_embd_head = hparams.n_embd_head_v; - const int64_t n_embd_gqa = hparams.n_embd_v_gqa(); +llama_model_phi3::graph<iswa>::graph(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { + const int64_t n_embd_head = hparams.n_embd_head_v(); - GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k()); ggml_tensor * cur; ggml_tensor * inpL; @@ -39,27 +102,8 @@ llm_build_phi3<iswa>::llm_build_phi3(const llama_model & model, const llm_graph_ LLM_NORM_RMS, il); cb(attn_norm_output, "attn_norm", il); - ggml_tensor * Qcur = nullptr; - ggml_tensor * Kcur = nullptr; - ggml_tensor * Vcur = nullptr; - - if (model.layers[il].wqkv) { - cur = build_lora_mm(model.layers[il].wqkv, attn_norm_output); - cb(cur, "wqkv", il); - - Qcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head, n_tokens, n_embd_head * sizeof(float), cur->nb[1], 0 * sizeof(float) * (n_embd)); - Kcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head * sizeof(float), cur->nb[1], 1 * sizeof(float) * (n_embd)); - Vcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head * sizeof(float), cur->nb[1], 1 * sizeof(float) * (n_embd + n_embd_gqa)); - } - else { - Qcur = ggml_add(ctx0, build_lora_mm(model.layers[il].wq, attn_norm_output), model.layers[il].bq); - Kcur = ggml_add(ctx0, build_lora_mm(model.layers[il].wk, attn_norm_output), model.layers[il].bk); - Vcur = ggml_add(ctx0, build_lora_mm(model.layers[il].wv, attn_norm_output), model.layers[il].bv); - - Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); - Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); - } + auto [Qcur, Kcur, Vcur] = build_qkv(model.layers[il], attn_norm_output, + n_embd_head, n_head, n_head_kv, il); Qcur = ggml_rope_ext( ctx0, Qcur, inp_pos, rope_factors, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, @@ -80,7 +124,7 @@ llm_build_phi3<iswa>::llm_build_phi3(const llama_model & model, const llm_graph_ cb(Qcur, "Qcur", il); cur = build_attn(inp_attn, - model.layers[il].wo, model.layers[il].bo, + model.layers[il].wo, model.layers[il].wo_b, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f, il); } if (il == n_layer - 1 && inp_out_ids) { @@ -114,7 +158,7 @@ llm_build_phi3<iswa>::llm_build_phi3(const llama_model & model, const llm_graph_ nullptr, n_expert, n_expert_used, LLM_FFN_SILU, true, - false, 0.0, + hparams.expert_weights_scale, LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX, il); cb(cur, "ffn_moe_out", il); @@ -135,7 +179,7 @@ llm_build_phi3<iswa>::llm_build_phi3(const llama_model & model, const llm_graph_ cb(cur, "result_norm", -1); res->t_embd = cur; - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); if (model.output_b != nullptr) { cb(cur, "result_output_no_bias", -1); @@ -148,5 +192,5 @@ llm_build_phi3<iswa>::llm_build_phi3(const llama_model & model, const llm_graph_ } // Explicit template instantiations -template struct llm_build_phi3<false>; -template struct llm_build_phi3<true>; +template struct llama_model_phi3::graph<false>; +template struct llama_model_phi3::graph<true>; diff --git a/examples/talk-llama/models/phimoe.cpp b/examples/talk-llama/models/phimoe.cpp new file mode 100644 index 00000000000..c332553bc7d --- /dev/null +++ b/examples/talk-llama/models/phimoe.cpp @@ -0,0 +1,55 @@ +#include "models.h" + +void llama_model_phimoe::load_arch_hparams(llama_model_loader & ml) { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + + switch (hparams.n_layer()) { + case 32: type = LLM_TYPE_16x3_8B; break; + default: type = LLM_TYPE_UNKNOWN; + } +} + +void llama_model_phimoe::load_arch_tensors(llama_model_loader &) { + LLAMA_LOAD_LOCALS; + + const int64_t n_embd_head = n_embd / n_head; + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), { n_embd, n_vocab }, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), { n_embd }, 0); + output_norm_b = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), { n_embd, n_vocab }, 0); + output_b = create_tensor(tn(LLM_TENSOR_OUTPUT, "bias"), { n_vocab }, 0); + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), { n_embd }, 0); + layer.attn_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "bias", i), { n_embd }, 0); + + create_tensor_qkv(layer, i, n_embd, n_embd, n_embd_gqa, n_embd_gqa, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), { n_embd, n_embd }, 0); + layer.wo_b = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), { n_embd }, 0); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), { n_embd }, 0); + layer.ffn_norm_b = create_tensor(tn(LLM_TENSOR_FFN_NORM, "bias", i), { n_embd }, 0); + + layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0); + layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {n_embd, n_ff, n_expert}, 0); + layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff, n_embd, n_expert}, 0); + layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), {n_embd, n_ff, n_expert}, 0); + + layer.rope_long = create_tensor(tn(LLM_TENSOR_ROPE_FACTORS_LONG, "weight", i), { n_embd_head/2 }, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0)); + layer.rope_short = create_tensor(tn(LLM_TENSOR_ROPE_FACTORS_SHORT, "weight", i), { n_embd_head/2 }, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0)); + } +} + +std::unique_ptr<llm_graph_context> llama_model_phimoe::build_arch_graph(const llm_graph_params & params) const { + if (hparams.swa_type != LLAMA_SWA_TYPE_NONE) { + return std::make_unique<graph<true>> (*this, params); + } else { + return std::make_unique<graph<false>>(*this, params); + } +} + diff --git a/examples/talk-llama/models/plamo.cpp b/examples/talk-llama/models/plamo.cpp index 04ff709f9c6..246144519e4 100644 --- a/examples/talk-llama/models/plamo.cpp +++ b/examples/talk-llama/models/plamo.cpp @@ -1,10 +1,46 @@ #include "models.h" -llm_build_plamo::llm_build_plamo(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { - const int64_t n_embd_head = hparams.n_embd_head_v; +void llama_model_plamo::load_arch_hparams(llama_model_loader & ml) { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); - GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); - GGML_ASSERT(n_embd_head == hparams.n_rot); + switch (hparams.n_layer()) { + case 40: type = LLM_TYPE_13B; break; + default: type = LLM_TYPE_UNKNOWN; + } +} + +void llama_model_plamo::load_arch_tensors(llama_model_loader &) { + LLAMA_LOAD_LOCALS; + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0); + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + + create_tensor_qkv(layer, i, n_embd, n_embd, n_embd_gqa, n_embd_gqa, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); + + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + } +} + +std::unique_ptr<llm_graph_context> llama_model_plamo::build_arch_graph(const llm_graph_params & params) const { + return std::make_unique<graph>(*this, params); +} + +llama_model_plamo::graph::graph(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { + const int64_t n_embd_head = hparams.n_embd_head_v(); + + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k()); + GGML_ASSERT(n_embd_head == n_rot); ggml_tensor * cur; ggml_tensor * inpL; @@ -30,18 +66,8 @@ llm_build_plamo::llm_build_plamo(const llama_model & model, const llm_graph_para // self-attention { // compute Q and K and RoPE them - ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); - cb(Qcur, "Qcur", il); - - ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); - cb(Kcur, "Kcur", il); - - ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); - cb(Vcur, "Vcur", il); - - Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); - Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + auto [Qcur, Kcur, Vcur] = build_qkv(model.layers[il], cur, + n_embd_head, n_head, n_head_kv, il); Qcur = ggml_rope_ext( ctx0, Qcur, inp_pos, nullptr, @@ -60,7 +86,7 @@ llm_build_plamo::llm_build_plamo(const llama_model & model, const llm_graph_para cb(Vcur, "Vcur", il); cur = build_attn(inp_attn, - model.layers[il].wo, NULL, + model.layers[il].wo, NULL, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } if (il == n_layer - 1 && inp_out_ids) { @@ -101,7 +127,7 @@ llm_build_plamo::llm_build_plamo(const llama_model & model, const llm_graph_para res->t_embd = cur; // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/plamo2.cpp b/examples/talk-llama/models/plamo2.cpp index 31115a08f95..0b81513c368 100644 --- a/examples/talk-llama/models/plamo2.cpp +++ b/examples/talk-llama/models/plamo2.cpp @@ -1,7 +1,114 @@ #include "models.h" +#include "llama-memory-recurrent.h" -llm_build_plamo2::llm_build_plamo2(const llama_model & model, const llm_graph_params & params) : - llm_graph_context_mamba(params) { +void llama_model_plamo2::load_arch_hparams(llama_model_loader & ml) { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + + // Load Mamba SSM parameters + ml.get_key(LLM_KV_SSM_CONV_KERNEL, hparams.ssm_d_conv); + ml.get_key(LLM_KV_SSM_INNER_SIZE, hparams.ssm_d_inner); + ml.get_key(LLM_KV_SSM_STATE_SIZE, hparams.ssm_d_state); + ml.get_key(LLM_KV_SSM_TIME_STEP_RANK, hparams.ssm_dt_rank); + ml.get_key(LLM_KV_SSM_GROUP_COUNT, hparams.ssm_n_group); + + // Load attention parameters + ml.get_key(LLM_KV_ATTENTION_KEY_LENGTH, hparams.n_embd_head_k_full, false); + ml.get_key(LLM_KV_ATTENTION_VALUE_LENGTH, hparams.n_embd_head_v_full, false); + + for (uint32_t i = 0; i < hparams.n_layer(); ++i) { + hparams.is_recr_impl[i] = hparams.n_head_kv(i) == 0; + } + + switch (hparams.n_layer()) { + case 16: type = LLM_TYPE_1B; break; + case 32: + if (hparams.n_embd == 2048) { + type = LLM_TYPE_2B; + } else if (hparams.n_embd == 4096) { + type = LLM_TYPE_8B; + } + break; + default: type = LLM_TYPE_UNKNOWN; + } +} + +void llama_model_plamo2::load_arch_tensors(llama_model_loader &) { + LLAMA_LOAD_LOCALS; + + // mamba parameters + const uint32_t d_conv = hparams.ssm_d_conv; + const uint32_t d_state = hparams.ssm_d_state; + const uint32_t num_heads = hparams.ssm_dt_rank; + const uint32_t intermediate_size = hparams.ssm_d_inner; + const int64_t dt_dim = std::max(64, int(hparams.n_embd / 16)); + + // attention parameters + const uint32_t qk_dim = hparams.n_embd_head_k(); + const uint32_t v_dim = hparams.n_embd_head_v(); + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); + // if output is NULL, init from the input tok embed + if (output == NULL) { + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); + } + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + bool is_mamba_layer = hparams.is_recr(i); + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + + if (is_mamba_layer) { + layer.ssm_in = create_tensor(tn(LLM_TENSOR_SSM_IN, "weight", i), {n_embd, 2 * intermediate_size}, 0); + layer.ssm_conv1d = create_tensor(tn(LLM_TENSOR_SSM_CONV1D, "weight", i), {d_conv, intermediate_size}, 0); + + layer.ssm_x = create_tensor(tn(LLM_TENSOR_SSM_X, "weight", i), {intermediate_size, dt_dim + 2*d_state}, 0); + layer.ssm_dt = create_tensor(tn(LLM_TENSOR_SSM_DT, "weight", i), {dt_dim, num_heads}, 0); + layer.ssm_dt_b = create_tensor(tn(LLM_TENSOR_SSM_DT, "bias", i), {num_heads}, 0); + + layer.ssm_a = create_tensor(tn(LLM_TENSOR_SSM_A, i), {num_heads}, 0); + layer.ssm_d = create_tensor(tn(LLM_TENSOR_SSM_D, i), {num_heads}, 0); + + layer.ssm_out = create_tensor(tn(LLM_TENSOR_SSM_OUT, "weight", i), {intermediate_size, n_embd}, 0); + + layer.ssm_dt_norm = create_tensor(tn(LLM_TENSOR_SSM_DT_NORM, i), {dt_dim}, 0); + layer.ssm_b_norm = create_tensor(tn(LLM_TENSOR_SSM_B_NORM, i), {d_state}, 0); + layer.ssm_c_norm = create_tensor(tn(LLM_TENSOR_SSM_C_NORM, i), {d_state}, 0); + } else { + const int64_t num_attention_heads = hparams.n_head(i); + const int64_t q_num_heads = num_attention_heads; + const int64_t num_key_value_heads = hparams.n_head_kv(i); + const int64_t k_num_heads = num_key_value_heads; + const int64_t v_num_heads = num_key_value_heads; + const int64_t q_proj_dim = q_num_heads * qk_dim; + const int64_t k_proj_dim = k_num_heads * qk_dim; + const int64_t v_proj_dim = v_num_heads * v_dim; + + layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, q_proj_dim + k_proj_dim + v_proj_dim}, 0); + layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {qk_dim, num_attention_heads}, 0); + layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {qk_dim, k_num_heads}, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {q_num_heads * v_dim, n_embd}, 0); + } + + // All layers have post-attention norm, FFN norm, and FFN tensors + layer.attn_post_norm = create_tensor(tn(LLM_TENSOR_ATTN_POST_NORM, i), {n_embd}, 0); + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff * 2}, 0); + layer.ffn_post_norm = create_tensor(tn(LLM_TENSOR_FFN_POST_NORM, i), {n_embd}, 0); + } +} + +std::unique_ptr<llm_graph_context> llama_model_plamo2::build_arch_graph(const llm_graph_params & params) const { + return std::make_unique<graph>(*this, params); +} + +llama_model_plamo2::graph::graph(const llama_model & model, const llm_graph_params & params) : + llm_build_mamba_base(params) { ggml_tensor * cur; ggml_tensor * inpL; @@ -25,7 +132,7 @@ llm_build_plamo2::llm_build_plamo2(const llama_model & model, const llm_graph_pa cur = build_norm(inpL, model.layers[il].attn_norm, NULL, LLM_NORM_RMS, il); // check if this layer is Mamba or Attention - bool is_mamba_layer = hparams.is_recurrent(il); + const bool is_mamba_layer = hparams.is_recr(il); if (is_mamba_layer) { // PLaMo-2 Mamba layer @@ -69,6 +176,7 @@ llm_build_plamo2::llm_build_plamo2(const llama_model & model, const llm_graph_pa cur = ggml_add(ctx0, cur, residual); cb(cur, "ffn_residual", il); + // input for next layer inpL = cur; } @@ -81,7 +189,7 @@ llm_build_plamo2::llm_build_plamo2(const llama_model & model, const llm_graph_pa res->t_embd = cur; // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); // Explicitly mark as output tensor to ensure proper backend assignment @@ -92,7 +200,7 @@ llm_build_plamo2::llm_build_plamo2(const llama_model & model, const llm_graph_pa ggml_build_forward_expand(gf, cur); } -ggml_tensor * llm_build_plamo2::build_plamo2_attn_layer(llm_graph_input_attn_kv * inp, +ggml_tensor * llama_model_plamo2::graph::build_plamo2_attn_layer(llm_graph_input_attn_kv * inp, ggml_tensor * inp_pos, ggml_tensor * cur, const llama_model & model, @@ -104,9 +212,9 @@ ggml_tensor * llm_build_plamo2::build_plamo2_attn_layer(llm_graph_input_attn_kv cb(qkv, "wqkv", il); // split QKV tensor into Q, K, V - const int64_t n_embd_head_q = hparams.n_embd_head_k; - const int64_t n_embd_head_k = hparams.n_embd_head_k; - const int64_t n_embd_head_v = hparams.n_embd_head_v; + const int64_t n_embd_head_q = hparams.n_embd_head_k(); + const int64_t n_embd_head_k = hparams.n_embd_head_k(); + const int64_t n_embd_head_v = hparams.n_embd_head_v(); int32_t n_head = hparams.n_head(il); int32_t n_head_kv = hparams.n_head_kv(il); @@ -138,7 +246,7 @@ ggml_tensor * llm_build_plamo2::build_plamo2_attn_layer(llm_graph_input_attn_kv ext_factor, attn_factor, beta_fast, beta_slow); cur = build_attn(inp, - model.layers[il].wo, NULL, + model.layers[il].wo, NULL, model.layers[il].wo_s, Qcur, Kcur, Vcur, NULL, NULL, NULL, 1.0f / sqrtf(float(n_embd_head_v)), il); } @@ -147,7 +255,7 @@ ggml_tensor * llm_build_plamo2::build_plamo2_attn_layer(llm_graph_input_attn_kv return cur; } -ggml_tensor * llm_build_plamo2::build_plamo2_mamba_layer(llm_graph_input_rs * inp, +ggml_tensor * llama_model_plamo2::graph::build_plamo2_mamba_layer(llm_graph_input_rs * inp, ggml_tensor * cur, const llama_model & model, const llama_ubatch & ubatch, @@ -169,6 +277,8 @@ ggml_tensor * llm_build_plamo2::build_plamo2_mamba_layer(llm_graph_input_rs * in GGML_ASSERT(n_seqs != 0); GGML_ASSERT(ubatch.equal_seqs()); GGML_ASSERT(ubatch.n_tokens == n_seq_tokens * n_seqs); + GGML_ASSERT(d_inner % n_heads == 0); + GGML_ASSERT(n_group == 0); ggml_tensor * conv_states_all = mctx_cur->get_r_l(il); ggml_tensor * ssm_states_all = mctx_cur->get_s_l(il); diff --git a/examples/talk-llama/models/plamo3.cpp b/examples/talk-llama/models/plamo3.cpp index 55c8064679e..16d0b1dcef7 100644 --- a/examples/talk-llama/models/plamo3.cpp +++ b/examples/talk-llama/models/plamo3.cpp @@ -1,10 +1,77 @@ #include "models.h" +void llama_model_plamo3::load_arch_hparams(llama_model_loader & ml) { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + const bool found_swa = ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa, false); + if (found_swa && hparams.n_swa > 0) { + hparams.swa_type = LLAMA_SWA_TYPE_STANDARD; + ml.get_key(LLM_KV_ROPE_FREQ_BASE_SWA, hparams.rope_freq_base_train_swa, false); + uint32_t swa_period = 8; + ml.get_key_or_arr(LLM_KV_ATTENTION_SLIDING_WINDOW_PATTERN, swa_period, false); + hparams.set_swa_pattern(swa_period); + } else { + hparams.swa_type = LLAMA_SWA_TYPE_NONE; + } + + switch (hparams.n_layer()) { + case 24: type = LLM_TYPE_2B; break; + default: type = LLM_TYPE_UNKNOWN; + } +} + +void llama_model_plamo3::load_arch_tensors(llama_model_loader &) { + LLAMA_LOAD_LOCALS; + + const int64_t head_dim_q = hparams.n_embd_head_k(); + const int64_t head_dim_v = hparams.n_embd_head_v(); + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); + if (output == NULL) { + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); + } + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + const int64_t num_attention_heads = hparams.n_head(i); + const int64_t num_key_value_heads = hparams.n_head_kv(i); + const int64_t q_proj_dim = num_attention_heads * head_dim_q; + const int64_t k_proj_dim = num_key_value_heads * head_dim_q; + const int64_t v_proj_dim = num_key_value_heads * head_dim_v; + const int64_t n_ff_cur = hparams.n_ff(i); + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), + {n_embd,q_proj_dim + k_proj_dim + v_proj_dim}, 0); + layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {head_dim_q}, 0); + layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {head_dim_q}, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {num_attention_heads * head_dim_v, n_embd}, 0); + layer.attn_post_norm = create_tensor(tn(LLM_TENSOR_ATTN_POST_NORM, i), {n_embd}, 0); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + layer.ffn_post_norm = create_tensor(tn(LLM_TENSOR_FFN_POST_NORM, i), {n_embd}, 0); + + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff_cur * 2}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff_cur, n_embd}, 0); + } +} + +std::unique_ptr<llm_graph_context> llama_model_plamo3::build_arch_graph(const llm_graph_params & params) const { + if (hparams.swa_type != LLAMA_SWA_TYPE_NONE) { + return std::make_unique<graph<true>> (*this, params); + } else { + return std::make_unique<graph<false>>(*this, params); + } +} + template <bool iswa> -llm_build_plamo3<iswa>::llm_build_plamo3(const llama_model & model, const llm_graph_params & params) : +llama_model_plamo3::graph<iswa>::graph(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { - const int64_t head_dim_q = hparams.n_embd_head_k; - const int64_t head_dim_v = hparams.n_embd_head_v; + const int64_t head_dim_q = hparams.n_embd_head_k(); + const int64_t head_dim_v = hparams.n_embd_head_v(); ggml_tensor * cur; ggml_tensor * inpL = build_inp_embd(model.tok_embd); @@ -73,7 +140,7 @@ llm_build_plamo3<iswa>::llm_build_plamo3(const llama_model & model, const llm_gr const float attn_scale = 1.0f / sqrtf(float(head_dim_q)); cur = build_attn(inp_attn, - model.layers[il].wo, NULL, + model.layers[il].wo, NULL, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, attn_scale, il); cb(cur, "attn_out", il); @@ -109,6 +176,8 @@ llm_build_plamo3<iswa>::llm_build_plamo3(const llama_model & model, const llm_gr cur = build_cvec(cur, il); cb(cur, "l_out", il); + + // input for next layer inpL = cur; } @@ -117,12 +186,12 @@ llm_build_plamo3<iswa>::llm_build_plamo3(const llama_model & model, const llm_gr cur = build_norm(cur, model.output_norm, NULL, LLM_NORM_RMS, -1); res->t_embd = cur; - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); res->t_logits = cur; ggml_build_forward_expand(gf, cur); } // Explicit template instantiations -template struct llm_build_plamo3<false>; -template struct llm_build_plamo3<true>; +template struct llama_model_plamo3::graph<false>; +template struct llama_model_plamo3::graph<true>; diff --git a/examples/talk-llama/models/plm.cpp b/examples/talk-llama/models/plm.cpp index 481cbba6907..8ca325f5e2c 100644 --- a/examples/talk-llama/models/plm.cpp +++ b/examples/talk-llama/models/plm.cpp @@ -1,10 +1,56 @@ #include "models.h" -llm_build_plm::llm_build_plm(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { - const float kq_scale = 1.0f/sqrtf(float(hparams.n_embd_head_k)); +void llama_model_plm::load_arch_hparams(llama_model_loader & ml) { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + ml.get_key(LLM_KV_ATTENTION_KV_LORA_RANK, hparams.n_lora_kv); + + switch (hparams.n_layer()) { + case 32: type = LLM_TYPE_1_8B; break; + default: type = LLM_TYPE_UNKNOWN; + } +} + +void llama_model_plm::load_arch_tensors(llama_model_loader &) { + LLAMA_LOAD_LOCALS; + + const int64_t n_embd_head_qk_rope = hparams.n_rot(); + const int64_t n_embd_head_qk_nope = hparams.n_embd_head_k() - hparams.n_rot(); + const int64_t kv_lora_rank = hparams.n_lora_kv; + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + // output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0); + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + + layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head}, 0); + layer.wkv_a_mqa = create_tensor(tn(LLM_TENSOR_ATTN_KV_A_MQA, "weight", i), {n_embd, kv_lora_rank + (n_embd_head_qk_rope)}, 0); + layer.attn_kv_a_norm = create_tensor(tn(LLM_TENSOR_ATTN_KV_A_NORM, "weight", i), {kv_lora_rank}, 0); + layer.wkv_b = create_tensor(tn(LLM_TENSOR_ATTN_KV_B, "weight", i), {kv_lora_rank, n_head * (n_embd_head_qk_nope + n_embd_head_v)}, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), { n_head * ( n_embd_head_v), n_embd}, 0); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + } +} + +std::unique_ptr<llm_graph_context> llama_model_plm::build_arch_graph(const llm_graph_params & params) const { + return std::make_unique<graph>(*this, params); +} + +llama_model_plm::graph::graph(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { + const float kq_scale = 1.0f/sqrtf(float(hparams.n_embd_head_k())); + + const uint32_t n_embd_head_qk_rope = hparams.n_rot(); + const uint32_t n_embd_head_qk_nope = hparams.n_embd_head_k() - hparams.n_rot(); - const uint32_t n_embd_head_qk_rope = hparams.n_rot; - const uint32_t n_embd_head_qk_nope = hparams.n_embd_head_k - hparams.n_rot; const uint32_t kv_lora_rank = hparams.n_lora_kv; ggml_tensor * cur; @@ -37,15 +83,15 @@ llm_build_plm::llm_build_plm(const llama_model & model, const llm_graph_params & // split into {n_head * n_embd_head_qk_nope, n_tokens} ggml_tensor * q_nope = ggml_view_3d(ctx0, q, n_embd_head_qk_nope, n_head, n_tokens, - ggml_row_size(q->type, hparams.n_embd_head_k), - ggml_row_size(q->type, hparams.n_embd_head_k * n_head), + ggml_row_size(q->type, hparams.n_embd_head_k()), + ggml_row_size(q->type, hparams.n_embd_head_k() * n_head), 0); cb(q_nope, "q_nope", il); // and {n_head * n_embd_head_qk_rope, n_tokens} ggml_tensor * q_pe = ggml_view_3d(ctx0, q, n_embd_head_qk_rope, n_head, n_tokens, - ggml_row_size(q->type, hparams.n_embd_head_k), - ggml_row_size(q->type, hparams.n_embd_head_k * n_head), + ggml_row_size(q->type, hparams.n_embd_head_k()), + ggml_row_size(q->type, hparams.n_embd_head_k() * n_head), ggml_row_size(q->type, n_embd_head_qk_nope)); cb(q_pe, "q_pe", il); @@ -77,23 +123,23 @@ llm_build_plm::llm_build_plm(const llama_model & model, const llm_graph_params & // split into {n_head * n_embd_head_qk_nope, n_tokens} ggml_tensor * k_nope = ggml_view_3d(ctx0, kv, n_embd_head_qk_nope, n_head, n_tokens, - ggml_row_size(kv->type, n_embd_head_qk_nope + hparams.n_embd_head_v), - ggml_row_size(kv->type, n_head * (n_embd_head_qk_nope + hparams.n_embd_head_v)), + ggml_row_size(kv->type, n_embd_head_qk_nope + hparams.n_embd_head_v()), + ggml_row_size(kv->type, n_head * (n_embd_head_qk_nope + hparams.n_embd_head_v())), 0); cb(k_nope, "k_nope", il); // and {n_head * n_embd_head_v, n_tokens} - ggml_tensor * v_states = ggml_view_3d(ctx0, kv, hparams.n_embd_head_v, n_head, n_tokens, - ggml_row_size(kv->type, (n_embd_head_qk_nope + hparams.n_embd_head_v)), - ggml_row_size(kv->type, (n_embd_head_qk_nope + hparams.n_embd_head_v)*n_head), + ggml_tensor * v_states = ggml_view_3d(ctx0, kv, hparams.n_embd_head_v(), n_head, n_tokens, + ggml_row_size(kv->type, (n_embd_head_qk_nope + hparams.n_embd_head_v())), + ggml_row_size(kv->type, (n_embd_head_qk_nope + hparams.n_embd_head_v())*n_head), ggml_row_size(kv->type, (n_embd_head_qk_nope))); cb(v_states, "v_states", il); v_states = ggml_cont(ctx0, v_states); cb(v_states, "v_states", il); - v_states = ggml_view_2d(ctx0, v_states, hparams.n_embd_head_v * n_head, n_tokens, - ggml_row_size(kv->type, hparams.n_embd_head_v * n_head), + v_states = ggml_view_2d(ctx0, v_states, hparams.n_embd_head_v() * n_head, n_tokens, + ggml_row_size(kv->type, hparams.n_embd_head_v() * n_head), 0); cb(v_states, "v_states", il); @@ -119,7 +165,7 @@ llm_build_plm::llm_build_plm(const llama_model & model, const llm_graph_params & cb(k_states, "k_states", il); cur = build_attn(inp_attn, - model.layers[il].wo, NULL, + model.layers[il].wo, NULL, model.layers[il].wo_s, q_states, k_states, v_states, nullptr, nullptr, nullptr, kq_scale, il); } if (il == n_layer - 1 && inp_out_ids) { @@ -159,7 +205,7 @@ llm_build_plm::llm_build_plm(const llama_model & model, const llm_graph_params & cb(cur, "result_norm", -1); res->t_embd = cur; - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/qwen.cpp b/examples/talk-llama/models/qwen.cpp index 31fd9b73763..1f5dff3843c 100644 --- a/examples/talk-llama/models/qwen.cpp +++ b/examples/talk-llama/models/qwen.cpp @@ -1,10 +1,49 @@ #include "models.h" +void llama_model_qwen::load_arch_hparams(llama_model_loader & ml) { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); -llm_build_qwen::llm_build_qwen(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { - const int64_t n_embd_head = hparams.n_embd_head_v; + switch (hparams.n_layer()) { + case 32: type = LLM_TYPE_7B; break; + case 40: type = LLM_TYPE_13B; break; + default: type = LLM_TYPE_UNKNOWN; + } +} + +void llama_model_qwen::load_arch_tensors(llama_model_loader &) { + LLAMA_LOAD_LOCALS; + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0); - GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + + layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd*3}, 0); + layer.wqkv_b = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "bias", i), {n_embd*3}, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff/2}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff/2, n_embd}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff/2}, 0); + } +} + +std::unique_ptr<llm_graph_context> llama_model_qwen::build_arch_graph(const llm_graph_params & params) const { + return std::make_unique<graph>(*this, params); +} + +llama_model_qwen::graph::graph(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { + const int64_t n_embd_head = hparams.n_embd_head_v(); + + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k()); ggml_tensor * cur; ggml_tensor * inpL; @@ -28,15 +67,8 @@ llm_build_qwen::llm_build_qwen(const llama_model & model, const llm_graph_params // self-attention { - cur = build_lora_mm(model.layers[il].wqkv, cur); - cb(cur, "wqkv", il); - - cur = ggml_add(ctx0, cur, model.layers[il].bqkv); - cb(cur, "bqkv", il); - - ggml_tensor * Qcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 0*sizeof(float)*(n_embd)); - ggml_tensor * Kcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 1*sizeof(float)*(n_embd)); - ggml_tensor * Vcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 2*sizeof(float)*(n_embd)); + auto [Qcur, Kcur, Vcur] = build_qkv(model.layers[il], cur, + n_embd_head, n_head, n_head_kv, il); // using mode = 2 for neox mode Qcur = ggml_rope_ext( @@ -56,7 +88,7 @@ llm_build_qwen::llm_build_qwen(const llama_model & model, const llm_graph_params cb(Vcur, "Vcur", il); cur = build_attn(inp_attn, - model.layers[il].wo, NULL, + model.layers[il].wo, NULL, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } if (il == n_layer - 1 && inp_out_ids) { @@ -99,7 +131,7 @@ llm_build_qwen::llm_build_qwen(const llama_model & model, const llm_graph_params res->t_embd = cur; // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/qwen2.cpp b/examples/talk-llama/models/qwen2.cpp index 3da4dea3c16..e9c2ea80a6b 100644 --- a/examples/talk-llama/models/qwen2.cpp +++ b/examples/talk-llama/models/qwen2.cpp @@ -1,10 +1,60 @@ #include "models.h" -llm_build_qwen2::llm_build_qwen2(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { - const int64_t n_embd_head = hparams.n_embd_head_v; +void llama_model_qwen2::load_arch_hparams(llama_model_loader & ml) { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + + switch (hparams.n_layer()) { + case 24: type = hparams.n_embd == 1024 ? LLM_TYPE_0_5B : LLM_TYPE_1B; break; + case 28: type = hparams.n_embd == 1536 ? LLM_TYPE_1_5B : LLM_TYPE_7B; break; + case 32: type = LLM_TYPE_7B; break; + case 36: type = LLM_TYPE_3B; break; + case 40: type = hparams.n_head() == 20 ? LLM_TYPE_4B : LLM_TYPE_13B; break; + case 48: type = LLM_TYPE_14B; break; + case 64: type = LLM_TYPE_32B; break; + case 80: type = LLM_TYPE_70B; break; + default: type = LLM_TYPE_UNKNOWN; + } +} + +void llama_model_qwen2::load_arch_tensors(llama_model_loader &) { + LLAMA_LOAD_LOCALS; + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); + output_b = create_tensor(tn(LLM_TENSOR_OUTPUT, "bias"), {n_vocab}, TENSOR_NOT_REQUIRED); + // if output is NULL, init from the input tok embed + if (output == NULL) { + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); + } + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + + create_tensor_qkv(layer, i, n_embd, n_embd, n_embd_gqa, n_embd_gqa, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + } +} + +std::unique_ptr<llm_graph_context> llama_model_qwen2::build_arch_graph(const llm_graph_params & params) const { + return std::make_unique<graph>(*this, params); +} + +llama_model_qwen2::graph::graph(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { + const int64_t n_embd_head = hparams.n_embd_head_v(); - GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); - GGML_ASSERT(n_embd_head == hparams.n_rot); + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k()); + GGML_ASSERT(n_embd_head == n_rot); ggml_tensor * cur; ggml_tensor * inpL; @@ -30,30 +80,8 @@ llm_build_qwen2::llm_build_qwen2(const llama_model & model, const llm_graph_para // self-attention { // compute Q and K and RoPE them - ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); - cb(Qcur, "Qcur", il); - if (model.layers[il].bq) { - Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq); - cb(Qcur, "Qcur", il); - } - - ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); - cb(Kcur, "Kcur", il); - if (model.layers[il].bk) { - Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk); - cb(Kcur, "Kcur", il); - } - - ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); - cb(Vcur, "Vcur", il); - if (model.layers[il].bv) { - Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv); - cb(Vcur, "Vcur", il); - } - - Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); - Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + auto [Qcur, Kcur, Vcur] = build_qkv(model.layers[il], cur, + n_embd_head, n_head, n_head_kv, il); Qcur = ggml_rope_ext( ctx0, Qcur, inp_pos, nullptr, @@ -72,7 +100,7 @@ llm_build_qwen2::llm_build_qwen2(const llama_model & model, const llm_graph_para cb(Vcur, "Vcur", il); cur = build_attn(inp_attn, - model.layers[il].wo, model.layers[il].bo, + model.layers[il].wo, model.layers[il].wo_b, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } if (il == n_layer - 1 && inp_out_ids) { @@ -114,7 +142,7 @@ llm_build_qwen2::llm_build_qwen2(const llama_model & model, const llm_graph_para res->t_embd = cur; // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); if (model.output_b != nullptr) { cur = ggml_add(ctx0, cur, model.output_b); diff --git a/examples/talk-llama/models/qwen2moe.cpp b/examples/talk-llama/models/qwen2moe.cpp index 49142b71236..e831ed11aad 100644 --- a/examples/talk-llama/models/qwen2moe.cpp +++ b/examples/talk-llama/models/qwen2moe.cpp @@ -1,10 +1,72 @@ #include "models.h" -llm_build_qwen2moe::llm_build_qwen2moe(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { - const int64_t n_embd_head = hparams.n_embd_head_v; +void llama_model_qwen2moe::load_arch_hparams(llama_model_loader & ml) { + ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp, false); + ml.get_key(LLM_KV_EXPERT_SHARED_FEED_FORWARD_LENGTH, hparams.n_ff_shexp, false); - GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); - GGML_ASSERT(n_embd_head == hparams.n_rot); + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + + switch (hparams.n_layer()) { + case 24: type = LLM_TYPE_A2_7B; break; + case 28: type = LLM_TYPE_57B_A14B; break; + default: type = LLM_TYPE_UNKNOWN; + } +} + +void llama_model_qwen2moe::load_arch_tensors(llama_model_loader &) { + LLAMA_LOAD_LOCALS; + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0); + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + + create_tensor_qkv(layer, i, n_embd, n_embd, n_embd_gqa, n_embd_gqa, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + + layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0); + + if (n_expert == 0) { + throw std::runtime_error("n_expert must be > 0 for QWEN2MOE"); + } + if (n_expert_used == 0) { + throw std::runtime_error("n_expert_used must be > 0 for QWEN2MOE"); + } + + // MoE branch + const int64_t n_ff_exp = hparams.n_ff_exp ? hparams.n_ff_exp : n_ff / n_expert_used; + + layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}, 0); + layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff_exp, n_embd, n_expert}, 0); + layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}, 0); + + // Shared expert branch + const int64_t n_ff_shexp = hparams.n_ff_shexp ? hparams.n_ff_shexp : n_ff; + + layer.ffn_gate_inp_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP_SHEXP, "weight", i), {n_embd}, 0); + layer.ffn_gate_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), { n_embd, n_ff_shexp}, 0); + layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), {n_ff_shexp, n_embd}, 0); + layer.ffn_up_shexp = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), { n_embd, n_ff_shexp}, 0); + } +} + +std::unique_ptr<llm_graph_context> llama_model_qwen2moe::build_arch_graph(const llm_graph_params & params) const { + return std::make_unique<graph>(*this, params); +} + +llama_model_qwen2moe::graph::graph(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { + const int64_t n_embd_head = hparams.n_embd_head_v(); + + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k()); + GGML_ASSERT(n_embd_head == n_rot); ggml_tensor * cur; ggml_tensor * inpL; @@ -30,27 +92,8 @@ llm_build_qwen2moe::llm_build_qwen2moe(const llama_model & model, const llm_grap // self_attention { // compute Q and K and RoPE them - ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); - cb(Qcur, "Qcur", il); - if (model.layers[il].bq) { - Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq); - cb(Qcur, "Qcur", il); - } - ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); - cb(Kcur, "Kcur", il); - if (model.layers[il].bk) { - Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk); - cb(Kcur, "Kcur", il); - } - ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); - cb(Vcur, "Vcur", il); - if (model.layers[il].bv) { - Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv); - cb(Vcur, "Vcur", il); - } - Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); - Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + auto [Qcur, Kcur, Vcur] = build_qkv(model.layers[il], cur, + n_embd_head, n_head, n_head_kv, il); Qcur = ggml_rope_ext( ctx0, Qcur, inp_pos, nullptr, @@ -69,7 +112,7 @@ llm_build_qwen2moe::llm_build_qwen2moe(const llama_model & model, const llm_grap cb(Vcur, "Vcur", il); cur = build_attn(inp_attn, - model.layers[il].wo, model.layers[il].bo, + model.layers[il].wo, model.layers[il].wo_b, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } if (il == n_layer - 1 && inp_out_ids) { @@ -94,7 +137,7 @@ llm_build_qwen2moe::llm_build_qwen2moe(const llama_model & model, const llm_grap nullptr, n_expert, n_expert_used, LLM_FFN_SILU, false, - false, 0.0, + hparams.expert_weights_scale, LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX, il); cb(moe_out, "ffn_moe_out", il); @@ -142,7 +185,7 @@ llm_build_qwen2moe::llm_build_qwen2moe(const llama_model & model, const llm_grap res->t_embd = cur; // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/qwen2vl.cpp b/examples/talk-llama/models/qwen2vl.cpp index 9be38675cf7..d79db682cd4 100644 --- a/examples/talk-llama/models/qwen2vl.cpp +++ b/examples/talk-llama/models/qwen2vl.cpp @@ -1,10 +1,49 @@ #include "models.h" -llm_build_qwen2vl::llm_build_qwen2vl(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { - const int64_t n_embd_head = hparams.n_embd_head_v; +void llama_model_qwen2vl::load_arch_hparams(llama_model_loader & ml) { + ml.get_key_or_arr(LLM_KV_ROPE_DIMENSION_SECTIONS, hparams.rope_sections, 4, true); +} +// fall through + +void llama_model_qwen2vl::load_arch_tensors(llama_model_loader &) { + LLAMA_LOAD_LOCALS; + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); + output_b = create_tensor(tn(LLM_TENSOR_OUTPUT, "bias"), {n_vocab}, TENSOR_NOT_REQUIRED); + // if output is NULL, init from the input tok embed + if (output == NULL) { + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); + } + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; - GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); - GGML_ASSERT(n_embd_head == hparams.n_rot); + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + + create_tensor_qkv(layer, i, n_embd, n_embd, n_embd_gqa, n_embd_gqa, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + } +} + +std::unique_ptr<llm_graph_context> llama_model_qwen2vl::build_arch_graph(const llm_graph_params & params) const { + return std::make_unique<graph>(*this, params); +} + +llama_model_qwen2vl::graph::graph(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { + const int64_t n_embd_head = hparams.n_embd_head_v(); + + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k()); + GGML_ASSERT(n_embd_head == n_rot); ggml_tensor * cur; ggml_tensor * inpL; @@ -33,21 +72,8 @@ llm_build_qwen2vl::llm_build_qwen2vl(const llama_model & model, const llm_graph_ // self-attention { // compute Q and K and RoPE them - ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); - Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq); - cb(Qcur, "Qcur", il); - - ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); - Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk); - cb(Kcur, "Kcur", il); - - ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); - Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv); - cb(Vcur, "Vcur", il); - - Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); - Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + auto [Qcur, Kcur, Vcur] = build_qkv(model.layers[il], cur, + n_embd_head, n_head, n_head_kv, il); Qcur = ggml_rope_multi( ctx0, Qcur, inp_pos, nullptr, @@ -66,7 +92,7 @@ llm_build_qwen2vl::llm_build_qwen2vl(const llama_model & model, const llm_graph_ cb(Vcur, "Vcur", il); cur = build_attn(inp_attn, - model.layers[il].wo, model.layers[il].bo, + model.layers[il].wo, model.layers[il].wo_b, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } if (il == n_layer - 1 && inp_out_ids) { @@ -108,7 +134,7 @@ llm_build_qwen2vl::llm_build_qwen2vl(const llama_model & model, const llm_graph_ res->t_embd = cur; // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/qwen3.cpp b/examples/talk-llama/models/qwen3.cpp index a5cfffa5314..f4b2a2aebe0 100644 --- a/examples/talk-llama/models/qwen3.cpp +++ b/examples/talk-llama/models/qwen3.cpp @@ -1,10 +1,60 @@ #include "models.h" -llm_build_qwen3::llm_build_qwen3(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { - const int64_t n_embd_head = hparams.n_embd_head_v; +void llama_model_qwen3::load_arch_hparams(llama_model_loader & ml) { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + + switch (hparams.n_layer()) { + case 28: type = hparams.n_embd == 1024 ? LLM_TYPE_0_6B : LLM_TYPE_1_7B; break; + case 36: type = hparams.n_embd == 2560 ? LLM_TYPE_4B : LLM_TYPE_8B; break; + case 40: type = LLM_TYPE_14B; break; + case 64: type = LLM_TYPE_32B; break; + default: type = LLM_TYPE_UNKNOWN; + } +} + +void llama_model_qwen3::load_arch_tensors(llama_model_loader &) { + LLAMA_LOAD_LOCALS; + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); + // if output is NULL, init from the input tok embed + if (output == NULL) { + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); + } + + // output rerank head + cls_out = create_tensor(tn(LLM_TENSOR_CLS_OUT, "weight"), {n_embd, hparams.n_cls_out}, TENSOR_NOT_REQUIRED); + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + + create_tensor_qkv(layer, i, n_embd, n_embd_head_k * n_head, n_embd_gqa, n_embd_gqa, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0); + + layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k}, 0); + layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head_k}, 0); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + } +} + +std::unique_ptr<llm_graph_context> llama_model_qwen3::build_arch_graph(const llm_graph_params & params) const { + return std::make_unique<graph>(*this, params); +} + +llama_model_qwen3::graph::graph(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { + const int64_t n_embd_head = hparams.n_embd_head_v(); - GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); - GGML_ASSERT(n_embd_head == hparams.n_rot); + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k()); + GGML_ASSERT(n_embd_head == n_rot); ggml_tensor * cur; ggml_tensor * inpL; @@ -19,6 +69,8 @@ llm_build_qwen3::llm_build_qwen3(const llama_model & model, const llm_graph_para ggml_tensor * inp_out_ids = build_inp_out_ids(); for (int il = 0; il < n_layer; ++il) { + res->t_layer_inp[il] = inpL; + ggml_tensor * inpSA = inpL; // norm @@ -30,18 +82,8 @@ llm_build_qwen3::llm_build_qwen3(const llama_model & model, const llm_graph_para // self-attention { // compute Q and K and RoPE them - ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); - cb(Qcur, "Qcur", il); - - ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); - cb(Kcur, "Kcur", il); - - ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); - cb(Vcur, "Vcur", il); - - Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); - Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + auto [Qcur, Kcur, Vcur] = build_qkv(model.layers[il], cur, + n_embd_head, n_head, n_head_kv, il); Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, NULL, LLM_NORM_RMS, il); cb(Qcur, "Qcur_normed", il); @@ -66,7 +108,7 @@ llm_build_qwen3::llm_build_qwen3(const llama_model & model, const llm_graph_para cb(Vcur, "Vcur", il); cur = build_attn(inp_attn, - model.layers[il].wo, model.layers[il].bo, + model.layers[il].wo, model.layers[il].wo_b, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } if (il == n_layer - 1 && inp_out_ids) { @@ -83,9 +125,9 @@ llm_build_qwen3::llm_build_qwen3(const llama_model & model, const llm_graph_para cb(cur, "ffn_norm", il); cur = build_ffn(cur, - model.layers[il].ffn_up, NULL, NULL, - model.layers[il].ffn_gate, NULL, NULL, - model.layers[il].ffn_down, NULL, NULL, + model.layers[il].ffn_up, NULL, model.layers[il].ffn_up_s, + model.layers[il].ffn_gate, NULL, model.layers[il].ffn_gate_s, + model.layers[il].ffn_down, NULL, model.layers[il].ffn_down_s, NULL, LLM_FFN_SILU, LLM_FFN_PAR, il); cb(cur, "ffn_out", il); @@ -108,7 +150,7 @@ llm_build_qwen3::llm_build_qwen3(const llama_model & model, const llm_graph_para res->t_embd = cur; // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/qwen35.cpp b/examples/talk-llama/models/qwen35.cpp new file mode 100644 index 00000000000..6783d98ec20 --- /dev/null +++ b/examples/talk-llama/models/qwen35.cpp @@ -0,0 +1,642 @@ +#include "models.h" +#include "llama-memory-recurrent.h" + +void llama_model_qwen35::load_arch_hparams(llama_model_loader & ml) { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + ml.get_key_or_arr(LLM_KV_ROPE_DIMENSION_SECTIONS, hparams.rope_sections, 4, true); + + // Load linear attention (gated delta net) parameters + ml.get_key(LLM_KV_SSM_CONV_KERNEL, hparams.ssm_d_conv); + ml.get_key(LLM_KV_SSM_INNER_SIZE, hparams.ssm_d_inner); + ml.get_key(LLM_KV_SSM_STATE_SIZE, hparams.ssm_d_state); + ml.get_key(LLM_KV_SSM_TIME_STEP_RANK, hparams.ssm_dt_rank); + ml.get_key(LLM_KV_SSM_GROUP_COUNT, hparams.ssm_n_group); + + // NextN/MTP (Qwen3.5/3.6): extra decoder block appended beyond the main stack + ml.get_key(LLM_KV_NEXTN_PREDICT_LAYERS, hparams.n_layer_nextn, false); + GGML_ASSERT(hparams.n_layer_nextn < hparams.n_layer_all && "n_layer_nextn must be < n_layer_impl"); + + // Mark recurrent layers (linear attention layers). MTP layers are dense + // attention-only and must be flagged non-recurrent. + if (!ml.get_key_or_arr(LLM_KV_ATTENTION_RECURRENT_LAYERS, hparams.is_recr_impl, hparams.n_layer_all, false)) { + uint32_t full_attn_interval = 4; + ml.get_key(LLM_KV_FULL_ATTENTION_INTERVAL, full_attn_interval, false); + for (uint32_t i = 0; i < hparams.n_layer_all; ++i) { + hparams.is_recr_impl[i] = (i < hparams.n_layer()) && ((i + 1) % full_attn_interval != 0); + } + } + + switch (hparams.n_layer()) { + case 24: type = hparams.n_embd == 1024 ? LLM_TYPE_0_8B : LLM_TYPE_2B; break; + case 32: type = hparams.n_embd == 2560 ? LLM_TYPE_4B : LLM_TYPE_9B; break; + case 64: type = LLM_TYPE_27B; break; + default: type = LLM_TYPE_UNKNOWN; + } +} + +void llama_model_qwen35::load_arch_tensors(llama_model_loader & ml) { + LLAMA_LOAD_LOCALS; + + const bool mtp_only = (hparams.n_layer_nextn > 0) && (ml.get_weight("blk.0.attn_norm.weight") == nullptr); + const int trunk_flags = mtp_only ? TENSOR_NOT_REQUIRED : 0; + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), { n_embd, n_vocab }, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), { n_embd }, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), { n_embd, n_vocab }, TENSOR_NOT_REQUIRED); + + // if output is NULL, init from the input tok embed + if (output == NULL) { + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), { n_embd, n_vocab }, TENSOR_DUPLICATED); + } + + auto load_block_trunk = [&](int il, int flags) { + auto & layer = layers[il]; + + // Calculate dimensions from hyperparameters + const int64_t head_k_dim = hparams.ssm_d_state; + const int64_t head_v_dim = hparams.ssm_d_state; + const int64_t n_k_heads = hparams.ssm_n_group; + const int64_t n_v_heads = hparams.ssm_dt_rank; + const int64_t key_dim = head_k_dim * n_k_heads; + const int64_t value_dim = head_v_dim * n_v_heads; + const int64_t conv_dim = key_dim * 2 + value_dim; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", il), { n_embd }, flags); + layer.attn_post_norm = create_tensor(tn(LLM_TENSOR_ATTN_POST_NORM, "weight", il), { n_embd }, flags); + + if (!hparams.is_recr(il)) { + // Attention layers + create_tensor_qkv(layer, il, n_embd, n_embd_head_k * n_head * 2, n_embd_k_gqa, n_embd_v_gqa, flags); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", il), { n_embd_head_k * n_head, n_embd }, flags); + + // Q/K normalization for attention layers + layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", il), { n_embd_head_k }, flags); + layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", il), { n_embd_head_k }, flags); + } else { + // Linear attention (gated delta net) specific tensors + // Create tensors with calculated dimensions + layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", il), { n_embd, key_dim * 2 + value_dim }, TENSOR_NOT_REQUIRED); + layer.wqkv_gate = create_tensor(tn(LLM_TENSOR_ATTN_GATE, "weight", il), { n_embd, value_dim }, TENSOR_NOT_REQUIRED); + layer.ssm_conv1d = create_tensor(tn(LLM_TENSOR_SSM_CONV1D, "weight", il), { hparams.ssm_d_conv, conv_dim }, flags); + layer.ssm_dt = create_tensor(tn(LLM_TENSOR_SSM_DT, "bias", il), { hparams.ssm_dt_rank }, flags); + layer.ssm_a = create_tensor(tn(LLM_TENSOR_SSM_A_NOSCAN, il), { hparams.ssm_dt_rank }, flags); + layer.ssm_beta = create_tensor(tn(LLM_TENSOR_SSM_BETA, "weight", il), { n_embd, n_v_heads }, flags); + layer.ssm_alpha = create_tensor(tn(LLM_TENSOR_SSM_ALPHA, "weight", il), { n_embd, n_v_heads }, flags); + layer.ssm_norm = create_tensor(tn(LLM_TENSOR_SSM_NORM, "weight", il), { head_v_dim }, flags); + layer.ssm_out = create_tensor(tn(LLM_TENSOR_SSM_OUT, "weight", il), { value_dim, n_embd }, flags); + } + + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", il), {n_embd, n_ff}, flags); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", il), { n_ff, n_embd}, flags); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", il), {n_embd, n_ff}, flags); + }; + + auto load_block_mtp = [&](int il) { + auto & layer = layers[il]; + + // MTP block looks like a full-attention Qwen3.5 decoder block. + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", il), { n_embd }, 0); + layer.attn_post_norm = create_tensor(tn(LLM_TENSOR_ATTN_POST_NORM, "weight", il), { n_embd }, 0); + + create_tensor_qkv(layer, il, n_embd, n_embd_head_k * n_head * 2, n_embd_k_gqa, n_embd_v_gqa, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", il), { n_embd_head_k * n_head, n_embd }, 0); + layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", il), { n_embd_head_k }, 0); + layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", il), { n_embd_head_k }, 0); + + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", il), {n_embd, n_ff}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", il), { n_ff, n_embd}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", il), {n_embd, n_ff}, 0); + + // NextN-specific tensors that define the MTP block. + layer.nextn.eh_proj = create_tensor(tn(LLM_TENSOR_NEXTN_EH_PROJ, "weight", il), { 2 * n_embd, n_embd }, 0); + layer.nextn.enorm = create_tensor(tn(LLM_TENSOR_NEXTN_ENORM, "weight", il), { n_embd }, 0); + layer.nextn.hnorm = create_tensor(tn(LLM_TENSOR_NEXTN_HNORM, "weight", il), { n_embd }, 0); + layer.nextn.embed_tokens = create_tensor(tn(LLM_TENSOR_NEXTN_EMBED_TOKENS, "weight", il), { n_embd, n_vocab }, TENSOR_NOT_REQUIRED); + layer.nextn.shared_head_head = create_tensor(tn(LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD, "weight", il), { n_embd, n_vocab }, TENSOR_NOT_REQUIRED); + layer.nextn.shared_head_norm = create_tensor(tn(LLM_TENSOR_NEXTN_SHARED_HEAD_NORM, "weight", il), { n_embd }, TENSOR_NOT_REQUIRED); + }; + + for (int i = 0; i < n_layer; ++i) { + load_block_trunk(i, trunk_flags); + } + for (int i = n_layer; i < n_layer_all; ++i) { + load_block_mtp(i); + } +} + +std::unique_ptr<llm_graph_context> llama_model_qwen35::build_arch_graph(const llm_graph_params & params) const { + if (params.gtype == LLM_GRAPH_TYPE_DECODER_MTP) { + return std::make_unique<graph_mtp>(*this, params); + } + return std::make_unique<graph>(*this, params); +} + +llama_model_qwen35::graph::graph(const llama_model & model, const llm_graph_params & params) : + llm_build_delta_net_base(params), model(model) { + const int64_t n_embd_head = hparams.n_embd_head_v(); + + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k()); + + int sections[4]; + std::copy(std::begin(hparams.rope_sections), std::begin(hparams.rope_sections) + 4, sections); + + ggml_tensor * cur; + ggml_tensor * inpL; + + inpL = build_inp_embd(model.tok_embd); + + cb(inpL, "model.input_embed", -1); + + auto * inp = build_inp_mem_hybrid(); + + ggml_tensor * inp_pos = build_inp_pos(); + ggml_tensor * inp_out_ids = build_inp_out_ids(); + + // MTP/NextN layers are loaded as extra decoder blocks but not executed in the main pass. + for (int il = 0; il < n_layer; ++il) { + ggml_tensor * inpSA = inpL; + + cur = build_norm(inpL, model.layers[il].attn_norm, nullptr, LLM_NORM_RMS, il); + cb(cur, "attn_norm", il); + + ggml_build_forward_expand(gf, cur); + + // Determine layer type and build appropriate attention mechanism + if (hparams.is_recr(il)) { + // Linear attention layer (gated delta net) + cur = build_layer_attn_linear(inp->get_recr(), cur, il); + } else { + // Full attention layer + cur = build_layer_attn(inp->get_attn(), cur, inp_pos, sections, il); + } + + if (il == n_layer - 1 && inp_out_ids && cparams.embeddings_nextn_masked) { + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); + } + + // Residual connection + cur = ggml_add(ctx0, cur, inpSA); + cb(cur, "attn_residual", il); + + // Save the tensor before post-attention norm for residual connection + ggml_tensor * ffn_residual = cur; + + // Post-attention norm + ggml_tensor * attn_post_norm = build_norm(cur, model.layers[il].attn_post_norm, nullptr, LLM_NORM_RMS, il); + cb(attn_post_norm, "attn_post_norm", il); + + // Dense FFN layer - without residual connection + cur = build_layer_ffn(attn_post_norm, il); + cb(cur, "ffn_out", il); + + // Residual connection for FFN - add to the tensor from before post_attention_layernorm + cur = ggml_add(ctx0, cur, ffn_residual); + cb(cur, "post_ffn", il); + + cur = build_cvec(cur, il); + cb(cur, "l_out", il); + + // Input for next layer + inpL = cur; + } + cur = inpL; + + cur = build_norm(cur, model.output_norm, nullptr, LLM_NORM_RMS, -1); + + cb(cur, "h_nextn", -1); + res->t_h_nextn = cur; + + if (!cparams.embeddings_nextn_masked && inp_out_ids) { + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + } + + cb(cur, "result_norm", -1); + res->t_embd = cur; + + // LM head + cur = build_lora_mm(model.output, cur, model.output_s); + + cb(cur, "result_output", -1); + res->t_logits = cur; + + ggml_build_forward_expand(gf, cur); +} + +std::pair<ggml_tensor *, ggml_tensor *> llama_model_qwen35::graph::build_qkvz( + ggml_tensor * input, + int il) { + const int64_t n_seqs = ubatch.n_seqs; + const int64_t n_seq_tokens = ubatch.n_seq_tokens; + + ggml_tensor * qkv_mixed = build_lora_mm(model.layers[il].wqkv, input, model.layers[il].wqkv_s); + qkv_mixed = ggml_reshape_3d(ctx0, qkv_mixed, qkv_mixed->ne[0], n_seq_tokens, n_seqs); + cb(qkv_mixed, "linear_attn_qkv_mixed", il); + + ggml_tensor * z = build_lora_mm(model.layers[il].wqkv_gate, input, model.layers[il].wqkv_gate_s); + cb(z, "z", il); + + return { qkv_mixed, z }; +} + +ggml_tensor * llama_model_qwen35::graph::build_norm_gated( + ggml_tensor * input, + ggml_tensor * weights, + ggml_tensor * gate, + int layer) { + ggml_tensor * normalized = build_norm(input, weights, nullptr, LLM_NORM_RMS, layer); + ggml_tensor * gated_silu = ggml_silu(ctx0, gate); + + return ggml_mul(ctx0, normalized, gated_silu); +} + +ggml_tensor * llama_model_qwen35::graph::build_layer_attn( + llm_graph_input_attn_kv * inp, + ggml_tensor * cur, + ggml_tensor * inp_pos, + int * sections, + int il) { + const int64_t n_embd_head = hparams.n_embd_head_v(); + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k()); + + // Order: joint QG projection, QG split, Q norm, KV projection, K norm, RoPE, attention + + // Qwen3Next uses a single Q projection that outputs query + gate + ggml_tensor * Qcur_full = build_lora_mm(model.layers[il].wq, cur, model.layers[il].wq_s); // [ (n_embd_head * 2) * n_head, n_tokens ] + cb(Qcur_full, "Qcur_full", il); + + ggml_tensor * Qcur = ggml_view_3d(ctx0, Qcur_full, n_embd_head, n_head, n_tokens, + ggml_element_size(Qcur_full) * n_embd_head * 2, + ggml_element_size(Qcur_full) * n_embd_head * 2 * n_head, 0); + cb(Qcur, "Qcur_reshaped", il); + + // Apply Q normalization + Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, nullptr, LLM_NORM_RMS, il); + cb(Qcur, "Qcur_normed", il); + + ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur, model.layers[il].wk_s); + cb(Kcur, "Kcur", il); + + ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur, model.layers[il].wv_s); + cb(Vcur, "Vcur", il); + + // Apply K normalization + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); + Kcur = build_norm(Kcur, model.layers[il].attn_k_norm, nullptr, LLM_NORM_RMS, il); + cb(Kcur, "Kcur_normed", il); + + ggml_tensor * gate = ggml_view_3d(ctx0, Qcur_full, n_embd_head, n_head, n_tokens, + ggml_element_size(Qcur_full) * n_embd_head * 2, + ggml_element_size(Qcur_full) * n_embd_head * 2 * n_head, + ggml_element_size(Qcur_full) * n_embd_head); + gate = ggml_cont_2d(ctx0, gate, n_embd_head * n_head, n_tokens); + cb(gate, "gate_reshaped", il); + + Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + + // Apply MRoPE + Qcur = ggml_rope_multi( + ctx0, Qcur, inp_pos, nullptr, + n_rot, sections, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + Kcur = ggml_rope_multi( + ctx0, Kcur, inp_pos, nullptr, + n_rot, sections, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + cb(Qcur, "Qcur", il); + cb(Kcur, "Kcur", il); + cb(Vcur, "Vcur", il); + + // Attention computation + const float kq_scale = hparams.f_attention_scale == 0.0f ? 1.0f / sqrtf(float(n_embd_head)) : hparams.f_attention_scale; + + cur = build_attn(inp, + nullptr, nullptr, nullptr, + Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale, il); + cb(cur, "attn_pregate", il); + + ggml_tensor * gate_sigmoid = ggml_sigmoid(ctx0, gate); + cb(gate_sigmoid, "gate_sigmoid", il); + + cur = ggml_mul(ctx0, cur, gate_sigmoid); + cb(cur, "attn_gated", il); + + cur = build_lora_mm(model.layers[il].wo, cur, model.layers[il].wo_s); + cb(cur, "attn_output", il); + + return cur; +} + +ggml_tensor * llama_model_qwen35::graph::build_layer_attn_linear( + llm_graph_input_rs * inp, + ggml_tensor * cur, + int il) { + const auto * mctx_cur = inp->mctx; + + const int64_t d_inner = hparams.ssm_d_inner; + const int64_t n_seqs = ubatch.n_seqs; + const int64_t head_k_dim = hparams.ssm_d_state; + const int64_t num_k_heads = hparams.ssm_n_group; + const int64_t num_v_heads = hparams.ssm_dt_rank; + const int64_t head_v_dim = d_inner / num_v_heads; + const int64_t n_seq_tokens = ubatch.n_seq_tokens; + + GGML_ASSERT(n_seqs != 0); + GGML_ASSERT(ubatch.equal_seqs()); + GGML_ASSERT(ubatch.n_tokens == n_seq_tokens * n_seqs); + + // Input projections + auto qkvz = build_qkvz(cur, il); + ggml_tensor * qkv_mixed = qkvz.first; + ggml_tensor * z = qkvz.second; + + ggml_tensor * beta = build_lora_mm(model.layers[il].ssm_beta, cur, model.layers[il].ssm_beta_s); + beta = ggml_reshape_4d(ctx0, beta, 1, num_v_heads, n_seq_tokens, n_seqs); + cb(beta, "beta", il); + + beta = ggml_sigmoid(ctx0, beta); + cb(beta, "beta_sigmoid", il); + + ggml_tensor * alpha = build_lora_mm(model.layers[il].ssm_alpha, cur, model.layers[il].ssm_alpha_s); + alpha = ggml_reshape_3d(ctx0, alpha, num_v_heads, n_seq_tokens, n_seqs); + cb(alpha, "alpha", il); + + ggml_tensor * alpha_biased = ggml_add(ctx0, alpha, model.layers[il].ssm_dt); + ggml_tensor * alpha_softplus = ggml_softplus(ctx0, alpha_biased); + cb(alpha_softplus, "a_softplus", il); + + ggml_tensor * gate = ggml_mul(ctx0, alpha_softplus, model.layers[il].ssm_a); // -A_log.exp() * softplus + cb(gate, "gate", il); + + gate = ggml_reshape_4d(ctx0, gate, 1, num_v_heads, n_seq_tokens, n_seqs); + + ggml_tensor * conv_states_all = mctx_cur->get_r_l(il); + ggml_tensor * ssm_states_all = mctx_cur->get_s_l(il); + + ggml_tensor * conv_kernel = model.layers[il].ssm_conv1d; + const int64_t conv_kernel_size = conv_kernel->ne[0]; + const int64_t conv_channels = d_inner + 2 * hparams.ssm_n_group * hparams.ssm_d_state; + + ggml_tensor * conv_input = build_conv_state(inp, conv_states_all, qkv_mixed, conv_kernel_size, conv_channels, il); + + ggml_tensor * state = build_rs(inp, ssm_states_all, hparams.n_embd_s(), n_seqs); + state = ggml_reshape_4d(ctx0, state, head_v_dim, head_v_dim, num_v_heads, n_seqs); + cb(state, "state_predelta", il); + + ggml_tensor * conv_output_proper = ggml_ssm_conv(ctx0, conv_input, conv_kernel); + cb(conv_output_proper, "conv_output_raw", il); + + ggml_tensor * conv_output_silu = ggml_silu(ctx0, conv_output_proper); + cb(conv_output_silu, "conv_output_silu", il); + + ggml_tensor * conv_qkv_mix = conv_output_silu; + + // Calculate the total conv dimension + int64_t qkv_dim = head_k_dim * num_k_heads * 2 + head_v_dim * num_v_heads; + int64_t nb1_qkv = ggml_row_size(conv_qkv_mix->type, qkv_dim); + + // Extract the convolved Q, K, V from conv_output + ggml_tensor * q_conv = ggml_view_4d(ctx0, conv_qkv_mix, head_k_dim, num_k_heads, n_seq_tokens, n_seqs, + ggml_row_size(conv_qkv_mix->type, head_k_dim), + nb1_qkv, + nb1_qkv * n_seq_tokens, + 0); + + ggml_tensor * k_conv = ggml_view_4d(ctx0, conv_qkv_mix, head_k_dim, num_k_heads, n_seq_tokens, n_seqs, + ggml_row_size(conv_qkv_mix->type, head_k_dim), + nb1_qkv, + nb1_qkv * n_seq_tokens, + head_k_dim * num_k_heads * ggml_element_size(conv_qkv_mix)); + + ggml_tensor * v_conv = ggml_view_4d(ctx0, conv_qkv_mix, head_v_dim, num_v_heads, n_seq_tokens, n_seqs, + ggml_row_size(conv_qkv_mix->type, head_v_dim), + nb1_qkv, + nb1_qkv * n_seq_tokens, + ggml_row_size(conv_qkv_mix->type, 2 * head_k_dim * num_k_heads)); + + cb(q_conv, "q_conv", il); + cb(k_conv, "k_conv", il); + cb(v_conv, "v_conv", il); + + const float eps_norm = hparams.f_norm_rms_eps; + + q_conv = ggml_l2_norm(ctx0, q_conv, eps_norm); + k_conv = ggml_l2_norm(ctx0, k_conv, eps_norm); + + //q_conv = ggml_cont_4d(ctx0, q_conv, head_k_dim, num_k_heads, n_seq_tokens, n_seqs); + //k_conv = ggml_cont_4d(ctx0, k_conv, head_k_dim, num_k_heads, n_seq_tokens, n_seqs); + //v_conv = ggml_cont_4d(ctx0, v_conv, head_v_dim, num_v_heads, n_seq_tokens, n_seqs); + + // if head keys and value keys are different, repeat to force tensors into matching shapes + // note: need explicit repeat only if we are not using the fused GDN. + if (num_k_heads != num_v_heads && (!cparams.fused_gdn_ar || !cparams.fused_gdn_ch)) { + GGML_ASSERT(num_v_heads % num_k_heads == 0); + q_conv = ggml_repeat_4d(ctx0, q_conv, head_k_dim, num_v_heads, n_seq_tokens, n_seqs); + k_conv = ggml_repeat_4d(ctx0, k_conv, head_k_dim, num_v_heads, n_seq_tokens, n_seqs); + } + + cb(q_conv, "q_conv_predelta", il); + cb(k_conv, "k_conv_predelta", il); + cb(v_conv, "v_conv_predelta", il); + + ggml_tensor * output = build_recurrent_attn(inp, ssm_states_all, q_conv, k_conv, v_conv, gate, beta, state, il); + + // z: [head_dim, n_heads, n_tokens, n_seqs] -> [n_heads * n_tokens * n_seqs, head_dim] + ggml_tensor * z_2d = ggml_reshape_4d(ctx0, z, head_v_dim, num_v_heads, n_seq_tokens, n_seqs); + + // Apply gated normalization: self.norm(core_attn_out, z) + ggml_tensor * attn_out_norm = build_norm_gated(output, model.layers[il].ssm_norm, z_2d, il); + + // Final reshape: [head_dim, n_heads, n_tokens, n_seqs] -> [n_tokens, n_seqs, n_heads * head_dim] + ggml_tensor * final_output = ggml_reshape_3d(ctx0, attn_out_norm, head_v_dim * num_v_heads, n_seq_tokens, n_seqs); + cb(final_output, "final_output", il); + + // Output projection + cur = build_lora_mm(model.layers[il].ssm_out, final_output, model.layers[il].ssm_out_s); + cb(cur, "linear_attn_out", il); + + // Reshape back to original dimensions + cur = ggml_reshape_2d(ctx0, cur, n_embd, n_seq_tokens * n_seqs); + + return cur; +} + +ggml_tensor * llama_model_qwen35::graph::build_layer_ffn(ggml_tensor * cur, const int il) { + // Qwen3.5 does not use MoE FFN + GGML_ASSERT(model.layers[il].ffn_gate_inp == nullptr); + + cur = build_ffn(cur, + model.layers[il].ffn_up, NULL, model.layers[il].ffn_up_s, + model.layers[il].ffn_gate, NULL, model.layers[il].ffn_gate_s, + model.layers[il].ffn_down, NULL, model.layers[il].ffn_down_s, + NULL, + LLM_FFN_SILU, LLM_FFN_PAR, il); + cb(cur, "ffn_out", il); + + return cur; +} + +// LLM_GRAPH_TYPE_DECODER_MTP draft head for Qwen3.5/3.6 dense series +llama_model_qwen35::graph_mtp::graph_mtp(const llama_model & model, const llm_graph_params & params) + : llm_graph_context(params) { + GGML_ASSERT(hparams.n_layer_nextn > 0 && "QWEN35 MTP requires n_layer_nextn > 0"); + GGML_ASSERT(hparams.n_layer_nextn == 1 && "QWEN35 MTP currently only supports a single MTP block"); + + const int64_t n_embd_head = hparams.n_embd_head_v(); + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k()); + + // hparams.n_layer includes both main model layers and MTP layers. The MTP + // layer is stored immediately after the main layers in model.layers[]. + const int il = hparams.n_layer(); + const auto & layer = model.layers[il]; + + GGML_ASSERT(layer.nextn.eh_proj && "MTP block missing nextn.eh_proj"); + GGML_ASSERT(layer.nextn.enorm && "MTP block missing nextn.enorm"); + GGML_ASSERT(layer.nextn.hnorm && "MTP block missing nextn.hnorm"); + + int sections[4]; + std::copy(std::begin(hparams.rope_sections), std::begin(hparams.rope_sections) + 4, sections); + + // TODO: extract in a common llm_graph_context::build_inp_embd_h() + auto inp = std::make_unique<llm_graph_input_embd_h>(hparams.n_embd); + + inp->tokens = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens); + ggml_set_input(inp->tokens); + + inp->embd = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, hparams.n_embd_inp(), n_tokens); + ggml_set_input(inp->embd); + + // TODO: make static using `ggml_build_forward_select()` + // see llm_graph_context::build_inp_embd() for reference + ggml_tensor * tok_embd; + if (ubatch.token) { + ggml_tensor * tok_embd_w = layer.nextn.embed_tokens ? layer.nextn.embed_tokens : model.tok_embd; + + tok_embd = ggml_get_rows(ctx0, tok_embd_w, inp->tokens); + } else { + tok_embd = inp->embd; + } + cb(tok_embd, "mtp_tok_embd", il); + + inp->h = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, hparams.n_embd, n_tokens); + ggml_set_input(inp->h); + ggml_set_name(inp->h, "mtp_h_input"); + + ggml_tensor * h_embd = inp->h; + + res->add_input(std::move(inp)); + + ggml_tensor * inp_pos = build_inp_pos(); + ggml_tensor * inp_out_ids = build_inp_out_ids(); + + auto * inp_attn = build_attn_inp_kv(); + + ggml_tensor * h_norm = build_norm(h_embd, layer.nextn.hnorm, nullptr, LLM_NORM_RMS, il); + cb(h_norm, "mtp_hnorm", il); + + ggml_tensor * e_norm = build_norm(tok_embd, layer.nextn.enorm, nullptr, LLM_NORM_RMS, il); + cb(e_norm, "mtp_enorm", il); + + ggml_tensor * concat = ggml_concat(ctx0, e_norm, h_norm, /*dim=*/ 0); + cb(concat, "mtp_concat", il); + + ggml_tensor * cur = build_lora_mm(layer.nextn.eh_proj, concat, layer.nextn.eh_proj_s); + cb(cur, "mtp_eh_proj", il); + + ggml_tensor * inpSA = cur; + + cur = build_norm(cur, layer.attn_norm, nullptr, LLM_NORM_RMS, il); + cb(cur, "mtp_attn_norm", il); + + ggml_tensor * Qcur_full = build_lora_mm(layer.wq, cur, layer.wq_s); + cb(Qcur_full, "mtp_Qcur_full", il); + + ggml_tensor * Qcur = ggml_view_3d(ctx0, Qcur_full, + n_embd_head, n_head, n_tokens, + ggml_element_size(Qcur_full) * n_embd_head * 2, + ggml_element_size(Qcur_full) * n_embd_head * 2 * n_head, + 0); + Qcur = build_norm(Qcur, layer.attn_q_norm, nullptr, LLM_NORM_RMS, il); + cb(Qcur, "mtp_Qcur_normed", il); + + ggml_tensor * gate = ggml_view_3d(ctx0, Qcur_full, + n_embd_head, n_head, n_tokens, + ggml_element_size(Qcur_full) * n_embd_head * 2, + ggml_element_size(Qcur_full) * n_embd_head * 2 * n_head, + ggml_element_size(Qcur_full) * n_embd_head); + gate = ggml_cont_2d(ctx0, gate, n_embd_head * n_head, n_tokens); + cb(gate, "mtp_gate", il); + + ggml_tensor * Kcur = build_lora_mm(layer.wk, cur, layer.wk_s); + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); + Kcur = build_norm(Kcur, layer.attn_k_norm, nullptr, LLM_NORM_RMS, il); + cb(Kcur, "mtp_Kcur_normed", il); + + ggml_tensor * Vcur = build_lora_mm(layer.wv, cur, layer.wv_s); + Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + cb(Vcur, "mtp_Vcur", il); + + Qcur = ggml_rope_multi(ctx0, Qcur, inp_pos, nullptr, + n_rot, sections, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow); + Kcur = ggml_rope_multi(ctx0, Kcur, inp_pos, nullptr, + n_rot, sections, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow); + + const float kq_scale = hparams.f_attention_scale == 0.0f + ? 1.0f / sqrtf(float(n_embd_head)) : hparams.f_attention_scale; + + cur = build_attn(inp_attn, + nullptr, nullptr, nullptr, + Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale, il); + cb(cur, "mtp_attn_pregate", il); + + cur = ggml_mul(ctx0, cur, ggml_sigmoid(ctx0, gate)); + cur = build_lora_mm(layer.wo, cur, layer.wo_s); + cb(cur, "mtp_attn_out", il); + + cur = ggml_add(ctx0, cur, inpSA); + cb(cur, "mtp_attn_residual", il); + + ggml_tensor * ffn_residual = cur; + cur = build_norm(cur, layer.attn_post_norm, nullptr, LLM_NORM_RMS, il); + cb(cur, "mtp_attn_post_norm", il); + + cur = build_ffn(cur, + layer.ffn_up, nullptr, layer.ffn_up_s, + layer.ffn_gate, nullptr, layer.ffn_gate_s, + layer.ffn_down, nullptr, layer.ffn_down_s, + nullptr, + LLM_FFN_SILU, LLM_FFN_PAR, il); + cb(cur, "mtp_ffn_out", il); + + cur = ggml_add(ctx0, cur, ffn_residual); + cb(cur, "mtp_post_ffn", il); + + ggml_tensor * head_norm_w = layer.nextn.shared_head_norm + ? layer.nextn.shared_head_norm + : model.output_norm; + GGML_ASSERT(head_norm_w && "QWEN35 MTP: missing both nextn.shared_head_norm and output_norm"); + cur = build_norm(cur, head_norm_w, nullptr, LLM_NORM_RMS, -1); + + cb(cur, "h_nextn", -1); + res->t_h_nextn = cur; + + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + cb(cur, "mtp_shared_head_norm", -1); + + ggml_tensor * head_w = layer.nextn.shared_head_head ? layer.nextn.shared_head_head : model.output; + ggml_tensor * head_s = layer.nextn.shared_head_head ? layer.nextn.shared_head_head_s : model.output_s; + GGML_ASSERT(head_w && "QWEN35 MTP: missing LM head (nextn.shared_head_head or model.output)"); + cur = build_lora_mm(head_w, cur, head_s); + cb(cur, "result_output", -1); + + res->t_logits = cur; + ggml_build_forward_expand(gf, cur); +} diff --git a/examples/talk-llama/models/qwen35moe.cpp b/examples/talk-llama/models/qwen35moe.cpp new file mode 100644 index 00000000000..eb5e9a406a1 --- /dev/null +++ b/examples/talk-llama/models/qwen35moe.cpp @@ -0,0 +1,739 @@ +#include "models.h" +#include "llama-memory-recurrent.h" + +void llama_model_qwen35moe::load_arch_hparams(llama_model_loader & ml) { + ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp, false); + ml.get_key(LLM_KV_EXPERT_SHARED_FEED_FORWARD_LENGTH, hparams.n_ff_shexp, false); + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + + ml.get_key_or_arr(LLM_KV_ROPE_DIMENSION_SECTIONS, hparams.rope_sections, 4, true); + + // Load linear attention (gated delta net) parameters + ml.get_key(LLM_KV_SSM_CONV_KERNEL, hparams.ssm_d_conv); + ml.get_key(LLM_KV_SSM_INNER_SIZE, hparams.ssm_d_inner); + ml.get_key(LLM_KV_SSM_STATE_SIZE, hparams.ssm_d_state); + ml.get_key(LLM_KV_SSM_TIME_STEP_RANK, hparams.ssm_dt_rank); + ml.get_key(LLM_KV_SSM_GROUP_COUNT, hparams.ssm_n_group); + + // NextN/MTP (Qwen3.5/3.6): extra decoder block appended beyond the main stack + ml.get_key(LLM_KV_NEXTN_PREDICT_LAYERS, hparams.n_layer_nextn, false); + GGML_ASSERT(hparams.n_layer_nextn < hparams.n_layer_all && "n_layer_nextn must be < n_layer_impl"); + + // Mark recurrent layers (linear attention layers). MTP layers are dense + // attention-only and must be flagged non-recurrent. + if (!ml.get_key_or_arr(LLM_KV_ATTENTION_RECURRENT_LAYERS, hparams.is_recr_impl, hparams.n_layer_all, false)) { + uint32_t full_attn_interval = 4; + ml.get_key(LLM_KV_FULL_ATTENTION_INTERVAL, full_attn_interval, false); + for (uint32_t i = 0; i < hparams.n_layer_all; ++i) { + hparams.is_recr_impl[i] = (i < hparams.n_layer()) && ((i + 1) % full_attn_interval != 0); + } + } + + switch (hparams.n_layer()) { + case 40: type = LLM_TYPE_35B_A3B; break; + case 48: type = LLM_TYPE_122B_A10B; break; + case 60: type = LLM_TYPE_397B_A17B; break; + default: type = LLM_TYPE_UNKNOWN; + } +} + +void llama_model_qwen35moe::load_arch_tensors(llama_model_loader & ml) { + LLAMA_LOAD_LOCALS; + + const bool mtp_only = (hparams.n_layer_nextn > 0) && (ml.get_weight("blk.0.attn_norm.weight") == nullptr); + const int trunk_flags = mtp_only ? TENSOR_NOT_REQUIRED : 0; + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), { n_embd, n_vocab }, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), { n_embd }, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), { n_embd, n_vocab }, TENSOR_NOT_REQUIRED); + + // if output is NULL, init from the input tok embed + if (output == NULL) { + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), { n_embd, n_vocab }, TENSOR_DUPLICATED); + } + + auto load_block_trunk = [&](int il, int flags) { + auto & layer = layers[il]; + + const int64_t n_ff_exp = hparams.n_ff_exp ? hparams.n_ff_exp : n_ff / n_expert_used; + const int64_t n_ff_shexp = hparams.n_ff_shexp ? hparams.n_ff_shexp : n_ff; + + // Calculate dimensions from hyperparameters + const int64_t head_k_dim = hparams.ssm_d_state; + const int64_t head_v_dim = hparams.ssm_d_state; + const int64_t n_k_heads = hparams.ssm_n_group; + const int64_t n_v_heads = hparams.ssm_dt_rank; + const int64_t key_dim = head_k_dim * n_k_heads; + const int64_t value_dim = head_v_dim * n_v_heads; + const int64_t conv_dim = key_dim * 2 + value_dim; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", il), { n_embd }, flags); + layer.attn_post_norm = create_tensor(tn(LLM_TENSOR_ATTN_POST_NORM, "weight", il), { n_embd }, flags); + + if (!hparams.is_recr(il)) { + // Attention layers + create_tensor_qkv(layer, il, n_embd, n_embd_head_k * n_head * 2, n_embd_k_gqa, n_embd_v_gqa, flags); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", il), { n_embd_head_k * n_head, n_embd }, flags); + + // Q/K normalization for attention layers + layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", il), { n_embd_head_k }, flags); + layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", il), { n_embd_head_k }, flags); + } else { + // Linear attention (gated delta net) specific tensors + // Create tensors with calculated dimensions + layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", il), { n_embd, key_dim * 2 + value_dim }, TENSOR_NOT_REQUIRED); + layer.wqkv_gate = create_tensor(tn(LLM_TENSOR_ATTN_GATE, "weight", il), { n_embd, value_dim }, TENSOR_NOT_REQUIRED); + layer.ssm_conv1d = create_tensor(tn(LLM_TENSOR_SSM_CONV1D, "weight", il), { hparams.ssm_d_conv, conv_dim }, flags); + layer.ssm_dt = create_tensor(tn(LLM_TENSOR_SSM_DT, "bias", il), { hparams.ssm_dt_rank }, flags); + layer.ssm_a = create_tensor(tn(LLM_TENSOR_SSM_A_NOSCAN, il), { hparams.ssm_dt_rank }, flags); + layer.ssm_beta = create_tensor(tn(LLM_TENSOR_SSM_BETA, "weight", il), { n_embd, n_v_heads }, flags); + layer.ssm_alpha = create_tensor(tn(LLM_TENSOR_SSM_ALPHA, "weight", il), { n_embd, n_v_heads }, flags); + layer.ssm_norm = create_tensor(tn(LLM_TENSOR_SSM_NORM, "weight", il), { head_v_dim }, flags); + layer.ssm_out = create_tensor(tn(LLM_TENSOR_SSM_OUT, "weight", il), { value_dim, n_embd }, flags); + } + + // Routed experts + layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", il), { n_embd, n_expert }, flags); + layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", il), { n_ff_exp, n_embd, n_expert }, flags); + create_tensor_gate_up_exps(layer, il, n_embd, n_ff_exp, n_expert, flags); + + // Shared experts + layer.ffn_gate_inp_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP_SHEXP, "weight", il), { n_embd }, flags); + layer.ffn_gate_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", il), { n_embd, n_ff_shexp }, flags); + layer.ffn_up_shexp = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", il), { n_embd, n_ff_shexp }, flags); + layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", il), { n_ff_shexp, n_embd }, flags); + }; + + auto load_block_mtp = [&](int il) { + auto & layer = layers[il]; + + const int64_t n_ff_exp = hparams.n_ff_exp ? hparams.n_ff_exp : n_ff / n_expert_used; + const int64_t n_ff_shexp = hparams.n_ff_shexp ? hparams.n_ff_shexp : n_ff; + + // MTP block looks like a full-attention Qwen3.5 decoder block with MoE FFN. + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", il), { n_embd }, 0); + layer.attn_post_norm = create_tensor(tn(LLM_TENSOR_ATTN_POST_NORM, "weight", il), { n_embd }, 0); + + create_tensor_qkv(layer, il, n_embd, n_embd_head_k * n_head * 2, n_embd_k_gqa, n_embd_v_gqa, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", il), { n_embd_head_k * n_head, n_embd }, 0); + layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", il), { n_embd_head_k }, 0); + layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", il), { n_embd_head_k }, 0); + + // Routed experts + layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", il), { n_embd, n_expert }, 0); + layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", il), { n_ff_exp, n_embd, n_expert }, 0); + create_tensor_gate_up_exps(layer, il, n_embd, n_ff_exp, n_expert, 0); + + // Shared experts + layer.ffn_gate_inp_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP_SHEXP, "weight", il), { n_embd }, 0); + layer.ffn_gate_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", il), { n_embd, n_ff_shexp }, 0); + layer.ffn_up_shexp = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", il), { n_embd, n_ff_shexp }, 0); + layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", il), { n_ff_shexp, n_embd }, 0); + + // NextN-specific tensors that define the MTP block. + layer.nextn.eh_proj = create_tensor(tn(LLM_TENSOR_NEXTN_EH_PROJ, "weight", il), { 2 * n_embd, n_embd }, 0); + layer.nextn.enorm = create_tensor(tn(LLM_TENSOR_NEXTN_ENORM, "weight", il), { n_embd }, 0); + layer.nextn.hnorm = create_tensor(tn(LLM_TENSOR_NEXTN_HNORM, "weight", il), { n_embd }, 0); + layer.nextn.embed_tokens = create_tensor(tn(LLM_TENSOR_NEXTN_EMBED_TOKENS, "weight", il), { n_embd, n_vocab }, TENSOR_NOT_REQUIRED); + layer.nextn.shared_head_head = create_tensor(tn(LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD, "weight", il), { n_embd, n_vocab }, TENSOR_NOT_REQUIRED); + layer.nextn.shared_head_norm = create_tensor(tn(LLM_TENSOR_NEXTN_SHARED_HEAD_NORM, "weight", il), { n_embd }, TENSOR_NOT_REQUIRED); + }; + + for (int i = 0; i < n_layer; ++i) { + load_block_trunk(i, trunk_flags); + } + for (int i = n_layer; i < n_layer_all; ++i) { + load_block_mtp(i); + } +} + +std::unique_ptr<llm_graph_context> llama_model_qwen35moe::build_arch_graph(const llm_graph_params & params) const { + if (params.gtype == LLM_GRAPH_TYPE_DECODER_MTP) { + return std::make_unique<graph_mtp>(*this, params); + } + return std::make_unique<graph>(*this, params); +} + +llama_model_qwen35moe::graph::graph(const llama_model & model, const llm_graph_params & params) : + llm_build_delta_net_base(params), model(model) { + const int64_t n_embd_head = hparams.n_embd_head_v(); + + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k()); + + int sections[4]; + std::copy(std::begin(hparams.rope_sections), std::begin(hparams.rope_sections) + 4, sections); + + ggml_tensor * cur; + ggml_tensor * inpL; + + inpL = build_inp_embd(model.tok_embd); + + cb(inpL, "model.input_embed", -1); + + auto * inp = build_inp_mem_hybrid(); + + ggml_tensor * inp_pos = build_inp_pos(); + ggml_tensor * inp_out_ids = build_inp_out_ids(); + + // MTP/NextN layers are loaded as extra decoder blocks but not executed in the main pass. + for (int il = 0; il < n_layer; ++il) { + ggml_tensor * inpSA = inpL; + + cur = build_norm(inpL, model.layers[il].attn_norm, nullptr, LLM_NORM_RMS, il); + cb(cur, "attn_norm", il); + + ggml_build_forward_expand(gf, cur); + + // Determine layer type and build appropriate attention mechanism + if (hparams.is_recr(il)) { + // Linear attention layer (gated delta net) + cur = build_layer_attn_linear(inp->get_recr(), cur, il); + } else { + // Full attention layer + cur = build_layer_attn(inp->get_attn(), cur, inp_pos, sections, il); + } + + if (il == n_layer - 1 && inp_out_ids && cparams.embeddings_nextn_masked) { + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); + } + + // Residual connection + cur = ggml_add(ctx0, cur, inpSA); + cb(cur, "attn_residual", il); + + // Save the tensor before post-attention norm for residual connection + ggml_tensor * ffn_residual = cur; + + // Post-attention norm + ggml_tensor * attn_post_norm = build_norm(cur, model.layers[il].attn_post_norm, nullptr, LLM_NORM_RMS, il); + cb(attn_post_norm, "attn_post_norm", il); + + // MOE FFN layer + cur = build_layer_ffn(attn_post_norm, il); + cb(cur, "ffn_out", il); + + // Residual connection for FFN - add to the tensor from before post_attention_layernorm + cur = ggml_add(ctx0, cur, ffn_residual); + cb(cur, "post_moe", il); + + cur = build_cvec(cur, il); + cb(cur, "l_out", il); + + // Input for next layer + inpL = cur; + } + cur = inpL; + + // post-norm hidden state feeds both the LM head and the MTP seed below + cur = build_norm(cur, model.output_norm, nullptr, LLM_NORM_RMS, -1); + + cb(cur, "h_nextn", -1); + res->t_h_nextn = cur; + + if (!cparams.embeddings_nextn_masked && inp_out_ids) { + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + } + + cb(cur, "result_norm", -1); + res->t_embd = cur; + + // LM head + cur = build_lora_mm(model.output, cur, model.output_s); + + cb(cur, "result_output", -1); + res->t_logits = cur; + + ggml_build_forward_expand(gf, cur); +} + +std::pair<ggml_tensor *, ggml_tensor *> llama_model_qwen35moe::graph::build_qkvz( + ggml_tensor * input, + int il) { + const int64_t n_seqs = ubatch.n_seqs; + const int64_t n_seq_tokens = ubatch.n_seq_tokens; + + ggml_tensor * qkv_mixed = build_lora_mm(model.layers[il].wqkv, input, model.layers[il].wqkv_s); + qkv_mixed = ggml_reshape_3d(ctx0, qkv_mixed, qkv_mixed->ne[0], n_seq_tokens, n_seqs); + cb(qkv_mixed, "linear_attn_qkv_mixed", il); + + ggml_tensor * z = build_lora_mm(model.layers[il].wqkv_gate, input, model.layers[il].wqkv_gate_s); + cb(z, "z", il); + + return { qkv_mixed, z }; +} + +ggml_tensor * llama_model_qwen35moe::graph::build_norm_gated( + ggml_tensor * input, + ggml_tensor * weights, + ggml_tensor * gate, + int layer) { + ggml_tensor * normalized = build_norm(input, weights, nullptr, LLM_NORM_RMS, layer); + ggml_tensor * gated_silu = ggml_silu(ctx0, gate); + + return ggml_mul(ctx0, normalized, gated_silu); +} + +ggml_tensor * llama_model_qwen35moe::graph::build_layer_attn( + llm_graph_input_attn_kv * inp, + ggml_tensor * cur, + ggml_tensor * inp_pos, + int * sections, + int il) { + const int64_t n_embd_head = hparams.n_embd_head_v(); + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k()); + + // Order: joint QG projection, QG split, Q norm, KV projection, K norm, RoPE, attention + + // Qwen3Next uses a single Q projection that outputs query + gate + ggml_tensor * Qcur_full = build_lora_mm(model.layers[il].wq, cur, model.layers[il].wq_s); // [ (n_embd_head * 2) * n_head, n_tokens ] + cb(Qcur_full, "Qcur_full", il); + + ggml_tensor * Qcur = ggml_view_3d(ctx0, Qcur_full, n_embd_head, n_head, n_tokens, + ggml_element_size(Qcur_full) * n_embd_head * 2, + ggml_element_size(Qcur_full) * n_embd_head * 2 * n_head, 0); + cb(Qcur, "Qcur_reshaped", il); + + // Apply Q normalization + Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, nullptr, LLM_NORM_RMS, il); + cb(Qcur, "Qcur_normed", il); + + ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur, model.layers[il].wk_s); + cb(Kcur, "Kcur", il); + + ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur, model.layers[il].wv_s); + cb(Vcur, "Vcur", il); + + // Apply K normalization + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); + Kcur = build_norm(Kcur, model.layers[il].attn_k_norm, nullptr, LLM_NORM_RMS, il); + cb(Kcur, "Kcur_normed", il); + + ggml_tensor * gate = ggml_view_3d(ctx0, Qcur_full, n_embd_head, n_head, n_tokens, + ggml_element_size(Qcur_full) * n_embd_head * 2, + ggml_element_size(Qcur_full) * n_embd_head * 2 * n_head, + ggml_element_size(Qcur_full) * n_embd_head); + gate = ggml_cont_2d(ctx0, gate, n_embd_head * n_head, n_tokens); + cb(gate, "gate_reshaped", il); + + Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + + // Apply IMRoPE + Qcur = ggml_rope_multi( + ctx0, Qcur, inp_pos, nullptr, + n_rot, sections, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + Kcur = ggml_rope_multi( + ctx0, Kcur, inp_pos, nullptr, + n_rot, sections, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + cb(Qcur, "Qcur", il); + cb(Kcur, "Kcur", il); + cb(Vcur, "Vcur", il); + + // Attention computation + const float kq_scale = hparams.f_attention_scale == 0.0f ? 1.0f / sqrtf(float(n_embd_head)) : hparams.f_attention_scale; + + cur = build_attn(inp, + nullptr, nullptr, nullptr, + Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale, il); + cb(cur, "attn_pregate", il); + + ggml_tensor * gate_sigmoid = ggml_sigmoid(ctx0, gate); + cb(gate_sigmoid, "gate_sigmoid", il); + + cur = ggml_mul(ctx0, cur, gate_sigmoid); + cb(cur, "attn_gated", il); + + cur = build_lora_mm(model.layers[il].wo, cur, model.layers[il].wo_s); + cb(cur, "attn_output", il); + + return cur; +} + +ggml_tensor * llama_model_qwen35moe::graph::build_layer_attn_linear( + llm_graph_input_rs * inp, + ggml_tensor * cur, + int il) { + const auto * mctx_cur = inp->mctx; + + const int64_t d_inner = hparams.ssm_d_inner; + const int64_t n_seqs = ubatch.n_seqs; + const int64_t head_k_dim = hparams.ssm_d_state; + const int64_t num_k_heads = hparams.ssm_n_group; + const int64_t num_v_heads = hparams.ssm_dt_rank; + const int64_t head_v_dim = d_inner / num_v_heads; + const int64_t n_seq_tokens = ubatch.n_seq_tokens; + + GGML_ASSERT(n_seqs != 0); + GGML_ASSERT(ubatch.equal_seqs()); + GGML_ASSERT(ubatch.n_tokens == n_seq_tokens * n_seqs); + + // Input projections + auto qkvz = build_qkvz(cur, il); + ggml_tensor * qkv_mixed = qkvz.first; + ggml_tensor * z = qkvz.second; + + ggml_tensor * beta = build_lora_mm(model.layers[il].ssm_beta, cur, model.layers[il].ssm_beta_s); + beta = ggml_reshape_4d(ctx0, beta, 1, num_v_heads, n_seq_tokens, n_seqs); + cb(beta, "beta", il); + + beta = ggml_sigmoid(ctx0, beta); + cb(beta, "beta_sigmoid", il); + + ggml_tensor * alpha = build_lora_mm(model.layers[il].ssm_alpha, cur, model.layers[il].ssm_alpha_s); + alpha = ggml_reshape_3d(ctx0, alpha, num_v_heads, n_seq_tokens, n_seqs); + cb(alpha, "alpha", il); + + ggml_tensor * alpha_biased = ggml_add(ctx0, alpha, model.layers[il].ssm_dt); + ggml_tensor * alpha_softplus = ggml_softplus(ctx0, alpha_biased); + cb(alpha_softplus, "a_softplus", il); + + ggml_tensor * gate = ggml_mul(ctx0, alpha_softplus, model.layers[il].ssm_a); // -A_log.exp() * softplus + cb(gate, "gate", il); + + gate = ggml_reshape_4d(ctx0, gate, 1, num_v_heads, n_seq_tokens, n_seqs); + + ggml_tensor * conv_states_all = mctx_cur->get_r_l(il); + ggml_tensor * ssm_states_all = mctx_cur->get_s_l(il); + + ggml_tensor * conv_kernel = model.layers[il].ssm_conv1d; + const int64_t conv_kernel_size = conv_kernel->ne[0]; + const int64_t conv_channels = d_inner + 2 * hparams.ssm_n_group * hparams.ssm_d_state; + + ggml_tensor * conv_input = build_conv_state(inp, conv_states_all, qkv_mixed, conv_kernel_size, conv_channels, il); + + ggml_tensor * state = build_rs(inp, ssm_states_all, hparams.n_embd_s(), n_seqs); + state = ggml_reshape_4d(ctx0, state, head_v_dim, head_v_dim, num_v_heads, n_seqs); + cb(state, "state_predelta", il); + + ggml_tensor * conv_output_proper = ggml_ssm_conv(ctx0, conv_input, conv_kernel); + cb(conv_output_proper, "conv_output_raw", il); + + ggml_tensor * conv_output_silu = ggml_silu(ctx0, conv_output_proper); + cb(conv_output_silu, "conv_output_silu", il); + + ggml_tensor * conv_qkv_mix = conv_output_silu; + + // Calculate the total conv dimension + int64_t qkv_dim = head_k_dim * num_k_heads * 2 + head_v_dim * num_v_heads; + int64_t nb1_qkv = ggml_row_size(conv_qkv_mix->type, qkv_dim); + + // Extract the convolved Q, K, V from conv_output + ggml_tensor * q_conv = ggml_view_4d(ctx0, conv_qkv_mix, head_k_dim, num_k_heads, n_seq_tokens, n_seqs, + ggml_row_size(conv_qkv_mix->type, head_k_dim), + nb1_qkv, + nb1_qkv * n_seq_tokens, + 0); + + ggml_tensor * k_conv = ggml_view_4d(ctx0, conv_qkv_mix, head_k_dim, num_k_heads, n_seq_tokens, n_seqs, + ggml_row_size(conv_qkv_mix->type, head_k_dim), + nb1_qkv, + nb1_qkv * n_seq_tokens, + head_k_dim * num_k_heads * ggml_element_size(conv_qkv_mix)); + + ggml_tensor * v_conv = ggml_view_4d(ctx0, conv_qkv_mix, head_v_dim, num_v_heads, n_seq_tokens, n_seqs, + ggml_row_size(conv_qkv_mix->type, head_v_dim), + nb1_qkv, + nb1_qkv * n_seq_tokens, + ggml_row_size(conv_qkv_mix->type, 2 * head_k_dim * num_k_heads)); + + cb(q_conv, "q_conv", il); + cb(k_conv, "k_conv", il); + cb(v_conv, "v_conv", il); + + const float eps_norm = hparams.f_norm_rms_eps; + + q_conv = ggml_l2_norm(ctx0, q_conv, eps_norm); + k_conv = ggml_l2_norm(ctx0, k_conv, eps_norm); + + //q_conv = ggml_cont_4d(ctx0, q_conv, head_k_dim, num_k_heads, n_seq_tokens, n_seqs); + //k_conv = ggml_cont_4d(ctx0, k_conv, head_k_dim, num_k_heads, n_seq_tokens, n_seqs); + //v_conv = ggml_cont_4d(ctx0, v_conv, head_v_dim, num_v_heads, n_seq_tokens, n_seqs); + + // if head keys and value keys are different, repeat to force tensors into matching shapes + // note: need explicit repeat only if we are not using the fused GDN. + if (num_k_heads != num_v_heads && (!cparams.fused_gdn_ar || !cparams.fused_gdn_ch)) { + GGML_ASSERT(num_v_heads % num_k_heads == 0); + q_conv = ggml_repeat_4d(ctx0, q_conv, head_k_dim, num_v_heads, n_seq_tokens, n_seqs); + k_conv = ggml_repeat_4d(ctx0, k_conv, head_k_dim, num_v_heads, n_seq_tokens, n_seqs); + } + + cb(q_conv, "q_conv_predelta", il); + cb(k_conv, "k_conv_predelta", il); + cb(v_conv, "v_conv_predelta", il); + + ggml_tensor * output = build_recurrent_attn(inp, ssm_states_all, q_conv, k_conv, v_conv, gate, beta, state, il); + + // z: [head_dim, n_heads, n_tokens, n_seqs] -> [n_heads * n_tokens * n_seqs, head_dim] + ggml_tensor * z_2d = ggml_reshape_4d(ctx0, z, head_v_dim, num_v_heads, n_seq_tokens, n_seqs); + + // Apply gated normalization: self.norm(core_attn_out, z) + ggml_tensor * attn_out_norm = build_norm_gated(output, model.layers[il].ssm_norm, z_2d, il); + + // Final reshape: [head_dim, n_heads, n_tokens, n_seqs] -> [n_tokens, n_seqs, n_heads * head_dim] + ggml_tensor * final_output = ggml_reshape_3d(ctx0, attn_out_norm, head_v_dim * num_v_heads, n_seq_tokens, n_seqs); + cb(final_output, "final_output", il); + + // Output projection + cur = build_lora_mm(model.layers[il].ssm_out, final_output, model.layers[il].ssm_out_s); + cb(cur, "linear_attn_out", il); + + // Reshape back to original dimensions + cur = ggml_reshape_2d(ctx0, cur, n_embd, n_seq_tokens * n_seqs); + + return cur; +} + +ggml_tensor * llama_model_qwen35moe::graph::build_layer_ffn(ggml_tensor * cur, const int il) { + // Check if this is an MoE layer + GGML_ASSERT(model.layers[il].ffn_gate_inp != nullptr); + + ggml_tensor * moe_out = + build_moe_ffn(cur, + model.layers[il].ffn_gate_inp, + model.layers[il].ffn_up_exps, + model.layers[il].ffn_gate_exps, + model.layers[il].ffn_down_exps, + nullptr, + n_expert, n_expert_used, + LLM_FFN_SILU, true, + hparams.expert_weights_scale, + LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX, il, + nullptr, model.layers[il].ffn_gate_up_exps, + model.layers[il].ffn_up_exps_s, + model.layers[il].ffn_gate_exps_s, + model.layers[il].ffn_down_exps_s); + cb(moe_out, "ffn_moe_out", il); + + // Add shared experts if present - following Qwen3Next reference implementation + if (model.layers[il].ffn_up_shexp != nullptr) { + ggml_tensor * ffn_shexp = + build_ffn(cur, + model.layers[il].ffn_up_shexp, NULL, model.layers[il].ffn_up_shexp_s, + model.layers[il].ffn_gate_shexp, NULL, model.layers[il].ffn_gate_shexp_s, + model.layers[il].ffn_down_shexp, NULL, model.layers[il].ffn_down_shexp_s, + NULL, + LLM_FFN_SILU, LLM_FFN_PAR, il); + cb(ffn_shexp, "ffn_shexp", il); + + // Apply shared expert gating as in the reference implementation + // The shared expert has its own gate that is sigmoided + // Note: ffn_gate_inp_shexp is the shared expert gate (outputs 1 value per token) + ggml_tensor * shared_gate = build_lora_mm(model.layers[il].ffn_gate_inp_shexp, cur); + cb(shared_gate, "shared_expert_gate", il); + + // Apply sigmoid to the gate + shared_gate = ggml_sigmoid(ctx0, shared_gate); + cb(shared_gate, "shared_expert_gate_sigmoid", il); + + + // Apply the gate to the shared expert output + ffn_shexp = ggml_mul(ctx0, ffn_shexp, shared_gate); + cb(ffn_shexp, "ffn_shexp_gated", il); + + cur = ggml_add(ctx0, moe_out, ffn_shexp); + cb(cur, "ffn_out", il); + } else { + cur = moe_out; + } + + return cur; +} + +// LLM_GRAPH_TYPE_DECODER_MTP draft head for Qwen3.5/3.6 MoE +llama_model_qwen35moe::graph_mtp::graph_mtp(const llama_model & model, const llm_graph_params & params) + : llm_graph_context(params) { + GGML_ASSERT(hparams.n_layer_nextn > 0 && "QWEN35MOE MTP requires n_layer_nextn > 0"); + GGML_ASSERT(hparams.n_layer_nextn == 1 && "QWEN35MOE MTP currently only supports a single MTP block"); + + const int64_t n_embd_head = hparams.n_embd_head_v(); + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k()); + + const int il = hparams.n_layer(); + const auto & layer = model.layers[il]; + + GGML_ASSERT(layer.nextn.eh_proj && "MTP block missing nextn.eh_proj"); + GGML_ASSERT(layer.nextn.enorm && "MTP block missing nextn.enorm"); + GGML_ASSERT(layer.nextn.hnorm && "MTP block missing nextn.hnorm"); + GGML_ASSERT(layer.ffn_gate_inp && "MTP block missing ffn_gate_inp"); + + int sections[4]; + std::copy(std::begin(hparams.rope_sections), std::begin(hparams.rope_sections) + 4, sections); + + // TODO: extract in a common llm_graph_context::build_inp_embd_h() + auto inp = std::make_unique<llm_graph_input_embd_h>(hparams.n_embd); + + inp->tokens = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens); + ggml_set_input(inp->tokens); + + inp->embd = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, hparams.n_embd_inp(), n_tokens); + ggml_set_input(inp->embd); + + // TODO: make static using `ggml_build_forward_select()` + // see llm_graph_context::build_inp_embd() for reference + ggml_tensor * tok_embd; + if (ubatch.token) { + ggml_tensor * tok_embd_w = layer.nextn.embed_tokens ? layer.nextn.embed_tokens : model.tok_embd; + + tok_embd = ggml_get_rows(ctx0, tok_embd_w, inp->tokens); + } else { + tok_embd = inp->embd; + } + cb(tok_embd, "mtp_tok_embd", il); + + inp->h = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, hparams.n_embd, n_tokens); + ggml_set_input(inp->h); + ggml_set_name(inp->h, "mtp_h_input"); + + ggml_tensor * h_embd = inp->h; + + res->add_input(std::move(inp)); + + ggml_tensor * inp_pos = build_inp_pos(); + ggml_tensor * inp_out_ids = build_inp_out_ids(); + + auto * inp_attn = build_attn_inp_kv(); + + ggml_tensor * h_norm = build_norm(h_embd, layer.nextn.hnorm, nullptr, LLM_NORM_RMS, il); + cb(h_norm, "mtp_hnorm", il); + + ggml_tensor * e_norm = build_norm(tok_embd, layer.nextn.enorm, nullptr, LLM_NORM_RMS, il); + cb(e_norm, "mtp_enorm", il); + + ggml_tensor * concat = ggml_concat(ctx0, e_norm, h_norm, /*dim=*/ 0); + cb(concat, "mtp_concat", il); + + ggml_tensor * cur = build_lora_mm(layer.nextn.eh_proj, concat, layer.nextn.eh_proj_s); + cb(cur, "mtp_eh_proj", il); + + ggml_tensor * inpSA = cur; + + cur = build_norm(cur, layer.attn_norm, nullptr, LLM_NORM_RMS, il); + cb(cur, "mtp_attn_norm", il); + + ggml_tensor * Qcur_full = build_lora_mm(layer.wq, cur, layer.wq_s); + cb(Qcur_full, "mtp_Qcur_full", il); + + ggml_tensor * Qcur = ggml_view_3d(ctx0, Qcur_full, + n_embd_head, n_head, n_tokens, + ggml_element_size(Qcur_full) * n_embd_head * 2, + ggml_element_size(Qcur_full) * n_embd_head * 2 * n_head, + 0); + Qcur = build_norm(Qcur, layer.attn_q_norm, nullptr, LLM_NORM_RMS, il); + cb(Qcur, "mtp_Qcur_normed", il); + + ggml_tensor * gate = ggml_view_3d(ctx0, Qcur_full, + n_embd_head, n_head, n_tokens, + ggml_element_size(Qcur_full) * n_embd_head * 2, + ggml_element_size(Qcur_full) * n_embd_head * 2 * n_head, + ggml_element_size(Qcur_full) * n_embd_head); + gate = ggml_cont_2d(ctx0, gate, n_embd_head * n_head, n_tokens); + cb(gate, "mtp_gate", il); + + ggml_tensor * Kcur = build_lora_mm(layer.wk, cur, layer.wk_s); + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); + Kcur = build_norm(Kcur, layer.attn_k_norm, nullptr, LLM_NORM_RMS, il); + cb(Kcur, "mtp_Kcur_normed", il); + + ggml_tensor * Vcur = build_lora_mm(layer.wv, cur, layer.wv_s); + Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + cb(Vcur, "mtp_Vcur", il); + + Qcur = ggml_rope_multi(ctx0, Qcur, inp_pos, nullptr, + n_rot, sections, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow); + Kcur = ggml_rope_multi(ctx0, Kcur, inp_pos, nullptr, + n_rot, sections, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow); + + const float kq_scale = hparams.f_attention_scale == 0.0f + ? 1.0f / sqrtf(float(n_embd_head)) : hparams.f_attention_scale; + + cur = build_attn(inp_attn, + nullptr, nullptr, nullptr, + Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale, il); + cb(cur, "mtp_attn_pregate", il); + + cur = ggml_mul(ctx0, cur, ggml_sigmoid(ctx0, gate)); + cur = build_lora_mm(layer.wo, cur, layer.wo_s); + cb(cur, "mtp_attn_out", il); + + cur = ggml_add(ctx0, cur, inpSA); + cb(cur, "mtp_attn_residual", il); + + ggml_tensor * ffn_residual = cur; + cur = build_norm(cur, layer.attn_post_norm, nullptr, LLM_NORM_RMS, il); + cb(cur, "mtp_attn_post_norm", il); + + // MoE FFN — routed experts plus gated shared expert (mirrors qwen35moe). + ggml_tensor * moe_out = + build_moe_ffn(cur, + layer.ffn_gate_inp, + layer.ffn_up_exps, + layer.ffn_gate_exps, + layer.ffn_down_exps, + nullptr, + n_expert, n_expert_used, + LLM_FFN_SILU, true, + hparams.expert_weights_scale, + LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX, il, + nullptr, layer.ffn_gate_up_exps, + layer.ffn_up_exps_s, + layer.ffn_gate_exps_s, + layer.ffn_down_exps_s); + cb(moe_out, "mtp_ffn_moe_out", il); + + if (layer.ffn_up_shexp != nullptr) { + ggml_tensor * ffn_shexp = + build_ffn(cur, + layer.ffn_up_shexp, nullptr, layer.ffn_up_shexp_s, + layer.ffn_gate_shexp, nullptr, layer.ffn_gate_shexp_s, + layer.ffn_down_shexp, nullptr, layer.ffn_down_shexp_s, + nullptr, + LLM_FFN_SILU, LLM_FFN_PAR, il); + cb(ffn_shexp, "mtp_ffn_shexp", il); + + ggml_tensor * shared_gate = build_lora_mm(layer.ffn_gate_inp_shexp, cur); + shared_gate = ggml_sigmoid(ctx0, shared_gate); + cb(shared_gate, "mtp_shared_expert_gate_sigmoid", il); + + ffn_shexp = ggml_mul(ctx0, ffn_shexp, shared_gate); + cb(ffn_shexp, "mtp_ffn_shexp_gated", il); + + cur = ggml_add(ctx0, moe_out, ffn_shexp); + } else { + cur = moe_out; + } + cb(cur, "mtp_ffn_out", il); + + cur = ggml_add(ctx0, cur, ffn_residual); + cb(cur, "mtp_post_ffn", il); + + ggml_tensor * head_norm_w = layer.nextn.shared_head_norm + ? layer.nextn.shared_head_norm + : model.output_norm; + GGML_ASSERT(head_norm_w && "QWEN35MOE MTP: missing both nextn.shared_head_norm and output_norm"); + cur = build_norm(cur, head_norm_w, nullptr, LLM_NORM_RMS, -1); + + cb(cur, "h_nextn", -1); + res->t_h_nextn= cur; + + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + cb(cur, "mtp_shared_head_norm", -1); + + ggml_tensor * head_w = layer.nextn.shared_head_head ? layer.nextn.shared_head_head : model.output; + ggml_tensor * head_s = layer.nextn.shared_head_head ? layer.nextn.shared_head_head_s : model.output_s; + GGML_ASSERT(head_w && "QWEN35MOE MTP: missing LM head (nextn.shared_head_head or model.output)"); + cur = build_lora_mm(head_w, cur, head_s); + cb(cur, "result_output", -1); + + res->t_logits = cur; + ggml_build_forward_expand(gf, cur); +} diff --git a/examples/talk-llama/models/qwen3moe.cpp b/examples/talk-llama/models/qwen3moe.cpp index 888534fb347..6f6df5390e3 100644 --- a/examples/talk-llama/models/qwen3moe.cpp +++ b/examples/talk-llama/models/qwen3moe.cpp @@ -1,10 +1,69 @@ #include "models.h" -llm_build_qwen3moe::llm_build_qwen3moe(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { - const int64_t n_embd_head = hparams.n_embd_head_v; +void llama_model_qwen3moe::load_arch_hparams(llama_model_loader & ml) { + ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp, false); + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + + switch (hparams.n_layer()) { + case 48: type = LLM_TYPE_30B_A3B; break; + case 94: type = LLM_TYPE_235B_A22B; break; + default: type = LLM_TYPE_UNKNOWN; + } +} + +void llama_model_qwen3moe::load_arch_tensors(llama_model_loader &) { + LLAMA_LOAD_LOCALS; + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); + // if output is NULL, init from the input tok embed + if (output == NULL) { + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); + } + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + + create_tensor_qkv(layer, i, n_embd, n_embd_head_k * n_head, n_embd_gqa, n_embd_gqa, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0); + + layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k}, 0); + layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head_k}, 0); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + + layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0); + + if (n_expert == 0) { + throw std::runtime_error("n_expert must be > 0 for QWEN3MOE"); + } + if (n_expert_used == 0) { + throw std::runtime_error("n_expert_used must be > 0 for QWEN3MOE"); + } - GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); - GGML_ASSERT(n_embd_head == hparams.n_rot); + // MoE branch + const int64_t n_ff_exp = hparams.n_ff_exp ? hparams.n_ff_exp : n_ff / n_expert_used; + + layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}, 0); + layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff_exp, n_embd, n_expert}, 0); + layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}, 0); + } +} + +std::unique_ptr<llm_graph_context> llama_model_qwen3moe::build_arch_graph(const llm_graph_params & params) const { + return std::make_unique<graph>(*this, params); +} + +llama_model_qwen3moe::graph::graph(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { + const int64_t n_embd_head = hparams.n_embd_head_v(); + + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k()); + GGML_ASSERT(n_embd_head == n_rot); ggml_tensor * cur; ggml_tensor * inpL; @@ -19,6 +78,8 @@ llm_build_qwen3moe::llm_build_qwen3moe(const llama_model & model, const llm_grap ggml_tensor * inp_out_ids = build_inp_out_ids(); for (int il = 0; il < n_layer; ++il) { + res->t_layer_inp[il] = inpL; + ggml_tensor * inpSA = inpL; // norm @@ -30,18 +91,8 @@ llm_build_qwen3moe::llm_build_qwen3moe(const llama_model & model, const llm_grap // self_attention { // compute Q and K and RoPE them - ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); - cb(Qcur, "Qcur", il); - - ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); - cb(Kcur, "Kcur", il); - - ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); - cb(Vcur, "Vcur", il); - - Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); - Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + auto [Qcur, Kcur, Vcur] = build_qkv(model.layers[il], cur, + n_embd_head, n_head, n_head_kv, il); Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, NULL, LLM_NORM_RMS, il); cb(Qcur, "Qcur_normed", il); @@ -66,7 +117,7 @@ llm_build_qwen3moe::llm_build_qwen3moe(const llama_model & model, const llm_grap cb(Vcur, "Vcur", il); cur = build_attn(inp_attn, - model.layers[il].wo, model.layers[il].bo, + model.layers[il].wo, model.layers[il].wo_b, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } if (il == n_layer - 1 && inp_out_ids) { @@ -91,9 +142,13 @@ llm_build_qwen3moe::llm_build_qwen3moe(const llama_model & model, const llm_grap nullptr, n_expert, n_expert_used, LLM_FFN_SILU, true, - false, 0.0, + hparams.expert_weights_scale, LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX, - il); + il, + nullptr, nullptr, + model.layers[il].ffn_up_exps_s, + model.layers[il].ffn_gate_exps_s, + model.layers[il].ffn_down_exps_s); cb(moe_out, "ffn_moe_out", il); cur = moe_out; @@ -115,7 +170,7 @@ llm_build_qwen3moe::llm_build_qwen3moe(const llama_model & model, const llm_grap res->t_embd = cur; // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/qwen3next.cpp b/examples/talk-llama/models/qwen3next.cpp index 57b6659baf0..97200a44072 100644 --- a/examples/talk-llama/models/qwen3next.cpp +++ b/examples/talk-llama/models/qwen3next.cpp @@ -1,10 +1,114 @@ -#include "ggml.h" #include "models.h" +#include "llama-memory-recurrent.h" + +void llama_model_qwen3next::load_arch_hparams(llama_model_loader & ml) { + ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp, false); + ml.get_key(LLM_KV_EXPERT_SHARED_FEED_FORWARD_LENGTH, hparams.n_ff_shexp, false); + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + + // Load linear attention (gated delta net) parameters + ml.get_key(LLM_KV_SSM_CONV_KERNEL, hparams.ssm_d_conv); + ml.get_key(LLM_KV_SSM_INNER_SIZE, hparams.ssm_d_inner); + ml.get_key(LLM_KV_SSM_STATE_SIZE, hparams.ssm_d_state); + ml.get_key(LLM_KV_SSM_TIME_STEP_RANK, hparams.ssm_dt_rank); + ml.get_key(LLM_KV_SSM_GROUP_COUNT, hparams.ssm_n_group); + + // Mark recurrent layers (linear attention layers) + if (!ml.get_key_or_arr(LLM_KV_ATTENTION_RECURRENT_LAYERS, hparams.is_recr_impl, hparams.n_layer_all, false)) { + uint32_t full_attn_interval = 4; + ml.get_key(LLM_KV_FULL_ATTENTION_INTERVAL, full_attn_interval, false); + for (uint32_t i = 0; i < hparams.n_layer_all; ++i) { + hparams.is_recr_impl[i] = (i < hparams.n_layer()) && ((i + 1) % full_attn_interval != 0); + } + } + + switch (hparams.n_layer()) { + case 48: type = LLM_TYPE_80B_A3B; break; + default: type = LLM_TYPE_UNKNOWN; + } +} + +void llama_model_qwen3next::load_arch_tensors(llama_model_loader &) { + LLAMA_LOAD_LOCALS; + + if (n_expert == 0) { + throw std::runtime_error(arch_name() + " model cannot have zero experts"); + } + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), { n_embd, n_vocab }, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), { n_embd }, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), { n_embd, n_vocab }, TENSOR_NOT_REQUIRED); + + // if output is NULL, init from the input tok embed + if (output == NULL) { + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), { n_embd, n_vocab }, TENSOR_DUPLICATED); + } + + const int64_t n_ff_exp = hparams.n_ff_exp ? hparams.n_ff_exp : n_ff / n_expert_used; + + // Calculate dimensions from hyperparameters + const int64_t head_k_dim = hparams.ssm_d_state; + const int64_t head_v_dim = hparams.ssm_d_state; + const int64_t n_k_heads = hparams.ssm_n_group; + const int64_t n_v_heads = hparams.ssm_dt_rank; + const int64_t key_dim = head_k_dim * n_k_heads; + const int64_t value_dim = head_v_dim * n_v_heads; + const int64_t conv_dim = key_dim * 2 + value_dim; + + // Calculate projection sizes + const int64_t qkvz_dim = key_dim * 2 + value_dim * 2; + const int64_t ba_dim = n_v_heads * 2; + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + const uint32_t n_ff_shexp = hparams.n_ff_shexp > 0 ? hparams.n_ff_shexp : hparams.n_ff(i); + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), { n_embd }, 0); + layer.attn_post_norm = create_tensor(tn(LLM_TENSOR_ATTN_POST_NORM, "weight", i), { n_embd }, 0); + + if (!hparams.is_recr(i)) { + // Attention layers + create_tensor_qkv(layer, i, n_embd, n_embd_head_k * n_head * 2, n_embd_k_gqa, n_embd_v_gqa, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), { n_embd_head_k * n_head, n_embd }, 0); + + // Q/K normalization for attention layers + layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), { n_embd_head_k }, 0); + layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), { n_embd_head_k }, 0); + } else { + // Linear attention (gated delta net) specific tensors + // Create tensors with calculated dimensions + // note: ssm_in is used by legacy GGUF + layer.ssm_in = create_tensor(tn(LLM_TENSOR_SSM_IN, "weight", i), { n_embd, qkvz_dim }, TENSOR_NOT_REQUIRED); + layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), { n_embd, key_dim * 2 + value_dim }, TENSOR_NOT_REQUIRED); + layer.wqkv_gate = create_tensor(tn(LLM_TENSOR_ATTN_GATE, "weight", i), { n_embd, value_dim }, TENSOR_NOT_REQUIRED); + layer.ssm_conv1d = create_tensor(tn(LLM_TENSOR_SSM_CONV1D, "weight", i), { hparams.ssm_d_conv, conv_dim }, 0); + layer.ssm_dt = create_tensor(tn(LLM_TENSOR_SSM_DT, "bias", i), { hparams.ssm_dt_rank }, 0); + layer.ssm_a = create_tensor(tn(LLM_TENSOR_SSM_A_NOSCAN, i), { hparams.ssm_dt_rank }, 0); + layer.ssm_beta_alpha = create_tensor(tn(LLM_TENSOR_SSM_BETA_ALPHA, "weight", i), { n_embd, ba_dim }, 0); + layer.ssm_norm = create_tensor(tn(LLM_TENSOR_SSM_NORM, "weight", i), { head_v_dim }, 0); + layer.ssm_out = create_tensor(tn(LLM_TENSOR_SSM_OUT, "weight", i), { value_dim, n_embd }, 0); + } + + layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), { n_embd, n_expert }, 0); + layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), { n_ff_exp, n_embd, n_expert }, 0); + create_tensor_gate_up_exps(layer, i, n_embd, n_ff_exp, n_expert, 0); + + // Shared experts + layer.ffn_gate_inp_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP_SHEXP, "weight", i), { n_embd }, 0); + layer.ffn_gate_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), { n_embd, n_ff_shexp }, 0); + layer.ffn_up_shexp = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), { n_embd, n_ff_shexp }, 0); + layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), { n_ff_shexp, n_embd }, 0); + } +} -#define CHUNK_SIZE 64 +std::unique_ptr<llm_graph_context> llama_model_qwen3next::build_arch_graph(const llm_graph_params & params) const { + return std::make_unique<graph>(*this, params); +} -llm_build_qwen3next::llm_build_qwen3next(const llama_model & model, const llm_graph_params & params) : - llm_graph_context_mamba(params), model(model) { +llama_model_qwen3next::graph::graph(const llama_model & model, const llm_graph_params & params) : + llm_build_delta_net_base(params), model(model) { ggml_tensor * cur; ggml_tensor * inpL; @@ -16,27 +120,18 @@ llm_build_qwen3next::llm_build_qwen3next(const llama_model & model, const llm_gr ggml_tensor * inp_pos = build_inp_pos(); ggml_tensor * inp_out_ids = build_inp_out_ids(); - ggml_tensor * causal_mask = - ggml_tri(ctx0, ggml_fill_inplace(ctx0, ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, CHUNK_SIZE, CHUNK_SIZE), 1.0f), - GGML_TRI_TYPE_LOWER); - - ggml_tensor * identity = ggml_diag(ctx0, ggml_fill_inplace(ctx0, ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, CHUNK_SIZE), 1.0f)); - ggml_tensor * diag_mask = ggml_add(ctx0, causal_mask, identity); - - ggml_build_forward_expand(gf, causal_mask); - ggml_build_forward_expand(gf, identity); - ggml_build_forward_expand(gf, diag_mask); - for (int il = 0; il < n_layer; ++il) { ggml_tensor * inpSA = inpL; cur = build_norm(inpL, model.layers[il].attn_norm, nullptr, LLM_NORM_RMS, il); cb(cur, "attn_norm", il); + ggml_build_forward_expand(gf, cur); + // Determine layer type and build appropriate attention mechanism - if (hparams.is_recurrent(il)) { + if (hparams.is_recr(il)) { // Linear attention layer (gated delta net) - cur = build_layer_attn_linear(inp->get_recr(), cur, causal_mask, identity, diag_mask, il); + cur = build_layer_attn_linear(inp->get_recr(), cur, il); } else { // Full attention layer cur = build_layer_attn(inp->get_attn(), cur, inp_pos, il); @@ -66,6 +161,9 @@ llm_build_qwen3next::llm_build_qwen3next(const llama_model & model, const llm_gr cur = ggml_add(ctx0, cur, ffn_residual); cb(cur, "post_moe", il); + cur = build_cvec(cur, il); + cb(cur, "l_out", il); + // Input for next layer inpL = cur; } @@ -78,7 +176,7 @@ llm_build_qwen3next::llm_build_qwen3next(const llama_model & model, const llm_gr res->t_embd = cur; // LM head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; @@ -94,349 +192,7 @@ static ggml_tensor * get_slice_2d(ggml_context * ctx0, ggml_tensor * t, int64_t t->nb[1], t->nb[2], t->nb[3], t->nb[2] * c); } -std::pair<ggml_tensor *, ggml_tensor *> llm_build_qwen3next::build_delta_net_chunking( - ggml_tensor * q, - ggml_tensor * k, - ggml_tensor * v, - ggml_tensor * g, - ggml_tensor * beta, - ggml_tensor * state, - ggml_tensor * causal_mask, - ggml_tensor * identity, - ggml_tensor * diag_mask, - int il) { - const int64_t S_k = q->ne[0]; - const int64_t H_k = q->ne[1]; - const int64_t n_tokens = q->ne[2]; - const int64_t n_seqs = q->ne[3]; - - const int64_t S_v = v->ne[0]; - const int64_t H_v = v->ne[1]; - - GGML_ASSERT(v->ne[2] == n_tokens); - GGML_ASSERT(k->ne[2] == n_tokens); - GGML_ASSERT(g->ne[0] == H_v && g->ne[1] == n_tokens && g->ne[2] == n_seqs); - GGML_ASSERT(beta->ne[0] == H_v && beta->ne[2] == n_tokens && beta->ne[3] == n_seqs); - GGML_ASSERT(state->ne[0] == S_v && state->ne[1] == S_v * H_v && state->ne[2] == 1 && state->ne[3] == n_seqs); - - GGML_ASSERT(q->ne[0] == S_k && q->ne[1] == H_k && q->ne[2] == n_tokens && q->ne[3] == n_seqs); - GGML_ASSERT(k->ne[0] == S_k && k->ne[1] == H_k && k->ne[2] == n_tokens && k->ne[3] == n_seqs); - - GGML_ASSERT(H_k == H_v); // we did a repeat to make sure this is the case - - const float eps_norm = hparams.f_norm_rms_eps; - - q = ggml_l2_norm(ctx0, q, eps_norm); - k = ggml_l2_norm(ctx0, k, eps_norm); - - const float scale = 1.0f / sqrtf(S_v); - - q = ggml_scale(ctx0, q, scale); - - beta = ggml_sigmoid(ctx0, beta); - - cb(q, "q_in", il); - cb(k, "k_in", il); - cb(v, "v_in", il); - cb(beta, "beta_in", il); - cb(g, "g_in", il); - - q = ggml_cont_4d(ctx0, ggml_permute(ctx0, q, 0, 2, 1, 3), S_v, n_tokens, H_v, n_seqs); - k = ggml_cont_4d(ctx0, ggml_permute(ctx0, k, 0, 2, 1, 3), S_v, n_tokens, H_v, n_seqs); - v = ggml_cont_4d(ctx0, ggml_permute(ctx0, v, 0, 2, 1, 3), S_v, n_tokens, H_v, n_seqs); - g = ggml_cont_4d(ctx0, ggml_permute(ctx0, g, 2, 0, 3, 1), n_tokens, 1, H_k, n_seqs); - - beta = ggml_cont(ctx0, ggml_permute(ctx0, beta, 2, 0, 1, 3)); - state = ggml_reshape_4d(ctx0, state, S_v, S_v, H_v, n_seqs); - - cb(q, "q_perm", il); - cb(k, "k_perm", il); - cb(v, "v_perm", il); - cb(beta, "beta_perm", il); - cb(g, "g_perm", il); - cb(state, "state_in", il); - - GGML_ASSERT(q->ne[1] == n_tokens && q->ne[0] == S_k && q->ne[2] == H_k && q->ne[3] == n_seqs); - GGML_ASSERT(k->ne[1] == n_tokens && k->ne[0] == S_k && k->ne[2] == H_k && k->ne[3] == n_seqs); - GGML_ASSERT(v->ne[1] == n_tokens && v->ne[0] == S_v && v->ne[2] == H_k && v->ne[3] == n_seqs); - GGML_ASSERT(beta->ne[1] == n_tokens && beta->ne[2] == H_k && beta->ne[0] == 1 && beta->ne[3] == n_seqs); - - // Do padding - const int64_t chunk_size = CHUNK_SIZE; - - const int64_t pad = (chunk_size - n_tokens % chunk_size) % chunk_size; - const int64_t n_chunks = (n_tokens + pad) / chunk_size; - - q = ggml_pad(ctx0, q, 0, pad, 0, 0); - k = ggml_pad(ctx0, k, 0, pad, 0, 0); - v = ggml_pad(ctx0, v, 0, pad, 0, 0); - g = ggml_pad(ctx0, g, pad, 0, 0, 0); - beta = ggml_pad(ctx0, beta, 0, pad, 0, 0); - - cb(q, "q_pad", il); - cb(k, "k_pad", il); - cb(v, "v_pad", il); - cb(beta, "beta_pad", il); - cb(g, "g_pad", il); - - ggml_tensor * v_beta = ggml_mul(ctx0, v, beta); - ggml_tensor * k_beta = ggml_mul(ctx0, k, beta); - - cb(v_beta, "v_beta", il); - cb(k_beta, "k_beta", il); - - q = ggml_reshape_4d(ctx0, q, S_k, chunk_size, n_chunks, H_k * n_seqs); - k = ggml_reshape_4d(ctx0, k, S_k, chunk_size, n_chunks, H_k * n_seqs); - k_beta = ggml_reshape_4d(ctx0, k_beta, S_k, chunk_size, n_chunks, H_k * n_seqs); - v = ggml_reshape_4d(ctx0, v, S_v, chunk_size, n_chunks, H_v * n_seqs); - v_beta = ggml_reshape_4d(ctx0, v_beta, S_v, chunk_size, n_chunks, H_v * n_seqs); - - g = ggml_reshape_4d(ctx0, g, chunk_size, 1, n_chunks, H_k * n_seqs); - beta = ggml_reshape_4d(ctx0, beta, 1, chunk_size, n_chunks, H_k * n_seqs); - - ggml_tensor * g_cumsum = ggml_cumsum(ctx0, g); - cb(g_cumsum, "g_cumsum", il); // shape: (chunk_size, 1, n_chunks, H_v * n_seqs) - - ggml_tensor * gcs_i = g_cumsum; // ggml_reshape_4d(ctx0, g_cumsum, chunk_size, 1, n_chunks, H_v * n_seqs); - ggml_tensor * gcs_j = ggml_reshape_4d(ctx0, g_cumsum, 1, chunk_size, n_chunks, H_v * n_seqs); - - ggml_tensor * gcs_j_broadcast = - ggml_repeat_4d(ctx0, gcs_j, chunk_size, chunk_size, n_chunks, H_v * n_seqs); - - ggml_tensor * decay_mask = ggml_sub(ctx0, gcs_j_broadcast, gcs_i); - cb(decay_mask, "decay_mask", il); // shape: (chunk_size, chunk_size, n_chunks, H_v * n_seqs) - - decay_mask = ggml_mul(ctx0, decay_mask, diag_mask); - decay_mask = ggml_exp(ctx0, decay_mask); - decay_mask = ggml_mul(ctx0, decay_mask, diag_mask); - - ggml_tensor * kmulkbeta = ggml_mul_mat(ctx0, k, k_beta); - - ggml_tensor * k_decay = ggml_mul(ctx0, kmulkbeta, decay_mask); - ggml_tensor * attn = ggml_neg(ctx0, ggml_mul(ctx0, k_decay, causal_mask)); - cb(attn, "attn_pre_solve", il); // shape: (chunk_size, chunk_size, n_chunks, H_v * n_seqs) - - ggml_tensor * attn_lower = ggml_mul(ctx0, attn, causal_mask); - ggml_tensor * lhs = ggml_sub(ctx0, ggml_repeat(ctx0, identity, attn_lower), attn_lower); - - ggml_tensor * lin_solve = ggml_solve_tri(ctx0, lhs, attn, true, true, false); - attn = ggml_mul(ctx0, lin_solve, causal_mask); - attn = ggml_add(ctx0, attn, identity); - cb(attn, "attn_solved", il); // shape: (chunk_size, chunk_size, n_chunks, H_v * n_seqs) - - v = ggml_mul_mat(ctx0, ggml_cont(ctx0, ggml_transpose(ctx0, v_beta)), attn); - - ggml_tensor * g_cumsum_t = ggml_cont(ctx0, ggml_transpose(ctx0, g_cumsum)); - ggml_tensor * gexp = ggml_exp(ctx0, g_cumsum_t); - - ggml_tensor * kbeta_gexp = ggml_mul(ctx0, k_beta, gexp); - cb(kbeta_gexp, "kbeta_gexp", il); // shape: (S_k, chunk_size, n_chunks, H_v * n_seqs) - - ggml_tensor * k_cumdecay = - ggml_cont(ctx0, ggml_transpose(ctx0, ggml_mul_mat(ctx0, attn, ggml_cont(ctx0, ggml_transpose(ctx0, kbeta_gexp))))); - cb(k_cumdecay, "k_cumdecay", il); // shape: (chunk_size, chunk_size, n_chunks, H_v * n_seqs) - - ggml_tensor * attn_kq = ggml_mul_mat(ctx0, k, q); - attn_kq = ggml_mul(ctx0, attn_kq, decay_mask); - attn_kq = ggml_mul(ctx0, attn_kq, diag_mask); - cb(attn_kq, "attn_kq", il); // shape: (chunk_size, chunk_size, n_chunks, H_v * n_seqs) - - - // vectorized calculation of key_gdiff - // improved from the chunked version: - // g_last = torch.clamp(g_cum[:, :, -1], max=50.0).exp().unsqueeze(-1).unsqueeze(-1) - // g_diff = torch.clamp(g_cum[:, :, -1:] - g_cum, max=50.0).exp() - // key_gdiff = key * g_diff.unsqueeze(-1) - // kgdmulvnew = (key_gdiff).transpose(-1, -2) @ v_new - // last_recurrent_state = last_recurrent_state * g_last + kgdmulvnew - - // get last element in g_cumsum along chunk_size dimension (ne0) - // example: [[x, y, z, ..., last], ...] -> [[last], ...] - ggml_tensor * g_last = ggml_view_4d(ctx0, g_cumsum, 1, 1, g_cumsum->ne[2], g_cumsum->ne[3], - g_cumsum->nb[1], g_cumsum->nb[2], g_cumsum->nb[3], - (g_cumsum->ne[0] - 1) * ggml_element_size(g_cumsum)); - g_last = ggml_cont(ctx0, g_last); - cb(g_last, "g_last", il); // shape: (1, 1, n_chunks, H_v * n_seqs) - - ggml_tensor * g_last_exp = ggml_exp(ctx0, g_last); - cb(g_last_exp, "g_last_exp", il); // shape: (1, 1, n_chunks, H_v * n_seqs) - - ggml_tensor * g_diff = ggml_neg(ctx0, ggml_sub(ctx0, g_cumsum, g_last)); - cb(g_diff, "g_diff", il); // shape: (chunk_size, 1, n_chunks, H_v * n_seqs) - - ggml_tensor * g_diff_exp = ggml_exp(ctx0, g_diff); - ggml_tensor * key_gdiff = ggml_mul(ctx0, k, g_diff_exp); - cb(key_gdiff, "key_gdiff", il); // shape: (S_k, chunk_size, n_chunks, H_v * n_seqs) - - - // state to be updated per chunk - ggml_tensor * new_state = state; // ggml_dup(ctx0, state); - cb(new_state, "new_state", il); // shape: (S_v, S_v, H_v, n_seqs) - - // shape after loop of chunks: (S_v, chunk_size, n_chunks, H_v * n_seqs) - ggml_tensor * core_attn_out = nullptr; - - for (int64_t chunk = 0; chunk < n_chunks; chunk++) { - // shape: (S_k, chunk_size, 1, H_k * n_seqs) - ggml_tensor * q_chunk = get_slice_2d(ctx0, q, chunk); // (no cont), next op: ggml_mul - - // shape: (S_v, chunk_size, 1, H_v * n_seqs) - ggml_tensor * v_chunk = get_slice_2d(ctx0, v, chunk); // (no cont), next op: ggml_repeat - - // shape: (chunk_size, 1, n_chunks, H_v * n_seqs) - ggml_tensor * gexp_chunk = get_slice_2d(ctx0, gexp, chunk); // (no cont), next op: ggml_mul - - // shape: (chunk_size, 1, H_v * n_seqs) - ggml_tensor * k_cumdecay_chunk = get_slice_2d(ctx0, k_cumdecay, chunk); // (no cont), next op: ggml_mul_mat - - // attn = (q_i @ k_i.transpose(-1, -2) * decay_mask[:, :, i]).masked_fill_(mask, 0) - // replaced by precomputed attn_kq - ggml_tensor * attn_chunk = get_slice_2d(ctx0, attn_kq, chunk); - cb(attn_chunk, "attn_chunk", il); - - ggml_tensor * state_t = ggml_cont_4d(ctx0, ggml_permute(ctx0, new_state, 1, 0, 2, 3), S_v, S_v, 1, H_v * n_seqs); - - // v_prime = (k_cumdecay[:, :, i]) @ last_recurrent_state - ggml_tensor * v_prime = ggml_mul_mat(ctx0, state_t, k_cumdecay_chunk); - cb(v_prime, "v_prime_chunk", il); // shape: (S_v, 1, H_v * n_seqs) - - // v_new = v_i - v_prime - ggml_tensor * v_new = ggml_sub(ctx0, ggml_repeat(ctx0, v_chunk, v_prime), v_prime); - ggml_tensor * v_new_t = ggml_cont(ctx0, ggml_transpose(ctx0, v_new)); - cb(v_new, "v_new_chunk", il); - - // attn_inter = (q_i * g[:, :, i, :, None].exp()) @ last_recurrent_state - ggml_tensor * q_g_exp = ggml_mul(ctx0, q_chunk, gexp_chunk); - ggml_tensor * attn_inter = ggml_mul_mat(ctx0, state_t, q_g_exp); - cb(attn_inter, "attn_inter_chunk", il); - - // core_attn_out[:, :, i] = attn_inter + attn @ v_new - ggml_tensor * v_attn = ggml_mul_mat(ctx0, v_new_t, attn_chunk); - cb(v_attn, "v_attn_chunk", il); - - ggml_tensor * core_attn_out_chunk = ggml_add(ctx0, attn_inter, v_attn); - cb(core_attn_out_chunk, "core_attn_out_chunk", il); // shape: (S_v, chunk_size, 1, H_v * n_seqs) - - core_attn_out = core_attn_out == nullptr - ? core_attn_out_chunk - : ggml_concat(ctx0, core_attn_out, core_attn_out_chunk, 2); - - // kgdmulvnew = (key_gdiff).transpose(-1, -2) @ v_new - ggml_tensor * k_gdiff = ggml_cont(ctx0, get_slice_2d(ctx0, key_gdiff, chunk)); - //ggml_tensor * kgdmulvnew = ggml_mul_mat(ctx0, k_gdiff, v_new); // this is slower on metal, why? - ggml_tensor * kgdmulvnew = ggml_mul_mat(ctx0, v_new_t, ggml_cont(ctx0, ggml_transpose(ctx0, k_gdiff))); - - // last_recurrent_state = last_recurrent_state * g_last + kgdmulvnew - ggml_tensor * gexp_last_chunk = ggml_cont(ctx0, get_slice_2d(ctx0, g_last_exp, chunk)); - new_state = ggml_add(ctx0, - ggml_mul(ctx0, new_state, ggml_reshape_4d(ctx0, gexp_last_chunk, gexp_last_chunk->ne[0], gexp_last_chunk->ne[1], H_v, n_seqs)), - ggml_reshape_4d(ctx0, kgdmulvnew, kgdmulvnew->ne[0], kgdmulvnew->ne[1], H_v, n_seqs)); - } - - // truncate padded tokens - ggml_tensor * output_tokens = ggml_view_4d(ctx0, core_attn_out, - S_v, n_tokens, H_v, n_seqs, - ggml_row_size(core_attn_out->type, S_v), - ggml_row_size(core_attn_out->type, S_v * chunk_size * n_chunks), - ggml_row_size(core_attn_out->type, S_v * chunk_size * n_chunks * H_v), 0); - output_tokens = ggml_cont(ctx0, output_tokens); - cb(output_tokens, "output_tokens", il); - - // permute back to (S_v, H_v, n_tokens, n_seqs) - output_tokens = ggml_permute(ctx0, output_tokens, 0, 2, 1, 3); - output_tokens = ggml_cont(ctx0, output_tokens); - - return {output_tokens, new_state}; -} - -std::pair<ggml_tensor *, ggml_tensor *> llm_build_qwen3next::build_delta_net_autoregressive( - ggml_tensor * q, - ggml_tensor * k, - ggml_tensor * v, - ggml_tensor * g, - ggml_tensor * beta, - ggml_tensor * state, - int il) { - const int64_t S_k = q->ne[0]; - const int64_t H_k = q->ne[1]; - const int64_t n_tokens = q->ne[2]; - const int64_t n_seqs = q->ne[3]; - - const int64_t S_v = v->ne[0]; - const int64_t H_v = v->ne[1]; - - GGML_ASSERT(n_tokens == 1); // This function is optimized for single token processing - GGML_ASSERT(v->ne[2] == n_tokens); - GGML_ASSERT(k->ne[2] == n_tokens); - GGML_ASSERT(g->ne[0] == H_v && g->ne[1] == n_tokens && g->ne[2] == n_seqs); - GGML_ASSERT(beta->ne[0] == H_v && beta->ne[2] == n_tokens && beta->ne[3] == n_seqs); - GGML_ASSERT(state->ne[0] == S_v && state->ne[1] == S_v * H_v && state->ne[2] == 1 && state->ne[3] == n_seqs); - - GGML_ASSERT(q->ne[0] == S_k && q->ne[1] == H_k && q->ne[2] == n_tokens && q->ne[3] == n_seqs); - GGML_ASSERT(k->ne[0] == S_k && k->ne[1] == H_k && k->ne[2] == n_tokens && k->ne[3] == n_seqs); - - GGML_ASSERT(H_k == H_v); // we did a repeat to make sure this is the case - - const float eps_norm = hparams.f_norm_rms_eps; - - q = ggml_l2_norm(ctx0, q, eps_norm); - k = ggml_l2_norm(ctx0, k, eps_norm); - - const float scale = 1.0f / sqrtf(S_v); - - q = ggml_scale(ctx0, q, scale); - beta = ggml_sigmoid(ctx0, beta); - - cb(q, "q_in", il); - cb(k, "k_in", il); - cb(v, "v_in", il); - cb(beta, "beta_in", il); - cb(g, "g_in", il); - - state = ggml_reshape_4d(ctx0, state, S_v, S_v, H_v, n_seqs); - - ggml_tensor * g_t = ggml_reshape_4d(ctx0, ggml_transpose(ctx0, g), 1, 1, H_k, n_seqs); - ggml_tensor * beta_t = ggml_reshape_4d(ctx0, ggml_transpose(ctx0, beta), 1, 1, H_k, n_seqs); - - // Apply exponential to g_t - g_t = ggml_exp(ctx0, g_t); - - // Apply the gated delta rule for the single timestep - // last_recurrent_state = last_recurrent_state * g_t - state = ggml_mul(ctx0, state, g_t); - - // kv_mem = (last_recurrent_state * k_t.unsqueeze(-1)).sum(dim=-2) - ggml_tensor * k_t_unsqueezed = ggml_reshape_4d(ctx0, k, 1, S_v, H_v, n_seqs); - ggml_tensor * kv_mem = ggml_mul(ctx0, state, k_t_unsqueezed); - // we need to sum over dim=-2, so we transpose, sum, then transpose again - kv_mem = ggml_transpose(ctx0, ggml_sum_rows(ctx0, ggml_cont(ctx0, ggml_transpose(ctx0, kv_mem)))); - - // v_t = v.unsqueeze(2) (we insert the singleton dimension after n_seqs and H_v) - ggml_tensor * v_t = ggml_reshape_4d(ctx0, v, S_v, 1, H_v, n_seqs); - // delta = (v_t - kv_mem) * beta_t - ggml_tensor * v_diff = ggml_sub(ctx0, v_t, kv_mem); // both should be [S_v, 1, H_v, n_seqs] - ggml_tensor * delta = ggml_mul(ctx0, v_diff, beta_t); - - // last_recurrent_state = last_recurrent_state + k_t.unsqueeze(-1) * delta - ggml_tensor * k_t_delta = ggml_mul(ctx0, ggml_repeat_4d(ctx0, k_t_unsqueezed, S_v, S_v, H_v, n_seqs), delta); - state = ggml_add(ctx0, state, k_t_delta); - - // Compute the attention output - // core_attn_out = (last_recurrent_state * q_t.unsqueeze(-1)).sum(dim=-2) - ggml_tensor * q_t_unsqueezed = ggml_reshape_4d(ctx0, q, 1, S_v, H_v, n_seqs); // unsqueeze q_t - ggml_tensor * state_q = ggml_mul(ctx0, state, q_t_unsqueezed); - // again, since it's over dim = -2, transpose, sum, transpose back - ggml_tensor * core_attn_out = - ggml_transpose(ctx0, ggml_sum_rows(ctx0, ggml_cont(ctx0, ggml_transpose(ctx0, state_q)))); - - // core_attn_out should be [S_v, 1, H_v, n_seqs] after this - cb(core_attn_out, "output_tokens", il); - cb(state, "new_state", il); - - return {core_attn_out, state}; -} - -ggml_tensor * llm_build_qwen3next::build_norm_gated( +ggml_tensor * llama_model_qwen3next::graph::build_norm_gated( ggml_tensor * input, ggml_tensor * weights, ggml_tensor * gate, @@ -447,13 +203,13 @@ ggml_tensor * llm_build_qwen3next::build_norm_gated( return ggml_mul(ctx0, normalized, gated_silu); } -ggml_tensor * llm_build_qwen3next::build_layer_attn( +ggml_tensor * llama_model_qwen3next::graph::build_layer_attn( llm_graph_input_attn_kv * inp, ggml_tensor * cur, ggml_tensor * inp_pos, int il) { - const int64_t n_embd_head = hparams.n_embd_head_v; - GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); + const int64_t n_embd_head = hparams.n_embd_head_v(); + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k()); // Order: joint QG projection, QG split, Q norm, KV projection, K norm, RoPE, attention @@ -466,39 +222,29 @@ ggml_tensor * llm_build_qwen3next::build_layer_attn( // Split Q projection into query and gate // The split should be along dimension 0 (the feature dimension) ggml_tensor * Qcur = ggml_view_4d(ctx0, Qcur_full, n_embd_head, n_head, n_tokens, 1, - Qcur_full->nb[1], Qcur_full->nb[2], Qcur_full->nb[3], 0); + Qcur_full->nb[1], Qcur_full->nb[2], Qcur_full->nb[3], 0); + cb(Qcur, "Qcur_view", il); + ggml_tensor * gate = ggml_view_4d(ctx0, Qcur_full, n_embd_head, n_head, n_tokens, 1, Qcur_full->nb[1], Qcur_full->nb[2], Qcur_full->nb[3], n_embd_head * ggml_element_size(Qcur_full)); - cb(Qcur, "Qcur", il); cb(gate, "gate", il); - // Now reshape Qcur to [n_embd_head, n_head, n_tokens] for multi-head attention - Qcur = ggml_cont_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - cb(Qcur, "Qcur_reshaped", il); - - // Apply Q normalization - Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, nullptr, LLM_NORM_RMS, il); - cb(Qcur, "Qcur_normed", il); - ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); cb(Kcur, "Kcur", il); ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); cb(Vcur, "Vcur", il); - // Apply K normalization Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); - Kcur = build_norm(Kcur, model.layers[il].attn_k_norm, nullptr, LLM_NORM_RMS, il); - cb(Kcur, "Kcur_normed", il); + Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); - // Reshape gate to [n_embd, n_tokens] for the sigmoid gating (flatten the heads) - gate = ggml_cont_2d(ctx0, gate, n_embd_head * n_head, n_tokens); - cb(gate, "gate_reshaped", il); + Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, nullptr, LLM_NORM_RMS, il); + cb(Qcur, "Qcur_normed", il); - Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + Kcur = build_norm(Kcur, model.layers[il].attn_k_norm, nullptr, LLM_NORM_RMS, il); + cb(Kcur, "Kcur_normed", il); - // Apply RoPE Qcur = ggml_rope_ext( ctx0, Qcur, inp_pos, nullptr, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, @@ -513,27 +259,31 @@ ggml_tensor * llm_build_qwen3next::build_layer_attn( cb(Kcur, "Kcur", il); cb(Vcur, "Vcur", il); - // Attention computation const float kq_scale = hparams.f_attention_scale == 0.0f ? 1.0f / sqrtf(float(n_embd_head)) : hparams.f_attention_scale; cur = build_attn(inp, - nullptr, nullptr, + nullptr, nullptr, nullptr, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale, il); cb(cur, "attn_pregate", il); - ggml_tensor * gate_sigmoid = ggml_sigmoid(ctx0, gate); - cb(gate_sigmoid, "gate_sigmoid", il); + // TODO: CUDA is missing non-contiguous unary ops. when implemented: remove this cont + gate = ggml_cont_2d(ctx0, gate, n_embd_head * n_head, n_tokens); + + gate = ggml_sigmoid(ctx0, gate); + cb(gate, "gate_sigmoid", il); + + gate = ggml_reshape_2d(ctx0, gate, n_embd_head * n_head, n_tokens); - cur = ggml_mul(ctx0, cur, gate_sigmoid); + cur = ggml_mul(ctx0, cur, gate); cb(cur, "attn_gated", il); - cur = build_lora_mm(model.layers[il].wo, cur); + cur = build_lora_mm(model.layers[il].wo, cur, model.layers[il].wo_s); cb(cur, "attn_output", il); return cur; } -std::pair<ggml_tensor *, ggml_tensor *> llm_build_qwen3next::build_qkvz( +std::pair<ggml_tensor *, ggml_tensor *> llama_model_qwen3next::graph::build_qkvz( ggml_tensor * input, int il) { const int64_t d_inner = hparams.ssm_d_inner; @@ -554,7 +304,6 @@ std::pair<ggml_tensor *, ggml_tensor *> llm_build_qwen3next::build_qkvz( cb(z, "z", il); return { qkv_mixed, z }; - } else { // legacy (slower) path ggml_tensor * mixed_qkvz = build_lora_mm(model.layers[il].ssm_in, input); @@ -615,12 +364,9 @@ std::pair<ggml_tensor *, ggml_tensor *> llm_build_qwen3next::build_qkvz( } } -ggml_tensor * llm_build_qwen3next::build_layer_attn_linear( +ggml_tensor * llama_model_qwen3next::graph::build_layer_attn_linear( llm_graph_input_rs * inp, ggml_tensor * cur, - ggml_tensor * causal_mask, - ggml_tensor * identity, - ggml_tensor * diag_mask, int il) { const auto * mctx_cur = inp->mctx; @@ -632,8 +378,6 @@ ggml_tensor * llm_build_qwen3next::build_layer_attn_linear( const int64_t head_v_dim = d_inner / num_v_heads; const int64_t n_seq_tokens = ubatch.n_seq_tokens; - const auto kv_head = mctx_cur->get_head(); - GGML_ASSERT(n_seqs != 0); GGML_ASSERT(ubatch.equal_seqs()); GGML_ASSERT(ubatch.n_tokens == n_seq_tokens * n_seqs); @@ -665,7 +409,10 @@ ggml_tensor * llm_build_qwen3next::build_layer_attn_linear( split_sizes_ba[0] * ggml_element_size(mixed_ba_reshaped)); cb(a, "a", il); - ggml_tensor * beta = ggml_cont_4d(ctx0, b, num_v_heads, 1, n_seq_tokens, n_seqs); + // TODO: CUDA is missing non-contiguous unary ops. when implemented: remove this cont + b = ggml_cont(ctx0, b); + + ggml_tensor * beta = ggml_sigmoid(ctx0, b); // Reshape a to merge head dimensions: [batch, seq_len, num_k_heads, num_v_heads/num_k_heads] -> [batch, seq_len, num_v_heads] ggml_tensor * alpha = ggml_cont_3d(ctx0, a, num_v_heads, n_seq_tokens, n_seqs); @@ -673,48 +420,26 @@ ggml_tensor * llm_build_qwen3next::build_layer_attn_linear( ggml_tensor * alpha_biased = ggml_add(ctx0, alpha, model.layers[il].ssm_dt); ggml_tensor * alpha_softplus = ggml_softplus(ctx0, alpha_biased); cb(alpha_softplus, "a_softplus", il); + ggml_tensor * gate = ggml_mul(ctx0, alpha_softplus, model.layers[il].ssm_a); // -A_log.exp() * softplus cb(gate, "gate", il); - // Get convolution states from cache + beta = ggml_reshape_4d(ctx0, beta, 1, num_v_heads, n_seq_tokens, n_seqs); + gate = ggml_reshape_4d(ctx0, gate, 1, num_v_heads, n_seq_tokens, n_seqs); + ggml_tensor * conv_states_all = mctx_cur->get_r_l(il); ggml_tensor * ssm_states_all = mctx_cur->get_s_l(il); - // bool use_precomputed_states = n_seq_tokens == 1 && mctx_cur->has_previous_state(); - - // Build the convolution states tensor - ggml_tensor * conv_states = build_rs(inp, conv_states_all, hparams.n_embd_r(), n_seqs); - cb(conv_states, "conv_states", il); - - // Calculate convolution kernel size ggml_tensor * conv_kernel = model.layers[il].ssm_conv1d; const int64_t conv_kernel_size = conv_kernel->ne[0]; const int64_t conv_channels = d_inner + 2 * hparams.ssm_n_group * hparams.ssm_d_state; - conv_states = ggml_reshape_3d(ctx0, conv_states, conv_kernel_size - 1, conv_channels, n_seqs); - cb(conv_states, "conv_states_reshaped", il); - - qkv_mixed = ggml_permute(ctx0, qkv_mixed, 1, 0, 2, 3); - cb(qkv_mixed, "qkv_mixed_permuted", il); - - ggml_tensor * conv_input = ggml_concat(ctx0, conv_states, qkv_mixed, 0); - cb(conv_input, "conv_input", il); - // Update convolution state cache - // Extract the last (conv_kernel_size - 1) states from conv_input - ggml_tensor * last_conv_states = - ggml_view_3d(ctx0, conv_input, conv_kernel_size - 1, conv_channels, n_seqs, conv_input->nb[1], - conv_input->nb[2], (conv_input->ne[0] - conv_states->ne[0]) * ggml_element_size(conv_input)); - cb(last_conv_states, "last_conv_states", il); + ggml_tensor * conv_input = build_conv_state(inp, conv_states_all, qkv_mixed, conv_kernel_size, conv_channels, il); - ggml_tensor * state_update_target = - ggml_view_1d(ctx0, conv_states_all, (conv_kernel_size - 1) * conv_channels * n_seqs, - kv_head * (conv_kernel_size - 1) * conv_channels * ggml_element_size(conv_states_all)); - cb(state_update_target, "state_update_target", il); - - ggml_build_forward_expand(gf, ggml_cpy(ctx0, last_conv_states, state_update_target)); - cb(conv_states_all, "conv_states_updated", il); + ggml_tensor * state = build_rs(inp, ssm_states_all, hparams.n_embd_s(), n_seqs); + state = ggml_reshape_4d(ctx0, state, head_v_dim, head_v_dim, num_v_heads, n_seqs); + cb(state, "state_predelta", il); - // Apply SSM convolution ggml_tensor * conv_output_proper = ggml_ssm_conv(ctx0, conv_input, conv_kernel); cb(conv_output_proper, "conv_output_raw", il); @@ -728,45 +453,56 @@ ggml_tensor * llm_build_qwen3next::build_layer_attn_linear( int64_t nb1_qkv = ggml_row_size(conv_qkv_mix->type, qkv_dim); // Extract the convolved Q, K, V from conv_output - ggml_tensor * q_conv = - ggml_view_2d(ctx0, conv_qkv_mix, head_k_dim * num_k_heads, n_seq_tokens * n_seqs, nb1_qkv, 0); + ggml_tensor * q_conv = ggml_view_4d(ctx0, conv_qkv_mix, head_k_dim, num_k_heads, n_seq_tokens, n_seqs, + ggml_row_size(conv_qkv_mix->type, head_k_dim), + nb1_qkv, + nb1_qkv * n_seq_tokens, + 0); + + ggml_tensor * k_conv = ggml_view_4d(ctx0, conv_qkv_mix, head_k_dim, num_k_heads, n_seq_tokens, n_seqs, + ggml_row_size(conv_qkv_mix->type, head_k_dim), + nb1_qkv, + nb1_qkv * n_seq_tokens, + head_k_dim * num_k_heads * ggml_element_size(conv_qkv_mix)); + + ggml_tensor * v_conv = ggml_view_4d(ctx0, conv_qkv_mix, head_v_dim, num_v_heads, n_seq_tokens, n_seqs, + ggml_row_size(conv_qkv_mix->type, head_v_dim), + nb1_qkv, + nb1_qkv * n_seq_tokens, + ggml_row_size(conv_qkv_mix->type, 2 * head_k_dim * num_k_heads)); + cb(q_conv, "q_conv", il); - ggml_tensor * k_conv = - ggml_view_2d(ctx0, conv_qkv_mix, head_k_dim * num_k_heads, n_seq_tokens * n_seqs, nb1_qkv, - head_k_dim * num_k_heads * ggml_element_size(conv_qkv_mix)); cb(k_conv, "k_conv", il); - ggml_tensor * v_conv = - ggml_view_2d(ctx0, conv_qkv_mix, head_v_dim * num_v_heads, n_seq_tokens * n_seqs, nb1_qkv, - 2 * head_k_dim * num_k_heads * ggml_element_size(conv_qkv_mix)); cb(v_conv, "v_conv", il); - // Unsqueeze them - q_conv = ggml_cont_4d(ctx0, q_conv, head_k_dim, num_k_heads, n_seq_tokens, n_seqs); - k_conv = ggml_cont_4d(ctx0, k_conv, head_k_dim, num_k_heads, n_seq_tokens, n_seqs); - v_conv = ggml_cont_4d(ctx0, v_conv, head_v_dim, num_v_heads, n_seq_tokens, n_seqs); + const float eps_norm = hparams.f_norm_rms_eps; - ggml_tensor * state = build_rs(inp, ssm_states_all, hparams.n_embd_s(), n_seqs); - state = ggml_reshape_4d(ctx0, state, head_v_dim, head_v_dim * num_v_heads, 1, n_seqs); - cb(state, "state_predelta", il); + q_conv = ggml_l2_norm(ctx0, q_conv, eps_norm); + k_conv = ggml_l2_norm(ctx0, k_conv, eps_norm); + + //q_conv = ggml_cont_4d(ctx0, q_conv, head_k_dim, num_k_heads, n_seq_tokens, n_seqs); + //k_conv = ggml_cont_4d(ctx0, k_conv, head_k_dim, num_k_heads, n_seq_tokens, n_seqs); + //v_conv = ggml_cont_4d(ctx0, v_conv, head_v_dim, num_v_heads, n_seq_tokens, n_seqs); // if head keys and value keys are different, repeat to force tensors into matching shapes + // TODO: avoid repeats for fused GDN, needs broadcast configuration for GDN op [TAG_GGML_GDN_BCAST] if (num_k_heads != num_v_heads) { GGML_ASSERT(num_v_heads % num_k_heads == 0); int64_t repeat_factor = num_v_heads / num_k_heads; - // repeat interleave: reshape to (repeat part, 1, remaining part), do repeat, then reshape back - ggml_tensor * q_reshaped = ggml_reshape_3d(ctx0, q_conv, head_k_dim, 1, num_k_heads * n_seq_tokens * n_seqs); - ggml_tensor * k_reshaped = ggml_reshape_3d(ctx0, k_conv, head_k_dim, 1, num_k_heads * n_seq_tokens * n_seqs); + // repeat interleave: reshape to (repeat part, 1, remaining part...), do repeat, then reshape back + ggml_tensor * q_reshaped = ggml_reshape_4d(ctx0, q_conv, head_k_dim, 1, num_k_heads, n_seq_tokens * n_seqs); + ggml_tensor * k_reshaped = ggml_reshape_4d(ctx0, k_conv, head_k_dim, 1, num_k_heads, n_seq_tokens * n_seqs); // Repeat along the third dimension (the new dimension with size 1) ggml_tensor * q_repeated = - ggml_repeat_4d(ctx0, q_reshaped, head_k_dim, repeat_factor, num_k_heads * n_seq_tokens * n_seqs, 1); + ggml_repeat_4d(ctx0, q_reshaped, head_k_dim, repeat_factor, num_k_heads, n_seq_tokens * n_seqs); ggml_tensor * k_repeated = - ggml_repeat_4d(ctx0, k_reshaped, head_k_dim, repeat_factor, num_k_heads * n_seq_tokens * n_seqs, 1); + ggml_repeat_4d(ctx0, k_reshaped, head_k_dim, repeat_factor, num_k_heads, n_seq_tokens * n_seqs); // Reshape back to merge the head and repeat dimensions - // From [head_dim, num_k_heads, repeat_factor, n_seq_tokens * n_seqs] - // Back to [head_dim, num_k_heads * repeat_factor, n_seq_tokens, n_seqs] + // From [head_dim, repeat_factor, num_k_heads, n_seq_tokens * n_seqs] + // Back to [head_dim, repeat_factor * num_k_heads, n_seq_tokens, n_seqs] q_conv = ggml_reshape_4d(ctx0, q_repeated, head_k_dim, num_k_heads * repeat_factor, n_seq_tokens, n_seqs); k_conv = ggml_reshape_4d(ctx0, k_repeated, head_k_dim, num_k_heads * repeat_factor, n_seq_tokens, n_seqs); } @@ -775,33 +511,13 @@ ggml_tensor * llm_build_qwen3next::build_layer_attn_linear( cb(k_conv, "k_conv_predelta", il); cb(v_conv, "v_conv_predelta", il); - // Choose between build_delta_net_chunking, build_delta_net_recurrent, and build_delta_net_autoregressive based on n_tokens - std::pair<ggml_tensor *, ggml_tensor *> attn_out; // pair of (output, new_state) - if (n_seq_tokens == 1) { - attn_out = build_delta_net_autoregressive(q_conv, k_conv, v_conv, gate, beta, state, il); - } else { - attn_out = build_delta_net_chunking(q_conv, k_conv, v_conv, gate, beta, state, causal_mask, identity, diag_mask, il); - } - ggml_tensor * output = attn_out.first; - ggml_tensor * new_state = attn_out.second; - cb(output, "attn_output", il); - cb(new_state, "new_state", il); - - // Update the recurrent states - ggml_build_forward_expand(gf, - ggml_cpy(ctx0, new_state, - ggml_view_1d(ctx0, ssm_states_all, hparams.n_embd_s() * n_seqs, - kv_head * hparams.n_embd_s() * ggml_element_size(ssm_states_all)))); - - // Reshape both attn_out_final and z to 2D tensors for normalization - // attn_out_final: [head_dim, n_heads, n_tokens, n_seqs] -> [n_heads * n_tokens * n_seqs, head_dim] - ggml_tensor * attn_out_2d_final = ggml_reshape_2d(ctx0, output, head_v_dim, num_v_heads * n_seq_tokens * n_seqs); + ggml_tensor * output = build_recurrent_attn(inp, ssm_states_all, q_conv, k_conv, v_conv, gate, beta, state, il); // z: [head_dim, n_heads, n_tokens, n_seqs] -> [n_heads * n_tokens * n_seqs, head_dim] - ggml_tensor * z_2d = ggml_reshape_2d(ctx0, z, head_v_dim, num_v_heads * n_seq_tokens * n_seqs); + ggml_tensor * z_2d = ggml_reshape_4d(ctx0, z, head_v_dim, num_v_heads, n_seq_tokens, n_seqs); // Apply gated normalization: self.norm(core_attn_out, z) - ggml_tensor * attn_out_norm = build_norm_gated(attn_out_2d_final, model.layers[il].ssm_norm, z_2d, il); + ggml_tensor * attn_out_norm = build_norm_gated(output, model.layers[il].ssm_norm, z_2d, il); // Final reshape: [head_dim, n_heads, n_tokens, n_seqs] -> [n_tokens, n_seqs, n_heads * head_dim] ggml_tensor * final_output = ggml_reshape_3d(ctx0, attn_out_norm, head_v_dim * num_v_heads, n_seq_tokens, n_seqs); @@ -812,28 +528,34 @@ ggml_tensor * llm_build_qwen3next::build_layer_attn_linear( cb(cur, "linear_attn_out", il); // Reshape back to original dimensions - cur = ggml_cont_2d(ctx0, cur, n_embd, n_seq_tokens * n_seqs); + cur = ggml_reshape_2d(ctx0, cur, n_embd, n_seq_tokens * n_seqs); + return cur; } -ggml_tensor * llm_build_qwen3next::build_layer_ffn(ggml_tensor * cur, const int il) { +ggml_tensor * llama_model_qwen3next::graph::build_layer_ffn(ggml_tensor * cur, const int il) { // Check if this is an MoE layer if (model.layers[il].ffn_gate_inp != nullptr) { // MoE branch ggml_tensor * moe_out = build_moe_ffn(cur, - model.layers[il].ffn_gate_inp, model.layers[il].ffn_up_exps, - model.layers[il].ffn_gate_exps, model.layers[il].ffn_down_exps, + model.layers[il].ffn_gate_inp, + model.layers[il].ffn_up_exps, + model.layers[il].ffn_gate_exps, + model.layers[il].ffn_down_exps, nullptr, - n_expert, n_expert_used, LLM_FFN_SILU, - true, false, 0.0, LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX, il); + n_expert, n_expert_used, + LLM_FFN_SILU, true, + hparams.expert_weights_scale, + LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX, il, + nullptr, model.layers[il].ffn_gate_up_exps); cb(moe_out, "ffn_moe_out", il); // Add shared experts if present - following Qwen3Next reference implementation if (model.layers[il].ffn_up_shexp != nullptr) { ggml_tensor * ffn_shexp = build_ffn(cur, - model.layers[il].ffn_up_shexp, NULL, NULL, + model.layers[il].ffn_up_shexp, NULL, NULL, model.layers[il].ffn_gate_shexp, NULL, NULL, model.layers[il].ffn_down_shexp, NULL, NULL, NULL, @@ -846,11 +568,9 @@ ggml_tensor * llm_build_qwen3next::build_layer_ffn(ggml_tensor * cur, const int ggml_tensor * shared_gate = build_lora_mm(model.layers[il].ffn_gate_inp_shexp, cur); cb(shared_gate, "shared_expert_gate", il); - // Apply sigmoid to the gate shared_gate = ggml_sigmoid(ctx0, shared_gate); cb(shared_gate, "shared_expert_gate_sigmoid", il); - // Apply the gate to the shared expert output ffn_shexp = ggml_mul(ctx0, ffn_shexp, shared_gate); cb(ffn_shexp, "ffn_shexp_gated", il); diff --git a/examples/talk-llama/models/qwen3vl-moe.cpp b/examples/talk-llama/models/qwen3vl-moe.cpp deleted file mode 100644 index f72f80a8376..00000000000 --- a/examples/talk-llama/models/qwen3vl-moe.cpp +++ /dev/null @@ -1,149 +0,0 @@ -#include "models.h" - -llm_build_qwen3vlmoe::llm_build_qwen3vlmoe(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { - const size_t n_deepstack_layers = hparams.n_deepstack_layers; - const int64_t n_embd = hparams.n_embd; - const int64_t n_embd_head = hparams.n_embd_head_v; - - GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); - GGML_ASSERT(n_embd_head == hparams.n_rot); - - ggml_tensor * cur; - ggml_tensor * inpL; - - inpL = build_inp_embd(model.tok_embd); - - int sections[4]; - std::copy(std::begin(hparams.rope_sections), std::begin(hparams.rope_sections) + 4, sections); - - std::vector<ggml_tensor *> deepstack_features(n_deepstack_layers, nullptr); - - if (ubatch.embd) { - // Image input: split main embd and deepstack embds - ggml_tensor * inpL_main = ggml_view_2d(ctx0, inpL, n_embd, n_tokens, inpL->nb[1], 0); - for (size_t i = 0; i < n_deepstack_layers; i++) { - deepstack_features[i] = ggml_view_2d(ctx0, inpL, n_embd, n_tokens, inpL->nb[1], (i + 1) * n_embd * sizeof(float)); - } - inpL = inpL_main; - } - - // inp_pos - contains the positions - ggml_tensor * inp_pos = build_inp_pos(); - - auto * inp_attn = build_attn_inp_kv(); - - ggml_tensor * inp_out_ids = build_inp_out_ids(); - - for (int il = 0; il < n_layer; ++il) { - ggml_tensor * inpSA = inpL; - - // norm - cur = build_norm(inpL, - model.layers[il].attn_norm, NULL, - LLM_NORM_RMS, il); - cb(cur, "attn_norm", il); - - // self_attention - { - // compute Q and K and RoPE them - ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); - cb(Qcur, "Qcur", il); - - ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); - cb(Kcur, "Kcur", il); - - ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); - cb(Vcur, "Vcur", il); - - Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); - Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); - - Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, NULL, LLM_NORM_RMS, il); - cb(Qcur, "Qcur_normed", il); - - Qcur = ggml_rope_multi( - ctx0, Qcur, inp_pos, nullptr, - n_rot, sections, rope_type, n_ctx_orig, freq_base, freq_scale, - ext_factor, attn_factor, beta_fast, beta_slow - ); - - Kcur = build_norm(Kcur, model.layers[il].attn_k_norm, NULL, LLM_NORM_RMS, il); - cb(Kcur, "Kcur_normed", il); - - Kcur = ggml_rope_multi( - ctx0, Kcur, inp_pos, nullptr, - n_rot, sections, rope_type, n_ctx_orig, freq_base, freq_scale, - ext_factor, attn_factor, beta_fast, beta_slow - ); - - cb(Qcur, "Qcur", il); - cb(Kcur, "Kcur", il); - cb(Vcur, "Vcur", il); - - cur = build_attn(inp_attn, - model.layers[il].wo, model.layers[il].bo, - Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); - } - - if (il == n_layer - 1 && inp_out_ids) { - cur = ggml_get_rows(ctx0, cur, inp_out_ids); - inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); - } - - ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA); - cb(ffn_inp, "ffn_inp", il); - - // MoE branch - cur = build_norm(ffn_inp, - model.layers[il].ffn_norm, NULL, - LLM_NORM_RMS, il); - cb(cur, "ffn_norm", il); - - ggml_tensor * moe_out = - build_moe_ffn(cur, - model.layers[il].ffn_gate_inp, - model.layers[il].ffn_up_exps, - model.layers[il].ffn_gate_exps, - model.layers[il].ffn_down_exps, - nullptr, - n_expert, n_expert_used, - LLM_FFN_SILU, true, - false, 0.0, - LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX, - il); - cb(moe_out, "ffn_moe_out", il); - cur = moe_out; - - cur = ggml_add(ctx0, cur, ffn_inp); - - cur = build_cvec(cur, il); - cb(cur, "l_out", il); - - if (ubatch.embd && (size_t)il < n_deepstack_layers) { - cur = ggml_add(ctx0, cur, deepstack_features[il]); - cb(cur, "deepstack_out", il); - } - - // input for next layer - inpL = cur; - } - - cur = inpL; - - cur = build_norm(cur, - model.output_norm, NULL, - LLM_NORM_RMS, -1); - - cb(cur, "result_norm", -1); - res->t_embd = cur; - - // lm_head - cur = build_lora_mm(model.output, cur); - - cb(cur, "result_output", -1); - res->t_logits = cur; - - ggml_build_forward_expand(gf, cur); -} - diff --git a/examples/talk-llama/models/qwen3vl.cpp b/examples/talk-llama/models/qwen3vl.cpp index 0bae52239ca..724d6140d19 100644 --- a/examples/talk-llama/models/qwen3vl.cpp +++ b/examples/talk-llama/models/qwen3vl.cpp @@ -1,12 +1,64 @@ #include "models.h" -llm_build_qwen3vl::llm_build_qwen3vl(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { +void llama_model_qwen3vl::load_arch_hparams(llama_model_loader & ml) { + ml.get_key(LLM_KV_NUM_DEEPSTACK_LAYERS, hparams.n_deepstack_layers, false); + ml.get_key_or_arr(LLM_KV_ROPE_DIMENSION_SECTIONS, hparams.rope_sections, 4, true); + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + + switch (hparams.n_layer()) { + case 28: type = LLM_TYPE_1_7B; break; + case 36: type = hparams.n_embd == 2560 ? LLM_TYPE_4B : LLM_TYPE_8B; break; + case 64: type = LLM_TYPE_32B; break; + default: type = LLM_TYPE_UNKNOWN; + } +} + +void llama_model_qwen3vl::load_arch_tensors(llama_model_loader &) { + LLAMA_LOAD_LOCALS; + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); + // if output is NULL, init from the input tok embed + if (output == NULL) { + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); + } + + // output rerank head + cls_out = create_tensor(tn(LLM_TENSOR_CLS_OUT, "weight"), {n_embd, hparams.n_cls_out}, TENSOR_NOT_REQUIRED); + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + + create_tensor_qkv(layer, i, n_embd, n_embd_head_k * n_head, n_embd_gqa, n_embd_gqa, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0); + + layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k}, 0); + layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head_k}, 0); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + } +} + +std::unique_ptr<llm_graph_context> llama_model_qwen3vl::build_arch_graph(const llm_graph_params & params) const { + return std::make_unique<graph>(*this, params); +} + +llama_model_qwen3vl::graph::graph(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const size_t n_deepstack_layers = hparams.n_deepstack_layers; - const int64_t n_embd = hparams.n_embd; - const int64_t n_embd_head = hparams.n_embd_head_v; - GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); - GGML_ASSERT(n_embd_head == hparams.n_rot); + const int64_t n_embd = hparams.n_embd; + const int64_t n_embd_head = hparams.n_embd_head_v(); + + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k()); + GGML_ASSERT(n_embd_head == n_rot); ggml_tensor * cur; ggml_tensor * inpL; @@ -16,17 +68,6 @@ llm_build_qwen3vl::llm_build_qwen3vl(const llama_model & model, const llm_graph_ int sections[4]; std::copy(std::begin(hparams.rope_sections), std::begin(hparams.rope_sections) + 4, sections); - std::vector<ggml_tensor *> deepstack_features(n_deepstack_layers, nullptr); - - if (ubatch.embd) { - // Image input: split main embd and deepstack embds - ggml_tensor * inpL_main = ggml_view_2d(ctx0, inpL, n_embd, n_tokens, inpL->nb[1], 0); - for (size_t i = 0; i < n_deepstack_layers; i++) { - deepstack_features[i] = ggml_view_2d(ctx0, inpL, n_embd, n_tokens, inpL->nb[1], (i + 1) * n_embd * sizeof(float)); - } - inpL = inpL_main; - } - // inp_pos - contains the positions ggml_tensor * inp_pos = build_inp_pos(); @@ -46,18 +87,8 @@ llm_build_qwen3vl::llm_build_qwen3vl(const llama_model & model, const llm_graph_ // self-attention { // compute Q and K and RoPE them - ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); - cb(Qcur, "Qcur", il); - - ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); - cb(Kcur, "Kcur", il); - - ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); - cb(Vcur, "Vcur", il); - - Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); - Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + auto [Qcur, Kcur, Vcur] = build_qkv(model.layers[il], cur, + n_embd_head, n_head, n_head_kv, il); Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, NULL, LLM_NORM_RMS, il); cb(Qcur, "Qcur_normed", il); @@ -82,7 +113,7 @@ llm_build_qwen3vl::llm_build_qwen3vl(const llama_model & model, const llm_graph_ cb(Vcur, "Vcur", il); cur = build_attn(inp_attn, - model.layers[il].wo, model.layers[il].bo, + model.layers[il].wo, model.layers[il].wo_b, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } @@ -113,8 +144,9 @@ llm_build_qwen3vl::llm_build_qwen3vl(const llama_model & model, const llm_graph_ cur = build_cvec(cur, il); cb(cur, "l_out", il); - if (ubatch.embd && (size_t)il < n_deepstack_layers) { - cur = ggml_add(ctx0, cur, deepstack_features[il]); + if (il < (int) n_deepstack_layers) { + ggml_tensor * ds = ggml_view_2d(ctx0, res->t_inp_embd, n_embd, n_tokens, res->t_inp_embd->nb[1], (il + 1) * n_embd * sizeof(float)); + cur = ggml_add(ctx0, cur, ds); cb(cur, "deepstack_out", il); } @@ -132,7 +164,7 @@ llm_build_qwen3vl::llm_build_qwen3vl(const llama_model & model, const llm_graph_ res->t_embd = cur; // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/qwen3vlmoe.cpp b/examples/talk-llama/models/qwen3vlmoe.cpp new file mode 100644 index 00000000000..7c41592f772 --- /dev/null +++ b/examples/talk-llama/models/qwen3vlmoe.cpp @@ -0,0 +1,190 @@ +#include "models.h" + +void llama_model_qwen3vlmoe::load_arch_hparams(llama_model_loader & ml) { + ml.get_key(LLM_KV_NUM_DEEPSTACK_LAYERS, hparams.n_deepstack_layers, false); + ml.get_key_or_arr(LLM_KV_ROPE_DIMENSION_SECTIONS, hparams.rope_sections, 4, true); + ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp, false); + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + + switch (hparams.n_layer()) { + case 48: type = LLM_TYPE_30B_A3B; break; + case 94: type = LLM_TYPE_235B_A22B; break; + default: type = LLM_TYPE_UNKNOWN; + } +} + +void llama_model_qwen3vlmoe::load_arch_tensors(llama_model_loader &) { + LLAMA_LOAD_LOCALS; + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); + // if output is NULL, init from the input tok embed + if (output == NULL) { + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); + } + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + + create_tensor_qkv(layer, i, n_embd, n_embd_head_k * n_head, n_embd_gqa, n_embd_gqa, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0); + + layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k}, 0); + layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head_k}, 0); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + + layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0); + + if (n_expert == 0) { + throw std::runtime_error("n_expert must be > 0 for QWEN3MOE"); + } + if (n_expert_used == 0) { + throw std::runtime_error("n_expert_used must be > 0 for QWEN3MOE"); + } + + // MoE branch + const int64_t n_ff_exp = hparams.n_ff_exp ? hparams.n_ff_exp : n_ff / n_expert_used; + + layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}, 0); + layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff_exp, n_embd, n_expert}, 0); + layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}, 0); + } +} + +std::unique_ptr<llm_graph_context> llama_model_qwen3vlmoe::build_arch_graph(const llm_graph_params & params) const { + return std::make_unique<graph>(*this, params); +} + +llama_model_qwen3vlmoe::graph::graph(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { + const size_t n_deepstack_layers = hparams.n_deepstack_layers; + + const int64_t n_embd = hparams.n_embd; + const int64_t n_embd_head = hparams.n_embd_head_v(); + + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k()); + GGML_ASSERT(n_embd_head == n_rot); + + ggml_tensor * cur; + ggml_tensor * inpL; + + inpL = build_inp_embd(model.tok_embd); + + int sections[4]; + std::copy(std::begin(hparams.rope_sections), std::begin(hparams.rope_sections) + 4, sections); + + // inp_pos - contains the positions + ggml_tensor * inp_pos = build_inp_pos(); + + auto * inp_attn = build_attn_inp_kv(); + + ggml_tensor * inp_out_ids = build_inp_out_ids(); + + for (int il = 0; il < n_layer; ++il) { + ggml_tensor * inpSA = inpL; + + // norm + cur = build_norm(inpL, + model.layers[il].attn_norm, NULL, + LLM_NORM_RMS, il); + cb(cur, "attn_norm", il); + + // self_attention + { + // compute Q and K and RoPE them + auto [Qcur, Kcur, Vcur] = build_qkv(model.layers[il], cur, + n_embd_head, n_head, n_head_kv, il); + + Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, NULL, LLM_NORM_RMS, il); + cb(Qcur, "Qcur_normed", il); + + Qcur = ggml_rope_multi( + ctx0, Qcur, inp_pos, nullptr, + n_rot, sections, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + Kcur = build_norm(Kcur, model.layers[il].attn_k_norm, NULL, LLM_NORM_RMS, il); + cb(Kcur, "Kcur_normed", il); + + Kcur = ggml_rope_multi( + ctx0, Kcur, inp_pos, nullptr, + n_rot, sections, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + cb(Qcur, "Qcur", il); + cb(Kcur, "Kcur", il); + cb(Vcur, "Vcur", il); + + cur = build_attn(inp_attn, + model.layers[il].wo, model.layers[il].wo_b, model.layers[il].wo_s, + Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); + } + + if (il == n_layer - 1 && inp_out_ids) { + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); + } + + ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA); + cb(ffn_inp, "ffn_inp", il); + + // MoE branch + cur = build_norm(ffn_inp, + model.layers[il].ffn_norm, NULL, + LLM_NORM_RMS, il); + cb(cur, "ffn_norm", il); + + ggml_tensor * moe_out = + build_moe_ffn(cur, + model.layers[il].ffn_gate_inp, + model.layers[il].ffn_up_exps, + model.layers[il].ffn_gate_exps, + model.layers[il].ffn_down_exps, + nullptr, + n_expert, n_expert_used, + LLM_FFN_SILU, true, + hparams.expert_weights_scale, + LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX, + il); + cb(moe_out, "ffn_moe_out", il); + cur = moe_out; + + cur = ggml_add(ctx0, cur, ffn_inp); + + cur = build_cvec(cur, il); + cb(cur, "l_out", il); + + if (il < (int) n_deepstack_layers) { + ggml_tensor * ds = ggml_view_2d(ctx0, res->t_inp_embd, n_embd, n_tokens, res->t_inp_embd->nb[1], (il + 1) * n_embd * sizeof(float)); + cur = ggml_add(ctx0, cur, ds); + cb(cur, "deepstack_out", il); + } + + // input for next layer + inpL = cur; + } + + cur = inpL; + + cur = build_norm(cur, + model.output_norm, NULL, + LLM_NORM_RMS, -1); + + cb(cur, "result_norm", -1); + res->t_embd = cur; + + // lm_head + cur = build_lora_mm(model.output, cur, model.output_s); + + cb(cur, "result_output", -1); + res->t_logits = cur; + + ggml_build_forward_expand(gf, cur); +} diff --git a/examples/talk-llama/models/refact.cpp b/examples/talk-llama/models/refact.cpp index ff5eb2841db..a46c358fa68 100644 --- a/examples/talk-llama/models/refact.cpp +++ b/examples/talk-llama/models/refact.cpp @@ -1,9 +1,85 @@ #include "models.h" -llm_build_refact::llm_build_refact(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { - const int64_t n_embd_head = hparams.n_embd_head_v; +void llama_model_refact::load_arch_hparams(llama_model_loader & ml) { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); - GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); + switch (hparams.n_layer()) { + case 32: type = LLM_TYPE_1B; break; + default: type = LLM_TYPE_UNKNOWN; + } + + // TODO: become GGUF KV parameter + hparams.f_max_alibi_bias = 8.0f; +} + +void llama_model_refact::load_arch_tensors(llama_model_loader &) { + LLAMA_LOAD_LOCALS; + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); + + // if output is NULL, init from the input tok embed + if (output == NULL) { + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); + } + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + + create_tensor_qkv(layer, i, n_embd, n_embd_head_k * n_head, n_embd_k_gqa, n_embd_v_gqa, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0); + + // optional bias tensors + layer.wo_b = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + + if (hparams.rope_scaling_type_train == LLAMA_ROPE_SCALING_TYPE_LONGROPE) { + layer.rope_long = create_tensor(tn(LLM_TENSOR_ROPE_FACTORS_LONG, "weight", i), {n_rot/2}, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0)); + layer.rope_short = create_tensor(tn(LLM_TENSOR_ROPE_FACTORS_SHORT, "weight", i), {n_rot/2}, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0)); + } + else { + layer.rope_freqs = create_tensor(tn(LLM_TENSOR_ROPE_FREQS, "weight", i), {n_rot/2}, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0)); + } + + if (n_expert == 0) { + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + + // optional MLP bias + layer.ffn_gate_b = create_tensor(tn(LLM_TENSOR_FFN_GATE, "bias", i), {n_ff}, TENSOR_NOT_REQUIRED); + layer.ffn_down_b = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); + layer.ffn_up_b = create_tensor(tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff}, TENSOR_NOT_REQUIRED); + } else { + layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0); + layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {n_embd, n_ff, n_expert}, TENSOR_NOT_REQUIRED); + layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), { n_ff, n_embd, n_expert}, 0); + layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), {n_embd, n_ff, n_expert}, 0); + + // For Granite MoE Shared + if (hparams.n_ff_shexp > 0) { + layer.ffn_gate_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), {n_embd, hparams.n_ff_shexp}, 0); + layer.ffn_up_shexp = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), {n_embd, hparams.n_ff_shexp}, 0); + layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), {hparams.n_ff_shexp, n_embd}, 0); + } + } + } +} + +std::unique_ptr<llm_graph_context> llama_model_refact::build_arch_graph(const llm_graph_params & params) const { + return std::make_unique<graph>(*this, params); +} + +llama_model_refact::graph::graph(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { + const int64_t n_embd_head = hparams.n_embd_head_v(); + + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k()); ggml_tensor * cur; ggml_tensor * inpL; @@ -24,25 +100,15 @@ llm_build_refact::llm_build_refact(const llama_model & model, const llm_graph_pa // self-attention { - ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); - cb(Qcur, "Qcur", il); - - ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); - cb(Kcur, "Kcur", il); - - ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); - cb(Vcur, "Vcur", il); - - Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); - Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + auto [Qcur, Kcur, Vcur] = build_qkv(model.layers[il], cur, + n_embd_head, n_head, n_head_kv, il); cb(Qcur, "Qcur", il); cb(Kcur, "Kcur", il); cb(Vcur, "Vcur", il); cur = build_attn(inp_attn, - model.layers[il].wo, NULL, + model.layers[il].wo, NULL, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } if (il == n_layer - 1 && inp_out_ids) { @@ -85,7 +151,7 @@ llm_build_refact::llm_build_refact(const llama_model & model, const llm_graph_pa res->t_embd = cur; // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/rnd1.cpp b/examples/talk-llama/models/rnd1.cpp index 46b3dc3efca..fc276ce591b 100644 --- a/examples/talk-llama/models/rnd1.cpp +++ b/examples/talk-llama/models/rnd1.cpp @@ -1,11 +1,72 @@ #include "models.h" +void llama_model_rnd1::load_arch_hparams(llama_model_loader & ml) { + ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp, false); + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + + switch (hparams.n_layer()) { + case 48: type = LLM_TYPE_30B_A3B; break; + default: type = LLM_TYPE_UNKNOWN; + } + + // Set non-causal attention for diffusion models + hparams.causal_attn = false; +} + +void llama_model_rnd1::load_arch_tensors(llama_model_loader &) { + LLAMA_LOAD_LOCALS; + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); + // if output is NULL, init from the input tok embed + if (output == NULL) { + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); + } + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + + create_tensor_qkv(layer, i, n_embd, n_embd_head_k * n_head, n_embd_gqa, n_embd_gqa, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0); + + layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k}, 0); + layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head_k}, 0); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + + layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0); + + if (n_expert == 0) { + throw std::runtime_error("n_expert must be > 0 for QWEN3MOE"); + } + if (n_expert_used == 0) { + throw std::runtime_error("n_expert_used must be > 0 for QWEN3MOE"); + } + + // MoE branch + const int64_t n_ff_exp = hparams.n_ff_exp ? hparams.n_ff_exp : n_ff / n_expert_used; + + layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}, 0); + layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff_exp, n_embd, n_expert}, 0); + layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}, 0); + } +} + +std::unique_ptr<llm_graph_context> llama_model_rnd1::build_arch_graph(const llm_graph_params & params) const { + return std::make_unique<graph>(*this, params); +} + // RND1 is a Qwen3Moe AR model converted to diffusion model. -llm_build_rnd1::llm_build_rnd1(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { - const int64_t n_embd_head = hparams.n_embd_head_v; +llama_model_rnd1::graph::graph(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { + const int64_t n_embd_head = hparams.n_embd_head_v(); - GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); - GGML_ASSERT(n_embd_head == hparams.n_rot); + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k()); + GGML_ASSERT(n_embd_head == n_rot); ggml_tensor * cur; ggml_tensor * inpL; @@ -32,18 +93,8 @@ llm_build_rnd1::llm_build_rnd1(const llama_model & model, const llm_graph_params // self_attention { // compute Q and K and RoPE them - ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); - cb(Qcur, "Qcur", il); - - ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); - cb(Kcur, "Kcur", il); - - ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); - cb(Vcur, "Vcur", il); - - Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); - Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + auto [Qcur, Kcur, Vcur] = build_qkv(model.layers[il], cur, + n_embd_head, n_head, n_head_kv, il); Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, NULL, LLM_NORM_RMS, il); cb(Qcur, "Qcur_normed", il); @@ -68,7 +119,7 @@ llm_build_rnd1::llm_build_rnd1(const llama_model & model, const llm_graph_params cb(Vcur, "Vcur", il); cur = build_attn(inp_attn, - model.layers[il].wo, model.layers[il].bo, + model.layers[il].wo, model.layers[il].wo_b, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } if (il == n_layer - 1 && inp_out_ids) { @@ -93,7 +144,7 @@ llm_build_rnd1::llm_build_rnd1(const llama_model & model, const llm_graph_params nullptr, n_expert, n_expert_used, LLM_FFN_SILU, true, - false, 0.0, + hparams.expert_weights_scale, LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX, il); cb(moe_out, "ffn_moe_out", il); @@ -117,7 +168,7 @@ llm_build_rnd1::llm_build_rnd1(const llama_model & model, const llm_graph_params res->t_embd = cur; // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/rwkv6-base.cpp b/examples/talk-llama/models/rwkv6-base.cpp index 7beed2daffb..83aeab7280b 100644 --- a/examples/talk-llama/models/rwkv6-base.cpp +++ b/examples/talk-llama/models/rwkv6-base.cpp @@ -1,5 +1,7 @@ #include "models.h" +#include "llama-memory-recurrent.h" + llm_build_rwkv6_base::llm_build_rwkv6_base(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params), model(model) {} diff --git a/examples/talk-llama/models/rwkv6.cpp b/examples/talk-llama/models/rwkv6.cpp index 15453fbf50f..0b5013dc758 100644 --- a/examples/talk-llama/models/rwkv6.cpp +++ b/examples/talk-llama/models/rwkv6.cpp @@ -1,6 +1,97 @@ #include "models.h" -llm_build_rwkv6::llm_build_rwkv6(const llama_model & model, const llm_graph_params & params) : +void llama_model_rwkv6::load_arch_hparams(llama_model_loader & ml) { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps, false); + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps, false); + ml.get_key(LLM_KV_WKV_HEAD_SIZE, hparams.wkv_head_size); + ml.get_key(LLM_KV_TIME_MIX_EXTRA_DIM, hparams.time_mix_extra_dim); + ml.get_key(LLM_KV_TIME_DECAY_EXTRA_DIM, hparams.time_decay_extra_dim); + ml.get_key(LLM_KV_RESCALE_EVERY_N_LAYERS, hparams.rescale_every_n_layers, false); + ml.get_key(LLM_KV_TOKEN_SHIFT_COUNT, hparams.token_shift_count, false); + + switch (hparams.n_layer()) { + case 24: type = LLM_TYPE_1_6B; break; + case 32: + switch (hparams.n_embd) { + case 2560: type = LLM_TYPE_3B; break; + case 4096: type = LLM_TYPE_7B; break; + default: type = LLM_TYPE_UNKNOWN; + } break; + case 61: type = LLM_TYPE_14B; break; + case 64: type = LLM_TYPE_32B; break; + default: type = LLM_TYPE_UNKNOWN; + } +} + +void llama_model_rwkv6::load_arch_tensors(llama_model_loader &) { + LLAMA_LOAD_LOCALS; + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // Block 0, LN0 + tok_norm = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "weight", 0), {n_embd}, 0); + tok_norm_b = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "bias", 0), {n_embd}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output_norm_b = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0); + + const int time_mix_extra_dim = hparams.time_mix_extra_dim; + const int time_decay_extra_dim = hparams.time_decay_extra_dim; + const int head_size = hparams.wkv_head_size; + const int attn_hidden_size = n_embd; + const int ffn_size = hparams.n_ff_arr[0]; + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + layer.attn_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "bias", i), {n_embd}, 0); + + layer.attn_norm_2 = create_tensor(tn(LLM_TENSOR_ATTN_NORM_2, "weight", i), {n_embd}, 0); + layer.attn_norm_2_b = create_tensor(tn(LLM_TENSOR_ATTN_NORM_2, "bias", i), {n_embd}, 0); + + layer.time_mix_w1 = create_tensor(tn(LLM_TENSOR_TIME_MIX_W1, "weight", i), {n_embd, time_mix_extra_dim * 5}, 0); + layer.time_mix_w2 = create_tensor(tn(LLM_TENSOR_TIME_MIX_W2, "weight", i), {time_mix_extra_dim, n_embd, 5}, 0); + + layer.time_mix_lerp_x = create_tensor(tn(LLM_TENSOR_TIME_MIX_LERP_X, "weight", i), {n_embd, 1, 1}, 0); + layer.time_mix_lerp_w = create_tensor(tn(LLM_TENSOR_TIME_MIX_LERP_W, "weight", i), {n_embd, 1, 1}, TENSOR_NOT_REQUIRED); + layer.time_mix_lerp_k = create_tensor(tn(LLM_TENSOR_TIME_MIX_LERP_K, "weight", i), {n_embd, 1, 1}, TENSOR_NOT_REQUIRED); + layer.time_mix_lerp_v = create_tensor(tn(LLM_TENSOR_TIME_MIX_LERP_V, "weight", i), {n_embd, 1, 1}, TENSOR_NOT_REQUIRED); + layer.time_mix_lerp_r = create_tensor(tn(LLM_TENSOR_TIME_MIX_LERP_R, "weight", i), {n_embd, 1, 1}, TENSOR_NOT_REQUIRED); + layer.time_mix_lerp_g = create_tensor(tn(LLM_TENSOR_TIME_MIX_LERP_G, "weight", i), {n_embd, 1, 1}, TENSOR_NOT_REQUIRED); + layer.time_mix_lerp_fused = create_tensor(tn(LLM_TENSOR_TIME_MIX_LERP_FUSED, "weight", i), {n_embd, 1, 1, 5}, TENSOR_NOT_REQUIRED); + GGML_ASSERT(!(layer.time_mix_lerp_fused == NULL && layer.time_mix_lerp_w == NULL)); + + layer.time_mix_first = create_tensor(tn(LLM_TENSOR_TIME_MIX_FIRST, "weight", i), {head_size, n_embd / head_size}, 0); + layer.time_mix_decay = create_tensor(tn(LLM_TENSOR_TIME_MIX_DECAY, "weight", i), {n_embd}, 0); + layer.time_mix_decay_w1 = create_tensor(tn(LLM_TENSOR_TIME_MIX_DECAY_W1, "weight", i), {n_embd, time_decay_extra_dim}, 0); + layer.time_mix_decay_w2 = create_tensor(tn(LLM_TENSOR_TIME_MIX_DECAY_W2, "weight", i), {time_decay_extra_dim, attn_hidden_size}, 0); + layer.time_mix_key = create_tensor(tn(LLM_TENSOR_TIME_MIX_KEY, "weight", i), {attn_hidden_size, n_embd}, 0); + layer.time_mix_value = create_tensor(tn(LLM_TENSOR_TIME_MIX_VALUE, "weight", i), {attn_hidden_size, n_embd}, 0); + layer.time_mix_receptance = create_tensor(tn(LLM_TENSOR_TIME_MIX_RECEPTANCE, "weight", i), {attn_hidden_size, n_embd}, 0); + layer.time_mix_gate = create_tensor(tn(LLM_TENSOR_TIME_MIX_GATE, "weight", i), {attn_hidden_size, n_embd}, 0); + + layer.time_mix_ln = create_tensor(tn(LLM_TENSOR_TIME_MIX_LN, "weight", i), {n_embd}, 0); + layer.time_mix_ln_b = create_tensor(tn(LLM_TENSOR_TIME_MIX_LN, "bias", i), {n_embd}, 0); + layer.time_mix_output = create_tensor(tn(LLM_TENSOR_TIME_MIX_OUTPUT, "weight", i), {n_embd, attn_hidden_size}, 0); + + layer.channel_mix_lerp_k = create_tensor(tn(LLM_TENSOR_CHANNEL_MIX_LERP_K, "weight", i), {n_embd, 1, 1}, 0); + layer.channel_mix_lerp_r = create_tensor(tn(LLM_TENSOR_CHANNEL_MIX_LERP_R, "weight", i), {n_embd, 1, 1}, 0); + + layer.channel_mix_key = create_tensor(tn(LLM_TENSOR_CHANNEL_MIX_KEY, "weight", i), {n_embd, ffn_size}, 0); + layer.channel_mix_value = create_tensor(tn(LLM_TENSOR_CHANNEL_MIX_VALUE, "weight", i), {ffn_size, n_embd}, 0); + layer.channel_mix_receptance = create_tensor(tn(LLM_TENSOR_CHANNEL_MIX_RECEPTANCE, "weight", i), {n_embd, n_embd}, 0); + } + +} + +std::unique_ptr<llm_graph_context> llama_model_rwkv6::build_arch_graph(const llm_graph_params & params) const { + return std::make_unique<graph>(*this, params); +} + +llama_model_rwkv6::graph::graph(const llama_model & model, const llm_graph_params & params) : llm_build_rwkv6_base(model, params) { GGML_ASSERT(hparams.token_shift_count == 2); @@ -8,7 +99,7 @@ llm_build_rwkv6::llm_build_rwkv6(const llama_model & model, const llm_graph_para ggml_tensor * inpL; inpL = build_inp_embd(model.tok_embd); - inpL = build_norm(inpL, model.tok_norm, model.tok_norm_b, LLM_NORM, -1); + inpL = build_norm(inpL, model.tok_norm, model.tok_norm_b, LLM_NORM, 0); auto * rs_inp = build_rs_inp(); @@ -85,7 +176,7 @@ llm_build_rwkv6::llm_build_rwkv6(const llama_model & model, const llm_graph_para cb(cur, "result_norm", -1); res->t_embd = cur; - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/rwkv6qwen2.cpp b/examples/talk-llama/models/rwkv6qwen2.cpp index e84e5973820..6c7db514435 100644 --- a/examples/talk-llama/models/rwkv6qwen2.cpp +++ b/examples/talk-llama/models/rwkv6qwen2.cpp @@ -1,6 +1,87 @@ #include "models.h" -llm_build_rwkv6qwen2::llm_build_rwkv6qwen2(const llama_model & model, const llm_graph_params & params) : llm_build_rwkv6_base(model, params) { +void llama_model_rwkv6qwen2::load_arch_hparams(llama_model_loader & ml) { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps, false); + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps, false); + ml.get_key(LLM_KV_WKV_HEAD_SIZE, hparams.wkv_head_size); + ml.get_key(LLM_KV_TIME_MIX_EXTRA_DIM, hparams.time_mix_extra_dim); + ml.get_key(LLM_KV_TIME_DECAY_EXTRA_DIM, hparams.time_decay_extra_dim); + ml.get_key(LLM_KV_RESCALE_EVERY_N_LAYERS, hparams.rescale_every_n_layers, false); + ml.get_key(LLM_KV_TOKEN_SHIFT_COUNT, hparams.token_shift_count, false); + + switch (hparams.n_layer()) { + case 24: type = LLM_TYPE_1_6B; break; + case 32: + switch (hparams.n_embd) { + case 2560: type = LLM_TYPE_3B; break; + case 4096: type = LLM_TYPE_7B; break; + default: type = LLM_TYPE_UNKNOWN; + } break; + case 61: type = LLM_TYPE_14B; break; + case 64: type = LLM_TYPE_32B; break; + default: type = LLM_TYPE_UNKNOWN; + } +} + +void llama_model_rwkv6qwen2::load_arch_tensors(llama_model_loader &) { + LLAMA_LOAD_LOCALS; + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output_norm_b = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd}, TENSOR_NOT_REQUIRED); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0); + + const int time_mix_extra_dim = hparams.time_mix_extra_dim; + const int time_decay_extra_dim = hparams.time_decay_extra_dim; + const int head_size = hparams.wkv_head_size; + const int attn_hidden_size = n_embd; + int attn_key_value_size; + if (n_head_kv == 0 || attn_hidden_size / head_size == n_head_kv) { + attn_key_value_size = attn_hidden_size; + } else { + attn_key_value_size = n_head_kv * head_size; + } + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + + layer.time_mix_w1 = create_tensor(tn(LLM_TENSOR_TIME_MIX_W1, "weight", i), {n_embd, time_mix_extra_dim * 5}, 0); + layer.time_mix_w2 = create_tensor(tn(LLM_TENSOR_TIME_MIX_W2, "weight", i), {time_mix_extra_dim, n_embd, 5}, 0); + + layer.time_mix_lerp_x = create_tensor(tn(LLM_TENSOR_TIME_MIX_LERP_X, "weight", i), {n_embd, 1, 1}, 0); + layer.time_mix_lerp_fused = create_tensor(tn(LLM_TENSOR_TIME_MIX_LERP_FUSED, "weight", i), {n_embd, 1, 1, 5}, 0); + + layer.time_mix_first = create_tensor(tn(LLM_TENSOR_TIME_MIX_FIRST, "weight", i), {head_size, n_embd / head_size}, TENSOR_NOT_REQUIRED); + layer.time_mix_decay = create_tensor(tn(LLM_TENSOR_TIME_MIX_DECAY, "weight", i), {n_embd}, 0); + layer.time_mix_decay_w1 = create_tensor(tn(LLM_TENSOR_TIME_MIX_DECAY_W1, "weight", i), {n_embd, time_decay_extra_dim}, 0); + layer.time_mix_decay_w2 = create_tensor(tn(LLM_TENSOR_TIME_MIX_DECAY_W2, "weight", i), {time_decay_extra_dim, attn_hidden_size}, 0); + layer.time_mix_key = create_tensor(tn(LLM_TENSOR_TIME_MIX_KEY, "weight", i), {n_embd, attn_key_value_size}, 0); + layer.time_mix_value = create_tensor(tn(LLM_TENSOR_TIME_MIX_VALUE, "weight", i), {n_embd, attn_key_value_size}, 0); + layer.time_mix_receptance = create_tensor(tn(LLM_TENSOR_TIME_MIX_RECEPTANCE, "weight", i), {attn_hidden_size, n_embd}, 0); + layer.time_mix_gate = create_tensor(tn(LLM_TENSOR_TIME_MIX_GATE, "weight", i), {attn_hidden_size, n_embd}, 0); + // optional bias tensors + layer.time_mix_key_b = create_tensor(tn(LLM_TENSOR_TIME_MIX_KEY, "bias", i), {attn_key_value_size}, TENSOR_NOT_REQUIRED); + layer.time_mix_value_b = create_tensor(tn(LLM_TENSOR_TIME_MIX_VALUE, "bias", i), {attn_key_value_size}, TENSOR_NOT_REQUIRED); + layer.time_mix_receptance_b = create_tensor(tn(LLM_TENSOR_TIME_MIX_RECEPTANCE, "bias", i), {attn_hidden_size}, TENSOR_NOT_REQUIRED); + + layer.time_mix_output = create_tensor(tn(LLM_TENSOR_TIME_MIX_OUTPUT, "weight", i), {n_embd, attn_hidden_size}, 0); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + } +} + +std::unique_ptr<llm_graph_context> llama_model_rwkv6qwen2::build_arch_graph(const llm_graph_params & params) const { + return std::make_unique<graph>(*this, params); +} + +llama_model_rwkv6qwen2::graph::graph(const llama_model & model, const llm_graph_params & params) : llm_build_rwkv6_base(model, params) { GGML_ASSERT(n_embd == hparams.n_embd_r()); ggml_tensor * cur; @@ -77,7 +158,7 @@ llm_build_rwkv6qwen2::llm_build_rwkv6qwen2(const llama_model & model, const llm_ cb(cur, "result_norm", -1); res->t_embd = cur; - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/rwkv7-base.cpp b/examples/talk-llama/models/rwkv7-base.cpp index cda44653849..7fcab77745c 100644 --- a/examples/talk-llama/models/rwkv7-base.cpp +++ b/examples/talk-llama/models/rwkv7-base.cpp @@ -1,5 +1,7 @@ #include "models.h" +#include "llama-memory-recurrent.h" + llm_build_rwkv7_base::llm_build_rwkv7_base(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params), model(model) {} diff --git a/examples/talk-llama/models/rwkv7.cpp b/examples/talk-llama/models/rwkv7.cpp index 5caf6553dfe..67c51f5b59c 100644 --- a/examples/talk-llama/models/rwkv7.cpp +++ b/examples/talk-llama/models/rwkv7.cpp @@ -1,6 +1,127 @@ #include "models.h" -llm_build_rwkv7::llm_build_rwkv7(const llama_model & model, const llm_graph_params & params) : +void llama_model_rwkv7::load_arch_hparams(llama_model_loader & ml) { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps, false); + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps, false); + ml.get_key(LLM_KV_WKV_HEAD_SIZE, hparams.wkv_head_size); + ml.get_key(LLM_KV_ATTENTION_DECAY_LORA_RANK, hparams.n_lora_decay); + ml.get_key(LLM_KV_ATTENTION_ICLR_LORA_RANK, hparams.n_lora_iclr); + ml.get_key(LLM_KV_ATTENTION_VALUE_RESIDUAL_MIX_LORA_RANK, hparams.n_lora_value_res_mix); + ml.get_key(LLM_KV_ATTENTION_GATE_LORA_RANK, hparams.n_lora_gate, false); + ml.get_key(LLM_KV_TOKEN_SHIFT_COUNT, hparams.token_shift_count, false); + + switch (hparams.n_layer()) { + case 12: + switch (hparams.n_embd) { + case 768: type = LLM_TYPE_190M; break; + default: type = LLM_TYPE_UNKNOWN; + } break; + case 24: + switch (hparams.n_embd) { + case 1024: type = LLM_TYPE_450M; break; + case 2048: type = LLM_TYPE_1_5B; break; + default: type = LLM_TYPE_UNKNOWN; + } break; + case 28: + switch (hparams.n_embd) { + case 1536: type = LLM_TYPE_1_5B; break; + case 3584: type = LLM_TYPE_7B; break; + default: type = LLM_TYPE_UNKNOWN; + } break; + case 32: + switch (hparams.n_embd) { + case 2560: type = LLM_TYPE_2_9B; break; + case 4096: type = LLM_TYPE_7B; break; + default: type = LLM_TYPE_UNKNOWN; + } break; + case 61: + switch (hparams.n_embd) { + case 4096: type = LLM_TYPE_14B; break; + default: type = LLM_TYPE_UNKNOWN; + } break; + default: type = LLM_TYPE_UNKNOWN; + } +} + +void llama_model_rwkv7::load_arch_tensors(llama_model_loader &) { + LLAMA_LOAD_LOCALS; + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // Block 0, LN0 + tok_norm = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "weight", 0), {n_embd}, 0); + tok_norm_b = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "bias", 0), {n_embd}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output_norm_b = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0); + + const int n_lora_decay = hparams.n_lora_decay; + const int n_lora_iclr = hparams.n_lora_iclr; + const int n_lora_value_res_mix = hparams.n_lora_value_res_mix; + const int n_lora_gate = hparams.n_lora_gate; + const int attn_hidden_size = n_embd; + const int ffn_size = hparams.n_ff_arr[0]; + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + layer.attn_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "bias", i), {n_embd}, 0); + + layer.attn_norm_2 = create_tensor(tn(LLM_TENSOR_ATTN_NORM_2, "weight", i), {n_embd}, 0); + layer.attn_norm_2_b = create_tensor(tn(LLM_TENSOR_ATTN_NORM_2, "bias", i), {n_embd}, 0); + + layer.time_mix_w0 = create_tensor(tn(LLM_TENSOR_TIME_MIX_W0, "weight", i), {n_embd}, 0); + layer.time_mix_w1 = create_tensor(tn(LLM_TENSOR_TIME_MIX_W1, "weight", i), {n_embd, n_lora_decay}, 0); + layer.time_mix_w2 = create_tensor(tn(LLM_TENSOR_TIME_MIX_W2, "weight", i), {n_lora_decay, n_embd}, 0); + + layer.time_mix_a0 = create_tensor(tn(LLM_TENSOR_TIME_MIX_A0, "weight", i), {n_embd}, 0); + layer.time_mix_a1 = create_tensor(tn(LLM_TENSOR_TIME_MIX_A1, "weight", i), {n_embd, n_lora_iclr}, 0); + layer.time_mix_a2 = create_tensor(tn(LLM_TENSOR_TIME_MIX_A2, "weight", i), {n_lora_iclr, n_embd}, 0); + + if (i == 0) { + // actually not used + layer.time_mix_v0 = create_tensor(tn(LLM_TENSOR_TIME_MIX_V0, "weight", i), {n_embd}, 0); + layer.time_mix_v1 = create_tensor(tn(LLM_TENSOR_TIME_MIX_V1, "weight", i), {n_embd, n_lora_iclr}, 0); + layer.time_mix_v2 = create_tensor(tn(LLM_TENSOR_TIME_MIX_V2, "weight", i), {n_lora_iclr, n_embd}, 0); + } else { + layer.time_mix_v0 = create_tensor(tn(LLM_TENSOR_TIME_MIX_V0, "weight", i), {n_embd}, 0); + layer.time_mix_v1 = create_tensor(tn(LLM_TENSOR_TIME_MIX_V1, "weight", i), {n_embd, n_lora_value_res_mix}, 0); + layer.time_mix_v2 = create_tensor(tn(LLM_TENSOR_TIME_MIX_V2, "weight", i), {n_lora_value_res_mix, n_embd}, 0); + } + + layer.time_mix_g1 = create_tensor(tn(LLM_TENSOR_TIME_MIX_G1, "weight", i), {n_embd, n_lora_gate}, 0); + layer.time_mix_g2 = create_tensor(tn(LLM_TENSOR_TIME_MIX_G2, "weight", i), {n_lora_gate, n_embd}, 0); + + layer.time_mix_lerp_fused = create_tensor(tn(LLM_TENSOR_TIME_MIX_LERP_FUSED, "weight", i), {n_embd, 1, 1, 6}, 0); + + layer.time_mix_k_k = create_tensor(tn(LLM_TENSOR_TIME_MIX_K_K, "weight", i), {attn_hidden_size}, 0); + layer.time_mix_k_a = create_tensor(tn(LLM_TENSOR_TIME_MIX_K_A, "weight", i), {attn_hidden_size}, 0); + layer.time_mix_r_k = create_tensor(tn(LLM_TENSOR_TIME_MIX_R_K, "weight", i), {attn_hidden_size}, 0); + + layer.time_mix_key = create_tensor(tn(LLM_TENSOR_TIME_MIX_KEY, "weight", i), {attn_hidden_size, n_embd}, 0); + layer.time_mix_value = create_tensor(tn(LLM_TENSOR_TIME_MIX_VALUE, "weight", i), {attn_hidden_size, n_embd}, 0); + layer.time_mix_receptance = create_tensor(tn(LLM_TENSOR_TIME_MIX_RECEPTANCE, "weight", i), {attn_hidden_size, n_embd}, 0); + + layer.time_mix_ln = create_tensor(tn(LLM_TENSOR_TIME_MIX_LN, "weight", i), {n_embd}, 0); + layer.time_mix_ln_b = create_tensor(tn(LLM_TENSOR_TIME_MIX_LN, "bias", i), {n_embd}, 0); + layer.time_mix_output = create_tensor(tn(LLM_TENSOR_TIME_MIX_OUTPUT, "weight", i), {n_embd, attn_hidden_size}, 0); + + layer.channel_mix_lerp_k = create_tensor(tn(LLM_TENSOR_CHANNEL_MIX_LERP_K, "weight", i), {n_embd, 1, 1}, 0); + + layer.channel_mix_key = create_tensor(tn(LLM_TENSOR_CHANNEL_MIX_KEY, "weight", i), {n_embd, ffn_size}, 0); + layer.channel_mix_value = create_tensor(tn(LLM_TENSOR_CHANNEL_MIX_VALUE, "weight", i), {ffn_size, n_embd}, 0); + } + +} + +std::unique_ptr<llm_graph_context> llama_model_rwkv7::build_arch_graph(const llm_graph_params & params) const { + return std::make_unique<graph>(*this, params); +} + +llama_model_rwkv7::graph::graph(const llama_model & model, const llm_graph_params & params) : llm_build_rwkv7_base(model, params) { GGML_ASSERT(hparams.token_shift_count == 2); @@ -9,7 +130,7 @@ llm_build_rwkv7::llm_build_rwkv7(const llama_model & model, const llm_graph_para ggml_tensor * v_first = nullptr; inpL = build_inp_embd(model.tok_embd); - inpL = build_norm(inpL, model.tok_norm, model.tok_norm_b, LLM_NORM, -1); + inpL = build_norm(inpL, model.tok_norm, model.tok_norm_b, LLM_NORM, 0); auto * rs_inp = build_rs_inp(); @@ -81,7 +202,7 @@ llm_build_rwkv7::llm_build_rwkv7(const llama_model & model, const llm_graph_para cb(cur, "result_norm", -1); res->t_embd = cur; - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/seed-oss.cpp b/examples/talk-llama/models/seed-oss.cpp index 0dc33c50ba3..57de881a091 100644 --- a/examples/talk-llama/models/seed-oss.cpp +++ b/examples/talk-llama/models/seed-oss.cpp @@ -1,10 +1,56 @@ #include "models.h" -llm_build_seed_oss::llm_build_seed_oss(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { - const int64_t n_embd_head = hparams.n_embd_head_v; +void llama_model_seed_oss::load_arch_hparams(llama_model_loader & ml) { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); - GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); - GGML_ASSERT(n_embd_head == hparams.n_rot); + switch (hparams.n_layer()) { + case 64: type = LLM_TYPE_36B; break; + default: type = LLM_TYPE_UNKNOWN; + } +} + +void llama_model_seed_oss::load_arch_tensors(llama_model_loader &) { + LLAMA_LOAD_LOCALS; + + const uint32_t head_dim = hparams.n_embd_head_k(); + const int64_t n_qo_dim = n_head * head_dim; + const int64_t n_kv_dim = n_head_kv * head_dim; + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); + // if output is NULL, init from the input tok embed + if (output == NULL) { + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); + } + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + create_tensor_qkv(layer, i, n_embd, n_qo_dim, n_kv_dim, n_kv_dim, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_qo_dim, n_embd}, 0); + + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + layer.attn_post_norm = create_tensor(tn(LLM_TENSOR_ATTN_POST_NORM, "weight", i), {n_embd}, 0); + + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); + } +} + +std::unique_ptr<llm_graph_context> llama_model_seed_oss::build_arch_graph(const llm_graph_params & params) const { + return std::make_unique<graph>(*this, params); +} + +llama_model_seed_oss::graph::graph(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { + const int64_t n_embd_head = hparams.n_embd_head_v(); + + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k()); + GGML_ASSERT(n_embd_head == n_rot); ggml_tensor * cur; ggml_tensor * inpL; @@ -32,27 +78,8 @@ llm_build_seed_oss::llm_build_seed_oss(const llama_model & model, const llm_grap // self-attention { // compute Q and K and RoPE them - ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); - cb(Qcur, "Qcur", il); - if (model.layers[il].bq) { - Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq); - cb(Qcur, "Qcur", il); - } - ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); - cb(Kcur, "Kcur", il); - if (model.layers[il].bk) { - Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk); - cb(Kcur, "Kcur", il); - } - ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); - cb(Vcur, "Vcur", il); - if (model.layers[il].bv) { - Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv); - cb(Vcur, "Vcur", il); - } - Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); - Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + auto [Qcur, Kcur, Vcur] = build_qkv(model.layers[il], cur, + n_embd_head, n_head, n_head_kv, il); Qcur = ggml_rope_ext( ctx0, Qcur, inp_pos, nullptr, @@ -71,7 +98,7 @@ llm_build_seed_oss::llm_build_seed_oss(const llama_model & model, const llm_grap cb(Vcur, "Vcur", il); cur = build_attn(inp_attn, - model.layers[il].wo, model.layers[il].bo, + model.layers[il].wo, model.layers[il].wo_b, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale, il); cb(cur, "attn_out", il); } @@ -115,7 +142,7 @@ llm_build_seed_oss::llm_build_seed_oss(const llama_model & model, const llm_grap res->t_embd = cur; // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/smallthinker.cpp b/examples/talk-llama/models/smallthinker.cpp index 4c497ca76f4..a8e3d957f1f 100644 --- a/examples/talk-llama/models/smallthinker.cpp +++ b/examples/talk-llama/models/smallthinker.cpp @@ -1,11 +1,84 @@ #include "models.h" +void llama_model_smallthinker::load_arch_hparams(llama_model_loader & ml) { + const bool found_swa = ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa, false); + + if (found_swa && hparams.n_swa > 0) { + hparams.swa_type = LLAMA_SWA_TYPE_STANDARD; + hparams.n_swa = 4096; + uint32_t swa_period = 4; + ml.get_key_or_arr(LLM_KV_ATTENTION_SLIDING_WINDOW_PATTERN, swa_period, false); + hparams.set_swa_pattern(swa_period, true); + + hparams.rope_freq_base_train_swa = hparams.rope_freq_base_train; + hparams.rope_freq_scale_train_swa = hparams.rope_freq_scale_train; + ml.get_key(LLM_KV_ROPE_FREQ_BASE_SWA, hparams.rope_freq_base_train_swa, false); + } else { + hparams.swa_type = LLAMA_SWA_TYPE_NONE; + hparams.n_no_rope_layer_step = hparams.n_layer(); + } + + ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp, false); + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + ml.get_key(LLM_KV_EXPERT_GATING_FUNC, hparams.expert_gating_func, false); + + switch (hparams.n_layer()) { + case 32: type = LLM_TYPE_4B; break; + case 52: type = LLM_TYPE_20B; break; + default: type = LLM_TYPE_UNKNOWN; + } +} + +void llama_model_smallthinker::load_arch_tensors(llama_model_loader &) { + LLAMA_LOAD_LOCALS; + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), { n_embd, n_vocab }, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), { n_embd }, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); + + // if output is NULL, init from the input tok embed + if (output == NULL) { + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); + } + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), { n_embd }, 0); + + create_tensor_qkv(layer, i, n_embd, n_embd_head_k * n_head, n_embd_gqa, n_embd_gqa, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), { n_embd_head_k * n_head, n_embd }, 0); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), { n_embd }, 0); + + GGML_ASSERT(n_expert > 0 && "n_expert must be > 0 for SMALLTHINKER"); + GGML_ASSERT(n_expert_used > 0 && "n_expert_used must be > 0 for SMALLTHINKER"); + + // MoE branch + const int64_t n_ff_exp = hparams.n_ff_exp; + layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), { n_embd, n_expert }, 0); + layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert }, 0); + layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), { n_ff_exp, n_embd, n_expert }, 0); + layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert }, 0); + } +} + +std::unique_ptr<llm_graph_context> llama_model_smallthinker::build_arch_graph(const llm_graph_params & params) const { + if (hparams.swa_type == LLAMA_SWA_TYPE_STANDARD) { + return std::make_unique<graph<true>> (*this, params); + } else { + return std::make_unique<graph<false>>(*this, params); + } +} + template <bool iswa> -llm_build_smallthinker<iswa>::llm_build_smallthinker(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params){ - const int64_t n_embd_head = hparams.n_embd_head_v; +llama_model_smallthinker::graph<iswa>::graph(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params){ + const int64_t n_embd_head = hparams.n_embd_head_v(); - GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); - GGML_ASSERT(n_embd_head == hparams.n_rot); + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k()); + GGML_ASSERT(n_embd_head == n_rot); ggml_tensor * cur; ggml_tensor * inpL; @@ -45,18 +118,8 @@ llm_build_smallthinker<iswa>::llm_build_smallthinker(const llama_model & model, // self_attention { // compute Q and K and RoPE them - struct ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); - cb(Qcur, "Qcur", il); - - struct ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); - cb(Kcur, "Kcur", il); - - struct ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); - cb(Vcur, "Vcur", il); - - Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); - Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + auto [Qcur, Kcur, Vcur] = build_qkv(model.layers[il], cur, + n_embd_head, n_head, n_head_kv, il); if (use_rope) { Qcur = ggml_rope_ext(ctx0, Qcur, inp_pos, nullptr, n_rot, rope_type, n_ctx_orig, freq_base_l, freq_scale_l, @@ -69,7 +132,7 @@ llm_build_smallthinker<iswa>::llm_build_smallthinker(const llama_model & model, cb(Kcur, "Kcur", il); cur = build_attn(inp_attn, - model.layers[il].wo, model.layers[il].bo, + model.layers[il].wo, model.layers[il].wo_b, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f / sqrtf(float(n_embd_head)), il); } if (il == n_layer - 1 && inp_out_ids) { @@ -93,7 +156,7 @@ llm_build_smallthinker<iswa>::llm_build_smallthinker(const llama_model & model, nullptr, n_expert, n_expert_used, LLM_FFN_RELU, true, - false, 0.0, + hparams.expert_weights_scale, static_cast<llama_expert_gating_func_type>(hparams.expert_gating_func), il, probs); @@ -101,6 +164,7 @@ llm_build_smallthinker<iswa>::llm_build_smallthinker(const llama_model & model, cur = ffn_out; cur = ggml_add(ctx0, cur, ffn_inp); + cur = build_cvec(cur, il); cb(cur, "l_out", il); @@ -114,7 +178,7 @@ llm_build_smallthinker<iswa>::llm_build_smallthinker(const llama_model & model, res->t_embd = cur; // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; @@ -122,5 +186,5 @@ llm_build_smallthinker<iswa>::llm_build_smallthinker(const llama_model & model, } // Explicit template instantiations -template struct llm_build_smallthinker<false>; -template struct llm_build_smallthinker<true>; +template struct llama_model_smallthinker::graph<false>; +template struct llama_model_smallthinker::graph<true>; diff --git a/examples/talk-llama/models/smollm3.cpp b/examples/talk-llama/models/smollm3.cpp index 97c30deed54..c67d967b204 100644 --- a/examples/talk-llama/models/smollm3.cpp +++ b/examples/talk-llama/models/smollm3.cpp @@ -1,10 +1,53 @@ #include "models.h" -llm_build_smollm3::llm_build_smollm3(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { - const int64_t n_embd_head = hparams.n_embd_head_v; +void llama_model_smollm3::load_arch_hparams(llama_model_loader & ml) { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + hparams.n_no_rope_layer_step = 4; - GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); - GGML_ASSERT(n_embd_head == hparams.n_rot); + switch (hparams.n_layer()) { + case 36: type = LLM_TYPE_3B; break; + default: type = LLM_TYPE_UNKNOWN; + } +} + +void llama_model_smollm3::load_arch_tensors(llama_model_loader &) { + LLAMA_LOAD_LOCALS; + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); + + // if output is NULL, init from the input tok embed + if (output == NULL) { + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); + } + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + + create_tensor_qkv(layer, i, n_embd, n_embd_head_k * n_head, n_embd_k_gqa, n_embd_v_gqa, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + } +} + +std::unique_ptr<llm_graph_context> llama_model_smollm3::build_arch_graph(const llm_graph_params & params) const { + return std::make_unique<graph>(*this, params); +} + +llama_model_smollm3::graph::graph(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { + const int64_t n_embd_head = hparams.n_embd_head_v(); + + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k()); + GGML_ASSERT(n_embd_head == n_rot); ggml_tensor * cur; ggml_tensor * inpL; @@ -34,27 +77,8 @@ llm_build_smollm3::llm_build_smollm3(const llama_model & model, const llm_graph_ // self-attention { // compute Q and K and RoPE them - ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); - cb(Qcur, "Qcur", il); - if (model.layers[il].bq) { - Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq); - cb(Qcur, "Qcur", il); - } - ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); - cb(Kcur, "Kcur", il); - if (model.layers[il].bk) { - Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk); - cb(Kcur, "Kcur", il); - } - ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); - cb(Vcur, "Vcur", il); - if (model.layers[il].bv) { - Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv); - cb(Vcur, "Vcur", il); - } - Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); - Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + auto [Qcur, Kcur, Vcur] = build_qkv(model.layers[il], cur, + n_embd_head, n_head, n_head_kv, il); if (use_rope) { Qcur = ggml_rope_ext( @@ -74,7 +98,7 @@ llm_build_smollm3::llm_build_smollm3(const llama_model & model, const llm_graph_ cb(Vcur, "Vcur", il); cur = build_attn(inp_attn, - model.layers[il].wo, model.layers[il].bo, + model.layers[il].wo, model.layers[il].wo_b, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale, il); cb(cur, "attn_out", il); } @@ -119,7 +143,7 @@ llm_build_smollm3::llm_build_smollm3(const llama_model & model, const llm_graph_ res->t_embd = cur; // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/stablelm.cpp b/examples/talk-llama/models/stablelm.cpp index bed1915c006..bf6087b8796 100644 --- a/examples/talk-llama/models/stablelm.cpp +++ b/examples/talk-llama/models/stablelm.cpp @@ -1,9 +1,57 @@ #include "models.h" -llm_build_stablelm::llm_build_stablelm(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { - const int64_t n_embd_head = hparams.n_embd_head_v; +void llama_model_stablelm::load_arch_hparams(llama_model_loader & ml) { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); + + switch (hparams.n_layer()) { + case 24: type = LLM_TYPE_1B; break; + case 32: type = LLM_TYPE_3B; break; + case 40: type = LLM_TYPE_12B; break; + default: type = LLM_TYPE_UNKNOWN; + } +} + +void llama_model_stablelm::load_arch_tensors(llama_model_loader &) { + LLAMA_LOAD_LOCALS; + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm_b = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd}, 0); + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0); + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + layer.attn_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "bias", i), {n_embd}, 0); + + create_tensor_qkv(layer, i, n_embd, n_embd, n_embd_gqa, n_embd_gqa, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); + + // optional q and k layernorms, present in StableLM 2 12B + layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head_k, n_head}, TENSOR_NOT_REQUIRED); + layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k, n_head_kv}, TENSOR_NOT_REQUIRED); + + // optional FFN norm, not present in StableLM 2 12B which uses parallel residual + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, TENSOR_NOT_REQUIRED); + layer.ffn_norm_b = create_tensor(tn(LLM_TENSOR_FFN_NORM, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); + + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + } +} + +std::unique_ptr<llm_graph_context> llama_model_stablelm::build_arch_graph(const llm_graph_params & params) const { + return std::make_unique<graph>(*this, params); +} + +llama_model_stablelm::graph::graph(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { + const int64_t n_embd_head = hparams.n_embd_head_v(); - GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k()); ggml_tensor * cur; ggml_tensor * inpL; @@ -30,30 +78,8 @@ llm_build_stablelm::llm_build_stablelm(const llama_model & model, const llm_grap // self-attention { // compute Q and K and RoPE them - ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); - cb(Qcur, "Qcur", il); - if (model.layers[il].bq) { - Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq); - cb(Qcur, "Qcur", il); - } - - ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); - cb(Kcur, "Kcur", il); - if (model.layers[il].bk) { - Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk); - cb(Kcur, "Kcur", il); - } - - ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); - cb(Vcur, "Vcur", il); - if (model.layers[il].bv) { - Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv); - cb(Vcur, "Vcur", il); - } - - Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); - Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + auto [Qcur, Kcur, Vcur] = build_qkv(model.layers[il], cur, + n_embd_head, n_head, n_head_kv, il); if (model.layers[il].attn_q_norm) { Qcur = build_norm(Qcur, @@ -87,7 +113,7 @@ llm_build_stablelm::llm_build_stablelm(const llama_model & model, const llm_grap cb(Vcur, "Vcur", il); cur = build_attn(inp_attn, - model.layers[il].wo, NULL, + model.layers[il].wo, NULL, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } if (il == n_layer - 1 && inp_out_ids) { @@ -137,7 +163,7 @@ llm_build_stablelm::llm_build_stablelm(const llama_model & model, const llm_grap res->t_embd = cur; // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/starcoder.cpp b/examples/talk-llama/models/starcoder.cpp index e197af4a8c6..f73a88fd4e9 100644 --- a/examples/talk-llama/models/starcoder.cpp +++ b/examples/talk-llama/models/starcoder.cpp @@ -1,10 +1,66 @@ #include "models.h" -llm_build_starcoder::llm_build_starcoder(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { - const int64_t n_embd_head = hparams.n_embd_head_v; - const int64_t n_embd_gqa = hparams.n_embd_v_gqa(); +void llama_model_starcoder::load_arch_hparams(llama_model_loader & ml) { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); + + switch (hparams.n_layer()) { + case 24: type = LLM_TYPE_1B; break; + case 36: type = LLM_TYPE_3B; break; + case 42: type = LLM_TYPE_7B; break; + case 40: type = LLM_TYPE_15B; break; + default: type = LLM_TYPE_UNKNOWN; + } +} + +void llama_model_starcoder::load_arch_tensors(llama_model_loader &) { + LLAMA_LOAD_LOCALS; + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + pos_embd = create_tensor(tn(LLM_TENSOR_POS_EMBD, "weight"), {n_embd, n_ctx_train}, 0); + + // output + { + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output_norm_b = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); + if (!output) { + // needs to be on GPU + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); + } + + } + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + layer.attn_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "bias", i), {n_embd}, 0); + + layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, 0); + layer.wqkv_b = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "bias", i), {n_embd + 2*n_embd_gqa}, 0); + + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); + layer.wo_b = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, 0); - GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + layer.ffn_norm_b = create_tensor(tn(LLM_TENSOR_FFN_NORM, "bias", i), {n_embd}, 0); + + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0); + layer.ffn_down_b = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, 0); + + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_up_b = create_tensor(tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff}, 0); + } +} + +std::unique_ptr<llm_graph_context> llama_model_starcoder::build_arch_graph(const llm_graph_params & params) const { + return std::make_unique<graph>(*this, params); +} + +llama_model_starcoder::graph::graph(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { + const int64_t n_embd_head = hparams.n_embd_head_v(); + + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k()); ggml_tensor * cur; ggml_tensor * inpL; @@ -33,22 +89,11 @@ llm_build_starcoder::llm_build_starcoder(const llama_model & model, const llm_gr // self-attention { - cur = build_lora_mm(model.layers[il].wqkv, cur); - cb(cur, "wqkv", il); - - cur = ggml_add(ctx0, cur, model.layers[il].bqkv); - cb(cur, "bqkv", il); - - ggml_tensor * Qcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 0*sizeof(float)*(n_embd)); - ggml_tensor * Kcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 1*sizeof(float)*(n_embd)); - ggml_tensor * Vcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head*sizeof(float), cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa)); - - cb(Qcur, "Qcur", il); - cb(Kcur, "Kcur", il); - cb(Vcur, "Vcur", il); + auto [Qcur, Kcur, Vcur] = build_qkv(model.layers[il], cur, + n_embd_head, n_head, n_head_kv, il); cur = build_attn(inp_attn, - model.layers[il].wo, model.layers[il].bo, + model.layers[il].wo, model.layers[il].wo_b, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } if (il == n_layer - 1 && inp_out_ids) { @@ -91,7 +136,7 @@ llm_build_starcoder::llm_build_starcoder(const llama_model & model, const llm_gr cb(cur, "result_norm", -1); res->t_embd = cur; - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/starcoder2.cpp b/examples/talk-llama/models/starcoder2.cpp index e40ef2cb749..b81b469374a 100644 --- a/examples/talk-llama/models/starcoder2.cpp +++ b/examples/talk-llama/models/starcoder2.cpp @@ -1,10 +1,66 @@ #include "models.h" -llm_build_starcoder2::llm_build_starcoder2(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { - const int64_t n_embd_head = hparams.n_embd_head_v; +void llama_model_starcoder2::load_arch_hparams(llama_model_loader & ml) { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); + + switch (hparams.n_layer()) { + case 30: type = LLM_TYPE_3B; break; + case 32: type = LLM_TYPE_7B; break; + case 40: type = LLM_TYPE_15B; break; + case 52: type = LLM_TYPE_20B; break; // granite + case 88: type = LLM_TYPE_34B; break; // granite + default: type = LLM_TYPE_UNKNOWN; + } +} + +void llama_model_starcoder2::load_arch_tensors(llama_model_loader &) { + LLAMA_LOAD_LOCALS; + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output_norm_b = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd}, 0); + + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); + // if output is NULL, init from the input tok embed + if (output == NULL) { + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); + } + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + layer.attn_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "bias", i), {n_embd}, 0); + + create_tensor_qkv(layer, i, n_embd, n_embd, n_embd_gqa, n_embd_gqa, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); + + // optional bias tensors + layer.wo_b = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, 0); - GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); - GGML_ASSERT(n_embd_head == hparams.n_rot); + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + layer.ffn_norm_b = create_tensor(tn(LLM_TENSOR_FFN_NORM, "bias", i), {n_embd}, 0); + + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + + // optional bias tensors + layer.ffn_down_b = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, 0); + layer.ffn_up_b = create_tensor(tn(LLM_TENSOR_FFN_UP , "bias", i), { n_ff}, 0); + } +} + +std::unique_ptr<llm_graph_context> llama_model_starcoder2::build_arch_graph(const llm_graph_params & params) const { + return std::make_unique<graph>(*this, params); +} + +llama_model_starcoder2::graph::graph(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { + const int64_t n_embd_head = hparams.n_embd_head_v(); + + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k()); + GGML_ASSERT(n_embd_head == n_rot); ggml_tensor * cur; ggml_tensor * inpL; @@ -30,27 +86,8 @@ llm_build_starcoder2::llm_build_starcoder2(const llama_model & model, const llm_ // self-attention { // compute Q and K and RoPE them - ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); - cb(Qcur, "Qcur", il); - if (model.layers[il].bq) { - Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq); - cb(Qcur, "Qcur", il); - } - ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); - cb(Kcur, "Kcur", il); - if (model.layers[il].bk) { - Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk); - cb(Kcur, "Kcur", il); - } - ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); - cb(Vcur, "Vcur", il); - if (model.layers[il].bv) { - Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv); - cb(Vcur, "Vcur", il); - } - Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); - Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + auto [Qcur, Kcur, Vcur] = build_qkv(model.layers[il], cur, + n_embd_head, n_head, n_head_kv, il); Qcur = ggml_rope_ext( ctx0, Qcur, inp_pos, nullptr, @@ -69,7 +106,7 @@ llm_build_starcoder2::llm_build_starcoder2(const llama_model & model, const llm_ cb(Vcur, "Vcur", il); cur = build_attn(inp_attn, - model.layers[il].wo, model.layers[il].bo, + model.layers[il].wo, model.layers[il].wo_b, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } if (il == n_layer - 1 && inp_out_ids) { @@ -112,7 +149,7 @@ llm_build_starcoder2::llm_build_starcoder2(const llama_model & model, const llm_ res->t_embd = cur; // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/models/step35.cpp b/examples/talk-llama/models/step35.cpp new file mode 100644 index 00000000000..e2218c58704 --- /dev/null +++ b/examples/talk-llama/models/step35.cpp @@ -0,0 +1,557 @@ +#include "models.h" + +void llama_model_step35::load_arch_hparams(llama_model_loader & ml) { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + + hparams.swa_type = LLAMA_SWA_TYPE_STANDARD; + + // full_attention layer only use half of the RoPE dimensions + hparams.n_rot_full = hparams.n_rot_full / 2; + + // MoE + SWA parameters + ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp); + ml.get_key(LLM_KV_EXPERT_SHARED_FEED_FORWARD_LENGTH, hparams.n_ff_shexp, false); + ml.get_key(LLM_KV_EXPERT_GATING_FUNC, hparams.expert_gating_func, false); + ml.get_key(LLM_KV_EXPERT_WEIGHTS_SCALE, hparams.expert_weights_scale, false); + ml.get_key(LLM_KV_EXPERT_WEIGHTS_NORM, hparams.expert_weights_norm, false); + + // Step35 uses sigmoid gating by default (if not set in GGUF) + if (hparams.expert_gating_func == LLAMA_EXPERT_GATING_FUNC_TYPE_NONE) { + hparams.expert_gating_func = LLAMA_EXPERT_GATING_FUNC_TYPE_SIGMOID; + } + + ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa); + ml.get_key(LLM_KV_ROPE_FREQ_BASE_SWA, hparams.rope_freq_base_train_swa, false); + + ml.get_key_or_arr(LLM_KV_ATTENTION_SLIDING_WINDOW_PATTERN, hparams.is_swa_impl, hparams.n_layer()); + + ml.get_key_or_arr(LLM_KV_SWIGLU_CLAMP_EXP, hparams.swiglu_clamp_exp, hparams.n_layer(), false); + ml.get_key_or_arr(LLM_KV_SWIGLU_CLAMP_SHEXP, hparams.swiglu_clamp_shexp, hparams.n_layer(), false); + + // NextN/MTP (Step3p5): extra decoder block appended beyond the main stack. + ml.get_key(LLM_KV_NEXTN_PREDICT_LAYERS, hparams.n_layer_nextn, false); + GGML_ASSERT(hparams.n_layer_nextn < hparams.n_layer_all && "n_layer_nextn must be < n_layer_impl"); + + switch (hparams.n_layer()) { + case 45: type = LLM_TYPE_196B_A11B; break; + default: type = LLM_TYPE_UNKNOWN; + } +} + +void llama_model_step35::load_arch_tensors(llama_model_loader & ml) { + LLAMA_LOAD_LOCALS; + + const bool mtp_only = (hparams.n_layer_nextn > 0) && (ml.get_weight("blk.0.attn_norm.weight") == nullptr); + // Trunk-only: the GGUF declares MTP layers in metadata but the actual MTP + // tensors live in a separate file (e.g. user split target/draft). Mark + // MTP tensors NOT_REQUIRED so the trunk loads cleanly. + const std::string mtp_probe = "blk." + std::to_string(n_layer) + ".nextn.eh_proj.weight"; + const bool trunk_only = (hparams.n_layer_nextn > 0) && (ml.get_weight(mtp_probe.c_str()) == nullptr); + const int trunk_flags = mtp_only ? TENSOR_NOT_REQUIRED : 0; + const int mtp_flags = trunk_only ? TENSOR_NOT_REQUIRED : 0; + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, trunk_flags); + + // STEP35 supports per-layer partial RoPE dims; rope factors are stored as a single shared tensor + // ("rope_freqs.weight") and ggml uses only the first (n_rot_l/2) entries per layer. + uint32_t n_rot_max = 0; + for (int i = 0; i < n_layer; ++i) { + n_rot_max = std::max(n_rot_max, hparams.n_rot(i)); + } + if (n_rot_max == 0) { + n_rot_max = n_rot; + } + + auto load_block_trunk = [&](int i, int flags) { + auto & layer = layers[i]; + + const uint32_t n_head_l = hparams.n_head(i); + const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(i); + const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(i); + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, flags); + layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head_k}, TENSOR_NOT_REQUIRED); + layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k}, TENSOR_NOT_REQUIRED); + + // optional rope factors (llama3) / longrope tensors + if (hparams.rope_scaling_type_train == LLAMA_ROPE_SCALING_TYPE_LONGROPE) { + layer.rope_long = create_tensor(tn(LLM_TENSOR_ROPE_FACTORS_LONG, "weight", i), {n_rot_max/2}, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0)); + layer.rope_short = create_tensor(tn(LLM_TENSOR_ROPE_FACTORS_SHORT, "weight", i), {n_rot_max/2}, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0)); + } else { + layer.rope_freqs = create_tensor(tn(LLM_TENSOR_ROPE_FREQS, "weight", i), {n_rot_max/2}, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0)); + } + + create_tensor_qkv(layer, i, n_embd, n_embd_head_k * n_head_l, n_embd_k_gqa, n_embd_v_gqa, flags); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_v * n_head_l, n_embd}, flags); + + // head-wise attention gate (Step35 self_attn.g_proj) + layer.wqkv_gate = create_tensor(tn(LLM_TENSOR_ATTN_GATE, "weight", i), {n_embd, n_head_l}, TENSOR_NOT_REQUIRED); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, flags); + + // dense MLP (leading dense blocks) + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, TENSOR_NOT_REQUIRED); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, TENSOR_NOT_REQUIRED); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, TENSOR_NOT_REQUIRED); + + // MoE routed experts + selection bias (router_bias) + const int64_t n_ff_exp = hparams.n_ff_exp; + layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, TENSOR_NOT_REQUIRED); + layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {n_embd, n_ff_exp, n_expert}, TENSOR_NOT_REQUIRED); + layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff_exp, n_embd, n_expert}, TENSOR_NOT_REQUIRED); + layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), {n_embd, n_ff_exp, n_expert}, TENSOR_NOT_REQUIRED); + layer.ffn_exp_probs_b = create_tensor(tn(LLM_TENSOR_FFN_EXP_PROBS_B, "bias", i), {n_expert}, TENSOR_NOT_REQUIRED); + + // shared expert MLP + layer.ffn_gate_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), {n_embd, hparams.n_ff_shexp}, TENSOR_NOT_REQUIRED); + layer.ffn_up_shexp = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), {n_embd, hparams.n_ff_shexp}, TENSOR_NOT_REQUIRED); + layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), {hparams.n_ff_shexp, n_embd}, TENSOR_NOT_REQUIRED); + }; + + auto load_block_mtp = [&](int i, bool is_first_mtp) { + auto & layer = layers[i]; + + const uint32_t n_head_l = hparams.n_head(i); + const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(i); + const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(i); + + // The MTP block is a full Step3p5 decoder layer (mtp_block) plus the + // NextN-specific wiring (enorm/hnorm/eh_proj + optional shared head). + // `mtp_flags` becomes NOT_REQUIRED when the GGUF is trunk-only. + // + // Only the FIRST MTP block (i == n_main) is required for the + // single-block MTP runtime; trailing MTP blocks are always tolerated + // as missing so pruned GGUFs (block 0 only) load cleanly. Override + // mtp_flags to NOT_REQUIRED for those. + const int eff_mtp_flags = is_first_mtp ? mtp_flags : (mtp_flags | TENSOR_NOT_REQUIRED); + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, eff_mtp_flags); + layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head_k}, TENSOR_NOT_REQUIRED); + layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k}, TENSOR_NOT_REQUIRED); + + if (hparams.rope_scaling_type_train == LLAMA_ROPE_SCALING_TYPE_LONGROPE) { + layer.rope_long = create_tensor(tn(LLM_TENSOR_ROPE_FACTORS_LONG, "weight", i), {n_rot_max/2}, TENSOR_NOT_REQUIRED | TENSOR_DUPLICATED); + layer.rope_short = create_tensor(tn(LLM_TENSOR_ROPE_FACTORS_SHORT, "weight", i), {n_rot_max/2}, TENSOR_NOT_REQUIRED | TENSOR_DUPLICATED); + } else { + layer.rope_freqs = create_tensor(tn(LLM_TENSOR_ROPE_FREQS, "weight", i), {n_rot_max/2}, TENSOR_NOT_REQUIRED | TENSOR_DUPLICATED); + } + + create_tensor_qkv(layer, i, n_embd, n_embd_head_k * n_head_l, n_embd_k_gqa, n_embd_v_gqa, eff_mtp_flags); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_v * n_head_l, n_embd}, eff_mtp_flags); + + layer.wqkv_gate = create_tensor(tn(LLM_TENSOR_ATTN_GATE, "weight", i), {n_embd, n_head_l}, TENSOR_NOT_REQUIRED); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, eff_mtp_flags); + + // dense MLP (leading dense blocks) — present if the MTP block isn't MoE + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, TENSOR_NOT_REQUIRED); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, TENSOR_NOT_REQUIRED); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, TENSOR_NOT_REQUIRED); + + // MoE routed experts + selection bias (router_bias) + const int64_t n_ff_exp = hparams.n_ff_exp; + layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, TENSOR_NOT_REQUIRED); + layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {n_embd, n_ff_exp, n_expert}, TENSOR_NOT_REQUIRED); + layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff_exp, n_embd, n_expert}, TENSOR_NOT_REQUIRED); + layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), {n_embd, n_ff_exp, n_expert}, TENSOR_NOT_REQUIRED); + layer.ffn_exp_probs_b = create_tensor(tn(LLM_TENSOR_FFN_EXP_PROBS_B, "bias", i), {n_expert}, TENSOR_NOT_REQUIRED); + + layer.ffn_gate_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), {n_embd, hparams.n_ff_shexp}, TENSOR_NOT_REQUIRED); + layer.ffn_up_shexp = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), {n_embd, hparams.n_ff_shexp}, TENSOR_NOT_REQUIRED); + layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), {hparams.n_ff_shexp, n_embd}, TENSOR_NOT_REQUIRED); + + // NextN-specific tensors that define the MTP block. + layer.nextn.eh_proj = create_tensor(tn(LLM_TENSOR_NEXTN_EH_PROJ, "weight", i), { 2 * n_embd, n_embd }, eff_mtp_flags); + layer.nextn.enorm = create_tensor(tn(LLM_TENSOR_NEXTN_ENORM, "weight", i), { n_embd }, eff_mtp_flags); + layer.nextn.hnorm = create_tensor(tn(LLM_TENSOR_NEXTN_HNORM, "weight", i), { n_embd }, eff_mtp_flags); + layer.nextn.embed_tokens = create_tensor(tn(LLM_TENSOR_NEXTN_EMBED_TOKENS, "weight", i), { n_embd, n_vocab }, TENSOR_NOT_REQUIRED); + layer.nextn.shared_head_head = create_tensor(tn(LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD, "weight", i), { n_embd, n_vocab }, TENSOR_NOT_REQUIRED); + layer.nextn.shared_head_norm = create_tensor(tn(LLM_TENSOR_NEXTN_SHARED_HEAD_NORM, "weight", i), { n_embd }, TENSOR_NOT_REQUIRED); + }; + + for (int i = 0; i < n_layer; ++i) { + load_block_trunk(i, trunk_flags); + } + // Only the first MTP block (i == n_main) is required at runtime — the + // single-block-MTP graph in build_arch_graph always uses that one. + // Trailing MTP blocks are loaded if present (so an un-pruned GGUF with + // all MTP layers still works) but tolerated when absent via the pruning + // path. See scripts/prune_step35_extra_mtp.py for the pruner. + for (int i = n_layer; i < n_layer_all; ++i) { + load_block_mtp(i, /*is_first_mtp=*/ i == n_layer); + } +} + +std::unique_ptr<llm_graph_context> llama_model_step35::build_arch_graph(const llm_graph_params & params) const { + if (params.gtype == LLM_GRAPH_TYPE_DECODER_MTP) { + return std::make_unique<graph_mtp>(*this, params); + } + return std::make_unique<graph>(*this, params); +} + +llama_model_step35::graph::graph(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { + ggml_tensor * cur; + ggml_tensor * inpL; + + inpL = build_inp_embd(model.tok_embd); + ggml_tensor * inp_pos = build_inp_pos(); + auto * inp_attn = build_attn_inp_kv_iswa(); + ggml_tensor * inp_out_ids = build_inp_out_ids(); + + // MTP/NextN layers are loaded as extra decoder blocks but not executed in the main pass. + for (int il = 0; il < n_layer; ++il) { + ggml_tensor * inpSA = inpL; + + const uint32_t n_head_l = hparams.n_head(il); + const uint32_t n_head_kv_l = hparams.n_head_kv(il); + + const float freq_base_l = model.get_rope_freq_base(cparams, il); + const float freq_scale_l = model.get_rope_freq_scale(cparams, il); + + cur = inpL; + + // dump pre-attn RMSNorm input to pinpoint layer boundary issues + cb(cur, "attn_norm_in", il); + + // self-attention + { + cur = build_norm(cur, model.layers[il].attn_norm, nullptr, LLM_NORM_RMS, il); + cb(cur, "attn_norm", il); + ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); + ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); + ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); + + cb(Qcur, "Qcur", il); + cb(Kcur, "Kcur", il); + cb(Vcur, "Vcur", il); + + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head_k, n_head_l, n_tokens); + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head_k, n_head_kv_l, n_tokens); + Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head_v, n_head_kv_l, n_tokens); + + // Q/K per-head RMSNorm (Step35 q_norm / k_norm) + if (model.layers[il].attn_q_norm) { + Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, nullptr, LLM_NORM_RMS, il); + cb(Qcur, "Qcur_normed", il); + } + if (model.layers[il].attn_k_norm) { + Kcur = build_norm(Kcur, model.layers[il].attn_k_norm, nullptr, LLM_NORM_RMS, il); + cb(Kcur, "Kcur_normed", il); + } + + // RoPE (partial rotary factors per layer) + const bool is_swa = hparams.is_swa(il); + ggml_tensor * rope_factors = is_swa ? nullptr : model.get_rope_factors(cparams, il); + const int64_t n_rot_l = hparams.n_rot(il); + Qcur = ggml_rope_ext( + ctx0, Qcur, inp_pos, rope_factors, + n_rot_l, rope_type, n_ctx_orig, freq_base_l, freq_scale_l, + ext_factor, attn_factor, beta_fast, beta_slow + ); + Kcur = ggml_rope_ext( + ctx0, Kcur, inp_pos, rope_factors, + n_rot_l, rope_type, n_ctx_orig, freq_base_l, freq_scale_l, + ext_factor, attn_factor, beta_fast, beta_slow + ); + cb(Qcur, "Qcur_pos", il); + cb(Kcur, "Kcur_pos", il); + + const float kq_scale = 1.0f / sqrtf(float(n_embd_head_k)); + ggml_tensor * attn_out = build_attn(inp_attn, + nullptr, nullptr, nullptr, + Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale, il); + cb(attn_out, "attn_out", il); + // head-wise attention gate: sigmoid(g_proj(x)) in torch + if (model.layers[il].wqkv_gate) { + ggml_tensor * gate = build_lora_mm(model.layers[il].wqkv_gate, cur); // [n_head_l, n_tokens] + cb(gate, "attn_gate", il); + + gate = ggml_sigmoid(ctx0, gate); + cb(gate, "attn_gate_sigmoid", il); + + // reshape + broadcast to [n_embd_head_v, n_head_l, n_tokens] + ggml_tensor * attn_3d = ggml_reshape_3d(ctx0, attn_out, n_embd_head_v, n_head_l, n_tokens); + ggml_tensor * gate_3d = ggml_reshape_3d(ctx0, gate, 1, n_head_l, n_tokens); + cb(gate_3d, "attn_gate_3d", il); + + attn_3d = ggml_mul(ctx0, attn_3d, gate_3d); + cb(attn_3d, "attn_gated_3d", il); + + attn_out = ggml_reshape_2d(ctx0, attn_3d, n_embd_head_v * n_head_l, n_tokens); + cb(attn_out, "attn_gated", il); + } + + // output projection + cur = build_lora_mm(model.layers[il].wo, attn_out, model.layers[il].wo_s); + cb(cur, "attn_proj", il); + } + + if (il == n_layer - 1 && inp_out_ids && cparams.embeddings_nextn_masked) { + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); + } + + ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA); + cb(ffn_inp, "ffn_inp", il); + + cur = build_norm(ffn_inp, model.layers[il].ffn_norm, nullptr, LLM_NORM_RMS, il); + cb(cur, "ffn_norm", il); + + // feed-forward + if (model.layers[il].ffn_gate_inp == nullptr) { + // dense MLP + cur = build_ffn(cur, + model.layers[il].ffn_up, model.layers[il].ffn_up_b, nullptr, + model.layers[il].ffn_gate, model.layers[il].ffn_gate_b, nullptr, + model.layers[il].ffn_down, model.layers[il].ffn_down_b, nullptr, + nullptr, + LLM_FFN_SILU, LLM_FFN_PAR, il); + cb(cur, "ffn_out", il); + } else { + // MoE routed experts + ggml_tensor * moe_out = build_moe_ffn(cur, + model.layers[il].ffn_gate_inp, + model.layers[il].ffn_up_exps, + model.layers[il].ffn_gate_exps, + model.layers[il].ffn_down_exps, + model.layers[il].ffn_exp_probs_b, + n_expert, n_expert_used, + LLM_FFN_SILU, hparams.expert_weights_norm, + hparams.expert_weights_scale, + (llama_expert_gating_func_type) hparams.expert_gating_func, + il); + cb(moe_out, "ffn_moe_out", il); + + // shared expert MLP (always added on MoE layers in Step35) + ggml_tensor * sh_out = build_ffn(cur, + model.layers[il].ffn_up_shexp, nullptr, nullptr, + model.layers[il].ffn_gate_shexp, nullptr, nullptr, + model.layers[il].ffn_down_shexp, nullptr, nullptr, + nullptr, + LLM_FFN_SILU, LLM_FFN_PAR, il); + cb(sh_out, "ffn_shared_out", il); + + cur = ggml_add(ctx0, moe_out, sh_out); + cb(cur, "ffn_out", il); + } + cur = ggml_add(ctx0, cur, ffn_inp); + + cur = build_cvec(cur, il); + cb(cur, "l_out", il); + + // input for next layer + inpL = cur; + } + + cur = inpL; + + cb(cur, "h_nextn", -1); + res->t_h_nextn = cur; + + if (!cparams.embeddings_nextn_masked && inp_out_ids) { + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + } + + cur = build_norm(cur, model.output_norm, nullptr, LLM_NORM_RMS, -1); + cb(cur, "result_norm", -1); + res->t_embd = cur; + + cur = build_lora_mm(model.output, cur, model.output_s); + cb(cur, "result_output", -1); + res->t_logits = cur; + + ggml_build_forward_expand(gf, cur); +} + +// LLM_GRAPH_TYPE_DECODER_MTP draft head for Step3p5 (MoE) +llama_model_step35::graph_mtp::graph_mtp(const llama_model & model, const llm_graph_params & params) + : llm_graph_context(params) { + GGML_ASSERT(hparams.n_layer_nextn > 0 && "STEP35 MTP requires n_layer_nextn > 0"); + + // Single-block MTP only: always run the first trained MTP block (Qwen + // MTP / vLLM single-MTP-layer style). Multi-block round-robin proved to + // be a much deeper refactor than this PR justifies; the trailing MTP + // blocks are loaded with TENSOR_NOT_REQUIRED so pruned GGUFs (with just + // block 0) also work — see load_arch_tensors below and + // scripts/prune_step35_extra_mtp.py. + const int il = hparams.n_layer(); + const auto & layer = model.layers[il]; + + GGML_ASSERT(layer.nextn.eh_proj && "MTP block missing nextn.eh_proj"); + GGML_ASSERT(layer.nextn.enorm && "MTP block missing nextn.enorm"); + GGML_ASSERT(layer.nextn.hnorm && "MTP block missing nextn.hnorm"); + + const uint32_t n_head_l = hparams.n_head(il); + const uint32_t n_head_kv_l = hparams.n_head_kv(il); + + const float freq_base_l = model.get_rope_freq_base(cparams, il); + const float freq_scale_l = model.get_rope_freq_scale(cparams, il); + + auto inp = std::make_unique<llm_graph_input_embd>(hparams.n_embd); + + inp->tokens = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens); + ggml_set_input(inp->tokens); + + inp->embd = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, hparams.n_embd, n_tokens); + ggml_set_input(inp->embd); + ggml_set_name(inp->embd, "mtp_h_input"); + + ggml_tensor * tok_embd_w = layer.nextn.embed_tokens ? layer.nextn.embed_tokens : model.tok_embd; + + ggml_tensor * h_input = inp->embd; + ggml_tensor * tok_embd = ggml_get_rows(ctx0, tok_embd_w, inp->tokens); + cb(tok_embd, "mtp_tok_embd", il); + + res->add_input(std::move(inp)); + + ggml_tensor * inp_pos = build_inp_pos(); + auto * inp_attn = build_attn_inp_kv_iswa(); + + ggml_tensor * h_norm = build_norm(h_input, layer.nextn.hnorm, nullptr, LLM_NORM_RMS, il); + cb(h_norm, "mtp_hnorm", il); + + ggml_tensor * e_norm = build_norm(tok_embd, layer.nextn.enorm, nullptr, LLM_NORM_RMS, il); + cb(e_norm, "mtp_enorm", il); + + ggml_tensor * concat = ggml_concat(ctx0, e_norm, h_norm, /*dim=*/ 0); + cb(concat, "mtp_concat", il); + + ggml_tensor * cur = build_lora_mm(layer.nextn.eh_proj, concat); + cb(cur, "mtp_eh_proj", il); + + ggml_tensor * inpSA = cur; + + // mtp_block: full Step3p5 decoder layer (attention with optional head-wise gate, then MoE/dense FFN) + cur = build_norm(cur, layer.attn_norm, nullptr, LLM_NORM_RMS, il); + cb(cur, "mtp_attn_norm", il); + + ggml_tensor * Qcur = build_lora_mm(layer.wq, cur, layer.wq_s); + ggml_tensor * Kcur = build_lora_mm(layer.wk, cur, layer.wk_s); + ggml_tensor * Vcur = build_lora_mm(layer.wv, cur, layer.wv_s); + cb(Qcur, "mtp_Qcur", il); + cb(Kcur, "mtp_Kcur", il); + cb(Vcur, "mtp_Vcur", il); + + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head_k, n_head_l, n_tokens); + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head_k, n_head_kv_l, n_tokens); + Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head_v, n_head_kv_l, n_tokens); + + if (layer.attn_q_norm) { + Qcur = build_norm(Qcur, layer.attn_q_norm, nullptr, LLM_NORM_RMS, il); + cb(Qcur, "mtp_Qcur_normed", il); + } + if (layer.attn_k_norm) { + Kcur = build_norm(Kcur, layer.attn_k_norm, nullptr, LLM_NORM_RMS, il); + cb(Kcur, "mtp_Kcur_normed", il); + } + + const bool is_swa = hparams.is_swa(il); + ggml_tensor * rope_factors = is_swa ? nullptr : model.get_rope_factors(cparams, il); + const int64_t n_rot_l = hparams.n_rot(il); + + Qcur = ggml_rope_ext( + ctx0, Qcur, inp_pos, rope_factors, + n_rot_l, rope_type, n_ctx_orig, freq_base_l, freq_scale_l, + ext_factor, attn_factor, beta_fast, beta_slow); + Kcur = ggml_rope_ext( + ctx0, Kcur, inp_pos, rope_factors, + n_rot_l, rope_type, n_ctx_orig, freq_base_l, freq_scale_l, + ext_factor, attn_factor, beta_fast, beta_slow); + cb(Qcur, "mtp_Qcur_pos", il); + cb(Kcur, "mtp_Kcur_pos", il); + + const float kq_scale = 1.0f / sqrtf(float(n_embd_head_k)); + ggml_tensor * attn_out = build_attn(inp_attn, + nullptr, nullptr, nullptr, + Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale, il); + cb(attn_out, "mtp_attn_out", il); + + // head-wise attention gate: sigmoid(g_proj(x)) + if (layer.wqkv_gate) { + ggml_tensor * gate = build_lora_mm(layer.wqkv_gate, cur); // [n_head_l, n_tokens] + cb(gate, "mtp_attn_gate", il); + + gate = ggml_sigmoid(ctx0, gate); + cb(gate, "mtp_attn_gate_sigmoid", il); + + ggml_tensor * attn_3d = ggml_reshape_3d(ctx0, attn_out, n_embd_head_v, n_head_l, n_tokens); + ggml_tensor * gate_3d = ggml_reshape_3d(ctx0, gate, 1, n_head_l, n_tokens); + cb(gate_3d, "mtp_attn_gate_3d", il); + + attn_3d = ggml_mul(ctx0, attn_3d, gate_3d); + cb(attn_3d, "mtp_attn_gated_3d", il); + + attn_out = ggml_reshape_2d(ctx0, attn_3d, n_embd_head_v * n_head_l, n_tokens); + cb(attn_out, "mtp_attn_gated", il); + } + + cur = build_lora_mm(layer.wo, attn_out, layer.wo_s); + cb(cur, "mtp_attn_proj", il); + + cur = ggml_add(ctx0, cur, inpSA); + cb(cur, "mtp_attn_residual", il); + + ggml_tensor * ffn_inp = cur; + cur = build_norm(cur, layer.ffn_norm, nullptr, LLM_NORM_RMS, il); + cb(cur, "mtp_ffn_norm", il); + + // FFN: dense MLP or MoE (mirrors trunk path) + if (layer.ffn_gate_inp == nullptr) { + cur = build_ffn(cur, + layer.ffn_up, layer.ffn_up_b, nullptr, + layer.ffn_gate, layer.ffn_gate_b, nullptr, + layer.ffn_down, layer.ffn_down_b, nullptr, + nullptr, + LLM_FFN_SILU, LLM_FFN_PAR, il); + cb(cur, "mtp_ffn_out", il); + } else { + ggml_tensor * moe_out = build_moe_ffn(cur, + layer.ffn_gate_inp, + layer.ffn_up_exps, + layer.ffn_gate_exps, + layer.ffn_down_exps, + layer.ffn_exp_probs_b, + n_expert, n_expert_used, + LLM_FFN_SILU, hparams.expert_weights_norm, + hparams.expert_weights_scale, + (llama_expert_gating_func_type) hparams.expert_gating_func, + il); + cb(moe_out, "mtp_ffn_moe_out", il); + + ggml_tensor * sh_out = build_ffn(cur, + layer.ffn_up_shexp, nullptr, nullptr, + layer.ffn_gate_shexp, nullptr, nullptr, + layer.ffn_down_shexp, nullptr, nullptr, + nullptr, + LLM_FFN_SILU, LLM_FFN_PAR, il); + cb(sh_out, "mtp_ffn_shared_out", il); + + cur = ggml_add(ctx0, moe_out, sh_out); + cb(cur, "mtp_ffn_out", il); + } + cur = ggml_add(ctx0, cur, ffn_inp); + cb(cur, "mtp_post_ffn", il); + + // Pre-norm hidden state: used by the AR draft loop to seed the next MTP step. + cb(cur, "h_nextn", -1); + res->t_h_nextn = cur; + + ggml_tensor * head_norm_w = layer.nextn.shared_head_norm + ? layer.nextn.shared_head_norm + : model.output_norm; + GGML_ASSERT(head_norm_w && "STEP35 MTP: missing both nextn.shared_head_norm and output_norm"); + cur = build_norm(cur, head_norm_w, nullptr, LLM_NORM_RMS, -1); + cb(cur, "mtp_shared_head_norm", -1); + + ggml_tensor * head_w = layer.nextn.shared_head_head ? layer.nextn.shared_head_head : model.output; + GGML_ASSERT(head_w && "STEP35 MTP: missing LM head (nextn.shared_head_head or model.output)"); + cur = build_lora_mm(head_w, cur); + cb(cur, "result_output", -1); + + res->t_logits = cur; + ggml_build_forward_expand(gf, cur); +} diff --git a/examples/talk-llama/models/t5-dec.cpp b/examples/talk-llama/models/t5-dec.cpp deleted file mode 100644 index 297e450de76..00000000000 --- a/examples/talk-llama/models/t5-dec.cpp +++ /dev/null @@ -1,166 +0,0 @@ -#include "models.h" - -llm_build_t5_dec::llm_build_t5_dec(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { - const int64_t n_embd_head = hparams.n_embd_head_v; - //const int64_t n_embd_gqa = hparams.n_embd_v_gqa(); - - GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); - - ggml_tensor * cur; - ggml_tensor * inpL; - - inpL = build_inp_embd(model.tok_embd); - - ggml_tensor * embd_enc = build_inp_cross_embd(); - ggml_tensor * pos_bucket_dec = build_inp_pos_bucket_dec(); - - const int64_t n_outputs_enc = embd_enc->ne[1]; - - auto * inp_attn_self = build_attn_inp_kv(); - auto * inp_attn_cross = build_attn_inp_cross(); - - ggml_tensor * inp_out_ids = build_inp_out_ids(); - - const int64_t dec_n_layer = hparams.dec_n_layer; - - for (int il = 0; il < dec_n_layer; ++il) { - ggml_tensor * inpSA = inpL; - - // norm - cur = build_norm(inpL, - model.layers[il].attn_norm, NULL, - LLM_NORM_RMS, il); - cb(cur, "attn_norm", il); - - // self-attention - { - ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); - cb(Qcur, "Qcur", il); - - ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); - cb(Kcur, "Kcur", il); - - ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); - cb(Vcur, "Vcur", il); - - Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); - Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); - - ggml_tensor * attn_rel_b = model.layers[il].attn_rel_b ? model.layers[il].attn_rel_b : model.layers[0].attn_rel_b; - ggml_tensor * kq_b = build_pos_bias(pos_bucket_dec, attn_rel_b); - - cur = build_attn(inp_attn_self, - model.layers[il].wo, model.layers[il].bo, - Qcur, Kcur, Vcur, kq_b, nullptr, nullptr, 1.0f, il); - cb(cur, "kqv_out", il); - } - cur = ggml_add(ctx0, cur, inpSA); - cb(cur, "cross_inp", il); - - ggml_tensor * inpCA = cur; - - // norm - cur = build_norm(cur, - model.layers[il].attn_norm_cross, NULL, - LLM_NORM_RMS, il); - cb(cur, "attn_norm_cross", il); - - // cross-attention - { - ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq_cross, cur); - cb(Qcur, "Qcur", il); - - ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk_cross, embd_enc); - cb(Kcur, "Kcur", il); - - ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv_cross, embd_enc); - cb(Vcur, "Vcur", il); - - Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_outputs_enc); - Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_outputs_enc); - - cur = build_attn(inp_attn_cross, - model.layers[il].wo_cross, nullptr, - Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f, il); - cb(cur, "kqv_out", il); - - //ggml_tensor * q = ggml_permute(ctx0, Qcur, 0, 2, 1, 3); - //ggml_tensor * k = ggml_cont(ctx0, ggml_permute(ctx0, Kcur, 0, 2, 1, 3)); - - //ggml_tensor * kq = ggml_mul_mat(ctx0, k, q); - //cb(kq, "kq", il); - - //kq = ggml_soft_max_ext(ctx0, kq, KQ_mask_cross, 1.0f, hparams.f_max_alibi_bias); - //cb(kq, "kq_soft_max_ext", il); - - //ggml_tensor * v = ggml_cont(ctx0, ggml_transpose(ctx0, ggml_reshape_2d(ctx0, Vcur, n_embd_gqa, n_outputs_enc))); - //cb(v, "v", il); - - //ggml_tensor * kqv = ggml_mul_mat(ctx0, ggml_reshape_3d(ctx0, v, n_outputs_enc, n_embd_head, n_head_kv), kq); - //cb(kqv, "kqv", il); - - //ggml_tensor * kqv_merged = ggml_permute(ctx0, kqv, 0, 2, 1, 3); - //cb(kqv_merged, "kqv_merged", il); - - //cur = ggml_cont_2d(ctx0, kqv_merged, n_embd_gqa, n_tokens); - //cb(cur, "kqv_merged_cont", il); - - //ggml_build_forward_expand(gf, cur); - - //cur = build_lora_mm(model.layers[il].wo_cross, cur); - //cb(cur, "kqv_out", il); - } - if (il == dec_n_layer - 1 && inp_out_ids) { - cur = ggml_get_rows(ctx0, cur, inp_out_ids); - inpCA = ggml_get_rows(ctx0, inpCA, inp_out_ids); - } - ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpCA); - cb(ffn_inp, "ffn_inp", il); - - // feed-forward network - { - cur = build_norm(ffn_inp, - model.layers[il].ffn_norm, NULL, - LLM_NORM_RMS, il); - cb(cur, "ffn_norm", il); - - // T5 uses relu, flan-T5 uses gelu-gated - cur = build_ffn(cur, - model.layers[il].ffn_up, NULL, NULL, - model.layers[il].ffn_gate, NULL, NULL, - model.layers[il].ffn_down, NULL, NULL, - NULL, - model.layers[il].ffn_gate ? LLM_FFN_GELU : LLM_FFN_RELU, - model.layers[il].ffn_gate ? LLM_FFN_PAR : LLM_FFN_SEQ, - il); - cb(cur, "ffn_out", il); - } - cur = ggml_add(ctx0, cur, ffn_inp); - cb(cur, "ffn_out", il); - - cur = build_cvec(cur, il); - cb(cur, "l_out", il); - - // input for next layer - inpL = cur; - } - cur = inpL; - cb(cur, "result_embd", -1); - - cur = build_norm(cur, - model.output_norm, NULL, - LLM_NORM_RMS, -1); - - cb(cur, "result_norm", -1); - res->t_embd = cur; - - // lm_head - cur = build_lora_mm(model.output, cur); - - cb(cur, "result_output", -1); - res->t_logits = cur; - - ggml_build_forward_expand(gf, cur); -} diff --git a/examples/talk-llama/models/t5-enc.cpp b/examples/talk-llama/models/t5-enc.cpp deleted file mode 100644 index 70e1d80dcdd..00000000000 --- a/examples/talk-llama/models/t5-enc.cpp +++ /dev/null @@ -1,96 +0,0 @@ -#include "models.h" - -llm_build_t5_enc::llm_build_t5_enc(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { - const int64_t n_embd_head = hparams.n_embd_head_v; - - GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); - - ggml_tensor * cur; - ggml_tensor * inpL; - - inpL = build_inp_embd(model.tok_embd); - - ggml_tensor * pos_bucket_enc = build_inp_pos_bucket_enc(); - - auto * inp_attn = build_attn_inp_no_cache(); - - ggml_tensor * inp_out_ids = build_inp_out_ids(); - - for (int il = 0; il < n_layer; ++il) { - ggml_tensor * inpSA = inpL; - - // norm - cur = build_norm(inpL, - model.layers[il].attn_norm_enc, NULL, - LLM_NORM_RMS, il); - cb(cur, "attn_norm", il); - - // self-attention - { - ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq_enc, cur); - cb(Qcur, "Qcur", il); - - ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk_enc, cur); - cb(Kcur, "Kcur", il); - - ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv_enc, cur); - cb(Vcur, "Vcur", il); - - Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); - Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); - - ggml_tensor * attn_rel_b = model.layers[il].attn_rel_b_enc ? model.layers[il].attn_rel_b_enc : model.layers[0].attn_rel_b_enc; - ggml_tensor * kq_b = build_pos_bias(pos_bucket_enc, attn_rel_b); - - cur = build_attn(inp_attn, - model.layers[il].wo_enc, nullptr, - Qcur, Kcur, Vcur, kq_b, nullptr, nullptr, 1.0f, il); - cb(cur, "kqv_out", il); - } - if (il == n_layer - 1 && inp_out_ids) { - cur = ggml_get_rows(ctx0, cur, inp_out_ids); - inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); - } - ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA); - cb(ffn_inp, "ffn_inp", il); - - // feed-forward network - { - cur = build_norm(ffn_inp, - model.layers[il].ffn_norm_enc, NULL, - LLM_NORM_RMS, il); - cb(cur, "ffn_norm", il); - - // T5 uses relu, flan-T5 uses gelu-gated - cur = build_ffn(cur, - model.layers[il].ffn_up_enc, NULL, NULL, - model.layers[il].ffn_gate_enc, NULL, NULL, - model.layers[il].ffn_down_enc, NULL, NULL, - NULL, - model.layers[il].ffn_gate_enc ? LLM_FFN_GELU : LLM_FFN_RELU, - model.layers[il].ffn_gate_enc ? LLM_FFN_PAR : LLM_FFN_SEQ, - il); - cb(cur, "ffn_out", il); - } - cur = ggml_add(ctx0, cur, ffn_inp); - cb(cur, "ffn_out", il); - - cur = build_cvec(cur, il); - cb(cur, "l_out", il); - - // input for next layer - inpL = cur; - } - cur = inpL; - cb(cur, "result_embd", -1); - - cur = build_norm(cur, - model.output_norm_enc, NULL, - LLM_NORM_RMS, -1); - - cb(cur, "result_norm", -1); - res->t_embd = cur; - - ggml_build_forward_expand(gf, cur); -} diff --git a/examples/talk-llama/models/t5.cpp b/examples/talk-llama/models/t5.cpp new file mode 100644 index 00000000000..b0e3f062572 --- /dev/null +++ b/examples/talk-llama/models/t5.cpp @@ -0,0 +1,370 @@ +#include "models.h" + +void llama_model_t5::load_arch_hparams(llama_model_loader & ml) { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + ml.get_key(LLM_KV_ATTENTION_RELATIVE_BUCKETS_COUNT, hparams.n_rel_attn_bkts); + + uint32_t dec_start_token_id; + if (ml.get_key(LLM_KV_DECODER_START_TOKEN_ID, dec_start_token_id, false)) { + hparams.dec_start_token_id = dec_start_token_id; + } + + hparams.dec_n_layer = hparams.n_layer(); + ml.get_key(LLM_KV_DECODER_BLOCK_COUNT, hparams.dec_n_layer, false); + + switch (hparams.n_layer()) { + case 6: type = LLM_TYPE_60M; break; // t5-small + case 8: type = LLM_TYPE_80M; break; // flan-t5-small + case 12: + switch (hparams.n_ff()) { + case 3072: type = LLM_TYPE_220M; break; // t5-base + case 2048: type = LLM_TYPE_250M; break; // flan-t5-base + default: type = LLM_TYPE_UNKNOWN; + } break; + case 24: + switch (hparams.n_ff()) { + case 4096: type = LLM_TYPE_770M; break; // t5-large + case 2816: type = LLM_TYPE_780M; break; // flan-t5-large + case 16384: type = LLM_TYPE_3B; break; // t5-3b + case 5120: type = LLM_TYPE_3B; break; // flan-t5-xl + case 65536: type = LLM_TYPE_11B; break; // t5-11b + case 10240: type = LLM_TYPE_11B; break; // flan-t5-xxl + default: type = LLM_TYPE_UNKNOWN; + } break; + default: type = LLM_TYPE_UNKNOWN; + } +} + +void llama_model_t5::load_arch_tensors(llama_model_loader &) { + LLAMA_LOAD_LOCALS; + + const auto n_rel_attn_bkts = hparams.n_rel_attn_bkts; + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm_enc = create_tensor(tn(LLM_TENSOR_ENC_OUTPUT_NORM, "weight"), {n_embd}, 0); + output_norm = create_tensor(tn(LLM_TENSOR_DEC_OUTPUT_NORM, "weight"), {n_embd}, 0); + + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); + // if output is NULL, init from the input tok embed + if (output == NULL) { + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); + } + + // n_layer: number of encoder_layers + // dec_n_layer: number of decoder_layers + const int dec_n_layer = hparams.dec_n_layer; + if (dec_n_layer > n_layer) { + layers.resize(dec_n_layer); + } + + // load encoder layers + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm_enc = create_tensor(tn(LLM_TENSOR_ENC_ATTN_NORM, "weight", i), {n_embd}, 0); + layer.attn_rel_b_enc = create_tensor(tn(LLM_TENSOR_ENC_ATTN_REL_B, "weight", i), {n_head, n_rel_attn_bkts}, TENSOR_NOT_REQUIRED); + + layer.wq_enc = create_tensor(tn(LLM_TENSOR_ENC_ATTN_Q, "weight", i), {n_embd, n_embd_k_gqa}, 0); + layer.wk_enc = create_tensor(tn(LLM_TENSOR_ENC_ATTN_K, "weight", i), {n_embd, n_embd_k_gqa}, 0); + layer.wv_enc = create_tensor(tn(LLM_TENSOR_ENC_ATTN_V, "weight", i), {n_embd, n_embd_v_gqa}, 0); + layer.wo_enc = create_tensor(tn(LLM_TENSOR_ENC_ATTN_OUT, "weight", i), {n_embd_v_gqa, n_embd}, 0); + + layer.ffn_norm_enc = create_tensor(tn(LLM_TENSOR_ENC_FFN_NORM, "weight", i), {n_embd}, 0); + layer.ffn_gate_enc = create_tensor(tn(LLM_TENSOR_ENC_FFN_GATE, "weight", i), {n_embd, n_ff}, TENSOR_NOT_REQUIRED); + layer.ffn_down_enc = create_tensor(tn(LLM_TENSOR_ENC_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); + layer.ffn_up_enc = create_tensor(tn(LLM_TENSOR_ENC_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + } + + // load decoder layers + for (int i = 0; i < dec_n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_DEC_ATTN_NORM, "weight", i), {n_embd}, 0); + layer.attn_rel_b = create_tensor(tn(LLM_TENSOR_DEC_ATTN_REL_B, "weight", i), {n_head, n_rel_attn_bkts}, TENSOR_NOT_REQUIRED); + + layer.wq = create_tensor(tn(LLM_TENSOR_DEC_ATTN_Q, "weight", i), {n_embd, n_embd_k_gqa}, 0); + layer.wk = create_tensor(tn(LLM_TENSOR_DEC_ATTN_K, "weight", i), {n_embd, n_embd_k_gqa}, 0); + layer.wv = create_tensor(tn(LLM_TENSOR_DEC_ATTN_V, "weight", i), {n_embd, n_embd_v_gqa}, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_DEC_ATTN_OUT, "weight", i), {n_embd_v_gqa, n_embd}, 0); + + layer.attn_norm_cross = create_tensor(tn(LLM_TENSOR_DEC_CROSS_ATTN_NORM, "weight", i), {n_embd}, 0); + // this tensor seems to be unused in HF transformers implementation + layer.attn_rel_b_cross = create_tensor( + tn(LLM_TENSOR_DEC_CROSS_ATTN_REL_B, "weight", i), {n_head, n_rel_attn_bkts}, TENSOR_NOT_REQUIRED | TENSOR_SKIP_IF_VIRTUAL); + + layer.wq_cross = create_tensor(tn(LLM_TENSOR_DEC_CROSS_ATTN_Q, "weight", i), {n_embd, n_embd_k_gqa}, 0); + layer.wk_cross = create_tensor(tn(LLM_TENSOR_DEC_CROSS_ATTN_K, "weight", i), {n_embd, n_embd_k_gqa}, 0); + layer.wv_cross = create_tensor(tn(LLM_TENSOR_DEC_CROSS_ATTN_V, "weight", i), {n_embd, n_embd_v_gqa}, 0); + layer.wo_cross = create_tensor(tn(LLM_TENSOR_DEC_CROSS_ATTN_OUT, "weight", i), {n_embd_v_gqa, n_embd}, 0); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_DEC_FFN_NORM, "weight", i), {n_embd}, 0); + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_DEC_FFN_GATE, "weight", i), {n_embd, n_ff}, TENSOR_NOT_REQUIRED); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_DEC_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_DEC_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + } +} + +std::unique_ptr<llm_graph_context> llama_model_t5::build_arch_graph(const llm_graph_params & params) const { + switch (params.gtype) { + case LLM_GRAPH_TYPE_ENCODER: + return std::make_unique<graph<true>>(*this, params); + case LLM_GRAPH_TYPE_DEFAULT: + case LLM_GRAPH_TYPE_DECODER: + return std::make_unique<graph<false>>(*this, params); + default: + GGML_ABORT("invalid graph type"); + }; +} + +template <> +llama_model_t5::graph<false>::graph(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { + const int64_t n_embd_head = hparams.n_embd_head_v(); + //const int64_t n_embd_gqa = hparams.n_embd_v_gqa(); + + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k()); + + ggml_tensor * cur; + ggml_tensor * inpL; + + inpL = build_inp_embd(model.tok_embd); + + ggml_tensor * embd_enc = build_inp_cross_embd(); + ggml_tensor * pos_bucket_dec = build_inp_pos_bucket_dec(); + + const int64_t n_outputs_enc = embd_enc->ne[1]; + + auto * inp_attn_self = build_attn_inp_kv(); + auto * inp_attn_cross = build_attn_inp_cross(); + + ggml_tensor * inp_out_ids = build_inp_out_ids(); + + const int64_t dec_n_layer = hparams.dec_n_layer; + + for (int il = 0; il < dec_n_layer; ++il) { + ggml_tensor * inpSA = inpL; + + // norm + cur = build_norm(inpL, + model.layers[il].attn_norm, NULL, + LLM_NORM_RMS, il); + cb(cur, "attn_norm", il); + + // self-attention + { + auto [Qcur, Kcur, Vcur] = build_qkv(model.layers[il], cur, n_embd_head, n_head, n_head_kv, il); + + ggml_tensor * attn_rel_b = model.layers[il].attn_rel_b ? model.layers[il].attn_rel_b : model.layers[0].attn_rel_b; + ggml_tensor * kq_b = build_pos_bias(pos_bucket_dec, attn_rel_b); + + cur = build_attn(inp_attn_self, + model.layers[il].wo, model.layers[il].wo_b, model.layers[il].wo_s, + Qcur, Kcur, Vcur, kq_b, nullptr, nullptr, 1.0f, il); + cb(cur, "kqv_out", il); + } + cur = ggml_add(ctx0, cur, inpSA); + cb(cur, "cross_inp", il); + + ggml_tensor * inpCA = cur; + + // norm + cur = build_norm(cur, + model.layers[il].attn_norm_cross, NULL, + LLM_NORM_RMS, il); + cb(cur, "attn_norm_cross", il); + + // cross-attention + { + ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq_cross, cur); + cb(Qcur, "Qcur", il); + + ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk_cross, embd_enc); + cb(Kcur, "Kcur", il); + + ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv_cross, embd_enc); + cb(Vcur, "Vcur", il); + + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_outputs_enc); + Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_outputs_enc); + + cur = build_attn(inp_attn_cross, + model.layers[il].wo_cross, nullptr, nullptr, + Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f, il); + cb(cur, "kqv_out", il); + + //ggml_tensor * q = ggml_permute(ctx0, Qcur, 0, 2, 1, 3); + //ggml_tensor * k = ggml_cont(ctx0, ggml_permute(ctx0, Kcur, 0, 2, 1, 3)); + + //ggml_tensor * kq = ggml_mul_mat(ctx0, k, q); + //cb(kq, "kq", il); + + //kq = ggml_soft_max_ext(ctx0, kq, KQ_mask_cross, 1.0f, hparams.f_max_alibi_bias); + //cb(kq, "kq_soft_max_ext", il); + + //ggml_tensor * v = ggml_cont(ctx0, ggml_transpose(ctx0, ggml_reshape_2d(ctx0, Vcur, n_embd_gqa, n_outputs_enc))); + //cb(v, "v", il); + + //ggml_tensor * kqv = ggml_mul_mat(ctx0, ggml_reshape_3d(ctx0, v, n_outputs_enc, n_embd_head, n_head_kv), kq); + //cb(kqv, "kqv", il); + + //ggml_tensor * kqv_merged = ggml_permute(ctx0, kqv, 0, 2, 1, 3); + //cb(kqv_merged, "kqv_merged", il); + + //cur = ggml_cont_2d(ctx0, kqv_merged, n_embd_gqa, n_tokens); + //cb(cur, "kqv_merged_cont", il); + + //ggml_build_forward_expand(gf, cur); + + //cur = build_lora_mm(model.layers[il].wo_cross, cur); + //cb(cur, "kqv_out", il); + } + if (il == dec_n_layer - 1 && inp_out_ids) { + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpCA = ggml_get_rows(ctx0, inpCA, inp_out_ids); + } + ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpCA); + cb(ffn_inp, "ffn_inp", il); + + // feed-forward network + { + cur = build_norm(ffn_inp, + model.layers[il].ffn_norm, NULL, + LLM_NORM_RMS, il); + cb(cur, "ffn_norm", il); + + // T5 uses relu, flan-T5 uses gelu-gated + cur = build_ffn(cur, + model.layers[il].ffn_up, NULL, NULL, + model.layers[il].ffn_gate, NULL, NULL, + model.layers[il].ffn_down, NULL, NULL, + NULL, + model.layers[il].ffn_gate ? LLM_FFN_GELU : LLM_FFN_RELU, + model.layers[il].ffn_gate ? LLM_FFN_PAR : LLM_FFN_SEQ, + il); + cb(cur, "ffn_out", il); + } + cur = ggml_add(ctx0, cur, ffn_inp); + cb(cur, "ffn_out", il); + + cur = build_cvec(cur, il); + cb(cur, "l_out", il); + + // input for next layer + inpL = cur; + } + cur = inpL; + cb(cur, "result_embd", -1); + + cur = build_norm(cur, + model.output_norm, NULL, + LLM_NORM_RMS, -1); + + cb(cur, "result_norm", -1); + res->t_embd = cur; + + // lm_head + cur = build_lora_mm(model.output, cur, model.output_s); + + cb(cur, "result_output", -1); + res->t_logits = cur; + + ggml_build_forward_expand(gf, cur); +} + +template <> +llama_model_t5::graph<true>::graph(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { + const int64_t n_embd_head = hparams.n_embd_head_v(); + + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k()); + + ggml_tensor * cur; + ggml_tensor * inpL; + + inpL = build_inp_embd(model.tok_embd); + + ggml_tensor * pos_bucket_enc = build_inp_pos_bucket_enc(); + + auto * inp_attn = build_attn_inp_no_cache(); + + ggml_tensor * inp_out_ids = build_inp_out_ids(); + + for (int il = 0; il < n_layer; ++il) { + ggml_tensor * inpSA = inpL; + + // norm + cur = build_norm(inpL, + model.layers[il].attn_norm_enc, NULL, + LLM_NORM_RMS, il); + cb(cur, "attn_norm", il); + + // self-attention + { + ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq_enc, cur); + cb(Qcur, "Qcur", il); + + ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk_enc, cur); + cb(Kcur, "Kcur", il); + + ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv_enc, cur); + cb(Vcur, "Vcur", il); + + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); + Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + + ggml_tensor * attn_rel_b = model.layers[il].attn_rel_b_enc ? model.layers[il].attn_rel_b_enc : model.layers[0].attn_rel_b_enc; + ggml_tensor * kq_b = build_pos_bias(pos_bucket_enc, attn_rel_b); + + cur = build_attn(inp_attn, + model.layers[il].wo_enc, nullptr, nullptr, + Qcur, Kcur, Vcur, kq_b, nullptr, nullptr, 1.0f, il); + cb(cur, "kqv_out", il); + } + if (il == n_layer - 1 && inp_out_ids) { + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); + } + ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA); + cb(ffn_inp, "ffn_inp", il); + + // feed-forward network + { + cur = build_norm(ffn_inp, + model.layers[il].ffn_norm_enc, NULL, + LLM_NORM_RMS, il); + cb(cur, "ffn_norm", il); + + // T5 uses relu, flan-T5 uses gelu-gated + cur = build_ffn(cur, + model.layers[il].ffn_up_enc, NULL, NULL, + model.layers[il].ffn_gate_enc, NULL, NULL, + model.layers[il].ffn_down_enc, NULL, NULL, + NULL, + model.layers[il].ffn_gate_enc ? LLM_FFN_GELU : LLM_FFN_RELU, + model.layers[il].ffn_gate_enc ? LLM_FFN_PAR : LLM_FFN_SEQ, + il); + cb(cur, "ffn_out", il); + } + cur = ggml_add(ctx0, cur, ffn_inp); + cb(cur, "ffn_out", il); + + cur = build_cvec(cur, il); + cb(cur, "l_out", il); + + // input for next layer + inpL = cur; + } + cur = inpL; + cb(cur, "result_embd", -1); + + cur = build_norm(cur, + model.output_norm_enc, NULL, + LLM_NORM_RMS, -1); + + cb(cur, "result_norm", -1); + res->t_embd = cur; + + ggml_build_forward_expand(gf, cur); +} diff --git a/examples/talk-llama/models/t5encoder.cpp b/examples/talk-llama/models/t5encoder.cpp new file mode 100644 index 00000000000..23c5f9b6a1c --- /dev/null +++ b/examples/talk-llama/models/t5encoder.cpp @@ -0,0 +1,44 @@ +#include "models.h" + +void llama_model_t5encoder::load_arch_hparams(llama_model_loader & ml) { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + ml.get_key(LLM_KV_ATTENTION_RELATIVE_BUCKETS_COUNT, hparams.n_rel_attn_bkts); + type = LLM_TYPE_UNKNOWN; +} + +void llama_model_t5encoder::load_arch_tensors(llama_model_loader &) { + LLAMA_LOAD_LOCALS; + + const auto n_rel_attn_bkts = hparams.n_rel_attn_bkts; + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm_enc = create_tensor(tn(LLM_TENSOR_ENC_OUTPUT_NORM, "weight"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); + // if output is NULL, init from the input tok embed + if (output == NULL) { + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); + } + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm_enc = create_tensor(tn(LLM_TENSOR_ENC_ATTN_NORM, "weight", i), {n_embd}, 0); + layer.attn_rel_b_enc = create_tensor(tn(LLM_TENSOR_ENC_ATTN_REL_B, "weight", i), {n_head, n_rel_attn_bkts}, TENSOR_NOT_REQUIRED); + + layer.wq_enc = create_tensor(tn(LLM_TENSOR_ENC_ATTN_Q, "weight", i), {n_embd, n_embd_k_gqa}, 0); + layer.wk_enc = create_tensor(tn(LLM_TENSOR_ENC_ATTN_K, "weight", i), {n_embd, n_embd_k_gqa}, 0); + layer.wv_enc = create_tensor(tn(LLM_TENSOR_ENC_ATTN_V, "weight", i), {n_embd, n_embd_v_gqa}, 0); + layer.wo_enc = create_tensor(tn(LLM_TENSOR_ENC_ATTN_OUT, "weight", i), {n_embd_v_gqa, n_embd}, 0); + + layer.ffn_norm_enc = create_tensor(tn(LLM_TENSOR_ENC_FFN_NORM, "weight", i), {n_embd}, 0); + layer.ffn_gate_enc = create_tensor(tn(LLM_TENSOR_ENC_FFN_GATE, "weight", i), {n_embd, n_ff}, TENSOR_NOT_REQUIRED); + layer.ffn_down_enc = create_tensor(tn(LLM_TENSOR_ENC_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); + layer.ffn_up_enc = create_tensor(tn(LLM_TENSOR_ENC_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + } +} + +std::unique_ptr<llm_graph_context> llama_model_t5encoder::build_arch_graph(const llm_graph_params & params) const { + return std::make_unique<graph>(*this, params); +} diff --git a/examples/talk-llama/models/talkie.cpp b/examples/talk-llama/models/talkie.cpp new file mode 100644 index 00000000000..393e8f65bf4 --- /dev/null +++ b/examples/talk-llama/models/talkie.cpp @@ -0,0 +1,149 @@ +#include "models.h" + +void llama_model_talkie::load_arch_hparams(llama_model_loader & ml) { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + ml.get_key(LLM_KV_LOGIT_SCALE, hparams.f_logit_scale); + + switch (hparams.n_layer()) { + case 40: type = LLM_TYPE_13B; break; + default: type = LLM_TYPE_UNKNOWN; + } +} + +void llama_model_talkie::load_arch_tensors(llama_model_loader &) { + LLAMA_LOAD_LOCALS; + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0); + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + create_tensor_qkv(layer, i, n_embd, n_embd_head_k * n_head, n_embd_gqa, n_embd_gqa, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0); + + // no k gain + layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {1, n_head}, 0); + + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0); + + layer.out_scale = create_tensor(tn(LLM_TENSOR_LAYER_OUT_SCALE, "weight", i), {1}, 0); + } +} + +std::unique_ptr<llm_graph_context> llama_model_talkie::build_arch_graph(const llm_graph_params & params) const { + return std::make_unique<graph>(*this, params); +} + +llama_model_talkie::graph::graph(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { + const int64_t n_embd_head = hparams.n_embd_head_k(); + + GGML_ASSERT(n_embd_head == hparams.n_embd_head_v()); + GGML_ASSERT(n_embd_head == n_rot); + + ggml_tensor * cur; + ggml_tensor * inpL; + + inpL = build_inp_embd(model.tok_embd); + inpL = build_norm(inpL, nullptr, nullptr, LLM_NORM_RMS, -1); + cb(inpL, "inp_norm", -1); + + ggml_tensor * embd_skip = inpL; + + // inp_pos - contains the positions + ggml_tensor * inp_pos = build_inp_pos(); + + auto * inp_attn = build_attn_inp_kv(); + + ggml_tensor * inp_out_ids = build_inp_out_ids(); + + const float kq_scale = 1.0f / sqrtf(float(n_embd_head)); + + for (int il = 0; il < n_layer; ++il) { + ggml_tensor * inpSA = inpL; + ggml_tensor * inp_skip = embd_skip; + + cur = build_norm(inpL, nullptr, nullptr, LLM_NORM_RMS, il); + cb(cur, "attn_norm", il); + + // self-attention + { + auto [Qcur, Kcur, Vcur] = build_qkv(model.layers[il], cur, + n_embd_head, n_head, n_head_kv, il); + + Qcur = ggml_rope_ext( + ctx0, Qcur, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow); + + Kcur = ggml_rope_ext( + ctx0, Kcur, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow); + + // reference applies qknorm after rope + Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, nullptr, LLM_NORM_RMS, il); + cb(Qcur, "Qcur_norm", il); + + Kcur = build_norm(Kcur, nullptr, nullptr, LLM_NORM_RMS, il); + cb(Kcur, "Kcur_norm", il); + + cb(Vcur, "Vcur", il); + + cur = build_attn(inp_attn, + model.layers[il].wo, nullptr, model.layers[il].wo_s, + Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale, il); + cb(cur, "attn_out", il); + } + + if (il == n_layer - 1 && inp_out_ids) { + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); + inp_skip = ggml_get_rows(ctx0, inp_skip, inp_out_ids); + } + + ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA); + cb(ffn_inp, "ffn_inp", il); + + cur = build_norm(ffn_inp, nullptr, nullptr, LLM_NORM_RMS, il); + cb(cur, "ffn_norm", il); + + cur = build_ffn(cur, + model.layers[il].ffn_up, nullptr, nullptr, + model.layers[il].ffn_gate, nullptr, nullptr, + model.layers[il].ffn_down, nullptr, model.layers[il].ffn_down_s, + nullptr, + LLM_FFN_SILU, LLM_FFN_PAR, il); + cb(cur, "ffn_out", il); + + cur = ggml_add(ctx0, cur, ffn_inp); + + ggml_tensor * skip = ggml_mul(ctx0, inp_skip, model.layers[il].out_scale); + cb(skip, "embd_skip", il); + + cur = ggml_add(ctx0, cur, skip); + + cur = build_cvec(cur, il); + cb(cur, "l_out", il); + + // input for next layer + inpL = cur; + } + + cur = inpL; + + cur = build_norm(cur, nullptr, nullptr, LLM_NORM_RMS, -1); + cb(cur, "result_norm", -1); + + res->t_embd = cur; + + cur = build_lora_mm(model.output, cur); + cur = ggml_scale(ctx0, cur, hparams.f_logit_scale); + cb(cur, "result_output", -1); + + res->t_logits = cur; + + ggml_build_forward_expand(gf, cur); +} diff --git a/examples/talk-llama/models/wavtokenizer-dec.cpp b/examples/talk-llama/models/wavtokenizer-dec.cpp index 537a0d41248..214fed99bad 100644 --- a/examples/talk-llama/models/wavtokenizer-dec.cpp +++ b/examples/talk-llama/models/wavtokenizer-dec.cpp @@ -1,6 +1,121 @@ #include "models.h" -llm_build_wavtokenizer_dec::llm_build_wavtokenizer_dec(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { +void llama_model_wavtokenizer_dec::load_arch_hparams(llama_model_loader & ml) { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); + ml.get_key(LLM_KV_ATTENTION_GROUPNORM_EPS, hparams.f_norm_group_eps); + ml.get_key(LLM_KV_ATTENTION_GROUPNORM_GROUPS, hparams.n_norm_groups); +} + +void llama_model_wavtokenizer_dec::load_arch_tensors(llama_model_loader &) { + LLAMA_LOAD_LOCALS; + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {hparams.n_embd, n_vocab}, 0); + + conv1d = create_tensor(tn(LLM_TENSOR_CONV1D, "weight", 0), {7, hparams.n_embd, hparams.posnet.n_embd}, 0); + conv1d_b = create_tensor(tn(LLM_TENSOR_CONV1D, "bias", 0), {1, hparams.posnet.n_embd}, 0); + + // posnet + { + const int64_t n_embd = hparams.posnet.n_embd; + + for (uint32_t i = 0; i < hparams.posnet.n_layer; ++i) { + auto & layer = layers[i].posnet; + + // posnet: + // + // - resnet + // - resnet + // - attn + // - resnet + // - resnet + // - norm + // + switch (i) { + case 0: + case 1: + case 3: + case 4: + { + layer.norm1 = create_tensor(tn(LLM_TENSOR_POS_NET_NORM1, "weight", i), {1, n_embd}, 0); + layer.norm1_b = create_tensor(tn(LLM_TENSOR_POS_NET_NORM1, "bias", i), {1, n_embd}, 0); + + layer.conv1 = create_tensor(tn(LLM_TENSOR_POS_NET_CONV1, "weight", i), {3, n_embd, n_embd}, 0); + layer.conv1_b = create_tensor(tn(LLM_TENSOR_POS_NET_CONV1, "bias", i), {1, n_embd}, 0); + + layer.norm2 = create_tensor(tn(LLM_TENSOR_POS_NET_NORM2, "weight", i), {1, n_embd}, 0); + layer.norm2_b = create_tensor(tn(LLM_TENSOR_POS_NET_NORM2, "bias", i), {1, n_embd}, 0); + + layer.conv2 = create_tensor(tn(LLM_TENSOR_POS_NET_CONV2, "weight", i), {3, n_embd, n_embd}, 0); + layer.conv2_b = create_tensor(tn(LLM_TENSOR_POS_NET_CONV2, "bias", i), {1, n_embd}, 0); + } break; + case 2: + { + layer.attn_norm = create_tensor(tn(LLM_TENSOR_POS_NET_ATTN_NORM, "weight", i), {1, n_embd}, 0); + layer.attn_norm_b = create_tensor(tn(LLM_TENSOR_POS_NET_ATTN_NORM, "bias", i), {1, n_embd}, 0); + + layer.attn_q = create_tensor(tn(LLM_TENSOR_POS_NET_ATTN_Q, "weight", i), {1, n_embd, n_embd}, 0); + layer.attn_q_b = create_tensor(tn(LLM_TENSOR_POS_NET_ATTN_Q, "bias", i), {1, n_embd}, 0); + + layer.attn_k = create_tensor(tn(LLM_TENSOR_POS_NET_ATTN_K, "weight", i), {1, n_embd, n_embd}, 0); + layer.attn_k_b = create_tensor(tn(LLM_TENSOR_POS_NET_ATTN_K, "bias", i), {1, n_embd}, 0); + + layer.attn_v = create_tensor(tn(LLM_TENSOR_POS_NET_ATTN_V, "weight", i), {1, n_embd, n_embd}, 0); + layer.attn_v_b = create_tensor(tn(LLM_TENSOR_POS_NET_ATTN_V, "bias", i), {1, n_embd}, 0); + + layer.attn_o = create_tensor(tn(LLM_TENSOR_POS_NET_ATTN_OUT, "weight", i), {1, n_embd, n_embd}, 0); + layer.attn_o_b = create_tensor(tn(LLM_TENSOR_POS_NET_ATTN_OUT, "bias", i), {1, n_embd}, 0); + } break; + case 5: + { + layer.norm = create_tensor(tn(LLM_TENSOR_POS_NET_ATTN_NORM, "weight", i), {1, n_embd}, 0); + layer.norm_b = create_tensor(tn(LLM_TENSOR_POS_NET_ATTN_NORM, "bias", i), {1, n_embd}, 0); + } break; + default: GGML_ABORT("unknown posnet layer"); + }; + } + } + + GGML_ASSERT(hparams.posnet.n_embd == hparams.convnext.n_embd); + + tok_norm = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "weight", 0), {hparams.posnet.n_embd}, 0); + tok_norm_b = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "bias", 0), {hparams.posnet.n_embd}, 0); + + // convnext + { + const int64_t n_embd = hparams.convnext.n_embd; + + for (uint32_t i = 0; i < hparams.convnext.n_layer; ++i) { + auto & layer = layers[i].convnext; + + layer.dw = create_tensor(tn(LLM_TENSOR_CONVNEXT_DW, "weight", i), {7, 1, n_embd}, 0); + layer.dw_b = create_tensor(tn(LLM_TENSOR_CONVNEXT_DW, "bias", i), {1, n_embd}, 0); + + layer.norm = create_tensor(tn(LLM_TENSOR_CONVNEXT_NORM, "weight", i), {n_embd}, 0); + layer.norm_b = create_tensor(tn(LLM_TENSOR_CONVNEXT_NORM, "bias", i), {n_embd}, 0); + + layer.pw1 = create_tensor(tn(LLM_TENSOR_CONVNEXT_PW1, "weight", i), {n_embd, n_ff}, 0); + layer.pw1_b = create_tensor(tn(LLM_TENSOR_CONVNEXT_PW1, "bias", i), {n_ff}, 0); + + layer.pw2 = create_tensor(tn(LLM_TENSOR_CONVNEXT_PW2, "weight", i), {n_ff, n_embd}, 0); + layer.pw2_b = create_tensor(tn(LLM_TENSOR_CONVNEXT_PW2, "bias", i), {n_embd}, 0); + + layer.gamma = create_tensor(tn(LLM_TENSOR_CONVNEXT_GAMMA, "weight", i), {n_embd}, 0); + } + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output_norm_b = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd}, 0); + } + + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {hparams.convnext.n_embd, hparams.n_embd_out()}, 0); + output_b = create_tensor(tn(LLM_TENSOR_OUTPUT, "bias"), {hparams.n_embd_out()}, 0); +} + +std::unique_ptr<llm_graph_context> llama_model_wavtokenizer_dec::build_arch_graph(const llm_graph_params & params) const { + return std::make_unique<graph>(*this, params); +} + +llama_model_wavtokenizer_dec::graph::graph(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { ggml_tensor * cur; ggml_tensor * inpL; @@ -93,7 +208,7 @@ llm_build_wavtokenizer_dec::llm_build_wavtokenizer_dec(const llama_model & model cur = build_norm(cur, model.tok_norm, model.tok_norm_b, - LLM_NORM, -1); + LLM_NORM, 0); cur = ggml_cont(ctx0, ggml_transpose(ctx0, cur)); @@ -138,7 +253,7 @@ llm_build_wavtokenizer_dec::llm_build_wavtokenizer_dec(const llama_model & model LLM_NORM, -1); // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cur = ggml_add(ctx0, cur, model.output_b); diff --git a/examples/talk-llama/models/xverse.cpp b/examples/talk-llama/models/xverse.cpp index 364797dd31b..3135001293a 100644 --- a/examples/talk-llama/models/xverse.cpp +++ b/examples/talk-llama/models/xverse.cpp @@ -1,10 +1,48 @@ #include "models.h" -llm_build_xverse::llm_build_xverse(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { - const int64_t n_embd_head = hparams.n_embd_head_v; +void llama_model_xverse::load_arch_hparams(llama_model_loader & ml) { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + + switch (hparams.n_layer()) { + case 32: type = LLM_TYPE_7B; break; + case 40: type = LLM_TYPE_13B; break; + case 80: type = LLM_TYPE_65B; break; + default: type = LLM_TYPE_UNKNOWN; + } +} + +void llama_model_xverse::load_arch_tensors(llama_model_loader &) { + LLAMA_LOAD_LOCALS; + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0); + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + + create_tensor_qkv(layer, i, n_embd, n_embd, n_embd_gqa, n_embd_gqa, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + } +} + +std::unique_ptr<llm_graph_context> llama_model_xverse::build_arch_graph(const llm_graph_params & params) const { + return std::make_unique<graph>(*this, params); +} - GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); - GGML_ASSERT(n_embd_head == hparams.n_rot); +llama_model_xverse::graph::graph(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { + const int64_t n_embd_head = hparams.n_embd_head_v(); + + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k()); + GGML_ASSERT(n_embd_head == n_rot); ggml_tensor * cur; ggml_tensor * inpL; @@ -28,18 +66,8 @@ llm_build_xverse::llm_build_xverse(const llama_model & model, const llm_graph_pa // self-attention { - ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); - cb(Qcur, "Qcur", il); - - ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); - cb(Kcur, "Kcur", il); - - ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); - cb(Vcur, "Vcur", il); - - Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); - Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + auto [Qcur, Kcur, Vcur] = build_qkv(model.layers[il], cur, + n_embd_head, n_head, n_head_kv, il); Qcur = ggml_rope_ext( ctx0, Qcur, inp_pos, nullptr, @@ -58,7 +86,7 @@ llm_build_xverse::llm_build_xverse(const llama_model & model, const llm_graph_pa cb(Vcur, "Vcur", il); cur = build_attn(inp_attn, - model.layers[il].wo, NULL, + model.layers[il].wo, NULL, model.layers[il].wo_s, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } if (il == n_layer - 1 && inp_out_ids) { @@ -99,7 +127,7 @@ llm_build_xverse::llm_build_xverse(const llama_model & model, const llm_graph_pa res->t_embd = cur; // lm_head - cur = build_lora_mm(model.output, cur); + cur = build_lora_mm(model.output, cur, model.output_s); cb(cur, "result_output", -1); res->t_logits = cur; diff --git a/examples/talk-llama/unicode.cpp b/examples/talk-llama/unicode.cpp index b47dcbe6198..b02ecdc930f 100644 --- a/examples/talk-llama/unicode.cpp +++ b/examples/talk-llama/unicode.cpp @@ -1,16 +1,10 @@ -#if defined(_MSC_VER) -#define _SILENCE_CXX17_CODECVT_HEADER_DEPRECATION_WARNING -#endif - #include "unicode.h" #include "unicode-data.h" #include <algorithm> #include <cassert> -#include <codecvt> #include <cstddef> #include <cstdint> -#include <locale> #include <map> #include <regex> #include <stdexcept> @@ -199,27 +193,6 @@ static std::unordered_map<std::string, uint8_t> unicode_utf8_to_byte_map() { return map; } -static inline std::wstring unicode_wstring_from_utf8(const std::string & s) { -#if defined(__clang__) - // disable C++17 deprecation warning for std::codecvt_utf8 -# pragma clang diagnostic push -# pragma clang diagnostic ignored "-Wdeprecated-declarations" -#elif defined(__GNUC__) -# pragma GCC diagnostic push -# pragma GCC diagnostic ignored "-Wdeprecated-declarations" -#endif - - std::wstring_convert<std::codecvt_utf8<wchar_t>> conv; - -#if defined(__clang__) -# pragma clang diagnostic pop -#elif defined(__GNUC__) -# pragma GCC diagnostic pop -#endif - - return conv.from_bytes(s); -} - static std::vector<std::string> unicode_byte_encoding_process(const std::vector<std::string> & bpe_words) { std::vector<std::string> bpe_encoded_words; for (const auto & word : bpe_words) { @@ -497,49 +470,291 @@ static std::vector<size_t> unicode_regex_split_custom_llama3(const std::string & return bpe_offsets; } -// use std::wregex to split the text -static std::vector<size_t> unicode_regex_split_stl(const std::wstring & wtext, const std::wstring & regex_expr, const std::vector<size_t> & offsets) { - std::wregex expr(regex_expr, std::regex_constants::optimize | std::regex_constants::nosubs); +// Qwen2 system regex: "(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+" +static std::vector<size_t> unicode_regex_split_custom_qwen2(const std::string & text, const std::vector<size_t> & offsets) { std::vector<size_t> bpe_offsets; // store the offset of each word bpe_offsets.reserve(offsets.size()); // Reserve memory for the approximate size + + const auto cpts = unicode_cpts_from_utf8(text); + size_t start = 0; for (auto offset : offsets) { - std::wcregex_iterator it(wtext.data() + start, wtext.data() + start + offset, expr); - std::wcregex_iterator end; + const size_t offset_ini = start; + const size_t offset_end = start + offset; + assert(offset_end <= cpts.size()); + start = offset_end; - int64_t start_idx = 0; - while (it != end) { - std::wcmatch match = *it; - if (match.position() > start_idx) { - bpe_offsets.emplace_back(match.position() - start_idx); + static const uint32_t OUT_OF_RANGE = 0xFFFFFFFF; + auto _get_cpt = [&] (const size_t pos) -> uint32_t { + return (offset_ini <= pos && pos < offset_end) ? cpts[pos] : OUT_OF_RANGE; + }; + + auto _get_flags = [&] (const size_t pos) -> unicode_cpt_flags { + return (offset_ini <= pos && pos < offset_end) ? unicode_cpt_flags_from_cpt(cpts[pos]) : unicode_cpt_flags{}; + }; + + size_t _prev_end = offset_ini; + auto _add_token = [&] (const size_t end) -> size_t { + assert(_prev_end <= end && end <= offset_end); + size_t len = end - _prev_end; + if (len > 0) { + bpe_offsets.push_back(len); } - bpe_offsets.emplace_back(match.length()); - start_idx = match.position() + match.length(); - ++it; + _prev_end = end; + //if (len > 0) { + // std::string s = ""; + // for(size_t p = end-len; p < end; p++) + // s += unicode_cpt_to_utf8(cpts[p]); + // printf(">>> '%s'\n", s.c_str()); + //} + return len; + }; + + for (size_t pos = offset_ini; pos < offset_end; /*pos++*/ ) { + const uint32_t cpt = _get_cpt(pos); + const auto flags = _get_flags(pos); + + // regex: (?i:'s|'t|'re|'ve|'m|'ll|'d) // case insensitive + if (cpt == '\'' && pos+1 < offset_end) { + uint32_t cpt_next = unicode_tolower(_get_cpt(pos+1)); + if (cpt_next == 's' || cpt_next == 't' || cpt_next == 'm' || cpt_next == 'd') { + pos += _add_token(pos+2); + continue; + } + if (pos+2 < offset_end) { + uint32_t cpt_next_next = unicode_tolower(_get_cpt(pos+2)); + if ((cpt_next == 'r' && cpt_next_next == 'e') || + (cpt_next == 'v' && cpt_next_next == 'e') || + (cpt_next == 'l' && cpt_next_next == 'l')) { + pos += _add_token(pos+3); + continue; + } + } + } + + // regex: [^\r\n\p{L}\p{N}]?\p{L}+ + if (!(cpt == '\r' || cpt == '\n' || flags.is_number)) { + if (flags.is_letter || _get_flags(pos+1).is_letter) { // one or more letters + pos++; + while (_get_flags(pos).is_letter) { + pos++; + } + _add_token(pos); + continue; + } + } + + // regex: \p{N} + if (flags.is_number) { + pos++; + _add_token(pos); + continue; + } + + // regex: <space>?[^\s\p{L}\p{N}]+[\r\n]* + auto flags2 = (cpt == ' ' ? _get_flags(pos+1) : flags); + if (!(flags2.is_whitespace | flags2.is_letter | flags2.is_number) && flags.as_uint()) { + pos += (cpt == ' '); + while (!(flags2.is_whitespace | flags2.is_letter | flags2.is_number) && flags2.as_uint()) { + flags2 = _get_flags(++pos); + } + uint32_t cpt2 = _get_cpt(pos); + while (cpt2 == '\r' || cpt2 == '\n') { + cpt2 = _get_cpt(++pos); + } + _add_token(pos); + continue; + } + + size_t num_whitespaces = 0; + size_t last_end_r_or_n = 0; + while (_get_flags(pos+num_whitespaces).is_whitespace) { + uint32_t cpt2 = _get_cpt(pos+num_whitespaces); + if (cpt2 == '\r' || cpt2 == '\n') { + last_end_r_or_n = pos + num_whitespaces + 1; + } + num_whitespaces++; + } + + // regex: \s*[\r\n]+ + if (last_end_r_or_n > 0) { + pos = last_end_r_or_n; + _add_token(pos); + continue; + } + + // regex: \s+(?!\S) + if (num_whitespaces > 1 && _get_cpt(pos+num_whitespaces) != OUT_OF_RANGE) { + pos += num_whitespaces - 1; + _add_token(pos); + continue; + } + + // regex: \s+ + if (num_whitespaces > 0) { + pos += num_whitespaces; + _add_token(pos); + continue; + } + + // no matches + _add_token(++pos); } + } - if (start_idx < (int64_t) offset) { - bpe_offsets.emplace_back(offset - start_idx); + return bpe_offsets; +} + +// Qwen3.5 system regex: "(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?[\\p{L}\\p{M}]+|\\p{N}| ?[^\\s\\p{L}\\p{M}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+" +// Compared to Qwen2, letter-runs also consume Unicode combining marks (\p{M}): [\p{L}\p{M}]+ instead of \p{L}+ +static std::vector<size_t> unicode_regex_split_custom_qwen35(const std::string & text, const std::vector<size_t> & offsets) { + std::vector<size_t> bpe_offsets; // store the offset of each word + bpe_offsets.reserve(offsets.size()); // Reserve memory for the approximate size + + const auto cpts = unicode_cpts_from_utf8(text); + + size_t start = 0; + for (auto offset : offsets) { + const size_t offset_ini = start; + const size_t offset_end = start + offset; + assert(offset_end <= cpts.size()); + start = offset_end; + + static const uint32_t OUT_OF_RANGE = 0xFFFFFFFF; + auto _get_cpt = [&] (const size_t pos) -> uint32_t { + return (offset_ini <= pos && pos < offset_end) ? cpts[pos] : OUT_OF_RANGE; + }; + + auto _get_flags = [&] (const size_t pos) -> unicode_cpt_flags { + return (offset_ini <= pos && pos < offset_end) ? unicode_cpt_flags_from_cpt(cpts[pos]) : unicode_cpt_flags{}; + }; + + size_t _prev_end = offset_ini; + auto _add_token = [&] (const size_t end) -> size_t { + assert(_prev_end <= end && end <= offset_end); + size_t len = end - _prev_end; + if (len > 0) { + bpe_offsets.push_back(len); + } + _prev_end = end; + return len; + }; + + for (size_t pos = offset_ini; pos < offset_end; /*pos++*/ ) { + const uint32_t cpt = _get_cpt(pos); + const auto flags = _get_flags(pos); + + // regex: (?i:'s|'t|'re|'ve|'m|'ll|'d) // case insensitive + if (cpt == '\'' && pos+1 < offset_end) { + uint32_t cpt_next = unicode_tolower(_get_cpt(pos+1)); + if (cpt_next == 's' || cpt_next == 't' || cpt_next == 'm' || cpt_next == 'd') { + pos += _add_token(pos+2); + continue; + } + if (pos+2 < offset_end) { + uint32_t cpt_next_next = unicode_tolower(_get_cpt(pos+2)); + if ((cpt_next == 'r' && cpt_next_next == 'e') || + (cpt_next == 'v' && cpt_next_next == 'e') || + (cpt_next == 'l' && cpt_next_next == 'l')) { + pos += _add_token(pos+3); + continue; + } + } + } + + // regex: [^\r\n\p{L}\p{N}]?[\p{L}\p{M}]+ + if (!(cpt == '\r' || cpt == '\n' || flags.is_number)) { + if (flags.is_letter || flags.is_accent_mark || _get_flags(pos + 1).is_accent_mark || _get_flags(pos+1).is_letter) { + pos++; + while (_get_flags(pos).is_letter || _get_flags(pos).is_accent_mark) { + pos++; + } + _add_token(pos); + continue; + } + } + + // regex: \p{N} + if (flags.is_number) { + pos++; + _add_token(pos); + continue; + } + + // regex: <space>?[^\s\p{L}\p{M}\p{N}]+[\r\n]* + auto flags2 = (cpt == ' ' ? _get_flags(pos+1) : flags); + if (!(flags2.is_whitespace | flags2.is_letter | flags2.is_accent_mark | flags2.is_number) && flags.as_uint()) { + pos += (cpt == ' '); + while (!(flags2.is_whitespace | flags2.is_letter | flags2.is_accent_mark | flags2.is_number) && flags2.as_uint()) { + flags2 = _get_flags(++pos); + } + uint32_t cpt2 = _get_cpt(pos); + while (cpt2 == '\r' || cpt2 == '\n') { + cpt2 = _get_cpt(++pos); + } + _add_token(pos); + continue; + } + + size_t num_whitespaces = 0; + size_t last_end_r_or_n = 0; + while (_get_flags(pos+num_whitespaces).is_whitespace) { + uint32_t cpt2 = _get_cpt(pos+num_whitespaces); + if (cpt2 == '\r' || cpt2 == '\n') { + last_end_r_or_n = pos + num_whitespaces + 1; + } + num_whitespaces++; + } + + // regex: \s*[\r\n]+ + if (last_end_r_or_n > 0) { + pos = last_end_r_or_n; + _add_token(pos); + continue; + } + + // regex: \s+(?!\S) + if (num_whitespaces > 1 && _get_cpt(pos+num_whitespaces) != OUT_OF_RANGE) { + pos += num_whitespaces - 1; + _add_token(pos); + continue; + } + + // regex: \s+ + if (num_whitespaces > 0) { + pos += num_whitespaces; + _add_token(pos); + continue; + } + + // no matches + _add_token(++pos); } - start += offset; } return bpe_offsets; } -// use std::regex to split the text -static std::vector<size_t> unicode_regex_split_stl(const std::string & text, const std::string & regex_expr, const std::vector<size_t> & offsets) { - std::regex expr(regex_expr, std::regex_constants::optimize | std::regex_constants::nosubs); +template <typename CharT> +static std::vector<size_t> unicode_regex_split_stl(const std::basic_string<CharT> & text, const std::basic_string<CharT> & regex, const std::vector<size_t> & offsets) { + using BidirIt = typename std::basic_string<CharT>::const_iterator; +#ifdef _MSC_VER + // Bypass bug in MSVC: https://github.com/ggml-org/llama.cpp/issues/17830 + constexpr auto regex_flags = std::regex_constants::ECMAScript; +#else + constexpr auto regex_flags = std::regex_constants::optimize | std::regex_constants::nosubs; +#endif + std::basic_regex<CharT> expr(regex, regex_flags); std::vector<size_t> bpe_offsets; // store the offset of each word bpe_offsets.reserve(offsets.size()); // Reserve memory for the approximate size size_t start = 0; for (auto offset : offsets) { - std::cregex_iterator it(text.data() + start, text.data() + start + offset, expr); - std::cregex_iterator end; + std::regex_iterator<BidirIt> it(text.begin() + start, text.begin() + start + offset, expr); + std::regex_iterator<BidirIt> end; int64_t start_idx = 0; while (it != end) { - std::cmatch match = *it; + std::match_results<BidirIt> match = *it; if (match.position() > start_idx) { bpe_offsets.emplace_back(match.position() - start_idx); } @@ -803,6 +1018,35 @@ static std::vector<size_t> unicode_regex_split_custom_afmoe(const std::string & return bpe_offsets; } +// regex: [^\n]+|[\n]+ +// splits text into runs of non-newline characters and runs of newline characters +static std::vector<size_t> unicode_regex_split_custom_newlines(const std::string & text, const std::vector<size_t> & offsets) { + std::vector<size_t> bpe_offsets; + bpe_offsets.reserve(offsets.size()); + + const auto cpts = unicode_cpts_from_utf8(text); + + size_t start = 0; + for (auto offset : offsets) { + const size_t offset_ini = start; + const size_t offset_end = start + offset; + assert(offset_end <= cpts.size()); + start = offset_end; + + size_t pos = offset_ini; + while (pos < offset_end) { + const bool is_newline = (cpts[pos] == '\n'); + const size_t run_start = pos; + while (pos < offset_end && (cpts[pos] == '\n') == is_newline) { + pos++; + } + bpe_offsets.push_back(pos - run_start); + } + } + + return bpe_offsets; +} + static std::vector<size_t> unicode_regex_split_custom(const std::string & text, const std::string & regex_expr, const std::vector<size_t> & offsets) { std::vector<size_t> bpe_offsets; @@ -811,14 +1055,27 @@ static std::vector<size_t> unicode_regex_split_custom(const std::string & text, } else if ( regex_expr == "(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+" || regex_expr == "(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+") { - bpe_offsets = unicode_regex_split_custom_llama3(text, offsets); + } else if ( + regex_expr == "(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+") { + bpe_offsets = unicode_regex_split_custom_qwen2(text, offsets); + } else if ( + regex_expr == "(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])|[^\\r\\n\\p{L}\\p{N}]?[\\p{L}\\p{M}]+|\\p{N}| ?[^\\s\\p{L}\\p{M}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+") { + bpe_offsets = unicode_regex_split_custom_qwen35(text, offsets); } else if (regex_expr == "\\p{Han}+") { // K2's first pattern - handle all K2 patterns together bpe_offsets = unicode_regex_split_custom_kimi_k2(text, offsets); } else if (regex_expr == "\\p{AFMoE_digits}") { // AFMOE digit pattern - use custom implementation for proper splitting bpe_offsets = unicode_regex_split_custom_afmoe(text, offsets); + } else if (regex_expr == "[^\\n]+|[\\n]+") { + bpe_offsets = unicode_regex_split_custom_newlines(text, offsets); + } else if (regex_expr == "\\d{1,3}(?=(?:\\d{3})*\\b)") { + // tiny_aya digit grouping pattern from tokenizer.json: + // {"type": "Split", "pattern": {"Regex": "\\d{1,3}(?=(?:\\d{3})*\\b)"}, "behavior": "Isolated"} + // Splits digits into groups of 3 from the right (e.g., 1234567 -> 1, 234, 567) + // TODO: Revisit this regex, in case there are any subtle tokenization differences with the original regex. + bpe_offsets = unicode_regex_split_custom_afmoe(text, offsets); } return bpe_offsets; @@ -956,7 +1213,7 @@ bool unicode_cpt_is_han(uint32_t cpt) { return false; } -std::vector<std::string> unicode_regex_split(const std::string & text, const std::vector<std::string> & regex_exprs) { +std::vector<std::string> unicode_regex_split(const std::string & text, const std::vector<std::string> & regex_exprs, bool byte_encode) { // unicode categories static const std::map<std::string, int> k_ucat_enum = { { "\\p{N}", unicode_cpt_flags::NUMBER }, @@ -1051,10 +1308,10 @@ std::vector<std::string> unicode_regex_split(const std::string & text, const std break; } } + const auto cpts_regex = unicode_cpts_from_utf8(regex_expr); if (use_collapsed) { // sanity-check that the original regex does not contain any non-ASCII characters - const auto cpts_regex = unicode_cpts_from_utf8(regex_expr); for (size_t i = 0; i < cpts_regex.size(); ++i) { if (cpts_regex[i] >= 128) { throw std::runtime_error("Regex includes both unicode categories and non-ASCII characters - not supported"); @@ -1110,7 +1367,7 @@ std::vector<std::string> unicode_regex_split(const std::string & text, const std bpe_offsets = unicode_regex_split_stl(text_collapsed, regex_expr_collapsed, bpe_offsets); } else { // no unicode category used, we can use std::wregex directly - const std::wstring wregex_expr = unicode_wstring_from_utf8(regex_expr); + std::wstring wregex_expr(cpts_regex.begin(), cpts_regex.end()); // std::wregex \s does not mach non-ASCII whitespaces, using 0x0B as fallback std::wstring wtext(cpts.begin(), cpts.end()); @@ -1143,5 +1400,9 @@ std::vector<std::string> unicode_regex_split(const std::string & text, const std start += offset; } - return unicode_byte_encoding_process(bpe_words); + if (byte_encode) { + return unicode_byte_encoding_process(bpe_words); + } + + return bpe_words; } diff --git a/examples/talk-llama/unicode.h b/examples/talk-llama/unicode.h index 5bd1362ff41..600ab9216b9 100644 --- a/examples/talk-llama/unicode.h +++ b/examples/talk-llama/unicode.h @@ -108,4 +108,4 @@ uint32_t unicode_tolower(uint32_t cpt); bool unicode_cpt_is_han(uint32_t cpt); -std::vector<std::string> unicode_regex_split(const std::string & text, const std::vector<std::string> & regex_exprs); +std::vector<std::string> unicode_regex_split(const std::string & text, const std::vector<std::string> & regex_exprs, bool byte_encode = true); diff --git a/ggml/CMakeLists.txt b/ggml/CMakeLists.txt index 0176ca1ce93..249ed3da290 100644 --- a/ggml/CMakeLists.txt +++ b/ggml/CMakeLists.txt @@ -1,12 +1,15 @@ -cmake_minimum_required(VERSION 3.14) # for add_link_options and implicit target directories. +cmake_minimum_required(VERSION 3.14...3.28) # for add_link_options and implicit target directories. + project("ggml" C CXX ASM) ### GGML Version set(GGML_VERSION_MAJOR 0) -set(GGML_VERSION_MINOR 9) -set(GGML_VERSION_PATCH 5) +set(GGML_VERSION_MINOR 15) +set(GGML_VERSION_PATCH 1) set(GGML_VERSION_BASE "${GGML_VERSION_MAJOR}.${GGML_VERSION_MINOR}.${GGML_VERSION_PATCH}") +list(APPEND CMAKE_MODULE_PATH "${CMAKE_CURRENT_SOURCE_DIR}/cmake/") + find_program(GIT_EXE NAMES git git.exe NO_CMAKE_FIND_ROOT_PATH) if(GIT_EXE) # Get current git commit hash @@ -166,15 +169,16 @@ if (NOT MSVC) option(GGML_AMX_INT8 "ggml: enable AMX-INT8" OFF) option(GGML_AMX_BF16 "ggml: enable AMX-BF16" OFF) endif() -option(GGML_LASX "ggml: enable lasx" ON) -option(GGML_LSX "ggml: enable lsx" ON) -option(GGML_RVV "ggml: enable rvv" ON) -option(GGML_RV_ZFH "ggml: enable riscv zfh" ON) -option(GGML_RV_ZVFH "ggml: enable riscv zvfh" ON) -option(GGML_RV_ZICBOP "ggml: enable riscv zicbop" ON) -option(GGML_RV_ZIHINTPAUSE "ggml: enable riscv zihintpause " ON) -option(GGML_XTHEADVECTOR "ggml: enable xtheadvector" OFF) -option(GGML_VXE "ggml: enable vxe" ${GGML_NATIVE}) +option(GGML_LASX "ggml: enable lasx" ON) +option(GGML_LSX "ggml: enable lsx" ON) +option(GGML_RVV "ggml: enable rvv" ON) +option(GGML_RV_ZFH "ggml: enable riscv zfh" ON) +option(GGML_RV_ZVFH "ggml: enable riscv zvfh" ON) +option(GGML_RV_ZICBOP "ggml: enable riscv zicbop" ON) +option(GGML_RV_ZIHINTPAUSE "ggml: enable riscv zihintpause" ON) +option(GGML_RV_ZVFBFWMA "ggml: enable riscv zvfbfwma" OFF) +option(GGML_XTHEADVECTOR "ggml: enable xtheadvector" OFF) +option(GGML_VXE "ggml: enable vxe" ${GGML_NATIVE}) option(GGML_CPU_ALL_VARIANTS "ggml: build all variants of the CPU backend (requires GGML_BACKEND_DL)" OFF) set(GGML_CPU_ARM_ARCH "" CACHE STRING "ggml: CPU architecture for ARM") @@ -203,12 +207,14 @@ option(GGML_CUDA_NO_VMM "ggml: do not try to use CUDA VMM" option(GGML_CUDA_FA "ggml: compile ggml FlashAttention CUDA kernels" ON) option(GGML_CUDA_FA_ALL_QUANTS "ggml: compile all quants for FlashAttention" OFF) option(GGML_CUDA_GRAPHS "ggml: use CUDA graphs (llama.cpp only)" ${GGML_CUDA_GRAPHS_DEFAULT}) +option(GGML_CUDA_NCCL "ggml: use NVIDIA Collective Comm. Library" ON) set (GGML_CUDA_COMPRESSION_MODE "size" CACHE STRING "ggml: cuda link binary compression mode; requires cuda 12.8+") set_property(CACHE GGML_CUDA_COMPRESSION_MODE PROPERTY STRINGS "none;speed;balance;size") option(GGML_HIP "ggml: use HIP" OFF) -option(GGML_HIP_GRAPHS "ggml: use HIP graph, experimental, slow" OFF) +option(GGML_HIP_GRAPHS "ggml: use HIP graph" ON) +option(GGML_HIP_RCCL "ggml: use ROCm Collective Comm. Library" OFF) option(GGML_HIP_NO_VMM "ggml: do not try to use HIP VMM" ON) option(GGML_HIP_ROCWMMA_FATTN "ggml: enable rocWMMA for FlashAttention" OFF) option(GGML_HIP_MMQ_MFMA "ggml: enable MFMA MMA for CDNA in MMQ" ON) @@ -228,6 +234,8 @@ option(GGML_WEBGPU_CPU_PROFILE "ggml: enable WebGPU profiling (CPU) option(GGML_WEBGPU_GPU_PROFILE "ggml: enable WebGPU profiling (GPU)" OFF) option(GGML_WEBGPU_JSPI "ggml: use JSPI for WebGPU" ON) option(GGML_ZDNN "ggml: use zDNN" OFF) +option(GGML_VIRTGPU "ggml: use the VirtGPU/Virglrenderer API Remoting frontend" OFF) +option(GGML_VIRTGPU_BACKEND "ggml: build the VirtGPU/Virglrenderer API Remoting backend" OFF) option(GGML_METAL "ggml: use Metal" ${GGML_METAL_DEFAULT}) option(GGML_METAL_NDEBUG "ggml: disable Metal debugging" OFF) option(GGML_METAL_SHADER_DEBUG "ggml: compile Metal with -fno-fast-math" OFF) @@ -240,18 +248,22 @@ option(GGML_RPC "ggml: use RPC" option(GGML_SYCL "ggml: use SYCL" OFF) option(GGML_SYCL_F16 "ggml: use 16 bit floats for sycl calculations" OFF) option(GGML_SYCL_GRAPH "ggml: enable graphs in the SYCL backend" ON) +option(GGML_SYCL_HOST_MEM_FALLBACK "ggml: allow host memory fallback in SYCL reorder (requires kernel 6.8+)" ON) +option(GGML_SYCL_SUPPORT_LEVEL_ZERO "ggml: use Level Zero API in SYCL backend" ON) option(GGML_SYCL_DNN "ggml: enable oneDNN in the SYCL backend" ON) set (GGML_SYCL_TARGET "INTEL" CACHE STRING "ggml: sycl target device") set (GGML_SYCL_DEVICE_ARCH "" CACHE STRING "ggml: sycl device architecture") +option(GGML_OPENVINO "ggml: use OPENVINO" OFF) + option(GGML_OPENCL "ggml: use OpenCL" OFF) option(GGML_OPENCL_PROFILING "ggml: use OpenCL profiling (increases overhead)" OFF) option(GGML_OPENCL_EMBED_KERNELS "ggml: embed kernels" ON) option(GGML_OPENCL_USE_ADRENO_KERNELS "ggml: use optimized kernels for Adreno" ON) set (GGML_OPENCL_TARGET_VERSION "300" CACHE STRING - "gmml: OpenCL API version to target") + "ggml: OpenCL API version to target") option(GGML_HEXAGON "ggml: enable Hexagon backend" OFF) set(GGML_HEXAGON_FP32_QUANTIZE_GROUP_SIZE 128 CACHE STRING "ggml: quantize group size (32, 64, or 128)") @@ -320,10 +332,12 @@ set(GGML_PUBLIC_HEADERS include/ggml-opt.h include/ggml-metal.h include/ggml-rpc.h + include/ggml-virtgpu.h include/ggml-sycl.h include/ggml-vulkan.h include/ggml-webgpu.h include/ggml-zendnn.h + include/ggml-openvino.h include/gguf.h) set_target_properties(ggml PROPERTIES PUBLIC_HEADER "${GGML_PUBLIC_HEADERS}") @@ -339,7 +353,7 @@ if (GGML_STANDALONE) @ONLY) install(FILES ${CMAKE_CURRENT_BINARY_DIR}/ggml.pc - DESTINATION share/pkgconfig) + DESTINATION ${CMAKE_INSTALL_LIBDIR}/pkgconfig) endif() # diff --git a/ggml/cmake/BuildTypes.cmake b/ggml/cmake/BuildTypes.cmake deleted file mode 100644 index a9c7b6c91ec..00000000000 --- a/ggml/cmake/BuildTypes.cmake +++ /dev/null @@ -1,54 +0,0 @@ -# Add new build types - -# ReleaseGG - Release with enabled asserts - -SET(CMAKE_CXX_FLAGS_RELEASEGG - "-O3" - CACHE STRING "Flags used by the c++ compiler during release builds with enabled asserts." - FORCE ) -SET(CMAKE_C_FLAGS_RELEASEGG - "-O3" - CACHE STRING "Flags used by the compiler during release builds with enabled asserts." - FORCE ) -SET(CMAKE_EXE_LINKER_FLAGS_RELEASEGG - "" - CACHE STRING "Flags used for linking binaries during release builds with enabled asserts." - FORCE ) -SET(CMAKE_SHARED_LINKER_FLAGS_RELEASEGG - "" - CACHE STRING "Flags used by the shared libraries linker during release builds with enabled asserts." - FORCE ) -MARK_AS_ADVANCED( - CMAKE_CXX_FLAGS_RELEASEGG - CMAKE_C_FLAGS_RELEASEGG - CMAKE_EXE_LINKER_FLAGS_RELEASEGG - CMAKE_SHARED_LINKER_FLAGS_RELEASEGG ) - -# RelWithDebInfoGG - RelWithDebInfo with enabled asserts - -SET(CMAKE_CXX_FLAGS_RELWITHDEBINFOGG - "-O2 -g" - CACHE STRING "Flags used by the c++ compiler during release builds with debug symbols and enabled asserts." - FORCE ) -SET(CMAKE_C_FLAGS_RELWITHDEBINFOGG - "-O2 -g" - CACHE STRING "Flags used by the compiler during release builds with debug symbols and enabled asserts." - FORCE ) -SET(CMAKE_EXE_LINKER_FLAGS_RELWITHDEBINFOGG - "" - CACHE STRING "Flags used for linking binaries during release builds with debug symbols and enabled asserts." - FORCE ) -SET(CMAKE_SHARED_LINKER_FLAGS_RELWITHDEBINFOGG - "" - CACHE STRING "Flags used by the shared libraries linker during release builds with debug symbols and enabled asserts." - FORCE ) -MARK_AS_ADVANCED( - CMAKE_CXX_FLAGS_RELWITHDEBINFOGG - CMAKE_C_FLAGS_RELWITHDEBINFOGG - CMAKE_EXE_LINKER_FLAGS_RELWITHDEBINFOGG - CMAKE_SHARED_LINKER_FLAGS_RELWITHDEBINFOGG ) - -if (NOT XCODE AND NOT MSVC AND NOT CMAKE_BUILD_TYPE) - set(CMAKE_BUILD_TYPE Release CACHE STRING "Build type" FORCE) - set_property(CACHE CMAKE_BUILD_TYPE PROPERTY STRINGS "Debug" "Release" "MinSizeRel" "RelWithDebInfo" "ReleaseGG" "RelWithDebInfoGG") -endif() diff --git a/ggml/cmake/FindNCCL.cmake b/ggml/cmake/FindNCCL.cmake new file mode 100644 index 00000000000..67511e2d56a --- /dev/null +++ b/ggml/cmake/FindNCCL.cmake @@ -0,0 +1,36 @@ +# cmake/FindNCCL.cmake + +# NVIDIA does not distribute CMake files with NCCl, therefore use this file to find it instead. + +find_path(NCCL_INCLUDE_DIR + NAMES nccl.h + HINTS ${NCCL_ROOT} $ENV{NCCL_ROOT} $ENV{CUDA_HOME} /usr/local/cuda + PATH_SUFFIXES include +) + +find_library(NCCL_LIBRARY + NAMES nccl + HINTS ${NCCL_ROOT} $ENV{NCCL_ROOT} $ENV{CUDA_HOME} /usr/local/cuda + PATH_SUFFIXES lib lib64 +) + +include(FindPackageHandleStandardArgs) +find_package_handle_standard_args(NCCL + DEFAULT_MSG + NCCL_LIBRARY NCCL_INCLUDE_DIR +) + +if(NCCL_FOUND) + set(NCCL_LIBRARIES ${NCCL_LIBRARY}) + set(NCCL_INCLUDE_DIRS ${NCCL_INCLUDE_DIR}) + + if(NOT TARGET NCCL::NCCL) + add_library(NCCL::NCCL UNKNOWN IMPORTED) + set_target_properties(NCCL::NCCL PROPERTIES + IMPORTED_LOCATION "${NCCL_LIBRARY}" + INTERFACE_INCLUDE_DIRECTORIES "${NCCL_INCLUDE_DIR}" + ) + endif() +endif() + +mark_as_advanced(NCCL_INCLUDE_DIR NCCL_LIBRARY) diff --git a/ggml/cmake/ggml-config.cmake.in b/ggml/cmake/ggml-config.cmake.in index 91c9d5cd343..23a3066f56d 100644 --- a/ggml/cmake/ggml-config.cmake.in +++ b/ggml/cmake/ggml-config.cmake.in @@ -6,6 +6,7 @@ include(CMakeFindDependencyMacro) find_dependency(Threads) if (NOT GGML_SHARED_LIB) + set(GGML_BASE_INTERFACE_LINK_LIBRARIES "") set(GGML_CPU_INTERFACE_LINK_LIBRARIES "") set(GGML_CPU_INTERFACE_LINK_OPTIONS "") @@ -20,7 +21,15 @@ if (NOT GGML_SHARED_LIB) if (GGML_OPENMP_ENABLED) find_dependency(OpenMP) - list(APPEND GGML_CPU_INTERFACE_LINK_LIBRARIES OpenMP::OpenMP_C OpenMP::OpenMP_CXX) + set(GGML_OPENMP_INTERFACE_LINK_LIBRARIES "") + if (TARGET OpenMP::OpenMP_C) + list(APPEND GGML_OPENMP_INTERFACE_LINK_LIBRARIES OpenMP::OpenMP_C) + endif() + if (TARGET OpenMP::OpenMP_CXX) + list(APPEND GGML_OPENMP_INTERFACE_LINK_LIBRARIES OpenMP::OpenMP_CXX) + endif() + list(APPEND GGML_BASE_INTERFACE_LINK_LIBRARIES ${GGML_OPENMP_INTERFACE_LINK_LIBRARIES}) + list(APPEND GGML_CPU_INTERFACE_LINK_LIBRARIES ${GGML_OPENMP_INTERFACE_LINK_LIBRARIES}) endif() if (GGML_CPU_HBM) @@ -122,7 +131,8 @@ if(NOT TARGET ggml::ggml) add_library(ggml::ggml-base UNKNOWN IMPORTED) set_target_properties(ggml::ggml-base PROPERTIES - IMPORTED_LOCATION "${GGML_BASE_LIBRARY}") + IMPORTED_LOCATION "${GGML_BASE_LIBRARY}" + INTERFACE_LINK_LIBRARIES "${GGML_BASE_INTERFACE_LINK_LIBRARIES}") set(_ggml_all_targets "") if (NOT GGML_BACKEND_DL) diff --git a/ggml/include/ggml-alloc.h b/ggml/include/ggml-alloc.h index 78aa059dde3..a7926a21a9a 100644 --- a/ggml/include/ggml-alloc.h +++ b/ggml/include/ggml-alloc.h @@ -76,6 +76,7 @@ GGML_API size_t ggml_gallocr_get_buffer_size(ggml_gallocr_t galloc, int buffer_i // Utils // Create a buffer and allocate all the tensors in a ggml_context // ggml_backend_alloc_ctx_tensors_from_buft_size returns the size of the buffer that would be allocated by ggml_backend_alloc_ctx_tensors_from_buft +// ggml_backend_alloc_ctx_tensors_from_buft returns NULL on failure or if all tensors in ctx are already allocated or zero-sized GGML_API size_t ggml_backend_alloc_ctx_tensors_from_buft_size(struct ggml_context * ctx, ggml_backend_buffer_type_t buft); GGML_API struct ggml_backend_buffer * ggml_backend_alloc_ctx_tensors_from_buft(struct ggml_context * ctx, ggml_backend_buffer_type_t buft); GGML_API struct ggml_backend_buffer * ggml_backend_alloc_ctx_tensors(struct ggml_context * ctx, ggml_backend_t backend); diff --git a/ggml/include/ggml-backend.h b/ggml/include/ggml-backend.h index a9d1778641e..2924fdbe988 100644 --- a/ggml/include/ggml-backend.h +++ b/ggml/include/ggml-backend.h @@ -68,7 +68,7 @@ extern "C" { GGML_API void ggml_backend_buffer_reset (ggml_backend_buffer_t buffer); // tensor copy between different backends - GGML_API void ggml_backend_tensor_copy(struct ggml_tensor * src, struct ggml_tensor * dst); + GGML_API void ggml_backend_tensor_copy(const struct ggml_tensor * src, struct ggml_tensor * dst); // // Backend (stream) @@ -83,13 +83,17 @@ extern "C" { GGML_API size_t ggml_backend_get_alignment(ggml_backend_t backend); GGML_API size_t ggml_backend_get_max_size(ggml_backend_t backend); - GGML_API void ggml_backend_tensor_set_async(ggml_backend_t backend, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size); - GGML_API void ggml_backend_tensor_get_async(ggml_backend_t backend, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size); + GGML_API void ggml_backend_tensor_set_async (ggml_backend_t backend, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size); + GGML_API void ggml_backend_tensor_get_async (ggml_backend_t backend, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size); + GGML_API void ggml_backend_tensor_set_2d_async(ggml_backend_t backend, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size, size_t n_copies, size_t stride_tensor, size_t stride_data); + GGML_API void ggml_backend_tensor_get_2d_async(ggml_backend_t backend, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size, size_t n_copies, size_t stride_tensor, size_t stride_data); // "offset" refers to the offset in tensor->data for setting/getting data - GGML_API void ggml_backend_tensor_set( struct ggml_tensor * tensor, const void * data, size_t offset, size_t size); - GGML_API void ggml_backend_tensor_get(const struct ggml_tensor * tensor, void * data, size_t offset, size_t size); - GGML_API void ggml_backend_tensor_memset( struct ggml_tensor * tensor, uint8_t value, size_t offset, size_t size); + GGML_API void ggml_backend_tensor_set ( struct ggml_tensor * tensor, const void * data, size_t offset, size_t size); + GGML_API void ggml_backend_tensor_get (const struct ggml_tensor * tensor, void * data, size_t offset, size_t size); + GGML_API void ggml_backend_tensor_set_2d( struct ggml_tensor * tensor, const void * data, size_t offset, size_t size, size_t n_copies, size_t stride_tensor, size_t stride_data); + GGML_API void ggml_backend_tensor_get_2d(const struct ggml_tensor * tensor, void * data, size_t offset, size_t size, size_t n_copies, size_t stride_tensor, size_t stride_data); + GGML_API void ggml_backend_tensor_memset( struct ggml_tensor * tensor, uint8_t value, size_t offset, size_t size); GGML_API void ggml_backend_synchronize(ggml_backend_t backend); @@ -109,7 +113,7 @@ extern "C" { // the copy is performed after all the currently queued operations in backend_src // backend_dst will wait for the copy to complete before performing other operations // automatic fallback to sync copy if async is not supported - GGML_API void ggml_backend_tensor_copy_async(ggml_backend_t backend_src, ggml_backend_t backend_dst, struct ggml_tensor * src, struct ggml_tensor * dst); + GGML_API void ggml_backend_tensor_copy_async(ggml_backend_t backend_src, ggml_backend_t backend_dst, const struct ggml_tensor * src, struct ggml_tensor * dst); GGML_API ggml_backend_dev_t ggml_backend_get_device(ggml_backend_t backend); @@ -135,7 +139,9 @@ extern "C" { // integrated GPU device using host memory GGML_BACKEND_DEVICE_TYPE_IGPU, // accelerator devices intended to be used together with the CPU backend (e.g. BLAS or AMX) - GGML_BACKEND_DEVICE_TYPE_ACCEL + GGML_BACKEND_DEVICE_TYPE_ACCEL, + // "meta" device wrapping multiple other devices for tensor parallelism + GGML_BACKEND_DEVICE_TYPE_META, }; // functionality supported by the device @@ -163,7 +169,7 @@ extern "C" { // device type enum ggml_backend_dev_type type; // device id - // for PCI devices, this should be the PCI bus id formatted as "domain:bus:device.function" (e.g. "0000:01:00.0") + // for PCI devices, this should be the lower-case PCI bus id formatted as "domain:bus:device.function" (e.g. "0000:c1:00.0") // if the id is unknown, this should be NULL const char * device_id; // device capabilities @@ -196,7 +202,12 @@ extern "C" { // Common functions that may be obtained using ggml_backend_reg_get_proc_address - // Split buffer type for tensor parallelism + // Context management and operations for faster communication between backends, used for tensor parallelism (meta backend) + typedef void * (*ggml_backend_comm_init_t)(ggml_backend_t * backends, size_t n_backends); + typedef void (*ggml_backend_comm_free_t)(void * comm_ctx); + typedef bool (*ggml_backend_comm_allreduce_tensor_t)(void * comm_ctx, struct ggml_tensor ** tensors); + + // Split buffer type for tensor parallelism (old) typedef ggml_backend_buffer_type_t (*ggml_backend_split_buffer_type_t)(int main_device, const float * tensor_split); // Set the number of threads for the backend typedef void (*ggml_backend_set_n_threads_t)(ggml_backend_t backend, int n_threads); @@ -259,7 +270,7 @@ extern "C" { Example usage: // operations that use tensors allocated in a buffer with USAGE_WEIGHTS will be assigned - // preferrably to run on the same backend as the buffer + // preferably to run on the same backend as the buffer ggml_backend_buffer_set_usage(buf_weights, GGML_BACKEND_BUFFER_USAGE_WEIGHTS); sched = ggml_backend_sched_new({backend_gpu, backend_gpu2, backend_cpu}, NULL, num_backends, GGML_DEFAULT_GRAPH_SIZE, false, true); @@ -340,6 +351,57 @@ extern "C" { // Set a callback to be called for each resulting node during graph compute GGML_API void ggml_backend_sched_set_eval_callback(ggml_backend_sched_t sched, ggml_backend_sched_eval_callback callback, void * user_data); + // + // Meta backend + // + +#define GGML_BACKEND_META_MAX_DEVICES 16 + + enum ggml_backend_meta_split_axis { + // tensor split by tensor dimensions: + GGML_BACKEND_SPLIT_AXIS_0 = 0, + GGML_BACKEND_SPLIT_AXIS_1 = 1, + GGML_BACKEND_SPLIT_AXIS_2 = 2, + GGML_BACKEND_SPLIT_AXIS_3 = 3, + + GGML_BACKEND_SPLIT_AXIS_MIRRORED = 10, // all values on all backends + GGML_BACKEND_SPLIT_AXIS_PARTIAL = 11, // each backend has a partial sum + + // for internal bookkeeping only: + GGML_BACKEND_SPLIT_AXIS_NONE = 98, + GGML_BACKEND_SPLIT_AXIS_UNKNOWN = 99, + }; + GGML_API const char * ggml_backend_meta_split_axis_name(enum ggml_backend_meta_split_axis split_axis); + + struct ggml_backend_meta_split_state { + enum ggml_backend_meta_split_axis axis; + + // for tensors with axis >= 0 && axis < GGML_MAX_DIMS: + // - each device has a slice of the tensor along the split axis + // - most tensors have n_segments == 1 and a contiguous slice of the tensor data + // - some tensors have an inhomogenenous data layout along the split axis, + // those tensors are divided into segments which are each individually split across devices + // - ne has one entry per segment and device and that segment repeats nr times, + // in total when accounting for repetitions the segments add up to ggml_tensor::ne for that axis, + // the outer/inner loops are over segments/devices like [seg0_dev0_r0, seg0_dev1_r0, seg0_dev0_r1, seg0_dev1_r1, seg1_dev0_r0, seg1_dev1_r0], + // - for example, a transformer may have a fused QKV matrix rather than 3 matrices, those would be 3 separate segments + // that each need to be split individually across devices so that each device gets a slice of Q, K, and V, + // the Q matrix can be larger than the K and V matrices so this can either be expressed as 3 segments or as 2 segments + // where the segment for K/V repeats twice + int64_t ne[16*GGML_BACKEND_META_MAX_DEVICES]; + uint32_t nr[16]; + uint32_t n_segments; + }; + + // function to assign split states for statically allocated tensors, compute tensor split states will be assigned to be compatible: + typedef struct ggml_backend_meta_split_state(*ggml_backend_meta_get_split_state_t)(const struct ggml_tensor * tensor, void * userdata); + + // create a new meta device from "simple" devices, meta buffer type/buffer/backend is then derived from this: + // TODO: this looks a bit strange - a backend API creates a device. I think we should try + // express this as a backend registry functionality instead + GGML_API ggml_backend_dev_t ggml_backend_meta_device( + ggml_backend_dev_t * devs, size_t n_devs, ggml_backend_meta_get_split_state_t get_split_state, void * get_split_state_ud); + // // Utils // diff --git a/ggml/include/ggml-cann.h b/ggml/include/ggml-cann.h index b469e228d06..74af465337a 100644 --- a/ggml/include/ggml-cann.h +++ b/ggml/include/ggml-cann.h @@ -1,5 +1,5 @@ /* - * Copyright (c) 2023-2024 The ggml authors + * Copyright (c) 2023-2026 The ggml authors * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to diff --git a/ggml/include/ggml-cpu.h b/ggml/include/ggml-cpu.h index 4f3b99c8d07..e3e067c916f 100644 --- a/ggml/include/ggml-cpu.h +++ b/ggml/include/ggml-cpu.h @@ -19,6 +19,9 @@ extern "C" { // abort ggml_graph_compute when true ggml_abort_callback abort_callback; void * abort_callback_data; + + // use only reference implementations + bool use_ref; }; // numa strategies @@ -132,6 +135,8 @@ extern "C" { GGML_BACKEND_API void ggml_backend_cpu_set_threadpool (ggml_backend_t backend_cpu, ggml_threadpool_t threadpool); GGML_BACKEND_API void ggml_backend_cpu_set_abort_callback(ggml_backend_t backend_cpu, ggml_abort_callback abort_callback, void * abort_callback_data); + GGML_BACKEND_API void ggml_backend_cpu_set_use_ref(ggml_backend_t backend_cpu, bool use_ref); + GGML_BACKEND_API ggml_backend_reg_t ggml_backend_cpu_reg(void); GGML_BACKEND_API void ggml_cpu_fp32_to_fp32(const float *, float *, int64_t); diff --git a/ggml/include/ggml-cuda.h b/ggml/include/ggml-cuda.h index 22ad2c00963..5436c7ef579 100644 --- a/ggml/include/ggml-cuda.h +++ b/ggml/include/ggml-cuda.h @@ -27,6 +27,9 @@ GGML_BACKEND_API bool ggml_backend_is_cuda(ggml_backend_t backend); // device buffer GGML_BACKEND_API ggml_backend_buffer_type_t ggml_backend_cuda_buffer_type(int device); +// conduct allreduce operation between devices +GGML_BACKEND_API bool ggml_backend_cuda_allreduce_tensor(ggml_backend_t * backends, struct ggml_tensor ** tensors, size_t n_backends); + // split tensor buffer that splits matrices by rows across multiple devices GGML_BACKEND_API ggml_backend_buffer_type_t ggml_backend_cuda_split_buffer_type(int main_device, const float * tensor_split); diff --git a/ggml/include/ggml-openvino.h b/ggml/include/ggml-openvino.h new file mode 100644 index 00000000000..c43beb07b6a --- /dev/null +++ b/ggml/include/ggml-openvino.h @@ -0,0 +1,37 @@ +#pragma once + +#include "ggml-backend.h" + +#include <cstring> + +#ifdef __cplusplus +extern "C" { +#endif + +#define GGML_OPENVINO_NAME "OPENVINO" + +// backend API +GGML_BACKEND_API ggml_backend_t ggml_backend_openvino_init(int device); + +GGML_BACKEND_API bool ggml_backend_is_openvino(ggml_backend_t backend); + +GGML_BACKEND_API bool ggml_backend_buffer_is_openvino(ggml_backend_buffer_t buffer); + +GGML_BACKEND_API bool ggml_backend_buft_is_openvino(ggml_backend_buffer_type_t buft); + +GGML_BACKEND_API bool ggml_backend_buft_is_openvino_host(ggml_backend_buffer_type_t buft); + +GGML_BACKEND_API size_t ggml_backend_openvino_buffer_get_ctx_id(ggml_backend_buffer_t buffer); + +// device buffer +GGML_BACKEND_API ggml_backend_buffer_type_t ggml_backend_openvino_buffer_type(int device); + +GGML_BACKEND_API ggml_backend_buffer_type_t ggml_backend_openvino_host_buffer_type(int device); + +GGML_BACKEND_API int ggml_backend_openvino_get_device_count(void); + +GGML_BACKEND_API ggml_backend_reg_t ggml_backend_openvino_reg(void); + +#ifdef __cplusplus +} +#endif diff --git a/ggml/include/ggml-opt.h b/ggml/include/ggml-opt.h index 4703a05afe1..1c2ed79b774 100644 --- a/ggml/include/ggml-opt.h +++ b/ggml/include/ggml-opt.h @@ -138,7 +138,7 @@ extern "C" { GGML_API ggml_opt_context_t ggml_opt_init(struct ggml_opt_params params); GGML_API void ggml_opt_free(ggml_opt_context_t opt_ctx); - // set gradients to zero, initilize loss, and optionally reset the optimizer + // set gradients to zero, initialize loss, and optionally reset the optimizer GGML_API void ggml_opt_reset(ggml_opt_context_t opt_ctx, bool optimizer); GGML_API bool ggml_opt_static_graphs(ggml_opt_context_t opt_ctx); // whether the graphs are allocated_statically diff --git a/ggml/include/ggml-rpc.h b/ggml/include/ggml-rpc.h index df1ad2a5168..5ad121ae57f 100644 --- a/ggml/include/ggml-rpc.h +++ b/ggml/include/ggml-rpc.h @@ -6,9 +6,14 @@ extern "C" { #endif -#define RPC_PROTO_MAJOR_VERSION 3 -#define RPC_PROTO_MINOR_VERSION 6 -#define RPC_PROTO_PATCH_VERSION 0 +#define RPC_PROTO_MAJOR_VERSION 4 +#define RPC_PROTO_MINOR_VERSION 0 +#define RPC_PROTO_PATCH_VERSION 1 + +#ifdef __cplusplus +static_assert(GGML_OP_COUNT == 97, "GGML_OP_COUNT has changed - update RPC_PROTO_PATCH_VERSION"); +#endif + #define GGML_RPC_MAX_SERVERS 16 // backend API diff --git a/ggml/include/ggml-virtgpu.h b/ggml/include/ggml-virtgpu.h new file mode 100644 index 00000000000..faaba8f246d --- /dev/null +++ b/ggml/include/ggml-virtgpu.h @@ -0,0 +1,14 @@ +#pragma once + +#include "ggml.h" +#include "ggml-backend.h" + +#ifdef __cplusplus +extern "C" { +#endif + +GGML_BACKEND_API ggml_backend_reg_t ggml_backend_virtgpu_reg(); + +#ifdef __cplusplus +} +#endif diff --git a/ggml/include/ggml.h b/ggml/include/ggml.h index b69583dd3fd..d6807b6dd47 100644 --- a/ggml/include/ggml.h +++ b/ggml/include/ggml.h @@ -6,7 +6,7 @@ // This documentation is still a work in progress. // If you wish some specific topics to be covered, feel free to drop a comment: // -// https://github.com/ggerganov/whisper.cpp/issues/40 +// https://github.com/ggml-org/whisper.cpp/issues/40 // // ## Overview // @@ -427,7 +427,9 @@ extern "C" { // GGML_TYPE_IQ4_NL_4_8 = 37, // GGML_TYPE_IQ4_NL_8_8 = 38, GGML_TYPE_MXFP4 = 39, // MXFP4 (1 block) - GGML_TYPE_COUNT = 40, + GGML_TYPE_NVFP4 = 40, // NVFP4 (4 blocks, E4M3 scale) + GGML_TYPE_Q1_0 = 41, + GGML_TYPE_COUNT = 42, }; // precision @@ -436,6 +438,12 @@ extern "C" { GGML_PREC_F32 = 10, }; + // op hint + enum ggml_op_hint { + GGML_HINT_NONE = 0, + GGML_HINT_SRC0_IS_HADAMARD = 1, + }; + // model file types enum ggml_ftype { GGML_FTYPE_UNKNOWN = -1, @@ -463,6 +471,8 @@ extern "C" { GGML_FTYPE_MOSTLY_IQ1_M = 23, // except 1d tensors GGML_FTYPE_MOSTLY_BF16 = 24, // except 1d tensors GGML_FTYPE_MOSTLY_MXFP4 = 25, // except 1d tensors + GGML_FTYPE_MOSTLY_NVFP4 = 26, // except 1d tensors + GGML_FTYPE_MOSTLY_Q1_0 = 27, // except 1d tensors }; // available tensor operations: @@ -525,6 +535,7 @@ extern "C" { GGML_OP_IM2COL, GGML_OP_IM2COL_BACK, GGML_OP_IM2COL_3D, + GGML_OP_COL2IM_1D, GGML_OP_CONV_2D, GGML_OP_CONV_3D, GGML_OP_CONV_2D_DW, @@ -556,6 +567,7 @@ extern "C" { GGML_OP_GATED_LINEAR_ATTN, GGML_OP_RWKV_WKV7, GGML_OP_SOLVE_TRI, + GGML_OP_GATED_DELTA_NET, GGML_OP_UNARY, @@ -630,10 +642,11 @@ extern "C" { // this tensor... enum ggml_tensor_flag { - GGML_TENSOR_FLAG_INPUT = 1, // ...is an input for the GGML compute graph - GGML_TENSOR_FLAG_OUTPUT = 2, // ...is an output for the GGML compute graph - GGML_TENSOR_FLAG_PARAM = 4, // ...contains trainable parameters - GGML_TENSOR_FLAG_LOSS = 8, // ...defines loss for numerical optimization (multiple loss tensors add up) + GGML_TENSOR_FLAG_INPUT = 1, // ...is an input for the GGML compute graph + GGML_TENSOR_FLAG_OUTPUT = 2, // ...is an output for the GGML compute graph + GGML_TENSOR_FLAG_PARAM = 4, // ...contains trainable parameters + GGML_TENSOR_FLAG_LOSS = 8, // ...defines loss for numerical optimization (multiple loss tensors add up) + GGML_TENSOR_FLAG_COMPUTE = 16, // ...must be computed }; enum ggml_tri_type { @@ -751,6 +764,7 @@ extern "C" { GGML_API bool ggml_is_transposed(const struct ggml_tensor * tensor); GGML_API bool ggml_is_permuted (const struct ggml_tensor * tensor); GGML_API bool ggml_is_empty (const struct ggml_tensor * tensor); + GGML_API bool ggml_is_view (const struct ggml_tensor * tensor); GGML_API bool ggml_is_scalar (const struct ggml_tensor * tensor); GGML_API bool ggml_is_vector (const struct ggml_tensor * tensor); GGML_API bool ggml_is_matrix (const struct ggml_tensor * tensor); @@ -895,15 +909,17 @@ extern "C" { struct ggml_tensor * b, struct ggml_tensor * ids); - GGML_API struct ggml_tensor * ggml_add1( + GGML_DEPRECATED(GGML_API struct ggml_tensor * ggml_add1( struct ggml_context * ctx, struct ggml_tensor * a, - struct ggml_tensor * b); + struct ggml_tensor * b), + "use ggml_add instead"); - GGML_API struct ggml_tensor * ggml_add1_inplace( + GGML_DEPRECATED(GGML_API struct ggml_tensor * ggml_add1_inplace( struct ggml_context * ctx, struct ggml_tensor * a, - struct ggml_tensor * b); + struct ggml_tensor * b), + "use ggml_add_inplace instead"); // dst = a // view(dst, nb1, nb2, nb3, offset) += b @@ -1174,8 +1190,8 @@ extern "C" { struct ggml_context * ctx, struct ggml_tensor * a); - // a - x - // b - dy + // a - dy + // b - x GGML_API struct ggml_tensor * ggml_silu_back( struct ggml_context * ctx, struct ggml_tensor * a, @@ -1410,6 +1426,11 @@ extern "C" { struct ggml_tensor * a, enum ggml_prec prec); + // change the hint of a matrix multiplication + GGML_API void ggml_mul_mat_set_hint( + struct ggml_tensor * a, + enum ggml_op_hint hint); + // indirect matrix multiplication GGML_API struct ggml_tensor * ggml_mul_mat_id( struct ggml_context * ctx, @@ -1764,8 +1785,32 @@ extern "C" { int n_dims, int mode); - // custom RoPE + // RoPE operations with extended options + // a is the input tensor to apply RoPE to, shape [n_embd, n_head, n_token] + // b is an int32 vector with size n_token // c is freq factors (e.g. phi3-128k), (optional) + // mode can be GGML_ROPE_TYPE_NORMAL or NEOX; for MROPE and VISION mode, use ggml_rope_multi + // + // pseudo-code for computing theta: + // for i in [0, n_dims/2): + // theta[i] = b[i] * powf(freq_base, -2.0 * i / n_dims); + // theta[i] = theta[i] / c[i]; # if c is provided, divide theta by c + // theta[i] = rope_yarn(theta[i], ...); # note: theta = theta * freq_scale is applied here + // + // other params are used by YaRN RoPE scaling, these default values will disable YaRN: + // freq_scale = 1.0f + // ext_factor = 0.0f + // attn_factor = 1.0f + // beta_fast = 0.0f + // beta_slow = 0.0f + // + // example: + // (marking: c = cos, s = sin, 0 = unrotated) + // given a single head with size = 8 --> [00000000] + // GGML_ROPE_TYPE_NORMAL n_dims = 4 --> [cscs0000] + // GGML_ROPE_TYPE_NORMAL n_dims = 8 --> [cscscscs] + // GGML_ROPE_TYPE_NEOX n_dims = 4 --> [ccss0000] + // GGML_ROPE_TYPE_NEOX n_dims = 8 --> [ccccssss] GGML_API struct ggml_tensor * ggml_rope_ext( struct ggml_context * ctx, struct ggml_tensor * a, @@ -1781,6 +1826,36 @@ extern "C" { float beta_fast, float beta_slow); + // multi-dimensional RoPE, for Qwen-VL and similar vision models + // mode can be either VISION, MROPE, IMROPE, cannot be combined with NORMAL or NEOX + // sections specify how many dimensions to rotate in each section: + // section length is equivalent to number of cos/sin pairs, NOT the number of dims + // (i.e. sum of 4 sections are expected to be n_dims/2) + // last sections can be 0, means ignored + // all other options are identical to ggml_rope_ext + // + // important note: + // - NEOX ordering is automatically applied and cannot be disabled for MROPE and VISION + // if you need normal ordering, there are 2 methods: + // (1) split the tensor manually using ggml_view + // (2) permute the weight upon conversion + // - for VISION, n_dims must be head_size/2 + // + // example M-RoPE: + // given sections = [t=4, y=2, x=2, 0] + // given a single head with size = 18 --> [000000000000000000] + // GGML_ROPE_TYPE_MROPE n_dims = 16 --> [ttttyyxxttttyyxx00] (cos/sin are applied in NEOX ordering) + // GGML_ROPE_TYPE_IMROPE n_dims = 16 --> [ttyxttyxttyxttyx00] (interleaved M-RoPE, still NEOX ordering) + // note: the theta for each dim is computed the same way as ggml_rope_ext, no matter the section + // in other words, idx used for theta: [0123456789... until n_dims/2], not reset for each section + // + // example vision RoPE: + // given sections = [y=4, x=4, 0, 0] (last 2 sections are ignored) + // given a single head with size = 8 --> [00000000] + // GGML_ROPE_TYPE_VISION n_dims = 4 --> [yyyyxxxx] + // other values of n_dims are untested and is undefined behavior + // note: unlike MROPE, the theta for each dim is computed differently for each section + // in other words, idx used for theta: [0123] for y section, then [0123] for x section GGML_API struct ggml_tensor * ggml_rope_multi( struct ggml_context * ctx, struct ggml_tensor * a, @@ -1933,6 +2008,16 @@ extern "C" { int d1, // dilation dimension 1 bool is_2D); + // col2im_1d: scatter-add GEMM columns back to 1D signal + // a: [K*OC, T_in] (columns from matmul, K = a->ne[0]/OC) + // result: [T_out, OC] where T_out = (T_in - 1)*s0 + K - 2*p0 + GGML_API struct ggml_tensor * ggml_col2im_1d( + struct ggml_context * ctx, + struct ggml_tensor * a, // columns [K*OC, T_in] + int s0, // stride + int oc, // output channels + int p0); // padding to crop from both sides + GGML_API struct ggml_tensor * ggml_conv_1d( struct ggml_context * ctx, struct ggml_tensor * a, // convolution kernel @@ -2465,6 +2550,29 @@ extern "C" { bool lower, bool uni); + // TODO: add ggml_gated_delta_net_set_bcast() to be able to configure Q, K broadcast type: tiled vs interleaved [TAG_GGML_GDN_BCAST] + // ref: https://github.com/ggml-org/llama.cpp/pull/19468#discussion_r2786394306 + // + // tensor shapes (S_k == S_v, H_v % H_k == 0): + // q, k : [S_k, H_k, n_tokens, n_seqs] + // v : [S_v, H_v, n_tokens, n_seqs] + // g : [1, H_v, n_tokens, n_seqs] (scalar gate) or [S_v, H_v, n_tokens, n_seqs] (KDA) + // beta : [1, H_v, n_tokens, n_seqs] + // state : [S_v, S_v, H_v, n_seqs] -- initial recurrent state s0 + // + // the output packs the attention scores [S_v, H_v, n_tokens, n_seqs] followed by K state + // snapshots, most-recent first (slot 0 = final state, slot s = state s tokens back). K == 1 + // keeps only the final state; when n_tokens < K only slots 0..n_tokens-1 are written. + GGML_API struct ggml_tensor * ggml_gated_delta_net( + struct ggml_context * ctx, + struct ggml_tensor * q, + struct ggml_tensor * k, + struct ggml_tensor * v, + struct ggml_tensor * g, + struct ggml_tensor * beta, + struct ggml_tensor * state, + int64_t K); + // custom operators typedef void (*ggml_custom1_op_t)(struct ggml_tensor * dst , const struct ggml_tensor * a, int ith, int nth, void * userdata); @@ -2577,11 +2685,42 @@ extern "C" { struct ggml_tensor * grad, struct ggml_tensor * sgd_params); // alpha, weight decay + // build forward multiple tensors and select one of them for computing + // this is useful for creating graphs that have constant topology but compute different things based on the input + // ref: https://github.com/ggml-org/llama.cpp/pull/18550 + // + // nodes: + // | - build forward into the graph but do not compute + // c - build forward into the graph and compute + // + // | | ... c ... | + // | | ... c ... | + // | | ... c ... | + // [0 1 ... idx ... n-1] <-- ggml_build_forward_select(..., n, idx) + // c + // c + // + // example: + // struct ggml_tensor * curs[3]; + // + // curs[0] = compute0(...); + // curs[1] = compute1(...); + // curs[2] = compute2(...); // - // automatic differentiation + // int idx = select_branch(some_input); // + // struct ggml_tensor * out = ggml_build_forward_select(cgraph, curs, 3, idx); + // + GGML_API struct ggml_tensor * ggml_build_forward_select( + struct ggml_cgraph * cgraph, + struct ggml_tensor ** tensors, + int n_tensors, + int idx); + + GGML_API void ggml_build_forward_expand( + struct ggml_cgraph * cgraph, + struct ggml_tensor * tensor); - GGML_API void ggml_build_forward_expand(struct ggml_cgraph * cgraph, struct ggml_tensor * tensor); GGML_API void ggml_build_backward_expand( struct ggml_context * ctx, // context for gradient computation struct ggml_cgraph * cgraph, @@ -2613,7 +2752,7 @@ extern "C" { GGML_API void ggml_graph_print(const struct ggml_cgraph * cgraph); // dump the graph into a file using the dot format - GGML_API void ggml_graph_dump_dot(const struct ggml_cgraph * gb, const struct ggml_cgraph * gf, const char * filename); + GGML_API void ggml_graph_dump_dot(const struct ggml_cgraph * gb, const struct ggml_cgraph * cgraph, const char * filename); // TODO these functions were sandwiched in the old optimization interface, is there a better place for them? typedef void (*ggml_log_callback)(enum ggml_log_level level, const char * text, void * user_data); diff --git a/ggml/include/gguf.h b/ggml/include/gguf.h index 79ee202062b..67851ba6f16 100644 --- a/ggml/include/gguf.h +++ b/ggml/include/gguf.h @@ -76,9 +76,16 @@ extern "C" { struct ggml_context ** ctx; }; + // callback to simulate or wrap a FILE pointer - read up to `len` bytes at `offset` into `output` and return the number of bytes read + typedef size_t (*gguf_reader_callback_t)(void * userdata, void * output, uint64_t offset, size_t len); + GGML_API struct gguf_context * gguf_init_empty(void); + GGML_API struct gguf_context * gguf_init_from_file_ptr(FILE * file, struct gguf_init_params params); GGML_API struct gguf_context * gguf_init_from_file(const char * fname, struct gguf_init_params params); - //GGML_API struct gguf_context * gguf_init_from_buffer(..); + GGML_API struct gguf_context * gguf_init_from_buffer(const void * data, size_t size, struct gguf_init_params params); + + // max_chunk_read is the maximum number of bytes that the GGUF code will read at once from the callback, a value of 0 means no limit + GGML_API struct gguf_context * gguf_init_from_callback(gguf_reader_callback_t callback, void * userdata, size_t max_chunk_read, uint64_t max_expected_size, struct gguf_init_params params); GGML_API void gguf_free(struct gguf_context * ctx); @@ -86,7 +93,7 @@ extern "C" { GGML_API uint32_t gguf_get_version (const struct gguf_context * ctx); GGML_API size_t gguf_get_alignment (const struct gguf_context * ctx); - GGML_API size_t gguf_get_data_offset(const struct gguf_context * ctx); + GGML_API size_t gguf_get_data_offset(const struct gguf_context * ctx); // padded to gguf_get_alignment if and only if the gguf_context contains at least one tensor GGML_API int64_t gguf_get_n_kv(const struct gguf_context * ctx); GGML_API int64_t gguf_find_key(const struct gguf_context * ctx, const char * key); // returns -1 if key is not found @@ -189,6 +196,7 @@ extern "C" { // // write the entire context to a binary file + GGML_API bool gguf_write_to_file_ptr(const struct gguf_context * ctx, FILE * file, bool only_meta); GGML_API bool gguf_write_to_file(const struct gguf_context * ctx, const char * fname, bool only_meta); // get the size in bytes of the meta data (header, kv pairs, tensor info) including padding diff --git a/ggml/src/CMakeLists.txt b/ggml/src/CMakeLists.txt index 6192a870466..c26c3f1470d 100644 --- a/ggml/src/CMakeLists.txt +++ b/ggml/src/CMakeLists.txt @@ -200,6 +200,7 @@ add_library(ggml-base ggml.cpp ggml-alloc.c ggml-backend.cpp + ggml-backend-meta.cpp ggml-opt.cpp ggml-threading.cpp ggml-threading.h @@ -221,7 +222,25 @@ if (GGML_SCHED_NO_REALLOC) target_compile_definitions(ggml-base PUBLIC GGML_SCHED_NO_REALLOC) endif() +if (GGML_OPENMP) + find_package(OpenMP) + if (OpenMP_FOUND) + set(GGML_OPENMP_ENABLED "ON" CACHE INTERNAL "") + else() + set(GGML_OPENMP_ENABLED "OFF" CACHE INTERNAL "") + message(WARNING "OpenMP not found") + endif() +else() + set(GGML_OPENMP_ENABLED "OFF" CACHE INTERNAL "") +endif() + +if (GGML_OPENMP_ENABLED) + target_compile_definitions(ggml-base PRIVATE GGML_USE_OPENMP) + target_link_libraries(ggml-base PRIVATE OpenMP::OpenMP_C OpenMP::OpenMP_CXX) +endif() + add_library(ggml + ggml-backend-dl.cpp ggml-backend-reg.cpp) add_library(ggml::ggml ALIAS ggml) @@ -451,6 +470,7 @@ ggml_add_backend(HIP) ggml_add_backend(METAL) ggml_add_backend(MUSA) ggml_add_backend(RPC) +ggml_add_backend(VirtGPU) ggml_add_backend(SYCL) ggml_add_backend(Vulkan) ggml_add_backend(WebGPU) @@ -458,6 +478,7 @@ ggml_add_backend(zDNN) ggml_add_backend(OpenCL) ggml_add_backend(Hexagon) ggml_add_backend(ZenDNN) +ggml_add_backend(OPENVINO) foreach (target ggml-base ggml) target_include_directories(${target} PUBLIC $<BUILD_INTERFACE:${CMAKE_CURRENT_SOURCE_DIR}/../include> $<INSTALL_INTERFACE:include>) @@ -466,11 +487,10 @@ endforeach() target_link_libraries(ggml-base PRIVATE Threads::Threads) -find_library(MATH_LIBRARY m) -if (MATH_LIBRARY) - if (NOT WIN32 OR NOT DEFINED ENV{ONEAPI_ROOT}) - target_link_libraries(ggml-base PRIVATE m) - endif() +if (DEFINED MATH_LIBRARY) + target_link_libraries(ggml-base PRIVATE ${MATH_LIBRARY}) +elseif (NOT WIN32 AND NOT DEFINED ENV{ONEAPI_ROOT}) + target_link_libraries(ggml-base PRIVATE m) endif() if (CMAKE_SYSTEM_NAME MATCHES "Android") diff --git a/ggml/src/ggml-alloc.c b/ggml/src/ggml-alloc.c index 41419b617bd..3bda9abbe03 100644 --- a/ggml/src/ggml-alloc.c +++ b/ggml/src/ggml-alloc.c @@ -2,6 +2,7 @@ #include "ggml-backend-impl.h" #include "ggml.h" #include "ggml-impl.h" + #include <assert.h> #include <limits.h> #include <stdarg.h> @@ -17,11 +18,6 @@ //#define AT_PRINTF(...) GGML_LOG_DEBUG(__VA_ARGS__) #define AT_PRINTF(...) - -static bool ggml_is_view(const struct ggml_tensor * t) { - return t->view_src != NULL; -} - // ops that return true for this function must not use restrict pointers for their backend implementations bool ggml_op_can_inplace(enum ggml_op op) { switch (op) { @@ -154,7 +150,7 @@ static void ggml_dyn_tallocr_insert_block(struct tallocr_chunk * chunk, size_t o static void ggml_dyn_tallocr_remove_block(struct tallocr_chunk * chunk, int idx) { // shift all elements after idx by 1 to the left, overwriting the element at idx - for (int i = idx; i < chunk->n_free_blocks; i++) { + for (int i = idx; i < chunk->n_free_blocks - 1; i++) { chunk->free_blocks[i] = chunk->free_blocks[i+1]; } chunk->n_free_blocks--; @@ -627,7 +623,7 @@ static void ggml_gallocr_allocate_node(ggml_gallocr_t galloc, struct ggml_tensor GGML_ASSERT(buffer_id >= 0); struct hash_node * hn = ggml_gallocr_hash_get(galloc, node); - if (!ggml_gallocr_is_allocated(galloc, node) && !ggml_is_view(node)) { + if (!ggml_gallocr_is_allocated(galloc, node) && !ggml_impl_is_view(node)) { hn->allocated = true; assert(hn->addr.offset == 0); @@ -658,7 +654,7 @@ static void ggml_gallocr_allocate_node(ggml_gallocr_t galloc, struct ggml_tensor struct hash_node * p_hn = ggml_gallocr_hash_get(galloc, parent); if (p_hn->n_children == 1 && p_hn->n_views == 0) { - if (ggml_is_view(parent)) { + if (ggml_impl_is_view(parent)) { struct ggml_tensor * view_src = parent->view_src; struct hash_node * view_src_hn = ggml_gallocr_hash_get(galloc, view_src); if (view_src_hn->n_views == 1 && view_src_hn->n_children == 0 && view_src->data == parent->data) { @@ -739,7 +735,7 @@ static void ggml_gallocr_alloc_graph_impl(ggml_gallocr_t galloc, struct ggml_cgr // GGML_OP_NONE does not appear normally in the graph nodes, but is used by ggml-backend to add dependencies to // control when some tensors are allocated and freed. in this case, the dependencies are in `src`, but the node // itself is never used and should not be considered a dependency - if (ggml_is_view(node) && node->op != GGML_OP_NONE) { + if (ggml_impl_is_view(node) && node->op != GGML_OP_NONE) { struct ggml_tensor * view_src = node->view_src; ggml_gallocr_hash_get(galloc, view_src)->n_views += 1; } @@ -806,7 +802,7 @@ static void ggml_gallocr_alloc_graph_impl(ggml_gallocr_t galloc, struct ggml_cgr parent->name, p_hn->n_children, p_hn->n_views, p_hn->allocated); if (p_hn->n_children == 0 && p_hn->n_views == 0) { - if (ggml_is_view(parent)) { + if (ggml_impl_is_view(parent)) { struct ggml_tensor * view_src = parent->view_src; struct hash_node * view_src_hn = ggml_gallocr_hash_get(galloc, view_src); view_src_hn->n_views -= 1; @@ -1241,6 +1237,9 @@ size_t ggml_backend_alloc_ctx_tensors_from_buft_size(struct ggml_context * ctx, ggml_backend_buffer_t ggml_backend_alloc_ctx_tensors_from_buft(struct ggml_context * ctx, ggml_backend_buffer_type_t buft) { size_t nbytes_total = 0; + if (ggml_backend_buft_is_meta(buft)) { + return ggml_backend_meta_alloc_ctx_tensors_from_buft(ctx, buft); + } return ggml_backend_alloc_ctx_tensors_from_buft_impl(ctx, buft, &nbytes_total, /*no_alloc =*/ false); } diff --git a/ggml/src/ggml-backend-dl.cpp b/ggml/src/ggml-backend-dl.cpp new file mode 100644 index 00000000000..a65cf009055 --- /dev/null +++ b/ggml/src/ggml-backend-dl.cpp @@ -0,0 +1,48 @@ +#include "ggml-backend-dl.h" + +#ifdef _WIN32 + +dl_handle * dl_load_library(const fs::path & path) { + // suppress error dialogs for missing DLLs + DWORD old_mode = SetErrorMode(SEM_FAILCRITICALERRORS); + SetErrorMode(old_mode | SEM_FAILCRITICALERRORS); + + HMODULE handle = LoadLibraryW(path.wstring().c_str()); + + SetErrorMode(old_mode); + + return handle; +} + +void * dl_get_sym(dl_handle * handle, const char * name) { + DWORD old_mode = SetErrorMode(SEM_FAILCRITICALERRORS); + SetErrorMode(old_mode | SEM_FAILCRITICALERRORS); + + void * p = (void *) GetProcAddress(handle, name); + + SetErrorMode(old_mode); + + return p; +} + +const char * dl_error() { + return ""; +} + +#else + +dl_handle * dl_load_library(const fs::path & path) { + dl_handle * handle = dlopen(path.string().c_str(), RTLD_NOW | RTLD_LOCAL); + return handle; +} + +void * dl_get_sym(dl_handle * handle, const char * name) { + return dlsym(handle, name); +} + +const char * dl_error() { + const char *rslt = dlerror(); + return rslt != nullptr ? rslt : ""; +} + +#endif diff --git a/ggml/src/ggml-backend-dl.h b/ggml/src/ggml-backend-dl.h new file mode 100644 index 00000000000..f74b7c94894 --- /dev/null +++ b/ggml/src/ggml-backend-dl.h @@ -0,0 +1,45 @@ +#pragma once + +#ifdef _WIN32 +# define WIN32_LEAN_AND_MEAN +# ifndef NOMINMAX +# define NOMINMAX +# endif +# include <windows.h> +# include <winevt.h> +#else +# include <dlfcn.h> +# include <unistd.h> +#endif +#include <filesystem> + +namespace fs = std::filesystem; + +#ifdef _WIN32 + +using dl_handle = std::remove_pointer_t<HMODULE>; + +struct dl_handle_deleter { + void operator()(HMODULE handle) { + FreeLibrary(handle); + } +}; + +#else + +using dl_handle = void; + +struct dl_handle_deleter { + void operator()(void * handle) { + dlclose(handle); + } +}; + +#endif + +using dl_handle_ptr = std::unique_ptr<dl_handle, dl_handle_deleter>; + +dl_handle * dl_load_library(const fs::path & path); +void * dl_get_sym(dl_handle * handle, const char * name); +const char * dl_error(); + diff --git a/ggml/src/ggml-backend-impl.h b/ggml/src/ggml-backend-impl.h index 59190b7c465..9c56ec30c5f 100644 --- a/ggml/src/ggml-backend-impl.h +++ b/ggml/src/ggml-backend-impl.h @@ -49,6 +49,10 @@ extern "C" { void (*memset_tensor)(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, uint8_t value, size_t offset, size_t size); void (*set_tensor) (ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size); void (*get_tensor) (ggml_backend_buffer_t buffer, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size); + // (optional) 2d data copies + void (*set_tensor_2d)(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size, size_t n_copies, size_t stride_tensor, size_t stride_data); + void (*get_tensor_2d)(ggml_backend_buffer_t buffer, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size, size_t n_copies, size_t stride_tensor, size_t stride_data); + // (optional) tensor copy: dst is in the buffer, src may be in any buffer, including buffers from a different backend (return false if not supported) bool (*cpy_tensor) (ggml_backend_buffer_t buffer, const struct ggml_tensor * src, struct ggml_tensor * dst); // clear the entire buffer @@ -80,6 +84,20 @@ extern "C" { GGML_API bool ggml_backend_buffer_is_multi_buffer(ggml_backend_buffer_t buffer); GGML_API void ggml_backend_multi_buffer_set_usage(ggml_backend_buffer_t buffer, enum ggml_backend_buffer_usage usage); + // + // Backend (meta) + // + + GGML_API bool ggml_backend_is_meta (ggml_backend_t backend); + GGML_API bool ggml_backend_buffer_is_meta(ggml_backend_buffer_t buf); + GGML_API bool ggml_backend_buft_is_meta (ggml_backend_buffer_type_t buft); + + GGML_API size_t ggml_backend_meta_n_backends (ggml_backend_t meta_backend); + GGML_API ggml_backend_t ggml_backend_meta_simple_backend(ggml_backend_t meta_backend, size_t index); + + // temporary workaround to statically allocate tensors from a context in a deduplicated way: + GGML_API struct ggml_backend_buffer * ggml_backend_meta_alloc_ctx_tensors_from_buft(struct ggml_context * ctx, ggml_backend_buffer_type_t buft); + // // Backend (stream) // @@ -90,8 +108,10 @@ extern "C" { void (*free)(ggml_backend_t backend); // (optional) asynchronous tensor data access - void (*set_tensor_async)(ggml_backend_t backend, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size); - void (*get_tensor_async)(ggml_backend_t backend, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size); + void (*set_tensor_async) (ggml_backend_t backend, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size); + void (*get_tensor_async) (ggml_backend_t backend, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size); + void (*set_tensor_2d_async)(ggml_backend_t backend, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size, size_t n_copies, size_t stride_tensor, size_t stride_data); + void (*get_tensor_2d_async)(ggml_backend_t backend, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size, size_t n_copies, size_t stride_tensor, size_t stride_data); bool (*cpy_tensor_async)(ggml_backend_t backend_src, ggml_backend_t backend_dst, const struct ggml_tensor * src, struct ggml_tensor * dst); // (optional) complete all pending operations (required if the backend supports async operations) diff --git a/ggml/src/ggml-backend-meta.cpp b/ggml/src/ggml-backend-meta.cpp new file mode 100644 index 00000000000..0a36f099000 --- /dev/null +++ b/ggml/src/ggml-backend-meta.cpp @@ -0,0 +1,2263 @@ +#include "ggml.h" +#include "ggml-impl.h" +#include "ggml-backend.h" +#include "ggml-backend-impl.h" +#include "ggml-alloc.h" +#include "ggml-cpp.h" + +#include <algorithm> +#include <cassert> +#include <cmath> +#include <cstddef> +#include <cstdint> +#include <cstring> +#include <map> +#include <memory> +#include <set> +#include <string> +#include <tuple> +#include <utility> +#include <vector> + +struct ggml_backend_meta_device; +struct ggml_backend_meta_buffer_type; +struct ggml_backend_meta_buffer; +struct ggml_backend_meta; + +const char * ggml_backend_meta_split_axis_name(enum ggml_backend_meta_split_axis split_axis) { + switch (split_axis) { + case GGML_BACKEND_SPLIT_AXIS_0: + return "0"; + case GGML_BACKEND_SPLIT_AXIS_1: + return "1"; + case GGML_BACKEND_SPLIT_AXIS_2: + return "2"; + case GGML_BACKEND_SPLIT_AXIS_3: + return "3"; + case GGML_BACKEND_SPLIT_AXIS_MIRRORED: + return "MIRRORED"; + case GGML_BACKEND_SPLIT_AXIS_PARTIAL: + return "PARTIAL"; + case GGML_BACKEND_SPLIT_AXIS_NONE: + return "NONE"; + case GGML_BACKEND_SPLIT_AXIS_UNKNOWN: + return "UNKNOWN"; + default: + GGML_ABORT("fatal error"); + } +} + +// +// meta backend device +// + +struct ggml_backend_meta_device_context { + std::vector<ggml_backend_dev_t> simple_devs; + ggml_backend_meta_get_split_state_t get_split_state; + void * get_split_state_ud; + + std::string name; + std::string description; + + ggml_backend_meta_device_context( + std::vector<ggml_backend_dev_t> simple_devs, ggml_backend_meta_get_split_state_t get_split_state, void * get_split_state_ud) : + simple_devs(std::move(simple_devs)), get_split_state(get_split_state), get_split_state_ud(get_split_state_ud) { + name = std::string("Meta("); + description = std::string("Meta("); + for (size_t i = 0; i < simple_devs.size(); i++) { + if (i > 0) { + name += ","; + description += ","; + } + name += ggml_backend_dev_name (simple_devs[i]); + description += ggml_backend_dev_description(simple_devs[i]); + } + name += ")"; + description += ")"; + } + + bool operator<(const ggml_backend_meta_device_context & other) const { + return std::tie(simple_devs, get_split_state, get_split_state_ud) + < std::tie(other.simple_devs, other.get_split_state, other.get_split_state_ud); + } +}; + +static bool ggml_backend_dev_is_meta(ggml_backend_dev_t dev); + +static const char * ggml_backend_meta_device_get_name(ggml_backend_dev_t dev) { + GGML_ASSERT(ggml_backend_dev_is_meta(dev)); + const ggml_backend_meta_device_context * meta_dev_ctx = (const ggml_backend_meta_device_context *) dev->context; + return meta_dev_ctx->name.c_str(); +} + +static const char * ggml_backend_meta_device_get_description(ggml_backend_dev_t dev) { + GGML_ASSERT(ggml_backend_dev_is_meta(dev)); + const ggml_backend_meta_device_context * meta_dev_ctx = (const ggml_backend_meta_device_context *) dev->context; + return meta_dev_ctx->description.c_str(); +} + +static void ggml_backend_meta_device_get_memory(ggml_backend_dev_t dev, size_t * free, size_t * total) { + GGML_ASSERT(ggml_backend_dev_is_meta(dev)); + const ggml_backend_meta_device_context * meta_dev_ctx = (const ggml_backend_meta_device_context *) dev->context; + *free = 0; + *total = 0; + for (ggml_backend_dev_t dev : meta_dev_ctx->simple_devs) { + size_t tmp_free, tmp_total; + ggml_backend_dev_memory(dev, &tmp_free, &tmp_total); + *free += tmp_free; + *total += tmp_total; + } +} + +static enum ggml_backend_dev_type ggml_backend_meta_device_get_type(ggml_backend_dev_t dev) { + return GGML_BACKEND_DEVICE_TYPE_META; + + GGML_UNUSED(dev); +} + +static void ggml_backend_meta_device_get_props(ggml_backend_dev_t dev, ggml_backend_dev_props * props) { + GGML_ASSERT(ggml_backend_dev_is_meta(dev)); + const ggml_backend_meta_device_context * meta_dev_ctx = (const ggml_backend_meta_device_context *) dev->context; + + // TODO replace placeholders + props->name = ggml_backend_meta_device_get_name(dev); + props->description = ggml_backend_meta_device_get_description(dev); + props->type = ggml_backend_meta_device_get_type(dev); + props->device_id = 0; + + ggml_backend_meta_device_get_memory(dev, &props->memory_free, &props->memory_total); + + props->caps = { + /* .async = */ true, + /* .host_buffer = */ false, // Not implemented. + /* .buffer_from_host_ptr = */ false, // Not implemented. + /* .events = */ false, // Not implemented. + }; + for (ggml_backend_dev_t simple_dev : meta_dev_ctx->simple_devs) { + ggml_backend_dev_props tmp_props; + ggml_backend_dev_get_props(simple_dev, &tmp_props); + props->caps.async = props->caps.async && tmp_props.caps.async; + props->caps.host_buffer = props->caps.host_buffer && tmp_props.caps.host_buffer; + props->caps.buffer_from_host_ptr = props->caps.buffer_from_host_ptr && tmp_props.caps.buffer_from_host_ptr; + props->caps.events = props->caps.events && tmp_props.caps.events; + } +} + +static ggml_backend_t ggml_backend_meta_device_init_backend(ggml_backend_dev_t dev, const char * params); + +static ggml_backend_buffer_type_t ggml_backend_meta_device_get_buffer_type(ggml_backend_dev_t dev); + +static ggml_backend_buffer_type_t ggml_backend_meta_device_get_host_buffer_type(ggml_backend_dev_t dev); + +static bool ggml_backend_meta_device_supports_op(ggml_backend_dev_t dev, const ggml_tensor * op) { + GGML_ASSERT(ggml_backend_dev_is_meta(dev)); + const ggml_backend_meta_device_context * meta_dev_ctx = (const ggml_backend_meta_device_context *) dev->context; + return std::all_of(meta_dev_ctx->simple_devs.begin(), meta_dev_ctx->simple_devs.end(), + [op](ggml_backend_dev_t simple_dev) { return ggml_backend_dev_supports_op(simple_dev, op); }); +} + +static bool ggml_backend_meta_device_supports_buft(ggml_backend_dev_t dev, ggml_backend_buffer_type_t buft) { + GGML_ASSERT(ggml_backend_dev_is_meta(dev)); + ggml_backend_dev_t dev_buft = ggml_backend_buft_get_device(buft); + if (!ggml_backend_dev_is_meta(dev_buft)) { + return false; + } + const ggml_backend_meta_device_context * meta_dev_ctx = (const ggml_backend_meta_device_context *) dev->context; + const ggml_backend_meta_device_context * meta_buft_dev_ctx = (const ggml_backend_meta_device_context *) dev_buft->context; + if (meta_dev_ctx->simple_devs.size() != meta_buft_dev_ctx->simple_devs.size()) { + return false; + } + for (size_t i = 0; i < meta_dev_ctx->simple_devs.size(); i++) { + if (meta_dev_ctx->simple_devs[i] != meta_buft_dev_ctx->simple_devs[i]) { + return false; + } + } + return true; +} + +static const ggml_backend_device_i ggml_backend_meta_device_iface = { + /* .get_name = */ ggml_backend_meta_device_get_name, + /* .get_description = */ ggml_backend_meta_device_get_description, + /* .get_memory = */ ggml_backend_meta_device_get_memory, + /* .get_type = */ ggml_backend_meta_device_get_type, + /* .get_props = */ ggml_backend_meta_device_get_props, + /* .init_backend = */ ggml_backend_meta_device_init_backend, + /* .get_buffer_type = */ ggml_backend_meta_device_get_buffer_type, + /* .get_host_buffer_type = */ ggml_backend_meta_device_get_host_buffer_type, + /* .buffer_from_host_ptr = */ nullptr, + /* .supports_op = */ ggml_backend_meta_device_supports_op, + /* .supports_buft = */ ggml_backend_meta_device_supports_buft, + /* .offload_op = */ nullptr, + /* .event_new = */ nullptr, + /* .event_free = */ nullptr, + /* .event_synchronize = */ nullptr, +}; + +static bool ggml_backend_dev_is_meta(ggml_backend_dev_t dev) { + return dev != nullptr && dev->iface.get_name == ggml_backend_meta_device_iface.get_name; +} + +static size_t ggml_backend_meta_dev_n_devs(ggml_backend_dev_t meta_dev) { + GGML_ASSERT(ggml_backend_dev_is_meta(meta_dev)); + const ggml_backend_meta_device_context * meta_dev_ctx = (const ggml_backend_meta_device_context *) meta_dev->context; + return meta_dev_ctx->simple_devs.size(); +} + +static ggml_backend_dev_t ggml_backend_meta_dev_simple_dev(ggml_backend_dev_t meta_dev, size_t index) { + GGML_ASSERT(ggml_backend_dev_is_meta(meta_dev)); + const ggml_backend_meta_device_context * meta_dev_ctx = (const ggml_backend_meta_device_context *) meta_dev->context; + GGML_ASSERT(index < meta_dev_ctx->simple_devs.size()); + return meta_dev_ctx->simple_devs[index]; +} + +ggml_backend_dev_t ggml_backend_meta_device( + ggml_backend_dev_t * devs, size_t n_devs, ggml_backend_meta_get_split_state_t get_split_state, void * get_split_state_ud) { + GGML_ASSERT(n_devs <= GGML_BACKEND_META_MAX_DEVICES); + // TODO: this is not thread-safe - needs to be fixed + static std::vector<std::unique_ptr<ggml_backend_meta_device_context>> ctxs; + static std::map<ggml_backend_meta_device_context, struct ggml_backend_device> meta_devs; + + std::vector<ggml_backend_dev_t> simple_devs; + simple_devs.reserve(n_devs); + for (size_t i = 0; i < n_devs; i++) { + simple_devs.push_back(devs[i]); + } + ggml_backend_meta_device_context ctx(simple_devs, get_split_state, get_split_state_ud); + + { + auto it = meta_devs.find(ctx); + if (it != meta_devs.end()) { + return &it->second; + } + } + ctxs.push_back(std::make_unique<ggml_backend_meta_device_context>(ctx)); + + struct ggml_backend_device meta_dev = { + /*iface =*/ ggml_backend_meta_device_iface, + /*reg =*/ nullptr, + /*ctx =*/ ctxs.back().get(), + }; + + auto result = meta_devs.emplace(*ctxs.back(), meta_dev); + return &result.first->second; +} + +// +// meta backend buffer type +// + +struct ggml_backend_meta_buffer_type_context { + std::vector<ggml_backend_buffer_type_t> simple_bufts; + + std::string name; + + ggml_backend_meta_buffer_type_context(std::vector<ggml_backend_buffer_type_t> simple_bufts) : simple_bufts(std::move(simple_bufts)) { + name = "Meta("; + for (size_t i = 0; i < simple_bufts.size(); i++) { + if (i > 0) { + name += ","; + } + name += ggml_backend_buft_name(simple_bufts[i]); + } + name += ")"; + } + + bool operator<(const ggml_backend_meta_buffer_type_context & other) const { + return simple_bufts < other.simple_bufts; + } +}; + +static size_t ggml_backend_meta_buft_n_bufts(ggml_backend_buffer_type_t meta_buft) { + GGML_ASSERT(ggml_backend_buft_is_meta(meta_buft)); + const ggml_backend_meta_buffer_type_context * meta_buft_ctx = (const ggml_backend_meta_buffer_type_context *) meta_buft->context; + return meta_buft_ctx->simple_bufts.size(); +} + +static const char * ggml_backend_meta_buffer_type_get_name(ggml_backend_buffer_type_t buft) { + GGML_ASSERT(ggml_backend_buft_is_meta(buft)); + const ggml_backend_meta_buffer_type_context * meta_buft_ctx = (const ggml_backend_meta_buffer_type_context *) buft->context; + return meta_buft_ctx->name.c_str(); +} + +static ggml_backend_buffer_type_t ggml_backend_meta_buft_simple_buft(ggml_backend_buffer_type_t meta_buft, size_t index) { + GGML_ASSERT(ggml_backend_buft_is_meta(meta_buft)); + const ggml_backend_meta_buffer_type_context * meta_buft_ctx = (const ggml_backend_meta_buffer_type_context *) meta_buft->context; + GGML_ASSERT(index < meta_buft_ctx->simple_bufts.size()); + return meta_buft_ctx->simple_bufts[index]; +} + +static ggml_backend_buffer_t ggml_backend_meta_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size); + +static size_t ggml_backend_meta_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) { + const size_t n_simple_bufts = ggml_backend_meta_buft_n_bufts(buft); + size_t max_alignment = 1; + for (size_t i = 0; i < n_simple_bufts; i++) { + const size_t alignment = ggml_backend_buft_get_alignment(ggml_backend_meta_buft_simple_buft(buft, i)); + max_alignment = std::max(max_alignment, alignment); + GGML_ASSERT(max_alignment % alignment == 0); + } + return max_alignment; +} + +static size_t ggml_backend_meta_buffer_type_get_max_size(ggml_backend_buffer_type_t buft) { + const size_t n_simple_bufts = ggml_backend_meta_buft_n_bufts(buft); + size_t max_size = SIZE_MAX; + for (size_t i = 0; i < n_simple_bufts; i++) { + max_size = std::min(max_size, ggml_backend_buft_get_max_size(ggml_backend_meta_buft_simple_buft(buft, i))); + } + return max_size; +} + +static size_t ggml_backend_meta_buffer_type_get_alloc_size(ggml_backend_buffer_type_t buft, const ggml_tensor * tensor) { + const size_t n_simple_bufts = ggml_backend_meta_buft_n_bufts(buft); + size_t max_alloc_size = 0; + for (size_t i = 0; i < n_simple_bufts; i++) { + const size_t alloc_size = ggml_backend_buft_get_alloc_size(ggml_backend_meta_buft_simple_buft(buft, i), tensor); + max_alloc_size = std::max(max_alloc_size, alloc_size); + } + return max_alloc_size; +} + +static bool ggml_backend_meta_buffer_type_is_host(ggml_backend_buffer_type_t buft) { + const size_t n_simple_bufts = ggml_backend_meta_buft_n_bufts(buft); + for (size_t i = 0; i < n_simple_bufts; i++) { + if (!ggml_backend_buft_is_host(ggml_backend_meta_buft_simple_buft(buft, i))) { + return false; + } + } + return true; +} + +static const struct ggml_backend_buffer_type_i ggml_backend_meta_buffer_type_iface = { + /* .get_name = */ ggml_backend_meta_buffer_type_get_name, + /* .alloc_buffer = */ ggml_backend_meta_buffer_type_alloc_buffer, + /* .get_alignment = */ ggml_backend_meta_buffer_type_get_alignment, + /* .get_max_size = */ ggml_backend_meta_buffer_type_get_max_size, + /* .get_alloc_size = */ ggml_backend_meta_buffer_type_get_alloc_size, + /* .is_host = */ ggml_backend_meta_buffer_type_is_host, +}; + +bool ggml_backend_buft_is_meta(ggml_backend_buffer_type_t buft) { + return buft != nullptr && buft->iface.get_name == ggml_backend_meta_buffer_type_iface.get_name; +} + +static ggml_backend_buffer_type_t ggml_backend_meta_device_get_buffer_type(ggml_backend_dev_t dev) { + static std::map<ggml_backend_dev_t, struct ggml_backend_buffer_type> meta_bufts; + GGML_ASSERT(ggml_backend_dev_is_meta(dev)); + { + auto it = meta_bufts.find(dev); + if (it != meta_bufts.end()) { + return &it->second; + } + } + + const size_t n_devs = ggml_backend_meta_dev_n_devs(dev); + std::vector<ggml_backend_buffer_type_t> simple_bufts; + simple_bufts.reserve(n_devs); + for (size_t i = 0; i < n_devs; i++) { + simple_bufts.push_back(ggml_backend_dev_buffer_type(ggml_backend_meta_dev_simple_dev(dev, i))); + } + ggml_backend_meta_buffer_type_context * buft_ctx = new ggml_backend_meta_buffer_type_context(simple_bufts); + + struct ggml_backend_buffer_type meta_buft = { + /*iface =*/ ggml_backend_meta_buffer_type_iface, + /*device =*/ dev, + /*ctx =*/ buft_ctx, + }; + auto result = meta_bufts.emplace(dev, meta_buft); + return &result.first->second; +} + +static ggml_backend_buffer_type_t ggml_backend_meta_device_get_host_buffer_type(ggml_backend_dev_t dev) { + GGML_ASSERT(ggml_backend_dev_is_meta(dev)); + const ggml_backend_meta_device_context * meta_dev_ctx = (const ggml_backend_meta_device_context *) dev->context; + + ggml_backend_buffer_type_t host_buft = nullptr; + for (ggml_backend_dev_t simple_dev : meta_dev_ctx->simple_devs) { + ggml_backend_buffer_type_t simple_host_buft = ggml_backend_dev_host_buffer_type(simple_dev); + if (simple_host_buft == nullptr) { + return nullptr; + } + if (host_buft == nullptr) { + host_buft = simple_host_buft; + } else if (host_buft != simple_host_buft) { + // if different simple devices have different host buffer types, + // we cannot provide a single host buffer type for the meta device + return nullptr; + } + } + return host_buft; +} + +// +// meta backend buffer +// + +// Container to hold the tensor slices per simple ggml backend buffer. +struct ggml_backend_meta_simple_tensor_container { + std::vector<ggml_context_ptr> ctxs; + std::map<const ggml_tensor *, std::vector<ggml_tensor *>> simple_tensors; + + ggml_backend_meta_simple_tensor_container(const ggml_init_params & params, const int n_simple) { + ctxs.reserve(n_simple); + for (int i = 0; i < n_simple; i++) { + ctxs.emplace_back(ggml_init(params)); + } + } + ggml_backend_meta_simple_tensor_container() {} +}; + +struct ggml_backend_meta_buffer_context { + // FIXME + // Most tensors can simply be stored statically in their own buffer. + // Externally created views however also need a mapping to simple tensors but they use the buffer of the view source. + // If external views are simply using that buffer they will slowly deplete its memory. + // Current solution: rotating set of 2 "compute" containers to hold external views, works correctly for llama.cpp. + // Long-term: tie the lifetime of external views to the meta backend executing the graph instead, + // currently not possible due to graph-external operations in the backend scheduler. + ggml_backend_meta_simple_tensor_container stc_static; + ggml_backend_meta_simple_tensor_container stc_compute[2]; + int stc_compute_index = 0; + int stc_compute_index_next = 0; + std::vector<ggml_backend_buffer_ptr> bufs; + + // FIXME + // The size of the split state cache is unbounded and can theoretically grow infinitely large. + // However, it is also expensive to build and clearing it on every rebuild in ggml_backend_meta_graph_compute is too expensive. + static constexpr size_t nbtc = GGML_TENSOR_SIZE - sizeof(ggml_tensor::padding); + std::map<std::pair<const ggml_tensor *, bool>, std::pair<ggml_backend_meta_split_state, char[nbtc]>> split_state_cache; + + int debug; + + ggml_backend_meta_buffer_context( + ggml_backend_meta_simple_tensor_container & stc_static, + ggml_backend_meta_simple_tensor_container & stc_compute_0, + ggml_backend_meta_simple_tensor_container & stc_compute_1, + const std::vector<ggml_backend_buffer_t> & bufs) + : stc_static(std::move(stc_static)), stc_compute{std::move(stc_compute_0), std::move(stc_compute_1)} { + this->bufs.reserve(bufs.size()); + for (ggml_backend_buffer_t buf : bufs) { + this->bufs.emplace_back(buf); + } + const char * GGML_META_DEBUG = getenv("GGML_META_DEBUG"); + debug = GGML_META_DEBUG ? atoi(GGML_META_DEBUG) : 0; + } + + ggml_backend_meta_simple_tensor_container & get_simple_tensor_container(const ggml_tensor * tensor) { + if (stc_static.simple_tensors.find(tensor) != stc_static.simple_tensors.end()) { + return stc_static; + } + return stc_compute[stc_compute_index]; + } +}; + +static void ggml_backend_meta_buffer_free_buffer(ggml_backend_buffer_t buffer) { + GGML_ASSERT(ggml_backend_buffer_is_meta(buffer)); + ggml_backend_meta_buffer_context * buf_ctx = (ggml_backend_meta_buffer_context *) buffer->context; + delete buf_ctx; +} + +static size_t ggml_backend_meta_buffer_n_bufs(ggml_backend_buffer_t meta_buf) { + GGML_ASSERT(ggml_backend_buffer_is_meta(meta_buf)); + ggml_backend_meta_buffer_context * buf_ctx = (ggml_backend_meta_buffer_context *) meta_buf->context; + return buf_ctx->bufs.size(); +} + +static ggml_backend_buffer_t ggml_backend_meta_buffer_simple_buffer(ggml_backend_buffer_t meta_buf, size_t index) { + GGML_ASSERT(ggml_backend_buffer_is_meta(meta_buf)); + ggml_backend_meta_buffer_context * buf_ctx = (ggml_backend_meta_buffer_context *) meta_buf->context; + GGML_ASSERT(index < buf_ctx->bufs.size()); + return buf_ctx->bufs[index].get(); +} + +static struct ggml_tensor * ggml_backend_meta_buffer_simple_tensor(const struct ggml_tensor * tensor, size_t index) { + GGML_ASSERT(ggml_backend_buffer_is_meta(tensor->buffer)); + ggml_backend_meta_buffer_context * buf_ctx = (ggml_backend_meta_buffer_context *) tensor->buffer->context; + GGML_ASSERT(index < buf_ctx->bufs.size()); + + ggml_backend_meta_simple_tensor_container & stc = buf_ctx->get_simple_tensor_container(tensor); + auto it = stc.simple_tensors.find(tensor); + if (it == stc.simple_tensors.end()) { + return nullptr; + } + return it->second[index]; +} + +static struct ggml_backend_meta_split_state ggml_backend_meta_get_split_state(const struct ggml_tensor * tensor, bool assume_sync); + +static struct ggml_backend_meta_split_state ggml_backend_meta_get_split_state( + ggml_backend_meta_simple_tensor_container & stc, const struct ggml_tensor * tensor, bool assume_sync) { + // FIXME Currently this function preserves/erases the information in n_segments and nr in an inconsistent way. + // Since the operations in question are developed specifically for llama.cpp this currently does not manifest as a bug there. + // However, in a broader ggml context with arbitrary ggml graphs this can lead to unexpected results. + const size_t n_bufs = ggml_backend_meta_buffer_n_bufs(tensor->buffer); + ggml_backend_meta_buffer_context * buf_ctx = (ggml_backend_meta_buffer_context *) tensor->buffer->context; + + auto split_states_equal = [&](const ggml_backend_meta_split_state & a, const ggml_backend_meta_split_state & b) -> bool { + if (a.axis != b.axis) { + return false; + } + for (size_t j = 0; j < n_bufs; j++) { + int64_t sum_a = 0; + for (size_t s = 0; s < a.n_segments; s++) { + sum_a += a.ne[s*n_bufs + j] * a.nr[s]; + } + int64_t sum_b = 0; + for (size_t s = 0; s < b.n_segments; s++) { + sum_b += b.ne[s*n_bufs + j] * b.nr[s]; + } + if (sum_a != sum_b) { + return false; + } + } + return true; + }; + + auto handle_generic = [&](const std::vector<ggml_backend_meta_split_state> & src_ss, bool scalar_only) -> ggml_backend_meta_split_state { + ggml_backend_meta_split_state ret = {GGML_BACKEND_SPLIT_AXIS_NONE, {0}, {1}, 1}; + for (size_t i = 0; i < GGML_MAX_SRC; i++) { + if (tensor->src[i] == nullptr || tensor->src[i] == tensor) { + continue; + } + if (ret.axis == GGML_BACKEND_SPLIT_AXIS_NONE) { + ret = src_ss[i]; + } else if (!split_states_equal(src_ss[i], ret)) { + ret = {GGML_BACKEND_SPLIT_AXIS_UNKNOWN, {0}, {1}, 1}; + break; + } + } + if (ret.axis == GGML_BACKEND_SPLIT_AXIS_NONE) { + ret = {GGML_BACKEND_SPLIT_AXIS_UNKNOWN, {0}, {1}, 1}; + } + if (scalar_only && ret.axis >= 0 && ret.axis < GGML_MAX_DIMS) { + ret = {GGML_BACKEND_SPLIT_AXIS_UNKNOWN, {0}, {1}, 1}; + } + GGML_ASSERT(ret.axis != GGML_BACKEND_SPLIT_AXIS_UNKNOWN); + return ret; + }; + + // Some ops process data on a per-row bases: + auto handle_per_row = [&](const std::vector<ggml_backend_meta_split_state> & src_ss) -> ggml_backend_meta_split_state { + GGML_ASSERT(src_ss[0].axis != GGML_BACKEND_SPLIT_AXIS_0); + return src_ss[0]; + }; + + // Some ops broadcast the src1 data across src0: + auto handle_bin_bcast = [&](const std::vector<ggml_backend_meta_split_state> & src_ss) -> ggml_backend_meta_split_state { + if (src_ss[0].axis >= 0 && src_ss[0].axis < GGML_MAX_DIMS && + tensor->src[1]->ne[src_ss[0].axis] == 1 && src_ss[1].axis == GGML_BACKEND_SPLIT_AXIS_MIRRORED) { + return src_ss[0]; + } + if (src_ss[2].axis == GGML_BACKEND_SPLIT_AXIS_MIRRORED && (src_ss[0].axis == src_ss[1].axis || + (src_ss[0].axis == GGML_BACKEND_SPLIT_AXIS_MIRRORED && (src_ss[1].axis == GGML_BACKEND_SPLIT_AXIS_PARTIAL)))) { + return src_ss[0]; // GGML_OP_ADD_ID + } + GGML_ASSERT(tensor->src[2] == nullptr || src_ss[2].axis == GGML_BACKEND_SPLIT_AXIS_MIRRORED); + return handle_generic(src_ss, /*scalar_only =*/ false); + }; + + auto handle_concat = [&](const std::vector<ggml_backend_meta_split_state> & src_ss) -> ggml_backend_meta_split_state { + const ggml_backend_meta_split_axis concat_axis = ggml_backend_meta_split_axis(ggml_get_op_params_i32(tensor, 0)); + if (src_ss[0].axis == GGML_BACKEND_SPLIT_AXIS_MIRRORED && src_ss[1].axis >= 0 && src_ss[1].axis < GGML_MAX_DIMS) { + GGML_ASSERT(concat_axis != src_ss[1].axis); + return src_ss[1]; + } + if (src_ss[1].axis == GGML_BACKEND_SPLIT_AXIS_MIRRORED && src_ss[0].axis >= 0 && src_ss[0].axis < GGML_MAX_DIMS) { + GGML_ASSERT(concat_axis != src_ss[0].axis); + return src_ss[0]; + } + if (src_ss[0].axis == src_ss[1].axis && src_ss[0].axis != concat_axis) { + return src_ss[0]; + } + return handle_generic(src_ss, /*scalar_only =*/ true); + }; + + auto handle_mul_mat = [&](const std::vector<ggml_backend_meta_split_state> & src_ss) -> ggml_backend_meta_split_state { + if (src_ss[0].axis == GGML_BACKEND_SPLIT_AXIS_MIRRORED && src_ss[1].axis == GGML_BACKEND_SPLIT_AXIS_MIRRORED) { + return {GGML_BACKEND_SPLIT_AXIS_MIRRORED, {0}, {1}, 1}; + } + if (src_ss[0].axis == GGML_BACKEND_SPLIT_AXIS_1 && src_ss[1].axis == GGML_BACKEND_SPLIT_AXIS_MIRRORED) { + ggml_backend_meta_split_state ret = src_ss[0]; + ret.axis = GGML_BACKEND_SPLIT_AXIS_0; + ret.nr[0] = 1; + ret.n_segments = 1; + return ret; + } + if (src_ss[1].axis == GGML_BACKEND_SPLIT_AXIS_1 && src_ss[0].axis == GGML_BACKEND_SPLIT_AXIS_MIRRORED) { + return src_ss[1]; + } + if (src_ss[0].axis == GGML_BACKEND_SPLIT_AXIS_0 && src_ss[1].axis == GGML_BACKEND_SPLIT_AXIS_0) { + GGML_ASSERT(split_states_equal(src_ss[0], src_ss[1])); + return {assume_sync ? GGML_BACKEND_SPLIT_AXIS_MIRRORED : GGML_BACKEND_SPLIT_AXIS_PARTIAL, {0}, {1}, 1}; + } + GGML_ABORT("fatal error"); + //return {GGML_BACKEND_SPLIT_AXIS_UNKNOWN, {0}, {1}, 1}; + }; + + auto handle_reshape = [&](const std::vector<ggml_backend_meta_split_state> & src_ss) -> ggml_backend_meta_split_state { + switch (src_ss[0].axis) { + case GGML_BACKEND_SPLIT_AXIS_0: + case GGML_BACKEND_SPLIT_AXIS_1: + case GGML_BACKEND_SPLIT_AXIS_2: + case GGML_BACKEND_SPLIT_AXIS_3: { + GGML_ASSERT(src_ss[0].n_segments == 1); + if (src_ss[0].axis == ggml_n_dims(tensor->src[0]) - 1 && src_ss[0].nr[0] == 1) { + return {ggml_backend_meta_split_axis(ggml_n_dims(tensor) - 1), {0}, {1}, 1}; + } + int64_t base_ne_in = tensor->src[0]->ne[0]; + for (int dim = 1; dim <= src_ss[0].axis; dim++) { + base_ne_in *= tensor->src[0]->ne[dim]; + } + base_ne_in /= src_ss[0].nr[0]; + int64_t base_ne_out = 1; + for (int dim = 0; dim < GGML_MAX_DIMS; dim++) { + const int64_t base_ne_out_next = base_ne_out *= tensor->ne[dim]; + if (base_ne_out_next % base_ne_in == 0) { + return {ggml_backend_meta_split_axis(dim), {0}, {uint32_t(base_ne_out_next/base_ne_in)}, 1}; + } + if (base_ne_out_next > base_ne_in) { + GGML_ASSERT(src_ss[0].n_segments == 1); + GGML_ASSERT(src_ss[0].nr[0] == 1); + return {ggml_backend_meta_split_axis(dim), {0}, {1}, 1}; + } + base_ne_out = base_ne_out_next; + } + GGML_ABORT("shape mismatch for %s", ggml_op_name(tensor->op)); + } + case GGML_BACKEND_SPLIT_AXIS_MIRRORED: + case GGML_BACKEND_SPLIT_AXIS_PARTIAL: { + return src_ss[0]; + } + default: { + GGML_ABORT("fatal error"); + //return {GGML_BACKEND_SPLIT_AXIS_UNKNOWN, {0}, {1}, 1}; + } + } + }; + + auto handle_cpy = [&](const std::vector<ggml_backend_meta_split_state> & src_ss) -> ggml_backend_meta_split_state { + if (src_ss[0].axis >= 0 && src_ss[0].axis < GGML_MAX_DIMS) { + return handle_reshape(src_ss); + } + return handle_generic(src_ss, /*scalar_only =*/ false); + }; + + auto handle_view = [&](const std::vector<ggml_backend_meta_split_state> & src_ss) -> ggml_backend_meta_split_state { + if (ggml_is_contiguous(tensor) && ggml_is_contiguous(tensor->src[0])) { + return handle_reshape(src_ss); + } + const int axis = src_ss[0].axis; + { + bool all_strides_the_same = true; + for (int dim = 0; dim < GGML_MAX_DIMS; dim++) { + if (tensor->ne[dim] == 1 && tensor->src[0]->ne[dim] == 1) { + continue; + } + if (tensor->nb[dim] != tensor->src[0]->nb[dim]) { + all_strides_the_same = false; + break; + } + } + if (all_strides_the_same) { + return src_ss[0]; + } + } + if (!ggml_is_permuted(tensor) && !ggml_is_permuted(tensor->src[0]) && axis >= 0 && axis < GGML_MAX_DIMS-1) { + for (int dim = 0; dim < GGML_MAX_DIMS-1; dim++) { + if (tensor->nb[dim+1] == tensor->src[0]->nb[axis+1]) { + return {ggml_backend_meta_split_axis(dim), {0}, {1}, 1}; + } + } + GGML_ABORT("fatal error"); + } + if (src_ss[0].axis == GGML_BACKEND_SPLIT_AXIS_MIRRORED || src_ss[0].axis == GGML_BACKEND_SPLIT_AXIS_PARTIAL) { + return src_ss[0]; + } + GGML_ABORT("view of permuted tensor not implemented"); + //return {GGML_BACKEND_SPLIT_AXIS_UNKNOWN, {0}, {1}, 1}; + }; + + auto handle_permute = [&](const std::vector<ggml_backend_meta_split_state> & src_ss) -> ggml_backend_meta_split_state { + switch (src_ss[0].axis) { + case GGML_BACKEND_SPLIT_AXIS_0: + case GGML_BACKEND_SPLIT_AXIS_1: + case GGML_BACKEND_SPLIT_AXIS_2: + case GGML_BACKEND_SPLIT_AXIS_3: { + GGML_ASSERT(src_ss[0].n_segments == 1 || src_ss[0].nr[0] == 1); + return {ggml_backend_meta_split_axis(tensor->op_params[src_ss[0].axis]), {0}, {src_ss[0].nr[0]}, 1}; + } + case GGML_BACKEND_SPLIT_AXIS_MIRRORED: + case GGML_BACKEND_SPLIT_AXIS_PARTIAL: { + return src_ss[0]; + } + default: { + GGML_ABORT("fatal error"); + //return {GGML_BACKEND_SPLIT_AXIS_UNKNOWN, {0}, {1}, 1}; + } + } + }; + + auto handle_transpose = [&](const std::vector<ggml_backend_meta_split_state> & src_ss) -> ggml_backend_meta_split_state { + switch (src_ss[0].axis) { + case GGML_BACKEND_SPLIT_AXIS_0: + case GGML_BACKEND_SPLIT_AXIS_1: { + GGML_ASSERT(src_ss[0].n_segments == 1 || src_ss[0].nr[0] == 1); + return {ggml_backend_meta_split_axis(int(src_ss[0].axis) ^ 1), {0}, {src_ss[0].nr[0]}, 1}; + } + case GGML_BACKEND_SPLIT_AXIS_2: + case GGML_BACKEND_SPLIT_AXIS_3: + case GGML_BACKEND_SPLIT_AXIS_MIRRORED: + case GGML_BACKEND_SPLIT_AXIS_PARTIAL: { + return src_ss[0]; + } + default: { + GGML_ABORT("fatal error"); + //return {GGML_BACKEND_SPLIT_AXIS_UNKNOWN, {0}, {1}, 1}; + } + } + }; + + auto handle_get_rows = [&](const std::vector<ggml_backend_meta_split_state> & src_ss) -> ggml_backend_meta_split_state { + if (src_ss[0].axis == GGML_BACKEND_SPLIT_AXIS_0 && src_ss[1].axis == GGML_BACKEND_SPLIT_AXIS_MIRRORED) { + return src_ss[0]; + } + return handle_generic(src_ss, /*scalar_only =*/ true); + }; + + auto handle_set_rows = [&](const std::vector<ggml_backend_meta_split_state> & src_ss) -> ggml_backend_meta_split_state { + GGML_ASSERT(src_ss[0].axis != GGML_BACKEND_SPLIT_AXIS_1); + GGML_ASSERT(src_ss[1].axis == GGML_BACKEND_SPLIT_AXIS_MIRRORED); + GGML_ASSERT(split_states_equal(src_ss[0], src_ss[2])); + return src_ss[0]; + }; + + auto handle_rope = [&](const std::vector<ggml_backend_meta_split_state> & src_ss) -> ggml_backend_meta_split_state { + GGML_ASSERT(src_ss[1].axis == GGML_BACKEND_SPLIT_AXIS_MIRRORED); + return src_ss[0]; + }; + + auto handle_pad = [&](const std::vector<ggml_backend_meta_split_state> & src_ss) -> ggml_backend_meta_split_state { + if (src_ss[0].axis >= 0 && src_ss[0].axis < GGML_MAX_DIMS) { + GGML_ASSERT(tensor->op_params[2*src_ss[0].axis + 0] == 0); + GGML_ASSERT(tensor->op_params[2*src_ss[0].axis + 1] == 0); + } + return src_ss[0]; + }; + + auto handle_flash_attn_ext = [&](const std::vector<ggml_backend_meta_split_state> & src_ss) -> ggml_backend_meta_split_state { + GGML_ASSERT( src_ss[0].axis == GGML_BACKEND_SPLIT_AXIS_2); + GGML_ASSERT( src_ss[1].axis == GGML_BACKEND_SPLIT_AXIS_2); + GGML_ASSERT( src_ss[2].axis == GGML_BACKEND_SPLIT_AXIS_2); + GGML_ASSERT(tensor->src[4] == nullptr || src_ss[3].axis == GGML_BACKEND_SPLIT_AXIS_MIRRORED); + GGML_ASSERT(tensor->src[4] == nullptr || src_ss[4].axis == GGML_BACKEND_SPLIT_AXIS_0); + return {GGML_BACKEND_SPLIT_AXIS_1, {0}, {1}, 1}; + }; + + auto handle_ssm_conv = [&](const std::vector<ggml_backend_meta_split_state> & src_ss) -> ggml_backend_meta_split_state { + if (src_ss[0].axis == src_ss[1].axis) { + if (src_ss[0].axis == GGML_BACKEND_SPLIT_AXIS_0) { + return {GGML_BACKEND_SPLIT_AXIS_1, {0}, {1}, 1}; + } + if (src_ss[0].axis == GGML_BACKEND_SPLIT_AXIS_1) { + return {GGML_BACKEND_SPLIT_AXIS_0, {0}, {1}, 1}; + } + } + return handle_generic(src_ss, /*scalar_only =*/ false); + }; + + auto handle_gated_delta_net = [&](const std::vector<ggml_backend_meta_split_state> & src_ss) -> ggml_backend_meta_split_state { + if (src_ss[0].axis == GGML_BACKEND_SPLIT_AXIS_MIRRORED && src_ss[1].axis == GGML_BACKEND_SPLIT_AXIS_MIRRORED && + src_ss[2].axis == GGML_BACKEND_SPLIT_AXIS_MIRRORED && src_ss[3].axis == GGML_BACKEND_SPLIT_AXIS_MIRRORED && + src_ss[4].axis == GGML_BACKEND_SPLIT_AXIS_MIRRORED && src_ss[5].axis == GGML_BACKEND_SPLIT_AXIS_MIRRORED) { + return src_ss[0]; + } + GGML_ASSERT(src_ss[0].axis == GGML_BACKEND_SPLIT_AXIS_1); + GGML_ASSERT(src_ss[1].axis == GGML_BACKEND_SPLIT_AXIS_1); + GGML_ASSERT(src_ss[2].axis == GGML_BACKEND_SPLIT_AXIS_1); + GGML_ASSERT(src_ss[3].axis == GGML_BACKEND_SPLIT_AXIS_1); + GGML_ASSERT(src_ss[4].axis == GGML_BACKEND_SPLIT_AXIS_1); + // state shape is [S_v, S_v, H_v, n_seqs] (s0 only); the heads dim is its own axis 2, + // so a head-aligned split on the input cache lands on axis 2 here. + GGML_ASSERT(src_ss[5].axis == GGML_BACKEND_SPLIT_AXIS_2 || src_ss[5].axis == GGML_BACKEND_SPLIT_AXIS_1 || src_ss[5].axis == GGML_BACKEND_SPLIT_AXIS_0); + return {GGML_BACKEND_SPLIT_AXIS_0, {0}, {1}, 1}; + }; + + auto calculate_split_state = [&]() -> ggml_backend_meta_split_state { + if (ggml_nelements(tensor) == 0) { + return {GGML_BACKEND_SPLIT_AXIS_UNKNOWN, {0}, {1}, 1}; + } + if (ggml_backend_buffer_get_usage(tensor->buffer) != GGML_BACKEND_BUFFER_USAGE_COMPUTE && tensor->view_src == nullptr) { + ggml_backend_dev_t dev = ggml_backend_buft_get_device(ggml_backend_buffer_get_type(tensor->buffer)); + const ggml_backend_meta_device_context * dev_ctx = (const ggml_backend_meta_device_context *) dev->context; + ggml_backend_meta_split_state ret = dev_ctx->get_split_state(tensor, dev_ctx->get_split_state_ud); + if (ret.axis >= 0 && ret.axis <= GGML_MAX_DIMS) { + const int64_t granularity = ret.axis == GGML_BACKEND_SPLIT_AXIS_0 ? ggml_blck_size(tensor->type) : 1; + int64_t ne_sum = 0; + for (size_t s = 0; s < ret.n_segments; s++) { + for (size_t j = 0; j < n_bufs; j++) { + GGML_ASSERT(ret.ne[s*n_bufs + j] % granularity == 0); + ne_sum += ret.ne[s*n_bufs + j] * ret.nr[s]; + } + } + GGML_ASSERT(ne_sum == tensor->ne[ret.axis]); + } + return ret; + } + + std::vector<ggml_backend_meta_split_state> src_ss(GGML_MAX_SRC, {GGML_BACKEND_SPLIT_AXIS_NONE, {0}, {1}, 1}); + for (size_t i = 0; i < GGML_MAX_SRC; i++) { + if (tensor->src[i] == nullptr || tensor->src[i] == tensor) { + src_ss[i] = {GGML_BACKEND_SPLIT_AXIS_UNKNOWN, {0}, {1}, 1}; + continue; + } + src_ss[i] = ggml_backend_meta_get_split_state(stc, tensor->src[i], /*assume_sync =*/ true); + GGML_ASSERT(src_ss[i].axis != GGML_BACKEND_SPLIT_AXIS_UNKNOWN); + } + + ggml_backend_meta_split_state split_state; + switch (tensor->op) { + case GGML_OP_NONE: { + split_state = {GGML_BACKEND_SPLIT_AXIS_MIRRORED, {0}, {1}, 1}; + } break; + case GGML_OP_DUP: { + split_state = handle_generic(src_ss, /*scalar_only =*/ true); + } break; + case GGML_OP_ADD: + case GGML_OP_ADD_ID: { + split_state = handle_bin_bcast(src_ss); + } break; + case GGML_OP_ADD1: + case GGML_OP_ACC: { + split_state = handle_generic(src_ss, /*scalar_only =*/ true); + } break; + case GGML_OP_SUB: + case GGML_OP_MUL: + case GGML_OP_DIV: { + split_state = handle_bin_bcast(src_ss); + } break; + case GGML_OP_SQR: + case GGML_OP_SQRT: + case GGML_OP_LOG: + case GGML_OP_SIN: + case GGML_OP_COS: { + split_state = handle_generic(src_ss, /*scalar_only =*/ false); + } break; + case GGML_OP_SUM: { + split_state = handle_generic(src_ss, /*scalar_only =*/ true); + } break; + case GGML_OP_SUM_ROWS: + case GGML_OP_CUMSUM: + case GGML_OP_MEAN: + case GGML_OP_ARGMAX: + case GGML_OP_COUNT_EQUAL: { + split_state = handle_per_row(src_ss); + } break; + case GGML_OP_REPEAT: + case GGML_OP_REPEAT_BACK: { + split_state = handle_generic(src_ss, /*scalar_only =*/ false); + } break; + case GGML_OP_CONCAT: { + split_state = handle_concat(src_ss); + } break; + case GGML_OP_SILU_BACK: { + split_state = handle_generic(src_ss, /*scalar_only =*/ false); + } break; + case GGML_OP_NORM: + case GGML_OP_RMS_NORM: + case GGML_OP_RMS_NORM_BACK: + case GGML_OP_GROUP_NORM: + case GGML_OP_L2_NORM: { + split_state = handle_per_row(src_ss); + } break; + case GGML_OP_MUL_MAT: + case GGML_OP_MUL_MAT_ID: { + split_state = handle_mul_mat(src_ss); + } break; + case GGML_OP_OUT_PROD: { + split_state = handle_generic(src_ss, /*scalar_only =*/ true); + } break; + case GGML_OP_SCALE: { + split_state = handle_generic(src_ss, /*scalar_only =*/ false); + } break; + case GGML_OP_SET: { + split_state = handle_generic(src_ss, /*scalar_only =*/ true); + } break; + case GGML_OP_CPY: { + split_state = handle_cpy(src_ss); + } break; + case GGML_OP_CONT: + case GGML_OP_RESHAPE: { + split_state = handle_reshape(src_ss); + } break; + case GGML_OP_VIEW: { + split_state = handle_view(src_ss); + } break; + case GGML_OP_PERMUTE: { + split_state = handle_permute(src_ss); + } break; + case GGML_OP_TRANSPOSE: { + split_state = handle_transpose(src_ss); + } break; + case GGML_OP_GET_ROWS: { + split_state = handle_get_rows(src_ss); + } break; + case GGML_OP_GET_ROWS_BACK: { + split_state = handle_generic(src_ss, /*scalar_only =*/ true); + } break; + case GGML_OP_SET_ROWS: { + split_state = handle_set_rows(src_ss); + } break; + case GGML_OP_DIAG: + case GGML_OP_DIAG_MASK_INF: + case GGML_OP_DIAG_MASK_ZERO: { + split_state = handle_generic(src_ss, /*scalar_only =*/ true); + } break; + case GGML_OP_SOFT_MAX: + case GGML_OP_SOFT_MAX_BACK: { + split_state = handle_generic(src_ss, /*scalar_only =*/ false); + } break; + case GGML_OP_ROPE: { + split_state = handle_rope(src_ss); + } break; + case GGML_OP_ROPE_BACK: { + split_state = handle_generic(src_ss, /*scalar_only =*/ true); + } break; + case GGML_OP_CLAMP: { + split_state = handle_generic(src_ss, /*scalar_only =*/ false); + } break; + case GGML_OP_CONV_TRANSPOSE_1D: + case GGML_OP_IM2COL: + case GGML_OP_IM2COL_BACK: + case GGML_OP_IM2COL_3D: + case GGML_OP_CONV_2D: + case GGML_OP_CONV_3D: + case GGML_OP_CONV_2D_DW: + case GGML_OP_CONV_TRANSPOSE_2D: + case GGML_OP_POOL_1D: + case GGML_OP_POOL_2D: + case GGML_OP_POOL_2D_BACK: + case GGML_OP_UPSCALE: { + split_state = handle_generic(src_ss, /*scalar_only =*/ true); + } break; + case GGML_OP_PAD: { + split_state = handle_pad(src_ss); + } break; + case GGML_OP_PAD_REFLECT_1D: + case GGML_OP_ROLL: + case GGML_OP_ARANGE: + case GGML_OP_TIMESTEP_EMBEDDING: { + split_state = handle_generic(src_ss, /*scalar_only =*/ true); + } break; + case GGML_OP_ARGSORT: + case GGML_OP_TOP_K: { + split_state = handle_per_row(src_ss); + } break; + case GGML_OP_LEAKY_RELU: { + split_state = handle_generic(src_ss, /*scalar_only =*/ false); + } break; + case GGML_OP_TRI: { + split_state = handle_generic(src_ss, /*scalar_only =*/ true); + } break; + case GGML_OP_FILL: { + split_state = handle_generic(src_ss, /*scalar_only =*/ false); + } break; + case GGML_OP_FLASH_ATTN_EXT: { + split_state = handle_flash_attn_ext(src_ss); + } break; + case GGML_OP_FLASH_ATTN_BACK: { + split_state = handle_generic(src_ss, /*scalar_only =*/ true); + } break; + case GGML_OP_SSM_CONV: { + split_state = handle_ssm_conv(src_ss); + } break; + case GGML_OP_SSM_SCAN: + case GGML_OP_WIN_PART: + case GGML_OP_WIN_UNPART: + case GGML_OP_GET_REL_POS: + case GGML_OP_ADD_REL_POS: + case GGML_OP_RWKV_WKV6: + case GGML_OP_GATED_LINEAR_ATTN: + case GGML_OP_RWKV_WKV7: + case GGML_OP_SOLVE_TRI: { + split_state = handle_generic(src_ss, /*scalar_only =*/ true); + } break; + case GGML_OP_GATED_DELTA_NET: { + split_state = handle_gated_delta_net(src_ss); + } break; + case GGML_OP_UNARY: { + split_state = handle_generic(src_ss, /*scalar_only =*/ false); + } break; + case GGML_OP_MAP_CUSTOM1: + case GGML_OP_MAP_CUSTOM2: + case GGML_OP_MAP_CUSTOM3: + case GGML_OP_CUSTOM: { + split_state = handle_generic(src_ss, /*scalar_only =*/ true); + } break; + case GGML_OP_CROSS_ENTROPY_LOSS: + case GGML_OP_CROSS_ENTROPY_LOSS_BACK: { + split_state = handle_per_row(src_ss); + } break; + case GGML_OP_OPT_STEP_ADAMW: + case GGML_OP_OPT_STEP_SGD: + case GGML_OP_GLU: { + split_state = handle_generic(src_ss, /*scalar_only =*/ false); + } break; + default: { + GGML_ABORT("ggml op not implemented: %s", ggml_op_name(tensor->op)); + split_state = {GGML_BACKEND_SPLIT_AXIS_UNKNOWN, {0}, {1}, 1}; + } break; + } + if (split_state.axis >= 0 && split_state.axis < GGML_MAX_DIMS) { + bool first_src_split_by_axis = true; + const size_t n_bufs = ggml_backend_meta_buffer_n_bufs(tensor->buffer); + + for (size_t i = 0; i < GGML_MAX_SRC; i++) { + if (tensor->src[i] == nullptr || src_ss[i].axis < 0 || src_ss[i].axis >= GGML_MAX_DIMS) { + continue; + } + if (first_src_split_by_axis) { + for (size_t j = 0; j < n_bufs; j++) { + // Take over ratio from src: + for (size_t s = 0; s < src_ss[i].n_segments; s++) { + split_state.ne[s*n_bufs + j] = 0; + } + for (size_t s = 0; s < src_ss[i].n_segments; s++) { + split_state.ne[j] += src_ss[i].ne[s*n_bufs + j] * src_ss[i].nr[s]; + } + split_state.ne[j] *= tensor->ne[split_state.axis]; + if (split_state.ne[j] != 0 || tensor->src[i]->ne[src_ss[i].axis] != 0) { + const int64_t div = tensor->src[i]->ne[src_ss[i].axis] * split_state.nr[0]; + GGML_ASSERT(split_state.ne[j] % div == 0); + split_state.ne[j] /= div; + } + } + } else { + GGML_ASSERT(split_state.n_segments == 1); + for (size_t j = 0; j < n_bufs; j++) { + // Assert that ratio is consistent: + int64_t sum = 0; + for (size_t s = 0; s < src_ss[i].n_segments; s++) { + sum += src_ss[i].ne[s*n_bufs + j] * src_ss[i].nr[s]; + } + GGML_ASSERT(split_state.ne[j]*split_state.nr[0] * tensor->src[i]->ne[src_ss[i].axis] + == sum * tensor->ne[split_state.axis]); + } + } + first_src_split_by_axis = false; + } + GGML_ASSERT(!first_src_split_by_axis); + } + return split_state; + }; + + const std::pair key = std::make_pair(tensor, assume_sync); + auto it = buf_ctx->split_state_cache.find(key); + if (it != buf_ctx->split_state_cache.end() && memcmp(it->second.second, (const char *) tensor, sizeof(it->second.second)) != 0) { + buf_ctx->split_state_cache.clear(); + it = buf_ctx->split_state_cache.end(); + } + + if (it == buf_ctx->split_state_cache.end()) { + buf_ctx->split_state_cache[key].first = calculate_split_state(); + memcpy(buf_ctx->split_state_cache[key].second, tensor, sizeof(buf_ctx->split_state_cache[key].second)); + if (buf_ctx->debug > 0) { + std::string srcs_info; + for (size_t i = 0; i < GGML_MAX_SRC; i++) { + if (tensor->src[i] == nullptr) { + continue; + } + if (!srcs_info.empty()) { + srcs_info += ", "; + } + const ggml_backend_meta_split_state split_state = ggml_backend_meta_get_split_state(tensor->src[0], true); + GGML_ASSERT(split_state.n_segments == 1); + const char * axis_name = ggml_backend_meta_split_axis_name(split_state.axis); + std::string ne_info; + for (size_t j = 0; j < n_bufs; j++) { + if (!ne_info.empty()) { + ne_info += ", "; + } + ne_info += std::to_string(split_state.ne[j]) + "x" + std::to_string(split_state.nr[0]); + } + srcs_info += std::string(tensor->src[i]->name) + "[" + ggml_op_name(tensor->src[i]->op) + ", " + axis_name + ", {" + ne_info + "}]"; + } + std::string ne_info; + for (size_t j = 0; j < n_bufs; j++) { + if (!ne_info.empty()) { + ne_info += ", "; + } + const ggml_backend_meta_split_state & ss = buf_ctx->split_state_cache[key].first; + ne_info += std::to_string(ss.ne[j]) + "x" + std::to_string(ss.nr[0]); + } + GGML_LOG_DEBUG("SPLIT_STATE: {%s} -> %s[%s, %s, {%s}]\n", srcs_info.c_str(), tensor->name, ggml_op_name(tensor->op), + ggml_backend_meta_split_axis_name(buf_ctx->split_state_cache[key].first.axis), ne_info.c_str()); + } + } + + ggml_backend_meta_split_state ret = buf_ctx->split_state_cache[key].first; + GGML_ASSERT(ret.axis != GGML_BACKEND_SPLIT_AXIS_NONE); +#ifndef NDEBUG + if (ret.axis >= 0 && ret.axis < GGML_MAX_DIMS) { + int64_t ne_ret = 0; + for (size_t s = 0; s < ret.n_segments; s++) { + for (size_t j = 0; j < n_bufs; j++) { + ne_ret += ret.ne[s*n_bufs + j] * ret.nr[s]; + } + } + assert(ne_ret == tensor->ne[int(ret.axis)]); + } +#endif // NDEBUG + return ret; +} + +static struct ggml_backend_meta_split_state ggml_backend_meta_get_split_state(const struct ggml_tensor * tensor, bool assume_sync) { + GGML_ASSERT(ggml_backend_buffer_is_meta(tensor->buffer)); + ggml_backend_meta_buffer_context * buf_ctx = (ggml_backend_meta_buffer_context *) tensor->buffer->context; + return ggml_backend_meta_get_split_state(buf_ctx->get_simple_tensor_container(tensor), tensor, assume_sync); +} + +static void * ggml_backend_meta_buffer_get_base(ggml_backend_buffer_t buffer) { + GGML_UNUSED(buffer); + return (void *) 0x1000000000000000; // FIXME +} + +static enum ggml_status ggml_backend_meta_buffer_init_tensor_impl(ggml_backend_meta_simple_tensor_container & stc, ggml_tensor * tensor) { + GGML_ASSERT(ggml_backend_buffer_is_meta(tensor->buffer)); + ggml_backend_meta_buffer_context * buf_ctx = (ggml_backend_meta_buffer_context *) tensor->buffer->context; + const size_t n_simple_bufs = ggml_backend_meta_buffer_n_bufs(tensor->buffer); + + const ggml_backend_meta_split_state split_state = ggml_backend_meta_get_split_state(stc, tensor, /*assume_sync =*/ true); + GGML_ASSERT(ggml_nelements(tensor) == 0 || split_state.axis != GGML_BACKEND_SPLIT_AXIS_UNKNOWN); + GGML_ASSERT(split_state.n_segments <= 16); + + int split_dim = split_state.axis; + int64_t ne[GGML_MAX_DIMS]; + size_t nb[GGML_MAX_DIMS]; + for (size_t k = 0; k < GGML_MAX_DIMS; k++) { + ne[k] = tensor->ne[k]; + nb[k] = tensor->nb[k]; + } + + std::vector<ggml_tensor *> simple_tensors; + simple_tensors.reserve(n_simple_bufs); + for (size_t j = 0; j < n_simple_bufs; j++) { + ggml_context * simple_ctx = stc.ctxs[j].get(); + ggml_backend_buffer_t simple_buf = buf_ctx->bufs[j].get(); + + if (split_dim >= 0 && split_dim < GGML_MAX_DIMS) { + // TODO: the following assert fails for llama-parallel even though the results are correct: + // GGML_ASSERT(ggml_is_contiguously_allocated(tensor)); + ne[split_dim] = 0; + for (size_t s = 0; s < split_state.n_segments; s++) { + ne[split_dim] += split_state.ne[s*n_simple_bufs + j] * split_state.nr[s]; + } + for (int i = 0; i < GGML_MAX_DIMS; i++) { + if (tensor->nb[i] > tensor->nb[split_dim]) { + nb[i] = tensor->nb[i] * ne[split_dim]/tensor->ne[split_dim]; + } + } + } + + ggml_tensor * t_ij = ggml_new_tensor(simple_ctx, tensor->type, GGML_MAX_DIMS, ne); + t_ij->op = tensor->op; + for (int i = 0; i < GGML_MAX_DIMS; i++) { + t_ij->nb[i] = nb[i]; + } + t_ij->flags = tensor->flags; + memcpy(t_ij->op_params, tensor->op_params, sizeof(tensor->op_params)); + ggml_set_name(t_ij, tensor->name); + t_ij->buffer = simple_buf; + t_ij->view_src = tensor->view_src; + t_ij->view_offs = tensor->view_offs; + if (t_ij->view_src != nullptr && ggml_backend_buffer_is_meta(t_ij->view_src->buffer)) { + t_ij->view_src = ggml_backend_meta_buffer_simple_tensor(tensor->view_src, j); + if (t_ij->view_offs > 0 && split_dim >= 0 && split_dim < GGML_MAX_DIMS) { + GGML_ASSERT(tensor->ne[split_dim] != 0); + const int split_dim_view_src = ggml_backend_meta_get_split_state(tensor->view_src, /*assume_sync =*/ true).axis; + GGML_ASSERT(split_dim_view_src >= 0 && split_dim_view_src < GGML_MAX_DIMS); + + // The offset can be internal to the data split, in those cases the view offset should not be scaled. + // If however, the offset is larger than the data split then it needs to be scaled proportionally. + bool split_internal_offset = t_ij->view_offs <= tensor->view_src->nb[split_dim_view_src]; + for (int i = 0; i < GGML_MAX_DIMS; i++) { + const size_t dim_size = tensor->ne[i] * tensor->nb[i]; + if (tensor->view_offs <= dim_size && dim_size < tensor->nb[split_dim]) { + split_internal_offset = true; + break; + } + } + if (!split_internal_offset) { + t_ij->view_offs = t_ij->view_offs * ne[split_dim]/tensor->ne[split_dim]; + } + } + } + if (t_ij->view_src != nullptr) { + t_ij->data = (char *) t_ij->view_src->data + t_ij->view_offs; + } else if (simple_buf != nullptr) { + t_ij->data = (char *) ggml_backend_buffer_get_base(simple_buf) + + size_t(tensor->data) - size_t(ggml_backend_buffer_get_base(tensor->buffer)); + } + t_ij->extra = tensor->extra; + for (int i = 0; i < GGML_MAX_SRC; i++) { + t_ij->src[i] = tensor->src[i]; + if (tensor->src[i] == tensor) { + t_ij->src[i] = t_ij; + } else if (t_ij->src[i] != nullptr && ggml_backend_buffer_is_meta(t_ij->src[i]->buffer)) { + t_ij->src[i] = ggml_backend_meta_buffer_simple_tensor(tensor->src[i], j); + } + } + + simple_tensors.push_back(t_ij); + } + + // If one of the sources has a zero-sized slice, disable the computation: + for (int i = 0; i < GGML_MAX_SRC; i++) { + if (tensor->src[i] == nullptr || !ggml_backend_buffer_is_meta(tensor->src[i]->buffer)) { + continue; + } + + const ggml_backend_meta_split_state split_state_src = ggml_backend_meta_get_split_state(tensor->src[i], /*assume_sync =*/ true); + if (split_state_src.axis < 0 || split_state_src.axis >= GGML_MAX_DIMS) { + continue; + } + for (size_t j = 0; j < n_simple_bufs; j++) { + int64_t ne_sum = 0; + for (size_t s = 0; s < split_state_src.n_segments; s++) { + ne_sum += split_state_src.ne[s*n_simple_bufs + j] * split_state_src.nr[s]; + } + if (ne_sum == 0) { + simple_tensors[j]->flags &= ~GGML_TENSOR_FLAG_COMPUTE; + } + } + } + + stc.simple_tensors[tensor] = simple_tensors; + + return GGML_STATUS_SUCCESS; +} + +static enum ggml_status ggml_backend_meta_buffer_init_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor) { + GGML_ASSERT(ggml_backend_buffer_is_meta(buffer)); + ggml_backend_meta_buffer_context * buf_ctx = (ggml_backend_meta_buffer_context *) buffer->context; + buf_ctx->stc_compute_index = buf_ctx->stc_compute_index_next; + return ggml_backend_meta_buffer_init_tensor_impl(buf_ctx->get_simple_tensor_container(tensor), tensor); +} + +static void ggml_backend_meta_buffer_set_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor, const void * data, size_t offset, size_t size) { + const size_t n_bufs = ggml_backend_meta_buffer_n_bufs(buffer); + GGML_ASSERT(ggml_is_contiguous(tensor)); + + const ggml_backend_meta_split_state split_state = ggml_backend_meta_get_split_state(tensor, /*assume_sync =*/ false); + + if (split_state.n_segments != 1 || split_state.nr[0] != 1) { + GGML_ASSERT(split_state.axis >= 0 && split_state.axis < GGML_MAX_DIMS); + GGML_ASSERT(split_state.nr[0] != 0); + GGML_ASSERT(tensor->ne[3] == 1); + + size_t offset_data = 0; + std::vector<size_t> simple_offsets(n_bufs, 0); + if (split_state.axis == GGML_BACKEND_SPLIT_AXIS_0) { + GGML_ASSERT(tensor->ne[2] == 1); + + const size_t row_stride = tensor->nb[1]; + GGML_ASSERT(offset % row_stride == 0); + GGML_ASSERT(size % row_stride == 0); + const int64_t row_start = offset / row_stride; + const int64_t row_count = size / row_stride; + GGML_ASSERT(row_start + row_count <= tensor->ne[1]); + + const int64_t blck_size = ggml_blck_size(tensor->type); + for (size_t s = 0; s < split_state.n_segments; s++) { + for (size_t r = 0; r < split_state.nr[s]; r++) { + for (size_t j = 0; j < n_bufs; j++) { + ggml_tensor * simple_tensor = ggml_backend_meta_buffer_simple_tensor(tensor, j); + GGML_ASSERT(split_state.ne[s*n_bufs + j] % blck_size == 0); + const size_t nbytes = split_state.ne[s*n_bufs + j]/blck_size * tensor->nb[0]; + ggml_backend_tensor_set_2d(simple_tensor, (const char *) data + offset_data, + simple_offsets[j] + row_start * simple_tensor->nb[1], nbytes, + row_count, simple_tensor->nb[1], tensor->nb[1]); + offset_data += nbytes; + simple_offsets[j] += nbytes; + } + } + } + GGML_ASSERT(offset_data*row_count == size); + return; + } + GGML_ASSERT(split_state.axis == GGML_BACKEND_SPLIT_AXIS_1); + + const size_t row_stride = tensor->nb[2]; + GGML_ASSERT(offset % row_stride == 0); + GGML_ASSERT(size % row_stride == 0); + const int64_t row_start = offset / row_stride; + const int64_t row_count = size / row_stride; + GGML_ASSERT(row_start + row_count <= tensor->ne[2]); + + for (size_t s = 0; s < split_state.n_segments; s++) { + for (size_t r = 0; r < split_state.nr[s]; r++) { + for (size_t j = 0; j < n_bufs; j++) { + ggml_tensor * simple_tensor = ggml_backend_meta_buffer_simple_tensor(tensor, j); + const size_t nbytes = split_state.ne[s*n_bufs + j] * tensor->nb[1]; + ggml_backend_tensor_set_2d(simple_tensor, (const char *) data + offset_data, + simple_offsets[j] + row_start * simple_tensor->nb[2], nbytes, + row_count, simple_tensor->nb[2], tensor->nb[2]); + offset_data += nbytes; + simple_offsets[j] += nbytes; + } + } + } + GGML_ASSERT(offset_data*row_count == size); + return; + } + + switch (split_state.axis) { + case GGML_BACKEND_SPLIT_AXIS_0: + case GGML_BACKEND_SPLIT_AXIS_1: + case GGML_BACKEND_SPLIT_AXIS_2: { + // Exploit that tensors are contiguous to splice it with simple tensors as "chunks". + const size_t chunk_size_full = tensor->nb[split_state.axis + 1]; + GGML_ASSERT(offset % chunk_size_full == 0); + GGML_ASSERT(size % chunk_size_full == 0); + const int64_t i_start = offset /chunk_size_full; + const int64_t i_stop = (offset + size)/chunk_size_full; + size_t offset_j = 0; + for (size_t j = 0; j < n_bufs; j++) { + ggml_tensor * simple_tensor = ggml_backend_meta_buffer_simple_tensor(tensor, j); + const size_t chunk_size_j = simple_tensor->nb[split_state.axis + 1]; + if (chunk_size_j == 0) { + continue; + } + const size_t simple_offset = i_start * chunk_size_j; + ggml_backend_tensor_set_2d(simple_tensor, (const char *) data + offset_j, simple_offset, chunk_size_j, i_stop - i_start, chunk_size_j, chunk_size_full); + offset_j += chunk_size_j; + } + GGML_ASSERT(offset_j == chunk_size_full); + } break; + case GGML_BACKEND_SPLIT_AXIS_MIRRORED: { + for (size_t j = 0; j < n_bufs; j++) { + ggml_tensor * simple_tensor = ggml_backend_meta_buffer_simple_tensor(tensor, j); + ggml_backend_tensor_set(simple_tensor, data, offset, size); + } + } break; + case GGML_BACKEND_SPLIT_AXIS_PARTIAL: { + GGML_ASSERT(tensor->type == GGML_TYPE_F32); + const int64_t ne = ggml_nelements(tensor); + std::vector<float> tmp; + tmp.reserve(ne); + for (int64_t i = 0; i < ne; i++) { + tmp.push_back(((const float *) data)[i] / n_bufs); + } + for (size_t j = 0; j < n_bufs; j++) { + ggml_tensor * simple_tensor = ggml_backend_meta_buffer_simple_tensor(tensor, j); + ggml_backend_tensor_set(simple_tensor, tmp.data(), offset, size); + } + } break; + default: { + GGML_ABORT("fatal error"); + } + } +} + +static void ggml_backend_meta_buffer_get_tensor(ggml_backend_buffer_t buffer, const ggml_tensor * tensor, void * data, size_t offset, size_t size) { + const size_t n_bufs = ggml_backend_meta_buffer_n_bufs(buffer); + GGML_ASSERT(ggml_is_contiguous(tensor)); + + const ggml_backend_meta_split_state split_state = ggml_backend_meta_get_split_state(tensor, /*assume_sync =*/ false); + + if (split_state.n_segments != 1 || split_state.nr[0] != 1) { + GGML_ASSERT(split_state.axis >= 0 && split_state.axis < GGML_MAX_DIMS); + GGML_ASSERT(split_state.nr[0] != 0); + GGML_ASSERT(tensor->ne[3] == 1); + + size_t offset_data = 0; + std::vector<size_t> simple_offsets(n_bufs, 0); + if (split_state.axis == GGML_BACKEND_SPLIT_AXIS_0) { + GGML_ASSERT(tensor->ne[2] == 1); + + const size_t row_stride = tensor->nb[1]; + GGML_ASSERT(offset % row_stride == 0); + GGML_ASSERT(size % row_stride == 0); + const int64_t row_start = offset / row_stride; + const int64_t row_count = size / row_stride; + GGML_ASSERT(row_start + row_count <= tensor->ne[1]); + + const int64_t blck_size = ggml_blck_size(tensor->type); + for (size_t s = 0; s < split_state.n_segments; s++) { + for (size_t r = 0; r < split_state.nr[s]; r++) { + for (size_t j = 0; j < n_bufs; j++) { + const ggml_tensor * simple_tensor = ggml_backend_meta_buffer_simple_tensor(tensor, j); + GGML_ASSERT(split_state.ne[s*n_bufs + j] % blck_size == 0); + const size_t nbytes = split_state.ne[s*n_bufs + j]/blck_size * tensor->nb[0]; + ggml_backend_tensor_get_2d(simple_tensor, (char *) data + offset_data, + simple_offsets[j] + row_start * simple_tensor->nb[1], nbytes, + row_count, simple_tensor->nb[1], tensor->nb[1]); + offset_data += nbytes; + simple_offsets[j] += nbytes; + } + } + } + GGML_ASSERT(offset_data*row_count == size); + return; + } + GGML_ASSERT(split_state.axis == GGML_BACKEND_SPLIT_AXIS_1); + + const size_t row_stride = tensor->nb[2]; + GGML_ASSERT(offset % row_stride == 0); + GGML_ASSERT(size % row_stride == 0); + const int64_t row_start = offset / row_stride; + const int64_t row_count = size / row_stride; + GGML_ASSERT(row_start + row_count <= tensor->ne[2]); + + for (size_t s = 0; s < split_state.n_segments; s++) { + for (size_t r = 0; r < split_state.nr[s]; r++) { + for (size_t j = 0; j < n_bufs; j++) { + const ggml_tensor * simple_tensor = ggml_backend_meta_buffer_simple_tensor(tensor, j); + const size_t nbytes = split_state.ne[s*n_bufs + j] * tensor->nb[1]; + ggml_backend_tensor_get_2d(simple_tensor, (char *) data + offset_data, + simple_offsets[j] + row_start * simple_tensor->nb[2], nbytes, + row_count, simple_tensor->nb[2], tensor->nb[2]); + offset_data += nbytes; + simple_offsets[j] += nbytes; + } + } + } + GGML_ASSERT(offset_data*row_count == size); + return; + } + + switch (split_state.axis) { + case GGML_BACKEND_SPLIT_AXIS_0: + case GGML_BACKEND_SPLIT_AXIS_1: + case GGML_BACKEND_SPLIT_AXIS_2: { + // Exploit that tensors are contiguous to splice it with simple tensors as "chunks". + const size_t chunk_size_full = tensor->nb[split_state.axis + 1]; + GGML_ASSERT(offset % chunk_size_full == 0); + GGML_ASSERT(size % chunk_size_full == 0); + const int64_t i_start = offset /chunk_size_full; + const int64_t i_stop = (offset + size)/chunk_size_full; + size_t offset_j = 0; + for (size_t j = 0; j < n_bufs; j++){ + const ggml_tensor * simple_tensor = ggml_backend_meta_buffer_simple_tensor(tensor, j); + const size_t chunk_size_j = simple_tensor->nb[split_state.axis + 1]; + if (chunk_size_j == 0) { + continue; + } + const size_t simple_offset = i_start * chunk_size_j; + ggml_backend_tensor_get_2d(simple_tensor, (char *) data + offset_j, simple_offset, chunk_size_j, i_stop - i_start, chunk_size_j, chunk_size_full); + offset_j += chunk_size_j; + } + GGML_ASSERT(offset_j == chunk_size_full); + } break; + case GGML_BACKEND_SPLIT_AXIS_MIRRORED: { + // TODO other simple backend may be better + const ggml_tensor * simple_tensor = ggml_backend_meta_buffer_simple_tensor(tensor, 0); + ggml_backend_tensor_get(simple_tensor, data, offset, size); + } break; + default: { + GGML_ABORT("fatal error"); + } + } +} + +static void ggml_backend_meta_buffer_clear(ggml_backend_buffer_t buffer, uint8_t value) { + const size_t n_buffers = ggml_backend_meta_buffer_n_bufs(buffer); + for (size_t i = 0; i < n_buffers; i++) { + ggml_backend_buffer_clear(ggml_backend_meta_buffer_simple_buffer(buffer, i), value); + } +} + +static void ggml_backend_meta_buffer_reset(ggml_backend_buffer_t buffer) { + GGML_ASSERT(ggml_backend_buffer_is_meta(buffer)); + ggml_backend_meta_buffer_context * buf_ctx = (ggml_backend_meta_buffer_context *) buffer->context; + for (size_t i = 0; i < buf_ctx->bufs.size(); i++) { + ggml_backend_buffer_reset(ggml_backend_meta_buffer_simple_buffer(buffer, i)); + } +} + +static const ggml_backend_buffer_i ggml_backend_meta_buffer_iface = { + /* .free_buffer = */ ggml_backend_meta_buffer_free_buffer, + /* .get_base = */ ggml_backend_meta_buffer_get_base, + /* .init_tensor = */ ggml_backend_meta_buffer_init_tensor, + /* .memset_tensor = */ nullptr, // TODO implement + /* .set_tensor = */ ggml_backend_meta_buffer_set_tensor, + /* .get_tensor = */ ggml_backend_meta_buffer_get_tensor, + /* .set_tensor_2d = */ nullptr, + /* .get_tensor_2d = */ nullptr, + /* .cpy_tensor = */ nullptr, + /* .clear = */ ggml_backend_meta_buffer_clear, + /* .reset = */ ggml_backend_meta_buffer_reset, +}; + +bool ggml_backend_buffer_is_meta(ggml_backend_buffer_t buf) { + return buf != nullptr && buf->iface.free_buffer == ggml_backend_meta_buffer_iface.free_buffer; +} + +static ggml_backend_buffer_t ggml_backend_meta_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) { + const size_t n_simple_bufts = ggml_backend_meta_buft_n_bufts(buft); + + const ggml_init_params params = { + /*.mem_size =*/ 1024*1024*ggml_tensor_overhead(), // FIXME + /*.mem_buffer =*/ nullptr, + /*.no_alloc =*/ true, + }; + ggml_backend_meta_simple_tensor_container stc_static; + ggml_backend_meta_simple_tensor_container stc_compute_0(params, n_simple_bufts); + ggml_backend_meta_simple_tensor_container stc_compute_1(params, n_simple_bufts); + + size_t max_size = 0; + std::vector<ggml_backend_buffer_t> bufs; + bufs.reserve(n_simple_bufts); + for (size_t i = 0; i < n_simple_bufts; i++) { + bufs.push_back(ggml_backend_buft_alloc_buffer(ggml_backend_meta_buft_simple_buft(buft, i), size)); + GGML_ASSERT(bufs.back() != nullptr); + max_size = std::max(max_size, ggml_backend_buffer_get_size(bufs.back())); + } + ggml_backend_meta_buffer_context * buf_ctx = new ggml_backend_meta_buffer_context(stc_static, stc_compute_0, stc_compute_1, bufs); + + return ggml_backend_buffer_init(buft, ggml_backend_meta_buffer_iface, buf_ctx, max_size); +} + +struct ggml_backend_buffer * ggml_backend_meta_alloc_ctx_tensors_from_buft(struct ggml_context * ctx, ggml_backend_buffer_type_t buft) { + const size_t n_simple_bufts = ggml_backend_meta_buft_n_bufts(buft); + + constexpr size_t compute_headroom = 16; // Maximum number of views per statically allocated tensor that can be created between evals. + const ggml_init_params params_static = { + /*.mem_size =*/ ggml_get_mem_size(ctx), + /*.mem_buffer =*/ nullptr, + /*.no_alloc =*/ true, + }; + const ggml_init_params params_compute = { + /*.mem_size =*/ compute_headroom*ggml_get_mem_size(ctx), + /*.mem_buffer =*/ nullptr, + /*.no_alloc =*/ true, + }; + ggml_backend_meta_simple_tensor_container stc_static (params_static, n_simple_bufts); + ggml_backend_meta_simple_tensor_container stc_compute_0(params_compute, n_simple_bufts); + ggml_backend_meta_simple_tensor_container stc_compute_1(params_compute, n_simple_bufts); + + std::vector<ggml_backend_buffer_t> bufs(n_simple_bufts, nullptr); + ggml_backend_meta_buffer_context * meta_buf_ctx = new ggml_backend_meta_buffer_context(stc_static, stc_compute_0, stc_compute_1, bufs); + + ggml_backend_buffer_t meta_buf = ggml_backend_buffer_init(buft, ggml_backend_meta_buffer_iface, meta_buf_ctx, 0); + for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != nullptr; t = ggml_get_next_tensor(ctx, t)) { + t->buffer = meta_buf; + ggml_backend_meta_buffer_init_tensor_impl(meta_buf_ctx->stc_static, t); + t->data = (void *) 0x2000000000000000; // FIXME + } + for (size_t i = 0; i < n_simple_bufts; i++) { + ggml_context * ctx = meta_buf_ctx->stc_static.ctxs[i].get(); + ggml_backend_buffer_type_t simple_buft = ggml_backend_meta_buft_simple_buft(buft, i); + + // If a ggml_context only has zero-sized tensors, ggml_backend_alloc_ctx_tensors_from_buft returns NULL. + // For those edge cases, allocate a dummy buffer instead. + bool any_nonzero_slice = false; + for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != nullptr; t = ggml_get_next_tensor(ctx, t)) { + if (ggml_nelements(t) != 0) { + any_nonzero_slice = true; + break; + } + } + if (any_nonzero_slice) { + meta_buf_ctx->bufs[i].reset(ggml_backend_alloc_ctx_tensors_from_buft(ctx, simple_buft)); + } else { + meta_buf_ctx->bufs[i].reset(ggml_backend_buft_alloc_buffer(simple_buft, 0)); + for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != nullptr; t = ggml_get_next_tensor(ctx, t)) { + t->buffer = meta_buf_ctx->bufs[i].get(); + } + } + GGML_ASSERT(meta_buf_ctx->bufs[i]); + meta_buf->size = std::max(meta_buf->size, ggml_backend_buffer_get_size(meta_buf_ctx->bufs[i].get())); + } + return meta_buf; +} + +// +// meta backend +// + +static ggml_guid_t ggml_backend_meta_guid() { + static ggml_guid guid = {0xf1, 0x0e, 0x34, 0xcf, 0x9c, 0x6f, 0x43, 0xcb, 0x96, 0x92, 0xbe, 0x8e, 0xbb, 0x71, 0x3f, 0xda}; + return &guid; +} + +struct ggml_backend_meta_context { + struct cgraph_config { + ggml_cgraph * cgraph_main = nullptr; + int offset = 0; // Node offset vs. original graph + + std::vector<ggml_cgraph *> cgraphs_aux; + }; + struct backend_config { + ggml_backend_t backend; + + std::vector<cgraph_config> cgraphs; + std::vector<ggml_tensor *> nodes; + std::vector<ggml_backend_buffer_ptr> bufs; + + backend_config(ggml_backend_t backend, const size_t n_reduce_steps) : backend(backend) { + bufs.resize(n_reduce_steps); + } + }; + std::string name; + std::vector<backend_config> backend_configs; + ggml_context_ptr ctx; + std::vector<ggml_cgraph *> cgraphs_aux; + std::vector<ggml_tensor *> nodes_aux; + size_t n_reduce_steps; + int max_nnodes = 0; + size_t max_tmp_size = 0; + size_t max_subgraphs = 0; + size_t n_subgraphs = 0; + uint64_t uid = 0; + + void * comm_ctx = nullptr; + ggml_backend_comm_allreduce_tensor_t comm_allreduce = nullptr; + + ggml_backend_meta_context(ggml_backend_dev_t meta_dev, const char * params) { + const size_t n_devs = ggml_backend_meta_dev_n_devs(meta_dev); + n_reduce_steps = std::ceil(std::log2(n_devs)); + name = "Meta("; + std::vector<ggml_backend_t> simple_backends; + backend_configs.reserve(n_devs); + simple_backends.reserve(n_devs); + for (size_t i = 0; i < n_devs; i++) { + ggml_backend_dev_t simple_dev = ggml_backend_meta_dev_simple_dev(meta_dev, i); + if (i > 0) { + name += ","; + } + name += ggml_backend_dev_name(simple_dev); + simple_backends.push_back(ggml_backend_dev_init(simple_dev, params)); + backend_configs.emplace_back(simple_backends.back(), n_reduce_steps); + } + name += ")"; + + if (n_devs > 1) { + ggml_backend_comm_init_t comm_init = (ggml_backend_comm_init_t) ggml_backend_reg_get_proc_address( + ggml_backend_dev_backend_reg(ggml_backend_get_device(simple_backends[0])), "ggml_backend_comm_init"); + if (comm_init != nullptr) { + comm_ctx = comm_init(simple_backends.data(), simple_backends.size()); + } + } + if (comm_ctx != nullptr) { + comm_allreduce = (ggml_backend_comm_allreduce_tensor_t) + ggml_backend_reg_get_proc_address(ggml_backend_dev_backend_reg( + ggml_backend_get_device(simple_backends[0])), "ggml_backend_comm_allreduce_tensor"); + GGML_ASSERT(comm_allreduce != nullptr); + } + } + + ~ggml_backend_meta_context() { + if (comm_ctx != nullptr) { + ggml_backend_comm_free_t comm_free = (ggml_backend_comm_free_t) ggml_backend_reg_get_proc_address( + ggml_backend_dev_backend_reg(ggml_backend_get_device(backend_configs[0].backend)), "ggml_backend_comm_free"); + GGML_ASSERT(comm_free != nullptr); + comm_free(comm_ctx); + } + for (auto & bc : backend_configs) { + ggml_backend_free(bc.backend); + } + } +}; + +static const char * ggml_backend_meta_get_name(ggml_backend_t backend) { + GGML_ASSERT(ggml_backend_is_meta(backend)); + const ggml_backend_meta_context * backend_ctx = (const ggml_backend_meta_context *) backend->context; + return backend_ctx->name.c_str(); +} + +static void ggml_backend_meta_free(ggml_backend_t backend) { + GGML_ASSERT(ggml_backend_is_meta(backend)); + ggml_backend_meta_context * backend_ctx = (ggml_backend_meta_context *) backend->context; + delete backend_ctx; + delete backend; +} + +static void ggml_backend_meta_set_tensor_async(ggml_backend_t backend, ggml_tensor * tensor, const void * data, size_t offset, size_t size) { + const size_t n_backends = ggml_backend_meta_n_backends(backend); + GGML_ASSERT(offset == 0); + GGML_ASSERT(ggml_is_contiguous(tensor)); + + const ggml_backend_meta_split_state split_state = ggml_backend_meta_get_split_state(tensor, /*assume_sync =*/ false); + GGML_ASSERT(split_state.n_segments == 1); + GGML_ASSERT(split_state.nr[0] == 1); + + switch (split_state.axis) { + case GGML_BACKEND_SPLIT_AXIS_0: + case GGML_BACKEND_SPLIT_AXIS_1: + case GGML_BACKEND_SPLIT_AXIS_2: { + // Exploit that tensors are contiguous to splice it with simple tensors as "chunks". + const size_t chunk_size_full = tensor->nb[split_state.axis + 1]; + GGML_ASSERT(offset % chunk_size_full == 0); + GGML_ASSERT(size % chunk_size_full == 0); + const int64_t i_start = offset /chunk_size_full; + const int64_t i_stop = (offset + size)/chunk_size_full; + size_t offset_j = 0; + for (size_t j = 0; j < n_backends; j++){ + ggml_backend_t simple_backend = ggml_backend_meta_simple_backend(backend, j); + ggml_tensor * simple_tensor = ggml_backend_meta_buffer_simple_tensor(tensor, j); + const size_t chunk_size_j = simple_tensor->nb[split_state.axis + 1]; + if (chunk_size_j == 0) { + continue; + } + ggml_backend_tensor_set_2d_async(simple_backend, simple_tensor, (const char *) data + offset_j, offset, chunk_size_j, + i_stop - i_start, chunk_size_j, chunk_size_full); + offset_j += chunk_size_j; + } + GGML_ASSERT(offset_j == chunk_size_full); + } break; + case GGML_BACKEND_SPLIT_AXIS_MIRRORED: { + for (size_t j = 0; j < n_backends; j++) { + ggml_backend_tensor_set_async( + ggml_backend_meta_simple_backend(backend, j), ggml_backend_meta_buffer_simple_tensor(tensor, j), data, offset, size); + } + } break; + default: { + GGML_ABORT("fatal error"); + } + } +} + +static void ggml_backend_meta_get_tensor_async(ggml_backend_t backend, const ggml_tensor * tensor, void * data, size_t offset, size_t size) { + const size_t n_backends = ggml_backend_meta_n_backends(backend); + GGML_ASSERT(offset == 0); + GGML_ASSERT(ggml_is_contiguous(tensor)); + + const ggml_backend_meta_split_state split_state = ggml_backend_meta_get_split_state(tensor, /*assume_sync =*/ false); + GGML_ASSERT(split_state.n_segments == 1); + GGML_ASSERT(split_state.nr[0] == 1); + + switch (split_state.axis) { + case GGML_BACKEND_SPLIT_AXIS_0: + case GGML_BACKEND_SPLIT_AXIS_1: + case GGML_BACKEND_SPLIT_AXIS_2: { + // Exploit that tensors are contiguous to splice it with simple tensors as "chunks". + const size_t chunk_size_full = tensor->nb[split_state.axis + 1]; + GGML_ASSERT(offset % chunk_size_full == 0); + GGML_ASSERT(size % chunk_size_full == 0); + const int64_t i_start = offset /chunk_size_full; + const int64_t i_stop = (offset + size)/chunk_size_full; + size_t offset_j = 0; + for (size_t j = 0; j < n_backends; j++){ + ggml_backend_t simple_backend = ggml_backend_meta_simple_backend(backend, j); + const ggml_tensor * simple_tensor = ggml_backend_meta_buffer_simple_tensor(tensor, j); + const size_t chunk_size_j = simple_tensor->nb[split_state.axis + 1]; + if (chunk_size_j == 0) { + continue; + } + ggml_backend_tensor_get_2d_async(simple_backend, simple_tensor, (char *) data + offset_j, offset, chunk_size_j, + i_stop - i_start, chunk_size_j, chunk_size_full); + offset_j += chunk_size_j; + } + GGML_ASSERT(offset_j == chunk_size_full); + } break; + case GGML_BACKEND_SPLIT_AXIS_MIRRORED: { + // TODO other simple backend may be better + ggml_backend_t simple_backend = ggml_backend_meta_simple_backend(backend, 0); + const ggml_tensor * simple_tensor = ggml_backend_meta_buffer_simple_tensor(tensor, 0); + ggml_backend_tensor_get_async(simple_backend, simple_tensor, data, offset, size); + } break; + default: { + GGML_ABORT("fatal error"); + } + } +} + +static void ggml_backend_meta_synchronize(ggml_backend_t backend) { + const size_t n_backends = ggml_backend_meta_n_backends(backend); + for (size_t i = 0; i < n_backends; i++) { + ggml_backend_synchronize(ggml_backend_meta_simple_backend(backend, i)); + } +} + +static enum ggml_status ggml_backend_meta_graph_compute(ggml_backend_t backend, struct ggml_cgraph * cgraph) { + GGML_ASSERT(cgraph->grads == nullptr); + const size_t n_backends = ggml_backend_meta_n_backends(backend); + ggml_backend_meta_context * backend_ctx = (ggml_backend_meta_context *) backend->context; + + // If the previous cgraph had a defined UID it can be used to skip rebuilding the subgraphs per simple backend. + const bool needs_rebuild = (cgraph->uid == 0) || (cgraph->uid != backend_ctx->uid); + + bool max_nnodes_raised = false; + if (cgraph->n_nodes > backend_ctx->max_nnodes) { + for (size_t j = 0; j < n_backends; j++) { + auto & bcj = backend_ctx->backend_configs[j]; + bcj.nodes.resize(cgraph->n_nodes); + bcj.cgraphs.resize(cgraph->n_nodes); + } + backend_ctx->max_nnodes = cgraph->n_nodes; + max_nnodes_raised = true; + assert(needs_rebuild); + } + + if (needs_rebuild) { + std::set<ggml_backend_buffer_t> used_buffers; + for (int i = 0; i < cgraph->n_leafs; i++) { + if (ggml_backend_buffer_is_meta(cgraph->leafs[i]->buffer)) { + used_buffers.emplace(cgraph->leafs[i]->buffer); + } + } + for (int i = 0; i < cgraph->n_nodes; i++) { + if (ggml_backend_buffer_is_meta(cgraph->nodes[i]->buffer)) { + used_buffers.emplace(cgraph->nodes[i]->buffer); + } + } + for (ggml_backend_buffer_t buf : used_buffers) { + ggml_backend_meta_buffer_context * buf_ctx = (ggml_backend_meta_buffer_context *) buf->context; + buf_ctx->stc_compute_index_next = buf_ctx->stc_compute_index ^ 1; + ggml_backend_meta_simple_tensor_container & stc = buf_ctx->stc_compute[buf_ctx->stc_compute_index_next]; + for (ggml_context_ptr & ctx : stc.ctxs) { + ggml_reset(ctx.get()); + } + stc.simple_tensors.clear(); + } + size_t n_subgraphs = 0; + size_t max_tmp_size = 0; + + for (size_t j = 0; j < n_backends; j++) { + auto & bcj = backend_ctx->backend_configs[j]; + + for (int i = 0; i < cgraph->n_nodes; i++) { + ggml_tensor * node = cgraph->nodes[i]; + if (node->view_src != nullptr && node->view_src->op == GGML_OP_NONE && ggml_backend_buffer_is_host(node->view_src->buffer)) { + // FIXME s_copy_main is on the CPU and its view seems to be incorrectly added to the graph nodes. + // For regular usage this doesn't matter since it's a noop but trying to call ggml_backend_meta_buffer_simple_tensor results in a crash. + bcj.nodes[i] = node; + continue; + } + bcj.nodes[i] = ggml_backend_meta_buffer_simple_tensor(node, j); + GGML_ASSERT(bcj.nodes[i]); + } + } + + { + // For MoE models it may make sense to delay the AllReduce in order to reduce I/O: + auto get_i_delayed = [&](const int i) -> int { + int id = i; // i_delayed + int idr = i; // i_delayed return, last safe return value + + ggml_tensor * node = cgraph->nodes[id]; + int32_t n_used = ggml_node_get_use_count(cgraph, id); + + // Skip MIRRORED nodes that don't consume node + auto skip_unrelated = [&]() { + while (id + 1 < cgraph->n_nodes) { + ggml_tensor * next = cgraph->nodes[id+1]; + if (ggml_backend_meta_get_split_state(next, false).axis != GGML_BACKEND_SPLIT_AXIS_MIRRORED) { + break; + } + bool safe = true; + for (int s = 0; s < GGML_MAX_SRC; s++) { + if (next->src[s] == nullptr) { + continue; + } + if (next->src[s] == node) { + safe = false; + break; + } + if (ggml_backend_meta_get_split_state(next->src[s], false).axis != GGML_BACKEND_SPLIT_AXIS_MIRRORED) { + safe = false; + break; + } + } + if (!safe) { + break; + } + id++; + } + }; + + skip_unrelated(); + if (id + 1 >= cgraph->n_nodes) { + return idr; + } + { + ggml_tensor * next = cgraph->nodes[id+1]; + if (next->op == GGML_OP_ADD_ID && next->src[0] == node && + ggml_backend_meta_get_split_state(next->src[1], false).axis == GGML_BACKEND_SPLIT_AXIS_PARTIAL && + ggml_backend_meta_get_split_state(next->src[2], false).axis == GGML_BACKEND_SPLIT_AXIS_MIRRORED) { + node = next; + id++; + idr = id; + n_used = ggml_node_get_use_count(cgraph, id); + } + } + // Chain of MULs with MIRRORED src[1] + while (true) { + skip_unrelated(); + if (id + 1 >= cgraph->n_nodes) { + return idr; + } + ggml_tensor * next = cgraph->nodes[id+1]; + if (next->op == GGML_OP_MUL && next->src[0] == node && + ggml_backend_meta_get_split_state(next->src[1], false).axis == GGML_BACKEND_SPLIT_AXIS_MIRRORED) { + node = next; + id++; + idr = id; + n_used = ggml_node_get_use_count(cgraph, id); + } else { + break; + } + } + + if (n_used != node->ne[1] || id + 2*n_used-1 >= cgraph->n_nodes) { + return idr; + } + for (int32_t k = 0; k < n_used; k++) { + ggml_tensor * next = cgraph->nodes[id+1]; + if (next->op != GGML_OP_VIEW || next->view_src != node || next->view_offs != k*node->nb[1] || + next->ne[0] != node->ne[0] || next->ne[1] != node->ne[2] || next->nb[1] != node->nb[2] || + ggml_node_get_use_count(cgraph, id+1) != 1) { + return idr; + } + id++; + } + { + ggml_tensor * next = cgraph->nodes[id+1]; + if (next->op != GGML_OP_ADD || next->src[0] != cgraph->nodes[id - (n_used-1)] || + next->src[1] != cgraph->nodes[id - (n_used-2)] || ggml_node_get_use_count(cgraph, id+1) != 1) { + return idr; + } + id++; + } + for (int32_t k = 0; k < n_used - 2; k++) { + ggml_tensor * next = cgraph->nodes[id+1]; + if (next->op != GGML_OP_ADD || next->src[0] != cgraph->nodes[id] || + next->src[1] != cgraph->nodes[id - (n_used-2)] || ggml_node_get_use_count(cgraph, id+1) != 1) { + return idr; + } + id++; + } + idr = id; + return idr; + }; + + int i_start = 0; + for (int i = 0; i < cgraph->n_nodes; i++) { + ggml_tensor * node = cgraph->nodes[i]; + if (node->view_src != nullptr && node->view_src->op == GGML_OP_NONE && ggml_backend_buffer_is_host(node->view_src->buffer)) { + continue; + } + const ggml_backend_meta_split_state split_state = ggml_backend_meta_get_split_state(node, /*assume_sync =*/ false); + if (split_state.axis == GGML_BACKEND_SPLIT_AXIS_PARTIAL) { + max_tmp_size = std::max(max_tmp_size, ggml_nbytes(node)); + } + const bool new_subgraph = i + 1 == cgraph->n_nodes || split_state.axis == GGML_BACKEND_SPLIT_AXIS_PARTIAL; + if (!new_subgraph) { + continue; + } + + const int i_delayed = get_i_delayed(i); + + // If we can delay the AllReduce we need to consider the interaction with zero-sized tensor slices. + // A backend with such a slice would normally have valid data after participating in the AllReduce with a node that has + // its compute flag disabled and thus gets its data zeroed out. + // If the AllReduce is delayed then the nodes until that point also need to have their compute flag disabled. + if (i_delayed > i) { + for (size_t j = 0; j < n_backends; j++) { + auto & bcj = backend_ctx->backend_configs[j]; + if ((bcj.nodes[i]->flags & GGML_TENSOR_FLAG_COMPUTE) == 0) { + for (int ii = i + 1; ii <= i_delayed; ii++) { + bcj.nodes[ii]->flags &= ~GGML_TENSOR_FLAG_COMPUTE; + } + } + } + } + + i = i_delayed; + + for (size_t j = 0; j < n_backends; j++) { + auto & bcj = backend_ctx->backend_configs[j]; + bcj.cgraphs[n_subgraphs].offset = i_start; + } + n_subgraphs++; + i_start = i + 1; + } + GGML_ASSERT(i_start == cgraph->n_nodes); + } + + backend_ctx->uid = cgraph->uid; + backend_ctx->n_subgraphs = n_subgraphs; + + if (max_tmp_size > backend_ctx->max_tmp_size) { + for (size_t j = 0; j < n_backends; j++) { + auto & bcj = backend_ctx->backend_configs[j]; + for (size_t i = 0; i < backend_ctx->n_reduce_steps; i++) { + bcj.bufs[i].reset(ggml_backend_alloc_buffer(bcj.backend, max_tmp_size)); + } + } + backend_ctx->max_tmp_size = max_tmp_size; + } + + if (max_nnodes_raised || n_subgraphs > backend_ctx->max_subgraphs) { + backend_ctx->max_subgraphs = std::max(backend_ctx->max_subgraphs, n_subgraphs); + const size_t n_nodes_per_device = 3 * backend_ctx->n_reduce_steps; // tmp + ADD (+zeroing) graph per step and device + const size_t n_cgraphs_per_device = 2 * backend_ctx->n_reduce_steps; // ADD ( + zeroing) graph per step and device + const size_t mem_per_device_graphs_main = backend_ctx->max_subgraphs*ggml_graph_overhead_custom(backend_ctx->max_nnodes, cgraph->grads); + const size_t mem_per_device_graphs_aux = n_cgraphs_per_device*backend_ctx->max_subgraphs*ggml_graph_overhead_custom(1, cgraph->grads); + const size_t mem_per_device_nodes_aux = n_nodes_per_device*backend_ctx->max_subgraphs*ggml_tensor_overhead(); + const ggml_init_params params = { + /*.mem_size =*/ n_backends * (mem_per_device_graphs_main + mem_per_device_graphs_aux + mem_per_device_nodes_aux), + /*.mem_buffer =*/ nullptr, + /*.no_alloc =*/ true, + }; + backend_ctx->ctx.reset(ggml_init(params)); + for (size_t j = 0; j < n_backends; j++) { + auto & bcj = backend_ctx->backend_configs[j]; + for (size_t i = 0; i < n_subgraphs; i++) { + bcj.cgraphs[i].cgraph_main = ggml_new_graph_custom(backend_ctx->ctx.get(), cgraph->n_nodes, /*grads =*/ false); + } + } + backend_ctx->cgraphs_aux.resize(n_backends*n_cgraphs_per_device*backend_ctx->max_subgraphs); + for (size_t k = 0; k < backend_ctx->cgraphs_aux.size(); k++) { + backend_ctx->cgraphs_aux[k] = ggml_new_graph_custom(backend_ctx->ctx.get(), 1, cgraph->grads); + } + backend_ctx->nodes_aux.resize(n_backends*n_nodes_per_device*backend_ctx->max_subgraphs); + for (size_t k = 0; k < backend_ctx->nodes_aux.size(); k++) { + backend_ctx->nodes_aux[k] = ggml_new_tensor_1d(backend_ctx->ctx.get(), GGML_TYPE_F32, 1); + } + } + + for (size_t j = 0; j < n_backends; j++) { + auto & bcj = backend_ctx->backend_configs[j]; + for (size_t i_graph = 0; i_graph < n_subgraphs; i_graph++) { + ggml_cgraph * cgraph_ij = bcj.cgraphs[i_graph].cgraph_main; + const size_t i_node_start = bcj.cgraphs[i_graph].offset; + const size_t i_node_stop = i_graph + 1 < n_subgraphs ? bcj.cgraphs[i_graph + 1].offset : cgraph->n_nodes; + cgraph_ij->n_nodes = i_node_stop - i_node_start; + ggml_hash_set_reset(&cgraph_ij->visited_hash_set); + for (size_t i_node = i_node_start; i_node < i_node_stop; i_node++) { + ggml_tensor * node_ij = bcj.nodes[i_node]; + cgraph_ij->nodes[i_node - i_node_start] = node_ij; + const size_t hash_pos_orig = ggml_hash_find(&cgraph->visited_hash_set, cgraph->nodes[i_node]); + const size_t hash_pos_ij = ggml_hash_insert(&cgraph_ij->visited_hash_set, node_ij); + cgraph_ij->use_counts[hash_pos_ij] = cgraph->use_counts[hash_pos_orig]; + } + cgraph_ij->uid = ggml_graph_next_uid(); + } + } + } + + size_t iga = 0; // i graph aux + size_t ina = 0; // i node aux + + auto get_node_aux = [&](ggml_tensor * t) -> ggml_tensor * { + ggml_tensor * ret = backend_ctx->nodes_aux[ina++]; + memset(ret, 0, sizeof(ggml_tensor)); + ret->op = GGML_OP_NONE; + ret->type = t->type; + for (size_t k = 0; k < GGML_MAX_DIMS; k++) { + ret->ne[k] = t->ne[k]; + ret->nb[k] = t->nb[k]; + } + return ret; + }; + auto set_tmp_data = [&](ggml_tensor * tensor, const size_t j, const size_t i_buf) { + auto & bcj = backend_ctx->backend_configs[j]; + ggml_backend_buffer_ptr & buf_ptr = bcj.bufs[i_buf]; + if (!buf_ptr || ggml_backend_buffer_get_size(buf_ptr.get()) < backend_ctx->max_tmp_size) { + buf_ptr.reset(ggml_backend_alloc_buffer(bcj.backend, backend_ctx->max_tmp_size)); + } + tensor->buffer = buf_ptr.get(); + tensor->data = ggml_backend_buffer_get_base(buf_ptr.get()); + }; + // FIXME usage_counts + auto get_cgraph_aux = [&]() -> ggml_cgraph * { + ggml_cgraph * ret = backend_ctx->cgraphs_aux[iga++]; + return ret; + }; + + // Preferentially use backend-specific allreduce_tensor_async (e.g. NCCL for CUDA), use a generic fallback if unavailable: + auto allreduce_fallback = [&](size_t i) -> ggml_status { + std::vector<ggml_cgraph *> step_cgraphs(n_backends, nullptr); + + // Zero out nodes that were disabled due to having a zero-sized slice: + for (size_t j = 0; j < n_backends; j++) { + auto & bcj = backend_ctx->backend_configs[j]; + ggml_tensor * node = bcj.cgraphs[i].cgraph_main->nodes[bcj.cgraphs[i].cgraph_main->n_nodes - 1]; + if (node->flags & GGML_TENSOR_FLAG_COMPUTE) { + continue; + } + ggml_tensor * node_zero = get_node_aux(node); + node_zero->op = GGML_OP_SCALE; // FIXME 0.0f * NaN == NaN + node_zero->src[0] = node; + ggml_set_op_params_f32(node_zero, 0, 0.0f); + node_zero->data = node->data; + node_zero->buffer = node->buffer; + node_zero->flags |= GGML_TENSOR_FLAG_COMPUTE; + + step_cgraphs[j] = get_cgraph_aux(); + step_cgraphs[j]->nodes[0] = node_zero; + step_cgraphs[j]->n_nodes = 1; + const ggml_status status = ggml_backend_graph_compute_async(bcj.backend, step_cgraphs[j]); + if (status != GGML_STATUS_SUCCESS) { + return status; + } + } + std::fill(step_cgraphs.begin(), step_cgraphs.end(), nullptr); + + auto push_data = [&](const size_t j_src, const size_t j_dst, const size_t i_buf) { + assert(step_cgraphs[j_dst] == nullptr); + auto & bcj_src = backend_ctx->backend_configs[j_src]; + auto & bcj_dst = backend_ctx->backend_configs[j_dst]; + + ggml_tensor * node_src = bcj_src.cgraphs[i].cgraph_main->nodes[bcj_src.cgraphs[i].cgraph_main->n_nodes - 1]; + ggml_tensor * node_dst = bcj_dst.cgraphs[i].cgraph_main->nodes[bcj_dst.cgraphs[i].cgraph_main->n_nodes - 1]; + GGML_ASSERT(ggml_is_contiguous(node_src)); + GGML_ASSERT(ggml_is_contiguous(node_dst)); + + ggml_tensor * node_tmp = get_node_aux(node_dst); + set_tmp_data(node_tmp, j_dst, i_buf); + + ggml_backend_tensor_copy_async(bcj_src.backend, bcj_dst.backend, node_src, node_tmp); + + ggml_tensor * node_red = get_node_aux(node_dst); + node_red->view_src = node_dst->view_src == nullptr ? node_dst : node_dst->view_src; + node_red->view_offs = node_dst->view_offs; + node_red->op = GGML_OP_ADD; + node_red->src[0] = node_dst; + node_red->src[1] = node_tmp; + node_red->flags |= GGML_TENSOR_FLAG_COMPUTE; + ggml_backend_view_init(node_red); + + ggml_cgraph * cgraph_aux = get_cgraph_aux(); + cgraph_aux->nodes[0] = node_red; + cgraph_aux->n_nodes = 1; + step_cgraphs[j_dst] = cgraph_aux; + }; + + size_t offset_j = n_backends/2; + while ((offset_j & (offset_j - 1)) != 0) { + offset_j--; + } + const size_t offset_j_max = offset_j; + size_t i_buf = 0; + + // If n_backends is not a power of 2, fold in the excess prior to butterfly reduction: + for (size_t j_src = 2*offset_j_max; j_src < n_backends; j_src++) { + const size_t j_dst = j_src - 2*offset_j_max; + push_data(j_src, j_dst, i_buf); + const ggml_status status = ggml_backend_graph_compute_async(backend_ctx->backend_configs[j_dst].backend, step_cgraphs[j_dst]); + if (status != GGML_STATUS_SUCCESS) { + return status; + } + i_buf = 1; + } + + // Butterfly reduction: + for (; offset_j >= 1; offset_j /= 2) { + std::fill(step_cgraphs.begin(), step_cgraphs.end(), nullptr); + + for (size_t j = 0; j < 2*offset_j_max; j++) { + const size_t j_other = j ^ offset_j; + if (j_other >= n_backends) { + continue; + } + push_data(j, j_other, i_buf); + } + + for (size_t j = 0; j < 2*offset_j_max; j++) { + if (step_cgraphs[j] == nullptr) { + continue; + } + auto & bcj = backend_ctx->backend_configs[j]; + const ggml_status status = ggml_backend_graph_compute_async(bcj.backend, step_cgraphs[j]); + if (status != GGML_STATUS_SUCCESS) { + return status; + } + } + i_buf++; + } + assert(i_buf == backend_ctx->n_reduce_steps); + + // If n_backends is not a power of 2, copy back the reduced tensors to the excess: + for (size_t j = 2*offset_j_max; j < n_backends; j++) { + auto & bcj_src = backend_ctx->backend_configs[j - 2*offset_j_max]; + auto & bcj_dst = backend_ctx->backend_configs[j]; + + ggml_tensor * node_src = bcj_src.cgraphs[i].cgraph_main->nodes[bcj_src.cgraphs[i].cgraph_main->n_nodes - 1]; + ggml_tensor * node_dst = bcj_dst.cgraphs[i].cgraph_main->nodes[bcj_dst.cgraphs[i].cgraph_main->n_nodes - 1]; + ggml_backend_tensor_copy_async(bcj_src.backend, bcj_dst.backend, node_src, node_dst); + } + + return GGML_STATUS_SUCCESS; + }; + + + for (size_t i = 0; i < backend_ctx->n_subgraphs; i++) { + for (size_t j = 0; j < n_backends; j++) { + auto & bcj = backend_ctx->backend_configs[j]; + const ggml_status status = ggml_backend_graph_compute_async(bcj.backend, bcj.cgraphs[i].cgraph_main); + if (status != GGML_STATUS_SUCCESS) { + return status; + } + } + + if (n_backends > 1 && i < backend_ctx->n_subgraphs - 1) { + bool backend_allreduce_success = false; + if (backend_ctx->comm_ctx) { + std::vector<ggml_tensor *> nodes; + nodes.reserve(n_backends); + for (size_t j = 0; j < n_backends; j++) { + auto & bcj = backend_ctx->backend_configs[j]; + ggml_cgraph * cgraph_ij = bcj.cgraphs[i].cgraph_main; + nodes.push_back(cgraph_ij->nodes[cgraph_ij->n_nodes-1]); + } + backend_allreduce_success = backend_ctx->comm_allreduce(backend_ctx->comm_ctx, nodes.data()); + } + + if (!backend_allreduce_success) { + const ggml_status status = allreduce_fallback(i); + if (status != GGML_STATUS_SUCCESS) { + return status; + } + } + } + } + return GGML_STATUS_SUCCESS; +} + +static const ggml_backend_i ggml_backend_meta_i = { + /* .get_name = */ ggml_backend_meta_get_name, + /* .free = */ ggml_backend_meta_free, + /* .set_tensor_async = */ ggml_backend_meta_set_tensor_async, + /* .get_tensor_async = */ ggml_backend_meta_get_tensor_async, + /* .set_tensor_2d_async = */ nullptr, + /* .get_tensor_2d_async = */ nullptr, + /* .cpy_tensor_async = */ nullptr, + /* .synchronize = */ ggml_backend_meta_synchronize, + /* .graph_plan_create = */ nullptr, + /* .graph_plan_free = */ nullptr, + /* .graph_plan_update = */ nullptr, + /* .graph_plan_compute = */ nullptr, + /* .graph_compute = */ ggml_backend_meta_graph_compute, + /* .event_record = */ nullptr, + /* .event_wait = */ nullptr, + /* .graph_optimize = */ nullptr, +}; + +bool ggml_backend_is_meta(ggml_backend_t backend) { + return backend != nullptr && backend->iface.get_name == ggml_backend_meta_i.get_name; +} + +static ggml_backend_t ggml_backend_meta_device_init_backend(ggml_backend_dev_t dev, const char * params) { + ggml_backend_meta_context * backend_ctx = new ggml_backend_meta_context(dev, params); + + ggml_backend_t backend = new struct ggml_backend; + backend->guid = ggml_backend_meta_guid(); + backend->iface = ggml_backend_meta_i; + backend->device = dev; + backend->context = backend_ctx; + return backend; +} + +size_t ggml_backend_meta_n_backends(ggml_backend_t meta_backend) { + GGML_ASSERT(ggml_backend_is_meta(meta_backend)); + const ggml_backend_meta_context * backend_ctx = (const ggml_backend_meta_context *) meta_backend->context; + return backend_ctx->backend_configs.size(); +} + +ggml_backend_t ggml_backend_meta_simple_backend(ggml_backend_t meta_backend, size_t index) { + GGML_ASSERT(ggml_backend_is_meta(meta_backend)); + const ggml_backend_meta_context * backend_ctx = (const ggml_backend_meta_context *) meta_backend->context; + return backend_ctx->backend_configs[index].backend; +} diff --git a/ggml/src/ggml-backend-reg.cpp b/ggml/src/ggml-backend-reg.cpp index 4181a714ad6..8165ae2c8bb 100644 --- a/ggml/src/ggml-backend-reg.cpp +++ b/ggml/src/ggml-backend-reg.cpp @@ -1,5 +1,6 @@ #include "ggml-backend-impl.h" #include "ggml-backend.h" +#include "ggml-backend-dl.h" #include "ggml-impl.h" #include <algorithm> #include <cstring> @@ -69,6 +70,10 @@ #include "ggml-rpc.h" #endif +#ifdef GGML_USE_VIRTGPU_FRONTEND +#include "ggml-virtgpu.h" +#endif + #ifdef GGML_USE_CANN #include "ggml-cann.h" #endif @@ -77,105 +82,27 @@ #include "ggml-zendnn.h" #endif -// disable C++17 deprecation warning for std::codecvt_utf8 -#if defined(__clang__) -# pragma clang diagnostic push -# pragma clang diagnostic ignored "-Wdeprecated-declarations" -#elif defined(__GNUC__) -# pragma GCC diagnostic push -# pragma GCC diagnostic ignored "-Wdeprecated-declarations" +#ifdef GGML_USE_OPENVINO +#include "ggml-openvino.h" #endif namespace fs = std::filesystem; static std::string path_str(const fs::path & path) { - std::string u8path; try { #if defined(__cpp_lib_char8_t) // C++20 and later: u8string() returns std::u8string - std::u8string u8str = path.u8string(); - u8path = std::string(reinterpret_cast<const char*>(u8str.c_str())); + const std::u8string u8str = path.u8string(); + return std::string(reinterpret_cast<const char *>(u8str.data()), u8str.size()); #else // C++17: u8string() returns std::string - u8path = path.u8string(); + return path.u8string(); #endif } catch (...) { + return std::string(); } - return u8path; -} - -#if defined(__clang__) -# pragma clang diagnostic pop -#elif defined(__GNUC__) -# pragma GCC diagnostic pop -#endif - -#ifdef _WIN32 - -using dl_handle = std::remove_pointer_t<HMODULE>; - -struct dl_handle_deleter { - void operator()(HMODULE handle) { - FreeLibrary(handle); - } -}; - -static dl_handle * dl_load_library(const fs::path & path) { - // suppress error dialogs for missing DLLs - DWORD old_mode = SetErrorMode(SEM_FAILCRITICALERRORS); - SetErrorMode(old_mode | SEM_FAILCRITICALERRORS); - - HMODULE handle = LoadLibraryW(path.wstring().c_str()); - - SetErrorMode(old_mode); - - return handle; -} - -static void * dl_get_sym(dl_handle * handle, const char * name) { - DWORD old_mode = SetErrorMode(SEM_FAILCRITICALERRORS); - SetErrorMode(old_mode | SEM_FAILCRITICALERRORS); - - void * p = (void *) GetProcAddress(handle, name); - - SetErrorMode(old_mode); - - return p; -} - -static const char * dl_error() { - return ""; -} - -#else - -using dl_handle = void; - -struct dl_handle_deleter { - void operator()(void * handle) { - dlclose(handle); - } -}; - -static void * dl_load_library(const fs::path & path) { - dl_handle * handle = dlopen(path.string().c_str(), RTLD_NOW | RTLD_LOCAL); - - return handle; } -static void * dl_get_sym(dl_handle * handle, const char * name) { - return dlsym(handle, name); -} - -static const char * dl_error() { - const char *rslt = dlerror(); - return rslt != nullptr ? rslt : ""; -} - -#endif - -using dl_handle_ptr = std::unique_ptr<dl_handle, dl_handle_deleter>; - struct ggml_backend_reg_entry { ggml_backend_reg_t reg; dl_handle_ptr handle; @@ -196,7 +123,12 @@ struct ggml_backend_registry { register_backend(ggml_backend_sycl_reg()); #endif #ifdef GGML_USE_VULKAN + // Add runtime disable check + if (getenv("GGML_DISABLE_VULKAN") == nullptr) { register_backend(ggml_backend_vk_reg()); + } else { + GGML_LOG_DEBUG("Vulkan backend disabled by GGML_DISABLE_VULKAN environment variable\n"); + } #endif #ifdef GGML_USE_WEBGPU register_backend(ggml_backend_webgpu_reg()); @@ -204,6 +136,10 @@ struct ggml_backend_registry { #ifdef GGML_USE_ZDNN register_backend(ggml_backend_zdnn_reg()); #endif +#ifdef GGML_USE_VIRTGPU_FRONTEND + register_backend(ggml_backend_virtgpu_reg()); +#endif + #ifdef GGML_USE_OPENCL register_backend(ggml_backend_opencl_reg()); #endif @@ -222,6 +158,9 @@ struct ggml_backend_registry { #ifdef GGML_USE_RPC register_backend(ggml_backend_rpc_reg()); #endif +#ifdef GGML_USE_OPENVINO + register_backend(ggml_backend_openvino_reg()); +#endif #ifdef GGML_USE_CPU register_backend(ggml_backend_cpu_reg()); #endif @@ -242,6 +181,12 @@ struct ggml_backend_registry { return; } + for (auto & entry : backends) { + if (entry.reg == reg) { + return; + } + } + #ifndef NDEBUG GGML_LOG_DEBUG("%s: registered backend %s (%zu devices)\n", __func__, ggml_backend_reg_name(reg), ggml_backend_reg_dev_count(reg)); @@ -253,6 +198,12 @@ struct ggml_backend_registry { } void register_device(ggml_backend_dev_t device) { + for (auto & dev : devices) { + if (dev == device) { + return; + } + } + #ifndef NDEBUG GGML_LOG_DEBUG("%s: registered device %s (%s)\n", __func__, ggml_backend_dev_name(device), ggml_backend_dev_description(device)); #endif @@ -539,9 +490,10 @@ static ggml_backend_reg_t ggml_backend_load_best(const char * name, bool silent, int best_score = 0; fs::path best_path; + std::error_code ec; for (const auto & search_path : search_paths) { - if (std::error_code ec; !fs::exists(search_path, ec)) { + if (!fs::exists(search_path, ec)) { if (ec) { GGML_LOG_DEBUG("%s: posix_stat(%s) failure, error-message: %s\n", __func__, path_str(search_path).c_str(), ec.message().c_str()); } else { @@ -551,7 +503,7 @@ static ggml_backend_reg_t ggml_backend_load_best(const char * name, bool silent, } fs::directory_iterator dir_it(search_path, fs::directory_options::skip_permission_denied); for (const auto & entry : dir_it) { - if (entry.is_regular_file()) { + if (entry.is_regular_file(ec)) { auto filename = entry.path().filename(); auto ext = entry.path().extension(); if (filename.native().find(file_prefix) == 0 && ext == file_extension) { @@ -620,9 +572,11 @@ void ggml_backend_load_all_from_path(const char * dir_path) { ggml_backend_load_best("rpc", silent, dir_path); ggml_backend_load_best("sycl", silent, dir_path); ggml_backend_load_best("vulkan", silent, dir_path); + ggml_backend_load_best("virtgpu", silent, dir_path); ggml_backend_load_best("opencl", silent, dir_path); ggml_backend_load_best("hexagon", silent, dir_path); ggml_backend_load_best("musa", silent, dir_path); + ggml_backend_load_best("openvino", silent, dir_path); ggml_backend_load_best("cpu", silent, dir_path); // check the environment variable GGML_BACKEND_PATH to load an out-of-tree backend const char * backend_path = std::getenv("GGML_BACKEND_PATH"); diff --git a/ggml/src/ggml-backend.cpp b/ggml/src/ggml-backend.cpp index 1b59924b8cb..87615921c09 100644 --- a/ggml/src/ggml-backend.cpp +++ b/ggml/src/ggml-backend.cpp @@ -123,7 +123,7 @@ size_t ggml_backend_buffer_get_size(ggml_backend_buffer_t buffer) { void * ggml_backend_buffer_get_base(ggml_backend_buffer_t buffer) { GGML_ASSERT(buffer); // get_base is optional if the buffer is zero-sized - if (buffer->size == 0) { + if (!ggml_backend_buffer_is_meta(buffer) && buffer->size == 0) { return NULL; } @@ -258,6 +258,7 @@ void ggml_backend_tensor_set_async(ggml_backend_t backend, struct ggml_tensor * GGML_ASSERT(offset + size <= ggml_nbytes(tensor) && "tensor write out of bounds"); if (backend->iface.set_tensor_async == NULL) { + ggml_backend_synchronize(backend); ggml_backend_tensor_set(tensor, data, offset, size); } else { backend->iface.set_tensor_async(backend, tensor, data, offset, size); @@ -271,21 +272,64 @@ void ggml_backend_tensor_get_async(ggml_backend_t backend, const struct ggml_ten GGML_ASSERT(offset + size <= ggml_nbytes(tensor) && "tensor read out of bounds"); if (backend->iface.get_tensor_async == NULL) { + ggml_backend_synchronize(backend); ggml_backend_tensor_get(tensor, data, offset, size); } else { backend->iface.get_tensor_async(backend, tensor, data, offset, size); } } +void ggml_backend_tensor_set_2d_async(ggml_backend_t backend, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size, + size_t n_copies, size_t stride_tensor, size_t stride_data) { + GGML_ASSERT(backend); + GGML_ASSERT(tensor); + GGML_ASSERT(tensor->data != NULL && "tensor not allocated"); + + if (n_copies <= 1 || backend->iface.set_tensor_2d_async == NULL) { + for (size_t i = 0; i < n_copies; i++) { + ggml_backend_tensor_set_async(backend, tensor, (const char *) data + i*stride_data, offset + i*stride_tensor, size); + } + return; + } + if (size == 0) { + return; + } + + GGML_ASSERT(tensor->data != NULL && "tensor not allocated"); + GGML_ASSERT(offset + (n_copies-1)*stride_tensor + size <= ggml_nbytes(tensor) && "tensor write out of bounds"); + backend->iface.set_tensor_2d_async(backend, tensor, data, offset, size, n_copies, stride_tensor, stride_data); +} + +void ggml_backend_tensor_get_2d_async(ggml_backend_t backend, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size, + size_t n_copies, size_t stride_tensor, size_t stride_data) { + GGML_ASSERT(backend); + GGML_ASSERT(tensor); + GGML_ASSERT(tensor->data != NULL && "tensor not allocated"); + + if (n_copies <= 1 || backend->iface.get_tensor_2d_async == NULL) { + for (size_t i = 0; i < n_copies; i++) { + ggml_backend_tensor_get_async(backend, tensor, (char *) data + i*stride_data, offset + i*stride_tensor, size); + } + return; + } + if (size == 0) { + return; + } + + GGML_ASSERT(tensor->data != NULL && "tensor not allocated"); + GGML_ASSERT(offset + (n_copies-1)*stride_tensor + size <= ggml_nbytes(tensor) && "tensor read out of bounds"); + backend->iface.get_tensor_2d_async(backend, tensor, data, offset, size, n_copies, stride_tensor, stride_data); +} + void ggml_backend_tensor_set(struct ggml_tensor * tensor, const void * data, size_t offset, size_t size) { GGML_ASSERT(tensor); ggml_backend_buffer_t buf = tensor->view_src ? tensor->view_src->buffer : tensor->buffer; + GGML_ASSERT(buf != NULL && "tensor buffer not set"); if (size == 0) { return; } - GGML_ASSERT(buf != NULL && "tensor buffer not set"); GGML_ASSERT(tensor->data != NULL && "tensor not allocated"); GGML_ASSERT(offset + size <= ggml_nbytes(tensor) && "tensor write out of bounds"); @@ -295,18 +339,62 @@ void ggml_backend_tensor_set(struct ggml_tensor * tensor, const void * data, siz void ggml_backend_tensor_get(const struct ggml_tensor * tensor, void * data, size_t offset, size_t size) { GGML_ASSERT(tensor); ggml_backend_buffer_t buf = tensor->view_src ? tensor->view_src->buffer : tensor->buffer; + GGML_ASSERT(buf != NULL && "tensor buffer not set"); if (size == 0) { return; } - GGML_ASSERT(buf != NULL && "tensor buffer not set"); GGML_ASSERT(tensor->data != NULL && "tensor not allocated"); GGML_ASSERT(offset + size <= ggml_nbytes(tensor) && "tensor read out of bounds"); buf->iface.get_tensor(buf, tensor, data, offset, size); } +void ggml_backend_tensor_set_2d(struct ggml_tensor * tensor, const void * data, size_t offset, size_t size, + size_t n_copies, size_t stride_tensor, size_t stride_data) { + GGML_ASSERT(tensor); + ggml_backend_buffer_t buf = tensor->view_src ? tensor->view_src->buffer : tensor->buffer; + GGML_ASSERT(buf != NULL && "tensor buffer not set"); + + if (n_copies <= 1 || buf->iface.set_tensor_2d == NULL) { + for (size_t i = 0; i < n_copies; i++) { + ggml_backend_tensor_set(tensor, (const char *) data + i*stride_data, offset + i*stride_tensor, size); + } + return; + } + if (size == 0) { + return; + } + + GGML_ASSERT(tensor->data != NULL && "tensor not allocated"); + GGML_ASSERT(offset + (n_copies-1)*stride_tensor + size <= ggml_nbytes(tensor) && "tensor write out of bounds"); + + buf->iface.set_tensor_2d(buf, tensor, data, offset, size, n_copies, stride_tensor, stride_data); +} + +void ggml_backend_tensor_get_2d(const struct ggml_tensor * tensor, void * data, size_t offset, size_t size, + size_t n_copies, size_t stride_tensor, size_t stride_data) { + GGML_ASSERT(tensor); + ggml_backend_buffer_t buf = tensor->view_src ? tensor->view_src->buffer : tensor->buffer; + GGML_ASSERT(buf != NULL && "tensor buffer not set"); + + if (n_copies <= 1 || buf->iface.get_tensor_2d == NULL) { + for (size_t i = 0; i < n_copies; i++) { + ggml_backend_tensor_get(tensor, (char *) data + i*stride_data, offset + i*stride_tensor, size); + } + return; + } + if (size == 0) { + return; + } + + GGML_ASSERT(tensor->data != NULL && "tensor not allocated"); + GGML_ASSERT(offset + (n_copies-1)*stride_tensor + size <= ggml_nbytes(tensor) && "tensor read out of bounds"); + + buf->iface.get_tensor_2d(buf, tensor, data, offset, size, n_copies, stride_tensor, stride_data); +} + void ggml_backend_tensor_memset(struct ggml_tensor * tensor, uint8_t value, size_t offset, size_t size) { GGML_ASSERT(tensor); ggml_backend_buffer_t buf = tensor->view_src ? tensor->view_src->buffer : tensor->buffer; @@ -386,7 +474,7 @@ ggml_backend_dev_t ggml_backend_get_device(ggml_backend_t backend) { // backend copy -void ggml_backend_tensor_copy(struct ggml_tensor * src, struct ggml_tensor * dst) { +void ggml_backend_tensor_copy(const struct ggml_tensor * src, struct ggml_tensor * dst) { GGML_ASSERT(ggml_are_same_layout(src, dst) && "cannot copy tensors with different layouts"); if (src == dst) { @@ -400,7 +488,7 @@ void ggml_backend_tensor_copy(struct ggml_tensor * src, struct ggml_tensor * dst } else if (!ggml_backend_buffer_copy_tensor(src, dst)) { #ifndef NDEBUG GGML_LOG_DEBUG("%s: warning: slow copy from %s to %s\n", __func__, ggml_backend_buffer_name(src->buffer), ggml_backend_buffer_name(dst->buffer)); -#endif +#endif // NDEBUG size_t nbytes = ggml_nbytes(src); void * data = malloc(nbytes); ggml_backend_tensor_get(src, data, 0, nbytes); @@ -409,7 +497,7 @@ void ggml_backend_tensor_copy(struct ggml_tensor * src, struct ggml_tensor * dst } } -void ggml_backend_tensor_copy_async(ggml_backend_t backend_src, ggml_backend_t backend_dst, struct ggml_tensor * src, struct ggml_tensor * dst) { +void ggml_backend_tensor_copy_async(ggml_backend_t backend_src, ggml_backend_t backend_dst, const struct ggml_tensor * src, struct ggml_tensor * dst) { GGML_ASSERT(ggml_are_same_layout(src, dst) && "cannot copy tensors with different layouts"); if (src == dst) { @@ -498,6 +586,7 @@ enum ggml_backend_dev_type ggml_backend_dev_type(ggml_backend_dev_t device) { } void ggml_backend_dev_get_props(ggml_backend_dev_t device, struct ggml_backend_dev_props * props) { + GGML_ASSERT(device); memset(props, 0, sizeof(*props)); device->iface.get_props(device, props); } @@ -608,6 +697,8 @@ static const struct ggml_backend_buffer_i ggml_backend_multi_buffer_i = { /* .memset_tensor = */ NULL, /* .set_tensor = */ NULL, /* .get_tensor = */ NULL, + /* .set_tensor_2d = */ NULL, + /* .get_tensor_2d = */ NULL, /* .cpy_tensor = */ NULL, /* .clear = */ ggml_backend_multi_buffer_clear, /* .reset = */ NULL, @@ -874,9 +965,9 @@ static void ggml_backend_sched_print_assignments(ggml_backend_sched_t sched, str } if (sched->debug > 1) { ggml_backend_t tensor_backend = ggml_backend_sched_get_tensor_backend(sched, node); - GGML_LOG_DEBUG("node #%3d (%10.10s): %20.20s (%5.5s) [%5.5s %8.8s] use=%d:", i, ggml_op_name(node->op), node->name, + GGML_LOG_DEBUG("node #%3d (%10.10s): %20.20s (%5.5s) [%5.5s %8.8s] use=%d,c=%d:", i, ggml_op_desc(node), node->name, fmt_size(ggml_nbytes(node)), tensor_backend ? ggml_backend_name(tensor_backend) : "NULL", GET_CAUSE(node), - graph->use_counts[ggml_hash_find(&graph->visited_hash_set, node)]); + graph->use_counts[ggml_hash_find(&graph->visited_hash_set, node)], node->flags & GGML_TENSOR_FLAG_COMPUTE ? 1 : 0); for (int j = 0; j < GGML_MAX_SRC; j++) { struct ggml_tensor * src = node->src[j]; if (src == NULL) { @@ -939,6 +1030,8 @@ void ggml_backend_sched_split_graph(ggml_backend_sched_t sched, struct ggml_cgra GGML_ABORT("%s: failed to initialize context\n", __func__); } + graph->uid = ggml_graph_next_uid(); + // pass 1: assign backends to ops with pre-allocated inputs for (int i = 0; i < graph->n_leafs; i++) { struct ggml_tensor * leaf = graph->leafs[i]; @@ -1386,6 +1479,11 @@ void ggml_backend_sched_split_graph(ggml_backend_sched_t sched, struct ggml_cgra assert(graph_copy->size > graph_copy->n_leafs); graph_copy->leafs[graph_copy->n_leafs++] = leaf; } + + // set ids for all splits + for (int i = 0; i < sched->n_splits; ++i) { + sched->splits[i].graph.uid = ggml_graph_next_uid(); + } } static bool ggml_backend_sched_alloc_splits(ggml_backend_sched_t sched) { @@ -1897,8 +1995,9 @@ enum ggml_status ggml_backend_tensor_alloc(ggml_backend_buffer_t buffer, struct GGML_ASSERT(tensor->data == NULL); GGML_ASSERT(tensor->view_src == NULL); GGML_ASSERT(addr >= ggml_backend_buffer_get_base(buffer)); - GGML_ASSERT((char *)addr + ggml_backend_buffer_get_alloc_size(buffer, tensor) <= - (char *)ggml_backend_buffer_get_base(buffer) + ggml_backend_buffer_get_size(buffer)); + GGML_ASSERT(ggml_backend_buffer_is_meta(buffer) || + (char *) addr + ggml_backend_buffer_get_alloc_size(buffer, tensor) <= + (char *) ggml_backend_buffer_get_base(buffer) + ggml_backend_buffer_get_size(buffer)); tensor->buffer = buffer; tensor->data = addr; @@ -1922,6 +2021,7 @@ static struct ggml_tensor * graph_copy_dup_tensor(struct ggml_hash_set hash_set, dst->view_offs = src->view_offs; } dst->op = src->op; + dst->flags = src->flags; memcpy(dst->op_params, src->op_params, sizeof(dst->op_params)); ggml_set_name(dst, src->name); @@ -2171,6 +2271,8 @@ static const struct ggml_backend_buffer_i ggml_backend_cpu_buffer_i = { /* .memset_tensor = */ ggml_backend_cpu_buffer_memset_tensor, /* .set_tensor = */ ggml_backend_cpu_buffer_set_tensor, /* .get_tensor = */ ggml_backend_cpu_buffer_get_tensor, + /* .set_tensor_2d = */ NULL, + /* .get_tensor_2d = */ NULL, /* .cpy_tensor = */ ggml_backend_cpu_buffer_cpy_tensor, /* .clear = */ ggml_backend_cpu_buffer_clear, /* .reset = */ NULL, @@ -2183,6 +2285,8 @@ static const struct ggml_backend_buffer_i ggml_backend_cpu_buffer_from_ptr_i = { /* .memset_tensor = */ ggml_backend_cpu_buffer_memset_tensor, /* .set_tensor = */ ggml_backend_cpu_buffer_set_tensor, /* .get_tensor = */ ggml_backend_cpu_buffer_get_tensor, + /* .set_tensor_2d = */ NULL, + /* .get_tensor_2d = */ NULL, /* .cpy_tensor = */ ggml_backend_cpu_buffer_cpy_tensor, /* .clear = */ ggml_backend_cpu_buffer_clear, /* .reset = */ NULL, diff --git a/ggml/src/ggml-blas/CMakeLists.txt b/ggml/src/ggml-blas/CMakeLists.txt index fb0936f47b7..c27dc174c00 100644 --- a/ggml/src/ggml-blas/CMakeLists.txt +++ b/ggml/src/ggml-blas/CMakeLists.txt @@ -93,7 +93,7 @@ if (BLAS_FOUND) endif() target_link_libraries (ggml-blas PRIVATE ${BLAS_LIBRARIES}) - target_include_directories(ggml-blas PRIVATE ${BLAS_INCLUDE_DIRS}) + target_include_directories(ggml-blas SYSTEM PRIVATE ${BLAS_INCLUDE_DIRS}) else() message(FATAL_ERROR "BLAS not found, please refer to " "https://cmake.org/cmake/help/latest/module/FindBLAS.html#blas-lapack-vendors" diff --git a/ggml/src/ggml-blas/ggml-blas.cpp b/ggml/src/ggml-blas/ggml-blas.cpp index 84956cbb9ce..b4c735267e0 100644 --- a/ggml/src/ggml-blas/ggml-blas.cpp +++ b/ggml/src/ggml-blas/ggml-blas.cpp @@ -121,6 +121,8 @@ static void ggml_backend_blas_mul_mat(ggml_backend_blas_context * ctx, struct gg bli_thread_set_num_threads(ctx->n_threads); #elif defined(GGML_BLAS_USE_NVPL) nvpl_blas_set_num_threads(ctx->n_threads); +#elif defined(GGML_BLAS_USE_MKL) + mkl_set_num_threads(ctx->n_threads); #endif for (int64_t i13 = 0; i13 < ne13; i13++) { @@ -226,6 +228,10 @@ static enum ggml_status ggml_backend_blas_graph_compute(ggml_backend_t backend, for (int i = 0; i < cgraph->n_nodes; i++) { struct ggml_tensor * node = cgraph->nodes[i]; + if ((node->flags & GGML_TENSOR_FLAG_COMPUTE) == 0) { + continue; + } + switch (node->op) { case GGML_OP_MUL_MAT: ggml_backend_blas_mul_mat(ctx, node); @@ -257,6 +263,8 @@ static struct ggml_backend_i blas_backend_i = { /* .free = */ ggml_backend_blas_free, /* .set_tensor_async = */ NULL, /* .get_tensor_async = */ NULL, + /* .set_tensor_2d_async = */ NULL, + /* .get_tensor_2d_async = */ NULL, /* .cpy_tensor_async = */ NULL, /* .synchronize = */ NULL, /* .graph_plan_create = */ NULL, @@ -335,8 +343,8 @@ static const char * ggml_backend_blas_device_get_description(ggml_backend_dev_t } static void ggml_backend_blas_device_get_memory(ggml_backend_dev_t dev, size_t * free, size_t * total) { - // TODO - *free = 0; + // no memory to report + *free = 0; *total = 0; GGML_UNUSED(dev); diff --git a/ggml/src/ggml-cann/acl_tensor.cpp b/ggml/src/ggml-cann/acl_tensor.cpp index 7b7042a1f54..e95d3c4d88d 100644 --- a/ggml/src/ggml-cann/acl_tensor.cpp +++ b/ggml/src/ggml-cann/acl_tensor.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2023-2024 The ggml authors + * Copyright (c) 2023-2026 The ggml authors * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to diff --git a/ggml/src/ggml-cann/acl_tensor.h b/ggml/src/ggml-cann/acl_tensor.h index 7deac383420..4737773a4d4 100644 --- a/ggml/src/ggml-cann/acl_tensor.h +++ b/ggml/src/ggml-cann/acl_tensor.h @@ -1,5 +1,5 @@ /* - * Copyright (c) 2023-2024 The ggml authors + * Copyright (c) 2023-2026 The ggml authors * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to diff --git a/ggml/src/ggml-cann/aclnn_ops.cpp b/ggml/src/ggml-cann/aclnn_ops.cpp index 6b718e01c31..2dc0f40917d 100644 --- a/ggml/src/ggml-cann/aclnn_ops.cpp +++ b/ggml/src/ggml-cann/aclnn_ops.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2023-2024 The ggml authors + * Copyright (c) 2023-2026 The ggml authors * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to @@ -25,6 +25,7 @@ #include "ggml-impl.h" #include "ggml.h" + #include <aclnnop/aclnn_add.h> #include <aclnnop/aclnn_add_rms_norm.h> #include <aclnnop/aclnn_addcdiv.h> @@ -45,7 +46,9 @@ #include <aclnnop/aclnn_fused_infer_attention_score_v2.h> #include <aclnnop/aclnn_ger.h> #include <aclnnop/aclnn_group_norm.h> +#include <aclnnop/aclnn_gather_v2.h> #include <aclnnop/aclnn_grouped_matmul_v3.h> +#include <aclnnop/aclnn_scatter.h> #include <aclnnop/aclnn_gt_scalar.h> #include <aclnnop/aclnn_im2col.h> #include <aclnnop/aclnn_index_copy.h> @@ -58,9 +61,11 @@ #include <aclnnop/aclnn_mean.h> #include <aclnnop/aclnn_mm.h> #include <aclnnop/aclnn_mul.h> +#include <aclnnop/aclnn_mv.h> #include <aclnnop/aclnn_permute.h> #include <aclnnop/aclnn_pow.h> #include <aclnnop/aclnn_pow_tensor_tensor.h> +#include <aclnnop/aclnn_recurrent_gated_delta_rule.h> #include <aclnnop/aclnn_reduce_sum.h> #include <aclnnop/aclnn_reflection_pad1d.h> #include <aclnnop/aclnn_repeat.h> @@ -68,11 +73,15 @@ #include <aclnnop/aclnn_rms_norm.h> #include <aclnnop/aclnn_roll.h> #include <aclnnop/aclnn_softmax.h> +#include <aclnnop/aclnn_softmax_cross_entropy_with_logits.h> #include <aclnnop/aclnn_sub.h> #include <aclnnop/aclnn_sum.h> #include <aclnnop/aclnn_threshold.h> #include <aclnnop/aclnn_tril.h> +#include <aclnnop/aclnn_triangular_solve.h> #include <aclnnop/aclnn_triu.h> +#include <aclnnop/aclnn_logical_not.h> +#include <aclnnop/aclnn_masked_fill_scalar.h> #include <aclnnop/aclnn_upsample_nearest_2d.h> #include <aclnnop/aclnn_weight_quant_batch_matmul_v2.h> #include <aclnnop/aclnn_zero.h> @@ -150,6 +159,107 @@ void ggml_cann_op_unary_gated(std::function<void(ggml_backend_cann_context &, ac GGML_CANN_CALL_ACLNN_OP(ctx, InplaceMul, acl_dst.get(), acl_src1.get()); } +// Fused SwiGLU using aclnnSwiGlu: splits input along innermost dim, applies +// SiLU to left half, multiplies by right half. +// +// Falls back to the generic two-kernel path when src[1] != nullptr (two +// independent halves) or swapped != 0 (reversed activation order), as +// aclnnSwiGlu only handles the single interleaved tensor in standard order. +// +// CANN tiling for SwiGlu requires (storageShapeDim + viewDims) to be even. +// aclCreateTensor always uses storageShapeDim=1, so viewDims must be odd. +// We use a 3D view (1+3=4, even) to satisfy this constraint while preserving +// correct split semantics along the innermost (ne[0]) dimension. +void ggml_cann_swiglu(ggml_backend_cann_context & ctx, ggml_tensor * dst) { + auto silu_fn = [](ggml_backend_cann_context & ctx, aclTensor * acl_src, aclTensor * acl_dst) { + GGML_CANN_CALL_ACLNN_OP(ctx, Silu, acl_src, acl_dst); + }; + + const int32_t swapped = ggml_get_op_params_i32(dst, 1); + if (dst->src[1] != nullptr || swapped != 0) { + ggml_cann_op_unary_gated(silu_fn, ctx, dst); + return; + } + + // aclnnSwiGlu requires the split dim (src->ne[0]) to be even; fall back otherwise. + if (dst->src[0]->ne[0] % 2 != 0) { + ggml_cann_op_unary_gated(silu_fn, ctx, dst); + return; + } + + ggml_tensor * src0 = dst->src[0]; + size_t elem_size = ggml_element_size(src0); + + // src0 GGML: [2*ne0, ne1, ne2, ne3] → 3D view [2*ne0, ne1, ne2*ne3] + // CANN reversed: [ne2*ne3, ne1, 2*ne0], split along CANN dim 2 (last). + int64_t ne0_x2 = src0->ne[0]; + int64_t ne1 = src0->ne[1]; + int64_t ne23 = src0->ne[2] * src0->ne[3]; + int64_t src3d_ne[] = { ne0_x2, ne1, ne23 }; + size_t src3d_nb[] = { (size_t)src0->nb[0], (size_t)src0->nb[1], (size_t)src0->nb[2] }; + acl_tensor_ptr acl_src = ggml_cann_create_tensor(src0->data, ggml_cann_type_mapping(src0->type), + elem_size, src3d_ne, src3d_nb, 3); + + // dst GGML: [ne0, ne1, ne2, ne3] → 3D view [ne0, ne1, ne2*ne3] + int64_t ne0 = dst->ne[0]; + int64_t dst3d_ne[] = { ne0, ne1, ne23 }; + size_t dst3d_nb[] = { (size_t)dst->nb[0], (size_t)dst->nb[1], (size_t)dst->nb[2] }; + acl_tensor_ptr acl_dst = ggml_cann_create_tensor(dst->data, ggml_cann_type_mapping(dst->type), + elem_size, dst3d_ne, dst3d_nb, 3); + + // CANN tensor [ne23, ne1, 2*ne0]: split along CANN dim 2 (last) = 2*ne0. + GGML_CANN_CALL_ACLNN_OP(ctx, SwiGlu, acl_src.get(), (int64_t)2, acl_dst.get()); +} + +// Fused GeGLU using aclnnGeGluV3: splits input along ne[0] (CANN last dim), +// activates the LEFT half with GELU, multiplies by right half. +// approximate: 0=tanh, 1=none(erf). activateLeft=true matches GGML convention. +// outGelu is a required-but-discard output buffer. +// +// Falls back to the generic two-kernel path when src[1] != nullptr (two +// independent halves) or swapped != 0 (reversed activation order), as +// aclnnGeGluV3 only handles the single interleaved tensor in standard order. +void ggml_cann_geglu(ggml_backend_cann_context & ctx, ggml_tensor * dst, int64_t approximate) { + auto gelu_fn = [](ggml_backend_cann_context & ctx, aclTensor * acl_src, aclTensor * acl_dst) { + GGML_CANN_CALL_ACLNN_OP(ctx, Gelu, acl_src, acl_dst); + }; + + const int32_t swapped = ggml_get_op_params_i32(dst, 1); + if (dst->src[1] != nullptr || swapped != 0) { + ggml_cann_op_unary_gated(gelu_fn, ctx, dst); + return; + } + + // aclnnGeGluV3 requires the split dim (src->ne[0]) to be even; fall back otherwise. + if (dst->src[0]->ne[0] % 2 != 0) { + ggml_cann_op_unary_gated(gelu_fn, ctx, dst); + return; + } + + ggml_tensor * src0 = dst->src[0]; + acl_tensor_ptr acl_src = ggml_cann_create_tensor(src0); + acl_tensor_ptr acl_dst = ggml_cann_create_tensor(dst); + + // Allocate a temporary buffer for the required outGelu output (same shape as dst). + // Build contiguous strides since the pool allocation is a fresh buffer. + size_t elem_size = ggml_element_size(dst); + int64_t ne[GGML_MAX_DIMS] = { dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3] }; + size_t nb[GGML_MAX_DIMS]; + nb[0] = elem_size; + for (int i = 1; i < GGML_MAX_DIMS; i++) { + nb[i] = nb[i - 1] * ne[i - 1]; + } + size_t gelu_out_size = nb[GGML_MAX_DIMS - 1] * ne[GGML_MAX_DIMS - 1]; + ggml_cann_pool_alloc gelu_out_alloc(ctx.pool(), gelu_out_size); + + acl_tensor_ptr acl_gelu_out = ggml_cann_create_tensor( + gelu_out_alloc.get(), ggml_cann_type_mapping(dst->type), elem_size, ne, nb, GGML_MAX_DIMS); + // V3 adds activateLeft param; true → Gelu(left)*right, matching GGML convention. + // GGML dim 0 → CANN last dim (index GGML_MAX_DIMS-1 = 3 for 4D tensor). + GGML_CANN_CALL_ACLNN_OP(ctx, GeGluV3, acl_src.get(), (int64_t)(GGML_MAX_DIMS - 1), approximate, true, + acl_dst.get(), acl_gelu_out.get()); +} + /** * @brief Repeats elements of a tensor along each dimension according to the * specified repeat array. @@ -433,6 +543,9 @@ void ggml_cann_norm(ggml_backend_cann_context & ctx, ggml_tensor * dst) { void ggml_cann_l2_norm(ggml_backend_cann_context & ctx, ggml_tensor * dst) { ggml_tensor * src = dst->src[0]; + float eps; + memcpy(&eps, dst->op_params, sizeof(float)); + acl_tensor_ptr acl_src = ggml_cann_create_tensor(src); acl_tensor_ptr acl_dst = ggml_cann_create_tensor(dst); @@ -441,21 +554,33 @@ void ggml_cann_l2_norm(ggml_backend_cann_context & ctx, ggml_tensor * dst) { ggml_cann_pool_alloc temp_buffer_allocator(ctx.pool(), n_bytes); void * buffer = temp_buffer_allocator.get(); - int64_t div_ne[] = { 1, src->ne[1], src->ne[2], src->ne[3] }; - size_t div_nb[GGML_MAX_DIMS]; - div_nb[0] = sizeof(float); + int64_t norm_ne[] = { 1, src->ne[1], src->ne[2], src->ne[3] }; + size_t norm_nb[GGML_MAX_DIMS]; + norm_nb[0] = sizeof(float); for (int i = 1; i < GGML_MAX_DIMS; ++i) { - div_nb[i] = div_nb[i - 1] * div_ne[i - 1]; + norm_nb[i] = norm_nb[i - 1] * norm_ne[i - 1]; } - acl_tensor_ptr acl_div = ggml_cann_create_tensor(buffer, ACL_FLOAT, type_size, div_ne, div_nb, GGML_MAX_DIMS); + acl_tensor_ptr acl_norm = ggml_cann_create_tensor(buffer, ACL_FLOAT, sizeof(float), norm_ne, norm_nb, GGML_MAX_DIMS); std::vector<int64_t> norm_dims = { 3 }; acl_int_array_ptr dims_array = ggml_cann_create_int_array(norm_dims.data(), norm_dims.size()); float p_value = 2.0f; acl_scalar_ptr p_scalar = ggml_cann_create_scalar(&p_value, aclDataType::ACL_FLOAT); - GGML_CANN_CALL_ACLNN_OP(ctx, Norm, acl_src.get(), p_scalar.get(), dims_array.get(), true, acl_div.get()); - GGML_CANN_CALL_ACLNN_OP(ctx, Div, acl_src.get(), acl_div.get(), acl_dst.get()); + GGML_CANN_CALL_ACLNN_OP(ctx, Norm, acl_src.get(), p_scalar.get(), dims_array.get(), true, acl_norm.get()); + + ggml_cann_pool_alloc clamp_buffer_allocator(ctx.pool()); + acl_tensor_ptr acl_clamped; + + if (eps > 0.0f) { + void * clamp_buf = clamp_buffer_allocator.alloc(n_bytes); + acl_clamped = ggml_cann_create_tensor(clamp_buf, ACL_FLOAT, sizeof(float), norm_ne, norm_nb, GGML_MAX_DIMS); + acl_scalar_ptr eps_scalar = ggml_cann_create_scalar(&eps, aclDataType::ACL_FLOAT); + GGML_CANN_CALL_ACLNN_OP(ctx, ClampMin, acl_norm.get(), eps_scalar.get(), acl_clamped.get()); + } + + aclTensor * acl_div_input = acl_clamped ? acl_clamped.get() : acl_norm.get(); + GGML_CANN_CALL_ACLNN_OP(ctx, Div, acl_src.get(), acl_div_input, acl_dst.get()); } void ggml_cann_cross_entropy_loss(ggml_backend_cann_context & ctx, ggml_tensor * dst) { @@ -471,56 +596,30 @@ void ggml_cann_cross_entropy_loss(ggml_backend_cann_context & ctx, ggml_tensor * logits_nb[1] = logits_nb[0] * logits_ne[0]; acl_tensor_ptr acl_logits = ggml_cann_create_tensor(src0->data, ACL_FLOAT, sizeof(float), logits_ne, logits_nb, 2); - size_t log_softmax_type_size = sizeof(float); - int64_t log_softmax_n_bytes = nr * nc * log_softmax_type_size; - ggml_cann_pool_alloc log_softmax_allocator(ctx.pool(), log_softmax_n_bytes); - void * log_softmax_buffer = log_softmax_allocator.get(); - - int64_t log_softmax_ne[] = { nc, nr }; - size_t log_softmax_nb[2]; - log_softmax_nb[0] = log_softmax_type_size; - log_softmax_nb[1] = log_softmax_nb[0] * log_softmax_ne[0]; - acl_tensor_ptr acl_log_softmax = ggml_cann_create_tensor(log_softmax_buffer, ACL_FLOAT, log_softmax_type_size, - log_softmax_ne, log_softmax_nb, 2); - - GGML_CANN_CALL_ACLNN_OP(ctx, LogSoftmax, acl_logits.get(), 1, acl_log_softmax.get()); - int64_t labels_ne[] = { nc, nr }; size_t labels_nb[2]; labels_nb[0] = ggml_type_size(src1->type); labels_nb[1] = labels_nb[0] * labels_ne[0]; acl_tensor_ptr acl_labels = ggml_cann_create_tensor(src1->data, ACL_FLOAT, sizeof(float), labels_ne, labels_nb, 2); - size_t mul_type_size = sizeof(float); - int64_t mul_n_bytes = nr * nc * mul_type_size; - ggml_cann_pool_alloc mul_allocator(ctx.pool(), mul_n_bytes); - void * mul_buffer = mul_allocator.get(); + size_t loss_per_sample_type_size = sizeof(float); + int64_t loss_per_sample_n_bytes = nr * loss_per_sample_type_size; + ggml_cann_pool_alloc loss_per_sample_allocator(ctx.pool(), loss_per_sample_n_bytes); + void * loss_per_sample_buffer = loss_per_sample_allocator.get(); - int64_t mul_ne[] = { nc, nr }; - size_t mul_nb[2]; - mul_nb[0] = mul_type_size; - mul_nb[1] = mul_nb[0] * mul_ne[0]; - acl_tensor_ptr acl_mul_result = ggml_cann_create_tensor(mul_buffer, ACL_FLOAT, mul_type_size, mul_ne, mul_nb, 2); + int64_t loss_per_sample_ne[] = { nr }; + size_t loss_per_sample_nb[1]; + loss_per_sample_nb[0] = loss_per_sample_type_size; + acl_tensor_ptr acl_loss_per_sample = ggml_cann_create_tensor( + loss_per_sample_buffer, ACL_FLOAT, loss_per_sample_type_size, loss_per_sample_ne, loss_per_sample_nb, 1); - GGML_CANN_CALL_ACLNN_OP(ctx, Mul, acl_log_softmax.get(), acl_labels.get(), acl_mul_result.get()); + size_t backprop_n_bytes = nr * nc * sizeof(float); + ggml_cann_pool_alloc backprop_allocator(ctx.pool(), backprop_n_bytes); + void * backprop_buffer = backprop_allocator.get(); + acl_tensor_ptr acl_backprop = ggml_cann_create_tensor(backprop_buffer, ACL_FLOAT, sizeof(float), logits_ne, logits_nb, 2); - size_t sum_per_sample_type_size = sizeof(float); - int64_t sum_per_sample_n_bytes = nr * sum_per_sample_type_size; - ggml_cann_pool_alloc sum_per_sample_allocator(ctx.pool(), sum_per_sample_n_bytes); - void * sum_per_sample_buffer = sum_per_sample_allocator.get(); - - int64_t sum_per_sample_ne[] = { nr }; - size_t sum_per_sample_nb[1]; - sum_per_sample_nb[0] = sum_per_sample_type_size; - acl_tensor_ptr acl_sum_per_sample = ggml_cann_create_tensor( - sum_per_sample_buffer, ACL_FLOAT, sum_per_sample_type_size, sum_per_sample_ne, sum_per_sample_nb, 1); - - std::vector<int64_t> sum_dims = { 1 }; - acl_int_array_ptr dims_array = ggml_cann_create_int_array(sum_dims.data(), sum_dims.size()); - bool keep_dims = false; - - GGML_CANN_CALL_ACLNN_OP(ctx, ReduceSum, acl_mul_result.get(), dims_array.get(), keep_dims, ACL_FLOAT, - acl_sum_per_sample.get()); + GGML_CANN_CALL_ACLNN_OP(ctx, SoftmaxCrossEntropyWithLogits, acl_logits.get(), acl_labels.get(), + acl_loss_per_sample.get(), acl_backprop.get()); size_t total_sum_type_size = sizeof(float); int64_t total_sum_n_bytes = 1 * total_sum_type_size; @@ -536,11 +635,12 @@ void ggml_cann_cross_entropy_loss(ggml_backend_cann_context & ctx, ggml_tensor * std::vector<int64_t> total_sum_dims = { 0 }; acl_int_array_ptr total_sum_dims_array = ggml_cann_create_int_array(total_sum_dims.data(), total_sum_dims.size()); + bool keep_dims = false; - GGML_CANN_CALL_ACLNN_OP(ctx, ReduceSum, acl_sum_per_sample.get(), total_sum_dims_array.get(), keep_dims, ACL_FLOAT, + GGML_CANN_CALL_ACLNN_OP(ctx, ReduceSum, acl_loss_per_sample.get(), total_sum_dims_array.get(), keep_dims, ACL_FLOAT, acl_total_sum.get()); - float value = -1.0f / static_cast<float>(nr); + float value = 1.0f / static_cast<float>(nr); acl_scalar_ptr scale_factor = ggml_cann_create_scalar(&value, aclDataType::ACL_FLOAT); acl_tensor_ptr acl_dst = ggml_cann_create_tensor(dst->data, ACL_FLOAT, sizeof(float), total_sum_ne, total_sum_nb, 1); @@ -578,6 +678,33 @@ void ggml_cann_group_norm(ggml_backend_cann_context & ctx, ggml_tensor * dst) { acl_mean_out.get(), acl_rstd_out.get()); } +void ggml_cann_set(ggml_backend_cann_context & ctx, ggml_tensor * dst) { + ggml_tensor * src0 = dst->src[0]; + ggml_tensor * src1 = dst->src[1]; + + size_t nb1 = ((int32_t *) dst->op_params)[0]; + size_t nb2 = ((int32_t *) dst->op_params)[1]; + size_t nb3 = ((int32_t *) dst->op_params)[2]; + size_t offset = ((int32_t *) dst->op_params)[3]; + bool inplace = (bool) ((int32_t *) dst->op_params)[4]; + + size_t param_nb[] = { ggml_element_size(src0), nb1, nb2, nb3 }; + + // Create a view of dst at the target offset with src1's dimensions + acl_tensor_ptr acl_dst = ggml_cann_create_tensor(dst, src1->ne, param_nb, GGML_MAX_DIMS, ACL_FORMAT_ND, offset); + acl_tensor_ptr acl_src1 = ggml_cann_create_tensor(src1); + + if (!inplace) { + // First copy src0 to dst entirely + size_t cpy_size = ggml_nbytes(dst); + ACL_CHECK( + aclrtMemcpyAsync(dst->data, cpy_size, src0->data, cpy_size, ACL_MEMCPY_DEVICE_TO_DEVICE, ctx.stream())); + } + + // Copy src1 into the target region of dst + GGML_CANN_CALL_ACLNN_OP(ctx, InplaceCopy, acl_dst.get(), acl_src1.get()); +} + void ggml_cann_acc(ggml_backend_cann_context & ctx, ggml_tensor * dst) { ggml_tensor * src0 = dst->src[0]; ggml_tensor * src1 = dst->src[1]; @@ -641,6 +768,113 @@ void ggml_cann_sum(ggml_backend_cann_context & ctx, ggml_tensor * dst) { aclnn_reduce_sum(ctx, dst, reduce_dims, 4); } +void ggml_cann_cumsum(ggml_backend_cann_context & ctx, ggml_tensor * dst) { + ggml_tensor * src = dst->src[0]; + acl_tensor_ptr acl_src = ggml_cann_create_tensor(src); + acl_tensor_ptr acl_dst = ggml_cann_create_tensor(dst); + // GGML cumsum operates along dim 0 (innermost / ne[0]). + // ggml_cann_create_tensor reverses dimensions to [ne3,ne2,ne1,ne0], + // so GGML dim 0 maps to CANN dim 3 (the last dim of the 4-D tensor). + GGML_CANN_CALL_ACLNN_OP(ctx, Cumsum, acl_src.get(), (int64_t)3, + ggml_cann_type_mapping(dst->type), acl_dst.get()); +} + +void ggml_cann_solve_tri(ggml_backend_cann_context & ctx, ggml_tensor * dst) { + ggml_tensor * src0 = dst->src[0]; // A: [N, N, B2, B3] lower triangular + ggml_tensor * src1 = dst->src[1]; // B: [K, N, B2, B3] + + acl_tensor_ptr acl_a = ggml_cann_create_tensor(src0); + acl_tensor_ptr acl_b = ggml_cann_create_tensor(src1); + acl_tensor_ptr acl_x = ggml_cann_create_tensor(dst); + + // mOut: triangular copy of A (required output), same shape as A. + const size_t a_bytes = ggml_nbytes(src0); + ggml_cann_pool_alloc m_alloc(ctx.pool(), a_bytes); + acl_tensor_ptr acl_m = ggml_cann_create_tensor( + m_alloc.get(), ggml_cann_type_mapping(src0->type), + ggml_type_size(src0->type), src0->ne, src0->nb, GGML_MAX_DIMS); + + // Solve AX = B: upper=false (lower tri), transpose=false, unitriangular=false. + GGML_CANN_CALL_ACLNN_OP(ctx, TriangularSolve, + acl_b.get(), acl_a.get(), false, false, false, + acl_x.get(), acl_m.get()); +} + +void ggml_cann_diag(ggml_backend_cann_context & ctx, ggml_tensor * dst) { + ggml_tensor * src = dst->src[0]; + + GGML_ASSERT(src->ne[1] == 1); + + const int64_t N = src->ne[0]; + const int64_t n_batch = src->ne[2] * src->ne[3]; + const size_t nb_f32 = sizeof(float); + + // Fill dst with zeros. + acl_tensor_ptr acl_dst = ggml_cann_create_tensor(dst); + { + float zero = 0.0f; + acl_scalar_ptr acl_zero = ggml_cann_create_scalar(&zero, ACL_FLOAT); + GGML_CANN_CALL_ACLNN_OP(ctx, InplaceFillScalar, acl_dst.get(), acl_zero.get()); + } + + // Copy src vector onto the diagonal of dst via strided views. + // src viewed as [N, n_batch], contiguous strides. + int64_t ne_vec[2] = { N, n_batch }; + size_t nb_src_vec[2] = { nb_f32, N * nb_f32 }; + // dst diagonal view: stride (N+1)*4 steps along the diagonal. + size_t nb_dst_diag[2] = { (N + 1) * nb_f32, N * N * nb_f32 }; + + acl_tensor_ptr acl_src_vec = ggml_cann_create_tensor(src->data, ACL_FLOAT, nb_f32, ne_vec, nb_src_vec, 2); + acl_tensor_ptr acl_dst_diag = ggml_cann_create_tensor(dst->data, ACL_FLOAT, nb_f32, ne_vec, nb_dst_diag, 2); + + GGML_CANN_CALL_ACLNN_OP(ctx, InplaceCopy, acl_dst_diag.get(), acl_src_vec.get()); +} + +void ggml_cann_fill(ggml_backend_cann_context & ctx, ggml_tensor * dst) { + float c = ggml_get_op_params_f32(dst, 0); + + acl_tensor_ptr acl_dst = ggml_cann_create_tensor(dst); + acl_scalar_ptr acl_c = ggml_cann_create_scalar(&c, ACL_FLOAT); + GGML_CANN_CALL_ACLNN_OP(ctx, InplaceFillScalar, acl_dst.get(), acl_c.get()); +} + +void ggml_cann_tri(ggml_backend_cann_context & ctx, ggml_tensor * dst) { + ggml_tensor * src = dst->src[0]; + + const int64_t S = src->ne[0]; + const int64_t n_batch = src->ne[2] * src->ne[3]; + const size_t nb_f32 = sizeof(float); + + int64_t ne3d[3] = { S, S, n_batch }; + size_t nb3d[3] = { nb_f32, S * nb_f32, S * S * nb_f32 }; + + const ggml_tri_type ttype = (ggml_tri_type) ggml_get_op_params_i32(dst, 0); + + acl_tensor_ptr acl_src = ggml_cann_create_tensor(src->data, ACL_FLOAT, nb_f32, ne3d, nb3d, 3); + acl_tensor_ptr acl_dst = ggml_cann_create_tensor(dst->data, ACL_FLOAT, nb_f32, ne3d, nb3d, 3); + + switch (ttype) { + case GGML_TRI_TYPE_LOWER: + // Tril(-1): preserve row > col (strict lower), zero upper + diagonal. + GGML_CANN_CALL_ACLNN_OP(ctx, Tril, acl_src.get(), (int64_t)-1, acl_dst.get()); + break; + case GGML_TRI_TYPE_UPPER_DIAG: + // Triu(0): preserve row <= col (upper + diagonal), zero strict lower. + GGML_CANN_CALL_ACLNN_OP(ctx, Triu, acl_src.get(), (int64_t)0, acl_dst.get()); + break; + case GGML_TRI_TYPE_UPPER: + // Triu(1): preserve row < col (strict upper), zero lower + diagonal. + GGML_CANN_CALL_ACLNN_OP(ctx, Triu, acl_src.get(), (int64_t)1, acl_dst.get()); + break; + case GGML_TRI_TYPE_LOWER_DIAG: + // Tril(0): preserve row >= col (lower + diagonal), zero strict upper. + GGML_CANN_CALL_ACLNN_OP(ctx, Tril, acl_src.get(), (int64_t)0, acl_dst.get()); + break; + default: + GGML_ABORT("unsupported tri type"); + } +} + void ggml_cann_upsample_nearest2d(ggml_backend_cann_context & ctx, ggml_tensor * dst) { ggml_tensor * src = dst->src[0]; acl_tensor_ptr acl_src = ggml_cann_create_tensor(src, nullptr, nullptr, 0, ACL_FORMAT_NCHW); @@ -1543,8 +1777,8 @@ static void aclnn_get_slope(ggml_backend_cann_context & ctx, end = 2 * ((n_head - 1) - n_head_log2) + 1; step = 2; count = n_head - n_head_log2; - aclnn_get_slope_inner(ctx, (char *) slope_buffer + n_head_log2 * sizeof(float), m1, count, start, end + 1, step, - dtype); + aclnn_get_slope_inner(ctx, (char *) slope_buffer + n_head_log2 * ggml_type_size(dtype), m1, count, start, end + 1, + step, dtype); } } @@ -1684,150 +1918,90 @@ void ggml_cann_softmax(ggml_backend_cann_context & ctx, ggml_tensor * dst) { aclnn_softmax(ctx, softmax_tensor.get(), 3, acl_dst.get()); } -/** - * @brief Performs index select operation on a 4D tensor using the CANN backend. - * - * This function applies the `IndexSelect` operation along a specific dimension - * of the source tensor (`src_buffer`) using the indices from the index tensor (`index`). - * It iterates over the last two dimensions of the source tensor, creates the corresponding - * CANN tensors for the source, index, and output slices, and executes the `IndexSelect` - * operation for each slice. - * - * @param ctx The context for CANN backend operations. - * @param src_buffer The source buffer containing the 4D input tensor data. - * @param src_ne The dimensions of the source tensor. - * @param src_nb The strides (byte offsets) of the source tensor. - * @param dst_buffer The destination buffer where the output tensor data will be written. - * @param dst_ne The dimensions of the destination tensor. - * @param dst_nb The strides (byte offsets) of the destination tensor. - * @param index The index tensor specifying the indices to select from the source tensor. - * @param type The data type of the source and destination tensors. - */ -static void aclnn_index_select_4d(ggml_backend_cann_context & ctx, - void * src_buffer, - int64_t * src_ne, - size_t * src_nb, - void * dst_buffer, - int64_t * dst_ne, - size_t * dst_nb, - ggml_tensor * index, - ggml_type type) { - for (int64_t i = 0; i < src_ne[3]; i++) { - for (int64_t j = 0; j < src_ne[2]; j++) { - // src - acl_tensor_ptr acl_src_tensor = - ggml_cann_create_tensor((char *) src_buffer + i * src_nb[3] + j * src_nb[2], - ggml_cann_type_mapping(type), ggml_type_size(type), src_ne, src_nb, 2); - - // index - acl_tensor_ptr acl_index = ggml_cann_create_tensor( - (char *) index->data + (i % index->ne[2]) * index->nb[2] + (j % index->ne[1]) * index->nb[1], - ggml_cann_type_mapping(index->type), ggml_element_size(index), index->ne, index->nb, 1); - - // out - acl_tensor_ptr acl_out = - ggml_cann_create_tensor((char *) dst_buffer + i * dst_nb[3] + j * dst_nb[2], - ggml_cann_type_mapping(type), ggml_type_size(type), dst_ne, dst_nb, 2); - GGML_CANN_CALL_ACLNN_OP(ctx, IndexSelect, acl_src_tensor.get(), 0, acl_index.get(), acl_out.get()); - } - } -} - -/** - * @brief Performs inplace index copy operation on a 4D tensor using the CANN backend. - * - * This function applies the `IndexCopy` operation along a specific dimension of the - * destination tensor (`dst_buffer`) by copying elements from the source tensor (`src_buffer`) - * to positions specified by the index tensor (`index`). - * It iterates over the last two dimensions of the tensors, creates the corresponding - * CANN tensors for source, index, and destination slices, and performs the index copy - * operation for each slice. - * - * @param ctx The context for CANN backend operations. - * @param src_buffer The source buffer containing the 4D input tensor data to be copied. - * @param src_ne The dimensions of the source tensor. - * @param src_nb The strides (byte offsets) of the source tensor. - * @param dst_buffer The destination buffer where values will be copied to. - * @param dst_ne The dimensions of the destination tensor. - * @param dst_nb The strides (byte offsets) of the destination tensor. - * @param index The index tensor specifying target positions in the destination tensor. - * @param type The data type of the source and destination tensors. - */ -static void aclnn_index_copy_4d(ggml_backend_cann_context & ctx, - void * src_buffer, - int64_t * src_ne, - size_t * src_nb, - void * dst_buffer, - int64_t * dst_ne, - size_t * dst_nb, - ggml_tensor * index, - ggml_type type) { - for (int64_t i = 0; i < src_ne[3]; i++) { - for (int64_t j = 0; j < src_ne[2]; j++) { - // src - acl_tensor_ptr acl_src_tensor = - ggml_cann_create_tensor((char *) src_buffer + i * src_nb[3] + j * src_nb[2], - ggml_cann_type_mapping(type), ggml_type_size(type), src_ne, src_nb, 2); - - // index - acl_tensor_ptr acl_index = ggml_cann_create_tensor( - (char *) index->data + (i % index->ne[2]) * index->nb[2] + (j % index->ne[1]) * index->nb[1], - ggml_cann_type_mapping(index->type), ggml_element_size(index), index->ne, index->nb, 1); - - // out - acl_tensor_ptr acl_out = - ggml_cann_create_tensor((char *) dst_buffer + i * dst_nb[3] + j * dst_nb[2], - ggml_cann_type_mapping(type), ggml_type_size(type), dst_ne, dst_nb, 2); - GGML_CANN_CALL_ACLNN_OP(ctx, InplaceIndexCopy, acl_out.get(), 0, acl_index.get(), acl_src_tensor.get()); - } - } -} void ggml_cann_get_rows(ggml_backend_cann_context & ctx, ggml_tensor * dst) { - ggml_tensor * src0 = dst->src[0]; // src + ggml_tensor * src0 = dst->src[0]; // weight ggml_tensor * src1 = dst->src[1]; // index - GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16); + GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16 + || dst->type == GGML_TYPE_BF16); + + // n_idx: number of row indices per (i2, i3) batch slice. + // ggml guarantees: src0->ne[2] == src1->ne[1], src0->ne[3] == src1->ne[2], src1->ne[3] == 1. + const int64_t n_idx = src1->ne[0]; + + // Gather all (i2, i3) batch slices from src into dst. + // ggml_cann_create_tensor reverses dims, so ACL sees [ne1, ne0]. + // GatherV2 with dim=0 gathers along ACL dim-0 == ggml ne[1] (the vocabulary / row axis). + // nb: the 4 strides of the source buffer (nb[0..1] for the 2D slice shape, + // nb[2..3] for computing per-batch-slice base pointer offsets). + auto gather_batched = [&](void * src_base, aclDataType acl_type, size_t type_size, + const size_t * nb) { + int64_t src_ne[2] = { src0->ne[0], src0->ne[1] }; + size_t src_nb_2d[2] = { nb[0], nb[1] }; + int64_t dst_ne[2] = { src0->ne[0], n_idx }; + size_t dst_nb_2d[2] = { dst->nb[0], dst->nb[1] }; + int64_t idx_ne[1] = { n_idx }; + size_t idx_nb[1] = { (size_t)ggml_element_size(src1) }; + + for (int64_t i3 = 0; i3 < src0->ne[3]; i3++) { + for (int64_t i2 = 0; i2 < src0->ne[2]; i2++) { + acl_tensor_ptr acl_src = ggml_cann_create_tensor( + (char *)src_base + i3 * nb[3] + i2 * nb[2], + acl_type, type_size, src_ne, src_nb_2d, 2); + acl_tensor_ptr acl_idx = ggml_cann_create_tensor( + (char *)src1->data + i3 * src1->nb[2] + i2 * src1->nb[1], + ggml_cann_type_mapping(src1->type), (size_t)ggml_element_size(src1), + idx_ne, idx_nb, 1); + acl_tensor_ptr acl_dst = ggml_cann_create_tensor( + (char *)dst->data + i3 * dst->nb[3] + i2 * dst->nb[2], + acl_type, type_size, dst_ne, dst_nb_2d, 2); + GGML_CANN_CALL_ACLNN_OP(ctx, GatherV2, acl_src.get(), 0, acl_idx.get(), acl_dst.get()); + } + } + }; switch (src0->type) { + case GGML_TYPE_BF16: case GGML_TYPE_F16: case GGML_TYPE_F32: if (src0->type == dst->type) { - aclnn_index_select_4d(ctx, src0->data, src0->ne, src0->nb, dst->data, dst->ne, dst->nb, src1, - dst->type); + gather_batched(src0->data, + ggml_cann_type_mapping(src0->type), ggml_type_size(src0->type), + src0->nb); } else { - acl_tensor_ptr acl_src0 = ggml_cann_create_tensor(src0); - ggml_cann_pool_alloc src_buffer_allocator(ctx.pool(), ggml_nelements(src0) * ggml_element_size(dst)); - void * src_trans_buffer = src_buffer_allocator.get(); - size_t src_trans_nb[GGML_MAX_DIMS]; - src_trans_nb[0] = dst->nb[0]; + // Cast src0 to dst type, then gather. + ggml_cann_pool_alloc src_cast_allocator(ctx.pool(), + ggml_nelements(src0) * ggml_element_size(dst)); + size_t src_cast_nb[GGML_MAX_DIMS]; + src_cast_nb[0] = ggml_type_size(dst->type); for (int i = 1; i < GGML_MAX_DIMS; i++) { - src_trans_nb[i] = src_trans_nb[i - 1] * src0->ne[i - 1]; + src_cast_nb[i] = src_cast_nb[i - 1] * src0->ne[i - 1]; } - acl_tensor_ptr src_trans_tensor = - ggml_cann_create_tensor(src_trans_buffer, ggml_cann_type_mapping(dst->type), - ggml_type_size(dst->type), src0->ne, src_trans_nb, GGML_MAX_DIMS); - aclnn_cast(ctx, acl_src0.get(), src_trans_tensor.get(), ggml_cann_type_mapping(dst->type)); - aclnn_index_select_4d(ctx, src_trans_buffer, src0->ne, src_trans_nb, dst->data, dst->ne, dst->nb, src1, - dst->type); + acl_tensor_ptr acl_src0 = ggml_cann_create_tensor(src0); + acl_tensor_ptr acl_src_cast = ggml_cann_create_tensor( + src_cast_allocator.get(), ggml_cann_type_mapping(dst->type), ggml_type_size(dst->type), + src0->ne, src_cast_nb, GGML_MAX_DIMS); + aclnn_cast(ctx, acl_src0.get(), acl_src_cast.get(), ggml_cann_type_mapping(dst->type)); + + gather_batched(src_cast_allocator.get(), + ggml_cann_type_mapping(dst->type), ggml_type_size(dst->type), + src_cast_nb); } break; case GGML_TYPE_Q8_0: { - // add 1 dim for bcast mul. + // Dequantize Q8_0 to dst type, then gather. size_t weight_nb[GGML_MAX_DIMS + 1], scale_nb[GGML_MAX_DIMS + 1], dequant_nb[GGML_MAX_DIMS + 1]; int64_t weight_ne[GGML_MAX_DIMS + 1], scale_ne[GGML_MAX_DIMS + 1], *dequant_ne; - int64_t scale_offset = 0; - // [3,4,5,64] -> [3,4,5,2,32] - weight_ne[0] = QK8_0; - weight_ne[1] = src0->ne[0] / QK8_0; - weight_nb[0] = sizeof(int8_t); - weight_nb[1] = weight_nb[0] * weight_ne[0]; + weight_ne[0] = QK8_0; + weight_ne[1] = src0->ne[0] / QK8_0; + weight_nb[0] = sizeof(int8_t); + weight_nb[1] = weight_nb[0] * weight_ne[0]; for (int i = 2; i < GGML_MAX_DIMS + 1; i++) { weight_ne[i] = src0->ne[i - 1]; weight_nb[i] = weight_nb[i - 1] * weight_ne[i - 1]; } - // [3,4,5,64] -> [3,4,5,2,1] scale_ne[0] = 1; scale_ne[1] = src0->ne[0] / QK8_0; scale_nb[0] = sizeof(uint16_t); @@ -1836,31 +2010,33 @@ void ggml_cann_get_rows(ggml_backend_cann_context & ctx, ggml_tensor * dst) { scale_ne[i] = src0->ne[i - 1]; scale_nb[i] = scale_nb[i - 1] * scale_ne[i - 1]; } - // [3,4,5,64] -> [3,4,5,2,32] dequant_ne = weight_ne; dequant_nb[0] = ggml_type_size(dst->type); for (int i = 1; i < GGML_MAX_DIMS + 1; i++) { dequant_nb[i] = dequant_nb[i - 1] * dequant_ne[i - 1]; } - scale_offset = ggml_nelements(src0) * sizeof(int8_t); - ggml_cann_pool_alloc dequant_buffer_allocator(ctx.pool(), - ggml_nelements(src0) * ggml_type_size(dst->type)); - acl_tensor_ptr acl_weight_tensor = ggml_cann_create_tensor(src0->data, ACL_INT8, sizeof(int8_t), - weight_ne, weight_nb, GGML_MAX_DIMS + 1); - acl_tensor_ptr acl_scale_tensor = - ggml_cann_create_tensor(src0->data, ACL_FLOAT16, sizeof(uint16_t), scale_ne, scale_nb, - GGML_MAX_DIMS + 1, ACL_FORMAT_ND, scale_offset); - acl_tensor_ptr dequant_tensor = - ggml_cann_create_tensor(dequant_buffer_allocator.get(), ggml_cann_type_mapping(dst->type), - ggml_type_size(dst->type), dequant_ne, dequant_nb, GGML_MAX_DIMS + 1); - aclnn_mul(ctx, acl_weight_tensor.get(), acl_scale_tensor.get(), dequant_tensor.get()); - dequant_nb[0] = ggml_type_size(dst->type); + const int64_t scale_offset = ggml_nelements(src0) * sizeof(int8_t); + ggml_cann_pool_alloc dequant_allocator(ctx.pool(), + ggml_nelements(src0) * ggml_type_size(dst->type)); + acl_tensor_ptr acl_weight = ggml_cann_create_tensor(src0->data, ACL_INT8, sizeof(int8_t), + weight_ne, weight_nb, GGML_MAX_DIMS + 1); + acl_tensor_ptr acl_scale = ggml_cann_create_tensor( + src0->data, ACL_FLOAT16, sizeof(uint16_t), scale_ne, scale_nb, + GGML_MAX_DIMS + 1, ACL_FORMAT_ND, scale_offset); + acl_tensor_ptr acl_dequant = ggml_cann_create_tensor( + dequant_allocator.get(), ggml_cann_type_mapping(dst->type), + ggml_type_size(dst->type), dequant_ne, dequant_nb, GGML_MAX_DIMS + 1); + aclnn_mul(ctx, acl_weight.get(), acl_scale.get(), acl_dequant.get()); + + // Reinterpret dequant buffer as 4D [src0->ne] with contiguous strides. dequant_ne = src0->ne; + dequant_nb[0] = ggml_type_size(dst->type); for (int i = 1; i < GGML_MAX_DIMS; i++) { dequant_nb[i] = dequant_nb[i - 1] * src0->ne[i - 1]; } - aclnn_index_select_4d(ctx, dequant_buffer_allocator.get(), dequant_ne, dequant_nb, dst->data, dst->ne, - dst->nb, src1, dst->type); + gather_batched(dequant_allocator.get(), + ggml_cann_type_mapping(dst->type), ggml_type_size(dst->type), + dequant_nb); break; } default: @@ -1870,30 +2046,70 @@ void ggml_cann_get_rows(ggml_backend_cann_context & ctx, ggml_tensor * dst) { } void ggml_cann_set_rows(ggml_backend_cann_context & ctx, ggml_tensor * dst) { - ggml_tensor * src0 = dst->src[0]; // src - ggml_tensor * src1 = dst->src[1]; // index + ggml_tensor * src0 = dst->src[0]; // source values + ggml_tensor * src1 = dst->src[1]; // row indices + + // n_idx: number of source rows to scatter per batch slice. + // ggml guarantees: src0->ne[1] == src1->ne[0]. + const int64_t n_idx = src1->ne[0]; + + // Copy n_idx rows of src [ne0, n_idx] into dst [ne0, ne1] at positions given by a 1D index. + // ggml_cann_create_tensor reverses dims, so ACL sees [ne1, ne0] for dst. + // InplaceIndexCopy with dim=0 copies along ACL dim-0 == ggml ne[1] (the row axis). + // src_nb: the 4 strides of the source buffer (nb[0..1] for the 2D slice shape, + // nb[2..3] for computing per-batch-slice base pointer offsets). + auto scatter_batched = [&](void * src_base, aclDataType acl_type, size_t type_size, + const size_t * src_nb) { + int64_t d_ne[2] = { dst->ne[0], dst->ne[1] }; + size_t d_nb[2] = { dst->nb[0], dst->nb[1] }; + int64_t s_ne[2] = { dst->ne[0], n_idx }; + size_t s_nb_2d[2] = { src_nb[0], src_nb[1] }; + int64_t i_ne[1] = { n_idx }; + size_t i_nb[1] = { (size_t)ggml_element_size(src1) }; + + for (int64_t i3 = 0; i3 < dst->ne[3]; i3++) { + for (int64_t i2 = 0; i2 < dst->ne[2]; i2++) { + acl_tensor_ptr acl_dst = ggml_cann_create_tensor( + (char *)dst->data + i3 * dst->nb[3] + i2 * dst->nb[2], + acl_type, type_size, d_ne, d_nb, 2); + acl_tensor_ptr acl_idx = ggml_cann_create_tensor( + (char *)src1->data + (i3 % src1->ne[2]) * src1->nb[2] + (i2 % src1->ne[1]) * src1->nb[1], + ggml_cann_type_mapping(src1->type), (size_t)ggml_element_size(src1), + i_ne, i_nb, 1); + acl_tensor_ptr acl_src = ggml_cann_create_tensor( + (char *)src_base + i3 * src_nb[3] + i2 * src_nb[2], + acl_type, type_size, s_ne, s_nb_2d, 2); + GGML_CANN_CALL_ACLNN_OP(ctx, InplaceIndexCopy, acl_dst.get(), 0, acl_idx.get(), acl_src.get()); + } + } + }; switch (dst->type) { case GGML_TYPE_F32: - { - aclnn_index_copy_4d(ctx, src0->data, src0->ne, src0->nb, dst->data, dst->ne, dst->nb, src1, dst->type); - break; - } + scatter_batched(src0->data, + ggml_cann_type_mapping(dst->type), ggml_type_size(dst->type), + src0->nb); + break; case GGML_TYPE_F16: + case GGML_TYPE_BF16: { - acl_tensor_ptr acl_src0 = ggml_cann_create_tensor(src0); - ggml_cann_pool_alloc src_buffer_allocator(ctx.pool(), ggml_nelements(src0) * sizeof(uint16_t)); - void * src_trans_buffer = src_buffer_allocator.get(); - size_t src_trans_nb[GGML_MAX_DIMS]; - src_trans_nb[0] = sizeof(uint16_t); + // Cast src0 (F32) to dst type first. + ggml_cann_pool_alloc src_cast_allocator(ctx.pool(), + ggml_nelements(src0) * ggml_type_size(dst->type)); + size_t src_cast_nb[GGML_MAX_DIMS]; + src_cast_nb[0] = ggml_type_size(dst->type); for (int i = 1; i < GGML_MAX_DIMS; i++) { - src_trans_nb[i] = src_trans_nb[i - 1] * src0->ne[i - 1]; + src_cast_nb[i] = src_cast_nb[i - 1] * src0->ne[i - 1]; } - acl_tensor_ptr src_trans_tensor = ggml_cann_create_tensor( - src_trans_buffer, ACL_FLOAT16, ggml_type_size(dst->type), src0->ne, src_trans_nb, GGML_MAX_DIMS); - aclnn_cast(ctx, acl_src0.get(), src_trans_tensor.get(), ggml_cann_type_mapping(dst->type)); - aclnn_index_copy_4d(ctx, src_trans_buffer, src0->ne, src_trans_nb, dst->data, dst->ne, dst->nb, src1, - dst->type); + acl_tensor_ptr acl_src0 = ggml_cann_create_tensor(src0); + acl_tensor_ptr acl_src_cast = ggml_cann_create_tensor( + src_cast_allocator.get(), ggml_cann_type_mapping(dst->type), ggml_type_size(dst->type), + src0->ne, src_cast_nb, GGML_MAX_DIMS); + aclnn_cast(ctx, acl_src0.get(), acl_src_cast.get(), ggml_cann_type_mapping(dst->type)); + + scatter_batched(src_cast_allocator.get(), + ggml_cann_type_mapping(dst->type), ggml_type_size(dst->type), + src_cast_nb); break; } default: @@ -1964,7 +2180,7 @@ static void ggml_cann_mat_mul_fp(ggml_backend_cann_context & ctx, ggml_tensor * // Only check env once. static bool weight_to_nz = parse_bool(get_env_as_lowercase("GGML_CANN_WEIGHT_NZ").value_or("on")); - if (weight_to_nz && is_matmul_weight(weight)) { + if (weight_to_nz && weight->type != GGML_TYPE_BF16 && is_matmul_weight(weight)) { acl_weight_tensor = ggml_cann_create_tensor(weight, transpose_ne, transpose_nb, n_dims, ACL_FORMAT_FRACTAL_NZ); } else { acl_weight_tensor = ggml_cann_create_tensor(weight, transpose_ne, transpose_nb, n_dims, ACL_FORMAT_ND); @@ -2145,6 +2361,9 @@ void ggml_cann_mul_mat(ggml_backend_cann_context & ctx, ggml_tensor * dst) { switch (type) { case GGML_TYPE_F32: case GGML_TYPE_F16: +#ifndef ASCEND_310P + case GGML_TYPE_BF16: +#endif ggml_cann_mat_mul_fp(ctx, dst); break; case GGML_TYPE_Q4_0: @@ -2338,20 +2557,21 @@ static void aclnn_rope_cache_init(ggml_backend_cann_context & ctx, // Step1.2: prepare rope_yarn_ramp, if this part updated, should update theta_scale_tensor. // TODO: acl_yarn_ramp_tensor use rope cache. - bool yarn_ramp_tensor_updated = false; - acl_tensor_ptr acl_yarn_ramp_tensor; + bool yarn_ramp_tensor_updated = false; + acl_tensor_ptr acl_yarn_ramp_tensor; if (ext_factor != 0 && (theta_scale_updated || ctx.rope_cache.theta_scale_length != theta_scale_length || ctx.rope_cache.freq_scale != freq_scale)) { yarn_ramp_tensor_updated = true; if (ctx.rope_cache.yarn_ramp_cache != nullptr) { ACL_CHECK(aclrtFree(ctx.rope_cache.yarn_ramp_cache)); } - ACL_CHECK(aclrtMalloc(&ctx.rope_cache.yarn_ramp_cache, theta_scale_length * sizeof(float), ACL_MEM_MALLOC_HUGE_FIRST)); + ACL_CHECK(aclrtMalloc(&ctx.rope_cache.yarn_ramp_cache, theta_scale_length * sizeof(float), + ACL_MEM_MALLOC_HUGE_FIRST)); // -rope_yarn_ramp // const float y = (i0 / 2 - low) / MAX(0.001f, high - low); // return MIN(1, MAX(0, y)) - 1; - acl_yarn_ramp_tensor = - ggml_cann_create_tensor(ctx.rope_cache.yarn_ramp_cache, ACL_FLOAT, sizeof(float), theta_scale_ne, theta_scale_nb, 1); + acl_yarn_ramp_tensor = ggml_cann_create_tensor(ctx.rope_cache.yarn_ramp_cache, ACL_FLOAT, sizeof(float), + theta_scale_ne, theta_scale_nb, 1); float zero_value = 0, one_value = 1; float denom_safe_value = MAX(0.001f, corr_dims[1] - corr_dims[0]); acl_scalar_ptr low = ggml_cann_create_scalar(&corr_dims[0], aclDataType::ACL_FLOAT); @@ -2382,8 +2602,8 @@ static void aclnn_rope_cache_init(ggml_backend_cann_context & ctx, GGML_CANN_CALL_ACLNN_OP(ctx, InplaceMuls, acl_yarn_ramp_tensor.get(), freq_scale_1_sc.get()); GGML_CANN_CALL_ACLNN_OP(ctx, InplaceAdds, acl_yarn_ramp_tensor.get(), freq_scale_sc.get(), one.get()); } else { - acl_yarn_ramp_tensor = - ggml_cann_create_tensor(ctx.rope_cache.yarn_ramp_cache, ACL_FLOAT, sizeof(float), theta_scale_ne, theta_scale_nb, 1); + acl_yarn_ramp_tensor = ggml_cann_create_tensor(ctx.rope_cache.yarn_ramp_cache, ACL_FLOAT, sizeof(float), + theta_scale_ne, theta_scale_nb, 1); } // Step 1.3: update theta_scale_tensor according to ext_factor or freq_scale. if (ext_factor != 0) { @@ -2941,6 +3161,27 @@ void ggml_cann_rope(ggml_backend_cann_context & ctx, ggml_tensor * dst) { // Rotate full tensor (no tail), using trans tensors GGML_CANN_CALL_ACLNN_OP(ctx, RotaryPositionEmbedding, acl_src_trans_tensor.get(), acl_cos_reshape_tensor.get(), acl_sin_reshape_tensor.get(), acl_mode, acl_dst_trans_tensor.get()); + } else if (src0->data == dst->data && !ggml_is_contiguous(src0)) { + // In-place on non-contiguous tensor: RotaryPositionEmbedding cannot safely + // read and write the same non-contiguous buffer. Use contiguous temporaries. + size_t contiguous_nb[GGML_MAX_DIMS]; + contiguous_nb[0] = sizeof(float); + for (int i = 1; i < GGML_MAX_DIMS; i++) { + contiguous_nb[i] = contiguous_nb[i - 1] * src0->ne[i - 1]; + } + int64_t total_elements = ggml_nelements(src0); + ggml_cann_pool_alloc inplace_src_alloc(ctx.pool(), total_elements * sizeof(float)); + ggml_cann_pool_alloc inplace_dst_alloc(ctx.pool(), total_elements * sizeof(float)); + + acl_tensor_ptr acl_src_contig = ggml_cann_create_tensor(inplace_src_alloc.get(), ACL_FLOAT, sizeof(float), + src0->ne, contiguous_nb, GGML_MAX_DIMS); + acl_tensor_ptr acl_dst_contig = ggml_cann_create_tensor(inplace_dst_alloc.get(), ACL_FLOAT, sizeof(float), + dst->ne, contiguous_nb, GGML_MAX_DIMS); + + cann_copy(ctx, acl_src.get(), acl_src_contig.get()); + GGML_CANN_CALL_ACLNN_OP(ctx, RotaryPositionEmbedding, acl_src_contig.get(), acl_cos_reshape_tensor.get(), + acl_sin_reshape_tensor.get(), acl_mode, acl_dst_contig.get()); + cann_copy(ctx, acl_dst_contig.get(), acl_dst.get()); } else { // Rotate full tensor (no tail), using original tensors GGML_CANN_CALL_ACLNN_OP(ctx, RotaryPositionEmbedding, acl_src.get(), acl_cos_reshape_tensor.get(), @@ -2982,6 +3223,58 @@ void ggml_cann_rope(ggml_backend_cann_context & ctx, ggml_tensor * dst) { } } +void ggml_cann_rope_cache_preload(ggml_backend_cann_context & ctx, ggml_tensor * dst) { + ggml_tensor * src0 = dst->src[0]; + + float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow; + int sections[4]; + const int n_dims = ((int32_t *) dst->op_params)[1]; + const int mode = ((int32_t *) dst->op_params)[2]; + const int n_ctx_orig = ((int32_t *) dst->op_params)[4]; + + GGML_TENSOR_UNARY_OP_LOCALS + + memcpy(&freq_base, (int32_t *) dst->op_params + 5, sizeof(float)); + memcpy(&freq_scale, (int32_t *) dst->op_params + 6, sizeof(float)); + memcpy(&ext_factor, (int32_t *) dst->op_params + 7, sizeof(float)); + memcpy(&attn_factor, (int32_t *) dst->op_params + 8, sizeof(float)); + memcpy(&beta_fast, (int32_t *) dst->op_params + 9, sizeof(float)); + memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float)); + memcpy(§ions, (int32_t *) dst->op_params + 11, sizeof(int) * 4); + + const float theta_scale = powf(freq_base, -2.0f / n_dims); + + float corr_dims[2]; + ggml_rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims); + + bool is_neox = mode & GGML_ROPE_TYPE_NEOX; + const bool is_imrope = mode == GGML_ROPE_TYPE_IMROPE; + const bool mrope_used = mode & GGML_ROPE_TYPE_MROPE; + const bool is_vision = mode == GGML_ROPE_TYPE_VISION; + + if (is_imrope || mrope_used) { + is_neox = true; + } + + int64_t rope_dims = n_dims; + if (is_vision) { + rope_dims = src0->ne[0]; + } + + // Run the full cache init on the non-captured stream. This performs all + // host-to-device memcpy, aclrtMalloc/Free, and on-device computations + // so that the memory pool is warmed up and cache metadata is populated. + aclnn_rope_cache_init(ctx, dst, corr_dims, ext_factor, theta_scale, freq_scale, attn_factor, is_neox, sections, + mrope_used, is_imrope, is_vision, rope_dims); + + // Reset `cached` so that during graph capture the on-device computations + // (sin/cos, position multiply, repeat, etc.) still execute and get recorded + // into the captured graph. The cache metadata (theta_scale_length, + // theta_scale, sections, position_length, etc.) remains set, which causes + // all host-to-device copy and malloc/free branches to be skipped. + ctx.rope_cache.cached = false; +} + void ggml_cann_argmax(ggml_backend_cann_context & ctx, ggml_tensor * dst) { ggml_tensor * src0 = dst->src[0]; @@ -2991,20 +3284,20 @@ void ggml_cann_argmax(ggml_backend_cann_context & ctx, ggml_tensor * dst) { GGML_CANN_CALL_ACLNN_OP(ctx, ArgMax, acl_src.get(), 3, false, acl_dst.get()); } -void ggml_cann_conv_transpose_1d(ggml_backend_cann_context& ctx, ggml_tensor* dst){ +void ggml_cann_conv_transpose_1d(ggml_backend_cann_context & ctx, ggml_tensor * dst) { ggml_tensor * src0 = dst->src[0]; ggml_tensor * src1 = dst->src[1]; // stride - int64_t s0 = ((const int32_t*)(dst->op_params))[0]; + int64_t s0 = ((const int32_t *) (dst->op_params))[0]; - acl_tensor_ptr acl_input = ggml_cann_create_tensor(src1, src1->ne, src1->nb, 3, ACL_FORMAT_NCL); + acl_tensor_ptr acl_input = ggml_cann_create_tensor(src1, src1->ne, src1->nb, 3, ACL_FORMAT_NCL); acl_tensor_ptr acl_weight = ggml_cann_create_tensor(src0, src0->ne, src0->nb, 3, ACL_FORMAT_NCL); - acl_tensor_ptr acl_dst = ggml_cann_create_tensor(dst, dst->ne, dst->nb, 3, ACL_FORMAT_NCL); + acl_tensor_ptr acl_dst = ggml_cann_create_tensor(dst, dst->ne, dst->nb, 3, ACL_FORMAT_NCL); // get base information of input and kernel - int64_t input_len = *(src1->ne); - int64_t dst_len = *(dst->ne); + int64_t input_len = *(src1->ne); + int64_t dst_len = *(dst->ne); int64_t kernel_size = *(src0->ne); // set the max kernel size for each conv @@ -3012,56 +3305,55 @@ void ggml_cann_conv_transpose_1d(ggml_backend_cann_context& ctx, ggml_tensor* ds // compute the partition of kernel int64_t part_num = 1; - part_num = (kernel_size + max_kernel_size - 1) / max_kernel_size; + part_num = (kernel_size + max_kernel_size - 1) / max_kernel_size; int64_t strideVal[1]; - strideVal[0] = s0; - acl_int_array_ptr stride = ggml_cann_create_int_array(strideVal, 1); - int64_t paddingVal[] = {0}; - acl_int_array_ptr padding = ggml_cann_create_int_array(paddingVal, 1); - int64_t dilationVal[] = {1}; - acl_int_array_ptr dilation = ggml_cann_create_int_array(dilationVal, 1); - bool transposed = true; - int64_t groups = 1; - int8_t cubeMathType = 0; + strideVal[0] = s0; + acl_int_array_ptr stride = ggml_cann_create_int_array(strideVal, 1); + int64_t paddingVal[] = { 0 }; + acl_int_array_ptr padding = ggml_cann_create_int_array(paddingVal, 1); + int64_t dilationVal[] = { 1 }; + acl_int_array_ptr dilation = ggml_cann_create_int_array(dilationVal, 1); + bool transposed = true; + int64_t groups = 1; + int8_t cubeMathType = 0; #ifdef ASCEND_310P cubeMathType = 1; #endif auto weight_type = ggml_cann_type_mapping(src0->type); - auto dst_type = ggml_cann_type_mapping(dst->type); + auto dst_type = ggml_cann_type_mapping(dst->type); // slice the kernel to make each conv available - int64_t slice_dim = -1; + int64_t slice_dim = -1; int64_t slice_start = 0; - int64_t slice_end = max_kernel_size; - int64_t slice_step = 1; - int64_t interval = max_kernel_size; + int64_t slice_end = max_kernel_size; + int64_t slice_step = 1; + int64_t interval = max_kernel_size; - int64_t left_pad_len = dilationVal[0] * (max_kernel_size - 1) + 1 - 2 * paddingVal[0]; + int64_t left_pad_len = dilationVal[0] * (max_kernel_size - 1) + 1 - 2 * paddingVal[0]; int64_t right_pad_len = 0; - acl_scalar_ptr alpha = nullptr; - float alphaValue = 1.0; - alpha = ggml_cann_create_scalar(&alphaValue, aclDataType::ACL_FLOAT); + acl_scalar_ptr alpha = nullptr; + float alphaValue = 1.0; + alpha = ggml_cann_create_scalar(&alphaValue, aclDataType::ACL_FLOAT); // set zero to destination GGML_CANN_CALL_ACLNN_OP(ctx, InplaceZero, acl_dst.get()); - for(int k = 0; k < part_num; k++){ - + for (int k = 0; k < part_num; k++) { // create part kernel tensor and slice from big kernel slice_start = max_kernel_size * k; - if(k == part_num - 1){ + if (k == part_num - 1) { slice_end = kernel_size; - interval = kernel_size - max_kernel_size * k; - }else{ - slice_end = max_kernel_size * (k+1); + interval = kernel_size - max_kernel_size * k; + } else { + slice_end = max_kernel_size * (k + 1); } int64_t part_ne[4]; - for(int i = 0; i < 4; i++) { + for (int i = 0; i < 4; i++) { part_ne[i] = *(src0->ne + i); } part_ne[0] = interval; @@ -3074,16 +3366,17 @@ void ggml_cann_conv_transpose_1d(ggml_backend_cann_context& ctx, ggml_tensor* ds ggml_cann_pool_alloc part_kernel_allocator; part_kernel_allocator.alloc(ctx.pool(), part_nb[3]); - void* part_kernel_buf = part_kernel_allocator.get(); + void * part_kernel_buf = part_kernel_allocator.get(); - acl_tensor_ptr part_kernel = ggml_cann_create_tensor(part_kernel_buf, weight_type, - ggml_element_size(src0), part_ne, part_nb, 3, ACL_FORMAT_NCL); + acl_tensor_ptr part_kernel = ggml_cann_create_tensor(part_kernel_buf, weight_type, ggml_element_size(src0), + part_ne, part_nb, 3, ACL_FORMAT_NCL); - GGML_CANN_CALL_ACLNN_OP(ctx, Slice, acl_weight.get(), slice_dim, slice_start, slice_end, slice_step, part_kernel.get()); + GGML_CANN_CALL_ACLNN_OP(ctx, Slice, acl_weight.get(), slice_dim, slice_start, slice_end, slice_step, + part_kernel.get()); // create the part conv result tensor int64_t part_dst_ne[4]; - for(int i = 0; i < 4; i++){ + for (int i = 0; i < 4; i++) { part_dst_ne[i] = *(dst->ne + i); } part_dst_ne[0] = (input_len - 1) * strideVal[0] - 2 * paddingVal[0] + dilationVal[0] * (part_ne[0] - 1) + 1; @@ -3095,32 +3388,33 @@ void ggml_cann_conv_transpose_1d(ggml_backend_cann_context& ctx, ggml_tensor* ds } ggml_cann_pool_alloc part_dst_allocator; part_dst_allocator.alloc(ctx.pool(), part_dst_nb[3]); - void* part_dst_buf = part_dst_allocator.get(); + void * part_dst_buf = part_dst_allocator.get(); acl_tensor_ptr acl_part_dst = ggml_cann_create_tensor(part_dst_buf, dst_type, ggml_element_size(dst), - part_dst_ne, part_dst_nb, 3, ACL_FORMAT_NCL); + part_dst_ne, part_dst_nb, 3, ACL_FORMAT_NCL); GGML_CANN_CALL_ACLNN_OP(ctx, InplaceZero, acl_part_dst.get()); // compute part conv transpose 1d GGML_CANN_CALL_ACLNN_OP(ctx, Convolution, acl_input.get(), part_kernel.get(), nullptr, stride.get(), - padding.get(), dilation.get(), transposed, padding.get(), groups, acl_part_dst.get(), cubeMathType); + padding.get(), dilation.get(), transposed, padding.get(), groups, acl_part_dst.get(), + cubeMathType); // compute the position of part result in final result int64_t global_start = slice_start; - int64_t global_end = std::min((input_len - 1) * strideVal[0] + slice_end, dst_len); + int64_t global_end = std::min((input_len - 1) * strideVal[0] + slice_end, dst_len); - left_pad_len = global_start; + left_pad_len = global_start; right_pad_len = dst_len - global_end; - std::vector<int64_t> padDataVal = {left_pad_len,right_pad_len}; - acl_int_array_ptr padData = ggml_cann_create_int_array(padDataVal.data(), 2); + std::vector<int64_t> padDataVal = { left_pad_len, right_pad_len }; + acl_int_array_ptr padData = ggml_cann_create_int_array(padDataVal.data(), 2); - acl_scalar_ptr pad_value = nullptr; - float pad_valueVal = 0.0; - pad_value = ggml_cann_create_scalar(&pad_valueVal, aclDataType::ACL_FLOAT); + acl_scalar_ptr pad_value = nullptr; + float pad_valueVal = 0.0; + pad_value = ggml_cann_create_scalar(&pad_valueVal, aclDataType::ACL_FLOAT); int64_t conv_result_ne[4]; - for(int i = 0; i < 4; i++){ + for (int i = 0; i < 4; i++) { conv_result_ne[i] = *(dst->ne + i); } @@ -3132,13 +3426,14 @@ void ggml_cann_conv_transpose_1d(ggml_backend_cann_context& ctx, ggml_tensor* ds ggml_cann_pool_alloc conv_result_allocator; conv_result_allocator.alloc(ctx.pool(), conv_result_nb[3]); - void* conv_result_buf = conv_result_allocator.get(); + void * conv_result_buf = conv_result_allocator.get(); acl_tensor_ptr conv_result = ggml_cann_create_tensor(conv_result_buf, dst_type, ggml_element_size(dst), - conv_result_ne, conv_result_nb, 3, ACL_FORMAT_NCL); + conv_result_ne, conv_result_nb, 3, ACL_FORMAT_NCL); GGML_CANN_CALL_ACLNN_OP(ctx, InplaceZero, conv_result.get()); - GGML_CANN_CALL_ACLNN_OP(ctx, ConstantPadNd, acl_part_dst.get(), padData.get(), pad_value.get(), conv_result.get()); + GGML_CANN_CALL_ACLNN_OP(ctx, ConstantPadNd, acl_part_dst.get(), padData.get(), pad_value.get(), + conv_result.get()); GGML_CANN_CALL_ACLNN_OP(ctx, InplaceAdd, acl_dst.get(), conv_result.get(), alpha.get()); } } @@ -3175,29 +3470,50 @@ void ggml_cann_pad_reflect_1d(ggml_backend_cann_context & ctx, ggml_tensor * dst int64_t paddingsArray[2] = { opts[0], opts[1] }; acl_int_array_ptr paddings = ggml_cann_create_int_array(paddingsArray, 2); - for (int64_t i = 0; i < src0->ne[3]; i++) { - acl_tensor_ptr acl_src = - ggml_cann_create_tensor((char *) src0->data + i * src0->ne[3], ggml_cann_type_mapping(src0->type), - ggml_element_size(src0), src0->ne, src0->nb, 3); + // Collapsing ne[2]*ne[3] into a single batch dimension requires that dim3 + // is contiguous with respect to dim2 in both src and dst. + GGML_ASSERT(src0->nb[3] == src0->nb[2] * src0->ne[2]); + GGML_ASSERT(dst->nb[3] == dst->nb[2] * dst->ne[2]); - acl_tensor_ptr acl_dst = - ggml_cann_create_tensor((char *) dst->data + i * src0->ne[3], ggml_cann_type_mapping(dst->type), - ggml_element_size(dst), dst->ne, dst->nb, 3); + int64_t src_ne_3d[3] = { src0->ne[0], src0->ne[1], src0->ne[2] * src0->ne[3] }; + int64_t dst_ne_3d[3] = { dst->ne[0], dst->ne[1], dst->ne[2] * dst->ne[3] }; - GGML_CANN_CALL_ACLNN_OP(ctx, ReflectionPad1d, acl_src.get(), paddings.get(), acl_dst.get()); - } + acl_tensor_ptr acl_src = ggml_cann_create_tensor(src0->data, ggml_cann_type_mapping(src0->type), + ggml_element_size(src0), src_ne_3d, src0->nb, 3); + + acl_tensor_ptr acl_dst = ggml_cann_create_tensor(dst->data, ggml_cann_type_mapping(dst->type), + ggml_element_size(dst), dst_ne_3d, dst->nb, 3); + + GGML_CANN_CALL_ACLNN_OP(ctx, ReflectionPad1d, acl_src.get(), paddings.get(), acl_dst.get()); } void ggml_cann_count_equal(ggml_backend_cann_context & ctx, ggml_tensor * dst) { ggml_tensor * src0 = dst->src[0]; ggml_tensor * src1 = dst->src[1]; + // Write element-wise equality (0 or 1) into a temporary buffer to avoid + // modifying src0 in-place. Use the same type as src0 so ReduceSum can + // consume it directly without a type cast. + ggml_cann_pool_alloc eq_alloc(ctx.pool(), ggml_nelements(src0) * ggml_element_size(src0)); + size_t eq_nb[GGML_MAX_DIMS]; + eq_nb[0] = ggml_element_size(src0); + for (int i = 1; i < GGML_MAX_DIMS; i++) { + eq_nb[i] = eq_nb[i - 1] * src0->ne[i - 1]; + } + acl_tensor_ptr acl_eq = ggml_cann_create_tensor( + eq_alloc.get(), ggml_cann_type_mapping(src0->type), ggml_element_size(src0), + src0->ne, eq_nb, GGML_MAX_DIMS); + acl_tensor_ptr acl_self = ggml_cann_create_tensor(src0); acl_tensor_ptr acl_other = ggml_cann_create_tensor(src1); - - GGML_CANN_CALL_ACLNN_OP(ctx, InplaceEqTensor, acl_self.get(), acl_other.get()); - - ggml_cann_sum(ctx, dst); + GGML_CANN_CALL_ACLNN_OP(ctx, EqTensor, acl_self.get(), acl_other.get(), acl_eq.get()); + + // Sum the 0/1 values into dst. + acl_tensor_ptr acl_dst = ggml_cann_create_tensor(dst); + int64_t dims[4] = { 0, 1, 2, 3 }; + acl_int_array_ptr dims_arr = ggml_cann_create_int_array(dims, 4); + GGML_CANN_CALL_ACLNN_OP(ctx, ReduceSum, acl_eq.get(), dims_arr.get(), true, + ggml_cann_type_mapping(dst->type), acl_dst.get()); } void ggml_cann_step(ggml_backend_cann_context & ctx, ggml_tensor * dst) { @@ -3213,6 +3529,27 @@ void ggml_cann_step(ggml_backend_cann_context & ctx, ggml_tensor * dst) { GGML_CANN_CALL_ACLNN_OP(ctx, GtScalar, acl_src.get(), alpha.get(), acl_dst.get()); } +void ggml_cann_softplus(ggml_backend_cann_context & ctx, ggml_tensor * dst) { + ggml_tensor * src0 = dst->src[0]; + + acl_tensor_ptr acl_src = ggml_cann_create_tensor(src0); + acl_tensor_ptr acl_dst = ggml_cann_create_tensor(dst); + + float beta_val = 1.0f; + float threshold_val = 20.0f; + acl_scalar_ptr beta = ggml_cann_create_scalar(&beta_val, ACL_FLOAT); + acl_scalar_ptr threshold = ggml_cann_create_scalar(&threshold_val, ACL_FLOAT); + + GGML_CANN_CALL_ACLNN_OP(ctx, Softplus, acl_src.get(), beta.get(), threshold.get(), acl_dst.get()); +} + +void ggml_cann_geglu_quick(ggml_backend_cann_context & ctx, ggml_tensor * dst) { + auto gelu_quick_fn = [](ggml_backend_cann_context & ctx, aclTensor * acl_src, aclTensor * acl_dst) { + GGML_CANN_CALL_ACLNN_OP(ctx, GeluV2, acl_src, 0, acl_dst); + }; + ggml_cann_op_unary_gated(gelu_quick_fn, ctx, dst); +} + /** * @brief Performs expert-specific matrix multiplication (MoE) with * floating-point precision using the CANN backend. @@ -3282,130 +3619,223 @@ static void ggml_cann_mul_mat_id_fp(ggml_backend_cann_context & ctx, ggml_tensor } /** - * @brief Performs expert-specific matrix multiplication (MoE) with - * quantized precision using the CANN backend. + * @brief Performs quantized matrix multiplication for Mixture of Experts (MoE) + * models using the CANN backend. * - * This function executes a matrix multiplication operation tailored for - * Mixture of Experts (MoE) models, where the input tensor is multiplied - * with expert-specific quantized weight matrices. It leverages the CANN - * backend to perform efficient low-precision computations and stores the - * quantized result in the destination tensor `dst`. + * This function implements MUL_MAT_ID operation for quantized weight matrices + * (Q4_0 and Q8_0 formats). It selects expert-specific weight matrices based on + * the provided expert indices, and computes matrix multiplication using CANN's + * WeightQuantBatchMatmulV2 operator. * - * Quantization techniques reduce memory footprint and improve performance - * by using lower-bit representations (e.g., int8) instead of floating-point. - * This function is designed to work with such formats and may incorporate - * optimizations like identity-based fast paths or routing masks for sparse - * expert selection. + * The function performs the following steps: + * 1. Converts input/output tensors to F16 format if necessary + * 2. Uses IndexSelect to extract expert-specific weights and scales based on indices + * 3. Performs quantized matrix multiplication for each expert using WeightQuantBatchMatmulV2 + * 4. Converts output back to the target type if needed * - * @param ctx The context for executing CANN backend operations. - * @param dst The destination tensor where the quantized MoE multiplication result - * will be stored. + * Tensor shapes: + * - dst: [M, K, N, 1] - output tensor + * - src0: [D, M, A, 1] - quantized weight matrices (Q4_0 or Q8_0) + * - src1: [D, B, N, 1] - input activations (B = K for per-expert input, or B = 1 for broadcast) + * - ids: [K, N] - expert indices for routing + * + * @param ctx The CANN backend context for operation execution. + * @param dst The destination tensor where the multiplication result will be stored. * - * @note This function assumes quantized data types and is designed for - * MoE architectures with potential sparse expert routing. + * @note Only Q4_0 and Q8_0 quantization formats are supported. + * @note The function handles automatic type conversion to/from F16 as needed by the hardware. */ static void ggml_cann_mul_mat_id_quant(ggml_backend_cann_context & ctx, ggml_tensor * dst) { - // TODO: Use aclnnGroupedMatMul - //dst [M, K, N, 1] - ggml_tensor * src0 = dst->src[0]; //src0 [D, M, A, 1] - ggml_tensor * src1 = dst->src[1]; //src1 [D, B, N, 1], B = K or B = 1 - ggml_tensor * ids = dst->src[2]; //ids [K, N] + // dst: [M, K, N, 1] + // src0: [D, M, A, 1] - quantized weights + // src1: [D, B, N, 1] - input activations, B = K or B = 1 + // ids: [K, N] - expert indices + ggml_tensor * src0 = dst->src[0]; + ggml_tensor * src1 = dst->src[1]; + ggml_tensor * ids = dst->src[2]; - GGML_TENSOR_BINARY_OP_LOCALS + GGML_ASSERT(src0->ne[3] == 1); + GGML_ASSERT(src1->ne[3] == 1); + GGML_ASSERT(dst->ne[3] == 1); + GGML_ASSERT(src1->ne[2] == ids->ne[1]); + + const int64_t n_batches = ids->ne[1]; + const int64_t n_select_experts = ids->ne[0]; + const enum ggml_type type = src0->type; + + const int32_t group_size = QK8_0; // Both Q4_0 and Q8_0 use group size of 32 + GGML_ASSERT(group_size == QK4_0); + + // Calculate element size for quantized weights + const float weight_elem_size = + (type == GGML_TYPE_Q4_0) ? 0.5f : + (type == GGML_TYPE_Q8_0) ? 1.0f : + (GGML_ABORT("MUL_MAT_ID only supports Q4_0 and Q8_0"), 0.0f); + + // Calculate scale offset in memory + const size_t weight_size = src0->ne[0] * src0->ne[1] * src0->ne[2] * weight_elem_size; + const size_t scale_elem_size = sizeof(uint16_t); + char * scale_data = (char *) src0->data + weight_size; + + // Allocate buffers for selected expert weights and scales + const size_t selected_weight_size = src0->ne[0] * src0->ne[1] * n_select_experts * weight_elem_size; + ggml_cann_pool_alloc selected_weight_alloc(ctx.pool(), selected_weight_size); + void * selected_weight_buffer = selected_weight_alloc.get(); + + const size_t selected_scale_size = (src0->ne[0] / group_size) * src0->ne[1] * n_select_experts * scale_elem_size; + ggml_cann_pool_alloc selected_scale_alloc(ctx.pool(), selected_scale_size); + void * selected_scale_buffer = selected_scale_alloc.get(); + + // Helper lambda to allocate and cast tensor to F16 if needed + constexpr size_t f16_elem_size = sizeof(uint16_t); + auto prepare_f16_buffer = [&](ggml_tensor * tensor, ggml_cann_pool_alloc & allocator, + bool need_cast = false) -> void * { + if (tensor->type == GGML_TYPE_F16) { + return tensor->data; + } + + size_t total_size = f16_elem_size; + for (int i = 0; i < GGML_MAX_DIMS; i++) { + total_size *= tensor->ne[i]; + } + void * buffer = allocator.alloc(total_size); + + if (need_cast == false) { + return buffer; + } - // copy index from npu to cpu - int64_t n_as = ne02; // A - int64_t n_ids = ids->ne[0]; // K + int64_t ne[GGML_MAX_DIMS]; + size_t nb[GGML_MAX_DIMS] = { f16_elem_size }; + for (int i = 0; i < GGML_MAX_DIMS; i++) { + ne[i] = tensor->ne[i]; + if (i > 0) { + nb[i] = nb[i - 1] * ne[i - 1]; + } + } - std::vector<char> ids_host(ggml_nbytes(ids)); - ACL_CHECK(aclrtMemcpyAsync(ids_host.data(), ggml_nbytes(ids), ids->data, ggml_nbytes(ids), - ACL_MEMCPY_DEVICE_TO_HOST, ctx.stream())); - ACL_CHECK(aclrtSynchronizeStream(ctx.stream())); + acl_tensor_ptr src_tensor = ggml_cann_create_tensor(tensor); + acl_tensor_ptr f16_tensor = ggml_cann_create_tensor(buffer, ACL_FLOAT16, f16_elem_size, ne, nb, GGML_MAX_DIMS); + aclnn_cast(ctx, src_tensor.get(), f16_tensor.get(), ACL_FLOAT16); - char * src0_original = (char *) src0->data; - char * src1_original = (char *) src1->data; - char * dst_original = (char *) dst->data; + return buffer; + }; - ggml_tensor src0_row = *src0; - ggml_tensor src1_row = *src1; - ggml_tensor dst_row = *dst; + // Prepare input and output buffers + ggml_cann_pool_alloc input_alloc(ctx.pool()); + void * input_buffer = prepare_f16_buffer(src1, input_alloc, true); - const enum ggml_type type = dst->src[0]->type; - float weight_elem_size; - if (type == GGML_TYPE_Q4_0) { - weight_elem_size = float(sizeof(uint8_t)) / 2; - } else if (type == GGML_TYPE_Q8_0) { - weight_elem_size = float(sizeof(uint8_t)); - } else { - GGML_ABORT("MUL_MAT_ID only support quant type Q4_0 and Q8_0 "); - } + ggml_cann_pool_alloc output_alloc(ctx.pool()); + void * output_buffer = prepare_f16_buffer(dst, output_alloc, false); - // src0_row [D, M, 1, 1] weight without permute - src0_row.ne[2] = 1; - src0_row.ne[3] = 1; - src0_row.nb[0] = weight_elem_size; - src0_row.nb[1] = weight_elem_size * ne00; - src0_row.nb[2] = weight_elem_size * ne00; - src0_row.nb[3] = weight_elem_size * ne00; - size_t weight_stride = ne00 * ne01 * weight_elem_size; - size_t weight_size = weight_stride * ne02 * ne03; + // Process each batch + for (int64_t batch_idx = 0; batch_idx < n_batches; batch_idx++) { + // Create index tensor for current batch + const size_t index_offset = batch_idx * ids->nb[1]; + acl_tensor_ptr batch_indices = ggml_cann_create_tensor(ids, ids->ne, ids->nb, 1, ACL_FORMAT_ND, index_offset); - // scale [D, M, 1, 1] -> scale && permute - size_t scale_elem_size = sizeof(uint16_t); - size_t scale_stride = src0->ne[1] * src0->ne[0] / QK8_0 * scale_elem_size; + // Select quantized weights using expert indices + // Q4_0 stores 2 values per byte, Q8_0 stores 1 value per byte + const int64_t weight_d = (type == GGML_TYPE_Q4_0) ? src0->ne[0] / 2 : src0->ne[0]; + const int64_t weight_m = src0->ne[1]; + const int64_t weight_n_experts = src0->ne[2]; + + int64_t weight_ne[3] = { weight_d, weight_m, weight_n_experts }; + size_t weight_nb[3] = { sizeof(int8_t), weight_d * sizeof(int8_t), weight_d * weight_m * sizeof(int8_t) }; + + acl_tensor_ptr all_weights = + ggml_cann_create_tensor(src0->data, ACL_INT8, sizeof(int8_t), weight_ne, weight_nb, 3); + + int64_t selected_weight_ne[3] = { weight_d, weight_m, n_select_experts }; + size_t selected_weight_nb[3] = { sizeof(int8_t), weight_d * sizeof(int8_t), + weight_d * weight_m * sizeof(int8_t) }; + + acl_tensor_ptr selected_weights = ggml_cann_create_tensor(selected_weight_buffer, ACL_INT8, sizeof(int8_t), + selected_weight_ne, selected_weight_nb, 3); + + GGML_CANN_CALL_ACLNN_OP(ctx, IndexSelect, all_weights.get(), 0, batch_indices.get(), selected_weights.get()); - // src1_row [D, 1, 1, 1] -> input - src1_row.ne[1] = 1; - src1_row.ne[2] = 1; - src1_row.ne[3] = 1; - src1_row.nb[2] = nb11; - src1_row.nb[3] = nb11; - - // dst_row [M, 1, 1, 1] -> out - dst_row.ne[1] = 1; - dst_row.ne[2] = 1; - dst_row.ne[3] = 1; - dst_row.nb[2] = nb1; - dst_row.nb[3] = nb1; - - //create weight for one row - ggml_cann_pool_alloc weight_allocator(ctx.pool()); - void * weight_buffer = weight_allocator.alloc(nb02); - for (int64_t iid1 = 0; iid1 < ids->ne[1]; iid1++) { - for (int64_t id = 0; id < n_ids; id++) { - // expert index - int32_t i02 = *(int32_t *) (ids_host.data() + iid1 * ids->nb[1] + id * ids->nb[0]); - GGML_ASSERT(i02 >= 0 && i02 < n_as); - - // If B = 1 (broadcast), always use 0; otherwise, use id. - int64_t i11 = (ne11 == 1 ? 0 : id); - int64_t i12 = iid1; - - int64_t i1 = id; - int64_t i2 = i12; - - void * src0_tmp_ptr = src0_original + i02 * weight_stride; - void * scale_tmp_ptr = src0_original + weight_size + i02 * scale_stride; - void * src1_tmp_ptr = src1_original + i11 * nb11 + i12 * nb12; - void * dst_tmp_ptr = dst_original + i1 * nb1 + i2 * nb2; - - // mem cpy - ACL_CHECK(aclrtMemcpyAsync(weight_buffer, weight_stride, src0_tmp_ptr, weight_stride, - ACL_MEMCPY_DEVICE_TO_DEVICE, ctx.stream())); - void * scale_buffer = (char *) weight_buffer + weight_stride; - ACL_CHECK(aclrtMemcpyAsync(scale_buffer, scale_stride, scale_tmp_ptr, scale_stride, - ACL_MEMCPY_DEVICE_TO_DEVICE, ctx.stream())); - - src0_row.data = weight_buffer; - src1_row.data = src1_tmp_ptr; - dst_row.data = dst_tmp_ptr; - dst_row.src[0] = &src0_row; - dst_row.src[1] = &src1_row; - - ggml_cann_mul_mat(ctx, &dst_row); + // Select scales using the same expert indices + const int64_t scale_d = src0->ne[0] / group_size; + int64_t scale_ne[3] = { scale_d, weight_m, weight_n_experts }; + size_t scale_nb[3] = { scale_elem_size, scale_d * scale_elem_size, scale_d * weight_m * scale_elem_size }; + + acl_tensor_ptr all_scales = + ggml_cann_create_tensor(scale_data, ACL_FLOAT16, scale_elem_size, scale_ne, scale_nb, 3); + + int64_t selected_scale_ne[3] = { scale_d, weight_m, n_select_experts }; + size_t selected_scale_nb[3] = { scale_elem_size, scale_d * scale_elem_size, + scale_d * weight_m * scale_elem_size }; + + acl_tensor_ptr selected_scales = ggml_cann_create_tensor(selected_scale_buffer, ACL_FLOAT16, scale_elem_size, + selected_scale_ne, selected_scale_nb, 3); + + GGML_CANN_CALL_ACLNN_OP(ctx, IndexSelect, all_scales.get(), 0, batch_indices.get(), selected_scales.get()); + + // Process each expert for current batch + // IndexSelect output layout: [D, M, K] in contiguous format + // WeightQuantBatchMatmulV2 expects: [M, D] with row-major stride + for (int64_t expert_idx = 0; expert_idx < n_select_experts; expert_idx++) { + // Determine input offset: broadcast if src1->ne[1]==1, otherwise use per-expert input + const size_t input_offset = + (batch_idx * src1->ne[1] + (src1->ne[1] == 1 ? 0 : expert_idx)) * src1->ne[0] * f16_elem_size; + const size_t output_offset = (batch_idx * dst->ne[1] + expert_idx) * dst->ne[0] * f16_elem_size; + + // Create weight view for current expert: [D, M, K] -> [M, D] + int64_t weight_view_ne[2] = { weight_m, src0->ne[0] }; + float weight_view_nb[2] = { src0->ne[0] * weight_elem_size, weight_elem_size }; + const size_t weight_view_offset = expert_idx * selected_weight_nb[2]; + + acl_tensor_ptr weight_view = + ggml_cann_create_tensor(selected_weight_buffer, ggml_cann_type_mapping(type), weight_elem_size, + weight_view_ne, weight_view_nb, 2, ACL_FORMAT_ND, weight_view_offset); + + // Create scale view for current expert: [D, M, K] -> [M, D] + int64_t scale_view_ne[2] = { weight_m, scale_d }; + size_t scale_view_nb[2] = { selected_scale_nb[1], selected_scale_nb[0] }; + const size_t scale_view_offset = expert_idx * selected_scale_nb[2]; + + acl_tensor_ptr scale_view = + ggml_cann_create_tensor(selected_scale_buffer, ACL_FLOAT16, scale_elem_size, scale_view_ne, + scale_view_nb, 2, ACL_FORMAT_ND, scale_view_offset); + + // Create input activation tensor [D, 1] + int64_t input_ne[2] = { src1->ne[0], 1 }; + size_t input_nb[2] = { f16_elem_size, src1->ne[0] * f16_elem_size }; + + acl_tensor_ptr input_tensor = ggml_cann_create_tensor(input_buffer, ACL_FLOAT16, f16_elem_size, input_ne, + input_nb, 2, ACL_FORMAT_ND, input_offset); + + // Create output tensor [M, 1] + int64_t output_ne[2] = { dst->ne[0], 1 }; + size_t output_nb[2] = { f16_elem_size, dst->ne[0] * f16_elem_size }; + + acl_tensor_ptr output_tensor = ggml_cann_create_tensor(output_buffer, ACL_FLOAT16, f16_elem_size, output_ne, + output_nb, 2, ACL_FORMAT_ND, output_offset); + + // Perform quantized matrix multiplication + GGML_CANN_CALL_ACLNN_OP(ctx, WeightQuantBatchMatmulV2, input_tensor.get(), weight_view.get(), + scale_view.get(), nullptr, nullptr, nullptr, nullptr, group_size, + output_tensor.get()); } } - return; + + // Cast output back to original type if we used a temporary F16 buffer + if (dst->type != GGML_TYPE_F16) { + int64_t ne[GGML_MAX_DIMS]; + size_t nb[GGML_MAX_DIMS] = { f16_elem_size }; + for (int i = 0; i < GGML_MAX_DIMS; i++) { + ne[i] = dst->ne[i]; + if (i > 0) { + nb[i] = nb[i - 1] * ne[i - 1]; + } + } + + acl_tensor_ptr f16_output = + ggml_cann_create_tensor(output_buffer, ACL_FLOAT16, f16_elem_size, ne, nb, GGML_MAX_DIMS); + acl_tensor_ptr dst_tensor = ggml_cann_create_tensor(dst); + + aclnn_cast(ctx, f16_output.get(), dst_tensor.get(), ggml_cann_type_mapping(dst->type)); + } } void ggml_cann_mul_mat_id(ggml_backend_cann_context & ctx, ggml_tensor * dst) { @@ -3502,6 +3932,44 @@ void ggml_cann_flash_attn_ext(ggml_backend_cann_context & ctx, ggml_tensor * dst acl_k_tensor = ggml_cann_create_tensor(src1, src1_bsnd_ne, src1_bsnd_nb, GGML_MAX_DIMS); acl_v_tensor = ggml_cann_create_tensor(src2, src2_bsnd_ne, src2_bsnd_nb, GGML_MAX_DIMS); + // Step 2.5: Pad Q, K, V along head dimension if D is not a multiple of 16 + // (required by FusedInferAttentionScoreV2) + const int64_t D = src0->ne[0]; + const int64_t D_padded = GGML_PAD(D, 16); + const bool needs_padding = (D != D_padded); + + ggml_cann_pool_alloc q_pad_allocator(ctx.pool()); + ggml_cann_pool_alloc k_pad_allocator(ctx.pool()); + ggml_cann_pool_alloc v_pad_allocator(ctx.pool()); + + if (needs_padding) { + int64_t paddings[] = { 0, D_padded - D, 0, 0, 0, 0, 0, 0 }; + + auto pad_fa_tensor = [&](acl_tensor_ptr & tensor, const int64_t * bsnd_ne, + ggml_cann_pool_alloc & allocator) { + int64_t pad_ne[GGML_MAX_DIMS] = { D_padded, bsnd_ne[1], bsnd_ne[2], bsnd_ne[3] }; + size_t pad_nb[GGML_MAX_DIMS]; + pad_nb[0] = faElemSize; + for (int i = 1; i < GGML_MAX_DIMS; ++i) { + pad_nb[i] = pad_nb[i - 1] * pad_ne[i - 1]; + } + int64_t nelements = pad_ne[0] * pad_ne[1] * pad_ne[2] * pad_ne[3]; + void * buffer = allocator.alloc(nelements * faElemSize); + acl_tensor_ptr padded = + ggml_cann_create_tensor(buffer, faDataType, faElemSize, pad_ne, pad_nb, GGML_MAX_DIMS); + aclnn_pad(ctx, tensor.get(), padded.get(), paddings); + tensor = std::move(padded); + }; + + pad_fa_tensor(acl_q_tensor, src0_bsnd_ne, q_pad_allocator); + pad_fa_tensor(acl_k_tensor, src1_bsnd_ne, k_pad_allocator); + pad_fa_tensor(acl_v_tensor, src2_bsnd_ne, v_pad_allocator); + + src0_bsnd_ne[0] = D_padded; + src1_bsnd_ne[0] = D_padded; + src2_bsnd_ne[0] = D_padded; + } + // Step 3: create the PSEShift tensor if needed // this tensor is considered as mask (f16) in the llama.cpp acl_tensor_ptr bcast_pse_tensor; @@ -3591,17 +4059,16 @@ void ggml_cann_flash_attn_ext(ggml_backend_cann_context & ctx, ggml_tensor * dst GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16); acl_tensor_ptr fa_dst_tensor; - acl_tensor_ptr acl_dst_tensor; ggml_cann_pool_alloc out_f16_allocator(ctx.pool()); - if (dst->type == GGML_TYPE_F32) { - void * out_f16_buffer = out_f16_allocator.alloc(ggml_nelements(dst) * faElemSize); - + if (dst->type == GGML_TYPE_F32 || needs_padding) { int64_t * out_f16_ne = src0_bsnd_ne; size_t out_f16_nb[GGML_MAX_DIMS]; out_f16_nb[0] = faElemSize; for (int i = 1; i < GGML_MAX_DIMS; ++i) { out_f16_nb[i] = out_f16_nb[i - 1] * out_f16_ne[i - 1]; } + int64_t out_nelements = out_f16_ne[0] * out_f16_ne[1] * out_f16_ne[2] * out_f16_ne[3]; + void * out_f16_buffer = out_f16_allocator.alloc(out_nelements * faElemSize); fa_dst_tensor = ggml_cann_create_tensor(out_f16_buffer, faDataType, faElemSize, out_f16_ne, out_f16_nb, GGML_MAX_DIMS); @@ -3633,8 +4100,33 @@ void ggml_cann_flash_attn_ext(ggml_backend_cann_context & ctx, ggml_tensor * dst nullptr // softmaxLse ); - if (dst->type == GGML_TYPE_F32) { - // Step 6: post-processing, permute and cast to f32 + // Step 6: post-processing — slice padded output and/or cast to f32 + if (needs_padding) { + ggml_cann_pool_alloc sliced_f16_allocator(ctx.pool()); + + if (dst->type == GGML_TYPE_F32) { + int64_t sliced_ne[GGML_MAX_DIMS] = { D, src0_bsnd_ne[1], src0_bsnd_ne[2], src0_bsnd_ne[3] }; + size_t sliced_nb[GGML_MAX_DIMS]; + sliced_nb[0] = faElemSize; + for (int i = 1; i < GGML_MAX_DIMS; ++i) { + sliced_nb[i] = sliced_nb[i - 1] * sliced_ne[i - 1]; + } + int64_t sliced_nelements = sliced_ne[0] * sliced_ne[1] * sliced_ne[2] * sliced_ne[3]; + void * sliced_buffer = sliced_f16_allocator.alloc(sliced_nelements * faElemSize); + acl_tensor_ptr sliced_f16_tensor = ggml_cann_create_tensor(sliced_buffer, faDataType, faElemSize, + sliced_ne, sliced_nb, GGML_MAX_DIMS); + + GGML_CANN_CALL_ACLNN_OP(ctx, Slice, fa_dst_tensor.get(), + (int64_t) -1, (int64_t) 0, D, (int64_t) 1, sliced_f16_tensor.get()); + + acl_tensor_ptr acl_dst_tensor = ggml_cann_create_tensor(dst); + aclnn_cast(ctx, sliced_f16_tensor.get(), acl_dst_tensor.get(), ggml_cann_type_mapping(dst->type)); + } else { + acl_tensor_ptr acl_dst_tensor = ggml_cann_create_tensor(dst); + GGML_CANN_CALL_ACLNN_OP(ctx, Slice, fa_dst_tensor.get(), + (int64_t) -1, (int64_t) 0, D, (int64_t) 1, acl_dst_tensor.get()); + } + } else if (dst->type == GGML_TYPE_F32) { acl_tensor_ptr acl_dst_tensor = ggml_cann_create_tensor(dst); aclnn_cast(ctx, fa_dst_tensor.get(), acl_dst_tensor.get(), ggml_cann_type_mapping(dst->type)); } @@ -3644,46 +4136,65 @@ void ggml_cann_flash_attn_ext(ggml_backend_cann_context & ctx, ggml_tensor * dst } static void ggml_cann_out_prod_fp(ggml_backend_cann_context & ctx, ggml_tensor * dst) { - ggml_tensor * src0 = dst->src[0]; // weight - ggml_tensor * src1 = dst->src[1]; // input + ggml_tensor * src0 = dst->src[0]; // weight [ne00=m, ne01=K, ne02, ne03] + ggml_tensor * src1 = dst->src[1]; // input [ne10=n, ne11=K, ne12, ne13] GGML_TENSOR_BINARY_OP_LOCALS - acl_tensor_ptr acl_dst = ggml_cann_create_tensor(dst); - GGML_CANN_CALL_ACLNN_OP(ctx, InplaceZero, acl_dst.get()); + // dst[i,j] = sum_k src0[i,k] * src1[j,k] i.e. dst = src0 @ src1^T. + // + // ggml_cann_create_tensor reverses dimension order, so ACL sees: + // acl_src0 slice: ggml[m,K] -> ACL[K,m] + // acl_src1 slice: ggml[n,K] -> ACL[K,n] + // acl_dst slice: ggml[m,n] -> ACL[n,m] + // + // Build a transposed view of src1 by swapping ne[0]/ne[1]: + // src1_t: ggml[K,n] (swapped strides) -> ACL[n,K] + // + // Matmul(src1_t [n,K], src0 [K,m]) = [n,m] = acl_dst ✓ + // + // The outer batch loop is kept because src0 may have fewer batch slices than + // dst (ne02 <= ne2, ne03 <= ne3): this is a strided-broadcast not supported + // by standard CANN Matmul broadcasting. + + const aclDataType src0_acl_type = ggml_cann_type_mapping(src0->type); + const aclDataType src1_acl_type = ggml_cann_type_mapping(src1->type); + const aclDataType dst_acl_type = ggml_cann_type_mapping(dst->type); + const size_t src0_type_sz = ggml_type_size(src0->type); + const size_t src1_type_sz = ggml_type_size(src1->type); + const size_t dst_type_sz = ggml_type_size(dst->type); const int64_t dps2 = ne2 / ne02; const int64_t dps3 = ne3 / ne03; + for (int64_t i3 = 0; i3 < ne3; i3++) { for (int64_t i2 = 0; i2 < ne2; i2++) { const int64_t i02 = i2 / dps2; const int64_t i03 = i3 / dps3; - const int64_t i12 = i2; - const int64_t i13 = i3; - acl_tensor_ptr accumulator = - ggml_cann_create_tensor((char *) dst->data + i2 * nb2 + i3 * nb3, ggml_cann_type_mapping(dst->type), - ggml_type_size(dst->type), dst->ne, dst->nb, 2); - - // The outer product needs to be accumulated in this dimension. - for (int64_t i1 = 0; i1 < ne11; i1++) { - acl_tensor_ptr acl_input = ggml_cann_create_tensor( - (char *) src1->data + i1 * nb11 + i12 * nb12 + i13 * nb13, ggml_cann_type_mapping(src0->type), - ggml_type_size(src0->type), src1->ne, src1->nb, 1); - - acl_tensor_ptr acl_weight = ggml_cann_create_tensor( - (char *) src0->data + i1 * nb01 + i02 * nb02 + i03 * nb03, ggml_cann_type_mapping(src0->type), - ggml_type_size(src0->type), src0->ne, src0->nb, 1); - - ggml_cann_pool_alloc output_allocator(ctx.pool()); - void * output_buffer = output_allocator.alloc(ggml_nbytes(dst)); - acl_tensor_ptr acl_out = ggml_cann_create_tensor(output_buffer, ggml_cann_type_mapping(dst->type), - ggml_type_size(dst->type), dst->ne, dst->nb, 2); - - GGML_CANN_CALL_ACLNN_OP(ctx, Ger, acl_input.get(), acl_weight.get(), acl_out.get()); - float alpha_value = 1.0f; - aclScalar * alpha = aclCreateScalar(&alpha_value, ACL_FLOAT); - GGML_CANN_CALL_ACLNN_OP(ctx, InplaceAdd, accumulator.get(), acl_out.get(), alpha); - } + // src0 2D slice at [i02, i03]: ggml [m, K] -> ACL [K, m] + int64_t src0_ne[2] = { ne00, ne01 }; + size_t src0_nb[2] = { nb00, nb01 }; + acl_tensor_ptr acl_src0_s = ggml_cann_create_tensor( + (char *) src0->data + i02 * nb02 + i03 * nb03, + src0_acl_type, src0_type_sz, src0_ne, src0_nb, 2); + + // src1 transposed 2D slice at [i2, i3]: swap ne/nb -> ggml[K,n] -> ACL[n,K] + int64_t src1_t_ne[2] = { ne11, ne10 }; + size_t src1_t_nb[2] = { nb11, nb10 }; + acl_tensor_ptr acl_src1_t = ggml_cann_create_tensor( + (char *) src1->data + i2 * nb12 + i3 * nb13, + src1_acl_type, src1_type_sz, src1_t_ne, src1_t_nb, 2); + + // dst 2D slice at [i2, i3]: ggml [m, n] -> ACL [n, m] + int64_t dst_ne[2] = { ne0, ne1 }; + size_t dst_nb[2] = { nb0, nb1 }; + acl_tensor_ptr acl_dst_s = ggml_cann_create_tensor( + (char *) dst->data + i2 * nb2 + i3 * nb3, + dst_acl_type, dst_type_sz, dst_ne, dst_nb, 2); + + // Matmul(src1_t [n,K], src0 [K,m]) = [n,m] = acl_dst_s ✓ + GGML_CANN_CALL_ACLNN_OP(ctx, Matmul, + acl_src1_t.get(), acl_src0_s.get(), acl_dst_s.get(), (int8_t) 1); } } } @@ -3742,15 +4253,15 @@ void ggml_cann_ssm_conv(ggml_backend_cann_context & ctx, ggml_tensor * dst) { // we want a view: ne_w = { nc, 1, nr } // [K, 1, C] // so that reversed dims -> [C, 1, K] which matches // [out_channels, in_channels/groups, kernel_size] - int64_t w_ne[GGML_MAX_DIMS] = { nc, 1, nr, 1 }; // [K, 1 input ch. per group, C groups] + int64_t w_ne[GGML_MAX_DIMS] = { nc, 1, nr, 1 }; // [K, 1 input ch. per group, C groups] // Layout: src1 data is [K, C] with // offset(k, c) = k*nb0 + c*nb1 // We want offset_w(k, 0, c) = k*nb0 + c*nb1, // so we can reuse nb0 and nb1, and set nb2 = nb1. - size_t w_nb[GGML_MAX_DIMS] = { src1->nb[0], src1->nb[1], src1->nb[1], src1->nb[3] }; // same as src1 + size_t w_nb[GGML_MAX_DIMS] = { src1->nb[0], src1->nb[1], src1->nb[1], src1->nb[3] }; // same as src1 - acl_tensor_ptr acl_w = ggml_cann_create_tensor( - src1->data, ggml_cann_type_mapping(src1->type), ggml_type_size(src1->type), w_ne, w_nb, 3, ACL_FORMAT_NCL); + acl_tensor_ptr acl_w = ggml_cann_create_tensor(src1->data, ggml_cann_type_mapping(src1->type), + ggml_type_size(src1->type), w_ne, w_nb, 3, ACL_FORMAT_NCL); // 3) Output: dst is { d_inner, n_t, n_s } (CLN) // @@ -3768,11 +4279,12 @@ void ggml_cann_ssm_conv(ggml_backend_cann_context & ctx, ggml_tensor * dst) { // nb_y[0] = nr * sizeof(float); // step in L // nb_y[1] = sizeof(float); // step in C // nb_y[2] = nr * n_t * sizeof(float); // step in N - int64_t y_ne[GGML_MAX_DIMS] = { n_t, nr, n_s, 1 }; // [L_out, C, N] - size_t y_nb[GGML_MAX_DIMS] = { dst->ne[0] * sizeof(float), sizeof(float), dst->ne[0] * dst->ne[1] * sizeof(float), dst->nb[3] }; // [nr, 1, nr * n_t] + int64_t y_ne[GGML_MAX_DIMS] = { n_t, nr, n_s, 1 }; // [L_out, C, N] + size_t y_nb[GGML_MAX_DIMS] = { dst->ne[0] * sizeof(float), sizeof(float), dst->ne[0] * dst->ne[1] * sizeof(float), + dst->nb[3] }; // [nr, 1, nr * n_t] - acl_tensor_ptr acl_y = ggml_cann_create_tensor( - dst->data, ggml_cann_type_mapping(dst->type), ggml_type_size(dst->type), y_ne, y_nb, 3, ACL_FORMAT_NCL); + acl_tensor_ptr acl_y = ggml_cann_create_tensor(dst->data, ggml_cann_type_mapping(dst->type), + ggml_type_size(dst->type), y_ne, y_nb, 3, ACL_FORMAT_NCL); // --- Conv1d parameters: depthwise, stride 1, no padding ("valid") --- int64_t strideVal[1] = { 1 }; @@ -3791,22 +4303,15 @@ void ggml_cann_ssm_conv(ggml_backend_cann_context & ctx, ggml_tensor * dst) { cubeMathType = 1; #endif - GGML_CANN_CALL_ACLNN_OP(ctx, - Convolution, + GGML_CANN_CALL_ACLNN_OP(ctx, Convolution, acl_x.get(), // input: N, C, L_in = ncs acl_w.get(), // weight: [C, 1, K] with groups=nr nullptr, // bias - stride.get(), - padding.get(), - dilation.get(), - transposed, - padding.get(), // output padding (unused for non-transposed) - groups, - acl_y.get(), - cubeMathType); + stride.get(), padding.get(), dilation.get(), transposed, + padding.get(), // output padding (unused for non-transposed) + groups, acl_y.get(), cubeMathType); } - void ggml_cann_op_add_rms_norm_fused(ggml_backend_cann_context & ctx, ggml_tensor * add_node, ggml_tensor * rms_norm_node) { @@ -3860,3 +4365,72 @@ void ggml_cann_op_add_rms_norm_fused(ggml_backend_cann_context & ctx, eps, // double type acl_yout.get(), acl_rstd.get(), acl_xout.get()); } + +void ggml_cann_gated_linear_attn(ggml_backend_cann_context & ctx, ggml_tensor * dst) { + ggml_tensor * k = dst->src[0]; + ggml_tensor * v = dst->src[1]; + ggml_tensor * q = dst->src[2]; + ggml_tensor * g = dst->src[3]; + ggml_tensor * s = dst->src[4]; + + int64_t B = dst->src[4]->ne[1]; + int64_t T = dst->src[0]->ne[2]; + int64_t H = dst->src[0]->ne[1]; + int64_t C = dst->ne[0]; + int64_t D = C / H; + int64_t L = T / B; + + int64_t ne_qkg[2] = { 1, D }; + int64_t ne_s[2] = { D, D }; + int64_t ne_st[2] = { ne_s[1], ne_s[0] }; + int64_t ne_vo[2] = { D, 1 }; + int64_t ne_q[1] = { D }; + size_t nb_base = ggml_type_size(k->type); + size_t nb_qkg[2] = { nb_base, nb_base }; + size_t nb_s[2] = { nb_base, D * nb_base }; + size_t nb_st[2] = { nb_s[1], nb_s[0] }; + size_t nb_vo[2] = { nb_base, D * nb_base }; + size_t nb_q[1] = { nb_base }; + + const float scale = ggml_get_op_params_f32(dst, 0); + + acl_tensor_ptr acl_s = ggml_cann_create_tensor(s, s->ne, s->nb, 2, ACL_FORMAT_ND); + acl_tensor_ptr new_state = ggml_cann_create_tensor(dst, s->ne, s->nb, 2, ACL_FORMAT_ND, (B * L * H * D) * nb_base); + cann_copy(ctx, acl_s.get(), new_state.get()); + + for (int64_t b = 0; b < B; b++) { + for (int64_t h = 0; h < H; h++) { + size_t s_offset = (b * (H * D * D) + h * (D * D)) * nb_base; + // D * D + acl_tensor_ptr acl_s_new = + ggml_cann_create_tensor(dst, ne_s, nb_s, 2, ACL_FORMAT_ND, (B * L * H * D) * nb_base + s_offset); + acl_tensor_ptr acl_s_new_t = + ggml_cann_create_tensor(dst, ne_st, nb_st, 2, ACL_FORMAT_ND, (B * L * H * D) * nb_base + s_offset); + for (int64_t l = 0; l < L; l++) { + size_t qkvgo_offset = (b * (L * H * D) + l * (H * D) + h * (D)) * nb_base; + // D * 1 + acl_tensor_ptr acl_k = ggml_cann_create_tensor(k, ne_qkg, nb_qkg, 2, ACL_FORMAT_ND, qkvgo_offset); + acl_tensor_ptr acl_g = ggml_cann_create_tensor(g, ne_qkg, nb_qkg, 2, ACL_FORMAT_ND, qkvgo_offset); + // D + acl_tensor_ptr acl_q = ggml_cann_create_tensor(q, ne_q, nb_q, 1, ACL_FORMAT_ND, qkvgo_offset); + // 1 * D + acl_tensor_ptr acl_v = ggml_cann_create_tensor(v, ne_vo, nb_vo, 2, ACL_FORMAT_ND, qkvgo_offset); + // D + acl_tensor_ptr acl_o = ggml_cann_create_tensor(dst, ne_q, nb_q, 1, ACL_FORMAT_ND, qkvgo_offset); + // k ⊗ v + size_t buf_size = D * D * nb_base; + ggml_cann_pool_alloc buffer_allocator(ctx.pool(), buf_size); + acl_tensor_ptr tmp_tensor = ggml_cann_create_tensor( + buffer_allocator.get(), ggml_cann_type_mapping(k->type), nb_base, ne_s, nb_s, 2); + aclnn_mul(ctx, acl_k.get(), acl_v.get(), tmp_tensor.get()); + //s_new = g ⊗ s_old + k ⊗ v + aclnn_mul(ctx, acl_s_new.get(), acl_g.get(), nullptr); + aclnn_add(ctx, acl_s_new.get(), tmp_tensor.get(), nullptr); + // compute output + GGML_CANN_CALL_ACLNN_OP(ctx, Mv, acl_s_new_t.get(), acl_q.get(), acl_o.get(), 1); + aclnn_muls(ctx, acl_o.get(), scale, nullptr, true); + } + } + } +} + diff --git a/ggml/src/ggml-cann/aclnn_ops.h b/ggml/src/ggml-cann/aclnn_ops.h index 08ee7b1fbdf..cdbf9260f85 100644 --- a/ggml/src/ggml-cann/aclnn_ops.h +++ b/ggml/src/ggml-cann/aclnn_ops.h @@ -1,5 +1,5 @@ /** - * Copyright (c) 2023-2024 The ggml authors + * Copyright (c) 2023-2026 The ggml authors * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to @@ -32,6 +32,9 @@ #include <aclnnop/aclnn_cat.h> #include <aclnnop/aclnn_clamp.h> #include <aclnnop/aclnn_cos.h> +#include <aclnnop/aclnn_cumsum.h> +#include <aclnnop/aclnn_tril.h> +#include <aclnnop/aclnn_triu.h> #include <aclnnop/aclnn_exp.h> #include <aclnnop/aclnn_gelu.h> #include <aclnnop/aclnn_gelu_v2.h> @@ -47,6 +50,9 @@ #include <aclnnop/aclnn_sign.h> #include <aclnnop/aclnn_silu.h> #include <aclnnop/aclnn_sin.h> +#include <aclnnop/aclnn_softplus.h> +#include <aclnnop/aclnn_swi_glu.h> +#include <aclnnop/aclnn_geglu.h> #include <aclnnop/aclnn_slice.h> #include <aclnnop/aclnn_sqrt.h> #include <aclnnop/aclnn_tanh.h> @@ -69,6 +75,9 @@ */ void ggml_cann_repeat(ggml_backend_cann_context & ctx, ggml_tensor * dst); +void ggml_cann_swiglu(ggml_backend_cann_context & ctx, ggml_tensor * dst); +void ggml_cann_geglu(ggml_backend_cann_context & ctx, ggml_tensor * dst, int64_t approximate); + /** * @brief Applies the Leaky ReLU activation function to a tensor using the CANN * backend. @@ -325,6 +334,48 @@ void ggml_cann_sum_rows(ggml_backend_cann_context & ctx, ggml_tensor * dst); void ggml_cann_sum(ggml_backend_cann_context & ctx, ggml_tensor * dst); +/** + * @brief Computes the cumulative sum of a ggml tensor along dim 0 using the + * CANN backend. + * + * @param ctx The CANN context used for operations. + * @param dst The destination tensor. dst->op is `GGML_OP_CUMSUM`. + */ +void ggml_cann_cumsum(ggml_backend_cann_context & ctx, ggml_tensor * dst); + +/** + * @brief Computes a triangular mask (tril/triu) of a square ggml tensor + * using the CANN backend. + * + * @param ctx The CANN context used for operations. + * @param dst The destination tensor. dst->op is `GGML_OP_TRI`. + */ +void ggml_cann_tri(ggml_backend_cann_context & ctx, ggml_tensor * dst); + +/** + * @brief Solves a triangular linear system AX=B using the CANN backend. + * + * @param ctx The CANN context used for operations. + * @param dst The destination tensor. dst->op is `GGML_OP_SOLVE_TRI`. + */ +void ggml_cann_solve_tri(ggml_backend_cann_context & ctx, ggml_tensor * dst); + +/** + * @brief Creates a diagonal matrix from a vector using the CANN backend. + * + * @param ctx The CANN context used for operations. + * @param dst The destination tensor. dst->op is `GGML_OP_DIAG`. + */ +void ggml_cann_diag(ggml_backend_cann_context & ctx, ggml_tensor * dst); + +/** + * @brief Fills a tensor with a constant scalar value using the CANN backend. + * + * @param ctx The CANN context used for operations. + * @param dst The destination tensor. dst->op is `GGML_OP_FILL`. + */ +void ggml_cann_fill(ggml_backend_cann_context & ctx, ggml_tensor * dst); + /** * @brief Upsamples a ggml tensor using nearest neighbor interpolation using * the CANN backend. @@ -461,6 +512,9 @@ void ggml_cann_timestep_embedding(ggml_backend_cann_context & ctx, ggml_tensor * // @see ggml_cann_dup. void ggml_cann_cpy(ggml_backend_cann_context & ctx, ggml_tensor * dst); +// @see ggml_cann_acc, but copies src1 into dst instead of adding. +void ggml_cann_set(ggml_backend_cann_context & ctx, ggml_tensor * dst); + /** * @brief Computes the softmax activation with optional masking. * @@ -543,6 +597,21 @@ void ggml_cann_mul_mat(ggml_backend_cann_context & ctx, ggml_tensor * dst); */ void ggml_cann_rope(ggml_backend_cann_context & ctx, ggml_tensor * dst); +/** + * @brief Pre-load the RoPE cache before ACL graph capture. + * + * This function must be called outside of graph capture to perform + * host-to-device memory copies and device memory allocations that are + * not allowed on a captured stream. After pre-loading, the rope cache + * metadata is updated so that the subsequent call to + * aclnn_rope_cache_init (inside graph capture) skips these operations + * and only records the on-device computations into the captured graph. + * + * @param ctx CANN backend context. + * @param dst A ROPE destination tensor from the computation graph. + */ +void ggml_cann_rope_cache_preload(ggml_backend_cann_context & ctx, ggml_tensor * dst); + /** * @brief Computes the index of the maximum value along the specified dimension * of a ggml tensor using the CANN backend. @@ -798,6 +867,8 @@ void ggml_cann_count_equal(ggml_backend_cann_context & ctx, ggml_tensor * dst); * dst->op is expected to be `GGML_OP_STEP`. */ void ggml_cann_step(ggml_backend_cann_context & ctx, ggml_tensor * dst); +void ggml_cann_softplus(ggml_backend_cann_context & ctx, ggml_tensor * dst); +void ggml_cann_geglu_quick(ggml_backend_cann_context & ctx, ggml_tensor * dst); /** * @brief Performs the Flash Attention extended operator using the CANN backend. @@ -814,67 +885,20 @@ void ggml_cann_step(ggml_backend_cann_context & ctx, ggml_tensor * dst); */ void ggml_cann_flash_attn_ext(ggml_backend_cann_context & ctx, ggml_tensor * dst); -/* - * @brief A generic wrapper for ACL resources with custom deleter support. - */ -using any_acl_resource = std::unique_ptr<void, std::function<void(void *)>>; - /** - * @brief Trait structure used to define how to destroy a given ACL resource type. + * @brief Forward Gated Linear Attention on the CANN backend. * - * @tparam T ACL resource type. - */ -template <typename T> struct acl_resource_traits; - -/** - * @brief Specialization for aclTensor, defines how to destroy an aclTensor resource. - */ -template <> struct acl_resource_traits<aclTensor> { - static void destroy(void * p) { ACL_CHECK(aclDestroyTensor(static_cast<aclTensor *>(p))); } -}; - -/** - * @brief Specialization for aclIntArray, defines how to destroy an aclIntArray resource. - */ -template <> struct acl_resource_traits<aclIntArray> { - static void destroy(void * p) { ACL_CHECK(aclDestroyIntArray(static_cast<aclIntArray *>(p))); } -}; - -/** - * @brief Specialization for aclScalar, defines how to destroy an aclScalar resource. - */ -template <> struct acl_resource_traits<aclScalar> { - static void destroy(void * p) { ACL_CHECK(aclDestroyScalar(static_cast<aclScalar *>(p))); } -}; - -/** - * @brief Specialization for aclTensorList, defines how to destroy an aclTensorList resource. - */ -template <> struct acl_resource_traits<aclTensorList> { - static void destroy(void * p) { ACL_CHECK(aclDestroyTensorList(static_cast<aclTensorList *>(p))); } -}; - -/** - * @brief Creates a generic ACL resource wrapper with proper destruction logic. + * Expects dst->src[0..4] = {k, v, q, g, s} with shape conventions: + * k, v, q, g: [D] with outer dims T x H batched as ne[2]=T, ne[1]=H + * s: initial state [B, H, D, D], where B is batch and D=C/H + * dst holds both outputs (o) and updated state; a scale factor is read from op params. * - * @tparam T ACL resource type. - * @param ptr Raw pointer to ACL resource. - * @return any_acl_resource Smart pointer that handles destruction. - */ -template <typename T> any_acl_resource make_acl_resource(T * ptr) { - return any_acl_resource(static_cast<void *>(ptr), [](void * p) { acl_resource_traits<T>::destroy(p); }); -} - -/** - * @brief Registers multiple ACL resources into a vector for lifetime management. + * The kernel updates per time step l: S_new = g ⊗ S_old + k ⊗ v, then computes o = (S_new^T q) * scale. * - * @tparam Args Variadic list of ACL resource types. - * @param vec Target vector to hold ACL resources. - * @param args Raw pointers to ACL resources. + * @param ctx Backend context providing stream/allocator utilities. + * @param dst Output tensor; src deps are k, v, q, g, s as above. */ -template <typename... Args> void register_acl_resources(std::vector<any_acl_resource> & vec, Args *... args) { - (vec.emplace_back(make_acl_resource(args)), ...); -} +void ggml_cann_gated_linear_attn(ggml_backend_cann_context & ctx, ggml_tensor * dst); /** * @brief Launches an asynchronous task using the memory allocator. @@ -894,19 +918,19 @@ template <typename... Args> void register_acl_resources(std::vector<any_acl_reso * same stream are executed in queue order. */ -#define GGML_CANN_CALL_ACLNN_OP(CTX, OP_NAME, ...) \ - do { \ - uint64_t workspaceSize = 0; \ - aclOpExecutor * executor; \ - void * workspaceAddr = nullptr; \ - ACL_CHECK(aclnn##OP_NAME##GetWorkspaceSize(__VA_ARGS__, &workspaceSize, &executor)); \ - /* workspace should alloced in main thread to keep malloc order when using vmm. */ \ - if (workspaceSize > 0) { \ - ggml_cann_pool_alloc workspace_allocator(CTX.pool(), workspaceSize); \ - workspaceAddr = workspace_allocator.get(); \ - } \ - ACL_CHECK(aclnn##OP_NAME(workspaceAddr, workspaceSize, executor, CTX.stream())); \ - } while (0) +# define GGML_CANN_CALL_ACLNN_OP(CTX, OP_NAME, ...) \ + do { \ + uint64_t workspaceSize = 0; \ + aclOpExecutor * executor; \ + void * workspaceAddr = nullptr; \ + ACL_CHECK(aclnn##OP_NAME##GetWorkspaceSize(__VA_ARGS__, &workspaceSize, &executor)); \ + /* workspace should alloced in main thread to keep malloc order when using vmm. */ \ + if (workspaceSize > 0) { \ + ggml_cann_pool_alloc workspace_allocator(CTX.pool(), workspaceSize); \ + workspaceAddr = workspace_allocator.get(); \ + } \ + ACL_CHECK(aclnn##OP_NAME(workspaceAddr, workspaceSize, executor, CTX.stream())); \ + } while (0) /** * @brief Performs sparse expert-based matrix multiplication using the CANN backend. @@ -947,7 +971,9 @@ void ggml_cann_mul_mat_id(ggml_backend_cann_context & ctx, ggml_tensor * dst); * @param rms_norm_tensor The RMS_NORM operation node, contains the gamma weights * and epsilon parameter. */ -void ggml_cann_op_add_rms_norm_fused(ggml_backend_cann_context & ctx, ggml_tensor * add_node, ggml_tensor * rms_norm_node); +void ggml_cann_op_add_rms_norm_fused(ggml_backend_cann_context & ctx, + ggml_tensor * add_node, + ggml_tensor * rms_norm_node); /** * @brief Check whether a tensor is a weight tensor for matrix multiplication. @@ -1104,13 +1130,13 @@ void ggml_cann_op_unary_gated(std::function<void(ggml_backend_cann_context &, ac * @see ggml_cann_op_unary * @see GGML_CANN_CALL_ACLNN_OP */ -#define GGML_CANN_CALL_OP_UNARY(OP_NAME) \ - do { \ - auto lambda = [](ggml_backend_cann_context & ctx, aclTensor * acl_src, aclTensor * acl_dst) { \ - GGML_CANN_CALL_ACLNN_OP(ctx, OP_NAME, acl_src, acl_dst); \ - }; \ - ggml_cann_op_unary(lambda, ctx, dst); \ - } while (0) +# define GGML_CANN_CALL_OP_UNARY(OP_NAME) \ + do { \ + auto lambda = [](ggml_backend_cann_context & ctx, aclTensor * acl_src, aclTensor * acl_dst) { \ + GGML_CANN_CALL_ACLNN_OP(ctx, OP_NAME, acl_src, acl_dst); \ + }; \ + ggml_cann_op_unary(lambda, ctx, dst); \ + } while (0) /** * @brief Helper macro to call a gated unary ACL operator via ggml_cann_op_unary_gated. @@ -1133,13 +1159,13 @@ void ggml_cann_op_unary_gated(std::function<void(ggml_backend_cann_context &, ac * @see ggml_cann_op_unary_gated * @see GGML_CANN_CALL_ACLNN_OP */ -#define GGML_CANN_CALL_OP_UNARY_GATED(OP_NAME) \ - do { \ - auto lambda = [](ggml_backend_cann_context & ctx, aclTensor * acl_src, aclTensor * acl_dst) { \ - GGML_CANN_CALL_ACLNN_OP(ctx, OP_NAME, acl_src, acl_dst); \ - }; \ - ggml_cann_op_unary_gated(lambda, ctx, dst); \ - } while (0) +# define GGML_CANN_CALL_OP_UNARY_GATED(OP_NAME) \ + do { \ + auto lambda = [](ggml_backend_cann_context & ctx, aclTensor * acl_src, aclTensor * acl_dst) { \ + GGML_CANN_CALL_ACLNN_OP(ctx, OP_NAME, acl_src, acl_dst); \ + }; \ + ggml_cann_op_unary_gated(lambda, ctx, dst); \ + } while (0) #endif // CANN_ACLNN_OPS diff --git a/ggml/src/ggml-cann/common.h b/ggml/src/ggml-cann/common.h index 6895349b207..1c6e685c38c 100644 --- a/ggml/src/ggml-cann/common.h +++ b/ggml/src/ggml-cann/common.h @@ -1,5 +1,5 @@ /* - * Copyright (c) 2023-2024 The ggml authors + * Copyright (c) 2023-2026 The ggml authors * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to @@ -101,7 +101,6 @@ struct ggml_cann_device_info { const ggml_cann_device_info & ggml_cann_info(); void ggml_cann_set_device(int32_t device); -int32_t ggml_cann_get_device(); std::optional<std::string> get_env_as_lowercase(const std::string & name); bool parse_bool(const std::string & value); @@ -217,14 +216,16 @@ struct ggml_cann_pool_alloc { #ifdef USE_ACL_GRAPH struct ggml_graph_node_properties { // dst tensor - void * node_address; - int64_t ne[GGML_MAX_DIMS]; - size_t nb[GGML_MAX_DIMS]; + void * node_address; + ggml_type node_type; + int64_t ne[GGML_MAX_DIMS]; + size_t nb[GGML_MAX_DIMS]; // src tensor - void * src_address[GGML_MAX_SRC]; - int64_t src_ne[GGML_MAX_SRC][GGML_MAX_DIMS]; - size_t src_nb[GGML_MAX_SRC][GGML_MAX_DIMS]; + void * src_address[GGML_MAX_SRC]; + ggml_type src_type[GGML_MAX_SRC]; + int64_t src_ne[GGML_MAX_SRC][GGML_MAX_DIMS]; + size_t src_nb[GGML_MAX_SRC][GGML_MAX_DIMS]; // op ggml_op node_op; @@ -248,6 +249,10 @@ struct ggml_graph_node_properties { return false; } + if (node->type != this->node_type) { + return false; + } + for (int i = 0; i < GGML_MAX_DIMS; i++) { if (node->ne[i] != this->ne[i]) { return false; @@ -263,6 +268,10 @@ struct ggml_graph_node_properties { return false; } + if (node->src[i]->type != this->src_type[i]) { + return false; + } + for (int d = 0; d < GGML_MAX_DIMS; d++) { if (node->src[i]->ne[d] != this->src_ne[i][d]) { return false; @@ -278,10 +287,7 @@ struct ggml_graph_node_properties { } } - if (node->op == GGML_OP_SCALE || node->op == GGML_OP_UNARY || node->op == GGML_OP_GLU) { - return memcmp(this->op_params, node->op_params, GGML_MAX_OP_PARAMS) == 0; - } - return true; + return memcmp(this->op_params, node->op_params, GGML_MAX_OP_PARAMS) == 0; } }; @@ -323,6 +329,7 @@ struct ggml_cann_graph { prop.node_address = node->data; prop.node_op = node->op; + prop.node_type = node->type; std::copy_n(node->ne, GGML_MAX_DIMS, prop.ne); std::copy_n(node->nb, GGML_MAX_DIMS, prop.nb); @@ -330,10 +337,12 @@ struct ggml_cann_graph { for (int src = 0; src < GGML_MAX_SRC; ++src) { if (node->src[src]) { prop.src_address[src] = node->src[src]->data; + prop.src_type[src] = node->src[src]->type; std::copy_n(node->src[src]->ne, GGML_MAX_DIMS, prop.src_ne[src]); std::copy_n(node->src[src]->nb, GGML_MAX_DIMS, prop.src_nb[src]); } else { prop.src_address[src] = nullptr; + prop.src_type[src] = GGML_TYPE_COUNT; std::fill_n(prop.src_ne[src], GGML_MAX_DIMS, 0); std::fill_n(prop.src_nb[src], GGML_MAX_DIMS, 0); } @@ -382,7 +391,7 @@ struct ggml_cann_graph_lru_cache { std::list<ggml_cann_graph *> cache_list; /**< List storing cached graphs as raw pointers. */ - ggml_cann_graph_lru_cache() { capacity = parse_integer(get_env("GGML_CANN_GRAPH_CACHE_CAPACITY").value_or("12")); } + ggml_cann_graph_lru_cache() { capacity = parse_integer(get_env_as_lowercase("GGML_CANN_GRAPH_CACHE_CAPACITY").value_or("12")); } /** * @brief Push a new graph to the front of the cache. @@ -574,7 +583,7 @@ struct ggml_backend_cann_context { description = aclrtGetSocName(); #ifdef USE_ACL_GRAPH - acl_graph_mode = parse_bool(get_env("GGML_CANN_ACL_GRAPH").value_or("on")); + acl_graph_mode = parse_bool(get_env_as_lowercase("GGML_CANN_ACL_GRAPH").value_or("on")); GGML_LOG_INFO("%s: device %d execution mode is %s (%s)\n", __func__, device, acl_graph_mode ? "GRAPH" : "EAGER", acl_graph_mode ? "acl graph enabled" : "acl graph disabled"); #endif diff --git a/ggml/src/ggml-cann/ggml-cann.cpp b/ggml/src/ggml-cann/ggml-cann.cpp index d7a93848df8..5f51ea3bb3c 100644 --- a/ggml/src/ggml-cann/ggml-cann.cpp +++ b/ggml/src/ggml-cann/ggml-cann.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2023-2024 The ggml authors + * Copyright (c) 2023-2026 The ggml authors * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to @@ -36,10 +36,13 @@ #include <cmath> #include <cstdio> #include <cstring> +#include <memory> #include <mutex> #include <optional> #include <queue> +#include <unordered_map> #include <unordered_set> +#include <vector> #define GGML_COMMON_DECL_C @@ -93,17 +96,6 @@ void ggml_cann_set_device(const int32_t device) { g_current_cann_device = device; } -/** - * @brief Retrieves the current device ID. - * - * @return The current device ID. - */ -int32_t ggml_cann_get_device() { - int32_t id; - ACL_CHECK(aclrtGetDevice(&id)); - return id; -} - /** * @brief Get the value of the specified environment variable (name) as lowercase. * if not empty, return a std::string object @@ -781,6 +773,21 @@ std::unique_ptr<ggml_cann_pool> ggml_backend_cann_context::new_pool_for_device(i } // cann buffer + +/** + * @brief Tracks multi-threaded write progress for a single tensor. + * + * When multiple threads call set_tensor on different chunks of the same tensor, + * this tracker accumulates progress and defers post-processing (quantized format + * transform or ND-to-NZ conversion) until all data has been written. + */ +struct TensorSetTracker { + std::mutex mtx; ///< Protects concurrent access to this tracker + size_t bytes_written = 0; ///< Accumulated bytes written so far + size_t total_bytes = 0; ///< Target size (full tensor) + std::vector<uint8_t> host_buffer; ///< Host staging buffer for quantized tensors +}; + /** * @brief Context for managing a CANN buffer associated with a specific device. * @@ -791,6 +798,9 @@ struct ggml_backend_cann_buffer_context { int32_t device; ///< The device ID associated with this buffer context. void * dev_ptr = nullptr; ///< Pointer to the device memory allocated for the buffer. + std::mutex tracker_mutex; ///< Protects the trackers map + std::unordered_map<void *, std::unique_ptr<TensorSetTracker>> trackers; + /** * @brief Constructor to initialize the CANN buffer context. * @@ -803,21 +813,71 @@ struct ggml_backend_cann_buffer_context { * @brief Destructor to free the device memory allocated for the buffer. */ ~ggml_backend_cann_buffer_context() { ACL_CHECK(aclrtFree(dev_ptr)); } + + /** + * @brief Get or create a tracker for the given tensor. + */ + TensorSetTracker * get_or_create_tracker(ggml_tensor * tensor) { + std::lock_guard<std::mutex> lock(tracker_mutex); + auto key = tensor->data; + auto it = trackers.find(key); + if (it == trackers.end()) { + auto tracker = std::make_unique<TensorSetTracker>(); + tracker->total_bytes = ggml_nbytes(tensor); + auto * ptr = tracker.get(); + trackers[key] = std::move(tracker); + return ptr; + } + return it->second.get(); + } + + /** + * @brief Remove the tracker for the given tensor. + */ + void remove_tracker(ggml_tensor * tensor) { + std::lock_guard<std::mutex> lock(tracker_mutex); + trackers.erase(tensor->data); + } +}; + +// cann buffer type +/** + * @brief Structure representing context information for a specific backend + * buffer type. + */ +struct ggml_backend_cann_buffer_type_context { + int32_t device; /**< Device identifier associated with the buffer context. */ + std::string name; /**< Name associated with the buffer context. */ }; /** - * @brief Check if a buffer is a CANN buffer. + * @brief Retrieves the name associated with a CANN buffer type. * - * This function checks if a given buffer is a CANN buffer by comparing its - * `get_name` function pointer to `ggml_backend_cann_buffer_get_name`. + * This function returns the descriptive name associated with the specified + * CANN buffer type context. * - * @param buffer The buffer to check. - * @return true if the buffer is a CANN buffer, false otherwise. + * @param buft Pointer to the buffer type context. + * @return Const pointer to the C-style string containing the name. */ -static bool ggml_backend_buft_is_cann(ggml_backend_buffer_type_t buft); +static const char * ggml_backend_cann_buffer_type_name(ggml_backend_buffer_type_t buft) { + ggml_backend_cann_buffer_type_context * buft_ctx = (ggml_backend_cann_buffer_type_context *) buft->context; -static bool ggml_backend_buffer_is_cann(ggml_backend_buffer_t buffer) { - return ggml_backend_buft_is_cann(buffer->buft); + return buft_ctx->name.c_str(); +} + +/** + * @brief Checks if the backend buffer type is associated with the CANN backend. + * + * This function checks whether the provided backend buffer type is associated + * with the CANN backend based on the comparison of its name retrieval function + * pointer. + * + * @param buft Pointer to the backend buffer type to check. + * @return bool Returns true if the buffer type is associated with the CANN + * backend, otherwise false. + */ +static bool ggml_backend_buft_is_cann(ggml_backend_buffer_type_t buft) { + return buft->iface.get_name == ggml_backend_cann_buffer_type_name; } /** @@ -1110,6 +1170,7 @@ static enum ggml_status ggml_backend_cann_buffer_init_tensor(ggml_backend_buffer * designed to be used with a global array, one per device. */ struct ggml_cann_nz_workspace { + std::mutex mtx; // Protects ptr/allocated from concurrent access void * ptr; // Pointer to allocated device buffer size_t allocated; // Size of currently allocated buffer in bytes @@ -1176,13 +1237,15 @@ static ggml_cann_nz_workspace g_nz_workspaces[GGML_CANN_MAX_DEVICES]; * @note The workspace buffer used in this function is managed globally and reused * across calls. This reduces overhead from repeated memory allocation and deallocation. */ -static void weight_format_to_nz(ggml_tensor * tensor, size_t offset, int device) { - acl_tensor_ptr weightTransposed = ggml_cann_create_tensor(tensor, tensor->ne, tensor->nb, 2, ACL_FORMAT_ND, offset); +static void weight_format_to_nz(ggml_tensor * tensor, int device) { + acl_tensor_ptr weightTransposed = ggml_cann_create_tensor(tensor, tensor->ne, tensor->nb, 2, ACL_FORMAT_ND, 0); uint64_t workspaceSize = 0; aclOpExecutor * executor; // TransMatmulWeight ACL_CHECK(aclnnTransMatmulWeightGetWorkspaceSize(weightTransposed.get(), &workspaceSize, &executor)); + + std::lock_guard<std::mutex> lock(g_nz_workspaces[device].mtx); // Avoid frequent malloc/free of the workspace. g_nz_workspaces[device].realloc(workspaceSize); @@ -1196,7 +1259,13 @@ static void weight_format_to_nz(ggml_tensor * tensor, size_t offset, int device) * @brief Set tensor data in a CANN buffer. * * This function sets tensor data in a CANN buffer, handling transformations - * if needed based on the tensor's type. + * if needed based on the tensor's type. It supports multi-threaded calls + * where different threads write different chunks of the same tensor. + * + * For quantized tensors (Q4_0/Q8_0), data is staged in a host buffer and + * the format transform is deferred until all chunks are written. + * For NZ weight tensors, chunks are uploaded directly but the ND-to-NZ + * conversion is deferred until all chunks are written. * * @param buffer The CANN buffer where the tensor data will be set. * @param tensor Pointer to the tensor whose data will be set. @@ -1212,25 +1281,72 @@ static void ggml_backend_cann_buffer_set_tensor(ggml_backend_buffer_t buffer, ggml_backend_cann_buffer_context * ctx = (ggml_backend_cann_buffer_context *) buffer->context; ggml_cann_set_device(ctx->device); - // TODO: refer to cann(#6017), it use thread's default stream. - // For acl, synchronous functions use this default stream. - // Why aclrtSynchronizeDevice? // Only check env once. static bool weight_to_nz = parse_bool(get_env_as_lowercase("GGML_CANN_WEIGHT_NZ").value_or("on")); - if (!need_transform(tensor->type)) { + + bool is_quantized = need_transform(tensor->type); + bool is_nz = !is_quantized && tensor->type != GGML_TYPE_BF16 && weight_to_nz && + is_matmul_weight((const ggml_tensor *) tensor); + + // Plain tensor (not quantized, not NZ): direct copy, no tracking needed + if (!is_quantized && !is_nz) { ACL_CHECK(aclrtMemcpy((char *) tensor->data + offset, size, data, size, ACL_MEMCPY_HOST_TO_DEVICE)); - if (weight_to_nz && is_matmul_weight((const ggml_tensor *) tensor)) { + return; + } + + // Single-shot write (full tensor at once): handle directly without tracking overhead + if (offset == 0 && size == ggml_nbytes(tensor)) { + if (is_quantized) { + void * transform_buffer = malloc(size); + ggml_backend_cann_transform(tensor, data, transform_buffer); + ACL_CHECK(aclrtMemcpy(tensor->data, size, transform_buffer, size, ACL_MEMCPY_HOST_TO_DEVICE)); + free(transform_buffer); + } else { + // NZ weight GGML_ASSERT(tensor->ne[2] == 1); GGML_ASSERT(tensor->ne[3] == 1); - weight_format_to_nz(tensor, offset, ctx->device); + ACL_CHECK(aclrtMemcpy(tensor->data, size, data, size, ACL_MEMCPY_HOST_TO_DEVICE)); + weight_format_to_nz(tensor, ctx->device); + } + return; + } + + // Chunked write: use tracker to accumulate progress and defer transform/conversion + TensorSetTracker * tracker = ctx->get_or_create_tracker(tensor); + std::unique_lock<std::mutex> lock(tracker->mtx); + + if (is_quantized) { + // Stage data in host buffer; transform requires full tensor data + if (tracker->host_buffer.empty()) { + tracker->host_buffer.resize(tracker->total_bytes); } + memcpy(tracker->host_buffer.data() + offset, data, size); } else { - void * transform_buffer = malloc(size); - ggml_backend_cann_transform(tensor, data, transform_buffer); + // NZ weight: upload chunk to device immediately, defer conversion + ACL_CHECK(aclrtMemcpy((char *) tensor->data + offset, size, data, size, ACL_MEMCPY_HOST_TO_DEVICE)); + } - ACL_CHECK(aclrtMemcpy((char *) tensor->data + offset, size, transform_buffer, size, ACL_MEMCPY_HOST_TO_DEVICE)); - free(transform_buffer); + tracker->bytes_written += size; + + // All chunks received: perform deferred transform/conversion + if (tracker->bytes_written >= tracker->total_bytes) { + if (is_quantized) { + void * transform_buffer = malloc(tracker->total_bytes); + ggml_backend_cann_transform(tensor, tracker->host_buffer.data(), transform_buffer); + ACL_CHECK(aclrtMemcpy(tensor->data, tracker->total_bytes, transform_buffer, tracker->total_bytes, ACL_MEMCPY_HOST_TO_DEVICE)); + free(transform_buffer); + } + + if (is_nz) { + GGML_ASSERT(tensor->ne[2] == 1); + GGML_ASSERT(tensor->ne[3] == 1); + weight_format_to_nz(tensor, ctx->device); + } + + // Unlock before removing tracker, as remove_tracker destroys the mutex + lock.unlock(); + ctx->remove_tracker(tensor); } } @@ -1282,7 +1398,7 @@ static void ggml_backend_cann_buffer_get_tensor(ggml_backend_buffer_t buffer, static bool ggml_backend_cann_buffer_cpy_tensor(ggml_backend_buffer_t buffer, const ggml_tensor * src, ggml_tensor * dst) { - if (ggml_backend_buffer_is_cann(src->buffer)) { + if (ggml_backend_buft_is_cann(src->buffer->buft)) { ggml_backend_cann_buffer_context * src_ctx = (ggml_backend_cann_buffer_context *) src->buffer->context; ggml_backend_cann_buffer_context * dst_ctx = (ggml_backend_cann_buffer_context *) buffer->context; @@ -1312,6 +1428,22 @@ static bool ggml_backend_cann_buffer_cpy_tensor(ggml_backend_buffer_t buffer, return false; } +/** + * @brief Set a region of a tensor's device memory to a specified value. + * + * @param buffer The CANN buffer containing the tensor. + * @param tensor Pointer to the tensor whose memory will be set. + * @param value The value to which each byte in the region will be set. + * @param offset Byte offset within the tensor's data to start setting. + * @param size Number of bytes to set. + */ +static void ggml_backend_cann_buffer_memset_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor, uint8_t value, size_t offset, size_t size) { + ggml_backend_cann_buffer_context * ctx = (ggml_backend_cann_buffer_context *) buffer->context; + + ggml_cann_set_device(ctx->device); + ACL_CHECK(aclrtMemset((char *) tensor->data + offset, size, value, size)); +} + /** * @brief Clear a CANN buffer by setting all its memory to a specified value. * @@ -1338,39 +1470,16 @@ static const ggml_backend_buffer_i ggml_backend_cann_buffer_interface = { /* .free_buffer = */ ggml_backend_cann_buffer_free_buffer, /* .get_base = */ ggml_backend_cann_buffer_get_base, /* .init_tensor = */ ggml_backend_cann_buffer_init_tensor, - /* .memset_tensor = */ NULL, + /* .memset_tensor = */ ggml_backend_cann_buffer_memset_tensor, /* .set_tensor = */ ggml_backend_cann_buffer_set_tensor, /* .get_tensor = */ ggml_backend_cann_buffer_get_tensor, + /* .set_tensor_2d = */ NULL, + /* .get_tensor_2d = */ NULL, /* .cpy_tensor = */ ggml_backend_cann_buffer_cpy_tensor, /* .clear = */ ggml_backend_cann_buffer_clear, /* .reset = */ NULL, }; -// cann buffer type -/** - * @brief Structure representing context information for a specific backend - * buffer type. - */ -struct ggml_backend_cann_buffer_type_context { - int32_t device; /**< Device identifier associated with the buffer context. */ - std::string name; /**< Name associated with the buffer context. */ -}; - -/** - * @brief Retrieves the name associated with a CANN buffer type. - * - * This function returns the descriptive name associated with the specified - * CANN buffer type context. - * - * @param buft Pointer to the buffer type context. - * @return Const pointer to the C-style string containing the name. - */ -static const char * ggml_backend_cann_buffer_type_name(ggml_backend_buffer_type_t buft) { - ggml_backend_cann_buffer_type_context * buft_ctx = (ggml_backend_cann_buffer_type_context *) buft->context; - - return buft_ctx->name.c_str(); -} - /** * @brief Allocates a new CANN buffer of the specified type and size. * @@ -1454,7 +1563,8 @@ static size_t ggml_backend_cann_buffer_type_get_alloc_size(ggml_backend_buffer_t if (ne0 % MATRIX_ROW_PADDING != 0) { size += ggml_row_size(tensor->type, MATRIX_ROW_PADDING - ne0 % MATRIX_ROW_PADDING); } - } else if (weight_to_nz && is_matmul_weight((const ggml_tensor *) tensor)) { + } else if (weight_to_nz && tensor->type != GGML_TYPE_BF16 + && is_matmul_weight((const ggml_tensor *) tensor)) { // NZ format weight are not support quantized yet. // If ND tensor transform to NZ, size may changed. int64_t shape[] = { tensor->ne[1], tensor->ne[0] }; @@ -1741,6 +1851,9 @@ static bool ggml_cann_compute_forward(ggml_backend_cann_context & ctx, struct gg case GGML_UNARY_OP_STEP: ggml_cann_step(ctx, dst); break; + case GGML_UNARY_OP_SOFTPLUS: + ggml_cann_softplus(ctx, dst); + break; default: return false; } @@ -1751,20 +1864,16 @@ static bool ggml_cann_compute_forward(ggml_backend_cann_context & ctx, struct gg GGML_CANN_CALL_OP_UNARY_GATED(Relu); break; case GGML_GLU_OP_GEGLU: + ggml_cann_geglu(ctx, dst, 0); // approximate=0 → tanh + break; case GGML_GLU_OP_GEGLU_ERF: - // aclnnGelu internally uses the erf-based approximation. - GGML_CANN_CALL_OP_UNARY_GATED(Gelu); + ggml_cann_geglu(ctx, dst, 1); // approximate=1 → erf break; case GGML_GLU_OP_SWIGLU: - GGML_CANN_CALL_OP_UNARY_GATED(Silu); + ggml_cann_swiglu(ctx, dst); break; case GGML_GLU_OP_GEGLU_QUICK: - { - auto lambda = [](ggml_backend_cann_context & ctx, aclTensor * acl_src, aclTensor * acl_dst) { - GGML_CANN_CALL_ACLNN_OP(ctx, GeluV2, acl_src, 0, acl_dst); - }; - ggml_cann_op_unary_gated(lambda, ctx, dst); - } + ggml_cann_geglu_quick(ctx, dst); break; default: return false; @@ -1826,6 +1935,9 @@ static bool ggml_cann_compute_forward(ggml_backend_cann_context & ctx, struct gg case GGML_OP_CPY: ggml_cann_cpy(ctx, dst); break; + case GGML_OP_SET: + ggml_cann_set(ctx, dst); + break; case GGML_OP_CONT: ggml_cann_dup(ctx, dst); break; @@ -1889,9 +2001,27 @@ static bool ggml_cann_compute_forward(ggml_backend_cann_context & ctx, struct gg case GGML_OP_OUT_PROD: ggml_cann_out_prod(ctx, dst); break; + case GGML_OP_GATED_LINEAR_ATTN: + ggml_cann_gated_linear_attn(ctx, dst); + break; case GGML_OP_SSM_CONV: ggml_cann_ssm_conv(ctx, dst); break; + case GGML_OP_CUMSUM: + ggml_cann_cumsum(ctx, dst); + break; + case GGML_OP_TRI: + ggml_cann_tri(ctx, dst); + break; + case GGML_OP_FILL: + ggml_cann_fill(ctx, dst); + break; + case GGML_OP_DIAG: + ggml_cann_diag(ctx, dst); + break; + case GGML_OP_SOLVE_TRI: + ggml_cann_solve_tri(ctx, dst); + break; default: return false; } @@ -2005,7 +2135,7 @@ static bool ggml_backend_cann_cpy_tensor_async(ggml_backend_t backend_src, GGML_ASSERT(!is_matmul_weight((const ggml_tensor *) src)); - if (!ggml_backend_buffer_is_cann(src->buffer) || !ggml_backend_buffer_is_cann(dst->buffer)) { + if (!ggml_backend_buft_is_cann(src->buffer->buft) || !ggml_backend_buft_is_cann(dst->buffer->buft)) { return false; } @@ -2154,6 +2284,10 @@ static void evaluate_and_capture_cann_graph(ggml_backend_cann_context * cann_ctx continue; } + if ((node->flags & GGML_TENSOR_FLAG_COMPUTE) == 0) { + continue; + } + bool ok = ggml_cann_compute_forward(*cann_ctx, node); if (!ok) { GGML_LOG_ERROR("%s: op not supported %s (%s)\n", __func__, node->name, ggml_op_name(node->op)); @@ -2223,10 +2357,24 @@ static enum ggml_status ggml_backend_cann_graph_compute(ggml_backend_t backend, if (use_cann_graph) { // If no matching graph is found, the graph needs to be recaptured. graph_capture_required = !cann_ctx->graph_lru_cache.find_and_move_to_front(cgraph); + if (graph_capture_required) { // If no matching graph is found, add a new ACL graph. ggml_cann_graph * new_graph = ggml_cann_graph::create_from_cgraph(cgraph); cann_ctx->graph_lru_cache.push(new_graph); + + // Pre-load rope cache before graph capture. During capture the + // stream cannot perform host-to-device memcpy or device memory + // malloc/free. Running the full cache init now populates the + // cache metadata so these branches are skipped during capture, + // while also warming up the memory pool. + for (int i = 0; i < cgraph->n_nodes; i++) { + ggml_tensor * node = cgraph->nodes[i]; + if (node->op == GGML_OP_ROPE) { + ggml_cann_rope_cache_preload(*cann_ctx, node); + break; + } + } } } #else @@ -2268,6 +2416,7 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev, const ggml_ten case GGML_UNARY_OP_SGN: case GGML_UNARY_OP_STEP: case GGML_UNARY_OP_GELU_ERF: + case GGML_UNARY_OP_SOFTPLUS: return true; default: return false; @@ -2287,6 +2436,9 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev, const ggml_ten case GGML_OP_MUL_MAT: { switch (op->src[0]->type) { +#ifndef ASCEND_310P + case GGML_TYPE_BF16: +#endif case GGML_TYPE_F16: case GGML_TYPE_F32: return true; @@ -2324,6 +2476,9 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev, const ggml_ten switch (op->src[0]->type) { case GGML_TYPE_F32: case GGML_TYPE_F16: +#ifndef ASCEND_310P + case GGML_TYPE_BF16: +#endif case GGML_TYPE_Q8_0: return true; default: @@ -2336,6 +2491,9 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev, const ggml_ten switch (op->type) { case GGML_TYPE_F32: case GGML_TYPE_F16: +#ifndef ASCEND_310P + case GGML_TYPE_BF16: +#endif return true; default: return false; @@ -2345,20 +2503,30 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev, const ggml_ten case GGML_OP_CPY: { ggml_tensor * src = op->src[0]; +#ifdef ASCEND_310P if ((op->type != GGML_TYPE_F32 && op->type != GGML_TYPE_F16) || (src->type != GGML_TYPE_F32 && src->type != GGML_TYPE_F16)) { - // only support F32 and F16. + // only support F32 and F16 on 310P. + return false; + } +#else + if ((op->type != GGML_TYPE_F32 && op->type != GGML_TYPE_F16 && op->type != GGML_TYPE_BF16) || + (src->type != GGML_TYPE_F32 && src->type != GGML_TYPE_F16 && src->type != GGML_TYPE_BF16)) { + // only support F32, F16 and BF16. return false; } +#endif return true; } break; case GGML_OP_CONT: { - // TODO: support GGML_TYPE_BF16 switch (op->src[0]->type) { case GGML_TYPE_F32: case GGML_TYPE_F16: +#ifndef ASCEND_310P + case GGML_TYPE_BF16: +#endif return true; default: return false; @@ -2439,6 +2607,7 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev, const ggml_ten case GGML_OP_SUM_ROWS: case GGML_OP_ARGSORT: case GGML_OP_ACC: + case GGML_OP_SET: case GGML_OP_GROUP_NORM: return true; case GGML_OP_PAD: @@ -2454,6 +2623,7 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev, const ggml_ten case GGML_OP_MEAN: case GGML_OP_PAD_REFLECT_1D: case GGML_OP_COUNT_EQUAL: + case GGML_OP_GATED_LINEAR_ATTN: return true; case GGML_OP_OUT_PROD: { @@ -2506,10 +2676,6 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev, const ggml_ten // different head sizes of K and V are not supported yet return false; } - if (op->src[0]->ne[0] % 16 != 0) { - // TODO: padding to support - return false; - } float logitSoftcap = 0.0f; memcpy(&logitSoftcap, (const float *) (op->op_params) + 2, sizeof(float)); if (logitSoftcap != 0.0f) { @@ -2519,6 +2685,16 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev, const ggml_ten } case GGML_OP_SSM_CONV: return true; + case GGML_OP_CUMSUM: + return op->src[0]->type == GGML_TYPE_F32; + case GGML_OP_TRI: + return op->src[0]->type == GGML_TYPE_F32; + case GGML_OP_FILL: + return op->src[0]->type == GGML_TYPE_F32; + case GGML_OP_DIAG: + return op->src[0]->type == GGML_TYPE_F32; + case GGML_OP_SOLVE_TRI: + return op->src[0]->type == GGML_TYPE_F32; default: return false; } @@ -2526,21 +2702,6 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev, const ggml_ten GGML_UNUSED(dev); } -/** - * @brief Checks if the backend buffer type is associated with the CANN backend. - * - * This function checks whether the provided backend buffer type is associated - * with the CANN backend based on the comparison of its name retrieval function - * pointer. - * - * @param buft Pointer to the backend buffer type to check. - * @return bool Returns true if the buffer type is associated with the CANN - * backend, otherwise false. - */ -static bool ggml_backend_buft_is_cann(ggml_backend_buffer_type_t buft) { - return buft->iface.get_name == ggml_backend_cann_buffer_type_name; -} - /** * @brief Records an event on the CANN backend stream. * @@ -2585,6 +2746,8 @@ static const ggml_backend_i ggml_backend_cann_interface = { /* .free = */ ggml_backend_cann_free, /* .set_tensor_async = */ ggml_backend_cann_set_tensor_async, /* .get_tensor_async = */ ggml_backend_cann_get_tensor_async, + /* .set_tensor_2d_async = */ NULL, + /* .get_tensor_2d_async = */ NULL, /* .cpy_tensor_async = */ ggml_backend_cann_cpy_tensor_async, /* .synchronize = */ ggml_backend_cann_synchronize, /* .graph_plan_create = */ NULL, diff --git a/ggml/src/ggml-common.h b/ggml/src/ggml-common.h index 93ab7ea446e..f05683b44cd 100644 --- a/ggml/src/ggml-common.h +++ b/ggml/src/ggml-common.h @@ -93,6 +93,10 @@ typedef sycl::half2 ggml_half2; // QR = QK / number of values before dequantization // QI = number of 32 bit integers before dequantization +#define QI1_0 (QK1_0 / 32) +#define QR1_0 1 + + #define QI4_0 (QK4_0 / (4 * QR4_0)) #define QR4_0 2 @@ -102,6 +106,9 @@ typedef sycl::half2 ggml_half2; #define QI_MXFP4 (QK_MXFP4 / (4 * QR_MXFP4)) #define QR_MXFP4 2 +#define QI_NVFP4 (QK_NVFP4 / (4 * QR_NVFP4)) +#define QR_NVFP4 2 + #define QI5_0 (QK5_0 / (4 * QR5_0)) #define QR5_0 2 @@ -167,6 +174,13 @@ typedef sycl::half2 ggml_half2; #define GGML_EXTENSION __extension__ #endif // _MSC_VER +#define QK1_0 128 +typedef struct { + ggml_half d; // delta + uint8_t qs[QK1_0 / 8]; // bits / quants +} block_q1_0; +static_assert(sizeof(block_q1_0) == sizeof(ggml_half) + QK1_0 / 8, "wrong q1_0 block size/padding"); + #define QK4_0 32 typedef struct { ggml_half d; // delta @@ -194,6 +208,14 @@ typedef struct { } block_mxfp4; static_assert(sizeof(block_mxfp4) == sizeof(uint8_t) + QK_MXFP4/2, "wrong mxfp4 block size/padding"); +#define QK_NVFP4 64 +#define QK_NVFP4_SUB 16 // sub-block size for per-group scales +typedef struct { + uint8_t d[QK_NVFP4/QK_NVFP4_SUB]; // UE4M3 scales (4 bytes, one per 16-element sub-block) + uint8_t qs[QK_NVFP4/2]; // packed 4-bit E2M1 values (32 bytes) +} block_nvfp4; +static_assert(sizeof(block_nvfp4) == sizeof(uint8_t)*(QK_NVFP4/QK_NVFP4_SUB) + QK_NVFP4/2, "wrong nvfp4 block size/padding"); + #define QK5_0 32 typedef struct { ggml_half d; // delta diff --git a/ggml/src/ggml-cpu/CMakeLists.txt b/ggml/src/ggml-cpu/CMakeLists.txt index 7622d0bf49b..8c735a045b3 100644 --- a/ggml/src/ggml-cpu/CMakeLists.txt +++ b/ggml/src/ggml-cpu/CMakeLists.txt @@ -9,6 +9,11 @@ function(ggml_add_cpu_backend_features cpu_name arch) target_compile_definitions(${GGML_CPU_FEATS_NAME} PRIVATE ${ARGN}) target_compile_definitions(${GGML_CPU_FEATS_NAME} PRIVATE GGML_BACKEND_DL GGML_BACKEND_BUILD GGML_BACKEND_SHARED) set_target_properties(${GGML_CPU_FEATS_NAME} PROPERTIES POSITION_INDEPENDENT_CODE ON) + # Disable LTO for the feature detection code to prevent cross-module optimization + # from inlining architecture-specific instructions into the score function. + # Without this, LTO can cause SIGILL when loading backends on older CPUs + # (e.g., loading power10 backend on power9 crashes before feature check runs). + target_compile_options(${GGML_CPU_FEATS_NAME} PRIVATE -fno-lto) target_link_libraries(${cpu_name} PRIVATE ${GGML_CPU_FEATS_NAME}) endfunction() @@ -67,17 +72,9 @@ function(ggml_add_cpu_backend_variant_impl tag_name) endif() endif() - if (GGML_OPENMP) - find_package(OpenMP) - if (OpenMP_FOUND) - set(GGML_OPENMP_ENABLED "ON" CACHE INTERNAL "") - target_compile_definitions(${GGML_CPU_NAME} PRIVATE GGML_USE_OPENMP) - - target_link_libraries(${GGML_CPU_NAME} PRIVATE OpenMP::OpenMP_C OpenMP::OpenMP_CXX) - else() - set(GGML_OPENMP_ENABLED "OFF" CACHE INTERNAL "") - message(WARNING "OpenMP not found") - endif() + if (GGML_OPENMP_ENABLED) + target_compile_definitions(${GGML_CPU_NAME} PRIVATE GGML_USE_OPENMP) + target_link_libraries(${GGML_CPU_NAME} PRIVATE OpenMP::OpenMP_C OpenMP::OpenMP_CXX) endif() if (GGML_LLAMAFILE) @@ -445,16 +442,30 @@ function(ggml_add_cpu_backend_variant_impl tag_name) ggml-cpu/arch/riscv/repack.cpp ) if (GGML_CPU_RISCV64_SPACEMIT) + include(ggml-cpu/cmake/FindSMTIME.cmake) target_compile_definitions(${GGML_CPU_NAME} PRIVATE GGML_USE_CPU_RISCV64_SPACEMIT ${RISCV64_SPACEMIT_IME_SPEC}) list(APPEND GGML_CPU_SOURCES ggml-cpu/spacemit/ime.cpp ggml-cpu/spacemit/ime.h + ggml-cpu/spacemit/spine_mem_pool.cpp + ggml-cpu/spacemit/spine_mem_pool.h + ggml-cpu/spacemit/repack.cpp + ggml-cpu/spacemit/repack.h + ggml-cpu/spacemit/ime_env.cpp + ggml-cpu/spacemit/ime_env.h ggml-cpu/spacemit/ime1_kernels.cpp + ggml-cpu/spacemit/ime2_kernels.cpp ggml-cpu/spacemit/ime_kernels.h + ggml-cpu/spacemit/rvv_kernels.cpp + ggml-cpu/spacemit/rvv_kernels.h ) endif() if(NOT GGML_CPU_ALL_VARIANTS) set(MARCH_STR "rv64gc") + if (GGML_RVV) + string(APPEND MARCH_STR "v") + endif() + if (GGML_RV_ZFH) string(APPEND MARCH_STR "_zfh") endif() @@ -462,7 +473,6 @@ function(ggml_add_cpu_backend_variant_impl tag_name) if (GGML_XTHEADVECTOR) string(APPEND MARCH_STR "_xtheadvector") elseif (GGML_RVV) - string(APPEND MARCH_STR "_v") if (GGML_RV_ZVFH) string(APPEND MARCH_STR "_zvfh") endif() @@ -470,12 +480,24 @@ function(ggml_add_cpu_backend_variant_impl tag_name) string(APPEND MARCH_STR "_zvfbfwma") endif() endif() + if (GGML_RV_ZICBOP) string(APPEND MARCH_STR "_zicbop") endif() if (GGML_RV_ZIHINTPAUSE) string(APPEND MARCH_STR "_zihintpause") endif() + if (GGML_RV_ZBA) + string(APPEND MARCH_STR "_zba") + endif() + if (GGML_CPU_RISCV64_SPACEMIT) + # `xsmtvdotii' is only required for GCC >= 15. + if (CMAKE_C_COMPILER_ID STREQUAL "GNU" AND + CMAKE_C_COMPILER_VERSION VERSION_GREATER_EQUAL 15) + string(APPEND MARCH_STR "_xsmtvdotii") + endif() + endif() + list(APPEND ARCH_FLAGS "-march=${MARCH_STR}" -mabi=lp64d) else() # Begin with the lowest baseline @@ -561,35 +583,44 @@ function(ggml_add_cpu_backend_variant_impl tag_name) # Fetch KleidiAI sources: include(FetchContent) - set(KLEIDIAI_COMMIT_TAG "v1.16.0") - set(KLEIDIAI_DOWNLOAD_URL "https://github.com/ARM-software/kleidiai/archive/refs/tags/${KLEIDIAI_COMMIT_TAG}.tar.gz") - set(KLEIDIAI_ARCHIVE_MD5 "0a9e9008adb6031f9e8cf70dff4a3321") + set(KLEIDIAI_COMMIT_TAG "v1.24.0") + set(KLEIDIAI_DOWNLOAD_URL "https://github.com/ARM-software/kleidiai/releases/download/${KLEIDIAI_COMMIT_TAG}/kleidiai-${KLEIDIAI_COMMIT_TAG}-src.tar.gz") + set(KLEIDIAI_RELEASE_ARCHIVE_MD5 "2f02ebe29573d45813e671eb304f2a00") - if (POLICY CMP0135) - cmake_policy(SET CMP0135 NEW) + set(KLEIDIAI_FETCH_ARGS + URL ${KLEIDIAI_DOWNLOAD_URL} + URL_HASH MD5=${KLEIDIAI_RELEASE_ARCHIVE_MD5} + ) + if (CMAKE_VERSION VERSION_GREATER_EQUAL "3.24") + list(APPEND KLEIDIAI_FETCH_ARGS DOWNLOAD_EXTRACT_TIMESTAMP NEW) endif() - FetchContent_Declare(KleidiAI_Download - URL ${KLEIDIAI_DOWNLOAD_URL} - DOWNLOAD_EXTRACT_TIMESTAMP NEW - URL_HASH MD5=${KLEIDIAI_ARCHIVE_MD5}) + if (CMAKE_VERSION VERSION_GREATER_EQUAL "3.28") + FetchContent_Declare(KleidiAI_Download + ${KLEIDIAI_FETCH_ARGS} + EXCLUDE_FROM_ALL + ) + + FetchContent_MakeAvailable(KleidiAI_Download) + FetchContent_GetProperties(KleidiAI_Download SOURCE_DIR KLEIDIAI_SRC) + else() + FetchContent_Declare(KleidiAI_Download + ${KLEIDIAI_FETCH_ARGS} + ) - FetchContent_MakeAvailable(KleidiAI_Download) - FetchContent_GetProperties(KleidiAI_Download - SOURCE_DIR KLEIDIAI_SRC - POPULATED KLEIDIAI_POPULATED) + FetchContent_GetProperties(KleidiAI_Download + SOURCE_DIR KLEIDIAI_SRC + POPULATED KLEIDIAI_POPULATED + ) - if (NOT KLEIDIAI_POPULATED) - message(FATAL_ERROR "KleidiAI source downloaded failed.") + if (NOT KLEIDIAI_POPULATED) + FetchContent_Populate(KleidiAI_Download) + FetchContent_GetProperties(KleidiAI_Download SOURCE_DIR KLEIDIAI_SRC) + endif() endif() add_compile_definitions(GGML_USE_CPU_KLEIDIAI) - # Remove kleidiai target after fetching it - if (TARGET kleidiai) - set_target_properties(kleidiai PROPERTIES EXCLUDE_FROM_ALL TRUE) - endif() - list(APPEND GGML_CPU_SOURCES ggml-cpu/kleidiai/kleidiai.cpp ggml-cpu/kleidiai/kernels.cpp @@ -606,6 +637,7 @@ function(ggml_add_cpu_backend_variant_impl tag_name) ${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/ ${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/ ${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_fp32_bf16p_bf16p/ + ${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_f16p_qsi4c32p/ ${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/) set(ARCH_FLAGS_TEMP "${ARCH_FLAGS}") @@ -646,7 +678,6 @@ function(ggml_add_cpu_backend_variant_impl tag_name) if (NOT SME_ENABLED MATCHES -1) list(APPEND GGML_KLEIDIAI_SOURCES - ${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa.c ${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot.c ${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa.c ${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa_asm.S @@ -654,10 +685,13 @@ function(ggml_add_cpu_backend_variant_impl tag_name) ${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_dot_asm.S ${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_fp32_bf16p_bf16p/kai_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa.c ${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_fp32_bf16p_bf16p/kai_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa_asm.S + ${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_f16p_qsi4c32p/kai_matmul_clamp_f32_f16p1vlx2_qsi4c32p4vlx2_1vlx4vl_sme2_mopa.c + ${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_f16p_qsi4c32p/kai_matmul_clamp_f32_f16p1vlx2_qsi4c32p4vlx2_1vlx4vl_sme2_mopa_asm.S ${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/kai_lhs_pack_bf16p2vlx2_f32_sme.c ${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/kai_rhs_pack_kxn_bf16p2vlx2b_f32_x32_sme.c + ${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/kai_lhs_pack_f16pmrx2_f32_neon.c ${KLEIDIAI_SRC}/kai/kai_common_sme_asm.S) - set(PRIVATE_ARCH_FLAGS "-fno-tree-vectorize;${PRIVATE_ARCH_FLAGS}+sve+sve2") + set(PRIVATE_ARCH_FLAGS "-fno-tree-vectorize;${PRIVATE_ARCH_FLAGS}+sve+sve2+sme2+fp16") endif() if (NOT SVE_ENABLED MATCHES -1) diff --git a/ggml/src/ggml-cpu/amx/amx.cpp b/ggml/src/ggml-cpu/amx/amx.cpp index 895a5713753..1118f7169c9 100644 --- a/ggml/src/ggml-cpu/amx/amx.cpp +++ b/ggml/src/ggml-cpu/amx/amx.cpp @@ -111,6 +111,8 @@ static ggml_backend_buffer_i ggml_backend_amx_buffer_interface = { /* .memset_tensor = */ ggml_backend_amx_buffer_memset_tensor, /* .set_tensor = */ ggml_backend_amx_buffer_set_tensor, /* .get_tensor = */ nullptr, + /* .set_tensor_2d = */ nullptr, + /* .get_tensor_2d = */ nullptr, /* .cpy_tensor = */ nullptr, /* .clear = */ ggml_backend_amx_buffer_clear, /* .reset = */ nullptr, @@ -141,27 +143,50 @@ static size_t ggml_backend_amx_buffer_type_get_alignment(ggml_backend_buffer_typ namespace ggml::cpu::amx { class extra_buffer_type : ggml::cpu::extra_buffer_type { bool supports_op(ggml_backend_dev_t, const struct ggml_tensor * op) override { - // handle only 2d gemm for now - auto is_contiguous_2d = [](const struct ggml_tensor * t) { - return ggml_is_contiguous(t) && t->ne[3] == 1 && t->ne[2] == 1; - }; - - if (op->op == GGML_OP_MUL_MAT && is_contiguous_2d(op->src[0]) && // src0 must be contiguous - is_contiguous_2d(op->src[1]) && // src1 must be contiguous - op->src[0]->buffer && op->src[0]->buffer->buft == ggml_backend_amx_buffer_type() && - op->src[0]->ne[0] % (TILE_K * 2 * 32) == 0 && // TODO: not sure if correct (https://github.com/ggml-org/llama.cpp/pull/16315) - op->ne[0] % (TILE_N * 2) == 0 && // out_features is 32x - (qtype_has_amx_kernels(op->src[0]->type) || (op->src[0]->type == GGML_TYPE_F16))) { - // src1 must be host buffer - if (op->src[1]->buffer && !ggml_backend_buft_is_host(op->src[1]->buffer->buft)) { + if (op->op != GGML_OP_MUL_MAT) { + return false; + } + auto * src0 = op->src[0]; + auto * src1 = op->src[1]; + + if (!ggml_is_contiguous(src0) || !ggml_is_contiguous(src1)) { + return false; + } + if (!src0->buffer || src0->buffer->buft != ggml_backend_amx_buffer_type()) { + return false; + } + if (src1->buffer && !ggml_backend_buft_is_host(src1->buffer->buft)) { + return false; + } + if (op->ne[0] % (TILE_N * 2)) { + return false; + } + int alignment; + switch (src0->type) { + case GGML_TYPE_Q4_0: + case GGML_TYPE_Q4_1: + case GGML_TYPE_Q8_0: + alignment = TILE_K; + break; + case GGML_TYPE_Q4_K: + case GGML_TYPE_Q5_K: + case GGML_TYPE_Q6_K: + case GGML_TYPE_IQ4_XS: + alignment = 256; // QK_K + break; + case GGML_TYPE_F16: + alignment = 16; + break; + default: return false; - } - // src1 must be float32 - if (op->src[1]->type == GGML_TYPE_F32) { - return true; - } } - return false; + if (src0->ne[0] % alignment) { + return false; + } + if (src1->type != GGML_TYPE_F32) { + return false; + } + return true; } ggml::cpu::tensor_traits * get_tensor_traits(const struct ggml_tensor * op) override { diff --git a/ggml/src/ggml-cpu/amx/common.h b/ggml/src/ggml-cpu/amx/common.h index f392e898518..26a6ec1a2d0 100644 --- a/ggml/src/ggml-cpu/amx/common.h +++ b/ggml/src/ggml-cpu/amx/common.h @@ -9,6 +9,8 @@ #if defined(GGML_USE_OPENMP) #include <omp.h> +#else +#include <thread> #endif #define TILE_M 16 @@ -56,18 +58,40 @@ inline void balance211(T n, T nth, T ith, T& n_start, T& n_end) { } template <typename func_t> -inline void parallel_for(int n, const func_t& f) { +inline void parallel_for(int n, const func_t & f) { + if (n <= 0) { + return; + } #if defined(GGML_USE_OPENMP) -#pragma omp parallel -{ - int nth = omp_get_num_threads(); - int ith = omp_get_thread_num(); - int tbegin, tend; - balance211(n, nth, ith, tbegin, tend); - f(tbegin, tend); -} + #pragma omp parallel + { + int nth = omp_get_num_threads(); + int ith = omp_get_thread_num(); + int tbegin, tend; + balance211(n, nth, ith, tbegin, tend); + f(tbegin, tend); + } #else - f(0, n); + int nth = std::thread::hardware_concurrency(); + if (nth <= 1) { + f(0, n); + return; + } + if (nth > n) { + nth = n; + } + std::vector<std::thread> threads; + threads.reserve(nth); + for (int ith = 0; ith < nth; ++ith) { + threads.emplace_back([&f, n, ith, nth] { + int tbegin, tend; + balance211(n, nth, ith, tbegin, tend); + f(tbegin, tend); + }); + } + for (auto & t : threads) { + t.join(); + } #endif } diff --git a/ggml/src/ggml-cpu/amx/mmq.cpp b/ggml/src/ggml-cpu/amx/mmq.cpp index 47c61b88164..d9383a04be8 100644 --- a/ggml/src/ggml-cpu/amx/mmq.cpp +++ b/ggml/src/ggml-cpu/amx/mmq.cpp @@ -1,4 +1,3 @@ - #if defined(__GNUC__) #pragma GCC diagnostic ignored "-Wpedantic" #pragma GCC diagnostic ignored "-Wunused-local-typedefs" @@ -196,41 +195,33 @@ struct tile_config_t{ // will be needed. // // Here another commonly used pattern 1-3-3 is skipped, as it is mostly used when m <=16; -// and the sinlge batch gemm (m=1) has a special fast path with `avx512-vnni`. +// and the single batch gemm (m=1) has a special fast path with `avx512-vnni`. // // ref: https://www.intel.com/content/www/us/en/developer/articles/code-sample/ // advanced-matrix-extensions-intrinsics-functions.html // -#define TC_CONFIG_TILE(i, r, cb) tc.rows[i] = r; tc.colsb[i] = cb -void ggml_tile_config_init(void) { - static thread_local bool is_first_time = true; +inline void ggml_tile_config_init(void) { + static thread_local bool done = false; - if (!is_first_time) { + if (done) { return; } - static thread_local tile_config_t tc; - tile_config_t current_tc; - _tile_storeconfig(¤t_tc); - - // load only when config changes - if (tc.palette_id == 0 || (memcmp(¤t_tc.colsb, &tc.colsb, sizeof(uint16_t) * 8) != 0 && - memcmp(¤t_tc.rows, &tc.rows, sizeof(uint8_t) * 8) != 0)) { - tc.palette_id = 1; - tc.start_row = 0; - TC_CONFIG_TILE(TMM0, 8, 64); - TC_CONFIG_TILE(TMM1, 8, 64); - TC_CONFIG_TILE(TMM2, 16, 32); - TC_CONFIG_TILE(TMM3, 16, 32); - TC_CONFIG_TILE(TMM4, 16, 64); - TC_CONFIG_TILE(TMM5, 16, 64); - TC_CONFIG_TILE(TMM6, 16, 64); - TC_CONFIG_TILE(TMM7, 16, 64); - _tile_loadconfig(&tc); - } - - is_first_time = false; + alignas(64) tile_config_t tc = {}; + tc.palette_id = 1; + tc.start_row = 0; + tc.rows[0] = 8; tc.colsb[0] = 64; + tc.rows[1] = 8; tc.colsb[1] = 64; + tc.rows[2] = 16; tc.colsb[2] = 32; + tc.rows[3] = 16; tc.colsb[3] = 32; + tc.rows[4] = 16; tc.colsb[4] = 64; + tc.rows[5] = 16; tc.colsb[5] = 64; + tc.rows[6] = 16; tc.colsb[6] = 64; + tc.rows[7] = 16; tc.colsb[7] = 64; + + _tile_loadconfig(&tc); + done = true; } // we need an extra 16 * 4B (TILE_N * int32_t) for each NB/KB block for compensation. @@ -268,33 +259,6 @@ int get_row_size(int K) { return row_size; } -// vectorized dtype conversion -inline float FP16_TO_FP32(ggml_half val) { - __m256i v = _mm256_setr_epi16( - val, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0); - __m512 o = _mm512_cvtph_ps(v); - return _mm512_cvtss_f32(o); -} - -inline __m512 FP16_TO_FP32_VEC(ggml_half val) { - __m256i v = _mm256_set1_epi16(val); - return _mm512_cvtph_ps(v); -} - -// horizontal reduce -inline float _mm512_reduce_max_ps(const __m512 x) { - __m512 v = x; - __m512 v1 = _mm512_shuffle_f32x4(v, v, 0x4E); - v = _mm512_max_ps(v, v1); - v1 = _mm512_shuffle_f32x4(v, v, 0xB1); - v = _mm512_max_ps(v, v1); - v1 = _mm512_shuffle_ps(v, v, 0x4E); - v = _mm512_max_ps(v, v1); - v1 = _mm512_shuffle_ps(v, v, 0xB1); - v = _mm512_max_ps(v, v1); - return _mm512_cvtss_f32(v); -} - // transpose utils #define SHUFFLE_EPI32(a, b, mask) \ _mm256_castps_si256(_mm256_shuffle_ps(_mm256_castsi256_ps(a), _mm256_castsi256_ps(b), mask)) @@ -1370,9 +1334,9 @@ struct tinygemm_kernel_avx<float, ggml_fp16_t, float, BLOCK_M, BLOCK_N, BLOCK_K> #define LAUNCH_TINYGEMM_KERNEL_AVX(MB_SIZE, NB_SIZE) \ tinygemm_kernel_avx<float, type, float, MB_SIZE, NB_SIZE, blck_size>::apply( \ - K, (const float *)src1->data + mb_start * K, \ - (const type *)src0->data + nb_start * K, \ - (float *)dst->data + mb_start * ldc + nb_start, ldc); + K, (const float *)src1->data + src1_offset + mb_start * K, \ + (const type *)src0->data + src0_offset + nb_start * K, \ + (float *)dst->data + dst_offset + mb_start * ldc + nb_start, ldc) // re-organize in the format {NB, KB, TILE_SIZE}: @@ -1415,8 +1379,8 @@ struct tinygemm_kernel_vnni<block_q8_0, block_q4_0, float, BLOCK_M, BLOCK_N, BLO // sum of offsets, shared across COLS // // avx512-vnni does not have `_mm512_dpbssd_epi32`, - // need to transfrom ss to us: - // a * (b - 8) is equavilent to b * a - 8 * a + // need to transform ss to us: + // a * (b - 8) is equivalent to b * a - 8 * a // s u u u s u s // __m512i vcomp; @@ -2019,11 +1983,11 @@ struct tinygemm_kernel_vnni<block_q8_K, block_iq4_xs, float, BLOCK_M, BLOCK_N, B } }; -#define LAUNCH_TINYGEMM_KERNEL_VNNI(NB_SIZE) \ - tinygemm_kernel_vnni<vec_dot_type, type, float, 1, NB_SIZE, blck_size>::apply( \ - KB, (const char *)wdata + 0 * row_size_A, \ - (const char *)src0->data + PACKED_INDEX(nb * kTilesN, 0, KB, TILE_SIZE), \ - (float *) dst->data + 0 * N + nb_start, ldc) +#define LAUNCH_TINYGEMM_KERNEL_VNNI(NB_SIZE) \ + tinygemm_kernel_vnni<vec_dot_type, type, float, 1, NB_SIZE, blck_size>::apply( \ + KB, wdata_batch, \ + (const char *)src0->data + src0_offset + PACKED_INDEX(nb * kTilesN, 0, KB, TILE_SIZE), \ + (float *) dst->data + dst_offset + nb_start, ldc) template <typename TA, typename TB, typename TC, int BLOCK_K, typename std::enable_if<!is_type_qkk<TB>::value, int>::type = 0> @@ -2041,12 +2005,12 @@ void tinygemm_kernel_amx(int M, int N, int KB, const void * RESTRICT _A, const v const int lda = KB * sizeof(TA); //const int ldb = KB * sizeof(TB); - static thread_local packed_B_t Tile0[TILE_N * TILE_K]; - static thread_local packed_B_t Tile1[TILE_N * TILE_K]; - static thread_local int8_t Tile23[TILE_M * TILE_K]; + alignas(64) static thread_local packed_B_t Tile0[TILE_N * TILE_K]; + alignas(64) static thread_local packed_B_t Tile1[TILE_N * TILE_K]; + alignas(64) static thread_local int8_t Tile23[TILE_M * TILE_K]; - static thread_local int32_t TileC0[TILE_M * TILE_N * 4]; - static thread_local int32_t TileC1[TILE_M * TILE_N * 4]; + alignas(64) static thread_local int32_t TileC0[TILE_M * TILE_N * 4]; + alignas(64) static thread_local int32_t TileC1[TILE_M * TILE_N * 4]; // double buffering C to interleave avx512 and amx int32_t * C_cur = TileC0; @@ -2079,7 +2043,7 @@ void tinygemm_kernel_amx(int M, int N, int KB, const void * RESTRICT _A, const v _tile_stored(TMM5, Tile5(C_pre), TILE_N * sizeof(int32_t)); if (need_unpack) { - unpack_B<TB>(Tile1, B_blk0); + unpack_B<TB>(Tile1, B_blk1); _tile_loadd(TMM1, Tile1, TILE_N * VNNI_BLK); } else { _tile_loadd(TMM1, B_blk1, TILE_N * VNNI_BLK); @@ -2223,21 +2187,21 @@ void tinygemm_kernel_amx(int M, int N, int KB, const void * RESTRICT _A, const v const int m1 = std::max(M - TILE_M, 0); //const int lda = KB * sizeof(TA); - static thread_local int8_t Tile0[TILE_N * TILE_K]; - static thread_local int8_t Tile1[TILE_N * TILE_K]; - static thread_local int8_t Tile23[TILE_M * TILE_K]; + alignas(64) static thread_local int8_t Tile0[TILE_N * TILE_K]; + alignas(64) static thread_local int8_t Tile1[TILE_N * TILE_K]; + alignas(64) static thread_local int8_t Tile23[TILE_M * TILE_K]; // mat mul result for each group - static thread_local int32_t Tile4[TILE_M * TILE_N]; - static thread_local int32_t Tile5[TILE_M * TILE_N]; - static thread_local int32_t Tile6[TILE_M * TILE_N]; - static thread_local int32_t Tile7[TILE_M * TILE_N]; + alignas(64) static thread_local int32_t Tile4[TILE_M * TILE_N]; + alignas(64) static thread_local int32_t Tile5[TILE_M * TILE_N]; + alignas(64) static thread_local int32_t Tile6[TILE_M * TILE_N]; + alignas(64) static thread_local int32_t Tile7[TILE_M * TILE_N]; // sum of each QK_K block, contains 8 groups, int32 - static thread_local int32_t Sumi4[TILE_M * TILE_N]; - static thread_local int32_t Sumi5[TILE_M * TILE_N]; - static thread_local int32_t Sumi6[TILE_M * TILE_N]; - static thread_local int32_t Sumi7[TILE_M * TILE_N]; + alignas(64) static thread_local int32_t Sumi4[TILE_M * TILE_N]; + alignas(64) static thread_local int32_t Sumi5[TILE_M * TILE_N]; + alignas(64) static thread_local int32_t Sumi6[TILE_M * TILE_N]; + alignas(64) static thread_local int32_t Sumi7[TILE_M * TILE_N]; const int k_group_size = std::is_same<TB, block_q6_K>::value ? 16 : 32; for (int i = 0; i < KB; ++i) { @@ -2336,6 +2300,13 @@ void ggml_backend_amx_convert_weight(struct ggml_tensor * tensor, const void * d }); } +// ne2 is passed explicitly to help compiler optimize repeated calls +inline int64_t ggml_batch_offset(const ggml_tensor * t, int64_t batch_idx, int64_t ne2) { + const int64_t i2 = batch_idx % ne2; + const int64_t i3 = batch_idx / ne2; + return i3 * t->nb[3] + i2 * t->nb[2]; +} + size_t ggml_backend_amx_desired_wsize(const struct ggml_tensor * dst) { struct ggml_tensor * src0 = dst->src[0]; @@ -2348,12 +2319,13 @@ size_t ggml_backend_amx_desired_wsize(const struct ggml_tensor * dst) { const int M = dst->ne[1]; const int K = src0->ne[0]; + const int64_t n_batch = dst->ne[2] * dst->ne[3]; size_t desired_wsize = 0; GGML_DISPATCH_QTYPES(TYPE, [&] { const size_t row_size_A = K / blck_size * sizeof(vec_dot_type); - desired_wsize = M * row_size_A; + desired_wsize = n_batch * M * row_size_A; }); return desired_wsize; @@ -2365,7 +2337,7 @@ size_t ggml_backend_amx_desired_wsize(const struct ggml_tensor * dst) { // src1: input in shape of {M, K}, float32 // dst: output in shape of {M, N}, float32 // -// the function performs: dst = src1 @ src0.T +// the function performs: dst = src1 @ src0.T for each batch // void ggml_backend_amx_mul_mat(const ggml_compute_params * params, struct ggml_tensor * dst) { struct ggml_tensor * src0 = dst->src[0]; @@ -2382,17 +2354,26 @@ void ggml_backend_amx_mul_mat(const ggml_compute_params * params, struct ggml_te const int K = src0->ne[0]; const int ldc = dst->nb[1] / dst->nb[0]; + const int64_t ne2 = dst->ne[2]; + const int64_t n_batch = ne2 * dst->ne[3]; + if (is_floating_type) { constexpr int BLOCK_M = 4; constexpr int BLOCK_N = 6; const int MB = div_up(M, BLOCK_M); const int NB = div_up(N, BLOCK_N); - parallel_for_ggml(params, MB * NB, [&](int begin, int end) { + parallel_for_ggml(params, n_batch * MB * NB, [&](int begin, int end) { GGML_DISPATCH_FLOATING_TYPES(TYPE, [&] { for (int i = begin; i < end; ++i) { - int mb = i / NB; - int nb = i % NB; + int batch_idx = i / (MB * NB); + int remaining = i % (MB * NB); + int mb = remaining / NB; + int nb = remaining % NB; + + int64_t src0_offset = ggml_batch_offset(src0, batch_idx, ne2); + int64_t src1_offset = ggml_batch_offset(src1, batch_idx, ne2); + int64_t dst_offset = ggml_batch_offset(dst, batch_idx, ne2); int mb_start = mb * BLOCK_M; int mb_size = std::min(BLOCK_M, M - mb_start); @@ -2424,10 +2405,10 @@ void ggml_backend_amx_mul_mat(const ggml_compute_params * params, struct ggml_te void * wdata = params->wdata; //TODO: performance improvement: merge quant A - if (params->ith == 0) { + // if (params->ith == 0) { GGML_DISPATCH_QTYPES(TYPE, [&] { const size_t row_size_A = K / blck_size * sizeof(vec_dot_type); - const size_t desired_wsize = M * row_size_A; + const size_t desired_wsize = n_batch * M * row_size_A; if (params->wsize < desired_wsize) { GGML_ABORT("insufficient work space size"); } @@ -2436,12 +2417,19 @@ void ggml_backend_amx_mul_mat(const ggml_compute_params * params, struct ggml_te // Q4_K, Q5_K, Q6_K, IQ4_XS handles 8 TILE_K per blck_size GGML_ASSERT(TILE_K == blck_size || TILE_K * 8 == blck_size); - const float * A_data = static_cast<const float *>(src1->data); - for (int m = 0; m < M; ++m) { - from_float<vec_dot_type>(A_data + m * K, (char *)wdata + m * row_size_A, K); - } + parallel_for_ggml(params, n_batch, [&](int begin, int end) { + for (int batch_idx = begin; batch_idx < end; ++batch_idx) { + int64_t src1_offset = ggml_batch_offset(src1, batch_idx, ne2); + const float * A_data = (const float *)((const char *)src1->data + src1_offset); + char * wdata_batch = (char *)wdata + batch_idx * M * row_size_A; + + for (int m = 0; m < M; ++m) { + from_float<vec_dot_type>(A_data + m * K, wdata_batch + m * row_size_A, K); + } + } + }); }); - } + // } ggml_barrier(params->threadpool); @@ -2451,13 +2439,19 @@ void ggml_backend_amx_mul_mat(const ggml_compute_params * params, struct ggml_te constexpr int BLOCK_N = TILE_N * kTilesN; const int NB = div_up(N, BLOCK_N); - parallel_for_ggml(params, NB, [&](int begin, int end) { + parallel_for_ggml(params, n_batch * NB, [&](int begin, int end) { GGML_DISPATCH_QTYPES(TYPE, [&] { const int KB = K / blck_size; const int TILE_SIZE = get_tile_size<type>(); const int row_size_A = KB * sizeof(vec_dot_type); for (int i = begin; i < end; ++i) { - int nb = i; + int batch_idx = i / NB; + int nb = i % NB; + + int64_t src0_offset = ggml_batch_offset(src0, batch_idx, ne2); + int64_t dst_offset = ggml_batch_offset(dst, batch_idx, ne2); + const char * wdata_batch = (const char *)wdata + batch_idx * row_size_A; + int nb_start = nb * BLOCK_N; int nb_size = std::min(BLOCK_N, N - nb_start); // 32, 64, 96 @@ -2481,7 +2475,7 @@ void ggml_backend_amx_mul_mat(const ggml_compute_params * params, struct ggml_te const int MB = div_up(M, BLOCK_M); const int NB = div_up(N, BLOCK_N); - parallel_for_ggml(params, MB * NB, [&](int begin, int end) { + parallel_for_ggml(params, n_batch * MB * NB, [&](int begin, int end) { // init tile config for each thread ggml_tile_config_init(); @@ -2491,8 +2485,14 @@ void ggml_backend_amx_mul_mat(const ggml_compute_params * params, struct ggml_te const int row_size_A = KB * sizeof(vec_dot_type); for (int i = begin; i < end; ++i) { - int mb = i / NB; - int nb = i % NB; + int batch_idx = i / (MB * NB); + int remaining = i % (MB * NB); + int mb = remaining / NB; + int nb = remaining % NB; + + int64_t src0_offset = ggml_batch_offset(src0, batch_idx, ne2); + int64_t dst_offset = ggml_batch_offset(dst, batch_idx, ne2); + const char * wdata_batch = (const char *)wdata + batch_idx * M * row_size_A; int mb_start = mb * BLOCK_M; int mb_size = std::min(BLOCK_M, M - mb_start); @@ -2501,9 +2501,9 @@ void ggml_backend_amx_mul_mat(const ggml_compute_params * params, struct ggml_te tinygemm_kernel_amx<vec_dot_type, type, float, blck_size>( mb_size, nb_size, KB, - (const char *)wdata + mb_start * row_size_A, - (const char *)src0->data + PACKED_INDEX(nb * 2, 0, KB, TILE_SIZE), - (float *) dst->data + mb_start * N + nb_start, ldc); + wdata_batch + mb_start * row_size_A, + (const char *)src0->data + src0_offset + PACKED_INDEX(nb * 2, 0, KB, TILE_SIZE), + (float *) dst->data + dst_offset + mb_start * N + nb_start, ldc); } }); }); diff --git a/ggml/src/ggml-cpu/arch-fallback.h b/ggml/src/ggml-cpu/arch-fallback.h index 3f8946ac701..b0391a67c88 100644 --- a/ggml/src/ggml-cpu/arch-fallback.h +++ b/ggml/src/ggml-cpu/arch-fallback.h @@ -1,3 +1,4 @@ + #pragma once // Rename `_generic` functions if no native implementation is available. @@ -14,6 +15,8 @@ #define ggml_vec_dot_q5_1_q8_1_generic ggml_vec_dot_q5_1_q8_1 #define ggml_vec_dot_q8_0_q8_0_generic ggml_vec_dot_q8_0_q8_0 #define ggml_vec_dot_mxfp4_q8_0_generic ggml_vec_dot_mxfp4_q8_0 +#define ggml_vec_dot_nvfp4_q8_0_generic ggml_vec_dot_nvfp4_q8_0 +#define ggml_vec_dot_q1_0_q8_0_generic ggml_vec_dot_q1_0_q8_0 #define ggml_vec_dot_tq1_0_q8_K_generic ggml_vec_dot_tq1_0_q8_K #define ggml_vec_dot_tq2_0_q8_K_generic ggml_vec_dot_tq2_0_q8_K #define ggml_vec_dot_q2_K_q8_K_generic ggml_vec_dot_q2_K_q8_K @@ -38,21 +41,33 @@ #define ggml_gemv_q4_0_4x4_q8_0_generic ggml_gemv_q4_0_4x4_q8_0 #define ggml_gemv_q4_0_4x8_q8_0_generic ggml_gemv_q4_0_4x8_q8_0 #define ggml_gemv_q4_0_8x8_q8_0_generic ggml_gemv_q4_0_8x8_q8_0 +#define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K #define ggml_gemv_q4_K_8x4_q8_K_generic ggml_gemv_q4_K_8x4_q8_K #define ggml_gemv_q4_K_8x8_q8_K_generic ggml_gemv_q4_K_8x8_q8_K -#define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K +#define ggml_gemv_q5_K_8x4_q8_K_generic ggml_gemv_q5_K_8x4_q8_K +#define ggml_gemv_q5_K_8x8_q8_K_generic ggml_gemv_q5_K_8x8_q8_K +#define ggml_gemv_q6_K_8x4_q8_K_generic ggml_gemv_q6_K_8x4_q8_K +#define ggml_gemv_q6_K_8x8_q8_K_generic ggml_gemv_q6_K_8x8_q8_K #define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0 #define ggml_gemv_iq4_nl_8x8_q8_0_generic ggml_gemv_iq4_nl_8x8_q8_0 +#define ggml_gemv_mxfp4_4x4_q8_0_generic ggml_gemv_mxfp4_4x4_q8_0 +#define ggml_gemv_mxfp4_8x8_q8_0_generic ggml_gemv_mxfp4_8x8_q8_0 #define ggml_gemv_q8_0_4x4_q8_0_generic ggml_gemv_q8_0_4x4_q8_0 #define ggml_gemv_q8_0_4x8_q8_0_generic ggml_gemv_q8_0_4x8_q8_0 #define ggml_gemm_q4_0_4x4_q8_0_generic ggml_gemm_q4_0_4x4_q8_0 #define ggml_gemm_q4_0_4x8_q8_0_generic ggml_gemm_q4_0_4x8_q8_0 #define ggml_gemm_q4_0_8x8_q8_0_generic ggml_gemm_q4_0_8x8_q8_0 +#define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K #define ggml_gemm_q4_K_8x4_q8_K_generic ggml_gemm_q4_K_8x4_q8_K #define ggml_gemm_q4_K_8x8_q8_K_generic ggml_gemm_q4_K_8x8_q8_K -#define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K +#define ggml_gemm_q5_K_8x4_q8_K_generic ggml_gemm_q5_K_8x4_q8_K +#define ggml_gemm_q5_K_8x8_q8_K_generic ggml_gemm_q5_K_8x8_q8_K +#define ggml_gemm_q6_K_8x4_q8_K_generic ggml_gemm_q6_K_8x4_q8_K +#define ggml_gemm_q6_K_8x8_q8_K_generic ggml_gemm_q6_K_8x8_q8_K #define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0 #define ggml_gemm_iq4_nl_8x8_q8_0_generic ggml_gemm_iq4_nl_8x8_q8_0 +#define ggml_gemm_mxfp4_4x4_q8_0_generic ggml_gemm_mxfp4_4x4_q8_0 +#define ggml_gemm_mxfp4_8x8_q8_0_generic ggml_gemm_mxfp4_8x8_q8_0 #define ggml_gemm_q8_0_4x4_q8_0_generic ggml_gemm_q8_0_4x4_q8_0 #define ggml_gemm_q8_0_4x8_q8_0_generic ggml_gemm_q8_0_4x8_q8_0 #elif defined(__aarch64__) || defined(__arm__) || defined(_M_ARM) || defined(_M_ARM64) @@ -60,29 +75,45 @@ #define ggml_quantize_mat_q8_K_4x4_generic ggml_quantize_mat_q8_K_4x4 #define ggml_quantize_mat_q8_K_4x8_generic ggml_quantize_mat_q8_K_4x8 #define ggml_gemv_iq4_nl_8x8_q8_0_generic ggml_gemv_iq4_nl_8x8_q8_0 +#define ggml_gemv_mxfp4_8x8_q8_0_generic ggml_gemv_mxfp4_8x8_q8_0 #define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K #define ggml_gemm_iq4_nl_8x8_q8_0_generic ggml_gemm_iq4_nl_8x8_q8_0 +#define ggml_gemm_mxfp4_8x8_q8_0_generic ggml_gemm_mxfp4_8x8_q8_0 #define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K #elif defined(__x86_64__) || defined(__i386__) || defined(_M_IX86) || defined(_M_X64) +// quants.c +#define ggml_vec_dot_nvfp4_q8_0_generic ggml_vec_dot_nvfp4_q8_0 // repack.cpp #define ggml_quantize_mat_q8_0_4x4_generic ggml_quantize_mat_q8_0_4x4 #define ggml_quantize_mat_q8_K_4x4_generic ggml_quantize_mat_q8_K_4x4 #define ggml_gemv_q4_0_4x4_q8_0_generic ggml_gemv_q4_0_4x4_q8_0 #define ggml_gemv_q4_0_4x8_q8_0_generic ggml_gemv_q4_0_4x8_q8_0 #define ggml_gemv_q4_K_8x4_q8_K_generic ggml_gemv_q4_K_8x4_q8_K +#define ggml_gemv_q5_K_8x4_q8_K_generic ggml_gemv_q5_K_8x4_q8_K +#define ggml_gemv_q5_K_8x8_q8_K_generic ggml_gemv_q5_K_8x8_q8_K +#define ggml_gemv_q6_K_8x4_q8_K_generic ggml_gemv_q6_K_8x4_q8_K +#define ggml_gemv_q6_K_8x8_q8_K_generic ggml_gemv_q6_K_8x8_q8_K #define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0 +#define ggml_gemv_mxfp4_4x4_q8_0_generic ggml_gemv_mxfp4_4x4_q8_0 #define ggml_gemv_q8_0_4x4_q8_0_generic ggml_gemv_q8_0_4x4_q8_0 #define ggml_gemv_q8_0_4x8_q8_0_generic ggml_gemv_q8_0_4x8_q8_0 #define ggml_gemm_q4_0_4x4_q8_0_generic ggml_gemm_q4_0_4x4_q8_0 #define ggml_gemm_q4_0_4x8_q8_0_generic ggml_gemm_q4_0_4x8_q8_0 #define ggml_gemm_q4_K_8x4_q8_K_generic ggml_gemm_q4_K_8x4_q8_K +#define ggml_gemm_q5_K_8x4_q8_K_generic ggml_gemm_q5_K_8x4_q8_K +#define ggml_gemm_q5_K_8x8_q8_K_generic ggml_gemm_q5_K_8x8_q8_K +#define ggml_gemm_q6_K_8x4_q8_K_generic ggml_gemm_q6_K_8x4_q8_K +#define ggml_gemm_q6_K_8x8_q8_K_generic ggml_gemm_q6_K_8x8_q8_K #define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0 +#define ggml_gemm_mxfp4_4x4_q8_0_generic ggml_gemm_mxfp4_4x4_q8_0 #define ggml_gemm_q8_0_4x4_q8_0_generic ggml_gemm_q8_0_4x4_q8_0 #define ggml_gemm_q8_0_4x8_q8_0_generic ggml_gemm_q8_0_4x8_q8_0 #elif defined(__POWERPC__) || defined(__powerpc__) // ref: https://github.com/ggml-org/llama.cpp/pull/14146#issuecomment-2972561679 // quants.c #define quantize_row_q8_K_generic quantize_row_q8_K +#define ggml_vec_dot_nvfp4_q8_0_generic ggml_vec_dot_nvfp4_q8_0 +#define ggml_vec_dot_q1_0_q8_0_generic ggml_vec_dot_q1_0_q8_0 #define ggml_vec_dot_tq1_0_q8_K_generic ggml_vec_dot_tq1_0_q8_K #define ggml_vec_dot_tq2_0_q8_K_generic ggml_vec_dot_tq2_0_q8_K #define ggml_vec_dot_iq1_m_q8_K_generic ggml_vec_dot_iq1_m_q8_K @@ -94,21 +125,33 @@ #define ggml_gemv_q4_0_4x4_q8_0_generic ggml_gemv_q4_0_4x4_q8_0 #define ggml_gemv_q4_0_4x8_q8_0_generic ggml_gemv_q4_0_4x8_q8_0 #define ggml_gemv_q4_0_8x8_q8_0_generic ggml_gemv_q4_0_8x8_q8_0 +#define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K #define ggml_gemv_q4_K_8x4_q8_K_generic ggml_gemv_q4_K_8x4_q8_K #define ggml_gemv_q4_K_8x8_q8_K_generic ggml_gemv_q4_K_8x8_q8_K -#define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K +#define ggml_gemv_q5_K_8x4_q8_K_generic ggml_gemv_q5_K_8x4_q8_K +#define ggml_gemv_q5_K_8x8_q8_K_generic ggml_gemv_q5_K_8x8_q8_K +#define ggml_gemv_q6_K_8x4_q8_K_generic ggml_gemv_q6_K_8x4_q8_K +#define ggml_gemv_q6_K_8x8_q8_K_generic ggml_gemv_q6_K_8x8_q8_K #define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0 #define ggml_gemv_iq4_nl_8x8_q8_0_generic ggml_gemv_iq4_nl_8x8_q8_0 +#define ggml_gemv_mxfp4_4x4_q8_0_generic ggml_gemv_mxfp4_4x4_q8_0 +#define ggml_gemv_mxfp4_8x8_q8_0_generic ggml_gemv_mxfp4_8x8_q8_0 #define ggml_gemv_q8_0_4x4_q8_0_generic ggml_gemv_q8_0_4x4_q8_0 #define ggml_gemv_q8_0_4x8_q8_0_generic ggml_gemv_q8_0_4x8_q8_0 #define ggml_gemm_q4_0_4x4_q8_0_generic ggml_gemm_q4_0_4x4_q8_0 #define ggml_gemm_q4_0_4x8_q8_0_generic ggml_gemm_q4_0_4x8_q8_0 #define ggml_gemm_q4_0_8x8_q8_0_generic ggml_gemm_q4_0_8x8_q8_0 +#define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K #define ggml_gemm_q4_K_8x4_q8_K_generic ggml_gemm_q4_K_8x4_q8_K #define ggml_gemm_q4_K_8x8_q8_K_generic ggml_gemm_q4_K_8x8_q8_K -#define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K +#define ggml_gemm_q5_K_8x4_q8_K_generic ggml_gemm_q5_K_8x4_q8_K +#define ggml_gemm_q5_K_8x8_q8_K_generic ggml_gemm_q5_K_8x8_q8_K +#define ggml_gemm_q6_K_8x4_q8_K_generic ggml_gemm_q6_K_8x4_q8_K +#define ggml_gemm_q6_K_8x8_q8_K_generic ggml_gemm_q6_K_8x8_q8_K #define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0 #define ggml_gemm_iq4_nl_8x8_q8_0_generic ggml_gemm_iq4_nl_8x8_q8_0 +#define ggml_gemm_mxfp4_4x4_q8_0_generic ggml_gemm_mxfp4_4x4_q8_0 +#define ggml_gemm_mxfp4_8x8_q8_0_generic ggml_gemm_mxfp4_8x8_q8_0 #define ggml_gemm_q8_0_4x4_q8_0_generic ggml_gemm_q8_0_4x4_q8_0 #define ggml_gemm_q8_0_4x8_q8_0_generic ggml_gemm_q8_0_4x8_q8_0 #elif defined(__loongarch64) @@ -118,6 +161,8 @@ #define ggml_vec_dot_tq2_0_q8_K_generic ggml_vec_dot_tq2_0_q8_K #define ggml_vec_dot_iq1_m_q8_K_generic ggml_vec_dot_iq1_m_q8_K #define ggml_vec_dot_mxfp4_q8_0_generic ggml_vec_dot_mxfp4_q8_0 +#define ggml_vec_dot_nvfp4_q8_0_generic ggml_vec_dot_nvfp4_q8_0 +#define ggml_vec_dot_q1_0_q8_0_generic ggml_vec_dot_q1_0_q8_0 // repack.cpp #define ggml_quantize_mat_q8_0_4x4_generic ggml_quantize_mat_q8_0_4x4 #define ggml_quantize_mat_q8_0_4x8_generic ggml_quantize_mat_q8_0_4x8 @@ -126,64 +171,79 @@ #define ggml_gemv_q4_0_4x4_q8_0_generic ggml_gemv_q4_0_4x4_q8_0 #define ggml_gemv_q4_0_4x8_q8_0_generic ggml_gemv_q4_0_4x8_q8_0 #define ggml_gemv_q4_0_8x8_q8_0_generic ggml_gemv_q4_0_8x8_q8_0 +#define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K #define ggml_gemv_q4_K_8x4_q8_K_generic ggml_gemv_q4_K_8x4_q8_K #define ggml_gemv_q4_K_8x8_q8_K_generic ggml_gemv_q4_K_8x8_q8_K -#define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K +#define ggml_gemv_q5_K_8x4_q8_K_generic ggml_gemv_q5_K_8x4_q8_K +#define ggml_gemv_q5_K_8x8_q8_K_generic ggml_gemv_q5_K_8x8_q8_K +#define ggml_gemv_q6_K_8x4_q8_K_generic ggml_gemv_q6_K_8x4_q8_K +#define ggml_gemv_q6_K_8x8_q8_K_generic ggml_gemv_q6_K_8x8_q8_K #define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0 #define ggml_gemv_iq4_nl_8x8_q8_0_generic ggml_gemv_iq4_nl_8x8_q8_0 +#define ggml_gemv_mxfp4_4x4_q8_0_generic ggml_gemv_mxfp4_4x4_q8_0 +#define ggml_gemv_mxfp4_8x8_q8_0_generic ggml_gemv_mxfp4_8x8_q8_0 #define ggml_gemv_q8_0_4x4_q8_0_generic ggml_gemv_q8_0_4x4_q8_0 #define ggml_gemv_q8_0_4x8_q8_0_generic ggml_gemv_q8_0_4x8_q8_0 #define ggml_gemm_q4_0_4x4_q8_0_generic ggml_gemm_q4_0_4x4_q8_0 #define ggml_gemm_q4_0_4x8_q8_0_generic ggml_gemm_q4_0_4x8_q8_0 #define ggml_gemm_q4_0_8x8_q8_0_generic ggml_gemm_q4_0_8x8_q8_0 +#define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K #define ggml_gemm_q4_K_8x4_q8_K_generic ggml_gemm_q4_K_8x4_q8_K #define ggml_gemm_q4_K_8x8_q8_K_generic ggml_gemm_q4_K_8x8_q8_K -#define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K +#define ggml_gemm_q5_K_8x4_q8_K_generic ggml_gemm_q5_K_8x4_q8_K +#define ggml_gemm_q5_K_8x8_q8_K_generic ggml_gemm_q5_K_8x8_q8_K +#define ggml_gemm_q6_K_8x4_q8_K_generic ggml_gemm_q6_K_8x4_q8_K +#define ggml_gemm_q6_K_8x8_q8_K_generic ggml_gemm_q6_K_8x8_q8_K #define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0 #define ggml_gemm_iq4_nl_8x8_q8_0_generic ggml_gemm_iq4_nl_8x8_q8_0 +#define ggml_gemm_mxfp4_4x4_q8_0_generic ggml_gemm_mxfp4_4x4_q8_0 +#define ggml_gemm_mxfp4_8x8_q8_0_generic ggml_gemm_mxfp4_8x8_q8_0 #define ggml_gemm_q8_0_4x4_q8_0_generic ggml_gemm_q8_0_4x4_q8_0 #define ggml_gemm_q8_0_4x8_q8_0_generic ggml_gemm_q8_0_4x8_q8_0 #elif defined(__riscv) // quants.c -#define quantize_row_q8_K_generic quantize_row_q8_K -#define ggml_vec_dot_tq1_0_q8_K_generic ggml_vec_dot_tq1_0_q8_K -#define ggml_vec_dot_tq2_0_q8_K_generic ggml_vec_dot_tq2_0_q8_K -#define ggml_vec_dot_iq2_xxs_q8_K_generic ggml_vec_dot_iq2_xxs_q8_K -#define ggml_vec_dot_iq2_xs_q8_K_generic ggml_vec_dot_iq2_xs_q8_K -#define ggml_vec_dot_iq2_s_q8_K_generic ggml_vec_dot_iq2_s_q8_K -#define ggml_vec_dot_iq3_xxs_q8_K_generic ggml_vec_dot_iq3_xxs_q8_K -#define ggml_vec_dot_iq3_s_q8_K_generic ggml_vec_dot_iq3_s_q8_K -#define ggml_vec_dot_iq1_s_q8_K_generic ggml_vec_dot_iq1_s_q8_K -#define ggml_vec_dot_iq1_m_q8_K_generic ggml_vec_dot_iq1_m_q8_K -#define ggml_vec_dot_iq4_nl_q8_0_generic ggml_vec_dot_iq4_nl_q8_0 -#define ggml_vec_dot_iq4_xs_q8_K_generic ggml_vec_dot_iq4_xs_q8_K -#define ggml_vec_dot_mxfp4_q8_0_generic ggml_vec_dot_mxfp4_q8_0 +#define ggml_vec_dot_nvfp4_q8_0_generic ggml_vec_dot_nvfp4_q8_0 // repack.cpp +#define ggml_quantize_mat_q8_0_4x1_generic ggml_quantize_mat_q8_0_4x1 #define ggml_quantize_mat_q8_0_4x4_generic ggml_quantize_mat_q8_0_4x4 -#define ggml_quantize_mat_q8_0_4x8_generic ggml_quantize_mat_q8_0_4x8 +#define ggml_quantize_mat_q8_K_4x1_generic ggml_quantize_mat_q8_K_4x1 #define ggml_quantize_mat_q8_K_4x4_generic ggml_quantize_mat_q8_K_4x4 #define ggml_quantize_mat_q8_K_4x8_generic ggml_quantize_mat_q8_K_4x8 #define ggml_gemv_q4_0_4x4_q8_0_generic ggml_gemv_q4_0_4x4_q8_0 #define ggml_gemv_q4_0_4x8_q8_0_generic ggml_gemv_q4_0_4x8_q8_0 +#define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K #define ggml_gemv_q4_K_8x4_q8_K_generic ggml_gemv_q4_K_8x4_q8_K #define ggml_gemv_q4_K_8x8_q8_K_generic ggml_gemv_q4_K_8x8_q8_K -#define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K +#define ggml_gemv_q5_K_8x4_q8_K_generic ggml_gemv_q5_K_8x4_q8_K +#define ggml_gemv_q5_K_8x8_q8_K_generic ggml_gemv_q5_K_8x8_q8_K +#define ggml_gemv_q6_K_8x4_q8_K_generic ggml_gemv_q6_K_8x4_q8_K +#define ggml_gemv_q6_K_8x8_q8_K_generic ggml_gemv_q6_K_8x8_q8_K #define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0 #define ggml_gemv_iq4_nl_8x8_q8_0_generic ggml_gemv_iq4_nl_8x8_q8_0 +#define ggml_gemv_mxfp4_4x4_q8_0_generic ggml_gemv_mxfp4_4x4_q8_0 +#define ggml_gemv_mxfp4_8x8_q8_0_generic ggml_gemv_mxfp4_8x8_q8_0 #define ggml_gemv_q8_0_4x4_q8_0_generic ggml_gemv_q8_0_4x4_q8_0 #define ggml_gemv_q8_0_4x8_q8_0_generic ggml_gemv_q8_0_4x8_q8_0 #define ggml_gemm_q4_0_4x4_q8_0_generic ggml_gemm_q4_0_4x4_q8_0 #define ggml_gemm_q4_0_4x8_q8_0_generic ggml_gemm_q4_0_4x8_q8_0 +#define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K #define ggml_gemm_q4_K_8x4_q8_K_generic ggml_gemm_q4_K_8x4_q8_K #define ggml_gemm_q4_K_8x8_q8_K_generic ggml_gemm_q4_K_8x8_q8_K -#define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K +#define ggml_gemm_q5_K_8x4_q8_K_generic ggml_gemm_q5_K_8x4_q8_K +#define ggml_gemm_q5_K_8x8_q8_K_generic ggml_gemm_q5_K_8x8_q8_K +#define ggml_gemm_q6_K_8x4_q8_K_generic ggml_gemm_q6_K_8x4_q8_K +#define ggml_gemm_q6_K_8x8_q8_K_generic ggml_gemm_q6_K_8x8_q8_K #define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0 #define ggml_gemm_iq4_nl_8x8_q8_0_generic ggml_gemm_iq4_nl_8x8_q8_0 +#define ggml_gemm_mxfp4_4x4_q8_0_generic ggml_gemm_mxfp4_4x4_q8_0 +#define ggml_gemm_mxfp4_8x8_q8_0_generic ggml_gemm_mxfp4_8x8_q8_0 #define ggml_gemm_q8_0_4x4_q8_0_generic ggml_gemm_q8_0_4x4_q8_0 #define ggml_gemm_q8_0_4x8_q8_0_generic ggml_gemm_q8_0_4x8_q8_0 #elif defined(__s390x__) // quants.c #define quantize_row_q8_K_generic quantize_row_q8_K +#define ggml_vec_dot_nvfp4_q8_0_generic ggml_vec_dot_nvfp4_q8_0 +#define ggml_vec_dot_q1_0_q8_0_generic ggml_vec_dot_q1_0_q8_0 #define ggml_vec_dot_tq1_0_q8_K_generic ggml_vec_dot_tq1_0_q8_K #define ggml_vec_dot_tq2_0_q8_K_generic ggml_vec_dot_tq2_0_q8_K #define ggml_vec_dot_q2_K_q8_K_generic ggml_vec_dot_q2_K_q8_K @@ -202,21 +262,33 @@ #define ggml_gemv_q4_0_4x4_q8_0_generic ggml_gemv_q4_0_4x4_q8_0 #define ggml_gemv_q4_0_4x8_q8_0_generic ggml_gemv_q4_0_4x8_q8_0 #define ggml_gemv_q4_0_8x8_q8_0_generic ggml_gemv_q4_0_8x8_q8_0 +#define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K #define ggml_gemv_q4_K_8x4_q8_K_generic ggml_gemv_q4_K_8x4_q8_K #define ggml_gemv_q4_K_8x8_q8_K_generic ggml_gemv_q4_K_8x8_q8_K -#define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K +#define ggml_gemv_q5_K_8x4_q8_K_generic ggml_gemv_q5_K_8x4_q8_K +#define ggml_gemv_q5_K_8x8_q8_K_generic ggml_gemv_q5_K_8x8_q8_K +#define ggml_gemv_q6_K_8x4_q8_K_generic ggml_gemv_q6_K_8x4_q8_K +#define ggml_gemv_q6_K_8x8_q8_K_generic ggml_gemv_q6_K_8x8_q8_K #define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0 #define ggml_gemv_iq4_nl_8x8_q8_0_generic ggml_gemv_iq4_nl_8x8_q8_0 +#define ggml_gemv_mxfp4_4x4_q8_0_generic ggml_gemv_mxfp4_4x4_q8_0 +#define ggml_gemv_mxfp4_8x8_q8_0_generic ggml_gemv_mxfp4_8x8_q8_0 #define ggml_gemv_q8_0_4x4_q8_0_generic ggml_gemv_q8_0_4x4_q8_0 #define ggml_gemv_q8_0_4x8_q8_0_generic ggml_gemv_q8_0_4x8_q8_0 #define ggml_gemm_q4_0_4x4_q8_0_generic ggml_gemm_q4_0_4x4_q8_0 #define ggml_gemm_q4_0_4x8_q8_0_generic ggml_gemm_q4_0_4x8_q8_0 #define ggml_gemm_q4_0_8x8_q8_0_generic ggml_gemm_q4_0_8x8_q8_0 +#define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K #define ggml_gemm_q4_K_8x4_q8_K_generic ggml_gemm_q4_K_8x4_q8_K #define ggml_gemm_q4_K_8x8_q8_K_generic ggml_gemm_q4_K_8x8_q8_K -#define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K +#define ggml_gemm_q5_K_8x4_q8_K_generic ggml_gemm_q5_K_8x4_q8_K +#define ggml_gemm_q5_K_8x8_q8_K_generic ggml_gemm_q5_K_8x8_q8_K +#define ggml_gemm_q6_K_8x4_q8_K_generic ggml_gemm_q6_K_8x4_q8_K +#define ggml_gemm_q6_K_8x8_q8_K_generic ggml_gemm_q6_K_8x8_q8_K #define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0 #define ggml_gemm_iq4_nl_8x8_q8_0_generic ggml_gemm_iq4_nl_8x8_q8_0 +#define ggml_gemm_mxfp4_4x4_q8_0_generic ggml_gemm_mxfp4_4x4_q8_0 +#define ggml_gemm_mxfp4_8x8_q8_0_generic ggml_gemm_mxfp4_8x8_q8_0 #define ggml_gemm_q8_0_4x4_q8_0_generic ggml_gemm_q8_0_4x4_q8_0 #define ggml_gemm_q8_0_4x8_q8_0_generic ggml_gemm_q8_0_4x8_q8_0 #elif defined(__wasm__) @@ -234,6 +306,8 @@ #define ggml_vec_dot_iq4_nl_q8_0_generic ggml_vec_dot_iq4_nl_q8_0 #define ggml_vec_dot_iq4_xs_q8_K_generic ggml_vec_dot_iq4_xs_q8_K #define ggml_vec_dot_mxfp4_q8_0_generic ggml_vec_dot_mxfp4_q8_0 +#define ggml_vec_dot_nvfp4_q8_0_generic ggml_vec_dot_nvfp4_q8_0 +#define ggml_vec_dot_q1_0_q8_0_generic ggml_vec_dot_q1_0_q8_0 // repack.cpp #define ggml_quantize_mat_q8_0_4x4_generic ggml_quantize_mat_q8_0_4x4 #define ggml_quantize_mat_q8_0_4x8_generic ggml_quantize_mat_q8_0_4x8 @@ -242,21 +316,33 @@ #define ggml_gemv_q4_0_4x4_q8_0_generic ggml_gemv_q4_0_4x4_q8_0 #define ggml_gemv_q4_0_4x8_q8_0_generic ggml_gemv_q4_0_4x8_q8_0 #define ggml_gemv_q4_0_8x8_q8_0_generic ggml_gemv_q4_0_8x8_q8_0 +#define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K #define ggml_gemv_q4_K_8x4_q8_K_generic ggml_gemv_q4_K_8x4_q8_K #define ggml_gemv_q4_K_8x8_q8_K_generic ggml_gemv_q4_K_8x8_q8_K -#define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K +#define ggml_gemv_q5_K_8x4_q8_K_generic ggml_gemv_q5_K_8x4_q8_K +#define ggml_gemv_q5_K_8x8_q8_K_generic ggml_gemv_q5_K_8x8_q8_K +#define ggml_gemv_q6_K_8x4_q8_K_generic ggml_gemv_q6_K_8x4_q8_K +#define ggml_gemv_q6_K_8x8_q8_K_generic ggml_gemv_q6_K_8x8_q8_K #define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0 #define ggml_gemv_iq4_nl_8x8_q8_0_generic ggml_gemv_iq4_nl_8x8_q8_0 +#define ggml_gemv_mxfp4_4x4_q8_0_generic ggml_gemv_mxfp4_4x4_q8_0 +#define ggml_gemv_mxfp4_8x8_q8_0_generic ggml_gemv_mxfp4_8x8_q8_0 #define ggml_gemv_q8_0_4x4_q8_0_generic ggml_gemv_q8_0_4x4_q8_0 #define ggml_gemv_q8_0_4x8_q8_0_generic ggml_gemv_q8_0_4x8_q8_0 #define ggml_gemm_q4_0_4x4_q8_0_generic ggml_gemm_q4_0_4x4_q8_0 #define ggml_gemm_q4_0_4x8_q8_0_generic ggml_gemm_q4_0_4x8_q8_0 #define ggml_gemm_q4_0_8x8_q8_0_generic ggml_gemm_q4_0_8x8_q8_0 +#define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K #define ggml_gemm_q4_K_8x4_q8_K_generic ggml_gemm_q4_K_8x4_q8_K #define ggml_gemm_q4_K_8x8_q8_K_generic ggml_gemm_q4_K_8x8_q8_K -#define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K +#define ggml_gemm_q5_K_8x4_q8_K_generic ggml_gemm_q5_K_8x4_q8_K +#define ggml_gemm_q5_K_8x8_q8_K_generic ggml_gemm_q5_K_8x8_q8_K +#define ggml_gemm_q6_K_8x4_q8_K_generic ggml_gemm_q6_K_8x4_q8_K +#define ggml_gemm_q6_K_8x8_q8_K_generic ggml_gemm_q6_K_8x8_q8_K #define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0 #define ggml_gemm_iq4_nl_8x8_q8_0_generic ggml_gemm_iq4_nl_8x8_q8_0 +#define ggml_gemm_mxfp4_4x4_q8_0_generic ggml_gemm_mxfp4_4x4_q8_0 +#define ggml_gemm_mxfp4_8x8_q8_0_generic ggml_gemm_mxfp4_8x8_q8_0 #define ggml_gemm_q8_0_4x4_q8_0_generic ggml_gemm_q8_0_4x4_q8_0 #define ggml_gemm_q8_0_4x8_q8_0_generic ggml_gemm_q8_0_4x8_q8_0 #endif diff --git a/ggml/src/ggml-cpu/arch/arm/quants.c b/ggml/src/ggml-cpu/arch/arm/quants.c index b390ab61c78..fe621332970 100644 --- a/ggml/src/ggml-cpu/arch/arm/quants.c +++ b/ggml/src/ggml-cpu/arch/arm/quants.c @@ -137,6 +137,89 @@ void quantize_row_q8_K(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, in //===================================== Dot products ================================= +void ggml_vec_dot_q1_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + const int qk = QK1_0; // 128 + const int nb = n / qk; + + assert(n % qk == 0); + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + + const block_q1_0 * GGML_RESTRICT x = vx; + const block_q8_0 * GGML_RESTRICT y = vy; + +#if defined(__ARM_NEON) + float32x4_t sumv = vdupq_n_f32(0.0f); + + for (int i = 0; i < nb; i++) { + const float d0 = GGML_CPU_FP16_TO_FP32(x[i].d); + + // Process 4 Q8_0 blocks (each has 32 elements) + for (int k = 0; k < 4; k++) { + const block_q8_0 * GGML_RESTRICT yb = &y[i * 4 + k]; + const float d1 = GGML_CPU_FP16_TO_FP32(yb->d); + + // Get the 4 bytes of bits for this Q8_0 block (32 bits = 4 bytes) + // Bits are at offset k*4 bytes in x[i].qs + const uint8_t * bits = &x[i].qs[k * 4]; + + // Load 32 int8 values from y + const int8x16_t y0 = vld1q_s8(yb->qs); + const int8x16_t y1 = vld1q_s8(yb->qs + 16); + + // Byte 0-1: bits for y0[0..15] + const uint64_t expand0 = table_b2b_0[bits[0]]; + const uint64_t expand1 = table_b2b_0[bits[1]]; + // Byte 2-3: bits for y1[0..15] + const uint64_t expand2 = table_b2b_0[bits[2]]; + const uint64_t expand3 = table_b2b_0[bits[3]]; + + // Build the sign vectors by reinterpreting the table values + uint8x8_t e0 = vcreate_u8(expand0); + uint8x8_t e1 = vcreate_u8(expand1); + uint8x8_t e2 = vcreate_u8(expand2); + uint8x8_t e3 = vcreate_u8(expand3); + + // Shift right by 4 to get 0 or 1 + int8x8_t s0 = vreinterpret_s8_u8(vshr_n_u8(e0, 4)); + int8x8_t s1 = vreinterpret_s8_u8(vshr_n_u8(e1, 4)); + int8x8_t s2 = vreinterpret_s8_u8(vshr_n_u8(e2, 4)); + int8x8_t s3 = vreinterpret_s8_u8(vshr_n_u8(e3, 4)); + + // Convert 0/1 to -1/+1: sign = 2*val - 1 + int8x8_t one = vdup_n_s8(1); + s0 = vsub_s8(vadd_s8(s0, s0), one); // 2*s0 - 1 + s1 = vsub_s8(vadd_s8(s1, s1), one); + s2 = vsub_s8(vadd_s8(s2, s2), one); + s3 = vsub_s8(vadd_s8(s3, s3), one); + + // Combine into 16-element vectors + int8x16_t signs0 = vcombine_s8(s0, s1); + int8x16_t signs1 = vcombine_s8(s2, s3); + + // Multiply signs with y values and accumulate + // dot(signs, y) where signs are +1/-1 + int32x4_t p0 = ggml_vdotq_s32(vdupq_n_s32(0), signs0, y0); + int32x4_t p1 = ggml_vdotq_s32(p0, signs1, y1); + + // Scale by d1 and accumulate + sumv = vmlaq_n_f32(sumv, vcvtq_f32_s32(p1), d0 * d1); + } + } + + *s = vaddvq_f32(sumv); +#else + UNUSED(nb); + UNUSED(x); + UNUSED(y); + ggml_vec_dot_q1_0_q8_0_generic(n, s, bs, vx, bx, vy, by, nrc); +#endif +} + + void ggml_vec_dot_q4_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { const int qk = QK8_0; const int nb = n / qk; @@ -650,6 +733,116 @@ void ggml_vec_dot_mxfp4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const vo *s = sumf; } +void ggml_vec_dot_nvfp4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + assert(n % QK_NVFP4 == 0); + + const block_nvfp4 * GGML_RESTRICT x = vx; + const block_q8_0 * GGML_RESTRICT y = vy; + + // Each NVFP4 super-block (64 elements) spans 2 q8_0 blocks + const int nb = n / QK_NVFP4; + + float sumf = 0; + +#if defined(__ARM_NEON) && defined(__ARM_FEATURE_FMA) + const int8x16_t values = vld1q_s8(kvalues_mxfp4); + const uint8x16_t m4b = vdupq_n_u8(0x0f); + float32x4_t acc = vdupq_n_f32(0.0f); + + for (int ib = 0; ib < nb; ++ib) { + const uint8x16_t q4bits_0 = vld1q_u8(x[ib].qs); + const uint8x16_t q4bits_1 = vld1q_u8(x[ib].qs + 16); + + const int8x16_t q4_lo_0 = ggml_vqtbl1q_s8(values, vandq_u8 (q4bits_0, m4b)); + const int8x16_t q4_hi_0 = ggml_vqtbl1q_s8(values, vshrq_n_u8(q4bits_0, 4)); + const int8x16_t q4_lo_1 = ggml_vqtbl1q_s8(values, vandq_u8 (q4bits_1, m4b)); + const int8x16_t q4_hi_1 = ggml_vqtbl1q_s8(values, vshrq_n_u8(q4bits_1, 4)); + +#if defined(__ARM_FEATURE_DOTPROD) + const int8x16_t q8_0a = vld1q_s8(y[2*ib].qs); + const int8x16_t q8_0b = vld1q_s8(y[2*ib].qs + 16); + const int8x16_t q8_lo_0 = vcombine_s8(vget_low_s8(q8_0a), vget_low_s8(q8_0b)); + const int8x16_t q8_hi_0 = vcombine_s8(vget_high_s8(q8_0a), vget_high_s8(q8_0b)); + + const int8x16_t q8_1a = vld1q_s8(y[2*ib+1].qs); + const int8x16_t q8_1b = vld1q_s8(y[2*ib+1].qs + 16); + const int8x16_t q8_lo_1 = vcombine_s8(vget_low_s8(q8_1a), vget_low_s8(q8_1b)); + const int8x16_t q8_hi_1 = vcombine_s8(vget_high_s8(q8_1a), vget_high_s8(q8_1b)); + + const int32x4_t p0 = vaddq_s32( + vdotq_s32(vdupq_n_s32(0), q4_lo_0, q8_lo_0), + vdotq_s32(vdupq_n_s32(0), q4_hi_0, q8_hi_0)); + const int32x4_t p1 = vaddq_s32( + vdotq_s32(vdupq_n_s32(0), q4_lo_1, q8_lo_1), + vdotq_s32(vdupq_n_s32(0), q4_hi_1, q8_hi_1)); + + const int32x4_t sumi = vpaddq_s32(p0, p1); +#else + const int8x8_t q4_0_lo = vget_low_s8(q4_lo_0); + const int8x8_t q4_0_hi = vget_low_s8(q4_hi_0); + const int8x8_t q4_1_lo = vget_high_s8(q4_lo_0); + const int8x8_t q4_1_hi = vget_high_s8(q4_hi_0); + const int8x8_t q4_2_lo = vget_low_s8(q4_lo_1); + const int8x8_t q4_2_hi = vget_low_s8(q4_hi_1); + const int8x8_t q4_3_lo = vget_high_s8(q4_lo_1); + const int8x8_t q4_3_hi = vget_high_s8(q4_hi_1); + + const int8x8_t q8_0_lo = vld1_s8(y[2*ib].qs); + const int8x8_t q8_0_hi = vld1_s8(y[2*ib].qs + 8); + const int8x8_t q8_1_lo = vld1_s8(y[2*ib].qs + 16); + const int8x8_t q8_1_hi = vld1_s8(y[2*ib].qs + 24); + const int8x8_t q8_2_lo = vld1_s8(y[2*ib+1].qs); + const int8x8_t q8_2_hi = vld1_s8(y[2*ib+1].qs + 8); + const int8x8_t q8_3_lo = vld1_s8(y[2*ib+1].qs + 16); + const int8x8_t q8_3_hi = vld1_s8(y[2*ib+1].qs + 24); + + const int32x4_t sumi = (int32x4_t){ + vaddvq_s32(ggml_nvfp4_dot8(q4_0_lo, q8_0_lo, q4_0_hi, q8_0_hi)), + vaddvq_s32(ggml_nvfp4_dot8(q4_1_lo, q8_1_lo, q4_1_hi, q8_1_hi)), + vaddvq_s32(ggml_nvfp4_dot8(q4_2_lo, q8_2_lo, q4_2_hi, q8_2_hi)), + vaddvq_s32(ggml_nvfp4_dot8(q4_3_lo, q8_3_lo, q4_3_hi, q8_3_hi)), + }; +#endif + + const float dy0 = GGML_CPU_FP16_TO_FP32(y[2*ib].d); + const float dy1 = GGML_CPU_FP16_TO_FP32(y[2*ib+1].d); + const float32x4_t nvsc = { + ggml_ue4m3_to_fp32(x[ib].d[0]), + ggml_ue4m3_to_fp32(x[ib].d[1]), + ggml_ue4m3_to_fp32(x[ib].d[2]), + ggml_ue4m3_to_fp32(x[ib].d[3]) + }; + const float32x4_t scales = vmulq_f32(nvsc, (float32x4_t){dy0, dy0, dy1, dy1}); + + acc = vfmaq_f32(acc, vcvtq_f32_s32(sumi), scales); + } + sumf = vaddvq_f32(acc); +#else + for (int ib = 0; ib < nb; ++ib) { + for (int si = 0; si < 4; ++si) { + const float d = ggml_ue4m3_to_fp32(x[ib].d[si]); + const int q8b = si / 2; + const int q8o = (si % 2) * QK_NVFP4_SUB; + const float dy = GGML_CPU_FP16_TO_FP32(y[2*ib + q8b].d); + + int sumi_lo = 0, sumi_hi = 0; + for (int j = 0; j < QK_NVFP4_SUB/2; ++j) { + const uint8_t qv = x[ib].qs[si*(QK_NVFP4_SUB/2) + j]; + sumi_lo += y[2*ib + q8b].qs[q8o + j + 0] * kvalues_mxfp4[qv & 0xf]; + sumi_hi += y[2*ib + q8b].qs[q8o + j + QK_NVFP4_SUB/2] * kvalues_mxfp4[qv >> 4]; + } + sumf += dy * d * (sumi_lo + sumi_hi); + } + } +#endif + *s = sumf; +} + void ggml_vec_dot_q5_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { const int qk = QK8_0; const int nb = n / qk; @@ -968,7 +1161,7 @@ void ggml_vec_dot_q8_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const voi const int vector_length = ggml_cpu_get_sve_cnt()*8; - //VLA Implemenation for SVE + //VLA Implementation for SVE switch (vector_length) { case 128: { diff --git a/ggml/src/ggml-cpu/arch/arm/repack.cpp b/ggml/src/ggml-cpu/arch/arm/repack.cpp index b61220a189a..a7534443091 100644 --- a/ggml/src/ggml-cpu/arch/arm/repack.cpp +++ b/ggml/src/ggml-cpu/arch/arm/repack.cpp @@ -25,9 +25,8 @@ #define UNUSED GGML_UNUSED #if defined(__aarch64__) && defined(__ARM_NEON) && (defined(__ARM_FEATURE_MATMUL_INT8) || defined(__ARM_FEATURE_DOTPROD)) -static inline void decode_q4_Kx8_scales_mins(const uint8_t * scales_in, - int16x8_t * out_mins, - int8_t * out_scales) { +// Helper for decoding scales and mins of Q4_K and Q5_K block formats +static inline void decode_q_Kx8_6bit_scales(const uint8_t * scales_in, int16x8_t * out_mins, int8_t * out_scales) { constexpr uint32_t kmask1 = 0x3f3f3f3f; constexpr uint32_t kmask2 = 0x0f0f0f0f; constexpr uint32_t kmask3 = 0x03030303; @@ -499,6 +498,81 @@ void ggml_gemv_iq4_nl_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const ggml_gemv_iq4_nl_4x4_q8_0_generic(n, s, bs, vx, vy, nr, nc); } +void ggml_gemv_mxfp4_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + const int qk = QK8_0; + const int nb = n / qk; + const int ncols_interleaved = 4; + const int blocklen = 4; + + assert (n % qk == 0); + assert (nc % ncols_interleaved == 0); + + UNUSED(s); + UNUSED(bs); + UNUSED(vx); + UNUSED(vy); + UNUSED(nr); + UNUSED(nc); + UNUSED(nb); + UNUSED(ncols_interleaved); + UNUSED(blocklen); + +#if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD) + const int8x16_t kvalues = vld1q_s8(kvalues_mxfp4); + const block_q8_0 * a_ptr = (const block_q8_0 *) vy; + float * res_ptr = s; + + for (int x = 0; x < nc / ncols_interleaved; x++) { + const block_mxfp4x4 * b_ptr = (const block_mxfp4x4 *) vx + (x * nb); + + float32x4_t sumf = vdupq_n_f32(0); + for (int l = 0; l < nb; l++) { + uint8x16_t b_0 = vld1q_u8(b_ptr[l].qs + 0); + uint8x16_t b_1 = vld1q_u8(b_ptr[l].qs + 16); + uint8x16_t b_2 = vld1q_u8(b_ptr[l].qs + 32); + uint8x16_t b_3 = vld1q_u8(b_ptr[l].qs + 48); + + int8x16_t b_0_hi = vqtbl1q_s8(kvalues, b_0 >> 4); + int8x16_t b_0_lo = vqtbl1q_s8(kvalues, b_0 & 0x0F); + int8x16_t b_1_hi = vqtbl1q_s8(kvalues, b_1 >> 4); + int8x16_t b_1_lo = vqtbl1q_s8(kvalues, b_1 & 0x0F); + int8x16_t b_2_hi = vqtbl1q_s8(kvalues, b_2 >> 4); + int8x16_t b_2_lo = vqtbl1q_s8(kvalues, b_2 & 0x0F); + int8x16_t b_3_hi = vqtbl1q_s8(kvalues, b_3 >> 4); + int8x16_t b_3_lo = vqtbl1q_s8(kvalues, b_3 & 0x0F); + + int8x16_t a_0 = vld1q_s8(a_ptr[l].qs + 0); + int8x16_t a_1 = vld1q_s8(a_ptr[l].qs + 16); + + int32x4_t sumi = vdupq_n_s32(0); + sumi = vdotq_laneq_s32(sumi, b_0_lo, a_0, 0); + sumi = vdotq_laneq_s32(sumi, b_0_hi, a_1, 0); + sumi = vdotq_laneq_s32(sumi, b_1_lo, a_0, 1); + sumi = vdotq_laneq_s32(sumi, b_1_hi, a_1, 1); + sumi = vdotq_laneq_s32(sumi, b_2_lo, a_0, 2); + sumi = vdotq_laneq_s32(sumi, b_2_hi, a_1, 2); + sumi = vdotq_laneq_s32(sumi, b_3_lo, a_0, 3); + sumi = vdotq_laneq_s32(sumi, b_3_hi, a_1, 3); + + float32x4_t a_d = vcvt_f32_f16(vld1_dup_f16((const float16_t *)&a_ptr[l].d)); + float32x4_t b_d = { + GGML_CPU_E8M0_TO_FP32_HALF(b_ptr[l].e[0]), + GGML_CPU_E8M0_TO_FP32_HALF(b_ptr[l].e[1]), + GGML_CPU_E8M0_TO_FP32_HALF(b_ptr[l].e[2]), + GGML_CPU_E8M0_TO_FP32_HALF(b_ptr[l].e[3]), + }; + float32x4_t d = a_d * b_d; + + sumf = vmlaq_f32(sumf, d, vcvtq_f32_s32(sumi)); + } + + vst1q_f32(res_ptr + x * 4, sumf); + } + return; +#endif // #if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON) + ggml_gemv_mxfp4_4x4_q8_0_generic(n, s, bs, vx, vy, nr, nc); +} + void ggml_gemv_q4_K_8x4_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { constexpr int qk = QK_K; const int nb = n / qk; @@ -561,7 +635,7 @@ void ggml_gemv_q4_K_8x4_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo for (int i = 0; i < 2; i++) { int8_t aux_q4sb[8]; const int offset = sb * 24 + i * 12; - decode_q4_Kx8_scales_mins(&q4_ptr[b].scales[offset], &q4sb_mins[i], aux_q4sb); + decode_q_Kx8_6bit_scales(&q4_ptr[b].scales[offset], &q4sb_mins[i], aux_q4sb); q4sb_scales[i] = vmovl_s8(vld1_s8(aux_q4sb)); } @@ -701,13 +775,13 @@ void ggml_gemv_q4_K_8x8_q8_K(int n, for (int i = 0; i < 2; i++) { int8_t aux_q4sb[8]; const int offset = sb * 24 + i * 12; - decode_q4_Kx8_scales_mins(&q4_ptr[b].scales[offset], &q4sb_mins[i], aux_q4sb); + decode_q_Kx8_6bit_scales(&q4_ptr[b].scales[offset], &q4sb_mins[i], aux_q4sb); q4sb_scales[i] = vmovl_s8(vld1_s8(aux_q4sb)); } const uint8_t * q4_base = q4_ptr[b].qs + sb * QK_K; - // Load the 64 quants from q8K duplicated to use vecdots with the interelaved columns + // Load the 64 quants from q8K duplicated to use vecdots with the interleaved columns // but still need the qs to use the low and hi bits from q4 const int8_t * q8_base = q8_ptr[b].qs + sb * 64; int8x16_t q8_qs[8]; @@ -786,17 +860,18 @@ void ggml_gemv_q4_K_8x8_q8_K(int n, ggml_gemv_q4_K_8x8_q8_K_generic(n, s, bs, vx, vy, nr, nc); } -void ggml_gemv_q8_0_4x4_q8_0(int n, +void ggml_gemv_q5_K_8x4_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { - const int qk = QK8_0; - const int nb = n / qk; - const int ncols_interleaved = 4; - const int blocklen = 4; + constexpr int qk = QK_K; + const int nb = n / qk; + + constexpr int ncols_interleaved = 8; + constexpr int blocklen = 4; assert(n % qk == 0); assert(nc % ncols_interleaved == 0); @@ -806,55 +881,156 @@ void ggml_gemv_q8_0_4x4_q8_0(int n, UNUSED(blocklen); #if defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD) - const block_q8_0x4 * b_ptr = (const block_q8_0x4 *) vx; + constexpr int col_groups = ncols_interleaved / 4; // 0123 and 4567 + const uint8x16_t m4b = vdupq_n_u8(0x0f); + const uint8x16_t mone = vdupq_n_u8(1); + const uint8x16_t mtwo = vdupq_n_u8(2); + + // 1x8 tile = 2 x 4 + float32x4_t acc_f32[col_groups]; + + const block_q8_K * GGML_RESTRICT q8_ptr = (const block_q8_K *) vy; + + for (int x = 0; x < nc / ncols_interleaved; x++) { + const block_q5_Kx8 * GGML_RESTRICT q5_ptr = (const block_q5_Kx8 *) vx + (x * nb); + + for (int i = 0; i < col_groups; i++) { + acc_f32[i] = vdupq_n_f32(0); + } - for (int c = 0; c < nc; c += ncols_interleaved) { - const block_q8_0 * a_ptr = (const block_q8_0 *) vy; - float32x4_t acc = vdupq_n_f32(0); for (int b = 0; b < nb; b++) { - int8x16x4_t b_low = vld1q_s8_x4((const int8_t *) b_ptr->qs); - int8x16x4_t b_high = vld1q_s8_x4((const int8_t *) b_ptr->qs + 64); - float16x4_t bd = vld1_f16((const __fp16 *) b_ptr->d); + float32x4_t q5_d_0 = vcvt_f32_f16(vld1_f16((const __fp16 *) q5_ptr[b].d)); // d0 d1 d2 d3 + float32x4_t q5_d_1 = vcvt_f32_f16(vld1_f16((const __fp16 *) q5_ptr[b].d + 4)); // d4 d5 d6 d7 + float32x4_t q8_d = vdupq_n_f32(q8_ptr[b].d); + float32x4_t sb_scale_0123 = vmulq_f32(q5_d_0, q8_d); + float32x4_t sb_scale_4567 = vmulq_f32(q5_d_1, q8_d); + float32x4_t q5_dmin_0 = vcvt_f32_f16(vld1_f16((const __fp16 *) q5_ptr[b].dmin)); // dmin 0..3 + float32x4_t q5_dmin_1 = vcvt_f32_f16(vld1_f16((const __fp16 *) q5_ptr[b].dmin + 4)); // dmin 4..7 + float32x4_t sb_min_0123 = vmulq_f32(q5_dmin_0, q8_d); + float32x4_t sb_min_4567 = vmulq_f32(q5_dmin_1, q8_d); - int8x16x2_t a = vld1q_s8_x2(a_ptr->qs); - float16x4_t ad = vld1_dup_f16((const __fp16 *) &a_ptr->d); + // interleaved bias_acc: [0]->r0 0123, [1]->r0 4567 + int32x4_t bias_acc[2] = { vdupq_n_s32(0), vdupq_n_s32(0) }; + int32x4_t acc_lo[col_groups]; + int32x4_t acc_hi[col_groups]; - int32x4_t ret = vdupq_n_s32(0); + // Each bsum is 16 elements, pairwise add leaves us with the 8 bsums of the entire block + const int16x8_t bsums = vpaddq_s16(vld1q_s16(q8_ptr[b].bsums), vld1q_s16(q8_ptr[b].bsums + 8)); + int16_t bsums_arr[8]; + vst1q_s16(bsums_arr, bsums); - ret = vdotq_laneq_s32(ret, b_low.val[0], a.val[0], 0); - ret = vdotq_laneq_s32(ret, b_low.val[1], a.val[0], 1); - ret = vdotq_laneq_s32(ret, b_low.val[2], a.val[0], 2); - ret = vdotq_laneq_s32(ret, b_low.val[3], a.val[0], 3); + uint8x16_t qh[col_groups][8]; + for (int c = 0; c < col_groups; c++) { + for (int i = 0; i < 8; i++) { + qh[c][i] = vld1q_u8(q5_ptr[b].qh + i * 32 + 16 * c); + } + } - ret = vdotq_laneq_s32(ret, b_high.val[0], a.val[1], 0); - ret = vdotq_laneq_s32(ret, b_high.val[1], a.val[1], 1); - ret = vdotq_laneq_s32(ret, b_high.val[2], a.val[1], 2); - ret = vdotq_laneq_s32(ret, b_high.val[3], a.val[1], 3); + for (int sb = 0; sb < QK_K / 64; sb++) { + for (int i = 0; i < col_groups; i++) { + acc_lo[i] = vdupq_n_s32(0); + acc_hi[i] = vdupq_n_s32(0); + } + // Need scales for the low and high nibbles + // 2 * 12 = 24 bytes per subblock, 4 sbs -> 4 * 24 = 96 bytes total + int16x8_t q5sb_mins[2]; + int16x8_t q5sb_scales[2]; + for (int i = 0; i < 2; i++) { + int8_t aux_q5sb[8]; + const int offset = sb * 24 + i * 12; + decode_q_Kx8_6bit_scales(&q5_ptr[b].scales[offset], &q5sb_mins[i], aux_q5sb); + q5sb_scales[i] = vmovl_s8(vld1_s8(aux_q5sb)); + } - acc = vfmaq_f32(acc, vcvtq_f32_s32(ret), vmulq_f32(vcvt_f32_f16(ad), vcvt_f32_f16(bd))); - a_ptr++; - b_ptr++; - } - vst1q_f32(s, acc); - s += ncols_interleaved; - } - return; + int8x16_t q8_qs[4]; + for (int i = 0; i < 4; i++) { + q8_qs[i] = vld1q_s8(q8_ptr[b].qs + sb * 64 + i * 16); + } + + for (int c = 0; c < col_groups; c++) { + uint8x16_t q5_cols[8]; + uint8x16_t hbit_lo[8]; + uint8x16_t hbit_hi[8]; + int8x16_t q5_lo[8]; + int8x16_t q5_hi[8]; + + for (int i = 0; i < 8; i++) { + q5_cols[i] = vld1q_u8(q5_ptr[b].qs + sb * QK_K + i * 32 + 16 * c); + hbit_lo[i] = vandq_u8(qh[c][i], mone); + hbit_hi[i] = vshlq_n_u8(vandq_u8(qh[c][i], mtwo), 3); + qh[c][i] = vshrq_n_u8(qh[c][i], 2); + q5_lo[i] = vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(q5_cols[i], m4b), hbit_lo[i], 4)); + q5_hi[i] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q5_cols[i], 4), hbit_hi[i])); + } + + acc_lo[c] = vdotq_laneq_s32(acc_lo[c], q5_lo[0], q8_qs[0], 0); + acc_lo[c] = vdotq_laneq_s32(acc_lo[c], q5_lo[1], q8_qs[0], 1); + acc_lo[c] = vdotq_laneq_s32(acc_lo[c], q5_lo[2], q8_qs[0], 2); + acc_lo[c] = vdotq_laneq_s32(acc_lo[c], q5_lo[3], q8_qs[0], 3); + acc_lo[c] = vdotq_laneq_s32(acc_lo[c], q5_lo[4], q8_qs[1], 0); + acc_lo[c] = vdotq_laneq_s32(acc_lo[c], q5_lo[5], q8_qs[1], 1); + acc_lo[c] = vdotq_laneq_s32(acc_lo[c], q5_lo[6], q8_qs[1], 2); + acc_lo[c] = vdotq_laneq_s32(acc_lo[c], q5_lo[7], q8_qs[1], 3); + + acc_hi[c] = vdotq_laneq_s32(acc_hi[c], q5_hi[0], q8_qs[2], 0); + acc_hi[c] = vdotq_laneq_s32(acc_hi[c], q5_hi[1], q8_qs[2], 1); + acc_hi[c] = vdotq_laneq_s32(acc_hi[c], q5_hi[2], q8_qs[2], 2); + acc_hi[c] = vdotq_laneq_s32(acc_hi[c], q5_hi[3], q8_qs[2], 3); + acc_hi[c] = vdotq_laneq_s32(acc_hi[c], q5_hi[4], q8_qs[3], 0); + acc_hi[c] = vdotq_laneq_s32(acc_hi[c], q5_hi[5], q8_qs[3], 1); + acc_hi[c] = vdotq_laneq_s32(acc_hi[c], q5_hi[6], q8_qs[3], 2); + acc_hi[c] = vdotq_laneq_s32(acc_hi[c], q5_hi[7], q8_qs[3], 3); + } + + // Scales + // row c0123 blk0 and blk1 + const int16x4_t sc_0123_lo = vget_low_s16(q5sb_scales[0]); + const int16x4_t sc_0123_hi = vget_low_s16(q5sb_scales[1]); + const float32x4_t sumf_0123 = vcvtq_f32_s32(vaddq_s32(vmulq_s32(vmovl_s16(sc_0123_lo), acc_lo[0]), + vmulq_s32(vmovl_s16(sc_0123_hi), acc_hi[0]))); + acc_f32[0] = vfmaq_f32(acc_f32[0], sb_scale_0123, sumf_0123); + // row c4567 blk0 and blk1 + const int16x4_t sc_4567_lo = vget_high_s16(q5sb_scales[0]); + const int16x4_t sc_4567_hi = vget_high_s16(q5sb_scales[1]); + const float32x4_t sumf_4567 = vcvtq_f32_s32(vaddq_s32(vmulq_s32(vmovl_s16(sc_4567_lo), acc_lo[1]), + vmulq_s32(vmovl_s16(sc_4567_hi), acc_hi[1]))); + acc_f32[1] = vfmaq_f32(acc_f32[1], sb_scale_4567, sumf_4567); + + // Bias Correction + const int16x4_t bsums_vec_lo = vdup_n_s16(bsums_arr[2 * sb + 0]); + const int16x4_t bsums_vec_hi = vdup_n_s16(bsums_arr[2 * sb + 1]); + + bias_acc[0] = vmlal_s16(bias_acc[0], bsums_vec_lo, vget_low_s16(q5sb_mins[0])); + bias_acc[0] = vmlal_s16(bias_acc[0], bsums_vec_hi, vget_low_s16(q5sb_mins[1])); + bias_acc[1] = vmlal_s16(bias_acc[1], bsums_vec_lo, vget_high_s16(q5sb_mins[0])); + bias_acc[1] = vmlal_s16(bias_acc[1], bsums_vec_hi, vget_high_s16(q5sb_mins[1])); + } // for sb + + acc_f32[0] = vmlsq_f32(acc_f32[0], vcvtq_f32_s32(bias_acc[0]), sb_min_0123); + acc_f32[1] = vmlsq_f32(acc_f32[1], vcvtq_f32_s32(bias_acc[1]), sb_min_4567); + } // for b + int base = x * ncols_interleaved; + vst1q_f32(s + base, acc_f32[0]); + vst1q_f32(s + base + 4, acc_f32[1]); + } // for x + return; #endif // defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD) - ggml_gemv_q8_0_4x4_q8_0_generic(n, s, bs, vx, vy, nr, nc); + ggml_gemv_q5_K_8x4_q8_K_generic(n, s, bs, vx, vy, nr, nc); } -void ggml_gemv_q8_0_4x8_q8_0(int n, +void ggml_gemv_q5_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { - const int qk = QK8_0; - const int nb = n / qk; - const int ncols_interleaved = 4; - const int blocklen = 8; + constexpr int qk = QK_K; + const int nb = n / qk; + + constexpr int ncols_interleaved = 8; + constexpr int blocklen = 8; assert(n % qk == 0); assert(nc % ncols_interleaved == 0); @@ -864,269 +1040,1003 @@ void ggml_gemv_q8_0_4x8_q8_0(int n, UNUSED(blocklen); #if defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD) - const block_q8_0x4 * b_ptr = (const block_q8_0x4 *) vx; + constexpr int col_pairs = ncols_interleaved / 2; + const uint8x16_t m4b = vdupq_n_u8(0x0f); + const uint8x16_t mone = vdupq_n_u8(1); + const uint8x16_t mtwo = vdupq_n_u8(2); - for (int c = 0; c < nc; c += ncols_interleaved) { - const block_q8_0 * a_ptr = (const block_q8_0 *) vy; - float32x4_t acc = vdupq_n_f32(0); + // 1x8 tile = 2 x 4 + float32x4_t acc_f32[ncols_interleaved / 4]; - for (int b = 0; b < nb; b++) { - int8x16x4_t b_low = vld1q_s8_x4((const int8_t *) b_ptr->qs); - int8x16x4_t b_high = vld1q_s8_x4((const int8_t *) b_ptr->qs + 64); - float16x4_t bd = vld1_f16((const __fp16 *) b_ptr->d); + const block_q8_K * GGML_RESTRICT q8_ptr = (const block_q8_K *) vy; - int8x8x4_t a_chunks = vld1_s8_x4(a_ptr->qs); - int8x16_t a0 = vcombine_s8(a_chunks.val[0], a_chunks.val[0]); - int8x16_t a1 = vcombine_s8(a_chunks.val[1], a_chunks.val[1]); - int8x16_t a2 = vcombine_s8(a_chunks.val[2], a_chunks.val[2]); - int8x16_t a3 = vcombine_s8(a_chunks.val[3], a_chunks.val[3]); - float16x4_t ad = vld1_dup_f16((const __fp16 *) &a_ptr->d); + for (int x = 0; x < nc / ncols_interleaved; x++) { + const block_q5_Kx8 * GGML_RESTRICT q5_ptr = (const block_q5_Kx8 *) vx + (x * nb); - int32x4_t ret0 = vdupq_n_s32(0); - int32x4_t ret1 = vdupq_n_s32(0); + for (int i = 0; i < ncols_interleaved / 4; i++) { + acc_f32[i] = vdupq_n_f32(0); + } - // 0..7 - ret0 = vdotq_s32(ret0, b_low.val[0], a0); - ret1 = vdotq_s32(ret1, b_low.val[1], a0); - // 8..15 - ret0 = vdotq_s32(ret0, b_low.val[2], a1); - ret1 = vdotq_s32(ret1, b_low.val[3], a1); - // 16..23 - ret0 = vdotq_s32(ret0, b_high.val[0], a2); - ret1 = vdotq_s32(ret1, b_high.val[1], a2); - // 24..31 - ret0 = vdotq_s32(ret0, b_high.val[2], a3); - ret1 = vdotq_s32(ret1, b_high.val[3], a3); + for (int b = 0; b < nb; b++) { + float32x4_t q5_d_0 = vcvt_f32_f16(vld1_f16((const __fp16 *) q5_ptr[b].d)); // d0 d1 d2 d3 + float32x4_t q5_d_1 = vcvt_f32_f16(vld1_f16((const __fp16 *) q5_ptr[b].d + 4)); // d4 d5 d6 d7 + float32x4_t q8_d = vdupq_n_f32(q8_ptr[b].d); + float32x4_t sb_scale_0 = vmulq_f32(q5_d_0, q8_d); + float32x4_t sb_scale_1 = vmulq_f32(q5_d_1, q8_d); + float32x4_t q5_dmin_0 = vcvt_f32_f16(vld1_f16((const __fp16 *) q5_ptr[b].dmin)); // dmin 0..3 + float32x4_t q5_dmin_1 = vcvt_f32_f16(vld1_f16((const __fp16 *) q5_ptr[b].dmin + 4)); // dmin 4..7 + float32x4_t sb_min_0 = vmulq_f32(q5_dmin_0, q8_d); + float32x4_t sb_min_1 = vmulq_f32(q5_dmin_1, q8_d); - int32x4_t ret = vpaddq_s32(ret0, ret1); + // 2 sb each iteration + int32x4_t acc_lo[col_pairs]; + int32x4_t acc_hi[col_pairs]; - acc = vfmaq_f32(acc, vcvtq_f32_s32(ret), vmulq_f32(vcvt_f32_f16(ad), vcvt_f32_f16(bd))); - a_ptr++; - b_ptr++; - } - vst1q_f32(s, acc); - s += ncols_interleaved; - } - return; + // Each bsum is 16 elements, pairwise add leaves us with the 8 bsums of the entire block + const int16x8_t bsums = vpaddq_s16(vld1q_s16(q8_ptr[b].bsums), vld1q_s16(q8_ptr[b].bsums + 8)); + int16_t bsums_arr[8]; + vst1q_s16(bsums_arr, bsums); -#endif // defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD) - ggml_gemv_q8_0_4x8_q8_0_generic(n, s, bs, vx, vy, nr, nc); -} + // Load qh once per block and shift after each subblock + const uint8_t * qh_base = q5_ptr[b].qh; + uint8x16_t qh[col_pairs][4]; + for (int cp = 0; cp < col_pairs; cp++) { + qh[cp][0] = vld1q_u8(qh_base + 16 * cp); + qh[cp][1] = vld1q_u8(qh_base + 16 * cp + 64); + qh[cp][2] = vld1q_u8(qh_base + 16 * cp + 128); + qh[cp][3] = vld1q_u8(qh_base + 16 * cp + 192); + } -void ggml_gemm_q4_0_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { - const int qk = QK8_0; - const int nb = n / qk; - const int ncols_interleaved = 4; - const int blocklen = 4; + for (int sb = 0; sb < QK_K / 64; sb++) { + for (int i = 0; i < col_pairs; i++) { + acc_lo[i] = vdupq_n_s32(0); + acc_hi[i] = vdupq_n_s32(0); + } + // Need scales for the low and high nibbles + // 2 * 12 = 24 bytes per subblock, 4 sbs -> 4 * 24 = 96 bytes total + int16x8_t q5sb_mins[2]; // int16 as its needed for bias_acc later + int16x8_t q5sb_scales[2]; + for (int i = 0; i < 2; i++) { + int8_t aux_q5sb[8]; + const int offset = sb * 24 + i * 12; + decode_q_Kx8_6bit_scales(&q5_ptr[b].scales[offset], &q5sb_mins[i], aux_q5sb); + q5sb_scales[i] = vmovl_s8(vld1_s8(aux_q5sb)); + } - assert (n % qk == 0); - assert (nr % 4 == 0); - assert (nc % ncols_interleaved == 0); + const uint8_t * qs_base = q5_ptr[b].qs + sb * QK_K; + + // Load the 64 quants from q8K duplicated to use vecdots with the interleaved columns + const int8_t * q8_base = q8_ptr[b].qs + sb * 64; + int8x16_t q8_qs[8]; + for (int i = 0; i < 8; i++) { + q8_qs[i] = (int8x16_t) vld1q_dup_s64((const int64_t *) (q8_base + i * 8)); + } + + // Q5s column pair loop unrolled + { + // Cols 01 + uint8x16_t qs_0 = vld1q_u8(qs_base); + uint8x16_t qs_1 = vld1q_u8(qs_base + 64); + uint8x16_t qs_2 = vld1q_u8(qs_base + 128); + uint8x16_t qs_3 = vld1q_u8(qs_base + 192); + + uint8x16_t hbit_lo_0 = vandq_u8(qh[0][0], mone); + uint8x16_t hbit_lo_1 = vandq_u8(qh[0][1], mone); + uint8x16_t hbit_lo_2 = vandq_u8(qh[0][2], mone); + uint8x16_t hbit_lo_3 = vandq_u8(qh[0][3], mone); + uint8x16_t hbit_hi_0 = vshlq_n_u8(vandq_u8(qh[0][0], mtwo), 3); + uint8x16_t hbit_hi_1 = vshlq_n_u8(vandq_u8(qh[0][1], mtwo), 3); + uint8x16_t hbit_hi_2 = vshlq_n_u8(vandq_u8(qh[0][2], mtwo), 3); + uint8x16_t hbit_hi_3 = vshlq_n_u8(vandq_u8(qh[0][3], mtwo), 3); + + qh[0][0] = vshrq_n_u8(qh[0][0], 2); + qh[0][1] = vshrq_n_u8(qh[0][1], 2); + qh[0][2] = vshrq_n_u8(qh[0][2], 2); + qh[0][3] = vshrq_n_u8(qh[0][3], 2); + + acc_lo[0] = ggml_vdotq_s32( + acc_lo[0], vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(qs_0, m4b), hbit_lo_0, 4)), q8_qs[0]); + acc_lo[0] = ggml_vdotq_s32( + acc_lo[0], vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(qs_1, m4b), hbit_lo_1, 4)), q8_qs[1]); + acc_lo[0] = ggml_vdotq_s32( + acc_lo[0], vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(qs_2, m4b), hbit_lo_2, 4)), q8_qs[2]); + acc_lo[0] = ggml_vdotq_s32( + acc_lo[0], vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(qs_3, m4b), hbit_lo_3, 4)), q8_qs[3]); + acc_hi[0] = ggml_vdotq_s32(acc_hi[0], vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_0, 4), hbit_hi_0)), + q8_qs[4]); + acc_hi[0] = ggml_vdotq_s32(acc_hi[0], vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_1, 4), hbit_hi_1)), + q8_qs[5]); + acc_hi[0] = ggml_vdotq_s32(acc_hi[0], vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_2, 4), hbit_hi_2)), + q8_qs[6]); + acc_hi[0] = ggml_vdotq_s32(acc_hi[0], vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_3, 4), hbit_hi_3)), + q8_qs[7]); + + // Cols 23 + qs_0 = vld1q_u8(qs_base + 16); + qs_1 = vld1q_u8(qs_base + 80); + qs_2 = vld1q_u8(qs_base + 144); + qs_3 = vld1q_u8(qs_base + 208); + + hbit_lo_0 = vandq_u8(qh[1][0], mone); + hbit_lo_1 = vandq_u8(qh[1][1], mone); + hbit_lo_2 = vandq_u8(qh[1][2], mone); + hbit_lo_3 = vandq_u8(qh[1][3], mone); + hbit_hi_0 = vshlq_n_u8(vandq_u8(qh[1][0], mtwo), 3); + hbit_hi_1 = vshlq_n_u8(vandq_u8(qh[1][1], mtwo), 3); + hbit_hi_2 = vshlq_n_u8(vandq_u8(qh[1][2], mtwo), 3); + hbit_hi_3 = vshlq_n_u8(vandq_u8(qh[1][3], mtwo), 3); + + qh[1][0] = vshrq_n_u8(qh[1][0], 2); + qh[1][1] = vshrq_n_u8(qh[1][1], 2); + qh[1][2] = vshrq_n_u8(qh[1][2], 2); + qh[1][3] = vshrq_n_u8(qh[1][3], 2); + + acc_lo[1] = ggml_vdotq_s32( + acc_lo[1], vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(qs_0, m4b), hbit_lo_0, 4)), q8_qs[0]); + acc_lo[1] = ggml_vdotq_s32( + acc_lo[1], vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(qs_1, m4b), hbit_lo_1, 4)), q8_qs[1]); + acc_lo[1] = ggml_vdotq_s32( + acc_lo[1], vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(qs_2, m4b), hbit_lo_2, 4)), q8_qs[2]); + acc_lo[1] = ggml_vdotq_s32( + acc_lo[1], vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(qs_3, m4b), hbit_lo_3, 4)), q8_qs[3]); + acc_hi[1] = ggml_vdotq_s32(acc_hi[1], vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_0, 4), hbit_hi_0)), + q8_qs[4]); + acc_hi[1] = ggml_vdotq_s32(acc_hi[1], vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_1, 4), hbit_hi_1)), + q8_qs[5]); + acc_hi[1] = ggml_vdotq_s32(acc_hi[1], vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_2, 4), hbit_hi_2)), + q8_qs[6]); + acc_hi[1] = ggml_vdotq_s32(acc_hi[1], vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_3, 4), hbit_hi_3)), + q8_qs[7]); + + // Cols 45 + qs_0 = vld1q_u8(qs_base + 32); + qs_1 = vld1q_u8(qs_base + 96); + qs_2 = vld1q_u8(qs_base + 160); + qs_3 = vld1q_u8(qs_base + 224); + + hbit_lo_0 = vandq_u8(qh[2][0], mone); + hbit_lo_1 = vandq_u8(qh[2][1], mone); + hbit_lo_2 = vandq_u8(qh[2][2], mone); + hbit_lo_3 = vandq_u8(qh[2][3], mone); + hbit_hi_0 = vshlq_n_u8(vandq_u8(qh[2][0], mtwo), 3); + hbit_hi_1 = vshlq_n_u8(vandq_u8(qh[2][1], mtwo), 3); + hbit_hi_2 = vshlq_n_u8(vandq_u8(qh[2][2], mtwo), 3); + hbit_hi_3 = vshlq_n_u8(vandq_u8(qh[2][3], mtwo), 3); + + qh[2][0] = vshrq_n_u8(qh[2][0], 2); + qh[2][1] = vshrq_n_u8(qh[2][1], 2); + qh[2][2] = vshrq_n_u8(qh[2][2], 2); + qh[2][3] = vshrq_n_u8(qh[2][3], 2); + + acc_lo[2] = ggml_vdotq_s32( + acc_lo[2], vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(qs_0, m4b), hbit_lo_0, 4)), q8_qs[0]); + acc_lo[2] = ggml_vdotq_s32( + acc_lo[2], vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(qs_1, m4b), hbit_lo_1, 4)), q8_qs[1]); + acc_lo[2] = ggml_vdotq_s32( + acc_lo[2], vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(qs_2, m4b), hbit_lo_2, 4)), q8_qs[2]); + acc_lo[2] = ggml_vdotq_s32( + acc_lo[2], vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(qs_3, m4b), hbit_lo_3, 4)), q8_qs[3]); + acc_hi[2] = ggml_vdotq_s32(acc_hi[2], vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_0, 4), hbit_hi_0)), + q8_qs[4]); + acc_hi[2] = ggml_vdotq_s32(acc_hi[2], vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_1, 4), hbit_hi_1)), + q8_qs[5]); + acc_hi[2] = ggml_vdotq_s32(acc_hi[2], vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_2, 4), hbit_hi_2)), + q8_qs[6]); + acc_hi[2] = ggml_vdotq_s32(acc_hi[2], vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_3, 4), hbit_hi_3)), + q8_qs[7]); + + // Cols 45 + qs_0 = vld1q_u8(qs_base + 48); + qs_1 = vld1q_u8(qs_base + 112); + qs_2 = vld1q_u8(qs_base + 176); + qs_3 = vld1q_u8(qs_base + 240); + + hbit_lo_0 = vandq_u8(qh[3][0], mone); + hbit_lo_1 = vandq_u8(qh[3][1], mone); + hbit_lo_2 = vandq_u8(qh[3][2], mone); + hbit_lo_3 = vandq_u8(qh[3][3], mone); + hbit_hi_0 = vshlq_n_u8(vandq_u8(qh[3][0], mtwo), 3); + hbit_hi_1 = vshlq_n_u8(vandq_u8(qh[3][1], mtwo), 3); + hbit_hi_2 = vshlq_n_u8(vandq_u8(qh[3][2], mtwo), 3); + hbit_hi_3 = vshlq_n_u8(vandq_u8(qh[3][3], mtwo), 3); + + qh[3][0] = vshrq_n_u8(qh[3][0], 2); + qh[3][1] = vshrq_n_u8(qh[3][1], 2); + qh[3][2] = vshrq_n_u8(qh[3][2], 2); + qh[3][3] = vshrq_n_u8(qh[3][3], 2); + + acc_lo[3] = ggml_vdotq_s32( + acc_lo[3], vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(qs_0, m4b), hbit_lo_0, 4)), q8_qs[0]); + acc_lo[3] = ggml_vdotq_s32( + acc_lo[3], vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(qs_1, m4b), hbit_lo_1, 4)), q8_qs[1]); + acc_lo[3] = ggml_vdotq_s32( + acc_lo[3], vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(qs_2, m4b), hbit_lo_2, 4)), q8_qs[2]); + acc_lo[3] = ggml_vdotq_s32( + acc_lo[3], vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(qs_3, m4b), hbit_lo_3, 4)), q8_qs[3]); + acc_hi[3] = ggml_vdotq_s32(acc_hi[3], vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_0, 4), hbit_hi_0)), + q8_qs[4]); + acc_hi[3] = ggml_vdotq_s32(acc_hi[3], vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_1, 4), hbit_hi_1)), + q8_qs[5]); + acc_hi[3] = ggml_vdotq_s32(acc_hi[3], vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_2, 4), hbit_hi_2)), + q8_qs[6]); + acc_hi[3] = ggml_vdotq_s32(acc_hi[3], vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_3, 4), hbit_hi_3)), + q8_qs[7]); + } + + // Prepare bsum vectors for bias computation + // Each pair of subblocks share the same bsums + int16x4_t bsums_vec_lo = vdup_n_s16(bsums_arr[2 * sb + 0]); + int16x4_t bsums_vec_hi = vdup_n_s16(bsums_arr[2 * sb + 1]); + + // Iterates over a pair of column pairs (4 columns) to use a single 128 register + // p = 0 -> 0123 p2 -> 4567 + for (int i = 0, p = 0; p < col_pairs; i++, p += 2) { + int16x4_t group_scales_lo = p == 0 ? vget_low_s16(q5sb_scales[0]) : vget_high_s16(q5sb_scales[0]); + int16x4_t group_scales_hi = p == 0 ? vget_low_s16(q5sb_scales[1]) : vget_high_s16(q5sb_scales[1]); + int16x4_t group_mins_lo = p == 0 ? vget_low_s16(q5sb_mins[0]) : vget_high_s16(q5sb_mins[0]); + int16x4_t group_mins_hi = p == 0 ? vget_low_s16(q5sb_mins[1]) : vget_high_s16(q5sb_mins[1]); + float32x4_t sb_scale = p == 0 ? sb_scale_0 : sb_scale_1; + float32x4_t sb_min = p == 0 ? sb_min_0 : sb_min_1; + + // 0123 or 4567 + float32x4_t sumf_0 = + vcvtq_f32_s32(vmulq_s32(vmovl_s16(group_scales_lo), vpaddq_s32(acc_lo[p], acc_lo[p + 1]))); + acc_f32[i] = vfmaq_f32(acc_f32[i], sb_scale, sumf_0); + + float32x4_t sumf_1 = + vcvtq_f32_s32(vmulq_s32(vmovl_s16(group_scales_hi), vpaddq_s32(acc_hi[p], acc_hi[p + 1]))); + acc_f32[i] = vfmaq_f32(acc_f32[i], sb_scale, sumf_1); + + // FUSED BIAS: Compute and subtract bias immediately + // bias = (bsums_lo * mins_lo + bsums_hi * mins_hi) * sb_min + int32x4_t bias = vmull_s16(bsums_vec_lo, group_mins_lo); + bias = vmlal_s16(bias, bsums_vec_hi, group_mins_hi); + float32x4_t bias_f32 = vcvtq_f32_s32(bias); + acc_f32[i] = vmlsq_f32(acc_f32[i], sb_min, bias_f32); + } + } // for sb + } // for b + + int base = x * ncols_interleaved; + vst1q_f32(s + base, acc_f32[0]); + vst1q_f32(s + base + 4, acc_f32[1]); + } // for x + return; +#endif // defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD) + ggml_gemv_q5_K_8x8_q8_K_generic(n, s, bs, vx, vy, nr, nc); +} + +void ggml_gemv_q6_K_8x4_q8_K(int n, + float * GGML_RESTRICT s, + size_t bs, + const void * GGML_RESTRICT vx, + const void * GGML_RESTRICT vy, + int nr, + int nc) { + constexpr int qk = QK_K; + const int nb = n / qk; + + constexpr int ncols_interleaved = 8; + constexpr int blocklen = 4; + + assert(n % qk == 0); + assert(nc % ncols_interleaved == 0); - UNUSED(s); - UNUSED(bs); - UNUSED(vx); - UNUSED(vy); - UNUSED(nr); - UNUSED(nc); UNUSED(nb); UNUSED(ncols_interleaved); UNUSED(blocklen); -#if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD) - const void * b_ptr = vx; - const void * a_ptr = vy; - float * res_ptr = s; - size_t res_stride = bs * sizeof(float); +#if defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD) + constexpr int col_groups = ncols_interleaved / 4; + const uint8x16_t m4b = vdupq_n_u8(0x0f); + const uint8x16_t mask_lo = vdupq_n_u8(0x03); + const uint8x16_t mask_hi = vdupq_n_u8(0x30); - __asm__ __volatile__( - "mov x10, %x[nr]\n" - "mov x9, #0x88\n" - "cmp x10, #0x10\n" - "mul x9, %x[nb], x9\n" - "blt 4f\n" - "1:" // Row loop - "add x28, %x[b_ptr], #0x8\n" - "mov x27, %x[nc]\n" - "add x26, %x[res_ptr], %x[res_stride], LSL #4\n" - "2:" // Column loop - "add x25, %x[a_ptr], #0x8\n" - "movi v15.16b, #0x0\n" - "movi v19.16b, #0x0\n" - "mov x24, %x[nb]\n" - "add x23, x25, x9\n" - "movi v18.16b, #0x0\n" - "movi v14.16b, #0x0\n" - "add x22, x23, x9\n" - "movi v11.16b, #0x0\n" - "movi v13.16b, #0x0\n" - "add x21, x22, x9\n" - "movi v23.16b, #0x0\n" - "movi v16.16b, #0x0\n" - "movi v25.16b, #0x0\n" - "movi v7.16b, #0x0\n" - "movi v0.16b, #0x0\n" - "movi v4.16b, #0x0\n" - "movi v5.16b, #0x0\n" - "movi v21.16b, #0x0\n" - "movi v8.16b, #0x0\n" - "movi v1.16b, #0x0\n" - "3:" // Block loop - "ldr q3, [x28, #0x0]\n" - "ldr q31, [x25, #0x0]\n" - "movi v28.16b, #0x4\n" - "movi v10.4s, #0x0\n" - "ldr q22, [x28, #0x10]\n" - "ldr q6, [x25, #0x10]\n" - "movi v29.4s, #0x0\n" - "movi v9.4s, #0x0\n" - "ldr q27, [x28, #0x20]\n" - "ldr q30, [x28, #0x30]\n" - "movi v20.4s, #0x0\n" - "movi v24.16b, #0xf0\n" - "ldr d2, [x25, #-0x8]\n" - "ldr d26, [x23, #-0x8]\n" - "sshl v12.16b, v3.16b, v28.16b\n" - "sub x20, x28, #0x8\n" - "ldr d17, [x20, #0x0]\n" - "and v3.16b, v3.16b, v24.16b\n" - "subs x24, x24, #0x1\n" - "add x28, x28, #0x48\n" - ".inst 0x4f9fe18a // sdot v10.4s, v12.16b, v31.4b[0]\n" - ".inst 0x4fbfe19d // sdot v29.4s, v12.16b, v31.4b[1]\n" - ".inst 0x4f9fe989 // sdot v9.4s, v12.16b, v31.4b[2]\n" - ".inst 0x4fbfe994 // sdot v20.4s, v12.16b, v31.4b[3]\n" - "sshl v31.16b, v22.16b, v28.16b\n" - "and v22.16b, v22.16b, v24.16b\n" - "fcvtl v17.4s, v17.4h\n" - "fcvtl v2.4s, v2.4h\n" - "fcvtl v26.4s, v26.4h\n" - ".inst 0x4f86e3ea // sdot v10.4s, v31.16b, v6.4b[0]\n" - ".inst 0x4fa6e3fd // sdot v29.4s, v31.16b, v6.4b[1]\n" - ".inst 0x4f86ebe9 // sdot v9.4s, v31.16b, v6.4b[2]\n" - ".inst 0x4fa6ebf4 // sdot v20.4s, v31.16b, v6.4b[3]\n" - "sshl v6.16b, v27.16b, v28.16b\n" - "sshl v28.16b, v30.16b, v28.16b\n" - "and v27.16b, v27.16b, v24.16b\n" - "and v30.16b, v30.16b, v24.16b\n" - "ldr q24, [x25, #0x20]\n" - ".inst 0x4f98e0ca // sdot v10.4s, v6.16b, v24.4b[0]\n" - ".inst 0x4fb8e0dd // sdot v29.4s, v6.16b, v24.4b[1]\n" - ".inst 0x4f98e8c9 // sdot v9.4s, v6.16b, v24.4b[2]\n" - ".inst 0x4fb8e8d4 // sdot v20.4s, v6.16b, v24.4b[3]\n" - "ldr q24, [x25, #0x30]\n" - ".inst 0x4f98e38a // sdot v10.4s, v28.16b, v24.4b[0]\n" - ".inst 0x4fb8e39d // sdot v29.4s, v28.16b, v24.4b[1]\n" - ".inst 0x4f98eb89 // sdot v9.4s, v28.16b, v24.4b[2]\n" - ".inst 0x4fb8eb94 // sdot v20.4s, v28.16b, v24.4b[3]\n" - "ldr q24, [x25, #0x40]\n" - ".inst 0x4f98e06a // sdot v10.4s, v3.16b, v24.4b[0]\n" - ".inst 0x4fb8e07d // sdot v29.4s, v3.16b, v24.4b[1]\n" - ".inst 0x4f98e869 // sdot v9.4s, v3.16b, v24.4b[2]\n" - ".inst 0x4fb8e874 // sdot v20.4s, v3.16b, v24.4b[3]\n" - "ldr q24, [x25, #0x50]\n" - ".inst 0x4f98e2ca // sdot v10.4s, v22.16b, v24.4b[0]\n" - ".inst 0x4fb8e2dd // sdot v29.4s, v22.16b, v24.4b[1]\n" - ".inst 0x4f98eac9 // sdot v9.4s, v22.16b, v24.4b[2]\n" - ".inst 0x4fb8ead4 // sdot v20.4s, v22.16b, v24.4b[3]\n" - "ldr q24, [x25, #0x60]\n" - ".inst 0x4f98e36a // sdot v10.4s, v27.16b, v24.4b[0]\n" - ".inst 0x4fb8e37d // sdot v29.4s, v27.16b, v24.4b[1]\n" - ".inst 0x4f98eb69 // sdot v9.4s, v27.16b, v24.4b[2]\n" - ".inst 0x4fb8eb74 // sdot v20.4s, v27.16b, v24.4b[3]\n" - "ldr q24, [x25, #0x70]\n" - "add x25, x25, #0x88\n" - ".inst 0x4f98e3ca // sdot v10.4s, v30.16b, v24.4b[0]\n" - ".inst 0x4fb8e3dd // sdot v29.4s, v30.16b, v24.4b[1]\n" - ".inst 0x4f98ebc9 // sdot v9.4s, v30.16b, v24.4b[2]\n" - ".inst 0x4fb8ebd4 // sdot v20.4s, v30.16b, v24.4b[3]\n" - "fmul v24.4s, v17.4s, v2.s[0]\n" - "scvtf v10.4s, v10.4s, #0x4\n" - "scvtf v29.4s, v29.4s, #0x4\n" - "scvtf v9.4s, v9.4s, #0x4\n" - "scvtf v20.4s, v20.4s, #0x4\n" - "fmla v15.4s, v10.4s, v24.4s\n" - "ldr q24, [x23, #0x0]\n" - "fmul v10.4s, v17.4s, v2.s[1]\n" - "fmla v19.4s, v29.4s, v10.4s\n" - "ldr q10, [x23, #0x10]\n" - "fmul v29.4s, v17.4s, v2.s[2]\n" - "fmul v2.4s, v17.4s, v2.s[3]\n" - "fmla v18.4s, v9.4s, v29.4s\n" - "movi v9.4s, #0x0\n" - "movi v29.4s, #0x0\n" - ".inst 0x4f98e189 // sdot v9.4s, v12.16b, v24.4b[0]\n" - ".inst 0x4fb8e19d // sdot v29.4s, v12.16b, v24.4b[1]\n" - "fmla v14.4s, v20.4s, v2.4s\n" - "movi v20.4s, #0x0\n" - "movi v2.4s, #0x0\n" - ".inst 0x4f98e994 // sdot v20.4s, v12.16b, v24.4b[2]\n" - ".inst 0x4fb8e982 // sdot v2.4s, v12.16b, v24.4b[3]\n" - "ldr q24, [x23, #0x20]\n" - ".inst 0x4f8ae3e9 // sdot v9.4s, v31.16b, v10.4b[0]\n" - ".inst 0x4faae3fd // sdot v29.4s, v31.16b, v10.4b[1]\n" - ".inst 0x4f8aebf4 // sdot v20.4s, v31.16b, v10.4b[2]\n" - ".inst 0x4faaebe2 // sdot v2.4s, v31.16b, v10.4b[3]\n" - "ldr q10, [x23, #0x30]\n" - ".inst 0x4f98e0c9 // sdot v9.4s, v6.16b, v24.4b[0]\n" - ".inst 0x4fb8e0dd // sdot v29.4s, v6.16b, v24.4b[1]\n" - ".inst 0x4f98e8d4 // sdot v20.4s, v6.16b, v24.4b[2]\n" - ".inst 0x4fb8e8c2 // sdot v2.4s, v6.16b, v24.4b[3]\n" - "ldr q24, [x23, #0x40]\n" - ".inst 0x4f8ae389 // sdot v9.4s, v28.16b, v10.4b[0]\n" - ".inst 0x4faae39d // sdot v29.4s, v28.16b, v10.4b[1]\n" - ".inst 0x4f8aeb94 // sdot v20.4s, v28.16b, v10.4b[2]\n" - ".inst 0x4faaeb82 // sdot v2.4s, v28.16b, v10.4b[3]\n" - "ldr q10, [x23, #0x50]\n" - ".inst 0x4f98e069 // sdot v9.4s, v3.16b, v24.4b[0]\n" - ".inst 0x4fb8e07d // sdot v29.4s, v3.16b, v24.4b[1]\n" - ".inst 0x4f98e874 // sdot v20.4s, v3.16b, v24.4b[2]\n" - ".inst 0x4fb8e862 // sdot v2.4s, v3.16b, v24.4b[3]\n" - "ldr q24, [x23, #0x60]\n" - ".inst 0x4f8ae2c9 // sdot v9.4s, v22.16b, v10.4b[0]\n" - ".inst 0x4faae2dd // sdot v29.4s, v22.16b, v10.4b[1]\n" - ".inst 0x4f8aead4 // sdot v20.4s, v22.16b, v10.4b[2]\n" - ".inst 0x4faaeac2 // sdot v2.4s, v22.16b, v10.4b[3]\n" - "ldr q10, [x23, #0x70]\n" - "add x23, x23, #0x88\n" - ".inst 0x4f98e369 // sdot v9.4s, v27.16b, v24.4b[0]\n" - ".inst 0x4fb8e37d // sdot v29.4s, v27.16b, v24.4b[1]\n" - ".inst 0x4f98eb74 // sdot v20.4s, v27.16b, v24.4b[2]\n" - ".inst 0x4fb8eb62 // sdot v2.4s, v27.16b, v24.4b[3]\n" - "ldr q24, [x22, #0x0]\n" - ".inst 0x4f8ae3c9 // sdot v9.4s, v30.16b, v10.4b[0]\n" - ".inst 0x4faae3dd // sdot v29.4s, v30.16b, v10.4b[1]\n" - ".inst 0x4f8aebd4 // sdot v20.4s, v30.16b, v10.4b[2]\n" - ".inst 0x4faaebc2 // sdot v2.4s, v30.16b, v10.4b[3]\n" - "fmul v10.4s, v17.4s, v26.s[0]\n" - "scvtf v9.4s, v9.4s, #0x4\n" - "scvtf v29.4s, v29.4s, #0x4\n" - "scvtf v20.4s, v20.4s, #0x4\n" - "scvtf v2.4s, v2.4s, #0x4\n" - "fmla v11.4s, v9.4s, v10.4s\n" - "ldr q9, [x22, #0x10]\n" - "fmul v10.4s, v17.4s, v26.s[1]\n" - "fmla v13.4s, v29.4s, v10.4s\n" - "ldr d29, [x22, #-0x8]\n" - "fmul v10.4s, v17.4s, v26.s[2]\n" - "fmul v26.4s, v17.4s, v26.s[3]\n" - "fcvtl v29.4s, v29.4h\n" - "fmla v23.4s, v20.4s, v10.4s\n" - "movi v20.4s, #0x0\n" - "movi v10.4s, #0x0\n" - "fmla v16.4s, v2.4s, v26.4s\n" - "movi v26.4s, #0x0\n" - "movi v2.4s, #0x0\n" - ".inst 0x4f98e194 // sdot v20.4s, v12.16b, v24.4b[0]\n" - ".inst 0x4fb8e18a // sdot v10.4s, v12.16b, v24.4b[1]\n" - ".inst 0x4f98e99a // sdot v26.4s, v12.16b, v24.4b[2]\n" - ".inst 0x4fb8e982 // sdot v2.4s, v12.16b, v24.4b[3]\n" - "ldr q24, [x22, #0x20]\n" - ".inst 0x4f89e3f4 // sdot v20.4s, v31.16b, v9.4b[0]\n" - ".inst 0x4fa9e3ea // sdot v10.4s, v31.16b, v9.4b[1]\n" - ".inst 0x4f89ebfa // sdot v26.4s, v31.16b, v9.4b[2]\n" - ".inst 0x4fa9ebe2 // sdot v2.4s, v31.16b, v9.4b[3]\n" - "ldr q9, [x22, #0x30]\n" + // 1x8 tile = 2 x 4 + float32x4_t acc_f32[2]; + + const block_q8_K * GGML_RESTRICT q8_ptr = (const block_q8_K *) vy; + + for (int x = 0; x < nc / ncols_interleaved; x++) { + const block_q6_Kx8 * GGML_RESTRICT q6_ptr = (const block_q6_Kx8 *) vx + (x * nb); + + for (int i = 0; i < col_groups; i++) { + acc_f32[i] = vdupq_n_f32(0); + } + + for (int b = 0; b < nb; b++) { + float32x4_t q6_d_0 = vcvt_f32_f16(vld1_f16((const __fp16 *) q6_ptr[b].d)); // d0 d1 d2 d3 + float32x4_t q6_d_1 = vcvt_f32_f16(vld1_f16((const __fp16 *) q6_ptr[b].d + 4)); // d4 d5 d6 d7 + float32x4_t q8_d = vdupq_n_f32(q8_ptr[b].d); + float32x4_t sb_scale_0 = vmulq_f32(q6_d_0, q8_d); + float32x4_t sb_scale_1 = vmulq_f32(q6_d_1, q8_d); + + int32x4_t acc[col_groups]; + for (int i = 0; i < col_groups; i++) { + acc[i] = vdupq_n_s32(0); + } + + // Load all 16 scales once and widen to int16 (Q6_K has 16 scales per block) + // Reused for bias and dequantization later + int16_t q6_scales[16 * 8]; + for (int i = 0; i < 16; i++) { + int16x8_t scales = vmovl_s8(vld1_s8(q6_ptr[b].scales + i * 8)); + vst1q_s16(q6_scales + i * 8, scales); + } + + // Compute bias per column using q8 bsums and preloaded scales to skip the -32 shift + int32x4_t bias_lo = vdupq_n_s32(0); + int32x4_t bias_hi = vdupq_n_s32(0); + + // Load bsums in chunks of 4 to process with vectorized operations + for (int i = 0; i < 16; i += 4) { + int16x4_t bsums_vec = vld1_s16(q8_ptr[b].bsums + i); + int16x4_t scales_lo_0 = vld1_s16(q6_scales + (i + 0) * 8); + int16x4_t scales_hi_0 = vld1_s16(q6_scales + (i + 0) * 8 + 4); + int16x4_t scales_lo_1 = vld1_s16(q6_scales + (i + 1) * 8); + int16x4_t scales_hi_1 = vld1_s16(q6_scales + (i + 1) * 8 + 4); + int16x4_t scales_lo_2 = vld1_s16(q6_scales + (i + 2) * 8); + int16x4_t scales_hi_2 = vld1_s16(q6_scales + (i + 2) * 8 + 4); + int16x4_t scales_lo_3 = vld1_s16(q6_scales + (i + 3) * 8); + int16x4_t scales_hi_3 = vld1_s16(q6_scales + (i + 3) * 8 + 4); + + bias_lo = vmlal_lane_s16(bias_lo, scales_lo_0, bsums_vec, 0); + bias_hi = vmlal_lane_s16(bias_hi, scales_hi_0, bsums_vec, 0); + bias_lo = vmlal_lane_s16(bias_lo, scales_lo_1, bsums_vec, 1); + bias_hi = vmlal_lane_s16(bias_hi, scales_hi_1, bsums_vec, 1); + bias_lo = vmlal_lane_s16(bias_lo, scales_lo_2, bsums_vec, 2); + bias_hi = vmlal_lane_s16(bias_hi, scales_hi_2, bsums_vec, 2); + bias_lo = vmlal_lane_s16(bias_lo, scales_lo_3, bsums_vec, 3); + bias_hi = vmlal_lane_s16(bias_hi, scales_hi_3, bsums_vec, 3); + } + bias_lo = vshlq_n_s32(bias_lo, 5); + bias_hi = vshlq_n_s32(bias_hi, 5); + + // Process two 128-value halves per superblock + for (int half = 0; half < 2; half++) { + const uint8_t * ql_base = q6_ptr[b].ql + half * 512; + const uint8_t * qh_base = q6_ptr[b].qh + half * 256; + + // A subblock (sb) is a set of weights that share the scale + // Since q6_K scales are per 16 elements + // num sbs -> 256 elements / (16 elements/scale * 2 elements/byte * 2 halves) + for (int sb = 0; sb < QK_K / 64; sb++) { + const int8_t * q8_base_l = q8_ptr[b].qs + half * 128 + sb * 16; + const int8_t * q8_base_h = q8_base_l + 64; + + // Load and duplicate q8 values (each register covers four interleaved columns of q6) + int8x16_t q8_l[4]; + int8x16_t q8_h[4]; + for (int i = 0; i < 4; i++) { + q8_l[i] = (int8x16_t) vld1q_dup_s32((const int32_t *) (q8_base_l + i * 4)); + q8_h[i] = (int8x16_t) vld1q_dup_s32((const int32_t *) (q8_base_h + i * 4)); + } + + const int ql_off_base = sb * QK_K / 2; + const int qh_off_base = ql_off_base & 255; // wraps after 256 bytes + + // Load 4 vectors at once (64 bytes each for ql_0, ql_1, qh_0, qh_1) + uint8x16x4_t q6_ql_0 = vld1q_u8_x4(ql_base + ql_off_base); + uint8x16x4_t q6_ql_1 = vld1q_u8_x4(ql_base + ql_off_base + 64); + uint8x16x4_t q6_qh_0 = vld1q_u8_x4(qh_base + qh_off_base); + uint8x16x4_t q6_qh_1 = vld1q_u8_x4(qh_base + qh_off_base + 64); + + // Adjust qh for subblocks 2 and 3 (shift right by 2) + if (sb > 1) { + q6_qh_0.val[0] = vshrq_n_u8(q6_qh_0.val[0], 2); + q6_qh_0.val[1] = vshrq_n_u8(q6_qh_0.val[1], 2); + q6_qh_0.val[2] = vshrq_n_u8(q6_qh_0.val[2], 2); + q6_qh_0.val[3] = vshrq_n_u8(q6_qh_0.val[3], 2); + q6_qh_1.val[0] = vshrq_n_u8(q6_qh_1.val[0], 2); + q6_qh_1.val[1] = vshrq_n_u8(q6_qh_1.val[1], 2); + q6_qh_1.val[2] = vshrq_n_u8(q6_qh_1.val[2], 2); + q6_qh_1.val[3] = vshrq_n_u8(q6_qh_1.val[3], 2); + } + + const uint8x16_t q6_ql[8] = { q6_ql_0.val[0], q6_ql_0.val[1], q6_ql_0.val[2], q6_ql_0.val[3], + q6_ql_1.val[0], q6_ql_1.val[1], q6_ql_1.val[2], q6_ql_1.val[3] }; + const uint8x16_t q6_qh[8] = { q6_qh_0.val[0], q6_qh_0.val[1], q6_qh_0.val[2], q6_qh_0.val[3], + q6_qh_1.val[0], q6_qh_1.val[1], q6_qh_1.val[2], q6_qh_1.val[3] }; + + // Process column groups (0-3, 4-7) + for (int g = 0; g < col_groups; g++) { + int32x4_t sb_acc_l = vdupq_n_s32(0); + int32x4_t sb_acc_h = vdupq_n_s32(0); + + for (int chunk = 0; chunk < 4; chunk++) { + const int idx = chunk * 2 + g; + + const uint8x16_t q6_qs_l = q6_ql[idx]; + const uint8x16_t q6_qs_h = q6_qh[idx]; + + // Extract high 2 bits for upper nibble reconstruction + const uint8x16_t q6_qs_hh = vandq_u8(q6_qs_h, mask_hi); + + // q6 = (low4 | high2<<4), without -32 bias (handled via bsums) + const int8x16_t q6_l = + vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(q6_qs_l, m4b), vandq_u8(q6_qs_h, mask_lo), 4)); + const int8x16_t q6_h = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6_qs_l, 4), q6_qs_hh)); + + sb_acc_l = vdotq_s32(sb_acc_l, q6_l, q8_l[chunk]); + sb_acc_h = vdotq_s32(sb_acc_h, q6_h, q8_h[chunk]); + } + + const int scale_idx_l = half * 8 + sb; + const int scale_idx_h = half * 8 + sb + 4; + + const int32x4_t scale_vec_l = vmovl_s16(vld1_s16(q6_scales + scale_idx_l * 8 + g * 4)); + const int32x4_t scale_vec_h = vmovl_s16(vld1_s16(q6_scales + scale_idx_h * 8 + g * 4)); + + acc[g] = vmlaq_s32(acc[g], sb_acc_l, scale_vec_l); + acc[g] = vmlaq_s32(acc[g], sb_acc_h, scale_vec_h); + } + } + } // for half + + // Bias correction + acc[0] = vsubq_s32(acc[0], bias_lo); + acc[1] = vsubq_s32(acc[1], bias_hi); + + // Apply superblock scale (no mins for q6_K) + // acc[g] has [c0, c1, c2, c3] + float32x4_t w_0123 = vmulq_f32(vcvtq_f32_s32(acc[0]), sb_scale_0); + float32x4_t w_4567 = vmulq_f32(vcvtq_f32_s32(acc[1]), sb_scale_1); + + acc_f32[0] = vaddq_f32(acc_f32[0], w_0123); + acc_f32[1] = vaddq_f32(acc_f32[1], w_4567); + } // for b + + int base = x * ncols_interleaved; + vst1q_f32(s + base, acc_f32[0]); + vst1q_f32(s + base + 4, acc_f32[1]); + } // for x + return; +#endif // defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD) + ggml_gemv_q6_K_8x4_q8_K_generic(n, s, bs, vx, vy, nr, nc); +} + +void ggml_gemv_q6_K_8x8_q8_K(int n, + float * GGML_RESTRICT s, + size_t bs, + const void * GGML_RESTRICT vx, + const void * GGML_RESTRICT vy, + int nr, + int nc) { + constexpr int qk = QK_K; + const int nb = n / qk; + + constexpr int ncols_interleaved = 8; + constexpr int blocklen = 8; + + assert(n % qk == 0); + assert(nc % ncols_interleaved == 0); + + UNUSED(nb); + UNUSED(ncols_interleaved); + UNUSED(blocklen); + +#if defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD) + constexpr int col_pairs = ncols_interleaved / 2; + const uint8x16_t m4b = vdupq_n_u8(0x0f); + const uint8x16_t mask_lo = vdupq_n_u8(0x03); + const uint8x16_t mask_hi = vdupq_n_u8(0x30); + + // 1x8 tile = 2 x 4 + float32x4_t acc_f32[2]; + + const block_q8_K * GGML_RESTRICT q8_ptr = (const block_q8_K *) vy; + + for (int x = 0; x < nc / ncols_interleaved; x++) { + const block_q6_Kx8 * GGML_RESTRICT q6_ptr = (const block_q6_Kx8 *) vx + (x * nb); + + acc_f32[0] = vdupq_n_f32(0); + acc_f32[1] = vdupq_n_f32(0); + + for (int b = 0; b < nb; b++) { + float32x4_t q6_d_0 = vcvt_f32_f16(vld1_f16((const __fp16 *) q6_ptr[b].d)); // d0 d1 d2 d3 + float32x4_t q6_d_1 = vcvt_f32_f16(vld1_f16((const __fp16 *) q6_ptr[b].d + 4)); // d4 d5 d6 d7 + float32x4_t q8_d = vdupq_n_f32(q8_ptr[b].d); + float32x4_t sb_scale_0 = vmulq_f32(q6_d_0, q8_d); + float32x4_t sb_scale_1 = vmulq_f32(q6_d_1, q8_d); + + int32x2_t acc[col_pairs]; + for (int i = 0; i < col_pairs; i++) { + acc[i] = vdup_n_s32(0); + } + + // Load all 16 scales once and widen to int16 (Q6_K has 16 scales per block) + // Reused for bias and dequantization later + int16_t q6_scales[16 * 8]; + for (int i = 0; i < 16; i++) { + int16x8_t scales = vmovl_s8(vld1_s8(q6_ptr[b].scales + i * 8)); + vst1q_s16(q6_scales + i * 8, scales); + } + + // Compute bias per column using q8 bsums and preloaded scales to skip the -32 shift + int32x4_t bias_lo = vdupq_n_s32(0); + int32x4_t bias_hi = vdupq_n_s32(0); + + // Load bsums in chunks of 4 to process with vectorized operations + for (int i = 0; i < 16; i += 4) { + int16x4_t bsums_vec = vld1_s16(q8_ptr[b].bsums + i); + int16x4_t scales_lo_0 = vld1_s16(q6_scales + (i + 0) * 8); + int16x4_t scales_hi_0 = vld1_s16(q6_scales + (i + 0) * 8 + 4); + int16x4_t scales_lo_1 = vld1_s16(q6_scales + (i + 1) * 8); + int16x4_t scales_hi_1 = vld1_s16(q6_scales + (i + 1) * 8 + 4); + int16x4_t scales_lo_2 = vld1_s16(q6_scales + (i + 2) * 8); + int16x4_t scales_hi_2 = vld1_s16(q6_scales + (i + 2) * 8 + 4); + int16x4_t scales_lo_3 = vld1_s16(q6_scales + (i + 3) * 8); + int16x4_t scales_hi_3 = vld1_s16(q6_scales + (i + 3) * 8 + 4); + + bias_lo = vmlal_lane_s16(bias_lo, scales_lo_0, bsums_vec, 0); + bias_hi = vmlal_lane_s16(bias_hi, scales_hi_0, bsums_vec, 0); + bias_lo = vmlal_lane_s16(bias_lo, scales_lo_1, bsums_vec, 1); + bias_hi = vmlal_lane_s16(bias_hi, scales_hi_1, bsums_vec, 1); + bias_lo = vmlal_lane_s16(bias_lo, scales_lo_2, bsums_vec, 2); + bias_hi = vmlal_lane_s16(bias_hi, scales_hi_2, bsums_vec, 2); + bias_lo = vmlal_lane_s16(bias_lo, scales_lo_3, bsums_vec, 3); + bias_hi = vmlal_lane_s16(bias_hi, scales_hi_3, bsums_vec, 3); + } + bias_lo = vshlq_n_s32(bias_lo, 5); + bias_hi = vshlq_n_s32(bias_hi, 5); + + // Process two 128-value halves per superblock + for (int half = 0; half < 2; half++) { + const uint8_t * ql_base = q6_ptr[b].ql + half * 512; + const uint8_t * qh_base = q6_ptr[b].qh + half * 256; + + // A subblock (sb) is a set of weights that share the scale + // Since q6_K scales are per 16 elements + // num sbs -> 256 elements / (16 elements/scale * 2 elements/byte * 2 halves) + for (int sb = 0; sb < QK_K / 64; sb++) { + const int8_t * q8_base_l = q8_ptr[b].qs + half * 128 + sb * 16; + const int8_t * q8_base_h = q8_base_l + 64; + + // Load and duplicate q8 values (each register covers two interleaved columns of q6) + int8x16_t q8_l[2]; + int8x16_t q8_h[2]; + for (int i = 0; i < 2; i++) { + q8_l[i] = (int8x16_t) vld1q_dup_s64((const int64_t *) (q8_base_l + i * 8)); + q8_h[i] = (int8x16_t) vld1q_dup_s64((const int64_t *) (q8_base_h + i * 8)); + } + + const int ql_off_base = sb * QK_K / 2; + const int qh_off_base = ql_off_base & 255; // wraps after 256 bytes + + // Load 4 vectors at once (64 bytes each for ql_0, ql_1, qh_0, qh_1) + uint8x16x4_t q6_ql_0 = vld1q_u8_x4(ql_base + ql_off_base); + uint8x16x4_t q6_ql_1 = vld1q_u8_x4(ql_base + ql_off_base + 64); + uint8x16x4_t q6_qh_0 = vld1q_u8_x4(qh_base + qh_off_base); + uint8x16x4_t q6_qh_1 = vld1q_u8_x4(qh_base + qh_off_base + 64); + + // Adjust qh for subblocks 2 and 3 (shift right by 2) + if (sb > 1) { + q6_qh_0.val[0] = vshrq_n_u8(q6_qh_0.val[0], 2); + q6_qh_0.val[1] = vshrq_n_u8(q6_qh_0.val[1], 2); + q6_qh_0.val[2] = vshrq_n_u8(q6_qh_0.val[2], 2); + q6_qh_0.val[3] = vshrq_n_u8(q6_qh_0.val[3], 2); + q6_qh_1.val[0] = vshrq_n_u8(q6_qh_1.val[0], 2); + q6_qh_1.val[1] = vshrq_n_u8(q6_qh_1.val[1], 2); + q6_qh_1.val[2] = vshrq_n_u8(q6_qh_1.val[2], 2); + q6_qh_1.val[3] = vshrq_n_u8(q6_qh_1.val[3], 2); + } + + // Process column pairs (0-1, 2-3, 4-5, 6-7) + for (int cp = 0; cp < col_pairs; cp++) { + const uint8x16_t q6_qs_cp_0_l = q6_ql_0.val[cp]; + const uint8x16_t q6_qs_cp_1_l = q6_ql_1.val[cp]; + const uint8x16_t q6_qs_cp_0_h = q6_qh_0.val[cp]; + const uint8x16_t q6_qs_cp_1_h = q6_qh_1.val[cp]; + + // Extract high 2 bits for upper nibble reconstruction + const uint8x16_t q6_qs_cp_0_hh = vandq_u8(q6_qs_cp_0_h, mask_hi); + const uint8x16_t q6_qs_cp_1_hh = vandq_u8(q6_qs_cp_1_h, mask_hi); + + // q6 = (low4 | high2<<4), without -32 bias (handled via bsums) + const int8x16_t q6_l0 = vreinterpretq_s8_u8( + vsliq_n_u8(vandq_u8(q6_qs_cp_0_l, m4b), vandq_u8(q6_qs_cp_0_h, mask_lo), 4)); + const int8x16_t q6_l1 = vreinterpretq_s8_u8( + vsliq_n_u8(vandq_u8(q6_qs_cp_1_l, m4b), vandq_u8(q6_qs_cp_1_h, mask_lo), 4)); + const int8x16_t q6_h0 = + vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6_qs_cp_0_l, 4), q6_qs_cp_0_hh)); + const int8x16_t q6_h1 = + vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6_qs_cp_1_l, 4), q6_qs_cp_1_hh)); + + int32x4_t sb_acc_l = vdupq_n_s32(0); + sb_acc_l = vdotq_s32(sb_acc_l, q6_l0, q8_l[0]); + sb_acc_l = vdotq_s32(sb_acc_l, q6_l1, q8_l[1]); + + int32x4_t sb_acc_h = vdupq_n_s32(0); + sb_acc_h = vdotq_s32(sb_acc_h, q6_h0, q8_h[0]); + sb_acc_h = vdotq_s32(sb_acc_h, q6_h1, q8_h[1]); + + // Pairwise add to get per-column sums: [col0, col1] + int32x2_t sum_l = vpadd_s32(vget_low_s32(sb_acc_l), vget_high_s32(sb_acc_l)); + int32x2_t sum_h = vpadd_s32(vget_low_s32(sb_acc_h), vget_high_s32(sb_acc_h)); + + const int scale_idx_l = half * 8 + sb; + const int scale_idx_h = half * 8 + sb + 4; + + // Access scales using array indexing (scales are interleaved by column) + const int32x2_t scale_vec_l = { (int32_t) q6_scales[scale_idx_l * 8 + cp * 2], + (int32_t) q6_scales[scale_idx_l * 8 + cp * 2 + 1] }; + const int32x2_t scale_vec_h = { (int32_t) q6_scales[scale_idx_h * 8 + cp * 2], + (int32_t) q6_scales[scale_idx_h * 8 + cp * 2 + 1] }; + + // Accumulate scaled results + acc[cp] = vmla_s32(acc[cp], sum_l, scale_vec_l); + acc[cp] = vmla_s32(acc[cp], sum_h, scale_vec_h); + } + } + } // for half + + // Bias correction + acc[0] = vsub_s32(acc[0], vget_low_s32(bias_lo)); + acc[1] = vsub_s32(acc[1], vget_high_s32(bias_lo)); + acc[2] = vsub_s32(acc[2], vget_low_s32(bias_hi)); + acc[3] = vsub_s32(acc[3], vget_high_s32(bias_hi)); + + // Apply superblock scale (no mins for q6_K) + // acc[cp] has [c0, c1] + float32x2_t w_01 = vmul_f32(vcvt_f32_s32(acc[0]), vget_low_f32(sb_scale_0)); + float32x2_t w_23 = vmul_f32(vcvt_f32_s32(acc[1]), vget_high_f32(sb_scale_0)); + float32x2_t w_45 = vmul_f32(vcvt_f32_s32(acc[2]), vget_low_f32(sb_scale_1)); + float32x2_t w_67 = vmul_f32(vcvt_f32_s32(acc[3]), vget_high_f32(sb_scale_1)); + + acc_f32[0] = vaddq_f32(acc_f32[0], vcombine_f32(w_01, w_23)); + acc_f32[1] = vaddq_f32(acc_f32[1], vcombine_f32(w_45, w_67)); + } // for b + + int base = x * ncols_interleaved; + vst1q_f32(s + base, acc_f32[0]); + vst1q_f32(s + base + 4, acc_f32[1]); + } // for x + return; +#endif // defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD) + ggml_gemv_q6_K_8x8_q8_K_generic(n, s, bs, vx, vy, nr, nc); +} + +void ggml_gemv_q8_0_4x4_q8_0(int n, + float * GGML_RESTRICT s, + size_t bs, + const void * GGML_RESTRICT vx, + const void * GGML_RESTRICT vy, + int nr, + int nc) { + const int qk = QK8_0; + const int nb = n / qk; + const int ncols_interleaved = 4; + const int blocklen = 4; + + assert(n % qk == 0); + assert(nc % ncols_interleaved == 0); + + UNUSED(nb); + UNUSED(ncols_interleaved); + UNUSED(blocklen); + +#if defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD) + const block_q8_0x4 * b_ptr = (const block_q8_0x4 *) vx; + + for (int c = 0; c < nc; c += ncols_interleaved) { + const block_q8_0 * a_ptr = (const block_q8_0 *) vy; + float32x4_t acc = vdupq_n_f32(0); + for (int b = 0; b < nb; b++) { + int8x16x4_t b_low = vld1q_s8_x4((const int8_t *) b_ptr->qs); + int8x16x4_t b_high = vld1q_s8_x4((const int8_t *) b_ptr->qs + 64); + float16x4_t bd = vld1_f16((const __fp16 *) b_ptr->d); + + int8x16x2_t a = vld1q_s8_x2(a_ptr->qs); + float16x4_t ad = vld1_dup_f16((const __fp16 *) &a_ptr->d); + + int32x4_t ret = vdupq_n_s32(0); + + ret = vdotq_laneq_s32(ret, b_low.val[0], a.val[0], 0); + ret = vdotq_laneq_s32(ret, b_low.val[1], a.val[0], 1); + ret = vdotq_laneq_s32(ret, b_low.val[2], a.val[0], 2); + ret = vdotq_laneq_s32(ret, b_low.val[3], a.val[0], 3); + + ret = vdotq_laneq_s32(ret, b_high.val[0], a.val[1], 0); + ret = vdotq_laneq_s32(ret, b_high.val[1], a.val[1], 1); + ret = vdotq_laneq_s32(ret, b_high.val[2], a.val[1], 2); + ret = vdotq_laneq_s32(ret, b_high.val[3], a.val[1], 3); + + acc = vfmaq_f32(acc, vcvtq_f32_s32(ret), vmulq_f32(vcvt_f32_f16(ad), vcvt_f32_f16(bd))); + a_ptr++; + b_ptr++; + } + vst1q_f32(s, acc); + s += ncols_interleaved; + } + return; + +#endif // defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD) + ggml_gemv_q8_0_4x4_q8_0_generic(n, s, bs, vx, vy, nr, nc); +} + +void ggml_gemv_q8_0_4x8_q8_0(int n, + float * GGML_RESTRICT s, + size_t bs, + const void * GGML_RESTRICT vx, + const void * GGML_RESTRICT vy, + int nr, + int nc) { + const int qk = QK8_0; + const int nb = n / qk; + const int ncols_interleaved = 4; + const int blocklen = 8; + + assert(n % qk == 0); + assert(nc % ncols_interleaved == 0); + + UNUSED(nb); + UNUSED(ncols_interleaved); + UNUSED(blocklen); + +#if defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD) + const block_q8_0x4 * b_ptr = (const block_q8_0x4 *) vx; + + for (int c = 0; c < nc; c += ncols_interleaved) { + const block_q8_0 * a_ptr = (const block_q8_0 *) vy; + float32x4_t acc = vdupq_n_f32(0); + + for (int b = 0; b < nb; b++) { + int8x16x4_t b_low = vld1q_s8_x4((const int8_t *) b_ptr->qs); + int8x16x4_t b_high = vld1q_s8_x4((const int8_t *) b_ptr->qs + 64); + float16x4_t bd = vld1_f16((const __fp16 *) b_ptr->d); + + int8x8x4_t a_chunks = vld1_s8_x4(a_ptr->qs); + int8x16_t a0 = vcombine_s8(a_chunks.val[0], a_chunks.val[0]); + int8x16_t a1 = vcombine_s8(a_chunks.val[1], a_chunks.val[1]); + int8x16_t a2 = vcombine_s8(a_chunks.val[2], a_chunks.val[2]); + int8x16_t a3 = vcombine_s8(a_chunks.val[3], a_chunks.val[3]); + float16x4_t ad = vld1_dup_f16((const __fp16 *) &a_ptr->d); + + int32x4_t ret0 = vdupq_n_s32(0); + int32x4_t ret1 = vdupq_n_s32(0); + + // 0..7 + ret0 = vdotq_s32(ret0, b_low.val[0], a0); + ret1 = vdotq_s32(ret1, b_low.val[1], a0); + // 8..15 + ret0 = vdotq_s32(ret0, b_low.val[2], a1); + ret1 = vdotq_s32(ret1, b_low.val[3], a1); + // 16..23 + ret0 = vdotq_s32(ret0, b_high.val[0], a2); + ret1 = vdotq_s32(ret1, b_high.val[1], a2); + // 24..31 + ret0 = vdotq_s32(ret0, b_high.val[2], a3); + ret1 = vdotq_s32(ret1, b_high.val[3], a3); + + int32x4_t ret = vpaddq_s32(ret0, ret1); + + acc = vfmaq_f32(acc, vcvtq_f32_s32(ret), vmulq_f32(vcvt_f32_f16(ad), vcvt_f32_f16(bd))); + a_ptr++; + b_ptr++; + } + vst1q_f32(s, acc); + s += ncols_interleaved; + } + return; + +#endif // defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD) + ggml_gemv_q8_0_4x8_q8_0_generic(n, s, bs, vx, vy, nr, nc); +} + +void ggml_gemm_q4_0_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + const int qk = QK8_0; + const int nb = n / qk; + const int ncols_interleaved = 4; + const int blocklen = 4; + + assert (n % qk == 0); + assert (nr % 4 == 0); + assert (nc % ncols_interleaved == 0); + + UNUSED(s); + UNUSED(bs); + UNUSED(vx); + UNUSED(vy); + UNUSED(nr); + UNUSED(nc); + UNUSED(nb); + UNUSED(ncols_interleaved); + UNUSED(blocklen); + +#if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD) + const void * b_ptr = vx; + const void * a_ptr = vy; + float * res_ptr = s; + size_t res_stride = bs * sizeof(float); + + __asm__ __volatile__( + "mov x10, %x[nr]\n" + "mov x9, #0x88\n" + "cmp x10, #0x10\n" + "mul x9, %x[nb], x9\n" + "blt 4f\n" + "1:" // Row loop + "add x28, %x[b_ptr], #0x8\n" + "mov x27, %x[nc]\n" + "add x26, %x[res_ptr], %x[res_stride], LSL #4\n" + "2:" // Column loop + "add x25, %x[a_ptr], #0x8\n" + "movi v15.16b, #0x0\n" + "movi v19.16b, #0x0\n" + "mov x24, %x[nb]\n" + "add x23, x25, x9\n" + "movi v18.16b, #0x0\n" + "movi v14.16b, #0x0\n" + "add x22, x23, x9\n" + "movi v11.16b, #0x0\n" + "movi v13.16b, #0x0\n" + "add x21, x22, x9\n" + "movi v23.16b, #0x0\n" + "movi v16.16b, #0x0\n" + "movi v25.16b, #0x0\n" + "movi v7.16b, #0x0\n" + "movi v0.16b, #0x0\n" + "movi v4.16b, #0x0\n" + "movi v5.16b, #0x0\n" + "movi v21.16b, #0x0\n" + "movi v8.16b, #0x0\n" + "movi v1.16b, #0x0\n" + "3:" // Block loop + "ldr q3, [x28, #0x0]\n" + "ldr q31, [x25, #0x0]\n" + "movi v28.16b, #0x4\n" + "movi v10.4s, #0x0\n" + "ldr q22, [x28, #0x10]\n" + "ldr q6, [x25, #0x10]\n" + "movi v29.4s, #0x0\n" + "movi v9.4s, #0x0\n" + "ldr q27, [x28, #0x20]\n" + "ldr q30, [x28, #0x30]\n" + "movi v20.4s, #0x0\n" + "movi v24.16b, #0xf0\n" + "ldr d2, [x25, #-0x8]\n" + "ldr d26, [x23, #-0x8]\n" + "sshl v12.16b, v3.16b, v28.16b\n" + "sub x20, x28, #0x8\n" + "ldr d17, [x20, #0x0]\n" + "and v3.16b, v3.16b, v24.16b\n" + "subs x24, x24, #0x1\n" + "add x28, x28, #0x48\n" + ".inst 0x4f9fe18a // sdot v10.4s, v12.16b, v31.4b[0]\n" + ".inst 0x4fbfe19d // sdot v29.4s, v12.16b, v31.4b[1]\n" + ".inst 0x4f9fe989 // sdot v9.4s, v12.16b, v31.4b[2]\n" + ".inst 0x4fbfe994 // sdot v20.4s, v12.16b, v31.4b[3]\n" + "sshl v31.16b, v22.16b, v28.16b\n" + "and v22.16b, v22.16b, v24.16b\n" + "fcvtl v17.4s, v17.4h\n" + "fcvtl v2.4s, v2.4h\n" + "fcvtl v26.4s, v26.4h\n" + ".inst 0x4f86e3ea // sdot v10.4s, v31.16b, v6.4b[0]\n" + ".inst 0x4fa6e3fd // sdot v29.4s, v31.16b, v6.4b[1]\n" + ".inst 0x4f86ebe9 // sdot v9.4s, v31.16b, v6.4b[2]\n" + ".inst 0x4fa6ebf4 // sdot v20.4s, v31.16b, v6.4b[3]\n" + "sshl v6.16b, v27.16b, v28.16b\n" + "sshl v28.16b, v30.16b, v28.16b\n" + "and v27.16b, v27.16b, v24.16b\n" + "and v30.16b, v30.16b, v24.16b\n" + "ldr q24, [x25, #0x20]\n" + ".inst 0x4f98e0ca // sdot v10.4s, v6.16b, v24.4b[0]\n" + ".inst 0x4fb8e0dd // sdot v29.4s, v6.16b, v24.4b[1]\n" + ".inst 0x4f98e8c9 // sdot v9.4s, v6.16b, v24.4b[2]\n" + ".inst 0x4fb8e8d4 // sdot v20.4s, v6.16b, v24.4b[3]\n" + "ldr q24, [x25, #0x30]\n" + ".inst 0x4f98e38a // sdot v10.4s, v28.16b, v24.4b[0]\n" + ".inst 0x4fb8e39d // sdot v29.4s, v28.16b, v24.4b[1]\n" + ".inst 0x4f98eb89 // sdot v9.4s, v28.16b, v24.4b[2]\n" + ".inst 0x4fb8eb94 // sdot v20.4s, v28.16b, v24.4b[3]\n" + "ldr q24, [x25, #0x40]\n" + ".inst 0x4f98e06a // sdot v10.4s, v3.16b, v24.4b[0]\n" + ".inst 0x4fb8e07d // sdot v29.4s, v3.16b, v24.4b[1]\n" + ".inst 0x4f98e869 // sdot v9.4s, v3.16b, v24.4b[2]\n" + ".inst 0x4fb8e874 // sdot v20.4s, v3.16b, v24.4b[3]\n" + "ldr q24, [x25, #0x50]\n" + ".inst 0x4f98e2ca // sdot v10.4s, v22.16b, v24.4b[0]\n" + ".inst 0x4fb8e2dd // sdot v29.4s, v22.16b, v24.4b[1]\n" + ".inst 0x4f98eac9 // sdot v9.4s, v22.16b, v24.4b[2]\n" + ".inst 0x4fb8ead4 // sdot v20.4s, v22.16b, v24.4b[3]\n" + "ldr q24, [x25, #0x60]\n" + ".inst 0x4f98e36a // sdot v10.4s, v27.16b, v24.4b[0]\n" + ".inst 0x4fb8e37d // sdot v29.4s, v27.16b, v24.4b[1]\n" + ".inst 0x4f98eb69 // sdot v9.4s, v27.16b, v24.4b[2]\n" + ".inst 0x4fb8eb74 // sdot v20.4s, v27.16b, v24.4b[3]\n" + "ldr q24, [x25, #0x70]\n" + "add x25, x25, #0x88\n" + ".inst 0x4f98e3ca // sdot v10.4s, v30.16b, v24.4b[0]\n" + ".inst 0x4fb8e3dd // sdot v29.4s, v30.16b, v24.4b[1]\n" + ".inst 0x4f98ebc9 // sdot v9.4s, v30.16b, v24.4b[2]\n" + ".inst 0x4fb8ebd4 // sdot v20.4s, v30.16b, v24.4b[3]\n" + "fmul v24.4s, v17.4s, v2.s[0]\n" + "scvtf v10.4s, v10.4s, #0x4\n" + "scvtf v29.4s, v29.4s, #0x4\n" + "scvtf v9.4s, v9.4s, #0x4\n" + "scvtf v20.4s, v20.4s, #0x4\n" + "fmla v15.4s, v10.4s, v24.4s\n" + "ldr q24, [x23, #0x0]\n" + "fmul v10.4s, v17.4s, v2.s[1]\n" + "fmla v19.4s, v29.4s, v10.4s\n" + "ldr q10, [x23, #0x10]\n" + "fmul v29.4s, v17.4s, v2.s[2]\n" + "fmul v2.4s, v17.4s, v2.s[3]\n" + "fmla v18.4s, v9.4s, v29.4s\n" + "movi v9.4s, #0x0\n" + "movi v29.4s, #0x0\n" + ".inst 0x4f98e189 // sdot v9.4s, v12.16b, v24.4b[0]\n" + ".inst 0x4fb8e19d // sdot v29.4s, v12.16b, v24.4b[1]\n" + "fmla v14.4s, v20.4s, v2.4s\n" + "movi v20.4s, #0x0\n" + "movi v2.4s, #0x0\n" + ".inst 0x4f98e994 // sdot v20.4s, v12.16b, v24.4b[2]\n" + ".inst 0x4fb8e982 // sdot v2.4s, v12.16b, v24.4b[3]\n" + "ldr q24, [x23, #0x20]\n" + ".inst 0x4f8ae3e9 // sdot v9.4s, v31.16b, v10.4b[0]\n" + ".inst 0x4faae3fd // sdot v29.4s, v31.16b, v10.4b[1]\n" + ".inst 0x4f8aebf4 // sdot v20.4s, v31.16b, v10.4b[2]\n" + ".inst 0x4faaebe2 // sdot v2.4s, v31.16b, v10.4b[3]\n" + "ldr q10, [x23, #0x30]\n" + ".inst 0x4f98e0c9 // sdot v9.4s, v6.16b, v24.4b[0]\n" + ".inst 0x4fb8e0dd // sdot v29.4s, v6.16b, v24.4b[1]\n" + ".inst 0x4f98e8d4 // sdot v20.4s, v6.16b, v24.4b[2]\n" + ".inst 0x4fb8e8c2 // sdot v2.4s, v6.16b, v24.4b[3]\n" + "ldr q24, [x23, #0x40]\n" + ".inst 0x4f8ae389 // sdot v9.4s, v28.16b, v10.4b[0]\n" + ".inst 0x4faae39d // sdot v29.4s, v28.16b, v10.4b[1]\n" + ".inst 0x4f8aeb94 // sdot v20.4s, v28.16b, v10.4b[2]\n" + ".inst 0x4faaeb82 // sdot v2.4s, v28.16b, v10.4b[3]\n" + "ldr q10, [x23, #0x50]\n" + ".inst 0x4f98e069 // sdot v9.4s, v3.16b, v24.4b[0]\n" + ".inst 0x4fb8e07d // sdot v29.4s, v3.16b, v24.4b[1]\n" + ".inst 0x4f98e874 // sdot v20.4s, v3.16b, v24.4b[2]\n" + ".inst 0x4fb8e862 // sdot v2.4s, v3.16b, v24.4b[3]\n" + "ldr q24, [x23, #0x60]\n" + ".inst 0x4f8ae2c9 // sdot v9.4s, v22.16b, v10.4b[0]\n" + ".inst 0x4faae2dd // sdot v29.4s, v22.16b, v10.4b[1]\n" + ".inst 0x4f8aead4 // sdot v20.4s, v22.16b, v10.4b[2]\n" + ".inst 0x4faaeac2 // sdot v2.4s, v22.16b, v10.4b[3]\n" + "ldr q10, [x23, #0x70]\n" + "add x23, x23, #0x88\n" + ".inst 0x4f98e369 // sdot v9.4s, v27.16b, v24.4b[0]\n" + ".inst 0x4fb8e37d // sdot v29.4s, v27.16b, v24.4b[1]\n" + ".inst 0x4f98eb74 // sdot v20.4s, v27.16b, v24.4b[2]\n" + ".inst 0x4fb8eb62 // sdot v2.4s, v27.16b, v24.4b[3]\n" + "ldr q24, [x22, #0x0]\n" + ".inst 0x4f8ae3c9 // sdot v9.4s, v30.16b, v10.4b[0]\n" + ".inst 0x4faae3dd // sdot v29.4s, v30.16b, v10.4b[1]\n" + ".inst 0x4f8aebd4 // sdot v20.4s, v30.16b, v10.4b[2]\n" + ".inst 0x4faaebc2 // sdot v2.4s, v30.16b, v10.4b[3]\n" + "fmul v10.4s, v17.4s, v26.s[0]\n" + "scvtf v9.4s, v9.4s, #0x4\n" + "scvtf v29.4s, v29.4s, #0x4\n" + "scvtf v20.4s, v20.4s, #0x4\n" + "scvtf v2.4s, v2.4s, #0x4\n" + "fmla v11.4s, v9.4s, v10.4s\n" + "ldr q9, [x22, #0x10]\n" + "fmul v10.4s, v17.4s, v26.s[1]\n" + "fmla v13.4s, v29.4s, v10.4s\n" + "ldr d29, [x22, #-0x8]\n" + "fmul v10.4s, v17.4s, v26.s[2]\n" + "fmul v26.4s, v17.4s, v26.s[3]\n" + "fcvtl v29.4s, v29.4h\n" + "fmla v23.4s, v20.4s, v10.4s\n" + "movi v20.4s, #0x0\n" + "movi v10.4s, #0x0\n" + "fmla v16.4s, v2.4s, v26.4s\n" + "movi v26.4s, #0x0\n" + "movi v2.4s, #0x0\n" + ".inst 0x4f98e194 // sdot v20.4s, v12.16b, v24.4b[0]\n" + ".inst 0x4fb8e18a // sdot v10.4s, v12.16b, v24.4b[1]\n" + ".inst 0x4f98e99a // sdot v26.4s, v12.16b, v24.4b[2]\n" + ".inst 0x4fb8e982 // sdot v2.4s, v12.16b, v24.4b[3]\n" + "ldr q24, [x22, #0x20]\n" + ".inst 0x4f89e3f4 // sdot v20.4s, v31.16b, v9.4b[0]\n" + ".inst 0x4fa9e3ea // sdot v10.4s, v31.16b, v9.4b[1]\n" + ".inst 0x4f89ebfa // sdot v26.4s, v31.16b, v9.4b[2]\n" + ".inst 0x4fa9ebe2 // sdot v2.4s, v31.16b, v9.4b[3]\n" + "ldr q9, [x22, #0x30]\n" ".inst 0x4f98e0d4 // sdot v20.4s, v6.16b, v24.4b[0]\n" ".inst 0x4fb8e0ca // sdot v10.4s, v6.16b, v24.4b[1]\n" ".inst 0x4f98e8da // sdot v26.4s, v6.16b, v24.4b[2]\n" @@ -2247,89 +3157,1372 @@ void ggml_gemm_q4_0_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const vo ); return; } -#endif // #if defined(__ARM_FEATURE_SVE) && defined(__ARM_FEATURE_MATMUL_INT8) +#endif // #if defined(__ARM_FEATURE_SVE) && defined(__ARM_FEATURE_MATMUL_INT8) + +#endif // #if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) + ggml_gemm_q4_0_8x8_q8_0_generic(n, s, bs, vx, vy, nr, nc); +} + +void ggml_gemm_iq4_nl_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + const int qk = QK8_0; + const int nb = n / qk; + const int ncols_interleaved = 4; + const int blocklen = 4; + + assert (n % qk == 0); + assert (nr % 4 == 0); + assert (nc % ncols_interleaved == 0); + + UNUSED(s); + UNUSED(bs); + UNUSED(vx); + UNUSED(vy); + UNUSED(nr); + UNUSED(nc); + UNUSED(nb); + UNUSED(ncols_interleaved); + UNUSED(blocklen); + +#if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD) + const int8x16_t kvalues = vld1q_s8(kvalues_iq4nl); + + for (int y = 0; y < nr / 4; y++) { + const block_q8_0x4 * a_ptr = (const block_q8_0x4 *) vy + (y * nb); + for (int x = 0; x < nc / ncols_interleaved; x++) { + const block_iq4_nlx4 * b_ptr = (const block_iq4_nlx4 *) vx + (x * nb); + + float32x4_t sumf[4]; + for (int m = 0; m < 4; m++) { + sumf[m] = vdupq_n_f32(0); + } + + for (int l = 0; l < nb; l++) { + float32x4_t a_d = vcvt_f32_f16(vld1_f16((const float16_t *)a_ptr[l].d)); + float32x4_t b_d = vcvt_f32_f16(vld1_f16((const float16_t *)b_ptr[l].d)); + + int32x4_t sumi_0 = vdupq_n_s32(0); + int32x4_t sumi_1 = vdupq_n_s32(0); + int32x4_t sumi_2 = vdupq_n_s32(0); + int32x4_t sumi_3 = vdupq_n_s32(0); + + for (int k = 0; k < 4; k++) { + int8x16_t a_0 = vld1q_s8(a_ptr[l].qs + 16 * k + 0); + int8x16_t a_1 = vld1q_s8(a_ptr[l].qs + 16 * k + 64); + + uint8x16_t b = vld1q_u8(b_ptr[l].qs + 16 * k); + int8x16_t b_hi = vqtbl1q_s8(kvalues, b >> 4); + int8x16_t b_lo = vqtbl1q_s8(kvalues, b & 0xF); + + sumi_0 = vdotq_laneq_s32(sumi_0, b_lo, a_0, 0); + sumi_1 = vdotq_laneq_s32(sumi_1, b_lo, a_0, 1); + sumi_2 = vdotq_laneq_s32(sumi_2, b_lo, a_0, 2); + sumi_3 = vdotq_laneq_s32(sumi_3, b_lo, a_0, 3); + sumi_0 = vdotq_laneq_s32(sumi_0, b_hi, a_1, 0); + sumi_1 = vdotq_laneq_s32(sumi_1, b_hi, a_1, 1); + sumi_2 = vdotq_laneq_s32(sumi_2, b_hi, a_1, 2); + sumi_3 = vdotq_laneq_s32(sumi_3, b_hi, a_1, 3); + } + + sumf[0] = vmlaq_f32(sumf[0], vmulq_laneq_f32(b_d, a_d, 0), vcvtq_f32_s32(sumi_0)); + sumf[1] = vmlaq_f32(sumf[1], vmulq_laneq_f32(b_d, a_d, 1), vcvtq_f32_s32(sumi_1)); + sumf[2] = vmlaq_f32(sumf[2], vmulq_laneq_f32(b_d, a_d, 2), vcvtq_f32_s32(sumi_2)); + sumf[3] = vmlaq_f32(sumf[3], vmulq_laneq_f32(b_d, a_d, 3), vcvtq_f32_s32(sumi_3)); + } + + for (int m = 0; m < 4; m++) { + vst1q_f32(s + (y * 4 + m) * bs + x * 4, sumf[m]); + } + } + } + return; +#endif // #if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON) + ggml_gemm_iq4_nl_4x4_q8_0_generic(n, s, bs, vx, vy, nr, nc); +} + +void ggml_gemm_mxfp4_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + const int qk = QK8_0; + const int nb = n / qk; + const int ncols_interleaved = 4; + const int blocklen = 4; + + assert (n % qk == 0); + assert (nr % 4 == 0); + assert (nc % ncols_interleaved == 0); + + UNUSED(s); + UNUSED(bs); + UNUSED(vx); + UNUSED(vy); + UNUSED(nr); + UNUSED(nc); + UNUSED(nb); + UNUSED(ncols_interleaved); + UNUSED(blocklen); + +#if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD) + const int8x16_t kvalues = vld1q_s8(kvalues_mxfp4); + + for (int y = 0; y < nr / 4; y++) { + const block_q8_0x4 * a_ptr = (const block_q8_0x4 *) vy + (y * nb); + for (int x = 0; x < nc / ncols_interleaved; x++) { + const block_mxfp4x4 * b_ptr = (const block_mxfp4x4 *) vx + (x * nb); + + float32x4_t sumf[4]; + for (int m = 0; m < 4; m++) { + sumf[m] = vdupq_n_f32(0); + } + + for (int l = 0; l < nb; l++) { + float32x4_t a_d = vcvt_f32_f16(vld1_f16((const float16_t *)a_ptr[l].d)); + float32x4_t b_d = { + GGML_CPU_E8M0_TO_FP32_HALF(b_ptr[l].e[0]), + GGML_CPU_E8M0_TO_FP32_HALF(b_ptr[l].e[1]), + GGML_CPU_E8M0_TO_FP32_HALF(b_ptr[l].e[2]), + GGML_CPU_E8M0_TO_FP32_HALF(b_ptr[l].e[3]), + }; + + int32x4_t sumi_0 = vdupq_n_s32(0); + int32x4_t sumi_1 = vdupq_n_s32(0); + int32x4_t sumi_2 = vdupq_n_s32(0); + int32x4_t sumi_3 = vdupq_n_s32(0); + + for (int k = 0; k < 4; k++) { + int8x16_t a_0 = vld1q_s8(a_ptr[l].qs + 16 * k + 0); + int8x16_t a_1 = vld1q_s8(a_ptr[l].qs + 16 * k + 64); + + uint8x16_t b = vld1q_u8(b_ptr[l].qs + 16 * k); + int8x16_t b_hi = vqtbl1q_s8(kvalues, b >> 4); + int8x16_t b_lo = vqtbl1q_s8(kvalues, b & 0xF); + + sumi_0 = vdotq_laneq_s32(sumi_0, b_lo, a_0, 0); + sumi_1 = vdotq_laneq_s32(sumi_1, b_lo, a_0, 1); + sumi_2 = vdotq_laneq_s32(sumi_2, b_lo, a_0, 2); + sumi_3 = vdotq_laneq_s32(sumi_3, b_lo, a_0, 3); + sumi_0 = vdotq_laneq_s32(sumi_0, b_hi, a_1, 0); + sumi_1 = vdotq_laneq_s32(sumi_1, b_hi, a_1, 1); + sumi_2 = vdotq_laneq_s32(sumi_2, b_hi, a_1, 2); + sumi_3 = vdotq_laneq_s32(sumi_3, b_hi, a_1, 3); + } + + sumf[0] = vmlaq_f32(sumf[0], vmulq_laneq_f32(b_d, a_d, 0), vcvtq_f32_s32(sumi_0)); + sumf[1] = vmlaq_f32(sumf[1], vmulq_laneq_f32(b_d, a_d, 1), vcvtq_f32_s32(sumi_1)); + sumf[2] = vmlaq_f32(sumf[2], vmulq_laneq_f32(b_d, a_d, 2), vcvtq_f32_s32(sumi_2)); + sumf[3] = vmlaq_f32(sumf[3], vmulq_laneq_f32(b_d, a_d, 3), vcvtq_f32_s32(sumi_3)); + } + + for (int m = 0; m < 4; m++) { + vst1q_f32(s + (y * 4 + m) * bs + x * 4, sumf[m]); + } + } + } + return; +#endif // #if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON) + ggml_gemm_mxfp4_4x4_q8_0_generic(n, s, bs, vx, vy, nr, nc); +} + +void ggml_gemm_q4_K_8x4_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + constexpr int qk = QK_K; + const int nb = n / qk; + + constexpr int ncols_interleaved = 8; + constexpr int blocklen = 4; + + assert(n % qk == 0); + assert(nr % 4 == 0); + assert(nc % ncols_interleaved == 0); + + UNUSED(nb); + UNUSED(ncols_interleaved); + UNUSED(blocklen); + +#if defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD) + constexpr int q8_k_blocklen = 4; + constexpr int acc_size = 2 * 4; // 2 row pairs × 4 col pairs + const uint8x16_t m4b = vdupq_n_u8(0x0f); + + // 8 accumulators: 2 row pairs × 4 col pairs + float32x4_t acc_f32[acc_size]; + + for (int y = 0; y < nr / q8_k_blocklen; y++) { + const block_q8_Kx4 * GGML_RESTRICT q8_ptr = (const block_q8_Kx4 *) vy + (y * nb); + + for (int x = 0; x < nc / ncols_interleaved; x++) { + const block_q4_Kx8 * GGML_RESTRICT q4_ptr = (const block_q4_Kx8 *) vx + (x * nb); + + for (int i = 0; i < acc_size; i++) { + acc_f32[i] = vdupq_n_f32(0); + } + + for (int b = 0; b < nb; b++) { + // d4 0 1 2 3, 4 5 6 7 + float32x4_t q4_d_0123 = vcvt_f32_f16(vld1_f16((const __fp16 *) q4_ptr[b].d)); + float32x4_t q4_d_4567 = vcvt_f32_f16(vld1_f16((const __fp16 *) q4_ptr[b].d + 4)); + // d8 0 1 2 3 + float32x4_t q8_d_0123 = vld1q_f32(q8_ptr[b].d); + // mins + float32x4_t q4_dmin_0123 = vcvt_f32_f16(vld1_f16((const __fp16 *) q4_ptr[b].dmin)); + float32x4_t q4_dmin_4567 = vcvt_f32_f16(vld1_f16((const __fp16 *) q4_ptr[b].dmin + 4)); + + // Precomputation of scales and mins + float32x4_t sbd_scale_0123[q8_k_blocklen]; + float32x4_t sbd_scale_4567[q8_k_blocklen]; + float32x4_t sbd_min_0123[q8_k_blocklen]; + float32x4_t sbd_min_4567[q8_k_blocklen]; + + sbd_scale_0123[0] = vmulq_laneq_f32(q4_d_0123, q8_d_0123, 0); + sbd_scale_4567[0] = vmulq_laneq_f32(q4_d_4567, q8_d_0123, 0); + sbd_min_0123[0] = vmulq_laneq_f32(q4_dmin_0123, q8_d_0123, 0); + sbd_min_4567[0] = vmulq_laneq_f32(q4_dmin_4567, q8_d_0123, 0); + + sbd_scale_0123[1] = vmulq_laneq_f32(q4_d_0123, q8_d_0123, 1); + sbd_scale_4567[1] = vmulq_laneq_f32(q4_d_4567, q8_d_0123, 1); + sbd_min_0123[1] = vmulq_laneq_f32(q4_dmin_0123, q8_d_0123, 1); + sbd_min_4567[1] = vmulq_laneq_f32(q4_dmin_4567, q8_d_0123, 1); + + sbd_scale_0123[2] = vmulq_laneq_f32(q4_d_0123, q8_d_0123, 2); + sbd_scale_4567[2] = vmulq_laneq_f32(q4_d_4567, q8_d_0123, 2); + sbd_min_0123[2] = vmulq_laneq_f32(q4_dmin_0123, q8_d_0123, 2); + sbd_min_4567[2] = vmulq_laneq_f32(q4_dmin_4567, q8_d_0123, 2); + + sbd_scale_0123[3] = vmulq_laneq_f32(q4_d_0123, q8_d_0123, 3); + sbd_scale_4567[3] = vmulq_laneq_f32(q4_d_4567, q8_d_0123, 3); + sbd_min_0123[3] = vmulq_laneq_f32(q4_dmin_0123, q8_d_0123, 3); + sbd_min_4567[3] = vmulq_laneq_f32(q4_dmin_4567, q8_d_0123, 3); + + // Precomputation of bsums, each vpaddq calcs all the bsums for each row + const int16x8_t bsums[q8_k_blocklen] = { + vpaddq_s16(vld1q_s16(q8_ptr[b].bsums + 16 * 0), vld1q_s16(q8_ptr[b].bsums + 16 * 0 + 8)), + vpaddq_s16(vld1q_s16(q8_ptr[b].bsums + 16 * 1), vld1q_s16(q8_ptr[b].bsums + 16 * 1 + 8)), + vpaddq_s16(vld1q_s16(q8_ptr[b].bsums + 16 * 2), vld1q_s16(q8_ptr[b].bsums + 16 * 2 + 8)), + vpaddq_s16(vld1q_s16(q8_ptr[b].bsums + 16 * 3), vld1q_s16(q8_ptr[b].bsums + 16 * 3 + 8)), + }; + int16_t bsums_arr[QK_K / 64][8]; + for (int q8_row = 0; q8_row < 4; q8_row++) { + vst1q_s16(bsums_arr[q8_row], bsums[q8_row]); + } + + // interleaved bias_acc: [0]->r0 0123, [1]->r1 0123, .., [4]->r0 4567, [5]->r1 4567 .. + int32x4_t bias_acc[acc_size]; + for (int i = 0; i < acc_size; i++) { + bias_acc[i] = vdupq_n_s32(0); + } + + for (int sb = 0; sb < QK_K / 64; sb++) { + // Int accumulators for qs vecdot (4 row x 2 col quartets) + int32x4_t acc_lo[acc_size]; + int32x4_t acc_hi[acc_size]; + for (int i = 0; i < acc_size; i++) { + acc_lo[i] = vdupq_n_s32(0); + acc_hi[i] = vdupq_n_s32(0); + } + // Need scales for the low and high nibbles + // 2 * 12 = 24 bytes per subblock, 4 sbs -> 4 * 24 = 96 bytes total + int16x8_t q4sb_scales[2]; + int16x8_t q4sb_mins[2]; + for (int i = 0; i < 2; i++) { + int8_t aux_q4sb[8]; + const int offset = sb * 24 + i * 12; + decode_q_Kx8_6bit_scales(&q4_ptr[b].scales[offset], &q4sb_mins[i], aux_q4sb); + q4sb_scales[i] = vmovl_s8(vld1_s8(aux_q4sb)); + } + + constexpr int reads_per_sb = 8; // 8 * 16 bytes each => 32 qs * 4 rows + for (int k = 0; k < reads_per_sb; k++) { + const int8x16_t q8_blk0 = vld1q_s8(q8_ptr[b].qs + sb * 256 + 16 * k); + const int8x16_t q8_blk1 = vld1q_s8(q8_ptr[b].qs + sb * 256 + 16 * k + 128); + + // 0..3 & 32..35 + const uint8x16_t q4_0123 = vld1q_u8(q4_ptr[b].qs + sb * QK_K + 32 * k); + const uint8x16_t q4_4567 = vld1q_u8(q4_ptr[b].qs + sb * QK_K + 32 * k + 16); + + const int8x16_t q4_0123_lo = vreinterpretq_s8_u8(vandq_u8(q4_0123, m4b)); + const int8x16_t q4_0123_hi = vreinterpretq_s8_u8(vshrq_n_u8(q4_0123, 4)); + + acc_lo[0] = vdotq_laneq_s32(acc_lo[0], q4_0123_lo, q8_blk0, 0); // 0..3 r0 c0123 + acc_lo[1] = vdotq_laneq_s32(acc_lo[1], q4_0123_lo, q8_blk0, 1); // 0..3 r1 c0123 + acc_lo[2] = vdotq_laneq_s32(acc_lo[2], q4_0123_lo, q8_blk0, 2); // 0..3 r2 c0123 + acc_lo[3] = vdotq_laneq_s32(acc_lo[3], q4_0123_lo, q8_blk0, 3); // 0..3 r3 c0123 + + acc_hi[0] = vdotq_laneq_s32(acc_hi[0], q4_0123_hi, q8_blk1, 0); // 32..35 r0 c0123 + acc_hi[1] = vdotq_laneq_s32(acc_hi[1], q4_0123_hi, q8_blk1, 1); // 32..35 r1 c0123 + acc_hi[2] = vdotq_laneq_s32(acc_hi[2], q4_0123_hi, q8_blk1, 2); // 32..35 r2 c0123 + acc_hi[3] = vdotq_laneq_s32(acc_hi[3], q4_0123_hi, q8_blk1, 3); // 32..35 r3 c0123 + + const int8x16_t q4_4567_lo = vreinterpretq_s8_u8(vandq_u8(q4_4567, m4b)); + const int8x16_t q4_4567_hi = vreinterpretq_s8_u8(vshrq_n_u8(q4_4567, 4)); + + acc_lo[4] = vdotq_laneq_s32(acc_lo[4], q4_4567_lo, q8_blk0, 0); // 0..3 r0 c4567 + acc_lo[5] = vdotq_laneq_s32(acc_lo[5], q4_4567_lo, q8_blk0, 1); // 0..3 r1 c4567 + acc_lo[6] = vdotq_laneq_s32(acc_lo[6], q4_4567_lo, q8_blk0, 2); // 0..3 r2 c4567 + acc_lo[7] = vdotq_laneq_s32(acc_lo[7], q4_4567_lo, q8_blk0, 3); // 0..3 r3 c4567 + + acc_hi[4] = vdotq_laneq_s32(acc_hi[4], q4_4567_hi, q8_blk1, 0); // 32..35 r0 c4567 + acc_hi[5] = vdotq_laneq_s32(acc_hi[5], q4_4567_hi, q8_blk1, 1); // 32..35 r1 c4567 + acc_hi[6] = vdotq_laneq_s32(acc_hi[6], q4_4567_hi, q8_blk1, 2); // 32..35 r2 c4567 + acc_hi[7] = vdotq_laneq_s32(acc_hi[7], q4_4567_hi, q8_blk1, 3); // 32..35 r3 c4567 + } + + // Scale and bias application + // acc is stored interleaved to match output layout + const int16x4_t sc_0123_lo = vget_low_s16(q4sb_scales[0]); + const int16x4_t sc_4567_lo = vget_high_s16(q4sb_scales[0]); + const int16x4_t sc_0123_hi = vget_low_s16(q4sb_scales[1]); + const int16x4_t sc_4567_hi = vget_high_s16(q4sb_scales[1]); + for (int row = 0; row < q8_k_blocklen; row++) { + // Bias correction + // row c0123 blk0 and blk1 + const float32x4_t sumf_0123 = + vcvtq_f32_s32(vaddq_s32(vmulq_s32(vmovl_s16(sc_0123_lo), acc_lo[row]), + vmulq_s32(vmovl_s16(sc_0123_hi), acc_hi[row]))); + acc_f32[2 * row] = vfmaq_f32(acc_f32[2 * row], sbd_scale_0123[row], sumf_0123); + + // row c4567 blk0 and blk1 + const float32x4_t sumf_4567 = + vcvtq_f32_s32(vaddq_s32(vmulq_s32(vmovl_s16(sc_4567_lo), acc_lo[row + 4]), + vmulq_s32(vmovl_s16(sc_4567_hi), acc_hi[row + 4]))); + acc_f32[2 * row + 1] = vfmaq_f32(acc_f32[2 * row + 1], sbd_scale_4567[row], sumf_4567); + + // Bias + const int16x4_t bsums_vec_lo = vdup_n_s16(bsums_arr[sb][row * 2]); + const int16x4_t bsums_vec_hi = vdup_n_s16(bsums_arr[sb][row * 2 + 1]); + + // row c0123 blk0 and blk1 + bias_acc[2 * row] = vmlal_s16(bias_acc[2 * row], bsums_vec_lo, vget_low_s16(q4sb_mins[0])); + bias_acc[2 * row] = vmlal_s16(bias_acc[2 * row], bsums_vec_hi, vget_low_s16(q4sb_mins[1])); + + // row c4567 blk0 and blk1 + bias_acc[2 * row + 1] = + vmlal_s16(bias_acc[2 * row + 1], bsums_vec_lo, vget_high_s16(q4sb_mins[0])); + bias_acc[2 * row + 1] = + vmlal_s16(bias_acc[2 * row + 1], bsums_vec_hi, vget_high_s16(q4sb_mins[1])); + } + } // for sb + + for (int row = 0; row < q8_k_blocklen; row++) { + acc_f32[2 * row] = vmlsq_f32(acc_f32[2 * row], vcvtq_f32_s32(bias_acc[2 * row]), sbd_min_0123[row]); + acc_f32[2 * row + 1] = + vmlsq_f32(acc_f32[2 * row + 1], vcvtq_f32_s32(bias_acc[2 * row + 1]), sbd_min_4567[row]); + } + } // for b + + for (int i = 0; i < q8_k_blocklen; i++) { + int row = y * q8_k_blocklen + i; + for (int j = 0; j < 2; j++) { + int col = x * ncols_interleaved + j * 4; + int offset = row * bs + col; + vst1q_f32(s + offset, acc_f32[2 * i + j]); + } + } + } // for x + } // for y + return; +#endif // defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD) + ggml_gemm_q4_K_8x4_q8_K_generic(n, s, bs, vx, vy, nr, nc); +} + +void ggml_gemm_q5_K_8x4_q8_K(int n, + float * GGML_RESTRICT s, + size_t bs, + const void * GGML_RESTRICT vx, + const void * GGML_RESTRICT vy, + int nr, + int nc) { + constexpr int qk = QK_K; + const int nb = n / qk; + + constexpr int ncols_interleaved = 8; + constexpr int blocklen = 4; + + assert(n % qk == 0); + assert(nr % 4 == 0); + assert(nc % ncols_interleaved == 0); + + UNUSED(nb); + UNUSED(ncols_interleaved); + UNUSED(blocklen); + +#if defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD) + constexpr int q8_k_blocklen = 4; + constexpr int acc_size = 2 * 4; // 2 row pairs, 4 col pairs + constexpr int col_groups = ncols_interleaved / 4; + const uint8x16_t m4b = vdupq_n_u8(0x0f); + const uint8x16_t mone = vdupq_n_u8(1); + const uint8x16_t mtwo = vdupq_n_u8(2); + + // 8 accumulators: 2 row pairs, 4 col pairs + float32x4_t acc_f32[acc_size]; + + for (int y = 0; y < nr / q8_k_blocklen; y++) { + const block_q8_Kx4 * GGML_RESTRICT q8_ptr = (const block_q8_Kx4 *) vy + (y * nb); + + for (int x = 0; x < nc / ncols_interleaved; x++) { + const block_q5_Kx8 * GGML_RESTRICT q5_ptr = (const block_q5_Kx8 *) vx + (x * nb); + + for (int i = 0; i < acc_size; i++) { + acc_f32[i] = vdupq_n_f32(0); + } + + for (int b = 0; b < nb; b++) { + // d5 0 1 2 3, 4 5 6 7 + float32x4_t q5_d_0123 = vcvt_f32_f16(vld1_f16((const __fp16 *) q5_ptr[b].d)); + float32x4_t q5_d_4567 = vcvt_f32_f16(vld1_f16((const __fp16 *) q5_ptr[b].d + 4)); + // d8 0 1 2 3 + float32x4_t q8_d_0123 = vld1q_f32(q8_ptr[b].d); + // mins + float32x4_t q5_dmin_0123 = vcvt_f32_f16(vld1_f16((const __fp16 *) q5_ptr[b].dmin)); + float32x4_t q5_dmin_4567 = vcvt_f32_f16(vld1_f16((const __fp16 *) q5_ptr[b].dmin + 4)); + + // Precomputation of scales and mins + float32x4_t sbd_scale_0123[q8_k_blocklen]; + float32x4_t sbd_scale_4567[q8_k_blocklen]; + float32x4_t sbd_min_0123[q8_k_blocklen]; + float32x4_t sbd_min_4567[q8_k_blocklen]; + + sbd_scale_0123[0] = vmulq_laneq_f32(q5_d_0123, q8_d_0123, 0); + sbd_scale_4567[0] = vmulq_laneq_f32(q5_d_4567, q8_d_0123, 0); + sbd_min_0123[0] = vmulq_laneq_f32(q5_dmin_0123, q8_d_0123, 0); + sbd_min_4567[0] = vmulq_laneq_f32(q5_dmin_4567, q8_d_0123, 0); + + sbd_scale_0123[1] = vmulq_laneq_f32(q5_d_0123, q8_d_0123, 1); + sbd_scale_4567[1] = vmulq_laneq_f32(q5_d_4567, q8_d_0123, 1); + sbd_min_0123[1] = vmulq_laneq_f32(q5_dmin_0123, q8_d_0123, 1); + sbd_min_4567[1] = vmulq_laneq_f32(q5_dmin_4567, q8_d_0123, 1); + + sbd_scale_0123[2] = vmulq_laneq_f32(q5_d_0123, q8_d_0123, 2); + sbd_scale_4567[2] = vmulq_laneq_f32(q5_d_4567, q8_d_0123, 2); + sbd_min_0123[2] = vmulq_laneq_f32(q5_dmin_0123, q8_d_0123, 2); + sbd_min_4567[2] = vmulq_laneq_f32(q5_dmin_4567, q8_d_0123, 2); + + sbd_scale_0123[3] = vmulq_laneq_f32(q5_d_0123, q8_d_0123, 3); + sbd_scale_4567[3] = vmulq_laneq_f32(q5_d_4567, q8_d_0123, 3); + sbd_min_0123[3] = vmulq_laneq_f32(q5_dmin_0123, q8_d_0123, 3); + sbd_min_4567[3] = vmulq_laneq_f32(q5_dmin_4567, q8_d_0123, 3); + + // Precomputation of bsums, each vpaddq calcs all the bsums for each row + const int16x8_t bsums[q8_k_blocklen] = { + vpaddq_s16(vld1q_s16(q8_ptr[b].bsums + 16 * 0), vld1q_s16(q8_ptr[b].bsums + 16 * 0 + 8)), + vpaddq_s16(vld1q_s16(q8_ptr[b].bsums + 16 * 1), vld1q_s16(q8_ptr[b].bsums + 16 * 1 + 8)), + vpaddq_s16(vld1q_s16(q8_ptr[b].bsums + 16 * 2), vld1q_s16(q8_ptr[b].bsums + 16 * 2 + 8)), + vpaddq_s16(vld1q_s16(q8_ptr[b].bsums + 16 * 3), vld1q_s16(q8_ptr[b].bsums + 16 * 3 + 8)), + }; + int16_t bsums_arr[QK_K / 64][8]; + for (int q8_row = 0; q8_row < 4; q8_row++) { + vst1q_s16(bsums_arr[q8_row], bsums[q8_row]); + } + + // interleaved bias_acc: [0]->r0 0123, [1]->r1 0123, .., [4]->r0 4567, [5]->r1 4567 .. + int32x4_t bias_acc[acc_size]; + for (int i = 0; i < acc_size; i++) { + bias_acc[i] = vdupq_n_s32(0); + } + + uint8x16_t qh[col_groups][8]; + for (int c = 0; c < col_groups; c++) { + for (int i = 0; i < 8; i++) { + qh[c][i] = vld1q_u8(q5_ptr[b].qh + i * 32 + 16 * c); + } + } + + for (int sb = 0; sb < QK_K / 64; sb++) { + // Int accumulators for qs vecdot (4 row * 2 col quartets) + int32x4_t acc_lo[acc_size]; + int32x4_t acc_hi[acc_size]; + for (int i = 0; i < acc_size; i++) { + acc_lo[i] = vdupq_n_s32(0); + acc_hi[i] = vdupq_n_s32(0); + } + // Need scales for the low and high nibbles + // 2 * 12 = 24 bytes per subblock, 4 sbs -> 4 * 24 = 96 bytes total + int16x8_t q5sb_scales[2]; + int16x8_t q5sb_mins[2]; + for (int i = 0; i < 2; i++) { + int8_t aux_q5sb[8]; + const int offset = sb * 24 + i * 12; + decode_q_Kx8_6bit_scales(&q5_ptr[b].scales[offset], &q5sb_mins[i], aux_q5sb); + q5sb_scales[i] = vmovl_s8(vld1_s8(aux_q5sb)); + } + + constexpr int reads_per_sb = 8; // 8 * 16 bytes each => 32 qs * 4 rows + for (int k = 0; k < reads_per_sb; k++) { + const int8x16_t q8_blk0 = vld1q_s8(q8_ptr[b].qs + sb * 256 + 16 * k); + const int8x16_t q8_blk1 = vld1q_s8(q8_ptr[b].qs + sb * 256 + 16 * k + 128); + + // 0..3 & 32..35 + const uint8x16_t q5_0123 = vld1q_u8(q5_ptr[b].qs + sb * QK_K + 32 * k); + const uint8x16_t q5_4567 = vld1q_u8(q5_ptr[b].qs + sb * QK_K + 32 * k + 16); + + // NOTE: This is the only difference with q4_K + const uint8x16_t hbit_lo_0123 = vandq_u8(qh[0][k], mone); + const uint8x16_t hbit_hi_0123 = vshlq_n_u8(vandq_u8(qh[0][k], mtwo), 3); + qh[0][k] = vshrq_n_u8(qh[0][k], 2); + const uint8x16_t hbit_lo_4567 = vandq_u8(qh[1][k], mone); + const uint8x16_t hbit_hi_4567 = vshlq_n_u8(vandq_u8(qh[1][k], mtwo), 3); + qh[1][k] = vshrq_n_u8(qh[1][k], 2); + // From here, same as q4_K + + const int8x16_t q5_0123_lo = + vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(q5_0123, m4b), hbit_lo_0123, 4)); + const int8x16_t q5_0123_hi = + vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q5_0123, 4), hbit_hi_0123)); + + acc_lo[0] = vdotq_laneq_s32(acc_lo[0], q5_0123_lo, q8_blk0, 0); // 0..3 r0 c0123 + acc_lo[1] = vdotq_laneq_s32(acc_lo[1], q5_0123_lo, q8_blk0, 1); // 0..3 r1 c0123 + acc_lo[2] = vdotq_laneq_s32(acc_lo[2], q5_0123_lo, q8_blk0, 2); // 0..3 r2 c0123 + acc_lo[3] = vdotq_laneq_s32(acc_lo[3], q5_0123_lo, q8_blk0, 3); // 0..3 r3 c0123 + + acc_hi[0] = vdotq_laneq_s32(acc_hi[0], q5_0123_hi, q8_blk1, 0); // 32..35 r0 c0123 + acc_hi[1] = vdotq_laneq_s32(acc_hi[1], q5_0123_hi, q8_blk1, 1); // 32..35 r1 c0123 + acc_hi[2] = vdotq_laneq_s32(acc_hi[2], q5_0123_hi, q8_blk1, 2); // 32..35 r2 c0123 + acc_hi[3] = vdotq_laneq_s32(acc_hi[3], q5_0123_hi, q8_blk1, 3); // 32..35 r3 c0123 + + const int8x16_t q5_4567_lo = + vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(q5_4567, m4b), hbit_lo_4567, 4)); + const int8x16_t q5_4567_hi = + vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q5_4567, 4), hbit_hi_4567)); + + acc_lo[4] = vdotq_laneq_s32(acc_lo[4], q5_4567_lo, q8_blk0, 0); // 0..3 r0 c4567 + acc_lo[5] = vdotq_laneq_s32(acc_lo[5], q5_4567_lo, q8_blk0, 1); // 0..3 r1 c4567 + acc_lo[6] = vdotq_laneq_s32(acc_lo[6], q5_4567_lo, q8_blk0, 2); // 0..3 r2 c4567 + acc_lo[7] = vdotq_laneq_s32(acc_lo[7], q5_4567_lo, q8_blk0, 3); // 0..3 r3 c4567 + + acc_hi[4] = vdotq_laneq_s32(acc_hi[4], q5_4567_hi, q8_blk1, 0); // 32..35 r0 c4567 + acc_hi[5] = vdotq_laneq_s32(acc_hi[5], q5_4567_hi, q8_blk1, 1); // 32..35 r1 c4567 + acc_hi[6] = vdotq_laneq_s32(acc_hi[6], q5_4567_hi, q8_blk1, 2); // 32..35 r2 c4567 + acc_hi[7] = vdotq_laneq_s32(acc_hi[7], q5_4567_hi, q8_blk1, 3); // 32..35 r3 c4567 + } + + // Scale and bias application + // acc is stored interleaved to match output layout + const int16x4_t sc_0123_lo = vget_low_s16(q5sb_scales[0]); + const int16x4_t sc_4567_lo = vget_high_s16(q5sb_scales[0]); + const int16x4_t sc_0123_hi = vget_low_s16(q5sb_scales[1]); + const int16x4_t sc_4567_hi = vget_high_s16(q5sb_scales[1]); + for (int row = 0; row < q8_k_blocklen; row++) { + // Bias correction + // row c0123 blk0 and blk1 + const float32x4_t sumf_0123 = + vcvtq_f32_s32(vaddq_s32(vmulq_s32(vmovl_s16(sc_0123_lo), acc_lo[row]), + vmulq_s32(vmovl_s16(sc_0123_hi), acc_hi[row]))); + acc_f32[2 * row] = vfmaq_f32(acc_f32[2 * row], sbd_scale_0123[row], sumf_0123); + + // row c4567 blk0 and blk1 + const float32x4_t sumf_4567 = + vcvtq_f32_s32(vaddq_s32(vmulq_s32(vmovl_s16(sc_4567_lo), acc_lo[row + 4]), + vmulq_s32(vmovl_s16(sc_4567_hi), acc_hi[row + 4]))); + acc_f32[2 * row + 1] = vfmaq_f32(acc_f32[2 * row + 1], sbd_scale_4567[row], sumf_4567); + + // Bias + const int16x4_t bsums_vec_lo = vdup_n_s16(bsums_arr[sb][row * 2]); + const int16x4_t bsums_vec_hi = vdup_n_s16(bsums_arr[sb][row * 2 + 1]); + + // row c0123 blk0 and blk1 + bias_acc[2 * row] = vmlal_s16(bias_acc[2 * row], bsums_vec_lo, vget_low_s16(q5sb_mins[0])); + bias_acc[2 * row] = vmlal_s16(bias_acc[2 * row], bsums_vec_hi, vget_low_s16(q5sb_mins[1])); + + // row c4567 blk0 and blk1 + bias_acc[2 * row + 1] = + vmlal_s16(bias_acc[2 * row + 1], bsums_vec_lo, vget_high_s16(q5sb_mins[0])); + bias_acc[2 * row + 1] = + vmlal_s16(bias_acc[2 * row + 1], bsums_vec_hi, vget_high_s16(q5sb_mins[1])); + } + } // for sb + + for (int row = 0; row < q8_k_blocklen; row++) { + acc_f32[2 * row] = vmlsq_f32(acc_f32[2 * row], vcvtq_f32_s32(bias_acc[2 * row]), sbd_min_0123[row]); + acc_f32[2 * row + 1] = + vmlsq_f32(acc_f32[2 * row + 1], vcvtq_f32_s32(bias_acc[2 * row + 1]), sbd_min_4567[row]); + } + } // for b + + for (int i = 0; i < q8_k_blocklen; i++) { + int row = y * q8_k_blocklen + i; + for (int j = 0; j < 2; j++) { + int col = x * ncols_interleaved + j * 4; + int offset = row * bs + col; + vst1q_f32(s + offset, acc_f32[2 * i + j]); + } + } + } // for x + } // for y + return; +#endif // defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD) + ggml_gemm_q5_K_8x4_q8_K_generic(n, s, bs, vx, vy, nr, nc); +} + +void ggml_gemm_q4_K_8x8_q8_K(int n, + float * GGML_RESTRICT s, + size_t bs, + const void * GGML_RESTRICT vx, + const void * GGML_RESTRICT vy, + int nr, + int nc) { + constexpr int qk = QK_K; + const int nb = n / qk; + + constexpr int ncols_interleaved = 8; + constexpr int blocklen = 8; + + assert(n % qk == 0); + assert(nr % 4 == 0); + assert(nc % ncols_interleaved == 0); + + UNUSED(nb); + UNUSED(ncols_interleaved); + UNUSED(blocklen); + +#if defined(__aarch64__) && defined(__ARM_FEATURE_SVE) && defined(__ARM_FEATURE_MATMUL_INT8) + if (svcntb() * 8 == 256) { + constexpr int q8_k_blocklen = 4; + const svuint8_t m4b_1 = svdup_n_u8(0x0f); + // 8 accumulators: 2 row pairs × 4 col pairs + svfloat32_t acc_f32_01, acc_f32_23, acc_f32_45, acc_f32_67; + uint32_t idx_arr[8] = { 0, 2, 4, 6, 1, 3, 5, 7 }; + svbool_t pg = svptrue_pat_b32(SV_VL8); + svuint32_t idx = svld1(pg, idx_arr); + + static const uint32_t idx_data[8] = {0, 4, 2, 6, 1, 5, 3, 7}; + svuint32_t idx1 = svld1_u32(svptrue_b32(), idx_data); + + for (int y = 0; y < nr / q8_k_blocklen; y++) { + const block_q8_Kx4 * GGML_RESTRICT q8_ptr = (const block_q8_Kx4 *) vy + (y * nb); + + for (int x = 0; x < nc / ncols_interleaved; x++) { + const block_q4_Kx8 * GGML_RESTRICT q4_ptr = (const block_q4_Kx8 *) vx + (x * nb); + + acc_f32_01 = svdup_n_f32(0); + acc_f32_23 = svdup_n_f32(0); + acc_f32_45 = svdup_n_f32(0); + acc_f32_67 = svdup_n_f32(0); + + for (int b = 0; b < nb; b++) { + // bsums pairs belongs to the same q8_k subblock + // 64 elements loaded and made sum of 0-7 and 8-15 sum || 16-23 and 24 - 31 sum + const int16x8_t bsums[4]{ + vpaddq_s16(vld1q_s16(q8_ptr[b].bsums + 16 * 0), vld1q_s16(q8_ptr[b].bsums + 16 * 0 + 8)), + vpaddq_s16(vld1q_s16(q8_ptr[b].bsums + 16 * 1), vld1q_s16(q8_ptr[b].bsums + 16 * 1 + 8)), + vpaddq_s16(vld1q_s16(q8_ptr[b].bsums + 16 * 2), vld1q_s16(q8_ptr[b].bsums + 16 * 2 + 8)), + vpaddq_s16(vld1q_s16(q8_ptr[b].bsums + 16 * 3), vld1q_s16(q8_ptr[b].bsums + 16 * 3 + 8)), + }; + + int32_t bsums_arr32[4][8]; + + for (int q8_row = 0; q8_row < 4; q8_row++) { + int16x8_t v16 = bsums[q8_row]; + + // low 4 + int32x4_t v32_lo = vmovl_s16(vget_low_s16(v16)); + vst1q_s32(&bsums_arr32[q8_row][0], v32_lo); + + // high 4 + int32x4_t v32_hi = vmovl_s16(vget_high_s16(v16)); + vst1q_s32(&bsums_arr32[q8_row][4], v32_hi); + } + + svint32_t sb_acc_0 = svdup_n_s32(0); + svint32_t sb_acc_2 = svdup_n_s32(0); + + svint32_t acc_00 = svdup_n_s32(0); + svint32_t acc_11 = svdup_n_s32(0); + svint32_t acc_22 = svdup_n_s32(0); + svint32_t acc_33 = svdup_n_s32(0); + svint32_t acc_44 = svdup_n_s32(0); + svint32_t acc_55 = svdup_n_s32(0); + svint32_t acc_66 = svdup_n_s32(0); + svint32_t acc_77 = svdup_n_s32(0); + + svint32_t bias_acc_00 = svdup_n_s32(0); + svint32_t bias_acc_22 = svdup_n_s32(0); + svint32_t bias_acc_44 = svdup_n_s32(0); + svint32_t bias_acc_66 = svdup_n_s32(0); + + for (int sb = 0; sb < QK_K / 64; sb++) { + // Need scales for the low and high nibbles + // 2 * 12 = 24 bytes per subblock, 4 sbs -> 4 * 24 = 96 bytes total + svint32_t block_scale_0, block_scale_1, block_scale_2, block_scale_3; + svint32_t q4sb_mins_0, q4sb_mins_1; + { + // 2-superblock I am working on + const int offset = sb * 24 + 0 * 12; + const uint8_t * scales_in = &q4_ptr[b].scales[offset]; + + const int offset1 = sb * 24 + 12; + const uint8_t * scales_in1 = &q4_ptr[b].scales[offset1]; + + constexpr uint32_t kmask1 = 0x3f3f3f3f; + constexpr uint32_t kmask2 = 0x0f0f0f0f; + constexpr uint32_t kmask3 = 0x03030303; + constexpr uint8_t scales_size = 12; + + uint32_t sm[3]; + memcpy(sm, scales_in, scales_size); + + uint32_t sm1[3]; + memcpy(sm1, scales_in1, scales_size); + + const uint32_t mins_0_3 = sm[1] & kmask1; + const uint32_t mins_4_7 = ((sm[2] >> 4) & kmask2) | (((sm[1] >> 6) & kmask3) << 4); + + const uint32_t mins_0_3_1 = sm1[1] & kmask1; + const uint32_t mins_4_7_1 = ((sm1[2] >> 4) & kmask2) | (((sm1[1] >> 6) & kmask3) << 4); + + svuint32_t mins_u32_temp = svzip1_u32(svdup_n_u32(mins_0_3), svdup_n_u32(mins_4_7)); + svuint32_t mins_u32_temp_1 = svzip1_u32(svdup_n_u32(mins_0_3_1), svdup_n_u32(mins_4_7_1)); + + /* reinterpret u32 → u8 */ + svuint8_t mins_u8 = svreinterpret_u8_u32(mins_u32_temp); + svuint8_t mins_u8_1 = svreinterpret_u8_u32(mins_u32_temp_1); + + /* widen u8 → u16->u32 (lower half only) */ + svuint32_t mins_u16 = svunpklo_u32(svunpklo_u16(mins_u8)); + svuint32_t mins_u16_1 = svunpklo_u32(svunpklo_u16(mins_u8_1)); + + q4sb_mins_0 = svreinterpret_s32_u32(mins_u16); + q4sb_mins_1 = svreinterpret_s32_u32(mins_u16_1); + + uint32_t scales_u32_0 = sm[0] & kmask1; + uint32_t scales_u32_1 = (sm[2] & kmask2) | (((sm[0] >> 6) & kmask3) << 4); + uint32_t scales_u32_2 = sm1[0] & kmask1; + uint32_t scales_u32_3 = (sm1[2] & kmask2) | (((sm1[0] >> 6) & kmask3) << 4); + + svuint32_t S01 = svdup_n_u32(scales_u32_0); + svuint32_t S23 = svdup_n_u32(scales_u32_1); + svuint32_t R01 = svdup_n_u32(scales_u32_2); + svuint32_t R23 = svdup_n_u32(scales_u32_3); + + svint8_t S01_b = svreinterpret_s8_u32(S01); + svint8_t S23_b = svreinterpret_s8_u32(S23); + svint8_t R01_b = svreinterpret_s8_u32(R01); + svint8_t R23_b = svreinterpret_s8_u32(R23); + + svint32_t S01_d = svunpklo_s32(svunpklo_s16(svzip1_s8(S01_b, S01_b))); + svint32_t R01_d = svunpklo_s32(svunpklo_s16(svzip1_s8(R01_b, R01_b))); + svint32_t S23_d = svunpklo_s32(svunpklo_s16(svzip1_s8(S23_b, S23_b))); + svint32_t R23_d = svunpklo_s32(svunpklo_s16(svzip1_s8(R23_b, R23_b))); + + block_scale_0 = svtbl_s32(svzip1_s32(S01_d, R01_d), idx); + block_scale_1 = svtbl_s32(svzip2_s32(S01_d, R01_d), idx); + block_scale_2 = svtbl_s32(svzip1_s32(S23_d, R23_d), idx); + block_scale_3 = svtbl_s32(svzip2_s32(S23_d, R23_d), idx); + } + + const int8_t * q8_base_1 = q8_ptr[b].qs + sb * 256; + + // Load 32-byte per row pair, 1 subblock each time + // predicate for activating higher lanes for 16 int8 elements + const svbool_t ph16 = svptrue_pat_b8(SV_VL16); + // predicate for activating lower lanes for 16 int8 elements + const svbool_t pl16 = svnot_b_z(svptrue_b8(), ph16); + + svint8_t q8_qs_0 = svadd_s8_x(svptrue_b8(), svld1_s8(ph16, q8_base_1 + 0), svld1_s8(pl16, q8_base_1 + 112)); + svint8_t q8_qs_2 = svadd_s8_x(svptrue_b8(), svld1_s8(ph16, q8_base_1 + 32), svld1_s8(pl16, q8_base_1 + 144)); + svint8_t q8_qs_4 = svadd_s8_x(svptrue_b8(), svld1_s8(ph16, q8_base_1 + 64), svld1_s8(pl16, q8_base_1 + 176)); + svint8_t q8_qs_6 = svadd_s8_x(svptrue_b8(), svld1_s8(ph16, q8_base_1 + 96), svld1_s8(pl16, q8_base_1 + 208)); + + svint8_t q8_qs_1 = svadd_s8_x(svptrue_b8(), svld1_s8(ph16, q8_base_1 + 16), svld1_s8(pl16, q8_base_1 + 128)); + svint8_t q8_qs_3 = svadd_s8_x(svptrue_b8(), svld1_s8(ph16, q8_base_1 + 48), svld1_s8(pl16, q8_base_1 + 160)); + svint8_t q8_qs_5 = svadd_s8_x(svptrue_b8(), svld1_s8(ph16, q8_base_1 + 80), svld1_s8(pl16, q8_base_1 + 192)); + svint8_t q8_qs_7 = svadd_s8_x(svptrue_b8(), svld1_s8(ph16, q8_base_1 + 112), svld1_s8(pl16, q8_base_1 + 224)); + + // Q4s columns iterated in pairs (01, 23, 45, 67) + for (int cp = 0; cp < ncols_interleaved / 2; cp++) { + + sb_acc_0 = svdup_n_s32(0); + sb_acc_2 = svdup_n_s32(0); + + svuint8_t q4_qs_cp_00 = svld1rq_u8(svptrue_b8(), q4_ptr[b].qs + sb * QK_K + 16 * cp + 0); + svuint8_t q4_qs_cp_01 = svld1rq_u8(svptrue_b8(), q4_ptr[b].qs + sb * QK_K + 16 * cp + 64); + svuint8_t q4_qs_cp_02 = svld1rq_u8(svptrue_b8(), q4_ptr[b].qs + sb * QK_K + 16 * cp + 128); + svuint8_t q4_qs_cp_03 = svld1rq_u8(svptrue_b8(), q4_ptr[b].qs + sb * QK_K + 16 * cp + 192); + + svint8_t q4_nibbles_00 = svreinterpret_s8_u8(svlsr_n_u8_m(pl16, svand_u8_m(ph16, q4_qs_cp_00, m4b_1), 4)); + svint8_t q4_nibbles_01 = svreinterpret_s8_u8(svlsr_n_u8_m(pl16, svand_u8_m(ph16, q4_qs_cp_01, m4b_1), 4)); + svint8_t q4_nibbles_02 = svreinterpret_s8_u8(svlsr_n_u8_m(pl16, svand_u8_m(ph16, q4_qs_cp_02, m4b_1), 4)); + svint8_t q4_nibbles_03 = svreinterpret_s8_u8(svlsr_n_u8_m(pl16, svand_u8_m(ph16, q4_qs_cp_03, m4b_1), 4)); + + sb_acc_0 = svmmla_s32(sb_acc_0, q4_nibbles_00, q8_qs_0); + sb_acc_0 = svmmla_s32(sb_acc_0, q4_nibbles_01, q8_qs_2); + + sb_acc_0 = svmmla_s32(sb_acc_0, q4_nibbles_02, q8_qs_4); + sb_acc_0 = svmmla_s32(sb_acc_0, q4_nibbles_03, q8_qs_6); + + sb_acc_2 = svmmla_s32(sb_acc_2, q4_nibbles_00, q8_qs_1); + sb_acc_2 = svmmla_s32(sb_acc_2, q4_nibbles_01, q8_qs_3); + + sb_acc_2 = svmmla_s32(sb_acc_2, q4_nibbles_02, q8_qs_5); + sb_acc_2 = svmmla_s32(sb_acc_2, q4_nibbles_03, q8_qs_7); + + if(cp == 0) { + acc_00 = svmla_s32_m(svptrue_b32(), acc_00, sb_acc_0, block_scale_0); + acc_44 = svmla_s32_m(svptrue_b32(), acc_44, sb_acc_2, block_scale_0); + } + if(cp == 1) { + acc_11 = svmla_s32_m(svptrue_b32(), acc_11, sb_acc_0, block_scale_1); + acc_55 = svmla_s32_m(svptrue_b32(), acc_55, sb_acc_2, block_scale_1); + } + if(cp == 2) { + acc_22 = svmla_s32_m(svptrue_b32(), acc_22, sb_acc_0, block_scale_2); + acc_66 = svmla_s32_m(svptrue_b32(), acc_66, sb_acc_2, block_scale_2); + } + if(cp == 3) { + acc_33 = svmla_s32_m(svptrue_b32(), acc_33, sb_acc_0, block_scale_3); + acc_77 = svmla_s32_m(svptrue_b32(), acc_77, sb_acc_2, block_scale_3); + } + } + + bias_acc_00 = svmla_s32_m(svptrue_pat_b32(SV_VL8), bias_acc_00, svdup_n_s32(bsums_arr32[sb][0]), q4sb_mins_0); + bias_acc_00 = svmla_s32_m(svptrue_pat_b32(SV_VL8), bias_acc_00, svdup_n_s32(bsums_arr32[sb][1]), q4sb_mins_1); + + bias_acc_22 = svmla_s32_m(svptrue_pat_b32(SV_VL8), bias_acc_22, svdup_n_s32(bsums_arr32[sb][2]), q4sb_mins_0); + bias_acc_22 = svmla_s32_m(svptrue_pat_b32(SV_VL8), bias_acc_22, svdup_n_s32(bsums_arr32[sb][3]), q4sb_mins_1); + + bias_acc_44 = svmla_s32_m(svptrue_pat_b32(SV_VL8), bias_acc_44, svdup_n_s32(bsums_arr32[sb][4]), q4sb_mins_0); + bias_acc_44 = svmla_s32_m(svptrue_pat_b32(SV_VL8), bias_acc_44, svdup_n_s32(bsums_arr32[sb][5]), q4sb_mins_1); + + bias_acc_66 = svmla_s32_m(svptrue_pat_b32(SV_VL8), bias_acc_66, svdup_n_s32(bsums_arr32[sb][6]), q4sb_mins_0); + bias_acc_66 = svmla_s32_m(svptrue_pat_b32(SV_VL8), bias_acc_66, svdup_n_s32(bsums_arr32[sb][7]), q4sb_mins_1); + } // for sb + + + acc_00 = svadd_s32_z(svptrue_pat_b32(SV_VL4), acc_00, svext_s32(acc_00, acc_00, 4)); + acc_11 = svadd_s32_z(svptrue_pat_b32(SV_VL4), acc_11, svext_s32(acc_11, acc_11, 4)); + acc_22 = svadd_s32_z(svptrue_pat_b32(SV_VL4), acc_22, svext_s32(acc_22, acc_22, 4)); + acc_33 = svadd_s32_z(svptrue_pat_b32(SV_VL4), acc_33, svext_s32(acc_33, acc_33, 4)); + acc_44 = svadd_s32_z(svptrue_pat_b32(SV_VL4), acc_44, svext_s32(acc_44, acc_44, 4)); + acc_55 = svadd_s32_z(svptrue_pat_b32(SV_VL4), acc_55, svext_s32(acc_55, acc_55, 4)); + acc_66 = svadd_s32_z(svptrue_pat_b32(SV_VL4), acc_66, svext_s32(acc_66, acc_66, 4)); + acc_77 = svadd_s32_z(svptrue_pat_b32(SV_VL4), acc_77, svext_s32(acc_77, acc_77, 4)); + + svint32_t reorder_acc_01 = svtbl_s32( svzip1_s32( svtrn1_s32(acc_00, acc_11), svtrn1_s32(acc_22, acc_33)), idx1); + svint32_t reorder_acc_23 = svtbl_s32( svzip1_s32( svtrn2_s32(acc_00, acc_11), svtrn2_s32(acc_22, acc_33)), idx1); + + svint32_t reorder_acc_45 = svtbl_s32( svzip1_s32( svtrn1_s32(acc_44, acc_55), svtrn1_s32(acc_66, acc_77)), idx1); + svint32_t reorder_acc_67 = svtbl_s32( svzip1_s32( svtrn2_s32(acc_44, acc_55), svtrn2_s32(acc_66, acc_77)), idx1); + + // Broadcast q8 scalar + svfloat32_t q8_d = svdup_f32(q8_ptr[b].d[0]); + + svfloat32_t q4_dmin_temp = svcvt_f32_f16_x(svptrue_b32(), svzip1_f16( svld1_f16(svptrue_pat_b16(SV_VL8), (const __fp16 *)q4_ptr[b].dmin), svdup_f16(0))); + + svfloat32_t q4_d_temp = svcvt_f32_f16_x(svptrue_b32(), svzip1_f16( svld1_f16(svptrue_pat_b16(SV_VL8), (const __fp16 *)q4_ptr[b].d), svdup_f16(0))); + + svfloat32_t scale1 = svmul_f32_x(svptrue_b32(), q4_d_temp, q8_d); + svfloat32_t dmins1 = svmul_f32_x(svptrue_b32(), q4_dmin_temp, q8_d); + + acc_f32_01 = svmls_f32_m(svptrue_b32(), acc_f32_01, svcvt_f32_s32_m(svdup_n_f32(0), svptrue_b32(), bias_acc_00), dmins1); + acc_f32_01 = svmla_f32_m(svptrue_b32(), acc_f32_01, svcvt_f32_s32_m(svdup_n_f32(0), svptrue_b32(), reorder_acc_01), scale1); + + q8_d = svdup_f32(q8_ptr[b].d[1]); + + scale1 = svmul_f32_x(svptrue_b32(), q4_d_temp, q8_d); + dmins1 = svmul_f32_x(svptrue_b32(), q4_dmin_temp, q8_d); + + acc_f32_23 = svmls_f32_m(svptrue_b32(), acc_f32_23, svcvt_f32_s32_m(svdup_n_f32(0), svptrue_b32(), bias_acc_22), dmins1); + acc_f32_23 = svmla_f32_m(svptrue_b32(), acc_f32_23, svcvt_f32_s32_m(svdup_n_f32(0), svptrue_b32(), reorder_acc_23), scale1); + + q8_d = svdup_f32(q8_ptr[b].d[2]); + + + scale1 = svmul_f32_x(svptrue_b32(), q4_d_temp, q8_d); + dmins1 = svmul_f32_x(svptrue_b32(), q4_dmin_temp, q8_d); + + acc_f32_45 = svmls_f32_m(svptrue_b32(), acc_f32_45, svcvt_f32_s32_m(svdup_n_f32(0), svptrue_b32(), bias_acc_44), dmins1); + acc_f32_45 = svmla_f32_m(svptrue_b32(), acc_f32_45, svcvt_f32_s32_m(svdup_n_f32(0), svptrue_b32(), reorder_acc_45), scale1); + + q8_d = svdup_f32(q8_ptr[b].d[3]); + + scale1 = svmul_f32_x(svptrue_b32(), q4_d_temp, q8_d); + dmins1 = svmul_f32_x(svptrue_b32(), q4_dmin_temp, q8_d); + + acc_f32_67 = svmls_f32_m(svptrue_b32(), acc_f32_67, svcvt_f32_s32_m(svdup_n_f32(0), svptrue_b32(), bias_acc_66), dmins1); + acc_f32_67 = svmla_f32_m(svptrue_b32(), acc_f32_67, svcvt_f32_s32_m(svdup_n_f32(0), svptrue_b32(), reorder_acc_67), scale1); + + } // for b + + // With the previous reorder, the tile is already in the correct memory layout. + // Predicate for exactly 4 lanes + svbool_t pg4 = svptrue_pat_b32(SV_VL4); + for (int i = 0; i < q8_k_blocklen; i++) { + int row = y * q8_k_blocklen + i; + for (int j = 0; j < 2; j++) { + int col = x * ncols_interleaved + j * 4; + int offset = row * bs + col; + + if (i == 0 && j == 0) { + // acc_f32_0 → lower half of acc_f32_01 + svst1_f32(pg4, s + offset, acc_f32_01); + } else if (i == 0 && j == 1) { + // acc_f32_1 → upper half of acc_f32_01 + svst1_f32(pg4, s + offset, svext_f32(acc_f32_01, acc_f32_01, 4)); + } else if (i == 1 && j == 0) { + // acc_f32_2 + svst1_f32(pg4, s + offset, acc_f32_23); + } else if (i == 1 && j == 1) { + // acc_f32_3 + svst1_f32(pg4, s + offset, svext_f32(acc_f32_23, acc_f32_23, 4)); + } else if (i == 2 && j == 0) { + // acc_f32_4 + svst1_f32(pg4, s + offset, acc_f32_45); + } else if (i == 2 && j == 1) { + // acc_f32_5 + svst1_f32(pg4, s + offset, svext_f32(acc_f32_45, acc_f32_45, 4)); + } else if (i == 3 && j == 0) { + // acc_f32_6 + svst1_f32(pg4, s + offset, acc_f32_67); + } else if (i == 3 && j == 1) { + // acc_f32_7 + svst1_f32(pg4, s + offset, svext_f32(acc_f32_67, acc_f32_67, 4)); + } + } + } + } // for x + } // for y + return; + } +#endif // SVE compile-time end + +#if defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_MATMUL_INT8) + constexpr int q8_k_blocklen = 4; + const uint8x16_t m4b = vdupq_n_u8(0x0f); + + // 8 accumulators: 2 row pairs × 4 col pairs + float32x4_t acc_f32[blocklen]; + + for (int y = 0; y < nr / q8_k_blocklen; y++) { + const block_q8_Kx4 * GGML_RESTRICT q8_ptr = (const block_q8_Kx4 *) vy + (y * nb); + + for (int x = 0; x < nc / ncols_interleaved; x++) { + const block_q4_Kx8 * GGML_RESTRICT q4_ptr = (const block_q4_Kx8 *) vx + (x * nb); + + for (int i = 0; i < blocklen; i++) { + acc_f32[i] = vdupq_n_f32(0); + } + + for (int b = 0; b < nb; b++) { + // bsums pairs belongs to the same q8_k subblock + const int16x8_t bsums[4]{ + vpaddq_s16(vld1q_s16(q8_ptr[b].bsums + 16 * 0), vld1q_s16(q8_ptr[b].bsums + 16 * 0 + 8)), + vpaddq_s16(vld1q_s16(q8_ptr[b].bsums + 16 * 1), vld1q_s16(q8_ptr[b].bsums + 16 * 1 + 8)), + vpaddq_s16(vld1q_s16(q8_ptr[b].bsums + 16 * 2), vld1q_s16(q8_ptr[b].bsums + 16 * 2 + 8)), + vpaddq_s16(vld1q_s16(q8_ptr[b].bsums + 16 * 3), vld1q_s16(q8_ptr[b].bsums + 16 * 3 + 8)), + }; + int16_t bsums_arr[4][8]; + for (int q8_row = 0; q8_row < 4; q8_row++) { + vst1q_s16(bsums_arr[q8_row], bsums[q8_row]); + } + + int32x4_t sb_acc[4]; // Aux accumulators to store subblock (partial) results + int32x4_t acc[8]; // rows 01 stored in [0][1][2][3] rows 23 stored in [4][5][6][7] + int32x4_t bias_acc[8]; // interleaved bias_acc: [0]->r0 0123, [1]->r0 4567, [2]->r1 0123 ... + for (int i = 0; i < 8; i++) { + acc[i] = vdupq_n_s32(0); + bias_acc[i] = vdupq_n_s32(0); + } + + for (int sb = 0; sb < QK_K / 64; sb++) { + // Need scales for the low and high nibbles + // 2 * 12 = 24 bytes per subblock, 4 sbs -> 4 * 24 = 96 bytes total + int8_t q4sb_scales[2][8]; + int16x8_t q4sb_mins[2]; // int16 as its needed for bias_acc later + for (int i = 0; i < 2; i++) { + const int offset = sb * 24 + i * 12; + decode_q_Kx8_6bit_scales(&q4_ptr[b].scales[offset], &q4sb_mins[i], q4sb_scales[i]); + } + + // q8_ptr[b].qs has interleaved Q8 rows (01, 23) + const int8_t * q8_base = q8_ptr[b].qs + sb * 256; + + int8x16_t q8_qs_01[8]; + int8x16_t q8_qs_23[8]; + + // Load 32-byte per row pair, 1 subblock each time + for (int i = 0; i < 8; i++) { + const int offset = i * 32; // 16 for row 01, 16 for row 23 + q8_qs_01[i] = vld1q_s8(q8_base + offset); + q8_qs_23[i] = vld1q_s8(q8_base + offset + 16); + } + + const int8x16_t q8s[2][8] = { + { q8_qs_01[0], q8_qs_01[1], q8_qs_01[2], q8_qs_01[3], + q8_qs_01[4], q8_qs_01[5], q8_qs_01[6], q8_qs_01[7] }, + { q8_qs_23[0], q8_qs_23[1], q8_qs_23[2], q8_qs_23[3], + q8_qs_23[4], q8_qs_23[5], q8_qs_23[6], q8_qs_23[7] }, + }; + + // Q4s columns iterated in pairs (01, 23, 45, 67) + for (int cp = 0; cp < ncols_interleaved / 2; cp++) { + for (int i = 0; i < 4; i++) { + sb_acc[i] = vdupq_n_s32(0); + } + + uint8x16_t q4_qs_cp_0 = vld1q_u8(q4_ptr[b].qs + sb * QK_K + 16 * cp + 0); // 0 .. 7 & 32..39 + uint8x16_t q4_qs_cp_1 = vld1q_u8(q4_ptr[b].qs + sb * QK_K + 16 * cp + 64); // 8 ..15 & 40..47 + uint8x16_t q4_qs_cp_2 = vld1q_u8(q4_ptr[b].qs + sb * QK_K + 16 * cp + 128); // 16..23 & 48..55 + uint8x16_t q4_qs_cp_3 = vld1q_u8(q4_ptr[b].qs + sb * QK_K + 16 * cp + 192); // 24..31 & 56..63 + const int8x16_t q4_nibbles[2][4] = { + { + vreinterpretq_s8_u8(vandq_u8(q4_qs_cp_0, m4b)), + vreinterpretq_s8_u8(vandq_u8(q4_qs_cp_1, m4b)), + vreinterpretq_s8_u8(vandq_u8(q4_qs_cp_2, m4b)), + vreinterpretq_s8_u8(vandq_u8(q4_qs_cp_3, m4b)), + }, + { + vreinterpretq_s8_u8(vshrq_n_u8(q4_qs_cp_0, 4)), + vreinterpretq_s8_u8(vshrq_n_u8(q4_qs_cp_1, 4)), + vreinterpretq_s8_u8(vshrq_n_u8(q4_qs_cp_2, 4)), + vreinterpretq_s8_u8(vshrq_n_u8(q4_qs_cp_3, 4)), + } + }; + + // Calculates the Qs muladd of every row pair (rp) rows 01 and 23 of q8 + // for each of the internal 32 qs subblock (blk) + for (int rp = 0; rp < 2; rp++) { + for (int blk = 0; blk < 2; blk++) { + const int8x16_t * q8 = &q8s[rp][4 * blk]; + const int8x16_t * q4 = q4_nibbles[blk]; + int32x4_t acc = sb_acc[2 * rp + blk]; + // mul add for each qs in the same subblock + for (int qs_offset = 0; qs_offset < 4; qs_offset++) { + acc = vmmlaq_s32(acc, q4[qs_offset], q8[qs_offset]); + } + sb_acc[2 * rp + blk] = acc; + } + } + + // Scales[i] corresponds to column i + const int scale_offset = cp * 2; + const int32_t scale_00 = q4sb_scales[0][scale_offset]; + const int32_t scale_01 = q4sb_scales[0][scale_offset + 1]; + const int32_t scale_10 = q4sb_scales[1][scale_offset]; + const int32_t scale_11 = q4sb_scales[1][scale_offset + 1]; + const int32x4_t block_scale_0 = vcombine_s32(vdup_n_s32(scale_00), vdup_n_s32(scale_01)); + const int32x4_t block_scale_1 = vcombine_s32(vdup_n_s32(scale_10), vdup_n_s32(scale_11)); + + acc[cp] = vmlaq_s32(acc[cp], sb_acc[0], block_scale_0); + acc[cp + 4] = vmlaq_s32(acc[cp + 4], sb_acc[2], block_scale_0); + acc[cp] = vmlaq_s32(acc[cp], sb_acc[1], block_scale_1); + acc[cp + 4] = vmlaq_s32(acc[cp + 4], sb_acc[3], block_scale_1); + } + + // Multiply Acc bsum + mins + for (int q8_row = 0; q8_row < 4; q8_row++) { + // Each pair of subblocks share the same bsums + // Load scalar bsum → broadcast to a vector (vdupq_n_s16(s)). + int16x4_t bsums_vec_lo = vdup_n_s16(bsums_arr[sb][q8_row * 2]); + int16x4_t bsums_vec_hi = vdup_n_s16(bsums_arr[sb][q8_row * 2 + 1]); + + bias_acc[2 * q8_row] = + vmlal_s16(bias_acc[2 * q8_row], bsums_vec_lo, vget_low_s16(q4sb_mins[0])); + bias_acc[2 * q8_row] = + vmlal_s16(bias_acc[2 * q8_row], bsums_vec_hi, vget_low_s16(q4sb_mins[1])); + bias_acc[2 * q8_row + 1] = + vmlal_s16(bias_acc[2 * q8_row + 1], bsums_vec_lo, vget_high_s16(q4sb_mins[0])); + bias_acc[2 * q8_row + 1] = + vmlal_s16(bias_acc[2 * q8_row + 1], bsums_vec_hi, vget_high_s16(q4sb_mins[1])); + } + } // for sb + + // Reorder of i8mm output with bias and output layout + for (int i = 0; i < 8; i++) { + int32x2x2_t aux = vzip_s32(vget_low_s32(acc[i]), vget_high_s32(acc[i])); + acc[i] = vcombine_s32(aux.val[0], aux.val[1]); + } + int32x4_t reorder_acc[8] = { + vcombine_s32(vget_low_s32(acc[0]), vget_low_s32(acc[1])), + vcombine_s32(vget_low_s32(acc[2]), vget_low_s32(acc[3])), + vcombine_s32(vget_high_s32(acc[0]), vget_high_s32(acc[1])), + vcombine_s32(vget_high_s32(acc[2]), vget_high_s32(acc[3])), + vcombine_s32(vget_low_s32(acc[4]), vget_low_s32(acc[5])), + vcombine_s32(vget_low_s32(acc[6]), vget_low_s32(acc[7])), + vcombine_s32(vget_high_s32(acc[4]), vget_high_s32(acc[5])), + vcombine_s32(vget_high_s32(acc[6]), vget_high_s32(acc[7])), + }; + + for (int i = 0; i < q8_k_blocklen; i++) { + for (int j = 0; j < 2; j++) { + float32x4_t q8_d = vdupq_n_f32(q8_ptr[b].d[i]); + float32x4_t q4_dmin = vcvt_f32_f16(vld1_f16((const __fp16 *) (q4_ptr[b].dmin + j * 4))); + const float32x4_t dmins = vmulq_f32(q4_dmin, q8_d); + + float32x4_t q4_d = vcvt_f32_f16(vld1_f16((const __fp16 *) (q4_ptr[b].d + j * 4))); + const float32x4_t scale = vmulq_f32(q4_d, q8_d); + + acc_f32[2 * i + j] = vmlsq_f32(acc_f32[2 * i + j], vcvtq_f32_s32(bias_acc[2 * i + j]), dmins); + acc_f32[2 * i + j] = + vmlaq_f32(acc_f32[2 * i + j], vcvtq_f32_s32(reorder_acc[2 * i + j]), scale); + } + } + } // for b -#endif // #if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) - ggml_gemm_q4_0_8x8_q8_0_generic(n, s, bs, vx, vy, nr, nc); + // With the previous reorder, the tile is already in the correct memory layout. + for (int i = 0; i < q8_k_blocklen; i++) { + int row = y * q8_k_blocklen + i; + for (int j = 0; j < 2; j++) { + int col = x * ncols_interleaved + j * 4; + int offset = row * bs + col; + vst1q_f32(s + offset, acc_f32[2 * i + j]); + } + } + } // for x + } // for y + return; +#endif // defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_MATMUL_INT8) + ggml_gemm_q4_K_8x8_q8_K_generic(n, s, bs, vx, vy, nr, nc); } -void ggml_gemm_iq4_nl_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { - const int qk = QK8_0; - const int nb = n / qk; - const int ncols_interleaved = 4; - const int blocklen = 4; +void ggml_gemm_q5_K_8x8_q8_K(int n, + float * GGML_RESTRICT s, + size_t bs, + const void * GGML_RESTRICT vx, + const void * GGML_RESTRICT vy, + int nr, + int nc) { + constexpr int qk = QK_K; + const int nb = n / qk; - assert (n % qk == 0); - assert (nr % 4 == 0); - assert (nc % ncols_interleaved == 0); + constexpr int ncols_interleaved = 8; + constexpr int blocklen = 8; + + assert(n % qk == 0); + assert(nr % 4 == 0); + assert(nc % ncols_interleaved == 0); - UNUSED(s); - UNUSED(bs); - UNUSED(vx); - UNUSED(vy); - UNUSED(nr); - UNUSED(nc); UNUSED(nb); UNUSED(ncols_interleaved); UNUSED(blocklen); -#if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD) - const int8x16_t kvalues = vld1q_s8(kvalues_iq4nl); +#if defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_MATMUL_INT8) + constexpr int q8_k_blocklen = 4; + constexpr int col_pairs = ncols_interleaved / 2; + const uint8x16_t m4b = vdupq_n_u8(0x0f); + const uint8x16_t mone = vdupq_n_u8(1); + const uint8x16_t mtwo = vdupq_n_u8(2); + + // 8 accumulators: 2 row pairs × 4 col pairs + float32x4_t acc_f32[blocklen]; + + for (int y = 0; y < nr / q8_k_blocklen; y++) { + const block_q8_Kx4 * GGML_RESTRICT q8_ptr = (const block_q8_Kx4 *) vy + (y * nb); - for (int y = 0; y < nr / 4; y++) { - const block_q8_0x4 * a_ptr = (const block_q8_0x4 *) vy + (y * nb); for (int x = 0; x < nc / ncols_interleaved; x++) { - const block_iq4_nlx4 * b_ptr = (const block_iq4_nlx4 *) vx + (x * nb); + const block_q5_Kx8 * GGML_RESTRICT q5_ptr = (const block_q5_Kx8 *) vx + (x * nb); - float32x4_t sumf[4]; - for (int m = 0; m < 4; m++) { - sumf[m] = vdupq_n_f32(0); + for (int i = 0; i < blocklen; i++) { + acc_f32[i] = vdupq_n_f32(0); } - for (int l = 0; l < nb; l++) { - float32x4_t a_d = vcvt_f32_f16(vld1_f16((const float16_t *)a_ptr[l].d)); - float32x4_t b_d = vcvt_f32_f16(vld1_f16((const float16_t *)b_ptr[l].d)); + for (int b = 0; b < nb; b++) { + // bsums pairs belongs to the same q8_k subblock + const int16x8_t bsums[4]{ + vpaddq_s16(vld1q_s16(q8_ptr[b].bsums + 16 * 0), vld1q_s16(q8_ptr[b].bsums + 16 * 0 + 8)), + vpaddq_s16(vld1q_s16(q8_ptr[b].bsums + 16 * 1), vld1q_s16(q8_ptr[b].bsums + 16 * 1 + 8)), + vpaddq_s16(vld1q_s16(q8_ptr[b].bsums + 16 * 2), vld1q_s16(q8_ptr[b].bsums + 16 * 2 + 8)), + vpaddq_s16(vld1q_s16(q8_ptr[b].bsums + 16 * 3), vld1q_s16(q8_ptr[b].bsums + 16 * 3 + 8)), + }; + int16_t bsums_arr[4][8]; + for (int q8_row = 0; q8_row < 4; q8_row++) { + vst1q_s16(bsums_arr[q8_row], bsums[q8_row]); + } - int32x4_t sumi_0 = vdupq_n_s32(0); - int32x4_t sumi_1 = vdupq_n_s32(0); - int32x4_t sumi_2 = vdupq_n_s32(0); - int32x4_t sumi_3 = vdupq_n_s32(0); + int32x4_t sb_acc[4]; // Aux accumulators to store subblock (partial) results + int32x4_t acc[8]; // rows 01 stored in [0][1][2][3] rows 23 stored in [4][5][6][7] + int32x4_t bias_acc[8]; // interleaved bias_acc: [0]->r0 0123, [1]->r0 4567, [2]->r1 0123 ... + for (int i = 0; i < 8; i++) { + acc[i] = vdupq_n_s32(0); + bias_acc[i] = vdupq_n_s32(0); + } - for (int k = 0; k < 4; k++) { - int8x16_t a_0 = vld1q_s8(a_ptr[l].qs + 16 * k + 0); - int8x16_t a_1 = vld1q_s8(a_ptr[l].qs + 16 * k + 64); + // Load qh once per block and shift after each subblock + const uint8_t * qh_base = q5_ptr[b].qh; + uint8x16_t qh[col_pairs][4]; + for (int cp = 0; cp < col_pairs; cp++) { + qh[cp][0] = vld1q_u8(qh_base + 16 * cp); + qh[cp][1] = vld1q_u8(qh_base + 16 * cp + 64); + qh[cp][2] = vld1q_u8(qh_base + 16 * cp + 128); + qh[cp][3] = vld1q_u8(qh_base + 16 * cp + 192); + } - uint8x16_t b = vld1q_u8(b_ptr[l].qs + 16 * k); - int8x16_t b_hi = vqtbl1q_s8(kvalues, b >> 4); - int8x16_t b_lo = vqtbl1q_s8(kvalues, b & 0xF); + for (int sb = 0; sb < QK_K / 64; sb++) { + // Need scales for the low and high nibbles + // 2 * 12 = 24 bytes per subblock, 4 sbs -> 4 * 24 = 96 bytes total + int8_t q5sb_scales[2][8]; + int16x8_t q5sb_mins[2]; // int16 as its needed for bias_acc later + for (int i = 0; i < 2; i++) { + const int offset = sb * 24 + i * 12; + decode_q_Kx8_6bit_scales(&q5_ptr[b].scales[offset], &q5sb_mins[i], q5sb_scales[i]); + } - sumi_0 = vdotq_laneq_s32(sumi_0, b_lo, a_0, 0); - sumi_1 = vdotq_laneq_s32(sumi_1, b_lo, a_0, 1); - sumi_2 = vdotq_laneq_s32(sumi_2, b_lo, a_0, 2); - sumi_3 = vdotq_laneq_s32(sumi_3, b_lo, a_0, 3); - sumi_0 = vdotq_laneq_s32(sumi_0, b_hi, a_1, 0); - sumi_1 = vdotq_laneq_s32(sumi_1, b_hi, a_1, 1); - sumi_2 = vdotq_laneq_s32(sumi_2, b_hi, a_1, 2); - sumi_3 = vdotq_laneq_s32(sumi_3, b_hi, a_1, 3); + // q8_ptr[b].qs has interleaved Q8 rows (01, 23) + const int8_t * q8_base = q8_ptr[b].qs + sb * 256; + + int8x16_t q8_qs_01[8]; + int8x16_t q8_qs_23[8]; + + // Load 32-byte per row pair, 1 subblock each time + for (int i = 0; i < 8; i++) { + const int offset = i * 32; // 16 for row 01, 16 for row 23 + q8_qs_01[i] = vld1q_s8(q8_base + offset); + q8_qs_23[i] = vld1q_s8(q8_base + offset + 16); + } + + const int8x16_t q8s[2][8] = { + { q8_qs_01[0], q8_qs_01[1], q8_qs_01[2], q8_qs_01[3], q8_qs_01[4], q8_qs_01[5], q8_qs_01[6], + q8_qs_01[7] }, + { q8_qs_23[0], q8_qs_23[1], q8_qs_23[2], q8_qs_23[3], q8_qs_23[4], q8_qs_23[5], q8_qs_23[6], + q8_qs_23[7] }, + }; + + // Q5s columns iterated in pairs (01, 23, 45, 67) + for (int cp = 0; cp < col_pairs; cp++) { + for (int i = 0; i < 4; i++) { + sb_acc[i] = vdupq_n_s32(0); + } + + uint8x16_t qs_cp_0 = vld1q_u8(q5_ptr[b].qs + sb * QK_K + 16 * cp + 0); // 0 .. 7 & 32..39 + uint8x16_t qs_cp_1 = vld1q_u8(q5_ptr[b].qs + sb * QK_K + 16 * cp + 64); // 8 ..15 & 40..47 + uint8x16_t qs_cp_2 = vld1q_u8(q5_ptr[b].qs + sb * QK_K + 16 * cp + 128); // 16..23 & 48..55 + uint8x16_t qs_cp_3 = vld1q_u8(q5_ptr[b].qs + sb * QK_K + 16 * cp + 192); // 24..31 & 56..63 + + // This is the only part of the algorithm that differs with Q4_K + // Extract High bits and pack into 5 bit weights + uint8x16_t hbit_lo_0 = vandq_u8(qh[cp][0], mone); + uint8x16_t hbit_hi_0 = vshlq_n_u8(vandq_u8(qh[cp][0], mtwo), 3); + qh[cp][0] = vshrq_n_u8(qh[cp][0], 2); + // Same as Q4_K, i8mm to dequantize the weights. + const int8x16_t qs_lo_0 = vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(qs_cp_0, m4b), hbit_lo_0, 4)); + int32x4_t acc_0 = sb_acc[0]; + acc_0 = vmmlaq_s32(acc_0, qs_lo_0, q8s[0][0]); + int32x4_t acc_2 = sb_acc[2]; + acc_2 = vmmlaq_s32(acc_2, qs_lo_0, q8s[1][0]); + const int8x16_t qs_hi_0 = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_cp_0, 4), hbit_hi_0)); + int32x4_t acc_1 = sb_acc[1]; + acc_1 = vmmlaq_s32(acc_1, qs_hi_0, q8s[0][4]); + int32x4_t acc_3 = sb_acc[3]; + acc_3 = vmmlaq_s32(acc_3, qs_hi_0, q8s[1][4]); + + // Repeat for the other 3 columns (8..15, 16..23, 24..31) + uint8x16_t hbit_hi_1 = vshlq_n_u8(vandq_u8(qh[cp][1], mtwo), 3); + uint8x16_t hbit_lo_1 = vandq_u8(qh[cp][1], mone); + qh[cp][1] = vshrq_n_u8(qh[cp][1], 2); + const int8x16_t qs_lo_1 = vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(qs_cp_1, m4b), hbit_lo_1, 4)); + acc_0 = vmmlaq_s32(acc_0, qs_lo_1, q8s[0][1]); + acc_2 = vmmlaq_s32(acc_2, qs_lo_1, q8s[1][1]); + const int8x16_t qs_hi_1 = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_cp_1, 4), hbit_hi_1)); + acc_1 = vmmlaq_s32(acc_1, qs_hi_1, q8s[0][5]); + acc_3 = vmmlaq_s32(acc_3, qs_hi_1, q8s[1][5]); + + uint8x16_t hbit_hi_2 = vshlq_n_u8(vandq_u8(qh[cp][2], mtwo), 3); + uint8x16_t hbit_lo_2 = vandq_u8(qh[cp][2], mone); + qh[cp][2] = vshrq_n_u8(qh[cp][2], 2); + const int8x16_t qs_lo_2 = vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(qs_cp_2, m4b), hbit_lo_2, 4)); + acc_0 = vmmlaq_s32(acc_0, qs_lo_2, q8s[0][2]); + acc_2 = vmmlaq_s32(acc_2, qs_lo_2, q8s[1][2]); + const int8x16_t qs_hi_2 = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_cp_2, 4), hbit_hi_2)); + acc_1 = vmmlaq_s32(acc_1, qs_hi_2, q8s[0][6]); + acc_3 = vmmlaq_s32(acc_3, qs_hi_2, q8s[1][6]); + + uint8x16_t hbit_lo_3 = vandq_u8(qh[cp][3], mone); + uint8x16_t hbit_hi_3 = vshlq_n_u8(vandq_u8(qh[cp][3], mtwo), 3); + qh[cp][3] = vshrq_n_u8(qh[cp][3], 2); + const int8x16_t qs_lo_3 = vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(qs_cp_3, m4b), hbit_lo_3, 4)); + acc_0 = vmmlaq_s32(acc_0, qs_lo_3, q8s[0][3]); + sb_acc[0] = acc_0; + acc_2 = vmmlaq_s32(acc_2, qs_lo_3, q8s[1][3]); + sb_acc[2] = acc_2; + + // Scales[i] corresponds to column i + const int scale_offset = cp * 2; + const int32_t s0 = q5sb_scales[0][scale_offset]; + const int32_t s1 = q5sb_scales[0][scale_offset + 1]; + const int32x4_t block_scale = vcombine_s32(vdup_n_s32(s0), vdup_n_s32(s1)); + acc[cp] = vmlaq_s32(acc[cp], sb_acc[0], block_scale); + acc[cp + 4] = vmlaq_s32(acc[cp + 4], sb_acc[2], block_scale); + + const int8x16_t qs_hi_3 = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(qs_cp_3, 4), hbit_hi_3)); + acc_1 = vmmlaq_s32(acc_1, qs_hi_3, q8s[0][7]); + sb_acc[1] = acc_1; + acc_3 = vmmlaq_s32(acc_3, qs_hi_3, q8s[1][7]); + sb_acc[3] = acc_3; + + const int32_t s2 = q5sb_scales[1][scale_offset]; + const int32_t s3 = q5sb_scales[1][scale_offset + 1]; + const int32x4_t block_scale2 = vcombine_s32(vdup_n_s32(s2), vdup_n_s32(s3)); + acc[cp] = vmlaq_s32(acc[cp], sb_acc[1], block_scale2); + acc[cp + 4] = vmlaq_s32(acc[cp + 4], sb_acc[3], block_scale2); + } + + // Multiply Acc bsum + mins + for (int q8_row = 0; q8_row < 4; q8_row++) { + // Each pair of subblocks share the same bsums + // Load scalar bsum → broadcast to a vector (vdupq_n_s16(s)). + int16x4_t bsums_vec_lo = vdup_n_s16(bsums_arr[sb][q8_row * 2]); + int16x4_t bsums_vec_hi = vdup_n_s16(bsums_arr[sb][q8_row * 2 + 1]); + + bias_acc[2 * q8_row] = + vmlal_s16(bias_acc[2 * q8_row], bsums_vec_lo, vget_low_s16(q5sb_mins[0])); + bias_acc[2 * q8_row] = + vmlal_s16(bias_acc[2 * q8_row], bsums_vec_hi, vget_low_s16(q5sb_mins[1])); + bias_acc[2 * q8_row + 1] = + vmlal_s16(bias_acc[2 * q8_row + 1], bsums_vec_lo, vget_high_s16(q5sb_mins[0])); + bias_acc[2 * q8_row + 1] = + vmlal_s16(bias_acc[2 * q8_row + 1], bsums_vec_hi, vget_high_s16(q5sb_mins[1])); + } + } // for sb + + // Reorder of i8mm output with bias and output layout + for (int i = 0; i < 8; i++) { + int32x2x2_t aux = vzip_s32(vget_low_s32(acc[i]), vget_high_s32(acc[i])); + acc[i] = vcombine_s32(aux.val[0], aux.val[1]); } + int32x4_t reorder_acc[8] = { + vcombine_s32(vget_low_s32(acc[0]), vget_low_s32(acc[1])), + vcombine_s32(vget_low_s32(acc[2]), vget_low_s32(acc[3])), + vcombine_s32(vget_high_s32(acc[0]), vget_high_s32(acc[1])), + vcombine_s32(vget_high_s32(acc[2]), vget_high_s32(acc[3])), + vcombine_s32(vget_low_s32(acc[4]), vget_low_s32(acc[5])), + vcombine_s32(vget_low_s32(acc[6]), vget_low_s32(acc[7])), + vcombine_s32(vget_high_s32(acc[4]), vget_high_s32(acc[5])), + vcombine_s32(vget_high_s32(acc[6]), vget_high_s32(acc[7])), + }; - sumf[0] = vmlaq_f32(sumf[0], vmulq_laneq_f32(b_d, a_d, 0), vcvtq_f32_s32(sumi_0)); - sumf[1] = vmlaq_f32(sumf[1], vmulq_laneq_f32(b_d, a_d, 1), vcvtq_f32_s32(sumi_1)); - sumf[2] = vmlaq_f32(sumf[2], vmulq_laneq_f32(b_d, a_d, 2), vcvtq_f32_s32(sumi_2)); - sumf[3] = vmlaq_f32(sumf[3], vmulq_laneq_f32(b_d, a_d, 3), vcvtq_f32_s32(sumi_3)); - } + for (int i = 0; i < q8_k_blocklen; i++) { + for (int j = 0; j < 2; j++) { + float32x4_t q8_d = vdupq_n_f32(q8_ptr[b].d[i]); + float32x4_t q5_dmin = vcvt_f32_f16(vld1_f16((const __fp16 *) (q5_ptr[b].dmin + j * 4))); + const float32x4_t dmins = vmulq_f32(q5_dmin, q8_d); - for (int m = 0; m < 4; m++) { - vst1q_f32(s + (y * 4 + m) * bs + x * 4, sumf[m]); + float32x4_t q5_d = vcvt_f32_f16(vld1_f16((const __fp16 *) (q5_ptr[b].d + j * 4))); + const float32x4_t scale = vmulq_f32(q5_d, q8_d); + + acc_f32[2 * i + j] = vmlsq_f32(acc_f32[2 * i + j], vcvtq_f32_s32(bias_acc[2 * i + j]), dmins); + acc_f32[2 * i + j] = + vmlaq_f32(acc_f32[2 * i + j], vcvtq_f32_s32(reorder_acc[2 * i + j]), scale); + } + } + } // for b + + // With the previous reorder, the tile is already in the correct memory layout. + for (int i = 0; i < q8_k_blocklen; i++) { + int row = y * q8_k_blocklen + i; + for (int j = 0; j < 2; j++) { + int col = x * ncols_interleaved + j * 4; + int offset = row * bs + col; + vst1q_f32(s + offset, acc_f32[2 * i + j]); + } } - } - } + } // for x + } // for y return; -#endif // #if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON) - ggml_gemm_iq4_nl_4x4_q8_0_generic(n, s, bs, vx, vy, nr, nc); +#endif // defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_MATMUL_INT8) + ggml_gemm_q5_K_8x8_q8_K_generic(n, s, bs, vx, vy, nr, nc); } -void ggml_gemm_q4_K_8x4_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { +void ggml_gemm_q6_K_8x4_q8_K(int n, + float * GGML_RESTRICT s, + size_t bs, + const void * GGML_RESTRICT vx, + const void * GGML_RESTRICT vy, + int nr, + int nc) { constexpr int qk = QK_K; const int nb = n / qk; @@ -2346,171 +4539,167 @@ void ggml_gemm_q4_K_8x4_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo #if defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD) constexpr int q8_k_blocklen = 4; - constexpr int acc_size = 2 * 4; // 2 row pairs × 4 col pairs - const uint8x16_t m4b = vdupq_n_u8(0x0f); - - // 8 accumulators: 2 row pairs × 4 col pairs - float32x4_t acc_f32[acc_size]; - - for (int y = 0; y < nr / q8_k_blocklen; y++) { - const block_q8_Kx4 * GGML_RESTRICT q8_ptr = (const block_q8_Kx4 *) vy + (y * nb); - - for (int x = 0; x < nc / ncols_interleaved; x++) { - const block_q4_Kx8 * GGML_RESTRICT q4_ptr = (const block_q4_Kx8 *) vx + (x * nb); - - for (int i = 0; i < acc_size; i++) { - acc_f32[i] = vdupq_n_f32(0); - } - - for (int b = 0; b < nb; b++) { - // d4 0 1 2 3, 4 5 6 7 - float32x4_t q4_d_0123 = vcvt_f32_f16(vld1_f16((const __fp16 *) q4_ptr[b].d)); - float32x4_t q4_d_4567 = vcvt_f32_f16(vld1_f16((const __fp16 *) q4_ptr[b].d + 4)); - // d8 0 1 2 3 - float32x4_t q8_d_0123 = vld1q_f32(q8_ptr[b].d); - // mins - float32x4_t q4_dmin_0123 = vcvt_f32_f16(vld1_f16((const __fp16 *) q4_ptr[b].dmin)); - float32x4_t q4_dmin_4567 = vcvt_f32_f16(vld1_f16((const __fp16 *) q4_ptr[b].dmin + 4)); - - // Precomputation of scales and mins - float32x4_t sbd_scale_0123[q8_k_blocklen]; - float32x4_t sbd_scale_4567[q8_k_blocklen]; - float32x4_t sbd_min_0123[q8_k_blocklen]; - float32x4_t sbd_min_4567[q8_k_blocklen]; - - sbd_scale_0123[0] = vmulq_laneq_f32(q4_d_0123, q8_d_0123, 0); - sbd_scale_4567[0] = vmulq_laneq_f32(q4_d_4567, q8_d_0123, 0); - sbd_min_0123[0] = vmulq_laneq_f32(q4_dmin_0123, q8_d_0123, 0); - sbd_min_4567[0] = vmulq_laneq_f32(q4_dmin_4567, q8_d_0123, 0); - - sbd_scale_0123[1] = vmulq_laneq_f32(q4_d_0123, q8_d_0123, 1); - sbd_scale_4567[1] = vmulq_laneq_f32(q4_d_4567, q8_d_0123, 1); - sbd_min_0123[1] = vmulq_laneq_f32(q4_dmin_0123, q8_d_0123, 1); - sbd_min_4567[1] = vmulq_laneq_f32(q4_dmin_4567, q8_d_0123, 1); - - sbd_scale_0123[2] = vmulq_laneq_f32(q4_d_0123, q8_d_0123, 2); - sbd_scale_4567[2] = vmulq_laneq_f32(q4_d_4567, q8_d_0123, 2); - sbd_min_0123[2] = vmulq_laneq_f32(q4_dmin_0123, q8_d_0123, 2); - sbd_min_4567[2] = vmulq_laneq_f32(q4_dmin_4567, q8_d_0123, 2); + constexpr int col_groups = ncols_interleaved / 4; + constexpr int acc_size = q8_k_blocklen * col_groups; // 4 rows, 2 column groups + const uint8x16_t m4b = vdupq_n_u8(0x0f); + const uint8x16_t mask_lo = vdupq_n_u8(0x03); + const uint8x16_t mask_hi = vdupq_n_u8(0x30); + const int8x16_t m32s = vdupq_n_s8(32); - sbd_scale_0123[3] = vmulq_laneq_f32(q4_d_0123, q8_d_0123, 3); - sbd_scale_4567[3] = vmulq_laneq_f32(q4_d_4567, q8_d_0123, 3); - sbd_min_0123[3] = vmulq_laneq_f32(q4_dmin_0123, q8_d_0123, 3); - sbd_min_4567[3] = vmulq_laneq_f32(q4_dmin_4567, q8_d_0123, 3); + float32x4_t acc_f32[acc_size]; - // Precomputation of bsums, each vpaddq calcs all the bsums for each row - const int16x8_t bsums[q8_k_blocklen] = { - vpaddq_s16(vld1q_s16(q8_ptr[b].bsums + 16 * 0), vld1q_s16(q8_ptr[b].bsums + 16 * 0 + 8)), - vpaddq_s16(vld1q_s16(q8_ptr[b].bsums + 16 * 1), vld1q_s16(q8_ptr[b].bsums + 16 * 1 + 8)), - vpaddq_s16(vld1q_s16(q8_ptr[b].bsums + 16 * 2), vld1q_s16(q8_ptr[b].bsums + 16 * 2 + 8)), - vpaddq_s16(vld1q_s16(q8_ptr[b].bsums + 16 * 3), vld1q_s16(q8_ptr[b].bsums + 16 * 3 + 8)), - }; - int16_t bsums_arr[QK_K / 64][8]; - for (int q8_row = 0; q8_row < 4; q8_row++) { - vst1q_s16(bsums_arr[q8_row], bsums[q8_row]); - } + for (int y = 0; y < nr / q8_k_blocklen; y++) { + const block_q8_Kx4 * GGML_RESTRICT q8_ptr = (const block_q8_Kx4 *) vy + (y * nb); - // interleaved bias_acc: [0]->r0 0123, [1]->r1 0123, .., [4]->r0 4567, [5]->r1 4567 .. - int32x4_t bias_acc[acc_size]; - for (int i = 0; i < acc_size; i++) { - bias_acc[i] = vdupq_n_s32(0); - } + for (int x = 0; x < nc / ncols_interleaved; x++) { + const block_q6_Kx8 * GGML_RESTRICT q6_ptr = (const block_q6_Kx8 *) vx + (x * nb); - for (int sb = 0; sb < QK_K / 64; sb++) { - // Int accumulators for qs vecdot (4 row x 2 col quartets) - int32x4_t acc_lo[acc_size]; - int32x4_t acc_hi[acc_size]; - for (int i = 0; i < acc_size; i++) { - acc_lo[i] = vdupq_n_s32(0); - acc_hi[i] = vdupq_n_s32(0); - } - // Need scales for the low and high nibbles - // 2 * 12 = 24 bytes per subblock, 4 sbs -> 4 * 24 = 96 bytes total - int16x8_t q4sb_scales[2]; - int16x8_t q4sb_mins[2]; - for (int i = 0; i < 2; i++) { - int8_t aux_q4sb[8]; - const int offset = sb * 24 + i * 12; - decode_q4_Kx8_scales_mins(&q4_ptr[b].scales[offset], &q4sb_mins[i], aux_q4sb); - q4sb_scales[i] = vmovl_s8(vld1_s8(aux_q4sb)); - } + for (int i = 0; i < acc_size; i++) { + acc_f32[i] = vdupq_n_f32(0); + } - constexpr int reads_per_sb = 8; // 8 * 16 bytes each => 32 qs * 4 rows - for (int k = 0; k < reads_per_sb; k++) { - const int8x16_t q8_blk0 = vld1q_s8(q8_ptr[b].qs + sb * 256 + 16 * k); - const int8x16_t q8_blk1 = vld1q_s8(q8_ptr[b].qs + sb * 256 + 16 * k + 128); + for (int b = 0; b < nb; b++) { + float32x4_t q6_d_0123 = vcvt_f32_f16(vld1_f16((const __fp16 *) q6_ptr[b].d)); + float32x4_t q6_d_4567 = vcvt_f32_f16(vld1_f16((const __fp16 *) q6_ptr[b].d + 4)); + float32x4_t q8_d_0123 = vld1q_f32(q8_ptr[b].d); - // 0..3 & 32..35 - const uint8x16_t q4_0123 = vld1q_u8(q4_ptr[b].qs + sb * QK_K + 32 * k); - const uint8x16_t q4_4567 = vld1q_u8(q4_ptr[b].qs + sb * QK_K + 32 * k + 16); + float32x4_t sbd_scale_0123[q8_k_blocklen]; + float32x4_t sbd_scale_4567[q8_k_blocklen]; - const int8x16_t q4_0123_lo = vreinterpretq_s8_u8(vandq_u8(q4_0123, m4b)); - const int8x16_t q4_0123_hi = vreinterpretq_s8_u8(vshrq_n_u8(q4_0123, 4)); + sbd_scale_0123[0] = vmulq_laneq_f32(q6_d_0123, q8_d_0123, 0); + sbd_scale_4567[0] = vmulq_laneq_f32(q6_d_4567, q8_d_0123, 0); + sbd_scale_0123[1] = vmulq_laneq_f32(q6_d_0123, q8_d_0123, 1); + sbd_scale_4567[1] = vmulq_laneq_f32(q6_d_4567, q8_d_0123, 1); + sbd_scale_0123[2] = vmulq_laneq_f32(q6_d_0123, q8_d_0123, 2); + sbd_scale_4567[2] = vmulq_laneq_f32(q6_d_4567, q8_d_0123, 2); + sbd_scale_0123[3] = vmulq_laneq_f32(q6_d_0123, q8_d_0123, 3); + sbd_scale_4567[3] = vmulq_laneq_f32(q6_d_4567, q8_d_0123, 3); - acc_lo[0] = vdotq_laneq_s32(acc_lo[0], q4_0123_lo, q8_blk0, 0); // 0..3 r0 c0123 - acc_lo[1] = vdotq_laneq_s32(acc_lo[1], q4_0123_lo, q8_blk0, 1); // 0..3 r1 c0123 - acc_lo[2] = vdotq_laneq_s32(acc_lo[2], q4_0123_lo, q8_blk0, 2); // 0..3 r2 c0123 - acc_lo[3] = vdotq_laneq_s32(acc_lo[3], q4_0123_lo, q8_blk0, 3); // 0..3 r3 c0123 + int32x4_t acc_s32[acc_size]; + for (int i = 0; i < acc_size; i++) { + acc_s32[i] = vdupq_n_s32(0); + } - acc_hi[0] = vdotq_laneq_s32(acc_hi[0], q4_0123_hi, q8_blk1, 0); // 32..35 r0 c0123 - acc_hi[1] = vdotq_laneq_s32(acc_hi[1], q4_0123_hi, q8_blk1, 1); // 32..35 r1 c0123 - acc_hi[2] = vdotq_laneq_s32(acc_hi[2], q4_0123_hi, q8_blk1, 2); // 32..35 r2 c0123 - acc_hi[3] = vdotq_laneq_s32(acc_hi[3], q4_0123_hi, q8_blk1, 3); // 32..35 r3 c0123 + int16_t q6_scales[8 * 16]; + for (int i = 0; i < 16; i++) { + int16x8_t scales = vmovl_s8(vld1_s8(q6_ptr[b].scales + i * 8)); + vst1q_s16(q6_scales + i * 8, scales); + } - const int8x16_t q4_4567_lo = vreinterpretq_s8_u8(vandq_u8(q4_4567, m4b)); - const int8x16_t q4_4567_hi = vreinterpretq_s8_u8(vshrq_n_u8(q4_4567, 4)); + for (int half = 0; half < 2; half++) { + const uint8_t * ql_base = q6_ptr[b].ql + half * 512; + const uint8_t * qh_base = q6_ptr[b].qh + half * 256; - acc_lo[4] = vdotq_laneq_s32(acc_lo[4], q4_4567_lo, q8_blk0, 0); // 0..3 r0 c4567 - acc_lo[5] = vdotq_laneq_s32(acc_lo[5], q4_4567_lo, q8_blk0, 1); // 0..3 r1 c4567 - acc_lo[6] = vdotq_laneq_s32(acc_lo[6], q4_4567_lo, q8_blk0, 2); // 0..3 r2 c4567 - acc_lo[7] = vdotq_laneq_s32(acc_lo[7], q4_4567_lo, q8_blk0, 3); // 0..3 r3 c4567 + for (int sb = 0; sb < QK_K / 64; sb++) { + int32x4_t acc_lo[acc_size]; + int32x4_t acc_hi[acc_size]; + for (int i = 0; i < acc_size; i++) { + acc_lo[i] = vdupq_n_s32(0); + acc_hi[i] = vdupq_n_s32(0); + } - acc_hi[4] = vdotq_laneq_s32(acc_hi[4], q4_4567_hi, q8_blk1, 0); // 32..35 r0 c4567 - acc_hi[5] = vdotq_laneq_s32(acc_hi[5], q4_4567_hi, q8_blk1, 1); // 32..35 r1 c4567 - acc_hi[6] = vdotq_laneq_s32(acc_hi[6], q4_4567_hi, q8_blk1, 2); // 32..35 r2 c4567 - acc_hi[7] = vdotq_laneq_s32(acc_hi[7], q4_4567_hi, q8_blk1, 3); // 32..35 r3 c4567 - } + const int8_t * q8_base_l = q8_ptr[b].qs + half * 512 + sb * 64; + const int8_t * q8_base_h = q8_ptr[b].qs + half * 512 + 256 + sb * 64; + + // 4 rows * 16 elements per scale + // 4 reads of 16 bytes each + constexpr int reads_per_sb = 4; + int8x16_t q8_l[reads_per_sb]; + int8x16_t q8_h[reads_per_sb]; + for (int k = 0; k < reads_per_sb; k++) { + q8_l[k] = vld1q_s8(q8_base_l + 16 * k); + q8_h[k] = vld1q_s8(q8_base_h + 16 * k); + } - // Scale and bias application - // acc is stored interleaved to match output layout - const int16x4_t sc_0123_lo = vget_low_s16(q4sb_scales[0]); - const int16x4_t sc_4567_lo = vget_high_s16(q4sb_scales[0]); - const int16x4_t sc_0123_hi = vget_low_s16(q4sb_scales[1]); - const int16x4_t sc_4567_hi = vget_high_s16(q4sb_scales[1]); - for (int row = 0; row < q8_k_blocklen; row++) { - // Bias correction - // row c0123 blk0 and blk1 - const float32x4_t sumf_0123 = - vcvtq_f32_s32(vaddq_s32(vmulq_s32(vmovl_s16(sc_0123_lo), acc_lo[row]), - vmulq_s32(vmovl_s16(sc_0123_hi), acc_hi[row]))); - acc_f32[2 * row] = vfmaq_f32(acc_f32[2 * row], sbd_scale_0123[row], sumf_0123); + const int ql_off_base = sb * QK_K / 2; + const int qh_off_base = ql_off_base & 255; - // row c4567 blk0 and blk1 - const float32x4_t sumf_4567 = - vcvtq_f32_s32(vaddq_s32(vmulq_s32(vmovl_s16(sc_4567_lo), acc_lo[row + 4]), - vmulq_s32(vmovl_s16(sc_4567_hi), acc_hi[row + 4]))); - acc_f32[2 * row + 1] = vfmaq_f32(acc_f32[2 * row + 1], sbd_scale_4567[row], sumf_4567); + uint8x16_t q6_ql_0123[reads_per_sb]; + uint8x16_t q6_ql_4567[reads_per_sb]; + uint8x16_t q6_qh_0123[reads_per_sb]; + uint8x16_t q6_qh_4567[reads_per_sb]; - // Bias - const int16x4_t bsums_vec_lo = vdup_n_s16(bsums_arr[sb][row * 2]); - const int16x4_t bsums_vec_hi = vdup_n_s16(bsums_arr[sb][row * 2 + 1]); + for (int k = 0; k < reads_per_sb; k++) { + q6_ql_0123[k] = vld1q_u8(ql_base + ql_off_base + k * 32); + q6_ql_4567[k] = vld1q_u8(ql_base + ql_off_base + k * 32 + 16); + q6_qh_0123[k] = vld1q_u8(qh_base + qh_off_base + k * 32); + q6_qh_4567[k] = vld1q_u8(qh_base + qh_off_base + k * 32 + 16); + } - // row c0123 blk0 and blk1 - bias_acc[2 * row] = vmlal_s16(bias_acc[2 * row], bsums_vec_lo, vget_low_s16(q4sb_mins[0])); - bias_acc[2 * row] = vmlal_s16(bias_acc[2 * row], bsums_vec_hi, vget_low_s16(q4sb_mins[1])); + if (sb > 1) { + for (int k = 0; k < reads_per_sb; k++) { + q6_qh_0123[k] = vshrq_n_u8(q6_qh_0123[k], 2); + q6_qh_4567[k] = vshrq_n_u8(q6_qh_4567[k], 2); + } + } - // row c4567 blk0 and blk1 - bias_acc[2 * row + 1] = - vmlal_s16(bias_acc[2 * row + 1], bsums_vec_lo, vget_high_s16(q4sb_mins[0])); - bias_acc[2 * row + 1] = - vmlal_s16(bias_acc[2 * row + 1], bsums_vec_hi, vget_high_s16(q4sb_mins[1])); + for (int k = 0; k < reads_per_sb; k++) { + // q = (ql | qh) - 32 + const uint8x16_t hbit_lo_0123 = vandq_u8(q6_qh_0123[k], mask_lo); + const uint8x16_t hbit_hi_0123 = vandq_u8(q6_qh_0123[k], mask_hi); + const uint8x16_t hbit_lo_4567 = vandq_u8(q6_qh_4567[k], mask_lo); + const uint8x16_t hbit_hi_4567 = vandq_u8(q6_qh_4567[k], mask_hi); + + const int8x16_t q6_0123_lo = vsubq_s8( + vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(q6_ql_0123[k], m4b), hbit_lo_0123, 4)), m32s); + const int8x16_t q6_0123_hi = vsubq_s8( + vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6_ql_0123[k], 4), hbit_hi_0123)), m32s); + + acc_lo[0] = vdotq_laneq_s32(acc_lo[0], q6_0123_lo, q8_l[k], 0); // 0..3 r0 c0123 + acc_lo[1] = vdotq_laneq_s32(acc_lo[1], q6_0123_lo, q8_l[k], 1); // 0..3 r1 c0123 + acc_lo[2] = vdotq_laneq_s32(acc_lo[2], q6_0123_lo, q8_l[k], 2); // 0..3 r2 c0123 + acc_lo[3] = vdotq_laneq_s32(acc_lo[3], q6_0123_lo, q8_l[k], 3); // 0..3 r3 c0123 + + acc_hi[0] = vdotq_laneq_s32(acc_hi[0], q6_0123_hi, q8_h[k], 0); // 64..67 r0 c0123 + acc_hi[1] = vdotq_laneq_s32(acc_hi[1], q6_0123_hi, q8_h[k], 1); // 64..67 r1 c0123 + acc_hi[2] = vdotq_laneq_s32(acc_hi[2], q6_0123_hi, q8_h[k], 2); // 64..67 r2 c0123 + acc_hi[3] = vdotq_laneq_s32(acc_hi[3], q6_0123_hi, q8_h[k], 3); // 64..67 r3 c0123 + + const int8x16_t q6_4567_lo = vsubq_s8( + vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(q6_ql_4567[k], m4b), hbit_lo_4567, 4)), m32s); + const int8x16_t q6_4567_hi = vsubq_s8( + vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6_ql_4567[k], 4), hbit_hi_4567)), m32s); + + acc_lo[4] = vdotq_laneq_s32(acc_lo[4], q6_4567_lo, q8_l[k], 0); // 0..3 r0 c4567 + acc_lo[5] = vdotq_laneq_s32(acc_lo[5], q6_4567_lo, q8_l[k], 1); // 0..3 r1 c4567 + acc_lo[6] = vdotq_laneq_s32(acc_lo[6], q6_4567_lo, q8_l[k], 2); // 0..3 r2 c4567 + acc_lo[7] = vdotq_laneq_s32(acc_lo[7], q6_4567_lo, q8_l[k], 3); // 0..3 r3 c4567 + + acc_hi[4] = vdotq_laneq_s32(acc_hi[4], q6_4567_hi, q8_h[k], 0); // 64..67 r0 c4567 + acc_hi[5] = vdotq_laneq_s32(acc_hi[5], q6_4567_hi, q8_h[k], 1); // 64..67 r1 c4567 + acc_hi[6] = vdotq_laneq_s32(acc_hi[6], q6_4567_hi, q8_h[k], 2); // 64..67 r2 c4567 + acc_hi[7] = vdotq_laneq_s32(acc_hi[7], q6_4567_hi, q8_h[k], 3); // 64..67 r3 c4567 + } + + // Scale and bias + const int scale_idx_l = half * 8 + sb; + const int scale_idx_h = half * 8 + sb + 4; + + for (int g = 0; g < col_groups; g++) { + const int16x4_t scales_l16 = vld1_s16(q6_scales + scale_idx_l * 8 + g * 4); + const int16x4_t scales_h16 = vld1_s16(q6_scales + scale_idx_h * 8 + g * 4); + const int32x4_t scale_vec_l = vmovl_s16(scales_l16); + const int32x4_t scale_vec_h = vmovl_s16(scales_h16); + const int acc_offset = g * q8_k_blocklen; + + for (int row = 0; row < q8_k_blocklen; row++) { + const int idx = row * 2 + g; + acc_s32[idx] = vmlaq_s32(acc_s32[idx], acc_lo[acc_offset + row], scale_vec_l); + acc_s32[idx] = vmlaq_s32(acc_s32[idx], acc_hi[acc_offset + row], scale_vec_h); + } + } } - } // for sb + } + // Finally we apply the superblock scales for (int row = 0; row < q8_k_blocklen; row++) { - acc_f32[2 * row] = vmlsq_f32(acc_f32[2 * row], vcvtq_f32_s32(bias_acc[2 * row]), sbd_min_0123[row]); - acc_f32[2 * row + 1] = - vmlsq_f32(acc_f32[2 * row + 1], vcvtq_f32_s32(bias_acc[2 * row + 1]), sbd_min_4567[row]); + const int idx0 = 2 * row; + const int idx1 = 2 * row + 1; + const int32x4_t acc_0123 = acc_s32[idx0]; + const int32x4_t acc_4567 = acc_s32[idx1]; + + acc_f32[idx0] = vmlaq_f32(acc_f32[idx0], vcvtq_f32_s32(acc_0123), sbd_scale_0123[row]); + acc_f32[idx1] = vmlaq_f32(acc_f32[idx1], vcvtq_f32_s32(acc_4567), sbd_scale_4567[row]); } } // for b @@ -2526,10 +4715,10 @@ void ggml_gemm_q4_K_8x4_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo } // for y return; #endif // defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD) - ggml_gemm_q4_K_8x4_q8_K_generic(n, s, bs, vx, vy, nr, nc); + ggml_gemm_q6_K_8x4_q8_K_generic(n, s, bs, vx, vy, nr, nc); } -void ggml_gemm_q4_K_8x8_q8_K(int n, +void ggml_gemm_q6_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, @@ -2553,144 +4742,155 @@ void ggml_gemm_q4_K_8x8_q8_K(int n, #if defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_MATMUL_INT8) constexpr int q8_k_blocklen = 4; const uint8x16_t m4b = vdupq_n_u8(0x0f); + const uint8x16_t mask_lo = vdupq_n_u8(0x03); + const uint8x16_t mask_hi = vdupq_n_u8(0x30); + const int8x16_t m32s = vdupq_n_s8(32); - // 8 accumulators: 2 row pairs × 4 col pairs + // 8 accumulators: 4 q8 rows × 2 col groups (0-3, 4-7) float32x4_t acc_f32[blocklen]; for (int y = 0; y < nr / q8_k_blocklen; y++) { const block_q8_Kx4 * GGML_RESTRICT q8_ptr = (const block_q8_Kx4 *) vy + (y * nb); for (int x = 0; x < nc / ncols_interleaved; x++) { - const block_q4_Kx8 * GGML_RESTRICT q4_ptr = (const block_q4_Kx8 *) vx + (x * nb); + const block_q6_Kx8 * GGML_RESTRICT q6_ptr = (const block_q6_Kx8 *) vx + (x * nb); for (int i = 0; i < blocklen; i++) { acc_f32[i] = vdupq_n_f32(0); } for (int b = 0; b < nb; b++) { - // bsums pairs belongs to the same q8_k subblock - const int16x8_t bsums[4]{ - vpaddq_s16(vld1q_s16(q8_ptr[b].bsums + 16 * 0), vld1q_s16(q8_ptr[b].bsums + 16 * 0 + 8)), - vpaddq_s16(vld1q_s16(q8_ptr[b].bsums + 16 * 1), vld1q_s16(q8_ptr[b].bsums + 16 * 1 + 8)), - vpaddq_s16(vld1q_s16(q8_ptr[b].bsums + 16 * 2), vld1q_s16(q8_ptr[b].bsums + 16 * 2 + 8)), - vpaddq_s16(vld1q_s16(q8_ptr[b].bsums + 16 * 3), vld1q_s16(q8_ptr[b].bsums + 16 * 3 + 8)), - }; - int16_t bsums_arr[4][8]; - for (int q8_row = 0; q8_row < 4; q8_row++) { - vst1q_s16(bsums_arr[q8_row], bsums[q8_row]); - } - - int32x4_t sb_acc[4]; // Aux accumulators to store subblock (partial) results - int32x4_t acc[8]; // rows 01 stored in [0][1][2][3] rows 23 stored in [4][5][6][7] - int32x4_t bias_acc[8]; // interleaved bias_acc: [0]->r0 0123, [1]->r0 4567, [2]->r1 0123 ... + int32x4_t acc[8]; // rows 01 stored in [0][1][2][3], rows 23 stored in [4][5][6][7] for (int i = 0; i < 8; i++) { - acc[i] = vdupq_n_s32(0); - bias_acc[i] = vdupq_n_s32(0); + acc[i] = vdupq_n_s32(0); } - for (int sb = 0; sb < QK_K / 64; sb++) { - // Need scales for the low and high nibbles - // 2 * 12 = 24 bytes per subblock, 4 sbs -> 4 * 24 = 96 bytes total - int8_t q4sb_scales[2][8]; - int16x8_t q4sb_mins[2]; // int16 as its needed for bias_acc later - for (int i = 0; i < 2; i++) { - const int offset = sb * 24 + i * 12; - decode_q4_Kx8_scales_mins(&q4_ptr[b].scales[offset], &q4sb_mins[i], q4sb_scales[i]); - } - - // q8_ptr[b].qs has interleaved Q8 rows (01, 23) - const int8_t * q8_base = q8_ptr[b].qs + sb * 256; + // Q6_K has simple 8-bit scales, 16 per block (one per 16 values) + // Reused for bias and dequantization later + int16_t q6_scales[16 * 8]; + for (int i = 0; i < 16; ++i) { + int16x8_t s16 = vmovl_s8(vld1_s8(q6_ptr[b].scales + i * 8)); + vst1q_s16(q6_scales + i * 8, s16); + } - int8x16_t q8_qs_01[8]; - int8x16_t q8_qs_23[8]; + // Process two 128-value halves per superblock + for (int half = 0; half < 2; half++) { + + const uint8_t * ql_base = q6_ptr[b].ql + half * 512; + const uint8_t * qh_base = q6_ptr[b].qh + half * 256; + + // A subblock (sb) is a set of weights that share the scale + // Since q6_K scales are per 16 elements + // num sbs -> 256 elements / (16 elements/scale * 2 elements/byte * 2 halves) + for (int sb = 0; sb < QK_K / 64; sb++) { + // Q6_K weight index increasing by 64 instead of 32 requires + // loading various q8 memory regions + const int8_t * q8_base_l = q8_ptr[b].qs + half * 512 + sb * 64; + const int8_t * q8_base_h = q8_ptr[b].qs + half * 512 + 256 + sb * 64; + + int8x16_t q8_l_01[2]; + int8x16_t q8_l_23[2]; + for (int i = 0; i < 2; i++) { + const int offset = i * 32; + q8_l_01[i] = vld1q_s8(q8_base_l + offset); // 0..7 & 8..15 (r01) + q8_l_23[i] = vld1q_s8(q8_base_l + offset + 16); // 0..7 & 8..15 (r23) + } - // Load 32-byte per row pair, 1 subblock each time - for (int i = 0; i < 8; i++) { - const int offset = i * 32; // 16 for row 01, 16 for row 23 - q8_qs_01[i] = vld1q_s8(q8_base + offset); - q8_qs_23[i] = vld1q_s8(q8_base + offset + 16); - } + int8x16_t q8_h_01[2]; + int8x16_t q8_h_23[2]; + for (int i = 0; i < 2; i++) { + const int offset = i * 32; + q8_h_01[i] = vld1q_s8(q8_base_h + offset); + q8_h_23[i] = vld1q_s8(q8_base_h + offset + 16); + } - const int8x16_t q8s[2][8] = { - { q8_qs_01[0], q8_qs_01[1], q8_qs_01[2], q8_qs_01[3], - q8_qs_01[4], q8_qs_01[5], q8_qs_01[6], q8_qs_01[7] }, - { q8_qs_23[0], q8_qs_23[1], q8_qs_23[2], q8_qs_23[3], - q8_qs_23[4], q8_qs_23[5], q8_qs_23[6], q8_qs_23[7] }, - }; + const int ql_off_base = sb * QK_K / 2; - // Q4s columns iterated in pairs (01, 23, 45, 67) - for (int cp = 0; cp < ncols_interleaved / 2; cp++) { - for (int i = 0; i < 4; i++) { - sb_acc[i] = vdupq_n_s32(0); + uint8x16_t q6_ql_0[4]; + uint8x16_t q6_ql_1[4]; + for (int k = 0; k < 4; k++) { + q6_ql_0[k] = vld1q_u8(ql_base + ql_off_base + 16 * k); + q6_ql_1[k] = vld1q_u8(ql_base + ql_off_base + 64 + 16 * k); } - uint8x16_t q4_qs_cp_0 = vld1q_u8(q4_ptr[b].qs + sb * QK_K + 16 * cp + 0); // 0 .. 7 & 32..39 - uint8x16_t q4_qs_cp_1 = vld1q_u8(q4_ptr[b].qs + sb * QK_K + 16 * cp + 64); // 8 ..15 & 40..47 - uint8x16_t q4_qs_cp_2 = vld1q_u8(q4_ptr[b].qs + sb * QK_K + 16 * cp + 128); // 16..23 & 48..55 - uint8x16_t q4_qs_cp_3 = vld1q_u8(q4_ptr[b].qs + sb * QK_K + 16 * cp + 192); // 24..31 & 56..63 - const int8x16_t q4_nibbles[2][4] = { - { - vreinterpretq_s8_u8(vandq_u8(q4_qs_cp_0, m4b)), - vreinterpretq_s8_u8(vandq_u8(q4_qs_cp_1, m4b)), - vreinterpretq_s8_u8(vandq_u8(q4_qs_cp_2, m4b)), - vreinterpretq_s8_u8(vandq_u8(q4_qs_cp_3, m4b)), - }, - { - vreinterpretq_s8_u8(vshrq_n_u8(q4_qs_cp_0, 4)), - vreinterpretq_s8_u8(vshrq_n_u8(q4_qs_cp_1, 4)), - vreinterpretq_s8_u8(vshrq_n_u8(q4_qs_cp_2, 4)), - vreinterpretq_s8_u8(vshrq_n_u8(q4_qs_cp_3, 4)), - } - }; + const int qh_off_base = (sb * QK_K / 2) & 255; // wrap after 256 bytes + uint8x16_t q6_qh_0[4]; + uint8x16_t q6_qh_1[4]; + for (int k = 0; k < 4; k++) { + q6_qh_0[k] = vld1q_u8(qh_base + qh_off_base + 16 * k); + q6_qh_1[k] = vld1q_u8(qh_base + qh_off_base + 64 + 16 * k); + } - // Calculates the Qs muladd of every row pair (rp) rows 01 and 23 of q8 - // for each of the internal 32 qs subblock (blk) - for (int rp = 0; rp < 2; rp++) { - for (int blk = 0; blk < 2; blk++) { - const int8x16_t * q8 = &q8s[rp][4 * blk]; - const int8x16_t * q4 = q4_nibbles[blk]; - int32x4_t acc = sb_acc[2 * rp + blk]; - // mul add for each qs in the same subblock - for (int qs_offset = 0; qs_offset < 4; qs_offset++) { - acc = vmmlaq_s32(acc, q4[qs_offset], q8[qs_offset]); - } - sb_acc[2 * rp + blk] = acc; + // Adjust for the proper high bits (Sb 2 and 3) + if (sb > 1) { + for (int k = 0; k < 4; k++) { + q6_qh_0[k] = vshrq_n_u8(q6_qh_0[k], 2); + q6_qh_1[k] = vshrq_n_u8(q6_qh_1[k], 2); } } - // Scales[i] corresponds to column i - const int scale_offset = cp * 2; - for (int blk = 0; blk < 2; blk++) { - const int32x4_t block_scale = { - (int32_t) q4sb_scales[blk][scale_offset], - (int32_t) q4sb_scales[blk][scale_offset], - (int32_t) q4sb_scales[blk][scale_offset + 1], - (int32_t) q4sb_scales[blk][scale_offset + 1], + // Process column pairs (0-1, 2-3, 4-5, 6-7) + for (int cp = 0; cp < ncols_interleaved / 2; cp++) { + const uint8x16_t q6_qs_cp_0_l = q6_ql_0[cp]; + const uint8x16_t q6_qs_cp_1_l = q6_ql_1[cp]; + const uint8x16_t q6_qs_cp_0_h = q6_qh_0[cp]; + const uint8x16_t q6_qs_cp_1_h = q6_qh_1[cp]; + + // Extract high 2 bits for upper nibble reconstruction + const uint8x16_t q6_qs_cp_0_hh = vandq_u8(q6_qs_cp_0_h, mask_hi); + const uint8x16_t q6_qs_cp_1_hh = vandq_u8(q6_qs_cp_1_h, mask_hi); + + // q6 = (low4 | high2<<4) - 32 + // Use vsliq_n_u8 to combine shift-left-insert in one instruction (like Q5_K) + const int8x16_t q6_l0 = vsubq_s8( + vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(q6_qs_cp_0_l, m4b), vandq_u8(q6_qs_cp_0_h, mask_lo), 4)), + m32s); + const int8x16_t q6_l1 = vsubq_s8( + vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(q6_qs_cp_1_l, m4b), vandq_u8(q6_qs_cp_1_h, mask_lo), 4)), + m32s); + const int8x16_t q6_h0 = vsubq_s8( + vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6_qs_cp_0_l, 4), q6_qs_cp_0_hh)), m32s); + const int8x16_t q6_h1 = vsubq_s8( + vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6_qs_cp_1_l, 4), q6_qs_cp_1_hh)), m32s); + + // row pair 0, base_l + int32x4_t sb_acc_0l = vmmlaq_s32(vdupq_n_s32(0), q6_l0, q8_l_01[0]); + sb_acc_0l = vmmlaq_s32(sb_acc_0l, q6_l1, q8_l_01[1]); + // row pair 0, base_h + int32x4_t sb_acc_0h = vmmlaq_s32(vdupq_n_s32(0), q6_h0, q8_h_01[0]); + sb_acc_0h = vmmlaq_s32(sb_acc_0h, q6_h1, q8_h_01[1]); + // row pair 1, base_l + int32x4_t sb_acc_1l = vmmlaq_s32(vdupq_n_s32(0), q6_l0, q8_l_23[0]); + sb_acc_1l = vmmlaq_s32(sb_acc_1l, q6_l1, q8_l_23[1]); + // row pair 1, base_h + int32x4_t sb_acc_1h = vmmlaq_s32(vdupq_n_s32(0), q6_h0, q8_h_23[0]); + sb_acc_1h = vmmlaq_s32(sb_acc_1h, q6_h1, q8_h_23[1]); + + const int scale_idx_l = half * 8 + sb; + const int scale_idx_h = half * 8 + sb + 4; + + const int32x4_t scale_vec_l = { + q6_scales[scale_idx_l * 8 + cp * 2 + 0], + q6_scales[scale_idx_l * 8 + cp * 2 + 0], + q6_scales[scale_idx_l * 8 + cp * 2 + 1], + q6_scales[scale_idx_l * 8 + cp * 2 + 1], + }; + const int32x4_t scale_vec_h = { + q6_scales[scale_idx_h * 8 + cp * 2 + 0], + q6_scales[scale_idx_h * 8 + cp * 2 + 0], + q6_scales[scale_idx_h * 8 + cp * 2 + 1], + q6_scales[scale_idx_h * 8 + cp * 2 + 1], }; - acc[cp] = vmlaq_s32(acc[cp], sb_acc[blk], block_scale); - acc[cp + 4] = vmlaq_s32(acc[cp + 4], sb_acc[blk + 2], block_scale); - } - } - - // Multiply Acc bsum + mins - for (int q8_row = 0; q8_row < 4; q8_row++) { - // Each pair of subblocks share the same bsums - // Load scalar bsum → broadcast to a vector (vdupq_n_s16(s)). - int16x4_t bsums_vec_lo = vdup_n_s16(bsums_arr[sb][q8_row * 2]); - int16x4_t bsums_vec_hi = vdup_n_s16(bsums_arr[sb][q8_row * 2 + 1]); - bias_acc[2 * q8_row] = - vmlal_s16(bias_acc[2 * q8_row], bsums_vec_lo, vget_low_s16(q4sb_mins[0])); - bias_acc[2 * q8_row] = - vmlal_s16(bias_acc[2 * q8_row], bsums_vec_hi, vget_low_s16(q4sb_mins[1])); - bias_acc[2 * q8_row + 1] = - vmlal_s16(bias_acc[2 * q8_row + 1], bsums_vec_lo, vget_high_s16(q4sb_mins[0])); - bias_acc[2 * q8_row + 1] = - vmlal_s16(bias_acc[2 * q8_row + 1], bsums_vec_hi, vget_high_s16(q4sb_mins[1])); + acc[cp] = vmlaq_s32(acc[cp], sb_acc_0l, scale_vec_l); + acc[cp] = vmlaq_s32(acc[cp], sb_acc_0h, scale_vec_h); + acc[cp + 4] = vmlaq_s32(acc[cp + 4], sb_acc_1l, scale_vec_l); + acc[cp + 4] = vmlaq_s32(acc[cp + 4], sb_acc_1h, scale_vec_h); + } } - } // for sb + } // for half - // Reorder of i8mm output with bias and output layout + // Reorder i8mm output to match memory layout for (int i = 0; i < 8; i++) { int32x2x2_t aux = vzip_s32(vget_low_s32(acc[i]), vget_high_s32(acc[i])); acc[i] = vcombine_s32(aux.val[0], aux.val[1]); @@ -2706,23 +4906,20 @@ void ggml_gemm_q4_K_8x8_q8_K(int n, vcombine_s32(vget_high_s32(acc[6]), vget_high_s32(acc[7])), }; + // Apply superblock scale (no mins for q6_K) for (int i = 0; i < q8_k_blocklen; i++) { for (int j = 0; j < 2; j++) { - float32x4_t q8_d = vdupq_n_f32(q8_ptr[b].d[i]); - float32x4_t q4_dmin = vcvt_f32_f16(vld1_f16((const __fp16 *) (q4_ptr[b].dmin + j * 4))); - const float32x4_t dmins = vmulq_f32(q4_dmin, q8_d); - - float32x4_t q4_d = vcvt_f32_f16(vld1_f16((const __fp16 *) (q4_ptr[b].d + j * 4))); - const float32x4_t scale = vmulq_f32(q4_d, q8_d); + float32x4_t q8_d = vdupq_n_f32(q8_ptr[b].d[i]); + float32x4_t q6_d = vcvt_f32_f16(vld1_f16((const __fp16 *) (q6_ptr[b].d + j * 4))); + const float32x4_t scale = vmulq_f32(q6_d, q8_d); - acc_f32[2 * i + j] = vmlsq_f32(acc_f32[2 * i + j], vcvtq_f32_s32(bias_acc[2 * i + j]), dmins); acc_f32[2 * i + j] = vmlaq_f32(acc_f32[2 * i + j], vcvtq_f32_s32(reorder_acc[2 * i + j]), scale); } } } // for b - // With the previous reorder, the tile is already in the correct memory layout. + // Store results for (int i = 0; i < q8_k_blocklen; i++) { int row = y * q8_k_blocklen + i; for (int j = 0; j < 2; j++) { @@ -2735,10 +4932,9 @@ void ggml_gemm_q4_K_8x8_q8_K(int n, } // for y return; #endif // defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_MATMUL_INT8) - ggml_gemm_q4_K_8x8_q8_K_generic(n, s, bs, vx, vy, nr, nc); + ggml_gemm_q6_K_8x8_q8_K_generic(n, s, bs, vx, vy, nr, nc); } - void ggml_gemm_q8_0_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, @@ -2827,6 +5023,71 @@ void ggml_gemm_q8_0_4x8_q8_0(int n, UNUSED(ncols_interleaved); UNUSED(blocklen); +#if defined(__aarch64__) && defined(__ARM_FEATURE_SVE) && defined(__ARM_FEATURE_MATMUL_INT8) + if (svcntb() * 8 == 256) { + const block_q8_0x4 * b_ptr_base = (const block_q8_0x4 *) vx; + + static const uint32_t idx_arr[8] = {0, 1, 4, 5, 2, 3, 6, 7}; + svuint32_t idx = svld1(svptrue_b32(), idx_arr); + static const uint32_t idx_arr1[8] = {0, 1, 2, 3, 1, 2, 3, 0}; + svuint32_t idx_sc1 = svld1(svptrue_b32(), idx_arr1); + static const uint32_t idx_arr2[8] = {0, 1, 2, 3, 0, 1, 2, 3}; + svuint32_t idx_sc2 = svld1(svptrue_b32(), idx_arr2); + + for (int y = 0; y < nr; y += 4) { + const block_q8_0x4 * a_ptr_base = (const block_q8_0x4 *) vy + (y / 4) * nb; + + for (int x = 0; x < nc; x += ncols_interleaved) { + const block_q8_0x4 * b_ptr = b_ptr_base + (x / 4) * nb; + const block_q8_0x4 * a_ptr = a_ptr_base; + + svfloat32_t acc_f32_01 = svdup_f32(0); + svfloat32_t acc_f32_23 = svdup_f32(0); + + for (int b = 0; b < nb; b++) { + + svint32_t acc_01 = svdup_s32(0); + svint32_t acc_23 = svdup_s32(0); + + // Process 4 chunks of 8 positions each + for (int chunk = 0; chunk < 4; chunk++) { + svint8_t s_a01 = svld1rq_s8(svptrue_b8(), a_ptr->qs + chunk * 32); + svint8_t s_a23 = svld1rq_s8(svptrue_b8(), a_ptr->qs + chunk * 32 + 16); + svint8_t s_b0123 = svld1_s8(svptrue_b8(), b_ptr->qs + chunk * 32); + + acc_01 = svmmla_s32(acc_01, s_a01, s_b0123); + acc_23 = svmmla_s32(acc_23, s_a23, s_b0123); + } + + // Reorder outputs from 2×2 tiles to row-major + // acc[01] = [r0c0, r0c1, r1c0, r1c1, r0c2, r0c3, r1c2, r1c3] + // acc[23] = [r2c0, r2c1, r3c0, r3c1, r2c2, r2c3, r3c2, r3c3] + + svint32_t row01 = svtbl_s32(acc_01, idx); + svint32_t row23 = svtbl_s32(acc_23, idx); + + svfloat16_t temp1 = svld1_f16(svptrue_pat_b16(SV_VL4), (const __fp16 *) a_ptr->d); + svfloat16_t temp2 = svld1_f16(svptrue_pat_b16(SV_VL4), (const __fp16 *) b_ptr->d); + svfloat32_t sv_a_d = svtbl_f32(svcvt_f32_f16_x(svptrue_b32(), svzip1_f16(temp1, temp1)), idx_sc1); + svfloat32_t sv_b_d = svtbl_f32(svcvt_f32_f16_x(svptrue_b32(), svzip1_f16(temp2, temp2)), idx_sc2); + + acc_f32_01 = svmla_f32_x(svptrue_b32(), acc_f32_01, svcvt_f32_s32_x(svptrue_b32(), row01), svmul_lane_f32(sv_b_d, sv_a_d, 0)); + acc_f32_23 = svmla_f32_x(svptrue_b32(), acc_f32_23, svcvt_f32_s32_x(svptrue_b32(), row23), svmul_lane_f32(sv_b_d, sv_a_d, 2)); + a_ptr++; + b_ptr++; + } + + svbool_t pg4 = svptrue_pat_b32(SV_VL4); + svst1_f32(pg4, s + (y+0) * bs + x, acc_f32_01); + svst1_f32(pg4, s + (y+1) * bs + x, svext_f32(acc_f32_01, acc_f32_01, 4)); + svst1_f32(pg4, s + (y+2) * bs + x, acc_f32_23); + svst1_f32(pg4, s + (y+3) * bs + x, svext_f32(acc_f32_23, acc_f32_23, 4)); + } + } + return; + } +#endif // SVE compile-time end + #if defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_MATMUL_INT8) const block_q8_0x4 * b_ptr_base = (const block_q8_0x4 *) vx; diff --git a/ggml/src/ggml-cpu/arch/loongarch/quants.c b/ggml/src/ggml-cpu/arch/loongarch/quants.c index f531e916b9e..9c43da6cf89 100644 --- a/ggml/src/ggml-cpu/arch/loongarch/quants.c +++ b/ggml/src/ggml-cpu/arch/loongarch/quants.c @@ -977,6 +977,35 @@ void ggml_vec_dot_q8_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const voi sumf = hsum_float_8(acc); *s = sumf; + +#elif defined(__loongarch_sx) + + __m128 acc = (__m128)__lsx_vldi(0); + + for (; ib < nb; ++ib) { + const float d = GGML_CPU_FP16_TO_FP32(x[ib].d) * GGML_CPU_FP16_TO_FP32(y[ib].d); + const __m128i qx_0 = __lsx_vld((const __m128i *)x[ib].qs, 0); + const __m128i qx_1 = __lsx_vld((const __m128i *)x[ib].qs + 1, 0); + const __m128i qy_0 = __lsx_vld((const __m128i *)y[ib].qs, 0); + const __m128i qy_1 = __lsx_vld((const __m128i *)y[ib].qs + 1, 0); + + const __m128i p16_0 = lsx_maddubs_h(qx_0, qy_0); + const __m128i p16_1 = lsx_maddubs_h(qx_1, qy_1); + + // Sum int16 pairs → int32 + const __m128i s_0 = __lsx_vaddwev_w_h(p16_0, p16_1); + const __m128i s_1 = __lsx_vaddwod_w_h(p16_0, p16_1); + + const __m128 q = __lsx_vffint_s_w(__lsx_vadd_w(s_0, s_1)); + acc = __lsx_vfmadd_s(__lsx_vreplfr2vr_s(d), q, acc); + } + + __m128 res = lsx_hadd_s(acc, acc); + res = lsx_hadd_s(res, res); + sumf = ((v4f32)res)[0]; + + *s = sumf; + #else UNUSED(nb); UNUSED(ib); @@ -1443,6 +1472,99 @@ void ggml_vec_dot_q6_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi *s = hsum_float_8(acc); +#elif defined(__loongarch_sx) + + const __m128i m32s = __lsx_vreplgr2vr_b(32); + + __m128 acc_0 = (__m128)__lsx_vldi(0); + __m128 acc_1 = (__m128)__lsx_vldi(0); + + for (int i = 0; i < nb; ++i) { + + const float d = y[i].d * GGML_CPU_FP16_TO_FP32(x[i].d); + + const uint8_t * GGML_RESTRICT q4 = x[i].ql; + const uint8_t * GGML_RESTRICT qh = x[i].qh; + const int8_t * GGML_RESTRICT q8 = y[i].qs; + + const __m128i scale_i8 = __lsx_vld(x[i].scales, 0); + const __m128i scales_lo = __lsx_vsllwil_h_b(scale_i8, 0); + const __m128i scales_hi = __lsx_vsllwil_h_b(__lsx_vbsrl_v(scale_i8, 8), 0); + + __m128i sumi_0 = __lsx_vldi(0); + __m128i sumi_1 = __lsx_vldi(0); + + for (int j = 0; j < QK_K/128; ++j) { + + const __m128i q4bitsH_0 = __lsx_vld((const __m128i*)qh, 0); qh += 16; + const __m128i q4bitsH_1 = __lsx_vld((const __m128i*)qh, 0); qh += 16; + + const __m128i q4h_0 = __lsx_vslli_b(__lsx_vandi_b(q4bitsH_0, 3), 4); + const __m128i q4h_1 = __lsx_vslli_b(__lsx_vandi_b(q4bitsH_1, 3), 4); + const __m128i q4h_2 = __lsx_vslli_b(__lsx_vandi_b(q4bitsH_0, 3 << 2), 2); + const __m128i q4h_3 = __lsx_vslli_b(__lsx_vandi_b(q4bitsH_1, 3 << 2), 2); + const __m128i q4h_4 = __lsx_vandi_b(q4bitsH_0, 3 << 4); + const __m128i q4h_5 = __lsx_vandi_b(q4bitsH_1, 3 << 4); + const __m128i q4h_6 = __lsx_vsrli_b(__lsx_vandi_b(q4bitsH_0, 3 << 6), 2); + const __m128i q4h_7 = __lsx_vsrli_b(__lsx_vandi_b(q4bitsH_1, 3 << 6), 2); + + const __m128i q4bits1_0 = __lsx_vld((const __m128i*)q4, 0); q4 += 16; + const __m128i q4bits1_1 = __lsx_vld((const __m128i*)q4, 0); q4 += 16; + const __m128i q4bits2_0 = __lsx_vld((const __m128i*)q4, 0); q4 += 16; + const __m128i q4bits2_1 = __lsx_vld((const __m128i*)q4, 0); q4 += 16; + + const __m128i q4_0 = __lsx_vor_v(__lsx_vandi_b(q4bits1_0, 0xf), q4h_0); + const __m128i q4_1 = __lsx_vor_v(__lsx_vandi_b(q4bits1_1, 0xf), q4h_1); + const __m128i q4_2 = __lsx_vor_v(__lsx_vandi_b(q4bits2_0, 0xf), q4h_2); + const __m128i q4_3 = __lsx_vor_v(__lsx_vandi_b(q4bits2_1, 0xf), q4h_3); + const __m128i q4_4 = __lsx_vor_v(__lsx_vsrli_b(q4bits1_0, 4), q4h_4); + const __m128i q4_5 = __lsx_vor_v(__lsx_vsrli_b(q4bits1_1, 4), q4h_5); + const __m128i q4_6 = __lsx_vor_v(__lsx_vsrli_b(q4bits2_0, 4), q4h_6); + const __m128i q4_7 = __lsx_vor_v(__lsx_vsrli_b(q4bits2_1, 4), q4h_7); + + const __m128i q8_0 = __lsx_vld((const __m128i*)q8, 0); q8 += 16; + const __m128i q8_1 = __lsx_vld((const __m128i*)q8, 0); q8 += 16; + const __m128i q8_2 = __lsx_vld((const __m128i*)q8, 0); q8 += 16; + const __m128i q8_3 = __lsx_vld((const __m128i*)q8, 0); q8 += 16; + const __m128i q8_4 = __lsx_vld((const __m128i*)q8, 0); q8 += 16; + const __m128i q8_5 = __lsx_vld((const __m128i*)q8, 0); q8 += 16; + const __m128i q8_6 = __lsx_vld((const __m128i*)q8, 0); q8 += 16; + const __m128i q8_7 = __lsx_vld((const __m128i*)q8, 0); q8 += 16; + + __m128i p16_0 = lsx_maddubs_h(__lsx_vsub_b(q4_0, m32s), q8_0); + __m128i p16_1 = lsx_maddubs_h(__lsx_vsub_b(q4_1, m32s), q8_1); + __m128i p16_2 = lsx_maddubs_h(__lsx_vsub_b(q4_2, m32s), q8_2); + __m128i p16_3 = lsx_maddubs_h(__lsx_vsub_b(q4_3, m32s), q8_3); + __m128i p16_4 = lsx_maddubs_h(__lsx_vsub_b(q4_4, m32s), q8_4); + __m128i p16_5 = lsx_maddubs_h(__lsx_vsub_b(q4_5, m32s), q8_5); + __m128i p16_6 = lsx_maddubs_h(__lsx_vsub_b(q4_6, m32s), q8_6); + __m128i p16_7 = lsx_maddubs_h(__lsx_vsub_b(q4_7, m32s), q8_7); + + const __m128i sc_vec = j == 0 ? scales_lo : scales_hi; + + p16_0 = lsx_madd_h(__lsx_vreplvei_h(sc_vec, 0), p16_0); + p16_1 = lsx_madd_h(__lsx_vreplvei_h(sc_vec, 1), p16_1); + p16_2 = lsx_madd_h(__lsx_vreplvei_h(sc_vec, 2), p16_2); + p16_3 = lsx_madd_h(__lsx_vreplvei_h(sc_vec, 3), p16_3); + p16_4 = lsx_madd_h(__lsx_vreplvei_h(sc_vec, 4), p16_4); + p16_5 = lsx_madd_h(__lsx_vreplvei_h(sc_vec, 5), p16_5); + p16_6 = lsx_madd_h(__lsx_vreplvei_h(sc_vec, 6), p16_6); + p16_7 = lsx_madd_h(__lsx_vreplvei_h(sc_vec, 7), p16_7); + + sumi_0 = __lsx_vadd_w(sumi_0, __lsx_vadd_w(p16_0, p16_2)); + sumi_1 = __lsx_vadd_w(sumi_1, __lsx_vadd_w(p16_1, p16_3)); + sumi_0 = __lsx_vadd_w(sumi_0, __lsx_vadd_w(p16_4, p16_6)); + sumi_1 = __lsx_vadd_w(sumi_1, __lsx_vadd_w(p16_5, p16_7)); + } + + __m128 p_0 = __lsx_vfmul_s(__lsx_vreplfr2vr_s(d), __lsx_vffint_s_w(sumi_0)); + __m128 p_1 = __lsx_vfmul_s(__lsx_vreplfr2vr_s(d), __lsx_vffint_s_w(sumi_1)); + acc_0 = __lsx_vfadd_s(p_0, acc_0); + acc_1 = __lsx_vfadd_s(p_1, acc_1); + } + + *s = hsum_float_4x4(acc_0, acc_1, (__m128)__lsx_vldi(0), (__m128)__lsx_vldi(0)); + #else UNUSED(x); UNUSED(y); @@ -2149,6 +2271,35 @@ void ggml_vec_dot_iq4_xs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const v *s = hsum_float_8(accum); +#elif defined(__loongarch_sx) + + const __m128i values128 = __lsx_vld((const __m128i*)kvalues_iq4nl, 0); + + __m128 accum = (__m128)__lsx_vldi(0); + for (int ibl = 0; ibl < nb; ++ibl) { + const uint8_t * qs = x[ibl].qs; + const int8_t * q8 = y[ibl].qs; + uint16_t sh = x[ibl].scales_h; + __m128i sumi = __lsx_vldi(0); + for (int ib = 0; ib < QK_K/32; ++ib) { + const __m128i q4bits = __lsx_vld((const __m128i*)qs, 0); qs += 16; + const __m128i q8b_0 = __lsx_vld((const __m128i*)q8, 0); q8 += 16; + const __m128i q8b_1 = __lsx_vld((const __m128i*)q8, 0); q8 += 16; + const __m128i q4b_0 = __lsx_vshuf_b(values128, values128, __lsx_vandi_b(q4bits, 0xf)); + const __m128i q4b_1 = __lsx_vshuf_b(values128, values128, __lsx_vsrli_b(q4bits, 4)); + const __m128i p16_0 = lsx_maddubs_h(q4b_0, q8b_0); + const __m128i p16_1 = lsx_maddubs_h(q4b_1, q8b_1); + const int16_t ls = (((x[ibl].scales_l[ib/2] >> ((ib & 1) * 4)) & 0xf) | ((sh & 0x3) << 4)) - 32; + sh >>= 2; + sumi = __lsx_vadd_w(lsx_madd_h(p16_0, __lsx_vreplgr2vr_h(ls)), sumi); + sumi = __lsx_vadd_w(lsx_madd_h(p16_1, __lsx_vreplgr2vr_h(ls)), sumi); + } + const float ds = GGML_CPU_FP16_TO_FP32(x[ibl].d) * y[ibl].d; + accum = __lsx_vfadd_s(__lsx_vfmul_s(__lsx_vreplfr2vr_s(ds), __lsx_vffint_s_w(sumi)), accum); + } + + *s = ((v4f32)lsx_hadd_s(lsx_hadd_s(accum, accum), lsx_hadd_s(accum, accum)))[0]; + #else UNUSED(x); UNUSED(y); @@ -2156,4 +2307,3 @@ void ggml_vec_dot_iq4_xs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const v ggml_vec_dot_iq4_xs_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc); #endif } - diff --git a/ggml/src/ggml-cpu/arch/powerpc/quants.c b/ggml/src/ggml-cpu/arch/powerpc/quants.c index d3dfd049eaf..644c380c738 100644 --- a/ggml/src/ggml-cpu/arch/powerpc/quants.c +++ b/ggml/src/ggml-cpu/arch/powerpc/quants.c @@ -2302,4 +2302,3 @@ void ggml_vec_dot_iq4_xs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const v ggml_vec_dot_iq4_xs_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc); #endif } - diff --git a/ggml/src/ggml-cpu/arch/riscv/quants.c b/ggml/src/ggml-cpu/arch/riscv/quants.c index ae0ebb3cad1..47e9180bf9b 100644 --- a/ggml/src/ggml-cpu/arch/riscv/quants.c +++ b/ggml/src/ggml-cpu/arch/riscv/quants.c @@ -15,6 +15,12 @@ #include <stdlib.h> // for qsort #include <stdio.h> // for GGML_ASSERT +#ifdef _MSC_VER +#define NOINLINE __declspec(noinline) +#else +#define NOINLINE __attribute__((__noinline__)) +#endif + #define GROUP_MAX_EPS 1e-15f #define GROUP_MAX_EPS_IQ3_XXS 1e-8f #define GROUP_MAX_EPS_IQ2_S 1e-8f @@ -113,6 +119,104 @@ void quantize_row_q8_1(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, i #endif } +void quantize_row_q8_K(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k) { + assert(k % QK_K == 0); + size_t nb = k / QK_K; + +#if defined __riscv_v + block_q8_K * y_blocks = (block_q8_K *)y; + const size_t vlmax_f32m8 = __riscv_vsetvlmax_e32m8(); + + for (size_t i = 0; i < nb; i++) { + const float* x_block = x + i * QK_K; + block_q8_K* y_block = &y_blocks[i]; + + // 1. Calculate Min/Max + vfloat32m8_t max_v = __riscv_vfmv_v_f_f32m8(-__builtin_inff(), vlmax_f32m8); + vfloat32m8_t min_v = __riscv_vfmv_v_f_f32m8(__builtin_inff(), vlmax_f32m8); + + size_t rem = QK_K; + size_t offset = 0; + while (rem > 0) { + size_t vl = __riscv_vsetvl_e32m8(rem); + vfloat32m8_t v_curr = __riscv_vle32_v_f32m8(x_block + offset, vl); + max_v = __riscv_vfmax_vv_f32m8(max_v, v_curr, vl); + min_v = __riscv_vfmin_vv_f32m8(min_v, v_curr, vl); + rem -= vl; + offset += vl; + } + + vfloat32m1_t v_init_max = __riscv_vfmv_s_f_f32m1(-__builtin_inff(), 1); + vfloat32m1_t v_init_min = __riscv_vfmv_s_f_f32m1(__builtin_inff(), 1); + + vfloat32m1_t v_scalar_max = __riscv_vfredmax_vs_f32m8_f32m1(max_v, v_init_max, vlmax_f32m8); + vfloat32m1_t v_scalar_min = __riscv_vfredmin_vs_f32m8_f32m1(min_v, v_init_min, vlmax_f32m8); + + float max_val = __riscv_vfmv_f_s_f32m1_f32(v_scalar_max); + float min_val = __riscv_vfmv_f_s_f32m1_f32(v_scalar_min); + + float amax = fabsf(max_val) > fabsf(min_val) ? fabsf(max_val) : fabsf(min_val); + + if (amax == 0.0f) { + y_block->d = 0.0f; + memset(y_block->qs, 0, QK_K); + memset(y_block->bsums, 0, sizeof(y_block->bsums)); + continue; + } + + const float iscale = -127.f / (fabsf(max_val) > fabsf(min_val) ? max_val : min_val); + y_block->d = 1.0f / iscale; + + // 2. Quantize and Calculate Sums + offset = 0; + rem = QK_K; + vint16m1_t v_zero_sum = __riscv_vmv_v_x_i16m1(0, 1); + + while (rem > 0) { + size_t vl = __riscv_vsetvl_e32m8(rem); + vfloat32m8_t v_f = __riscv_vle32_v_f32m8(x_block + offset, vl); + + v_f = __riscv_vfmul_vf_f32m8(v_f, iscale, vl); + + vint32m8_t v_i32 = __riscv_vfcvt_x_f_v_i32m8_rm(v_f, __RISCV_FRM_RNE, vl); + vint16m4_t v_i16 = __riscv_vnclip_wx_i16m4(v_i32, 0, __RISCV_VXRM_RNE, vl); + vint8m2_t v_q = __riscv_vnclip_wx_i8m2(v_i16, 0, __RISCV_VXRM_RNE, vl); + + __riscv_vse8_v_i8m2(y_block->qs + offset, v_q, vl); + + // first iteration clear + + int sum_idx; + vint8m1_t chunk_m1; + vint16m1_t v_sum; + sum_idx = offset / 16; + chunk_m1 = __riscv_vget_v_i8m2_i8m1(v_q, 0); + v_sum = __riscv_vwredsum_vs_i8m1_i16m1(chunk_m1, v_zero_sum, 16); + y_block->bsums[sum_idx] = (int16_t)__riscv_vmv_x_s_i16m1_i16(v_sum); + + // remaining iterations + vint8m2_t slid_q = v_q; + for (size_t k = 16; k < vl; k += 16) { + slid_q = __riscv_vslidedown_vx_i8m2(slid_q, 16, vl); + + sum_idx = (offset + k) / 16; + chunk_m1 = __riscv_vget_v_i8m2_i8m1(slid_q, 0); + + v_sum = __riscv_vwredsum_vs_i8m1_i16m1(chunk_m1, v_zero_sum, 16); + y_block->bsums[sum_idx] =(int16_t)__riscv_vmv_x_s_i16m1_i16(v_sum); + } + + rem -= vl; + offset += vl; + } + } +#else + GGML_UNUSED(nb); + // scalar + quantize_row_q8_K_ref(x, y, k); +#endif +} + //===================================== Dot products ================================= void ggml_vec_dot_q4_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { @@ -376,7 +480,106 @@ void ggml_vec_dot_q8_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const voi #endif } -void ggml_vec_dot_q2_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { +#if defined(__riscv_v) +static NOINLINE void ggml_vec_dot_q1_0_q8_0_vl256(const int n, float * GGML_RESTRICT s, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy) { + const int qk = QK1_0; + const int nb = n / qk; + assert(n % qk == 0); + + const block_q1_0 * GGML_RESTRICT x = vx; + const block_q8_0 * GGML_RESTRICT y = vy; + + //LMUL = 1, VLMAX = 32 + const size_t vl32 = __riscv_vsetvl_e8m1(32); + assert(vl32 == 32); + + const vint16m1_t zero = __riscv_vmv_v_x_i16m1(0, 1); + + float sumf = 0; + + for (int ib = 0; ib < nb; ++ib) { + const float d0 = GGML_CPU_FP16_TO_FP32(x[ib].d); + + float acc = 0; + + for (int k = 0; k < 4; ++k) { + const block_q8_0 * GGML_RESTRICT yb = &y[ib * 4 + k]; + const vbool8_t is_not_zero = __riscv_vlm_v_b8(x[ib].qs + 4 * k, vl32); + + const vint8m1_t qy = __riscv_vle8_v_i8m1(yb->qs, vl32); + const vint8m1_t neg_qy = __riscv_vneg_v_i8m1(qy, vl32); + const vint8m1_t sy = __riscv_vmerge_vvm_i8m1(neg_qy, qy, is_not_zero, vl32); + + const vint16m1_t red = __riscv_vwredsum_vs_i8m1_i16m1(sy, zero, vl32); + acc += GGML_CPU_FP16_TO_FP32(yb->d) * (float)__riscv_vmv_x_s_i16m1_i16(red); + } + + sumf += d0 * acc; + } + + *s = sumf; +} + +static NOINLINE void ggml_vec_dot_q1_0_q8_0_vl128(const int n, float * GGML_RESTRICT s, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy) { + const int qk = QK1_0; + const int nb = n / qk; + assert(n % qk == 0); + + const block_q1_0 * GGML_RESTRICT x = vx; + const block_q8_0 * GGML_RESTRICT y = vy; + + //LMUL = 2, VLMAX = 32 + const size_t vl32 = __riscv_vsetvl_e8m2(32); + assert(vl32 == 32); + + const vint16m1_t zero = __riscv_vmv_v_x_i16m1(0, 1); + + float sumf = 0; + + for (int ib = 0; ib < nb; ++ib) { + const float d0 = GGML_CPU_FP16_TO_FP32(x[ib].d); + + float acc = 0; + + for (int k = 0; k < 4; ++k) { + const block_q8_0 * GGML_RESTRICT yb = &y[ib * 4 + k]; + const vbool4_t is_not_zero = __riscv_vlm_v_b4(x[ib].qs + 4 * k, vl32); + + const vint8m2_t qy = __riscv_vle8_v_i8m2(yb->qs, vl32); + const vint8m2_t neg_qy =__riscv_vneg_v_i8m2(qy, vl32); + const vint8m2_t sy = __riscv_vmerge_vvm_i8m2(neg_qy, qy, is_not_zero, vl32); + + const vint16m1_t red = __riscv_vwredsum_vs_i8m2_i16m1(sy, zero, vl32); + acc += GGML_CPU_FP16_TO_FP32(yb->d) * (float)__riscv_vmv_x_s_i16m1_i16(red); + } + + sumf += d0 * acc; + } + + *s = sumf; +} +#endif + +void ggml_vec_dot_q1_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { +#if defined(__riscv_v) + assert(nrc == 1); + + const size_t vlen_bits = __riscv_vlenb() * 8; + + if (vlen_bits >= 256) { + ggml_vec_dot_q1_0_q8_0_vl256(n, s, vx, vy); + } else if (vlen_bits >= 128) { + ggml_vec_dot_q1_0_q8_0_vl128(n, s, vx, vy); + } else { + ggml_vec_dot_q1_0_q8_0_generic(n, s, bs, vx, bx, vy, by, nrc); + } +#else + ggml_vec_dot_q1_0_q8_0_generic(n, s, bs, vx, bx, vy, by, nrc); +#endif +} + +#if defined __riscv_xtheadvector +void ggml_vec_dot_q2_K_q8_K_xtheadvector(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { assert(nrc == 1); UNUSED(nrc); UNUSED(bx); @@ -388,8 +591,6 @@ void ggml_vec_dot_q2_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi const int nb = n / QK_K; -#if defined __riscv_xtheadvector - float sumf = 0; uint8_t atmp[16]; @@ -484,246 +685,281 @@ void ggml_vec_dot_q2_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi } *s = sumf; +} +#endif -#elif defined __riscv_v +#if defined __riscv_v +void ggml_vec_dot_q2_K_q8_K_vl128(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + + const block_q2_K * GGML_RESTRICT x = vx; + const block_q8_K * GGML_RESTRICT y = vy; + + const int nb = n / QK_K; float sumf = 0; uint8_t atmp[16]; - const int vector_length = __riscv_vlenb() * 8; uint8_t temp_01[32] = { 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1 }; - switch (vector_length) { - case 256: - for (int i = 0; i < nb; ++i) { - const uint8_t * q2 = x[i].qs; - const int8_t * q8 = y[i].qs; - const uint8_t * sc = x[i].scales; - - const float dall = y[i].d * GGML_CPU_FP16_TO_FP32(x[i].d); - const float dmin = -y[i].d * GGML_CPU_FP16_TO_FP32(x[i].dmin); - - size_t vl = 16; - - vuint8m1_t scales = __riscv_vle8_v_u8m1(sc, vl); - vuint8m1_t aux = __riscv_vand_vx_u8m1(scales, 0x0F, vl); + for (int i = 0; i < nb; ++i) { + const uint8_t * q2 = x[i].qs; + const int8_t * q8 = y[i].qs; + const uint8_t * sc = x[i].scales; + const float dall = y[i].d * GGML_CPU_FP16_TO_FP32(x[i].d); + const float dmin = -y[i].d * GGML_CPU_FP16_TO_FP32(x[i].dmin); + uint8_t *patmp = atmp; + int vsums; + int tmp, t1, t2, t3, t4, t5, t6, t7; + __asm__ __volatile__( + "vsetivli zero, 16, e8, m1\n\t" + "vmv.v.x v8, zero\n\t" + "lb zero, 15(%[sc])\n\t" + "vle8.v v1, (%[sc])\n\t" + "vle8.v v2, (%[bsums])\n\t" + "addi %[tmp], %[bsums], 16\n\t" + "vand.vi v0, v1, 0xF\n\t" + "vsrl.vi v1, v1, 4\n\t" + "vle8.v v3, (%[tmp])\n\t" + "vse8.v v0, (%[scale])\n\t" + "vsetivli zero, 16, e16, m2\n\t" + "vzext.vf2 v0, v1\n\t" + "vwmul.vv v4, v0, v2\n\t" + "vsetivli zero, 16, e32, m4\n\t" + "vredsum.vs v8, v4, v8\n\t" + "vmv.x.s %[vsums], v8" + : [tmp] "=&r" (tmp), [vsums] "=&r" (vsums) + : [sc] "r" (sc), [scale] "r" (atmp), [bsums] "r" (y[i].bsums) + : "memory" + , "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7" + , "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15" + , "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23" + , "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31" + ); + sumf += dmin * vsums; + int isum = 0; - vint16m1_t q8sums = __riscv_vle16_v_i16m1(y[i].bsums, vl); + for (int j = 0; j < QK_K/128; ++j) { + __asm__ __volatile__( + "lb zero, 31(%[q2])\n\t" + "addi %[tmp], %[q2], 16\n\t" + "addi %[t1], %[q8], 16\n\t" + "vsetivli zero, 16, e8, m1\n\t" + "vle8.v v0, (%[q2])\n\t" + "vle8.v v1, (%[tmp])\n\t" + "vsrl.vi v2, v0, 2\n\t" + "vsrl.vi v3, v1, 2\n\t" + "vsrl.vi v4, v0, 4\n\t" + "addi %[tmp], %[q8], 32\n\t" + "vle8.v v8, (%[q8])\n\t" + "vle8.v v9, (%[t1])\n\t" + "addi %[t1], %[t1], 32\n\t" + "vsrl.vi v5, v1, 4\n\t" + "vsrl.vi v6, v0, 6\n\t" + "vsrl.vi v7, v1, 6\n\t" + "vle8.v v10, (%[tmp])\n\t" + "vle8.v v11, (%[t1])\n\t" + "addi %[tmp], %[tmp], 32\n\t" + "addi %[t1], %[t1], 32\n\t" + "vand.vi v0, v0, 0x3\n\t" + "vand.vi v1, v1, 0x3\n\t" + "vand.vi v2, v2, 0x3\n\t" + "vle8.v v12, (%[tmp])\n\t" + "vle8.v v13, (%[t1])\n\t" + "addi %[tmp], %[tmp], 32\n\t" + "addi %[t1], %[t1], 32\n\t" + "vand.vi v3, v3, 0x3\n\t" + "vand.vi v4, v4, 0x3\n\t" + "vand.vi v5, v5, 0x3\n\t" + "vle8.v v14, (%[tmp])\n\t" + "vle8.v v15, (%[t1])\n\t" + "vwmul.vv v16, v0, v8\n\t" + "vwmul.vv v18, v1, v9\n\t" + "vwmul.vv v20, v2, v10\n\t" + "vwmul.vv v22, v3, v11\n\t" + "vwmul.vv v24, v4, v12\n\t" + "vwmul.vv v26, v5, v13\n\t" + "vwmul.vv v28, v6, v14\n\t" + "vwmul.vv v30, v7, v15\n\t" + "vsetivli zero, 8, e16, m1\n\t" + "vmv.v.x v0, zero\n\t" + "lbu %[tmp], 0(%[scale])\n\t" + "vwredsum.vs v8, v16, v0\n\t" + "vwredsum.vs v9, v18, v0\n\t" + "lbu %[t1], 1(%[scale])\n\t" + "vwredsum.vs v10, v20, v0\n\t" + "vwredsum.vs v11, v22, v0\n\t" + "lbu %[t2], 2(%[scale])\n\t" + "vwredsum.vs v12, v24, v0\n\t" + "vwredsum.vs v13, v26, v0\n\t" + "lbu %[t3], 3(%[scale])\n\t" + "vwredsum.vs v14, v28, v0\n\t" + "vwredsum.vs v15, v30, v0\n\t" + "lbu %[t4], 4(%[scale])\n\t" + "vwredsum.vs v8, v17, v8\n\t" + "vwredsum.vs v9, v19, v9\n\t" + "lbu %[t5], 5(%[scale])\n\t" + "vwredsum.vs v10, v21, v10\n\t" + "vwredsum.vs v11, v23, v11\n\t" + "lbu %[t6], 6(%[scale])\n\t" + "vwredsum.vs v12, v25, v12\n\t" + "vwredsum.vs v13, v27, v13\n\t" + "lbu %[t7], 7(%[scale])\n\t" + "vwredsum.vs v14, v29, v14\n\t" + "vwredsum.vs v15, v31, v15\n\t" + "vsetivli zero, 4, e32, m1\n\t" + "vmul.vx v0, v8, %[tmp]\n\t" + "vmul.vx v1, v9, %[t1]\n\t" + "vmacc.vx v0, %[t2], v10\n\t" + "vmacc.vx v1, %[t3], v11\n\t" + "vmacc.vx v0, %[t4], v12\n\t" + "vmacc.vx v1, %[t5], v13\n\t" + "vmacc.vx v0, %[t6], v14\n\t" + "vmacc.vx v1, %[t7], v15\n\t" + "vmv.x.s %[tmp], v0\n\t" + "vmv.x.s %[t1], v1\n\t" + "add %[isum], %[isum], %[tmp]\n\t" + "add %[isum], %[isum], %[t1]" + : [tmp] "=&r" (tmp), [t1] "=&r" (t1), [t2] "=&r" (t2), [t3] "=&r" (t3) + , [t4] "=&r" (t4), [t5] "=&r" (t5), [t6] "=&r" (t6), [t7] "=&r" (t7) + , [isum] "+&r" (isum) + : [q2] "r" (q2), [scale] "r" (patmp), [q8] "r" (q8) + : "memory" + , "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7" + , "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15" + , "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23" + , "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31" + ); + q2 += 32; q8 += 128; patmp += 8; + } - vuint8mf2_t scales_2 = __riscv_vle8_v_u8mf2(sc, vl); - vuint8mf2_t mins8 = __riscv_vsrl_vx_u8mf2(scales_2, 0x4, vl); - vint16m1_t mins = __riscv_vreinterpret_v_u16m1_i16m1(__riscv_vzext_vf2_u16m1(mins8, vl)); - vint32m2_t prod = __riscv_vwmul_vv_i32m2(q8sums, mins, vl); - vint32m1_t vsums = __riscv_vredsum_vs_i32m2_i32m1(prod, __riscv_vmv_v_x_i32m1(0, 1), vl); + sumf += dall * isum; + } - sumf += dmin * __riscv_vmv_x_s_i32m1_i32(vsums); + *s = sumf; +} - vl = 32; +void ggml_vec_dot_q2_K_q8_K_vl256(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); - vint32m1_t vzero = __riscv_vmv_v_x_i32m1(0, 1); - vuint8m1_t v_b = __riscv_vle8_v_u8m1(temp_01, vl); + const block_q2_K * GGML_RESTRICT x = vx; + const block_q8_K * GGML_RESTRICT y = vy; - uint8_t is = 0; - int isum = 0; + const int nb = n / QK_K; - for (int j = 0; j < QK_K / 128; ++j) { - // load Q2 - vuint8m1_t q2_x = __riscv_vle8_v_u8m1(q2, vl); + float sumf = 0; + uint8_t atmp[16]; - vuint8m1_t q2_0 = __riscv_vand_vx_u8m1(q2_x, 0x03, vl); - vuint8m1_t q2_1 = __riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(q2_x, 0x2, vl), 0x03, vl); - vuint8m1_t q2_2 = __riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(q2_x, 0x4, vl), 0x03, vl); - vuint8m1_t q2_3 = __riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(q2_x, 0x6, vl), 0x03, vl); + uint8_t temp_01[32] = { 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1 }; - // duplicate scale elements for product - vuint8m1_t sc0 = __riscv_vrgather_vv_u8m1(aux, __riscv_vadd_vx_u8m1(v_b, 0 + is, vl), vl); - vuint8m1_t sc1 = __riscv_vrgather_vv_u8m1(aux, __riscv_vadd_vx_u8m1(v_b, 2 + is, vl), vl); - vuint8m1_t sc2 = __riscv_vrgather_vv_u8m1(aux, __riscv_vadd_vx_u8m1(v_b, 4 + is, vl), vl); - vuint8m1_t sc3 = __riscv_vrgather_vv_u8m1(aux, __riscv_vadd_vx_u8m1(v_b, 6 + is, vl), vl); + for (int i = 0; i < nb; ++i) { + const uint8_t * q2 = x[i].qs; + const int8_t * q8 = y[i].qs; + const uint8_t * sc = x[i].scales; - vint16m2_t p0 = __riscv_vreinterpret_v_u16m2_i16m2(__riscv_vwmulu_vv_u16m2(q2_0, sc0, vl)); - vint16m2_t p1 = __riscv_vreinterpret_v_u16m2_i16m2(__riscv_vwmulu_vv_u16m2(q2_1, sc1, vl)); - vint16m2_t p2 = __riscv_vreinterpret_v_u16m2_i16m2(__riscv_vwmulu_vv_u16m2(q2_2, sc2, vl)); - vint16m2_t p3 = __riscv_vreinterpret_v_u16m2_i16m2(__riscv_vwmulu_vv_u16m2(q2_3, sc3, vl)); + const float dall = y[i].d * GGML_CPU_FP16_TO_FP32(x[i].d); + const float dmin = -y[i].d * GGML_CPU_FP16_TO_FP32(x[i].dmin); - // load Q8 - vint8m1_t q8_0 = __riscv_vle8_v_i8m1(q8, vl); - vint8m1_t q8_1 = __riscv_vle8_v_i8m1(q8 + 32, vl); - vint8m1_t q8_2 = __riscv_vle8_v_i8m1(q8 + 64, vl); - vint8m1_t q8_3 = __riscv_vle8_v_i8m1(q8 + 96, vl); + size_t vl = 16; - vint32m4_t s0 = __riscv_vwmul_vv_i32m4(p0, __riscv_vwcvt_x_x_v_i16m2(q8_0, vl), vl); - vint32m4_t s1 = __riscv_vwmul_vv_i32m4(p1, __riscv_vwcvt_x_x_v_i16m2(q8_1, vl), vl); - vint32m4_t s2 = __riscv_vwmul_vv_i32m4(p2, __riscv_vwcvt_x_x_v_i16m2(q8_2, vl), vl); - vint32m4_t s3 = __riscv_vwmul_vv_i32m4(p3, __riscv_vwcvt_x_x_v_i16m2(q8_3, vl), vl); + vuint8m1_t scales = __riscv_vle8_v_u8m1(sc, vl); + vuint8m1_t aux = __riscv_vand_vx_u8m1(scales, 0x0F, vl); - vint32m1_t isum0 = __riscv_vredsum_vs_i32m4_i32m1(__riscv_vadd_vv_i32m4(s0, s1, vl), vzero, vl); - vint32m1_t isum1 = __riscv_vredsum_vs_i32m4_i32m1(__riscv_vadd_vv_i32m4(s2, s3, vl), isum0, vl); + vint16m1_t q8sums = __riscv_vle16_v_i16m1(y[i].bsums, vl); - isum += __riscv_vmv_x_s_i32m1_i32(isum1); + vuint8mf2_t scales_2 = __riscv_vle8_v_u8mf2(sc, vl); + vuint8mf2_t mins8 = __riscv_vsrl_vx_u8mf2(scales_2, 0x4, vl); + vint16m1_t mins = __riscv_vreinterpret_v_u16m1_i16m1(__riscv_vzext_vf2_u16m1(mins8, vl)); + vint32m2_t prod = __riscv_vwmul_vv_i32m2(q8sums, mins, vl); + vint32m1_t vsums = __riscv_vredsum_vs_i32m2_i32m1(prod, __riscv_vmv_v_x_i32m1(0, 1), vl); - q2 += 32; - q8 += 128; - is = 8; - } + sumf += dmin * __riscv_vmv_x_s_i32m1_i32(vsums); - sumf += dall * isum; - } - break; - case 128: - for (int i = 0; i < nb; ++i) { - const uint8_t * q2 = x[i].qs; - const int8_t * q8 = y[i].qs; - const uint8_t * sc = x[i].scales; - const float dall = y[i].d * GGML_CPU_FP16_TO_FP32(x[i].d); - const float dmin = -y[i].d * GGML_CPU_FP16_TO_FP32(x[i].dmin); - uint8_t *patmp = atmp; - int vsums; - int tmp, t1, t2, t3, t4, t5, t6, t7; - __asm__ __volatile__( - "vsetivli zero, 16, e8, m1\n\t" - "vmv.v.x v8, zero\n\t" - "lb zero, 15(%[sc])\n\t" - "vle8.v v1, (%[sc])\n\t" - "vle8.v v2, (%[bsums])\n\t" - "addi %[tmp], %[bsums], 16\n\t" - "vand.vi v0, v1, 0xF\n\t" - "vsrl.vi v1, v1, 4\n\t" - "vle8.v v3, (%[tmp])\n\t" - "vse8.v v0, (%[scale])\n\t" - "vsetivli zero, 16, e16, m2\n\t" - "vzext.vf2 v0, v1\n\t" - "vwmul.vv v4, v0, v2\n\t" - "vsetivli zero, 16, e32, m4\n\t" - "vredsum.vs v8, v4, v8\n\t" - "vmv.x.s %[vsums], v8" - : [tmp] "=&r" (tmp), [vsums] "=&r" (vsums) - : [sc] "r" (sc), [scale] "r" (atmp), [bsums] "r" (y[i].bsums) - : "memory" - , "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7" - , "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15" - , "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23" - , "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31" - ); - sumf += dmin * vsums; - int isum = 0; - - for (int j = 0; j < QK_K/128; ++j) { - __asm__ __volatile__( - "lb zero, 31(%[q2])\n\t" - "addi %[tmp], %[q2], 16\n\t" - "addi %[t1], %[q8], 16\n\t" - "vsetivli zero, 16, e8, m1\n\t" - "vle8.v v0, (%[q2])\n\t" - "vle8.v v1, (%[tmp])\n\t" - "vsrl.vi v2, v0, 2\n\t" - "vsrl.vi v3, v1, 2\n\t" - "vsrl.vi v4, v0, 4\n\t" - "addi %[tmp], %[q8], 32\n\t" - "vle8.v v8, (%[q8])\n\t" - "vle8.v v9, (%[t1])\n\t" - "addi %[t1], %[t1], 32\n\t" - "vsrl.vi v5, v1, 4\n\t" - "vsrl.vi v6, v0, 6\n\t" - "vsrl.vi v7, v1, 6\n\t" - "vle8.v v10, (%[tmp])\n\t" - "vle8.v v11, (%[t1])\n\t" - "addi %[tmp], %[tmp], 32\n\t" - "addi %[t1], %[t1], 32\n\t" - "vand.vi v0, v0, 0x3\n\t" - "vand.vi v1, v1, 0x3\n\t" - "vand.vi v2, v2, 0x3\n\t" - "vle8.v v12, (%[tmp])\n\t" - "vle8.v v13, (%[t1])\n\t" - "addi %[tmp], %[tmp], 32\n\t" - "addi %[t1], %[t1], 32\n\t" - "vand.vi v3, v3, 0x3\n\t" - "vand.vi v4, v4, 0x3\n\t" - "vand.vi v5, v5, 0x3\n\t" - "vle8.v v14, (%[tmp])\n\t" - "vle8.v v15, (%[t1])\n\t" - "vwmul.vv v16, v0, v8\n\t" - "vwmul.vv v18, v1, v9\n\t" - "vwmul.vv v20, v2, v10\n\t" - "vwmul.vv v22, v3, v11\n\t" - "vwmul.vv v24, v4, v12\n\t" - "vwmul.vv v26, v5, v13\n\t" - "vwmul.vv v28, v6, v14\n\t" - "vwmul.vv v30, v7, v15\n\t" - "vsetivli zero, 8, e16, m1\n\t" - "vmv.v.x v0, zero\n\t" - "lbu %[tmp], 0(%[scale])\n\t" - "vwredsum.vs v8, v16, v0\n\t" - "vwredsum.vs v9, v18, v0\n\t" - "lbu %[t1], 1(%[scale])\n\t" - "vwredsum.vs v10, v20, v0\n\t" - "vwredsum.vs v11, v22, v0\n\t" - "lbu %[t2], 2(%[scale])\n\t" - "vwredsum.vs v12, v24, v0\n\t" - "vwredsum.vs v13, v26, v0\n\t" - "lbu %[t3], 3(%[scale])\n\t" - "vwredsum.vs v14, v28, v0\n\t" - "vwredsum.vs v15, v30, v0\n\t" - "lbu %[t4], 4(%[scale])\n\t" - "vwredsum.vs v8, v17, v8\n\t" - "vwredsum.vs v9, v19, v9\n\t" - "lbu %[t5], 5(%[scale])\n\t" - "vwredsum.vs v10, v21, v10\n\t" - "vwredsum.vs v11, v23, v11\n\t" - "lbu %[t6], 6(%[scale])\n\t" - "vwredsum.vs v12, v25, v12\n\t" - "vwredsum.vs v13, v27, v13\n\t" - "lbu %[t7], 7(%[scale])\n\t" - "vwredsum.vs v14, v29, v14\n\t" - "vwredsum.vs v15, v31, v15\n\t" - "vsetivli zero, 4, e32, m1\n\t" - "vmul.vx v0, v8, %[tmp]\n\t" - "vmul.vx v1, v9, %[t1]\n\t" - "vmacc.vx v0, %[t2], v10\n\t" - "vmacc.vx v1, %[t3], v11\n\t" - "vmacc.vx v0, %[t4], v12\n\t" - "vmacc.vx v1, %[t5], v13\n\t" - "vmacc.vx v0, %[t6], v14\n\t" - "vmacc.vx v1, %[t7], v15\n\t" - "vmv.x.s %[tmp], v0\n\t" - "vmv.x.s %[t1], v1\n\t" - "add %[isum], %[isum], %[tmp]\n\t" - "add %[isum], %[isum], %[t1]" - : [tmp] "=&r" (tmp), [t1] "=&r" (t1), [t2] "=&r" (t2), [t3] "=&r" (t3) - , [t4] "=&r" (t4), [t5] "=&r" (t5), [t6] "=&r" (t6), [t7] "=&r" (t7) - , [isum] "+&r" (isum) - : [q2] "r" (q2), [scale] "r" (patmp), [q8] "r" (q8) - : "memory" - , "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7" - , "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15" - , "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23" - , "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31" - ); - q2 += 32; q8 += 128; patmp += 8; - } + vl = 32; - sumf += dall * isum; + vint32m1_t vzero = __riscv_vmv_v_x_i32m1(0, 1); + vuint8m1_t v_b = __riscv_vle8_v_u8m1(temp_01, vl); + + uint8_t is = 0; + int isum = 0; + + for (int j = 0; j < QK_K / 128; ++j) { + // load Q2 + vuint8m1_t q2_x = __riscv_vle8_v_u8m1(q2, vl); + + vuint8m1_t q2_0 = __riscv_vand_vx_u8m1(q2_x, 0x03, vl); + vuint8m1_t q2_1 = __riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(q2_x, 0x2, vl), 0x03, vl); + vuint8m1_t q2_2 = __riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(q2_x, 0x4, vl), 0x03, vl); + vuint8m1_t q2_3 = __riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(q2_x, 0x6, vl), 0x03, vl); + + // duplicate scale elements for product + vuint8m1_t sc0 = __riscv_vrgather_vv_u8m1(aux, __riscv_vadd_vx_u8m1(v_b, 0 + is, vl), vl); + vuint8m1_t sc1 = __riscv_vrgather_vv_u8m1(aux, __riscv_vadd_vx_u8m1(v_b, 2 + is, vl), vl); + vuint8m1_t sc2 = __riscv_vrgather_vv_u8m1(aux, __riscv_vadd_vx_u8m1(v_b, 4 + is, vl), vl); + vuint8m1_t sc3 = __riscv_vrgather_vv_u8m1(aux, __riscv_vadd_vx_u8m1(v_b, 6 + is, vl), vl); + + vint16m2_t p0 = __riscv_vreinterpret_v_u16m2_i16m2(__riscv_vwmulu_vv_u16m2(q2_0, sc0, vl)); + vint16m2_t p1 = __riscv_vreinterpret_v_u16m2_i16m2(__riscv_vwmulu_vv_u16m2(q2_1, sc1, vl)); + vint16m2_t p2 = __riscv_vreinterpret_v_u16m2_i16m2(__riscv_vwmulu_vv_u16m2(q2_2, sc2, vl)); + vint16m2_t p3 = __riscv_vreinterpret_v_u16m2_i16m2(__riscv_vwmulu_vv_u16m2(q2_3, sc3, vl)); + + // load Q8 + vint8m1_t q8_0 = __riscv_vle8_v_i8m1(q8, vl); + vint8m1_t q8_1 = __riscv_vle8_v_i8m1(q8 + 32, vl); + vint8m1_t q8_2 = __riscv_vle8_v_i8m1(q8 + 64, vl); + vint8m1_t q8_3 = __riscv_vle8_v_i8m1(q8 + 96, vl); + + vint32m4_t s0 = __riscv_vwmul_vv_i32m4(p0, __riscv_vwcvt_x_x_v_i16m2(q8_0, vl), vl); + vint32m4_t s1 = __riscv_vwmul_vv_i32m4(p1, __riscv_vwcvt_x_x_v_i16m2(q8_1, vl), vl); + vint32m4_t s2 = __riscv_vwmul_vv_i32m4(p2, __riscv_vwcvt_x_x_v_i16m2(q8_2, vl), vl); + vint32m4_t s3 = __riscv_vwmul_vv_i32m4(p3, __riscv_vwcvt_x_x_v_i16m2(q8_3, vl), vl); + + vint32m1_t isum0 = __riscv_vredsum_vs_i32m4_i32m1(__riscv_vadd_vv_i32m4(s0, s1, vl), vzero, vl); + vint32m1_t isum1 = __riscv_vredsum_vs_i32m4_i32m1(__riscv_vadd_vv_i32m4(s2, s3, vl), isum0, vl); + + isum += __riscv_vmv_x_s_i32m1_i32(isum1); + + q2 += 32; + q8 += 128; + is = 8; } - break; - default: - assert(false && "Unsupported vector length"); - break; + + sumf += dall * isum; } *s = sumf; +} +#endif +void ggml_vec_dot_q2_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { +#if defined __riscv_xtheadvector + ggml_vec_dot_q2_K_q8_K_xtheadvector(n, s, bs, vx, bx, vy, by, nrc); +#elif defined __riscv_v + switch (__riscv_vlenb() * 8) { + case 128: + ggml_vec_dot_q2_K_q8_K_vl128(n, s, bs, vx, bx, vy, by, nrc); + break; + default: + ggml_vec_dot_q2_K_q8_K_vl256(n, s, bs, vx, bx, vy, by, nrc); + break; + } #else - - UNUSED(x); - UNUSED(y); - UNUSED(nb); - ggml_vec_dot_q2_K_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc); #endif } -void ggml_vec_dot_q3_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { +#if defined __riscv_xtheadvector +void ggml_vec_dot_q3_K_q8_K_xtheadvector(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { assert(n % QK_K == 0); assert(nrc == 1); UNUSED(nrc); @@ -739,8 +975,6 @@ void ggml_vec_dot_q3_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi const int nb = n / QK_K; -#if defined __riscv_xtheadvector - uint32_t utmp[4]; float sumf = 0; @@ -866,257 +1100,538 @@ void ggml_vec_dot_q3_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi } *s = sumf; +} +#endif -#elif defined __riscv_v +#if defined __riscv_v +void ggml_vec_dot_q3_K_q8_K_vl128(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + assert(n % QK_K == 0); + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + + const uint32_t kmask1 = 0x03030303; + const uint32_t kmask2 = 0x0f0f0f0f; + + const block_q3_K * GGML_RESTRICT x = vx; + const block_q8_K * GGML_RESTRICT y = vy; + + const int nb = n / QK_K; uint32_t utmp[4]; float sumf = 0; uint32_t aux[3]; - const int vector_length = __riscv_vlenb() * 8; - switch (vector_length) { - case 256: - for (int i = 0; i < nb; ++i) { - - const uint8_t * GGML_RESTRICT q3 = x[i].qs; - const uint8_t * GGML_RESTRICT qh = x[i].hmask; - const int8_t * GGML_RESTRICT q8 = y[i].qs; - - memcpy(aux, x[i].scales, 12); - utmp[3] = ((aux[1] >> 4) & kmask2) | (((aux[2] >> 6) & kmask1) << 4); - utmp[2] = ((aux[0] >> 4) & kmask2) | (((aux[2] >> 4) & kmask1) << 4); - utmp[1] = (aux[1] & kmask2) | (((aux[2] >> 2) & kmask1) << 4); - utmp[0] = (aux[0] & kmask2) | (((aux[2] >> 0) & kmask1) << 4); + for (int i = 0; i < nb; ++i) { + const uint8_t * restrict q3 = x[i].qs; + const uint8_t * restrict qh = x[i].hmask; + const int8_t * restrict q8 = y[i].qs; - int8_t * scale = (int8_t *)utmp; - for (int j = 0; j < 16; ++j) scale[j] -= 32; + int8_t * scale = (int8_t *)utmp; + int tmp, t1, t2, t3, t4, t5, t6, t7; + __asm__ __volatile__( + "vsetivli zero, 12, e8, m1\n\t" + "vle8.v v0, (%[s6b])\n\t" + "vmv1r.v v2, v0\n\t" + "vsetivli zero, 2, e64, m1\n\t" + "vmv.v.x v9, %[sh]\n\t"\ + "vslidedown.vi v1, v0, 1\n\t" + "vslide1up.vx v8, v9, zero\n\t" // {0, 0, 4, 4} + "vslideup.vi v0, v2, 1\n\t" // {aux[0], aux[1], aux[0], aux[1]} + "vsetivli zero, 4, e32, m1\n\t" + "vid.v v9\n\t" + "vmv.x.s %[tmp], v1\n\t" + "vsll.vi v9, v9, 1\n\t" // {0, 2, 4, 6} + "vmv.v.x v1, %[tmp]\n\t" // {aux[2], aux[2], aux[2], aux[2]} + "vsrl.vv v4, v1, v9\n\t" + "vsrl.vv v2, v0, v8\n\t" + "vand.vx v5, v4, %[kmask1]\n\t" + "vand.vx v3, v2, %[kmask2]\n\t" + "vsll.vi v6, v5, 4\n\t" + "vor.vv v7, v6, v3\n\t" + "vsetivli zero, 16, e8, m1\n\t" + "vsub.vx v0, v7, %[c]\n\t" + "vse8.v v0, (%[scale])" + : [tmp] "=&r" (tmp) + : [sh] "r" (0x0000000400000004), [s6b] "r" (x[i].scales), [c] "r" (32) + , [scale] "r" (scale), [kmask1] "r" (kmask1), [kmask2] "r" (kmask2) + : "memory" + , "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7" + , "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15" + , "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23" + , "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31" + ); + uint8_t m = 1; + int isum = 0; + for (int j = 0; j < QK_K; j += 128) { + __asm__ __volatile__( + "lb zero, 31(%[q3])\n\t" + "vsetvli zero, %[vl32], e8, m2, ta, mu\n\t" + "vle8.v v8, (%[q3])\n\t" + "vsrl.vi v10, v8, 2\n\t" + "vsrl.vi v12, v8, 4\n\t" + "vsrl.vi v14, v8, 6\n\t" + "lb zero, 64(%[q8])\n\t" + "vand.vi v8, v8, 3\n\t" + "vand.vi v10, v10, 3\n\t" + "vand.vi v12, v12, 3\n\t" + "vle8.v v2, (%[qh])\n\t" + "lb zero, 127(%[q8])\n\t" + "vand.vx v4, v2, %[m]\n\t" + "slli %[m], %[m], 1\n\t" + "vmseq.vx v0, v4, zero\n\t" + "vadd.vi v8, v8, -4, v0.t\n\t" + "lb zero, 0(%[q8])\n\t" + "vand.vx v4, v2, %[m]\n\t" + "slli %[m], %[m], 1\n\t" + "vmseq.vx v0, v4, zero\n\t" + "vadd.vi v10, v10, -4, v0.t\n\t" + "vand.vx v4, v2, %[m]\n\t" + "slli %[m], %[m], 1\n\t" + "vmseq.vx v0, v4, zero\n\t" + "vadd.vi v12, v12, -4, v0.t\n\t" + "vand.vx v4, v2, %[m]\n\t" + "slli %[m], %[m], 1\n\t" + "vmseq.vx v0, v4, zero\n\t" + "vadd.vi v14, v14, -4, v0.t\n\t" + "vsetvli zero, %[vl128], e8, m8\n\t" + "vle8.v v0, (%[q8])\n\t" + "lb %[tmp], 0(%[scale])\n\t" + "lb %[t1], 1(%[scale])\n\t" + "lb %[t2], 2(%[scale])\n\t" + "lb %[t3], 3(%[scale])\n\t" + "vsetvli zero, %[vl64], e8, m4\n\t" + "vwmul.vv v16, v0, v8\n\t" + "vwmul.vv v24, v4, v12\n\t" + "vsetivli zero, 16, e16, m2\n\t" + "vmv.v.x v0, zero\n\t" + "vwredsum.vs v8, v16, v0\n\t" + "lb %[t4], 4(%[scale])\n\t" + "lb %[t5], 5(%[scale])\n\t" + "vwredsum.vs v9, v18, v0\n\t" + "vwredsum.vs v10, v20, v0\n\t" + "vwredsum.vs v11, v22, v0\n\t" + "vwredsum.vs v12, v24, v0\n\t" + "lb %[t6], 6(%[scale])\n\t" + "lb %[t7], 7(%[scale])\n\t" + "vwredsum.vs v13, v26, v0\n\t" + "vwredsum.vs v14, v28, v0\n\t" + "vwredsum.vs v15, v30, v0\n\t" + "vsetivli zero, 4, e32, m1\n\t" + "vmul.vx v0, v8, %[tmp]\n\t" + "vmul.vx v1, v9, %[t1]\n\t" + "vmacc.vx v0, %[t2], v10\n\t" + "vmacc.vx v1, %[t3], v11\n\t" + "vmacc.vx v0, %[t4], v12\n\t" + "vmacc.vx v1, %[t5], v13\n\t" + "vmacc.vx v0, %[t6], v14\n\t" + "vmacc.vx v1, %[t7], v15\n\t" + "vmv.x.s %[tmp], v0\n\t" + "vmv.x.s %[t1], v1\n\t" + "add %[isum], %[isum], %[tmp]\n\t" + "add %[isum], %[isum], %[t1]" + : [tmp] "=&r" (tmp), [t1] "=&r" (t1), [t2] "=&r" (t2), [t3] "=&r" (t3) + , [t4] "=&r" (t4), [t5] "=&r" (t5), [t6] "=&r" (t6), [t7] "=&r" (t7) + , [m] "+&r" (m), [isum] "+&r" (isum) + : [vl128] "r" (128), [vl64] "r" (64), [vl32] "r" (32) + , [q3] "r" (q3), [qh] "r" (qh), [scale] "r" (scale), [q8] "r" (q8) + : "memory" + , "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7" + , "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15" + , "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23" + , "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31" + ); + q3 += 32; q8 += 128; scale += 8; + } - size_t vl = 32; - uint8_t m = 1; + const float d = GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d; + sumf += d * isum; + } - vint32m1_t vzero = __riscv_vmv_v_x_i32m1(0, 1); - vuint8m1_t vqh = __riscv_vle8_v_u8m1(qh, vl); + *s = sumf; +} - int sum_t = 0; +void ggml_vec_dot_q3_K_q8_K_vl256(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + assert(n % QK_K == 0); + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + + const uint32_t kmask1 = 0x03030303; + const uint32_t kmask2 = 0x0f0f0f0f; + + const block_q3_K * GGML_RESTRICT x = vx; + const block_q8_K * GGML_RESTRICT y = vy; + + const int nb = n / QK_K; + uint32_t utmp[4]; + float sumf = 0; + uint32_t aux[3]; + + for (int i = 0; i < nb; ++i) { + const uint8_t * GGML_RESTRICT q3 = x[i].qs; + const uint8_t * GGML_RESTRICT qh = x[i].hmask; + const int8_t * GGML_RESTRICT q8 = y[i].qs; - for (int j = 0; j < QK_K; j += 128) { + memcpy(aux, x[i].scales, 12); + utmp[3] = ((aux[1] >> 4) & kmask2) | (((aux[2] >> 6) & kmask1) << 4); + utmp[2] = ((aux[0] >> 4) & kmask2) | (((aux[2] >> 4) & kmask1) << 4); + utmp[1] = (aux[1] & kmask2) | (((aux[2] >> 2) & kmask1) << 4); + utmp[0] = (aux[0] & kmask2) | (((aux[2] >> 0) & kmask1) << 4); + + int8_t * scale = (int8_t *)utmp; + for (int j = 0; j < 16; ++j) scale[j] -= 32; - vl = 32; - // load Q3 - vuint8m1_t q3_x = __riscv_vle8_v_u8m1(q3, vl); + size_t vl = 32; + uint8_t m = 1; + + vint32m1_t vzero = __riscv_vmv_v_x_i32m1(0, 1); + vuint8m1_t vqh = __riscv_vle8_v_u8m1(qh, vl); - vint8m1_t q3_0 = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vand_vx_u8m1(q3_x, 0x03, vl)); - vint8m1_t q3_1 = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(q3_x, 0x2, vl), 0x03 , vl)); - vint8m1_t q3_2 = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(q3_x, 0x4, vl), 0x03 , vl)); - vint8m1_t q3_3 = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(q3_x, 0x6, vl), 0x03 , vl)); + int sum_t = 0; - // compute mask for subtraction - vuint8m1_t qh_m0 = __riscv_vand_vx_u8m1(vqh, m, vl); - vbool8_t vmask_0 = __riscv_vmseq_vx_u8m1_b8(qh_m0, 0, vl); - vint8m1_t q3_m0 = __riscv_vsub_vx_i8m1_mu(vmask_0, q3_0, q3_0, 0x4, vl); - m <<= 1; + for (int j = 0; j < QK_K; j += 128) { - vuint8m1_t qh_m1 = __riscv_vand_vx_u8m1(vqh, m, vl); - vbool8_t vmask_1 = __riscv_vmseq_vx_u8m1_b8(qh_m1, 0, vl); - vint8m1_t q3_m1 = __riscv_vsub_vx_i8m1_mu(vmask_1, q3_1, q3_1, 0x4, vl); - m <<= 1; + vl = 32; - vuint8m1_t qh_m2 = __riscv_vand_vx_u8m1(vqh, m, vl); - vbool8_t vmask_2 = __riscv_vmseq_vx_u8m1_b8(qh_m2, 0, vl); - vint8m1_t q3_m2 = __riscv_vsub_vx_i8m1_mu(vmask_2, q3_2, q3_2, 0x4, vl); - m <<= 1; + // load Q3 + vuint8m1_t q3_x = __riscv_vle8_v_u8m1(q3, vl); - vuint8m1_t qh_m3 = __riscv_vand_vx_u8m1(vqh, m, vl); - vbool8_t vmask_3 = __riscv_vmseq_vx_u8m1_b8(qh_m3, 0, vl); - vint8m1_t q3_m3 = __riscv_vsub_vx_i8m1_mu(vmask_3, q3_3, q3_3, 0x4, vl); - m <<= 1; + vint8m1_t q3_0 = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vand_vx_u8m1(q3_x, 0x03, vl)); + vint8m1_t q3_1 = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(q3_x, 0x2, vl), 0x03 , vl)); + vint8m1_t q3_2 = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(q3_x, 0x4, vl), 0x03 , vl)); + vint8m1_t q3_3 = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(q3_x, 0x6, vl), 0x03 , vl)); - // load Q8 and take product with Q3 - vint16m2_t a0 = __riscv_vwmul_vv_i16m2(q3_m0, __riscv_vle8_v_i8m1(q8, vl), vl); - vint16m2_t a1 = __riscv_vwmul_vv_i16m2(q3_m1, __riscv_vle8_v_i8m1(q8+32, vl), vl); - vint16m2_t a2 = __riscv_vwmul_vv_i16m2(q3_m2, __riscv_vle8_v_i8m1(q8+64, vl), vl); - vint16m2_t a3 = __riscv_vwmul_vv_i16m2(q3_m3, __riscv_vle8_v_i8m1(q8+96, vl), vl); + // compute mask for subtraction + vuint8m1_t qh_m0 = __riscv_vand_vx_u8m1(vqh, m, vl); + vbool8_t vmask_0 = __riscv_vmseq_vx_u8m1_b8(qh_m0, 0, vl); + vint8m1_t q3_m0 = __riscv_vsub_vx_i8m1_mu(vmask_0, q3_0, q3_0, 0x4, vl); + m <<= 1; - vl = 16; + vuint8m1_t qh_m1 = __riscv_vand_vx_u8m1(vqh, m, vl); + vbool8_t vmask_1 = __riscv_vmseq_vx_u8m1_b8(qh_m1, 0, vl); + vint8m1_t q3_m1 = __riscv_vsub_vx_i8m1_mu(vmask_1, q3_1, q3_1, 0x4, vl); + m <<= 1; - // retrieve lane to multiply with scale - vint32m2_t aux0_0 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a0, 0), (scale[0]), vl); - vint32m2_t aux0_1 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a0, 1), (scale[1]), vl); - vint32m2_t aux1_0 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a1, 0), (scale[2]), vl); - vint32m2_t aux1_1 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a1, 1), (scale[3]), vl); - vint32m2_t aux2_0 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a2, 0), (scale[4]), vl); - vint32m2_t aux2_1 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a2, 1), (scale[5]), vl); - vint32m2_t aux3_0 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a3, 0), (scale[6]), vl); - vint32m2_t aux3_1 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a3, 1), (scale[7]), vl); + vuint8m1_t qh_m2 = __riscv_vand_vx_u8m1(vqh, m, vl); + vbool8_t vmask_2 = __riscv_vmseq_vx_u8m1_b8(qh_m2, 0, vl); + vint8m1_t q3_m2 = __riscv_vsub_vx_i8m1_mu(vmask_2, q3_2, q3_2, 0x4, vl); + m <<= 1; - vint32m1_t isum0 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(aux0_0, aux0_1, vl), vzero, vl); - vint32m1_t isum1 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(aux1_0, aux1_1, vl), isum0, vl); - vint32m1_t isum2 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(aux2_0, aux2_1, vl), isum1, vl); - vint32m1_t isum3 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(aux3_0, aux3_1, vl), isum2, vl); + vuint8m1_t qh_m3 = __riscv_vand_vx_u8m1(vqh, m, vl); + vbool8_t vmask_3 = __riscv_vmseq_vx_u8m1_b8(qh_m3, 0, vl); + vint8m1_t q3_m3 = __riscv_vsub_vx_i8m1_mu(vmask_3, q3_3, q3_3, 0x4, vl); + m <<= 1; - sum_t += __riscv_vmv_x_s_i32m1_i32(isum3); + // load Q8 and take product with Q3 + vint16m2_t a0 = __riscv_vwmul_vv_i16m2(q3_m0, __riscv_vle8_v_i8m1(q8, vl), vl); + vint16m2_t a1 = __riscv_vwmul_vv_i16m2(q3_m1, __riscv_vle8_v_i8m1(q8+32, vl), vl); + vint16m2_t a2 = __riscv_vwmul_vv_i16m2(q3_m2, __riscv_vle8_v_i8m1(q8+64, vl), vl); + vint16m2_t a3 = __riscv_vwmul_vv_i16m2(q3_m3, __riscv_vle8_v_i8m1(q8+96, vl), vl); - q3 += 32; q8 += 128; scale += 8; + vl = 16; - } + // retrieve lane to multiply with scale + vint32m2_t aux0_0 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a0, 0), (scale[0]), vl); + vint32m2_t aux0_1 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a0, 1), (scale[1]), vl); + vint32m2_t aux1_0 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a1, 0), (scale[2]), vl); + vint32m2_t aux1_1 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a1, 1), (scale[3]), vl); + vint32m2_t aux2_0 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a2, 0), (scale[4]), vl); + vint32m2_t aux2_1 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a2, 1), (scale[5]), vl); + vint32m2_t aux3_0 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a3, 0), (scale[6]), vl); + vint32m2_t aux3_1 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a3, 1), (scale[7]), vl); - const float d = GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d; + vint32m1_t isum0 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(aux0_0, aux0_1, vl), vzero, vl); + vint32m1_t isum1 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(aux1_0, aux1_1, vl), isum0, vl); + vint32m1_t isum2 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(aux2_0, aux2_1, vl), isum1, vl); + vint32m1_t isum3 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(aux3_0, aux3_1, vl), isum2, vl); - sumf += d*sum_t; + sum_t += __riscv_vmv_x_s_i32m1_i32(isum3); + + q3 += 32; q8 += 128; scale += 8; } - break; - case 128: - for (int i = 0; i < nb; ++i) { - const uint8_t * restrict q3 = x[i].qs; - const uint8_t * restrict qh = x[i].hmask; - const int8_t * restrict q8 = y[i].qs; - int8_t * scale = (int8_t *)utmp; - int tmp, t1, t2, t3, t4, t5, t6, t7; - __asm__ __volatile__( - "vsetivli zero, 12, e8, m1\n\t" - "vle8.v v0, (%[s6b])\n\t" - "vmv1r.v v2, v0\n\t" - "vsetivli zero, 2, e64, m1\n\t" - "vmv.v.x v9, %[sh]\n\t"\ - "vslidedown.vi v1, v0, 1\n\t" - "vslide1up.vx v8, v9, zero\n\t" // {0, 0, 4, 4} - "vslideup.vi v0, v2, 1\n\t" // {aux[0], aux[1], aux[0], aux[1]} - "vsetivli zero, 4, e32, m1\n\t" - "vid.v v9\n\t" - "vmv.x.s %[tmp], v1\n\t" - "vsll.vi v9, v9, 1\n\t" // {0, 2, 4, 6} - "vmv.v.x v1, %[tmp]\n\t" // {aux[2], aux[2], aux[2], aux[2]} - "vsrl.vv v4, v1, v9\n\t" - "vsrl.vv v2, v0, v8\n\t" - "vand.vx v5, v4, %[kmask1]\n\t" - "vand.vx v3, v2, %[kmask2]\n\t" - "vsll.vi v6, v5, 4\n\t" - "vor.vv v7, v6, v3\n\t" - "vsetivli zero, 16, e8, m1\n\t" - "vsub.vx v0, v7, %[c]\n\t" - "vse8.v v0, (%[scale])" - : [tmp] "=&r" (tmp) - : [sh] "r" (0x0000000400000004), [s6b] "r" (x[i].scales), [c] "r" (32) - , [scale] "r" (scale), [kmask1] "r" (kmask1), [kmask2] "r" (kmask2) - : "memory" - , "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7" - , "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15" - , "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23" - , "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31" - ); + const float d = GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d; - uint8_t m = 1; - int isum = 0; - for (int j = 0; j < QK_K; j += 128) { - __asm__ __volatile__( - "lb zero, 31(%[q3])\n\t" - "vsetvli zero, %[vl32], e8, m2, ta, mu\n\t" - "vle8.v v8, (%[q3])\n\t" - "vsrl.vi v10, v8, 2\n\t" - "vsrl.vi v12, v8, 4\n\t" - "vsrl.vi v14, v8, 6\n\t" - "lb zero, 64(%[q8])\n\t" - "vand.vi v8, v8, 3\n\t" - "vand.vi v10, v10, 3\n\t" - "vand.vi v12, v12, 3\n\t" - "vle8.v v2, (%[qh])\n\t" - "lb zero, 127(%[q8])\n\t" - "vand.vx v4, v2, %[m]\n\t" - "slli %[m], %[m], 1\n\t" - "vmseq.vx v0, v4, zero\n\t" - "vadd.vi v8, v8, -4, v0.t\n\t" - "lb zero, 0(%[q8])\n\t" - "vand.vx v4, v2, %[m]\n\t" - "slli %[m], %[m], 1\n\t" - "vmseq.vx v0, v4, zero\n\t" - "vadd.vi v10, v10, -4, v0.t\n\t" - "vand.vx v4, v2, %[m]\n\t" - "slli %[m], %[m], 1\n\t" - "vmseq.vx v0, v4, zero\n\t" - "vadd.vi v12, v12, -4, v0.t\n\t" - "vand.vx v4, v2, %[m]\n\t" - "slli %[m], %[m], 1\n\t" - "vmseq.vx v0, v4, zero\n\t" - "vadd.vi v14, v14, -4, v0.t\n\t" - "vsetvli zero, %[vl128], e8, m8\n\t" - "vle8.v v0, (%[q8])\n\t" - "lb %[tmp], 0(%[scale])\n\t" - "lb %[t1], 1(%[scale])\n\t" - "lb %[t2], 2(%[scale])\n\t" - "lb %[t3], 3(%[scale])\n\t" - "vsetvli zero, %[vl64], e8, m4\n\t" - "vwmul.vv v16, v0, v8\n\t" - "vwmul.vv v24, v4, v12\n\t" - "vsetivli zero, 16, e16, m2\n\t" - "vmv.v.x v0, zero\n\t" - "vwredsum.vs v8, v16, v0\n\t" - "lb %[t4], 4(%[scale])\n\t" - "lb %[t5], 5(%[scale])\n\t" - "vwredsum.vs v9, v18, v0\n\t" - "vwredsum.vs v10, v20, v0\n\t" - "vwredsum.vs v11, v22, v0\n\t" - "vwredsum.vs v12, v24, v0\n\t" - "lb %[t6], 6(%[scale])\n\t" - "lb %[t7], 7(%[scale])\n\t" - "vwredsum.vs v13, v26, v0\n\t" - "vwredsum.vs v14, v28, v0\n\t" - "vwredsum.vs v15, v30, v0\n\t" - "vsetivli zero, 4, e32, m1\n\t" - "vmul.vx v0, v8, %[tmp]\n\t" - "vmul.vx v1, v9, %[t1]\n\t" - "vmacc.vx v0, %[t2], v10\n\t" - "vmacc.vx v1, %[t3], v11\n\t" - "vmacc.vx v0, %[t4], v12\n\t" - "vmacc.vx v1, %[t5], v13\n\t" - "vmacc.vx v0, %[t6], v14\n\t" - "vmacc.vx v1, %[t7], v15\n\t" - "vmv.x.s %[tmp], v0\n\t" - "vmv.x.s %[t1], v1\n\t" - "add %[isum], %[isum], %[tmp]\n\t" - "add %[isum], %[isum], %[t1]" - : [tmp] "=&r" (tmp), [t1] "=&r" (t1), [t2] "=&r" (t2), [t3] "=&r" (t3) - , [t4] "=&r" (t4), [t5] "=&r" (t5), [t6] "=&r" (t6), [t7] "=&r" (t7) - , [m] "+&r" (m), [isum] "+&r" (isum) - : [vl128] "r" (128), [vl64] "r" (64), [vl32] "r" (32) - , [q3] "r" (q3), [qh] "r" (qh), [scale] "r" (scale), [q8] "r" (q8) - : "memory" - , "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7" - , "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15" - , "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23" - , "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31" - ); - q3 += 32; q8 += 128; scale += 8; - } + sumf += d*sum_t; + + } + + *s = sumf; +} + +void ggml_vec_dot_q3_K_q8_K_vl512(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + assert(n % QK_K == 0); + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + + const uint32_t kmask1 = 0x03030303; + const uint32_t kmask2 = 0x0f0f0f0f; + + const block_q3_K * GGML_RESTRICT x = vx; + const block_q8_K * GGML_RESTRICT y = vy; + + const int nb = n / QK_K; + + // mask for processing 16 elements per prod register + const vuint16m1_t va_index = __riscv_vid_v_u16m1(32); + const vbool16_t va_mask = __riscv_vmsgtu_vx_u16m1_b16(va_index, 15, 32); + + uint32_t utmp[4]; + float sumf = 0; + uint32_t aux[3]; + + for (int i = 0; i < nb; ++i) { + const uint8_t * GGML_RESTRICT q3 = x[i].qs; + const uint8_t * GGML_RESTRICT qh = x[i].hmask; + const int8_t * GGML_RESTRICT q8 = y[i].qs; + + memcpy(aux, x[i].scales, 12); + utmp[3] = ((aux[1] >> 4) & kmask2) | (((aux[2] >> 6) & kmask1) << 4); + utmp[2] = ((aux[0] >> 4) & kmask2) | (((aux[2] >> 4) & kmask1) << 4); + utmp[1] = (aux[1] & kmask2) | (((aux[2] >> 2) & kmask1) << 4); + utmp[0] = (aux[0] & kmask2) | (((aux[2] >> 0) & kmask1) << 4); + + int8_t * scale = (int8_t *)utmp; + for (int j = 0; j < 16; ++j) scale[j] -= 32; + + + size_t vl = 32; + uint8_t m = 1; + + vint32m1_t vzero = __riscv_vmv_v_x_i32m1(0, 1); + vuint8mf2_t vqh = __riscv_vle8_v_u8mf2(qh, vl); + + int sum_t = 0; + + vint32m2_t vaux_0 = __riscv_vmv_v_x_i32m2(0, vl); + vint32m2_t vaux_1 = __riscv_vmv_v_x_i32m2(0, vl); + vint32m2_t vaux_2 = __riscv_vmv_v_x_i32m2(0, vl); + vint32m2_t vaux_3 = __riscv_vmv_v_x_i32m2(0, vl); + + for (int j = 0; j < QK_K; j += 128) { + + vl = 32; + + // load Q3 + vuint8mf2_t q3_x = __riscv_vle8_v_u8mf2(q3, vl); + + vint8mf2_t q3_0 = __riscv_vreinterpret_v_u8mf2_i8mf2(__riscv_vand_vx_u8mf2(q3_x, 0x03, vl)); + vint8mf2_t q3_1 = __riscv_vreinterpret_v_u8mf2_i8mf2(__riscv_vand_vx_u8mf2(__riscv_vsrl_vx_u8mf2(q3_x, 0x2, vl), 0x03 , vl)); + vint8mf2_t q3_2 = __riscv_vreinterpret_v_u8mf2_i8mf2(__riscv_vand_vx_u8mf2(__riscv_vsrl_vx_u8mf2(q3_x, 0x4, vl), 0x03 , vl)); + vint8mf2_t q3_3 = __riscv_vreinterpret_v_u8mf2_i8mf2(__riscv_vand_vx_u8mf2(__riscv_vsrl_vx_u8mf2(q3_x, 0x6, vl), 0x03 , vl)); + + // compute mask for subtraction + vuint8mf2_t qh_m0 = __riscv_vand_vx_u8mf2(vqh, m, vl); + vbool16_t vmask_0 = __riscv_vmseq_vx_u8mf2_b16(qh_m0, 0, vl); + vint8mf2_t q3_m0 = __riscv_vsub_vx_i8mf2_mu(vmask_0, q3_0, q3_0, 0x4, vl); + m <<= 1; + + vuint8mf2_t qh_m1 = __riscv_vand_vx_u8mf2(vqh, m, vl); + vbool16_t vmask_1 = __riscv_vmseq_vx_u8mf2_b16(qh_m1, 0, vl); + vint8mf2_t q3_m1 = __riscv_vsub_vx_i8mf2_mu(vmask_1, q3_1, q3_1, 0x4, vl); + m <<= 1; + + vuint8mf2_t qh_m2 = __riscv_vand_vx_u8mf2(vqh, m, vl); + vbool16_t vmask_2 = __riscv_vmseq_vx_u8mf2_b16(qh_m2, 0, vl); + vint8mf2_t q3_m2 = __riscv_vsub_vx_i8mf2_mu(vmask_2, q3_2, q3_2, 0x4, vl); + m <<= 1; + + vuint8mf2_t qh_m3 = __riscv_vand_vx_u8mf2(vqh, m, vl); + vbool16_t vmask_3 = __riscv_vmseq_vx_u8mf2_b16(qh_m3, 0, vl); + vint8mf2_t q3_m3 = __riscv_vsub_vx_i8mf2_mu(vmask_3, q3_3, q3_3, 0x4, vl); + m <<= 1; + + // load Q8 and take product + vint16m1_t va_q_0 = __riscv_vwmul_vv_i16m1(q3_m0, __riscv_vle8_v_i8mf2(q8, vl), vl); + vint16m1_t va_q_1 = __riscv_vwmul_vv_i16m1(q3_m1, __riscv_vle8_v_i8mf2(q8+32, vl), vl); + vint16m1_t va_q_2 = __riscv_vwmul_vv_i16m1(q3_m2, __riscv_vle8_v_i8mf2(q8+64, vl), vl); + vint16m1_t va_q_3 = __riscv_vwmul_vv_i16m1(q3_m3, __riscv_vle8_v_i8mf2(q8+96, vl), vl); + + // accumulate + vaux_0 = __riscv_vwmacc_vx_i32m2(vaux_0, scale[0], va_q_0, 16); + vaux_1 = __riscv_vwmacc_vx_i32m2(vaux_1, scale[2], va_q_1, 16); + vaux_2 = __riscv_vwmacc_vx_i32m2(vaux_2, scale[4], va_q_2, 16); + vaux_3 = __riscv_vwmacc_vx_i32m2(vaux_3, scale[6], va_q_3, 16); + // + vaux_0 = __riscv_vwmacc_vx_i32m2_m(va_mask, vaux_0, scale[1], va_q_0, vl); + vaux_1 = __riscv_vwmacc_vx_i32m2_m(va_mask, vaux_1, scale[3], va_q_1, vl); + vaux_2 = __riscv_vwmacc_vx_i32m2_m(va_mask, vaux_2, scale[5], va_q_2, vl); + vaux_3 = __riscv_vwmacc_vx_i32m2_m(va_mask, vaux_3, scale[7], va_q_3, vl); - const float d = GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d; - sumf += d * isum; + q3 += 32; q8 += 128; scale += 8; } - break; - default: - assert(false && "Unsupported vector length"); - break; + + vint32m1_t isum0 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(vaux_0, vaux_1, vl), vzero, vl); + vint32m1_t isum1 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(vaux_2, vaux_3, vl), isum0, vl); + + sum_t += __riscv_vmv_x_s_i32m1_i32(isum1); + + const float d = GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d; + + sumf += d*sum_t; } *s = sumf; +} -#else +void ggml_vec_dot_q3_K_q8_K_vl1024(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + assert(n % QK_K == 0); + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); - UNUSED(kmask1); - UNUSED(kmask2); - UNUSED(x); - UNUSED(y); - UNUSED(nb); + const uint32_t kmask1 = 0x03030303; + const uint32_t kmask2 = 0x0f0f0f0f; - ggml_vec_dot_q3_K_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc); + const block_q3_K * GGML_RESTRICT x = vx; + const block_q8_K * GGML_RESTRICT y = vy; + + const int nb = n / QK_K; + + // mask for processing 16 elements per prod register + const vuint16mf2_t va_index = __riscv_vid_v_u16mf2(32); + const vbool32_t va_mask = __riscv_vmsgtu_vx_u16mf2_b32(va_index, 15, 32); + + uint32_t utmp[4]; + float sumf = 0; + uint32_t aux[3]; + + for (int i = 0; i < nb; ++i) { + const uint8_t * GGML_RESTRICT q3 = x[i].qs; + const uint8_t * GGML_RESTRICT qh = x[i].hmask; + const int8_t * GGML_RESTRICT q8 = y[i].qs; + + memcpy(aux, x[i].scales, 12); + utmp[3] = ((aux[1] >> 4) & kmask2) | (((aux[2] >> 6) & kmask1) << 4); + utmp[2] = ((aux[0] >> 4) & kmask2) | (((aux[2] >> 4) & kmask1) << 4); + utmp[1] = (aux[1] & kmask2) | (((aux[2] >> 2) & kmask1) << 4); + utmp[0] = (aux[0] & kmask2) | (((aux[2] >> 0) & kmask1) << 4); + + int8_t * scale = (int8_t *)utmp; + for (int j = 0; j < 16; ++j) scale[j] -= 32; + + + size_t vl = 32; + uint8_t m = 1; + + vint32m1_t vzero = __riscv_vmv_v_x_i32m1(0, 1); + vuint8mf4_t vqh = __riscv_vle8_v_u8mf4(qh, vl); + + int sum_t = 0; + + vint32m1_t vaux_0 = __riscv_vmv_v_x_i32m1(0, vl); + vint32m1_t vaux_1 = __riscv_vmv_v_x_i32m1(0, vl); + vint32m1_t vaux_2 = __riscv_vmv_v_x_i32m1(0, vl); + vint32m1_t vaux_3 = __riscv_vmv_v_x_i32m1(0, vl); + + for (int j = 0; j < QK_K; j += 128) { + + vl = 32; + + // load Q3 + vuint8mf4_t q3_x = __riscv_vle8_v_u8mf4(q3, vl); + + vint8mf4_t q3_0 = __riscv_vreinterpret_v_u8mf4_i8mf4(__riscv_vand_vx_u8mf4(q3_x, 0x03, vl)); + vint8mf4_t q3_1 = __riscv_vreinterpret_v_u8mf4_i8mf4(__riscv_vand_vx_u8mf4(__riscv_vsrl_vx_u8mf4(q3_x, 0x2, vl), 0x03 , vl)); + vint8mf4_t q3_2 = __riscv_vreinterpret_v_u8mf4_i8mf4(__riscv_vand_vx_u8mf4(__riscv_vsrl_vx_u8mf4(q3_x, 0x4, vl), 0x03 , vl)); + vint8mf4_t q3_3 = __riscv_vreinterpret_v_u8mf4_i8mf4(__riscv_vand_vx_u8mf4(__riscv_vsrl_vx_u8mf4(q3_x, 0x6, vl), 0x03 , vl)); + + // compute mask for subtraction + vuint8mf4_t qh_m0 = __riscv_vand_vx_u8mf4(vqh, m, vl); + vbool32_t vmask_0 = __riscv_vmseq_vx_u8mf4_b32(qh_m0, 0, vl); + vint8mf4_t q3_m0 = __riscv_vsub_vx_i8mf4_mu(vmask_0, q3_0, q3_0, 0x4, vl); + m <<= 1; + + vuint8mf4_t qh_m1 = __riscv_vand_vx_u8mf4(vqh, m, vl); + vbool32_t vmask_1 = __riscv_vmseq_vx_u8mf4_b32(qh_m1, 0, vl); + vint8mf4_t q3_m1 = __riscv_vsub_vx_i8mf4_mu(vmask_1, q3_1, q3_1, 0x4, vl); + m <<= 1; + + vuint8mf4_t qh_m2 = __riscv_vand_vx_u8mf4(vqh, m, vl); + vbool32_t vmask_2 = __riscv_vmseq_vx_u8mf4_b32(qh_m2, 0, vl); + vint8mf4_t q3_m2 = __riscv_vsub_vx_i8mf4_mu(vmask_2, q3_2, q3_2, 0x4, vl); + m <<= 1; + + vuint8mf4_t qh_m3 = __riscv_vand_vx_u8mf4(vqh, m, vl); + vbool32_t vmask_3 = __riscv_vmseq_vx_u8mf4_b32(qh_m3, 0, vl); + vint8mf4_t q3_m3 = __riscv_vsub_vx_i8mf4_mu(vmask_3, q3_3, q3_3, 0x4, vl); + m <<= 1; + + // load Q8 and take product + vint16mf2_t va_q_0 = __riscv_vwmul_vv_i16mf2(q3_m0, __riscv_vle8_v_i8mf4(q8, vl), vl); + vint16mf2_t va_q_1 = __riscv_vwmul_vv_i16mf2(q3_m1, __riscv_vle8_v_i8mf4(q8+32, vl), vl); + vint16mf2_t va_q_2 = __riscv_vwmul_vv_i16mf2(q3_m2, __riscv_vle8_v_i8mf4(q8+64, vl), vl); + vint16mf2_t va_q_3 = __riscv_vwmul_vv_i16mf2(q3_m3, __riscv_vle8_v_i8mf4(q8+96, vl), vl); + + // accumulate + vaux_0 = __riscv_vwmacc_vx_i32m1(vaux_0, scale[0], va_q_0, 16); + vaux_1 = __riscv_vwmacc_vx_i32m1(vaux_1, scale[2], va_q_1, 16); + vaux_2 = __riscv_vwmacc_vx_i32m1(vaux_2, scale[4], va_q_2, 16); + vaux_3 = __riscv_vwmacc_vx_i32m1(vaux_3, scale[6], va_q_3, 16); + // + vaux_0 = __riscv_vwmacc_vx_i32m1_m(va_mask, vaux_0, scale[1], va_q_0, vl); + vaux_1 = __riscv_vwmacc_vx_i32m1_m(va_mask, vaux_1, scale[3], va_q_1, vl); + vaux_2 = __riscv_vwmacc_vx_i32m1_m(va_mask, vaux_2, scale[5], va_q_2, vl); + vaux_3 = __riscv_vwmacc_vx_i32m1_m(va_mask, vaux_3, scale[7], va_q_3, vl); + + q3 += 32; q8 += 128; scale += 8; + } + + vint32m1_t isum0 = __riscv_vredsum_vs_i32m1_i32m1(__riscv_vadd_vv_i32m1(vaux_0, vaux_1, vl), vzero, vl); + vint32m1_t isum1 = __riscv_vredsum_vs_i32m1_i32m1(__riscv_vadd_vv_i32m1(vaux_2, vaux_3, vl), isum0, vl); + + sum_t += __riscv_vmv_x_s_i32m1_i32(isum1); + + const float d = GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d; + + sumf += d*sum_t; + } + + *s = sumf; +} #endif +void ggml_vec_dot_q3_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { +#if defined __riscv_xtheadvector + ggml_vec_dot_q3_K_q8_K_xtheadvector(n, s, bs, vx, bx, vy, by, nrc); +#elif defined __riscv_v + switch (__riscv_vlenb() * 8) { + case 128: + ggml_vec_dot_q3_K_q8_K_vl128(n, s, bs, vx, bx, vy, by, nrc); + break; + case 256: + ggml_vec_dot_q3_K_q8_K_vl256(n, s, bs, vx, bx, vy, by, nrc); + break; + case 512: + ggml_vec_dot_q3_K_q8_K_vl512(n, s, bs, vx, bx, vy, by, nrc); + break; + case 1024: + ggml_vec_dot_q3_K_q8_K_vl1024(n, s, bs, vx, bx, vy, by, nrc); + break; + default: + ggml_vec_dot_q3_K_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc); + break; + } +#else + ggml_vec_dot_q3_K_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc); +#endif } -void ggml_vec_dot_q4_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { +#if defined __riscv_xtheadvector +static NOINLINE void ggml_vec_dot_q4_K_q8_K_xtheadvector(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { assert(n % QK_K == 0); assert(nrc == 1); UNUSED(nrc); @@ -1135,8 +1650,6 @@ void ggml_vec_dot_q4_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi uint32_t utmp[4]; -#if defined __riscv_xtheadvector - const uint8_t * scales = (const uint8_t*)&utmp[0]; const uint8_t * mins = (const uint8_t*)&utmp[2]; @@ -1250,277 +1763,317 @@ void ggml_vec_dot_q4_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi } *s = sumf; +} +#endif -#elif defined __riscv_v +#if defined __riscv_v +static NOINLINE void ggml_vec_dot_q4_K_q8_K_vl128(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + assert(n % QK_K == 0); + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + + const block_q4_K * GGML_RESTRICT x = vx; + const block_q8_K * GGML_RESTRICT y = vy; + + const int nb = n / QK_K; + + static const uint32_t kmask1 = 0x3f3f3f3f; + static const uint32_t kmask2 = 0x0f0f0f0f; + static const uint32_t kmask3 = 0x03030303; + + uint32_t utmp[4]; const uint8_t * scales = (const uint8_t*)&utmp[0]; const uint8_t * mins = (const uint8_t*)&utmp[2]; float sumf = 0; - const int vector_length = __riscv_vlenb() * 8; + for (int i = 0; i < nb; ++i) { + const float d = y[i].d * GGML_CPU_FP16_TO_FP32(x[i].d); + const float dmin = y[i].d * GGML_CPU_FP16_TO_FP32(x[i].dmin); + + float ftmp, ft2; + const uint8_t * restrict q40; + const uint8_t * restrict q41; + const uint8_t * restrict q42; + const uint8_t * restrict q43; + const int8_t * restrict q80; + const int8_t * restrict q81; + const int8_t * restrict q82; + const int8_t * restrict q83; + int s0, s1, s2, s3; + + __asm__ __volatile__( + "li %[s1], 8\n\t" + "vsetivli zero, 4, e32, m1, ta, ma\n\t" + "vle32.v v1, (%[s6b])\n\t" + "vslide1down.vx v1, v1, zero\n\t" + "vmv.v.x v16, zero\n\t" + "vslidedown.vi v2, v1, 2\n\t" + "vmv1r.v v3, v2\n\t" + "vslideup.vi v2, v3, 1\n\t" // {aux[2], aux[2]} + "vsetivli zero, 2, e32, m1, ta, ma\n\t" + "vmv.v.i v4, 4\n\t" + "vand.vx v8, v1, %[kmask1]\n\t" + "vslide1up.vx v5, v4, zero\n\t" // {0, 4} + "vsrl.vi v6, v1, 6\n\t" + "vsrl.vv v7, v2, v5\n\t" + "vsse32.v v8, (%[utmp]), %[s1]\n\t" + "vand.vx v0, v6, %[kmask3]\n\t" + "vand.vx v2, v7, %[kmask2]\n\t" + "vsll.vi v6, v0, 4\n\t" + "addi %[s0], %[utmp], 4\n\t" + "vor.vv v1, v6, v2\n\t" + "vsse32.v v1, (%[s0]), %[s1]\n\t" + "vsetivli zero, 8, e16, m1, ta, ma\n\t" + "vle32.v v2, (%[bsums])\n\t" + "vnsrl.wi v0, v2, 0\n\t" + "vnsrl.wi v1, v2, 16\n\t" + "vadd.vv v2, v0, v1\n\t" + "vle8.v v3, (%[mins])\n\t" + "vzext.vf2 v4, v3\n\t" + "vwmul.vv v6, v4, v2\n\t" + "vsetivli zero, 4, e32, m1, ta, ma\n\t" + "vredsum.vs v0, v6, v16\n\t" + "vredsum.vs v0, v7, v0\n\t" + "vfcvt.f.x.v v0, v0\n\t" + "vfmv.f.s %[ftmp], v0\n\t" + "vsetivli zero, 16, e8, m1, ta, ma\n\t" + "vle8.v v0, (%[xs])\n\t" + "fnmsub.s %[sumf], %[dmin], %[ftmp], %[sumf]\n\t" + "addi %[q40], %[xs], 64\n\t" + "addi %[q41], %[xs], 16\n\t" + "addi %[q42], %[xs], 32\n\t" + "addi %[q43], %[xs], 48\n\t" + "addi %[q80], %[ys], 64\n\t" + "vle8.v v1, (%[q41])\n\t" + "vle8.v v2, (%[q42])\n\t" + "addi %[q81], %[ys], 16\n\t" + "addi %[q41], %[q41], 64\n\t" + "addi %[q82], %[ys], 32\n\t" + "vle8.v v3, (%[q43])\n\t" + "vle8.v v8, (%[ys])\n\t" + "addi %[q42], %[q42], 64\n\t" + "addi %[q83], %[ys], 48\n\t" + "addi %[q43], %[q43], 64\n\t" + "vsrl.vi v4, v0, 4\n\t" + "vle8.v v9, (%[q81])\n\t" + "vle8.v v10, (%[q82])\n\t" + "vand.vi v0, v0, 0xF\n\t" + "addi %[q81], %[q81], 64\n\t" + "vsrl.vi v5, v1, 4\n\t" + "addi %[q82], %[q82], 64\n\t" + "vle8.v v11, (%[q83])\n\t" + "vle8.v v12, (%[q80])\n\t" + "vand.vi v1, v1, 0xF\n\t" + "addi %[q83], %[q83], 64\n\t" + "vsrl.vi v6, v2, 4\n\t" + "addi %[q80], %[q80], 64\n\t" + "vle8.v v13, (%[q81])\n\t" + "vle8.v v14, (%[q82])\n\t" + "vand.vi v2, v2, 0xF\n\t" + "addi %[q81], %[q81], 64\n\t" + "vsrl.vi v7, v3, 4\n\t" + "addi %[q82], %[q82], 64\n\t" + "vwmul.vv v16, v0, v8\n\t" + "vle8.v v15, (%[q83])\n\t" + "vle8.v v0, (%[q40])\n\t" + "vand.vi v3, v3, 0xF\n\t" + "addi %[q83], %[q83], 64\n\t" + "vwmul.vv v24, v2, v12\n\t" + "vwmul.vv v20, v4, v10\n\t" + "vwmul.vv v28, v6, v14\n\t" + "vwmacc.vv v16, v1, v9\n\t" + "vle8.v v1, (%[q41])\n\t" + "vle8.v v2, (%[q42])\n\t" + "vwmacc.vv v24, v3, v13\n\t" + "vwmacc.vv v20, v5, v11\n\t" + "vwmacc.vv v28, v7, v15\n\t" + "addi %[q40], %[q80], 64\n\t" + "addi %[q41], %[q81], 64\n\t" + "vle8.v v3, (%[q43])\n\t" + "vle8.v v8, (%[q80])\n\t" + "addi %[q42], %[q82], 64\n\t" + "addi %[q43], %[q83], 64\n\t" + "vsrl.vi v4, v0, 4\n\t" + "vle8.v v9, (%[q81])\n\t" + "vle8.v v10, (%[q82])\n\t" + "vand.vi v0, v0, 0xF\n\t" + "vsrl.vi v5, v1, 4\n\t" + "vsrl.vi v7, v3, 4\n\t" + "vand.vi v3, v3, 0xF\n\t" + "vle8.v v11, (%[q83])\n\t" + "vle8.v v12, (%[q40])\n\t" + "vand.vi v1, v1, 0xF\n\t" + "vsrl.vi v6, v2, 4\n\t" + "vand.vi v2, v2, 0xF\n\t" + "vwmul.vv v18, v0, v8\n\t" + "vle8.v v13, (%[q41])\n\t" + "vle8.v v14, (%[q42])\n\t" + "vwmul.vv v26, v2, v12\n\t" + "vwmul.vv v22, v4, v10\n\t" + "vwmul.vv v30, v6, v14\n\t" + "vwmacc.vv v18, v1, v9\n\t" + "vle8.v v15, (%[q43])\n\t" + "vwmacc.vv v26, v3, v13\n\t" + "vwmacc.vv v22, v5, v11\n\t" + "vwmacc.vv v30, v7, v15\n\t" + "vmv.v.x v0, zero\n\t" + "vsetivli zero, 16, e16, m2, ta, ma\n\t" + "vwredsum.vs v4, v16, v0\n\t" + "lbu %[s0], 0(%[scale])\n\t" + "vwredsum.vs v5, v20, v0\n\t" + "lbu %[s1], 1(%[scale])\n\t" + "vwredsum.vs v6, v24, v0\n\t" + "lbu %[s2], 2(%[scale])\n\t" + "vwredsum.vs v7, v28, v0\n\t" + "lbu %[s3], 3(%[scale])\n\t" + "vwredsum.vs v8, v18, v0\n\t" + "lbu %[q40], 4(%[scale])\n\t" + "vwredsum.vs v9, v22, v0\n\t" + "lbu %[q41], 5(%[scale])\n\t" + "vwredsum.vs v10, v26, v0\n\t" + "lbu %[q42], 6(%[scale])\n\t" + "vwredsum.vs v11, v30, v0\n\t" + "lbu %[q43], 7(%[scale])\n\t" + "vsetivli zero, 4, e32, m1, ta, ma\n\t" + "vmul.vx v0, v4, %[s0]\n\t" + "vmul.vx v1, v8, %[q40]\n\t" + "vmacc.vx v0, %[s1], v5\n\t" + "vmacc.vx v1, %[q41], v9\n\t" + "vmacc.vx v0, %[s2], v6\n\t" + "vmacc.vx v1, %[q42], v10\n\t" + "vmacc.vx v0, %[s3], v7\n\t" + "vmacc.vx v1, %[q43], v11\n\t" + "vfcvt.f.x.v v0, v0\n\t" + "vfcvt.f.x.v v1, v1\n\t" + "vfmv.f.s %[ft2], v0\n\t" + "vfmv.f.s %[ftmp], v1\n\t" + "fadd.s %[ft2], %[ft2], %[ftmp]\n\t" + "fmadd.s %[sumf], %[d], %[ft2], %[sumf]" + : [ftmp] "=&f" (ftmp), [sumf] "+&f" (sumf), [ft2] "=&f" (ft2) + , [s0] "=&r" (s0), [s1] "=&r" (s1), [s2] "=&r" (s2), [s3] "=&r" (s3) + , [q40] "=&r" (q40), [q41] "=&r" (q41), [q42] "=&r" (q42), [q43] "=&r" (q43) + , [q80] "=&r" (q80), [q81] "=&r" (q81), [q82] "=&r" (q82), [q83] "=&r" (q83) + : [d] "f" (d), [ys] "r" (y[i].qs), [xs] "r" (x[i].qs), [scale] "r" (scales) + , [bsums] "r" (y[i].bsums), [mins] "r" (mins), [utmp] "r" (utmp) + , [s6b] "r" (&x[i]), [kmask1] "r" (kmask1), [dmin] "f" (dmin) + , [kmask2] "r" (kmask2), [kmask3] "r" (kmask3) + : "memory" + , "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7" + , "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15" + , "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23" + , "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31" + ); + } + + *s = sumf; +} + +static NOINLINE void ggml_vec_dot_q4_K_q8_K_vl256(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + assert(n % QK_K == 0); + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); - switch (vector_length) { - case 256: - for (int i = 0; i < nb; ++i) { + const block_q4_K * GGML_RESTRICT x = vx; + const block_q8_K * GGML_RESTRICT y = vy; - size_t vl = 8; + const int nb = n / QK_K; - const float d = y[i].d * GGML_CPU_FP16_TO_FP32(x[i].d); - const float dmin = y[i].d * GGML_CPU_FP16_TO_FP32(x[i].dmin); + static const uint32_t kmask1 = 0x3f3f3f3f; + static const uint32_t kmask2 = 0x0f0f0f0f; + static const uint32_t kmask3 = 0x03030303; - vint16mf2_t q8sums_0 = __riscv_vlse16_v_i16mf2(y[i].bsums, 4, vl); - vint16mf2_t q8sums_1 = __riscv_vlse16_v_i16mf2(y[i].bsums+1, 4, vl); - vint16mf2_t q8sums = __riscv_vadd_vv_i16mf2(q8sums_0, q8sums_1, vl); + uint32_t utmp[4]; - memcpy(utmp, x[i].scales, 12); - utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4); - const uint32_t uaux = utmp[1] & kmask1; - utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4); - utmp[2] = uaux; - utmp[0] &= kmask1; + const uint8_t * scales = (const uint8_t*)&utmp[0]; + const uint8_t * mins = (const uint8_t*)&utmp[2]; - vuint8mf4_t mins8 = __riscv_vle8_v_u8mf4(mins, vl); - vint16mf2_t v_mins = __riscv_vreinterpret_v_u16mf2_i16mf2(__riscv_vzext_vf2_u16mf2(mins8, vl)); - vint32m1_t prod = __riscv_vwmul_vv_i32m1(q8sums, v_mins, vl); + float sumf = 0; + for (int i = 0; i < nb; ++i) { + size_t vl = 8; - vint32m1_t sumi = __riscv_vredsum_vs_i32m1_i32m1(prod, __riscv_vmv_v_x_i32m1(0, 1), vl); - sumf -= dmin * __riscv_vmv_x_s_i32m1_i32(sumi); + const float d = y[i].d * GGML_CPU_FP16_TO_FP32(x[i].d); + const float dmin = y[i].d * GGML_CPU_FP16_TO_FP32(x[i].dmin); - const uint8_t * GGML_RESTRICT q4 = x[i].qs; - const int8_t * GGML_RESTRICT q8 = y[i].qs; + vint16mf2_t q8sums_0 = __riscv_vlse16_v_i16mf2(y[i].bsums, 4, vl); + vint16mf2_t q8sums_1 = __riscv_vlse16_v_i16mf2(y[i].bsums+1, 4, vl); + vint16mf2_t q8sums = __riscv_vadd_vv_i16mf2(q8sums_0, q8sums_1, vl); - vl = 32; + memcpy(utmp, x[i].scales, 12); + utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4); + const uint32_t uaux = utmp[1] & kmask1; + utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4); + utmp[2] = uaux; + utmp[0] &= kmask1; + + vuint8mf4_t mins8 = __riscv_vle8_v_u8mf4(mins, vl); + vint16mf2_t v_mins = __riscv_vreinterpret_v_u16mf2_i16mf2(__riscv_vzext_vf2_u16mf2(mins8, vl)); + vint32m1_t prod = __riscv_vwmul_vv_i32m1(q8sums, v_mins, vl); + + vint32m1_t sumi = __riscv_vredsum_vs_i32m1_i32m1(prod, __riscv_vmv_v_x_i32m1(0, 1), vl); + sumf -= dmin * __riscv_vmv_x_s_i32m1_i32(sumi); - int32_t sum_1 = 0; - int32_t sum_2 = 0; + const uint8_t * GGML_RESTRICT q4 = x[i].qs; + const int8_t * GGML_RESTRICT q8 = y[i].qs; - vint16m1_t vzero = __riscv_vmv_v_x_i16m1(0, 1); + vl = 32; - for (int j = 0; j < QK_K/64; ++j) { - // load Q4 - vuint8m1_t q4_x = __riscv_vle8_v_u8m1(q4, vl); + int32_t sum_1 = 0; + int32_t sum_2 = 0; - // load Q8 and multiply it with lower Q4 nibble - vint8m1_t q8_0 = __riscv_vle8_v_i8m1(q8, vl); - vint8m1_t q4_0 = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vand_vx_u8m1(q4_x, 0x0F, vl)); - vint16m2_t qv_0 = __riscv_vwmul_vv_i16m2(q4_0, q8_0, vl); - vint16m1_t vs_0 = __riscv_vredsum_vs_i16m2_i16m1(qv_0, vzero, vl); + vint16m1_t vzero = __riscv_vmv_v_x_i16m1(0, 1); - sum_1 += __riscv_vmv_x_s_i16m1_i16(vs_0) * scales[2*j+0]; + for (int j = 0; j < QK_K/64; ++j) { + // load Q4 + vuint8m1_t q4_x = __riscv_vle8_v_u8m1(q4, vl); - // load Q8 and multiply it with upper Q4 nibble - vint8m1_t q8_1 = __riscv_vle8_v_i8m1(q8+32, vl); - vint8m1_t q4_1 = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vsrl_vx_u8m1(q4_x, 0x04, vl)); - vint16m2_t qv_1 = __riscv_vwmul_vv_i16m2(q4_1, q8_1, vl); - vint16m1_t vs_1 = __riscv_vredsum_vs_i16m2_i16m1(qv_1, vzero, vl); + // load Q8 and multiply it with lower Q4 nibble + vint8m1_t q8_0 = __riscv_vle8_v_i8m1(q8, vl); + vint8m1_t q4_0 = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vand_vx_u8m1(q4_x, 0x0F, vl)); + vint16m2_t qv_0 = __riscv_vwmul_vv_i16m2(q4_0, q8_0, vl); + vint16m1_t vs_0 = __riscv_vredsum_vs_i16m2_i16m1(qv_0, vzero, vl); - sum_2 += __riscv_vmv_x_s_i16m1_i16(vs_1) * scales[2*j+1]; + sum_1 += __riscv_vmv_x_s_i16m1_i16(vs_0) * scales[2*j+0]; - q4 += 32; q8 += 64; + // load Q8 and multiply it with upper Q4 nibble + vint8m1_t q8_1 = __riscv_vle8_v_i8m1(q8+32, vl); + vint8m1_t q4_1 = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vsrl_vx_u8m1(q4_x, 0x04, vl)); + vint16m2_t qv_1 = __riscv_vwmul_vv_i16m2(q4_1, q8_1, vl); + vint16m1_t vs_1 = __riscv_vredsum_vs_i16m2_i16m1(qv_1, vzero, vl); - } + sum_2 += __riscv_vmv_x_s_i16m1_i16(vs_1) * scales[2*j+1]; - sumf += d*(sum_1 + sum_2); + q4 += 32; q8 += 64; } - break; - case 128: - for (int i = 0; i < nb; ++i) { - const float d = y[i].d * GGML_CPU_FP16_TO_FP32(x[i].d); - const float dmin = y[i].d * GGML_CPU_FP16_TO_FP32(x[i].dmin); - float ftmp, ft2; - const uint8_t * restrict q40; - const uint8_t * restrict q41; - const uint8_t * restrict q42; - const uint8_t * restrict q43; - const int8_t * restrict q80; - const int8_t * restrict q81; - const int8_t * restrict q82; - const int8_t * restrict q83; - int s0, s1, s2, s3; + sumf += d*(sum_1 + sum_2); - __asm__ __volatile__( - "li %[s1], 8\n\t" - "vsetivli zero, 4, e32, m1, ta, ma\n\t" - "vle32.v v1, (%[s6b])\n\t" - "vslide1down.vx v1, v1, zero\n\t" - "vmv.v.x v16, zero\n\t" - "vslidedown.vi v2, v1, 2\n\t" - "vmv1r.v v3, v2\n\t" - "vslideup.vi v2, v3, 1\n\t" // {aux[2], aux[2]} - "vsetivli zero, 2, e32, m1, ta, ma\n\t" - "vmv.v.i v4, 4\n\t" - "vand.vx v8, v1, %[kmask1]\n\t" - "vslide1up.vx v5, v4, zero\n\t" // {0, 4} - "vsrl.vi v6, v1, 6\n\t" - "vsrl.vv v7, v2, v5\n\t" - "vsse32.v v8, (%[utmp]), %[s1]\n\t" - "vand.vx v0, v6, %[kmask3]\n\t" - "vand.vx v2, v7, %[kmask2]\n\t" - "vsll.vi v6, v0, 4\n\t" - "addi %[s0], %[utmp], 4\n\t" - "vor.vv v1, v6, v2\n\t" - "vsse32.v v1, (%[s0]), %[s1]\n\t" - "vsetivli zero, 8, e16, m1, ta, ma\n\t" - "vle32.v v2, (%[bsums])\n\t" - "vnsrl.wi v0, v2, 0\n\t" - "vnsrl.wi v1, v2, 16\n\t" - "vadd.vv v2, v0, v1\n\t" - "vle8.v v3, (%[mins])\n\t" - "vzext.vf2 v4, v3\n\t" - "vwmul.vv v6, v4, v2\n\t" - "vsetivli zero, 4, e32, m1, ta, ma\n\t" - "vredsum.vs v0, v6, v16\n\t" - "vredsum.vs v0, v7, v0\n\t" - "vfcvt.f.x.v v0, v0\n\t" - "vfmv.f.s %[ftmp], v0\n\t" - "vsetivli zero, 16, e8, m1, ta, ma\n\t" - "vle8.v v0, (%[xs])\n\t" - "fnmsub.s %[sumf], %[dmin], %[ftmp], %[sumf]\n\t" - "addi %[q40], %[xs], 64\n\t" - "addi %[q41], %[xs], 16\n\t" - "addi %[q42], %[xs], 32\n\t" - "addi %[q43], %[xs], 48\n\t" - "addi %[q80], %[ys], 64\n\t" - "vle8.v v1, (%[q41])\n\t" - "vle8.v v2, (%[q42])\n\t" - "addi %[q81], %[ys], 16\n\t" - "addi %[q41], %[q41], 64\n\t" - "addi %[q82], %[ys], 32\n\t" - "vle8.v v3, (%[q43])\n\t" - "vle8.v v8, (%[ys])\n\t" - "addi %[q42], %[q42], 64\n\t" - "addi %[q83], %[ys], 48\n\t" - "addi %[q43], %[q43], 64\n\t" - "vsrl.vi v4, v0, 4\n\t" - "vle8.v v9, (%[q81])\n\t" - "vle8.v v10, (%[q82])\n\t" - "vand.vi v0, v0, 0xF\n\t" - "addi %[q81], %[q81], 64\n\t" - "vsrl.vi v5, v1, 4\n\t" - "addi %[q82], %[q82], 64\n\t" - "vle8.v v11, (%[q83])\n\t" - "vle8.v v12, (%[q80])\n\t" - "vand.vi v1, v1, 0xF\n\t" - "addi %[q83], %[q83], 64\n\t" - "vsrl.vi v6, v2, 4\n\t" - "addi %[q80], %[q80], 64\n\t" - "vle8.v v13, (%[q81])\n\t" - "vle8.v v14, (%[q82])\n\t" - "vand.vi v2, v2, 0xF\n\t" - "addi %[q81], %[q81], 64\n\t" - "vsrl.vi v7, v3, 4\n\t" - "addi %[q82], %[q82], 64\n\t" - "vwmul.vv v16, v0, v8\n\t" - "vle8.v v15, (%[q83])\n\t" - "vle8.v v0, (%[q40])\n\t" - "vand.vi v3, v3, 0xF\n\t" - "addi %[q83], %[q83], 64\n\t" - "vwmul.vv v24, v2, v12\n\t" - "vwmul.vv v20, v4, v10\n\t" - "vwmul.vv v28, v6, v14\n\t" - "vwmacc.vv v16, v1, v9\n\t" - "vle8.v v1, (%[q41])\n\t" - "vle8.v v2, (%[q42])\n\t" - "vwmacc.vv v24, v3, v13\n\t" - "vwmacc.vv v20, v5, v11\n\t" - "vwmacc.vv v28, v7, v15\n\t" - "addi %[q40], %[q80], 64\n\t" - "addi %[q41], %[q81], 64\n\t" - "vle8.v v3, (%[q43])\n\t" - "vle8.v v8, (%[q80])\n\t" - "addi %[q42], %[q82], 64\n\t" - "addi %[q43], %[q83], 64\n\t" - "vsrl.vi v4, v0, 4\n\t" - "vle8.v v9, (%[q81])\n\t" - "vle8.v v10, (%[q82])\n\t" - "vand.vi v0, v0, 0xF\n\t" - "vsrl.vi v5, v1, 4\n\t" - "vsrl.vi v7, v3, 4\n\t" - "vand.vi v3, v3, 0xF\n\t" - "vle8.v v11, (%[q83])\n\t" - "vle8.v v12, (%[q40])\n\t" - "vand.vi v1, v1, 0xF\n\t" - "vsrl.vi v6, v2, 4\n\t" - "vand.vi v2, v2, 0xF\n\t" - "vwmul.vv v18, v0, v8\n\t" - "vle8.v v13, (%[q41])\n\t" - "vle8.v v14, (%[q42])\n\t" - "vwmul.vv v26, v2, v12\n\t" - "vwmul.vv v22, v4, v10\n\t" - "vwmul.vv v30, v6, v14\n\t" - "vwmacc.vv v18, v1, v9\n\t" - "vle8.v v15, (%[q43])\n\t" - "vwmacc.vv v26, v3, v13\n\t" - "vwmacc.vv v22, v5, v11\n\t" - "vwmacc.vv v30, v7, v15\n\t" - "vmv.v.x v0, zero\n\t" - "vsetivli zero, 16, e16, m2, ta, ma\n\t" - "vwredsum.vs v4, v16, v0\n\t" - "lbu %[s0], 0(%[scale])\n\t" - "vwredsum.vs v5, v20, v0\n\t" - "lbu %[s1], 1(%[scale])\n\t" - "vwredsum.vs v6, v24, v0\n\t" - "lbu %[s2], 2(%[scale])\n\t" - "vwredsum.vs v7, v28, v0\n\t" - "lbu %[s3], 3(%[scale])\n\t" - "vwredsum.vs v8, v18, v0\n\t" - "lbu %[q40], 4(%[scale])\n\t" - "vwredsum.vs v9, v22, v0\n\t" - "lbu %[q41], 5(%[scale])\n\t" - "vwredsum.vs v10, v26, v0\n\t" - "lbu %[q42], 6(%[scale])\n\t" - "vwredsum.vs v11, v30, v0\n\t" - "lbu %[q43], 7(%[scale])\n\t" - "vsetivli zero, 4, e32, m1, ta, ma\n\t" - "vmul.vx v0, v4, %[s0]\n\t" - "vmul.vx v1, v8, %[q40]\n\t" - "vmacc.vx v0, %[s1], v5\n\t" - "vmacc.vx v1, %[q41], v9\n\t" - "vmacc.vx v0, %[s2], v6\n\t" - "vmacc.vx v1, %[q42], v10\n\t" - "vmacc.vx v0, %[s3], v7\n\t" - "vmacc.vx v1, %[q43], v11\n\t" - "vfcvt.f.x.v v0, v0\n\t" - "vfcvt.f.x.v v1, v1\n\t" - "vfmv.f.s %[ft2], v0\n\t" - "vfmv.f.s %[ftmp], v1\n\t" - "fadd.s %[ft2], %[ft2], %[ftmp]\n\t" - "fmadd.s %[sumf], %[d], %[ft2], %[sumf]" - : [ftmp] "=&f" (ftmp), [sumf] "+&f" (sumf), [ft2] "=&f" (ft2) - , [s0] "=&r" (s0), [s1] "=&r" (s1), [s2] "=&r" (s2), [s3] "=&r" (s3) - , [q40] "=&r" (q40), [q41] "=&r" (q41), [q42] "=&r" (q42), [q43] "=&r" (q43) - , [q80] "=&r" (q80), [q81] "=&r" (q81), [q82] "=&r" (q82), [q83] "=&r" (q83) - : [d] "f" (d), [ys] "r" (y[i].qs), [xs] "r" (x[i].qs), [scale] "r" (scales) - , [bsums] "r" (y[i].bsums), [mins] "r" (mins), [utmp] "r" (utmp) - , [s6b] "r" (&x[i]), [kmask1] "r" (kmask1), [dmin] "f" (dmin) - , [kmask2] "r" (kmask2), [kmask3] "r" (kmask3) - : "memory" - , "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7" - , "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15" - , "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23" - , "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31" - ); - } - break; - default: - assert(false && "Unsupported vector length"); - break; } *s = sumf; +} +#endif +void ggml_vec_dot_q4_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { +#if defined __riscv_xtheadvector + ggml_vec_dot_q4_K_q8_K_xtheadvector(n, s, bs, vx, bx, vy, by, nrc); +#elif defined __riscv_v + switch (__riscv_vlenb() * 8) { + case 128: + ggml_vec_dot_q4_K_q8_K_vl128(n, s, bs, vx, bx, vy, by, nrc); + break; + default: // 256 and above + ggml_vec_dot_q4_K_q8_K_vl256(n, s, bs, vx, bx, vy, by, nrc); + break; + } #else - - UNUSED(x); - UNUSED(y); - UNUSED(kmask1); - UNUSED(kmask2); - UNUSED(kmask3); - UNUSED(nb); - UNUSED(utmp); - ggml_vec_dot_q4_K_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc); #endif } @@ -1621,7 +2174,6 @@ void ggml_vec_dot_q5_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi aux32 += __riscv_vmv_x_s_i32m1_i32(vacc2); q5 += 32; q8 += 64; - } sums += aux32 * d; @@ -1644,7 +2196,8 @@ void ggml_vec_dot_q5_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi #endif } -void ggml_vec_dot_q6_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { +#if defined __riscv_xtheadvector +static NOINLINE void ggml_vec_dot_q6_K_q8_K_xtheadvector(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { assert(n % QK_K == 0); assert(nrc == 1); UNUSED(nrc); @@ -1657,8 +2210,6 @@ void ggml_vec_dot_q6_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi const int nb = n / QK_K; -#if defined __riscv_xtheadvector - float sumf = 0; for (int i = 0; i < nb; ++i) { @@ -1737,220 +2288,4309 @@ void ggml_vec_dot_q6_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi } *s = sumf; +} +#endif -#elif defined __riscv_v +#if defined __riscv_v +static NOINLINE void ggml_vec_dot_q6_K_q8_K_vl128(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + assert(n % QK_K == 0); + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + + const block_q6_K * GGML_RESTRICT x = vx; + const block_q8_K * GGML_RESTRICT y = vy; + + const int nb = n / QK_K; + + float sumf = 0.0f; + for (int i = 0; i < nb; ++i) { + __builtin_prefetch(&x[i + 1].d, 0, 1); + + const float d = GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d; + + const uint8_t * restrict q6 = x[i].ql; + const uint8_t * restrict qh = x[i].qh; + const int8_t * restrict q8 = y[i].qs; + + const int8_t * restrict scale = x[i].scales; + + int q6h; + float ftmp; + + for (int j = 0; j < QK_K/128; ++j) { + __asm__ __volatile__( + "addi %[q6h], %[q6], 32\n\t" + "ld t0, 0(%[scale])\n\t" + "addi %[scale], %[scale], 8\n\t" + "slli t6, t0, 1 * 8\n\t" + "lb zero, 0(%[q6])\n\t" + "slli t5, t0, 2 * 8\n\t" + "slli t4, t0, 3 * 8\n\t" + "lb zero, 0(%[q6h])\n\t" + "slli t3, t0, 4 * 8\n\t" + "slli t2, t0, 5 * 8\n\t" + "lb zero, 0(%[qh])\n\t" + "lb zero, 31(%[q6h])\n\t" + "slli t1, t0, 6 * 8\n\t" + "srai a7, t0, 56\n\t" + "vsetvli zero, %[vl32], e8, m2\n\t" + "vle8.v v8, (%[q6])\n\t" + "srai t6, t6, 56\n\t" + "srai t5, t5, 56\n\t" + "srai t4, t4, 56\n\t" + "srai t3, t3, 56\n\t" + "vle8.v v10, (%[q6h])\n\t" + "addi %[q6], %[q6], 64\n\t" + "slli t0, t0, 7 * 8\n\t" + "srai t2, t2, 56\n\t" + "srai t1, t1, 56\n\t" + "srai t0, t0, 56\n\t" + "vle8.v v4, (%[qh])\n\t" + "vsrl.vi v12, v8, 4\n\t" + "vsrl.vi v14, v10, 4\n\t" + "lb zero, 0(%[q8])\n\t" + "vand.vi v8, v8, 0xF\n\t" + "vand.vi v10, v10, 0xF\n\t" + "lb zero, 32(%[q8])\n\t" + "vsll.vi v0, v4, 4\n\t" + "vsll.vi v2, v4, 2\n\t" + "lb zero, 64(%[q8])\n\t" + "vsrl.vi v6, v4, 2\n\t" + "vand.vx v0, v0, %[mask]\n\t" + "lb zero, 96(%[q8])\n\t" + "vand.vx v2, v2, %[mask]\n\t" + "vand.vx v4, v4, %[mask]\n\t" + "vand.vx v6, v6, %[mask]\n\t" + "vor.vv v8, v8, v0\n\t" + "lb zero, 127(%[q8])\n\t" + "vor.vv v10, v10, v2\n\t" + "vor.vv v12, v12, v4\n\t" + "vor.vv v14, v14, v6\n\t" + "vsetvli zero, %[vl128], e8, m8\n\t" + "vle8.v v0, (%[q8])\n\t" + "vsub.vx v8, v8, %[vl32]\n\t" + "vsetvli zero, %[vl64], e8, m4\n\t" + "vwmul.vv v16, v0, v8\n\t" + "vwmul.vv v24, v4, v12\n\t" + "vsetivli zero, 16, e16, m2\n\t" + "vmv.v.x v0, zero\n\t" + "vwredsum.vs v10, v16, v0\n\t" + "vwredsum.vs v9, v18, v0\n\t" + "vwredsum.vs v8, v20, v0\n\t" + "vwredsum.vs v7, v22, v0\n\t" + "vwredsum.vs v11, v24, v0\n\t" + "vwredsum.vs v12, v26, v0\n\t" + "vwredsum.vs v13, v28, v0\n\t" + "vwredsum.vs v14, v30, v0\n\t" + "vsetivli zero, 4, e32, m1\n\t" + "vmul.vx v0, v10, t0\n\t" + "vmul.vx v1, v9, t1\n\t" + "vmacc.vx v0, t2, v8\n\t" + "vmacc.vx v1, t3, v7\n\t" + "vmacc.vx v0, t4, v11\n\t" + "vmacc.vx v1, t5, v12\n\t" + "vmacc.vx v0, t6, v13\n\t" + "vmacc.vx v1, a7, v14\n\t" + "vadd.vv v0, v0, v1\n\t" + "vfcvt.f.x.v v0, v0\n\t" + "vfmv.f.s %[ftmp], v0\n\t" + "fmadd.s %[sumf], %[d], %[ftmp], %[sumf]" + : [q6] "+&r" (q6), [q6h] "=&r" (q6h) + , [scale] "+&r" (scale) + , [sumf] "+&f" (sumf), [ftmp] "=&f" (ftmp) + : [qh] "r" (qh), [q8] "r" (q8) + , [vl32] "r" (32), [vl64] "r" (64), [vl128] "r" (128) + , [mask] "r" (0x30), [d] "f" (d) + : "memory" + , "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7" + , "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15" + , "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23" + , "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31" + , "t0", "t1", "t2", "t3", "t4", "t5", "t6", "a7" + , "a6", "a5", "a4", "a3" + ); + qh += 32; q8 += 128; + } + } + + *s = sumf; +} + +static NOINLINE void ggml_vec_dot_q6_K_q8_K_vl256(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + assert(n % QK_K == 0); + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + + const block_q6_K * GGML_RESTRICT x = vx; + const block_q8_K * GGML_RESTRICT y = vy; + + const int nb = n / QK_K; float sumf = 0; - const int vector_length = __riscv_vlenb() * 8; + for (int i = 0; i < nb; ++i) { + const float d = GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d; + + const uint8_t * GGML_RESTRICT q6 = x[i].ql; + const uint8_t * GGML_RESTRICT qh = x[i].qh; + const int8_t * GGML_RESTRICT q8 = y[i].qs; - switch (vector_length) { - case 256: - for (int i = 0; i < nb; ++i) { + const int8_t * GGML_RESTRICT scale = x[i].scales; - const float d = GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d; + size_t vl; - const uint8_t * GGML_RESTRICT q6 = x[i].ql; - const uint8_t * GGML_RESTRICT qh = x[i].qh; - const int8_t * GGML_RESTRICT q8 = y[i].qs; + vint32m1_t vzero = __riscv_vmv_v_x_i32m1(0, 1); - const int8_t * GGML_RESTRICT scale = x[i].scales; + int sum_t = 0; + int is = 0; - size_t vl; + for (int j = 0; j < QK_K/128; ++j) { - vint32m1_t vzero = __riscv_vmv_v_x_i32m1(0, 1); + vl = 32; - int sum_t = 0; - int is = 0; + // load qh + vuint8m1_t qh_x = __riscv_vle8_v_u8m1(qh, vl); - for (int j = 0; j < QK_K/128; ++j) { + // load Q6 + vuint8m1_t q6_0 = __riscv_vle8_v_u8m1(q6, vl); + vuint8m1_t q6_1 = __riscv_vle8_v_u8m1(q6+32, vl); - vl = 32; + vuint8m1_t q6a_0 = __riscv_vand_vx_u8m1(q6_0, 0x0F, vl); + vuint8m1_t q6a_1 = __riscv_vand_vx_u8m1(q6_1, 0x0F, vl); + vuint8m1_t q6s_0 = __riscv_vsrl_vx_u8m1(q6_0, 0x04, vl); + vuint8m1_t q6s_1 = __riscv_vsrl_vx_u8m1(q6_1, 0x04, vl); - // load qh - vuint8m1_t qh_x = __riscv_vle8_v_u8m1(qh, vl); + vuint8m1_t qh_0 = __riscv_vand_vx_u8m1(qh_x, 0x03, vl); + vuint8m1_t qh_1 = __riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(qh_x, 0x2, vl), 0x03 , vl); + vuint8m1_t qh_2 = __riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(qh_x, 0x4, vl), 0x03 , vl); + vuint8m1_t qh_3 = __riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(qh_x, 0x6, vl), 0x03 , vl); - // load Q6 - vuint8m1_t q6_0 = __riscv_vle8_v_u8m1(q6, vl); - vuint8m1_t q6_1 = __riscv_vle8_v_u8m1(q6+32, vl); + vuint8m1_t qhi_0 = __riscv_vor_vv_u8m1(q6a_0, __riscv_vsll_vx_u8m1(qh_0, 0x04, vl), vl); + vuint8m1_t qhi_1 = __riscv_vor_vv_u8m1(q6a_1, __riscv_vsll_vx_u8m1(qh_1, 0x04, vl), vl); + vuint8m1_t qhi_2 = __riscv_vor_vv_u8m1(q6s_0, __riscv_vsll_vx_u8m1(qh_2, 0x04, vl), vl); + vuint8m1_t qhi_3 = __riscv_vor_vv_u8m1(q6s_1, __riscv_vsll_vx_u8m1(qh_3, 0x04, vl), vl); - vuint8m1_t q6a_0 = __riscv_vand_vx_u8m1(q6_0, 0x0F, vl); - vuint8m1_t q6a_1 = __riscv_vand_vx_u8m1(q6_1, 0x0F, vl); - vuint8m1_t q6s_0 = __riscv_vsrl_vx_u8m1(q6_0, 0x04, vl); - vuint8m1_t q6s_1 = __riscv_vsrl_vx_u8m1(q6_1, 0x04, vl); + vint8m1_t a_0 = __riscv_vsub_vx_i8m1(__riscv_vreinterpret_v_u8m1_i8m1(qhi_0), 32, vl); + vint8m1_t a_1 = __riscv_vsub_vx_i8m1(__riscv_vreinterpret_v_u8m1_i8m1(qhi_1), 32, vl); + vint8m1_t a_2 = __riscv_vsub_vx_i8m1(__riscv_vreinterpret_v_u8m1_i8m1(qhi_2), 32, vl); + vint8m1_t a_3 = __riscv_vsub_vx_i8m1(__riscv_vreinterpret_v_u8m1_i8m1(qhi_3), 32, vl); - vuint8m1_t qh_0 = __riscv_vand_vx_u8m1(qh_x, 0x03, vl); - vuint8m1_t qh_1 = __riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(qh_x, 0x2, vl), 0x03 , vl); - vuint8m1_t qh_2 = __riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(qh_x, 0x4, vl), 0x03 , vl); - vuint8m1_t qh_3 = __riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(qh_x, 0x6, vl), 0x03 , vl); + // load Q8 and take product + vint16m2_t va_q_0 = __riscv_vwmul_vv_i16m2(a_0, __riscv_vle8_v_i8m1(q8, vl), vl); + vint16m2_t va_q_1 = __riscv_vwmul_vv_i16m2(a_1, __riscv_vle8_v_i8m1(q8+32, vl), vl); + vint16m2_t va_q_2 = __riscv_vwmul_vv_i16m2(a_2, __riscv_vle8_v_i8m1(q8+64, vl), vl); + vint16m2_t va_q_3 = __riscv_vwmul_vv_i16m2(a_3, __riscv_vle8_v_i8m1(q8+96, vl), vl); - vuint8m1_t qhi_0 = __riscv_vor_vv_u8m1(q6a_0, __riscv_vsll_vx_u8m1(qh_0, 0x04, vl), vl); - vuint8m1_t qhi_1 = __riscv_vor_vv_u8m1(q6a_1, __riscv_vsll_vx_u8m1(qh_1, 0x04, vl), vl); - vuint8m1_t qhi_2 = __riscv_vor_vv_u8m1(q6s_0, __riscv_vsll_vx_u8m1(qh_2, 0x04, vl), vl); - vuint8m1_t qhi_3 = __riscv_vor_vv_u8m1(q6s_1, __riscv_vsll_vx_u8m1(qh_3, 0x04, vl), vl); + vl = 16; - vint8m1_t a_0 = __riscv_vsub_vx_i8m1(__riscv_vreinterpret_v_u8m1_i8m1(qhi_0), 32, vl); - vint8m1_t a_1 = __riscv_vsub_vx_i8m1(__riscv_vreinterpret_v_u8m1_i8m1(qhi_1), 32, vl); - vint8m1_t a_2 = __riscv_vsub_vx_i8m1(__riscv_vreinterpret_v_u8m1_i8m1(qhi_2), 32, vl); - vint8m1_t a_3 = __riscv_vsub_vx_i8m1(__riscv_vreinterpret_v_u8m1_i8m1(qhi_3), 32, vl); + vint32m2_t vaux_0 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_0, 0), scale[is+0], vl); + vint32m2_t vaux_1 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_0, 1), scale[is+1], vl); + vint32m2_t vaux_2 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_1, 0), scale[is+2], vl); + vint32m2_t vaux_3 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_1, 1), scale[is+3], vl); + vint32m2_t vaux_4 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_2, 0), scale[is+4], vl); + vint32m2_t vaux_5 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_2, 1), scale[is+5], vl); + vint32m2_t vaux_6 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_3, 0), scale[is+6], vl); + vint32m2_t vaux_7 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_3, 1), scale[is+7], vl); - // load Q8 and take product - vint16m2_t va_q_0 = __riscv_vwmul_vv_i16m2(a_0, __riscv_vle8_v_i8m1(q8, vl), vl); - vint16m2_t va_q_1 = __riscv_vwmul_vv_i16m2(a_1, __riscv_vle8_v_i8m1(q8+32, vl), vl); - vint16m2_t va_q_2 = __riscv_vwmul_vv_i16m2(a_2, __riscv_vle8_v_i8m1(q8+64, vl), vl); - vint16m2_t va_q_3 = __riscv_vwmul_vv_i16m2(a_3, __riscv_vle8_v_i8m1(q8+96, vl), vl); + vint32m1_t isum0 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(vaux_0, vaux_1, vl), vzero, vl); + vint32m1_t isum1 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(vaux_2, vaux_3, vl), isum0, vl); + vint32m1_t isum2 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(vaux_4, vaux_5, vl), isum1, vl); + vint32m1_t isum3 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(vaux_6, vaux_7, vl), isum2, vl); - vl = 16; + sum_t += __riscv_vmv_x_s_i32m1_i32(isum3); - vint32m2_t vaux_0 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_0, 0), scale[is+0], vl); - vint32m2_t vaux_1 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_0, 1), scale[is+1], vl); - vint32m2_t vaux_2 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_1, 0), scale[is+2], vl); - vint32m2_t vaux_3 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_1, 1), scale[is+3], vl); - vint32m2_t vaux_4 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_2, 0), scale[is+4], vl); - vint32m2_t vaux_5 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_2, 1), scale[is+5], vl); - vint32m2_t vaux_6 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_3, 0), scale[is+6], vl); - vint32m2_t vaux_7 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(va_q_3, 1), scale[is+7], vl); + q6 += 64; qh += 32; q8 += 128; is=8; - vint32m1_t isum0 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(vaux_0, vaux_1, vl), vzero, vl); - vint32m1_t isum1 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(vaux_2, vaux_3, vl), isum0, vl); - vint32m1_t isum2 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(vaux_4, vaux_5, vl), isum1, vl); - vint32m1_t isum3 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(vaux_6, vaux_7, vl), isum2, vl); + } - sum_t += __riscv_vmv_x_s_i32m1_i32(isum3); + sumf += d * sum_t; - q6 += 64; qh += 32; q8 += 128; is=8; + } - } + *s = sumf; +} - sumf += d * sum_t; - - } - break; - case 128: - for (int i = 0; i < nb; ++i) { - - __builtin_prefetch(&x[i + 1].d, 0, 1); - - const float d = GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d; - - const uint8_t * restrict q6 = x[i].ql; - const uint8_t * restrict qh = x[i].qh; - const int8_t * restrict q8 = y[i].qs; - - const int8_t * restrict scale = x[i].scales; - - int q6h; - float ftmp; - - for (int j = 0; j < QK_K/128; ++j) { - __asm__ __volatile__( - "addi %[q6h], %[q6], 32\n\t" - "ld t0, 0(%[scale])\n\t" - "addi %[scale], %[scale], 8\n\t" - "slli t6, t0, 1 * 8\n\t" - "lb zero, 0(%[q6])\n\t" - "slli t5, t0, 2 * 8\n\t" - "slli t4, t0, 3 * 8\n\t" - "lb zero, 0(%[q6h])\n\t" - "slli t3, t0, 4 * 8\n\t" - "slli t2, t0, 5 * 8\n\t" - "lb zero, 0(%[qh])\n\t" - "lb zero, 31(%[q6h])\n\t" - "slli t1, t0, 6 * 8\n\t" - "srai a7, t0, 56\n\t" - "vsetvli zero, %[vl32], e8, m2\n\t" - "vle8.v v8, (%[q6])\n\t" - "srai t6, t6, 56\n\t" - "srai t5, t5, 56\n\t" - "srai t4, t4, 56\n\t" - "srai t3, t3, 56\n\t" - "vle8.v v10, (%[q6h])\n\t" - "addi %[q6], %[q6], 64\n\t" - "slli t0, t0, 7 * 8\n\t" - "srai t2, t2, 56\n\t" - "srai t1, t1, 56\n\t" - "srai t0, t0, 56\n\t" - "vle8.v v4, (%[qh])\n\t" - "vsrl.vi v12, v8, 4\n\t" - "vsrl.vi v14, v10, 4\n\t" - "lb zero, 0(%[q8])\n\t" - "vand.vi v8, v8, 0xF\n\t" - "vand.vi v10, v10, 0xF\n\t" - "lb zero, 32(%[q8])\n\t" - "vsll.vi v0, v4, 4\n\t" - "vsll.vi v2, v4, 2\n\t" - "lb zero, 64(%[q8])\n\t" - "vsrl.vi v6, v4, 2\n\t" - "vand.vx v0, v0, %[mask]\n\t" - "lb zero, 96(%[q8])\n\t" - "vand.vx v2, v2, %[mask]\n\t" - "vand.vx v4, v4, %[mask]\n\t" - "vand.vx v6, v6, %[mask]\n\t" - "vor.vv v8, v8, v0\n\t" - "lb zero, 127(%[q8])\n\t" - "vor.vv v10, v10, v2\n\t" - "vor.vv v12, v12, v4\n\t" - "vor.vv v14, v14, v6\n\t" - "vsetvli zero, %[vl128], e8, m8\n\t" - "vle8.v v0, (%[q8])\n\t" - "vsub.vx v8, v8, %[vl32]\n\t" - "vsetvli zero, %[vl64], e8, m4\n\t" - "vwmul.vv v16, v0, v8\n\t" - "vwmul.vv v24, v4, v12\n\t" - "vsetivli zero, 16, e16, m2\n\t" - "vmv.v.x v0, zero\n\t" - "vwredsum.vs v10, v16, v0\n\t" - "vwredsum.vs v9, v18, v0\n\t" - "vwredsum.vs v8, v20, v0\n\t" - "vwredsum.vs v7, v22, v0\n\t" - "vwredsum.vs v11, v24, v0\n\t" - "vwredsum.vs v12, v26, v0\n\t" - "vwredsum.vs v13, v28, v0\n\t" - "vwredsum.vs v14, v30, v0\n\t" - "vsetivli zero, 4, e32, m1\n\t" - "vmul.vx v0, v10, t0\n\t" - "vmul.vx v1, v9, t1\n\t" - "vmacc.vx v0, t2, v8\n\t" - "vmacc.vx v1, t3, v7\n\t" - "vmacc.vx v0, t4, v11\n\t" - "vmacc.vx v1, t5, v12\n\t" - "vmacc.vx v0, t6, v13\n\t" - "vmacc.vx v1, a7, v14\n\t" - "vadd.vv v0, v0, v1\n\t" - "vfcvt.f.x.v v0, v0\n\t" - "vfmv.f.s %[ftmp], v0\n\t" - "fmadd.s %[sumf], %[d], %[ftmp], %[sumf]" - : [q6] "+&r" (q6), [q6h] "=&r" (q6h) - , [scale] "+&r" (scale) - , [sumf] "+&f" (sumf), [ftmp] "=&f" (ftmp) - : [qh] "r" (qh), [q8] "r" (q8) - , [vl32] "r" (32), [vl64] "r" (64), [vl128] "r" (128) - , [mask] "r" (0x30), [d] "f" (d) - : "memory" - , "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7" - , "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15" - , "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23" - , "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31" - , "t0", "t1", "t2", "t3", "t4", "t5", "t6", "a7" - , "a6", "a5", "a4", "a3" - ); - qh += 32; q8 += 128; - } +static NOINLINE void ggml_vec_dot_q6_K_q8_K_vl512(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + assert(n % QK_K == 0); + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + + const block_q6_K * GGML_RESTRICT x = vx; + const block_q8_K * GGML_RESTRICT y = vy; + + const int nb = n / QK_K; + + // mask for processing 16 elements per prod register + const vuint16m1_t va_index = __riscv_vid_v_u16m1(32); + const vbool16_t va_mask = __riscv_vmsgtu_vx_u16m1_b16(va_index, 15, 32); + + float sumf = 0; + + for (int i = 0; i < nb; ++i) { + const float d = GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d; + + const uint8_t * GGML_RESTRICT q6 = x[i].ql; + const uint8_t * GGML_RESTRICT qh = x[i].qh; + const int8_t * GGML_RESTRICT q8 = y[i].qs; + + const int8_t * GGML_RESTRICT scale = x[i].scales; + + size_t vl = 32; + + vint32m1_t vzero = __riscv_vmv_v_x_i32m1(0, 1); + + int sum_t = 0; + int is = 0; + + vint32m2_t vaux_0 = __riscv_vmv_v_x_i32m2(0, vl); + vint32m2_t vaux_1 = __riscv_vmv_v_x_i32m2(0, vl); + vint32m2_t vaux_2 = __riscv_vmv_v_x_i32m2(0, vl); + vint32m2_t vaux_3 = __riscv_vmv_v_x_i32m2(0, vl); + + for (int j = 0; j < QK_K/128; ++j) { + // load qh + vuint8mf2_t qh_x = __riscv_vle8_v_u8mf2(qh, vl); + + // load Q6 + vuint8mf2_t q6_0 = __riscv_vle8_v_u8mf2(q6, vl); + vuint8mf2_t q6_1 = __riscv_vle8_v_u8mf2(q6+32, vl); + + vuint8mf2_t q6a_0 = __riscv_vand_vx_u8mf2(q6_0, 0x0F, vl); + vuint8mf2_t q6a_1 = __riscv_vand_vx_u8mf2(q6_1, 0x0F, vl); + vuint8mf2_t q6s_0 = __riscv_vsrl_vx_u8mf2(q6_0, 0x04, vl); + vuint8mf2_t q6s_1 = __riscv_vsrl_vx_u8mf2(q6_1, 0x04, vl); + + vuint8mf2_t qh_0 = __riscv_vand_vx_u8mf2(qh_x, 0x03, vl); + vuint8mf2_t qh_1 = __riscv_vand_vx_u8mf2(__riscv_vsrl_vx_u8mf2(qh_x, 0x2, vl), 0x03 , vl); + vuint8mf2_t qh_2 = __riscv_vand_vx_u8mf2(__riscv_vsrl_vx_u8mf2(qh_x, 0x4, vl), 0x03 , vl); + vuint8mf2_t qh_3 = __riscv_vand_vx_u8mf2(__riscv_vsrl_vx_u8mf2(qh_x, 0x6, vl), 0x03 , vl); + + vuint8mf2_t qhi_0 = __riscv_vor_vv_u8mf2(q6a_0, __riscv_vsll_vx_u8mf2(qh_0, 0x04, vl), vl); + vuint8mf2_t qhi_1 = __riscv_vor_vv_u8mf2(q6a_1, __riscv_vsll_vx_u8mf2(qh_1, 0x04, vl), vl); + vuint8mf2_t qhi_2 = __riscv_vor_vv_u8mf2(q6s_0, __riscv_vsll_vx_u8mf2(qh_2, 0x04, vl), vl); + vuint8mf2_t qhi_3 = __riscv_vor_vv_u8mf2(q6s_1, __riscv_vsll_vx_u8mf2(qh_3, 0x04, vl), vl); + + vint8mf2_t a_0 = __riscv_vsub_vx_i8mf2(__riscv_vreinterpret_v_u8mf2_i8mf2(qhi_0), 32, vl); + vint8mf2_t a_1 = __riscv_vsub_vx_i8mf2(__riscv_vreinterpret_v_u8mf2_i8mf2(qhi_1), 32, vl); + vint8mf2_t a_2 = __riscv_vsub_vx_i8mf2(__riscv_vreinterpret_v_u8mf2_i8mf2(qhi_2), 32, vl); + vint8mf2_t a_3 = __riscv_vsub_vx_i8mf2(__riscv_vreinterpret_v_u8mf2_i8mf2(qhi_3), 32, vl); + + // load Q8 and take product + vint16m1_t va_q_0 = __riscv_vwmul_vv_i16m1(a_0, __riscv_vle8_v_i8mf2(q8, vl), vl); + vint16m1_t va_q_1 = __riscv_vwmul_vv_i16m1(a_1, __riscv_vle8_v_i8mf2(q8+32, vl), vl); + vint16m1_t va_q_2 = __riscv_vwmul_vv_i16m1(a_2, __riscv_vle8_v_i8mf2(q8+64, vl), vl); + vint16m1_t va_q_3 = __riscv_vwmul_vv_i16m1(a_3, __riscv_vle8_v_i8mf2(q8+96, vl), vl); + + // accumulate + vaux_0 = __riscv_vwmacc_vx_i32m2(vaux_0, scale[is+0], va_q_0, 16); + vaux_1 = __riscv_vwmacc_vx_i32m2(vaux_1, scale[is+2], va_q_1, 16); + vaux_2 = __riscv_vwmacc_vx_i32m2(vaux_2, scale[is+4], va_q_2, 16); + vaux_3 = __riscv_vwmacc_vx_i32m2(vaux_3, scale[is+6], va_q_3, 16); + // + vaux_0 = __riscv_vwmacc_vx_i32m2_m(va_mask, vaux_0, scale[is+1], va_q_0, vl); + vaux_1 = __riscv_vwmacc_vx_i32m2_m(va_mask, vaux_1, scale[is+3], va_q_1, vl); + vaux_2 = __riscv_vwmacc_vx_i32m2_m(va_mask, vaux_2, scale[is+5], va_q_2, vl); + vaux_3 = __riscv_vwmacc_vx_i32m2_m(va_mask, vaux_3, scale[is+7], va_q_3, vl); + + q6 += 64; qh += 32; q8 += 128; is=8; } - break; - default: - assert(false && "Unsupported vector length"); - break; + + vint32m1_t isum0 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(vaux_0, vaux_1, vl), vzero, vl); + vint32m1_t isum1 = __riscv_vredsum_vs_i32m2_i32m1(__riscv_vadd_vv_i32m2(vaux_2, vaux_3, vl), isum0, vl); + + sum_t += __riscv_vmv_x_s_i32m1_i32(isum1); + + sumf += d * sum_t; + } *s = sumf; +} -#else +static NOINLINE void ggml_vec_dot_q6_K_q8_K_vl1024(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + assert(n % QK_K == 0); + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); - UNUSED(x); - UNUSED(y); - UNUSED(nb); + const block_q6_K * GGML_RESTRICT x = vx; + const block_q8_K * GGML_RESTRICT y = vy; + + const int nb = n / QK_K; + + // mask for processing 16 elements per prod register + const vuint16mf2_t va_index = __riscv_vid_v_u16mf2(32); + const vbool32_t va_mask = __riscv_vmsgtu_vx_u16mf2_b32(va_index, 15, 32); + + float sumf = 0; + for (int i = 0; i < nb; ++i) { + const float d = GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d; + + const uint8_t * GGML_RESTRICT q6 = x[i].ql; + const uint8_t * GGML_RESTRICT qh = x[i].qh; + const int8_t * GGML_RESTRICT q8 = y[i].qs; + + const int8_t * GGML_RESTRICT scale = x[i].scales; + + size_t vl = 32; + + vint32m1_t vzero = __riscv_vmv_v_x_i32m1(0, 1); + + int sum_t = 0; + int is = 0; + + vint32m1_t vaux_0 = __riscv_vmv_v_x_i32m1(0, vl); + vint32m1_t vaux_1 = __riscv_vmv_v_x_i32m1(0, vl); + vint32m1_t vaux_2 = __riscv_vmv_v_x_i32m1(0, vl); + vint32m1_t vaux_3 = __riscv_vmv_v_x_i32m1(0, vl); + + for (int j = 0; j < QK_K/128; ++j) { + // load qh + vuint8mf4_t qh_x = __riscv_vle8_v_u8mf4(qh, vl); + + // load Q6 + vuint8mf4_t q6_0 = __riscv_vle8_v_u8mf4(q6, vl); + vuint8mf4_t q6_1 = __riscv_vle8_v_u8mf4(q6+32, vl); + + vuint8mf4_t q6a_0 = __riscv_vand_vx_u8mf4(q6_0, 0x0F, vl); + vuint8mf4_t q6a_1 = __riscv_vand_vx_u8mf4(q6_1, 0x0F, vl); + vuint8mf4_t q6s_0 = __riscv_vsrl_vx_u8mf4(q6_0, 0x04, vl); + vuint8mf4_t q6s_1 = __riscv_vsrl_vx_u8mf4(q6_1, 0x04, vl); + + vuint8mf4_t qh_0 = __riscv_vand_vx_u8mf4(qh_x, 0x03, vl); + vuint8mf4_t qh_1 = __riscv_vand_vx_u8mf4(__riscv_vsrl_vx_u8mf4(qh_x, 0x2, vl), 0x03 , vl); + vuint8mf4_t qh_2 = __riscv_vand_vx_u8mf4(__riscv_vsrl_vx_u8mf4(qh_x, 0x4, vl), 0x03 , vl); + vuint8mf4_t qh_3 = __riscv_vand_vx_u8mf4(__riscv_vsrl_vx_u8mf4(qh_x, 0x6, vl), 0x03 , vl); + + vuint8mf4_t qhi_0 = __riscv_vor_vv_u8mf4(q6a_0, __riscv_vsll_vx_u8mf4(qh_0, 0x04, vl), vl); + vuint8mf4_t qhi_1 = __riscv_vor_vv_u8mf4(q6a_1, __riscv_vsll_vx_u8mf4(qh_1, 0x04, vl), vl); + vuint8mf4_t qhi_2 = __riscv_vor_vv_u8mf4(q6s_0, __riscv_vsll_vx_u8mf4(qh_2, 0x04, vl), vl); + vuint8mf4_t qhi_3 = __riscv_vor_vv_u8mf4(q6s_1, __riscv_vsll_vx_u8mf4(qh_3, 0x04, vl), vl); + + vint8mf4_t a_0 = __riscv_vsub_vx_i8mf4(__riscv_vreinterpret_v_u8mf4_i8mf4(qhi_0), 32, vl); + vint8mf4_t a_1 = __riscv_vsub_vx_i8mf4(__riscv_vreinterpret_v_u8mf4_i8mf4(qhi_1), 32, vl); + vint8mf4_t a_2 = __riscv_vsub_vx_i8mf4(__riscv_vreinterpret_v_u8mf4_i8mf4(qhi_2), 32, vl); + vint8mf4_t a_3 = __riscv_vsub_vx_i8mf4(__riscv_vreinterpret_v_u8mf4_i8mf4(qhi_3), 32, vl); + + // load Q8 and take product + vint16mf2_t va_q_0 = __riscv_vwmul_vv_i16mf2(a_0, __riscv_vle8_v_i8mf4(q8, vl), vl); + vint16mf2_t va_q_1 = __riscv_vwmul_vv_i16mf2(a_1, __riscv_vle8_v_i8mf4(q8+32, vl), vl); + vint16mf2_t va_q_2 = __riscv_vwmul_vv_i16mf2(a_2, __riscv_vle8_v_i8mf4(q8+64, vl), vl); + vint16mf2_t va_q_3 = __riscv_vwmul_vv_i16mf2(a_3, __riscv_vle8_v_i8mf4(q8+96, vl), vl); + + // accumulate + vaux_0 = __riscv_vwmacc_vx_i32m1(vaux_0, scale[is+0], va_q_0, 16); + vaux_1 = __riscv_vwmacc_vx_i32m1(vaux_1, scale[is+2], va_q_1, 16); + vaux_2 = __riscv_vwmacc_vx_i32m1(vaux_2, scale[is+4], va_q_2, 16); + vaux_3 = __riscv_vwmacc_vx_i32m1(vaux_3, scale[is+6], va_q_3, 16); + // + vaux_0 = __riscv_vwmacc_vx_i32m1_m(va_mask, vaux_0, scale[is+1], va_q_0, vl); + vaux_1 = __riscv_vwmacc_vx_i32m1_m(va_mask, vaux_1, scale[is+3], va_q_1, vl); + vaux_2 = __riscv_vwmacc_vx_i32m1_m(va_mask, vaux_2, scale[is+5], va_q_2, vl); + vaux_3 = __riscv_vwmacc_vx_i32m1_m(va_mask, vaux_3, scale[is+7], va_q_3, vl); + + q6 += 64; qh += 32; q8 += 128; is=8; + + } + + vint32m1_t isum0 = __riscv_vredsum_vs_i32m1_i32m1(__riscv_vadd_vv_i32m1(vaux_0, vaux_1, vl), vzero, vl); + vint32m1_t isum1 = __riscv_vredsum_vs_i32m1_i32m1(__riscv_vadd_vv_i32m1(vaux_2, vaux_3, vl), isum0, vl); + + sum_t += __riscv_vmv_x_s_i32m1_i32(isum1); + + sumf += d * sum_t; + + } + + *s = sumf; +} +#endif + +void ggml_vec_dot_q6_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { +#if defined __riscv_xtheadvector + ggml_vec_dot_q6_K_q8_K_xtheadvector(n, s, bs, vx, bx, vy, by, nrc); +#elif defined __riscv_v + switch (__riscv_vlenb() * 8) { + case 128: + ggml_vec_dot_q6_K_q8_K_vl128(n, s, bs, vx, bx, vy, by, nrc); + break; + case 256: + ggml_vec_dot_q6_K_q8_K_vl256(n, s, bs, vx, bx, vy, by, nrc); + break; + case 512: + ggml_vec_dot_q6_K_q8_K_vl512(n, s, bs, vx, bx, vy, by, nrc); + break; + case 1024: + ggml_vec_dot_q6_K_q8_K_vl1024(n, s, bs, vx, bx, vy, by, nrc); + break; + default: + ggml_vec_dot_q6_K_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc); + break; + } +#else ggml_vec_dot_q6_K_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc); #endif } +#if defined __riscv_v +static NOINLINE void ggml_vec_dot_iq1_s_q8_K_vl128(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + assert(n % QK_K == 0); + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + + const block_iq1_s * GGML_RESTRICT x = vx; + const block_q8_K * GGML_RESTRICT y = vy; + + const int nb = n / QK_K; + + float sumf = 0; + for (int i = 0; i < nb; ++i) { + // Load qh once for the entire superblock. + vuint16m1_t qh = __riscv_vle16_v_u16m1(x[i].qh, 8); + + // Calculate ls. + vuint16m1_t temp = __riscv_vsrl_vx_u16m1(qh, 12, 8); + temp = __riscv_vand_vx_u16m1(temp, 7, 8); + vint32m2_t ls = __riscv_vreinterpret_v_u32m2_i32m2(__riscv_vwmulu_vx_u32m2(temp, 2, 8)); + ls = __riscv_vadd_vx_i32m2(ls, 1, 8); + + // Calculate delta. + vbool16_t mask = __riscv_vmseq_vx_u16m1_b16(__riscv_vand_vx_u16m1(qh, 0x8000, 8), 0, 8); + vint32m2_t delta_neg = __riscv_vmv_v_x_i32m2(-1, 8); + vint32m2_t delta_pos = __riscv_vmv_v_x_i32m2(1, 8); + vint32m2_t delta = __riscv_vmerge_vvm_i32m2(delta_neg, delta_pos, mask, 8); + + // Load qs. + vuint8m2_t qs = __riscv_vle8_v_u8m2(x[i].qs, 32); + + // Prepare the indices. + const uint64_t shift = 0x0009000600030000; + vuint16m4_t qh_shift = __riscv_vreinterpret_v_u64m4_u16m4(__riscv_vmv_v_x_u64m4(shift, 8)); + vuint16m4_t qh_gather_index = __riscv_vreinterpret_v_i16m4_u16m4( + __riscv_vdiv_vx_i16m4(__riscv_vreinterpret_v_u16m4_i16m4(__riscv_vid_v_u16m4(32)), 4, 32)); + vuint16m4_t qh_ext = __riscv_vlmul_ext_v_u16m2_u16m4(__riscv_vlmul_ext_v_u16m1_u16m2(qh)); + vuint16m4_t qh_index = __riscv_vrgather_vv_u16m4(qh_ext, qh_gather_index, 32); + qh_index = __riscv_vsrl_vv_u16m4(qh_index, qh_shift, 32); + qh_index = __riscv_vand_vx_u16m4(qh_index, 7, 32); + qh_index = __riscv_vsll_vx_u16m4(qh_index, 8, 32); + qh_index = __riscv_vor_vv_u16m4(qh_index, __riscv_vzext_vf2_u16m4(qs, 32), 32); + vuint16m4_t index = __riscv_vsll_vx_u16m4(qh_index, 3, 32); + + // Final lsums. + int32_t lsums_s[8]; + vint32m1_t one_scalar = __riscv_vmv_v_x_i32m1(0, 1); + + // Sub-blocks 1-2 + { + vuint16m1_t grid_index0 = __riscv_vget_v_u16m4_u16m1(index, 0); + vint8m4_t grid0 = __riscv_vreinterpret_v_i64m4_i8m4(__riscv_vluxei16_v_i64m4((const int64_t*)iq1s_grid, grid_index0, 8)); + vint8m4_t q80 = __riscv_vle8_v_i8m4(&y[i].qs[0], 64); + vint16m8_t lsum0 = __riscv_vwmul_vv_i16m8(grid0, q80, 128); + lsums_s[0] = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m4_i32m1(__riscv_vget_v_i16m8_i16m4(lsum0, 0), one_scalar, 32)); + lsums_s[1] = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m4_i32m1(__riscv_vget_v_i16m8_i16m4(lsum0, 1), one_scalar, 32)); + } + __asm__ __volatile__("" ::: "memory"); + // Sub-blocks 3-4 + { + vuint16m1_t grid_index0 = __riscv_vget_v_u16m4_u16m1(index, 1); + vint8m4_t grid0 = __riscv_vreinterpret_v_i64m4_i8m4(__riscv_vluxei16_v_i64m4((const int64_t*)iq1s_grid, grid_index0, 8)); + vint8m4_t q80 = __riscv_vle8_v_i8m4(&y[i].qs[64], 64); + vint16m8_t lsum0 = __riscv_vwmul_vv_i16m8(grid0, q80, 128); + lsums_s[2] = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m4_i32m1(__riscv_vget_v_i16m8_i16m4(lsum0, 0), one_scalar, 32)); + lsums_s[3] = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m4_i32m1(__riscv_vget_v_i16m8_i16m4(lsum0, 1), one_scalar, 32)); + } + __asm__ __volatile__("" ::: "memory"); + // Sub-blocks 5-6 + { + vuint16m1_t grid_index0 = __riscv_vget_v_u16m4_u16m1(index, 2); + vint8m4_t grid0 = __riscv_vreinterpret_v_i64m4_i8m4(__riscv_vluxei16_v_i64m4((const int64_t*)iq1s_grid, grid_index0, 8)); + vint8m4_t q80 = __riscv_vle8_v_i8m4(&y[i].qs[128], 64); + vint16m8_t lsum0 = __riscv_vwmul_vv_i16m8(grid0, q80, 128); + lsums_s[4] = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m4_i32m1(__riscv_vget_v_i16m8_i16m4(lsum0, 0), one_scalar, 32)); + lsums_s[5] = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m4_i32m1(__riscv_vget_v_i16m8_i16m4(lsum0, 1), one_scalar, 32)); + } + __asm__ __volatile__("" ::: "memory"); + // Sub-blocks 7-8 + { + vuint16m1_t grid_index0 = __riscv_vget_v_u16m4_u16m1(index, 3); + vint8m4_t grid0 = __riscv_vreinterpret_v_i64m4_i8m4(__riscv_vluxei16_v_i64m4((const int64_t*)iq1s_grid, grid_index0, 8)); + vint8m4_t q80 = __riscv_vle8_v_i8m4(&y[i].qs[192], 64); + vint16m8_t lsum0 = __riscv_vwmul_vv_i16m8(grid0, q80, 128); + lsums_s[6] = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m4_i32m1(__riscv_vget_v_i16m8_i16m4(lsum0, 0), one_scalar, 32)); + lsums_s[7] = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m4_i32m1(__riscv_vget_v_i16m8_i16m4(lsum0, 1), one_scalar, 32)); + } + __asm__ __volatile__("" ::: "memory"); + vint32m2_t lsums = __riscv_vle32_v_i32m2(&lsums_s[0], 8); + + // Calculate the bsums. + vint16m2_t bsums_0 = __riscv_vle16_v_i16m2(y[i].bsums, 16); + const vuint32m2_t bsums_i32 = __riscv_vreinterpret_v_u16m2_u32m2(__riscv_vreinterpret_v_i16m2_u16m2(bsums_0)); + const vint16m1_t bsums_i32_0 = __riscv_vreinterpret_v_u16m1_i16m1(__riscv_vnsrl_wx_u16m1(bsums_i32, 0, 8)); + const vint16m1_t bsums_i32_1 = __riscv_vreinterpret_v_u16m1_i16m1(__riscv_vnsrl_wx_u16m1(bsums_i32, 16, 8)); + const vint32m2_t bsums = __riscv_vwadd_vv_i32m2(bsums_i32_0, bsums_i32_1, 8); + + // Accumulation. + vint32m2_t sumi_v = __riscv_vmul_vv_i32m2(ls, lsums, 8); + vint32m2_t sumi1_v = __riscv_vmul_vv_i32m2(__riscv_vmul_vv_i32m2(ls, delta, 8), bsums, 8); + + // Update sumf. + int sumi = __riscv_vmv_x_s_i32m1_i32(__riscv_vredsum_vs_i32m2_i32m1(sumi_v, __riscv_vmv_v_x_i32m1(0.0f, 1), 8)); + int sumi1 = __riscv_vmv_x_s_i32m1_i32(__riscv_vredsum_vs_i32m2_i32m1(sumi1_v, __riscv_vmv_v_x_i32m1(0.0f, 1), 8)); + sumf += GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d * (sumi + IQ1S_DELTA * sumi1); + } + + *s = sumf; +} + +static NOINLINE void ggml_vec_dot_iq1_s_q8_K_vl256(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + assert(n % QK_K == 0); + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + + const block_iq1_s * GGML_RESTRICT x = vx; + const block_q8_K * GGML_RESTRICT y = vy; + + const int nb = n / QK_K; + + float sumf = 0; + for (int i = 0; i < nb; ++i) { + // Load qh once for the entire superblock. + vuint16mf2_t qh = __riscv_vle16_v_u16mf2(x[i].qh, 8); + + // Calculate ls. + vuint16mf2_t temp = __riscv_vsrl_vx_u16mf2(qh, 12, 8); + temp = __riscv_vand_vx_u16mf2(temp, 7, 8); + vint32m1_t ls = __riscv_vreinterpret_v_u32m1_i32m1(__riscv_vwmulu_vx_u32m1(temp, 2, 8)); + ls = __riscv_vadd_vx_i32m1(ls, 1, 8); + + // Calculate delta. + vbool32_t mask = __riscv_vmseq_vx_u16mf2_b32(__riscv_vand_vx_u16mf2(qh, 0x8000, 8), 0, 8); + vint32m1_t delta_neg = __riscv_vmv_v_x_i32m1(-1, 8); + vint32m1_t delta_pos = __riscv_vmv_v_x_i32m1(1, 8); + vint32m1_t delta = __riscv_vmerge_vvm_i32m1(delta_neg, delta_pos, mask, 8); + + // Load qs. + vuint8m1_t qs = __riscv_vle8_v_u8m1(x[i].qs, 32); + + // Prepare the indices. + const uint64_t shift = 0x0009000600030000; + vuint16m2_t qh_shift = __riscv_vreinterpret_v_u64m2_u16m2(__riscv_vmv_v_x_u64m2(shift, 8)); + vuint16m2_t qh_gather_index = __riscv_vreinterpret_v_i16m2_u16m2( + __riscv_vdiv_vx_i16m2(__riscv_vreinterpret_v_u16m2_i16m2(__riscv_vid_v_u16m2(32)), 4, 32)); + vuint16m2_t qh_ext = __riscv_vlmul_ext_v_u16m1_u16m2(__riscv_vlmul_ext_v_u16mf2_u16m1(qh)); + vuint16m2_t qh_index = __riscv_vrgather_vv_u16m2(qh_ext, qh_gather_index, 32); + qh_index = __riscv_vsrl_vv_u16m2(qh_index, qh_shift, 32); + qh_index = __riscv_vand_vx_u16m2(qh_index, 7, 32); + qh_index = __riscv_vsll_vx_u16m2(qh_index, 8, 32); + qh_index = __riscv_vor_vv_u16m2(qh_index, __riscv_vzext_vf2_u16m2(qs, 32), 32); + vuint16m2_t index = __riscv_vsll_vx_u16m2(qh_index, 3, 32); + + // Final lsums. + int32_t lsums_s[8]; + vint32m1_t one_scalar = __riscv_vmv_v_x_i32m1(0, 1); + + // Sub-blocks 1-4 + { + vuint16m1_t grid_index0 = __riscv_vget_v_u16m2_u16m1(index, 0); + vint8m4_t grid0 = __riscv_vreinterpret_v_i64m4_i8m4(__riscv_vluxei16_v_i64m4((const int64_t*)iq1s_grid, grid_index0, 16)); + vint8m4_t q80 = __riscv_vle8_v_i8m4(y[i].qs, 128); + vint16m8_t lsum0 = __riscv_vwmul_vv_i16m8(grid0, q80, 128); + lsums_s[0] = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m2_i32m1(__riscv_vget_v_i16m8_i16m2(lsum0, 0), one_scalar, 32)); + lsums_s[1] = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m2_i32m1(__riscv_vget_v_i16m8_i16m2(lsum0, 1), one_scalar, 32)); + lsums_s[2] = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m2_i32m1(__riscv_vget_v_i16m8_i16m2(lsum0, 2), one_scalar, 32)); + lsums_s[3] = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m2_i32m1(__riscv_vget_v_i16m8_i16m2(lsum0, 3), one_scalar, 32)); + } + __asm__ __volatile__("" ::: "memory"); + // Sub-blocks 5-8 + { + vuint16m1_t grid_index1 = __riscv_vget_v_u16m2_u16m1(index, 1); + vint8m4_t grid1 = __riscv_vreinterpret_v_i64m4_i8m4(__riscv_vluxei16_v_i64m4((const int64_t*)iq1s_grid, grid_index1, 16)); + vint8m4_t q81 = __riscv_vle8_v_i8m4(&y[i].qs[128], 128); + vint16m8_t lsum1 = __riscv_vwmul_vv_i16m8(grid1, q81, 128); + lsums_s[4] = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m2_i32m1(__riscv_vget_v_i16m8_i16m2(lsum1, 0), one_scalar, 32)); + lsums_s[5] = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m2_i32m1(__riscv_vget_v_i16m8_i16m2(lsum1, 1), one_scalar, 32)); + lsums_s[6] = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m2_i32m1(__riscv_vget_v_i16m8_i16m2(lsum1, 2), one_scalar, 32)); + lsums_s[7] = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m2_i32m1(__riscv_vget_v_i16m8_i16m2(lsum1, 3), one_scalar, 32)); + } + __asm__ __volatile__("" ::: "memory"); + vint32m1_t lsums = __riscv_vle32_v_i32m1(&lsums_s[0], 8); + + // Calculate the bsums. + vint16m1_t bsums_0 = __riscv_vle16_v_i16m1(y[i].bsums, 16); + const vuint32m1_t bsums_i32 = __riscv_vreinterpret_v_u16m1_u32m1(__riscv_vreinterpret_v_i16m1_u16m1(bsums_0)); + const vint16mf2_t bsums_i32_0 = __riscv_vreinterpret_v_u16mf2_i16mf2(__riscv_vnsrl_wx_u16mf2(bsums_i32, 0, 8)); + const vint16mf2_t bsums_i32_1 = __riscv_vreinterpret_v_u16mf2_i16mf2(__riscv_vnsrl_wx_u16mf2(bsums_i32, 16, 8)); + const vint32m1_t bsums = __riscv_vwadd_vv_i32m1(bsums_i32_0, bsums_i32_1, 8); + + // Accumulation. + vint32m1_t sumi_v = __riscv_vmul_vv_i32m1(ls, lsums, 8); + vint32m1_t sumi1_v = __riscv_vmul_vv_i32m1(__riscv_vmul_vv_i32m1(ls, delta, 8), bsums, 8); + + // Update sumf. + int sumi = __riscv_vmv_x_s_i32m1_i32(__riscv_vredsum_vs_i32m1_i32m1(sumi_v, __riscv_vmv_v_x_i32m1(0.0f, 1), 8)); + int sumi1 = __riscv_vmv_x_s_i32m1_i32(__riscv_vredsum_vs_i32m1_i32m1(sumi1_v, __riscv_vmv_v_x_i32m1(0.0f, 1), 8)); + sumf += GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d * (sumi + IQ1S_DELTA * sumi1); + } + + *s = sumf; +} + +static NOINLINE void ggml_vec_dot_iq1_s_q8_K_vl512(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + assert(n % QK_K == 0); + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + + const block_iq1_s * GGML_RESTRICT x = vx; + const block_q8_K * GGML_RESTRICT y = vy; + + const int nb = n / QK_K; + + float sumf = 0; + for (int i = 0; i < nb; ++i) { + // Load qh once for the entire superblock. + vuint16mf4_t qh = __riscv_vle16_v_u16mf4(x[i].qh, 8); + + // Calculate ls. + vuint16mf4_t temp = __riscv_vsrl_vx_u16mf4(qh, 12, 8); + temp = __riscv_vand_vx_u16mf4(temp, 7, 8); + vint32mf2_t ls = __riscv_vreinterpret_v_u32mf2_i32mf2(__riscv_vwmulu_vx_u32mf2(temp, 2, 8)); + ls = __riscv_vadd_vx_i32mf2(ls, 1, 8); + + // Calculate delta. + vbool64_t mask = __riscv_vmseq_vx_u16mf4_b64(__riscv_vand_vx_u16mf4(qh, 0x8000, 8), 0, 8); + vint32mf2_t delta_neg = __riscv_vmv_v_x_i32mf2(-1, 8); + vint32mf2_t delta_pos = __riscv_vmv_v_x_i32mf2(1, 8); + vint32mf2_t delta = __riscv_vmerge_vvm_i32mf2(delta_neg, delta_pos, mask, 8); + + // Load qs. + vuint8mf2_t qs = __riscv_vle8_v_u8mf2(x[i].qs, 32); + + // Prepare the indices. + const uint64_t shift = 0x0009000600030000; + vuint16m1_t qh_shift = __riscv_vreinterpret_v_u64m1_u16m1(__riscv_vmv_v_x_u64m1(shift, 8)); + vuint16m1_t qh_gather_index = __riscv_vreinterpret_v_i16m1_u16m1( + __riscv_vdiv_vx_i16m1(__riscv_vreinterpret_v_u16m1_i16m1(__riscv_vid_v_u16m1(32)), 4, 32)); + vuint16m1_t qh_ext = __riscv_vlmul_ext_v_u16mf2_u16m1(__riscv_vlmul_ext_v_u16mf4_u16mf2(qh)); + vuint16m1_t qh_index = __riscv_vrgather_vv_u16m1(qh_ext, qh_gather_index, 32); + qh_index = __riscv_vsrl_vv_u16m1(qh_index, qh_shift, 32); + qh_index = __riscv_vand_vx_u16m1(qh_index, 7, 32); + qh_index = __riscv_vsll_vx_u16m1(qh_index, 8, 32); + qh_index = __riscv_vor_vv_u16m1(qh_index, __riscv_vzext_vf2_u16m1(qs, 32), 32); + vuint16m1_t index = __riscv_vsll_vx_u16m1(qh_index, 3, 32); + + // Final lsums. + int32_t lsums_s[8]; + vint32m1_t one_scalar = __riscv_vmv_v_x_i32m1(0, 1); + + // Sub-blocks 1-8 + { + vint8m4_t grid0 = __riscv_vreinterpret_v_i64m4_i8m4(__riscv_vluxei16_v_i64m4((const int64_t*)iq1s_grid, index, 32)); + vint8m4_t q80 = __riscv_vle8_v_i8m4(y[i].qs, 256); + vint16m8_t lsum0 = __riscv_vwmul_vv_i16m8(grid0, q80, 256); + lsums_s[0] = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m1_i32m1(__riscv_vget_v_i16m8_i16m1(lsum0, 0), one_scalar, 32)); + lsums_s[1] = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m1_i32m1(__riscv_vget_v_i16m8_i16m1(lsum0, 1), one_scalar, 32)); + lsums_s[2] = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m1_i32m1(__riscv_vget_v_i16m8_i16m1(lsum0, 2), one_scalar, 32)); + lsums_s[3] = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m1_i32m1(__riscv_vget_v_i16m8_i16m1(lsum0, 3), one_scalar, 32)); + lsums_s[4] = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m1_i32m1(__riscv_vget_v_i16m8_i16m1(lsum0, 4), one_scalar, 32)); + lsums_s[5] = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m1_i32m1(__riscv_vget_v_i16m8_i16m1(lsum0, 5), one_scalar, 32)); + lsums_s[6] = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m1_i32m1(__riscv_vget_v_i16m8_i16m1(lsum0, 6), one_scalar, 32)); + lsums_s[7] = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m1_i32m1(__riscv_vget_v_i16m8_i16m1(lsum0, 7), one_scalar, 32)); + } + __asm__ __volatile__("" ::: "memory"); + vint32mf2_t lsums = __riscv_vle32_v_i32mf2(&lsums_s[0], 8); + + // Calculate the bsums. + vint16mf2_t bsums_0 = __riscv_vle16_v_i16mf2(y[i].bsums, 16); + const vuint32mf2_t bsums_i32 = __riscv_vreinterpret_v_u16mf2_u32mf2(__riscv_vreinterpret_v_i16mf2_u16mf2(bsums_0)); + const vint16mf4_t bsums_i32_0 = __riscv_vreinterpret_v_u16mf4_i16mf4(__riscv_vnsrl_wx_u16mf4(bsums_i32, 0, 8)); + const vint16mf4_t bsums_i32_1 = __riscv_vreinterpret_v_u16mf4_i16mf4(__riscv_vnsrl_wx_u16mf4(bsums_i32, 16, 8)); + const vint32mf2_t bsums = __riscv_vwadd_vv_i32mf2(bsums_i32_0, bsums_i32_1, 8); + + // Accumulation. + vint32mf2_t sumi_v = __riscv_vmul_vv_i32mf2(ls, lsums, 8); + vint32mf2_t sumi1_v = __riscv_vmul_vv_i32mf2(__riscv_vmul_vv_i32mf2(ls, delta, 8), bsums, 8); + + // Update sumf. + int sumi = __riscv_vmv_x_s_i32m1_i32(__riscv_vredsum_vs_i32mf2_i32m1(sumi_v, __riscv_vmv_v_x_i32m1(0.0f, 1), 8)); + int sumi1 = __riscv_vmv_x_s_i32m1_i32(__riscv_vredsum_vs_i32mf2_i32m1(sumi1_v, __riscv_vmv_v_x_i32m1(0.0f, 1), 8)); + sumf += GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d * (sumi + IQ1S_DELTA * sumi1); + } + + *s = sumf; +} + +static NOINLINE void ggml_vec_dot_iq1_s_q8_K_vl1024(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + assert(n % QK_K == 0); + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + + const block_iq1_s * GGML_RESTRICT x = vx; + const block_q8_K * GGML_RESTRICT y = vy; + + const int nb = n / QK_K; + + // Mask for processing 32 elements per lsum register. + vuint16m1_t l_index = __riscv_vid_v_u16m1(64); + vbool16_t l_mask = __riscv_vmsgtu_vx_u16m1_b16(l_index, 31, 64); + + float sumf = 0; + for (int i = 0; i < nb; ++i) { + // Load qh once for the entire superblock. + vuint16mf4_t qh = __riscv_vle16_v_u16mf4(x[i].qh, 8); + + // Calculate ls. + vuint16mf4_t temp = __riscv_vsrl_vx_u16mf4(qh, 12, 8); + temp = __riscv_vand_vx_u16mf4(temp, 7, 8); + vint32mf2_t ls = __riscv_vreinterpret_v_u32mf2_i32mf2(__riscv_vwmulu_vx_u32mf2(temp, 2, 8)); + ls = __riscv_vadd_vx_i32mf2(ls, 1, 8); + + // Calculate delta. + vbool64_t mask = __riscv_vmseq_vx_u16mf4_b64(__riscv_vand_vx_u16mf4(qh, 0x8000, 8), 0, 8); + vint32mf2_t delta_neg = __riscv_vmv_v_x_i32mf2(-1, 8); + vint32mf2_t delta_pos = __riscv_vmv_v_x_i32mf2(1, 8); + vint32mf2_t delta = __riscv_vmerge_vvm_i32mf2(delta_neg, delta_pos, mask, 8); + + // Load qs. + vuint8mf2_t qs = __riscv_vle8_v_u8mf2(x[i].qs, 32); + + // Prepare the indices. + const uint64_t shift = 0x0009000600030000; + vuint16m1_t qh_shift = __riscv_vreinterpret_v_u64m1_u16m1(__riscv_vmv_v_x_u64m1(shift, 8)); + vuint16m1_t qh_gather_index = __riscv_vreinterpret_v_i16m1_u16m1( + __riscv_vdiv_vx_i16m1(__riscv_vreinterpret_v_u16m1_i16m1(__riscv_vid_v_u16m1(32)), 4, 32)); + vuint16m1_t qh_ext = __riscv_vlmul_ext_v_u16mf2_u16m1(__riscv_vlmul_ext_v_u16mf4_u16mf2(qh)); + vuint16m1_t qh_index = __riscv_vrgather_vv_u16m1(qh_ext, qh_gather_index, 32); + qh_index = __riscv_vsrl_vv_u16m1(qh_index, qh_shift, 32); + qh_index = __riscv_vand_vx_u16m1(qh_index, 7, 32); + qh_index = __riscv_vsll_vx_u16m1(qh_index, 8, 32); + qh_index = __riscv_vor_vv_u16m1(qh_index, __riscv_vzext_vf2_u16m1(qs, 32), 32); + vuint16mf2_t index = __riscv_vlmul_trunc_v_u16m1_u16mf2(__riscv_vsll_vx_u16m1(qh_index, 3, 32)); + + // Final lsums. + int32_t lsums_s[8]; + vint32m1_t one_scalar = __riscv_vmv_v_x_i32m1(0, 1); + + // Sub-blocks 1-8 + { + vint8m2_t grid0 = __riscv_vreinterpret_v_i64m2_i8m2(__riscv_vluxei16_v_i64m2((const int64_t*)iq1s_grid, index, 32)); + vint8m2_t q80 = __riscv_vle8_v_i8m2(y[i].qs, 256); + vint16m4_t lsum0 = __riscv_vwmul_vv_i16m4(grid0, q80, 256); + + // Reduce. + lsums_s[0] = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m1_i32m1( __riscv_vget_v_i16m4_i16m1(lsum0, 0), one_scalar, 32)); + lsums_s[1] = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m1_i32m1_m(l_mask, __riscv_vget_v_i16m4_i16m1(lsum0, 0), one_scalar, 64)); + lsums_s[2] = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m1_i32m1( __riscv_vget_v_i16m4_i16m1(lsum0, 1), one_scalar, 32)); + lsums_s[3] = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m1_i32m1_m(l_mask, __riscv_vget_v_i16m4_i16m1(lsum0, 1), one_scalar, 64)); + lsums_s[4] = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m1_i32m1( __riscv_vget_v_i16m4_i16m1(lsum0, 2), one_scalar, 32)); + lsums_s[5] = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m1_i32m1_m(l_mask, __riscv_vget_v_i16m4_i16m1(lsum0, 2), one_scalar, 64)); + lsums_s[6] = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m1_i32m1( __riscv_vget_v_i16m4_i16m1(lsum0, 3), one_scalar, 32)); + lsums_s[7] = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m1_i32m1_m(l_mask, __riscv_vget_v_i16m4_i16m1(lsum0, 3), one_scalar, 64)); + } + __asm__ __volatile__("" ::: "memory"); + vint32mf2_t lsums = __riscv_vle32_v_i32mf2(&lsums_s[0], 8); + + // Calculate the bsums. + vint16mf2_t bsums_0 = __riscv_vle16_v_i16mf2(y[i].bsums, 16); + const vuint32mf2_t bsums_i32 = __riscv_vreinterpret_v_u16mf2_u32mf2(__riscv_vreinterpret_v_i16mf2_u16mf2(bsums_0)); + const vint16mf4_t bsums_i32_0 = __riscv_vreinterpret_v_u16mf4_i16mf4(__riscv_vnsrl_wx_u16mf4(bsums_i32, 0, 8)); + const vint16mf4_t bsums_i32_1 = __riscv_vreinterpret_v_u16mf4_i16mf4(__riscv_vnsrl_wx_u16mf4(bsums_i32, 16, 8)); + const vint32mf2_t bsums = __riscv_vwadd_vv_i32mf2(bsums_i32_0, bsums_i32_1, 8); + + // Accumulation. + vint32mf2_t sumi_v = __riscv_vmul_vv_i32mf2(ls, lsums, 8); + vint32mf2_t sumi1_v = __riscv_vmul_vv_i32mf2(__riscv_vmul_vv_i32mf2(ls, delta, 8), bsums, 8); + + // Update sumf. + int sumi = __riscv_vmv_x_s_i32m1_i32(__riscv_vredsum_vs_i32mf2_i32m1(sumi_v, __riscv_vmv_v_x_i32m1(0.0f, 1), 8)); + int sumi1 = __riscv_vmv_x_s_i32m1_i32(__riscv_vredsum_vs_i32mf2_i32m1(sumi1_v, __riscv_vmv_v_x_i32m1(0.0f, 1), 8)); + sumf += GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d * (sumi + IQ1S_DELTA * sumi1); + } + + *s = sumf; +} +#endif + +void ggml_vec_dot_iq1_s_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { +#if defined __riscv_v + switch (__riscv_vlenb() * 8) { + case 128: + ggml_vec_dot_iq1_s_q8_K_vl128(n, s, bs, vx, bx, vy, by, nrc); + break; + case 256: + ggml_vec_dot_iq1_s_q8_K_vl256(n, s, bs, vx, bx, vy, by, nrc); + break; + case 512: + ggml_vec_dot_iq1_s_q8_K_vl512(n, s, bs, vx, bx, vy, by, nrc); + break; + case 1024: + ggml_vec_dot_iq1_s_q8_K_vl1024(n, s, bs, vx, bx, vy, by, nrc); + break; + default: + ggml_vec_dot_iq1_s_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc); + break; + } +#else + ggml_vec_dot_iq1_s_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc); +#endif +} + +#if defined __riscv_v +static NOINLINE void ggml_vec_dot_iq1_m_q8_K_vl128(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + assert(n % QK_K == 0); + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + + const block_iq1_m * GGML_RESTRICT x = vx; + const block_q8_K * GGML_RESTRICT y = vy; + + const int nb = n / QK_K; + + iq1m_scale_t scale; + float sumf = 0.0f; + for (int i = 0; i < nb; ++i) { + const int8_t * q8 = y[i].qs; + const uint8_t * qs = x[i].qs; + const uint8_t * qh = x[i].qh; + const uint16_t * sc = (const uint16_t *)x[i].scales; + + scale.u16 = (sc[0] >> 12) | ((sc[1] >> 8) & 0x00f0) | ((sc[2] >> 4) & 0x0f00) | (sc[3] & 0xf000); + + // Accumulators. + vint32m4_t acc1 = __riscv_vmv_v_x_i32m4(0, 16); + vint32m4_t acc2 = __riscv_vmv_v_x_i32m4(0, 16); + + // We process 8 16-element sub-blocks together. + #pragma GCC unroll 1 + for (int ib = 0; ib < QK_K/128; ib++) { + // Load qh for 8 sub-blocks. + const vuint8mf2_t qh_8 = __riscv_vle8_v_u8mf2(qh, 8); + const vuint16m1_t qh_16_lo = __riscv_vzext_vf2_u16m1(qh_8, 8); + const vuint16m1_t qh_16_hi = __riscv_vsll_vx_u16m1(qh_16_lo, 8, 8); + const vuint16m2_t qhb = __riscv_vzext_vf2_u16m2( + __riscv_vreinterpret_v_u16m1_u8m1(__riscv_vor_vv_u16m1(qh_16_lo, qh_16_hi, 8)), 16); + qh += 8; + + // Prepare grid indices. + const vuint16m2_t qsb = __riscv_vzext_vf2_u16m2(__riscv_vle8_v_u8m1(&qs[0], 16), 16); + const vuint16m2_t shift = __riscv_vreinterpret_v_u32m2_u16m2(__riscv_vmv_v_x_u32m2(0x00040008, 8)); + vuint16m2_t index = __riscv_vor_vv_u16m2(qsb, __riscv_vand_vx_u16m2(__riscv_vsll_vv_u16m2(qhb, shift, 16), 0x700, 16), 16); + index = __riscv_vsll_vx_u16m2(index, 3, 16); + qs += 16; + + // Prepare the deltas. + const vbool8_t mask = __riscv_vmsgtu_vx_u16m2_b8( + __riscv_vand_vv_u16m2(qhb, __riscv_vreinterpret_v_u32m2_u16m2(__riscv_vmv_v_x_u32m2(0x00800008, 8)), 16), 0, 16); + const vint64m8_t delta_pos = __riscv_vmv_v_x_i64m8(0x0101010101010101, 16); + const vint8m8_t delta = __riscv_vreinterpret_v_i64m8_i8m8( + __riscv_vmerge_vxm_i64m8(delta_pos, 0xffffffffffffffff, mask, 16)); + + // Sub-blocks 0-3 + { + // Load the grid. + const vint8m4_t iq1b = __riscv_vreinterpret_v_i64m4_i8m4(__riscv_vreinterpret_v_u64m4_i64m4( + __riscv_vluxei16_v_u64m4(iq1s_grid, __riscv_vget_v_u16m2_u16m1(index, 0), 8))); + + // Calculate the lsums. + // + // Sub-block 0, 1 + { + // Load q8 for each sub-block. + const vint8m2_t q8b = __riscv_vle8_v_i8m2(q8, 32); + q8 += 32; + + // Calculate the lsums. + const vint16m4_t lsum1 = __riscv_vwmul_vv_i16m4(__riscv_vget_v_i8m4_i8m2(iq1b, 0), q8b, 32); + const vint16m4_t lsum2 = __riscv_vwmul_vv_i16m4(__riscv_vget_v_i8m8_i8m2(delta, 0), q8b, 32); + + // Prepare the scales. + const int16_t ls_0 = 2*((sc[0] >> 0) & 0x7) + 1; + const int16_t ls_1 = 2*((sc[0] >> 3) & 0x7) + 1; + + // Accumulate in acc0 and acc1 for each sub-block. + acc1 = __riscv_vwmacc_vx_i32m4(acc1, ls_0, __riscv_vget_v_i16m4_i16m2(lsum1, 0), 16); + acc1 = __riscv_vwmacc_vx_i32m4(acc1, ls_1, __riscv_vget_v_i16m4_i16m2(lsum1, 1), 16); + acc2 = __riscv_vwmacc_vx_i32m4(acc2, ls_0, __riscv_vget_v_i16m4_i16m2(lsum2, 0), 16); + acc2 = __riscv_vwmacc_vx_i32m4(acc2, ls_1, __riscv_vget_v_i16m4_i16m2(lsum2, 1), 16); + } + __asm__ __volatile__("" ::: "memory"); + // Sub-block 2, 3 + { + // Load q8 for each sub-block. + const vint8m2_t q8b = __riscv_vle8_v_i8m2(q8, 32); + q8 += 32; + + // Calculate the lsums. + const vint16m4_t lsum1 = __riscv_vwmul_vv_i16m4(__riscv_vget_v_i8m4_i8m2(iq1b, 1), q8b, 32); + const vint16m4_t lsum2 = __riscv_vwmul_vv_i16m4(__riscv_vget_v_i8m8_i8m2(delta, 1), q8b, 32); + + // Prepare the scales. + const int16_t ls_0 = 2*((sc[0] >> 6) & 0x7) + 1; + const int16_t ls_1 = 2*((sc[0] >> 9) & 0x7) + 1; + + // Accumulate in acc0 and acc1 for each sub-block. + acc1 = __riscv_vwmacc_vx_i32m4(acc1, ls_0, __riscv_vget_v_i16m4_i16m2(lsum1, 0), 16); + acc1 = __riscv_vwmacc_vx_i32m4(acc1, ls_1, __riscv_vget_v_i16m4_i16m2(lsum1, 1), 16); + acc2 = __riscv_vwmacc_vx_i32m4(acc2, ls_0, __riscv_vget_v_i16m4_i16m2(lsum2, 0), 16); + acc2 = __riscv_vwmacc_vx_i32m4(acc2, ls_1, __riscv_vget_v_i16m4_i16m2(lsum2, 1), 16); + } + sc += 1; + } + __asm__ __volatile__("" ::: "memory"); + // Sub-blocks 4-7 + { + // Load the grid. + const vint8m4_t iq1b = __riscv_vreinterpret_v_i64m4_i8m4(__riscv_vreinterpret_v_u64m4_i64m4( + __riscv_vluxei16_v_u64m4(iq1s_grid, __riscv_vget_v_u16m2_u16m1(index, 1), 8))); + + // Calculate the lsums. + // + // Sub-block 4, 5 + { + // Load q8 for each sub-block. + const vint8m2_t q8b = __riscv_vle8_v_i8m2(q8, 32); + q8 += 32; + + // Calculate the lsums. + const vint16m4_t lsum1 = __riscv_vwmul_vv_i16m4(__riscv_vget_v_i8m4_i8m2(iq1b, 0), q8b, 32); + const vint16m4_t lsum2 = __riscv_vwmul_vv_i16m4(__riscv_vget_v_i8m8_i8m2(delta, 2), q8b, 32); + + // Prepare the scales. + const int16_t ls_0 = 2*((sc[0] >> 0) & 0x7) + 1; + const int16_t ls_1 = 2*((sc[0] >> 3) & 0x7) + 1; + + // Accumulate in acc0 and acc1 for each sub-block. + acc1 = __riscv_vwmacc_vx_i32m4(acc1, ls_0, __riscv_vget_v_i16m4_i16m2(lsum1, 0), 16); + acc1 = __riscv_vwmacc_vx_i32m4(acc1, ls_1, __riscv_vget_v_i16m4_i16m2(lsum1, 1), 16); + acc2 = __riscv_vwmacc_vx_i32m4(acc2, ls_0, __riscv_vget_v_i16m4_i16m2(lsum2, 0), 16); + acc2 = __riscv_vwmacc_vx_i32m4(acc2, ls_1, __riscv_vget_v_i16m4_i16m2(lsum2, 1), 16); + } + __asm__ __volatile__("" ::: "memory"); + // Sub-block 6, 7 + { + // Load q8 for each sub-block. + const vint8m2_t q8b = __riscv_vle8_v_i8m2(q8, 32); + q8 += 32; + + // Calculate the lsums. + const vint16m4_t lsum1 = __riscv_vwmul_vv_i16m4(__riscv_vget_v_i8m4_i8m2(iq1b, 1), q8b, 32); + const vint16m4_t lsum2 = __riscv_vwmul_vv_i16m4(__riscv_vget_v_i8m8_i8m2(delta, 3), q8b, 32); + + // Prepare the scales. + const int16_t ls_0 = 2*((sc[0] >> 6) & 0x7) + 1; + const int16_t ls_1 = 2*((sc[0] >> 9) & 0x7) + 1; + + // Accumulate in acc0 and acc1 for each sub-block. + acc1 = __riscv_vwmacc_vx_i32m4(acc1, ls_0, __riscv_vget_v_i16m4_i16m2(lsum1, 0), 16); + acc1 = __riscv_vwmacc_vx_i32m4(acc1, ls_1, __riscv_vget_v_i16m4_i16m2(lsum1, 1), 16); + acc2 = __riscv_vwmacc_vx_i32m4(acc2, ls_0, __riscv_vget_v_i16m4_i16m2(lsum2, 0), 16); + acc2 = __riscv_vwmacc_vx_i32m4(acc2, ls_1, __riscv_vget_v_i16m4_i16m2(lsum2, 1), 16); + } + sc += 1; + } + } + + // Reduce and accumulate in `sumf`. + vint32m1_t one = __riscv_vmv_v_x_i32m1(0, 1); + int sumi1 = __riscv_vmv_x_s_i32m1_i32(__riscv_vredsum_vs_i32m4_i32m1(acc1, one, 16)); + int sumi2 = __riscv_vmv_x_s_i32m1_i32(__riscv_vredsum_vs_i32m4_i32m1(acc2, one, 16)); + sumf += y[i].d * GGML_CPU_FP16_TO_FP32(scale.f16) * (sumi1 + IQ1M_DELTA * sumi2); + } + + *s = sumf; +} + +static NOINLINE void ggml_vec_dot_iq1_m_q8_K_vl256(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + assert(n % QK_K == 0); + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + + const block_iq1_m * GGML_RESTRICT x = vx; + const block_q8_K * GGML_RESTRICT y = vy; + + const int nb = n / QK_K; + + iq1m_scale_t scale; + float sumf = 0.0f; + for (int i = 0; i < nb; ++i) { + const int8_t * q8 = y[i].qs; + const uint8_t * qs = x[i].qs; + const uint8_t * qh = x[i].qh; + const uint16_t * sc = (const uint16_t *)x[i].scales; + + scale.u16 = (sc[0] >> 12) | ((sc[1] >> 8) & 0x00f0) | ((sc[2] >> 4) & 0x0f00) | (sc[3] & 0xf000); + + // Accumulators. + vint32m2_t acc1 = __riscv_vmv_v_x_i32m2(0, 16); + vint32m2_t acc2 = __riscv_vmv_v_x_i32m2(0, 16); + + // We process 8 16-element sub-blocks together. + #pragma GCC unroll 1 + for (int ib = 0; ib < QK_K/128; ib++) { + // Load qh for 8 sub-blocks. + const vuint8mf4_t qh_8 = __riscv_vle8_v_u8mf4(qh, 8); + const vuint16mf2_t qh_16_lo = __riscv_vzext_vf2_u16mf2(qh_8, 8); + const vuint16mf2_t qh_16_hi = __riscv_vsll_vx_u16mf2(qh_16_lo, 8, 8); + const vuint16m1_t qhb = __riscv_vzext_vf2_u16m1( + __riscv_vreinterpret_v_u16mf2_u8mf2(__riscv_vor_vv_u16mf2(qh_16_lo, qh_16_hi, 8)), 16); + qh += 8; + + __asm__ __volatile__("" ::: "memory"); + + // Prepare grid indices. + const vuint16m1_t qsb = __riscv_vzext_vf2_u16m1(__riscv_vle8_v_u8mf2(&qs[0], 16), 16); + const vuint16m1_t shift = __riscv_vreinterpret_v_u32m1_u16m1(__riscv_vmv_v_x_u32m1(0x00040008, 8)); + vuint16m1_t index = __riscv_vor_vv_u16m1(qsb, __riscv_vand_vx_u16m1(__riscv_vsll_vv_u16m1(qhb, shift, 16), 0x700, 16), 16); + index = __riscv_vsll_vx_u16m1(index, 3, 16); + qs += 16; + + __asm__ __volatile__("" ::: "memory"); + + // Load the grid. + const vint8m4_t iq1b = __riscv_vreinterpret_v_i64m4_i8m4(__riscv_vreinterpret_v_u64m4_i64m4( + __riscv_vluxei16_v_u64m4(iq1s_grid, index, 16))); + + // Prepare the deltas. + const vbool16_t mask = __riscv_vmsgtu_vx_u16m1_b16( + __riscv_vand_vv_u16m1(qhb, __riscv_vreinterpret_v_u32m1_u16m1(__riscv_vmv_v_x_u32m1(0x00800008, 8)), 16), 0, 16); + const vint64m4_t delta_pos = __riscv_vmv_v_x_i64m4(0x0101010101010101, 16); + const vint8m4_t delta = __riscv_vreinterpret_v_i64m4_i8m4( + __riscv_vmerge_vxm_i64m4(delta_pos, 0xffffffffffffffff, mask, 16)); + + // Load q8 for sub-blocks. + const vint8m4_t q8b = __riscv_vle8_v_i8m4(q8, 128); + q8 += 128; + + // Calculate the lsums. + const vint16m8_t lsum1 = __riscv_vwmul_vv_i16m8(iq1b, q8b, 128); + const vint16m8_t lsum2 = __riscv_vwmul_vv_i16m8(delta, q8b, 128); + + // Prepare the scales. + const int16_t ls_0_0 = 2*((sc[0] >> 0) & 0x7) + 1; + const int16_t ls_0_1 = 2*((sc[0] >> 3) & 0x7) + 1; + const int16_t ls_1_0 = 2*((sc[0] >> 6) & 0x7) + 1; + const int16_t ls_1_1 = 2*((sc[0] >> 9) & 0x7) + 1; + const int16_t ls_2_0 = 2*((sc[1] >> 0) & 0x7) + 1; + const int16_t ls_2_1 = 2*((sc[1] >> 3) & 0x7) + 1; + const int16_t ls_3_0 = 2*((sc[1] >> 6) & 0x7) + 1; + const int16_t ls_3_1 = 2*((sc[1] >> 9) & 0x7) + 1; + sc += 2; + + // Accumulate in acc0 and acc1 for each sub-block. + acc1 = __riscv_vwmacc_vx_i32m2(acc1, ls_0_0, __riscv_vget_v_i16m8_i16m1(lsum1, 0), 16); + acc1 = __riscv_vwmacc_vx_i32m2(acc1, ls_0_1, __riscv_vget_v_i16m8_i16m1(lsum1, 1), 16); + acc2 = __riscv_vwmacc_vx_i32m2(acc2, ls_0_0, __riscv_vget_v_i16m8_i16m1(lsum2, 0), 16); + acc2 = __riscv_vwmacc_vx_i32m2(acc2, ls_0_1, __riscv_vget_v_i16m8_i16m1(lsum2, 1), 16); + // + acc1 = __riscv_vwmacc_vx_i32m2(acc1, ls_1_0, __riscv_vget_v_i16m8_i16m1(lsum1, 2), 16); + acc1 = __riscv_vwmacc_vx_i32m2(acc1, ls_1_1, __riscv_vget_v_i16m8_i16m1(lsum1, 3), 16); + acc2 = __riscv_vwmacc_vx_i32m2(acc2, ls_1_0, __riscv_vget_v_i16m8_i16m1(lsum2, 2), 16); + acc2 = __riscv_vwmacc_vx_i32m2(acc2, ls_1_1, __riscv_vget_v_i16m8_i16m1(lsum2, 3), 16); + // + acc1 = __riscv_vwmacc_vx_i32m2(acc1, ls_2_0, __riscv_vget_v_i16m8_i16m1(lsum1, 4), 16); + acc1 = __riscv_vwmacc_vx_i32m2(acc1, ls_2_1, __riscv_vget_v_i16m8_i16m1(lsum1, 5), 16); + acc2 = __riscv_vwmacc_vx_i32m2(acc2, ls_2_0, __riscv_vget_v_i16m8_i16m1(lsum2, 4), 16); + acc2 = __riscv_vwmacc_vx_i32m2(acc2, ls_2_1, __riscv_vget_v_i16m8_i16m1(lsum2, 5), 16); + // + acc1 = __riscv_vwmacc_vx_i32m2(acc1, ls_3_0, __riscv_vget_v_i16m8_i16m1(lsum1, 6), 16); + acc1 = __riscv_vwmacc_vx_i32m2(acc1, ls_3_1, __riscv_vget_v_i16m8_i16m1(lsum1, 7), 16); + acc2 = __riscv_vwmacc_vx_i32m2(acc2, ls_3_0, __riscv_vget_v_i16m8_i16m1(lsum2, 6), 16); + acc2 = __riscv_vwmacc_vx_i32m2(acc2, ls_3_1, __riscv_vget_v_i16m8_i16m1(lsum2, 7), 16); + + __asm__ __volatile__("" ::: "memory"); + } + + // Reduce and accumulate in `sumf`. + vint32m1_t one = __riscv_vmv_v_x_i32m1(0, 1); + int sumi1 = __riscv_vmv_x_s_i32m1_i32(__riscv_vredsum_vs_i32m2_i32m1(acc1, one, 16)); + int sumi2 = __riscv_vmv_x_s_i32m1_i32(__riscv_vredsum_vs_i32m2_i32m1(acc2, one, 16)); + sumf += y[i].d * GGML_CPU_FP16_TO_FP32(scale.f16) * (sumi1 + IQ1M_DELTA * sumi2); + } + + *s = sumf; +} + +static NOINLINE void ggml_vec_dot_iq1_m_q8_K_vl512(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + assert(n % QK_K == 0); + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + + const block_iq1_m * GGML_RESTRICT x = vx; + const block_q8_K * GGML_RESTRICT y = vy; + + const int nb = n / QK_K; + + iq1m_scale_t scale; + + // Mask for processing 16 elements per lsum register. + const vuint16m1_t l_index = __riscv_vid_v_u16m1(32); + const vbool16_t l_mask = __riscv_vmsgtu_vx_u16m1_b16(l_index, 15, 32); + + float sumf = 0.0f; + for (int i = 0; i < nb; ++i) { + const int8_t * q8 = y[i].qs; + const uint8_t * qs = x[i].qs; + const uint8_t * qh = x[i].qh; + const uint16_t * sc = (const uint16_t *)x[i].scales; + + scale.u16 = (sc[0] >> 12) | ((sc[1] >> 8) & 0x00f0) | ((sc[2] >> 4) & 0x0f00) | (sc[3] & 0xf000); + + // Accumulators. + vint32m2_t acc1 = __riscv_vmv_v_x_i32m2(0, 32); + vint32m2_t acc2 = __riscv_vmv_v_x_i32m2(0, 32); + + // We process all the sub-blocks together. + #pragma GCC unroll 1 + for (int ib = 0; ib < QK_K/256; ib++) { + // Load qh for all 16 sub-blocks. + const vuint8mf4_t qh_8 = __riscv_vle8_v_u8mf4(qh, 16); + const vuint16mf2_t qh_16_lo = __riscv_vzext_vf2_u16mf2(qh_8, 16); + const vuint16mf2_t qh_16_hi = __riscv_vsll_vx_u16mf2(qh_16_lo, 8, 16); + const vuint16m1_t qhb = __riscv_vzext_vf2_u16m1( + __riscv_vreinterpret_v_u16mf2_u8mf2(__riscv_vor_vv_u16mf2(qh_16_lo, qh_16_hi, 16)), 32); + __asm__ __volatile__("" ::: "memory"); + + // Prepare grid indices. + const vuint16m1_t qsb = __riscv_vzext_vf2_u16m1(__riscv_vle8_v_u8mf2(&qs[0], 32), 32); + const vuint16m1_t shift = __riscv_vreinterpret_v_u32m1_u16m1(__riscv_vmv_v_x_u32m1(0x00040008, 16)); + vuint16m1_t index = __riscv_vor_vv_u16m1(qsb, __riscv_vand_vx_u16m1(__riscv_vsll_vv_u16m1(qhb, shift, 32), 0x700, 32), 32); + index = __riscv_vsll_vx_u16m1(index, 3, 32); + __asm__ __volatile__("" ::: "memory"); + + // Load the grid. + const vint8m4_t iq1b = __riscv_vreinterpret_v_i64m4_i8m4(__riscv_vreinterpret_v_u64m4_i64m4( + __riscv_vluxei16_v_u64m4(iq1s_grid, index, 32))); + + // Prepare the deltas. + const vbool16_t mask = __riscv_vmsgtu_vx_u16m1_b16( + __riscv_vand_vv_u16m1(qhb, __riscv_vreinterpret_v_u32m1_u16m1(__riscv_vmv_v_x_u32m1(0x00800008, 16)), 32), 0, 32); + const vint64m4_t delta_pos = __riscv_vmv_v_x_i64m4(0x0101010101010101, 32); + const vint8m4_t delta = __riscv_vreinterpret_v_i64m4_i8m4( + __riscv_vmerge_vxm_i64m4(delta_pos, 0xffffffffffffffff, mask, 32)); + + // Load q8 for sub-blocks. + const vint8m4_t q8b = __riscv_vle8_v_i8m4(q8, 256); + + // Calculate the lsums. + const vint16m8_t lsum1 = __riscv_vwmul_vv_i16m8(iq1b, q8b, 256); + const vint16m8_t lsum2 = __riscv_vwmul_vv_i16m8(delta, q8b, 256); + + // Prepare the scales. + const int16_t ls_0 = 2*((sc[0] >> 0) & 0x7) + 1; + const int16_t ls_1 = 2*((sc[0] >> 3) & 0x7) + 1; + const int16_t ls_2 = 2*((sc[0] >> 6) & 0x7) + 1; + const int16_t ls_3 = 2*((sc[0] >> 9) & 0x7) + 1; + const int16_t ls_4 = 2*((sc[1] >> 0) & 0x7) + 1; + const int16_t ls_5 = 2*((sc[1] >> 3) & 0x7) + 1; + const int16_t ls_6 = 2*((sc[1] >> 6) & 0x7) + 1; + const int16_t ls_7 = 2*((sc[1] >> 9) & 0x7) + 1; + const int16_t ls_8 = 2*((sc[2] >> 0) & 0x7) + 1; + const int16_t ls_9 = 2*((sc[2] >> 3) & 0x7) + 1; + const int16_t ls_10 = 2*((sc[2] >> 6) & 0x7) + 1; + const int16_t ls_11 = 2*((sc[2] >> 9) & 0x7) + 1; + const int16_t ls_12 = 2*((sc[3] >> 0) & 0x7) + 1; + const int16_t ls_13 = 2*((sc[3] >> 3) & 0x7) + 1; + const int16_t ls_14 = 2*((sc[3] >> 6) & 0x7) + 1; + const int16_t ls_15 = 2*((sc[3] >> 9) & 0x7) + 1; + + // Accumulate in acc0 and acc1 for each sub-block. + acc1 = __riscv_vwmacc_vx_i32m2( acc1, ls_0, __riscv_vget_v_i16m8_i16m1(lsum1, 0), 16); + acc1 = __riscv_vwmacc_vx_i32m2_m(l_mask, acc1, ls_1, __riscv_vget_v_i16m8_i16m1(lsum1, 0), 32); + acc2 = __riscv_vwmacc_vx_i32m2( acc2, ls_0, __riscv_vget_v_i16m8_i16m1(lsum2, 0), 16); + acc2 = __riscv_vwmacc_vx_i32m2_m(l_mask, acc2, ls_1, __riscv_vget_v_i16m8_i16m1(lsum2, 0), 32); + // + acc1 = __riscv_vwmacc_vx_i32m2( acc1, ls_2, __riscv_vget_v_i16m8_i16m1(lsum1, 1), 16); + acc1 = __riscv_vwmacc_vx_i32m2_m(l_mask, acc1, ls_3, __riscv_vget_v_i16m8_i16m1(lsum1, 1), 32); + acc2 = __riscv_vwmacc_vx_i32m2( acc2, ls_2, __riscv_vget_v_i16m8_i16m1(lsum2, 1), 16); + acc2 = __riscv_vwmacc_vx_i32m2_m(l_mask, acc2, ls_3, __riscv_vget_v_i16m8_i16m1(lsum2, 1), 32); + // + acc1 = __riscv_vwmacc_vx_i32m2( acc1, ls_4, __riscv_vget_v_i16m8_i16m1(lsum1, 2), 16); + acc1 = __riscv_vwmacc_vx_i32m2_m(l_mask, acc1, ls_5, __riscv_vget_v_i16m8_i16m1(lsum1, 2), 32); + acc2 = __riscv_vwmacc_vx_i32m2( acc2, ls_4, __riscv_vget_v_i16m8_i16m1(lsum2, 2), 16); + acc2 = __riscv_vwmacc_vx_i32m2_m(l_mask, acc2, ls_5, __riscv_vget_v_i16m8_i16m1(lsum2, 2), 32); + // + acc1 = __riscv_vwmacc_vx_i32m2( acc1, ls_6, __riscv_vget_v_i16m8_i16m1(lsum1, 3), 16); + acc1 = __riscv_vwmacc_vx_i32m2_m(l_mask, acc1, ls_7, __riscv_vget_v_i16m8_i16m1(lsum1, 3), 32); + acc2 = __riscv_vwmacc_vx_i32m2( acc2, ls_6, __riscv_vget_v_i16m8_i16m1(lsum2, 3), 16); + acc2 = __riscv_vwmacc_vx_i32m2_m(l_mask, acc2, ls_7, __riscv_vget_v_i16m8_i16m1(lsum2, 3), 32); + // + acc1 = __riscv_vwmacc_vx_i32m2( acc1, ls_8, __riscv_vget_v_i16m8_i16m1(lsum1, 4), 16); + acc1 = __riscv_vwmacc_vx_i32m2_m(l_mask, acc1, ls_9, __riscv_vget_v_i16m8_i16m1(lsum1, 4), 32); + acc2 = __riscv_vwmacc_vx_i32m2( acc2, ls_8, __riscv_vget_v_i16m8_i16m1(lsum2, 4), 16); + acc2 = __riscv_vwmacc_vx_i32m2_m(l_mask, acc2, ls_9, __riscv_vget_v_i16m8_i16m1(lsum2, 4), 32); + // + acc1 = __riscv_vwmacc_vx_i32m2( acc1, ls_10, __riscv_vget_v_i16m8_i16m1(lsum1, 5), 16); + acc1 = __riscv_vwmacc_vx_i32m2_m(l_mask, acc1, ls_11, __riscv_vget_v_i16m8_i16m1(lsum1, 5), 32); + acc2 = __riscv_vwmacc_vx_i32m2( acc2, ls_10, __riscv_vget_v_i16m8_i16m1(lsum2, 5), 16); + acc2 = __riscv_vwmacc_vx_i32m2_m(l_mask, acc2, ls_11, __riscv_vget_v_i16m8_i16m1(lsum2, 5), 32); + // + acc1 = __riscv_vwmacc_vx_i32m2( acc1, ls_12, __riscv_vget_v_i16m8_i16m1(lsum1, 6), 16); + acc1 = __riscv_vwmacc_vx_i32m2_m(l_mask, acc1, ls_13, __riscv_vget_v_i16m8_i16m1(lsum1, 6), 32); + acc2 = __riscv_vwmacc_vx_i32m2( acc2, ls_12, __riscv_vget_v_i16m8_i16m1(lsum2, 6), 16); + acc2 = __riscv_vwmacc_vx_i32m2_m(l_mask, acc2, ls_13, __riscv_vget_v_i16m8_i16m1(lsum2, 6), 32); + // + acc1 = __riscv_vwmacc_vx_i32m2( acc1, ls_14, __riscv_vget_v_i16m8_i16m1(lsum1, 7), 16); + acc1 = __riscv_vwmacc_vx_i32m2_m(l_mask, acc1, ls_15, __riscv_vget_v_i16m8_i16m1(lsum1, 7), 32); + acc2 = __riscv_vwmacc_vx_i32m2( acc2, ls_14, __riscv_vget_v_i16m8_i16m1(lsum2, 7), 16); + acc2 = __riscv_vwmacc_vx_i32m2_m(l_mask, acc2, ls_15, __riscv_vget_v_i16m8_i16m1(lsum2, 7), 32); + + __asm__ __volatile__("" ::: "memory"); + } + + // Reduce and accumulate in `sumf`. + vint32m1_t one = __riscv_vmv_v_x_i32m1(0, 1); + int sumi1 = __riscv_vmv_x_s_i32m1_i32(__riscv_vredsum_vs_i32m2_i32m1(acc1, one, 32)); + int sumi2 = __riscv_vmv_x_s_i32m1_i32(__riscv_vredsum_vs_i32m2_i32m1(acc2, one, 32)); + sumf += y[i].d * GGML_CPU_FP16_TO_FP32(scale.f16) * (sumi1 + IQ1M_DELTA * sumi2); + } + + *s = sumf; +} + +static NOINLINE void ggml_vec_dot_iq1_m_q8_K_vl1024(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + assert(n % QK_K == 0); + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + + const block_iq1_m * GGML_RESTRICT x = vx; + const block_q8_K * GGML_RESTRICT y = vy; + + const int nb = n / QK_K; + + iq1m_scale_t scale; + float sumf = 0.0f; + for (int i = 0; i < nb; ++i) { + const int8_t * q8 = y[i].qs; + const uint8_t * qs = x[i].qs; + const uint8_t * qh = x[i].qh; + const uint16_t * sc = (const uint16_t *)x[i].scales; + + scale.u16 = (sc[0] >> 12) | ((sc[1] >> 8) & 0x00f0) | ((sc[2] >> 4) & 0x0f00) | (sc[3] & 0xf000); + + // Accumulators. + vint32m2_t acc1 = __riscv_vmv_v_x_i32m2(0, 64); + vint32m2_t acc2 = __riscv_vmv_v_x_i32m2(0, 64); + + // We process all the sub-blocks together. + #pragma GCC unroll 1 + for (int ib = 0; ib < QK_K/256; ib++) { + // Load qh for all 16 sub-blocks. + const vuint8mf8_t qh_8 = __riscv_vle8_v_u8mf8(qh, 16); + const vuint16mf4_t qh_16_lo = __riscv_vzext_vf2_u16mf4(qh_8, 16); + const vuint16mf4_t qh_16_hi = __riscv_vsll_vx_u16mf4(qh_16_lo, 8, 16); + const vuint16mf2_t qhb = __riscv_vzext_vf2_u16mf2( + __riscv_vreinterpret_v_u16mf4_u8mf4(__riscv_vor_vv_u16mf4(qh_16_lo, qh_16_hi, 16)), 32); + __asm__ __volatile__("" ::: "memory"); + + // Prepare grid indices. + const vuint16mf2_t qsb = __riscv_vzext_vf2_u16mf2(__riscv_vle8_v_u8mf4(&qs[0], 32), 32); + const vuint16mf2_t shift = __riscv_vreinterpret_v_u32mf2_u16mf2(__riscv_vmv_v_x_u32mf2(0x00040008, 16)); + vuint16mf2_t index = __riscv_vor_vv_u16mf2(qsb, __riscv_vand_vx_u16mf2(__riscv_vsll_vv_u16mf2(qhb, shift, 32), 0x700, 32), 32); + index = __riscv_vsll_vx_u16mf2(index, 3, 32); + __asm__ __volatile__("" ::: "memory"); + + // Load the grid. + const vint8m2_t iq1b = __riscv_vreinterpret_v_i64m2_i8m2(__riscv_vreinterpret_v_u64m2_i64m2( + __riscv_vluxei16_v_u64m2(iq1s_grid, index, 32))); + + // Prepare the deltas. + const vbool32_t mask = __riscv_vmsgtu_vx_u16mf2_b32( + __riscv_vand_vv_u16mf2(qhb, __riscv_vreinterpret_v_u32mf2_u16mf2(__riscv_vmv_v_x_u32mf2(0x00800008, 16)), 32), 0, 32); + const vint64m2_t delta_pos = __riscv_vmv_v_x_i64m2(0x0101010101010101, 32); + const vint8m2_t delta = __riscv_vreinterpret_v_i64m2_i8m2( + __riscv_vmerge_vxm_i64m2(delta_pos, 0xffffffffffffffff, mask, 32)); + + // Load q8 for sub-blocks. + const vint8m2_t q8b = __riscv_vle8_v_i8m2(q8, 256); + + // Calculate the lsums. + const vint16m4_t lsum1 = __riscv_vwmul_vv_i16m4(iq1b, q8b, 256); + const vint16m4_t lsum2 = __riscv_vwmul_vv_i16m4(delta, q8b, 256); + + // Prepare the scales. + const int16_t ls_0 = 2*((sc[0] >> 0) & 0x7) + 1; + const int16_t ls_1 = 2*((sc[0] >> 3) & 0x7) + 1; + const int16_t ls_2 = 2*((sc[0] >> 6) & 0x7) + 1; + const int16_t ls_3 = 2*((sc[0] >> 9) & 0x7) + 1; + const int16_t ls_4 = 2*((sc[1] >> 0) & 0x7) + 1; + const int16_t ls_5 = 2*((sc[1] >> 3) & 0x7) + 1; + const int16_t ls_6 = 2*((sc[1] >> 6) & 0x7) + 1; + const int16_t ls_7 = 2*((sc[1] >> 9) & 0x7) + 1; + const int16_t ls_8 = 2*((sc[2] >> 0) & 0x7) + 1; + const int16_t ls_9 = 2*((sc[2] >> 3) & 0x7) + 1; + const int16_t ls_10 = 2*((sc[2] >> 6) & 0x7) + 1; + const int16_t ls_11 = 2*((sc[2] >> 9) & 0x7) + 1; + const int16_t ls_12 = 2*((sc[3] >> 0) & 0x7) + 1; + const int16_t ls_13 = 2*((sc[3] >> 3) & 0x7) + 1; + const int16_t ls_14 = 2*((sc[3] >> 6) & 0x7) + 1; + const int16_t ls_15 = 2*((sc[3] >> 9) & 0x7) + 1; + + // Mask for processing 16 elements per lsum register. + const vuint16m1_t l_index = __riscv_vid_v_u16m1(64); + + // Accumulate in acc1 and acc2 for each sub-block. + acc1 = __riscv_vwmacc_vx_i32m2(acc1, ls_0, __riscv_vget_v_i16m4_i16m1(lsum1, 0), 16); + acc2 = __riscv_vwmacc_vx_i32m2(acc2, ls_0, __riscv_vget_v_i16m4_i16m1(lsum2, 0), 16); + acc1 = __riscv_vwmacc_vx_i32m2(acc1, ls_4, __riscv_vget_v_i16m4_i16m1(lsum1, 1), 16); + acc2 = __riscv_vwmacc_vx_i32m2(acc2, ls_4, __riscv_vget_v_i16m4_i16m1(lsum2, 1), 16); + acc1 = __riscv_vwmacc_vx_i32m2(acc1, ls_8, __riscv_vget_v_i16m4_i16m1(lsum1, 2), 16); + acc2 = __riscv_vwmacc_vx_i32m2(acc2, ls_8, __riscv_vget_v_i16m4_i16m1(lsum2, 2), 16); + acc1 = __riscv_vwmacc_vx_i32m2(acc1, ls_12, __riscv_vget_v_i16m4_i16m1(lsum1, 3), 16); + acc2 = __riscv_vwmacc_vx_i32m2(acc2, ls_12, __riscv_vget_v_i16m4_i16m1(lsum2, 3), 16); + // + const vbool16_t l_mask_16_32 = __riscv_vmsgtu_vx_u16m1_b16(l_index, 15, 64); + acc1 = __riscv_vwmacc_vx_i32m2_m(l_mask_16_32, acc1, ls_1, __riscv_vget_v_i16m4_i16m1(lsum1, 0), 32); + acc2 = __riscv_vwmacc_vx_i32m2_m(l_mask_16_32, acc2, ls_1, __riscv_vget_v_i16m4_i16m1(lsum2, 0), 32); + acc1 = __riscv_vwmacc_vx_i32m2_m(l_mask_16_32, acc1, ls_5, __riscv_vget_v_i16m4_i16m1(lsum1, 1), 32); + acc2 = __riscv_vwmacc_vx_i32m2_m(l_mask_16_32, acc2, ls_5, __riscv_vget_v_i16m4_i16m1(lsum2, 1), 32); + acc1 = __riscv_vwmacc_vx_i32m2_m(l_mask_16_32, acc1, ls_9, __riscv_vget_v_i16m4_i16m1(lsum1, 2), 32); + acc2 = __riscv_vwmacc_vx_i32m2_m(l_mask_16_32, acc2, ls_9, __riscv_vget_v_i16m4_i16m1(lsum2, 2), 32); + acc1 = __riscv_vwmacc_vx_i32m2_m(l_mask_16_32, acc1, ls_13, __riscv_vget_v_i16m4_i16m1(lsum1, 3), 32); + acc2 = __riscv_vwmacc_vx_i32m2_m(l_mask_16_32, acc2, ls_13, __riscv_vget_v_i16m4_i16m1(lsum2, 3), 32); + // + const vbool16_t l_mask_32_48 = __riscv_vmsgtu_vx_u16m1_b16(l_index, 31, 64); + acc1 = __riscv_vwmacc_vx_i32m2_m(l_mask_32_48, acc1, ls_2, __riscv_vget_v_i16m4_i16m1(lsum1, 0), 48); + acc2 = __riscv_vwmacc_vx_i32m2_m(l_mask_32_48, acc2, ls_2, __riscv_vget_v_i16m4_i16m1(lsum2, 0), 48); + acc1 = __riscv_vwmacc_vx_i32m2_m(l_mask_32_48, acc1, ls_6, __riscv_vget_v_i16m4_i16m1(lsum1, 1), 48); + acc2 = __riscv_vwmacc_vx_i32m2_m(l_mask_32_48, acc2, ls_6, __riscv_vget_v_i16m4_i16m1(lsum2, 1), 48); + acc1 = __riscv_vwmacc_vx_i32m2_m(l_mask_32_48, acc1, ls_10, __riscv_vget_v_i16m4_i16m1(lsum1, 2), 48); + acc2 = __riscv_vwmacc_vx_i32m2_m(l_mask_32_48, acc2, ls_10, __riscv_vget_v_i16m4_i16m1(lsum2, 2), 48); + acc1 = __riscv_vwmacc_vx_i32m2_m(l_mask_32_48, acc1, ls_14, __riscv_vget_v_i16m4_i16m1(lsum1, 3), 48); + acc2 = __riscv_vwmacc_vx_i32m2_m(l_mask_32_48, acc2, ls_14, __riscv_vget_v_i16m4_i16m1(lsum2, 3), 48); + // + const vbool16_t l_mask_48_64 = __riscv_vmsgtu_vx_u16m1_b16(l_index, 47, 64); + acc1 = __riscv_vwmacc_vx_i32m2_m(l_mask_48_64, acc1, ls_3, __riscv_vget_v_i16m4_i16m1(lsum1, 0), 64); + acc2 = __riscv_vwmacc_vx_i32m2_m(l_mask_48_64, acc2, ls_3, __riscv_vget_v_i16m4_i16m1(lsum2, 0), 64); + acc1 = __riscv_vwmacc_vx_i32m2_m(l_mask_48_64, acc1, ls_7, __riscv_vget_v_i16m4_i16m1(lsum1, 1), 64); + acc2 = __riscv_vwmacc_vx_i32m2_m(l_mask_48_64, acc2, ls_7, __riscv_vget_v_i16m4_i16m1(lsum2, 1), 64); + acc1 = __riscv_vwmacc_vx_i32m2_m(l_mask_48_64, acc1, ls_11, __riscv_vget_v_i16m4_i16m1(lsum1, 2), 64); + acc2 = __riscv_vwmacc_vx_i32m2_m(l_mask_48_64, acc2, ls_11, __riscv_vget_v_i16m4_i16m1(lsum2, 2), 64); + acc1 = __riscv_vwmacc_vx_i32m2_m(l_mask_48_64, acc1, ls_15, __riscv_vget_v_i16m4_i16m1(lsum1, 3), 64); + acc2 = __riscv_vwmacc_vx_i32m2_m(l_mask_48_64, acc2, ls_15, __riscv_vget_v_i16m4_i16m1(lsum2, 3), 64); + + __asm__ __volatile__("" ::: "memory"); + } + + // Reduce and accumulate in `sumf`. + vint32m1_t one = __riscv_vmv_v_x_i32m1(0, 1); + int sumi1 = __riscv_vmv_x_s_i32m1_i32(__riscv_vredsum_vs_i32m2_i32m1(acc1, one, 64)); + int sumi2 = __riscv_vmv_x_s_i32m1_i32(__riscv_vredsum_vs_i32m2_i32m1(acc2, one, 64)); + sumf += y[i].d * GGML_CPU_FP16_TO_FP32(scale.f16) * (sumi1 + IQ1M_DELTA * sumi2); + } + + *s = sumf; +} +#endif + +void ggml_vec_dot_iq1_m_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { +#if defined __riscv_v + switch (__riscv_vlenb() * 8) { + case 128: + ggml_vec_dot_iq1_m_q8_K_vl128(n, s, bs, vx, bx, vy, by, nrc); + break; + case 256: + ggml_vec_dot_iq1_m_q8_K_vl256(n, s, bs, vx, bx, vy, by, nrc); + break; + case 512: + ggml_vec_dot_iq1_m_q8_K_vl512(n, s, bs, vx, bx, vy, by, nrc); + break; + case 1024: + ggml_vec_dot_iq1_m_q8_K_vl1024(n, s, bs, vx, bx, vy, by, nrc); + break; + default: + ggml_vec_dot_iq1_m_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc); + break; + } +#else + ggml_vec_dot_iq1_m_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc); +#endif +} + +#if defined __riscv_v +static const uint8_t sign_gather_indices_arr[64] = { + 0,0,0,0,0,0,0,0, 1,1,1,1,1,1,1,1, 2,2,2,2,2,2,2,2, 3,3,3,3,3,3,3,3, + 4,4,4,4,4,4,4,4, 5,5,5,5,5,5,5,5, 6,6,6,6,6,6,6,6, 7,7,7,7,7,7,7,7 +}; + +static const uint8_t sign_bit_masks_arr[64] = { + 1,2,4,8,16,32,64,128, 1,2,4,8,16,32,64,128, 1,2,4,8,16,32,64,128, 1,2,4,8,16,32,64,128, + 1,2,4,8,16,32,64,128, 1,2,4,8,16,32,64,128, 1,2,4,8,16,32,64,128, 1,2,4,8,16,32,64,128 +}; + +static NOINLINE void ggml_vec_dot_iq2_s_q8_K_vl128(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + assert(n % QK_K == 0); + UNUSED(nrc); UNUSED(bx); UNUSED(by); UNUSED(bs); + + const block_iq2_s * GGML_RESTRICT x = vx; + const block_q8_K * GGML_RESTRICT y = vy; + + const int nb = n / QK_K; + const uint64_t * grid64 = (const uint64_t *)iq2s_grid; + + // Pre-load Constants + vuint8m2_t v_ids = __riscv_vid_v_u8m2(32); + vuint8m2_t v_sign_gather_indices = __riscv_vsrl_vx_u8m2(v_ids, 3, 32); + vuint8m2_t v_ones = __riscv_vmv_v_x_u8m2(1, 32); + vuint8m2_t v_shift_amts = __riscv_vand_vx_u8m2(v_ids, 7, 32); + vuint8m2_t v_sign_masks = __riscv_vsll_vv_u8m2(v_ones, v_shift_amts, 32); + uint16_t shift_qh_arr[4] = {11, 9, 7, 5}; + vuint16mf2_t v_shift_qh = __riscv_vle16_v_u16mf2(shift_qh_arr, 4); + + float sumf = 0.0f; + + for (int i = 0; i < nb; ++i) { + const float combined_scale = GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d; + + const uint8_t * GGML_RESTRICT qs = x[i].qs; + const uint8_t * GGML_RESTRICT qh = x[i].qh; + const uint8_t * GGML_RESTRICT scales = x[i].scales; + const int8_t * GGML_RESTRICT q8 = y[i].qs; + + const uint8_t * signs_ptr = qs + 32; + float sum_block = 0.0f; + + for (int ib = 0; ib < 8; ++ib) { + + // Load Low Bits [4 bytes] + vuint8mf4_t v_qs_u8 = __riscv_vle8_v_u8mf4(qs, 4); + qs += 4; + + // Load 1 byte. It contains bits for 4 mini-blocks. + uint8_t qh_val = *qh++; + + // Combine Low + High bits of 10bit indices + vuint8mf4_t v_qh_raw = __riscv_vmv_v_x_u8mf4(qh_val, 4); + vuint16mf2_t v_qh_u16 = __riscv_vwcvtu_x_x_v_u16mf2(v_qh_raw, 4); + vuint16mf2_t v_qh_mf2 = __riscv_vsll_vv_u16mf2(v_qh_u16, v_shift_qh, 4); + v_qh_mf2 = __riscv_vand_vx_u16mf2(v_qh_mf2, 0x1800, 4); + vuint16mf2_t v_qs_u16_mf2 = __riscv_vwcvtu_x_x_v_u16mf2(v_qs_u8, 4); + vuint16mf2_t v_qs_u16 = __riscv_vsll_vx_u16mf2(v_qs_u16_mf2, 3, 4); + vuint16mf2_t v_grid_offsets = __riscv_vor_vv_u16mf2(v_qs_u16, v_qh_mf2, 4); + + // Lookup Grid + vint8m2_t v_grid_i8 = __riscv_vreinterpret_v_u8m2_i8m2(__riscv_vreinterpret_v_u64m2_u8m2(__riscv_vluxei16_v_u64m2(grid64, v_grid_offsets, 4))); + + vuint8mf4_t v_signs_raw = __riscv_vle8_v_u8mf4(signs_ptr, 4); + signs_ptr += 4; + vuint8m2_t v_signs_source = __riscv_vlmul_ext_v_u8mf4_u8m2(v_signs_raw); + vuint8m2_t v_signs_bcast = __riscv_vrgather_vv_u8m2(v_signs_source, v_sign_gather_indices, 32); + + // generating sign mask + vuint8m2_t v_sign_bits = __riscv_vand_vv_u8m2(v_signs_bcast, v_sign_masks, 32); + vbool4_t m_negative = __riscv_vmsne_vx_u8m2_b4(v_sign_bits, 0, 32); + + vint8m2_t v_q8 = __riscv_vle8_v_i8m2(q8, 32); + q8 += 32; + + // apply signs + vint8m2_t v_q8_signed = __riscv_vrsub_vx_i8m2_mu(m_negative,v_q8, v_q8, 0, 32); + vint16m4_t v_dot = __riscv_vwmul_vv_i16m4(v_grid_i8, v_q8_signed, 32); + + // Reduction + vint32m1_t v_zero = __riscv_vmv_v_x_i32m1(0, 1); + + // Reduce 0-15 (First Half) + int32_t s0 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m2_i32m1( + __riscv_vget_v_i16m4_i16m2(v_dot, 0), v_zero, 16)); + + // Reduce 16-31 (Second Half) + int32_t s1 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m2_i32m1( + __riscv_vget_v_i16m4_i16m2(v_dot, 1), v_zero, 16)); + + // Apply sub Scales + uint8_t sc = *scales++; + + sum_block += s0 * (2 * (sc & 0xF) + 1); + sum_block += s1 * (2 * (sc >> 4) + 1); + } + sumf += sum_block * combined_scale; + } + *s = 0.125f * sumf; +} + +static NOINLINE void ggml_vec_dot_iq2_s_q8_K_vl256(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + assert(n % QK_K == 0); + UNUSED(nrc); UNUSED(bx); UNUSED(by); UNUSED(bs); + + const block_iq2_s * GGML_RESTRICT x = vx; + const block_q8_K * GGML_RESTRICT y = vy; + + const int nb = n / QK_K; + const uint64_t * grid64 = (const uint64_t *)iq2s_grid; + + // --- Pre-load Constants --- + uint16_t gather_qh_arr[8] = {0, 0, 0, 0, 1, 1, 1, 1}; + vuint16mf2_t v_gather_qh = __riscv_vle16_v_u16mf2(gather_qh_arr, 8); + uint16_t shift_qh_arr[8] = {11, 9, 7, 5, 11, 9, 7, 5}; + vuint16mf2_t v_shift_qh = __riscv_vle16_v_u16mf2(shift_qh_arr, 8); + + // Constants for sign extraction + vuint8m2_t v_sign_gather_indices = __riscv_vle8_v_u8m2(sign_gather_indices_arr, 64); + vuint8m2_t v_sign_masks = __riscv_vle8_v_u8m2(sign_bit_masks_arr, 64); + + float sumf = 0.0f; + + for (int i = 0; i < nb; ++i) { + const float combined_scale = GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d; + + const uint8_t * GGML_RESTRICT qs = x[i].qs; + const uint8_t * GGML_RESTRICT qh = x[i].qh; + const uint8_t * GGML_RESTRICT scales = x[i].scales; + const int8_t * GGML_RESTRICT q8 = y[i].qs; + + const uint8_t * signs_ptr = qs + 32; + + float sum_block = 0.0f; + + for (int ib = 0; ib < 4; ++ib) { + // Combine low + high bits + vuint8mf4_t v_qs_u8 = __riscv_vle8_v_u8mf4(qs, 8); + qs += 8; + uint16_t qh_val; + memcpy(&qh_val, qh, 2); + qh += 2; + vuint8mf8_t v_qh_raw = __riscv_vle8_v_u8mf8((const uint8_t*)&qh_val, 2); + vuint16mf4_t v_qh_u16 = __riscv_vwcvtu_x_x_v_u16mf4(v_qh_raw, 2); + vuint16mf2_t v_qh_u16_ext = __riscv_vlmul_ext_v_u16mf4_u16mf2(v_qh_u16); + vuint16mf2_t v_qh_expanded = __riscv_vrgather_vv_u16mf2(v_qh_u16_ext, v_gather_qh, 8); + v_qh_expanded = __riscv_vsll_vv_u16mf2(v_qh_expanded, v_shift_qh, 8); + + // Mask: We want bits 11-12. 0x1800 = 0001 1000 0000 0000 + v_qh_expanded = __riscv_vand_vx_u16mf2(v_qh_expanded, 0x1800, 8); + vuint16mf2_t v_qs_u16 = __riscv_vwcvtu_x_x_v_u16mf2(v_qs_u8, 8); + + // Multiply by 8 to get byte offset, instead of element offset + v_qs_u16 = __riscv_vsll_vx_u16mf2(v_qs_u16, 3, 8); + vuint16mf2_t v_grid_offsets = __riscv_vor_vv_u16mf2(v_qs_u16, v_qh_expanded, 8); + + // Lookup Grid using Byte Offsets + vuint64m2_t v_grid_vals = __riscv_vluxei16_v_u64m2(grid64, v_grid_offsets, 8); + + vuint8m2_t v_grid_u8 = __riscv_vreinterpret_v_u64m2_u8m2(v_grid_vals); + vint8m2_t v_grid_i8 = __riscv_vreinterpret_v_u8m2_i8m2(v_grid_u8); + + // Load signs and generate sign mask + vuint8mf4_t v_signs_raw = __riscv_vle8_v_u8mf4(signs_ptr, 8); + signs_ptr += 8; + + vuint8m2_t v_signs_source = __riscv_vlmul_ext_v_u8mf4_u8m2(v_signs_raw); + vuint8m2_t v_signs_bcast = __riscv_vrgather_vv_u8m2(v_signs_source, v_sign_gather_indices, 64); + + vuint8m2_t v_sign_bits = __riscv_vand_vv_u8m2(v_signs_bcast, v_sign_masks, 64); + vbool4_t m_negative = __riscv_vmsne_vx_u8m2_b4(v_sign_bits, 0, 64); + + vint8m2_t v_q8 = __riscv_vle8_v_i8m2(q8, 64); + q8 += 64; + + vint8m2_t v_q8_signed = __riscv_vrsub_vx_i8m2_mu(m_negative, v_q8, v_q8, 0, 64); + vint16m4_t v_dot = __riscv_vwmul_vv_i16m4(v_grid_i8, v_q8_signed, 64); + + vint32m1_t v_zero = __riscv_vmv_v_x_i32m1(0, 1); + + int32_t s0 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m1_i32m1( + __riscv_vget_v_i16m4_i16m1(v_dot, 0), v_zero, 16)); + int32_t s1 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m1_i32m1( + __riscv_vget_v_i16m4_i16m1(v_dot, 1), v_zero, 16)); + int32_t s2 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m1_i32m1( + __riscv_vget_v_i16m4_i16m1(v_dot, 2), v_zero, 16)); + int32_t s3 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m1_i32m1( + __riscv_vget_v_i16m4_i16m1(v_dot, 3), v_zero, 16)); + + uint8_t sc0 = scales[0]; + uint8_t sc1 = scales[1]; + scales += 2; + + sum_block += s0 * (2 * (sc0 & 0xF) + 1); + sum_block += s1 * (2 * (sc0 >> 4) + 1); + sum_block += s2 * (2 * (sc1 & 0xF) + 1); + sum_block += s3 * (2 * (sc1 >> 4) + 1); + } + sumf += sum_block * combined_scale; + } + *s = 0.125f * sumf; +} + +static NOINLINE void ggml_vec_dot_iq2_s_q8_K_vl512(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + assert(n % QK_K == 0); + UNUSED(nrc); UNUSED(bx); UNUSED(by); UNUSED(bs); + + const block_iq2_s * GGML_RESTRICT x = vx; + const block_q8_K * GGML_RESTRICT y = vy; + + const int nb = n / QK_K; + const uint64_t * grid64 = (const uint64_t *)iq2s_grid; + + vuint8m2_t v_ids = __riscv_vid_v_u8m2(128); + vuint8m2_t v_sign_gather_indices = __riscv_vsrl_vx_u8m2(v_ids, 3, 128); + + vuint8m2_t v_ones = __riscv_vmv_v_x_u8m2(1, 128); + vuint8m2_t v_shift_amts = __riscv_vand_vx_u8m2(v_ids, 7, 128); + vuint8m2_t v_sign_masks = __riscv_vsll_vv_u8m2(v_ones, v_shift_amts, 128); + + uint16_t gather_qh_arr[16] = {0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3}; + vuint16mf2_t v_gather_qh = __riscv_vle16_v_u16mf2(gather_qh_arr, 16); + + uint16_t shift_qh_arr[16] = {11, 9, 7, 5, 11, 9, 7, 5, 11, 9, 7, 5, 11, 9, 7, 5}; + vuint16mf2_t v_shift_qh = __riscv_vle16_v_u16mf2(shift_qh_arr, 16); + + // Masks for selecting lower/upper 16 lanes within a 32-lane i16m1 register + vuint16m1_t v_ids16 = __riscv_vid_v_u16m1(32); + vbool16_t m_hi16 = __riscv_vmsgeu_vx_u16m1_b16(v_ids16, 16, 32); + float sumf = 0.0f; + + for (int i = 0; i < nb; ++i) { + const float combined_scale = GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d; + + const uint8_t * GGML_RESTRICT qs = x[i].qs; + const uint8_t * GGML_RESTRICT qh = x[i].qh; + const uint8_t * GGML_RESTRICT scales = x[i].scales; + const int8_t * GGML_RESTRICT q8 = y[i].qs; + + const uint8_t * signs_ptr = qs + 32; + + float sum_block = 0.0f; + + for (int ib = 0; ib < 2; ++ib) { + vuint8mf4_t v_qs_u8 = __riscv_vle8_v_u8mf4(qs, 16); + qs += 16; + + vuint8mf8_t v_qh_raw = __riscv_vle8_v_u8mf8(qh, 4); + qh += 4; + + vuint16mf4_t v_qh_u16 = __riscv_vwcvtu_x_x_v_u16mf4(v_qh_raw, 4); + vuint16mf2_t v_qh_u16_ext = __riscv_vlmul_ext_v_u16mf4_u16mf2(v_qh_u16); + vuint16mf2_t v_qh_expanded = __riscv_vrgather_vv_u16mf2(v_qh_u16_ext, v_gather_qh, 16); + v_qh_expanded = __riscv_vsll_vv_u16mf2(v_qh_expanded, v_shift_qh, 16); + v_qh_expanded = __riscv_vand_vx_u16mf2(v_qh_expanded, 0x1800, 16); + + vuint16mf2_t v_qs_u16 = __riscv_vwcvtu_x_x_v_u16mf2(v_qs_u8, 16); + v_qs_u16 = __riscv_vsll_vx_u16mf2(v_qs_u16, 3, 16); + + vuint16mf2_t v_grid_offsets = __riscv_vor_vv_u16mf2(v_qs_u16, v_qh_expanded, 16); + vuint64m2_t v_grid_vals = __riscv_vluxei16_v_u64m2(grid64, v_grid_offsets, 16); + vuint8m2_t v_grid_u8 = __riscv_vreinterpret_v_u64m2_u8m2(v_grid_vals); + vint8m2_t v_grid_i8 = __riscv_vreinterpret_v_u8m2_i8m2(v_grid_u8); + + vuint8mf4_t v_signs_raw = __riscv_vle8_v_u8mf4(signs_ptr, 16); + signs_ptr += 16; + + vuint8m2_t v_signs_source = __riscv_vlmul_ext_v_u8mf4_u8m2(v_signs_raw); + vuint8m2_t v_signs_bcast = __riscv_vrgather_vv_u8m2(v_signs_source, v_sign_gather_indices, 128); + vuint8m2_t v_sign_bits = __riscv_vand_vv_u8m2(v_signs_bcast, v_sign_masks, 128); + vbool4_t m_negative = __riscv_vmsne_vx_u8m2_b4(v_sign_bits, 0, 128); + vint8m2_t v_q8 = __riscv_vle8_v_i8m2(q8, 128); + q8 += 128; + + vint8m2_t v_q8_signed = __riscv_vrsub_vx_i8m2_mu(m_negative, v_q8, v_q8, 0, 128); + vint16m4_t v_dot = __riscv_vwmul_vv_i16m4(v_grid_i8, v_q8_signed, 128); + + vint32m1_t v_zero = __riscv_vmv_v_x_i32m1(0, 1); + vint16m1_t v0 = __riscv_vget_v_i16m4_i16m1(v_dot, 0); + vint16m1_t v1 = __riscv_vget_v_i16m4_i16m1(v_dot, 1); + vint16m1_t v2 = __riscv_vget_v_i16m4_i16m1(v_dot, 2); + vint16m1_t v3 = __riscv_vget_v_i16m4_i16m1(v_dot, 3); + + int32_t s0 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m1_i32m1(v0, v_zero, 16)); + int32_t s1 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m1_i32m1_m(m_hi16, v0, v_zero, 32)); + int32_t s2 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m1_i32m1(v1, v_zero, 16)); + int32_t s3 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m1_i32m1_m(m_hi16, v1, v_zero, 32)); + int32_t s4 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m1_i32m1(v2, v_zero, 16)); + int32_t s5 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m1_i32m1_m(m_hi16, v2, v_zero, 32)); + int32_t s6 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m1_i32m1( v3, v_zero, 16)); + int32_t s7 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m1_i32m1_m(m_hi16, v3, v_zero, 32)); + + uint8_t sc0 = scales[0]; + uint8_t sc1 = scales[1]; + uint8_t sc2 = scales[2]; + uint8_t sc3 = scales[3]; + scales += 4; + + sum_block += s0 * (2 * (sc0 & 0xF) + 1); + sum_block += s1 * (2 * (sc0 >> 4) + 1); + sum_block += s2 * (2 * (sc1 & 0xF) + 1); + sum_block += s3 * (2 * (sc1 >> 4) + 1); + sum_block += s4 * (2 * (sc2 & 0xF) + 1); + sum_block += s5 * (2 * (sc2 >> 4) + 1); + sum_block += s6 * (2 * (sc3 & 0xF) + 1); + sum_block += s7 * (2 * (sc3 >> 4) + 1); + } + + sumf += sum_block * combined_scale; + } + *s = 0.125f * sumf; +} + +static NOINLINE void ggml_vec_dot_iq2_s_q8_K_vl1024(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + assert(n % QK_K == 0); + UNUSED(nrc); UNUSED(bx); UNUSED(by); UNUSED(bs); + + const block_iq2_s * GGML_RESTRICT x = vx; + const block_q8_K * GGML_RESTRICT y = vy; + + const int nb = n / QK_K; + const uint64_t * grid64 = (const uint64_t *)iq2s_grid; + vuint8m2_t v_ids = __riscv_vid_v_u8m2(256); + vuint8m2_t v_sign_gather_indices = __riscv_vsrl_vx_u8m2(v_ids, 3, 256); + + vuint8m2_t v_ones = __riscv_vmv_v_x_u8m2(1, 256); + vuint8m2_t v_shift_amts = __riscv_vand_vx_u8m2(v_ids, 7, 256); + vuint8m2_t v_sign_masks = __riscv_vsll_vv_u8m2(v_ones, v_shift_amts, 256); + + uint16_t gather_qh_arr[32] = { + 0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3, + 4, 4, 4, 4, 5, 5, 5, 5, 6, 6, 6, 6, 7, 7, 7, 7 + }; + vuint16mf2_t v_gather_qh = __riscv_vle16_v_u16mf2(gather_qh_arr, 32); + + uint16_t shift_qh_arr[32] = { + 11, 9, 7, 5, 11, 9, 7, 5, 11, 9, 7, 5, 11, 9, 7, 5, + 11, 9, 7, 5, 11, 9, 7, 5, 11, 9, 7, 5, 11, 9, 7, 5 + }; + vuint16mf2_t v_shift_qh = __riscv_vle16_v_u16mf2(shift_qh_arr, 32); + + // Masks for 4 groups of 16 lanes within a 64-lane i16m4 chunk + vuint16m4_t v_ids64 = __riscv_vid_v_u16m4(64); + vbool4_t m_g0 = __riscv_vmsltu_vx_u16m4_b4(v_ids64, 16, 64); + vbool4_t m_g1 = __riscv_vmand_mm_b4( + __riscv_vmsgeu_vx_u16m4_b4(v_ids64, 16, 64), + __riscv_vmsltu_vx_u16m4_b4(v_ids64, 32, 64), 64); + vbool4_t m_g2 = __riscv_vmand_mm_b4( + __riscv_vmsgeu_vx_u16m4_b4(v_ids64, 32, 64), + __riscv_vmsltu_vx_u16m4_b4(v_ids64, 48, 64), 64); + vbool4_t m_g3 = __riscv_vmsgeu_vx_u16m4_b4(v_ids64, 48, 64); + + float sumf = 0.0f; + + for (int i = 0; i < nb; ++i) { + const float combined_scale = GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d; + + const uint8_t * GGML_RESTRICT qs = x[i].qs; + const uint8_t * GGML_RESTRICT qh = x[i].qh; + const uint8_t * GGML_RESTRICT scales = x[i].scales; + const int8_t * GGML_RESTRICT q8 = y[i].qs; + + const uint8_t * signs_ptr = qs + 32; + + float sum_block = 0.0f; + + vuint8mf4_t v_qs_u8 = __riscv_vle8_v_u8mf4(qs, 32); + qs += 32; + + vuint8mf8_t v_qh_raw = __riscv_vle8_v_u8mf8(qh, 8); + qh += 8; + + vuint16mf4_t v_qh_u16 = __riscv_vwcvtu_x_x_v_u16mf4(v_qh_raw, 8); + vuint16mf2_t v_qh_u16_ext = __riscv_vlmul_ext_v_u16mf4_u16mf2(v_qh_u16); + vuint16mf2_t v_qh_expanded = __riscv_vrgather_vv_u16mf2(v_qh_u16_ext, v_gather_qh, 32); + v_qh_expanded = __riscv_vsll_vv_u16mf2(v_qh_expanded, v_shift_qh, 32); + v_qh_expanded = __riscv_vand_vx_u16mf2(v_qh_expanded, 0x1800, 32); + + vuint16mf2_t v_qs_u16 = __riscv_vwcvtu_x_x_v_u16mf2(v_qs_u8, 32); + v_qs_u16 = __riscv_vsll_vx_u16mf2(v_qs_u16, 3, 32); + + vuint16mf2_t v_grid_offsets = __riscv_vor_vv_u16mf2(v_qs_u16, v_qh_expanded, 32); + vuint64m2_t v_grid_vals = __riscv_vluxei16_v_u64m2(grid64, v_grid_offsets, 32); + vuint8m2_t v_grid_u8 = __riscv_vreinterpret_v_u64m2_u8m2(v_grid_vals); + vint8m2_t v_grid_i8 = __riscv_vreinterpret_v_u8m2_i8m2(v_grid_u8); + + //loading signs + vuint8mf2_t v_signs_raw = __riscv_vle8_v_u8mf2(signs_ptr, 32); + signs_ptr += 32; + + vuint8m2_t v_signs_source = __riscv_vlmul_ext_v_u8mf2_u8m2(v_signs_raw); + vuint8m2_t v_signs_bcast = __riscv_vrgather_vv_u8m2(v_signs_source, v_sign_gather_indices, 256); + vuint8m2_t v_sign_bits = __riscv_vand_vv_u8m2(v_signs_bcast, v_sign_masks, 256); + vbool4_t m_negative = __riscv_vmsne_vx_u8m2_b4(v_sign_bits, 0, 256); + + vint8m2_t v_q8 = __riscv_vle8_v_i8m2(q8, 256); + q8 += 256; + + vint8m2_t v_q8_signed = __riscv_vrsub_vx_i8m2_mu(m_negative, v_q8, v_q8, 0, 256); + vint16m4_t v_dot = __riscv_vwmul_vv_i16m4(v_grid_i8, v_q8_signed, 256); + + vint32m1_t v_zero = __riscv_vmv_v_x_i32m1(0, 1); + + vint16m4_t c = v_dot; + + int32_t s0 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m4_i32m1_m(m_g0, c, v_zero, 64)); + int32_t s1 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m4_i32m1_m(m_g1, c, v_zero, 64)); + int32_t s2 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m4_i32m1_m(m_g2, c, v_zero, 64)); + int32_t s3 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m4_i32m1_m(m_g3, c, v_zero, 64)); + + c = __riscv_vslidedown_vx_i16m4(c, 64, 256); + int32_t s4 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m4_i32m1_m(m_g0, c, v_zero, 64)); + int32_t s5 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m4_i32m1_m(m_g1, c, v_zero, 64)); + int32_t s6 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m4_i32m1_m(m_g2, c, v_zero, 64)); + int32_t s7 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m4_i32m1_m(m_g3, c, v_zero, 64)); + + c = __riscv_vslidedown_vx_i16m4(c, 64, 256); + int32_t s8 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m4_i32m1_m(m_g0, c, v_zero, 64)); + int32_t s9 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m4_i32m1_m(m_g1, c, v_zero, 64)); + int32_t s10 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m4_i32m1_m(m_g2, c, v_zero, 64)); + int32_t s11 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m4_i32m1_m(m_g3, c, v_zero, 64)); + + c = __riscv_vslidedown_vx_i16m4(c, 64, 256); + int32_t s12 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m4_i32m1_m(m_g0, c, v_zero, 64)); + int32_t s13 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m4_i32m1_m(m_g1, c, v_zero, 64)); + int32_t s14 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m4_i32m1_m(m_g2, c, v_zero, 64)); + int32_t s15 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m4_i32m1_m(m_g3, c, v_zero, 64)); + + int32_t sums_arr[16] = { s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11, s12, s13, s14, s15 }; + + // Load 8 scale bytes and split into 16 nibbles + vuint8mf2_t v_sc8 = __riscv_vle8_v_u8mf2(scales, 8); + scales += 8; + + vuint8mf2_t v_lo8 = __riscv_vand_vx_u8mf2(v_sc8, 0x0F, 8); + vuint8mf2_t v_hi8 = __riscv_vsrl_vx_u8mf2(v_sc8, 4, 8); + + vuint8m1_t v_idx16 = __riscv_vid_v_u8m1(16); + vuint8m1_t v_half = __riscv_vsrl_vx_u8m1(v_idx16, 1, 16); + vbool8_t m_even = __riscv_vmseq_vx_u8m1_b8(__riscv_vand_vx_u8m1(v_idx16, 1, 16), 0, 16); + + vuint8m1_t v_lo_ext = __riscv_vlmul_ext_v_u8mf2_u8m1(v_lo8); + vuint8m1_t v_hi_ext = __riscv_vlmul_ext_v_u8mf2_u8m1(v_hi8); + vuint8m1_t v_lo_g = __riscv_vrgather_vv_u8m1(v_lo_ext, v_half, 16); + vuint8m1_t v_hi_g = __riscv_vrgather_vv_u8m1(v_hi_ext, v_half, 16); + vuint8m1_t v_nib = __riscv_vmerge_vvm_u8m1(v_lo_g, v_hi_g, m_even, 16); + + static const uint8_t iq2s_scale_lut_16_local[16] = { + 1, 3, 5, 7, 9, 11, 13, 15, 17, 19, 21, 23, 25, 27, 29, 31 + }; + vuint8m1_t v_lut = __riscv_vle8_v_u8m1(iq2s_scale_lut_16_local, 16); + vuint8m1_t v_sc8v = __riscv_vrgather_vv_u8m1(v_lut, v_nib, 16); + + vint32m4_t v_sums = __riscv_vle32_v_i32m4(sums_arr, 16); + vuint16m2_t v_sc16 = __riscv_vwcvtu_x_x_v_u16m2(v_sc8v, 16); + vuint32m4_t v_sc32u = __riscv_vwcvtu_x_x_v_u32m4(v_sc16, 16); + vint32m4_t v_sc32 = __riscv_vreinterpret_v_u32m4_i32m4(v_sc32u); + vint32m4_t v_prod = __riscv_vmul_vv_i32m4(v_sums, v_sc32, 16); + + vint32m1_t v_zero32 = __riscv_vmv_v_x_i32m1(0, 1); + int32_t sum_part = __riscv_vmv_x_s_i32m1_i32(__riscv_vredsum_vs_i32m4_i32m1(v_prod, v_zero32, 16)); + sum_block += sum_part; + + sumf += sum_block * combined_scale; + } + *s = 0.125f * sumf; +} +#endif + +void ggml_vec_dot_iq2_s_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { +#if defined __riscv_v + switch (__riscv_vlenb() * 8) { + case 128: + ggml_vec_dot_iq2_s_q8_K_vl128(n, s, bs, vx, bx, vy, by, nrc); + break; + case 256: + ggml_vec_dot_iq2_s_q8_K_vl256(n, s, bs, vx, bx, vy, by, nrc); + break; + case 512: + ggml_vec_dot_iq2_s_q8_K_vl512(n, s, bs, vx, bx, vy, by, nrc); + break; + default: + ggml_vec_dot_iq2_s_q8_K_vl1024(n, s, bs, vx, bx, vy, by, nrc); + break; + } +#else + ggml_vec_dot_iq2_s_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc); +#endif +} + +#if defined __riscv_v +static const int8_t keven_signs_q2xs[1024] = { + 1, 1, 1, 1, 1, 1, 1, 1, -1, 1, 1, 1, 1, 1, 1, -1, 1, -1, 1, 1, 1, 1, 1, -1, -1, -1, 1, 1, 1, 1, 1, 1, + 1, 1, -1, 1, 1, 1, 1, -1, -1, 1, -1, 1, 1, 1, 1, 1, 1, -1, -1, 1, 1, 1, 1, 1, -1, -1, -1, 1, 1, 1, 1, -1, + 1, 1, 1, -1, 1, 1, 1, -1, -1, 1, 1, -1, 1, 1, 1, 1, 1, -1, 1, -1, 1, 1, 1, 1, -1, -1, 1, -1, 1, 1, 1, -1, + 1, 1, -1, -1, 1, 1, 1, 1, -1, 1, -1, -1, 1, 1, 1, -1, 1, -1, -1, -1, 1, 1, 1, -1, -1, -1, -1, -1, 1, 1, 1, 1, + 1, 1, 1, 1, -1, 1, 1, -1, -1, 1, 1, 1, -1, 1, 1, 1, 1, -1, 1, 1, -1, 1, 1, 1, -1, -1, 1, 1, -1, 1, 1, -1, + 1, 1, -1, 1, -1, 1, 1, 1, -1, 1, -1, 1, -1, 1, 1, -1, 1, -1, -1, 1, -1, 1, 1, -1, -1, -1, -1, 1, -1, 1, 1, 1, + 1, 1, 1, -1, -1, 1, 1, 1, -1, 1, 1, -1, -1, 1, 1, -1, 1, -1, 1, -1, -1, 1, 1, -1, -1, -1, 1, -1, -1, 1, 1, 1, + 1, 1, -1, -1, -1, 1, 1, -1, -1, 1, -1, -1, -1, 1, 1, 1, 1, -1, -1, -1, -1, 1, 1, 1, -1, -1, -1, -1, -1, 1, 1, -1, + 1, 1, 1, 1, 1, -1, 1, -1, -1, 1, 1, 1, 1, -1, 1, 1, 1, -1, 1, 1, 1, -1, 1, 1, -1, -1, 1, 1, 1, -1, 1, -1, + 1, 1, -1, 1, 1, -1, 1, 1, -1, 1, -1, 1, 1, -1, 1, -1, 1, -1, -1, 1, 1, -1, 1, -1, -1, -1, -1, 1, 1, -1, 1, 1, + 1, 1, 1, -1, 1, -1, 1, 1, -1, 1, 1, -1, 1, -1, 1, -1, 1, -1, 1, -1, 1, -1, 1, -1, -1, -1, 1, -1, 1, -1, 1, 1, + 1, 1, -1, -1, 1, -1, 1, -1, -1, 1, -1, -1, 1, -1, 1, 1, 1, -1, -1, -1, 1, -1, 1, 1, -1, -1, -1, -1, 1, -1, 1, -1, + 1, 1, 1, 1, -1, -1, 1, 1, -1, 1, 1, 1, -1, -1, 1, -1, 1, -1, 1, 1, -1, -1, 1, -1, -1, -1, 1, 1, -1, -1, 1, 1, + 1, 1, -1, 1, -1, -1, 1, -1, -1, 1, -1, 1, -1, -1, 1, 1, 1, -1, -1, 1, -1, -1, 1, 1, -1, -1, -1, 1, -1, -1, 1, -1, + 1, 1, 1, -1, -1, -1, 1, -1, -1, 1, 1, -1, -1, -1, 1, 1, 1, -1, 1, -1, -1, -1, 1, 1, -1, -1, 1, -1, -1, -1, 1, -1, + 1, 1, -1, -1, -1, -1, 1, 1, -1, 1, -1, -1, -1, -1, 1, -1, 1, -1, -1, -1, -1, -1, 1, -1, -1, -1, -1, -1, -1, -1, 1, 1, + 1, 1, 1, 1, 1, 1, -1, -1, -1, 1, 1, 1, 1, 1, -1, 1, 1, -1, 1, 1, 1, 1, -1, 1, -1, -1, 1, 1, 1, 1, -1, -1, + 1, 1, -1, 1, 1, 1, -1, 1, -1, 1, -1, 1, 1, 1, -1, -1, 1, -1, -1, 1, 1, 1, -1, -1, -1, -1, -1, 1, 1, 1, -1, 1, + 1, 1, 1, -1, 1, 1, -1, 1, -1, 1, 1, -1, 1, 1, -1, -1, 1, -1, 1, -1, 1, 1, -1, -1, -1, -1, 1, -1, 1, 1, -1, 1, + 1, 1, -1, -1, 1, 1, -1, -1, -1, 1, -1, -1, 1, 1, -1, 1, 1, -1, -1, -1, 1, 1, -1, 1, -1, -1, -1, -1, 1, 1, -1, -1, + 1, 1, 1, 1, -1, 1, -1, 1, -1, 1, 1, 1, -1, 1, -1, -1, 1, -1, 1, 1, -1, 1, -1, -1, -1, -1, 1, 1, -1, 1, -1, 1, + 1, 1, -1, 1, -1, 1, -1, -1, -1, 1, -1, 1, -1, 1, -1, 1, 1, -1, -1, 1, -1, 1, -1, 1, -1, -1, -1, 1, -1, 1, -1, -1, + 1, 1, 1, -1, -1, 1, -1, -1, -1, 1, 1, -1, -1, 1, -1, 1, 1, -1, 1, -1, -1, 1, -1, 1, -1, -1, 1, -1, -1, 1, -1, -1, + 1, 1, -1, -1, -1, 1, -1, 1, -1, 1, -1, -1, -1, 1, -1, -1, 1, -1, -1, -1, -1, 1, -1, -1, -1, -1, -1, -1, -1, 1, -1, 1, + 1, 1, 1, 1, 1, -1, -1, 1, -1, 1, 1, 1, 1, -1, -1, -1, 1, -1, 1, 1, 1, -1, -1, -1, -1, -1, 1, 1, 1, -1, -1, 1, + 1, 1, -1, 1, 1, -1, -1, -1, -1, 1, -1, 1, 1, -1, -1, 1, 1, -1, -1, 1, 1, -1, -1, 1, -1, -1, -1, 1, 1, -1, -1, -1, + 1, 1, 1, -1, 1, -1, -1, -1, -1, 1, 1, -1, 1, -1, -1, 1, 1, -1, 1, -1, 1, -1, -1, 1, -1, -1, 1, -1, 1, -1, -1, -1, + 1, 1, -1, -1, 1, -1, -1, 1, -1, 1, -1, -1, 1, -1, -1, -1, 1, -1, -1, -1, 1, -1, -1, -1, -1, -1, -1, -1, 1, -1, -1, 1, + 1, 1, 1, 1, -1, -1, -1, -1, -1, 1, 1, 1, -1, -1, -1, 1, 1, -1, 1, 1, -1, -1, -1, 1, -1, -1, 1, 1, -1, -1, -1, -1, + 1, 1, -1, 1, -1, -1, -1, 1, -1, 1, -1, 1, -1, -1, -1, -1, 1, -1, -1, 1, -1, -1, -1, -1, -1, -1, -1, 1, -1, -1, -1, 1, + 1, 1, 1, -1, -1, -1, -1, 1, -1, 1, 1, -1, -1, -1, -1, -1, 1, -1, 1, -1, -1, -1, -1, -1, -1, -1, 1, -1, -1, -1, -1, 1, + 1, 1, -1, -1, -1, -1, -1, -1, -1, 1, -1, -1, -1, -1, -1, 1, 1, -1, -1, -1, -1, -1, -1, 1, -1, -1, -1, -1, -1, -1, -1, -1, +}; + +static NOINLINE void ggml_vec_dot_iq2_xs_q8_K_vl128(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + assert(n % QK_K == 0); + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + + const block_iq2_xs * GGML_RESTRICT x = vx; + const block_q8_K * GGML_RESTRICT y = vy; + + const int nb = n / QK_K; + const uint64_t * signs64 = (const uint64_t *)keven_signs_q2xs; + const uint64_t * grid64 = (const uint64_t *)iq2xs_grid; + + float sumf = 0.0f; +#pragma GCC unroll 1 + for (int i = 0; i < nb; ++i) { + const float d = GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d; + const uint16_t * GGML_RESTRICT qs = x[i].qs; + const int8_t * GGML_RESTRICT q8 = y[i].qs; + const uint8_t * GGML_RESTRICT scales = x[i].scales; + + int32_t sum_int = 0; + + // Loop over 4 subblocks of 64 elements + for (int ib64 = 0; ib64 < QK_K / 64; ++ib64) { + + // Load indices. + vuint16m1_t v_qs = __riscv_vle16_v_u16m1(qs, 8); + qs += 8; + + // Prepare offsets + vuint16m1_t vidx_grid = __riscv_vsll_vx_u16m1(__riscv_vand_vx_u16m1(v_qs, 511, 8), 3, 8); + vuint16m1_t vidx_sign = __riscv_vsll_vx_u16m1(__riscv_vsrl_vx_u16m1(v_qs, 9, 8), 3, 8); + + // load values and signs from the lookup tables + vuint64m4_t vq2_64 = __riscv_vluxei16_v_u64m4(grid64, vidx_grid, 8); + vuint64m4_t vs2_64 = __riscv_vluxei16_v_u64m4(signs64, vidx_sign, 8); + vint8m4_t q2u = __riscv_vreinterpret_v_u8m4_i8m4(__riscv_vreinterpret_v_u64m4_u8m4(vq2_64)); + vint8m4_t q2s = __riscv_vreinterpret_v_u8m4_i8m4(__riscv_vreinterpret_v_u64m4_u8m4(vs2_64)); + vint8m4_t q2_final = __riscv_vmul_vv_i8m4(q2u, q2s, 64); + asm volatile("" ::: "memory"); + vint8m4_t q8v = __riscv_vle8_v_i8m4(q8, 64); + q8 += 64; + + vint16m8_t prod = __riscv_vwmul_vv_i16m8(q2_final, q8v, 64); + asm volatile("" ::: "memory"); + vint32m1_t zero_vec = __riscv_vmv_v_x_i32m1(0, 1); + + int32_t sum0 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m2_i32m1( + __riscv_vget_v_i16m8_i16m2(prod, 0), zero_vec, 16)); + + int32_t sum1 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m2_i32m1( + __riscv_vget_v_i16m8_i16m2(prod, 1), zero_vec, 16)); + + int32_t sum2 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m2_i32m1( + __riscv_vget_v_i16m8_i16m2(prod, 2), zero_vec, 16)); + + int32_t sum3 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m2_i32m1( + __riscv_vget_v_i16m8_i16m2(prod, 3), zero_vec, 16)); + + const uint8_t scale_byte_1 = scales[0]; + const uint8_t scale_byte_2 = scales[1]; + scales += 2; + + sum_int += sum0 * ((scale_byte_1 & 0x0F) * 2 + 1); + sum_int += sum1 * ((scale_byte_1 >> 4) * 2 + 1); + sum_int += sum2 * ((scale_byte_2 & 0x0F) * 2 + 1); + sum_int += sum3 * ((scale_byte_2 >> 4) * 2 + 1); + } + + sumf += d * sum_int; + } + *s = 0.125f * sumf; +} + +static NOINLINE void ggml_vec_dot_iq2_xs_q8_K_vl256(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + assert(n % QK_K == 0); + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + + const block_iq2_xs * GGML_RESTRICT x = vx; + const block_q8_K * GGML_RESTRICT y = vy; + + const int nb = n / QK_K; + const uint64_t * signs64 = (const uint64_t *)keven_signs_q2xs; + const uint64_t * grid64 = (const uint64_t *)iq2xs_grid; + + float sumf = 0.0f; + + for (int i = 0; i < nb; ++i) { + const float d = GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d; + const uint16_t * GGML_RESTRICT qs = x[i].qs; + const int8_t * GGML_RESTRICT q8 = y[i].qs; + const uint8_t * GGML_RESTRICT scales = x[i].scales; + + int32_t sum_int = 0; + + for (int ib128 = 0; ib128 < 2; ++ib128) { + + vuint16m1_t v_qs = __riscv_vle16_v_u16m1(qs, 16); + qs += 16; + + // Prepare offsets for grid and signs + vuint16m1_t vidx_grid = __riscv_vsll_vx_u16m1(__riscv_vand_vx_u16m1(v_qs, 511, 16), 3, 16); + vuint16m1_t vidx_sign = __riscv_vsll_vx_u16m1(__riscv_vsrl_vx_u16m1(v_qs, 9, 16), 3, 16); + + // Indexed load 128 weights (16 x 8-byte chunks) + vuint64m4_t vq2_64 = __riscv_vluxei16_v_u64m4(grid64, vidx_grid, 16); + vuint64m4_t vs2_64 = __riscv_vluxei16_v_u64m4(signs64, vidx_sign, 16); + + vint8m4_t q2u = __riscv_vreinterpret_v_u8m4_i8m4(__riscv_vreinterpret_v_u64m4_u8m4(vq2_64)); + vint8m4_t q2s = __riscv_vreinterpret_v_u8m4_i8m4(__riscv_vreinterpret_v_u64m4_u8m4(vs2_64)); + + // Apply signs to get dequantized IQ2 values + vint8m4_t q2_final = __riscv_vmul_vv_i8m4(q2u, q2s, 128); + asm volatile("" ::: "memory"); + + // Load corresponding Q8 weights + vint8m4_t q8v = __riscv_vle8_v_i8m4(q8, 128); + q8 += 128; + + vint16m8_t prod = __riscv_vwmul_vv_i16m8(q2_final, q8v, 128); + asm volatile("" ::: "memory"); + + uint8_t sc0 = scales[0]; + uint8_t sc1 = scales[1]; + uint8_t sc2 = scales[2]; + uint8_t sc3 = scales[3]; + scales += 4; + + vint32m1_t zero_vec = __riscv_vmv_v_x_i32m1(0, 1); + + // 9. Reduce each 16-element chunk and apply corresponding nibble scale + + int32_t s0 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m1_i32m1(__riscv_vget_v_i16m8_i16m1(prod, 0), zero_vec, 16)); + sum_int += s0 * ((sc0 & 0x0F) * 2 + 1); + + int32_t s1 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m1_i32m1(__riscv_vget_v_i16m8_i16m1(prod, 1), zero_vec, 16)); + sum_int += s1 * ((sc0 >> 4) * 2 + 1); + + int32_t s2 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m1_i32m1(__riscv_vget_v_i16m8_i16m1(prod, 2), zero_vec, 16)); + sum_int += s2 * ((sc1 & 0x0F) * 2 + 1); + + int32_t s3 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m1_i32m1(__riscv_vget_v_i16m8_i16m1(prod, 3), zero_vec, 16)); + sum_int += s3 * ((sc1 >> 4) * 2 + 1); + + int32_t s4 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m1_i32m1(__riscv_vget_v_i16m8_i16m1(prod, 4), zero_vec, 16)); + sum_int += s4 * ((sc2 & 0x0F) * 2 + 1); + + int32_t s5 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m1_i32m1(__riscv_vget_v_i16m8_i16m1(prod, 5), zero_vec, 16)); + sum_int += s5 * ((sc2 >> 4) * 2 + 1); + + int32_t s6 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m1_i32m1(__riscv_vget_v_i16m8_i16m1(prod, 6), zero_vec, 16)); + sum_int += s6 * ((sc3 & 0x0F) * 2 + 1); + + int32_t s7 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m1_i32m1(__riscv_vget_v_i16m8_i16m1(prod, 7), zero_vec, 16)); + sum_int += s7 * ((sc3 >> 4) * 2 + 1); + } + + sumf += d * (float)sum_int; + } + *s = 0.125f * sumf; +} + +static NOINLINE void ggml_vec_dot_iq2_xs_q8_K_vl512(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + assert(n % QK_K == 0); + assert(nrc == 1); + UNUSED(nrc); UNUSED(bx); UNUSED(by); UNUSED(bs); + + const block_iq2_xs * GGML_RESTRICT x = vx; + const block_q8_K * GGML_RESTRICT y = vy; + + const int nb = n / QK_K; + const uint64_t * signs64 = (const uint64_t *)keven_signs_q2xs; + const uint64_t * grid64 = (const uint64_t *)iq2xs_grid; + + float sumf = 0.0f; + for (int i = 0; i < nb; ++i) { + const float combined_scale = GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d; + + const uint16_t * GGML_RESTRICT qs = x[i].qs; + const uint8_t * GGML_RESTRICT scales = x[i].scales; + const int8_t * GGML_RESTRICT q8 = y[i].qs; + + vint8m4_t q8_all = __riscv_vle8_v_i8m4(q8, 256); + + // Load indices --- + vuint16m1_t v_qs = __riscv_vle16_v_u16m1(qs, 32); + + // Extract low 9 bits and multiply by 8 (shift left 3) for byte offset into uint64 table + vuint16m1_t vidx_grid = __riscv_vsll_vx_u16m1(__riscv_vand_vx_u16m1(v_qs, 511, 32), 3, 32); + + // Extract high 7 bits (shift right 9) and multiply by 8 (shift left 3) for byte offset + vuint16m1_t vidx_sign = __riscv_vsll_vx_u16m1(__riscv_vsrl_vx_u16m1(v_qs, 9, 32), 3, 32); + + vuint64m4_t vq2_64 = __riscv_vluxei16_v_u64m4(grid64, vidx_grid, 32); + vuint64m4_t vs2_64 = __riscv_vluxei16_v_u64m4(signs64, vidx_sign, 32); + + vint8m4_t q2_all = __riscv_vreinterpret_v_u8m4_i8m4(__riscv_vreinterpret_v_u64m4_u8m4(vq2_64)); + vint8m4_t s2_all = __riscv_vreinterpret_v_u8m4_i8m4(__riscv_vreinterpret_v_u64m4_u8m4(vs2_64)); + + vint8m4_t q2_signed = __riscv_vmul_vv_i8m4(q2_all, s2_all, 256); + vint16m8_t dot_all = __riscv_vwmul_vv_i16m8(q2_signed, q8_all, 256); + float sum = 0.0f; + vint32m1_t zero_vec = __riscv_vmv_v_x_i32m1(0, 1); + +#pragma GCC unroll 1 + for (int j = 0; j < 8; ++j) { + uint8_t sc = scales[j]; + int16_t sc_lo = 2 * (sc & 0x0F) + 1; + int16_t sc_hi = 2 * (sc >> 4) + 1; + + vint32m1_t sum_v0 = __riscv_vwredsum_vs_i16m8_i32m1( + __riscv_vslidedown_vx_i16m8(dot_all, j * 32, 16), zero_vec, 16); + int32_t isum0 = __riscv_vmv_x_s_i32m1_i32(sum_v0); + + vint32m1_t sum_v1 = __riscv_vwredsum_vs_i16m8_i32m1( + __riscv_vslidedown_vx_i16m8(dot_all, j * 32 + 16, 16), zero_vec, 16); + int32_t isum1 = __riscv_vmv_x_s_i32m1_i32(sum_v1); + + sum += (float)isum0 * sc_lo + (float)isum1 * sc_hi; + } + + sumf += sum * combined_scale; + } + *s = 0.125f * sumf; +} +#endif + +void ggml_vec_dot_iq2_xs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { +#if defined __riscv_v + switch (__riscv_vlenb() * 8) { + case 128: + ggml_vec_dot_iq2_xs_q8_K_vl128(n, s, bs, vx, bx, vy, by, nrc); + break; + case 256: + ggml_vec_dot_iq2_xs_q8_K_vl256(n, s, bs, vx, bx, vy, by, nrc); + break; + default: // 512 and above + ggml_vec_dot_iq2_xs_q8_K_vl512(n, s, bs, vx, bx, vy, by, nrc); + break; + } +#else + ggml_vec_dot_iq2_xs_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc); +#endif +} + +#if defined __riscv_v +static NOINLINE void ggml_vec_dot_iq2_xxs_q8_K_vl128(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + assert(n % QK_K == 0); + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + + const block_iq2_xxs * GGML_RESTRICT x = vx; + const block_q8_K * GGML_RESTRICT y = vy; + + const int nb = n / QK_K; + const uint64_t * signs64 = (const uint64_t *)keven_signs_q2xs; + const uint64_t * grid64 = (const uint64_t *)iq2xxs_grid; + + uint32_t shift_constants[4] = {0, 7, 14, 21}; + vuint32m1_t v_shifts = __riscv_vle32_v_u32m1(shift_constants, 4); + + float sumf = 0.0f; + for (int i = 0; i < nb; ++i) { + const float combined_scale = GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d; + + const uint8_t * GGML_RESTRICT q2_ptr = (const uint8_t *) x[i].qs; + const int8_t * GGML_RESTRICT q8 = y[i].qs; + + float sum = 0.0f; + + #pragma GCC unroll 1 + for (int ib32 = 0; ib32 < QK_K / 32; ib32 += 2) { + vint8m2_t q8_1 = __riscv_vle8_v_i8m2(q8, 32); q8 += 32; + vint8m2_t q8_2 = __riscv_vle8_v_i8m2(q8, 32); q8 += 32; + + vuint8mf4_t v_raw_q2_1 = __riscv_vle8_v_u8mf4(q2_ptr, 4); + vuint8mf4_t v_raw_q2_2 = __riscv_vle8_v_u8mf4(q2_ptr + 8, 4); + + vuint16mf2_t vidx_q2_1 = __riscv_vwcvtu_x_x_v_u16mf2(v_raw_q2_1, 4); + vuint16mf2_t vidx_q2_2 = __riscv_vwcvtu_x_x_v_u16mf2(v_raw_q2_2, 4); + + vidx_q2_1 = __riscv_vsll_vx_u16mf2(vidx_q2_1, 3, 4); + vidx_q2_2 = __riscv_vsll_vx_u16mf2(vidx_q2_2, 3, 4); + + uint32_t s_packed_1, s_packed_2; + memcpy(&s_packed_1, q2_ptr + 4, 4); + memcpy(&s_packed_2, q2_ptr + 12, 4); + + vuint32m1_t v_s_1 = __riscv_vmv_v_x_u32m1(s_packed_1, 4); + vuint32m1_t v_s_2 = __riscv_vmv_v_x_u32m1(s_packed_2, 4); + v_s_1 = __riscv_vsrl_vv_u32m1(v_s_1, v_shifts, 4); + v_s_2 = __riscv_vsrl_vv_u32m1(v_s_2, v_shifts, 4); + + v_s_1 = __riscv_vand_vx_u32m1(v_s_1, 127, 4); + v_s_2 = __riscv_vand_vx_u32m1(v_s_2, 127, 4); + + vuint16mf2_t vidx_s2_1 = __riscv_vsll_vx_u16mf2(__riscv_vncvt_x_x_w_u16mf2(v_s_1, 4), 3, 4); + vuint16mf2_t vidx_s2_2 = __riscv_vsll_vx_u16mf2(__riscv_vncvt_x_x_w_u16mf2(v_s_2, 4), 3, 4); + + vuint64m2_t vq2_64_1 = __riscv_vluxei16_v_u64m2(grid64, vidx_q2_1, 4); + vuint64m2_t vq2_64_2 = __riscv_vluxei16_v_u64m2(grid64, vidx_q2_2, 4); + + vint8m2_t q2_1 = __riscv_vreinterpret_v_u8m2_i8m2(__riscv_vreinterpret_v_u64m2_u8m2(vq2_64_1)); + vint8m2_t q2_2 = __riscv_vreinterpret_v_u8m2_i8m2(__riscv_vreinterpret_v_u64m2_u8m2(vq2_64_2)); + + vuint64m2_t vs2_64_1 = __riscv_vluxei16_v_u64m2(signs64, vidx_s2_1, 4); + vuint64m2_t vs2_64_2 = __riscv_vluxei16_v_u64m2(signs64, vidx_s2_2, 4); + vint8m2_t s2_1 = __riscv_vreinterpret_v_u8m2_i8m2(__riscv_vreinterpret_v_u64m2_u8m2(vs2_64_1)); + vint8m2_t s2_2 = __riscv_vreinterpret_v_u8m2_i8m2(__riscv_vreinterpret_v_u64m2_u8m2(vs2_64_2)); + + vint8m2_t q8s_1 = __riscv_vmul_vv_i8m2(q8_1, s2_1, 32); + vint8m2_t q8s_2 = __riscv_vmul_vv_i8m2(q8_2, s2_2, 32); + + vint16m4_t dot1 = __riscv_vwmul_vv_i16m4(q8s_1, q2_1, 32); + vint16m4_t dot2 = __riscv_vwmul_vv_i16m4(q8s_2, q2_2, 32); + + vint32m1_t zero_vec = __riscv_vmv_v_x_i32m1(0, 1); + vint32m1_t sumv1 = __riscv_vwredsum_vs_i16m4_i32m1(dot1, zero_vec, 32); + vint32m1_t sumv2 = __riscv_vwredsum_vs_i16m4_i32m1(dot2, zero_vec, 32); + + int32_t scalar_sum1 = __riscv_vmv_x_s_i32m1_i32(sumv1); + int32_t scalar_sum2 = __riscv_vmv_x_s_i32m1_i32(sumv2); + + int16_t scale1 = 2 * ((s_packed_1 >> 28) & 0xF) + 1; + int16_t scale2 = 2 * ((s_packed_2 >> 28) & 0xF) + 1; + + sum += scalar_sum1 * scale1 + scalar_sum2 * scale2; + q2_ptr += 16; + } + sumf += sum * combined_scale; + } + *s = 0.125f * sumf; +} + +static NOINLINE void ggml_vec_dot_iq2_xxs_q8_K_vl256(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + assert(n % QK_K == 0); + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + + const block_iq2_xxs * GGML_RESTRICT x = vx; + const block_q8_K * GGML_RESTRICT y = vy; + + const int nb = n / QK_K; + const uint64_t * signs64 = (const uint64_t *)keven_signs_q2xs; + const uint64_t * grid64 = (const uint64_t *)iq2xxs_grid; + + uint32_t shift_constants[4] = {0, 7, 14, 21}; + vuint32mf2_t v_shifts = __riscv_vle32_v_u32mf2(shift_constants, 4); + + float sumf = 0.0f; + + for (int i = 0; i < nb; ++i) { + const float combined_scale = GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d; + + const uint8_t * GGML_RESTRICT q2_ptr = (const uint8_t *) x[i].qs; + const int8_t * GGML_RESTRICT q8 = y[i].qs; + + float sum = 0.0f; + + for (int ib32 = 0; ib32 < QK_K / 32; ib32 += 2) { + vint8m1_t q8_1 = __riscv_vle8_v_i8m1(q8, 32); q8 += 32; + vint8m1_t q8_2 = __riscv_vle8_v_i8m1(q8, 32); q8 += 32; + + vuint8mf8_t v_raw_q2_1 = __riscv_vle8_v_u8mf8(q2_ptr, 4); + vuint8mf8_t v_raw_q2_2 = __riscv_vle8_v_u8mf8(q2_ptr + 8, 4); + + vuint16mf4_t vidx_q2_1 = __riscv_vwcvtu_x_x_v_u16mf4(v_raw_q2_1, 4); + vuint16mf4_t vidx_q2_2 = __riscv_vwcvtu_x_x_v_u16mf4(v_raw_q2_2, 4); + + vidx_q2_1 = __riscv_vsll_vx_u16mf4(vidx_q2_1, 3, 4); + vidx_q2_2 = __riscv_vsll_vx_u16mf4(vidx_q2_2, 3, 4); + + uint32_t s_packed_1, s_packed_2; + memcpy(&s_packed_1, q2_ptr + 4, 4); + memcpy(&s_packed_2, q2_ptr + 12, 4); + + vuint32mf2_t v_s_1 = __riscv_vmv_v_x_u32mf2(s_packed_1, 4); + vuint32mf2_t v_s_2 = __riscv_vmv_v_x_u32mf2(s_packed_2, 4); + + v_s_1 = __riscv_vsrl_vv_u32mf2(v_s_1, v_shifts, 4); + v_s_2 = __riscv_vsrl_vv_u32mf2(v_s_2, v_shifts, 4); + + v_s_1 = __riscv_vand_vx_u32mf2(v_s_1, 127, 4); + v_s_2 = __riscv_vand_vx_u32mf2(v_s_2, 127, 4); + + // Narrow u32 -> u16 (vncvt) and Scale by 8 to get byte offsets + vuint16mf4_t vidx_s2_1 = __riscv_vsll_vx_u16mf4(__riscv_vncvt_x_x_w_u16mf4(v_s_1, 4), 3, 4); + vuint16mf4_t vidx_s2_2 = __riscv_vsll_vx_u16mf4(__riscv_vncvt_x_x_w_u16mf4(v_s_2, 4), 3, 4); + + // Load q2 values from lookup grid + vuint64m1_t vq2_64_1 = __riscv_vluxei16_v_u64m1(grid64, vidx_q2_1, 4); + vuint64m1_t vq2_64_2 = __riscv_vluxei16_v_u64m1(grid64, vidx_q2_2, 4); + vint8m1_t q2_1 = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vreinterpret_v_u64m1_u8m1(vq2_64_1)); + vint8m1_t q2_2 = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vreinterpret_v_u64m1_u8m1(vq2_64_2)); + + // Load sign values + vuint64m1_t vs2_64_1 = __riscv_vluxei16_v_u64m1(signs64, vidx_s2_1, 4); + vuint64m1_t vs2_64_2 = __riscv_vluxei16_v_u64m1(signs64, vidx_s2_2, 4); + vint8m1_t s2_1 = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vreinterpret_v_u64m1_u8m1(vs2_64_1)); + vint8m1_t s2_2 = __riscv_vreinterpret_v_u8m1_i8m1(__riscv_vreinterpret_v_u64m1_u8m1(vs2_64_2)); + + // Apply signs to q8 + vint8m1_t q8s_1 = __riscv_vmul_vv_i8m1(q8_1, s2_1, 32); + vint8m1_t q8s_2 = __riscv_vmul_vv_i8m1(q8_2, s2_2, 32); + + // multiplying q2 with q8 + vint16m2_t dot1 = __riscv_vwmul_vv_i16m2(q8s_1, q2_1, 32); + vint16m2_t dot2 = __riscv_vwmul_vv_i16m2(q8s_2, q2_2, 32); + + vint32m1_t zero_vec = __riscv_vmv_v_x_i32m1(0, 1); + vint32m1_t sumv1 = __riscv_vwredsum_vs_i16m2_i32m1(dot1, zero_vec, 32); + vint32m1_t sumv2 = __riscv_vwredsum_vs_i16m2_i32m1(dot2, zero_vec, 32); + int32_t scalar_sum1 = __riscv_vmv_x_s_i32m1_i32(sumv1); + int32_t scalar_sum2 = __riscv_vmv_x_s_i32m1_i32(sumv2); + int16_t scale1 = 2 * ((s_packed_1 >> 28) & 0xF) + 1; + int16_t scale2 = 2 * ((s_packed_2 >> 28) & 0xF) + 1; + + sum += scalar_sum1 * scale1 + scalar_sum2 * scale2; + q2_ptr += 16; + } + sumf += sum * combined_scale; + } + *s = 0.125f * sumf; +} + +static NOINLINE void ggml_vec_dot_iq2_xxs_q8_K_vl512(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + assert(n % QK_K == 0); + assert(nrc == 1); + UNUSED(nrc); UNUSED(bx); UNUSED(by); UNUSED(bs); + + const block_iq2_xxs * GGML_RESTRICT x = vx; + const block_q8_K * GGML_RESTRICT y = vy; + + const int nb = n / QK_K; + const uint64_t * signs64 = (const uint64_t *)keven_signs_q2xs; + const uint64_t * grid64 = (const uint64_t *)iq2xxs_grid; + // Shift pattern {0,7,14,21} repeated 8 times for all 8 sub-blocks + uint8_t shift_arr[32] = { + 0, 7, 14, 21, 0, 7, 14, 21, 0, 7, 14, 21, 0, 7, 14, 21, + 0, 7, 14, 21, 0, 7, 14, 21, 0, 7, 14, 21, 0, 7, 14, 21 + }; + vuint8mf2_t v_shifts = __riscv_vle8_v_u8mf2(shift_arr, 32); + + // Gather pattern to broadcast the 8 sub-block scales across the 32 lookup slots + uint8_t gather_arr[32] = { + 0,0,0,0, 1,1,1,1, 2,2,2,2, 3,3,3,3, + 4,4,4,4, 5,5,5,5, 6,6,6,6, 7,7,7,7 + }; + vuint8mf2_t v_sign_gather_idx = __riscv_vle8_v_u8mf2(gather_arr, 32); + + float sumf = 0.0f; + for (int i = 0; i < nb; ++i) { + const float combined_scale = GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d; + + const uint8_t * GGML_RESTRICT q2_ptr = (const uint8_t *) x[i].qs; + const int8_t * GGML_RESTRICT q8 = y[i].qs; + vint8m4_t q8_all = __riscv_vle8_v_i8m4(q8, 256); + + // De-interleave all 8 Index/Scale pairs for the 8x32-element sub-blocks + vuint32mf2x2_t tuple = __riscv_vlseg2e32_v_u32mf2x2((const uint32_t*)q2_ptr, 8); + vuint32mf2_t v_ind32 = __riscv_vget_v_u32mf2x2_u32mf2(tuple, 0); + vuint32mf2_t v_sc32 = __riscv_vget_v_u32mf2x2_u32mf2(tuple, 1); + + vuint8mf2_t v_raw_q2 = __riscv_vreinterpret_v_u32mf2_u8mf2(v_ind32); + vuint16m1_t vidx_q2 = __riscv_vwcvtu_x_x_v_u16m1(v_raw_q2, 32); + vidx_q2 = __riscv_vsll_vx_u16m1(vidx_q2, 3, 32); + + vuint32m2_t v_s = __riscv_vrgatherei16_vv_u32m2(__riscv_vlmul_ext_v_u32mf2_u32m2(v_sc32), __riscv_vwcvtu_x_x_v_u16m1(v_sign_gather_idx,32), 32); + v_s = __riscv_vsrl_vv_u32m2(v_s, __riscv_vwcvtu_x_x_v_u32m2(__riscv_vwcvtu_x_x_v_u16m1(v_shifts,32),32), 32); + v_s = __riscv_vand_vx_u32m2(v_s, 127, 32); + vuint16m1_t vidx_s2 = __riscv_vsll_vx_u16m1(__riscv_vncvt_x_x_w_u16m1(v_s, 32), 3, 32); + + vuint64m4_t vq2_64 = __riscv_vluxei16_v_u64m4(grid64, vidx_q2, 32); + vuint64m4_t vs2_64 = __riscv_vluxei16_v_u64m4(signs64, vidx_s2, 32); + vint8m4_t q2_all = __riscv_vreinterpret_v_u8m4_i8m4(__riscv_vreinterpret_v_u64m4_u8m4(vq2_64)); + vint8m4_t s2_all = __riscv_vreinterpret_v_u8m4_i8m4(__riscv_vreinterpret_v_u64m4_u8m4(vs2_64)); + + vint8m4_t q8s_all = __riscv_vmul_vv_i8m4(q8_all, s2_all, 256); + vint16m8_t dot_all = __riscv_vwmul_vv_i16m8(q8s_all, q2_all, 256); + + float sum = 0.0f; + vint32m1_t zero_vec = __riscv_vmv_v_x_i32m1(0, 1); + + for (int j = 0; j < 8; ++j) { + uint32_t s_p = __riscv_vmv_x_s_u32mf2_u32(__riscv_vslidedown_vx_u32mf2(v_sc32, j, 8)); + int16_t sc = 2 * ((s_p >> 28) & 0xF) + 1; + dot_all=__riscv_vslidedown_vx_i16m8(dot_all,j*32,32); + vint32m1_t sum_v = __riscv_vwredsum_vs_i16m8_i32m1(dot_all, zero_vec, 32); + int32_t isum = __riscv_vmv_x_s_i32m1_i32(sum_v); + sum += (float)isum * sc; + } + + sumf += sum * combined_scale; + } + *s = 0.125f * sumf; +} +#endif + +void ggml_vec_dot_iq2_xxs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { +#if defined __riscv_v + switch (__riscv_vlenb() * 8) { + case 128: + ggml_vec_dot_iq2_xxs_q8_K_vl128(n, s, bs, vx, bx, vy, by, nrc); + break; + case 256: + ggml_vec_dot_iq2_xxs_q8_K_vl256(n, s, bs, vx, bx, vy, by, nrc); + break; + default: // 512 and above + ggml_vec_dot_iq2_xxs_q8_K_vl512(n, s, bs, vx, bx, vy, by, nrc); + break; + } +#else + ggml_vec_dot_iq2_xxs_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc); +#endif +} + +#if defined __riscv_v +static NOINLINE void ggml_vec_dot_iq3_s_q8_K_vl128(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + assert(n % QK_K == 0); + UNUSED(nrc); UNUSED(bx); UNUSED(by); UNUSED(bs); + const block_iq3_s * GGML_RESTRICT x = vx; + const block_q8_K * GGML_RESTRICT y = vy; + + const int nb = n / QK_K; + const uint32_t * grid32 = (const uint32_t *)iq3s_grid; + + vuint8mf2_t v_id_8 = __riscv_vid_v_u8mf2(8); + vuint8m2_t v_id_32 = __riscv_vid_v_u8m2(32); + + // Keeping these in a tight scope to hint they're only needed for the mask computation. + vuint8m2_t v_sign_gather_indices, v_sign_masks; + { + vuint8m2_t v_shifts = __riscv_vand_vx_u8m2(v_id_32, 7, 32); + vuint8m2_t v_one_32 = __riscv_vmv_v_x_u8m2(1, 32); + v_sign_gather_indices = __riscv_vsrl_vx_u8m2(v_id_32, 3, 32); + v_sign_masks = __riscv_vsll_vv_u8m2(v_one_32, v_shifts, 32); + } + + float sumf = 0.0f; + + for (int i = 0; i < nb; ++i) { + const float d = GGML_CPU_FP16_TO_FP32(x[i].d); + const float combined_scale = d * y[i].d; + + const uint8_t * GGML_RESTRICT qs = x[i].qs; + const uint8_t * GGML_RESTRICT qh = x[i].qh; + const uint8_t * GGML_RESTRICT scales = x[i].scales; + const uint8_t * GGML_RESTRICT signs = x[i].signs; + const int8_t * GGML_RESTRICT q8 = y[i].qs; + + float sum_block = 0.0f; + + for (int ib = 0; ib < 8; ++ib) { + + // Grid lookup + vuint8m2_t v_grid_u8; + { + vuint8mf2_t v_qs_u8 = __riscv_vle8_v_u8mf2(qs, 8); + qs += 8; + + uint8_t qh_val = *qh++; + vuint8mf2_t v_qh_val = __riscv_vmv_v_x_u8mf2(qh_val, 8); + v_qh_val = __riscv_vsrl_vv_u8mf2(v_qh_val, v_id_8, 8); + v_qh_val = __riscv_vand_vx_u8mf2(v_qh_val, 1, 8); + + vuint16m1_t v_qs_u16 = __riscv_vwcvtu_x_x_v_u16m1(v_qs_u8, 8); + v_qs_u16 = __riscv_vsll_vx_u16m1(v_qs_u16, 2, 8); + + vuint16m1_t v_qh_u16 = __riscv_vwcvtu_x_x_v_u16m1(v_qh_val, 8); + v_qh_u16 = __riscv_vsll_vx_u16m1(v_qh_u16, 10, 8); + + vuint16m1_t v_grid_offsets = __riscv_vor_vv_u16m1(v_qs_u16, v_qh_u16, 8); + + vuint32m2_t v_grid_packed = __riscv_vluxei16_v_u32m2(grid32, v_grid_offsets, 8); + v_grid_u8 = __riscv_vreinterpret_v_u32m2_u8m2(v_grid_packed); + } + __asm__ volatile ("" ::: "memory"); + + //Sign application and dot product + int32_t s_val; + { + vuint8mf4_t v_signs_raw = __riscv_vle8_v_u8mf4(signs, 4); + signs += 4; + + vuint8m2_t v_signs_source = __riscv_vlmul_ext_v_u8mf4_u8m2(v_signs_raw); + vuint8m2_t v_signs_bcast = __riscv_vrgather_vv_u8m2(v_signs_source, v_sign_gather_indices, 32); + vuint8m2_t v_sign_bits = __riscv_vand_vv_u8m2(v_signs_bcast, v_sign_masks, 32); + vbool4_t m_negative = __riscv_vmsne_vx_u8m2_b4(v_sign_bits, 0, 32); + + vint8m2_t v_q8 = __riscv_vle8_v_i8m2(q8, 32); + q8 += 32; + + vint8m2_t v_q8_signed = __riscv_vrsub_vx_i8m2_mu(m_negative, v_q8, v_q8, 0, 32); + vint16m4_t v_dot = __riscv_vwmulsu_vv_i16m4(v_q8_signed, v_grid_u8, 32); + + vint32m1_t v_zero = __riscv_vmv_v_x_i32m1(0, 1); + s_val = __riscv_vmv_x_s_i32m1_i32( + __riscv_vwredsum_vs_i16m4_i32m1(v_dot, v_zero, 32)); + } + __asm__ volatile ("" ::: "memory"); + { + uint8_t sc_byte = scales[ib >> 1]; + int sc_val = (ib & 1) ? (sc_byte >> 4) : (sc_byte & 0xF); + sc_val = sc_val * 2 + 1; + sum_block += (float)(s_val * sc_val); + } + } + sumf += sum_block * combined_scale; + } + *s = sumf; +} + +static NOINLINE void ggml_vec_dot_iq3_s_q8_K_vl256(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + assert(n % QK_K == 0); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + + const block_iq3_s * GGML_RESTRICT x = vx; + const block_q8_K * GGML_RESTRICT y = vy; + + const int nb = n / QK_K; + + const uint64_t * grid64 = (const uint64_t *)iq3s_grid; + + // --- Pre-load Constants --- + const uint16_t qh_bit_shifts_arr[16] = { + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15 + }; + vuint8m2_t v_sign_gather_indices = __riscv_vle8_v_u8m2(sign_gather_indices_arr, 64); + vuint8m2_t v_sign_masks = __riscv_vle8_v_u8m2(sign_bit_masks_arr, 64); + vuint16m1_t v_qh_shifts = __riscv_vle16_v_u16m1(qh_bit_shifts_arr, 16); + + float sumf = 0.0f; + + for (int i = 0; i < nb; ++i) { + const float d = GGML_CPU_FP16_TO_FP32(x[i].d); + const float combined_scale = d * y[i].d; + + const uint8_t * GGML_RESTRICT qs = x[i].qs; + const uint8_t * GGML_RESTRICT qh = x[i].qh; + const uint8_t * GGML_RESTRICT scales = x[i].scales; + const uint8_t * GGML_RESTRICT signs = x[i].signs; + const int8_t * GGML_RESTRICT q8 = y[i].qs; + + float sum_block = 0.0f; + + // Loop: Process 64 weights (16 mini-blocks of 4) per iteration + for (int ib = 0; ib < 4; ++ib) { + + vuint8mf2_t v_qs_u8 = __riscv_vle8_v_u8mf2(qs, 16); + qs += 16; + + uint16_t qh_val; + memcpy(&qh_val, qh, 2); + qh += 2; + + vuint16m1_t v_qh_val = __riscv_vmv_v_x_u16m1(qh_val, 16); + // Extract bits: (qh >> i) & 1 + v_qh_val = __riscv_vsrl_vv_u16m1(v_qh_val, v_qh_shifts, 16); + v_qh_val = __riscv_vand_vx_u16m1(v_qh_val, 1, 16); + + vuint16m1_t v_qs_u16 = __riscv_vwcvtu_x_x_v_u16m1(v_qs_u8, 16); + v_qs_u16 = __riscv_vsll_vx_u16m1(v_qs_u16, 2, 16); + v_qh_val = __riscv_vsll_vx_u16m1(v_qh_val, 10, 16); + vuint16m1_t v_grid_offsets = __riscv_vor_vv_u16m1(v_qs_u16, v_qh_val, 16); + + // Grid value is 4xuint8 + vuint32m2_t v_grid_packed = __riscv_vluxei16_v_u32m2((const uint32_t *)grid64, v_grid_offsets, 16); + vuint8m2_t v_grid_u8 = __riscv_vreinterpret_v_u32m2_u8m2(v_grid_packed); + vuint8mf4_t v_signs_raw = __riscv_vle8_v_u8mf4(signs, 8); + signs += 8; + + // Generate sign mask + vuint8m2_t v_signs_source = __riscv_vlmul_ext_v_u8mf4_u8m2(v_signs_raw); + vuint8m2_t v_signs_bcast = __riscv_vrgather_vv_u8m2(v_signs_source, v_sign_gather_indices, 64); + vuint8m2_t v_sign_bits = __riscv_vand_vv_u8m2(v_signs_bcast, v_sign_masks, 64); + vbool4_t m_negative = __riscv_vmsne_vx_u8m2_b4(v_sign_bits, 0, 64); + + vint8m2_t v_q8 = __riscv_vle8_v_i8m2(q8, 64); + q8 += 64; + + // Apply Signs + vint8m2_t v_q8_signed = __riscv_vrsub_vx_i8m2_mu(m_negative, v_q8, v_q8, 0, 64); + vint16m4_t v_dot = __riscv_vwmulsu_vv_i16m4(v_q8_signed, v_grid_u8, 64); + + // Reduction + vint16m2_t v_dot_lo = __riscv_vget_v_i16m4_i16m2(v_dot, 0); + vint16m2_t v_dot_hi = __riscv_vget_v_i16m4_i16m2(v_dot, 1); + vint32m1_t v_zero = __riscv_vmv_v_x_i32m1(0, 1); + + int32_t s_lo = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m2_i32m1(v_dot_lo, v_zero, 32)); + int32_t s_hi = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m2_i32m1(v_dot_hi, v_zero, 32)); + + // Apply sub-scales + uint8_t sc_byte = *scales++; + int sc_lo = (sc_byte & 0xF) * 2 + 1; + int sc_hi = (sc_byte >> 4) * 2 + 1; + + sum_block += s_lo * sc_lo + s_hi * sc_hi; + } + sumf += sum_block * combined_scale; + } + *s = sumf; +} + +static NOINLINE void ggml_vec_dot_iq3_s_q8_K_vl512(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + assert(n % QK_K == 0); + UNUSED(nrc); UNUSED(bx); UNUSED(by); UNUSED(bs); + const block_iq3_s * GGML_RESTRICT x = vx; + const block_q8_K * GGML_RESTRICT y = vy; + + const int nb = n / QK_K; + const uint32_t * grid32 = (const uint32_t *)iq3s_grid; + + // Generate Constants + vuint8mf2_t v_id_32 = __riscv_vid_v_u8mf2(32); + vuint8mf2_t v_qh_gather = __riscv_vsrl_vx_u8mf2(v_id_32, 3, 32); + vuint8mf2_t v_qh_shifts = __riscv_vand_vx_u8mf2(v_id_32, 7, 32); + vuint8m2_t v_id_128 = __riscv_vid_v_u8m2(128); + vuint8m2_t v_sign_gather = __riscv_vsrl_vx_u8m2(v_id_128, 3, 128); // byte index + vuint8m2_t v_sign_shift_amts = __riscv_vand_vx_u8m2(v_id_128, 7, 128); // bit shift + vuint8m2_t v_one_128 = __riscv_vmv_v_x_u8m2(1, 128); + vuint8m2_t v_sign_masks = __riscv_vsll_vv_u8m2(v_one_128, v_sign_shift_amts, 128); + vuint8m2_t v_scale_indices = __riscv_vsrl_vx_u8m2(v_id_128, 5, 128); + + float sumf = 0.0f; + + for (int i = 0; i < nb; ++i) { + const float combined_scale = GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d; + + const uint8_t * GGML_RESTRICT qs = x[i].qs; + const uint8_t * GGML_RESTRICT qh = x[i].qh; + const uint8_t * GGML_RESTRICT scales = x[i].scales; + const uint8_t * GGML_RESTRICT signs = x[i].signs; + const int8_t * GGML_RESTRICT q8 = y[i].qs; + + float sum_block = 0.0f; + for (int ib = 0; ib < 2; ++ib) { + vuint8mf2_t v_qs_u8 = __riscv_vle8_v_u8mf2(qs, 32); + qs += 32; + vuint8mf2_t v_qh_loaded = __riscv_vle8_v_u8mf2(qh, 4); + qh += 4; + vuint8mf2_t v_qh_expanded = __riscv_vrgather_vv_u8mf2(v_qh_loaded, v_qh_gather, 32); + v_qh_expanded = __riscv_vsrl_vv_u8mf2(v_qh_expanded, v_qh_shifts, 32); + v_qh_expanded = __riscv_vand_vx_u8mf2(v_qh_expanded, 1, 32); + vuint16m1_t v_qs_u16 = __riscv_vwcvtu_x_x_v_u16m1(v_qs_u8, 32); + v_qs_u16 = __riscv_vsll_vx_u16m1(v_qs_u16, 2, 32); // * 4 + + vuint16m1_t v_qh_u16 = __riscv_vwcvtu_x_x_v_u16m1(v_qh_expanded, 32); + v_qh_u16 = __riscv_vsll_vx_u16m1(v_qh_u16, 10, 32); // * 256 * 4 + + vuint16m1_t v_grid_offsets = __riscv_vor_vv_u16m1(v_qs_u16, v_qh_u16, 32); + vuint32m2_t v_grid_packed = __riscv_vluxei16_v_u32m2(grid32, v_grid_offsets, 32); + vuint8m2_t v_grid_u8 = __riscv_vreinterpret_v_u32m2_u8m2(v_grid_packed); + vuint8mf2_t v_signs_raw = __riscv_vle8_v_u8mf2(signs, 16); + signs += 16; + + vuint8m2_t v_signs_source = __riscv_vlmul_ext_v_u8mf2_u8m2(v_signs_raw); + vuint8m2_t v_signs_bcast = __riscv_vrgather_vv_u8m2(v_signs_source, v_sign_gather, 128); + vuint8m2_t v_sign_bits = __riscv_vand_vv_u8m2(v_signs_bcast, v_sign_masks, 128); + vbool4_t m_negative = __riscv_vmsne_vx_u8m2_b4(v_sign_bits, 0, 128); + + vint8m2_t v_q8 = __riscv_vle8_v_i8m2(q8, 128); + q8 += 128; + + vint8m2_t v_q8_signed = __riscv_vrsub_vx_i8m2_mu(m_negative, v_q8, v_q8, 0, 128); + vint16m4_t v_dot = __riscv_vwmulsu_vv_i16m4(v_q8_signed, v_grid_u8, 128); + uint16_t sc_raw; + memcpy(&sc_raw, scales, 2); + scales += 2; // Advance 2 bytes + + uint8_t sc_unpacked[4]; + sc_unpacked[0] = (sc_raw & 0xF); + sc_unpacked[1] = (sc_raw >> 4) & 0xF; + sc_unpacked[2] = (sc_raw >> 8) & 0xF; + sc_unpacked[3] = (sc_raw >> 12) & 0xF; + + vuint8mf2_t v_sc_4 = __riscv_vle8_v_u8mf2(sc_unpacked, 4); + v_sc_4 = __riscv_vmul_vx_u8mf2(v_sc_4, 2, 4); + v_sc_4 = __riscv_vadd_vx_u8mf2(v_sc_4, 1, 4); + vuint8m2_t v_sc_4_expanded = __riscv_vlmul_ext_v_u8mf2_u8m2(v_sc_4); + vuint8m2_t v_scales_bcast = __riscv_vrgather_vv_u8m2(v_sc_4_expanded, v_scale_indices, 128); + vint16m4_t v_scales_i16 = __riscv_vreinterpret_v_u16m4_i16m4(__riscv_vwcvtu_x_x_v_u16m4(v_scales_bcast, 128)); + vint32m8_t v_weighted_sum = __riscv_vwmul_vv_i32m8(v_dot, v_scales_i16, 128); + vint32m1_t v_zero = __riscv_vmv_v_x_i32m1(0, 1); + int32_t s_val = __riscv_vmv_x_s_i32m1_i32(__riscv_vredsum_vs_i32m8_i32m1(v_weighted_sum, v_zero, 128)); + + sum_block += s_val; + } + sumf += sum_block * combined_scale; + } + *s = sumf; +} +#endif + +void ggml_vec_dot_iq3_s_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { +#if defined __riscv_v + switch (__riscv_vlenb() * 8) { + case 128: + ggml_vec_dot_iq3_s_q8_K_vl128(n, s, bs, vx, bx, vy, by, nrc); + break; + case 256: + ggml_vec_dot_iq3_s_q8_K_vl256(n, s, bs, vx, bx, vy, by, nrc); + break; + default: // 512 and above + ggml_vec_dot_iq3_s_q8_K_vl512(n, s, bs, vx, bx, vy, by, nrc); + break; + } +#else + ggml_vec_dot_iq3_s_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc); +#endif +} + +#if defined __riscv_v +static NOINLINE void ggml_vec_dot_iq3_xxs_q8_K_vl128(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + assert(n % QK_K == 0); + UNUSED(nrc); UNUSED(bx); UNUSED(by); UNUSED(bs); + + const block_iq3_xxs * GGML_RESTRICT x = vx; + const block_q8_K * GGML_RESTRICT y = vy; + const int nb = n / QK_K; + + const uint64_t * signs64 = (const uint64_t *)keven_signs_q2xs; + const uint32_t * grid32 = (const uint32_t *)iq3xxs_grid; + + // constants for unpacking logic + const uint32_t shifts_val[8] = {0, 7, 14, 21, 0, 7, 14, 21}; + vuint32m2_t v_shifts = __riscv_vle32_v_u32m2(shifts_val, 8); + + const uint32_t gather_idx_val[8] = {0, 0, 0, 0, 1, 1, 1, 1}; + vuint32m2_t v_gather_idx = __riscv_vle32_v_u32m2(gather_idx_val, 8); + + uint32_t aux32[2]; + float sumf = 0.0f; + + for (int i = 0; i < nb; ++i) { + const float d = GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d; + + const uint8_t * GGML_RESTRICT q3_indices = x[i].qs; + const uint8_t * GGML_RESTRICT metadata = x[i].qs + QK_K/4; + const int8_t * GGML_RESTRICT q8 = y[i].qs; + + float block_sum = 0.0f; + + // Process 64 weights per loop + for (int ib = 0; ib < QK_K / 64; ++ib) { + + // load of metadata via memcpy + memcpy(aux32, metadata, 2 * sizeof(uint32_t)); + metadata += 2 * sizeof(uint32_t); + + vuint8m1_t v_q3_idx_u8 = __riscv_vle8_v_u8m1(q3_indices, 16); + q3_indices += 16; + + vuint16m2_t v_q3_idx_u16 = __riscv_vwmulu_vx_u16m2(v_q3_idx_u8, 4, 16); + + vuint32m4_t v_q3_magnitudes_u32 = __riscv_vluxei16_v_u32m4(grid32, v_q3_idx_u16, 16); + + vint8m4_t v_q3_magnitudes = __riscv_vreinterpret_v_u8m4_i8m4( + __riscv_vreinterpret_v_u32m4_u8m4(v_q3_magnitudes_u32)); + + vuint32m2_t v_aux = __riscv_vle32_v_u32m2(aux32, 2); + + vuint32m2_t v_aux_expanded = __riscv_vrgather_vv_u32m2(v_aux, v_gather_idx, 8); + + vuint32m2_t v_s_vals_raw = __riscv_vand_vx_u32m2( + __riscv_vsrl_vv_u32m2(v_aux_expanded, v_shifts, 8), 127, 8); + + vuint16m1_t sign_indices_byte_offset = __riscv_vsll_vx_u16m1( + __riscv_vncvt_x_x_w_u16m1(v_s_vals_raw, 8), 3, 8); + + vuint64m4_t v_s_vals_u64 = __riscv_vluxei16_v_u64m4(signs64, sign_indices_byte_offset, 8); + + vint8m4_t v_s_vals = __riscv_vreinterpret_v_u8m4_i8m4( + __riscv_vreinterpret_v_u64m4_u8m4(v_s_vals_u64)); + + vint8m4_t v_q3_signed = __riscv_vmul_vv_i8m4(v_q3_magnitudes, v_s_vals, 64); + asm volatile("" ::: "memory"); + vint8m4_t v_q8 = __riscv_vle8_v_i8m4(q8, 64); + q8 += 64; + + vint16m8_t v_dot = __riscv_vwmul_vv_i16m8(v_q8, v_q3_signed, 64); + + asm volatile("" ::: "memory"); + + vint16m4_t v_dot_1 = __riscv_vget_v_i16m8_i16m4(v_dot, 0); + vint16m4_t v_dot_2 = __riscv_vget_v_i16m8_i16m4(v_dot, 1); + + vint32m1_t v_zero = __riscv_vmv_v_x_i32m1(0, 1); + + vint32m1_t v_sum_1 = __riscv_vwredsum_vs_i16m4_i32m1(v_dot_1, v_zero, 32); + vint32m1_t v_sum_2 = __riscv_vwredsum_vs_i16m4_i32m1(v_dot_2, v_zero, 32); + + int32_t sum1_i = __riscv_vmv_x_s_i32m1_i32(v_sum_1); + int32_t sum2_i = __riscv_vmv_x_s_i32m1_i32(v_sum_2); + + const float scale1_f = (float)(2 * (aux32[0] >> 28) + 1); + const float scale2_f = (float)(2 * (aux32[1] >> 28) + 1); + + block_sum += sum1_i * scale1_f + sum2_i * scale2_f; + } + + sumf += d * block_sum; + } + *s = 0.25f * sumf; +} + +static NOINLINE void ggml_vec_dot_iq3_xxs_q8_K_vl256(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + assert(n % QK_K == 0); + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + + const block_iq3_xxs * GGML_RESTRICT x = vx; + const block_q8_K * GGML_RESTRICT y = vy; + const int nb = n / QK_K; + + const uint64_t * signs64 = (const uint64_t *)keven_signs_q2xs; + const uint32_t * grid32 = (const uint32_t *)iq3xxs_grid; + + // constants for unpacking logic + const uint32_t shifts_val[8] = {0, 7, 14, 21, 0, 7, 14, 21}; + vuint32m1_t v_shifts = __riscv_vle32_v_u32m1(shifts_val, 8); + + const uint32_t gather_idx_val[8] = {0, 0, 0, 0, 1, 1, 1, 1}; + vuint32m1_t v_gather_idx = __riscv_vle32_v_u32m1(gather_idx_val, 8); + + uint32_t aux32[2]; + float sumf = 0.0f; + + for (int i = 0; i < nb; ++i) { + const float d = GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d; + + const uint8_t * GGML_RESTRICT q3_indices = x[i].qs; + const uint8_t * GGML_RESTRICT metadata = x[i].qs + QK_K/4; + const int8_t * GGML_RESTRICT q8 = y[i].qs; + + float block_sum = 0.0f; + + for (int ib = 0; ib < QK_K / 64; ++ib) { + // Load q8 (64 bytes) + vint8m2_t v_q8 = __riscv_vle8_v_i8m2(q8, 64); + q8 += 64; + + // load of metadata via memcpy + memcpy(aux32, metadata, 2 * sizeof(uint32_t)); + metadata += 2 * sizeof(uint32_t); + + // Load q3 indices and gather magnitudes + vuint8mf2_t v_q3_idx_u8 = __riscv_vle8_v_u8mf2(q3_indices, 16); + q3_indices += 16; + + vuint16m1_t v_q3_idx_u16 = __riscv_vwmulu_vx_u16m1(v_q3_idx_u8, 4, 16); + vuint32m2_t v_q3_magnitudes_u32 = __riscv_vluxei16_v_u32m2(grid32, v_q3_idx_u16, 16); + vint8m2_t v_q3_magnitudes = __riscv_vreinterpret_v_u8m2_i8m2(__riscv_vreinterpret_v_u32m2_u8m2(v_q3_magnitudes_u32)); + + // --- Unpacking of Sign Indices --- + + // 1. Load the 2 auxiliary 32-bit integers into a vector + vuint32m1_t v_aux = __riscv_vle32_v_u32m1(aux32, 2); + + // 2. Broadcast/Gather: replicate aux[0] to first 4 lanes, aux[1] to next 4 lanes + vuint32m1_t v_aux_expanded = __riscv_vrgather_vv_u32m1(v_aux, v_gather_idx, 8); + + // 3. Apply Shifts and Mask: ((val >> shift) & 127) + vuint32m1_t v_s_vals_raw = __riscv_vand_vx_u32m1(__riscv_vsrl_vv_u32m1(v_aux_expanded, v_shifts, 8), 127, 8); + + // 4. Narrow to u16 (required for vluxei index) and multiply by 8 (byte offset for u64 table) + vuint16mf2_t sign_indices_byte_offset = __riscv_vsll_vx_u16mf2(__riscv_vncvt_x_x_w_u16mf2(v_s_vals_raw, 8), 3, 8); + + // 5. Gather Signs + vuint64m2_t v_s_vals_u64 = __riscv_vluxei16_v_u64m2(signs64, sign_indices_byte_offset, 8); + vint8m2_t v_s_vals = __riscv_vreinterpret_v_u8m2_i8m2(__riscv_vreinterpret_v_u64m2_u8m2(v_s_vals_u64)); + + vint8m2_t v_q3_signed = __riscv_vmul_vv_i8m2(v_q3_magnitudes, v_s_vals, 64); + vint16m4_t v_dot = __riscv_vwmul_vv_i16m4(v_q8, v_q3_signed, 64); + + vint16m2_t v_dot_1 = __riscv_vget_v_i16m4_i16m2(v_dot, 0); + vint16m2_t v_dot_2 = __riscv_vget_v_i16m4_i16m2(v_dot, 1); + + vint32m1_t v_zero = __riscv_vmv_v_x_i32m1(0, 1); + vint32m1_t v_sum_1 = __riscv_vwredsum_vs_i16m2_i32m1(v_dot_1, v_zero, 32); + vint32m1_t v_sum_2 = __riscv_vwredsum_vs_i16m2_i32m1(v_dot_2, v_zero, 32); + + int32_t sum1_i = __riscv_vmv_x_s_i32m1_i32(v_sum_1); + int32_t sum2_i = __riscv_vmv_x_s_i32m1_i32(v_sum_2); + + const float scale1_f = (float)(2 * (aux32[0] >> 28) + 1); + const float scale2_f = (float)(2 * (aux32[1] >> 28) + 1); + + block_sum += sum1_i * scale1_f + sum2_i * scale2_f; + } + + sumf += d * block_sum; + } + *s = 0.25f * sumf; +} + +static NOINLINE void ggml_vec_dot_iq3_xxs_q8_K_vl512(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + assert(n % QK_K == 0); + assert(nrc == 1); + UNUSED(nrc); UNUSED(bx); UNUSED(by); UNUSED(bs); + const block_iq3_xxs * GGML_RESTRICT x = vx; + const block_q8_K * GGML_RESTRICT y = vy; + const int nb = n / QK_K; + + const uint64_t * signs64 = (const uint64_t *)keven_signs_q2xs; + const uint32_t * grid32 = (const uint32_t *)iq3xxs_grid; + + // generate constants for unpacking metadata words into sign indices + vuint32m1_t v_shifts; + { + vuint32m1_t v_base = __riscv_vid_v_u32m1(16); + vuint32m1_t v_mod4 = __riscv_vand_vx_u32m1(v_base, 3, 16); + v_shifts = __riscv_vmul_vx_u32m1(v_mod4, 7, 16); + } + + vuint16mf2_t v_gather_idx; + { + vuint16mf2_t v_idx = __riscv_vid_v_u16mf2(16); + v_gather_idx = __riscv_vsrl_vx_u16mf2(v_idx, 2, 16); + } + + float sumf = 0.0f; + + for (int i = 0; i < nb; ++i) { + const float d = GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d; + + const uint8_t * GGML_RESTRICT q3_indices = x[i].qs; + const uint8_t * GGML_RESTRICT metadata = x[i].qs + QK_K/4; + const int8_t * GGML_RESTRICT q8 = y[i].qs; + + float block_sum = 0.0f; + for (int ib128 = 0; ib128 < 2; ++ib128) { + + vint8m2_t v_q8 = __riscv_vle8_v_i8m2(q8, 128); + q8 += 128; + vuint8mf2_t v_q3_idx_u8 = __riscv_vle8_v_u8mf2(q3_indices, 32); + q3_indices += 32; + + vuint16m1_t v_q3_idx_u16 = __riscv_vwmulu_vx_u16m1(v_q3_idx_u8, 4, 32); + vuint32m2_t v_q3_mag_u32 = __riscv_vluxei16_v_u32m2(grid32, v_q3_idx_u16, 32); + vint8m2_t v_q3_magnitudes = __riscv_vreinterpret_v_u8m2_i8m2( + __riscv_vreinterpret_v_u32m2_u8m2(v_q3_mag_u32)); + vuint32m1_t v_aux = __riscv_vreinterpret_v_u8m1_u32m1(__riscv_vle8_v_u8m1(metadata, 16)); + metadata += 4 * sizeof(uint32_t); + + vuint32m1_t v_aux_expanded = __riscv_vrgatherei16_vv_u32m1(v_aux, v_gather_idx, 16); + + vuint32m1_t v_s_raw = __riscv_vand_vx_u32m1( + __riscv_vsrl_vv_u32m1(v_aux_expanded, v_shifts, 16), 127, 16); + vuint16mf2_t sign_byte_offset = __riscv_vsll_vx_u16mf2( + __riscv_vncvt_x_x_w_u16mf2(v_s_raw, 16), 3, 16); + vuint64m2_t v_s_u64 = __riscv_vluxei16_v_u64m2(signs64, sign_byte_offset, 16); + vint8m2_t v_signs = __riscv_vreinterpret_v_u8m2_i8m2( + __riscv_vreinterpret_v_u64m2_u8m2(v_s_u64)); + vint8m2_t v_q3_signed = __riscv_vmul_vv_i8m2(v_q3_magnitudes, v_signs, 128); + vint16m4_t prod = __riscv_vwmul_vv_i16m4(v_q3_signed, v_q8, 128); + + vint32m1_t zero_vec = __riscv_vmv_v_x_i32m1(0, 1); + int32_t group0_sum = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m1_i32m1( + __riscv_vget_v_i16m4_i16m1(prod, 0), zero_vec, 32)); + int32_t group1_sum = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m1_i32m1( + __riscv_vget_v_i16m4_i16m1(prod, 1), zero_vec, 32)); + int32_t group2_sum = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m1_i32m1( + __riscv_vget_v_i16m4_i16m1(prod, 2), zero_vec, 32)); + int32_t group3_sum = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m1_i32m1( + __riscv_vget_v_i16m4_i16m1(prod, 3), zero_vec, 32)); + + vuint32m1_t v_scales_raw = __riscv_vsrl_vx_u32m1(v_aux, 28, 4); + vuint32m1_t v_scales = __riscv_vadd_vx_u32m1( + __riscv_vsll_vx_u32m1(v_scales_raw, 1, 4), + 1, 4); + int32_t scale0 = (int32_t)__riscv_vmv_x_s_u32m1_u32(v_scales); + int32_t scale1 = (int32_t)__riscv_vmv_x_s_u32m1_u32(__riscv_vslidedown_vx_u32m1(v_scales, 1, 4)); + int32_t scale2 = (int32_t)__riscv_vmv_x_s_u32m1_u32(__riscv_vslidedown_vx_u32m1(v_scales, 2, 4)); + int32_t scale3 = (int32_t)__riscv_vmv_x_s_u32m1_u32(__riscv_vslidedown_vx_u32m1(v_scales, 3, 4)); + + block_sum += (float)(group0_sum * scale0 + group1_sum * scale1 + + group2_sum * scale2 + group3_sum * scale3); + } + + sumf += d * block_sum; + } + *s = 0.25f * sumf; +} + +static NOINLINE void ggml_vec_dot_iq3_xxs_q8_K_vl1024(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + assert(n % QK_K == 0); + assert(nrc == 1); + UNUSED(nrc); UNUSED(bx); UNUSED(by); UNUSED(bs); + + const block_iq3_xxs * GGML_RESTRICT x = vx; + const block_q8_K * GGML_RESTRICT y = vy; + const int nb = n / QK_K; + + const uint64_t * signs64 = (const uint64_t *)keven_signs_q2xs; + const uint32_t * grid32 = (const uint32_t *)iq3xxs_grid; + + vuint32m1_t v_shifts; + { + vuint32m1_t v_id = __riscv_vid_v_u32m1(32); + vuint32m1_t v_mod4 = __riscv_vand_vx_u32m1(v_id, 3, 32); + v_shifts = __riscv_vmul_vx_u32m1(v_mod4, 7, 32); + } + vuint16mf2_t v_gather_idx; + { + vuint16mf2_t v_id_16 = __riscv_vid_v_u16mf2(32); + v_gather_idx = __riscv_vsrl_vx_u16mf2(v_id_16, 2, 32); + } + + float sumf = 0.0f; + uint32_t aux32[8]; // Buffer for block metadata + + for (int i = 0; i < nb; ++i) { + const float d = GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d; + + const uint8_t * GGML_RESTRICT q3_indices = x[i].qs; + const uint8_t * GGML_RESTRICT metadata = x[i].qs + QK_K/4; + const int8_t * GGML_RESTRICT q8 = y[i].qs; + + vint8m2_t v_q8 = __riscv_vle8_v_i8m2(q8, 256); + vuint8mf2_t v_q3_idx_raw = __riscv_vle8_v_u8mf2(q3_indices, 64); + vuint16m1_t v_q3_idx_u16 = __riscv_vwmulu_vx_u16m1(v_q3_idx_raw, 4, 64); + + vuint32m2_t v_q3_grid_vals = __riscv_vluxei16_v_u32m2(grid32, v_q3_idx_u16, 64); + + vint8m2_t v_q3_mags = __riscv_vreinterpret_v_u8m2_i8m2( + __riscv_vreinterpret_v_u32m2_u8m2(v_q3_grid_vals)); + + memcpy(aux32, metadata, 8 * sizeof(uint32_t)); + vuint32m1_t v_aux_8 = __riscv_vle32_v_u32m1(aux32, 8); + + vuint32m1_t v_aux_32 = __riscv_vrgatherei16_vv_u32m1(v_aux_8, v_gather_idx, 32); + + vuint32m1_t v_sign_idx_raw = __riscv_vand_vx_u32m1( + __riscv_vsrl_vv_u32m1(v_aux_32, v_shifts, 32), 127, 32); + + vuint16mf2_t v_sign_offsets = __riscv_vsll_vx_u16mf2( + __riscv_vncvt_x_x_w_u16mf2(v_sign_idx_raw, 32), 3, 32); + + vuint64m2_t v_signs_u64 = __riscv_vluxei16_v_u64m2(signs64, v_sign_offsets, 32); + + vint8m2_t v_signs = __riscv_vreinterpret_v_u8m2_i8m2( + __riscv_vreinterpret_v_u64m2_u8m2(v_signs_u64)); + + vint8m2_t v_q3_final = __riscv_vmul_vv_i8m2(v_q3_mags, v_signs, 256); + + vint16m4_t v_dot = __riscv_vwmul_vv_i16m4(v_q8, v_q3_final, 256); + float block_sum = 0.0f; + vint32m1_t v_zero = __riscv_vmv_v_x_i32m1(0, 1); + vint16m4_t v_accum = v_dot; + + for (int j = 0; j < 8; ++j) { + float scale = (float)(2 * (aux32[j] >> 28) + 1); + + vint32m1_t v_partial_sum = __riscv_vwredsum_vs_i16m4_i32m1(v_accum, v_zero, 32); + + int32_t partial_sum_i = __riscv_vmv_x_s_i32m1_i32(v_partial_sum); + block_sum += partial_sum_i * scale; + v_accum = __riscv_vslidedown_vx_i16m4(v_accum, 32, 32); + + } + + sumf += d * block_sum; + } + *s = 0.25f * sumf; +} +#endif + +void ggml_vec_dot_iq3_xxs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { +#if defined __riscv_v + switch (__riscv_vlenb() * 8) { + case 128: + ggml_vec_dot_iq3_xxs_q8_K_vl128(n, s, bs, vx, bx, vy, by, nrc); + break; + case 256: + ggml_vec_dot_iq3_xxs_q8_K_vl256(n, s, bs, vx, bx, vy, by, nrc); + break; + case 512: + ggml_vec_dot_iq3_xxs_q8_K_vl512(n, s, bs, vx, bx, vy, by, nrc); + break; + default: // 1024 and above + ggml_vec_dot_iq3_xxs_q8_K_vl1024(n, s, bs, vx, bx, vy, by, nrc); + break; + } +#else + ggml_vec_dot_iq3_xxs_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc); +#endif +} + +#if defined __riscv_v +static NOINLINE void ggml_vec_dot_iq4_nl_q8_0_vl128(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + assert(n % QK4_NL == 0); + static_assert(QK4_NL == QK8_0, "QK4_NL and QK8_0 must be the same"); + + const block_iq4_nl * GGML_RESTRICT x = vx; + const block_q8_0 * GGML_RESTRICT y = vy; + + const int nb = n / QK4_NL; + + int ib = 0; + float sumf = 0; + + // Load the lookup table once. + const vint8m2_t values = __riscv_vle8_v_i8m2(kvalues_iq4nl, 16); + int acc1, acc2; + + // We process 2 blocks at once. + for (; ib + 1 < nb; ib += 2) { + // Weights and activations. + vuint8m1_t iq4_packed1 = __riscv_vle8_v_u8m1(x[ib + 0].qs, 16); + vint8m2_t q8b1 = __riscv_vle8_v_i8m2(y[ib + 0].qs, 32); + vuint8m1_t iq4_packed2 = __riscv_vle8_v_u8m1(x[ib + 1].qs, 16); + vint8m2_t q8b2 = __riscv_vle8_v_i8m2(y[ib + 1].qs, 32); + + // Unpack the weight blocks. + vuint8m2_t iq4bits1 = __riscv_vcreate_v_u8m1_u8m2( + __riscv_vand_vx_u8m1(iq4_packed1, 0xf, 16), + __riscv_vsrl_vx_u8m1(iq4_packed1, 4, 16) + ); + vuint8m2_t iq4bits2 = __riscv_vcreate_v_u8m1_u8m2( + __riscv_vand_vx_u8m1(iq4_packed2, 0xf, 16), + __riscv_vsrl_vx_u8m1(iq4_packed2, 4, 16) + ); + + // Gather values from the lookup table. + vint8m2_t iq4b1 = __riscv_vrgather_vv_i8m2(values, iq4bits1, 32); + vint8m2_t iq4b2 = __riscv_vrgather_vv_i8m2(values, iq4bits2, 32); + + // Accumulation. + vint16m4_t sum1 = __riscv_vwmul_vv_i16m4(q8b1, iq4b1, 32); + vint16m4_t sum2 = __riscv_vwmul_vv_i16m4(q8b2, iq4b2, 32); + __riscv_vse32_v_i32m1(&acc1,__riscv_vwredsum_vs_i16m4_i32m1(sum1, __riscv_vmv_v_x_i32m1(0, 1), 32), 1); + __riscv_vse32_v_i32m1(&acc2,__riscv_vwredsum_vs_i16m4_i32m1(sum2, __riscv_vmv_v_x_i32m1(0, 1), 32), 1); + sumf += ((GGML_CPU_FP16_TO_FP32(x[ib + 0].d) * GGML_CPU_FP16_TO_FP32(y[ib + 0].d) * acc1)); + sumf += ((GGML_CPU_FP16_TO_FP32(x[ib + 1].d) * GGML_CPU_FP16_TO_FP32(y[ib + 1].d) * acc2)); + } + + *s = sumf; +} + +static NOINLINE void ggml_vec_dot_iq4_nl_q8_0_vl256(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + assert(n % QK4_NL == 0); + static_assert(QK4_NL == QK8_0, "QK4_NL and QK8_0 must be the same"); + + const block_iq4_nl * GGML_RESTRICT x = vx; + const block_q8_0 * GGML_RESTRICT y = vy; + + const int nb = n / QK4_NL; + + int ib = 0; + float sumf = 0; + + // Load the lookup table once. + const vint8mf2_t values = __riscv_vle8_v_i8mf2(kvalues_iq4nl, 16); + int acc1, acc2; + + // We process 2 blocks at once. + for (; ib + 1 < nb; ib += 2) { + // Weights and activations. + vuint8mf2_t iq4_packed1 = __riscv_vle8_v_u8mf2(x[ib + 0].qs, 16); + vint8mf2_t q8b_lo1 = __riscv_vle8_v_i8mf2(y[ib + 0].qs, 16); + vint8mf2_t q8b_hi1 = __riscv_vle8_v_i8mf2(y[ib + 0].qs + 16, 16); + vuint8mf2_t iq4_packed2 = __riscv_vle8_v_u8mf2(x[ib + 1].qs, 16); + vint8mf2_t q8b_lo2 = __riscv_vle8_v_i8mf2(y[ib + 1].qs, 16); + vint8mf2_t q8b_hi2 = __riscv_vle8_v_i8mf2(y[ib + 1].qs + 16, 16); + + // Unpack the weight blocks. + vuint8mf2_t iq4bits_lo1 = __riscv_vand_vx_u8mf2(iq4_packed1, 0xf, 16); + vuint8mf2_t iq4bits_hi1 = __riscv_vsrl_vx_u8mf2(iq4_packed1, 4, 16); + vuint8mf2_t iq4bits_lo2 = __riscv_vand_vx_u8mf2(iq4_packed2, 0xf, 16); + vuint8mf2_t iq4bits_hi2 = __riscv_vsrl_vx_u8mf2(iq4_packed2, 4, 16); + + // Gather values from the lookup table. + vint8mf2_t iq4b_lo1 = __riscv_vrgather_vv_i8mf2(values, iq4bits_lo1, 16); + vint8mf2_t iq4b_hi1 = __riscv_vrgather_vv_i8mf2(values, iq4bits_hi1, 16); + vint8mf2_t iq4b_lo2 = __riscv_vrgather_vv_i8mf2(values, iq4bits_lo2, 16); + vint8mf2_t iq4b_hi2 = __riscv_vrgather_vv_i8mf2(values, iq4bits_hi2, 16); + + // Accumulation. + vint16m1_t sum1 = __riscv_vwmul_vv_i16m1(q8b_lo1, iq4b_lo1, 16); + sum1 = __riscv_vwmacc_vv_i16m1(sum1, q8b_hi1, iq4b_hi1, 16); + vint16m1_t sum2 = __riscv_vwmul_vv_i16m1(q8b_lo2, iq4b_lo2, 16); + sum2 = __riscv_vwmacc_vv_i16m1(sum2, q8b_hi2, iq4b_hi2, 16); + __riscv_vse32_v_i32m1(&acc1,__riscv_vwredsum_vs_i16m1_i32m1(sum1, __riscv_vmv_v_x_i32m1(0, 1), 16), 1); + __riscv_vse32_v_i32m1(&acc2,__riscv_vwredsum_vs_i16m1_i32m1(sum2, __riscv_vmv_v_x_i32m1(0, 1), 16), 1); + sumf += ((GGML_CPU_FP16_TO_FP32(x[ib + 0].d) * GGML_CPU_FP16_TO_FP32(y[ib + 0].d) * acc1)); + sumf += ((GGML_CPU_FP16_TO_FP32(x[ib + 1].d) * GGML_CPU_FP16_TO_FP32(y[ib + 1].d) * acc2)); + } + + *s = sumf; +} +#endif + +void ggml_vec_dot_iq4_nl_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { +#if defined __riscv_v + switch (__riscv_vlenb() * 8) { + case 128: + ggml_vec_dot_iq4_nl_q8_0_vl128(n, s, bs, vx, bx, vy, by, nrc); + break; + default: // 256 and above + ggml_vec_dot_iq4_nl_q8_0_vl256(n, s, bs, vx, bx, vy, by, nrc); + break; + } +#else + ggml_vec_dot_iq4_nl_q8_0_generic(n, s, bs, vx, bx, vy, by, nrc); +#endif +} + +#if defined __riscv_v +static NOINLINE void ggml_vec_dot_iq4_xs_q8_K_vl128(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + assert(n % QK_K == 0); + + const block_iq4_xs * GGML_RESTRICT x = vx; + const block_q8_K * GGML_RESTRICT y = vy; + + const int nb = n / QK_K; + + const vint8m4_t values = __riscv_vle8_v_i8m4(kvalues_iq4nl, 16); + float sumf = 0; + + for (int ibl = 0; ibl < nb; ++ibl) { + const int8_t * q8 = y[ibl].qs; + const uint8_t * iq4 = x[ibl].qs; + uint16_t h = x[ibl].scales_h; + + // We process 2 sub-blocks together. + int sumi1 = 0, sumi2 = 0; + #pragma GCC unroll 1 + for (int ib = 0; ib < QK_K / 64; ++ib) { + // Load the packed weights. + const vuint8m2_t iq4_packed = __riscv_vle8_v_u8m2(iq4, 32); + iq4 += 32; + + // Unpack the weight blocks. + const vuint8m2_t iq4bits_lo = __riscv_vand_vx_u8m2(iq4_packed, 0xf, 32); + const vuint8m2_t iq4bits_hi = __riscv_vsrl_vx_u8m2(iq4_packed, 4, 32); + const vuint8m4_t iq4bits = __riscv_vcreate_v_u8m2_u8m4(iq4bits_lo, iq4bits_hi); + const vuint8m4_t iq4bits_reorder = __riscv_vcreate_v_u8m1_u8m4( + __riscv_vmv_v_v_u8m1(__riscv_vget_v_u8m4_u8m1(iq4bits, 0), 16), + __riscv_vmv_v_v_u8m1(__riscv_vget_v_u8m4_u8m1(iq4bits, 2), 16), + __riscv_vmv_v_v_u8m1(__riscv_vget_v_u8m4_u8m1(iq4bits, 1), 16), + __riscv_vmv_v_v_u8m1(__riscv_vget_v_u8m4_u8m1(iq4bits, 3), 16) + ); + const vint8m4_t iq4b = __riscv_vrgather_vv_i8m4(values, iq4bits_reorder, 64); + + // Multiply with activations. + const vint8m4_t q8b = __riscv_vle8_v_i8m4(q8, 64); + q8 += 64; + const vint16m8_t prod = __riscv_vwmul_vv_i16m8(iq4b, q8b, 64); + + // Reduce separately. + const int acc0 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m4_i32m1(__riscv_vget_v_i16m8_i16m4(prod, 0), __riscv_vmv_v_x_i32m1(0, 1), 32)); + const int acc1 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m4_i32m1(__riscv_vget_v_i16m8_i16m4(prod, 1), __riscv_vmv_v_x_i32m1(0, 1), 32)); + + const int ls1 = ((x[ibl].scales_l[ib] & 0xf) | ((h << 4) & 0x30)) - 32; + const int ls2 = ((x[ibl].scales_l[ib] >> 4) | ((h << 2) & 0x30)) - 32; + h >>= 4; + + sumi1 += acc0 * ls1; + sumi2 += acc1 * ls2; + + __asm__ __volatile__("" ::: "memory"); + } + + sumf += GGML_CPU_FP16_TO_FP32(x[ibl].d) * y[ibl].d * (sumi1 + sumi2); + } + + *s = sumf; +} + +static NOINLINE void ggml_vec_dot_iq4_xs_q8_K_vl256(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + assert(n % QK_K == 0); + + const block_iq4_xs * GGML_RESTRICT x = vx; + const block_q8_K * GGML_RESTRICT y = vy; + + const int nb = n / QK_K; + + const vint8m4_t values = __riscv_vle8_v_i8m4(kvalues_iq4nl, 16); + float sumf = 0; + + // Indices for re-ordering IQ4 data. + uint16_t index[16] = { + 0, 1, 8, 9, + 2, 3, 10, 11, + 4, 5,12, 13, + 6, 7, 14, 15, + }; + vuint16m1_t i_vec = __riscv_vle16_v_u16m1(index, 16); + + for (int ibl = 0; ibl < nb; ++ibl) { + const int8_t * q8 = y[ibl].qs; + const uint8_t * iq4 = x[ibl].qs; + uint16_t h = x[ibl].scales_h; + + int sumi1 = 0, sumi2 = 0, sumi3 = 0, sumi4 = 0; + + #pragma GCC unroll 1 + for (int ib = 0; ib < QK_K / 128; ++ib) { + // Weights and activations. + vuint8m2_t iq4_packed = __riscv_vle8_v_u8m2(iq4, 64); + iq4 += 64; + + // Unpack the weight blocks. + vuint8m2_t iq4bits_lo = __riscv_vand_vx_u8m2(iq4_packed, 0xf, 64); + vuint8m2_t iq4bits_hi = __riscv_vsrl_vx_u8m2(iq4_packed, 4, 64); + vuint8m4_t iq4bits = __riscv_vcreate_v_u8m2_u8m4(iq4bits_lo, iq4bits_hi); + vuint8m4_t iq4bits_reorder = __riscv_vreinterpret_v_u64m4_u8m4(__riscv_vrgatherei16_vv_u64m4(__riscv_vreinterpret_v_u8m4_u64m4(iq4bits), i_vec, 16)); + vint8m4_t iq4b = __riscv_vrgather_vv_i8m4(values, iq4bits_reorder, 128); + + __asm__ __volatile__("" ::: "memory"); + + // Multiply with activations. + vint8m4_t q8b = __riscv_vle8_v_i8m4(q8, 128); + vint16m8_t prod = __riscv_vwmul_vv_i16m8(iq4b, q8b, 128); + q8 += 128; + + __asm__ __volatile__("" ::: "memory"); + + // Reduce separately. + int acc0 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m2_i32m1(__riscv_vget_v_i16m8_i16m2(prod, 0), __riscv_vmv_v_x_i32m1(0, 1), 32)); + int acc1 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m2_i32m1(__riscv_vget_v_i16m8_i16m2(prod, 1), __riscv_vmv_v_x_i32m1(0, 1), 32)); + int acc2 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m2_i32m1(__riscv_vget_v_i16m8_i16m2(prod, 2), __riscv_vmv_v_x_i32m1(0, 1), 32)); + int acc3 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m2_i32m1(__riscv_vget_v_i16m8_i16m2(prod, 3), __riscv_vmv_v_x_i32m1(0, 1), 32)); + + int ls1 = ((x[ibl].scales_l[ib * 2 + 0] & 0xf) | ((h << 4) & 0x30)) - 32; + int ls2 = ((x[ibl].scales_l[ib * 2 + 0] >> 4) | ((h << 2) & 0x30)) - 32; + int ls3 = ((x[ibl].scales_l[ib * 2 + 1] & 0xf) | ((h << 0) & 0x30)) - 32; + int ls4 = ((x[ibl].scales_l[ib * 2 + 1] >> 4) | ((h >> 2) & 0x30)) - 32; + h >>= 8; + + sumi1 += acc0 * ls1; + sumi2 += acc1 * ls2; + sumi3 += acc2 * ls3; + sumi4 += acc3 * ls4; + + __asm__ __volatile__("" ::: "memory"); + } + + sumf += GGML_CPU_FP16_TO_FP32(x[ibl].d) * y[ibl].d * (sumi1 + sumi2 + sumi3 + sumi4); + } + + *s = sumf; +} + +static NOINLINE void ggml_vec_dot_iq4_xs_q8_K_vl512(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + assert(n % QK_K == 0); + + const block_iq4_xs * GGML_RESTRICT x = vx; + const block_q8_K * GGML_RESTRICT y = vy; + + const int nb = n / QK_K; + + const vint8m4_t values = __riscv_vle8_v_i8m4(kvalues_iq4nl, 16); + float sumf = 0; + + // Indices for re-ordering IQ4 data. + const uint16_t index[32] = { + 0, 1, 16, 17, + 2, 3, 18, 19, + 4, 5,20, 21, + 6, 7, 22, 23, + 8, 9, 24, 25, + 10, 11, 26, 27, + 12, 13,28, 29, + 14, 15, 30, 31, + }; + const vuint16m1_t i_vec = __riscv_vle16_v_u16m1(index, 32); + + for (int ibl = 0; ibl < nb; ++ibl) { + const int8_t * q8 = y[ibl].qs; + const uint8_t * iq4 = x[ibl].qs; + uint16_t h = x[ibl].scales_h; + + int sumi = 0; + + #pragma GCC unroll 1 + // Process the entire super-block together. + for (int ib = 0; ib < QK_K / 256; ++ib) { + // Weights and activations. + const vuint8m2_t iq4_packed = __riscv_vle8_v_u8m2(iq4, 128); + iq4 += 128; + + // Unpack the weight blocks. + const vuint8m2_t iq4bits_lo = __riscv_vand_vx_u8m2(iq4_packed, 0xf, 128); + const vuint8m2_t iq4bits_hi = __riscv_vsrl_vx_u8m2(iq4_packed, 4, 128); + const vuint8m4_t iq4bits = __riscv_vcreate_v_u8m2_u8m4(iq4bits_lo, iq4bits_hi); + const vuint8m4_t iq4bits_reorder = __riscv_vreinterpret_v_u64m4_u8m4(__riscv_vrgatherei16_vv_u64m4(__riscv_vreinterpret_v_u8m4_u64m4(iq4bits), i_vec, 32)); + const vint8m4_t iq4b = __riscv_vrgather_vv_i8m4(values, iq4bits_reorder, 256); + + __asm__ __volatile__("" ::: "memory"); + + // Multiply with activations. + const vint8m4_t q8b = __riscv_vle8_v_i8m4(q8, 256); + const vint16m8_t prod = __riscv_vwmul_vv_i16m8(iq4b, q8b, 256); + q8 += 256; + + // Reduce separately. + const int acc0 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m1_i32m1(__riscv_vget_v_i16m8_i16m1(prod, 0), __riscv_vmv_v_x_i32m1(0, 1), 32)); + const int acc1 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m1_i32m1(__riscv_vget_v_i16m8_i16m1(prod, 1), __riscv_vmv_v_x_i32m1(0, 1), 32)); + const int acc2 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m1_i32m1(__riscv_vget_v_i16m8_i16m1(prod, 2), __riscv_vmv_v_x_i32m1(0, 1), 32)); + const int acc3 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m1_i32m1(__riscv_vget_v_i16m8_i16m1(prod, 3), __riscv_vmv_v_x_i32m1(0, 1), 32)); + const int acc4 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m1_i32m1(__riscv_vget_v_i16m8_i16m1(prod, 4), __riscv_vmv_v_x_i32m1(0, 1), 32)); + const int acc5 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m1_i32m1(__riscv_vget_v_i16m8_i16m1(prod, 5), __riscv_vmv_v_x_i32m1(0, 1), 32)); + const int acc6 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m1_i32m1(__riscv_vget_v_i16m8_i16m1(prod, 6), __riscv_vmv_v_x_i32m1(0, 1), 32)); + const int acc7 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m1_i32m1(__riscv_vget_v_i16m8_i16m1(prod, 7), __riscv_vmv_v_x_i32m1(0, 1), 32)); + + + const int ls0 = ((x[ibl].scales_l[0] & 0xf) | ((h << 4) & 0x30)) - 32; + const int ls1 = ((x[ibl].scales_l[0] >> 4) | ((h << 2) & 0x30)) - 32; + const int ls2 = ((x[ibl].scales_l[1] & 0xf) | ((h << 0) & 0x30)) - 32; + const int ls3 = ((x[ibl].scales_l[1] >> 4) | ((h >> 2) & 0x30)) - 32; + h >>= 8; + const int ls4 = ((x[ibl].scales_l[2] & 0xf) | ((h << 4) & 0x30)) - 32; + const int ls5 = ((x[ibl].scales_l[2] >> 4) | ((h << 2) & 0x30)) - 32; + const int ls6 = ((x[ibl].scales_l[3] & 0xf) | ((h << 0) & 0x30)) - 32; + const int ls7 = ((x[ibl].scales_l[3] >> 4) | ((h >> 2) & 0x30)) - 32; + + sumi += acc0 * ls0; + sumi += acc1 * ls1; + sumi += acc2 * ls2; + sumi += acc3 * ls3; + sumi += acc4 * ls4; + sumi += acc5 * ls5; + sumi += acc6 * ls6; + sumi += acc7 * ls7; + + __asm__ __volatile__("" ::: "memory"); + } + + sumf += GGML_CPU_FP16_TO_FP32(x[ibl].d) * y[ibl].d * (sumi); + } + + *s = sumf; +} + +static NOINLINE void ggml_vec_dot_iq4_xs_q8_K_vl1024(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + assert(n % QK_K == 0); + + const block_iq4_xs * GGML_RESTRICT x = vx; + const block_q8_K * GGML_RESTRICT y = vy; + + const int nb = n / QK_K; + + const vint8m2_t values = __riscv_vle8_v_i8m2(kvalues_iq4nl, 16); + float sumf = 0; + + // Indices for re-ordering IQ4 data. + const uint16_t index[32] = { + 0, 1, 16, 17, + 2, 3, 18, 19, + 4, 5,20, 21, + 6, 7, 22, 23, + 8, 9, 24, 25, + 10, 11, 26, 27, + 12, 13,28, 29, + 14, 15, 30, 31, + }; + const vuint16mf2_t i_vec = __riscv_vle16_v_u16mf2(index, 32); + + for (int ibl = 0; ibl < nb; ++ibl) { + const int8_t * q8 = y[ibl].qs; + const uint8_t * iq4 = x[ibl].qs; + uint16_t h = x[ibl].scales_h; + + int sumi = 0; + + #pragma GCC unroll 1 + // Process the entire super-block together. + for (int ib = 0; ib < QK_K / 256; ++ib) { + // Weights and activations. + const vuint8m1_t iq4_packed = __riscv_vle8_v_u8m1(iq4, 128); + iq4 += 128; + + // Unpack the weight blocks. + const vuint8m1_t iq4bits_lo = __riscv_vand_vx_u8m1(iq4_packed, 0xf, 128); + const vuint8m1_t iq4bits_hi = __riscv_vsrl_vx_u8m1(iq4_packed, 4, 128); + const vuint8m2_t iq4bits = __riscv_vcreate_v_u8m1_u8m2(iq4bits_lo, iq4bits_hi); + const vuint8m2_t iq4bits_reorder = __riscv_vreinterpret_v_u64m2_u8m2(__riscv_vrgatherei16_vv_u64m2(__riscv_vreinterpret_v_u8m2_u64m2(iq4bits), i_vec, 32)); + const vint8m2_t iq4b = __riscv_vrgather_vv_i8m2(values, iq4bits_reorder, 256); + + __asm__ __volatile__("" ::: "memory"); + + // Multiply with activations. + const vint8m2_t q8b = __riscv_vle8_v_i8m2(q8, 256); + const vint16m4_t prod = __riscv_vwmul_vv_i16m4(iq4b, q8b, 256); + q8 += 256; + + // Mask for processing 32 elements per prod register. + const vuint16m1_t p_index = __riscv_vid_v_u16m1(64); + const vbool16_t p_mask = __riscv_vmsgtu_vx_u16m1_b16(p_index, 31, 64); + + // Reduce separately. + const int acc0 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m1_i32m1( __riscv_vget_v_i16m4_i16m1(prod, 0), __riscv_vmv_v_x_i32m1(0, 1), 32)); + const int acc1 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m1_i32m1_m(p_mask, __riscv_vget_v_i16m4_i16m1(prod, 0), __riscv_vmv_v_x_i32m1(0, 1), 64)); + const int acc2 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m1_i32m1( __riscv_vget_v_i16m4_i16m1(prod, 1), __riscv_vmv_v_x_i32m1(0, 1), 32)); + const int acc3 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m1_i32m1_m(p_mask, __riscv_vget_v_i16m4_i16m1(prod, 1), __riscv_vmv_v_x_i32m1(0, 1), 64)); + const int acc4 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m1_i32m1( __riscv_vget_v_i16m4_i16m1(prod, 2), __riscv_vmv_v_x_i32m1(0, 1), 32)); + const int acc5 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m1_i32m1_m(p_mask, __riscv_vget_v_i16m4_i16m1(prod, 2), __riscv_vmv_v_x_i32m1(0, 1), 64)); + const int acc6 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m1_i32m1( __riscv_vget_v_i16m4_i16m1(prod, 3), __riscv_vmv_v_x_i32m1(0, 1), 32)); + const int acc7 = __riscv_vmv_x_s_i32m1_i32(__riscv_vwredsum_vs_i16m1_i32m1_m(p_mask, __riscv_vget_v_i16m4_i16m1(prod, 3), __riscv_vmv_v_x_i32m1(0, 1), 64)); + + const int ls0 = ((x[ibl].scales_l[0] & 0xf) | ((h << 4) & 0x30)) - 32; + const int ls1 = ((x[ibl].scales_l[0] >> 4) | ((h << 2) & 0x30)) - 32; + const int ls2 = ((x[ibl].scales_l[1] & 0xf) | ((h << 0) & 0x30)) - 32; + const int ls3 = ((x[ibl].scales_l[1] >> 4) | ((h >> 2) & 0x30)) - 32; + h >>= 8; + const int ls4 = ((x[ibl].scales_l[2] & 0xf) | ((h << 4) & 0x30)) - 32; + const int ls5 = ((x[ibl].scales_l[2] >> 4) | ((h << 2) & 0x30)) - 32; + const int ls6 = ((x[ibl].scales_l[3] & 0xf) | ((h << 0) & 0x30)) - 32; + const int ls7 = ((x[ibl].scales_l[3] >> 4) | ((h >> 2) & 0x30)) - 32; + + sumi += acc0 * ls0; + sumi += acc1 * ls1; + sumi += acc2 * ls2; + sumi += acc3 * ls3; + sumi += acc4 * ls4; + sumi += acc5 * ls5; + sumi += acc6 * ls6; + sumi += acc7 * ls7; + + __asm__ __volatile__("" ::: "memory"); + } + + sumf += GGML_CPU_FP16_TO_FP32(x[ibl].d) * y[ibl].d * (sumi); + } + + *s = sumf; +} +#endif + +void ggml_vec_dot_iq4_xs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { +#if defined __riscv_v + switch (__riscv_vlenb() * 8) { + case 128: + ggml_vec_dot_iq4_xs_q8_K_vl128(n, s, bs, vx, bx, vy, by, nrc); + break; + case 256: + ggml_vec_dot_iq4_xs_q8_K_vl256(n, s, bs, vx, bx, vy, by, nrc); + break; + case 512: + ggml_vec_dot_iq4_xs_q8_K_vl512(n, s, bs, vx, bx, vy, by, nrc); + break; + case 1024: + ggml_vec_dot_iq4_xs_q8_K_vl1024(n, s, bs, vx, bx, vy, by, nrc); + break; + default: + ggml_vec_dot_iq4_xs_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc); + break; + } +#else + ggml_vec_dot_iq4_xs_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc); +#endif +} + +#if defined __riscv_v +static NOINLINE void ggml_vec_dot_tq1_0_q8_K_vl128(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + + const block_tq1_0 * GGML_RESTRICT x = vx; + const block_q8_K * GGML_RESTRICT y = vy; + + const int nb = n / QK_K; + + float sumf = 0.0f; + uint8_t pow[16] = {1, 1, 1, 1, 3, 3, 3, 3, 9, 9, 9, 9, 27, 27, 27, 27}; + + for (int i = 0; i < nb; i++) { + const uint8_t * GGML_RESTRICT tq = x[i].qs; + const int8_t * GGML_RESTRICT q8 = y[i].qs; + + // First loop. + vint16m4_t suml1; + { + const int vl = 32; + const vuint8m2_t tqb = __riscv_vle8_v_u8m2(tq, vl); + tq += 32; + + { + const vuint16m4_t tq0 = __riscv_vsrl_vx_u16m4(__riscv_vwmulu_vx_u16m4(tqb, 3, vl), 8, vl); + const vint16m4_t q80 = __riscv_vwcvt_x_x_v_i16m4(__riscv_vle8_v_i8m2(q8, vl), vl); + suml1 = __riscv_vmul_vv_i16m4(__riscv_vreinterpret_v_u16m4_i16m4(__riscv_vsub_vx_u16m4(tq0, 1, vl)), q80, vl); + q8 += 32; + } + + uint8_t pow3 = 3; + #pragma GCC unroll 1 + for (int t = 0; t < 4; t++) { + const vuint16m4_t tqn = __riscv_vsrl_vx_u16m4(__riscv_vwmulu_vx_u16m4(__riscv_vmul_vx_u8m2(tqb, pow3, vl), 3, vl), 8, vl); + const vint16m4_t q8n = __riscv_vwcvt_x_x_v_i16m4(__riscv_vle8_v_i8m2(q8, vl), vl); + suml1 = __riscv_vmacc_vv_i16m4(suml1, __riscv_vreinterpret_v_u16m4_i16m4(__riscv_vsub_vx_u16m4(tqn, 1, vl)), q8n, vl); + pow3 *= 3; + q8 += 32; + } + } + + // Second loop. + vint16m2_t suml2; + { + const int vl = 16; + const vuint8m1_t tqb = __riscv_vle8_v_u8m1(tq, vl); + + { + const vuint16m2_t tq0 = __riscv_vsrl_vx_u16m2(__riscv_vwmulu_vx_u16m2(tqb, 3, vl), 8, vl); + const vint16m2_t q80 = __riscv_vwcvt_x_x_v_i16m2(__riscv_vle8_v_i8m1(q8, vl), vl); + suml2 = __riscv_vmul_vv_i16m2(__riscv_vreinterpret_v_u16m2_i16m2(__riscv_vsub_vx_u16m2(tq0, 1, vl)), q80, vl); + q8 += 16; + } + + uint8_t pow3 = 3; + #pragma GCC unroll 1 + for (int t = 0; t < 4; t++) { + const vuint16m2_t tqn = __riscv_vsrl_vx_u16m2(__riscv_vwmulu_vx_u16m2(__riscv_vmul_vx_u8m1(tqb, pow3, vl), 3, vl), 8, vl); + const vint16m2_t q8n = __riscv_vwcvt_x_x_v_i16m2(__riscv_vle8_v_i8m1(q8, vl), vl); + suml2 = __riscv_vmacc_vv_i16m2(suml2, __riscv_vreinterpret_v_u16m2_i16m2(__riscv_vsub_vx_u16m2(tqn, 1, vl)), q8n, vl); + pow3 *= 3; + q8 += 16; + } + } + + // Third loop. + vint16m2_t suml3; + { + const int vl = 16; + + uint32_t qh; + memcpy(&qh, &x[i].qh[0], 4); + // Prevent fusion with vmv. + __asm__ __volatile__("" : "+r"(qh)); + const vuint8m1_t tqb = __riscv_vreinterpret_v_u32m1_u8m1(__riscv_vmv_v_x_u32m1(qh, vl / 4)); + + const vuint8m1_t p = __riscv_vle8_v_u8m1(pow, vl); + + const vuint16m2_t tq0 = __riscv_vsrl_vx_u16m2(__riscv_vwmulu_vx_u16m2(__riscv_vmul_vv_u8m1(tqb, p, vl), 3, vl), 8, vl); + + const vint16m2_t q80 = __riscv_vwcvt_x_x_v_i16m2(__riscv_vle8_v_i8m1(q8, vl), vl); + + suml3 = __riscv_vmul_vv_i16m2(__riscv_vreinterpret_v_u16m2_i16m2(__riscv_vsub_vx_u16m2(tq0, 1, vl)), q80, vl); + } + + vint16m2_t sumb = __riscv_vadd_vv_i16m2(__riscv_vget_v_i16m4_i16m2(suml1, 0), __riscv_vget_v_i16m4_i16m2(suml1, 1), 16); + sumb = __riscv_vadd_vv_i16m2(sumb, suml2, 16); + sumb = __riscv_vadd_vv_i16m2(sumb, suml3, 16); + + vint32m1_t sum = __riscv_vwredsum_vs_i16m2_i32m1(sumb, __riscv_vmv_v_x_i32m1(0, 1), 16); + sumf += __riscv_vmv_x_s_i32m1_i32(sum) * y[i].d * GGML_CPU_FP16_TO_FP32(x[i].d); + } + + *s = sumf; +} + +static NOINLINE void ggml_vec_dot_tq1_0_q8_K_vl256(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + + const block_tq1_0 * GGML_RESTRICT x = vx; + const block_q8_K * GGML_RESTRICT y = vy; + + const int nb = n / QK_K; + + float sumf = 0.0f; + uint8_t pow[16] = {1, 1, 1, 1, 3, 3, 3, 3, 9, 9, 9, 9, 27, 27, 27, 27}; + + for (int i = 0; i < nb; i++) { + // First loop. + vint16m2_t suml1; + { + const int vl = 32; + vuint8m1_t tq = __riscv_vle8_v_u8m1(x[i].qs, vl); + + vuint16m2_t tq0 = __riscv_vsrl_vx_u16m2(__riscv_vwmulu_vx_u16m2(tq, 3, vl), 8, vl); + vuint16m2_t tq1 = __riscv_vsrl_vx_u16m2(__riscv_vwmulu_vx_u16m2(__riscv_vmul_vx_u8m1(tq, 3, vl), 3, vl), 8, vl); + vuint16m2_t tq2 = __riscv_vsrl_vx_u16m2(__riscv_vwmulu_vx_u16m2(__riscv_vmul_vx_u8m1(tq, 9, vl), 3, vl), 8, vl); + vuint16m2_t tq3 = __riscv_vsrl_vx_u16m2(__riscv_vwmulu_vx_u16m2(__riscv_vmul_vx_u8m1(tq, 27, vl), 3, vl), 8, vl); + vuint16m2_t tq4 = __riscv_vsrl_vx_u16m2(__riscv_vwmulu_vx_u16m2(__riscv_vmul_vx_u8m1(tq, 81, vl), 3, vl), 8, vl); + + vint16m2_t q80 = __riscv_vwcvt_x_x_v_i16m2(__riscv_vle8_v_i8m1(y[i].qs + 0, vl), vl); + vint16m2_t q81 = __riscv_vwcvt_x_x_v_i16m2(__riscv_vle8_v_i8m1(y[i].qs + 32, vl), vl); + vint16m2_t q82 = __riscv_vwcvt_x_x_v_i16m2(__riscv_vle8_v_i8m1(y[i].qs + 64, vl), vl); + vint16m2_t q83 = __riscv_vwcvt_x_x_v_i16m2(__riscv_vle8_v_i8m1(y[i].qs + 96, vl), vl); + vint16m2_t q84 = __riscv_vwcvt_x_x_v_i16m2(__riscv_vle8_v_i8m1(y[i].qs + 128, vl), vl); + + vint16m2_t sum0 = __riscv_vmul_vv_i16m2(__riscv_vreinterpret_v_u16m2_i16m2(__riscv_vsub_vx_u16m2(tq0, 1, vl)), q80, vl); + vint16m2_t sum1 = __riscv_vmul_vv_i16m2(__riscv_vreinterpret_v_u16m2_i16m2(__riscv_vsub_vx_u16m2(tq1, 1, vl)), q81, vl); + vint16m2_t sum2 = __riscv_vmul_vv_i16m2(__riscv_vreinterpret_v_u16m2_i16m2(__riscv_vsub_vx_u16m2(tq2, 1, vl)), q82, vl); + vint16m2_t sum3 = __riscv_vmul_vv_i16m2(__riscv_vreinterpret_v_u16m2_i16m2(__riscv_vsub_vx_u16m2(tq3, 1, vl)), q83, vl); + vint16m2_t sum4 = __riscv_vmul_vv_i16m2(__riscv_vreinterpret_v_u16m2_i16m2(__riscv_vsub_vx_u16m2(tq4, 1, vl)), q84, vl); + + vint16m2_t sumi0 = __riscv_vadd_vv_i16m2(sum0, sum1, vl); + vint16m2_t sumi1 = __riscv_vadd_vv_i16m2(sum2, sum3, vl); + suml1 = __riscv_vadd_vv_i16m2(sum4, __riscv_vadd_vv_i16m2(sumi0, sumi1, vl), vl); + } + + // Second loop. + vint16m1_t suml2; + { + const int vl = 16; + vuint8mf2_t tq = __riscv_vle8_v_u8mf2(x[i].qs + 32, vl); + + vuint16m1_t tq0 = __riscv_vsrl_vx_u16m1(__riscv_vwmulu_vx_u16m1(tq, 3 * 1, vl), 8, vl); + vuint16m1_t tq1 = __riscv_vsrl_vx_u16m1(__riscv_vwmulu_vx_u16m1(__riscv_vmul_vx_u8mf2(tq, 3, vl), 3, vl), 8, vl); + vuint16m1_t tq2 = __riscv_vsrl_vx_u16m1(__riscv_vwmulu_vx_u16m1(__riscv_vmul_vx_u8mf2(tq, 9, vl), 3, vl), 8, vl); + vuint16m1_t tq3 = __riscv_vsrl_vx_u16m1(__riscv_vwmulu_vx_u16m1(__riscv_vmul_vx_u8mf2(tq, 27, vl), 3, vl), 8, vl); + vuint16m1_t tq4 = __riscv_vsrl_vx_u16m1(__riscv_vwmulu_vx_u16m1(__riscv_vmul_vx_u8mf2(tq, 81, vl), 3, vl), 8, vl); + + vint16m1_t q80 = __riscv_vwcvt_x_x_v_i16m1(__riscv_vle8_v_i8mf2(y[i].qs + 160, vl), vl); + vint16m1_t q81 = __riscv_vwcvt_x_x_v_i16m1(__riscv_vle8_v_i8mf2(y[i].qs + 176, vl), vl); + vint16m1_t q82 = __riscv_vwcvt_x_x_v_i16m1(__riscv_vle8_v_i8mf2(y[i].qs + 192, vl), vl); + vint16m1_t q83 = __riscv_vwcvt_x_x_v_i16m1(__riscv_vle8_v_i8mf2(y[i].qs + 208, vl), vl); + vint16m1_t q84 = __riscv_vwcvt_x_x_v_i16m1(__riscv_vle8_v_i8mf2(y[i].qs + 224, vl), vl); + + vint16m1_t sum0 = __riscv_vmul_vv_i16m1(__riscv_vreinterpret_v_u16m1_i16m1(__riscv_vsub_vx_u16m1(tq0, 1, vl)), q80, vl); + vint16m1_t sum1 = __riscv_vmul_vv_i16m1(__riscv_vreinterpret_v_u16m1_i16m1(__riscv_vsub_vx_u16m1(tq1, 1, vl)), q81, vl); + vint16m1_t sum2 = __riscv_vmul_vv_i16m1(__riscv_vreinterpret_v_u16m1_i16m1(__riscv_vsub_vx_u16m1(tq2, 1, vl)), q82, vl); + vint16m1_t sum3 = __riscv_vmul_vv_i16m1(__riscv_vreinterpret_v_u16m1_i16m1(__riscv_vsub_vx_u16m1(tq3, 1, vl)), q83, vl); + vint16m1_t sum4 = __riscv_vmul_vv_i16m1(__riscv_vreinterpret_v_u16m1_i16m1(__riscv_vsub_vx_u16m1(tq4, 1, vl)), q84, vl); + + vint16m1_t sumi0 = __riscv_vadd_vv_i16m1(sum0, sum1, vl); + vint16m1_t sumi1 = __riscv_vadd_vv_i16m1(sum2, sum3, vl); + suml2 = __riscv_vadd_vv_i16m1(sum4, __riscv_vadd_vv_i16m1(sumi0, sumi1, vl), vl); + } + + // Third loop. + vint16m1_t suml3; + { + const int vl = 16; + + uint32_t qh; + memcpy(&qh, &x[i].qh[0], 4); + // Prevent fusion with vmv. + __asm__ __volatile__("" : "+r"(qh)); + vuint8mf2_t tq = __riscv_vreinterpret_v_u32mf2_u8mf2(__riscv_vmv_v_x_u32mf2(qh, vl / 4)); + + vuint8mf2_t p = __riscv_vle8_v_u8mf2(pow, vl); + + vuint16m1_t tq0 = __riscv_vsrl_vx_u16m1(__riscv_vwmulu_vx_u16m1(__riscv_vmul_vv_u8mf2(tq, p, vl), 3, vl), 8, vl); + + vint16m1_t q80 = __riscv_vwcvt_x_x_v_i16m1(__riscv_vle8_v_i8mf2(y[i].qs + 240, vl), vl); + + suml3 = __riscv_vmul_vv_i16m1(__riscv_vreinterpret_v_u16m1_i16m1(__riscv_vsub_vx_u16m1(tq0, 1, vl)), q80, vl); + } + + vint16m1_t sumb = __riscv_vadd_vv_i16m1(__riscv_vget_v_i16m2_i16m1(suml1, 0), __riscv_vget_v_i16m2_i16m1(suml1, 1), 16); + sumb = __riscv_vadd_vv_i16m1(sumb, __riscv_vadd_vv_i16m1(suml2, suml3, 16), 16); + + vint32m1_t sum = __riscv_vwredsum_vs_i16m1_i32m1(sumb, __riscv_vmv_v_x_i32m1(0, 1), 16); + sumf += __riscv_vmv_x_s_i32m1_i32(sum) * y[i].d * GGML_CPU_FP16_TO_FP32(x[i].d); + } + + *s = sumf; +} + +static NOINLINE void ggml_vec_dot_tq1_0_q8_K_vl512(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + + const block_tq1_0 * GGML_RESTRICT x = vx; + const block_q8_K * GGML_RESTRICT y = vy; + + const int nb = n / QK_K; + + float sumf = 0.0f; + uint8_t pow[16] = {1, 1, 1, 1, 3, 3, 3, 3, 9, 9, 9, 9, 27, 27, 27, 27}; + + for (int i = 0; i < nb; i++) { + // First loop. + vint16m1_t suml1; + { + const int vl = 32; + vuint8mf2_t tq = __riscv_vle8_v_u8mf2(x[i].qs, vl); + + vuint16m1_t tq0 = __riscv_vsrl_vx_u16m1(__riscv_vwmulu_vx_u16m1(tq, 3, vl), 8, vl); + vuint16m1_t tq1 = __riscv_vsrl_vx_u16m1(__riscv_vwmulu_vx_u16m1(__riscv_vmul_vx_u8mf2(tq, 3, vl), 3, vl), 8, vl); + vuint16m1_t tq2 = __riscv_vsrl_vx_u16m1(__riscv_vwmulu_vx_u16m1(__riscv_vmul_vx_u8mf2(tq, 9, vl), 3, vl), 8, vl); + vuint16m1_t tq3 = __riscv_vsrl_vx_u16m1(__riscv_vwmulu_vx_u16m1(__riscv_vmul_vx_u8mf2(tq, 27, vl), 3, vl), 8, vl); + vuint16m1_t tq4 = __riscv_vsrl_vx_u16m1(__riscv_vwmulu_vx_u16m1(__riscv_vmul_vx_u8mf2(tq, 81, vl), 3, vl), 8, vl); + + vint16m1_t q80 = __riscv_vwcvt_x_x_v_i16m1(__riscv_vle8_v_i8mf2(y[i].qs + 0, vl), vl); + vint16m1_t q81 = __riscv_vwcvt_x_x_v_i16m1(__riscv_vle8_v_i8mf2(y[i].qs + 32, vl), vl); + vint16m1_t q82 = __riscv_vwcvt_x_x_v_i16m1(__riscv_vle8_v_i8mf2(y[i].qs + 64, vl), vl); + vint16m1_t q83 = __riscv_vwcvt_x_x_v_i16m1(__riscv_vle8_v_i8mf2(y[i].qs + 96, vl), vl); + vint16m1_t q84 = __riscv_vwcvt_x_x_v_i16m1(__riscv_vle8_v_i8mf2(y[i].qs + 128, vl), vl); + + vint16m1_t sum0 = __riscv_vmul_vv_i16m1(__riscv_vreinterpret_v_u16m1_i16m1(__riscv_vsub_vx_u16m1(tq0, 1, vl)), q80, vl); + vint16m1_t sum1 = __riscv_vmul_vv_i16m1(__riscv_vreinterpret_v_u16m1_i16m1(__riscv_vsub_vx_u16m1(tq1, 1, vl)), q81, vl); + vint16m1_t sum2 = __riscv_vmul_vv_i16m1(__riscv_vreinterpret_v_u16m1_i16m1(__riscv_vsub_vx_u16m1(tq2, 1, vl)), q82, vl); + vint16m1_t sum3 = __riscv_vmul_vv_i16m1(__riscv_vreinterpret_v_u16m1_i16m1(__riscv_vsub_vx_u16m1(tq3, 1, vl)), q83, vl); + vint16m1_t sum4 = __riscv_vmul_vv_i16m1(__riscv_vreinterpret_v_u16m1_i16m1(__riscv_vsub_vx_u16m1(tq4, 1, vl)), q84, vl); + + vint16m1_t sumi0 = __riscv_vadd_vv_i16m1(sum0, sum1, vl); + vint16m1_t sumi1 = __riscv_vadd_vv_i16m1(sum2, sum3, vl); + suml1 = __riscv_vadd_vv_i16m1(sum4, __riscv_vadd_vv_i16m1(sumi0, sumi1, vl), vl); + } + + // Second loop. + vint16mf2_t suml2; + { + const int vl = 16; + vuint8mf4_t tq = __riscv_vle8_v_u8mf4(x[i].qs + 32, vl); + + vuint16mf2_t tq0 = __riscv_vsrl_vx_u16mf2(__riscv_vwmulu_vx_u16mf2(tq, 3 * 1, vl), 8, vl); + vuint16mf2_t tq1 = __riscv_vsrl_vx_u16mf2(__riscv_vwmulu_vx_u16mf2(__riscv_vmul_vx_u8mf4(tq, 3, vl), 3, vl), 8, vl); + vuint16mf2_t tq2 = __riscv_vsrl_vx_u16mf2(__riscv_vwmulu_vx_u16mf2(__riscv_vmul_vx_u8mf4(tq, 9, vl), 3, vl), 8, vl); + vuint16mf2_t tq3 = __riscv_vsrl_vx_u16mf2(__riscv_vwmulu_vx_u16mf2(__riscv_vmul_vx_u8mf4(tq, 27, vl), 3, vl), 8, vl); + vuint16mf2_t tq4 = __riscv_vsrl_vx_u16mf2(__riscv_vwmulu_vx_u16mf2(__riscv_vmul_vx_u8mf4(tq, 81, vl), 3, vl), 8, vl); + + vint16mf2_t q80 = __riscv_vwcvt_x_x_v_i16mf2(__riscv_vle8_v_i8mf4(y[i].qs + 160, vl), vl); + vint16mf2_t q81 = __riscv_vwcvt_x_x_v_i16mf2(__riscv_vle8_v_i8mf4(y[i].qs + 176, vl), vl); + vint16mf2_t q82 = __riscv_vwcvt_x_x_v_i16mf2(__riscv_vle8_v_i8mf4(y[i].qs + 192, vl), vl); + vint16mf2_t q83 = __riscv_vwcvt_x_x_v_i16mf2(__riscv_vle8_v_i8mf4(y[i].qs + 208, vl), vl); + vint16mf2_t q84 = __riscv_vwcvt_x_x_v_i16mf2(__riscv_vle8_v_i8mf4(y[i].qs + 224, vl), vl); + + vint16mf2_t sum0 = __riscv_vmul_vv_i16mf2(__riscv_vreinterpret_v_u16mf2_i16mf2(__riscv_vsub_vx_u16mf2(tq0, 1, vl)), q80, vl); + vint16mf2_t sum1 = __riscv_vmul_vv_i16mf2(__riscv_vreinterpret_v_u16mf2_i16mf2(__riscv_vsub_vx_u16mf2(tq1, 1, vl)), q81, vl); + vint16mf2_t sum2 = __riscv_vmul_vv_i16mf2(__riscv_vreinterpret_v_u16mf2_i16mf2(__riscv_vsub_vx_u16mf2(tq2, 1, vl)), q82, vl); + vint16mf2_t sum3 = __riscv_vmul_vv_i16mf2(__riscv_vreinterpret_v_u16mf2_i16mf2(__riscv_vsub_vx_u16mf2(tq3, 1, vl)), q83, vl); + vint16mf2_t sum4 = __riscv_vmul_vv_i16mf2(__riscv_vreinterpret_v_u16mf2_i16mf2(__riscv_vsub_vx_u16mf2(tq4, 1, vl)), q84, vl); + + vint16mf2_t sumi0 = __riscv_vadd_vv_i16mf2(sum0, sum1, vl); + vint16mf2_t sumi1 = __riscv_vadd_vv_i16mf2(sum2, sum3, vl); + suml2 = __riscv_vadd_vv_i16mf2(sum4, __riscv_vadd_vv_i16mf2(sumi0, sumi1, vl), vl); + } + + // Third loop. + vint16mf2_t suml3; + { + const int vl = 16; + + uint32_t qh; + memcpy(&qh, &x[i].qh[0], 4); + // Prevent fusion with vmv. + __asm__ __volatile__("" : "+r"(qh)); + vuint8mf4_t tq = __riscv_vlmul_trunc_v_u8mf2_u8mf4(__riscv_vreinterpret_v_u32mf2_u8mf2(__riscv_vmv_v_x_u32mf2(qh, vl / 4))); + + vuint8mf4_t p = __riscv_vle8_v_u8mf4(pow, vl); + + vuint16mf2_t tq0 = __riscv_vsrl_vx_u16mf2(__riscv_vwmulu_vx_u16mf2(__riscv_vmul_vv_u8mf4(tq, p, vl), 3, vl), 8, vl); + + vint16mf2_t q80 = __riscv_vwcvt_x_x_v_i16mf2(__riscv_vle8_v_i8mf4(y[i].qs + 240, vl), vl); + + suml3 = __riscv_vmul_vv_i16mf2(__riscv_vreinterpret_v_u16mf2_i16mf2(__riscv_vsub_vx_u16mf2(tq0, 1, vl)), q80, vl); + } + + vint32m1_t sum = __riscv_vwredsum_vs_i16m1_i32m1(suml1, __riscv_vmv_v_x_i32m1(0, 1), 32); + sum = __riscv_vwredsum_vs_i16mf2_i32m1(__riscv_vadd_vv_i16mf2(suml2, suml3, 16), sum, 16); + sumf += __riscv_vmv_x_s_i32m1_i32(sum) * y[i].d * GGML_CPU_FP16_TO_FP32(x[i].d); + } + + *s = sumf; +} +#endif + +void ggml_vec_dot_tq1_0_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { +#if defined __riscv_v + switch (__riscv_vlenb() * 8) { + case 128: + ggml_vec_dot_tq1_0_q8_K_vl128(n, s, bs, vx, bx, vy, by, nrc); + break; + case 256: + ggml_vec_dot_tq1_0_q8_K_vl256(n, s, bs, vx, bx, vy, by, nrc); + break; + default: // 512 and above + ggml_vec_dot_tq1_0_q8_K_vl512(n, s, bs, vx, bx, vy, by, nrc); + break; + } +#else + ggml_vec_dot_tq1_0_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc); +#endif +} + +#if defined __riscv_v +static NOINLINE void ggml_vec_dot_tq2_0_q8_K_vl128(const int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + assert(n % QK_K == 0); + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + + const block_tq2_0 * GGML_RESTRICT x = vx; + const block_q8_K * GGML_RESTRICT y = vy; + + const int nb = n / QK_K; + float sumf = 0.0f; + for (int i = 0; i < nb; ++i) { + int32_t sumi = 0; + + for (size_t j = 0; j < sizeof(x[0].qs); j += 32) { + const int8_t * py0 = &y[i].qs[j * 4 + 0 * 32]; + const int8_t * py1 = &y[i].qs[j * 4 + 1 * 32]; + const int8_t * py2 = &y[i].qs[j * 4 + 2 * 32]; + const int8_t * py3 = &y[i].qs[j * 4 + 3 * 32]; + const uint8_t* px = &x[i].qs[j]; + + size_t vl = __riscv_vsetvl_e16m4(32); + vint16m4_t vacc16 = __riscv_vmv_v_x_i16m4(0, vl); + + // Load Raw Packed elements + vl = __riscv_vsetvl_e8m2(32); + vuint8m2_t vx_u8 = __riscv_vle8_v_u8m2(px, vl); + + // Process bits 1:0 + { + // Unpack + vuint8m2_t t0 = __riscv_vand_vx_u8m2(vx_u8, 0x03, vl); + vint8m2_t vq = __riscv_vsub_vx_i8m2(__riscv_vreinterpret_v_u8m2_i8m2(t0), 1, vl); + vint8m2_t vy = __riscv_vle8_v_i8m2(py0, vl); + // Accumulate + vacc16 = __riscv_vwmacc_vv_i16m4(vacc16, vq, vy, vl); + } + __asm__ volatile("" ::: "memory"); + // Process bits 3:2 + { + vuint8m2_t t1 = __riscv_vsrl_vx_u8m2(vx_u8, 2, vl); + t1 = __riscv_vand_vx_u8m2(t1, 0x03, vl); + vint8m2_t vq = __riscv_vsub_vx_i8m2(__riscv_vreinterpret_v_u8m2_i8m2(t1), 1, vl); + + vint8m2_t vy = __riscv_vle8_v_i8m2(py1, vl); + vacc16 = __riscv_vwmacc_vv_i16m4(vacc16, vq, vy, vl); + } + __asm__ volatile("" ::: "memory"); + // Process bits 5:4 + { + vuint8m2_t t2 = __riscv_vsrl_vx_u8m2(vx_u8, 4, vl); + t2 = __riscv_vand_vx_u8m2(t2, 0x03, vl); + vint8m2_t vq = __riscv_vsub_vx_i8m2(__riscv_vreinterpret_v_u8m2_i8m2(t2), 1, vl); + + vint8m2_t vy = __riscv_vle8_v_i8m2(py2, vl); + vacc16 = __riscv_vwmacc_vv_i16m4(vacc16, vq, vy, vl); + } + __asm__ volatile("" ::: "memory"); + // Process bits 7:6 + { + vuint8m2_t t3 = __riscv_vsrl_vx_u8m2(vx_u8, 6, vl); + vint8m2_t vq = __riscv_vsub_vx_i8m2(__riscv_vreinterpret_v_u8m2_i8m2(t3), 1, vl); + + vint8m2_t vy = __riscv_vle8_v_i8m2(py3, vl); + vacc16 = __riscv_vwmacc_vv_i16m4(vacc16, vq, vy, vl); + } + __asm__ volatile("" ::: "memory"); + vl = __riscv_vsetvl_e16m4(32); + vint32m1_t vzero32 = __riscv_vmv_v_x_i32m1(0, 1); + vint32m1_t vred32 = __riscv_vwredsum_vs_i16m4_i32m1(vacc16, vzero32, vl); + sumi += __riscv_vmv_x_s_i32m1_i32(vred32); + } + + const float d = y[i].d * GGML_CPU_FP16_TO_FP32(x[i].d); + sumf += (float)sumi * d; + } + + *s = sumf; +} + +static NOINLINE void ggml_vec_dot_tq2_0_q8_K_vl256(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + assert(n % QK_K == 0); + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + + const block_tq2_0 * GGML_RESTRICT x = vx; + const block_q8_K * GGML_RESTRICT y = vy; + + const int nb = n / QK_K; + + float sumf = 0.0f; + for (int i = 0; i < nb; ++i) { + int32_t sumi = 0; + + for (size_t j = 0; j < sizeof(x[0].qs); j += 32) { + const int8_t * py0 = &y[i].qs[j * 4 + 0 * 32]; + const int8_t * py1 = &y[i].qs[j * 4 + 1 * 32]; + const int8_t * py2 = &y[i].qs[j * 4 + 2 * 32]; + const int8_t * py3 = &y[i].qs[j * 4 + 3 * 32]; + const uint8_t* px = &x[i].qs[j]; + + size_t vlmax_16m2 = __riscv_vsetvl_e16m2(32); + vint16m2_t vacc16 = __riscv_vmv_v_x_i16m2(0, vlmax_16m2); + + size_t vl = __riscv_vsetvl_e8m1(32); + + vuint8m1_t vx_u8 = __riscv_vle8_v_u8m1(px, vl); + + vint8m1_t vy0 = __riscv_vle8_v_i8m1(py0 , vl); + vint8m1_t vy1 = __riscv_vle8_v_i8m1(py1, vl); + vint8m1_t vy2 = __riscv_vle8_v_i8m1(py2, vl); + vint8m1_t vy3 = __riscv_vle8_v_i8m1(py3, vl); + + // l=0 (bits 1:0) + vuint8m1_t t0 = __riscv_vand_vx_u8m1(vx_u8, 0x03, vl); + vint8m1_t vq0 = __riscv_vsub_vx_i8m1(__riscv_vreinterpret_v_u8m1_i8m1(t0), 1, vl); + + // l=1 (bits 3:2) + vuint8m1_t t1 = __riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(vx_u8, 2, vl), 0x03, vl); + vint8m1_t vq1 = __riscv_vsub_vx_i8m1(__riscv_vreinterpret_v_u8m1_i8m1(t1), 1, vl); + + // l=2 (bits 5:4) + vuint8m1_t t2 = __riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(vx_u8, 4, vl), 0x03, vl); + vint8m1_t vq2 = __riscv_vsub_vx_i8m1(__riscv_vreinterpret_v_u8m1_i8m1(t2), 1, vl); + + // l=3 (bits 7:6) + vuint8m1_t t3 = __riscv_vsrl_vx_u8m1(vx_u8, 6, vl); // No final AND needed as vsrl shifts in zeros + vint8m1_t vq3 = __riscv_vsub_vx_i8m1(__riscv_vreinterpret_v_u8m1_i8m1(t3), 1, vl); + + // 4. Multiply and accumulate + vacc16 = __riscv_vwmacc_vv_i16m2(vacc16, vq0, vy0, vl); + vacc16 = __riscv_vwmacc_vv_i16m2(vacc16, vq1, vy1, vl); + vacc16 = __riscv_vwmacc_vv_i16m2(vacc16, vq2, vy2, vl); + vacc16 = __riscv_vwmacc_vv_i16m2(vacc16, vq3, vy3, vl); + + vlmax_16m2 = __riscv_vsetvl_e16m2(32); + vint32m1_t vzero32 = __riscv_vmv_v_x_i32m1(0, 1); + vint32m1_t vred32 = __riscv_vwredsum_vs_i16m2_i32m1(vacc16, vzero32, vlmax_16m2); + + sumi += __riscv_vmv_x_s_i32m1_i32(vred32); + } + const float d = y[i].d * GGML_CPU_FP16_TO_FP32(x[i].d); + sumf += (float)sumi * d; + } + + *s = sumf; +} +#endif + +void ggml_vec_dot_tq2_0_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { +#if defined __riscv_v + switch (__riscv_vlenb() * 8) { + case 128: + ggml_vec_dot_tq2_0_q8_K_vl128(n, s, bs, vx, bx, vy, by, nrc); + break; + default: // 256 and above + ggml_vec_dot_tq2_0_q8_K_vl256(n, s, bs, vx, bx, vy, by, nrc); + break; + } +#else + ggml_vec_dot_tq2_0_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc); +#endif +} + +#if defined __riscv_v +static NOINLINE void ggml_vec_dot_mxfp4_q8_0_vl128(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + assert(n % QK_MXFP4 == 0); + static_assert(QK_MXFP4 == QK8_0, "QK_MXFP4 and QK8_0 must be the same"); + + const block_mxfp4 * GGML_RESTRICT x = vx; + const block_q8_0 * GGML_RESTRICT y = vy; + + const int nb = n / QK_MXFP4; + + int ib = 0; + float sumf = 0; + + // Load the lookup table once. + const vint8m2_t values = __riscv_vle8_v_i8m2(kvalues_mxfp4, 16); + int acc1, acc2; + + // We process 2 blocks at once. + for (; ib + 1 < nb; ib += 2) { + // Weights and activations. + vuint8m1_t mx_packed1 = __riscv_vle8_v_u8m1(x[ib + 0].qs, 16); + vint8m2_t q8b1 = __riscv_vle8_v_i8m2(y[ib + 0].qs, 32); + vuint8m1_t mx_packed2 = __riscv_vle8_v_u8m1(x[ib + 1].qs, 16); + vint8m2_t q8b2 = __riscv_vle8_v_i8m2(y[ib + 1].qs, 32); + + // Unpack the weight blocks. + vuint8m2_t mxbits1 = __riscv_vcreate_v_u8m1_u8m2( + __riscv_vand_vx_u8m1(mx_packed1, 0xf, 16), + __riscv_vsrl_vx_u8m1(mx_packed1, 4, 16) + ); + vuint8m2_t mxbits2 = __riscv_vcreate_v_u8m1_u8m2( + __riscv_vand_vx_u8m1(mx_packed2, 0xf, 16), + __riscv_vsrl_vx_u8m1(mx_packed2, 4, 16) + ); + + // Gather values from the lookup table. + vint8m2_t mxb1 = __riscv_vrgather_vv_i8m2(values, mxbits1, 32); + vint8m2_t mxb2 = __riscv_vrgather_vv_i8m2(values, mxbits2, 32); + + // Accumulation. + vint16m4_t sum1 = __riscv_vwmul_vv_i16m4(q8b1, mxb1, 32); + vint16m4_t sum2 = __riscv_vwmul_vv_i16m4(q8b2, mxb2, 32); + __riscv_vse32_v_i32m1(&acc1,__riscv_vwredsum_vs_i16m4_i32m1(sum1, __riscv_vmv_v_x_i32m1(0, 1), 32), 1); + __riscv_vse32_v_i32m1(&acc2,__riscv_vwredsum_vs_i16m4_i32m1(sum2, __riscv_vmv_v_x_i32m1(0, 1), 32), 1); + sumf += ((GGML_E8M0_TO_FP32_HALF(x[ib + 0].e) * GGML_CPU_FP16_TO_FP32(y[ib + 0].d) * acc1)); + sumf += ((GGML_E8M0_TO_FP32_HALF(x[ib + 1].e) * GGML_CPU_FP16_TO_FP32(y[ib + 1].d) * acc2)); + } + + *s = sumf; +} + +static NOINLINE void ggml_vec_dot_mxfp4_q8_0_vl256(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + assert(n % QK_MXFP4 == 0); + static_assert(QK_MXFP4 == QK8_0, "QK_MXFP4 and QK8_0 must be the same"); + + const block_mxfp4 * GGML_RESTRICT x = vx; + const block_q8_0 * GGML_RESTRICT y = vy; + + const int nb = n / QK_MXFP4; + + int ib = 0; + float sumf = 0; + + // Load the lookup table once. + const vint8mf2_t values = __riscv_vle8_v_i8mf2(kvalues_mxfp4, 16); + int acc1, acc2; + + // We process 2 blocks at once. + for (; ib + 1 < nb; ib+=2) { + // Weights and activations. + vuint8mf2_t mx_packed1 = __riscv_vle8_v_u8mf2(x[ib + 0].qs, 16); + vint8mf2_t q8b_lo1 = __riscv_vle8_v_i8mf2(y[ib + 0].qs, 16); + vint8mf2_t q8b_hi1 = __riscv_vle8_v_i8mf2(y[ib + 0].qs + 16, 16); + vuint8mf2_t mx_packed2 = __riscv_vle8_v_u8mf2(x[ib + 1].qs, 16); + vint8mf2_t q8b_lo2 = __riscv_vle8_v_i8mf2(y[ib + 1].qs, 16); + vint8mf2_t q8b_hi2 = __riscv_vle8_v_i8mf2(y[ib + 1].qs + 16, 16); + + // Unpack the weight blocks. + vuint8mf2_t mxbits_lo1 = __riscv_vand_vx_u8mf2(mx_packed1, 0xf, 16); + vuint8mf2_t mxbits_hi1 = __riscv_vsrl_vx_u8mf2(mx_packed1, 4, 16); + vuint8mf2_t mxbits_lo2 = __riscv_vand_vx_u8mf2(mx_packed2, 0xf, 16); + vuint8mf2_t mxbits_hi2 = __riscv_vsrl_vx_u8mf2(mx_packed2, 4, 16); + + // Gather values from the lookup table. + vint8mf2_t mxb_lo1 = __riscv_vrgather_vv_i8mf2(values, mxbits_lo1, 16); + vint8mf2_t mxb_hi1 = __riscv_vrgather_vv_i8mf2(values, mxbits_hi1, 16); + vint8mf2_t mxb_lo2 = __riscv_vrgather_vv_i8mf2(values, mxbits_lo2, 16); + vint8mf2_t mxb_hi2 = __riscv_vrgather_vv_i8mf2(values, mxbits_hi2, 16); + + // Accumulation. + vint16m1_t sum1 = __riscv_vwmul_vv_i16m1(q8b_lo1, mxb_lo1, 16); + sum1 = __riscv_vwmacc_vv_i16m1(sum1, q8b_hi1, mxb_hi1, 16); + vint16m1_t sum2 = __riscv_vwmul_vv_i16m1(q8b_lo2, mxb_lo2, 16); + sum2 = __riscv_vwmacc_vv_i16m1(sum2, q8b_hi2, mxb_hi2, 16); + __riscv_vse32_v_i32m1(&acc1,__riscv_vwredsum_vs_i16m1_i32m1(sum1, __riscv_vmv_v_x_i32m1(0, 1), 16), 1); + __riscv_vse32_v_i32m1(&acc2,__riscv_vwredsum_vs_i16m1_i32m1(sum2, __riscv_vmv_v_x_i32m1(0, 1), 16), 1); + sumf += ((GGML_E8M0_TO_FP32_HALF(x[ib + 0].e) * GGML_CPU_FP16_TO_FP32(y[ib + 0].d) * acc1)); + sumf += ((GGML_E8M0_TO_FP32_HALF(x[ib + 1].e) * GGML_CPU_FP16_TO_FP32(y[ib + 1].d) * acc2)); + } + + *s = sumf; +} +#endif + +void ggml_vec_dot_mxfp4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { +#if defined __riscv_v + switch (__riscv_vlenb() * 8) { + case 128: + ggml_vec_dot_mxfp4_q8_0_vl128(n, s, bs, vx, bx, vy, by, nrc); + break; + default: // 256 and above + ggml_vec_dot_mxfp4_q8_0_vl256(n, s, bs, vx, bx, vy, by, nrc); + break; + } +#else + ggml_vec_dot_mxfp4_q8_0_generic(n, s, bs, vx, bx, vy, by, nrc); +#endif +} diff --git a/ggml/src/ggml-cpu/arch/riscv/repack.cpp b/ggml/src/ggml-cpu/arch/riscv/repack.cpp index 2a35ff9ad87..c37488cae54 100644 --- a/ggml/src/ggml-cpu/arch/riscv/repack.cpp +++ b/ggml/src/ggml-cpu/arch/riscv/repack.cpp @@ -24,6 +24,93 @@ #define UNUSED GGML_UNUSED +void ggml_quantize_mat_q8_0_4x8(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k) { + assert(QK8_0 == 32); + assert(k % QK8_0 == 0); + const int nb = k / QK8_0; + +#if defined(__riscv_v_intrinsic) + block_q8_0x4 * GGML_RESTRICT y = (block_q8_0x4 *) vy; + const size_t vl_calc = __riscv_vsetvl_e32m8(QK8_0); + const size_t vl_save = __riscv_vsetvl_e64m2(4); + vfloat32m1_t v_scalar_zero = __riscv_vfmv_s_f_f32m1(0.0f, __riscv_vsetvl_e32m1(1)); + + for (int i = 0; i < nb; i++) { + const float *x_block_base = x + i * QK8_0; + vint8m2_t q_r0, q_r1, q_r2, q_r3; + { + vfloat32m8_t v_src = __riscv_vle32_v_f32m8(x_block_base + 0 * k, vl_calc); + vfloat32m8_t v_abs = __riscv_vfabs_v_f32m8(v_src, vl_calc); + vfloat32m1_t v_max = __riscv_vfredmax_vs_f32m8_f32m1(v_abs, v_scalar_zero, vl_calc); + float amax = __riscv_vfmv_f_s_f32m1_f32(v_max); + + float d = amax / 127.0f; + y[i].d[0] = GGML_CPU_FP32_TO_FP16(d); + + float id = d ? 1.0f / d : 0.0f; + vfloat32m8_t v_scaled = __riscv_vfmul_vf_f32m8(v_src, id, vl_calc); + vint16m4_t v_i16 = __riscv_vfncvt_x_f_w_i16m4_rm(v_scaled, 4, vl_calc); + q_r0 = __riscv_vncvt_x_x_w_i8m2(v_i16, vl_calc); + } + asm volatile ("" ::: "memory"); + + { + vfloat32m8_t v_src = __riscv_vle32_v_f32m8(x_block_base + 1 * k, vl_calc); + vfloat32m8_t v_abs = __riscv_vfabs_v_f32m8(v_src, vl_calc); + vfloat32m1_t v_max = __riscv_vfredmax_vs_f32m8_f32m1(v_abs, v_scalar_zero, vl_calc); + float amax = __riscv_vfmv_f_s_f32m1_f32(v_max); + + float d = amax / 127.0f; + y[i].d[1] = GGML_CPU_FP32_TO_FP16(d); + float id = d ? 1.0f / d : 0.0f; + + vfloat32m8_t v_scaled = __riscv_vfmul_vf_f32m8(v_src, id, vl_calc); + vint16m4_t v_i16 = __riscv_vfncvt_x_f_w_i16m4_rm(v_scaled, 4, vl_calc); + q_r1 = __riscv_vncvt_x_x_w_i8m2(v_i16, vl_calc); + } + asm volatile ("" ::: "memory"); + { + vfloat32m8_t v_src = __riscv_vle32_v_f32m8(x_block_base + 2 * k, vl_calc); + vfloat32m8_t v_abs = __riscv_vfabs_v_f32m8(v_src, vl_calc); + vfloat32m1_t v_max = __riscv_vfredmax_vs_f32m8_f32m1(v_abs, v_scalar_zero, vl_calc); + float amax = __riscv_vfmv_f_s_f32m1_f32(v_max); + + float d = amax / 127.0f; + y[i].d[2] = GGML_CPU_FP32_TO_FP16(d); + float id = d ? 1.0f / d : 0.0f; + + vfloat32m8_t v_scaled = __riscv_vfmul_vf_f32m8(v_src, id, vl_calc); + vint16m4_t v_i16 = __riscv_vfncvt_x_f_w_i16m4_rm(v_scaled, 4, vl_calc); + q_r2 = __riscv_vncvt_x_x_w_i8m2(v_i16, vl_calc); + } + asm volatile ("" ::: "memory"); + { + vfloat32m8_t v_src = __riscv_vle32_v_f32m8(x_block_base + 3 * k, vl_calc); + vfloat32m8_t v_abs = __riscv_vfabs_v_f32m8(v_src, vl_calc); + vfloat32m1_t v_max = __riscv_vfredmax_vs_f32m8_f32m1(v_abs, v_scalar_zero, vl_calc); + float amax = __riscv_vfmv_f_s_f32m1_f32(v_max); + + float d = amax / 127.0f; + y[i].d[3] = GGML_CPU_FP32_TO_FP16(d); + float id = d ? 1.0f / d : 0.0f; + + vfloat32m8_t v_scaled = __riscv_vfmul_vf_f32m8(v_src, id, vl_calc); + vint16m4_t v_i16 = __riscv_vfncvt_x_f_w_i16m4_rm(v_scaled, 4, vl_calc); + q_r3 = __riscv_vncvt_x_x_w_i8m2(v_i16, vl_calc); + } + vint64m2_t v_q64_r0 = __riscv_vreinterpret_v_i8m2_i64m2(q_r0); + vint64m2_t v_q64_r1 = __riscv_vreinterpret_v_i8m2_i64m2(q_r1); + vint64m2_t v_q64_r2 = __riscv_vreinterpret_v_i8m2_i64m2(q_r2); + vint64m2_t v_q64_r3 = __riscv_vreinterpret_v_i8m2_i64m2(q_r3); + vint64m2x4_t v_quant_tuple = __riscv_vcreate_v_i64m2x4(v_q64_r0, v_q64_r1, v_q64_r2, v_q64_r3); + __riscv_vsseg4e64_v_i64m2x4((int64_t*)y[i].qs, v_quant_tuple, vl_save); + } +#else + UNUSED(nb); + ggml_quantize_mat_q8_0_4x8_generic(x, vy, k); +#endif +} + void ggml_gemv_q4_0_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { const int qk = QK8_0; const int nb = n / qk; @@ -115,6 +202,471 @@ void ggml_gemv_q4_0_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const vo ggml_gemv_q4_0_8x8_q8_0_generic(n, s, bs, vx, vy, nr, nc); } +#if defined __riscv_zvfh +void ggml_gemv_q4_0_16x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + const int qk = QK8_0; + const int nb = n / qk; + const int ncols_interleaved = 16; + const int blocklen = 1; + + assert (n % qk == 0); + assert (nc % ncols_interleaved == 0); + + UNUSED(s); + UNUSED(bs); + UNUSED(vx); + UNUSED(vy); + UNUSED(nr); + UNUSED(nc); + UNUSED(nb); + UNUSED(ncols_interleaved); + UNUSED(blocklen); + + const block_q8_0 * a_ptr = (const block_q8_0 *) vy; + for (int x = 0; x < nc / ncols_interleaved; x++) { + const block_q4_0x16 * b_ptr = (const block_q4_0x16 *) vx + (x * nb); + + // 1x16 Accumulator + vfloat32m2_t sumf = __riscv_vfmv_v_f_f32m2(0.0f, 16); + + for (int l = 0; l < nb; l++) { + // 1x16 Integer Accumulator + vint16m1_t sumi_0_lo_16 = __riscv_vmv_v_x_i16m1(0.0f, 16); + vint16m1_t sumi_0_hi_16 = __riscv_vmv_v_x_i16m1(0.0f, 16); + + // Accumulation loop. + for (int i = 0; i < QK4_0 / 2; i++) { + // Load `b_ptr`. + const vint8mf2_t b_0_packed = __riscv_vle8_v_i8mf2((const int8_t *)&b_ptr[l].qs[i * 16], 16); + const vint8mf2_t b_0_lo = __riscv_vsra_vx_i8mf2(__riscv_vsll_vx_i8mf2(b_0_packed, 4, 16), 4, 16); + const vint8mf2_t b_0_hi = __riscv_vsra_vx_i8mf2(b_0_packed, 4, 16); + + sumi_0_lo_16 = __riscv_vwmacc_vx_i16m1(sumi_0_lo_16, a_ptr[l].qs[i], b_0_lo, 16); + sumi_0_hi_16 = __riscv_vwmacc_vx_i16m1(sumi_0_hi_16, a_ptr[l].qs[16 + i], b_0_hi, 16); + } + + const vint32m2_t sumi = __riscv_vwadd_vv_i32m2(sumi_0_lo_16, sumi_0_hi_16, 16); + + const vfloat16m1_t b_d = __riscv_vle16_v_f16m1((const _Float16 *)b_ptr[l].d, 16); + const vfloat32m2_t d_0 = __riscv_vfwmul_vf_f32m2(b_d, *(const _Float16 *)&a_ptr[l].d, 16); + + sumf = __riscv_vfmacc_vv_f32m2(sumf, __riscv_vfcvt_f_x_v_f32m2(sumi, 16), d_0, 16); + } + + __riscv_vse32_v_f32m2(s + x * 16, sumf, 16); + } +} + +void ggml_gemv_q4_K_16x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + const int qk = QK_K; + const int nb = n / qk; + const int ncols_interleaved = 16; + const int blocklen = 1; + + assert (n % qk == 0); + assert (nc % ncols_interleaved == 0); + + UNUSED(s); + UNUSED(bs); + UNUSED(vx); + UNUSED(vy); + UNUSED(nr); + UNUSED(nc); + UNUSED(nb); + UNUSED(ncols_interleaved); + UNUSED(blocklen); + + const block_q8_K * a_ptr = (const block_q8_K *) vy; + + for (int x = 0; x < nc / ncols_interleaved; x++) { + const block_q4_Kx16 * b_ptr = (const block_q4_Kx16 *) vx + (x * nb); + + // 1x16 Accumulator + vfloat32m2_t sumf = __riscv_vfmv_v_f_f32m2(0.0f, 16); + + for (int l = 0; l < nb; l++) { + vint32m2_t sumi = __riscv_vmv_v_x_i32m2(0, 16); + + // Load `dmin`. + const vfloat32m2_t dmins_d = __riscv_vfmul_vf_f32m2( + __riscv_vfwcvt_f_f_v_f32m2(__riscv_vle16_v_f16m1((const _Float16 *)b_ptr[l].dmin, 16), 16), a_ptr[l].d, 16); + + // We process 4 sub-blocks at once. + for (int j = 0; j < QK_K / 128; j++) { + // Extract the scales and the mins. + // + // Low bits. + vuint8m2_t scales_mins_lo = __riscv_vle8_v_u8m2(&b_ptr[l].scales[j * 64], 64); + vuint8m2_t scales_lo = __riscv_vand_vx_u8m2(scales_mins_lo, 0x0F, 64); + vuint8m2_t mins_lo = __riscv_vsrl_vx_u8m2(scales_mins_lo, 4, 64); + + // High bits. + vuint8m2_t scales_mins_hi = __riscv_vle8_v_u8m2(&b_ptr[l].scales[128], 64); + vuint8m2_t scales_hi; + vuint8m2_t mins_hi; + if (!j) { + scales_hi = __riscv_vsll_vx_u8m2(__riscv_vand_vx_u8m2(scales_mins_hi, 0x03, 64), 4, 64); + mins_hi = __riscv_vsll_vx_u8m2(__riscv_vand_vx_u8m2(scales_mins_hi, 0x0C, 64), 2, 64); + } else { + scales_hi = __riscv_vand_vx_u8m2(scales_mins_hi, 0x30, 64); + mins_hi = __riscv_vsrl_vx_u8m2(__riscv_vand_vx_u8m2(scales_mins_hi, 0xC0, 64), 2, 64); + } + vuint16m4_t scales = __riscv_vzext_vf2_u16m4(__riscv_vor_vv_u8m2(scales_hi, scales_lo, 64), 64); + vint16m4_t mins = __riscv_vreinterpret_v_u16m4_i16m4(__riscv_vzext_vf2_u16m4(__riscv_vor_vv_u8m2(mins_hi, mins_lo, 64), 64)); + + // Reduce the mins and multiply with `dmin`. + // + // Correct in `sumf`. + vint32m2_t bsums = __riscv_vmv_v_x_i32m2(0, 16); + bsums = __riscv_vwmacc_vx_i32m2(bsums, a_ptr[l].bsums[j * 8] + a_ptr[l].bsums[j * 8 + 1], __riscv_vget_v_i16m4_i16m1(mins, 0), 16); + bsums = __riscv_vwmacc_vx_i32m2(bsums, a_ptr[l].bsums[j * 8 + 2] + a_ptr[l].bsums[j * 8 + 3], __riscv_vget_v_i16m4_i16m1(mins, 1), 16); + bsums = __riscv_vwmacc_vx_i32m2(bsums, a_ptr[l].bsums[j * 8 + 4] + a_ptr[l].bsums[j * 8 + 5], __riscv_vget_v_i16m4_i16m1(mins, 2), 16); + bsums = __riscv_vwmacc_vx_i32m2(bsums, a_ptr[l].bsums[j * 8 + 6] + a_ptr[l].bsums[j * 8 + 7], __riscv_vget_v_i16m4_i16m1(mins, 3), 16); + + sumf = __riscv_vfsub_vv_f32m2(sumf, __riscv_vfmul_vv_f32m2(dmins_d, __riscv_vfcvt_f_x_v_f32m2(bsums, 16), 16), 16); + + // Accumulation for 2 sub-blocks. + // + // This might overflow, so we accumulate in two steps. + // + // Recheck. + for (int k = 0; k < 2; k++) { + vint16m1_t sumi_s_0_16 = __riscv_vmv_v_x_i16m1(0.0f, 16); + vint16m1_t sumi_s_1_16 = __riscv_vmv_v_x_i16m1(0.0f, 16); + + for (int i = k * 16; i < k * 16 + QK4_0 / 2; i++) { + // Load `b_ptr`. + const vuint8mf2_t b_0_packed = __riscv_vle8_v_u8mf2(&b_ptr[l].qs[j * 1024 + i * 16], 16); + const vint8mf2_t b_s_0 = __riscv_vreinterpret_v_u8mf2_i8mf2(__riscv_vand_vx_u8mf2(b_0_packed, 0xF, 16)); + const vint8mf2_t b_s_1 = __riscv_vreinterpret_v_u8mf2_i8mf2(__riscv_vsrl_vx_u8mf2(b_0_packed, 4, 16)); + + sumi_s_0_16 = __riscv_vwmacc_vx_i16m1(sumi_s_0_16, a_ptr[l].qs[j * 128 + i], b_s_0, 16); + sumi_s_1_16 = __riscv_vwmacc_vx_i16m1(sumi_s_1_16, a_ptr[l].qs[j * 128 + 32 + i], b_s_1, 16); + } + + sumi = __riscv_vwmacc_vv_i32m2(sumi, + __riscv_vreinterpret_v_u16m1_i16m1(__riscv_vget_v_u16m4_u16m1(scales, 0)), + sumi_s_0_16, 16); + sumi = __riscv_vwmacc_vv_i32m2(sumi, + __riscv_vreinterpret_v_u16m1_i16m1(__riscv_vget_v_u16m4_u16m1(scales, 1)), + sumi_s_1_16, 16); + } + // Accumulation for 2 sub-blocks. + // + // This might overflow, so we accumulate in two steps. + // + // Recheck. + for (int k = 0; k < 2; k++) { + vint16m1_t sumi_s_0_16 = __riscv_vmv_v_x_i16m1(0.0f, 16); + vint16m1_t sumi_s_1_16 = __riscv_vmv_v_x_i16m1(0.0f, 16); + + for (int i = k * 16; i < k * 16 + QK4_0 / 2; i++) { + // Load `b_ptr`. + const vuint8mf2_t b_0_packed = __riscv_vle8_v_u8mf2(&b_ptr[l].qs[j * 1024 + 512 + i * 16], 16); + const vint8mf2_t b_s_0 = __riscv_vreinterpret_v_u8mf2_i8mf2(__riscv_vand_vx_u8mf2(b_0_packed, 0xF, 16)); + const vint8mf2_t b_s_1 = __riscv_vreinterpret_v_u8mf2_i8mf2(__riscv_vsrl_vx_u8mf2(b_0_packed, 4, 16)); + + sumi_s_0_16 = __riscv_vwmacc_vx_i16m1(sumi_s_0_16, a_ptr[l].qs[j * 128 + 64 + i], b_s_0, 16); + sumi_s_1_16 = __riscv_vwmacc_vx_i16m1(sumi_s_1_16, a_ptr[l].qs[j * 128 + 96 + i], b_s_1, 16); + } + + sumi = __riscv_vwmacc_vv_i32m2(sumi, + __riscv_vreinterpret_v_u16m1_i16m1(__riscv_vget_v_u16m4_u16m1(scales, 2)), + sumi_s_0_16, 16); + sumi = __riscv_vwmacc_vv_i32m2(sumi, + __riscv_vreinterpret_v_u16m1_i16m1(__riscv_vget_v_u16m4_u16m1(scales, 3)), + sumi_s_1_16, 16); + } + } + + const vfloat32m2_t b_d = __riscv_vfwcvt_f_f_v_f32m2(__riscv_vle16_v_f16m1((const _Float16 *)&b_ptr[l].d[0], 16), 16); + const vfloat32m2_t d_0 = __riscv_vfmul_vf_f32m2(b_d, a_ptr[l].d, 16); + + sumf = __riscv_vfmacc_vv_f32m2(sumf, __riscv_vfcvt_f_x_v_f32m2(sumi, 16), d_0, 16); + } + + __riscv_vse32_v_f32m2(s + x * 16, sumf, 16); + } +} + +void ggml_gemv_iq4_nl_16x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + const int qk = QK8_0; + const int nb = n / qk; + const int ncols_interleaved = 16; + const int blocklen = 1; + + assert (n % qk == 0); + assert (nc % ncols_interleaved == 0); + + UNUSED(s); + UNUSED(bs); + UNUSED(vx); + UNUSED(vy); + UNUSED(nr); + UNUSED(nc); + UNUSED(nb); + UNUSED(ncols_interleaved); + UNUSED(blocklen); + + const vint8mf2_t values = __riscv_vle8_v_i8mf2(kvalues_iq4nl, 16); + const block_q8_0 * a_ptr = (const block_q8_0 *) vy; + for (int x = 0; x < nc / ncols_interleaved; x++) { + const block_iq4_nlx16 * b_ptr = (const block_iq4_nlx16 *) vx + (x * nb); + + // 1x16 Accumulator1 + vfloat32m2_t sumf = __riscv_vfmv_v_f_f32m2(0.0f, 16); + + for (int l = 0; l < nb; l++) { + // 1x16 integer accumulator + vint32m2_t sumi = __riscv_vmv_v_x_i32m2(0.0f, 16); + + // Accumulation loop. + for (int i = 0; i < QK4_NL / 2; i++) { + // Load `b_ptr`. + const vuint8mf2_t b_0_packed = __riscv_vle8_v_u8mf2((const uint8_t *)&b_ptr[l].qs[i * 16], 16); + const vint8mf2_t b_0_lo = __riscv_vrgather_vv_i8mf2(values, __riscv_vand_vx_u8mf2(b_0_packed, 0xf, 16), 16); + const vint8mf2_t b_0_hi = __riscv_vrgather_vv_i8mf2(values, __riscv_vsrl_vx_u8mf2(b_0_packed, 4, 16), 16); + // const vint16m1_t b_0_lo_16 = __riscv_vwcvt_x_x_v_i16m1(b_0_lo, 16); + // const vint16m1_t b_0_hi_16 = __riscv_vwcvt_x_x_v_i16m1(b_0_hi, 16); + + const vint16m1_t sumi_lo = __riscv_vwmul_vx_i16m1(b_0_lo, a_ptr[l].qs[i], 16); + const vint16m1_t sumi_hi = __riscv_vwmul_vx_i16m1(b_0_hi, a_ptr[l].qs[16 + i], 16); + sumi = __riscv_vadd_vv_i32m2(sumi, __riscv_vwadd_vv_i32m2(sumi_lo, sumi_hi, 16), 16); + } + + const vfloat16m1_t b_d = __riscv_vle16_v_f16m1((const _Float16 *)b_ptr[l].d, 16); + const vfloat32m2_t d_0 = __riscv_vfwmul_vf_f32m2(b_d, *(const _Float16 *)&a_ptr[l].d, 16); + + sumf = __riscv_vfmacc_vv_f32m2(sumf, __riscv_vfcvt_f_x_v_f32m2(sumi, 16), d_0, 16); + } + + __riscv_vse32_v_f32m2(s + x * 16, sumf, 16); + } +} + +void ggml_gemv_q8_0_16x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + const int qk = QK8_0; + const int nb = n / qk; + const int ncols_interleaved = 16; + const int blocklen = 1; + + assert (n % qk == 0); + assert (nc % ncols_interleaved == 0); + + UNUSED(s); + UNUSED(bs); + UNUSED(vx); + UNUSED(vy); + UNUSED(nr); + UNUSED(nc); + UNUSED(nb); + UNUSED(ncols_interleaved); + UNUSED(blocklen); + UNUSED(bs); + + const block_q8_0 * a_ptr = (const block_q8_0 *) vy; + for (int x = 0; x < nc / ncols_interleaved; x++) { + const block_q8_0x16 * b_ptr = (const block_q8_0x16 *) vx + (x * nb); + + // 1x16 Accumulator + vfloat32m2_t sumf = __riscv_vfmv_v_f_f32m2(0.0f, 16); + + for (int l = 0; l < nb; l++) { + // 1x16 Integer Accumulator + vint32m2_t sumi = __riscv_vmv_v_x_i32m2(0.0f, 16); + + // Accumulation loop. + for (int i = 0; i < QK8_0; i++) { + // Load `b_ptr`. + const vint8mf2_t b_0 = __riscv_vle8_v_i8mf2((const int8_t *)&b_ptr[l].qs[i * 16], 16); + // const vint16m1_t b_0_16 = __riscv_vwcvt_x_x_v_i16m1(b_0, 16); + + sumi = __riscv_vwadd_wv_i32m2(sumi, __riscv_vwmul_vx_i16m1(b_0, a_ptr[l].qs[i], 16), 16); + } + + const vfloat16m1_t b_d = __riscv_vle16_v_f16m1((const _Float16 *)b_ptr[l].d, 16); + const vfloat32m2_t d_0 = __riscv_vfwmul_vf_f32m2(b_d, *(const _Float16 *)&a_ptr[l].d, 16); + + sumf = __riscv_vfmacc_vv_f32m2(sumf, __riscv_vfcvt_f_x_v_f32m2(sumi, 16), d_0, 16); + } + + __riscv_vse32_v_f32m2(s + x * 16, sumf, 16); + } +} + +void ggml_gemv_q2_K_16x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + assert(n % QK_K == 0); + assert(nr == 1); + assert(nc % 16 == 0); + + UNUSED(bs); + + const int N_COLS_TILE = 16; + const int num_k_blocks = n / QK_K; + + const size_t vl = __riscv_vsetvl_e32m2(N_COLS_TILE); + for (int col_tile = 0; col_tile < nc; col_tile += N_COLS_TILE) { + + const block_q8_K* lhs_base_ptr = (const block_q8_K*)vy; + const block_q2_Kx16* rhs_base_ptr = (const block_q2_Kx16*)vx + (col_tile / N_COLS_TILE) * num_k_blocks; + + vfloat32m2_t v_sumf = __riscv_vfmv_v_f_f32m2(0.0f, vl); + + for (int k_block = 0; k_block < num_k_blocks; ++k_block) { + const block_q8_K* lhs_current = &lhs_base_ptr[k_block]; + const block_q2_Kx16* rhs_current = &rhs_base_ptr[k_block]; + + // 1. Prepare Global Min Scales + vfloat16m1_t v_g_min_f16 = __riscv_vle16_v_f16m1((const _Float16*)rhs_current->dmin, vl); + vfloat32m2_t v_g_min_base = __riscv_vfwcvt_f_f_v_f32m2(v_g_min_f16, vl); + + vfloat32m2_t v_g_min_final = __riscv_vfmul_vf_f32m2(v_g_min_base, lhs_current->d, vl); + + vint32m2_t v_isum = __riscv_vmv_v_x_i32m2(0, vl); + + const uint8_t* rhs_qs_ptr = rhs_current->qs; + const uint8_t* rhs_sc_ptr = rhs_current->scales; + const int8_t* lhs_qs_ptr = lhs_current->qs; + + // --- Phase Loop (4 phases x 64 elements) --- + for (int phase = 0; phase < 4; ++phase) { + + // A. Load Scales/Mins + vuint16m1_t v_d_sb_0, v_d_sb_1, v_d_sb_2, v_d_sb_3; + vuint16m1_t v_m_sb_0, v_m_sb_1, v_m_sb_2, v_m_sb_3; + + { + vuint8mf2_t v_raw; + // Sub-block 0 + v_raw = __riscv_vle8_v_u8mf2(rhs_sc_ptr + 0, vl); + v_d_sb_0 = __riscv_vzext_vf2_u16m1(__riscv_vand_vx_u8mf2(v_raw, 0xF, vl), vl); + v_m_sb_0 = __riscv_vzext_vf2_u16m1(__riscv_vsrl_vx_u8mf2(v_raw, 4, vl), vl); + + // Sub-block 1 + v_raw = __riscv_vle8_v_u8mf2(rhs_sc_ptr + 16, vl); + v_d_sb_1 = __riscv_vzext_vf2_u16m1(__riscv_vand_vx_u8mf2(v_raw, 0xF, vl), vl); + v_m_sb_1 = __riscv_vzext_vf2_u16m1(__riscv_vsrl_vx_u8mf2(v_raw, 4, vl), vl); + + // Sub-block 2 + v_raw = __riscv_vle8_v_u8mf2(rhs_sc_ptr + 32, vl); + v_d_sb_2 = __riscv_vzext_vf2_u16m1(__riscv_vand_vx_u8mf2(v_raw, 0xF, vl), vl); + v_m_sb_2 = __riscv_vzext_vf2_u16m1(__riscv_vsrl_vx_u8mf2(v_raw, 4, vl), vl); + + // Sub-block 3 + v_raw = __riscv_vle8_v_u8mf2(rhs_sc_ptr + 48, vl); + v_d_sb_3 = __riscv_vzext_vf2_u16m1(__riscv_vand_vx_u8mf2(v_raw, 0xF, vl), vl); + v_m_sb_3 = __riscv_vzext_vf2_u16m1(__riscv_vsrl_vx_u8mf2(v_raw, 4, vl), vl); + + rhs_sc_ptr += 64; + } + + int base_k_phase = (phase < 2) ? (phase * 16) : (128 + (phase-2)*16); + int k_offsets[4] = {0, 32, 64, 96}; + + // B. Inner Dot Product Loop + for (int l = 0; l < 16; ++l) { + vuint8mf2_t v_rhs_data = __riscv_vle8_v_u8mf2(rhs_qs_ptr, vl); + rhs_qs_ptr += 16; + + // Sub-block 0 + { + vuint8mf2_t v_q2 = __riscv_vand_vx_u8mf2(v_rhs_data, 3, vl); + vint16m1_t v_w = __riscv_vmul_vv_i16m1( + __riscv_vreinterpret_v_u16m1_i16m1(__riscv_vzext_vf2_u16m1(v_q2, vl)), + __riscv_vreinterpret_v_u16m1_i16m1(v_d_sb_0), vl); + + int8_t q8 = lhs_qs_ptr[base_k_phase + k_offsets[0] + l]; + v_isum = __riscv_vwmacc_vx_i32m2(v_isum, (int16_t)q8, v_w, vl); + } + // Sub-block 1 + { + vuint8mf2_t v_q2 = __riscv_vand_vx_u8mf2(__riscv_vsrl_vx_u8mf2(v_rhs_data, 2, vl), 3, vl); + vint16m1_t v_w = __riscv_vmul_vv_i16m1( + __riscv_vreinterpret_v_u16m1_i16m1(__riscv_vzext_vf2_u16m1(v_q2, vl)), + __riscv_vreinterpret_v_u16m1_i16m1(v_d_sb_1), vl); + + int8_t q8 = lhs_qs_ptr[base_k_phase + k_offsets[1] + l]; + v_isum = __riscv_vwmacc_vx_i32m2(v_isum, (int16_t)q8, v_w, vl); + } + // Sub-block 2 + { + vuint8mf2_t v_q2 = __riscv_vand_vx_u8mf2(__riscv_vsrl_vx_u8mf2(v_rhs_data, 4, vl), 3, vl); + vint16m1_t v_w = __riscv_vmul_vv_i16m1( + __riscv_vreinterpret_v_u16m1_i16m1(__riscv_vzext_vf2_u16m1(v_q2, vl)), + __riscv_vreinterpret_v_u16m1_i16m1(v_d_sb_2), vl); + + int8_t q8 = lhs_qs_ptr[base_k_phase + k_offsets[2] + l]; + v_isum = __riscv_vwmacc_vx_i32m2(v_isum, (int16_t)q8, v_w, vl); + } + // Sub-block 3 + { + vuint8mf2_t v_q2 = __riscv_vand_vx_u8mf2(__riscv_vsrl_vx_u8mf2(v_rhs_data, 6, vl), 3, vl); + vint16m1_t v_w = __riscv_vmul_vv_i16m1( + __riscv_vreinterpret_v_u16m1_i16m1(__riscv_vzext_vf2_u16m1(v_q2, vl)), + __riscv_vreinterpret_v_u16m1_i16m1(v_d_sb_3), vl); + + int8_t q8 = lhs_qs_ptr[base_k_phase + k_offsets[3] + l]; + v_isum = __riscv_vwmacc_vx_i32m2(v_isum, (int16_t)q8, v_w, vl); + } + } + + // correction + int sb_base_abs = base_k_phase / 16; + + // Sub-block 0 + { + int sb_idx = sb_base_abs + (k_offsets[0] / 16); + int16_t bsum = lhs_current->bsums[sb_idx]; + vint16m1_t v_min = __riscv_vreinterpret_v_u16m1_i16m1(v_m_sb_0); + vint32m2_t v_c = __riscv_vwmul_vx_i32m2(v_min, bsum, vl); + vfloat32m2_t vf_c = __riscv_vfmul_vv_f32m2(__riscv_vfcvt_f_x_v_f32m2(v_c, vl), v_g_min_final, vl); + v_sumf = __riscv_vfsub_vv_f32m2(v_sumf, vf_c, vl); + } + // Sub-block 1 + { + int sb_idx = sb_base_abs + (k_offsets[1] / 16); + int16_t bsum = lhs_current->bsums[sb_idx]; + vint16m1_t v_min = __riscv_vreinterpret_v_u16m1_i16m1(v_m_sb_1); + vint32m2_t v_c = __riscv_vwmul_vx_i32m2(v_min, bsum, vl); + vfloat32m2_t vf_c = __riscv_vfmul_vv_f32m2(__riscv_vfcvt_f_x_v_f32m2(v_c, vl), v_g_min_final, vl); + v_sumf = __riscv_vfsub_vv_f32m2(v_sumf, vf_c, vl); + } + // Sub-block 2 + { + int sb_idx = sb_base_abs + (k_offsets[2] / 16); + int16_t bsum = lhs_current->bsums[sb_idx]; + vint16m1_t v_min = __riscv_vreinterpret_v_u16m1_i16m1(v_m_sb_2); + vint32m2_t v_c = __riscv_vwmul_vx_i32m2(v_min, bsum, vl); + vfloat32m2_t vf_c = __riscv_vfmul_vv_f32m2(__riscv_vfcvt_f_x_v_f32m2(v_c, vl), v_g_min_final, vl); + v_sumf = __riscv_vfsub_vv_f32m2(v_sumf, vf_c, vl); + } + // Sub-block 3 + { + int sb_idx = sb_base_abs + (k_offsets[3] / 16); + int16_t bsum = lhs_current->bsums[sb_idx]; + vint16m1_t v_min = __riscv_vreinterpret_v_u16m1_i16m1(v_m_sb_3); + vint32m2_t v_c = __riscv_vwmul_vx_i32m2(v_min, bsum, vl); + vfloat32m2_t vf_c = __riscv_vfmul_vv_f32m2(__riscv_vfcvt_f_x_v_f32m2(v_c, vl), v_g_min_final, vl); + v_sumf = __riscv_vfsub_vv_f32m2(v_sumf, vf_c, vl); + } + + } // End Phase Loop + + // Apply global Scales + vfloat16m1_t v_g_all_f16 = __riscv_vle16_v_f16m1((const _Float16*)rhs_current->d, vl); + vfloat32m2_t v_g_all_base = __riscv_vfwcvt_f_f_v_f32m2(v_g_all_f16, vl); + + vfloat32m2_t v_g_all_final = __riscv_vfmul_vf_f32m2(v_g_all_base, lhs_current->d, vl); + vfloat32m2_t v_sum = __riscv_vfcvt_f_x_v_f32m2(v_isum, vl); + v_sum = __riscv_vfmul_vv_f32m2(v_sum, v_g_all_final, vl); + v_sumf = __riscv_vfadd_vv_f32m2(v_sumf, v_sum, vl); + + } // End K-Block + __riscv_vse32_v_f32m2(s + col_tile, v_sumf, vl); + } +} +#endif + void ggml_gemm_q4_0_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { const int qk = QK8_0; const int nb = n / qk; @@ -340,3 +892,812 @@ void ggml_gemm_q4_0_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const vo #endif ggml_gemm_q4_0_8x8_q8_0_generic(n, s, bs, vx, vy, nr, nc); } + +#if defined __riscv_zvfh +void ggml_gemm_q4_0_16x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + const int qk = QK8_0; + const int nb = n / qk; + const int ncols_interleaved = 16; + const int blocklen = 1; + + assert (n % qk == 0); + assert (nr % 4 == 0); + assert (nc % ncols_interleaved == 0); + + UNUSED(s); + UNUSED(bs); + UNUSED(vx); + UNUSED(vy); + UNUSED(nr); + UNUSED(nc); + UNUSED(nb); + UNUSED(ncols_interleaved); + UNUSED(blocklen); + + for (int y = 0; y < nr / 4; y++) { + const block_q8_0x4 * a_ptr = (const block_q8_0x4 *) vy + (y * nb); + for (int x = 0; x < nc / ncols_interleaved; x++) { + const block_q4_0x16 * b_ptr = (const block_q4_0x16 *) vx + (x * nb); + + // 4x16 Accumulators + vfloat32m2_t sumf_0 = __riscv_vfmv_v_f_f32m2(0.0f, 16); + vfloat32m2_t sumf_1 = __riscv_vfmv_v_f_f32m2(0.0f, 16); + vfloat32m2_t sumf_2 = __riscv_vfmv_v_f_f32m2(0.0f, 16); + vfloat32m2_t sumf_3 = __riscv_vfmv_v_f_f32m2(0.0f, 16); + + for (int l = 0; l < nb; l++) { + // 4x16 integer accumulators + vint16m1_t sumi_0_lo_16 = __riscv_vmv_v_x_i16m1(0.0f, 16); + vint16m1_t sumi_1_lo_16 = __riscv_vmv_v_x_i16m1(0.0f, 16); + vint16m1_t sumi_2_lo_16 = __riscv_vmv_v_x_i16m1(0.0f, 16); + vint16m1_t sumi_3_lo_16 = __riscv_vmv_v_x_i16m1(0.0f, 16); + vint16m1_t sumi_0_hi_16 = __riscv_vmv_v_x_i16m1(0.0f, 16); + vint16m1_t sumi_1_hi_16 = __riscv_vmv_v_x_i16m1(0.0f, 16); + vint16m1_t sumi_2_hi_16 = __riscv_vmv_v_x_i16m1(0.0f, 16); + vint16m1_t sumi_3_hi_16 = __riscv_vmv_v_x_i16m1(0.0f, 16); + + // Accumulation loop. + for (int i = 0; i < QK4_0 / 2; i++) { + // Load `b_ptr`. + const vint8mf2_t b_0_packed = __riscv_vle8_v_i8mf2((const int8_t *)&b_ptr[l].qs[i * 16], 16); + const vint8mf2_t b_0_lo = __riscv_vsra_vx_i8mf2(__riscv_vsll_vx_i8mf2(b_0_packed, 4, 16), 4, 16); + const vint8mf2_t b_0_hi = __riscv_vsra_vx_i8mf2(b_0_packed, 4, 16); + + sumi_0_lo_16 = __riscv_vwmacc_vx_i16m1(sumi_0_lo_16, a_ptr[l].qs[i * 4], b_0_lo, 16); + sumi_1_lo_16 = __riscv_vwmacc_vx_i16m1(sumi_1_lo_16, a_ptr[l].qs[i * 4 + 1], b_0_lo, 16); + sumi_2_lo_16 = __riscv_vwmacc_vx_i16m1(sumi_2_lo_16, a_ptr[l].qs[i * 4 + 2], b_0_lo, 16); + sumi_3_lo_16 = __riscv_vwmacc_vx_i16m1(sumi_3_lo_16, a_ptr[l].qs[i * 4 + 3], b_0_lo, 16); + + sumi_0_hi_16 = __riscv_vwmacc_vx_i16m1(sumi_0_hi_16, a_ptr[l].qs[64 + i * 4], b_0_hi, 16); + sumi_1_hi_16 = __riscv_vwmacc_vx_i16m1(sumi_1_hi_16, a_ptr[l].qs[64 + i * 4 + 1], b_0_hi, 16); + sumi_2_hi_16 = __riscv_vwmacc_vx_i16m1(sumi_2_hi_16, a_ptr[l].qs[64 + i * 4 + 2], b_0_hi, 16); + sumi_3_hi_16 = __riscv_vwmacc_vx_i16m1(sumi_3_hi_16, a_ptr[l].qs[64 + i * 4 + 3], b_0_hi, 16); + } + + // Do the final accumulation in i32 to prevent overflow. + const vint32m2_t sumi_0 = __riscv_vwadd_vv_i32m2(sumi_0_lo_16, sumi_0_hi_16, 16); + const vint32m2_t sumi_1 = __riscv_vwadd_vv_i32m2(sumi_1_lo_16, sumi_1_hi_16, 16); + const vint32m2_t sumi_2 = __riscv_vwadd_vv_i32m2(sumi_2_lo_16, sumi_2_hi_16, 16); + const vint32m2_t sumi_3 = __riscv_vwadd_vv_i32m2(sumi_3_lo_16, sumi_3_hi_16, 16); + + const vfloat16m1_t b_d = __riscv_vle16_v_f16m1((const _Float16 *)b_ptr[l].d, 16); + const vfloat32m2_t d_0 = __riscv_vfwmul_vf_f32m2(b_d, *(const _Float16 *)&a_ptr[l].d[0], 16); + const vfloat32m2_t d_1 = __riscv_vfwmul_vf_f32m2(b_d, *(const _Float16 *)&a_ptr[l].d[1], 16); + const vfloat32m2_t d_2 = __riscv_vfwmul_vf_f32m2(b_d, *(const _Float16 *)&a_ptr[l].d[2], 16); + const vfloat32m2_t d_3 = __riscv_vfwmul_vf_f32m2(b_d, *(const _Float16 *)&a_ptr[l].d[3], 16); + + sumf_0 = __riscv_vfmacc_vv_f32m2(sumf_0, __riscv_vfcvt_f_x_v_f32m2(sumi_0, 16), d_0, 16); + sumf_1 = __riscv_vfmacc_vv_f32m2(sumf_1, __riscv_vfcvt_f_x_v_f32m2(sumi_1, 16), d_1, 16); + sumf_2 = __riscv_vfmacc_vv_f32m2(sumf_2, __riscv_vfcvt_f_x_v_f32m2(sumi_2, 16), d_2, 16); + sumf_3 = __riscv_vfmacc_vv_f32m2(sumf_3, __riscv_vfcvt_f_x_v_f32m2(sumi_3, 16), d_3, 16); + } + + __riscv_vse32_v_f32m2(s + (y * 4 + 0) * bs + x * 16, sumf_0, 16); + __riscv_vse32_v_f32m2(s + (y * 4 + 1) * bs + x * 16, sumf_1, 16); + __riscv_vse32_v_f32m2(s + (y * 4 + 2) * bs + x * 16, sumf_2, 16); + __riscv_vse32_v_f32m2(s + (y * 4 + 3) * bs + x * 16, sumf_3, 16); + } + } +} + +void ggml_gemm_q4_K_16x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + const int qk = QK_K; + const int nb = n / qk; + const int ncols_interleaved = 16; + const int blocklen = 1; + + assert (n % qk == 0); + assert (nr % 4 == 0); + assert (nc % ncols_interleaved == 0); + + UNUSED(s); + UNUSED(bs); + UNUSED(vx); + UNUSED(vy); + UNUSED(nr); + UNUSED(nc); + UNUSED(nb); + UNUSED(ncols_interleaved); + UNUSED(blocklen); + + for (int y = 0; y < nr / 4; y++) { + const block_q8_Kx4 * a_ptr = (const block_q8_Kx4 *) vy + (y * nb); + for (int x = 0; x < nc / ncols_interleaved; x++) { + const block_q4_Kx16 * b_ptr = (const block_q4_Kx16 *) vx + (x * nb); + + // 4x16 Accumulators + vfloat32m2_t sumf_0 = __riscv_vfmv_v_f_f32m2(0.0f, 16); + vfloat32m2_t sumf_1 = __riscv_vfmv_v_f_f32m2(0.0f, 16); + vfloat32m2_t sumf_2 = __riscv_vfmv_v_f_f32m2(0.0f, 16); + vfloat32m2_t sumf_3 = __riscv_vfmv_v_f_f32m2(0.0f, 16); + + for (int l = 0; l < nb; l++) { + vint32m2_t sumi_0 = __riscv_vmv_v_x_i32m2(0, 16); + vint32m2_t sumi_1 = __riscv_vmv_v_x_i32m2(0, 16); + vint32m2_t sumi_2 = __riscv_vmv_v_x_i32m2(0, 16); + vint32m2_t sumi_3 = __riscv_vmv_v_x_i32m2(0, 16); + + // Load `dmin`. + const vfloat32m2_t dmins = __riscv_vfwcvt_f_f_v_f32m2(__riscv_vle16_v_f16m1((const _Float16 *)b_ptr[l].dmin, 16), 16); + + // We process 4 sub-blocks at once. + for (int j = 0; j < QK_K / 128; j++) { + // Extract the scales and the mins. + // + // Low bits. + vuint8m2_t scales_mins_lo = __riscv_vle8_v_u8m2(&b_ptr[l].scales[j * 64], 64); + vuint8m2_t scales_lo = __riscv_vand_vx_u8m2(scales_mins_lo, 0x0F, 64); + vuint8m2_t mins_lo = __riscv_vsrl_vx_u8m2(scales_mins_lo, 4, 64); + + // High bits. + vuint8m2_t scales_mins_hi = __riscv_vle8_v_u8m2(&b_ptr[l].scales[128], 64); + vuint8m2_t scales_hi; + vuint8m2_t mins_hi; + if (!j) { + scales_hi = __riscv_vsll_vx_u8m2(__riscv_vand_vx_u8m2(scales_mins_hi, 0x03, 64), 4, 64); + mins_hi = __riscv_vsll_vx_u8m2(__riscv_vand_vx_u8m2(scales_mins_hi, 0x0C, 64), 2, 64); + } else { + scales_hi = __riscv_vand_vx_u8m2(scales_mins_hi, 0x30, 64); + mins_hi = __riscv_vsrl_vx_u8m2(__riscv_vand_vx_u8m2(scales_mins_hi, 0xC0, 64), 2, 64); + } + vuint16m4_t scales = __riscv_vzext_vf2_u16m4(__riscv_vor_vv_u8m2(scales_hi, scales_lo, 64), 64); + vint16m4_t mins = __riscv_vreinterpret_v_u16m4_i16m4(__riscv_vzext_vf2_u16m4(__riscv_vor_vv_u8m2(mins_hi, mins_lo, 64), 64)); + + // Reduce the mins and multiply with `dmin`. + // + // Correct in `sumf`. + vint32m2_t bsums_0 = __riscv_vmv_v_x_i32m2(0, 16); + vint32m2_t bsums_1 = __riscv_vmv_v_x_i32m2(0, 16); + vint32m2_t bsums_2 = __riscv_vmv_v_x_i32m2(0, 16); + vint32m2_t bsums_3 = __riscv_vmv_v_x_i32m2(0, 16); + + bsums_0 = __riscv_vwmacc_vx_i32m2(bsums_0, + a_ptr[l].bsums[j * 32] + a_ptr[l].bsums[j * 32 + 4], + __riscv_vget_v_i16m4_i16m1(mins, 0), 16); + bsums_1 = __riscv_vwmacc_vx_i32m2(bsums_1, + a_ptr[l].bsums[j * 32 + 1] + a_ptr[l].bsums[j * 32 + 5], + __riscv_vget_v_i16m4_i16m1(mins, 0), 16); + bsums_2 = __riscv_vwmacc_vx_i32m2(bsums_2, + a_ptr[l].bsums[j * 32 + 2] + a_ptr[l].bsums[j * 32 + 6], + __riscv_vget_v_i16m4_i16m1(mins, 0), 16); + bsums_3 = __riscv_vwmacc_vx_i32m2(bsums_3, + a_ptr[l].bsums[j * 32 + 3] + a_ptr[l].bsums[j * 32 + 7], + __riscv_vget_v_i16m4_i16m1(mins, 0), 16); + bsums_0 = __riscv_vwmacc_vx_i32m2(bsums_0, + a_ptr[l].bsums[j * 32 + 8] + a_ptr[l].bsums[j * 32 + 8 + 4], + __riscv_vget_v_i16m4_i16m1(mins, 1), 16); + bsums_1 = __riscv_vwmacc_vx_i32m2(bsums_1, + a_ptr[l].bsums[j * 32 + 8 + 1] + a_ptr[l].bsums[j * 32 + 8 + 5], + __riscv_vget_v_i16m4_i16m1(mins, 1), 16); + bsums_2 = __riscv_vwmacc_vx_i32m2(bsums_2, + a_ptr[l].bsums[j * 32 + 8 + 2] + a_ptr[l].bsums[j * 32 + 8 + 6], + __riscv_vget_v_i16m4_i16m1(mins, 1), 16); + bsums_3 = __riscv_vwmacc_vx_i32m2(bsums_3, + a_ptr[l].bsums[j * 32 + 8 + 3] + a_ptr[l].bsums[j * 32 + 8 + 7], + __riscv_vget_v_i16m4_i16m1(mins, 1), 16); + bsums_0 = __riscv_vwmacc_vx_i32m2(bsums_0, + a_ptr[l].bsums[j * 32 + 16] + a_ptr[l].bsums[j * 32 + 16 + 4], + __riscv_vget_v_i16m4_i16m1(mins, 2), 16); + bsums_1 = __riscv_vwmacc_vx_i32m2(bsums_1, + a_ptr[l].bsums[j * 32 + 16 + 1] + a_ptr[l].bsums[j * 32 + 16 + 5], + __riscv_vget_v_i16m4_i16m1(mins, 2), 16); + bsums_2 = __riscv_vwmacc_vx_i32m2(bsums_2, + a_ptr[l].bsums[j * 32 + 16 + 2] + a_ptr[l].bsums[j * 32 + 16 + 6], + __riscv_vget_v_i16m4_i16m1(mins, 2), 16); + bsums_3 = __riscv_vwmacc_vx_i32m2(bsums_3, + a_ptr[l].bsums[j * 32 + 16 + 3] + a_ptr[l].bsums[j * 32 + 16 + 7], + __riscv_vget_v_i16m4_i16m1(mins, 2), 16); + bsums_0 = __riscv_vwmacc_vx_i32m2(bsums_0, + a_ptr[l].bsums[j * 32 + 24 + 0] + a_ptr[l].bsums[j * 32 + 24 + 4], + __riscv_vget_v_i16m4_i16m1(mins, 3), 16); + bsums_1 = __riscv_vwmacc_vx_i32m2(bsums_1, + a_ptr[l].bsums[j * 32 + 24 + 1] + a_ptr[l].bsums[j * 32 + 24 + 5], + __riscv_vget_v_i16m4_i16m1(mins, 3), 16); + bsums_2 = __riscv_vwmacc_vx_i32m2(bsums_2, + a_ptr[l].bsums[j * 32 + 24 + 2] + a_ptr[l].bsums[j * 32 + 24 + 6], + __riscv_vget_v_i16m4_i16m1(mins, 3), 16); + bsums_3 = __riscv_vwmacc_vx_i32m2(bsums_3, + a_ptr[l].bsums[j * 32 + 24 + 3] + a_ptr[l].bsums[j * 32 + 24 + 7], + __riscv_vget_v_i16m4_i16m1(mins, 3), 16); + + const vfloat32m2_t dmins_d_0 = __riscv_vfmul_vf_f32m2(dmins, a_ptr[l].d[0], 16); + const vfloat32m2_t dmins_d_1 = __riscv_vfmul_vf_f32m2(dmins, a_ptr[l].d[1], 16); + const vfloat32m2_t dmins_d_2 = __riscv_vfmul_vf_f32m2(dmins, a_ptr[l].d[2], 16); + const vfloat32m2_t dmins_d_3 = __riscv_vfmul_vf_f32m2(dmins, a_ptr[l].d[3], 16); + + sumf_0 = __riscv_vfsub_vv_f32m2(sumf_0, __riscv_vfmul_vv_f32m2(dmins_d_0, __riscv_vfcvt_f_x_v_f32m2(bsums_0, 16), 16), 16); + sumf_1 = __riscv_vfsub_vv_f32m2(sumf_1, __riscv_vfmul_vv_f32m2(dmins_d_1, __riscv_vfcvt_f_x_v_f32m2(bsums_1, 16), 16), 16); + sumf_2 = __riscv_vfsub_vv_f32m2(sumf_2, __riscv_vfmul_vv_f32m2(dmins_d_2, __riscv_vfcvt_f_x_v_f32m2(bsums_2, 16), 16), 16); + sumf_3 = __riscv_vfsub_vv_f32m2(sumf_3, __riscv_vfmul_vv_f32m2(dmins_d_3, __riscv_vfcvt_f_x_v_f32m2(bsums_3, 16), 16), 16); + + + // Accumulation for 2 sub-blocks. + // + // This might overflow, so we accumulate in two steps. + // + // Recheck. + for (int k = 0; k < 2; k++) { + // 4x16 integer accumulators + vint16m1_t sumi_0_s_0_16 = __riscv_vmv_v_x_i16m1(0.0f, 16); + vint16m1_t sumi_1_s_0_16 = __riscv_vmv_v_x_i16m1(0.0f, 16); + vint16m1_t sumi_2_s_0_16 = __riscv_vmv_v_x_i16m1(0.0f, 16); + vint16m1_t sumi_3_s_0_16 = __riscv_vmv_v_x_i16m1(0.0f, 16); + vint16m1_t sumi_0_s_1_16 = __riscv_vmv_v_x_i16m1(0.0f, 16); + vint16m1_t sumi_1_s_1_16 = __riscv_vmv_v_x_i16m1(0.0f, 16); + vint16m1_t sumi_2_s_1_16 = __riscv_vmv_v_x_i16m1(0.0f, 16); + vint16m1_t sumi_3_s_1_16 = __riscv_vmv_v_x_i16m1(0.0f, 16); + + for (int i = k * 16; i < k * 16 + QK4_0 / 2; i++) { + // Load `b_ptr`. + const vuint8mf2_t b_0_packed = __riscv_vle8_v_u8mf2(&b_ptr[l].qs[j * 1024 + i * 16], 16); + const vint8mf2_t b_s_0 = __riscv_vreinterpret_v_u8mf2_i8mf2(__riscv_vand_vx_u8mf2(b_0_packed, 0xF, 16)); + const vint8mf2_t b_s_1 = __riscv_vreinterpret_v_u8mf2_i8mf2(__riscv_vsrl_vx_u8mf2(b_0_packed, 4, 16)); + + sumi_0_s_0_16 = __riscv_vwmacc_vx_i16m1(sumi_0_s_0_16, a_ptr[l].qs[j * 512 + i * 4], b_s_0, 16); + sumi_1_s_0_16 = __riscv_vwmacc_vx_i16m1(sumi_1_s_0_16, a_ptr[l].qs[j * 512 + i * 4 + 1], b_s_0, 16); + sumi_2_s_0_16 = __riscv_vwmacc_vx_i16m1(sumi_2_s_0_16, a_ptr[l].qs[j * 512 + i * 4 + 2], b_s_0, 16); + sumi_3_s_0_16 = __riscv_vwmacc_vx_i16m1(sumi_3_s_0_16, a_ptr[l].qs[j * 512 + i * 4 + 3], b_s_0, 16); + + sumi_0_s_1_16 = __riscv_vwmacc_vx_i16m1(sumi_0_s_1_16, a_ptr[l].qs[j * 512 + 128 + i * 4], b_s_1, 16); + sumi_1_s_1_16 = __riscv_vwmacc_vx_i16m1(sumi_1_s_1_16, a_ptr[l].qs[j * 512 + 128 + i * 4 + 1], b_s_1, 16); + sumi_2_s_1_16 = __riscv_vwmacc_vx_i16m1(sumi_2_s_1_16, a_ptr[l].qs[j * 512 + 128 + i * 4 + 2], b_s_1, 16); + sumi_3_s_1_16 = __riscv_vwmacc_vx_i16m1(sumi_3_s_1_16, a_ptr[l].qs[j * 512 + 128 + i * 4 + 3], b_s_1, 16); + } + + sumi_0 = __riscv_vwmacc_vv_i32m2(sumi_0, + __riscv_vreinterpret_v_u16m1_i16m1(__riscv_vget_v_u16m4_u16m1(scales, 0)), + sumi_0_s_0_16, 16); + sumi_0 = __riscv_vwmacc_vv_i32m2(sumi_0, + __riscv_vreinterpret_v_u16m1_i16m1(__riscv_vget_v_u16m4_u16m1(scales, 1)), + sumi_0_s_1_16, 16); + sumi_1 = __riscv_vwmacc_vv_i32m2(sumi_1, + __riscv_vreinterpret_v_u16m1_i16m1(__riscv_vget_v_u16m4_u16m1(scales, 0)), + sumi_1_s_0_16, 16); + sumi_1 = __riscv_vwmacc_vv_i32m2(sumi_1, + __riscv_vreinterpret_v_u16m1_i16m1(__riscv_vget_v_u16m4_u16m1(scales, 1)), + sumi_1_s_1_16, 16); + sumi_2 = __riscv_vwmacc_vv_i32m2(sumi_2, + __riscv_vreinterpret_v_u16m1_i16m1(__riscv_vget_v_u16m4_u16m1(scales, 0)), + sumi_2_s_0_16, 16); + sumi_2 = __riscv_vwmacc_vv_i32m2(sumi_2, + __riscv_vreinterpret_v_u16m1_i16m1(__riscv_vget_v_u16m4_u16m1(scales, 1)), + sumi_2_s_1_16, 16); + sumi_3 = __riscv_vwmacc_vv_i32m2(sumi_3, + __riscv_vreinterpret_v_u16m1_i16m1(__riscv_vget_v_u16m4_u16m1(scales, 0)), + sumi_3_s_0_16, 16); + sumi_3 = __riscv_vwmacc_vv_i32m2(sumi_3, + __riscv_vreinterpret_v_u16m1_i16m1(__riscv_vget_v_u16m4_u16m1(scales, 1)), + sumi_3_s_1_16, 16); + } + // Accumulation for 2 sub-blocks. + // + // This might overflow, so we accumulate in two steps. + // + // Recheck. + for (int k = 0; k < 2; k++) { + // 4x16 integer accumulators + vint16m1_t sumi_0_s_0_16 = __riscv_vmv_v_x_i16m1(0.0f, 16); + vint16m1_t sumi_1_s_0_16 = __riscv_vmv_v_x_i16m1(0.0f, 16); + vint16m1_t sumi_2_s_0_16 = __riscv_vmv_v_x_i16m1(0.0f, 16); + vint16m1_t sumi_3_s_0_16 = __riscv_vmv_v_x_i16m1(0.0f, 16); + vint16m1_t sumi_0_s_1_16 = __riscv_vmv_v_x_i16m1(0.0f, 16); + vint16m1_t sumi_1_s_1_16 = __riscv_vmv_v_x_i16m1(0.0f, 16); + vint16m1_t sumi_2_s_1_16 = __riscv_vmv_v_x_i16m1(0.0f, 16); + vint16m1_t sumi_3_s_1_16 = __riscv_vmv_v_x_i16m1(0.0f, 16); + + for (int i = k * 16; i < k * 16 + QK4_0 / 2; i++) { + // Load `b_ptr`. + const vuint8mf2_t b_0_packed = __riscv_vle8_v_u8mf2(&b_ptr[l].qs[j * 1024 + 512 + i * 16], 16); + const vint8mf2_t b_s_0 = __riscv_vreinterpret_v_u8mf2_i8mf2(__riscv_vand_vx_u8mf2(b_0_packed, 0xF, 16)); + const vint8mf2_t b_s_1 = __riscv_vreinterpret_v_u8mf2_i8mf2(__riscv_vsrl_vx_u8mf2(b_0_packed, 4, 16)); + + sumi_0_s_0_16 = __riscv_vwmacc_vx_i16m1(sumi_0_s_0_16, a_ptr[l].qs[j * 512 + 256 + i * 4], b_s_0, 16); + sumi_1_s_0_16 = __riscv_vwmacc_vx_i16m1(sumi_1_s_0_16, a_ptr[l].qs[j * 512 + 256 + i * 4 + 1], b_s_0, 16); + sumi_2_s_0_16 = __riscv_vwmacc_vx_i16m1(sumi_2_s_0_16, a_ptr[l].qs[j * 512 + 256 + i * 4 + 2], b_s_0, 16); + sumi_3_s_0_16 = __riscv_vwmacc_vx_i16m1(sumi_3_s_0_16, a_ptr[l].qs[j * 512 + 256 + i * 4 + 3], b_s_0, 16); + + sumi_0_s_1_16 = __riscv_vwmacc_vx_i16m1(sumi_0_s_1_16, a_ptr[l].qs[j * 512 + 384 + i * 4], b_s_1, 16); + sumi_1_s_1_16 = __riscv_vwmacc_vx_i16m1(sumi_1_s_1_16, a_ptr[l].qs[j * 512 + 384 + i * 4 + 1], b_s_1, 16); + sumi_2_s_1_16 = __riscv_vwmacc_vx_i16m1(sumi_2_s_1_16, a_ptr[l].qs[j * 512 + 384 + i * 4 + 2], b_s_1, 16); + sumi_3_s_1_16 = __riscv_vwmacc_vx_i16m1(sumi_3_s_1_16, a_ptr[l].qs[j * 512 + 384 + i * 4 + 3], b_s_1, 16); + } + + sumi_0 = __riscv_vwmacc_vv_i32m2(sumi_0, + __riscv_vreinterpret_v_u16m1_i16m1(__riscv_vget_v_u16m4_u16m1(scales, 2)), + sumi_0_s_0_16, 16); + sumi_0 = __riscv_vwmacc_vv_i32m2(sumi_0, + __riscv_vreinterpret_v_u16m1_i16m1(__riscv_vget_v_u16m4_u16m1(scales, 3)), + sumi_0_s_1_16, 16); + sumi_1 = __riscv_vwmacc_vv_i32m2(sumi_1, + __riscv_vreinterpret_v_u16m1_i16m1(__riscv_vget_v_u16m4_u16m1(scales, 2)), + sumi_1_s_0_16, 16); + sumi_1 = __riscv_vwmacc_vv_i32m2(sumi_1, + __riscv_vreinterpret_v_u16m1_i16m1(__riscv_vget_v_u16m4_u16m1(scales, 3)), + sumi_1_s_1_16, 16); + sumi_2 = __riscv_vwmacc_vv_i32m2(sumi_2, + __riscv_vreinterpret_v_u16m1_i16m1(__riscv_vget_v_u16m4_u16m1(scales, 2)), + sumi_2_s_0_16, 16); + sumi_2 = __riscv_vwmacc_vv_i32m2(sumi_2, + __riscv_vreinterpret_v_u16m1_i16m1(__riscv_vget_v_u16m4_u16m1(scales, 3)), + sumi_2_s_1_16, 16); + sumi_3 = __riscv_vwmacc_vv_i32m2(sumi_3, + __riscv_vreinterpret_v_u16m1_i16m1(__riscv_vget_v_u16m4_u16m1(scales, 2)), + sumi_3_s_0_16, 16); + sumi_3 = __riscv_vwmacc_vv_i32m2(sumi_3, + __riscv_vreinterpret_v_u16m1_i16m1(__riscv_vget_v_u16m4_u16m1(scales, 3)), + sumi_3_s_1_16, 16); + } + } + + const vfloat32m2_t b_d = __riscv_vfwcvt_f_f_v_f32m2(__riscv_vle16_v_f16m1((const _Float16 *)b_ptr[l].d, 16), 16); + const vfloat32m2_t d_0 = __riscv_vfmul_vf_f32m2(b_d, a_ptr[l].d[0], 16); + const vfloat32m2_t d_1 = __riscv_vfmul_vf_f32m2(b_d, a_ptr[l].d[1], 16); + const vfloat32m2_t d_2 = __riscv_vfmul_vf_f32m2(b_d, a_ptr[l].d[2], 16); + const vfloat32m2_t d_3 = __riscv_vfmul_vf_f32m2(b_d, a_ptr[l].d[3], 16); + + sumf_0 = __riscv_vfmacc_vv_f32m2(sumf_0, __riscv_vfcvt_f_x_v_f32m2(sumi_0, 16), d_0, 16); + sumf_1 = __riscv_vfmacc_vv_f32m2(sumf_1, __riscv_vfcvt_f_x_v_f32m2(sumi_1, 16), d_1, 16); + sumf_2 = __riscv_vfmacc_vv_f32m2(sumf_2, __riscv_vfcvt_f_x_v_f32m2(sumi_2, 16), d_2, 16); + sumf_3 = __riscv_vfmacc_vv_f32m2(sumf_3, __riscv_vfcvt_f_x_v_f32m2(sumi_3, 16), d_3, 16); + } + + __riscv_vse32_v_f32m2(s + (y * 4 + 0) * bs + x * 16, sumf_0, 16); + __riscv_vse32_v_f32m2(s + (y * 4 + 1) * bs + x * 16, sumf_1, 16); + __riscv_vse32_v_f32m2(s + (y * 4 + 2) * bs + x * 16, sumf_2, 16); + __riscv_vse32_v_f32m2(s + (y * 4 + 3) * bs + x * 16, sumf_3, 16); + } + } +} + +void ggml_gemm_iq4_nl_16x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + const int qk = QK8_0; + const int nb = n / qk; + const int ncols_interleaved = 16; + const int blocklen = 1; + + assert (n % qk == 0); + assert (nr % 4 == 0); + assert (nc % ncols_interleaved == 0); + + UNUSED(s); + UNUSED(bs); + UNUSED(vx); + UNUSED(vy); + UNUSED(nr); + UNUSED(nc); + UNUSED(nb); + UNUSED(ncols_interleaved); + UNUSED(blocklen); + + const vint8mf2_t values = __riscv_vle8_v_i8mf2(kvalues_iq4nl, 16); + for (int y = 0; y < nr / 4; y++) { + const block_q8_0x4 * a_ptr = (const block_q8_0x4 *) vy + (y * nb); + for (int x = 0; x < nc / ncols_interleaved; x++) { + const block_iq4_nlx16 * b_ptr = (const block_iq4_nlx16 *) vx + (x * nb); + + // 4x16 Accumulators + vfloat32m2_t sumf_0 = __riscv_vfmv_v_f_f32m2(0.0f, 16); + vfloat32m2_t sumf_1 = __riscv_vfmv_v_f_f32m2(0.0f, 16); + vfloat32m2_t sumf_2 = __riscv_vfmv_v_f_f32m2(0.0f, 16); + vfloat32m2_t sumf_3 = __riscv_vfmv_v_f_f32m2(0.0f, 16); + + for (int l = 0; l < nb; l++) { + // 4x16 integer accumulators + vint32m2_t sumi_0 = __riscv_vmv_v_x_i32m2(0.0f, 16); + vint32m2_t sumi_1 = __riscv_vmv_v_x_i32m2(0.0f, 16); + vint32m2_t sumi_2 = __riscv_vmv_v_x_i32m2(0.0f, 16); + vint32m2_t sumi_3 = __riscv_vmv_v_x_i32m2(0.0f, 16); + + // Accumulation loop. + for (int i = 0; i < QK4_NL / 2; i++) { + // Load `b_ptr`. + const vuint8mf2_t b_0_packed = __riscv_vle8_v_u8mf2((const uint8_t *)&b_ptr[l].qs[i * 16], 16); + const vint8mf2_t b_0_lo = __riscv_vrgather_vv_i8mf2(values, __riscv_vand_vx_u8mf2(b_0_packed, 0xf, 16), 16); + const vint8mf2_t b_0_hi = __riscv_vrgather_vv_i8mf2(values, __riscv_vsrl_vx_u8mf2(b_0_packed, 4, 16), 16); + // const vint16m1_t b_0_lo_16 = __riscv_vwcvt_x_x_v_i16m1(b_0_lo, 16); + // const vint16m1_t b_0_hi_16 = __riscv_vwcvt_x_x_v_i16m1(b_0_hi, 16); + + const vint16m1_t sumi_0_lo = __riscv_vwmul_vx_i16m1(b_0_lo, a_ptr[l].qs[i * 4], 16); + const vint16m1_t sumi_1_lo = __riscv_vwmul_vx_i16m1(b_0_lo, a_ptr[l].qs[i * 4 + 1], 16); + const vint16m1_t sumi_2_lo = __riscv_vwmul_vx_i16m1(b_0_lo, a_ptr[l].qs[i * 4 + 2], 16); + const vint16m1_t sumi_3_lo = __riscv_vwmul_vx_i16m1(b_0_lo, a_ptr[l].qs[i * 4 + 3], 16); + + const vint16m1_t sumi_0_hi = __riscv_vwmul_vx_i16m1(b_0_hi, a_ptr[l].qs[64 + i * 4], 16); + const vint16m1_t sumi_1_hi = __riscv_vwmul_vx_i16m1(b_0_hi, a_ptr[l].qs[64 + i * 4 + 1], 16); + const vint16m1_t sumi_2_hi = __riscv_vwmul_vx_i16m1(b_0_hi, a_ptr[l].qs[64 + i * 4 + 2], 16); + const vint16m1_t sumi_3_hi = __riscv_vwmul_vx_i16m1(b_0_hi, a_ptr[l].qs[64 + i * 4 + 3], 16); + + sumi_0 = __riscv_vadd_vv_i32m2(sumi_0, __riscv_vwadd_vv_i32m2(sumi_0_lo, sumi_0_hi, 16), 16); + sumi_1 = __riscv_vadd_vv_i32m2(sumi_1, __riscv_vwadd_vv_i32m2(sumi_1_lo, sumi_1_hi, 16), 16); + sumi_2 = __riscv_vadd_vv_i32m2(sumi_2, __riscv_vwadd_vv_i32m2(sumi_2_lo, sumi_2_hi, 16), 16); + sumi_3 = __riscv_vadd_vv_i32m2(sumi_3, __riscv_vwadd_vv_i32m2(sumi_3_lo, sumi_3_hi, 16), 16); + } + + const vfloat16m1_t b_d = __riscv_vle16_v_f16m1((const _Float16 *)b_ptr[l].d, 16); + const vfloat32m2_t d_0 = __riscv_vfwmul_vf_f32m2(b_d, *(const _Float16 *)&a_ptr[l].d[0], 16); + const vfloat32m2_t d_1 = __riscv_vfwmul_vf_f32m2(b_d, *(const _Float16 *)&a_ptr[l].d[1], 16); + const vfloat32m2_t d_2 = __riscv_vfwmul_vf_f32m2(b_d, *(const _Float16 *)&a_ptr[l].d[2], 16); + const vfloat32m2_t d_3 = __riscv_vfwmul_vf_f32m2(b_d, *(const _Float16 *)&a_ptr[l].d[3], 16); + + sumf_0 = __riscv_vfmacc_vv_f32m2(sumf_0, __riscv_vfcvt_f_x_v_f32m2(sumi_0, 16), d_0, 16); + sumf_1 = __riscv_vfmacc_vv_f32m2(sumf_1, __riscv_vfcvt_f_x_v_f32m2(sumi_1, 16), d_1, 16); + sumf_2 = __riscv_vfmacc_vv_f32m2(sumf_2, __riscv_vfcvt_f_x_v_f32m2(sumi_2, 16), d_2, 16); + sumf_3 = __riscv_vfmacc_vv_f32m2(sumf_3, __riscv_vfcvt_f_x_v_f32m2(sumi_3, 16), d_3, 16); + } + + __riscv_vse32_v_f32m2(s + (y * 4 + 0) * bs + x * 16, sumf_0, 16); + __riscv_vse32_v_f32m2(s + (y * 4 + 1) * bs + x * 16, sumf_1, 16); + __riscv_vse32_v_f32m2(s + (y * 4 + 2) * bs + x * 16, sumf_2, 16); + __riscv_vse32_v_f32m2(s + (y * 4 + 3) * bs + x * 16, sumf_3, 16); + } + } +} + +void ggml_gemm_q8_0_16x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + const int qk = QK8_0; + const int nb = n / qk; + const int ncols_interleaved = 16; + const int blocklen = 1; + + assert (n % qk == 0); + assert (nr % 4 == 0); + assert (nc % ncols_interleaved == 0); + + UNUSED(s); + UNUSED(bs); + UNUSED(vx); + UNUSED(vy); + UNUSED(nr); + UNUSED(nc); + UNUSED(nb); + UNUSED(ncols_interleaved); + UNUSED(blocklen); + + for (int y = 0; y < nr / 4; y++) { + const block_q8_0x4 * a_ptr = (const block_q8_0x4 *) vy + (y * nb); + for (int x = 0; x < nc / ncols_interleaved; x++) { + const block_q8_0x16 * b_ptr = (const block_q8_0x16 *) vx + (x * nb); + + // 4x16 Accumulators + vfloat32m2_t sumf_0 = __riscv_vfmv_v_f_f32m2(0.0f, 16); + vfloat32m2_t sumf_1 = __riscv_vfmv_v_f_f32m2(0.0f, 16); + vfloat32m2_t sumf_2 = __riscv_vfmv_v_f_f32m2(0.0f, 16); + vfloat32m2_t sumf_3 = __riscv_vfmv_v_f_f32m2(0.0f, 16); + + for (int l = 0; l < nb; l++) { + // 4x16 Integer Accumulators + vint32m2_t sumi_0 = __riscv_vmv_v_x_i32m2(0.0f, 16); + vint32m2_t sumi_1 = __riscv_vmv_v_x_i32m2(0.0f, 16); + vint32m2_t sumi_2 = __riscv_vmv_v_x_i32m2(0.0f, 16); + vint32m2_t sumi_3 = __riscv_vmv_v_x_i32m2(0.0f, 16); + + // Accumulation loop. + for (int i = 0; i < QK8_0; i++) { + // Load `b_ptr`. + const vint8mf2_t b_0 = __riscv_vle8_v_i8mf2((const int8_t *)&b_ptr[l].qs[i * 16], 16); + // const vint16m1_t b_0_16 = __riscv_vwcvt_x_x_v_i16m1(b_0, 16); + + sumi_0 = __riscv_vwadd_wv_i32m2(sumi_0, __riscv_vwmul_vx_i16m1(b_0, a_ptr[l].qs[i * 4 + 0], 16), 16); + sumi_1 = __riscv_vwadd_wv_i32m2(sumi_1, __riscv_vwmul_vx_i16m1(b_0, a_ptr[l].qs[i * 4 + 1], 16), 16); + sumi_2 = __riscv_vwadd_wv_i32m2(sumi_2, __riscv_vwmul_vx_i16m1(b_0, a_ptr[l].qs[i * 4 + 2], 16), 16); + sumi_3 = __riscv_vwadd_wv_i32m2(sumi_3, __riscv_vwmul_vx_i16m1(b_0, a_ptr[l].qs[i * 4 + 3], 16), 16); + } + + const vfloat16m1_t b_d = __riscv_vle16_v_f16m1((const _Float16 *)b_ptr[l].d, 16); + const vfloat32m2_t d_0 = __riscv_vfwmul_vf_f32m2(b_d, *(const _Float16 *)&a_ptr[l].d[0], 16); + const vfloat32m2_t d_1 = __riscv_vfwmul_vf_f32m2(b_d, *(const _Float16 *)&a_ptr[l].d[1], 16); + const vfloat32m2_t d_2 = __riscv_vfwmul_vf_f32m2(b_d, *(const _Float16 *)&a_ptr[l].d[2], 16); + const vfloat32m2_t d_3 = __riscv_vfwmul_vf_f32m2(b_d, *(const _Float16 *)&a_ptr[l].d[3], 16); + + sumf_0 = __riscv_vfmacc_vv_f32m2(sumf_0, __riscv_vfcvt_f_x_v_f32m2(sumi_0, 16), d_0, 16); + sumf_1 = __riscv_vfmacc_vv_f32m2(sumf_1, __riscv_vfcvt_f_x_v_f32m2(sumi_1, 16), d_1, 16); + sumf_2 = __riscv_vfmacc_vv_f32m2(sumf_2, __riscv_vfcvt_f_x_v_f32m2(sumi_2, 16), d_2, 16); + sumf_3 = __riscv_vfmacc_vv_f32m2(sumf_3, __riscv_vfcvt_f_x_v_f32m2(sumi_3, 16), d_3, 16); + } + + __riscv_vse32_v_f32m2(s + (y * 4 + 0) * bs + x * 16, sumf_0, 16); + __riscv_vse32_v_f32m2(s + (y * 4 + 1) * bs + x * 16, sumf_1, 16); + __riscv_vse32_v_f32m2(s + (y * 4 + 2) * bs + x * 16, sumf_2, 16); + __riscv_vse32_v_f32m2(s + (y * 4 + 3) * bs + x * 16, sumf_3, 16); + } + } +} + +void ggml_gemm_q2_K_16x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + assert(n % QK_K == 0); + const int num_k_blocks = n / QK_K; + const int N_ROWS_TILE = 4; + const int N_COLS_TILE = 16; + assert(nr % N_ROWS_TILE == 0); + assert(nc % N_COLS_TILE == 0); + + const size_t vl = __riscv_vsetvl_e32m2(N_COLS_TILE); + // --- Tiling Loops --- +#pragma GCC unroll 1 + for (int row_tile = 0; row_tile < nr; row_tile += N_ROWS_TILE) { +#pragma GCC unroll 1 + for (int col_tile = 0; col_tile < nc; col_tile += N_COLS_TILE) { + // Base Pointers + const block_q8_Kx4* lhs_base_ptr = (const block_q8_Kx4*)vy + (row_tile / N_ROWS_TILE) * num_k_blocks; + const block_q2_Kx16* rhs_base_ptr = (const block_q2_Kx16*)vx + (col_tile / N_COLS_TILE) * num_k_blocks; + + // Persistent Float Accumulators + vfloat32m2_t v_sumf_0 = __riscv_vfmv_v_f_f32m2(0.0f, vl); + vfloat32m2_t v_sumf_1 = __riscv_vfmv_v_f_f32m2(0.0f, vl); + vfloat32m2_t v_sumf_2 = __riscv_vfmv_v_f_f32m2(0.0f, vl); + vfloat32m2_t v_sumf_3 = __riscv_vfmv_v_f_f32m2(0.0f, vl); + + // --- Super-Block Loop (K=0..255) --- +#pragma GCC unroll 1 + for (int k_block = 0; k_block < num_k_blocks; ++k_block) { + const block_q8_Kx4* lhs_current = &lhs_base_ptr[k_block]; + const block_q2_Kx16* rhs_current = &rhs_base_ptr[k_block]; + + // 1. Load Global Min Scales (Keep as F16/LMUL=1 to save registers) + vfloat16m1_t v_g_min_f16 = __riscv_vle16_v_f16m1((const _Float16*)rhs_current->dmin, vl); + vfloat32m2_t v_g_min_base = __riscv_vfwcvt_f_f_v_f32m2(v_g_min_f16, vl); + + // 2. Initialize Integer Accumulators + vint32m2_t v_isum_0 = __riscv_vmv_v_x_i32m2(0, vl); + vint32m2_t v_isum_1 = __riscv_vmv_v_x_i32m2(0, vl); + vint32m2_t v_isum_2 = __riscv_vmv_v_x_i32m2(0, vl); + vint32m2_t v_isum_3 = __riscv_vmv_v_x_i32m2(0, vl); + + const uint8_t* rhs_qs_ptr = rhs_current->qs; + const uint8_t* rhs_sc_ptr = rhs_current->scales; + const int8_t* lhs_qs_ptr = lhs_current->qs; + + // --- Phase Loop (4 phases x 64 elements) --- +#pragma GCC unroll 1 + for (int phase = 0; phase < 4; ++phase) { + + // A. Load Scales/Mins for the 4 interleaved sub-blocks + vuint16m1_t v_d_sb_0, v_d_sb_1, v_d_sb_2, v_d_sb_3; + vuint16m1_t v_m_sb_0, v_m_sb_1, v_m_sb_2, v_m_sb_3; + + // Unrolled Load Logic + { + vuint8mf2_t v_raw; + // Sub-block 0 + v_raw = __riscv_vle8_v_u8mf2(rhs_sc_ptr + 0, vl); + v_d_sb_0 = __riscv_vzext_vf2_u16m1(__riscv_vand_vx_u8mf2(v_raw, 0xF, vl), vl); + v_m_sb_0 = __riscv_vzext_vf2_u16m1(__riscv_vsrl_vx_u8mf2(v_raw, 4, vl), vl); + + // Sub-block 1 + v_raw = __riscv_vle8_v_u8mf2(rhs_sc_ptr + 16, vl); + v_d_sb_1 = __riscv_vzext_vf2_u16m1(__riscv_vand_vx_u8mf2(v_raw, 0xF, vl), vl); + v_m_sb_1 = __riscv_vzext_vf2_u16m1(__riscv_vsrl_vx_u8mf2(v_raw, 4, vl), vl); + + // Sub-block 2 + v_raw = __riscv_vle8_v_u8mf2(rhs_sc_ptr + 32, vl); + v_d_sb_2 = __riscv_vzext_vf2_u16m1(__riscv_vand_vx_u8mf2(v_raw, 0xF, vl), vl); + v_m_sb_2 = __riscv_vzext_vf2_u16m1(__riscv_vsrl_vx_u8mf2(v_raw, 4, vl), vl); + + // Sub-block 3 + v_raw = __riscv_vle8_v_u8mf2(rhs_sc_ptr + 48, vl); + v_d_sb_3 = __riscv_vzext_vf2_u16m1(__riscv_vand_vx_u8mf2(v_raw, 0xF, vl), vl); + v_m_sb_3 = __riscv_vzext_vf2_u16m1(__riscv_vsrl_vx_u8mf2(v_raw, 4, vl), vl); + + rhs_sc_ptr += 64; + } + + int base_k_phase = (phase < 2) ? (phase * 16) : (128 + (phase-2)*16); + int k_offsets[4] = {0, 32, 64, 96}; + + // B. Inner Dot Product Loop +#pragma GCC unroll 1 + for (int l = 0; l < 16; ++l) { + vuint8mf2_t v_rhs_data = __riscv_vle8_v_u8mf2(rhs_qs_ptr, vl); + rhs_qs_ptr += 16; + + // Unroll over 4 sub-blocks (0, 1, 2, 3 relative to phase) + + // --- Sub-block 0 --- + { + vuint8mf2_t v_q2 = __riscv_vand_vx_u8mf2(v_rhs_data, 3, vl); + vint16m1_t v_w = __riscv_vmul_vv_i16m1( + __riscv_vreinterpret_v_u16m1_i16m1(__riscv_vzext_vf2_u16m1(v_q2, vl)), + __riscv_vreinterpret_v_u16m1_i16m1(v_d_sb_0), vl); + + const int8_t* q8 = &lhs_qs_ptr[(base_k_phase + k_offsets[0] + l) * 4]; + v_isum_0 = __riscv_vwmacc_vx_i32m2(v_isum_0, (int16_t)q8[0], v_w, vl); + v_isum_1 = __riscv_vwmacc_vx_i32m2(v_isum_1, (int16_t)q8[1], v_w, vl); + v_isum_2 = __riscv_vwmacc_vx_i32m2(v_isum_2, (int16_t)q8[2], v_w, vl); + v_isum_3 = __riscv_vwmacc_vx_i32m2(v_isum_3, (int16_t)q8[3], v_w, vl); + } + // --- Sub-block 1 --- + { + vuint8mf2_t v_q2 = __riscv_vand_vx_u8mf2(__riscv_vsrl_vx_u8mf2(v_rhs_data, 2, vl), 3, vl); + vint16m1_t v_w = __riscv_vmul_vv_i16m1( + __riscv_vreinterpret_v_u16m1_i16m1(__riscv_vzext_vf2_u16m1(v_q2, vl)), + __riscv_vreinterpret_v_u16m1_i16m1(v_d_sb_1), vl); + + const int8_t* q8 = &lhs_qs_ptr[(base_k_phase + k_offsets[1] + l) * 4]; + v_isum_0 = __riscv_vwmacc_vx_i32m2(v_isum_0, (int16_t)q8[0], v_w, vl); + v_isum_1 = __riscv_vwmacc_vx_i32m2(v_isum_1, (int16_t)q8[1], v_w, vl); + v_isum_2 = __riscv_vwmacc_vx_i32m2(v_isum_2, (int16_t)q8[2], v_w, vl); + v_isum_3 = __riscv_vwmacc_vx_i32m2(v_isum_3, (int16_t)q8[3], v_w, vl); + } + // --- Sub-block 2 --- + { + vuint8mf2_t v_q2 = __riscv_vand_vx_u8mf2(__riscv_vsrl_vx_u8mf2(v_rhs_data, 4, vl), 3, vl); + vint16m1_t v_w = __riscv_vmul_vv_i16m1( + __riscv_vreinterpret_v_u16m1_i16m1(__riscv_vzext_vf2_u16m1(v_q2, vl)), + __riscv_vreinterpret_v_u16m1_i16m1(v_d_sb_2), vl); + + const int8_t* q8 = &lhs_qs_ptr[(base_k_phase + k_offsets[2] + l) * 4]; + v_isum_0 = __riscv_vwmacc_vx_i32m2(v_isum_0, (int16_t)q8[0], v_w, vl); + v_isum_1 = __riscv_vwmacc_vx_i32m2(v_isum_1, (int16_t)q8[1], v_w, vl); + v_isum_2 = __riscv_vwmacc_vx_i32m2(v_isum_2, (int16_t)q8[2], v_w, vl); + v_isum_3 = __riscv_vwmacc_vx_i32m2(v_isum_3, (int16_t)q8[3], v_w, vl); + } + // --- Sub-block 3 --- + { + vuint8mf2_t v_q2 = __riscv_vand_vx_u8mf2(__riscv_vsrl_vx_u8mf2(v_rhs_data, 6, vl), 3, vl); + vint16m1_t v_w = __riscv_vmul_vv_i16m1( + __riscv_vreinterpret_v_u16m1_i16m1(__riscv_vzext_vf2_u16m1(v_q2, vl)), + __riscv_vreinterpret_v_u16m1_i16m1(v_d_sb_3), vl); + + const int8_t* q8 = &lhs_qs_ptr[(base_k_phase + k_offsets[3] + l) * 4]; + v_isum_0 = __riscv_vwmacc_vx_i32m2(v_isum_0, (int16_t)q8[0], v_w, vl); + v_isum_1 = __riscv_vwmacc_vx_i32m2(v_isum_1, (int16_t)q8[1], v_w, vl); + v_isum_2 = __riscv_vwmacc_vx_i32m2(v_isum_2, (int16_t)q8[2], v_w, vl); + v_isum_3 = __riscv_vwmacc_vx_i32m2(v_isum_3, (int16_t)q8[3], v_w, vl); + } + } + + // C CORRECTION + int sb_base_abs = base_k_phase / 16; + + // --- Correction Sub-block 0 --- + { + int sb_abs = sb_base_abs + (k_offsets[0] / 16); + vint16m1_t v_min = __riscv_vreinterpret_v_u16m1_i16m1(v_m_sb_0); + + // Row 0 + vfloat32m2_t v_g_min = __riscv_vfmul_vf_f32m2(v_g_min_base, lhs_current->d[0], vl); + vint32m2_t v_c = __riscv_vwmul_vx_i32m2(v_min, lhs_current->bsums[sb_abs * 4 + 0], vl); + vfloat32m2_t vf_c = __riscv_vfmul_vv_f32m2(__riscv_vfcvt_f_x_v_f32m2(v_c, vl), v_g_min, vl); + v_sumf_0 = __riscv_vfsub_vv_f32m2(v_sumf_0, vf_c, vl); + + // Row 1 + v_g_min = __riscv_vfmul_vf_f32m2(v_g_min_base, lhs_current->d[1], vl); + v_c = __riscv_vwmul_vx_i32m2(v_min, lhs_current->bsums[sb_abs * 4 + 1], vl); + vf_c = __riscv_vfmul_vv_f32m2(__riscv_vfcvt_f_x_v_f32m2(v_c, vl), v_g_min, vl); + v_sumf_1 = __riscv_vfsub_vv_f32m2(v_sumf_1, vf_c, vl); + + // Row 2 + v_g_min = __riscv_vfmul_vf_f32m2(v_g_min_base, lhs_current->d[2], vl); + v_c = __riscv_vwmul_vx_i32m2(v_min, lhs_current->bsums[sb_abs * 4 + 2], vl); + vf_c = __riscv_vfmul_vv_f32m2(__riscv_vfcvt_f_x_v_f32m2(v_c, vl), v_g_min, vl); + v_sumf_2 = __riscv_vfsub_vv_f32m2(v_sumf_2, vf_c, vl); + + // Row 3 + v_g_min = __riscv_vfmul_vf_f32m2(v_g_min_base, lhs_current->d[3], vl); + v_c = __riscv_vwmul_vx_i32m2(v_min, lhs_current->bsums[sb_abs * 4 + 3], vl); + vf_c = __riscv_vfmul_vv_f32m2(__riscv_vfcvt_f_x_v_f32m2(v_c, vl), v_g_min, vl); + v_sumf_3 = __riscv_vfsub_vv_f32m2(v_sumf_3, vf_c, vl); + } + + // --- Correction Sub-block 1 --- + { + int sb_abs = sb_base_abs + (k_offsets[1] / 16); + vint16m1_t v_min = __riscv_vreinterpret_v_u16m1_i16m1(v_m_sb_1); + + vfloat32m2_t v_g_min = __riscv_vfmul_vf_f32m2(v_g_min_base, lhs_current->d[0], vl); + vint32m2_t v_c = __riscv_vwmul_vx_i32m2(v_min, lhs_current->bsums[sb_abs * 4 + 0], vl); + vfloat32m2_t vf_c = __riscv_vfmul_vv_f32m2(__riscv_vfcvt_f_x_v_f32m2(v_c, vl), v_g_min, vl); + v_sumf_0 = __riscv_vfsub_vv_f32m2(v_sumf_0, vf_c, vl); + + v_g_min = __riscv_vfmul_vf_f32m2(v_g_min_base, lhs_current->d[1], vl); + v_c = __riscv_vwmul_vx_i32m2(v_min, lhs_current->bsums[sb_abs * 4 + 1], vl); + vf_c = __riscv_vfmul_vv_f32m2(__riscv_vfcvt_f_x_v_f32m2(v_c, vl), v_g_min, vl); + v_sumf_1 = __riscv_vfsub_vv_f32m2(v_sumf_1, vf_c, vl); + + v_g_min = __riscv_vfmul_vf_f32m2(v_g_min_base, lhs_current->d[2], vl); + v_c = __riscv_vwmul_vx_i32m2(v_min, lhs_current->bsums[sb_abs * 4 + 2], vl); + vf_c = __riscv_vfmul_vv_f32m2(__riscv_vfcvt_f_x_v_f32m2(v_c, vl), v_g_min, vl); + v_sumf_2 = __riscv_vfsub_vv_f32m2(v_sumf_2, vf_c, vl); + + v_g_min = __riscv_vfmul_vf_f32m2(v_g_min_base, lhs_current->d[3], vl); + v_c = __riscv_vwmul_vx_i32m2(v_min, lhs_current->bsums[sb_abs * 4 + 3], vl); + vf_c = __riscv_vfmul_vv_f32m2(__riscv_vfcvt_f_x_v_f32m2(v_c, vl), v_g_min, vl); + v_sumf_3 = __riscv_vfsub_vv_f32m2(v_sumf_3, vf_c, vl); + } + + // --- Correction Sub-block 2 --- + { + int sb_abs = sb_base_abs + (k_offsets[2] / 16); + vint16m1_t v_min = __riscv_vreinterpret_v_u16m1_i16m1(v_m_sb_2); + + vfloat32m2_t v_g_min = __riscv_vfmul_vf_f32m2(v_g_min_base, lhs_current->d[0], vl); + vint32m2_t v_c = __riscv_vwmul_vx_i32m2(v_min, lhs_current->bsums[sb_abs * 4 + 0], vl); + vfloat32m2_t vf_c = __riscv_vfmul_vv_f32m2(__riscv_vfcvt_f_x_v_f32m2(v_c, vl), v_g_min, vl); + v_sumf_0 = __riscv_vfsub_vv_f32m2(v_sumf_0, vf_c, vl); + + v_g_min = __riscv_vfmul_vf_f32m2(v_g_min_base, lhs_current->d[1], vl); + v_c = __riscv_vwmul_vx_i32m2(v_min, lhs_current->bsums[sb_abs * 4 + 1], vl); + vf_c = __riscv_vfmul_vv_f32m2(__riscv_vfcvt_f_x_v_f32m2(v_c, vl), v_g_min, vl); + v_sumf_1 = __riscv_vfsub_vv_f32m2(v_sumf_1, vf_c, vl); + + v_g_min = __riscv_vfmul_vf_f32m2(v_g_min_base, lhs_current->d[2], vl); + v_c = __riscv_vwmul_vx_i32m2(v_min, lhs_current->bsums[sb_abs * 4 + 2], vl); + vf_c = __riscv_vfmul_vv_f32m2(__riscv_vfcvt_f_x_v_f32m2(v_c, vl), v_g_min, vl); + v_sumf_2 = __riscv_vfsub_vv_f32m2(v_sumf_2, vf_c, vl); + + v_g_min = __riscv_vfmul_vf_f32m2(v_g_min_base, lhs_current->d[3], vl); + v_c = __riscv_vwmul_vx_i32m2(v_min, lhs_current->bsums[sb_abs * 4 + 3], vl); + vf_c = __riscv_vfmul_vv_f32m2(__riscv_vfcvt_f_x_v_f32m2(v_c, vl), v_g_min, vl); + v_sumf_3 = __riscv_vfsub_vv_f32m2(v_sumf_3, vf_c, vl); + } + + // --- Correction Sub-block 3 --- + { + int sb_abs = sb_base_abs + (k_offsets[3] / 16); + vint16m1_t v_min = __riscv_vreinterpret_v_u16m1_i16m1(v_m_sb_3); + + vfloat32m2_t v_g_min = __riscv_vfmul_vf_f32m2(v_g_min_base, lhs_current->d[0], vl); + vint32m2_t v_c = __riscv_vwmul_vx_i32m2(v_min, lhs_current->bsums[sb_abs * 4 + 0], vl); + vfloat32m2_t vf_c = __riscv_vfmul_vv_f32m2(__riscv_vfcvt_f_x_v_f32m2(v_c, vl), v_g_min, vl); + v_sumf_0 = __riscv_vfsub_vv_f32m2(v_sumf_0, vf_c, vl); + + v_g_min = __riscv_vfmul_vf_f32m2(v_g_min_base, lhs_current->d[1], vl); + v_c = __riscv_vwmul_vx_i32m2(v_min, lhs_current->bsums[sb_abs * 4 + 1], vl); + vf_c = __riscv_vfmul_vv_f32m2(__riscv_vfcvt_f_x_v_f32m2(v_c, vl), v_g_min, vl); + v_sumf_1 = __riscv_vfsub_vv_f32m2(v_sumf_1, vf_c, vl); + + v_g_min = __riscv_vfmul_vf_f32m2(v_g_min_base, lhs_current->d[2], vl); + v_c = __riscv_vwmul_vx_i32m2(v_min, lhs_current->bsums[sb_abs * 4 + 2], vl); + vf_c = __riscv_vfmul_vv_f32m2(__riscv_vfcvt_f_x_v_f32m2(v_c, vl), v_g_min, vl); + v_sumf_2 = __riscv_vfsub_vv_f32m2(v_sumf_2, vf_c, vl); + + v_g_min = __riscv_vfmul_vf_f32m2(v_g_min_base, lhs_current->d[3], vl); + v_c = __riscv_vwmul_vx_i32m2(v_min, lhs_current->bsums[sb_abs * 4 + 3], vl); + vf_c = __riscv_vfmul_vv_f32m2(__riscv_vfcvt_f_x_v_f32m2(v_c, vl), v_g_min, vl); + v_sumf_3 = __riscv_vfsub_vv_f32m2(v_sumf_3, vf_c, vl); + } + + } // End Phase Loop + + // --- Apply Main Scales --- + vfloat16m1_t v_g_all_f16 = __riscv_vle16_v_f16m1((const _Float16*)rhs_current->d, vl); + vfloat32m2_t v_g_all_base = __riscv_vfwcvt_f_f_v_f32m2(v_g_all_f16, vl); + + { + vfloat32m2_t v_g_all = __riscv_vfmul_vf_f32m2(v_g_all_base, lhs_current->d[0], vl); + vfloat32m2_t v_sum = __riscv_vfcvt_f_x_v_f32m2(v_isum_0, vl); + v_sum = __riscv_vfmul_vv_f32m2(v_sum, v_g_all, vl); + v_sumf_0 = __riscv_vfadd_vv_f32m2(v_sumf_0, v_sum, vl); + } + // Row 1 + { + vfloat32m2_t v_g_all = __riscv_vfmul_vf_f32m2(v_g_all_base, lhs_current->d[1], vl); + vfloat32m2_t v_sum = __riscv_vfcvt_f_x_v_f32m2(v_isum_1, vl); + v_sum = __riscv_vfmul_vv_f32m2(v_sum, v_g_all, vl); + v_sumf_1 = __riscv_vfadd_vv_f32m2(v_sumf_1, v_sum, vl); + } + // Row 2 + { + vfloat32m2_t v_g_all = __riscv_vfmul_vf_f32m2(v_g_all_base, lhs_current->d[2], vl); + vfloat32m2_t v_sum = __riscv_vfcvt_f_x_v_f32m2(v_isum_2, vl); + v_sum = __riscv_vfmul_vv_f32m2(v_sum, v_g_all, vl); + v_sumf_2 = __riscv_vfadd_vv_f32m2(v_sumf_2, v_sum, vl); + } + // Row 3 + { + vfloat32m2_t v_g_all = __riscv_vfmul_vf_f32m2(v_g_all_base, lhs_current->d[3], vl); + vfloat32m2_t v_sum = __riscv_vfcvt_f_x_v_f32m2(v_isum_3, vl); + v_sum = __riscv_vfmul_vv_f32m2(v_sum, v_g_all, vl); + v_sumf_3 = __riscv_vfadd_vv_f32m2(v_sumf_3, v_sum, vl); + } + + } // End K-Block + + __riscv_vse32_v_f32m2(s + (row_tile + 0) * bs + col_tile, v_sumf_0, vl); + __riscv_vse32_v_f32m2(s + (row_tile + 1) * bs + col_tile, v_sumf_1, vl); + __riscv_vse32_v_f32m2(s + (row_tile + 2) * bs + col_tile, v_sumf_2, vl); + __riscv_vse32_v_f32m2(s + (row_tile + 3) * bs + col_tile, v_sumf_3, vl); + } + } +} +#endif diff --git a/ggml/src/ggml-cpu/arch/s390/quants.c b/ggml/src/ggml-cpu/arch/s390/quants.c index 19d225a4837..500857579a7 100644 --- a/ggml/src/ggml-cpu/arch/s390/quants.c +++ b/ggml/src/ggml-cpu/arch/s390/quants.c @@ -181,11 +181,11 @@ void ggml_vec_dot_q4_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const voi const int8x16_t v_yh = vec_xl(QK8_0/2, y[ib].qs); const int16x8_t v_xylso = vec_mulo(v_xls, v_yl); - const int16x8_t v_xylse = vec_mule(v_xls, v_yl); + const int16x8_t v_xyl = vec_meadd(v_xls, v_yl, v_xylso); const int16x8_t v_xyhso = vec_mulo(v_xhs, v_yh); - const int16x8_t v_xyhse = vec_mule(v_xhs, v_yh); + const int16x8_t v_xyh = vec_meadd(v_xhs, v_yh, v_xyhso); - int16x8_t v_xy_ = v_xylso + v_xylse + v_xyhso + v_xyhse; v_xy_ += vec_reve(v_xy_); + int16x8_t v_xy_ = v_xyl + v_xyh; v_xy_ += vec_reve(v_xy_); const float32x4_t v_xy = vec_float(vec_unpackh(v_xy_)); const float32x4_t v_d = vec_splats(GGML_CPU_FP16_TO_FP32(x[ib].d) * GGML_CPU_FP16_TO_FP32(y[ib].d)); @@ -890,8 +890,7 @@ void ggml_vec_dot_q4_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi const int16x8_t v_minsh = (int16x8_t)vec_unpackh((uint8x16_t)v_mins8); const int32x4_t v_minso = vec_mulo(v_ysums, v_minsh); - const int32x4_t v_minse = vec_mule(v_ysums, v_minsh); - const int32x4_t v_mins = v_minso + v_minse; + const int32x4_t v_mins = vec_meadd(v_ysums, v_minsh, v_minso); sumf -= dmin * (v_mins[0] + v_mins[1] + v_mins[2] + v_mins[3]); const uint8_t * scales = (const uint8_t *)utmp; @@ -1004,8 +1003,7 @@ void ggml_vec_dot_q5_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi const int16x8_t v_minsh = (int16x8_t)vec_unpackh(v_mins8); const int32x4_t v_minsho = vec_mulo(v_ysums, v_minsh); - const int32x4_t v_minshe = vec_mule(v_ysums, v_minsh); - const int32x4_t v_mins = vec_add(v_minsho, v_minshe); + const int32x4_t v_mins = vec_meadd(v_ysums, v_minsh, v_minsho); const int32_t mins = vec_hsum_i32x4(v_mins); const uint8_t * scales = (const uint8_t *)utmp; @@ -1110,10 +1108,10 @@ void ggml_vec_dot_q6_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi const int16x8_t v_scaleh = vec_unpackl(v_scale); const int32x4_t v_minslo = vec_mulo(v_ysumsl, v_scalel); - const int32x4_t v_minsle = vec_mule(v_ysumsl, v_scalel); + const int32x4_t v_minsl = vec_meadd(v_ysumsl, v_scalel, v_minslo); const int32x4_t v_minsho = vec_mulo(v_ysumsh, v_scaleh); - const int32x4_t v_minshe = vec_mule(v_ysumsh, v_scaleh); - const int32x4_t v_mins = v_minslo + v_minsle + v_minsho + v_minshe; + const int32x4_t v_minsh = vec_meadd(v_ysumsh, v_scaleh, v_minsho); + const int32x4_t v_mins = vec_add(v_minsl, v_minsh); const int32_t mins = vec_hsum_i32x4(v_mins); @@ -1465,4 +1463,3 @@ void ggml_vec_dot_iq4_xs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const v ggml_vec_dot_iq4_xs_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc); #endif } - diff --git a/ggml/src/ggml-cpu/arch/wasm/quants.c b/ggml/src/ggml-cpu/arch/wasm/quants.c index 74a359e6d12..0a7119b4e1f 100644 --- a/ggml/src/ggml-cpu/arch/wasm/quants.c +++ b/ggml/src/ggml-cpu/arch/wasm/quants.c @@ -355,6 +355,78 @@ void ggml_vec_dot_q4_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const voi *s = sumf; } +void ggml_vec_dot_q4_1_q8_1(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + const int qk = QK8_1; + const int nb = n / qk; + + assert(n % qk == 0); + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + + const block_q4_1 * GGML_RESTRICT x = vx; + const block_q8_1 * GGML_RESTRICT y = vy; + + float sumf = 0; + +#if defined __wasm_simd128__ + v128_t sumv = wasm_f32x4_splat(0.0f); + float summs = 0.0f; + + for (int ib = 0; ib < nb; ++ib) { + const block_q4_1 * GGML_RESTRICT x0 = &x[ib]; + const block_q8_1 * GGML_RESTRICT y0 = &y[ib]; + + summs += GGML_CPU_FP16_TO_FP32(x0->m) * GGML_CPU_FP16_TO_FP32(y0->s); + + const v128_t raw = wasm_v128_load(x0->qs); + const v128_t v0s = wasm_v128_and(raw, wasm_i8x16_splat(0x0F)); + const v128_t v1s = wasm_u8x16_shr(raw, 4); + + const v128_t ys_lo = wasm_v128_load(y0->qs); + const v128_t ys_hi = wasm_v128_load(y0->qs + 16); + + const v128_t v0s_l = wasm_u16x8_extend_low_u8x16(v0s); + const v128_t v0s_h = wasm_u16x8_extend_high_u8x16(v0s); + const v128_t ylo_l = wasm_i16x8_extend_low_i8x16(ys_lo); + const v128_t ylo_h = wasm_i16x8_extend_high_i8x16(ys_lo); + const v128_t v1s_l = wasm_u16x8_extend_low_u8x16(v1s); + const v128_t v1s_h = wasm_u16x8_extend_high_u8x16(v1s); + const v128_t yhi_l = wasm_i16x8_extend_low_i8x16(ys_hi); + const v128_t yhi_h = wasm_i16x8_extend_high_i8x16(ys_hi); + + const v128_t acc = wasm_i32x4_add( + wasm_i32x4_add( + wasm_i32x4_dot_i16x8(v0s_l, ylo_l), + wasm_i32x4_dot_i16x8(v0s_h, ylo_h)), + wasm_i32x4_add( + wasm_i32x4_dot_i16x8(v1s_l, yhi_l), + wasm_i32x4_dot_i16x8(v1s_h, yhi_h))); + + sumv = wasm_f32x4_add(sumv, + wasm_f32x4_mul( + wasm_f32x4_convert_i32x4(acc), + wasm_f32x4_splat(GGML_CPU_FP16_TO_FP32(x0->d) * GGML_CPU_FP16_TO_FP32(y0->d)))); + } + + sumf = wasm_f32x4_extract_lane(sumv, 0) + wasm_f32x4_extract_lane(sumv, 1) + + wasm_f32x4_extract_lane(sumv, 2) + wasm_f32x4_extract_lane(sumv, 3) + summs; + + *s = sumf; + +#else + UNUSED(nb); + UNUSED(x); + UNUSED(y); + UNUSED(sumf); + + ggml_vec_dot_q4_1_q8_1_generic( + n, s, bs, vx, bx, vy, by, nrc); +#endif +} + void ggml_vec_dot_q5_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { const int qk = QK8_0; const int nb = n / qk; @@ -1218,4 +1290,3 @@ void ggml_vec_dot_q6_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi ggml_vec_dot_q6_K_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc); #endif } - diff --git a/ggml/src/ggml-cpu/arch/x86/quants.c b/ggml/src/ggml-cpu/arch/x86/quants.c index cb49320a67f..94b19b82bbc 100644 --- a/ggml/src/ggml-cpu/arch/x86/quants.c +++ b/ggml/src/ggml-cpu/arch/x86/quants.c @@ -268,12 +268,24 @@ static inline __m256 quad_fp16_delta_float(const float x0, const float y0, const _mm_set1_ps(GGML_CPU_FP16_TO_FP32(x0) * GGML_CPU_FP16_TO_FP32(y0))); } -static inline __m256 quad_mx_delta_float(const int8_t x0, const float y0, const int8_t x1, const float y1) { - return _mm256_set_m128(_mm_set1_ps(GGML_E8M0_TO_FP32_HALF(x1) * GGML_CPU_FP16_TO_FP32(y1)), - _mm_set1_ps(GGML_E8M0_TO_FP32_HALF(x0) * GGML_CPU_FP16_TO_FP32(y0))); +static inline __m256 quad_mx_delta_float(const uint8_t x0, const float y0, const uint8_t x1, const float y1) { + return _mm256_set_m128(_mm_set1_ps(GGML_CPU_E8M0_TO_FP32_HALF(x1) * GGML_CPU_FP16_TO_FP32(y1)), + _mm_set1_ps(GGML_CPU_E8M0_TO_FP32_HALF(x0) * GGML_CPU_FP16_TO_FP32(y0))); } #endif #elif defined(__SSSE3__) +static inline __m128i bytes_from_bits_16(const uint8_t * x) { + uint16_t x16; + memcpy(&x16, x, sizeof(uint16_t)); + + const __m128i shuf_mask = _mm_set_epi64x(0x0101010101010101, 0x0000000000000000); + __m128i bytes = _mm_shuffle_epi8(_mm_set1_epi16((short) x16), shuf_mask); + const __m128i bit_mask = _mm_set_epi64x(0x7fbfdfeff7fbfdfe, 0x7fbfdfeff7fbfdfe); + bytes = _mm_or_si128(bytes, bit_mask); + + return _mm_cmpeq_epi8(bytes, _mm_set1_epi64x(-1)); +} + // horizontally add 4x4 floats static inline float hsum_float_4x4(const __m128 a, const __m128 b, const __m128 c, const __m128 d) { __m128 res_0 =_mm_hadd_ps(a, b); @@ -540,6 +552,152 @@ static inline __m128i get_scale_shuffle(int i) { } #endif +void ggml_vec_dot_q1_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + const int qk = QK1_0; + const int nb = n / qk; + + assert(n % qk == 0); + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + + const block_q1_0 * GGML_RESTRICT x = vx; + const block_q8_0 * GGML_RESTRICT y = vy; + +#if defined(__AVX2__) + const __m256i ones_8 = _mm256_set1_epi8(1); + const __m256i ones_16 = _mm256_set1_epi16(1); + const __m256i byte_shuf = _mm256_setr_epi8( + 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, + 2, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3, 3); + const __m256i bit_masks = _mm256_setr_epi8( + 1, 2, 4, 8, 16, 32, 64, (char) -128, 1, 2, 4, 8, 16, 32, 64, (char) -128, + 1, 2, 4, 8, 16, 32, 64, (char) -128, 1, 2, 4, 8, 16, 32, 64, (char) -128); + const __m256i zero = _mm256_setzero_si256(); + __m256 acc = _mm256_setzero_ps(); + + for (int ib = 0; ib < nb; ++ib) { + const float d0 = GGML_CPU_FP16_TO_FP32(x[ib].d); + const uint32_t * GGML_RESTRICT qs32 = (const uint32_t *) x[ib].qs; + const block_q8_0 * GGML_RESTRICT y_ptr = &y[ib * 4]; + + __m256 acc_block; + { + const __m256i qy = _mm256_loadu_si256((const __m256i *) y_ptr[0].qs); + const __m256i sm = _mm256_cmpeq_epi8( + _mm256_and_si256(_mm256_shuffle_epi8(_mm256_set1_epi32((int) qs32[0]), byte_shuf), bit_masks), zero); + const __m256i sy = _mm256_sub_epi8(_mm256_xor_si256(qy, sm), sm); + const __m256i s32 = _mm256_madd_epi16(_mm256_maddubs_epi16(ones_8, sy), ones_16); + acc_block = _mm256_mul_ps(_mm256_set1_ps(GGML_CPU_FP16_TO_FP32(y_ptr[0].d)), _mm256_cvtepi32_ps(s32)); + } + for (int K = 1; K < 4; ++K) { + const __m256i qy = _mm256_loadu_si256((const __m256i *) y_ptr[K].qs); + const __m256i sm = _mm256_cmpeq_epi8( + _mm256_and_si256(_mm256_shuffle_epi8(_mm256_set1_epi32((int) qs32[K]), byte_shuf), bit_masks), zero); + const __m256i sy = _mm256_sub_epi8(_mm256_xor_si256(qy, sm), sm); + const __m256i s32 = _mm256_madd_epi16(_mm256_maddubs_epi16(ones_8, sy), ones_16); + acc_block = _mm256_fmadd_ps(_mm256_set1_ps(GGML_CPU_FP16_TO_FP32(y_ptr[K].d)), _mm256_cvtepi32_ps(s32), acc_block); + } + acc = _mm256_fmadd_ps(_mm256_set1_ps(d0), acc_block, acc); + } + + *s = hsum_float_8(acc); +#elif defined(__AVX__) + const __m128i ones_8 = _mm_set1_epi8(1); + const __m128i ones_16 = _mm_set1_epi16(1); + const __m128i zero = _mm_setzero_si128(); + __m256 acc = _mm256_setzero_ps(); + + for (int ib = 0; ib < nb; ++ib) { + const float d0 = GGML_CPU_FP16_TO_FP32(x[ib].d); + const block_q8_0 * GGML_RESTRICT y_ptr = &y[ib * 4]; + __m256 acc_block; + { + const __m256i bit_mask = bytes_from_bits_32(&x[ib].qs[0]); + const __m128i bit_mask_0 = _mm256_castsi256_si128(bit_mask); + const __m128i bit_mask_1 = _mm256_extractf128_si256(bit_mask, 1); + const __m128i qy_0 = _mm_loadu_si128((const __m128i *) &y_ptr[0].qs[0]); + const __m128i qy_1 = _mm_loadu_si128((const __m128i *) &y_ptr[0].qs[16]); + const __m128i sign_mask_0 = _mm_cmpeq_epi8(bit_mask_0, zero); + const __m128i sign_mask_1 = _mm_cmpeq_epi8(bit_mask_1, zero); + const __m128i sy_0 = _mm_sub_epi8(_mm_xor_si128(qy_0, sign_mask_0), sign_mask_0); + const __m128i sy_1 = _mm_sub_epi8(_mm_xor_si128(qy_1, sign_mask_1), sign_mask_1); + const __m128i sum16_0 = _mm_maddubs_epi16(ones_8, sy_0); + const __m128i sum16_1 = _mm_maddubs_epi16(ones_8, sy_1); + const __m128i sum32_0 = _mm_madd_epi16(sum16_0, ones_16); + const __m128i sum32_1 = _mm_madd_epi16(sum16_1, ones_16); + const __m256 q = _mm256_cvtepi32_ps(MM256_SET_M128I(sum32_1, sum32_0)); + acc_block = _mm256_mul_ps(_mm256_set1_ps(GGML_CPU_FP16_TO_FP32(y_ptr[0].d)), q); + } + for(int K = 1; K < 4; ++K) { + const __m256i bit_mask = bytes_from_bits_32(&x[ib].qs[(K) * 4]); + const __m128i bit_mask_0 = _mm256_castsi256_si128(bit_mask); + const __m128i bit_mask_1 = _mm256_extractf128_si256(bit_mask, 1); + const __m128i qy_0 = _mm_loadu_si128((const __m128i *) &y_ptr[(K)].qs[0]); + const __m128i qy_1 = _mm_loadu_si128((const __m128i *) &y_ptr[(K)].qs[16]); + const __m128i sign_mask_0 = _mm_cmpeq_epi8(bit_mask_0, zero); + const __m128i sign_mask_1 = _mm_cmpeq_epi8(bit_mask_1, zero); + const __m128i sy_0 = _mm_sub_epi8(_mm_xor_si128(qy_0, sign_mask_0), sign_mask_0); + const __m128i sy_1 = _mm_sub_epi8(_mm_xor_si128(qy_1, sign_mask_1), sign_mask_1); + const __m128i sum16_0 = _mm_maddubs_epi16(ones_8, sy_0); + const __m128i sum16_1 = _mm_maddubs_epi16(ones_8, sy_1); + const __m128i sum32_0 = _mm_madd_epi16(sum16_0, ones_16); + const __m128i sum32_1 = _mm_madd_epi16(sum16_1, ones_16); + const __m256 q = _mm256_cvtepi32_ps(MM256_SET_M128I(sum32_1, sum32_0)); + acc_block = _mm256_add_ps(acc_block, _mm256_mul_ps(_mm256_set1_ps(GGML_CPU_FP16_TO_FP32(y_ptr[(K)].d)), q)); + } +#undef Q1_AVX_BLOCK + + acc = _mm256_add_ps(acc, _mm256_mul_ps(_mm256_set1_ps(d0), acc_block)); + } + + *s = hsum_float_8(acc); +#elif defined(__SSSE3__) + const __m128i ones_8 = _mm_set1_epi8(1); + const __m128i ones_16 = _mm_set1_epi16(1); + const __m128i zero = _mm_setzero_si128(); + __m128 acc_0 = _mm_setzero_ps(); + __m128 acc_1 = _mm_setzero_ps(); + __m128 acc_2 = _mm_setzero_ps(); + __m128 acc_3 = _mm_setzero_ps(); + + for (int ib = 0; ib < nb; ++ib) { + const __m128 d0 = _mm_set1_ps(GGML_CPU_FP16_TO_FP32(x[ib].d)); + const block_q8_0 * GGML_RESTRICT y_ptr = &y[ib * 4]; + +#define Q1_SSSE3_BLOCK(QS_OFF, Y_IDX, ACC) \ + { \ + const __m128i bit_mask_0 = bytes_from_bits_16(&x[ib].qs[(QS_OFF) + 0]); \ + const __m128i bit_mask_1 = bytes_from_bits_16(&x[ib].qs[(QS_OFF) + 2]); \ + const __m128i qy_0 = _mm_loadu_si128((const __m128i *) &y_ptr[(Y_IDX)].qs[0]); \ + const __m128i qy_1 = _mm_loadu_si128((const __m128i *) &y_ptr[(Y_IDX)].qs[16]); \ + const __m128i sign_mask_0 = _mm_cmpeq_epi8(bit_mask_0, zero); \ + const __m128i sign_mask_1 = _mm_cmpeq_epi8(bit_mask_1, zero); \ + const __m128i sy_0 = _mm_sub_epi8(_mm_xor_si128(qy_0, sign_mask_0), sign_mask_0); \ + const __m128i sy_1 = _mm_sub_epi8(_mm_xor_si128(qy_1, sign_mask_1), sign_mask_1); \ + const __m128i sum_0 = _mm_madd_epi16(_mm_maddubs_epi16(ones_8, sy_0), ones_16); \ + const __m128i sum_1 = _mm_madd_epi16(_mm_maddubs_epi16(ones_8, sy_1), ones_16); \ + const __m128 q = _mm_cvtepi32_ps(_mm_add_epi32(sum_0, sum_1)); \ + (ACC) = _mm_add_ps((ACC), _mm_mul_ps(_mm_mul_ps(d0, _mm_set1_ps(GGML_CPU_FP16_TO_FP32(y_ptr[(Y_IDX)].d))), q)); \ + } + Q1_SSSE3_BLOCK(0, 0, acc_0) + Q1_SSSE3_BLOCK(4, 1, acc_1) + Q1_SSSE3_BLOCK(8, 2, acc_2) + Q1_SSSE3_BLOCK(12, 3, acc_3) +#undef Q1_SSSE3_BLOCK + } + + *s = hsum_float_4x4(acc_0, acc_1, acc_2, acc_3); +#else + UNUSED(nb); + UNUSED(x); + UNUSED(y); + ggml_vec_dot_q1_0_q8_0_generic(n, s, bs, vx, bx, vy, by, nrc); +#endif +} + void ggml_vec_dot_q4_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { const int qk = QK8_0; const int nb = n / qk; @@ -782,6 +940,7 @@ void ggml_vec_dot_mxfp4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const vo __m256 accum1 = _mm256_setzero_ps(); __m256 accum2 = _mm256_setzero_ps(); + for (; ib + 1 < nb; ib += 2) { const __m128i q4bits_1 = _mm_loadu_si128((const __m128i*)x[ib + 0].qs); const __m128i q4bits_2 = _mm_loadu_si128((const __m128i*)x[ib + 1].qs); @@ -795,10 +954,10 @@ void ggml_vec_dot_mxfp4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const vo const __m256i p16_2 = mul_add_epi8(q4b_2, q8b_2); const __m256i p_1 = _mm256_madd_epi16(p16_1, mone); const __m256i p_2 = _mm256_madd_epi16(p16_2, mone); - accum1 = _mm256_fmadd_ps(_mm256_set1_ps(GGML_CPU_FP16_TO_FP32(y[ib + 0].d)*GGML_E8M0_TO_FP32_HALF(x[ib + 0].e)), - _mm256_cvtepi32_ps(p_1), accum1); - accum2 = _mm256_fmadd_ps(_mm256_set1_ps(GGML_CPU_FP16_TO_FP32(y[ib + 1].d)*GGML_E8M0_TO_FP32_HALF(x[ib + 1].e)), - _mm256_cvtepi32_ps(p_2), accum2); + const __m256 scale0 = _mm256_set1_ps(GGML_CPU_FP16_TO_FP32(y[ib + 0].d)*GGML_CPU_E8M0_TO_FP32_HALF(x[ib + 0].e)); + const __m256 scale1 = _mm256_set1_ps(GGML_CPU_FP16_TO_FP32(y[ib + 1].d)*GGML_CPU_E8M0_TO_FP32_HALF(x[ib + 1].e)); + accum1 = _mm256_fmadd_ps(scale0, _mm256_cvtepi32_ps(p_1), accum1); + accum2 = _mm256_fmadd_ps(scale1, _mm256_cvtepi32_ps(p_2), accum2); } sumf = hsum_float_8(_mm256_add_ps(accum1, accum2)); @@ -830,7 +989,7 @@ void ggml_vec_dot_mxfp4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const vo #endif for (; ib < nb; ++ib) { - const float d = GGML_CPU_FP16_TO_FP32(y[ib].d)*GGML_E8M0_TO_FP32_HALF(x[ib].e); + const float d = GGML_CPU_FP16_TO_FP32(y[ib].d)*GGML_CPU_E8M0_TO_FP32_HALF(x[ib].e); int sumi1 = 0; int sumi2 = 0; for (int j = 0; j < QK_MXFP4/2; ++j) { @@ -2141,9 +2300,8 @@ void ggml_vec_dot_q6_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi #if defined __AVX2__ - const __m256i m4 = _mm256_set1_epi8(0xF); - const __m256i m2 = _mm256_set1_epi8(3); - const __m256i m32s = _mm256_set1_epi8(32); + const __m256i m3 = _mm256_set1_epi8(3); + const __m256i m15 = _mm256_set1_epi8(15); __m256 acc = _mm256_setzero_ps(); @@ -2155,53 +2313,45 @@ void ggml_vec_dot_q6_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi const uint8_t * GGML_RESTRICT qh = x[i].qh; const int8_t * GGML_RESTRICT q8 = y[i].qs; + const __m256i q8sums = _mm256_loadu_si256((const __m256i*)y[i].bsums); const __m128i scales = _mm_loadu_si128((const __m128i*)x[i].scales); + const __m256i scales_16 = _mm256_cvtepi8_epi16(scales); + const __m256i q8sclsub = _mm256_slli_epi32(_mm256_madd_epi16(q8sums, scales_16), 5); __m256i sumi = _mm256_setzero_si256(); int is = 0; for (int j = 0; j < QK_K/128; ++j) { - - const __m128i scale_0 = _mm_shuffle_epi8(scales, get_scale_shuffle(is + 0)); - const __m128i scale_1 = _mm_shuffle_epi8(scales, get_scale_shuffle(is + 1)); - const __m128i scale_2 = _mm_shuffle_epi8(scales, get_scale_shuffle(is + 2)); - const __m128i scale_3 = _mm_shuffle_epi8(scales, get_scale_shuffle(is + 3)); - is += 4; - const __m256i q4bits1 = _mm256_loadu_si256((const __m256i*)q4); q4 += 32; const __m256i q4bits2 = _mm256_loadu_si256((const __m256i*)q4); q4 += 32; const __m256i q4bitsH = _mm256_loadu_si256((const __m256i*)qh); qh += 32; - const __m256i q4h_0 = _mm256_slli_epi16(_mm256_and_si256(q4bitsH, m2), 4); - const __m256i q4h_1 = _mm256_slli_epi16(_mm256_and_si256(_mm256_srli_epi16(q4bitsH, 2), m2), 4); - const __m256i q4h_2 = _mm256_slli_epi16(_mm256_and_si256(_mm256_srli_epi16(q4bitsH, 4), m2), 4); - const __m256i q4h_3 = _mm256_slli_epi16(_mm256_and_si256(_mm256_srli_epi16(q4bitsH, 6), m2), 4); + const __m256i q4h_0 = _mm256_slli_epi16(_mm256_and_si256(q4bitsH, m3), 4); + const __m256i q4h_1 = _mm256_slli_epi16(_mm256_and_si256(q4bitsH, _mm256_set1_epi8(12)), 2); + const __m256i q4h_2 = _mm256_and_si256(q4bitsH, _mm256_set1_epi8(48)); + const __m256i q4h_3 = _mm256_srli_epi16(_mm256_and_si256(q4bitsH, _mm256_set1_epi8(-64)), 2); - const __m256i q4_0 = _mm256_or_si256(_mm256_and_si256(q4bits1, m4), q4h_0); - const __m256i q4_1 = _mm256_or_si256(_mm256_and_si256(q4bits2, m4), q4h_1); - const __m256i q4_2 = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(q4bits1, 4), m4), q4h_2); - const __m256i q4_3 = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(q4bits2, 4), m4), q4h_3); + const __m256i q4_0 = _mm256_or_si256(_mm256_and_si256(q4bits1, m15), q4h_0); + const __m256i q4_1 = _mm256_or_si256(_mm256_and_si256(q4bits2, m15), q4h_1); + const __m256i q4_2 = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(q4bits1, 4), m15), q4h_2); + const __m256i q4_3 = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(q4bits2, 4), m15), q4h_3); const __m256i q8_0 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32; const __m256i q8_1 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32; const __m256i q8_2 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32; const __m256i q8_3 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32; - __m256i q8s_0 = _mm256_maddubs_epi16(m32s, q8_0); - __m256i q8s_1 = _mm256_maddubs_epi16(m32s, q8_1); - __m256i q8s_2 = _mm256_maddubs_epi16(m32s, q8_2); - __m256i q8s_3 = _mm256_maddubs_epi16(m32s, q8_3); - __m256i p16_0 = _mm256_maddubs_epi16(q4_0, q8_0); __m256i p16_1 = _mm256_maddubs_epi16(q4_1, q8_1); __m256i p16_2 = _mm256_maddubs_epi16(q4_2, q8_2); __m256i p16_3 = _mm256_maddubs_epi16(q4_3, q8_3); - p16_0 = _mm256_sub_epi16(p16_0, q8s_0); - p16_1 = _mm256_sub_epi16(p16_1, q8s_1); - p16_2 = _mm256_sub_epi16(p16_2, q8s_2); - p16_3 = _mm256_sub_epi16(p16_3, q8s_3); + const __m128i scale_0 = _mm_shuffle_epi8(scales, get_scale_shuffle(is + 0)); + const __m128i scale_1 = _mm_shuffle_epi8(scales, get_scale_shuffle(is + 1)); + const __m128i scale_2 = _mm_shuffle_epi8(scales, get_scale_shuffle(is + 2)); + const __m128i scale_3 = _mm_shuffle_epi8(scales, get_scale_shuffle(is + 3)); + is += 4; p16_0 = _mm256_madd_epi16(_mm256_cvtepi8_epi16(scale_0), p16_0); p16_1 = _mm256_madd_epi16(_mm256_cvtepi8_epi16(scale_1), p16_1); @@ -2213,6 +2363,7 @@ void ggml_vec_dot_q6_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi } + sumi = _mm256_sub_epi32(sumi, q8sclsub); acc = _mm256_fmadd_ps(_mm256_broadcast_ss(&d), _mm256_cvtepi32_ps(sumi), acc); } @@ -3817,4 +3968,3 @@ void ggml_vec_dot_iq4_xs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const v ggml_vec_dot_iq4_xs_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc); #endif } - diff --git a/ggml/src/ggml-cpu/arch/x86/repack.cpp b/ggml/src/ggml-cpu/arch/x86/repack.cpp index 7dda9eea0c5..af1cebad131 100644 --- a/ggml/src/ggml-cpu/arch/x86/repack.cpp +++ b/ggml/src/ggml-cpu/arch/x86/repack.cpp @@ -423,7 +423,7 @@ void ggml_quantize_mat_q8_K_4x8(const float * GGML_RESTRICT x, void * GGML_RESTR quants_interleaved[j] = i0; } - // Masks to shuffle the quants of corresonding sub blocks for rearraning quants for vectorized bsums computation + // Masks to shuffle the quants of corresponding sub blocks for rearranging quants for vectorized bsums computation __m256i shuffle_mask_sb2 = _mm256_castsi128_si256(_mm_setr_epi8(0, 1, 0, 1, 4, 5, 6, 7, 8, 9, 8, 9, 12, 13, 14, 15)); shuffle_mask_sb2 = _mm256_permute2f128_si256(shuffle_mask_sb2, shuffle_mask_sb2, 0); __m256i shuffle_mask_sb3 = _mm256_castsi128_si256(_mm_setr_epi8(0, 1, 2, 3, 0, 1, 6, 7, 8, 9, 10, 11, 8, 9, 14, 15)); @@ -522,7 +522,8 @@ template<typename block_tx8> static void gemv_q4_b32_8x8_q8_0_lut_avx(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc, __m256i signextendlut) { static_assert( std::is_same_v<block_tx8, block_q4_0x8> || - std::is_same_v<block_tx8, block_iq4_nlx8>, + std::is_same_v<block_tx8, block_iq4_nlx8> || + std::is_same_v<block_tx8, block_mxfp4x8>, "Unsupported block type"); const int qk = QK8_0; @@ -530,7 +531,6 @@ static void gemv_q4_b32_8x8_q8_0_lut_avx(int n, float * GGML_RESTRICT s, size_t UNUSED(bs); - __m128i changemask = _mm_set_epi8(15, 14, 7, 6, 13, 12, 5, 4, 11, 10, 3, 2, 9, 8, 1, 0); __m256i finalpermutemask = _mm256_set_epi32(7, 5, 3, 1, 6, 4, 2, 0); // Permute mask used for easier vector processing at later stages @@ -579,7 +579,20 @@ static void gemv_q4_b32_8x8_q8_0_lut_avx(int n, float * GGML_RESTRICT s, size_t if constexpr ( std::is_same_v<block_tx8, block_q4_0x8> || std::is_same_v<block_tx8, block_iq4_nlx8>) { + const __m128i changemask = _mm_set_epi8(15, 14, 7, 6, 13, 12, 5, 4, 11, 10, 3, 2, 9, 8, 1, 0); col_scale_f32 = GGML_F32Cx8_REARRANGE_LOAD(b_ptr[b].d, changemask); + } else if constexpr (std::is_same_v<block_tx8, block_mxfp4x8>) { + // Load 8 E8M0 exponents and convert to float via LUT + // Rearranged to match changemask order: 0,4,1,5,2,6,3,7 + col_scale_f32 = _mm256_set_ps( + GGML_CPU_E8M0_TO_FP32_HALF(b_ptr[b].e[7]), + GGML_CPU_E8M0_TO_FP32_HALF(b_ptr[b].e[3]), + GGML_CPU_E8M0_TO_FP32_HALF(b_ptr[b].e[6]), + GGML_CPU_E8M0_TO_FP32_HALF(b_ptr[b].e[2]), + GGML_CPU_E8M0_TO_FP32_HALF(b_ptr[b].e[5]), + GGML_CPU_E8M0_TO_FP32_HALF(b_ptr[b].e[1]), + GGML_CPU_E8M0_TO_FP32_HALF(b_ptr[b].e[4]), + GGML_CPU_E8M0_TO_FP32_HALF(b_ptr[b].e[0])); } // Load and convert to FP32 scale from block_q8_0 @@ -612,7 +625,7 @@ static void gemv_q4_b32_8x8_q8_0_lut_avx(int n, float * GGML_RESTRICT s, size_t iacc = mul_sum_i8_pairs_acc_int32x8(iacc, _mm256_blend_epi32(rhs_vec_0123_3 ,_mm256_shuffle_epi32(rhs_vec_4567_3, 177), 170), _mm256_shuffle_epi32(lhs_vec_1, 170)); iacc = mul_sum_i8_pairs_acc_int32x8(iacc, _mm256_blend_epi32(_mm256_shuffle_epi32(rhs_vec_0123_3, 177) ,rhs_vec_4567_3, 170), _mm256_shuffle_epi32(lhs_vec_1, 255)); - // Accumulated values multipled with appropriate scales + // Accumulated values multiplied with appropriate scales acc_row = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc), _mm256_mul_ps(col_scale_f32, row_scale_f32), acc_row); } @@ -628,7 +641,8 @@ template<typename block_tx8> static void gemm_q4_b32_8x8_q8_0_lut_avx(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc, __m256i signextendlut) { static_assert( std::is_same_v<block_tx8, block_q4_0x8> || - std::is_same_v<block_tx8, block_iq4_nlx8>, + std::is_same_v<block_tx8, block_iq4_nlx8> || + std::is_same_v<block_tx8, block_mxfp4x8>, "Unsupported block type"); const int qk = QK8_0; @@ -749,6 +763,25 @@ static void gemm_q4_b32_8x8_q8_0_lut_avx(int n, float * GGML_RESTRICT s, size_t std::is_same_v<block_tx8, block_q4_0x8> || std::is_same_v<block_tx8, block_iq4_nlx8>) { col_scale_f32 = GGML_F32Cx8x2_LOAD(b_ptr_0[b].d, b_ptr_1[b].d); + } else if constexpr (std::is_same_v<block_tx8, block_mxfp4x8>) { + //TODO: simd-ify + col_scale_f32 = _mm512_set_ps( + GGML_CPU_E8M0_TO_FP32_HALF(b_ptr_1[b].e[7]), + GGML_CPU_E8M0_TO_FP32_HALF(b_ptr_1[b].e[6]), + GGML_CPU_E8M0_TO_FP32_HALF(b_ptr_1[b].e[5]), + GGML_CPU_E8M0_TO_FP32_HALF(b_ptr_1[b].e[4]), + GGML_CPU_E8M0_TO_FP32_HALF(b_ptr_1[b].e[3]), + GGML_CPU_E8M0_TO_FP32_HALF(b_ptr_1[b].e[2]), + GGML_CPU_E8M0_TO_FP32_HALF(b_ptr_1[b].e[1]), + GGML_CPU_E8M0_TO_FP32_HALF(b_ptr_1[b].e[0]), + GGML_CPU_E8M0_TO_FP32_HALF(b_ptr_0[b].e[7]), + GGML_CPU_E8M0_TO_FP32_HALF(b_ptr_0[b].e[6]), + GGML_CPU_E8M0_TO_FP32_HALF(b_ptr_0[b].e[5]), + GGML_CPU_E8M0_TO_FP32_HALF(b_ptr_0[b].e[4]), + GGML_CPU_E8M0_TO_FP32_HALF(b_ptr_0[b].e[3]), + GGML_CPU_E8M0_TO_FP32_HALF(b_ptr_0[b].e[2]), + GGML_CPU_E8M0_TO_FP32_HALF(b_ptr_0[b].e[1]), + GGML_CPU_E8M0_TO_FP32_HALF(b_ptr_0[b].e[0])); } // Process LHS in pairs of rows @@ -835,7 +868,7 @@ static void gemm_q4_b32_8x8_q8_0_lut_avx(int n, float * GGML_RESTRICT s, size_t const __m128i row_scale_f16 = _mm_shuffle_epi32(_mm_maskload_epi32((int const*)(a_ptrs[rp][b].d), loadMask), 68); const __m512 row_scale_f32 = GGML_F32Cx16_REPEAT_LOAD(row_scale_f16); - // Multiply with appropiate scales and accumulate + // Multiply with appropriate scales and accumulate acc_rows[rp * 4] = _mm512_fmadd_ps(_mm512_cvtepi32_ps(iacc_row_0), _mm512_mul_ps(col_scale_f32, _mm512_shuffle_ps(row_scale_f32, row_scale_f32, 0)), acc_rows[rp * 4]); acc_rows[rp * 4 + 1] = _mm512_fmadd_ps(_mm512_cvtepi32_ps(iacc_row_1), _mm512_mul_ps(col_scale_f32, _mm512_shuffle_ps(row_scale_f32, row_scale_f32, 85)), acc_rows[rp * 4 + 1]); acc_rows[rp * 4 + 2] = _mm512_fmadd_ps(_mm512_cvtepi32_ps(iacc_row_2), _mm512_mul_ps(col_scale_f32, _mm512_shuffle_ps(row_scale_f32, row_scale_f32, 170)), acc_rows[rp * 4 + 2]); @@ -941,6 +974,25 @@ static void gemm_q4_b32_8x8_q8_0_lut_avx(int n, float * GGML_RESTRICT s, size_t std::is_same_v<block_tx8, block_q4_0x8> || std::is_same_v<block_tx8, block_iq4_nlx8>) { col_scale_f32 = GGML_F32Cx8x2_LOAD(b_ptr_0[b].d, b_ptr_1[b].d); + } else if constexpr (std::is_same_v<block_tx8, block_mxfp4x8>) { + //TODO: simd-ify + col_scale_f32 = _mm512_set_ps( + GGML_CPU_E8M0_TO_FP32_HALF(b_ptr_1[b].e[7]), + GGML_CPU_E8M0_TO_FP32_HALF(b_ptr_1[b].e[6]), + GGML_CPU_E8M0_TO_FP32_HALF(b_ptr_1[b].e[5]), + GGML_CPU_E8M0_TO_FP32_HALF(b_ptr_1[b].e[4]), + GGML_CPU_E8M0_TO_FP32_HALF(b_ptr_1[b].e[3]), + GGML_CPU_E8M0_TO_FP32_HALF(b_ptr_1[b].e[2]), + GGML_CPU_E8M0_TO_FP32_HALF(b_ptr_1[b].e[1]), + GGML_CPU_E8M0_TO_FP32_HALF(b_ptr_1[b].e[0]), + GGML_CPU_E8M0_TO_FP32_HALF(b_ptr_0[b].e[7]), + GGML_CPU_E8M0_TO_FP32_HALF(b_ptr_0[b].e[6]), + GGML_CPU_E8M0_TO_FP32_HALF(b_ptr_0[b].e[5]), + GGML_CPU_E8M0_TO_FP32_HALF(b_ptr_0[b].e[4]), + GGML_CPU_E8M0_TO_FP32_HALF(b_ptr_0[b].e[3]), + GGML_CPU_E8M0_TO_FP32_HALF(b_ptr_0[b].e[2]), + GGML_CPU_E8M0_TO_FP32_HALF(b_ptr_0[b].e[1]), + GGML_CPU_E8M0_TO_FP32_HALF(b_ptr_0[b].e[0])); } // Load the four blocks of quantized values interleaved with each other in chunks of eight - A0,A1,A2,A3 @@ -1024,7 +1076,7 @@ static void gemm_q4_b32_8x8_q8_0_lut_avx(int n, float * GGML_RESTRICT s, size_t const __m128i row_scale_f16 = _mm_shuffle_epi32(_mm_maskload_epi32((int const*)(a_ptr[b].d), loadMask), 68); const __m512 row_scale_f32 = GGML_F32Cx16_REPEAT_LOAD(row_scale_f16); - // Multiply with appropiate scales and accumulate + // Multiply with appropriate scales and accumulate acc_rows[0] = _mm512_fmadd_ps(_mm512_cvtepi32_ps(iacc_row_0), _mm512_mul_ps(col_scale_f32, _mm512_shuffle_ps(row_scale_f32, row_scale_f32, 0)), acc_rows[0]); acc_rows[1] = _mm512_fmadd_ps(_mm512_cvtepi32_ps(iacc_row_1), _mm512_mul_ps(col_scale_f32, _mm512_shuffle_ps(row_scale_f32, row_scale_f32, 85)), acc_rows[1]); acc_rows[2] = _mm512_fmadd_ps(_mm512_cvtepi32_ps(iacc_row_2), _mm512_mul_ps(col_scale_f32, _mm512_shuffle_ps(row_scale_f32, row_scale_f32, 170)), acc_rows[2]); @@ -1123,6 +1175,16 @@ static void gemm_q4_b32_8x8_q8_0_lut_avx(int n, float * GGML_RESTRICT s, size_t std::is_same_v<block_tx8, block_q4_0x8> || std::is_same_v<block_tx8, block_iq4_nlx8>) { col_scale_f32 = GGML_F32Cx8_LOAD(b_ptr[b].d); + } else if constexpr (std::is_same_v<block_tx8, block_mxfp4x8>) { + col_scale_f32 = _mm256_set_ps( + GGML_CPU_E8M0_TO_FP32_HALF(b_ptr[b].e[7]), + GGML_CPU_E8M0_TO_FP32_HALF(b_ptr[b].e[6]), + GGML_CPU_E8M0_TO_FP32_HALF(b_ptr[b].e[5]), + GGML_CPU_E8M0_TO_FP32_HALF(b_ptr[b].e[4]), + GGML_CPU_E8M0_TO_FP32_HALF(b_ptr[b].e[3]), + GGML_CPU_E8M0_TO_FP32_HALF(b_ptr[b].e[2]), + GGML_CPU_E8M0_TO_FP32_HALF(b_ptr[b].e[1]), + GGML_CPU_E8M0_TO_FP32_HALF(b_ptr[b].e[0])); } // Process LHS in groups of four @@ -1195,7 +1257,7 @@ static void gemm_q4_b32_8x8_q8_0_lut_avx(int n, float * GGML_RESTRICT s, size_t // Load the scale(d) values for all the 4 Q8_0 blocks and repeat it across lanes const __m256 row_scale_f32 = GGML_F32Cx8_REPEAT_LOAD(a_ptrs[rp][b].d, loadMask); - // Multiply with appropiate scales and accumulate + // Multiply with appropriate scales and accumulate acc_rows[rp * 4] = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc_row_0), _mm256_mul_ps(col_scale_f32, _mm256_shuffle_ps(row_scale_f32, row_scale_f32, 0)), acc_rows[rp * 4]); acc_rows[rp * 4 + 1] = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc_row_1), _mm256_mul_ps(col_scale_f32, _mm256_shuffle_ps(row_scale_f32, row_scale_f32, 85)), acc_rows[rp * 4 + 1]); acc_rows[rp * 4 + 2] = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc_row_2), _mm256_mul_ps(col_scale_f32, _mm256_shuffle_ps(row_scale_f32, row_scale_f32, 170)), acc_rows[rp * 4 + 2]); @@ -1283,6 +1345,16 @@ static void gemm_q4_b32_8x8_q8_0_lut_avx(int n, float * GGML_RESTRICT s, size_t std::is_same_v<block_tx8, block_q4_0x8> || std::is_same_v<block_tx8, block_iq4_nlx8>) { col_scale_f32 = GGML_F32Cx8_LOAD(b_ptr[b].d); + } else if constexpr (std::is_same_v<block_tx8, block_mxfp4x8>) { + col_scale_f32 = _mm256_set_ps( + GGML_CPU_E8M0_TO_FP32_HALF(b_ptr[b].e[7]), + GGML_CPU_E8M0_TO_FP32_HALF(b_ptr[b].e[6]), + GGML_CPU_E8M0_TO_FP32_HALF(b_ptr[b].e[5]), + GGML_CPU_E8M0_TO_FP32_HALF(b_ptr[b].e[4]), + GGML_CPU_E8M0_TO_FP32_HALF(b_ptr[b].e[3]), + GGML_CPU_E8M0_TO_FP32_HALF(b_ptr[b].e[2]), + GGML_CPU_E8M0_TO_FP32_HALF(b_ptr[b].e[1]), + GGML_CPU_E8M0_TO_FP32_HALF(b_ptr[b].e[0])); } // Load the four blocks of quantized values interleaved with each other in chunks of eight - A0,A1,A2,A3 @@ -1356,7 +1428,7 @@ static void gemm_q4_b32_8x8_q8_0_lut_avx(int n, float * GGML_RESTRICT s, size_t // Load the scale(d) values for all the 4 Q8_0 blocks and repeat it across lanes const __m256 row_scale_f32 = GGML_F32Cx8_REPEAT_LOAD(a_ptr[b].d, loadMask); - // Multiply with appropiate scales and accumulate + // Multiply with appropriate scales and accumulate acc_rows[0] = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc_row_0), _mm256_mul_ps(col_scale_f32, _mm256_shuffle_ps(row_scale_f32, row_scale_f32, 0)), acc_rows[0]); acc_rows[1] = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc_row_1), _mm256_mul_ps(col_scale_f32, _mm256_shuffle_ps(row_scale_f32, row_scale_f32, 85)), acc_rows[1]); acc_rows[2] = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc_row_2), _mm256_mul_ps(col_scale_f32, _mm256_shuffle_ps(row_scale_f32, row_scale_f32, 170)), acc_rows[2]); @@ -1540,7 +1612,7 @@ void ggml_gemv_q4_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo lhs_vec_11 = _mm256_permute2f128_si256(lhs_vec_11, lhs_vec_11, 0); // Dot product done within 32 bit lanes and accumulated in the same vector - // First done for first sub block and thenn for second sub block in each sb + // First done for first sub block and then for second sub block in each sb // B0(0-3) B4(0-3) B1(0-3) B5(0-3) B2(0-3) B6(0-3) B3(0-3) B7(0-3) with A0(0-3) // B0(4-7) B4(4-7) B1(4-7) B5(4-7) B2(4-7) B6(4-7) B3(4-7) B7(4-7) with A0(4-7) // ........................................................................... @@ -1625,6 +1697,19 @@ void ggml_gemv_iq4_nl_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const ggml_gemv_iq4_nl_8x8_q8_0_generic(n, s, bs, vx, vy, nr, nc); } +void ggml_gemv_mxfp4_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { +#if defined(__AVX2__) + __m256i signextendlut = _mm256_castsi128_si256(_mm_loadu_si128((const __m128i*)kvalues_mxfp4)); + signextendlut = _mm256_permute2f128_si256(signextendlut, signextendlut, 0); + + gemv_q4_b32_8x8_q8_0_lut_avx<block_mxfp4x8>(n, s, bs, vx, vy, nr, nc, signextendlut); + + return; +#endif + + ggml_gemv_mxfp4_8x8_q8_0_generic(n, s, bs, vx, vy, nr, nc); +} + void ggml_gemv_q2_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { const int qk = QK_K; const int nb = n / qk; @@ -2337,7 +2422,7 @@ void ggml_gemm_q4_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo const __m256 row_scale_f32_ymm = _mm256_set_m128(row_scale_f32_sse, row_scale_f32_sse); const __m512 row_scale_f32 = _mm512_insertf32x8(_mm512_castps256_ps512(row_scale_f32_ymm), row_scale_f32_ymm, 1); - // Multiply with appropiate scales and accumulate (for both d and dmin) below + // Multiply with appropriate scales and accumulate (for both d and dmin) below acc_rows[rp * 4] = _mm512_fmadd_ps(_mm512_cvtepi32_ps(iacc_row_0), _mm512_mul_ps(col_scale_f32, _mm512_shuffle_ps(row_scale_f32, row_scale_f32, 0)), acc_rows[rp * 4]); acc_rows[rp * 4 + 1] = _mm512_fmadd_ps(_mm512_cvtepi32_ps(iacc_row_1), _mm512_mul_ps(col_scale_f32, _mm512_shuffle_ps(row_scale_f32, row_scale_f32, 85)), acc_rows[rp * 4 + 1]); acc_rows[rp * 4 + 2] = _mm512_fmadd_ps(_mm512_cvtepi32_ps(iacc_row_2), _mm512_mul_ps(col_scale_f32, _mm512_shuffle_ps(row_scale_f32, row_scale_f32, 170)), acc_rows[rp * 4 + 2]); @@ -2700,7 +2785,7 @@ void ggml_gemm_q4_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo const __m256 row_scale_f32_ymm = _mm256_set_m128(row_scale_f32_sse, row_scale_f32_sse); const __m512 row_scale_f32 = _mm512_insertf32x8(_mm512_castps256_ps512(row_scale_f32_ymm), row_scale_f32_ymm, 1); - // Multiply with appropiate scales and accumulate (for both d and dmin) below + // Multiply with appropriate scales and accumulate (for both d and dmin) below acc_rows[0] = _mm512_fmadd_ps(_mm512_cvtepi32_ps(iacc_row_0), _mm512_mul_ps(col_scale_f32, _mm512_shuffle_ps(row_scale_f32, row_scale_f32, 0)), acc_rows[0]); acc_rows[1] = _mm512_fmadd_ps(_mm512_cvtepi32_ps(iacc_row_1), _mm512_mul_ps(col_scale_f32, _mm512_shuffle_ps(row_scale_f32, row_scale_f32, 85)), acc_rows[1]); acc_rows[2] = _mm512_fmadd_ps(_mm512_cvtepi32_ps(iacc_row_2), _mm512_mul_ps(col_scale_f32, _mm512_shuffle_ps(row_scale_f32, row_scale_f32, 170)), acc_rows[2]); @@ -2717,7 +2802,7 @@ void ggml_gemm_q4_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo acc_min_rows[3] = _mm512_fmadd_ps(_mm512_cvtepi32_ps(iacc_row_min_3), _mm512_mul_ps(col_dmin_f32, _mm512_shuffle_ps(row_scale_f32, row_scale_f32, 255)), acc_min_rows[3]); } } - // Store accumlated values + // Store accumulated values for (int i = 0; i < 4; i++) { _mm512_storeu_ps((float * )(s + ((y * 4 + i) * bs + x * 8)), _mm512_sub_ps(acc_rows[i], acc_min_rows[i])); } @@ -3045,7 +3130,7 @@ void ggml_gemm_q4_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo const __m128 row_scale_f32_sse = _mm_load_ps(a_ptrs[rp][b].d); const __m256 row_scale_f32 = _mm256_set_m128(row_scale_f32_sse, row_scale_f32_sse);//GGML_F32Cx8_REPEAT_LOAD(a_ptrs[rp][b].d, loadMask); - // Multiply with appropiate scales and accumulate (for both d and dmin) below + // Multiply with appropriate scales and accumulate (for both d and dmin) below acc_rows[rp * 4] = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc_row_0), _mm256_mul_ps(col_scale_f32, _mm256_shuffle_ps(row_scale_f32, row_scale_f32, 0)), acc_rows[rp * 4]); acc_rows[rp * 4 + 1] = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc_row_1), _mm256_mul_ps(col_scale_f32, _mm256_shuffle_ps(row_scale_f32, row_scale_f32, 85)), acc_rows[rp * 4 + 1]); acc_rows[rp * 4 + 2] = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc_row_2), _mm256_mul_ps(col_scale_f32, _mm256_shuffle_ps(row_scale_f32, row_scale_f32, 170)), acc_rows[rp * 4 + 2]); @@ -3375,7 +3460,7 @@ void ggml_gemm_q4_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo const __m128 row_scale_f32_sse = _mm_load_ps(a_ptr[b].d); const __m256 row_scale_f32 = _mm256_set_m128(row_scale_f32_sse, row_scale_f32_sse); //GGML_F32Cx8_REPEAT_LOAD(a_ptrs[rp][b].d, loadMask); - // Multiply with appropiate scales and accumulate (for both d and dmin) below + // Multiply with appropriate scales and accumulate (for both d and dmin) below acc_rows[0] = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc_row_0), _mm256_mul_ps(col_scale_f32, _mm256_shuffle_ps(row_scale_f32, row_scale_f32, 0)), acc_rows[0]); acc_rows[1] = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc_row_1), _mm256_mul_ps(col_scale_f32, _mm256_shuffle_ps(row_scale_f32, row_scale_f32, 85)), acc_rows[1]); acc_rows[2] = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc_row_2), _mm256_mul_ps(col_scale_f32, _mm256_shuffle_ps(row_scale_f32, row_scale_f32, 170)), acc_rows[2]); @@ -3423,6 +3508,21 @@ void ggml_gemm_iq4_nl_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const ggml_gemm_iq4_nl_4x4_q8_0(n, s, bs, vx, vy, nr, nc); } +void ggml_gemm_mxfp4_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { +#if defined(__AVX2__) || defined(__AVX512F__) + { + __m256i signextendlut = _mm256_castsi128_si256(_mm_loadu_si128((const __m128i*)kvalues_mxfp4)); + signextendlut = _mm256_permute2f128_si256(signextendlut, signextendlut, 0); + + gemm_q4_b32_8x8_q8_0_lut_avx<block_mxfp4x8>(n, s, bs, vx, vy, nr, nc, signextendlut); + + return; + } +#endif // defined(__AVX2__) || defined(__AVX512F__) + + ggml_gemm_mxfp4_8x8_q8_0_generic(n, s, bs, vx, vy, nr, nc); +} + void ggml_gemm_q2_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { const int qk = QK_K; const int nb = n / qk; @@ -4168,7 +4268,7 @@ void ggml_gemm_q2_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo const __m256 row_scale_f32_ymm = _mm256_set_m128(row_scale_f32_sse, row_scale_f32_sse); const __m512 row_scale_f32 = _mm512_insertf32x8(_mm512_castps256_ps512(row_scale_f32_ymm), row_scale_f32_ymm, 1); - // Multiply with appropiate scales and accumulate (for both d and dmin) below + // Multiply with appropriate scales and accumulate (for both d and dmin) below acc_rows[rp * 4] = _mm512_fmadd_ps(_mm512_cvtepi32_ps(iacc_row_0), _mm512_mul_ps(col_scale_f32, _mm512_shuffle_ps(row_scale_f32, row_scale_f32, 0)), acc_rows[rp * 4]); acc_rows[rp * 4 + 1] = _mm512_fmadd_ps(_mm512_cvtepi32_ps(iacc_row_1), _mm512_mul_ps(col_scale_f32, _mm512_shuffle_ps(row_scale_f32, row_scale_f32, 85)), acc_rows[rp * 4 + 1]); acc_rows[rp * 4 + 2] = _mm512_fmadd_ps(_mm512_cvtepi32_ps(iacc_row_2), _mm512_mul_ps(col_scale_f32, _mm512_shuffle_ps(row_scale_f32, row_scale_f32, 170)), acc_rows[rp * 4 + 2]); @@ -4935,7 +5035,7 @@ void ggml_gemm_q2_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo acc_min_rows[3] = _mm512_fmadd_ps(_mm512_cvtepi32_ps(iacc_row_min_3), _mm512_mul_ps(col_dmin_f32, _mm512_shuffle_ps(row_scale_f32, row_scale_f32, 255)), acc_min_rows[3]); } } - // Store accumlated values + // Store accumulated values for (int i = 0; i < 4; i++) { _mm512_storeu_ps((float * )(s + ((y * 4 + i) * bs + x * 8)), _mm512_sub_ps(acc_rows[i], acc_min_rows[i])); } @@ -5577,7 +5677,7 @@ void ggml_gemm_q2_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo const __m128 row_scale_f32_sse = _mm_load_ps(a_ptrs[rp][b].d); const __m256 row_scale_f32 = _mm256_set_m128(row_scale_f32_sse, row_scale_f32_sse); - // Multiply with appropiate scales and accumulate (for both d and dmin) below + // Multiply with appropriate scales and accumulate (for both d and dmin) below acc_rows[rp * 4] = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc_row_0), _mm256_mul_ps(col_scale_f32, _mm256_shuffle_ps(row_scale_f32, row_scale_f32, 0)), acc_rows[rp * 4]); acc_rows[rp * 4 + 1] = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc_row_1), _mm256_mul_ps(col_scale_f32, _mm256_shuffle_ps(row_scale_f32, row_scale_f32, 85)), acc_rows[rp * 4 + 1]); acc_rows[rp * 4 + 2] = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc_row_2), _mm256_mul_ps(col_scale_f32, _mm256_shuffle_ps(row_scale_f32, row_scale_f32, 170)), acc_rows[rp * 4 + 2]); @@ -6249,7 +6349,7 @@ void ggml_gemm_q2_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo const __m128 row_scale_f32_sse = _mm_load_ps(a_ptr[b].d); const __m256 row_scale_f32 = _mm256_set_m128(row_scale_f32_sse, row_scale_f32_sse); - // Multiply with appropiate scales and accumulate (for both d and dmin) below + // Multiply with appropriate scales and accumulate (for both d and dmin) below acc_rows[0] = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc_row_0), _mm256_mul_ps(col_scale_f32, _mm256_shuffle_ps(row_scale_f32, row_scale_f32, 0)), acc_rows[0]); acc_rows[1] = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc_row_1), _mm256_mul_ps(col_scale_f32, _mm256_shuffle_ps(row_scale_f32, row_scale_f32, 85)), acc_rows[1]); acc_rows[2] = _mm256_fmadd_ps(_mm256_cvtepi32_ps(iacc_row_2), _mm256_mul_ps(col_scale_f32, _mm256_shuffle_ps(row_scale_f32, row_scale_f32, 170)), acc_rows[2]); diff --git a/ggml/src/ggml-cpu/binary-ops.cpp b/ggml/src/ggml-cpu/binary-ops.cpp index 14f5b43ae0e..75e38290015 100644 --- a/ggml/src/ggml-cpu/binary-ops.cpp +++ b/ggml/src/ggml-cpu/binary-ops.cpp @@ -59,11 +59,7 @@ static void apply_binary_op(const ggml_compute_params * params, ggml_tensor * ds GGML_ASSERT(nb00 == sizeof(src0_t)); const auto [ir0, ir1] = get_thread_range(params, src0); - const bool is_src1_contiguous = (nb10 == sizeof(src1_t)); - - if (!is_src1_contiguous) { // broadcast not implemented yet for non-contiguous - GGML_ASSERT(ggml_are_same_shape(src0, src1)); - } + const bool is_src1_contiguous_rows = ggml_is_contiguous_rows(src1); #ifdef GGML_USE_ACCELERATE vDSP_fn_t vDSP_op = nullptr; @@ -94,7 +90,7 @@ static void apply_binary_op(const ggml_compute_params * params, ggml_tensor * ds const src0_t * src0_ptr = (const src0_t *) ((const char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01); const src1_t * src1_ptr = (const src1_t *) ((const char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11); - if (is_src1_contiguous) { + if (is_src1_contiguous_rows) { // src1 is broadcastable across src0 and dst in i1, i2, i3 const int64_t nr0 = ne00 / ne10; diff --git a/ggml/src/ggml-cpu/cmake/FindSMTIME.cmake b/ggml/src/ggml-cpu/cmake/FindSMTIME.cmake new file mode 100644 index 00000000000..c8a4d4b4ec9 --- /dev/null +++ b/ggml/src/ggml-cpu/cmake/FindSMTIME.cmake @@ -0,0 +1,32 @@ +include(CheckCSourceRuns) + +if (CMAKE_SYSTEM_PROCESSOR MATCHES "^(riscv)" AND GGML_CPU_RISCV64_SPACEMIT) + set(SMT_MARCH_STR "-march=rv64gcv_zfh_zvfh_zba_zicbop") + if (CMAKE_C_COMPILER_ID STREQUAL "GNU" AND + CMAKE_C_COMPILER_VERSION VERSION_GREATER_EQUAL 15) + string(APPEND SMT_MARCH_STR "_xsmtvdotii") + endif() + set(CMAKE_REQUIRED_FLAGS "${SMT_MARCH_STR}") + + check_c_source_compiles("int main() {__asm__ volatile(\"vmadot v2, v0, v1\");}" SPACEMIT_RISCV_COMPILER_SUPPORT_IME1) + check_c_source_compiles("int main() {__asm__ volatile(\"vmadot v2, v0, v1, i4\");}" SPACEMIT_RISCV_COMPILER_SUPPORT_VMADOT_S4) + check_c_source_compiles("int main() {__asm__ volatile(\"vmadot v2, v0, v1, i8\");}" SPACEMIT_RISCV_COMPILER_SUPPORT_VMADOT_S8) + check_c_source_compiles("int main() {__asm__ volatile(\"vfwmadot v2, v0, v1, fp16\");}" SPACEMIT_RISCV_COMPILER_SUPPORT_VFWMADOT_FP16) + check_c_source_compiles("int main() {__asm__ volatile(\"vmadot.hp v2, v0, v1, v0, 0, i4\");}" SPACEMIT_RISCV_COMPILER_SUPPORT_VFMADOT_S4) + check_c_source_compiles("int main() {__asm__ volatile(\"vmadot.hp v2, v0, v1, v0, 0, i8\");}" SPACEMIT_RISCV_COMPILER_SUPPORT_VFMADOT_S8) + check_c_source_compiles("int main() {__asm__ volatile(\"vmadot1 v2, v0, v1\");}" SPACEMIT_RISCV_COMPILER_SUPPORT_VMADOTN) + check_c_source_compiles("int main() {__asm__ volatile(\"vpack.vv v2, v0, v1, 2\");}" SPACEMIT_RISCV_COMPILER_SUPPORT_VPACK) + check_c_source_compiles("int main() {__asm__ volatile(\"vnspack.vv v2, v0, v1, 2\");}" SPACEMIT_RISCV_COMPILER_SUPPORT_VNPACK) + unset(CMAKE_REQUIRED_FLAGS) + + list(APPEND RISCV64_SPACEMIT_IME_SPEC "") + if (SPACEMIT_RISCV_COMPILER_SUPPORT_IME1) + set(RISCV64_SPACEMIT_IME_SPEC "RISCV64_SPACEMIT_IME1") + endif() + + if (SPACEMIT_RISCV_COMPILER_SUPPORT_VMADOT_S4 AND SPACEMIT_RISCV_COMPILER_SUPPORT_VPACK AND SPACEMIT_RISCV_COMPILER_SUPPORT_VNPACK) + list(APPEND RISCV64_SPACEMIT_IME_SPEC "RISCV64_SPACEMIT_IME2") + endif() + + message("RISCV64_SPACEMIT_IME_SPEC: ${RISCV64_SPACEMIT_IME_SPEC}") +endif() diff --git a/ggml/src/ggml-cpu/common.h b/ggml/src/ggml-cpu/common.h index 6adca5437f8..abbadc359c5 100644 --- a/ggml/src/ggml-cpu/common.h +++ b/ggml/src/ggml-cpu/common.h @@ -6,6 +6,9 @@ #include "ggml-impl.h" #include "simd-mappings.h" +#define GGML_FA_TILE_Q 64 +#define GGML_FA_TILE_KV 64 + #ifdef __cplusplus #include <utility> @@ -84,4 +87,9 @@ static std::pair<int64_t, int64_t> get_thread_range(const struct ggml_compute_pa return {ir0, ir1}; } +struct ggml_fa_tile_config { + static constexpr size_t Q = GGML_FA_TILE_Q; + static constexpr size_t KV = GGML_FA_TILE_KV; +}; + #endif diff --git a/ggml/src/ggml-cpu/ggml-cpu-impl.h b/ggml/src/ggml-cpu/ggml-cpu-impl.h index 0e8dd0ae053..5d1ca5ffcc3 100644 --- a/ggml/src/ggml-cpu/ggml-cpu-impl.h +++ b/ggml/src/ggml-cpu/ggml-cpu-impl.h @@ -24,6 +24,9 @@ struct ggml_compute_params { void * wdata; struct ggml_threadpool * threadpool; + + // use reference implementation + bool use_ref; }; @@ -303,6 +306,7 @@ inline static uint8x16_t ggml_vqtbl1q_u8(uint8x16_t a, uint8x16_t b) { #if !defined(__ARM_FEATURE_DOTPROD) +// NOTE: this fallback produces the same total sum as native vdotq_s32 but with different per-lane grouping — do not use when individual lane values matter. inline static int32x4_t ggml_vdotq_s32(int32x4_t acc, int8x16_t a, int8x16_t b) { const int16x8_t p0 = vmull_s8(vget_low_s8 (a), vget_low_s8 (b)); const int16x8_t p1 = vmull_s8(vget_high_s8(a), vget_high_s8(b)); @@ -316,6 +320,15 @@ inline static int32x4_t ggml_vdotq_s32(int32x4_t acc, int8x16_t a, int8x16_t b) #endif // !defined(__ARM_FEATURE_DOTPROD) +static inline int32x4_t ggml_nvfp4_dot8(const int8x8_t q4_lo, const int8x8_t q8_lo, + const int8x8_t q4_hi, const int8x8_t q8_hi) { + const int16x8_t p_lo = vmull_s8(q4_lo, q8_lo); + const int16x8_t p_hi = vmull_s8(q4_hi, q8_hi); + const int32x4_t sum_lo = vpaddlq_s16(p_lo); + const int32x4_t sum_hi = vpaddlq_s16(p_hi); + return vaddq_s32(sum_lo, sum_hi); +} + #endif // defined(__ARM_NEON) #ifdef __wasm_simd128__ diff --git a/ggml/src/ggml-cpu/ggml-cpu.c b/ggml/src/ggml-cpu/ggml-cpu.c index f7ba1fe317d..eb8341c9aec 100644 --- a/ggml/src/ggml-cpu/ggml-cpu.c +++ b/ggml/src/ggml-cpu/ggml-cpu.c @@ -5,7 +5,6 @@ #include "ggml-backend.h" #include "traits.h" #include "ggml-cpu-impl.h" -#include "ggml-cpu.h" #include "ggml-impl.h" #include "quants.h" #include "ggml-threading.h" @@ -14,6 +13,7 @@ #include "vec.h" #include "ops.h" #include "ggml.h" +#include "common.h" #if defined(_MSC_VER) || defined(__MINGW32__) #include <malloc.h> // using malloc.h with MSC/MINGW @@ -50,6 +50,10 @@ #include "llamafile/sgemm.h" #endif +#ifdef GGML_USE_CPU_RISCV64_SPACEMIT +# include "spacemit/ime.h" +#endif + // Note: once we move threading into a separate C++ file // will use std::hardware_destructive_interference_size instead of hardcoding it here // and we'll use C++ attribute syntax. @@ -75,6 +79,9 @@ // precomputed f32 table for f16 (256 KB) (simd-mappings.h) float ggml_table_f32_f16[1 << 16]; +// precomputed f32 table for e8m0 half (1 KB) (simd-mappings.h) +float ggml_table_f32_e8m0_half[1 << 8]; + #if defined(__ARM_ARCH) struct ggml_arm_arch_features_type { int sve_cnt; @@ -214,6 +221,12 @@ static const struct ggml_type_traits_cpu type_traits_cpu[GGML_TYPE_COUNT] = { .vec_dot_type = GGML_TYPE_F16, .nrows = 1, }, + [GGML_TYPE_Q1_0] = { + .from_float = quantize_row_q1_0, + .vec_dot = ggml_vec_dot_q1_0_q8_0, + .vec_dot_type = GGML_TYPE_Q8_0, + .nrows = 1, + }, [GGML_TYPE_Q4_0] = { .from_float = quantize_row_q4_0, .vec_dot = ggml_vec_dot_q4_0_q8_0, @@ -267,6 +280,12 @@ static const struct ggml_type_traits_cpu type_traits_cpu[GGML_TYPE_COUNT] = { .vec_dot_type = GGML_TYPE_Q8_0, .nrows = 1, }, + [GGML_TYPE_NVFP4] = { + .from_float = quantize_row_nvfp4, + .vec_dot = ggml_vec_dot_nvfp4_q8_0, + .vec_dot_type = GGML_TYPE_Q8_0, + .nrows = 1, + }, [GGML_TYPE_Q2_K] = { .from_float = quantize_row_q2_K, .vec_dot = ggml_vec_dot_q2_K_q8_K, @@ -1230,6 +1249,12 @@ void ggml_compute_forward_mul_mat( const struct ggml_tensor * src0 = dst->src[0]; const struct ggml_tensor * src1 = dst->src[1]; + const int32_t hint = ggml_get_op_params_i32(dst, 1); + if (hint == GGML_HINT_SRC0_IS_HADAMARD && !params->use_ref) { + ggml_compute_forward_fwht(params, dst); + return; + } + GGML_TENSOR_BINARY_OP_LOCALS const int ith = params->ith; @@ -1887,6 +1912,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm { ggml_compute_forward_im2col_3d(params, tensor); } break; + case GGML_OP_COL2IM_1D: + { + ggml_compute_forward_col2im_1d(params, tensor); + } break; case GGML_OP_CONV_2D: { ggml_compute_forward_conv_2d(params, tensor); @@ -2018,6 +2047,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm { ggml_compute_forward_solve_tri(params, tensor); } break; + case GGML_OP_GATED_DELTA_NET: + { + ggml_compute_forward_gated_delta_net(params, tensor); + } break; case GGML_OP_MAP_CUSTOM1: { ggml_compute_forward_map_custom1(params, tensor); @@ -2197,6 +2230,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) { } break; case GGML_OP_COUNT_EQUAL: case GGML_OP_SOLVE_TRI: + case GGML_OP_GATED_DELTA_NET: { n_tasks = n_threads; } break; @@ -2313,6 +2347,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) { case GGML_OP_CONV_2D: case GGML_OP_CONV_3D: case GGML_OP_CONV_2D_DW: + case GGML_OP_COL2IM_1D: case GGML_OP_CONV_TRANSPOSE_1D: case GGML_OP_CONV_TRANSPOSE_2D: { @@ -2336,11 +2371,15 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) { case GGML_OP_FLASH_ATTN_BACK: case GGML_OP_SSM_CONV: case GGML_OP_SSM_SCAN: + { + n_tasks = n_threads; + } break; case GGML_OP_RWKV_WKV6: case GGML_OP_GATED_LINEAR_ATTN: case GGML_OP_RWKV_WKV7: { - n_tasks = n_threads; + const int64_t n_heads = node->src[1]->ne[1]; + n_tasks = MIN(n_threads, n_heads); } break; case GGML_OP_WIN_PART: case GGML_OP_WIN_UNPART: @@ -2474,7 +2513,7 @@ static bool ggml_thread_apply_priority(int32_t prio) { if (prio != GGML_SCHED_PRIO_LOW) { // Tell Windows that this thread should not be throttled (needs its own CPU core). - // Newer Windows 11 versions aggresively park (offline) CPU cores and often place + // Newer Windows 11 versions aggressively park (offline) CPU cores and often place // all our threads onto the first 4 cores which results in terrible performance with // n_threads > 4 #if _WIN32_WINNT >= 0x0602 @@ -2857,8 +2896,12 @@ struct ggml_cplan ggml_graph_plan( const int64_t ne11 = node->src[1]->ne[1]; // H const int64_t ne12 = node->src[1]->ne[2]; // Channels In - cur += sizeof(ggml_fp16_t)*ne00*ne01*ne02*ne03; - cur += sizeof(ggml_fp16_t)*ne10*ne11*ne12; + GGML_ASSERT(node->src[0]->type == GGML_TYPE_F16 || node->src[0]->type == GGML_TYPE_F32); + GGML_ASSERT(node->src[1]->type == GGML_TYPE_F32); + + cur += ggml_type_size(node->src[0]->type) * ne00 * ne01 * ne02 * ne03; + cur += ggml_type_size(node->src[0]->type) * ne10 * ne11 * ne12; + } break; case GGML_OP_TOP_K: { @@ -2866,10 +2909,20 @@ struct ggml_cplan ggml_graph_plan( } break; case GGML_OP_FLASH_ATTN_EXT: { - const int64_t ne10 = node->src[1]->ne[0]; // DK - const int64_t ne20 = node->src[2]->ne[0]; // DV + const int64_t neq2 = node->src[0]->ne[2]; // number of query heads + const int64_t DK = node->src[1]->ne[0]; + const int64_t DV = node->src[2]->ne[0]; + + // Tiled flash attention scratch (tile sizes defined in common.h) + // Per-thread: Q_q + KQ + mask + VKQ32 + V32 + K_f32 + padding + size_t prefill = sizeof(float)*(GGML_FA_TILE_Q*DK + 2*GGML_FA_TILE_Q*GGML_FA_TILE_KV + GGML_FA_TILE_Q*DV + GGML_FA_TILE_KV*DV + GGML_FA_TILE_KV*DK)*n_tasks; - cur = sizeof(float)*(1*ne10 + 2*ne20)*n_tasks; // 1x head size K + 2x head size V (per thread) + // Decode path: n_kv_chunks = n_tasks (one chunk per thread) + // Per-thread: VKQ accmulator (DV), partial M, partial S + intra-thread scratch for V, Q and VKQ + size_t n_chunks = n_tasks; + size_t decode = sizeof(float)*(neq2*n_chunks*(2+DV) + n_tasks*(DK + 2*DV)); + + cur += MAX(prefill, decode); } break; case GGML_OP_FLASH_ATTN_BACK: { @@ -2892,6 +2945,13 @@ struct ggml_cplan ggml_graph_plan( { cur = ggml_type_size(node->type)*(n_tasks + node->src[0]->ne[0]*n_tasks); } break; + case GGML_OP_GATED_DELTA_NET: + { + const int64_t S_v = node->src[2]->ne[0]; + const int64_t K = ggml_get_op_params_i32(node, 0); + const int64_t per_thread = S_v + (K > 1 ? S_v * S_v : 0); + cur = per_thread * sizeof(float) * n_tasks; + } break; case GGML_OP_COUNT: { GGML_ABORT("fatal error"); @@ -2916,6 +2976,45 @@ struct ggml_cplan ggml_graph_plan( return cplan; } + +// Try to fuse the current node with subsequent nodes for better performance. +// Returns the number of nodes skipped by fusion (>=1), or 0 if no fusion was applied. +static bool ggml_cpu_disable_fusion = false; // initialized once in ggml_cpu_init(), read-only afterwards + +static int ggml_cpu_try_fuse_ops( + const struct ggml_cgraph * cgraph, + const int node_n, + const struct ggml_compute_params * params, + const struct ggml_cplan * cplan) { + + if (ggml_cpu_disable_fusion || cplan->use_ref) { + return 0; + } + + struct ggml_tensor * node = cgraph->nodes[node_n]; + + if (node->op == GGML_OP_RMS_NORM) { + // RMS_NORM + MUL fusion + const enum ggml_op fuse_ops[] = { GGML_OP_RMS_NORM, GGML_OP_MUL }; + if (ggml_can_fuse(cgraph, node_n, fuse_ops, 2)) { + struct ggml_tensor * mul_node = cgraph->nodes[node_n + 1]; + const struct ggml_tensor * mul_w = (mul_node->src[0] == node) + ? mul_node->src[1] : mul_node->src[0]; + if (node->src[0]->type == GGML_TYPE_F32 && + mul_node->type == GGML_TYPE_F32 && + mul_w->type == GGML_TYPE_F32 && + mul_w->ne[0] == node->ne[0] && + mul_w->nb[0] == sizeof(float)) { + + ggml_compute_forward_rms_norm_mul_fused(params, node, mul_node); + return 1; + } + } + } + + return 0; +} + static thread_ret_t ggml_graph_compute_thread(void * data) { struct ggml_compute_state * state = (struct ggml_compute_state *) data; struct ggml_threadpool * tp = state->threadpool; @@ -2923,17 +3022,26 @@ static thread_ret_t ggml_graph_compute_thread(void * data) { const struct ggml_cgraph * cgraph = tp->cgraph; const struct ggml_cplan * cplan = tp->cplan; +#ifdef GGML_USE_CPU_RISCV64_SPACEMIT + ggml_backend_cpu_riscv64_spacemit_set_numa_thread_affinity(state->ith); +#else set_numa_thread_affinity(state->ith); +#endif struct ggml_compute_params params = { - /*.ith =*/ state->ith, - /*.nth =*/ atomic_load_explicit(&tp->n_graph, memory_order_relaxed) & GGML_THREADPOOL_N_THREADS_MASK, - /*.wsize =*/ cplan->work_size, - /*.wdata =*/ cplan->work_data, - /*.threadpool=*/ tp, + /*.ith =*/ state->ith, + /*.nth =*/ atomic_load_explicit(&tp->n_graph, memory_order_relaxed) & GGML_THREADPOOL_N_THREADS_MASK, + /*.wsize =*/ cplan->work_size, + /*.wdata =*/ cplan->work_data, + /*.threadpool =*/ tp, + /*.use_ref =*/ cplan->use_ref, }; - GGML_PRINT_DEBUG("thread #%d compute-start cplan %p last-graph %d \n", state->ith, cplan, state->last_graph); +#ifdef GGML_USE_OPENMP + GGML_PRINT_DEBUG("thread #%d compute-start cplan %p\n", state->ith, (const void *)cplan); +#else + GGML_PRINT_DEBUG("thread #%d compute-start cplan %p last-graph %d\n", state->ith, (const void *)cplan, state->last_graph); +#endif for (int node_n = 0; node_n < cgraph->n_nodes && atomic_load_explicit(&tp->abort, memory_order_relaxed) != node_n; node_n++) { struct ggml_tensor * node = cgraph->nodes[node_n]; @@ -2943,7 +3051,18 @@ static thread_ret_t ggml_graph_compute_thread(void * data) { continue; } - ggml_compute_forward(¶ms, node); + if ((node->flags & GGML_TENSOR_FLAG_COMPUTE) == 0) { + continue; + } + + // TODO: move fused-op detection into ggml_graph_plan so fusion decisions are made once at planning time + // Try fused ops, fall back to normal compute + const int n_fused = ggml_cpu_try_fuse_ops(cgraph, node_n, ¶ms, cplan); + if (n_fused > 0) { + node_n += n_fused; + } else { + ggml_compute_forward(¶ms, node); + } if (state->ith == 0 && cplan->abort_callback && cplan->abort_callback(cplan->abort_callback_data)) { @@ -2956,10 +3075,18 @@ static thread_ret_t ggml_graph_compute_thread(void * data) { } } - GGML_PRINT_DEBUG("thread #%d compute-done cplan %p last-graph %d \n", state->ith, cplan, state->last_graph); +#ifdef GGML_USE_OPENMP + GGML_PRINT_DEBUG("thread #%d compute-done cplan %p\n", state->ith, (const void *)cplan); +#else + GGML_PRINT_DEBUG("thread #%d compute-done cplan %p last-graph %d\n", state->ith, (const void *)cplan, state->last_graph); +#endif ggml_barrier(state->threadpool); +#ifdef GGML_USE_CPU_RISCV64_SPACEMIT + ggml_backend_cpu_riscv64_spacemit_clear_numa_thread_affinity_threaded(state->ith); +#endif + return 0; } @@ -3666,6 +3793,11 @@ void ggml_cpu_init(void) { ggml_table_gelu_quick_f16[i] = GGML_CPU_FP32_TO_FP16(ggml_gelu_quick_f32(f)); } + // initialize E8M0 half table (256 entries) + for (int i = 0; i < (1 << 8); ++i) { + ggml_table_f32_e8m0_half[i] = GGML_E8M0_TO_FP32_HALF(i); + } + const uint64_t t_end = ggml_time_us(); UNUSED(t_end); GGML_PRINT_DEBUG("%s: GELU, Quick GELU, SILU and EXP tables initialized in %f ms\n", __func__, (t_end - t_start)/1000.0); @@ -3696,6 +3828,11 @@ void ggml_cpu_init(void) { ggml_init_riscv_arch_features(); #endif + { + const char * env = getenv("GGML_CPU_DISABLE_FUSION"); + ggml_cpu_disable_fusion = (env != NULL && atoi(env) == 1); + } + is_first_call = false; } diff --git a/ggml/src/ggml-cpu/ggml-cpu.cpp b/ggml/src/ggml-cpu/ggml-cpu.cpp index f4713a42185..128883b41ce 100644 --- a/ggml/src/ggml-cpu/ggml-cpu.cpp +++ b/ggml/src/ggml-cpu/ggml-cpu.cpp @@ -105,6 +105,8 @@ struct ggml_backend_cpu_context { ggml_abort_callback abort_callback; void * abort_callback_data; + + bool use_ref; // use reference implementation }; static const char * ggml_backend_cpu_get_name(ggml_backend_t backend) { @@ -143,6 +145,7 @@ static ggml_backend_graph_plan_t ggml_backend_cpu_graph_plan_create(ggml_backend cpu_plan->cplan.abort_callback = cpu_ctx->abort_callback; cpu_plan->cplan.abort_callback_data = cpu_ctx->abort_callback_data; + cpu_plan->cplan.use_ref = cpu_ctx->use_ref; return cpu_plan; } @@ -182,6 +185,7 @@ static enum ggml_status ggml_backend_cpu_graph_compute(ggml_backend_t backend, s cplan.abort_callback = cpu_ctx->abort_callback; cplan.abort_callback_data = cpu_ctx->abort_callback_data; + cplan.use_ref = cpu_ctx->use_ref; return ggml_graph_compute(cgraph, &cplan); } @@ -191,6 +195,8 @@ static const struct ggml_backend_i ggml_backend_cpu_i = { /* .free = */ ggml_backend_cpu_free, /* .set_tensor_async = */ NULL, /* .get_tensor_async = */ NULL, + /* .set_tensor_2d_async = */ NULL, + /* .get_tensor_2d_async = */ NULL, /* .cpy_tensor_async = */ NULL, /* .synchronize = */ NULL, /* .graph_plan_create = */ ggml_backend_cpu_graph_plan_create, @@ -223,6 +229,7 @@ ggml_backend_t ggml_backend_cpu_init(void) { ctx->work_size = 0; ctx->abort_callback = NULL; ctx->abort_callback_data = NULL; + ctx->use_ref = false; ggml_backend_t cpu_backend = new ggml_backend { /* .guid = */ ggml_backend_cpu_guid(), @@ -270,6 +277,13 @@ void ggml_backend_cpu_set_abort_callback(ggml_backend_t backend_cpu, ggml_abort_ ctx->abort_callback_data = abort_callback_data; } +void ggml_backend_cpu_set_use_ref(ggml_backend_t backend_cpu, bool use_ref) { + GGML_ASSERT(ggml_backend_is_cpu(backend_cpu)); + + struct ggml_backend_cpu_context * ctx = (struct ggml_backend_cpu_context *)backend_cpu->context; + ctx->use_ref = use_ref; +} + // CPU backend - device struct ggml_backend_cpu_device_context { @@ -646,6 +660,9 @@ static void * ggml_backend_cpu_get_proc_address(ggml_backend_reg_t reg, const ch if (strcmp(name, "ggml_backend_cpu_is_numa") == 0) { return (void *)ggml_is_numa; } + if (strcmp(name, "ggml_backend_cpu_set_use_ref") == 0) { + return (void *)ggml_backend_cpu_set_use_ref; + } // threadpool - TODO: move to ggml-base if (strcmp(name, "ggml_threadpool_new") == 0) { diff --git a/ggml/src/ggml-cpu/kleidiai/kernels.cpp b/ggml/src/ggml-cpu/kleidiai/kernels.cpp index d114f2d49bf..8c4d7bc925f 100644 --- a/ggml/src/ggml-cpu/kleidiai/kernels.cpp +++ b/ggml/src/ggml-cpu/kleidiai/kernels.cpp @@ -1,4 +1,4 @@ -// SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates <open-source-office@arm.com> +// SPDX-FileCopyrightText: Copyright 2025-2026 Arm Limited and/or its affiliates <open-source-office@arm.com> // SPDX-License-Identifier: MIT // @@ -9,7 +9,6 @@ #include "kai_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod.h" #include "kai_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod.h" #include "kai_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm.h" -#include "kai_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa.h" #include "kai_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot.h" #include "kai_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa.h" #include "kai_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa.h" @@ -20,6 +19,7 @@ #include "kai_matmul_clamp_f32_qai8dxp4x8_qsi8cxp4x8_16x4_neon_i8mm.h" #include "kai_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p8x8_16x8_sve_i8mm.h" #include "kai_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p8x8_1x8_sve_dotprod.h" +#include "kai_matmul_clamp_f32_f16p1vlx2_qsi4c32p4vlx2_1vlx4vl_sme2_mopa.h" #include "kai_lhs_pack_bf16p2vlx2_f32_sme.h" #include "kai_lhs_quant_pack_qsi8d32p_f32.h" @@ -31,6 +31,7 @@ #include "kai_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0.h" #include "kai_rhs_pack_nxk_qsi4c32ps1s0scalef16_qsu4c32s16s0_neon.h" #include "kai_rhs_pack_nxk_qsi8cxp_qsi8cx_neon.h" +#include "kai_lhs_pack_f16pmrx2_f32_neon.h" #include "kai_common.h" @@ -309,24 +310,24 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = { { /* SME GEMM */ /* .kern_info = */ { - /* .get_m_step = */ kai_get_m_step_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa, - /* .get_n_step = */ kai_get_n_step_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa, - /* .get_mr = */ kai_get_mr_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa, - /* .get_nr = */ kai_get_nr_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa, - /* .get_kr = */ kai_get_kr_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa, - /* .get_sr = */ kai_get_sr_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa, - /* .get_dst_offset = */ kai_get_dst_offset_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa, - /* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa, - /* .get_lhs_offset_ex = */ &kernel_offs_fn3<kai_get_lhs_packed_offset_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa>, - /* .get_rhs_packed_offset_ex = */ &kernel_offs_fn3<kai_get_rhs_packed_offset_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa>, - /* .run_kernel_ex = */ &kernel_run_fn11<kai_run_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa>, + /* .get_m_step = */ kai_get_m_step_matmul_clamp_f32_f16p1vlx2_qsi4c32p4vlx2_1vlx4vl_sme2_mopa, + /* .get_n_step = */ kai_get_n_step_matmul_clamp_f32_f16p1vlx2_qsi4c32p4vlx2_1vlx4vl_sme2_mopa, + /* .get_mr = */ kai_get_mr_matmul_clamp_f32_f16p1vlx2_qsi4c32p4vlx2_1vlx4vl_sme2_mopa, + /* .get_nr = */ kai_get_nr_matmul_clamp_f32_f16p1vlx2_qsi4c32p4vlx2_1vlx4vl_sme2_mopa, + /* .get_kr = */ kai_get_kr_matmul_clamp_f32_f16p1vlx2_qsi4c32p4vlx2_1vlx4vl_sme2_mopa, + /* .get_sr = */ kai_get_sr_matmul_clamp_f32_f16p1vlx2_qsi4c32p4vlx2_1vlx4vl_sme2_mopa, + /* .get_dst_offset = */ kai_get_dst_offset_matmul_clamp_f32_f16p1vlx2_qsi4c32p4vlx2_1vlx4vl_sme2_mopa, + /* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_f16p1vlx2_qsi4c32p4vlx2_1vlx4vl_sme2_mopa, + /* .get_lhs_offset_ex = */ &kernel_offs_fn3<kai_get_lhs_packed_offset_matmul_clamp_f32_f16p1vlx2_qsi4c32p4vlx2_1vlx4vl_sme2_mopa>, + /* .get_rhs_packed_offset_ex = */ &kernel_offs_fn3<kai_get_rhs_packed_offset_matmul_clamp_f32_f16p1vlx2_qsi4c32p4vlx2_1vlx4vl_sme2_mopa>, + /* .run_kernel_ex = */ &kernel_run_fn11<kai_run_matmul_clamp_f32_f16p1vlx2_qsi4c32p4vlx2_1vlx4vl_sme2_mopa>, }, /* .gemm_lhs_info = */ { - /* .get_offset = */ kai_get_lhs_offset_lhs_quant_pack_qsi8d32p_f32_neon, - /* .get_packed_offset_ex = */ &lhs_offs_fn6<kai_get_lhs_packed_offset_lhs_quant_pack_qsi8d32p_f32_neon>, - /* .packed_size_ex = */ &lhs_ps_fn6<kai_get_lhs_packed_size_lhs_quant_pack_qsi8d32p_f32_neon>, - /* .pack_func_ex = */ &lhs_pack_float_fn10<kai_run_lhs_quant_pack_qsi8d32p_f32_neon>, + /* .get_offset = */ kai_get_lhs_offset_lhs_pack_f16pmrx2_f32_neon, + /* .get_packed_offset_ex = */ &lhs_offs_fn6<kai_get_lhs_packed_offset_lhs_pack_f16pmrx2_f32_neon>, + /* .packed_size_ex = */ &lhs_ps_fn6<kai_get_lhs_packed_size_lhs_pack_f16pmrx2_f32_neon>, + /* .pack_func_ex = */ &lhs_pack_void_fn10<kai_run_lhs_pack_f16pmrx2_f32_neon>, }, /* SME GEMV */ /* .kern_info = */ { @@ -519,7 +520,7 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = { /* .packed_stride_ex = */ &rhs_stride_fn4<kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0>, /* .pack_func_ex = */ &rhs_pack_fn12<kai_run_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0>, }, - /* .required_cpu = */ CPU_FEATURE_DOTPROD | CPU_FEATURE_I8MM, + /* .required_cpu = */ CPU_FEATURE_I8MM, /* .lhs_type = */ GGML_TYPE_F32, /* .rhs_type = */ GGML_TYPE_Q4_0, /* .op_type = */ GGML_TYPE_F32, @@ -630,7 +631,7 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = { /* .packed_stride_ex = */ &rhs_stride_fn4<kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0>, /* .pack_func_ex = */ &rhs_pack_fn12<kai_run_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0>, }, - /* .required_cpu = */ CPU_FEATURE_DOTPROD | CPU_FEATURE_I8MM, + /* .required_cpu = */ CPU_FEATURE_I8MM, /* .lhs_type = */ GGML_TYPE_F32, /* .rhs_type = */ GGML_TYPE_Q4_0, /* .op_type = */ GGML_TYPE_F32, @@ -800,7 +801,7 @@ static ggml_kleidiai_kernels gemm_gemv_kernels_q8[] = { /* .packed_stride_ex = */ &rhs_stride_fn4<kai_get_rhs_packed_stride_rhs_pack_nxk_qsi8cxp_qsi8cx_neon>, /* .pack_func_ex = */ &rhs_pack_scale_fn12<kai_run_rhs_pack_nxk_qsi8cxp_qsi8cx_neon>, }, - /* .required_cpu = */ CPU_FEATURE_DOTPROD | CPU_FEATURE_I8MM, + /* .required_cpu = */ CPU_FEATURE_I8MM, /* .lhs_type = */ GGML_TYPE_F32, /* .rhs_type = */ GGML_TYPE_Q8_0, /* .op_type = */ GGML_TYPE_F32, diff --git a/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp b/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp index ad23e73184e..9e54b676b93 100644 --- a/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +++ b/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp @@ -1,20 +1,31 @@ -// SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates <open-source-office@arm.com> +// SPDX-FileCopyrightText: Copyright 2025-2026 Arm Limited and/or its affiliates <open-source-office@arm.com> // SPDX-License-Identifier: MIT // #include <arm_neon.h> #include <assert.h> +#include <stdio.h> #include <atomic> #include <cfloat> -#include <cmath> #include <algorithm> +#include <cmath> #include <stdexcept> #include <stdint.h> #include <string.h> #include <string> #include <vector> +#include <array> +#include <cstddef> +#include <cstdint> +#include <fstream> +#include <set> +#include <iostream> +#include <climits> #if defined(__linux__) #include <asm/hwcap.h> #include <sys/auxv.h> +#include <sys/types.h> +#include <sys/stat.h> +#include <unistd.h> #elif defined(__APPLE__) #include <string_view> #include <sys/sysctl.h> @@ -27,6 +38,7 @@ #include "kleidiai.h" #include "ggml-cpu.h" +#include "ggml-cpu-impl.h" #include "ggml-impl.h" #include "ggml-backend-impl.h" #include "ggml-threading.h" @@ -39,11 +51,19 @@ #define GGML_COMMON_DECL_CPP #include "ggml-common.h" +static constexpr int GGML_KLEIDIAI_MAX_KERNEL_SLOTS = 2; +static constexpr uint32_t GGML_KLEIDIAI_PACK_MAGIC = 0x4b4c4149; // "KLAI" +static constexpr uint16_t GGML_KLEIDIAI_PACK_VERSION = 1; +static constexpr size_t GGML_KLEIDIAI_PACK_ALIGN = 64; + struct ggml_kleidiai_context { cpu_feature features; ggml_kleidiai_kernels * kernels_q4; ggml_kleidiai_kernels * kernels_q8; -} static ctx = { CPU_FEATURE_NONE, NULL, NULL }; + int sme_thread_cap; // <= 0 means “SME disabled/unknown”; + int thread_hint; // <= 0 means “no hint” + int chunk_multiplier; +} static ctx = { CPU_FEATURE_NONE, nullptr, nullptr, 0, -1, 4 }; static const char* cpu_feature_to_string(cpu_feature f) { if (f == CPU_FEATURE_NONE) { @@ -63,41 +83,388 @@ static const char* cpu_feature_to_string(cpu_feature f) { } } -static void init_kleidiai_context(void) { +static size_t detect_num_smcus() { + if (!ggml_cpu_has_sme()) { + return 0; + } + +#if defined(__linux__) && defined(__aarch64__) + // Linux/aarch64: Best-effort count of Streaming Mode Compute Units (SMCUs) via SMIDR_EL1 sysfs. + size_t num_private = 0; + std::set<uint32_t> shared_ids; + + for (size_t cpu = 0;; ++cpu) { + const std::string path = + "/sys/devices/system/cpu/cpu" + std::to_string(cpu) + + "/regs/identification/smidr_el1"; + + std::ifstream file(path); + if (!file.is_open()) { + break; + } + + uint64_t smidr = 0; + if (!(file >> std::hex >> smidr)) { + continue; + } + + // Arm ARM: SMIDR_EL1 + const uint32_t sh = (uint32_t)((smidr >> 13) & 0x3); + // Build an "affinity-like" identifier for shared SMCUs. + // Keep the original packing logic, but isolate it here. + const uint32_t id = (uint32_t)((smidr & 0xFFFu) | ((smidr >> 20) & 0xFFFFF000u)); + + switch (sh) { + case 0b10: // private SMCU + ++num_private; + break; + case 0b11: // shared SMCU + shared_ids.emplace(id); + break; + case 0b00: + // Ambiguous / implementation-defined. Be conservative: + // treat id==0 as private, otherwise as shared. + if (id == 0) ++num_private; + else shared_ids.emplace(id); + break; + default: + break; + } + } + + return num_private + shared_ids.size(); + +#elif defined(__APPLE__) && defined(__aarch64__) + // table for known M4 variants. Users can override via GGML_KLEIDIAI_SME=<n>. + char chip_name[256] = {}; + size_t size = sizeof(chip_name); + + if (sysctlbyname("machdep.cpu.brand_string", chip_name, &size, nullptr, 0) == 0) { + const std::string brand(chip_name); + + struct ModelSMCU { const char *match; size_t smcus; }; + static const ModelSMCU table[] = { + { "M4 Ultra", 2 }, + { "M4 Max", 2 }, + { "M4 Pro", 2 }, + { "M4", 1 }, + }; + + for (const auto &e : table) { + if (brand.find(e.match) != std::string::npos) { + return e.smcus; + } + } + } + return 1; + +#else + return 1; +#endif +} +static int parse_uint_env(const char *s, const char *name, bool *ok) { + if (!s) { *ok = false; return 0; } + char *end = nullptr; + long v = strtol(s, &end, 10); + if (end == s || *end != '\0') { + GGML_LOG_WARN("kleidiai: invalid %s='%s' (expected integer)\n", name, s); + *ok = false; + return 0; + } + if (v < 0 || v > INT_MAX) { + GGML_LOG_WARN("kleidiai: out-of-range %s='%s'\n", name, s); + *ok = false; + return 0; + } + *ok = true; + return (int)v; +} + +static void init_kleidiai_context(void) { ggml_critical_section_start(); static bool initialized = false; if (!initialized) { initialized = true; - const char *env_var = getenv("GGML_KLEIDIAI_SME"); - int sme_enabled = 0; + + const char *env_sme = getenv("GGML_KLEIDIAI_SME"); + const char *env_threads = getenv("GGML_TOTAL_THREADS"); + const char *env_chunk_mult = getenv("GGML_KLEIDIAI_CHUNK_MULTIPLIER"); + + const bool cpu_has_sme = ggml_cpu_has_sme(); + size_t detected_smcus = 0; ctx.features = (ggml_cpu_has_dotprod() ? CPU_FEATURE_DOTPROD : CPU_FEATURE_NONE) | (ggml_cpu_has_matmul_int8() ? CPU_FEATURE_I8MM : CPU_FEATURE_NONE) | ((ggml_cpu_has_sve() && ggml_cpu_get_sve_cnt() == QK8_0) ? CPU_FEATURE_SVE : CPU_FEATURE_NONE); - if (env_var) { - sme_enabled = atoi(env_var); + if (env_threads) { + bool ok = false; + int hint = parse_uint_env(env_threads, "GGML_TOTAL_THREADS", &ok); + if (ok && hint > 0) { + ctx.thread_hint = hint; + } } - if (sme_enabled != 0) { - ctx.features |= ggml_cpu_has_sme() ? CPU_FEATURE_SME : CPU_FEATURE_NONE; + if (env_chunk_mult) { + bool ok = false; + int multiplier = parse_uint_env(env_chunk_mult, "GGML_KLEIDIAI_CHUNK_MULTIPLIER", &ok); + if (ok && multiplier > 0) { + ctx.chunk_multiplier = multiplier; + } } + + // SME policy: + // - If CPU doesn't support SME: SME always off. + // - Else: + // - env unset => auto-detect cores; enable if detected > 0. + // - env=0 => force off. + // - env>0 => force N cores (skip detection). + int sme_cores = 0; + bool sme_env_ok = false; + bool sme_env_set = (env_sme != nullptr); + + if (!cpu_has_sme) { + if (sme_env_set) { + bool ok = false; + int req = parse_uint_env(env_sme, "GGML_KLEIDIAI_SME", &ok); + if (ok && req > 0) { + GGML_LOG_WARN("kleidiai: GGML_KLEIDIAI_SME=%d but SME is not supported on this CPU; disabling SME\n", req); + } + } + sme_cores = 0; + } else { + if (sme_env_set) { + bool ok = false; + int v = parse_uint_env(env_sme, "GGML_KLEIDIAI_SME", &ok); + sme_env_ok = ok; + + if (!ok) { + GGML_LOG_WARN("kleidiai: GGML_KLEIDIAI_SME set but parsing failed; falling back to runtime SME-core detection\n"); + detected_smcus = detect_num_smcus(); + sme_cores = detected_smcus > 0 ? (int)detected_smcus : 0; + } else if (v == 0) { + sme_cores = 0; + } else { + sme_cores = v; + } + } else { + detected_smcus = detect_num_smcus(); + sme_cores = detected_smcus > 0 ? (int)detected_smcus : 0; + } + + if (!sme_env_set && sme_cores == 0) { + GGML_LOG_WARN("kleidiai: SME supported but runtime SME-core detection returned 0; falling back to NEON\n"); + } + + if (sme_cores > 0) { + ctx.features |= CPU_FEATURE_SME; + } + } + + // Kernel selection ctx.kernels_q4 = ggml_kleidiai_select_kernels_q4_0(ctx.features); ctx.kernels_q8 = ggml_kleidiai_select_kernels_q8_0(ctx.features); -#ifndef NDEBUG - if (ctx.kernels_q4) { - GGML_LOG_DEBUG("kleidiai: using q4 kernel with CPU feature %s\n", cpu_feature_to_string(ctx.kernels_q4->required_cpu)); + + if (!ctx.kernels_q4) { + GGML_LOG_INFO("kleidiai: no compatible q4 kernels found for CPU features mask %d\n", (int)ctx.features); + } else { + GGML_LOG_INFO("kleidiai: primary q4 kernel feature %s\n", cpu_feature_to_string(ctx.kernels_q4->required_cpu)); } - if (ctx.kernels_q8) { - GGML_LOG_DEBUG("kleidiai: using q8 kernel with CPU feature %s\n", cpu_feature_to_string(ctx.kernels_q8->required_cpu)); + + if (!ctx.kernels_q8) { + GGML_LOG_INFO("kleidiai: no compatible q8 kernels found for CPU features mask %d\n", (int)ctx.features); + } else { + GGML_LOG_INFO("kleidiai: primary q8 kernel feature %s\n", cpu_feature_to_string(ctx.kernels_q8->required_cpu)); + } + + ctx.sme_thread_cap = (ctx.features & CPU_FEATURE_SME) ? sme_cores : 0; + + if (ctx.features & CPU_FEATURE_SME) { + if (sme_env_set && sme_env_ok && sme_cores > 0) { + GGML_LOG_INFO("kleidiai: SME enabled (GGML_KLEIDIAI_SME=%d override)\n", sme_cores); + } else { + GGML_LOG_INFO("kleidiai: SME enabled (runtime-detected SME cores=%d)\n", sme_cores); + } + } else { + GGML_LOG_INFO("kleidiai: SME disabled\n"); } -#endif } + ggml_critical_section_end(); } +static inline int kleidiai_sme_thread_cap() { + return ctx.sme_thread_cap; +} + +static inline size_t align_up(size_t value, size_t alignment) { + if (alignment == 0) { + return value; + } + const size_t remainder = value % alignment; + return remainder == 0 ? value : value + (alignment - remainder); +} + +static inline size_t gcd_size(size_t a, size_t b) { + while (b != 0) { + const size_t t = a % b; + a = b; + b = t; + } + return a; +} + +static inline bool lcm_size(size_t a, size_t b, size_t & result) { + if (a == 0 || b == 0) { + result = 0; + return false; + } + const size_t g = gcd_size(a, b); + const size_t q = a / g; + if (q > SIZE_MAX / b) { + return false; + } + result = q * b; + return true; +} + +static inline size_t ceil_div_size(size_t a, size_t b) { + return b == 0 ? 0 : (a + b - 1) / b; +} + +struct kleidiai_block_args { + size_t lhs_bl; + size_t rhs_bl; + size_t pack_bl; +}; + +static inline kleidiai_block_args kleidiai_get_block_args(ggml_type rhs_type) { + switch (rhs_type) { + case GGML_TYPE_Q4_0: + return { QK4_0, QK4_0, QK4_0 }; + case GGML_TYPE_Q8_0: + return { 0, 0, QK8_0 }; + default: + return { 0, 0, 0 }; + } +} + +static inline bool kleidiai_pack_fallback_allowed() { + if (ctx.sme_thread_cap <= 0) { + return false; + } + if (ctx.thread_hint <= 0) { + return true; + } + return ctx.thread_hint > ctx.sme_thread_cap; +} + +struct kleidiai_weight_header { + uint32_t magic; + uint16_t version; + uint16_t slot_count; + uint64_t offsets[GGML_KLEIDIAI_MAX_KERNEL_SLOTS]; + uint64_t sizes[GGML_KLEIDIAI_MAX_KERNEL_SLOTS]; +}; + +static inline kleidiai_weight_header * kleidiai_weight_header_from_ptr(void * data) { + return reinterpret_cast<kleidiai_weight_header *>(data); +} + +static inline const kleidiai_weight_header * kleidiai_weight_header_from_ptr(const void * data) { + return reinterpret_cast<const kleidiai_weight_header *>(data); +} + +static inline bool kleidiai_is_weight_header_valid(const kleidiai_weight_header * header) { + if (!header) { + return false; + } + if (header->magic != GGML_KLEIDIAI_PACK_MAGIC || header->version != GGML_KLEIDIAI_PACK_VERSION) { + return false; + } + if (header->slot_count == 0 || header->slot_count > GGML_KLEIDIAI_MAX_KERNEL_SLOTS) { + return false; + } + return true; +} + +static inline uint8_t * kleidiai_weight_slot_ptr(kleidiai_weight_header * header, int slot) { + if (!kleidiai_is_weight_header_valid(header)) { + return nullptr; + } + if (slot < 0 || slot >= header->slot_count) { + return nullptr; + } + return reinterpret_cast<uint8_t *>(header) + header->offsets[slot]; +} + +static inline const uint8_t * kleidiai_weight_slot_ptr(const kleidiai_weight_header * header, int slot) { + if (!kleidiai_is_weight_header_valid(header)) { + return nullptr; + } + if (slot < 0 || slot >= header->slot_count) { + return nullptr; + } + return reinterpret_cast<const uint8_t *>(header) + header->offsets[slot]; +} + +static inline ggml_kleidiai_kernels * kleidiai_primary_kernel_q4() { + return ctx.kernels_q4; +} + +static inline ggml_kleidiai_kernels * kleidiai_primary_kernel_q8() { + return ctx.kernels_q8; +} + +template <typename SelectFallback> +static int kleidiai_collect_kernel_chain_common( + ggml_kleidiai_kernels * primary, + cpu_feature features, + std::array<ggml_kleidiai_kernels *, GGML_KLEIDIAI_MAX_KERNEL_SLOTS> & out, + SelectFallback select_fallback) { + int count = 0; + if (!primary) { + return 0; + } + out[count++] = primary; + + if ((primary->required_cpu & CPU_FEATURE_SME) == CPU_FEATURE_SME) { + const cpu_feature fallback_mask = static_cast<cpu_feature>(features & ~CPU_FEATURE_SME); + if (fallback_mask != CPU_FEATURE_NONE) { + ggml_kleidiai_kernels * fallback = select_fallback(fallback_mask); + if (fallback && fallback != primary && + fallback->lhs_type == primary->lhs_type && + fallback->rhs_type == primary->rhs_type && + fallback->op_type == primary->op_type) { + out[count++] = fallback; + } + } + } + + return count; +} + +static int kleidiai_collect_kernel_chain(const struct ggml_tensor * op, + std::array<ggml_kleidiai_kernels *, GGML_KLEIDIAI_MAX_KERNEL_SLOTS> & out) { + ggml_kleidiai_kernels * primary = ggml_kleidiai_select_kernels(ctx.features, op); + return kleidiai_collect_kernel_chain_common(primary, ctx.features, out, + [&](cpu_feature mask) { return ggml_kleidiai_select_kernels(mask, op); }); +} + +static int kleidiai_collect_q4_chain(std::array<ggml_kleidiai_kernels *, GGML_KLEIDIAI_MAX_KERNEL_SLOTS> & out) { + ggml_kleidiai_kernels * primary = kleidiai_primary_kernel_q4(); + return kleidiai_collect_kernel_chain_common(primary, ctx.features, out, + [&](cpu_feature mask) { return ggml_kleidiai_select_kernels_q4_0(mask); }); +} + +static int kleidiai_collect_q8_chain(std::array<ggml_kleidiai_kernels *, GGML_KLEIDIAI_MAX_KERNEL_SLOTS> & out) { + ggml_kleidiai_kernels * primary = kleidiai_primary_kernel_q8(); + return kleidiai_collect_kernel_chain_common(primary, ctx.features, out, + [&](cpu_feature mask) { return ggml_kleidiai_select_kernels_q8_0(mask); }); +} + static inline int64_t ggml_ne(const ggml_tensor * tensor, int dim) { GGML_ASSERT(dim >= 0 && dim < GGML_MAX_DIMS); return tensor->ne[dim]; @@ -126,49 +493,108 @@ class tensor_traits : public ggml::cpu::tensor_traits { if (op->op != GGML_OP_MUL_MAT) { return false; } - ggml_kleidiai_kernels *kernels = ggml_kleidiai_select_kernels(ctx.features, op); - if (!kernels) { + + std::array<ggml_kleidiai_kernels *, GGML_KLEIDIAI_MAX_KERNEL_SLOTS> kernel_chain; + const int slot_count = kleidiai_collect_kernel_chain(op, kernel_chain); + if (slot_count == 0) { return false; } - bool is_gemv = op->src[1]->ne[1] == 1; - kernel_info * kernel = is_gemv ? &kernels->gemv : &kernels->gemm; - lhs_packing_info * lhs_info = is_gemv ? &kernels->gemv_lhs_info : &kernels->gemm_lhs_info; - size_t k = op->src[0]->ne[0]; - size_t n = op->src[0]->ne[1]; - size_t m = op->src[1]->ne[1]; - - size_t mr = kernel->get_mr(); - size_t kr = kernel->get_kr(); - size_t sr = kernel->get_sr(); - - if (kernels->rhs_type == GGML_TYPE_Q4_0) { - if (!lhs_info->packed_size_ex) return false; - size = lhs_info->packed_size_ex(m, k, QK4_0, mr, kr, sr); - } else if (kernels->rhs_type == GGML_TYPE_Q8_0) { - if (!lhs_info->packed_size_ex) return false; - size = lhs_info->packed_size_ex(m, k, QK8_0, mr, kr, sr); - } else if (kernels->rhs_type == GGML_TYPE_F16) { - if (!lhs_info->packed_size_ex || !kernels->rhs_info.packed_size_ex) return false; + const bool is_gemv = op->src[1]->ne[1] == 1; + const size_t k = op->src[0]->ne[0]; + const size_t n = op->src[0]->ne[1]; + const size_t m = op->src[1]->ne[1]; + + if (op->src[0]->type == GGML_TYPE_Q4_0 || op->src[0]->type == GGML_TYPE_Q8_0) { + const size_t qk = (op->src[0]->type == GGML_TYPE_Q4_0) ? QK4_0 : QK8_0; + + size_t cursor = 0; + bool any_slot = false; + + for (int slot = 0; slot < slot_count; ++slot) { + ggml_kleidiai_kernels * kernels = kernel_chain[slot]; + lhs_packing_info * lhs_info = is_gemv ? &kernels->gemv_lhs_info : &kernels->gemm_lhs_info; + kernel_info * kernel = is_gemv ? &kernels->gemv : &kernels->gemm; + + if (!lhs_info || !lhs_info->packed_size_ex || !kernel) { + return false; + } + + const size_t mr = kernel->get_mr(); + const size_t kr = kernel->get_kr(); + const size_t sr = kernel->get_sr(); + + const size_t packed = lhs_info->packed_size_ex(m, k, qk, mr, kr, sr); + + cursor = align_up(cursor, GGML_KLEIDIAI_PACK_ALIGN); + cursor += packed; + any_slot = true; + } + + if (!any_slot) { + return false; + } + + size = cursor; + return true; + } + + if (op->src[0]->type == GGML_TYPE_F16) { const int64_t lhs_batch_size0 = op->src[1]->ne[2]; const int64_t rhs_batch_size0 = op->src[0]->ne[2]; + GGML_ASSERT(rhs_batch_size0 > 0); const int64_t r = lhs_batch_size0 / rhs_batch_size0; - size = lhs_info->packed_size_ex(m * r, k, 0, mr, kr, sr) + - kernels->rhs_info.packed_size_ex(n, k, kernel->get_nr(), kernel->get_kr(), 0) + - k * n * sizeof(float) + n * sizeof(float); - } else { - return false; + + size_t cursor = 0; + bool any_slot = false; + + for (int slot = 0; slot < slot_count; ++slot) { + ggml_kleidiai_kernels * kernels = kernel_chain[slot]; + lhs_packing_info * lhs_info = is_gemv ? &kernels->gemv_lhs_info : &kernels->gemm_lhs_info; + kernel_info * kernel = is_gemv ? &kernels->gemv : &kernels->gemm; + if (!lhs_info || !lhs_info->packed_size_ex || !kernels->rhs_info.packed_size_ex || !kernel) { + return false; + } + + const size_t mr = kernel->get_mr(); + const size_t kr = kernel->get_kr(); + const size_t sr = kernel->get_sr(); + + cursor = align_up(cursor, GGML_KLEIDIAI_PACK_ALIGN); + cursor += lhs_info->packed_size_ex(m * r, k, 0, mr, kr, sr); + any_slot = true; + } + + for (int slot = 0; slot < slot_count; ++slot) { + ggml_kleidiai_kernels * kernels = kernel_chain[slot]; + kernel_info * kernel = is_gemv ? &kernels->gemv : &kernels->gemm; + if (!kernel || !kernels->rhs_info.packed_size_ex) { + return false; + } + cursor = align_up(cursor, GGML_KLEIDIAI_PACK_ALIGN); + cursor += kernels->rhs_info.packed_size_ex(n, k, kernel->get_nr(), kernel->get_kr(), 0); + } + + cursor = align_up(cursor, GGML_KLEIDIAI_PACK_ALIGN); + cursor += k * n * sizeof(float); + cursor = align_up(cursor, GGML_KLEIDIAI_PACK_ALIGN); + cursor += n * sizeof(float); + + if (!any_slot) { + return false; + } + + size = cursor; + return true; } - return true; + return false; } bool compute_forward(struct ggml_compute_params * params, struct ggml_tensor * dst) override { if (dst->op == GGML_OP_MUL_MAT) { - if (dst->src[0]->type == GGML_TYPE_Q4_0) { - return compute_forward_q4_0(params, dst); - } else if (dst->src[0]->type == GGML_TYPE_Q8_0) { - return compute_forward_q8_0(params, dst); + if (dst->src[0]->type == GGML_TYPE_Q4_0 || dst->src[0]->type == GGML_TYPE_Q8_0) { + return compute_forward_qx(params, dst); } else if (dst->src[0]->type == GGML_TYPE_F16) { return compute_forward_fp16(params, dst); } @@ -331,204 +757,412 @@ class tensor_traits : public ggml::cpu::tensor_traits { return true; } - bool compute_forward_q4_0(struct ggml_compute_params * params, struct ggml_tensor * dst) { - GGML_ASSERT(dst->src[0]->type == GGML_TYPE_Q4_0); + bool compute_forward_qx(struct ggml_compute_params * params, struct ggml_tensor * dst) { + GGML_ASSERT(dst->src[0]->type == GGML_TYPE_Q4_0 || dst->src[0]->type == GGML_TYPE_Q8_0); const ggml_tensor * src0 = dst->src[0]; const ggml_tensor * src1 = dst->src[1]; GGML_TENSOR_BINARY_OP_LOCALS - ggml_kleidiai_kernels *kernels = ggml_kleidiai_select_kernels(ctx.features, dst); - if (!kernels) { - return false; - } + const kleidiai_weight_header * header = kleidiai_weight_header_from_ptr(src0->data); + const bool has_header = kleidiai_is_weight_header_valid(header); + const bool is_gemv = src1->ne[1] == 1; + std::array<ggml_kleidiai_kernels *, GGML_KLEIDIAI_MAX_KERNEL_SLOTS> kernel_chain; + const int slot_total = kleidiai_collect_kernel_chain(dst, kernel_chain); - bool is_gemv = src1->ne[1] == 1; - kernel_info * kernel = is_gemv ? &kernels->gemv : &kernels->gemm; - lhs_packing_info * lhs_info = is_gemv ? &kernels->gemv_lhs_info : &kernels->gemm_lhs_info; + auto weight_for_slot = [&](int slot_index, size_t & size_out) -> const uint8_t * { + if (slot_index < 0 || slot_index >= slot_total) { + return nullptr; + } + if (has_header) { + if (slot_index < header->slot_count) { + size_out = static_cast<size_t>(header->sizes[slot_index]); + return kleidiai_weight_slot_ptr(header, slot_index); + } + return nullptr; + } + if (slot_index == 0) { + size_out = ggml_nbytes(src0); + return static_cast<const uint8_t *>(src0->data); + } + return nullptr; + }; + + struct runtime_slot { + int slot_index; + ggml_kleidiai_kernels * kernels; + kernel_info * kernel; + lhs_packing_info * lhs_info; + size_t mr; + size_t nr; + size_t kr; + size_t sr; + size_t n_step; + size_t lhs_packed_size; + size_t lhs_offset; + size_t lhs_bl; + size_t rhs_bl; + size_t pack_bl; + size_t lhs_packed_offset0; + int assigned_threads; + int thread_begin; + int thread_end; + const uint8_t * rhs_base; + }; + + std::array<runtime_slot, GGML_KLEIDIAI_MAX_KERNEL_SLOTS> runtime{}; + int runtime_count = 0; + + for (int slot = 0; slot < slot_total && runtime_count < GGML_KLEIDIAI_MAX_KERNEL_SLOTS; ++slot) { + ggml_kleidiai_kernels * kernels = kernel_chain[slot]; + kernel_info * kinfo = is_gemv ? &kernels->gemv : &kernels->gemm; + lhs_packing_info * linfo = is_gemv ? &kernels->gemv_lhs_info : &kernels->gemm_lhs_info; + if (!kinfo || !linfo || !linfo->packed_size_ex || !linfo->pack_func_ex || !linfo->get_offset || + !kinfo->get_rhs_packed_offset_ex || !kinfo->run_kernel_ex || !kinfo->get_dst_offset) { + continue; + } - GGML_ASSERT(kernel); - if (!lhs_info->get_packed_offset_ex || !lhs_info->pack_func_ex || - !kernel->get_rhs_packed_offset_ex || !kernel->run_kernel_ex || !kernel->get_dst_offset) { + size_t rhs_size = 0; + const uint8_t * rhs_ptr = weight_for_slot(slot, rhs_size); + if (!rhs_ptr || rhs_size == 0) { + continue; + } + + const kleidiai_block_args block_args = kleidiai_get_block_args(kernels->rhs_type); + + runtime[runtime_count] = { + slot, + kernels, + kinfo, + linfo, + kinfo->get_mr(), + kinfo->get_nr(), + kinfo->get_kr(), + kinfo->get_sr(), + kinfo->get_n_step(), + 0, + 0, + block_args.lhs_bl, + block_args.rhs_bl, + block_args.pack_bl, + 0, + 0, + 0, + 0, + rhs_ptr + }; + ++runtime_count; + } + + if (runtime_count == 0) { + GGML_LOG_WARN("kleidiai: no runtime kernel slot available for supported op %s\n", dst->name); return false; } - const int ith = params->ith; - const int nth_raw = params->nth; - const int nth = nth_raw > 0 ? nth_raw : 1; + const int nth_total = params->nth > 0 ? params->nth : 1; + const int ith_total = params->ith; - const size_t k = ne00; - const size_t m = ne11; - const size_t n = ne01; + int sme_slot = -1; + for (int i = 0; i < runtime_count; ++i) { + if ((runtime[i].kernels->required_cpu & CPU_FEATURE_SME) == CPU_FEATURE_SME) { + sme_slot = i; + break; + } + } + int non_sme_slot = -1; + for (int i = 0; i < runtime_count; ++i) { + if ((runtime[i].kernels->required_cpu & CPU_FEATURE_SME) != CPU_FEATURE_SME) { + non_sme_slot = i; + break; + } + } - size_t mr = kernel->get_mr(); - size_t kr = kernel->get_kr(); - size_t sr = kernel->get_sr(); + const int sme_cap_limit = ctx.sme_thread_cap; + const bool use_hybrid = sme_cap_limit > 0 && + runtime_count > 1 && + nth_total > sme_cap_limit; + // Heuristic: disable hybrid for very small workloads where per-slot overhead dominates. + // If rows are small or average columns per thread are small, keep single-slot. + size_t min_cols_per_thread = 0; + if (runtime_count > 0 && nth_total > 0) { + min_cols_per_thread = (size_t) std::max<int64_t>(1, (int64_t)ne01 / (int64_t)nth_total); + } + const bool too_small_for_hybrid = (min_cols_per_thread < 2) || (ne11 < 128); - const uint8_t * lhs = static_cast<const uint8_t *>(src1->data); - uint8_t * lhs_packed = (uint8_t*)params->wdata; - const uint8_t * rhs_packed = static_cast<const uint8_t *>(src0->data); + const bool hybrid_enabled = use_hybrid && !too_small_for_hybrid; - const size_t n_step = kernel->get_n_step(); - const size_t num_n_per_thread = kai_roundup(kai_roundup(n, nth) / nth, n_step); - const size_t n_start = ith * num_n_per_thread; + if (!hybrid_enabled) { + int chosen_slot = 0; + if (too_small_for_hybrid && sme_slot != -1) { + chosen_slot = nth_total > sme_cap_limit && non_sme_slot != -1 ? non_sme_slot : sme_slot; + } else if (runtime_count > 1 && ctx.sme_thread_cap > 0 && nth_total > ctx.sme_thread_cap) { + chosen_slot = 1; + } + if (chosen_slot != 0 && chosen_slot < runtime_count) { + runtime[0] = runtime[chosen_slot]; + runtime[0].assigned_threads = 0; + runtime[0].thread_begin = 0; + runtime[0].thread_end = 0; + } + runtime_count = runtime_count > 0 ? 1 : 0; - size_t n_to_process = 0; - if (n_start < n) { - n_to_process = num_n_per_thread; - if ((n_start + n_to_process) > n) { - n_to_process = n - n_start; + // Recompute SME slot based on the collapsed runtime[0] + sme_slot = -1; + if (runtime_count > 0 && + (runtime[0].kernels->required_cpu & CPU_FEATURE_SME) == CPU_FEATURE_SME) { + sme_slot = 0; } } - // Calculate number of columns to be processed per thread - const size_t num_m_per_thread = kai_roundup(m, mr * nth) / nth; - const size_t m_start = ith * num_m_per_thread; - size_t m_to_process = num_m_per_thread; - if ((m_start + m_to_process) > m) { - m_to_process = m - m_start; + int sme_cap = kleidiai_sme_thread_cap(); + if (sme_cap < 0) { + sme_cap = nth_total; } + sme_cap = std::min(sme_cap, nth_total); - if (m_start < m) { - // Transform LHS - const size_t src_stride = src1->nb[1]; - const float * src_ptr = reinterpret_cast<const float *>(lhs + lhs_info->get_offset(m_start, dst->src[1]->nb[1])); - const size_t lhs_packed_offset = lhs_info->get_packed_offset_ex(m_start, k, QK4_0, mr, kr, sr); - void * lhs_packed_ptr = static_cast<void *>(lhs_packed + lhs_packed_offset); - - // Pack this thread's chunk with m_idx_start = 0 and per-thread output pointer - lhs_info->pack_func_ex(m_to_process, k, QK4_0, mr, kr, sr, 0, src_ptr, src_stride, lhs_packed_ptr); + int threads_remaining = nth_total; + if (sme_slot != -1) { + int sme_threads = std::min(std::max(sme_cap, 0), threads_remaining); + runtime[sme_slot].assigned_threads = sme_threads; + threads_remaining -= sme_threads; } - ggml_barrier(params->threadpool); + int fallback_indices[GGML_KLEIDIAI_MAX_KERNEL_SLOTS]; + int fallback_count = 0; + // The current hybrid chain is bounded to SME + one non-SME fallback slot. + GGML_ASSERT(GGML_KLEIDIAI_MAX_KERNEL_SLOTS == 2); + for (int i = 0; i < runtime_count; ++i) { + if (i == sme_slot) { + continue; + } + fallback_indices[fallback_count++] = i; + } - // Perform the operation - const size_t dst_stride = dst->nb[1]; - const size_t lhs_packed_offset = lhs_info->get_packed_offset_ex(0, k, QK4_0, mr, kr, sr); - const size_t rhs_packed_offset = kernel->get_rhs_packed_offset_ex(n_start, k, QK4_0); - const size_t dst_offset = kernel->get_dst_offset(0, n_start, dst_stride); - const void * rhs_ptr = static_cast<const void *>(rhs_packed + rhs_packed_offset); - const void* lhs_ptr = (const void*)((const char *)lhs_packed + lhs_packed_offset); - float *dst_ptr = reinterpret_cast<float *>(static_cast<uint8_t *>(dst->data) + dst_offset); + for (int fi = 0; fi < fallback_count; ++fi) { + if (threads_remaining <= 0) { + break; + } + const int slot_index = fallback_indices[fi]; + const int slots_left = fallback_count - fi; + int share = (threads_remaining + slots_left - 1) / slots_left; + share = std::min(share, threads_remaining); + runtime[slot_index].assigned_threads = share; + threads_remaining -= share; + } - if (n_to_process > 0) { - kernel->run_kernel_ex(m, n_to_process, k, QK4_0, lhs_ptr, rhs_ptr, dst_ptr, dst_stride, - sizeof(float), -FLT_MAX, FLT_MAX); + if (threads_remaining > 0) { + const int fallback_slot = (sme_slot != -1) ? sme_slot : 0; + runtime[fallback_slot].assigned_threads += threads_remaining; + threads_remaining = 0; } - return true; - } + int thread_cursor = 0; + for (int i = 0; i < runtime_count; ++i) { + runtime[i].thread_begin = thread_cursor; + thread_cursor += runtime[i].assigned_threads; + runtime[i].thread_end = thread_cursor; + } - bool compute_forward_q8_0(struct ggml_compute_params * params, struct ggml_tensor * dst) { - GGML_ASSERT(dst->src[0]->type == GGML_TYPE_Q8_0); + if (thread_cursor < nth_total && runtime_count > 0) { + runtime[runtime_count - 1].assigned_threads += nth_total - thread_cursor; + runtime[runtime_count - 1].thread_end = nth_total; + } - const ggml_tensor * src0 = dst->src[0]; - const ggml_tensor * src1 = dst->src[1]; + int local_slot = -1; + int local_ith = 0; + for (int i = 0; i < runtime_count; ++i) { + if (ith_total >= runtime[i].thread_begin && ith_total < runtime[i].thread_end) { + local_slot = i; + local_ith = ith_total - runtime[i].thread_begin; + break; + } + } + if (local_slot == -1) { + return false; + } - GGML_TENSOR_BINARY_OP_LOCALS + const size_t k = ne00; + const size_t m = ne11; + const size_t n = ne01; - ggml_kleidiai_kernels *kernels = ggml_kleidiai_select_kernels(ctx.features, dst); - if (!kernels) { - return false; + size_t cursor = 0; + for (int i = 0; i < runtime_count; ++i) { + runtime[i].lhs_packed_size = runtime[i].lhs_info->packed_size_ex(m, k, runtime[i].pack_bl, runtime[i].mr, runtime[i].kr, runtime[i].sr); + cursor = align_up(cursor, GGML_KLEIDIAI_PACK_ALIGN); + runtime[i].lhs_offset = cursor; + runtime[i].lhs_packed_offset0 = runtime[i].lhs_info->get_packed_offset_ex(0, k, runtime[i].lhs_bl, runtime[i].mr, runtime[i].kr, runtime[i].sr); + cursor += runtime[i].lhs_packed_size; } - bool is_gemv = src1->ne[1] == 1; - kernel_info * kernel = is_gemv ? &kernels->gemv : &kernels->gemm; - lhs_packing_info * lhs_info = is_gemv ? &kernels->gemv_lhs_info : &kernels->gemm_lhs_info; + GGML_ASSERT(cursor <= params->wsize); + uint8_t * scratch = static_cast<uint8_t *>(params->wdata); - if (!kernel || !lhs_info->get_packed_offset_ex || !lhs_info->pack_func_ex || - !kernel->get_rhs_packed_offset_ex || !kernel->run_kernel_ex || !kernel->get_dst_offset) { - return false; + size_t common_step = 1; + for (int i = 0; i < runtime_count; ++i) { + if (runtime[i].assigned_threads == 0) { + continue; + } + size_t next_step = 0; + if (!lcm_size(common_step, runtime[i].n_step ? runtime[i].n_step : 1, next_step)) { + return false; + } + common_step = next_step; + } + GGML_ASSERT(common_step > 0); + + const bool disable_chunking = ggml_is_numa(); + const size_t chunk_multiplier = std::max(1, ctx.chunk_multiplier); + const size_t chunk_divisor = (nth_total == 1 || disable_chunking) ? (size_t)nth_total : (size_t)nth_total * chunk_multiplier; + size_t chunk_cols = align_up(std::max<size_t>(1, ceil_div_size(n, chunk_divisor)), common_step); + if (chunk_cols == 0) { + chunk_cols = common_step; } + // If common_step is larger than n, the loop below runs one valid tail chunk + // with cols == n. + const size_t nchunk_size = std::max<size_t>(1, ceil_div_size(n, chunk_cols)); + GGML_ASSERT(nchunk_size <= (size_t)INT_MAX); + const int nchunk = (int)nchunk_size; + const size_t dst_stride = dst->nb[1]; - const int ith = params->ith; - const int nth_raw = params->nth; - const int nth = nth_raw > 0 ? nth_raw : 1; + auto run_chunk = [&](runtime_slot & slot, size_t global_start, size_t cols, uint8_t * dst_batch_base) { + const size_t rhs_packed_offset = slot.kernel->get_rhs_packed_offset_ex(global_start, k, slot.rhs_bl); + const size_t dst_offset = slot.kernel->get_dst_offset(0, global_start, dst_stride); + + const uint8_t * lhs_ptr = scratch + slot.lhs_offset + slot.lhs_packed_offset0; + const uint8_t * rhs_ptr = slot.rhs_base + rhs_packed_offset; + float * dst_ptr = reinterpret_cast<float *>(dst_batch_base + dst_offset); + + slot.kernel->run_kernel_ex(m, cols, k, slot.rhs_bl, + lhs_ptr, + rhs_ptr, + dst_ptr, + dst_stride, + sizeof(float), + -FLT_MAX, + FLT_MAX); + }; + + for (int64_t batch_idx = 0; batch_idx < ne12; ++batch_idx) { + const uint8_t * lhs_batch_base = static_cast<const uint8_t *>(src1->data) + batch_idx * src1->nb[2]; + uint8_t * dst_batch_base = static_cast<uint8_t *>(dst->data) + batch_idx * dst->nb[2]; - const size_t k = ne00; - const size_t m = ne11; - const size_t n = ne01; + if (runtime[local_slot].assigned_threads > 0) { + runtime_slot & slot = runtime[local_slot]; + const int64_t m_roundup_mr = kai_roundup((int64_t)m, (int64_t)slot.mr); + int64_t max_threads = slot.mr ? (m_roundup_mr / (int64_t)slot.mr) : slot.assigned_threads; + max_threads = std::max<int64_t>(1, max_threads); + const int64_t use_threads = std::min<int64_t>(slot.assigned_threads, max_threads); + + if (local_ith < use_threads) { + const int64_t num_m_per_thread0 = round_down((size_t)(m_roundup_mr / use_threads), slot.mr); + const int64_t num_m_per_threadN_1 = (int64_t)m - (use_threads - 1) * num_m_per_thread0; + + const int64_t m_start = (int64_t)local_ith * num_m_per_thread0; + const int64_t m_count = (local_ith == use_threads - 1) ? num_m_per_threadN_1 : num_m_per_thread0; + + const size_t base_packed_off = slot.lhs_info->get_packed_offset_ex(m_start, k, slot.lhs_bl, slot.mr, slot.kr, slot.sr); + const size_t next_block_off = slot.lhs_info->get_packed_offset_ex(m_start + slot.mr, k, slot.lhs_bl, slot.mr, slot.kr, slot.sr); + const size_t row_stride_bytes = slot.mr ? (next_block_off - base_packed_off) / slot.mr : 0; + + int64_t remaining = m_count; + int64_t cur = m_start; - size_t mr = kernel->get_mr(); - size_t kr = kernel->get_kr(); - size_t sr = kernel->get_sr(); + uint8_t * lhs_packed = scratch + slot.lhs_offset; + while (remaining > 0) { + const int64_t row_in_group = cur; + const int64_t avail = (int64_t)m - row_in_group; + const int64_t take = std::min(avail, remaining); - const uint8_t * lhs = static_cast<const uint8_t *>(src1->data); - uint8_t * lhs_packed = static_cast<uint8_t *>(params->wdata); - const uint8_t * rhs_packed = static_cast<const uint8_t *>(src0->data); + const size_t src_off = slot.lhs_info->get_offset(row_in_group, src1->nb[1]); + const void * src_ptr = lhs_batch_base + src_off; + const size_t dst_off = base_packed_off + (size_t)(cur - m_start) * row_stride_bytes; + void * dst_ptr = lhs_packed + dst_off; - const size_t n_step = kernel->get_n_step(); - const size_t num_n_per_thread = kai_roundup(kai_roundup(n, nth) / nth, n_step); - const size_t n_start = ith * num_n_per_thread; + slot.lhs_info->pack_func_ex(take, k, slot.lhs_bl, slot.mr, slot.kr, slot.sr, 0, src_ptr, src1->nb[1], dst_ptr); - size_t n_to_process = 0; - if (n_start < n) { - n_to_process = num_n_per_thread; - if ((n_start + n_to_process) > n) { - n_to_process = n - n_start; + cur += take; + remaining -= take; + } + } } - } - const size_t num_m_per_thread = kai_roundup(m, mr * nth) / nth; - const size_t m_start = ith * num_m_per_thread; - size_t m_to_process = num_m_per_thread; - if ((m_start + m_to_process) > m) { - m_to_process = m - m_start; - } + if (ith_total == 0) { + ggml_threadpool_chunk_set(params->threadpool, nth_total); + } - if (m_start < m) { - const size_t src_stride = src1->nb[1]; - const float * src_ptr = reinterpret_cast<const float *>(lhs + lhs_info->get_offset(m_start, dst->src[1]->nb[1])); - const size_t lhs_packed_offset = lhs_info->get_packed_offset_ex(m_start, k, 0, mr, kr, sr); - void * lhs_packed_ptr = static_cast<void *>(lhs_packed + lhs_packed_offset); + // Publishes both LHS packing and the initialized dynamic chunk queue. + ggml_barrier(params->threadpool); - lhs_info->pack_func_ex(m_to_process, k, 0, mr, kr, sr, 0, src_ptr, src_stride, lhs_packed_ptr); - } + runtime_slot & slot = runtime[local_slot]; + int current_chunk = ith_total; + while (current_chunk < nchunk) { + const size_t global_start = (size_t)current_chunk * chunk_cols; + if (global_start >= n) { + break; + } - ggml_barrier(params->threadpool); + const size_t cols = std::min(chunk_cols, n - global_start); + if (cols > 0) { + // KleidiAI GEMM/GEMV kernels accept arbitrary final tail widths; + // only non-tail chunks are guaranteed to be n_step-aligned. + run_chunk(slot, global_start, cols, dst_batch_base); + } - const size_t dst_stride = dst->nb[1]; - const size_t lhs_packed_offset = lhs_info->get_packed_offset_ex(0, k, 0, mr, kr, sr); - const size_t rhs_packed_offset = kernel->get_rhs_packed_offset_ex(n_start, k, 0); - const size_t dst_offset = kernel->get_dst_offset(0, n_start, dst_stride); - const void * rhs_ptr = static_cast<const void *>(rhs_packed + rhs_packed_offset); - const void * lhs_ptr = static_cast<const void *>(lhs_packed + lhs_packed_offset); - float * dst_ptr = reinterpret_cast<float *>(static_cast<uint8_t *>(dst->data) + dst_offset); + current_chunk = ggml_threadpool_chunk_add(params->threadpool, 1); + } - if (n_to_process > 0) { - kernel->run_kernel_ex(m, n_to_process, k, 0, lhs_ptr, rhs_ptr, dst_ptr, dst_stride, - sizeof(float), -FLT_MAX, FLT_MAX); + if (batch_idx != ne12 - 1) { + ggml_barrier(params->threadpool); + } } return true; } bool compute_forward_get_rows(struct ggml_compute_params * params, struct ggml_tensor * dst) { + GGML_ASSERT(dst->src[0]->type == GGML_TYPE_Q4_0 || dst->src[0]->type == GGML_TYPE_Q8_0); const ggml_tensor * src0 = dst->src[0]; const ggml_tensor * src1 = dst->src[1]; GGML_TENSOR_BINARY_OP_LOCALS + const kleidiai_weight_header * header = kleidiai_weight_header_from_ptr(src0->data); + const bool has_header = kleidiai_is_weight_header_valid(header); + + std::array<ggml_kleidiai_kernels *, GGML_KLEIDIAI_MAX_KERNEL_SLOTS> kernel_chain; + const bool want_q8 = src0->type == GGML_TYPE_Q8_0; + const int chain_count = want_q8 ? kleidiai_collect_q8_chain(kernel_chain) + : kleidiai_collect_q4_chain(kernel_chain); + ggml_kleidiai_kernels * kernels = nullptr; - size_t block_len = 0; - size_t num_bytes_multiplier = 0; + const uint8_t * packed_base = static_cast<const uint8_t *>(src0->data); - if (dst->src[0]->type == GGML_TYPE_Q4_0) { - if (!ctx.kernels_q4) { - return false; + if (has_header && chain_count > 0) { + int select_slot = 0; + if (select_slot >= header->slot_count) { + select_slot = header->slot_count - 1; } - kernels = ctx.kernels_q4; - block_len = QK4_0; - num_bytes_multiplier = sizeof(uint16_t); - } else if (dst->src[0]->type == GGML_TYPE_Q8_0) { - if (!ctx.kernels_q8) { - return false; + if (select_slot >= 0 && select_slot < chain_count) { + kernels = kernel_chain[select_slot]; + const uint8_t * slot_ptr = kleidiai_weight_slot_ptr(header, select_slot); + if (slot_ptr) { + packed_base = slot_ptr; + } } - kernels = ctx.kernels_q8; - block_len = QK8_0; - num_bytes_multiplier = sizeof(float); - } else { + } + + if (!kernels && chain_count > 0) { + kernels = kernel_chain[0]; + if (has_header) { + const uint8_t * slot_ptr = kleidiai_weight_slot_ptr(header, 0); + if (slot_ptr) { + packed_base = slot_ptr; + } + } + } + + if (!kernels) { return false; } @@ -541,6 +1175,19 @@ class tensor_traits : public ggml::cpu::tensor_traits { const int64_t nc = ne00; const int64_t nr = ggml_nelements(src1); + const ggml_type rhs_type = kernels->rhs_type; + size_t block_len = 0; + size_t num_bytes_multiplier = 0; + if (rhs_type == GGML_TYPE_Q4_0) { + block_len = QK4_0; + num_bytes_multiplier = sizeof(uint16_t); + } else if (rhs_type == GGML_TYPE_Q8_0) { + block_len = QK8_0; + num_bytes_multiplier = sizeof(float); + } else { + return false; + } + const size_t block_rows = kernel->get_nr(); const size_t kr = kernel->get_kr(); @@ -559,7 +1206,7 @@ class tensor_traits : public ggml::cpu::tensor_traits { GGML_ASSERT(row_idx >= 0 && row_idx < src0->ne[1]); float *out = (float *)((char *)dst->data + i * nb1); - rhs_info->to_float(src0->data, row_idx, nc, out, block_rows, packed_stride, kr, block_len, num_bytes_multiplier); + rhs_info->to_float(packed_base, row_idx, nc, out, block_rows, packed_stride, kr, block_len, num_bytes_multiplier); } return true; @@ -567,36 +1214,39 @@ class tensor_traits : public ggml::cpu::tensor_traits { public: int repack(struct ggml_tensor * tensor, const void * data, size_t data_size) { + GGML_ASSERT(tensor->type == GGML_TYPE_Q4_0 || tensor->type == GGML_TYPE_Q8_0); const size_t n = tensor->ne[1]; const size_t k = tensor->ne[0]; - if (tensor->type == GGML_TYPE_Q4_0) { - if (!ctx.kernels_q4) { - return -1; - } - size_t nr = ctx.kernels_q4->gemm.get_nr(); - size_t kr = ctx.kernels_q4->gemm.get_kr(); - size_t sr = ctx.kernels_q4->gemm.get_sr(); + kleidiai_weight_header * header = kleidiai_weight_header_from_ptr(tensor->data); + if (!header) { + return -1; + } - struct kai_rhs_pack_qs4cxs1s0_param params; - params.lhs_zero_point = 1; - params.rhs_zero_point = 8; - ctx.kernels_q4->rhs_info.pack_func_ex(1, n, k, nr, kr, sr, QK4_0, 0, - static_cast<const uint8_t *>(data), - nullptr, nullptr, tensor->data, 0, ¶ms); - GGML_UNUSED(data_size); - return 0; - } else if (tensor->type == GGML_TYPE_Q8_0) { - if (!ctx.kernels_q8) { - return -1; - } + header->magic = GGML_KLEIDIAI_PACK_MAGIC; + header->version = GGML_KLEIDIAI_PACK_VERSION; + header->slot_count = 0; + + uint8_t * base_ptr = static_cast<uint8_t *>(tensor->data); + size_t cursor = sizeof(kleidiai_weight_header); + cursor = align_up(cursor, GGML_KLEIDIAI_PACK_ALIGN); + + std::array<ggml_kleidiai_kernels *, GGML_KLEIDIAI_MAX_KERNEL_SLOTS> kernel_chain; + const bool want_q8 = tensor->type == GGML_TYPE_Q8_0; + const int slot_total = want_q8 ? kleidiai_collect_q8_chain(kernel_chain) + : kleidiai_collect_q4_chain(kernel_chain); + const bool allow_fallback = kleidiai_pack_fallback_allowed(); + + std::vector<int8_t> qdata; + std::vector<float> scales; + + if (want_q8 && slot_total > 0) { + qdata.resize(n * k, 0); + scales.resize(n, 0.0f); const size_t row_stride = tensor->nb[1]; const size_t k_blocks = (k + QK8_0 - 1) / QK8_0; - std::vector<int8_t> qdata(n * k, 0); - std::vector<float> scales(n, 0.0f); - for (size_t row = 0; row < n; ++row) { const auto * row_blocks = reinterpret_cast<const block_q8_0 *>( static_cast<const uint8_t *>(data) + row * row_stride); @@ -610,7 +1260,7 @@ class tensor_traits : public ggml::cpu::tensor_traits { if (linear_idx >= k) { break; } - const float value = d * blk.qs[l]; + const float value = d * static_cast<float>(blk.qs[l]); max_abs = std::max(max_abs, std::fabs(value)); } } @@ -627,31 +1277,73 @@ class tensor_traits : public ggml::cpu::tensor_traits { if (linear_idx >= k) { break; } - const float value = d * blk.qs[l]; + const float value = d * static_cast<float>(blk.qs[l]); int32_t q = scale > 0.0f ? static_cast<int32_t>(std::lround(value * inv_scale)) : 0; q = std::clamp(q, -127, 127); qdata[row * k + linear_idx] = static_cast<int8_t>(q); } } } + } - size_t nr = ctx.kernels_q8->gemm.get_nr(); - size_t kr = ctx.kernels_q8->gemm.get_kr(); - size_t sr = ctx.kernels_q8->gemm.get_sr(); + for (int slot = 0; slot < slot_total && slot < GGML_KLEIDIAI_MAX_KERNEL_SLOTS; ++slot) { + if (!allow_fallback && slot > 0) { + break; + } + ggml_kleidiai_kernels * kernels = kernel_chain[slot]; + kernel_info * kernel = &kernels->gemm; + rhs_packing_info * rhs_info = &kernels->rhs_info; + if (!rhs_info || !rhs_info->pack_func_ex || !rhs_info->packed_size_ex || !kernel) { + continue; + } - struct kai_rhs_pack_qsi8cx_params params; - params.lhs_zero_point = 1; - params.scale_multiplier = 1.0f; + const size_t nr = kernel->get_nr(); + const size_t kr = kernel->get_kr(); + const size_t sr = kernel->get_sr(); + const ggml_type rhs_type = kernels->rhs_type; + const size_t block_len = rhs_type == GGML_TYPE_Q8_0 ? QK8_0 : + rhs_type == GGML_TYPE_Q4_0 ? QK4_0 : 0; + if (block_len == 0) { + continue; + } + + const size_t packed_size = rhs_info->packed_size_ex(n, k, nr, kr, block_len); + const size_t aligned_cursor = align_up(cursor, GGML_KLEIDIAI_PACK_ALIGN); + + uint8_t * dst_ptr = base_ptr + aligned_cursor; + + if (rhs_type == GGML_TYPE_Q4_0) { + struct kai_rhs_pack_qs4cxs1s0_param params; + params.lhs_zero_point = 1; + params.rhs_zero_point = 8; + rhs_info->pack_func_ex(1, n, k, nr, kr, sr, QK4_0, 0, + static_cast<const uint8_t *>(data), nullptr, nullptr, + dst_ptr, 0, ¶ms); + } else if (rhs_type == GGML_TYPE_Q8_0) { + struct kai_rhs_pack_qsi8cx_params params; + params.lhs_zero_point = 1; + params.scale_multiplier = 1.0f; + rhs_info->pack_func_ex(1, n, k, nr, kr, sr, 0, 0, + qdata.data(), nullptr, scales.data(), + dst_ptr, 0, ¶ms); + } else { + continue; + } - ctx.kernels_q8->rhs_info.pack_func_ex(1, n, k, nr, kr, sr, 0, 0, - qdata.data(), nullptr, scales.data(), - tensor->data, 0, ¶ms); - GGML_UNUSED(data_size); - return 0; + header->offsets[header->slot_count] = aligned_cursor; + header->sizes[header->slot_count] = packed_size; + ++header->slot_count; + + cursor = aligned_cursor + packed_size; + } + + if (header->slot_count == 0) { + header->magic = 0; + header->version = 0; + memcpy(tensor->data, data, data_size); } - GGML_UNUSED(data_size); - return -1; + return 0; } }; @@ -681,9 +1373,8 @@ static void ggml_backend_cpu_kleidiai_buffer_set_tensor(ggml_backend_buffer_t bu } static const char * ggml_backend_cpu_kleidiai_buffer_type_get_name(ggml_backend_buffer_type_t buft) { - return "CPU_KLEIDIAI"; - GGML_UNUSED(buft); + return "CPU_KLEIDIAI"; } static ggml_backend_buffer_t ggml_backend_cpu_kleidiai_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) { @@ -702,56 +1393,85 @@ static ggml_backend_buffer_t ggml_backend_cpu_kleidiai_buffer_type_alloc_buffer( } static size_t ggml_backend_cpu_kleidiai_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) { - return TENSOR_ALIGNMENT; - GGML_UNUSED(buft); + return TENSOR_ALIGNMENT; } static size_t ggml_backend_cpu_kleidiai_buffer_type_get_alloc_size(ggml_backend_buffer_type_t buft, const struct ggml_tensor * tensor) { GGML_UNUSED(buft); + if (tensor->type != GGML_TYPE_Q4_0 && tensor->type != GGML_TYPE_Q8_0) { + return ggml_nbytes(tensor); + } + const size_t n = tensor->ne[1]; const size_t k = tensor->ne[0]; - ggml_kleidiai_kernels * kernels = nullptr; - size_t block_len = 0; - - if (tensor->type == GGML_TYPE_Q4_0) { - GGML_ASSERT(ctx.kernels_q4); - kernels = ctx.kernels_q4; - block_len = QK4_0; - } else if (tensor->type == GGML_TYPE_Q8_0) { - GGML_ASSERT(ctx.kernels_q8); - kernels = ctx.kernels_q8; - block_len = QK8_0; - } else { - return 0; + size_t cursor = sizeof(kleidiai_weight_header); + cursor = align_up(cursor, GGML_KLEIDIAI_PACK_ALIGN); + + std::array<ggml_kleidiai_kernels *, GGML_KLEIDIAI_MAX_KERNEL_SLOTS> kernel_chain; + const bool want_q8 = tensor->type == GGML_TYPE_Q8_0; + const int slot_total = want_q8 ? kleidiai_collect_q8_chain(kernel_chain) + : kleidiai_collect_q4_chain(kernel_chain); + const bool allow_fallback = kleidiai_pack_fallback_allowed(); + + size_t slot_count = 0; + for (int slot = 0; slot < slot_total; ++slot) { + if (!allow_fallback && slot > 0) { + break; + } + ggml_kleidiai_kernels * kernels = kernel_chain[slot]; + if (!kernels) { + continue; + } + kernel_info * kernel = &kernels->gemm; + rhs_packing_info * rhs_info = &kernels->rhs_info; + if (!kernel || !rhs_info || !rhs_info->packed_size_ex) { + continue; + } + + const ggml_type rhs_type = kernels->rhs_type; + const size_t block_len = rhs_type == GGML_TYPE_Q4_0 ? QK4_0 : + rhs_type == GGML_TYPE_Q8_0 ? QK8_0 : 0; + if (block_len == 0) { + continue; + } + + cursor = align_up(cursor, GGML_KLEIDIAI_PACK_ALIGN); + cursor += rhs_info->packed_size_ex(n, k, kernel->get_nr(), kernel->get_kr(), block_len); + ++slot_count; } - const size_t nr = kernels->gemm.get_nr(); - const size_t kr = kernels->gemm.get_kr(); - const size_t packed = kernels->rhs_info.packed_size_ex(n, k, nr, kr, block_len); - const size_t raw = ggml_nbytes(tensor); + if (slot_count == 0) { + return ggml_nbytes(tensor); + } - return packed > raw ? packed : raw; + return std::max(cursor, ggml_nbytes(tensor)); } namespace ggml::cpu::kleidiai { class extra_buffer_type : ggml::cpu::extra_buffer_type { bool supports_op(ggml_backend_dev_t, const struct ggml_tensor * op) override { + std::array<ggml_kleidiai_kernels *, GGML_KLEIDIAI_MAX_KERNEL_SLOTS> kernel_chain; + const int slot_total = kleidiai_collect_kernel_chain(op, kernel_chain); if ((op->op == GGML_OP_MUL_MAT || op->op == GGML_OP_GET_ROWS) && (op->src[0]->type == GGML_TYPE_Q4_0 || op->src[0]->type == GGML_TYPE_Q8_0) && op->src[0]->buffer && (ggml_n_dims(op->src[0]) == 2) && - op->src[0]->buffer->buft == ggml_backend_cpu_kleidiai_buffer_type()) { - if (((op->src[0]->type == GGML_TYPE_Q4_0) ? ctx.kernels_q4 : ctx.kernels_q8) == nullptr) { + op->src[0]->buffer->buft == ggml_backend_cpu_kleidiai_buffer_type() && + slot_total > 0) { + if (op->src[0]->type == GGML_TYPE_Q4_0 && ctx.kernels_q4 == nullptr) { + return false; + } + if (op->src[0]->type == GGML_TYPE_Q8_0 && ctx.kernels_q8 == nullptr) { return false; } if (op->src[1]->buffer && !ggml_backend_buft_is_host(op->src[1]->buffer->buft)) { return false; } if ((op->src[1]->type == GGML_TYPE_F32 || op->src[1]->type == GGML_TYPE_I32) && - ggml_ne(op->src[1], 2) == 1 && ggml_ne(op->src[1], 3) == 1) { + ggml_ne(op->src[1], 3) == 1) { return true; } } @@ -762,14 +1482,19 @@ class extra_buffer_type : ggml::cpu::extra_buffer_type { if (op->op == GGML_OP_MUL_MAT || op->op == GGML_OP_GET_ROWS) { if (op->src[0]->buffer && op->src[0]->buffer->buft == ggml_backend_cpu_kleidiai_buffer_type()) { return (ggml::cpu::tensor_traits *) op->src[0]->extra; - } - else if (ggml_kleidiai_select_kernels(ctx.features, op) && op->src[1]->ne[1] > 1) { - if ((op->src[0]->nb[1] * op->src[0]->ne[1] != op->src[0]->nb[2]) || - (op->src[1]->nb[1] * op->src[1]->ne[1] != op->src[1]->nb[2])) { + } else { + if (op->src[0]->type != GGML_TYPE_F16) { return nullptr; } - - return ggml::cpu::kleidiai::get_tensor_traits(NULL, NULL); + std::array<ggml_kleidiai_kernels *, GGML_KLEIDIAI_MAX_KERNEL_SLOTS> kernel_chain; + const int slot_total = kleidiai_collect_kernel_chain(op, kernel_chain); + if (slot_total > 0 && op->src[1]->ne[1] > 1) { + if ((op->src[0]->nb[1] * op->src[0]->ne[1] != op->src[0]->nb[2]) || + (op->src[1]->nb[1] * op->src[1]->ne[1] != op->src[1]->nb[2])) { + return nullptr; + } + return ggml::cpu::kleidiai::get_tensor_traits(NULL, NULL); + } } } return nullptr; diff --git a/ggml/src/ggml-cpu/llamafile/sgemm-ppc.h b/ggml/src/ggml-cpu/llamafile/sgemm-ppc.h deleted file mode 100644 index a7078687288..00000000000 --- a/ggml/src/ggml-cpu/llamafile/sgemm-ppc.h +++ /dev/null @@ -1,333 +0,0 @@ -#pragma once - -typedef vector unsigned char vec_t; -typedef __vector_quad acc_t; - -template <typename TA> -class tinyBLAS_Q0_PPC { - public: - tinyBLAS_Q0_PPC(int64_t k, - const TA *A, int64_t lda, - const block_q8_0 *B, int64_t ldb, - float *C, int64_t ldc, - int ith, int nth); - - void matmul(int64_t m, int64_t n); - void matmul_tiled_q0(int64_t m, int64_t n, int64_t mc, int64_t nc, int64_t kc) { - vec_t A_pack[mc*kc*2]; - vec_t B_pack[nc*kc*2]; - int comparray[mc*kc]; - constexpr bool is_Ablock_q4 = std::is_same_v<TA, block_q4_0>; - int64_t ytiles = m / mc; - int64_t xtiles = n / nc; - int64_t tiles = xtiles * ytiles; - int64_t duty = (tiles + nth - 1) / nth; - int64_t start = duty * ith; - int64_t end = start + duty; - if (end > tiles) { - end = tiles; - } - for (int64_t job = start; job < end; ++job) { - int64_t ii = (job / xtiles) * mc; - int64_t jj = (job % xtiles) * nc; - for (int64_t kk = 0; kk < k; kk += kc) { - if constexpr(is_Ablock_q4) { - packNormalInt4_large(A + ii*lda + kk, lda, mc, 4, (int8_t*)A_pack, comparray); - } else { - packNormal_large<int8_t, vector signed char>(A + ii*lda + kk, lda, mc, 8, (int8_t*)A_pack, false, comparray); - } - packNormal_large<uint8_t, vector unsigned char>(B + jj*ldb + kk, ldb, nc, 8, (uint8_t*)B_pack, true); - KERNEL_Q0(ii, jj, mc, nc, kc, kk, A_pack, B_pack, comparray); - } - } - } - - private: - inline void save_res(int ii, int jj, int idx, vector float* fin_res, int RM=4, int RN=4) { - for (int I = 0; I < RM; I++) { - for (int J = 0; J < RN; J++) { - *((float*)(C+ii+((jj+J)*ldc)+I)) = *((float*)&fin_res[idx+I]+J); - } - } - } - - inline void add_save_res(int ii, int jj, int idx, vector float* fin_res, int RM=4, int RN=4) { - for (int I = 0; I < RM; I++) { - for (int J = 0; J < RN; J++) { - float * c_ptr = (float *)(C+ii+((jj+J)*ldc)+I); - *c_ptr += *((float*)&fin_res[idx+I]+J); - } - } - } - - template<typename ArrayType> - inline void compute(acc_t* ACC, int c_idx, int s_idx, ArrayType& comparray, vector float* vs, vector float* fin_res) { - vector signed int vec_C[4]; - vector float CA[4] = {0}; - vector float res[4] = {0}; - __builtin_mma_disassemble_acc(vec_C, ACC); - for (int i = 0; i < 4; i++) { - CA[i] = vec_splats((float)(((double)comparray[c_idx+i]) * -128.0)); - res[i] = vec_add(vec_ctf(vec_C[i], 0), CA[i]); - fin_res[s_idx+i] = vec_madd(res[i], vs[s_idx+i], fin_res[s_idx+i]); - } - } - - inline void process_q4_elements(vector signed char (&c)[2], int* ca) { - const vector signed char lowMask = vec_splats((signed char)0xF); - const vector unsigned char v4 = vec_splats((unsigned char)0x4); - const vector signed char v8 = vec_splats((signed char)0x8); - vector signed int vsum = {0}; - vector signed int vsum2 = {0}; - c[0] = vec_and(c[1], lowMask); - c[1] = vec_sr(c[1], v4); - c[0] = vec_sub(c[0], v8); - c[1] = vec_sub(c[1], v8); - vsum = vec_sum4s(c[0], vsum); - vsum2 = vec_sum4s(c[1], vsum2); - vsum = vec_add(vsum, vsum2); - *(ca) = vsum[0] + vsum[1] + vsum[2] + vsum[3]; - } - - template <typename V1, typename V2> - inline void vector_permute_store(V2 &s1, V2 &s2, V2 &s3, V2 &s4, V1 *vecOffset, bool flip) { - vector unsigned char swiz1 = {0, 1, 2, 3, 4, 5, 6, 7, 16, 17, 18, 19, 20, 21, 22, 23}; - vector unsigned char swiz2 = {8, 9, 10, 11, 12, 13, 14, 15, 24, 25, 26, 27, 28, 29, 30, 31}; - vector unsigned char swiz3 = {0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27}; - vector unsigned char swiz4 = {4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31}; - V2 t1, t2, t3, t4, t5, t6, t7, t8; - vector unsigned char xor_vector; - uint8_t flip_vec = 0x80; - xor_vector = vec_splats(flip_vec); - t1 = vec_perm(s1, s2, swiz1); - t2 = vec_perm(s1, s2, swiz2); - t3 = vec_perm(s3, s4, swiz1); - t4 = vec_perm(s3, s4, swiz2); - t5 = vec_perm(t1, t3, swiz3); - t6 = vec_perm(t1, t3, swiz4); - t7 = vec_perm(t2, t4, swiz3); - t8 = vec_perm(t2, t4, swiz4); - if (flip == true) { - t5 = vec_xor(t5, xor_vector); - t6 = vec_xor(t6, xor_vector); - t7 = vec_xor(t7, xor_vector); - t8 = vec_xor(t8, xor_vector); - } - vec_xst(t5, 0, vecOffset); - vec_xst(t6, 0, vecOffset+16); - vec_xst(t7, 0, vecOffset+32); - vec_xst(t8, 0, vecOffset+48); - } - - template<int RM, int RN> - inline void kernel(int64_t ii, int64_t jj) { - if constexpr(RM == 4 && RN == 8) { - KERNEL_4x8(ii,jj); - } else if constexpr(RM == 8 && RN == 4) { - KERNEL_8x4(ii,jj); - } else if constexpr(RM == 8 && RN == 8) { - KERNEL_8x8(ii,jj); - } else { - assert(false && "RN/RM values not supported"); - } - } - template<int size> - void packNormalInt4(const TA* a, int64_t lda, int rows, int cols, int8_t* vec, std::array<int, size>& comparray); - template<typename VA, typename VB> - void packNormal(const block_q8_0* a, int64_t lda, int rows, int cols, VA* vec, bool flip); - void mnpack(int64_t m0, int64_t m, int64_t n0, int64_t n); - void KERNEL_4x8(int64_t ii, int64_t jj); - void KERNEL_8x4(int64_t ii, int64_t jj); - void KERNEL_8x8(int64_t ii, int64_t jj); - void gemm_small(int64_t m0, int64_t m, int64_t n0, int64_t n, int RM, int RN); - template <int RM, int RN> - void gemm(int64_t m0, int64_t m, int64_t n0, int64_t n); - - void compute_scale(int64_t ii, int64_t jj, int blk, vector float* vs){ - for (int I = 0; I<8; I++) { - float a_scale = unhalf((A+((ii+I)*lda)+blk)->d); - for (int J = 0; J<4; J++) { - *((float*)&vs[I]+J) = (a_scale * unhalf((B+((jj+J)*ldb)+blk)->d)); - *((float*)&vs[I+8]+J) = (a_scale * unhalf((B+((jj+J+4)*ldb)+blk)->d)); - } - } - } - - inline void process_q8_elements(const int8_t *qs, int *ca) { - vector signed char c1 = vec_xl(0, qs); - vector signed char c2 = vec_xl(16, qs); - vector signed int vsum1 = {0}; - vector signed int vsum2 = {0}; - vsum1 = vec_sum4s(c1, vsum1); - vsum2 = vec_sum4s(c2, vsum2); - vector signed int vsum = vec_add(vsum1, vsum2); - *ca = vsum[0] + vsum[1] + vsum[2] + vsum[3]; - } - - template<typename VA, typename VB> - void packNormal_large(const block_q8_0* a, int64_t lda, int rows, int cols, VA* vec, bool flip, int* comparray=nullptr) { - int64_t i, j; - block_q8_0 *aoffset = NULL; - VA *vecOffset = NULL; - block_q8_0* aoffsets[8]; - __vector_pair arr[8]; - VB c[8][2] = {0}; - VB c1[8] = {0}; VB c2[8] = {0}; - aoffset = const_cast<block_q8_0*>(a); - vecOffset = vec; - j = (rows >> 3); - int index = 0; - if (j > 0) { - do { - for (int it = 0; it < 8; it++) - aoffsets[it] = aoffset + it*lda; - aoffset += 8 * lda; - for (int blk = 0; blk < kc; blk++) { - for (int it = 0; it < 8; it++) { - arr[it] = __builtin_vsx_lxvp(0, (__vector_pair*)(aoffsets[it]+blk)->qs); - __builtin_vsx_disassemble_pair(c[it], &arr[it]); - c1[it] = c[it][0]; - c2[it] = c[it][1]; - if (comparray){ - process_q8_elements((aoffsets[it]+ blk)->qs, &comparray[index + 8*blk + it]); - } - } - vector_permute_store<VA, VB>(c1[0], c1[1], c1[2], c1[3], vecOffset, flip); - vector_permute_store<VA, VB>(c2[0], c2[1], c2[2], c2[3], vecOffset+64, flip); - vector_permute_store<VA, VB>(c1[4], c1[5], c1[6], c1[7], vecOffset+128, flip); - vector_permute_store<VA, VB>(c2[4], c2[5], c2[6], c2[7], vecOffset+192, flip); - vecOffset += 256; - } - j--; - index += 8*kc; - } while(j > 0); - } - - } - - void packNormalInt4_large(const TA* a, int64_t lda, int rows, int cols, int8_t* vec, int*comparray) { - int64_t i, j; - TA *aoffset = NULL; - int8_t *vecOffset = NULL; - TA *aoffset1 = NULL, *aoffset2 = NULL, *aoffset3 = NULL, *aoffset4 = NULL; - TA *aoffset5 = NULL, *aoffset6 = NULL, *aoffset7 = NULL, *aoffset8 = NULL; - vector signed char c1[2] = {0}, c2[2] = {0}, c3[2] = {0}, c4[2] = {0}; - vector signed char c5[2] = {0}, c6[2] = {0}, c7[2] = {0}, c8[2] = {0}; - aoffset = const_cast<TA*>(a); - vecOffset = vec; - int index = 0; - j = (rows >> 3); - if (j > 0) { - do { - aoffset1 = aoffset; - aoffset2 = aoffset1 + lda; - aoffset3 = aoffset2 + lda; - aoffset4 = aoffset3 + lda; - aoffset5 = aoffset4 + lda; - aoffset6 = aoffset5 + lda; - aoffset7 = aoffset6 + lda; - aoffset8 = aoffset7 + lda; - aoffset += 8 * lda; - for (int blk = 0; blk < kc; blk++) { - c1[1] = reinterpret_cast<vector signed char>(vec_xl(0, (aoffset1+blk)->qs)); - c2[1] = reinterpret_cast<vector signed char>(vec_xl(0, (aoffset2+blk)->qs)); - c3[1] = reinterpret_cast<vector signed char>(vec_xl(0, (aoffset3+blk)->qs)); - c4[1] = reinterpret_cast<vector signed char>(vec_xl(0, (aoffset4+blk)->qs)); - c5[1] = reinterpret_cast<vector signed char>(vec_xl(0, (aoffset5+blk)->qs)); - c6[1] = reinterpret_cast<vector signed char>(vec_xl(0, (aoffset6+blk)->qs)); - c7[1] = reinterpret_cast<vector signed char>(vec_xl(0, (aoffset7+blk)->qs)); - c8[1] = reinterpret_cast<vector signed char>(vec_xl(0, (aoffset8+blk)->qs)); - - process_q4_elements(c1, &comparray[index + 8*blk+0]); - process_q4_elements(c2, &comparray[index + 8*blk+1]); - process_q4_elements(c3, &comparray[index + 8*blk+2]); - process_q4_elements(c4, &comparray[index + 8*blk+3]); - process_q4_elements(c5, &comparray[index + 8*blk+4]); - process_q4_elements(c6, &comparray[index + 8*blk+5]); - process_q4_elements(c7, &comparray[index + 8*blk+6]); - process_q4_elements(c8, &comparray[index + 8*blk+7]); - vector_permute_store<int8_t, vector signed char>(c1[0], c2[0], c3[0], c4[0], vecOffset, false); - vector_permute_store<int8_t, vector signed char>(c1[1], c2[1], c3[1], c4[1], vecOffset+64, false); - vector_permute_store<int8_t, vector signed char>(c5[0], c6[0], c7[0], c8[0], vecOffset+128, false); - vector_permute_store<int8_t, vector signed char>(c5[1], c6[1], c7[1], c8[1], vecOffset+192, false); - vecOffset += 256; - } - j--; - index += 8*kc; - } while (j > 0); - } - } - - void KERNEL_Q0(int64_t ii, int64_t jj, int64_t mc, int64_t nc, int64_t kc, int64_t l, vec_t *vec_A, vec_t *vec_B, int *comparray) { - acc_t acc[8]; - for (int i = 0; i < mc ; i += 8) { - for (int j = 0; j < nc; j += 8) { - vector float fin_res[16] = {0}; - vector float vs[16] = {0}; - for (int64_t kk = 0; kk < kc; kk+=2) { - for (int x = 0; x < 8; x++) { - __builtin_mma_xxsetaccz(&acc[x]); - } - int A_block_idx = (i/8)*(16*kc) + kk*16; - int B_block_idx = (j/8)*(16*kc)+ kk*16; - vec_t *A_block = &vec_A[A_block_idx]; - vec_t *B_block = &vec_B[B_block_idx]; - for (int x = 0; x < 8; x++) { - __builtin_mma_xvi8ger4pp(&acc[0], A_block[x], B_block[x]); - __builtin_mma_xvi8ger4pp(&acc[1], A_block[x + 8], B_block[x]); - __builtin_mma_xvi8ger4pp(&acc[2], A_block[x], B_block[x+8]); - __builtin_mma_xvi8ger4pp(&acc[3], A_block[x+8], B_block[x+8]); - } - compute_scale(ii+i, jj+j, l+kk, vs); - int c_index = (i/8)*(8*kc)+ kk*8; - int* c_block = &comparray[c_index]; - compute(&acc[0], 0, 0, c_block, vs, fin_res); - compute(&acc[1], 4, 4, c_block, vs, fin_res); - compute(&acc[2], 0, 8, c_block, vs, fin_res); - compute(&acc[3], 4, 12, c_block, vs, fin_res); - - A_block_idx = (i/8)*(16*kc) + (kk+1)*16; - B_block_idx = (j/8)*(16*kc)+ (kk+1)*16; - A_block = &vec_A[A_block_idx]; - B_block = &vec_B[B_block_idx]; - for (int x = 0; x < 8; x++) { - __builtin_mma_xvi8ger4pp(&acc[4], A_block[x], B_block[x]); - __builtin_mma_xvi8ger4pp(&acc[5], A_block[x + 8], B_block[x]); - __builtin_mma_xvi8ger4pp(&acc[6], A_block[x], B_block[x+8]); - __builtin_mma_xvi8ger4pp(&acc[7], A_block[x+8], B_block[x+8]); - } - compute_scale(ii+i, jj+j, l+kk+1, vs); - c_index = (i/8)*(8*kc)+ (kk+1)*8; - c_block = &comparray[c_index]; - compute(&acc[4], 0, 0, c_block, vs, fin_res); - compute(&acc[5], 4, 4, c_block, vs, fin_res); - compute(&acc[6], 0, 8, c_block, vs, fin_res); - compute(&acc[7], 4, 12, c_block, vs, fin_res); - - } - if (l == 0) { - save_res(ii+i, jj+j, 0, fin_res); - save_res(ii+i+4, jj+j, 4, fin_res); - save_res(ii+i, jj+j+4, 8, fin_res); - save_res(ii+i+4, jj+j+4, 12, fin_res); - } else { - add_save_res(ii+i, jj+j, 0, fin_res); - add_save_res(ii+i+4, jj+j, 4, fin_res); - add_save_res(ii+i, jj+j+4, 8, fin_res); - add_save_res(ii+i+4, jj+j+4, 12, fin_res); - } - } - } - } - - const TA *const A; - const block_q8_0 *const B; - float *C; - const int64_t k; - int64_t kc; - const int64_t lda; - const int64_t ldb; - const int64_t ldc; - const int ith; - const int nth; -}; diff --git a/ggml/src/ggml-cpu/llamafile/sgemm.cpp b/ggml/src/ggml-cpu/llamafile/sgemm.cpp index 7dc36d4f8ad..e13828e3be6 100644 --- a/ggml/src/ggml-cpu/llamafile/sgemm.cpp +++ b/ggml/src/ggml-cpu/llamafile/sgemm.cpp @@ -121,7 +121,8 @@ inline float32x4_t mul(float32x4_t x, float32x4_t y) { return vec_mul(x, y); } #endif #if defined(__MMA__) -#include "sgemm-ppc.h" +typedef vector unsigned char vec_t; +typedef __vector_quad acc_t; #endif //////////////////////////////////////////////////////////////////////////////////////////////////// // VECTORIZED FUSED MULTIPLY ADD @@ -179,44 +180,49 @@ inline float32x4_t madd(float32x4_t a, float32x4_t b, float32x4_t c) { } #endif +#if defined(__riscv_v_intrinsic) +template <> inline vfloat32m1_t madd(vfloat32m1_t a, vfloat32m1_t b, vfloat32m1_t c) { + return __riscv_vfmacc_vv_f32m1(c, a, b, __riscv_vsetvlmax_e32m1()); +} +template <> inline vfloat32m2_t madd(vfloat32m2_t a, vfloat32m2_t b, vfloat32m2_t c) { + return __riscv_vfmacc_vv_f32m2(c, a, b, __riscv_vsetvlmax_e32m2()); +} +template <> inline vfloat32m4_t madd(vfloat32m4_t a, vfloat32m4_t b, vfloat32m4_t c) { + return __riscv_vfmacc_vv_f32m4(c, a, b, __riscv_vsetvlmax_e32m4()); +} +template <> inline vfloat32m8_t madd(vfloat32m8_t a, vfloat32m8_t b, vfloat32m8_t c) { + return __riscv_vfmacc_vv_f32m8(c, a, b, __riscv_vsetvlmax_e32m8()); +} +#endif + #if defined(__riscv_zvfh) -template <> -inline vfloat32m1_t madd(vfloat16mf2_t a, vfloat16mf2_t b, vfloat32m1_t c) { +template <> inline vfloat32m1_t madd(vfloat16mf2_t a, vfloat16mf2_t b, vfloat32m1_t c) { return __riscv_vfwmacc_vv_f32m1(c, a, b, __riscv_vsetvlmax_e32m1()); } -inline vfloat32m2_t madd(vfloat16m1_t a, vfloat16m1_t b, vfloat32m2_t c) { +template <> inline vfloat32m2_t madd(vfloat16m1_t a, vfloat16m1_t b, vfloat32m2_t c) { return __riscv_vfwmacc_vv_f32m2(c, a, b, __riscv_vsetvlmax_e32m2()); } -inline vfloat32m4_t madd(vfloat16m2_t a, vfloat16m2_t b, vfloat32m4_t c) { +template <> inline vfloat32m4_t madd(vfloat16m2_t a, vfloat16m2_t b, vfloat32m4_t c) { return __riscv_vfwmacc_vv_f32m4(c, a, b, __riscv_vsetvlmax_e32m4()); } -inline vfloat32m8_t madd(vfloat16m4_t a, vfloat16m4_t b, vfloat32m8_t c) { +template <> inline vfloat32m8_t madd(vfloat16m4_t a, vfloat16m4_t b, vfloat32m8_t c) { return __riscv_vfwmacc_vv_f32m8(c, a, b, __riscv_vsetvlmax_e32m8()); } -inline vfloat32m1_t madd(vfloat32m1_t a, vfloat32m1_t b, vfloat32m1_t c) { - return __riscv_vfmacc_vv_f32m1(c, a, b, __riscv_vsetvlmax_e32m1()); -} -inline vfloat32m2_t madd(vfloat32m2_t a, vfloat32m2_t b, vfloat32m2_t c) { - return __riscv_vfmacc_vv_f32m2(c, a, b, __riscv_vsetvlmax_e32m2()); -} -inline vfloat32m4_t madd(vfloat32m4_t a, vfloat32m4_t b, vfloat32m4_t c) { - return __riscv_vfmacc_vv_f32m4(c, a, b, __riscv_vsetvlmax_e32m4()); -} -inline vfloat32m8_t madd(vfloat32m8_t a, vfloat32m8_t b, vfloat32m8_t c) { - return __riscv_vfmacc_vv_f32m8(c, a, b, __riscv_vsetvlmax_e32m8()); -} #endif #if defined(__riscv_zvfbfwma) -inline vfloat32m1_t madd(vbfloat16mf2_t a, vbfloat16mf2_t b, vfloat32m1_t c) { +template <> inline vfloat32m1_t madd(vbfloat16mf2_t a, vbfloat16mf2_t b, vfloat32m1_t c) { return __riscv_vfwmaccbf16_vv_f32m1(c, a, b, __riscv_vsetvlmax_e32m1()); } -inline vfloat32m2_t madd(vbfloat16m1_t a, vbfloat16m1_t b, vfloat32m2_t c) { +template <> inline vfloat32m2_t madd(vbfloat16m1_t a, vbfloat16m1_t b, vfloat32m2_t c) { return __riscv_vfwmaccbf16_vv_f32m2(c, a, b, __riscv_vsetvlmax_e32m2()); } -inline vfloat32m4_t madd(vbfloat16m2_t a, vbfloat16m2_t b, vfloat32m4_t c) { +template <> inline vfloat32m4_t madd(vbfloat16m2_t a, vbfloat16m2_t b, vfloat32m4_t c) { return __riscv_vfwmaccbf16_vv_f32m4(c, a, b, __riscv_vsetvlmax_e32m4()); } +template <> inline vfloat32m8_t madd(vbfloat16m4_t a, vbfloat16m4_t b, vfloat32m8_t c) { + return __riscv_vfwmaccbf16_vv_f32m8(c, a, b, __riscv_vsetvlmax_e32m8()); +} #endif //////////////////////////////////////////////////////////////////////////////////////////////////// @@ -271,7 +277,7 @@ inline float hsum(__m512 x) { } #endif // __AVX512F__ -#if defined(__riscv_zvfh) +#if defined(__riscv_v_intrinsic) inline float hsum(vfloat32m1_t x) { return __riscv_vfmv_f_s_f32m1_f32( __riscv_vfredusum_vs_f32m1_f32m1(x, __riscv_vfmv_v_f_f32m1(0, 1), __riscv_vsetvlmax_e32m1())); @@ -378,6 +384,21 @@ template <> inline __m256bh load(const float *p) { } #endif +#if defined(__riscv_v_intrinsic) +template <> inline vfloat32m1_t load(const float *p) { + return __riscv_vle32_v_f32m1(p, __riscv_vsetvlmax_e32m1()); +} +template <> inline vfloat32m2_t load(const float *p) { + return __riscv_vle32_v_f32m2(p, __riscv_vsetvlmax_e32m2()); +} +template <> inline vfloat32m4_t load(const float *p) { + return __riscv_vle32_v_f32m4(p, __riscv_vsetvlmax_e32m4()); +} +template <> inline vfloat32m8_t load(const float *p) { + return __riscv_vle32_v_f32m8(p, __riscv_vsetvlmax_e32m8()); +} +#endif + #if defined(__riscv_zvfh) template <> inline vfloat16mf2_t load(const ggml_fp16_t *p) { return __riscv_vle16_v_f16mf2(reinterpret_cast<const _Float16 *>(p), __riscv_vsetvlmax_e16mf2()); @@ -391,18 +412,6 @@ template <> inline vfloat16m2_t load(const ggml_fp16_t *p) { template <> inline vfloat16m4_t load(const ggml_fp16_t *p) { return __riscv_vle16_v_f16m4(reinterpret_cast<const _Float16 *>(p), __riscv_vsetvlmax_e16m4()); } -template <> inline vfloat32m1_t load(const float *p) { - return __riscv_vle32_v_f32m1(p, __riscv_vsetvlmax_e32m1()); -} -template <> inline vfloat32m2_t load(const float *p) { - return __riscv_vle32_v_f32m2(p, __riscv_vsetvlmax_e32m2()); -} -template <> inline vfloat32m4_t load(const float *p) { - return __riscv_vle32_v_f32m4(p, __riscv_vsetvlmax_e32m4()); -} -template <> inline vfloat32m8_t load(const float *p) { - return __riscv_vle32_v_f32m8(p, __riscv_vsetvlmax_e32m8()); -} #endif #if defined(__riscv_zvfbfwma) @@ -415,23 +424,14 @@ template <> inline vbfloat16m1_t load(const ggml_bf16_t *p) { template <> inline vbfloat16m2_t load(const ggml_bf16_t *p) { return __riscv_vle16_v_bf16m2(reinterpret_cast<const __bf16*>(p), __riscv_vsetvlmax_e16m2()); } +template <> inline vbfloat16m4_t load(const ggml_bf16_t *p) { + return __riscv_vle16_v_bf16m4(reinterpret_cast<const __bf16*>(p), __riscv_vsetvlmax_e16m4()); +} #endif -#if defined(__riscv_zvfh) +#if defined(__riscv_v_intrinsic) template <typename T> T set_zero(); -template <> inline vfloat16mf2_t set_zero() { - return __riscv_vfmv_v_f_f16mf2(0, __riscv_vsetvlmax_e16mf2()); -} -template <> inline vfloat16m1_t set_zero() { - return __riscv_vfmv_v_f_f16m1(0, __riscv_vsetvlmax_e16m1()); -} -template <> inline vfloat16m2_t set_zero() { - return __riscv_vfmv_v_f_f16m2(0, __riscv_vsetvlmax_e16m2()); -} -template <> inline vfloat16m4_t set_zero() { - return __riscv_vfmv_v_f_f16m4(0, __riscv_vsetvlmax_e16m4()); -} template <> inline vfloat32m1_t set_zero() { return __riscv_vfmv_v_f_f32m1(0.0f, __riscv_vsetvlmax_e32m1()); } @@ -448,14 +448,22 @@ template <> inline vfloat32m8_t set_zero() { #if defined(__riscv_v_intrinsic) template <typename T> size_t vlmax() { - if constexpr (std::is_same_v<T, vfloat16mf2_t>) { return __riscv_vsetvlmax_e16mf2(); } - else if constexpr (std::is_same_v<T, vfloat16m1_t>) { return __riscv_vsetvlmax_e16m1(); } - else if constexpr (std::is_same_v<T, vfloat16m2_t>) { return __riscv_vsetvlmax_e16m2(); } - else if constexpr (std::is_same_v<T, vfloat16m4_t>) { return __riscv_vsetvlmax_e16m4(); } - else if constexpr (std::is_same_v<T, vfloat32m1_t>) { return __riscv_vsetvlmax_e32m1(); } + if constexpr (std::is_same_v<T, vfloat32m1_t>) { return __riscv_vsetvlmax_e32m1(); } else if constexpr (std::is_same_v<T, vfloat32m2_t>) { return __riscv_vsetvlmax_e32m2(); } else if constexpr (std::is_same_v<T, vfloat32m4_t>) { return __riscv_vsetvlmax_e32m4(); } else if constexpr (std::is_same_v<T, vfloat32m8_t>) { return __riscv_vsetvlmax_e32m8(); } + #if defined (__riscv_zvfh) + else if constexpr (std::is_same_v<T, vfloat16mf2_t>) { return __riscv_vsetvlmax_e16mf2(); } + else if constexpr (std::is_same_v<T, vfloat16m1_t>) { return __riscv_vsetvlmax_e16m1(); } + else if constexpr (std::is_same_v<T, vfloat16m2_t>) { return __riscv_vsetvlmax_e16m2(); } + else if constexpr (std::is_same_v<T, vfloat16m4_t>) { return __riscv_vsetvlmax_e16m4(); } + #endif + #if defined (__riscv_zvfbfwma) + else if constexpr (std::is_same_v<T, vbfloat16mf2_t>) { return __riscv_vsetvlmax_e16mf2(); } + else if constexpr (std::is_same_v<T, vbfloat16m1_t>) { return __riscv_vsetvlmax_e16m1(); } + else if constexpr (std::is_same_v<T, vbfloat16m2_t>) { return __riscv_vsetvlmax_e16m2(); } + else if constexpr (std::is_same_v<T, vbfloat16m4_t>) { return __riscv_vsetvlmax_e16m4(); } + #endif return 0; } #endif @@ -532,7 +540,7 @@ class tinyBLAS { if constexpr (RN > 1) { return mnpack<RM, RN-1, BM>(m, n, SIZE_N, BN); } else { - GGML_LOG_ERROR("mnpack<%d, %d> bloc size not supported\n", RM, (int)SIZE_N); + GGML_LOG_ERROR("mnpack<%d, %d> block size not supported\n", RM, (int)SIZE_N); GGML_ASSERT(false); // we have miss something. } } @@ -710,7 +718,7 @@ class tinyBLAS_RVV { if constexpr (RN > 1) { return mnpack<RM, RN-1, BM>(m, n, SIZE_N, BN); } else { - GGML_LOG_ERROR("mnpack<%d, %d> bloc size not supported\n", RM, (int)SIZE_N); + GGML_LOG_ERROR("mnpack<%d, %d> block size not supported\n", RM, (int)SIZE_N); GGML_ASSERT(false); // we have miss something. } } @@ -1797,10 +1805,27 @@ class tinyBLAS_Q0_AVX { } \ } \ +template<typename T> +struct mma_instr; + +template<> +struct mma_instr<ggml_bf16_t> { + static inline void outer_product(acc_t *acc, vec_t a, vec_t b) { + __builtin_mma_xvbf16ger2pp(acc, a, b); + } +}; + +template<> +struct mma_instr<ggml_fp16_t> { + static inline void outer_product(acc_t *acc, vec_t a, vec_t b) { + __builtin_mma_xvf16ger2pp(acc, a, b); + } +}; + template <typename TA, typename TB, typename TC> -class tinyBLAS_BF16_PPC { +class tinyBLAS_HP16_PPC { public: - tinyBLAS_BF16_PPC(int64_t k, + tinyBLAS_HP16_PPC(int64_t k, const TA *A, int64_t lda, const TB *B, int64_t ldb, TC *C, int64_t ldc, @@ -2118,8 +2143,8 @@ class tinyBLAS_BF16_PPC { packNormal((A+(ii*lda)+l), lda, 4, 8, (uint8_t*)vec_A); packNormal((B+(jj*ldb)+l), ldb, 8, 8, (uint8_t*)vec_B); for (int x = 0; x < 4; x++) { - __builtin_mma_xvbf16ger2pp(&acc_0, vec_A[x], vec_B[x]); - __builtin_mma_xvbf16ger2pp(&acc_1, vec_A[x], vec_B[x+4]); + mma_instr<TA>::outer_product(&acc_0, vec_A[x], vec_B[x]); + mma_instr<TA>::outer_product(&acc_1, vec_A[x], vec_B[x+4]); } } SAVE_ACC(&acc_0, ii, jj); @@ -2135,8 +2160,8 @@ class tinyBLAS_BF16_PPC { packNormal((A+(ii*lda)+l), lda, 8, 8, (uint8_t*)vec_A); packNormal((B+(jj*ldb)+l), ldb, 8, 4, (uint8_t*)vec_B); for (int x = 0; x < 4; x++) { - __builtin_mma_xvbf16ger2pp(&acc_0, vec_A[x], vec_B[x]); - __builtin_mma_xvbf16ger2pp(&acc_1, vec_A[x+4], vec_B[x]); + mma_instr<TA>::outer_product(&acc_0, vec_A[x], vec_B[x]); + mma_instr<TA>::outer_product(&acc_1, vec_A[x+4], vec_B[x]); } } SAVE_ACC(&acc_0, ii, jj); @@ -2155,10 +2180,10 @@ class tinyBLAS_BF16_PPC { packNormal(A+(ii*lda)+l, lda, 8, 8, (uint8_t*)vec_A); packNormal(B+(jj*ldb)+l, ldb, 8, 8, (uint8_t*)vec_B); for (int x = 0; x < 4; x++) { - __builtin_mma_xvbf16ger2pp(&acc_0, vec_A[x], vec_B[x]); - __builtin_mma_xvbf16ger2pp(&acc_1, (vec_t)vec_A[x], (vec_t)vec_B[x+4]); - __builtin_mma_xvbf16ger2pp(&acc_2, (vec_t)vec_A[x+4], (vec_t)vec_B[x]); - __builtin_mma_xvbf16ger2pp(&acc_3, (vec_t)vec_A[x+4], (vec_t)vec_B[x+4]); + mma_instr<TA>::outer_product(&acc_0, vec_A[x], vec_B[x]); + mma_instr<TA>::outer_product(&acc_1, vec_A[x], vec_B[x+4]); + mma_instr<TA>::outer_product(&acc_2, vec_A[x+4], vec_B[x]); + mma_instr<TA>::outer_product(&acc_3, vec_A[x+4], vec_B[x+4]); } } @@ -2189,7 +2214,7 @@ class tinyBLAS_BF16_PPC { packNormal(A+(ii*lda)+l, lda, RM, 4, (uint8_t*)vec_A); packNormal(B+(jj*ldb)+l, ldb, RN, 4, (uint8_t*)vec_B); for (int x = 0; x<2; x++) { - __builtin_mma_xvbf16ger2pp(&acc_0, vec_A[x], vec_B[x]); + mma_instr<TA>::outer_product(&acc_0, vec_A[x], vec_B[x]); } } __builtin_mma_disassemble_acc(vec_C, &acc_0); @@ -2224,8 +2249,8 @@ class tinyBLAS_BF16_PPC { packNormal(A+(ii*lda)+l, lda, RM, 8, (uint8_t*)vec_A); packNormal(B+(jj*ldb)+l, ldb, RN, 8, (uint8_t*)vec_B); for (int x = 0; x<4; x++) { - __builtin_mma_xvbf16ger2pp(&acc_0, vec_A[x], vec_B[x]); - __builtin_mma_xvbf16ger2pp(&acc_1, vec_A[x], vec_B[x+4]); + mma_instr<TA>::outer_product(&acc_0, vec_A[x], vec_B[x]); + mma_instr<TA>::outer_product(&acc_1, vec_A[x], vec_B[x+4]); } } __builtin_mma_disassemble_acc(vec_C, &acc_0); @@ -2284,43 +2309,302 @@ class tinyBLAS_BF16_PPC { const int nth; }; - template <typename TA> - tinyBLAS_Q0_PPC<TA>::tinyBLAS_Q0_PPC(int64_t k, - const TA *A, int64_t lda, - const block_q8_0 *B, int64_t ldb, - float *C, int64_t ldc, - int ith, int nth) +template <typename TA> +class tinyBLAS_Q0_PPC { + public: + tinyBLAS_Q0_PPC(int64_t k, + const TA * A, int64_t lda, + const block_q8_0 * B, int64_t ldb, + float * C, int64_t ldc, + int ith, int nth) : A(A), B(B), C(C), k(k), lda(lda), ldb(ldb), ldc(ldc), ith(ith), nth(nth) { - kc = 64; } - template<typename TA> - void tinyBLAS_Q0_PPC<TA>::matmul(int64_t m, int64_t n) { - int mc = 64; int nc = 64; - if (n % 8 == 0 && n < nc) { - nc = n; - mc = 32 ; - kc = 32; - } - const bool is_aligned = ((m & (mc - 1)) == 0) & ((n & (nc - 1)) == 0) & ((k & (kc - 1)) == 0); - if (is_aligned) { - this->matmul_tiled_q0(m, n, mc, nc, kc); + void matmul(int64_t m, int64_t n) { + #if defined(_AIX) || defined(__BIG_ENDIAN__) + mnpack(0, m, 0, n); + #else + const int64_t mc = 64; + const int64_t kc = 64; + int64_t nc = 64; + int64_t n_aligned = 0; + if (n % 64 == 0) { + n_aligned = n; + } else if (n == 4) { + n_aligned = 4; + } else if (n < 64) { + n_aligned = (n / 8) * 8; + } else { + n_aligned = (n / 64) * 64; + } + if (n_aligned > 0) { + if (n_aligned % 64 == 0) nc = 64; + else if (n_aligned == n) nc = n; + else if (n_aligned % 32 == 0) nc = 32; + else if (n_aligned % 24 == 0) nc = 24; + else if (n_aligned % 16 == 0) nc = 16; + else nc = 8; + } + bool can_use_tiled = n_aligned > 0 && (m % mc == 0) && (k % kc == 0); + if (can_use_tiled) { + matmul_tiled(m, n_aligned, mc, nc, kc); + if (n > n_aligned) { + mnpack(0, m, n_aligned, n); + } } else { mnpack(0, m, 0, n); } + #endif + } + + private: + inline void save_res(int ii, int jj, int idx, vector float * fin_res, int RM = 4, int RN = 4) { + for (int I = 0; I < RM; I++) { + for (int J = 0; J < RN; J++) { + *((float *)(C + ii + ((jj + J) * ldc) + I)) = *((float *)&fin_res[idx + I] + J); + } + } } - template<typename TA> - template<int size> - void tinyBLAS_Q0_PPC<TA>::packNormalInt4(const TA* a, int64_t lda, int rows, int cols, int8_t* vec, std::array<int, size>& comparray) { + inline void save_acc(acc_t * ACC, int64_t ii, int64_t jj) { + vec_t vec_C[4]; + __builtin_mma_disassemble_acc(vec_C, ACC); + for (int I = 0; I < 4; I++) { + for (int J = 0; J < 4; J++) { + *((float *)(C + ii + ((jj + J) * ldc) + I)) = *((float *)&vec_C[I] + J); + } + } + } + + inline void add_save_acc(acc_t * ACC, int64_t ii, int64_t jj) { + vec_t vec_C[4]; + __builtin_mma_disassemble_acc(vec_C, ACC); + for (int I = 0; I < 4; I++) { + for (int J = 0; J < 4; J++) { + float * c_ptr = (float *)(C + ii+ ((jj + J) * ldc) + I); + *c_ptr += *((float *)&vec_C[I] + J); + } + } + } + + template<typename ArrayType> + inline void compute(acc_t * ACC, int c_idx, int s_idx, ArrayType & comparray, vector float * vs, vector float * fin_res) { + vector signed int vec_C[4]; + vector float CA[4] = {0}; + vector float res[4] = {0}; + __builtin_mma_disassemble_acc(vec_C, ACC); + for (int i = 0; i < 4; i++) { + CA[i] = vec_splats((float)(((double)comparray[c_idx + i]) * -128.0)); + res[i] = vec_add(vec_ctf(vec_C[i], 0), CA[i]); + fin_res[s_idx + i] = vec_madd(res[i], vs[s_idx + i], fin_res[s_idx + i]); + } + } + + inline void process_q4_elements(vector signed char (&c)[2], int * ca) { + const vector signed char lowMask = vec_splats((signed char)0xF); + const vector unsigned char v4 = vec_splats((unsigned char)0x4); + const vector signed char v8 = vec_splats((signed char)0x8); + vector signed int vsum = {0}; + vector signed int vsum2 = {0}; + c[0] = vec_and(c[1], lowMask); + c[1] = vec_sr(c[1], v4); + c[0] = vec_sub(c[0], v8); + c[1] = vec_sub(c[1], v8); + vsum = vec_sum4s(c[0], vsum); + vsum2 = vec_sum4s(c[1], vsum2); + vsum = vec_add(vsum, vsum2); + *(ca) = vsum[0] + vsum[1] + vsum[2] + vsum[3]; + } + + template <typename V1, typename V2> + inline void vector_permute_store(V2 & s1, V2 & s2, V2 & s3, V2 & s4, V1 * vecOffset, bool flip) { + vector unsigned char swiz1 = {0, 1, 2, 3, 4, 5, 6, 7, 16, 17, 18, 19, 20, 21, 22, 23}; + vector unsigned char swiz2 = {8, 9, 10, 11, 12, 13, 14, 15, 24, 25, 26, 27, 28, 29, 30, 31}; + vector unsigned char swiz3 = {0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27}; + vector unsigned char swiz4 = {4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31}; + V2 t1, t2, t3, t4, t5, t6, t7, t8; + vector unsigned char xor_vector; + uint8_t flip_vec = 0x80; + xor_vector = vec_splats(flip_vec); + t1 = vec_perm(s1, s2, swiz1); + t2 = vec_perm(s1, s2, swiz2); + t3 = vec_perm(s3, s4, swiz1); + t4 = vec_perm(s3, s4, swiz2); + t5 = vec_perm(t1, t3, swiz3); + t6 = vec_perm(t1, t3, swiz4); + t7 = vec_perm(t2, t4, swiz3); + t8 = vec_perm(t2, t4, swiz4); + if (flip == true) { + t5 = vec_xor(t5, xor_vector); + t6 = vec_xor(t6, xor_vector); + t7 = vec_xor(t7, xor_vector); + t8 = vec_xor(t8, xor_vector); + } + vec_xst(t5, 0, vecOffset); + vec_xst(t6, 0, vecOffset + 16); + vec_xst(t7, 0, vecOffset + 32); + vec_xst(t8, 0, vecOffset + 48); + } + + inline void unpack_q4_to_q8(vector signed char packed, vector signed char & lo, vector signed char & hi) { + const vector signed char lowMask = vec_splats((signed char)0x0F); + const vector signed char v8 = vec_splats((signed char)0x08); + const vector unsigned char v4 = vec_splats((unsigned char)4); + lo = vec_and(packed, lowMask); + hi = vec_sr(packed, v4); + lo = vec_sub(lo, v8); + hi = vec_sub(hi, v8); + } + + inline void vector_permute_store_fp16(vec_t * c, unsigned char * vecOffset) { + vec_t t[8], s[8]; + vec_t swiz1 = {0, 1, 2, 3, 16, 17, 18, 19, 4, 5, 6, 7, 20, 21, 22, 23}; + vec_t swiz2 = {8, 9, 10, 11, 24, 25, 26, 27, 12, 13, 14, 15, 28, 29, 30, 31}; + vec_t swiz3 = {0, 1, 2, 3, 4, 5, 6, 7, 16, 17, 18, 19, 20, 21, 22, 23}; + vec_t swiz4 = {8, 9, 10, 11, 12, 13, 14, 15, 24, 25, 26, 27, 28, 29, 30, 31}; + for (int i = 0; i < 4; i += 2) { + t[i + 0] = vec_perm(c[i + 0], c[i + 1], swiz1); + t[i + 1] = vec_perm(c[i + 0], c[i + 1], swiz2); + } + for (int i = 4; i < 8; i += 2) { + t[i + 0] = vec_perm(c[i + 0], c[i + 1], swiz1); + t[i + 1] = vec_perm(c[i + 0], c[i + 1], swiz2); + } + s[0] = vec_perm(t[0], t[2], swiz3); + s[1] = vec_perm(t[0], t[2], swiz4); + s[2] = vec_perm(t[1], t[3], swiz3); + s[3] = vec_perm(t[1], t[3], swiz4); + s[4] = vec_perm(t[4], t[6], swiz3); + s[5] = vec_perm(t[4], t[6], swiz4); + s[6] = vec_perm(t[5], t[7], swiz3); + s[7] = vec_perm(t[5], t[7], swiz4); + for (int i = 0; i < 8; ++i) { + vec_xst(s[i], 0, (vec_t *)(vecOffset + i * 16)); + } + } + + static inline void convert_and_scale_q8(vector signed char raw, vector float v_scale, vector unsigned short & out_hi, vector unsigned short & out_lo) { + vector signed short i16_hi = vec_unpackh(raw); + vector signed short i16_lo = vec_unpackl(raw); + + vector float f_hi_h = vec_ctf(vec_unpackh(i16_hi), 0); + vector float f_hi_l = vec_ctf(vec_unpackl(i16_hi), 0); + vector float f_lo_h = vec_ctf(vec_unpackh(i16_lo), 0); + vector float f_lo_l = vec_ctf(vec_unpackl(i16_lo), 0); + out_hi = vec_pack_to_short_fp32(vec_mul(f_hi_h, v_scale), vec_mul(f_hi_l, v_scale)); + out_lo = vec_pack_to_short_fp32(vec_mul(f_lo_h, v_scale), vec_mul(f_lo_l, v_scale)); + } + + void packNormal_q4_fp16(const block_q4_0 * a, int64_t lda, int rows, int blocks, unsigned char * vec) { + unsigned char * vecOffset = vec; + for (int i = 0; i < rows; i += 8) { + const block_q4_0 * rows_base[8]; + for (int r = 0; r < 8; r++) { + rows_base[r] = a + (i + r) * lda; + } + for (int blk = 0; blk < blocks; blk++) { + vector unsigned short hp_res[8][4]; + for (int r = 0; r < 8; r++) { + const block_q4_0 * current_blk = rows_base[r] + blk; + vector float v_scale = vec_extract_fp32_from_shorth(vec_splats(current_blk->d)); + vector signed char v_qs = vec_xl(0, (const vector signed char *)current_blk->qs); + vector signed char c1, c2; + unpack_q4_to_q8(v_qs, c1, c2); + convert_and_scale_q8(c1, v_scale, hp_res[r][0], hp_res[r][1]); + convert_and_scale_q8(c2, v_scale, hp_res[r][2], hp_res[r][3]); + } + for (int c = 0; c < 4; c++) { + vector unsigned char c_arr[8]; + for (int r = 0; r < 8; r++) { + c_arr[r] = (vector unsigned char)hp_res[r][c]; + } + vector_permute_store_fp16((vec_t *)c_arr, vecOffset); + vecOffset += 128; + } + } + } + } + + template <int chunk_size> + static inline void pack_q8_block(const block_q8_0 * a, int64_t lda, int rows, int blocks, unsigned char * vec) { + unsigned char * vecOffset = vec; + const vec_t swiz1 = {0, 1, 2, 3, 16, 17, 18, 19, 4, 5, 6, 7, 20, 21, 22, 23}; + const vec_t swiz2 = {8, 9, 10, 11, 24, 25, 26, 27, 12, 13, 14, 15, 28, 29, 30, 31}; + const vec_t swiz3 = {0, 1, 2, 3, 4, 5, 6, 7, 16, 17, 18, 19, 20, 21, 22, 23}; + const vec_t swiz4 = {8, 9, 10, 11, 12, 13, 14, 15, 24, 25, 26, 27, 28, 29, 30, 31}; + + for (int i = 0; i < rows; i += chunk_size) { + const block_q8_0 * rows_base[chunk_size]; + for (int r = 0; r < chunk_size; r++) { + rows_base[r] = a + (i + r) * lda; + } + for (int blk = 0; blk < blocks; blk++) { + vector unsigned short hp_res[chunk_size][4]; + for (int r = 0; r < chunk_size; r++) { + const block_q8_0 * b = rows_base[r] + blk; + vector float v_scale = vec_extract_fp32_from_shorth(vec_splats(b->d)); + vector signed char c[2]; + __vector_pair pair = __builtin_vsx_lxvp(0, (__vector_pair *)b->qs); + __builtin_vsx_disassemble_pair(c, & pair); + convert_and_scale_q8(c[0], v_scale, hp_res[r][0], hp_res[r][1]); + convert_and_scale_q8(c[1], v_scale, hp_res[r][2], hp_res[r][3]); + } + for (int col = 0; col < 4; col++) { + if constexpr (chunk_size == 8) { + vec_t t[8]; + t[0] = vec_perm((vec_t)hp_res[0][col], (vec_t)hp_res[1][col], swiz1); + t[1] = vec_perm((vec_t)hp_res[0][col], (vec_t)hp_res[1][col], swiz2); + t[2] = vec_perm((vec_t)hp_res[2][col], (vec_t)hp_res[3][col], swiz1); + t[3] = vec_perm((vec_t)hp_res[2][col], (vec_t)hp_res[3][col], swiz2); + t[4] = vec_perm((vec_t)hp_res[4][col], (vec_t)hp_res[5][col], swiz1); + t[5] = vec_perm((vec_t)hp_res[4][col], (vec_t)hp_res[5][col], swiz2); + t[6] = vec_perm((vec_t)hp_res[6][col], (vec_t)hp_res[7][col], swiz1); + t[7] = vec_perm((vec_t)hp_res[6][col], (vec_t)hp_res[7][col], swiz2); + + vec_xst(vec_perm(t[0], t[2], swiz3), 0, (vec_t *)(vecOffset + 0)); + vec_xst(vec_perm(t[0], t[2], swiz4), 0, (vec_t *)(vecOffset + 16)); + vec_xst(vec_perm(t[1], t[3], swiz3), 0, (vec_t *)(vecOffset + 32)); + vec_xst(vec_perm(t[1], t[3], swiz4), 0, (vec_t *)(vecOffset + 48)); + vec_xst(vec_perm(t[4], t[6], swiz3), 0, (vec_t *)(vecOffset + 64)); + vec_xst(vec_perm(t[4], t[6], swiz4), 0, (vec_t *)(vecOffset + 80)); + vec_xst(vec_perm(t[5], t[7], swiz3), 0, (vec_t *)(vecOffset + 96)); + vec_xst(vec_perm(t[5], t[7], swiz4), 0, (vec_t *)(vecOffset + 112)); + vecOffset += 128; + } else { + vec_t t0 = vec_perm((vec_t)hp_res[0][col], (vec_t)hp_res[1][col], swiz1); + vec_t t1 = vec_perm((vec_t)hp_res[0][col], (vec_t)hp_res[1][col], swiz2); + vec_t t2 = vec_perm((vec_t)hp_res[2][col], (vec_t)hp_res[3][col], swiz1); + vec_t t3 = vec_perm((vec_t)hp_res[2][col], (vec_t)hp_res[3][col], swiz2); + + vec_xst(vec_perm(t0, t2, swiz3), 0, (vec_t *)(vecOffset + 0)); + vec_xst(vec_perm(t0, t2, swiz4), 0, (vec_t *)(vecOffset + 16)); + vec_xst(vec_perm(t1, t3, swiz3), 0, (vec_t *)(vecOffset + 32)); + vec_xst(vec_perm(t1, t3, swiz4), 0, (vec_t *)(vecOffset + 48)); + vecOffset += 64; + } + } + } + } + } + + void packNormal_q8_fp16(const block_q8_0 * a, int64_t lda, int rows, int blocks, unsigned char * vec) { + if (rows == 4) { + pack_q8_block<4>(a, lda, rows, blocks, vec); + } else { + pack_q8_block<8>(a, lda, rows, blocks, vec); + } + } + + template<int size> + void packNormalInt4(const TA * a, int64_t lda, int rows, int cols, int8_t * vec, std::array<int, size> & comparray) { int64_t i, j; - TA *aoffset = NULL; - int8_t *vecOffset = NULL; - TA *aoffset1 = NULL, *aoffset2 = NULL, *aoffset3 = NULL, *aoffset4 = NULL; - TA *aoffset5 = NULL, *aoffset6 = NULL, *aoffset7 = NULL, *aoffset8 = NULL; + TA * aoffset = NULL; + int8_t * vecOffset = NULL; + TA * aoffset1 = NULL, * aoffset2 = NULL, * aoffset3 = NULL, * aoffset4 = NULL; + TA * aoffset5 = NULL, * aoffset6 = NULL, * aoffset7 = NULL, * aoffset8 = NULL; vector signed char c1[2] = {0}, c2[2] = {0}, c3[2] = {0}, c4[2] = {0}; vector signed char c5[2] = {0}, c6[2] = {0}, c7[2] = {0}, c8[2] = {0}; - aoffset = const_cast<TA*>(a); + aoffset = const_cast<TA *>(a); vecOffset = vec; j = (rows >> 3); if (j > 0) { @@ -2337,27 +2621,27 @@ class tinyBLAS_BF16_PPC { i = (cols >> 2); if (i > 0) { do { - c1[1] = reinterpret_cast<vector signed char>(vec_xl(0, aoffset1->qs)); - c2[1] = reinterpret_cast<vector signed char>(vec_xl(0, aoffset2->qs)); - c3[1] = reinterpret_cast<vector signed char>(vec_xl(0, aoffset3->qs)); - c4[1] = reinterpret_cast<vector signed char>(vec_xl(0, aoffset4->qs)); - c5[1] = reinterpret_cast<vector signed char>(vec_xl(0, aoffset5->qs)); - c6[1] = reinterpret_cast<vector signed char>(vec_xl(0, aoffset6->qs)); - c7[1] = reinterpret_cast<vector signed char>(vec_xl(0, aoffset7->qs)); - c8[1] = reinterpret_cast<vector signed char>(vec_xl(0, aoffset8->qs)); - - process_q4_elements(c1, &comparray[0]); - process_q4_elements(c2, &comparray[1]); - process_q4_elements(c3, &comparray[2]); - process_q4_elements(c4, &comparray[3]); - process_q4_elements(c5, &comparray[4]); - process_q4_elements(c6, &comparray[5]); - process_q4_elements(c7, &comparray[6]); - process_q4_elements(c8, &comparray[7]); + c1[1] = vec_xl(0, (const vector signed char *)aoffset1->qs); + c2[1] = vec_xl(0, (const vector signed char *)aoffset2->qs); + c3[1] = vec_xl(0, (const vector signed char *)aoffset3->qs); + c4[1] = vec_xl(0, (const vector signed char *)aoffset4->qs); + c5[1] = vec_xl(0, (const vector signed char *)aoffset5->qs); + c6[1] = vec_xl(0, (const vector signed char *)aoffset6->qs); + c7[1] = vec_xl(0, (const vector signed char *)aoffset7->qs); + c8[1] = vec_xl(0, (const vector signed char *)aoffset8->qs); + + process_q4_elements(c1, & comparray[0]); + process_q4_elements(c2, & comparray[1]); + process_q4_elements(c3, & comparray[2]); + process_q4_elements(c4, & comparray[3]); + process_q4_elements(c5, & comparray[4]); + process_q4_elements(c6, & comparray[5]); + process_q4_elements(c7, & comparray[6]); + process_q4_elements(c8, & comparray[7]); vector_permute_store<int8_t, vector signed char>(c1[0], c2[0], c3[0], c4[0], vecOffset, false); - vector_permute_store<int8_t, vector signed char>(c1[1], c2[1], c3[1], c4[1], vecOffset+64, false); - vector_permute_store<int8_t, vector signed char>(c5[0], c6[0], c7[0], c8[0], vecOffset+128, false); - vector_permute_store<int8_t, vector signed char>(c5[1], c6[1], c7[1], c8[1], vecOffset+192, false); + vector_permute_store<int8_t, vector signed char>(c1[1], c2[1], c3[1], c4[1], vecOffset + 64, false); + vector_permute_store<int8_t, vector signed char>(c5[0], c6[0], c7[0], c8[0], vecOffset + 128, false); + vector_permute_store<int8_t, vector signed char>(c5[1], c6[1], c7[1], c8[1], vecOffset + 192, false); aoffset1 += lda; aoffset2 += lda; aoffset3 += lda; @@ -2383,17 +2667,17 @@ class tinyBLAS_BF16_PPC { i = (cols >> 2); if (i > 0) { do { - c1[1] = reinterpret_cast<vector signed char>(vec_xl(0, aoffset1->qs)); - c2[1] = reinterpret_cast<vector signed char>(vec_xl(0, aoffset2->qs)); - c3[1] = reinterpret_cast<vector signed char>(vec_xl(0, aoffset3->qs)); - c4[1] = reinterpret_cast<vector signed char>(vec_xl(0, aoffset4->qs)); - - process_q4_elements(c1, &comparray[0]); - process_q4_elements(c2, &comparray[1]); - process_q4_elements(c3, &comparray[2]); - process_q4_elements(c4, &comparray[3]); + c1[1] = vec_xl(0, (const vector signed char *)aoffset1->qs); + c2[1] = vec_xl(0, (const vector signed char *)aoffset2->qs); + c3[1] = vec_xl(0, (const vector signed char *)aoffset3->qs); + c4[1] = vec_xl(0, (const vector signed char *)aoffset4->qs); + + process_q4_elements(c1, & comparray[0]); + process_q4_elements(c2, & comparray[1]); + process_q4_elements(c3, & comparray[2]); + process_q4_elements(c4, & comparray[3]); vector_permute_store<int8_t, vector signed char>(c1[0], c2[0], c3[0], c4[0], vecOffset, false); - vector_permute_store<int8_t, vector signed char>(c1[1], c2[1], c3[1], c4[1], vecOffset+64, false); + vector_permute_store<int8_t, vector signed char>(c1[1], c2[1], c3[1], c4[1], vecOffset + 64, false); aoffset1 += lda; aoffset2 += lda; aoffset3 += lda; @@ -2412,17 +2696,17 @@ class tinyBLAS_BF16_PPC { if (i > 0) { do { switch(rows) { - case 3: c3[1] = reinterpret_cast<vector signed char>(vec_xl(0, aoffset3->qs)); - case 2: c2[1] = reinterpret_cast<vector signed char>(vec_xl(0, aoffset2->qs)); - case 1: c1[1] = reinterpret_cast<vector signed char>(vec_xl(0, aoffset1->qs)); + case 3: c3[1] = vec_xl(0, (const vector signed char *)aoffset3->qs); + case 2: c2[1] = vec_xl(0, (const vector signed char *)aoffset2->qs); + case 1: c1[1] = vec_xl(0, (const vector signed char *)aoffset1->qs); break; } - process_q4_elements(c1, &comparray[0]); - process_q4_elements(c2, &comparray[1]); - process_q4_elements(c3, &comparray[2]); - process_q4_elements(c4, &comparray[3]); + process_q4_elements(c1, & comparray[0]); + process_q4_elements(c2, & comparray[1]); + process_q4_elements(c3, & comparray[2]); + process_q4_elements(c4, & comparray[3]); vector_permute_store<int8_t, vector signed char>(c1[0], c2[0], c3[0], c4[0], vecOffset, false); - vector_permute_store<int8_t, vector signed char>(c1[1], c2[1], c3[1], c4[1], vecOffset+64, false); + vector_permute_store<int8_t, vector signed char>(c1[1], c2[1], c3[1], c4[1], vecOffset + 64, false); aoffset1 += lda; aoffset2 += lda; aoffset3 += lda; @@ -2433,39 +2717,38 @@ class tinyBLAS_BF16_PPC { } } - template<typename TA> template<typename VA, typename VB> - void tinyBLAS_Q0_PPC<TA>::packNormal(const block_q8_0* a, int64_t lda, int rows, int cols, VA* vec, bool flip) { + void packNormal(const block_q8_0 * a, int64_t lda, int rows, int cols, VA * vec, bool flip) { int64_t i, j; - block_q8_0 *aoffset = NULL; - VA *vecOffset = NULL; - block_q8_0* aoffsets[8]; + block_q8_0 * aoffset = NULL; + VA * vecOffset = NULL; + block_q8_0 * aoffsets[8]; __vector_pair arr[8]; VB c[8][2] = {0}; VB c1[8] = {0}; VB c2[8] = {0}; - aoffset = const_cast<block_q8_0*>(a); + aoffset = const_cast<block_q8_0 *>(a); vecOffset = vec; j = (rows >> 3); if (j > 0) { do { aoffsets[0] = aoffset; for (int it = 1; it < 8; it++) - aoffsets[it] = aoffsets[it-1] + lda; + aoffsets[it] = aoffsets[it - 1] + lda; aoffset += 8 * lda; i = (cols >> 3); if (i > 0) { do { for (int it = 0; it < 8; it++) { - arr[it] = __builtin_vsx_lxvp(0, (__vector_pair*)aoffsets[it]->qs); - __builtin_vsx_disassemble_pair(c[it], &arr[it]); + arr[it] = __builtin_vsx_lxvp(0, (__vector_pair *)aoffsets[it]->qs); + __builtin_vsx_disassemble_pair(c[it], & arr[it]); c1[it] = c[it][0]; c2[it] = c[it][1]; } vector_permute_store<VA, VB>(c1[0], c1[1], c1[2], c1[3], vecOffset, flip); - vector_permute_store<VA, VB>(c2[0], c2[1], c2[2], c2[3], vecOffset+64, flip); - vector_permute_store<VA, VB>(c1[4], c1[5], c1[6], c1[7], vecOffset+128, flip); - vector_permute_store<VA, VB>(c2[4], c2[5], c2[6], c2[7], vecOffset+192, flip); + vector_permute_store<VA, VB>(c2[0], c2[1], c2[2], c2[3], vecOffset + 64, flip); + vector_permute_store<VA, VB>(c1[4], c1[5], c1[6], c1[7], vecOffset + 128, flip); + vector_permute_store<VA, VB>(c2[4], c2[5], c2[6], c2[7], vecOffset + 192, flip); for (int it = 0; it < 8; it++) aoffsets[it] += lda; vecOffset += 256; @@ -2484,13 +2767,13 @@ class tinyBLAS_BF16_PPC { if (i > 0) { do { for (int it = 0; it < 4; it++) { - arr[it] = __builtin_vsx_lxvp(0, (__vector_pair*)aoffsets[it]->qs); - __builtin_vsx_disassemble_pair(c[it], &arr[it]); + arr[it] = __builtin_vsx_lxvp(0, (__vector_pair *)aoffsets[it]->qs); + __builtin_vsx_disassemble_pair(c[it], & arr[it]); c1[it] = c[it][0]; c2[it] = c[it][1]; } vector_permute_store<VA, VB>(c1[0], c1[1], c1[2], c1[3], vecOffset, flip); - vector_permute_store<VA, VB>(c2[0], c2[1], c2[2], c2[3], vecOffset+64, flip); + vector_permute_store<VA, VB>(c2[0], c2[1], c2[2], c2[3], vecOffset + 64, flip); for (int it = 0; it < 4; it++) { aoffsets[it] += lda; } @@ -2503,24 +2786,24 @@ class tinyBLAS_BF16_PPC { if (rows & 3) { aoffsets[0] = aoffset; for (int it = 1; it < 3; it++ ) - aoffsets[it] = aoffsets[it-1] + lda; + aoffsets[it] = aoffsets[it - 1] + lda; i = (cols >> 3); if (i > 0) { do { switch(rows) { - case 3: arr[2] = __builtin_vsx_lxvp(0, (__vector_pair*)aoffsets[2]->qs); - __builtin_vsx_disassemble_pair(c[2], &arr[2]); + case 3: arr[2] = __builtin_vsx_lxvp(0, (__vector_pair *)aoffsets[2]->qs); + __builtin_vsx_disassemble_pair(c[2], & arr[2]); c1[2] = c[2][0]; c2[2] = c[2][1]; - case 2: arr[1] = __builtin_vsx_lxvp(0, (__vector_pair*)aoffsets[1]->qs); - __builtin_vsx_disassemble_pair(c[1], &arr[1]); + case 2: arr[1] = __builtin_vsx_lxvp(0, (__vector_pair *)aoffsets[1]->qs); + __builtin_vsx_disassemble_pair(c[1], & arr[1]); c1[1] = c[1][0]; c2[1] = c[1][1]; - case 1: arr[0] = __builtin_vsx_lxvp(0, (__vector_pair*)aoffsets[0]->qs); - __builtin_vsx_disassemble_pair(c[0], &arr[0]); + case 1: arr[0] = __builtin_vsx_lxvp(0, (__vector_pair *)aoffsets[0]->qs); + __builtin_vsx_disassemble_pair(c[0], & arr[0]); c1[0] = c[0][0]; c2[0] = c[0][1]; break; } vector_permute_store<VA, VB>(c1[0], c1[1], c1[2], c1[3], vecOffset, flip); - vector_permute_store<VA, VB>(c2[0], c2[1], c2[2], c2[3], vecOffset+64, flip); + vector_permute_store<VA, VB>(c2[0], c2[1], c2[2], c2[3], vecOffset + 64, flip); for (int it = 0; it < 3; it++) aoffsets[it] += lda; vecOffset += 128; @@ -2530,8 +2813,7 @@ class tinyBLAS_BF16_PPC { } } - template<typename TA> - void tinyBLAS_Q0_PPC<TA>::mnpack(int64_t m0, int64_t m, int64_t n0, int64_t n) { + void mnpack(int64_t m0, int64_t m, int64_t n0, int64_t n) { int m_rem = MIN(m - m0, 16); int n_rem = MIN(n - n0, 16); @@ -2568,8 +2850,7 @@ class tinyBLAS_BF16_PPC { } - template<typename TA> - void tinyBLAS_Q0_PPC<TA>::KERNEL_4x8(int64_t ii, int64_t jj) { + void KERNEL_4x8(int64_t ii, int64_t jj) { vec_t vec_A[8], vec_B[16] = {0}; acc_t acc_0, acc_1; std::array<int, 4> comparray {}; @@ -2577,26 +2858,26 @@ class tinyBLAS_BF16_PPC { vector float vs[8] = {0}; bool isAblock_q4 = std::is_same_v<TA, block_q4_0>; for (int l = 0; l < k; l++) { - __builtin_mma_xxsetaccz(&acc_0); - __builtin_mma_xxsetaccz(&acc_1); + __builtin_mma_xxsetaccz(& acc_0); + __builtin_mma_xxsetaccz(& acc_1); if (std::is_same_v<TA, block_q4_0>) { - packNormalInt4<4>((A+(ii*lda)+l), lda, 4, 4, (int8_t*)vec_A, comparray); + packNormalInt4<4>((A + (ii * lda) + l), lda, 4, 4, (int8_t *)vec_A, comparray); } else { - packNormal<int8_t, vector signed char>((const block_q8_0*)(A+(ii*lda)+l), lda, 4, 8, (int8_t*)vec_A, false); + packNormal<int8_t, vector signed char>((const block_q8_0 *)(A + (ii * lda) + l), lda, 4, 8, (int8_t *)vec_A, false); } - packNormal<uint8_t, vector unsigned char>((B+(jj*ldb)+l), ldb, 8, 8, (uint8_t*)vec_B, true); + packNormal<uint8_t, vector unsigned char>((B + (jj * ldb) + l), ldb, 8, 8, (uint8_t *)vec_B, true); for(int x = 0; x < 8; x++) { - __builtin_mma_xvi8ger4pp(&acc_0, vec_A[x], vec_B[x]); - __builtin_mma_xvi8ger4pp(&acc_1, vec_A[x], vec_B[x+8]); + __builtin_mma_xvi8ger4pp(& acc_0, vec_A[x], vec_B[x]); + __builtin_mma_xvi8ger4pp(& acc_1, vec_A[x], vec_B[x+8]); } for (int I = 0; I<4; I++) { for (int J = 0; J<4; J++) { - *((float*)&vs[I]+J) = (unhalf((A+((ii+I)*lda)+l)->d) * unhalf((B+((jj+J)*ldb)+l)->d)); - *((float*)&vs[I+4]+J) = (unhalf((A+((ii+I)*lda)+l)->d) * unhalf((B+((jj+J+4)*ldb)+l)->d)); + *((float *)& vs[I] + J) = (unhalf((A + ((ii + I) * lda) + l)->d) * unhalf((B + ((jj + J) * ldb) + l)->d)); + *((float *)& vs[I + 4] + J) = (unhalf((A +((ii + I) * lda) + l)->d) * unhalf((B + ((jj + J + 4) * ldb) + l)->d)); } } if (!isAblock_q4) { - auto aoffset = A+(ii*lda)+l; + auto aoffset = A + (ii * lda) + l; for (int i = 0; i < 4; i++) { comparray[i] = 0; int ca = 0; @@ -2607,15 +2888,14 @@ class tinyBLAS_BF16_PPC { aoffset += lda; } } - compute(&acc_0, 0, 0, comparray, vs, fin_res); - compute(&acc_1, 0, 4, comparray, vs, fin_res); + compute(& acc_0, 0, 0, comparray, vs, fin_res); + compute(& acc_1, 0, 4, comparray, vs, fin_res); } save_res(ii, jj, 0, fin_res); - save_res(ii, jj+4, 4, fin_res); + save_res(ii, jj + 4, 4, fin_res); } - template<typename TA> - void tinyBLAS_Q0_PPC<TA>::KERNEL_8x4(int64_t ii, int64_t jj) { + void KERNEL_8x4(int64_t ii, int64_t jj) { vec_t vec_A[16], vec_B[8] = {0}; acc_t acc_0, acc_1; std::array<int, 8> comparray {}; @@ -2623,25 +2903,25 @@ class tinyBLAS_BF16_PPC { vector float vs[8] = {0}; bool isAblock_q4 = std::is_same_v<TA, block_q4_0>; for (int l = 0; l < k; l++) { - __builtin_mma_xxsetaccz(&acc_0); - __builtin_mma_xxsetaccz(&acc_1); + __builtin_mma_xxsetaccz(& acc_0); + __builtin_mma_xxsetaccz(& acc_1); if (std::is_same_v<TA, block_q4_0>) { - packNormalInt4<8>((A+(ii*lda)+l), lda, 8, 4, (int8_t*)vec_A, comparray); + packNormalInt4<8>((A + (ii * lda) + l), lda, 8, 4, (int8_t *)vec_A, comparray); } else { - packNormal<int8_t, vector signed char>((const block_q8_0*)(A+(ii*lda)+l), lda, 8, 8, (int8_t*)vec_A, false); + packNormal<int8_t, vector signed char>((const block_q8_0 *)(A + (ii * lda) + l), lda, 8, 8, (int8_t *)vec_A, false); } - packNormal<uint8_t, vector unsigned char>((B+(jj*ldb)+l), ldb, 4, 8, (uint8_t*)vec_B, true); + packNormal<uint8_t, vector unsigned char>((B + (jj * ldb) + l), ldb, 4, 8, (uint8_t *)vec_B, true); for(int x = 0; x < 8; x++) { - __builtin_mma_xvi8ger4pp(&acc_0, vec_A[x], vec_B[x]); - __builtin_mma_xvi8ger4pp(&acc_1, vec_A[x+8], vec_B[x]); + __builtin_mma_xvi8ger4pp(& acc_0, vec_A[x], vec_B[x]); + __builtin_mma_xvi8ger4pp(& acc_1, vec_A[x + 8], vec_B[x]); } - for (int I = 0; I<8; I++) { - for (int J = 0; J<4; J++) { - *((float*)&vs[I]+J) = (unhalf((A+((ii+I)*lda)+l)->d) * unhalf((B+((jj+J)*ldb)+l)->d)); + for (int I = 0; I < 8; I++) { + for (int J = 0; J < 4; J++) { + *((float *)&vs[I] + J) = (unhalf((A + ((ii + I) * lda) + l)->d) * unhalf((B + ((jj + J) * ldb) + l)->d)); } } if (!isAblock_q4) { - auto aoffset = A+(ii*lda)+l; + auto aoffset = A + (ii * lda) + l; for (int i = 0; i < 8; i++) { comparray[i] = 0; int ca = 0; @@ -2652,15 +2932,14 @@ class tinyBLAS_BF16_PPC { aoffset += lda; } } - compute(&acc_0, 0, 0, comparray, vs, fin_res); - compute(&acc_1, 4, 4, comparray, vs, fin_res); + compute(& acc_0, 0, 0, comparray, vs, fin_res); + compute(& acc_1, 4, 4, comparray, vs, fin_res); } save_res(ii, jj, 0, fin_res); - save_res(ii+4, jj, 4, fin_res); + save_res(ii + 4, jj, 4, fin_res); } - template<typename TA> - void tinyBLAS_Q0_PPC<TA>::KERNEL_8x8(int64_t ii, int64_t jj) { + void KERNEL_8x8(int64_t ii, int64_t jj) { vec_t vec_A[16], vec_B[16] = {0}; acc_t acc_0, acc_1, acc_2, acc_3; acc_t acc_4, acc_5, acc_6, acc_7; @@ -2669,30 +2948,30 @@ class tinyBLAS_BF16_PPC { vector float vs[16] = {0}; bool isAblock_q4 = std::is_same_v<TA, block_q4_0>; for (int l = 0; l < k; l++) { - __builtin_mma_xxsetaccz(&acc_0); - __builtin_mma_xxsetaccz(&acc_1); - __builtin_mma_xxsetaccz(&acc_2); - __builtin_mma_xxsetaccz(&acc_3); + __builtin_mma_xxsetaccz(& acc_0); + __builtin_mma_xxsetaccz(& acc_1); + __builtin_mma_xxsetaccz(& acc_2); + __builtin_mma_xxsetaccz(& acc_3); if (std::is_same_v<TA, block_q4_0>) { - packNormalInt4<8>((A+(ii*lda)+l), lda, 8, 4, (int8_t*)vec_A, comparray); + packNormalInt4<8>((A + (ii * lda) + l), lda, 8, 4, (int8_t *)vec_A, comparray); } else { - packNormal<int8_t, vector signed char>((const block_q8_0*)(A+(ii*lda)+l), lda, 8, 8, (int8_t*)vec_A, false); + packNormal<int8_t, vector signed char>((const block_q8_0 *)(A + (ii * lda) + l), lda, 8, 8, (int8_t *)vec_A, false); } - packNormal<uint8_t, vector unsigned char>((B+(jj*ldb)+l), ldb, 8, 8, (uint8_t*)vec_B, true); + packNormal<uint8_t, vector unsigned char>((B + (jj * ldb) + l), ldb, 8, 8, (uint8_t *)vec_B, true); for(int x = 0; x < 8; x++) { - __builtin_mma_xvi8ger4pp(&acc_0, vec_A[x], vec_B[x]); - __builtin_mma_xvi8ger4pp(&acc_1, vec_A[x+8], vec_B[x]); - __builtin_mma_xvi8ger4pp(&acc_2, vec_A[x], vec_B[x+8]); - __builtin_mma_xvi8ger4pp(&acc_3, vec_A[x+8], vec_B[x+8]); + __builtin_mma_xvi8ger4pp(& acc_0, vec_A[x], vec_B[x]); + __builtin_mma_xvi8ger4pp(& acc_1, vec_A[x + 8], vec_B[x]); + __builtin_mma_xvi8ger4pp(& acc_2, vec_A[x], vec_B[x + 8]); + __builtin_mma_xvi8ger4pp(& acc_3, vec_A[x + 8], vec_B[x + 8]); } - for (int I = 0; I<8; I++) { - for (int J = 0; J<4; J++) { - *((float*)&vs[I]+J) = (unhalf((A+((ii+I)*lda)+l)->d) * unhalf((B+((jj+J)*ldb)+l)->d)); - *((float*)&vs[I+8]+J) = (unhalf((A+((ii+I)*lda)+l)->d) * unhalf((B+((jj+J+4)*ldb)+l)->d)); + for (int I = 0; I < 8 ; I++) { + for (int J = 0; J < 4; J++) { + *((float *)& vs[I] + J) = (unhalf((A + ((ii + I) * lda) + l)->d) * unhalf((B + ((jj + J) * ldb) + l)->d)); + *((float *)& vs[I + 8] + J) = (unhalf((A + ((ii + I) * lda) + l)->d) * unhalf((B + ((jj + J + 4) * ldb) + l)->d)); } } if (!isAblock_q4) { - auto aoffset = A+(ii*lda)+l; + auto aoffset = A + (ii * lda) + l; for (int i = 0; i < 8; i++) { comparray[i] = 0; int ca = 0; @@ -2703,19 +2982,99 @@ class tinyBLAS_BF16_PPC { aoffset += lda; } } - compute(&acc_0, 0, 0, comparray, vs, fin_res); - compute(&acc_1, 4, 4, comparray, vs, fin_res); - compute(&acc_2, 0, 8, comparray, vs, fin_res); - compute(&acc_3, 4, 12, comparray, vs, fin_res); + compute(& acc_0, 0, 0, comparray, vs, fin_res); + compute(& acc_1, 4, 4, comparray, vs, fin_res); + compute(& acc_2, 0, 8, comparray, vs, fin_res); + compute(& acc_3, 4, 12, comparray, vs, fin_res); } save_res(ii, jj, 0, fin_res); - save_res(ii+4, jj, 4, fin_res); - save_res(ii, jj+4, 8, fin_res); - save_res(ii+4, jj+4, 12, fin_res); + save_res(ii + 4, jj, 4, fin_res); + save_res(ii, jj + 4, 8, fin_res); + save_res(ii + 4, jj + 4, 12, fin_res); + } + + void KERNEL_Q0(int64_t ii, int64_t jj, int64_t mc, int64_t nc, int64_t kc, int64_t l, vec_t * vec_A, vec_t * vec_B) { + acc_t acc[8]; + for (int i = 0; i < mc ; i += 16) { + for (int j = 0; j < nc; j += 8) { + int A0_base = (i / 16) * (2 * 32 * kc); + int B0_base = (j / 8) * (32 * kc); + for (int x = 0; x < 8; x++) { + __builtin_mma_xxsetaccz(&acc[x]); + } + for (int64_t kk = 0; kk < kc; kk++) { + int A0_block_idx = A0_base + kk * 32; + int B0_block_idx = B0_base + kk * 32; + int A1_block_idx = A0_block_idx + 32 * kc; + int B1_block_idx = B0_block_idx + 32 * kc; + vec_t * A0_block = & vec_A[A0_block_idx]; + vec_t * B0_block = & vec_B[B0_block_idx]; + vec_t * A1_block = & vec_A[A1_block_idx]; + for (int it = 0; it < 4; it++) { + for (int x = 0; x < 4; x++) { + __builtin_mma_xvf16ger2pp(& acc[0], A0_block[8 * it + x], B0_block[8 * it + x]); + __builtin_mma_xvf16ger2pp(& acc[1], A0_block[8 * it + x], B0_block[8 * it + x + 4]); + __builtin_mma_xvf16ger2pp(& acc[2], A0_block[8 * it + x + 4], B0_block[8 * it + x]); + __builtin_mma_xvf16ger2pp(& acc[3], A0_block[8 * it + x + 4], B0_block[8 * it + x + 4]); + __builtin_mma_xvf16ger2pp(& acc[4], A1_block[8 * it + x], B0_block[8 * it + x]); + __builtin_mma_xvf16ger2pp(& acc[5], A1_block[8 * it + x], B0_block[8 * it+ x + 4]); + __builtin_mma_xvf16ger2pp(& acc[6], A1_block[8 * it + x + 4], B0_block[8 * it + x]); + __builtin_mma_xvf16ger2pp(& acc[7], A1_block[8 * it + x + 4], B0_block[8 * it + x + 4]); + } + } + } + if (l == 0) { + save_acc(& acc[0], ii + i, jj + j); + save_acc(& acc[1], ii + i, jj + j + 4); + save_acc(& acc[2], ii + i + 4, jj + j); + save_acc(& acc[3], ii + i + 4, jj + j + 4); + save_acc(& acc[4], ii + i + 8, jj + j); + save_acc(& acc[5], ii + i + 8, jj + j + 4); + save_acc(& acc[6], ii + i + 12, jj + j); + save_acc(& acc[7], ii + i + 12, jj + j + 4); + } else { + add_save_acc(& acc[0], ii + i, jj + j); + add_save_acc(& acc[1], ii + i, jj + j + 4); + add_save_acc(& acc[2], ii + i + 4, jj + j); + add_save_acc(& acc[3], ii + i + 4, jj + j + 4); + add_save_acc(& acc[4], ii + i + 8, jj + j); + add_save_acc(& acc[5], ii + i + 8, jj + j + 4); + add_save_acc(& acc[6], ii + i + 12, jj + j); + add_save_acc(& acc[7], ii + i + 12, jj + j + 4); + } + } + } + } + + void matmul_tiled(int64_t m, int64_t n, int64_t mc, int64_t nc, int64_t kc) { + vec_t A_pack[mc * kc * 4]; + vec_t B_pack[nc * kc * 4]; + constexpr bool is_Ablock_q4 = std::is_same_v<TA, block_q4_0>; + int64_t ytiles = m / mc; + int64_t xtiles = n / nc; + int64_t tiles = xtiles * ytiles; + int64_t duty = (tiles + nth - 1) / nth; + int64_t start = duty * ith; + int64_t end = start + duty; + if (end > tiles) { + end = tiles; + } + for (int64_t job = start; job < end; ++job) { + int64_t ii = (job / xtiles) * mc; + int64_t jj = (job % xtiles) * nc; + for (int64_t kk = 0; kk < k; kk += kc) { + if constexpr(is_Ablock_q4) { + packNormal_q4_fp16(A + ii * lda + kk, lda, mc, kc, (uint8_t *)A_pack); + } else { + packNormal_q8_fp16(A + ii * lda + kk, lda, mc, kc, (uint8_t *)A_pack); + } + packNormal_q8_fp16(B + jj * ldb + kk, ldb, nc, kc, (uint8_t *)B_pack); + KERNEL_Q0(ii, jj, mc, nc, kc, kk, A_pack, B_pack); + } + } } - template<typename TA> - void tinyBLAS_Q0_PPC<TA>::gemm_small(int64_t m0, int64_t m, int64_t n0, int64_t n, int RM, int RN) { + void gemm_small(int64_t m0, int64_t m, int64_t n0, int64_t n, int RM, int RN) { int64_t ytiles = (m - m0) / RM; int64_t xtiles = (n - n0) / RN; int64_t tiles = xtiles * ytiles; @@ -2737,32 +3096,32 @@ class tinyBLAS_BF16_PPC { vector float fin_res[4] = {0}; vector float vs[4] = {0}; vector float CA[4] = {0}; - __builtin_prefetch((A+(ii*lda)+0)->qs, 0, 1); // prefetch first value - __builtin_prefetch((B+(jj*ldb)+0)->qs, 0, 1); // prefetch first value + __builtin_prefetch((A + (ii * lda) + 0)->qs, 0, 1); // prefetch first value + __builtin_prefetch((B + (jj * ldb) + 0)->qs, 0, 1); // prefetch first value for (int l = 0; l < k; l++) { - __builtin_prefetch((A+(ii*lda)+(l+1))->qs, 0, 1); // prefetch one loop ahead - __builtin_prefetch((B+(jj*ldb)+(l+1))->qs, 0, 1); // prefetch one loop ahead - __builtin_mma_xxsetaccz(&acc_0); + __builtin_prefetch((A + (ii * lda) + (l + 1))->qs, 0, 1); // prefetch one loop ahead + __builtin_prefetch((B + (jj * ldb) + (l + 1))->qs, 0, 1); // prefetch one loop ahead + __builtin_mma_xxsetaccz(& acc_0); if (isAblock_q4) { - packNormalInt4<4>((A+(ii*lda)+l), lda, RM, 4, (int8_t*)vec_A, comparray); + packNormalInt4<4>((A + (ii * lda) + l), lda, RM, 4, (int8_t *)vec_A, comparray); } else { - packNormal<int8_t, vector signed char>((const block_q8_0*)(A+(ii*lda)+l), lda, RM, 8, (int8_t*)vec_A, false); + packNormal<int8_t, vector signed char>((const block_q8_0 *)(A + (ii * lda) + l), lda, RM, 8, (int8_t *)vec_A, false); } - packNormal<uint8_t, vector unsigned char>((B+(jj*ldb)+l), ldb, RN, 8, (uint8_t*)vec_B, true); - for(int x = 0; x < 8; x+=4) { - __builtin_mma_xvi8ger4pp(&acc_0, vec_A[x], vec_B[x]); - __builtin_mma_xvi8ger4pp(&acc_0, vec_A[x+1], vec_B[x+1]); - __builtin_mma_xvi8ger4pp(&acc_0, vec_A[x+2], vec_B[x+2]); - __builtin_mma_xvi8ger4pp(&acc_0, vec_A[x+3], vec_B[x+3]); + packNormal<uint8_t, vector unsigned char>((B + (jj * ldb) + l), ldb, RN, 8, (uint8_t *)vec_B, true); + for (int x = 0; x < 8; x += 4) { + __builtin_mma_xvi8ger4pp(& acc_0, vec_A[x], vec_B[x]); + __builtin_mma_xvi8ger4pp(& acc_0, vec_A[x + 1], vec_B[x + 1]); + __builtin_mma_xvi8ger4pp(& acc_0, vec_A[x + 2], vec_B[x + 2]); + __builtin_mma_xvi8ger4pp(& acc_0, vec_A[x + 3], vec_B[x + 3]); } - for (int I = 0; I<RM; I++) { - for (int J = 0; J<RN; J++) { - *((float*)&vs[I]+J) = (unhalf((A+((ii+I)*lda)+l)->d) * unhalf((B+((jj+J)*ldb)+l)->d)); + for (int I = 0; I < RM; I++) { + for (int J = 0; J < RN; J++) { + *((float*)&vs[I] + J) = (unhalf((A + ((ii + I) * lda) + l)->d) * unhalf((B + ((jj + J) * ldb) + l)->d)); } } - __builtin_mma_disassemble_acc(vec_C, &acc_0); + __builtin_mma_disassemble_acc(vec_C, & acc_0); if (!isAblock_q4) { - auto aoffset = A+(ii*lda)+l; + auto aoffset = A + (ii * lda) + l; for (int i = 0; i < RM; i++) { comparray[i] = 0; int ca = 0; @@ -2783,9 +3142,21 @@ class tinyBLAS_BF16_PPC { } } - template<typename TA> + template<int RM, int RN> + inline void kernel(int64_t ii, int64_t jj) { + if constexpr(RM == 4 && RN == 8) { + KERNEL_4x8(ii,jj); + } else if constexpr(RM == 8 && RN == 4) { + KERNEL_8x4(ii,jj); + } else if constexpr(RM == 8 && RN == 8) { + KERNEL_8x8(ii,jj); + } else { + assert(false && "RN/RM values not supported"); + } + } + template <int RM, int RN> - NOINLINE void tinyBLAS_Q0_PPC<TA>::gemm(int64_t m0, int64_t m, int64_t n0, int64_t n) { + NOINLINE void gemm(int64_t m0, int64_t m, int64_t n0, int64_t n) { int64_t ytiles = (m - m0) / RM; int64_t xtiles = (n - n0) / RN; int64_t tiles = xtiles * ytiles; @@ -2797,12 +3168,20 @@ class tinyBLAS_BF16_PPC { for (int64_t job = start; job < end; ++job) { int64_t ii = m0 + job / xtiles * RM; int64_t jj = n0 + job % xtiles * RN; - this->kernel<RM, RN>(ii, jj); + kernel<RM, RN>(ii, jj); } } - -template class tinyBLAS_Q0_PPC<block_q4_0>; -template class tinyBLAS_Q0_PPC<block_q8_0>; + const TA * const A; + const block_q8_0 * const B; + float * C; + const int64_t k; + int64_t kc; + const int64_t lda; + const int64_t ldb; + const int64_t ldc; + const int ith; + const int nth; +}; class tinyBLAS_PPC { public: @@ -2815,16 +3194,21 @@ class tinyBLAS_PPC { } void matmul(int64_t m, int64_t n) { + #if defined(_AIX) || defined(__BIG_ENDIAN__) + mnpack(0, m, 0, n); + #else int64_t mc = 256; int64_t nc = 256; int64_t kc = 256; if (m % mc == 0 && n % nc == 0 && k % kc == 0) { matmul_tiled(m, n, mc, nc, kc); } else { mnpack(0, m, 0, n); } + #endif } private: + __attribute__((always_inline)) inline void save_acc(acc_t * ACC, int64_t ii, int64_t jj) { vec_t vec_C[4]; __builtin_mma_disassemble_acc(vec_C, ACC); @@ -2835,6 +3219,7 @@ class tinyBLAS_PPC { } } + __attribute__((always_inline)) inline void add_save_acc(acc_t * ACC, int64_t ii, int64_t jj) { vec_t vec_C[4]; __builtin_mma_disassemble_acc(vec_C, ACC); @@ -3369,7 +3754,7 @@ bool llamafile_sgemm(const struct ggml_compute_params * params, int64_t m, int64 params->ith, params->nth}; tb.matmul(m, n); return true; -#elif defined(__riscv_zvfh) +#elif defined(__riscv_v_intrinsic) #if LMUL == 1 tinyBLAS_RVV<vfloat32m1_t, vfloat32m1_t, float, float, float> tb{ params, k, (const float *)A, lda, @@ -3418,35 +3803,40 @@ bool llamafile_sgemm(const struct ggml_compute_params * params, int64_t m, int64 return tb.matmul(m, n); } #elif defined(__MMA__) - if ((k % 8)) - return false; - if(Btype == GGML_TYPE_BF16) { - tinyBLAS_BF16_PPC<ggml_bf16_t, ggml_bf16_t, float> tb{ k, - (const ggml_bf16_t *)A, lda, - (const ggml_bf16_t *)B, ldb, - (float *)C, ldc, - params->ith, params->nth}; - tb.matmul(m, n); - return true; + if (k % 8) { + return false; } -#elif defined(__riscv_zvfbfwma) - #if LMUL == 1 - tinyBLAS_RVV<vfloat32m1_t, vbfloat16mf2_t, ggml_bf16_t, ggml_bf16_t, float> tb{ params, - k, (const ggml_bf16_t *)A, lda, - (const ggml_bf16_t *)B, ldb, - (float *)C, ldc}; - #elif LMUL == 2 - tinyBLAS_RVV<vfloat32m2_t, vbfloat16m1_t, ggml_bf16_t, ggml_bf16_t, float> tb{ params, - k, (const ggml_bf16_t *)A, lda, - (const ggml_bf16_t *)B, ldb, - (float *)C, ldc}; - #else // LMUL = 4 - tinyBLAS_RVV<vfloat32m4_t, vbfloat16m2_t, ggml_bf16_t, ggml_bf16_t, float> tb{ params, - k, (const ggml_bf16_t *)A, lda, + + if (Btype == GGML_TYPE_BF16) { + tinyBLAS_HP16_PPC<ggml_bf16_t, ggml_bf16_t, float> tb{ k, + (const ggml_bf16_t *)A, lda, (const ggml_bf16_t *)B, ldb, - (float *)C, ldc}; - #endif - return tb.matmul(m, n); + (float *)C, ldc, + params->ith, params->nth }; + + tb.matmul(m, n); + return true; + } +#elif defined(__riscv_zvfbfwma) + if (Btype == GGML_TYPE_BF16) { + #if LMUL == 1 + tinyBLAS_RVV<vfloat32m1_t, vbfloat16mf2_t, ggml_bf16_t, ggml_bf16_t, float> tb{ params, + k, (const ggml_bf16_t *)A, lda, + (const ggml_bf16_t *)B, ldb, + (float *)C, ldc}; + #elif LMUL == 2 + tinyBLAS_RVV<vfloat32m2_t, vbfloat16m1_t, ggml_bf16_t, ggml_bf16_t, float> tb{ params, + k, (const ggml_bf16_t *)A, lda, + (const ggml_bf16_t *)B, ldb, + (float *)C, ldc}; + #else // LMUL = 4 + tinyBLAS_RVV<vfloat32m4_t, vbfloat16m2_t, ggml_bf16_t, ggml_bf16_t, float> tb{ params, + k, (const ggml_bf16_t *)A, lda, + (const ggml_bf16_t *)B, ldb, + (float *)C, ldc}; + #endif + return tb.matmul(m, n); + } #endif return false; } @@ -3516,6 +3906,21 @@ bool llamafile_sgemm(const struct ggml_compute_params * params, int64_t m, int64 #endif return tb.matmul(m, n); } +#elif defined(__MMA__) + if (k % 8) { + return false; + } + + if (Btype == GGML_TYPE_F16) { + tinyBLAS_HP16_PPC<ggml_fp16_t, ggml_fp16_t, float> tb{ k, + (const ggml_fp16_t *)A, lda, + (const ggml_fp16_t *)B, ldb, + (float *)C, ldc, + params->ith, params->nth }; + + tb.matmul(m, n); + return true; + } #endif return false; } diff --git a/ggml/src/ggml-cpu/ops.cpp b/ggml/src/ggml-cpu/ops.cpp index 3032783971d..74611dce7f1 100644 --- a/ggml/src/ggml-cpu/ops.cpp +++ b/ggml/src/ggml-cpu/ops.cpp @@ -3,14 +3,14 @@ #include "ggml-cpu.h" #include "ggml-impl.h" #include "binary-ops.h" +#include "simd-gemm.h" #include "ggml.h" #include "unary-ops.h" #include "vec.h" -#include <cfloat> #include <algorithm> +#include <cfloat> #include <cmath> -#include <functional> // ggml_compute_forward_dup @@ -375,7 +375,7 @@ static void ggml_compute_forward_dup_bytes( const size_t rs = ne00 * type_size; if (nb00 == type_size) { - // src0 is contigous on first dimension, copy by rows + // src0 is contiguous on first dimension, copy by rows for (int64_t i03 = 0; i03 < ne03; i03++) { for (int64_t i02 = 0; i02 < ne02; i02++) { id += rs * ir0; @@ -664,12 +664,14 @@ void ggml_compute_forward_add( { ggml_compute_forward_add_non_quantized(params, dst); } break; + case GGML_TYPE_Q1_0: case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_1: case GGML_TYPE_Q5_0: case GGML_TYPE_Q5_1: case GGML_TYPE_Q8_0: case GGML_TYPE_MXFP4: + case GGML_TYPE_NVFP4: case GGML_TYPE_Q2_K: case GGML_TYPE_Q3_K: case GGML_TYPE_Q4_K: @@ -1112,6 +1114,7 @@ void ggml_compute_forward_add1( GGML_ABORT("fatal error"); } } break; + case GGML_TYPE_Q1_0: case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_1: case GGML_TYPE_Q5_0: @@ -1119,6 +1122,7 @@ void ggml_compute_forward_add1( case GGML_TYPE_Q8_0: case GGML_TYPE_Q8_1: case GGML_TYPE_MXFP4: + case GGML_TYPE_NVFP4: case GGML_TYPE_Q2_K: case GGML_TYPE_Q3_K: case GGML_TYPE_Q4_K: @@ -1240,6 +1244,7 @@ void ggml_compute_forward_acc( } break; case GGML_TYPE_F16: case GGML_TYPE_BF16: + case GGML_TYPE_Q1_0: case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_1: case GGML_TYPE_Q5_0: @@ -1247,6 +1252,7 @@ void ggml_compute_forward_acc( case GGML_TYPE_Q8_0: case GGML_TYPE_Q8_1: case GGML_TYPE_MXFP4: + case GGML_TYPE_NVFP4: case GGML_TYPE_Q2_K: case GGML_TYPE_Q3_K: case GGML_TYPE_Q4_K: @@ -1795,7 +1801,7 @@ void ggml_compute_forward_repeat( { ggml_compute_forward_repeat_f32(params, dst); } break; - // TODO: templateify the implemenation and support for I64 + // TODO: templateify the implementation and support for I64 // ref https://github.com/ggml-org/llama.cpp/pull/14274#discussion_r2169492225 //case GGML_TYPE_I64: // { @@ -2097,10 +2103,14 @@ static void ggml_compute_forward_gelu_f32( const ggml_tensor * src0 = dst->src[0]; - assert(ggml_is_contiguous_1(src0)); - assert(ggml_is_contiguous_1(dst)); + assert(ggml_is_contiguous_rows(src0)); assert(ggml_are_same_shape(src0, dst)); + GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne) + GGML_TENSOR_LOCALS(size_t, nb0, src0, nb) + GGML_TENSOR_LOCALS(int64_t, ne, dst, ne) + GGML_TENSOR_LOCALS(size_t, nb, dst, nb) + const int ith = params->ith; const int nth = params->nth; @@ -2114,19 +2124,23 @@ static void ggml_compute_forward_gelu_f32( const int ir0 = dr*ith; const int ir1 = MIN(ir0 + dr, nr); - for (int i1 = ir0; i1 < ir1; i1++) { + for (int ir = ir0; ir < ir1; ++ir) { + const int i3 = ir/(ne02*ne01); + const int i2 = (ir - i3*ne02*ne01)/ne01; + const int i1 = (ir - i3*ne02*ne01 - i2*ne01); + ggml_vec_gelu_f32(nc, - (float *) ((char *) dst->data + i1*( dst->nb[1])), - (float *) ((char *) src0->data + i1*(src0->nb[1]))); + (float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1), + (float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01)); #ifndef NDEBUG for (int k = 0; k < nc; k++) { - const float x = ((float *) ((char *) dst->data + i1*( dst->nb[1])))[k]; + const float x = ((float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*(dst->nb[1])))[k]; GGML_UNUSED(x); assert(!isnan(x)); assert(!isinf(x)); } -#endif +#endif // NDEBUG } } @@ -2136,10 +2150,14 @@ static void ggml_compute_forward_gelu_f16( const ggml_tensor * src0 = dst->src[0]; - assert(ggml_is_contiguous_1(src0)); - assert(ggml_is_contiguous_1(dst)); + assert(ggml_is_contiguous_rows(src0)); assert(ggml_are_same_shape(src0, dst)); + GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne) + GGML_TENSOR_LOCALS(size_t, nb0, src0, nb) + GGML_TENSOR_LOCALS(int64_t, ne, dst, ne) + GGML_TENSOR_LOCALS(size_t, nb, dst, nb) + const int ith = params->ith; const int nth = params->nth; @@ -2153,20 +2171,24 @@ static void ggml_compute_forward_gelu_f16( const int ir0 = dr*ith; const int ir1 = MIN(ir0 + dr, nr); - for (int i1 = ir0; i1 < ir1; i1++) { + for (int ir = ir0; ir < ir1; ++ir) { + const int i3 = ir/(ne02*ne01); + const int i2 = (ir - i3*ne02*ne01)/ne01; + const int i1 = (ir - i3*ne02*ne01 - i2*ne01); + ggml_vec_gelu_f16(nc, - (ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])), - (ggml_fp16_t *) ((char *) src0->data + i1*(src0->nb[1]))); + (ggml_fp16_t *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1), + (ggml_fp16_t *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01)); #ifndef NDEBUG for (int k = 0; k < nc; k++) { - const ggml_fp16_t x = ((ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])))[k]; + const ggml_fp16_t x = ((ggml_fp16_t *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*( dst->nb[1])))[k]; const float v = GGML_CPU_FP16_TO_FP32(x); GGML_UNUSED(v); assert(!isnan(v)); assert(!isinf(v)); } -#endif +#endif // NDEBUG } } @@ -2213,8 +2235,42 @@ static void ggml_compute_forward_fill_f32(const ggml_compute_params * params, gg } } +static void ggml_compute_forward_fill_f16(const ggml_compute_params * params, ggml_tensor * dst) { + const ggml_fp16_t c = GGML_CPU_FP32_TO_FP16(ggml_get_op_params_f32(dst, 0)); + + GGML_TENSOR_LOCALS(int64_t, ne, dst, ne); + GGML_TENSOR_LOCALS(size_t, nb, dst, nb); + + const auto [ir0, ir1] = get_thread_range(params, dst); + + for (int64_t ir = ir0; ir < ir1; ++ir) { + const int64_t i03 = ir/(ne2*ne1); + const int64_t i02 = (ir - i03*ne2*ne1)/ne1; + const int64_t i01 = (ir - i03*ne2*ne1 - i02*ne1); + + ggml_fp16_t * dst_ptr = (ggml_fp16_t *) ((char *) dst->data + i03*nb3 + i02*nb2 + i01*nb1); + + ggml_vec_set_f16(ne0, dst_ptr, c); + } +} + void ggml_compute_forward_fill(const ggml_compute_params * params, ggml_tensor * dst) { - ggml_compute_forward_fill_f32(params, dst); + const ggml_tensor * src0 = dst->src[0]; + + switch (src0->type) { + case GGML_TYPE_F32: + { + ggml_compute_forward_fill_f32(params, dst); + } break; + case GGML_TYPE_F16: + { + ggml_compute_forward_fill_f16(params, dst); + } break; + default: + { + GGML_ABORT("unsupported type for ggml_compute_forward_fill: %s", ggml_type_name(src0->type)); + } + } } // ggml_compute_tri @@ -2277,10 +2333,14 @@ static void ggml_compute_forward_gelu_erf_f32( const ggml_tensor * src0 = dst->src[0]; - assert(ggml_is_contiguous_1(src0)); - assert(ggml_is_contiguous_1(dst)); + assert(ggml_is_contiguous_rows(src0)); assert(ggml_are_same_shape(src0, dst)); + GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne) + GGML_TENSOR_LOCALS(size_t, nb0, src0, nb) + GGML_TENSOR_LOCALS(int64_t, ne, dst, ne) + GGML_TENSOR_LOCALS(size_t, nb, dst, nb) + const int ith = params->ith; const int nth = params->nth; @@ -2294,19 +2354,23 @@ static void ggml_compute_forward_gelu_erf_f32( const int ir0 = dr*ith; const int ir1 = MIN(ir0 + dr, nr); - for (int i1 = ir0; i1 < ir1; i1++) { + for (int ir = ir0; ir < ir1; ++ir) { + const int i3 = ir/(ne02*ne01); + const int i2 = (ir - i3*ne02*ne01)/ne01; + const int i1 = (ir - i3*ne02*ne01 - i2*ne01); + ggml_vec_gelu_erf_f32(nc, - (float *) ((char *) dst->data + i1*( dst->nb[1])), - (float *) ((char *) src0->data + i1*(src0->nb[1]))); + (float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1), + (float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01)); #ifndef NDEBUG for (int k = 0; k < nc; k++) { - const float x = ((float *) ((char *) dst->data + i1*( dst->nb[1])))[k]; + const float x = ((float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*(dst->nb[1])))[k]; GGML_UNUSED(x); assert(!isnan(x)); assert(!isinf(x)); } -#endif +#endif // NDEBUG } } @@ -2316,10 +2380,14 @@ static void ggml_compute_forward_gelu_erf_f16( const ggml_tensor * src0 = dst->src[0]; - assert(ggml_is_contiguous_1(src0)); - assert(ggml_is_contiguous_1(dst)); + assert(ggml_is_contiguous_rows(src0)); assert(ggml_are_same_shape(src0, dst)); + GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne) + GGML_TENSOR_LOCALS(size_t, nb0, src0, nb) + GGML_TENSOR_LOCALS(int64_t, ne, dst, ne) + GGML_TENSOR_LOCALS(size_t, nb, dst, nb) + const int ith = params->ith; const int nth = params->nth; @@ -2333,20 +2401,24 @@ static void ggml_compute_forward_gelu_erf_f16( const int ir0 = dr*ith; const int ir1 = MIN(ir0 + dr, nr); - for (int i1 = ir0; i1 < ir1; i1++) { + for (int ir = ir0; ir < ir1; ++ir) { + const int i3 = ir/(ne02*ne01); + const int i2 = (ir - i3*ne02*ne01)/ne01; + const int i1 = (ir - i3*ne02*ne01 - i2*ne01); + ggml_vec_gelu_erf_f16(nc, - (ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])), - (ggml_fp16_t *) ((char *) src0->data + i1*(src0->nb[1]))); + (ggml_fp16_t *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1), + (ggml_fp16_t *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01)); #ifndef NDEBUG for (int k = 0; k < nc; k++) { - const ggml_fp16_t x = ((ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])))[k]; + const ggml_fp16_t x = ((ggml_fp16_t *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*( dst->nb[1])))[k]; const float v = GGML_CPU_FP16_TO_FP32(x); GGML_UNUSED(v); assert(!isnan(v)); assert(!isinf(v)); } -#endif +#endif // NDEBUG } } @@ -2380,10 +2452,14 @@ static void ggml_compute_forward_gelu_quick_f32( const ggml_tensor * src0 = dst->src[0]; - assert(ggml_is_contiguous_1(src0)); - assert(ggml_is_contiguous_1(dst)); + assert(ggml_is_contiguous_rows(src0)); assert(ggml_are_same_shape(src0, dst)); + GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne) + GGML_TENSOR_LOCALS(size_t, nb0, src0, nb) + GGML_TENSOR_LOCALS(int64_t, ne, dst, ne) + GGML_TENSOR_LOCALS(size_t, nb, dst, nb) + const int ith = params->ith; const int nth = params->nth; @@ -2397,19 +2473,23 @@ static void ggml_compute_forward_gelu_quick_f32( const int ir0 = dr*ith; const int ir1 = MIN(ir0 + dr, nr); - for (int i1 = ir0; i1 < ir1; i1++) { + for (int ir = ir0; ir < ir1; ++ir) { + const int i3 = ir/(ne02*ne01); + const int i2 = (ir - i3*ne02*ne01)/ne01; + const int i1 = (ir - i3*ne02*ne01 - i2*ne01); + ggml_vec_gelu_quick_f32(nc, - (float *) ((char *) dst->data + i1*( dst->nb[1])), - (float *) ((char *) src0->data + i1*(src0->nb[1]))); + (float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1), + (float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01)); #ifndef NDEBUG for (int k = 0; k < nc; k++) { - const float x = ((float *) ((char *) dst->data + i1*( dst->nb[1])))[k]; + const float x = ((float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*(dst->nb[1])))[k]; GGML_UNUSED(x); assert(!isnan(x)); assert(!isinf(x)); } -#endif +#endif // NDEBUG } } @@ -2419,10 +2499,14 @@ static void ggml_compute_forward_gelu_quick_f16( const ggml_tensor * src0 = dst->src[0]; - assert(ggml_is_contiguous_1(src0)); - assert(ggml_is_contiguous_1(dst)); + assert(ggml_is_contiguous_rows(src0)); assert(ggml_are_same_shape(src0, dst)); + GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne) + GGML_TENSOR_LOCALS(size_t, nb0, src0, nb) + GGML_TENSOR_LOCALS(int64_t, ne, dst, ne) + GGML_TENSOR_LOCALS(size_t, nb, dst, nb) + const int ith = params->ith; const int nth = params->nth; @@ -2436,20 +2520,24 @@ static void ggml_compute_forward_gelu_quick_f16( const int ir0 = dr*ith; const int ir1 = MIN(ir0 + dr, nr); - for (int i1 = ir0; i1 < ir1; i1++) { + for (int ir = ir0; ir < ir1; ++ir) { + const int i3 = ir/(ne02*ne01); + const int i2 = (ir - i3*ne02*ne01)/ne01; + const int i1 = (ir - i3*ne02*ne01 - i2*ne01); + ggml_vec_gelu_quick_f16(nc, - (ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])), - (ggml_fp16_t *) ((char *) src0->data + i1*(src0->nb[1]))); + (ggml_fp16_t *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1), + (ggml_fp16_t *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01)); #ifndef NDEBUG for (int k = 0; k < nc; k++) { - const ggml_fp16_t x = ((ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])))[k]; + const ggml_fp16_t x = ((ggml_fp16_t *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*( dst->nb[1])))[k]; const float v = GGML_CPU_FP16_TO_FP32(x); GGML_UNUSED(v); assert(!isnan(v)); assert(!isinf(v)); } -#endif +#endif // NDEBUG } } @@ -2483,10 +2571,14 @@ static void ggml_compute_forward_silu_f32( const ggml_tensor * src0 = dst->src[0]; - assert(ggml_is_contiguous_1(src0)); - assert(ggml_is_contiguous_1(dst)); + assert(ggml_is_contiguous_rows(src0)); assert(ggml_are_same_shape(src0, dst)); + GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne) + GGML_TENSOR_LOCALS(size_t, nb0, src0, nb) + GGML_TENSOR_LOCALS(int64_t, ne, dst, ne) + GGML_TENSOR_LOCALS(size_t, nb, dst, nb) + const int ith = params->ith; const int nth = params->nth; @@ -2500,19 +2592,23 @@ static void ggml_compute_forward_silu_f32( const int ir0 = dr*ith; const int ir1 = MIN(ir0 + dr, nr); - for (int i1 = ir0; i1 < ir1; i1++) { + for (int ir = ir0; ir < ir1; ++ir) { + const int i3 = ir/(ne02*ne01); + const int i2 = (ir - i3*ne02*ne01)/ne01; + const int i1 = (ir - i3*ne02*ne01 - i2*ne01); + ggml_vec_silu_f32(nc, - (float *) ((char *) dst->data + i1*( dst->nb[1])), - (float *) ((char *) src0->data + i1*(src0->nb[1]))); + (float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1), + (float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01)); #ifndef NDEBUG for (int k = 0; k < nc; k++) { - const float x = ((float *) ((char *) dst->data + i1*(dst->nb[1])))[k]; + const float x = ((float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*(dst->nb[1])))[k]; GGML_UNUSED(x); assert(!isnan(x)); assert(!isinf(x)); } -#endif +#endif // NDEBUG } } @@ -2522,10 +2618,14 @@ static void ggml_compute_forward_silu_f16( const ggml_tensor * src0 = dst->src[0]; - assert(ggml_is_contiguous_1(src0)); - assert(ggml_is_contiguous_1(dst)); + assert(ggml_is_contiguous_rows(src0)); assert(ggml_are_same_shape(src0, dst)); + GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne) + GGML_TENSOR_LOCALS(size_t, nb0, src0, nb) + GGML_TENSOR_LOCALS(int64_t, ne, dst, ne) + GGML_TENSOR_LOCALS(size_t, nb, dst, nb) + const int ith = params->ith; const int nth = params->nth; @@ -2539,20 +2639,24 @@ static void ggml_compute_forward_silu_f16( const int ir0 = dr*ith; const int ir1 = MIN(ir0 + dr, nr); - for (int i1 = ir0; i1 < ir1; i1++) { + for (int ir = ir0; ir < ir1; ++ir) { + const int i3 = ir/(ne02*ne01); + const int i2 = (ir - i3*ne02*ne01)/ne01; + const int i1 = (ir - i3*ne02*ne01 - i2*ne01); + ggml_vec_silu_f16(nc, - (ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])), - (ggml_fp16_t *) ((char *) src0->data + i1*(src0->nb[1]))); + (ggml_fp16_t *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1), + (ggml_fp16_t *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01)); #ifndef NDEBUG for (int k = 0; k < nc; k++) { - const ggml_fp16_t x = ((ggml_fp16_t *) ((char *) dst->data + i1*(dst->nb[1])))[k]; + const ggml_fp16_t x = ((ggml_fp16_t *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*( dst->nb[1])))[k]; const float v = GGML_CPU_FP16_TO_FP32(x); GGML_UNUSED(v); assert(!isnan(v)); assert(!isinf(v)); } -#endif +#endif // NDEBUG } } @@ -2702,7 +2806,7 @@ static void ggml_compute_forward_silu_back_f32( assert(!isnan(x)); assert(!isinf(x)); } -#endif +#endif // NDEBUG } } @@ -2738,7 +2842,7 @@ static void ggml_compute_forward_silu_back_f16( (ggml_fp16_t *) ((char *) src1->data + i1*(src1->nb[1])), (ggml_fp16_t *) ((char *) grad->data + i1*(grad->nb[1]))); - #ifndef NDEBUG +#ifndef NDEBUG for (int k = 0; k < nc; k++) { const float x = ((ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])))[k]; const float v = GGML_CPU_FP16_TO_FP32(x); @@ -2746,7 +2850,7 @@ static void ggml_compute_forward_silu_back_f16( assert(!isnan(v)); assert(!isinf(v)); } - #endif +#endif // NDEBUG } } @@ -2829,7 +2933,7 @@ static void ggml_compute_forward_reglu_f32( assert(!isnan(x)); assert(!isinf(x)); } -#endif +#endif // NDEBUG } } @@ -2889,7 +2993,7 @@ static void ggml_compute_forward_reglu_f16( assert(!isnan(v)); assert(!isinf(v)); } -#endif +#endif // NDEBUG } } @@ -2972,7 +3076,7 @@ static void ggml_compute_forward_geglu_f32( assert(!isnan(x)); assert(!isinf(x)); } -#endif +#endif // NDEBUG } } @@ -3032,7 +3136,7 @@ static void ggml_compute_forward_geglu_f16( assert(!isnan(v)); assert(!isinf(v)); } -#endif +#endif // NDEBUG } } @@ -3115,7 +3219,7 @@ static void ggml_compute_forward_swiglu_f32( assert(!isnan(x)); assert(!isinf(x)); } -#endif +#endif // NDEBUG } } @@ -3175,7 +3279,7 @@ static void ggml_compute_forward_swiglu_f16( assert(!isnan(v)); assert(!isinf(v)); } -#endif +#endif // NDEBUG } } @@ -3266,7 +3370,7 @@ static void ggml_compute_forward_swiglu_oai_f32( assert(!isnan(x)); assert(!isinf(x)); } -#endif +#endif // NDEBUG } } @@ -3345,7 +3449,7 @@ static void ggml_compute_forward_geglu_erf_f32( assert(!isnan(x)); assert(!isinf(x)); } -#endif +#endif // NDEBUG } } @@ -3405,7 +3509,7 @@ static void ggml_compute_forward_geglu_erf_f16( assert(!isnan(v)); assert(!isinf(v)); } -#endif +#endif // NDEBUG } } @@ -3488,7 +3592,7 @@ static void ggml_compute_forward_geglu_quick_f32( assert(!isnan(x)); assert(!isinf(x)); } -#endif +#endif // NDEBUG } } @@ -3548,7 +3652,7 @@ static void ggml_compute_forward_geglu_quick_f16( assert(!isnan(v)); assert(!isinf(v)); } -#endif +#endif // NDEBUG } } @@ -3643,11 +3747,27 @@ void ggml_compute_forward_norm( // ggml_compute_forward_group_rms_norm +// fusion kinds that can be combined with the rms_norm computation in a single pass. +// extend this enum when adding new fused variants (e.g. FUSE_ADD, FUSE_MUL_ADD, ...). +enum ggml_rms_norm_fuse_op { + GGML_RMS_NORM_FUSE_OP_NONE, + GGML_RMS_NORM_FUSE_OP_MUL, +}; + +template <ggml_rms_norm_fuse_op FUSE_OP> static void ggml_compute_forward_rms_norm_f32( const ggml_compute_params * params, - ggml_tensor * dst) { + ggml_tensor * dst_rms_norm, + ggml_tensor * dst_fused = nullptr) { - const ggml_tensor * src0 = dst->src[0]; + const ggml_tensor * src0 = dst_rms_norm->src[0]; + const ggml_tensor * src1 = nullptr; + ggml_tensor * dst = dst_rms_norm; + + if constexpr (FUSE_OP == GGML_RMS_NORM_FUSE_OP_MUL) { + src1 = (dst_fused->src[0] == dst_rms_norm) ? dst_fused->src[1] : dst_fused->src[0]; + dst = dst_fused; + } GGML_ASSERT(ggml_are_same_shape(src0, dst)); @@ -3656,11 +3776,10 @@ static void ggml_compute_forward_rms_norm_f32( const int ith = params->ith; const int nth = params->nth; - GGML_TENSOR_UNARY_OP_LOCALS + GGML_TENSOR_BINARY_OP_LOCALS float eps; - memcpy(&eps, dst->op_params, sizeof(float)); - + memcpy(&eps, dst_rms_norm->op_params, sizeof(float)); GGML_ASSERT(eps >= 0.0f); // TODO: optimize @@ -3670,25 +3789,32 @@ static void ggml_compute_forward_rms_norm_f32( const float * x = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03); ggml_float sum = 0.0; + // worth switching to explicit SIMD? for (int64_t i00 = 0; i00 < ne00; i00++) { sum += (ggml_float)(x[i00] * x[i00]); } - const float mean = sum/ne00; - - float * y = (float *) ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3); - - memcpy(y, x, ne00 * sizeof(float)); - // for (int i00 = 0; i00 < ne00; i00++) { - // y[i00] = x[i00]; - // } - + const float mean = sum/ne00; const float scale = 1.0f/sqrtf(mean + eps); // if you hit this, likely you got an inf somewhere earlier assert(scale > 0.0f); - ggml_vec_scale_f32(ne00, y, scale); + float * y = (float *) ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3); + + if constexpr (FUSE_OP == GGML_RMS_NORM_FUSE_OP_MUL) { + const int64_t i11 = i01 % ne11; + const int64_t i12 = i02 % ne12; + const int64_t i13 = i03 % ne13; + const float * w = (float *) ((char *) src1->data + i11*nb11 + i12*nb12 + i13*nb13); + + for (int64_t i00 = 0; i00 < ne00; i00++) { + y[i00] = x[i00] * scale * w[i00]; + } + } else { + memcpy(y, x, ne00 * sizeof(float)); + ggml_vec_scale_f32(ne00, y, scale); + } } } } @@ -3703,7 +3829,31 @@ void ggml_compute_forward_rms_norm( switch (src0->type) { case GGML_TYPE_F32: { - ggml_compute_forward_rms_norm_f32(params, dst); + ggml_compute_forward_rms_norm_f32<GGML_RMS_NORM_FUSE_OP_NONE>(params, dst); + } break; + default: + { + GGML_ABORT("fatal error"); + } + } +} + +// Fused RMS_NORM + MUL: computes dst = rms_norm(src0) * src1 in a single pass. +// This avoids materializing the intermediate rms_norm result in memory. +void ggml_compute_forward_rms_norm_mul_fused( + const ggml_compute_params * params, + ggml_tensor * dst_rms_norm, + ggml_tensor * dst_mul) { + + GGML_ASSERT(dst_mul != nullptr); + GGML_ASSERT(dst_mul->src[0] == dst_rms_norm || dst_mul->src[1] == dst_rms_norm); + + const ggml_tensor * src0 = dst_rms_norm->src[0]; + + switch (src0->type) { + case GGML_TYPE_F32: + { + ggml_compute_forward_rms_norm_f32<GGML_RMS_NORM_FUSE_OP_MUL>(params, dst_rms_norm, dst_mul); } break; default: { @@ -3858,12 +4008,12 @@ static void ggml_compute_forward_rms_norm_back_f32( // dx := scale(dx, rrms) float * dx = (float *) ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3); - // dx[i00] = (x*(-sum_xdz/sum_eps) + dz) / sqrtf(mean_eps) - ggml_vec_cpy_f32 (ne00, dx, x); - // ggml_vec_scale_f32(ne00, dx, -mean_xdz/mean_eps); - ggml_vec_scale_f32(ne00, dx, (float)(-sum_xdz)/sum_eps); - ggml_vec_acc_f32 (ne00, dx, dz); - ggml_vec_scale_f32(ne00, dx, rrms); + // dx[i00] = (dz + x*(-sum_xdz/sum_eps)) * rrms + // note: https://github.com/ggml-org/ggml/issues/1491 + const float scale_x = (float) (-sum_xdz) / sum_eps; + for (int64_t i00 = 0; i00 < ne00; i00++) { + dx[i00] = (dz[i00] + x[i00] * scale_x) * rrms; + } } } } @@ -4264,12 +4414,14 @@ void ggml_compute_forward_out_prod( const ggml_tensor * src0 = dst->src[0]; switch (src0->type) { + case GGML_TYPE_Q1_0: case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_1: case GGML_TYPE_Q5_0: case GGML_TYPE_Q5_1: case GGML_TYPE_Q8_0: case GGML_TYPE_MXFP4: + case GGML_TYPE_NVFP4: case GGML_TYPE_Q2_K: case GGML_TYPE_Q3_K: case GGML_TYPE_Q4_K: @@ -4538,6 +4690,7 @@ void ggml_compute_forward_set( } break; case GGML_TYPE_F16: case GGML_TYPE_BF16: + case GGML_TYPE_Q1_0: case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_1: case GGML_TYPE_Q5_0: @@ -4545,6 +4698,7 @@ void ggml_compute_forward_set( case GGML_TYPE_Q8_0: case GGML_TYPE_Q8_1: case GGML_TYPE_MXFP4: + case GGML_TYPE_NVFP4: case GGML_TYPE_Q2_K: case GGML_TYPE_Q3_K: case GGML_TYPE_Q4_K: @@ -4760,6 +4914,7 @@ void ggml_compute_forward_get_rows( const ggml_tensor * src0 = dst->src[0]; switch (src0->type) { + case GGML_TYPE_Q1_0: case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_1: case GGML_TYPE_Q5_0: @@ -4767,6 +4922,7 @@ void ggml_compute_forward_get_rows( case GGML_TYPE_Q8_0: case GGML_TYPE_Q8_1: case GGML_TYPE_MXFP4: + case GGML_TYPE_NVFP4: case GGML_TYPE_Q2_K: case GGML_TYPE_Q3_K: case GGML_TYPE_Q4_K: @@ -5239,7 +5395,7 @@ static void ggml_compute_forward_soft_max_f32( //printf("p[%d] = %f\n", i, p[i]); assert(!isnan(wp[i])); } -#endif +#endif // NDEBUG float max = -INFINITY; ggml_vec_max_f32(ne00, &max, wp); @@ -5264,7 +5420,7 @@ static void ggml_compute_forward_soft_max_f32( assert(!isnan(dp[i])); assert(!isinf(dp[i])); } -#endif +#endif // NDEBUG } } } @@ -5338,7 +5494,7 @@ static void ggml_compute_forward_soft_max_ext_back_f32( assert(!isnan(dy[i])); assert(!isnan(y[i])); } -#endif +#endif // NDEBUG // Jii = yi - yi*yi // Jij = -yi*yj // J = diag(y)-y.T*y @@ -5371,7 +5527,7 @@ static void ggml_compute_forward_soft_max_ext_back_f32( assert(!isnan(dx[i])); assert(!isinf(dx[i])); } -#endif +#endif // NDEBUG } } @@ -5484,6 +5640,7 @@ void ggml_compute_forward_clamp( ggml_compute_forward_clamp_f16(params, dst); } break; case GGML_TYPE_BF16: + case GGML_TYPE_Q1_0: case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_1: case GGML_TYPE_Q5_0: @@ -5491,6 +5648,7 @@ void ggml_compute_forward_clamp( case GGML_TYPE_Q8_0: case GGML_TYPE_Q8_1: case GGML_TYPE_MXFP4: + case GGML_TYPE_NVFP4: case GGML_TYPE_Q2_K: case GGML_TYPE_Q3_K: case GGML_TYPE_Q4_K: @@ -5739,28 +5897,33 @@ static void ggml_compute_forward_rope_flt( const int32_t * pos = (const int32_t *) src1->data; + int64_t last_i2 = -1; + for (int64_t i3 = 0; i3 < ne3; i3++) { // batch for (int64_t i2 = 0; i2 < ne2; i2++) { // seq-len - - float * cache = (float *) params->wdata + (ne0 + CACHE_LINE_SIZE_F32)*ith; - if (!mrope_used) { - const int64_t p = pos[i2]; - ggml_rope_cache_init(p, freq_scale, freq_factors, corr_dims, ne0, ext_factor, attn_factor, cache, sin_sign, theta_scale); - } - else { - const int64_t p_t = pos[i2]; - const int64_t p_h = pos[i2 + ne2]; - const int64_t p_w = pos[i2 + ne2 * 2]; - const int64_t p_e = pos[i2 + ne2 * 3]; - ggml_mrope_cache_init( - p_t, p_h, p_w, p_e, sections, is_imrope, is_vision, - freq_scale, freq_factors, corr_dims, ne0, ext_factor, attn_factor, cache, sin_sign, theta_scale); - } - for (int64_t i1 = 0; i1 < ne1; i1++) { // attn-heads - if (ir++ < ir0) continue; + if (ir++ < ir0) continue; // skip rows mapped to other threads if (ir > ir1) break; + float * cache = (float *) params->wdata + (ne0 + CACHE_LINE_SIZE_F32)*ith; + if (last_i2 != i2) { + if (!mrope_used) { + const int64_t p = pos[i2]; + ggml_rope_cache_init(p, freq_scale, freq_factors, corr_dims, ne0, ext_factor, attn_factor, cache, sin_sign, theta_scale); + } + else { + const int64_t p_t = pos[i2]; + const int64_t p_h = pos[i2 + ne2]; + const int64_t p_w = pos[i2 + ne2 * 2]; + const int64_t p_e = pos[i2 + ne2 * 3]; + ggml_mrope_cache_init( + p_t, p_h, p_w, p_e, sections, is_imrope, is_vision, + freq_scale, freq_factors, corr_dims, ne0, ext_factor, attn_factor, cache, sin_sign, theta_scale); + } + + last_i2 = i2; + } + T * src = (T *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01); T * dst_data = (T *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1); @@ -6129,7 +6292,7 @@ static void ggml_compute_forward_im2col_f16( const ggml_tensor * src1 = dst->src[1]; GGML_ASSERT(src0->type == GGML_TYPE_F16); - GGML_ASSERT(src1->type == GGML_TYPE_F32); + GGML_ASSERT(src1->type == GGML_TYPE_F16 || src1->type == GGML_TYPE_F32); GGML_ASSERT( dst->type == GGML_TYPE_F16); GGML_TENSOR_BINARY_OP_LOCALS; @@ -6160,7 +6323,7 @@ static void ggml_compute_forward_im2col_f16( int ofs1 = is_2D ? nb12 : nb11; GGML_ASSERT(nb00 == sizeof(ggml_fp16_t)); - GGML_ASSERT(nb10 == sizeof(float)); + GGML_ASSERT(nb10 == ggml_type_size(src1->type)); // im2col: [N, IC, IH, IW] => [N, OH, OW, IC*KH*KW] { @@ -6173,7 +6336,12 @@ static void ggml_compute_forward_im2col_f16( // micro kernel ggml_fp16_t * dst_data = wdata + (in*OH*OW + ioh*OW + iow)*(IC*KH*KW); // [IC, KH, KW] - const float * const src_data = (float *)((char *) src1->data + in*ofs0 + iic*ofs1); // [IH, IW] + const float * const src_data_f32 = src1->type == GGML_TYPE_F32 + ? (const float *)((const char *) src1->data + in*ofs0 + iic*ofs1) + : nullptr; // [IH, IW] + const ggml_fp16_t * const src_data_f16 = src1->type == GGML_TYPE_F16 + ? (const ggml_fp16_t *)((const char *) src1->data + in*ofs0 + iic*ofs1) + : nullptr; // [IH, IW] for (int64_t ikh = 0; ikh < KH; ikh++) { // 1 for (int64_t ikw = 0; ikw < KW; ikw++) { @@ -6183,7 +6351,11 @@ static void ggml_compute_forward_im2col_f16( if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW) { dst_data[iic*(KH*KW) + ikh*KW + ikw] = 0; } else { - dst_data[iic*(KH*KW) + ikh*KW + ikw] = GGML_CPU_FP32_TO_FP16(src_data[iih*IW + iiw]); + if (src_data_f32 != nullptr) { + dst_data[iic*(KH*KW) + ikh*KW + ikw] = GGML_CPU_FP32_TO_FP16(src_data_f32[iih*IW + iiw]); + } else { + dst_data[iic*(KH*KW) + ikh*KW + ikw] = src_data_f16[iih*IW + iiw]; + } } } } @@ -6558,6 +6730,78 @@ static inline int64_t ggml_wrap_around(int64_t coord, int64_t size) { return (coord + size) % size; // adding size avoids negative number weirdness } +// ggml_compute_forward_col2im_1d +// +// Scatter-add columns [K*OC, T_in] -> signal [T_out, OC] +// where T_out = (T_in - 1)*s + K - 2*p. Gather approach: each output reads ceil(K/s) inputs. +// Parallelized over the time axis so the split stays balanced whatever OC is. +// Supports F32, F16, BF16 input/output (same type), F32 accumulator. + +template <typename elem_t> +static void ggml_compute_forward_col2im_1d_impl( + const ggml_compute_params * params, + ggml_tensor * dst) { + + const ggml_tensor * src = dst->src[0]; // [K*OC, T_in] + + GGML_ASSERT(ggml_is_contiguous(src)); + GGML_ASSERT(ggml_is_contiguous(dst)); + + const int32_t s0 = ((const int32_t *)(dst->op_params))[0]; + const int32_t OC = ((const int32_t *)(dst->op_params))[1]; + const int32_t p0 = ((const int32_t *)(dst->op_params))[2]; + + const int64_t K_OC = src->ne[0]; + const int64_t T_in = src->ne[1]; + const int64_t K = K_OC / OC; + const int64_t T_out = dst->ne[0]; + + const elem_t * col_data = (const elem_t *) src->data; + elem_t * dst_data = (elem_t *) dst->data; + + const int ith = params->ith; + const int nth = params->nth; + + // Parallelize over the time axis: the split stays balanced whatever OC is, + // down to OC = 1 for mono audio, and threads read disjoint column bands + const int64_t dr = (T_out + nth - 1) / nth; + const int64_t it0 = dr * ith; + const int64_t it1 = it0 + dr < T_out ? it0 + dr : T_out; + + for (int64_t oc = 0; oc < OC; oc++) { + for (int64_t t_out = it0; t_out < it1; t_out++) { + const int64_t t_abs = t_out + p0; // absolute position in uncropped signal + // Gather: find all (t_in, k) where t_in * s + k == t_abs, 0 <= k < K + int64_t t_in_min = (t_abs - K + 1 + s0 - 1) / s0; // ceil((t_abs-K+1)/s) + if (t_in_min < 0) t_in_min = 0; + int64_t t_in_max = t_abs / s0; + if (t_in_max >= T_in) t_in_max = T_in - 1; + + float sum = 0.0f; + for (int64_t t_in = t_in_min; t_in <= t_in_max; t_in++) { + int64_t k = t_abs - t_in * s0; + if (k >= 0 && k < K) { + // col layout: [K*OC, T_in], element (oc*K+k, t_in) + sum += type_conversion_table<elem_t>::to_f32(col_data[(oc * K + k) + t_in * K_OC]); + } + } + // dst layout: [T_out, OC], element (t_out, oc) + dst_data[t_out + oc * T_out] = type_conversion_table<elem_t>::from_f32(sum); + } + } +} + +void ggml_compute_forward_col2im_1d( + const ggml_compute_params * params, + ggml_tensor * dst) { + switch (dst->src[0]->type) { + case GGML_TYPE_F32: ggml_compute_forward_col2im_1d_impl<float> (params, dst); break; + case GGML_TYPE_F16: ggml_compute_forward_col2im_1d_impl<ggml_fp16_t>(params, dst); break; + case GGML_TYPE_BF16: ggml_compute_forward_col2im_1d_impl<ggml_bf16_t>(params, dst); break; + default: GGML_ABORT("col2im_1d: unsupported type %d", dst->src[0]->type); + } +} + // ggml_compute_forward_conv_2d @@ -6838,16 +7082,15 @@ void ggml_compute_forward_conv_3d( ggml_compute_forward_conv_3d_impl(params, src0, src1, dst, src0->type); } -// ggml_compute_forward_conv_transpose_2d - -void ggml_compute_forward_conv_transpose_2d( - const ggml_compute_params * params, - ggml_tensor * dst) { +template <typename kernel_t> +static void ggml_compute_forward_conv_transpose_2d_impl( + const ggml_compute_params * params, + ggml_tensor * dst) { const ggml_tensor * src0 = dst->src[0]; const ggml_tensor * src1 = dst->src[1]; - GGML_ASSERT(src0->type == GGML_TYPE_F16); + GGML_ASSERT(src0->type == GGML_TYPE_F16 || src0->type == GGML_TYPE_F32); GGML_ASSERT(src1->type == GGML_TYPE_F32); GGML_ASSERT( dst->type == GGML_TYPE_F32); @@ -6858,7 +7101,7 @@ void ggml_compute_forward_conv_transpose_2d( const int nk = ne00*ne01*ne02*ne03; - GGML_ASSERT(nb00 == sizeof(ggml_fp16_t)); + GGML_ASSERT(nb00 == ggml_type_size(src0->type)); GGML_ASSERT(nb10 == sizeof(float)); if (ith == 0) { @@ -6866,12 +7109,12 @@ void ggml_compute_forward_conv_transpose_2d( // permute kernel data (src0) from (Kw x Kh x Cout x Cin) to (Cin x Kw x Kh x Cout) { - ggml_fp16_t * const wdata = (ggml_fp16_t *) params->wdata + 0; + kernel_t * const wdata = (kernel_t *) params->wdata + 0; for (int64_t i03 = 0; i03 < ne03; i03++) { for (int64_t i02 = 0; i02 < ne02; i02++) { - const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i03*nb03 + i02*nb02); - ggml_fp16_t * dst_data = wdata + i02*ne01*ne00*ne03; + const kernel_t * const src = (kernel_t *)((char *) src0->data + i03*nb03 + i02*nb02); + kernel_t * dst_data = wdata + i02*ne01*ne00*ne03; for (int64_t i01 = 0; i01 < ne01; i01++) { for (int64_t i00 = 0; i00 < ne00; i00++) { dst_data[i01*ne00*ne03 + i00*ne03 + i03] = src[i01 * ne00 + i00]; @@ -6883,13 +7126,17 @@ void ggml_compute_forward_conv_transpose_2d( // permute source data (src1) from (Sw x Sh x Cin) to (Cin x Sw x Sh) { - ggml_fp16_t * const wdata = (ggml_fp16_t *) params->wdata + nk; + kernel_t * const wdata = (kernel_t *) params->wdata + nk; for (int i12 = 0; i12 < ne12; i12++) { for (int i11 = 0; i11 < ne11; i11++) { const float * const src = (float *)((char *) src1->data + i12*nb12 + i11*nb11); - ggml_fp16_t * dst_data = wdata + i11*ne10*ne12; + kernel_t * dst_data = wdata + i11*ne10*ne12; for (int i10 = 0; i10 < ne10; i10++) { - dst_data[i10*ne12 + i12] = GGML_CPU_FP32_TO_FP16(src[i10]); + if constexpr (std::is_same_v<kernel_t, ggml_fp16_t>) { + dst_data[i10*ne12 + i12] = GGML_CPU_FP32_TO_FP16(src[i10]); + } else { + dst_data[i10*ne12 + i12] = src[i10]; + } } } } @@ -6911,21 +7158,27 @@ void ggml_compute_forward_conv_transpose_2d( const int ip0 = dp*ith; const int ip1 = MIN(ip0 + dp, np); - ggml_fp16_t * const wdata = (ggml_fp16_t *) params->wdata + 0; - ggml_fp16_t * const wdata_src = wdata + nk; + kernel_t * const wdata = (kernel_t *) params->wdata + 0; + kernel_t * const wdata_src = wdata + nk; for (int i2 = ip0; i2 < ip1; i2++) { // Cout float * dst_data = (float *)((char *) dst->data + i2*nb2); - ggml_fp16_t * wdata_kernel = wdata + i2*ne01*ne00*ne03; + kernel_t * wdata_kernel = wdata + i2*ne01*ne00*ne03; for (int i11 = 0; i11 < ne11; i11++) { for (int i10 = 0; i10 < ne10; i10++) { const int i1n = i11*ne10*ne12 + i10*ne12; for (int i01 = 0; i01 < ne01; i01++) { for (int i00 = 0; i00 < ne00; i00++) { float v = 0; - ggml_vec_dot_f16(ne03, &v, 0, - wdata_src + i1n, 0, - wdata_kernel + i01*ne00*ne03 + i00*ne03, 0, 1); + if constexpr (std::is_same_v<kernel_t, ggml_fp16_t>) { + ggml_vec_dot_f16(ne03, &v, 0, + wdata_src + i1n, 0, + wdata_kernel + i01*ne00*ne03 + i00*ne03, 0, 1); + } else { + ggml_vec_dot_f32(ne03, &v, 0, + wdata_src + i1n, 0, + wdata_kernel + i01*ne00*ne03 + i00*ne03, 0, 1); + } dst_data[(i11*stride + i01)*ne0 + i10*stride + i00] += v; } } @@ -6934,19 +7187,41 @@ void ggml_compute_forward_conv_transpose_2d( } } -// ggml_compute_forward_conv_2d_dw +void ggml_compute_forward_conv_transpose_2d( + const ggml_compute_params * params, + ggml_tensor * dst) { -struct ggml_conv_2d_dw_params { - int64_t channels; - int64_t batch; - int64_t src_w; - int64_t src_h; - int64_t dst_w; - int64_t dst_h; - int64_t knl_w; - int64_t knl_h; - int stride_x; - int stride_y; + const ggml_tensor * src0 = dst->src[0]; + + switch (src0->type) { + case GGML_TYPE_F16: + { + ggml_compute_forward_conv_transpose_2d_impl<ggml_fp16_t>(params, dst); + } break; + case GGML_TYPE_F32: + { + ggml_compute_forward_conv_transpose_2d_impl<float>(params, dst); + } break; + default: + { + GGML_ABORT("fatal error"); + } + } +} + +// ggml_compute_forward_conv_2d_dw + +struct ggml_conv_2d_dw_params { + int64_t channels; + int64_t batch; + int64_t src_w; + int64_t src_h; + int64_t dst_w; + int64_t dst_h; + int64_t knl_w; + int64_t knl_h; + int stride_x; + int stride_y; int pad_x; int pad_y; int dilation_x; @@ -7110,12 +7385,13 @@ void ggml_compute_forward_conv_2d_dw( } } -// ggml_compute_forward_pool_1d_sk_p0 - -static void ggml_compute_forward_pool_1d_sk_p0( +// ggml_compute_forward_pool_1d_ksp +static void ggml_compute_forward_pool_1d_ksp( const ggml_compute_params * params, const ggml_op_pool op, const int k, + const int s, + const int p, ggml_tensor * dst) { const ggml_tensor * src = dst->src[0]; @@ -7126,39 +7402,56 @@ static void ggml_compute_forward_pool_1d_sk_p0( return; } - const char * cdata = (const char *)src->data; - const char * const data_end = cdata + ggml_nbytes(src); - float * drow = (float *)dst->data; + const int64_t IW = src->ne[0]; + const int64_t OW = dst->ne[0]; - const int64_t rs = dst->ne[0]; + const int64_t nr = ggml_nrows(src); - while (cdata < data_end) { - const void * srow = (const void *)cdata; - int j = 0; - for (int64_t i = 0; i < rs; ++i) { + for (int64_t ir = 0; ir < nr; ++ir) { + const char * srow_bytes = (const char *) src->data + ir * src->nb[1]; + float * drow = (float *) (( char *) dst->data + ir * dst->nb[1]); + + for (int64_t ow = 0; ow < OW; ++ow) { + float res = 0; switch (op) { - case GGML_OP_POOL_AVG: drow[i] = 0; break; - case GGML_OP_POOL_MAX: drow[i] = -FLT_MAX; break; + case GGML_OP_POOL_AVG: res = 0.0f; break; + case GGML_OP_POOL_MAX: res = -FLT_MAX; break; case GGML_OP_POOL_COUNT: GGML_ABORT("fatal error"); } + + int count = 0; + const int base = (int) ow * s - p; + for (int ki = 0; ki < k; ++ki) { - const float srow_j = (src->type == GGML_TYPE_F32) ? ((const float*)srow)[j] : GGML_CPU_FP16_TO_FP32(((const ggml_fp16_t*)srow)[j]); + const int j = base + ki; + if (j < 0 || j >= (int) IW) { + continue; + } + + float v; + if (src->type == GGML_TYPE_F32) { + v = ((const float *) srow_bytes)[j]; + } else { + v = GGML_CPU_FP16_TO_FP32(((const ggml_fp16_t *) srow_bytes)[j]); + } + switch (op) { - case GGML_OP_POOL_AVG: drow[i] += srow_j; break; - case GGML_OP_POOL_MAX: if (srow_j > drow[i]) drow[i] = srow_j; break; - case GGML_OP_POOL_COUNT: GGML_ABORT("fatal error"); + case GGML_OP_POOL_AVG: res += v; break; + case GGML_OP_POOL_MAX: res = std::max(v, res); break; + case GGML_OP_POOL_COUNT: GGML_ABORT("fatal error"); } - ++j; + + ++count; } + switch (op) { - case GGML_OP_POOL_AVG: drow[i] /= k; break; - case GGML_OP_POOL_MAX: break; + case GGML_OP_POOL_AVG: res = (count > 0) ? (res / count) : 0.0f; break; + case GGML_OP_POOL_MAX: break; case GGML_OP_POOL_COUNT: GGML_ABORT("fatal error"); } - } - cdata += src->nb[1]; - drow += rs; + drow[ow] = res; + } } } @@ -7173,10 +7466,8 @@ void ggml_compute_forward_pool_1d( const int k0 = opts[1]; const int s0 = opts[2]; const int p0 = opts[3]; - GGML_ASSERT(p0 == 0); // padding not supported - GGML_ASSERT(k0 == s0); // only s = k supported - ggml_compute_forward_pool_1d_sk_p0(params, op, k0, dst); + ggml_compute_forward_pool_1d_ksp(params, op, k0, s0, p0, dst); } // ggml_compute_forward_pool_2d @@ -7194,6 +7485,7 @@ void ggml_compute_forward_pool_2d( } const int32_t * opts = (const int32_t *)dst->op_params; + ggml_op_pool op = static_cast<ggml_op_pool>(opts[0]); const int k0 = opts[1]; const int k1 = opts[2]; @@ -7217,11 +7509,13 @@ void ggml_compute_forward_pool_2d( while (cdata < data_end) { for (int oy = 0; oy < py; ++oy) { float * const drow = dplane + oy * px; + float * const out = drow; + for (int ox = 0; ox < px; ++ox) { - float * const out = drow + ox; + float res = 0; switch (op) { - case GGML_OP_POOL_AVG: *out = 0; break; - case GGML_OP_POOL_MAX: *out = -FLT_MAX; break; + case GGML_OP_POOL_AVG: res = 0; break; + case GGML_OP_POOL_MAX: res = -FLT_MAX; break; case GGML_OP_POOL_COUNT: GGML_ABORT("fatal error"); } @@ -7229,24 +7523,32 @@ void ggml_compute_forward_pool_2d( const int iy = offset1 + oy * s1; for (int ky = 0; ky < k1; ++ky) { - if (iy + ky < 0 || iy + ky >= src->ne[1]) continue; + if (iy + ky < 0 || iy + ky >= src->ne[1]) { + continue; + } + const void * srow = (const void *)(cdata + src->nb[1] * (iy + ky)); for (int kx = 0; kx < k0; ++kx) { int j = ix + kx; - if (j < 0 || j >= src->ne[0]) continue; + if (j < 0 || j >= src->ne[0]) { + continue; + } + const float srow_j = (src->type == GGML_TYPE_F32) ? ((const float*)srow)[j] : GGML_CPU_FP16_TO_FP32(((const ggml_fp16_t*)srow)[j]); switch (op) { - case GGML_OP_POOL_AVG: *out += srow_j; break; - case GGML_OP_POOL_MAX: if (srow_j > *out) *out = srow_j; break; + case GGML_OP_POOL_AVG: res += srow_j; break; + case GGML_OP_POOL_MAX: res = std::max(srow_j, res); break; case GGML_OP_POOL_COUNT: GGML_ABORT("fatal error"); } } } switch (op) { - case GGML_OP_POOL_AVG: *out /= ka; break; - case GGML_OP_POOL_MAX: break; + case GGML_OP_POOL_AVG: res /= ka; break; + case GGML_OP_POOL_MAX: break; case GGML_OP_POOL_COUNT: GGML_ABORT("fatal error"); } + + out[ox] = res; } } @@ -7603,8 +7905,7 @@ static void ggml_compute_forward_pad_f32( const ggml_tensor * src0 = dst->src[0]; - GGML_ASSERT(src0->nb[0] == sizeof(float)); - GGML_ASSERT( dst->nb[0] == sizeof(float)); + assert(dst->nb[0] == sizeof(float)); const int ith = params->ith; const int nth = params->nth; @@ -8016,12 +8317,14 @@ void ggml_compute_forward_top_k( } } -// ggml_compute_forward_flash_attn_ext - static void ggml_compute_forward_flash_attn_ext_f16_one_chunk( const ggml_compute_params * params, ggml_tensor * dst, - int ir0, int ir1) { + int ir0, int ir1, + int64_t ic_start, int64_t ic_end, + float * partials, int64_t partial_stride) { + + const bool write_partials = (partials != nullptr); const ggml_tensor * q = dst->src[0]; const ggml_tensor * k = dst->src[1]; const ggml_tensor * v = dst->src[2]; @@ -8098,7 +8401,6 @@ static void ggml_compute_forward_flash_attn_ext_f16_one_chunk( int ith = params->ith; - // loop over n_batch and n_head for (int ir = ir0; ir < ir1; ++ir) { // q indices const int iq3 = ir/(neq2*neq1); @@ -8138,7 +8440,8 @@ static void ggml_compute_forward_flash_attn_ext_f16_one_chunk( // online softmax / attention // loop over n_kv and n_head_kv // ref: https://arxiv.org/pdf/2112.05682.pdf - for (int64_t ic = 0; ic < nek1; ++ic) { + + for (int64_t ic = ic_start; ic < ic_end; ++ic) { const float mv = mp ? slope*GGML_CPU_FP16_TO_FP32(mp[ic]) : 0.0f; if (mv == -INFINITY) { continue; @@ -8211,8 +8514,8 @@ static void ggml_compute_forward_flash_attn_ext_f16_one_chunk( } } - // sinks - if (sinks) { + // sinks - apply only on the first kv-chunk + if (sinks && ic_start == 0) { const float s = ((float *)((char *) sinks->data))[h]; float ms = 1.0f; @@ -8220,6 +8523,7 @@ static void ggml_compute_forward_flash_attn_ext_f16_one_chunk( if (s > M) { ms = expf(M - s); + M = s; ggml_vec_scale_f32(DV, VKQ32, ms); } else { vs = expf(s - M); @@ -8228,20 +8532,386 @@ static void ggml_compute_forward_flash_attn_ext_f16_one_chunk( S = S*ms + vs; } - // V /= S - const float S_inv = S == 0.0f ? 0.0f : 1.0f/S; - ggml_vec_scale_f32(DV, VKQ32, S_inv); + if (write_partials) { + // Write M, S, VKQ to partials for later reduction + // partials layout: [M, S, VKQ[DV]] per query head + float * partial = partials + ir * partial_stride; + partial[0] = M; + partial[1] = S; + memcpy(partial + 2, VKQ32, DV * sizeof(float)); + } else { + // V /= S + const float S_inv = S == 0.0f ? 0.0f : 1.0f/S; + ggml_vec_scale_f32(DV, VKQ32, S_inv); - // dst indices - const int i1 = iq1; - const int i2 = iq2; - const int i3 = iq3; + // dst indices + const int i1 = iq1; + const int i2 = iq2; + const int i3 = iq3; + + // permute(0, 2, 1, 3) + memcpy((char *) dst->data + (i3*ne2*ne1 + i2 + i1*ne1)*nb1, VKQ32, nb1); + } + } +} + +static void ggml_compute_forward_flash_attn_ext_tiled( + const ggml_compute_params * params, + ggml_tensor * dst, + int ir0, int ir1) { + const ggml_tensor * q = dst->src[0]; + const ggml_tensor * k = dst->src[1]; + const ggml_tensor * v = dst->src[2]; + const ggml_tensor * mask = dst->src[3]; + const ggml_tensor * sinks = dst->src[4]; + + GGML_TENSOR_LOCALS(int64_t, neq, q, ne) + GGML_TENSOR_LOCALS(size_t, nbq, q, nb) + GGML_TENSOR_LOCALS(int64_t, nek, k, ne) + GGML_TENSOR_LOCALS(size_t, nbk, k, nb) + GGML_TENSOR_LOCALS(int64_t, nev, v, ne) + GGML_TENSOR_LOCALS(size_t, nbv, v, nb) + GGML_TENSOR_LOCALS(int64_t, ne, dst, ne) + GGML_TENSOR_LOCALS(size_t, nb, dst, nb) + + const int64_t DK = nek0; + const int64_t DV = nev0; + const int64_t N = neq1; + + GGML_ASSERT(ne0 == DV); + GGML_ASSERT(ne2 == N); + + // input tensor rows must be contiguous + GGML_ASSERT(nbq0 == ggml_type_size(q->type)); + GGML_ASSERT(nbk0 == ggml_type_size(k->type)); + GGML_ASSERT(nbv0 == ggml_type_size(v->type)); + + GGML_ASSERT(neq0 == DK); + GGML_ASSERT(nek0 == DK); + GGML_ASSERT(nev0 == DV); + + GGML_ASSERT(neq1 == N); + + // dst cannot be transposed or permuted + GGML_ASSERT(nb0 == sizeof(float)); + GGML_ASSERT(nb0 <= nb1); + GGML_ASSERT(nb1 <= nb2); + GGML_ASSERT(nb2 <= nb3); + + GGML_ASSERT(k->type == v->type); + const ggml_type kv_type = k->type; + + + // broadcast factors + const int64_t rk2 = neq2/nek2; + const int64_t rk3 = neq3/nek3; + + const int64_t rv2 = neq2/nev2; + const int64_t rv3 = neq3/nev3; + + float scale = 1.0f; + float max_bias = 0.0f; + float logit_softcap = 0.0f; + + memcpy(&scale, (float *) dst->op_params + 0, sizeof(float)); + memcpy(&max_bias, (float *) dst->op_params + 1, sizeof(float)); + memcpy(&logit_softcap, (float *) dst->op_params + 2, sizeof(float)); + + if (logit_softcap != 0) { + scale /= logit_softcap; + } + + const uint32_t n_head = neq2; + const uint32_t n_head_log2 = 1u << (uint32_t) floor(log2(n_head)); + + const float m0 = powf(2.0f, -(max_bias ) / n_head_log2); + const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2); + + int ith = params->ith; + + static constexpr int Q_TILE_SZ = ggml_fa_tile_config::Q; + static constexpr int KV_TILE_SZ = ggml_fa_tile_config::KV; + + int ir = ir0; + while (ir < ir1) { + // q indices for the start of this tile + const int iq3 = ir/(neq2*neq1); + const int iq2 = (ir - iq3*neq2*neq1)/neq1; + const int iq1 = (ir - iq3*neq2*neq1 - iq2*neq1); + + // Number of valid rows in this tile: + // - limited by tile size (Q_TILE_SZ) + // - limited by chunk boundary (ir1 - ir) + // - limited by head boundary (neq1 - iq1) to avoid crossing into next head + const int tile_rows = MIN(Q_TILE_SZ, MIN((int)(ir1 - ir), (int)(neq1 - iq1))); + GGML_ASSERT(tile_rows > 0); + + const uint32_t h = iq2; // head index + const float slope = (max_bias > 0.0f) ? h < n_head_log2 ? powf(m0, h + 1) : powf(m1, 2*(h - n_head_log2) + 1) : 1.0f; + + float S[Q_TILE_SZ]; + float M[Q_TILE_SZ]; + + for (int i = 0 ; i < Q_TILE_SZ; ++i) { + S[i] = 0.; + M[i] = -INFINITY; + } + + // Per-thread scratch layout: + // Q_q: Q_TILE_SZ * DK (converted Q tile — F32 for GEMM, KV type for scalar) + // KQ: Q_TILE_SZ * KV_TILE_SZ (attention scores in float) + // mask: Q_TILE_SZ * KV_TILE_SZ (mask in float) + // VKQ32: Q_TILE_SZ * DV (FP32 output accumulator) + // V32: KV_TILE_SZ * DV (F32 buffer for V tile) + // K_f32: KV_TILE_SZ * DK (F32 buffer for K tile — GEMM path) + float * base = (float *) params->wdata + ith*(Q_TILE_SZ*DK + 2*Q_TILE_SZ*KV_TILE_SZ + Q_TILE_SZ*DV + KV_TILE_SZ*DV + KV_TILE_SZ*DK + CACHE_LINE_SIZE_F32); + + void * Q_q = base; + float * KQ = (float *)((char *)base + Q_TILE_SZ * DK * sizeof(float)); + float * mask32 = KQ + Q_TILE_SZ * KV_TILE_SZ; + float * VKQ32 = mask32 + Q_TILE_SZ * KV_TILE_SZ; + float * V32 = VKQ32 + Q_TILE_SZ * DV; + float * K_f32 = V32 + KV_TILE_SZ * DV; + + memset(VKQ32, 0, Q_TILE_SZ * DV * sizeof(float)); + memset(mask32, 0, Q_TILE_SZ * KV_TILE_SZ * sizeof(float)); + + // k indices + const int ik3 = iq3 / rk3; + const int ik2 = iq2 / rk2; + + // v indices + const int iv3 = iq3 / rv3; + const int iv2 = iq2 / rv2; + + { + float * Q_f32 = (float *)Q_q; + for (int tq = 0; tq < tile_rows; tq++) { + const float * pq = (const float *) ((char *) q->data + ((iq1 + tq)*nbq1 + iq2*nbq2 + iq3*nbq3)); + memcpy(Q_f32 + tq * DK, pq, DK * sizeof(float)); + } + for (int tq = tile_rows; tq < Q_TILE_SZ; tq++) { + memset(Q_f32 + tq * DK, 0, DK * sizeof(float)); + } + } + + memset(K_f32, 0, DK * KV_TILE_SZ * sizeof(float)); + memset(V32, 0, KV_TILE_SZ * DV * sizeof(float)); + + for (int64_t ic = 0; ic < nek1; ic += KV_TILE_SZ) { + const int kv_tile = (int)std::min((int64_t)KV_TILE_SZ, nek1 - ic); + + // skip the tile entirely if all the masks are -inf + if (mask) { + bool can_skip = true; + for (int tq = 0; tq < tile_rows; tq++) { + const ggml_fp16_t * mp_row = (const ggml_fp16_t *)((const char *) mask->data + (iq1 + tq)*mask->nb[1] + (iq2%mask->ne[2])*mask->nb[2] + (iq3%mask->ne[3])*mask->nb[3]); + for (int tk = 0; tk < kv_tile; tk++) { + mask32[tq * KV_TILE_SZ + tk] = slope * GGML_CPU_FP16_TO_FP32(mp_row[ic + tk]); + if (mask32[tq * KV_TILE_SZ + tk] != -INFINITY) { + can_skip = false; + } + } + // Pad remaining mask entries with -inf + for (int tk = kv_tile; tk < KV_TILE_SZ; tk++) { + mask32[tq * KV_TILE_SZ + tk] = -INFINITY; + } + } + + if (can_skip) { + continue; + } + } + + // Pack K tile transposed: K_f32[dk][kv] so KV_TILE is contiguous (SIMD dim) + // Zero-pad the last tile so the GEMM always operates on KV_TILE_SZ columns + for (int tk = 0; tk < kv_tile; tk++) { + const char * k_data = (const char *)k->data + (ic + tk)*nbk1 + ik2*nbk2 + ik3*nbk3; + if (kv_type == GGML_TYPE_F16) { + const ggml_fp16_t * k_f16 = (const ggml_fp16_t *)k_data; + for (int64_t dk = 0; dk < DK; dk++) { + K_f32[dk * KV_TILE_SZ + tk] = GGML_CPU_FP16_TO_FP32(k_f16[dk]); + } + } else { + const float * k_f32_src = (const float *)k_data; + for (int64_t dk = 0; dk < DK; dk++) { + K_f32[dk * KV_TILE_SZ + tk] = k_f32_src[dk]; + } + } + } + memset(KQ, 0, Q_TILE_SZ * KV_TILE_SZ * sizeof(float)); + simd_gemm(KQ, (const float *)Q_q, K_f32, Q_TILE_SZ, DK, KV_TILE_SZ); + ggml_vec_scale_f32(Q_TILE_SZ * KV_TILE_SZ, KQ, scale); + + // Set padded KQ entries to -inf so softmax gives them zero weight + if (kv_tile < KV_TILE_SZ) { + for (int tq = 0; tq < Q_TILE_SZ; tq++) { + for (int tk = kv_tile; tk < KV_TILE_SZ; tk++) { + KQ[tq * KV_TILE_SZ + tk] = -INFINITY; + } + } + } + + if (logit_softcap != 0.0f) { + ggml_vec_tanh_f32(Q_TILE_SZ * KV_TILE_SZ, KQ, KQ); + ggml_vec_scale_f32(Q_TILE_SZ * KV_TILE_SZ, KQ, logit_softcap); + } + + if (mask) { + ggml_vec_add_f32(tile_rows * KV_TILE_SZ, KQ, KQ, mask32); + } + + bool skip[Q_TILE_SZ] = {}; + + for (int tq = 0; tq < Q_TILE_SZ; tq++) { + float * kq_row = KQ + tq * KV_TILE_SZ; + + float tile_max; + ggml_vec_max_f32(KV_TILE_SZ, &tile_max, kq_row); + + if (tile_max == -INFINITY) { + skip[tq] = true; + continue; + } + + const float Mold = M[tq]; + const float Mnew = fmaxf(Mold, tile_max); + + if (Mnew > Mold) { + const float ms = expf(Mold - Mnew); + ggml_vec_scale_f32(DV, VKQ32 + tq * DV, ms); + S[tq] *= ms; + } + M[tq] = Mnew; + + + S[tq] += ggml_vec_soft_max_f32(KV_TILE_SZ, kq_row, kq_row, Mnew); + } + + // V accumulation: VKQ32 += softmax(KQ) * V + // Pack V tile to contiguous F32, zero-padded + for (int tk = 0; tk < kv_tile; tk++) { + const char * v_data = (const char *)v->data + (ic + tk)*nbv1 + iv2*nbv2 + iv3*nbv3; + if (kv_type == GGML_TYPE_F16) { + ggml_fp16_to_fp32_row((const ggml_fp16_t *)v_data, V32 + tk * DV, DV); + } else { + memcpy(V32 + tk * DV, v_data, DV * sizeof(float)); + } + } + for (int tq = 0; tq < Q_TILE_SZ; tq++) { + if (skip[tq]) { + memset(KQ + tq * KV_TILE_SZ, 0, KV_TILE_SZ * sizeof(float)); + } + } + simd_gemm(VKQ32, KQ, V32, Q_TILE_SZ, KV_TILE_SZ, DV); + } + + // sinks (apply only to valid rows in the tile) + if (sinks) { + const float s = ((float *)((char *) sinks->data))[h]; + + for (int tq = 0; tq < tile_rows; tq++) { + float ms = 1.0f; + float vs = 1.0f; + + if (s > M[tq]) { + ms = expf(M[tq] - s); + ggml_vec_scale_f32(DV, VKQ32 + tq * DV, ms); + } else { + vs = expf(s - M[tq]); + } + + S[tq] = S[tq] * ms + vs; + } + } + + for (int tq = 0; tq < tile_rows; tq++) { + // V /= S + const float S_inv = S[tq] == 0.0f ? 0.0f : 1.0f / S[tq]; + ggml_vec_scale_f32(DV, VKQ32 + tq * DV, S_inv); + + // dst indices + const int i1 = iq1 + tq; + const int i2 = iq2; + const int i3 = iq3; + + // permute(0, 2, 1, 3) + memcpy((char *) dst->data + (i3*ne2*ne1 + i2 + i1*ne1)*nb1, VKQ32 + tq * DV, nb1); + } + + ir += tile_rows; + } +} + +// Reduction function: combines partial results across KV chunks +// Partials layout in wdata: [n_q_heads][n_chunks][2 + DV] +static void ggml_flash_attn_ext_reduce_partials( + const ggml_compute_params * params, + ggml_tensor * dst, + const int64_t n_chunks, + const int64_t chunk_size) { + + const ggml_tensor * q = dst->src[0]; + const ggml_tensor * k = dst->src[1]; + const ggml_tensor * v = dst->src[2]; + + const int64_t DK = k->ne[0]; + const int64_t DV = v->ne[0]; + const int64_t nek1 = k->ne[1]; + const int64_t n_q_heads = q->ne[2]; + + const int ith = params->ith; + const int nth = params->nth; + + const int64_t wdata_per_thread = DK + 2*DV + CACHE_LINE_SIZE_F32; + float * thread_wdata = (float *) params->wdata + ith * wdata_per_thread; + + const int64_t partials_offset = nth * (DK + 2*DV + CACHE_LINE_SIZE_F32); + const int64_t partial_size = 2 + DV; + const float * partials_base = (const float *) params->wdata + partials_offset; + + // Output layout + const int64_t ne1 = dst->ne[1]; + const int64_t ne2 = dst->ne[2]; + const size_t nb1 = dst->nb[1]; - // original - //memcpy((char *) dst->data + (i1*nb1 + i2*nb2 + i3*nb3), V, nev0*sizeof(float)); + // Each thread reduces a subset of query heads + for (int64_t q_head = ith; q_head < n_q_heads; q_head += nth) { + float M_final = -INFINITY; + float S_final = 0.0f; + float * VKQ_final = thread_wdata; + memset(VKQ_final, 0, DV * sizeof(float)); - // permute(0, 2, 1, 3) - memcpy((char *) dst->data + (i3*ne2*ne1 + i2 + i1*ne1)*nb1, VKQ32, nb1); + // Combine partials from all chunks + for (int64_t chunk_idx = 0; chunk_idx < n_chunks; ++chunk_idx) { + const int64_t ic_start = chunk_idx * chunk_size; + if (ic_start >= nek1) continue; + + const float * partial = partials_base + (q_head * n_chunks + chunk_idx) * partial_size; + const float M_chunk = partial[0]; + const float S_chunk = partial[1]; + const float * VKQ_chunk = partial + 2; + + if (S_chunk == 0.0f) continue; + + const float M_new = fmaxf(M_final, M_chunk); + const float scale_old = expf(M_final - M_new); + const float scale_new = expf(M_chunk - M_new); + + for (int64_t d = 0; d < DV; ++d) { + VKQ_final[d] = VKQ_final[d] * scale_old + VKQ_chunk[d] * scale_new; + } + S_final = S_final * scale_old + S_chunk * scale_new; + M_final = M_new; + } + + // Normalize and write to output + if (S_final != 0.0f) { + const float S_inv = 1.0f / S_final; + ggml_vec_scale_f32(DV, VKQ_final, S_inv); + } + // iq1=0, iq3=0 for decode + memcpy((char *) dst->data + (0*ne2*ne1 + q_head + 0*ne1)*nb1, VKQ_final, nb1); } } @@ -8266,6 +8936,7 @@ static void ggml_compute_forward_flash_attn_ext_f16( const int64_t DV = nev0; const int64_t N = neq1; + GGML_ASSERT(ne0 == DV); GGML_ASSERT(ne2 == N); @@ -8286,47 +8957,97 @@ static void ggml_compute_forward_flash_attn_ext_f16( GGML_ASSERT(nb1 <= nb2); GGML_ASSERT(nb2 <= nb3); - // parallelize by q rows using ggml_vec_dot_f32 - - // total rows in q - const int64_t nr = neq1*neq2*neq3; - - // rows per thread const int ith = params->ith; const int nth = params->nth; - // disable for NUMA - const bool disable_chunking = ggml_is_numa(); + // When use_ref is set, force the vec-only reference implementation (no tiling, no KV-chunking) + const bool use_ref = params->use_ref; - // 4x chunks per thread - int nth_scaled = nth * 4; - int64_t chunk_size = (nr + nth_scaled - 1) / nth_scaled; - int64_t nchunk = (nr + chunk_size - 1) / chunk_size; + const bool kv_is_f32_or_f16 = (k->type == GGML_TYPE_F32 || k->type == GGML_TYPE_F16); + const bool use_split_kv_path = !use_ref && (neq1 == 1 && neq3 == 1) && kv_is_f32_or_f16 && (k->type == v->type) && q->type == GGML_TYPE_F32 && nek1 >= 512; - if (nth == 1 || nchunk < nth || disable_chunking) { - nchunk = nth; - } + if (use_split_kv_path) { + const int64_t chunk_size = (nek1 + nth - 1) / nth; - if (ith == 0) { - // Every thread starts at ith, so the first unprocessed chunk is nth. This save a bit of coordination right at the start. - ggml_threadpool_chunk_set(params->threadpool, nth); - } + // Partials buffer layout: [q_head][kv_chunk][M, S, VKQ] + const int64_t partial_size = 2 + DV; + float * partials_base = (float *) params->wdata + nth * (DK + 2*DV + CACHE_LINE_SIZE_F32); - ggml_barrier(params->threadpool); + const int64_t ic_start = ith * chunk_size; + const int64_t ic_end = std::min(ic_start + chunk_size, nek1); - // The number of elements in each chunk - const int64_t dr = (nr + nchunk - 1) / nchunk; + const int64_t partial_stride = nth * partial_size; + float * chunk_partials = partials_base + ith * partial_size; - // The first chunk comes from our thread_id, the rest will get auto-assigned. - int current_chunk = ith; + if (ic_start < nek1) { + for (int64_t q_head = 0; q_head < neq2; q_head++) { + ggml_compute_forward_flash_attn_ext_f16_one_chunk( + params, dst, q_head, q_head + 1, ic_start, ic_end, + chunk_partials, partial_stride); + } + } else { + for (int64_t q_head = 0; q_head < neq2; q_head++) { + float * q_partials = chunk_partials + q_head * partial_stride; + q_partials[0] = -INFINITY; // M + q_partials[1] = 0.0f; // S + } + } - while (current_chunk < nchunk) { - const int64_t ir0 = dr * current_chunk; - const int64_t ir1 = MIN(ir0 + dr, nr); + ggml_barrier(params->threadpool); + ggml_flash_attn_ext_reduce_partials(params, dst, nth, chunk_size); + } else { - ggml_compute_forward_flash_attn_ext_f16_one_chunk(params, dst, ir0, ir1); + // total rows in q + const int64_t nr = neq1*neq2*neq3; - current_chunk = ggml_threadpool_chunk_add(params->threadpool, 1); + // disable for NUMA + const bool disable_chunking = ggml_is_numa(); + + // 4x chunks per thread + int nth_scaled = nth * 4; + int64_t chunk_size = (nr + nth_scaled - 1) / nth_scaled; + int64_t nchunk = (nr + chunk_size - 1) / chunk_size; + + if (nth == 1 || nchunk < nth || disable_chunking) { + nchunk = nth; + } + + if (ith == 0) { + ggml_threadpool_chunk_set(params->threadpool, nth); + } + + ggml_barrier(params->threadpool); + + const int64_t dr = (nr + nchunk - 1) / nchunk; + + static constexpr int64_t Q_TILE_SZ = ggml_fa_tile_config::Q; + bool use_tiled = !use_ref && + (q->type == GGML_TYPE_F32 && + kv_is_f32_or_f16 && + k->type == v->type && + neq1 >= Q_TILE_SZ); +#ifdef GGML_SIMD +#if defined(__ARM_FEATURE_SVE) + const int64_t f32_epr = svcntw(); +#else + const int64_t f32_epr = GGML_F32_EPR; +#endif + use_tiled &= (DV % f32_epr == 0); +#endif + int current_chunk = ith; + + while (current_chunk < nchunk) { + const int64_t ir0 = dr * current_chunk; + const int64_t ir1 = MIN(ir0 + dr, nr); + + if (use_tiled) { + ggml_compute_forward_flash_attn_ext_tiled(params, dst, ir0, ir1); + } else { + ggml_compute_forward_flash_attn_ext_f16_one_chunk(params, dst, ir0, ir1, 0, nek1, nullptr, 0); + } + + current_chunk = ggml_threadpool_chunk_add(params->threadpool, 1); + } } } @@ -9107,7 +9828,7 @@ void ggml_compute_forward_win_unpart( } } -//gmml_compute_forward_unary +//ggml_compute_forward_unary void ggml_compute_forward_unary( const ggml_compute_params * params, @@ -9396,13 +10117,9 @@ static void ggml_compute_forward_rwkv_wkv6_f32( const int ith = params->ith; const int nth = params->nth; - if (ith >= HEADS) { - return; - } - - const int h_start = (HEADS * ith) / nth; - const int h_end = ((HEADS * (ith + 1)) / nth < HEADS) ? - (HEADS * (ith + 1)) / nth : HEADS; + const int h_start = (HEADS * (ith )) / nth; + const int h_end = ((HEADS * (ith + 1)) / nth < HEADS) ? + (HEADS * (ith + 1)) / nth : HEADS; float * k = (float *) dst->src[0]->data; float * v = (float *) dst->src[1]->data; @@ -9613,13 +10330,9 @@ static void ggml_compute_forward_gla_f32( const int ith = params->ith; const int nth = params->nth; - if (ith >= HEADS) { - return; - } - - const int h_start = (HEADS * ith) / nth; - const int h_end = ((HEADS * (ith + 1)) / nth < HEADS) ? - (HEADS * (ith + 1)) / nth : HEADS; + const int h_start = (HEADS * (ith )) / nth; + const int h_end = ((HEADS * (ith + 1)) / nth < HEADS) ? + (HEADS * (ith + 1)) / nth : HEADS; float * k = (float *) dst->src[0]->data; float * v = (float *) dst->src[1]->data; @@ -9870,6 +10583,219 @@ void ggml_compute_forward_solve_tri(const struct ggml_compute_params * params, s } } +// ggml_compute_forward_gated_delta_net +static void ggml_compute_forward_gated_delta_net_one_chunk( + const ggml_compute_params * params, + ggml_tensor * dst, + int64_t ir0, + int64_t ir1) { + + ggml_tensor * src_q = dst->src[0]; + ggml_tensor * src_k = dst->src[1]; + ggml_tensor * src_v = dst->src[2]; + ggml_tensor * src_g = dst->src[3]; + ggml_tensor * src_beta = dst->src[4]; + ggml_tensor * src_state = dst->src[5]; + + const int64_t S_v = src_v->ne[0]; + const int64_t H = src_v->ne[1]; + const int64_t n_tokens = src_v->ne[2]; + const int64_t n_seqs = src_v->ne[3]; + + GGML_ASSERT(ggml_is_contiguous_rows(src_q)); + GGML_ASSERT(ggml_is_contiguous_rows(src_k)); + GGML_ASSERT(ggml_is_contiguous_rows(src_v)); + GGML_ASSERT(ggml_is_contiguous(src_g)); + GGML_ASSERT(ggml_is_contiguous(src_beta)); + GGML_ASSERT(ggml_is_contiguous(src_state)); + + GGML_ASSERT(src_g->ne[0] == 1 || src_g->ne[0] == S_v); + GGML_ASSERT(src_beta->ne[0] == 1); + + GGML_TENSOR_LOCALS(int64_t, neq, src_q, ne); + GGML_TENSOR_LOCALS(size_t, nbq, src_q, nb); + GGML_TENSOR_LOCALS(int64_t, nek, src_k, ne); + GGML_TENSOR_LOCALS(size_t, nbk, src_k, nb); + GGML_TENSOR_LOCALS(int64_t, nev, src_v, ne); + GGML_TENSOR_LOCALS(size_t, nbv, src_v, nb); + GGML_TENSOR_LOCALS(int64_t, neg, src_g, ne); + GGML_TENSOR_LOCALS(size_t, nbg, src_g, nb); + GGML_TENSOR_LOCALS(size_t, nbb, src_beta, nb); + + const bool kda = (neg0 == S_v); + + // K (snapshot slot count) is an op param; state holds s0 only [S_v, S_v, H, n_seqs]. + const int64_t K = ggml_get_op_params_i32(dst, 0); + GGML_ASSERT(K >= 1); + // per-seq stride in floats (seq s starts at state + s * seq_stride) + const int64_t state_seq_stride = src_state->nb[3] / sizeof(float); + + const int64_t per_thread = S_v + (K > 1 ? S_v * S_v : 0); + const int ith = params->ith; + + float * delta = (float *)params->wdata + ith * per_thread + CACHE_LINE_SIZE_F32; + float * state_work = K > 1 ? (delta + S_v) : nullptr; + + // output layout: [attn_scores | new_states] + // attn_scores: S_v * H * n_tokens * n_seqs floats + // new_states: S_v * S_v * H * n_seqs * K floats (K snapshot slots; last min(n_tokens, K)) + const int64_t attn_score_elems = S_v * H * n_tokens * n_seqs; + const int64_t state_size_per_snap = S_v * S_v * H * n_seqs; + float * attn_out_base = (float *)dst->data; + float * state_out_base = (float *)dst->data + attn_score_elems; + + // snapshot slot mapping: slot 0 = most recent state, slot s = s tokens back. + // When n_tokens < K only slots 0..n_tokens-1 are written; older slots are caller-owned. + + const float * state_in_base = (const float *)src_state->data; + + //const int64_t rq1 = nev1 / neq1; + //const int64_t rk1 = nev1 / nek1; + const int64_t rq3 = nev3 / neq3; + const int64_t rk3 = nev3 / nek3; + + const float scale = 1.0f / sqrtf((float) S_v); + + for (int64_t ir = ir0; ir < ir1; ++ir) { + const int64_t iv1 = ir % H; // head_index + const int64_t iv3 = ir / H; // sequence + + const int64_t iq1 = iv1 % neq1; + const int64_t ik1 = iv1 % nek1; + + const int64_t iq3 = iv3 / rq3; + const int64_t ik3 = iv3 / rk3; + + // For K=1, write directly to the single output slot to avoid an extra memcpy at the end. + // For K>1, work in scratch and copy out per-token when the slot is in range. + float * s_out = (K > 1) + ? state_work + : state_out_base + (iv3 * H + iv1) * S_v * S_v; + + // copy input state into the working buffer and operate in-place + // state layout [S_v, S_v, H, n_seqs]: seq iv3 starts at iv3 * state_seq_stride. + const float * s_in = state_in_base + iv3 * state_seq_stride + iv1 * S_v * S_v; + memcpy(s_out, s_in, S_v * S_v * sizeof(float)); + + // attn output pointer for first token of this (head, seq) + float * attn_data = attn_out_base + (iv3 * n_tokens * H + iv1) * S_v; + + for (int64_t t = 0; t < n_tokens; t++) { + const float * q_d = (const float *)((const char *)src_q->data + iq3 * nbq3 + t * nbq2 + iq1 * nbq1); + const float * k_d = (const float *)((const char *)src_k->data + ik3 * nbk3 + t * nbk2 + ik1 * nbk1); + const float * v_d = (const float *)((const char *)src_v->data + iv3 * nbv3 + t * nbv2 + iv1 * nbv1); + + const float beta_val = *(const float *)((const char *)src_beta->data + iv3 * nbb3 + t * nbb2 + iv1 * nbb1); + const float * g_d = (const float *)((const char *)src_g->data + iv3 * nbg3 + t * nbg2 + iv1 * nbg1); + + // state is stored transposed: s_out[j*S_v + i] = S[i][j] + // so row j of s_out = column j of S (contiguous access) + + if (kda) { + // precompute exp(g) into delta scratch (reused below) + for (int64_t i = 0; i < S_v; ++i) { + delta[i] = expf(g_d[i]); + } + // S[i][:] *= exp(g[i]) => for each row j of M: M[j][i] *= exp(g[i]) + for (int64_t j = 0; j < S_v; ++j) { + ggml_vec_mul_f32(S_v, &s_out[j * S_v], &s_out[j * S_v], delta); + } + } else { + ggml_vec_scale_f32(S_v * S_v, s_out, expf(g_d[0])); + } + + // delta[j] = sum_i S[i][j] * k[i] = dot(row j of M, k) + for (int64_t j = 0; j < S_v; ++j) { + float sum = 0.0f; + ggml_vec_dot_f32(S_v, &sum, 0, &s_out[j * S_v], 0, k_d, 0, 1); + delta[j] = (v_d[j] - sum) * beta_val; + } + + // outer product: S[i][j] += k[i] * delta[j] => M[j][i] += delta[j] * k[i] + for (int64_t j = 0; j < S_v; ++j) { + ggml_vec_mad_f32(S_v, &s_out[j * S_v], k_d, delta[j]); + } + + // attn_out[j] = sum_i S[i][j] * q[i] = dot(row j of M, q) + for (int64_t j = 0; j < S_v; ++j) { + float sum = 0.0f; + ggml_vec_dot_f32(S_v, &sum, 0, &s_out[j * S_v], 0, q_d, 0, 1); + attn_data[j] = sum * scale; + } + + attn_data += S_v * H; // advance to next token + + if (K > 1) { + const int64_t target_slot = n_tokens - 1 - t; + if (target_slot >= 0 && target_slot < K) { + float * curr_state_o = state_out_base + target_slot * state_size_per_snap + + (iv3 * H + iv1) * S_v * S_v; + memcpy(curr_state_o, s_out, S_v * S_v * sizeof(float)); + } + } + } + } +} + + +static void ggml_compute_forward_gated_delta_net_f32( + const ggml_compute_params * params, + ggml_tensor * dst) { + + ggml_tensor * V = dst->src[2]; + int64_t nr = V->ne[1] * V->ne[3]; + + // disable for NUMA + const bool disable_chunking = ggml_is_numa(); + + int nth = params->nth; + int ith = params->ith; + + // 4x chunks per thread + int nth_scaled = nth * 4; + int64_t chunk_size = (nr + nth_scaled - 1) / nth_scaled; + int64_t nchunk = (nr + chunk_size - 1) / chunk_size; + + if (nth == 1 || nchunk < nth || disable_chunking) { + nchunk = nth; + } + + if (ith == 0) { + ggml_threadpool_chunk_set(params->threadpool, nth); + } + + ggml_barrier(params->threadpool); + + const int64_t dr = (nr + nchunk - 1) / nchunk; + + int current_chunk = ith; + + while (current_chunk < nchunk) { + const int64_t ir0 = dr * current_chunk; + const int64_t ir1 = MIN(ir0 + dr, nr); + + ggml_compute_forward_gated_delta_net_one_chunk(params, dst, ir0, ir1); + current_chunk = ggml_threadpool_chunk_add(params->threadpool, 1); + } +} + +void ggml_compute_forward_gated_delta_net( + const ggml_compute_params * params, + ggml_tensor * dst) { + const ggml_tensor * src0 = dst->src[0]; + + switch (src0->type) { + case GGML_TYPE_F32: + { + ggml_compute_forward_gated_delta_net_f32(params, dst); + } break; + default: + { + GGML_ABORT("fatal error"); + } + } +} + // ggml_compute_forward_rwkv_wkv7 static void ggml_compute_forward_rwkv_wkv7_f32( @@ -9887,13 +10813,9 @@ static void ggml_compute_forward_rwkv_wkv7_f32( const int ith = params->ith; const int nth = params->nth; - if (ith >= HEADS) { - return; - } - - const int h_start = (HEADS * ith) / nth; - const int h_end = ((HEADS * (ith + 1)) / nth < HEADS) ? - (HEADS * (ith + 1)) / nth : HEADS; + const int h_start = (HEADS * (ith )) / nth; + const int h_end = ((HEADS * (ith + 1)) / nth < HEADS) ? + (HEADS * (ith + 1)) / nth : HEADS; float * r = (float *) dst->src[0]->data; float * w = (float *) dst->src[1]->data; @@ -10195,7 +11117,7 @@ static void ggml_compute_forward_cross_entropy_loss_f32( assert(!isnan(s0[i])); assert(!isnan(s1[i])); } -#endif +#endif // NDEBUG float max = -INFINITY; ggml_vec_max_f32(nc, &max, s0); @@ -10214,7 +11136,7 @@ static void ggml_compute_forward_cross_entropy_loss_f32( assert(!isnan(st[i])); assert(!isinf(st[i])); } -#endif +#endif // NDEBUG } sums[ith] = sum_thread; ggml_barrier(params->threadpool); @@ -10287,7 +11209,7 @@ static void ggml_compute_forward_cross_entropy_loss_back_f32( assert(!isnan(s0[i])); assert(!isnan(s1[i])); } -#endif +#endif // NDEBUG // soft_max float max = -INFINITY; @@ -10305,7 +11227,7 @@ static void ggml_compute_forward_cross_entropy_loss_back_f32( assert(!isnan(ds0[i])); assert(!isinf(ds0[i])); } -#endif +#endif // NDEBUG } } @@ -10471,3 +11393,95 @@ void ggml_compute_forward_opt_step_sgd(const ggml_compute_params * params, ggml_ } } } + +static void ggml_compute_forward_fwht_f32(const ggml_compute_params * params, ggml_tensor * dst) { + const ggml_tensor * src0 = dst->src[0]; + const ggml_tensor * src1 = dst->src[1]; + + GGML_ASSERT(src1->type == GGML_TYPE_F32); + GGML_ASSERT(dst->type == GGML_TYPE_F32); + + GGML_TENSOR_BINARY_OP_LOCALS + + const int ith = params->ith; + const int nth = params->nth; + + const int64_t n = ne10; + GGML_ASSERT((n & (n - 1)) == 0); // must be power of 2 + + const int64_t nr = ne11 * ne12 * ne13; + const int64_t rows_per_thread = (nr + nth - 1) / nth; + const int64_t start_row = ith * rows_per_thread; + const int64_t end_row = MIN(start_row + rows_per_thread, nr); + + const float scale = 1.0f / sqrtf((float)n); + +#if defined(GGML_SIMD) + const GGML_F32_VEC v_minus_one = GGML_F32_VEC_SET1(-1.0f); +#endif + + for (int64_t r = start_row; r < end_row; r++) { + const int64_t i13 = r / (ne11 * ne12); + const int64_t i12 = (r - i13 * ne11 * ne12) / ne11; + const int64_t i11 = r - i13 * ne11 * ne12 - i12 * ne11; + + const float * src_row = (const float *) ((const char *) src1->data + i11 * nb11 + i12 * nb12 + i13 * nb13); + float * dst_row = (float *) ((char *) dst->data + i11 * nb1 + i12 * nb2 + i13 * nb3); + + for (int64_t j = 0; j < n; j++) { + dst_row[j] = src_row[j] * scale; + } + + // Scalar passes +#if defined(GGML_SIMD) +#if defined(__ARM_FEATURE_SVE) + const int step = svcntw(); +#else + const int step = GGML_F32_EPR; +#endif +#else + const int step = n; +#endif + for (int64_t len = 1; len < step && len < n; len <<= 1) { + for (int64_t i = 0; i < n; i += 2 * len) { + for (int64_t j = 0; j < len; j++) { + float u = dst_row[i + j]; + float v = dst_row[i + len + j]; + dst_row[i + j] = u + v; + dst_row[i + len + j] = u - v; + } + } + } + + // SIMD passes using GGML_F32_VEC_* macros for multi-architecture support +#if defined(GGML_SIMD) + for (int64_t len = step; len < n; len <<= 1) { + for (int64_t i = 0; i < n; i += 2 * len) { + for (int64_t j = 0; j < len; j += step) { + GGML_F32_VEC u = GGML_F32_VEC_LOAD(dst_row + i + j); + GGML_F32_VEC v = GGML_F32_VEC_LOAD(dst_row + i + len + j); + + GGML_F32_VEC_STORE(dst_row + i + j, GGML_F32_VEC_ADD(u, v)); + GGML_F32_VEC_STORE(dst_row + i + len + j, GGML_F32_VEC_FMA(u, v, v_minus_one)); + } + } + } +#endif + } +} + +void ggml_compute_forward_fwht(const ggml_compute_params * params, ggml_tensor * dst) { + const ggml_tensor * src1 = dst->src[1]; + + switch (src1->type) { + case GGML_TYPE_F32: + { + ggml_compute_forward_fwht_f32(params, dst); + } + break; + default: + { + GGML_ABORT("fatal error - fwht is F32 only"); + } + } +} diff --git a/ggml/src/ggml-cpu/ops.h b/ggml/src/ggml-cpu/ops.h index 0fdfee79766..a8e18c716db 100644 --- a/ggml/src/ggml-cpu/ops.h +++ b/ggml/src/ggml-cpu/ops.h @@ -44,6 +44,7 @@ void ggml_compute_forward_concat(const struct ggml_compute_params * params, stru void ggml_compute_forward_silu_back(const struct ggml_compute_params * params, struct ggml_tensor * dst); void ggml_compute_forward_norm(const struct ggml_compute_params * params, struct ggml_tensor * dst); void ggml_compute_forward_rms_norm(const struct ggml_compute_params * params, struct ggml_tensor * dst); +void ggml_compute_forward_rms_norm_mul_fused(const struct ggml_compute_params * params, struct ggml_tensor * dst_rms_norm, struct ggml_tensor * dst_mul); void ggml_compute_forward_rms_norm_back(const struct ggml_compute_params * params, struct ggml_tensor * dst); void ggml_compute_forward_group_norm(const struct ggml_compute_params * params, struct ggml_tensor * dst); void ggml_compute_forward_l2_norm(const struct ggml_compute_params * params, struct ggml_tensor * dst); @@ -67,6 +68,7 @@ void ggml_compute_forward_conv_transpose_1d(const struct ggml_compute_params * p void ggml_compute_forward_im2col(const struct ggml_compute_params * params, struct ggml_tensor * dst); void ggml_compute_forward_im2col_back_f32(const struct ggml_compute_params * params, struct ggml_tensor * dst); void ggml_compute_forward_im2col_3d(const struct ggml_compute_params * params, struct ggml_tensor * dst); +void ggml_compute_forward_col2im_1d(const struct ggml_compute_params * params, struct ggml_tensor * dst); void ggml_compute_forward_conv_2d(const struct ggml_compute_params * params, struct ggml_tensor * dst); void ggml_compute_forward_conv_3d(const struct ggml_compute_params * params, struct ggml_tensor * dst); void ggml_compute_forward_conv_transpose_2d(const struct ggml_compute_params * params, struct ggml_tensor * dst); @@ -102,6 +104,7 @@ void ggml_compute_forward_rwkv_wkv6(const struct ggml_compute_params * params, s void ggml_compute_forward_rwkv_wkv7(const struct ggml_compute_params * params, struct ggml_tensor * dst); void ggml_compute_forward_solve_tri(const struct ggml_compute_params * params, struct ggml_tensor * dst); void ggml_compute_forward_gla(const struct ggml_compute_params * params, struct ggml_tensor * dst); +void ggml_compute_forward_gated_delta_net(const struct ggml_compute_params * params, struct ggml_tensor * dst); void ggml_compute_forward_map_custom1(const struct ggml_compute_params * params, struct ggml_tensor * dst); void ggml_compute_forward_map_custom2(const struct ggml_compute_params * params, struct ggml_tensor * dst); void ggml_compute_forward_map_custom3(const struct ggml_compute_params * params, struct ggml_tensor * dst); @@ -110,6 +113,7 @@ void ggml_compute_forward_cross_entropy_loss(const struct ggml_compute_params * void ggml_compute_forward_cross_entropy_loss_back(const struct ggml_compute_params * params, struct ggml_tensor * dst); void ggml_compute_forward_opt_step_adamw(const struct ggml_compute_params * params, struct ggml_tensor * dst); void ggml_compute_forward_mul_mat(const struct ggml_compute_params * params, struct ggml_tensor * dst); +void ggml_compute_forward_fwht(const struct ggml_compute_params * params, struct ggml_tensor * dst); void ggml_compute_forward_opt_step_sgd(const struct ggml_compute_params * params, struct ggml_tensor * dst); #ifdef __cplusplus } diff --git a/ggml/src/ggml-cpu/quants.c b/ggml/src/ggml-cpu/quants.c index 365cb36d2d7..e5f9a4083f9 100644 --- a/ggml/src/ggml-cpu/quants.c +++ b/ggml/src/ggml-cpu/quants.c @@ -22,6 +22,10 @@ #define UNUSED GGML_UNUSED +void quantize_row_q1_0(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k) { + quantize_row_q1_0_ref(x, y, k); +} + void quantize_row_q4_0(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k) { quantize_row_q4_0_ref(x, y, k); } @@ -50,6 +54,10 @@ void quantize_row_mxfp4(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, i quantize_row_mxfp4_ref(x, y, k); } +void quantize_row_nvfp4(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k) { + quantize_row_nvfp4_ref(x, y, k); +} + // // 2-6 bit quantization in super-blocks // @@ -112,6 +120,57 @@ void quantize_row_q8_K_generic(const float * GGML_RESTRICT x, void * GGML_RESTRI //===================================== Dot products ================================= +void ggml_vec_dot_q1_0_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + const int qk = QK1_0; + const int nb = n / qk; + + assert(n % qk == 0); + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + + const block_q1_0 * GGML_RESTRICT x = vx; + const block_q8_0 * GGML_RESTRICT y = vy; + + float sumf = 0.0; + + for (int i = 0; i < nb; i++) { + const float d0 = GGML_CPU_FP16_TO_FP32(x[i].d); + + float sumi = 0.0f; + + for (int k = 0; k < 4; k++) { + const block_q8_0 * GGML_RESTRICT yb = &y[i * 4 + k]; + const float d1 = GGML_CPU_FP16_TO_FP32(yb->d); + int sumi_block = 0; + + const uint8_t * GGML_RESTRICT bits = &x[i].qs[k * 4]; + const int8_t * GGML_RESTRICT qy = yb->qs; + + for (int b = 0; b < 4; ++b, qy += 8) { + const unsigned mask = bits[b]; + sumi_block += ((mask & 0x01) ? qy[0] : -qy[0]) + + ((mask & 0x02) ? qy[1] : -qy[1]) + + ((mask & 0x04) ? qy[2] : -qy[2]) + + ((mask & 0x08) ? qy[3] : -qy[3]) + + ((mask & 0x10) ? qy[4] : -qy[4]) + + ((mask & 0x20) ? qy[5] : -qy[5]) + + ((mask & 0x40) ? qy[6] : -qy[6]) + + ((mask & 0x80) ? qy[7] : -qy[7]); + } + + sumi += d1 * sumi_block; + } + + sumf += d0 * sumi; + } + + *s = sumf; +} + + void ggml_vec_dot_q4_0_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { const int qk = QK8_0; const int nb = n / qk; @@ -216,6 +275,42 @@ void ggml_vec_dot_mxfp4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, *s = sumf; } +// NVFP4: super-block of 64 elements = 4 sub-blocks of 16 = 2 q8_0 blocks +void ggml_vec_dot_nvfp4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + assert(n % QK_NVFP4 == 0); + + const block_nvfp4 * GGML_RESTRICT x = vx; + const block_q8_0 * GGML_RESTRICT y = vy; + + const int nb = n / QK_NVFP4; + + float sumf = 0; + + for (int ib = 0; ib < nb; ++ib) { + for (int s_idx = 0; s_idx < 4; ++s_idx) { + const float d = ggml_ue4m3_to_fp32(x[ib].d[s_idx]); + const int q8_block = s_idx / 2; + const int q8_off = (s_idx % 2) * QK_NVFP4_SUB; + const float dy = GGML_CPU_FP16_TO_FP32(y[2*ib + q8_block].d); + + int sumi_lo = 0, sumi_hi = 0; + for (int j = 0; j < QK_NVFP4_SUB/2; ++j) { + const uint8_t qv = x[ib].qs[s_idx*(QK_NVFP4_SUB/2) + j]; + sumi_lo += y[2*ib + q8_block].qs[q8_off + j + 0] * kvalues_mxfp4[qv & 0xf]; + sumi_hi += y[2*ib + q8_block].qs[q8_off + j + QK_NVFP4_SUB/2] * kvalues_mxfp4[qv >> 4]; + } + + sumf += dy * d * (sumi_lo + sumi_hi); + } + } + *s = sumf; +} + void ggml_vec_dot_q5_0_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) { const int qk = QK8_0; const int nb = n / qk; diff --git a/ggml/src/ggml-cpu/quants.h b/ggml/src/ggml-cpu/quants.h index d83eb1b144d..d4bc87a1c05 100644 --- a/ggml/src/ggml-cpu/quants.h +++ b/ggml/src/ggml-cpu/quants.h @@ -12,6 +12,7 @@ extern "C" { #endif // Quantization +void quantize_row_q1_0(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); void quantize_row_q4_0(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); void quantize_row_q4_1(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); void quantize_row_q5_0(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); @@ -20,6 +21,7 @@ void quantize_row_q8_0(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, in void quantize_row_q8_1(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); void quantize_row_mxfp4(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); +void quantize_row_nvfp4(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); void quantize_row_q2_K(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); void quantize_row_q3_K(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); @@ -35,6 +37,7 @@ void quantize_row_iq4_nl (const float * GGML_RESTRICT x, void * GGML_RESTRICT y, void quantize_row_iq4_xs (const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); // Dot product +void ggml_vec_dot_q1_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); void ggml_vec_dot_q4_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); void ggml_vec_dot_q4_1_q8_1(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); void ggml_vec_dot_q5_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); @@ -42,6 +45,7 @@ void ggml_vec_dot_q5_1_q8_1(int n, float * GGML_RESTRICT s, size_t bs, const voi void ggml_vec_dot_q8_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); void ggml_vec_dot_mxfp4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); +void ggml_vec_dot_nvfp4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); void ggml_vec_dot_q2_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); void ggml_vec_dot_q3_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); @@ -66,6 +70,7 @@ void ggml_vec_dot_iq3_s_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const void quantize_row_q8_0_generic(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k); void quantize_row_q8_1_generic(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k); void quantize_row_q8_K_generic(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); +void ggml_vec_dot_q1_0_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); void ggml_vec_dot_q4_0_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); void ggml_vec_dot_q4_1_q8_1_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); void ggml_vec_dot_q5_0_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); @@ -73,6 +78,7 @@ void ggml_vec_dot_q5_1_q8_1_generic(int n, float * GGML_RESTRICT s, size_t bs, c void ggml_vec_dot_q8_0_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); void ggml_vec_dot_mxfp4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); +void ggml_vec_dot_nvfp4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); void ggml_vec_dot_tq1_0_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); void ggml_vec_dot_tq2_0_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); diff --git a/ggml/src/ggml-cpu/repack.cpp b/ggml/src/ggml-cpu/repack.cpp index fbf7ed9432a..f18758f16bb 100644 --- a/ggml/src/ggml-cpu/repack.cpp +++ b/ggml/src/ggml-cpu/repack.cpp @@ -48,6 +48,90 @@ static inline int nearest_int(float fval) { extern "C" { +#if defined __riscv_zvfh +void ggml_quantize_mat_q8_0_4x1_generic(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k) { + assert(QK8_0 == 32); + assert(k % QK8_0 == 0); + const int nb = k / QK8_0; + + block_q8_0x4 * GGML_RESTRICT y = (block_q8_0x4 *) vy; + + // scalar + const int blck_size_interleave = 1; + float srcv[4][QK8_0]; + float id[4]; + + for (int i = 0; i < nb; i++) { + for (int row_iter = 0; row_iter < 4; row_iter++) { + float amax = 0.0f; // absolute max + + for (int j = 0; j < QK8_0; j++) { + srcv[row_iter][j] = x[row_iter * k + i * QK8_0 + j]; + amax = MAX(amax, fabsf(srcv[row_iter][j])); + } + + const float d = amax / ((1 << 7) - 1); + id[row_iter] = d ? 1.0f / d : 0.0f; + + y[i].d[row_iter] = GGML_CPU_FP32_TO_FP16(d); + } + + for (int j = 0; j < QK8_0 * 4; j++) { + int src_offset = (j / (4 * blck_size_interleave)) * blck_size_interleave; + int src_id = (j % (4 * blck_size_interleave)) / blck_size_interleave; + src_offset += (j % blck_size_interleave); + + float x0 = srcv[src_id][src_offset] * id[src_id]; + y[i].qs[j] = roundf(x0); + } + } +} + +void ggml_quantize_mat_q8_K_4x1_generic(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k) { + assert(QK_K == 256); + assert(k % QK_K == 0); + const int nb = k / QK_K; + + block_q8_Kx4 * GGML_RESTRICT y = (block_q8_Kx4 *) vy; + + const int blck_size_interleave = 1; + float srcv[4][QK_K]; + float iscale[4]; + + for (int i = 0; i < nb; i++) { + for (int row_iter = 0; row_iter < 4; row_iter++) { + float amax = 0.0f; // absolute max + float max = 0; + + for (int j = 0; j < QK_K; j++) { + srcv[row_iter][j] = x[row_iter * k + i * QK_K + j]; + // Update the maximum value of the corresponding super block + if(amax < fabsf(srcv[row_iter][j])) { + amax = fabsf(srcv[row_iter][j]); + max = srcv[row_iter][j]; + } + } + + iscale[row_iter] = amax ? -127.f/max : 0; + y[i].d[row_iter] = amax ? 1/iscale[row_iter] : 0; + } + + for (int j = 0; j < QK_K / 4; j++) { + y[i].bsums[j] = 0; + } + for (int j = 0; j < QK_K * 4; j++) { + int src_id = j % 4; + int src_offset = j / 4; + int index = ((j >> 6) << 2) + (j & 3); + + float x0 = srcv[src_id][src_offset] * iscale[src_id]; + y[i].qs[j] = nearest_int(x0); + y[i].bsums[index] += y[i].qs[j]; + } + } +} +#endif + void ggml_quantize_mat_q8_0_4x4_generic(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k) { assert(QK8_0 == 32); assert(k % QK8_0 == 0); @@ -124,7 +208,6 @@ void ggml_quantize_mat_q8_0_4x8_generic(const float * GGML_RESTRICT x, void * GG } } - void ggml_quantize_mat_q8_K_4x4_generic(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k) { assert(QK_K == 256); assert(k % QK_K == 0); @@ -256,192 +339,289 @@ template <> void ggml_quantize_mat_t<8, GGML_TYPE_Q8_K>(const float * GGML_RESTR ggml_quantize_mat_q8_K_4x8(x, vy, n_per_row); } -extern "C" { +#if defined __riscv_zvfh +template <> void ggml_quantize_mat_t<1, GGML_TYPE_Q8_0>(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t nrow, int64_t n_per_row) { + assert(nrow == 4); + UNUSED(nrow); + ggml_quantize_mat_q8_0_4x1(x, vy, n_per_row); +} -void ggml_gemv_q4_0_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { - const int qk = QK8_0; - const int nb = n / qk; - const int ncols_interleaved = 4; - const int blocklen = 4; +template <> void ggml_quantize_mat_t<1, GGML_TYPE_Q8_K>(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t nrow, int64_t n_per_row) { + assert(nrow == 4); + UNUSED(nrow); + ggml_quantize_mat_q8_K_4x1(x, vy, n_per_row); +} +#endif + +template <int M, int N> +static void ggml_gemv_q6_K_NxM_q8_K_generic_impl(int n, + float * GGML_RESTRICT s, + size_t bs, + const void * GGML_RESTRICT vx, + const void * GGML_RESTRICT vy, + int nr, + int nc) { + constexpr int blocklen = M; + constexpr int ncols_interleaved = N; + const int qk = QK_K; + const int nb = n / qk; + const int blocks_per_half = 64 / blocklen; - assert(nr == 1); assert(n % qk == 0); assert(nc % ncols_interleaved == 0); - UNUSED(s); UNUSED(bs); - UNUSED(vx); - UNUSED(vy); UNUSED(nr); - UNUSED(nc); - UNUSED(nb); - UNUSED(ncols_interleaved); - UNUSED(blocklen); - float sumf[4]; - int sumi; + float sumf[8]; - const block_q8_0 * a_ptr = (const block_q8_0 *) vy; + const block_q8_K * a_ptr = (const block_q8_K *) vy; for (int x = 0; x < nc / ncols_interleaved; x++) { - const block_q4_0x4 * b_ptr = (const block_q4_0x4 *) vx + (x * nb); + const block_q6_Kx8 * b_ptr = (const block_q6_Kx8 *) vx + (x * nb); + + for (int j = 0; j < ncols_interleaved; j++) { + sumf[j] = 0.0f; + } - for (int j = 0; j < ncols_interleaved; j++) sumf[j] = 0.0; for (int l = 0; l < nb; l++) { for (int k = 0; k < (qk / (2 * blocklen)); k++) { + const int base_l = (k / blocks_per_half) * 128 + (k % blocks_per_half) * blocklen; + const int base_h = base_l + 64; + + const int scale_idx_l = base_l / 16; + const int scale_idx_h = base_h / 16; + + const int qh_shift_l = ((base_l % 128) / 32) * 2; + const int qh_shift_h = ((base_h % 128) / 32) * 2; + + const int qh_half_l = (base_l / 128) * 32; + const int qh_half_h = (base_h / 128) * 32; + for (int j = 0; j < ncols_interleaved; j++) { - sumi = 0; - for (int i = 0; i < blocklen; ++i) { - const int v0 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] << 4); - const int v1 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0xF0); - sumi += ((v0 * a_ptr[l].qs[k * blocklen + i]) + (v1 * a_ptr[l].qs[k * blocklen + i + qk / 2])) >> 4; + const int8_t scale_l = b_ptr[l].scales[scale_idx_l * ncols_interleaved + j]; + const int8_t scale_h = b_ptr[l].scales[scale_idx_h * ncols_interleaved + j]; + + int sumi_l = 0; + int sumi_h = 0; + + for (int i = 0; i < blocklen; i++) { + const int ql_pos = k * ncols_interleaved * blocklen + j * blocklen + i; + const int l_4 = b_ptr[l].ql[ql_pos] & 0xF; + const int hi_4 = (b_ptr[l].ql[ql_pos] >> 4) & 0xF; + + const int qh_idx_l = qh_half_l + ((base_l + i) % 32); + const int qh_chunk_l = qh_idx_l / blocklen; + const int qh_pos_l = qh_idx_l % blocklen; + const int qh_offset_l = qh_chunk_l * (blocklen * ncols_interleaved) + j * blocklen + qh_pos_l; + const int hi_2_l = (b_ptr[l].qh[qh_offset_l] >> qh_shift_l) & 0x3; + + const int qh_idx_h = qh_half_h + ((base_h + i) % 32); + const int qh_chunk_h = qh_idx_h / blocklen; + const int qh_pos_h = qh_idx_h % blocklen; + const int qh_offset_h = qh_chunk_h * (blocklen * ncols_interleaved) + j * blocklen + qh_pos_h; + const int hi_2_h = (b_ptr[l].qh[qh_offset_h] >> qh_shift_h) & 0x3; + + const int q_l = ((hi_2_l << 4) | l_4) - 32; + const int q_h = ((hi_2_h << 4) | hi_4) - 32; + + const int8_t a_l = a_ptr[l].qs[base_l + i]; + const int8_t a_h = a_ptr[l].qs[base_h + i]; + + sumi_l += q_l * a_l; + sumi_h += q_h * a_h; } - sumf[j] += sumi * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * GGML_CPU_FP16_TO_FP32(a_ptr[l].d); + + sumf[j] += + (sumi_l * scale_l + sumi_h * scale_h) * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * a_ptr[l].d; } } } - for (int j = 0; j < ncols_interleaved; j++) s[x * ncols_interleaved + j] = sumf[j]; + + for (int j = 0; j < ncols_interleaved; j++) { + s[x * ncols_interleaved + j] = sumf[j]; + } } } -void ggml_gemv_q4_0_4x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { - const int qk = QK8_0; - const int nb = n / qk; - const int ncols_interleaved = 4; - const int blocklen = 8; +template <int M, int N> +static void ggml_gemm_q6_K_NxM_q8_K_generic_impl(int n, + float * GGML_RESTRICT s, + size_t bs, + const void * GGML_RESTRICT vx, + const void * GGML_RESTRICT vy, + int nr, + int nc) { + constexpr int blocklen = M; + constexpr int ncols_interleaved = N; + const int qk = QK_K; + const int nb = n / qk; + const int blocks_per_half = 64 / blocklen; + const int q8_half_stride = 512; + const int q8_low_high_step = 256; - assert (n % qk == 0); - assert (nc % ncols_interleaved == 0); + assert(n % qk == 0); + assert(nr % 4 == 0); + assert(nc % ncols_interleaved == 0); - UNUSED(s); UNUSED(bs); - UNUSED(vx); - UNUSED(vy); - UNUSED(nr); - UNUSED(nc); - UNUSED(nb); - UNUSED(ncols_interleaved); - UNUSED(blocklen); - float sumf[4]; - int sumi; + float sumf[4][8]; - const block_q8_0 * a_ptr = (const block_q8_0 *) vy; - for (int x = 0; x < nc / ncols_interleaved; x++) { - const block_q4_0x4 * b_ptr = (const block_q4_0x4 *) vx + (x * nb); + for (int y = 0; y < nr / 4; y++) { + const block_q8_Kx4 * a_ptr = (const block_q8_Kx4 *) vy + (y * nb); + for (int x = 0; x < nc / ncols_interleaved; x++) { + const block_q6_Kx8 * b_ptr = (const block_q6_Kx8 *) vx + (x * nb); - for (int j = 0; j < ncols_interleaved; j++) sumf[j] = 0.0; - for (int l = 0; l < nb; l++) { - for (int k = 0; k < (qk / (2 * blocklen)); k++) { + for (int m = 0; m < 4; m++) { for (int j = 0; j < ncols_interleaved; j++) { - sumi = 0; - for (int i = 0; i < blocklen; ++i) { - const int v0 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] << 4); - const int v1 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0xF0); - sumi += ((v0 * a_ptr[l].qs[k * blocklen + i]) + (v1 * a_ptr[l].qs[k * blocklen + i + qk / 2])) >> 4; - } - sumf[j] += sumi * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * GGML_CPU_FP16_TO_FP32(a_ptr[l].d); + sumf[m][j] = 0.0f; } } - } - for (int j = 0; j < ncols_interleaved; j++) s[x * ncols_interleaved + j] = sumf[j]; - } -} -void ggml_gemv_q4_0_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { - const int qk = QK8_0; - const int nb = n / qk; - const int ncols_interleaved = 8; - const int blocklen = 8; + for (int l = 0; l < nb; l++) { + for (int k = 0; k < (qk / (2 * blocklen)); k++) { + const int base_l = (k / blocks_per_half) * 128 + (k % blocks_per_half) * blocklen; + const int base_h = base_l + 64; - assert (n % qk == 0); - assert (nc % ncols_interleaved == 0); + const int scale_idx_l = base_l / 16; + const int scale_idx_h = base_h / 16; - UNUSED(s); - UNUSED(bs); - UNUSED(vx); - UNUSED(vy); - UNUSED(nr); - UNUSED(nc); - UNUSED(nb); - UNUSED(ncols_interleaved); - UNUSED(blocklen); + const int qh_shift_l = ((base_l % 128) / 32) * 2; + const int qh_shift_h = ((base_h % 128) / 32) * 2; - float sumf[8]; - int sumi; + const int qh_half_l = (base_l / 128) * 32; + const int qh_half_h = (base_h / 128) * 32; - const block_q8_0 * a_ptr = (const block_q8_0 *) vy; - for (int x = 0; x < nc / ncols_interleaved; x++) { - const block_q4_0x8 * b_ptr = (const block_q4_0x8 *) vx + (x * nb); + const int q8_base = (k / blocks_per_half) * q8_half_stride + (k % blocks_per_half) * (blocklen * 4); - for (int j = 0; j < ncols_interleaved; j++) sumf[j] = 0.0; - for (int l = 0; l < nb; l++) { - for (int k = 0; k < (qk / (2 * blocklen)); k++) { - for (int j = 0; j < ncols_interleaved; j++) { - sumi = 0; - for (int i = 0; i < blocklen; ++i) { - const int v0 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] << 4); - const int v1 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0xF0); - sumi += ((v0 * a_ptr[l].qs[k * blocklen + i]) + (v1 * a_ptr[l].qs[k * blocklen + i + qk / 2])) >> 4; + for (int m = 0; m < 4; m++) { + for (int j = 0; j < ncols_interleaved; j++) { + const int8_t scale_l = b_ptr[l].scales[scale_idx_l * ncols_interleaved + j]; + const int8_t scale_h = b_ptr[l].scales[scale_idx_h * ncols_interleaved + j]; + + int sumi_l = 0; + int sumi_h = 0; + + for (int i = 0; i < blocklen; i++) { + const int ql_pos = k * ncols_interleaved * blocklen + j * blocklen + i; + const int l_4 = b_ptr[l].ql[ql_pos] & 0xF; + const int hi_4 = (b_ptr[l].ql[ql_pos] >> 4) & 0xF; + + const int qh_idx_l = qh_half_l + ((base_l + i) % 32); + const int qh_chunk_l = qh_idx_l / blocklen; + const int qh_pos_l = qh_idx_l % blocklen; + const int qh_offset_l = + qh_chunk_l * (blocklen * ncols_interleaved) + j * blocklen + qh_pos_l; + const int hi_2_l = (b_ptr[l].qh[qh_offset_l] >> qh_shift_l) & 0x3; + + const int qh_idx_h = qh_half_h + ((base_h + i) % 32); + const int qh_chunk_h = qh_idx_h / blocklen; + const int qh_pos_h = qh_idx_h % blocklen; + const int qh_offset_h = + qh_chunk_h * (blocklen * ncols_interleaved) + j * blocklen + qh_pos_h; + const int hi_2_h = (b_ptr[l].qh[qh_offset_h] >> qh_shift_h) & 0x3; + + const int q_l = ((hi_2_l << 4) | l_4) - 32; + const int q_h = ((hi_2_h << 4) | hi_4) - 32; + + const int8_t q8_l = a_ptr[l].qs[q8_base + m * blocklen + i]; + const int8_t q8_h = a_ptr[l].qs[q8_base + m * blocklen + i + q8_low_high_step]; + + sumi_l += q_l * q8_l; + sumi_h += q_h * q8_h; + } + + sumf[m][j] += (sumi_l * scale_l + sumi_h * scale_h) * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * + a_ptr[l].d[m]; + } } - sumf[j] += sumi * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * GGML_CPU_FP16_TO_FP32(a_ptr[l].d); + } + } + + for (int m = 0; m < 4; m++) { + for (int j = 0; j < ncols_interleaved; j++) { + s[(y * 4 + m) * bs + x * ncols_interleaved + j] = sumf[m][j]; } } } - for (int j = 0; j < ncols_interleaved; j++) s[x * ncols_interleaved + j] = sumf[j]; } } -void ggml_gemv_q4_K_8x4_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { - const int qk = QK_K; - const int nb = n / qk; - const int ncols_interleaved = 8; - const int blocklen = 4; - static const uint32_t kmask1 = 0x3f3f3f3f; - static const uint32_t kmask2 = 0x0f0f0f0f; - static const uint32_t kmask3 = 0x03030303; +template <int M, int N> +static void ggml_gemv_q5_K_NxM_q8_K_generic_impl(int n, + float * GGML_RESTRICT s, + size_t bs, + const void * GGML_RESTRICT vx, + const void * GGML_RESTRICT vy, + int nr, + int nc) { + constexpr int blocklen = M; + constexpr int ncols_interleaved = N; + const int qk = QK_K; + const int nb = n / qk; + static const uint32_t kmask1 = 0x3f3f3f3f; + static const uint32_t kmask2 = 0x0f0f0f0f; + static const uint32_t kmask3 = 0x03030303; - assert (n % qk == 0); - assert (nc % ncols_interleaved == 0); + assert(n % qk == 0); + assert(nc % ncols_interleaved == 0); UNUSED(bs); UNUSED(nr); - float sumf[8]; - float sum_minf[8]; + float sumf[ncols_interleaved]; + float sum_minf[ncols_interleaved]; uint32_t utmp[32]; - int sumi1; - int sumi2; - int sumi; + int sumi1; + int sumi2; + int sumi; const block_q8_K * a_ptr = (const block_q8_K *) vy; for (int x = 0; x < nc / ncols_interleaved; x++) { - const block_q4_Kx8 * b_ptr = (const block_q4_Kx8 *) vx + (x * nb); + const block_q5_Kx8 * b_ptr = (const block_q5_Kx8 *) vx + (x * nb); for (int j = 0; j < ncols_interleaved; j++) { - sumf[j] = 0.0; + sumf[j] = 0.0; sum_minf[j] = 0.0; } for (int l = 0; l < nb; l++) { for (int sb = 0; sb < 8; sb++) { - memcpy(utmp + sb * 4, b_ptr[l].scales + sb * 12, 12); - utmp[sb * 4 + 3] = ((utmp[sb * 4 + 2] >> 4) & kmask2) | (((utmp[sb * 4 + 1] >> 6) & kmask3) << 4); + memcpy(utmp + sb * 4, b_ptr[l].scales + sb * K_SCALE_SIZE, K_SCALE_SIZE); + utmp[sb * 4 + 3] = ((utmp[sb * 4 + 2] >> 4) & kmask2) | (((utmp[sb * 4 + 1] >> 6) & kmask3) << 4); const uint32_t uaux_0 = utmp[sb * 4 + 1] & kmask1; - utmp[sb * 4 + 1] = (utmp[sb * 4 + 2] & kmask2) | (((utmp[sb * 4 + 0] >> 6) & kmask3) << 4); - utmp[sb * 4 + 2] = uaux_0; + utmp[sb * 4 + 1] = (utmp[sb * 4 + 2] & kmask2) | (((utmp[sb * 4 + 0] >> 6) & kmask3) << 4); + utmp[sb * 4 + 2] = uaux_0; utmp[sb * 4 + 0] &= kmask1; } for (int k = 0; k < (qk / (2 * blocklen)); k++) { - uint8_t * scales_0 = (uint8_t *) utmp + (k / 8) * 32; - uint8_t * scales_1 = (uint8_t *) utmp + (k / 8) * 32 + 16; + constexpr int scale_stride = 32; + uint8_t * scales_0 = (uint8_t *) utmp + (k / (32 / blocklen)) * scale_stride; + uint8_t * scales_1 = (uint8_t *) utmp + (k / (32 / blocklen)) * scale_stride + 16; + + const int qh_shift = (k / (32 / blocklen)) * 2; for (int j = 0; j < ncols_interleaved; j++) { sumi1 = 0; sumi2 = 0; - sumi = 0; + sumi = 0; for (int i = 0; i < blocklen; ++i) { - const int v0 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0xF); - const int v1 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] >> 4); - sumi1 = (v0 * a_ptr[l].qs[(k / 8) * 64 + (k % 8) * blocklen + i]); - sumi2 = (v1 * a_ptr[l].qs[(k / 8) * 64 + (k % 8) * blocklen + i + 32]); + const int b_qs_offset = k * ncols_interleaved * blocklen + j * blocklen + i; + + const int qh_idx = (k * blocklen + i) % 32; + const int qh_chunk = qh_idx / blocklen; + const int qh_pos = qh_idx % blocklen; + const int b_qh_offset = qh_chunk * (blocklen * ncols_interleaved) + j * blocklen + qh_pos; + + const uint8_t qh_val = b_ptr[l].qh[b_qh_offset]; + const uint8_t h0 = (qh_val >> qh_shift) & 1; + const uint8_t h1 = (qh_val >> (qh_shift + 1)) & 1; + + const int v0 = (int8_t) ((b_ptr[l].qs[b_qs_offset] & 0xF) | (h0 << 4)); + const int v1 = (int8_t) ((b_ptr[l].qs[b_qs_offset] >> 4) | (h1 << 4)); + + const int q8_offset = (k / (32 / blocklen)) * 64 + (k % (32 / blocklen)) * blocklen + i; + + sumi1 = (v0 * a_ptr[l].qs[q8_offset]); + sumi2 = (v1 * a_ptr[l].qs[q8_offset + 32]); sumi1 = sumi1 * scales_0[j]; sumi2 = sumi2 * scales_1[j]; sumi += sumi1 + sumi2; @@ -452,7 +632,8 @@ void ggml_gemv_q4_K_8x4_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, for (int sb = 0; sb < 8; sb++) { uint8_t * mins = (uint8_t *) utmp + 8 + sb * 16; for (int j = 0; j < ncols_interleaved; j++) { - sum_minf[j] += mins[j] * (a_ptr[l].bsums[sb * 2] + a_ptr[l].bsums[sb * 2 + 1]) * GGML_CPU_FP16_TO_FP32(b_ptr[l].dmin[j]) * a_ptr[l].d; + sum_minf[j] += mins[j] * (a_ptr[l].bsums[sb * 2] + a_ptr[l].bsums[sb * 2 + 1]) * + GGML_CPU_FP16_TO_FP32(b_ptr[l].dmin[j]) * a_ptr[l].d; } } } @@ -462,17 +643,123 @@ void ggml_gemv_q4_K_8x4_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, } } -void ggml_gemv_q4_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { - const int qk = QK_K; +template <int M, int N> +static void ggml_gemm_q5_K_NxM_q8_K_generic_impl(int n, + float * GGML_RESTRICT s, + size_t bs, + const void * GGML_RESTRICT vx, + const void * GGML_RESTRICT vy, + int nr, + int nc) { + constexpr int blocklen = M; + constexpr int ncols_interleaved = N; + const int qk = QK_K; + const int nb = n / qk; + static const uint32_t kmask1 = 0x3f3f3f3f; + static const uint32_t kmask2 = 0x0f0f0f0f; + static const uint32_t kmask3 = 0x03030303; + + assert(n % qk == 0); + assert(nr % 4 == 0); + assert(nc % ncols_interleaved == 0); + + float sumf[4][ncols_interleaved]; + float sum_minf[4][ncols_interleaved]; + uint32_t utmp[32]; + int sumi1; + int sumi2; + int sumi; + + for (int y = 0; y < nr / 4; y++) { + const block_q8_Kx4 * a_ptr = (const block_q8_Kx4 *) vy + (y * nb); + for (int x = 0; x < nc / ncols_interleaved; x++) { + const block_q5_Kx8 * b_ptr = (const block_q5_Kx8 *) vx + (x * nb); + for (int m = 0; m < 4; m++) { + for (int j = 0; j < ncols_interleaved; j++) { + sumf[m][j] = 0.0; + sum_minf[m][j] = 0.0; + } + } + for (int l = 0; l < nb; l++) { + for (int sb = 0; sb < 8; sb++) { + memcpy(utmp + sb * 4, b_ptr[l].scales + sb * K_SCALE_SIZE, K_SCALE_SIZE); + utmp[sb * 4 + 3] = ((utmp[sb * 4 + 2] >> 4) & kmask2) | (((utmp[sb * 4 + 1] >> 6) & kmask3) << 4); + const uint32_t uaux_0 = utmp[sb * 4 + 1] & kmask1; + utmp[sb * 4 + 1] = (utmp[sb * 4 + 2] & kmask2) | (((utmp[sb * 4 + 0] >> 6) & kmask3) << 4); + utmp[sb * 4 + 2] = uaux_0; + utmp[sb * 4 + 0] &= kmask1; + } + for (int k = 0; k < (qk / (2 * blocklen)); k++) { + constexpr int scale_stride = 32; + uint8_t * scales_0 = (uint8_t *) utmp + (k / (32 / blocklen)) * scale_stride; + uint8_t * scales_1 = (uint8_t *) utmp + (k / (32 / blocklen)) * scale_stride + 16; + + const int qh_shift = (k / (32 / blocklen)) * 2; + for (int m = 0; m < 4; m++) { + for (int j = 0; j < ncols_interleaved; j++) { + sumi1 = 0; + sumi2 = 0; + sumi = 0; + for (int i = 0; i < blocklen; ++i) { + const int b_qs_offset = k * ncols_interleaved * blocklen + j * blocklen + i; + + const int qh_idx = (k * blocklen + i) % 32; + const int qh_chunk = qh_idx / blocklen; + const int qh_pos = qh_idx % blocklen; + const int b_qh_offset = + qh_chunk * (blocklen * ncols_interleaved) + j * blocklen + qh_pos; + + const uint8_t qh_val = b_ptr[l].qh[b_qh_offset]; + const uint8_t h0 = (qh_val >> qh_shift) & 1; + const uint8_t h1 = (qh_val >> (qh_shift + 1)) & 1; + + const int v0 = (int8_t) ((b_ptr[l].qs[b_qs_offset] & 0xF) | (h0 << 4)); + const int v1 = (int8_t) ((b_ptr[l].qs[b_qs_offset] >> 4) | (h1 << 4)); + + const int q8_offset = (k / (32 / blocklen)) * 256 + + (k % (32 / blocklen)) * 4 * blocklen + m * blocklen + i; + + sumi1 = (v0 * a_ptr[l].qs[q8_offset]); + sumi2 = (v1 * a_ptr[l].qs[q8_offset + 128]); + sumi1 = sumi1 * scales_0[j]; + sumi2 = sumi2 * scales_1[j]; + sumi += sumi1 + sumi2; + } + sumf[m][j] += sumi * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * a_ptr[l].d[m]; + } + } + } + for (int sb = 0; sb < 8; sb++) { + uint8_t * mins = (uint8_t *) utmp + 8 + sb * 16; + for (int m = 0; m < 4; m++) { + const int16_t * bsums = a_ptr[l].bsums + (sb * 8) + (m * 4) - ((sb % 2) * 6); + for (int j = 0; j < ncols_interleaved; j++) { + sum_minf[m][j] += mins[j] * (bsums[0] + bsums[1]) * + GGML_CPU_FP16_TO_FP32(b_ptr[l].dmin[j]) * a_ptr[l].d[m]; + } + } + } + } + for (int m = 0; m < 4; m++) { + for (int j = 0; j < ncols_interleaved; j++) { + s[(y * 4 + m) * bs + x * ncols_interleaved + j] = sumf[m][j] - sum_minf[m][j]; + } + } + } + } +} + +extern "C" { + +void ggml_gemv_q4_0_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + const int qk = QK8_0; const int nb = n / qk; - const int ncols_interleaved = 8; - const int blocklen = 8; - static const uint32_t kmask1 = 0x3f3f3f3f; - static const uint32_t kmask2 = 0x0f0f0f0f; - static const uint32_t kmask3 = 0x03030303; + const int ncols_interleaved = 4; + const int blocklen = 4; - assert (n % qk == 0); - assert (nc % ncols_interleaved == 0); + assert(nr == 1); + assert(n % qk == 0); + assert(nc % ncols_interleaved == 0); UNUSED(s); UNUSED(bs); @@ -484,66 +771,35 @@ void ggml_gemv_q4_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, UNUSED(ncols_interleaved); UNUSED(blocklen); - float sumf[8]; - float sum_minf[8]; - uint32_t utmp[32]; - int sumi1; - int sumi2; + float sumf[4]; int sumi; - const block_q8_K * a_ptr = (const block_q8_K *) vy; + const block_q8_0 * a_ptr = (const block_q8_0 *) vy; for (int x = 0; x < nc / ncols_interleaved; x++) { - const block_q4_Kx8 * b_ptr = (const block_q4_Kx8 *) vx + (x * nb); + const block_q4_0x4 * b_ptr = (const block_q4_0x4 *) vx + (x * nb); - for (int j = 0; j < ncols_interleaved; j++) { - sumf[j] = 0.0; - sum_minf[j] = 0.0; - } + for (int j = 0; j < ncols_interleaved; j++) sumf[j] = 0.0; for (int l = 0; l < nb; l++) { - for (int sb = 0; sb < 8; sb++) { - memcpy(utmp + sb * 4, b_ptr[l].scales + sb * 12, 12); - utmp[sb * 4 + 3] = ((utmp[sb * 4 + 2] >> 4) & kmask2) | (((utmp[sb * 4 + 1] >> 6) & kmask3) << 4); - const uint32_t uaux_0 = utmp[sb * 4 + 1] & kmask1; - utmp[sb * 4 + 1] = (utmp[sb * 4 + 2] & kmask2) | (((utmp[sb * 4 + 0] >> 6) & kmask3) << 4); - utmp[sb * 4 + 2] = uaux_0; - utmp[sb * 4 + 0] &= kmask1; - } for (int k = 0; k < (qk / (2 * blocklen)); k++) { - uint8_t *scales_0 = (uint8_t*) utmp + (k / 4) * 32; - uint8_t *scales_1 = (uint8_t*) utmp + (k / 4) * 32 + 16; for (int j = 0; j < ncols_interleaved; j++) { - sumi1 = 0; - sumi2 = 0; sumi = 0; for (int i = 0; i < blocklen; ++i) { - const int v0 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0xF); - const int v1 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] >> 4); - sumi1 = (v0 * a_ptr[l].qs[(k >> 2) * 64 + (k % 4) * blocklen + i]); - sumi2 = (v1 * a_ptr[l].qs[(k >> 2) * 64 + (k % 4) * blocklen + i + 32]); - sumi1 = sumi1 * scales_0[j]; - sumi2 = sumi2 * scales_1[j]; - sumi += sumi1 + sumi2; + const int v0 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] << 4); + const int v1 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0xF0); + sumi += ((v0 * a_ptr[l].qs[k * blocklen + i]) + (v1 * a_ptr[l].qs[k * blocklen + i + qk / 2])) >> 4; } - sumf[j] += sumi * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * a_ptr[l].d; - } - } - for (int sb = 0; sb < 8; sb++) { - uint8_t *mins = (uint8_t*) utmp + 8 + sb * 16; - for (int j = 0; j < ncols_interleaved; j++) { - sum_minf[j] += mins[j] * (a_ptr[l].bsums[sb * 2] + a_ptr[l].bsums[sb * 2 + 1]) * GGML_CPU_FP16_TO_FP32(b_ptr[l].dmin[j]) * a_ptr[l].d; + sumf[j] += sumi * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * GGML_CPU_FP16_TO_FP32(a_ptr[l].d); } } } - for (int j = 0; j < ncols_interleaved; j++) { - s[x * ncols_interleaved + j] = sumf[j] - sum_minf[j]; - } + for (int j = 0; j < ncols_interleaved; j++) s[x * ncols_interleaved + j] = sumf[j]; } } -void ggml_gemv_q2_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { - const int qk = QK_K; +void ggml_gemv_q4_0_4x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + const int qk = QK8_0; const int nb = n / qk; - const int ncols_interleaved = 8; + const int ncols_interleaved = 4; const int blocklen = 8; assert (n % qk == 0); @@ -559,82 +815,56 @@ void ggml_gemv_q2_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, UNUSED(ncols_interleaved); UNUSED(blocklen); - float sumf[8]; - float sum_minf[8]; - int sumi1,sumi2,sumi3,sumi4; + float sumf[4]; int sumi; - const block_q8_K * a_ptr = (const block_q8_K *)vy; - for(int x = 0; x < nc / ncols_interleaved; x++) { - const block_q2_Kx8 * b_ptr = (const block_q2_Kx8 *) vx + (x * nb); - for (int j = 0; j < ncols_interleaved; j++) { - sumf[j] = 0.0; - sum_minf[j] = 0.0; - } + const block_q8_0 * a_ptr = (const block_q8_0 *) vy; + for (int x = 0; x < nc / ncols_interleaved; x++) { + const block_q4_0x4 * b_ptr = (const block_q4_0x4 *) vx + (x * nb); + + for (int j = 0; j < ncols_interleaved; j++) sumf[j] = 0.0; for (int l = 0; l < nb; l++) { - for (int k = 0; k < (qk / (4 * blocklen)); k++) { - const uint8_t *scales_0 = b_ptr[l].scales + (k / 4) * 64 ; - const uint8_t *scales_1 = b_ptr[l].scales + (k / 4) * 64 + 16; - const uint8_t *scales_2 = b_ptr[l].scales + (k / 4) * 64 + 32; - const uint8_t *scales_3 = b_ptr[l].scales + (k / 4) * 64 + 48; + for (int k = 0; k < (qk / (2 * blocklen)); k++) { for (int j = 0; j < ncols_interleaved; j++) { - sumi1 = 0; - sumi2 = 0; - sumi3 = 0; - sumi4 = 0; sumi = 0; - int offset = ((k / 2) % 2) + j * 2; - for (int i = 0; i < blocklen; ++i){ - const int v0 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 3); - const int v1 = (int8_t) ((b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] >> 2 ) & 3); - const int v2 = (int8_t) ((b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] >> 4 ) & 3); - const int v3 = (int8_t) ((b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] >> 6 ) & 3); - sumi1 = (v0 * a_ptr[l].qs[(k >> 2) * 128 + (k % 4) * blocklen + i]); - sumi2 = (v1 * a_ptr[l].qs[(k >> 2) * 128 + (k % 4) * blocklen + i + 32]); - sumi3 = (v2 * a_ptr[l].qs[(k >> 2) * 128 + (k % 4) * blocklen + i + 64]); - sumi4 = (v3 * a_ptr[l].qs[(k >> 2) * 128 + (k % 4) * blocklen + i + 96]); - - sumi1 = sumi1 * (scales_0[offset] & 0xF); - sumi2 = sumi2 * (scales_1[offset] & 0xF); - sumi3 = sumi3 * (scales_2[offset] & 0xF); - sumi4 = sumi4 * (scales_3[offset] & 0xF); - sumi += sumi1 + sumi2 + sumi3 + sumi4; + for (int i = 0; i < blocklen; ++i) { + const int v0 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] << 4); + const int v1 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0xF0); + sumi += ((v0 * a_ptr[l].qs[k * blocklen + i]) + (v1 * a_ptr[l].qs[k * blocklen + i + qk / 2])) >> 4; } - sumf[j] += sumi * GGML_FP16_TO_FP32(b_ptr[l].d[j]) * a_ptr[l].d; - } - } - for(int sb = 0; sb < 8; sb++) { - const uint8_t *mins = b_ptr[l].scales + sb * 16; - for(int j = 0; j < ncols_interleaved; j++){ - sum_minf[j] += ((mins[j * 2] >> 4) * a_ptr[l].bsums[sb * 2] + (mins[(j * 2)+ 1] >> 4) * a_ptr[l].bsums[sb * 2 + 1]) * GGML_FP16_TO_FP32(b_ptr[l].dmin[j]) * a_ptr[l].d; + sumf[j] += sumi * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * GGML_CPU_FP16_TO_FP32(a_ptr[l].d); } } } - for (int j = 0; j < ncols_interleaved; j++) { - s[x * ncols_interleaved + j] = sumf[j] - sum_minf[j]; - } + for (int j = 0; j < ncols_interleaved; j++) s[x * ncols_interleaved + j] = sumf[j]; } } -void ggml_gemv_iq4_nl_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { +void ggml_gemv_q4_0_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { const int qk = QK8_0; const int nb = n / qk; - const int ncols_interleaved = 4; - const int blocklen = 4; + const int ncols_interleaved = 8; + const int blocklen = 8; - assert(nr == 1); - assert(n % qk == 0); - assert(nc % ncols_interleaved == 0); + assert (n % qk == 0); + assert (nc % ncols_interleaved == 0); + UNUSED(s); UNUSED(bs); + UNUSED(vx); + UNUSED(vy); UNUSED(nr); + UNUSED(nc); + UNUSED(nb); + UNUSED(ncols_interleaved); + UNUSED(blocklen); - float sumf[4]; + float sumf[8]; int sumi; const block_q8_0 * a_ptr = (const block_q8_0 *) vy; for (int x = 0; x < nc / ncols_interleaved; x++) { - const block_iq4_nlx4 * b_ptr = (const block_iq4_nlx4 *) vx + (x * nb); + const block_q4_0x8 * b_ptr = (const block_q4_0x8 *) vx + (x * nb); for (int j = 0; j < ncols_interleaved; j++) sumf[j] = 0.0; for (int l = 0; l < nb; l++) { @@ -642,9 +872,9 @@ void ggml_gemv_iq4_nl_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs for (int j = 0; j < ncols_interleaved; j++) { sumi = 0; for (int i = 0; i < blocklen; ++i) { - const int v0 = kvalues_iq4nl[b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0x0F]; - const int v1 = kvalues_iq4nl[b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] >> 4]; - sumi += ((v0 * a_ptr[l].qs[k * blocklen + i]) + (v1 * a_ptr[l].qs[k * blocklen + i + qk / 2])); + const int v0 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] << 4); + const int v1 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0xF0); + sumi += ((v0 * a_ptr[l].qs[k * blocklen + i]) + (v1 * a_ptr[l].qs[k * blocklen + i + qk / 2])) >> 4; } sumf[j] += sumi * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * GGML_CPU_FP16_TO_FP32(a_ptr[l].d); } @@ -654,139 +884,1212 @@ void ggml_gemv_iq4_nl_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs } } -void ggml_gemv_iq4_nl_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { - const int qk = QK8_0; +void ggml_gemv_q4_K_8x4_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + const int qk = QK_K; const int nb = n / qk; const int ncols_interleaved = 8; - const int blocklen = 8; + const int blocklen = 4; + static const uint32_t kmask1 = 0x3f3f3f3f; + static const uint32_t kmask2 = 0x0f0f0f0f; + static const uint32_t kmask3 = 0x03030303; - assert(nr == 1); - assert(n % qk == 0); - assert(nc % ncols_interleaved == 0); + assert (n % qk == 0); + assert (nc % ncols_interleaved == 0); UNUSED(bs); UNUSED(nr); float sumf[8]; + float sum_minf[8]; + uint32_t utmp[32]; + int sumi1; + int sumi2; int sumi; - const block_q8_0 * a_ptr = (const block_q8_0 *) vy; + const block_q8_K * a_ptr = (const block_q8_K *) vy; for (int x = 0; x < nc / ncols_interleaved; x++) { - const block_iq4_nlx8 * b_ptr = (const block_iq4_nlx8 *) vx + (x * nb); + const block_q4_Kx8 * b_ptr = (const block_q4_Kx8 *) vx + (x * nb); - for (int j = 0; j < ncols_interleaved; j++) sumf[j] = 0.0; + for (int j = 0; j < ncols_interleaved; j++) { + sumf[j] = 0.0; + sum_minf[j] = 0.0; + } for (int l = 0; l < nb; l++) { + for (int sb = 0; sb < 8; sb++) { + memcpy(utmp + sb * 4, b_ptr[l].scales + sb * 12, 12); + utmp[sb * 4 + 3] = ((utmp[sb * 4 + 2] >> 4) & kmask2) | (((utmp[sb * 4 + 1] >> 6) & kmask3) << 4); + const uint32_t uaux_0 = utmp[sb * 4 + 1] & kmask1; + utmp[sb * 4 + 1] = (utmp[sb * 4 + 2] & kmask2) | (((utmp[sb * 4 + 0] >> 6) & kmask3) << 4); + utmp[sb * 4 + 2] = uaux_0; + utmp[sb * 4 + 0] &= kmask1; + } for (int k = 0; k < (qk / (2 * blocklen)); k++) { + uint8_t * scales_0 = (uint8_t *) utmp + (k / 8) * 32; + uint8_t * scales_1 = (uint8_t *) utmp + (k / 8) * 32 + 16; for (int j = 0; j < ncols_interleaved; j++) { - sumi = 0; + sumi1 = 0; + sumi2 = 0; + sumi = 0; + for (int i = 0; i < blocklen; ++i) { + const int v0 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0xF); + const int v1 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] >> 4); + sumi1 = (v0 * a_ptr[l].qs[(k / 8) * 64 + (k % 8) * blocklen + i]); + sumi2 = (v1 * a_ptr[l].qs[(k / 8) * 64 + (k % 8) * blocklen + i + 32]); + sumi1 = sumi1 * scales_0[j]; + sumi2 = sumi2 * scales_1[j]; + sumi += sumi1 + sumi2; + } + sumf[j] += sumi * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * a_ptr[l].d; + } + } + for (int sb = 0; sb < 8; sb++) { + uint8_t * mins = (uint8_t *) utmp + 8 + sb * 16; + for (int j = 0; j < ncols_interleaved; j++) { + sum_minf[j] += mins[j] * (a_ptr[l].bsums[sb * 2] + a_ptr[l].bsums[sb * 2 + 1]) * GGML_CPU_FP16_TO_FP32(b_ptr[l].dmin[j]) * a_ptr[l].d; + } + } + } + for (int j = 0; j < ncols_interleaved; j++) { + s[x * ncols_interleaved + j] = sumf[j] - sum_minf[j]; + } + } +} + +void ggml_gemv_q4_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + const int qk = QK_K; + const int nb = n / qk; + const int ncols_interleaved = 8; + const int blocklen = 8; + static const uint32_t kmask1 = 0x3f3f3f3f; + static const uint32_t kmask2 = 0x0f0f0f0f; + static const uint32_t kmask3 = 0x03030303; + + assert (n % qk == 0); + assert (nc % ncols_interleaved == 0); + + UNUSED(bs); + UNUSED(nr); + + float sumf[8]; + float sum_minf[8]; + uint32_t utmp[32]; + int sumi1; + int sumi2; + int sumi; + + const block_q8_K * a_ptr = (const block_q8_K *) vy; + for (int x = 0; x < nc / ncols_interleaved; x++) { + const block_q4_Kx8 * b_ptr = (const block_q4_Kx8 *) vx + (x * nb); + + for (int j = 0; j < ncols_interleaved; j++) { + sumf[j] = 0.0; + sum_minf[j] = 0.0; + } + for (int l = 0; l < nb; l++) { + for (int sb = 0; sb < 8; sb++) { + memcpy(utmp + sb * 4, b_ptr[l].scales + sb * 12, 12); + utmp[sb * 4 + 3] = ((utmp[sb * 4 + 2] >> 4) & kmask2) | (((utmp[sb * 4 + 1] >> 6) & kmask3) << 4); + const uint32_t uaux_0 = utmp[sb * 4 + 1] & kmask1; + utmp[sb * 4 + 1] = (utmp[sb * 4 + 2] & kmask2) | (((utmp[sb * 4 + 0] >> 6) & kmask3) << 4); + utmp[sb * 4 + 2] = uaux_0; + utmp[sb * 4 + 0] &= kmask1; + } + for (int k = 0; k < (qk / (2 * blocklen)); k++) { + uint8_t *scales_0 = (uint8_t*) utmp + (k / 4) * 32; + uint8_t *scales_1 = (uint8_t*) utmp + (k / 4) * 32 + 16; + for (int j = 0; j < ncols_interleaved; j++) { + sumi1 = 0; + sumi2 = 0; + sumi = 0; + for (int i = 0; i < blocklen; ++i) { + const int v0 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0xF); + const int v1 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] >> 4); + sumi1 = (v0 * a_ptr[l].qs[(k >> 2) * 64 + (k % 4) * blocklen + i]); + sumi2 = (v1 * a_ptr[l].qs[(k >> 2) * 64 + (k % 4) * blocklen + i + 32]); + sumi1 = sumi1 * scales_0[j]; + sumi2 = sumi2 * scales_1[j]; + sumi += sumi1 + sumi2; + } + sumf[j] += sumi * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * a_ptr[l].d; + } + } + for (int sb = 0; sb < 8; sb++) { + uint8_t *mins = (uint8_t*) utmp + 8 + sb * 16; + for (int j = 0; j < ncols_interleaved; j++) { + sum_minf[j] += mins[j] * (a_ptr[l].bsums[sb * 2] + a_ptr[l].bsums[sb * 2 + 1]) * GGML_CPU_FP16_TO_FP32(b_ptr[l].dmin[j]) * a_ptr[l].d; + } + } + } + for (int j = 0; j < ncols_interleaved; j++) { + s[x * ncols_interleaved + j] = sumf[j] - sum_minf[j]; + } + } +} + +void ggml_gemv_q2_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + const int qk = QK_K; + const int nb = n / qk; + const int ncols_interleaved = 8; + const int blocklen = 8; + + assert (n % qk == 0); + assert (nc % ncols_interleaved == 0); + + UNUSED(s); + UNUSED(bs); + UNUSED(vx); + UNUSED(vy); + UNUSED(nr); + UNUSED(nc); + UNUSED(nb); + UNUSED(ncols_interleaved); + UNUSED(blocklen); + + float sumf[8]; + float sum_minf[8]; + int sumi1,sumi2,sumi3,sumi4; + int sumi; + + const block_q8_K * a_ptr = (const block_q8_K *)vy; + for(int x = 0; x < nc / ncols_interleaved; x++) { + const block_q2_Kx8 * b_ptr = (const block_q2_Kx8 *) vx + (x * nb); + for (int j = 0; j < ncols_interleaved; j++) { + sumf[j] = 0.0; + sum_minf[j] = 0.0; + } + for (int l = 0; l < nb; l++) { + for (int k = 0; k < (qk / (4 * blocklen)); k++) { + const uint8_t *scales_0 = b_ptr[l].scales + (k / 4) * 64 ; + const uint8_t *scales_1 = b_ptr[l].scales + (k / 4) * 64 + 16; + const uint8_t *scales_2 = b_ptr[l].scales + (k / 4) * 64 + 32; + const uint8_t *scales_3 = b_ptr[l].scales + (k / 4) * 64 + 48; + for (int j = 0; j < ncols_interleaved; j++) { + sumi1 = 0; + sumi2 = 0; + sumi3 = 0; + sumi4 = 0; + sumi = 0; + int offset = ((k / 2) % 2) + j * 2; + for (int i = 0; i < blocklen; ++i){ + const int v0 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 3); + const int v1 = (int8_t) ((b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] >> 2 ) & 3); + const int v2 = (int8_t) ((b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] >> 4 ) & 3); + const int v3 = (int8_t) ((b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] >> 6 ) & 3); + sumi1 = (v0 * a_ptr[l].qs[(k >> 2) * 128 + (k % 4) * blocklen + i]); + sumi2 = (v1 * a_ptr[l].qs[(k >> 2) * 128 + (k % 4) * blocklen + i + 32]); + sumi3 = (v2 * a_ptr[l].qs[(k >> 2) * 128 + (k % 4) * blocklen + i + 64]); + sumi4 = (v3 * a_ptr[l].qs[(k >> 2) * 128 + (k % 4) * blocklen + i + 96]); + + sumi1 = sumi1 * (scales_0[offset] & 0xF); + sumi2 = sumi2 * (scales_1[offset] & 0xF); + sumi3 = sumi3 * (scales_2[offset] & 0xF); + sumi4 = sumi4 * (scales_3[offset] & 0xF); + sumi += sumi1 + sumi2 + sumi3 + sumi4; + } + sumf[j] += sumi * GGML_FP16_TO_FP32(b_ptr[l].d[j]) * a_ptr[l].d; + } + } + for(int sb = 0; sb < 8; sb++) { + const uint8_t *mins = b_ptr[l].scales + sb * 16; + for(int j = 0; j < ncols_interleaved; j++){ + sum_minf[j] += ((mins[j * 2] >> 4) * a_ptr[l].bsums[sb * 2] + (mins[(j * 2)+ 1] >> 4) * a_ptr[l].bsums[sb * 2 + 1]) * GGML_FP16_TO_FP32(b_ptr[l].dmin[j]) * a_ptr[l].d; + } + } + } + for (int j = 0; j < ncols_interleaved; j++) { + s[x * ncols_interleaved + j] = sumf[j] - sum_minf[j]; + } + } +} + +void ggml_gemv_q5_K_8x4_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + ggml_gemv_q5_K_NxM_q8_K_generic_impl<4, 8>(n, s, bs, vx, vy, nr, nc); +} + +void ggml_gemv_q5_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + ggml_gemv_q5_K_NxM_q8_K_generic_impl<8, 8>(n, s, bs, vx, vy, nr, nc); +} + + +void ggml_gemv_q6_K_8x4_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + ggml_gemv_q6_K_NxM_q8_K_generic_impl<4, 8>(n, s, bs, vx, vy, nr, nc); +} + +void ggml_gemv_q6_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + ggml_gemv_q6_K_NxM_q8_K_generic_impl<8, 8>(n, s, bs, vx, vy, nr, nc); +} + +void ggml_gemv_iq4_nl_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + const int qk = QK8_0; + const int nb = n / qk; + const int ncols_interleaved = 4; + const int blocklen = 4; + + assert(nr == 1); + assert(n % qk == 0); + assert(nc % ncols_interleaved == 0); + + UNUSED(bs); + UNUSED(nr); + + float sumf[4]; + int sumi; + + const block_q8_0 * a_ptr = (const block_q8_0 *) vy; + for (int x = 0; x < nc / ncols_interleaved; x++) { + const block_iq4_nlx4 * b_ptr = (const block_iq4_nlx4 *) vx + (x * nb); + + for (int j = 0; j < ncols_interleaved; j++) sumf[j] = 0.0; + for (int l = 0; l < nb; l++) { + for (int k = 0; k < (qk / (2 * blocklen)); k++) { + for (int j = 0; j < ncols_interleaved; j++) { + sumi = 0; + for (int i = 0; i < blocklen; ++i) { + const int v0 = kvalues_iq4nl[b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0x0F]; + const int v1 = kvalues_iq4nl[b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] >> 4]; + sumi += ((v0 * a_ptr[l].qs[k * blocklen + i]) + (v1 * a_ptr[l].qs[k * blocklen + i + qk / 2])); + } + sumf[j] += sumi * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * GGML_CPU_FP16_TO_FP32(a_ptr[l].d); + } + } + } + for (int j = 0; j < ncols_interleaved; j++) s[x * ncols_interleaved + j] = sumf[j]; + } +} + +void ggml_gemv_iq4_nl_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + const int qk = QK8_0; + const int nb = n / qk; + const int ncols_interleaved = 8; + const int blocklen = 8; + + assert(nr == 1); + assert(n % qk == 0); + assert(nc % ncols_interleaved == 0); + + UNUSED(bs); + UNUSED(nr); + + float sumf[8]; + int sumi; + + const block_q8_0 * a_ptr = (const block_q8_0 *) vy; + for (int x = 0; x < nc / ncols_interleaved; x++) { + const block_iq4_nlx8 * b_ptr = (const block_iq4_nlx8 *) vx + (x * nb); + + for (int j = 0; j < ncols_interleaved; j++) sumf[j] = 0.0; + for (int l = 0; l < nb; l++) { + for (int k = 0; k < (qk / (2 * blocklen)); k++) { + for (int j = 0; j < ncols_interleaved; j++) { + sumi = 0; for (int i = 0; i < blocklen; ++i) { const int v0 = kvalues_iq4nl[b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0x0F]; const int v1 = kvalues_iq4nl[b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] >> 4]; sumi += ((v0 * a_ptr[l].qs[k * blocklen + i]) + (v1 * a_ptr[l].qs[k * blocklen + i + qk / 2])); } - sumf[j] += sumi * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * GGML_CPU_FP16_TO_FP32(a_ptr[l].d); + sumf[j] += sumi * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * GGML_CPU_FP16_TO_FP32(a_ptr[l].d); + } + } + } + for (int j = 0; j < ncols_interleaved; j++) s[x * ncols_interleaved + j] = sumf[j]; + } +} + +void ggml_gemv_mxfp4_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + const int qk = QK8_0; + const int nb = n / qk; + const int ncols_interleaved = 4; + const int blocklen = 4; + + assert(nr == 1); + assert(n % qk == 0); + assert(nc % ncols_interleaved == 0); + + UNUSED(bs); + UNUSED(nr); + + float sumf[4]; + int sumi; + + const block_q8_0 * a_ptr = (const block_q8_0 *) vy; + for (int x = 0; x < nc / ncols_interleaved; x++) { + const block_mxfp4x4 * b_ptr = (const block_mxfp4x4 *) vx + (x * nb); + + for (int j = 0; j < ncols_interleaved; j++) sumf[j] = 0.0; + for (int l = 0; l < nb; l++) { + for (int k = 0; k < (qk / (2 * blocklen)); k++) { + for (int j = 0; j < ncols_interleaved; j++) { + sumi = 0; + for (int i = 0; i < blocklen; ++i) { + const int v0 = kvalues_mxfp4[b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0x0F]; + const int v1 = kvalues_mxfp4[b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] >> 4]; + sumi += ((v0 * a_ptr[l].qs[k * blocklen + i]) + (v1 * a_ptr[l].qs[k * blocklen + i + qk / 2])); + } + sumf[j] += sumi * GGML_CPU_E8M0_TO_FP32_HALF(b_ptr[l].e[j]) * GGML_CPU_FP16_TO_FP32(a_ptr[l].d); + } + } + } + for (int j = 0; j < ncols_interleaved; j++) s[x * ncols_interleaved + j] = sumf[j]; + } +} + +void ggml_gemv_mxfp4_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + const int qk = QK8_0; + const int nb = n / qk; + const int ncols_interleaved = 8; + const int blocklen = 8; + + assert(nr == 1); + assert(n % qk == 0); + assert(nc % ncols_interleaved == 0); + + UNUSED(bs); + UNUSED(nr); + + float sumf[8]; + int sumi; + + const block_q8_0 * a_ptr = (const block_q8_0 *) vy; + for (int x = 0; x < nc / ncols_interleaved; x++) { + const block_mxfp4x8 * b_ptr = (const block_mxfp4x8 *) vx + (x * nb); + + for (int j = 0; j < ncols_interleaved; j++) sumf[j] = 0.0; + for (int l = 0; l < nb; l++) { + for (int k = 0; k < (qk / (2 * blocklen)); k++) { + for (int j = 0; j < ncols_interleaved; j++) { + sumi = 0; + for (int i = 0; i < blocklen; ++i) { + const int v0 = kvalues_mxfp4[b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0x0F]; + const int v1 = kvalues_mxfp4[b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] >> 4]; + sumi += ((v0 * a_ptr[l].qs[k * blocklen + i]) + (v1 * a_ptr[l].qs[k * blocklen + i + qk / 2])); + } + sumf[j] += sumi * GGML_CPU_E8M0_TO_FP32_HALF(b_ptr[l].e[j]) * GGML_CPU_FP16_TO_FP32(a_ptr[l].d); + } + } + } + for (int j = 0; j < ncols_interleaved; j++) s[x * ncols_interleaved + j] = sumf[j]; + } +} + +void ggml_gemv_q8_0_4x4_q8_0_generic(int n, + float * GGML_RESTRICT s, + size_t bs, + const void * GGML_RESTRICT vx, + const void * GGML_RESTRICT vy, + int nr, + int nc) { + const int qk = QK8_0; + const int nb = n / qk; + const int ncols_interleaved = 4; + const int blocklen = 4; + + assert(nr == 1); + assert(n % qk == 0); + assert(nc % ncols_interleaved == 0); + + UNUSED(bs); + UNUSED(nr); + + float sumf[4]; + int sumi; + + const block_q8_0 * a_ptr = (const block_q8_0 *) vy; + for (int x = 0; x < nc / ncols_interleaved; x++) { + const block_q8_0x4 * b_ptr = (const block_q8_0x4 *) vx + (x * nb); + + for (int j = 0; j < ncols_interleaved; j++) { + sumf[j] = 0.0; + } + for (int l = 0; l < nb; l++) { + for (int k = 0; k < (qk / blocklen); k++) { + for (int j = 0; j < ncols_interleaved; j++) { + sumi = 0; + for (int i = 0; i < blocklen; ++i) { + const int v0 = b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i]; + sumi += v0 * a_ptr[l].qs[k * blocklen + i]; + } + sumf[j] += sumi * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * GGML_CPU_FP16_TO_FP32(a_ptr[l].d); + } + } + } + for (int j = 0; j < ncols_interleaved; j++) { + s[x * ncols_interleaved + j] = sumf[j]; + } + } +} + +void ggml_gemv_q8_0_4x8_q8_0_generic(int n, + float * GGML_RESTRICT s, + size_t bs, + const void * GGML_RESTRICT vx, + const void * GGML_RESTRICT vy, + int nr, + int nc) { + const int qk = QK8_0; + const int nb = n / qk; + const int ncols_interleaved = 4; + const int blocklen = 8; + + assert(nr == 1); + assert(n % qk == 0); + assert(nc % ncols_interleaved == 0); + + UNUSED(bs); + UNUSED(nr); + + float sumf[4]; + int sumi; + + const block_q8_0 * a_ptr = (const block_q8_0 *) vy; + for (int x = 0; x < nc / ncols_interleaved; x++) { + const block_q8_0x4 * b_ptr = (const block_q8_0x4 *) vx + (x * nb); + + for (int j = 0; j < ncols_interleaved; j++) { + sumf[j] = 0.0; + } + for (int l = 0; l < nb; l++) { + for (int k = 0; k < (qk / blocklen); k++) { + for (int j = 0; j < ncols_interleaved; j++) { + sumi = 0; + for (int i = 0; i < blocklen; ++i) { + const int v0 = b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i]; + sumi += v0 * a_ptr[l].qs[k * blocklen + i]; + } + sumf[j] += sumi * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * GGML_CPU_FP16_TO_FP32(a_ptr[l].d); + } + } + } + for (int j = 0; j < ncols_interleaved; j++) { + s[x * ncols_interleaved + j] = sumf[j]; + } + } +} + +// Only enable these for RISC-V. +#if defined __riscv_zvfh +void ggml_gemv_q4_0_16x1_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + const int qk = QK8_0; + const int nb = n / qk; + const int ncols_interleaved = 16; + const int blocklen = 1; + + assert (n % qk == 0); + assert (nc % ncols_interleaved == 0); + + UNUSED(s); + UNUSED(bs); + UNUSED(vx); + UNUSED(vy); + UNUSED(nr); + UNUSED(nc); + UNUSED(nb); + UNUSED(ncols_interleaved); + UNUSED(blocklen); + + float sumf[16]; + int sumi; + + const block_q8_0 * a_ptr = (const block_q8_0 *) vy; + for (int x = 0; x < nc / ncols_interleaved; x++) { + const block_q4_0x16 * b_ptr = (const block_q4_0x16 *) vx + (x * nb); + + for (int j = 0; j < ncols_interleaved; j++) sumf[j] = 0.0; + for (int l = 0; l < nb; l++) { + for (int k = 0; k < (qk / (2 * blocklen)); k++) { + for (int j = 0; j < ncols_interleaved; j++) { + sumi = 0; + for (int i = 0; i < blocklen; ++i) { + const int v0 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] << 4); + const int v1 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0xF0); + sumi += ((v0 * a_ptr[l].qs[k * blocklen + i]) + (v1 * a_ptr[l].qs[k * blocklen + i + qk / 2])) >> 4; + } + sumf[j] += sumi * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * GGML_CPU_FP16_TO_FP32(a_ptr[l].d); + } + } + } + for (int j = 0; j < ncols_interleaved; j++) s[x * ncols_interleaved + j] = sumf[j]; + } +} + +void ggml_gemv_q4_K_16x1_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + const int qk = QK_K; + const int nb = n / qk; + const int ncols_interleaved = 16; + const int blocklen = 1; + assert (n % qk == 0); + assert (nc % ncols_interleaved == 0); + UNUSED(s); + UNUSED(bs); + UNUSED(vx); + UNUSED(vy); + UNUSED(nr); + UNUSED(nc); + UNUSED(nb); + UNUSED(ncols_interleaved); + UNUSED(blocklen); + float sumf[16]; + float sum_minf[16]; + uint8_t scales[128]; + uint8_t mins[128]; + int sumi1; + int sumi2; + int sumi; + const block_q8_K * a_ptr = (const block_q8_K *) vy; + for (int x = 0; x < nc / ncols_interleaved; x++) { + const block_q4_Kx16 * b_ptr = (const block_q4_Kx16 *) vx + (x * nb); + for (int j = 0; j < ncols_interleaved; j++) { + sumf[j] = 0.0f; + sum_minf[j] = 0.0f; + } + for (int l = 0; l < nb; l++) { + for (int i = 0; i < 128; i++) { + scales[i] = b_ptr[l].scales[i] & 0x0F; + mins[i] = b_ptr[l].scales[i] >> 4; + } + for (int i = 0; i < 64; i++) { + scales[i] |= (b_ptr[l].scales[128 + i] & 0x03) << 4; + mins[i] |= (b_ptr[l].scales[128 + i] & 0x0C) << 2; + scales[i + 64] |= (b_ptr[l].scales[128 + i] & 0x30); + mins[i + 64] |= (b_ptr[l].scales[128 + i] & 0xC0) >> 2; + } + for (int sb = 0; sb < 8; sb++) { + uint8_t *min = &mins[sb * 16]; + for (int j = 0; j < ncols_interleaved; j++) { + sum_minf[j] += min[j] * (a_ptr[l].bsums[sb * 2] + a_ptr[l].bsums[sb * 2 + 1]) * GGML_CPU_FP16_TO_FP32(b_ptr[l].dmin[j]) * a_ptr[l].d; + } + } + for (int sb = 0; sb < 8; sb += 2) { + uint8_t *scales_0 = &scales[sb * 16]; + uint8_t *scales_1 = &scales[(sb + 1) * 16]; + for (int i = 0; i < QK4_0; i++) { + for (int j = 0; j < ncols_interleaved; j++) { + sumi1 = 0; + sumi2 = 0; + sumi = 0; + const int v0 = (int8_t) (b_ptr[l].qs[sb * 256 + i * 16 + j] & 0xF); + const int v1 = (int8_t) (b_ptr[l].qs[sb * 256 + i * 16 + j] >> 4); + sumi1 = (v0 * a_ptr[l].qs[sb * 32 + i]); + sumi2 = (v1 * a_ptr[l].qs[sb * 32 + 32 + i]); + sumi1 = sumi1 * scales_0[j]; + sumi2 = sumi2 * scales_1[j]; + sumi += sumi1 + sumi2; + sumf[j] += sumi * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * a_ptr[l].d; + } + } + } + } + for (int j = 0; j < ncols_interleaved; j++) { + s[x * ncols_interleaved + j] = sumf[j] - sum_minf[j]; + } + } +} + +void ggml_gemv_iq4_nl_16x1_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + const int qk = QK8_0; + const int nb = n / qk; + const int ncols_interleaved = 16; + const int blocklen = 1; + + assert(nr == 1); + assert(n % qk == 0); + assert(nc % ncols_interleaved == 0); + + UNUSED(bs); + UNUSED(nr); + + float sumf[16]; + int sumi; + + const block_q8_0 * a_ptr = (const block_q8_0 *) vy; + for (int x = 0; x < nc / ncols_interleaved; x++) { + const block_iq4_nlx16 * b_ptr = (const block_iq4_nlx16 *) vx + (x * nb); + + for (int j = 0; j < ncols_interleaved; j++) sumf[j] = 0.0; + for (int l = 0; l < nb; l++) { + for (int k = 0; k < (qk / (2 * blocklen)); k++) { + for (int j = 0; j < ncols_interleaved; j++) { + sumi = 0; + for (int i = 0; i < blocklen; ++i) { + const int v0 = kvalues_iq4nl[b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0x0F]; + const int v1 = kvalues_iq4nl[b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] >> 4]; + sumi += ((v0 * a_ptr[l].qs[k * blocklen + i]) + (v1 * a_ptr[l].qs[k * blocklen + i + qk / 2])); + } + sumf[j] += sumi * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * GGML_CPU_FP16_TO_FP32(a_ptr[l].d); + } + } + } + for (int j = 0; j < ncols_interleaved; j++) s[x * ncols_interleaved + j] = sumf[j]; + } +} + +void ggml_gemv_q8_0_16x1_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + const int qk = QK8_0; + const int nb = n / qk; + const int ncols_interleaved = 16; + const int blocklen = 1; + + assert(nr == 1); + assert(n % qk == 0); + assert(nc % ncols_interleaved == 0); + + UNUSED(bs); + UNUSED(nr); + + float sumf[16]; + int sumi; + + const block_q8_0 * a_ptr = (const block_q8_0 *) vy; + for (int x = 0; x < nc / ncols_interleaved; x++) { + const block_q8_0x16 * b_ptr = (const block_q8_0x16 *) vx + (x * nb); + + for (int j = 0; j < ncols_interleaved; j++) { + sumf[j] = 0.0; + } + for (int l = 0; l < nb; l++) { + for (int k = 0; k < (qk / blocklen); k++) { + for (int j = 0; j < ncols_interleaved; j++) { + sumi = 0; + for (int i = 0; i < blocklen; ++i) { + const int v0 = b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i]; + sumi += v0 * a_ptr[l].qs[k * blocklen + i]; + } + sumf[j] += sumi * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * GGML_CPU_FP16_TO_FP32(a_ptr[l].d); + } + } + } + for (int j = 0; j < ncols_interleaved; j++) { + s[x * ncols_interleaved + j] = sumf[j]; + } + } +} + +void ggml_gemv_q2_K_16x1_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + assert(n % QK_K == 0); + assert(nr == 1); + assert(nc % 16 == 0); + + UNUSED(bs); + UNUSED(nr); + + const int nb = n / QK_K; + const block_q2_Kx16 * x = (const block_q2_Kx16 *)vx; + const block_q8_K * y = (const block_q8_K *)vy; + + // Layout: Even-Low(0,2,4,6), Odd-Low(1,3,5,7), Even-High(8...), Odd-High(9...) + const int sb_perm[16] = { + 0, 4, 1, 5, 2, 6, 3, 7, // 0-7 + 8, 12, 9, 13, 10, 14, 11, 15 // 8-15 + }; + + for (int col_tile = 0; col_tile < nc; col_tile += 16) { + const block_q2_Kx16 * x_ptr = x + (col_tile / 16) * nb; + const block_q8_K * y_ptr = y; + + float sumf[16] = {0}; + + // Loop over K-blocks + for (int k_block = 0; k_block < nb; ++k_block) { + int32_t isum[16] = {0}; + int32_t summs[16] = {0}; + + const uint8_t * qs_rhs = x_ptr[k_block].qs; + const uint8_t * sc_rhs = x_ptr[k_block].scales; + const int8_t * qs_lhs = y_ptr[k_block].qs; + const int16_t * bs_lhs = y_ptr[k_block].bsums; + + // Iterate over sub-blocks 0..15 + for (int sb = 0; sb < 16; ++sb) { + // Correction Term + int16_t bsum = bs_lhs[sb]; + int scale_offset = sb_perm[sb] * 16; + + for (int col = 0; col < 16; ++col) { + uint8_t sc_val = sc_rhs[scale_offset + col]; + summs[col] += bsum * (sc_val >> 4); // Min is high 4 bits + } + + // Main Dot Product + // Calculate base offsets for Q2 unpacking based on SB + int byte_base; + if (sb < 8) byte_base = (sb % 2 == 0) ? 0 : 16; + else byte_base = (sb % 2 == 0) ? 32 : 48; + + int shift = ((sb / 2) % 4) * 2; + + for (int col = 0; col < 16; ++col) { + uint8_t sc_val = sc_rhs[scale_offset + col]; + int32_t d_sb = sc_val & 0xF; // Scale is low 4 bits + + // Process 16 elements (l=0..15) + for (int l = 0; l < 16; ++l) { + // Q2: Interleaved by column. Byte `l` contains 4 k-values. + int qs_idx = (byte_base + l) * 16 + col; + uint8_t q2_val = (qs_rhs[qs_idx] >> shift) & 3; + + // Q8: Linear access + int k = sb * 16 + l; + int8_t q8_val = qs_lhs[k]; + + isum[col] += q8_val * q2_val * d_sb; + } + } + } + + // Finalize K-Block + for (int col = 0; col < 16; ++col) { + float d_lhs = y_ptr[k_block].d; + float d_rhs = GGML_FP16_TO_FP32(x_ptr[k_block].d[col]); + float dm_rhs = GGML_FP16_TO_FP32(x_ptr[k_block].dmin[col]); + + float d_all = d_lhs * d_rhs; + float d_min = d_lhs * dm_rhs; + + sumf[col] += (isum[col] * d_all) - (summs[col] * d_min); + } + } + + for (int col = 0; col < 16; ++col) { + s[col_tile + col] = sumf[col]; + } + } +} +#endif + +void ggml_gemm_q4_0_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + const int qk = QK8_0; + const int nb = n / qk; + const int ncols_interleaved = 4; + const int blocklen = 4; + + assert (n % qk == 0); + assert (nr % 4 == 0); + assert (nc % ncols_interleaved == 0); + + UNUSED(s); + UNUSED(bs); + UNUSED(vx); + UNUSED(vy); + UNUSED(nr); + UNUSED(nc); + UNUSED(nb); + UNUSED(ncols_interleaved); + UNUSED(blocklen); + + { + float sumf[4][4]; + int sumi; + + for (int y = 0; y < nr / 4; y++) { + const block_q8_0x4 * a_ptr = (const block_q8_0x4 *) vy + (y * nb); + for (int x = 0; x < nc / ncols_interleaved; x++) { + const block_q4_0x4 * b_ptr = (const block_q4_0x4 *) vx + (x * nb); + for (int m = 0; m < 4; m++) { + for (int j = 0; j < ncols_interleaved; j++) sumf[m][j] = 0.0; + } + for (int l = 0; l < nb; l++) { + for (int k = 0; k < (qk / (2 * blocklen)); k++) { + for (int m = 0; m < 4; m++) { + for (int j = 0; j < ncols_interleaved; j++) { + sumi = 0; + for (int i = 0; i < blocklen; ++i) { + const int v0 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] << 4); + const int v1 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0xF0); + sumi += ((v0 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i]) + + (v1 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i + qk / 2 * 4])) >> 4; + } + sumf[m][j] += sumi * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * GGML_CPU_FP16_TO_FP32(a_ptr[l].d[m]); + } + } + } + } + for (int m = 0; m < 4; m++) { + for (int j = 0; j < ncols_interleaved; j++) + s[(y * 4 + m) * bs + x * ncols_interleaved + j] = sumf[m][j]; + } + } + } + } +} + +void ggml_gemm_q4_0_4x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + const int qk = QK8_0; + const int nb = n / qk; + const int ncols_interleaved = 4; + const int blocklen = 8; + + assert (n % qk == 0); + assert (nr % 4 == 0); + assert (nc % ncols_interleaved == 0); + + UNUSED(s); + UNUSED(bs); + UNUSED(vx); + UNUSED(vy); + UNUSED(nr); + UNUSED(nc); + UNUSED(nb); + UNUSED(ncols_interleaved); + UNUSED(blocklen); + + float sumf[4][4]; + int sumi; + + for (int y = 0; y < nr / 4; y++) { + const block_q8_0x4 * a_ptr = (const block_q8_0x4 *) vy + (y * nb); + for (int x = 0; x < nc / ncols_interleaved; x++) { + const block_q4_0x4 * b_ptr = (const block_q4_0x4 *) vx + (x * nb); + for (int m = 0; m < 4; m++) { + for (int j = 0; j < ncols_interleaved; j++) sumf[m][j] = 0.0; + } + for (int l = 0; l < nb; l++) { + for (int k = 0; k < (qk / (2 * blocklen)); k++) { + for (int m = 0; m < 4; m++) { + for (int j = 0; j < ncols_interleaved; j++) { + sumi = 0; + for (int i = 0; i < blocklen; ++i) { + const int v0 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] << 4); + const int v1 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0xF0); + sumi += ((v0 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i]) + + (v1 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i + qk / 2 * 4])) >> 4; + } + sumf[m][j] += sumi * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * GGML_CPU_FP16_TO_FP32(a_ptr[l].d[m]); + } + } + } + } + for (int m = 0; m < 4; m++) { + for (int j = 0; j < ncols_interleaved; j++) + s[(y * 4 + m) * bs + x * ncols_interleaved + j] = sumf[m][j]; + } + } + } +} + +void ggml_gemm_q4_0_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + const int qk = QK8_0; + const int nb = n / qk; + const int ncols_interleaved = 8; + const int blocklen = 8; + + assert (n % qk == 0); + assert (nr % 4 == 0); + assert (nc % ncols_interleaved == 0); + + UNUSED(s); + UNUSED(bs); + UNUSED(vx); + UNUSED(vy); + UNUSED(nr); + UNUSED(nc); + UNUSED(nb); + UNUSED(ncols_interleaved); + UNUSED(blocklen); + + float sumf[4][8]; + int sumi; + + for (int y = 0; y < nr / 4; y++) { + const block_q8_0x4 * a_ptr = (const block_q8_0x4 *) vy + (y * nb); + for (int x = 0; x < nc / ncols_interleaved; x++) { + const block_q4_0x8 * b_ptr = (const block_q4_0x8 *) vx + (x * nb); + for (int m = 0; m < 4; m++) { + for (int j = 0; j < ncols_interleaved; j++) sumf[m][j] = 0.0; + } + for (int l = 0; l < nb; l++) { + for (int k = 0; k < (qk / (2 * blocklen)); k++) { + for (int m = 0; m < 4; m++) { + for (int j = 0; j < ncols_interleaved; j++) { + sumi = 0; + for (int i = 0; i < blocklen; ++i) { + const int v0 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] << 4); + const int v1 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0xF0); + sumi += ((v0 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i]) + + (v1 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i + qk / 2 * 4])) >> 4; + } + sumf[m][j] += sumi * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * GGML_CPU_FP16_TO_FP32(a_ptr[l].d[m]); + } + } + } + } + for (int m = 0; m < 4; m++) { + for (int j = 0; j < ncols_interleaved; j++) + s[(y * 4 + m) * bs + x * ncols_interleaved + j] = sumf[m][j]; + } + } + } +} + +void ggml_gemm_q4_K_8x4_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + const int qk = QK_K; + const int nb = n / qk; + const int ncols_interleaved = 8; + const int blocklen = 4; + static const uint32_t kmask1 = 0x3f3f3f3f; + static const uint32_t kmask2 = 0x0f0f0f0f; + static const uint32_t kmask3 = 0x03030303; + + assert (n % qk == 0); + assert (nr % 4 == 0); + assert (nc % ncols_interleaved == 0); + + UNUSED(nb); + UNUSED(ncols_interleaved); + UNUSED(blocklen); + + float sumf[4][8]; + float sum_minf[4][8]; + uint32_t utmp[32]; + int sumi1; + int sumi2; + int sumi; + + for (int y = 0; y < nr / 4; y++) { + const block_q8_Kx4 * a_ptr = (const block_q8_Kx4 *) vy + (y * nb); + for (int x = 0; x < nc / ncols_interleaved; x++) { + const block_q4_Kx8 * b_ptr = (const block_q4_Kx8 *) vx + (x * nb); + for (int m = 0; m < 4; m++) { + for (int j = 0; j < ncols_interleaved; j++) { + sumf[m][j] = 0.0; + sum_minf[m][j] = 0.0; + } + } + for (int l = 0; l < nb; l++) { + for (int sb = 0; sb < 8; sb++) { + memcpy(utmp + sb * 4, b_ptr[l].scales + sb * 12, 12); + utmp[sb * 4 + 3] = ((utmp[sb * 4 + 2] >> 4) & kmask2) | (((utmp[sb * 4 + 1] >> 6) & kmask3) << 4); + const uint32_t uaux_0 = utmp[sb * 4 + 1] & kmask1; + utmp[sb * 4 + 1] = (utmp[sb * 4 + 2] & kmask2) | (((utmp[sb * 4 + 0] >> 6) & kmask3) << 4); + utmp[sb * 4 + 2] = uaux_0; + utmp[sb * 4 + 0] &= kmask1; + } + for (int k = 0; k < (qk / (2 * blocklen)); k++) { + uint8_t * scales_0 = (uint8_t *) utmp + (k / 8) * 32; + uint8_t * scales_1 = (uint8_t *) utmp + (k / 8) * 32 + 16; + for (int m = 0; m < 4; m++) { + for (int j = 0; j < ncols_interleaved; j++) { + sumi1 = 0; + sumi2 = 0; + sumi = 0; + for (int i = 0; i < blocklen; ++i) { + const int v0 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0xF); + const int v1 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] >> 4); + sumi1 = (v0 * a_ptr[l].qs[(k / 8) * 256 + (k % 8) * 4 * blocklen + m * blocklen + i]); + sumi2 = (v1 * a_ptr[l].qs[(k / 8) * 256 + (k % 8) * 4 * blocklen + m * blocklen + i + 128]); + sumi1 = sumi1 * scales_0[j]; + sumi2 = sumi2 * scales_1[j]; + sumi += sumi1 + sumi2; + } + sumf[m][j] += sumi * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * a_ptr[l].d[m]; + } + } + } + for (int sb = 0; sb < 8; sb++) { + uint8_t * mins = (uint8_t *) utmp + 8 + sb * 16; + for(int m = 0; m < 4; m++) { + const int16_t * bsums = a_ptr[l].bsums + (sb * 8) + (m * 4) - ((sb % 2) * 6); + for(int j = 0; j < ncols_interleaved; j++) { + sum_minf[m][j] += mins[j] * (bsums[0] + bsums[1]) * GGML_CPU_FP16_TO_FP32(b_ptr[l].dmin[j]) * a_ptr[l].d[m]; + } + } + } + } + for (int m = 0; m < 4; m++) { + for (int j = 0; j < ncols_interleaved; j++) { + s[(y * 4 + m) * bs + x * ncols_interleaved + j] = sumf[m][j] - sum_minf[m][j]; + } + } + } + } +} + +void ggml_gemm_q4_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + const int qk = QK_K; + const int nb = n / qk; + const int ncols_interleaved = 8; + const int blocklen = 8; + static const uint32_t kmask1 = 0x3f3f3f3f; + static const uint32_t kmask2 = 0x0f0f0f0f; + static const uint32_t kmask3 = 0x03030303; + + assert (n % qk == 0); + assert (nr % 4 == 0); + assert (nc % ncols_interleaved == 0); + + UNUSED(bs); + + float sumf[4][8]; + float sum_minf[4][8]; + uint32_t utmp[32]; + int sumi1; + int sumi2; + int sumi; + + for (int y = 0; y < nr / 4; y++) { + const block_q8_Kx4 * a_ptr = (const block_q8_Kx4 *) vy + (y * nb); + for (int x = 0; x < nc / ncols_interleaved; x++) { + const block_q4_Kx8 * b_ptr = (const block_q4_Kx8 *) vx + (x * nb); + for (int m = 0; m < 4; m++) { + for (int j = 0; j < ncols_interleaved; j++) { + sumf[m][j] = 0.0; + sum_minf[m][j] = 0.0; + } + } + for (int l = 0; l < nb; l++) { + for (int sb = 0; sb < 8; sb++) { + memcpy(utmp + sb * 4, b_ptr[l].scales + sb * 12, 12); + utmp[sb * 4 + 3] = ((utmp[sb * 4 + 2] >> 4) & kmask2) | (((utmp[sb * 4 + 1] >> 6) & kmask3) << 4); + const uint32_t uaux_0 = utmp[sb * 4 + 1] & kmask1; + utmp[sb * 4 + 1] = (utmp[sb * 4 + 2] & kmask2) | (((utmp[sb * 4 + 0] >> 6) & kmask3) << 4); + utmp[sb * 4 + 2] = uaux_0; + utmp[sb * 4 + 0] &= kmask1; + } + for (int k = 0; k < (qk / (2 * blocklen)); k++) { + uint8_t *scales_0 = (uint8_t*) utmp + (k / 4) * 32; + uint8_t *scales_1 = (uint8_t*) utmp + (k / 4) * 32 + 16; + for (int m = 0; m < 4; m++) { + for (int j = 0; j < ncols_interleaved; j++) { + sumi1 = 0; + sumi2 = 0; + sumi = 0; + for (int i = 0; i < blocklen; ++i) { + const int v0 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0xF); + const int v1 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] >> 4); + sumi1 = (v0 * a_ptr[l].qs[(k >> 2) * 256 + (k % 4) * 4 * blocklen + m * blocklen + i]); + sumi2 = (v1 * a_ptr[l].qs[(k >> 2) * 256 + (k % 4) * 4 * blocklen + m * blocklen + i + 128]); + sumi1 = sumi1 * scales_0[j]; + sumi2 = sumi2 * scales_1[j]; + sumi += sumi1 + sumi2; + } + sumf[m][j] += sumi * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * a_ptr[l].d[m]; + } + } + } + for (int sb = 0; sb < 8; sb++) { + uint8_t *mins = (uint8_t*) utmp + 8 + sb * 16; + for(int m = 0; m < 4; m++) { + const int16_t *bsums = a_ptr[l].bsums + (sb * 8) + (m * 4) - ((sb % 2) * 6); + for(int j = 0; j < ncols_interleaved; j++) { + sum_minf[m][j] += mins[j] * (bsums[0] + bsums[1]) * GGML_CPU_FP16_TO_FP32(b_ptr[l].dmin[j]) * a_ptr[l].d[m]; + } + } + } + } + for (int m = 0; m < 4; m++) { + for (int j = 0; j < ncols_interleaved; j++) { + s[(y * 4 + m) * bs + x * ncols_interleaved + j] = sumf[m][j] - sum_minf[m][j]; } } } - for (int j = 0; j < ncols_interleaved; j++) s[x * ncols_interleaved + j] = sumf[j]; } } -void ggml_gemv_q8_0_4x4_q8_0_generic(int n, - float * GGML_RESTRICT s, - size_t bs, - const void * GGML_RESTRICT vx, - const void * GGML_RESTRICT vy, - int nr, - int nc) { - const int qk = QK8_0; - const int nb = n / qk; - const int ncols_interleaved = 4; - const int blocklen = 4; +void ggml_gemm_q2_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + const int qk = QK_K; + const int nb = n / qk; + const int ncols_interleaved = 8; + const int blocklen = 8; - assert(nr == 1); - assert(n % qk == 0); - assert(nc % ncols_interleaved == 0); + assert (n % qk == 0); + assert (nr % 4 == 0); + assert (nc % ncols_interleaved == 0); + UNUSED(s); UNUSED(bs); + UNUSED(vx); + UNUSED(vy); UNUSED(nr); + UNUSED(nc); + UNUSED(nb); + UNUSED(ncols_interleaved); + UNUSED(blocklen); - float sumf[4]; - int sumi; - - const block_q8_0 * a_ptr = (const block_q8_0 *) vy; - for (int x = 0; x < nc / ncols_interleaved; x++) { - const block_q8_0x4 * b_ptr = (const block_q8_0x4 *) vx + (x * nb); + float sumf[4][8]; + float sum_minf[4][8]; + int sumi1, sumi2, sumi3, sumi4; + int sumi; - for (int j = 0; j < ncols_interleaved; j++) { - sumf[j] = 0.0; - } - for (int l = 0; l < nb; l++) { - for (int k = 0; k < (qk / blocklen); k++) { + for (int y = 0; y < nr / 4; y++) { + const block_q8_Kx4 * a_ptr = (const block_q8_Kx4 *) vy + (y * nb); + for (int x = 0; x < nc / ncols_interleaved; x++) { + const block_q2_Kx8 * b_ptr = (const block_q2_Kx8 *) vx + (x * nb); + for (int m = 0; m < 4; m++) { for (int j = 0; j < ncols_interleaved; j++) { - sumi = 0; - for (int i = 0; i < blocklen; ++i) { - const int v0 = b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i]; - sumi += v0 * a_ptr[l].qs[k * blocklen + i]; + sumf[m][j] = 0.0; + sum_minf[m][j] = 0.0; + } + } + for (int l = 0; l < nb; l++) { + for (int k = 0; k < (qk / (4 * blocklen)); k++) { + + const uint8_t *scales_0 = b_ptr[l].scales + (k / 4) * 64 ; + const uint8_t *scales_1 = b_ptr[l].scales + (k / 4) * 64 + 16; + const uint8_t *scales_2 = b_ptr[l].scales + (k / 4) * 64 + 32; + const uint8_t *scales_3 = b_ptr[l].scales + (k / 4) * 64 + 48; + for (int m = 0; m < 4; m++) { + for (int j = 0; j < ncols_interleaved; j++) { + sumi1 = 0; + sumi2 = 0; + sumi3 = 0; + sumi4 = 0; + sumi = 0; + int offset = ((k / 2) % 2) + j * 2; + for (int i = 0; i < blocklen; ++i){ + const int v0 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 3); + const int v1 = (int8_t) ((b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] >> 2 ) & 3); + const int v2 = (int8_t) ((b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] >> 4 ) & 3); + const int v3 = (int8_t) ((b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] >> 6 ) & 3); + sumi1 = (v0 * a_ptr[l].qs[(k >> 2) * 512 + (k % 4) * 4 * blocklen + m * blocklen + i]); + sumi2 = (v1 * a_ptr[l].qs[(k >> 2) * 512 + (k % 4) * 4 * blocklen + m * blocklen + i + 128]); + sumi3 = (v2 * a_ptr[l].qs[(k >> 2) * 512 + (k % 4) * 4 * blocklen + m * blocklen + i + 256]); + sumi4 = (v3 * a_ptr[l].qs[(k >> 2) * 512 + (k % 4) * 4 * blocklen + m * blocklen + i + 384]); + sumi1 = sumi1 * (scales_0[offset] & 0xF); + sumi2 = sumi2 * (scales_1[offset] & 0xF); + sumi3 = sumi3 * (scales_2[offset] & 0xF); + sumi4 = sumi4 * (scales_3[offset] & 0xF); + sumi += sumi1 + sumi2 + sumi3 + sumi4; + } + sumf[m][j] += sumi * GGML_FP16_TO_FP32(b_ptr[l].d[j]) * a_ptr[l].d[m]; + } + } + } + for(int sb = 0; sb < 8; sb++) { + const uint8_t *mins = b_ptr[l].scales + sb * 16; + for(int m = 0; m < 4; m++) { + const int16_t *bsums = a_ptr[l].bsums + (sb * 8) + (m * 4) - ((sb % 2) * 6); + for(int j = 0; j < ncols_interleaved; j++) { + int mins_prod = ((mins[j * 2] >> 4) * bsums[0] + (mins[(j * 2)+ 1] >> 4) * bsums[1]); + sum_minf[m][j] += (mins_prod) * GGML_FP16_TO_FP32(b_ptr[l].dmin[j]) * a_ptr[l].d[m]; + } } - sumf[j] += sumi * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * GGML_CPU_FP16_TO_FP32(a_ptr[l].d); } } - } - for (int j = 0; j < ncols_interleaved; j++) { - s[x * ncols_interleaved + j] = sumf[j]; + + for (int m = 0; m < 4; m++) { + for (int j = 0; j < ncols_interleaved; j++) { + s[(y * 4 + m) * bs + x * ncols_interleaved + j] = sumf[m][j] - sum_minf[m][j]; + } + } } } } -void ggml_gemv_q8_0_4x8_q8_0_generic(int n, - float * GGML_RESTRICT s, - size_t bs, - const void * GGML_RESTRICT vx, - const void * GGML_RESTRICT vy, - int nr, - int nc) { - const int qk = QK8_0; - const int nb = n / qk; - const int ncols_interleaved = 4; - const int blocklen = 8; - - assert(nr == 1); - assert(n % qk == 0); - assert(nc % ncols_interleaved == 0); - - UNUSED(bs); - UNUSED(nr); +void ggml_gemm_q5_K_8x4_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + ggml_gemm_q5_K_NxM_q8_K_generic_impl<4, 8>(n, s, bs, vx, vy, nr, nc); +} - float sumf[4]; - int sumi; +void ggml_gemm_q5_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + ggml_gemm_q5_K_NxM_q8_K_generic_impl<8, 8>(n, s, bs, vx, vy, nr, nc); +} - const block_q8_0 * a_ptr = (const block_q8_0 *) vy; - for (int x = 0; x < nc / ncols_interleaved; x++) { - const block_q8_0x4 * b_ptr = (const block_q8_0x4 *) vx + (x * nb); +void ggml_gemm_q6_K_8x4_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + ggml_gemm_q6_K_NxM_q8_K_generic_impl<4, 8>(n, s, bs, vx, vy, nr, nc); +} - for (int j = 0; j < ncols_interleaved; j++) { - sumf[j] = 0.0; - } - for (int l = 0; l < nb; l++) { - for (int k = 0; k < (qk / blocklen); k++) { - for (int j = 0; j < ncols_interleaved; j++) { - sumi = 0; - for (int i = 0; i < blocklen; ++i) { - const int v0 = b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i]; - sumi += v0 * a_ptr[l].qs[k * blocklen + i]; - } - sumf[j] += sumi * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * GGML_CPU_FP16_TO_FP32(a_ptr[l].d); - } - } - } - for (int j = 0; j < ncols_interleaved; j++) { - s[x * ncols_interleaved + j] = sumf[j]; - } - } +void ggml_gemm_q6_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + ggml_gemm_q6_K_NxM_q8_K_generic_impl<8, 8>(n, s, bs, vx, vy, nr, nc); } -void ggml_gemm_q4_0_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { +void ggml_gemm_iq4_nl_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { const int qk = QK8_0; const int nb = n / qk; const int ncols_interleaved = 4; @@ -813,7 +2116,7 @@ void ggml_gemm_q4_0_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, for (int y = 0; y < nr / 4; y++) { const block_q8_0x4 * a_ptr = (const block_q8_0x4 *) vy + (y * nb); for (int x = 0; x < nc / ncols_interleaved; x++) { - const block_q4_0x4 * b_ptr = (const block_q4_0x4 *) vx + (x * nb); + const block_iq4_nlx4 * b_ptr = (const block_iq4_nlx4 *) vx + (x * nb); for (int m = 0; m < 4; m++) { for (int j = 0; j < ncols_interleaved; j++) sumf[m][j] = 0.0; } @@ -823,10 +2126,10 @@ void ggml_gemm_q4_0_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, for (int j = 0; j < ncols_interleaved; j++) { sumi = 0; for (int i = 0; i < blocklen; ++i) { - const int v0 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] << 4); - const int v1 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0xF0); + const int v0 = kvalues_iq4nl[b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0x0F]; + const int v1 = kvalues_iq4nl[b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] >> 4]; sumi += ((v0 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i]) + - (v1 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i + qk / 2 * 4])) >> 4; + (v1 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i + qk / 2 * 4])); } sumf[m][j] += sumi * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * GGML_CPU_FP16_TO_FP32(a_ptr[l].d[m]); } @@ -842,33 +2145,23 @@ void ggml_gemm_q4_0_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, } } -void ggml_gemm_q4_0_4x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { +void ggml_gemm_iq4_nl_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { const int qk = QK8_0; const int nb = n / qk; - const int ncols_interleaved = 4; + const int ncols_interleaved = 8; const int blocklen = 8; - assert (n % qk == 0); - assert (nr % 4 == 0); - assert (nc % ncols_interleaved == 0); - - UNUSED(s); - UNUSED(bs); - UNUSED(vx); - UNUSED(vy); - UNUSED(nr); - UNUSED(nc); - UNUSED(nb); - UNUSED(ncols_interleaved); - UNUSED(blocklen); + assert(n % qk == 0); + assert(nr % 4 == 0); + assert(nc % ncols_interleaved == 0); - float sumf[4][4]; + float sumf[4][8]; int sumi; for (int y = 0; y < nr / 4; y++) { const block_q8_0x4 * a_ptr = (const block_q8_0x4 *) vy + (y * nb); for (int x = 0; x < nc / ncols_interleaved; x++) { - const block_q4_0x4 * b_ptr = (const block_q4_0x4 *) vx + (x * nb); + const block_iq4_nlx8 * b_ptr = (const block_iq4_nlx8 *) vx + (x * nb); for (int m = 0; m < 4; m++) { for (int j = 0; j < ncols_interleaved; j++) sumf[m][j] = 0.0; } @@ -878,10 +2171,10 @@ void ggml_gemm_q4_0_4x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, for (int j = 0; j < ncols_interleaved; j++) { sumi = 0; for (int i = 0; i < blocklen; ++i) { - const int v0 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] << 4); - const int v1 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0xF0); + const int v0 = kvalues_iq4nl[b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0x0F]; + const int v1 = kvalues_iq4nl[b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] >> 4]; sumi += ((v0 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i]) + - (v1 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i + qk / 2 * 4])) >> 4; + (v1 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i + qk / 2 * 4])); } sumf[m][j] += sumi * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * GGML_CPU_FP16_TO_FP32(a_ptr[l].d[m]); } @@ -896,25 +2189,59 @@ void ggml_gemm_q4_0_4x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, } } -void ggml_gemm_q4_0_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { +void ggml_gemm_mxfp4_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { const int qk = QK8_0; const int nb = n / qk; - const int ncols_interleaved = 8; - const int blocklen = 8; + const int ncols_interleaved = 4; + const int blocklen = 4; - assert (n % qk == 0); - assert (nr % 4 == 0); - assert (nc % ncols_interleaved == 0); + assert(n % qk == 0); + assert(nr % 4 == 0); + assert(nc % ncols_interleaved == 0); - UNUSED(s); - UNUSED(bs); - UNUSED(vx); - UNUSED(vy); - UNUSED(nr); - UNUSED(nc); - UNUSED(nb); - UNUSED(ncols_interleaved); - UNUSED(blocklen); + float sumf[4][4]; + int sumi; + + for (int y = 0; y < nr / 4; y++) { + const block_q8_0x4 * a_ptr = (const block_q8_0x4 *) vy + (y * nb); + for (int x = 0; x < nc / ncols_interleaved; x++) { + const block_mxfp4x4 * b_ptr = (const block_mxfp4x4 *) vx + (x * nb); + for (int m = 0; m < 4; m++) { + for (int j = 0; j < ncols_interleaved; j++) sumf[m][j] = 0.0; + } + for (int l = 0; l < nb; l++) { + for (int k = 0; k < (qk / (2 * blocklen)); k++) { + for (int m = 0; m < 4; m++) { + for (int j = 0; j < ncols_interleaved; j++) { + sumi = 0; + for (int i = 0; i < blocklen; ++i) { + const int v0 = kvalues_mxfp4[b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0x0F]; + const int v1 = kvalues_mxfp4[b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] >> 4]; + sumi += ((v0 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i]) + + (v1 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i + qk / 2 * 4])); + } + sumf[m][j] += sumi * GGML_CPU_E8M0_TO_FP32_HALF(b_ptr[l].e[j]) * GGML_CPU_FP16_TO_FP32(a_ptr[l].d[m]); + } + } + } + } + for (int m = 0; m < 4; m++) { + for (int j = 0; j < ncols_interleaved; j++) + s[(y * 4 + m) * bs + x * ncols_interleaved + j] = sumf[m][j]; + } + } + } +} + +void ggml_gemm_mxfp4_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + const int qk = QK8_0; + const int nb = n / qk; + const int ncols_interleaved = 8; + const int blocklen = 8; + + assert(n % qk == 0); + assert(nr % 4 == 0); + assert(nc % ncols_interleaved == 0); float sumf[4][8]; int sumi; @@ -922,7 +2249,7 @@ void ggml_gemm_q4_0_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, for (int y = 0; y < nr / 4; y++) { const block_q8_0x4 * a_ptr = (const block_q8_0x4 *) vy + (y * nb); for (int x = 0; x < nc / ncols_interleaved; x++) { - const block_q4_0x8 * b_ptr = (const block_q4_0x8 *) vx + (x * nb); + const block_mxfp4x8 * b_ptr = (const block_mxfp4x8 *) vx + (x * nb); for (int m = 0; m < 4; m++) { for (int j = 0; j < ncols_interleaved; j++) sumf[m][j] = 0.0; } @@ -932,12 +2259,12 @@ void ggml_gemm_q4_0_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, for (int j = 0; j < ncols_interleaved; j++) { sumi = 0; for (int i = 0; i < blocklen; ++i) { - const int v0 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] << 4); - const int v1 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0xF0); + const int v0 = kvalues_mxfp4[b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0x0F]; + const int v1 = kvalues_mxfp4[b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] >> 4]; sumi += ((v0 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i]) + - (v1 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i + qk / 2 * 4])) >> 4; + (v1 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i + qk / 2 * 4])); } - sumf[m][j] += sumi * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * GGML_CPU_FP16_TO_FP32(a_ptr[l].d[m]); + sumf[m][j] += sumi * GGML_CPU_E8M0_TO_FP32_HALF(b_ptr[l].e[j]) * GGML_CPU_FP16_TO_FP32(a_ptr[l].d[m]); } } } @@ -950,183 +2277,119 @@ void ggml_gemm_q4_0_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, } } -void ggml_gemm_q4_K_8x4_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { - const int qk = QK_K; - const int nb = n / qk; - const int ncols_interleaved = 8; - const int blocklen = 4; - static const uint32_t kmask1 = 0x3f3f3f3f; - static const uint32_t kmask2 = 0x0f0f0f0f; - static const uint32_t kmask3 = 0x03030303; - - assert (n % qk == 0); - assert (nr % 4 == 0); - assert (nc % ncols_interleaved == 0); +void ggml_gemm_q8_0_4x4_q8_0_generic(int n, + float * GGML_RESTRICT s, + size_t bs, + const void * GGML_RESTRICT vx, + const void * GGML_RESTRICT vy, + int nr, + int nc) { + const int qk = QK8_0; + const int nb = n / qk; + const int ncols_interleaved = 4; + const int blocklen = 4; - UNUSED(nb); - UNUSED(ncols_interleaved); - UNUSED(blocklen); + assert(n % qk == 0); + assert(nr % 4 == 0); + assert(nc % ncols_interleaved == 0); - float sumf[4][8]; - float sum_minf[4][8]; - uint32_t utmp[32]; - int sumi1; - int sumi2; - int sumi; + float sumf[4][4]; + int sumi; for (int y = 0; y < nr / 4; y++) { - const block_q8_Kx4 * a_ptr = (const block_q8_Kx4 *) vy + (y * nb); + const block_q8_0x4 * a_ptr = (const block_q8_0x4 *) vy + (y * nb); for (int x = 0; x < nc / ncols_interleaved; x++) { - const block_q4_Kx8 * b_ptr = (const block_q4_Kx8 *) vx + (x * nb); + const block_q8_0x4 * b_ptr = (const block_q8_0x4 *) vx + (x * nb); for (int m = 0; m < 4; m++) { for (int j = 0; j < ncols_interleaved; j++) { sumf[m][j] = 0.0; - sum_minf[m][j] = 0.0; } } for (int l = 0; l < nb; l++) { - for (int sb = 0; sb < 8; sb++) { - memcpy(utmp + sb * 4, b_ptr[l].scales + sb * 12, 12); - utmp[sb * 4 + 3] = ((utmp[sb * 4 + 2] >> 4) & kmask2) | (((utmp[sb * 4 + 1] >> 6) & kmask3) << 4); - const uint32_t uaux_0 = utmp[sb * 4 + 1] & kmask1; - utmp[sb * 4 + 1] = (utmp[sb * 4 + 2] & kmask2) | (((utmp[sb * 4 + 0] >> 6) & kmask3) << 4); - utmp[sb * 4 + 2] = uaux_0; - utmp[sb * 4 + 0] &= kmask1; - } - for (int k = 0; k < (qk / (2 * blocklen)); k++) { - uint8_t * scales_0 = (uint8_t *) utmp + (k / 8) * 32; - uint8_t * scales_1 = (uint8_t *) utmp + (k / 8) * 32 + 16; + for (int k = 0; k < (qk / blocklen); k++) { for (int m = 0; m < 4; m++) { for (int j = 0; j < ncols_interleaved; j++) { - sumi1 = 0; - sumi2 = 0; sumi = 0; for (int i = 0; i < blocklen; ++i) { - const int v0 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0xF); - const int v1 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] >> 4); - sumi1 = (v0 * a_ptr[l].qs[(k / 8) * 256 + (k % 8) * 4 * blocklen + m * blocklen + i]); - sumi2 = (v1 * a_ptr[l].qs[(k / 8) * 256 + (k % 8) * 4 * blocklen + m * blocklen + i + 128]); - sumi1 = sumi1 * scales_0[j]; - sumi2 = sumi2 * scales_1[j]; - sumi += sumi1 + sumi2; + const int v0 = b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i]; + sumi += v0 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i]; } - sumf[m][j] += sumi * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * a_ptr[l].d[m]; - } - } - } - for (int sb = 0; sb < 8; sb++) { - uint8_t * mins = (uint8_t *) utmp + 8 + sb * 16; - for(int m = 0; m < 4; m++) { - const int16_t * bsums = a_ptr[l].bsums + (sb * 8) + (m * 4) - ((sb % 2) * 6); - for(int j = 0; j < ncols_interleaved; j++) { - sum_minf[m][j] += mins[j] * (bsums[0] + bsums[1]) * GGML_CPU_FP16_TO_FP32(b_ptr[l].dmin[j]) * a_ptr[l].d[m]; + sumf[m][j] += + sumi * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * GGML_CPU_FP16_TO_FP32(a_ptr[l].d[m]); } } } } for (int m = 0; m < 4; m++) { for (int j = 0; j < ncols_interleaved; j++) { - s[(y * 4 + m) * bs + x * ncols_interleaved + j] = sumf[m][j] - sum_minf[m][j]; + s[(y * 4 + m) * bs + x * ncols_interleaved + j] = sumf[m][j]; } } } } } -void ggml_gemm_q4_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { - const int qk = QK_K; - const int nb = n / qk; - const int ncols_interleaved = 8; - const int blocklen = 8; - static const uint32_t kmask1 = 0x3f3f3f3f; - static const uint32_t kmask2 = 0x0f0f0f0f; - static const uint32_t kmask3 = 0x03030303; - assert (n % qk == 0); - assert (nr % 4 == 0); - assert (nc % ncols_interleaved == 0); - UNUSED(s); - UNUSED(bs); - UNUSED(vx); - UNUSED(vy); - UNUSED(nr); - UNUSED(nc); - UNUSED(nb); - UNUSED(ncols_interleaved); - UNUSED(blocklen); +void ggml_gemm_q8_0_4x8_q8_0_generic(int n, + float * GGML_RESTRICT s, + size_t bs, + const void * GGML_RESTRICT vx, + const void * GGML_RESTRICT vy, + int nr, + int nc) { + const int qk = QK8_0; + const int nb = n / qk; + const int ncols_interleaved = 4; + const int blocklen = 8; - float sumf[4][8]; - float sum_minf[4][8]; - uint32_t utmp[32]; - int sumi1; - int sumi2; - int sumi; + assert(n % qk == 0); + assert(nr % 4 == 0); + assert(nc % ncols_interleaved == 0); + + float sumf[4][4]; + int sumi; for (int y = 0; y < nr / 4; y++) { - const block_q8_Kx4 * a_ptr = (const block_q8_Kx4 *) vy + (y * nb); + const block_q8_0x4 * a_ptr = (const block_q8_0x4 *) vy + (y * nb); for (int x = 0; x < nc / ncols_interleaved; x++) { - const block_q4_Kx8 * b_ptr = (const block_q4_Kx8 *) vx + (x * nb); + const block_q8_0x4 * b_ptr = (const block_q8_0x4 *) vx + (x * nb); for (int m = 0; m < 4; m++) { for (int j = 0; j < ncols_interleaved; j++) { sumf[m][j] = 0.0; - sum_minf[m][j] = 0.0; } } for (int l = 0; l < nb; l++) { - for (int sb = 0; sb < 8; sb++) { - memcpy(utmp + sb * 4, b_ptr[l].scales + sb * 12, 12); - utmp[sb * 4 + 3] = ((utmp[sb * 4 + 2] >> 4) & kmask2) | (((utmp[sb * 4 + 1] >> 6) & kmask3) << 4); - const uint32_t uaux_0 = utmp[sb * 4 + 1] & kmask1; - utmp[sb * 4 + 1] = (utmp[sb * 4 + 2] & kmask2) | (((utmp[sb * 4 + 0] >> 6) & kmask3) << 4); - utmp[sb * 4 + 2] = uaux_0; - utmp[sb * 4 + 0] &= kmask1; - } - for (int k = 0; k < (qk / (2 * blocklen)); k++) { - uint8_t *scales_0 = (uint8_t*) utmp + (k / 4) * 32; - uint8_t *scales_1 = (uint8_t*) utmp + (k / 4) * 32 + 16; + for (int k = 0; k < (qk / blocklen); k++) { for (int m = 0; m < 4; m++) { for (int j = 0; j < ncols_interleaved; j++) { - sumi1 = 0; - sumi2 = 0; sumi = 0; for (int i = 0; i < blocklen; ++i) { - const int v0 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0xF); - const int v1 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] >> 4); - sumi1 = (v0 * a_ptr[l].qs[(k >> 2) * 256 + (k % 4) * 4 * blocklen + m * blocklen + i]); - sumi2 = (v1 * a_ptr[l].qs[(k >> 2) * 256 + (k % 4) * 4 * blocklen + m * blocklen + i + 128]); - sumi1 = sumi1 * scales_0[j]; - sumi2 = sumi2 * scales_1[j]; - sumi += sumi1 + sumi2; + const int v0 = b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i]; + sumi += v0 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i]; } - sumf[m][j] += sumi * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * a_ptr[l].d[m]; - } - } - } - for (int sb = 0; sb < 8; sb++) { - uint8_t *mins = (uint8_t*) utmp + 8 + sb * 16; - for(int m = 0; m < 4; m++) { - const int16_t *bsums = a_ptr[l].bsums + (sb * 8) + (m * 4) - ((sb % 2) * 6); - for(int j = 0; j < ncols_interleaved; j++) { - sum_minf[m][j] += mins[j] * (bsums[0] + bsums[1]) * GGML_CPU_FP16_TO_FP32(b_ptr[l].dmin[j]) * a_ptr[l].d[m]; + sumf[m][j] += + sumi * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * GGML_CPU_FP16_TO_FP32(a_ptr[l].d[m]); } } } } for (int m = 0; m < 4; m++) { for (int j = 0; j < ncols_interleaved; j++) { - s[(y * 4 + m) * bs + x * ncols_interleaved + j] = sumf[m][j] - sum_minf[m][j]; + s[(y * 4 + m) * bs + x * ncols_interleaved + j] = sumf[m][j]; } } } } } -void ggml_gemm_q2_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { - const int qk = QK_K; +// Only enable these for RISC-V. +#if defined __riscv_zvfh +void ggml_gemm_q4_0_16x1_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + const int qk = QK8_0; const int nb = n / qk; - const int ncols_interleaved = 8; - const int blocklen = 8; + const int ncols_interleaved = 16; + const int blocklen = 1; assert (n % qk == 0); assert (nr % 4 == 0); @@ -1142,82 +2405,45 @@ void ggml_gemm_q2_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, UNUSED(ncols_interleaved); UNUSED(blocklen); - float sumf[4][8]; - float sum_minf[4][8]; - int sumi1, sumi2, sumi3, sumi4; + float sumf[4][16]; int sumi; for (int y = 0; y < nr / 4; y++) { - const block_q8_Kx4 * a_ptr = (const block_q8_Kx4 *) vy + (y * nb); + const block_q8_0x4 * a_ptr = (const block_q8_0x4 *) vy + (y * nb); for (int x = 0; x < nc / ncols_interleaved; x++) { - const block_q2_Kx8 * b_ptr = (const block_q2_Kx8 *) vx + (x * nb); + const block_q4_0x16 * b_ptr = (const block_q4_0x16 *) vx + (x * nb); for (int m = 0; m < 4; m++) { - for (int j = 0; j < ncols_interleaved; j++) { - sumf[m][j] = 0.0; - sum_minf[m][j] = 0.0; - } + for (int j = 0; j < ncols_interleaved; j++) sumf[m][j] = 0.0; } for (int l = 0; l < nb; l++) { - for (int k = 0; k < (qk / (4 * blocklen)); k++) { - - const uint8_t *scales_0 = b_ptr[l].scales + (k / 4) * 64 ; - const uint8_t *scales_1 = b_ptr[l].scales + (k / 4) * 64 + 16; - const uint8_t *scales_2 = b_ptr[l].scales + (k / 4) * 64 + 32; - const uint8_t *scales_3 = b_ptr[l].scales + (k / 4) * 64 + 48; + for (int k = 0; k < (qk / (2 * blocklen)); k++) { for (int m = 0; m < 4; m++) { for (int j = 0; j < ncols_interleaved; j++) { - sumi1 = 0; - sumi2 = 0; - sumi3 = 0; - sumi4 = 0; sumi = 0; - int offset = ((k / 2) % 2) + j * 2; - for (int i = 0; i < blocklen; ++i){ - const int v0 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 3); - const int v1 = (int8_t) ((b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] >> 2 ) & 3); - const int v2 = (int8_t) ((b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] >> 4 ) & 3); - const int v3 = (int8_t) ((b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] >> 6 ) & 3); - sumi1 = (v0 * a_ptr[l].qs[(k >> 2) * 512 + (k % 4) * 4 * blocklen + m * blocklen + i]); - sumi2 = (v1 * a_ptr[l].qs[(k >> 2) * 512 + (k % 4) * 4 * blocklen + m * blocklen + i + 128]); - sumi3 = (v2 * a_ptr[l].qs[(k >> 2) * 512 + (k % 4) * 4 * blocklen + m * blocklen + i + 256]); - sumi4 = (v3 * a_ptr[l].qs[(k >> 2) * 512 + (k % 4) * 4 * blocklen + m * blocklen + i + 384]); - sumi1 = sumi1 * (scales_0[offset] & 0xF); - sumi2 = sumi2 * (scales_1[offset] & 0xF); - sumi3 = sumi3 * (scales_2[offset] & 0xF); - sumi4 = sumi4 * (scales_3[offset] & 0xF); - sumi += sumi1 + sumi2 + sumi3 + sumi4; + for (int i = 0; i < blocklen; ++i) { + const int v0 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] << 4); + const int v1 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0xF0); + sumi += ((v0 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i]) + + (v1 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i + qk / 2 * 4])) >> 4; } - sumf[m][j] += sumi * GGML_FP16_TO_FP32(b_ptr[l].d[j]) * a_ptr[l].d[m]; - } - } - } - for(int sb = 0; sb < 8; sb++) { - const uint8_t *mins = b_ptr[l].scales + sb * 16; - for(int m = 0; m < 4; m++) { - const int16_t *bsums = a_ptr[l].bsums + (sb * 8) + (m * 4) - ((sb % 2) * 6); - for(int j = 0; j < ncols_interleaved; j++) { - int mins_prod = ((mins[j * 2] >> 4) * bsums[0] + (mins[(j * 2)+ 1] >> 4) * bsums[1]); - sum_minf[m][j] += (mins_prod) * GGML_FP16_TO_FP32(b_ptr[l].dmin[j]) * a_ptr[l].d[m]; + sumf[m][j] += sumi * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * GGML_CPU_FP16_TO_FP32(a_ptr[l].d[m]); } } } } - for (int m = 0; m < 4; m++) { - for (int j = 0; j < ncols_interleaved; j++) { - s[(y * 4 + m) * bs + x * ncols_interleaved + j] = sumf[m][j] - sum_minf[m][j]; - } + for (int j = 0; j < ncols_interleaved; j++) + s[(y * 4 + m) * bs + x * ncols_interleaved + j] = sumf[m][j]; } } } } - -void ggml_gemm_iq4_nl_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { - const int qk = QK8_0; +void ggml_gemm_q4_K_16x1_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + const int qk = QK_K; const int nb = n / qk; - const int ncols_interleaved = 4; - const int blocklen = 4; + const int ncols_interleaved = 16; + const int blocklen = 1; assert (n % qk == 0); assert (nr % 4 == 0); @@ -1233,59 +2459,97 @@ void ggml_gemm_iq4_nl_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs UNUSED(ncols_interleaved); UNUSED(blocklen); - { - float sumf[4][4]; - int sumi; + float sumf[4][16]; + float sum_minf[4][16]; + uint8_t scales[128]; + uint8_t mins[128]; + int sumi1; + int sumi2; + int sumi; + + for (int y = 0; y < nr / 4; y++) { + const block_q8_Kx4 * a_ptr = (const block_q8_Kx4 *) vy + (y * nb); + for (int x = 0; x < nc / ncols_interleaved; x++) { + const block_q4_Kx16 * b_ptr = (const block_q4_Kx16 *) vx + (x * nb); + for (int m = 0; m < 4; m++) { + for (int j = 0; j < ncols_interleaved; j++) { + sumf[m][j] = 0.0; + sum_minf[m][j] = 0.0; + } + } + for (int l = 0; l < nb; l++) { + for (int i = 0; i < 128; i++) { + scales[i] = b_ptr[l].scales[i] & 0x0F; + mins[i] = b_ptr[l].scales[i] >> 4; + } + for (int i = 0; i < 64; i++) { + scales[i] |= (b_ptr[l].scales[128 + i] & 0x03) << 4; + mins[i] |= (b_ptr[l].scales[128 + i] & 0x0C) << 2; + scales[i + 64] |= (b_ptr[l].scales[128 + i] & 0x30); + mins[i + 64] |= (b_ptr[l].scales[128 + i] & 0xC0) >> 2; + } - for (int y = 0; y < nr / 4; y++) { - const block_q8_0x4 * a_ptr = (const block_q8_0x4 *) vy + (y * nb); - for (int x = 0; x < nc / ncols_interleaved; x++) { - const block_iq4_nlx4 * b_ptr = (const block_iq4_nlx4 *) vx + (x * nb); - for (int m = 0; m < 4; m++) { - for (int j = 0; j < ncols_interleaved; j++) sumf[m][j] = 0.0; + for (int sb = 0; sb < 8; sb++) { + uint8_t *min = &mins[sb * 16]; + for(int m = 0; m < 4; m++) { + const int16_t bsums = a_ptr[l].bsums[sb * 8 + m] + a_ptr[l].bsums[sb * 8 + m + 4]; + for(int j = 0; j < ncols_interleaved; j++) { + sum_minf[m][j] += min[j] * bsums * GGML_CPU_FP16_TO_FP32(b_ptr[l].dmin[j]) * a_ptr[l].d[m]; + } + } } - for (int l = 0; l < nb; l++) { - for (int k = 0; k < (qk / (2 * blocklen)); k++) { + + for (int sb = 0; sb < 8; sb += 2) { + uint8_t *scales_0 = &scales[sb * 16]; + uint8_t *scales_1 = &scales[(sb + 1) * 16]; + + for (int i = 0; i < QK4_0; i++) { for (int m = 0; m < 4; m++) { for (int j = 0; j < ncols_interleaved; j++) { + sumi1 = 0; + sumi2 = 0; sumi = 0; - for (int i = 0; i < blocklen; ++i) { - const int v0 = kvalues_iq4nl[b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0x0F]; - const int v1 = kvalues_iq4nl[b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] >> 4]; - sumi += ((v0 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i]) + - (v1 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i + qk / 2 * 4])); - } - sumf[m][j] += sumi * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * GGML_CPU_FP16_TO_FP32(a_ptr[l].d[m]); + + const int v0 = (int8_t) (b_ptr[l].qs[sb * 256 + i * 16 + j] & 0xF); + const int v1 = (int8_t) (b_ptr[l].qs[sb * 256 + i * 16 + j] >> 4); + sumi1 = (v0 * a_ptr[l].qs[sb * 4 * 32 + i * 4 + m]); + sumi2 = (v1 * a_ptr[l].qs[sb * 4 * 32 + 32 * 4 + i * 4 + m]); + sumi1 = sumi1 * scales_0[j]; + sumi2 = sumi2 * scales_1[j]; + sumi += sumi1 + sumi2; + + sumf[m][j] += sumi * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * a_ptr[l].d[m]; } } } } - for (int m = 0; m < 4; m++) { - for (int j = 0; j < ncols_interleaved; j++) - s[(y * 4 + m) * bs + x * ncols_interleaved + j] = sumf[m][j]; + } + for (int m = 0; m < 4; m++) { + for (int j = 0; j < ncols_interleaved; j++) { + s[(y * 4 + m) * bs + x * ncols_interleaved + j] = sumf[m][j] - sum_minf[m][j]; } } } } } -void ggml_gemm_iq4_nl_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { +void ggml_gemm_iq4_nl_16x1_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { const int qk = QK8_0; const int nb = n / qk; - const int ncols_interleaved = 8; - const int blocklen = 8; + const int ncols_interleaved = 16; + const int blocklen = 1; assert(n % qk == 0); assert(nr % 4 == 0); assert(nc % ncols_interleaved == 0); - float sumf[4][8]; + float sumf[4][16]; int sumi; for (int y = 0; y < nr / 4; y++) { const block_q8_0x4 * a_ptr = (const block_q8_0x4 *) vy + (y * nb); for (int x = 0; x < nc / ncols_interleaved; x++) { - const block_iq4_nlx8 * b_ptr = (const block_iq4_nlx8 *) vx + (x * nb); + const block_iq4_nlx16 * b_ptr = (const block_iq4_nlx16 *) vx + (x * nb); for (int m = 0; m < 4; m++) { for (int j = 0; j < ncols_interleaved; j++) sumf[m][j] = 0.0; } @@ -1298,7 +2562,7 @@ void ggml_gemm_iq4_nl_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs const int v0 = kvalues_iq4nl[b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0x0F]; const int v1 = kvalues_iq4nl[b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] >> 4]; sumi += ((v0 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i]) + - (v1 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i + qk / 2 * 4])); + (v1 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i + (qk / 2) * 4])); } sumf[m][j] += sumi * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * GGML_CPU_FP16_TO_FP32(a_ptr[l].d[m]); } @@ -1313,29 +2577,23 @@ void ggml_gemm_iq4_nl_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs } } -void ggml_gemm_q8_0_4x4_q8_0_generic(int n, - float * GGML_RESTRICT s, - size_t bs, - const void * GGML_RESTRICT vx, - const void * GGML_RESTRICT vy, - int nr, - int nc) { +void ggml_gemm_q8_0_16x1_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { const int qk = QK8_0; const int nb = n / qk; - const int ncols_interleaved = 4; - const int blocklen = 4; + const int ncols_interleaved = 16; + const int blocklen = 1; assert(n % qk == 0); assert(nr % 4 == 0); assert(nc % ncols_interleaved == 0); - float sumf[4][4]; + float sumf[4][16]; int sumi; for (int y = 0; y < nr / 4; y++) { const block_q8_0x4 * a_ptr = (const block_q8_0x4 *) vy + (y * nb); for (int x = 0; x < nc / ncols_interleaved; x++) { - const block_q8_0x4 * b_ptr = (const block_q8_0x4 *) vx + (x * nb); + const block_q8_0x16 * b_ptr = (const block_q8_0x16 *) vx + (x * nb); for (int m = 0; m < 4; m++) { for (int j = 0; j < ncols_interleaved; j++) { sumf[m][j] = 0.0; @@ -1365,57 +2623,102 @@ void ggml_gemm_q8_0_4x4_q8_0_generic(int n, } } -void ggml_gemm_q8_0_4x8_q8_0_generic(int n, - float * GGML_RESTRICT s, - size_t bs, - const void * GGML_RESTRICT vx, - const void * GGML_RESTRICT vy, - int nr, - int nc) { - const int qk = QK8_0; - const int nb = n / qk; - const int ncols_interleaved = 4; - const int blocklen = 8; - assert(n % qk == 0); +void ggml_gemm_q2_K_16x1_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + assert(n % QK_K == 0); assert(nr % 4 == 0); - assert(nc % ncols_interleaved == 0); + assert(nc % 16 == 0); + const int nb = n / QK_K; + const block_q2_Kx16 * x = (const block_q2_Kx16 *)vx; + const block_q8_Kx4 * y = (const block_q8_Kx4 *)vy; + + const int sb_perm[16] = { + 0, 4, 1, 5, 2, 6, 3, 7, + 8, 12, 9, 13, 10, 14, 11, 15 + }; - float sumf[4][4]; - int sumi; + // Iterate Rows in tiles of 4 + for (int row_tile = 0; row_tile < nr; row_tile += 4) { + // Iterate Columns in tiles of 16 + for (int col_tile = 0; col_tile < nc; col_tile += 16) { + + const block_q2_Kx16 * x_ptr = x + (col_tile / 16) * nb; + const block_q8_Kx4 * y_ptr = y + (row_tile / 4) * nb; + + float sumf[4][16]; + memset(sumf, 0, sizeof(sumf)); + + for (int k_block = 0; k_block < nb; ++k_block) { + int32_t isum[4][16]; + int32_t summs[4][16]; + memset(isum, 0, sizeof(isum)); + memset(summs, 0, sizeof(summs)); + + const uint8_t * qs_rhs = x_ptr[k_block].qs; + const uint8_t * sc_rhs = x_ptr[k_block].scales; + const int8_t * qs_lhs = y_ptr[k_block].qs; + const int16_t * bs_lhs = y_ptr[k_block].bsums; + + for (int sb = 0; sb < 16; ++sb) { + int scale_offset = sb_perm[sb] * 16; + + int byte_base; + if (sb < 8) byte_base = (sb % 2 == 0) ? 0 : 16; + else byte_base = (sb % 2 == 0) ? 32 : 48; + int shift = ((sb / 2) % 4) * 2; + + for (int col = 0; col < 16; ++col) { + uint8_t sc_val = sc_rhs[scale_offset + col]; + int32_t d_sb = sc_val & 0xF; + int32_t m_sb = sc_val >> 4; + + // Correction Term + for (int r = 0; r < 4; ++r) { + int bsum_idx = (sb / 4) * 16 + r * 4 + (sb % 4); + summs[r][col] += bs_lhs[bsum_idx] * m_sb; + } - for (int y = 0; y < nr / 4; y++) { - const block_q8_0x4 * a_ptr = (const block_q8_0x4 *) vy + (y * nb); - for (int x = 0; x < nc / ncols_interleaved; x++) { - const block_q8_0x4 * b_ptr = (const block_q8_0x4 *) vx + (x * nb); - for (int m = 0; m < 4; m++) { - for (int j = 0; j < ncols_interleaved; j++) { - sumf[m][j] = 0.0; - } - } - for (int l = 0; l < nb; l++) { - for (int k = 0; k < (qk / blocklen); k++) { - for (int m = 0; m < 4; m++) { - for (int j = 0; j < ncols_interleaved; j++) { - sumi = 0; - for (int i = 0; i < blocklen; ++i) { - const int v0 = b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i]; - sumi += v0 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i]; + // Main Dot Product + for (int l = 0; l < 16; ++l) { + int qs_idx = (byte_base + l) * 16 + col; + uint8_t q2_val = (qs_rhs[qs_idx] >> shift) & 3; + + // Calculate Q8 index for this specific k and row + int k = sb * 16 + l; + int q8_idx = (k / 4) * 16 + (k % 4); + + for (int r = 0; r < 4; ++r) { + // Add r*4 to jump to the correct row within the 4x4 chunk + int8_t q8_val = qs_lhs[q8_idx + r * 4]; + isum[r][col] += q8_val * q2_val * d_sb; } - sumf[m][j] += - sumi * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * GGML_CPU_FP16_TO_FP32(a_ptr[l].d[m]); } } } + + // Finalize K-Block + for (int col = 0; col < 16; ++col) { + float d_rhs = GGML_FP16_TO_FP32(x_ptr[k_block].d[col]); + float dm_rhs = GGML_FP16_TO_FP32(x_ptr[k_block].dmin[col]); + + for (int r = 0; r < 4; ++r) { + float d_lhs = y_ptr[k_block].d[r]; + float d_all = d_lhs * d_rhs; + float d_min = d_lhs * dm_rhs; + sumf[r][col] += (isum[r][col] * d_all) - (summs[r][col] * d_min); + } + } } - for (int m = 0; m < 4; m++) { - for (int j = 0; j < ncols_interleaved; j++) { - s[(y * 4 + m) * bs + x * ncols_interleaved + j] = sumf[m][j]; + + for (int r = 0; r < 4; ++r) { + for (int col = 0; col < 16; ++col) { + s[(row_tile + r) * bs + (col_tile + col)] = sumf[r][col]; } } } } } +#endif } // extern "C" @@ -1498,16 +2801,212 @@ static block_q4_0x8 make_block_q4_0x8(block_q4_0 * in, unsigned int blck_size_in uint64_t elems; memcpy(&elems, &in[src_id].qs[src_offset], sizeof(uint64_t)); - elems ^= xor_mask; + elems ^= xor_mask; + memcpy(&out.qs[dst_offset], &elems, sizeof(uint64_t)); + } + + return out; +} + +static block_q4_0x16 make_block_q4_0x16(block_q4_0 * in, unsigned int blck_size_interleave) { + block_q4_0x16 out; + + for (int i = 0; i < 16; i++) { + out.d[i] = in[i].d; + } + + const int end = QK4_0 * 8 / blck_size_interleave; + + if (blck_size_interleave == 1) { + const uint8_t xor_mask = 0x88; + for (int i = 0; i < end; ++i) { + int src_id = i % 16; + int src_offset = i / 16; + int dst_offset = i; + + out.qs[dst_offset] = in[src_id].qs[src_offset] ^ xor_mask; + } + } else { + GGML_ASSERT(false); + } + + return out; +} + +static block_q4_Kx8 make_block_q4_Kx8(block_q4_K * in, unsigned int blck_size_interleave) { + block_q4_Kx8 out; + //Delta(scale) and dmin values of the eight Q4_K structures are copied onto the output interleaved structure + for (int i = 0; i < 8; i++) { + out.d[i] = in[i].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.d; + } + + for (int i = 0; i < 8; i++) { + out.dmin[i] = in[i].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.dmin; + } + + const int end = QK_K * 4 / blck_size_interleave; + + // Interleave Q4_K quants by taking 8 bytes at a time + for (int i = 0; i < end; ++i) { + int src_id = i % 8; + int src_offset = (i / 8) * blck_size_interleave; + int dst_offset = i * blck_size_interleave; + + // buffer large enough for the max interleave block size (8 bytes) + uint64_t elems; + memcpy(&elems, &in[src_id].qs[src_offset], blck_size_interleave); + memcpy(&out.qs[dst_offset], &elems, blck_size_interleave); + } + + // The below logic is designed so as to unpack and rearrange scales and mins values in Q4_K + // Currently the Q4_K structure has 8 scales and 8 mins packed in 12 bytes ( 6 bits for each value) + // The output Q4_Kx8 structure has 96 bytes + // Every 12 byte is packed such that it contains scales and mins for corresponding sub blocks from Q4_K structure + // For eg - First 12 bytes contains 8 scales and 8 mins - each of first sub block from different Q4_K structures + uint8_t s[8], m[8]; + + for (int i = 0; i < 4; i++) { + for (int j = 0; j < 8; j++) { + s[j] = in[j].scales[i] & 63; + m[j] = in[j].scales[i + 4] & 63; + } + + out.scales[i * 12] = (s[0] & 63) + ((s[4] & 48) << 2); + out.scales[i * 12 + 1] = (s[1] & 63) + ((s[5] & 48) << 2); + out.scales[i * 12 + 2] = (s[2] & 63) + ((s[6] & 48) << 2); + out.scales[i * 12 + 3] = (s[3] & 63) + ((s[7] & 48) << 2); + out.scales[i * 12 + 4] = (m[0] & 63) + ((m[4] & 48) << 2); + out.scales[i * 12 + 5] = (m[1] & 63) + ((m[5] & 48) << 2); + out.scales[i * 12 + 6] = (m[2] & 63) + ((m[6] & 48) << 2); + out.scales[i * 12 + 7] = (m[3] & 63) + ((m[7] & 48) << 2); + out.scales[i * 12 + 8] = (s[4] & 15) + ((m[4] & 15) << 4); + out.scales[i * 12 + 9] = (s[5] & 15) + ((m[5] & 15) << 4); + out.scales[i * 12 + 10] = (s[6] & 15) + ((m[6] & 15) << 4); + out.scales[i * 12 + 11] = (s[7] & 15) + ((m[7] & 15) << 4); + + } + + for (int i = 0; i < 4; i++) { + for (int j = 0; j < 8; j++) { + s[j] = ((in[j].scales[i] & 192) >> 2) | (in[j].scales[i+8] & 15); + m[j] = ((in[j].scales[i + 4] & 192) >> 2) | ((in[j].scales[i+8] & 240) >> 4); + } + + out.scales[i * 12 + 48] = (s[0] & 63) + ((s[4] & 48) << 2); + out.scales[i * 12 + 49] = (s[1] & 63) + ((s[5] & 48) << 2); + out.scales[i * 12 + 50] = (s[2] & 63) + ((s[6] & 48) << 2); + out.scales[i * 12 + 51] = (s[3] & 63) + ((s[7] & 48) << 2); + out.scales[i * 12 + 52] = (m[0] & 63) + ((m[4] & 48) << 2); + out.scales[i * 12 + 53] = (m[1] & 63) + ((m[5] & 48) << 2); + out.scales[i * 12 + 54] = (m[2] & 63) + ((m[6] & 48) << 2); + out.scales[i * 12 + 55] = (m[3] & 63) + ((m[7] & 48) << 2); + out.scales[i * 12 + 56] = (s[4] & 15) + ((m[4] & 15) << 4); + out.scales[i * 12 + 57] = (s[5] & 15) + ((m[5] & 15) << 4); + out.scales[i * 12 + 58] = (s[6] & 15) + ((m[6] & 15) << 4); + out.scales[i * 12 + 59] = (s[7] & 15) + ((m[7] & 15) << 4); + + } + + return out; +} + +static block_q4_Kx16 make_block_q4_Kx16(block_q4_K * in, unsigned int blck_size_interleave) { + block_q4_Kx16 out; + //Delta(scale) and dmin values of the 16 Q4_K structures are copied onto the output interleaved structure + for (int i = 0; i < 16; i++) { + out.d[i] = in[i].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.d; + } + + for (int i = 0; i < 16; i++) { + out.dmin[i] = in[i].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.dmin; + } + + const int end = QK_K * 8 / blck_size_interleave; + + if (blck_size_interleave == 1) { + for (int i = 0; i < end; ++i) { + int src_id = i % 16; + int src_offset = i / 16; + int dst_offset = i; + + out.qs[dst_offset] = in[src_id].qs[src_offset]; + } + + // RVV repacking. + // + // Extract sums and mins for all 8 sub-blocks for each block of Q4_K. + uint8_t s[128], m[128]; + for (int i = 0; i < 4; i++) { + for (int j = 0; j < 16; j++) { + s[i * 16 + j] = in[j].scales[i] & 63; + m[i * 16 + j] = in[j].scales[i + 4] & 63; + } + } + for (int i = 0; i < 4; i++) { + for (int j = 0; j < 16; j++) { + s[64 + i * 16 + j] = ((in[j].scales[i] & 192) >> 2) | (in[j].scales[i+8] & 15); + m[64 + i * 16 + j] = ((in[j].scales[i + 4] & 192) >> 2) | ((in[j].scales[i+8] & 240) >> 4); + } + } + + for (int i = 0; i < 128; i++) { + out.scales[i] = (s[i] & 15) | ((m[i] & 15) << 4); + } + for (int i = 0; i < 64; i++) { + out.scales[128 + i] = ((s[i] & 48) >> 4) | ((m[i] & 48) >> 2) | (s[64 + i] & 48) | ((m[64 + i] & 48) << 2); + } + } else { + GGML_ASSERT(false); + } + + return out; +} + +static block_q2_Kx8 make_block_q2_Kx8(block_q2_K * in, unsigned int blck_size_interleave) { + block_q2_Kx8 out; + + // Delta(scale) and dmin values of the eight Q2_K structures are copied onto the output interleaved structure + for (int i = 0; i < 8; i++) { + out.d[i] = in[i].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.d; + } + + for (int i = 0; i < 8; i++) { + out.dmin[i] = in[i].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.dmin; + } + + const int end = QK_K * 2 / blck_size_interleave; + + // Interleave Q2_K quants by taking 8 bytes at a time + for (int i = 0; i < end; ++i) { + int src_id = i % 8; + int src_offset = (i / 8) * blck_size_interleave; + int dst_offset = i * blck_size_interleave; + + uint64_t elems; + memcpy(&elems, &in[src_id].qs[src_offset], sizeof(uint64_t)); memcpy(&out.qs[dst_offset], &elems, sizeof(uint64_t)); } + // The below logic is designed so as to unpack and rearrange scales and mins values in Q2_K + // Currently the Q2_K structure has 16 scales and 16 mins packed in 16 bytes ( 4 bits for each value) + // The output Q2_Kx8 structure has 128 bytes for storing scales and mins + // Every 16 byte is packed such that it contains scales and mins for corresponding sub blocks from Q2_K structure + // For eg - First 16 bytes contains 16 scales and 16 mins - each of first and second sub blocks from different Q2_K structures + + for (int i = 0; i < 128; i++) { + // Index for selecting which q2k super block + int src1 = (i % 16) / 2; + // Index for selecting scale + int src2 = ((i / 16) * 2) + (i % 2); + + out.scales[i] = in[src1].scales[src2]; + } return out; } -static block_q4_Kx8 make_block_q4_Kx8(block_q4_K * in, unsigned int blck_size_interleave) { - block_q4_Kx8 out; - //Delta(scale) and dmin values of the eight Q4_K structures are copied onto the output interleaved structure +static block_q5_Kx8 make_block_q5_Kx8(block_q5_K * in, unsigned int blck_size_interleave) { + block_q5_Kx8 out; + //Delta(scale) and dmin values of the eight Q5_K structures are copied onto the output interleaved structure for (int i = 0; i < 8; i++) { out.d[i] = in[i].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.d; } @@ -1518,22 +3017,33 @@ static block_q4_Kx8 make_block_q4_Kx8(block_q4_K * in, unsigned int blck_size_in const int end = QK_K * 4 / blck_size_interleave; - // Interleave Q4_K quants by taking 8 bytes at a time + // Interleave Q5_K quants by taking blck_size_interleave bytes at a time for (int i = 0; i < end; ++i) { - int src_id = i % 8; + int src_id = i % 8; int src_offset = (i / 8) * blck_size_interleave; int dst_offset = i * blck_size_interleave; - uint64_t elems; - memcpy(&elems, &in[src_id].qs[src_offset], sizeof(uint64_t)); - memcpy(&out.qs[dst_offset], &elems, sizeof(uint64_t)); + memcpy(&out.qs[dst_offset], &in[src_id].qs[src_offset], blck_size_interleave); } - // The below logic is designed so as to unpack and rearrange scales and mins values in Q4_K - // Currently the Q4_K structure has 8 scales and 8 mins packed in 12 bytes ( 6 bits for each value) - // The output Q4_Kx8 structure has 96 bytes - // Every 12 byte is packed such that it contains scales and mins for corresponding sub blocks from Q4_K structure - // For eg - First 12 bytes contains 8 scales and 8 mins - each of first sub block from different Q4_K structures + // Repeat for high bits with the same chunk size, since + // the high bits are interleaved in Q5_K and the index is + // qh_idx = (qs_idx % 32); + // qh_val = qh[qh_idx] >> (qs_idx / 32); + for (int i = 0; i < end / 4; ++i) { + int src_id = i % 8; + int src_offset = (i / 8) * blck_size_interleave; + int dst_offset = i * blck_size_interleave; + + memcpy(&out.qh[dst_offset], &in[src_id].qh[src_offset], blck_size_interleave); + } + + // The below logic is copied over from Q4_K + // The point is to unpack all the scales and mins for each sub block every time we load 12 bytes. + // Currently the Q5_K structure has 8 scales and 8 mins packed in 12 bytes ( 6 bits for each value) + // The output Q5_Kx8 structure has 96 bytes + // Every 12 byte is packed such that it contains scales and mins for corresponding sub blocks from Q5_K structure + // For eg - First 12 bytes contains 8 scales and 8 mins - each of first sub block from different Q5_K structures uint8_t s[8], m[8]; for (int i = 0; i < 4; i++) { @@ -1554,13 +3064,12 @@ static block_q4_Kx8 make_block_q4_Kx8(block_q4_K * in, unsigned int blck_size_in out.scales[i * 12 + 9] = (s[5] & 15) + ((m[5] & 15) << 4); out.scales[i * 12 + 10] = (s[6] & 15) + ((m[6] & 15) << 4); out.scales[i * 12 + 11] = (s[7] & 15) + ((m[7] & 15) << 4); - } for (int i = 0; i < 4; i++) { for (int j = 0; j < 8; j++) { - s[j] = ((in[j].scales[i] & 192) >> 2) | (in[j].scales[i+8] & 15); - m[j] = ((in[j].scales[i + 4] & 192) >> 2) | ((in[j].scales[i+8] & 240) >> 4); + s[j] = ((in[j].scales[i] & 192) >> 2) | (in[j].scales[i + 8] & 15); + m[j] = ((in[j].scales[i + 4] & 192) >> 2) | ((in[j].scales[i + 8] & 240) >> 4); } out.scales[i * 12 + 48] = (s[0] & 63) + ((s[4] & 48) << 2); @@ -1575,54 +3084,117 @@ static block_q4_Kx8 make_block_q4_Kx8(block_q4_K * in, unsigned int blck_size_in out.scales[i * 12 + 57] = (s[5] & 15) + ((m[5] & 15) << 4); out.scales[i * 12 + 58] = (s[6] & 15) + ((m[6] & 15) << 4); out.scales[i * 12 + 59] = (s[7] & 15) + ((m[7] & 15) << 4); - } return out; } -static block_q2_Kx8 make_block_q2_Kx8(block_q2_K * in, unsigned int blck_size_interleave) { - block_q2_Kx8 out; +static block_q6_Kx8 make_block_q6_Kx8(block_q6_K * in, unsigned int blck_size_interleave) { + block_q6_Kx8 out; + constexpr int n_blocks = 8; // Kx8 + for (int i = 0; i < n_blocks; i++) { + out.d[i] = in[i].d; + } - // Delta(scale) and dmin values of the eight Q2_K structures are copied onto the output interleaved structure - for (int i = 0; i < 8; i++) { - out.d[i] = in[i].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.d; + const int end_ls = QK_K * 4 / blck_size_interleave; + // Interleave Q6_K quants by taking blck_size_interleave bytes at a time + for (int i = 0; i < end_ls; ++i) { + int src_id = i % n_blocks; + int src_offset = (i / n_blocks) * blck_size_interleave; + int dst_offset = i * blck_size_interleave; + + uint64_t elem_ls; + memcpy(&elem_ls, &in[src_id].ql[src_offset], blck_size_interleave); + memcpy(&out.ql[dst_offset], &elem_ls, blck_size_interleave); } - for (int i = 0; i < 8; i++) { + // Interleave high bits using same chunk size as low bits + const int end_hs = end_ls / 2; + for (int i = 0; i < end_hs; ++i) { + int src_id = i % n_blocks; + int src_offset = (i / n_blocks) * blck_size_interleave; + int dst_offset = i * blck_size_interleave; + + uint64_t elem_hs; + memcpy(&elem_hs, &in[src_id].qh[src_offset], blck_size_interleave); + memcpy(&out.qh[dst_offset], &elem_hs, blck_size_interleave); + } + + // The below logic is designed so as to unpack and rearrange scales in Q6_K + // The output Q6_Kx8 structure interleaves the 8 bit scales in the same fashion as the quants + // Q6_K structure has an 8-bit scale per 16 elements -> 16 scales + // scales: [0 bl0 0 bl1 ... 0 bl7][1 bl0 ... 1 bl7] ... [15 bl0 ... 15 bl7] (bl = block) + constexpr int n_scales = QK_K / 16; + + for (int i = 0; i < n_blocks; i++) { + for (int j = 0; j < n_scales; j++) { + out.scales[j * n_blocks + i] = in[i].scales[j]; + } + } + + return out; +} + +static block_q2_Kx16 make_block_q2_Kx16(const block_q2_K * in, unsigned int blck_size_interleave) { + block_q2_Kx16 out; + constexpr int N_COLS = 16; + + // 1. Copy Super-Scales (d) and Super-Mins (dmin) + for (int i = 0; i < N_COLS; i++) { + out.d[i] = in[i].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.d; out.dmin[i] = in[i].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.dmin; } - const int end = QK_K * 2 / blck_size_interleave; + // 2. Interleave Q2_K Data + const int bytes_per_col = 64; + const int total_bytes = N_COLS * bytes_per_col; + const int end = total_bytes / blck_size_interleave; - // Interleave Q2_K quants by taking 8 bytes at a time for (int i = 0; i < end; ++i) { - int src_id = i % 8; - int src_offset = (i / 8) * blck_size_interleave; + int src_col_id = i % N_COLS; + int src_offset = (i / N_COLS) * blck_size_interleave; int dst_offset = i * blck_size_interleave; - - uint64_t elems; - memcpy(&elems, &in[src_id].qs[src_offset], sizeof(uint64_t)); - memcpy(&out.qs[dst_offset], &elems, sizeof(uint64_t)); + memcpy(&out.qs[dst_offset], &in[src_col_id].qs[src_offset], blck_size_interleave); } - // The below logic is designed so as to unpack and rearrange scales and mins values in Q2_K - // Currently the Q2_K structure has 16 scales and 16 mins packed in 16 bytes ( 4 bits for each value) - // The output Q2_Kx8 structure has 128 bytes for storing scales and mins - // Every 16 byte is packed such that it contains scales and mins for corresponding sub blocks from Q2_K structure - // For eg - First 16 bytes contains 16 scales and 16 mins - each of first and second sub blocks from different Q2_K structures + // 3. Repack Scales into the Optimized "Sequential-Parallel" Layout + int out_idx = 0; - for(int i = 0; i < 128; i++){ + // Arrays define the sub-block order for each group + const int even_low_sbs[] = {0, 2, 4, 6}; + const int odd_low_sbs[] = {1, 3, 5, 7}; + const int even_high_sbs[] = {8, 10, 12, 14}; + const int odd_high_sbs[] = {9, 11, 13, 15}; - // Index for selecting which q2k super block - int src1 = (i % 16) / 2; - // Index for selecting scale - int src2 = ((i / 16) * 2) + (i % 2); + // Pack Group 1: Even-Low + for (int sb : even_low_sbs) { + for (int col = 0; col < N_COLS; col++) { + out.scales[out_idx++] = in[col].scales[sb]; + } + } - out.scales[i] = in[src1].scales[src2]; + // Pack Group 2: Odd-Low + for (int sb : odd_low_sbs) { + for (int col = 0; col < N_COLS; col++) { + out.scales[out_idx++] = in[col].scales[sb]; + } + } + + // Pack Group 3: Even-High + for (int sb : even_high_sbs) { + for (int col = 0; col < N_COLS; col++) { + out.scales[out_idx++] = in[col].scales[sb]; + } } - return out; + // Pack Group 4: Odd-High + for (int sb : odd_high_sbs) { + for (int col = 0; col < N_COLS; col++) { + out.scales[out_idx++] = in[col].scales[sb]; + } + } + + return out; } static int repack_q4_0_to_q4_0_4_bl(struct ggml_tensor * t, int interleave_block, const void * GGML_RESTRICT data, size_t data_size) { @@ -1687,6 +3259,36 @@ static int repack_q4_K_to_q4_K_8_bl(struct ggml_tensor * t, int interleave_block GGML_UNUSED(data_size); } +static int repack_q4_K_to_q4_K_16_bl(struct ggml_tensor * t, int interleave_block, const void * GGML_RESTRICT data, size_t data_size) { + GGML_ASSERT(t->type == GGML_TYPE_Q4_K); + constexpr int nrows_interleaved = 16; + + block_q4_Kx16 * dst = (block_q4_Kx16*)t->data; + const block_q4_K * src = (const block_q4_K*) data; + block_q4_K dst_tmp[16]; + int nrow = ggml_nrows(t); + int nblocks = t->ne[0] / QK_K; + + GGML_ASSERT(data_size == nrow * nblocks * sizeof(block_q4_K)); + + if (t->ne[1] % nrows_interleaved != 0 || t->ne[0] % 8 != 0) { + return -1; + } + + for (int b = 0; b < nrow; b += nrows_interleaved) { + for (int64_t x = 0; x < nblocks; x++) { + for (int i = 0; i < nrows_interleaved; i++ ) { + dst_tmp[i] = src[x + i * nblocks]; + } + *dst++ = make_block_q4_Kx16(dst_tmp, interleave_block); + } + src += nrows_interleaved * nblocks; + } + return 0; + + GGML_UNUSED(data_size); +} + static int repack_q2_K_to_q2_K_8_bl(struct ggml_tensor * t, int interleave_block, const void * GGML_RESTRICT data, size_t data_size) { GGML_ASSERT(t->type == GGML_TYPE_Q2_K); GGML_ASSERT(interleave_block == 8); @@ -1706,7 +3308,7 @@ static int repack_q2_K_to_q2_K_8_bl(struct ggml_tensor * t, int interleave_block for (int b = 0; b < nrow; b += nrows_interleaved) { for (int64_t x = 0; x < nblocks; x++) { - for (int i = 0; i < nrows_interleaved; i++ ) { + for (int i = 0; i < nrows_interleaved; i++) { dst_tmp[i] = src[x + i * nblocks]; } *dst++ = make_block_q2_Kx8(dst_tmp, interleave_block); @@ -1718,6 +3320,132 @@ static int repack_q2_K_to_q2_K_8_bl(struct ggml_tensor * t, int interleave_block GGML_UNUSED(data_size); } +static int repack_q2_K_to_q2_K_16_bl(struct ggml_tensor * t, int interleave_block, const void * GGML_RESTRICT data, size_t data_size) { + GGML_ASSERT(t->type == GGML_TYPE_Q2_K); + constexpr int nrows_interleaved = 16; + + block_q2_Kx16 * dst = (block_q2_Kx16*)t->data; + const block_q2_K * src = (const block_q2_K*) data; + + block_q2_K dst_tmp[nrows_interleaved]; + + int nrow = ggml_nrows(t); + int nblocks = t->ne[0] / QK_K; + + GGML_ASSERT(data_size == nrow * nblocks * sizeof(block_q2_K)); + + if (t->ne[1] % nrows_interleaved != 0 || t->ne[0] % 8 != 0) { + return -1; + } + + for (int b = 0; b < nrow; b += nrows_interleaved) { + for (int64_t x = 0; x < nblocks; x++) { + // This loop gathers 16 separate blocks (one from each column) + // that correspond to the same K-dimension chunk. + for (int i = 0; i < nrows_interleaved; i++ ) { + dst_tmp[i] = src[x + i * nblocks]; + } + + *dst++ = make_block_q2_Kx16(dst_tmp, interleave_block); + } + src += nrows_interleaved * nblocks; + } + return 0; + + GGML_UNUSED(data_size); +} + +static int repack_q4_0_to_q4_0_16_bl(struct ggml_tensor * t, int interleave_block, const void * GGML_RESTRICT data, size_t data_size) { + GGML_ASSERT(t->type == GGML_TYPE_Q4_0); + constexpr int nrows_interleaved = 16; + + block_q4_0x16 * dst = (block_q4_0x16*)t->data; + const block_q4_0 * src = (const block_q4_0*) data; + block_q4_0 dst_tmp[16]; + int nrow = ggml_nrows(t); + int nblocks = t->ne[0] / QK4_0; + + GGML_ASSERT(data_size == nrow * nblocks * sizeof(block_q4_0)); + + if (t->ne[1] % nrows_interleaved != 0 || t->ne[0] % 8 != 0) { + return -1; + } + + for (int b = 0; b < nrow; b += nrows_interleaved) { + for (int64_t x = 0; x < nblocks; x++) { + for (int i = 0; i < nrows_interleaved; i++ ) { + dst_tmp[i] = src[x + i * nblocks]; + } + *dst++ = make_block_q4_0x16(dst_tmp, interleave_block); + } + src += nrows_interleaved * nblocks; + } + return 0; + + GGML_UNUSED(data_size); +} + +static int repack_q5_K_to_q5_K_8_bl(struct ggml_tensor * t, + int interleave_block, + const void * GGML_RESTRICT data, + size_t data_size) { + GGML_ASSERT(t->type == GGML_TYPE_Q5_K); + GGML_ASSERT(interleave_block == 4 || interleave_block == 8); + constexpr int nrows_interleaved = 8; + + block_q5_Kx8 * dst = (block_q5_Kx8 *) t->data; + const block_q5_K * src = (const block_q5_K *) data; + block_q5_K dst_tmp[8]; + int nrow = ggml_nrows(t); + int nblocks = t->ne[0] / QK_K; + + GGML_ASSERT(data_size == nrow * nblocks * sizeof(block_q5_K)); + + if (t->ne[1] % nrows_interleaved != 0 || t->ne[0] % 8 != 0) { + return -1; + } + + for (int b = 0; b < nrow; b += nrows_interleaved) { + for (int64_t x = 0; x < nblocks; x++) { + for (int i = 0; i < nrows_interleaved; i++) { + dst_tmp[i] = src[x + i * nblocks]; + } + *dst++ = make_block_q5_Kx8(dst_tmp, interleave_block); + } + src += nrows_interleaved * nblocks; + } + return 0; +} + +static int repack_q6_K_to_q6_K_8_bl(struct ggml_tensor * t, int interleave_block, const void * GGML_RESTRICT data, size_t data_size) { + GGML_ASSERT(t->type == GGML_TYPE_Q6_K); + GGML_ASSERT(interleave_block == 4 || interleave_block == 8); + constexpr int nrows_interleaved = 8; + + block_q6_Kx8 * dst = (block_q6_Kx8 *)t->data; + const block_q6_K * src = (const block_q6_K *) data; + block_q6_K dst_tmp[8]; + int nrow = ggml_nrows(t); + int nblocks = t->ne[0] / QK_K; + + GGML_ASSERT(data_size == nrow * nblocks * sizeof(block_q6_K)); + + if (t->ne[1] % nrows_interleaved != 0 || t->ne[0] % 8 != 0) { + return -1; + } + + for (int b = 0; b < nrow; b += nrows_interleaved) { + for (int64_t x = 0; x < nblocks; x++) { + for (int i = 0; i < nrows_interleaved; i++) { + dst_tmp[i] = src[x + i * nblocks]; + } + *dst++ = make_block_q6_Kx8(dst_tmp, interleave_block); + } + src += nrows_interleaved * nblocks; + } + return 0; +} + static int repack_q4_0_to_q4_0_8_bl(struct ggml_tensor * t, int interleave_block, const void * GGML_RESTRICT data, size_t data_size) { GGML_ASSERT(t->type == GGML_TYPE_Q4_0); GGML_ASSERT(interleave_block == 8); @@ -1757,9 +3485,63 @@ static int repack_q8_0_to_q8_0_4_bl(struct ggml_tensor * t, GGML_ASSERT(interleave_block == 4 || interleave_block == 8); constexpr int nrows_interleaved = 4; - block_q8_0x4 * dst = (block_q8_0x4 *) t->data; + block_q8_0x4 * dst = (block_q8_0x4 *) t->data; + const block_q8_0 * src = (const block_q8_0 *) data; + block_q8_0 dst_tmp[4]; + int nrow = ggml_nrows(t); + int nblocks = t->ne[0] / QK8_0; + + GGML_ASSERT(data_size == nrow * nblocks * sizeof(block_q8_0)); + + if (t->ne[1] % nrows_interleaved != 0 || t->ne[0] % 8 != 0) { + return -1; + } + + for (int b = 0; b < nrow; b += nrows_interleaved) { + for (int64_t x = 0; x < nblocks; x++) { + for (int i = 0; i < nrows_interleaved; i++) { + dst_tmp[i] = src[x + i * nblocks]; + } + *dst++ = make_block_q8_0x4(dst_tmp, interleave_block); + } + src += nrows_interleaved * nblocks; + } + return 0; +} + +static block_q8_0x16 make_block_q8_0x16(block_q8_0 * in, unsigned int blck_size_interleave) { + block_q8_0x16 out; + + for (int i = 0; i < 16; i++) { + out.d[i] = in[i].d; + } + + const int end = QK8_0 * 16 / blck_size_interleave; + + if (blck_size_interleave == 1) { + for (int i = 0; i < end; ++i) { + int src_id = i % 16; + int src_offset = i / 16; + int dst_offset = i; + out.qs[dst_offset] = in[src_id].qs[src_offset]; + } + } else { + GGML_ASSERT(false); + } + + return out; +} + +static int repack_q8_0_to_q8_0_16_bl(struct ggml_tensor * t, + int interleave_block, + const void * GGML_RESTRICT data, + size_t data_size) { + GGML_ASSERT(t->type == GGML_TYPE_Q8_0); + constexpr int nrows_interleaved = 16; + + block_q8_0x16 * dst = (block_q8_0x16 *) t->data; const block_q8_0 * src = (const block_q8_0 *) data; - block_q8_0 dst_tmp[4]; + block_q8_0 dst_tmp[16]; int nrow = ggml_nrows(t); int nblocks = t->ne[0] / QK8_0; @@ -1774,7 +3556,7 @@ static int repack_q8_0_to_q8_0_4_bl(struct ggml_tensor * t, for (int i = 0; i < nrows_interleaved; i++) { dst_tmp[i] = src[x + i * nblocks]; } - *dst++ = make_block_q8_0x4(dst_tmp, interleave_block); + *dst++ = make_block_q8_0x16(dst_tmp, interleave_block); } src += nrows_interleaved * nblocks; } @@ -1906,6 +3688,177 @@ static int repack_iq4_nl_to_iq4_nl_8_bl(struct ggml_tensor * t, int interleave_b GGML_UNUSED(data_size); } +static block_iq4_nlx16 make_block_iq4_nlx16(block_iq4_nl * in, unsigned int blck_size_interleave) { + block_iq4_nlx16 out; + + for (int i = 0; i < 16; i++) { + out.d[i] = in[i].d; + } + + const int end = QK4_NL * 8 / blck_size_interleave; + + if (blck_size_interleave == 1) { + for (int i = 0; i < end; ++i) { + int src_id = i % 16; + int src_offset = i / 16; + int dst_offset = i; + + out.qs[dst_offset] = in[src_id].qs[src_offset]; + } + } else { + GGML_ASSERT(false); + } + + return out; +} + +static int repack_iq4_nl_to_iq4_nl_16_bl(struct ggml_tensor * t, int interleave_block, const void * GGML_RESTRICT data, size_t data_size) { + GGML_ASSERT(t->type == GGML_TYPE_IQ4_NL); + GGML_ASSERT(interleave_block == 1); + + const block_iq4_nl * src = (const block_iq4_nl *)data; + block_iq4_nlx16 * dst = ( block_iq4_nlx16 *)t->data; + + block_iq4_nl dst_tmp[16]; + + int nrow = ggml_nrows(t); + int nrows_interleaved = 16; + int nblocks = t->ne[0] / QK4_NL; + + GGML_ASSERT(data_size == nrow * nblocks * sizeof(block_iq4_nl)); + + if (t->ne[1] % nrows_interleaved != 0) { + return -1; + } + + for (int b = 0; b < nrow; b += nrows_interleaved) { + for (int64_t x = 0; x < nblocks; x++) { + for (int i = 0; i < nrows_interleaved; i++) { + dst_tmp[i] = src[x + i * nblocks]; + } + *dst++ = make_block_iq4_nlx16(dst_tmp, interleave_block); + } + src += nrows_interleaved * nblocks; + } + return 0; + + GGML_UNUSED(data_size); +} + +static block_mxfp4x4 make_block_mxfp4x4(block_mxfp4 * in, unsigned int blck_size_interleave) { + block_mxfp4x4 out; + + for (int i = 0; i < 4; i++) { + out.e[i] = in[i].e; + } + + const int end = QK_MXFP4 * 2 / blck_size_interleave; + + if (blck_size_interleave == 4) { + for (int i = 0; i < end; ++i) { + int src_id = i % 4; + int src_offset = (i / 4) * blck_size_interleave; + int dst_offset = i * blck_size_interleave; + + memcpy(&out.qs[dst_offset], &in[src_id].qs[src_offset], sizeof(uint32_t)); + } + } else { + GGML_ASSERT(false); + } + + return out; +} + +static int repack_mxfp4_to_mxfp4_4_bl(struct ggml_tensor * t, int interleave_block, const void * GGML_RESTRICT data, size_t data_size) { + GGML_ASSERT(t->type == GGML_TYPE_MXFP4); + GGML_ASSERT(interleave_block == 4); + + const block_mxfp4 * src = (const block_mxfp4 *)data; + block_mxfp4x4 * dst = ( block_mxfp4x4 *)t->data; + + block_mxfp4 dst_tmp[4]; + + int nrow = ggml_nrows(t); + int nrows_interleaved = 4; + int nblocks = t->ne[0] / QK_MXFP4; + + GGML_ASSERT(data_size == nrow * nblocks * sizeof(block_mxfp4)); + + if (t->ne[1] % nrows_interleaved != 0 || t->ne[0] % 8 != 0) { + return -1; + } + + for (int b = 0; b < nrow; b += nrows_interleaved) { + for (int64_t x = 0; x < nblocks; x++) { + for (int i = 0; i < nrows_interleaved; i++) { + dst_tmp[i] = src[x + i * nblocks]; + } + *dst++ = make_block_mxfp4x4(dst_tmp, interleave_block); + } + src += nrows_interleaved * nblocks; + } + return 0; + + GGML_UNUSED(data_size); +} + +static block_mxfp4x8 make_block_mxfp4x8(block_mxfp4 * in, unsigned int blck_size_interleave) { + block_mxfp4x8 out; + + for (int i = 0; i < 8; i++) { + out.e[i] = in[i].e; + } + + const int end = QK_MXFP4 * 4 / blck_size_interleave; + + if (blck_size_interleave == 8) { + for (int i = 0; i < end; ++i) { + int src_id = i % 8; + int src_offset = (i / 8) * blck_size_interleave; + int dst_offset = i * blck_size_interleave; + + memcpy(&out.qs[dst_offset], &in[src_id].qs[src_offset], sizeof(uint64_t)); + } + } else { + GGML_ASSERT(false); + } + + return out; +} + +static int repack_mxfp4_to_mxfp4_8_bl(struct ggml_tensor * t, int interleave_block, const void * GGML_RESTRICT data, size_t data_size) { + GGML_ASSERT(t->type == GGML_TYPE_MXFP4); + GGML_ASSERT(interleave_block == 8); + + const block_mxfp4 * src = (const block_mxfp4 *)data; + block_mxfp4x8 * dst = ( block_mxfp4x8 *)t->data; + + block_mxfp4 dst_tmp[8]; + + int nrow = ggml_nrows(t); + int nrows_interleaved = 8; + int nblocks = t->ne[0] / QK_MXFP4; + + GGML_ASSERT(data_size == nrow * nblocks * sizeof(block_mxfp4)); + + if (t->ne[1] % nrows_interleaved != 0) { + return -1; + } + + for (int b = 0; b < nrow; b += nrows_interleaved) { + for (int64_t x = 0; x < nblocks; x++) { + for (int i = 0; i < nrows_interleaved; i++) { + dst_tmp[i] = src[x + i * nblocks]; + } + *dst++ = make_block_mxfp4x8(dst_tmp, interleave_block); + } + src += nrows_interleaved * nblocks; + } + return 0; + + GGML_UNUSED(data_size); +} + namespace ggml::cpu::repack { // repack template <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS> @@ -1936,6 +3889,22 @@ template <> int repack<block_q2_K, 8, 8>(struct ggml_tensor * t, const void * da return repack_q2_K_to_q2_K_8_bl(t, 8, data, data_size); } +template <> int repack<block_q5_K, 4, 8>(struct ggml_tensor * t, const void * data, size_t data_size) { + return repack_q5_K_to_q5_K_8_bl(t, 4, data, data_size); +} + +template <> int repack<block_q5_K, 8, 8>(struct ggml_tensor * t, const void * data, size_t data_size) { + return repack_q5_K_to_q5_K_8_bl(t, 8, data, data_size); +} + +template <> int repack<block_q6_K, 4, 8>(struct ggml_tensor * t, const void * data, size_t data_size) { + return repack_q6_K_to_q6_K_8_bl(t, 4, data, data_size); +} + +template <> int repack<block_q6_K, 8, 8>(struct ggml_tensor * t, const void * data, size_t data_size) { + return repack_q6_K_to_q6_K_8_bl(t, 8, data, data_size); +} + template <> int repack<block_iq4_nl, 4, 4>(struct ggml_tensor * t, const void * data, size_t data_size) { return repack_iq4_nl_to_iq4_nl_4_bl(t, 4, data, data_size); } @@ -1949,6 +3918,14 @@ template <> int repack<block_iq4_nl, 8, 8>(struct ggml_tensor * t, const void * return repack_iq4_nl_to_iq4_nl_8_bl(t, 8, data, data_size); } +template <> int repack<block_mxfp4, 4, 4>(struct ggml_tensor * t, const void * data, size_t data_size) { + return repack_mxfp4_to_mxfp4_4_bl(t, 4, data, data_size); +} + +template <> int repack<block_mxfp4, 8, 8>(struct ggml_tensor * t, const void * data, size_t data_size) { + return repack_mxfp4_to_mxfp4_8_bl(t, 8, data, data_size); +} + template <> int repack<block_q8_0, 4, 4>(struct ggml_tensor * t, const void * data, size_t data_size) { return repack_q8_0_to_q8_0_4_bl(t, 4, data, data_size); } @@ -1957,6 +3934,28 @@ template <> int repack<block_q8_0, 8, 4>(struct ggml_tensor * t, const void * da return repack_q8_0_to_q8_0_4_bl(t, 8, data, data_size); } +#if defined __riscv_zvfh +template <> int repack<block_q4_0, 1, 16>(struct ggml_tensor * t, const void * data, size_t data_size) { + return repack_q4_0_to_q4_0_16_bl(t, 1, data, data_size); +} + +template <> int repack<block_q4_K, 1, 16>(struct ggml_tensor * t, const void * data, size_t data_size) { + return repack_q4_K_to_q4_K_16_bl(t, 1, data, data_size); +} + +template <> int repack<block_iq4_nl, 1, 16>(struct ggml_tensor * t, const void * data, size_t data_size) { + return repack_iq4_nl_to_iq4_nl_16_bl(t, 1, data, data_size); +} + +template <> int repack<block_q8_0, 1, 16>(struct ggml_tensor * t, const void * data, size_t data_size) { + return repack_q8_0_to_q8_0_16_bl(t, 1, data, data_size); +} + +template <> int repack<block_q2_K, 1, 16>(struct ggml_tensor * t, const void * data, size_t data_size) { + return repack_q2_K_to_q2_K_16_bl(t, 1, data, data_size); +} +#endif + // gemv template <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS, ggml_type PARAM_TYPE> void gemv(int, float *, size_t, const void *, const void *, int, int); @@ -1973,6 +3972,17 @@ template <> void gemv<block_q4_0, 8, 8, GGML_TYPE_Q8_0>(int n, float * s, size_t ggml_gemv_q4_0_8x8_q8_0(n, s, bs, vx, vy, nr, nc); } +template <> +void gemv<block_q2_K, 8, 8, GGML_TYPE_Q8_K>(int n, + float * s, + size_t bs, + const void * vx, + const void * vy, + int nr, + int nc) { + ggml_gemv_q2_K_8x8_q8_K(n, s, bs, vx, vy, nr, nc); +} + template <> void gemv<block_q4_K, 4, 8, GGML_TYPE_Q8_K>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { ggml_gemv_q4_K_8x4_q8_K(n, s, bs, vx, vy, nr, nc); } @@ -1981,8 +3991,20 @@ template <> void gemv<block_q4_K, 8, 8, GGML_TYPE_Q8_K>(int n, float * s, size_t ggml_gemv_q4_K_8x8_q8_K(n, s, bs, vx, vy, nr, nc); } -template <> void gemv<block_q2_K, 8, 8, GGML_TYPE_Q8_K>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { - ggml_gemv_q2_K_8x8_q8_K(n, s, bs, vx, vy, nr, nc); +template <> void gemv<block_q5_K, 4, 8, GGML_TYPE_Q8_K>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { + ggml_gemv_q5_K_8x4_q8_K(n, s, bs, vx, vy, nr, nc); +} + +template <> void gemv<block_q5_K, 8, 8, GGML_TYPE_Q8_K>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { + ggml_gemv_q5_K_8x8_q8_K(n, s, bs, vx, vy, nr, nc); +} + +template <> void gemv<block_q6_K, 4, 8, GGML_TYPE_Q8_K>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { + ggml_gemv_q6_K_8x4_q8_K(n, s, bs, vx, vy, nr, nc); +} + +template <> void gemv<block_q6_K, 8, 8, GGML_TYPE_Q8_K>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { + ggml_gemv_q6_K_8x8_q8_K(n, s, bs, vx, vy, nr, nc); } template <> void gemv<block_iq4_nl, 4, 4, GGML_TYPE_Q8_0>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { @@ -1993,6 +4015,14 @@ template <> void gemv<block_iq4_nl, 8, 8, GGML_TYPE_Q8_0>(int n, float * s, size ggml_gemv_iq4_nl_8x8_q8_0(n, s, bs, vx, vy, nr, nc); } +template <> void gemv<block_mxfp4, 4, 4, GGML_TYPE_Q8_0>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { + ggml_gemv_mxfp4_4x4_q8_0(n, s, bs, vx, vy, nr, nc); +} + +template <> void gemv<block_mxfp4, 8, 8, GGML_TYPE_Q8_0>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { + ggml_gemv_mxfp4_8x8_q8_0(n, s, bs, vx, vy, nr, nc); +} + template <> void gemv<block_q8_0, 4, 4, GGML_TYPE_Q8_0>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { ggml_gemv_q8_0_4x4_q8_0(n, s, bs, vx, vy, nr, nc); } @@ -2001,6 +4031,28 @@ template <> void gemv<block_q8_0, 8, 4, GGML_TYPE_Q8_0>(int n, float * s, size_t ggml_gemv_q8_0_4x8_q8_0(n, s, bs, vx, vy, nr, nc); } +#if defined __riscv_zvfh +template <> void gemv<block_q4_0, 1, 16, GGML_TYPE_Q8_0>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { + ggml_gemv_q4_0_16x1_q8_0(n, s, bs, vx, vy, nr, nc); +} + +template <> void gemv<block_q4_K, 1, 16, GGML_TYPE_Q8_K>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { + ggml_gemv_q4_K_16x1_q8_K(n, s, bs, vx, vy, nr, nc); +} + +template <> void gemv<block_iq4_nl, 1, 16, GGML_TYPE_Q8_0>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { + ggml_gemv_iq4_nl_16x1_q8_0(n, s, bs, vx, vy, nr, nc); +} + +template <> void gemv<block_q8_0, 1, 16, GGML_TYPE_Q8_0>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { + ggml_gemv_q8_0_16x1_q8_0(n, s, bs, vx, vy, nr, nc); +} + +template <> void gemv<block_q2_K, 1, 16, GGML_TYPE_Q8_K>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { + ggml_gemv_q2_K_16x1_q8_K(n, s, bs, vx, vy, nr, nc); +} +#endif + // gemm template <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS, ggml_type PARAM_TYPE> void gemm(int, float *, size_t, const void *, const void *, int, int); @@ -2013,20 +4065,43 @@ template <> void gemm<block_q4_0, 8, 4, GGML_TYPE_Q8_0>(int n, float * s, size_t ggml_gemm_q4_0_4x8_q8_0(n, s, bs, vx, vy, nr, nc); } -template <> void gemm<block_q4_K, 4, 8, GGML_TYPE_Q8_K>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { - ggml_gemm_q4_K_8x4_q8_K(n, s, bs, vx, vy, nr, nc); +template <> +void gemm<block_q4_0, 8, 8, GGML_TYPE_Q8_0>(int n, + float * s, + size_t bs, + const void * vx, + const void * vy, + int nr, + int nc) { + ggml_gemm_q4_0_8x8_q8_0(n, s, bs, vx, vy, nr, nc); } -template <> void gemm<block_q4_0, 8, 8, GGML_TYPE_Q8_0>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { - ggml_gemm_q4_0_8x8_q8_0(n, s, bs, vx, vy, nr, nc); +template <> void gemm<block_q2_K, 8, 8, GGML_TYPE_Q8_K>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { + ggml_gemm_q2_K_8x8_q8_K(n, s, bs, vx, vy, nr, nc); +} + +template <> void gemm<block_q4_K, 4, 8, GGML_TYPE_Q8_K>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { + ggml_gemm_q4_K_8x4_q8_K(n, s, bs, vx, vy, nr, nc); } template <> void gemm<block_q4_K, 8, 8, GGML_TYPE_Q8_K>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { ggml_gemm_q4_K_8x8_q8_K(n, s, bs, vx, vy, nr, nc); } -template <> void gemm<block_q2_K, 8, 8, GGML_TYPE_Q8_K>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { - ggml_gemm_q2_K_8x8_q8_K(n, s, bs, vx, vy, nr, nc); +template <> void gemm<block_q5_K, 4, 8, GGML_TYPE_Q8_K>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { + ggml_gemm_q5_K_8x4_q8_K(n, s, bs, vx, vy, nr, nc); +} + +template <> void gemm<block_q5_K, 8, 8, GGML_TYPE_Q8_K>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { + ggml_gemm_q5_K_8x8_q8_K(n, s, bs, vx, vy, nr, nc); +} + +template <> void gemm<block_q6_K, 4, 8, GGML_TYPE_Q8_K>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { + ggml_gemm_q6_K_8x4_q8_K(n, s, bs, vx, vy, nr, nc); +} + +template <> void gemm<block_q6_K, 8, 8, GGML_TYPE_Q8_K>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { + ggml_gemm_q6_K_8x8_q8_K(n, s, bs, vx, vy, nr, nc); } template <> void gemm<block_iq4_nl, 4, 4, GGML_TYPE_Q8_0>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { @@ -2037,6 +4112,14 @@ template <> void gemm<block_iq4_nl, 8, 8, GGML_TYPE_Q8_0>(int n, float * s, size ggml_gemm_iq4_nl_8x8_q8_0(n, s, bs, vx, vy, nr, nc); } +template <> void gemm<block_mxfp4, 4, 4, GGML_TYPE_Q8_0>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { + ggml_gemm_mxfp4_4x4_q8_0(n, s, bs, vx, vy, nr, nc); +} + +template <> void gemm<block_mxfp4, 8, 8, GGML_TYPE_Q8_0>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { + ggml_gemm_mxfp4_8x8_q8_0(n, s, bs, vx, vy, nr, nc); +} + template <> void gemm<block_q8_0, 4, 4, GGML_TYPE_Q8_0>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { ggml_gemm_q8_0_4x4_q8_0(n, s, bs, vx, vy, nr, nc); } @@ -2045,6 +4128,28 @@ template <> void gemm<block_q8_0, 8, 4, GGML_TYPE_Q8_0>(int n, float * s, size_t ggml_gemm_q8_0_4x8_q8_0(n, s, bs, vx, vy, nr, nc); } +#if defined __riscv_zvfh +template <> void gemm<block_q4_0, 1, 16, GGML_TYPE_Q8_0>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { + ggml_gemm_q4_0_16x1_q8_0(n, s, bs, vx, vy, nr, nc); +} + +template <> void gemm<block_q4_K, 1, 16, GGML_TYPE_Q8_K>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { + ggml_gemm_q4_K_16x1_q8_K(n, s, bs, vx, vy, nr, nc); +} + +template <> void gemm<block_iq4_nl, 1, 16, GGML_TYPE_Q8_0>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { + ggml_gemm_iq4_nl_16x1_q8_0(n, s, bs, vx, vy, nr, nc); +} + +template <> void gemm<block_q8_0, 1, 16, GGML_TYPE_Q8_0>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { + ggml_gemm_q8_0_16x1_q8_0(n, s, bs, vx, vy, nr, nc); +} + +template <> void gemm<block_q2_K, 1, 16, GGML_TYPE_Q8_K>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { + ggml_gemm_q2_K_16x1_q8_K(n, s, bs, vx, vy, nr, nc); +} +#endif + class tensor_traits_base : public ggml::cpu::tensor_traits { public: virtual int repack(struct ggml_tensor * t, const void * data, size_t data_size) = 0; @@ -2063,7 +4168,7 @@ template <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS, ggml_type PAR case GGML_OP_MUL_MAT_ID: { size = ggml_row_size(PARAM_TYPE, ggml_nelements(op->src[1])); - size = GGML_PAD(size, sizeof(int64_t)); // + padding for next bloc. + size = GGML_PAD(size, sizeof(int64_t)); // + padding for next block. const int64_t ne02 = op->src[0]->ne[2]; // n_as, n_expert const int64_t ne12 = op->src[1]->ne[2]; // n_tokens @@ -2328,7 +4433,7 @@ template <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS, ggml_type PAR auto * wdata = (char *)params->wdata; auto * wdata_src1_end = (char *)wdata + GGML_PAD(nbw3, sizeof(int64_t)); - // total of [n_as][ne12 + 1] elemets of type mmid_row_mapping (2*int32_t = int64_t) + // total of [n_as][ne12 + 1] elements of type mmid_row_mapping (2*int32_t = int64_t) auto * matrix_row_counts = (int64_t *) (wdata_src1_end); // [n_as] struct mmid_row_mapping * matrix_rows = (struct mmid_row_mapping *) (matrix_row_counts + n_as); // [n_as][ne12] @@ -2393,20 +4498,19 @@ template <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS, ggml_type PAR for (int ir1 = 0; ir1 < nr1; ir1++) { struct mmid_row_mapping row_mapping = MMID_MATRIX_ROW(cur_a, ir1); - const int id = row_mapping.i1; // selected expert index + const int id = row_mapping.i1; // selected expert index const int64_t i11 = id % ne11; - const int64_t i12 = row_mapping.i2; // row index in src1 + const int64_t i12 = row_mapping.i2; // row index in src1 - const int64_t i1 = id; // selected expert index - const int64_t i2 = i12; // row + const int64_t i1 = id; // selected expert index + const int64_t i2 = i12; // row const auto * src1_col = (const char *) wdata + (i11 * nbw1 + i12 * nbw2); - gemv<BLOC_TYPE, INTER_SIZE, NB_COLS, PARAM_TYPE>(ne00, - (float *)((char *) dst->data + (i1 * nb1 + i2 * nb2)) + src0_cur_start, ne01, - src0_cur + src0_cur_start * nb01, - src1_col, 1, src0_cur_end - src0_cur_start); + gemv<BLOC_TYPE, INTER_SIZE, NB_COLS, PARAM_TYPE>( + ne00, (float *) ((char *) dst->data + (i1 * nb1 + i2 * nb2)) + src0_cur_start, ne01, + src0_cur + src0_cur_start * nb01, src1_col, 1, src0_cur_end - src0_cur_start); } } #undef MMID_MATRIX_ROW @@ -2422,7 +4526,6 @@ template <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS, ggml_type PAR } // namespace ggml::cpu::repack static const ggml::cpu::tensor_traits * ggml_repack_get_optimal_repack_type(const struct ggml_tensor * cur) { - // instance for Q4 static const ggml::cpu::repack::tensor_traits<block_q4_0, 4, 4, GGML_TYPE_Q8_0> q4_0_4x4_q8_0; static const ggml::cpu::repack::tensor_traits<block_q4_0, 8, 4, GGML_TYPE_Q8_0> q4_0_4x8_q8_0; @@ -2432,6 +4535,14 @@ static const ggml::cpu::tensor_traits * ggml_repack_get_optimal_repack_type(cons static const ggml::cpu::repack::tensor_traits<block_q4_K, 4, 8, GGML_TYPE_Q8_K> q4_K_8x4_q8_K; static const ggml::cpu::repack::tensor_traits<block_q4_K, 8, 8, GGML_TYPE_Q8_K> q4_K_8x8_q8_K; + // instance for Q5_K + static const ggml::cpu::repack::tensor_traits<block_q5_K, 4, 8, GGML_TYPE_Q8_K> q5_K_8x4_q8_K; + static const ggml::cpu::repack::tensor_traits<block_q5_K, 8, 8, GGML_TYPE_Q8_K> q5_K_8x8_q8_K; + + // instance for Q6_K + static const ggml::cpu::repack::tensor_traits<block_q6_K, 4, 8, GGML_TYPE_Q8_K> q6_K_8x4_q8_K; + static const ggml::cpu::repack::tensor_traits<block_q6_K, 8, 8, GGML_TYPE_Q8_K> q6_K_8x8_q8_K; + // instance for Q2 static const ggml::cpu::repack::tensor_traits<block_q2_K, 8, 8, GGML_TYPE_Q8_K> q2_K_8x8_q8_K; @@ -2439,13 +4550,28 @@ static const ggml::cpu::tensor_traits * ggml_repack_get_optimal_repack_type(cons static const ggml::cpu::repack::tensor_traits<block_iq4_nl, 4, 4, GGML_TYPE_Q8_0> iq4_nl_4x4_q8_0; static const ggml::cpu::repack::tensor_traits<block_iq4_nl, 8, 8, GGML_TYPE_Q8_0> iq4_nl_8x8_q8_0; + // instance for MXFP4 + static const ggml::cpu::repack::tensor_traits<block_mxfp4, 4, 4, GGML_TYPE_Q8_0> mxfp4_4x4_q8_0; + static const ggml::cpu::repack::tensor_traits<block_mxfp4, 8, 8, GGML_TYPE_Q8_0> mxfp4_8x8_q8_0; + // instance for Q8_0 static const ggml::cpu::repack::tensor_traits<block_q8_0, 4, 4, GGML_TYPE_Q8_0> q8_0_4x4_q8_0; static const ggml::cpu::repack::tensor_traits<block_q8_0, 8, 4, GGML_TYPE_Q8_0> q8_0_4x8_q8_0; + // instances for RISC-V + // + // These implement outer-product style matrix multiplication kernels with + // an interleave of 1. +#if defined __riscv_zvfh + static const ggml::cpu::repack::tensor_traits<block_q4_0, 1, 16, GGML_TYPE_Q8_0> q4_0_16x1_q8_0; + static const ggml::cpu::repack::tensor_traits<block_q4_K, 1, 16, GGML_TYPE_Q8_K> q4_K_16x1_q8_K; + static const ggml::cpu::repack::tensor_traits<block_iq4_nl, 1, 16, GGML_TYPE_Q8_0> iq4_nl_16x1_q8_0; + static const ggml::cpu::repack::tensor_traits<block_q8_0, 1, 16, GGML_TYPE_Q8_0> q8_0_16x1_q8_0; + static const ggml::cpu::repack::tensor_traits<block_q2_K, 1, 16, GGML_TYPE_Q8_K> q2_K_16x1_q8_K; +#endif + if (cur->type == GGML_TYPE_Q4_0) { - if (ggml_cpu_has_avx2() || (ggml_cpu_has_sve() && ggml_cpu_has_matmul_int8() && ggml_cpu_get_sve_cnt() == QK8_0) - || (ggml_cpu_has_riscv_v() && (ggml_cpu_get_rvv_vlen() >= QK4_0))) { + if (ggml_cpu_has_avx2() || (ggml_cpu_has_sve() && ggml_cpu_has_matmul_int8() && ggml_cpu_get_sve_cnt() == QK8_0)) { if (cur->ne[1] % 8 == 0) { return &q4_0_8x8_q8_0; } @@ -2460,6 +4586,17 @@ static const ggml::cpu::tensor_traits * ggml_repack_get_optimal_repack_type(cons return &q4_0_4x4_q8_0; } } + if (ggml_cpu_has_riscv_v()) { + #if defined __riscv_zvfh + switch (__riscv_vlenb() * 8) { + case 128: { break; } // TODO + case 256: { if (cur->ne[1] % 16 == 0) { return &q4_0_16x1_q8_0; } break; } + case 512: { break; } // TODO + case 1024: { break; } // TODO + default: { return nullptr; } + } + #endif + } } else if (cur->type == GGML_TYPE_Q4_K) { if (ggml_cpu_has_avx2()) { if (cur->ne[1] % 8 == 0) { @@ -2476,12 +4613,56 @@ static const ggml::cpu::tensor_traits * ggml_repack_get_optimal_repack_type(cons return &q4_K_8x4_q8_K; } } + if (ggml_cpu_has_riscv_v()) { + #if defined __riscv_zvfh + switch (__riscv_vlenb() * 8) { + case 128: { break; } // TODO + case 256: { if (cur->ne[1] % 16 == 0) { return &q4_K_16x1_q8_K; } break; } + case 512: { break; } // TODO + case 1024: { break; } // TODO + default: { return nullptr; } + } + #endif + } } else if (cur->type == GGML_TYPE_Q2_K) { if (ggml_cpu_has_avx512()) { if (cur->ne[1] % 8 == 0) { return &q2_K_8x8_q8_K; } } + if (ggml_cpu_has_riscv_v()) { + #if defined __riscv_zvfh + switch (__riscv_vlenb() * 8) { + case 128: { break; } // TODO + case 256: { if (cur->ne[1] % 16 == 0) { return &q2_K_16x1_q8_K; } break; } + case 512: { break; } // TODO + case 1024: { break; } // TODO + default: { return nullptr; } + } + #endif + } + } else if (cur->type == GGML_TYPE_Q5_K) { + if (ggml_cpu_has_neon() && ggml_cpu_has_matmul_int8()) { + if (cur->ne[1] % 8 == 0) { + return &q5_K_8x8_q8_K; + } + } + if (ggml_cpu_has_neon() && ggml_cpu_has_dotprod()) { + if (cur->ne[1] % 8 == 0) { + return &q5_K_8x4_q8_K; + } + } + } else if (cur->type == GGML_TYPE_Q6_K) { + if (ggml_cpu_has_neon() && ggml_cpu_has_matmul_int8()) { + if (cur->ne[1] % 8 == 0) { + return &q6_K_8x8_q8_K; + } + } + if (ggml_cpu_has_neon() && ggml_cpu_has_dotprod()) { + if (cur->ne[1] % 8 == 0) { + return &q6_K_8x4_q8_K; + } + } } else if (cur->type == GGML_TYPE_IQ4_NL) { if (ggml_cpu_has_avx2()) { if (cur->ne[1] % 8 == 0) { @@ -2493,6 +4674,28 @@ static const ggml::cpu::tensor_traits * ggml_repack_get_optimal_repack_type(cons return &iq4_nl_4x4_q8_0; } } + if (ggml_cpu_has_riscv_v()) { + #if defined __riscv_zvfh + switch (__riscv_vlenb() * 8) { + case 128: { break; } // TODO + case 256: { if (cur->ne[1] % 16 == 0) { return &iq4_nl_16x1_q8_0; } break; } + case 512: { break; } // TODO + case 1024: { break; } // TODO + default: { return nullptr; } + } + #endif + } + } else if (cur->type == GGML_TYPE_MXFP4) { + if (ggml_cpu_has_avx2()) { + if (cur->ne[1] % 8 == 0) { + return &mxfp4_8x8_q8_0; + } + } + if (ggml_cpu_has_neon() && ggml_cpu_has_dotprod()) { + if (cur->ne[1] % 4 == 0) { + return &mxfp4_4x4_q8_0; + } + } } else if (cur->type == GGML_TYPE_Q8_0) { if (ggml_cpu_has_neon() && ggml_cpu_has_matmul_int8()) { if (cur->ne[1] % 4 == 0) { @@ -2504,6 +4707,17 @@ static const ggml::cpu::tensor_traits * ggml_repack_get_optimal_repack_type(cons return &q8_0_4x4_q8_0; } } + if (ggml_cpu_has_riscv_v()) { + #if defined __riscv_zvfh + switch (__riscv_vlenb() * 8) { + case 128: { break; } // TODO + case 256: { if (cur->ne[1] % 16 == 0) { return &q8_0_16x1_q8_0; } break; } + case 512: { break; } // TODO + case 1024: { break; } // TODO + default: { return nullptr; } + } + #endif + } } return nullptr; diff --git a/ggml/src/ggml-cpu/repack.h b/ggml/src/ggml-cpu/repack.h index af98e703442..cb21edf6239 100644 --- a/ggml/src/ggml-cpu/repack.h +++ b/ggml/src/ggml-cpu/repack.h @@ -28,13 +28,17 @@ template <int K, int N> struct block { // control size static_assert(sizeof(block<4, 4>) == 4 * sizeof(ggml_half) + QK8_0 * 2, "wrong block<4,4> size/padding"); static_assert(sizeof(block<4, 8>) == 8 * sizeof(ggml_half) + QK8_0 * 4, "wrong block<4,8> size/padding"); +static_assert(sizeof(block<4, 16>) == 16 * sizeof(ggml_half) + QK8_0 * 8, "wrong block<4,16> size/padding"); static_assert(sizeof(block<8, 4>) == 4 * sizeof(ggml_half) + QK8_0 * 4, "wrong block<8,4> size/padding"); static_assert(sizeof(block<8, 8>) == 8 * sizeof(ggml_half) + QK8_0 * 8, "wrong block<8,8> size/padding"); +static_assert(sizeof(block<8, 16>) == 16 * sizeof(ggml_half) + QK8_0 * 16, "wrong block<8,16> size/padding"); using block_q4_0x4 = block<4, 4>; using block_q4_0x8 = block<4, 8>; +using block_q4_0x16 = block<4, 16>; using block_q8_0x4 = block<8, 4>; using block_q8_0x8 = block<8, 8>; +using block_q8_0x16 = block<8, 16>; struct block_q4_Kx8 { ggml_half d[8]; // super-block scale for quantized scales @@ -44,6 +48,14 @@ struct block_q4_Kx8 { }; static_assert(sizeof(block_q4_Kx8) == sizeof(ggml_half) * 16 + K_SCALE_SIZE * 8 + QK_K * 4, "wrong q4_K block size/padding"); +struct block_q4_Kx16 { + ggml_half d[16]; // super-block scale for quantized scales + ggml_half dmin[16]; // super-block scale for quantized mins + uint8_t scales[192]; // scales and mins, quantized with 6 bits + uint8_t qs[2048]; // 4--bit quants +}; + +static_assert(sizeof(block_q4_Kx16) == sizeof(ggml_half) * 32 + K_SCALE_SIZE * 16 + QK_K * 8, "wrong q4_K block size/padding"); struct block_q2_Kx8 { ggml_half d[8]; // super-block scale for quantized scales ggml_half dmin[8]; // super-block scale for quantized mins @@ -52,6 +64,35 @@ struct block_q2_Kx8 { }; static_assert(sizeof(block_q2_Kx8) == sizeof(ggml_half) * 16 + QK_K/2 + QK_K * 2, "wrong q2_K block size/padding"); +struct block_q2_Kx16 { + ggml_half d[16]; // Super-block scale for quantized scales + ggml_half dmin[16]; // Super-block scale for quantized mins + uint8_t scales[256]; // Sub-block scales (16 cols * 16 sub-blocks) + uint8_t qs[1024]; // Data (16 cols * 64 bytes per block) +}; +static_assert(sizeof(block_q2_Kx16) == sizeof(ggml_half) * 32 + QK_K + QK_K * 4, "wrong q2_K block size/padding"); + +struct block_q5_Kx8 { + ggml_half d[8]; // super-block scale for quantized scales + ggml_half dmin[8]; // super-block scale for quantized mins + uint8_t scales[96]; // scales and mins, quantized with 6 bits + uint8_t qh[QK_K * 8 / 8]; // high bits of 5-bit quants + uint8_t qs[QK_K * 8 / 2]; // low bits of 5-bit quants (in groups of 4) +}; + +static_assert(sizeof(block_q5_Kx8) == sizeof(ggml_half) * 16 + K_SCALE_SIZE * 8 + QK_K * 5, + "wrong q5_K block size/padding"); + +struct block_q6_Kx8 { + ggml_half d[8]; + int8_t scales[QK_K / 16 * 8]; + uint8_t ql[QK_K / 2 * 8]; // low bits of 6-bit quants (groups of 2) + uint8_t qh[QK_K / 4 * 8]; // high bits of 6-bit quants (groups of 4) +}; + +static_assert(sizeof(block_q6_Kx8) == sizeof(ggml_half) * 8 + QK_K / 16 * 8 + 3 * QK_K / 4 * 8, + "wrong q6_K block size/padding"); + struct block_q8_Kx4 { float d[4]; // delta int8_t qs[QK_K * 4]; // quants @@ -74,6 +115,24 @@ struct block_iq4_nlx8 { static_assert(sizeof(block_iq4_nlx8) == 8 * sizeof(ggml_half) + QK4_NL * 4, "wrong iq4_nlx8 block size/padding"); +struct block_iq4_nlx16 { + ggml_half d[16]; // deltas for 16 iq4_nl blocks + uint8_t qs[QK4_NL * 8]; // nibbles / quants for 16 iq4_nl blocks +}; + +static_assert(sizeof(block_iq4_nlx16) == 16 * sizeof(ggml_half) + QK4_NL * 8, "wrong iq4_nlx16 block size/padding"); +struct block_mxfp4x4 { + uint8_t e[4]; + uint8_t qs[QK_MXFP4 * 2]; +}; +static_assert(sizeof(block_mxfp4x4) == 4 + QK_MXFP4 * 2, "wrong mxfp4x4 block size/padding"); + +struct block_mxfp4x8 { + uint8_t e[8]; + uint8_t qs[QK_MXFP4 * 4]; +}; +static_assert(sizeof(block_mxfp4x8) == 8 + QK_MXFP4 * 4, "wrong mxfp4x8 block size/padding"); + #if defined(__cplusplus) extern "C" { #endif @@ -85,23 +144,49 @@ void ggml_quantize_mat_q8_K_4x8(const float * GGML_RESTRICT x, void * GGML_RESTR void ggml_gemv_q4_0_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemv_q4_0_4x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemv_q4_0_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemv_q2_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemv_q4_K_8x4_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemv_q4_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); -void ggml_gemv_q2_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemv_q5_K_8x4_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemv_q5_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemv_q6_K_8x4_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemv_q6_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemv_iq4_nl_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemv_iq4_nl_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemv_mxfp4_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemv_mxfp4_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemv_q8_0_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemv_q8_0_4x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_q4_0_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_q4_0_4x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_q4_0_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemm_q2_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_q4_K_8x4_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_q4_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); -void ggml_gemm_q2_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemm_q5_K_8x4_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemm_q5_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemm_q6_K_8x4_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemm_q6_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_iq4_nl_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_iq4_nl_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); -void ggml_gemv_q8_0_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); -void ggml_gemv_q8_0_4x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemm_mxfp4_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemm_mxfp4_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_q8_0_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_q8_0_4x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +#if defined __riscv_zvfh +void ggml_quantize_mat_q8_0_4x1(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k); +void ggml_quantize_mat_q8_K_4x1(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k); +void ggml_gemv_q4_0_16x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemv_q4_K_16x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemv_iq4_nl_16x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemv_q8_0_16x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemv_q2_K_16x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemm_q4_0_16x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemm_q4_K_16x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemm_iq4_nl_16x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemm_q8_0_16x1_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemm_q2_K_16x1_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +#endif // Native implementations void ggml_quantize_mat_q8_0_4x4_generic(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k); @@ -111,23 +196,49 @@ void ggml_quantize_mat_q8_K_4x8_generic(const float * GGML_RESTRICT x, void * GG void ggml_gemv_q4_0_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemv_q4_0_4x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemv_q4_0_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemv_q2_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemv_q4_K_8x4_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemv_q4_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); -void ggml_gemv_q2_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemv_q5_K_8x4_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemv_q5_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemv_q6_K_8x4_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemv_q6_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemv_iq4_nl_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemv_iq4_nl_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemv_mxfp4_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemv_mxfp4_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemv_q8_0_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemv_q8_0_4x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_q4_0_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_q4_0_4x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_q4_0_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemm_q2_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_q4_K_8x4_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_q4_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); -void ggml_gemm_q2_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemm_q5_K_8x4_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemm_q5_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemm_q6_K_8x4_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemm_q6_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_iq4_nl_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_iq4_nl_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); -void ggml_gemv_q8_0_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); -void ggml_gemv_q8_0_4x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemm_mxfp4_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemm_mxfp4_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_q8_0_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_q8_0_4x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +#if defined __riscv_zvfh +void ggml_quantize_mat_q8_0_4x1_generic(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k); +void ggml_quantize_mat_q8_K_4x1_generic(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k); +void ggml_gemv_q4_0_16x1_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemv_q4_K_16x1_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemv_q8_0_16x1_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemv_q2_K_16x1_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemv_iq4_nl_16x1_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemm_q4_0_16x1_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemm_q4_K_16x1_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemm_q8_0_16x1_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemm_q2_K_16x1_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemm_iq4_nl_16x1_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +#endif #if defined(__cplusplus) } // extern "C" diff --git a/ggml/src/ggml-cpu/simd-gemm.h b/ggml/src/ggml-cpu/simd-gemm.h new file mode 100644 index 00000000000..4119d04f895 --- /dev/null +++ b/ggml/src/ggml-cpu/simd-gemm.h @@ -0,0 +1,226 @@ +#pragma once + +// Computes C[M x N] += A[M x K] * B[K x N] + +#include "simd-mappings.h" + +// TODO: add support for sizeless vector types +#if defined(GGML_SIMD) && !defined(__ARM_FEATURE_SVE) && !defined(__riscv_v_intrinsic) + +// TODO: untested on avx512 +// These are in units of GGML_F32_EPR +#if defined(__AVX512F__) || defined (__ARM_NEON__) + static constexpr int GEMM_RM = 4; + static constexpr int GEMM_RN = 4; // 16+4+1 = 25/32 +#elif defined(__AVX2__) || defined(__AVX__) + static constexpr int GEMM_RM = 6; + static constexpr int GEMM_RN = 2; // 12+2+1 = 15/16 +#else + static constexpr int GEMM_RM = 2; + static constexpr int GEMM_RN = 2; +#endif + +template <int RM, int RN> +static inline void simd_gemm_ukernel( + float * GGML_RESTRICT C, + const float * GGML_RESTRICT A, + const float * GGML_RESTRICT B, + int K, int N) +{ + static constexpr int KN = GGML_F32_EPR; + + GGML_F32_VEC acc[RM][RN]; + for (int64_t i = 0; i < RM; i++) { + for (int r = 0; r < RN; r++) { + acc[i][r] = GGML_F32_VEC_LOAD(C + i * N + r * KN); + } + } + + for (int64_t kk = 0; kk < K; kk++) { + GGML_F32_VEC Bv[RN]; + for (int r = 0; r < RN; r++) { + Bv[r] = GGML_F32_VEC_LOAD(B + kk * N + r * KN); + } + for (int64_t i = 0; i < RM; i++) { + GGML_F32_VEC p = GGML_F32_VEC_SET1(A[i * K + kk]); + for (int r = 0; r < RN; r++) { + acc[i][r] = GGML_F32_VEC_FMA(acc[i][r], Bv[r], p); + } + } + } + + for (int64_t i = 0; i < RM; i++) { + for (int r = 0; r < RN; r++) { + GGML_F32_VEC_STORE(C + i * N + r * KN, acc[i][r]); + } + } +} + +// C[M x N] += A[M x K] * B[K x N] +static void simd_gemm( + float * GGML_RESTRICT C, + const float * GGML_RESTRICT A, + const float * GGML_RESTRICT B, + int M, int K, int N) +{ + static constexpr int KN = GGML_F32_EPR; + + int64_t ii = 0; + for (; ii + GEMM_RM <= M; ii += GEMM_RM) { + int64_t jj = 0; + for (; jj + GEMM_RN * KN <= N; jj += GEMM_RN * KN) { + simd_gemm_ukernel<GEMM_RM, GEMM_RN>(C + jj, A, B + jj, K, N); + } + for (; jj + KN <= N; jj += KN) { + simd_gemm_ukernel<GEMM_RM, 1>(C + jj, A, B + jj, K, N); + } + for (; jj < N; jj++) { + for (int64_t i = 0; i < GEMM_RM; i++) { + float a = C[i * N + jj]; + for (int64_t kk = 0; kk < K; kk++) { + a += A[i + kk] * B[kk * N + jj]; + } + C[i * N + jj] = a; + } + } + + A += GEMM_RM * K; + C += GEMM_RM * N; + } + + // Tail rows: one at a time + for (; ii < M; ii++) { + int64_t jj = 0; + for (; jj + GEMM_RN * KN <= N; jj += GEMM_RN * KN) { + simd_gemm_ukernel<1, GEMM_RN>(C + jj, A, B + jj, K, N); + } + for (; jj + KN <= N; jj += KN) { + simd_gemm_ukernel<1, 1>(C + jj, A, B + jj, K, N); + } + for (; jj < N; jj++) { + float a = C[jj]; + for (int64_t kk = 0; kk < K; kk++) { + a += A[kk] * B[kk * N + jj]; + } + C[jj] = a; + } + + A += K; + C += N; + } +} +#elif defined(GGML_SIMD) && defined(__riscv_v_intrinsic) +// RM accumulators + 1 B vector = RM + 1 <= 8 => RM <= 7 +// Microkernel: C[RM x vl] += A[RM x K] * B[K x N] +template <int RM> +static inline void rvv_simd_gemm_ukernel( + float * GGML_RESTRICT C, + const float * GGML_RESTRICT A, + const float * GGML_RESTRICT B, + int K, int N, size_t vl) +{ + static_assert(RM >= 1 && RM <= 7, "RM must be 1..7 for LMUL=4"); + + vfloat32m4_t acc_0 = __riscv_vle32_v_f32m4(C + 0 * N, vl); + vfloat32m4_t acc_1, acc_2, acc_3, acc_4, acc_5, acc_6; + if constexpr (RM > 1) acc_1 = __riscv_vle32_v_f32m4(C + 1 * N, vl); + if constexpr (RM > 2) acc_2 = __riscv_vle32_v_f32m4(C + 2 * N, vl); + if constexpr (RM > 3) acc_3 = __riscv_vle32_v_f32m4(C + 3 * N, vl); + if constexpr (RM > 4) acc_4 = __riscv_vle32_v_f32m4(C + 4 * N, vl); + if constexpr (RM > 5) acc_5 = __riscv_vle32_v_f32m4(C + 5 * N, vl); + if constexpr (RM > 6) acc_6 = __riscv_vle32_v_f32m4(C + 6 * N, vl); + + for (int kk = 0; kk < K; kk++) { + vfloat32m4_t b_0 = __riscv_vle32_v_f32m4(B + kk * N, vl); + + acc_0 = __riscv_vfmacc_vf_f32m4(acc_0, A[0 * K + kk], b_0, vl); + if constexpr (RM > 1) acc_1 = __riscv_vfmacc_vf_f32m4(acc_1, A[1 * K + kk], b_0, vl); + if constexpr (RM > 2) acc_2 = __riscv_vfmacc_vf_f32m4(acc_2, A[2 * K + kk], b_0, vl); + if constexpr (RM > 3) acc_3 = __riscv_vfmacc_vf_f32m4(acc_3, A[3 * K + kk], b_0, vl); + if constexpr (RM > 4) acc_4 = __riscv_vfmacc_vf_f32m4(acc_4, A[4 * K + kk], b_0, vl); + if constexpr (RM > 5) acc_5 = __riscv_vfmacc_vf_f32m4(acc_5, A[5 * K + kk], b_0, vl); + if constexpr (RM > 6) acc_6 = __riscv_vfmacc_vf_f32m4(acc_6, A[6 * K + kk], b_0, vl); + } + + __riscv_vse32_v_f32m4(C + 0 * N, acc_0, vl); + if constexpr (RM > 1) __riscv_vse32_v_f32m4(C + 1 * N, acc_1, vl); + if constexpr (RM > 2) __riscv_vse32_v_f32m4(C + 2 * N, acc_2, vl); + if constexpr (RM > 3) __riscv_vse32_v_f32m4(C + 3 * N, acc_3, vl); + if constexpr (RM > 4) __riscv_vse32_v_f32m4(C + 4 * N, acc_4, vl); + if constexpr (RM > 5) __riscv_vse32_v_f32m4(C + 5 * N, acc_5, vl); + if constexpr (RM > 6) __riscv_vse32_v_f32m4(C + 6 * N, acc_6, vl); +} + +template <int RM> +static inline void rvv_simd_gemm_dispatch_tail( + float * GGML_RESTRICT C, + const float * GGML_RESTRICT A, + const float * GGML_RESTRICT B, + int K, int N, int KN, int remaining_rows) +{ + if constexpr (RM > 0) { + if (remaining_rows == RM) { + int64_t jj = 0; + for (; jj + KN <= N; jj += KN) { + rvv_simd_gemm_ukernel<RM>(C + jj, A, B + jj, K, N, KN); + } + if (jj < N) { + rvv_simd_gemm_ukernel<RM>(C + jj, A, B + jj, K, N, N - jj); + } + } else { + rvv_simd_gemm_dispatch_tail<RM - 1>(C, A, B, K, N, KN, remaining_rows); + } + } +} + +static constexpr int GEMM_RM = 7; + +// C[M x N] += A[M x K] * B[K x N] +static void simd_gemm( + float * GGML_RESTRICT C, + const float * GGML_RESTRICT A, + const float * GGML_RESTRICT B, + int M, int K, int N) +{ + const int KN = (int)__riscv_vlenb(); + int64_t ii = 0; + for (; ii + GEMM_RM <= M; ii += GEMM_RM) { + int64_t jj = 0; + for (; jj + KN <= N; jj += KN) { + rvv_simd_gemm_ukernel<GEMM_RM>(C + jj, A, B + jj, K, N, KN); + } + if (jj < N) { + rvv_simd_gemm_ukernel<GEMM_RM>(C + jj, A, B + jj, K, N, N - jj); + } + A += GEMM_RM * K; + C += GEMM_RM * N; + } + + int remaining_rows = M - ii; + rvv_simd_gemm_dispatch_tail<GEMM_RM - 1>(C, A, B, K, N, KN, remaining_rows); +} + +#if defined(__GNUC__) && !defined(__clang__) +#pragma GCC diagnostic pop +#endif + +#else // scalar path + +static void simd_gemm( + float * GGML_RESTRICT C, + const float * GGML_RESTRICT A, + const float * GGML_RESTRICT B, + int M, int K, int N) +{ + for (int64_t i = 0; i < M; i++) { + for (int64_t j = 0; j < N; j++) { + float sum = C[i * N + j]; + for (int64_t kk = 0; kk < K; kk++) { + sum += A[i * K + kk] * B[kk * N + j]; + } + C[i * N + j] = sum; + } + } +} + +#endif // GGML_SIMD diff --git a/ggml/src/ggml-cpu/simd-mappings.h b/ggml/src/ggml-cpu/simd-mappings.h index a7a82722052..62e687201ef 100644 --- a/ggml/src/ggml-cpu/simd-mappings.h +++ b/ggml/src/ggml-cpu/simd-mappings.h @@ -116,6 +116,17 @@ extern "C" { // defined in ggml-cpu.c, initialized in ggml_cpu_init() extern float ggml_table_f32_f16[1 << 16]; +// precomputed f32 table for e8m0 half (1 KB) +// defined in ggml-cpu.c, initialized in ggml_cpu_init() +extern float ggml_table_f32_e8m0_half[1 << 8]; + +// Use lookup table for E8M0 on x86 (faster than bit manipulation) +#if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__) +#define GGML_CPU_E8M0_TO_FP32_HALF(x) ggml_table_f32_e8m0_half[(uint8_t)(x)] +#else +#define GGML_CPU_E8M0_TO_FP32_HALF(x) GGML_E8M0_TO_FP32_HALF(x) +#endif + // On ARM NEON, it's quicker to directly convert x -> x instead of calling into ggml_lookup_fp16_to_fp32, // so we define GGML_CPU_FP16_TO_FP32 and GGML_CPU_FP32_TO_FP16 elsewhere for NEON. // This is also true for POWER9. @@ -468,13 +479,51 @@ do { \ // F16 AVX512 -// F16 AVX +#if defined(__AVX512FP16__) + +#define GGML_F16_STEP 128 +#define GGML_F16_EPR 32 + +#define GGML_F16x32 __m512h +#define GGML_F16x32_ZERO _mm512_setzero_ph() +#define GGML_F16x32_SET1(x) _mm512_set1_ph(__extension__(_Float16)(x)) +#define GGML_F16x32_LOAD(x) _mm512_loadu_ph(x) +#define GGML_F16x32_STORE(x, y) _mm512_storeu_ph(x, y) +#define GGML_F16x32_FMA(a, b, c) _mm512_fmadd_ph(b, c, a) +#define GGML_F16x32_ADD _mm512_add_ph +#define GGML_F16x32_MUL _mm512_mul_ph +#define GGML_F16x32_REDUCE(res, x) \ +do { \ + int offset = GGML_F16_ARR >> 1; \ + for (int i = 0; i < offset; ++i) { \ + x[i] = _mm512_add_ph(x[i], x[offset+i]); \ + } \ + offset >>= 1; \ + for (int i = 0; i < offset; ++i) { \ + x[i] = _mm512_add_ph(x[i], x[offset+i]); \ + } \ + offset >>= 1; \ + for (int i = 0; i < offset; ++i) { \ + x[i] = _mm512_add_ph(x[i], x[offset+i]); \ + } \ + res = (ggml_float) _mm512_reduce_add_ph(x[0]); \ +} while (0) + +#define GGML_F16_VEC GGML_F16x32 +#define GGML_F16_VEC_ZERO GGML_F16x32_ZERO +#define GGML_F16_VEC_SET1 GGML_F16x32_SET1 +#define GGML_F16_VEC_LOAD(p, i) GGML_F16x32_LOAD(p) +#define GGML_F16_VEC_STORE(p, r, i) GGML_F16x32_STORE(p, r[i]) +#define GGML_F16_VEC_FMA GGML_F16x32_FMA +#define GGML_F16_VEC_ADD GGML_F16x32_ADD +#define GGML_F16_VEC_MUL GGML_F16x32_MUL +#define GGML_F16_VEC_REDUCE GGML_F16x32_REDUCE + +#else // Fallback FP16 <-> FP32 #define GGML_F16_STEP 64 #define GGML_F16_EPR 16 -// AVX512 has FP16 extension (AVX512_FP16) but I don't have it on my machine so I use FP32 instead - #define GGML_F32Cx16 __m512 #define GGML_F32Cx16_ZERO _mm512_setzero_ps() #define GGML_F32Cx16_SET1(x) _mm512_set1_ps(x) @@ -514,6 +563,8 @@ do { \ #define GGML_F16_VEC_MUL GGML_F32Cx16_MUL #define GGML_F16_VEC_REDUCE GGML_F32Cx16_REDUCE + +#endif // __AVX512FP16__ #elif defined(__AVX__) #define GGML_SIMD @@ -654,6 +705,14 @@ static inline void __avx_f32cx8_store(ggml_fp16_t *x, __m256 y) { vec_extract(x[0], 2) + \ vec_extract(x[0], 3); \ } +#define GGML_F32x4_REDUCE_4(res, s0, s1, s2, s3) \ +{ \ + vector float v = vec_add(vec_add(s0, s1), \ + vec_add(s2, s3)); \ + v = vec_add(v, vec_sld(v, v, 8)); \ + v = vec_add(v, vec_sld(v, v, 4)); \ + res += (ggml_float) vec_extract(v, 0); \ +} #define GGML_F32_VEC GGML_F32x4 #define GGML_F32_VEC_ZERO GGML_F32x4_ZERO @@ -690,6 +749,29 @@ static inline unsigned char ggml_endian_byte(int i) { r[i - GGML_ENDIAN_BYTE(0)]), \ 0, p - GGML_F16_EPR) +//BF16 POWER9 +#define GGML_BF16_STEP 16 +#define GGML_BF16_EPR 8 + +#define GGML_BF16x8 vector unsigned short +#define GGML_BF16x8_ZERO vec_splats((unsigned short)0) +#define GGML_BF16x8_LOAD(p) vec_xl(0, (const unsigned short *)(p)) + +#define GGML_BF16_VEC GGML_BF16x8 +#define GGML_BF16_VEC_ZERO GGML_BF16x8_ZERO +#define GGML_BF16_VEC_LOAD GGML_BF16x8_LOAD +#if defined(__LITTLE_ENDIAN__) +#define GGML_BF16_TO_F32_LO(v) ((vector float) vec_mergel(GGML_BF16_VEC_ZERO, (v))) +#define GGML_BF16_TO_F32_HI(v) ((vector float) vec_mergeh(GGML_BF16_VEC_ZERO, (v))) +#else +#define GGML_BF16_TO_F32_LO(v) ((vector float) vec_mergel((v), GGML_BF16_VEC_ZERO)) +#define GGML_BF16_TO_F32_HI(v) ((vector float) vec_mergeh((v), GGML_BF16_VEC_ZERO)) +#endif +#define GGML_BF16_FMA_LO(acc, x, y) \ + (acc) = GGML_F32x4_FMA((acc), GGML_BF16_TO_F32_LO(x), GGML_BF16_TO_F32_LO(y)) +#define GGML_BF16_FMA_HI(acc, x, y) \ + (acc) = GGML_F32x4_FMA((acc), GGML_BF16_TO_F32_HI(x), GGML_BF16_TO_F32_HI(y)) + #elif defined(__wasm_simd128__) #define GGML_SIMD @@ -1043,25 +1125,12 @@ static inline void __lasx_f32cx8_store(ggml_fp16_t * x, __m256 y) { #define GGML_F16_EPR 4 static inline __m128 __lsx_f16x4_load(const ggml_fp16_t * x) { - float tmp[4]; - - tmp[0] = GGML_CPU_FP16_TO_FP32(x[0]); - tmp[1] = GGML_CPU_FP16_TO_FP32(x[1]); - tmp[2] = GGML_CPU_FP16_TO_FP32(x[2]); - tmp[3] = GGML_CPU_FP16_TO_FP32(x[3]); - - return (__m128)__lsx_vld(tmp, 0); + return __lsx_vfcvtl_s_h(__lsx_vld((const void *)x, 0)); } static inline void __lsx_f16x4_store(ggml_fp16_t * x, __m128 y) { - float arr[4]; - - __lsx_vst(y, arr, 0); - - x[0] = GGML_CPU_FP32_TO_FP16(arr[0]); - x[1] = GGML_CPU_FP32_TO_FP16(arr[1]); - x[2] = GGML_CPU_FP32_TO_FP16(arr[2]); - x[3] = GGML_CPU_FP32_TO_FP16(arr[3]); + __m128i a = __lsx_vfcvt_h_s(y, y); + memcpy(x, &a, sizeof(ggml_fp16_t) * 4); } #define GGML_F32Cx4 __m128 @@ -1118,6 +1187,14 @@ static inline void __lsx_f16x4_store(ggml_fp16_t * x, __m128 y) { float32x4_t tmp = x[0] + vec_reve(x[0]); \ res = tmp[0] + tmp[1]; \ } +#define GGML_F32x4_REDUCE_4(res, s0, s1, s2, s3) \ +{ \ + float32x4_t v = vec_add(vec_add(s0, s1), \ + vec_add(s2, s3)); \ + v = vec_add(v, vec_sld(v, v, 8)); \ + v = vec_add(v, vec_sld(v, v, 4)); \ + res += (ggml_float)vec_extract(v, 0); \ +} #define GGML_F32_VEC GGML_F32x4 #define GGML_F32_VEC_ZERO GGML_F32x4_ZERO @@ -1167,6 +1244,24 @@ static inline void __lzs_f16cx4_store(ggml_fp16_t * x, float32x4_t v_y) { #define GGML_F16_VEC_MUL GGML_F32x4_MUL #define GGML_F16_VEC_REDUCE GGML_F32x4_REDUCE +// BF16 s390x +#define GGML_BF16_STEP 16 +#define GGML_BF16_EPR 8 + +#define GGML_BF16x8 __vector unsigned short +#define GGML_BF16x8_ZERO vec_splats((unsigned short)0) +#define GGML_BF16x8_LOAD(p) vec_xl(0, (const unsigned short *)(p)) + +#define GGML_BF16_VEC GGML_BF16x8 +#define GGML_BF16_VEC_ZERO GGML_BF16x8_ZERO +#define GGML_BF16_VEC_LOAD GGML_BF16x8_LOAD +#define GGML_BF16_TO_F32_LO(v) ((float32x4_t) vec_mergel((v), GGML_BF16_VEC_ZERO)) +#define GGML_BF16_TO_F32_HI(v) ((float32x4_t) vec_mergeh((v), GGML_BF16_VEC_ZERO)) +#define GGML_BF16_FMA_LO(acc, x, y) \ + (acc) = GGML_F32x4_FMA((acc), GGML_BF16_TO_F32_LO(x), GGML_BF16_TO_F32_LO(y)) +#define GGML_BF16_FMA_HI(acc, x, y) \ + (acc) = GGML_F32x4_FMA((acc), GGML_BF16_TO_F32_HI(x), GGML_BF16_TO_F32_HI(y)) + #elif defined(__riscv_v_intrinsic) // compatible with vlen >= 128 diff --git a/ggml/src/ggml-cpu/spacemit/ime.cpp b/ggml/src/ggml-cpu/spacemit/ime.cpp index 91fe1925eaa..9563ea3e4bd 100644 --- a/ggml/src/ggml-cpu/spacemit/ime.cpp +++ b/ggml/src/ggml-cpu/spacemit/ime.cpp @@ -3,19 +3,32 @@ #include "ime.h" +#include "binary-ops.h" +#include "common.h" #include "ggml-backend-impl.h" #include "ggml-common.h" #include "ggml-cpu.h" +#include "ime_env.h" #include "ime_kernels.h" +#include "ops.h" +#include "repack.h" +#include "rvv_kernels.h" +#include "spine_mem_pool.h" #include "traits.h" +#include "vec.h" + +#include <fcntl.h> +#include <sys/mman.h> +#include <unistd.h> #include <algorithm> +#include <atomic> #include <cassert> +#include <cerrno> #include <cmath> #include <cstdio> // for GGML_ASSERT #include <stdexcept> #include <thread> - // clang-format off #if defined(__riscv) @@ -25,13 +38,17 @@ #include <riscv_vector.h> #endif -#if !defined(__riscv_zfh) -#error "riscv zfh extension not enabled" +#if !defined(__riscv_zfh) || !defined(__riscv_zvfh) +#error "riscv zfh extension not enabled, GGML_RV_ZFH and GGML_RV_ZVFH must be defined to 1" #endif -#if defined(RISCV64_SPACEMIT_IME1) +#if !defined(__riscv_zba) +#error "riscv zba extension not enabled, GGML_RV_ZBA must be defined to 1" +#endif + +#if defined(RISCV64_SPACEMIT_IME1) || defined(RISCV64_SPACEMIT_IME2) #else -#error "RISCV64_SPACEMIT_IME1 not defined" +#error "RISCV64_SPACEMIT_IME1 or RISCV64_SPACEMIT_IME2 not defined" #endif #else @@ -46,382 +63,490 @@ #pragma GCC diagnostic ignored "-Wunused-parameter" #endif -#if defined(RISCV64_SPACEMIT_IME1) -#define QGEMM_STRIDEN_THREAD_ALIGN 16 -#else -#define QGEMM_STRIDEN_THREAD_ALIGN 32 -#endif - // clang-format on -struct qnbitgemm_spacemit_ime_args { - const float * a_ptr = nullptr; - size_t lda = 0; - const std::byte * packed_quant_b_data = nullptr; - const float * quant_b_scale = nullptr; - const void * quant_b_zp = nullptr; - const float * quant_b_blksum = nullptr; - const float * bias = nullptr; - float * c_ptr = nullptr; - size_t ldc = 0; -}; - -constexpr size_t div_round_up(size_t up, size_t down) { - return (up + down - 1) / down; -} - -constexpr size_t q8_blk_size(size_t blk_len) { - const size_t blk_size = sizeof(float) + blk_len * sizeof(int8_t); - // Currently, the strictest alignment requirement of a block is for a float. - // Ensure contiguous blocks are suitably aligned. - assert(blk_size % alignof(float) == 0); - return blk_size; +extern "C" { +extern void ggml_threadpool_chunk_set(struct ggml_threadpool * tp, int value); +extern int ggml_threadpool_chunk_add(struct ggml_threadpool * tp, int value); } namespace ggml::cpu::riscv64_spacemit { -const int num_ai_cores = std::thread::hardware_concurrency() / 2; - -} // namespace ggml::cpu::riscv64_spacemit +struct TLSContext { + int cpu_id{ -1 }; + cpu_set_t cpuset; + void * tcm_buffer{ nullptr }; + size_t tcm_buffer_size{ 0 }; +}; -static void sqnbitgemm_spacemit_ime_i8i4(const size_t blk_len, - const size_t gemm_k, - const qnbitgemm_spacemit_ime_args * gemm_args, - void * const per_gemm_ws, - const size_t m_start, - const size_t m_count, - const size_t n_start, - const size_t n_count) { - constexpr size_t scale_stride = sizeof(uint16_t); - constexpr size_t blk_bitwidth = 4; +thread_local TLSContext tls_context; + +template <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS> constexpr size_t get_repacked_block_type_size() { + if constexpr (std::is_same_v<BLOC_TYPE, block_q6_K> || std::is_same_v<BLOC_TYPE, block_q8_0>) { + return sizeof(block_q8_0); + } else if constexpr (std::is_same_v<BLOC_TYPE, block_q4_0>) { + return sizeof(block_q4_0) * INTER_SIZE / QK4_0; + } else if constexpr (std::is_same_v<BLOC_TYPE, block_q4_1> || std::is_same_v<BLOC_TYPE, block_q4_K>) { + return (sizeof(block_q4_0) + sizeof(uint8_t)) * INTER_SIZE / QK4_1; + } else if constexpr (std::is_same_v<BLOC_TYPE, block_q2_K>) { + return sizeof(spacemit_kernels::nrow_block_q2_k<1>); + } else if constexpr (std::is_same_v<BLOC_TYPE, block_q3_K>) { + return sizeof(spacemit_kernels::nrow_block_q3_k<1>); + } else if constexpr (std::is_same_v<BLOC_TYPE, block_mxfp4>) { + return sizeof(spacemit_kernels::nrow_block_mxfp4<1>); + } else if constexpr (std::is_same_v<BLOC_TYPE, block_q5_1> || std::is_same_v<BLOC_TYPE, block_q5_K>) { + return sizeof(spacemit_kernels::nrow_block_q5_1<1>); + } else if constexpr (std::is_same_v<BLOC_TYPE, block_q5_0>) { + return sizeof(spacemit_kernels::nrow_block_q5_0<1>); + } else { + assert(false); + return 0; + } +} - const size_t k_blks = div_round_up(gemm_k, blk_len); +template <typename BLOC_TYPE> constexpr bool block_type_has_zp() { + if constexpr (std::is_same_v<BLOC_TYPE, block_q6_K> || std::is_same_v<BLOC_TYPE, block_q8_0> || + std::is_same_v<BLOC_TYPE, block_q3_K> || std::is_same_v<BLOC_TYPE, block_q4_0> || + std::is_same_v<BLOC_TYPE, block_mxfp4> || std::is_same_v<BLOC_TYPE, block_q5_0>) { + return false; + } else if constexpr (std::is_same_v<BLOC_TYPE, block_q4_1> || std::is_same_v<BLOC_TYPE, block_q4_K> || + std::is_same_v<BLOC_TYPE, block_q2_K> || std::is_same_v<BLOC_TYPE, block_q5_1> || + std::is_same_v<BLOC_TYPE, block_q5_K>) { + return true; + } else { + assert(false); + return false; + } +} - const size_t lda = k_blks * q8_blk_size(blk_len); - const size_t ldc = gemm_args->ldc; - const size_t ldb = k_blks * (blk_len * blk_bitwidth / 8); - const std::byte * quant_a_ptr = static_cast<const std::byte *>(per_gemm_ws) + m_start * lda; +class tensor_traits_base : public ggml::cpu::tensor_traits { + public: + virtual int repack(ggml_tensor * t, const void * data, size_t data_size) = 0; +}; - const size_t zero_point_stride = gemm_args->quant_b_zp != nullptr ? sizeof(uint8_t) : 0; - const size_t packed_b_stride = ldb + k_blks * (scale_stride + zero_point_stride); - const std::byte * packed_quant_b_data = gemm_args->packed_quant_b_data + n_start * packed_b_stride; +template <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS> class tensor_traits : public tensor_traits_base { + bool work_size(int /* n_threads */, const ggml_tensor * op, size_t & size) override { + switch (op->op) { + case GGML_OP_MUL_MAT: + { + int64_t src1_nelements = ggml_nelements(op->src[1]); + + if constexpr (std::is_same_v<BLOC_TYPE, block_q2_K> || std::is_same_v<BLOC_TYPE, block_q3_K>) { + size = + spacemit_kernels::div_round_up(src1_nelements, QK_K) * spacemit_kernels::q8k_blk_size(QK_K); + } else if constexpr (INTER_SIZE == QK4_0) { + size = spacemit_kernels::div_round_up(src1_nelements, QK4_0) * + spacemit_kernels::q8_blk_size(QK4_0, true); + } else if constexpr (INTER_SIZE == 256) { + size = spacemit_kernels::div_round_up(src1_nelements, 256) * + spacemit_kernels::q8_hp_blk_size(256, true, true); + } else { + GGML_ABORT("unsupported block type"); + } - float * c_ptr = gemm_args->c_ptr + m_start * ldc + n_start; + size = GGML_PAD(size, sizeof(int64_t)); - size_t count_n = 0; - const size_t compute_block_count_n = m_count == 1 ? n_count : 16; - for (size_t n = 0; n < n_count; n += count_n) { - count_n = std::min(n_count - n, compute_block_count_n); + return true; + } + case GGML_OP_MUL_MAT_ID: + { + int64_t src1_nelements = ggml_nelements(op->src[1]); + + if constexpr (std::is_same_v<BLOC_TYPE, block_q2_K> || std::is_same_v<BLOC_TYPE, block_q3_K>) { + size = + spacemit_kernels::div_round_up(src1_nelements, QK_K) * spacemit_kernels::q8k_blk_size(QK_K); + } else if constexpr (INTER_SIZE == QK4_0) { + size = spacemit_kernels::div_round_up(src1_nelements, QK4_0) * + spacemit_kernels::q8_blk_size(QK4_0, true); + } else if constexpr (INTER_SIZE == 256) { + size = spacemit_kernels::div_round_up(src1_nelements, 256) * + spacemit_kernels::q8_hp_blk_size(256, true, true); + } else { + GGML_ABORT("unsupported block type"); + } - const std::byte * a_row = quant_a_ptr; - const std::byte * b_col = packed_quant_b_data + n * packed_b_stride; - const std::byte * b_col_zp = (zero_point_stride != 0) ? b_col : nullptr; - float * c_blk = c_ptr + n; + size = GGML_PAD(size, sizeof(int64_t)); - int32_t rows_remaining = m_count; + const int64_t ne02 = op->src[0]->ne[2]; // n_as, n_expert + const int64_t ne12 = op->src[1]->ne[2]; // n_tokens - while (rows_remaining > 0) { - const auto rows_handled = sqnbitgemm_spacemit_ime::ime1::gemm_kernel_i8i4( - blk_len, a_row, b_col, nullptr, b_col_zp, c_blk, rows_remaining, count_n, gemm_k, k_blks, ldc, nullptr, - scale_stride); + const size_t sizeof_mmid_row_mapping = sizeof(int64_t); + size += sizeof_mmid_row_mapping * ne02 * (ne12 + 1) + (ne02 + 1) * sizeof(int64_t); - c_blk += rows_handled * ldc; - a_row += rows_handled * lda; + size = GGML_PAD(size, sizeof(int64_t)); - rows_remaining -= rows_handled; + return true; + } + default: + // GGML_ABORT("fatal error"); + break; } + return false; } -} -template <int K> constexpr int QK_0() { - if constexpr (K == 4) { - return QK4_0; - } - if constexpr (K == 8) { - return QK8_0; + bool compute_forward(ggml_compute_params * params, ggml_tensor * op) override { + switch (op->op) { + case GGML_OP_MUL_MAT: + switch (op->src[0]->type) { + case GGML_TYPE_Q2_K: + case GGML_TYPE_Q3_K: + case GGML_TYPE_Q4_0: + case GGML_TYPE_Q4_1: + case GGML_TYPE_Q4_K: + case GGML_TYPE_Q6_K: + case GGML_TYPE_Q8_0: + case GGML_TYPE_Q5_1: + case GGML_TYPE_Q5_K: + //case GGML_TYPE_MXFP4: + forward_mul_mat(params, op); + return true; + default: + // GGML_ABORT("fatal error: unsupported type for src0 in MUL_MAT"); + return false; + } + break; + case GGML_OP_MUL_MAT_ID: + switch (op->src[0]->type) { + case GGML_TYPE_Q2_K: + case GGML_TYPE_Q3_K: + case GGML_TYPE_Q4_0: + case GGML_TYPE_Q4_1: + case GGML_TYPE_Q4_K: + case GGML_TYPE_Q6_K: + case GGML_TYPE_Q8_0: + case GGML_TYPE_Q5_1: + case GGML_TYPE_Q5_K: + //case GGML_TYPE_MXFP4: + forward_mul_mat_id(params, op); + return true; + default: + // GGML_ABORT("fatal error: unsupported type for src0 in MUL_MAT_ID"); + return false; + } + break; + default: + // GGML_ABORT("fatal error"); + break; + } + return false; } - return -1; -} -template <int K, int N> struct block { - ggml_half d[N]; // deltas for N qK_0 blocks - uint8_t qs[(QK_0<K>() * N * K) / 8]; // quants for N qK_0 blocks -}; + void forward_mul_mat(ggml_compute_params * params, ggml_tensor * op) { + constexpr size_t a_blk_len = INTER_SIZE; + constexpr size_t b_blk_len = INTER_SIZE; -template <int K, int N> struct block_with_zp { - ggml_half d[N]; // deltas for N qK_1 blocks - uint8_t zp[N]; // zero points for N qK_1 blocks - uint8_t qs[(QK_0<K>() * N * K) / 8]; // quants for N qK_1 blocks -}; + const ggml_tensor * src0 = op->src[0]; + const ggml_tensor * src1 = op->src[1]; + ggml_tensor * dst = op; -// control size -static_assert(sizeof(block<4, 16>) == 16 * sizeof(ggml_half) + QK4_0 * 8, "wrong block<4,16> size/padding"); -static_assert(sizeof(block_with_zp<4, 16>) == 16 * sizeof(ggml_half) + QK4_0 * 8 + 16 * sizeof(uint8_t), - "wrong block_with_zp<4,16> size/padding"); -static_assert(sizeof(block<8, 16>) == 16 * sizeof(ggml_half) + QK4_0 * 16, "wrong block<8,16> size/padding"); + GGML_TENSOR_BINARY_OP_LOCALS -using block_q4_0x16 = block<4, 16>; -using block_q4_1x16 = block_with_zp<4, 16>; -using block_q8_0x16 = block<8, 16>; + int ith = params->ith; + int nth = params->nth; -static block_q4_0x16 make_block_q4_0x16(block_q4_0 * in, unsigned int blck_size_interleave) { - block_q4_0x16 out; - GGML_ASSERT(QK4_0 / blck_size_interleave == 2); + [[maybe_unused]] const enum ggml_type type = src0->type; - for (int i = 0; i < 16; i++) { - out.d[i] = in[i].d; - } + void * w_data = (void *) src0->data; + const float * feature = (const float *) src1->data; + float * output = (float *) dst->data; - for (int i = 0; i < 16; i++) { - // [0, 15], in.d & 0x0F - for (int j = 0; j < QK4_0 / 4; j++) { - //src [b0 b16] ......... [b8 b24] ......... [b15 b31] - //dst [b0 b8] ......... [b7 b15] - out.qs[i * QK4_0 / 4 + j] = (in[i].qs[j] & 0x0F) | ((in[i].qs[j + QK4_0 / 4] & 0x0F) << 4); + const int64_t gemm_m = ne11 * ne12 * ne13; + const int64_t gemm_k = ne10; + const int64_t gemm_n = ne01; + + spacemit_kernels::quantize_a_row_def quantize_a_row_i8; + spacemit_kernels::quantize_a_row_def quantize_a_4row_i8; + spacemit_kernels::gemm_kernel_quantize_def gemm_kernel; + bool set_kernel_impl = false; + + int64_t block_stride_a = spacemit_kernels::q8_blk_size(a_blk_len); + +#if defined(RISCV64_SPACEMIT_IME2) + if (!set_kernel_impl && (global_spine_env_info.use_ime2)) { + quantize_a_row_i8 = spacemit_kernels::rvv::quantize_a_row_i8; + quantize_a_4row_i8 = spacemit_kernels::rvv::quantize_a_4row_i8; + block_stride_a = spacemit_kernels::q8_blk_size(a_blk_len, true); + + if constexpr (std::is_same_v<BLOC_TYPE, block_q6_K> || std::is_same_v<BLOC_TYPE, block_q8_0>) { + gemm_kernel = spacemit_kernels::ime2::gemm_kernel_i8i8; + set_kernel_impl = true; + } else if constexpr (std::is_same_v<BLOC_TYPE, block_q4_0> || std::is_same_v<BLOC_TYPE, block_q4_1> || + std::is_same_v<BLOC_TYPE, block_q4_K>) { + if constexpr (INTER_SIZE == 256) { + gemm_kernel = spacemit_kernels::ime2::gemm_kernel_i8i4_hp; + quantize_a_row_i8 = spacemit_kernels::rvv::quantize_a_row_i8_hp; + quantize_a_4row_i8 = spacemit_kernels::rvv::quantize_a_4row_i8_hp; + block_stride_a = spacemit_kernels::q8_hp_blk_size(a_blk_len, true, true); + set_kernel_impl = true; + } else { + gemm_kernel = spacemit_kernels::ime2::gemm_kernel_i8i4; + quantize_a_row_i8 = spacemit_kernels::rvv::quantize_a_row_i8; + quantize_a_4row_i8 = spacemit_kernels::rvv::quantize_a_4row_i8; + block_stride_a = spacemit_kernels::q8_blk_size(a_blk_len, true); + set_kernel_impl = true; + } + } else if constexpr (std::is_same_v<BLOC_TYPE, block_q2_K>) { + quantize_a_row_i8 = spacemit_kernels::rvv::quantize_a_row_i8k; + quantize_a_4row_i8 = spacemit_kernels::rvv::quantize_a_4row_i8k; + block_stride_a = spacemit_kernels::q8k_blk_size(a_blk_len); + + gemm_kernel = spacemit_kernels::ime2::gemm_kernel_i8i2k; + set_kernel_impl = true; + } else if constexpr (std::is_same_v<BLOC_TYPE, block_q3_K>) { + quantize_a_row_i8 = spacemit_kernels::rvv::quantize_a_row_i8k; + quantize_a_4row_i8 = spacemit_kernels::rvv::quantize_a_4row_i8k; + block_stride_a = spacemit_kernels::q8k_blk_size(a_blk_len); + + gemm_kernel = spacemit_kernels::ime2::gemm_kernel_i8i3k; + set_kernel_impl = true; + } else if constexpr (std::is_same_v<BLOC_TYPE, block_mxfp4>) { + gemm_kernel = spacemit_kernels::ime2::gemm_kernel_i8mxfp4; + set_kernel_impl = true; + } else if constexpr (std::is_same_v<BLOC_TYPE, block_q5_1> || std::is_same_v<BLOC_TYPE, block_q5_K> || + std::is_same_v<BLOC_TYPE, block_q5_0>) { + gemm_kernel = spacemit_kernels::ime2::gemm_kernel_i8i5; + set_kernel_impl = true; + } } - } +#endif - for (int i = 0; i < 16; i++) { - // [16, 31], in.d & 0xF0 - for (int j = 0; j < QK4_0 / 4; j++) { - //src [b0 b16] ......... [b8 b24] ......... [b15 b31] - //dst [b16 b24] ......... [b23 b31] - out.qs[4 * QK4_0 + i * QK4_0 / 4 + j] = ((in[i].qs[j] & 0xF0) >> 4) | (in[i].qs[j + QK4_0 / 4] & 0xF0); +#if defined(RISCV64_SPACEMIT_IME1) + if (!set_kernel_impl && (global_spine_env_info.use_ime1)) { + quantize_a_row_i8 = spacemit_kernels::ime1::quantize_a_row_i8; + quantize_a_4row_i8 = spacemit_kernels::ime1::quantize_a_4row_i8; + + if constexpr (std::is_same_v<BLOC_TYPE, block_q4_0> || std::is_same_v<BLOC_TYPE, block_q4_1> || + std::is_same_v<BLOC_TYPE, block_q4_K>) { + gemm_kernel = spacemit_kernels::ime1::gemm_kernel_i8i4; + set_kernel_impl = true; + } + } +#endif + if (!set_kernel_impl) { + GGML_ABORT("no kernel implementation found for the block type"); } - } - return out; -} + const int64_t a_k_blks = spacemit_kernels::div_round_up(gemm_k, a_blk_len); + const int64_t b_k_blks = spacemit_kernels::div_round_up(gemm_k, b_blk_len); -static block_q4_1x16 make_block_q4_1x16(block_q4_1 * in, unsigned int blck_size_interleave) { - block_q4_1x16 out; - GGML_ASSERT(QK4_1 / blck_size_interleave == 2); - - for (int i = 0; i < 16; i++) { - float d = GGML_FP16_TO_FP32(in[i].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.d); - float m = GGML_FP16_TO_FP32(in[i].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.m); - float mid = -std::nearbyintf(m / d); - mid = std::min(15.0f, std::max(0.0f, mid)); - out.d[i] = GGML_FP32_TO_FP16(d); - out.zp[i] = static_cast<uint8_t>(mid); - } + const int64_t row_stride_a = a_k_blks * block_stride_a; + const int64_t gemm_workspace_size = GGML_PAD(gemm_m * row_stride_a, alignof(int64_t)); - for (int i = 0; i < 16; i++) { - // [0, 15], in.d & 0x0F - for (int j = 0; j < QK4_1 / 4; j++) { - //src [b0 b16] ......... [b8 b24] ......... [b15 b31] - //dst [b0 b8] ......... [b7 b15] - out.qs[i * QK4_1 / 4 + j] = (in[i].qs[j] & 0x0F) | ((in[i].qs[j + QK4_1 / 4] & 0x0F) << 4); + if (ith == 0 && params->wsize < gemm_workspace_size) { + GGML_ABORT("wsize less than gemm_workspace_size"); } - } - for (int i = 0; i < 16; i++) { - // [16, 31], in.d & 0xF0 - for (int j = 0; j < QK4_1 / 4; j++) { - //src [b0 b16] ......... [b8 b24] ......... [b15 b31] - //dst [b16 b24] ......... [b23 b31] - out.qs[4 * QK4_1 + i * QK4_1 / 4 + j] = ((in[i].qs[j] & 0xF0) >> 4) | (in[i].qs[j + QK4_1 / 4] & 0xF0); - } - } + uintptr_t ws_ptr = reinterpret_cast<uintptr_t>(params->wdata); - return out; -} + void * tcm_buffer = ggml::cpu::riscv64_spacemit::tls_context.tcm_buffer; + const int64_t tcm_buffer_size = ggml::cpu::riscv64_spacemit::tls_context.tcm_buffer_size; -static int repack_q4_0_to_q4_0_16_bl(struct ggml_tensor * t, - int interleave_block, - const void * GGML_RESTRICT data, - size_t data_size) { - GGML_ASSERT(t->type == GGML_TYPE_Q4_0); - GGML_ASSERT(interleave_block == 16); + auto * quant_a_buffer = reinterpret_cast<uint8_t *>(ws_ptr); - constexpr int nrows_interleaved = 16; + constexpr int64_t row_align = 4; + const int64_t row_blks = spacemit_kernels::div_round_up(gemm_m, row_align); - block_q4_0x16 * dst = (block_q4_0x16 *) t->data; - const block_q4_0 * src = (const block_q4_0 *) data; - block_q4_0 dst_tmp[16]; - int nrow = ggml_nrows(t); - int nblocks = t->ne[0] / QK4_0; + const int64_t row_stride_b = b_k_blks * get_repacked_block_type_size<BLOC_TYPE, INTER_SIZE, NB_COLS>(); + const int64_t per_mb_rows_wsize = row_align * row_stride_a; + const int64_t per_nb_cols_wsize = NB_COLS * row_stride_b; - GGML_ASSERT(data_size == nrow * nblocks * sizeof(block_q4_0)); + const int64_t barrier_idx = static_cast<int64_t>(ith / 2); - if (t->ne[1] % nrows_interleaved != 0 || t->ne[0] % QK4_0 != 0) { - return -1; - } + GGML_ASSERT(global_spine_env_info.init_barrier != nullptr); + GGML_ASSERT(barrier_idx < spine_init_barrier_count); + spine_barrier_t * cur_barrier = &global_spine_env_info.init_barrier[barrier_idx]; - for (int b = 0; b < nrow; b += nrows_interleaved) { - for (int64_t x = 0; x < nblocks; x++) { - for (int i = 0; i < nrows_interleaved; i++) { - dst_tmp[i] = src[x + i * nblocks]; + if (gemm_m == 1) { + int task_per_thread = spacemit_kernels::div_round_up(a_k_blks, nth); + int a_blk_start = ith * task_per_thread; + int a_blk_end = std::min(a_blk_start + task_per_thread, (int) a_k_blks); + if (a_blk_start < a_blk_end) { + quantize_a_row_i8(a_blk_len, feature + a_blk_start * a_blk_len, (a_blk_end - a_blk_start) * a_blk_len, + quant_a_buffer + a_blk_start * block_stride_a); + } + } else { + int task_per_thread = spacemit_kernels::div_round_up(row_blks, nth); + int m_row_blk_start = ith * task_per_thread; + int m_row_blk_end = std::min(m_row_blk_start + task_per_thread, (int) row_blks); + for (int m_row_blk = m_row_blk_start; m_row_blk < m_row_blk_end; m_row_blk++) { + int m_idx = m_row_blk * row_align; + int rows_tobe_handled = (gemm_m - m_idx) > row_align ? row_align : (gemm_m - m_idx); + + if (rows_tobe_handled == row_align && quantize_a_4row_i8 != nullptr) { + const float * a_row_ptr = feature + m_idx * gemm_k; + auto * quant_a_row_ptr = quant_a_buffer + m_idx * row_stride_a; + quantize_a_4row_i8(a_blk_len, a_row_ptr, gemm_k, quant_a_row_ptr); + } else { + while (rows_tobe_handled) { + const float * a_row_ptr = feature + m_idx * gemm_k; + auto * quant_a_row_ptr = quant_a_buffer + m_idx * row_stride_a; + quantize_a_row_i8(a_blk_len, a_row_ptr, gemm_k, quant_a_row_ptr); + rows_tobe_handled -= 1; + m_idx += 1; + } + } } - *dst++ = make_block_q4_0x16(dst_tmp, interleave_block); } - src += nrows_interleaved * nblocks; - } - return 0; - GGML_UNUSED(data_size); -} + ggml_barrier(params->threadpool); -static int repack_q4_1_to_q4_1_16_bl(struct ggml_tensor * t, - int interleave_block, - const void * GGML_RESTRICT data, - size_t data_size) { - GGML_ASSERT(t->type == GGML_TYPE_Q4_1); - GGML_ASSERT(interleave_block == 16); + const int64_t gemm_m_stride = gemm_n / gemm_m > 64 ? gemm_m : 16; + const int64_t gemm_m_blocked = spacemit_kernels::div_round_up(gemm_m, gemm_m_stride); + const int64_t max_gemm_n_stride = spacemit_kernels::div_round_up(gemm_n * gemm_m_blocked, nth); - constexpr int nrows_interleaved = 16; + int64_t gemm_n_stride = gemm_n; + if (max_gemm_n_stride < gemm_n) { + gemm_n_stride = + std::min(gemm_n_stride, spacemit_kernels::div_round_up(max_gemm_n_stride, NB_COLS) * NB_COLS); + } - block_q4_1x16 * dst = (block_q4_1x16 *) t->data; - const block_q4_1 * src = (const block_q4_1 *) data; - block_q4_1 dst_tmp[16]; - int nrow = ggml_nrows(t); - int nblocks = t->ne[0] / QK4_1; + if (gemm_n_stride == gemm_n && tcm_buffer != nullptr && per_mb_rows_wsize <= tcm_buffer_size) { + for (int64_t m_start = ith * row_align; m_start < gemm_m; m_start += row_align * nth) { + uint8_t * b_col = reinterpret_cast<uint8_t *>(w_data); + uint8_t * b_col_zp = block_type_has_zp<BLOC_TYPE>() ? b_col : nullptr; - GGML_ASSERT(data_size == nrow * nblocks * sizeof(block_q4_1)); + int64_t m_row_real = std::min(gemm_m - m_start, row_align); - if (t->ne[1] % nrows_interleaved != 0 || t->ne[0] % QK4_1 != 0) { - return -1; - } + spacemit_kernels::rvv::memcpy1d(tcm_buffer, quant_a_buffer + m_start * row_stride_a, + m_row_real * row_stride_a); - for (int b = 0; b < nrow; b += nrows_interleaved) { - for (int64_t x = 0; x < nblocks; x++) { - for (int i = 0; i < nrows_interleaved; i++) { - dst_tmp[i] = src[x + i * nblocks]; + int64_t n_blk_real = 0; + for (int64_t ni = 0; ni < gemm_n; ni += n_blk_real, b_col += n_blk_real * row_stride_b) { + n_blk_real = std::min(gemm_n - ni, (int64_t) NB_COLS); + + uint8_t * a_row_ptr = (uint8_t *) tcm_buffer; + float * c_blk = output + m_start * gemm_n + ni; + + int32_t rows_remaining = m_row_real; + + while (rows_remaining > 0) { + auto rows_handled = gemm_kernel(b_blk_len, a_row_ptr, b_col, b_col_zp, c_blk, rows_remaining, + n_blk_real, b_k_blks, gemm_n); + + c_blk += rows_handled * gemm_n; + a_row_ptr += rows_handled * row_stride_a; + + rows_remaining -= rows_handled; + } + } } - *dst++ = make_block_q4_1x16(dst_tmp, interleave_block); - } - src += nrows_interleaved * nblocks; - } - return 0; + } else if (tcm_buffer != nullptr && per_nb_cols_wsize <= tcm_buffer_size) { + uint8_t * a_row = quant_a_buffer; + uint8_t * b_col = reinterpret_cast<uint8_t *>(tcm_buffer); + if ((gemm_workspace_size + per_nb_cols_wsize) <= tcm_buffer_size) { + a_row = (uint8_t *) tcm_buffer; + b_col = reinterpret_cast<uint8_t *>(tcm_buffer) + gemm_workspace_size; + } + uint8_t * b_col_zp = block_type_has_zp<BLOC_TYPE>() ? b_col : nullptr; - GGML_UNUSED(data_size); -} + int64_t ni = ith * NB_COLS; + int64_t nb_real = std::min(gemm_n - ni, NB_COLS); -static inline void get_scale_min_k4(int j, - const uint8_t * GGML_RESTRICT q, - uint8_t * GGML_RESTRICT d, - uint8_t * GGML_RESTRICT m) { - if (j < 4) { - *d = q[j] & 63; - *m = q[j + 4] & 63; - } else { - *d = (q[j + 4] & 0xF) | ((q[j - 4] >> 6) << 4); - *m = (q[j + 4] >> 4) | ((q[j - 0] >> 6) << 4); - } -} + if (ith % 2 == 0 && nb_real > 0) { + spacemit_kernels::rvv::memcpy1d(b_col, reinterpret_cast<uint8_t *>(w_data) + ni * row_stride_b, + nb_real * row_stride_b); + if (a_row != quant_a_buffer) { + spacemit_kernels::rvv::memcpy1d(a_row, quant_a_buffer, gemm_workspace_size); + } + } -static int repack_q4_k_to_q4_1_16_bl(struct ggml_tensor * t, - int interleave_block, - const void * GGML_RESTRICT data, - size_t data_size) { - GGML_ASSERT(t->type == GGML_TYPE_Q4_K); - GGML_ASSERT(interleave_block == 16); - GGML_ASSERT(QK_K / QK4_1 == 8); + spine_barrier_wait(cur_barrier); - constexpr int nrows_interleaved = 16; + if (ith % 2 != 0 && nb_real > 0) { + if (a_row != quant_a_buffer) { + spacemit_kernels::rvv::memcpy1d(a_row, quant_a_buffer, gemm_workspace_size); + } + spacemit_kernels::rvv::memcpy1d(b_col, reinterpret_cast<uint8_t *>(w_data) + ni * row_stride_b, + nb_real * row_stride_b); + } - block_q4_1x16 * dst = (block_q4_1x16 *) t->data; - const block_q4_K * src = (const block_q4_K *) data; - block_q4_1 dst_tmp[16]; - int nrow = ggml_nrows(t); - int nblocks = t->ne[0] / QK_K; + for (; ni < gemm_n; ni += NB_COLS * nth) { + int64_t rows_remaining = gemm_m; + float * c_blk = output + ni; + auto * a_row_cur = a_row; - if (t->ne[1] % nrows_interleaved != 0 || t->ne[0] % QK_K != 0) { - return -1; - } + if (ith % 2 != 0) { + spine_barrier_wait(cur_barrier); + } - for (int b = 0; b < nrow; b += nrows_interleaved) { - for (int64_t x = 0; x < nblocks; x++) { - for (int j = 0; j < 8; j++) { - for (int i = 0; i < nrows_interleaved; i++) { - uint8_t sc, m; - const float d = GGML_FP16_TO_FP32(src[x + i * nblocks].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.d); - const float min = - GGML_FP16_TO_FP32(src[x + i * nblocks].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.dmin); - get_scale_min_k4(j, src[x + i * nblocks].scales, &sc, &m); - const float d1 = d * sc; - const float m1 = min * m; - - dst_tmp[i].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.d = GGML_FP32_TO_FP16(d1); - dst_tmp[i].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.m = GGML_FP32_TO_FP16(-m1); - // src -> [b0, b32] [b1, b33] ... [b31, b63] - // dst -> [b0, b16] [b1, b17] ... [b15, b31] [b32, b48] [b33, b49] ... [b47, b63] - const uint8_t * q = src[x + i * nblocks].qs + (j / 2) * QK4_1; - if (j % 2 == 0) { - for (int ii = 0; ii < 16; ii++) { - dst_tmp[i].qs[ii] = (q[ii] & 0x0F) | ((q[ii + 16] & 0x0F) << 4); - } - } else { - for (int ii = 0; ii < 16; ii++) { - dst_tmp[i].qs[ii] = ((q[ii] & 0xF0) >> 4) | (q[ii + 16] & 0xF0); - } - } + while (rows_remaining > 0) { + auto rows_handled = gemm_kernel(b_blk_len, a_row_cur, b_col, b_col_zp, c_blk, rows_remaining, + nb_real, b_k_blks, gemm_n); + + c_blk += rows_handled * gemm_n; + a_row_cur += rows_handled * row_stride_a; + + rows_remaining -= rows_handled; + } + + if (ith % 2 == 0) { + spine_barrier_wait(cur_barrier); + } + + const int64_t next_ni = ni + NB_COLS * nth; + if (next_ni < gemm_n) { + nb_real = std::min(gemm_n - next_ni, NB_COLS); + spacemit_kernels::rvv::memcpy1d(b_col, reinterpret_cast<uint8_t *>(w_data) + next_ni * row_stride_b, + nb_real * row_stride_b); } - *dst++ = make_block_q4_1x16(dst_tmp, interleave_block); } - } - src += nrows_interleaved * nblocks; - } - return 0; + } else { + const int64_t task_count_m = spacemit_kernels::div_round_up(gemm_m, gemm_m_stride); + const int64_t task_count_n = spacemit_kernels::div_round_up(gemm_n, gemm_n_stride); - GGML_UNUSED(data_size); -} + int64_t task_count = task_count_m * task_count_n; + int64_t task_per_thread = (task_count + nth - 1) / nth; + int64_t start = ith * task_per_thread; + int64_t end = std::min((ith + 1) * task_per_thread, task_count); + for (int64_t compute_idx = start; compute_idx < end; compute_idx++) { + const auto tid_n = compute_idx / task_count_m; + const auto tid_m = compute_idx % task_count_m; -namespace ggml::cpu::riscv64_spacemit { + const int64_t m_start = tid_m * gemm_m_stride; + const int64_t m_count = std::min(gemm_m - m_start, (int64_t) gemm_m_stride); -template <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS> -int repack(struct ggml_tensor *, const void *, size_t); + const int64_t n_start = tid_n * gemm_n_stride; + const int64_t n_count = std::min(gemm_n - n_start, (int64_t) gemm_n_stride); -template <> int repack<block_q4_0, 8, 16>(struct ggml_tensor * t, const void * data, size_t data_size) { - return repack_q4_0_to_q4_0_16_bl(t, 16, data, data_size); -} + const int64_t n_blk = m_count == 1 ? n_count : NB_COLS; -template <> int repack<block_q4_1, 8, 16>(struct ggml_tensor * t, const void * data, size_t data_size) { - return repack_q4_1_to_q4_1_16_bl(t, 16, data, data_size); -} + uint8_t * b_col = reinterpret_cast<uint8_t *>(w_data) + n_start * row_stride_b; + uint8_t * b_col_zp = block_type_has_zp<BLOC_TYPE>() ? b_col : nullptr; -template <> int repack<block_q4_K, 8, 16>(struct ggml_tensor * t, const void * data, size_t data_size) { - return repack_q4_k_to_q4_1_16_bl(t, 16, data, data_size); -} + int64_t n_blk_real = 0; + for (int64_t ni = 0; ni < n_count; ni += n_blk_real, b_col += n_blk_real * row_stride_b) { + n_blk_real = std::min(n_count - ni, n_blk); -class tensor_traits_base : public ggml::cpu::tensor_traits { - public: - virtual int repack(struct ggml_tensor * t, const void * data, size_t data_size) = 0; -}; + uint8_t * a_row = quant_a_buffer + m_start * row_stride_a; -template <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS> class tensor_traits : public tensor_traits_base { - bool work_size(int /* n_threads */, const struct ggml_tensor * op, size_t & size) override { - switch (op->op) { - case GGML_OP_MUL_MAT: - size = ggml_row_size(GGML_TYPE_Q8_0, ggml_nelements(op->src[1])) * 4; - size = ((size + QK4_0 - 1) / QK4_0) * (QK4_0 * sizeof(float) + sizeof(float)); - return true; - default: - // GGML_ABORT("fatal error"); - break; - } - return false; - } + float * c_blk = output + m_start * gemm_n + n_start + ni; - bool compute_forward(struct ggml_compute_params * params, struct ggml_tensor * op) override { - switch (op->op) { - case GGML_OP_MUL_MAT: - if (op->src[0]->type == GGML_TYPE_Q4_0 || // - op->src[0]->type == GGML_TYPE_Q4_1 || // - op->src[0]->type == GGML_TYPE_Q4_K) { - forward_mul_mat_q4(params, op); - return true; + int64_t rows_remaining = m_count; + + uint8_t * b_col_cur = b_col; + uint8_t * b_col_zp_cur = b_col_zp; + + while (rows_remaining > 0) { + auto rows_handled = gemm_kernel(b_blk_len, a_row, b_col_cur, b_col_zp_cur, c_blk, + rows_remaining, n_blk_real, b_k_blks, gemm_n); + + c_blk += rows_handled * gemm_n; + a_row += rows_handled * row_stride_a; + + rows_remaining -= rows_handled; + } } - default: - // GGML_ABORT("fatal error"); - break; + } } - return false; } - void forward_mul_mat_q4(ggml_compute_params * params, ggml_tensor * op) { + void forward_mul_mat_id(ggml_compute_params * params, ggml_tensor * op) { + constexpr size_t a_blk_len = INTER_SIZE; + constexpr size_t b_blk_len = INTER_SIZE; + const ggml_tensor * src0 = op->src[0]; const ggml_tensor * src1 = op->src[1]; + const ggml_tensor * ids = op->src[2]; ggml_tensor * dst = op; GGML_TENSOR_BINARY_OP_LOCALS @@ -429,133 +554,381 @@ template <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS> class tensor_ int ith = params->ith; int nth = params->nth; - [[maybe_unused]] const enum ggml_type type = src0->type; + // row groups + const int n_ids = ids->ne[0]; // n_expert_used + const int n_as = ne02; // n_expert + + struct mmid_row_mapping { + int32_t i1; + int32_t i2; + }; + + spacemit_kernels::quantize_a_row_def quantize_a_row_i8; + spacemit_kernels::gemm_kernel_quantize_def gemm_kernel; + spacemit_kernels::moe_gemm_kernel_quantize_def moe_gemm_kernel_m2; + bool set_kernel_impl = false; + size_t block_stride_a = spacemit_kernels::q8_blk_size(QK4_0); + +#if defined(RISCV64_SPACEMIT_IME2) + if (!set_kernel_impl && (global_spine_env_info.use_ime2)) { + quantize_a_row_i8 = spacemit_kernels::rvv::quantize_a_row_i8; + block_stride_a = spacemit_kernels::q8_blk_size(QK4_0, true); + + if constexpr (std::is_same_v<BLOC_TYPE, block_q6_K> || std::is_same_v<BLOC_TYPE, block_q8_0>) { + gemm_kernel = spacemit_kernels::ime2::gemm_kernel_i8i8; + set_kernel_impl = true; + } else if constexpr (std::is_same_v<BLOC_TYPE, block_q4_0> || std::is_same_v<BLOC_TYPE, block_q4_1> || + std::is_same_v<BLOC_TYPE, block_q4_K>) { + if constexpr (INTER_SIZE == 256) { + gemm_kernel = spacemit_kernels::ime2::gemm_kernel_i8i4_hp; + quantize_a_row_i8 = spacemit_kernels::rvv::quantize_a_row_i8_hp; + block_stride_a = spacemit_kernels::q8_hp_blk_size(a_blk_len, true, true); + set_kernel_impl = true; + } else { + gemm_kernel = spacemit_kernels::ime2::gemm_kernel_i8i4; + moe_gemm_kernel_m2 = spacemit_kernels::ime2::moe_m2_gemm_kernel_i8i4; + quantize_a_row_i8 = spacemit_kernels::rvv::quantize_a_row_i8; + block_stride_a = spacemit_kernels::q8_blk_size(a_blk_len, true); + set_kernel_impl = true; + } + } else if constexpr (std::is_same_v<BLOC_TYPE, block_q2_K>) { + quantize_a_row_i8 = spacemit_kernels::rvv::quantize_a_row_i8k; + block_stride_a = spacemit_kernels::q8k_blk_size(a_blk_len); + gemm_kernel = spacemit_kernels::ime2::gemm_kernel_i8i2k; + set_kernel_impl = true; + } else if constexpr (std::is_same_v<BLOC_TYPE, block_q3_K>) { + quantize_a_row_i8 = spacemit_kernels::rvv::quantize_a_row_i8k; + block_stride_a = spacemit_kernels::q8k_blk_size(a_blk_len); + gemm_kernel = spacemit_kernels::ime2::gemm_kernel_i8i3k; + set_kernel_impl = true; + } else if constexpr (std::is_same_v<BLOC_TYPE, block_mxfp4>) { + gemm_kernel = spacemit_kernels::ime2::gemm_kernel_i8mxfp4; + moe_gemm_kernel_m2 = spacemit_kernels::ime2::moe_m2_gemm_kernel_i8mxfp4; + set_kernel_impl = true; + } else if constexpr (std::is_same_v<BLOC_TYPE, block_q5_1> || std::is_same_v<BLOC_TYPE, block_q5_K> || + std::is_same_v<BLOC_TYPE, block_q5_0>) { + gemm_kernel = spacemit_kernels::ime2::gemm_kernel_i8i5; + moe_gemm_kernel_m2 = spacemit_kernels::ime2::moe_m2_gemm_kernel_i8i5; + set_kernel_impl = true; + } + } +#endif - void * w_data = (void *) src0->data; - const float * feature = (const float *) src1->data; - float * output = (float *) dst->data; +#if defined(RISCV64_SPACEMIT_IME1) + if (!set_kernel_impl && (global_spine_env_info.use_ime1)) { + quantize_a_row_i8 = spacemit_kernels::ime1::quantize_a_row_i8; + + if constexpr (std::is_same_v<BLOC_TYPE, block_q4_0> || std::is_same_v<BLOC_TYPE, block_q4_1> || + std::is_same_v<BLOC_TYPE, block_q4_K>) { + gemm_kernel = spacemit_kernels::ime1::gemm_kernel_i8i4; + set_kernel_impl = true; + } + } +#endif + if (!set_kernel_impl) { + GGML_ABORT("no kernel implementation found for the block type"); + } - const size_t batch_feature = ne12 * ne13; - [[maybe_unused]] const size_t batch_weight = ne02 * ne03; - const size_t gemm_m = ne11; - const size_t gemm_k = ne10; - const size_t gemm_n = ne01; + const size_t a_k_blks = spacemit_kernels::div_round_up(ne10, a_blk_len); + const size_t b_k_blks = spacemit_kernels::div_round_up(ne10, b_blk_len); - GGML_ASSERT(batch_weight == 1); + const size_t nbw1 = a_k_blks * block_stride_a; + const size_t nbw2 = ne11 * nbw1; + const size_t nbw3 = nbw2 * ne12; + const size_t gemm_workspace_size = GGML_PAD(nbw3, alignof(int64_t)); - const size_t block_count_k = div_round_up(gemm_k, QK4_0); - const size_t per_gemm_workspace_size = gemm_m * block_count_k * q8_blk_size(QK4_0); - const size_t per_gemm_workspace_stride = - div_round_up(per_gemm_workspace_size, alignof(uint64_t)) * alignof(uint64_t); - const size_t gemm_workspace_size = batch_feature * per_gemm_workspace_stride; - const size_t desired_wsize = gemm_workspace_size + alignof(uint64_t) - 1; + const uintptr_t ws_ptr = reinterpret_cast<uintptr_t>(params->wdata); + auto * quant_a_buffer = reinterpret_cast<uint8_t *>(ws_ptr); - if (ith == 0 && params->wsize < desired_wsize) { - throw std::runtime_error("wsize less than desired_wsize"); + if (ne11 == 1) { + for (int64_t ii = ith; ii < ne12 * a_k_blks; ii += nth) { + int64_t i12 = ii / a_k_blks; + int64_t ak_blk_id = ii % a_k_blks; + quantize_a_row_i8(a_blk_len, (float *) ((char *) src1->data + i12 * nb12) + ak_blk_id * a_blk_len, + a_blk_len, quant_a_buffer + i12 * nbw2 + ak_blk_id * block_stride_a); + } + } else { + for (int64_t ii = ith; ii < ne12 * ne11; ii += nth) { + int64_t i12 = ii / ne11; + int64_t i11 = ii % ne11; + quantize_a_row_i8(a_blk_len, (float *) ((char *) src1->data + i12 * nb12 + i11 * nb11), ne10, + quant_a_buffer + i12 * nbw2 + i11 * nbw1); + } } - std::vector<qnbitgemm_spacemit_ime_args> qnbitgemm_args(batch_feature); +#define MMID_MATRIX_ROW(row_id, i1) matrix_rows[(row_id) *ne12 + (i1)] - for (size_t i = 0; i < batch_feature; i++) { - qnbitgemm_args[i].a_ptr = feature + gemm_m * gemm_k * i; - qnbitgemm_args[i].lda = gemm_k; - qnbitgemm_args[i].packed_quant_b_data = (const std::byte *) w_data; - qnbitgemm_args[i].quant_b_scale = nullptr; + int64_t * matrix_row_counts = (int64_t *) (ws_ptr + gemm_workspace_size); + int32_t * valid_ep_count = (int32_t *) (matrix_row_counts + n_as); + int32_t * valid_act_count = (int32_t *) (valid_ep_count + 1); + int64_t * valid_matrix_row_counts = (int64_t *) (valid_act_count + 1); + mmid_row_mapping * matrix_rows = (mmid_row_mapping *) (valid_matrix_row_counts + n_as); - if constexpr (std::is_same_v<BLOC_TYPE, block_q4_0>) { - qnbitgemm_args[i].quant_b_zp = nullptr; - } else { - qnbitgemm_args[i].quant_b_zp = w_data; + if (ith == 0) { + // initialize matrix_row_counts + memset(matrix_row_counts, 0, n_as * sizeof(int64_t)); + + // group rows by src0 matrix + for (int32_t iid1 = 0; iid1 < ids->ne[1]; ++iid1) { + for (int32_t id = 0; id < n_ids; ++id) { + const int32_t i02 = + *(const int32_t *) ((const char *) ids->data + iid1 * ids->nb[1] + id * ids->nb[0]); + + GGML_ASSERT(i02 >= 0 && i02 < n_as); + + MMID_MATRIX_ROW(i02, matrix_row_counts[i02]) = { id, iid1 }; + matrix_row_counts[i02] += 1; + } } - qnbitgemm_args[i].bias = nullptr; - qnbitgemm_args[i].c_ptr = output + gemm_m * gemm_n * i; - qnbitgemm_args[i].ldc = gemm_n; + int32_t valid_ep_count_t = 0; + int32_t valid_act_count_t = 0; + for (int cur_a = 0; cur_a < n_as; ++cur_a) { + const int64_t cne1 = matrix_row_counts[cur_a]; + if (cne1 == 0) { + continue; + } + valid_matrix_row_counts[valid_ep_count_t] = cur_a; + valid_act_count_t += cne1; + valid_ep_count_t += 1; + } + valid_ep_count[0] = valid_ep_count_t; + valid_act_count[0] = valid_act_count_t; } - const uintptr_t ws_ptr = reinterpret_cast<uintptr_t>(params->wdata); - void * ws = reinterpret_cast<void *>((ws_ptr + alignof(uint64_t) - 1) & (~(alignof(uint64_t) - 1))); - const size_t quant_a_stride = block_count_k * q8_blk_size(QK4_0); + const int64_t barrier_idx = static_cast<int64_t>(ith / 2); - { - constexpr size_t block_size_m = 4; - size_t per_gemm_block_count_m = div_round_up(gemm_m, block_size_m); - int32_t task_count = batch_feature * per_gemm_block_count_m; - int32_t task_per_thread = (task_count + nth - 1) / nth; - int32_t start = ith * task_per_thread; - int32_t end = std::min((ith + 1) * task_per_thread, task_count); - for (int32_t compute_idx = start; compute_idx < end; compute_idx++) { - int32_t gemm_idx = compute_idx / per_gemm_block_count_m; - int32_t block_idx_in_gemm = compute_idx % per_gemm_block_count_m; - int32_t m_idx = block_idx_in_gemm * block_size_m; - const qnbitgemm_spacemit_ime_args & data = qnbitgemm_args[gemm_idx]; - int32_t rows_tobe_handled = (gemm_m - m_idx) > block_size_m ? block_size_m : (gemm_m - m_idx); - - if (rows_tobe_handled == block_size_m) { - const float * a_row_ptr = data.a_ptr + m_idx * data.lda; - std::byte * quant_a_row_ptr = - static_cast<std::byte *>(ws) + gemm_idx * per_gemm_workspace_stride + m_idx * quant_a_stride; - sqnbitgemm_spacemit_ime::ime1::quantize_a_4row_i8(QK4_0, a_row_ptr, gemm_k, quant_a_row_ptr); - } else { - while (rows_tobe_handled) { - const float * a_row_ptr = data.a_ptr + m_idx * data.lda; - std::byte * quant_a_row_ptr = static_cast<std::byte *>(ws) + - gemm_idx * per_gemm_workspace_stride + m_idx * quant_a_stride; - sqnbitgemm_spacemit_ime::ime1::quantize_a_row_i8(QK4_0, a_row_ptr, gemm_k, quant_a_row_ptr); - rows_tobe_handled -= 1; - m_idx += 1; + GGML_ASSERT(global_spine_env_info.init_barrier != nullptr); + GGML_ASSERT(barrier_idx < spine_init_barrier_count); + spine_barrier_t * cur_barrier = &global_spine_env_info.init_barrier[barrier_idx]; + + ggml_barrier(params->threadpool); + + const size_t row_stride_b = b_k_blks * get_repacked_block_type_size<BLOC_TYPE, INTER_SIZE, NB_COLS>(); + const size_t expert_b_stride = ne01 * row_stride_b; + const size_t per_nb_cols_wsize = NB_COLS * row_stride_b; + + std::array<const uint8_t *, 2> src_workspaces; + std::array<float *, 2> dst_workspaces; + + auto * tcm_buffer = ggml::cpu::riscv64_spacemit::tls_context.tcm_buffer; + const auto tcm_buffer_size = ggml::cpu::riscv64_spacemit::tls_context.tcm_buffer_size; + + const auto valid_ep_count_t = valid_ep_count[0]; + const auto valid_act_count_t = valid_act_count[0]; + + int nth_es = 1; + int nth_n = nth; + + int ith_es = ith % nth_es; + int ith_n = (ith / nth_es) % nth_n; + + if (valid_ep_count_t % nth == 0 && tcm_buffer != nullptr && valid_ep_count_t == n_as && + valid_act_count_t == n_as && per_nb_cols_wsize <= tcm_buffer_size) { + for (int64_t valid_id = ith; valid_id < valid_ep_count_t; valid_id += nth) { + const int64_t cur_a = valid_matrix_row_counts[valid_id]; + + auto * src0_cur = (uint8_t *) src0->data + cur_a * expert_b_stride; + + mmid_row_mapping row_mapping = MMID_MATRIX_ROW(cur_a, 0); + const int id = row_mapping.i1; + const int64_t i11 = id % ne11; + const int64_t i12 = row_mapping.i2; + const int64_t i1 = id; + const int64_t i2 = i12; + + auto * src1_col = quant_a_buffer + (i11 * nbw1 + i12 * nbw2); + float * c_blk = (float *) ((char *) dst->data + (i1 * nb1 + i2 * nb2)); + + uint8_t * a_row = src1_col; + uint8_t * b_col = reinterpret_cast<uint8_t *>(tcm_buffer); + if ((nbw1 + per_nb_cols_wsize) <= tcm_buffer_size) { + a_row = (uint8_t *) tcm_buffer; + b_col = reinterpret_cast<uint8_t *>(tcm_buffer) + nbw1; + } + uint8_t * b_col_zp = block_type_has_zp<BLOC_TYPE>() ? b_col : nullptr; + + if (ith % 2 == 0) { + spacemit_kernels::rvv::memcpy1d(b_col, reinterpret_cast<uint8_t *>(src0_cur), per_nb_cols_wsize); + + if (a_row != src1_col) { + spacemit_kernels::rvv::memcpy1d(a_row, src1_col, nbw1); + } + } + + spine_barrier_wait(cur_barrier); + + if (ith % 2 != 0) { + if (a_row != src1_col) { + spacemit_kernels::rvv::memcpy1d(a_row, src1_col, nbw1); + } + + spacemit_kernels::rvv::memcpy1d(b_col, reinterpret_cast<uint8_t *>(src0_cur), per_nb_cols_wsize); + } + + int64_t nb_real = std::min(ne01, NB_COLS); + for (int64_t ni = 0; ni < ne01; ni += NB_COLS) { + if (ith % 2 != 0) { + spine_barrier_wait(cur_barrier); + } + + gemm_kernel(b_blk_len, a_row, b_col, b_col_zp, c_blk + ni, 1, nb_real, b_k_blks, ne01); + + if (ith % 2 == 0) { + spine_barrier_wait(cur_barrier); + } + + const int64_t next_ni = ni + NB_COLS; + if (next_ni < ne01) { + nb_real = std::min(ne01 - next_ni, NB_COLS); + spacemit_kernels::rvv::memcpy1d( + b_col, reinterpret_cast<uint8_t *>(src0_cur) + next_ni * row_stride_b, per_nb_cols_wsize); } } } - } + } else { + for (int64_t valid_id = ith_es; valid_id < valid_ep_count_t; valid_id += nth_es) { + const int64_t cur_a = valid_matrix_row_counts[valid_id]; + const int64_t cne1 = matrix_row_counts[cur_a]; - ggml_barrier(params->threadpool); + int64_t src1_cur_start = 0; + int64_t src1_cur_end = cne1; - if (ith >= ggml::cpu::riscv64_spacemit::num_ai_cores) { - return; - } - nth = std::min(nth, int{ ggml::cpu::riscv64_spacemit::num_ai_cores }); - - size_t threads_per_gemm = nth / batch_feature; - constexpr size_t gemm_m_stride = 128; - size_t nc = gemm_n; - const size_t gemm_m_blocked = div_round_up(gemm_m, gemm_m_stride); - const size_t max_nc = div_round_up(gemm_n * gemm_m_blocked, threads_per_gemm); - if (max_nc < nc) { - nc = std::min(nc, div_round_up(max_nc, QGEMM_STRIDEN_THREAD_ALIGN) * QGEMM_STRIDEN_THREAD_ALIGN); - } - const size_t gemm_n_stride = nc; - const size_t thread_count_m = div_round_up(gemm_m, gemm_m_stride); - const size_t thread_count_n = div_round_up(gemm_n, gemm_n_stride); - threads_per_gemm = thread_count_m * thread_count_n; + int64_t src0_cur_start = (ith_n * ne01) / nth_n; + int64_t src0_cur_end = MIN(((ith_n + 1) * ne01) / nth_n, ne01); - { - int task_count = batch_feature * threads_per_gemm; - int task_per_thread = (task_count + nth - 1) / nth; - int start = ith * task_per_thread; - int end = std::min((ith + 1) * task_per_thread, task_count); - for (int compute_idx = start; compute_idx < end; compute_idx++) { - const auto gemm_i = compute_idx / threads_per_gemm; - const auto blk_i = compute_idx % threads_per_gemm; - const auto * data = &qnbitgemm_args[gemm_i]; + if (src1_cur_start >= src1_cur_end || src0_cur_start >= src0_cur_end) { + continue; + } + + src0_cur_start = + (src0_cur_start % NB_COLS) ? src0_cur_start + NB_COLS - (src0_cur_start % NB_COLS) : src0_cur_start; + src0_cur_end = + (src0_cur_end % NB_COLS) ? src0_cur_end + NB_COLS - (src0_cur_end % NB_COLS) : src0_cur_end; + + auto * src0_cur = (uint8_t *) src0->data + cur_a * expert_b_stride + src0_cur_start * row_stride_b; + uint8_t * b_col_zp = block_type_has_zp<BLOC_TYPE>() ? src0_cur : nullptr; + + size_t extra_tcm_buffer_size = tcm_buffer_size; + void * extra_tcm_buffer = tcm_buffer; + if (tcm_buffer != nullptr && (src1_cur_end - src1_cur_start) >= 4 && + (src0_cur_end - src0_cur_start) * row_stride_b <= tcm_buffer_size) { + spacemit_kernels::rvv::memcpy1d(tcm_buffer, src0_cur, + (src0_cur_end - src0_cur_start) * row_stride_b); + src0_cur = reinterpret_cast<uint8_t *>(tcm_buffer); + b_col_zp = block_type_has_zp<BLOC_TYPE>() ? src0_cur : nullptr; + extra_tcm_buffer_size -= (src0_cur_end - src0_cur_start) * row_stride_b; + extra_tcm_buffer = reinterpret_cast<void *>(reinterpret_cast<uintptr_t>(tcm_buffer) + + (src0_cur_end - src0_cur_start) * row_stride_b); + } - const auto tid_n = blk_i / thread_count_m; - const auto tid_m = blk_i % thread_count_m; + int ir1 = src1_cur_start; - const size_t m_start = tid_m * gemm_m_stride; - const size_t m_count = std::min(gemm_m - m_start, (size_t) gemm_m_stride); + if (extra_tcm_buffer_size >= nbw1 && extra_tcm_buffer != nullptr) { + int64_t quant_a_tile_size = extra_tcm_buffer_size / nbw1; + do { + quant_a_tile_size = MIN(quant_a_tile_size, src1_cur_end - ir1); - const size_t n_start = tid_n * gemm_n_stride; - const size_t n_count = std::min(gemm_n - n_start, (size_t) gemm_n_stride); + uint8_t * quant_a_tile_buffer = reinterpret_cast<uint8_t *>(extra_tcm_buffer); - void * per_gemm_ws = reinterpret_cast<std::byte *>(ws) + gemm_i * per_gemm_workspace_stride; + int iir1 = ir1; + for (; iir1 < (ir1 + quant_a_tile_size); ++iir1) { + mmid_row_mapping row_mapping = MMID_MATRIX_ROW(cur_a, iir1); - sqnbitgemm_spacemit_ime_i8i4(QK4_0, gemm_k, data, per_gemm_ws, m_start, m_count, n_start, n_count); + const int id = row_mapping.i1; // selected expert index + + const int64_t i11 = id % ne11; + const int64_t i12 = row_mapping.i2; // row index in src1 + + auto * src1_col = quant_a_buffer + (i11 * nbw1 + i12 * nbw2); + spacemit_kernels::rvv::memcpy1d(quant_a_tile_buffer, src1_col, nbw1); + quant_a_tile_buffer = quant_a_tile_buffer + nbw1; + } + + quant_a_tile_buffer = reinterpret_cast<uint8_t *>(extra_tcm_buffer); + iir1 = ir1; + + if (moe_gemm_kernel_m2 != nullptr) { + for (; iir1 < (ir1 + quant_a_tile_size - 1); iir1 += 2, quant_a_tile_buffer += 2 * nbw1) { + mmid_row_mapping row_mapping_0 = MMID_MATRIX_ROW(cur_a, iir1); + mmid_row_mapping row_mapping_1 = MMID_MATRIX_ROW(cur_a, iir1 + 1); + + src_workspaces[0] = quant_a_tile_buffer; + src_workspaces[1] = quant_a_tile_buffer + nbw1; + + dst_workspaces[0] = + (float *) ((char *) dst->data + (row_mapping_0.i1 * nb1 + row_mapping_0.i2 * nb2)) + + src0_cur_start; + dst_workspaces[1] = (float *) ((char *) dst->data + + ((row_mapping_1.i1) * nb1 + (row_mapping_1.i2) * nb2)) + + src0_cur_start; + moe_gemm_kernel_m2(b_blk_len, src_workspaces.data(), src0_cur, b_col_zp, + dst_workspaces.data(), 1, src0_cur_end - src0_cur_start, b_k_blks, + ne01); + } + } + + for (; iir1 < (ir1 + quant_a_tile_size); iir1++, quant_a_tile_buffer += nbw1) { + mmid_row_mapping row_mapping_0 = MMID_MATRIX_ROW(cur_a, iir1); + + gemm_kernel( + b_blk_len, quant_a_tile_buffer, src0_cur, b_col_zp, + (float *) ((char *) dst->data + (row_mapping_0.i1 * nb1 + row_mapping_0.i2 * nb2)) + + src0_cur_start, + 1, src0_cur_end - src0_cur_start, b_k_blks, ne01); + } + + ir1 += quant_a_tile_size; + } while (ir1 < src1_cur_end); + } else { + if (moe_gemm_kernel_m2 != nullptr) { + for (; ir1 < src1_cur_end - 1; ir1 += 2) { + for (int iir1 = 0; iir1 < 2; ++iir1) { + mmid_row_mapping row_mapping = MMID_MATRIX_ROW(cur_a, ir1 + iir1); + + const int id = row_mapping.i1; // selected expert index + + const int64_t i11 = id % ne11; + const int64_t i12 = row_mapping.i2; // row index in src1 + + const int64_t i1 = id; // selected expert index + const int64_t i2 = i12; // row + + src_workspaces[iir1] = quant_a_buffer + (i11 * nbw1 + i12 * nbw2); + + dst_workspaces[iir1] = + (float *) ((char *) dst->data + (i1 * nb1 + i2 * nb2)) + src0_cur_start; + } + + moe_gemm_kernel_m2(b_blk_len, src_workspaces.data(), src0_cur, b_col_zp, + dst_workspaces.data(), 1, src0_cur_end - src0_cur_start, b_k_blks, ne01); + } + } + + for (; ir1 < src1_cur_end; ir1++) { + mmid_row_mapping row_mapping = MMID_MATRIX_ROW(cur_a, ir1); + + const int id = row_mapping.i1; // selected expert index + + const int64_t i11 = id % ne11; + const int64_t i12 = row_mapping.i2; // row index in src1 + + const int64_t i1 = id; // selected expert index + const int64_t i2 = i12; // row + + auto * src1_col = quant_a_buffer + (i11 * nbw1 + i12 * nbw2); + + gemm_kernel(b_blk_len, src1_col, src0_cur, b_col_zp, + (float *) ((char *) dst->data + (i1 * nb1 + i2 * nb2)) + src0_cur_start, 1, + src0_cur_end - src0_cur_start, b_k_blks, ne01); + } + } } } +#undef MMID_MATRIX_ROW } - int repack(struct ggml_tensor * t, const void * data, size_t data_size) override { + int repack(ggml_tensor * t, const void * data, size_t data_size) override { GGML_LOG_DEBUG("%s: repack tensor %s with %s_%dx%d\n", __func__, t->name, ggml_type_name(t->type), (int) NB_COLS, (int) INTER_SIZE); return ggml::cpu::riscv64_spacemit::repack<BLOC_TYPE, INTER_SIZE, NB_COLS>(t, data, data_size); @@ -563,309 +936,464 @@ template <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS> class tensor_ }; class tensor_traits_common : public tensor_traits_base { - bool work_size(int /* n_threads */, const struct ggml_tensor * op, size_t & size) override { + bool work_size(int n_threads, const ggml_tensor * op, size_t & size) override { switch (op->op) { - case GGML_OP_NORM: - case GGML_OP_RMS_NORM: - size = 0; + case GGML_OP_FLASH_ATTN_EXT: + { + const int n_tasks = n_threads; + const int64_t neq2 = op->src[0]->ne[2]; // number of query heads + const int64_t DK = op->src[1]->ne[0]; + const int64_t DV = op->src[2]->ne[0]; // DV + + // Tiled flash attention scratch (tile sizes defined in common.h) + // Per-thread: Q_q + KQ + mask + VKQ32 + V32 + K_f32 + padding + size_t prefill = sizeof(float) * + (GGML_FA_TILE_Q * DK + 2 * GGML_FA_TILE_Q * GGML_FA_TILE_KV + GGML_FA_TILE_Q * DV + + GGML_FA_TILE_KV * DV + GGML_FA_TILE_KV * DK) * + n_tasks; + + // Decode path: n_kv_chunks = n_tasks (one chunk per thread) + // Per-thread: VKQ accmulator (DV), partial M, partial S + intra-thread scratch for V, Q and VKQ + size_t n_chunks = n_tasks; + size_t decode = sizeof(float) * (neq2 * n_chunks * (2 + DV) + n_tasks * (DK + 2 * DV)); + + size = MAX(prefill, decode); + } return true; default: - // GGML_ABORT("fatal error"); break; } return false; } - bool compute_forward(struct ggml_compute_params * params, struct ggml_tensor * op) override { + bool compute_forward(ggml_compute_params * params, ggml_tensor * op) override { switch (op->op) { case GGML_OP_NORM: - forward_norm_f32(params, op); - return true; + switch (op->src[0]->type) { + case GGML_TYPE_F32: + spacemit_kernels::rvv::forward_norm_f32(params, op); + return true; + default: + GGML_ABORT("fatal error"); + } case GGML_OP_RMS_NORM: - forward_rms_norm_f32(params, op); + switch (op->src[0]->type) { + case GGML_TYPE_F32: + spacemit_kernels::rvv::forward_rms_norm_f32(params, op); + return true; + default: + GGML_ABORT("fatal error"); + } + case GGML_OP_ADD: + switch (op->src[0]->type) { + case GGML_TYPE_F32: + spacemit_kernels::rvv::forward_binary<GGML_OP_ADD, float>(params, op); + return true; + case GGML_TYPE_F16: + spacemit_kernels::rvv::forward_binary<GGML_OP_ADD, _Float16>(params, op); + return true; + default: + ggml_compute_forward_add(params, op); + return true; + } + case GGML_OP_SUB: + switch (op->src[0]->type) { + case GGML_TYPE_F32: + spacemit_kernels::rvv::forward_binary<GGML_OP_SUB, float>(params, op); + return true; + case GGML_TYPE_F16: + spacemit_kernels::rvv::forward_binary<GGML_OP_SUB, _Float16>(params, op); + return true; + default: + ggml_compute_forward_sub(params, op); + return true; + } + case GGML_OP_MUL: + switch (op->src[0]->type) { + case GGML_TYPE_F32: + spacemit_kernels::rvv::forward_binary<GGML_OP_MUL, float>(params, op); + return true; + case GGML_TYPE_F16: + spacemit_kernels::rvv::forward_binary<GGML_OP_MUL, _Float16>(params, op); + return true; + default: + ggml_compute_forward_mul(params, op); + return true; + } + case GGML_OP_DIV: + switch (op->src[0]->type) { + case GGML_TYPE_F32: + spacemit_kernels::rvv::forward_binary<GGML_OP_DIV, float>(params, op); + return true; + case GGML_TYPE_F16: + spacemit_kernels::rvv::forward_binary<GGML_OP_DIV, _Float16>(params, op); + return true; + default: + ggml_compute_forward_div(params, op); + return true; + } + case GGML_OP_FLASH_ATTN_EXT: + forward_flash_attn_ext_f16(params, op); + return true; + case GGML_OP_CONT: + { + const ggml_tensor * src0 = op->src[0]; + if (op->type == src0->type && op->nb[0] != src0->nb[0] && op->nb[0] == src0->nb[1] && + op->ne[3] * op->ne[2] * op->nb[2] == src0->ne[3] * src0->ne[2] * src0->nb[2]) { + spacemit_kernels::rvv::forward_cont_with_permute(params, op); + } else { + ggml_compute_forward_cont(params, op); + } + return true; + } + case GGML_OP_CPY: + { + const ggml_tensor * src0 = op->src[0]; + if (op->type == src0->type && op->nb[0] == src0->nb[1] && src0->nb[0] != src0->nb[1] && + ggml_nelements(src0) == ggml_nelements(op)) { + spacemit_kernels::rvv::forward_cpy_with_permute(params, op); + } else { + ggml_compute_forward_cpy(params, op); + } + return true; + } + case GGML_OP_REPEAT: + { + const bool rows_equal = ggml_nrows(op->src[0]) == ggml_nrows(op); + const bool broadcast_or_equal = op->src[0]->ne[0] == 1 || op->src[0]->ne[0] == op->ne[0]; + + if (rows_equal && broadcast_or_equal) { + switch (op->src[0]->type) { + case GGML_TYPE_F32: + spacemit_kernels::rvv::forward_repeat_nrows<int32_t>(params, op); + return true; + case GGML_TYPE_F16: + spacemit_kernels::rvv::forward_repeat_nrows<int16_t>(params, op); + return true; + default: + break; + } + } + + if (op->src[0]->ne[1] == 1 && op->src[0]->ne[0] == op->ne[0]) { + switch (op->src[0]->type) { + case GGML_TYPE_F32: + spacemit_kernels::rvv::forward_repeat_dim1<int32_t>(params, op); + return true; + case GGML_TYPE_F16: + spacemit_kernels::rvv::forward_repeat_dim1<int16_t>(params, op); + return true; + default: + break; + } + } + + ggml_compute_forward_repeat(params, op); + } + return true; + case GGML_OP_SUM_ROWS: + { + if (op->src[0]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32) { + spacemit_kernels::rvv::forward_sum_rows<float>(params, op); + } else { + ggml_compute_forward_sum_rows(params, op); + } + } + return true; + case GGML_OP_GET_ROWS: + { + if (op->src[0]->type == op->type) { + switch (op->src[0]->type) { + case GGML_TYPE_F32: + spacemit_kernels::rvv::forward_get_rows<int32_t>(params, op); + return true; + case GGML_TYPE_F16: + spacemit_kernels::rvv::forward_get_rows<int16_t>(params, op); + return true; + default: + break; + } + } + + ggml_compute_forward_get_rows(params, op); + } return true; + case GGML_OP_CONCAT: + { + const int32_t dim = ggml_get_op_params_i32(op, 0); + if (dim == 0 && op->type == op->src[0]->type) { + switch (op->src[0]->type) { + case GGML_TYPE_F32: + spacemit_kernels::rvv::forward_concat<int32_t>(params, op); + return true; + case GGML_TYPE_F16: + spacemit_kernels::rvv::forward_concat<int16_t>(params, op); + return true; + default: + break; + } + } + + ggml_compute_forward_concat(params, op); + } + return true; + // TODO For GGML_OP_GATED_DELTA_NET + // case GGML_OP_GATED_DELTA_NET: + // return true; default: - // GGML_ABORT("fatal error"); break; } return false; } - void forward_norm_f32(ggml_compute_params * params, ggml_tensor * op) { - const ggml_tensor * src0 = op->src[0]; - ggml_tensor * dst = op; - GGML_ASSERT(ggml_are_same_shape(src0, dst)); - GGML_ASSERT(src0->nb[0] == sizeof(float)); + void forward_flash_attn_ext_f16(const ggml_compute_params * params, ggml_tensor * dst) { + const ggml_tensor * q = dst->src[0]; + const ggml_tensor * k = dst->src[1]; + const ggml_tensor * v = dst->src[2]; + + GGML_TENSOR_LOCALS(int64_t, neq, q, ne) + GGML_TENSOR_LOCALS(size_t, nbq, q, nb) + GGML_TENSOR_LOCALS(int64_t, nek, k, ne) + GGML_TENSOR_LOCALS(size_t, nbk, k, nb) + GGML_TENSOR_LOCALS(int64_t, nev, v, ne) + GGML_TENSOR_LOCALS(size_t, nbv, v, nb) + GGML_TENSOR_LOCALS(int64_t, ne, dst, ne) + GGML_TENSOR_LOCALS(size_t, nb, dst, nb) + + const int64_t DK = nek0; + const int64_t DV = nev0; + + const bool supported_prec = (dst->op_params[3] == GGML_PREC_F32 || dst->op_params[3] == GGML_PREC_DEFAULT); + const bool supported_types = (q->type == GGML_TYPE_F32 && k->type == GGML_TYPE_F16 && v->type == GGML_TYPE_F16); + const bool supported_shape = (DK > 0 && DK <= 128 && DV > 0 && DV <= 128); + const bool supported_vlen = (__riscv_vlenb() == 128); + + if (!(supported_prec && supported_types && supported_shape && supported_vlen)) { + ggml_compute_forward_flash_attn_ext(params, dst); + return; + } + + // total rows in q + const int64_t nr = neq1 * neq2 * neq3; + // rows per thread const int ith = params->ith; const int nth = params->nth; - GGML_TENSOR_UNARY_OP_LOCALS + static constexpr int64_t Q_TILE_SZ = ggml_fa_tile_config::Q; + const bool use_tiled = !params->use_ref && (neq1 >= Q_TILE_SZ); - float epsilon; - memcpy(&epsilon, dst->op_params, sizeof(float)); + // 4x chunks per thread + // int nth_scaled = nth * 4; + // int64_t chunk_size = (nr + nth_scaled - 1) / nth_scaled; + // int64_t nchunk = (nr + chunk_size - 1) / chunk_size; - GGML_ASSERT(epsilon > 0.0f); + // if (nth == 1 || nchunk < nth) { + // nchunk = nth; + // } - auto * input = (float *) src0->data; - auto * output = (float *) dst->data; + int64_t nchunk = nth; - const auto hidden_size = ne00; - const auto task_count = ne01 * ne02 * ne03; - const auto task_per_thread = (task_count + nth - 1) / nth; - - const auto task_begin = ith * task_per_thread; - const auto task_end = std::min((ith + 1) * task_per_thread, task_count); + if (ith == 0) { + // Every thread starts at ith, so the first unprocessed chunk is nth. This save a bit of coordination right at the start. + ggml_threadpool_chunk_set(params->threadpool, nth); + } - for (auto task_idx = task_begin; task_idx < task_end; task_idx++) { - auto offset = task_idx * hidden_size; - auto * p_input = const_cast<float *>(input + offset); + ggml_barrier(params->threadpool); - auto * p_output = output + offset; - auto * p_temp_output = p_output; - auto * p_gamma_data = (const float *) nullptr; - auto * p_beta_data = (const float *) nullptr; - size_t gvl = __riscv_vsetvlmax_e32m4(); - vfloat32m4_t sum = __riscv_vfmv_v_f_f32m4(0.f, gvl); - vfloat32m4_t sum_sq = __riscv_vfmv_v_f_f32m4(0.f, gvl); - int64_t length = hidden_size; - while (length > 0) { - gvl = __riscv_vsetvl_e32m4(length); - // load data - vfloat32m4_t src_data = __riscv_vle32_v_f32m4(p_input, gvl); + // The number of elements in each chunk + const int64_t dr = (nr + nchunk - 1) / nchunk; - sum = __riscv_vfadd_vv_f32m4(sum, src_data, gvl); - sum_sq = __riscv_vfmacc_vv_f32m4(sum_sq, src_data, src_data, gvl); + // The first chunk comes from our thread_id, the rest will get auto-assigned. + int current_chunk = ith; - __riscv_vse32_v_f32m4(p_temp_output, src_data, gvl); + while (current_chunk < nchunk) { + const int64_t ir0 = dr * current_chunk; + const int64_t ir1 = MIN(ir0 + dr, nr); - p_input += gvl; - p_temp_output += gvl; - length -= gvl; + if (use_tiled) { + spacemit_kernels::rvv::forward_flash_attn_ext_f16_tiled_vlen1024_vf16( + params, dst, ir0, ir1, ggml::cpu::riscv64_spacemit::tls_context.tcm_buffer, + ggml::cpu::riscv64_spacemit::tls_context.tcm_buffer_size); + } else { + spacemit_kernels::rvv::forward_flash_attn_ext_f16_one_chunk_vlen1024_vf16( + params, dst, ir0, ir1, ggml::cpu::riscv64_spacemit::tls_context.tcm_buffer, + ggml::cpu::riscv64_spacemit::tls_context.tcm_buffer_size); } - gvl = __riscv_vsetvlmax_e32m1(); - - float mean = 0.f; - vfloat32m1_t zero_v = __riscv_vfmv_v_f_f32m1(0.f, gvl); - vfloat32m1_t mean_v = - __riscv_vfadd_vv_f32m1(__riscv_vget_v_f32m4_f32m1(sum, 0), __riscv_vget_v_f32m4_f32m1(sum, 1), gvl); - mean_v = __riscv_vfadd_vv_f32m1(mean_v, __riscv_vget_v_f32m4_f32m1(sum, 2), gvl); - mean_v = __riscv_vfadd_vv_f32m1(mean_v, __riscv_vget_v_f32m4_f32m1(sum, 3), gvl); - mean_v = __riscv_vfredusum_vs_f32m1_f32m1(mean_v, zero_v, gvl); - mean = __riscv_vfmv_f_s_f32m1_f32(mean_v); - mean /= hidden_size; - - vfloat32m1_t mean_square_v = __riscv_vfadd_vv_f32m1(__riscv_vget_v_f32m4_f32m1(sum_sq, 0), - __riscv_vget_v_f32m4_f32m1(sum_sq, 1), gvl); - mean_square_v = __riscv_vfadd_vv_f32m1(mean_square_v, __riscv_vget_v_f32m4_f32m1(sum_sq, 2), gvl); - mean_square_v = __riscv_vfadd_vv_f32m1(mean_square_v, __riscv_vget_v_f32m4_f32m1(sum_sq, 3), gvl); - mean_square_v = __riscv_vfredusum_vs_f32m1_f32m1(mean_square_v, zero_v, gvl); - - float mean_square = __riscv_vfmv_f_s_f32m1_f32(mean_square_v); - mean_square /= hidden_size; - mean_square = sqrt(mean_square - mean * mean + epsilon); - - mean_square = 1.0f / mean_square; - length = hidden_size; - p_temp_output = p_output; - - if (p_gamma_data == nullptr && p_beta_data == nullptr) { - while (length > 0) { - gvl = __riscv_vsetvl_e32m4(length); - vfloat32m4_t src_data = __riscv_vle32_v_f32m4(p_temp_output, gvl); - src_data = __riscv_vfsub_vf_f32m4(src_data, mean, gvl); - src_data = __riscv_vfmul_vf_f32m4(src_data, mean_square, gvl); - __riscv_vse32_v_f32m4(p_output, src_data, gvl); - p_temp_output += gvl; - p_output += gvl; - length -= gvl; - } - } else if (p_beta_data == nullptr) { - while (length > 0) { - gvl = __riscv_vsetvl_e32m4(length); - vfloat32m4_t src_data = __riscv_vle32_v_f32m4(p_temp_output, gvl); - vfloat32m4_t gamma_data_v = __riscv_vle32_v_f32m4(p_gamma_data, gvl); - src_data = __riscv_vfsub_vf_f32m4(src_data, mean, gvl); - src_data = __riscv_vfmul_vf_f32m4(src_data, mean_square, gvl); - src_data = __riscv_vfmul_vv_f32m4(src_data, gamma_data_v, gvl); - __riscv_vse32_v_f32m4(p_output, src_data, gvl); - p_temp_output += gvl; - p_output += gvl; - p_gamma_data += gvl; - length -= gvl; - } - } else if (p_gamma_data != nullptr) { - while (length > 0) { - gvl = __riscv_vsetvl_e32m4(length); - vfloat32m4_t src_data = __riscv_vle32_v_f32m4(p_temp_output, gvl); - vfloat32m4_t gamma_data_v = __riscv_vle32_v_f32m4(p_gamma_data, gvl); - src_data = __riscv_vfsub_vf_f32m4(src_data, mean, gvl); - src_data = __riscv_vfmul_vf_f32m4(src_data, mean_square, gvl); - src_data = __riscv_vfmul_vv_f32m4(src_data, gamma_data_v, gvl); - vfloat32m4_t beta_data_v = __riscv_vle32_v_f32m4(p_beta_data, gvl); - src_data = __riscv_vfadd_vv_f32m4(src_data, beta_data_v, gvl); - p_beta_data += gvl; - __riscv_vse32_v_f32m4(p_output, src_data, gvl); - p_temp_output += gvl; - p_output += gvl; - p_gamma_data += gvl; - length -= gvl; - } - } + current_chunk = ggml_threadpool_chunk_add(params->threadpool, 1); } } - void forward_rms_norm_f32(ggml_compute_params * params, ggml_tensor * op) { - const ggml_tensor * src0 = op->src[0]; - ggml_tensor * dst = op; - GGML_ASSERT(ggml_are_same_shape(src0, dst)); - GGML_ASSERT(src0->nb[0] == sizeof(float)); - - const int ith = params->ith; - const int nth = params->nth; - - GGML_TENSOR_UNARY_OP_LOCALS - - float epsilon; - memcpy(&epsilon, dst->op_params, sizeof(float)); - - GGML_ASSERT(epsilon > 0.0f); - - auto * input = (float *) src0->data; - auto * output = (float *) dst->data; - - const auto hidden_size = ne00; - const auto task_count = ne01 * ne02 * ne03; - const auto task_per_thread = (task_count + nth - 1) / nth; - - const auto task_begin = ith * task_per_thread; - const auto task_end = std::min((ith + 1) * task_per_thread, task_count); - - for (auto task_idx = task_begin; task_idx < task_end; task_idx++) { - auto offset = task_idx * hidden_size; - auto * p_input = const_cast<float *>(input + offset); - auto * p_output = output + offset; - auto * p_temp_output = p_output; - auto * p_gamma_data = (const float *) nullptr; - auto * p_beta_data = (const float *) nullptr; - - size_t gvl = __riscv_vsetvlmax_e32m4(); - // vfloat32m4_t sum = __riscv_vfmv_v_f_f32m4(0.f, gvl); - vfloat32m4_t sum_sq = __riscv_vfmv_v_f_f32m4(0.f, gvl); - int64_t length = hidden_size; - while (length > 0) { - gvl = __riscv_vsetvl_e32m4(length); - // load data - vfloat32m4_t src_data = __riscv_vle32_v_f32m4(p_input, gvl); + int repack(ggml_tensor * t, const void * data, size_t data_size) override { + memcpy(t->data, data, data_size); + return 0; + } +}; - sum_sq = __riscv_vfmacc_vv_f32m4(sum_sq, src_data, src_data, gvl); +// Impl By IME1 +static const tensor_traits<block_q4_0, 32, 16> q4_0_16x32_q8_0; +static const tensor_traits<block_q4_1, 32, 16> q4_1_16x32_q8_0; +static const tensor_traits<block_q4_K, 32, 16> q4_k_16x32_q8_0; +// Impl By IME2 +static const tensor_traits<block_q2_K, 256, 32> q2_k_32x256_q8_0; +static const tensor_traits<block_q3_K, 256, 32> q3_k_32x256_q8_0; +static const tensor_traits<block_q4_0, 32, 32> q4_0_32x32_q8_0; +static const tensor_traits<block_q4_1, 32, 32> q4_1_32x32_q8_0; +static const tensor_traits<block_q4_0, 256, 32> q4_0_32x256_q8_0; +static const tensor_traits<block_q4_1, 256, 32> q4_1_32x256_q8_0; +static const tensor_traits<block_q4_K, 32, 32> q4_k_32x32_q8_0; +static const tensor_traits<block_q6_K, 32, 32> q6_k_32x32_q8_0; +static const tensor_traits<block_q8_0, 32, 32> q8_0_32x32_q8_0; +static const tensor_traits<block_mxfp4, 32, 32> mxfp4_32x32_q8_0; +static const tensor_traits<block_q5_K, 32, 32> q5_k_32x32_q8_0; +static const tensor_traits<block_q5_1, 32, 32> q5_1_32x32_q8_0; +static const tensor_traits<block_q5_0, 32, 32> q5_0_32x32_q8_0; +// Impl By RVV +static const tensor_traits_common rvv_impl; - __riscv_vse32_v_f32m4(p_temp_output, src_data, gvl); +} // namespace ggml::cpu::riscv64_spacemit - p_input += gvl; - p_temp_output += gvl; - length -= gvl; +static const ggml::cpu::tensor_traits * ggml_riscv64_spacemit_get_optimal_repack_type(const ggml_tensor * cur) { + switch (cur->type) { + case GGML_TYPE_Q2_K: + { +#if defined(RISCV64_SPACEMIT_IME2) + if (cur->ne[1] % 32 == 0 && (ggml::cpu::riscv64_spacemit::global_spine_env_info.use_ime2)) { + return &ggml::cpu::riscv64_spacemit::q2_k_32x256_q8_0; + } +#endif } + break; + case GGML_TYPE_Q3_K: + { +#if defined(RISCV64_SPACEMIT_IME2) + if (cur->ne[1] % 32 == 0 && (ggml::cpu::riscv64_spacemit::global_spine_env_info.use_ime2)) { + return &ggml::cpu::riscv64_spacemit::q3_k_32x256_q8_0; + } +#endif + } + break; + case GGML_TYPE_Q4_0: + { +#if defined(RISCV64_SPACEMIT_IME2) + if (cur->ne[1] % 32 == 0 && cur->ne[0] % 256 == 0 && + (ggml::cpu::riscv64_spacemit::global_spine_env_info.use_ime2)) { + return &ggml::cpu::riscv64_spacemit::q4_0_32x256_q8_0; + } - gvl = __riscv_vsetvlmax_e32m1(); - - // float mean = 0.f; - vfloat32m1_t zero_v = __riscv_vfmv_v_f_f32m1(0.f, gvl); - - vfloat32m1_t mean_square_v = __riscv_vfadd_vv_f32m1(__riscv_vget_v_f32m4_f32m1(sum_sq, 0), - __riscv_vget_v_f32m4_f32m1(sum_sq, 1), gvl); - mean_square_v = __riscv_vfadd_vv_f32m1(mean_square_v, __riscv_vget_v_f32m4_f32m1(sum_sq, 2), gvl); - mean_square_v = __riscv_vfadd_vv_f32m1(mean_square_v, __riscv_vget_v_f32m4_f32m1(sum_sq, 3), gvl); - mean_square_v = __riscv_vfredusum_vs_f32m1_f32m1(mean_square_v, zero_v, gvl); - - float mean_square = __riscv_vfmv_f_s_f32m1_f32(mean_square_v); - mean_square /= hidden_size; + if (cur->ne[1] % 32 == 0 && (ggml::cpu::riscv64_spacemit::global_spine_env_info.use_ime2)) { + return &ggml::cpu::riscv64_spacemit::q4_0_32x32_q8_0; + } +#endif - mean_square = sqrt(mean_square + epsilon); +#if defined(RISCV64_SPACEMIT_IME1) + if (cur->ne[1] % 16 == 0 && (ggml::cpu::riscv64_spacemit::global_spine_env_info.use_ime1)) { + return &ggml::cpu::riscv64_spacemit::q4_0_16x32_q8_0; + } +#endif + } + break; + case GGML_TYPE_Q4_1: + { +#if defined(RISCV64_SPACEMIT_IME2) + // TODO + // if (cur->ne[1] % 32 == 0 && cur->ne[0] % 256 == 0 && + // (ggml::cpu::riscv64_spacemit::global_spine_env_info.use_ime2)) { + // return &ggml::cpu::riscv64_spacemit::q4_1_32x256_q8_0; + // } + + if (cur->ne[1] % 32 == 0 && (ggml::cpu::riscv64_spacemit::global_spine_env_info.use_ime2)) { + return &ggml::cpu::riscv64_spacemit::q4_1_32x32_q8_0; + } +#endif - mean_square = 1.0f / mean_square; - length = hidden_size; - p_temp_output = p_output; +#if defined(RISCV64_SPACEMIT_IME1) + if (cur->ne[1] % 16 == 0 && (ggml::cpu::riscv64_spacemit::global_spine_env_info.use_ime1)) { + return &ggml::cpu::riscv64_spacemit::q4_1_16x32_q8_0; + } +#endif + } + break; + case GGML_TYPE_Q4_K: + { +#if defined(RISCV64_SPACEMIT_IME2) + if (cur->ne[1] % 32 == 0 && (ggml::cpu::riscv64_spacemit::global_spine_env_info.use_ime2)) { + return &ggml::cpu::riscv64_spacemit::q4_k_32x32_q8_0; + } +#endif - if (p_gamma_data == nullptr && p_beta_data == nullptr) { - while (length > 0) { - gvl = __riscv_vsetvl_e32m4(length); - vfloat32m4_t src_data = __riscv_vle32_v_f32m4(p_temp_output, gvl); - src_data = __riscv_vfmul_vf_f32m4(src_data, mean_square, gvl); - __riscv_vse32_v_f32m4(p_output, src_data, gvl); - p_temp_output += gvl; - p_output += gvl; - length -= gvl; +#if defined(RISCV64_SPACEMIT_IME1) + if (cur->ne[1] % 16 == 0 && (ggml::cpu::riscv64_spacemit::global_spine_env_info.use_ime1)) { + return &ggml::cpu::riscv64_spacemit::q4_k_16x32_q8_0; } - } else if (p_beta_data == nullptr) { - while (length > 0) { - gvl = __riscv_vsetvl_e32m4(length); - vfloat32m4_t src_data = __riscv_vle32_v_f32m4(p_temp_output, gvl); - vfloat32m4_t gamma_data_v = __riscv_vle32_v_f32m4(p_gamma_data, gvl); - src_data = __riscv_vfmul_vf_f32m4(src_data, mean_square, gvl); - src_data = __riscv_vfmul_vv_f32m4(src_data, gamma_data_v, gvl); - __riscv_vse32_v_f32m4(p_output, src_data, gvl); - p_temp_output += gvl; - p_output += gvl; - p_gamma_data += gvl; - length -= gvl; +#endif + } + break; + case GGML_TYPE_Q6_K: + { +#if defined(RISCV64_SPACEMIT_IME2) + if ((ggml::cpu::riscv64_spacemit::global_spine_env_info.use_ime2)) { + return &ggml::cpu::riscv64_spacemit::q6_k_32x32_q8_0; } - } else if (p_gamma_data != nullptr) { - while (length > 0) { - gvl = __riscv_vsetvl_e32m4(length); - vfloat32m4_t src_data = __riscv_vle32_v_f32m4(p_temp_output, gvl); - vfloat32m4_t gamma_data_v = __riscv_vle32_v_f32m4(p_gamma_data, gvl); - src_data = __riscv_vfmul_vf_f32m4(src_data, mean_square, gvl); - src_data = __riscv_vfmul_vv_f32m4(src_data, gamma_data_v, gvl); - vfloat32m4_t beta_data_v = __riscv_vle32_v_f32m4(p_beta_data, gvl); - src_data = __riscv_vfadd_vv_f32m4(src_data, beta_data_v, gvl); - p_beta_data += gvl; - __riscv_vse32_v_f32m4(p_output, src_data, gvl); - p_temp_output += gvl; - p_output += gvl; - p_gamma_data += gvl; - length -= gvl; +#endif + } + break; + case GGML_TYPE_Q8_0: + { +#if defined(RISCV64_SPACEMIT_IME2) + if ((ggml::cpu::riscv64_spacemit::global_spine_env_info.use_ime2)) { + return &ggml::cpu::riscv64_spacemit::q8_0_32x32_q8_0; } +#endif } - } - } - - int repack(struct ggml_tensor * t, const void * data, size_t data_size) override { - memcpy(t->data, data, data_size); - return 0; - } -}; - -static const tensor_traits<block_q4_0, 8, 16> q4_0_16x8_q8_0; -static const tensor_traits<block_q4_1, 8, 16> q4_1_16x8_q8_0; -static const tensor_traits<block_q4_K, 8, 16> q4_k_16x8_q8_0; -static const tensor_traits_common rvv_impl; - -} // namespace ggml::cpu::riscv64_spacemit - -static const ggml::cpu::tensor_traits * ggml_riscv64_spacemit_get_optimal_repack_type(const struct ggml_tensor * cur) { - if (cur->type == GGML_TYPE_Q4_0) { - if (cur->ne[1] % 16 == 0) { - return &ggml::cpu::riscv64_spacemit::q4_0_16x8_q8_0; - } - } else if (cur->type == GGML_TYPE_Q4_1) { - if (cur->ne[1] % 16 == 0) { - return &ggml::cpu::riscv64_spacemit::q4_1_16x8_q8_0; - } - } else if (cur->type == GGML_TYPE_Q4_K) { - if (cur->ne[1] % 16 == 0) { - return &ggml::cpu::riscv64_spacemit::q4_k_16x8_q8_0; - } - } else if (cur->type == GGML_TYPE_F32) { - return &ggml::cpu::riscv64_spacemit::rvv_impl; + break; + case GGML_TYPE_MXFP4: + { +#if defined(RISCV64_SPACEMIT_IME2) + // TODO + // if (cur->ne[1] % 32 == 0 && (ggml::cpu::riscv64_spacemit::global_spine_env_info.use_ime2)) { + // return &ggml::cpu::riscv64_spacemit::mxfp4_32x32_q8_0; + // } +#endif + } + break; + case GGML_TYPE_Q5_K: + { +#if defined(RISCV64_SPACEMIT_IME2) + if (cur->ne[1] % 32 == 0 && (ggml::cpu::riscv64_spacemit::global_spine_env_info.use_ime2)) { + return &ggml::cpu::riscv64_spacemit::q5_k_32x32_q8_0; + } +#endif + } + break; + case GGML_TYPE_Q5_1: + { +#if defined(RISCV64_SPACEMIT_IME2) + if (cur->ne[1] % 32 == 0 && (ggml::cpu::riscv64_spacemit::global_spine_env_info.use_ime2)) { + return &ggml::cpu::riscv64_spacemit::q5_1_32x32_q8_0; + } +#endif + } + break; + case GGML_TYPE_Q5_0: + { +#if defined(RISCV64_SPACEMIT_IME2) + if (cur->ne[1] % 32 == 0 && (ggml::cpu::riscv64_spacemit::global_spine_env_info.use_ime2)) { + return &ggml::cpu::riscv64_spacemit::q5_0_32x32_q8_0; + } +#endif + } + break; + default: + break; } return nullptr; } static enum ggml_status ggml_backend_riscv64_spacemit_buffer_init_tensor(ggml_backend_buffer_t buffer, - struct ggml_tensor * tensor) { + ggml_tensor * tensor) { tensor->extra = (void *) const_cast<ggml::cpu::tensor_traits *>(ggml_riscv64_spacemit_get_optimal_repack_type(tensor)); @@ -874,8 +1402,46 @@ static enum ggml_status ggml_backend_riscv64_spacemit_buffer_init_tensor(ggml_ba return GGML_STATUS_SUCCESS; } +static void ggml_backend_riscv64_spacemit_buffer_free_buffer(ggml_backend_buffer_t buffer) { + GGML_ASSERT(buffer); + + void * base = buffer->context; + if (base == nullptr) { + return; + } + + ggml::cpu::riscv64_spacemit::spine_mem_pool_free(base); +} + +static void * ggml_backend_riscv64_spacemit_buffer_get_base(ggml_backend_buffer_t buffer) { + GGML_ASSERT(buffer); + + void * base = buffer->context; + GGML_ASSERT(base != nullptr); + return base; +} + +static void ggml_backend_riscv64_spacemit_buffer_memset_tensor(ggml_backend_buffer_t buffer, + ggml_tensor * tensor, + uint8_t value, + size_t offset, + size_t size) { + GGML_ASSERT(tensor); + memset((char *) tensor->data + offset, value, size); + + GGML_UNUSED(buffer); +} + +static void ggml_backend_riscv64_spacemit_buffer_clear(ggml_backend_buffer_t buffer, uint8_t value) { + GGML_ASSERT(buffer); + + void * base = buffer->context; + GGML_ASSERT(base != nullptr); + memset(base, value, buffer->size); +} + static void ggml_backend_riscv64_spacemit_buffer_set_tensor(ggml_backend_buffer_t buffer, - struct ggml_tensor * tensor, + ggml_tensor * tensor, const void * data, size_t offset, size_t size) { @@ -891,6 +1457,20 @@ static void ggml_backend_riscv64_spacemit_buffer_set_tensor(ggml_backend_buffer_ GGML_UNUSED(buffer); } +static const ggml_backend_buffer_i ggml_backend_riscv64_spacemit_buffer_i = { + /* .free_buffer = */ ggml_backend_riscv64_spacemit_buffer_free_buffer, + /* .get_base = */ ggml_backend_riscv64_spacemit_buffer_get_base, + /* .init_tensor = */ ggml_backend_riscv64_spacemit_buffer_init_tensor, + /* .memset_tensor = */ ggml_backend_riscv64_spacemit_buffer_memset_tensor, + /* .set_tensor = */ ggml_backend_riscv64_spacemit_buffer_set_tensor, + /* .get_tensor = */ nullptr, + /* .set_tensor_2d = */ nullptr, + /* .get_tensor_2d = */ nullptr, + /* .cpy_tensor = */ nullptr, + /* .clear = */ ggml_backend_riscv64_spacemit_buffer_clear, + /* .reset = */ nullptr, +}; + static const char * ggml_backend_cpu_riscv64_spacemit_buffer_type_get_name(ggml_backend_buffer_type_t buft) { return "CPU_RISCV64_SPACEMIT"; @@ -899,18 +1479,12 @@ static const char * ggml_backend_cpu_riscv64_spacemit_buffer_type_get_name(ggml_ static ggml_backend_buffer_t ggml_backend_cpu_riscv64_spacemit_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) { - ggml_backend_buffer_t buffer = ggml_backend_buft_alloc_buffer(ggml_backend_cpu_buffer_type(), size); - - if (buffer == nullptr) { + void * base = ggml::cpu::riscv64_spacemit::spine_mem_pool_alloc(size, 64); + if (base == nullptr) { return nullptr; } - buffer->buft = buft; - buffer->iface.init_tensor = ggml_backend_riscv64_spacemit_buffer_init_tensor; - buffer->iface.set_tensor = ggml_backend_riscv64_spacemit_buffer_set_tensor; - buffer->iface.get_tensor = nullptr; - buffer->iface.cpy_tensor = nullptr; - return buffer; + return ggml_backend_buffer_init(buft, ggml_backend_riscv64_spacemit_buffer_i, base, size); } static size_t ggml_backend_cpu_riscv64_spacemit_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) { @@ -919,44 +1493,91 @@ static size_t ggml_backend_cpu_riscv64_spacemit_buffer_type_get_alignment(ggml_b GGML_UNUSED(buft); } -static size_t ggml_backend_cpu_riscv64_spacemit_nbytes(ggml_backend_buffer_type_t buft, - const struct ggml_tensor * tensor) { +static size_t ggml_backend_cpu_riscv64_spacemit_nbytes(ggml_backend_buffer_type_t buft, const ggml_tensor * tensor) { for (int i = 0; i < GGML_MAX_DIMS; ++i) { if (tensor->ne[i] <= 0) { return 0; } } - size_t nbytes; + GGML_UNUSED(buft); + + const auto plain_nbytes = [&]() { + size_t total = ggml_type_size(tensor->type); + for (int i = 0; i < GGML_MAX_DIMS; ++i) { + total += (tensor->ne[i] - 1) * tensor->nb[i]; + } + return total; + }; + const size_t blck_size = ggml_blck_size(tensor->type); if (blck_size == 1) { - nbytes = ggml_type_size(tensor->type); - for (int i = 0; i < GGML_MAX_DIMS; ++i) { - nbytes += (tensor->ne[i] - 1) * tensor->nb[i]; + return plain_nbytes(); + } + + const size_t row_nbytes = tensor->ne[0] * tensor->nb[0] / blck_size; + + const auto add_strided_nbytes = [&](size_t total, size_t src_block_size, size_t dst_block_size) { + for (int i = 1; i < GGML_MAX_DIMS; ++i) { + total += (tensor->ne[i] - 1) * (tensor->nb[i] / src_block_size) * dst_block_size; } - } else { - nbytes = tensor->ne[0] * tensor->nb[0] / blck_size; - if (tensor->type == GGML_TYPE_Q4_K) { - GGML_ASSERT(nbytes % sizeof(block_q4_K) == 0); - nbytes = (nbytes / sizeof(block_q4_K)) * sizeof(block_q4_1) * 8; - for (int i = 1; i < GGML_MAX_DIMS; ++i) { - nbytes += (tensor->ne[i] - 1) * (tensor->nb[i] / sizeof(block_q4_K)) * sizeof(block_q4_1) * 8; - } - } else { - for (int i = 1; i < GGML_MAX_DIMS; ++i) { - nbytes += (tensor->ne[i] - 1) * tensor->nb[i]; - } + return total; + }; + + const auto remap_block_nbytes = [&](size_t src_block_size, size_t dst_block_size, int64_t padded_rows = 0) { + GGML_ASSERT(row_nbytes % src_block_size == 0); + + size_t total = + add_strided_nbytes((row_nbytes / src_block_size) * dst_block_size, src_block_size, dst_block_size); + + if (padded_rows > 0 && tensor->ne[1] % padded_rows != 0) { + total += (padded_rows - tensor->ne[1] % padded_rows) * (tensor->nb[1] / src_block_size) * dst_block_size; } + + return total; + }; + + size_t nbytes = row_nbytes; + switch (tensor->type) { + case GGML_TYPE_Q4_K: + nbytes = remap_block_nbytes(sizeof(block_q4_K), sizeof(block_q4_1) * 8); + break; + case GGML_TYPE_Q6_K: + nbytes = remap_block_nbytes(sizeof(block_q6_K), sizeof(block_q8_0) * 8, 32); + break; + case GGML_TYPE_Q8_0: + nbytes = remap_block_nbytes(sizeof(block_q8_0), sizeof(block_q8_0), 32); + break; + case GGML_TYPE_Q2_K: + nbytes = remap_block_nbytes(sizeof(block_q2_K), sizeof(spacemit_kernels::nrow_block_q2_k<1>)); + break; + case GGML_TYPE_Q3_K: + nbytes = remap_block_nbytes(sizeof(block_q3_K), sizeof(spacemit_kernels::nrow_block_q3_k<1>)); + break; + case GGML_TYPE_MXFP4: + nbytes = remap_block_nbytes(sizeof(block_mxfp4), sizeof(spacemit_kernels::nrow_block_mxfp4<1>)); + break; + case GGML_TYPE_Q5_K: + nbytes = remap_block_nbytes(sizeof(block_q5_K), sizeof(spacemit_kernels::nrow_block_q5_1<1>) * 8); + break; + case GGML_TYPE_Q5_1: + nbytes = remap_block_nbytes(sizeof(block_q5_1), sizeof(spacemit_kernels::nrow_block_q5_1<1>)); + break; + case GGML_TYPE_Q5_0: + nbytes = remap_block_nbytes(sizeof(block_q5_0), sizeof(spacemit_kernels::nrow_block_q5_0<1>)); + break; + default: + nbytes = add_strided_nbytes(row_nbytes, 1, 1); + break; } - GGML_UNUSED(buft); return nbytes; } namespace ggml::cpu::riscv64_spacemit { class extra_buffer_type : ggml::cpu::extra_buffer_type { - bool supports_op(ggml_backend_dev_t, const struct ggml_tensor * op) override { + bool supports_op(ggml_backend_dev_t, const ggml_tensor * op) override { switch (op->op) { case GGML_OP_MUL_MAT: if (op->src[0]->buffer && (ggml_n_dims(op->src[0]) == 2) && @@ -970,10 +1591,16 @@ class extra_buffer_type : ggml::cpu::extra_buffer_type { } } break; - case GGML_OP_NORM: - case GGML_OP_RMS_NORM: - if (op->src[0]->type == GGML_TYPE_F32) { - return true; + case GGML_OP_MUL_MAT_ID: + if (op->src[0]->buffer && (ggml_n_dims(op->src[0]) == 3) && + op->src[0]->buffer->buft == ggml_backend_cpu_riscv64_spacemit_buffer_type() && + ggml_riscv64_spacemit_get_optimal_repack_type(op->src[0])) { + if (op->src[1]->buffer && !ggml_backend_buft_is_host(op->src[1]->buffer->buft)) { + return false; + } + if (op->src[1]->type == GGML_TYPE_F32) { + return true; + } } break; default: @@ -983,15 +1610,28 @@ class extra_buffer_type : ggml::cpu::extra_buffer_type { return false; } - ggml::cpu::tensor_traits * get_tensor_traits(const struct ggml_tensor * op) override { + ggml::cpu::tensor_traits * get_tensor_traits(const ggml_tensor * op) override { switch (op->op) { case GGML_OP_MUL_MAT: + case GGML_OP_MUL_MAT_ID: if (op->src[0]->buffer && op->src[0]->buffer->buft == ggml_backend_cpu_riscv64_spacemit_buffer_type()) { return (ggml::cpu::tensor_traits *) op->src[0]->extra; } break; case GGML_OP_NORM: case GGML_OP_RMS_NORM: + case GGML_OP_ADD: + case GGML_OP_SUB: + case GGML_OP_MUL: + case GGML_OP_DIV: + case GGML_OP_FLASH_ATTN_EXT: + case GGML_OP_CONT: + case GGML_OP_CPY: + case GGML_OP_REPEAT: + case GGML_OP_SUM_ROWS: + case GGML_OP_GET_ROWS: + case GGML_OP_CONCAT: + // case GGML_OP_GATED_DELTA_NET: return (ggml::cpu::tensor_traits *) (&ggml::cpu::riscv64_spacemit::rvv_impl); default: // GGML_ABORT("fatal error"); @@ -1005,7 +1645,7 @@ class extra_buffer_type : ggml::cpu::extra_buffer_type { } // namespace ggml::cpu::riscv64_spacemit ggml_backend_buffer_type_t ggml_backend_cpu_riscv64_spacemit_buffer_type(void) { - static struct ggml_backend_buffer_type ggml_backend_cpu_buffer_type_riscv64_spacemit = { + static ggml_backend_buffer_type ggml_backend_cpu_buffer_type_riscv64_spacemit = { /* .iface = */ { /* .get_name = */ ggml_backend_cpu_riscv64_spacemit_buffer_type_get_name, @@ -1023,3 +1663,78 @@ ggml_backend_buffer_type_t ggml_backend_cpu_riscv64_spacemit_buffer_type(void) { return &ggml_backend_cpu_buffer_type_riscv64_spacemit; } + +extern "C" { +static int bind_ai_thread() { + int fd, bytes; + char str[32]; + + fd = open("/proc/set_ai_thread", O_WRONLY); + if (fd < 0) { + GGML_LOG_ERROR("try open /proc/set_ai_thread failed\n"); + return -1; + } + + snprintf(str, 16, "%d", 0); + bytes = write(fd, str, strlen(str)); + if (bytes < 0) { + GGML_LOG_ERROR("try write /proc/set_ai_thread failed\n"); + close(fd); + return -1; + } + + close(fd); + return 0; +} + +void ggml_backend_cpu_riscv64_spacemit_set_numa_thread_affinity(int thread_n) { + int cpu_id = sched_getcpu(); + if (ggml::cpu::riscv64_spacemit::global_spine_env_info.use_ime2 && + !((1 << cpu_id) & ggml::cpu::riscv64_spacemit::global_spine_env_info.cpu_mask)) { + GGML_PRINT_DEBUG("bind_ai_thread for thread %d, pid %d\n", thread_n, getpid()); + bind_ai_thread(); + } + + if (ggml::cpu::riscv64_spacemit::global_spine_env_info.use_tcm && + ggml::cpu::riscv64_spacemit::tls_context.cpu_id == -1) { + CPU_ZERO(&(ggml::cpu::riscv64_spacemit::tls_context.cpuset)); + pthread_t main_thread = pthread_self(); + const auto & perfer_core_ids = ggml::cpu::riscv64_spacemit::global_spine_env_info.perfer_core_ids; + if (thread_n < 0 || static_cast<size_t>(thread_n) >= perfer_core_ids.size()) { + GGML_ABORT("thread_n %d exceeds perfer_core_ids size %zu\n", thread_n, perfer_core_ids.size()); + } + auto perfer_cpu_id = perfer_core_ids[static_cast<size_t>(thread_n)]; + CPU_SET(perfer_cpu_id, &(ggml::cpu::riscv64_spacemit::tls_context.cpuset)); + int s = + pthread_setaffinity_np(main_thread, sizeof(cpu_set_t), &(ggml::cpu::riscv64_spacemit::tls_context.cpuset)); + if (s != 0) { + GGML_ABORT("set thread affinity error for thread_n %d, cpu_id %d\n", thread_n, perfer_cpu_id); + } + + int ai_cpu_id = perfer_cpu_id - ggml::cpu::riscv64_spacemit::global_spine_env_info.aicpu_id_offset; + ggml::cpu::riscv64_spacemit::tls_context.cpu_id = ai_cpu_id; + ggml::cpu::riscv64_spacemit::tls_context.tcm_buffer = + ggml::cpu::riscv64_spacemit::spine_mem_pool_tcm_mem_get(ai_cpu_id); + ggml::cpu::riscv64_spacemit::tls_context.tcm_buffer_size = + ggml::cpu::riscv64_spacemit::global_spine_env_info.tcm_blk_size; + } + + if (ggml::cpu::riscv64_spacemit::tls_context.tcm_buffer != nullptr) { + void * rt = + ggml::cpu::riscv64_spacemit::spine_mem_pool_tcm_mem_wait(ggml::cpu::riscv64_spacemit::tls_context.cpu_id); + if (rt == nullptr) { + GGML_ABORT("wait tcm buffer failed for cpu_id: %d", ggml::cpu::riscv64_spacemit::tls_context.cpu_id); + } + } +} + +void ggml_backend_cpu_riscv64_spacemit_clear_numa_thread_affinity_threaded(int thread_n) { + if (ggml::cpu::riscv64_spacemit::tls_context.tcm_buffer != nullptr) { + auto rt = ggml::cpu::riscv64_spacemit::spine_mem_pool_tcm_mem_release( + ggml::cpu::riscv64_spacemit::tls_context.cpu_id); + if (rt != 0) { + GGML_ABORT("release tcm buffer failed for cpu_id: %d", ggml::cpu::riscv64_spacemit::tls_context.cpu_id); + } + } +} +} diff --git a/ggml/src/ggml-cpu/spacemit/ime.h b/ggml/src/ggml-cpu/spacemit/ime.h index 800d91acdae..6849dd95e05 100644 --- a/ggml/src/ggml-cpu/spacemit/ime.h +++ b/ggml/src/ggml-cpu/spacemit/ime.h @@ -8,6 +8,14 @@ extern "C" { ggml_backend_buffer_type_t ggml_backend_cpu_riscv64_spacemit_buffer_type(void); +void ggml_backend_cpu_riscv64_spacemit_set_numa_thread_affinity(int thread_n); + +void ggml_backend_cpu_riscv64_spacemit_clear_numa_thread_affinity_threaded(int thread_n); + +void * ggml_backend_cpu_riscv64_spacemit_alloc_shared(size_t size, size_t alignment); + +void ggml_backend_cpu_riscv64_spacemit_free_shared(void * ptr); + #ifdef __cplusplus } #endif diff --git a/ggml/src/ggml-cpu/spacemit/ime1_kernels.cpp b/ggml/src/ggml-cpu/spacemit/ime1_kernels.cpp index cbbb6cd9160..6acc6819dfb 100644 --- a/ggml/src/ggml-cpu/spacemit/ime1_kernels.cpp +++ b/ggml/src/ggml-cpu/spacemit/ime1_kernels.cpp @@ -1,8 +1,26 @@ +#include "ggml-impl.h" #include "ggml.h" #include "ime_kernels.h" +#include "rvv_kernels.h" #include <algorithm> #include <cmath> +#include <stdexcept> + +#if !defined(__riscv_v) || !defined(__riscv_v_intrinsic) +# error "riscv v extension or v_intrinsic not enabled" +#else +# include <riscv_vector.h> +#endif + +#if !defined(__riscv_zfh) +# error "riscv zfh extension not enabled" +#endif + +#if defined(RISCV64_SPACEMIT_IME1) +#else +# error "RISCV64_SPACEMIT_IME1 not defined" +#endif // clang-format off #if defined(__GNUC__) @@ -11,7 +29,7 @@ #pragma GCC diagnostic ignored "-Wunused-parameter" #endif // clang-format on -namespace sqnbitgemm_spacemit_ime { +namespace spacemit_kernels { #define QUANTIZEM4ROW_KERNEL \ "vmv.s.x v16, zero \n\t" \ @@ -76,1093 +94,208 @@ namespace sqnbitgemm_spacemit_ime { "vse8.v v31, (s1) \n\t" namespace ime1 { -void quantize_a_4row_i8(size_t BlkLen, const float * A, size_t CountK, std::byte * QuantA) { +void quantize_a_4row_i8(size_t BlkLen, const float * A, size_t CountK, uint8_t * QuantA) { constexpr float range_max_reciprocal = 1.0f / ((1 << 7) - 1); const float fone = 1.0f; - if (BlkLen == 16 || BlkLen == 32 || BlkLen == 64) { - for (size_t row_index = 0; row_index < 4; ++row_index) { - const float * SRC = A + row_index * CountK; - std::byte * DST = QuantA + row_index * sizeof(float); + for (size_t row_index = 0; row_index < 4; ++row_index) { + const float * SRC = A + row_index * CountK; + uint8_t * DST = QuantA + row_index * sizeof(float); - const size_t offset = (4 - row_index) * 4 + row_index * 8; - const size_t stride = 4 * (sizeof(float) + BlkLen); - __asm__ volatile( - "vsetvli t0, zero, e32, m8 \n\t" - "addi t2, %[CountK], 0 \n\t" - "addi a1, %[DST], 0 \n\t" - "blt t2, %[BlkLen], TAIL%= \n\t" - - "LOOP%=: \n\t" - "vsetvli t0, %[BlkLen], e32, m8 \n\t" - "vle32.v v0, (%[SRC]) \n\t" - "sub t2, t2, t0 \n\t" - "slli t1, t0, 2 \n\t" - "add %[SRC], %[SRC], t1 \n\t" - "add s1, a1, %[OFFSET] \n\t" - - QUANTIZEM4ROW_KERNEL QUANTIZEM4ROW_STORE - - "add a1, a1, %[STRIDE] \n\t" - "bge t2, %[BlkLen], LOOP%= \n\t" - - "TAIL%=: \n\t" - "blez t2, QUIT%= \n\t" - "vsetvli t0, zero, e32, m8 \n\t" - "vxor.vv v16, v16, v16 \n\t" - "vxor.vv v24, v24, v24 \n\t" - "vsetvli t0, t2, e32, m8 \n\t" - "vle32.v v0, (%[SRC]) \n\t" - "add s1, a1, %[OFFSET] \n\t" - - QUANTIZEM4ROW_KERNEL - - "addi t3, %[BlkLen], 0 \n\t" - "addi s2, s1, 0 \n\t" - "vsetvli t0, zero, e8, mf4 \n\t" - "vxor.vv v8, v8, v8 \n\t" - "SET_ZERO%=: \n\t" - "vse8.v v8, (s2) \n\t" - "addi s2, s2, 32 \n\t" - "addi t3, t3, -8 \n\t" - "bnez t3, SET_ZERO%= \n\t" - - QUANTIZEM4ROW_STORE - - "QUIT%=: \n\t" - : [SRC] "+r"(SRC) - : [DST] "r"(DST), [BlkLen] "r"(BlkLen), [OFFSET] "r"(offset), [STRIDE] "r"(stride), - [CountK] "r"(CountK), [FONE] "f"(fone), [RMAXREC] "f"(range_max_reciprocal) - : "cc", "t0", "t1", "t2", "t3", "a1", "s1", "s2", "f10", "f11"); - } - } else if (BlkLen == 128) { - for (size_t row_index = 0; row_index < 4; ++row_index) { - const float * SRC = A + row_index * CountK; - std::byte * DST = QuantA + row_index * sizeof(float); - - const size_t offset = (4 - row_index) * 4 + row_index * 8; - const size_t stride = 4 * (sizeof(float) + BlkLen); - __asm__ volatile( - "vsetvli t0, zero, e32, m8 \n\t" - "li t6, 32 \n\t" - "addi t2, %[CountK], 0 \n\t" - "addi a1, %[DST], 0 \n\t" - "add s1, a1, %[OFFSET] \n\t" - "blt t2, %[BlkLen], TAIL%= \n\t" - - "LOOP%=: \n\t" - "vsetvli t0, zero, e32, m8 \n\t" - "vle32.v v0, (%[SRC]) \n\t" - "addi %[SRC], %[SRC], 256 \n\t" - "vle32.v v8, (%[SRC]) \n\t" - "addi %[SRC], %[SRC], 256 \n\t" - "addi t2, t2, -128 \n\t" - - "QUANTIZE%=: \n\t" - "add s1, a1, %[OFFSET] \n\t" - "vfabs.v v16, v0 \n\t" - "vfabs.v v24, v8 \n\t" - "vfmax.vv v16, v24, v16 \n\t" - "vfredmax.vs v24, v16, v24 \n\t" - "vfmv.f.s f10, v24 \n\t" - "fmul.s f10, f10, %[RMAXREC] \n\t" - "fsw f10, (a1) \n\t" - "fdiv.s f11, %[FONE], f10 \n\t" - "vfmul.vf v16, v0, f11 \n\t" - "vfmul.vf v24, v8, f11 \n\t" - "vfcvt.x.f.v v16, v16 \n\t" - "vfcvt.x.f.v v24, v24 \n\t" - "vsetvli t0, zero, e16, m4 \n\t" - "vnclip.wx v16, v16, zero \n\t" - "vnclip.wx v20, v24, zero \n\t" - "vsetvli t0, zero, e8, m4 \n\t" - "vnclip.wx v16, v16, zero \n\t" - "vsetvli t0, zero, e64, m4 \n\t" - "vsse64.v v16, (s1), t6 \n\t" - "add a1, a1, %[STRIDE] \n\t" - "bge t2, %[BlkLen], LOOP%= \n\t" - - "TAIL%=: \n\t" - "blez t2, QUIT%= \n\t" - "vsetvli t0, zero, e32, m8 \n\t" - "vxor.vv v0, v0, v0 \n\t" - "vxor.vv v8, v8, v8 \n\t" - "vxor.vv v16, v16, v16 \n\t" - "vxor.vv v24, v24, v24 \n\t" - "vsetvli t0, t2, e32, m8 \n\t" - "sub t2, t2, t0 \n\t" - "vle32.v v0, (%[SRC]) \n\t" - "addi %[SRC], %[SRC], 256 \n\t" - "vsetvli t0, t2, e32, m8 \n\t" - "vle32.v v8, (%[SRC]) \n\t" - "sub t2, t2, t2 \n\t" - "vsetvli t0, zero, e32, m8 \n\t" - "jal x0, QUANTIZE%= \n\t" - - "QUIT%=: \n\t" - : [SRC] "+r"(SRC) - : [DST] "r"(DST), [BlkLen] "r"(BlkLen), [OFFSET] "r"(offset), [STRIDE] "r"(stride), - [CountK] "r"(CountK), [FONE] "f"(fone), [RMAXREC] "f"(range_max_reciprocal) - : "cc", "t0", "t1", "t2", "t6", "a1", "s1", "s2", "f10", "f11"); - } - } else if (BlkLen == 256) { - for (size_t row_index = 0; row_index < 4; ++row_index) { - const float * SRC = A + row_index * CountK; - std::byte * DST = QuantA + row_index * sizeof(float); - const size_t offset = (4 - row_index) * 4 + row_index * 8; - const size_t stride = 4 * (sizeof(float) + BlkLen); - __asm__ volatile( - "vsetvli t0, zero, e32, m8 \n\t" - "li t6, 32 \n\t" - "addi t2, %[CountK], 0 \n\t" - "addi a1, %[DST], 0 \n\t" - "add s1, a1, %[OFFSET] \n\t" - "blt t2, %[BlkLen], TAIL%= \n\t" - - "LOOP%=: \n\t" - "vsetvli t0, zero, e32, m8 \n\t" - "vle32.v v0, (%[SRC]) \n\t" - "addi %[SRC], %[SRC], 256 \n\t" - "vle32.v v8, (%[SRC]) \n\t" - "addi %[SRC], %[SRC], 256 \n\t" - "vle32.v v16, (%[SRC]) \n\t" - "addi %[SRC], %[SRC], 256 \n\t" - "vle32.v v24, (%[SRC]) \n\t" - "addi %[SRC], %[SRC], -768 \n\t" - "addi t2, t2, -256 \n\t" - "vfabs.v v0, v0 \n\t" - "vfabs.v v8, v8 \n\t" - "vfabs.v v16, v16 \n\t" - "vfabs.v v24, v24 \n\t" - "vfmax.vv v8, v0, v8 \n\t" - "vfmax.vv v24, v24, v16 \n\t" - "vfmax.vv v8, v8, v24 \n\t" - "vfredmax.vs v24, v8, v24 \n\t" - "vfmv.f.s f10, v24 \n\t" - "vle32.v v0, (%[SRC]) \n\t" - "addi %[SRC], %[SRC], 256 \n\t" - "vle32.v v8, (%[SRC]) \n\t" - "addi %[SRC], %[SRC], 256 \n\t" - "vle32.v v16, (%[SRC]) \n\t" - "addi %[SRC], %[SRC], 256 \n\t" - "vle32.v v24, (%[SRC]) \n\t" - "addi %[SRC], %[SRC], 256 \n\t" - - "QUANTIZE%=: \n\t" - "add s1, a1, %[OFFSET] \n\t" - "fmul.s f10, f10, %[RMAXREC] \n\t" - "fsw f10, (a1) \n\t" - "fdiv.s f11, %[FONE], f10 \n\t" - "vfmul.vf v0, v0, f11 \n\t" - "vfmul.vf v8, v8, f11 \n\t" - "vfmul.vf v16, v16, f11 \n\t" - "vfmul.vf v24, v24, f11 \n\t" - "vfcvt.x.f.v v0, v0 \n\t" - "vfcvt.x.f.v v8, v8 \n\t" - "vfcvt.x.f.v v16, v16 \n\t" - "vfcvt.x.f.v v24, v24 \n\t" - "vsetvli t0, zero, e16, m4 \n\t" - "vnclip.wx v0, v0, zero \n\t" - "vnclip.wx v4, v8, zero \n\t" - "vnclip.wx v8, v16, zero \n\t" - "vnclip.wx v12, v24, zero \n\t" - "vsetvli t0, zero, e8, m4 \n\t" - "vnclip.wx v0, v0, zero \n\t" - "vnclip.wx v4, v8, zero \n\t" - "vsetvli t0, zero, e64, m8 \n\t" - "vsse64.v v0, (s1), t6 \n\t" - "add a1, a1, %[STRIDE] \n\t" - "bge t2, %[BlkLen], LOOP%= \n\t" - - "TAIL%=: \n\t" - "blez t2, QUIT%= \n\t" - "vsetvli t0, zero, e32, m8 \n\t" - "vxor.vv v0, v0, v0 \n\t" - "vxor.vv v8, v8, v8 \n\t" - "vxor.vv v16, v16, v16 \n\t" - "vxor.vv v24, v24, v24 \n\t" - "addi t1, t2, 0 \n\t" - "vsetvli t0, t1, e32, m8 \n\t" - "sub t1, t1, t0 \n\t" - "vle32.v v0, (%[SRC]) \n\t" - "addi %[SRC], %[SRC], 256 \n\t" - "vsetvli t0, t1, e32, m8 \n\t" - "sub t1, t1, t0 \n\t" - "vle32.v v8, (%[SRC]) \n\t" - "addi %[SRC], %[SRC], 256 \n\t" - "vsetvli t0, t1, e32, m8 \n\t" - "sub t1, t1, t0 \n\t" - "vle32.v v16, (%[SRC]) \n\t" - "addi %[SRC], %[SRC], 256 \n\t" - "vsetvli t0, t1, e32, m8 \n\t" - "vle32.v v24, (%[SRC]) \n\t" - "addi %[SRC], %[SRC], -768 \n\t" - "vsetvli t0, zero, e32, m8 \n\t" - "vfabs.v v0, v0 \n\t" - "vfabs.v v8, v8 \n\t" - "vfabs.v v16, v16 \n\t" - "vfabs.v v24, v24 \n\t" - "vfmax.vv v8, v0, v8 \n\t" - "vfmax.vv v24, v16, v24 \n\t" - "vfmax.vv v8, v8, v24 \n\t" - "vfredmax.vs v24, v8, v24 \n\t" - "vfmv.f.s f10, v24 \n\t" - "add s1, a1, %[OFFSET] \n\t" - "fmul.s f10, f10, %[RMAXREC] \n\t" - "fsw f10, (a1) \n\t" - "fdiv.s f11, %[FONE], f10 \n\t" - "vsetvli t0, zero, e64, m8 \n\t" - "vxor.vv v0, v0, v0 \n\t" - "vsse64.v v0, (s1), t6 \n\t" - - "TAIL_LOOP%=: \n\t" - "vsetvli t0, zero, e32, m4 \n\t" - "vxor.vv v0, v0, v0 \n\t" - "vsetvli t0, t2, e32, m1 \n\t" - "sub t2, t2, t0 \n\t" - "vle32.v v0, (%[SRC]) \n\t" - "addi %[SRC], %[SRC], 32 \n\t" - "vfmul.vf v1, v0, f11 \n\t" - "vfcvt.x.f.v v2, v1 \n\t" - "vsetvli t0, zero, e16, mf2 \n\t" - "vnclip.wx v3, v2, zero \n\t" - "vsetvli t0, zero, e8, mf4 \n\t" - "vnclip.wx v3, v3, zero \n\t" - "vse8.v v3, (s1) \n\t" - "addi s1, s1, 32 \n\t" - "bnez t2, TAIL_LOOP%= \n\t" - - "QUIT%=: \n\t" - : [SRC] "+r"(SRC) - : [DST] "r"(DST), [BlkLen] "r"(BlkLen), [OFFSET] "r"(offset), [STRIDE] "r"(stride), - [CountK] "r"(CountK), [FONE] "f"(fone), [RMAXREC] "f"(range_max_reciprocal) - : "cc", "t0", "t1", "t2", "t6", "a1", "s1", "s2", "f10", "f11"); - } + const size_t offset = (4 - row_index) * 4 + row_index * 8; + const size_t stride = 4 * (sizeof(float) + BlkLen); + __asm__ volatile( + "vsetvli t0, zero, e32, m8 \n\t" + "addi t2, %[CountK], 0 \n\t" + "addi a1, %[DST], 0 \n\t" + "blt t2, %[BlkLen], TAIL%= \n\t" + + "LOOP%=: \n\t" + "vsetvli t0, %[BlkLen], e32, m8 \n\t" + "vle32.v v0, (%[SRC]) \n\t" + "sub t2, t2, t0 \n\t" + "slli t1, t0, 2 \n\t" + "add %[SRC], %[SRC], t1 \n\t" + "add s1, a1, %[OFFSET] \n\t" + + QUANTIZEM4ROW_KERNEL QUANTIZEM4ROW_STORE + + "add a1, a1, %[STRIDE] \n\t" + "bge t2, %[BlkLen], LOOP%= \n\t" + + "TAIL%=: \n\t" + "blez t2, QUIT%= \n\t" + "vsetvli t0, zero, e32, m8 \n\t" + "vxor.vv v16, v16, v16 \n\t" + "vxor.vv v24, v24, v24 \n\t" + "vsetvli t0, t2, e32, m8 \n\t" + "vle32.v v0, (%[SRC]) \n\t" + "add s1, a1, %[OFFSET] \n\t" + + QUANTIZEM4ROW_KERNEL + + "addi t3, %[BlkLen], 0 \n\t" + "addi s2, s1, 0 \n\t" + "vsetvli t0, zero, e8, mf4 \n\t" + "vxor.vv v8, v8, v8 \n\t" + "SET_ZERO%=: \n\t" + "vse8.v v8, (s2) \n\t" + "addi s2, s2, 32 \n\t" + "addi t3, t3, -8 \n\t" + "bnez t3, SET_ZERO%= \n\t" + + QUANTIZEM4ROW_STORE + + "QUIT%=: \n\t" + : [SRC] "+r"(SRC) + : [DST] "r"(DST), [BlkLen] "r"(BlkLen), [OFFSET] "r"(offset), [STRIDE] "r"(stride), [CountK] "r"(CountK), + [FONE] "f"(fone), [RMAXREC] "f"(range_max_reciprocal) + : "cc", "t0", "t1", "t2", "t3", "a1", "s1", "s2", "f10", "f11"); } } -void quantize_a_row_i8(size_t BlkLen, const float * A, size_t CountK, std::byte * QuantA) { +void quantize_a_row_i8(size_t BlkLen, const float * A, size_t CountK, uint8_t * QuantA) { const float * SRC = A; - std::byte * DST = QuantA; + uint8_t * DST = QuantA; constexpr float range_max_reciprocal = 1.0f / ((1 << 7) - 1); const float fone = 1.0f; - std::byte * QuantA_offset = QuantA + CountK + 4 * ((CountK + BlkLen - 1) / BlkLen); + uint8_t * QuantA_offset = QuantA + CountK + 4 * ((CountK + BlkLen - 1) / BlkLen); size_t offset = (CountK + BlkLen - 1) / BlkLen * BlkLen - CountK; - if (CountK <= BlkLen) { - float max_abs_A = 0.0f; - for (size_t k = 0; k < CountK; k++) { - max_abs_A = std::max(max_abs_A, fabsf(A[k])); - } - float scale_A = max_abs_A * range_max_reciprocal; - - ((float *) QuantA)[0] = scale_A; - - auto * QuantAData_offset = (int8_t *) (QuantA + sizeof(float)); - - for (size_t k = 0; k < CountK; k++) { - QuantAData_offset[k] = - (int8_t) std::clamp(roundf(A[k] / scale_A), (float) std::numeric_limits<int8_t>::lowest(), - (float) std::numeric_limits<int8_t>::max()); - } - for (size_t k = CountK; k < BlkLen; k++) { - QuantAData_offset[k] = 0; - } - - return; - } - - if (BlkLen != 32 || BlkLen != 64 || BlkLen != 128) { - __asm__ volatile( - "vsetvli t0, zero, e8, m8 \n\t" - "vxor.vv v24, v24, v24 \n\t" - "LOOP%=: \n\t" - "vsetvli t0, %[CNT], e8, m8 \n\t" - "vse8.v v24, (%[DST]) \n\t" - "addi %[DST], %[DST], 128 \n\t" - "sub %[CNT], %[CNT], t0 \n\t" - "bnez %[CNT], LOOP%= \n\t" - : [DST] "+r"(QuantA_offset), [CNT] "+r"(offset) - : - : "cc", "t0"); - } - if (BlkLen == 16) { - float buffer[64] = { 0.0f }; - __asm__ volatile( - "addi t3, zero, 16*8 \n\t" - "addi t2, zero, 16 \n\t" - "blt %[K], t3, LOOP_K%= \n\t" - "blt %[K], t2, TAIL%= \n\t" - "LOOP_MAIN%=: \n\t" - "vsetvli t1, zero, e32, m2 \n\t" - "addi %[K], %[K], -128 \n\t" - "vle32.v v0, (%[SRC]) \n\t" - "addi %[SRC], %[SRC], 64 \n\t" - "vle32.v v2, (%[SRC]) \n\t" - "addi %[SRC], %[SRC], 64 \n\t" - "vle32.v v4, (%[SRC]) \n\t" - "addi %[SRC], %[SRC], 64 \n\t" - "vle32.v v6, (%[SRC]) \n\t" - "addi %[SRC], %[SRC], 64 \n\t" - "vle32.v v8, (%[SRC]) \n\t" - "addi %[SRC], %[SRC], 64 \n\t" - "vle32.v v10, (%[SRC]) \n\t" - "addi %[SRC], %[SRC], 64 \n\t" - "vle32.v v12, (%[SRC]) \n\t" - "addi %[SRC], %[SRC], 64 \n\t" - "vle32.v v14, (%[SRC]) \n\t" - "addi %[SRC], %[SRC], 64 \n\t" - "addi a1, %[BUFFER], 0 \n\t" - "vfabs.v v16, v0 \n\t" - "vfabs.v v18, v2 \n\t" - "vfabs.v v20, v4 \n\t" - "vfabs.v v22, v6 \n\t" - "vfabs.v v24, v8 \n\t" - "vfabs.v v26, v10 \n\t" - "vfabs.v v28, v12 \n\t" - "vfabs.v v30, v14 \n\t" - "vsetvli t0, zero, e32, m1 \n\t" - "vfmax.vv v16, v16, v17 \n\t" - "vfmax.vv v18, v18, v19 \n\t" - "vfmax.vv v20, v20, v21 \n\t" - "vfmax.vv v22, v22, v23 \n\t" - "vfmax.vv v24, v24, v25 \n\t" - "vfmax.vv v26, v26, v27 \n\t" - "vfmax.vv v28, v28, v29 \n\t" - "vfmax.vv v30, v30, v31 \n\t" - "vse32.v v16, (a1) \n\t" - "addi a1, a1, 32 \n\t" - "vse32.v v18, (a1) \n\t" - "addi a1, a1, 32 \n\t" - "vse32.v v20, (a1) \n\t" - "addi a1, a1, 32 \n\t" - "vse32.v v22, (a1) \n\t" - "addi a1, a1, 32 \n\t" - "vse32.v v24, (a1) \n\t" - "addi a1, a1, 32 \n\t" - "vse32.v v26, (a1) \n\t" - "addi a1, a1, 32 \n\t" - "vse32.v v28, (a1) \n\t" - "addi a1, a1, 32 \n\t" - "vse32.v v30, (a1) \n\t" - "addi a1, %[BUFFER], 0 \n\t" - "flw f0, (a1) \n\t" - "flw f1, 4(a1) \n\t" - "flw f2, 8(a1) \n\t" - "flw f3, 12(a1) \n\t" - "flw f4, 16(a1) \n\t" - "flw f5, 20(a1) \n\t" - "flw f6, 24(a1) \n\t" - "flw f7, 28(a1) \n\t" - "addi a1, a1, 32 \n\t" - "fmax.s f1, f0, f1 \n\t" - "fmax.s f3, f2, f3 \n\t" - "fmax.s f5, f4, f5 \n\t" - "fmax.s f7, f6, f7 \n\t" - "fmax.s f3, f1, f3 \n\t" - "fmax.s f7, f5, f7 \n\t" - "fmax.s f10, f3, f7 \n\t" - "fmul.s f10, f10, %[RMAXREC] \n\t" - "fsw f10, (%[DST]) \n\t" - "addi %[DST], %[DST], 20 \n\t" - "fdiv.s f10, %[FONE], f10 \n\t" - "flw f0, (a1) \n\t" - "flw f1, 4(a1) \n\t" - "flw f2, 8(a1) \n\t" - "flw f3, 12(a1) \n\t" - "flw f4, 16(a1) \n\t" - "flw f5, 20(a1) \n\t" - "flw f6, 24(a1) \n\t" - "flw f7, 28(a1) \n\t" - "addi a1, a1, 32 \n\t" - "fmax.s f1, f0, f1 \n\t" - "fmax.s f3, f2, f3 \n\t" - "fmax.s f5, f4, f5 \n\t" - "fmax.s f7, f6, f7 \n\t" - "fmax.s f3, f1, f3 \n\t" - "fmax.s f7, f5, f7 \n\t" - "fmax.s f11, f3, f7 \n\t" - "fmul.s f11, f11, %[RMAXREC] \n\t" - "fsw f11, (%[DST]) \n\t" - "addi %[DST], %[DST], 20 \n\t" - "fdiv.s f11, %[FONE], f11 \n\t" - "flw f0, (a1) \n\t" - "flw f1, 4(a1) \n\t" - "flw f2, 8(a1) \n\t" - "flw f3, 12(a1) \n\t" - "flw f4, 16(a1) \n\t" - "flw f5, 20(a1) \n\t" - "flw f6, 24(a1) \n\t" - "flw f7, 28(a1) \n\t" - "addi a1, a1, 32 \n\t" - "fmax.s f1, f0, f1 \n\t" - "fmax.s f3, f2, f3 \n\t" - "fmax.s f5, f4, f5 \n\t" - "fmax.s f7, f6, f7 \n\t" - "fmax.s f3, f1, f3 \n\t" - "fmax.s f7, f5, f7 \n\t" - "fmax.s f12, f3, f7 \n\t" - "fmul.s f12, f12, %[RMAXREC] \n\t" - "fsw f12, (%[DST]) \n\t" - "addi %[DST], %[DST], 20 \n\t" - "fdiv.s f12, %[FONE], f12 \n\t" - "flw f0, (a1) \n\t" - "flw f1, 4(a1) \n\t" - "flw f2, 8(a1) \n\t" - "flw f3, 12(a1) \n\t" - "flw f4, 16(a1) \n\t" - "flw f5, 20(a1) \n\t" - "flw f6, 24(a1) \n\t" - "flw f7, 28(a1) \n\t" - "addi a1, a1, 32 \n\t" - "fmax.s f1, f0, f1 \n\t" - "fmax.s f3, f2, f3 \n\t" - "fmax.s f5, f4, f5 \n\t" - "fmax.s f7, f6, f7 \n\t" - "fmax.s f3, f1, f3 \n\t" - "fmax.s f7, f5, f7 \n\t" - "fmax.s f13, f3, f7 \n\t" - "fmul.s f13, f13, %[RMAXREC] \n\t" - "fsw f13, (%[DST]) \n\t" - "addi %[DST], %[DST], 20 \n\t" - "fdiv.s f13, %[FONE], f13 \n\t" - "flw f0, (a1) \n\t" - "flw f1, 4(a1) \n\t" - "flw f2, 8(a1) \n\t" - "flw f3, 12(a1) \n\t" - "flw f4, 16(a1) \n\t" - "flw f5, 20(a1) \n\t" - "flw f6, 24(a1) \n\t" - "flw f7, 28(a1) \n\t" - "addi a1, a1, 32 \n\t" - "fmax.s f1, f0, f1 \n\t" - "fmax.s f3, f2, f3 \n\t" - "fmax.s f5, f4, f5 \n\t" - "fmax.s f7, f6, f7 \n\t" - "fmax.s f3, f1, f3 \n\t" - "fmax.s f7, f5, f7 \n\t" - "fmax.s f14, f3, f7 \n\t" - "fmul.s f14, f14, %[RMAXREC] \n\t" - "fsw f14, (%[DST]) \n\t" - "addi %[DST], %[DST], 20 \n\t" - "fdiv.s f14, %[FONE], f14 \n\t" - "flw f0, (a1) \n\t" - "flw f1, 4(a1) \n\t" - "flw f2, 8(a1) \n\t" - "flw f3, 12(a1) \n\t" - "flw f4, 16(a1) \n\t" - "flw f5, 20(a1) \n\t" - "flw f6, 24(a1) \n\t" - "flw f7, 28(a1) \n\t" - "addi a1, a1, 32 \n\t" - "fmax.s f1, f0, f1 \n\t" - "fmax.s f3, f2, f3 \n\t" - "fmax.s f5, f4, f5 \n\t" - "fmax.s f7, f6, f7 \n\t" - "fmax.s f3, f1, f3 \n\t" - "fmax.s f7, f5, f7 \n\t" - "fmax.s f15, f3, f7 \n\t" - "fmul.s f15, f15, %[RMAXREC] \n\t" - "fsw f15, (%[DST]) \n\t" - "addi %[DST], %[DST], 20 \n\t" - "fdiv.s f15, %[FONE], f15 \n\t" - "flw f0, (a1) \n\t" - "flw f1, 4(a1) \n\t" - "flw f2, 8(a1) \n\t" - "flw f3, 12(a1) \n\t" - "flw f4, 16(a1) \n\t" - "flw f5, 20(a1) \n\t" - "flw f6, 24(a1) \n\t" - "flw f7, 28(a1) \n\t" - "addi a1, a1, 32 \n\t" - "fmax.s f1, f0, f1 \n\t" - "fmax.s f3, f2, f3 \n\t" - "fmax.s f5, f4, f5 \n\t" - "fmax.s f7, f6, f7 \n\t" - "fmax.s f3, f1, f3 \n\t" - "fmax.s f7, f5, f7 \n\t" - "fmax.s f16, f3, f7 \n\t" - "fmul.s f16, f16, %[RMAXREC] \n\t" - "fsw f16, (%[DST]) \n\t" - "addi %[DST], %[DST], 20 \n\t" - "fdiv.s f16, %[FONE], f16 \n\t" - "flw f0, (a1) \n\t" - "flw f1, 4(a1) \n\t" - "flw f2, 8(a1) \n\t" - "flw f3, 12(a1) \n\t" - "flw f4, 16(a1) \n\t" - "flw f5, 20(a1) \n\t" - "flw f6, 24(a1) \n\t" - "flw f7, 28(a1) \n\t" - "addi a1, a1, 32 \n\t" - "fmax.s f1, f0, f1 \n\t" - "fmax.s f3, f2, f3 \n\t" - "fmax.s f5, f4, f5 \n\t" - "fmax.s f7, f6, f7 \n\t" - "fmax.s f3, f1, f3 \n\t" - "fmax.s f7, f5, f7 \n\t" - "fmax.s f17, f3, f7 \n\t" - "fmul.s f17, f17, %[RMAXREC] \n\t" - "fsw f17, (%[DST]) \n\t" - "addi %[DST], %[DST], -136 \n\t" - "fdiv.s f17, %[FONE], f17 \n\t" - "vsetvli t0, zero, e32, m2 \n\t" - "vfmul.vf v16, v0, f10 \n\t" - "vfmul.vf v18, v2, f11 \n\t" - "vfmul.vf v20, v4, f12 \n\t" - "vfmul.vf v22, v6, f13 \n\t" - "vfmul.vf v24, v8, f14 \n\t" - "vfmul.vf v26, v10, f15 \n\t" - "vfmul.vf v28, v12, f16 \n\t" - "vfmul.vf v30, v14, f17 \n\t" - "vfcvt.x.f.v v16, v16 \n\t" - "vfcvt.x.f.v v18, v18 \n\t" - "vfcvt.x.f.v v20, v20 \n\t" - "vfcvt.x.f.v v22, v22 \n\t" - "vfcvt.x.f.v v24, v24 \n\t" - "vfcvt.x.f.v v26, v26 \n\t" - "vfcvt.x.f.v v28, v28 \n\t" - "vfcvt.x.f.v v30, v30 \n\t" - "vsetvli t0, zero, e16, m1 \n\t" - "vnclip.wx v16, v16, zero \n\t" - "vnclip.wx v18, v18, zero \n\t" - "vnclip.wx v20, v20, zero \n\t" - "vnclip.wx v22, v22, zero \n\t" - "vnclip.wx v24, v24, zero \n\t" - "vnclip.wx v26, v26, zero \n\t" - "vnclip.wx v28, v28, zero \n\t" - "vnclip.wx v30, v30, zero \n\t" - "vsetvli t0, t1, e8, mf2 \n\t" - "vnclip.wx v16, v16, zero \n\t" - "vnclip.wx v18, v18, zero \n\t" - "vnclip.wx v20, v20, zero \n\t" - "vnclip.wx v22, v22, zero \n\t" - "vnclip.wx v24, v24, zero \n\t" - "vnclip.wx v26, v26, zero \n\t" - "vnclip.wx v28, v28, zero \n\t" - "vnclip.wx v30, v30, zero \n\t" - "vse8.v v16, (%[DST]) \n\t" - "addi %[DST], %[DST], 20 \n\t" - "vse8.v v18, (%[DST]) \n\t" - "addi %[DST], %[DST], 20 \n\t" - "vse8.v v20, (%[DST]) \n\t" - "addi %[DST], %[DST], 20 \n\t" - "vse8.v v22, (%[DST]) \n\t" - "addi %[DST], %[DST], 20 \n\t" - "vse8.v v24, (%[DST]) \n\t" - "addi %[DST], %[DST], 20 \n\t" - "vse8.v v26, (%[DST]) \n\t" - "addi %[DST], %[DST], 20 \n\t" - "vse8.v v28, (%[DST]) \n\t" - "addi %[DST], %[DST], 20 \n\t" - "vse8.v v30, (%[DST]) \n\t" - "addi %[DST], %[DST], 16 \n\t" - "bge %[K], t3, LOOP_MAIN%= \n\t" - "blt %[K], t2, TAIL%= \n\t" - "LOOP_K%=: \n\t" - "vsetvli t1, %[K], e32, m2 \n\t" - "vle32.v v0, (%[SRC]) \n\t" - "addi %[SRC], %[SRC], 64 \n\t" - "sub %[K], %[K], t1 \n\t" - "vfabs.v v16, v0 \n\t" - "vsetvli t0, zero, e32, m1 \n\t" - "vfmax.vv v16, v16, v17 \n\t" - "vse32.v v16, (%[BUFFER]) \n\t" - "flw f0, (%[BUFFER]) \n\t" - "flw f1, 4(%[BUFFER]) \n\t" - "flw f2, 8(%[BUFFER]) \n\t" - "flw f3, 12(%[BUFFER]) \n\t" - "flw f4, 16(%[BUFFER]) \n\t" - "flw f5, 20(%[BUFFER]) \n\t" - "flw f6, 24(%[BUFFER]) \n\t" - "flw f7, 28(%[BUFFER]) \n\t" - "fmax.s f1, f0, f1 \n\t" - "fmax.s f3, f2, f3 \n\t" - "fmax.s f5, f4, f5 \n\t" - "fmax.s f7, f6, f7 \n\t" - "fmax.s f3, f1, f3 \n\t" - "fmax.s f7, f5, f7 \n\t" - "fmax.s f10, f3, f7 \n\t" - "fmul.s f10, f10, %[RMAXREC] \n\t" - "fsw f10, (%[DST]) \n\t" - "addi %[DST], %[DST], 4 \n\t" - "fdiv.s f11, %[FONE], f10 \n\t" - "vsetvli t0, zero, e32, m2 \n\t" - "vfmul.vf v16, v0, f11 \n\t" - "vfcvt.x.f.v v16, v16 \n\t" - "vsetvli t0, zero, e16, m1 \n\t" - "vnclip.wx v16, v16, zero \n\t" - "vsetvli t0, t1, e8, mf2 \n\t" - "vnclip.wx v16, v16, zero \n\t" - "vse8.v v16, (%[DST]) \n\t" - "addi %[DST], %[DST], 16 \n\t" - "bge %[K], t2, LOOP_K%= \n\t" - "TAIL%=: \n\t" - "blez %[K], END%= \n\t" - "vsetvli t0, t3, e32, m2 \n\t" - "vxor.vv v16, v16, v16 \n\t" - "jal x0, LOOP_K%= \n\t" - "END%=: \n\t" - : [SRC] "+r"(SRC), [DST] "+r"(DST), [K] "+r"(CountK) - : [FONE] "f"(fone), [RMAXREC] "f"(range_max_reciprocal), [BUFFER] "r"(buffer) - : "cc", "t3", "t2", "t1", "t0", "a1", "f0", "f1", "f2", "f3", "f4", "f5", "f6", "f7", "f10", "f11", "f12", - "f13", "f14", "f15", "f16", "f17"); - } else if (BlkLen == 32) { - __asm__ volatile( - "addi t3, zero, 32*4 \n\t" - "addi t2, zero, 32 \n\t" - - "addi a1, %[SRC], 0 \n\t" - "addi a2, %[SRC], 128 \n\t" - "addi a3, %[SRC], 256 \n\t" - "addi a4, %[SRC], 384 \n\t" - - "addi s1, %[DST], 0 \n\t" - "addi s2, %[DST], 36 \n\t" - "addi s3, %[DST], 72 \n\t" - "addi s4, %[DST], 108 \n\t" - "blt %[K], t3, LOOP_K%= \n\t" - "blt %[K], t2, TAIL%= \n\t" - - "LOOP_MAIN%=: \n\t" - "vsetvli t1, zero, e32, m4 \n\t" - "addi %[K], %[K], -128 \n\t" - "vle32.v v0, (a1) \n\t" - "addi a1, a1, 512 \n\t" - "vle32.v v4, (a2) \n\t" - "addi a2, a2, 512 \n\t" - "vle32.v v8, (a3) \n\t" - "addi a3, a3, 512 \n\t" - "vle32.v v12, (a4) \n\t" - "addi a4, a4, 512 \n\t" - "vfabs.v v16, v0 \n\t" - "vfabs.v v20, v4 \n\t" - "vfabs.v v24, v8 \n\t" - "vfabs.v v28, v12 \n\t" - "vsetvli t0, zero, e32, m2 \n\t" - "vfmax.vv v16, v16, v18 \n\t" - "vfmax.vv v20, v20, v22 \n\t" - "vfmax.vv v24, v24, v26 \n\t" - "vfmax.vv v28, v28, v30 \n\t" - "vsetvli t0, zero, e32, m1 \n\t" - "vfmax.vv v16, v16, v17 \n\t" - "vfmax.vv v20, v20, v21 \n\t" - "vfmax.vv v24, v24, v25 \n\t" - "vfmax.vv v28, v28, v29 \n\t" - - "vfredmax.vs v17, v16, v17 \n\t" - "vfredmax.vs v21, v20, v21 \n\t" - "vfredmax.vs v25, v24, v25 \n\t" - "vfredmax.vs v29, v28, v29 \n\t" - "vfmv.f.s f10, v17 \n\t" - "vfmv.f.s f11, v21 \n\t" - "vfmv.f.s f12, v25 \n\t" - "vfmv.f.s f13, v29 \n\t" - - "fmul.s f10, f10, %[RMAXREC] \n\t" - "fmul.s f11, f11, %[RMAXREC] \n\t" - "fmul.s f12, f12, %[RMAXREC] \n\t" - "fmul.s f13, f13, %[RMAXREC] \n\t" - "fsw f10, (s1) \n\t" - "addi s1, s1, 4 \n\t" - - "fsw f11, (s2) \n\t" - "addi s2, s2, 4 \n\t" - "fsw f12, (s3) \n\t" - "addi s3, s3, 4 \n\t" - "fsw f13, (s4) \n\t" - "addi s4, s4, 4 \n\t" - "fdiv.s f10, %[FONE], f10 \n\t" - "fdiv.s f11, %[FONE], f11 \n\t" - "fdiv.s f12, %[FONE], f12 \n\t" - "fdiv.s f13, %[FONE], f13 \n\t" - "vsetvli t0, zero, e32, m4 \n\t" - "vfmul.vf v16, v0, f10 \n\t" - "vfmul.vf v20, v4, f11 \n\t" - "vfmul.vf v24, v8, f12 \n\t" - "vfmul.vf v28, v12, f13 \n\t" - "vfcvt.x.f.v v16, v16 \n\t" - "vfcvt.x.f.v v20, v20 \n\t" - "vfcvt.x.f.v v24, v24 \n\t" - "vfcvt.x.f.v v28, v28 \n\t" - "vsetvli t0, zero, e16, m2 \n\t" - "vnclip.wx v16, v16, zero \n\t" - "vnclip.wx v20, v20, zero \n\t" - "vnclip.wx v24, v24, zero \n\t" - "vnclip.wx v28, v28, zero \n\t" - "vsetvli t0, t1, e8, m1 \n\t" - "vnclip.wx v16, v16, zero \n\t" - "vnclip.wx v20, v20, zero \n\t" - "vnclip.wx v24, v24, zero \n\t" - "vnclip.wx v28, v28, zero \n\t" - "vse8.v v16, (s1) \n\t" - "addi s1, s1, 140 \n\t" - "vse8.v v20, (s2) \n\t" - "addi s2, s2, 140 \n\t" - "vse8.v v24, (s3) \n\t" - "addi s3, s3, 140 \n\t" - "vse8.v v28, (s4) \n\t" - "addi s4, s4, 140 \n\t" - "bge %[K], t3, LOOP_MAIN%= \n\t" - "blt %[K], t2, TAIL%= \n\t" - "LOOP_K%=: \n\t" - "vsetvli t1, %[K], e32, m4 \n\t" - "vle32.v v0, (a1) \n\t" - "addi a1, a1, 128 \n\t" - "sub %[K], %[K], t1 \n\t" - "vfabs.v v16, v0 \n\t" - "vsetvli t0, zero, e32, m2 \n\t" - "vfmax.vv v16, v16, v18 \n\t" - "vsetvli t0, zero, e32, m1 \n\t" - "vfmax.vv v16, v16, v17 \n\t" - "vfredmax.vs v17, v16, v17 \n\t" - "vfmv.f.s f10, v17 \n\t" - - "fmul.s f10, f10, %[RMAXREC] \n\t" - "fsw f10, (s1) \n\t" - "addi s1, s1, 4 \n\t" - "fdiv.s f11, %[FONE], f10 \n\t" - "vsetvli t0, zero, e32, m4 \n\t" - "vfmul.vf v16, v0, f11 \n\t" - "vfcvt.x.f.v v16, v16 \n\t" - "vsetvli t0, zero, e16, m2 \n\t" - "vnclip.wx v16, v16, zero \n\t" - "vsetvli t0, zero, e8, m1 \n\t" - "vnclip.wx v16, v16, zero \n\t" - "vse8.v v16, (s1) \n\t" - "addi s1, s1, 32 \n\t" - "bge %[K], t2, LOOP_K%= \n\t" - "TAIL%=: \n\t" - "blez %[K], END%= \n\t" - "vsetvli t0, t3, e32, m4 \n\t" - "vxor.vv v0, v0, v0 \n\t" - "vxor.vv v16, v16, v16 \n\t" - "jal x0, LOOP_K%= \n\t" - "END%=: \n\t" - : [K] "+r"(CountK) - : [FONE] "f"(fone), [RMAXREC] "f"(range_max_reciprocal), [SRC] "r"(SRC), [DST] "r"(DST) - : "cc", "t3", "t2", "t1", "t0", "a1", "a2", "a3", "a4", "s1", "s2", "s3", "s4", "f10", "f11", "f12", "f13"); - } else if (BlkLen == 64) { - __asm__ volatile( - "addi t3, zero, 64*2 \n\t" - "addi t2, zero, 64 \n\t" - "addi a1, %[SRC], 0 \n\t" - "addi a2, %[SRC], 256 \n\t" - "addi s1, %[DST], 0 \n\t" - "addi s2, %[DST], 68 \n\t" - "blt %[K], t3, LOOP_K%= \n\t" - "blt %[K], t2, TAIL%= \n\t" - "LOOP_MAIN%=: \n\t" - "vsetvli t1, zero, e32, m8 \n\t" - "addi %[K], %[K], -128 \n\t" - "vle32.v v0, (a1) \n\t" - "addi a1, a1, 512 \n\t" - "vle32.v v8, (a2) \n\t" - "addi a2, a2, 512 \n\t" - "vfabs.v v16, v0 \n\t" - "vfabs.v v24, v8 \n\t" - "vsetvli t0, zero, e32, m4 \n\t" - "vfmax.vv v16, v16, v20 \n\t" - "vfmax.vv v24, v24, v28 \n\t" - "vsetvli t0, zero, e32, m2 \n\t" - "vfmax.vv v16, v16, v18 \n\t" - "vfmax.vv v24, v24, v26 \n\t" - "vsetvli t0, zero, e32, m1 \n\t" - "vfmax.vv v16, v16, v17 \n\t" - "vfmax.vv v24, v24, v25 \n\t" - "vfredmax.vs v17, v16, v17 \n\t" - "vfredmax.vs v25, v24, v25 \n\t" - "vfmv.f.s f10, v17 \n\t" - "vfmv.f.s f11, v25 \n\t" - "fmul.s f10, f10, %[RMAXREC] \n\t" - "fmul.s f11, f11, %[RMAXREC] \n\t" - "fsw f10, (s1) \n\t" - "addi s1, s1, 4 \n\t" - "fsw f11, (s2) \n\t" - "addi s2, s2, 4 \n\t" - "fdiv.s f10, %[FONE], f10 \n\t" - "fdiv.s f11, %[FONE], f11 \n\t" - "vsetvli t0, zero, e32, m8 \n\t" - "vfmul.vf v16, v0, f10 \n\t" - "vfmul.vf v24, v8, f11 \n\t" - "vfcvt.x.f.v v16, v16 \n\t" - "vfcvt.x.f.v v24, v24 \n\t" - "vsetvli t0, zero, e16, m4 \n\t" - "vnclip.wx v16, v16, zero \n\t" - "vnclip.wx v24, v24, zero \n\t" - "vsetvli t0, t1, e8, m2 \n\t" - "vnclip.wx v16, v16, zero \n\t" - "vnclip.wx v24, v24, zero \n\t" - "vse8.v v16, (s1) \n\t" - "addi s1, s1, 132 \n\t" - "vse8.v v24, (s2) \n\t" - "addi s2, s2, 132 \n\t" - "bge %[K], t3, LOOP_MAIN%= \n\t" - "blt %[K], t2, TAIL%= \n\t" - "LOOP_K%=: \n\t" - "vsetvli t1, %[K], e32, m8 \n\t" - "vle32.v v0, (a1) \n\t" - "addi a1, a1, 256 \n\t" - "sub %[K], %[K], t1 \n\t" - "vfabs.v v16, v0 \n\t" - "vsetvli t0, zero, e32, m4 \n\t" - "vfmax.vv v16, v16, v20 \n\t" - "vsetvli t0, zero, e32, m2 \n\t" - "vfmax.vv v16, v16, v18 \n\t" - "vsetvli t0, zero, e32, m1 \n\t" - "vfmax.vv v16, v16, v17 \n\t" - "vfredmax.vs v17, v16, v17 \n\t" - "vfmv.f.s f10, v17 \n\t" - "fmul.s f10, f10, %[RMAXREC] \n\t" - "fsw f10, (s1) \n\t" - "addi s1, s1, 4 \n\t" - "fdiv.s f11, %[FONE], f10 \n\t" - "vsetvli t0, zero, e32, m8 \n\t" - "vfmul.vf v16, v0, f11 \n\t" - "vfcvt.x.f.v v16, v16 \n\t" - "vsetvli t0, zero, e16, m4 \n\t" - "vnclip.wx v16, v16, zero \n\t" - "vsetvli t0, zero, e8, m2 \n\t" - "vnclip.wx v16, v16, zero \n\t" - "vse8.v v16, (s1) \n\t" - "addi s1, s1, 64 \n\t" - "bge %[K], t2, LOOP_K%= \n\t" - "TAIL%=: \n\t" - "blez %[K], END%= \n\t" - "vsetvli t0, t3, e32, m8 \n\t" - "vxor.vv v0, v0, v0 \n\t" - "vxor.vv v16, v16, v16 \n\t" - "jal x0, LOOP_K%= \n\t" - "END%=: \n\t" - : [K] "+r"(CountK) - : [SRC] "r"(SRC), [DST] "r"(DST), [FONE] "f"(fone), [RMAXREC] "f"(range_max_reciprocal) - : "cc", "t3", "t2", "t1", "t0", "a1", "a2", "s1", "s2", "f10", "f11"); - } else if (BlkLen == 128) { - __asm__ volatile( - "addi t2, zero, 128 \n\t" - "addi a1, %[SRC], 0 \n\t" - "addi a2, %[SRC], 256 \n\t" - "blt %[K], t2, TAIL%= \n\t" - "LOOP_K%=: \n\t" - "vsetvli t1, zero, e32, m8 \n\t" - "vle32.v v0, (a1) \n\t" - "addi a1, a1, 512 \n\t" - "vle32.v v8, (a2) \n\t" - "addi a2, a2, 512 \n\t" - "sub %[K], %[K], t2 \n\t" - "QUANT%=: \n\t" - "vfabs.v v16, v0 \n\t" - "vfabs.v v24, v8 \n\t" - "vfmax.vv v24, v16, v24 \n\t" - "vsetvli t1, zero, e32, m4 \n\t" - "vfmax.vv v28, v24, v28 \n\t" - "vsetvli t0, zero, e32, m2 \n\t" - "vfmax.vv v30, v28, v30 \n\t" - "vsetvli t0, zero, e32, m1 \n\t" - "vfmax.vv v30, v30, v31 \n\t" - "vfredmax.vs v31, v30, v31 \n\t" - "vfmv.f.s f10, v31 \n\t" - "fmul.s f10, f10, %[RMAXREC] \n\t" - "fsw f10, (%[DST]) \n\t" - "addi %[DST], %[DST], 4 \n\t" - "fdiv.s f11, %[FONE], f10 \n\t" - "vsetvli t0, zero, e32, m8 \n\t" - "vfmul.vf v16, v0, f11 \n\t" - "vfmul.vf v24, v8, f11 \n\t" - "vfcvt.x.f.v v16, v16 \n\t" - "vfcvt.x.f.v v24, v24 \n\t" - "vsetvli t0, zero, e16, m4 \n\t" - "vnclip.wx v16, v16, zero \n\t" - "vnclip.wx v20, v24, zero \n\t" - "vsetvli t0, zero, e8, m4 \n\t" - "vnclip.wx v16, v16, zero \n\t" - "vse8.v v16, (%[DST]) \n\t" - "addi %[DST], %[DST], 128 \n\t" - "bge %[K], t2, LOOP_K%= \n\t" - "TAIL%=: \n\t" - "blez %[K], END%= \n\t" - "vsetvli t1, zero, e32, m8 \n\t" - "vxor.vv v0, v0, v0 \n\t" - "vxor.vv v8, v8, v8 \n\t" - "vsetvli t0, %[K], e32, m8 \n\t" - "vle32.v v0, (a1) \n\t" - "sub %[K], %[K], t0 \n\t" - "vsetvli t0, %[K], e32, m8 \n\t" - "vle32.v v8, (a2) \n\t" - "sub %[K], %[K], t0 \n\t" - "vsetvli t1, zero, e32, m8 \n\t" - "jal x0, QUANT%= \n\t" - "END%=: \n\t" - - : [DST] "+r"(DST), [K] "+r"(CountK) - : [FONE] "f"(fone), [RMAXREC] "f"(range_max_reciprocal), [SRC] "r"(SRC) - : "cc", "t2", "t1", "t0", "a1", "a2", "f10", "f11"); - } else { - float buffer[8] = { 0.0f }; - size_t cnt = BlkLen / 256; - - __asm__ volatile( - "slli t3, %[BLK], 2 \n\t" - "blt %[K], %[BLK], LOOP_TAIL%= \n\t" - "LOOP_MAIN%=: \n\t" - "vsetvli t0, zero, e32, m1 \n\t" - "vxor.vv v31, v31, v31 \n\t" - "vse32.v v31, (%[BUFFER]) \n\t" - "addi t6, %[CNT], 0 \n\t" - "LOOP_CMP%=: \n\t" - "addi t6, t6, -1 \n\t" - "vsetvli t0, zero, e32, m8 \n\t" - "vle32.v v0, (%[SRC]) \n\t" - "addi %[SRC], %[SRC], 256 \n\t" - "vle32.v v8, (%[SRC]) \n\t" - "addi %[SRC], %[SRC], 256 \n\t" - "vle32.v v16, (%[SRC]) \n\t" - "addi %[SRC], %[SRC], 256 \n\t" - "vle32.v v24, (%[SRC]) \n\t" - "addi %[SRC], %[SRC], 256 \n\t" - "vfabs.v v0, v0 \n\t" - "vfabs.v v8, v8 \n\t" - "vfabs.v v16, v16 \n\t" - "vfabs.v v24, v24 \n\t" - "vfmax.vv v8, v0, v8 \n\t" - "vfmax.vv v16, v16, v24 \n\t" - "vfmax.vv v0, v0, v16 \n\t" - "vsetvli t0, zero, e32, m4 \n\t" - "vfmax.vv v0, v0, v4 \n\t" - "vsetvli t0, zero, e32, m2 \n\t" - "vfmax.vv v0, v0, v2 \n\t" - "vsetvli t0, zero, e32, m1 \n\t" - "vfmax.vv v0, v0, v1 \n\t" - "vle32.v v30, (%[BUFFER]) \n\t" - "vfmax.vv v31, v30, v0 \n\t" - "vse32.v v31, (%[BUFFER]) \n\t" - "bnez t6, LOOP_CMP%= \n\t" - "sub %[SRC], %[SRC], t3 \n\t" - "addi t6, %[CNT], 0 \n\t" - "flw f0, (%[BUFFER]) \n\t" - "flw f1, 4(%[BUFFER]) \n\t" - "flw f2, 8(%[BUFFER]) \n\t" - "flw f3, 12(%[BUFFER]) \n\t" - "flw f4, 16(%[BUFFER]) \n\t" - "flw f5, 20(%[BUFFER]) \n\t" - "flw f6, 24(%[BUFFER]) \n\t" - "flw f7, 28(%[BUFFER]) \n\t" - "fmax.s f1, f0, f1 \n\t" - "fmax.s f3, f2, f3 \n\t" - "fmax.s f5, f4, f5 \n\t" - "fmax.s f7, f6, f7 \n\t" - "fmax.s f3, f1, f3 \n\t" - "fmax.s f7, f5, f7 \n\t" - "fmax.s f10, f3, f7 \n\t" - "fmul.s f10, f10, %[RMAXREC] \n\t" - "fsw f10, (%[DST]) \n\t" - "addi %[DST], %[DST], 4 \n\t" - "fdiv.s f11, %[FONE], f10 \n\t" - "addi t6, %[CNT], 0 \n\t" - "LOOP_QUANT%=: \n\t" - "addi t6, t6, -1 \n\t" - "vsetvli t0, zero, e32, m8 \n\t" - "vle32.v v0, (%[SRC]) \n\t" - "addi %[SRC], %[SRC], 256 \n\t" - "vle32.v v8, (%[SRC]) \n\t" - "addi %[SRC], %[SRC], 256 \n\t" - "vle32.v v16, (%[SRC]) \n\t" - "addi %[SRC], %[SRC], 256 \n\t" - "vle32.v v24, (%[SRC]) \n\t" - "addi %[SRC], %[SRC], 256 \n\t" - "vsetvli t0, zero, e32, m8 \n\t" - "vfmul.vf v0, v0, f11 \n\t" - "vfmul.vf v8, v8, f11 \n\t" - "vfmul.vf v16, v16, f11 \n\t" - "vfmul.vf v24, v24, f11 \n\t" - "vfcvt.x.f.v v0, v0 \n\t" - "vfcvt.x.f.v v8, v8 \n\t" - "vfcvt.x.f.v v16, v16 \n\t" - "vfcvt.x.f.v v24, v24 \n\t" - "vsetvli t0, zero, e16, m4 \n\t" - "vnclip.wx v0, v0, zero \n\t" - "vnclip.wx v4, v8, zero \n\t" - "vnclip.wx v8, v16, zero \n\t" - "vnclip.wx v12, v24, zero \n\t" - "vsetvli t0, zero, e8, m4 \n\t" - "vnclip.wx v0, v0, zero \n\t" - "vnclip.wx v4, v8, zero \n\t" - "vse8.v v0, (%[DST]) \n\t" - "addi %[DST], %[DST], 128 \n\t" - "vse8.v v4, (%[DST]) \n\t" - "addi %[DST], %[DST], 128 \n\t" - "bnez t6, LOOP_QUANT%= \n\t" - "sub %[K], %[K], %[BLK] \n\t" - "bge %[K], %[BLK], LOOP_MAIN%= \n\t" - "blez %[K], END%= \n\t" - "LOOP_TAIL%=: \n\t" - "vsetvli t0, zero, e32, m1 \n\t" - "vxor.vv v31, v31, v31 \n\t" - "vse32.v v31, (%[BUFFER]) \n\t" - "addi t6, %[K], 0 \n\t" - "addi s1, %[SRC], 0 \n\t" - "TAIL_CMP%=: \n\t" - "vsetvli t0, zero, e32, m8 \n\t" - "vxor.vv v0, v0, v0 \n\t" - "vsetvli t0, t6, e32, m8 \n\t" - "vle32.v v0, (%[SRC]) \n\t" - "addi %[SRC], %[SRC], 256 \n\t" - "sub t6, t6, t0 \n\t" - "vfabs.v v0, v0 \n\t" - "vsetvli t0, zero, e32, m4 \n\t" - "vfmax.vv v0, v0, v4 \n\t" - "vsetvli t0, zero, e32, m2 \n\t" - "vfmax.vv v0, v0, v2 \n\t" - "vsetvli t0, zero, e32, m1 \n\t" - "vfmax.vv v0, v0, v1 \n\t" - "vle32.v v30, (%[BUFFER]) \n\t" - "vfmax.vv v31, v30, v0 \n\t" - "vse32.v v31, (%[BUFFER]) \n\t" - "bnez t6, TAIL_CMP%= \n\t" - "addi t6, %[K], 0 \n\t" - "flw f0, (%[BUFFER]) \n\t" - "flw f1, 4(%[BUFFER]) \n\t" - "flw f2, 8(%[BUFFER]) \n\t" - "flw f3, 12(%[BUFFER]) \n\t" - "flw f4, 16(%[BUFFER]) \n\t" - "flw f5, 20(%[BUFFER]) \n\t" - "flw f6, 24(%[BUFFER]) \n\t" - "flw f7, 28(%[BUFFER]) \n\t" - "fmax.s f1, f0, f1 \n\t" - "fmax.s f3, f2, f3 \n\t" - "fmax.s f5, f4, f5 \n\t" - "fmax.s f7, f6, f7 \n\t" - "fmax.s f3, f1, f3 \n\t" - "fmax.s f7, f5, f7 \n\t" - "fmax.s f10, f3, f7 \n\t" - "fmul.s f10, f10, %[RMAXREC] \n\t" - "fsw f10, (%[DST]) \n\t" - "addi %[DST], %[DST], 4 \n\t" - "fdiv.s f11, %[FONE], f10 \n\t" - "addi t6, %[K], 0 \n\t" - "TAIL_QUANT%=: \n\t" - "vsetvli t0, zero, e32, m8 \n\t" - "vxor.vv v0, v0, v0 \n\t" - "vsetvli t1, t6, e32, m8 \n\t" - "vle32.v v0, (s1) \n\t" - "addi s1, s1, 256 \n\t" - "sub t6, t6, t1 \n\t" - "vsetvli t0, zero, e32, m8 \n\t" - "vfmul.vf v0, v0, f11 \n\t" - "vfcvt.x.f.v v0, v0 \n\t" - "vsetvli t0, zero, e16, m4 \n\t" - "vnclip.wx v0, v0, zero \n\t" - "vsetvli t0, t1, e8, m2 \n\t" - "vnclip.wx v0, v0, zero \n\t" - "vse8.v v0, (%[DST]) \n\t" - "addi %[DST], %[DST], 64 \n\t" - "bnez t6, TAIL_QUANT%= \n\t" - "END%=: \n\t" - : [SRC] "+r"(SRC), [DST] "+r"(DST), [K] "+r"(CountK) - : [FONE] "f"(fone), [RMAXREC] "f"(range_max_reciprocal), [BLK] "r"(BlkLen), [BUFFER] "r"(buffer), - [CNT] "r"(cnt) - : "cc", "t1", "t0", "t6", "s1", "f0", "f1", "f2", "f3", "f4", "f5", "f6"); - } + __asm__ volatile( + "addi t3, zero, 32*4 \n\t" + "addi t2, zero, 32 \n\t" + + "addi a1, %[SRC], 0 \n\t" + "addi a2, %[SRC], 128 \n\t" + "addi a3, %[SRC], 256 \n\t" + "addi a4, %[SRC], 384 \n\t" + + "addi s1, %[DST], 0 \n\t" + "addi s2, %[DST], 36 \n\t" + "addi s3, %[DST], 72 \n\t" + "addi s4, %[DST], 108 \n\t" + "blt %[K], t3, LOOP_K%= \n\t" + "blt %[K], t2, TAIL%= \n\t" + + "LOOP_MAIN%=: \n\t" + "vsetvli t1, zero, e32, m4 \n\t" + "addi %[K], %[K], -128 \n\t" + "vle32.v v0, (a1) \n\t" + "addi a1, a1, 512 \n\t" + "vle32.v v4, (a2) \n\t" + "addi a2, a2, 512 \n\t" + "vle32.v v8, (a3) \n\t" + "addi a3, a3, 512 \n\t" + "vle32.v v12, (a4) \n\t" + "addi a4, a4, 512 \n\t" + "vfabs.v v16, v0 \n\t" + "vfabs.v v20, v4 \n\t" + "vfabs.v v24, v8 \n\t" + "vfabs.v v28, v12 \n\t" + "vsetvli t0, zero, e32, m2 \n\t" + "vfmax.vv v16, v16, v18 \n\t" + "vfmax.vv v20, v20, v22 \n\t" + "vfmax.vv v24, v24, v26 \n\t" + "vfmax.vv v28, v28, v30 \n\t" + "vsetvli t0, zero, e32, m1 \n\t" + "vfmax.vv v16, v16, v17 \n\t" + "vfmax.vv v20, v20, v21 \n\t" + "vfmax.vv v24, v24, v25 \n\t" + "vfmax.vv v28, v28, v29 \n\t" + + "vfredmax.vs v17, v16, v17 \n\t" + "vfredmax.vs v21, v20, v21 \n\t" + "vfredmax.vs v25, v24, v25 \n\t" + "vfredmax.vs v29, v28, v29 \n\t" + "vfmv.f.s f10, v17 \n\t" + "vfmv.f.s f11, v21 \n\t" + "vfmv.f.s f12, v25 \n\t" + "vfmv.f.s f13, v29 \n\t" + + "fmul.s f10, f10, %[RMAXREC] \n\t" + "fmul.s f11, f11, %[RMAXREC] \n\t" + "fmul.s f12, f12, %[RMAXREC] \n\t" + "fmul.s f13, f13, %[RMAXREC] \n\t" + "fsw f10, (s1) \n\t" + "addi s1, s1, 4 \n\t" + + "fsw f11, (s2) \n\t" + "addi s2, s2, 4 \n\t" + "fsw f12, (s3) \n\t" + "addi s3, s3, 4 \n\t" + "fsw f13, (s4) \n\t" + "addi s4, s4, 4 \n\t" + "fdiv.s f10, %[FONE], f10 \n\t" + "fdiv.s f11, %[FONE], f11 \n\t" + "fdiv.s f12, %[FONE], f12 \n\t" + "fdiv.s f13, %[FONE], f13 \n\t" + "vsetvli t0, zero, e32, m4 \n\t" + "vfmul.vf v16, v0, f10 \n\t" + "vfmul.vf v20, v4, f11 \n\t" + "vfmul.vf v24, v8, f12 \n\t" + "vfmul.vf v28, v12, f13 \n\t" + "vfcvt.x.f.v v16, v16 \n\t" + "vfcvt.x.f.v v20, v20 \n\t" + "vfcvt.x.f.v v24, v24 \n\t" + "vfcvt.x.f.v v28, v28 \n\t" + "vsetvli t0, zero, e16, m2 \n\t" + "vnclip.wx v16, v16, zero \n\t" + "vnclip.wx v20, v20, zero \n\t" + "vnclip.wx v24, v24, zero \n\t" + "vnclip.wx v28, v28, zero \n\t" + "vsetvli t0, t1, e8, m1 \n\t" + "vnclip.wx v16, v16, zero \n\t" + "vnclip.wx v20, v20, zero \n\t" + "vnclip.wx v24, v24, zero \n\t" + "vnclip.wx v28, v28, zero \n\t" + "vse8.v v16, (s1) \n\t" + "addi s1, s1, 140 \n\t" + "vse8.v v20, (s2) \n\t" + "addi s2, s2, 140 \n\t" + "vse8.v v24, (s3) \n\t" + "addi s3, s3, 140 \n\t" + "vse8.v v28, (s4) \n\t" + "addi s4, s4, 140 \n\t" + "bge %[K], t3, LOOP_MAIN%= \n\t" + "blt %[K], t2, TAIL%= \n\t" + "LOOP_K%=: \n\t" + "vsetvli t1, %[K], e32, m4 \n\t" + "vle32.v v0, (a1) \n\t" + "addi a1, a1, 128 \n\t" + "sub %[K], %[K], t1 \n\t" + "vfabs.v v16, v0 \n\t" + "vsetvli t0, zero, e32, m2 \n\t" + "vfmax.vv v16, v16, v18 \n\t" + "vsetvli t0, zero, e32, m1 \n\t" + "vfmax.vv v16, v16, v17 \n\t" + "vfredmax.vs v17, v16, v17 \n\t" + "vfmv.f.s f10, v17 \n\t" + + "fmul.s f10, f10, %[RMAXREC] \n\t" + "fsw f10, (s1) \n\t" + "addi s1, s1, 4 \n\t" + "fdiv.s f11, %[FONE], f10 \n\t" + "vsetvli t0, zero, e32, m4 \n\t" + "vfmul.vf v16, v0, f11 \n\t" + "vfcvt.x.f.v v16, v16 \n\t" + "vsetvli t0, zero, e16, m2 \n\t" + "vnclip.wx v16, v16, zero \n\t" + "vsetvli t0, zero, e8, m1 \n\t" + "vnclip.wx v16, v16, zero \n\t" + "vse8.v v16, (s1) \n\t" + "addi s1, s1, 32 \n\t" + "bge %[K], t2, LOOP_K%= \n\t" + "TAIL%=: \n\t" + "blez %[K], END%= \n\t" + "vsetvli t0, t3, e32, m4 \n\t" + "vxor.vv v0, v0, v0 \n\t" + "vxor.vv v16, v16, v16 \n\t" + "jal x0, LOOP_K%= \n\t" + "END%=: \n\t" + : [K] "+r"(CountK) + : [FONE] "f"(fone), [RMAXREC] "f"(range_max_reciprocal), [SRC] "r"(SRC), [DST] "r"(DST) + : "cc", "t3", "t2", "t1", "t0", "a1", "a2", "a3", "a4", "s1", "s2", "s3", "s4", "f10", "f11", "f12", "f13"); } } // namespace ime1 @@ -1451,1746 +584,444 @@ namespace { "vadd.vi v1, v1, -12 \n\t" template <bool HasZeroPoint> -void SQ4BitGemmM4Kernel_CompInt8_ScaleFp16_Impl(size_t BlkLen, - const std::byte * QuantA, - const std::byte * QuantBData, - const float * QuantBScale, - const std::byte * QuantBZeroPoint, - float * C, - size_t CountN, - size_t BlockCountK, - const float * Bias, - const size_t ldc) { - GGML_UNUSED(QuantBScale); - GGML_UNUSED(QuantBZeroPoint); +void SQ4BitGemmM4Kernel_CompInt8_ScaleFp16_Impl(size_t BlkLen, + const uint8_t * QuantA, + const uint8_t * QuantBData, + float * C, + size_t CountN, + size_t BlockCountK, + const size_t ldc) { size_t LDC = ldc * sizeof(float); const size_t INNER = BlkLen / 16; float tmp[4 * 16]; if constexpr (HasZeroPoint) { for (size_t n = 0; n < CountN; n += 16) { - size_t NBLKS = (CountN - n) > 16 ? 16 : CountN - n; - std::byte * QuantBDataPtr = (std::byte *) QuantBData + // - n * BlockCountK * BlkLen / 2 + // b data - n * BlockCountK * sizeof(uint8_t) + // zp - n * BlockCountK * sizeof(_Float16); // scale + size_t NBLKS = (CountN - n) > 16 ? 16 : CountN - n; + uint8_t * QuantBDataPtr = (uint8_t *) QuantBData + // + n * BlockCountK * BlkLen / 2 + // b data + n * BlockCountK * sizeof(uint8_t) + // zp + n * BlockCountK * sizeof(_Float16); // scale float * CPtr = C + n; if (NBLKS < 16) { CPtr = tmp; LDC = 16 * sizeof(float); } - if (Bias != nullptr) { - const float * bias = Bias + n; - if (NBLKS < 16) { - __asm__ volatile( - "vsetvli t0, %[N], e32, m2 \n\t" - "vle32.v v0, (%[SRC]) \n\t" - "vse32.v v0, (%[DST]) \n\t" - : - : [SRC] "r"(bias), [DST] "r"(tmp), [N] "r"(NBLKS) - : "cc", "t0"); - bias = tmp; - } - __asm__ volatile(LOAD_BIAS - - "addi t3, %[BlockCountK], 0 \n\t" - - "vsetvli t0, zero, e8, m1 \n\t" - "li s1, 24 \n\t" - "vmv.v.i v1, 3 \n\t" - "vsetvli t0, s1, e8, m1 \n\t" - "vmv.v.i v1, 2 \n\t" - "vsetvli t0, zero, e8, mf2 \n\t" - "vmv.v.i v1, 1 \n\t" - "vsetvli t0, zero, e8, mf4 \n\t" - "vmv.v.i v1, 0 \n\t" - - "addi a1, %[A], 0 \n\t" - "addi s1, %[B], 0 \n\t" - - "BLOCK_COUNTK_LOOP%=: \n\t" - // scale offset - "addi s5, s1, 0 \n\t" - // zp offset - "addi s6, s1, 32 \n\t" - "addi s1, s6, 16 \n\t" - "addi s2, s1, 32 \n\t" - "addi s3, s1, 32*2 \n\t" - "addi s4, s1, 32*3 \n\t" - - "vsetvli t0, zero, e32, m8 \n\t" - "vxor.vv v16, v16, v16 \n\t" - // load a scale - "flw f1, (a1) \n\t" - "flw f2, 4(a1) \n\t" - "flw f3, 8(a1) \n\t" - "flw f4, 12(a1) \n\t" - "addi a1, a1, 16 \n\t" - "addi t2, %[INNER], 0 \n\t" - - SQ4BIT_KERNEL_LOAD_ZP_16X1_v2 - - "BLOCK_INNER_LOOP%=: \n\t" - - LOAD_B_16x8x2 - - "vle8.v v10, (a1) \n\t" - "addi a1, a1, 32 \n\t" - "vle8.v v11, (a1) \n\t" - "addi a1, a1, 32 \n\t" - "vsub.vv v2, v2, v12 \n\t" - "vsub.vv v6, v6, v12 \n\t" - "vsub.vv v3, v3, v13 \n\t" - "vsub.vv v7, v7, v13 \n\t" - "vsub.vv v4, v4, v14 \n\t" - "vsub.vv v8, v8, v14 \n\t" - "vsub.vv v5, v5, v15 \n\t" - "vsub.vv v9, v9, v15 \n\t" - - SQ4BIT_KERNEL_COMP_4x16x16 - - "addi t2, t2, -1 \n\t" - "bnez t2, BLOCK_INNER_LOOP%= \n\t" - - LOAD_SCALE_4x16_FP16 - - "vsetvli t0, zero, e32, m8 \n\t" - "vfcvt.f.x.v v16, v16 \n\t" - "vfmacc.vv v24, v16, v8 \n\t" - "addi t3, t3, -1 \n\t" - "bnez t3, BLOCK_COUNTK_LOOP%= \n\t" - - "RESULT_SAVE%=: \n\t" - - SAVE_RESULT_4x16 - - : - : [INNER] "r"(INNER), [A] "r"(QuantA), [B] "r"(QuantBDataPtr), [LDC] "r"(LDC), - [BlockCountK] "r"(BlockCountK), [C] "r"(CPtr), [BIAS] "r"(bias) - : "cc", "t0", "t1", "t2", "t3", "a1", "a2", "a3", "a4", "f1", "f2", "f3", "f4", "s1", - "s2", "s3", "s4", "s5", "s6"); - - } else { - __asm__ volatile( - "vsetvli t0, zero, e32, m8 \n\t" - "vxor.vv v24, v24, v24 \n\t" - "addi t3, %[BlockCountK], 0 \n\t" - "vsetvli t0, zero, e8, m1 \n\t" - "li s1, 24 \n\t" - "vmv.v.i v1, 3 \n\t" - "vsetvli t0, s1, e8, m1 \n\t" - "vmv.v.i v1, 2 \n\t" - "vsetvli t0, zero, e8, mf2 \n\t" - "vmv.v.i v1, 1 \n\t" - "vsetvli t0, zero, e8, mf4 \n\t" - "vmv.v.i v1, 0 \n\t" - "addi a1, %[A], 0 \n\t" - "addi s1, %[B], 0 \n\t" - "BLOCK_COUNTK_LOOP%=: \n\t" - // scale offset - "addi s5, s1, 0 \n\t" - // zp offset - "addi s6, s1, 32 \n\t" - "addi s1, s6, 16 \n\t" - "addi s2, s1, 32 \n\t" - "addi s3, s1, 32*2 \n\t" - "addi s4, s1, 32*3 \n\t" - - "vsetvli t0, zero, e32, m8 \n\t" - "vxor.vv v16, v16, v16 \n\t" - // load a scale - "flw f1, (a1) \n\t" - "flw f2, 4(a1) \n\t" - "flw f3, 8(a1) \n\t" - "flw f4, 12(a1) \n\t" - "addi a1, a1, 16 \n\t" - "addi t2, %[INNER], 0 \n\t" - - SQ4BIT_KERNEL_LOAD_ZP_16X1_v2 - - "BLOCK_INNER_LOOP%=: \n\t" - - LOAD_B_16x8x2 - - "vle8.v v10, (a1) \n\t" - "addi a1, a1, 32 \n\t" - "vle8.v v11, (a1) \n\t" - "addi a1, a1, 32 \n\t" - "vsub.vv v2, v2, v12 \n\t" - "vsub.vv v6, v6, v12 \n\t" - "vsub.vv v3, v3, v13 \n\t" - "vsub.vv v7, v7, v13 \n\t" - "vsub.vv v4, v4, v14 \n\t" - "vsub.vv v8, v8, v14 \n\t" - "vsub.vv v5, v5, v15 \n\t" - "vsub.vv v9, v9, v15 \n\t" - - SQ4BIT_KERNEL_COMP_4x16x16 - - "addi t2, t2, -1 \n\t" - "bnez t2, BLOCK_INNER_LOOP%= \n\t" - - LOAD_SCALE_4x16_FP16 - - "vsetvli t0, zero, e32, m8 \n\t" - "vfcvt.f.x.v v16, v16 \n\t" - "vfmacc.vv v24, v16, v8 \n\t" - "addi t3, t3, -1 \n\t" - "bnez t3, BLOCK_COUNTK_LOOP%= \n\t" - - "RESULT_SAVE%=: \n\t" - - SAVE_RESULT_4x16 - - : - : [INNER] "r"(INNER), [A] "r"(QuantA), [B] "r"(QuantBDataPtr), [LDC] "r"(LDC), - [BlockCountK] "r"(BlockCountK), [C] "r"(CPtr) - : "cc", "t0", "t1", "t2", "t3", "a1", "a2", "a3", "a4", "f1", "f2", "f3", "f4", "s1", "s2", "s3", - "s4", "s5", "s6"); - } - } - } else { - for (size_t n = 0; n < CountN; n += 16) { - size_t NBLKS = (CountN - n) > 16 ? 16 : CountN - n; - std::byte * QuantBDataPtr = (std::byte *) QuantBData + // - n * BlockCountK * BlkLen / 2 + // b data - n * BlockCountK * sizeof(_Float16); // scale - float * CPtr = C + n; - if (NBLKS < 16) { - CPtr = tmp; - LDC = 16 * sizeof(float); - } - if (Bias != nullptr) { - const float * bias = Bias + n; - if (NBLKS < 16) { - __asm__ volatile( - "vsetvli t0, %[N], e32, m2 \n\t" - "vle32.v v0, (%[SRC]) \n\t" - "vse32.v v0, (%[DST]) \n\t" - : - : [SRC] "r"(bias), [DST] "r"(tmp), [N] "r"(NBLKS) - : "cc", "t0"); - bias = tmp; - } - __asm__ volatile(LOAD_BIAS - - "addi t3, %[BlockCountK], 0 \n\t" - "addi a1, %[A], 0 \n\t" - "addi s1, %[B], 0 \n\t" - "BLOCK_COUNTK_LOOP%=: \n\t" - "addi s5, s1, 0 \n\t" - "addi s1, s5, 32 \n\t" - "addi s2, s1, 32 \n\t" - "addi s3, s1, 32*2 \n\t" - "addi s4, s1, 32*3 \n\t" - "vsetvli t0, zero, e32, m8 \n\t" - "vxor.vv v16, v16, v16 \n\t" - // load a scale - "flw f1, (a1) \n\t" - "flw f2, 4(a1) \n\t" - "flw f3, 8(a1) \n\t" - "flw f4, 12(a1) \n\t" - "addi a1, a1, 16 \n\t" - "addi t2, %[INNER], 0 \n\t" - "BLOCK_INNER_LOOP%=: \n\t" - - LOAD_B_16x8x2 - - "vsetvli t0, zero, e8, m1 \n\t" - "vle8.v v10, (a1) \n\t" - "addi a1, a1, 32 \n\t" - "vle8.v v11, (a1) \n\t" - "addi a1, a1, 32 \n\t" - "vadd.vi v2, v2, -8 \n\t" - "vadd.vi v3, v3, -8 \n\t" - "vadd.vi v4, v4, -8 \n\t" - "vadd.vi v5, v5, -8 \n\t" - "vadd.vi v6, v6, -8 \n\t" - "vadd.vi v7, v7, -8 \n\t" - "vadd.vi v8, v8, -8 \n\t" - "vadd.vi v9, v9, -8 \n\t" - - SQ4BIT_KERNEL_COMP_4x16x16 - - "addi t2, t2, -1 \n\t" - "bnez t2, BLOCK_INNER_LOOP%= \n\t" - - LOAD_SCALE_4x16_FP16 - - "vsetvli t0, zero, e32, m8 \n\t" - "vfcvt.f.x.v v16, v16 \n\t" - "vfmacc.vv v24, v16, v8 \n\t" - "addi t3, t3, -1 \n\t" - "bnez t3, BLOCK_COUNTK_LOOP%= \n\t" - "RESULT_SAVE%=: \n\t" - - SAVE_RESULT_4x16 - - : - : [INNER] "r"(INNER), [A] "r"(QuantA), [B] "r"(QuantBDataPtr), [LDC] "r"(LDC), - [BlockCountK] "r"(BlockCountK), [C] "r"(CPtr), [BIAS] "r"(bias) - : "cc", "t0", "t1", "t2", "t3", "a1", "a2", "a3", "a4", "f1", "f2", "f3", "f4", "s1", - "s2", "s3", "s4", "s5", "s6"); - - } else { - __asm__ volatile( - "vsetvli t0, zero, e32, m8 \n\t" - "vxor.vv v24, v24, v24 \n\t" - "addi t3, %[BlockCountK], 0 \n\t" - "addi a1, %[A], 0 \n\t" - "addi s1, %[B], 0 \n\t" - "BLOCK_COUNTK_LOOP%=: \n\t" - "addi s5, s1, 0 \n\t" - "addi s1, s5, 32 \n\t" - "addi s2, s1, 32 \n\t" - "addi s3, s1, 32*2 \n\t" - "addi s4, s1, 32*3 \n\t" - "vsetvli t0, zero, e32, m8 \n\t" - "vxor.vv v16, v16, v16 \n\t" - // load a scale - "flw f1, (a1) \n\t" - "flw f2, 4(a1) \n\t" - "flw f3, 8(a1) \n\t" - "flw f4, 12(a1) \n\t" - "addi a1, a1, 16 \n\t" - "addi t2, %[INNER], 0 \n\t" - "BLOCK_INNER_LOOP%=: \n\t" - - LOAD_B_16x8x2 - - "vsetvli t0, zero, e8, m1 \n\t" - "vle8.v v10, (a1) \n\t" - "addi a1, a1, 32 \n\t" - "vle8.v v11, (a1) \n\t" - "addi a1, a1, 32 \n\t" - "vadd.vi v2, v2, -8 \n\t" - "vadd.vi v3, v3, -8 \n\t" - "vadd.vi v4, v4, -8 \n\t" - "vadd.vi v5, v5, -8 \n\t" - "vadd.vi v6, v6, -8 \n\t" - "vadd.vi v7, v7, -8 \n\t" - "vadd.vi v8, v8, -8 \n\t" - "vadd.vi v9, v9, -8 \n\t" - - SQ4BIT_KERNEL_COMP_4x16x16 - - "addi t2, t2, -1 \n\t" - "bnez t2, BLOCK_INNER_LOOP%= \n\t" - - LOAD_SCALE_4x16_FP16 - - "vsetvli t0, zero, e32, m8 \n\t" - "vfcvt.f.x.v v16, v16 \n\t" - "vfmacc.vv v24, v16, v8 \n\t" - "addi t3, t3, -1 \n\t" - "bnez t3, BLOCK_COUNTK_LOOP%= \n\t" - "RESULT_SAVE%=: \n\t" - - SAVE_RESULT_4x16 - - : - : [INNER] "r"(INNER), [A] "r"(QuantA), [B] "r"(QuantBDataPtr), [LDC] "r"(LDC), - [BlockCountK] "r"(BlockCountK), [C] "r"(CPtr) - : "cc", "t0", "t1", "t2", "t3", "a1", "a2", "a3", "a4", "f1", "f2", "f3", "f4", "s1", "s2", "s3", - "s4", "s5", "s6"); - } - } - } - if (CountN % 16 != 0) { - // stroe output from tmp to C when NBLKS less than 16. - float * CPtr = C + CountN / 16 * 16; - const size_t N = CountN % 16; - LDC = ldc * sizeof(float); - __asm__ volatile( - "vsetvli t0, %[N], e32, m2 \n\t" - "vle32.v v0, (%[SRC]) \n\t" - "addi s2, %[SRC], 64 \n\t" - "addi s3, %[SRC], 64*2 \n\t" - "addi s4, %[SRC], 64*3 \n\t" - "vle32.v v2, (s2) \n\t" - "vle32.v v4, (s3) \n\t" - "vle32.v v6, (s4) \n\t" - "add t2, %[DST], %[LDC] \n\t" - "add t3, t2, %[LDC] \n\t" - "add t4, t3, %[LDC] \n\t" - "vse32.v v0, (%[DST]) \n\t" - "vse32.v v2, (t2) \n\t" - "vse32.v v4, (t3) \n\t" - "vse32.v v6, (t4) \n\t" - : - : [N] "r"(N), [SRC] "r"(tmp), [DST] "r"(CPtr), [LDC] "r"(LDC) - : "cc", "t0", "t2", "t3", "t4", "s2", "s3", "s4"); - } -} -template <bool HasZeroPoint> -void SQ4BitGemmM4Kernel_CompInt8_Impl(size_t BlkLen, - const std::byte * QuantA, - const std::byte * QuantBData, - const float * QuantBScale, - const std::byte * QuantBZeroPoint, - float * C, - size_t CountN, - size_t BlockCountK, - const float * Bias, - const size_t ldc) { - GGML_UNUSED(QuantBScale); - GGML_UNUSED(QuantBZeroPoint); - size_t LDC = ldc * sizeof(float); - const size_t INNER = BlkLen / 16; - float tmp[4 * 16]; - - if constexpr (HasZeroPoint) { - for (size_t n = 0; n < CountN; n += 16) { - size_t NBLKS = (CountN - n) > 16 ? 16 : CountN - n; - std::byte * QuantBDataPtr = (std::byte *) QuantBData + // - n * BlockCountK * BlkLen / 2 + // b data - n * BlockCountK * sizeof(uint8_t) + // zp - n * BlockCountK * sizeof(float); // scale - float * CPtr = C + n; - if (NBLKS < 16) { - CPtr = tmp; - LDC = 16 * sizeof(float); - } - if (Bias != nullptr) { - const float * bias = Bias + n; - if (NBLKS < 16) { - __asm__ volatile( - "vsetvli t0, %[N], e32, m2 \n\t" - "vle32.v v0, (%[SRC]) \n\t" - "vse32.v v0, (%[DST]) \n\t" - : - : [SRC] "r"(bias), [DST] "r"(tmp), [N] "r"(NBLKS) - : "cc", "t0"); - bias = tmp; - } - - __asm__ volatile(LOAD_BIAS - "addi t3, %[BlockCountK], 0 \n\t" - "vsetvli t0, zero, e8, m1 \n\t" - "li s1, 24 \n\t" - "vmv.v.i v1, 3 \n\t" - "vsetvli t0, s1, e8, m1 \n\t" - "vmv.v.i v1, 2 \n\t" - "vsetvli t0, zero, e8, mf2 \n\t" - "vmv.v.i v1, 1 \n\t" - "vsetvli t0, zero, e8, mf4 \n\t" - "vmv.v.i v1, 0 \n\t" - "addi a1, %[A], 0 \n\t" - "addi s1, %[B], 0 \n\t" - "BLOCK_COUNTK_LOOP%=: \n\t" - // scale offset - "addi s5, s1, 0 \n\t" - // zp offset - "addi s6, s1, 64 \n\t" - "addi s1, s6, 16 \n\t" - "addi s2, s1, 32 \n\t" - "addi s3, s1, 32*2 \n\t" - "addi s4, s1, 32*3 \n\t" - "vsetvli t0, zero, e32, m8 \n\t" - "vxor.vv v16, v16, v16 \n\t" - // load a scale - "flw f1, (a1) \n\t" - "flw f2, 4(a1) \n\t" - "flw f3, 8(a1) \n\t" - "flw f4, 12(a1) \n\t" - "addi a1, a1, 16 \n\t" - "addi t2, %[INNER], 0 \n\t" - - SQ4BIT_KERNEL_LOAD_ZP_16X1_v2 - - "BLOCK_INNER_LOOP%=: \n\t" - - LOAD_B_16x8x2 - - "vle8.v v10, (a1) \n\t" - "addi a1, a1, 32 \n\t" - "vle8.v v11, (a1) \n\t" - "addi a1, a1, 32 \n\t" - "vsub.vv v2, v2, v12 \n\t" - "vsub.vv v6, v6, v12 \n\t" - "vsub.vv v3, v3, v13 \n\t" - "vsub.vv v7, v7, v13 \n\t" - "vsub.vv v4, v4, v14 \n\t" - "vsub.vv v8, v8, v14 \n\t" - "vsub.vv v5, v5, v15 \n\t" - "vsub.vv v9, v9, v15 \n\t" - - SQ4BIT_KERNEL_COMP_4x16x16 - - "addi t2, t2, -1 \n\t" - "bnez t2, BLOCK_INNER_LOOP%= \n\t" - - LOAD_SCALE_4x16 - - "vsetvli t0, zero, e32, m8 \n\t" - "vfcvt.f.x.v v16, v16 \n\t" - "vfmacc.vv v24, v16, v8 \n\t" - "addi t3, t3, -1 \n\t" - "bnez t3, BLOCK_COUNTK_LOOP%= \n\t" - - "RESULT_SAVE%=: \n\t" - - SAVE_RESULT_4x16 - - : - : [INNER] "r"(INNER), [A] "r"(QuantA), [B] "r"(QuantBDataPtr), [LDC] "r"(LDC), - [BlockCountK] "r"(BlockCountK), [C] "r"(CPtr), [BIAS] "r"(bias) - : "cc", "t0", "t1", "t2", "t3", "a1", "a2", "a3", "a4", "f1", "f2", "f3", "f4", "s1", - "s2", "s3", "s4", "s5", "s6"); - - } else { - __asm__ volatile( - "vsetvli t0, zero, e32, m8 \n\t" - "vxor.vv v24, v24, v24 \n\t" - "addi t3, %[BlockCountK], 0 \n\t" - "vsetvli t0, zero, e8, m1 \n\t" - "li s1, 24 \n\t" - "vmv.v.i v1, 3 \n\t" - "vsetvli t0, s1, e8, m1 \n\t" - "vmv.v.i v1, 2 \n\t" - "vsetvli t0, zero, e8, mf2 \n\t" - "vmv.v.i v1, 1 \n\t" - "vsetvli t0, zero, e8, mf4 \n\t" - "vmv.v.i v1, 0 \n\t" - "addi a1, %[A], 0 \n\t" - "addi s1, %[B], 0 \n\t" - "BLOCK_COUNTK_LOOP%=: \n\t" - // scale offset - "addi s5, s1, 0 \n\t" - // zp offset - "addi s6, s1, 64 \n\t" - "addi s1, s6, 16 \n\t" - "addi s2, s1, 32 \n\t" - "addi s3, s1, 32*2 \n\t" - "addi s4, s1, 32*3 \n\t" - "vsetvli t0, zero, e32, m8 \n\t" - "vxor.vv v16, v16, v16 \n\t" - // load a scale - // load a scale - "flw f1, (a1) \n\t" - "flw f2, 4(a1) \n\t" - "flw f3, 8(a1) \n\t" - "flw f4, 12(a1) \n\t" - "addi a1, a1, 16 \n\t" - "addi t2, %[INNER], 0 \n\t" - - SQ4BIT_KERNEL_LOAD_ZP_16X1_v2 - - "BLOCK_INNER_LOOP%=: \n\t" - - LOAD_B_16x8x2 - - "vle8.v v10, (a1) \n\t" - "addi a1, a1, 32 \n\t" - "vle8.v v11, (a1) \n\t" - "addi a1, a1, 32 \n\t" - "vsub.vv v2, v2, v12 \n\t" - "vsub.vv v6, v6, v12 \n\t" - "vsub.vv v3, v3, v13 \n\t" - "vsub.vv v7, v7, v13 \n\t" - "vsub.vv v4, v4, v14 \n\t" - "vsub.vv v8, v8, v14 \n\t" - "vsub.vv v5, v5, v15 \n\t" - "vsub.vv v9, v9, v15 \n\t" - - SQ4BIT_KERNEL_COMP_4x16x16 - - "addi t2, t2, -1 \n\t" - "bnez t2, BLOCK_INNER_LOOP%= \n\t" - - LOAD_SCALE_4x16 - - "vsetvli t0, zero, e32, m8 \n\t" - "vfcvt.f.x.v v16, v16 \n\t" - "vfmacc.vv v24, v16, v8 \n\t" - "addi t3, t3, -1 \n\t" - "bnez t3, BLOCK_COUNTK_LOOP%= \n\t" - - "RESULT_SAVE%=: \n\t" - - SAVE_RESULT_4x16 - - : - : [INNER] "r"(INNER), [A] "r"(QuantA), [B] "r"(QuantBDataPtr), [LDC] "r"(LDC), - [BlockCountK] "r"(BlockCountK), [C] "r"(CPtr) - : "cc", "t0", "t1", "t2", "t3", "a1", "a2", "a3", "a4", "f1", "f2", "f3", "f4", "s1", "s2", "s3", - "s4", "s5", "s6"); - } + __asm__ volatile( + "vsetvli t0, zero, e32, m8 \n\t" + "vxor.vv v24, v24, v24 \n\t" + "addi t3, %[BlockCountK], 0 \n\t" + "vsetvli t0, zero, e8, m1 \n\t" + "li s1, 24 \n\t" + "vmv.v.i v1, 3 \n\t" + "vsetvli t0, s1, e8, m1 \n\t" + "vmv.v.i v1, 2 \n\t" + "vsetvli t0, zero, e8, mf2 \n\t" + "vmv.v.i v1, 1 \n\t" + "vsetvli t0, zero, e8, mf4 \n\t" + "vmv.v.i v1, 0 \n\t" + "addi a1, %[A], 0 \n\t" + "addi s1, %[B], 0 \n\t" + "BLOCK_COUNTK_LOOP%=: \n\t" + // scale offset + "addi s5, s1, 0 \n\t" + // zp offset + "addi s6, s1, 32 \n\t" + "addi s1, s6, 16 \n\t" + "addi s2, s1, 32 \n\t" + "addi s3, s1, 32*2 \n\t" + "addi s4, s1, 32*3 \n\t" + + "vsetvli t0, zero, e32, m8 \n\t" + "vxor.vv v16, v16, v16 \n\t" + // load a scale + "flw f1, (a1) \n\t" + "flw f2, 4(a1) \n\t" + "flw f3, 8(a1) \n\t" + "flw f4, 12(a1) \n\t" + "addi a1, a1, 16 \n\t" + "addi t2, %[INNER], 0 \n\t" + + SQ4BIT_KERNEL_LOAD_ZP_16X1_v2 + + "BLOCK_INNER_LOOP%=: \n\t" + + LOAD_B_16x8x2 + + "vle8.v v10, (a1) \n\t" + "addi a1, a1, 32 \n\t" + "vle8.v v11, (a1) \n\t" + "addi a1, a1, 32 \n\t" + "vsub.vv v2, v2, v12 \n\t" + "vsub.vv v6, v6, v12 \n\t" + "vsub.vv v3, v3, v13 \n\t" + "vsub.vv v7, v7, v13 \n\t" + "vsub.vv v4, v4, v14 \n\t" + "vsub.vv v8, v8, v14 \n\t" + "vsub.vv v5, v5, v15 \n\t" + "vsub.vv v9, v9, v15 \n\t" + + SQ4BIT_KERNEL_COMP_4x16x16 + + "addi t2, t2, -1 \n\t" + "bnez t2, BLOCK_INNER_LOOP%= \n\t" + + LOAD_SCALE_4x16_FP16 + + "vsetvli t0, zero, e32, m8 \n\t" + "vfcvt.f.x.v v16, v16 \n\t" + "vfmacc.vv v24, v16, v8 \n\t" + "addi t3, t3, -1 \n\t" + "bnez t3, BLOCK_COUNTK_LOOP%= \n\t" + + "RESULT_SAVE%=: \n\t" + + SAVE_RESULT_4x16 + + : + : [INNER] "r"(INNER), [A] "r"(QuantA), [B] "r"(QuantBDataPtr), [LDC] "r"(LDC), + [BlockCountK] "r"(BlockCountK), [C] "r"(CPtr) + : "cc", "t0", "t1", "t2", "t3", "a1", "a2", "a3", "a4", "f1", "f2", "f3", "f4", "s1", "s2", "s3", "s4", + "s5", "s6"); } } else { for (size_t n = 0; n < CountN; n += 16) { - size_t NBLKS = (CountN - n) > 16 ? 16 : CountN - n; - std::byte * QuantBDataPtr = (std::byte *) QuantBData + // - n * BlockCountK * BlkLen / 2 + // b data - n * BlockCountK * sizeof(float); // scale + size_t NBLKS = (CountN - n) > 16 ? 16 : CountN - n; + uint8_t * QuantBDataPtr = (uint8_t *) QuantBData + // + n * BlockCountK * BlkLen / 2 + // b data + n * BlockCountK * sizeof(_Float16); // scale float * CPtr = C + n; if (NBLKS < 16) { CPtr = tmp; LDC = 16 * sizeof(float); } - if (Bias != nullptr) { - const float * bias = Bias + n; - if (NBLKS < 16) { - __asm__ volatile( - "vsetvli t0, %[N], e32, m2 \n\t" - "vle32.v v0, (%[SRC]) \n\t" - "vse32.v v0, (%[DST]) \n\t" - : - : [SRC] "r"(bias), [DST] "r"(tmp), [N] "r"(NBLKS) - : "cc", "t0"); - bias = tmp; - } - __asm__ volatile(LOAD_BIAS - "addi t3, %[BlockCountK], 0 \n\t" - "addi a1, %[A], 0 \n\t" - "addi s1, %[B], 0 \n\t" - "BLOCK_COUNTK_LOOP%=: \n\t" - "addi s5, s1, 0 \n\t" - "addi s1, s5, 64 \n\t" - "addi s2, s1, 32 \n\t" - "addi s3, s1, 32*2 \n\t" - "addi s4, s1, 32*3 \n\t" - "vsetvli t0, zero, e32, m8 \n\t" - "vxor.vv v16, v16, v16 \n\t" - // load a scale - "flw f1, (a1) \n\t" - "flw f2, 4(a1) \n\t" - "flw f3, 8(a1) \n\t" - "flw f4, 12(a1) \n\t" - "addi a1, a1, 16 \n\t" - "addi t2, %[INNER], 0 \n\t" - "BLOCK_INNER_LOOP%=: \n\t" - - LOAD_B_16x8x2 - - "vsetvli t0, zero, e8, m1 \n\t" - "vle8.v v10, (a1) \n\t" - "addi a1, a1, 32 \n\t" - "vle8.v v11, (a1) \n\t" - "addi a1, a1, 32 \n\t" - "vadd.vi v2, v2, -8 \n\t" - "vadd.vi v3, v3, -8 \n\t" - "vadd.vi v4, v4, -8 \n\t" - "vadd.vi v5, v5, -8 \n\t" - "vadd.vi v6, v6, -8 \n\t" - "vadd.vi v7, v7, -8 \n\t" - "vadd.vi v8, v8, -8 \n\t" - "vadd.vi v9, v9, -8 \n\t" - - SQ4BIT_KERNEL_COMP_4x16x16 - - "addi t2, t2, -1 \n\t" - "bnez t2, BLOCK_INNER_LOOP%= \n\t" - - LOAD_SCALE_4x16 - - "vsetvli t0, zero, e32, m8 \n\t" - "vfcvt.f.x.v v16, v16 \n\t" - "vfmacc.vv v24, v16, v8 \n\t" - "addi t3, t3, -1 \n\t" - "bnez t3, BLOCK_COUNTK_LOOP%= \n\t" - - "RESULT_SAVE%=: \n\t" - - SAVE_RESULT_4x16 - - : - : [INNER] "r"(INNER), [A] "r"(QuantA), [B] "r"(QuantBDataPtr), [LDC] "r"(LDC), - [BlockCountK] "r"(BlockCountK), [C] "r"(CPtr), [BIAS] "r"(bias) - : "cc", "t0", "t1", "t2", "t3", "a1", "a2", "a3", "a4", "f1", "f2", "f3", "f4", "s1", - "s2", "s3", "s4", "s5", "s6"); - - } else { - __asm__ volatile( - "vsetvli t0, zero, e32, m8 \n\t" - "vxor.vv v24, v24, v24 \n\t" - "addi t3, %[BlockCountK], 0 \n\t" - "addi a1, %[A], 0 \n\t" - "addi s1, %[B], 0 \n\t" - "BLOCK_COUNTK_LOOP%=: \n\t" - "addi s5, s1, 0 \n\t" - "addi s1, s5, 64 \n\t" - "addi s2, s1, 32 \n\t" - "addi s3, s1, 32*2 \n\t" - "addi s4, s1, 32*3 \n\t" - "vsetvli t0, zero, e32, m8 \n\t" - "vxor.vv v16, v16, v16 \n\t" - // load a scale - "flw f1, (a1) \n\t" - "flw f2, 4(a1) \n\t" - "flw f3, 8(a1) \n\t" - "flw f4, 12(a1) \n\t" - "addi a1, a1, 16 \n\t" - "addi t2, %[INNER], 0 \n\t" - "BLOCK_INNER_LOOP%=: \n\t" - - LOAD_B_16x8x2 - - "vsetvli t0, zero, e8, m1 \n\t" - "vle8.v v10, (a1) \n\t" - - "addi a1, a1, 32 \n\t" - "vle8.v v11, (a1) \n\t" - "addi a1, a1, 32 \n\t" - "vadd.vi v2, v2, -8 \n\t" - "vadd.vi v3, v3, -8 \n\t" - "vadd.vi v4, v4, -8 \n\t" - "vadd.vi v5, v5, -8 \n\t" - "vadd.vi v6, v6, -8 \n\t" - "vadd.vi v7, v7, -8 \n\t" - "vadd.vi v8, v8, -8 \n\t" - "vadd.vi v9, v9, -8 \n\t" - - SQ4BIT_KERNEL_COMP_4x16x16 - - "addi t2, t2, -1 \n\t" - "bnez t2, BLOCK_INNER_LOOP%= \n\t" - - LOAD_SCALE_4x16 - - "vsetvli t0, zero, e32, m8 \n\t" - "vfcvt.f.x.v v16, v16 \n\t" - "vfmacc.vv v24, v16, v8 \n\t" - "addi t3, t3, -1 \n\t" - "bnez t3, BLOCK_COUNTK_LOOP%= \n\t" - - "RESULT_SAVE%=: \n\t" - - SAVE_RESULT_4x16 - - : - : [INNER] "r"(INNER), [A] "r"(QuantA), [B] "r"(QuantBDataPtr), [LDC] "r"(LDC), - [BlockCountK] "r"(BlockCountK), [C] "r"(CPtr) - : "cc", "t0", "t1", "t2", "t3", "a1", "a2", "a3", "a4", "f1", "f2", "f3", "f4", "s1", "s2", "s3", - "s4", "s5", "s6"); - } + + __asm__ volatile( + "vsetvli t0, zero, e32, m8 \n\t" + "vxor.vv v24, v24, v24 \n\t" + "addi t3, %[BlockCountK], 0 \n\t" + "addi a1, %[A], 0 \n\t" + "addi s1, %[B], 0 \n\t" + "BLOCK_COUNTK_LOOP%=: \n\t" + "addi s5, s1, 0 \n\t" + "addi s1, s5, 32 \n\t" + "addi s2, s1, 32 \n\t" + "addi s3, s1, 32*2 \n\t" + "addi s4, s1, 32*3 \n\t" + "vsetvli t0, zero, e32, m8 \n\t" + "vxor.vv v16, v16, v16 \n\t" + // load a scale + "flw f1, (a1) \n\t" + "flw f2, 4(a1) \n\t" + "flw f3, 8(a1) \n\t" + "flw f4, 12(a1) \n\t" + "addi a1, a1, 16 \n\t" + "addi t2, %[INNER], 0 \n\t" + "BLOCK_INNER_LOOP%=: \n\t" + + LOAD_B_16x8x2 + + "vsetvli t0, zero, e8, m1 \n\t" + "vle8.v v10, (a1) \n\t" + "addi a1, a1, 32 \n\t" + "vle8.v v11, (a1) \n\t" + "addi a1, a1, 32 \n\t" + "vadd.vi v2, v2, -8 \n\t" + "vadd.vi v3, v3, -8 \n\t" + "vadd.vi v4, v4, -8 \n\t" + "vadd.vi v5, v5, -8 \n\t" + "vadd.vi v6, v6, -8 \n\t" + "vadd.vi v7, v7, -8 \n\t" + "vadd.vi v8, v8, -8 \n\t" + "vadd.vi v9, v9, -8 \n\t" + + SQ4BIT_KERNEL_COMP_4x16x16 + + "addi t2, t2, -1 \n\t" + "bnez t2, BLOCK_INNER_LOOP%= \n\t" + + LOAD_SCALE_4x16_FP16 + + "vsetvli t0, zero, e32, m8 \n\t" + "vfcvt.f.x.v v16, v16 \n\t" + "vfmacc.vv v24, v16, v8 \n\t" + "addi t3, t3, -1 \n\t" + "bnez t3, BLOCK_COUNTK_LOOP%= \n\t" + "RESULT_SAVE%=: \n\t" + + SAVE_RESULT_4x16 + + : + : [INNER] "r"(INNER), [A] "r"(QuantA), [B] "r"(QuantBDataPtr), [LDC] "r"(LDC), + [BlockCountK] "r"(BlockCountK), [C] "r"(CPtr) + : "cc", "t0", "t1", "t2", "t3", "a1", "a2", "a3", "a4", "f1", "f2", "f3", "f4", "s1", "s2", "s3", "s4", + "s5", "s6"); } } - if (CountN % 16 != 0) { - // stroe output from tmp to C when NBLKS less than 16. - float * CPtr = C + CountN / 16 * 16; - const size_t N = CountN % 16; - LDC = ldc * sizeof(float); - __asm__ volatile( - "vsetvli t0, %[N], e32, m2 \n\t" - "vle32.v v0, (%[SRC]) \n\t" - "addi s2, %[SRC], 64 \n\t" - "addi s3, %[SRC], 64*2 \n\t" - "addi s4, %[SRC], 64*3 \n\t" - "vle32.v v2, (s2) \n\t" - "vle32.v v4, (s3) \n\t" - "vle32.v v6, (s4) \n\t" - "add t2, %[DST], %[LDC] \n\t" - "add t3, t2, %[LDC] \n\t" - "add t4, t3, %[LDC] \n\t" - "vse32.v v0, (%[DST]) \n\t" - "vse32.v v2, (t2) \n\t" - "vse32.v v4, (t3) \n\t" - "vse32.v v6, (t4) \n\t" - : - : [N] "r"(N), [SRC] "r"(tmp), [DST] "r"(CPtr), [LDC] "r"(LDC) - : "cc", "t0", "t2", "t3", "t4", "s2", "s3", "s4"); - } } template <bool HasZeroPoint> -void SQ4BitGemmM1Kernel_CompInt8_ScaleFp16_Impl(size_t BlkLen, - const std::byte * QuantA, - const std::byte * QuantBData, - const float * QuantBScale, - const std::byte * QuantBZeroPoint, - float * C, - size_t CountN, - size_t BlockCountK, - const float * Bias) { - GGML_UNUSED(QuantBScale); - GGML_UNUSED(QuantBZeroPoint); +void SQ4BitGemmM1Kernel_CompInt8_ScaleFp16_Impl(size_t BlkLen, + const uint8_t * QuantA, + const uint8_t * QuantBData, + float * C, + size_t CountN, + size_t BlockCountK, + const size_t ldc) { + GGML_UNUSED(ldc); size_t INNER = BlkLen / 16; if constexpr (HasZeroPoint) { for (size_t n = 0; n < CountN; n += 16) { - size_t nblks = (CountN - n) > 16 ? 16 : CountN - n; - std::byte * QuantBDataPtr = (std::byte *) QuantBData + // - n * BlockCountK * BlkLen / 2 + // b data - n * BlockCountK * sizeof(uint8_t) + // zp - n * BlockCountK * sizeof(_Float16); // scale - float * CPtr = C + n; - size_t cnt = BlockCountK; - if (Bias != nullptr) { - const float * bias = Bias + n; - __asm__ volatile( - "addi t3, %[NBLKS], 0 \n\t" - "vsetvli t0, zero, e8, m1 \n\t" - - "vmv.v.i v13, 3 \n\t" - "li s1, 24 \n\t" - "vsetvli t0, s1, e8, m1 \n\t" - "vmv.v.i v13, 2 \n\t" - "vsetvli t0, zero, e8, mf2 \n\t" - "vmv.v.i v13, 1 \n\t" - "vsetvli t0, zero, e8, mf4 \n\t" - "vmv.v.i v13, 0 \n\t" - "addi s1, %[B], 0 \n\t" - "addi s2, %[B], 8 \n\t" - "addi s3, %[B], 16 \n\t" - "addi s4, %[B], 24 \n\t" - // zp offset - "addi s7, %[B], 32 \n\t" - // a offset - "addi s5, %[A], 0 \n\t" - "addi s6, %[A], 12 \n\t" - - "vsetvli t0, t3, e32, mf2 \n\t" - "vle32.v v28, (%[BIAS]) \n\t" - "sub t3, t3, t0 \n\t" - "addi %[BIAS], %[BIAS], 16 \n\t" - "vsetvli t0, t3, e32, mf2 \n\t" - "vle32.v v29, (%[BIAS]) \n\t" - "sub t3, t3, t0 \n\t" - "addi %[BIAS], %[BIAS], 16 \n\t" - "vsetvli t0, t3, e32, mf2 \n\t" - "vle32.v v30, (%[BIAS]) \n\t" - "sub t3, t3, t0 \n\t" - "addi %[BIAS], %[BIAS], 16 \n\t" - "vsetvli t0, t3, e32, mf2 \n\t" - "vle32.v v31, (%[BIAS]) \n\t" - - "LOOP_K%=: \n\t" - "vsetvli t0, zero, e16, mf4 \n\t" - - "vle16.v v4, (s1) \n\t" - "addi s1, s1, 48 \n\t" - "vle16.v v5, (s2) \n\t" - "addi s2, s2, 72 \n\t" - "vle16.v v6, (s3) \n\t" - "addi s3, s3, 96 \n\t" - "vle16.v v7, (s4) \n\t" - "addi s4, s4, 120 \n\t" - "flw f1, (s5) \n\t" - "addi s5, s5, 4 \n\t" - "vfwcvt.f.f.v v8, v4 \n\t" - "vfwcvt.f.f.v v9, v5 \n\t" - "vfwcvt.f.f.v v10, v6 \n\t" - "vfwcvt.f.f.v v11, v7 \n\t" - - "vsetvli t0, zero, e32, mf2 \n\t" - "addi t5, %[INNER], 0 \n\t" - "vxor.vv v16, v16, v16 \n\t" - "vxor.vv v18, v18, v18 \n\t" - "vxor.vv v20, v20, v20 \n\t" - "vxor.vv v22, v22, v22 \n\t" - "vfmul.vf v24, v8, f1 \n\t" - "vfmul.vf v25, v9, f1 \n\t" - "vfmul.vf v26, v10, f1 \n\t" - "vfmul.vf v27, v11, f1 \n\t" - "addi %[CNT], %[CNT], -1 \n\t" - - SQ4BIT_KERNEL_LOAD_ZP_16X1 - - "LOOP_INNER%=: \n\t" - - SQ4BIT_KERNEL_LOAD_1x8x2_4X8X4 - - "vsub.vv v0, v0, v8 \n\t" - "vsub.vv v4, v4, v8 \n\t" - "vsub.vv v1, v1, v9 \n\t" - "vsub.vv v5, v5, v9 \n\t" - "vsub.vv v2, v2, v10 \n\t" - "vsub.vv v6, v6, v10 \n\t" - "vsub.vv v3, v3, v11 \n\t" - "vsub.vv v7, v7, v11 \n\t" - - SQ4BIT_KERNEL_COMP_1x8x2_4X8X4 - - "bnez t5, LOOP_INNER%= \n\t" - "vsetvli t0, zero, e32, mf2 \n\t" - - SQ4BIT_KERNEL_ACC_F16_1X4X4 - "addi s7, s1, 32 \n\t" - - "bnez %[CNT], LOOP_K%= \n\t" - "addi t3, zero, 16 \n\t" - "addi s1, %[C], 16 \n\t" - "addi s2, %[C], 32 \n\t" - "addi s3, %[C], 48 \n\t" - "blt %[NBLKS], t3, ST_TAIL%= \n\t" - "vse32.v v28, (%[C]) \n\t" - "vse32.v v29, (s1) \n\t" - "vse32.v v30, (s2) \n\t" - "vse32.v v31, (s3) \n\t" - "jal x0, END%= \n\t" - - "ST_TAIL%=: \n\t" - "vsetvli t0, %[NBLKS], e32, mf2 \n\t" - "sub %[NBLKS], %[NBLKS], t0 \n\t" - "vse32.v v28, (%[C]) \n\t" - "vsetvli t0, %[NBLKS], e32, mf2 \n\t" - "sub %[NBLKS], %[NBLKS], t0 \n\t" - "vse32.v v29, (s1) \n\t" - "vsetvli t0, %[NBLKS], e32, mf2 \n\t" - "sub %[NBLKS], %[NBLKS], t0 \n\t" - "vse32.v v30, (s2) \n\t" - "vsetvli t0, %[NBLKS], e32, mf2 \n\t" - "sub %[NBLKS], %[NBLKS], t0 \n\t" - "vse32.v v31, (s3) \n\t" - "END%=: \n\t" - - : [CNT] "+r"(cnt), [NBLKS] "+r"(nblks), [BIAS] "+r"(bias) - : [INNER] "r"(INNER), [A] "r"(QuantA), [B] "r"(QuantBDataPtr), [C] "r"(CPtr) - : "cc", "t0", "t5", "t3", "f1", "s1", "s2", "s3", "s4", "s5", "s6", "s7"); - } else { - __asm__ volatile( - "vsetvli t0, zero, e32, m4 \n\t" - "vxor.vv v28, v28, v28 \n\t" - - "vsetvli t0, zero, e8, m1 \n\t" - "vmv.v.i v13, 3 \n\t" - "li s1, 24 \n\t" - "vsetvli t0, s1, e8, m1 \n\t" - "vmv.v.i v13, 2 \n\t" - "vsetvli t0, zero, e8, mf2 \n\t" - "vmv.v.i v13, 1 \n\t" - "vsetvli t0, zero, e8, mf4 \n\t" - "vmv.v.i v13, 0 \n\t" - - "addi s1, %[B], 0 \n\t" - "addi s2, %[B], 8 \n\t" - "addi s3, %[B], 16 \n\t" - "addi s4, %[B], 24 \n\t" - - "addi s7, %[B], 32 \n\t" - - "addi s5, %[A], 0 \n\t" - "addi s6, %[A], 12 \n\t" - "LOOP_K%=: \n\t" - "vsetvli t0, zero, e16, mf4 \n\t" - "vle16.v v4, (s1) \n\t" - "addi s1, s1, 48 \n\t" - "vle16.v v5, (s2) \n\t" - "addi s2, s2, 72 \n\t" - "vle16.v v6, (s3) \n\t" - "addi s3, s3, 96 \n\t" - "vle16.v v7, (s4) \n\t" - "addi s4, s4, 120 \n\t" - "flw f1, (s5) \n\t" - "addi s5, s5, 4 \n\t" - - "vfwcvt.f.f.v v8, v4 \n\t" - "vfwcvt.f.f.v v9, v5 \n\t" - "vfwcvt.f.f.v v10, v6 \n\t" - "vfwcvt.f.f.v v11, v7 \n\t" - "vsetvli t0, zero, e32, mf2 \n\t" - - "addi t5, %[INNER], 0 \n\t" - "vxor.vv v16, v16, v16 \n\t" - "vxor.vv v18, v18, v18 \n\t" - "vxor.vv v20, v20, v20 \n\t" - "vxor.vv v22, v22, v22 \n\t" - "vfmul.vf v24, v8, f1 \n\t" - "vfmul.vf v25, v9, f1 \n\t" - "vfmul.vf v26, v10, f1 \n\t" - "vfmul.vf v27, v11, f1 \n\t" - "addi %[CNT], %[CNT], -1 \n\t" - - SQ4BIT_KERNEL_LOAD_ZP_16X1 - - "LOOP_INNER%=: \n\t" - - SQ4BIT_KERNEL_LOAD_1x8x2_4X8X4 - - "vsub.vv v0, v0, v8 \n\t" - "vsub.vv v4, v4, v8 \n\t" - "vsub.vv v1, v1, v9 \n\t" - "vsub.vv v5, v5, v9 \n\t" - "vsub.vv v2, v2, v10 \n\t" - "vsub.vv v6, v6, v10 \n\t" - "vsub.vv v3, v3, v11 \n\t" - "vsub.vv v7, v7, v11 \n\t" - - SQ4BIT_KERNEL_COMP_1x8x2_4X8X4 - - "bnez t5, LOOP_INNER%= \n\t" - "vsetvli t0, zero, e32, mf2 \n\t" - - SQ4BIT_KERNEL_ACC_F16_1X4X4 - "addi s7, s1, 32 \n\t" - - "bnez %[CNT], LOOP_K%= \n\t" - "addi t3, zero, 16 \n\t" - "addi s1, %[C], 16 \n\t" - "addi s2, %[C], 32 \n\t" - "addi s3, %[C], 48 \n\t" - "blt %[NBLKS], t3, ST_TAIL%= \n\t" - "vse32.v v28, (%[C]) \n\t" - "vse32.v v29, (s1) \n\t" - "vse32.v v30, (s2) \n\t" - "vse32.v v31, (s3) \n\t" - "jal x0, END%= \n\t" - - "ST_TAIL%=: \n\t" - "vsetvli t0, %[NBLKS], e32, mf2 \n\t" - "sub %[NBLKS], %[NBLKS], t0 \n\t" - "vse32.v v28, (%[C]) \n\t" - "vsetvli t0, %[NBLKS], e32, mf2 \n\t" - "sub %[NBLKS], %[NBLKS], t0 \n\t" - "vse32.v v29, (s1) \n\t" - "vsetvli t0, %[NBLKS], e32, mf2 \n\t" - "sub %[NBLKS], %[NBLKS], t0 \n\t" - "vse32.v v30, (s2) \n\t" - "vsetvli t0, %[NBLKS], e32, mf2 \n\t" - "sub %[NBLKS], %[NBLKS], t0 \n\t" - "vse32.v v31, (s3) \n\t" - "END%=: \n\t" - - : [CNT] "+r"(cnt), [NBLKS] "+r"(nblks) - : [INNER] "r"(INNER), [A] "r"(QuantA), [B] "r"(QuantBDataPtr), [C] "r"(CPtr) - : "cc", "t0", "t5", "t3", "f1", "s1", "s2", "s3", "s4", "s5", "s6", "s7"); - } - } - } else { - for (size_t n = 0; n < CountN; n += 16) { - size_t nblks = (CountN - n) > 16 ? 16 : CountN - n; - std::byte * QuantBDataPtr = (std::byte *) QuantBData + // - n * BlockCountK * BlkLen / 2 + // b data - n * BlockCountK * sizeof(_Float16); // scale + size_t nblks = (CountN - n) > 16 ? 16 : CountN - n; + uint8_t * QuantBDataPtr = (uint8_t *) QuantBData + // + n * BlockCountK * BlkLen / 2 + // b data + n * BlockCountK * sizeof(uint8_t) + // zp + n * BlockCountK * sizeof(_Float16); // scale float * CPtr = C + n; size_t cnt = BlockCountK; - if (Bias != nullptr) { - const float * bias = Bias + n; - __asm__ volatile( - "addi t3, %[NBLKS], 0 \n\t" - "addi s1, %[B], 0 \n\t" - "addi s2, %[B], 8 \n\t" - "addi s3, %[B], 16 \n\t" - "addi s4, %[B], 24 \n\t" - "addi s5, %[A], 0 \n\t" - "addi s6, %[A], 12 \n\t" - "vsetvli t0, t3, e32, mf2 \n\t" - "vle32.v v28, (%[BIAS]) \n\t" - "sub t3, t3, t0 \n\t" - "addi %[BIAS], %[BIAS], 16 \n\t" - "vsetvli t0, t3, e32, mf2 \n\t" - "vle32.v v29, (%[BIAS]) \n\t" - "sub t3, t3, t0 \n\t" - "addi %[BIAS], %[BIAS], 16 \n\t" - "vsetvli t0, t3, e32, mf2 \n\t" - "vle32.v v30, (%[BIAS]) \n\t" - "sub t3, t3, t0 \n\t" - "addi %[BIAS], %[BIAS], 16 \n\t" - "vsetvli t0, t3, e32, mf2 \n\t" - "vle32.v v31, (%[BIAS]) \n\t" - - "LOOP_K%=: \n\t" - "vsetvli t0, zero, e16, mf4 \n\t" - - "vle16.v v4, (s1) \n\t" - "addi s1, s1, 32 \n\t" - "vle16.v v5, (s2) \n\t" - "addi s2, s2, 56 \n\t" - "vle16.v v6, (s3) \n\t" - "addi s3, s3, 80 \n\t" - "vle16.v v7, (s4) \n\t" - "addi s4, s4, 104 \n\t" - "flw f1, (s5) \n\t" - "addi s5, s5, 4 \n\t" - "vfwcvt.f.f.v v8, v4 \n\t" - "vfwcvt.f.f.v v9, v5 \n\t" - "vfwcvt.f.f.v v10, v6 \n\t" - "vfwcvt.f.f.v v11, v7 \n\t" - - "vsetvli t0, zero, e32, mf2 \n\t" - "addi t5, %[INNER], 0 \n\t" - "vxor.vv v16, v16, v16 \n\t" - "vxor.vv v18, v18, v18 \n\t" - "vxor.vv v20, v20, v20 \n\t" - "vxor.vv v22, v22, v22 \n\t" - "vfmul.vf v24, v8, f1 \n\t" - "vfmul.vf v25, v9, f1 \n\t" - "vfmul.vf v26, v10, f1 \n\t" - "vfmul.vf v27, v11, f1 \n\t" - "addi %[CNT], %[CNT], -1 \n\t" - "vsetvli t0, zero, e8, m1 \n\t" - "LOOP_INNER%=: \n\t" - - SQ4BIT_KERNEL_LOAD_1x8x2_4X8X4 - - "vadd.vi v0, v0, -8 \n\t" - "vadd.vi v1, v1, -8 \n\t" - "vadd.vi v2, v2, -8 \n\t" - "vadd.vi v3, v3, -8 \n\t" - "vadd.vi v4, v4, -8 \n\t" - "vadd.vi v5, v5, -8 \n\t" - "vadd.vi v6, v6, -8 \n\t" - "vadd.vi v7, v7, -8 \n\t" - - SQ4BIT_KERNEL_COMP_1x8x2_4X8X4 - - "bnez t5, LOOP_INNER%= \n\t" - "vsetvli t0, zero, e32, mf2 \n\t" - - SQ4BIT_KERNEL_ACC_F16_1X4X4 - - "bnez %[CNT], LOOP_K%= \n\t" - "addi t3, zero, 16 \n\t" - "addi s1, %[C], 16 \n\t" - "addi s2, %[C], 32 \n\t" - "addi s3, %[C], 48 \n\t" - "blt %[NBLKS], t3, ST_TAIL%= \n\t" - "vse32.v v28, (%[C]) \n\t" - "vse32.v v29, (s1) \n\t" - "vse32.v v30, (s2) \n\t" - "vse32.v v31, (s3) \n\t" - "jal x0, END%= \n\t" - - "ST_TAIL%=: \n\t" - "vsetvli t0, %[NBLKS], e32, mf2 \n\t" - "sub %[NBLKS], %[NBLKS], t0 \n\t" - "vse32.v v28, (%[C]) \n\t" - "vsetvli t0, %[NBLKS], e32, mf2 \n\t" - "sub %[NBLKS], %[NBLKS], t0 \n\t" - "vse32.v v29, (s1) \n\t" - "vsetvli t0, %[NBLKS], e32, mf2 \n\t" - "sub %[NBLKS], %[NBLKS], t0 \n\t" - "vse32.v v30, (s2) \n\t" - "vsetvli t0, %[NBLKS], e32, mf2 \n\t" - "sub %[NBLKS], %[NBLKS], t0 \n\t" - "vse32.v v31, (s3) \n\t" - "END%=: \n\t" - - : [CNT] "+r"(cnt), [NBLKS] "+r"(nblks), [BIAS] "+r"(bias) - : [INNER] "r"(INNER), [A] "r"(QuantA), [B] "r"(QuantBDataPtr), [C] "r"(CPtr) - : "cc", "t0", "t5", "t3", "f1", "s1", "s2", "s3", "s4", "s5", "s6"); - } else { - __asm__ volatile( - "vsetvli t0, zero, e32, m4 \n\t" - "vxor.vv v28, v28, v28 \n\t" - "addi s1, %[B], 0 \n\t" - "addi s2, %[B], 8 \n\t" - "addi s3, %[B], 16 \n\t" - "addi s4, %[B], 24 \n\t" - - "addi s5, %[A], 0 \n\t" - "addi s6, %[A], 12 \n\t" - "LOOP_K%=: \n\t" - "vsetvli t0, zero, e16, mf4 \n\t" - "vle16.v v4, (s1) \n\t" - "addi s1, s1, 32 \n\t" - "vle16.v v5, (s2) \n\t" - "addi s2, s2, 56 \n\t" - "vle16.v v6, (s3) \n\t" - "addi s3, s3, 80 \n\t" - "vle16.v v7, (s4) \n\t" - "addi s4, s4, 104 \n\t" - "flw f1, (s5) \n\t" - "addi s5, s5, 4 \n\t" - - "vfwcvt.f.f.v v8, v4 \n\t" - "vfwcvt.f.f.v v9, v5 \n\t" - "vfwcvt.f.f.v v10, v6 \n\t" - "vfwcvt.f.f.v v11, v7 \n\t" - "vsetvli t0, zero, e32, mf2 \n\t" - - "addi t5, %[INNER], 0 \n\t" - "vxor.vv v16, v16, v16 \n\t" - "vxor.vv v18, v18, v18 \n\t" - "vxor.vv v20, v20, v20 \n\t" - "vxor.vv v22, v22, v22 \n\t" - "vfmul.vf v24, v8, f1 \n\t" - "vfmul.vf v25, v9, f1 \n\t" - "vfmul.vf v26, v10, f1 \n\t" - "vfmul.vf v27, v11, f1 \n\t" - "addi %[CNT], %[CNT], -1 \n\t" - "vsetvli t0, zero, e8, m1 \n\t" - "LOOP_INNER%=: \n\t" - - SQ4BIT_KERNEL_LOAD_1x8x2_4X8X4 - - "vadd.vi v0, v0, -8 \n\t" - "vadd.vi v1, v1, -8 \n\t" - "vadd.vi v2, v2, -8 \n\t" - "vadd.vi v3, v3, -8 \n\t" - "vadd.vi v4, v4, -8 \n\t" - "vadd.vi v5, v5, -8 \n\t" - "vadd.vi v6, v6, -8 \n\t" - "vadd.vi v7, v7, -8 \n\t" - - SQ4BIT_KERNEL_COMP_1x8x2_4X8X4 - - "bnez t5, LOOP_INNER%= \n\t" - "vsetvli t0, zero, e32, mf2 \n\t" - - SQ4BIT_KERNEL_ACC_F16_1X4X4 - - "bnez %[CNT], LOOP_K%= \n\t" - "addi t3, zero, 16 \n\t" - "addi s1, %[C], 16 \n\t" - "addi s2, %[C], 32 \n\t" - "addi s3, %[C], 48 \n\t" - "blt %[NBLKS], t3, ST_TAIL%= \n\t" - "vse32.v v28, (%[C]) \n\t" - "vse32.v v29, (s1) \n\t" - "vse32.v v30, (s2) \n\t" - "vse32.v v31, (s3) \n\t" - "jal x0, END%= \n\t" - - "ST_TAIL%=: \n\t" - "vsetvli t0, %[NBLKS], e32, mf2 \n\t" - "sub %[NBLKS], %[NBLKS], t0 \n\t" - "vse32.v v28, (%[C]) \n\t" - "vsetvli t0, %[NBLKS], e32, mf2 \n\t" - "sub %[NBLKS], %[NBLKS], t0 \n\t" - "vse32.v v29, (s1) \n\t" - "vsetvli t0, %[NBLKS], e32, mf2 \n\t" - "sub %[NBLKS], %[NBLKS], t0 \n\t" - "vse32.v v30, (s2) \n\t" - "vsetvli t0, %[NBLKS], e32, mf2 \n\t" - "sub %[NBLKS], %[NBLKS], t0 \n\t" - "vse32.v v31, (s3) \n\t" - "END%=: \n\t" - - : [CNT] "+r"(cnt), [NBLKS] "+r"(nblks) - : [INNER] "r"(INNER), [A] "r"(QuantA), [B] "r"(QuantBDataPtr), [C] "r"(CPtr) - : "cc", "t0", "t5", "t3", "f1", "s1", "s2", "s3", "s4", "s5", "s6"); - } - } - } -} -template <bool HasZeroPoint> -void SQ4BitGemmM1Kernel_CompInt8_Impl(size_t BlkLen, - const std::byte * QuantA, - const std::byte * QuantBData, - const float * QuantBScale, - const std::byte * QuantBZeroPoint, - float * C, - size_t CountN, - size_t BlockCountK, - const float * Bias) { - GGML_UNUSED(QuantBScale); - GGML_UNUSED(QuantBZeroPoint); - const size_t INNER = BlkLen / 16; - if constexpr (HasZeroPoint) { - for (size_t n = 0; n < CountN; n += 16) { - size_t nblks = (CountN - n) > 16 ? 16 : CountN - n; - std::byte * QuantBDataPtr = (std::byte *) QuantBData + // - n * BlockCountK * BlkLen / 2 + // b data - n * BlockCountK * sizeof(uint8_t) + // zp - n * BlockCountK * sizeof(float); // scale - float * CPtr = C + n; - size_t cnt = BlockCountK; - if (Bias != nullptr) { - const float * bias = Bias + n; - __asm__ volatile( - "addi t3, %[NBLKS], 0 \n\t" - "vsetvli t0, zero, e8, m1 \n\t" - "vmv.v.i v13, 3 \n\t" - "li s1, 24 \n\t" - "vsetvli t0, s1, e8, m1 \n\t" - "vmv.v.i v13, 2 \n\t" - "vsetvli t0, zero, e8, mf2 \n\t" - "vmv.v.i v13, 1 \n\t" - "vsetvli t0, zero, e8, mf4 \n\t" - "vmv.v.i v13, 0 \n\t" - "vsetvli t0, zero, e32, m4 \n\t" - "vxor.vv v28, v28, v28 \n\t" - - // scale offset, scale0.0, scale1.0, scale2.0, scale3.0....scale15.0 - "addi s1, %[B], 0 \n\t" - "addi s2, %[B], 16 \n\t" - "addi s3, %[B], 32 \n\t" - "addi s4, %[B], 48 \n\t" - // zp offset - "addi s7, %[B], 64 \n\t" - // a offset - "addi s5, %[A], 0 \n\t" - "addi s6, %[A], 12 \n\t" - - "vsetvli t0, t3, e32, mf2 \n\t" - "vle32.v v28, (%[BIAS]) \n\t" - "sub t3, t3, t0 \n\t" - "addi %[BIAS], %[BIAS], 16 \n\t" - "vsetvli t0, t3, e32, mf2 \n\t" - "vle32.v v29, (%[BIAS]) \n\t" - "sub t3, t3, t0 \n\t" - "addi %[BIAS], %[BIAS], 16 \n\t" - "vsetvli t0, t3, e32, mf2 \n\t" - "vle32.v v30, (%[BIAS]) \n\t" - "sub t3, t3, t0 \n\t" - "addi %[BIAS], %[BIAS], 16 \n\t" - "vsetvli t0, t3, e32, mf2 \n\t" - "vle32.v v31, (%[BIAS]) \n\t" - "vsetvli t0, zero, e32, mf2 \n\t" - "LOOP_K%=: \n\t" - - // load scale - "vle32.v v8, (s1) \n\t" - "addi s1, s1, 80 \n\t" - "vle32.v v9, (s2) \n\t" - "addi s2, s2, 96 \n\t" - "vle32.v v10, (s3) \n\t" - "addi s3, s3, 112 \n\t" - "vle32.v v11, (s4) \n\t" - "addi s4, s4, 128 \n\t" - - // load a scale - "flw f1, (s5) \n\t" - "addi s5, s5, 4 \n\t" - - "addi t5, %[INNER], 0 \n\t" - "vxor.vv v16, v16, v16 \n\t" - "vxor.vv v18, v18, v18 \n\t" - "vxor.vv v20, v20, v20 \n\t" - "vxor.vv v22, v22, v22 \n\t" - - // a scale * b scale - "vfmul.vf v24, v8, f1 \n\t" - "vfmul.vf v25, v9, f1 \n\t" - "vfmul.vf v26, v10, f1 \n\t" - "vfmul.vf v27, v11, f1 \n\t" - "addi %[CNT], %[CNT], -1 \n\t" - - SQ4BIT_KERNEL_LOAD_ZP_16X1 - - "LOOP_INNER%=: \n\t" - - SQ4BIT_KERNEL_LOAD_1x8x2_4X8X4 - - "vsub.vv v0, v0, v8 \n\t" - "vsub.vv v4, v4, v8 \n\t" - "vsub.vv v1, v1, v9 \n\t" - "vsub.vv v5, v5, v9 \n\t" - "vsub.vv v2, v2, v10 \n\t" - "vsub.vv v6, v6, v10 \n\t" - "vsub.vv v3, v3, v11 \n\t" - "vsub.vv v7, v7, v11 \n\t" - - SQ4BIT_KERNEL_COMP_1x8x2_4X8X4 - - "bnez t5, LOOP_INNER%= \n\t" - "vsetvli t0, zero, e32, mf2 \n\t" - - SQ4BIT_KERNEL_ACC_1X4X4 - "addi s7, s1, 64 \n\t" - - "bnez %[CNT], LOOP_K%= \n\t" - - "addi t3, zero, 16 \n\t" - "addi s1, %[C], 16 \n\t" - "addi s2, %[C], 32 \n\t" - "addi s3, %[C], 48 \n\t" - "blt %[NBLKS], t3, ST_TAIL%= \n\t" - "vse32.v v28, (%[C]) \n\t" - "vse32.v v29, (s1) \n\t" - "vse32.v v30, (s2) \n\t" - "vse32.v v31, (s3) \n\t" - "jal x0, END%= \n\t" - - "ST_TAIL%=: \n\t" - "vsetvli t0, %[NBLKS], e32, mf2 \n\t" - "sub %[NBLKS], %[NBLKS], t0 \n\t" - "vse32.v v28, (%[C]) \n\t" - "vsetvli t0, %[NBLKS], e32, mf2 \n\t" - "sub %[NBLKS], %[NBLKS], t0 \n\t" - "vse32.v v29, (s1) \n\t" - "vsetvli t0, %[NBLKS], e32, mf2 \n\t" - "sub %[NBLKS], %[NBLKS], t0 \n\t" - "vse32.v v30, (s2) \n\t" - "vsetvli t0, %[NBLKS], e32, mf2 \n\t" - "sub %[NBLKS], %[NBLKS], t0 \n\t" - "vse32.v v31, (s3) \n\t" - "END%=: \n\t" - - : [CNT] "+r"(cnt), [NBLKS] "+r"(nblks), [BIAS] "+r"(bias) - : [INNER] "r"(INNER), [A] "r"(QuantA), [B] "r"(QuantBDataPtr), [C] "r"(CPtr) - : "cc", "t0", "t5", "t3", "f1", "s1", "s2", "s3", "s4", "s5", "s6", "s7"); - } else { - __asm__ volatile( - "vsetvli t0, zero, e32, m4 \n\t" - "vxor.vv v28, v28, v28 \n\t" - - "vsetvli t0, zero, e8, m1 \n\t" - "vmv.v.i v13, 3 \n\t" - "li s1, 24 \n\t" - "vsetvli t0, s1, e8, m1 \n\t" - "vmv.v.i v13, 2 \n\t" - "vsetvli t0, zero, e8, mf2 \n\t" - "vmv.v.i v13, 1 \n\t" - "vsetvli t0, zero, e8, mf4 \n\t" - "vmv.v.i v13, 0 \n\t" - "addi s1, %[B], 0 \n\t" - "addi s2, %[B], 16 \n\t" - "addi s3, %[B], 32 \n\t" - "addi s4, %[B], 48 \n\t" - - "addi s7, %[B], 64 \n\t" - - "addi s5, %[A], 0 \n\t" - "addi s6, %[A], 12 \n\t" - "vsetvli t0, zero, e32, mf2 \n\t" - - "LOOP_K%=: \n\t" - "vle32.v v8, (s1) \n\t" - "addi s1, s1, 80 \n\t" - "vle32.v v9, (s2) \n\t" - "addi s2, s2, 96 \n\t" - "vle32.v v10, (s3) \n\t" - "addi s3, s3, 112 \n\t" - "vle32.v v11, (s4) \n\t" - "addi s4, s4, 128 \n\t" - - "flw f1, (s5) \n\t" - "addi s5, s5, 4 \n\t" - - "addi t5, %[INNER], 0 \n\t" - "vxor.vv v16, v16, v16 \n\t" - "vxor.vv v18, v18, v18 \n\t" - "vxor.vv v20, v20, v20 \n\t" - "vxor.vv v22, v22, v22 \n\t" - - "vfmul.vf v24, v8, f1 \n\t" - "vfmul.vf v25, v9, f1 \n\t" - "vfmul.vf v26, v10, f1 \n\t" - "vfmul.vf v27, v11, f1 \n\t" - "addi %[CNT], %[CNT], -1 \n\t" - - SQ4BIT_KERNEL_LOAD_ZP_16X1 - - "LOOP_INNER%=: \n\t" - - SQ4BIT_KERNEL_LOAD_1x8x2_4X8X4 - - "vsub.vv v0, v0, v8 \n\t" - "vsub.vv v4, v4, v8 \n\t" - "vsub.vv v1, v1, v9 \n\t" - "vsub.vv v5, v5, v9 \n\t" - "vsub.vv v2, v2, v10 \n\t" - "vsub.vv v6, v6, v10 \n\t" - "vsub.vv v3, v3, v11 \n\t" - "vsub.vv v7, v7, v11 \n\t" - - SQ4BIT_KERNEL_COMP_1x8x2_4X8X4 - - "bnez t5, LOOP_INNER%= \n\t" - "vsetvli t0, zero, e32, mf2 \n\t" - - SQ4BIT_KERNEL_ACC_1X4X4 - "addi s7, s1, 64 \n\t" - - "bnez %[CNT], LOOP_K%= \n\t" - - "addi t3, zero, 16 \n\t" - "addi s1, %[C], 16 \n\t" - "addi s2, %[C], 32 \n\t" - "addi s3, %[C], 48 \n\t" - "blt %[NBLKS], t3, ST_TAIL%= \n\t" - "vse32.v v28, (%[C]) \n\t" - "vse32.v v29, (s1) \n\t" - "vse32.v v30, (s2) \n\t" - "vse32.v v31, (s3) \n\t" - "jal x0, END%= \n\t" - - "ST_TAIL%=: \n\t" - "vsetvli t0, %[NBLKS], e32, mf2 \n\t" - "sub %[NBLKS], %[NBLKS], t0 \n\t" - "vse32.v v28, (%[C]) \n\t" - "vsetvli t0, %[NBLKS], e32, mf2 \n\t" - "sub %[NBLKS], %[NBLKS], t0 \n\t" - "vse32.v v29, (s1) \n\t" - "vsetvli t0, %[NBLKS], e32, mf2 \n\t" - "sub %[NBLKS], %[NBLKS], t0 \n\t" - "vse32.v v30, (s2) \n\t" - "vsetvli t0, %[NBLKS], e32, mf2 \n\t" - "sub %[NBLKS], %[NBLKS], t0 \n\t" - "vse32.v v31, (s3) \n\t" - "END%=: \n\t" - - : [CNT] "+r"(cnt), [NBLKS] "+r"(nblks) - : [INNER] "r"(INNER), [A] "r"(QuantA), [B] "r"(QuantBDataPtr), [C] "r"(CPtr) - : "cc", "t0", "t5", "t3", "f1", "s1", "s2", "s3", "s4", "s5", "s6", "s7"); - } + __asm__ volatile( + "vsetvli t0, zero, e32, m4 \n\t" + "vxor.vv v28, v28, v28 \n\t" + + "vsetvli t0, zero, e8, m1 \n\t" + "vmv.v.i v13, 3 \n\t" + "li s1, 24 \n\t" + "vsetvli t0, s1, e8, m1 \n\t" + "vmv.v.i v13, 2 \n\t" + "vsetvli t0, zero, e8, mf2 \n\t" + "vmv.v.i v13, 1 \n\t" + "vsetvli t0, zero, e8, mf4 \n\t" + "vmv.v.i v13, 0 \n\t" + + "addi s1, %[B], 0 \n\t" + "addi s2, %[B], 8 \n\t" + "addi s3, %[B], 16 \n\t" + "addi s4, %[B], 24 \n\t" + + "addi s7, %[B], 32 \n\t" + + "addi s5, %[A], 0 \n\t" + "addi s6, %[A], 12 \n\t" + "LOOP_K%=: \n\t" + "vsetvli t0, zero, e16, mf4 \n\t" + "vle16.v v4, (s1) \n\t" + "addi s1, s1, 48 \n\t" + "vle16.v v5, (s2) \n\t" + "addi s2, s2, 72 \n\t" + "vle16.v v6, (s3) \n\t" + "addi s3, s3, 96 \n\t" + "vle16.v v7, (s4) \n\t" + "addi s4, s4, 120 \n\t" + "flw f1, (s5) \n\t" + "addi s5, s5, 4 \n\t" + + "vfwcvt.f.f.v v8, v4 \n\t" + "vfwcvt.f.f.v v9, v5 \n\t" + "vfwcvt.f.f.v v10, v6 \n\t" + "vfwcvt.f.f.v v11, v7 \n\t" + "vsetvli t0, zero, e32, mf2 \n\t" + + "addi t5, %[INNER], 0 \n\t" + "vxor.vv v16, v16, v16 \n\t" + "vxor.vv v18, v18, v18 \n\t" + "vxor.vv v20, v20, v20 \n\t" + "vxor.vv v22, v22, v22 \n\t" + "vfmul.vf v24, v8, f1 \n\t" + "vfmul.vf v25, v9, f1 \n\t" + "vfmul.vf v26, v10, f1 \n\t" + "vfmul.vf v27, v11, f1 \n\t" + "addi %[CNT], %[CNT], -1 \n\t" + + SQ4BIT_KERNEL_LOAD_ZP_16X1 + + "LOOP_INNER%=: \n\t" + + SQ4BIT_KERNEL_LOAD_1x8x2_4X8X4 + + "vsub.vv v0, v0, v8 \n\t" + "vsub.vv v4, v4, v8 \n\t" + "vsub.vv v1, v1, v9 \n\t" + "vsub.vv v5, v5, v9 \n\t" + "vsub.vv v2, v2, v10 \n\t" + "vsub.vv v6, v6, v10 \n\t" + "vsub.vv v3, v3, v11 \n\t" + "vsub.vv v7, v7, v11 \n\t" + + SQ4BIT_KERNEL_COMP_1x8x2_4X8X4 + + "bnez t5, LOOP_INNER%= \n\t" + "vsetvli t0, zero, e32, mf2 \n\t" + + SQ4BIT_KERNEL_ACC_F16_1X4X4 + "addi s7, s1, 32 \n\t" + + "bnez %[CNT], LOOP_K%= \n\t" + "addi t3, zero, 16 \n\t" + "addi s1, %[C], 16 \n\t" + "addi s2, %[C], 32 \n\t" + "addi s3, %[C], 48 \n\t" + "blt %[NBLKS], t3, ST_TAIL%= \n\t" + "vse32.v v28, (%[C]) \n\t" + "vse32.v v29, (s1) \n\t" + "vse32.v v30, (s2) \n\t" + "vse32.v v31, (s3) \n\t" + "jal x0, END%= \n\t" + + "ST_TAIL%=: \n\t" + "vsetvli t0, %[NBLKS], e32, mf2 \n\t" + "sub %[NBLKS], %[NBLKS], t0 \n\t" + "vse32.v v28, (%[C]) \n\t" + "vsetvli t0, %[NBLKS], e32, mf2 \n\t" + "sub %[NBLKS], %[NBLKS], t0 \n\t" + "vse32.v v29, (s1) \n\t" + "vsetvli t0, %[NBLKS], e32, mf2 \n\t" + "sub %[NBLKS], %[NBLKS], t0 \n\t" + "vse32.v v30, (s2) \n\t" + "vsetvli t0, %[NBLKS], e32, mf2 \n\t" + "sub %[NBLKS], %[NBLKS], t0 \n\t" + "vse32.v v31, (s3) \n\t" + "END%=: \n\t" + + : [CNT] "+r"(cnt), [NBLKS] "+r"(nblks) + : [INNER] "r"(INNER), [A] "r"(QuantA), [B] "r"(QuantBDataPtr), [C] "r"(CPtr) + : "cc", "t0", "t5", "t3", "f1", "s1", "s2", "s3", "s4", "s5", "s6", "s7"); } } else { for (size_t n = 0; n < CountN; n += 16) { - size_t nblks = (CountN - n) > 16 ? 16 : CountN - n; - std::byte * QuantBDataPtr = (std::byte *) QuantBData + // - n * BlockCountK * BlkLen / 2 + // b data - n * BlockCountK * sizeof(float); // scale + size_t nblks = (CountN - n) > 16 ? 16 : CountN - n; + uint8_t * QuantBDataPtr = (uint8_t *) QuantBData + // + n * BlockCountK * BlkLen / 2 + // b data + n * BlockCountK * sizeof(_Float16); // scale float * CPtr = C + n; size_t cnt = BlockCountK; - if (Bias != nullptr) { - const float * bias = Bias + n; - __asm__ volatile( - "addi t3, %[NBLKS], 0 \n\t" - "addi s1, %[B], 0 \n\t" - "addi s2, %[B], 16 \n\t" - "addi s3, %[B], 32 \n\t" - "addi s4, %[B], 48 \n\t" - "addi s5, %[A], 0 \n\t" - "addi s6, %[A], 12 \n\t" - "vsetvli t0, t3, e32, mf2 \n\t" - "vle32.v v28, (%[BIAS]) \n\t" - "sub t3, t3, t0 \n\t" - "addi %[BIAS], %[BIAS], 16 \n\t" - "vsetvli t0, t3, e32, mf2 \n\t" - "vle32.v v29, (%[BIAS]) \n\t" - "sub t3, t3, t0 \n\t" - "addi %[BIAS], %[BIAS], 16 \n\t" - "vsetvli t0, t3, e32, mf2 \n\t" - "vle32.v v30, (%[BIAS]) \n\t" - "sub t3, t3, t0 \n\t" - "addi %[BIAS], %[BIAS], 16 \n\t" - "vsetvli t0, t3, e32, mf2 \n\t" - "vle32.v v31, (%[BIAS]) \n\t" - "vsetvli t0, zero, e32, mf2 \n\t" - "LOOP_K%=: \n\t" - "vle32.v v8, (s1) \n\t" - "addi s1, s1, 64 \n\t" - "vle32.v v9, (s2) \n\t" - "addi s2, s2, 80 \n\t" - "vle32.v v10, (s3) \n\t" - "addi s3, s3, 96 \n\t" - "vle32.v v11, (s4) \n\t" - "addi s4, s4, 112 \n\t" - "flw f1, (s5) \n\t" - "addi s5, s5, 4 \n\t" - - "addi t5, %[INNER], 0 \n\t" - "vxor.vv v16, v16, v16 \n\t" - "vxor.vv v18, v18, v18 \n\t" - "vxor.vv v20, v20, v20 \n\t" - "vxor.vv v22, v22, v22 \n\t" - "vfmul.vf v24, v8, f1 \n\t" - "vfmul.vf v25, v9, f1 \n\t" - "vfmul.vf v26, v10, f1 \n\t" - "vfmul.vf v27, v11, f1 \n\t" - "addi %[CNT], %[CNT], -1 \n\t" - "vsetvli t0, zero, e8, m1 \n\t" - "LOOP_INNER%=: \n\t" - - SQ4BIT_KERNEL_LOAD_1x8x2_4X8X4 - - "vadd.vi v0, v0, -8 \n\t" - "vadd.vi v1, v1, -8 \n\t" - "vadd.vi v2, v2, -8 \n\t" - "vadd.vi v3, v3, -8 \n\t" - "vadd.vi v4, v4, -8 \n\t" - "vadd.vi v5, v5, -8 \n\t" - "vadd.vi v6, v6, -8 \n\t" - "vadd.vi v7, v7, -8 \n\t" - - SQ4BIT_KERNEL_COMP_1x8x2_4X8X4 - - "bnez t5, LOOP_INNER%= \n\t" - "vsetvli t0, zero, e32, mf2 \n\t" - - SQ4BIT_KERNEL_ACC_1X4X4 - - "bnez %[CNT], LOOP_K%= \n\t" - "addi t3, zero, 16 \n\t" - "addi s1, %[C], 16 \n\t" - "addi s2, %[C], 32 \n\t" - "addi s3, %[C], 48 \n\t" - "blt %[NBLKS], t3, ST_TAIL%= \n\t" - "vse32.v v28, (%[C]) \n\t" - "vse32.v v29, (s1) \n\t" - "vse32.v v30, (s2) \n\t" - "vse32.v v31, (s3) \n\t" - "jal x0, END%= \n\t" - - "ST_TAIL%=: \n\t" - "vsetvli t0, %[NBLKS], e32, mf2 \n\t" - "sub %[NBLKS], %[NBLKS], t0 \n\t" - "vse32.v v28, (%[C]) \n\t" - "vsetvli t0, %[NBLKS], e32, mf2 \n\t" - "sub %[NBLKS], %[NBLKS], t0 \n\t" - "vse32.v v29, (s1) \n\t" - "vsetvli t0, %[NBLKS], e32, mf2 \n\t" - "sub %[NBLKS], %[NBLKS], t0 \n\t" - "vse32.v v30, (s2) \n\t" - "vsetvli t0, %[NBLKS], e32, mf2 \n\t" - "sub %[NBLKS], %[NBLKS], t0 \n\t" - "vse32.v v31, (s3) \n\t" - "END%=: \n\t" - - : [CNT] "+r"(cnt), [NBLKS] "+r"(nblks), [BIAS] "+r"(bias) - : [INNER] "r"(INNER), [A] "r"(QuantA), [B] "r"(QuantBDataPtr), [C] "r"(CPtr) - : "cc", "t0", "t5", "t3", "f1", "s1", "s2", "s3", "s4", "s5", "s6"); - } else { - __asm__ volatile( - "vsetvli t0, zero, e32, m4 \n\t" - "vxor.vv v28, v28, v28 \n\t" - "addi s1, %[B], 0 \n\t" - "addi s2, %[B], 16 \n\t" - "addi s3, %[B], 32 \n\t" - "addi s4, %[B], 48 \n\t" - - "addi s5, %[A], 0 \n\t" - "addi s6, %[A], 12 \n\t" - "vsetvli t0, zero, e32, mf2 \n\t" - "LOOP_K%=: \n\t" - "vle32.v v8, (s1) \n\t" - "addi s1, s1, 64 \n\t" - "vle32.v v9, (s2) \n\t" - "addi s2, s2, 80 \n\t" - "vle32.v v10, (s3) \n\t" - "addi s3, s3, 96 \n\t" - "vle32.v v11, (s4) \n\t" - "addi s4, s4, 112 \n\t" - "flw f1, (s5) \n\t" - "addi s5, s5, 4 \n\t" - - "addi t5, %[INNER], 0 \n\t" - "vxor.vv v16, v16, v16 \n\t" - "vxor.vv v18, v18, v18 \n\t" - "vxor.vv v20, v20, v20 \n\t" - "vxor.vv v22, v22, v22 \n\t" - "vfmul.vf v24, v8, f1 \n\t" - "vfmul.vf v25, v9, f1 \n\t" - "vfmul.vf v26, v10, f1 \n\t" - "vfmul.vf v27, v11, f1 \n\t" - "addi %[CNT], %[CNT], -1 \n\t" - "vsetvli t0, zero, e8, m1 \n\t" - "LOOP_INNER%=: \n\t" - - SQ4BIT_KERNEL_LOAD_1x8x2_4X8X4 - - "vadd.vi v0, v0, -8 \n\t" - "vadd.vi v1, v1, -8 \n\t" - "vadd.vi v2, v2, -8 \n\t" - "vadd.vi v3, v3, -8 \n\t" - "vadd.vi v4, v4, -8 \n\t" - "vadd.vi v5, v5, -8 \n\t" - "vadd.vi v6, v6, -8 \n\t" - "vadd.vi v7, v7, -8 \n\t" - - SQ4BIT_KERNEL_COMP_1x8x2_4X8X4 - - "bnez t5, LOOP_INNER%= \n\t" - "vsetvli t0, zero, e32, mf2 \n\t" - - SQ4BIT_KERNEL_ACC_1X4X4 - - "bnez %[CNT], LOOP_K%= \n\t" - "addi t3, zero, 16 \n\t" - "addi s1, %[C], 16 \n\t" - "addi s2, %[C], 32 \n\t" - "addi s3, %[C], 48 \n\t" - "blt %[NBLKS], t3, ST_TAIL%= \n\t" - "vse32.v v28, (%[C]) \n\t" - "vse32.v v29, (s1) \n\t" - "vse32.v v30, (s2) \n\t" - "vse32.v v31, (s3) \n\t" - "jal x0, END%= \n\t" - - "ST_TAIL%=: \n\t" - "vsetvli t0, %[NBLKS], e32, mf2 \n\t" - "sub %[NBLKS], %[NBLKS], t0 \n\t" - "vse32.v v28, (%[C]) \n\t" - "vsetvli t0, %[NBLKS], e32, mf2 \n\t" - "sub %[NBLKS], %[NBLKS], t0 \n\t" - "vse32.v v29, (s1) \n\t" - "vsetvli t0, %[NBLKS], e32, mf2 \n\t" - "sub %[NBLKS], %[NBLKS], t0 \n\t" - "vse32.v v30, (s2) \n\t" - "vsetvli t0, %[NBLKS], e32, mf2 \n\t" - "sub %[NBLKS], %[NBLKS], t0 \n\t" - "vse32.v v31, (s3) \n\t" - "END%=: \n\t" - - : [CNT] "+r"(cnt), [NBLKS] "+r"(nblks) - : [INNER] "r"(INNER), [A] "r"(QuantA), [B] "r"(QuantBDataPtr), [C] "r"(CPtr) - : "cc", "t0", "t5", "t3", "f1", "s1", "s2", "s3", "s4", "s5", "s6"); - } - } - } -} - -template <bool HasZeroPoint> -inline void SQ4BitGemmM4Kernel_CompInt8_DispatchOnBlkLen(size_t BlkLen, - const std::byte * QuantA, - const std::byte * QuantBData, - const float * QuantBScale, - const std::byte * QuantBZeroPoint, - float * C, - size_t CountM, - size_t CountN, - size_t BlockStrideQuantB, - const float * Bias, - const size_t ldc, - const size_t scalestride) { - if (scalestride == 4) { - SQ4BitGemmM4Kernel_CompInt8_Impl<HasZeroPoint>(BlkLen, QuantA, QuantBData, QuantBScale, QuantBZeroPoint, C, - CountN, BlockStrideQuantB, Bias, ldc); - - } else if (scalestride == 2) { - SQ4BitGemmM4Kernel_CompInt8_ScaleFp16_Impl<HasZeroPoint>( - BlkLen, QuantA, QuantBData, QuantBScale, QuantBZeroPoint, C, CountN, BlockStrideQuantB, Bias, ldc); - } -} -template <bool HasZeroPoint> -inline void SQ4BitGemmM1Kernel_CompInt8_DispatchOnBlkLen(size_t BlkLen, - const std::byte * QuantA, - const std::byte * QuantBData, - const float * QuantBScale, - const std::byte * QuantBZeroPoint, - float * C, - size_t CountM, - size_t CountN, - size_t BlockStrideQuantB, - const float * Bias, - const size_t ldc, - const size_t scalestride) { - if (scalestride == 4) { - SQ4BitGemmM1Kernel_CompInt8_Impl<HasZeroPoint>(BlkLen, QuantA, QuantBData, QuantBScale, QuantBZeroPoint, C, - CountN, BlockStrideQuantB, Bias); - } else if (scalestride == 2) { - SQ4BitGemmM1Kernel_CompInt8_ScaleFp16_Impl<HasZeroPoint>(BlkLen, QuantA, QuantBData, QuantBScale, - QuantBZeroPoint, C, CountN, BlockStrideQuantB, Bias); + __asm__ volatile( + "vsetvli t0, zero, e32, m4 \n\t" + "vxor.vv v28, v28, v28 \n\t" + "addi s1, %[B], 0 \n\t" + "addi s2, %[B], 8 \n\t" + "addi s3, %[B], 16 \n\t" + "addi s4, %[B], 24 \n\t" + + "addi s5, %[A], 0 \n\t" + "addi s6, %[A], 12 \n\t" + "LOOP_K%=: \n\t" + "vsetvli t0, zero, e16, mf4 \n\t" + "vle16.v v4, (s1) \n\t" + "addi s1, s1, 32 \n\t" + "vle16.v v5, (s2) \n\t" + "addi s2, s2, 56 \n\t" + "vle16.v v6, (s3) \n\t" + "addi s3, s3, 80 \n\t" + "vle16.v v7, (s4) \n\t" + "addi s4, s4, 104 \n\t" + "flw f1, (s5) \n\t" + "addi s5, s5, 4 \n\t" + + "vfwcvt.f.f.v v8, v4 \n\t" + "vfwcvt.f.f.v v9, v5 \n\t" + "vfwcvt.f.f.v v10, v6 \n\t" + "vfwcvt.f.f.v v11, v7 \n\t" + "vsetvli t0, zero, e32, mf2 \n\t" + + "addi t5, %[INNER], 0 \n\t" + "vxor.vv v16, v16, v16 \n\t" + "vxor.vv v18, v18, v18 \n\t" + "vxor.vv v20, v20, v20 \n\t" + "vxor.vv v22, v22, v22 \n\t" + "vfmul.vf v24, v8, f1 \n\t" + "vfmul.vf v25, v9, f1 \n\t" + "vfmul.vf v26, v10, f1 \n\t" + "vfmul.vf v27, v11, f1 \n\t" + "addi %[CNT], %[CNT], -1 \n\t" + "vsetvli t0, zero, e8, m1 \n\t" + "LOOP_INNER%=: \n\t" + + SQ4BIT_KERNEL_LOAD_1x8x2_4X8X4 + + "vadd.vi v0, v0, -8 \n\t" + "vadd.vi v1, v1, -8 \n\t" + "vadd.vi v2, v2, -8 \n\t" + "vadd.vi v3, v3, -8 \n\t" + "vadd.vi v4, v4, -8 \n\t" + "vadd.vi v5, v5, -8 \n\t" + "vadd.vi v6, v6, -8 \n\t" + "vadd.vi v7, v7, -8 \n\t" + + SQ4BIT_KERNEL_COMP_1x8x2_4X8X4 + + "bnez t5, LOOP_INNER%= \n\t" + "vsetvli t0, zero, e32, mf2 \n\t" + + SQ4BIT_KERNEL_ACC_F16_1X4X4 + + "bnez %[CNT], LOOP_K%= \n\t" + "addi t3, zero, 16 \n\t" + "addi s1, %[C], 16 \n\t" + "addi s2, %[C], 32 \n\t" + "addi s3, %[C], 48 \n\t" + "blt %[NBLKS], t3, ST_TAIL%= \n\t" + "vse32.v v28, (%[C]) \n\t" + "vse32.v v29, (s1) \n\t" + "vse32.v v30, (s2) \n\t" + "vse32.v v31, (s3) \n\t" + "jal x0, END%= \n\t" + + "ST_TAIL%=: \n\t" + "vsetvli t0, %[NBLKS], e32, mf2 \n\t" + "sub %[NBLKS], %[NBLKS], t0 \n\t" + "vse32.v v28, (%[C]) \n\t" + "vsetvli t0, %[NBLKS], e32, mf2 \n\t" + "sub %[NBLKS], %[NBLKS], t0 \n\t" + "vse32.v v29, (s1) \n\t" + "vsetvli t0, %[NBLKS], e32, mf2 \n\t" + "sub %[NBLKS], %[NBLKS], t0 \n\t" + "vse32.v v30, (s2) \n\t" + "vsetvli t0, %[NBLKS], e32, mf2 \n\t" + "sub %[NBLKS], %[NBLKS], t0 \n\t" + "vse32.v v31, (s3) \n\t" + "END%=: \n\t" + + : [CNT] "+r"(cnt), [NBLKS] "+r"(nblks) + : [INNER] "r"(INNER), [A] "r"(QuantA), [B] "r"(QuantBDataPtr), [C] "r"(CPtr) + : "cc", "t0", "t5", "t3", "f1", "s1", "s2", "s3", "s4", "s5", "s6"); + } } } - } // namespace namespace ime1 { -size_t gemm_kernel_i8i4(size_t BlkLen, - const std::byte * QuantA, - const std::byte * QuantBData, - const float * QuantBScale, - const std::byte * QuantBZeroPoint, - float * C, - size_t CountM, - size_t CountN, - size_t CountK, - size_t BlockCountK, - size_t ldc, - const float * Bias, - const size_t ScaleStride) { - GGML_UNUSED(CountM); - GGML_UNUSED(CountK); - GGML_UNUSED(ldc); - if (CountM >= 4) { - if (QuantBZeroPoint != nullptr) { - SQ4BitGemmM4Kernel_CompInt8_DispatchOnBlkLen<true>(BlkLen, QuantA, QuantBData, QuantBScale, QuantBZeroPoint, - C, CountM, CountN, BlockCountK, Bias, ldc, ScaleStride); +size_t gemm_kernel_i8i4(size_t blk_len, + const uint8_t * quant_a_ptr, + const uint8_t * quant_b_data, + const uint8_t * quant_b_zp, + float * c_ptr, + size_t count_m, + size_t count_n, + size_t k_blks, + size_t ldc) { + if (count_m >= 4) { + if (quant_b_zp != nullptr) { + SQ4BitGemmM4Kernel_CompInt8_ScaleFp16_Impl<true>(blk_len, quant_a_ptr, quant_b_data, c_ptr, count_n, k_blks, + ldc); } else { - SQ4BitGemmM4Kernel_CompInt8_DispatchOnBlkLen<false>(BlkLen, QuantA, QuantBData, QuantBScale, - QuantBZeroPoint, C, CountM, CountN, BlockCountK, Bias, - ldc, ScaleStride); + SQ4BitGemmM4Kernel_CompInt8_ScaleFp16_Impl<false>(blk_len, quant_a_ptr, quant_b_data, c_ptr, count_n, + k_blks, ldc); } return 4; } else { - if (QuantBZeroPoint != nullptr) { - SQ4BitGemmM1Kernel_CompInt8_DispatchOnBlkLen<true>(BlkLen, QuantA, QuantBData, QuantBScale, QuantBZeroPoint, - C, CountM, CountN, BlockCountK, Bias, ldc, ScaleStride); + if (quant_b_zp != nullptr) { + SQ4BitGemmM1Kernel_CompInt8_ScaleFp16_Impl<true>(blk_len, quant_a_ptr, quant_b_data, c_ptr, count_n, k_blks, + ldc); } else { - SQ4BitGemmM1Kernel_CompInt8_DispatchOnBlkLen<false>(BlkLen, QuantA, QuantBData, QuantBScale, - QuantBZeroPoint, C, CountM, CountN, BlockCountK, Bias, - ldc, ScaleStride); + SQ4BitGemmM1Kernel_CompInt8_ScaleFp16_Impl<false>(blk_len, quant_a_ptr, quant_b_data, c_ptr, count_n, + k_blks, ldc); } return 1; } } } // namespace ime1 -} // namespace sqnbitgemm_spacemit_ime +} // namespace spacemit_kernels diff --git a/ggml/src/ggml-cpu/spacemit/ime2_kernels.cpp b/ggml/src/ggml-cpu/spacemit/ime2_kernels.cpp new file mode 100644 index 00000000000..0c7a036a92a --- /dev/null +++ b/ggml/src/ggml-cpu/spacemit/ime2_kernels.cpp @@ -0,0 +1,5768 @@ +#include "ggml-impl.h" +#include "ggml.h" +#include "ime_kernels.h" +#include "rvv_kernels.h" +#include "string.h" + +#include <algorithm> +#include <cmath> +#include <stdexcept> + +#if !defined(__riscv_v) || !defined(__riscv_v_intrinsic) +# error "riscv v extension or v_intrinsic not enabled" +#else +# include <riscv_vector.h> +#endif + +#if !defined(__riscv_zfh) +# error "riscv zfh extension not enabled" +#endif + +#if defined(RISCV64_SPACEMIT_IME2) +#else +# error "RISCV64_SPACEMIT_IME2 not defined" +#endif + +#if defined(__GNUC__) +# pragma GCC diagnostic ignored "-Woverlength-strings" +# pragma GCC diagnostic ignored "-Wcast-qual" +# pragma GCC diagnostic ignored "-Wunused-parameter" +#endif + +namespace spacemit_kernels { +namespace ime2 { + +template <size_t MB_ROWS, size_t NB_COLS> +void gemm_kernel_i8i2k_mrow_ref(size_t blk_len, + const uint8_t * quant_a_ptr, + const uint8_t * quant_b_data, + float * c_ptr, + size_t count_m, + size_t count_n, + size_t k_blks, + size_t ldc) { + using blk_type = nrow_block_q2_k<NB_COLS>; + constexpr float refactor_scale = 16.0f; + constexpr float factor_scale = 1.0f / refactor_scale; + + int64_t a_blk_stride = q8k_blk_size(256); + int64_t a_nrow_block_stride = a_blk_stride * MB_ROWS; + int64_t b_ncol_block_stride = sizeof(blk_type); + + float output[MB_ROWS * NB_COLS] = { 0 }; + _Float16 output_f16[MB_ROWS * NB_COLS] = { 0 }; + blk_type * quant_b_blk_data = (blk_type *) (quant_b_data); + + for (size_t ni = 0; ni < count_n; ni += NB_COLS, c_ptr += NB_COLS) { + size_t nb_real = std::min<size_t>(NB_COLS, count_n - ni); + + int8_t * a_data = (int8_t *) quant_a_ptr + sizeof(float) * MB_ROWS + sizeof(int16_t) * MB_ROWS * 16; + + for (size_t mi = 0; mi < MB_ROWS; mi++) { + for (size_t ci = 0; ci < NB_COLS; ci++) { + output[ci + mi * NB_COLS] = 0; + } + } + + for (size_t ki = 0; ki < k_blks; ki++, quant_b_blk_data++, a_data += a_nrow_block_stride) { + uint8_t * b_data = quant_b_blk_data->qs; + uint8_t * scales = quant_b_blk_data->scales; + uint8_t * scales16 = (uint8_t *) (quant_b_blk_data->scales16); + uint8_t * zeros16 = (uint8_t *) (quant_b_blk_data->zeros16); + + _Float16 * scales_fp16 = (_Float16 *) scales16; + _Float16 * zeros_fp16 = (_Float16 *) zeros16; + + float * a_scale_row = (float *) (a_data - sizeof(float) * MB_ROWS - sizeof(int16_t) * MB_ROWS * 16); + int16_t * a_sum_row = (int16_t *) (a_data - sizeof(int16_t) * MB_ROWS * 16); + + memset(output_f16, 0, sizeof(output_f16)); + + uint8_t * scales_temp = scales; + uint8_t * zps_temp = scales; + for (size_t kii = 0; kii < 16; kii++, scales_temp += NB_COLS, zps_temp++) { + size_t b_shift = (kii % 4) * 2; + + uint8_t * b_data_col = b_data + (kii / 4) * NB_COLS * 16; + + for (size_t mi = 0; mi < MB_ROWS; mi++) { + int16_t a_sum = a_sum_row[mi * 16 + kii]; + for (size_t ci = 0; ci < NB_COLS; ci++) { + _Float16 acc_0 = 0.0; + + uint8_t b_zp = zps_temp[ci * 16] >> 4; + uint8_t b_scale = scales_temp[ci] & 0x0F; + for (size_t bi = 0; bi < 16; bi++) { + int8_t a0 = a_data[mi * 256 + bi + kii * 16]; + uint8_t b0 = b_data_col[ci * 16 + bi]; + acc_0 += static_cast<int16_t>(a0) * static_cast<int16_t>((b0 >> b_shift) & 0x03); + } + + _Float16 scale_item = + static_cast<_Float16>(b_scale) * static_cast<_Float16>(factor_scale) * scales_fp16[ci]; + + output_f16[ci + mi * NB_COLS] += acc_0 * scale_item; + output[ci + mi * NB_COLS] += b_zp * a_sum * a_scale_row[mi] * zeros_fp16[ci]; + } + } + } + + for (size_t mi = 0; mi < MB_ROWS; mi++) { + auto a_scale = a_scale_row[mi] * refactor_scale; + for (size_t ci = 0; ci < NB_COLS; ci++) { + output[ci + mi * NB_COLS] += output_f16[ci + mi * NB_COLS] * a_scale; + } + } + } + + for (size_t mi = 0; mi < MB_ROWS; mi++) { + for (size_t ci = 0; ci < nb_real; ci++) { + c_ptr[mi * ldc + ci] = output[mi * NB_COLS + ci]; + } + } + } +} + +template <size_t MB_ROWS, size_t NB_COLS> +void gemm_kernel_i8i3k_mrow_ref(size_t blk_len, + const uint8_t * quant_a_ptr, + const uint8_t * quant_b_data, + float * c_ptr, + size_t count_m, + size_t count_n, + size_t k_blks, + size_t ldc) { + using blk_type = nrow_block_q2_k<NB_COLS>; + constexpr float refactor_scale = 16.0f; + constexpr float factor_scale = 1.0f / refactor_scale; + + int64_t a_blk_stride = q8k_blk_size(256); + int64_t a_nrow_block_stride = a_blk_stride * MB_ROWS; + int64_t b_ncol_block_stride = sizeof(blk_type); + + float output[MB_ROWS * NB_COLS] = { 0 }; + _Float16 output_f16[MB_ROWS * NB_COLS] = { 0 }; + + blk_type * quant_b_blk_data = (blk_type *) (quant_b_data); + + for (size_t ni = 0; ni < count_n; ni += NB_COLS, c_ptr += NB_COLS) { + size_t nb_real = std::min<size_t>(NB_COLS, count_n - ni); + + int8_t * a_data = (int8_t *) quant_a_ptr + sizeof(float) * MB_ROWS + sizeof(int16_t) * MB_ROWS * 16; + + for (size_t mi = 0; mi < MB_ROWS; mi++) { + for (size_t ci = 0; ci < NB_COLS; ci++) { + output[ci + mi * NB_COLS] = 0; + } + } + + for (size_t ki = 0; ki < k_blks; ki++, quant_b_blk_data++, a_data += a_nrow_block_stride) { + uint8_t * b_data = quant_b_blk_data->qs; + uint8_t * b_hmask = quant_b_blk_data->hmask; + int8_t * scales = quant_b_blk_data->scales; + uint8_t * scales16 = (uint8_t *) (quant_b_blk_data->scales16); + + _Float16 * scales_fp16 = (_Float16 *) scales16; + + float * a_scale_row = (float *) (a_data - sizeof(float) * MB_ROWS - sizeof(int16_t) * MB_ROWS * 16); + int16_t * a_sum_row = (int16_t *) (a_data - sizeof(int16_t) * MB_ROWS * 16); + + memset(output_f16, 0, sizeof(output_f16)); + + int8_t * scales_temp = scales; + uint16_t * b_mask_col = (uint16_t *) b_hmask; + + float acc_0_max = 0.0f; + for (size_t kii = 0; kii < 16; kii++, scales_temp += NB_COLS, b_mask_col += NB_COLS) { + size_t b_shift = (kii % 4) * 2; + + uint8_t * b_data_col = b_data + (kii / 4) * NB_COLS * 16; + + for (size_t mi = 0; mi < MB_ROWS; mi++) { + for (size_t ci = 0; ci < NB_COLS; ci++) { + _Float16 acc_0 = 0; + // blk 2 * kii + 0 + uint16_t b_shift_mask = 1; + for (size_t bi = 0; bi < 16; bi++, b_shift_mask <<= 1) { + int8_t a0 = a_data[mi * 256 + bi + kii * 16]; + int8_t b0 = static_cast<int8_t>((b_data_col[ci * 16 + bi] >> b_shift) & 0x03); + b0 -= b_mask_col[ci] & b_shift_mask ? 0 : 4; + acc_0 += static_cast<int16_t>(a0) * static_cast<int16_t>(b0); + } + + _Float16 scale_item = static_cast<_Float16>(scales_temp[ci]) * scales_fp16[ci] * + static_cast<_Float16>(factor_scale); + + output_f16[ci + mi * NB_COLS] += acc_0 * scale_item; + } + } + } + + for (size_t mi = 0; mi < MB_ROWS; mi++) { + auto a_scale = a_scale_row[mi] * refactor_scale; + for (size_t ci = 0; ci < NB_COLS; ci++) { + output[ci + mi * NB_COLS] += output_f16[ci + mi * NB_COLS] * a_scale; + } + } + } + + for (size_t mi = 0; mi < MB_ROWS; mi++) { + for (size_t ci = 0; ci < nb_real; ci++) { + c_ptr[mi * ldc + ci] = output[mi * NB_COLS + ci]; + } + } + } +} + +template <size_t MB_ROWS, size_t NB_COLS> +void gemm_kernel_i8i4_mrow_ref(size_t blk_len, + const uint8_t * quant_a_ptr, + const uint8_t * quant_b_data, + const uint8_t * quant_b_zp, + float * c_ptr, + size_t count_m, + size_t count_n, + size_t k_blks, + size_t ldc) { + constexpr size_t kblks_per_blk = 16; + GGML_ASSERT(k_blks % kblks_per_blk == 0); + + int64_t b_blk_stride = (sizeof(_Float16) + (blk_len / 2) + (quant_b_zp ? sizeof(uint8_t) : 0)); + int64_t b_stride = k_blks * b_blk_stride; + int64_t a_blk_stride = q8_blk_size(blk_len, true); + int64_t a_nrow_block_stride = a_blk_stride * MB_ROWS; + int64_t b_ncol_block_stride = b_blk_stride * NB_COLS; + + float output[MB_ROWS * NB_COLS] = { 0 }; + _Float16 output_f16[MB_ROWS * NB_COLS] = { 0 }; + + for (size_t ni = 0; ni < count_n; ni += NB_COLS, c_ptr += NB_COLS) { + size_t nb_real = std::min<size_t>(NB_COLS, count_n - ni); + uint8_t * b_data = (uint8_t *) quant_b_data + ni * b_stride + NB_COLS * sizeof(_Float16); + if (quant_b_zp) { + b_data += NB_COLS * sizeof(uint8_t); + } + + int8_t * a_data = (int8_t *) quant_a_ptr + sizeof(float) * MB_ROWS + sizeof(int16_t) * MB_ROWS; + + for (size_t mi = 0; mi < MB_ROWS; mi++) { + for (size_t ci = 0; ci < NB_COLS; ci++) { + output[ci + mi * NB_COLS] = 0.0f; + output_f16[ci + mi * NB_COLS] = static_cast<_Float16>(0.0f); + } + } + + size_t kii = 0; + for (size_t ki = 0; ki < k_blks; ki++, a_data += a_nrow_block_stride, b_data += b_ncol_block_stride) { + _Float16 * b_scale_fp16 = (_Float16 *) (b_data - NB_COLS * sizeof(_Float16)); + uint8_t * b_zp = nullptr; + if (quant_b_zp) { + b_scale_fp16 = (_Float16 *) (b_data - NB_COLS * sizeof(_Float16) - NB_COLS * sizeof(uint8_t)); + b_zp = (uint8_t *) (b_data - NB_COLS * sizeof(uint8_t)); + } + + float * a_scale_row = (float *) (a_data - sizeof(float) * MB_ROWS - sizeof(int16_t) * MB_ROWS); + int16_t * a_sum_row = (int16_t *) (a_data - sizeof(int16_t) * MB_ROWS); + + for (size_t mi = 0; mi < MB_ROWS; mi++) { + _Float16 a_scale = a_scale_row[mi]; + int16_t a_sum = a_sum_row[mi]; + + for (size_t ci = 0; ci < NB_COLS; ci++) { + _Float16 b_scale = b_scale_fp16[ci]; + int32_t acc = 0; + if (b_zp) { + acc += a_sum * b_zp[ci]; + } else { + acc += a_sum * 8; + } + for (size_t bi = 0; bi < blk_len / 2; bi++) { + int8_t a0 = a_data[mi * blk_len + 2 * bi]; + int8_t a1 = a_data[mi * blk_len + 2 * bi + 1]; + uint8_t b = b_data[ci * blk_len / 2 + bi]; + int8_t b0 = static_cast<int8_t>(b & 0x0F); + int8_t b1 = static_cast<int8_t>((b & 0xF0) >> 4); + acc += static_cast<int32_t>(a0) * static_cast<int32_t>(b0) + + static_cast<int32_t>(a1) * static_cast<int32_t>(b1); + } + output_f16[ci + mi * NB_COLS] += + static_cast<float>(acc) * static_cast<float>(a_scale) * static_cast<float>(b_scale); + } + } + + if (kii == kblks_per_blk - 1) { + for (size_t mi = 0; mi < MB_ROWS; mi++) { + for (size_t ci = 0; ci < NB_COLS; ci++) { + output[ci + mi * NB_COLS] += static_cast<float>(output_f16[ci + mi * NB_COLS]); + output_f16[ci + mi * NB_COLS] = 0.0f; + } + } + kii = 0; + } else { + kii++; + } + } + + if (kii == kblks_per_blk - 1) { + for (size_t mi = 0; mi < MB_ROWS; mi++) { + for (size_t ci = 0; ci < NB_COLS; ci++) { + output[ci + mi * NB_COLS] += static_cast<float>(output_f16[ci + mi * NB_COLS]); + output_f16[ci + mi * NB_COLS] = 0.0f; + } + } + kii = 0; + } + + for (size_t mi = 0; mi < MB_ROWS; mi++) { + for (size_t ci = 0; ci < nb_real; ci++) { + c_ptr[mi * ldc + ci] = output[mi * NB_COLS + ci]; + } + } + } +} + +template <size_t MB_ROWS, size_t NB_COLS> +void gemm_kernel_i8i4_hp_mrow_ref(size_t blk_len, + const uint8_t * quant_a_ptr, + const uint8_t * quant_b_data, + const uint8_t * quant_b_zp, + float * c_ptr, + size_t count_m, + size_t count_n, + size_t k_blks, + size_t ldc) { + constexpr size_t k_subblks_per_superblk = 8; + + struct block_q4_0x32_layout { + _Float16 d[NB_COLS]; + uint8_t qs[16 * NB_COLS]; + }; + + GGML_ASSERT(blk_len == 256); + + const size_t b_superblk_stride = sizeof(block_q4_0x32_layout) * k_subblks_per_superblk + + (quant_b_zp ? NB_COLS * k_subblks_per_superblk * sizeof(uint8_t) : 0); + const size_t b_tile_stride = k_blks * b_superblk_stride; + + const size_t a_nrow_block_stride = q8_hp_blk_size(blk_len, true, true) * MB_ROWS; + const size_t a_subblk_stride = q8_hp_blk_size(32, false, false) * MB_ROWS; + + float output[MB_ROWS * NB_COLS] = { 0 }; + for (size_t ni = 0; ni < count_n; ni += NB_COLS, c_ptr += NB_COLS) { + size_t nb_real = std::min<size_t>(NB_COLS, count_n - ni); + const uint8_t * b_tile_base = quant_b_data + (ni / NB_COLS) * b_tile_stride; + int8_t * a_data = (int8_t *) quant_a_ptr; + + for (size_t mi = 0; mi < MB_ROWS; mi++) { + for (size_t ci = 0; ci < NB_COLS; ci++) { + output[ci + mi * NB_COLS] = 0.0f; + } + } + + for (size_t ki = 0; ki < k_blks; ki++, a_data += a_nrow_block_stride) { + _Float16 output_f16[MB_ROWS * NB_COLS] = { 0 }; + + const uint8_t * b_superblk_ptr = b_tile_base + ki * b_superblk_stride; + const block_q4_0x32_layout * b_blocks = reinterpret_cast<const block_q4_0x32_layout *>(b_superblk_ptr); + const uint8_t * b_zps = + quant_b_zp ? b_superblk_ptr + sizeof(block_q4_0x32_layout) * k_subblks_per_superblk : nullptr; + + _Float16 * a_sum_row = (_Float16 *) (a_data + a_subblk_stride * k_subblks_per_superblk); + _Float16 * a_scale_avg_row = (_Float16 *) (a_data + a_nrow_block_stride - sizeof(_Float16) * MB_ROWS); + _Float16 scale_factor = a_scale_avg_row[0]; + + for (size_t ksi = 0; ksi < k_subblks_per_superblk; ++ksi) { + const _Float16 * a_scale_row = reinterpret_cast<const _Float16 *>(a_data + a_subblk_stride * ksi); + int8_t * a_subblk = a_data + a_subblk_stride * ksi + MB_ROWS * sizeof(_Float16); + const _Float16 a_scale = a_scale_row[0]; + const block_q4_0x32_layout & b_block = b_blocks[ksi]; + + for (size_t mi = 0; mi < MB_ROWS; mi++) { + for (size_t ci = 0; ci < NB_COLS; ci++) { + const uint8_t * b_qs = b_block.qs + ci * 16; + _Float16 b_scale = b_block.d[ci] * a_scale; + + int16_t acc = 0; + for (size_t bi = 0; bi < 16; bi++) { + uint8_t b = b_qs[bi]; + int8_t b0 = static_cast<int8_t>(b & 0x0F); + int8_t b1 = static_cast<int8_t>((b & 0xF0) >> 4); + + acc += static_cast<int16_t>(a_subblk[mi * 32 + 2 * bi]) * static_cast<int16_t>(b0) + + static_cast<int16_t>(a_subblk[mi * 32 + 2 * bi + 1]) * static_cast<int16_t>(b1); + } + + const _Float16 scaled_acc = static_cast<_Float16>(acc) * b_scale; + output_f16[ci + mi * NB_COLS] += scaled_acc; + } + } + } + + for (size_t ksi = 0; ksi < k_subblks_per_superblk; ++ksi) { + const _Float16 * a_scale_row = reinterpret_cast<const _Float16 *>(a_data + a_subblk_stride * ksi); + const block_q4_0x32_layout & b_block = b_blocks[ksi]; + const uint8_t * b_zp_row = b_zps ? b_zps + ksi * NB_COLS : nullptr; + const _Float16 a_scale = a_scale_row[0]; + + for (size_t mi = 0; mi < MB_ROWS; mi++) { + const _Float16 a_sum = a_sum_row[mi * k_subblks_per_superblk + ksi]; + for (size_t ci = 0; ci < NB_COLS; ci++) { + _Float16 b_scale = b_block.d[ci] * a_scale; + _Float16 a_sum_bzp = a_sum; + if (b_zp_row) { + a_sum_bzp = a_sum * static_cast<_Float16>(0.125f) * static_cast<_Float16>(b_zp_row[ci]); + } + + const _Float16 scaled_acc = a_sum_bzp * b_scale; + output[ci + mi * NB_COLS] += scaled_acc * scale_factor; + } + } + } + + for (size_t mi = 0; mi < MB_ROWS; mi++) { + for (size_t ci = 0; ci < NB_COLS; ci++) { + auto val = static_cast<float>(output_f16[ci + mi * NB_COLS]) * static_cast<float>(scale_factor); + output[ci + mi * NB_COLS] += val; + } + } + } + + for (size_t mi = 0; mi < MB_ROWS; mi++) { + for (size_t ci = 0; ci < nb_real; ci++) { + c_ptr[mi * ldc + ci] = output[mi * NB_COLS + ci]; + } + } + } +} + +template <size_t MB_ROWS, size_t NB_COLS> +void moe_gemm_kernel_i8i4_mrow_ref(size_t blk_len, + const uint8_t ** quant_a_ptr, + const uint8_t * quant_b_data, + const uint8_t * quant_b_zp, + float ** c_ptr, + size_t count_m, + size_t count_n, + size_t k_blks, + size_t ldc) { + int64_t b_blk_stride = (sizeof(ggml_fp16_t) + (blk_len / 2) + (quant_b_zp ? sizeof(uint8_t) : 0)); + int64_t b_stride = k_blks * b_blk_stride; + int64_t a_blk_stride = q8_blk_size(blk_len, true); + int64_t b_ncol_block_stride = b_blk_stride * NB_COLS; + + float output[MB_ROWS * NB_COLS] = { 0 }; + std::array<int8_t *, MB_ROWS> a_data; + std::array<float *, MB_ROWS> c_data; + + for (size_t mi = 0; mi < MB_ROWS; mi++) { + c_data[mi] = c_ptr[mi]; + } + + for (size_t ni = 0; ni < count_n; ni += NB_COLS) { + size_t nb_real = std::min<size_t>(NB_COLS, count_n - ni); + uint8_t * b_data = (uint8_t *) quant_b_data + ni * b_stride + NB_COLS * sizeof(ggml_fp16_t); + if (quant_b_zp) { + b_data += NB_COLS * sizeof(uint8_t); + } + + for (size_t mi = 0; mi < MB_ROWS; mi++) { + a_data[mi] = (int8_t *) quant_a_ptr[mi] + sizeof(float) + sizeof(int16_t); + } + + for (size_t mi = 0; mi < MB_ROWS; mi++) { + for (size_t ci = 0; ci < NB_COLS; ci++) { + output[ci + mi * NB_COLS] = 0; + } + } + + for (size_t ki = 0; ki < k_blks; ki++, b_data += b_ncol_block_stride) { + ggml_fp16_t * b_scale_fp16 = (ggml_fp16_t *) (b_data - NB_COLS * sizeof(ggml_fp16_t)); + uint8_t * b_zp = nullptr; + if (quant_b_zp) { + b_scale_fp16 = (ggml_fp16_t *) (b_data - NB_COLS * sizeof(ggml_fp16_t) - NB_COLS * sizeof(uint8_t)); + b_zp = (uint8_t *) (b_data - NB_COLS * sizeof(uint8_t)); + } + + for (size_t mi = 0; mi < MB_ROWS; mi++) { + float * a_scale_row = (float *) (a_data[mi] - sizeof(float) - sizeof(int16_t)); + int16_t * a_sum_row = (int16_t *) (a_data[mi] - sizeof(int16_t)); + + float a_scale = *a_scale_row; + int16_t a_sum = *a_sum_row; + + for (size_t ci = 0; ci < NB_COLS; ci++) { + float b_scale = ggml_fp16_to_fp32(b_scale_fp16[ci]); + int32_t acc = 0; + if (b_zp) { + acc += a_sum * b_zp[ci]; + } else { + acc += a_sum * 8; + } + for (size_t bi = 0; bi < blk_len / 2; bi++) { + int8_t a0 = (a_data[mi])[2 * bi]; + int8_t a1 = (a_data[mi])[2 * bi + 1]; + uint8_t b = b_data[ci * blk_len / 2 + bi]; + int8_t b0 = static_cast<int8_t>(b & 0x0F); + int8_t b1 = static_cast<int8_t>((b & 0xF0) >> 4); + acc += static_cast<int32_t>(a0) * static_cast<int32_t>(b0) + + static_cast<int32_t>(a1) * static_cast<int32_t>(b1); + } + output[ci + mi * NB_COLS] += static_cast<float>(acc) * a_scale * b_scale; + } + } + + for (size_t mi = 0; mi < MB_ROWS; mi++) { + a_data[mi] += a_blk_stride; + } + } + + for (size_t mi = 0; mi < MB_ROWS; mi++) { + for (size_t ci = 0; ci < nb_real; ci++) { + (c_data[mi])[ci] = output[mi * NB_COLS + ci]; + } + } + + for (size_t mi = 0; mi < MB_ROWS; mi++) { + c_data[mi] += NB_COLS; + } + } +} + +template <size_t MB_ROWS, size_t NB_COLS> +void moe_gemm_kernel_i8i5_mrow_ref(size_t blk_len, + const uint8_t ** quant_a_ptr, + const uint8_t * quant_b_data, + const uint8_t * quant_b_zp, + float ** c_ptr, + size_t count_m, + size_t count_n, + size_t k_blks, + size_t ldc) { + GGML_UNUSED(count_m); + GGML_UNUSED(ldc); + + // blk_len is expected to be 32 for Q5 types. + int64_t a_blk_stride = q8_blk_size(blk_len, true); + + float output[MB_ROWS * NB_COLS] = { 0 }; + std::array<int8_t *, MB_ROWS> a_data; + std::array<float *, MB_ROWS> c_data; + + for (size_t mi = 0; mi < MB_ROWS; ++mi) { + c_data[mi] = c_ptr[mi]; + } + + if (quant_b_zp) { + using blk_type = nrow_block_q5_1<NB_COLS>; + + for (size_t ni = 0; ni < count_n; ni += NB_COLS) { + size_t nb_real = std::min<size_t>(NB_COLS, count_n - ni); + blk_type * quant_b_blk_data = (blk_type *) quant_b_data + (ni / NB_COLS) * k_blks; + + for (size_t mi = 0; mi < MB_ROWS; ++mi) { + a_data[mi] = (int8_t *) quant_a_ptr[mi] + sizeof(float) + sizeof(int16_t); + } + + for (size_t mi = 0; mi < MB_ROWS; ++mi) { + for (size_t ci = 0; ci < NB_COLS; ++ci) { + output[ci + mi * NB_COLS] = 0; + } + } + + for (size_t ki = 0; ki < k_blks; ++ki, ++quant_b_blk_data) { + for (size_t mi = 0; mi < MB_ROWS; ++mi) { + float * a_scale_row = (float *) (a_data[mi] - sizeof(float) - sizeof(int16_t)); + int16_t * a_sum_row = (int16_t *) (a_data[mi] - sizeof(int16_t)); + float a_scale = *a_scale_row; + int16_t a_sum = *a_sum_row; + + for (size_t ci = 0; ci < NB_COLS; ++ci) { + float b_scale = ggml_fp16_to_fp32(quant_b_blk_data->scales16[ci]); + uint8_t b_zp_val = quant_b_blk_data->zp[ci]; + int32_t acc = a_sum * static_cast<int32_t>(b_zp_val); + + for (size_t bi = 0; bi < blk_len / 2; ++bi) { + int8_t a0 = a_data[mi][2 * bi]; + int8_t a1 = a_data[mi][2 * bi + 1]; + uint8_t qs_byte = quant_b_blk_data->qs[ci * (blk_len / 2) + bi]; + int8_t b0 = static_cast<int8_t>(qs_byte & 0x0F); + int8_t b1 = static_cast<int8_t>((qs_byte >> 4) & 0x0F); + uint8_t qh_byte0 = quant_b_blk_data->qh[ci * 4 + (2 * bi) / 8]; + uint8_t qh_byte1 = quant_b_blk_data->qh[ci * 4 + (2 * bi + 1) / 8]; + uint8_t h0 = (qh_byte0 >> ((2 * bi) % 8)) & 1; + uint8_t h1 = (qh_byte1 >> ((2 * bi + 1) % 8)) & 1; + + b0 |= (h0 << 4); + b1 |= (h1 << 4); + + acc += static_cast<int32_t>(a0) * static_cast<int32_t>(b0) + + static_cast<int32_t>(a1) * static_cast<int32_t>(b1); + } + + output[ci + mi * NB_COLS] += static_cast<float>(acc) * a_scale * b_scale; + } + + a_data[mi] += a_blk_stride; + } + } + + for (size_t mi = 0; mi < MB_ROWS; ++mi) { + for (size_t ci = 0; ci < nb_real; ++ci) { + c_data[mi][ci] = output[mi * NB_COLS + ci]; + } + c_data[mi] += NB_COLS; + } + } + } else { + using blk_type = nrow_block_q5_0<NB_COLS>; + + for (size_t ni = 0; ni < count_n; ni += NB_COLS) { + size_t nb_real = std::min<size_t>(NB_COLS, count_n - ni); + blk_type * quant_b_blk_data = (blk_type *) quant_b_data + (ni / NB_COLS) * k_blks; + + for (size_t mi = 0; mi < MB_ROWS; ++mi) { + a_data[mi] = (int8_t *) quant_a_ptr[mi] + sizeof(float) + sizeof(int16_t); + } + + for (size_t mi = 0; mi < MB_ROWS; ++mi) { + for (size_t ci = 0; ci < NB_COLS; ++ci) { + output[ci + mi * NB_COLS] = 0; + } + } + + for (size_t ki = 0; ki < k_blks; ++ki, ++quant_b_blk_data) { + for (size_t mi = 0; mi < MB_ROWS; ++mi) { + float * a_scale_row = (float *) (a_data[mi] - sizeof(float) - sizeof(int16_t)); + int16_t * a_sum_row = (int16_t *) (a_data[mi] - sizeof(int16_t)); + float a_scale = *a_scale_row; + int16_t a_sum = *a_sum_row; + + for (size_t ci = 0; ci < NB_COLS; ++ci) { + float b_scale = ggml_fp16_to_fp32(quant_b_blk_data->scales16[ci]); + int32_t acc = a_sum * 16; + + for (size_t bi = 0; bi < blk_len / 2; ++bi) { + int8_t a0 = a_data[mi][2 * bi]; + int8_t a1 = a_data[mi][2 * bi + 1]; + uint8_t qs_byte = quant_b_blk_data->qs[ci * (blk_len / 2) + bi]; + int8_t b0 = static_cast<int8_t>(qs_byte & 0x0F); + int8_t b1 = static_cast<int8_t>((qs_byte >> 4) & 0x0F); + uint8_t qh_byte0 = quant_b_blk_data->qh[ci * 4 + (2 * bi) / 8]; + uint8_t qh_byte1 = quant_b_blk_data->qh[ci * 4 + (2 * bi + 1) / 8]; + uint8_t h0 = (qh_byte0 >> ((2 * bi) % 8)) & 1; + uint8_t h1 = (qh_byte1 >> ((2 * bi + 1) % 8)) & 1; + + b0 |= (h0 << 4); + b1 |= (h1 << 4); + + acc += static_cast<int32_t>(a0) * static_cast<int32_t>(b0) + + static_cast<int32_t>(a1) * static_cast<int32_t>(b1); + } + + output[ci + mi * NB_COLS] += static_cast<float>(acc) * a_scale * b_scale; + } + + a_data[mi] += a_blk_stride; + } + } + + for (size_t mi = 0; mi < MB_ROWS; ++mi) { + for (size_t ci = 0; ci < nb_real; ++ci) { + c_data[mi][ci] = output[mi * NB_COLS + ci]; + } + c_data[mi] += NB_COLS; + } + } + } +} + +template <size_t MB_ROWS, size_t NB_COLS> +void gemm_kernel_i8i8_mrow_ref(size_t blk_len, + const uint8_t * quant_a_ptr, + const uint8_t * quant_b_data, + const uint8_t * quant_b_zp, + float * c_ptr, + size_t count_m, + size_t count_n, + size_t k_blks, + size_t ldc) { + int64_t b_blk_stride = (sizeof(ggml_fp16_t) + blk_len); + int64_t b_stride = k_blks * b_blk_stride; + int64_t a_blk_stride = q8_blk_size(blk_len, true); + int64_t a_nrow_block_stride = a_blk_stride * MB_ROWS; + int64_t b_ncol_block_stride = b_blk_stride * NB_COLS; + + float output[MB_ROWS * NB_COLS] = { 0 }; + + for (size_t ni = 0; ni < count_n; ni += NB_COLS, c_ptr += NB_COLS) { + size_t nb_real = std::min<size_t>(NB_COLS, count_n - ni); + int8_t * b_data = (int8_t *) quant_b_data + ni * b_stride + NB_COLS * sizeof(ggml_fp16_t); + + int8_t * a_data = (int8_t *) quant_a_ptr + sizeof(float) * MB_ROWS + sizeof(int16_t) * MB_ROWS; + + for (size_t mi = 0; mi < MB_ROWS; mi++) { + for (size_t ci = 0; ci < NB_COLS; ci++) { + output[ci + mi * NB_COLS] = 0; + } + } + + for (size_t ki = 0; ki < k_blks; ki++, a_data += a_nrow_block_stride, b_data += b_ncol_block_stride) { + ggml_fp16_t * b_scale_fp16 = (ggml_fp16_t *) (b_data - NB_COLS * sizeof(ggml_fp16_t)); + + float * a_scale_row = (float *) (a_data - sizeof(float) * MB_ROWS - sizeof(int16_t) * MB_ROWS); + + for (size_t mi = 0; mi < MB_ROWS; mi++) { + float a_scale = a_scale_row[mi]; + for (size_t ci = 0; ci < NB_COLS; ci++) { + float b_scale = ggml_fp16_to_fp32(b_scale_fp16[ci]); + int32_t acc = 0; + for (size_t bi = 0; bi < blk_len; bi++) { + int8_t a0 = a_data[mi * blk_len + bi]; + int8_t b0 = b_data[ci * blk_len + bi]; + acc += static_cast<int32_t>(a0) * static_cast<int32_t>(b0); + } + output[ci + mi * NB_COLS] += static_cast<float>(acc) * a_scale * b_scale; + } + } + } + + for (size_t mi = 0; mi < MB_ROWS; mi++) { + for (size_t ci = 0; ci < nb_real; ci++) { + c_ptr[mi * ldc + ci] = output[mi * NB_COLS + ci]; + } + } + } +} + +template <size_t MB_ROWS, size_t NB_COLS> +void gemm_kernel_i8i5_mrow_ref(size_t blk_len, + const uint8_t * quant_a_ptr, + const uint8_t * quant_b_data, + const uint8_t * quant_b_zp, + float * c_ptr, + size_t count_m, + size_t count_n, + size_t k_blks, + size_t ldc) { + // blk_len is expected to be 32 for Q5 types + // quant_b_zp != nullptr => nrow_block_q5_1<NB_COLS> (has zp) + // quant_b_zp == nullptr => nrow_block_q5_0<NB_COLS> (no zp) + + int64_t a_blk_stride = q8_blk_size(blk_len, true); + int64_t a_nrow_block_stride = a_blk_stride * MB_ROWS; + + float output[MB_ROWS * NB_COLS] = { 0 }; + + if (quant_b_zp) { + // nrow_block_q5_1<NB_COLS>: scales16[NB_COLS] + zp[NB_COLS] + qh[4*NB_COLS] + qs[16*NB_COLS] + using blk_type = nrow_block_q5_1<NB_COLS>; + int64_t b_ncol_block_stride = sizeof(blk_type); + blk_type * quant_b_blk_data = (blk_type *) quant_b_data; + + for (size_t ni = 0; ni < count_n; ni += NB_COLS, c_ptr += NB_COLS) { + size_t nb_real = std::min<size_t>(NB_COLS, count_n - ni); + + int8_t * a_data = (int8_t *) quant_a_ptr + sizeof(float) * MB_ROWS + sizeof(int16_t) * MB_ROWS; + + for (size_t mi = 0; mi < MB_ROWS; mi++) { + for (size_t ci = 0; ci < NB_COLS; ci++) { + output[ci + mi * NB_COLS] = 0; + } + } + + for (size_t ki = 0; ki < k_blks; ki++, quant_b_blk_data++, a_data += a_nrow_block_stride) { + float * a_scale_row = (float *) (a_data - sizeof(float) * MB_ROWS - sizeof(int16_t) * MB_ROWS); + int16_t * a_sum_row = (int16_t *) (a_data - sizeof(int16_t) * MB_ROWS); + + for (size_t mi = 0; mi < MB_ROWS; mi++) { + float a_scale = a_scale_row[mi]; + int16_t a_sum = a_sum_row[mi]; + + for (size_t ci = 0; ci < NB_COLS; ci++) { + float b_scale = ggml_fp16_to_fp32(quant_b_blk_data->scales16[ci]); + uint8_t b_zp_val = quant_b_blk_data->zp[ci]; + int32_t acc = a_sum * static_cast<int32_t>(b_zp_val); + + for (size_t bi = 0; bi < blk_len / 2; bi++) { + int8_t a0 = a_data[mi * blk_len + 2 * bi]; + int8_t a1 = a_data[mi * blk_len + 2 * bi + 1]; + uint8_t qs_byte = quant_b_blk_data->qs[ci * (blk_len / 2) + bi]; + int8_t b0 = static_cast<int8_t>(qs_byte & 0x0F); + int8_t b1 = static_cast<int8_t>((qs_byte >> 4) & 0x0F); + + // Extract high bits from qh + // qh is packed as 4 bytes per column (32 bits for 32 elements) + uint8_t qh_byte0 = quant_b_blk_data->qh[ci * 4 + (2 * bi) / 8]; + uint8_t qh_byte1 = quant_b_blk_data->qh[ci * 4 + (2 * bi + 1) / 8]; + uint8_t h0 = (qh_byte0 >> ((2 * bi) % 8)) & 1; + uint8_t h1 = (qh_byte1 >> ((2 * bi + 1) % 8)) & 1; + + b0 |= (h0 << 4); + b1 |= (h1 << 4); + + acc += static_cast<int32_t>(a0) * static_cast<int32_t>(b0) + + static_cast<int32_t>(a1) * static_cast<int32_t>(b1); + } + output[ci + mi * NB_COLS] += static_cast<float>(acc) * a_scale * b_scale; + } + } + } + + for (size_t mi = 0; mi < MB_ROWS; mi++) { + for (size_t ci = 0; ci < nb_real; ci++) { + c_ptr[mi * ldc + ci] = output[mi * NB_COLS + ci]; + } + } + } + } else { + // nrow_block_q5_0<NB_COLS>: scales16[NB_COLS] + qh[4*NB_COLS] + qs[16*NB_COLS] + using blk_type = nrow_block_q5_0<NB_COLS>; + int64_t b_ncol_block_stride = sizeof(blk_type); + blk_type * quant_b_blk_data = (blk_type *) quant_b_data; + + for (size_t ni = 0; ni < count_n; ni += NB_COLS, c_ptr += NB_COLS) { + size_t nb_real = std::min<size_t>(NB_COLS, count_n - ni); + + int8_t * a_data = (int8_t *) quant_a_ptr + sizeof(float) * MB_ROWS + sizeof(int16_t) * MB_ROWS; + + for (size_t mi = 0; mi < MB_ROWS; mi++) { + for (size_t ci = 0; ci < NB_COLS; ci++) { + output[ci + mi * NB_COLS] = 0; + } + } + + for (size_t ki = 0; ki < k_blks; ki++, quant_b_blk_data++, a_data += a_nrow_block_stride) { + float * a_scale_row = (float *) (a_data - sizeof(float) * MB_ROWS - sizeof(int16_t) * MB_ROWS); + int16_t * a_sum_row = (int16_t *) (a_data - sizeof(int16_t) * MB_ROWS); + + for (size_t mi = 0; mi < MB_ROWS; mi++) { + float a_scale = a_scale_row[mi]; + int16_t a_sum = a_sum_row[mi]; + + for (size_t ci = 0; ci < NB_COLS; ci++) { + float b_scale = ggml_fp16_to_fp32(quant_b_blk_data->scales16[ci]); + // Q5_0 has no zp, use default offset 16 (midpoint of 5-bit unsigned range) + int32_t acc = a_sum * 16; + + for (size_t bi = 0; bi < blk_len / 2; bi++) { + int8_t a0 = a_data[mi * blk_len + 2 * bi]; + int8_t a1 = a_data[mi * blk_len + 2 * bi + 1]; + uint8_t qs_byte = quant_b_blk_data->qs[ci * (blk_len / 2) + bi]; + int8_t b0 = static_cast<int8_t>(qs_byte & 0x0F); + int8_t b1 = static_cast<int8_t>((qs_byte >> 4) & 0x0F); + + // Extract high bits from qh + uint8_t qh_byte0 = quant_b_blk_data->qh[ci * 4 + (2 * bi) / 8]; + uint8_t qh_byte1 = quant_b_blk_data->qh[ci * 4 + (2 * bi + 1) / 8]; + uint8_t h0 = (qh_byte0 >> ((2 * bi) % 8)) & 1; + uint8_t h1 = (qh_byte1 >> ((2 * bi + 1) % 8)) & 1; + + b0 |= (h0 << 4); + b1 |= (h1 << 4); + + acc += static_cast<int32_t>(a0) * static_cast<int32_t>(b0) + + static_cast<int32_t>(a1) * static_cast<int32_t>(b1); + } + output[ci + mi * NB_COLS] += static_cast<float>(acc) * a_scale * b_scale; + } + } + } + + for (size_t mi = 0; mi < MB_ROWS; mi++) { + for (size_t ci = 0; ci < nb_real; ci++) { + c_ptr[mi * ldc + ci] = output[mi * NB_COLS + ci]; + } + } + } + } +} + +template <size_t MB_ROWS, size_t NB_COLS> +void gemm_kernel_i8mxfp4_mrow_ref(size_t blk_len, + const uint8_t * quant_a_ptr, + const uint8_t * quant_b_data, + const uint8_t * quant_b_zp, + float * c_ptr, + size_t count_m, + size_t count_n, + size_t k_blks, + size_t ldc) { + // blk_len is expected to be 32 (QK_MXFP4) + // quant_b_zp is unused for MXFP4 (symmetric quantization) + GGML_UNUSED(quant_b_zp); + + int64_t a_blk_stride = q8_blk_size(blk_len, true); + int64_t a_nrow_block_stride = a_blk_stride * MB_ROWS; + + float output[MB_ROWS * NB_COLS] = { 0 }; + + using blk_type = nrow_block_mxfp4<NB_COLS>; + blk_type * quant_b_blk_data = (blk_type *) quant_b_data; + + for (size_t ni = 0; ni < count_n; ni += NB_COLS, c_ptr += NB_COLS) { + size_t nb_real = std::min<size_t>(NB_COLS, count_n - ni); + + int8_t * a_data = (int8_t *) quant_a_ptr + sizeof(float) * MB_ROWS + sizeof(int16_t) * MB_ROWS; + + for (size_t mi = 0; mi < MB_ROWS; mi++) { + for (size_t ci = 0; ci < NB_COLS; ci++) { + output[ci + mi * NB_COLS] = 0; + } + } + + for (size_t ki = 0; ki < k_blks; ki++, quant_b_blk_data++, a_data += a_nrow_block_stride) { + float * a_scale_row = (float *) (a_data - sizeof(float) * MB_ROWS - sizeof(int16_t) * MB_ROWS); + int16_t * a_sum_row = (int16_t *) (a_data - sizeof(int16_t) * MB_ROWS); + + for (size_t mi = 0; mi < MB_ROWS; mi++) { + float a_scale = a_scale_row[mi]; + + for (size_t ci = 0; ci < NB_COLS; ci++) { + float b_scale = GGML_E8M0_TO_FP32_HALF(quant_b_blk_data->e[ci]); + + // Read 32 sign bits for this column + uint32_t sign_bits; + memcpy(&sign_bits, &quant_b_blk_data->qh[ci * 4], 4); + + int32_t acc = 0; + for (size_t bi = 0; bi < blk_len / 2; bi++) { + int8_t a0 = a_data[mi * blk_len + 2 * bi]; + int8_t a1 = a_data[mi * blk_len + 2 * bi + 1]; + + // qs[ci*16 + bi] stores abs(vals[bi*2]) in low 4 bits + // and abs(vals[bi*2+1]) in high 4 bits + uint8_t qs_byte = quant_b_blk_data->qs[ci * 16 + bi]; + int8_t b_abs0 = static_cast<int8_t>(qs_byte & 0x0F); + int8_t b_abs1 = static_cast<int8_t>((qs_byte >> 4) & 0x0F); + + // Extract sign bits: bit (2*bi) for vals[2*bi], bit (2*bi+1) for vals[2*bi+1] + int8_t b0 = (sign_bits >> (2 * bi)) & 1 ? -b_abs0 : b_abs0; + int8_t b1 = (sign_bits >> (2 * bi + 1)) & 1 ? -b_abs1 : b_abs1; + + acc += static_cast<int32_t>(a0) * static_cast<int32_t>(b0) + + static_cast<int32_t>(a1) * static_cast<int32_t>(b1); + } + output[ci + mi * NB_COLS] += static_cast<float>(acc) * a_scale * b_scale; + } + } + } + + for (size_t mi = 0; mi < MB_ROWS; mi++) { + for (size_t ci = 0; ci < nb_real; ci++) { + c_ptr[mi * ldc + ci] = output[mi * NB_COLS + ci]; + } + } + } +} + +void gemm_kernel_i8i2k_m1(size_t blk_len, + const uint8_t * quant_a_ptr, + const uint8_t * quant_b_data, + float * c_ptr, + size_t count_m, + size_t count_n, + size_t k_blks, + size_t ldc) { + constexpr size_t NB_COLS = 32; + using blk_type = nrow_block_q2_k<NB_COLS>; + + int64_t b_ncol_block_stride = sizeof(blk_type) * k_blks; + + for (size_t ni = 0; ni < count_n; ni += NB_COLS) { + uint8_t * b_data = (uint8_t *) quant_b_data + (ni / NB_COLS) * b_ncol_block_stride; + int8_t * a_data = (int8_t *) quant_a_ptr; + float * dst_c = (float *) c_ptr + ni; + + asm volatile( + "vsetvli t0, x0, e16, m1 \n\t" + "vxor.vv v31, v31, v31 \n\t" + "mv s1, %[BK] \n\t" + + ".align 4 \n\t" + "BLK_LOOP%=: \n\t" + // load scale A + "flw fa0, (%[A]) \n\t" + "addi %[A], %[A], 4 \n\t" + + "li t1, 4 \n\t" + "addi t2, %[B], 512 \n\t" // B data addr + "addi t3, %[A], 32 \n\t" // A data addr + "addi s3, %[B], 0 \n\t" + "vxor.vv v30, v29, v29 \n\t" // tmp result + + "INNER_K_LOOP%=: \n\t" + "vsetvli t0, x0, e8, m1 \n\t" + "vxor.vv v2, v2, v2 \n\t" + "vxor.vv v3, v3, v3 \n\t" + "vxor.vv v4, v4, v4 \n\t" + "vxor.vv v5, v5, v5 \n\t" + "vxor.vv v6, v6, v6 \n\t" + "vxor.vv v28, v28, v28 \n\t" + "vxor.vv v29, v29, v29 \n\t" + + // load scale B + "vsetvli t0, x0, e8, m1 \n\t" + "vle8.v v0, (%[B]) \n\t" + "addi %[B], %[B], 128 \n\t" + + // A data, 1x64@i8 + "vsetivli t0, 16, e8, mf4 \n\t" + "vle8.v v2, (t3) \n\t" + "addi t3, t3, 16 \n\t" + + "vsetivli t0, 16, e8, mf4 \n\t" + "vle8.v v4, (t3) \n\t" + "addi t3, t3, 16 \n\t" + + "vsetivli t0, 16, e8, mf4 \n\t" + "vle8.v v5, (t3) \n\t" + "addi t3, t3, 16 \n\t" + + "vsetivli t0, 16, e8, mf4 \n\t" + "vle8.v v6, (t3) \n\t" + "addi t3, t3, 16 \n\t" + + "vsetvli t0, x0, e64, mf2 \n\t" + "vslideup.vi v3, v4, 2 \n\t" + "vslideup.vi v28, v5, 4 \n\t" + "vslideup.vi v29, v6, 6 \n\t" + + // init the accumu to zero + "vsetvli t0, x0, e16, m1 \n\t" + "vxor.vv v20, v18, v18 \n\t" + "vxor.vv v22, v18, v18 \n\t" + "vxor.vv v24, v18, v18 \n\t" + "vxor.vv v26, v18, v18 \n\t" + + // B data, 32x64@i2 + "vsetvli t0, x0, e8, m1 \n\t" + "vl4r.v v4, (t2) \n\t" + "addi t2, t2, 512 \n\t" + "vand.vi v8, v4, 0x3 \n\t" // 0-15 + "vsrl.vi v9, v4, 2 \n\t" + "vsrl.vi v10, v4, 4 \n\t" + "vsrl.vi v11, v4, 6 \n\t" // 48-63 + "vand.vi v9, v9, 0x3 \n\t" // 16-31 + "vand.vi v10, v10, 0x3 \n\t" // 32-47 + + "vand.vi v12, v5, 0x3 \n\t" // 0-15 + "vsrl.vi v13, v5, 2 \n\t" + "vsrl.vi v14, v5, 4 \n\t" + "vsrl.vi v15, v5, 6 \n\t" // 48-63 + "vand.vi v13, v13, 0x3 \n\t" // 16-31 + "vand.vi v14, v14, 0x3 \n\t" // 32-47 + + "vand.vi v16, v6, 0x3 \n\t" // 0-15 + "vsrl.vi v17, v6, 2 \n\t" + "vsrl.vi v18, v6, 4 \n\t" + "vsrl.vi v19, v6, 6 \n\t" // 48-63 + "vand.vi v17, v17, 0x3 \n\t" // 16-31 + "vand.vi v18, v18, 0x3 \n\t" // 32-47 + + "vand.vi v4, v7, 0x3 \n\t" // 0-15 + "vsrl.vi v5, v7, 2 \n\t" + "vsrl.vi v6, v7, 4 \n\t" + "vsrl.vi v7, v7, 6 \n\t" // 48-63 + "vand.vi v5, v5, 0x3 \n\t" // 16-31 + "vand.vi v6, v6, 0x3 \n\t" // 32-47 + + // i2 * i8 vmadot + "vsetvli t0, x0, e8, m1 \n\t" + "vmadotsu v20, v2, v8, i8 \n\t" + "vmadotsu v22, v2, v12, i8 \n\t" + "vmadotsu v24, v2, v16, i8 \n\t" + "vmadotsu v26, v2, v4, i8 \n\t" + + "vmadotsu v20, v3, v9, i8 \n\t" + "vmadotsu v22, v3, v13, i8 \n\t" + "vmadotsu v24, v3, v17, i8 \n\t" + "vmadotsu v26, v3, v5, i8 \n\t" + + "vmadotsu v20, v28, v10, i8 \n\t" + "vmadotsu v22, v28, v14, i8 \n\t" + "vmadotsu v24, v28, v18, i8 \n\t" + "vmadotsu v26, v28, v6, i8 \n\t" + + "vmadotsu v20, v29, v11, i8 \n\t" + "vmadotsu v22, v29, v15, i8 \n\t" + "vmadotsu v24, v29, v19, i8 \n\t" + "vmadotsu v26, v29, v7, i8 \n\t" + + "vand.vi v10, v0, 0xf \n\t" // scale + "vwadd.vx v12, v10, x0 \n\t" + "vsetvli t0, x0, e16, m2 \n\t" + "vwadd.vx v16, v12, x0 \n\t" + + "vsetvli t0, x0, e32, m1 \n\t" + "vpack.vv v2, v20, v22, 2 \n\t" + "vpack.vv v4, v24, v26, 2 \n\t" + "vpack.vv v6, v2, v4, 3 \n\t" // 0,1 + "vpack.vv v8, v3, v5, 3 \n\t" // 2,3 + + // mul scale + "vmacc.vv v30, v6, v16 \n\t" + "vmacc.vv v30, v7, v17 \n\t" + "vmacc.vv v30, v8, v18 \n\t" + "vmacc.vv v30, v9, v19 \n\t" + + "addi t1, t1, -1 \n\t" + "bgtz t1, INNER_K_LOOP%= \n\t" + + // load zp B + "vsetvli t0, x0, e8, m4 \n\t" + "vle8.v v4, (s3) \n\t" + "vsrl.vi v8, v4, 4 \n\t" // zp + + // asum * zp + "vsetvli t0, x0, e16, m1 \n\t" + "vxor.vv v20, v20, v20 \n\t" + "vxor.vv v22, v22, v22 \n\t" + "vxor.vv v24, v24, v24 \n\t" + "vxor.vv v26, v26, v26 \n\t" + + "vsetvli t0, x0, e16, mf4 \n\t" + "vle16.v v2, (%[A]) \n\t" + "vsetvli t0, x0, e8, mf4 \n\t" + "vnsrl.wi v12, v2, 0 \n\t" // low 8 + "vnsra.wi v13, v2, 8 \n\t" // high 8 + + "vsetvli t0, x0, e32, m1 \n\t" + "vmadotsu v20, v13, v8, i8 \n\t" + "vmadotsu v22, v13, v9, i8 \n\t" + "vmadotsu v24, v13, v10, i8 \n\t" + "vmadotsu v26, v13, v11, i8 \n\t" + + "vsll.vi v20, v20, 8 \n\t" + "vsll.vi v22, v22, 8 \n\t" + "vsll.vi v24, v24, 8 \n\t" + "vsll.vi v26, v26, 8 \n\t" + + "vmadotu v20, v12, v8, i8 \n\t" + "vmadotu v22, v12, v9, i8 \n\t" + "vmadotu v24, v12, v10, i8 \n\t" + "vmadotu v26, v12, v11, i8 \n\t" + + "vpack.vv v2, v20, v22, 2 \n\t" + "vpack.vv v4, v24, v26, 2 \n\t" + "vpack.vv v28, v2, v4, 3 \n\t" + + "vsetvli t0, x0, e16, mf2 \n\t" + "vle16.v v0, (t2) \n\t" // scale16 + "addi t2, t2, 64 \n\t" + "vle16.v v1, (t2) \n\t" // zero16 + "vfwcvt.f.f.v v2, v0 \n\t" + "vfwcvt.f.f.v v4, v1 \n\t" + "vsetvli t0, x0, e32, m1 \n\t" + "vfcvt.f.x.v v30, v30 \n\t" + "vfcvt.f.x.v v28, v28 \n\t" + "addi %[B], t2, 64 \n\t" + "mv %[A], t3 \n\t" + + "vfmul.vv v30, v30, v2 \n\t" // mul scale16 + "vfmacc.vv v30, v28, v4 \n\t" // + mul zero16 + "vfmacc.vf v31, fa0, v30 \n\t" + "addi s1, s1, -1 \n\t" + "bgtz s1, BLK_LOOP%= \n\t" + + // save + "vsetvli t0, x0, e32, m1 \n\t" + "vse32.v v31, (%[DST]) \n\t" + : [A] "+r"(a_data), [B] "+r"(b_data) + : [DST] "r"(dst_c), [BK] "r"(k_blks) + : "t0", "t1", "t2", "t3", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", + "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", + "v28", "v29", "v30", "v31", "fa0", "t4", "t5", "t6", "s1", "s2", "s3"); + } +} + +void gemm_kernel_i8i2k_m4(size_t blk_len, + const uint8_t * quant_a_ptr, + const uint8_t * quant_b_data, + float * c_ptr, + size_t count_m, + size_t count_n, + size_t k_blks, + size_t ldc) { + constexpr size_t NB_COLS = 32; + using blk_type = nrow_block_q2_k<NB_COLS>; + + int64_t b_ncol_block_stride = sizeof(blk_type) * k_blks; + _Float16 scale = 0.0625f; + _Float16 scale_1 = 16.0f; + + for (size_t ni = 0; ni < count_n; ni += NB_COLS) { + uint8_t * b_data = (uint8_t *) quant_b_data + (ni / NB_COLS) * b_ncol_block_stride; + int8_t * a_data = (int8_t *) quant_a_ptr; + float * dst_c = (float *) c_ptr + ni; + + asm volatile( + "vsetvli t0, x0, e16, m1 \n\t" + "vxor.vv v28, v31, v31 \n\t" // init result + "vxor.vv v29, v31, v31 \n\t" + "vxor.vv v30, v31, v31 \n\t" + "vxor.vv v31, v31, v31 \n\t" + "mv s1, %[BK] \n\t" + + ".align 4 \n\t" + "BLK_LOOP%=: \n\t" + // load scale A + "flw fa0, (%[A]) \n\t" + "flw fa1, 4(%[A]) \n\t" + "flw fa2, 8(%[A]) \n\t" + "flw fa3, 12(%[A]) \n\t" + "addi %[A], %[A], 16 \n\t" + + "li t1, 4 \n\t" + "addi t2, %[B], 512 \n\t" // B data addr + "addi t3, %[A], 128 \n\t" // A data addr + "addi s4, t2, 1024 \n\t" // scale16 addr + "addi s4, s4, 1024 \n\t" // TODO + "addi s3, %[B], 0 \n\t" + + "vsetvli t0, x0, e16, mf2 \n\t" + "vle16.v v1, (s4) \n\t" // load scale16 + "vsetvli t0, x0, e16, m1 \n\t" + "vpack.vv v22, v1, v1, 3 \n\t" + + "addi s4, t3, 256 \n\t" // addr 1 + "addi s5, t3, 512 \n\t" // addr 2 + "addi s6, t3, 768 \n\t" // addr 3 + + // init the accu to 0 + "vxor.vv v24, v24, v24 \n\t" + "vxor.vv v25, v25, v25 \n\t" + "vxor.vv v26, v26, v26 \n\t" + "vxor.vv v27, v27, v27 \n\t" + + "INNER_K_LOOP%=: \n\t" + // load scale B + "vsetvli t0, x0, e8, m1 \n\t" + "vle8.v v1, (%[B]) \n\t" + "addi %[B], %[B], 128 \n\t" + "vand.vi v1, v1, 0xf \n\t" + + "vfwcvt.f.x.v v20, v1 \n\t" // f16 scale B + "vsetvli t0, x0, e16, m1 \n\t" + "vfmul.vv v0, v20, v22 \n\t" // mul scale16 + "vfmul.vv v1, v21, v22 \n\t" // mul scale16 + "vfmul.vf v0, v0, %[SCALE] \n\t" // mul magic + "vfmul.vf v1, v1, %[SCALE] \n\t" // mul magic + + // A data, 4x64@i8 + "vsetvli t0, x0, e8, mf2 \n\t" + "vle8.v v2, (t3) \n\t" + "addi t3, t3, 64 \n\t" + "vle8.v v3, (s4) \n\t" + "addi s4, s4, 64 \n\t" + "vle8.v v4, (s5) \n\t" + "addi s5, s5, 64 \n\t" + "vle8.v v5, (s6) \n\t" + "addi s6, s6, 64 \n\t" + + // 4x64 => 4x16x4 + "vsetvli t0, x0, e8, m1 \n\t" + "vpack.vv v6, v2, v3, 1 \n\t" + "vpack.vv v8, v4, v5, 1 \n\t" + "vpack.vv v2, v6, v8, 2 \n\t" // 0, 2 + + "vpack.vv v20, v2, v2, 3 \n\t" // 1 + "vor.vv v23, v21, v21 \n\t" + "vpack.vv v20, v3, v3, 3 \n\t" // 3 + + // B data, 32x64@i2 + "vsetvli t0, x0, e8, m1 \n\t" + "vl4r.v v4, (t2) \n\t" + "addi t2, t2, 512 \n\t" + "vand.vi v8, v4, 0x3 \n\t" // 0-15 + "vsrl.vi v9, v4, 2 \n\t" + "vsrl.vi v10, v4, 4 \n\t" + "vsrl.vi v11, v4, 6 \n\t" // 48-63 + "vand.vi v9, v9, 0x3 \n\t" // 16-31 + "vand.vi v10, v10, 0x3 \n\t" // 32-47 + + "vand.vi v12, v5, 0x3 \n\t" // 0-15 + "vsrl.vi v13, v5, 2 \n\t" + "vsrl.vi v14, v5, 4 \n\t" + "vsrl.vi v15, v5, 6 \n\t" // 48-63 + "vand.vi v13, v13, 0x3 \n\t" // 16-31 + "vand.vi v14, v14, 0x3 \n\t" // 32-47 + + "vand.vi v16, v6, 0x3 \n\t" // 0-15 + "vsrl.vi v17, v6, 2 \n\t" + "vsrl.vi v18, v6, 4 \n\t" + "vsrl.vi v19, v6, 6 \n\t" // 48-63 + "vand.vi v17, v17, 0x3 \n\t" // 16-31 + "vand.vi v18, v18, 0x3 \n\t" // 32-47 + + "vand.vi v4, v7, 0x3 \n\t" // 0-15 + "vsrl.vi v5, v7, 2 \n\t" + "vsrl.vi v6, v7, 4 \n\t" + "vsrl.vi v7, v7, 6 \n\t" // 48-63 + "vand.vi v5, v5, 0x3 \n\t" // 16-31 + "vand.vi v6, v6, 0x3 \n\t" // 32-47 + + // i2 * i8 vmadot + "vsetvli t0, x0, e8, m1 \n\t" + "vmadotsu.hp v24, v2, v8, v0, 0, i8 \n\t" + "vmadotsu.hp v25, v2, v12, v0, 1, i8 \n\t" + "vmadotsu.hp v26, v2, v16, v0, 2, i8 \n\t" + "vmadotsu.hp v27, v2, v4, v0, 3, i8 \n\t" + + "vmadotsu.hp v24, v23, v9, v0, 4, i8 \n\t" + "vmadotsu.hp v25, v23, v13, v0, 5, i8\n\t" + "vmadotsu.hp v26, v23, v17, v0, 6, i8\n\t" + "vmadotsu.hp v27, v23, v5, v0, 7, i8 \n\t" + + "vmadotsu.hp v24, v3, v10, v1, 0, i8 \n\t" + "vmadotsu.hp v25, v3, v14, v1, 1, i8 \n\t" + "vmadotsu.hp v26, v3, v18, v1, 2, i8 \n\t" + "vmadotsu.hp v27, v3, v6, v1, 3, i8 \n\t" + + "vmadotsu.hp v24, v21, v11, v1, 4, i8\n\t" + "vmadotsu.hp v25, v21, v15, v1, 5, i8\n\t" + "vmadotsu.hp v26, v21, v19, v1, 6, i8\n\t" + "vmadotsu.hp v27, v21, v7, v1, 7, i8 \n\t" + + "addi t1, t1, -1 \n\t" + "bgtz t1, INNER_K_LOOP%= \n\t" + + "vsetvli t0, x0, e16, m1 \n\t" + "vpack.vv v2, v24, v25, 1 \n\t" + "vpack.vv v4, v26, v27, 1 \n\t" + "vpack.vv v6, v2, v4, 2 \n\t" // 0,1,2,3 + + "vxor.vv v18, v18, v18 \n\t" + "vxor.vv v20, v20, v20 \n\t" + "vxor.vv v22, v22, v22 \n\t" + "vxor.vv v24, v24, v24 \n\t" + // load zp B, 16x8x4@int4 + "vsetvli t0, x0, e8, m4 \n\t" + "vle8.v v0, (s3) \n\t" + "vsrl.vi v0, v0, 4 \n\t" // zp + + // 4x16@int16 + "vsetvli t0, x0, e16, m1 \n\t" // a sum + "vle16.v v12, (%[A]) \n\t" + "vsetvli t0, x0, e8, m1 \n\t" + "vnsrl.wi v10, v12, 0 \n\t" // low 8 + "vnsra.wi v11, v12, 8 \n\t" // high 8 + + // asum * zp + "vsetvli t0, x0, e32, m1 \n\t" + "vmadotsu v18, v11, v0, i8 \n\t" + "vmadotsu v20, v11, v1, i8 \n\t" + "vmadotsu v22, v11, v2, i8 \n\t" + "vmadotsu v24, v11, v3, i8 \n\t" + "vsll.vi v18, v18, 8 \n\t" + "vsll.vi v20, v20, 8 \n\t" + "vsll.vi v22, v22, 8 \n\t" + "vsll.vi v24, v24, 8 \n\t" + "vmadotu v18, v10, v0, i8 \n\t" + "vmadotu v20, v10, v1, i8 \n\t" + "vmadotu v22, v10, v2, i8 \n\t" + "vmadotu v24, v10, v3, i8 \n\t" + + "vpack.vv v10, v18, v20, 2 \n\t" + "vpack.vv v12, v22, v24, 2 \n\t" + "vpack.vv v14, v10, v12, 3 \n\t" + "vpack.vv v16, v11, v13, 3 \n\t" + + "vsetvli t0, x0, e16, mf2 \n\t" + "addi t2, t2, 64 \n\t" + "vle16.v v20, (t2) \n\t" // zero16 + "vfwcvt.f.f.v v22, v20 \n\t" + + // mul 1/magic + "vsetvli t0, x0, e16, m1 \n\t" + "vfwmul.vf v0, v6, %[SCALE_1] \n\t" + "vfwmul.vf v2, v7, %[SCALE_1] \n\t" + + "vsetvli t0, x0, e32, m1 \n\t" + "vfcvt.f.x.v v14, v14 \n\t" + "vfcvt.f.x.v v15, v15 \n\t" + "vfcvt.f.x.v v16, v16 \n\t" + "vfcvt.f.x.v v17, v17 \n\t" + + "addi %[B], t2, 64 \n\t" + "mv %[A], s6 \n\t" + + "vfmacc.vv v0, v14, v22 \n\t" // + mul zero16 + "vfmacc.vv v1, v15, v22 \n\t" + "vfmacc.vv v2, v16, v22 \n\t" + "vfmacc.vv v3, v17, v22 \n\t" + + "vfmacc.vf v28, fa0, v0 \n\t" // mul a scale + "vfmacc.vf v29, fa1, v1 \n\t" + "vfmacc.vf v30, fa2, v2 \n\t" + "vfmacc.vf v31, fa3, v3 \n\t" + + "addi s1, s1, -1 \n\t" + "bgtz s1, BLK_LOOP%= \n\t" + + // save + "vsetvli t0, x0, e32, m1 \n\t" + "add t1, %[LDC], %[DST] \n\t" + "vse32.v v28, (%[DST]) \n\t" + "vse32.v v29, (t1) \n\t" + "add t1, t1, %[LDC] \n\t" + "vse32.v v30, (t1) \n\t" + "add t1, t1, %[LDC] \n\t" + "vse32.v v31, (t1) \n\t" + : [A] "+r"(a_data), [B] "+r"(b_data) + : [DST] "r"(dst_c), [BK] "r"(k_blks), [LDC] "r"(ldc * 4), [SCALE] "f"(scale), [SCALE_1] "f"(scale_1) + : "t0", "t1", "t2", "t3", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", + "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", + "v28", "v29", "v30", "v31", "fa0", "t4", "t5", "t6", "s1", "s2", "s3", "s4", "s5", "s6"); + } +} + +void gemm_kernel_i8i3k_m1(size_t blk_len, + const uint8_t * quant_a_ptr, + const uint8_t * quant_b_data, + float * c_ptr, + size_t count_m, + size_t count_n, + size_t k_blks, + size_t ldc) { + constexpr size_t NB_COLS = 32; //only support 32 in ASM + using blk_type = nrow_block_q3_k<NB_COLS>; + + const blk_type * b_base = reinterpret_cast<const blk_type *>(quant_b_data); + + int64_t a_blk_stride = q8k_blk_size(256); + int64_t a_nrow_block_stride = a_blk_stride; + int64_t b_ncol_block_stride = sizeof(blk_type); + + // Constants used by q3_k scaling in HP branch: + // - k_q3k_scale_step: per-nibble scale factor (1/16). + // - k_a_scale_post_mul: A_scale needs an extra *16 at the end (pairs with 1/16 above). + const _Float16 k_q3k_scale_step = (_Float16) 0.0625f; // 1 / 16 + const float k_a_scale_post_mul = 16.0f; + + for (size_t ni = 0; ni < count_n; ni += NB_COLS, c_ptr += NB_COLS) { + size_t nb_real = std::min<size_t>(NB_COLS, count_n - ni); + const blk_type * quant_b_blk_data = b_base + (ni / NB_COLS) * k_blks; +#if 0 + //------------------------------------------------------------------------------ + // A format + // Ascale fp32 * 1 32bit + // Asum int16 * 16 256bit + // A M1K256 int8 2048bit + //------------------------------------------------------------------------------ + // B format + // B_scl uint8*N32*16 4096bit + // B_Hmask N32K16*16 1bit 8192bit + // B_Qs N32K16*16 2bit 16384bit + // B scl16 fp16 * N32 512bit; + //------------------------------------------------------------------------------ + //bias always be nullptr + __asm__ volatile( + // t2 = k_blks (each is K256 superblock) + "mv t2, %[KBLKS] \n\t" + // t3 = 256/64 = 4 (K64 iterations per superblock) + "li t3, 4 \n\t" + "mv s2, %[pA] \n\t" // s2 = pASCL + "addi s3, %[pA], 4+32 \n\t" // s3 = pAData, (pA+AScl+ASum) + + // B block layout for nrow_block_q3_k<32>: + // scales: 512B, hmask: 1024B, qs: 2048B, scales16: 64B + "addi s5, %[pB], 32*16 \n\t" // s5 = pB_hmask + "mv s4, %[pB] \n\t" // s4 = pB_scales + "addi s6, s5, 1024 \n\t" // s6 = pB_qs + "mv s7, %[pB] \n\t" // s7 = pB_base + + "vsetvli t0, x0, e32, m1 \n\t" + "vxor.vv v31, v0, v0 \n\t" // clear acc + "vxor.vv v30, v0, v0 \n\t" // clear acc of K256 + + // ordinary vmadot: vle*10 vecIns*78 vmadot*16 + ".align 4 \n\t" + "BLK_LPST%=: \n\t" + "K64_LPST%=: \n\t" + + // K0-15 + // load B scales (32 bytes per K16, 16 times => 512B) + "vsetvli t0, x0, e8, m1 \n\t" + "vle8.v v2, (s4) \n\t" + "addi s4, s4, 128 \n\t" + + // load B qs chunk (128B per K16, 16 times => 2048B) + "vle8.v v4, (s6) \n\t" + "addi s6, s6, 128 \n\t" + "vle8.v v5, (s6) \n\t" + "addi s6, s6, 128 \n\t" + "vle8.v v6, (s6) \n\t" + "addi s6, s6, 128 \n\t" + "vle8.v v7, (s6) \n\t" + "addi s6, s6, 128 \n\t" + + // load B hmask chunk (64B per K16, 16 times => 1024B) + "vsetvli t0, x0, e8, mf2 \n\t" + "vle8.v v0, (s5) \n\t" + "addi s5, s5, 64 \n\t" + + // load A data (16 bytes per K16, 16 times => 256B) + "vsetvli t0, x0, e8, mf2 \n\t" + "vle8.v v1, (s3) \n\t" + "addi s3, s3, 64 \n\t" + + // unpack 2-bit qs + hmask -> signed values + "vsetvli t0, x0, e8, m1 \n\t" + "vnot.v v0, v0 \n\t" + "vand.vi v12, v4, 0x3 \n\t" + "vand.vi v13, v5, 0x3 \n\t" + "vand.vi v14, v6, 0x3 \n\t" + "vand.vi v15, v7, 0x3 \n\t" + + "vsetvli t0, x0, e8, m4 \n\t" + "vadd.vi v12, v12, -4, v0.t \n\t" + + "vsetvli t0, x0, e32, m1 \n\t" + "vxor.vv v16, v16, v16 \n\t" + "vxor.vv v18, v16, v16 \n\t" + "vxor.vv v20, v16, v16 \n\t" + "vxor.vv v22, v16, v16 \n\t" + + "vmadot v16, v1, v12, i8 \n\t" + "vmadot v18, v1, v13, i8 \n\t" + "vmadot v20, v1, v14, i8 \n\t" + "vmadot v22, v1, v15, i8 \n\t" + + "vsetvli t0, x0, e16, m1 \n\t" + "vpack.vv v24, v16, v18, 2 \n\t" + "vpack.vv v26, v20, v22, 2 \n\t" + "vpack.vv v16, v24, v26, 3 \n\t" // N0-N31 in v16 + + // apply B int8 scales (-32 bias has been applyed) + "vsetvli t0, x0, e8, mf4 \n\t" + "vwadd.vx v18, v2, x0 \n\t" // int8 -> int16 + + "vsetvli t0, x0, e16, mf2 \n\t" + "vwadd.vx v19, v18, x0 \n\t" // int8 -> int16 + + // static_cast<int32_t>(qsum) * b_scale; + "vsetvli t0, x0, e32, m1 \n\t" + "vmacc.vv v30, v16, v19 \n\t" + + //K16-31 + // load B scales (32 bytes per K16, 16 times => 512B) + "vsetvli t0, x0, e64, m1 \n\t" + "vslidedown.vi v2, v2, 4 \n\t" + + // load B hmask chunk (64B per K16, 16 times => 1024B) + "vsetvli t0, x0, e8, mf2 \n\t" + "vle8.v v0, (s5) \n\t" + "addi s5, s5, 64 \n\t" + + // load A data (16 bytes per K16, 16 times => 256B) + "vsetvli t0, x0, e64, mf2 \n\t" + "vslidedown.vi v1, v1, 2 \n\t" + + // unpack 2-bit qs + hmask -> signed values + "vsetvli t0, x0, e8, m1 \n\t" + "vsll.vi v8, v4, 4 \n\t" + "vsll.vi v9, v5, 4 \n\t" + "vsll.vi v10, v6, 4 \n\t" + "vsll.vi v11, v7, 4 \n\t" + "vnot.v v0, v0 \n\t" + + "vsrl.vi v12, v8, 6 \n\t" + "vsrl.vi v13, v9, 6 \n\t" + "vsrl.vi v14, v10, 6 \n\t" + "vsrl.vi v15, v11, 6 \n\t" + + "vsetvli t0, x0, e8, m4 \n\t" + "vadd.vi v12, v12, -4, v0.t \n\t" + + "vsetvli t0, x0, e32, m1 \n\t" + "vxor.vv v16, v16, v16 \n\t" + "vxor.vv v18, v16, v16 \n\t" + "vxor.vv v20, v16, v16 \n\t" + "vxor.vv v22, v16, v16 \n\t" + + "vmadot v16, v1, v12, i8 \n\t" + "vmadot v18, v1, v13, i8 \n\t" + "vmadot v20, v1, v14, i8 \n\t" + "vmadot v22, v1, v15, i8 \n\t" + + "vsetvli t0, x0, e16, m1 \n\t" + "vpack.vv v24, v16, v18, 2 \n\t" + "vpack.vv v26, v20, v22, 2 \n\t" + "vpack.vv v16, v24, v26, 3 \n\t" // N0-N31 in v16 + + // apply B int8 scales (-32 bias has been applyed) + "vsetvli t0, x0, e8, mf4 \n\t" + "vwadd.vx v18, v2, x0 \n\t" // int8 -> int16 + + "vsetvli t0, x0, e16, mf2 \n\t" + "vwadd.vx v19, v18, x0 \n\t" // int8 -> int16 + + // static_cast<int32_t>(qsum) * b_scale; + "vsetvli t0, x0, e32, m1 \n\t" + "vmacc.vv v30, v16, v19 \n\t" + + //K32-47 + // load B scales (32 bytes per K16, 16 times => 512B) + "vsetvli t0, x0, e64, m1 \n\t" + "vslidedown.vi v2, v2, 4 \n\t" + + // load B hmask chunk (64B per K16, 16 times => 1024B) + "vsetvli t0, x0, e8, mf2 \n\t" + "vle8.v v0, (s5) \n\t" + "addi s5, s5, 64 \n\t" + + // load A data (16 bytes per K16, 16 times => 256B) + "vsetvli t0, x0, e64, mf2 \n\t" + "vslidedown.vi v1, v1, 2 \n\t" + + // unpack 2-bit qs + hmask -> signed values + "vsetvli t0, x0, e8, m1 \n\t" + "vsll.vi v8, v4, 2 \n\t" + "vsll.vi v9, v5, 2 \n\t" + "vsll.vi v10, v6, 2 \n\t" + "vsll.vi v11, v7, 2 \n\t" + "vnot.v v0, v0 \n\t" + + "vsrl.vi v12, v8, 6 \n\t" + "vsrl.vi v13, v9, 6 \n\t" + "vsrl.vi v14, v10, 6 \n\t" + "vsrl.vi v15, v11, 6 \n\t" + + "vsetvli t0, x0, e8, m4 \n\t" + "vadd.vi v12, v12, -4, v0.t \n\t" + + "vsetvli t0, x0, e32, m1 \n\t" + "vxor.vv v16, v16, v16 \n\t" + "vxor.vv v18, v16, v16 \n\t" + "vxor.vv v20, v16, v16 \n\t" + "vxor.vv v22, v16, v16 \n\t" + + "vmadot v16, v1, v12, i8 \n\t" + "vmadot v18, v1, v13, i8 \n\t" + "vmadot v20, v1, v14, i8 \n\t" + "vmadot v22, v1, v15, i8 \n\t" + + "vsetvli t0, x0, e16, m1 \n\t" + "vpack.vv v24, v16, v18, 2 \n\t" + "vpack.vv v26, v20, v22, 2 \n\t" + "vpack.vv v16, v24, v26, 3 \n\t" + + // apply B int8 scales (-32 bias has been applyed) + "vsetvli t0, x0, e8, mf4 \n\t" + "vwadd.vx v18, v2, x0 \n\t" // int8 -> int16 + + "vsetvli t0, x0, e16, mf2 \n\t" + "vwadd.vx v19, v18, x0 \n\t" // int8 -> int16 + + // static_cast<int32_t>(qsum) * b_scale; + "vsetvli t0, x0, e32, m1 \n\t" + "vmacc.vv v30, v16, v19 \n\t" + + // K48-63 + // load B scales (32 bytes per K16, 16 times => 512B) + "vsetvli t0, x0, e64, m1 \n\t" + "vslidedown.vi v2, v2, 4 \n\t" + + // load B hmask chunk (64B per K16, 16 times => 1024B) + "vsetvli t0, x0, e8, mf2 \n\t" + "vle8.v v0, (s5) \n\t" + "addi s5, s5, 64 \n\t" + + // load A data (16 bytes per K16, 16 times => 256B) + "vsetvli t0, x0, e64, mf2 \n\t" + "vslidedown.vi v1, v1, 2 \n\t" + + "vsetvli t0, x0, e8, m1 \n\t" + "vnot.v v0, v0 \n\t" + "vsrl.vi v12, v4, 6 \n\t" + "vsrl.vi v13, v5, 6 \n\t" + "vsrl.vi v14, v6, 6 \n\t" + "vsrl.vi v15, v7, 6 \n\t" + + "vsetvli t0, x0, e8, m4 \n\t" + "vadd.vi v12, v12, -4, v0.t \n\t" + + "vsetvli t0, x0, e32, m1 \n\t" + "vxor.vv v16, v16, v16 \n\t" + "vxor.vv v18, v16, v16 \n\t" + "vxor.vv v20, v16, v16 \n\t" + "vxor.vv v22, v16, v16 \n\t" + + "vmadot v16, v1, v12, i8 \n\t" + "vmadot v18, v1, v13, i8 \n\t" + "vmadot v20, v1, v14, i8 \n\t" + "vmadot v22, v1, v15, i8 \n\t" + + "vsetvli t0, x0, e16, m1 \n\t" + "vpack.vv v24, v16, v18, 2 \n\t" + "vpack.vv v26, v20, v22, 2 \n\t" + "vpack.vv v16, v24, v26, 3 \n\t" + + // apply B int8 scales (-32 bias has been applyed) + "vsetvli t0, x0, e8, mf4 \n\t" + "vwadd.vx v18, v2, x0 \n\t" // int8 -> int16 + + "vsetvli t0, x0, e16, mf2 \n\t" + "vwadd.vx v19, v18, x0 \n\t" // int8 -> int16 + + // static_cast<int32_t>(qsum) * b_scale; + "vsetvli t0, x0, e32, m1 \n\t" + "vmacc.vv v30, v16, v19 \n\t" + + "addi t3, t3, -1 \n\t" + "bgtz t3, K64_LPST%= \n\t" + "K64_LPND%=: \n\t" + + // load A scale (fp32) and advance A to next superblock + "flw f0, (s2) \n\t" + "addi s2, s2, 4+32+256 \n\t" + "add t4, s7, %[B_STR] \n\t" // t4 = next B blk base + "addi s3, s2, 4+32 \n\t" + + // load B scales16[32] (fp16) at end of qs region + "vsetvli t0, x0, e16, mf2 \n\t" + "vle16.v v2, (s6) \n\t" + + // pointer modify + "addi s5, t4, 32*16 \n\t" + "mv s4, t4 \n\t" + "addi s6, s5, 32*32 \n\t" + "addi s7, t4, 0 \n\t" + + // b_scale fp16 -> fp32 + "vsetvli t0, x0, e16, mf2 \n\t" + "vfwcvt.f.f.v v24, v2 \n\t" + + // a_scale * b_scale; + "vsetvli t0, x0, e32, m1 \n\t" + "vfcvt.f.x.v v26, v30 \n\t" + "vfmul.vf v1, v24, f0 \n\t" + "vsetvli t0, x0, e32, m1 \n\t" + // static_cast<float>(qsum) * a_scale * b_scale; + "vfmacc.vv v31, v1, v26 \n\t" + + // next K-superblock + "addi t2, t2, -1 \n\t" + "vxor.vv v30, v0, v0 \n\t" // clear acc of K256 + "li t3, 4 \n\t" + "bgtz t2, BLK_LPST%= \n\t" + + "BLK_LPND%=: \n\t" + "vsetvli t0, %[NBLKS], e32, m1 \n\t" + "vse32.v v31, (%[pC]) \n\t" + "FUNC_END%=: \n\t" + + : + : [KBLKS] "r"(k_blks), [NBLKS] "r"(nb_real), [pA] "r"(quant_a_ptr), [pB] "r"(quant_b_blk_data), + [pC] "r"(c_ptr), [B_STR] "r"(b_ncol_block_stride) + : "cc", "memory", "t0", "t2", "t3", "t4", "t5", "f0", "s2", "s3", "s4", "s5", "s6", "s7"); +#else + + __asm__ volatile( + // ========================= + // Kernel overview (M1 x N32) + // ========================= + // Process one output row (M=1) and 32 columns (N=32) per call. + // + // Loop structure: + // - Outer loop: K superblocks of size K=256 (k_blks times) + // - Each K256 superblock is broken into 4 x K64 + // - Each K64 is processed as 4 x K16 "sub-blocks" (via unpack+dot) + // + // Data layout (high level): + // A (q8k K=256, per superblock): + // [ fp32 a_scale ][ int16 a_sum[16] ][ int8 a_qs[256] ] + // B (nrow_block_q3_k<32>, per superblock): + // [ int8 scales[32*16] ][ hmask[1024] ][ qs[2048] ][ fp16 scales16[32] ] + // + // Registers/pointers: + // s2: pA (points at A superblock header; used to load fp32 a_scale) + // s3: pA_qs (points at A int8 data within the current superblock) + // s4: pB_scales (points at B int8 per-K16 scales) + // s5: pB_hmask (points at B sign mask area) + // s6: pB_qs (points at B 2-bit packed qs area) + // s8: pB_scales16 (points at B fp16 scales16[32] at the end of block) + // s7: pB_base (base pointer to current B block; used for block-to-block stride) + + // t2 = number of K256 superblocks + "mv t2, %[KBLKS] \n\t" + // t3 = number of K64 chunks per K256 superblock (256 / 64) + "li t3, 4 \n\t" + + // A pointers + "mv s2, %[pA] \n\t" // s2 = pA_superblock (a_scale at +0) + "addi s3, %[pA], 4+32 \n\t" // s3 = pA_qs (skip a_scale + a_sum[16]) + + // B pointers for nrow_block_q3_k<32> + "addi s5, %[pB], 32*16 \n\t" // s5 = pB_hmask (skip scales[32*16]) + "mv s4, %[pB] \n\t" // s4 = pB_scales + "addi s6, s5, 1024 \n\t" // s6 = pB_qs (skip hmask) + // scales16 is at the end of the block: qs(2048) after hmask + "addi s8, s6, 1024 \n\t" + "addi s8, s8, 1024 \n\t" // s8 = pB_scales16 (fp16 scales16[32]) + "mv s7, %[pB] \n\t" // s7 = pB_base (for next-block address calc) + + // v31: final FP32 accumulator for N=32 + "vsetvli t0, x0, e32, m1 \n\t" + "vxor.vv v31, v0, v0 \n\t" + + // ---- Preload B scales16[32] and build FP16 scale vector used by vmadot.hp ---- + "vsetvli t0, x0, e16, mf2 \n\t" + "vle16.v v1, (s8) \n\t" // load fp16 scales16[32] + "vsetvli t0, x0, e16, m1 \n\t" + "vpack.vv v26, v1, v1, 3 \n\t" // broadcast/pack to match lanes + "vmv.v.v v17, v26 \n\t" + "vsetvli t0, x0, e16, m1 \n\t" + "vfmul.vf v30, v17, %[q3_step] \n\t" // v30 = scales16 * (1/16) + + // v24-v27: fp16 partial accumulators for a K64 chunk (vmadot.hp outputs) + "vsetvli t0, x0, e32, m1 \n\t" + "vxor.vv v24, v16, v16 \n\t" + "vxor.vv v25, v16, v16 \n\t" + "vxor.vv v26, v16, v16 \n\t" + "vxor.vv v27, v16, v16 \n\t" + + // HP vmadot: vle*10 vecIns*38 vmadot.hp*16 + ".align 4 \n\t" + "BLK_LPST%=: \n\t" // loop over K256 superblocks + "K64_LPST%=: \n\t" // loop over 4 x K64 chunks + + // ------------------------------------------------------------ + // K0-15: load B scales + {hmask, qs} + A data; unpack and dot + // ------------------------------------------------------------ + "vsetvli t0, x0, e8, m1 \n\t" + "vle8.v v2, (s4) \n\t" // B int8 scales for this K16 + "addi s4, s4, 128 \n\t" + + "vle8.v v4, (s6) \n\t" + "addi s6, s6, 128 \n\t" + "vle8.v v5, (s6) \n\t" + "addi s6, s6, 128 \n\t" + "vle8.v v6, (s6) \n\t" + "addi s6, s6, 128 \n\t" + "vle8.v v7, (s6) \n\t" + "addi s6, s6, 128 \n\t" + + "vsetvli t0, x0, e8, mf2 \n\t" + "vle8.v v0, (s5) \n\t" // B hmask for this K16 + "addi s5, s5, 64 \n\t" + + "vsetvli t0, x0, e8, mf2 \n\t" + "vle8.v v3, (s3) \n\t" // A int8 data for this K16 + "addi s3, s3, 64 \n\t" + + // Convert B int8 scales to FP16 and apply scales16*(1/16) + "vsetvli t0, x0, e8, m1 \n\t" + "vfwcvt.f.x.v v28, v2 \n\t" // int8 -> fp16 + "vsetvli t0, x0, e16, m1 \n\t" + "vfmul.vv v1, v28, v30 \n\t" // v1: FP16 scale vector for vmadot.hp + "vfmul.vv v29, v29, v30 \n\t" + + // Unpack B 2-bit qs + hmask -> signed int8 in v12..v15 + "vsetvli t0, x0, e8, m1 \n\t" + "vnot.v v0, v0 \n\t" + "vand.vi v12, v4, 0x3 \n\t" + "vand.vi v13, v5, 0x3 \n\t" + "vand.vi v14, v6, 0x3 \n\t" + "vand.vi v15, v7, 0x3 \n\t" + "vsetvli t0, x0, e8, m4 \n\t" + "vadd.vi v12, v12, -4, v0.t \n\t" + + // (Next K16 unpack path uses a fresh hmask load) + "vsetvli t0, x0, e8, mf2 \n\t" + "vle8.v v0, (s5) \n\t" + "addi s5, s5, 64 \n\t" + + // Prepare another group from packed qs (bit shifts) + apply sign from hmask + "vsetvli t0, x0, e8, m1 \n\t" + "vsll.vi v8, v4, 4 \n\t" + "vsll.vi v9, v5, 4 \n\t" + "vsll.vi v10, v6, 4 \n\t" + "vsll.vi v11, v7, 4 \n\t" + "vsrl.vi v16, v8, 6 \n\t" + "vsrl.vi v17, v9, 6 \n\t" + "vnot.v v0, v0 \n\t" + "vsrl.vi v18, v10, 6 \n\t" + "vsrl.vi v19, v11, 6 \n\t" + "vsetvli t0, x0, e8, m4 \n\t" + "vadd.vi v16, v16, -4, v0.t \n\t" + + // A shift for the second dot within this K64 + "vsetvli t0, x0, e64, mf2 \n\t" + "vslidedown.vi v2, v3, 2 \n\t" + + // Dot products with FP16 scaling (accumulate into v24..v27) + "vsetvli t0, x0, e32, m1 \n\t" + "vmadot.hp v24, v3, v12, v1, 0, i8 \n\t" + "vmadot.hp v25, v3, v13, v1, 1, i8 \n\t" + "vmadot.hp v26, v3, v14, v1, 2, i8 \n\t" + "vmadot.hp v27, v3, v15, v1, 3, i8 \n\t" + "vmadot.hp v24, v2, v16, v1, 4, i8 \n\t" + "vmadot.hp v25, v2, v17, v1, 5, i8 \n\t" + "vmadot.hp v26, v2, v18, v1, 6, i8 \n\t" + "vmadot.hp v27, v2, v19, v1, 7, i8 \n\t" + + // (K32-47 / K48-63 blocks continue unchanged...) + // load B scales (32 bytes per K16, 16 times => 512B) + "vsetvli t0, x0, e64, m1 \n\t" + "vmv.v.v v1, v29 \n\t" + + // load B hmask chunk (64B per K16, 16 times => 1024B) + "vsetvli t0, x0, e8, mf2 \n\t" + "vle8.v v0, (s5) \n\t" + "addi s5, s5, 64 \n\t" + + // load A data (16 bytes per K16, 16 times => 256B) + "vsetvli t0, x0, e64, mf2 \n\t" + "vslidedown.vi v3, v3, 4 \n\t" + + // unpack 2-bit qs + hmask -> signed values + "vsetvli t0, x0, e8, m1 \n\t" + "vsll.vi v8, v4, 2 \n\t" + "vsll.vi v9, v5, 2 \n\t" + "vsll.vi v10, v6, 2 \n\t" + "vsll.vi v11, v7, 2 \n\t" + + "vsrl.vi v20, v8, 6 \n\t" + "vsrl.vi v21, v9, 6 \n\t" + "vnot.v v0, v0 \n\t" + "vsrl.vi v22, v10, 6 \n\t" + "vsrl.vi v23, v11, 6 \n\t" + + "vsetvli t0, x0, e8, m4 \n\t" + "vadd.vi v20, v20, -4, v0.t \n\t" + + // K48-63 + "vsetvli t0, x0, e8, mf2 \n\t" + "vle8.v v0, (s5) \n\t" + "addi s5, s5, 64 \n\t" + + "vsetvli t0, x0, e8, m1 \n\t" + "vsrl.vi v8, v4, 6 \n\t" + "vsrl.vi v9, v5, 6 \n\t" + "vnot.v v0, v0 \n\t" + "vsrl.vi v10, v6, 6 \n\t" + "vsrl.vi v11, v7, 6 \n\t" + + "vsetvli t0, x0, e8, m4 \n\t" + "vadd.vi v8, v8, -4, v0.t \n\t" + + // load A data (16 bytes per K16, 16 times => 256B) + "vsetvli t0, x0, e64, mf2 \n\t" + "vslidedown.vi v2, v3, 2 \n\t" + + "vsetvli t0, x0, e32, m1 \n\t" + "vmadot.hp v24, v3, v20, v1, 0, i8 \n\t" + "vmadot.hp v25, v3, v21, v1, 1, i8 \n\t" + "vmadot.hp v26, v3, v22, v1, 2, i8 \n\t" + "vmadot.hp v27, v3, v23, v1, 3, i8 \n\t" + "vmadot.hp v24, v2, v8, v1, 4, i8 \n\t" + "vmadot.hp v25, v2, v9, v1, 5, i8 \n\t" + "vmadot.hp v26, v2, v10, v1, 6, i8 \n\t" + "vmadot.hp v27, v2, v11, v1, 7, i8 \n\t" + + "addi t3, t3, -1 \n\t" + "bgtz t3, K64_LPST%= \n\t" + "K64_LPND%=: \n\t" + + // ---- End of K64 chunk: reduce fp16 accumulators -> fp32 and scale by A ---- + "vsetvli t0, x0, e16, m1 \n\t" + "vpack.vv v12, v24, v25, 1 \n\t" + "vpack.vv v14, v26, v27, 1 \n\t" + "vpack.vv v16, v12, v14, 2 \n\t" + "vsetvli t0, x0, e16, mf2 \n\t" + "vfwcvt.f.f.v v26, v16 \n\t" // fp16 -> fp32 vector (qsum * b_scales) + + // Load A scale and advance A pointer to next K256 superblock + "flw f0, (s2) \n\t" + "addi s2, s2, 4+32+256 \n\t" + "add t4, s7, %[B_STR] \n\t" // next B block base + "addi s3, s2, 4+32 \n\t" // reset A data pointer for next block + + // Advance B pointers to next K256 superblock + "addi s5, t4, 32*16 \n\t" + "mv s4, t4 \n\t" + "addi s6, s5, 32*32 \n\t" + "addi s8, s6, 1024 \n\t" + "addi s8, s8, 1024 \n\t" + "addi s7, t4, 0 \n\t" + "addi t2, t2, -1 \n\t" + + // Final per-block scaling: a_scale * 16.0f + "fmul.s f0, f0, %[a_post_mul] \n\t" + // acc += (qsum * b_scales) * (a_scale*16) + "vsetvli t0, x0, e32, m1 \n\t" + "vfmacc.vf v31, f0, v26 \n\t" + + "beqz t2, BLK_LPND%= \n\t" + + // Preload next block's scales16 and rebuild v30 for vmadot.hp + "vsetvli t0, x0, e16, mf2 \n\t" + "vle16.v v1, (s8) \n\t" + "vsetvli t0, x0, e16, m1 \n\t" + "vpack.vv v26, v1, v1, 3 \n\t" + "vmv.v.v v17, v26 \n\t" + "vsetvli t0, x0, e16, m1 \n\t" + "vfmul.vf v30, v17, %[q3_step] \n\t" + + // Reset fp16 partial accumulators for next K64 loop(s) + "vsetvli t0, x0, e32, m1 \n\t" + "vxor.vv v24, v16, v16 \n\t" + "vxor.vv v25, v16, v16 \n\t" + "vxor.vv v26, v16, v16 \n\t" + "vxor.vv v27, v16, v16 \n\t" + + "li t3, 4 \n\t" + "bgtz t2, BLK_LPST%= \n\t" + + "BLK_LPND%=: \n\t" + "vsetvli t0, %[NBLKS], e32, m1 \n\t" + "vse32.v v31, (%[pC]) \n\t" + + : + : [KBLKS] "r"(k_blks), [NBLKS] "r"(nb_real), [pA] "r"(quant_a_ptr), [pB] "r"(quant_b_blk_data), + [pC] "r"(c_ptr), [B_STR] "r"(b_ncol_block_stride), [q3_step] "f"(k_q3k_scale_step), + [a_post_mul] "f"(k_a_scale_post_mul) + : "cc", "memory", "t0", "t2", "t3", "t4", "t5", "f0", "f1", "s2", "s3", "s4", "s5", "s6", "s7", "s8"); +#endif + } +} + +void gemm_kernel_i8i3k_m4(size_t blk_len, + const uint8_t * quant_a_ptr, + const uint8_t * quant_b_data, + float * c_ptr, + size_t count_m, + size_t count_n, + size_t k_blks, + size_t ldc) { + using blk_type = nrow_block_q3_k<32>; + constexpr size_t NB_COLS = 32; //only support 32 in ASM + + const blk_type * b_base = reinterpret_cast<const blk_type *>(quant_b_data); + + int64_t a_blk_stride = q8k_blk_size(256); + int64_t a_nrow_block_stride = a_blk_stride * 4; + int64_t b_ncol_block_stride = sizeof(blk_type); + + for (size_t ni = 0; ni < count_n; ni += NB_COLS, c_ptr += NB_COLS) { + size_t nb_real = std::min<size_t>(NB_COLS, count_n - ni); + const blk_type * quant_b_blk_data = b_base + (ni / NB_COLS) * k_blks; + + //------------------------------------------------------------------------------ + // A format + // Ascale fp32 * 1* 4row 128bit + // Asum int16 * 16 4row 1024bit + // A M1K256 int8 4row 8192bit + //------------------------------------------------------------------------------ + // B format + // B_scl uint8*N32*16 4096bit + // B_Hmask N32K16*16 1bit 8192bit + // B_Qs N32K16*16 2bit 16384bit + // B scl16 fp16 * N32 512bit; + //------------------------------------------------------------------------------ + //bias always be nullptr + __asm__ volatile( + // t2 = k_blks (each is K256 superblock) + "mv t2, %[KBLKS] \n\t" + // t3 = 256/64 = 4 (K64 iterations per superblock) + "li t3, 4 \n\t" + "mv s2, %[pA] \n\t" // s2 = pASCL + "addi s3, %[pA], 16+128 \n\t" // s3 = pAData, (pA+AScl+ASum) + + // B block layout for nrow_block_q3_k<32>: + // scales: 512B, hmask: 1024B, qs: 2048B, scales16: 64B + "addi s5, %[pB], 32*16 \n\t" // s5 = pB_hmask (skip scales) + "mv s4, %[pB] \n\t" // s4 = pB_scales + "addi s6, s5, 1024 \n\t" // s6 = pB_qs (skip hmask) + "mv s7, %[pB] \n\t" // s7 = pB_base + + "vsetvli t0, x0, e32, m1 \n\t" + "vxor.vv v24, v0, v0 \n\t" // v24-v27: K256 temp accumulator + "vxor.vv v25, v0, v0 \n\t" + "vxor.vv v26, v0, v0 \n\t" + "vxor.vv v27, v0, v0 \n\t" + "vxor.vv v28, v0, v0 \n\t" // v28-v31: final accumulator + "vxor.vv v29, v0, v0 \n\t" + "vxor.vv v30, v0, v0 \n\t" + "vxor.vv v31, v0, v0 \n\t" + + // ordinary vmadot: vle*13 vecIns*96 vmadot*16 + ".align 4 \n\t" + "BLK_LPST%=: \n\t" + "K64_LPST%=: \n\t" + + // ========== K0-15: First K16 sub-block ========== + // Load B INT8 scale factors (32 cols × 16 K16 blocks) + "vsetvli t0, x0, e8, m1 \n\t" + "vle8.v v8, (s4) \n\t" + "addi s4, s4, 128 \n\t" + + // Load B quantized data (32 cols × 16 elements × 2bit, stored in 4 groups) + "vle8.v v4, (s6) \n\t" + "addi s6, s6, 128 \n\t" + "vle8.v v5, (s6) \n\t" + "addi s6, s6, 128 \n\t" + "vle8.v v6, (s6) \n\t" + "addi s6, s6, 128 \n\t" + "vle8.v v7, (s6) \n\t" + "addi s6, s6, 128 \n\t" + + // Load B hmask (32 cols × 16bit sign mask) + "vsetvli t0, x0, e8, mf2 \n\t" + "vle8.v v0, (s5) \n\t" + "addi s5, s5, 64 \n\t" + + // Load A data (4 rows × 16 elements × INT8) + "vsetvli t0, x0, e8, mf2 \n\t" + "vle8.v v12, (s3) \n\t" + "addi s3, s3, 256 \n\t" // Jump to next row + "vle8.v v13, (s3) \n\t" + "addi s3, s3, 256 \n\t" + "vle8.v v14, (s3) \n\t" + "addi s3, s3, 256 \n\t" + "vle8.v v15, (s3) \n\t" + "addi s3, s3, -768+64 \n\t" // Back to first row, advance 16 elements + + // Pack A data: merge 4 rows into 2 vectors + "vsetvli t0, x0, e8, m1 \n\t" + "vpack.vv v16, v12, v13, 1 \n\t" + "vpack.vv v18, v14, v15, 1 \n\t" + "vpack.vv v2, v16, v18, 2 \n\t" + + // unpack 2-bit qs + hmask -> signed values + "vsetvli t0, x0, e8, m1 \n\t" + "vnot.v v0, v0 \n\t" + "vand.vi v12, v4, 0x3 \n\t" + "vand.vi v13, v5, 0x3 \n\t" + "vand.vi v14, v6, 0x3 \n\t" + "vand.vi v15, v7, 0x3 \n\t" + + "vsetvli t0, x0, e8, m4 \n\t" + "vadd.vi v12, v12, -4, v0.t \n\t" + + "vsetvli t0, x0, e32, m1 \n\t" + "vxor.vv v16, v16, v16 \n\t" + "vxor.vv v18, v16, v16 \n\t" + "vxor.vv v20, v16, v16 \n\t" + "vxor.vv v22, v16, v16 \n\t" + + "vmadot v16, v2, v12, i8 \n\t" // 4 rows × cols 0-7 + "vmadot v18, v2, v13, i8 \n\t" // 4 rows × cols 8-15 + "vmadot v20, v2, v14, i8 \n\t" // 4 rows × cols 16-23 + "vmadot v22, v2, v15, i8 \n\t" // 4 rows × cols 24-31 + + "vsetvli t0, x0, e16, m1 \n\t" + "vpack.vv v12, v16, v18, 2 \n\t" // Merge cols 0-15 + "vpack.vv v14, v20, v22, 2 \n\t" // Merge cols 16-31 + "vpack.vv v16, v12, v14, 3 \n\t" // Inter-row results (INT16) + "vpack.vv v18, v13, v15, 3 \n\t" + + // apply B int8 scales (-32 bias has been applyed) + "vsetvli t0, x0, e8, mf4 \n\t" + "vwadd.vx v21, v8, x0 \n\t" // INT8 → INT16 + + "vsetvli t0, x0, e16, mf2 \n\t" + "vwadd.vx v23, v21, x0 \n\t" // INT16 → INT32 + + // Accumulate to K256 accumulator: qsum * b_scale + "vsetvli t0, x0, e32, m1 \n\t" + "vmacc.vv v24, v16, v23 \n\t" // Row 0 + "vmacc.vv v25, v17, v23 \n\t" // Row 1 + "vmacc.vv v26, v18, v23 \n\t" // Row 2 + "vmacc.vv v27, v19, v23 \n\t" + + // ========== K16-31, K32-47, K48-63: Similar processing ========== + // load B scales (32 bytes per K16, 16 times => 512B) + "vsetvli t0, x0, e64, m1 \n\t" + "vslidedown.vi v8, v8, 4 \n\t" + + // load B hmask chunk (64B per K16, 16 times => 1024B) + "vsetvli t0, x0, e8, mf2 \n\t" + "vle8.v v0, (s5) \n\t" + "addi s5, s5, 64 \n\t" + + // load A data (16 bytes per K16, 16 times => 256B) + "vsetvli t0, x0, e64, m1 \n\t" + "vslidedown.vi v2, v2, 8 \n\t" + + // unpack 2-bit qs + hmask -> signed values + "vsetvli t0, x0, e8, m1 \n\t" + "vsll.vi v12, v4, 4 \n\t" + "vsll.vi v13, v5, 4 \n\t" + "vsll.vi v14, v6, 4 \n\t" + "vsll.vi v15, v7, 4 \n\t" + "vnot.v v0, v0 \n\t" + + "vsrl.vi v12, v12, 6 \n\t" + "vsrl.vi v13, v13, 6 \n\t" + "vsrl.vi v14, v14, 6 \n\t" + "vsrl.vi v15, v15, 6 \n\t" + + "vsetvli t0, x0, e8, m4 \n\t" + "vadd.vi v12, v12, -4, v0.t \n\t" + + "vsetvli t0, x0, e32, m1 \n\t" + "vxor.vv v16, v16, v16 \n\t" + "vxor.vv v18, v16, v16 \n\t" + "vxor.vv v20, v16, v16 \n\t" + "vxor.vv v22, v16, v16 \n\t" + + "vmadot v16, v2, v12, i8 \n\t" + "vmadot v18, v2, v13, i8 \n\t" + "vmadot v20, v2, v14, i8 \n\t" + "vmadot v22, v2, v15, i8 \n\t" + + "vsetvli t0, x0, e16, m1 \n\t" + "vpack.vv v12, v16, v18, 2 \n\t" + "vpack.vv v14, v20, v22, 2 \n\t" + "vpack.vv v16, v12, v14, 3 \n\t" // N0-N31 in v16 + "vpack.vv v18, v13, v15, 3 \n\t" + + // apply B int8 scales (-32 bias has been applyed) + "vsetvli t0, x0, e8, mf4 \n\t" + "vwadd.vx v21, v8, x0 \n\t" // int8 -> int16 + + "vsetvli t0, x0, e16, mf2 \n\t" + "vwadd.vx v23, v21, x0 \n\t" // int8 -> int16 + + // static_cast<int32_t>(qsum) * b_scale; + "vsetvli t0, x0, e32, m1 \n\t" + "vmacc.vv v24, v16, v23 \n\t" + "vmacc.vv v25, v17, v23 \n\t" + "vmacc.vv v26, v18, v23 \n\t" + "vmacc.vv v27, v19, v23 \n\t" + + //K32-47 + // load B scales (32 bytes per K16, 16 times => 512B) + "vsetvli t0, x0, e64, m1 \n\t" + "vslidedown.vi v8, v8, 4 \n\t" + + // load B hmask chunk (64B per K16, 16 times => 1024B) + "vsetvli t0, x0, e8, mf2 \n\t" + "vle8.v v0, (s5) \n\t" + "addi s5, s5, 64 \n\t" + + // load A data (16 bytes per K16, 16 times => 256B) + + // unpack 2-bit qs + hmask -> signed values + "vsetvli t0, x0, e8, m1 \n\t" + "vsll.vi v12, v4, 2 \n\t" + "vsll.vi v13, v5, 2 \n\t" + "vsll.vi v14, v6, 2 \n\t" + "vsll.vi v15, v7, 2 \n\t" + "vnot.v v0, v0 \n\t" + + "vsrl.vi v12, v12, 6 \n\t" + "vsrl.vi v13, v13, 6 \n\t" + "vsrl.vi v14, v14, 6 \n\t" + "vsrl.vi v15, v15, 6 \n\t" + + "vsetvli t0, x0, e8, m4 \n\t" + "vadd.vi v12, v12, -4, v0.t \n\t" + + "vsetvli t0, x0, e32, m1 \n\t" + "vxor.vv v16, v16, v16 \n\t" + "vxor.vv v18, v16, v16 \n\t" + "vxor.vv v20, v16, v16 \n\t" + "vxor.vv v22, v16, v16 \n\t" + + "vmadot v16, v3, v12, i8 \n\t" + "vmadot v18, v3, v13, i8 \n\t" + "vmadot v20, v3, v14, i8 \n\t" + "vmadot v22, v3, v15, i8 \n\t" + + "vsetvli t0, x0, e16, m1 \n\t" + "vpack.vv v12, v16, v18, 2 \n\t" + "vpack.vv v14, v20, v22, 2 \n\t" + "vpack.vv v16, v12, v14, 3 \n\t" // N0-N31 in v16 + "vpack.vv v18, v13, v15, 3 \n\t" + + // apply B int8 scales (-32 bias has been applyed) + "vsetvli t0, x0, e8, mf4 \n\t" + "vwadd.vx v21, v8, x0 \n\t" // int8 -> int16 + + "vsetvli t0, x0, e16, mf2 \n\t" + "vwadd.vx v23, v21, x0 \n\t" // int8 -> int16 + + // static_cast<int32_t>(qsum) * b_scale; + "vsetvli t0, x0, e32, m1 \n\t" + "vmacc.vv v24, v16, v23 \n\t" + "vmacc.vv v25, v17, v23 \n\t" + "vmacc.vv v26, v18, v23 \n\t" + "vmacc.vv v27, v19, v23 \n\t" + + // K48-63 + // load B scales (32 bytes per K16, 16 times => 512B) + "vsetvli t0, x0, e64, m1 \n\t" + "vslidedown.vi v8, v8, 4 \n\t" + + // load B hmask chunk (64B per K16, 16 times => 1024B) + "vsetvli t0, x0, e8, mf2 \n\t" + "vle8.v v0, (s5) \n\t" + "addi s5, s5, 64 \n\t" + + // load A data (16 bytes per K16, 16 times => 256B) + "vsetvli t0, x0, e64, m1 \n\t" + "vslidedown.vi v3, v3, 8 \n\t" + + "vsetvli t0, x0, e8, m1 \n\t" + "vnot.v v0, v0 \n\t" + "vsrl.vi v12, v4, 6 \n\t" + "vsrl.vi v13, v5, 6 \n\t" + "vsrl.vi v14, v6, 6 \n\t" + "vsrl.vi v15, v7, 6 \n\t" + + "vsetvli t0, x0, e8, m4 \n\t" + "vadd.vi v12, v12, -4, v0.t \n\t" + + "vsetvli t0, x0, e32, m1 \n\t" + "vxor.vv v16, v16, v16 \n\t" + "vxor.vv v18, v16, v16 \n\t" + "vxor.vv v20, v16, v16 \n\t" + "vxor.vv v22, v16, v16 \n\t" + + "vmadot v16, v3, v12, i8 \n\t" + "vmadot v18, v3, v13, i8 \n\t" + "vmadot v20, v3, v14, i8 \n\t" + "vmadot v22, v3, v15, i8 \n\t" + + "vsetvli t0, x0, e16, m1 \n\t" + "vpack.vv v12, v16, v18, 2 \n\t" + "vpack.vv v14, v20, v22, 2 \n\t" + "vpack.vv v16, v12, v14, 3 \n\t" // N0-N31 in v16 + "vpack.vv v18, v13, v15, 3 \n\t" + + // apply B int8 scales (-32 bias has been applyed) + "vsetvli t0, x0, e8, mf4 \n\t" + "vwadd.vx v21, v8, x0 \n\t" // int8 -> int16 + + "vsetvli t0, x0, e16, mf2 \n\t" + "vwadd.vx v23, v21, x0 \n\t" // int8 -> int16 + + // static_cast<int32_t>(qsum) * b_scale; + "vsetvli t0, x0, e32, m1 \n\t" + "vmacc.vv v24, v16, v23 \n\t" + "vmacc.vv v25, v17, v23 \n\t" + "vmacc.vv v26, v18, v23 \n\t" + "vmacc.vv v27, v19, v23 \n\t" + + "addi t3, t3, -1 \n\t" + "bgtz t3, K64_LPST%= \n\t" + "K64_LPND%=: \n\t" + + // ========== K256 superblock complete, apply scale factors ========== + // Load A's 4 row scale factors (FP32) + "flw f0, (s2) \n\t" + "flw f1, 4(s2) \n\t" + "flw f2, 8(s2) \n\t" + "flw f3, 12(s2) \n\t" + "add s2, s2, %[A_STR] \n\t" // Advance to next superblock + "add t4, s7, %[B_STR] \n\t" // t4 = next B block address + "addi s3, s2, (4+32)*4 \n\t" + + // Load B FP16 global scale factors (32 cols) + "vsetvli t0, x0, e16, mf2 \n\t" + "vle16.v v8, (s6) \n\t" + + // Update B pointers to next block + "addi s5, t4, 32*16 \n\t" + "mv s4, t4 \n\t" + "addi s6, s5, 32*32 \n\t" + "addi s7, t4, 0 \n\t" + + // ========== Type conversion and final scaling ========== + // FP16 → FP32 + "vsetvli t0, x0, e16, mf2 \n\t" + "vfwcvt.f.f.v v9, v8 \n\t" + + // INT32 → FP32 + "vsetvli t0, x0, e32, m1 \n\t" + "vfcvt.f.x.v v24, v24 \n\t" + "vfcvt.f.x.v v25, v25 \n\t" + "vfcvt.f.x.v v26, v26 \n\t" + "vfcvt.f.x.v v27, v27 \n\t" + + // Compute a_scale * b_scale (4 rows) + "vfmul.vf v12, v9, f0 \n\t" + "vfmul.vf v13, v9, f1 \n\t" + "vfmul.vf v14, v9, f2 \n\t" + "vfmul.vf v15, v9, f3 \n\t" + + // Final accumulation: result += qsum * a_scale * b_scale + "vsetvli t0, x0, e32, m1 \n\t" + "vfmacc.vv v28, v12, v24 \n\t" + "vfmacc.vv v29, v13, v25 \n\t" + "vfmacc.vv v30, v14, v26 \n\t" + "vfmacc.vv v31, v15, v27 \n\t" + + // Prepare for next K superblock + "addi t2, t2, -1 \n\t" + "vxor.vv v24, v0, v0 \n\t" // Clear K256 accumulator + "vxor.vv v25, v0, v0 \n\t" + "vxor.vv v26, v0, v0 \n\t" + "vxor.vv v27, v0, v0 \n\t" + "li t3, 4 \n\t" + "bgtz t2, BLK_LPST%= \n\t" + + "BLK_LPND%=: \n\t" + + // ========== Store results (4 rows × 32 cols) ========== + "mv t5, %[pC] \n\t" + "vsetvli t0, %[NBLKS], e32, m1 \n\t" + "vse32.v v28, (%[pC]) \n\t" + "add t5, t5, %[LDC] \n\t" + "vse32.v v29, (t5) \n\t" + "add t5, t5, %[LDC] \n\t" + "vse32.v v30, (t5) \n\t" + "add t5, t5, %[LDC] \n\t" + "vse32.v v31, (t5) \n\t" + "add t5, t5, %[LDC] \n\t" + "FUNC_END%=: \n\t" + + : + : [KBLKS] "r"(k_blks), [NBLKS] "r"(nb_real), [pA] "r"(quant_a_ptr), [pB] "r"(quant_b_blk_data), + [pC] "r"(c_ptr), [B_STR] "r"(b_ncol_block_stride), [A_STR] "r"(a_nrow_block_stride), [LDC] "r"(ldc * 4) + : "cc", "memory", "t0", "t2", "t3", "t4", "t5", "f0", "f1", "f2", "f3", "s2", "s3", "s4", "s5", "s6", "s7"); + } +} + +void gemm_kernel_i8i4_m1(size_t blk_len, + const uint8_t * quant_a_ptr, + const uint8_t * quant_b_data, + const uint8_t * quant_b_zp, + float * c_ptr, + size_t count_m, + size_t count_n, + size_t k_blks, + size_t ldc) { + if (quant_b_zp == NULL) { + for (size_t n = 0; n < count_n; n += 32) { + size_t nblks = (count_n - n) > 32 ? 32 : count_n - n; + uint8_t * QuantBDataPtr = (uint8_t *) quant_b_data + // + n * k_blks * blk_len / 2 + // b data + n * k_blks * sizeof(_Float16); // scale + float * CPtr = c_ptr + n; + size_t cnt = k_blks; + + // A format Version_1 (FP32 SCALE FOR Normal VMADOTins of IME2) + // A M1K32 int8 256bit + // Ascale fp32 * 1 32bit + // || scl*1(fp32) | Asum(int16) | blk0 || scl*1(fp32) | Asum(int16) | blk0 || ... + // || Element || Element || ... + // B format + // B N8K32 int4 1024bit + // 4VRF, N32K32, 4096bit + // Bscale fp16 * N32 512bit; + // || scl*32..(fp16) | blk0 blk1 ... blk31 || scl*32..(fp16) | blk0 blk1 ... blk31 || ... + // || Element || Element || ... +#if 0 + //bias always be nullptr + __asm__ volatile( + + // t3 = k/32 + "mv t3, %[BCK] \n\t" + "mv t4, %[NBLKS] \n\t" + "mv s2, %[pA] \n\t" // s2 = pASCL + "addi s3, %[pA], 4+2 \n\t" // s3 = pAData, (pA+AScl+ASum) + "mv s4, %[pB] \n\t" // s4 = pBSCL + "addi s5, %[pB], 32*2 \n\t" // s5 = pBdata; + "mv s6, %[pC] \n\t" + + "vsetvli t0, x0, e32, m1 \n\t" + "vxor.vv v2, v0, v0 \n\t" // clear acc + + // ordinary vmadot: vle*6 flw*1 vecIns*21 vmadot*8 + ".align 4 \n\t" + "_K_LPST%=: \n\t" + + "vsetvli t0, x0, e8, m1 \n\t" + "vl4r.v v4, (s5) \n\t" // B Data 4VRF * 8Row * 32 + "addi s5, s5, 128*4+64 \n\t" // 1024bit + + "vsetvli t0, x0, e8, mf2 \n\t" + "vle8.v v0, (s4) \n\t" // B Scale 4VRF*8Row*FP16 = 512bit + "addi s4, s4, 64+128*4 \n\t" + + "vsetvli t0, x0, e8, mf4 \n\t" + "vle8.v v3, (s3) \n\t" // A Data M1*K32*int8 = 256bit + "addi s3, s3, 32+6 \n\t" + + "flw f0, (s2) \n\t" // A Scale fp32 + "lh t2, 4(s2) \n\t" // A sum of int16 + "addi s2, s2, 6+32 \n\t" + + "vsetvli t0, zero, e8, m1 \n\t" + "vsrl.vi v24, v3, 4 \n\t" + + "vnpack4.vv v8, v3, v3, 3 \n\t" // lo4 of A + "vnpack4.vv v10, v24, v24, 3 \n\t" // hi4 of A + + "vsetvli t0, x0, e32, m1 \n\t" + "vxor.vv v16, v16, v16 \n\t" + "vxor.vv v18, v16, v16 \n\t" + "vxor.vv v20, v16, v16 \n\t" + "vxor.vv v22, v16, v16 \n\t" + + "vmadotsu v16, v10, v4, i4 \n\t" // M0 N0 - N7 INT32(256bit) + "vmadotsu v18, v10, v5, i4 \n\t" // M0 N8 - N15 + "vmadotsu v20, v10, v6, i4 \n\t" // M0 N16 - N23 + "vmadotsu v22, v10, v7, i4 \n\t" // M0 N24 - N31 + + "vsll.vi v16, v16, 4 \n\t" + "vsll.vi v18, v18, 4 \n\t" + "vsll.vi v20, v20, 4 \n\t" + "vsll.vi v22, v22, 4 \n\t" + + "vmadotu v16, v8, v4, i4 \n\t" + "vmadotu v18, v8, v5, i4 \n\t" + "vmadotu v20, v8, v6, i4 \n\t" + "vmadotu v22, v8, v7, i4 \n\t" + + "vsetvli t0, x0, e16, m1 \n\t" + "vmv.v.i v28, 8 \n\t" + "vpack.vv v24, v16, v18, 2 \n\t" + "vpack.vv v26, v20, v22, 2 \n\t" + "vpack.vv v16, v24, v26, 3 \n\t" + + "vwmul.vx v24, v28, t2 \n\t" + "vsetvli t0, x0, e32, m1 \n\t" + "vadd.vv v16, v16, v24 \n\t" + + // b_scale fp16 -> fp32 + "vsetvli t0, x0, e16, mf2 \n\t" + "vfwcvt.f.f.v v24, v0 \n\t" + // mac result i32 -> fp32 + "vsetvli t0, x0, e32, m1 \n\t" + "vfcvt.f.x.v v26, v16 \n\t" + // a_scale * b_scale; + "vfmul.vf v1, v24, f0 \n\t" + // static_cast<float>(qsum) * a_scale * b_scale; + "vfmacc.vv v2, v1, v26 \n\t" + + "addi t3, t3, -1 \n\t" + "bgtz t3, _K_LPST%= \n\t" + "_K_LPND%=: \n\t" + + //----------------------------------------- + // STORE Equal 32N------------------------- + "_ST32%=: \n\t" + "vsetvli t0, t4, e32, m1 \n\t" + "vse32.v v2, (s6) \n\t" // M0 [N0 : N32]; FP32(1024bit) + + "_FUNC_END%=: \n\t" + + : + : [BCK] "r"(cnt), [NBLKS] "r"(nblks), [pA] "r"(quant_a_ptr), [pB] "r"(QuantBDataPtr), [pC] "r"(CPtr) + : "cc", "t0", "t2", "t3", "t4", "f0", "s2", "s3", "s4", "s5", "s6"); +#else + __asm__ volatile( + + // t3 = k/32 + "mv t3, %[BCK] \n\t" + "mv t4, %[NBLKS] \n\t" + "vsetvli t0, x0, e16, m1 \n\t" + "vmv.v.i v0, 1 \n\t" // init the scale + "mv s2, %[pA] \n\t" // s2 = pASCL + "addi s3, %[pA], 4+2 \n\t" // s3 = pAData, (pA+AScl+ASum) + "mv s4, %[pB] \n\t" // s4 = pBSCL + "addi s5, %[pB], 32*2 \n\t" // s5 = pBdata; + "mv s6, %[pC] \n\t" + + "vsll.vi v1, v0, 4 \n\t" + "vxor.vv v2, v0, v0 \n\t" // clear acc + "vfcvt.f.x.v v0, v0 \n\t" + "vfcvt.f.x.v v1, v1 \n\t" + + // vmadot hp: vle*7 flw*1 vecIns*14 vmadot*8 + ".align 4 \n\t" + "_K_LPST%=: \n\t" + + "vsetvli t0, x0, e8, m1 \n\t" + "vl4r.v v4, (s5) \n\t" // B Data 4VRF * 8Row * 32 + "addi s5, s5, 128*4+64 \n\t" // 1024bit + + "vsetvli t0, x0, e8, mf2 \n\t" + "vle8.v v30, (s4) \n\t" // B Scale 4VRF*8Row*FP16 = 512bit + "addi s4, s4, 64+128*4 \n\t" + + "vsetvli t0, x0, e8, mf4 \n\t" + "vle8.v v3, (s3) \n\t" // A Data M1*K32*int8 = 256bit + "addi s3, s3, 32+6 \n\t" + + "flw f0, (s2) \n\t" // A Scale fp32 + "lh t2, 4(s2) \n\t" // A sum of int16 + "addi s2, s2, 6+32 \n\t" + + "vsetvli t0, x0, e16, m1 \n\t" + "vmv.v.i v28, 8 \n\t" // Bzp u8 -> u16 + "vsetvli t0, x0, e8, m1 \n\t" + "vsrl.vi v24, v3, 4 \n\t" + + "vsetvli t0, x0, e16, m1 \n\t" + "vmul.vx v26, v28, t2 \n\t" // asum*zp i16*i16 + "vnpack4.vv v8, v3, v3, 3 \n\t" // lo4 of A + "vnpack4.vv v10, v24, v24, 3 \n\t" // hi4 of A + + "vfcvt.f.x.v v16, v26 \n\t" // zp i16 -> fp16 + "vadd.vi v18, v16, 0 \n\t" + "vadd.vi v20, v16, 0 \n\t" + "vadd.vi v22, v16, 0 \n\t" + + "vmadotsu.hp v16, v10, v4, v1, 0, i4 \n\t" // high 4 + "vmadotsu.hp v18, v10, v5, v1, 0, i4 \n\t" + "vmadotsu.hp v20, v10, v6, v1, 0, i4 \n\t" + "vmadotsu.hp v22, v10, v7, v1, 0, i4 \n\t" + "vmadotu.hp v16, v8, v4, v0, 0, i4 \n\t" // low 4 + "vmadotu.hp v18, v8, v5, v0, 0, i4 \n\t" + "vmadotu.hp v20, v8, v6, v0, 0, i4 \n\t" + "vmadotu.hp v22, v8, v7, v0, 0, i4 \n\t" + + "vpack.vv v24, v16, v18, 1 \n\t" + "vpack.vv v26, v20, v22, 1 \n\t" + "vpack.vv v16, v24, v26, 2 \n\t" + + "vsetvli t0, x0, e16, mf2 \n\t" + // mac result * b_scale; f16*f16->f32 + "vfwmul.vv v31, v30, v16 \n\t" + + "vsetvli t0, x0, e32, m1 \n\t" + // static_cast<float>(qsum * b_scale) * a_scale; + "vfmacc.vf v2, f0, v31 \n\t" + + "addi t3, t3, -1 \n\t" + "bgtz t3, _K_LPST%= \n\t" + "_K_LPND%=: \n\t" + + //----------------------------------------- + // STORE Equal 32N------------------------- + "_ST32%=: \n\t" + "vsetvli t0, t4, e32, m1 \n\t" + "vse32.v v2, (s6) \n\t" // M0 [N0 : N32]; FP32(1024bit) + + "_FUNC_END%=: \n\t" + + : + : [BCK] "r"(cnt), [NBLKS] "r"(nblks), [pA] "r"(quant_a_ptr), [pB] "r"(QuantBDataPtr), [pC] "r"(CPtr) + : "cc", "t0", "t2", "t3", "t4", "f0", "s2", "s3", "s4", "s5", "s6"); + +#endif + } + } else { + for (size_t n = 0; n < count_n; n += 32) { + size_t nblks = (count_n - n) > 32 ? 32 : count_n - n; + uint8_t * QuantBDataPtr = (uint8_t *) quant_b_data + // + n * k_blks * blk_len / 2 + // b data + n * k_blks * sizeof(uint8_t) + // b zp + n * k_blks * sizeof(_Float16); // scale + float * CPtr = c_ptr + n; + size_t cnt = k_blks; + + // A format Version_1 (FP32 SCALE FOR Normal VMADOTins of IME2) + // A M1K32 int8 256bit + // Ascale fp32 * 1 32bit + // || scl*1(fp32) | Asum(int16) | blk0 || scl*1(fp32) | Asum(int16) | blk0 || ... + // || Element || Element || ... + // B format + // B N8K32 int4 1024bit + // 4VRF, N32K32, 4096bit + // Bscale fp16 * N32 512bit; + // Bzp uint8_t * N32 256bit; + // || scl*32..(fp16) | zp*32(uint8) | blk0 blk1 ... blk31 || scl*32..(fp16) ... + // || Element || Element ... + + //bias always be nullptr +#if 0 + __asm__ volatile( + + // t3 = k/32 + "mv t3, %[BCK] \n\t" + "mv t4, %[NBLKS] \n\t" + "mv s2, %[pA] \n\t" // s2 = pASCL + "addi s3, %[pA], 4+2 \n\t" // s3 = pAData, (pA+AScl+ASum) + "mv s4, %[pB] \n\t" // s4 = pBSCL + "addi s5, %[pB], 32*3 \n\t" // s5 = pBdata, (pB+BScl+Bzp) + "mv s6, %[pC] \n\t" + + "vsetvli t0, x0, e32, m1 \n\t" + "vxor.vv v2, v0, v0 \n\t" // clear acc + + // ordinary vmadot: vle*6 flw*1 vecIns*21 vmadot*8 + ".align 4 \n\t" + "_K_LPST%=: \n\t" + + "vsetvli t0, x0, e8, m1 \n\t" + "vl4r.v v4, (s5) \n\t" // B Data 4VRF * 8Row * 32 + "addi s5, s5, 128*4+96 \n\t" // 1024bit + + "vsetvli t0, x0, e8, mf2 \n\t" + "vle8.v v0, (s4) \n\t" // B Scale 4VRF*8Row*FP16 = 512bit + "addi s4, s4, 64 \n\t" + + "vsetvli t0, x0, e8, mf4 \n\t" + "vle8.v v3, (s3) \n\t" // A Data M1*K32*int8 = 256bit + "addi s3, s3, 32+6 \n\t" + + "flw f0, (s2) \n\t" // A Scale fp32 + "lh t2, 4(s2) \n\t" // A sum of int16 + "addi s2, s2, 6+32 \n\t" + + "vsetvli t0, zero, e8, m1 \n\t" + "vsrl.vi v24, v3, 4 \n\t" + + "vnpack4.vv v8, v3, v3, 3 \n\t" // lo4 of A + "vnpack4.vv v10, v24, v24, 3 \n\t" // hi4 of A + + "vsetvli t0, x0, e32, m1 \n\t" + "vxor.vv v16, v16, v16 \n\t" + "vxor.vv v18, v16, v16 \n\t" + "vxor.vv v20, v16, v16 \n\t" + "vxor.vv v22, v16, v16 \n\t" + + "vmadotsu v16, v10, v4, i4 \n\t" // M0 N0 - N7 INT32(256bit) + "vmadotsu v18, v10, v5, i4 \n\t" // M0 N8 - N15 + "vmadotsu v20, v10, v6, i4 \n\t" // M0 N16 - N23 + "vmadotsu v22, v10, v7, i4 \n\t" // M0 N24 - N31 + + "vsll.vi v16, v16, 4 \n\t" + "vsll.vi v18, v18, 4 \n\t" + "vsll.vi v20, v20, 4 \n\t" + "vsll.vi v22, v22, 4 \n\t" + + "vsetvli t0, x0, e8, m1 \n\t" + "vle8.v v1, (s4) \n\t" // Bzp + "addi s4, s4, 32+128*4 \n\t" + + "vmadotu v16, v8, v4, i4 \n\t" + "vmadotu v18, v8, v5, i4 \n\t" + "vmadotu v20, v8, v6, i4 \n\t" + "vmadotu v22, v8, v7, i4 \n\t" + + "vwaddu.vx v28, v1, x0 \n\t" // uint8 -> uint16 + "vpack.vv v24, v16, v18, 2 \n\t" + "vpack.vv v26, v20, v22, 2 \n\t" + "vpack.vv v16, v24, v26, 3 \n\t" + + "vsetvli t0, x0, e16, m1 \n\t" + "vwmul.vx v24, v28, t2 \n\t" + "vsetvli t0, x0, e32, m1 \n\t" + "vadd.vv v16, v16, v24 \n\t" + + // b_scale fp16 -> fp32 + "vsetvli t0, x0, e16, mf2 \n\t" + "vfwcvt.f.f.v v24, v0 \n\t" + // mac result i32 -> fp32 + "vsetvli t0, x0, e32, m1 \n\t" + "vfcvt.f.x.v v26, v16 \n\t" + // a_scale * b_scale; + "vfmul.vf v1, v24, f0 \n\t" + // static_cast<float>(qsum) * a_scale * b_scale; + "vfmacc.vv v2, v1, v26 \n\t" + + "addi t3, t3, -1 \n\t" + "bgtz t3, _K_LPST%= \n\t" + "_K_LPND%=: \n\t" + + //----------------------------------------- + // STORE Equal 32N------------------------- + "_ST32%=: \n\t" + "vsetvli t0, t4, e32, m1 \n\t" + "vse32.v v2, (s6) \n\t" // M0 [N0 : N32]; FP32(1024bit) + + "_FUNC_END%=: \n\t" + + : + : [BCK] "r"(cnt), [NBLKS] "r"(nblks), [pA] "r"(quant_a_ptr), [pB] "r"(QuantBDataPtr), [pC] "r"(CPtr) + : "cc", "t0", "t2", "t3", "t4", "f0", "s2", "s3", "s4", "s5", "s6"); +#else + __asm__ volatile( + + // t3 = k/32 + "mv t3, %[BCK] \n\t" + "mv t4, %[NBLKS] \n\t" + "vsetvli t0, x0, e16, m1 \n\t" + "vmv.v.i v0, 1 \n\t" // init the scale + "mv s2, %[pA] \n\t" // s2 = pASCL + "addi s3, %[pA], 4+2 \n\t" // s3 = pAData, (pA+AScl+ASum) + "mv s4, %[pB] \n\t" // s4 = pBSCL + "addi s5, %[pB], 32*3 \n\t" // s5 = pBdata, (pB+BScl+Bzp) + "mv s6, %[pC] \n\t" + + "vsll.vi v1, v0, 4 \n\t" + "vxor.vv v2, v0, v0 \n\t" // clear acc + "vfcvt.f.x.v v0, v0 \n\t" + "vfcvt.f.x.v v1, v1 \n\t" + + // vmadot hp: vle*6 flw*1 vecIns*14 vmadot*8 + ".align 4 \n\t" + "_K_LPST%=: \n\t" + + "vsetvli t0, x0, e8, m1 \n\t" + "vl4r.v v4, (s5) \n\t" // B Data 4VRF * 8Row * 32 + "addi s5, s5, 128*4+96 \n\t" // 1024bit + + "vsetvli t0, x0, e8, mf2 \n\t" + "vle8.v v30, (s4) \n\t" // B Scale 4VRF*8Row*FP16 = 512bit + "addi s4, s4, 64 \n\t" + + "vsetvli t0, x0, e8, mf4 \n\t" + "vle8.v v31, (s4) \n\t" // B zp 32Row*uint8 = 256bit + "addi s4, s4, 32+128*4 \n\t" + + "vle8.v v3, (s3) \n\t" // A Data M1*K32*int8 = 256bit + "addi s3, s3, 32+6 \n\t" + + "flw f0, (s2) \n\t" // A Scale fp32 + "lh t2, 4(s2) \n\t" // A sum of int16 + "addi s2, s2, 6+32 \n\t" + + "vsetvli t0, x0, e8, m1 \n\t" + "vsrl.vi v24, v3, 4 \n\t" + + "vsetvli t0, x0, e16, m1 \n\t" + "vnpack4.vv v8, v3, v3, 3 \n\t" // lo4 of A + "vnpack4.vv v10, v24, v24, 3 \n\t" // hi4 of A + + "vxor.vv v16, v16, v16 \n\t" + "vxor.vv v18, v16, v16 \n\t" + "vxor.vv v20, v16, v16 \n\t" + "vxor.vv v22, v16, v16 \n\t" + + "vmadotsu.hp v16, v10, v4, v1, 0, i4 \n\t" // high 4 + "vmadotsu.hp v18, v10, v5, v1, 0, i4 \n\t" + "vmadotsu.hp v20, v10, v6, v1, 0, i4 \n\t" + "vmadotsu.hp v22, v10, v7, v1, 0, i4 \n\t" + "vmadotu.hp v16, v8, v4, v0, 0, i4 \n\t" // low 4 + "vmadotu.hp v18, v8, v5, v0, 0, i4 \n\t" + "vmadotu.hp v20, v8, v6, v0, 0, i4 \n\t" + "vmadotu.hp v22, v8, v7, v0, 0, i4 \n\t" + + "vsetvli t0, x0, e8, mf4 \n\t" + "vwaddu.vx v28, v31, x0 \n\t" // Bzp u8 -> u16 + + "vsetvli t0, x0, e8, m1 \n\t" + "vpack.vv v24, v16, v18, 1 \n\t" + "vpack.vv v26, v20, v22, 1 \n\t" + "vpack.vv v16, v24, v26, 2 \n\t" + + "vsetvli t0, x0, e16, mf2 \n\t" + "vmul.vx v26, v28, t2 \n\t" // asum*zp i16*i16 + "vfwcvt.f.f.v v22, v30 \n\t" // b_scale fp16 -> fp32 + "vfcvt.f.x.v v18, v26 \n\t" // zp i16 -> fp16 + "vsetvli t0, x0, e16, m1 \n\t" + "vfwadd.vv v20, v18, v16 \n\t" + + "vsetvli t0, x0, e32, m1 \n\t" + // mac result * b_scale; f32*f32->f32 + "vfmul.vv v31, v22, v20 \n\t" + + "vsetvli t0, x0, e32, m1 \n\t" + // static_cast<float>(qsum * b_scale) * a_scale; + "vfmacc.vf v2, f0, v31 \n\t" + + "addi t3, t3, -1 \n\t" + "bgtz t3, _K_LPST%= \n\t" + "_K_LPND%=: \n\t" + + //----------------------------------------- + // STORE Equal 32N------------------------- + "_ST32%=: \n\t" + "vsetvli t0, t4, e32, m1 \n\t" + "vse32.v v2, (s6) \n\t" // M0 [N0 : N32]; FP32(1024bit) + + "_FUNC_END%=: \n\t" + + : + : [BCK] "r"(cnt), [NBLKS] "r"(nblks), [pA] "r"(quant_a_ptr), [pB] "r"(QuantBDataPtr), [pC] "r"(CPtr) + : "cc", "t0", "t2", "t3", "t4", "f0", "s2", "s3", "s4", "s5", "s6"); +#endif + } + } +} + +void gemm_kernel_i8i4_hp_m1(size_t blk_len, + const uint8_t * quant_a_ptr, + const uint8_t * quant_b_data, + const uint8_t * quant_b_zp, + float * c_ptr, + size_t count_m, + size_t count_n, + size_t k_blks, + size_t ldc) { + constexpr size_t NB_COLS = 32; + constexpr size_t k_subblks_per_superblk = 8; + + struct block_q4_0x32_layout { + _Float16 d[NB_COLS]; + uint8_t qs[16 * NB_COLS]; + }; + + GGML_ASSERT(blk_len == 256); + + const size_t b_superblk_stride = sizeof(block_q4_0x32_layout) * k_subblks_per_superblk + + (quant_b_zp ? NB_COLS * k_subblks_per_superblk * sizeof(uint8_t) : 0); + const size_t b_tile_stride = k_blks * b_superblk_stride; + + if (quant_b_zp == NULL) { + for (size_t ni = 0; ni < count_n; ni += 32) { + uint8_t * b_data = (uint8_t *) quant_b_data + (ni / NB_COLS) * b_tile_stride; + int8_t * a_data = (int8_t *) quant_a_ptr; + float * dst_c = c_ptr + ni; + + asm volatile( + "vsetvli t0, x0, e16, m1 \n\t" + "vxor.vv v31, v31, v31 \n\t" // init acc to zero + "mv t4, %[BK] \n\t" + "li t0, 0x4c00 \n\t" // 16 in fp16 + "fmv.h.x fa0, t0 \n\t" + + ".align 4 \n\t" + "BLK_LOOP%=: \n\t" + "li t5, 8 \n\t" + "addi t6, %[A], 288 \n\t" // point to blk scale + "flh ft1, (t6) \n\t" + "addi t6, %[A], 272 \n\t" // point to asum + + // init the acc fp16 + "vsetvli t0, x0, e16, m1 \n\t" + "vxor.vv v16, v18, v18 \n\t" + "vxor.vv v17, v18, v18 \n\t" + "vxor.vv v18, v18, v18 \n\t" + "vxor.vv v19, v18, v18 \n\t" + + "INNER_BLK_LOOP%=: \n\t" + // load a sum and scale + "flh fa1, (t6) \n\t" + "addi t6, t6, 2 \n\t" + "flh ft0, (%[A]) \n\t" + "addi %[A], %[A], 2 \n\t" + // load A + "vsetvli t0, x0, e8, mf4 \n\t" + "vle8.v v3, (%[A]) \n\t" // 1x32@i8 + "addi %[A], %[A], 32 \n\t" + + // load scale B and B + "vsetvli t0, x0, e16, mf2 \n\t" + "vle16.v v8, (%[B]) \n\t" // b_scale fp16 + "addi %[B], %[B], 64 \n\t" + "vl4r.v v4, (%[B]) \n\t" // 32*32@i4 + "addi %[B], %[B], 512 \n\t" + "vfmul.vf v8, v8, ft0 \n\t" // scale b * scale a + "vfmul.vf v9, v8, fa0 \n\t" + "vfmul.vf v10, v8, fa1 \n\t" // scale b * scale a * asm + "vfwmacc.vf v31, ft1, v10 \n\t" // asum * scale a * scale b * blk scale + + "vsetvli t0, x0, e8, m1 \n\t" + "vpack.vv v0, v8, v9, 3 \n\t" + "vsrl.vi v28, v3, 4 \n\t" + + "vsetvli t0, x0, e16, m1 \n\t" + "vnpack4.vv v2, v3, v3, 3 \n\t" // lo4 of A + "vnpack4.vv v3, v28, v28, 3 \n\t" // hi4 of A + + // i4 * i4 vmadot + "vsetvli t0, x0, e16, m1 \n\t" + "vmadotsu.hp v16, v3, v4, v0, 4, i4 \n\t" // high 4 + "vmadotsu.hp v17, v3, v5, v0, 5, i4 \n\t" + "vmadotsu.hp v18, v3, v6, v0, 6, i4 \n\t" + "vmadotsu.hp v19, v3, v7, v0, 7, i4 \n\t" + "vmadotu.hp v16, v2, v4, v0, 0, i4 \n\t" // low 4 + "vmadotu.hp v17, v2, v5, v0, 1, i4 \n\t" + "vmadotu.hp v18, v2, v6, v0, 2, i4 \n\t" + "vmadotu.hp v19, v2, v7, v0, 3, i4 \n\t" + + "addi t5, t5, -1 \n\t" + "bgtz t5, INNER_BLK_LOOP%= \n\t" + + "vpack.vv v8, v16, v17, 1 \n\t" + "vpack.vv v12, v18, v19, 1 \n\t" + "vpack.vv v20, v8, v12, 2 \n\t" + + "vsetvli t0, x0, e16, mf2 \n\t" + "addi t4, t4, -1 \n\t" + "vfwmacc.vf v31, ft1, v20 \n\t" + //"vsetvli t0, x0, e32, m1 \n\t" + //"vfmul.vf v31, v31, ft1 \n\t" // blk scale + + // update A ptr + "addi %[A], t6, 2 \n\t" + + "bgtz t4, BLK_LOOP%= \n\t" + + // save + "vsetvli t0, x0, e32, m1 \n\t" + "vse32.v v31, (%[DST]) \n\t" + : [A] "+r"(a_data), [B] "+r"(b_data) + : [DST] "r"(dst_c), [BK] "r"(k_blks) + : "t0", "t1", "t2", "t3", "t4", "t5", "t6", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", + "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", + "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31", "fa0", "fa1", "ft0", "ft1"); + } + } else { + // TODO: support quant_b_zp for i8i4 hp kernel + GGML_ABORT("gemm_kernel_i8i4_hp_m1 with quant_b_zp is not supported yet"); + } +} + +void gemm_kernel_i8i4_m4(size_t blk_len, + const uint8_t * quant_a_ptr, + const uint8_t * quant_b_data, + const uint8_t * quant_b_zp, + float * c_ptr, + size_t count_m, + size_t count_n, + size_t k_blks, + size_t ldc) { + int64_t b_data_stride = + k_blks * (sizeof(ggml_fp16_t) + 16 * sizeof(int8_t) + (quant_b_zp != NULL ? sizeof(int8_t) : 0)); + if (quant_b_zp == NULL) { + for (size_t ni = 0; ni < count_n; ni += 32) { + uint8_t * b_data = (uint8_t *) quant_b_data + ni * b_data_stride; + int8_t * a_data = (int8_t *) quant_a_ptr; + float * dst_c = c_ptr + ni; +#if 0 + asm volatile( + "li t1, 8 \n\t" + "vsetvli t0, x0, e32, m1 \n\t" + "vxor.vv v28, v28, v28 \n\t" + "vxor.vv v29, v29, v29 \n\t" + "vxor.vv v30, v30, v30 \n\t" + "vxor.vv v31, v31, v31 \n\t" + "mv t4, %[BK] \n\t" + + ".align 4 \n\t" + "BLK_LOOP%=: \n\t" + // load scale A + "flw fa0, (%[A]) \n\t" + "flw fa1, 4(%[A]) \n\t" + "flw fa2, 8(%[A]) \n\t" + "flw fa3, 12(%[A]) \n\t" + "addi %[A], %[A], 16 \n\t" + + // load scale B + "vsetvli t0, x0, e16, mf2 \n\t" + "vle16.v v12, (%[B]) \n\t" + "addi %[B], %[B], 64 \n\t" + "vfwcvt.f.f.v v14, v12 \n\t" + + "vsetivli t0, 4, e16, mf2 \n\t" + "vle16.v v8, (%[A]) \n\t" // asum + "addi %[A], %[A], 8 \n\t" + "vwmul.vx v10, v8, t1 \n\t" // 8*asum + + "vsetvli t0, x0, e8, m1 \n\t" + "vl1r.v v0, (%[A]) \n\t" + "addi %[A], %[A], 128 \n\t" // 4*32@i8 + "vl4r.v v4, (%[B]) \n\t" // 32*32@i4 + "addi %[B], %[B], 512 \n\t" + "vsrl.vi v1, v0, 4 \n\t" + "vnpack4.vv v12, v0, v1, 3 \n\t" // A low u4 + "vupack.vv v2, v12, v12, 2 \n\t" + + // init the accumu to asum * zp + "vsetvli t0, x0, e32, m1 \n\t" + "vxor.vv v16, v16, v16 \n\t" + "vxor.vv v18, v16, v16 \n\t" + "vxor.vv v20, v16, v16 \n\t" + "vxor.vv v22, v16, v16 \n\t" + + // i4 * i4 vmadot + "vsetvli t0, x0, e32, m1 \n\t" + "vmadotsu v16, v3, v4, i4 \n\t" // high 4 + "vmadotsu v18, v3, v5, i4 \n\t" + "vmadotsu v20, v3, v6, i4 \n\t" + "vmadotsu v22, v3, v7, i4 \n\t" + "vsll.vi v16, v16, 4 \n\t" + "vsll.vi v18, v18, 4 \n\t" + "vsll.vi v20, v20, 4 \n\t" + "vsll.vi v22, v22, 4 \n\t" + "vmadotu v16, v2, v4, i4 \n\t" // low 4 + "vmadotu v18, v2, v5, i4 \n\t" + "vmadotu v20, v2, v6, i4 \n\t" + "vmadotu v22, v2, v7, i4 \n\t" + + "vpack.vv v0, v16, v18, 2 \n\t" + "vpack.vv v2, v20, v22, 2 \n\t" + "vpack.vv v16, v0, v2, 3 \n\t" + "vpack.vv v18, v1, v3, 3 \n\t" + + "vrgather.vi v0, v10, 0 \n\t" + "vrgather.vi v1, v10, 1 \n\t" + "vrgather.vi v2, v10, 2 \n\t" + "vrgather.vi v3, v10, 3 \n\t" + + "vadd.vv v16, v16, v0 \n\t" + "vadd.vv v17, v17, v1 \n\t" + "vadd.vv v18, v18, v2 \n\t" + "vadd.vv v19, v19, v3 \n\t" + + "vfcvt.f.x.v v16, v16 \n\t" + "vfcvt.f.x.v v17, v17 \n\t" + "vfcvt.f.x.v v18, v18 \n\t" + "vfcvt.f.x.v v19, v19 \n\t" + + // mul scale + "vfmul.vv v16, v16, v14 \n\t" + "vfmul.vv v17, v17, v14 \n\t" + "vfmul.vv v18, v18, v14 \n\t" + "vfmul.vv v19, v19, v14 \n\t" + + "addi t4, t4, -1 \n\t" + "vfmacc.vf v28, fa0, v16 \n\t" + "vfmacc.vf v29, fa1, v17 \n\t" + "vfmacc.vf v30, fa2, v18 \n\t" + "vfmacc.vf v31, fa3, v19 \n\t" + + "bgtz t4, BLK_LOOP%= \n\t" + + // save + "vsetvli t0, x0, e32, m1 \n\t" + "add t2, %[LDC], %[DST] \n\t" + "vse32.v v28, (%[DST]) \n\t" + "add t3, %[LDC], t2 \n\t" + "vse32.v v29, (t2) \n\t" + "add t2, %[LDC], t3 \n\t" + "vse32.v v30, (t3) \n\t" + "vse32.v v31, (t2) \n\t" + : [A] "+r"(a_data), [B] "+r"(b_data) + : [DST] "r"(dst_c), [LDC] "r"(ldc*4), [BK] "r"(k_blks) + : "t0", "t1", "t2", "t3", "t4", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", + "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", + "v26", "v27", "v28", "v29", "v30", "v31", "fa0", "fa1", "fa2", "fa3"); +#else + asm volatile( + "vsetvli t0, x0, e16, m1 \n\t" + "vxor.vv v28, v28, v28 \n\t" + "vxor.vv v29, v29, v29 \n\t" + "vxor.vv v30, v30, v30 \n\t" + "vxor.vv v31, v31, v31 \n\t" + "vmv.v.i v0, 1 \n\t" // init the scale + "vsll.vi v1, v0, 4 \n\t" + "vfcvt.f.x.v v0, v0 \n\t" + "vfcvt.f.x.v v1, v1 \n\t" + "mv t4, %[BK] \n\t" + + ".align 4 \n\t" + "BLK_LOOP%=: \n\t" + // load scale A + "flw fa0, (%[A]) \n\t" + "flw fa1, 4(%[A]) \n\t" + "flw fa2, 8(%[A]) \n\t" + "flw fa3, 12(%[A]) \n\t" + "addi %[A], %[A], 16 \n\t" + + // load scale B + "vsetvli t0, x0, e16, mf2 \n\t" + "vle16.v v12, (%[B]) \n\t" + "addi %[B], %[B], 64 \n\t" + "vsetvli t0, x0, e16, m1 \n\t" + "vpack.vv v14, v12, v12, 3 \n\t" + + "vsetivli t0, 4, e16, mf2 \n\t" + "vle16.v v8, (%[A]) \n\t" // asum + "addi %[A], %[A], 8 \n\t" + "vsll.vi v8, v8, 3 \n\t" // asum * 8 + "vfcvt.f.x.v v9, v8 \n\t" + "vsetvli t0, x0, e64, m1 \n\t" + "vrgather.vi v10, v9, 0 \n\t" + + "vsetvli t0, x0, e8, m1 \n\t" + "vl1r.v v16, (%[A]) \n\t" + "addi %[A], %[A], 128 \n\t" // 4*32@i8 + "vl4r.v v4, (%[B]) \n\t" // 32*32@i4 + "addi %[B], %[B], 512 \n\t" + "vsrl.vi v17, v16, 4 \n\t" + "vnpack4.vv v12, v16, v17, 3 \n\t" // A low u4 + "vupack.vv v2, v12, v12, 2 \n\t" + + // init the accumu to asum * zp + "vsetvli t0, x0, e16, m1 \n\t" + "vpack.vv v16, v10, v10,0 \n\t" + "vsetvli t0, x0, e32, m1 \n\t" + "vpack.vv v20, v16, v16,0 \n\t" + "vsetvli t0, x0, e64, m1 \n\t" + "vpack.vv v18, v20, v20, 0 \n\t" + "vor.vv v20, v18, v18 \n\t" + "vor.vv v21, v18, v18 \n\t" + + // i4 * i4 vmadot + "vsetvli t0, x0, e16, m1 \n\t" + "vmadotsu.hp v18, v3, v4, v1, 0, i4 \n\t" // high 4 + "vmadotsu.hp v19, v3, v5, v1, 0, i4 \n\t" + "vmadotsu.hp v20, v3, v6, v1, 0, i4 \n\t" + "vmadotsu.hp v21, v3, v7, v1, 0, i4 \n\t" + "vmadotu.hp v18, v2, v4, v0, 0, i4 \n\t" // low 4 + "vmadotu.hp v19, v2, v5, v0, 0, i4 \n\t" + "vmadotu.hp v20, v2, v6, v0, 0, i4 \n\t" + "vmadotu.hp v21, v2, v7, v0, 0, i4 \n\t" + + "vpack.vv v8, v18, v19, 1 \n\t" + "vpack.vv v12, v20, v21, 1 \n\t" + "vpack.vv v20, v8, v12, 2 \n\t" + + "vfwmul.vv v16, v20, v14 \n\t" + "vfwmul.vv v18, v21, v14 \n\t" + + "vsetvli t0, x0, e32, m1 \n\t" + + "addi t4, t4, -1 \n\t" + "vfmacc.vf v28, fa0, v16 \n\t" + "vfmacc.vf v29, fa1, v17 \n\t" + "vfmacc.vf v30, fa2, v18 \n\t" + "vfmacc.vf v31, fa3, v19 \n\t" + + "bgtz t4, BLK_LOOP%= \n\t" + + // save + "vsetvli t0, x0, e32, m1 \n\t" + "add t2, %[LDC], %[DST] \n\t" + "vse32.v v28, (%[DST]) \n\t" + "add t3, %[LDC], t2 \n\t" + "vse32.v v29, (t2) \n\t" + "add t2, %[LDC], t3 \n\t" + "vse32.v v30, (t3) \n\t" + "vse32.v v31, (t2) \n\t" + : [A] "+r"(a_data), [B] "+r"(b_data) + : [DST] "r"(dst_c), [LDC] "r"(ldc * 4), [BK] "r"(k_blks) + : "t0", "t1", "t2", "t3", "t4", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", + "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", + "v25", "v26", "v27", "v28", "v29", "v30", "v31", "fa0", "fa1", "fa2", "fa3"); +#endif + } + } else { + for (size_t ni = 0; ni < count_n; ni += 32) { + uint8_t * b_data = (uint8_t *) quant_b_data + ni * b_data_stride; + int8_t * a_data = (int8_t *) quant_a_ptr; + float * dst_c = c_ptr + ni; + + asm volatile( + "li t1, 8 \n\t" + "vsetvli t0, x0, e32, m1 \n\t" + "vxor.vv v28, v28, v28 \n\t" + "vxor.vv v29, v29, v29 \n\t" + "vxor.vv v30, v30, v30 \n\t" + "vxor.vv v31, v31, v31 \n\t" + "mv t4, %[BK] \n\t" + + ".align 4 \n\t" + "BLK_LOOP%=: \n\t" + // load scale A + "flw fa0, (%[A]) \n\t" + "flw fa1, 4(%[A]) \n\t" + "flw fa2, 8(%[A]) \n\t" + "flw fa3, 12(%[A]) \n\t" + "addi %[A], %[A], 16 \n\t" + + // load scale B + "vsetvli t0, x0, e16, mf2\n\t" + "vle16.v v12, (%[B]) \n\t" + "addi %[B], %[B], 64 \n\t" + "vfwcvt.f.f.v v14, v12 \n\t" + + // load zp + "vsetvli t0, x0, e8, mf4 \n\t" + "vle8.v v8, (%[B]) \n\t" + "addi %[B], %[B], 32 \n\t" + "vwaddu.vx v10, v8, x0 \n\t" + + // load a sum + "lh s1, (%[A]) \n\t" + "lh s2, 2(%[A]) \n\t" + "lh s3, 4(%[A]) \n\t" + "lh s4, 6(%[A]) \n\t" + "addi %[A], %[A], 8 \n\t" + + "vsetvli t0, x0, e8, m1 \n\t" + "vl1r.v v0, (%[A]) \n\t" + "addi %[A], %[A], 128 \n\t" // 4*32@i8 + "vl4r.v v4, (%[B]) \n\t" // 32*32@i4 + "addi %[B], %[B], 512 \n\t" + "vsrl.vi v1, v0, 4 \n\t" + "vnpack4.vv v12, v0, v1, 3 \n\t" // A low u4 + "vupack.vv v2, v12, v12, 2 \n\t" + + // init the accumu to asum * zp + "vsetvli t0, x0, e32, m1 \n\t" + "vxor.vv v16, v16, v16 \n\t" + "vxor.vv v18, v16, v16 \n\t" + "vxor.vv v20, v16, v16 \n\t" + "vxor.vv v22, v16, v16 \n\t" + + // i4 * i4 vmadot + "vsetvli t0, x0, e32, m1 \n\t" + "vmadotsu v16, v3, v4, i4 \n\t" // high 4 + "vmadotsu v18, v3, v5, i4 \n\t" + "vmadotsu v20, v3, v6, i4 \n\t" + "vmadotsu v22, v3, v7, i4 \n\t" + "vsll.vi v16, v16, 4 \n\t" + "vsll.vi v18, v18, 4 \n\t" + "vsll.vi v20, v20, 4 \n\t" + "vsll.vi v22, v22, 4 \n\t" + "vmadotu v16, v2, v4, i4 \n\t" // low 4 + "vmadotu v18, v2, v5, i4 \n\t" + "vmadotu v20, v2, v6, i4 \n\t" + "vmadotu v22, v2, v7, i4 \n\t" + + "vpack.vv v0, v16, v18, 2 \n\t" + "vpack.vv v2, v20, v22, 2 \n\t" + "vpack.vv v16, v0, v2, 3 \n\t" + "vpack.vv v18, v1, v3, 3 \n\t" + + "vsetvli t0, x0, e16, m1 \n\t" + "vwmul.vx v0, v10, s1 \n\t" + "vwmul.vx v2, v10, s2 \n\t" + "vwmul.vx v4, v10, s3 \n\t" + "vwmul.vx v6, v10, s4 \n\t" + + "vsetvli t0, x0, e32, m1 \n\t" + "vadd.vv v16, v16, v0 \n\t" + "vadd.vv v17, v17, v2 \n\t" + "vadd.vv v18, v18, v4 \n\t" + "vadd.vv v19, v19, v6 \n\t" + + "vfcvt.f.x.v v16, v16 \n\t" + "vfcvt.f.x.v v17, v17 \n\t" + "vfcvt.f.x.v v18, v18 \n\t" + "vfcvt.f.x.v v19, v19 \n\t" + + // mul scale + "vfmul.vv v16, v16, v14 \n\t" + "vfmul.vv v17, v17, v14 \n\t" + "vfmul.vv v18, v18, v14 \n\t" + "vfmul.vv v19, v19, v14 \n\t" + + "addi t4, t4, -1 \n\t" + "vfmacc.vf v28, fa0, v16 \n\t" + "vfmacc.vf v29, fa1, v17 \n\t" + "vfmacc.vf v30, fa2, v18 \n\t" + "vfmacc.vf v31, fa3, v19 \n\t" + + "bgtz t4, BLK_LOOP%= \n\t" + + // save + "vsetvli t0, x0, e32, m1 \n\t" + "add t2, %[LDC], %[DST]\n\t" + "vse32.v v28, (%[DST]) \n\t" + "add t3, %[LDC], t2 \n\t" + "vse32.v v29, (t2) \n\t" + "add t2, %[LDC], t3 \n\t" + "vse32.v v30, (t3) \n\t" + "vse32.v v31, (t2) \n\t" + : [A] "+r"(a_data), [B] "+r"(b_data) + : [DST] "r"(dst_c), [LDC] "r"(ldc * 4), [BK] "r"(k_blks) + : "t0", "t1", "t2", "t3", "t4", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", + "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", + "v25", "v26", "v27", "v28", "v29", "v30", "v31", "fa0", "fa1", "fa2", "fa3", "s1", "s2", "s3", "s4"); + } + } +} + +void gemm_kernel_i8i4_hp_m4(size_t blk_len, + const uint8_t * quant_a_ptr, + const uint8_t * quant_b_data, + const uint8_t * quant_b_zp, + float * c_ptr, + size_t count_m, + size_t count_n, + size_t k_blks, + size_t ldc) { + constexpr size_t NB_COLS = 32; + constexpr size_t K_SUBBLKS_PER_SUPERBLK = 8; + constexpr size_t K_SUBBLK_LEN = 32; + + struct block_q4_0x32_layout { + _Float16 d[NB_COLS]; + uint8_t qs[16 * NB_COLS]; + }; + + GGML_ASSERT(blk_len == 256); + GGML_ASSERT(count_m >= 4); + + // Contract: + // - computes a 4-row x 32-col tile per inner invocation + // - A is q8 HP packed in m4 layout, one logical K256 block at a time + // - B is q4 HP packed in N32 tiles, optionally with a separate zp area + // - tail-N is currently not handled here; the caller must provide full N32 tiles + + const size_t b_superblk_stride = sizeof(block_q4_0x32_layout) * K_SUBBLKS_PER_SUPERBLK + + (quant_b_zp ? NB_COLS * K_SUBBLKS_PER_SUPERBLK * sizeof(uint8_t) : 0); + const size_t b_tile_stride = k_blks * b_superblk_stride; + const size_t a_nrow_block_stride = q8_hp_blk_size(blk_len, true, true) * 4; + const size_t a_subblk_stride = q8_hp_blk_size(K_SUBBLK_LEN, false, false) * 4; + + if (quant_b_zp != nullptr) { + for (size_t ni = 0; ni < count_n; ni += NB_COLS) { + const size_t nb_real = std::min<size_t>(NB_COLS, count_n - ni); + if (nb_real != NB_COLS) { + break; + } + + uint8_t * b_tile_base = (uint8_t *) quant_b_data + (ni / NB_COLS) * b_tile_stride; + uint8_t * a_block = (uint8_t *) quant_a_ptr; + float * dst_c = c_ptr + ni; + + // Data layout summary for the with-zp path. + // + // A: M4 x K256 q8 HP block + // - split into 8 x K32 subblocks + // - each K32 subblock is 136B: + // 8B = 4 x fp16 row scales + // 128B = 4 x int8[32] row payloads + // - trailer after 8 subblocks is 72B: + // 4 rows x fp16[8] a_sum values, indexed as [row][ksi] + // 4 rows x fp16 scale_avg tail + // + // B: N32 x K256 q4 HP block with explicit zp area + // - each K32 subblock is 576B: + // 64B = fp16 scale[32] + // 512B = packed q4 payload for 32 columns x 32 k-elements + // - zp is stored separately, not interleaved with the 576B payload block + // - one K256 superblock is laid out as: + // 8 x (scale + qs) blocks = 4608B + // 8 x zp[32] = 256B + // + // C: 4 rows x 32 fp32 outputs + // + // ASM pointer convention: + // - t6: current A K32 subblock base + // - t2: current A a_sum base for this ksi + // row1/row2/row3 are at +16/+32/+48 bytes + // - s5: current B (scale + qs) K32 subblock base + // - s6: current B zp[32] base for this ksi + // + // Loop progression: + // - per ksi: A += 136, a_sum += 2, B_data += 576, B_zp += 32 + // - per ki : skip the 72B A trailer and advance B to the next 4864B superblock + + const _Float16 hp_scale_16 = (_Float16) 16.0f; + const _Float16 hp_scale_1 = (_Float16) 1.0f; + const _Float16 hp_scale_0125 = (_Float16) 0.125f; + + // VPR grouping used below: + // - v4-v7 : B q4 payload for N32 split as 4 x N8 groups + // - v8/v10 : zp u8 / widened fp16 + // - v12 : B fp16 scale[32] + // - v14-v15 : packed (Bscale * Ascale) for rows [0,1] / [2,3] + // - v16-v19 : temporary per-row scaled B scales + // - v28-v31 : final fp32 accumulators for rows 0..3 + + asm volatile( + "mv t5, %[BK] \n\t" + "mv t6, %[A] \n\t" + "mv s5, %[B] \n\t" + "vsetvli t0, x0, e32, m1 \n\t" + "vxor.vv v28, v28, v28 \n\t" + "vxor.vv v29, v29, v29 \n\t" + "vxor.vv v30, v30, v30 \n\t" + "vxor.vv v31, v31, v31 \n\t" + "li t4, 8 \n\t" + "li t1, 4608 \n\t" + "addi t2, t6, 1088 \n\t" // 8 * 136B A K32 subblocks, a_sum trailer starts here + "add s6, s5, t1 \n\t" // 8 * 576B B(scale+qs), zp area starts here + + ".align 4 \n\t" + "_BLK_LPST%=: \n\t" + "flh fa1, 64(t2) \n\t" // a_scale_avg_row[0] + "vsetvli t0, x0, e32, m1 \n\t" + "vxor.vv v18, v30, v30 \n\t" + "vxor.vv v19, v31, v31 \n\t" + "vxor.vv v20, v30, v30 \n\t" + "vxor.vv v21, v31, v31 \n\t" + "_KsubBLK_LPST%=: \n\t" + // load first subblock scales for 4 rows + "flh fa0, 0(t6) \n\t" // ascale_fp16 + + // load B fp16 scales[32] + "vsetvli t0, x0, e16, mf2 \n\t" + "vle16.v v12, (s5) \n\t" + + // load Bzp[32] for the current ksi from the dedicated zp area + "vsetvli t0, x0, e8, mf4 \n\t" + "vle8.v v8, (s6) \n\t" + + "fmul.h fa2, fa0, %[HP16] \n\t" + "vfwcvt.f.xu.v v10, v8 \n\t" // uint8 -> fp16 + + "vsetvli t0, x0, e16, mf2 \n\t" + "vfmul.vf v16, v12, fa0 \n\t" // row0: Bscale * Ascale + "vfmul.vf v17, v12, fa2 \n\t" + + // load a_sum[row][ksi] from the trailer; t2 points to row0[ksi] + "flh ft1, 0(t2) \n\t" + "flh ft2, 16(t2) \n\t" + "flh ft3, 32(t2) \n\t" + "flh ft4, 48(t2) \n\t" + + "fmul.h ft1, ft1, %[HP0125] \n\t" + "fmul.h ft2, ft2, %[HP0125] \n\t" + "fmul.h ft3, ft3, %[HP0125] \n\t" + "fmul.h ft4, ft4, %[HP0125] \n\t" + + // load A payload from current K32 subblock and B q4 payload from current 576B block + "addi t3, t6, 8 \n\t" + "vsetvli t0, x0, e8, m1 \n\t" + "vl1r.v v0, (t3) \n\t" //A + "addi t3, s5, 64 \n\t" + "vl4r.v v4, (t3) \n\t" //B + + "vsetvli t0, x0, e8, m1 \n\t" + "vsrl.vi v1, v0, 4 \n\t" + "vnpack4.vv v12, v0, v1, 3 \n\t" + "vpack.vv v0, v17, v16, 3 \n\t" + "vupack.vv v2, v12, v12, 2 \n\t" + + "vsetvli t0, x0, e16, mf2 \n\t" // mf2 -> mf2 + "vfmul.vv v10, v10, v16 \n\t" // zp * ascale * bscale; fp16*fp16 + + "vsetvli t0, x0, e16, mf2 \n\t" // mf2 -> m1 + "vfmul.vf v12, v10, ft1 \n\t" // zp(1:n)* abscale * asum_m0; fp16*fp16 + "vfmul.vf v13, v10, ft2 \n\t" // zp(1:n)* abscale * asum_m1; fp16*fp16 + "vfmul.vf v24, v10, ft3 \n\t" // zp(1:n)* abscale * asum_m2; fp16*fp16 + "vfmul.vf v25, v10, ft4 \n\t" // zp(1:n)* abscale * asum_m3; fp16*fp16 + + "vsetvli t0, x0, e16, mf2 \n\t" + "vfwmacc.vf v28, fa1, v12 \n\t" // row0/1 accum += dot * packed scale + "vfwmacc.vf v29, fa1, v13 \n\t" + "vfwmacc.vf v30, fa1, v24 \n\t" + "vfwmacc.vf v31, fa1, v25 \n\t" + + "vsetvli t0, x0, e32, m1 \n\t" + "vmadotsu.hp v18, v3, v4, v0, 0, i4 \n\t" //lo4;n0n7 + "vmadotsu.hp v19, v3, v5, v0, 1, i4 \n\t" //lo4;n8n15 + "vmadotsu.hp v20, v3, v6, v0, 2, i4 \n\t" //lo4;n16n23 + "vmadotsu.hp v21, v3, v7, v0, 3, i4 \n\t" //lo4;n24n31 + "vmadotu.hp v18, v2, v4, v0, 4, i4 \n\t" //hi4;n0n7 + "vmadotu.hp v19, v2, v5, v0, 5, i4 \n\t" //hi4;n8n15 + "vmadotu.hp v20, v2, v6, v0, 6, i4 \n\t" //hi4;n16n23 + "vmadotu.hp v21, v2, v7, v0, 7, i4 \n\t" //hi4;n24n31 + + "addi t4, t4, -1 \n\t" + "addi t6, t6, 8+128 \n\t" // next A K32 subblock + "addi t2, t2, 2 \n\t" // next ksi entry in each a_sum row + "addi s5, s5, 64+512 \n\t" // next B (scale + qs) K32 block + "addi s6, s6, 32 \n\t" // next zp[32] + "bgtz t4, _KsubBLK_LPST%= \n\t" + + "vsetvli t0, x0, e16, m1 \n\t" + "vpack.vv v8, v18, v19, 1 \n\t" // 128(16*8)->256(16*16) + "vpack.vv v12, v20, v21, 1 \n\t" + "vpack.vv v26, v8, v12, 2 \n\t" // 256(16*16)->512(16*32) + + "vsetvli t0, x0, e16, m1 \n\t" + "vfwmacc.vf v28, fa1, v26 \n\t" // row0/1 accum += dot * packed scale + "vfwmacc.vf v30, fa1, v27 \n\t" + + "li t4, 8 \n\t" + "addi t5, t5, -1 \n\t" + "addi t6, t6, 72 \n\t" // skip A trailer after 8 subblocks and scale_avg tail + "mv s5, s6 \n\t" // s6 already points to next B superblock base + "addi t2, t6, 1088 \n\t" // 8 * 136B A K32 subblocks, a_sum trailer starts here + "add s6, s5, t1 \n\t" // 8 * 576B B(scale+qs), zp area starts here + "bgtz t5, _BLK_LPST%= \n\t" + + "_BLK_LPND%=: \n\t" + "vsetvli t0, x0, e32, m1 \n\t" + "add t2, %[LDC], %[DST] \n\t" + "vse32.v v28, (%[DST]) \n\t" + "add t3, %[LDC], t2 \n\t" + "vse32.v v29, (t2) \n\t" + "add t2, %[LDC], t3 \n\t" + "vse32.v v30, (t3) \n\t" + "vse32.v v31, (t2) \n\t" + : [A] "+r"(a_block), [B] "+r"(b_tile_base) + : [DST] "r"(dst_c), [LDC] "r"(ldc * 4), [BK] "r"(k_blks), [HP16] "f"(hp_scale_16), + [HP1] "f"(hp_scale_1), [HP0125] "f"(hp_scale_0125) + : "t0", "t1", "t2", "t3", "t4", "t5", "t6", "s5", "s6", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", + "v8", "v10", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v24", + "v25", "v26", "v27", "v28", "v29", "v30", "v31", "fa0", "fa1", "fa2", "ft1", "ft2", "ft3", "ft4", + "memory"); + } + return; + } else { + for (size_t ni = 0; ni < count_n; ni += NB_COLS) { + const size_t nb_real = std::min<size_t>(NB_COLS, count_n - ni); + if (nb_real != NB_COLS) { + break; + } + + uint8_t * b_tile_base = (uint8_t *) quant_b_data + (ni / NB_COLS) * b_tile_stride; + uint8_t * a_block = (uint8_t *) quant_a_ptr; + float * dst_c = c_ptr + ni; + + // Data layout summary for the no-zp path. + // + // A layout is identical to the with-zp branch. + // + // B: N32 x K256 q4 HP block without explicit zp storage + // - each K32 subblock is still 576B: + // 64B = fp16 scale[32] + // 512B = packed q4 payload + // - zp is implicit and treated as a constant value 8 in the kernel + // - one K256 superblock therefore contains only: + // 8 x (scale + qs) blocks = 4608B + // + // C: 4 rows x 32 fp32 outputs + // + // ASM pointer convention: + // - t6: current A K32 subblock base + // - t2: current A a_sum base for this ksi + // - s5: current B (scale + qs) K32 subblock base + // + // Loop progression: + // - per ksi: A += 136, a_sum += 2, B_data += 576 + // - per ki : skip the 72B A trailer and advance B to the next 4608B superblock + + const _Float16 hp_scale_16 = (_Float16) 16.0f; + const _Float16 hp_scale_1 = (_Float16) 1.0f; + + // VPR grouping used below matches the with-zp path: + // - v4-v7 : B q4 payload for N32 split as 4 x N8 groups + // - v8/v10 : implicit zp lane / widened fp16 + // - v12 : B fp16 scale[32] + // - v14-v15 : packed (Bscale * Ascale) for rows [0,1] / [2,3] + // - v16-v19 : temporary per-row scaled B scales + // - v28-v31 : final fp32 accumulators for rows 0..3 + + asm volatile( + "mv t5, %[BK] \n\t" + "mv t6, %[A] \n\t" + "mv s5, %[B] \n\t" + "vsetvli t0, x0, e32, m1 \n\t" + "vxor.vv v28, v28, v28 \n\t" + "vxor.vv v29, v29, v29 \n\t" + "vxor.vv v30, v30, v30 \n\t" + "vxor.vv v31, v31, v31 \n\t" + "li t4, 8 \n\t" + "addi t2, t6, 1088 \n\t" // 8 * 136B A K32 subblocks, a_sum trailer starts here + + ".align 4 \n\t" + "_BLK_LPST%=: \n\t" + "flh fa1, 64(t2) \n\t" // a_scale_avg_row[0] + "vsetvli t0, x0, e32, m1 \n\t" + "vxor.vv v18, v30, v30 \n\t" + "vxor.vv v19, v31, v31 \n\t" + "vxor.vv v20, v30, v30 \n\t" + "vxor.vv v21, v31, v31 \n\t" + "_KsubBLK_LPST%=: \n\t" + // load first subblock scales for 4 rows + "flh fa0, 0(t6) \n\t" // ascale_fp16 + + // load B fp16 scales[32] + "vsetvli t0, x0, e16, mf2 \n\t" + "vle16.v v12, (s5) \n\t" + + "fmul.h fa2, fa0, %[HP16] \n\t" + + "vsetvli t0, x0, e16, mf2 \n\t" + "vfmul.vf v16, v12, fa0 \n\t" // row0: Bscale * Ascale + "vfmul.vf v17, v12, fa2 \n\t" + + // load a_sum[row][ksi] from the trailer; t2 points to row0[ksi] + "flh ft1, 0(t2) \n\t" + "flh ft2, 16(t2) \n\t" + "flh ft3, 32(t2) \n\t" + "flh ft4, 48(t2) \n\t" + + // load A payload from current K32 subblock and B q4 payload from current 576B block + "addi t3, t6, 8 \n\t" + "vsetvli t0, x0, e8, m1 \n\t" + "vl1r.v v0, (t3) \n\t" //A + "addi t3, s5, 64 \n\t" + "vl4r.v v4, (t3) \n\t" //B + + "vsetvli t0, x0, e8, m1 \n\t" + "vsrl.vi v1, v0, 4 \n\t" + "vnpack4.vv v12, v0, v1, 3 \n\t" + "vpack.vv v0, v17, v16, 3 \n\t" + "vupack.vv v2, v12, v12, 2 \n\t" + + "vsetvli t0, x0, e16, mf2 \n\t" // mf2 -> m1 + "vfmul.vf v12, v16, ft1 \n\t" // zp(1:n)* abscale * asum_m0; fp16*fp16 + "vfmul.vf v13, v16, ft2 \n\t" // zp(1:n)* abscale * asum_m1; fp16*fp16 + "vfmul.vf v24, v16, ft3 \n\t" // zp(1:n)* abscale * asum_m2; fp16*fp16 + "vfmul.vf v25, v16, ft4 \n\t" // zp(1:n)* abscale * asum_m3; fp16*fp16 + + "vsetvli t0, x0, e16, mf2 \n\t" + "vfwmacc.vf v28, fa1, v12 \n\t" + "vfwmacc.vf v29, fa1, v13 \n\t" + "vfwmacc.vf v30, fa1, v24 \n\t" + "vfwmacc.vf v31, fa1, v25 \n\t" + + "vsetvli t0, x0, e32, m1 \n\t" + "vmadotsu.hp v18, v3, v4, v0, 0, i4 \n\t" //lo4;n0n7 + "vmadotsu.hp v19, v3, v5, v0, 1, i4 \n\t" //lo4;n8n15 + "vmadotsu.hp v20, v3, v6, v0, 2, i4 \n\t" //lo4;n16n23 + "vmadotsu.hp v21, v3, v7, v0, 3, i4 \n\t" //lo4;n24n31 + "vmadotu.hp v18, v2, v4, v0, 4, i4 \n\t" //hi4;n0n7 + "vmadotu.hp v19, v2, v5, v0, 5, i4 \n\t" //hi4;n8n15 + "vmadotu.hp v20, v2, v6, v0, 6, i4 \n\t" //hi4;n16n23 + "vmadotu.hp v21, v2, v7, v0, 7, i4 \n\t" //hi4;n24n31 + + "addi t4, t4, -1 \n\t" + + "addi t6, t6, 8+128 \n\t" // next A K32 subblock + "addi t2, t2, 2 \n\t" // next ksi entry in each a_sum row + "addi s5, s5, 64+512 \n\t" // next B (scale + qs) K32 block + "bgtz t4, _KsubBLK_LPST%= \n\t" + + "vsetvli t0, x0, e16, m1 \n\t" //N32in1register + "vpack.vv v8, v18, v19, 1 \n\t" // 128(16*8)->256(16*16) + "vpack.vv v12, v20, v21, 1 \n\t" + "vpack.vv v26, v8, v12, 2 \n\t" // 256(16*16)->512(16*32) + + "vsetvli t0, x0, e16, m1 \n\t" + "vfwmacc.vf v28, fa1, v26 \n\t" // row0/1 accum += dot * packed scale + "vfwmacc.vf v30, fa1, v27 \n\t" + + "li t4, 8 \n\t" + "addi t5, t5, -1 \n\t" + "addi t6, t6, 72 \n\t" // skip A trailer after 8 subblocks and scale_avg tail + // s5 already points to next B superblock base + "addi t2, t6, 1088 \n\t" // 8 * 136B A K32 subblocks, a_sum trailer starts here + "bgtz t5, _BLK_LPST%= \n\t" + + "_BLK_LPND%=: \n\t" + "vsetvli t0, x0, e32, m1 \n\t" + "add t2, %[LDC], %[DST] \n\t" + "vse32.v v28, (%[DST]) \n\t" + "add t3, %[LDC], t2 \n\t" + "vse32.v v29, (t2) \n\t" + "add t2, %[LDC], t3 \n\t" + "vse32.v v30, (t3) \n\t" + "vse32.v v31, (t2) \n\t" + : [A] "+r"(a_block), [B] "+r"(b_tile_base) + : [DST] "r"(dst_c), [LDC] "r"(ldc * 4), [BK] "r"(k_blks), [HP16] "f"(hp_scale_16), [HP1] "f"(hp_scale_1) + : "t0", "t2", "t3", "t4", "t5", "t6", "s5", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v10", + "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v24", "v25", "v26", + "v27", "v28", "v29", "v30", "v31", "fa0", "fa1", "fa2", "ft1", "ft2", "ft3", "ft4", "memory"); + } + return; + } +} + +void gemm_kernel_i8mxfp4_m1(size_t blk_len, + const uint8_t * quant_a_ptr, + const uint8_t * quant_b_data, + const uint8_t * quant_b_zp, + float * c_ptr, + size_t count_m, + size_t count_n, + size_t k_blks, + size_t ldc) { + constexpr size_t NB_COLS = 32; + constexpr size_t K_TILE = 32; + using blk_type = nrow_block_mxfp4<NB_COLS>; + + GGML_ASSERT(blk_len == K_TILE); + GGML_ASSERT(count_m == 1); + GGML_UNUSED(quant_b_zp); + + const size_t a_blk_stride = q8_blk_size(blk_len, true); + const size_t b_blk_stride = sizeof(blk_type); + const size_t b_tile_stride = k_blks * b_blk_stride; + + if (quant_b_zp == NULL) { + for (size_t n = 0; n < count_n; n += 32) { + size_t nblks = (count_n - n) > 32 ? 32 : count_n - n; + // MXFP4 no-zp: per column per k-block stride = scale_e8m0(1B) + qs(16B) + qh(4B) = 21B + uint8_t * QuantBDataPtr = (uint8_t *) quant_b_data + // + n * k_blks * (blk_len / 8) + // qh sign/high-bit mask: n×k_blks×4 + n * k_blks * blk_len / 2 + // qs packed 4-bit magnitudes: n×k_blks×16 + n * k_blks * sizeof(uint8_t); // scale: n×k_blks×1 + float * CPtr = c_ptr + n; + size_t cnt = k_blks; + + // A format (q8 block with per-block scale and stored sum field): + // || scl(fp32,4B) | asum(int16,2B) | data(int8,32B) || × k_blks + // + // Register map: + // t3 = k_blks loop counter t4 = nblks (tail) + // f0 = A scale (fp32) + // s2 = pA (scale/asum) s3 = pA data + // s4 = pB scales (u8×32) + // s5 = pB qh (sign/high-bit mask, 128B) + // s6 = pB qs (packed 4-bit magnitudes, 512B) + // s7 = pC + // v3 = fp32 accumulator (N32) + // v2 = B scales u8 (loaded as bytes; later widened) + // v0 = qh mask bytes (also used as v0.t mask after load) + // v1 = A int8 (K32) + // v8..v15 / v16..v23 = qs unpack/pack temporaries (build signed vmadot lanes) + // v24/v26/v28/v30 = int32 dot accumulators & packing temps + + __asm__ volatile( + "mv t3, %[BCK] \n\t" // t3 = k_blks + "mv t4, %[NBLKS] \n\t" // t4 = nblks (tail guard) + + // ---- pre-loop: init fp16 constants in e16 m1 context ---- + "vsetvli t0, x0, e16, m1 \n\t" + "vmv.v.i v0, 1 \n\t" // v0 = int16(1) + "vfcvt.f.x.v v0, v0 \n\t" // v0 = 1.0_fp16 + "vxor.vv v3, v16, v16 \n\t" + + // ---- pointer setup ---- + "mv s2, %[pA] \n\t" // s2 = pA (scale, fp32) + "addi s3, %[pA], 4+2 \n\t" // s3 = pA data (skip scale+asum) + "mv s4, %[pB] \n\t" // s4 = pBSCL + "addi s5, %[pB], 32 \n\t" // s5 = pBh (pB + 32B scale) + "addi s6, %[pB], 32+128 \n\t" // s6 = pBs (pB + 32 + 128 = pB+192) + "mv s7, %[pC] \n\t" // s7 = pC + + // ===================================================================== + // K-block loop: each iteration processes one N32×K32 block + // Stride per k-block = 672B = 32(scl) + 512(Bs) + 128(Bh) + // ===================================================================== + ".align 4 \n\t" + "BLK_LPST%=: \n\t" + + // ---- load qs (512B = 4 VRF) from s6, advance s6 by 672 ---- + "vsetvli t0, x0, e8, m1 \n\t" + "vl4r.v v8, (s6) \n\t" // v8..v11 = qs N32K32 packed 4-bit magnitudes + "addi s6, s6, 128*4+128+32 \n\t" // s6 += 672 (512+128+32) + + // ---- load B scale (32B = 32×u8) from s4, advance s4 by 672 ---- + "vsetvli t0, x0, e8, mf2 \n\t" + "vle8.v v2, (s4) \n\t" // v2 = scale_u8 × 32 + "addi s4, s4, 32+128*4+128 \n\t" // s4 += 672 (32+512+128) + + // ---- load qh (128B = 1 VRF) from s5, advance s5 by 672 ---- + "vsetvli t0, x0, e8, m1 \n\t" + "vle8.v v0, (s5) \n\t" // v0 = qh N32K32 sign/high-bit packed + "addi s5, s5, 128+32+128*4 \n\t" // s5 += 672 (128+32+512) + + // ---- load A data (32B = K32 int8) from s3 ---- + "vsetvli t0, x0, e8, mf4 \n\t" + "vle8.v v1, (s3) \n\t" // v1 = A M1K32 int8 + "addi s3, s3, 32+6 \n\t" // s3 += 38 (data + scl + asum) + + // ---- load A scale (fp32) and asum (int16) from s2 ---- + "flw f0, (s2) \n\t" // f0 = A scale (fp32) + "addi s2, s2, 6+32 \n\t" // s2 += 38 + + // ---- Decode packed MXFP4 payload into a vmadot-friendly signed-lane layout ---- + "vsetvli t0, x0, e8, m1 \n\t" + "vand.vi v12, v8, 0xF \n\t" //8bit(lo4) //[8*32] + "vand.vi v13, v9, 0xF \n\t" + "vand.vi v14, v10, 0xF \n\t" + "vand.vi v15, v11, 0xF \n\t" + "vsrl.vi v8, v8, 4 \n\t" //8bit(hi4) + "vsrl.vi v9, v9, 4 \n\t" + "vsrl.vi v10, v10, 4 \n\t" + "vsrl.vi v11, v11, 4 \n\t" + + // [4*32]*2 + "vsetvli t0, x0, e8, m1 \n\t" + "vpack.vv v16, v12, v8, 0 \n\t" + "vpack.vv v18, v13, v9, 0 \n\t" + "vpack.vv v20, v14, v10, 0 \n\t" + "vpack.vv v22, v15, v11, 0 \n\t" + + "vsetvli t0, x0, e8, m8 \n\t" + "vrsub.vi v16, v16, 0, v0.t \n\t" + + // [4*32]*2 -> [8*16] + "vsetvli t0, x0, e8, m1 \n\t" + "vupack.vv v8, v16, v17, 1 \n\t" + "vupack.vv v10, v18, v19, 1 \n\t" + "vupack.vv v12, v20, v21, 1 \n\t" + "vupack.vv v14, v22, v23, 1 \n\t" + + "vsetvli t0, x0, e64, m1 \n\t" + "vslidedown.vi v16, v1, 2 \n\t" + + // init the accumu to 0 + "vsetvli t0, x0, e32, m1 \n\t" + "vxor.vv v24, v16, v16 \n\t" + "vxor.vv v26, v16, v16 \n\t" + "vxor.vv v28, v16, v16 \n\t" + "vxor.vv v30, v16, v16 \n\t" + + // ---- int8 dot products over the decoded MXFP4 lane groups ---- + "vmadot v24, v1, v8, i8 \n\t" // N0..7 + "vmadot v26, v1, v10, i8 \n\t" // N8..15 + "vmadot v28, v1, v12, i8 \n\t" // N16..23 + "vmadot v30, v1, v14, i8 \n\t" // N24..31 + "vmadot v24, v16, v9, i8 \n\t" // N0..7 + "vmadot v26, v16, v11, i8 \n\t" // N8..15 + "vmadot v28, v16, v13, i8 \n\t" // N16..23 + "vmadot v30, v16, v15, i8 \n\t" // N24..31 + + "vsetvli t0, x0, e32, m1 \n\t" + "vpack.vv v16, v24, v26, 2 \n\t" // v16 = N0..15 + "vpack.vv v18, v28, v30, 2 \n\t" // v18 = N16..31 + "vpack.vv v24, v16, v18, 3 \n\t" // v24 = N0..31 + + "lui t1, 0x00200 \n\t" + "vmv.v.x v30, t1 \n\t" + // b_scale e8m0 -> fp32 + "vsetvli t0, x0, e8, mf4 \n\t" + "vwaddu.vx v28, v2, x0 \n\t" + "vsetvli t0, x0, e16, mf2 \n\t" + "vwadd.vx v2, v28, x0 \n\t" + "vsetvli t0, x0, e32, m1 \n\t" + "vmsle.vi v0, v2, 1 \n\t" + "vadd.vi v28, v2, -1 \n\t" + "vsll.vi v28, v28, 23 \n\t" + "vsll.vv v28, v30, v2, v0.t \n\t" + + // a_scale * b_scale; + "vsetvli t0, x0, e32, m1 \n\t" + "vfcvt.f.x.v v26, v24 \n\t" + "vfmul.vf v30, v28, f0 \n\t" + "vsetvli t0, x0, e32, m1 \n\t" + // static_cast<float>(qsum) * a_scale * b_scale; + "vfmacc.vv v3, v30, v26 \n\t" + + "addi t3, t3, -1 \n\t" + "bgtz t3, BLK_LPST%= \n\t" + "BLK_LPND%=: \n\t" + "vsetvli t0, %[NBLKS], e32, m1 \n\t" + "vse32.v v3, (%[pC]) \n\t" + "FUNC_END%=: \n\t" + + : + : [BCK] "r"(cnt), [NBLKS] "r"(nblks), [pA] "r"(quant_a_ptr), [pB] "r"(QuantBDataPtr), [pC] "r"(CPtr) + : "cc", "memory", "t0", "t1", "t2", "t3", "t4", "f0", "s2", "s3", "s4", "s5", "s6", "s7", "v0", "v1", + "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v16", "v17", "v18", "v19", + "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31"); + } + } +} + +void gemm_kernel_i8mxfp4_m4(size_t blk_len, + const uint8_t * quant_a_ptr, + const uint8_t * quant_b_data, + const uint8_t * quant_b_zp, + float * c_ptr, + size_t count_m, + size_t count_n, + size_t k_blks, + size_t ldc) { + constexpr size_t NB_COLS = 32; + constexpr size_t K_TILE = 32; + using blk_type = nrow_block_mxfp4<NB_COLS>; + + GGML_ASSERT(blk_len == K_TILE); + GGML_ASSERT(count_m == 4); + GGML_UNUSED(quant_b_zp); + + const size_t a_blk_stride = q8_blk_size(blk_len, true); + const size_t b_blk_stride = sizeof(blk_type); + const size_t b_tile_stride = k_blks * b_blk_stride; + + if (quant_b_zp == NULL) { + // MXFP4 block layout per K32/N32 tile: + // [scale_e8m0 x 32][qh sign/high-bit mask x 128B][qs packed 4-bit magnitudes x 512B] + // There is no explicit zp stream; qh is combined with qs to reconstruct signed MXFP4 values. + for (size_t ni = 0; ni < count_n; ni += NB_COLS) { + size_t nb_real = std::min<size_t>(NB_COLS, count_n - ni); + uint8_t * b_data = (uint8_t *) quant_b_data + (ni / NB_COLS) * b_tile_stride; + uint8_t * a_data = (uint8_t *) quant_a_ptr; + float * dst_c = c_ptr + ni; + size_t cnt = k_blks; + + asm volatile( + // v4-v7 are the fp32 accumulators for rows 0..3 of the current N32 tile. + "vsetvli t0, x0, e32, m1 \n\t" + "vxor.vv v4, v4, v4 \n\t" + "vxor.vv v5, v5, v5 \n\t" + "vxor.vv v6, v6, v6 \n\t" + "vxor.vv v7, v7, v7 \n\t" + + ".align 4 \n\t" + "BLK_LOOP%=: \n\t" + // Load the 4 A-row scales for this K32 block and build row data pointers. + "flw fa0, 0(%[A]) \n\t" + "flw fa1, 4(%[A]) \n\t" + "flw fa2, 8(%[A]) \n\t" + "flw fa3, 12(%[A]) \n\t" + "addi t3, %[A], 24 \n\t" + "addi t4, t3, 32 \n\t" + "addi t5, t3, 64 \n\t" + "addi t6, t3, 96 \n\t" + "addi %[A], %[A], 152 \n\t" + + // B-side pointers: + // t1 -> qh bitmask stream, t2 -> qs low-nibble stream. + "addi t1, %[B], 32 \n\t" + "addi t2, %[B], 160 \n\t" + "vsetvli t0, x0, e8, mf2 \n\t" + "vle8.v v2, (%[B]) \n\t" + "addi %[B], %[B], 672 \n\t" + "vsetvli t0, x0, e8, m1 \n\t" + "vle8.v v0, (t1) \n\t" + "vl4r.v v8, (t2) \n\t" + + // Decode the packed MXFP4 payload once for the whole tile and expand it + // into a vmadot-friendly layout. + "vand.vi v12, v8, 0xF \n\t" + "vand.vi v13, v9, 0xF \n\t" + "vand.vi v14, v10, 0xF \n\t" + "vand.vi v15, v11, 0xF \n\t" + "vsrl.vi v8, v8, 4 \n\t" + "vsrl.vi v9, v9, 4 \n\t" + "vsrl.vi v10, v10, 4 \n\t" + "vsrl.vi v11, v11, 4 \n\t" + + "vpack.vv v16, v12, v8, 0 \n\t" + "vpack.vv v18, v13, v9, 0 \n\t" + "vpack.vv v20, v14, v10, 0 \n\t" + "vpack.vv v22, v15, v11, 0 \n\t" + + "vsetvli t0, x0, e8, m8 \n\t" + "vrsub.vi v16, v16, 0, v0.t \n\t" + + "vsetvli t0, x0, e8, m1 \n\t" + "vupack.vv v8, v16, v17, 1 \n\t" + "vupack.vv v10, v18, v19, 1 \n\t" + "vupack.vv v12, v20, v21, 1 \n\t" + "vupack.vv v14, v22, v23, 1 \n\t" + + "lui t1, 0x00200 \n\t" + "vmv.v.x v30, t1 \n\t" + // b_scale e8m0 -> fp32 + "vsetvli t0, x0, e8, mf4 \n\t" + "vwaddu.vx v28, v2, x0 \n\t" + "vsetvli t0, x0, e16, mf2 \n\t" + "vwadd.vx v26, v28, x0 \n\t" + "vsetvli t0, x0, e32, m1 \n\t" + "vmsle.vi v0, v26, 1 \n\t" + "vadd.vi v24, v26, -1 \n\t" + "vsll.vi v18, v24, 23 \n\t" + "vsll.vv v18, v30, v26, v0.t \n\t" + + // Row 0: dot(A0, decoded MXFP4 lane groups), accumulate in int32 and + // then apply A/B scaling. + "vsetvli t0, x0, e8, m1 \n\t" + "vle8.v v1, (t3) \n\t" + "vsetvli t0, x0, e64, m1 \n\t" + "vupack.vv v16, v1, v2, 1 \n\t" + "vsetvli t0, x0, e32, m1 \n\t" + "vxor.vv v24, v24, v24 \n\t" + "vxor.vv v26, v26, v26 \n\t" + "vxor.vv v28, v28, v28 \n\t" + "vxor.vv v30, v30, v30 \n\t" + "vmadot v24, v16, v8, i8 \n\t" + "vmadot v26, v16, v10, i8 \n\t" + "vmadot v28, v16, v12, i8 \n\t" + "vmadot v30, v16, v14, i8 \n\t" + "vmadot v24, v17, v9, i8 \n\t" + "vmadot v26, v17, v11, i8 \n\t" + "vmadot v28, v17, v13, i8 \n\t" + "vmadot v30, v17, v15, i8 \n\t" + "vpack.vv v16, v24, v26, 2 \n\t" + "vpack.vv v20, v28, v30, 2 \n\t" + "vpack.vv v24, v16, v20, 3 \n\t" + "vpack.vv v26, v17, v21, 3 \n\t" + "vfcvt.f.x.v v24, v24 \n\t" + "vfcvt.f.x.v v25, v25 \n\t" + "vfcvt.f.x.v v26, v26 \n\t" + "vfcvt.f.x.v v27, v27 \n\t" + "vfmul.vv v24, v24, v18 \n\t" + "vfmul.vv v25, v25, v18 \n\t" + "vfmul.vv v26, v26, v18 \n\t" + "vfmul.vv v27, v27, v18 \n\t" + "vfmacc.vf v4, fa0, v24 \n\t" + "vfmacc.vf v5, fa1, v25 \n\t" + "vfmacc.vf v6, fa2, v26 \n\t" + "vfmacc.vf v7, fa3, v27 \n\t" + + "addi %[BK], %[BK], -1 \n\t" + "bgtz %[BK], BLK_LOOP%= \n\t" + + // Tail-aware store for the final N tile (`nb_real` may be < 32). + "vsetvli t0, %[NBLKS], e32, m1 \n\t" + "add t1, %[LDC], %[DST] \n\t" + "vse32.v v4, (%[DST]) \n\t" + "vse32.v v5, (t1) \n\t" + "add t2, t1, %[LDC] \n\t" + "vse32.v v6, (t2) \n\t" + "add t3, t2, %[LDC] \n\t" + "vse32.v v7, (t3) \n\t" + : [A] "+r"(a_data), [B] "+r"(b_data), [BK] "+r"(cnt) + : [DST] "r"(dst_c), [LDC] "r"(ldc * 4), [NBLKS] "r"(nb_real) + : "cc", "memory", "t0", "t1", "t2", "t3", "t4", "t5", "t6", "s1", "s2", "s3", "s4", "v0", "v1", "v2", + "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", + "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31", + "fa0", "fa1", "fa2", "fa3"); + } + } +} + +void gemm_kernel_i8i5_m1(size_t blk_len, + const uint8_t * quant_a_ptr, + const uint8_t * quant_b_data, + const uint8_t * quant_b_zp, + float * c_ptr, + size_t count_m, + size_t count_n, + size_t k_blks, + size_t ldc) { + // ========================================================================= + // i8i5: 8-bit activation × 5-bit weight (4-bit low + 1-bit high mask) + // + // B layout per N32K32 k-block (no-zp): + // [0 .. 63 ] : scale_fp16 × 32 (64B) + // [64 .. 191] : Bh i1-high-bit × 32N × 32K (128B = 1 VRF) + // [192.. 703] : Bs i4-low-nibble × 32N × 32K (512B = 4 VRF) + // Total: 704B per k-block stride + // + // B layout per N32K32 k-block (with-zp): + // [0 .. 63 ] : scale_fp16 × 32 (64B) + // [64 .. 95 ] : zp_uint8 × 32 (32B) + // [96 .. 223] : Bh i1-high-bit × 32N × 32K (128B = 1 VRF) + // [224.. 735] : Bs i4-low-nibble × 32N × 32K (512B = 4 VRF) + // Total: 736B per k-block stride + // + // Bh format per N8K32 sub-block (32B): + // K rows × N cols × 1bit packed as bytes (8 cols per byte, K groups of 4B) + // Byte k gives 8 mask bits for columns N7..N0 at k-th K-element. + // + // Computation: + // B5bit_signed = (Bs | (Bh << 4)) - zp + // dot(A, B5) = dot(A, Bs_u4) + 16*dot(A, Bh_u1) - zp*asum + // No-zp: implicit zp = 16 (unsigned [0..31] centered at 16) + // With-zp: explicit zp from data + // + // ========================================================================= + + if (quant_b_zp == NULL) { + for (size_t n = 0; n < count_n; n += 32) { + size_t nblks = (count_n - n) > 32 ? 32 : count_n - n; + // i8i5 no-zp: per column per k-block stride = fp16(2B) + i4(16B) + i1(4B) = 22B + uint8_t * QuantBDataPtr = (uint8_t *) quant_b_data + // + n * k_blks * (blk_len / 8) + // Bh i1 mask: n×k_blks×4 + n * k_blks * blk_len / 2 + // Bs i4 data: n×k_blks×16 + n * k_blks * sizeof(_Float16); // scale: n×k_blks×2 + float * CPtr = c_ptr + n; + size_t cnt = k_blks; + + // A format (same as i8i4): + // || scl(fp32,4B) | asum(int16,2B) | data(int8,32B) || × k_blks + // + // Register map: + // t3 = k_blks loop counter t4 = nblks (tail) + // t2 = A asum (int16) << 4 f0 = A scale (fp32) + // s2 = pA (scale/asum) s3 = pA data + // s4 = pB scales (fp16×32) + // s5 = pB Bh (i1 mask, 128B) + // s6 = pB Bs (i4 packed, 512B) + // s7 = pC + // v3 = fp32 accumulator (N32) + // v2 = B scales fp16 (loaded as bytes; later widened) + // v0 = Bh mask bytes (also used as v0.t mask after load) + // v1 = A int8 (K32) + // v8..v15 / v16..v23 = Bs unpack/pack temporaries (build b5bit bytes) + // v24/v26/v28/v30 = int32 dot accumulators & packing temps + + __asm__ volatile( + "mv t3, %[BCK] \n\t" // t3 = k_blks + "mv t4, %[NBLKS] \n\t" // t4 = nblks (tail guard) + + // ---- pre-loop: init fp16 constants in e16 m1 context ---- + "vsetvli t0, x0, e16, m1 \n\t" + "vmv.v.i v0, 1 \n\t" // v0 = int16(1) + "vfcvt.f.x.v v0, v0 \n\t" // v0 = 1.0_fp16 + "vxor.vv v3, v16, v16 \n\t" + + // ---- pointer setup ---- + "mv s2, %[pA] \n\t" // s2 = pA (scale, fp32) + "addi s3, %[pA], 4+2 \n\t" // s3 = pA data (skip scale+asum) + "mv s4, %[pB] \n\t" // s4 = pBSCL + "addi s5, %[pB], 32*2 \n\t" // s5 = pBh (pB + 64B scale) + "addi s6, %[pB], 32*2+128 \n\t" // s6 = pBs (pB + 64 + 128 = pB+192) + "mv s7, %[pC] \n\t" // s7 = pC + + // ===================================================================== + // K-block loop: each iteration processes one N32×K32 block + // Stride per k-block = 704B = 64(scl) + 512(Bs) + 128(Bh) + // ===================================================================== + ".align 4 \n\t" + "BLK_LPST%=: \n\t" + + // ---- load Bs (512B = 4 VRF) from s6, advance s6 by 704 ---- + "vsetvli t0, x0, e8, m1 \n\t" + "vl4r.v v8, (s6) \n\t" // v8..v11 = Bs N32K32 i4 + "addi s6, s6, 128*4+128+64 \n\t" // s6 += 704 (512+128+64) + + // ---- load B scale (64B = 32×fp16) from s4, advance s4 by 704 ---- + "vsetvli t0, x0, e8, mf2 \n\t" + "vle8.v v2, (s4) \n\t" // v2 = scale_fp16 × 32 + "addi s4, s4, 64+128*4+128 \n\t" // s4 += 704 (64+512+128) + + // ---- load Bh (128B = 1 VRF) from s5, advance s5 by 704 ---- + "vsetvli t0, x0, e8, m1 \n\t" + "vle8.v v0, (s5) \n\t" // v0 = Bh N32K32 1-bit packed + "addi s5, s5, 128+64+128*4 \n\t" // s5 += 704 (128+64+512) + + // ---- load A data (32B = K32 int8) from s3 ---- + "vsetvli t0, x0, e8, mf4 \n\t" + "vle8.v v1, (s3) \n\t" // v1 = A M1K32 int8 + "addi s3, s3, 32+6 \n\t" // s3 += 38 (data + scl + asum) + + // ---- load A scale (fp32) and asum (int16) from s2 ---- + "flw f0, (s2) \n\t" // f0 = A scale (fp32) + "lh t2, 4(s2) \n\t" // t2 = A asum (int16) + "addi s2, s2, 6+32 \n\t" // s2 += 38 + + //// ---- A nibble unpacking ---- + "vsetvli t0, x0, e8, m1 \n\t" + "vand.vi v12, v8, 0xF \n\t" //8bit(lo4) //[8*32] + "vand.vi v13, v9, 0xF \n\t" + "vand.vi v14, v10, 0xF \n\t" + "vand.vi v15, v11, 0xF \n\t" + "vsrl.vi v8, v8, 4 \n\t" //8bit(hi4) + "vsrl.vi v9, v9, 4 \n\t" + "vsrl.vi v10, v10, 4 \n\t" + "vsrl.vi v11, v11, 4 \n\t" + + "slli t2, t2, 4 \n\t" // a_sum * 16; + // [4*32]*2 + "vsetvli t0, x0, e8, m1 \n\t" + "vpack.vv v16, v12, v8, 0 \n\t" + "vpack.vv v18, v13, v9, 0 \n\t" + "vpack.vv v20, v14, v10, 0 \n\t" + "vpack.vv v22, v15, v11, 0 \n\t" + + "li t1, 16 \n\t" + "vsetvli t0, x0, e8, m8 \n\t" + "vadd.vx v16, v16, t1, v0.t \n\t" + + // [4*32]*2 -> [8*16] + "vsetvli t0, x0, e8, m1 \n\t" + "vupack.vv v8, v16, v17, 1 \n\t" + "vupack.vv v10, v18, v19, 1 \n\t" + "vupack.vv v12, v20, v21, 1 \n\t" + "vupack.vv v14, v22, v23, 1 \n\t" + + "vsetvli t0, x0, e64, m1 \n\t" + "vslidedown.vi v16, v1, 2 \n\t" + + // init the accumu to asum * zp + "vsetvli t0, x0, e32, m1 \n\t" + "vxor.vv v24, v16, v16 \n\t" + "vxor.vv v26, v16, v16 \n\t" + "vxor.vv v28, v16, v16 \n\t" + "vxor.vv v30, v16, v16 \n\t" + + // ---- i8 main dot products ---- + // vmadot: A × unsigned Bh × 16 → fp16 accumulate + "vmadot v24, v1, v8, i8 \n\t" // N0..7 + "vmadot v26, v1, v10, i8 \n\t" // N8..15 + "vmadot v28, v1, v12, i8 \n\t" // N16..23 + "vmadot v30, v1, v14, i8 \n\t" // N24..31 + //// vmadot: A × unsigned Bh × 1 → fp16 accumulate + "vmadot v24, v16, v9, i8 \n\t" // N0..7 + "vmadot v26, v16, v11, i8 \n\t" // N8..15 + "vmadot v28, v16, v13, i8 \n\t" // N16..23 + "vmadot v30, v16, v15, i8 \n\t" // N24..31 + + "vsetvli t0, x0, e32, m1 \n\t" + "vpack.vv v16, v24, v26, 2 \n\t" // v16 = N0..15 + "vpack.vv v18, v28, v30, 2 \n\t" // v18 = N16..31 + "vpack.vv v24, v16, v18, 3 \n\t" // v24 = N0..31 + + "vadd.vx v24, v24, t2 \n\t" + // b_scale fp16 -> fp32 + "vsetvli t0, x0, e16, mf2 \n\t" + "vfwcvt.f.f.v v28, v2 \n\t" + + // a_scale * b_scale; + "vsetvli t0, x0, e32, m1 \n\t" + "vfcvt.f.x.v v26, v24 \n\t" + "vfmul.vf v30, v28, f0 \n\t" + "vsetvli t0, x0, e32, m1 \n\t" + // static_cast<float>(qsum) * a_scale * b_scale; + "vfmacc.vv v3, v30, v26 \n\t" + + "addi t3, t3, -1 \n\t" + "bgtz t3, BLK_LPST%= \n\t" + "BLK_LPND%=: \n\t" + "vsetvli t0, %[NBLKS], e32, m1 \n\t" + "vse32.v v3, (%[pC]) \n\t" + "FUNC_END%=: \n\t" + + : + : [BCK] "r"(cnt), [NBLKS] "r"(nblks), [pA] "r"(quant_a_ptr), [pB] "r"(QuantBDataPtr), [pC] "r"(CPtr) + : "cc", "memory", "t0", "t1", "t2", "t3", "t4", "f0", "s2", "s3", "s4", "s5", "s6", "s7", "v0", "v1", + "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v16", "v17", "v18", "v19", + "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31"); + } + } else { + for (size_t n = 0; n < count_n; n += 32) { + size_t nblks = (count_n - n) > 32 ? 32 : count_n - n; + // i8i5 with-zp: per column per k-block stride = fp16(2B)+zp(1B)+i4(16B)+i1(4B)=23B + uint8_t * QuantBDataPtr = (uint8_t *) quant_b_data + // + n * k_blks * blk_len / 2 + // Bs i4: n×k_blks×16 + n * k_blks * (blk_len / 8) + // Bh i1: n×k_blks×4 + n * k_blks * sizeof(uint8_t) + // zp: n×k_blks×1 + n * k_blks * sizeof(_Float16); // scale: n×k_blks×2 + float * CPtr = c_ptr + n; + size_t cnt = k_blks; + + // A format (same as i8i4): + // || scl(fp32,4B) | asum(int16,2B) | data(int8,32B) || × k_blks + // + // Register map: + // t3 = k_blks loop counter t4 = nblks (tail) + // t2 = A asum (int16) << 4 f0 = A scale (fp32) + // s2 = pA (scale/asum) s3 = pA data + // s4 = pB scales (fp16×32); 每个 k-block 先 +64 指向 zp,再 +672 到下一个 block + // s5 = pB Bh (i1 mask, 128B) (offset +96) + // s6 = pB Bs (i4 packed, 512B) (offset +224) + // s7 = pC + // v3 = fp32 accumulator (N32) + // v2 = B scales fp16 (loaded as bytes; later widened) + // v0 = Bh mask bytes (also used as v0.t mask after load) + // v1 = A int8 (K32) / later reused to hold Bzp bytes + // v8..v15 / v16..v23 = Bs unpack/pack temporaries (build b5bit bytes) + // v24/v26/v28/v30 = int32 dot accumulators & packing temps + + __asm__ volatile( + "mv t3, %[BCK] \n\t" // t3 = k_blks + "mv t4, %[NBLKS] \n\t" // t4 = nblks (tail guard) + + // ---- pre-loop: init fp16 constants in e16 m1 context ---- + "vsetvli t0, x0, e16, m1 \n\t" + "vmv.v.i v0, 1 \n\t" // v0 = int16(1) + "vfcvt.f.x.v v0, v0 \n\t" // v0 = 1.0_fp16 + "vxor.vv v3, v16, v16 \n\t" + + // ---- pointer setup ---- + "mv s2, %[pA] \n\t" // s2 = pA (scale, fp32) + "addi s3, %[pA], 4+2 \n\t" // s3 = pA data (skip scale+asum) + "mv s4, %[pB] \n\t" // s4 = pBSCL + "addi s5, %[pB], 32*3 \n\t" // s5 = pBh (pB + 64B scale + 32B zp = pB+96) + "addi s6, %[pB], 32*3+128 \n\t" // s6 = pBs (pB + 96 + 128 = pB+224) + "mv s7, %[pC] \n\t" // s7 = pC + + // ===================================================================== + // K-block loop: each iteration processes one N32×K32 block + // Stride per k-block = 736B = 64(scale) + 32(zp) + 128(Bh) + 512(Bs) + // ===================================================================== + ".align 4 \n\t" + "BLK_LPST%=: \n\t" + + // ---- load Bs (512B = 4 VRF) from s6, advance s6 by 736 ---- + "vsetvli t0, x0, e8, m1 \n\t" + "vl4r.v v8, (s6) \n\t" // v8..v11 = Bs N32K32 i4 + "addi s6, s6, 128*4+128+96 \n\t" // s6 += 736 (512+128+96) + + // ---- load B scale (64B = 32×fp16) from s4; then s4 points to zp[32] ---- + "vsetvli t0, x0, e8, mf2 \n\t" + "vle8.v v2, (s4) \n\t" // v2 = scale_fp16 × 32 + "addi s4, s4, 64 \n\t" // s4 += 64 (now points to zp) + + // ---- load Bh (128B = 1 VRF) from s5, advance s5 by 736 ---- + "vsetvli t0, x0, e8, m1 \n\t" + "vle8.v v0, (s5) \n\t" // v0 = Bh N32K32 1-bit packed + "addi s5, s5, 128+96+128*4 \n\t" // s5 += 736 (128+96+512) + + // ---- load A data (32B = K32 int8) from s3 ---- + "vsetvli t0, x0, e8, mf4 \n\t" + "vle8.v v1, (s3) \n\t" // v1 = A M1K32 int8 + "addi s3, s3, 32+6 \n\t" // s3 += 38 (data + scl + asum) + + // ---- load A scale (fp32) and asum (int16) from s2 ---- + "flw f0, (s2) \n\t" // f0 = A scale (fp32) + "lh t2, 4(s2) \n\t" // t2 = A asum (int16) + "addi s2, s2, 6+32 \n\t" // s2 += 38 + + //// ---- A nibble unpacking ---- + "vsetvli t0, x0, e8, m1 \n\t" + "vand.vi v12, v8, 0xF \n\t" //8bit(lo4) //[8*32] + "vand.vi v13, v9, 0xF \n\t" + "vand.vi v14, v10, 0xF \n\t" + "vand.vi v15, v11, 0xF \n\t" + "vsrl.vi v8, v8, 4 \n\t" //8bit(hi4) + "vsrl.vi v9, v9, 4 \n\t" + "vsrl.vi v10, v10, 4 \n\t" + "vsrl.vi v11, v11, 4 \n\t" + + // [4*32]*2 + "vsetvli t0, x0, e8, m1 \n\t" + "vpack.vv v16, v12, v8, 0 \n\t" + "vpack.vv v18, v13, v9, 0 \n\t" + "vpack.vv v20, v14, v10, 0 \n\t" + "vpack.vv v22, v15, v11, 0 \n\t" + + "li t1, 16 \n\t" + "vsetvli t0, x0, e8, m8 \n\t" + "vadd.vx v16, v16, t1, v0.t \n\t" + + // [4*32]*2 -> [8*16] + "vsetvli t0, x0, e8, m1 \n\t" + "vupack.vv v8, v16, v17, 1 \n\t" + "vupack.vv v10, v18, v19, 1 \n\t" + "vupack.vv v12, v20, v21, 1 \n\t" + "vupack.vv v14, v22, v23, 1 \n\t" + + "vsetvli t0, x0, e64, m1 \n\t" + "vslidedown.vi v16, v1, 2 \n\t" + + "vsetvli t0, x0, e32, m1 \n\t" + "vxor.vv v24, v16, v16 \n\t" + "vxor.vv v26, v16, v16 \n\t" + "vxor.vv v28, v16, v16 \n\t" + "vxor.vv v30, v16, v16 \n\t" + + // ---- i8 main dot products ---- + // vmadot: A × unsigned Bh × 16 → fp16 accumulate + "vmadot v24, v1, v8, i8 \n\t" // N0..7 + "vmadot v26, v1, v10, i8 \n\t" // N8..15 + "vmadot v28, v1, v12, i8 \n\t" // N16..23 + "vmadot v30, v1, v14, i8 \n\t" // N24..31 + // vmadot: A × unsigned Bh × 1 → fp16 accumulate + "vmadot v24, v16, v9, i8 \n\t" // N0..7 + "vmadot v26, v16, v11, i8 \n\t" // N8..15 + "vmadot v28, v16, v13, i8 \n\t" // N16..23 + "vmadot v30, v16, v15, i8 \n\t" // N24..31 + + "vsetvli t0, x0, e8, m1 \n\t" + "vle8.v v1, (s4) \n\t" // Bzp + "addi s4, s4, 32+128*4+128 \n\t" + + "vsetvli t0, x0, e8, m1 \n\t" + "vpack.vv v16, v24, v26, 2 \n\t" // v16 = N0..15 + "vpack.vv v18, v28, v30, 2 \n\t" // v18 = N16..31 + "vpack.vv v24, v16, v18, 3 \n\t" // v24 = N0..31 + + "vwaddu.vx v28, v1, x0 \n\t" // uint8 -> uint16 + + "vsetvli t0, x0, e16, m1 \n\t" + "vwmul.vx v30, v28, t2 \n\t" + + // b_scale fp16 -> fp32 + "vsetvli t0, x0, e16, mf2 \n\t" + "vfwcvt.f.f.v v28, v2 \n\t" + "vsetvli t0, x0, e32, m1 \n\t" + "vadd.vv v24, v24, v30 \n\t" + + // a_scale * b_scale; + "vsetvli t0, x0, e32, m1 \n\t" + "vfmul.vf v30, v28, f0 \n\t" + "vfcvt.f.x.v v26, v24 \n\t" + "vsetvli t0, x0, e32, m1 \n\t" + // static_cast<float>(qsum) * a_scale * b_scale; + "vfmacc.vv v3, v30, v26 \n\t" + + "addi t3, t3, -1 \n\t" + "bgtz t3, BLK_LPST%= \n\t" + "BLK_LPND%=: \n\t" + "vsetvli t0, %[NBLKS], e32, m1 \n\t" + "vse32.v v3, (%[pC]) \n\t" + "FUNC_END%=: \n\t" + : + : [BCK] "r"(cnt), [NBLKS] "r"(nblks), [pA] "r"(quant_a_ptr), [pB] "r"(QuantBDataPtr), [pC] "r"(CPtr) + : "cc", "memory", "t0", "t1", "t2", "t3", "t4", "f0", "s2", "s3", "s4", "s5", "s6", "s7", "v0", "v1", + "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v16", "v17", "v18", "v19", + "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31"); + } + } +} + +void gemm_kernel_i8i5_m4(size_t blk_len, + const uint8_t * quant_a_ptr, + const uint8_t * quant_b_data, + const uint8_t * quant_b_zp, + float * c_ptr, + size_t count_m, + size_t count_n, + size_t k_blks, + size_t ldc) { + constexpr size_t NB_COLS = 32; + + GGML_UNUSED(count_m); + GGML_UNUSED(blk_len); + + // This kernel computes a 4x32 output tile. For each K32 block we decode the + // packed Q5 weights once and reuse the decoded vectors across the 4 A rows. + constexpr size_t B_Q50_BLK_STRIDE = sizeof(nrow_block_q5_0<NB_COLS>); + constexpr size_t B_Q51_BLK_STRIDE = sizeof(nrow_block_q5_1<NB_COLS>); + + if (quant_b_zp) { + // Q5_1 block layout per K32/N32 tile: + // [scale_fp16 x 32][zp_u8 x 32][qh high-bit mask x 128B][qs low nibbles x 512B] + for (size_t ni = 0; ni < count_n; ni += NB_COLS) { + size_t nb_real = std::min<size_t>(NB_COLS, count_n - ni); + uint8_t * b_data = (uint8_t *) quant_b_data + (ni / NB_COLS) * k_blks * B_Q51_BLK_STRIDE; + uint8_t * a_data = (uint8_t *) quant_a_ptr; + float * dst_c = c_ptr + ni; + size_t cnt = k_blks; + + asm volatile( + // v4-v7 are the fp32 accumulators for rows 0..3 of the current N32 tile. + "vsetvli t0, x0, e32, m1 \n\t" + "vxor.vv v4, v4, v4 \n\t" + "vxor.vv v5, v5, v5 \n\t" + "vxor.vv v6, v6, v6 \n\t" + "vxor.vv v7, v7, v7 \n\t" + + ".align 4 \n\t" + "BLK_LOOP%=: \n\t" + // Load the 4 A-row scales/sums for this K32 block and build row data pointers. + "flw fa0, 0(%[A]) \n\t" + "flw fa1, 4(%[A]) \n\t" + "flw fa2, 8(%[A]) \n\t" + "flw fa3, 12(%[A]) \n\t" + "lh s1, 16(%[A]) \n\t" + "lh s2, 18(%[A]) \n\t" + "lh s3, 20(%[A]) \n\t" + "lh s4, 22(%[A]) \n\t" + "addi t3, %[A], 24 \n\t" + "addi t4, t3, 32 \n\t" + "addi t5, t3, 64 \n\t" + "addi t6, t3, 96 \n\t" + "addi %[A], %[A], 152 \n\t" + + // B-side pointers: + // t1 -> zp stream, t2 -> qh bitmask stream, s5 -> qs low-nibble stream. + "addi t1, %[B], 64 \n\t" + "addi t2, %[B], 96 \n\t" + "addi s5, %[B], 224 \n\t" + "vsetvli t0, x0, e8, mf2 \n\t" + "vle8.v v2, (%[B]) \n\t" + "vsetvli t0, x0, e8, m1 \n\t" + "vle8.v v0, (t2) \n\t" + "vl4r.v v8, (s5) \n\t" + "addi %[B], %[B], 736 \n\t" + + // Decode Q5 payload once for the whole tile: + // 1) split `qs` low/high nibbles, + // 2) repack into bytes, + // 3) use the `qh` mask to inject bit4 (+16) where needed, + // 4) expand into the vmadot-friendly layout reused by all 4 rows. + "vand.vi v12, v8, 0xF \n\t" + "vand.vi v13, v9, 0xF \n\t" + "vand.vi v14, v10, 0xF \n\t" + "vand.vi v15, v11, 0xF \n\t" + "vsrl.vi v8, v8, 4 \n\t" + "vsrl.vi v9, v9, 4 \n\t" + "vsrl.vi v10, v10, 4 \n\t" + "vsrl.vi v11, v11, 4 \n\t" + + "vpack.vv v16, v12, v8, 0 \n\t" + "vpack.vv v18, v13, v9, 0 \n\t" + "li t2, 16 \n\t" + "vpack.vv v20, v14, v10, 0 \n\t" + "vpack.vv v22, v15, v11, 0 \n\t" + + "vsetvli t0, x0, e8, m8 \n\t" + "vadd.vx v16, v16, t2, v0.t \n\t" + + "vsetvli t0, x0, e8, m1 \n\t" + "vupack.vv v8, v16, v17, 1 \n\t" + "vupack.vv v10, v18, v19, 1 \n\t" + "vupack.vv v12, v20, v21, 1 \n\t" + "vupack.vv v14, v22, v23, 1 \n\t" + + // Convert per-column fp16 scales once; the same scale vector is shared by all 4 rows. + "vsetvli t0, x0, e16, mf2 \n\t" + "vfwcvt.f.f.v v18, v2 \n\t" + "vsetvli t0, x0, e8, m1 \n\t" + "vle8.v v3, (t1) \n\t" + "vsetvli t0, x0, e8, m1 \n\t" + + // Row 0: dot(A0, decoded_q5) + a_sum0 * zp, then scale by A/B scales. + // The widen/mul correction sequence intentionally matches the proven m1 Q5_1 path. + "vle8.v v1, (t3) \n\t" + "vsetvli t0, x0, e64, m1 \n\t" + "vupack.vv v16, v1, v2, 1 \n\t" + "vsetvli t0, x0, e32, m1 \n\t" + "vxor.vv v24, v24, v24 \n\t" + "vxor.vv v26, v26, v26 \n\t" + "vxor.vv v28, v28, v28 \n\t" + "vxor.vv v30, v30, v30 \n\t" + "vmadot v24, v16, v8, i8 \n\t" + "vmadot v26, v16, v10, i8 \n\t" + "vmadot v28, v16, v12, i8 \n\t" + "vmadot v30, v16, v14, i8 \n\t" + "vmadot v24, v17, v9, i8 \n\t" + "vmadot v26, v17, v11, i8 \n\t" + "vmadot v28, v17, v13, i8 \n\t" + "vmadot v30, v17, v15, i8 \n\t" + "vpack.vv v16, v24, v26, 2 \n\t" + "vpack.vv v20, v28, v30, 2 \n\t" + "vpack.vv v24, v16, v20, 3 \n\t" + "vpack.vv v26, v17, v21, 3 \n\t" + "vsetvli t0, x0, e8, m1 \n\t" + "vwaddu.vx v28, v3, x0 \n\t" + "vsetvli t0, x0, e16, m1 \n\t" + "vwmul.vx v12, v28, s1 \n\t" + "vwmul.vx v14, v28, s2 \n\t" + "vwmul.vx v20, v28, s3 \n\t" + "vwmul.vx v22, v28, s4 \n\t" + "vsetvli t0, x0, e32, m1 \n\t" + "vadd.vv v24, v24, v12 \n\t" + "vadd.vv v25, v25, v14 \n\t" + "vadd.vv v26, v26, v20 \n\t" + "vadd.vv v27, v27, v22 \n\t" + "vfcvt.f.x.v v12, v24 \n\t" + "vfcvt.f.x.v v14, v25 \n\t" + "vfcvt.f.x.v v20, v26 \n\t" + "vfcvt.f.x.v v22, v27 \n\t" + "vfmul.vv v12, v12, v18 \n\t" + "vfmul.vv v14, v14, v18 \n\t" + "vfmul.vv v20, v20, v18 \n\t" + "vfmul.vv v22, v22, v18 \n\t" + "vfmacc.vf v4, fa0, v12 \n\t" + "vfmacc.vf v5, fa1, v14 \n\t" + "vfmacc.vf v6, fa2, v20 \n\t" + "vfmacc.vf v7, fa3, v22 \n\t" + + "addi %[BK], %[BK], -1 \n\t" + "bgtz %[BK], BLK_LOOP%= \n\t" + + // Tail-aware store for the final N tile (`nb_real` may be < 32). + "vsetvli t0, %[NBLKS], e32, m1 \n\t" + "add t1, %[LDC], %[DST] \n\t" + "vse32.v v4, (%[DST]) \n\t" + "vse32.v v5, (t1) \n\t" + "add t2, t1, %[LDC] \n\t" + "vse32.v v6, (t2) \n\t" + "add t3, t2, %[LDC] \n\t" + "vse32.v v7, (t3) \n\t" + : [A] "+r"(a_data), [B] "+r"(b_data), [BK] "+r"(cnt) + : [DST] "r"(dst_c), [LDC] "r"(ldc * 4), [NBLKS] "r"(nb_real) + : "cc", "memory", "t0", "t1", "t2", "t3", "t4", "t5", "t6", "s1", "s2", "s3", "s4", "s5", "v0", "v1", + "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", + "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", + "v31", "fa0", "fa1", "fa2", "fa3"); + } + } else { + // Q5_0 block layout per K32/N32 tile: + // [scale_fp16 x 32][qh high-bit mask x 128B][qs low nibbles x 512B] + // There is no explicit zp stream; the implicit midpoint correction is +16. + for (size_t ni = 0; ni < count_n; ni += NB_COLS) { + size_t nb_real = std::min<size_t>(NB_COLS, count_n - ni); + uint8_t * b_data = (uint8_t *) quant_b_data + (ni / NB_COLS) * k_blks * B_Q50_BLK_STRIDE; + uint8_t * a_data = (uint8_t *) quant_a_ptr; + float * dst_c = c_ptr + ni; + size_t cnt = k_blks; + + asm volatile( + // v4-v7 are the fp32 accumulators for rows 0..3 of the current N32 tile. + "vsetvli t0, x0, e32, m1 \n\t" + "vxor.vv v4, v4, v4 \n\t" + "vxor.vv v5, v5, v5 \n\t" + "vxor.vv v6, v6, v6 \n\t" + "vxor.vv v7, v7, v7 \n\t" + + ".align 4 \n\t" + "BLK_LOOP%=: \n\t" + // Load the 4 A-row scales/sums for this K32 block and build row data pointers. + "flw fa0, 0(%[A]) \n\t" + "flw fa1, 4(%[A]) \n\t" + "flw fa2, 8(%[A]) \n\t" + "flw fa3, 12(%[A]) \n\t" + "lh s1, 16(%[A]) \n\t" + "lh s2, 18(%[A]) \n\t" + "lh s3, 20(%[A]) \n\t" + "lh s4, 22(%[A]) \n\t" + "addi t3, %[A], 24 \n\t" + "addi t4, t3, 32 \n\t" + "addi t5, t3, 64 \n\t" + "addi t6, t3, 96 \n\t" + "addi %[A], %[A], 152 \n\t" + + // B-side pointers: + // t1 -> qh bitmask stream, t2 -> qs low-nibble stream. + "addi t1, %[B], 64 \n\t" + "addi t2, %[B], 192 \n\t" + "vsetvli t0, x0, e8, mf2 \n\t" + "vle8.v v2, (%[B]) \n\t" + "vsetvli t0, x0, e8, m1 \n\t" + "vle8.v v0, (t1) \n\t" + "vl4r.v v8, (t2) \n\t" + "addi %[B], %[B], 704 \n\t" + + // Decode Q5 payload once for the whole tile and expand it into the vmadot layout. + "vand.vi v12, v8, 0xF \n\t" + "vand.vi v13, v9, 0xF \n\t" + "vand.vi v14, v10, 0xF \n\t" + "vand.vi v15, v11, 0xF \n\t" + "vsrl.vi v8, v8, 4 \n\t" + "vsrl.vi v9, v9, 4 \n\t" + "vsrl.vi v10, v10, 4 \n\t" + "vsrl.vi v11, v11, 4 \n\t" + + "vpack.vv v16, v12, v8, 0 \n\t" + "vpack.vv v18, v13, v9, 0 \n\t" + "li t2, 16 \n\t" + "vpack.vv v20, v14, v10, 0 \n\t" + "vpack.vv v22, v15, v11, 0 \n\t" + + "vsetvli t0, x0, e8, m8 \n\t" + "vadd.vx v16, v16, t2, v0.t \n\t" + + "vsetvli t0, x0, e8, m1 \n\t" + "vupack.vv v8, v16, v17, 1 \n\t" + "vupack.vv v10, v18, v19, 1 \n\t" + "vupack.vv v12, v20, v21, 1 \n\t" + "vupack.vv v14, v22, v23, 1 \n\t" + + // Convert per-column fp16 scales once; the same scale vector is shared by all 4 rows. + "vsetvli t0, x0, e16, mf2 \n\t" + "vfwcvt.f.f.v v18, v2 \n\t" + "vsetvli t0, x0, e8, m1 \n\t" + + // Row 0: dot(A0, decoded_q5) + a_sum0 * 16 (implicit Q5_0 midpoint correction). + "vle8.v v1, (t3) \n\t" + "vsetvli t0, x0, e64, m1 \n\t" + "vupack.vv v16, v1, v2, 1 \n\t" + "vsetvli t0, x0, e32, m1 \n\t" + "vxor.vv v24, v24, v24 \n\t" + "vxor.vv v26, v26, v26 \n\t" + "vxor.vv v28, v28, v28 \n\t" + "vxor.vv v30, v30, v30 \n\t" + "vmadot v24, v16, v8, i8 \n\t" + "vmadot v26, v16, v10, i8 \n\t" + "vmadot v28, v16, v12, i8 \n\t" + "vmadot v30, v16, v14, i8 \n\t" + "vmadot v24, v17, v9, i8 \n\t" + "vmadot v26, v17, v11, i8 \n\t" + "vmadot v28, v17, v13, i8 \n\t" + "vmadot v30, v17, v15, i8 \n\t" + "vpack.vv v16, v24, v26, 2 \n\t" + "slli s1, s1, 4 \n\t" + "vpack.vv v20, v28, v30, 2 \n\t" + "slli s2, s2, 4 \n\t" + "vpack.vv v24, v16, v20, 3 \n\t" + "slli s3, s3, 4 \n\t" + "vpack.vv v26, v17, v21, 3 \n\t" + "slli s4, s4, 4 \n\t" + "vadd.vx v24, v24, s1 \n\t" + "vadd.vx v25, v25, s2 \n\t" + "vadd.vx v26, v26, s3 \n\t" + "vadd.vx v27, v27, s4 \n\t" + "vfcvt.f.x.v v24, v24 \n\t" + "vfcvt.f.x.v v25, v25 \n\t" + "vfcvt.f.x.v v26, v26 \n\t" + "vfcvt.f.x.v v27, v27 \n\t" + "vfmul.vv v24, v24, v18 \n\t" + "vfmul.vv v25, v25, v18 \n\t" + "vfmul.vv v26, v26, v18 \n\t" + "vfmul.vv v27, v27, v18 \n\t" + "vfmacc.vf v4, fa0, v24 \n\t" + "vfmacc.vf v5, fa1, v25 \n\t" + "vfmacc.vf v6, fa2, v26 \n\t" + "vfmacc.vf v7, fa3, v27 \n\t" + + "addi %[BK], %[BK], -1 \n\t" + "bgtz %[BK], BLK_LOOP%= \n\t" + + // Tail-aware store for the final N tile (`nb_real` may be < 32). + "vsetvli t0, %[NBLKS], e32, m1 \n\t" + "add t1, %[LDC], %[DST] \n\t" + "vse32.v v4, (%[DST]) \n\t" + "vse32.v v5, (t1) \n\t" + "add t2, t1, %[LDC] \n\t" + "vse32.v v6, (t2) \n\t" + "add t3, t2, %[LDC] \n\t" + "vse32.v v7, (t3) \n\t" + : [A] "+r"(a_data), [B] "+r"(b_data), [BK] "+r"(cnt) + : [DST] "r"(dst_c), [LDC] "r"(ldc * 4), [NBLKS] "r"(nb_real) + : "cc", "memory", "t0", "t1", "t2", "t3", "t4", "t5", "t6", "s1", "s2", "s3", "s4", "v0", "v1", "v2", + "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", + "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31", + "fa0", "fa1", "fa2", "fa3"); + } + } +} + +void gemm_kernel_i8i8_m1(size_t blk_len, + const uint8_t * quant_a_ptr, + const uint8_t * quant_b_data, + const uint8_t * quant_b_zp, + float * c_ptr, + size_t count_m, + size_t count_n, + size_t k_blks, + size_t ldc) { + for (size_t n = 0; n < count_n; n += 32) { + size_t nblks = (count_n - n) > 32 ? 32 : count_n - n; + uint8_t * QuantBDataPtr = (uint8_t *) quant_b_data + // + n * k_blks * blk_len + // b data + n * k_blks * sizeof(_Float16); // scale + float * CPtr = c_ptr + n; + size_t cnt = k_blks; + + // A format Version_1 (FP32 SCALE FOR Normal VMADOTins of IME2) + // A M1K32 int8 256bit + // Ascale fp32 * 1 32bit + // || scl*1(fp32) | Asum(int16) | blk0 || scl*1(fp32) | Asum(int16) | blk0 || ... + // || Element || Element || ... + // B format + // B N8K32 int4 2048bit + // 4VRF, N32K32, 8192bit + // Bscale fp16 * N32 512bit; + // || scl*32..(fp16) | blk0 blk1 ... blk31 || scl*32..(fp16) | blk0 blk1 ... blk31 || ... + // || Element || Element || ... + + //bias always be nullptr + __asm__ volatile( + + // t3 = k/32 + "mv t3, %[BCK] \n\t" + "mv t4, %[NBLKS] \n\t" + "mv s2, %[pA] \n\t" // s2 = pASCL + "addi s3, %[pA], 4+2 \n\t" // s3 = pAData, (pA+AScl+ASum) + "mv s4, %[pB] \n\t" // s4 = pBSCL + "addi s5, %[pB], 32*2 \n\t" // s5 = pBdata; + "mv s6, %[pC] \n\t" + + "vsetvli t0, x0, e32, m1 \n\t" + "vxor.vv v2, v0, v0 \n\t" // clear acc + + // ordinary vmadot: vle*6 flw*1 vecIns*64 vmadot*8 + ".align 4 \n\t" + "_K_LPST%=: \n\t" + + "vsetvli t0, x0, e8, m1 \n\t" + "vl4r.v v4, (s5) \n\t" // B Data 4VRF * 8Row * 32 + "addi s5, s5, 128*4 \n\t" + "vl4r.v v8, (s5) \n\t" // B Data 4VRF * 8Row * 32 + "addi s5, s5, 128*4+64 \n\t" + + "vsetvli t0, x0, e8, mf2 \n\t" + "vle8.v v0, (s4) \n\t" // B Scale 4VRF*8Row*FP16 = 512bit + "addi s4, s4, 64+128*8 \n\t" + + "vsetvli t0, x0, e8, mf4 \n\t" + "vle8.v v3, (s3) \n\t" // A Data M1*K32*int8 = 256bit + "addi s3, s3, 32+6 \n\t" + + "flw f0, (s2) \n\t" // A Scale fp32 + "addi s2, s2, 6+32 \n\t" // AScale + Asum(FP32+i16) + + "vsetvli t0, zero, e32, m1 \n\t" + "vupack.vv v24, v4, v5, 1 \n\t" + "vupack.vv v26, v6, v7, 1 \n\t" + "vupack.vv v28, v8, v9, 1 \n\t" + "vupack.vv v30, v10, v11, 1 \n\t" + + "vslidedown.vi v4, v3, 4 \n\t" + + "vxor.vv v16, v16, v16 \n\t" + "vxor.vv v18, v16, v16 \n\t" + "vxor.vv v20, v16, v16 \n\t" + "vxor.vv v22, v16, v16 \n\t" + + "vmadot v16, v3, v24, i8 \n\t" // M0 N0 - N7 INT32(256bit) + "vmadot v18, v3, v26, i8 \n\t" // M0 N8 - N15 + "vmadot v20, v3, v28, i8 \n\t" // M0 N16 - N23 + "vmadot v22, v3, v30, i8 \n\t" // M0 N24 - N31 + + "vmadot v16, v4, v25, i8 \n\t" + "vmadot v18, v4, v27, i8 \n\t" + "vmadot v20, v4, v29, i8 \n\t" + "vmadot v22, v4, v31, i8 \n\t" + + "vpack.vv v24, v16, v18, 2 \n\t" + "vpack.vv v26, v20, v22, 2 \n\t" + "vpack.vv v16, v24, v26, 3 \n\t" + + // b_scale fp16 -> fp32 + "vsetvli t0, x0, e16, mf2 \n\t" + "vfwcvt.f.f.v v24, v0 \n\t" + // mac result i32 -> fp32 + "vsetvli t0, x0, e32, m1 \n\t" + "vfcvt.f.x.v v26, v16 \n\t" + // a_scale * b_scale; + "vfmul.vf v1, v24, f0 \n\t" + // static_cast<float>(qsum) * a_scale * b_scale; + "vfmacc.vv v2, v1, v26 \n\t" + + "addi t3, t3, -1 \n\t" + "bgtz t3, _K_LPST%= \n\t" + "_K_LPND%=: \n\t" + + //----------------------------------------- + // STORE Equal 32N------------------------- + "_ST32%=: \n\t" + "vsetvli t0, t4, e32, m1 \n\t" + "vse32.v v2, (s6) \n\t" // M0 [N0 : N32]; FP32(1024bit) + + "_FUNC_END%=: \n\t" + + : + : [BCK] "r"(cnt), [NBLKS] "r"(nblks), [pA] "r"(quant_a_ptr), [pB] "r"(QuantBDataPtr), [pC] "r"(CPtr) + : "cc", "t0", "t3", "t4", "f0", "s2", "s3", "s4", "s5", "s6"); + } +} + +void gemm_kernel_i8i8_m4(size_t blk_len, + const uint8_t * quant_a_ptr, + const uint8_t * quant_b_data, + const uint8_t * quant_b_zp, + float * c_ptr, + size_t count_m, + size_t count_n, + size_t k_blks, + size_t ldc) { + int64_t b_data_stride = k_blks * sizeof(ggml_fp16_t) + k_blks * blk_len; + for (size_t ni = 0; ni < count_n; ni += 32) { + uint8_t * b_data = (uint8_t *) quant_b_data + ni * b_data_stride; + int8_t * a_data = (int8_t *) quant_a_ptr; + float * dst_c = c_ptr + ni; + + asm volatile( + "vsetvli t0, x0, e32, m1 \n\t" + "vxor.vv v28, v28, v28 \n\t" + "vxor.vv v29, v29, v29 \n\t" + "vxor.vv v30, v30, v30 \n\t" + "vxor.vv v31, v31, v31 \n\t" + + ".align 4 \n\t" + "BLK_LOOP%=: \n\t" + // load scale A + "flw fa0, (%[A]) \n\t" + "flw fa1, 4(%[A]) \n\t" + "flw fa2, 8(%[A]) \n\t" + "flw fa3, 12(%[A]) \n\t" + "addi %[A], %[A], 16+8 \n\t" // Ascl+Asum; FP32*4+i16*4 + + // load scale B + "vsetvli t0, x0, e16, mf2 \n\t" + "vle16.v v12, (%[B]) \n\t" + "addi %[B], %[B], 64 \n\t" + "vfwcvt.f.f.v v14, v12 \n\t" + + "vsetvli t0, x0, e8, m1 \n\t" + "vl1r.v v0, (%[A]) \n\t" + "addi %[A], %[A], 128 \n\t" // 4*32@i8 + "vl4r.v v4, (%[B]) \n\t" // 32*32@i8 + "addi %[B], %[B], 512 \n\t" + "vl4r.v v8, (%[B]) \n\t" // 32*32@i8 + "addi %[B], %[B], 512 \n\t" + + "vsetvli t0, zero, e32, m1 \n\t" + "vupack.vv v2, v0, v0, 1 \n\t" + + "vupack.vv v24, v4, v5, 1 \n\t" + "vupack.vv v26, v6, v7, 1 \n\t" + "vupack.vv v4, v8, v9, 1 \n\t" + "vupack.vv v6, v10, v11, 1 \n\t" + + // init the accumu to asum * zp + "vsetvli t0, x0, e32, m1 \n\t" + "vxor.vv v16, v16, v16 \n\t" + "vxor.vv v18, v16, v16 \n\t" + "vxor.vv v20, v16, v16 \n\t" + "vxor.vv v22, v16, v16 \n\t" + + // i4 * i4 vmadot + "vsetvli t0, x0, e32, m1 \n\t" + "vmadot v16, v2, v24, i8 \n\t" + "vmadot v18, v2, v26, i8 \n\t" + "vmadot v20, v2, v4, i8 \n\t" + "vmadot v22, v2, v6, i8 \n\t" + "vmadot v16, v3, v25, i8 \n\t" + "vmadot v18, v3, v27, i8 \n\t" + "vmadot v20, v3, v5, i8 \n\t" + "vmadot v22, v3, v7, i8 \n\t" + + "vpack.vv v0, v16, v18, 2 \n\t" + "vpack.vv v2, v20, v22, 2 \n\t" + "vpack.vv v16, v0, v2, 3 \n\t" + "vpack.vv v18, v1, v3, 3 \n\t" + + "vfcvt.f.x.v v16, v16 \n\t" + "vfcvt.f.x.v v17, v17 \n\t" + "vfcvt.f.x.v v18, v18 \n\t" + "vfcvt.f.x.v v19, v19 \n\t" + + // mul scale + "vfmul.vv v16, v16, v14 \n\t" + "vfmul.vv v17, v17, v14 \n\t" + "vfmul.vv v18, v18, v14 \n\t" + "vfmul.vv v19, v19, v14 \n\t" + + "addi %[BK], %[BK], -1 \n\t" + "vfmacc.vf v28, fa0, v16 \n\t" + "vfmacc.vf v29, fa1, v17 \n\t" + "vfmacc.vf v30, fa2, v18 \n\t" + "vfmacc.vf v31, fa3, v19 \n\t" + + "bgtz %[BK], BLK_LOOP%= \n\t" + + // save + "vsetvli t0, x0, e32, m1 \n\t" + "add t2, %[LDC], %[DST] \n\t" + "vse32.v v28, (%[DST]) \n\t" + "add t3, %[LDC], t2 \n\t" + "vse32.v v29, (t2) \n\t" + "add t2, %[LDC], t3 \n\t" + "vse32.v v30, (t3) \n\t" + "vse32.v v31, (t2) \n\t" + : [A] "+r"(a_data), [B] "+r"(b_data) + : [DST] "r"(dst_c), [LDC] "r"(ldc * 4), [BK] "r"(k_blks) + : "t0", "t1", "t2", "t3", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", + "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", + "v28", "v29", "v30", "v31", "fa0", "fa1", "fa2", "fa3"); + } +} + +void moe_m2_gemm_kernel_i8i4_impl(size_t blk_len, + const uint8_t ** quant_a_ptr, + const uint8_t * quant_b_data, + const uint8_t * quant_b_zp, + float ** c_ptr, + size_t count_m, + size_t count_n, + size_t k_blks, + size_t ldc) { +#if 0 + moe_gemm_kernel_i8i4_mrow_ref<2, 32>(blk_len, quant_a_ptr, quant_b_data, quant_b_zp, c_ptr, count_m, count_n, k_blks, + ldc); +#else + int64_t b_data_stride = + k_blks * (sizeof(ggml_fp16_t) + 16 * sizeof(int8_t) + (quant_b_zp != NULL ? sizeof(int8_t) : 0)); + if (quant_b_zp == NULL) { + for (size_t ni = 0; ni < count_n; ni += 32) { + uint8_t * b_data = (uint8_t *) quant_b_data + ni * b_data_stride; + int8_t * a_data0 = (int8_t *) quant_a_ptr[0]; + int8_t * a_data1 = (int8_t *) quant_a_ptr[1]; + float * dst_c0 = (float *) c_ptr[0] + ni; + float * dst_c1 = (float *) c_ptr[1] + ni; + + asm volatile( + "vsetvli t0, x0, e16, m1 \n\t" + "vxor.vv v28, v28, v28 \n\t" + "vxor.vv v29, v29, v29 \n\t" + "vmv.v.i v0, 1 \n\t" // init the scale + "vsll.vi v1, v0, 4 \n\t" + "vfcvt.f.x.v v0, v0 \n\t" + "vfcvt.f.x.v v1, v1 \n\t" + "mv t3, %[BK] \n\t" + + ".align 4 \n\t" + "BLK_LOOP%=: \n\t" + // load scale A0 + "flw fa0, (%[A0]) \n\t" // A0 scale + "lh t1, 4(%[A0]) \n\t" // A0 asum + "addi %[A0], %[A0], 6 \n\t" + + // load scale B + "vsetvli t0, x0, e16, mf2 \n\t" + "vle16.v v12, (%[B]) \n\t" + "addi %[B], %[B], 64 \n\t" + "vsetvli t0, x0, e16, m1 \n\t" + "vpack.vv v14, v12, v12, 3 \n\t" + + // load scale A1 + "flw fa1, (%[A1]) \n\t" // A1 scale + "lh t2, 4(%[A1]) \n\t" // A1 asum + "addi %[A1], %[A1], 6 \n\t" + "vsetvli t0, x0, e16, m1 \n\t" + "vmv.v.x v10, t1 \n\t" + "vmv.v.x v11, t2 \n\t" + + "vpack.vv v18, v10, v11, 1 \n\t" + "vsll.vi v18, v18, 3 \n\t" // mul 8 + "vfcvt.f.x.v v18, v18 \n\t" + + "vsetvli t0, x0, e8, mf4 \n\t" // A0 data + "vle8.v v16, (%[A0]) \n\t" + "addi %[A0], %[A0], 32 \n\t" // 1*32@i8 + "vle8.v v20, (%[A1]) \n\t" + "addi %[A1], %[A1], 32 \n\t" // 1*32@i8 + + "vl4r.v v4, (%[B]) \n\t" // 32*32@i4 + "addi %[B], %[B], 512 \n\t" + + "vsrl.vi v17, v16, 4 \n\t" + "vsrl.vi v21, v20, 4 \n\t" + "vsetvli t0, x0, e8, m1 \n\t" + "vnpack4.vv v2, v16, v20, 2 \n\t" // low u4 + "vnpack4.vv v3, v17, v21, 2 \n\t" // high s4 + + // init the accumu to asum * zp + "vsetvli t0, x0, e16, m1 \n\t" + "vor.vv v19, v18, v18 \n\t" + "vor.vv v20, v18, v18 \n\t" + "vor.vv v21, v18, v18 \n\t" + + // i4 * i4 vmadot + "vsetvli t0, x0, e16, m1 \n\t" + "vmadotsu.hp v18, v3, v4, v1, 0, i4 \n\t" // high 4 + "vmadotsu.hp v19, v3, v5, v1, 0, i4 \n\t" + "vmadotsu.hp v20, v3, v6, v1, 0, i4 \n\t" + "vmadotsu.hp v21, v3, v7, v1, 0, i4 \n\t" + "vmadotu.hp v18, v2, v4, v0, 0, i4 \n\t" // low 4 + "vmadotu.hp v19, v2, v5, v0, 0, i4 \n\t" + "vmadotu.hp v20, v2, v6, v0, 0, i4 \n\t" + "vmadotu.hp v21, v2, v7, v0, 0, i4 \n\t" + + "vpack.vv v8, v18, v19, 1 \n\t" + "vpack.vv v12, v20, v21, 1 \n\t" + "vpack.vv v20, v8, v12, 2 \n\t" + + "vfwmul.vv v16, v20, v14 \n\t" + + "vsetvli t0, x0, e32, m1 \n\t" + + "addi t3, t3, -1 \n\t" + "vfmacc.vf v28, fa0, v16 \n\t" + "vfmacc.vf v29, fa1, v17 \n\t" + + "bgtz t3, BLK_LOOP%= \n\t" + + // save + "vsetvli t0, x0, e32, m1 \n\t" + "vse32.v v28, (%[DST0]) \n\t" + "vse32.v v29, (%[DST1]) \n\t" + : [A0] "+r"(a_data0), [A1] "+r"(a_data1), [B] "+r"(b_data) + : [DST0] "r"(dst_c0), [DST1] "r"(dst_c1), [BK] "r"(k_blks) + : "t0", "t1", "t2", "t3", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", + "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", + "v26", "v27", "v28", "v29", "v30", "v31", "fa0", "fa1", "fa2", "fa3"); + } + } else { +# if 0 + moe_gemm_kernel_i8i4_mrow_ref<2, 32>(blk_len, quant_a_ptr, quant_b_data, quant_b_zp, c_ptr, count_m, count_n, + k_blks, ldc); +# else + for (size_t ni = 0; ni < count_n; ni += 32) { + uint8_t * b_data = (uint8_t *) quant_b_data + ni * b_data_stride; + int8_t * a_data0 = (int8_t *) quant_a_ptr[0]; + int8_t * a_data1 = (int8_t *) quant_a_ptr[1]; + float * dst_c0 = (float *) c_ptr[0] + ni; + float * dst_c1 = (float *) c_ptr[1] + ni; + + asm volatile( + "vsetvli t0, x0, e16, m1 \n\t" + "vxor.vv v28, v28, v28 \n\t" + "vxor.vv v29, v29, v29 \n\t" + "vmv.v.i v0, 1 \n\t" // init the scale + "vsll.vi v1, v0, 4 \n\t" + "vfcvt.f.x.v v0, v0 \n\t" + "vfcvt.f.x.v v1, v1 \n\t" + "mv t3, %[BK] \n\t" + + ".align 4 \n\t" + "BLK_LOOP%=: \n\t" + // load scale A0 + "flw fa0, (%[A0]) \n\t" // A0 scale + "lh t1, 4(%[A0]) \n\t" // A0 asum + "addi %[A0], %[A0], 6 \n\t" + + // load scale B + "vsetvli t0, x0, e16, mf2 \n\t" + "vle16.v v12, (%[B]) \n\t" + "addi %[B], %[B], 64 \n\t" + "vsetvli t0, x0, e16, m1 \n\t" + "vpack.vv v14, v12, v12, 3 \n\t" + + // load scale A1 + "flw fa1, (%[A1]) \n\t" // A1 scale + "lh t2, 4(%[A1]) \n\t" // A1 asum + "addi %[A1], %[A1], 6 \n\t" + + // load zp + "vsetvli t0, x0, e8, mf4 \n\t" + "vle8.v v8, (%[B]) \n\t" + "addi %[B], %[B], 32 \n\t" + "vwaddu.vx v10, v8, x0 \n\t" + + "vsetvli t0, x0, e8, mf4 \n\t" // A0 data + "vle8.v v16, (%[A0]) \n\t" + "addi %[A0], %[A0], 32 \n\t" // 1*32@i8 + "vle8.v v20, (%[A1]) \n\t" + "addi %[A1], %[A1], 32 \n\t" // 1*32@i8 + + "vl4r.v v4, (%[B]) \n\t" // 32*32@i4 + "addi %[B], %[B], 512 \n\t" + + "vsrl.vi v17, v16, 4 \n\t" + "vsrl.vi v21, v20, 4 \n\t" + "vsetvli t0, x0, e8, m1 \n\t" + "vnpack4.vv v2, v16, v20, 2 \n\t" // low u4 + "vnpack4.vv v3, v17, v21, 2 \n\t" // high s4 + + // init the accumu to asum * zp + "vsetvli t0, x0, e16, m1 \n\t" + "vxor.vv v18, v18, v18 \n\t" + "vxor.vv v19, v19, v19 \n\t" + "vxor.vv v20, v20, v20 \n\t" + "vxor.vv v21, v21, v21 \n\t" + + // i4 * i4 vmadot + "vsetvli t0, x0, e16, m1 \n\t" + "vmadotsu.hp v18, v3, v4, v1, 0, i4 \n\t" // high 4 + "vmadotsu.hp v19, v3, v5, v1, 0, i4 \n\t" + "vmadotsu.hp v20, v3, v6, v1, 0, i4 \n\t" + "vmadotsu.hp v21, v3, v7, v1, 0, i4 \n\t" + "vmadotu.hp v18, v2, v4, v0, 0, i4 \n\t" // low 4 + "vmadotu.hp v19, v2, v5, v0, 0, i4 \n\t" + "vmadotu.hp v20, v2, v6, v0, 0, i4 \n\t" + "vmadotu.hp v21, v2, v7, v0, 0, i4 \n\t" + + "vpack.vv v8, v18, v19, 1 \n\t" + "vpack.vv v12, v20, v21, 1 \n\t" + "vpack.vv v20, v8, v12, 2 \n\t" + // asum*zp + "vsetvli t0, x0, e16, mf2 \n\t" + "vwmul.vx v2, v10, t1 \n\t" + "vwmul.vx v4, v10, t2 \n\t" + + "vsetvli t0, x0, e32, m1 \n\t" + + "vfcvt.f.x.v v2, v2 \n\t" + "vfcvt.f.x.v v4, v4 \n\t" + + "vsetvli t0, x0, e16, m1 \n\t" + "vfwcvt.f.f.v v16, v20 \n\t" + + "vfwcvt.f.f.v v18, v14 \n\t" + + // +asum*zp + "vsetvli t0, x0, e32, m1 \n\t" + "vfadd.vv v16, v16, v2 \n\t" + "vfadd.vv v17, v17, v4 \n\t" + "vfmul.vv v16, v16, v18 \n\t" + "vfmul.vv v17, v17, v18 \n\t" + + "addi t3, t3, -1 \n\t" + "vfmacc.vf v28, fa0, v16 \n\t" + "vfmacc.vf v29, fa1, v17 \n\t" + + "bgtz t3, BLK_LOOP%= \n\t" + + // save + "vsetvli t0, x0, e32, m1 \n\t" + "vse32.v v28, (%[DST0]) \n\t" + "vse32.v v29, (%[DST1]) \n\t" + : [A0] "+r"(a_data0), [A1] "+r"(a_data1), [B] "+r"(b_data) + : [DST0] "r"(dst_c0), [DST1] "r"(dst_c1), [BK] "r"(k_blks) + : "t0", "t1", "t2", "t3", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", + "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", + "v26", "v27", "v28", "v29", "v30", "v31", "fa0", "fa1", "fa2", "fa3"); + } +# endif + } +#endif +} + +void moe_m2_gemm_kernel_i8i5_impl(size_t blk_len, + const uint8_t ** quant_a_ptr, + const uint8_t * quant_b_data, + const uint8_t * quant_b_zp, + float ** c_ptr, + size_t count_m, + size_t count_n, + size_t k_blks, + size_t ldc) { + constexpr size_t NB_COLS = 32; + constexpr size_t B_Q50_BLK_STRIDE = sizeof(nrow_block_q5_0<NB_COLS>); + constexpr size_t B_Q51_BLK_STRIDE = sizeof(nrow_block_q5_1<NB_COLS>); + + GGML_UNUSED(blk_len); + GGML_UNUSED(count_m); + GGML_UNUSED(ldc); + + if (quant_b_zp == NULL) { + for (size_t ni = 0; ni < count_n; ni += NB_COLS) { + size_t nb_real = std::min<size_t>(NB_COLS, count_n - ni); + uint8_t * b_data = (uint8_t *) quant_b_data + (ni / NB_COLS) * k_blks * B_Q50_BLK_STRIDE; + int8_t * a_data0 = (int8_t *) quant_a_ptr[0]; + int8_t * a_data1 = (int8_t *) quant_a_ptr[1]; + float * dst_c0 = (float *) c_ptr[0] + ni; + float * dst_c1 = (float *) c_ptr[1] + ni; + + asm volatile( + "mv t4, %[BK] \n\t" + "vsetvli t0, x0, e32, m1 \n\t" + "vxor.vv v2, v0, v0 \n\t" + "vxor.vv v3, v0, v0 \n\t" + + ".align 4 \n\t" + "BLK_LOOP%=: \n\t" + // ---- load B scale/Bh/Bs and advance to the next q5_0 k-block ---- + "vsetvli t0, x0, e8, mf2 \n\t" + "vle8.v v1, (%[B]) \n\t" // v1 = scale_fp16 × 32 + "addi %[B], %[B], 64 \n\t" + "vsetvli t0, x0, e8, m1 \n\t" + "vle8.v v0, (%[B]) \n\t" // v0 = Bh N32K32 1-bit packed + "addi %[B], %[B], 128 \n\t" + "vl4r.v v8, (%[B]) \n\t" // v8..v11 = Bs N32K32 i4 + "addi %[B], %[B], 512 \n\t" + + // ---- load A0/A1 header then payload, each block stride = 38B ---- + "flw f0, (%[A0]) \n\t" // f0 = A0 scale (fp32) + "lh t2, 4(%[A0]) \n\t" // t2 = A0 asum (int16) + "addi %[A0], %[A0], 6 \n\t" + "flw f1, (%[A1]) \n\t" // f1 = A1 scale (fp32) + "lh t3, 4(%[A1]) \n\t" // t3 = A1 asum (int16) + "addi %[A1], %[A1], 6 \n\t" + "vsetvli t0, x0, e8, mf4 \n\t" + "vle8.v v4, (%[A0]) \n\t" // v4 = A0 M1K32 int8 + "addi %[A0], %[A0], 32 \n\t" + "vle8.v v5, (%[A1]) \n\t" // v5 = A1 M1K32 int8 + "addi %[A1], %[A1], 32 \n\t" + + //// ---- A nibble unpacking ---- + "vsetvli t0, x0, e8, m1 \n\t" + "vand.vi v12, v8, 0xF \n\t" //8bit(lo4) //[8*32] + "vand.vi v13, v9, 0xF \n\t" + "vand.vi v14, v10, 0xF \n\t" + "vand.vi v15, v11, 0xF \n\t" + "vsrl.vi v8, v8, 4 \n\t" //8bit(hi4) + "vsrl.vi v9, v9, 4 \n\t" + "vsrl.vi v10, v10, 4 \n\t" + "vsrl.vi v11, v11, 4 \n\t" + + "slli t2, t2, 4 \n\t" // a_sum * 16; + "slli t3, t3, 4 \n\t" + // [4*32]*2 + "vsetvli t0, x0, e8, m1 \n\t" + "vpack.vv v16, v12, v8, 0 \n\t" + "vpack.vv v18, v13, v9, 0 \n\t" + "vpack.vv v20, v14, v10, 0 \n\t" + "vpack.vv v22, v15, v11, 0 \n\t" + + "li t1, 16 \n\t" + "vsetvli t0, x0, e8, m8 \n\t" + "vadd.vx v16, v16, t1, v0.t \n\t" + + // [4*32]*2 -> [8*16] + "vsetvli t0, x0, e8, m1 \n\t" + "vupack.vv v8, v16, v17, 1 \n\t" + "vupack.vv v10, v18, v19, 1 \n\t" + "vupack.vv v12, v20, v21, 1 \n\t" + "vupack.vv v14, v22, v23, 1 \n\t" + + "vpack.vv v6, v4, v5, 2 \n\t" + + // init the accumu to asum * zp + "vsetvli t0, x0, e32, m1 \n\t" + "vxor.vv v24, v16, v16 \n\t" + "vxor.vv v26, v16, v16 \n\t" + "vupack.vv v4, v6, v7, 1 \n\t" + "vxor.vv v28, v16, v16 \n\t" + "vxor.vv v30, v16, v16 \n\t" + + // ---- i8 main dot products ---- + // vmadot: A × unsigned Bh × 16 → fp16 accumulate + "vmadot v24, v4, v8, i8 \n\t" // N0..7 + "vmadot v26, v4, v10, i8 \n\t" // N8..15 + "vmadot v28, v4, v12, i8 \n\t" // N16..23 + "vmadot v30, v4, v14, i8 \n\t" // N24..31 + // vmadot: A × unsigned Bh × 1 → fp16 accumulate + "vmadot v24, v5, v9, i8 \n\t" // N0..7 + "vmadot v26, v5, v11, i8 \n\t" // N8..15 + "vmadot v28, v5, v13, i8 \n\t" // N16..23 + "vmadot v30, v5, v15, i8 \n\t" // N24..31 + + "vpack.vv v16, v24, v26, 2 \n\t" // v16 = N0..15 + "vpack.vv v18, v28, v30, 2 \n\t" // v18 = N16..31 + "vpack.vv v24, v16, v18, 3 \n\t" // v24 = N0..31 + + "vadd.vx v24, v24, t2 \n\t" + "vadd.vx v25, v25, t3 \n\t" + // b_scale fp16 -> fp32 + "vsetvli t0, x0, e16, mf2 \n\t" + "vfwcvt.f.f.v v28, v1 \n\t" + + // a_scale * b_scale; + "vsetvli t0, x0, e32, m1 \n\t" + "vfcvt.f.x.v v26, v24 \n\t" + "vfcvt.f.x.v v27, v25 \n\t" + "vfmul.vf v30, v28, f0 \n\t" + "vfmul.vf v31, v28, f1 \n\t" + // static_cast<float>(qsum) * a_scale * b_scale; + "vfmacc.vv v2, v30, v26 \n\t" + "vfmacc.vv v3, v31, v27 \n\t" + + "addi t4, t4, -1 \n\t" + "bgtz t4, BLK_LOOP%= \n\t" + + "vsetvli t0, %[NR], e32, m1 \n\t" + "vse32.v v2, (%[DST0]) \n\t" + "vse32.v v3, (%[DST1]) \n\t" + : [A0] "+r"(a_data0), [A1] "+r"(a_data1), [B] "+r"(b_data) + : [DST0] "r"(dst_c0), [DST1] "r"(dst_c1), [BK] "r"(k_blks), [NR] "r"(nb_real) + : "cc", "memory", "t0", "t1", "t2", "t3", "t4", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", + "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", + "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31", "f0", "f1"); + } + } else { + for (size_t ni = 0; ni < count_n; ni += NB_COLS) { + size_t nb_real = std::min<size_t>(NB_COLS, count_n - ni); + uint8_t * b_data = (uint8_t *) quant_b_data + (ni / NB_COLS) * k_blks * B_Q51_BLK_STRIDE; + int8_t * a_data0 = (int8_t *) quant_a_ptr[0]; + int8_t * a_data1 = (int8_t *) quant_a_ptr[1]; + float * dst_c0 = (float *) c_ptr[0] + ni; + float * dst_c1 = (float *) c_ptr[1] + ni; + + asm volatile( + "mv t4, %[BK] \n\t" + "vsetvli t0, x0, e32, m1 \n\t" + "vxor.vv v2, v0, v0 \n\t" + "vxor.vv v3, v0, v0 \n\t" + "addi t5, %[B], 64 \n\t" // t5 = zp (32B) + "addi t6, %[B], 96 \n\t" // t6 = qh (128B) + "addi s1, %[B], 224 \n\t" // s1 = qs (512B) + + ".align 4 \n\t" + "BLK_LOOP%=: \n\t" + // ---- load B scale/zp/Bh/Bs and advance to the next q5_1 k-block ---- + "vsetvli t0, x0, e8, mf2 \n\t" + "vle8.v v1, (%[B]) \n\t" // v1 = scale_fp16 × 32 + "addi %[B], %[B], 736 \n\t" + "vsetvli t0, x0, e8, m1 \n\t" + "vle8.v v0, (t6) \n\t" // v0 = Bh N32K32 1-bit packed + "addi t6, t6, 736 \n\t" + "vl4r.v v8, (s1) \n\t" // v8..v11 = Bs N32K32 i4 + "addi s1, s1, 736 \n\t" + + // ---- load A0/A1 header then payload, each block stride = 38B ---- + "flw f0, (%[A0]) \n\t" // f0 = A0 scale (fp32) + "lh t2, 4(%[A0]) \n\t" // t2 = A0 asum (int16) + "addi %[A0], %[A0], 6 \n\t" + "flw f1, (%[A1]) \n\t" // f1 = A1 scale (fp32) + "lh t3, 4(%[A1]) \n\t" // t3 = A1 asum (int16) + "addi %[A1], %[A1], 6 \n\t" + "vsetvli t0, x0, e8, mf4 \n\t" + "vle8.v v4, (%[A0]) \n\t" // v4 = A0 M1K32 int8 + "addi %[A0], %[A0], 32 \n\t" + "vle8.v v5, (%[A1]) \n\t" // v5 = A1 M1K32 int8 + "addi %[A1], %[A1], 32 \n\t" + + //// ---- A nibble unpacking ---- + "vsetvli t0, x0, e8, m1 \n\t" + "vand.vi v12, v8, 0xF \n\t" //8bit(lo4) //[8*32] + "vand.vi v13, v9, 0xF \n\t" + "vand.vi v14, v10, 0xF \n\t" + "vand.vi v15, v11, 0xF \n\t" + "vsrl.vi v8, v8, 4 \n\t" //8bit(hi4) + "vsrl.vi v9, v9, 4 \n\t" + "vsrl.vi v10, v10, 4 \n\t" + "vsrl.vi v11, v11, 4 \n\t" + + // q5_1 uses explicit zp, so keep a_sum unshifted here. + // [4*32]*2 + "vpack.vv v16, v12, v8, 0 \n\t" + "vpack.vv v18, v13, v9, 0 \n\t" + "vpack.vv v20, v14, v10, 0 \n\t" + "vpack.vv v22, v15, v11, 0 \n\t" + + "li t1, 16 \n\t" + "vsetvli t0, x0, e8, m8 \n\t" + "vadd.vx v16, v16, t1, v0.t \n\t" + + // [4*32]*2 -> [8*16] + "vsetvli t0, x0, e8, m1 \n\t" + "vupack.vv v8, v16, v17, 1 \n\t" + "vupack.vv v10, v18, v19, 1 \n\t" + "vupack.vv v12, v20, v21, 1 \n\t" + "vupack.vv v14, v22, v23, 1 \n\t" + + "vpack.vv v6, v4, v5, 2 \n\t" + + // init the accumu to asum * zp + "vsetvli t0, x0, e32, m1 \n\t" + "vxor.vv v24, v16, v16 \n\t" + "vxor.vv v26, v16, v16 \n\t" + "vupack.vv v4, v6, v7, 1 \n\t" + "vxor.vv v28, v16, v16 \n\t" + "vxor.vv v30, v16, v16 \n\t" + + // ---- i8 main dot products ---- + // vmadot: A × unsigned Bh × 16 → fp16 accumulate + "vmadot v24, v4, v8, i8 \n\t" // N0..7 + "vmadot v26, v4, v10, i8 \n\t" // N8..15 + "vmadot v28, v4, v12, i8 \n\t" // N16..23 + "vmadot v30, v4, v14, i8 \n\t" // N24..31 + // vmadot: A × unsigned Bh × 1 → fp16 accumulate + "vmadot v24, v5, v9, i8 \n\t" // N0..7 + "vmadot v26, v5, v11, i8 \n\t" // N8..15 + "vmadot v28, v5, v13, i8 \n\t" // N16..23 + "vmadot v30, v5, v15, i8 \n\t" // N24..31 + + "vsetvli t0, x0, e8, mf4 \n\t" + "vle8.v v4, (t5) \n\t" // v4 = Bzp N32 uint8 + "addi t5, t5, 736 \n\t" + + "vsetvli t0, x0, e8, m1 \n\t" + "vpack.vv v16, v24, v26, 2 \n\t" // v16 = N0..15 + "vpack.vv v18, v28, v30, 2 \n\t" // v18 = N16..31 + "vpack.vv v24, v16, v18, 3 \n\t" // v24 = N0..31 + + "vsetvli t0, x0, e8, mf4 \n\t" + "vwaddu.vx v28, v4, x0 \n\t" + + "vsetvli t0, x0, e16, mf2 \n\t" + "vwmul.vx v30, v28, t2 \n\t" + "vwmul.vx v31, v28, t3 \n\t" + + // b_scale fp16 -> fp32 + "vfwcvt.f.f.v v28, v1 \n\t" + + "vsetvli t0, x0, e32, m1 \n\t" + "vadd.vv v24, v24, v30 \n\t" + "vadd.vv v25, v25, v31 \n\t" + + // a_scale * b_scale; + "vfcvt.f.x.v v26, v24 \n\t" + "vfcvt.f.x.v v27, v25 \n\t" + "vfmul.vf v30, v28, f0 \n\t" + "vfmul.vf v31, v28, f1 \n\t" + // static_cast<float>(qsum) * a_scale * b_scale; + "vfmacc.vv v2, v30, v26 \n\t" + "vfmacc.vv v3, v31, v27 \n\t" + + "addi t4, t4, -1 \n\t" + "bgtz t4, BLK_LOOP%= \n\t" + + "vsetvli t0, %[NR], e32, m1 \n\t" + "vse32.v v2, (%[DST0]) \n\t" + "vse32.v v3, (%[DST1]) \n\t" + : [A0] "+r"(a_data0), [A1] "+r"(a_data1), [B] "+r"(b_data) + : [DST0] "r"(dst_c0), [DST1] "r"(dst_c1), [BK] "r"(k_blks), [NR] "r"(nb_real) + : "cc", "memory", "t0", "t1", "t2", "t3", "t4", "t5", "t6", "s1", "v0", "v1", "v2", "v3", "v4", "v5", + "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", + "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31", "f0", "f1"); + } + } +} + +size_t gemm_kernel_i8i2k(size_t blk_len, + const uint8_t * quant_a_ptr, + const uint8_t * quant_b_data, + const uint8_t * quant_b_zp, + float * c_ptr, + size_t count_m, + size_t count_n, + size_t k_blks, + size_t ldc) { + if (count_m >= 4) { +#if 0 + gemm_kernel_i8i2k_mrow_ref<4, 32>(blk_len, quant_a_ptr, quant_b_data, c_ptr, count_m, count_n, k_blks, ldc); +#else + gemm_kernel_i8i2k_m4(blk_len, quant_a_ptr, quant_b_data, c_ptr, count_m, count_n, k_blks, ldc); +#endif + return 4; + } else { +#if 0 + gemm_kernel_i8i2k_mrow_ref<1, 32>(blk_len, quant_a_ptr, quant_b_data, c_ptr, count_m, count_n, k_blks, + ldc); +#else + gemm_kernel_i8i2k_m1(blk_len, quant_a_ptr, quant_b_data, c_ptr, count_m, count_n, k_blks, ldc); +#endif + return 1; + } +} + +size_t gemm_kernel_i8i3k(size_t blk_len, + const uint8_t * quant_a_ptr, + const uint8_t * quant_b_data, + const uint8_t * quant_b_zp, + float * c_ptr, + size_t count_m, + size_t count_n, + size_t k_blks, + size_t ldc) { + if (count_m >= 4) { +#if 0 + gemm_kernel_i8i3k_mrow_ref<4, 32>(blk_len, quant_a_ptr, quant_b_data, c_ptr, count_m, count_n, k_blks, ldc); +#else + gemm_kernel_i8i3k_m4(blk_len, quant_a_ptr, quant_b_data, c_ptr, count_m, count_n, k_blks, ldc); +#endif + return 4; + } else { +#if 0 + gemm_kernel_i8i3k_mrow_ref<1, 32>(blk_len, quant_a_ptr, quant_b_data, c_ptr, count_m, count_n, k_blks, ldc); +#else + gemm_kernel_i8i3k_m1(blk_len, quant_a_ptr, quant_b_data, c_ptr, count_m, count_n, k_blks, ldc); +#endif + return 1; + } +} + +size_t gemm_kernel_i8i4(size_t blk_len, + const uint8_t * quant_a_ptr, + const uint8_t * quant_b_data, + const uint8_t * quant_b_zp, + float * c_ptr, + size_t count_m, + size_t count_n, + size_t k_blks, + size_t ldc) { + if (count_m >= 4) { +#if 0 + gemm_kernel_i8i4_mrow_ref<4, 32>(blk_len, quant_a_ptr, quant_b_data, quant_b_zp, c_ptr, count_m, count_n, + k_blks, ldc); +#else + gemm_kernel_i8i4_m4(blk_len, quant_a_ptr, quant_b_data, quant_b_zp, c_ptr, count_m, count_n, k_blks, ldc); +#endif + return 4; + } else { +#if 0 + gemm_kernel_i8i4_mrow_ref<1, 32>(blk_len, quant_a_ptr, quant_b_data, quant_b_zp, c_ptr, count_m, count_n, + k_blks, ldc); +#else + gemm_kernel_i8i4_m1(blk_len, quant_a_ptr, quant_b_data, quant_b_zp, c_ptr, count_m, count_n, k_blks, ldc); +#endif + return 1; + } +} + +size_t gemm_kernel_i8i4_hp(size_t blk_len, + const uint8_t * quant_a_ptr, + const uint8_t * quant_b_data, + const uint8_t * quant_b_zp, + float * c_ptr, + size_t count_m, + size_t count_n, + size_t k_blks, + size_t ldc) { + if (count_m >= 4) { +#if 0 + gemm_kernel_i8i4_hp_mrow_ref<4, 32>(blk_len, quant_a_ptr, quant_b_data, quant_b_zp, c_ptr, count_m, count_n, + k_blks, ldc); +#else + gemm_kernel_i8i4_hp_m4(blk_len, quant_a_ptr, quant_b_data, quant_b_zp, c_ptr, count_m, count_n, k_blks, ldc); +#endif + return 4; + } else { +#if 0 + gemm_kernel_i8i4_hp_mrow_ref<1, 32>(blk_len, quant_a_ptr, quant_b_data, quant_b_zp, c_ptr, count_m, count_n, + k_blks, ldc); +#else + gemm_kernel_i8i4_hp_m1(blk_len, quant_a_ptr, quant_b_data, quant_b_zp, c_ptr, count_m, count_n, k_blks, ldc); +#endif + return 1; + } +} + +size_t moe_m2_gemm_kernel_i8i4(size_t blk_len, + const uint8_t ** quant_a_ptr, + const uint8_t * quant_b_data, + const uint8_t * quant_b_zp, + float ** c_ptr, + size_t count_m, + size_t count_n, + size_t k_blks, + size_t ldc) { + moe_m2_gemm_kernel_i8i4_impl(blk_len, quant_a_ptr, quant_b_data, quant_b_zp, c_ptr, count_m, count_n, k_blks, ldc); + return 2; +} + +size_t gemm_kernel_i8i8(size_t blk_len, + const uint8_t * quant_a_ptr, + const uint8_t * quant_b_data, + const uint8_t * quant_b_zp, + float * c_ptr, + size_t count_m, + size_t count_n, + size_t k_blks, + size_t ldc) { + if (count_m >= 4) { +#if 0 + gemm_kernel_i8i8_mrow_ref<4, 32>(blk_len, quant_a_ptr, quant_b_data, quant_b_zp, c_ptr, count_m, count_n, + k_blks, ldc); +#else + gemm_kernel_i8i8_m4(blk_len, quant_a_ptr, quant_b_data, quant_b_zp, c_ptr, count_m, count_n, k_blks, ldc); +#endif + return 4; + } else { +#if 0 + gemm_kernel_i8i8_mrow_ref<1, 32>(blk_len, quant_a_ptr, quant_b_data, quant_b_zp, c_ptr, count_m, count_n, + k_blks, ldc); +#else + gemm_kernel_i8i8_m1(blk_len, quant_a_ptr, quant_b_data, quant_b_zp, c_ptr, count_m, count_n, k_blks, ldc); +#endif + return 1; + } +} + +size_t gemm_kernel_i8mxfp4(size_t blk_len, + const uint8_t * quant_a_ptr, + const uint8_t * quant_b_data, + const uint8_t * quant_b_zp, + float * c_ptr, + size_t count_m, + size_t count_n, + size_t k_blks, + size_t ldc) { + if (count_m >= 4) { +#if 1 + gemm_kernel_i8mxfp4_mrow_ref<4, 32>(blk_len, quant_a_ptr, quant_b_data, quant_b_zp, c_ptr, count_m, count_n, + k_blks, ldc); +#else + gemm_kernel_i8mxfp4_m4(blk_len, quant_a_ptr, quant_b_data, quant_b_zp, c_ptr, count_m, count_n, k_blks, ldc); +#endif + return 4; + } else { +#if 1 + gemm_kernel_i8mxfp4_mrow_ref<1, 32>(blk_len, quant_a_ptr, quant_b_data, quant_b_zp, c_ptr, count_m, count_n, + k_blks, ldc); +#else + gemm_kernel_i8mxfp4_m1(blk_len, quant_a_ptr, quant_b_data, quant_b_zp, c_ptr, count_m, count_n, k_blks, ldc); +#endif + return 1; + } +} + +size_t moe_m2_gemm_kernel_i8mxfp4(size_t blk_len, + const uint8_t ** quant_a_ptr, + const uint8_t * quant_b_data, + const uint8_t * quant_b_zp, + float ** c_ptr, + size_t count_m, + size_t count_n, + size_t k_blks, + size_t ldc) { + //moe_m2_gemm_kernel_i8mxfp4_impl(blk_len, quant_a_ptr, quant_b_data, quant_b_zp, c_ptr, count_m, count_n, k_blks, ldc); + return 2; +} + +size_t gemm_kernel_i8i5(size_t blk_len, + const uint8_t * quant_a_ptr, + const uint8_t * quant_b_data, + const uint8_t * quant_b_zp, + float * c_ptr, + size_t count_m, + size_t count_n, + size_t k_blks, + size_t ldc) { + if (count_m >= 4) { +#if 0 + gemm_kernel_i8i5_mrow_ref<4, 32>(blk_len, quant_a_ptr, quant_b_data, quant_b_zp, c_ptr, count_m, count_n, + k_blks, ldc); +#else + gemm_kernel_i8i5_m4(blk_len, quant_a_ptr, quant_b_data, quant_b_zp, c_ptr, count_m, count_n, k_blks, ldc); +#endif + return 4; + } else { +#if 0 + gemm_kernel_i8i5_mrow_ref<1, 32>(blk_len, quant_a_ptr, quant_b_data, quant_b_zp, c_ptr, count_m, count_n, + k_blks, ldc); +#else + gemm_kernel_i8i5_m1(blk_len, quant_a_ptr, quant_b_data, quant_b_zp, c_ptr, count_m, count_n, k_blks, ldc); +#endif + return 1; + } +} + +size_t moe_m2_gemm_kernel_i8i5(size_t blk_len, + const uint8_t ** quant_a_ptr, + const uint8_t * quant_b_data, + const uint8_t * quant_b_zp, + float ** c_ptr, + size_t count_m, + size_t count_n, + size_t k_blks, + size_t ldc) { +#if 0 + moe_gemm_kernel_i8i5_mrow_ref<2, 32>(blk_len, quant_a_ptr, quant_b_data, quant_b_zp, c_ptr, count_m, count_n, + k_blks, ldc); +#else + moe_m2_gemm_kernel_i8i5_impl(blk_len, quant_a_ptr, quant_b_data, quant_b_zp, c_ptr, count_m, count_n, k_blks, ldc); +#endif + return 2; +} + +} // namespace ime2 +} // namespace spacemit_kernels diff --git a/ggml/src/ggml-cpu/spacemit/ime_env.cpp b/ggml/src/ggml-cpu/spacemit/ime_env.cpp new file mode 100644 index 00000000000..a13ba391da2 --- /dev/null +++ b/ggml/src/ggml-cpu/spacemit/ime_env.cpp @@ -0,0 +1,320 @@ +#include "ime_env.h" + +#include "ggml-impl.h" +#include "spine_mem_pool.h" + +#include <fcntl.h> +#include <unistd.h> + +#include <algorithm> +#include <array> +#include <cctype> +#include <fstream> +#include <string> +#include <thread> +#include <unordered_map> + +namespace ggml::cpu::riscv64_spacemit { +bool spine_core_info::get_spine_core_info(std::vector<spine_core_info> & result) { + static std::unordered_map<uint64_t, spine_core_arch_id> spine_march_mapping_ = { + {0x8000000058000001, spine_core_arch_id::core_arch_x60 }, + { 0x8000000041000001, spine_core_arch_id::core_arch_a60 }, + { 0x8000000058000002, spine_core_arch_id::core_arch_x100}, + { 0x8000000041000002, spine_core_arch_id::core_arch_a100}, + }; + + result.clear(); + std::ifstream file("/proc/cpuinfo"); + std::string line; + + std::vector<std::array<uint64_t, 2>> cpu_info_list; + + uint64_t current_processor = spine_invalid_core_id; + uint64_t current_marchid = 0; + bool has_processor = false; + bool has_marchid = false; + + if (!file.is_open()) { + return false; + } + + while (std::getline(file, line)) { + if (line.substr(0, 9) == "processor") { + if (has_processor && has_marchid) { + cpu_info_list.push_back({ current_processor, current_marchid }); + } + + size_t colon_pos = line.find(':'); + if (colon_pos != std::string::npos) { + current_processor = std::stoi(line.substr(colon_pos + 1)); + has_processor = true; + } + + has_marchid = false; + } else if (line.substr(0, 7) == "marchid") { + size_t colon_pos = line.find(':'); + if (colon_pos != std::string::npos) { + std::string marchid_str = line.substr(colon_pos + 1); + marchid_str.erase(std::remove_if(marchid_str.begin(), marchid_str.end(), isspace), marchid_str.end()); + current_marchid = std::stoull(marchid_str, nullptr, 16); + has_marchid = true; + } + } + } + + if (has_processor && has_marchid) { + cpu_info_list.push_back({ current_processor, current_marchid }); + } + + if (has_processor && has_marchid) { + for (auto & cpu_info : cpu_info_list) { + if (cpu_info[0] != spine_invalid_core_id && + spine_march_mapping_.find(cpu_info[1]) != spine_march_mapping_.end()) { + auto core_info = spine_core_info(); + core_info.core_id = cpu_info[0]; + core_info.arch_id = spine_core_arch_id(spine_march_mapping_[cpu_info[1]]); + + result.push_back(core_info); + } + } + } + + return has_processor && has_marchid; +} + +namespace { +uint16_t hex_string_to_u16(const std::string & hex_str) { + try { + size_t pos = 0; + if (hex_str.substr(0, 2) == "0x" || hex_str.substr(0, 2) == "0X") { + pos = 2; + } + unsigned long result = std::stoul(hex_str.substr(pos), nullptr, 16); + if (result > std::numeric_limits<uint16_t>::max()) { + throw std::out_of_range("Converted value is out of range for uint16_t"); + } + return static_cast<uint16_t>(result); + } catch (const std::invalid_argument & e) { + throw std::invalid_argument("Invalid hexadecimal string"); + } catch (const std::out_of_range & e) { + throw; + } +} + +const char * spine_mem_pool_backend_to_string(spine_mem_pool_backend backend) { + switch (backend) { + case spine_mem_pool_backend::none: + return "NONE"; + case spine_mem_pool_backend::posix_memalign: + return "POSIX"; + case spine_mem_pool_backend::transparent_hugepage: + return "HPAGE"; + case spine_mem_pool_backend::hugetlb_1g: + return "HPAGE1GB"; + } + + return "unknown"; +} + +spine_mem_pool_backend parse_mem_backend(const char * mem_backend_str) { + if (mem_backend_str == nullptr || mem_backend_str[0] == '\0') { + return spine_mem_pool_backend::transparent_hugepage; + } + + std::string value(mem_backend_str); + std::transform(value.begin(), value.end(), value.begin(), + [](unsigned char ch) { return static_cast<char>(std::tolower(ch)); }); + + if (value == "none") { + return spine_mem_pool_backend::none; + } + + if (value == "posix") { + return spine_mem_pool_backend::posix_memalign; + } + + if (value == "hpage") { + return spine_mem_pool_backend::transparent_hugepage; + } + + if (value == "hpage1gb") { + return spine_mem_pool_backend::hugetlb_1g; + } + + throw std::runtime_error("invalid SPACEMIT_MEM_BACKEND: " + value + ", expected NONE, POSIX, HPAGE or HPAGE1GB"); +} +} // namespace + +spine_env_info::spine_env_info() { + num_cores = static_cast<int>(std::thread::hardware_concurrency()); + spine_core_info::get_spine_core_info(core_info_list); + + // special for x60 K1 + if (core_info_list.size() == 8 && core_info_list[0].arch_id == spine_core_arch_id::core_arch_x60) { + for (int i = 0; i < 4; i++) { + core_info_list[i].arch_id = spine_core_arch_id::core_arch_a60; + } + } + + // special for qemu + if (core_info_list.size() == 0) { + char * spine_core_arch_str = getenv("SPACEMIT_CORE_ARCH"); + if (spine_core_arch_str != nullptr) { + auto arch_id = hex_string_to_u16(spine_core_arch_str); + for (int i = 0; i < num_cores; i++) { + auto core_info = spine_core_info(); + core_info.core_id = i; + core_info.arch_id = spine_core_arch_id{ arch_id }; + core_info_list.push_back(core_info); + } + } + } + + if (core_info_list.size() == 0) { + throw std::runtime_error( + "Failed to get SPACEMIT_CORE_ARCH from environment or failed to parse it from /proc/cpuinfo"); + } + + char * spine_perfer_core_arch_str = getenv("SPACEMIT_PERFER_CORE_ARCH"); + if (spine_perfer_core_arch_str != nullptr && spine_perfer_core_arch_str != "") { + perfer_core_arch_id = spine_core_arch_id{ hex_string_to_u16(spine_perfer_core_arch_str) }; + } + + char * spine_perfer_core_id_str = getenv("SPACEMIT_PERFER_CORE_ID"); + std::vector<int> perfer_core_id_vec; + if (spine_perfer_core_id_str != nullptr && spine_perfer_core_id_str != "") { + std::string perfer_core_id_str(spine_perfer_core_id_str); + size_t start = 0; + size_t end = 0; + while ((end = perfer_core_id_str.find(',', start)) != std::string::npos) { + std::string core_id_substr = perfer_core_id_str.substr(start, end - start); + perfer_core_id_vec.push_back(std::stoi(core_id_substr)); + start = end + 1; + } + std::string core_id_substr = perfer_core_id_str.substr(start); + perfer_core_id_vec.push_back(std::stoi(core_id_substr)); + } + + perfer_core_ids.reserve(num_cores); + if (perfer_core_arch_id == spine_core_arch_id::core_arch_none) { + for (auto & core_info : core_info_list) { + auto core_arch_id = core_info.arch_id; + auto core_arch_head = (uint16_t) (core_arch_id) >> 12; + if (core_arch_head == 0xA) { + num_perfer_cores++; + perfer_core_arch_id = core_arch_id; + cpu_mask |= (1ULL << core_info.core_id); + perfer_core_ids.push_back(core_info.core_id); + } + } + } else { + for (auto & core_info : core_info_list) { + auto core_arch_id = core_info.arch_id; + if (core_arch_id == perfer_core_arch_id) { + num_perfer_cores++; + cpu_mask |= (1ULL << core_info.core_id); + + auto core_arch_head = (uint16_t) (core_arch_id) >> 12; + if (core_arch_head == 0xA) { + perfer_core_ids.push_back(core_info.core_id); + } + } + } + if (num_perfer_cores == 0) { + GGML_ABORT("can not find core with arch id %x for SPACEMIT_PERFER_CORE_ARCH in core info list\n", + (uint16_t) perfer_core_arch_id); + } + } + + if (perfer_core_id_vec.size() > 0) { + perfer_core_ids.clear(); + cpu_mask = 0; + num_perfer_cores = 0; + for (int core_id : perfer_core_id_vec) { + if (core_id < 0 || core_id >= num_cores) { + GGML_ABORT("invalid core id in SPACEMIT_PERFER_CORE_ID: %d, should be between 0 and %d\n", core_id, + num_cores - 1); + } + auto core_info = core_info_list[core_id]; + auto core_arch_id = core_info.arch_id; + if (core_arch_id == perfer_core_arch_id) { + cpu_mask |= (1ULL << core_id); + perfer_core_ids.push_back(core_id); + } else { + GGML_ABORT( + "core id %d in SPACEMIT_PERFER_CORE_ID has arch id %x which does not match " + "SPACEMIT_PERFER_CORE_ARCH %x\n", + core_id, (uint16_t) core_arch_id, (uint16_t) perfer_core_arch_id); + } + } + std::string perfer_core_id_vec_str; + for (int core_id : perfer_core_id_vec) { + perfer_core_id_vec_str += std::to_string(core_id) + ","; + } + perfer_core_id_vec_str.pop_back(); + GGML_LOG_DEBUG("SPACEMIT_PERFER_CORE_ID is set, perferred core ids: %s\n", perfer_core_id_vec_str.c_str()); + num_perfer_cores = static_cast<int>(perfer_core_id_vec.size()); + } + + use_ime1 = perfer_core_arch_id == spine_core_arch_id::core_arch_a60 || + perfer_core_arch_id == spine_core_arch_id::core_arch_x100; + + use_ime2 = perfer_core_arch_id == spine_core_arch_id::core_arch_a100; + + mem_backend = parse_mem_backend(getenv("SPACEMIT_MEM_BACKEND")); + char * spine_disable_tcm_str = getenv("SPACEMIT_DISABLE_TCM"); + auto user_disable_tcm = spine_disable_tcm_str != nullptr && strcmp(spine_disable_tcm_str, "0") != 0; + + if (!user_disable_tcm) { + spine_mem_pool_tcm_info tcm_info; + if (spine_mem_pool_tcm_init(&tcm_info)) { + use_tcm = tcm_info.available; + tcm_blk_size = tcm_info.blk_size; + GGML_LOG_DEBUG("CPU_RISCV64_SPACEMIT: tcm is available, blk_size: %zu, blk_num: %zu, is_fake_tcm: %d\n", + tcm_info.blk_size, tcm_info.blk_num, tcm_info.is_fake_tcm); + + for (auto & core_info : core_info_list) { + auto core_arch_head = (uint16_t) (core_info.arch_id) >> 12; + if (core_arch_head != 0xA) { + aicpu_id_offset++; + } else { + break; + } + } + } + } + + GGML_LOG_DEBUG( + "CPU_RISCV64_SPACEMIT: num_cores: %d, num_perfer_cores: %d, perfer_core_arch_id: %x, exclude_main_thread: %d, " + "use_ime1: %d, use_ime2: %d, mem_backend: %s, cpu_mask: %lx, aicpu_id_offset: %d\n", + num_cores, num_perfer_cores, (uint16_t) perfer_core_arch_id, exclude_main_thread, use_ime1, use_ime2, + spine_mem_pool_backend_to_string(mem_backend), cpu_mask, aicpu_id_offset); + + const size_t init_barrier_size = sizeof(spine_barrier_t) * spine_init_barrier_count; + init_barrier = + static_cast<spine_barrier_t *>(spine_mem_pool_shared_mem_alloc(init_barrier_size, alignof(spine_barrier_t))); + if (init_barrier != nullptr) { + init_barrier_is_shared_mem = true; + } else { + GGML_LOG_WARN("CPU_RISCV64_SPACEMIT: failed to allocate init_barrier from shared mem, falling back to heap\n", + __func__); + init_barrier = new spine_barrier_t[spine_init_barrier_count]; + } + + spine_barrier_init(init_barrier, spine_init_barrier_count, 2); +} + +spine_env_info::~spine_env_info() { + if (init_barrier_is_shared_mem) { + spine_mem_pool_shared_mem_free(init_barrier); + } else { + delete[] init_barrier; + } + + init_barrier = nullptr; + init_barrier_is_shared_mem = false; +} + +spine_env_info global_spine_env_info; + +} // namespace ggml::cpu::riscv64_spacemit diff --git a/ggml/src/ggml-cpu/spacemit/ime_env.h b/ggml/src/ggml-cpu/spacemit/ime_env.h new file mode 100644 index 00000000000..a6ca06d26a4 --- /dev/null +++ b/ggml/src/ggml-cpu/spacemit/ime_env.h @@ -0,0 +1,55 @@ +#pragma once + +#include "spine_barrier.h" +#include "spine_mem_pool.h" + +#include <cstddef> +#include <cstdint> +#include <vector> + +namespace ggml::cpu::riscv64_spacemit { + +constexpr uint64_t spine_invalid_core_id = 0xFFFFFFFF; +constexpr size_t spine_init_barrier_count = 16; + +enum class spine_core_arch_id : uint16_t { + core_arch_none = 0, + core_arch_x60 = 0x503C, + core_arch_x100 = 0x5064, + core_arch_x200 = 0x50C8, + core_arch_a60 = 0xA03C, + core_arch_a100 = 0xA064, + core_arch_a200 = 0xA0C8, +}; + +struct spine_core_info { + uint64_t core_id{ spine_invalid_core_id }; + spine_core_arch_id arch_id{ spine_core_arch_id::core_arch_none }; + + static bool get_spine_core_info(std::vector<spine_core_info> & result); +}; + +struct spine_env_info { + std::vector<spine_core_info> core_info_list; + std::vector<int> perfer_core_ids; + int aicpu_id_offset{ 0 }; + int num_cores{ 0 }; + int num_perfer_cores{ 0 }; + spine_core_arch_id perfer_core_arch_id{ spine_core_arch_id::core_arch_none }; + bool exclude_main_thread{ false }; + bool use_ime2{ false }; + bool use_ime1{ false }; + bool use_tcm{ false }; + spine_mem_pool_backend mem_backend{ spine_mem_pool_backend::transparent_hugepage }; + uint64_t tcm_blk_size{ 0 }; + uint64_t cpu_mask{ 0 }; + spine_barrier_t * init_barrier{ nullptr }; + bool init_barrier_is_shared_mem{ false }; + + spine_env_info(); + ~spine_env_info(); +}; + +extern spine_env_info global_spine_env_info; + +} // namespace ggml::cpu::riscv64_spacemit diff --git a/ggml/src/ggml-cpu/spacemit/ime_kernels.h b/ggml/src/ggml-cpu/spacemit/ime_kernels.h index 75706341505..0a1fafffb25 100644 --- a/ggml/src/ggml-cpu/spacemit/ime_kernels.h +++ b/ggml/src/ggml-cpu/spacemit/ime_kernels.h @@ -1,26 +1,189 @@ #pragma once +#include <cassert> #include <cstddef> +#include <functional> + +namespace spacemit_kernels { + +#define BLOCK_QNK_LEN 256 + +template <int N> struct nrow_block_q2_k { + // [4bit scale + 4bit zp] * N * 16 + uint8_t scales[N * BLOCK_QNK_LEN / 16]; + // [b0, b16, b32, b48] [b1, b17, b33, b49] ... [b15, b31, b47, b63] + // [b64, b80, b96, b112] ...[b79, b95, b111, b127] + // [b128, b144, b160, b176] ...[b143, b159, b175, b191] + // [b192, b208, b224, b240] ...[b207, b223, b239, b255] + uint8_t qs[N * BLOCK_QNK_LEN / 4]; + uint16_t scales16[N]; + uint16_t zeros16[N]; +}; + +template <int N> struct nrow_block_q3_k { + // [8bit scale] * N * 16 + int8_t scales[N * 16]; + // [b0, b1, b2, b3, b4, b5, b6, b7] ... [b248, b249, b250, b251, b252, b253, b254, b255] + uint8_t hmask[N * BLOCK_QNK_LEN / 8]; + // [b0, b16, b32, b48] [b1, b17, b33, b49] ... [b15, b31, b47, b63] + // [b64, b80, b96, b112] ...[b79, b95, b111, b127] + // [b128, b144, b160, b176] ...[b143, b159, b175, b191] + // [b192, b208, b224, b240] ...[b207, b223, b239, b255] + uint8_t qs[N * BLOCK_QNK_LEN / 4]; + uint16_t scales16[N]; +}; + +template <int N> struct nrow_block_mxfp4 { + uint8_t e[N]; + uint8_t qh[4 * N]; + uint8_t qs[16 * N]; +}; + +template <int N> struct __attribute__((packed)) nrow_block_q5_1 { + uint16_t scales16[N]; + uint8_t zp[N]; + // n0 [bh0, bh1, bh2, bh3, bh4, bh5, bh6, bh7] .... + uint8_t qh[4 * N]; + // n0 [b0, b1], [b2, b3] .... [b30, b31] + // n1 [b0, b1], [b2, b3] .... [b30, b31] + uint8_t qs[16 * N]; +}; + +static_assert(sizeof(nrow_block_q5_1<1>) == sizeof(uint8_t) + 22, "wrong nrow_block_q5_1 block size/padding"); + +template <int N> struct __attribute__((packed)) nrow_block_q5_0 { + uint16_t scales16[N]; + // n0 [bh0, bh1, bh2, bh3, bh4, bh5, bh6, bh7] .... + uint8_t qh[4 * N]; + // n0 [b0, b1], [b2, b3] .... [b30, b31] + // n1 [b0, b1], [b2, b3] .... [b30, b31] + uint8_t qs[16 * N]; +}; + +static_assert(sizeof(nrow_block_q5_0<1>) == 22, "wrong nrow_block_q5_0 block size/padding"); + +using gemm_kernel_quantize_def = std::function< + size_t(size_t, const uint8_t *, const uint8_t *, const uint8_t *, float *, size_t, size_t, size_t, size_t)>; + +using moe_gemm_kernel_quantize_def = std::function< + size_t(size_t, const uint8_t **, const uint8_t *, const uint8_t *, float **, size_t, size_t, size_t, size_t)>; -namespace sqnbitgemm_spacemit_ime { namespace ime1 { -size_t gemm_kernel_i8i4(size_t blk_len, - const std::byte * quant_a_ptr, - const std::byte * quant_b_data, - const float * quant_b_scale, - const std::byte * quant_b_zp, - float * c_ptr, - size_t count_m, - size_t count_n, - size_t count_k, - size_t block_count_k, - size_t ldc, - const float * bias, - const size_t scale_stride); - -void quantize_a_row_i8(size_t blk_len, const float * a_ptr, size_t count_k, std::byte * quant_a_ptr); - -void quantize_a_4row_i8(size_t blk_len, const float * a_ptr, size_t count_k, std::byte * quant_a_ptr); +size_t gemm_kernel_i8i4(size_t blk_len, + const uint8_t * quant_a_ptr, + const uint8_t * quant_b_data, + const uint8_t * quant_b_zp, + float * c_ptr, + size_t count_m, + size_t count_n, + size_t k_blks, + size_t ldc); + +void quantize_a_row_i8(size_t blk_len, const float * a_ptr, size_t count_k, uint8_t * quant_a_ptr); + +void quantize_a_4row_i8(size_t blk_len, const float * a_ptr, size_t count_k, uint8_t * quant_a_ptr); } // namespace ime1 -} // namespace sqnbitgemm_spacemit_ime + +namespace ime2 { +size_t gemm_kernel_i8i2k(size_t blk_len, + const uint8_t * quant_a_ptr, + const uint8_t * quant_b_data, + const uint8_t * quant_b_zp, + float * c_ptr, + size_t count_m, + size_t count_n, + size_t k_blks, + size_t ldc); + +size_t gemm_kernel_i8i3k(size_t blk_len, + const uint8_t * quant_a_ptr, + const uint8_t * quant_b_data, + const uint8_t * quant_b_zp, + float * c_ptr, + size_t count_m, + size_t count_n, + size_t k_blks, + size_t ldc); + +size_t gemm_kernel_i8i4(size_t blk_len, + const uint8_t * quant_a_ptr, + const uint8_t * quant_b_data, + const uint8_t * quant_b_zp, + float * c_ptr, + size_t count_m, + size_t count_n, + size_t k_blks, + size_t ldc); + +size_t gemm_kernel_i8i4_hp(size_t blk_len, + const uint8_t * quant_a_ptr, + const uint8_t * quant_b_data, + const uint8_t * quant_b_zp, + float * c_ptr, + size_t count_m, + size_t count_n, + size_t k_blks, + size_t ldc); + +size_t moe_m2_gemm_kernel_i8i4(size_t blk_len, + const uint8_t ** quant_a_ptr, + const uint8_t * quant_b_data, + const uint8_t * quant_b_zp, + float ** c_ptr, + size_t count_m, + size_t count_n, + size_t k_blks, + size_t ldc); + +size_t gemm_kernel_i8i8(size_t blk_len, + const uint8_t * quant_a_ptr, + const uint8_t * quant_b_data, + const uint8_t * quant_b_zp, + float * c_ptr, + size_t count_m, + size_t count_n, + size_t k_blks, + size_t ldc); + +size_t gemm_kernel_i8mxfp4(size_t blk_len, + const uint8_t * quant_a_ptr, + const uint8_t * quant_b_data, + const uint8_t * quant_b_zp, + float * c_ptr, + size_t count_m, + size_t count_n, + size_t k_blks, + size_t ldc); + +size_t moe_m2_gemm_kernel_i8mxfp4(size_t blk_len, + const uint8_t ** quant_a_ptr, + const uint8_t * quant_b_data, + const uint8_t * quant_b_zp, + float ** c_ptr, + size_t count_m, + size_t count_n, + size_t k_blks, + size_t ldc); + +size_t gemm_kernel_i8i5(size_t blk_len, + const uint8_t * quant_a_ptr, + const uint8_t * quant_b_data, + const uint8_t * quant_b_zp, + float * c_ptr, + size_t count_m, + size_t count_n, + size_t k_blks, + size_t ldc); + +size_t moe_m2_gemm_kernel_i8i5(size_t blk_len, + const uint8_t ** quant_a_ptr, + const uint8_t * quant_b_data, + const uint8_t * quant_b_zp, + float ** c_ptr, + size_t count_m, + size_t count_n, + size_t k_blks, + size_t ldc); +} // namespace ime2 +} // namespace spacemit_kernels diff --git a/ggml/src/ggml-cpu/spacemit/repack.cpp b/ggml/src/ggml-cpu/spacemit/repack.cpp new file mode 100644 index 00000000000..3c879c4b7a0 --- /dev/null +++ b/ggml/src/ggml-cpu/spacemit/repack.cpp @@ -0,0 +1,1795 @@ +#define GGML_COMMON_IMPL_CPP +#define GGML_COMMON_DECL_CPP + +#include "repack.h" + +#include "ggml-common.h" +#include "ggml-cpu.h" +#include "ggml-impl.h" +#include "ime_kernels.h" + +#include <algorithm> +#include <cassert> +#include <cmath> +#include <cstring> + +// clang-format off +#if defined(__riscv) + +#if !defined(__riscv_v) || !defined(__riscv_v_intrinsic) +#error "riscv v extension or v_intrinsic not enabled" +#else +#include <riscv_vector.h> +#endif + +#if !defined(__riscv_zfh) +#error "riscv zfh extension not enabled" +#endif + +#else +#error "riscv not enabled in this build" +#endif + +#if defined(__GNUC__) +#pragma GCC diagnostic ignored "-Wcast-qual" +#pragma GCC diagnostic ignored "-Wunused-parameter" +#endif + +// clang-format on + +template <int K> constexpr int QK_0() { + if constexpr (K == 4) { + return QK4_0; + } + if constexpr (K == 8) { + return QK8_0; + } + return -1; +} + +template <int K, int N> struct block { + ggml_half d[N]; // deltas for N qK_0 blocks + uint8_t qs[(QK_0<K>() * N * K) / 8]; // quants for N qK_0 blocks +}; + +template <int K, int N> struct block_with_zp { + ggml_half d[N]; // deltas for N qK_1 blocks + uint8_t zp[N]; // zero points for N qK_1 blocks + uint8_t qs[(QK_0<K>() * N * K) / 8]; // quants for N qK_1 blocks +}; + +// control size +static_assert(sizeof(block<4, 16>) == 16 * sizeof(ggml_half) + QK4_0 * 8, "wrong block<4,16> size/padding"); +static_assert(sizeof(block_with_zp<4, 16>) == 16 * sizeof(ggml_half) + QK4_0 * 8 + 16 * sizeof(uint8_t), + "wrong block_with_zp<4,16> size/padding"); + +static_assert(sizeof(block<8, 16>) == 16 * sizeof(ggml_half) + QK4_0 * 16, "wrong block<8,16> size/padding"); + +static_assert(sizeof(block<4, 32>) == 32 * sizeof(ggml_half) + QK4_0 * 16, "wrong block<4,32> size/padding"); +static_assert(sizeof(block_with_zp<4, 32>) == 32 * sizeof(ggml_half) + QK4_0 * 16 + 32 * sizeof(uint8_t), + "wrong block_with_zp<4,32> size/padding"); + +using block_q4_0x16 = block<4, 16>; +using block_q4_1x16 = block_with_zp<4, 16>; +using block_q8_0x16 = block<8, 16>; + +using block_q4_0x32 = block<4, 32>; +using block_q4_1x32 = block_with_zp<4, 32>; +using block_q8_0x32 = block<8, 32>; + +struct block_q4_0x32x256 { + block_q4_0x32 blocks[8]; // [f16 * 32 | i4 * 32 * 32] * 8 +}; + +struct block_q4_1x32x256 { + block_q4_0x32 blocks[8]; + uint8_t zps[32 * 8]; +}; + +static block_q4_0x16 make_block_q4_0x16(block_q4_0 * in, unsigned int blck_size_interleave) { + block_q4_0x16 out; + GGML_ASSERT(QK4_0 / blck_size_interleave == 2); + + for (int i = 0; i < 16; i++) { + out.d[i] = in[i].d; + } + + for (int i = 0; i < 16; i++) { + // [0, 15], in.d & 0x0F + for (int j = 0; j < QK4_0 / 4; j++) { + //src [b0 b16] ......... [b8 b24] ......... [b15 b31] + //dst [b0 b8] ......... [b7 b15] + out.qs[i * QK4_0 / 4 + j] = (in[i].qs[j] & 0x0F) | ((in[i].qs[j + QK4_0 / 4] & 0x0F) << 4); + } + } + + for (int i = 0; i < 16; i++) { + // [16, 31], in.d & 0xF0 + for (int j = 0; j < QK4_0 / 4; j++) { + //src [b0 b16] ......... [b8 b24] ......... [b15 b31] + //dst [b16 b24] ......... [b23 b31] + out.qs[4 * QK4_0 + i * QK4_0 / 4 + j] = ((in[i].qs[j] & 0xF0) >> 4) | (in[i].qs[j + QK4_0 / 4] & 0xF0); + } + } + + return out; +} + +static block_q4_1x16 make_block_q4_1x16(block_q4_1 * in, unsigned int blck_size_interleave) { + block_q4_1x16 out; + GGML_ASSERT(QK4_1 / blck_size_interleave == 2); + + for (int i = 0; i < 16; i++) { + float d = GGML_FP16_TO_FP32(in[i].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.d); + float m = GGML_FP16_TO_FP32(in[i].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.m); + float mid = -std::nearbyintf(m / d); + mid = std::min(15.0f, std::max(0.0f, mid)); + out.d[i] = GGML_FP32_TO_FP16(d); + out.zp[i] = static_cast<uint8_t>(mid); + } + + for (int i = 0; i < 16; i++) { + // [0, 15], in.d & 0x0F + for (int j = 0; j < QK4_1 / 4; j++) { + //src [b0 b16] ......... [b8 b24] ......... [b15 b31] + //dst [b0 b8] ......... [b7 b15] + out.qs[i * QK4_1 / 4 + j] = (in[i].qs[j] & 0x0F) | ((in[i].qs[j + QK4_1 / 4] & 0x0F) << 4); + } + } + + for (int i = 0; i < 16; i++) { + // [16, 31], in.d & 0xF0 + for (int j = 0; j < QK4_1 / 4; j++) { + //src [b0 b16] ......... [b8 b24] ......... [b15 b31] + //dst [b16 b24] ......... [b23 b31] + out.qs[4 * QK4_1 + i * QK4_1 / 4 + j] = ((in[i].qs[j] & 0xF0) >> 4) | (in[i].qs[j + QK4_1 / 4] & 0xF0); + } + } + + return out; +} + +static int repack_q4_0_to_q4_0_16_bl(ggml_tensor * t, + int interleave_block, + const void * GGML_RESTRICT data, + size_t data_size) { + GGML_ASSERT(t->type == GGML_TYPE_Q4_0); + GGML_ASSERT(interleave_block == 16); + + constexpr int nrows_interleaved = 16; + + block_q4_0x16 * dst = (block_q4_0x16 *) t->data; + const block_q4_0 * src = (const block_q4_0 *) data; + block_q4_0 dst_tmp[16]; + int nrow = ggml_nrows(t); + int nblocks = t->ne[0] / QK4_0; + + GGML_ASSERT(data_size == nrow * nblocks * sizeof(block_q4_0)); + + if (t->ne[1] % nrows_interleaved != 0 || t->ne[0] % QK4_0 != 0) { + return -1; + } + + for (int b = 0; b < nrow; b += nrows_interleaved) { + for (int64_t x = 0; x < nblocks; x++) { + for (int i = 0; i < nrows_interleaved; i++) { + dst_tmp[i] = src[x + i * nblocks]; + } + *dst++ = make_block_q4_0x16(dst_tmp, interleave_block); + } + src += nrows_interleaved * nblocks; + } + return 0; + + GGML_UNUSED(data_size); +} + +static int repack_q4_1_to_q4_1_16_bl(ggml_tensor * t, + int interleave_block, + const void * GGML_RESTRICT data, + size_t data_size) { + GGML_ASSERT(t->type == GGML_TYPE_Q4_1); + GGML_ASSERT(interleave_block == 16); + + constexpr int nrows_interleaved = 16; + + block_q4_1x16 * dst = (block_q4_1x16 *) t->data; + const block_q4_1 * src = (const block_q4_1 *) data; + block_q4_1 dst_tmp[16]; + int nrow = ggml_nrows(t); + int nblocks = t->ne[0] / QK4_1; + + GGML_ASSERT(data_size == nrow * nblocks * sizeof(block_q4_1)); + + if (t->ne[1] % nrows_interleaved != 0 || t->ne[0] % QK4_1 != 0) { + return -1; + } + + for (int b = 0; b < nrow; b += nrows_interleaved) { + for (int64_t x = 0; x < nblocks; x++) { + for (int i = 0; i < nrows_interleaved; i++) { + dst_tmp[i] = src[x + i * nblocks]; + } + *dst++ = make_block_q4_1x16(dst_tmp, interleave_block); + } + src += nrows_interleaved * nblocks; + } + return 0; + + GGML_UNUSED(data_size); +} + +static inline void get_scale_min_k4(int j, + const uint8_t * GGML_RESTRICT q, + uint8_t * GGML_RESTRICT d, + uint8_t * GGML_RESTRICT m) { + if (j < 4) { + *d = q[j] & 63; + *m = q[j + 4] & 63; + } else { + *d = (q[j + 4] & 0xF) | ((q[j - 4] >> 6) << 4); + *m = (q[j + 4] >> 4) | ((q[j - 0] >> 6) << 4); + } +} + +static int repack_q4_k_to_q4_1_16_bl(ggml_tensor * t, + int interleave_block, + const void * GGML_RESTRICT data, + size_t data_size) { + GGML_ASSERT(t->type == GGML_TYPE_Q4_K); + GGML_ASSERT(interleave_block == 16); + GGML_ASSERT(QK_K / QK4_1 == 8); + + constexpr int nrows_interleaved = 16; + + block_q4_1x16 * dst = (block_q4_1x16 *) t->data; + const block_q4_K * src = (const block_q4_K *) data; + block_q4_1 dst_tmp[16]; + int nrow = ggml_nrows(t); + int nblocks = t->ne[0] / QK_K; + + if (t->ne[1] % nrows_interleaved != 0 || t->ne[0] % QK_K != 0) { + return -1; + } + + for (int b = 0; b < nrow; b += nrows_interleaved) { + for (int64_t x = 0; x < nblocks; x++) { + for (int j = 0; j < 8; j++) { + for (int i = 0; i < nrows_interleaved; i++) { + uint8_t sc, m; + const float d = GGML_FP16_TO_FP32(src[x + i * nblocks].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.d); + const float min = + GGML_FP16_TO_FP32(src[x + i * nblocks].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.dmin); + get_scale_min_k4(j, src[x + i * nblocks].scales, &sc, &m); + const float d1 = d * sc; + const float m1 = min * m; + + dst_tmp[i].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.d = GGML_FP32_TO_FP16(d1); + dst_tmp[i].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.m = GGML_FP32_TO_FP16(-m1); + // src -> [b0, b32] [b1, b33] ... [b31, b63] + // dst -> [b0, b16] [b1, b17] ... [b15, b31] [b32, b48] [b33, b49] ... [b47, b63] + const uint8_t * q = src[x + i * nblocks].qs + (j / 2) * QK4_1; + if (j % 2 == 0) { + for (int ii = 0; ii < 16; ii++) { + dst_tmp[i].qs[ii] = (q[ii] & 0x0F) | ((q[ii + 16] & 0x0F) << 4); + } + } else { + for (int ii = 0; ii < 16; ii++) { + dst_tmp[i].qs[ii] = ((q[ii] & 0xF0) >> 4) | (q[ii + 16] & 0xF0); + } + } + } + *dst++ = make_block_q4_1x16(dst_tmp, interleave_block); + } + } + src += nrows_interleaved * nblocks; + } + return 0; + + GGML_UNUSED(data_size); +} + +static block_q4_0x32 make_block_q4_0x32(block_q4_0 * in, unsigned int blck_size_interleave) { + block_q4_0x32 out; + assert(QK4_0 / blck_size_interleave == 1); + GGML_UNUSED(blck_size_interleave); + + for (int i = 0; i < 32; i++) { + out.d[i] = in[i].d; + } + + for (int i = 0; i < 32; i++) { + // [0, 15], in.d & 0x0F + for (int j = 0; j < QK4_0 / 4; j++) { + //src [b0 b16] ......... [b8 b24] ......... [b15 b31] + //dst [b0 b1] ......... [b14 b15] + out.qs[i * QK4_0 / 2 + j] = (in[i].qs[j * 2] & 0x0F) | ((in[i].qs[j * 2 + 1] & 0x0F) << 4); + } + } + + for (int i = 0; i < 32; i++) { + // [16, 31], in.d & 0xF0 + for (int j = 0; j < QK4_0 / 4; j++) { + //src [b0 b16] ......... [b8 b24] ......... [b15 b31] + //dst [b16 b17] ......... [b30 b31] + out.qs[i * QK4_0 / 2 + QK4_0 / 4 + j] = ((in[i].qs[j * 2] & 0xF0) >> 4) | (in[i].qs[j * 2 + 1] & 0xF0); + } + } + + return out; +} + +static block_q4_1x32 make_block_q4_1x32(block_q4_1 * in, unsigned int blck_size_interleave) { + block_q4_1x32 out; + GGML_ASSERT(QK4_1 / blck_size_interleave == 1); + GGML_UNUSED(blck_size_interleave); + + for (int i = 0; i < 32; i++) { + float d = GGML_FP16_TO_FP32(in[i].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.d); + float m = GGML_FP16_TO_FP32(in[i].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.m); + float mid = -std::nearbyintf(m / d); + mid = std::min(15.0f, std::max(0.0f, mid)); + out.d[i] = GGML_FP32_TO_FP16(d); + out.zp[i] = static_cast<uint8_t>(mid); + } + + for (int i = 0; i < 32; i++) { + // [0, 15], in.d & 0x0F + for (int j = 0; j < QK4_1 / 4; j++) { + //src [b0 b16] ......... [b8 b24] ......... [b15 b31] + //dst [b0 b1] ......... [b14 b15] + out.qs[i * QK4_1 / 2 + j] = (in[i].qs[j * 2] & 0x0F) | ((in[i].qs[j * 2 + 1] & 0x0F) << 4); + } + } + + for (int i = 0; i < 32; i++) { + // [16, 31], in.d & 0xF0 + for (int j = 0; j < QK4_1 / 4; j++) { + //src [b0 b16] ......... [b8 b24] ......... [b15 b31] + //dst [b16 b24] ......... [b23 b31] + out.qs[i * QK4_1 / 2 + QK4_1 / 4 + j] = ((in[i].qs[j * 2] & 0xF0) >> 4) | (in[i].qs[j * 2 + 1] & 0xF0); + } + } + + return out; +} + +static block_q8_0x32 make_block_q8_0x32(block_q8_0 * in, unsigned int blck_size_interleave) { + block_q8_0x32 out; + GGML_ASSERT(QK8_0 / blck_size_interleave == 1); + GGML_UNUSED(blck_size_interleave); + + for (int i = 0; i < 32; i++) { + out.d[i] = in[i].d; + } + + for (int i = 0; i < 32; i++) { + memcpy(out.qs + i * QK8_0, in[i].qs, QK8_0); + } + + return out; +} + +static int repack_q2_k_to_q2_k_32_bl(ggml_tensor * t, + int interleave_block, + const void * GGML_RESTRICT data, + size_t data_size) { + GGML_ASSERT(t->type == GGML_TYPE_Q2_K); + GGML_ASSERT(interleave_block == 32); + GGML_ASSERT(QK_K == 256); + + constexpr int nrows_interleaved = 32; + + const block_q2_K * src = (const block_q2_K *) data; + + auto * dst = (spacemit_kernels::nrow_block_q2_k<32> *) t->data; + + int nrow = ggml_nrows(t); + int nblocks = t->ne[0] / QK_K; + + GGML_ASSERT(data_size == nrow * nblocks * sizeof(block_q2_K)); + + if (t->ne[1] % nrows_interleaved != 0 || t->ne[0] % QK_K != 0) { + return -1; + } + + uint8_t qs_aux[256] = { 0 }; + for (int b = 0; b < nrow; b += nrows_interleaved) { + for (int64_t x = 0; x < nblocks; x++) { + for (int i = 0; i < nrows_interleaved; i++) { + const block_q2_K * src_block = &src[(b + i) * nblocks + x]; + + // scale for [16, N] + for (int j = 0; j < 16; j++) { + auto zp_aux = (dst->scales[j * nrows_interleaved + i]) & 0xF0; + + dst->scales[j * nrows_interleaved + i] = (src_block->scales[j] & 0x0F) | zp_aux; + } + + // zp for [N, 16] + for (int j = 0; j < 16; j++) { + auto scale_aux = (dst->scales[16 * i + j]) & 0x0F; + + dst->scales[16 * i + j] = (src_block->scales[j] & 0xF0) | scale_aux; + } + + for (int k = 0; k < 4; k++) { + for (int j = 0; j < 32; j++) { + qs_aux[k * 32 + j] = (src_block->qs[j] >> (2 * k)) & 0x03; + } + } + + for (int k = 0; k < 4; k++) { + for (int j = 0; j < 32; j++) { + qs_aux[k * 32 + j + 128] = (src_block->qs[j + 32] >> (2 * k)) & 0x03; + } + } + + // from nrows_interleaved * [2 * 32byte] + // to 4 * [nrows_interleaved * 16byte] + for (int k = 0; k < 4; k++) { + for (int j = 0; j < 16; j++) { + uint8_t qs0 = qs_aux[j + k * 64]; + uint8_t qs16 = qs_aux[j + 16 + k * 64]; + uint8_t qs32 = qs_aux[j + 32 + k * 64]; + uint8_t qs48 = qs_aux[j + 48 + k * 64]; + + dst->qs[(k * nrows_interleaved + i) * 16 + j] = + (qs0 & 0x03) | ((qs16 & 0x03) << 2) | ((qs32 & 0x03) << 4) | ((qs48 & 0x03) << 6); + } + } + + dst->scales16[i] = src_block->GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.d; + dst->zeros16[i] = src_block->GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.dmin; + } + dst++; + } + } + + return 0; +} + +static int repack_q3_k_to_q3_k_32_bl(ggml_tensor * t, + int interleave_block, + const void * GGML_RESTRICT data, + size_t data_size) { + GGML_ASSERT(t->type == GGML_TYPE_Q3_K); + GGML_ASSERT(interleave_block == 32); + GGML_ASSERT(QK_K == 256); + + constexpr int nrows_interleaved = 32; + + const uint32_t kmask1 = 0x03030303; + const uint32_t kmask2 = 0x0f0f0f0f; + + const block_q3_K * src = (const block_q3_K *) data; + + auto * dst = (spacemit_kernels::nrow_block_q3_k<32> *) t->data; + + int nrow = ggml_nrows(t); + int nblocks = t->ne[0] / QK_K; + + GGML_ASSERT(data_size == nrow * nblocks * sizeof(block_q3_K)); + + if (t->ne[1] % nrows_interleaved != 0 || t->ne[0] % QK_K != 0) { + return -1; + } + + uint32_t b_scale_aux[4] = { 0 }; + uint8_t qs_aux[256] = { 0 }; + + for (int b = 0; b < nrow; b += nrows_interleaved) { + for (int64_t x = 0; x < nblocks; x++) { + for (int i = 0; i < nrows_interleaved; i++) { + const block_q3_K * src_block = &src[(b + i) * nblocks + x]; + + uint32_t * auxs = b_scale_aux; + int8_t * scale = (int8_t *) auxs; + memcpy(auxs, src_block->scales, 12); + + uint32_t tmp = auxs[2]; + auxs[2] = ((auxs[0] >> 4) & kmask2) | (((tmp >> 4) & kmask1) << 4); + auxs[3] = ((auxs[1] >> 4) & kmask2) | (((tmp >> 6) & kmask1) << 4); + auxs[0] = (auxs[0] & kmask2) | (((tmp >> 0) & kmask1) << 4); + auxs[1] = (auxs[1] & kmask2) | (((tmp >> 2) & kmask1) << 4); + + for (int j = 0; j < 16; j++) { + dst->scales[j * nrows_interleaved + i] = scale[j] - 32; + } + + for (int k = 0; k < 4; k++) { + for (int j = 0; j < 32; j++) { + qs_aux[k * 32 + j] = (src_block->qs[j] >> (2 * k)) & 0x03; + } + } + + for (int k = 0; k < 4; k++) { + for (int j = 0; j < 32; j++) { + qs_aux[k * 32 + j + 128] = (src_block->qs[j + 32] >> (2 * k)) & 0x03; + } + } + + // from nrows_interleaved * [2 * 32byte] + // to 4 * [nrows_interleaved * 16byte] + for (int k = 0; k < 4; k++) { + for (int j = 0; j < 16; j++) { + uint8_t qs0 = qs_aux[j + k * 64]; + uint8_t qs16 = qs_aux[j + 16 + k * 64]; + uint8_t qs32 = qs_aux[j + 32 + k * 64]; + uint8_t qs48 = qs_aux[j + 48 + k * 64]; + + dst->qs[(k * nrows_interleaved + i) * 16 + j] = + (qs0 & 0x03) | ((qs16 & 0x03) << 2) | ((qs32 & 0x03) << 4) | ((qs48 & 0x03) << 6); + } + } + + //memcpy(dst->hmask + i * 32, src_block->hmask, 32); + + // from nrows_interleaved * [32byte] + // to 16 * [nrows_interleaved * uint16_t] + uint16_t * dst_mask = ((uint16_t *) dst->hmask) + i; + for (int j = 0; j < 16; j++, dst_mask += nrows_interleaved) { + uint8_t b_shift = j / 2; + uint8_t * b_mask_col = (uint8_t *) (src_block->hmask + (j % 2) * 16); + // b0 - b15 + uint16_t msk_out_0 = 0; + + for (int k = 0; k < 8; k++) { + msk_out_0 |= (uint16_t) ((b_mask_col[k] >> b_shift) & 0x01) << k; + } + for (int k = 8; k < 16; k++) { + msk_out_0 |= (uint16_t) ((b_mask_col[k] >> b_shift) & 0x01) << k; + } + + dst_mask[0] = msk_out_0; + } + + dst->scales16[i] = src_block->d; + } + + dst++; + } + } + + return 0; +} + +static int repack_q4_0_to_q4_0_32_bl_ref(ggml_tensor * t, + int interleave_block, + const void * GGML_RESTRICT data, + size_t data_size) { + GGML_ASSERT(t->type == GGML_TYPE_Q4_0); + GGML_ASSERT(interleave_block == 32); // unused + + constexpr int nrows_interleaved = 32; + + block_q4_0x32 * dst = (block_q4_0x32 *) t->data; + const block_q4_0 * src = (const block_q4_0 *) data; + block_q4_0 dst_tmp[32]; + int nrow = ggml_nrows(t); + int nblocks = t->ne[0] / QK4_0; + + GGML_ASSERT(data_size == nrow * nblocks * sizeof(block_q4_0)); + + if (t->ne[1] % nrows_interleaved != 0 || t->ne[0] % QK4_0 != 0) { + return -1; + } + + for (int b = 0; b < nrow; b += nrows_interleaved) { + for (int64_t x = 0; x < nblocks; x++) { + for (int i = 0; i < nrows_interleaved; i++) { + dst_tmp[i] = src[x + i * nblocks]; + } + *dst++ = make_block_q4_0x32(dst_tmp, interleave_block); + } + src += nrows_interleaved * nblocks; + } + return 0; + + GGML_UNUSED(data_size); +} + +static int repack_q4_0_to_q4_0_256_32_bl_ref(ggml_tensor * t, + int interleave_block, + const void * GGML_RESTRICT data, + size_t data_size) { + GGML_ASSERT(t->type == GGML_TYPE_Q4_0); + GGML_ASSERT(interleave_block == 32); // unused + + constexpr int nrows_interleaved = 32; + + block_q4_0x32x256 * dst = (block_q4_0x32x256 *) t->data; + const block_q4_0 * src = (const block_q4_0 *) data; + block_q4_0 dst_tmp[32]; + int nrow = ggml_nrows(t); + int nblocks = t->ne[0] / QK4_0; + + GGML_ASSERT(data_size == nrow * nblocks * sizeof(block_q4_0)); + GGML_ASSERT(nblocks % 8 == 0); // for 256-block interleaving + if (t->ne[1] % nrows_interleaved != 0 || t->ne[0] % QK4_0 != 0) { + return -1; + } + + for (int b = 0; b < nrow; b += nrows_interleaved) { + for (int64_t x = 0; x < nblocks; x += 8) { + for (int j = 0; j < 8; j++) { + for (int i = 0; i < nrows_interleaved; i++) { + dst_tmp[i] = src[x + j + i * nblocks]; + } + dst->blocks[j] = make_block_q4_0x32(dst_tmp, interleave_block); + } + dst++; + } + src += nrows_interleaved * nblocks; + } + return 0; + + GGML_UNUSED(data_size); +} + +static int repack_q4_0_to_q4_1_256_32_bl_ref(ggml_tensor * t, + int interleave_block, + const void * GGML_RESTRICT data, + size_t data_size) { + GGML_ASSERT(t->type == GGML_TYPE_Q4_1); + GGML_ASSERT(interleave_block == 32); // unused + + constexpr int nrows_interleaved = 32; + + block_q4_1x32x256 * dst = (block_q4_1x32x256 *) t->data; + const block_q4_1 * src = (const block_q4_1 *) data; + block_q4_1 dst_tmp[32]; + int nrow = ggml_nrows(t); + int nblocks = t->ne[0] / QK4_0; + + GGML_ASSERT(data_size == nrow * nblocks * sizeof(block_q4_1)); + GGML_ASSERT(nblocks % 8 == 0); // for 256-block interleaving + if (t->ne[1] % nrows_interleaved != 0 || t->ne[0] % QK4_0 != 0) { + return -1; + } + + for (int b = 0; b < nrow; b += nrows_interleaved) { + for (int64_t x = 0; x < nblocks; x += 8) { + for (int j = 0; j < 8; j++) { + for (int i = 0; i < nrows_interleaved; i++) { + dst_tmp[i] = src[x + j + i * nblocks]; + } + + block_q4_0x32 * dst_block = &dst->blocks[j]; + uint8_t * dst_zp = dst->zps + j * nrows_interleaved; + + for (int i = 0; i < nrows_interleaved; i++) { + float d = GGML_FP16_TO_FP32(dst_tmp[i].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.d); + float m = GGML_FP16_TO_FP32(dst_tmp[i].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.m); + float mid = -std::nearbyintf(m / d); + mid = std::min(15.0f, std::max(0.0f, mid)); + + dst_block->d[i] = GGML_FP32_TO_FP16(d); + dst_zp[i] = static_cast<uint8_t>(mid); + } + + for (int i = 0; i < nrows_interleaved; i++) { + for (int k = 0; k < QK4_1 / 4; k++) { + dst_block->qs[i * QK4_1 / 2 + k] = + (dst_tmp[i].qs[k * 2] & 0x0F) | ((dst_tmp[i].qs[k * 2 + 1] & 0x0F) << 4); + } + } + + for (int i = 0; i < nrows_interleaved; i++) { + for (int k = 0; k < QK4_1 / 4; k++) { + dst_block->qs[i * QK4_1 / 2 + QK4_1 / 4 + k] = + ((dst_tmp[i].qs[k * 2] & 0xF0) >> 4) | (dst_tmp[i].qs[k * 2 + 1] & 0xF0); + } + } + } + dst++; + } + src += nrows_interleaved * nblocks; + } + return 0; + + GGML_UNUSED(data_size); +} + +// RVV optimized version of repack_q4_0_to_q4_0_32_bl +// Eliminates the intermediate dst_tmp buffer and vectorizes nibble repack. +static int repack_q4_0_to_q4_0_32_bl(ggml_tensor * t, + int interleave_block, + const void * GGML_RESTRICT data, + size_t data_size) { + GGML_ASSERT(t->type == GGML_TYPE_Q4_0); + GGML_ASSERT(interleave_block == 32); + + constexpr int nrows_interleaved = 32; + constexpr int qs_bytes = QK4_0 / 2; // 16 + + block_q4_0x32 * dst = (block_q4_0x32 *) t->data; + const block_q4_0 * src = (const block_q4_0 *) data; + int nrow = ggml_nrows(t); + int nblocks = t->ne[0] / QK4_0; + + GGML_ASSERT(data_size == nrow * nblocks * sizeof(block_q4_0)); + + if (t->ne[1] % nrows_interleaved != 0 || t->ne[0] % QK4_0 != 0) { + return -1; + } + + const ptrdiff_t row_stride = (ptrdiff_t) nblocks * sizeof(block_q4_0); + + for (int b = 0; b < nrow; b += nrows_interleaved) { + for (int64_t x = 0; x < nblocks; x++) { + const block_q4_0 * col_src = src + x; + + // --- 1) Gather 32 scale values (ggml_half d) with stride load --- + // d is at offset 0 of each block_q4_0, stride between rows = row_stride + { + const uint8_t * d_base = (const uint8_t *) &col_src->d; + ggml_half * d_dst = dst->d; + size_t remaining = 32; + size_t offset = 0; + while (remaining > 0) { + size_t vl = __riscv_vsetvl_e16m1(remaining); + vuint16m1_t vd = + __riscv_vlse16_v_u16m1((const uint16_t *) (d_base + offset * row_stride), row_stride, vl); + __riscv_vse16_v_u16m1((uint16_t *) (d_dst + offset), vd, vl); + offset += vl; + remaining -= vl; + } + } + + // --- 2) Nibble repack qs for each of the 32 rows --- + // For each row i: + // src qs[16]: [b0|b16] [b1|b17] ... [b15|b31] (lo nibble = b_j, hi nibble = b_{j+16}) + // dst qs low 8B: (qs[2j] & 0x0F) | ((qs[2j+1] & 0x0F) << 4) for j=0..7 + // dst qs high 8B: ((qs[2j] >> 4)) | (qs[2j+1] & 0xF0) for j=0..7 + { + const size_t vl8 = __riscv_vsetvl_e8m1(8); + for (int i = 0; i < 32; i++) { + const uint8_t * sq = col_src[i * nblocks].qs; + uint8_t * dq = dst->qs + i * qs_bytes; + + // stride-2 load to separate even/odd bytes + vuint8m1_t v_even = __riscv_vlse8_v_u8m1(sq, 2, vl8); // qs[0], qs[2], ..., qs[14] + vuint8m1_t v_odd = __riscv_vlse8_v_u8m1(sq + 1, 2, vl8); // qs[1], qs[3], ..., qs[15] + + // low nibble part: (even & 0x0F) | ((odd & 0x0F) << 4) + vuint8m1_t v_even_lo = __riscv_vand_vx_u8m1(v_even, 0x0F, vl8); + vuint8m1_t v_odd_lo = __riscv_vand_vx_u8m1(v_odd, 0x0F, vl8); + vuint8m1_t v_lo = __riscv_vor_vv_u8m1(v_even_lo, __riscv_vsll_vx_u8m1(v_odd_lo, 4, vl8), vl8); + + // high nibble part: (even >> 4) | (odd & 0xF0) + vuint8m1_t v_even_hi = __riscv_vsrl_vx_u8m1(v_even, 4, vl8); + vuint8m1_t v_odd_hi = __riscv_vand_vx_u8m1(v_odd, 0xF0, vl8); + vuint8m1_t v_hi = __riscv_vor_vv_u8m1(v_even_hi, v_odd_hi, vl8); + + __riscv_vse8_v_u8m1(dq, v_lo, vl8); + __riscv_vse8_v_u8m1(dq + 8, v_hi, vl8); + } + } + + dst++; + } + src += nrows_interleaved * nblocks; + } + return 0; + + GGML_UNUSED(data_size); +} + +static int repack_q4_1_to_q4_1_32_bl_ref(ggml_tensor * t, + int interleave_block, + const void * GGML_RESTRICT data, + size_t data_size) { + GGML_ASSERT(t->type == GGML_TYPE_Q4_1); + GGML_ASSERT(interleave_block == 32); // unused + + constexpr int nrows_interleaved = 32; + + block_q4_1x32 * dst = (block_q4_1x32 *) t->data; + const block_q4_1 * src = (const block_q4_1 *) data; + block_q4_1 dst_tmp[32]; + int nrow = ggml_nrows(t); + int nblocks = t->ne[0] / QK4_1; + + GGML_ASSERT(data_size == nrow * nblocks * sizeof(block_q4_1)); + + if (t->ne[1] % nrows_interleaved != 0 || t->ne[0] % QK4_1 != 0) { + return -1; + } + + for (int b = 0; b < nrow; b += nrows_interleaved) { + for (int64_t x = 0; x < nblocks; x++) { + for (int i = 0; i < nrows_interleaved; i++) { + dst_tmp[i] = src[x + i * nblocks]; + } + *dst++ = make_block_q4_1x32(dst_tmp, interleave_block); + } + src += nrows_interleaved * nblocks; + } + return 0; + + GGML_UNUSED(data_size); +} + +// RVV optimized version of repack_q4_1_to_q4_1_32_bl +// Eliminates the intermediate dst_tmp buffer and vectorizes nibble repack + zp computation. +static int repack_q4_1_to_q4_1_32_bl(ggml_tensor * t, + int interleave_block, + const void * GGML_RESTRICT data, + size_t data_size) { + GGML_ASSERT(t->type == GGML_TYPE_Q4_1); + GGML_ASSERT(interleave_block == 32); + + constexpr int nrows_interleaved = 32; + constexpr int qs_bytes = QK4_1 / 2; // 16 + + block_q4_1x32 * dst = (block_q4_1x32 *) t->data; + const block_q4_1 * src = (const block_q4_1 *) data; + int nrow = ggml_nrows(t); + int nblocks = t->ne[0] / QK4_1; + + GGML_ASSERT(data_size == nrow * nblocks * sizeof(block_q4_1)); + + if (t->ne[1] % nrows_interleaved != 0 || t->ne[0] % QK4_1 != 0) { + return -1; + } + + const ptrdiff_t row_stride = (ptrdiff_t) nblocks * sizeof(block_q4_1); + + for (int b = 0; b < nrow; b += nrows_interleaved) { + for (int64_t x = 0; x < nblocks; x++) { + const block_q4_1 * col_src = src + x; + + // --- 1) Gather d and m, compute zp = clamp(nearbyint(-m/d), 0, 15) --- + // block_q4_1 layout: [d(f16), m(f16), qs[16]] + // d is at byte offset 0, m is at byte offset 2 from each block start + { + const uint8_t * dm_base = (const uint8_t *) &col_src->GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.d; + ggml_half * d_dst = dst->d; + uint8_t * zp_dst = dst->zp; + size_t remaining = 32; + size_t offset = 0; + while (remaining > 0) { + size_t vl = __riscv_vsetvl_e16m1(remaining); + + // stride load d (f16) from each row + vuint16m1_t vd_raw = + __riscv_vlse16_v_u16m1((const uint16_t *) (dm_base + offset * row_stride), row_stride, vl); + __riscv_vse16_v_u16m1((uint16_t *) (d_dst + offset), vd_raw, vl); + + // stride load m (f16) from each row (offset +2 bytes from d) + vuint16m1_t vm_raw = + __riscv_vlse16_v_u16m1((const uint16_t *) (dm_base + 2 + offset * row_stride), row_stride, vl); + + // convert to f32 for zp computation: zp = nearbyint(-m / d) + vfloat16m1_t vd_f16 = __riscv_vreinterpret_v_u16m1_f16m1(vd_raw); + vfloat16m1_t vm_f16 = __riscv_vreinterpret_v_u16m1_f16m1(vm_raw); + + // -m / d in f16 directly (SpaceMIT X60 supports f16 arithmetic) + vfloat16m1_t v_neg_m = __riscv_vfneg_v_f16m1(vm_f16, vl); + vfloat16m1_t v_ratio = __riscv_vfdiv_vv_f16m1(v_neg_m, vd_f16, vl); + + // Convert to f32 for nearbyint, then clamp + vfloat32m2_t v_ratio_f32 = __riscv_vfwcvt_f_f_v_f32m2(v_ratio, vl); + + // Use integer rounding: convert f32 -> int (rounds to nearest) + vint32m2_t v_zp_i32 = __riscv_vfcvt_x_f_v_i32m2(v_ratio_f32, vl); + + // clamp to [0, 15] + v_zp_i32 = __riscv_vmax_vx_i32m2(v_zp_i32, 0, vl); + v_zp_i32 = __riscv_vmin_vx_i32m2(v_zp_i32, 15, vl); + + // narrow i32 -> u8 + vint16m1_t v_zp_i16 = __riscv_vncvt_x_x_w_i16m1(v_zp_i32, vl); + vint8mf2_t v_zp_i8 = __riscv_vncvt_x_x_w_i8mf2(v_zp_i16, vl); + vuint8mf2_t v_zp_u8 = __riscv_vreinterpret_v_i8mf2_u8mf2(v_zp_i8); + __riscv_vse8_v_u8mf2(zp_dst + offset, v_zp_u8, vl); + + offset += vl; + remaining -= vl; + } + } + + // --- 2) Nibble repack qs for each of the 32 rows --- + { + const size_t vl8 = __riscv_vsetvl_e8m1(8); + for (int i = 0; i < 32; i++) { + const uint8_t * sq = col_src[i * nblocks].qs; + uint8_t * dq = dst->qs + i * qs_bytes; + + // stride-2 load to separate even/odd bytes + vuint8m1_t v_even = __riscv_vlse8_v_u8m1(sq, 2, vl8); + vuint8m1_t v_odd = __riscv_vlse8_v_u8m1(sq + 1, 2, vl8); + + // low nibble part: (even & 0x0F) | ((odd & 0x0F) << 4) + vuint8m1_t v_even_lo = __riscv_vand_vx_u8m1(v_even, 0x0F, vl8); + vuint8m1_t v_odd_lo = __riscv_vand_vx_u8m1(v_odd, 0x0F, vl8); + vuint8m1_t v_lo = __riscv_vor_vv_u8m1(v_even_lo, __riscv_vsll_vx_u8m1(v_odd_lo, 4, vl8), vl8); + + // high nibble part: (even >> 4) | (odd & 0xF0) + vuint8m1_t v_even_hi = __riscv_vsrl_vx_u8m1(v_even, 4, vl8); + vuint8m1_t v_odd_hi = __riscv_vand_vx_u8m1(v_odd, 0xF0, vl8); + vuint8m1_t v_hi = __riscv_vor_vv_u8m1(v_even_hi, v_odd_hi, vl8); + + __riscv_vse8_v_u8m1(dq, v_lo, vl8); + __riscv_vse8_v_u8m1(dq + 8, v_hi, vl8); + } + } + + dst++; + } + src += nrows_interleaved * nblocks; + } + return 0; + + GGML_UNUSED(data_size); +} + +static int repack_q4_k_to_q4_1_32_bl(ggml_tensor * t, + int interleave_block, + const void * GGML_RESTRICT data, + size_t data_size) { + GGML_ASSERT(t->type == GGML_TYPE_Q4_K); + GGML_ASSERT(interleave_block == 32); + GGML_ASSERT(QK_K / QK4_1 == 8); + + constexpr int nrows_interleaved = 32; + + block_q4_1x32 * dst = (block_q4_1x32 *) t->data; + const block_q4_K * src = (const block_q4_K *) data; + block_q4_1 dst_tmp[32]; + int nrow = ggml_nrows(t); + int nblocks = t->ne[0] / QK_K; + + if (t->ne[1] % nrows_interleaved != 0 || t->ne[0] % QK_K != 0) { + return -1; + } + + for (int b = 0; b < nrow; b += nrows_interleaved) { + for (int64_t x = 0; x < nblocks; x++) { + for (int j = 0; j < 8; j++) { + for (int i = 0; i < nrows_interleaved; i++) { + uint8_t sc, m; + const float d = GGML_FP16_TO_FP32(src[x + i * nblocks].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.d); + const float min = + GGML_FP16_TO_FP32(src[x + i * nblocks].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.dmin); + get_scale_min_k4(j, src[x + i * nblocks].scales, &sc, &m); + const float d1 = d * sc; + const float m1 = min * m; + + dst_tmp[i].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.d = GGML_FP32_TO_FP16(d1); + dst_tmp[i].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.m = GGML_FP32_TO_FP16(-m1); + // src -> [b0, b32] [b1, b33] ... [b31, b63] + // dst -> [b0, b16] [b1, b17] ... [b15, b31] [b32, b48] [b33, b49] ... [b47, b63] + const uint8_t * q = src[x + i * nblocks].qs + (j / 2) * QK4_1; + if (j % 2 == 0) { + for (int ii = 0; ii < 16; ii++) { + dst_tmp[i].qs[ii] = (q[ii] & 0x0F) | ((q[ii + 16] & 0x0F) << 4); + } + } else { + for (int ii = 0; ii < 16; ii++) { + dst_tmp[i].qs[ii] = ((q[ii] & 0xF0) >> 4) | (q[ii + 16] & 0xF0); + } + } + } + *dst++ = make_block_q4_1x32(dst_tmp, interleave_block); + } + } + src += nrows_interleaved * nblocks; + } + return 0; + + GGML_UNUSED(data_size); +} + +static int repack_q6_k_to_q8_0_32_bl_ref(ggml_tensor * t, + int interleave_block, + const void * GGML_RESTRICT data, + size_t data_size) { + GGML_ASSERT(t->type == GGML_TYPE_Q6_K); + GGML_ASSERT(interleave_block == 32); + GGML_ASSERT(QK_K / QK4_1 == 8); + + constexpr int nrows_interleaved = 32; + + block_q8_0x32 * dst = (block_q8_0x32 *) t->data; + const block_q6_K * src = (const block_q6_K *) data; + block_q8_0 dst_tmp[32]; + int8_t aux8[QK4_1]; + int nrow = ggml_nrows(t); + int nblocks = t->ne[0] / QK_K; + + if (t->ne[0] % QK_K != 0) { + return -1; + } + + for (int b = 0; b < nrow; b += nrows_interleaved) { + int64_t nrow_real = std::min((int64_t) nrow - b, (int64_t) nrows_interleaved); + for (int64_t x = 0; x < nblocks; x++) { + for (int bi = 0; bi < 8; bi++) { + int i = 0; + for (; i < nrow_real; i++) { + const uint8_t * q4 = src[x + i * nblocks].ql; + const uint8_t * qh = src[x + i * nblocks].qh; + const int8_t * scales = src[x + i * nblocks].scales; + float d = GGML_FP16_TO_FP32(src[x + i * nblocks].d); + + q4 += 64 * (bi / 4); + qh += 32 * (bi / 4); + int8_t * GGML_RESTRICT a = aux8; + + int8_t bi_idx = bi % 4; + + if (bi_idx == 0) { + for (int l = 0; l < 32; ++l) { + a[l] = (int8_t) ((q4[l] & 0xF) | (((qh[l] >> 0) & 3) << 4)) - 32; + } + } else if (bi_idx == 1) { + for (int l = 0; l < 32; ++l) { + a[l] = (int8_t) ((q4[l + 32] & 0xF) | (((qh[l] >> 2) & 3) << 4)) - 32; + } + } else if (bi_idx == 2) { + for (int l = 0; l < 32; ++l) { + a[l] = (int8_t) ((q4[l + 0] >> 4) | (((qh[l] >> 4) & 3) << 4)) - 32; + } + } else if (bi_idx == 3) { + for (int l = 0; l < 32; ++l) { + a[l] = (int8_t) ((q4[l + 32] >> 4) | (((qh[l] >> 6) & 3) << 4)) - 32; + } + } + a = aux8; + + float a_max_abs = 0.0f; + float scale_0 = scales[bi * 2 + 0] * d; + float scale_1 = scales[bi * 2 + 1] * d; + for (int l = 0; l < 16; ++l) { + a_max_abs = std::max(a_max_abs, std::abs(a[l] * scale_0)); + } + + for (int l = 16; l < 32; ++l) { + a_max_abs = std::max(a_max_abs, std::abs(a[l] * scale_1)); + } + + float reflect_scale = a_max_abs / ((1 << 7) - 1); + float reflect_scale_0 = scale_0 / reflect_scale; + float reflect_scale_1 = scale_1 / reflect_scale; + + for (int l = 0; l < 16; ++l) { + float a_temp = std::clamp(std::nearbyintf(a[l] * reflect_scale_0), -128.0f, 127.0f); + a[l] = (int8_t) (a_temp); + } + + for (int l = 16; l < 32; ++l) { + float a_temp = std::clamp(std::nearbyintf(a[l] * reflect_scale_1), -128.0f, 127.0f); + a[l] = (int8_t) (a_temp); + } + + dst_tmp[i].d = GGML_FP32_TO_FP16(reflect_scale); + + memcpy(dst_tmp[i].qs, a, 32 * sizeof(int8_t)); + } + + for (; i < nrows_interleaved; i++) { + memset(&dst_tmp[i], 0, sizeof(block_q8_0)); + } + + *dst++ = make_block_q8_0x32(dst_tmp, interleave_block); + } + } + src += nrows_interleaved * nblocks; + } + return 0; + + GGML_UNUSED(data_size); +} + +// RVV optimized version of repack_q6_k_to_q8_0_32_bl +// Vectorizes the Q6_K dequant -> requant pipeline using RVV intrinsics. +// For each sub-block (bi), dequant 32 Q6_K values to int6 -> apply two sub-block scales -> +// find max abs -> compute reflect_scale -> requant to int8 -> gather d with stride load. +static int repack_q6_k_to_q8_0_32_bl(ggml_tensor * t, + int interleave_block, + const void * GGML_RESTRICT data, + size_t data_size) { + GGML_ASSERT(t->type == GGML_TYPE_Q6_K); + GGML_ASSERT(interleave_block == 32); + GGML_ASSERT(QK_K / QK4_1 == 8); + + constexpr int nrows_interleaved = 32; + + block_q8_0x32 * dst = (block_q8_0x32 *) t->data; + const block_q6_K * src = (const block_q6_K *) data; + int nrow = ggml_nrows(t); + int nblocks = t->ne[0] / QK_K; + + if (t->ne[1] % nrows_interleaved != 0 || t->ne[0] % QK_K != 0) { + return -1; + } + + const ptrdiff_t row_stride = (ptrdiff_t) nblocks * sizeof(block_q6_K); + + for (int b = 0; b < nrow; b += nrows_interleaved) { + for (int64_t x = 0; x < nblocks; x++) { + for (int bi = 0; bi < 8; bi++) { + // --- 1) Gather 32 d values with stride load --- + // We need to compute reflect_scale per row first, so gather d later. + // Process each row: dequant Q6_K sub-block -> requant to Q8_0 + for (int i = 0; i < nrows_interleaved; i++) { + const block_q6_K * src_blk = &src[x + i * nblocks]; + const uint8_t * q4 = src_blk->ql + 64 * (bi / 4); + const uint8_t * qh = src_blk->qh + 32 * (bi / 4); + const int8_t * scales = src_blk->scales; + float d = GGML_FP16_TO_FP32(src_blk->d); + + int8_t bi_idx = bi % 4; + + // --- Dequant 32 Q6_K values to int6 (range [-32, 31]) using RVV --- + // vl = 32 for e8m2 (VLEN=256) or loop for smaller VLEN + const size_t vl16 = __riscv_vsetvl_e8m1(16); + + vint8m1_t va_lo, va_hi; // 16 elements each + + if (bi_idx == 0) { + // a[l] = (q4[l] & 0xF) | (((qh[l] >> 0) & 3) << 4) - 32 + vuint8m1_t vq4_lo = __riscv_vle8_v_u8m1(q4, vl16); + vuint8m1_t vq4_hi = __riscv_vle8_v_u8m1(q4 + 16, vl16); + vuint8m1_t vqh_lo = __riscv_vle8_v_u8m1(qh, vl16); + vuint8m1_t vqh_hi = __riscv_vle8_v_u8m1(qh + 16, vl16); + + vuint8m1_t vlo4_lo = __riscv_vand_vx_u8m1(vq4_lo, 0x0F, vl16); + vuint8m1_t vlo4_hi = __riscv_vand_vx_u8m1(vq4_hi, 0x0F, vl16); + vuint8m1_t vh_lo = __riscv_vsll_vx_u8m1(__riscv_vand_vx_u8m1(vqh_lo, 0x03, vl16), 4, vl16); + vuint8m1_t vh_hi = __riscv_vsll_vx_u8m1(__riscv_vand_vx_u8m1(vqh_hi, 0x03, vl16), 4, vl16); + + vuint8m1_t vcomb_lo = __riscv_vor_vv_u8m1(vlo4_lo, vh_lo, vl16); + vuint8m1_t vcomb_hi = __riscv_vor_vv_u8m1(vlo4_hi, vh_hi, vl16); + + va_lo = __riscv_vsub_vx_i8m1(__riscv_vreinterpret_v_u8m1_i8m1(vcomb_lo), 32, vl16); + va_hi = __riscv_vsub_vx_i8m1(__riscv_vreinterpret_v_u8m1_i8m1(vcomb_hi), 32, vl16); + } else if (bi_idx == 1) { + // a[l] = (q4[l+32] & 0xF) | (((qh[l] >> 2) & 3) << 4) - 32 + vuint8m1_t vq4_lo = __riscv_vle8_v_u8m1(q4 + 32, vl16); + vuint8m1_t vq4_hi = __riscv_vle8_v_u8m1(q4 + 48, vl16); + vuint8m1_t vqh_lo = __riscv_vle8_v_u8m1(qh, vl16); + vuint8m1_t vqh_hi = __riscv_vle8_v_u8m1(qh + 16, vl16); + + vuint8m1_t vlo4_lo = __riscv_vand_vx_u8m1(vq4_lo, 0x0F, vl16); + vuint8m1_t vlo4_hi = __riscv_vand_vx_u8m1(vq4_hi, 0x0F, vl16); + vuint8m1_t vh_lo = __riscv_vsll_vx_u8m1( + __riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(vqh_lo, 2, vl16), 0x03, vl16), 4, vl16); + vuint8m1_t vh_hi = __riscv_vsll_vx_u8m1( + __riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(vqh_hi, 2, vl16), 0x03, vl16), 4, vl16); + + vuint8m1_t vcomb_lo = __riscv_vor_vv_u8m1(vlo4_lo, vh_lo, vl16); + vuint8m1_t vcomb_hi = __riscv_vor_vv_u8m1(vlo4_hi, vh_hi, vl16); + + va_lo = __riscv_vsub_vx_i8m1(__riscv_vreinterpret_v_u8m1_i8m1(vcomb_lo), 32, vl16); + va_hi = __riscv_vsub_vx_i8m1(__riscv_vreinterpret_v_u8m1_i8m1(vcomb_hi), 32, vl16); + } else if (bi_idx == 2) { + // a[l] = (q4[l] >> 4) | (((qh[l] >> 4) & 3) << 4) - 32 + vuint8m1_t vq4_lo = __riscv_vle8_v_u8m1(q4, vl16); + vuint8m1_t vq4_hi = __riscv_vle8_v_u8m1(q4 + 16, vl16); + vuint8m1_t vqh_lo = __riscv_vle8_v_u8m1(qh, vl16); + vuint8m1_t vqh_hi = __riscv_vle8_v_u8m1(qh + 16, vl16); + + vuint8m1_t vhi4_lo = __riscv_vsrl_vx_u8m1(vq4_lo, 4, vl16); + vuint8m1_t vhi4_hi = __riscv_vsrl_vx_u8m1(vq4_hi, 4, vl16); + vuint8m1_t vh_lo = __riscv_vsll_vx_u8m1( + __riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(vqh_lo, 4, vl16), 0x03, vl16), 4, vl16); + vuint8m1_t vh_hi = __riscv_vsll_vx_u8m1( + __riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(vqh_hi, 4, vl16), 0x03, vl16), 4, vl16); + + vuint8m1_t vcomb_lo = __riscv_vor_vv_u8m1(vhi4_lo, vh_lo, vl16); + vuint8m1_t vcomb_hi = __riscv_vor_vv_u8m1(vhi4_hi, vh_hi, vl16); + + va_lo = __riscv_vsub_vx_i8m1(__riscv_vreinterpret_v_u8m1_i8m1(vcomb_lo), 32, vl16); + va_hi = __riscv_vsub_vx_i8m1(__riscv_vreinterpret_v_u8m1_i8m1(vcomb_hi), 32, vl16); + } else { // bi_idx == 3 + // a[l] = (q4[l+32] >> 4) | (((qh[l] >> 6) & 3) << 4) - 32 + vuint8m1_t vq4_lo = __riscv_vle8_v_u8m1(q4 + 32, vl16); + vuint8m1_t vq4_hi = __riscv_vle8_v_u8m1(q4 + 48, vl16); + vuint8m1_t vqh_lo = __riscv_vle8_v_u8m1(qh, vl16); + vuint8m1_t vqh_hi = __riscv_vle8_v_u8m1(qh + 16, vl16); + + vuint8m1_t vhi4_lo = __riscv_vsrl_vx_u8m1(vq4_lo, 4, vl16); + vuint8m1_t vhi4_hi = __riscv_vsrl_vx_u8m1(vq4_hi, 4, vl16); + vuint8m1_t vh_lo = __riscv_vsll_vx_u8m1( + __riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(vqh_lo, 6, vl16), 0x03, vl16), 4, vl16); + vuint8m1_t vh_hi = __riscv_vsll_vx_u8m1( + __riscv_vand_vx_u8m1(__riscv_vsrl_vx_u8m1(vqh_hi, 6, vl16), 0x03, vl16), 4, vl16); + + vuint8m1_t vcomb_lo = __riscv_vor_vv_u8m1(vhi4_lo, vh_lo, vl16); + vuint8m1_t vcomb_hi = __riscv_vor_vv_u8m1(vhi4_hi, vh_hi, vl16); + + va_lo = __riscv_vsub_vx_i8m1(__riscv_vreinterpret_v_u8m1_i8m1(vcomb_lo), 32, vl16); + va_hi = __riscv_vsub_vx_i8m1(__riscv_vreinterpret_v_u8m1_i8m1(vcomb_hi), 32, vl16); + } + + // --- Widen to i16 for scaled abs computation --- + float scale_0 = scales[bi * 2 + 0] * d; + float scale_1 = scales[bi * 2 + 1] * d; + + // Widen i8 -> i16 -> f32 for abs*scale computation + vint16m2_t va_lo_w = __riscv_vsext_vf2_i16m2(va_lo, vl16); + vint16m2_t va_hi_w = __riscv_vsext_vf2_i16m2(va_hi, vl16); + + // Compute |a[l] * scale_0| for lo half, |a[l] * scale_1| for hi half + vfloat32m4_t vf_lo = __riscv_vfcvt_f_x_v_f32m4(__riscv_vsext_vf2_i32m4(va_lo_w, vl16), vl16); + vfloat32m4_t vf_hi = __riscv_vfcvt_f_x_v_f32m4(__riscv_vsext_vf2_i32m4(va_hi_w, vl16), vl16); + + vfloat32m4_t vabs_lo = __riscv_vfabs_v_f32m4(__riscv_vfmul_vf_f32m4(vf_lo, scale_0, vl16), vl16); + vfloat32m4_t vabs_hi = __riscv_vfabs_v_f32m4(__riscv_vfmul_vf_f32m4(vf_hi, scale_1, vl16), vl16); + + // Find max abs across both halves + vfloat32m4_t vabs_max = __riscv_vfmax_vv_f32m4(vabs_lo, vabs_hi, vl16); + + // Reduce to scalar max + vfloat32m1_t vzero = __riscv_vfmv_v_f_f32m1(0.0f, 1); + vfloat32m1_t vmax_red = __riscv_vfredmax_vs_f32m4_f32m1(vabs_max, vzero, vl16); + float a_max_abs = __riscv_vfmv_f_s_f32m1_f32(vmax_red); + + float reflect_scale = a_max_abs / 127.0f; + float reflect_scale_0 = scale_0 / reflect_scale; + float reflect_scale_1 = scale_1 / reflect_scale; + + // --- Requant: a[l] = clamp(nearbyint(a[l] * reflect_scale_x), -128, 127) --- + vfloat32m4_t vscaled_lo = __riscv_vfmul_vf_f32m4(vf_lo, reflect_scale_0, vl16); + vfloat32m4_t vscaled_hi = __riscv_vfmul_vf_f32m4(vf_hi, reflect_scale_1, vl16); + + // fcvt.x rounds to nearest (using current rounding mode) + vint32m4_t vi_lo = __riscv_vfcvt_x_f_v_i32m4(vscaled_lo, vl16); + vint32m4_t vi_hi = __riscv_vfcvt_x_f_v_i32m4(vscaled_hi, vl16); + + // Clamp to [-128, 127] + vi_lo = __riscv_vmax_vx_i32m4(vi_lo, -128, vl16); + vi_lo = __riscv_vmin_vx_i32m4(vi_lo, 127, vl16); + vi_hi = __riscv_vmax_vx_i32m4(vi_hi, -128, vl16); + vi_hi = __riscv_vmin_vx_i32m4(vi_hi, 127, vl16); + + // Narrow i32 -> i16 -> i8 + vint16m2_t vi16_lo = __riscv_vncvt_x_x_w_i16m2(vi_lo, vl16); + vint16m2_t vi16_hi = __riscv_vncvt_x_x_w_i16m2(vi_hi, vl16); + vint8m1_t vi8_lo = __riscv_vncvt_x_x_w_i8m1(vi16_lo, vl16); + vint8m1_t vi8_hi = __riscv_vncvt_x_x_w_i8m1(vi16_hi, vl16); + + // Store d and qs directly into dst block + dst->d[i] = GGML_FP32_TO_FP16(reflect_scale); + int8_t * dq = (int8_t *) dst->qs + i * QK8_0; + __riscv_vse8_v_i8m1(dq, vi8_lo, vl16); + __riscv_vse8_v_i8m1(dq + 16, vi8_hi, vl16); + } + dst++; + } + } + src += nrows_interleaved * nblocks; + } + return 0; + + GGML_UNUSED(data_size); +} + +static int repack_q8_0_to_q8_0_32_bl_ref(ggml_tensor * t, + int interleave_block, + const void * GGML_RESTRICT data, + size_t data_size) { + GGML_ASSERT(t->type == GGML_TYPE_Q8_0); + GGML_ASSERT(interleave_block == 32); // unused + + constexpr int nrows_interleaved = 32; + + block_q8_0x32 * dst = (block_q8_0x32 *) t->data; + const block_q8_0 * src = (const block_q8_0 *) data; + block_q8_0 dst_tmp[32]; + int nrow = ggml_nrows(t); + int nblocks = t->ne[0] / QK8_0; + + GGML_ASSERT(data_size == nrow * nblocks * sizeof(block_q8_0)); + + if (t->ne[0] % QK8_0 != 0) { + return -1; + } + + for (int b = 0; b < nrow; b += nrows_interleaved) { + int64_t nrows_real = std::min((int64_t) nrow - b, (int64_t) nrows_interleaved); + for (int64_t x = 0; x < nblocks; x++) { + int i = 0; + for (; i < nrows_real; i++) { + dst_tmp[i] = src[x + i * nblocks]; + } + for (; i < nrows_interleaved; i++) { + memset(&dst_tmp[i], 0, sizeof(block_q8_0)); + } + *dst++ = make_block_q8_0x32(dst_tmp, interleave_block); + } + src += nrows_interleaved * nblocks; + } + return 0; + + GGML_UNUSED(data_size); +} + +// RVV optimized version of repack_q8_0_to_q8_0_32_bl +// Eliminates the intermediate dst_tmp buffer and vectorizes scale gather + qs copy. +static int repack_q8_0_to_q8_0_32_bl(ggml_tensor * t, + int interleave_block, + const void * GGML_RESTRICT data, + size_t data_size) { + GGML_ASSERT(t->type == GGML_TYPE_Q8_0); + GGML_ASSERT(interleave_block == 32); + + constexpr int nrows_interleaved = 32; + + block_q8_0x32 * dst = (block_q8_0x32 *) t->data; + const block_q8_0 * src = (const block_q8_0 *) data; + int nrow = ggml_nrows(t); + int nblocks = t->ne[0] / QK8_0; + + GGML_ASSERT(data_size == nrow * nblocks * sizeof(block_q8_0)); + + if (t->ne[1] % nrows_interleaved != 0 || t->ne[0] % QK8_0 != 0) { + return -1; + } + + const ptrdiff_t row_stride = (ptrdiff_t) nblocks * sizeof(block_q8_0); + + for (int b = 0; b < nrow; b += nrows_interleaved) { + for (int64_t x = 0; x < nblocks; x++) { + const block_q8_0 * col_src = src + x; + + // --- 1) Gather 32 scale values (ggml_half d) with stride load --- + { + const uint8_t * d_base = (const uint8_t *) &col_src->d; + ggml_half * d_dst = dst->d; + size_t remaining = 32; + size_t offset = 0; + while (remaining > 0) { + size_t vl = __riscv_vsetvl_e16m1(remaining); + vuint16m1_t vd = + __riscv_vlse16_v_u16m1((const uint16_t *) (d_base + offset * row_stride), row_stride, vl); + __riscv_vse16_v_u16m1((uint16_t *) (d_dst + offset), vd, vl); + offset += vl; + remaining -= vl; + } + } + + // --- 2) Copy qs for each of the 32 rows (32 bytes per row) --- + { + for (int i = 0; i < 32; i++) { + const int8_t * sq = col_src[i * nblocks].qs; + int8_t * dq = (int8_t *) dst->qs + i * QK8_0; + + size_t len = QK8_0; + size_t idx = 0; + while (len > 0) { + size_t vl = __riscv_vsetvl_e8m2(len); + vint8m2_t vs = __riscv_vle8_v_i8m2(sq + idx, vl); + __riscv_vse8_v_i8m2(dq + idx, vs, vl); + idx += vl; + len -= vl; + } + } + } + + dst++; + } + src += nrows_interleaved * nblocks; + } + return 0; + + GGML_UNUSED(data_size); +} + +static void convert_mxfp4_to_5bit(const block_mxfp4 & src, spacemit_kernels::nrow_block_mxfp4<1> & dst) { + dst.e[0] = src.e; + + // Decode all 32 mxfp4 values to signed integers via kvalues_mxfp4 + int8_t vals[32]; + for (int j = 0; j < QK_MXFP4 / 2; j++) { + vals[j] = kvalues_mxfp4[src.qs[j] & 0xF]; + vals[j + QK_MXFP4 / 2] = kvalues_mxfp4[src.qs[j] >> 4]; + } + + // vals [b0, b1, b2, b3, ..., b30, b31] + // Pack abs into qs with reorder: [b0,b1]..[b14,b15]..[b30,b31] + for (int j = 0; j < QK_MXFP4 / 2; j++) { + uint8_t lo0 = static_cast<uint8_t>(std::abs(vals[j * 2])); + uint8_t lo1 = static_cast<uint8_t>(std::abs(vals[j * 2 + 1])); + dst.qs[j] = (lo0 & 0x0F) | ((lo1 & 0x0F) << 4); + } + + // Pack sign bits into qh[4] (32 bits total, 1 bit per weight) + // reorder: [0,1,2,...,15,16,17,...,31] after the qs reorder above + uint32_t sign_bits = 0; + for (int j = 0; j < 32; j++) { + if (vals[j] < 0) { + sign_bits |= (1u << j); + } + } + memcpy(dst.qh, &sign_bits, 4); +} + +static spacemit_kernels::nrow_block_mxfp4<32> make_block_mxfp4x32(spacemit_kernels::nrow_block_mxfp4<1> * in, + unsigned int blck_size_interleave) { + spacemit_kernels::nrow_block_mxfp4<32> out; + GGML_ASSERT(QK_MXFP4 / blck_size_interleave == 1); + GGML_UNUSED(blck_size_interleave); + + for (int i = 0; i < 32; i++) { + out.e[i] = in[i].e[0]; + } + + // qs: copy per-row 16 bytes + for (int i = 0; i < 32; i++) { + memcpy(out.qs + i * 16, in[i].qs, 16); + } + + // qh: copy per-row 4 bytes + for (int i = 0; i < 32; i++) { + memcpy(out.qh + i * 4, in[i].qh, 4); + } + + return out; +} + +static int repack_mxfp4_to_mxfp4_32_bl(ggml_tensor * t, + int interleave_block, + const void * GGML_RESTRICT data, + size_t data_size) { + GGML_ASSERT(t->type == GGML_TYPE_MXFP4); + GGML_ASSERT(interleave_block == 32); + + constexpr int nrows_interleaved = 32; + + spacemit_kernels::nrow_block_mxfp4<32> * dst = (spacemit_kernels::nrow_block_mxfp4<32> *) t->data; + const block_mxfp4 * src = (const block_mxfp4 *) data; + spacemit_kernels::nrow_block_mxfp4<1> dst_tmp[32]; + int nrow = ggml_nrows(t); + int nblocks = t->ne[0] / QK_MXFP4; + + GGML_ASSERT(data_size == nrow * nblocks * sizeof(block_mxfp4)); + + if (t->ne[1] % nrows_interleaved != 0 || t->ne[0] % QK_MXFP4 != 0) { + return -1; + } + + for (int b = 0; b < nrow; b += nrows_interleaved) { + for (int64_t x = 0; x < nblocks; x++) { + for (int i = 0; i < nrows_interleaved; i++) { + convert_mxfp4_to_5bit(src[x + i * nblocks], dst_tmp[i]); + } + *dst++ = make_block_mxfp4x32(dst_tmp, interleave_block); + } + src += nrows_interleaved * nblocks; + } + return 0; +} + +static spacemit_kernels::nrow_block_q5_1<32> make_block_q5_1x32(spacemit_kernels::nrow_block_q5_1<1> * in, + unsigned int blck_size_interleave) { + spacemit_kernels::nrow_block_q5_1<32> out; + GGML_ASSERT(QK5_1 / blck_size_interleave == 1); + GGML_UNUSED(blck_size_interleave); + + for (int i = 0; i < 32; i++) { + out.scales16[i] = in[i].scales16[0]; + out.zp[i] = in[i].zp[0]; + } + + // qs: low 4 bits, reorder from [b0,b16],[b1,b17]... to [b0,b1]...[b14,b15] and [b16,b17]...[b30,b31] + for (int i = 0; i < 32; i++) { + // low half [0..15] + for (int j = 0; j < QK5_1 / 4; j++) { + out.qs[i * QK5_1 / 2 + j] = (in[i].qs[j * 2] & 0x0F) | ((in[i].qs[j * 2 + 1] & 0x0F) << 4); + } + // high half [16..31] + for (int j = 0; j < QK5_1 / 4; j++) { + out.qs[i * QK5_1 / 2 + QK5_1 / 4 + j] = ((in[i].qs[j * 2] & 0xF0) >> 4) | (in[i].qs[j * 2 + 1] & 0xF0); + } + } + + // qh: 5th bit, copy directly + for (int i = 0; i < 32; i++) { + for (int j = 0; j < 4; j++) { + out.qh[i * 4 + j] = in[i].qh[j]; + } + } + + return out; +} + +static spacemit_kernels::nrow_block_q5_0<32> make_block_q5_0x32(spacemit_kernels::nrow_block_q5_0<1> * in, + unsigned int blck_size_interleave) { + spacemit_kernels::nrow_block_q5_0<32> out; + GGML_ASSERT(QK5_0 / blck_size_interleave == 1); + GGML_UNUSED(blck_size_interleave); + + for (int i = 0; i < 32; i++) { + out.scales16[i] = in[i].scales16[0]; + } + + // qs: low 4 bits, reorder from [b0,b16],[b1,b17]... to [b0,b1]...[b14,b15] and [b16,b17]...[b30,b31] + for (int i = 0; i < 32; i++) { + // low half [0..15] + for (int j = 0; j < QK5_0 / 4; j++) { + out.qs[i * QK5_0 / 2 + j] = (in[i].qs[j * 2] & 0x0F) | ((in[i].qs[j * 2 + 1] & 0x0F) << 4); + } + // high half [16..31] + for (int j = 0; j < QK5_0 / 4; j++) { + out.qs[i * QK5_0 / 2 + QK5_0 / 4 + j] = ((in[i].qs[j * 2] & 0xF0) >> 4) | (in[i].qs[j * 2 + 1] & 0xF0); + } + } + + // qh: 5th bit, copy directly + for (int i = 0; i < 32; i++) { + for (int j = 0; j < 4; j++) { + out.qh[i * 4 + j] = in[i].qh[j]; + } + } + + return out; +} + +static int repack_q5_0_to_q5_0_32_bl(ggml_tensor * t, + int interleave_block, + const void * GGML_RESTRICT data, + size_t data_size) { + GGML_ASSERT(t->type == GGML_TYPE_Q5_0); + GGML_ASSERT(interleave_block == 32); // unused + + constexpr int nrows_interleaved = 32; + + spacemit_kernels::nrow_block_q5_0<32> * dst = (spacemit_kernels::nrow_block_q5_0<32> *) t->data; + const block_q5_0 * src = (const block_q5_0 *) data; + spacemit_kernels::nrow_block_q5_0<1> dst_tmp[32]; + int nrow = ggml_nrows(t); + int nblocks = t->ne[0] / QK5_0; + + GGML_ASSERT(data_size == nrow * nblocks * sizeof(block_q5_0)); + + if (t->ne[1] % nrows_interleaved != 0 || t->ne[0] % QK5_0 != 0) { + return -1; + } + + for (int b = 0; b < nrow; b += nrows_interleaved) { + for (int64_t x = 0; x < nblocks; x++) { + for (int i = 0; i < nrows_interleaved; i++) { + const block_q5_0 & s = src[x + i * nblocks]; + + dst_tmp[i].scales16[0] = s.d; + memcpy(dst_tmp[i].qs, s.qs, sizeof(dst_tmp[i].qs)); + memcpy(dst_tmp[i].qh, s.qh, sizeof(dst_tmp[i].qh)); + } + *dst++ = make_block_q5_0x32(dst_tmp, interleave_block); + } + src += nrows_interleaved * nblocks; + } + return 0; +} + +static int repack_q5_1_to_q5_1_32_bl(ggml_tensor * t, + int interleave_block, + const void * GGML_RESTRICT data, + size_t data_size) { + GGML_ASSERT(t->type == GGML_TYPE_Q5_1); + GGML_ASSERT(interleave_block == 32); // unused + + constexpr int nrows_interleaved = 32; + + spacemit_kernels::nrow_block_q5_1<32> * dst = (spacemit_kernels::nrow_block_q5_1<32> *) t->data; + const block_q5_1 * src = (const block_q5_1 *) data; + spacemit_kernels::nrow_block_q5_1<1> dst_tmp[32]; + int nrow = ggml_nrows(t); + int nblocks = t->ne[0] / QK5_1; + + GGML_ASSERT(data_size == nrow * nblocks * sizeof(block_q5_1)); + + if (t->ne[1] % nrows_interleaved != 0 || t->ne[0] % QK5_1 != 0) { + return -1; + } + + for (int b = 0; b < nrow; b += nrows_interleaved) { + for (int64_t x = 0; x < nblocks; x++) { + for (int i = 0; i < nrows_interleaved; i++) { + const block_q5_1 & s = src[x + i * nblocks]; + + float d = GGML_FP16_TO_FP32(s.GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.d); + float m = GGML_FP16_TO_FP32(s.GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.m); + + if (d == 0.0f) { + dst_tmp[i].scales16[0] = GGML_FP32_TO_FP16(std::fabs(m)); + dst_tmp[i].zp[0] = m < 0.0f ? 1 : 0; + memset(dst_tmp[i].qh, 0, sizeof(dst_tmp[i].qh)); + memset(dst_tmp[i].qs, m > 0.0f ? 0x11 : 0x00, sizeof(dst_tmp[i].qs)); + continue; + } + + float mid = std::nearbyintf(-m / d); + mid = std::min(31.0f, std::max(0.0f, mid)); + + dst_tmp[i].scales16[0] = GGML_FP32_TO_FP16(d); + dst_tmp[i].zp[0] = static_cast<uint8_t>(mid); + + // qs: copy low 4 bits directly (same nibble packing) + memcpy(dst_tmp[i].qs, s.qs, QK5_1 / 2); + + // qh: copy 5th bit directly + memcpy(dst_tmp[i].qh, s.qh, 4); + } + *dst++ = make_block_q5_1x32(dst_tmp, interleave_block); + } + src += nrows_interleaved * nblocks; + } + return 0; +} + +static int repack_q5_k_to_q5_1_32_bl(ggml_tensor * t, + int interleave_block, + const void * GGML_RESTRICT data, + size_t data_size) { + GGML_ASSERT(t->type == GGML_TYPE_Q5_K); + GGML_ASSERT(interleave_block == 32); + GGML_ASSERT(QK_K / QK5_1 == 8); + + constexpr int nrows_interleaved = 32; + + spacemit_kernels::nrow_block_q5_1<32> * dst = (spacemit_kernels::nrow_block_q5_1<32> *) t->data; + const block_q5_K * src = (const block_q5_K *) data; + spacemit_kernels::nrow_block_q5_1<1> dst_tmp[32]; + int nrow = ggml_nrows(t); + int nblocks = t->ne[0] / QK_K; + + if (t->ne[1] % nrows_interleaved != 0 || t->ne[0] % QK_K != 0) { + return -1; + } + + for (int b = 0; b < nrow; b += nrows_interleaved) { + for (int64_t x = 0; x < nblocks; x++) { + for (int j = 0; j < 8; j++) { + for (int i = 0; i < nrows_interleaved; i++) { + uint8_t sc, m; + const float d = GGML_FP16_TO_FP32(src[x + i * nblocks].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.d); + const float min = + GGML_FP16_TO_FP32(src[x + i * nblocks].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.dmin); + get_scale_min_k4(j, src[x + i * nblocks].scales, &sc, &m); + + float d1 = d * sc; + float m1 = min * m; + + float mid = std::nearbyintf(m1 / d1); + mid = std::min(31.0f, std::max(0.0f, mid)); + dst_tmp[i].scales16[0] = GGML_FP32_TO_FP16(d1); + dst_tmp[i].zp[0] = static_cast<uint8_t>(mid); + + // src -> [b0, b32] [b1, b33] ... [b31, b63] + // dst -> [b0, b16] [b1, b17] ... [b15, b31] [b32, b48] [b33, b49] ... [b47, b63] + const uint8_t * q = src[x + i * nblocks].qs + (j / 2) * QK5_1; + if (j % 2 == 0) { + for (int ii = 0; ii < 16; ii++) { + dst_tmp[i].qs[ii] = (q[ii] & 0x0F) | ((q[ii + 16] & 0x0F) << 4); + } + } else { + for (int ii = 0; ii < 16; ii++) { + dst_tmp[i].qs[ii] = ((q[ii] & 0xF0) >> 4) | (q[ii + 16] & 0xF0); + } + } + + // Extract the 5th bit (qh) for this sub-block + // block_q5_K.qh[32]: for sub-block j, the 5th bit is at bit position j in qh[l] + // qs was reordered: dst_qs maps to src weights [0,16,1,17,...,15,31] + // So qh must follow the same reorder to stay aligned with qs + // dst qh[4] = 32 bits for 32 weights in the reordered layout: + // byte 0: weights 0..7 (from src_qh[0..7]) + // byte 1: weights 8..15 (from src_qh[8..15]) + // byte 2: weights 16..23 (from src_qh[16..23]) + // byte 3: weights 24..31 (from src_qh[24..31]) + const uint8_t * src_qh = src[x + i * nblocks].qh; + for (int bi = 0; bi < 4; bi++) { + uint8_t qh_byte = 0; + for (int k = 0; k < 8; k++) { + int src_idx = bi * 8 + k; + qh_byte |= ((src_qh[src_idx] >> j) & 1) << k; + } + dst_tmp[i].qh[bi] = qh_byte; + } + } + *dst++ = make_block_q5_1x32(dst_tmp, interleave_block); + } + } + src += nrows_interleaved * nblocks; + } + return 0; +} + +namespace ggml::cpu::riscv64_spacemit { + +template <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS> int repack(ggml_tensor *, const void *, size_t); + +template <> int repack<block_q4_0, 32, 16>(ggml_tensor * t, const void * data, size_t data_size) { + return repack_q4_0_to_q4_0_16_bl(t, 16, data, data_size); +} + +template <> int repack<block_q4_1, 32, 16>(ggml_tensor * t, const void * data, size_t data_size) { + return repack_q4_1_to_q4_1_16_bl(t, 16, data, data_size); +} + +template <> int repack<block_q4_K, 32, 16>(ggml_tensor * t, const void * data, size_t data_size) { + return repack_q4_k_to_q4_1_16_bl(t, 16, data, data_size); +} + +template <> int repack<block_q2_K, 256, 32>(ggml_tensor * t, const void * data, size_t data_size) { + return repack_q2_k_to_q2_k_32_bl(t, 32, data, data_size); +} + +template <> int repack<block_q3_K, 256, 32>(ggml_tensor * t, const void * data, size_t data_size) { + return repack_q3_k_to_q3_k_32_bl(t, 32, data, data_size); +} + +template <> int repack<block_q4_0, 32, 32>(ggml_tensor * t, const void * data, size_t data_size) { +#if 0 + return repack_q4_0_to_q4_0_32_bl_ref(t, 32, data, data_size); +#else + return repack_q4_0_to_q4_0_32_bl(t, 32, data, data_size); +#endif +} + +template <> int repack<block_q4_0, 256, 32>(ggml_tensor * t, const void * data, size_t data_size) { +#if 1 + return repack_q4_0_to_q4_0_256_32_bl_ref(t, 32, data, data_size); +#else + //return repack_q4_0_to_q4_0_256_32_bl(t, 32, data, data_size); +#endif +} + +template <> int repack<block_q4_1, 32, 32>(ggml_tensor * t, const void * data, size_t data_size) { +#if 0 + return repack_q4_1_to_q4_1_32_bl_ref(t, 32, data, data_size); +#else + return repack_q4_1_to_q4_1_32_bl(t, 32, data, data_size); +#endif +} + +template <> int repack<block_q4_1, 256, 32>(ggml_tensor * t, const void * data, size_t data_size) { +#if 1 + return repack_q4_0_to_q4_1_256_32_bl_ref(t, 32, data, data_size); +#else + return repack_q4_1_to_q4_1_256_32_bl(t, 32, data, data_size); +#endif +} + +template <> int repack<block_q4_K, 32, 32>(ggml_tensor * t, const void * data, size_t data_size) { + return repack_q4_k_to_q4_1_32_bl(t, 32, data, data_size); +} + +template <> int repack<block_q6_K, 32, 32>(ggml_tensor * t, const void * data, size_t data_size) { +#if 1 + return repack_q6_k_to_q8_0_32_bl_ref(t, 32, data, data_size); +#else + return repack_q6_k_to_q8_0_32_bl(t, 32, data, data_size); +#endif +} + +template <> int repack<block_q8_0, 32, 32>(ggml_tensor * t, const void * data, size_t data_size) { +#if 1 + return repack_q8_0_to_q8_0_32_bl_ref(t, 32, data, data_size); +#else + return repack_q8_0_to_q8_0_32_bl(t, 32, data, data_size); +#endif +} + +template <> int repack<block_mxfp4, 32, 32>(ggml_tensor * t, const void * data, size_t data_size) { + return repack_mxfp4_to_mxfp4_32_bl(t, 32, data, data_size); +} + +template <> int repack<block_q5_0, 32, 32>(ggml_tensor * t, const void * data, size_t data_size) { + return repack_q5_0_to_q5_0_32_bl(t, 32, data, data_size); +} + +template <> int repack<block_q5_1, 32, 32>(ggml_tensor * t, const void * data, size_t data_size) { + return repack_q5_1_to_q5_1_32_bl(t, 32, data, data_size); +} + +template <> int repack<block_q5_K, 32, 32>(ggml_tensor * t, const void * data, size_t data_size) { + return repack_q5_k_to_q5_1_32_bl(t, 32, data, data_size); +} + +} // namespace ggml::cpu::riscv64_spacemit diff --git a/ggml/src/ggml-cpu/spacemit/repack.h b/ggml/src/ggml-cpu/spacemit/repack.h new file mode 100644 index 00000000000..950cbde7593 --- /dev/null +++ b/ggml/src/ggml-cpu/spacemit/repack.h @@ -0,0 +1,14 @@ +#pragma once + +#include "ggml-common.h" +#include "ggml.h" + +#include <cstddef> +#include <cstdint> + +namespace ggml::cpu::riscv64_spacemit { + +template <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS> +int repack(ggml_tensor * t, const void * data, size_t data_size); + +} // namespace ggml::cpu::riscv64_spacemit diff --git a/ggml/src/ggml-cpu/spacemit/rvv_kernels.cpp b/ggml/src/ggml-cpu/spacemit/rvv_kernels.cpp new file mode 100644 index 00000000000..d2f89743622 --- /dev/null +++ b/ggml/src/ggml-cpu/spacemit/rvv_kernels.cpp @@ -0,0 +1,3178 @@ +#include "rvv_kernels.h" + +#include "common.h" +#include "ggml.h" +#include "ops.h" +#include "string.h" + +#include <algorithm> +#include <cmath> +#include <cstdint> +#include <stdexcept> + +#if !defined(__riscv_v) || !defined(__riscv_v_intrinsic) +# error "riscv v extension or v_intrinsic not enabled" +#else +# include <riscv_vector.h> +#endif + +#if !defined(__riscv_zfh) +# error "riscv zfh extension not enabled" +#endif + +#if defined(__GNUC__) +# pragma GCC diagnostic ignored "-Woverlength-strings" +# pragma GCC diagnostic ignored "-Wcast-qual" +# pragma GCC diagnostic ignored "-Wunused-parameter" +#endif + +namespace spacemit_kernels::rvv { + +namespace { + +auto align_up(size_t value, size_t alignment) { + return (value + alignment - 1) / alignment * alignment; +} + +static inline bool flash_attn_ext_supported_d_vlen1024_vf16(int64_t d) { + return d > 0 && d <= 128; +} + +static inline bool flash_attn_ext_supported_shape_vlen1024_vf16(int64_t DK, int64_t DV) { + return flash_attn_ext_supported_d_vlen1024_vf16(DK) && flash_attn_ext_supported_d_vlen1024_vf16(DV); +} + +static inline float reduce_sum_f32m4_vlen1024(vfloat32m4_t v, size_t vl) { + vfloat32m1_t s_v = __riscv_vfmv_v_f_f32m1(0.0f, 1); + s_v = __riscv_vfredusum_vs_f32m4_f32m1(v, s_v, vl); + return __riscv_vfmv_f_s_f32m1_f32(s_v); +} + +static inline float reduce_sum_f32m2_vlen1024(vfloat32m2_t v, size_t vl) { + vfloat32m1_t s_v = __riscv_vfmv_v_f_f32m1(0.0f, 1); + s_v = __riscv_vfredusum_vs_f32m2_f32m1(v, s_v, vl); + return __riscv_vfmv_f_s_f32m1_f32(s_v); +} + +// Adapted from ggml_v_expf_m2 in vec.h. This is accurate enough for softmax. +static inline vfloat32m2_t rvv_expf_approx_f32m2(vfloat32m2_t x, size_t vl) { + const vfloat32m2_t r = __riscv_vfmv_v_f_f32m2(0x1.8p23f, vl); + const vfloat32m2_t z = __riscv_vfmacc_vf_f32m2(r, 0x1.715476p+0f, x, vl); + const vfloat32m2_t n = __riscv_vfsub_vv_f32m2(z, r, vl); + const vfloat32m2_t b = + __riscv_vfnmsac_vf_f32m2(__riscv_vfnmsac_vf_f32m2(x, 0x1.62e4p-1f, n, vl), 0x1.7f7d1cp-20f, n, vl); + const vuint32m2_t e = __riscv_vsll_vx_u32m2(__riscv_vreinterpret_v_f32m2_u32m2(z), 23, vl); + const vfloat32m2_t k = __riscv_vreinterpret_v_u32m2_f32m2(__riscv_vadd_vx_u32m2(e, 0x3f800000, vl)); + const vbool16_t c = __riscv_vmfgt_vf_f32m2_b16(__riscv_vfabs_v_f32m2(n, vl), 126.0f, vl); + const vfloat32m2_t u = __riscv_vfmul_vv_f32m2(b, b, vl); + const vfloat32m2_t j = __riscv_vfmacc_vv_f32m2( + __riscv_vfmul_vf_f32m2(b, 0x1.ffffecp-1f, vl), + __riscv_vfmacc_vv_f32m2( + __riscv_vfmacc_vf_f32m2(__riscv_vfmv_v_f_f32m2(0x1.fffdb6p-2f, vl), 0x1.555e66p-3f, b, vl), + __riscv_vfmacc_vf_f32m2(__riscv_vfmv_v_f_f32m2(0x1.573e2ep-5f, vl), 0x1.0e4020p-7f, b, vl), u, vl), + u, vl); + + if (!__riscv_vcpop_m_b16(c, vl)) { + return __riscv_vfmacc_vv_f32m2(k, j, k, vl); + } + + const vbool16_t dm = __riscv_vmfle_vf_f32m2_b16(n, 0.0f, vl); + const vuint32m2_t d = __riscv_vmerge_vxm_u32m2(__riscv_vmv_v_x_u32m2(0, vl), 0x82000000, dm, vl); + const vfloat32m2_t s1 = __riscv_vreinterpret_v_u32m2_f32m2(__riscv_vadd_vx_u32m2(d, 0x7f000000, vl)); + const vfloat32m2_t s2 = __riscv_vreinterpret_v_u32m2_f32m2(__riscv_vsub_vv_u32m2(e, d, vl)); + const vfloat32m2_t r1 = + __riscv_vmerge_vvm_f32m2(__riscv_vfmacc_vv_f32m2(k, k, j, vl), + __riscv_vfmul_vv_f32m2(__riscv_vfmacc_vv_f32m2(s2, s2, j, vl), s1, vl), c, vl); + return __riscv_vmerge_vvm_f32m2(r1, __riscv_vfmul_vv_f32m2(s1, s1, vl), + __riscv_vmfgt_vf_f32m2_b16(__riscv_vfabs_v_f32m2(n, vl), 192.0f, vl), vl); +} + +static inline vfloat32m2_t rvv_tanh_approx_f32m2(vfloat32m2_t x, size_t vl) { + const vfloat32m2_t abs_x = __riscv_vfabs_v_f32m2(x, vl); + const vfloat32m2_t neg_2_abs = __riscv_vfmul_vf_f32m2(abs_x, -2.0f, vl); + const vfloat32m2_t exp_term = rvv_expf_approx_f32m2(neg_2_abs, vl); + const vfloat32m2_t numerator = __riscv_vfsub_vf_f32m2(exp_term, 1.0f, vl); + const vfloat32m2_t denominator = __riscv_vfadd_vf_f32m2(exp_term, 1.0f, vl); + const vfloat32m2_t tanh_abs = __riscv_vfneg_v_f32m2(__riscv_vfdiv_vv_f32m2(numerator, denominator, vl), vl); + const vbool16_t neg_mask = __riscv_vmflt_vf_f32m2_b16(x, 0.0f, vl); + const vfloat32m2_t tanh_neg = __riscv_vfneg_v_f32m2(tanh_abs, vl); + return __riscv_vmerge_vvm_f32m2(tanh_abs, tanh_neg, neg_mask, vl); +} + +static void rvv_softcap_tanh_inplace_f32(float * dst, int64_t dst_stride, int64_t tile_rows, int64_t n, float softcap) { + for (int tq = 0; tq < tile_rows; ++tq, dst += dst_stride) { + float * dst_row = dst; + int64_t remaining = n; + while (remaining > 0) { + const size_t vl = __riscv_vsetvl_e32m2(remaining); + vfloat32m2_t v = __riscv_vle32_v_f32m2(dst_row, vl); + v = rvv_tanh_approx_f32m2(v, vl); + v = __riscv_vfmul_vf_f32m2(v, softcap, vl); + __riscv_vse32_v_f32m2(dst_row, v, vl); + dst_row += vl; + remaining -= vl; + } + } +} + +static inline float rvv_softmax_exp_inplace_f32(float * dst, int64_t n, float max_value) { + float row_sum = 0.0f; + while (n > 0) { + const size_t vl = __riscv_vsetvl_e32m2(n); + vfloat32m2_t v = __riscv_vle32_v_f32m2(dst, vl); + v = __riscv_vfsub_vf_f32m2(v, max_value, vl); + v = rvv_expf_approx_f32m2(v, vl); + __riscv_vse32_v_f32m2(dst, v, vl); + row_sum += reduce_sum_f32m2_vlen1024(v, vl); + dst += vl; + n -= vl; + } + return row_sum; +} + +static inline float rvv_add_max_inplace_f32(float * dst, const float * src, int64_t n) { + float max_val = -INFINITY; + while (n > 0) { + const size_t vl = __riscv_vsetvl_e32m4(n); + vfloat32m4_t vdst = __riscv_vle32_v_f32m4(dst, vl); + vfloat32m4_t vsrc = __riscv_vle32_v_f32m4(src, vl); + vdst = __riscv_vfadd_vv_f32m4(vdst, vsrc, vl); + __riscv_vse32_v_f32m4(dst, vdst, vl); + + vfloat32m1_t seed = __riscv_vfmv_v_f_f32m1(max_val, 1); + seed = __riscv_vfredmax_vs_f32m4_f32m1(vdst, seed, vl); + max_val = __riscv_vfmv_f_s_f32m1_f32(seed); + + dst += vl; + src += vl; + n -= vl; + } + return max_val; +} + +static inline float rvv_softcap_add_max_inplace_f32(float * dst, const float * src, int64_t n, float softcap) { + if (softcap == 0.0f) { + return rvv_add_max_inplace_f32(dst, src, n); + } + + float max_val = -INFINITY; + while (n > 0) { + const size_t vl = __riscv_vsetvl_e32m2(n); + vfloat32m2_t vdst = __riscv_vle32_v_f32m2(dst, vl); + vfloat32m2_t vsrc = __riscv_vle32_v_f32m2(src, vl); + vdst = rvv_tanh_approx_f32m2(vdst, vl); + vdst = __riscv_vfmul_vf_f32m2(vdst, softcap, vl); + vdst = __riscv_vfadd_vv_f32m2(vdst, vsrc, vl); + __riscv_vse32_v_f32m2(dst, vdst, vl); + + vfloat32m1_t seed = __riscv_vfmv_v_f_f32m1(max_val, 1); + seed = __riscv_vfredmax_vs_f32m2_f32m1(vdst, seed, vl); + max_val = __riscv_vfmv_f_s_f32m1_f32(seed); + + dst += vl; + src += vl; + n -= vl; + } + return max_val; +} + +static inline void rvv_zero_f32(float * dst, int64_t n) { + while (n > 0) { + const size_t vl = __riscv_vsetvl_e32m4(n); + const vfloat32m4_t z = __riscv_vfmv_v_f_f32m4(0.0f, vl); + __riscv_vse32_v_f32m4(dst, z, vl); + dst += vl; + n -= vl; + } +} + +static inline void rvv_scale_f32(float * dst, float scale, int64_t n) { + while (n > 0) { + const size_t vl = __riscv_vsetvl_e32m4(n); + vfloat32m4_t v = __riscv_vle32_v_f32m4(dst, vl); + v = __riscv_vfmul_vf_f32m4(v, scale, vl); + __riscv_vse32_v_f32m4(dst, v, vl); + dst += vl; + n -= vl; + } +} + +static inline void rvv_add_inplace_f32(float * dst, + int64_t dst_stride, + const float * src, + int64_t src_stride, + int64_t tile_rows, + int64_t n) { + for (int tq = 0; tq < tile_rows; ++tq, dst += dst_stride, src += src_stride) { + int64_t remaining = n; + float * dst_row = dst; + const float * src_row = src; + while (remaining > 0) { + const size_t vl = __riscv_vsetvl_e32m4(remaining); + vfloat32m4_t vdst = __riscv_vle32_v_f32m4(dst_row, vl); + vfloat32m4_t vsrc = __riscv_vle32_v_f32m4(src_row, vl); + vdst = __riscv_vfadd_vv_f32m4(vdst, vsrc, vl); + __riscv_vse32_v_f32m4(dst_row, vdst, vl); + dst_row += vl; + src_row += vl; + remaining -= vl; + } + } +} + +static inline float rvv_max_f32(const float * src, int64_t n) { + float max_val = -INFINITY; + while (n > 0) { + const size_t vl = __riscv_vsetvl_e32m4(n); + const vfloat32m4_t v = __riscv_vle32_v_f32m4(src, vl); + vfloat32m1_t seed = __riscv_vfmv_v_f_f32m1(max_val, 1); + seed = __riscv_vfredmax_vs_f32m4_f32m1(v, seed, vl); + max_val = __riscv_vfmv_f_s_f32m1_f32(seed); + src += vl; + n -= vl; + } + return max_val; +} + +static void rvv_pack_f32_as_scaled_f16(void * dst, + int64_t dst_row_stride, + const void * src, + int64_t src_row_stride, + int64_t tile_rows, + int64_t n, + float scale) { + for (int tq = 0; tq < tile_rows; ++tq) { + const float * row_ptr = (const float *) ((const char *) src + tq * src_row_stride); + _Float16 * dst_row_ptr = (_Float16 *) ((char *) dst + tq * dst_row_stride); + int64_t remaining = n; + while (remaining > 0) { + const size_t vl = __riscv_vsetvl_e32m4(remaining); + vfloat32m4_t v32 = __riscv_vle32_v_f32m4(row_ptr, vl); + v32 = __riscv_vfmul_vf_f32m4(v32, scale, vl); + const vfloat16m2_t v16 = __riscv_vfncvt_f_f_w_f16m2(v32, vl); + __riscv_vse16_v_f16m2(dst_row_ptr, v16, vl); + dst_row_ptr += vl; + row_ptr += vl; + remaining -= vl; + } + } +} + +static void rvv_pack_scaled_f16_as_f32(void * dst, + int64_t dst_row_stride, + const void * src, + int64_t src_row_stride, + int64_t tile_rows, + int64_t n, + float scale) { + for (int tq = 0; tq < tile_rows; ++tq) { + const _Float16 * row_ptr = (const _Float16 *) ((const char *) src + tq * src_row_stride); + float * dst_row_ptr = (float *) ((char *) dst + tq * dst_row_stride); + int64_t remaining = n; + while (remaining > 0) { + const size_t vl = __riscv_vsetvl_e16m2(remaining); + const vfloat16m2_t v16 = __riscv_vle16_v_f16m2(row_ptr, vl); + vfloat32m4_t v32 = __riscv_vfwcvt_f_f_v_f32m4(v16, vl); + v32 = __riscv_vfmul_vf_f32m4(v32, scale, vl); + __riscv_vse32_v_f32m4(dst_row_ptr, v32, vl); + dst_row_ptr += vl; + row_ptr += vl; + remaining -= vl; + } + } +} + +static void rvv_pack_scaled_f32_as_f32(void * dst, + int64_t dst_row_stride, + const void * src, + int64_t src_row_stride, + int64_t tile_rows, + int64_t n, + float * scale) { + for (int tq = 0; tq < tile_rows; ++tq) { + const float * row_ptr = (const float *) ((const char *) src + tq * src_row_stride); + float * dst_row_ptr = (float *) ((char *) dst + tq * dst_row_stride); + int64_t remaining = n; + while (remaining > 0) { + const size_t vl = __riscv_vsetvl_e32m4(remaining); + vfloat32m4_t v32 = __riscv_vle32_v_f32m4(row_ptr, vl); + v32 = __riscv_vfmul_vf_f32m4(v32, scale[tq], vl); + __riscv_vse32_v_f32m4(dst_row_ptr, v32, vl); + dst_row_ptr += vl; + row_ptr += vl; + remaining -= vl; + } + } +} + +static inline void rvv_transposed_s32_mn_to_nm(int8_t * dst, + int64_t n_dst_stride, + int8_t * src, + int64_t m_src_stride, + int64_t m, + int64_t n) { + int8_t * in = src; + int8_t * out = dst; + + __asm__ volatile( + "vsetvli t0, zero, e32, m1, tu, mu \n\t" + "mul t3, t0, %[os0] \n\t" + "srli t2, %[isz0], 3 \n\t" + "blez t2, M1%= \n\t" + + "LOOP_M8%=: \n\t" + "addi a1, %[dst], 0 \n\t" + "addi s1, %[src], 0 \n\t" + "add s2, %[src], %[is0] \n\t" + "add s3, s2, %[is0] \n\t" + "add s4, s3, %[is0] \n\t" + "add s5, s4, %[is0] \n\t" + "add s6, s5, %[is0] \n\t" + "add s7, s6, %[is0] \n\t" + "add s8, s7, %[is0] \n\t" + "addi t1, %[isz1], 0 \n\t" + + "LOOP_M8N%=: \n\t" + "vsetvli t0, t1, e32, m1, tu, mu \n\t" + "sub t1, t1, t0 \n\t" + "vle32.v v0, (s1) \n\t" + "sh2add s1, t0, s1 \n\t" + "vle32.v v1, (s2) \n\t" + "sh2add s2, t0, s2 \n\t" + "vle32.v v2, (s3) \n\t" + "sh2add s3, t0, s3 \n\t" + "vle32.v v3, (s4) \n\t" + "sh2add s4, t0, s4 \n\t" + "vle32.v v4, (s5) \n\t" + "sh2add s5, t0, s5 \n\t" + "vle32.v v5, (s6) \n\t" + "sh2add s6, t0, s6 \n\t" + "vle32.v v6, (s7) \n\t" + "sh2add s7, t0, s7 \n\t" + "vle32.v v7, (s8) \n\t" + "sh2add s8, t0, s8 \n\t" + "vssseg8e32.v v0, (a1), %[os0] \n\t" + "add a1, a1, t3 \n\t" + "bnez t1, LOOP_M8N%= \n\t" + "sh3add %[src], %[is0], %[src] \n\t" + "addi %[dst], %[dst], 32 \n\t" + "addi t2, t2, -1 \n\t" + "bnez t2, LOOP_M8%= \n\t" + + "M1%=: \n\t" + "andi t2, %[isz0], 7 \n\t" + "blez t2, END%= \n\t" + + "LOOP_M1%=: \n\t" + "addi a1, %[dst], 0 \n\t" + "addi s1, %[src], 0 \n\t" + "addi t1, %[isz1], 0 \n\t" + + "LOOP_M1N%=: \n\t" + "vsetvli t0, t1, e32, m1, tu, mu \n\t" + "sub t1, t1, t0 \n\t" + "vle32.v v0, (s1) \n\t" + "sh2add s1, t0, s1 \n\t" + "vsse32.v v0, (a1), %[os0] \n\t" + "add a1, a1, t3 \n\t" + "bnez t1, LOOP_M1N%= \n\t" + "add %[src], %[is0], %[src] \n\t" + "addi %[dst], %[dst], 4 \n\t" + "addi t2, t2, -1 \n\t" + "bnez t2, LOOP_M1%= \n\t" + "END%=: \n\t" + + : [src] "+r"(in), [dst] "+r"(out), [isz0] "+r"(m) + : [isz1] "r"(n), [is0] "r"(m_src_stride), [os0] "r"(n_dst_stride) + : "cc", "t0", "t1", "t2", "t3", "s1", "s2", "s3", "s4", "s5", "s6", "s7", "s8", "a1"); +} + +static inline void rvv_transposed_s16_mn_to_nm(int8_t * dst, + int64_t n_dst_stride, + int8_t * src, + int64_t m_src_stride, + int64_t m, + int64_t n) { + int8_t * in = src; + int8_t * out = dst; + + __asm__ volatile( + "vsetvli t0, zero, e16, m1, tu, mu \n\t" + "mul t3, t0, %[os0] \n\t" + "srli t2, %[isz0], 3 \n\t" + "blez t2, M1%= \n\t" + + "LOOP_M8%=: \n\t" + "addi a1, %[dst], 0 \n\t" + "addi s1, %[src], 0 \n\t" + "add s2, %[src], %[is0] \n\t" + "add s3, s2, %[is0] \n\t" + "add s4, s3, %[is0] \n\t" + "add s5, s4, %[is0] \n\t" + "add s6, s5, %[is0] \n\t" + "add s7, s6, %[is0] \n\t" + "add s8, s7, %[is0] \n\t" + "addi t1, %[isz1], 0 \n\t" + + "LOOP_M8N%=: \n\t" + "vsetvli t0, t1, e16, m1, tu, mu \n\t" + "sub t1, t1, t0 \n\t" + "vle16.v v0, (s1) \n\t" + "sh1add s1, t0, s1 \n\t" + "vle16.v v1, (s2) \n\t" + "sh1add s2, t0, s2 \n\t" + "vle16.v v2, (s3) \n\t" + "sh1add s3, t0, s3 \n\t" + "vle16.v v3, (s4) \n\t" + "sh1add s4, t0, s4 \n\t" + "vle16.v v4, (s5) \n\t" + "sh1add s5, t0, s5 \n\t" + "vle16.v v5, (s6) \n\t" + "sh1add s6, t0, s6 \n\t" + "vle16.v v6, (s7) \n\t" + "sh1add s7, t0, s7 \n\t" + "vle16.v v7, (s8) \n\t" + "sh1add s8, t0, s8 \n\t" + "vssseg8e16.v v0, (a1), %[os0] \n\t" + "add a1, a1, t3 \n\t" + "bnez t1, LOOP_M8N%= \n\t" + "sh3add %[src], %[is0], %[src] \n\t" + "addi %[dst], %[dst], 16 \n\t" + "addi t2, t2, -1 \n\t" + "bnez t2, LOOP_M8%= \n\t" + + "M1%=: \n\t" + "andi t2, %[isz0], 7 \n\t" + "blez t2, END%= \n\t" + + "LOOP_M1%=: \n\t" + "addi a1, %[dst], 0 \n\t" + "addi s1, %[src], 0 \n\t" + "addi t1, %[isz1], 0 \n\t" + + "LOOP_M1N%=: \n\t" + "vsetvli t0, t1, e16, m1, tu, mu \n\t" + "sub t1, t1, t0 \n\t" + "vle16.v v0, (s1) \n\t" + "sh1add s1, t0, s1 \n\t" + "vsse16.v v0, (a1), %[os0] \n\t" + "add a1, a1, t3 \n\t" + "bnez t1, LOOP_M1N%= \n\t" + "add %[src], %[is0], %[src] \n\t" + "addi %[dst], %[dst], 2 \n\t" + "addi t2, t2, -1 \n\t" + "bnez t2, LOOP_M1%= \n\t" + "END%=: \n\t" + + : [src] "+r"(in), [dst] "+r"(out), [isz0] "+r"(m) + : [isz1] "r"(n), [is0] "r"(m_src_stride), [os0] "r"(n_dst_stride) + : "cc", "t0", "t1", "t2", "t3", "s1", "s2", "s3", "s4", "s5", "s6", "s7", "s8", "a1"); +} + +static inline void rvv_qk_dot_tile_f16_x1(float * dst, + const _Float16 * q_row, + const _Float16 * k_pack, + int64_t dk, + int64_t kv_tile) { + const size_t vl = __riscv_vsetvl_e16m1(kv_tile); + vfloat32m2_t acc = __riscv_vfmv_v_f_f32m2(0.0f, vl); + + for (int64_t d = 0; d < dk; ++d) { + const vfloat16m1_t k_vec = __riscv_vle16_v_f16m1(k_pack + d * ggml_fa_tile_config::KV, vl); + acc = __riscv_vfwmacc_vf_f32m2(acc, q_row[d], k_vec, vl); + } + + __riscv_vse32_v_f32m2(dst, acc, vl); +} + +static inline void rvv_qk_dot_tile_f16_x4(float * dst0, + float * dst1, + float * dst2, + float * dst3, + const _Float16 * q0, + const _Float16 * q1, + const _Float16 * q2, + const _Float16 * q3, + const _Float16 * k_pack, + int64_t dk, + int64_t kv_tile) { + const size_t vl = __riscv_vsetvl_e16m1(kv_tile); + vfloat32m2_t acc0 = __riscv_vfmv_v_f_f32m2(0.0f, vl); + vfloat32m2_t acc1 = __riscv_vfmv_v_f_f32m2(0.0f, vl); + vfloat32m2_t acc2 = __riscv_vfmv_v_f_f32m2(0.0f, vl); + vfloat32m2_t acc3 = __riscv_vfmv_v_f_f32m2(0.0f, vl); + + for (int64_t d = 0; d < dk; ++d) { + const vfloat16m1_t k_vec = __riscv_vle16_v_f16m1(k_pack + d * ggml_fa_tile_config::KV, vl); + acc0 = __riscv_vfwmacc_vf_f32m2(acc0, q0[d], k_vec, vl); + acc1 = __riscv_vfwmacc_vf_f32m2(acc1, q1[d], k_vec, vl); + acc2 = __riscv_vfwmacc_vf_f32m2(acc2, q2[d], k_vec, vl); + acc3 = __riscv_vfwmacc_vf_f32m2(acc3, q3[d], k_vec, vl); + } + + __riscv_vse32_v_f32m2(dst0, acc0, vl); + __riscv_vse32_v_f32m2(dst1, acc1, vl); + __riscv_vse32_v_f32m2(dst2, acc2, vl); + __riscv_vse32_v_f32m2(dst3, acc3, vl); +} + +static inline void rvv_pv_accumulate_f16_x1(float * dst, + const float * prob, + const _Float16 * v_pack, + int64_t kv_tile, + int64_t dv) { + int64_t d_left = dv; + int64_t d_off = 0; + + while (d_left > 0) { + const size_t vl = __riscv_vsetvl_e16m2(d_left); + vfloat32m4_t acc = __riscv_vle32_v_f32m4(dst + d_off, vl); + + for (int64_t tk = 0; tk < kv_tile; ++tk) { + const vfloat16m2_t v16 = __riscv_vle16_v_f16m2(v_pack + tk * dv + d_off, vl); + const vfloat32m4_t v32 = __riscv_vfwcvt_f_f_v_f32m4(v16, vl); + acc = __riscv_vfmacc_vf_f32m4(acc, prob[tk], v32, vl); + } + + __riscv_vse32_v_f32m4(dst + d_off, acc, vl); + d_left -= vl; + d_off += vl; + } +} + +static inline void rvv_pv_accumulate_f16_x4(float * dst0, + float * dst1, + float * dst2, + float * dst3, + const float * prob0, + const float * prob1, + const float * prob2, + const float * prob3, + const _Float16 * v_pack, + int64_t kv_tile, + int64_t dv) { + int64_t d_left = dv; + int64_t d_off = 0; + + while (d_left > 0) { + const size_t vl = __riscv_vsetvl_e16m2(d_left); + vfloat32m4_t acc0 = __riscv_vle32_v_f32m4(dst0 + d_off, vl); + vfloat32m4_t acc1 = __riscv_vle32_v_f32m4(dst1 + d_off, vl); + vfloat32m4_t acc2 = __riscv_vle32_v_f32m4(dst2 + d_off, vl); + vfloat32m4_t acc3 = __riscv_vle32_v_f32m4(dst3 + d_off, vl); + + for (int64_t tk = 0; tk < kv_tile; ++tk) { + const vfloat16m2_t v16 = __riscv_vle16_v_f16m2(v_pack + tk * dv + d_off, vl); + const vfloat32m4_t v32 = __riscv_vfwcvt_f_f_v_f32m4(v16, vl); + acc0 = __riscv_vfmacc_vf_f32m4(acc0, prob0[tk], v32, vl); + acc1 = __riscv_vfmacc_vf_f32m4(acc1, prob1[tk], v32, vl); + acc2 = __riscv_vfmacc_vf_f32m4(acc2, prob2[tk], v32, vl); + acc3 = __riscv_vfmacc_vf_f32m4(acc3, prob3[tk], v32, vl); + } + + __riscv_vse32_v_f32m4(dst0 + d_off, acc0, vl); + __riscv_vse32_v_f32m4(dst1 + d_off, acc1, vl); + __riscv_vse32_v_f32m4(dst2 + d_off, acc2, vl); + __riscv_vse32_v_f32m4(dst3 + d_off, acc3, vl); + d_left -= vl; + d_off += vl; + } +} + +static inline void rvv_qk_dot_tile(float * dst, + const float * q_row, + const float * k_pack, + int64_t dk, + int64_t kv_tile, + float scale) { + const size_t vl = __riscv_vsetvl_e32m4(kv_tile); + vfloat32m4_t acc = __riscv_vfmv_v_f_f32m4(0.0f, vl); + + for (int64_t d = 0; d < dk; ++d) { + const vfloat32m4_t k_vec = __riscv_vle32_v_f32m4(k_pack + d * kv_tile, vl); + acc = __riscv_vfmacc_vf_f32m4(acc, q_row[d] * scale, k_vec, vl); + } + + __riscv_vse32_v_f32m4(dst, acc, vl); +} + +static inline void rvv_pv_accumulate(float * dst, + const float * prob, + const float * v_pack, + int64_t kv_tile, + int64_t dv) { + int64_t d_left = dv; + int64_t d_off = 0; + + while (d_left > 0) { + const size_t vl = __riscv_vsetvl_e32m4(d_left); + vfloat32m4_t acc = __riscv_vle32_v_f32m4(dst + d_off, vl); + + for (int64_t tk = 0; tk < kv_tile; ++tk) { + const vfloat32m4_t v_vec = __riscv_vle32_v_f32m4(v_pack + tk * dv + d_off, vl); + acc = __riscv_vfmacc_vf_f32m4(acc, prob[tk], v_vec, vl); + } + + __riscv_vse32_v_f32m4(dst + d_off, acc, vl); + d_left -= vl; + d_off += vl; + } +} + +static void permute_transpose_impl(const ggml_tensor * src0, + ggml_tensor * dst, + int64_t batch, + int64_t m, + int64_t n, + int64_t batch_stride, + int64_t m_src_stride, + int64_t n_src_stride, + int64_t n_dst_stride, + int ith, + int nth) { + GGML_ASSERT(n_src_stride == sizeof(int32_t) || n_src_stride == sizeof(int16_t)); + + if (n_src_stride == sizeof(int32_t)) { + for (int64_t bi = ith; bi < batch; bi += nth) { + rvv_transposed_s32_mn_to_nm((int8_t *) ((char *) dst->data + bi * batch_stride), n_dst_stride, + (int8_t *) ((char *) src0->data + bi * batch_stride), m_src_stride, m, n); + } + } else if (n_src_stride == sizeof(int16_t)) { + for (int64_t bi = ith; bi < batch; bi += nth) { + rvv_transposed_s32_mn_to_nm((int8_t *) ((char *) dst->data + bi * batch_stride), n_dst_stride, + (int8_t *) ((char *) src0->data + bi * batch_stride), m_src_stride, m, n); + } + } else { + GGML_ABORT("not implemented"); + } +} + +template <size_t QLEN> +static void flash_attn_ext_f16_one_chunk_inner_vlen1024_vf16_mrow(float ** pq, + const char * k_data_row, + const char * v_data_row, + const ggml_fp16_t * mp, + float ** sinks, + float ** dst, + float scale, + float logit_softcap, + float slope, + int64_t nek1, + int64_t nbk1, + int64_t nbv1, + int64_t DV, + int64_t DK, + void * tcm_buffer, + size_t tcm_buffer_size) { + GGML_ASSERT(flash_attn_ext_supported_shape_vlen1024_vf16(DK, DV)); + float S[QLEN] = { 0.0f }; // sum + float M[QLEN] = { -INFINITY }; // maximum KQ value + + _Float16 * kq16_buffer = (_Float16 *) tcm_buffer; + _Float16 * qv_buffer = kq16_buffer + QLEN * DV; + const size_t qkv_temp_buffer_size = (QLEN * DV + QLEN * DK) * sizeof(_Float16); + char * kv_tile_buffer = (char *) (qv_buffer + QLEN * DK); + + { + vfloat16m2_t VKQ16_v = __riscv_vfmv_v_f_f16m2(0.0f, DV); + for (int64_t i = 0; i < QLEN; ++i) { + __riscv_vse16_v_f16m2(kq16_buffer + i * DV, VKQ16_v, DV); + vfloat16m2_t Q_q_v = __riscv_vfncvt_f_f_w_f16m2(__riscv_vle32_v_f32m4(pq[i], DK), DK); + __riscv_vse16_v_f16m2(qv_buffer + i * DK, Q_q_v, DK); + } + } + + const uintptr_t scratch_addr = reinterpret_cast<uintptr_t>(kv_tile_buffer); + const size_t scratch_size = tcm_buffer_size > qkv_temp_buffer_size ? tcm_buffer_size - qkv_temp_buffer_size : 0; + const uintptr_t kq_tile_addr = align_up(scratch_addr, alignof(float)); + const size_t scratch_prefix = kq_tile_addr - scratch_addr; + const size_t packed_tile_size = + QLEN * sizeof(float) + DK * sizeof(_Float16) + DV * sizeof(_Float16) + sizeof(float); + const int64_t max_ic_tile_step = ((int64_t) __riscv_vsetvlmax_e16m1()) & ~((int64_t) 7); + const int64_t max_fit_by_tcm = + scratch_size > scratch_prefix ? (int64_t) ((scratch_size - scratch_prefix) / packed_tile_size) : 0; + const int64_t ic_tile_step = std::min(max_ic_tile_step, max_fit_by_tcm) & ~((int64_t) 7); + + const uintptr_t k_tile_addr = kq_tile_addr + QLEN * ic_tile_step * sizeof(float); + const uintptr_t v_tile_addr = k_tile_addr + DK * ic_tile_step * sizeof(_Float16); + const uintptr_t mv_tile_addr = v_tile_addr + ic_tile_step * DV * sizeof(_Float16); + + if (ic_tile_step >= 8) { + float * kq_tile_buffer = reinterpret_cast<float *>(kq_tile_addr); + _Float16 * k_tile_pack = reinterpret_cast<_Float16 *>(k_tile_addr); + _Float16 * v_tile_pack = reinterpret_cast<_Float16 *>(v_tile_addr); + float * mv_tile_pack = reinterpret_cast<float *>(mv_tile_addr); + + const int64_t k_tile_byte_stride = ic_tile_step * (int64_t) sizeof(_Float16); + + int64_t ic_step = 0; + for (int64_t ic = 0; ic < nek1; ++ic) { + const float mv = mp ? slope * ((_Float16 *) mp)[ic] : 0.0f; + + if (mv != -INFINITY) { + const _Float16 * k_data = (const _Float16 *) (k_data_row + ic * nbk1); + const _Float16 * v_data = (const _Float16 *) (v_data_row + ic * nbv1); + + const vfloat16m2_t k_data_v = __riscv_vle16_v_f16m2(k_data, DK); + const vfloat16m2_t v_data_v = __riscv_vle16_v_f16m2(v_data, DV); + __riscv_vsse16_v_f16m2(k_tile_pack + ic_step, k_tile_byte_stride, k_data_v, DK); + __riscv_vse16_v_f16m2(v_tile_pack + ic_step * DV, v_data_v, DV); + mv_tile_pack[ic_step] = mv; + ic_step++; + } + + if (ic_step > 0 && (ic_step == ic_tile_step || ic == (nek1 - 1))) { + if constexpr (QLEN == 4) { + const size_t qk_vl = __riscv_vsetvl_e16m1(ic_step); + vfloat32m2_t qk_acc0 = __riscv_vfmv_v_f_f32m2(0.0f, qk_vl); + vfloat32m2_t qk_acc1 = __riscv_vfmv_v_f_f32m2(0.0f, qk_vl); + vfloat32m2_t qk_acc2 = __riscv_vfmv_v_f_f32m2(0.0f, qk_vl); + vfloat32m2_t qk_acc3 = __riscv_vfmv_v_f_f32m2(0.0f, qk_vl); + + for (int64_t d = 0; d < DK; ++d) { + const vfloat16m1_t k_vec = __riscv_vle16_v_f16m1(k_tile_pack + d * ic_tile_step, qk_vl); + qk_acc0 = __riscv_vfwmacc_vf_f32m2(qk_acc0, qv_buffer[0 * DK + d], k_vec, qk_vl); + qk_acc1 = __riscv_vfwmacc_vf_f32m2(qk_acc1, qv_buffer[1 * DK + d], k_vec, qk_vl); + qk_acc2 = __riscv_vfwmacc_vf_f32m2(qk_acc2, qv_buffer[2 * DK + d], k_vec, qk_vl); + qk_acc3 = __riscv_vfwmacc_vf_f32m2(qk_acc3, qv_buffer[3 * DK + d], k_vec, qk_vl); + } + + qk_acc0 = __riscv_vfmul_vf_f32m2(qk_acc0, scale, qk_vl); + qk_acc1 = __riscv_vfmul_vf_f32m2(qk_acc1, scale, qk_vl); + qk_acc2 = __riscv_vfmul_vf_f32m2(qk_acc2, scale, qk_vl); + qk_acc3 = __riscv_vfmul_vf_f32m2(qk_acc3, scale, qk_vl); + + __riscv_vse32_v_f32m2(kq_tile_buffer + 0 * ic_tile_step, qk_acc0, qk_vl); + __riscv_vse32_v_f32m2(kq_tile_buffer + 1 * ic_tile_step, qk_acc1, qk_vl); + __riscv_vse32_v_f32m2(kq_tile_buffer + 2 * ic_tile_step, qk_acc2, qk_vl); + __riscv_vse32_v_f32m2(kq_tile_buffer + 3 * ic_tile_step, qk_acc3, qk_vl); + } else { + static_assert(QLEN == 2, "unsupported QLEN"); + + const size_t qk_vl = __riscv_vsetvl_e16m1(ic_step); + vfloat32m2_t qk_acc0 = __riscv_vfmv_v_f_f32m2(0.0f, qk_vl); + vfloat32m2_t qk_acc1 = __riscv_vfmv_v_f_f32m2(0.0f, qk_vl); + + for (int64_t d = 0; d < DK; ++d) { + const vfloat16m1_t k_vec = __riscv_vle16_v_f16m1(k_tile_pack + d * ic_tile_step, qk_vl); + qk_acc0 = __riscv_vfwmacc_vf_f32m2(qk_acc0, qv_buffer[0 * DK + d], k_vec, qk_vl); + qk_acc1 = __riscv_vfwmacc_vf_f32m2(qk_acc1, qv_buffer[1 * DK + d], k_vec, qk_vl); + } + + qk_acc0 = __riscv_vfmul_vf_f32m2(qk_acc0, scale, qk_vl); + qk_acc1 = __riscv_vfmul_vf_f32m2(qk_acc1, scale, qk_vl); + + __riscv_vse32_v_f32m2(kq_tile_buffer + 0 * ic_tile_step, qk_acc0, qk_vl); + __riscv_vse32_v_f32m2(kq_tile_buffer + 1 * ic_tile_step, qk_acc1, qk_vl); + } + + for (int i = 0; i < QLEN; ++i) { + float * row_ptr = kq_tile_buffer + i * ic_tile_step; + const float tile_max = + rvv_softcap_add_max_inplace_f32(row_ptr, mv_tile_pack, ic_step, logit_softcap); + + const float Mold = M[i]; + + if (tile_max > Mold) { + const float ms = expf(Mold - tile_max); + M[i] = tile_max; + S[i] *= ms; + + vfloat16m2_t VKQ16_v = __riscv_vle16_v_f16m2(kq16_buffer + i * DV, DV); + VKQ16_v = __riscv_vfmul_vf_f16m2(VKQ16_v, (_Float16) ms, DV); + __riscv_vse16_v_f16m2(kq16_buffer + i * DV, VKQ16_v, DV); + } + + S[i] += rvv_softmax_exp_inplace_f32(row_ptr, ic_step, M[i]); + } + + if constexpr (QLEN == 4) { + vfloat16m2_t pv_acc0 = __riscv_vle16_v_f16m2(kq16_buffer + 0 * DV, DV); + vfloat16m2_t pv_acc1 = __riscv_vle16_v_f16m2(kq16_buffer + 1 * DV, DV); + vfloat16m2_t pv_acc2 = __riscv_vle16_v_f16m2(kq16_buffer + 2 * DV, DV); + vfloat16m2_t pv_acc3 = __riscv_vle16_v_f16m2(kq16_buffer + 3 * DV, DV); + + for (int64_t tk = 0; tk < ic_step; ++tk) { + const vfloat16m2_t v16 = __riscv_vle16_v_f16m2(v_tile_pack + tk * DV, DV); + pv_acc0 = + __riscv_vfmacc_vf_f16m2(pv_acc0, (_Float16) kq_tile_buffer[0 * ic_tile_step + tk], v16, DV); + pv_acc1 = + __riscv_vfmacc_vf_f16m2(pv_acc1, (_Float16) kq_tile_buffer[1 * ic_tile_step + tk], v16, DV); + pv_acc2 = + __riscv_vfmacc_vf_f16m2(pv_acc2, (_Float16) kq_tile_buffer[2 * ic_tile_step + tk], v16, DV); + pv_acc3 = + __riscv_vfmacc_vf_f16m2(pv_acc3, (_Float16) kq_tile_buffer[3 * ic_tile_step + tk], v16, DV); + } + + __riscv_vse16_v_f16m2(kq16_buffer + 0 * DV, pv_acc0, DV); + __riscv_vse16_v_f16m2(kq16_buffer + 1 * DV, pv_acc1, DV); + __riscv_vse16_v_f16m2(kq16_buffer + 2 * DV, pv_acc2, DV); + __riscv_vse16_v_f16m2(kq16_buffer + 3 * DV, pv_acc3, DV); + } else { + static_assert(QLEN == 2, "unsupported QLEN"); + vfloat16m2_t pv_acc0 = __riscv_vle16_v_f16m2(kq16_buffer + 0 * DV, DV); + vfloat16m2_t pv_acc1 = __riscv_vle16_v_f16m2(kq16_buffer + 1 * DV, DV); + + for (int64_t tk = 0; tk < ic_step; ++tk) { + const vfloat16m2_t v16 = __riscv_vle16_v_f16m2(v_tile_pack + tk * DV, DV); + pv_acc0 = + __riscv_vfmacc_vf_f16m2(pv_acc0, (_Float16) kq_tile_buffer[0 * ic_tile_step + tk], v16, DV); + pv_acc1 = + __riscv_vfmacc_vf_f16m2(pv_acc1, (_Float16) kq_tile_buffer[1 * ic_tile_step + tk], v16, DV); + } + + __riscv_vse16_v_f16m2(kq16_buffer + 0 * DV, pv_acc0, DV); + __riscv_vse16_v_f16m2(kq16_buffer + 1 * DV, pv_acc1, DV); + } + + ic_step = 0; + } + } + } else { + for (int64_t ic = 0; ic < nek1; ++ic) { + const float mv = mp ? slope * ((_Float16 *) mp)[ic] : 0.0f; + + const char * k_data = k_data_row + ic * nbk1; + const char * v_data = v_data_row + ic * nbv1; + + vfloat16m2_t k_data_v; + vfloat16m2_t v_data_v; + + if (mv != -INFINITY) { + k_data_v = __riscv_vle16_v_f16m2((_Float16 *) k_data, DK); + v_data_v = __riscv_vle16_v_f16m2((_Float16 *) v_data, DV); + } else { + continue; + } + + for (int i = 0; i < QLEN; ++i) { + vfloat16m2_t Q_q_v = __riscv_vle16_v_f16m2(qv_buffer + i * DK, DK); + vfloat32m4_t qk_acc_v = __riscv_vfwmul_vv_f32m4(k_data_v, Q_q_v, DK); + float s = reduce_sum_f32m4_vlen1024(qk_acc_v, DK); + s = s * scale; + if (logit_softcap != 0.0f) { + s = logit_softcap * tanhf(s); + } + s += mv; + + const float Mold = M[i]; + + float ms = 1.0f; // upon new higher max val, scale VKQ and KQ sum with this value + float vs = 1.0f; // post-softmax KQ value, expf(s - M) + + vfloat16m2_t VKQ16_v = __riscv_vle16_v_f16m2(kq16_buffer + i * DV, DV); + if (s > M[i]) { + // s is new maximum, ms < 1.0f, vs == expf(s - s) == 1.0f + M[i] = s; + ms = expf(Mold - M[i]); + + // V = V*expf(Mold - M) + VKQ16_v = __riscv_vfmul_vf_f16m2(VKQ16_v, ms, DV); + } else { + // no new maximum, ms == 1.0f, vs != 1.0f + vs = expf(s - M[i]); + } + VKQ16_v = __riscv_vfmacc_vf_f16m2(VKQ16_v, vs, v_data_v, DV); + __riscv_vse16_v_f16m2(kq16_buffer + i * DV, VKQ16_v, DV); + S[i] = S[i] * ms + vs; // scale and increment sum with partial sum + } + } + } + + for (int i = 0; i < QLEN; ++i) { + vfloat16m2_t VKQ16_v = __riscv_vle16_v_f16m2(kq16_buffer + i * DV, DV); + vfloat32m4_t VKQ32_v = __riscv_vfwcvt_f_f_v_f32m4(VKQ16_v, DV); + + // sinks + if (sinks[i]) { + const float s = *(sinks[i]); + + float ms = 1.0f; + float vs = 1.0f; + + if (s > M[i]) { + ms = expf(M[i] - s); + M[i] = s; + VKQ32_v = __riscv_vfmul_vf_f32m4(VKQ32_v, ms, DV); + } else { + vs = expf(s - M[i]); + } + + S[i] = S[i] * ms + vs; + } + + // V /= S + const float S_inv = S[i] == 0.0f ? 0.0f : 1.0f / S[i]; + + VKQ32_v = __riscv_vfmul_vf_f32m4(VKQ32_v, S_inv, DV); + + __riscv_vse32_v_f32m4(dst[i], VKQ32_v, DV); + } +} + +static void flash_attn_ext_f16_one_chunk_inner_vlen1024_vf16_m1(const float * pq, + const char * k_data_row, + const char * v_data_row, + const ggml_fp16_t * mp, + const float * sinks, + float * dst, + float scale, + float logit_softcap, + float slope, + int64_t nek1, + int64_t nbk1, + int64_t nbv1, + int64_t DV, + int64_t DK) { + GGML_ASSERT(flash_attn_ext_supported_shape_vlen1024_vf16(DK, DV)); + + float S = 0.0f; // sum + float M = -INFINITY; // maximum KQ value + + vfloat16m2_t VKQ16_v = __riscv_vfmv_v_f_f16m2(0.0f, DV); + + vfloat16m2_t Q_q_v = __riscv_vfncvt_f_f_w_f16m2(__riscv_vle32_v_f32m4(pq, DK), DK); + + for (int64_t ic = 0; ic < nek1; ++ic) { + const float mv = mp ? slope * ((_Float16 *) mp)[ic] : 0.0f; + if (mv == -INFINITY) { + continue; + } + + const char * k_data = k_data_row + ic * nbk1; + + vfloat16m2_t k_data_v = __riscv_vle16_v_f16m2((_Float16 *) k_data, DK); + + vfloat32m4_t qk_acc_v = __riscv_vfwmul_vv_f32m4(k_data_v, Q_q_v, DK); + float s = reduce_sum_f32m4_vlen1024(qk_acc_v, DK); + + s = s * scale; // scale KQ value + + if (logit_softcap != 0.0f) { + s = logit_softcap * tanhf(s); + } + + s += mv; // apply mask + + const float Mold = M; + + float ms = 1.0f; // upon new higher max val, scale VKQ and KQ sum with this value + float vs = 1.0f; // post-softmax KQ value, expf(s - M) + + const char * v_data = v_data_row + ic * nbv1; + + vfloat16m2_t v_data_v = __riscv_vle16_v_f16m2((_Float16 *) v_data, DV); + + if (s > M) { + // s is new maximum, ms < 1.0f, vs == expf(s - s) == 1.0f + M = s; + ms = expf(Mold - M); + + // V = V*expf(Mold - M) + VKQ16_v = __riscv_vfmul_vf_f16m2(VKQ16_v, ms, DV); + } else { + // no new maximum, ms == 1.0f, vs != 1.0f + vs = expf(s - M); + } + + VKQ16_v = __riscv_vfmacc_vf_f16m2(VKQ16_v, vs, v_data_v, DV); + + S = S * ms + vs; // scale and increment sum with partial sum + } + + vfloat32m4_t VKQ32_v = __riscv_vfwcvt_f_f_v_f32m4(VKQ16_v, DV); + + // sinks + if (sinks) { + const float s = *sinks; + + float ms = 1.0f; + float vs = 1.0f; + + if (s > M) { + ms = expf(M - s); + M = s; + VKQ32_v = __riscv_vfmul_vf_f32m4(VKQ32_v, ms, DV); + } else { + vs = expf(s - M); + } + + S = S * ms + vs; + } + + // V /= S + const float S_inv = S == 0.0f ? 0.0f : 1.0f / S; + + VKQ32_v = __riscv_vfmul_vf_f32m4(VKQ32_v, S_inv, DV); + + __riscv_vse32_v_f32m4(dst, VKQ32_v, DV); +} + +} // namespace + +void memcpy1d(void * dst, const void * src, int64_t size) { + size_t byte_size_all = size; + size_t vlen = __riscv_vlenb() * 8; + if (vlen == 256) { + // 1024 bytes + __asm__ volatile( + // + "srli t0, %[size], 10 \n\t" + "blez t0, memcpy_tail%= \n\t" + "vsetvli t1, x0, e8, m8, tu, mu \n\t" + "memcpy_main_loop%=: \n\t" + "addi t0, t0, -1 \n\t" + "vle8.v v0, (%[s]) \n\t" + "addi %[s], %[s], 256 \n\t" + "vle8.v v8, (%[s]) \n\t" + "addi %[s], %[s], 256 \n\t" + "vle8.v v16, (%[s]) \n\t" + "addi %[s], %[s], 256 \n\t" + "vle8.v v24, (%[s]) \n\t" + "addi %[s], %[s], 256 \n\t" + // + "vse8.v v0, (%[d]) \n\t" + "addi %[d], %[d], 256 \n\t" + "vse8.v v8, (%[d]) \n\t" + "addi %[d], %[d], 256 \n\t" + "vse8.v v16, (%[d]) \n\t" + "addi %[d], %[d], 256 \n\t" + "vse8.v v24, (%[d]) \n\t" + "addi %[d], %[d], 256 \n\t" + // + "bnez t0, memcpy_main_loop%= \n\t" + "memcpy_tail%=: \n\t" + "andi t1, %[size], 1023 \n\t" + "blez t1, out%= \n\t" + "memcpy_tail_loop%=: \n\t" + "vsetvli t0, t1, e8, m8, tu, mu \n\t" + "sub t1, t1, t0 \n\t" + "vle8.v v0, (%[s]) \n\t" + "add %[s], %[s], t0 \n\t" + "vse8.v v0, (%[d]) \n\t" + "add %[d], %[d], t0 \n\t" + "bnez t1, memcpy_tail_loop%= \n\t" + "out%=: \n\t" + : [s] "+r"(src), [d] "+r"(dst) + : [size] "r"(byte_size_all) + : "cc", "t0", "t1"); + } else if (vlen == 1024) { + // 2048 bytes + __asm__ volatile( + // + "srli t0, %[size], 11 \n\t" + "blez t0, memcpy_tail%= \n\t" + "vsetvli t1, x0, e8, m8, tu, mu \n\t" + "addi t2, %[s], 1024 \n\t" + "addi t3, %[d], 1024 \n\t" + "li t5, 2048 \n\t" + "memcpy_main_loop%=: \n\t" + "addi t0, t0, -1 \n\t" + "vle8.v v0, (%[s]) \n\t" + "add %[s], %[s], t5 \n\t" + "vle8.v v8, (t2) \n\t" + "add t2, t2, t5 \n\t" + // + "vse8.v v0, (%[d]) \n\t" + "add %[d], %[d], t5 \n\t" + "vse8.v v8, (t3) \n\t" + "add t3, t3, t5 \n\t" + // + "bnez t0, memcpy_main_loop%= \n\t" + "memcpy_tail%=: \n\t" + "andi t1, %[size], 2047 \n\t" + "blez t1, out%= \n\t" + "memcpy_tail_loop%=: \n\t" + "vsetvli t0, t1, e8, m2, tu, mu \n\t" + "sub t1, t1, t0 \n\t" + "vle8.v v0, (%[s]) \n\t" + "add %[s], %[s], t0 \n\t" + "vse8.v v0, (%[d]) \n\t" + "add %[d], %[d], t0 \n\t" + "bnez t1, memcpy_tail_loop%= \n\t" + "out%=: \n\t" + : [s] "+r"(src), [d] "+r"(dst) + : [size] "r"(byte_size_all) + : "cc", "t0", "t1", "t2", "t3", "t5"); + } else { + __asm__ volatile( + // + "add t1, %[size], zero \n\t" + "memcpy_tail_loop%=: \n\t" + "vsetvli t0, t1, e8, m8, tu, mu \n\t" + "sub t1, t1, t0 \n\t" + "vle8.v v0, (%[s]) \n\t" + "add %[s], %[s], t0 \n\t" + "vse8.v v0, (%[d]) \n\t" + "add %[d], %[d], t0 \n\t" + "bnez t1, memcpy_tail_loop%= \n\t" + : [s] "+r"(src), [d] "+r"(dst) + : [size] "r"(byte_size_all) + : "cc", "t0", "t1", "t2", "t4", "t3"); + } +} + +void memcpy2d(void * dst, int64_t dst_stride, const void * src, int64_t src_stride, int64_t tile_rows, int64_t size) { + for (int64_t i = 0; i < tile_rows; ++i) { + memcpy1d((char *) dst + i * dst_stride, (const char *) src + i * src_stride, size); + } +} + +void forward_flash_attn_ext_f16_one_chunk_vlen1024_vf16(const ggml_compute_params * params, + ggml_tensor * dst, + int ir0, + int ir1, + void * tcm_buffer, + size_t tcm_buffer_size) { + const ggml_tensor * q = dst->src[0]; + const ggml_tensor * k = dst->src[1]; + const ggml_tensor * v = dst->src[2]; + const ggml_tensor * mask = dst->src[3]; + const ggml_tensor * sinks = dst->src[4]; + + GGML_TENSOR_LOCALS(int64_t, neq, q, ne) + GGML_TENSOR_LOCALS(size_t, nbq, q, nb) + GGML_TENSOR_LOCALS(int64_t, nek, k, ne) + GGML_TENSOR_LOCALS(size_t, nbk, k, nb) + GGML_TENSOR_LOCALS(int64_t, nev, v, ne) + GGML_TENSOR_LOCALS(size_t, nbv, v, nb) + GGML_TENSOR_LOCALS(int64_t, ne, dst, ne) + GGML_TENSOR_LOCALS(size_t, nb, dst, nb) + + const int64_t DK = nek0; + const int64_t DV = nev0; + const int64_t N = neq1; + + GGML_ASSERT(flash_attn_ext_supported_shape_vlen1024_vf16(DK, DV)); + + // broadcast factors + const int64_t rk2 = neq2 / nek2; + const int64_t rk3 = neq3 / nek3; + + const int64_t rv2 = neq2 / nev2; + const int64_t rv3 = neq3 / nev3; + + // parallelize by q rows using ggml_vec_dot_f32 + + float scale = *((float *) dst->op_params + 0); + float max_bias = *((float *) dst->op_params + 1); + float logit_softcap = *((float *) dst->op_params + 2); + + if (logit_softcap != 0) { + scale /= logit_softcap; + } + + const uint32_t n_head = neq2; + const uint32_t n_head_log2 = 1u << (uint32_t) floor(log2(n_head)); + + const float m0 = powf(2.0f, -(max_bias) / n_head_log2); + const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2); + + const int KV_row_size = DK * sizeof(_Float16) + DV * sizeof(_Float16); + + int ith = params->ith; + int ir_step = 1; + for (int ir = ir0; ir < ir1; ir += ir_step) { + // q indices + const int iq3 = ir / (neq2 * neq1); + const int iq2 = (ir - iq3 * neq2 * neq1) / neq1; + const int iq1 = (ir - iq3 * neq2 * neq1 - iq2 * neq1); + + const int iq3_1 = (ir + 1) / (neq2 * neq1); + const int iq2_1 = (ir + 1 - iq3_1 * neq2 * neq1) / neq1; + const int iq1_1 = (ir + 1 - iq3_1 * neq2 * neq1 - iq2_1 * neq1); + + const int iq3_2 = (ir + 2) / (neq2 * neq1); + const int iq2_2 = (ir + 2 - iq3_2 * neq2 * neq1) / neq1; + const int iq1_2 = (ir + 2 - iq3_2 * neq2 * neq1 - iq2_2 * neq1); + + const int iq3_3 = (ir + 3) / (neq2 * neq1); + const int iq2_3 = (ir + 3 - iq3_3 * neq2 * neq1) / neq1; + const int iq1_3 = (ir + 3 - iq3_3 * neq2 * neq1 - iq2_3 * neq1); + + const uint32_t h = iq2; // head index + const float slope = + (max_bias > 0.0f) ? h < n_head_log2 ? powf(m0, h + 1) : powf(m1, 2 * (h - n_head_log2) + 1) : 1.0f; + + const ggml_fp16_t * mp = + mask ? (ggml_fp16_t *) ((char *) mask->data + iq1 * mask->nb[1] + (iq2 % mask->ne[2]) * mask->nb[2] + + (iq3 % mask->ne[3]) * mask->nb[3]) : + NULL; + + const bool mp_equal_2 = iq1_1 == iq1 && (iq2 % mask->ne[2]) == (iq2_1 % mask->ne[2]) && + (iq3 % mask->ne[3]) == (iq3_1 % mask->ne[3]); + + const bool mp_equal_4 = mp_equal_2 && iq1_2 == iq1 && (iq2 % mask->ne[2]) == (iq2_2 % mask->ne[2]) && + (iq3 % mask->ne[3]) == (iq3_2 % mask->ne[3]) && iq1_3 == iq1 && + (iq2 % mask->ne[2]) == (iq2_3 % mask->ne[2]) && + (iq3 % mask->ne[3]) == (iq3_3 % mask->ne[3]); + + // k indices + const int ik3 = iq3 / rk3; + const int ik2 = iq2 / rk2; + + const int ik3_1 = iq3_1 / rk3; + const int ik2_1 = iq2_1 / rk2; + + const int ik3_2 = iq3_2 / rk3; + const int ik2_2 = iq2_2 / rk2; + + const int ik3_3 = iq3_3 / rk3; + const int ik2_3 = iq2_3 / rk2; + + // v indices + const int iv3 = iq3 / rv3; + const int iv2 = iq2 / rv2; + + const int iv3_1 = iq3_1 / rv3; + const int iv2_1 = iq2_1 / rv2; + + const int iv3_2 = iq3_2 / rv3; + const int iv2_2 = iq2_2 / rv2; + + const int iv3_3 = iq3_3 / rv3; + const int iv2_3 = iq2_3 / rv2; + + const float * pq = (const float *) ((char *) q->data + (iq1 * nbq1 + iq2 * nbq2 + iq3 * nbq3)); + + std::array<float *, 4> pq_buffer; + std::array<float *, 4> sinks_buffer; + std::array<float *, 4> dst_buffer; + + if (tcm_buffer != nullptr && 4 * KV_row_size < tcm_buffer_size && ir < (ir1 - 3) && mp_equal_4 && + ik3_3 == ik3 && ik2_3 == ik2 && iv3_3 == iv3 && iv2_3 == iv2 && ik3_2 == ik3 && ik2_2 == ik2 && + iv3_2 == iv3 && iv2_2 == iv2 && ik3_1 == ik3 && ik2_1 == ik2 && iv3_1 == iv3 && iv2_1 == iv2) { + ir_step = 4; + + pq_buffer[0] = (float *) ((char *) q->data + (iq1 * nbq1 + iq2 * nbq2 + iq3 * nbq3)); + pq_buffer[1] = (float *) ((char *) q->data + (iq1_1 * nbq1 + iq2_1 * nbq2 + iq3_1 * nbq3)); + pq_buffer[2] = (float *) ((char *) q->data + (iq1_2 * nbq1 + iq2_2 * nbq2 + iq3_2 * nbq3)); + pq_buffer[3] = (float *) ((char *) q->data + (iq1_3 * nbq1 + iq2_3 * nbq2 + iq3_3 * nbq3)); + + sinks_buffer[0] = sinks ? ((float *) ((char *) sinks->data)) + iq2 : nullptr; + sinks_buffer[1] = sinks ? ((float *) ((char *) sinks->data)) + iq2_1 : nullptr; + sinks_buffer[2] = sinks ? ((float *) ((char *) sinks->data)) + iq2_2 : nullptr; + sinks_buffer[3] = sinks ? ((float *) ((char *) sinks->data)) + iq2_3 : nullptr; + + dst_buffer[0] = (float *) ((char *) dst->data + (iq3 * ne2 * ne1 + iq2 + iq1 * ne1) * nb1); + dst_buffer[1] = (float *) ((char *) dst->data + (iq3_1 * ne2 * ne1 + iq2_1 + iq1_1 * ne1) * nb1); + dst_buffer[2] = (float *) ((char *) dst->data + (iq3_2 * ne2 * ne1 + iq2_2 + iq1_2 * ne1) * nb1); + dst_buffer[3] = (float *) ((char *) dst->data + (iq3_3 * ne2 * ne1 + iq2_3 + iq1_3 * ne1) * nb1); + + flash_attn_ext_f16_one_chunk_inner_vlen1024_vf16_mrow<4>( // + pq_buffer.data(), // + (const char *) k->data + (ik2 * nbk2 + ik3 * nbk3), // + (const char *) v->data + (iv2 * nbv2 + iv3 * nbv3), // + mp, // + sinks_buffer.data(), // + dst_buffer.data(), // + scale, logit_softcap, slope, nek1, nbk1, nbv1, DV, DK, tcm_buffer, tcm_buffer_size); + } else if (tcm_buffer != nullptr && 2 * KV_row_size < tcm_buffer_size && ir < (ir1 - 1) && mp_equal_2 && + ik3_1 == ik3 && ik2_1 == ik2 && iv3_1 == iv3 && iv2_1 == iv2) { + ir_step = 2; + + pq_buffer[0] = (float *) ((char *) q->data + (iq1 * nbq1 + iq2 * nbq2 + iq3 * nbq3)); + pq_buffer[1] = (float *) ((char *) q->data + (iq1_1 * nbq1 + iq2_1 * nbq2 + iq3_1 * nbq3)); + + sinks_buffer[0] = sinks ? ((float *) ((char *) sinks->data)) + iq2 : nullptr; + sinks_buffer[1] = sinks ? ((float *) ((char *) sinks->data)) + iq2_1 : nullptr; + + dst_buffer[0] = (float *) ((char *) dst->data + (iq3 * ne2 * ne1 + iq2 + iq1 * ne1) * nb1); + dst_buffer[1] = (float *) ((char *) dst->data + (iq3_1 * ne2 * ne1 + iq2_1 + iq1_1 * ne1) * nb1); + + flash_attn_ext_f16_one_chunk_inner_vlen1024_vf16_mrow<2>( // + pq_buffer.data(), // + (const char *) k->data + (ik2 * nbk2 + ik3 * nbk3), // + (const char *) v->data + (iv2 * nbv2 + iv3 * nbv3), // + mp, // + sinks_buffer.data(), // + dst_buffer.data(), // + scale, logit_softcap, slope, nek1, nbk1, nbv1, DV, DK, tcm_buffer, tcm_buffer_size); + } else { + ir_step = 1; + flash_attn_ext_f16_one_chunk_inner_vlen1024_vf16_m1( // + pq, // + (const char *) k->data + (ik2 * nbk2 + ik3 * nbk3), // + (const char *) v->data + (iv2 * nbv2 + iv3 * nbv3), // + mp, // + sinks ? ((float *) ((char *) sinks->data)) + h : nullptr, // + (float *) ((char *) dst->data + (iq3 * ne2 * ne1 + iq2 + iq1 * ne1) * nb1), // + scale, logit_softcap, slope, nek1, nbk1, nbv1, DV, DK); + } + } +} + +void forward_flash_attn_ext_f16_tiled_vlen1024_vf16(const ggml_compute_params * params, + ggml_tensor * dst, + int ir0, + int ir1, + void * tcm_buffer, + size_t tcm_buffer_size) { + const ggml_tensor * q = dst->src[0]; + const ggml_tensor * k = dst->src[1]; + const ggml_tensor * v = dst->src[2]; + const ggml_tensor * mask = dst->src[3]; + const ggml_tensor * sinks = dst->src[4]; + + GGML_TENSOR_LOCALS(int64_t, neq, q, ne) + GGML_TENSOR_LOCALS(size_t, nbq, q, nb) + GGML_TENSOR_LOCALS(int64_t, nek, k, ne) + GGML_TENSOR_LOCALS(size_t, nbk, k, nb) + GGML_TENSOR_LOCALS(int64_t, nev, v, ne) + GGML_TENSOR_LOCALS(size_t, nbv, v, nb) + GGML_TENSOR_LOCALS(int64_t, ne, dst, ne) + GGML_TENSOR_LOCALS(size_t, nb, dst, nb) + + const int64_t DK = nek0; + const int64_t DV = nev0; + const int64_t N = neq1; + + GGML_ASSERT(flash_attn_ext_supported_shape_vlen1024_vf16(DK, DV)); + + GGML_ASSERT(ne0 == DV); + GGML_ASSERT(ne2 == N); + + // input tensor rows must be contiguous + GGML_ASSERT(nbq0 == ggml_type_size(q->type)); + GGML_ASSERT(nbk0 == ggml_type_size(k->type)); + GGML_ASSERT(nbv0 == ggml_type_size(v->type)); + + GGML_ASSERT(neq0 == DK); + GGML_ASSERT(nek0 == DK); + GGML_ASSERT(nev0 == DV); + + GGML_ASSERT(neq1 == N); + + // dst cannot be transposed or permuted + GGML_ASSERT(nb0 == sizeof(float)); + GGML_ASSERT(nb0 <= nb1); + GGML_ASSERT(nb1 <= nb2); + GGML_ASSERT(nb2 <= nb3); + + GGML_ASSERT(k->type == v->type); + const ggml_type kv_type = k->type; + + // broadcast factors + const int64_t rk2 = neq2 / nek2; + const int64_t rk3 = neq3 / nek3; + + const int64_t rv2 = neq2 / nev2; + const int64_t rv3 = neq3 / nev3; + + float * param_list = (float *) dst->op_params; + float scale = param_list[0]; + float max_bias = param_list[1]; + float logit_softcap = param_list[2]; + + if (logit_softcap != 0) { + scale /= logit_softcap; + } + + const uint32_t n_head = neq2; + const uint32_t n_head_log2 = 1u << (uint32_t) floor(log2(n_head)); + + const float m0 = powf(2.0f, -(max_bias) / n_head_log2); + const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2); + + int ith = params->ith; + + static constexpr int Q_TILE_SZ = ggml_fa_tile_config::Q; + static constexpr int KV_TILE_SZ = ggml_fa_tile_config::KV; + + // Per-thread scratch layout: + // Q_f32: Q_TILE_SZ * DK + // KQ: Q_TILE_SZ * KV_TILE_SZ + // mask32: Q_TILE_SZ * KV_TILE_SZ + // VKQ32: Q_TILE_SZ * DV + // V32: KV_TILE_SZ * DV + // K_f32: DK * KV_TILE_SZ (transposed K tile) + float * base = (float *) params->wdata + ith * (Q_TILE_SZ * DK + 2 * Q_TILE_SZ * KV_TILE_SZ + Q_TILE_SZ * DV + + KV_TILE_SZ * DV + KV_TILE_SZ * DK + CACHE_LINE_SIZE_F32); + const size_t base_size = + (Q_TILE_SZ * DK + 2 * Q_TILE_SZ * KV_TILE_SZ + Q_TILE_SZ * DV + KV_TILE_SZ * DV + KV_TILE_SZ * DK) * + sizeof(float) + + CACHE_LINE_SIZE_F32; + + if (base_size <= tcm_buffer_size && tcm_buffer != nullptr) { + base = (float *) tcm_buffer; + } + + float S_M_Buf[Q_TILE_SZ * 2]; // buffer to hold S, M, bias for one tile to reduce register pressure in main loop + float * S = S_M_Buf; + float * M = S_M_Buf + Q_TILE_SZ; + + int ir = ir0; + while (ir < ir1) { + // q indices for the start of this tile + const int iq3 = ir / (neq2 * neq1); + const int iq2 = (ir - iq3 * neq2 * neq1) / neq1; + const int iq1 = (ir - iq3 * neq2 * neq1 - iq2 * neq1); + + // Number of valid rows in this tile: + // - limited by tile size (Q_TILE_SZ) + // - limited by chunk boundary (ir1 - ir) + // - limited by head boundary (neq1 - iq1) to avoid crossing into next head + const int tile_rows = MIN(Q_TILE_SZ, MIN((int) (ir1 - ir), (int) (neq1 - iq1))); + GGML_ASSERT(tile_rows > 0); + + const uint32_t h = iq2; // head index + const float slope = + (max_bias > 0.0f) ? h < n_head_log2 ? powf(m0, h + 1) : powf(m1, 2 * (h - n_head_log2) + 1) : 1.0f; + + for (int i = 0; i < Q_TILE_SZ; ++i) { + S[i] = 0.; + M[i] = -INFINITY; + } + + float * Q_f32 = base; + float * KQ = (float *) ((char *) base + Q_TILE_SZ * DK * sizeof(float)); + float * mask32 = KQ + Q_TILE_SZ * KV_TILE_SZ; + float * VKQ32 = mask32 + Q_TILE_SZ * KV_TILE_SZ; + float * V32 = VKQ32 + Q_TILE_SZ * DV; + float * K_f32 = V32 + KV_TILE_SZ * DV; + _Float16 * Q_f16 = (_Float16 *) Q_f32; + _Float16 * V_f16 = (_Float16 *) V32; + _Float16 * K_f16 = (_Float16 *) K_f32; + + rvv_zero_f32(VKQ32, Q_TILE_SZ * DV); + + // k indices + const int ik3 = iq3 / rk3; + const int ik2 = iq2 / rk2; + + // v indices + const int iv3 = iq3 / rv3; + const int iv2 = iq2 / rv2; + + const float * pq = (const float *) ((char *) q->data + (iq1 * nbq1 + iq2 * nbq2 + iq3 * nbq3)); + if (kv_type == GGML_TYPE_F16) { + rvv_pack_f32_as_scaled_f16((uint8_t *) Q_f16, DK * sizeof(_Float16), (uint8_t *) pq, nbq1, tile_rows, DK, + scale); + } else { + memcpy2d(Q_f32, DK * sizeof(float), pq, nbq1, tile_rows, DK * sizeof(float)); + } + + for (int64_t ic = 0; ic < nek1; ic += KV_TILE_SZ) { + const int kv_tile = (int) std::min((int64_t) KV_TILE_SZ, nek1 - ic); + + rvv_zero_f32(K_f32, DK * KV_TILE_SZ); + rvv_zero_f32(V32, KV_TILE_SZ * DV); + + // skip the tile entirely if all the masks are -inf + if (mask) { + bool can_skip = true; + const ggml_fp16_t * mp_row = + (const ggml_fp16_t *) ((const char *) mask->data + iq1 * mask->nb[1] + + (iq2 % mask->ne[2]) * mask->nb[2] + (iq3 % mask->ne[3]) * mask->nb[3]); + rvv_pack_scaled_f16_as_f32(mask32, KV_TILE_SZ * sizeof(float), mp_row + ic, mask->nb[1], tile_rows, + kv_tile, slope); + + for (int tq = 0; tq < tile_rows; tq++) { + for (int tk = 0; tk < kv_tile; tk++) { + if (mask32[tq * KV_TILE_SZ + tk] != -INFINITY) { + can_skip = false; + } + } + // Pad remaining mask entries with -inf + for (int tk = kv_tile; tk < KV_TILE_SZ; tk++) { + mask32[tq * KV_TILE_SZ + tk] = -INFINITY; + } + } + + if (can_skip) { + continue; + } + } + + if (kv_type == GGML_TYPE_F16) { + rvv_transposed_s16_mn_to_nm((int8_t *) K_f16, KV_TILE_SZ * sizeof(_Float16), + (int8_t *) k->data + ic * nbk1 + ik2 * nbk2 + ik3 * nbk3, nbk1, kv_tile, + DK); + + int tq = 0; + for (; tq + 3 < tile_rows; tq += 4) { + rvv_qk_dot_tile_f16_x4(KQ + (tq + 0) * KV_TILE_SZ, KQ + (tq + 1) * KV_TILE_SZ, + KQ + (tq + 2) * KV_TILE_SZ, KQ + (tq + 3) * KV_TILE_SZ, + Q_f16 + (tq + 0) * DK, Q_f16 + (tq + 1) * DK, Q_f16 + (tq + 2) * DK, + Q_f16 + (tq + 3) * DK, K_f16, DK, kv_tile); + } + for (; tq < tile_rows; ++tq) { + rvv_qk_dot_tile_f16_x1(KQ + tq * KV_TILE_SZ, Q_f16 + tq * DK, K_f16, DK, kv_tile); + } + } else { + for (int tk = 0; tk < kv_tile; tk++) { + const char * k_data = (const char *) k->data + (ic + tk) * nbk1 + ik2 * nbk2 + ik3 * nbk3; + float * k_col = K_f32 + tk; + const float * k_src = (const float *) k_data; + for (int64_t dk = 0; dk < DK; ++dk) { + k_col[dk * KV_TILE_SZ] = k_src[dk]; + } + } + + for (int tq = 0; tq < tile_rows; ++tq) { + rvv_qk_dot_tile(KQ + tq * KV_TILE_SZ, Q_f32 + tq * DK, K_f32, DK, KV_TILE_SZ, scale); + } + } + + // Set padded KQ entries to -inf so softmax gives them zero weight + if (kv_tile < KV_TILE_SZ) { + for (int tq = 0; tq < tile_rows; tq++) { + for (int tk = kv_tile; tk < KV_TILE_SZ; tk++) { + KQ[tq * KV_TILE_SZ + tk] = -INFINITY; + } + } + } + + if (logit_softcap != 0.0f) { + rvv_softcap_tanh_inplace_f32(KQ, KV_TILE_SZ, tile_rows, KV_TILE_SZ, logit_softcap); + } + + if (mask) { + rvv_add_inplace_f32(KQ, KV_TILE_SZ, mask32, KV_TILE_SZ, tile_rows, KV_TILE_SZ); + } + + bool skip[Q_TILE_SZ] = {}; + + for (int tq = 0; tq < tile_rows; tq++) { + float * kq_row = KQ + tq * KV_TILE_SZ; + + const float tile_max = rvv_max_f32(kq_row, KV_TILE_SZ); + + if (tile_max == -INFINITY) { + skip[tq] = true; + continue; + } + + const float Mold = M[tq]; + const float Mnew = fmaxf(Mold, tile_max); + + if (Mnew > Mold) { + const float ms = expf(Mold - Mnew); + rvv_scale_f32(VKQ32 + tq * DV, ms, DV); + S[tq] *= ms; + } + M[tq] = Mnew; + + S[tq] += rvv_softmax_exp_inplace_f32(kq_row, KV_TILE_SZ, Mnew); + } + + // Pack V as contiguous [KV_TILE_SZ][DV]. + if (kv_type == GGML_TYPE_F16) { + const char * v_data = (const char *) v->data + ic * nbv1 + iv2 * nbv2 + iv3 * nbv3; + memcpy2d(V_f16, DV * sizeof(_Float16), v_data, nbv1, kv_tile, DV * sizeof(_Float16)); + + int tq = 0; + for (; tq + 3 < tile_rows; tq += 4) { + if (skip[tq + 0] || skip[tq + 1] || skip[tq + 2] || skip[tq + 3]) { + for (int i = 0; i < 4; ++i) { + if (!skip[tq + i]) { + rvv_pv_accumulate_f16_x1(VKQ32 + (tq + i) * DV, KQ + (tq + i) * KV_TILE_SZ, V_f16, + KV_TILE_SZ, DV); + } + } + continue; + } + + rvv_pv_accumulate_f16_x4(VKQ32 + (tq + 0) * DV, VKQ32 + (tq + 1) * DV, VKQ32 + (tq + 2) * DV, + VKQ32 + (tq + 3) * DV, KQ + (tq + 0) * KV_TILE_SZ, + KQ + (tq + 1) * KV_TILE_SZ, KQ + (tq + 2) * KV_TILE_SZ, + KQ + (tq + 3) * KV_TILE_SZ, V_f16, KV_TILE_SZ, DV); + } + for (; tq < tile_rows; ++tq) { + if (!skip[tq]) { + rvv_pv_accumulate_f16_x1(VKQ32 + tq * DV, KQ + tq * KV_TILE_SZ, V_f16, KV_TILE_SZ, DV); + } + } + } else { + const char * v_data = (const char *) v->data + ic * nbv1 + iv2 * nbv2 + iv3 * nbv3; + memcpy2d(V32, DV * sizeof(float), v_data, nbv1, kv_tile, DV * sizeof(float)); + + for (int tq = 0; tq < tile_rows; ++tq) { + if (!skip[tq]) { + rvv_pv_accumulate(VKQ32 + tq * DV, KQ + tq * KV_TILE_SZ, V32, KV_TILE_SZ, DV); + } + } + } + } + + // sinks (apply only to valid rows in the tile) + if (sinks) { + const float s = ((float *) ((char *) sinks->data))[h]; + + for (int tq = 0; tq < tile_rows; tq++) { + float ms = 1.0f; + float vs = 1.0f; + + if (s > M[tq]) { + ms = expf(M[tq] - s); + rvv_scale_f32(VKQ32 + tq * DV, ms, DV); + } else { + vs = expf(s - M[tq]); + } + + float S_temp = S[tq] * ms + vs; + S[tq] = S_temp == 0.0f ? 0.0f : 1.0f / S_temp; + } + } else { + for (int tq = 0; tq < tile_rows; tq++) { + const float S_inv = S[tq] == 0.0f ? 0.0f : 1.0f / S[tq]; + S[tq] = S_inv; + } + } + + float * dst_ptr = (float *) ((char *) dst->data + (iq3 * ne2 * ne1 + iq2 + (iq1) *ne1) * nb1); + rvv_pack_scaled_f32_as_f32(dst_ptr, nb1 * ne1, VKQ32, DV * sizeof(float), tile_rows, DV, S); + + ir += tile_rows; + } +} + +void forward_rms_norm_f32(ggml_compute_params * params, ggml_tensor * op) { + const ggml_tensor * src0 = op->src[0]; + ggml_tensor * dst = op; + GGML_ASSERT(ggml_are_same_shape(src0, dst)); + GGML_ASSERT(src0->nb[0] == sizeof(float)); + + int ith = params->ith; + int nth = params->nth; + + GGML_TENSOR_UNARY_OP_LOCALS + + float epsilon = *((float *) dst->op_params); + + GGML_ASSERT(epsilon > 0.0f); + + auto * input = (char *) src0->data; + auto * output = (char *) dst->data; + + const auto hidden_size = ne00; + const auto task_count = ne01 * ne02 * ne03; + const auto task_per_thread = (task_count + nth - 1) / nth; + + const auto task_begin = ith * task_per_thread; + const auto task_end = std::min((ith + 1) * task_per_thread, task_count); + + for (auto task_idx = task_begin; task_idx < task_end; task_idx++) { + int64_t i03 = task_idx / (ne02 * ne01); + int64_t i02 = (task_idx - i03 * ne02 * ne01) / ne01; + int64_t i01 = (task_idx - i03 * ne02 * ne01 - i02 * ne01); + + auto * p_input = (float *) (input + i01 * nb01 + i02 * nb02 + i03 * nb03); + auto * p_output = (float *) (output + i01 * nb1 + i02 * nb2 + i03 * nb3); + auto * p_temp_output = p_output; + + size_t gvl = __riscv_vsetvlmax_e32m4(); + vfloat32m4_t sum_sq = __riscv_vfmv_v_f_f32m4(0.f, gvl); + int64_t length = hidden_size; + while (length > 0) { + gvl = __riscv_vsetvl_e32m4(length); + vfloat32m4_t src_data = __riscv_vle32_v_f32m4(p_input, gvl); + sum_sq = __riscv_vfmacc_vv_f32m4(sum_sq, src_data, src_data, gvl); + __riscv_vse32_v_f32m4(p_temp_output, src_data, gvl); + + p_input += gvl; + p_temp_output += gvl; + length -= gvl; + } + + gvl = __riscv_vsetvlmax_e32m1(); + vfloat32m1_t zero_v = __riscv_vfmv_v_f_f32m1(0.f, gvl); + vfloat32m1_t mean_square_v = + __riscv_vfadd_vv_f32m1(__riscv_vget_v_f32m4_f32m1(sum_sq, 0), __riscv_vget_v_f32m4_f32m1(sum_sq, 1), gvl); + + mean_square_v = __riscv_vfadd_vv_f32m1(mean_square_v, __riscv_vget_v_f32m4_f32m1(sum_sq, 2), gvl); + mean_square_v = __riscv_vfadd_vv_f32m1(mean_square_v, __riscv_vget_v_f32m4_f32m1(sum_sq, 3), gvl); + mean_square_v = __riscv_vfredusum_vs_f32m1_f32m1(mean_square_v, zero_v, gvl); + + float mean_square = __riscv_vfmv_f_s_f32m1_f32(mean_square_v); + mean_square /= hidden_size; + + mean_square = sqrt(mean_square + epsilon); + + mean_square = 1.0f / mean_square; + length = hidden_size; + p_temp_output = p_output; + + while (length > 0) { + gvl = __riscv_vsetvl_e32m4(length); + vfloat32m4_t src_data = __riscv_vle32_v_f32m4(p_temp_output, gvl); + src_data = __riscv_vfmul_vf_f32m4(src_data, mean_square, gvl); + __riscv_vse32_v_f32m4(p_output, src_data, gvl); + p_temp_output += gvl; + p_output += gvl; + length -= gvl; + } + } +} + +template <size_t MB_ROWS> +void quantize_a_nrow_i8_ref(size_t blk_len, const float * a_ptr, size_t count_k, uint8_t * quant_a_ptr) { + int64_t a_blk_stride = q8_blk_size(blk_len, true); + int64_t a_nrow_block_stride = a_blk_stride * MB_ROWS; + for (size_t k = 0; k < count_k; k += blk_len, quant_a_ptr += a_nrow_block_stride) { + float * scale_a_ptr = reinterpret_cast<float *>(quant_a_ptr); + int16_t * a_sum_ptr = reinterpret_cast<int16_t *>(quant_a_ptr + sizeof(float) * MB_ROWS); + int8_t * quant_a_blk = + reinterpret_cast<int8_t *>(quant_a_ptr + sizeof(float) * MB_ROWS + sizeof(int16_t) * MB_ROWS); + + for (size_t row = 0; row < MB_ROWS; row++) { + float max_abs_a = 0.0f; + for (size_t bk = 0; bk < blk_len; bk++) { + max_abs_a = std::max(max_abs_a, std::abs(a_ptr[row * count_k + k + bk])); + } + + float rep_scale_a = ((1 << 7) - 1) / max_abs_a; + scale_a_ptr[row] = 1 / rep_scale_a; + + int16_t a_sum = 0; + for (size_t bk = 0; bk < blk_len; bk++) { + const int8_t quantized = static_cast<int8_t>( + std::clamp(std::nearbyintf(a_ptr[row * count_k + k + bk] * rep_scale_a), -128.0f, 127.0f)); + quant_a_blk[row * blk_len + bk] = quantized; + a_sum += quantized; + } + a_sum_ptr[row] = -a_sum; + } + } +} + +template <size_t MB_ROWS> +void quantize_a_nrow_i8_hp_ref(size_t blk_len, const float * a_ptr, size_t count_k, uint8_t * quant_a_ptr) { + constexpr size_t k_subblk_len = 32; + const size_t subblk_count = blk_len / k_subblk_len; + + GGML_ASSERT(blk_len == 256); + + float scale_temp[8] = { 0.0f }; + int64_t a_blk_stride = q8_hp_blk_size(blk_len, true, true); + int64_t a_nrow_block_stride = a_blk_stride * MB_ROWS; + int64_t a_subblk_stride = q8_hp_blk_size(k_subblk_len, false, false) * MB_ROWS; + + for (size_t k = 0; k < count_k; k += blk_len, quant_a_ptr += a_nrow_block_stride) { + _Float16 * a_sum_ptr = reinterpret_cast<_Float16 *>(quant_a_ptr + a_subblk_stride * subblk_count); + + float scale_avg = 0.0f; + for (size_t kk = 0; kk < subblk_count; kk++) { + float max_abs_a = 0.0f; + for (size_t row = 0; row < MB_ROWS; row++) { + for (size_t bk = 0; bk < k_subblk_len; bk++) { + max_abs_a = std::max(max_abs_a, std::abs(a_ptr[row * count_k + k + bk + kk * k_subblk_len])); + } + } + scale_temp[kk] = max_abs_a / ((1 << 7) - 1); + scale_avg += scale_temp[kk]; + } + + scale_avg /= subblk_count; + float scale_factor = 1.0f / scale_avg; + + _Float16 * scale_avg_ptr = + reinterpret_cast<_Float16 *>(quant_a_ptr + a_nrow_block_stride - sizeof(_Float16) * MB_ROWS); + scale_avg_ptr[0] = scale_avg; + + for (size_t kk = 0; kk < subblk_count; kk++) { + uint8_t * a_subblk_base = quant_a_ptr + kk * a_subblk_stride; + _Float16 * scale_a_ptr = reinterpret_cast<_Float16 *>(a_subblk_base); + int8_t * quant_a_blk = reinterpret_cast<int8_t *>(a_subblk_base + sizeof(_Float16) * MB_ROWS); + + scale_a_ptr[0] = static_cast<_Float16>(scale_temp[kk] * scale_factor); + + const float rep_scale_a = 1.0f / scale_temp[kk]; + + for (size_t row = 0; row < MB_ROWS; row++) { + int16_t a_sum = 0; + for (size_t bk = 0; bk < k_subblk_len; bk++) { + const int8_t quantized = static_cast<int8_t>( + std::clamp(std::nearbyintf(a_ptr[row * count_k + k + bk + kk * k_subblk_len] * rep_scale_a), + -128.0f, 127.0f)); + quant_a_blk[row * k_subblk_len + bk] = quantized; + a_sum += quantized; + } + a_sum_ptr[row * subblk_count + kk] = static_cast<_Float16>(-a_sum) * static_cast<_Float16>(8.0f); + } + } + } +} + +template <size_t MB_ROWS> +void quantize_a_nrow_i8k_ref(size_t blk_len, const float * a_ptr, size_t count_k, uint8_t * quant_a_ptr) { + int64_t a_blk_stride = q8k_blk_size(256); + int64_t a_nrow_block_stride = a_blk_stride * MB_ROWS; + int64_t a_sum_size = 256 / 16; + + for (size_t k = 0; k < count_k; k += blk_len, quant_a_ptr += a_nrow_block_stride) { + float * scale_a_ptr = reinterpret_cast<float *>(quant_a_ptr); + int16_t * a_sum_ptr = reinterpret_cast<int16_t *>(quant_a_ptr + sizeof(float) * MB_ROWS); + int8_t * quant_a_blk = + reinterpret_cast<int8_t *>(quant_a_ptr + sizeof(float) * MB_ROWS + sizeof(int16_t) * a_sum_size * MB_ROWS); + + for (size_t row = 0; row < MB_ROWS; row++) { + float max_a = 0.0f; + float max_abs_a = 0.0f; + for (size_t bk = 0; bk < blk_len; bk++) { + float ax = std::abs(a_ptr[row * count_k + k + bk]); + if (ax > max_abs_a) { + max_abs_a = ax; + max_a = a_ptr[row * count_k + k + bk]; + } + } + + if (!max_abs_a) { + scale_a_ptr[row] = 0; + for (size_t bki = 0; bki < a_sum_size; bki++) { + for (size_t bk = bki * 16; bk < (bki + 1) * 16; bk++) { + quant_a_blk[row * blk_len + bk] = 0; + } + a_sum_ptr[row * a_sum_size + bki] = 0; + } + continue; + } + + float rep_scale_a = ((1 << 7) - 1) / max_abs_a; + scale_a_ptr[row] = 1 / rep_scale_a; + + for (size_t bki = 0; bki < a_sum_size; bki++) { + int16_t a_sum = 0; + for (size_t bk = bki * 16; bk < (bki + 1) * 16; bk++) { + const int8_t quantized = static_cast<int8_t>( + std::clamp(std::nearbyintf(a_ptr[row * count_k + k + bk] * rep_scale_a), -128.0f, 127.0f)); + quant_a_blk[row * blk_len + bk] = quantized; + a_sum += quantized; + } + a_sum_ptr[row * a_sum_size + bki] = -a_sum; + } + } + } +} + +void quantize_a_row_i8(size_t blk_len, const float * a_ptr, size_t count_k, uint8_t * quant_a_ptr) { + GGML_ASSERT(blk_len == 32); + int64_t a_blk_stride = q8_blk_size(blk_len, true); + size_t vlenb = __riscv_vlenb(); + + if (vlenb == 128) { + for (size_t k = 0; k < count_k; k += blk_len, quant_a_ptr += a_blk_stride) { + float * scale_a_ptr = reinterpret_cast<float *>(quant_a_ptr); + int16_t * a_sum_ptr = reinterpret_cast<int16_t *>(quant_a_ptr + sizeof(float)); + int8_t * quant_a_blk = reinterpret_cast<int8_t *>(quant_a_ptr + sizeof(float) + sizeof(int16_t)); + + size_t vl = __riscv_vsetvl_e32m1(blk_len); + vfloat32m1_t v_a = __riscv_vle32_v_f32m1(a_ptr + k, vl); + vfloat32m1_t v_a_abs = __riscv_vfabs_v_f32m1(v_a, vl); + + vfloat32m1_t tmp = __riscv_vfmv_v_f_f32m1(0.0f, vl); + vfloat32m1_t v_a_max = __riscv_vfredmax_vs_f32m1_f32m1(v_a_abs, tmp, vl); + float max_abs_a = __riscv_vfmv_f_s_f32m1_f32(v_a_max); + + float scale_a = max_abs_a / ((1 << 7) - 1); + float rep_scale_a = scale_a ? 1.0f / scale_a : 0.0f; + scale_a_ptr[0] = scale_a; + + vfloat32m1_t v_a_scale = __riscv_vfmul_vf_f32m1(v_a, rep_scale_a, vl); + vint16mf2_t v_a_quant = __riscv_vfncvt_x_f_w_i16mf2(v_a_scale, vl); + vint8mf4_t v_a_quant_i8 = __riscv_vncvt_x_x_w_i8mf4(v_a_quant, vl); + + vint16m1_t tmp_sum = __riscv_vmv_v_x_i16m1(0, vl); + vint16m1_t v_a_sum = __riscv_vwredsum_vs_i8mf4_i16m1(v_a_quant_i8, tmp_sum, vl); + int16_t a_sum = __riscv_vmv_x_s_i16m1_i16(v_a_sum); + a_sum_ptr[0] = -a_sum; + + __riscv_vse8_v_i8mf4(quant_a_blk, v_a_quant_i8, vl); + } + } else if (vlenb == 32) { + for (size_t k = 0; k < count_k; k += blk_len, quant_a_ptr += a_blk_stride) { + float * scale_a_ptr = reinterpret_cast<float *>(quant_a_ptr); + int16_t * a_sum_ptr = reinterpret_cast<int16_t *>(quant_a_ptr + sizeof(float)); + int8_t * quant_a_blk = reinterpret_cast<int8_t *>(quant_a_ptr + sizeof(float) + sizeof(int16_t)); + + size_t vl = __riscv_vsetvl_e32m4(blk_len); + vfloat32m4_t v_a = __riscv_vle32_v_f32m4(a_ptr + k, vl); + vfloat32m4_t v_a_abs = __riscv_vfabs_v_f32m4(v_a, vl); + + vfloat32m1_t tmp = __riscv_vfmv_v_f_f32m1(0.0f, vl); + vfloat32m1_t v_a_max = __riscv_vfredmax_vs_f32m4_f32m1(v_a_abs, tmp, vl); + float max_abs_a = __riscv_vfmv_f_s_f32m1_f32(v_a_max); + + float scale_a = max_abs_a / ((1 << 7) - 1); + float rep_scale_a = scale_a ? 1.0f / scale_a : 0.0f; + scale_a_ptr[0] = scale_a; + + vfloat32m4_t v_a_scale = __riscv_vfmul_vf_f32m4(v_a, rep_scale_a, vl); + vint16m2_t v_a_quant = __riscv_vfncvt_x_f_w_i16m2(v_a_scale, vl); + vint8m1_t v_a_quant_i8 = __riscv_vncvt_x_x_w_i8m1(v_a_quant, vl); + + vint16m1_t tmp_sum = __riscv_vmv_v_x_i16m1(0, vl); + vint16m1_t v_a_sum = __riscv_vwredsum_vs_i8m1_i16m1(v_a_quant_i8, tmp_sum, vl); + int16_t a_sum = __riscv_vmv_x_s_i16m1_i16(v_a_sum); + a_sum_ptr[0] = -a_sum; + + __riscv_vse8_v_i8m1(quant_a_blk, v_a_quant_i8, vl); + } + } else { + quantize_a_nrow_i8_ref<1>(blk_len, a_ptr, count_k, quant_a_ptr); + } +} + +void quantize_a_4row_i8(size_t blk_len, const float * a_ptr, size_t count_k, uint8_t * quant_a_ptr) { + GGML_ASSERT(blk_len == 32); + int64_t a_blk_stride = q8_blk_size(blk_len, true); + int64_t a_nrow_block_stride = a_blk_stride * 4; + size_t vlenb = __riscv_vlenb(); + + if (vlenb == 128) { + for (size_t k = 0; k < count_k; k += blk_len, quant_a_ptr += a_nrow_block_stride) { + float * scale_a_ptr = reinterpret_cast<float *>(quant_a_ptr); + int16_t * a_sum_ptr = reinterpret_cast<int16_t *>(quant_a_ptr + sizeof(float) * 4); + int8_t * quant_a_blk = reinterpret_cast<int8_t *>(quant_a_ptr + sizeof(float) * 4 + sizeof(int16_t) * 4); + + for (size_t mi = 0; mi < 4; mi++) { + size_t vl = __riscv_vsetvl_e32m1(blk_len); + vfloat32m1_t v_a = __riscv_vle32_v_f32m1(a_ptr + mi * count_k + k, vl); + vfloat32m1_t v_a_abs = __riscv_vfabs_v_f32m1(v_a, vl); + + vfloat32m1_t tmp = __riscv_vfmv_v_f_f32m1(0.0f, vl); + vfloat32m1_t v_a_max = __riscv_vfredmax_vs_f32m1_f32m1(v_a_abs, tmp, vl); + float max_abs_a = __riscv_vfmv_f_s_f32m1_f32(v_a_max); + + float scale_a = max_abs_a / ((1 << 7) - 1); + float rep_scale_a = scale_a ? 1.0f / scale_a : 0.0f; + scale_a_ptr[mi] = scale_a; + + vfloat32m1_t v_a_scale = __riscv_vfmul_vf_f32m1(v_a, rep_scale_a, vl); + vint16mf2_t v_a_quant = __riscv_vfncvt_x_f_w_i16mf2(v_a_scale, vl); + vint8mf4_t v_a_quant_i8 = __riscv_vncvt_x_x_w_i8mf4(v_a_quant, vl); + + vint16m1_t tmp_sum = __riscv_vmv_v_x_i16m1(0, vl); + vint16m1_t v_a_sum = __riscv_vwredsum_vs_i8mf4_i16m1(v_a_quant_i8, tmp_sum, vl); + int16_t a_sum = __riscv_vmv_x_s_i16m1_i16(v_a_sum); + a_sum_ptr[mi] = -a_sum; + + __riscv_vse8_v_i8mf4(quant_a_blk + mi * blk_len, v_a_quant_i8, vl); + } + } + } else if (vlenb == 32) { + for (size_t k = 0; k < count_k; k += blk_len, quant_a_ptr += a_nrow_block_stride) { + float * scale_a_ptr = reinterpret_cast<float *>(quant_a_ptr); + int16_t * a_sum_ptr = reinterpret_cast<int16_t *>(quant_a_ptr + sizeof(float) * 4); + int8_t * quant_a_blk = reinterpret_cast<int8_t *>(quant_a_ptr + sizeof(float) * 4 + sizeof(int16_t) * 4); + + for (size_t mi = 0; mi < 4; mi++) { + size_t vl = __riscv_vsetvl_e32m4(blk_len); + vfloat32m4_t v_a = __riscv_vle32_v_f32m4(a_ptr + mi * count_k + k, vl); + vfloat32m4_t v_a_abs = __riscv_vfabs_v_f32m4(v_a, vl); + + vfloat32m1_t tmp = __riscv_vfmv_v_f_f32m1(0.0f, vl); + vfloat32m1_t v_a_max = __riscv_vfredmax_vs_f32m4_f32m1(v_a_abs, tmp, vl); + float max_abs_a = __riscv_vfmv_f_s_f32m1_f32(v_a_max); + + float scale_a = max_abs_a / ((1 << 7) - 1); + float rep_scale_a = scale_a ? 1.0f / scale_a : 0.0f; + scale_a_ptr[mi] = scale_a; + + vfloat32m4_t v_a_scale = __riscv_vfmul_vf_f32m4(v_a, rep_scale_a, vl); + vint16m2_t v_a_quant = __riscv_vfncvt_x_f_w_i16m2(v_a_scale, vl); + vint8m1_t v_a_quant_i8 = __riscv_vncvt_x_x_w_i8m1(v_a_quant, vl); + + vint16m1_t tmp_sum = __riscv_vmv_v_x_i16m1(0, vl); + vint16m1_t v_a_sum = __riscv_vwredsum_vs_i8m1_i16m1(v_a_quant_i8, tmp_sum, vl); + int16_t a_sum = __riscv_vmv_x_s_i16m1_i16(v_a_sum); + a_sum_ptr[mi] = -a_sum; + + __riscv_vse8_v_i8m1(quant_a_blk + mi * blk_len, v_a_quant_i8, vl); + } + } + } else { + quantize_a_nrow_i8_ref<4>(blk_len, a_ptr, count_k, quant_a_ptr); + } +} + +void quantize_a_row_i8_hp(size_t blk_len, const float * a_ptr, size_t count_k, uint8_t * quant_a_ptr) { + constexpr size_t k_subblk_len = 32; + GGML_ASSERT(blk_len == 256); + + constexpr size_t subblk_count = 256 / k_subblk_len; + int64_t a_blk_stride = q8_hp_blk_size(blk_len, true, true); + int64_t a_subblk_stride = q8_hp_blk_size(k_subblk_len, false, false); + size_t vlenb = __riscv_vlenb(); + float scale_temp[subblk_count] = { 0.0f }; + + if (vlenb == 128) { + for (size_t k = 0; k < count_k; k += blk_len, quant_a_ptr += a_blk_stride) { + _Float16 * a_sum_ptr = reinterpret_cast<_Float16 *>(quant_a_ptr + a_subblk_stride * subblk_count); + _Float16 * scale_avg_ptr = reinterpret_cast<_Float16 *>(quant_a_ptr + a_blk_stride - sizeof(_Float16)); + float scale_avg = 0.0f; + + for (size_t kk = 0; kk < subblk_count; ++kk) { + const float * a_src_ptr = a_ptr + k + kk * k_subblk_len; + + size_t vl = __riscv_vsetvl_e32m1(k_subblk_len); + vfloat32m1_t v_a = __riscv_vle32_v_f32m1(a_src_ptr, vl); + vfloat32m1_t v_a_abs = __riscv_vfabs_v_f32m1(v_a, vl); + + vfloat32m1_t tmp = __riscv_vfmv_v_f_f32m1(0.0f, vl); + vfloat32m1_t v_a_max = __riscv_vfredmax_vs_f32m1_f32m1(v_a_abs, tmp, vl); + float max_abs_a = __riscv_vfmv_f_s_f32m1_f32(v_a_max); + + scale_temp[kk] = max_abs_a / ((1 << 7) - 1); + scale_avg += scale_temp[kk]; + } + + scale_avg /= subblk_count; + const float scale_factor = scale_avg ? 1.0f / scale_avg : 0.0f; + scale_avg_ptr[0] = static_cast<_Float16>(scale_avg); + + for (size_t kk = 0; kk < subblk_count; ++kk) { + uint8_t * a_subblk_base = quant_a_ptr + kk * a_subblk_stride; + _Float16 * scale_a_ptr = reinterpret_cast<_Float16 *>(a_subblk_base); + int8_t * quant_a_blk = reinterpret_cast<int8_t *>(a_subblk_base + sizeof(_Float16)); + const float * a_src_ptr = a_ptr + k + kk * k_subblk_len; + + size_t vl = __riscv_vsetvl_e32m1(k_subblk_len); + vfloat32m1_t v_a = __riscv_vle32_v_f32m1(a_src_ptr, vl); + float rep_scale_a = scale_temp[kk] ? 1.0f / scale_temp[kk] : 0.0f; + scale_a_ptr[0] = static_cast<_Float16>(scale_temp[kk] * scale_factor); + + vfloat32m1_t v_a_scale = __riscv_vfmul_vf_f32m1(v_a, rep_scale_a, vl); + vint16mf2_t v_a_quant = __riscv_vfncvt_x_f_w_i16mf2(v_a_scale, vl); + vint8mf4_t v_a_quant_i8 = __riscv_vncvt_x_x_w_i8mf4(v_a_quant, vl); + + vint16m1_t tmp_sum = __riscv_vmv_v_x_i16m1(0, vl); + vint16m1_t v_a_sum = __riscv_vwredsum_vs_i8mf4_i16m1(v_a_quant_i8, tmp_sum, vl); + int16_t a_sum = __riscv_vmv_x_s_i16m1_i16(v_a_sum); + a_sum_ptr[kk] = static_cast<_Float16>(-a_sum) * static_cast<_Float16>(8.0f); + + __riscv_vse8_v_i8mf4(quant_a_blk, v_a_quant_i8, vl); + } + } + } else if (vlenb == 32) { + for (size_t k = 0; k < count_k; k += blk_len, quant_a_ptr += a_blk_stride) { + _Float16 * a_sum_ptr = reinterpret_cast<_Float16 *>(quant_a_ptr + a_subblk_stride * subblk_count); + _Float16 * scale_avg_ptr = reinterpret_cast<_Float16 *>(quant_a_ptr + a_blk_stride - sizeof(_Float16)); + float scale_avg = 0.0f; + + for (size_t kk = 0; kk < subblk_count; ++kk) { + const float * a_src_ptr = a_ptr + k + kk * k_subblk_len; + + size_t vl = __riscv_vsetvl_e32m4(k_subblk_len); + vfloat32m4_t v_a = __riscv_vle32_v_f32m4(a_src_ptr, vl); + vfloat32m4_t v_a_abs = __riscv_vfabs_v_f32m4(v_a, vl); + + vfloat32m1_t tmp = __riscv_vfmv_v_f_f32m1(0.0f, vl); + vfloat32m1_t v_a_max = __riscv_vfredmax_vs_f32m4_f32m1(v_a_abs, tmp, vl); + float max_abs_a = __riscv_vfmv_f_s_f32m1_f32(v_a_max); + + scale_temp[kk] = max_abs_a / ((1 << 7) - 1); + scale_avg += scale_temp[kk]; + } + + scale_avg /= subblk_count; + const float scale_factor = scale_avg ? 1.0f / scale_avg : 0.0f; + scale_avg_ptr[0] = static_cast<_Float16>(scale_avg); + + for (size_t kk = 0; kk < subblk_count; ++kk) { + uint8_t * a_subblk_base = quant_a_ptr + kk * a_subblk_stride; + _Float16 * scale_a_ptr = reinterpret_cast<_Float16 *>(a_subblk_base); + int8_t * quant_a_blk = reinterpret_cast<int8_t *>(a_subblk_base + sizeof(_Float16)); + const float * a_src_ptr = a_ptr + k + kk * k_subblk_len; + + size_t vl = __riscv_vsetvl_e32m4(k_subblk_len); + vfloat32m4_t v_a = __riscv_vle32_v_f32m4(a_src_ptr, vl); + float rep_scale_a = scale_temp[kk] ? 1.0f / scale_temp[kk] : 0.0f; + scale_a_ptr[0] = static_cast<_Float16>(scale_temp[kk] * scale_factor); + + vfloat32m4_t v_a_scale = __riscv_vfmul_vf_f32m4(v_a, rep_scale_a, vl); + vint16m2_t v_a_quant = __riscv_vfncvt_x_f_w_i16m2(v_a_scale, vl); + vint8m1_t v_a_quant_i8 = __riscv_vncvt_x_x_w_i8m1(v_a_quant, vl); + + vint16m1_t tmp_sum = __riscv_vmv_v_x_i16m1(0, vl); + vint16m1_t v_a_sum = __riscv_vwredsum_vs_i8m1_i16m1(v_a_quant_i8, tmp_sum, vl); + int16_t a_sum = __riscv_vmv_x_s_i16m1_i16(v_a_sum); + a_sum_ptr[kk] = static_cast<_Float16>(-a_sum) * static_cast<_Float16>(8.0f); + + __riscv_vse8_v_i8m1(quant_a_blk, v_a_quant_i8, vl); + } + } + } else { + quantize_a_nrow_i8_hp_ref<1>(blk_len, a_ptr, count_k, quant_a_ptr); + } +} + +void quantize_a_4row_i8_hp(size_t blk_len, const float * a_ptr, size_t count_k, uint8_t * quant_a_ptr) { + constexpr size_t k_subblk_len = 32; + GGML_ASSERT(blk_len == 256); + + constexpr size_t subblk_count = 256 / k_subblk_len; + int64_t a_blk_stride = q8_hp_blk_size(blk_len, true, true); + int64_t a_nrow_block_stride = a_blk_stride * 4; + int64_t a_subblk_stride = q8_hp_blk_size(k_subblk_len, false, false) * 4; + size_t vlenb = __riscv_vlenb(); + float scale_temp[subblk_count] = { 0.0f }; + + if (vlenb == 128) { + for (size_t k = 0; k < count_k; k += blk_len, quant_a_ptr += a_nrow_block_stride) { + _Float16 * a_sum_ptr = reinterpret_cast<_Float16 *>(quant_a_ptr + a_subblk_stride * subblk_count); + _Float16 * scale_avg_ptr = + reinterpret_cast<_Float16 *>(quant_a_ptr + a_nrow_block_stride - sizeof(_Float16) * 4); + float scale_avg = 0.0f; + + for (size_t kk = 0; kk < subblk_count; ++kk) { + const float * a_src_ptr0 = a_ptr + 0 * count_k + k + kk * k_subblk_len; + const float * a_src_ptr1 = a_ptr + 1 * count_k + k + kk * k_subblk_len; + const float * a_src_ptr2 = a_ptr + 2 * count_k + k + kk * k_subblk_len; + const float * a_src_ptr3 = a_ptr + 3 * count_k + k + kk * k_subblk_len; + + size_t vl = __riscv_vsetvl_e32m1(k_subblk_len); + vfloat32m1_t v_a0 = __riscv_vle32_v_f32m1(a_src_ptr0, vl); + vfloat32m1_t v_a1 = __riscv_vle32_v_f32m1(a_src_ptr1, vl); + vfloat32m1_t v_a2 = __riscv_vle32_v_f32m1(a_src_ptr2, vl); + vfloat32m1_t v_a3 = __riscv_vle32_v_f32m1(a_src_ptr3, vl); + vfloat32m1_t v_a0_abs = __riscv_vfabs_v_f32m1(v_a0, vl); + vfloat32m1_t v_a1_abs = __riscv_vfabs_v_f32m1(v_a1, vl); + vfloat32m1_t v_a2_abs = __riscv_vfabs_v_f32m1(v_a2, vl); + vfloat32m1_t v_a3_abs = __riscv_vfabs_v_f32m1(v_a3, vl); + + vfloat32m1_t v_max_abs = __riscv_vfmax_vv_f32m1(v_a0_abs, v_a1_abs, vl); + v_max_abs = __riscv_vfmax_vv_f32m1(v_max_abs, v_a2_abs, vl); + v_max_abs = __riscv_vfmax_vv_f32m1(v_max_abs, v_a3_abs, vl); + + vfloat32m1_t tmp = __riscv_vfmv_v_f_f32m1(0.0f, vl); + vfloat32m1_t v_a_max = __riscv_vfredmax_vs_f32m1_f32m1(v_max_abs, tmp, vl); + float max_abs_a = __riscv_vfmv_f_s_f32m1_f32(v_a_max); + + scale_temp[kk] = max_abs_a / ((1 << 7) - 1); + scale_avg += scale_temp[kk]; + } + + scale_avg /= subblk_count; + const float scale_factor = scale_avg ? 1.0f / scale_avg : 0.0f; + scale_avg_ptr[0] = static_cast<_Float16>(scale_avg); + + for (size_t kk = 0; kk < subblk_count; ++kk) { + uint8_t * a_subblk_base = quant_a_ptr + kk * a_subblk_stride; + _Float16 * scale_a_ptr = reinterpret_cast<_Float16 *>(a_subblk_base); + int8_t * quant_a_blk = reinterpret_cast<int8_t *>(a_subblk_base + sizeof(_Float16) * 4); + const float * a_src_ptr0 = a_ptr + 0 * count_k + k + kk * k_subblk_len; + const float * a_src_ptr1 = a_ptr + 1 * count_k + k + kk * k_subblk_len; + const float * a_src_ptr2 = a_ptr + 2 * count_k + k + kk * k_subblk_len; + const float * a_src_ptr3 = a_ptr + 3 * count_k + k + kk * k_subblk_len; + + size_t vl = __riscv_vsetvl_e32m1(k_subblk_len); + vfloat32m1_t v_a0 = __riscv_vle32_v_f32m1(a_src_ptr0, vl); + vfloat32m1_t v_a1 = __riscv_vle32_v_f32m1(a_src_ptr1, vl); + vfloat32m1_t v_a2 = __riscv_vle32_v_f32m1(a_src_ptr2, vl); + vfloat32m1_t v_a3 = __riscv_vle32_v_f32m1(a_src_ptr3, vl); + + float rep_scale_a = scale_temp[kk] ? 1.0f / scale_temp[kk] : 0.0f; + scale_a_ptr[0] = static_cast<_Float16>(scale_temp[kk] * scale_factor); + + vfloat32m1_t v_a0_scale = __riscv_vfmul_vf_f32m1(v_a0, rep_scale_a, vl); + vfloat32m1_t v_a1_scale = __riscv_vfmul_vf_f32m1(v_a1, rep_scale_a, vl); + vfloat32m1_t v_a2_scale = __riscv_vfmul_vf_f32m1(v_a2, rep_scale_a, vl); + vfloat32m1_t v_a3_scale = __riscv_vfmul_vf_f32m1(v_a3, rep_scale_a, vl); + vint16mf2_t v_a0_quant = __riscv_vfncvt_x_f_w_i16mf2(v_a0_scale, vl); + vint16mf2_t v_a1_quant = __riscv_vfncvt_x_f_w_i16mf2(v_a1_scale, vl); + vint16mf2_t v_a2_quant = __riscv_vfncvt_x_f_w_i16mf2(v_a2_scale, vl); + vint16mf2_t v_a3_quant = __riscv_vfncvt_x_f_w_i16mf2(v_a3_scale, vl); + vint8mf4_t v_a0_quant_i8 = __riscv_vncvt_x_x_w_i8mf4(v_a0_quant, vl); + vint8mf4_t v_a1_quant_i8 = __riscv_vncvt_x_x_w_i8mf4(v_a1_quant, vl); + vint8mf4_t v_a2_quant_i8 = __riscv_vncvt_x_x_w_i8mf4(v_a2_quant, vl); + vint8mf4_t v_a3_quant_i8 = __riscv_vncvt_x_x_w_i8mf4(v_a3_quant, vl); + + vint16m1_t tmp_sum0 = __riscv_vmv_v_x_i16m1(0, vl); + vint16m1_t tmp_sum1 = __riscv_vmv_v_x_i16m1(0, vl); + vint16m1_t tmp_sum2 = __riscv_vmv_v_x_i16m1(0, vl); + vint16m1_t tmp_sum3 = __riscv_vmv_v_x_i16m1(0, vl); + vint16m1_t v_a0_sum = __riscv_vwredsum_vs_i8mf4_i16m1(v_a0_quant_i8, tmp_sum0, vl); + vint16m1_t v_a1_sum = __riscv_vwredsum_vs_i8mf4_i16m1(v_a1_quant_i8, tmp_sum1, vl); + vint16m1_t v_a2_sum = __riscv_vwredsum_vs_i8mf4_i16m1(v_a2_quant_i8, tmp_sum2, vl); + vint16m1_t v_a3_sum = __riscv_vwredsum_vs_i8mf4_i16m1(v_a3_quant_i8, tmp_sum3, vl); + + a_sum_ptr[0 * subblk_count + kk] = + static_cast<_Float16>(-__riscv_vmv_x_s_i16m1_i16(v_a0_sum)) * static_cast<_Float16>(8.0f); + a_sum_ptr[1 * subblk_count + kk] = + static_cast<_Float16>(-__riscv_vmv_x_s_i16m1_i16(v_a1_sum)) * static_cast<_Float16>(8.0f); + a_sum_ptr[2 * subblk_count + kk] = + static_cast<_Float16>(-__riscv_vmv_x_s_i16m1_i16(v_a2_sum)) * static_cast<_Float16>(8.0f); + a_sum_ptr[3 * subblk_count + kk] = + static_cast<_Float16>(-__riscv_vmv_x_s_i16m1_i16(v_a3_sum)) * static_cast<_Float16>(8.0f); + + __riscv_vse8_v_i8mf4(quant_a_blk + 0 * k_subblk_len, v_a0_quant_i8, vl); + __riscv_vse8_v_i8mf4(quant_a_blk + 1 * k_subblk_len, v_a1_quant_i8, vl); + __riscv_vse8_v_i8mf4(quant_a_blk + 2 * k_subblk_len, v_a2_quant_i8, vl); + __riscv_vse8_v_i8mf4(quant_a_blk + 3 * k_subblk_len, v_a3_quant_i8, vl); + } + } + } else if (vlenb == 32) { + for (size_t k = 0; k < count_k; k += blk_len, quant_a_ptr += a_nrow_block_stride) { + _Float16 * a_sum_ptr = reinterpret_cast<_Float16 *>(quant_a_ptr + a_subblk_stride * subblk_count); + _Float16 * scale_avg_ptr = + reinterpret_cast<_Float16 *>(quant_a_ptr + a_nrow_block_stride - sizeof(_Float16) * 4); + float scale_avg = 0.0f; + + for (size_t kk = 0; kk < subblk_count; ++kk) { + const float * a_src_ptr0 = a_ptr + 0 * count_k + k + kk * k_subblk_len; + const float * a_src_ptr1 = a_ptr + 1 * count_k + k + kk * k_subblk_len; + const float * a_src_ptr2 = a_ptr + 2 * count_k + k + kk * k_subblk_len; + const float * a_src_ptr3 = a_ptr + 3 * count_k + k + kk * k_subblk_len; + + size_t vl = __riscv_vsetvl_e32m4(k_subblk_len); + vfloat32m4_t v_a0 = __riscv_vle32_v_f32m4(a_src_ptr0, vl); + vfloat32m4_t v_a1 = __riscv_vle32_v_f32m4(a_src_ptr1, vl); + vfloat32m4_t v_a2 = __riscv_vle32_v_f32m4(a_src_ptr2, vl); + vfloat32m4_t v_a3 = __riscv_vle32_v_f32m4(a_src_ptr3, vl); + + vfloat32m4_t v_a0_abs = __riscv_vfabs_v_f32m4(v_a0, vl); + vfloat32m4_t v_a1_abs = __riscv_vfabs_v_f32m4(v_a1, vl); + vfloat32m4_t v_a2_abs = __riscv_vfabs_v_f32m4(v_a2, vl); + vfloat32m4_t v_a3_abs = __riscv_vfabs_v_f32m4(v_a3, vl); + + vfloat32m4_t v_max_abs = __riscv_vfmax_vv_f32m4(v_a0_abs, v_a1_abs, vl); + v_max_abs = __riscv_vfmax_vv_f32m4(v_max_abs, v_a2_abs, vl); + v_max_abs = __riscv_vfmax_vv_f32m4(v_max_abs, v_a3_abs, vl); + + vfloat32m1_t tmp = __riscv_vfmv_v_f_f32m1(0.0f, vl); + vfloat32m1_t v_a_max = __riscv_vfredmax_vs_f32m4_f32m1(v_max_abs, tmp, vl); + float max_abs_a = __riscv_vfmv_f_s_f32m1_f32(v_a_max); + + scale_temp[kk] = max_abs_a / ((1 << 7) - 1); + scale_avg += scale_temp[kk]; + } + + scale_avg /= subblk_count; + const float scale_factor = scale_avg ? 1.0f / scale_avg : 0.0f; + scale_avg_ptr[0] = static_cast<_Float16>(scale_avg); + + for (size_t kk = 0; kk < subblk_count; ++kk) { + uint8_t * a_subblk_base = quant_a_ptr + kk * a_subblk_stride; + _Float16 * scale_a_ptr = reinterpret_cast<_Float16 *>(a_subblk_base); + int8_t * quant_a_blk = reinterpret_cast<int8_t *>(a_subblk_base + sizeof(_Float16) * 4); + const float * a_src_ptr0 = a_ptr + 0 * count_k + k + kk * k_subblk_len; + const float * a_src_ptr1 = a_ptr + 1 * count_k + k + kk * k_subblk_len; + const float * a_src_ptr2 = a_ptr + 2 * count_k + k + kk * k_subblk_len; + const float * a_src_ptr3 = a_ptr + 3 * count_k + k + kk * k_subblk_len; + + size_t vl = __riscv_vsetvl_e32m4(k_subblk_len); + vfloat32m4_t v_a0 = __riscv_vle32_v_f32m4(a_src_ptr0, vl); + vfloat32m4_t v_a1 = __riscv_vle32_v_f32m4(a_src_ptr1, vl); + vfloat32m4_t v_a2 = __riscv_vle32_v_f32m4(a_src_ptr2, vl); + vfloat32m4_t v_a3 = __riscv_vle32_v_f32m4(a_src_ptr3, vl); + + float rep_scale_a = scale_temp[kk] ? 1.0f / scale_temp[kk] : 0.0f; + scale_a_ptr[0] = static_cast<_Float16>(scale_temp[kk] * scale_factor); + + vfloat32m4_t v_a0_scale = __riscv_vfmul_vf_f32m4(v_a0, rep_scale_a, vl); + vfloat32m4_t v_a1_scale = __riscv_vfmul_vf_f32m4(v_a1, rep_scale_a, vl); + vfloat32m4_t v_a2_scale = __riscv_vfmul_vf_f32m4(v_a2, rep_scale_a, vl); + vfloat32m4_t v_a3_scale = __riscv_vfmul_vf_f32m4(v_a3, rep_scale_a, vl); + vint16m2_t v_a0_quant = __riscv_vfncvt_x_f_w_i16m2(v_a0_scale, vl); + vint16m2_t v_a1_quant = __riscv_vfncvt_x_f_w_i16m2(v_a1_scale, vl); + vint16m2_t v_a2_quant = __riscv_vfncvt_x_f_w_i16m2(v_a2_scale, vl); + vint16m2_t v_a3_quant = __riscv_vfncvt_x_f_w_i16m2(v_a3_scale, vl); + vint8m1_t v_a0_quant_i8 = __riscv_vncvt_x_x_w_i8m1(v_a0_quant, vl); + vint8m1_t v_a1_quant_i8 = __riscv_vncvt_x_x_w_i8m1(v_a1_quant, vl); + vint8m1_t v_a2_quant_i8 = __riscv_vncvt_x_x_w_i8m1(v_a2_quant, vl); + vint8m1_t v_a3_quant_i8 = __riscv_vncvt_x_x_w_i8m1(v_a3_quant, vl); + + vint16m1_t tmp_sum0 = __riscv_vmv_v_x_i16m1(0, vl); + vint16m1_t tmp_sum1 = __riscv_vmv_v_x_i16m1(0, vl); + vint16m1_t tmp_sum2 = __riscv_vmv_v_x_i16m1(0, vl); + vint16m1_t tmp_sum3 = __riscv_vmv_v_x_i16m1(0, vl); + vint16m1_t v_a0_sum = __riscv_vwredsum_vs_i8m1_i16m1(v_a0_quant_i8, tmp_sum0, vl); + vint16m1_t v_a1_sum = __riscv_vwredsum_vs_i8m1_i16m1(v_a1_quant_i8, tmp_sum1, vl); + vint16m1_t v_a2_sum = __riscv_vwredsum_vs_i8m1_i16m1(v_a2_quant_i8, tmp_sum2, vl); + vint16m1_t v_a3_sum = __riscv_vwredsum_vs_i8m1_i16m1(v_a3_quant_i8, tmp_sum3, vl); + + a_sum_ptr[0 * subblk_count + kk] = + static_cast<_Float16>(-__riscv_vmv_x_s_i16m1_i16(v_a0_sum)) * static_cast<_Float16>(8.0f); + a_sum_ptr[1 * subblk_count + kk] = + static_cast<_Float16>(-__riscv_vmv_x_s_i16m1_i16(v_a1_sum)) * static_cast<_Float16>(8.0f); + a_sum_ptr[2 * subblk_count + kk] = + static_cast<_Float16>(-__riscv_vmv_x_s_i16m1_i16(v_a2_sum)) * static_cast<_Float16>(8.0f); + a_sum_ptr[3 * subblk_count + kk] = + static_cast<_Float16>(-__riscv_vmv_x_s_i16m1_i16(v_a3_sum)) * static_cast<_Float16>(8.0f); + + __riscv_vse8_v_i8m1(quant_a_blk + 0 * k_subblk_len, v_a0_quant_i8, vl); + __riscv_vse8_v_i8m1(quant_a_blk + 1 * k_subblk_len, v_a1_quant_i8, vl); + __riscv_vse8_v_i8m1(quant_a_blk + 2 * k_subblk_len, v_a2_quant_i8, vl); + __riscv_vse8_v_i8m1(quant_a_blk + 3 * k_subblk_len, v_a3_quant_i8, vl); + } + } + } else { + quantize_a_nrow_i8_hp_ref<4>(blk_len, a_ptr, count_k, quant_a_ptr); + } +} + +void quantize_a_row_i8k(size_t blk_len, const float * a_ptr, size_t count_k, uint8_t * quant_a_ptr) { + GGML_ASSERT(blk_len == 256); + constexpr int64_t a_blk_stride = q8k_blk_size(256); + constexpr int64_t a_sum_size = 256 / 16; + size_t vlenb = __riscv_vlenb(); + + if (vlenb == 128) { + // vlen = 1024 bits, can process 32 float32 elements with m1 + for (size_t k = 0; k < count_k; k += blk_len, quant_a_ptr += a_blk_stride) { + float * scale_a_ptr = reinterpret_cast<float *>(quant_a_ptr); + int16_t * a_sum_ptr = reinterpret_cast<int16_t *>(quant_a_ptr + sizeof(float)); + int8_t * quant_a_blk = + reinterpret_cast<int8_t *>(quant_a_ptr + sizeof(float) + sizeof(int16_t) * a_sum_size); + + // Find max absolute value across all 256 elements + size_t vl = __riscv_vsetvl_e32m1(16); + vfloat32m1_t v_max_abs = __riscv_vfmv_v_f_f32m1(0.0f, vl); + + for (size_t bki = 0; bki < a_sum_size; bki++) { + vfloat32m1_t v_a = __riscv_vle32_v_f32m1(a_ptr + k + bki * 16, vl); + vfloat32m1_t v_a_abs = __riscv_vfabs_v_f32m1(v_a, vl); + v_max_abs = __riscv_vfmax_vv_f32m1(v_a_abs, v_max_abs, vl); + } + vfloat32m1_t tmp = __riscv_vfmv_v_f_f32m1(0.0f, vl); + vfloat32m1_t v_local_max = __riscv_vfredmax_vs_f32m1_f32m1(v_max_abs, tmp, vl); + float max_abs_a = __riscv_vfmv_f_s_f32m1_f32(v_local_max); + + float scale_a = max_abs_a / ((1 << 7) - 1); + float rep_scale_a = scale_a ? 1.0f / scale_a : 0.0f; + scale_a_ptr[0] = scale_a; + + // Quantize and compute sums for each 16-element group + for (size_t bki = 0; bki < a_sum_size; bki++) { + vfloat32m1_t v_a = __riscv_vle32_v_f32m1(a_ptr + k + bki * 16, vl); + vfloat32m1_t v_a_scale = __riscv_vfmul_vf_f32m1(v_a, rep_scale_a, vl); + vint16mf2_t v_a_quant = __riscv_vfncvt_x_f_w_i16mf2(v_a_scale, vl); + vint8mf4_t v_a_quant_i8 = __riscv_vncvt_x_x_w_i8mf4(v_a_quant, vl); + + vint16m1_t tmp_sum = __riscv_vmv_v_x_i16m1(0, vl); + vint16m1_t v_a_sum = __riscv_vwredsum_vs_i8mf4_i16m1(v_a_quant_i8, tmp_sum, vl); + int16_t a_sum = __riscv_vmv_x_s_i16m1_i16(v_a_sum); + a_sum_ptr[bki] = -a_sum; + + __riscv_vse8_v_i8mf4(quant_a_blk + bki * 16, v_a_quant_i8, vl); + } + } + } else if (vlenb == 32) { + // vlen = 256 bits, can process 8 float32 elements with m1 + for (size_t k = 0; k < count_k; k += blk_len, quant_a_ptr += a_blk_stride) { + float * scale_a_ptr = reinterpret_cast<float *>(quant_a_ptr); + int16_t * a_sum_ptr = reinterpret_cast<int16_t *>(quant_a_ptr + sizeof(float)); + int8_t * quant_a_blk = + reinterpret_cast<int8_t *>(quant_a_ptr + sizeof(float) + sizeof(int16_t) * a_sum_size); + + // Find max absolute value across all 256 elements + size_t vl = __riscv_vsetvl_e32m2(16); + vfloat32m2_t v_max_abs = __riscv_vfmv_v_f_f32m2(0.0f, vl); + + for (size_t bki = 0; bki < a_sum_size; bki++) { + vfloat32m2_t v_a = __riscv_vle32_v_f32m2(a_ptr + k + bki * 16, vl); + vfloat32m2_t v_a_abs = __riscv_vfabs_v_f32m2(v_a, vl); + v_max_abs = __riscv_vfmax_vv_f32m2(v_a_abs, v_max_abs, vl); + } + vfloat32m1_t tmp = __riscv_vfmv_v_f_f32m1(0.0f, vl); + vfloat32m1_t v_local_max = __riscv_vfredmax_vs_f32m2_f32m1(v_max_abs, tmp, vl); + float max_abs_a = __riscv_vfmv_f_s_f32m1_f32(v_local_max); + + float scale_a = max_abs_a / ((1 << 7) - 1); + float rep_scale_a = scale_a ? 1.0f / scale_a : 0.0f; + scale_a_ptr[0] = scale_a; + + // Quantize and compute sums for each 16-element group + for (size_t bki = 0; bki < a_sum_size; bki++) { + vfloat32m2_t v_a = __riscv_vle32_v_f32m2(a_ptr + k + bki * 16, vl); + vfloat32m2_t v_a_scale = __riscv_vfmul_vf_f32m2(v_a, rep_scale_a, vl); + vint16m1_t v_a_quant = __riscv_vfncvt_x_f_w_i16m1(v_a_scale, vl); + vint8mf2_t v_a_quant_i8 = __riscv_vncvt_x_x_w_i8mf2(v_a_quant, vl); + + vint16m1_t tmp_sum = __riscv_vmv_v_x_i16m1(0, vl); + vint16m1_t v_a_sum = __riscv_vwredsum_vs_i8mf2_i16m1(v_a_quant_i8, tmp_sum, vl); + int16_t a_sum = __riscv_vmv_x_s_i16m1_i16(v_a_sum); + a_sum_ptr[bki] = -a_sum; + + __riscv_vse8_v_i8mf2(quant_a_blk + bki * 16, v_a_quant_i8, vl); + } + } + } else { + quantize_a_nrow_i8k_ref<1>(blk_len, a_ptr, count_k, quant_a_ptr); + } +} + +void quantize_a_4row_i8k(size_t blk_len, const float * a_ptr, size_t count_k, uint8_t * quant_a_ptr) { + GGML_ASSERT(blk_len == 256); + constexpr int64_t a_blk_stride = q8k_blk_size(256); + constexpr int64_t a_nrow_block_stride = a_blk_stride * 4; + constexpr int64_t a_sum_size = 256 / 16; + size_t vlenb = __riscv_vlenb(); + + if (vlenb == 128) { + // vlen = 1024 bits + for (size_t k = 0; k < count_k; k += blk_len, quant_a_ptr += a_nrow_block_stride) { + float * scale_a_ptr = reinterpret_cast<float *>(quant_a_ptr); + int16_t * a_sum_ptr = reinterpret_cast<int16_t *>(quant_a_ptr + sizeof(float) * 4); + int8_t * quant_a_blk = + reinterpret_cast<int8_t *>(quant_a_ptr + sizeof(float) * 4 + sizeof(int16_t) * a_sum_size * 4); + + for (size_t mi = 0; mi < 4; mi++) { + // Find max absolute value across all 256 elements for this row + size_t vl = __riscv_vsetvl_e32m1(16); + vfloat32m1_t v_max_abs = __riscv_vfmv_v_f_f32m1(0.0f, vl); + + for (size_t bki = 0; bki < a_sum_size; bki++) { + vfloat32m1_t v_a = __riscv_vle32_v_f32m1(a_ptr + mi * count_k + k + bki * 16, vl); + vfloat32m1_t v_a_abs = __riscv_vfabs_v_f32m1(v_a, vl); + v_max_abs = __riscv_vfmax_vv_f32m1(v_a_abs, v_max_abs, vl); + } + vfloat32m1_t tmp = __riscv_vfmv_v_f_f32m1(0.0f, vl); + vfloat32m1_t v_local_max = __riscv_vfredmax_vs_f32m1_f32m1(v_max_abs, tmp, vl); + float max_abs_a = __riscv_vfmv_f_s_f32m1_f32(v_local_max); + + float scale_a = max_abs_a / ((1 << 7) - 1); + float rep_scale_a = scale_a ? 1.0f / scale_a : 0.0f; + scale_a_ptr[mi] = scale_a; + + // Quantize and compute sums for each 16-element group + for (size_t bki = 0; bki < a_sum_size; bki++) { + vfloat32m1_t v_a = __riscv_vle32_v_f32m1(a_ptr + mi * count_k + k + bki * 16, vl); + vfloat32m1_t v_a_scale = __riscv_vfmul_vf_f32m1(v_a, rep_scale_a, vl); + vint16mf2_t v_a_quant = __riscv_vfncvt_x_f_w_i16mf2(v_a_scale, vl); + vint8mf4_t v_a_quant_i8 = __riscv_vncvt_x_x_w_i8mf4(v_a_quant, vl); + + vint16m1_t tmp_sum = __riscv_vmv_v_x_i16m1(0, vl); + vint16m1_t v_a_sum = __riscv_vwredsum_vs_i8mf4_i16m1(v_a_quant_i8, tmp_sum, vl); + int16_t a_sum = __riscv_vmv_x_s_i16m1_i16(v_a_sum); + a_sum_ptr[mi * a_sum_size + bki] = -a_sum; + + __riscv_vse8_v_i8mf4(quant_a_blk + mi * blk_len + bki * 16, v_a_quant_i8, vl); + } + } + } + } else if (vlenb == 32) { + // vlen = 256 bits + for (size_t k = 0; k < count_k; k += blk_len, quant_a_ptr += a_nrow_block_stride) { + float * scale_a_ptr = reinterpret_cast<float *>(quant_a_ptr); + int16_t * a_sum_ptr = reinterpret_cast<int16_t *>(quant_a_ptr + sizeof(float) * 4); + int8_t * quant_a_blk = + reinterpret_cast<int8_t *>(quant_a_ptr + sizeof(float) * 4 + sizeof(int16_t) * a_sum_size * 4); + + for (size_t mi = 0; mi < 4; mi++) { + // Find max absolute value across all 256 elements for this row + size_t vl = __riscv_vsetvl_e32m2(16); + vfloat32m2_t v_max_abs = __riscv_vfmv_v_f_f32m2(0.0f, vl); + + for (size_t bki = 0; bki < a_sum_size; bki++) { + vfloat32m2_t v_a = __riscv_vle32_v_f32m2(a_ptr + mi * count_k + k + bki * 16, vl); + vfloat32m2_t v_a_abs = __riscv_vfabs_v_f32m2(v_a, vl); + v_max_abs = __riscv_vfmax_vv_f32m2(v_a_abs, v_max_abs, vl); + } + vfloat32m1_t tmp = __riscv_vfmv_v_f_f32m1(0.0f, vl); + vfloat32m1_t v_local_max = __riscv_vfredmax_vs_f32m2_f32m1(v_max_abs, tmp, vl); + float max_abs_a = __riscv_vfmv_f_s_f32m1_f32(v_local_max); + + float scale_a = max_abs_a / ((1 << 7) - 1); + float rep_scale_a = scale_a ? 1.0f / scale_a : 0.0f; + scale_a_ptr[mi] = scale_a; + + // Quantize and compute sums for each 16-element group + for (size_t bki = 0; bki < a_sum_size; bki++) { + vfloat32m2_t v_a = __riscv_vle32_v_f32m2(a_ptr + mi * count_k + k + bki * 16, vl); + vfloat32m2_t v_a_scale = __riscv_vfmul_vf_f32m2(v_a, rep_scale_a, vl); + vint16m1_t v_a_quant = __riscv_vfncvt_x_f_w_i16m1(v_a_scale, vl); + vint8mf2_t v_a_quant_i8 = __riscv_vncvt_x_x_w_i8mf2(v_a_quant, vl); + + vint16m1_t tmp_sum = __riscv_vmv_v_x_i16m1(0, vl); + vint16m1_t v_a_sum = __riscv_vwredsum_vs_i8mf2_i16m1(v_a_quant_i8, tmp_sum, vl); + int16_t a_sum = __riscv_vmv_x_s_i16m1_i16(v_a_sum); + a_sum_ptr[mi * a_sum_size + bki] = -a_sum; + + __riscv_vse8_v_i8mf2(quant_a_blk + mi * blk_len + bki * 16, v_a_quant_i8, vl); + } + } + } + } else { + quantize_a_nrow_i8k_ref<4>(blk_len, a_ptr, count_k, quant_a_ptr); + } +} + +void forward_cpy_with_permute(ggml_compute_params * params, ggml_tensor * op) { + const ggml_tensor * src0 = op->src[0]; + ggml_tensor * dst = op; + const int ith = params->ith; + const int nth = params->nth; + + // [batch, m, n] -> [batch, n, m] + int64_t batch = src0->ne[2] * src0->ne[3]; + int64_t m = src0->ne[1]; + int64_t n = src0->ne[0]; + + int64_t batch_stride = src0->nb[2]; + int64_t m_src_stride = src0->nb[0]; + int64_t n_src_stride = src0->nb[1]; + int64_t n_dst_stride = n_src_stride * m; + + permute_transpose_impl(src0, dst, batch, m, n, batch_stride, m_src_stride, n_src_stride, n_dst_stride, ith, nth); +} + +void forward_cont_with_permute(ggml_compute_params * params, ggml_tensor * op) { + const ggml_tensor * src0 = op->src[0]; + ggml_tensor * dst = op; + const int ith = params->ith; + const int nth = params->nth; + + // [batch, m, n] -> [batch, n, m] + int64_t batch = dst->ne[2] * dst->ne[3]; + int64_t n = dst->ne[1]; + int64_t m = dst->ne[0]; + + int64_t batch_stride = dst->nb[2]; + int64_t m_src_stride = src0->nb[0]; + int64_t n_src_stride = src0->nb[1]; + int64_t n_dst_stride = dst->nb[1]; + + permute_transpose_impl(src0, dst, batch, m, n, batch_stride, m_src_stride, n_src_stride, n_dst_stride, ith, nth); +} + +void forward_norm_f32(ggml_compute_params * params, ggml_tensor * op) { + const ggml_tensor * src0 = op->src[0]; + ggml_tensor * dst = op; + GGML_ASSERT(ggml_are_same_shape(src0, dst)); + GGML_ASSERT(src0->nb[0] == sizeof(float)); + + int ith = params->ith; + int nth = params->nth; + + GGML_TENSOR_UNARY_OP_LOCALS + + float epsilon = *((float *) dst->op_params); + + GGML_ASSERT(epsilon > 0.0f); + + auto * input = (char *) src0->data; + auto * output = (char *) dst->data; + + const auto hidden_size = ne00; + const auto task_count = ne01 * ne02 * ne03; + const auto task_per_thread = (task_count + nth - 1) / nth; + + const auto task_begin = ith * task_per_thread; + const auto task_end = std::min((ith + 1) * task_per_thread, task_count); + + for (auto task_idx = task_begin; task_idx < task_end; task_idx++) { + int64_t i03 = task_idx / (ne02 * ne01); + int64_t i02 = (task_idx - i03 * ne02 * ne01) / ne01; + int64_t i01 = (task_idx - i03 * ne02 * ne01 - i02 * ne01); + + auto * p_input = (float *) (input + i01 * nb01 + i02 * nb02 + i03 * nb03); + auto * p_output = (float *) (output + i01 * nb1 + i02 * nb2 + i03 * nb3); + auto * p_temp_output = p_output; + + size_t gvl = __riscv_vsetvlmax_e32m4(); + vfloat32m4_t sum = __riscv_vfmv_v_f_f32m4(0.f, gvl); + vfloat32m4_t sum_sq = __riscv_vfmv_v_f_f32m4(0.f, gvl); + int64_t length = hidden_size; + while (length > 0) { + gvl = __riscv_vsetvl_e32m4(length); + // load data + vfloat32m4_t src_data = __riscv_vle32_v_f32m4(p_input, gvl); + + sum = __riscv_vfadd_vv_f32m4(sum, src_data, gvl); + sum_sq = __riscv_vfmacc_vv_f32m4(sum_sq, src_data, src_data, gvl); + + __riscv_vse32_v_f32m4(p_temp_output, src_data, gvl); + + p_input += gvl; + p_temp_output += gvl; + length -= gvl; + } + + gvl = __riscv_vsetvlmax_e32m1(); + + float mean = 0.f; + vfloat32m1_t zero_v = __riscv_vfmv_v_f_f32m1(0.f, gvl); + vfloat32m1_t mean_v = + __riscv_vfadd_vv_f32m1(__riscv_vget_v_f32m4_f32m1(sum, 0), __riscv_vget_v_f32m4_f32m1(sum, 1), gvl); + mean_v = __riscv_vfadd_vv_f32m1(mean_v, __riscv_vget_v_f32m4_f32m1(sum, 2), gvl); + mean_v = __riscv_vfadd_vv_f32m1(mean_v, __riscv_vget_v_f32m4_f32m1(sum, 3), gvl); + mean_v = __riscv_vfredusum_vs_f32m1_f32m1(mean_v, zero_v, gvl); + mean = __riscv_vfmv_f_s_f32m1_f32(mean_v); + mean /= hidden_size; + + vfloat32m1_t mean_square_v = + __riscv_vfadd_vv_f32m1(__riscv_vget_v_f32m4_f32m1(sum_sq, 0), __riscv_vget_v_f32m4_f32m1(sum_sq, 1), gvl); + mean_square_v = __riscv_vfadd_vv_f32m1(mean_square_v, __riscv_vget_v_f32m4_f32m1(sum_sq, 2), gvl); + mean_square_v = __riscv_vfadd_vv_f32m1(mean_square_v, __riscv_vget_v_f32m4_f32m1(sum_sq, 3), gvl); + mean_square_v = __riscv_vfredusum_vs_f32m1_f32m1(mean_square_v, zero_v, gvl); + + float mean_square = __riscv_vfmv_f_s_f32m1_f32(mean_square_v); + mean_square /= hidden_size; + mean_square = sqrt(mean_square - mean * mean + epsilon); + + mean_square = 1.0f / mean_square; + length = hidden_size; + p_temp_output = p_output; + + while (length > 0) { + gvl = __riscv_vsetvl_e32m4(length); + vfloat32m4_t src_data = __riscv_vle32_v_f32m4(p_temp_output, gvl); + src_data = __riscv_vfsub_vf_f32m4(src_data, mean, gvl); + src_data = __riscv_vfmul_vf_f32m4(src_data, mean_square, gvl); + __riscv_vse32_v_f32m4(p_output, src_data, gvl); + p_temp_output += gvl; + p_output += gvl; + length -= gvl; + } + } +} + +template <ggml_op op_type, typename T> void forward_binary(ggml_compute_params * params, ggml_tensor * op) { + const ggml_tensor * src0 = op->src[0]; + const ggml_tensor * src1 = op->src[1]; + ggml_tensor * dst = op; + GGML_ASSERT(ggml_can_repeat(src1, src0) && ggml_are_same_shape(src0, dst)); + + auto src0_rows = ggml_nrows(src0); + auto src1_rows = ggml_nrows(src1); + + int ith = params->ith; + int nth = params->nth; + + GGML_TENSOR_BINARY_OP_LOCALS + + GGML_ASSERT(nb0 == sizeof(T)); + GGML_ASSERT(nb00 == sizeof(T)); + + const auto [ir0, ir1] = get_thread_range(params, src0); + + auto compute_func_vv = [&](int64_t blk_len, int64_t r, T * src0_ptr, T * src1_ptr, T * dst_ptr) { + int64_t idx = 0; + if constexpr (op_type == GGML_OP_ADD) { + if constexpr (std::is_same_v<T, float>) { + for (size_t vl; blk_len > 0; blk_len -= vl, idx += vl) { + vl = __riscv_vsetvl_e32m4(blk_len); + vfloat32m4_t lhs = __riscv_vle32_v_f32m4(src0_ptr + idx + r, vl); + vfloat32m4_t rhs = __riscv_vle32_v_f32m4(src1_ptr + idx, vl); + vfloat32m4_t res = __riscv_vfadd_vv_f32m4(lhs, rhs, vl); + __riscv_vse32_v_f32m4(dst_ptr + idx + r, res, vl); + } + } else if constexpr (std::is_same_v<T, _Float16>) { + for (size_t vl; blk_len > 0; blk_len -= vl, idx += vl) { + vl = __riscv_vsetvl_e16m4(blk_len); + vfloat16m4_t lhs = __riscv_vle16_v_f16m4((src0_ptr + idx + r), vl); + vfloat16m4_t rhs = __riscv_vle16_v_f16m4((src1_ptr + idx), vl); + vfloat16m4_t res = __riscv_vfadd_vv_f16m4(lhs, rhs, vl); + __riscv_vse16_v_f16m4((dst_ptr + idx + r), res, vl); + } + } else { + GGML_ABORT("fatal error"); + } + } else if constexpr (op_type == GGML_OP_SUB) { + if constexpr (std::is_same_v<T, float>) { + for (size_t vl; blk_len > 0; blk_len -= vl, idx += vl) { + vl = __riscv_vsetvl_e32m4(blk_len); + vfloat32m4_t lhs = __riscv_vle32_v_f32m4(src0_ptr + idx + r, vl); + vfloat32m4_t rhs = __riscv_vle32_v_f32m4(src1_ptr + idx, vl); + vfloat32m4_t res = __riscv_vfsub_vv_f32m4(lhs, rhs, vl); + __riscv_vse32_v_f32m4(dst_ptr + idx + r, res, vl); + } + } else if constexpr (std::is_same_v<T, _Float16>) { + for (size_t vl; blk_len > 0; blk_len -= vl, idx += vl) { + vl = __riscv_vsetvl_e16m4(blk_len); + vfloat16m4_t lhs = __riscv_vle16_v_f16m4((src0_ptr + idx + r), vl); + vfloat16m4_t rhs = __riscv_vle16_v_f16m4((src1_ptr + idx), vl); + vfloat16m4_t res = __riscv_vfsub_vv_f16m4(lhs, rhs, vl); + __riscv_vse16_v_f16m4((dst_ptr + idx + r), res, vl); + } + } else { + GGML_ABORT("fatal error"); + } + } else if constexpr (op_type == GGML_OP_MUL) { + if constexpr (std::is_same_v<T, float>) { + for (size_t vl; blk_len > 0; blk_len -= vl, idx += vl) { + vl = __riscv_vsetvl_e32m4(blk_len); + vfloat32m4_t lhs = __riscv_vle32_v_f32m4(src0_ptr + idx + r, vl); + vfloat32m4_t rhs = __riscv_vle32_v_f32m4(src1_ptr + idx, vl); + vfloat32m4_t res = __riscv_vfmul_vv_f32m4(lhs, rhs, vl); + __riscv_vse32_v_f32m4(dst_ptr + idx + r, res, vl); + } + } else if constexpr (std::is_same_v<T, _Float16>) { + for (size_t vl; blk_len > 0; blk_len -= vl, idx += vl) { + vl = __riscv_vsetvl_e16m4(blk_len); + vfloat16m4_t lhs = __riscv_vle16_v_f16m4((src0_ptr + idx + r), vl); + vfloat16m4_t rhs = __riscv_vle16_v_f16m4((src1_ptr + idx), vl); + vfloat16m4_t res = __riscv_vfmul_vv_f16m4(lhs, rhs, vl); + __riscv_vse16_v_f16m4((dst_ptr + idx + r), res, vl); + } + } else { + GGML_ABORT("fatal error"); + } + } else if constexpr (op_type == GGML_OP_DIV) { + if constexpr (std::is_same_v<T, float>) { + for (size_t vl; blk_len > 0; blk_len -= vl, idx += vl) { + vl = __riscv_vsetvl_e32m4(blk_len); + vfloat32m4_t lhs = __riscv_vle32_v_f32m4(src0_ptr + idx + r, vl); + vfloat32m4_t rhs = __riscv_vle32_v_f32m4(src1_ptr + idx, vl); + vfloat32m4_t res = __riscv_vfdiv_vv_f32m4(lhs, rhs, vl); + __riscv_vse32_v_f32m4(dst_ptr + idx + r, res, vl); + } + } else if constexpr (std::is_same_v<T, _Float16>) { + for (size_t vl; blk_len > 0; blk_len -= vl, idx += vl) { + vl = __riscv_vsetvl_e16m4(blk_len); + vfloat16m4_t lhs = __riscv_vle16_v_f16m4((src0_ptr + idx + r), vl); + vfloat16m4_t rhs = __riscv_vle16_v_f16m4((src1_ptr + idx), vl); + vfloat16m4_t res = __riscv_vfdiv_vv_f16m4(lhs, rhs, vl); + __riscv_vse16_v_f16m4((dst_ptr + idx + r), res, vl); + } + } else { + GGML_ABORT("fatal error"); + } + } else { + GGML_ABORT("fatal error"); + } + }; + + if (src0_rows == src1_rows && src0_rows == 1 && ne00 == ne10) { + int64_t task_per_thread = (ne00 + nth - 1) / nth; + int64_t task_begin = ith * task_per_thread; + int64_t task_end = std::min((ith + 1) * task_per_thread, ne00); + + T * dst_ptr = ((T *) dst->data) + task_begin; + T * src0_ptr = ((T *) src0->data) + task_begin; + T * src1_ptr = ((T *) src1->data) + task_begin; + + compute_func_vv(task_end - task_begin, 0, src0_ptr, src1_ptr, dst_ptr); + } else if (ne10 > 1) { + for (int64_t ir = ir0; ir < ir1; ++ir) { + const int64_t i03 = ir / (ne02 * ne01); + const int64_t i02 = (ir - i03 * ne02 * ne01) / ne01; + const int64_t i01 = (ir - i03 * ne02 * ne01 - i02 * ne01); + + const int64_t i13 = i03 % ne13; + const int64_t i12 = i02 % ne12; + const int64_t i11 = i01 % ne11; + + T * dst_ptr = (T *) ((char *) dst->data + i03 * nb3 + i02 * nb2 + i01 * nb1); + T * src0_ptr = (T *) ((char *) src0->data + i03 * nb03 + i02 * nb02 + i01 * nb01); + T * src1_ptr = (T *) ((char *) src1->data + i13 * nb13 + i12 * nb12 + i11 * nb11); + + // src1 is broadcastable across src0 and dst in i1, i2, i3 + for (int64_t r = 0; r < ne00; r += ne10) { + compute_func_vv(ne10, r, src0_ptr, src1_ptr, dst_ptr); + } + } + } else { + for (int64_t ir = ir0; ir < ir1; ++ir) { + const int64_t i03 = ir / (ne02 * ne01); + const int64_t i02 = (ir - i03 * ne02 * ne01) / ne01; + const int64_t i01 = (ir - i03 * ne02 * ne01 - i02 * ne01); + + const int64_t i13 = i03 % ne13; + const int64_t i12 = i02 % ne12; + const int64_t i11 = i01 % ne11; + + T * dst_ptr = (T *) ((char *) dst->data + i03 * nb3 + i02 * nb2 + i01 * nb1); + T * src0_ptr = (T *) ((char *) src0->data + i03 * nb03 + i02 * nb02 + i01 * nb01); + T * src1_ptr = (T *) ((char *) src1->data + i13 * nb13 + i12 * nb12 + i11 * nb11); + + T rhs_scalar = src1_ptr[0]; + int64_t blk_len = ne00; + int64_t r = 0; + + for (size_t vl; blk_len > 0; blk_len -= vl, r += vl) { + if constexpr (op_type == GGML_OP_ADD) { + if constexpr (std::is_same_v<T, float>) { + vl = __riscv_vsetvl_e32m4(blk_len); + vfloat32m4_t lhs = __riscv_vle32_v_f32m4(src0_ptr + r, vl); + vfloat32m4_t res = __riscv_vfadd_vf_f32m4(lhs, rhs_scalar, vl); + __riscv_vse32_v_f32m4(dst_ptr + r, res, vl); + } else if constexpr (std::is_same_v<T, _Float16>) { + vl = __riscv_vsetvl_e16m4(blk_len); + vfloat16m4_t lhs = __riscv_vle16_v_f16m4((src0_ptr + r), vl); + vfloat16m4_t res = __riscv_vfadd_vf_f16m4(lhs, rhs_scalar, vl); + __riscv_vse16_v_f16m4((dst_ptr + r), res, vl); + } else { + GGML_ABORT("fatal error"); + } + } else if constexpr (op_type == GGML_OP_SUB) { + if constexpr (std::is_same_v<T, float>) { + vl = __riscv_vsetvl_e32m4(blk_len); + vfloat32m4_t lhs = __riscv_vle32_v_f32m4(src0_ptr + r, vl); + vfloat32m4_t res = __riscv_vfsub_vf_f32m4(lhs, rhs_scalar, vl); + __riscv_vse32_v_f32m4(dst_ptr + r, res, vl); + } else if constexpr (std::is_same_v<T, _Float16>) { + vl = __riscv_vsetvl_e16m4(blk_len); + vfloat16m4_t lhs = __riscv_vle16_v_f16m4((src0_ptr + r), vl); + vfloat16m4_t res = __riscv_vfsub_vf_f16m4(lhs, rhs_scalar, vl); + __riscv_vse16_v_f16m4((dst_ptr + r), res, vl); + } else { + GGML_ABORT("fatal error"); + } + } else if constexpr (op_type == GGML_OP_MUL) { + if constexpr (std::is_same_v<T, float>) { + vl = __riscv_vsetvl_e32m4(blk_len); + vfloat32m4_t lhs = __riscv_vle32_v_f32m4(src0_ptr + r, vl); + vfloat32m4_t res = __riscv_vfmul_vf_f32m4(lhs, rhs_scalar, vl); + __riscv_vse32_v_f32m4(dst_ptr + r, res, vl); + } else if constexpr (std::is_same_v<T, _Float16>) { + vl = __riscv_vsetvl_e16m4(blk_len); + vfloat16m4_t lhs = __riscv_vle16_v_f16m4((src0_ptr + r), vl); + vfloat16m4_t res = __riscv_vfmul_vf_f16m4(lhs, rhs_scalar, vl); + __riscv_vse16_v_f16m4((dst_ptr + r), res, vl); + } else { + GGML_ABORT("fatal error"); + } + } else if constexpr (op_type == GGML_OP_DIV) { + if constexpr (std::is_same_v<T, float>) { + vl = __riscv_vsetvl_e32m4(blk_len); + vfloat32m4_t lhs = __riscv_vle32_v_f32m4(src0_ptr + r, vl); + vfloat32m4_t res = __riscv_vfdiv_vf_f32m4(lhs, rhs_scalar, vl); + __riscv_vse32_v_f32m4(dst_ptr + r, res, vl); + } else if constexpr (std::is_same_v<T, _Float16>) { + vl = __riscv_vsetvl_e16m4(blk_len); + vfloat16m4_t lhs = __riscv_vle16_v_f16m4((src0_ptr + r), vl); + vfloat16m4_t res = __riscv_vfdiv_vf_f16m4(lhs, rhs_scalar, vl); + __riscv_vse16_v_f16m4((dst_ptr + r), res, vl); + } else { + GGML_ABORT("fatal error"); + } + } else { + GGML_ABORT("fatal error"); + } + } + } + } +} + +template <typename T> void forward_sum_rows(const ggml_compute_params * params, ggml_tensor * op) { + const ggml_tensor * src0 = op->src[0]; + ggml_tensor * dst = op; + + const int ith = params->ith; + const int nth = params->nth; + + GGML_TENSOR_UNARY_OP_LOCALS + + GGML_ASSERT(ne0 == 1); + GGML_ASSERT(ne1 == ne01); + GGML_ASSERT(ne2 == ne02); + GGML_ASSERT(ne3 == ne03); + + int64_t n_task = ne01 * ne02 * ne03; + int64_t task_per_thread = (n_task + nth - 1) / nth; + int64_t ir_start = ith * task_per_thread; + int64_t ir_end = std::min(ir_start + task_per_thread, n_task); + + for (int64_t ir = ir_start; ir < ir_end; ir++) { + const int64_t i3 = ir / (ne02 * ne01); + const int64_t i2 = (ir - i3 * ne02 * ne01) / ne01; + const int64_t i1 = (ir - i3 * ne02 * ne01 - i2 * ne01); + + T * src_row = (T *) ((char *) src0->data + i1 * nb01 + i2 * nb02 + i3 * nb03); + T * dst_row = (T *) ((char *) op->data + i1 * nb1 + i2 * nb2 + i3 * nb3); + + float row_sum = 0; + + if constexpr (std::is_same_v<T, float>) { + size_t gvl = __riscv_vsetvlmax_e32m4(); + vfloat32m4_t acc_vec = __riscv_vfmv_v_f_f32m4(0.0f, gvl); + int64_t length = ne00; + const float * p_data = src_row; + + while (length > 0) { + size_t vl = __riscv_vsetvl_e32m4(length); + vfloat32m4_t vec = __riscv_vle32_v_f32m4(p_data, vl); + acc_vec = __riscv_vfadd_vv_f32m4(acc_vec, vec, vl); + p_data += vl; + length -= vl; + } + + gvl = __riscv_vsetvlmax_e32m1(); + vfloat32m1_t zero_v = __riscv_vfmv_v_f_f32m1(0.0f, gvl); + vfloat32m1_t sum_v = __riscv_vfadd_vv_f32m1(__riscv_vget_v_f32m4_f32m1(acc_vec, 0), + __riscv_vget_v_f32m4_f32m1(acc_vec, 1), gvl); + sum_v = __riscv_vfadd_vv_f32m1(sum_v, __riscv_vget_v_f32m4_f32m1(acc_vec, 2), gvl); + sum_v = __riscv_vfadd_vv_f32m1(sum_v, __riscv_vget_v_f32m4_f32m1(acc_vec, 3), gvl); + sum_v = __riscv_vfredusum_vs_f32m1_f32m1(sum_v, zero_v, gvl); + row_sum = __riscv_vfmv_f_s_f32m1_f32(sum_v); + } else if constexpr (std::is_same_v<T, _Float16>) { + size_t gvl = __riscv_vsetvlmax_e16m2(); + vfloat32m4_t acc_vec = __riscv_vfmv_v_f_f32m4(0.0f, gvl); + int64_t length = ne00; + const _Float16 * p_data = src_row; + + while (length > 0) { + size_t vl = __riscv_vsetvl_e16m2(length); + vfloat16m2_t vec_f16 = __riscv_vle16_v_f16m2(p_data, vl); + vfloat32m4_t vec_f32 = __riscv_vfwcvt_f_f_v_f32m4(vec_f16, vl); + acc_vec = __riscv_vfadd_vv_f32m4(acc_vec, vec_f32, vl); + p_data += vl; + length -= vl; + } + + gvl = __riscv_vsetvlmax_e32m1(); + vfloat32m1_t zero_v = __riscv_vfmv_v_f_f32m1(0.0f, gvl); + vfloat32m1_t sum_v = __riscv_vfadd_vv_f32m1(__riscv_vget_v_f32m4_f32m1(acc_vec, 0), + __riscv_vget_v_f32m4_f32m1(acc_vec, 1), gvl); + sum_v = __riscv_vfadd_vv_f32m1(sum_v, __riscv_vget_v_f32m4_f32m1(acc_vec, 2), gvl); + sum_v = __riscv_vfadd_vv_f32m1(sum_v, __riscv_vget_v_f32m4_f32m1(acc_vec, 3), gvl); + sum_v = __riscv_vfredusum_vs_f32m1_f32m1(sum_v, zero_v, gvl); + row_sum = __riscv_vfmv_f_s_f32m1_f32(sum_v); + } else { + GGML_ABORT("fatal error"); + } + + dst_row[0] = row_sum; + } +} + +template <typename T> void forward_repeat_nrows(ggml_compute_params * params, ggml_tensor * op) { + const ggml_tensor * src0 = op->src[0]; + ggml_tensor * dst = op; + + const int ith = params->ith; + const int nth = params->nth; + + int64_t nrows = ggml_nrows(src0); + int64_t nrows_per_thread = (nrows + nth - 1) / nth; + int64_t ir_start = ith * nrows_per_thread; + int64_t ir_end = std::min(ir_start + nrows_per_thread, nrows); + + if (src0->ne[0] == 1) { + for (int64_t ir = ir_start; ir < ir_end; ir++) { + T * src_row = (T *) ((char *) src0->data + ir * src0->nb[1]); + T * dst_row = (T *) ((char *) dst->data + ir * dst->nb[1]); + + T src_scalar = src_row[0]; + + int64_t length = dst->ne[0]; + int64_t idx = 0; + size_t vl = 0; + + while (length > 0) { + if constexpr (std::is_same_v<T, int32_t>) { + vl = __riscv_vsetvl_e32m4(length); + vint32m4_t vec = __riscv_vmv_v_x_i32m4(src_scalar, vl); + __riscv_vse32_v_i32m4(dst_row + idx, vec, vl); + } else if constexpr (std::is_same_v<T, int16_t>) { + vl = __riscv_vsetvl_e16m4(length); + vint16m4_t vec = __riscv_vmv_v_x_i16m4(src_scalar, vl); + __riscv_vse16_v_i16m4((dst_row + idx), vec, vl); + } else { + GGML_ABORT("fatal error"); + } + idx += vl; + length -= vl; + } + } + } else if (src0->ne[0] == dst->ne[0]) { + for (int64_t ir = ir_start; ir < ir_end; ir++) { + T * src_row = (T *) ((char *) src0->data + ir * src0->nb[1]); + T * dst_row = (T *) ((char *) dst->data + ir * dst->nb[1]); + + int64_t length = dst->ne[0]; + int64_t idx = 0; + size_t vl = 0; + + while (length > 0) { + if constexpr (std::is_same_v<T, int32_t>) { + vl = __riscv_vsetvl_e32m4(length); + vint32m4_t vec = __riscv_vle32_v_i32m4(src_row + idx, vl); + __riscv_vse32_v_i32m4(dst_row + idx, vec, vl); + } else if constexpr (std::is_same_v<T, int16_t>) { + vl = __riscv_vsetvl_e16m4(length); + vint16m4_t vec = __riscv_vle16_v_i16m4((src_row + idx), vl); + __riscv_vse16_v_i16m4((dst_row + idx), vec, vl); + } else { + GGML_ABORT("fatal error"); + } + idx += vl; + length -= vl; + } + } + } else { + GGML_ABORT("fatal error"); + } +} + +template <typename T> void forward_repeat_dim1(ggml_compute_params * params, ggml_tensor * op) { + const ggml_tensor * src0 = op->src[0]; + ggml_tensor * dst = op; + + const int ith = params->ith; + const int nth = params->nth; + + const int64_t ne0 = dst->ne[0]; + const int64_t ne1 = dst->ne[1]; + const int64_t ne2 = dst->ne[2]; + const int64_t ne3 = dst->ne[3]; + + const int64_t total_batches = ne2 * ne3; + const int64_t batches_per_thread = (total_batches + nth - 1) / nth; + const int64_t batch_start = ith * batches_per_thread; + const int64_t batch_end = std::min(batch_start + batches_per_thread, total_batches); + + for (int64_t b = batch_start; b < batch_end; b++) { + const int64_t i3 = b / ne2; + const int64_t i2 = b % ne2; + + T * src_base = (T *) ((char *) src0->data + i2 * src0->nb[2] + i3 * src0->nb[3]); + T * dst_batch = (T *) ((char *) dst->data + i2 * dst->nb[2] + i3 * dst->nb[3]); + + for (int64_t i1 = 0; i1 < ne1; i1++) { + T * dst_ptr = (T *) ((char *) dst_batch + i1 * dst->nb[1]); + int64_t length = ne0; + int64_t idx = 0; + + while (length > 0) { + if constexpr (std::is_same_v<T, int32_t>) { + size_t vl = __riscv_vsetvl_e32m4(length); + vint32m4_t vec = __riscv_vle32_v_i32m4(src_base + idx, vl); + __riscv_vse32_v_i32m4(dst_ptr + idx, vec, vl); + idx += vl; + length -= vl; + } else if constexpr (std::is_same_v<T, int16_t>) { + size_t vl = __riscv_vsetvl_e16m4(length); + vint16m4_t vec = __riscv_vle16_v_i16m4((src_base + idx), vl); + __riscv_vse16_v_i16m4((dst_ptr + idx), vec, vl); + idx += vl; + length -= vl; + } else { + GGML_ABORT("fatal error"); + } + } + } + } +} + +template <typename T> void forward_get_rows(ggml_compute_params * params, ggml_tensor * op) { + const ggml_tensor * src0 = op->src[0]; + const ggml_tensor * src1 = op->src[1]; + ggml_tensor * dst = op; + + GGML_TENSOR_BINARY_OP_LOCALS + + const int64_t nc = ne00; + const int64_t nr = ggml_nelements(src1); + + assert(ne0 == nc); + assert(ne02 == ne11); + assert(nb00 == sizeof(float)); + assert(ggml_nrows(op) == nr); + + const int ith = params->ith; + const int nth = params->nth; + + int rows_nth = nth; + int cols_nth = 1; + + if (nr == 1) { + rows_nth = 1; + cols_nth = nth; + } + + // rows per thread + const int dr = (nr + rows_nth - 1) / rows_nth; + const int dc = (nc + cols_nth - 1) / cols_nth; + + int rows_ith = ith % rows_nth; + int cols_ith = ith % cols_nth; + + // row range for this thread + const int ir0 = dr * rows_ith; + const int ir1 = MIN(ir0 + dr, nr); + + const int cr0 = dc * cols_ith; + const int cr1 = MIN(cr0 + dc, nc); + + for (int64_t i = ir0; i < ir1; ++i) { + const int64_t i12 = i / (ne11 * ne10); + const int64_t i11 = (i - i12 * ne11 * ne10) / ne10; + const int64_t i10 = (i - i12 * ne11 * ne10 - i11 * ne10); + const int64_t i01 = *(int32_t *) ((char *) src1->data + i10 * nb10 + i11 * nb11 + i12 * nb12); + + GGML_ASSERT(i01 >= 0 && i01 < ne01); + + memcpy1d(((char *) dst->data + i10 * nb1 + i11 * nb2 + i12 * nb3) + cr0 * sizeof(T), + ((char *) src0->data + i01 * nb01 + i11 * nb02 + i12 * nb03) + cr0 * sizeof(T), + (cr1 - cr0) * sizeof(T)); + } +} + +template <typename T> void forward_concat(ggml_compute_params * params, ggml_tensor * op) { + const ggml_tensor * src0 = op->src[0]; + const ggml_tensor * src1 = op->src[1]; + ggml_tensor * dst = op; + + GGML_ASSERT(ggml_type_size(src0->type) == sizeof(float)); + + GGML_TENSOR_BINARY_OP_LOCALS + + const int32_t dim = ggml_get_op_params_i32(dst, 0); + + GGML_ASSERT(dim == 0 && nb0 == sizeof(float) && nb1 == sizeof(float) * (ne00 + ne10)); + + const int64_t nr = ggml_nrows(dst); + const int64_t nc = ne0; + + const int ith = params->ith; + const int nth = params->nth; + + int rows_nth = nth; + int cols_nth = 1; + + if (nr == 1) { + rows_nth = 1; + cols_nth = nth; + } + + const int dr = (nr + rows_nth - 1) / rows_nth; + const int dc = (nc + cols_nth - 1) / cols_nth; + + int rows_ith = ith % rows_nth; + int cols_ith = ith % cols_nth; + + // row range for this thread + const int ir0 = dr * rows_ith; + const int ir1 = MIN(ir0 + dr, nr); + + const int cr0 = dc * cols_ith; + const int cr1 = MIN(cr0 + dc, nc); + + int64_t o[4] = { 0, 0, 0, 0 }; + o[dim] = src0->ne[dim]; + const float * x; + + for (int64_t i = ir0; i < ir1; ++i) { + const int64_t i3 = i / (ne02 * ne01); + const int64_t i2 = (i - i3 * ne02 * ne01) / ne01; + const int64_t i1 = (i - i3 * ne02 * ne01 - i2 * ne01); + + for (int i0 = cr0; i0 < cr1; i0++) { + if (i0 < ne00 && i1 < ne01 && i2 < ne02 && i3 < ne03) { + x = (const float *) ((const char *) src0->data + (i0) *nb00 + (i1) *nb01 + (i2) *nb02 + (i3) *nb03); + } else { + x = (const float *) ((const char *) src1->data + (i0 - o[0]) * nb10 + (i1 - o[1]) * nb11 + + (i2 - o[2]) * nb12 + (i3 - o[3]) * nb13); + } + + float * y = (float *) ((char *) dst->data + i0 * nb0 + i1 * nb1 + i2 * nb2 + i3 * nb3); + + *y = *x; + } + } +} + +template void forward_binary<GGML_OP_ADD, float>(ggml_compute_params * params, ggml_tensor * op); +template void forward_binary<GGML_OP_SUB, float>(ggml_compute_params * params, ggml_tensor * op); +template void forward_binary<GGML_OP_MUL, float>(ggml_compute_params * params, ggml_tensor * op); +template void forward_binary<GGML_OP_DIV, float>(ggml_compute_params * params, ggml_tensor * op); +template void forward_binary<GGML_OP_ADD, _Float16>(ggml_compute_params * params, ggml_tensor * op); +template void forward_binary<GGML_OP_SUB, _Float16>(ggml_compute_params * params, ggml_tensor * op); +template void forward_binary<GGML_OP_MUL, _Float16>(ggml_compute_params * params, ggml_tensor * op); +template void forward_binary<GGML_OP_DIV, _Float16>(ggml_compute_params * params, ggml_tensor * op); +template void forward_sum_rows<float>(const ggml_compute_params * params, ggml_tensor * op); +template void forward_sum_rows<_Float16>(const ggml_compute_params * params, ggml_tensor * op); +template void forward_repeat_nrows<int32_t>(ggml_compute_params * params, ggml_tensor * op); +template void forward_repeat_nrows<int16_t>(ggml_compute_params * params, ggml_tensor * op); +template void forward_repeat_dim1<int32_t>(ggml_compute_params * params, ggml_tensor * op); +template void forward_repeat_dim1<int16_t>(ggml_compute_params * params, ggml_tensor * op); +template void forward_get_rows<int32_t>(ggml_compute_params * params, ggml_tensor * op); +template void forward_get_rows<int16_t>(ggml_compute_params * params, ggml_tensor * op); +template void forward_concat<int32_t>(ggml_compute_params * params, ggml_tensor * op); +template void forward_concat<int16_t>(ggml_compute_params * params, ggml_tensor * op); + +} // namespace spacemit_kernels::rvv diff --git a/ggml/src/ggml-cpu/spacemit/rvv_kernels.h b/ggml/src/ggml-cpu/spacemit/rvv_kernels.h new file mode 100644 index 00000000000..edddf957c21 --- /dev/null +++ b/ggml/src/ggml-cpu/spacemit/rvv_kernels.h @@ -0,0 +1,95 @@ +#pragma once + +#include "ggml-cpu-impl.h" + +#include <cassert> +#include <cstddef> +#include <cstdint> +#include <functional> + +namespace spacemit_kernels { + +constexpr auto div_round_up(auto up, auto down) { + return (up + down - 1) / down; +} + +// Q8 Blk [f32] [s16] [int8 * blk_len] +// Q8 Blk N [f32 * N] [s16 * N] [int8 * blk_len * N] +constexpr size_t q8_blk_size(size_t blk_len, bool with_blk_sum = false) { + const size_t blk_size = sizeof(float) + blk_len * sizeof(int8_t) + (with_blk_sum ? sizeof(int16_t) : 0); + return blk_size; +} + +// Q8 HP row block: K is split into K32 subblocks. +// Each subblock stores [f32 scale] [int8 * 32], with an optional fp16 sum trailer per subblock. +constexpr size_t q8_hp_blk_size(size_t blk_len, bool with_blk_sum = false, bool with_blk_scale = false) { + const size_t subblk_count = div_round_up(blk_len, size_t(32)); + const size_t blk_size = blk_len * sizeof(int8_t) + subblk_count * sizeof(_Float16) + + (with_blk_sum ? subblk_count * sizeof(_Float16) : 0) + + (with_blk_scale ? sizeof(_Float16) : 0); + return blk_size; +} + +// Q8K Blk [f32] [s16 * (blk_len / 16)] [int8 * blk_len] +// Q8K Blk N [f32 * N] [s16 * (blk_len / 16) * N] [int8 * blk_len * N] +constexpr size_t q8k_blk_size(size_t blk_len) { + const size_t blk_size = sizeof(float) + blk_len * sizeof(int8_t) + sizeof(int16_t) * blk_len / 16; + return blk_size; +} + +using quantize_a_row_def = std::function<void(size_t, const float *, size_t, uint8_t *)>; + +namespace rvv { +void memcpy1d(void * dst, const void * src, int64_t size); + +void memcpy2d(void * dst, int64_t dst_stride, const void * src, int64_t src_stride, int64_t tile_rows, int64_t size); + +void forward_flash_attn_ext_f16_one_chunk_vlen1024_vf16(const ggml_compute_params * params, + ggml_tensor * dst, + int ir0, + int ir1, + void * tcm_buffer, + size_t tcm_buffer_size); + +void forward_flash_attn_ext_f16_tiled_vlen1024_vf16(const ggml_compute_params * params, + ggml_tensor * dst, + int ir0, + int ir1, + void * tcm_buffer, + size_t tcm_buffer_size); + +void forward_rms_norm_f32(ggml_compute_params * params, ggml_tensor * op); + +void forward_norm_f32(ggml_compute_params * params, ggml_tensor * op); + +void forward_cont_with_permute(ggml_compute_params * params, ggml_tensor * op); + +void forward_cpy_with_permute(ggml_compute_params * params, ggml_tensor * op); + +template <typename T> void forward_get_rows(ggml_compute_params * params, ggml_tensor * op); + +template <typename T> void forward_concat(ggml_compute_params * params, ggml_tensor * op); + +template <ggml_op op_type, typename T> void forward_binary(ggml_compute_params * params, ggml_tensor * op); + +template <typename T> void forward_sum_rows(const ggml_compute_params * params, ggml_tensor * op); + +template <typename T> void forward_repeat_nrows(ggml_compute_params * params, ggml_tensor * op); + +template <typename T> void forward_repeat_dim1(ggml_compute_params * params, ggml_tensor * op); + +void quantize_a_row_i8(size_t blk_len, const float * a_ptr, size_t count_k, uint8_t * quant_a_ptr); + +void quantize_a_4row_i8(size_t blk_len, const float * a_ptr, size_t count_k, uint8_t * quant_a_ptr); + +void quantize_a_row_i8_hp(size_t blk_len, const float * a_ptr, size_t count_k, uint8_t * quant_a_ptr); + +void quantize_a_4row_i8_hp(size_t blk_len, const float * a_ptr, size_t count_k, uint8_t * quant_a_ptr); + +void quantize_a_row_i8k(size_t blk_len, const float * a_ptr, size_t count_k, uint8_t * quant_a_ptr); + +void quantize_a_4row_i8k(size_t blk_len, const float * a_ptr, size_t count_k, uint8_t * quant_a_ptr); + +} // namespace rvv + +} // namespace spacemit_kernels diff --git a/ggml/src/ggml-cpu/spacemit/spine_barrier.h b/ggml/src/ggml-cpu/spacemit/spine_barrier.h new file mode 100644 index 00000000000..f897dad4b8a --- /dev/null +++ b/ggml/src/ggml-cpu/spacemit/spine_barrier.h @@ -0,0 +1,34 @@ +#pragma once + +#include <atomic> +#include <cstdint> + +#define SPINE_CACHE_LINE 64 +#define SPINE_CACHE_ALIGN __attribute__((aligned(SPINE_CACHE_LINE))) + +struct spine_barrier_t { + SPINE_CACHE_ALIGN std::atomic<int64_t> pending_; + SPINE_CACHE_ALIGN std::atomic<int64_t> rounds_; + SPINE_CACHE_ALIGN int64_t total_; +}; + +inline void spine_barrier_wait(spine_barrier_t * b) { + auto cur_round = b->rounds_.load(std::memory_order_acquire); + auto cnt = --b->pending_; + if (cnt == 0) { + b->pending_.store(b->total_); + b->rounds_.store(cur_round + 1); + } else { + while (cur_round == b->rounds_.load(std::memory_order_relaxed)) { + __asm__ volatile("pause " ::: "memory"); + } + } +} + +inline void spine_barrier_init(spine_barrier_t * b, int num_barriers, uint64_t thread_count) { + for (int i = 0; i < num_barriers; i++) { + b[i].total_ = thread_count; + b[i].pending_.store(thread_count); + b[i].rounds_.store(0); + } +} diff --git a/ggml/src/ggml-cpu/spacemit/spine_mem_pool.cpp b/ggml/src/ggml-cpu/spacemit/spine_mem_pool.cpp new file mode 100644 index 00000000000..1409423b145 --- /dev/null +++ b/ggml/src/ggml-cpu/spacemit/spine_mem_pool.cpp @@ -0,0 +1,760 @@ +#include "spine_mem_pool.h" + +#include "common.h" +#include "ime_env.h" +#include "spine_tcm.h" + +#include <fcntl.h> +#include <sys/ioctl.h> +#include <sys/mman.h> +#include <unistd.h> + +#include <algorithm> +#include <cerrno> +#include <cstdint> +#include <cstdlib> +#include <limits> +#include <memory> +#include <mutex> +#include <unordered_map> +#include <vector> + +namespace ggml::cpu::riscv64_spacemit { +namespace { + +constexpr size_t SPINE_MEM_POOL_CHUNK_SIZE = 512ull * 1024ull * 1024ull; +constexpr size_t SPINE_SHARE_MEM_POOL_CHUNK_SIZE = 512ull * 1024ull; +constexpr size_t SPINE_MEM_POOL_1G_REGION_SIZE = 1ull << 30; +constexpr uint64_t HUGETLB_1G_FLAG_REQUIRE_PUD = 1ull << 0; +constexpr char SPINE_MEM_POOL_HUGETLB_1G_DEV[] = "/dev/hugetlb_1g"; +constexpr char SPINE_MEM_POOL_TCM_SYNC_MEM_DEV[] = "/dev/tcm_sync_mem"; + +struct hugetlb_1g_region { + uint64_t size{ 0 }; + uint64_t dma_addr{ 0 }; + uint64_t flags{ 0 }; + uint64_t reserved{ 0 }; +}; + +#define HUGETLB_1G_IOC_MAGIC 'M' +#define HUGETLB_1G_IOC_ALLOC _IOWR(HUGETLB_1G_IOC_MAGIC, 0x00, struct hugetlb_1g_region) +#define HUGETLB_1G_IOC_FREE _IO(HUGETLB_1G_IOC_MAGIC, 0x01) + +struct free_block { + size_t offset{ 0 }; + size_t size{ 0 }; +}; + +struct pool_chunk { + uint8_t * base{ nullptr }; + size_t size{ 0 }; + int fd{ -1 }; + std::vector<free_block> free_blocks; +}; + +struct pool_allocation { + void * chunk_base{ nullptr }; + size_t chunk_size{ 0 }; + void * base{ nullptr }; + size_t size{ 0 }; +}; + +bool is_power_of_two(size_t value) { + return value != 0 && (value & (value - 1)) == 0; +} + +bool align_up(size_t value, size_t alignment, size_t * aligned_value) { + if (aligned_value == nullptr || alignment == 0) { + return false; + } + + const size_t remainder = value % alignment; + if (remainder == 0) { + *aligned_value = value; + return true; + } + + const size_t padding = alignment - remainder; + if (value > std::numeric_limits<size_t>::max() - padding) { + return false; + } + + *aligned_value = value + padding; + return true; +} + +bool align_up_uintptr(uintptr_t value, size_t alignment, uintptr_t * aligned_value) { + if (aligned_value == nullptr || alignment == 0) { + return false; + } + + const uintptr_t remainder = value % alignment; + if (remainder == 0) { + *aligned_value = value; + return true; + } + + const uintptr_t padding = alignment - remainder; + if (value > std::numeric_limits<uintptr_t>::max() - padding) { + return false; + } + + *aligned_value = value + padding; + return true; +} + +class spine_mem_pool_manager { + public: + explicit spine_mem_pool_manager(size_t default_chunk_size) : default_chunk_size_(default_chunk_size) {} + + virtual ~spine_mem_pool_manager() = default; + + void * alloc(size_t size, size_t alignment) { + if (size == 0 || !is_power_of_two(alignment)) { + return nullptr; + } + + size_t aligned_size = 0; + if (!align_up(size, alignment, &aligned_size)) { + GGML_LOG_ERROR("CPU_RISCV64_SPACEMIT: %s: align_up failed for size %zu alignment %zu\n", __func__, size, + alignment); + return nullptr; + } + + pool_allocation allocation; + + std::lock_guard<std::mutex> lock(mutex_); + + if (!try_alloc_locked(aligned_size, alignment, &allocation)) { + if (!add_chunk_locked(aligned_size, alignment)) { + return nullptr; + } + + if (!try_alloc_locked(aligned_size, alignment, &allocation)) { + GGML_LOG_ERROR("CPU_RISCV64_SPACEMIT: %s: allocation retry failed for size %zu alignment %zu\n", + __func__, aligned_size, alignment); + return nullptr; + } + } + + try { + const auto [allocation_it, inserted] = allocations_.emplace(allocation.base, allocation); + if (!inserted) { + GGML_LOG_ERROR("CPU_RISCV64_SPACEMIT: %s: duplicate allocation key %p\n", __func__, allocation.base); + rollback_allocation_locked(allocation); + return nullptr; + } + } catch (const std::bad_alloc &) { + rollback_allocation_locked(allocation); + throw; + } + + return allocation.base; + } + + void free(void * base) { + if (base == nullptr) { + return; + } + + std::lock_guard<std::mutex> lock(mutex_); + + auto allocation_it = allocations_.find(base); + if (allocation_it == allocations_.end()) { + GGML_LOG_ERROR("CPU_RISCV64_SPACEMIT: %s: unknown allocation %p\n", __func__, base); + return; + } + + pool_allocation allocation = allocation_it->second; + allocations_.erase(allocation_it); + + auto chunk_it = find_chunk_locked(allocation); + if (chunk_it == chunks_.end()) { + GGML_LOG_ERROR("CPU_RISCV64_SPACEMIT: %s: unknown chunk for allocation %p size %zu\n", __func__, + allocation.base, allocation.size); + return; + } + + auto * chunk_base = chunk_it->base; + auto * alloc_base = static_cast<uint8_t *>(allocation.base); + if (alloc_base < chunk_base || alloc_base >= chunk_base + chunk_it->size) { + GGML_LOG_ERROR("CPU_RISCV64_SPACEMIT: %s: allocation %p out of chunk range %p..%p\n", __func__, + allocation.base, chunk_base, chunk_base + chunk_it->size); + return; + } + + const size_t offset = static_cast<size_t>(alloc_base - chunk_base); + if (offset > chunk_it->size || allocation.size > chunk_it->size - offset) { + GGML_LOG_ERROR("CPU_RISCV64_SPACEMIT: %s: allocation %p size %zu exceeds chunk size %zu\n", __func__, + allocation.base, allocation.size, chunk_it->size); + return; + } + + insert_free_block_locked(*chunk_it, { offset, allocation.size }); + maybe_release_empty_chunk_locked(chunk_it); + } + + protected: + void release_chunks() { + std::lock_guard<std::mutex> lock(mutex_); + + allocations_.clear(); + for (auto & chunk : chunks_) { + dealloc_chunk(&chunk); + } + chunks_.clear(); + } + + size_t default_chunk_size() const { return default_chunk_size_; } + + static void clear_chunk(pool_chunk * chunk) { + chunk->base = nullptr; + chunk->size = 0; + chunk->fd = -1; + chunk->free_blocks.clear(); + } + + virtual bool alloc_chunk(size_t min_size, size_t alignment, void * hint_addr, pool_chunk * chunk) = 0; + virtual void dealloc_chunk(pool_chunk * chunk) = 0; + + private: + struct alloc_candidate { + size_t chunk_index{ 0 }; + size_t block_index{ 0 }; + size_t aligned_offset{ 0 }; + uintptr_t address{ std::numeric_limits<uintptr_t>::max() }; + bool valid{ false }; + }; + + std::vector<pool_chunk>::iterator find_chunk_locked(const pool_allocation & allocation) { + return std::find_if(chunks_.begin(), chunks_.end(), [&](const pool_chunk & chunk) { + return chunk.base == allocation.chunk_base && chunk.size == allocation.chunk_size; + }); + } + + bool add_chunk_locked(size_t min_size, size_t alignment) { + pool_chunk chunk; + const size_t chunk_request = default_chunk_size_ == 0 ? min_size : std::max(min_size, default_chunk_size_); + void * hint_addr = nullptr; + + for (const auto & existing_chunk : chunks_) { + auto * chunk_end = existing_chunk.base + existing_chunk.size; + if (hint_addr == nullptr || chunk_end > hint_addr) { + hint_addr = chunk_end; + } + } + + if (!alloc_chunk(chunk_request, alignment, hint_addr, &chunk)) { + return false; + } + + if (chunk.base == nullptr || chunk.size < min_size) { + GGML_LOG_ERROR( + "CPU_RISCV64_SPACEMIT: %s: invalid chunk returned for request size %zu, chunk_base=%p chunk_size=%zu\n", + __func__, min_size, chunk.base, chunk.size); + dealloc_chunk(&chunk); + return false; + } + + try { + chunk.free_blocks.push_back({ 0, chunk.size }); + chunks_.push_back(std::move(chunk)); + } catch (const std::bad_alloc &) { + dealloc_chunk(&chunk); + throw; + } + + return true; + } + + void rollback_allocation_locked(const pool_allocation & allocation) { + auto chunk_it = find_chunk_locked(allocation); + if (chunk_it == chunks_.end()) { + GGML_LOG_ERROR("CPU_RISCV64_SPACEMIT: %s: failed to rollback allocation %p, owning chunk not found\n", + __func__, allocation.base); + return; + } + + auto * chunk_base = chunk_it->base; + auto * alloc_base = static_cast<uint8_t *>(allocation.base); + if (alloc_base < chunk_base || alloc_base >= chunk_base + chunk_it->size) { + GGML_LOG_ERROR("CPU_RISCV64_SPACEMIT: %s: failed to rollback allocation %p, chunk range is invalid\n", + __func__, allocation.base); + return; + } + + const size_t offset = static_cast<size_t>(alloc_base - chunk_base); + if (offset > chunk_it->size || allocation.size > chunk_it->size - offset) { + GGML_LOG_ERROR("CPU_RISCV64_SPACEMIT: %s: failed to rollback allocation %p size %zu\n", __func__, + allocation.base, allocation.size); + return; + } + + insert_free_block_locked(*chunk_it, { offset, allocation.size }); + maybe_release_empty_chunk_locked(chunk_it); + } + + bool try_alloc_locked(size_t size, size_t alignment, pool_allocation * allocation) { + alloc_candidate best; + + for (size_t chunk_index = 0; chunk_index < chunks_.size(); ++chunk_index) { + const auto & chunk = chunks_[chunk_index]; + for (size_t block_index = 0; block_index < chunk.free_blocks.size(); ++block_index) { + const auto & block = chunk.free_blocks[block_index]; + + uintptr_t aligned_addr = 0; + const auto block_addr = reinterpret_cast<uintptr_t>(chunk.base + block.offset); + if (!align_up_uintptr(block_addr, alignment, &aligned_addr)) { + continue; + } + + if (aligned_addr < block_addr) { + continue; + } + + const size_t aligned_offset = block.offset + static_cast<size_t>(aligned_addr - block_addr); + const size_t padding = aligned_offset - block.offset; + if (padding > block.size || size > block.size - padding) { + continue; + } + + if (!best.valid || aligned_addr < best.address) { + best.chunk_index = chunk_index; + best.block_index = block_index; + best.aligned_offset = aligned_offset; + best.address = aligned_addr; + best.valid = true; + } + } + } + + if (!best.valid) { + return false; + } + + auto & chunk = chunks_[best.chunk_index]; + const free_block block = chunk.free_blocks[best.block_index]; + const size_t padding = best.aligned_offset - block.offset; + const size_t alloc_end = best.aligned_offset + size; + const size_t block_end = block.offset + block.size; + + chunk.free_blocks.erase(chunk.free_blocks.begin() + best.block_index); + auto insert_it = chunk.free_blocks.begin() + best.block_index; + if (padding != 0) { + insert_it = chunk.free_blocks.insert(insert_it, { block.offset, padding }); + ++insert_it; + } + if (alloc_end < block_end) { + chunk.free_blocks.insert(insert_it, { alloc_end, block_end - alloc_end }); + } + + allocation->chunk_base = chunk.base; + allocation->chunk_size = chunk.size; + allocation->base = chunk.base + best.aligned_offset; + allocation->size = size; + return true; + } + + void maybe_release_empty_chunk_locked(std::vector<pool_chunk>::iterator chunk_it) { + if (chunk_it->free_blocks.size() != 1) { + return; + } + + const auto & block = chunk_it->free_blocks.front(); + if (block.offset != 0 || block.size != chunk_it->size) { + return; + } + + dealloc_chunk(&*chunk_it); + chunks_.erase(chunk_it); + } + + void insert_free_block_locked(pool_chunk & chunk, free_block block) { + auto it = chunk.free_blocks.begin(); + while (it != chunk.free_blocks.end() && it->offset < block.offset) { + ++it; + } + + if (it != chunk.free_blocks.begin()) { + const auto & prev = *(it - 1); + if (prev.offset + prev.size > block.offset) { + GGML_LOG_ERROR("CPU_RISCV64_SPACEMIT: %s: overlapping free block at offset %zu size %zu\n", __func__, + block.offset, block.size); + return; + } + } + + if (it != chunk.free_blocks.end() && block.offset + block.size > it->offset) { + GGML_LOG_ERROR("CPU_RISCV64_SPACEMIT: %s: overlapping next free block at offset %zu size %zu\n", __func__, + block.offset, block.size); + return; + } + + it = chunk.free_blocks.insert(it, block); + + if (it != chunk.free_blocks.begin()) { + auto prev = it - 1; + if (prev->offset + prev->size == it->offset) { + it->offset = prev->offset; + it->size += prev->size; + it = chunk.free_blocks.erase(prev); + } + } + + if (it + 1 != chunk.free_blocks.end() && it->offset + it->size == (it + 1)->offset) { + it->size += (it + 1)->size; + chunk.free_blocks.erase(it + 1); + } + } + + std::mutex mutex_; + std::vector<pool_chunk> chunks_; + std::unordered_map<void *, pool_allocation> allocations_; + size_t default_chunk_size_{ 0 }; +}; + +class spine_mem_pool_posix final : public spine_mem_pool_manager { + public: + spine_mem_pool_posix() : spine_mem_pool_manager(0) {} + + ~spine_mem_pool_posix() override { release_chunks(); } + + private: + bool alloc_chunk(size_t min_size, size_t alignment, void * hint_addr, pool_chunk * chunk) override { + (void) hint_addr; + + const size_t alloc_alignment = std::max(alignment, sizeof(void *)); + void * base = nullptr; + const int rc = posix_memalign(&base, alloc_alignment, min_size); + if (rc != 0) { + GGML_LOG_ERROR("CPU_RISCV64_SPACEMIT: %s: posix_memalign failed for size %zu alignment %zu, rc=%d\n", + __func__, min_size, alloc_alignment, rc); + return false; + } + + chunk->base = static_cast<uint8_t *>(base); + chunk->size = min_size; + chunk->fd = -1; + return true; + } + + void dealloc_chunk(pool_chunk * chunk) override { + std::free(chunk->base); + clear_chunk(chunk); + } +}; + +class spine_mem_pool_transparent_hugepage final : public spine_mem_pool_manager { + public: + spine_mem_pool_transparent_hugepage() : spine_mem_pool_manager(SPINE_MEM_POOL_CHUNK_SIZE) {} + + ~spine_mem_pool_transparent_hugepage() override { release_chunks(); } + + private: + bool alloc_chunk(size_t min_size, size_t alignment, void * hint_addr, pool_chunk * chunk) override { + (void) alignment; + + size_t chunk_size = 0; + if (!align_up(min_size, default_chunk_size(), &chunk_size)) { + GGML_LOG_ERROR("CPU_RISCV64_SPACEMIT: %s: failed to round chunk size for %zu\n", __func__, min_size); + return false; + } + + void * map_addr = mmap(hint_addr, chunk_size, PROT_READ | PROT_WRITE, MAP_PRIVATE | MAP_ANONYMOUS, -1, 0); + if (map_addr == MAP_FAILED) { + GGML_LOG_ERROR("CPU_RISCV64_SPACEMIT: %s: mmap failed for chunk size %zu, errno=%d\n", __func__, chunk_size, + errno); + return false; + } + + if (madvise(map_addr, chunk_size, MADV_HUGEPAGE) != 0) { + GGML_LOG_ERROR("CPU_RISCV64_SPACEMIT: %s: madvise(MADV_HUGEPAGE) failed for chunk size %zu, errno=%d\n", + __func__, chunk_size, errno); + munmap(map_addr, chunk_size); + return false; + } + + chunk->base = static_cast<uint8_t *>(map_addr); + chunk->size = chunk_size; + chunk->fd = -1; + return true; + } + + void dealloc_chunk(pool_chunk * chunk) override { + if (chunk->base != nullptr && chunk->size != 0 && munmap(chunk->base, chunk->size) != 0) { + GGML_LOG_ERROR("CPU_RISCV64_SPACEMIT: %s: munmap failed for chunk %p size %zu, errno=%d\n", __func__, + chunk->base, chunk->size, errno); + } + + clear_chunk(chunk); + } +}; + +class spine_mem_pool_hugetlb_1g final : public spine_mem_pool_manager { + public: + spine_mem_pool_hugetlb_1g() : spine_mem_pool_manager(SPINE_MEM_POOL_1G_REGION_SIZE) {} + + ~spine_mem_pool_hugetlb_1g() override { release_chunks(); } + + private: + bool alloc_chunk(size_t min_size, size_t alignment, void * hint_addr, pool_chunk * chunk) override { + (void) alignment; + (void) hint_addr; + + size_t region_size = 0; + if (!align_up(min_size, SPINE_MEM_POOL_1G_REGION_SIZE, ®ion_size)) { + GGML_LOG_ERROR("CPU_RISCV64_SPACEMIT: %s: failed to round hugetlb_1g size for %zu\n", __func__, min_size); + return false; + } + + const int fd = open(SPINE_MEM_POOL_HUGETLB_1G_DEV, O_RDWR); + if (fd < 0) { + GGML_LOG_ERROR("CPU_RISCV64_SPACEMIT: %s: open(%s) failed, errno=%d\n", __func__, + SPINE_MEM_POOL_HUGETLB_1G_DEV, errno); + return false; + } + + hugetlb_1g_region region; + region.size = region_size; + region.flags = HUGETLB_1G_FLAG_REQUIRE_PUD; + if (ioctl(fd, HUGETLB_1G_IOC_ALLOC, ®ion) < 0) { + GGML_LOG_ERROR("CPU_RISCV64_SPACEMIT: %s: HUGETLB_1G_IOC_ALLOC failed for size %zu, errno=%d\n", __func__, + region_size, errno); + close(fd); + return false; + } + + void * map_addr = mmap(nullptr, region.size, PROT_READ | PROT_WRITE, MAP_SHARED, fd, 0); + if (map_addr == MAP_FAILED) { + GGML_LOG_ERROR("CPU_RISCV64_SPACEMIT: %s: mmap failed for hugetlb_1g size %llu, errno=%d\n", __func__, + static_cast<unsigned long long>(region.size), errno); + ioctl(fd, HUGETLB_1G_IOC_FREE); + close(fd); + return false; + } + + chunk->base = static_cast<uint8_t *>(map_addr); + chunk->size = region.size; + chunk->fd = fd; + return true; + } + + void dealloc_chunk(pool_chunk * chunk) override { + if (chunk->base != nullptr && chunk->size != 0 && munmap(chunk->base, chunk->size) != 0) { + GGML_LOG_ERROR("CPU_RISCV64_SPACEMIT: %s: munmap failed for hugetlb_1g chunk %p size %zu, errno=%d\n", + __func__, chunk->base, chunk->size, errno); + } + + if (chunk->fd >= 0) { + if (ioctl(chunk->fd, HUGETLB_1G_IOC_FREE) < 0) { + GGML_LOG_ERROR("CPU_RISCV64_SPACEMIT: %s: HUGETLB_1G_IOC_FREE failed for chunk %p, errno=%d\n", + __func__, chunk->base, errno); + } + + close(chunk->fd); + } + + clear_chunk(chunk); + } +}; + +class spine_mem_pool_shared_mem final : public spine_mem_pool_manager { + public: + spine_mem_pool_shared_mem() : spine_mem_pool_manager(SPINE_SHARE_MEM_POOL_CHUNK_SIZE) {} + + ~spine_mem_pool_shared_mem() override { release_chunks(); } + + private: + bool alloc_chunk(size_t min_size, size_t alignment, void * hint_addr, pool_chunk * chunk) override { + (void) alignment; + + if (hint_addr != nullptr) { + GGML_LOG_ERROR("CPU_RISCV64_SPACEMIT: %s: shared_mem does not support multiple active chunks\n", __func__); + return false; + } + + if (min_size > default_chunk_size()) { + GGML_LOG_ERROR("CPU_RISCV64_SPACEMIT: %s: shared_mem request %zu exceeds chunk size %zu\n", __func__, + min_size, default_chunk_size()); + return false; + } + + const int fd = open(SPINE_MEM_POOL_TCM_SYNC_MEM_DEV, O_RDWR | O_SYNC); + if (fd < 0) { + GGML_LOG_ERROR("CPU_RISCV64_SPACEMIT: %s: open(%s) failed, errno=%d\n", __func__, + SPINE_MEM_POOL_TCM_SYNC_MEM_DEV, errno); + return false; + } + + void * map_addr = mmap(nullptr, default_chunk_size(), PROT_READ | PROT_WRITE, MAP_SHARED, fd, 0); + if (map_addr == MAP_FAILED) { + GGML_LOG_ERROR("CPU_RISCV64_SPACEMIT: %s: mmap failed for %s size %zu, errno=%d\n", __func__, + SPINE_MEM_POOL_TCM_SYNC_MEM_DEV, default_chunk_size(), errno); + close(fd); + return false; + } + + chunk->base = static_cast<uint8_t *>(map_addr); + chunk->size = default_chunk_size(); + chunk->fd = fd; + return true; + } + + void dealloc_chunk(pool_chunk * chunk) override { + if (chunk->base != nullptr && chunk->size != 0 && munmap(chunk->base, chunk->size) != 0) { + GGML_LOG_ERROR("CPU_RISCV64_SPACEMIT: %s: munmap failed for shared_mem chunk %p size %zu, errno=%d\n", + __func__, chunk->base, chunk->size, errno); + } + + if (chunk->fd >= 0) { + close(chunk->fd); + } + + clear_chunk(chunk); + } +}; + +spine_mem_pool_manager & get_spine_mem_pool_manager() { + static std::once_flag pool_once; + static std::unique_ptr<spine_mem_pool_manager> selected_pool; + static spine_mem_pool_backend selected_backend = spine_mem_pool_backend::none; + + spine_mem_pool_backend backend = global_spine_env_info.mem_backend; + if (backend == spine_mem_pool_backend::none) { + backend = spine_mem_pool_backend::transparent_hugepage; + } + + std::call_once(pool_once, [&]() { + selected_backend = backend; + + switch (selected_backend) { + case spine_mem_pool_backend::posix_memalign: + selected_pool = std::make_unique<spine_mem_pool_posix>(); + break; + case spine_mem_pool_backend::transparent_hugepage: + selected_pool = std::make_unique<spine_mem_pool_transparent_hugepage>(); + break; + case spine_mem_pool_backend::hugetlb_1g: + selected_pool = std::make_unique<spine_mem_pool_hugetlb_1g>(); + break; + case spine_mem_pool_backend::none: + selected_backend = spine_mem_pool_backend::transparent_hugepage; + selected_pool = std::make_unique<spine_mem_pool_transparent_hugepage>(); + break; + } + }); + + if (backend != selected_backend) { + GGML_LOG_ERROR( + "CPU_RISCV64_SPACEMIT: %s: mem pool backend is process-global and mutually exclusive, requested=%d but " + "selected=%d\n", + __func__, static_cast<int>(backend), static_cast<int>(selected_backend)); + } + + if (selected_pool) { + return *selected_pool; + } + + throw std::bad_alloc(); +} + +spine_mem_pool_manager & get_spine_mem_pool_shared_mem_manager() { + static std::once_flag shared_mem_pool_once; + static std::unique_ptr<spine_mem_pool_shared_mem> shared_mem_pool; + + std::call_once(shared_mem_pool_once, [&]() { shared_mem_pool = std::make_unique<spine_mem_pool_shared_mem>(); }); + + if (shared_mem_pool) { + return *shared_mem_pool; + } + + throw std::bad_alloc(); +} + +} // namespace + +bool spine_mem_pool_tcm_init(spine_mem_pool_tcm_info * info) noexcept { + if (info == nullptr) { + return false; + } + + *info = {}; + + if (spine_tcm_open_handle(NULL) != 0 || !spine_tcm_is_available()) { + return false; + } + + spine_tcm_mem_info_t mem_info; + if (spine_tcm_mem_info(&mem_info) != 0) { + return false; + } + + info->available = true; + info->blk_size = mem_info.blk_size; + info->blk_num = mem_info.blk_num; + info->is_fake_tcm = mem_info.is_fake_tcm != 0; + return true; +} + +void * spine_mem_pool_tcm_mem_get(int cpu_id) noexcept { + return spine_tcm_mem_get(cpu_id); +} + +void * spine_mem_pool_tcm_mem_wait(int cpu_id) noexcept { + return spine_tcm_mem_try_wait(cpu_id, 1000 * 1000); +} + +int spine_mem_pool_tcm_mem_release(int cpu_id) noexcept { + return spine_tcm_mem_release(cpu_id); +} + +void * spine_mem_pool_alloc(size_t size, size_t alignment) noexcept { + try { + return get_spine_mem_pool_manager().alloc(size, alignment); + } catch (const std::bad_alloc &) { + GGML_LOG_ERROR("CPU_RISCV64_SPACEMIT: %s: bad_alloc while allocating size %zu\n", __func__, size); + return nullptr; + } +} + +void * spine_mem_pool_shared_mem_alloc(size_t size, size_t alignment) noexcept { + try { + return get_spine_mem_pool_shared_mem_manager().alloc(size, alignment); + } catch (const std::bad_alloc &) { + GGML_LOG_ERROR("CPU_RISCV64_SPACEMIT: %s: bad_alloc while allocating shared memory size %zu\n", __func__, size); + return nullptr; + } +} + +void spine_mem_pool_free(void * base) noexcept { + try { + get_spine_mem_pool_manager().free(base); + } catch (const std::bad_alloc &) { + GGML_LOG_ERROR("CPU_RISCV64_SPACEMIT: %s: bad_alloc while freeing allocation %p\n", __func__, base); + } +} + +void spine_mem_pool_shared_mem_free(void * base) noexcept { + try { + get_spine_mem_pool_shared_mem_manager().free(base); + } catch (const std::bad_alloc &) { + GGML_LOG_ERROR("CPU_RISCV64_SPACEMIT: %s: bad_alloc while freeing shared allocation %p\n", __func__, base); + } +} + +} // namespace ggml::cpu::riscv64_spacemit + +extern "C" { +void * ggml_backend_cpu_riscv64_spacemit_alloc_shared(size_t size, size_t alignment) { + void * result = ggml::cpu::riscv64_spacemit::spine_mem_pool_shared_mem_alloc(size, alignment); + if (result == nullptr) { + GGML_LOG_ERROR("CPU_RISCV64_SPACEMIT: %s: failed to allocate shared memory size %zu alignment %zu\n", __func__, + size, alignment); + } + return result; +} + +void ggml_backend_cpu_riscv64_spacemit_free_shared(void * ptr) { + ggml::cpu::riscv64_spacemit::spine_mem_pool_shared_mem_free(ptr); +} +} diff --git a/ggml/src/ggml-cpu/spacemit/spine_mem_pool.h b/ggml/src/ggml-cpu/spacemit/spine_mem_pool.h new file mode 100644 index 00000000000..8740d2c99ef --- /dev/null +++ b/ggml/src/ggml-cpu/spacemit/spine_mem_pool.h @@ -0,0 +1,32 @@ +#pragma once + +#include <cstddef> +#include <cstdint> + +namespace ggml::cpu::riscv64_spacemit { + +enum class spine_mem_pool_backend : uint8_t { + none, + posix_memalign, + transparent_hugepage, + hugetlb_1g, +}; + +struct spine_mem_pool_tcm_info { + bool available{ false }; + size_t blk_size{ 0 }; + size_t blk_num{ 0 }; + bool is_fake_tcm{ false }; +}; + +bool spine_mem_pool_tcm_init(spine_mem_pool_tcm_info * info) noexcept; +void * spine_mem_pool_tcm_mem_get(int cpu_id) noexcept; +void * spine_mem_pool_tcm_mem_wait(int cpu_id) noexcept; +int spine_mem_pool_tcm_mem_release(int cpu_id) noexcept; + +void * spine_mem_pool_alloc(size_t size, size_t alignment) noexcept; +void * spine_mem_pool_shared_mem_alloc(size_t size, size_t alignment) noexcept; +void spine_mem_pool_free(void * base) noexcept; +void spine_mem_pool_shared_mem_free(void * base) noexcept; + +} // namespace ggml::cpu::riscv64_spacemit diff --git a/ggml/src/ggml-cpu/spacemit/spine_tcm.h b/ggml/src/ggml-cpu/spacemit/spine_tcm.h new file mode 100644 index 00000000000..f300d7d5c04 --- /dev/null +++ b/ggml/src/ggml-cpu/spacemit/spine_tcm.h @@ -0,0 +1,409 @@ +#ifndef SPINE_TCM_PUBLIC_H_ +#define SPINE_TCM_PUBLIC_H_ + +/* + * spine_tcm public API + * + * Usage: + * 1. Direct link mode + * Define SPINE_TCM_DIRECT_LINK and link against libspine_tcm.so. + * + * if (spine_tcm_is_available()) { + * void *buffer = spine_tcm_mem_get(0); + * spine_tcm_mem_free(0); + * } + * + * 2. Header-only loader mode + * Include this header without linking libspine_tcm.so. The loader first + * tries to reuse a process-global spine_tcm instance and falls back to + * dlopen("libspine_tcm.so") when needed. + * + * spine_tcm_open_handle(NULL); // optional pre-bind + * if (spine_tcm_is_available()) { + * void *buffer = spine_tcm_mem_get(0); + * spine_tcm_mem_free(0); + * } + */ + +#include <stddef.h> +#include <stdint.h> +#include <string.h> + +#if !defined(SPINE_TCM_BUILD_SHARED) && !defined(SPINE_TCM_DIRECT_LINK) +# include <dlfcn.h> +#endif + +#ifdef __cplusplus +extern "C" { +#endif + +#if defined(_WIN32) +# if defined(SPINE_TCM_BUILD_SHARED) +# define SPINE_TCM_API __declspec(dllexport) +# else +# define SPINE_TCM_API __declspec(dllimport) +# endif +#else +# define SPINE_TCM_API __attribute__((visibility("default"))) +#endif + +typedef struct spine_tcm_mem_info { + size_t blk_size; + size_t blk_num; + int is_fake_tcm; +} spine_tcm_mem_info_t; + +typedef struct spine_tcm_block_info { + int id; + void * va; + size_t size; + uint64_t phys_addr; + uint64_t cpu_affinity_mask; + int owner_tid; + int is_acquired; +} spine_tcm_block_info_t; + +/* Shared-library runtime ABI exported by libspine_tcm.so. */ +SPINE_TCM_API const char * spine_tcm_runtime_version(void); +SPINE_TCM_API int spine_tcm_runtime_is_available(void); +SPINE_TCM_API int spine_tcm_runtime_layout_info(spine_tcm_mem_info_t * info); +SPINE_TCM_API int spine_tcm_runtime_mem_info(int id, spine_tcm_block_info_t * info); +SPINE_TCM_API void * spine_tcm_runtime_mem_get(int id); +SPINE_TCM_API int spine_tcm_runtime_mem_free(int id); +SPINE_TCM_API void * spine_tcm_runtime_mem_try_wait(int id, size_t timeout_us); +SPINE_TCM_API int spine_tcm_runtime_mem_release(int id); +SPINE_TCM_API int spine_tcm_runtime_mem_force_release(int id); +SPINE_TCM_API int spine_tcm_runtime_mem_query(int id); + +#if defined(SPINE_TCM_DIRECT_LINK) +/* Optional no-op in direct-link mode. */ +static inline int spine_tcm_open_handle(const char * so_path) { + (void) so_path; + return 0; +} + +static inline const char * spine_tcm_version(void) { + return spine_tcm_runtime_version(); +} + +/* Returns 1 when the runtime driver is available, otherwise 0. */ +static inline int spine_tcm_is_available(void) { + return spine_tcm_runtime_is_available(); +} + +/* Returns runtime memory geometry and whether the current backend is fake TCM. */ +static inline int spine_tcm_mem_info(spine_tcm_mem_info_t * info) { + return spine_tcm_runtime_layout_info(info); +} + +/* Returns per-block runtime metadata for the given TCM id. */ +static inline int spine_tcm_block_info(int id, spine_tcm_block_info_t * info) { + return spine_tcm_runtime_mem_info(id, info); +} + +/* Returns a cached buffer for the given TCM id, or NULL on failure. */ +static inline void * spine_tcm_mem_get(int id) { + return spine_tcm_runtime_mem_get(id); +} + +/* Releases one reference acquired by spine_tcm_mem_get(id). */ +static inline int spine_tcm_mem_free(int id) { + return spine_tcm_runtime_mem_free(id); +} + +/* Waits for a TCM block handoff and returns the driver-owned buffer when available. */ +static inline void * spine_tcm_mem_try_wait(int id, size_t over_time) { + return spine_tcm_runtime_mem_try_wait(id, over_time); +} + +/* Releases a buffer acquired by spine_tcm_mem_try_wait(id, over_time). */ +static inline int spine_tcm_mem_release(int id) { + return spine_tcm_runtime_mem_release(id); +} + +/* Forces a release for the given TCM id when the backend supports it. */ +static inline int spine_tcm_mem_force_release(int id) { + return spine_tcm_runtime_mem_force_release(id); +} + +/* Returns whether the given TCM id is currently acquired. */ +static inline int spine_tcm_mem_query(int id) { + return spine_tcm_runtime_mem_query(id); +} +#elif !defined(SPINE_TCM_BUILD_SHARED) +typedef struct spine_tcm_handle { + void * module_handle; + int use_global_scope; + int owns_module_handle; + const char * (*runtime_version)(void); + int (*runtime_is_available)(void); + int (*runtime_layout_info)(spine_tcm_mem_info_t * info); + int (*runtime_mem_info)(int id, spine_tcm_block_info_t * info); + void * (*runtime_mem_get)(int id); + int (*runtime_mem_free)(int id); + void * (*runtime_mem_try_wait)(int id, size_t over_time); + int (*runtime_mem_release)(int id); + int (*runtime_mem_force_release)(int id); + int (*runtime_mem_query)(int id); +} spine_tcm_handle_t; + +static inline spine_tcm_handle_t * spine_tcm_default_handle(void) { + static spine_tcm_handle_t handle = { 0 }; + return &handle; +} + +static inline void spine_tcm_handle_reset(spine_tcm_handle_t * handle) { + if (handle != NULL) { + memset(handle, 0, sizeof(*handle)); + } +} + +static inline int spine_tcm_handle_bind(spine_tcm_handle_t * handle) { + void * symbol_scope = handle->use_global_scope ? RTLD_DEFAULT : handle->module_handle; + + handle->runtime_version = (const char * (*) (void) ) dlsym(symbol_scope, "spine_tcm_runtime_version"); + handle->runtime_is_available = (int (*)(void)) dlsym(symbol_scope, "spine_tcm_runtime_is_available"); + handle->runtime_layout_info = + (int (*)(spine_tcm_mem_info_t *)) dlsym(symbol_scope, "spine_tcm_runtime_layout_info"); + handle->runtime_mem_info = + (int (*)(int, spine_tcm_block_info_t *)) dlsym(symbol_scope, "spine_tcm_runtime_mem_info"); + handle->runtime_mem_get = (void * (*) (int) ) dlsym(symbol_scope, "spine_tcm_runtime_mem_get"); + handle->runtime_mem_free = (int (*)(int)) dlsym(symbol_scope, "spine_tcm_runtime_mem_free"); + handle->runtime_mem_try_wait = (void * (*) (int, size_t)) dlsym(symbol_scope, "spine_tcm_runtime_mem_try_wait"); + handle->runtime_mem_release = (int (*)(int)) dlsym(symbol_scope, "spine_tcm_runtime_mem_release"); + handle->runtime_mem_force_release = (int (*)(int)) dlsym(symbol_scope, "spine_tcm_runtime_mem_force_release"); + handle->runtime_mem_query = (int (*)(int)) dlsym(symbol_scope, "spine_tcm_runtime_mem_query"); + + return handle->runtime_version != NULL && handle->runtime_is_available != NULL && + handle->runtime_layout_info != NULL && handle->runtime_mem_info != NULL && + handle->runtime_mem_get != NULL && handle->runtime_mem_free != NULL && + handle->runtime_mem_try_wait != NULL && handle->runtime_mem_release != NULL && + handle->runtime_mem_force_release != NULL && handle->runtime_mem_query != NULL ? + 0 : + -1; +} + +/* + * Try to bind against an already-loaded process-global spine_tcm instance. + * The shared library exports spine_tcm_runtime_marker only for this probe. + */ +static inline int spine_tcm_try_bind_global(spine_tcm_handle_t * handle) { + if (dlsym(RTLD_DEFAULT, "spine_tcm_runtime_marker") == NULL) { + return -1; + } + + handle->use_global_scope = 1; + return spine_tcm_handle_bind(handle); +} + +/* + * Optional pre-bind entry point. + * + * Behavior: + * - Reuses an already-loaded global spine_tcm instance when available. + * - Otherwise loads the shared library from so_path or the default soname. + * - Repeated calls are safe and return 0 after the first successful bind. + */ +static inline int spine_tcm_open_handle(const char * so_path) { + spine_tcm_handle_t * resolved = spine_tcm_default_handle(); + const char * library = (so_path != NULL && so_path[0] != '\0') ? so_path : "libspine_tcm.so"; + + if (resolved->module_handle != NULL || resolved->use_global_scope) { + return 0; + } + + if (spine_tcm_try_bind_global(resolved) == 0) { + return 0; + } + + spine_tcm_handle_reset(resolved); + + resolved->module_handle = dlopen(library, RTLD_LAZY | RTLD_GLOBAL); + resolved->owns_module_handle = resolved->module_handle != NULL ? 1 : 0; + + if (resolved->module_handle == NULL) { + spine_tcm_handle_reset(resolved); + return -1; + } + + if (spine_tcm_handle_bind(resolved) != 0) { + if (resolved->owns_module_handle) { + dlclose(resolved->module_handle); + } + spine_tcm_handle_reset(resolved); + return -1; + } + + return 0; +} + +/* Returns 1 when the runtime driver is available, otherwise 0. */ +static inline int spine_tcm_is_available(void) { + spine_tcm_handle_t * resolved = spine_tcm_default_handle(); + + if (resolved->module_handle == NULL && !resolved->use_global_scope) { + (void) spine_tcm_open_handle(NULL); + } + + if ((resolved->module_handle == NULL && !resolved->use_global_scope) || resolved->runtime_is_available == NULL) { + return 0; + } + + return resolved->runtime_is_available(); +} + +/* Returns runtime memory geometry and whether the current backend is fake TCM. */ +static inline int spine_tcm_mem_info(spine_tcm_mem_info_t * info) { + spine_tcm_handle_t * resolved = spine_tcm_default_handle(); + + if (resolved->module_handle == NULL && !resolved->use_global_scope) { + (void) spine_tcm_open_handle(NULL); + } + + if ((resolved->module_handle == NULL && !resolved->use_global_scope) || resolved->runtime_layout_info == NULL) { + return -1; + } + + return resolved->runtime_layout_info(info); +} + +static inline const char * spine_tcm_version(void) { + spine_tcm_handle_t * resolved = spine_tcm_default_handle(); + + if (resolved->module_handle == NULL && !resolved->use_global_scope) { + (void) spine_tcm_open_handle(NULL); + } + + if ((resolved->module_handle == NULL && !resolved->use_global_scope) || resolved->runtime_version == NULL) { + return "unknown"; + } + + return resolved->runtime_version(); +} + +/* Returns per-block runtime metadata for the given TCM id. */ +static inline int spine_tcm_block_info(int id, spine_tcm_block_info_t * info) { + spine_tcm_handle_t * resolved = spine_tcm_default_handle(); + + if (resolved->module_handle == NULL && !resolved->use_global_scope) { + (void) spine_tcm_open_handle(NULL); + } + + if ((resolved->module_handle == NULL && !resolved->use_global_scope) || resolved->runtime_mem_info == NULL) { + return -1; + } + + return resolved->runtime_mem_info(id, info); +} + +/* Returns a cached buffer for the given TCM id, or NULL on failure. */ +static inline void * spine_tcm_mem_get(int id) { + spine_tcm_handle_t * resolved = spine_tcm_default_handle(); + + if (resolved->module_handle == NULL && !resolved->use_global_scope) { + (void) spine_tcm_open_handle(NULL); + } + + if (resolved->module_handle == NULL && !resolved->use_global_scope) { + return NULL; + } + + if (resolved->runtime_mem_get == NULL) { + return NULL; + } + + return resolved->runtime_mem_get(id); +} + +/* Releases one reference acquired by spine_tcm_mem_get(id). */ +static inline int spine_tcm_mem_free(int id) { + spine_tcm_handle_t * resolved = spine_tcm_default_handle(); + + if (resolved->module_handle == NULL && !resolved->use_global_scope) { + (void) spine_tcm_open_handle(NULL); + } + + if ((resolved->module_handle == NULL && !resolved->use_global_scope) || resolved->runtime_mem_free == NULL) { + return -1; + } + + return resolved->runtime_mem_free(id); +} + +/* Waits for a TCM block handoff and returns the driver-owned buffer when available. */ +static inline void * spine_tcm_mem_try_wait(int id, size_t over_time) { + spine_tcm_handle_t * resolved = spine_tcm_default_handle(); + + if (resolved->module_handle == NULL && !resolved->use_global_scope) { + (void) spine_tcm_open_handle(NULL); + } + + if (resolved->module_handle == NULL && !resolved->use_global_scope) { + return NULL; + } + + if (resolved->runtime_mem_try_wait == NULL) { + return NULL; + } + + return resolved->runtime_mem_try_wait(id, over_time); +} + +/* Releases a buffer acquired by spine_tcm_mem_try_wait(id, over_time). */ +static inline int spine_tcm_mem_release(int id) { + spine_tcm_handle_t * resolved = spine_tcm_default_handle(); + + if (resolved->module_handle == NULL && !resolved->use_global_scope) { + (void) spine_tcm_open_handle(NULL); + } + + if ((resolved->module_handle == NULL && !resolved->use_global_scope) || resolved->runtime_mem_release == NULL) { + return -1; + } + + return resolved->runtime_mem_release(id); +} + +/* Forces a release for the given TCM id when the backend supports it. */ +static inline int spine_tcm_mem_force_release(int id) { + spine_tcm_handle_t * resolved = spine_tcm_default_handle(); + + if (resolved->module_handle == NULL && !resolved->use_global_scope) { + (void) spine_tcm_open_handle(NULL); + } + + if ((resolved->module_handle == NULL && !resolved->use_global_scope) || + resolved->runtime_mem_force_release == NULL) { + return -1; + } + + return resolved->runtime_mem_force_release(id); +} + +/* Returns whether the given TCM id is currently acquired. */ +static inline int spine_tcm_mem_query(int id) { + spine_tcm_handle_t * resolved = spine_tcm_default_handle(); + + if (resolved->module_handle == NULL && !resolved->use_global_scope) { + (void) spine_tcm_open_handle(NULL); + } + + if ((resolved->module_handle == NULL && !resolved->use_global_scope) || resolved->runtime_mem_query == NULL) { + return -1; + } + + return resolved->runtime_mem_query(id); +} +#else +static inline const char * spine_tcm_version(void) { + return spine_tcm_runtime_version(); +} +#endif + +#define SPINE_TCM_VERSION (spine_tcm_version()) + +#ifdef __cplusplus +} +#endif + +#endif diff --git a/ggml/src/ggml-cpu/unary-ops.cpp b/ggml/src/ggml-cpu/unary-ops.cpp index 1d9873ad0f2..1d8344436f0 100644 --- a/ggml/src/ggml-cpu/unary-ops.cpp +++ b/ggml/src/ggml-cpu/unary-ops.cpp @@ -111,7 +111,7 @@ template <float (*op)(float), typename src0_t, typename dst_t> static void apply_unary_op(const ggml_compute_params * params, ggml_tensor * dst) { const ggml_tensor * src0 = dst->src[0]; - GGML_ASSERT(ggml_is_contiguous_1(src0) && ggml_is_contiguous_1(dst) && ggml_are_same_shape(src0, dst)); + GGML_ASSERT(ggml_is_contiguous_rows(src0) && ggml_is_contiguous_rows(dst) && ggml_are_same_shape(src0, dst)); GGML_TENSOR_UNARY_OP_LOCALS diff --git a/ggml/src/ggml-cpu/vec.cpp b/ggml/src/ggml-cpu/vec.cpp index 427e63245b0..67b6b05cac8 100644 --- a/ggml/src/ggml-cpu/vec.cpp +++ b/ggml/src/ggml-cpu/vec.cpp @@ -236,7 +236,24 @@ void ggml_vec_dot_bf16(int n, float * GGML_RESTRICT s, size_t bs, ggml_bf16_t * vfloat32m1_t redsum = __riscv_vfredusum_vs_f32m4_f32m1(vsum0, __riscv_vfmv_v_f_f32m1(0.0f, 1), vl); sumf += __riscv_vfmv_f_s_f32m1_f32(redsum); +#elif defined(__POWER9_VECTOR__) || defined(__VXE__) || defined(__VXE2__) + const int np = (n & ~(GGML_BF16_STEP - 1)); + if (np > 0) { + GGML_F32_VEC sum[4] = {GGML_F32_VEC_ZERO}; + for (; i < np; i += GGML_BF16_STEP) { + GGML_BF16_VEC vx0 = GGML_BF16_VEC_LOAD(x + i); + GGML_BF16_VEC vx1 = GGML_BF16_VEC_LOAD(x + i + 8); + GGML_BF16_VEC vy0 = GGML_BF16_VEC_LOAD(y + i); + GGML_BF16_VEC vy1 = GGML_BF16_VEC_LOAD(y + i + 8); + GGML_BF16_FMA_LO(sum[0], vx0, vy0); + GGML_BF16_FMA_HI(sum[1], vx0, vy0); + GGML_BF16_FMA_LO(sum[2], vx1, vy1); + GGML_BF16_FMA_HI(sum[3], vx1, vy1); + } + GGML_F32x4_REDUCE_4(sumf, sum[0], sum[1], sum[2], sum[3]); + } #endif + for (; i < n; ++i) { sumf += (ggml_float)(GGML_BF16_TO_FP32(x[i]) * GGML_BF16_TO_FP32(y[i])); @@ -256,67 +273,51 @@ void ggml_vec_dot_f16(int n, float * GGML_RESTRICT s, size_t bs, ggml_fp16_t * G #if defined(GGML_SIMD) #if defined(__ARM_FEATURE_SVE) - const int sve_register_length = svcntb() * 8; //get vector length - const int ggml_f16_epr = sve_register_length / 16; // running when 16 - const int ggml_f16_step = 8 * ggml_f16_epr; // choose 8 SVE registers - - const int np= (n & ~(ggml_f16_step - 1)); - svfloat16_t sum1 = svdup_n_f16(0.0f); - svfloat16_t sum2 = svdup_n_f16(0.0f); - svfloat16_t sum3 = svdup_n_f16(0.0f); - svfloat16_t sum4 = svdup_n_f16(0.0f); - - svfloat16_t ax1, ax2, ax3, ax4, ax5, ax6, ax7, ax8; - svfloat16_t ay1, ay2, ay3, ay4, ay5, ay6, ay7, ay8; - for (int i = 0; i < np; i += ggml_f16_step) { - ax1 = GGML_F16x_VEC_LOAD(x + i + 0 * ggml_f16_epr, 0); - ay1 = GGML_F16x_VEC_LOAD(y + i + 0 * ggml_f16_epr, 0); - sum1 = GGML_F16x_VEC_FMA(sum1, ax1, ay1); - - ax2 = GGML_F16x_VEC_LOAD(x + i + 1 * ggml_f16_epr, 1); - ay2 = GGML_F16x_VEC_LOAD(y + i + 1 * ggml_f16_epr, 1); - sum2 = GGML_F16x_VEC_FMA(sum2, ax2, ay2); - - ax3 = GGML_F16x_VEC_LOAD(x + i + 2 * ggml_f16_epr, 2); - ay3 = GGML_F16x_VEC_LOAD(y + i + 2 * ggml_f16_epr, 2); - sum3 = GGML_F16x_VEC_FMA(sum3, ax3, ay3); - - ax4 = GGML_F16x_VEC_LOAD(x + i + 3 * ggml_f16_epr, 3); - ay4 = GGML_F16x_VEC_LOAD(y + i + 3 * ggml_f16_epr, 3); - sum4 = GGML_F16x_VEC_FMA(sum4, ax4, ay4); + const int ggml_f16_epr = svcnth(); + const int ggml_f16_step = 8 * ggml_f16_epr; + const int np = n - (n % ggml_f16_step); + const int np2 = n - (n % ggml_f16_epr); + + svfloat32_t sum1_lo = svdup_n_f32(0.0f); + svfloat32_t sum1_hi = svdup_n_f32(0.0f); + svfloat32_t sum2_lo = svdup_n_f32(0.0f); + svfloat32_t sum2_hi = svdup_n_f32(0.0f); + svfloat32_t sum3_lo = svdup_n_f32(0.0f); + svfloat32_t sum3_hi = svdup_n_f32(0.0f); + svfloat32_t sum4_lo = svdup_n_f32(0.0f); + svfloat32_t sum4_hi = svdup_n_f32(0.0f); - ax5 = GGML_F16x_VEC_LOAD(x + i + 4 * ggml_f16_epr, 4); - ay5 = GGML_F16x_VEC_LOAD(y + i + 4 * ggml_f16_epr, 4); - sum1 = GGML_F16x_VEC_FMA(sum1, ax5, ay5); - - ax6 = GGML_F16x_VEC_LOAD(x + i + 5 * ggml_f16_epr, 5); - ay6 = GGML_F16x_VEC_LOAD(y + i + 5 * ggml_f16_epr, 5); - sum2 = GGML_F16x_VEC_FMA(sum2, ax6, ay6); - - ax7 = GGML_F16x_VEC_LOAD(x + i + 6 * ggml_f16_epr, 6); - ay7 = GGML_F16x_VEC_LOAD(y + i + 6 * ggml_f16_epr, 6); - sum3 = GGML_F16x_VEC_FMA(sum3, ax7, ay7); - - ax8 = GGML_F16x_VEC_LOAD(x + i + 7 * ggml_f16_epr, 7); - ay8 = GGML_F16x_VEC_LOAD(y + i + 7 * ggml_f16_epr, 7); - sum4 = GGML_F16x_VEC_FMA(sum4, ax8, ay8); + for (int i = 0; i < np; i += ggml_f16_step) { + ggml_sve_f16_fma_widened(&sum1_lo, &sum1_hi, GGML_F16x_VEC_LOAD(x + i + 0 * ggml_f16_epr, 0), GGML_F16x_VEC_LOAD(y + i + 0 * ggml_f16_epr, 0)); + ggml_sve_f16_fma_widened(&sum2_lo, &sum2_hi, GGML_F16x_VEC_LOAD(x + i + 1 * ggml_f16_epr, 1), GGML_F16x_VEC_LOAD(y + i + 1 * ggml_f16_epr, 1)); + ggml_sve_f16_fma_widened(&sum3_lo, &sum3_hi, GGML_F16x_VEC_LOAD(x + i + 2 * ggml_f16_epr, 2), GGML_F16x_VEC_LOAD(y + i + 2 * ggml_f16_epr, 2)); + ggml_sve_f16_fma_widened(&sum4_lo, &sum4_hi, GGML_F16x_VEC_LOAD(x + i + 3 * ggml_f16_epr, 3), GGML_F16x_VEC_LOAD(y + i + 3 * ggml_f16_epr, 3)); + ggml_sve_f16_fma_widened(&sum1_lo, &sum1_hi, GGML_F16x_VEC_LOAD(x + i + 4 * ggml_f16_epr, 4), GGML_F16x_VEC_LOAD(y + i + 4 * ggml_f16_epr, 4)); + ggml_sve_f16_fma_widened(&sum2_lo, &sum2_hi, GGML_F16x_VEC_LOAD(x + i + 5 * ggml_f16_epr, 5), GGML_F16x_VEC_LOAD(y + i + 5 * ggml_f16_epr, 5)); + ggml_sve_f16_fma_widened(&sum3_lo, &sum3_hi, GGML_F16x_VEC_LOAD(x + i + 6 * ggml_f16_epr, 6), GGML_F16x_VEC_LOAD(y + i + 6 * ggml_f16_epr, 6)); + ggml_sve_f16_fma_widened(&sum4_lo, &sum4_hi, GGML_F16x_VEC_LOAD(x + i + 7 * ggml_f16_epr, 7), GGML_F16x_VEC_LOAD(y + i + 7 * ggml_f16_epr, 7)); } - const int np2 = (n & ~(ggml_f16_epr - 1)); // round down to multiple of 8 - for (int k = np; k < np2; k += ggml_f16_epr) { - svfloat16_t rx = GGML_F16x_VEC_LOAD(x + k, 0); - svfloat16_t ry = GGML_F16x_VEC_LOAD(y + k, 0); - sum1 = GGML_F16x_VEC_FMA(sum1, rx, ry); + for (int i = np; i < np2; i += ggml_f16_epr) { + ggml_sve_f16_fma_widened(&sum1_lo, &sum1_hi, GGML_F16x_VEC_LOAD(x + i, 0), GGML_F16x_VEC_LOAD(y + i, 0)); } if (np2 < n) { - svbool_t pg = svwhilelt_b16(np2, n); - svfloat16_t hx = svld1_f16(pg, (const __fp16 *)(x + np2)); - svfloat16_t hy = svld1_f16(pg, (const __fp16 *)(y + np2)); + const svbool_t pg = svwhilelt_b16(np2, n); + const svfloat16_t rx = svld1_f16(pg, (const __fp16 *)(x + np2)); + const svfloat16_t ry = svld1_f16(pg, (const __fp16 *)(y + np2)); - sum1 = svmad_f16_x(pg, hx, hy, sum1); + ggml_sve_f16_fma_widened(&sum1_lo, &sum1_hi, rx, ry); } - GGML_F16x_VEC_REDUCE(sumf, sum1, sum2, sum3, sum4); + + sum1_lo = svadd_f32_m(DEFAULT_PG32, sum1_lo, sum2_lo); + sum1_hi = svadd_f32_m(DEFAULT_PG32, sum1_hi, sum2_hi); + sum3_lo = svadd_f32_m(DEFAULT_PG32, sum3_lo, sum4_lo); + sum3_hi = svadd_f32_m(DEFAULT_PG32, sum3_hi, sum4_hi); + sum1_lo = svadd_f32_m(DEFAULT_PG32, sum1_lo, sum3_lo); + sum1_hi = svadd_f32_m(DEFAULT_PG32, sum1_hi, sum3_hi); + + sumf = ggml_sve_sum_f32x2(sum1_lo, sum1_hi); #elif defined(__riscv_v_intrinsic) #if defined(__riscv_zvfh) int vl = __riscv_vsetvlmax_e32m2(); diff --git a/ggml/src/ggml-cpu/vec.h b/ggml/src/ggml-cpu/vec.h index 3198b33b509..5de9cb5b7e0 100644 --- a/ggml/src/ggml-cpu/vec.h +++ b/ggml/src/ggml-cpu/vec.h @@ -14,6 +14,35 @@ // floating point type used to accumulate sums typedef double ggml_float; +#if defined(__ARM_FEATURE_SVE) +inline static void ggml_sve_f16_fma_widened( + svfloat32_t * acc_lo, + svfloat32_t * acc_hi, + svfloat16_t x, + svfloat16_t y) { +#if defined(__ARM_FEATURE_SVE2) + *acc_lo = svmlalb_f32(*acc_lo, x, y); + *acc_hi = svmlalt_f32(*acc_hi, x, y); +#else + // Plain SVE fallback path if SVE2 instructions not available + svfloat16_t x_even = svtrn1_f16(x, x); + svfloat16_t x_odd = svtrn2_f16(x, x); + + svfloat16_t y_even = svtrn1_f16(y, y); + svfloat16_t y_odd = svtrn2_f16(y, y); + + svbool_t pg = svptrue_b32(); + + *acc_lo = svmla_f32_x(pg, *acc_lo, svcvt_f32_f16_x(pg, x_even), svcvt_f32_f16_x(pg, y_even)); + *acc_hi = svmla_f32_x(pg, *acc_hi, svcvt_f32_f16_x(pg, x_odd), svcvt_f32_f16_x(pg, y_odd)); +#endif +} + +inline static ggml_float ggml_sve_sum_f32x2(svfloat32_t sum_lo, svfloat32_t sum_hi) { + return (ggml_float) (svaddv_f32(svptrue_b32(), sum_lo) + svaddv_f32(svptrue_b32(), sum_hi)); +} +#endif + #define GGML_GELU_FP16 #define GGML_GELU_QUICK_FP16 @@ -122,173 +151,130 @@ inline static void ggml_vec_dot_f16_unroll(const int n, const int xs, float * GG #if defined(GGML_SIMD) #if defined(__ARM_FEATURE_SVE) - const int sve_register_length = svcntb() * 8; - const int ggml_f16_epr = sve_register_length / 16; // running when 16 - const int ggml_f16_step = 8 * ggml_f16_epr; // choose 8 SVE registers - - const int np = (n & ~(ggml_f16_step - 1)); - - svfloat16_t sum_00 = svdup_n_f16(0.0f); - svfloat16_t sum_01 = svdup_n_f16(0.0f); - svfloat16_t sum_02 = svdup_n_f16(0.0f); - svfloat16_t sum_03 = svdup_n_f16(0.0f); - - svfloat16_t sum_10 = svdup_n_f16(0.0f); - svfloat16_t sum_11 = svdup_n_f16(0.0f); - svfloat16_t sum_12 = svdup_n_f16(0.0f); - svfloat16_t sum_13 = svdup_n_f16(0.0f); + const int ggml_f16_epr = svcnth(); + const int ggml_f16_step = 2 * ggml_f16_epr; + int np = n - (n % ggml_f16_step); + int np2 = n - (n % ggml_f16_epr); - svfloat16_t ax1, ax2, ax3, ax4, ax5, ax6, ax7, ax8; - svfloat16_t ay1, ay2, ay3, ay4, ay5, ay6, ay7, ay8; + svfloat32_t sum_0_0_lo = svdup_n_f32(0.0f); + svfloat32_t sum_0_0_hi = svdup_n_f32(0.0f); + svfloat32_t sum_0_1_lo = svdup_n_f32(0.0f); + svfloat32_t sum_0_1_hi = svdup_n_f32(0.0f); + svfloat32_t sum_1_0_lo = svdup_n_f32(0.0f); + svfloat32_t sum_1_0_hi = svdup_n_f32(0.0f); + svfloat32_t sum_1_1_lo = svdup_n_f32(0.0f); + svfloat32_t sum_1_1_hi = svdup_n_f32(0.0f); for (int i = 0; i < np; i += ggml_f16_step) { - ay1 = GGML_F16x_VEC_LOAD(y + i + 0 * ggml_f16_epr, 0); // 8 elements - - ax1 = GGML_F16x_VEC_LOAD(x[0] + i + 0*ggml_f16_epr, 0); // 8 elements - sum_00 = GGML_F16x_VEC_FMA(sum_00, ax1, ay1); // sum_00 = sum_00+ax1*ay1 - ax1 = GGML_F16x_VEC_LOAD(x[1] + i + 0*ggml_f16_epr, 0); // 8 elements - sum_10 = GGML_F16x_VEC_FMA(sum_10, ax1, ay1); - - ay2 = GGML_F16x_VEC_LOAD(y + i + 1 * ggml_f16_epr, 1); // next 8 elements - - ax2 = GGML_F16x_VEC_LOAD(x[0] + i + 1*ggml_f16_epr, 1); // next 8 elements - sum_01 = GGML_F16x_VEC_FMA(sum_01, ax2, ay2); - ax2 = GGML_F16x_VEC_LOAD(x[1] + i + 1*ggml_f16_epr, 1); - sum_11 = GGML_F16x_VEC_FMA(sum_11, ax2, ay2); - - ay3 = GGML_F16x_VEC_LOAD(y + i + 2 * ggml_f16_epr, 2); - - ax3 = GGML_F16x_VEC_LOAD(x[0] + i + 2*ggml_f16_epr, 2); - sum_02 = GGML_F16x_VEC_FMA(sum_02, ax3, ay3); - ax3 = GGML_F16x_VEC_LOAD(x[1] + i + 2*ggml_f16_epr, 2); - sum_12 = GGML_F16x_VEC_FMA(sum_12, ax3, ay3); - - ay4 = GGML_F16x_VEC_LOAD(y + i + 3 * ggml_f16_epr, 3); - - ax4 = GGML_F16x_VEC_LOAD(x[0] + i + 3*ggml_f16_epr, 3); - sum_03 = GGML_F16x_VEC_FMA(sum_03, ax4, ay4); - ax4 = GGML_F16x_VEC_LOAD(x[1] + i + 3*ggml_f16_epr, 3); - sum_13 = GGML_F16x_VEC_FMA(sum_13, ax4, ay4); - - ay5 = GGML_F16x_VEC_LOAD(y + i + 4 * ggml_f16_epr, 4); - - ax5 = GGML_F16x_VEC_LOAD(x[0] + i + 4*ggml_f16_epr, 4); - - sum_00 = GGML_F16x_VEC_FMA(sum_00, ax5, ay5); - ax5 = GGML_F16x_VEC_LOAD(x[1] + i + 4*ggml_f16_epr, 4); - sum_10 = GGML_F16x_VEC_FMA(sum_10, ax5, ay5); - - ay6 = GGML_F16x_VEC_LOAD(y + i + 5 * ggml_f16_epr, 5); - - ax6 = GGML_F16x_VEC_LOAD(x[0] + i + 5*ggml_f16_epr, 5); - - sum_01 = GGML_F16x_VEC_FMA(sum_01, ax6, ay6); - ax6 = GGML_F16x_VEC_LOAD(x[1] + i + 5*ggml_f16_epr, 5); - sum_11 = GGML_F16x_VEC_FMA(sum_11, ax6, ay6); + const svfloat16_t ay0 = GGML_F16x_VEC_LOAD(y + i, 0); + const svfloat16_t ax00 = GGML_F16x_VEC_LOAD(x[0] + i, 0); + const svfloat16_t ax01 = GGML_F16x_VEC_LOAD(x[1] + i, 0); - ay7 = GGML_F16x_VEC_LOAD(y + i + 6 * ggml_f16_epr, 6); + ggml_sve_f16_fma_widened(&sum_0_0_lo, &sum_0_0_hi, ax00, ay0); + ggml_sve_f16_fma_widened(&sum_1_0_lo, &sum_1_0_hi, ax01, ay0); - ax7 = GGML_F16x_VEC_LOAD(x[0] + i + 6*ggml_f16_epr, 6); + const svfloat16_t ay1 = GGML_F16x_VEC_LOAD(y + i + 1 * ggml_f16_epr, 0); + const svfloat16_t ax10 = GGML_F16x_VEC_LOAD(x[0] + i + 1 * ggml_f16_epr, 0); + const svfloat16_t ax11 = GGML_F16x_VEC_LOAD(x[1] + i + 1 * ggml_f16_epr, 0); - sum_02 = GGML_F16x_VEC_FMA(sum_02, ax7, ay7); - ax7 = GGML_F16x_VEC_LOAD(x[1] + i + 6*ggml_f16_epr, 6); - sum_12 = GGML_F16x_VEC_FMA(sum_12, ax7, ay7); - - ay8 = GGML_F16x_VEC_LOAD(y + i + 7 * ggml_f16_epr, 7); - - ax8 = GGML_F16x_VEC_LOAD(x[0] + i + 7*ggml_f16_epr, 7); - - sum_03 = GGML_F16x_VEC_FMA(sum_03, ax8, ay8); - ax8 = GGML_F16x_VEC_LOAD(x[1] + i + 7*ggml_f16_epr, 7); - sum_13 = GGML_F16x_VEC_FMA(sum_13, ax8, ay8); + ggml_sve_f16_fma_widened(&sum_0_1_lo, &sum_0_1_hi, ax10, ay1); + ggml_sve_f16_fma_widened(&sum_1_1_lo, &sum_1_1_hi, ax11, ay1); } - const int np2 = (n & ~(ggml_f16_epr - 1)); - for (int k = np; k < np2; k += ggml_f16_epr) { - svfloat16_t ry = GGML_F16x_VEC_LOAD(y + k, 0); + for (int i = np; i < np2; i += ggml_f16_epr) { + const svfloat16_t ry = GGML_F16x_VEC_LOAD(y + i, 0); + const svfloat16_t rx0 = GGML_F16x_VEC_LOAD(x[0] + i, 0); + const svfloat16_t rx1 = GGML_F16x_VEC_LOAD(x[1] + i, 0); - svfloat16_t rx = GGML_F16x_VEC_LOAD(x[0] + k, 0); - sum_00 = GGML_F16x_VEC_FMA(sum_00, rx, ry); - rx = GGML_F16x_VEC_LOAD(x[1] + k, 0); - sum_10 = GGML_F16x_VEC_FMA(sum_10, rx, ry); + ggml_sve_f16_fma_widened(&sum_0_0_lo, &sum_0_0_hi, rx0, ry); + ggml_sve_f16_fma_widened(&sum_1_0_lo, &sum_1_0_hi, rx1, ry); } if (np2 < n) { - svbool_t pg = svwhilelt_b16(np2, n); - svfloat16_t hx_0 = svld1_f16(pg, (const __fp16 *)(x[0] + np2)); - svfloat16_t hx_1 = svld1_f16(pg, (const __fp16 *)(x[1] + np2)); - svfloat16_t hy = svld1_f16(pg, (const __fp16 *)(y + np2)); + const svbool_t pg = svwhilelt_b16(np2, n); + const svfloat16_t ay = svld1_f16(pg, (const __fp16 *)(y + np2)); + const svfloat16_t ax0 = svld1_f16(pg, (const __fp16 *)(x[0] + np2)); + const svfloat16_t ax1 = svld1_f16(pg, (const __fp16 *)(x[1] + np2)); - sum_00 = svmad_f16_x(pg, hx_0, hy, sum_00); - sum_10 = svmad_f16_x(pg, hx_1, hy, sum_10); + ggml_sve_f16_fma_widened(&sum_0_0_lo, &sum_0_0_hi, ax0, ay); + ggml_sve_f16_fma_widened(&sum_1_0_lo, &sum_1_0_hi, ax1, ay); } - GGML_F16x_VEC_REDUCE(sumf[0], sum_00, sum_01, sum_02, sum_03); - GGML_F16x_VEC_REDUCE(sumf[1], sum_10, sum_11, sum_12, sum_13); - - #elif defined(__riscv_v_intrinsic) && defined(__riscv_zvfh) - size_t vl = __riscv_vsetvlmax_e32m4(); - - // initialize accumulators to all zeroes - vfloat32m4_t vsum0_0 = __riscv_vfmv_v_f_f32m4(0.0f, vl); - vfloat32m4_t vsum0_1 = __riscv_vfmv_v_f_f32m4(0.0f, vl); - vfloat32m4_t vsum1_0 = __riscv_vfmv_v_f_f32m4(0.0f, vl); - vfloat32m4_t vsum1_1 = __riscv_vfmv_v_f_f32m4(0.0f, vl); - - // calculate step size - const size_t epr = __riscv_vsetvlmax_e16m2(); - const size_t step = epr * 2; - const int np = (n & ~(step - 1)); - // unroll by 2 along the row dimension - for (int i = 0; i < np; i += step) { - vfloat16m2_t ay0 = __riscv_vle16_v_f16m2((const _Float16 *)(y + i), epr); - vfloat16m2_t ax0_0 = __riscv_vle16_v_f16m2((const _Float16 *)(x[0] + i), epr); - vfloat16m2_t ax1_0 = __riscv_vle16_v_f16m2((const _Float16 *)(x[1] + i), epr); - vsum0_0 = __riscv_vfwmacc_vv_f32m4(vsum0_0, ax0_0, ay0, epr); - vsum1_0 = __riscv_vfwmacc_vv_f32m4(vsum1_0, ax1_0, ay0, epr); - - vfloat16m2_t ay1 = __riscv_vle16_v_f16m2((const _Float16 *)(y + i + epr), epr); - vfloat16m2_t ax0_1 = __riscv_vle16_v_f16m2((const _Float16 *)(x[0] + i + epr), epr); - vfloat16m2_t ax1_1 = __riscv_vle16_v_f16m2((const _Float16 *)(x[1] + i + epr), epr); - vsum0_1 = __riscv_vfwmacc_vv_f32m4(vsum0_1, ax0_1, ay1, epr); - vsum1_1 = __riscv_vfwmacc_vv_f32m4(vsum1_1, ax1_1, ay1, epr); - } - - vfloat32m4_t vsum0 = __riscv_vfadd_vv_f32m4(vsum0_0, vsum0_1, vl); - vfloat32m4_t vsum1 = __riscv_vfadd_vv_f32m4(vsum1_0, vsum1_1, vl); + svfloat32_t sum_0_lo = svadd_f32_x(DEFAULT_PG32, sum_0_0_lo, sum_0_1_lo); + svfloat32_t sum_0_hi = svadd_f32_x(DEFAULT_PG32, sum_0_0_hi, sum_0_1_hi); + svfloat32_t sum_1_lo = svadd_f32_x(DEFAULT_PG32, sum_1_0_lo, sum_1_1_lo); + svfloat32_t sum_1_hi = svadd_f32_x(DEFAULT_PG32, sum_1_0_hi, sum_1_1_hi); + sumf[0] = ggml_sve_sum_f32x2(sum_0_lo, sum_0_hi); + sumf[1] = ggml_sve_sum_f32x2(sum_1_lo, sum_1_hi); + np = n; + #elif defined(__riscv_v_intrinsic) + #if defined(__riscv_zvfh) + size_t vl = __riscv_vsetvlmax_e32m4(); + + // initialize accumulators to all zeroes + vfloat32m4_t vsum0_0 = __riscv_vfmv_v_f_f32m4(0.0f, vl); + vfloat32m4_t vsum0_1 = __riscv_vfmv_v_f_f32m4(0.0f, vl); + vfloat32m4_t vsum1_0 = __riscv_vfmv_v_f_f32m4(0.0f, vl); + vfloat32m4_t vsum1_1 = __riscv_vfmv_v_f_f32m4(0.0f, vl); + + // calculate step size + const size_t epr = __riscv_vsetvlmax_e16m2(); + const size_t step = epr * 2; + int np = (n & ~(step - 1)); + + // unroll by 2 along the row dimension + for (int i = 0; i < np; i += step) { + vfloat16m2_t ay0 = __riscv_vle16_v_f16m2((const _Float16 *)(y + i), epr); + vfloat16m2_t ax0_0 = __riscv_vle16_v_f16m2((const _Float16 *)(x[0] + i), epr); + vfloat16m2_t ax1_0 = __riscv_vle16_v_f16m2((const _Float16 *)(x[1] + i), epr); + vsum0_0 = __riscv_vfwmacc_vv_f32m4(vsum0_0, ax0_0, ay0, epr); + vsum1_0 = __riscv_vfwmacc_vv_f32m4(vsum1_0, ax1_0, ay0, epr); + + vfloat16m2_t ay1 = __riscv_vle16_v_f16m2((const _Float16 *)(y + i + epr), epr); + vfloat16m2_t ax0_1 = __riscv_vle16_v_f16m2((const _Float16 *)(x[0] + i + epr), epr); + vfloat16m2_t ax1_1 = __riscv_vle16_v_f16m2((const _Float16 *)(x[1] + i + epr), epr); + vsum0_1 = __riscv_vfwmacc_vv_f32m4(vsum0_1, ax0_1, ay1, epr); + vsum1_1 = __riscv_vfwmacc_vv_f32m4(vsum1_1, ax1_1, ay1, epr); + } - // leftovers - for (int i = np; i < n; i += vl) { - vl = __riscv_vsetvl_e16m2(n - i); - vfloat16m2_t ay = __riscv_vle16_v_f16m2((const _Float16 *)(y + i), vl); - vfloat16m2_t ax0 = __riscv_vle16_v_f16m2((const _Float16 *)(x[0] + i), vl); - vfloat16m2_t ax1 = __riscv_vle16_v_f16m2((const _Float16 *)(x[1] + i), vl); + vfloat32m4_t vsum0 = __riscv_vfadd_vv_f32m4(vsum0_0, vsum0_1, vl); + vfloat32m4_t vsum1 = __riscv_vfadd_vv_f32m4(vsum1_0, vsum1_1, vl); - vsum0 = __riscv_vfwmacc_vv_f32m4(vsum0, ax0, ay, vl); - vsum1 = __riscv_vfwmacc_vv_f32m4(vsum1, ax1, ay, vl); - } + // leftovers + for (int i = np; i < n; i += vl) { + vl = __riscv_vsetvl_e16m2(n - i); + vfloat16m2_t ay = __riscv_vle16_v_f16m2((const _Float16 *)(y + i), vl); + vfloat16m2_t ax0 = __riscv_vle16_v_f16m2((const _Float16 *)(x[0] + i), vl); + vfloat16m2_t ax1 = __riscv_vle16_v_f16m2((const _Float16 *)(x[1] + i), vl); - // reduce - vl = __riscv_vsetvlmax_e32m2(); - vfloat32m2_t acc0_0 = __riscv_vfadd_vv_f32m2(__riscv_vget_v_f32m4_f32m2(vsum0, 0), - __riscv_vget_v_f32m4_f32m2(vsum0, 1), vl); - vl = __riscv_vsetvlmax_e32m1(); - vfloat32m1_t acc0_1 = __riscv_vfadd_vv_f32m1(__riscv_vget_v_f32m2_f32m1(acc0_0, 0), - __riscv_vget_v_f32m2_f32m1(acc0_0, 1), vl); - vfloat32m1_t redsum0 = __riscv_vfredusum_vs_f32m1_f32m1( - acc0_1, __riscv_vfmv_v_f_f32m1(0.0f, 1), vl); - - vl = __riscv_vsetvlmax_e32m2(); - vfloat32m2_t acc1_0 = __riscv_vfadd_vv_f32m2(__riscv_vget_v_f32m4_f32m2(vsum1, 0), - __riscv_vget_v_f32m4_f32m2(vsum1, 1), vl); - vl = __riscv_vsetvlmax_e32m1(); - vfloat32m1_t acc1_1 = __riscv_vfadd_vv_f32m1(__riscv_vget_v_f32m2_f32m1(acc1_0, 0), - __riscv_vget_v_f32m2_f32m1(acc1_0, 1), vl); - vfloat32m1_t redsum1 = __riscv_vfredusum_vs_f32m1_f32m1( - acc1_1, __riscv_vfmv_v_f_f32m1(0.0f, 1), vl); - sumf[0] = __riscv_vfmv_f_s_f32m1_f32(redsum0); - sumf[1] = __riscv_vfmv_f_s_f32m1_f32(redsum1); + vsum0 = __riscv_vfwmacc_vv_f32m4(vsum0, ax0, ay, vl); + vsum1 = __riscv_vfwmacc_vv_f32m4(vsum1, ax1, ay, vl); + } + // reduce + vl = __riscv_vsetvlmax_e32m2(); + vfloat32m2_t acc0_0 = __riscv_vfadd_vv_f32m2(__riscv_vget_v_f32m4_f32m2(vsum0, 0), + __riscv_vget_v_f32m4_f32m2(vsum0, 1), vl); + vl = __riscv_vsetvlmax_e32m1(); + vfloat32m1_t acc0_1 = __riscv_vfadd_vv_f32m1(__riscv_vget_v_f32m2_f32m1(acc0_0, 0), + __riscv_vget_v_f32m2_f32m1(acc0_0, 1), vl); + vfloat32m1_t redsum0 = __riscv_vfredusum_vs_f32m1_f32m1( + acc0_1, __riscv_vfmv_v_f_f32m1(0.0f, 1), vl); + + vl = __riscv_vsetvlmax_e32m2(); + vfloat32m2_t acc1_0 = __riscv_vfadd_vv_f32m2(__riscv_vget_v_f32m4_f32m2(vsum1, 0), + __riscv_vget_v_f32m4_f32m2(vsum1, 1), vl); + vl = __riscv_vsetvlmax_e32m1(); + vfloat32m1_t acc1_1 = __riscv_vfadd_vv_f32m1(__riscv_vget_v_f32m2_f32m1(acc1_0, 0), + __riscv_vget_v_f32m2_f32m1(acc1_0, 1), vl); + vfloat32m1_t redsum1 = __riscv_vfredusum_vs_f32m1_f32m1( + acc1_1, __riscv_vfmv_v_f_f32m1(0.0f, 1), vl); + sumf[0] = __riscv_vfmv_f_s_f32m1_f32(redsum0); + sumf[1] = __riscv_vfmv_f_s_f32m1_f32(redsum1); + np = n; + #else + const int np = 0; + #endif #else const int np = (n & ~(GGML_F16_STEP - 1)); @@ -313,21 +299,17 @@ inline static void ggml_vec_dot_f16_unroll(const int n, const int xs, float * GG for (int k = 0; k < GGML_VEC_DOT_UNROLL; ++k) { GGML_F16_VEC_REDUCE(sumf[k], sum[k]); } - - // leftovers - for (int i = np; i < n; ++i) { - for (int j = 0; j < GGML_VEC_DOT_UNROLL; ++j) { - sumf[j] += (ggml_float)(GGML_CPU_FP16_TO_FP32(x[j][i])*GGML_CPU_FP16_TO_FP32(y[i])); - } - } #endif #else - for (int i = 0; i < n; ++i) { + // scalar path + const int np = 0; +#endif + // scalar and leftovers + for (int i = np; i < n; ++i) { for (int j = 0; j < GGML_VEC_DOT_UNROLL; ++j) { sumf[j] += (ggml_float)(GGML_CPU_FP16_TO_FP32(x[j][i])*GGML_CPU_FP16_TO_FP32(y[i])); } } -#endif for (int i = 0; i < GGML_VEC_DOT_UNROLL; ++i) { s[i] = (float)sumf[i]; @@ -532,40 +514,45 @@ inline static void ggml_vec_mad_f16(const int n, ggml_fp16_t * GGML_RESTRICT y, svst1_f16(pg, (__fp16 *)(y + np2), hy); } np = n; -#elif defined(__riscv_zvfh) // implies __riscv_v_intrinsic - const ggml_fp16_t s = GGML_CPU_FP32_TO_FP16(v); - const _Float16 scale = *(const _Float16*)(&s); - - // calculate step size - const int epr = __riscv_vsetvlmax_e16m4(); - const int step = epr * 2; - int np = (n & ~(step - 1)); - - // unroll by 2 - for (int i = 0; i < np; i += step) { - vfloat16m4_t ax0 = __riscv_vle16_v_f16m4((const _Float16*)x + i, epr); - vfloat16m4_t ay0 = __riscv_vle16_v_f16m4((const _Float16*)y + i, epr); - ay0 = __riscv_vfmacc_vf_f16m4(ay0, scale, ax0, epr); - __riscv_vse16_v_f16m4((_Float16*)y + i, ay0, epr); - __asm__ __volatile__ ("" ::: "memory"); - - vfloat16m4_t ax1 = __riscv_vle16_v_f16m4((const _Float16*)x + i + epr, epr); - vfloat16m4_t ay1 = __riscv_vle16_v_f16m4((const _Float16*)y + i + epr, epr); - ay1 = __riscv_vfmacc_vf_f16m4(ay1, scale, ax1, epr); - __riscv_vse16_v_f16m4((_Float16*)y + i + epr, ay1, epr); - __asm__ __volatile__ ("" ::: "memory"); - } +#elif defined(__riscv_v_intrinsic) // implies __riscv_v_intrinsic + #if defined (__riscv_zvfh) + const ggml_fp16_t s = GGML_CPU_FP32_TO_FP16(v); + const _Float16 scale = *(const _Float16*)(&s); - // leftovers - int vl; - for (int i = np; i < n; i += vl) { - vl = __riscv_vsetvl_e16m4(n - i); - vfloat16m4_t ax0 = __riscv_vle16_v_f16m4((const _Float16*)x + i, vl); - vfloat16m4_t ay0 = __riscv_vle16_v_f16m4((const _Float16*)y + i, vl); - ay0 = __riscv_vfmacc_vf_f16m4(ay0, scale, ax0, vl); - __riscv_vse16_v_f16m4((_Float16*)y + i, ay0, vl); - } - np = n; + // calculate step size + const int epr = __riscv_vsetvlmax_e16m4(); + const int step = epr * 2; + int np = (n & ~(step - 1)); + + // unroll by 2 + for (int i = 0; i < np; i += step) { + vfloat16m4_t ax0 = __riscv_vle16_v_f16m4((const _Float16*)x + i, epr); + vfloat16m4_t ay0 = __riscv_vle16_v_f16m4((const _Float16*)y + i, epr); + ay0 = __riscv_vfmacc_vf_f16m4(ay0, scale, ax0, epr); + __riscv_vse16_v_f16m4((_Float16*)y + i, ay0, epr); + __asm__ __volatile__ ("" ::: "memory"); + + vfloat16m4_t ax1 = __riscv_vle16_v_f16m4((const _Float16*)x + i + epr, epr); + vfloat16m4_t ay1 = __riscv_vle16_v_f16m4((const _Float16*)y + i + epr, epr); + ay1 = __riscv_vfmacc_vf_f16m4(ay1, scale, ax1, epr); + __riscv_vse16_v_f16m4((_Float16*)y + i + epr, ay1, epr); + __asm__ __volatile__ ("" ::: "memory"); + } + + // leftovers + int vl; + for (int i = np; i < n; i += vl) { + vl = __riscv_vsetvl_e16m4(n - i); + vfloat16m4_t ax0 = __riscv_vle16_v_f16m4((const _Float16*)x + i, vl); + vfloat16m4_t ay0 = __riscv_vle16_v_f16m4((const _Float16*)y + i, vl); + ay0 = __riscv_vfmacc_vf_f16m4(ay0, scale, ax0, vl); + __riscv_vse16_v_f16m4((_Float16*)y + i, ay0, vl); + } + np = n; + #else + // fall to scalar path + const int np = 0; + #endif #elif defined(GGML_SIMD) const int np = (n & ~(GGML_F16_STEP - 1)); @@ -584,10 +571,11 @@ inline static void ggml_vec_mad_f16(const int n, ggml_fp16_t * GGML_RESTRICT y, } } #else + // scalar path const int np = 0; #endif - // leftovers + // scalar and leftovers for (int i = np; i < n; ++i) { y[i] = GGML_CPU_FP32_TO_FP16(GGML_CPU_FP16_TO_FP32(y[i]) + GGML_CPU_FP16_TO_FP32(x[i])*v); } @@ -785,7 +773,7 @@ inline static void ggml_vec_scale_f16(const int n, ggml_fp16_t * y, const float const int ggml_f16_step = 2 * ggml_f16_epr; GGML_F16x_VEC vx = GGML_F16x_VEC_SET1(v); - const int np = (n & ~(ggml_f16_step - 1)); + int np = (n & ~(ggml_f16_step - 1)); svfloat16_t ay1, ay2; for (int i = 0; i < np; i += ggml_f16_step) { @@ -805,36 +793,43 @@ inline static void ggml_vec_scale_f16(const int n, ggml_fp16_t * y, const float svfloat16_t out = svmul_f16_m(pg, hy, vx); svst1_f16(pg, (__fp16 *)(y + np), out); } -#elif defined(__riscv_v_intrinsic) && defined(__riscv_zvfh) - const ggml_fp16_t s = GGML_CPU_FP32_TO_FP16(v); - const _Float16 scale = *(const _Float16*)(&s); - - // calculate step size - const int epr = __riscv_vsetvlmax_e16m4(); - const int step = epr * 2; - const int np = (n & ~(step - 1)); + np = n; +#elif defined(__riscv_v_intrinsic) + #if defined(__riscv_zvfh) + const ggml_fp16_t s = GGML_CPU_FP32_TO_FP16(v); + const _Float16 scale = *(const _Float16*)(&s); - // unroll by 2 - for (int i = 0; i < np; i += step) { - vfloat16m4_t ay0 = __riscv_vle16_v_f16m4((const _Float16*)y + i, epr); - ay0 = __riscv_vfmul_vf_f16m4(ay0, scale, epr); - __riscv_vse16_v_f16m4((_Float16*)y + i, ay0, epr); - __asm__ __volatile__ ("" ::: "memory"); + // calculate step size + const int epr = __riscv_vsetvlmax_e16m4(); + const int step = epr * 2; + int np = (n & ~(step - 1)); - vfloat16m4_t ay1 = __riscv_vle16_v_f16m4((const _Float16*)y + i + epr, epr); - ay1 = __riscv_vfmul_vf_f16m4(ay1, scale, epr); - __riscv_vse16_v_f16m4((_Float16*)y + i + epr, ay1, epr); - __asm__ __volatile__ ("" ::: "memory"); - } + // unroll by 2 + for (int i = 0; i < np; i += step) { + vfloat16m4_t ay0 = __riscv_vle16_v_f16m4((const _Float16*)y + i, epr); + ay0 = __riscv_vfmul_vf_f16m4(ay0, scale, epr); + __riscv_vse16_v_f16m4((_Float16*)y + i, ay0, epr); + __asm__ __volatile__ ("" ::: "memory"); + + vfloat16m4_t ay1 = __riscv_vle16_v_f16m4((const _Float16*)y + i + epr, epr); + ay1 = __riscv_vfmul_vf_f16m4(ay1, scale, epr); + __riscv_vse16_v_f16m4((_Float16*)y + i + epr, ay1, epr); + __asm__ __volatile__ ("" ::: "memory"); + } - // leftovers - int vl; - for (int i = np; i < n; i += vl) { - vl = __riscv_vsetvl_e16m4(n - i); - vfloat16m4_t ay0 = __riscv_vle16_v_f16m4((const _Float16*)y + i, vl); - ay0 = __riscv_vfmul_vf_f16m4(ay0, scale, vl); - __riscv_vse16_v_f16m4((_Float16*)y + i, ay0, vl); - } + // leftovers + int vl; + for (int i = np; i < n; i += vl) { + vl = __riscv_vsetvl_e16m4(n - i); + vfloat16m4_t ay0 = __riscv_vle16_v_f16m4((const _Float16*)y + i, vl); + ay0 = __riscv_vfmul_vf_f16m4(ay0, scale, vl); + __riscv_vse16_v_f16m4((_Float16*)y + i, ay0, vl); + } + np = n; + #else + // fall to scalar path + const int np = 0; + #endif #elif defined(GGML_SIMD) const int np = (n & ~(GGML_F16_STEP - 1)); @@ -850,17 +845,14 @@ inline static void ggml_vec_scale_f16(const int n, ggml_fp16_t * y, const float GGML_F16_VEC_STORE(y + i + j*GGML_F16_EPR, ay, j); } } - - // leftovers - for (int i = np; i < n; ++i) { - y[i] = GGML_CPU_FP32_TO_FP16(GGML_CPU_FP16_TO_FP32(y[i])*v); - } #else - // scalar - for (int i = 0; i < n; ++i) { + // scalar path + const int np = 0; +#endif + // scalar and leftovers + for (int i = np; i < n; ++i) { y[i] = GGML_CPU_FP32_TO_FP16(GGML_CPU_FP16_TO_FP32(y[i])*v); } -#endif } inline static void ggml_vec_norm_f32 (const int n, float * s, const float * x) { ggml_vec_dot_f32(n, s, 0, x, 0, x, 0, 1); *s = sqrtf(*s); } @@ -1026,12 +1018,12 @@ inline static float ggml_gelu_quick_f32(float x) { return x*(1.0f/(1.0f+expf(GELU_QUICK_COEF*x))); } -//inline static void ggml_vec_gelu_quick_f16(const int n, ggml_fp16_t * y, const ggml_fp16_t * x) { -// const uint16_t * i16 = (const uint16_t *) x; -// for (int i = 0; i < n; ++i) { -// y[i] = ggml_table_gelu_quick_f16[i16[i]]; -// } -//} +inline static void ggml_vec_gelu_quick_f16(const int n, ggml_fp16_t * y, const ggml_fp16_t * x) { + const uint16_t * i16 = (const uint16_t *) x; + for (int i = 0; i < n; ++i) { + y[i] = ggml_table_gelu_quick_f16[i16[i]]; + } +} #ifdef GGML_GELU_QUICK_FP16 inline static void ggml_vec_gelu_quick_f32(const int n, float * y, const float * x) { @@ -1050,13 +1042,6 @@ inline static void ggml_vec_gelu_quick_f32(const int n, float * y, const float * } #endif -inline static void ggml_vec_gelu_quick_f16(const int n, ggml_fp16_t * y, const ggml_fp16_t * x) { - for (int i = 0; i < n; ++i) { - float v = GGML_CPU_FP16_TO_FP32(x[i]); - y[i] = GGML_CPU_FP32_TO_FP16(v*(1.0f/(1.0f+expf(GELU_QUICK_COEF*v)))); - } -} - // Sigmoid Linear Unit (SiLU) function inline static float ggml_silu_f32(float x) { return x/(1.0f + expf(-x)); diff --git a/ggml/src/ggml-cuda/CMakeLists.txt b/ggml/src/ggml-cuda/CMakeLists.txt index d313c1ac9af..d3953eee962 100644 --- a/ggml/src/ggml-cuda/CMakeLists.txt +++ b/ggml/src/ggml-cuda/CMakeLists.txt @@ -15,6 +15,7 @@ if (CUDAToolkit_FOUND) # 80 == Ampere, asynchronous data loading, faster tensor core instructions # 86 == RTX 3000, needs CUDA v11.1 # 89 == RTX 4000, needs CUDA v11.8 + # 90 == Hopper H100/200, needs CUDA v11.8 # 120 == Blackwell, needs CUDA v12.8, FP4 tensor cores # # XX-virtual == compile CUDA code as PTX, do JIT compilation to binary code on first run @@ -33,7 +34,7 @@ if (CUDAToolkit_FOUND) list(APPEND CMAKE_CUDA_ARCHITECTURES 75-virtual 80-virtual 86-real) if (CUDAToolkit_VERSION VERSION_GREATER_EQUAL "11.8") - list(APPEND CMAKE_CUDA_ARCHITECTURES 89-real) + list(APPEND CMAKE_CUDA_ARCHITECTURES 89-real 90-virtual) endif() if (CUDAToolkit_VERSION VERSION_GREATER_EQUAL "12.8") @@ -64,7 +65,7 @@ if (CUDAToolkit_FOUND) FetchContent_Declare( CCCL GIT_REPOSITORY https://github.com/nvidia/cccl.git - GIT_TAG v3.2.0-rc2 + GIT_TAG v3.2.0 GIT_SHALLOW TRUE ) @@ -116,12 +117,11 @@ if (CUDAToolkit_FOUND) list(APPEND GGML_SOURCES_CUDA ${SRCS}) add_compile_definitions(GGML_CUDA_FA_ALL_QUANTS) else() - file(GLOB SRCS "template-instances/fattn-vec*q4_0-q4_0.cu") - list(APPEND GGML_SOURCES_CUDA ${SRCS}) - file(GLOB SRCS "template-instances/fattn-vec*q8_0-q8_0.cu") - list(APPEND GGML_SOURCES_CUDA ${SRCS}) - file(GLOB SRCS "template-instances/fattn-vec*f16-f16.cu") - list(APPEND GGML_SOURCES_CUDA ${SRCS}) + list(APPEND GGML_SOURCES_CUDA + template-instances/fattn-vec-instance-f16-f16.cu + template-instances/fattn-vec-instance-q4_0-q4_0.cu + template-instances/fattn-vec-instance-q8_0-q8_0.cu + template-instances/fattn-vec-instance-bf16-bf16.cu) endif() ggml_add_backend_library(ggml-cuda @@ -182,6 +182,16 @@ if (CUDAToolkit_FOUND) target_link_libraries(ggml-cuda PRIVATE CUDA::cuda_driver) endif() + if (GGML_CUDA_NCCL) + find_package(NCCL) + if (NCCL_FOUND) + add_compile_definitions(GGML_USE_NCCL) + target_link_libraries(ggml-cuda PRIVATE NCCL::NCCL) + else() + message(STATUS "Warning: NCCL not found, performance for multiple CUDA GPUs will be suboptimal") + endif() + endif() + set(CUDA_CXX_FLAGS "") set(CUDA_FLAGS -use_fast_math -extended-lambda) diff --git a/ggml/src/ggml-cuda/allreduce.cu b/ggml/src/ggml-cuda/allreduce.cu new file mode 100644 index 00000000000..d56129a227e --- /dev/null +++ b/ggml/src/ggml-cuda/allreduce.cu @@ -0,0 +1,971 @@ +#include "allreduce.cuh" + +#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) + +#include "convert.cuh" +#include "ggml-impl.h" + +#include <algorithm> +#include <cstdlib> +#include <cstring> +#include <limits> + +// --------------------------------------------------------------------------- +// CUDA AllReduce for tensor-parallel inference across two GPUs. +// +// Provides an in-place sum reduction over matching tensors on two CUDA +// devices in the same process. Used by the tensor-split path alongside +// NCCL; targets setups without NVLink, where data is exchanged between the +// GPUs by staging it through pinned host memory over PCIe. +// +// Two reduction strategies are selected per call by tensor size: +// +// * Chunked kernel path (small reductions): a single CUDA kernel both +// stages data through pinned host memory and performs the local sum. +// Cross-GPU synchronization happens *inside the kernel* (busy-wait on +// a host-memory flag), which keeps launch overhead low for the +// latency-sensitive token-generation case. +// +// * Copy-engine path (large reductions): the transfer is split into +// D2H + H2D cudaMemcpyAsync chunks driven by the GPU's copy engine, +// followed by a small device-side add kernel. Cross-GPU +// synchronization happens *outside the kernel*, via CUDA events +// between streams. This keeps the compute engine free while large +// transfers are in flight, which matters for prefill-sized tensors. +// Reductions larger than the per-call inner cap are processed by an +// outer chunker that issues sequential inner calls. +// --------------------------------------------------------------------------- + +// --------------------------------------------------------------------------- +// Cross-GPU signal mechanism +// +// One int per (slot, rank) pair in pinned host memory. Each AR call writes a +// strictly increasing token (= the AR call number) into its own arrival int. +// The peer spins until its read of the other's arrival int equals the token +// it expects for this call -- a mismatch means the peer hasn't arrived yet. +// Tokens never repeat over realistic call rates (32-bit int wraps in tens of +// days at thousands of ARs/sec), so arrival ints don't need to be reset +// between calls; we initialize once at pipeline init and let the values +// accumulate. +// +// There is exactly one writer (the owning GPU) and one reader (the peer), so +// we don't need atomics. A volatile store paired with __threadfence_system() +// provides the release ordering that makes the D2H writes visible system-wide +// before the arrival token is observed. +// +// atomicAdd_system() requires hostNativeAtomicSupported, which is unavailable +// on PCIe-attached consumer GPUs without NVLink, so the volatile path is the +// portable choice. +// --------------------------------------------------------------------------- + +static __device__ __forceinline__ void ggml_cuda_ar_signal_set(int * p, int token) { + *(volatile int *)p = token; +} +static __device__ __forceinline__ int ggml_cuda_ar_signal_get(const int * p) { + return *(const volatile int *)p; +} + +// Byte spacing between adjacent arrival ints. 64 bytes (one cache line) +// ensures each GPU/block's arrival slot lives on its own line, preventing +// false-sharing stalls on the polling GPU. +static constexpr size_t GGML_CUDA_AR_ARRIVAL_STRIDE = 64; + +// Number of blocks the chunked kernel launches with. Each block stripes a +// disjoint slice of the data and synchronizes through its own arrival-token +// slot so multiple SMs can pump PCIe stores in parallel. +static constexpr int GGML_CUDA_AR_KERNEL_BLOCKS = 8; + +// --------------------------------------------------------------------------- +// Chunked kernel AllReduce -- 2 GPUs, supports float, half, and bfloat16. +// +// Both GPUs run this kernel simultaneously on independent streams. sendbuf +// and recvbuf live in T_dst (the caller's tensor type); host_mine / host_other +// carry data in T_wire (the on-wire type, possibly narrower than T_dst -- e.g. +// T_dst=F32 with T_wire=BF16 halves the bytes pushed across PCIe). When +// T_dst == T_wire the casts below are no-ops. +// +// Each GPU runs three phases: +// +// Phase 1 (all threads): cast sendbuf (T_dst) -> T_wire and store as +// single-instruction-width vectors into host_mine. +// __threadfence_system() commits these writes to host +// memory. +// Phase 2 (thread 0): write token to arrival_mine; spin until +// arrival_other == token. +// Phase 3 (all threads): read T_wire vectors from host_other, cast +// each element to T_dst, and sum with the local +// sendbuf value (also rounded through T_wire so that +// both GPUs truncate identically -- this guarantees +// bit-equivalent results across the two devices). +// +// Multi-block: blocks stripe vectors across (gridDim.x * blockDim.x) global +// threads to keep multiple SMs issuing PCIe stores in parallel. Each block +// has its own arrival-token slot (offset by blockIdx.x * ARRIVAL_STRIDE); +// thread 0 of each block signals/spins on that slot independently of other +// blocks. Tail elements (the leftover < ELEMS_PER_VEC at the end) are +// handled only by block 0 to avoid cross-block writes to the same slots. +// --------------------------------------------------------------------------- +template <typename T_dst, typename T_wire> +static __global__ void ggml_cuda_ar_kernel( + const T_dst * sendbuf, + T_dst * recvbuf, + T_wire * __restrict__ host_mine, + const T_wire * __restrict__ host_other, + int count, + int * arrival_mine, + int * arrival_other, + int token) { + + // Vector unit for the wire type, sized to the arch's widest single-instruction + // copy (16 B on Volta+). Each phase-1 iter writes one vector to host memory; + // each phase-3 iter reads one and produces ELEMS_PER_VEC sums. + constexpr int ELEMS_PER_VEC = ggml_cuda_get_max_cpy_bytes() / sizeof(T_wire); + constexpr int ARRIVAL_INTS = (int)(GGML_CUDA_AR_ARRIVAL_STRIDE / sizeof(int)); + + const int tid = threadIdx.x; + const int nt = blockDim.x; + const int bid = blockIdx.x; + const int gtid = bid * nt + tid; + const int gnt = gridDim.x * nt; + const int count_vec = count / ELEMS_PER_VEC; + const int tail = count_vec * ELEMS_PER_VEC; + + // Phase 1: cast sendbuf (T_dst) -> host_mine (T_wire) and store as vectors. + { + for (int i = gtid; i < count_vec; i += gnt) { + const int off = i * ELEMS_PER_VEC; + T_wire wire[ELEMS_PER_VEC]; + #pragma unroll + for (int k = 0; k < ELEMS_PER_VEC; ++k) { + wire[k] = ggml_cuda_cast<T_wire>(sendbuf[off + k]); + } + ggml_cuda_memcpy_1<sizeof(wire)>(&host_mine[off], wire); + } + if (bid == 0 && tid < count - tail) { + host_mine[tail + tid] = ggml_cuda_cast<T_wire>(sendbuf[tail + tid]); + } + } + + // Commit this block's host writes before signalling. + __threadfence_system(); + __syncthreads(); + + // Phase 2: thread 0 of each block signals on its own arrival slot, then + // spins for the matching slot from peer. Per-block tokens mean blocks + // proceed independently -- no inter-block barrier needed. + if (tid == 0) { + int * my_slot = arrival_mine + bid * ARRIVAL_INTS; + const int * other_slot = arrival_other + bid * ARRIVAL_INTS; + + ggml_cuda_ar_signal_set(my_slot, token); + __threadfence_system(); // make our signal visible system-wide + + while (ggml_cuda_ar_signal_get(other_slot) != token) { +#if __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA + __nanosleep(100); +#else + NO_DEVICE_CODE; +#endif // __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA + } + } + + __syncthreads(); + + // Acquire peer's host_other writes (this block's stripe of them). + __threadfence_system(); + + // Phase 3: read peer's T_wire vector, cast both sides through T_wire for + // bit-equivalence, sum in T_dst precision, and write back to recvbuf. + { + for (int i = gtid; i < count_vec; i += gnt) { + const int off = i * ELEMS_PER_VEC; + T_wire wire[ELEMS_PER_VEC]; + ggml_cuda_memcpy_1<sizeof(wire)>(wire, &host_other[off]); + #pragma unroll + for (int k = 0; k < ELEMS_PER_VEC; ++k) { + const T_wire d_low = ggml_cuda_cast<T_wire>(sendbuf[off + k]); + recvbuf[off + k] = ggml_cuda_cast<T_dst>( + ggml_cuda_cast<float>(d_low) + ggml_cuda_cast<float>(wire[k])); + } + } + if (bid == 0 && tid < count - tail) { + const T_wire d_low = ggml_cuda_cast<T_wire>(sendbuf[tail + tid]); + recvbuf[tail + tid] = ggml_cuda_cast<T_dst>( + ggml_cuda_cast<float>(d_low) + + ggml_cuda_cast<float>(host_other[tail + tid])); + } + } +} + +// Combined load-convert-add kernel. The peer's contribution arrives as T_src +// (which may be a lower-precision type than T_dst when the BF16 round-trip is +// active). For bit-equivalence between the two GPUs, dst is first rounded +// through T_src's precision via ggml_cuda_cast -- peer already truncated its +// own value the same way before sending -- so both sides perform identical +// arithmetic. When T_dst == T_src the round-trip cast is a no-op. +template <typename T_dst, typename T_src> +static __global__ void ggml_cuda_ar_add_kernel( + T_dst * __restrict__ dst, + const T_src * __restrict__ src, + int count) { + const int tid = blockIdx.x * blockDim.x + threadIdx.x; + const int nt = gridDim.x * blockDim.x; + for (int i = tid; i < count; i += nt) { + const T_src d_low = ggml_cuda_cast<T_src>(dst[i]); + dst[i] = ggml_cuda_cast<T_dst>( + ggml_cuda_cast<float>(d_low) + ggml_cuda_cast<float>(src[i])); + } +} + +// --------------------------------------------------------------------------- +// Pipeline structure +// --------------------------------------------------------------------------- + +// Number of slots in the event / arrival ring. Two slots is sufficient: +// lockstep guarantees the two GPUs are at most one AR (or chunk) apart, so +// slot[N%2] is always safe to reuse -- peer has already consumed slot[N%2] +// from AR N-2 by the time we get to AR N. acquire_slot's +// cudaEventSynchronize on ev.ker for both devices makes that consumption +// explicit before we overwrite host_buf[slot] for the new AR. +static constexpr int GGML_CUDA_AR_POOL_SIZE = 2; + +// Maximum chunk size (bytes per GPU) handled by one chunked kernel launch. +// Larger tensors are reduced by issuing multiple chunked launches. +static constexpr size_t GGML_CUDA_AR_MAX_BYTES = 1024 * 1024; // 1 MB + +// Copy-engine path: largest tensor accepted on this path; sets host_large / +// dev_tmp allocation size. +static constexpr size_t GGML_CUDA_AR_COPY_MAX_BYTES = 32 * 1024 * 1024; // 32 MB + +// AR wire size at which the copy-engine path takes over from the chunked- +// kernel path. Override via GGML_CUDA_AR_COPY_THRESHOLD. +static constexpr size_t GGML_CUDA_AR_COPY_THRESHOLD_DEFAULT = 1024 * 1024; // 1 MB +// Per-call CE chunk-size heuristic: chunk_bytes = clamp(nbytes / 4, MIN, MAX). +// The /4 keeps ~4 chunks in flight at any moment (good D2H/H2D overlap with +// the peer); the clamps cover the cases where nbytes/4 is too small (per- +// memcpy fixed cost dominates) or too large (chunk-level pipelining stalls). +// Env var GGML_CUDA_AR_COPY_CHUNK_BYTES can override with a fixed value. +static constexpr size_t GGML_CUDA_AR_COPY_CHUNK_BYTES_HEURISTIC_MIN = 512 * 1024; // 512 KB +static constexpr size_t GGML_CUDA_AR_COPY_CHUNK_BYTES_HEURISTIC_MAX = 2 * 1024 * 1024; // 2 MB +// Absolute floor that an env-var override is allowed to set; this caps the +// per-slot copy-event array. 256 KB -> up to 128 chunks per 32 MB tensor. +static constexpr size_t GGML_CUDA_AR_COPY_CHUNK_BYTES_MIN = 256 * 1024; +static constexpr int GGML_CUDA_AR_COPY_MAX_CHUNKS = + static_cast<int>((GGML_CUDA_AR_COPY_MAX_BYTES + GGML_CUDA_AR_COPY_CHUNK_BYTES_MIN - 1) / + GGML_CUDA_AR_COPY_CHUNK_BYTES_MIN); + +struct ggml_cuda_ar_event_slot { + cudaEvent_t app = nullptr; // upstream computation complete + cudaEvent_t cpy[GGML_CUDA_AR_COPY_MAX_CHUNKS] = {}; // copy-engine D2H chunks complete + cudaEvent_t h2d = nullptr; // copy-engine H2Ds complete (handoff AR stream -> compute stream) + cudaEvent_t ker = nullptr; // AllReduce kernel complete +}; + +// Mapped pinned host allocation: cudaHostAlloc + cudaHostGetDevicePointer +// in one place, with the host handle preserved for cudaFreeHost. Used where +// the CPU never touches the buffer -- only the device reads/writes via the +// mapped device pointer. Required on systems where cudaDevAttrCanUseHost- +// PointerForRegisteredMem is 0 and the host pointer can't be used as a +// device pointer. +struct ggml_cuda_ar_host_mapping { + uint8_t * host = nullptr; // cudaFreeHost handle; also the H-side ptr for cudaMemcpyAsync + uint8_t * dev = nullptr; // device-side pointer for kernels / cudaMemset + + cudaError_t alloc(size_t bytes) { + cudaError_t rc = cudaHostAlloc(reinterpret_cast<void **>(&host), bytes, + cudaHostAllocPortable | cudaHostAllocMapped); + if (rc != cudaSuccess) { + host = nullptr; + return rc; + } + rc = cudaHostGetDevicePointer(reinterpret_cast<void **>(&dev), host, 0); + if (rc != cudaSuccess) { + cudaFreeHost(host); + host = nullptr; + dev = nullptr; + } + return rc; + } + + void free() { + if (host) { + cudaFreeHost(host); + host = nullptr; + dev = nullptr; + } + } +}; + +struct ggml_cuda_ar_pipeline { + int n_devices; + int devices[GGML_CUDA_MAX_DEVICES]; + size_t buf_bytes; // bytes per device in host_buf[] + size_t copy_bytes; // bytes per device in host_large[] / dev_tmp[] + size_t copy_threshold; + size_t copy_chunk_bytes; + size_t bf16_threshold; // tensors >= this size (bytes) are reduced via FP32->BF16 round-trip; 0 disables + uint64_t call_count; + + // Per-device resources. + ggml_cuda_ar_host_mapping host_buf[GGML_CUDA_MAX_DEVICES]; // pinned staging (chunked kernel) + ggml_cuda_ar_host_mapping host_large[GGML_CUDA_MAX_DEVICES]; // pinned staging (copy-engine) + char * dev_tmp[GGML_CUDA_MAX_DEVICES]; // device scratch for copy-engine path + cudaStream_t streams[GGML_CUDA_MAX_DEVICES]; // non-blocking + ggml_cuda_ar_event_slot ev_pool[GGML_CUDA_MAX_DEVICES][GGML_CUDA_AR_POOL_SIZE]; + + // Copy-engine: per-device "I finished reading my peer's host_large" + // event. Indexed by RECORDER device. Recorded same-device on streams[i] + // after stage 2's last H2D from host_large[peer]. Waited cross-device + // by peer's stage-1 stream before the next AR overwrites host_large[peer]. + cudaEvent_t host_large_read_done[GGML_CUDA_MAX_DEVICES]; + bool host_large_read_done_valid; + + // Copy-engine: per-device "my add_kernel is done with dev_tmp" event. + // Recorded on the compute stream after each add_kernel; the AR stream + // waits on it before the next copy_impl's H2D overwrites dev_tmp. Lets us + // single-buffer dev_tmp despite add_kernel running on a separate stream. + cudaEvent_t dev_tmp_kernel_done[GGML_CUDA_MAX_DEVICES]; + bool dev_tmp_kernel_done_valid; + + // Arrival ring: ARRIVAL_STRIDE bytes between adjacent ints. Mapped pinned + // memory; CPU never reads/writes -- only the kernel and cudaMemset. + // Use ggml_cuda_ar_arrival_ptr() to index. + ggml_cuda_ar_host_mapping arrival; +}; + +// Base pointer for the (slot, rank) per-block token block. The kernel adds +// blockIdx.x * (ARRIVAL_STRIDE/sizeof(int)) internally to land on its own slot. +static int * ggml_cuda_ar_arrival_ptr(const ggml_cuda_ar_pipeline * p, int slot, int rank) { + const size_t offset = ((size_t)slot * p->n_devices + rank) * + GGML_CUDA_AR_KERNEL_BLOCKS * GGML_CUDA_AR_ARRIVAL_STRIDE; + return reinterpret_cast<int *>(p->arrival.dev + offset); +} + +static uint64_t ggml_cuda_ar_env_u64(const char * name, uint64_t default_value) { + const char * value = getenv(name); + if (value == nullptr || value[0] == '\0') { + return default_value; + } + + char * end = nullptr; + const unsigned long long parsed = strtoull(value, &end, 10); + return end != value ? (uint64_t) parsed : default_value; +} + +struct ggml_cuda_ar_slot_info { + int slot; + int token; +}; + +static ggml_cuda_ar_slot_info ggml_cuda_ar_acquire_slot(ggml_cuda_ar_pipeline * p) { + const int slot = static_cast<int>(p->call_count % GGML_CUDA_AR_POOL_SIZE); + const bool pool_lapped = p->call_count >= GGML_CUDA_AR_POOL_SIZE; + p->call_count++; + + if (pool_lapped) { + for (int i = 0; i < p->n_devices; ++i) { + ggml_cuda_set_device(p->devices[i]); + CUDA_CHECK(cudaEventSynchronize(p->ev_pool[i][slot].ker)); + } + } + + return { slot, (int) p->call_count }; +} + +// Per-AR copy-engine chunk size: env-var override if set, else heuristic +// (clamp(nbytes/4, HEURISTIC_MIN, HEURISTIC_MAX)). +static size_t ggml_cuda_ar_chunk_bytes(const ggml_cuda_ar_pipeline * p, size_t nbytes) { + if (p->copy_chunk_bytes > 0) { + return p->copy_chunk_bytes; + } + return std::min(GGML_CUDA_AR_COPY_CHUNK_BYTES_HEURISTIC_MAX, + std::max(GGML_CUDA_AR_COPY_CHUNK_BYTES_HEURISTIC_MIN, nbytes / 4)); +} + +static void ggml_cuda_ar_wait_for_compute( + ggml_cuda_ar_pipeline * p, ggml_backend_cuda_context * cuda_ctx, int rank, int slot) { + ggml_cuda_ar_event_slot & ev = p->ev_pool[rank][slot]; + CUDA_CHECK(cudaEventRecord(ev.app, cuda_ctx->stream())); + CUDA_CHECK(cudaStreamWaitEvent(p->streams[rank], ev.app)); +} + +// --------------------------------------------------------------------------- +// Init / free +// --------------------------------------------------------------------------- + +ggml_cuda_ar_pipeline * ggml_cuda_ar_pipeline_init(const int * devices, size_t n_devices) { + + if (n_devices != 2) { + GGML_LOG_DEBUG("%s: internal AllReduce only supports n_devices=2 (got %zu); " + "falling back\n", __func__, n_devices); + return nullptr; + } + + // The chunked kernel uses __nanosleep, which is sm70+ (Volta+). + for (size_t i = 0; i < n_devices; ++i) { + const int cc = ggml_cuda_info().devices[devices[i]].cc; + if (cc < GGML_CUDA_CC_VOLTA) { + GGML_LOG_DEBUG("%s: internal AllReduce requires compute capability >= %d " + "(device %d has cc=%d); falling back\n", + __func__, GGML_CUDA_CC_VOLTA, devices[i], cc); + return nullptr; + } + } + + auto * p = new ggml_cuda_ar_pipeline{}; + p->n_devices = n_devices; + p->copy_bytes = GGML_CUDA_AR_COPY_MAX_BYTES; + p->copy_threshold = ggml_cuda_ar_env_u64("GGML_CUDA_AR_COPY_THRESHOLD", GGML_CUDA_AR_COPY_THRESHOLD_DEFAULT); + // 0 = use the per-call heuristic (default). Non-zero env value forces a + // fixed chunk size for diagnostics, with a floor at COPY_CHUNK_BYTES_MIN. + p->copy_chunk_bytes = ggml_cuda_ar_env_u64("GGML_CUDA_AR_COPY_CHUNK_BYTES", 0); + if (p->copy_chunk_bytes > 0 && p->copy_chunk_bytes < GGML_CUDA_AR_COPY_CHUNK_BYTES_MIN) { + GGML_LOG_WARN("%s: GGML_CUDA_AR_COPY_CHUNK_BYTES=%zu below minimum %zu; clamping\n", + __func__, p->copy_chunk_bytes, GGML_CUDA_AR_COPY_CHUNK_BYTES_MIN); + p->copy_chunk_bytes = GGML_CUDA_AR_COPY_CHUNK_BYTES_MIN; + } + // Default 1: BF16 round-trip is always on for F32 inputs (any non-zero + // ne). Set GGML_CUDA_AR_BF16_THRESHOLD=0 to disable, or to a larger + // byte threshold to opt out for small tensors. + p->bf16_threshold = ggml_cuda_ar_env_u64("GGML_CUDA_AR_BF16_THRESHOLD", 1); + for (size_t i = 0; i < n_devices; ++i) { + p->devices[i] = devices[i]; + } + + // Per-device streams and event pools. + for (size_t i = 0; i < n_devices; ++i) { + ggml_cuda_set_device(p->devices[i]); + + cudaStream_t stream = nullptr; + if (cudaStreamCreateWithFlags(&stream, cudaStreamNonBlocking) != cudaSuccess) { + GGML_LOG_ERROR("%s: cudaStreamCreateWithFlags failed for device %d\n", + __func__, p->devices[i]); + ggml_cuda_ar_pipeline_free(p); + return nullptr; + } + p->streams[i] = stream; + + for (int s = 0; s < GGML_CUDA_AR_POOL_SIZE; ++s) { + bool ok = + cudaEventCreateWithFlags(&p->ev_pool[i][s].app, cudaEventDisableTiming) == cudaSuccess && + cudaEventCreateWithFlags(&p->ev_pool[i][s].h2d, cudaEventDisableTiming) == cudaSuccess && + cudaEventCreateWithFlags(&p->ev_pool[i][s].ker, cudaEventDisableTiming) == cudaSuccess; + for (int c = 0; ok && c < GGML_CUDA_AR_COPY_MAX_CHUNKS; ++c) { + ok = cudaEventCreateWithFlags(&p->ev_pool[i][s].cpy[c], cudaEventDisableTiming) == cudaSuccess; + } + if (!ok) { + GGML_LOG_ERROR("%s: cudaEventCreate failed for device %d slot %d\n", + __func__, p->devices[i], s); + ggml_cuda_ar_pipeline_free(p); + return nullptr; + } + } + + if (cudaEventCreateWithFlags(&p->host_large_read_done[i], cudaEventDisableTiming) != cudaSuccess) { + GGML_LOG_ERROR("%s: cudaEventCreate for host_large_read_done failed for device %d\n", + __func__, p->devices[i]); + ggml_cuda_ar_pipeline_free(p); + return nullptr; + } + if (cudaEventCreateWithFlags(&p->dev_tmp_kernel_done[i], cudaEventDisableTiming) != cudaSuccess) { + GGML_LOG_ERROR("%s: cudaEventCreate for dev_tmp_kernel_done failed for device %d\n", + __func__, p->devices[i]); + ggml_cuda_ar_pipeline_free(p); + return nullptr; + } + } + + // Arrival ring: cache-line padded so each GPU's int is on its own line. + const size_t arrival_bytes = + (size_t)GGML_CUDA_AR_POOL_SIZE * n_devices * + GGML_CUDA_AR_KERNEL_BLOCKS * GGML_CUDA_AR_ARRIVAL_STRIDE; + if (p->arrival.alloc(arrival_bytes) != cudaSuccess) { + GGML_LOG_ERROR("%s: alloc for arrival ring failed (%zu bytes)\n", + __func__, arrival_bytes); + ggml_cuda_ar_pipeline_free(p); + return nullptr; + } + ggml_cuda_set_device(p->devices[0]); + if (cudaMemset(p->arrival.dev, 0, arrival_bytes) != cudaSuccess) { + GGML_LOG_ERROR("%s: cudaMemset for arrival ring failed (%zu bytes)\n", + __func__, arrival_bytes); + ggml_cuda_ar_pipeline_free(p); + return nullptr; + } + + // Per-device pinned staging buffers -- POOL_SIZE-deep ring so the chunked- + // kernel can write the next slot's data while the peer is still reading + // the previous slot's. Indexed by (slot * buf_bytes) at the call site. + p->buf_bytes = GGML_CUDA_AR_MAX_BYTES; + const size_t host_buf_total = (size_t) GGML_CUDA_AR_POOL_SIZE * p->buf_bytes; + for (size_t i = 0; i < n_devices; ++i) { + if (p->host_buf[i].alloc(host_buf_total) != cudaSuccess) { + GGML_LOG_ERROR("%s: alloc for staging failed (%zu bytes)\n", + __func__, host_buf_total); + ggml_cuda_ar_pipeline_free(p); + return nullptr; + } + } + + // Copy-engine path: pinned host staging + device scratch, sized for the + // largest tensor we accept on this path (GGML_CUDA_AR_COPY_MAX_BYTES). + // dev_tmp is single-buffered; cross-AR safety is enforced by an explicit + // cross-stream wait in copy_impl on the prior AR's add_kernel-done event. + for (size_t i = 0; i < n_devices; ++i) { + ggml_cuda_set_device(p->devices[i]); + if (p->host_large[i].alloc(p->copy_bytes) != cudaSuccess) { + GGML_LOG_ERROR("%s: alloc for large staging failed (%zu bytes)\n", + __func__, p->copy_bytes); + ggml_cuda_ar_pipeline_free(p); + return nullptr; + } + if (cudaMalloc(reinterpret_cast<void **>(&p->dev_tmp[i]), p->copy_bytes) != cudaSuccess) { + GGML_LOG_ERROR("%s: cudaMalloc for copy scratch failed (%zu bytes) on device %d\n", + __func__, p->copy_bytes, p->devices[i]); + ggml_cuda_ar_pipeline_free(p); + return nullptr; + } + } + + GGML_LOG_INFO("%s: initialized AllReduce pipeline: %zu GPUs, " + "%zu KB chunked kernel staging + %zu MB copy-engine staging per GPU\n", + __func__, n_devices, p->buf_bytes >> 10, p->copy_bytes >> 20); + + return p; +} + +void ggml_cuda_ar_pipeline_free(ggml_cuda_ar_pipeline * p) { + if (!p) { + return; + } + + // Drain all in-flight kernels before tearing down resources. + for (int i = 0; i < p->n_devices; ++i) { + if (p->streams[i]) { + ggml_cuda_set_device(p->devices[i]); + cudaStreamSynchronize(p->streams[i]); + } + } + + for (int i = 0; i < p->n_devices; ++i) { + p->host_buf[i].free(); + p->host_large[i].free(); + if (p->dev_tmp[i]) { + ggml_cuda_set_device(p->devices[i]); + cudaFree(p->dev_tmp[i]); + } + ggml_cuda_set_device(p->devices[i]); + for (int s = 0; s < GGML_CUDA_AR_POOL_SIZE; ++s) { + if (p->ev_pool[i][s].app) { cudaEventDestroy(p->ev_pool[i][s].app); } + for (int c = 0; c < GGML_CUDA_AR_COPY_MAX_CHUNKS; ++c) { + if (p->ev_pool[i][s].cpy[c]) { cudaEventDestroy(p->ev_pool[i][s].cpy[c]); } + } + if (p->ev_pool[i][s].h2d) { cudaEventDestroy(p->ev_pool[i][s].h2d); } + if (p->ev_pool[i][s].ker) { cudaEventDestroy(p->ev_pool[i][s].ker); } + } + if (p->host_large_read_done[i]) { + ggml_cuda_set_device(p->devices[i]); + cudaEventDestroy(p->host_large_read_done[i]); + } + if (p->dev_tmp_kernel_done[i]) { + ggml_cuda_set_device(p->devices[i]); + cudaEventDestroy(p->dev_tmp_kernel_done[i]); + } + if (p->streams[i]) { + ggml_cuda_set_device(p->devices[i]); + cudaStreamDestroy(p->streams[i]); + } + } + p->arrival.free(); + delete p; +} + +// --------------------------------------------------------------------------- +// Dispatch +// --------------------------------------------------------------------------- + +// Asymmetric copy_impl: data sent over PCIe in T_src precision (one element of +// nbytes per ne element); accumulated locally into a T_dst buffer. When +// T_src == T_dst this is the original homogeneous reduction. When they differ +// (e.g. BF16 wire / F32 accumulator) the add kernel rounds dst through T_src +// for bit-equivalence between GPUs and we skip the otherwise-needed +// post-conversion entirely. +template <typename T_src, typename T_dst> +static bool ggml_cuda_ar_allreduce_copy_impl( + ggml_cuda_ar_pipeline * p, + ggml_backend_t * backends, + T_src * const src_buf[GGML_CUDA_MAX_DEVICES], + T_dst * const dst_buf[GGML_CUDA_MAX_DEVICES], + const bool compute[GGML_CUDA_MAX_DEVICES], + int64_t ne, + size_t nbytes) { + GGML_ASSERT(p->n_devices == 2); + GGML_ASSERT(nbytes <= p->copy_bytes); + GGML_ASSERT(ne <= std::numeric_limits<int>::max()); + + const size_t chunk_bytes = ggml_cuda_ar_chunk_bytes(p, nbytes); + GGML_ASSERT(chunk_bytes > 0); + + const int slot = ggml_cuda_ar_acquire_slot(p).slot; + const size_t copy_chunks = (nbytes + chunk_bytes - 1) / chunk_bytes; + GGML_ASSERT(copy_chunks <= GGML_CUDA_AR_COPY_MAX_CHUNKS); + + ggml_backend_cuda_context * cuda_ctx[2] = {}; + + // Stage 1: both GPUs copy their local contribution to pinned host memory. + for (int i = 0; i < 2; ++i) { + ggml_cuda_set_device(p->devices[i]); + cuda_ctx[i] = static_cast<ggml_backend_cuda_context *>(backends[i]->context); + GGML_ASSERT(cuda_ctx[i]->device == p->devices[i]); + + ggml_cuda_ar_wait_for_compute(p, cuda_ctx[i], i, slot); + + // Wait for peer's H2D from our host_large[i] (recorded in the + // previous AR's stage 2) to complete before we overwrite host_large[i]. + // host_large_read_done[peer] = peer finished reading host_large[i]. + // No-op on the first AR -- no prior record exists. + if (p->host_large_read_done_valid) { + const int peer = 1 - i; + CUDA_CHECK(cudaStreamWaitEvent(p->streams[i], p->host_large_read_done[peer])); + } + + if (!compute[i]) { + CUDA_CHECK(cudaMemsetAsync(src_buf[i], 0, nbytes, p->streams[i])); + } + + for (size_t c = 0; c < copy_chunks; ++c) { + const size_t offset = c * chunk_bytes; + const size_t this_bytes = (nbytes - offset) < chunk_bytes ? + (nbytes - offset) : chunk_bytes; + + CUDA_CHECK(cudaMemcpyAsync( + p->host_large[i].host + offset, reinterpret_cast<char *>(src_buf[i]) + offset, this_bytes, + cudaMemcpyDeviceToHost, p->streams[i])); + CUDA_CHECK(cudaEventRecord(p->ev_pool[i][slot].cpy[c], p->streams[i])); + } + } + + // Stage 2: each GPU waits for each peer D2H chunk, pulls that chunk back to + // local device scratch (dev_tmp), then performs one device-local add over + // the assembled peer tensor. The H2Ds run on the AR stream (copy engine) + // and the add_kernel runs on the caller's compute stream, so the AR stream + // stays pure-copy and avoids an in-stream copy->compute engine switch every + // AR. dev_tmp is single-buffered: the AR stream waits cross-stream on the + // prior AR's add_kernel-done event before overwriting it. + for (int i = 0; i < 2; ++i) { + const int peer = 1 - i; + ggml_cuda_set_device(p->devices[i]); + + // Wait for the previous AR's add_kernel (on the compute stream) to + // finish reading dev_tmp before our H2D overwrites it. No-op on the + // first copy_impl call. + if (p->dev_tmp_kernel_done_valid) { + CUDA_CHECK(cudaStreamWaitEvent(p->streams[i], p->dev_tmp_kernel_done[i])); + } + + for (size_t c = 0; c < copy_chunks; ++c) { + const size_t offset = c * chunk_bytes; + const size_t this_bytes = (nbytes - offset) < chunk_bytes ? + (nbytes - offset) : chunk_bytes; + + CUDA_CHECK(cudaStreamWaitEvent(p->streams[i], p->ev_pool[peer][slot].cpy[c])); + CUDA_CHECK(cudaMemcpyAsync( + p->dev_tmp[i] + offset, p->host_large[peer].host + offset, this_bytes, + cudaMemcpyHostToDevice, p->streams[i])); + } + + // Mark our reads of host_large[peer] complete so peer's next AR can + // safely overwrite it. + CUDA_CHECK(cudaEventRecord(p->host_large_read_done[i], p->streams[i])); + + // Hand off from AR stream (copy engine) to compute stream: compute + // stream waits for all H2Ds to finish, then runs the add_kernel. + CUDA_CHECK(cudaEventRecord(p->ev_pool[i][slot].h2d, p->streams[i])); + CUDA_CHECK(cudaStreamWaitEvent(cuda_ctx[i]->stream(), p->ev_pool[i][slot].h2d)); + + const int block_size = 256; + int n_blocks = (int) ((ne + block_size - 1) / block_size); + if (n_blocks > 1024) { + n_blocks = 1024; + } + ggml_cuda_ar_add_kernel<T_dst, T_src><<<n_blocks, block_size, 0, cuda_ctx[i]->stream()>>>( + dst_buf[i], + reinterpret_cast<const T_src *>(p->dev_tmp[i]), + (int) ne); + CUDA_CHECK(cudaGetLastError()); + + // Record dev_tmp-released on the compute stream so the next copy_impl + // can wait for the kernel to finish before overwriting dev_tmp. Also + // record AR-done as ev.ker for acquire_slot's pool-wraparound sync. + CUDA_CHECK(cudaEventRecord(p->dev_tmp_kernel_done[i], cuda_ctx[i]->stream())); + CUDA_CHECK(cudaEventRecord(p->ev_pool[i][slot].ker, cuda_ctx[i]->stream())); + } + p->host_large_read_done_valid = true; + p->dev_tmp_kernel_done_valid = true; + + return true; +} + +// Outer-level chunker: copy_impl handles up to copy_bytes per call (limited by +// the host_large / dev_tmp allocation size). When the full AR exceeds that, +// slice the tensor into copy_bytes-sized pieces and call copy_impl repeatedly. +// Each slice goes through its own stage 1 -> stage 2 cycle and acquires its own +// slot, so cross-AR fences and pool wraparound work the same way as for any +// other sequence of small ARs. +template <typename T_src, typename T_dst> +static bool ggml_cuda_ar_allreduce_copy_outer( + ggml_cuda_ar_pipeline * p, + ggml_backend_t * backends, + T_src * const src_buf[GGML_CUDA_MAX_DEVICES], + T_dst * const dst_buf[GGML_CUDA_MAX_DEVICES], + const bool compute[GGML_CUDA_MAX_DEVICES], + int64_t ne) { + const int64_t outer_max_elems = (int64_t) (p->copy_bytes / sizeof(T_src)); + GGML_ASSERT(outer_max_elems > 0); + + bool ok = true; + for (int64_t outer_start = 0; outer_start < ne && ok; outer_start += outer_max_elems) { + const int64_t outer_ne = std::min(outer_max_elems, ne - outer_start); + const size_t outer_nbytes = (size_t) outer_ne * sizeof(T_src); + + T_src * src[GGML_CUDA_MAX_DEVICES] = {}; + T_dst * dst[GGML_CUDA_MAX_DEVICES] = {}; + for (int i = 0; i < p->n_devices; ++i) { + src[i] = src_buf[i] + outer_start; + dst[i] = dst_buf[i] + outer_start; + } + ok = ggml_cuda_ar_allreduce_copy_impl<T_src, T_dst>( + p, backends, src, dst, compute, outer_ne, outer_nbytes); + } + return ok; +} + +bool ggml_cuda_ar_allreduce( + ggml_cuda_ar_pipeline * p, + ggml_backend_t * backends, + ggml_tensor ** tensors) { + GGML_ASSERT(p != nullptr); + + const int n = p->n_devices; + GGML_ASSERT(n == 2); + + const ggml_type input_type = tensors[0]->type; + GGML_ASSERT(input_type == GGML_TYPE_F32 || input_type == GGML_TYPE_F16 || input_type == GGML_TYPE_BF16); + + const int64_t ne = ggml_nelements(tensors[0]); + GGML_ASSERT(ne > 0); + + const size_t input_nbytes = ggml_nbytes(tensors[0]); + + // BF16 round-trip: F32 inputs >= bf16_threshold are converted to BF16 for + // the reduction (chunked or copy-engine), halving on-wire bytes. Matches + // NCCL's behaviour. The pre-conversion zeroes inactive shards so the + // inner paths see them as already-prepared compute tensors. + const bool use_bf16 = + input_type == GGML_TYPE_F32 && + p->bf16_threshold > 0 && + input_nbytes >= p->bf16_threshold; + + const ggml_type kernel_type = use_bf16 ? GGML_TYPE_BF16 : input_type; + const size_t type_size = ggml_type_size(kernel_type); + GGML_ASSERT(p->buf_bytes >= type_size); + const size_t nbytes = (size_t) ne * type_size; + + bool compute_flag[GGML_CUDA_MAX_DEVICES] = {}; + for (int i = 0; i < n; ++i) { + compute_flag[i] = (tensors[i]->flags & GGML_TENSOR_FLAG_COMPUTE) != 0; + } + + // Decide between copy-engine and chunked kernel paths based on the working + // type's actual byte count. No upper bound: copy_outer slices reductions + // larger than copy_bytes into copy_bytes-sized pieces. + const bool use_copy_engine = + p->copy_threshold > 0 && + nbytes >= p->copy_threshold; + + // BF16 inactive-shard zeroing: when use_bf16 is on, the combined kernel + // (chunked kernel path) and the combined add kernel (copy_engine path) + // both accumulate into the F32 tensor data directly, so an inactive + // shard's accumulator must start at zero. + if (use_bf16) { + for (int i = 0; i < n; ++i) { + if (!compute_flag[i]) { + auto * cuda_ctx = static_cast<ggml_backend_cuda_context *>(backends[i]->context); + GGML_ASSERT(cuda_ctx->device == p->devices[i]); + ggml_cuda_set_device(p->devices[i]); + CUDA_CHECK(cudaMemsetAsync(tensors[i]->data, 0, (size_t) ne * sizeof(float), cuda_ctx->stream())); + } + } + } + + // Pre-convert F32 -> BF16 into bf16_tmp ONLY for the copy_engine + use_bf16 + // path; the chunked kernel path's combined kernel does the conversion + // inline as it writes to host_buf. + ggml_cuda_pool_alloc<nv_bfloat16> bf16_tmp[GGML_CUDA_MAX_DEVICES]; + void * copy_src_ptr[GGML_CUDA_MAX_DEVICES] = {}; + + if (use_copy_engine && use_bf16) { + to_bf16_cuda_t to_bf16 = ggml_get_to_bf16_cuda(GGML_TYPE_F32); + for (int i = 0; i < n; ++i) { + auto * cuda_ctx = static_cast<ggml_backend_cuda_context *>(backends[i]->context); + GGML_ASSERT(cuda_ctx->device == p->devices[i]); + bf16_tmp[i].pool = &cuda_ctx->pool(); + bf16_tmp[i].alloc(ne); + ggml_cuda_set_device(p->devices[i]); + if (compute_flag[i]) { + to_bf16(tensors[i]->data, bf16_tmp[i].get(), ne, cuda_ctx->stream()); + CUDA_CHECK(cudaGetLastError()); + } else { + CUDA_CHECK(cudaMemsetAsync(bf16_tmp[i].get(), 0, nbytes, cuda_ctx->stream())); + } + copy_src_ptr[i] = bf16_tmp[i].get(); + } + } + + bool ok = true; + if (use_copy_engine) { + // After up-front BF16 conversion, the tmp buffers already hold the + // (possibly zeroed-for-inactive) data, so the inner path can treat + // every shard as compute. + bool inner_compute[GGML_CUDA_MAX_DEVICES]; + for (int i = 0; i < n; ++i) { + inner_compute[i] = use_bf16 ? true : compute_flag[i]; + } + + // Dispatch into copy_impl with explicit src/dst types. When use_bf16 + // is on, the wire type is BF16 (src = bf16_tmp) and the accumulator + // is F32 (dst = tensors[i]->data); the combined add kernel rounds dst + // through BF16 for bit-equivalence and writes F32 directly, so no + // post-conversion is needed. Otherwise src == dst (same native type). + if (use_bf16) { + GGML_ASSERT(kernel_type == GGML_TYPE_BF16); + nv_bfloat16 * src[GGML_CUDA_MAX_DEVICES] = {}; + float * dst[GGML_CUDA_MAX_DEVICES] = {}; + for (int i = 0; i < n; ++i) { + src[i] = static_cast<nv_bfloat16 *>(copy_src_ptr[i]); + dst[i] = static_cast<float *>(tensors[i]->data); + } + ok = ggml_cuda_ar_allreduce_copy_outer<nv_bfloat16, float>( + p, backends, src, dst, inner_compute, ne); + } else { + switch (kernel_type) { + case GGML_TYPE_F32: { + float * buf[GGML_CUDA_MAX_DEVICES] = {}; + for (int i = 0; i < n; ++i) { + buf[i] = static_cast<float *>(tensors[i]->data); + } + ok = ggml_cuda_ar_allreduce_copy_outer<float, float>( + p, backends, buf, buf, inner_compute, ne); + break; + } + case GGML_TYPE_BF16: { + nv_bfloat16 * buf[GGML_CUDA_MAX_DEVICES] = {}; + for (int i = 0; i < n; ++i) { + buf[i] = static_cast<nv_bfloat16 *>(tensors[i]->data); + } + ok = ggml_cuda_ar_allreduce_copy_outer<nv_bfloat16, nv_bfloat16>( + p, backends, buf, buf, inner_compute, ne); + break; + } + case GGML_TYPE_F16: { + half * buf[GGML_CUDA_MAX_DEVICES] = {}; + for (int i = 0; i < n; ++i) { + buf[i] = static_cast<half *>(tensors[i]->data); + } + ok = ggml_cuda_ar_allreduce_copy_outer<half, half>( + p, backends, buf, buf, inner_compute, ne); + break; + } + default: + GGML_ASSERT(false); + } + } + } else { + // host_buf carries T_wire-typed data; max_chunk_elems is the count that + // fits in one host_buf at the wire size. + const size_t max_chunk_elems = p->buf_bytes / type_size; + const size_t input_type_size = ggml_type_size(input_type); + + // Chunked kernel path runs entirely on the caller's compute stream: + // since AR is a barrier here, same-stream ordering subsumes any + // cross-stream event handshake that the copy-engine path needs, and + // skips the cross-stream scheduling overhead that was hurting the + // small-tensor (tg) latency on the AR-stream variant. Only ev.ker is + // still recorded at end-of-AR for acquire_slot's pool-wraparound check. + for (int64_t chunk_start = 0; chunk_start < ne; chunk_start += (int64_t) max_chunk_elems) { + const size_t remaining_elems = (size_t) (ne - chunk_start); + const size_t chunk_elems = remaining_elems < max_chunk_elems ? remaining_elems : max_chunk_elems; + const size_t chunk_dst_bytes = chunk_elems * input_type_size; + + const auto [slot, token] = ggml_cuda_ar_acquire_slot(p); + const bool last_chunk = chunk_start + (int64_t) chunk_elems == ne; + + for (int i = 0; i < n; ++i) { + const int peer = 1 - i; // valid for n == 2 only + ggml_cuda_set_device(p->devices[i]); + auto * cuda_ctx = static_cast<ggml_backend_cuda_context *>(backends[i]->context); + GGML_ASSERT(cuda_ctx->device == p->devices[i]); + cudaStream_t stream = cuda_ctx->stream(); + + char * data = static_cast<char *>(tensors[i]->data) + chunk_start * (int64_t) input_type_size; + + // Match NCCL/meta-backend semantics: inactive shards contribute + // zeros. On the BF16 path the F32 tensor data was already + // zeroed up-front (above), so per-chunk zeroing isn't needed. + if (!compute_flag[i] && !use_bf16) { + CUDA_CHECK(cudaMemsetAsync(data, 0, chunk_dst_bytes, stream)); + } + +#define LAUNCH_AR_KERNEL(T_dst, T_wire) \ + ggml_cuda_ar_kernel<T_dst, T_wire><<<dim3(GGML_CUDA_AR_KERNEL_BLOCKS), dim3(256), 0, stream>>>( \ + reinterpret_cast<const T_dst *>(data), \ + reinterpret_cast<T_dst *>(data), \ + reinterpret_cast<T_wire *>(p->host_buf[i].dev + (size_t) slot * p->buf_bytes), \ + reinterpret_cast<const T_wire *>(p->host_buf[peer].dev + (size_t) slot * p->buf_bytes), \ + static_cast<int>(chunk_elems), \ + ggml_cuda_ar_arrival_ptr(p, slot, i), \ + ggml_cuda_ar_arrival_ptr(p, slot, peer), \ + token) + + if (use_bf16) { + GGML_ASSERT(input_type == GGML_TYPE_F32); + LAUNCH_AR_KERNEL(float, nv_bfloat16); + } else { + switch (input_type) { + case GGML_TYPE_F32: LAUNCH_AR_KERNEL(float, float); break; + case GGML_TYPE_F16: LAUNCH_AR_KERNEL(half, half); break; + case GGML_TYPE_BF16: LAUNCH_AR_KERNEL(nv_bfloat16, nv_bfloat16); break; + default: GGML_ASSERT(false); + } + } + +#undef LAUNCH_AR_KERNEL + CUDA_CHECK(cudaGetLastError()); + + if (last_chunk) { + CUDA_CHECK(cudaEventRecord(p->ev_pool[i][slot].ker, stream)); + } + } + } + } + + return ok; +} + +#else // defined(GGML_USE_HIP) || defined(GGML_USE_MUSA) + +// HIP and MUSA lack the host-mapped pinned-memory APIs (cudaHostAllocPortable +// / cudaHostAllocMapped / cudaHostGetDevicePointer) and __nanosleep that this +// implementation relies on, so the internal AllReduce is a CUDA-only feature. +// The dispatcher in ggml-cuda.cu treats a nullptr pipeline as "init failed" +// and silently falls back to the meta backend's generic AllReduce. +ggml_cuda_ar_pipeline * ggml_cuda_ar_pipeline_init(const int *, size_t) { + return nullptr; +} +void ggml_cuda_ar_pipeline_free(ggml_cuda_ar_pipeline *) { +} +bool ggml_cuda_ar_allreduce(ggml_cuda_ar_pipeline *, ggml_backend_t *, ggml_tensor **) { + return false; +} + +#endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) diff --git a/ggml/src/ggml-cuda/allreduce.cuh b/ggml/src/ggml-cuda/allreduce.cuh new file mode 100644 index 00000000000..0f2c9518d5d --- /dev/null +++ b/ggml/src/ggml-cuda/allreduce.cuh @@ -0,0 +1,29 @@ +#pragma once + +#include "common.cuh" +#include "ggml-backend-impl.h" + +#include <cstddef> + +// Opaque pipeline context -- owns all pinned buffers, streams, and events. +struct ggml_cuda_ar_pipeline; + +// Allocate a pipeline for n_devices GPUs. +// devices[] holds the CUDA device IDs in rank order. +// Returns nullptr on allocation failure. +ggml_cuda_ar_pipeline * ggml_cuda_ar_pipeline_init( + const int * devices, size_t n_devices); + +// Release all resources owned by the pipeline. +void ggml_cuda_ar_pipeline_free(ggml_cuda_ar_pipeline * pipeline); + +// Execute an in-place AllReduce (sum) across tensors[0..n_devices-1]. +// tensors[i] must live on the device managed by backends[i] and be +// contiguous F32, F16, or BF16. +// Preconditions are checked by the CUDA comm dispatcher before calling this. +// Returns true once the reduction work has been enqueued successfully. +bool ggml_cuda_ar_allreduce( + ggml_cuda_ar_pipeline * pipeline, + ggml_backend_t * backends, + ggml_tensor ** tensors); + diff --git a/ggml/src/ggml-cuda/argsort.cu b/ggml/src/ggml-cuda/argsort.cu index 57c8a99a286..c4f08091e79 100644 --- a/ggml/src/ggml-cuda/argsort.cu +++ b/ggml/src/ggml-cuda/argsort.cu @@ -2,6 +2,10 @@ #ifdef GGML_CUDA_USE_CUB # include <cub/cub.cuh> +# if (CCCL_MAJOR_VERSION >= 3 && CCCL_MINOR_VERSION >= 1) +# define STRIDED_ITERATOR_AVAILABLE +# include <cuda/iterator> +# endif using namespace cub; #endif // GGML_CUDA_USE_CUB @@ -14,12 +18,14 @@ static __global__ void init_indices(int * indices, const int ncols, const int nr } } +#ifndef STRIDED_ITERATOR_AVAILABLE static __global__ void init_offsets(int * offsets, const int ncols, const int nrows) { const int idx = blockIdx.x * blockDim.x + threadIdx.x; if (idx <= nrows) { offsets[idx] = idx * ncols; } } +#endif // STRIDED_ITERATOR_AVAILABLE #ifdef GGML_CUDA_USE_CUB void argsort_f32_i32_cuda_cub(ggml_cuda_pool & pool, @@ -31,42 +37,70 @@ void argsort_f32_i32_cuda_cub(ggml_cuda_pool & pool, cudaStream_t stream) { ggml_cuda_pool_alloc<int> temp_indices_alloc(pool, ncols * nrows); ggml_cuda_pool_alloc<float> temp_keys_alloc(pool, ncols * nrows); - ggml_cuda_pool_alloc<int> offsets_alloc(pool, nrows + 1); int * temp_indices = temp_indices_alloc.get(); float * temp_keys = temp_keys_alloc.get(); - int * d_offsets = offsets_alloc.get(); static const int block_size = 256; const dim3 grid_size((ncols + block_size - 1) / block_size, nrows); init_indices<<<grid_size, block_size, 0, stream>>>(temp_indices, ncols, nrows); - const dim3 offset_grid((nrows + block_size - 1) / block_size); - init_offsets<<<offset_grid, block_size, 0, stream>>>(d_offsets, ncols, nrows); - +#ifdef STRIDED_ITERATOR_AVAILABLE + auto offset_iterator = cuda::make_strided_iterator(cuda::make_counting_iterator(0), ncols); +#else + // offset_iterator needs to populate nrows + 1 elements, so we also have to ceildiv nrows + 1 by block_size + const int nrows_offset = nrows + 1; + ggml_cuda_pool_alloc<int> offsets_alloc(pool, nrows_offset); + int * offset_iterator = offsets_alloc.get(); + const dim3 offset_grid((nrows_offset + block_size - 1) / block_size); + init_offsets<<<offset_grid, block_size, 0, stream>>>(offset_iterator, ncols, nrows); +#endif CUDA_CHECK(cudaMemcpyAsync(temp_keys, x, ncols * nrows * sizeof(float), cudaMemcpyDeviceToDevice, stream)); size_t temp_storage_bytes = 0; + bool is_capturing = false; +#ifdef USE_CUDA_GRAPH + // Currently (confirmed for CCCL <= 3.2) DeviceSegmentedSort does not support stream capture, while DeviceSegmentedRadixSort does. + // See https://github.com/NVIDIA/cccl/issues/5661#issuecomment-3229037149 + // TODO: constrain this to the CCCL versions that have this issue once it's resolved in a future CCCL release. + cudaStreamCaptureStatus capture_status; + CUDA_CHECK(cudaStreamIsCapturing(stream, &capture_status)); + is_capturing = (capture_status != cudaStreamCaptureStatusNone); +#endif // USE_CUDA_GRAPH + if (order == GGML_SORT_ORDER_ASC) { if (nrows == 1) { - DeviceRadixSort::SortPairs(nullptr, temp_storage_bytes, temp_keys, temp_keys, // keys (in-place) - temp_indices, dst, // values (indices) - ncols, 0, sizeof(float) * 8, stream); + CUDA_CHECK(DeviceRadixSort::SortPairs(nullptr, temp_storage_bytes, temp_keys, temp_keys, // keys (in-place) + temp_indices, dst, // values (indices) + ncols, 0, sizeof(float) * 8, stream)); + } else if (is_capturing) { + CUDA_CHECK(DeviceSegmentedRadixSort::SortPairs( + nullptr, temp_storage_bytes, temp_keys, temp_keys, // keys (in-place) + temp_indices, dst, // values (indices) + ncols * nrows, nrows, // num items, num segments + offset_iterator, offset_iterator + 1, 0, sizeof(float) * 8, stream)); } else { - DeviceSegmentedSort::SortPairs(nullptr, temp_storage_bytes, temp_keys, temp_keys, // keys (in-place) - temp_indices, dst, // values (indices) - ncols * nrows, nrows, // num items, num segments - d_offsets, d_offsets + 1, stream); + CUDA_CHECK(DeviceSegmentedSort::SortPairs(nullptr, temp_storage_bytes, temp_keys, + temp_keys, // keys (in-place) + temp_indices, dst, // values (indices) + ncols * nrows, nrows, // num items, num segments + offset_iterator, offset_iterator + 1, stream)); } } else { if (nrows == 1) { - DeviceRadixSort::SortPairsDescending(nullptr, temp_storage_bytes, temp_keys, temp_keys, // keys (in-place) - temp_indices, dst, // values (indices) - ncols, 0, sizeof(float) * 8, stream); + CUDA_CHECK(DeviceRadixSort::SortPairsDescending(nullptr, temp_storage_bytes, temp_keys, + temp_keys, // keys (in-place) + temp_indices, dst, // values (indices) + ncols, 0, sizeof(float) * 8, stream)); + } else if (is_capturing) { + CUDA_CHECK(DeviceSegmentedRadixSort::SortPairsDescending( + nullptr, temp_storage_bytes, temp_keys, temp_keys, temp_indices, dst, ncols * nrows, nrows, + offset_iterator, offset_iterator + 1, 0, sizeof(float) * 8, stream)); } else { - DeviceSegmentedSort::SortPairsDescending(nullptr, temp_storage_bytes, temp_keys, temp_keys, temp_indices, - dst, ncols * nrows, nrows, d_offsets, d_offsets + 1, stream); + CUDA_CHECK(DeviceSegmentedSort::SortPairsDescending(nullptr, temp_storage_bytes, temp_keys, temp_keys, + temp_indices, dst, ncols * nrows, nrows, + offset_iterator, offset_iterator + 1, stream)); } } @@ -75,22 +109,33 @@ void argsort_f32_i32_cuda_cub(ggml_cuda_pool & pool, if (order == GGML_SORT_ORDER_ASC) { if (nrows == 1) { - DeviceRadixSort::SortPairs(d_temp_storage, temp_storage_bytes, temp_keys, temp_keys, // keys (in-place) - temp_indices, dst, // values (indices) - ncols, 0, sizeof(float) * 8, stream); + CUDA_CHECK(DeviceRadixSort::SortPairs(d_temp_storage, temp_storage_bytes, temp_keys, + temp_keys, // keys (in-place) + temp_indices, dst, // values (indices) + ncols, 0, sizeof(float) * 8, stream)); + } else if (is_capturing) { + CUDA_CHECK(DeviceSegmentedRadixSort::SortPairs(d_temp_storage, temp_storage_bytes, temp_keys, temp_keys, + temp_indices, dst, ncols * nrows, nrows, offset_iterator, + offset_iterator + 1, 0, sizeof(float) * 8, stream)); } else { - DeviceSegmentedSort::SortPairs(d_temp_storage, temp_storage_bytes, temp_keys, temp_keys, temp_indices, dst, - ncols * nrows, nrows, d_offsets, d_offsets + 1, stream); + CUDA_CHECK(DeviceSegmentedSort::SortPairs(d_temp_storage, temp_storage_bytes, temp_keys, temp_keys, + temp_indices, dst, ncols * nrows, nrows, offset_iterator, + offset_iterator + 1, stream)); } } else { if (nrows == 1) { - DeviceRadixSort::SortPairsDescending(d_temp_storage, temp_storage_bytes, temp_keys, temp_keys, // keys (in-place) - temp_indices, dst, // values (indices) - ncols, 0, sizeof(float) * 8, stream); + CUDA_CHECK(DeviceRadixSort::SortPairsDescending(d_temp_storage, temp_storage_bytes, temp_keys, + temp_keys, // keys (in-place) + temp_indices, dst, // values (indices) + ncols, 0, sizeof(float) * 8, stream)); + } else if (is_capturing) { + CUDA_CHECK(DeviceSegmentedRadixSort::SortPairsDescending( + d_temp_storage, temp_storage_bytes, temp_keys, temp_keys, temp_indices, dst, ncols * nrows, nrows, + offset_iterator, offset_iterator + 1, 0, sizeof(float) * 8, stream)); } else { - DeviceSegmentedSort::SortPairsDescending(d_temp_storage, temp_storage_bytes, temp_keys, temp_keys, - temp_indices, dst, ncols * nrows, nrows, d_offsets, d_offsets + 1, - stream); + CUDA_CHECK(DeviceSegmentedSort::SortPairsDescending(d_temp_storage, temp_storage_bytes, temp_keys, + temp_keys, temp_indices, dst, ncols * nrows, nrows, + offset_iterator, offset_iterator + 1, stream)); } } } diff --git a/ggml/src/ggml-cuda/binbcast.cu b/ggml/src/ggml-cuda/binbcast.cu index 0e6d777b1e6..c25f42b32bb 100644 --- a/ggml/src/ggml-cuda/binbcast.cu +++ b/ggml/src/ggml-cuda/binbcast.cu @@ -2,6 +2,9 @@ #include <cstdint> #include <utility> +template<typename T, size_t> +using type_for_index = T; + static __device__ __forceinline__ float op_repeat(const float a, const float b) { return b; GGML_UNUSED(a); @@ -39,16 +42,20 @@ static __global__ void k_bin_bcast(const src0_t * src0, const uint3 ne11, const uint3 ne12, const uint3 ne13, - /*int s0, */ const int s1, + /*const int s0,*/ + const int s1, const int s2, const int s3, - /*int s00,*/ const int s01, + const int s00, + const int s01, const int s02, const int s03, - /*int s10,*/ const int s11, + const int s10, + const int s11, const int s12, const int s13, src1_ptrs... src1s) { + ggml_cuda_pdl_lc(); const uint32_t i0s = blockDim.x * blockIdx.x + threadIdx.x; const uint32_t i1 = (blockDim.y * blockIdx.y + threadIdx.y); const uint32_t i2 = fastdiv((blockDim.z * blockIdx.z + threadIdx.z), ne3); @@ -69,14 +76,15 @@ static __global__ void k_bin_bcast(const src0_t * src0, const src0_t * src0_row = src0 ? (src0 + i_src0) : nullptr; dst_t * dst_row = dst + i_dst; + ggml_cuda_pdl_sync(); for (int i0 = i0s; i0 < ne0; i0 += blockDim.x * gridDim.x) { const uint32_t i10 = fastmodulo(i0, ne10); - float result = src0_row ? (float) src0_row[i0] : 0.0f; + float result = src0_row ? (float) src0_row[i0*s00] : 0.0f; if constexpr (sizeof...(src1_ptrs) > 0) { - result = (..., (result = bin_op(result, (float)src1s[i_src1 + i10]))); + result = (..., (result = bin_op(result, (float)src1s[i_src1 + i10*s10]))); } else { - result = bin_op(result, (float)src1[i_src1 + i10]); + result = bin_op(result, (float)src1[i_src1 + i10*s10]); } dst_row[i0] = (dst_t) result; @@ -101,13 +109,16 @@ static __global__ void k_bin_bcast_unravel(const src0_t * src0, const uint3 ne11, const uint3 ne12, const uint3 ne13, - /*int s0, */ const int s1, + /*const int s0,*/ + const int s1, const int s2, const int s3, - /*int s00,*/ const int s01, + const int s00, + const int s01, const int s02, const int s03, - /*int s10,*/ const int s11, + const int s10, + const int s11, const int s12, const int s13, src1_ptrs... src1s) { @@ -135,11 +146,12 @@ static __global__ void k_bin_bcast_unravel(const src0_t * src0, const int i10 = fastmodulo(i0, ne10); - float result = src0_row ? (float) src0_row[i0] : 0.0f; + ggml_cuda_pdl_sync(); + float result = src0_row ? (float) src0_row[i0*s00] : 0.0f; if constexpr (sizeof...(src1_ptrs) > 0) { - result = (..., (result = bin_op(result, (float)src1s[i_src1 + i10]))); + result = (..., (result = bin_op(result, (float)src1s[i_src1 + i10*s10]))); } else { - result = bin_op(result, (float)src1[i_src1 + i10]); + result = bin_op(result, (float)src1[i_src1 + i10*s10]); } dst_row[i0] = (dst_t) result; @@ -179,7 +191,7 @@ static void launch_bin_bcast_pack(const ggml_tensor * src0, const ggml_tensor * cnb[3] *= cne[3]; }; - if (ggml_is_contiguous(src0) && ggml_is_contiguous(src1) && ggml_is_contiguous(dst)) { + if (ggml_is_contiguous(src0) && ggml_is_contiguous(src1) && !ggml_is_permuted(src0) && !ggml_is_permuted(src1)) { for (int i = 0; i < 4; i++) { if (nr[i] != 1) { break; @@ -221,7 +233,7 @@ static void launch_bin_bcast_pack(const ggml_tensor * src0, const ggml_tensor * size_t nb12 = cnb1[2]; size_t nb13 = cnb1[3]; - size_t s0 = nb0 / sizeof(dst_t); + //size_t s0 = nb0 / sizeof(dst_t); size_t s1 = nb1 / sizeof(dst_t); size_t s2 = nb2 / sizeof(dst_t); size_t s3 = nb3 / sizeof(dst_t); @@ -251,10 +263,6 @@ static void launch_bin_bcast_pack(const ggml_tensor * src0, const ggml_tensor * GGML_ASSERT(nb12 % sizeof(src1_t) == 0); GGML_ASSERT(nb13 % sizeof(src1_t) == 0); - GGML_ASSERT(s0 == 1); - GGML_ASSERT(s00 == 1); - GGML_ASSERT(s10 == 1); - const int block_size = 128; int64_t hne0 = std::max(ne0 / 2LL, 1LL); @@ -280,35 +288,24 @@ static void launch_bin_bcast_pack(const ggml_tensor * src0, const ggml_tensor * const uint3 ne1_fastdiv = init_fastdiv_values((uint32_t) ne1); const uint3 ne2_fastdiv = init_fastdiv_values((uint32_t) ne2); - if constexpr (sizeof...(I) > 0) { - k_bin_bcast_unravel<bin_op, src0_t, src1_t, dst_t><<<block_num, block_size, 0, stream>>>( + { + const ggml_cuda_kernel_launch_params launch_params = ggml_cuda_kernel_launch_params((dim3)block_num, block_size, 0, stream); + ggml_cuda_kernel_launch(k_bin_bcast_unravel<bin_op, src0_t, src1_t, dst_t, type_for_index<const src1_t *, I>...>, launch_params, src0_dd, src1_dd, dst_dd, ne0_fastdiv, ne1_fastdiv, ne2_fastdiv, ne3, prod_012, prod_01, ne10, ne11, ne12, ne13, - /* s0, */ s1, s2, s3, - /* s00,*/ s01, s02, s03, - /* s10,*/ s11, s12, s13, (const src1_t *) dst->src[I + 1]->data...); - } else { - k_bin_bcast_unravel<bin_op, src0_t, src1_t, dst_t> - <<<block_num, block_size, 0, stream>>>(src0_dd, src1_dd, dst_dd, ne0_fastdiv, ne1_fastdiv, - ne2_fastdiv, ne3, prod_012, prod_01, ne10, ne11, ne12, ne13, - /* s0, */ s1, s2, s3, - /* s00,*/ s01, s02, s03, - /* s10,*/ s11, s12, s13); + /*s0,*/ s1, s2, s3, + s00, s01, s02, s03, + s10, s11, s12, s13, (const src1_t *) dst->src[I + 1]->data...); } } else { const uint3 ne3_fastdiv = init_fastdiv_values((uint32_t) ne3); - if constexpr (sizeof...(I) > 0) { - k_bin_bcast<bin_op, src0_t, src1_t, dst_t><<<block_nums, block_dims, 0, stream>>>( + { + const ggml_cuda_kernel_launch_params launch_params = ggml_cuda_kernel_launch_params(block_nums, block_dims, 0, stream); + ggml_cuda_kernel_launch(k_bin_bcast<bin_op, src0_t, src1_t, dst_t, type_for_index<const src1_t *, I>...>, launch_params, src0_dd, src1_dd, dst_dd, ne0, ne1, ne2, ne3_fastdiv, ne10, ne11, ne12, ne13, - /* s0, */ s1, s2, s3, - /* s00,*/ s01, s02, s03, - /* s10,*/ s11, s12, s13, (const src1_t *) dst->src[I + 1]->data...); - } else { - k_bin_bcast<bin_op, src0_t, src1_t, dst_t><<<block_nums, block_dims, 0, stream>>>( - src0_dd, src1_dd, dst_dd, ne0, ne1, ne2, ne3_fastdiv, ne10, ne11, ne12, ne13, - /* s0, */ s1, s2, s3, - /* s00,*/ s01, s02, s03, - /* s10,*/ s11, s12, s13); + /*s0,*/ s1, s2, s3, + s00, s01, s02, s03, + s10, s11, s12, s13, (const src1_t *) dst->src[I + 1]->data...); } } } @@ -331,6 +328,7 @@ static __global__ void k_repeat_back( } T sum = 0; + ggml_cuda_pdl_sync(); for (int64_t i3 = tid3; i3 < ne03; i3 += ne3) { for (int64_t i2 = tid2; i2 < ne02; i2 += ne2) { for (int64_t i1 = tid1; i1 < ne01; i1 += ne1) { @@ -470,6 +468,36 @@ void ggml_cuda_op_fused_add(ggml_backend_cuda_context & ctx, ggml_tensor * dst, } } +void ggml_cuda_op_fused_mul(ggml_backend_cuda_context & ctx, ggml_tensor * dst, int n_fuse) { + GGML_ASSERT(2 <= n_fuse && n_fuse <= 8); + + switch (n_fuse) { + case 2: + ggml_cuda_op_fused_binbcast_impl<op_mul, 2>(ctx, dst); + break; + case 3: + ggml_cuda_op_fused_binbcast_impl<op_mul, 3>(ctx, dst); + break; + case 4: + ggml_cuda_op_fused_binbcast_impl<op_mul, 4>(ctx, dst); + break; + case 5: + ggml_cuda_op_fused_binbcast_impl<op_mul, 5>(ctx, dst); + break; + case 6: + ggml_cuda_op_fused_binbcast_impl<op_mul, 6>(ctx, dst); + break; + case 7: + ggml_cuda_op_fused_binbcast_impl<op_mul, 7>(ctx, dst); + break; + case 8: + ggml_cuda_op_fused_binbcast_impl<op_mul, 8>(ctx, dst); + break; + default: + GGML_ASSERT(false && "Unsupported n_fuse value"); + } +} + void ggml_cuda_op_repeat_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { const ggml_tensor * src0 = dst->src[0]; diff --git a/ggml/src/ggml-cuda/binbcast.cuh b/ggml/src/ggml-cuda/binbcast.cuh index 62bc950111b..12624785b44 100644 --- a/ggml/src/ggml-cuda/binbcast.cuh +++ b/ggml/src/ggml-cuda/binbcast.cuh @@ -9,3 +9,4 @@ void ggml_cuda_op_div(ggml_backend_cuda_context & ctx, ggml_tensor * dst); void ggml_cuda_op_repeat_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst); void ggml_cuda_op_fused_add(ggml_backend_cuda_context & ctx, ggml_tensor * dst, int n_fuse); +void ggml_cuda_op_fused_mul(ggml_backend_cuda_context & ctx, ggml_tensor * dst, int n_fuse); diff --git a/ggml/src/ggml-cuda/common.cuh b/ggml/src/ggml-cuda/common.cuh index 9516d8ec8f9..e6e50e04119 100644 --- a/ggml/src/ggml-cuda/common.cuh +++ b/ggml/src/ggml-cuda/common.cuh @@ -5,7 +5,9 @@ #include "ggml-cuda.h" #include <cstdint> +#include <cstdlib> #include <memory> +#include <mutex> #if defined(GGML_USE_HIP) #define GGML_COMMON_DECL_HIP @@ -27,6 +29,7 @@ #include <cstdio> #include <string> #include <unordered_map> +#include <utility> #include <vector> #if defined(GGML_USE_HIP) @@ -50,9 +53,11 @@ #define GGML_CUDA_CC_TURING 750 #define GGML_CUDA_CC_AMPERE 800 #define GGML_CUDA_CC_ADA_LOVELACE 890 +#define GGML_CUDA_CC_HOPPER 900 // While BW spans CC 1000, 1100 & 1200, we are integrating Tensor Core instructions available to 1200 family, see // https://docs.nvidia.com/cutlass/media/docs/cpp/blackwell_functionality.html#blackwell-sm120-gemms #define GGML_CUDA_CC_BLACKWELL 1200 +#define GGML_CUDA_CC_DGX_SPARK 1210 #define GGML_CUDA_CC_RUBIN 1300 #define GGML_CUDA_CC_OFFSET_AMD 0x1000000 #define GGML_CUDA_CC_OFFSET_MTHREADS 0x0100000 @@ -64,8 +69,9 @@ #define GGML_CUDA_CC_VEGA (GGML_CUDA_CC_OFFSET_AMD + 0x900) // Vega56/64, minimum for fp16 dual issue #define GGML_CUDA_CC_VEGA20 (GGML_CUDA_CC_OFFSET_AMD + 0x906) // MI50/Radeon VII, minimum for dp4a #define GGML_CUDA_CC_CDNA1 (GGML_CUDA_CC_OFFSET_AMD + 0x908) // MI100, minimum for MFMA, acc registers -#define GGML_CUDA_CC_CDNA2 (GGML_CUDA_CC_OFFSET_AMD + 0x910) // MI210, minimum acc register renameing +#define GGML_CUDA_CC_CDNA2 (GGML_CUDA_CC_OFFSET_AMD + 0x90a) // MI210 (gfx90a), minimum acc register renaming #define GGML_CUDA_CC_CDNA3 (GGML_CUDA_CC_OFFSET_AMD + 0x942) // MI300 +#define GGML_CUDA_CC_CDNA4 (GGML_CUDA_CC_OFFSET_AMD + 0x950) // MI350X/MI355X // RDNA removes MFMA, dp4a, xnack, acc registers, wave size is 32 #define GGML_CUDA_CC_RDNA1 (GGML_CUDA_CC_OFFSET_AMD + 0x1010) // RX 5000 @@ -86,7 +92,8 @@ #define GGML_CUDA_CC_IS_CDNA(cc) (cc >= GGML_CUDA_CC_CDNA1 && cc < GGML_CUDA_CC_RDNA1) #define GGML_CUDA_CC_IS_CDNA1(cc) (cc >= GGML_CUDA_CC_CDNA1 && cc < GGML_CUDA_CC_CDNA2) #define GGML_CUDA_CC_IS_CDNA2(cc) (cc >= GGML_CUDA_CC_CDNA2 && cc < GGML_CUDA_CC_CDNA3) -#define GGML_CUDA_CC_IS_CDNA3(cc) (cc >= GGML_CUDA_CC_CDNA3 && cc < GGML_CUDA_CC_RDNA1) +#define GGML_CUDA_CC_IS_CDNA3(cc) (cc >= GGML_CUDA_CC_CDNA3 && cc < GGML_CUDA_CC_CDNA4) +#define GGML_CUDA_CC_IS_CDNA4(cc) (cc >= GGML_CUDA_CC_CDNA4 && cc < GGML_CUDA_CC_RDNA1) // Moore Threads #define MUSART_HMASK 40300 // MUSA rc4.3, min. ver. for half2 -> uint mask comparisons @@ -104,6 +111,27 @@ # define GGML_CUDA_USE_CUB #endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) && CUDART_VERSION >= 11070 +// PDL host-side support (cudaLaunchKernelEx) requires CUDART >= 11.8. +// However, this has been bugged in CTK < 12.3 for MSVC builds, see +// https://github.com/ggml-org/llama.cpp/pull/22522#discussion_r3302393293 +// __CUDA_ARCH__ is undefined in host passes; GPU arch check happens in device-side code. +#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) && \ + (CUDART_VERSION >= 12030 || (!(defined(_MSC_VER) && !defined(__clang__)) && CUDART_VERSION >= 11080)) +# define GGML_CUDA_USE_PDL +#endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) && (CUDART_VERSION >= 12030 || (!(defined(_MSC_VER) && !defined(__clang__)) && CUDART_VERSION >= 11080)) + +static __device__ __forceinline__ void ggml_cuda_pdl_sync() { +#if defined(GGML_CUDA_USE_PDL) && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= GGML_CUDA_CC_HOPPER + cudaGridDependencySynchronize(); +#endif // defined(GGML_CUDA_USE_PDL) && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= GGML_CUDA_CC_HOPPER +} + +static __device__ __forceinline__ void ggml_cuda_pdl_lc() { +#if defined(GGML_CUDA_USE_PDL) && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= GGML_CUDA_CC_HOPPER + cudaTriggerProgrammaticLaunchCompletion(); +#endif // defined(GGML_CUDA_USE_PDL) && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= GGML_CUDA_CC_HOPPER +} + #ifdef __CUDA_ARCH_LIST__ constexpr bool ggml_cuda_has_arch_impl(int) { return false; @@ -162,6 +190,7 @@ void ggml_cuda_error(const char * stmt, const char * func, const char * file, in #define CUDA_CHECK(err) CUDA_CHECK_GEN(err, cudaSuccess, cudaGetErrorString) + #if CUDART_VERSION >= 12000 || defined(GGML_USE_MUSA) static const char * cublas_get_error_str(const cublasStatus_t err) { return cublasGetStatusString(err); @@ -185,6 +214,10 @@ void ggml_cuda_error(const char * stmt, const char * func, const char * file, in #define CUBLAS_CHECK(err) CUDA_CHECK_GEN(err, CUBLAS_STATUS_SUCCESS, cublas_get_error_str) +#ifdef GGML_USE_NCCL +#define NCCL_CHECK(err) CUDA_CHECK_GEN(err, ncclSuccess, ncclGetErrorString) +#endif // GGML_USE_NCCL + #if !defined(GGML_USE_HIP) && !defined(GGML_CUDA_NO_VMM) static const char * cu_get_error_str(CUresult err) { const char * err_str; @@ -526,6 +559,86 @@ static __device__ __forceinline__ half2 warp_prefix_inclusive_sum(half2 a) { #endif // FP16_AVAILABLE } +enum class block_reduce_method { + MAX, + SUM, +}; + +template<block_reduce_method method_t, typename T> +struct block_reduce_policy; + +template <typename T, typename... Ts> +inline constexpr bool is_any = (std::is_same_v<T, Ts> || ...); + +template<typename...> +inline constexpr bool ggml_cuda_dependent_false_v = false; + +template <typename T> struct block_reduce_policy<block_reduce_method::SUM, T> { + static __device__ T reduce(T val) { + if constexpr(is_any<T, float, float2, half2, int>) { + return warp_reduce_sum(val); + } else { + static_assert(ggml_cuda_dependent_false_v<T>, "Unsupported type for block reduce sum"); + } + } + + static __device__ T sentinel() { + if constexpr (std::is_same_v<T, float>) { + return 0.0f; + } else if constexpr (std::is_same_v<T, float2>) { + return make_float2(0.0f, 0.0f); + } else if constexpr (std::is_same_v<T, half2>) { + return make_half2(0.0f, 0.0f); + } else if constexpr (std::is_same_v<T, int>) { + return 0; + } else { + static_assert(ggml_cuda_dependent_false_v<T>, "Unsupported type for block reduce sum"); + } + } +}; + +template <typename T> struct block_reduce_policy<block_reduce_method::MAX, T> { + static __device__ T reduce(T val) { + if constexpr (is_any<T, float, half2>) { + return warp_reduce_max(val); + } else { + static_assert(ggml_cuda_dependent_false_v<T>, "Unsupported type for block reduce max"); + } + } + + static __device__ T sentinel() { + if constexpr (std::is_same_v<T, float>) { + return -INFINITY; + } else if constexpr (std::is_same_v<T, half2>) { + return make_half2(-INFINITY, -INFINITY); + } else { + static_assert(ggml_cuda_dependent_false_v<T>, "Unsupported type for block reduce max"); + } + } +}; + +template <block_reduce_method reduce_method_t, const unsigned int block_size_template = 0, typename T> +static __device__ T block_reduce(T val, T * shared_vals) { + val = block_reduce_policy<reduce_method_t, T>::reduce(val); + const unsigned int block_size = block_size_template == 0 ? blockDim.x : block_size_template; + if (block_size > WARP_SIZE) { + assert((block_size <= 1024) && (block_size % WARP_SIZE) == 0); + const int warp_id = threadIdx.x / WARP_SIZE; + const int lane_id = threadIdx.x % WARP_SIZE; + if (lane_id == 0) { + shared_vals[warp_id] = val; + } + __syncthreads(); + val = block_reduce_policy<reduce_method_t, T>::sentinel(); + if (lane_id < (static_cast<int>(block_size) / WARP_SIZE)) { + val = shared_vals[lane_id]; + } + return block_reduce_policy<reduce_method_t, T>::reduce(val); + } + + return val; +} + static __device__ __forceinline__ half ggml_cuda_hmax(const half a, const half b) { #ifdef FP16_AVAILABLE @@ -714,6 +827,47 @@ static __device__ __forceinline__ float ggml_cuda_e8m0_to_fp32(uint8_t x) { #endif // CUDART_VERSION >= 12050 } +static __device__ __forceinline__ float ggml_cuda_ue4m3_to_fp32(uint8_t x) { +#if defined(GGML_USE_HIP) && defined(CDNA3) && defined(FP8_AVAILABLE) && HIP_VERSION >= 60200000 + // ROCm does not support fp8 in software on devices with fp8 hardware, + // but CDNA3 supports only e4m3_fnuz (no inf). + const uint32_t bits = x * (x != 0x7F && x != 0xFF); // Convert NaN to 0.0f to match CPU implementation. + const __hip_fp8_e4m3_fnuz xf = *reinterpret_cast<const __hip_fp8_e4m3_fnuz *>(&bits); + return static_cast<float>(xf) / 2; +#else +#if defined(FP8_AVAILABLE) && !defined(GGML_USE_HIP) + const uint32_t bits = x * (x != 0x7F && x != 0xFF); // Convert NaN to 0.0f to match CPU implementation. + const __nv_fp8_e4m3 xf = *reinterpret_cast<const __nv_fp8_e4m3 *>(&bits); + return static_cast<float>(xf) / 2; +#else + if (x == 0 || (x == 0x7F && x != 0xFF)) { // Convert NaN to 0.0f + return 0.0f; + } + const int exp = (x >> 3) & 0xF; + const int man = x & 0x7; + float raw; + if (exp == 0) { + raw = ldexpf((float) man, -9); + } else { + raw = ldexpf(1.0f + (float) man / 8.0f, exp - 7); + } + return static_cast<float>(raw / 2); +#endif // defined(FP8_AVAILABLE) && !defined(GGML_USE_HIP) +#endif // defined(GGML_USE_HIP) && defined(CDNA3) && defined(FP8_AVAILABLE) && HIP_VERSION >= 60200000 +} + +static __device__ __forceinline__ uint8_t ggml_cuda_fp32_to_ue4m3(float x) { +#if defined(BLACKWELL_MMA_AVAILABLE) // This is used for NVFP4 subblock scale quantizations only + if (!(x > 0.0f)) { + return 0; + } + const __nv_fp8_e4m3 xf(x); + return xf.__x; +#else + NO_DEVICE_CODE; // Used only for NVFP4 Scales for Activations, only for Blackwell +#endif // defined(BLACKWELL_MMA_AVAILABLE) +} + __device__ __forceinline__ uint8_t ggml_cuda_float_to_fp4_e2m1(float x, float e) { const uint8_t sign_bit = (x < 0.0f) << 3; float ax = fabsf(x) * e; @@ -804,6 +958,13 @@ struct ggml_cuda_type_traits<GGML_TYPE_F16> { static constexpr int qr = 1; }; +template<> +struct ggml_cuda_type_traits<GGML_TYPE_Q1_0> { + static constexpr int qk = QK1_0; + static constexpr int qr = QR1_0; + static constexpr int qi = QI1_0; +}; + template<> struct ggml_cuda_type_traits<GGML_TYPE_Q4_0> { static constexpr int qk = QK4_0; @@ -846,6 +1007,13 @@ struct ggml_cuda_type_traits<GGML_TYPE_MXFP4> { static constexpr int qi = QI_MXFP4; }; +template<> +struct ggml_cuda_type_traits<GGML_TYPE_NVFP4> { + static constexpr int qk = QK_NVFP4; + static constexpr int qr = QR_NVFP4; + static constexpr int qi = QI_NVFP4; +}; + template<> struct ggml_cuda_type_traits<GGML_TYPE_Q2_K> { static constexpr int qk = QK_K; @@ -1036,15 +1204,6 @@ struct ggml_tensor_extra_gpu { #define USE_CUDA_GRAPH #endif -struct ggml_cuda_graph_node_properties { - void * node_address; - ggml_op node_op; - int64_t ne[GGML_MAX_DIMS]; - size_t nb[GGML_MAX_DIMS]; - void * src_address[GGML_MAX_SRC]; - int32_t op_params[GGML_MAX_OP_PARAMS / sizeof(int32_t)]; -}; - struct ggml_cuda_graph { #ifdef USE_CUDA_GRAPH ~ggml_cuda_graph() { @@ -1060,25 +1219,20 @@ struct ggml_cuda_graph { size_t num_nodes = 0; std::vector<cudaGraphNode_t> nodes; bool disable_due_to_gpu_arch = false; - bool disable_due_to_too_many_updates = false; - int number_consecutive_updates = 0; - std::vector<ggml_cuda_graph_node_properties> props; - - void record_update(bool use_graph, bool update_required) { - if (use_graph && update_required) { - number_consecutive_updates++; - } else { - number_consecutive_updates = 0; - } - if (number_consecutive_updates >= 4) { - GGML_LOG_DEBUG("%s: disabling CUDA graphs due to too many consecutive updates\n", __func__); - disable_due_to_too_many_updates = true; - } - } + bool warmup_complete = false; + uint64_t uid = 0; + int64_t last_used_time = 0; + struct node_properties { + ggml_tensor node; + void * node_src_data_ptrs[GGML_MAX_SRC]; + int64_t node_src_ne[GGML_MAX_SRC][GGML_MAX_DIMS]; + size_t node_src_nb[GGML_MAX_SRC][GGML_MAX_DIMS]; + }; + std::vector<node_properties> node_props; bool is_enabled() const { static const bool disable_cuda_graphs_due_to_env = (getenv("GGML_CUDA_DISABLE_GRAPHS") != nullptr); - return !(disable_due_to_gpu_arch || disable_cuda_graphs_due_to_env || disable_due_to_too_many_updates); + return !(disable_due_to_gpu_arch || disable_cuda_graphs_due_to_env); } #endif }; @@ -1242,10 +1396,60 @@ struct ggml_backend_cuda_context { cudaStream_t streams[GGML_CUDA_MAX_DEVICES][GGML_CUDA_MAX_STREAMS] = { { nullptr } }; cublasHandle_t cublas_handles[GGML_CUDA_MAX_DEVICES] = {nullptr}; - std::unique_ptr<ggml_cuda_graph> cuda_graph; - int curr_stream_no = 0; +#ifdef USE_CUDA_GRAPH + // Map from first_node_ptr to cuda_graph - allows multiple graphs per context + // when the computation is split across CPU/GPU (e.g., with --n-cpu-moe) + std::unordered_map<const void *, std::unique_ptr<ggml_cuda_graph>> cuda_graphs; + + int64_t last_graph_eviction_sweep = 0; + + ggml_cuda_graph * cuda_graph(const void * first_node_ptr) { + const int64_t time_now = ggml_time_us(); + + // sweep every 5s, evicting cuda graphs unused for >=10s + if (time_now - last_graph_eviction_sweep >= 5'000'000) { + last_graph_eviction_sweep = time_now; + for (auto it = cuda_graphs.begin(); it != cuda_graphs.end(); ) { + if (time_now - it->second->last_used_time >= 10'000'000) { + it = cuda_graphs.erase(it); + } else { + ++it; + } + } + } + + auto it = cuda_graphs.find(first_node_ptr); + if (it == cuda_graphs.end()) { + it = cuda_graphs.emplace(first_node_ptr, std::make_unique<ggml_cuda_graph>()).first; + } + it->second->last_used_time = time_now; + return it->second.get(); + } + + // Check if any CUDA graph is enabled for this context (used by kernels that need to know + // if graphs are in use without having access to the specific graph key) + bool any_cuda_graph_enabled() const { + for (const auto & [key, graph] : cuda_graphs) { + if (graph && graph->is_enabled()) { + return true; + } + } + return false; + } + + // Check if any CUDA graph has an instance for this context + bool any_cuda_graph_has_instance() const { + for (const auto & [key, graph] : cuda_graphs) { + if (graph && graph->instance != nullptr) { + return true; + } + } + return false; + } +#endif // USE_CUDA_GRAPH + explicit ggml_backend_cuda_context(int device) : device(device), name(GGML_CUDA_NAME + std::to_string(device)) { @@ -1309,3 +1513,129 @@ struct ggml_cuda_mm_fusion_args_device { const void * gate_bias = nullptr; ggml_glu_op glu_op; }; + +struct ggml_cuda_kernel_launch_params { + dim3 block_nums; + dim3 block_dims; + size_t shmem; + cudaStream_t stream; + + // size_t shmem + ggml_cuda_kernel_launch_params(const dim3& block_nums_, const dim3& block_dims_, const size_t shmem_, const cudaStream_t stream_) + : block_nums(block_nums_), block_dims(block_dims_), shmem(shmem_), stream(stream_) {} + + // Some call sites pass ints instead of the required size_t. This 2nd constructor casts int->size_t to avoid these -Wnarrowing warnings. + ggml_cuda_kernel_launch_params(const dim3& block_nums_, const dim3& block_dims_, const int shmem_, const cudaStream_t stream_) + : block_nums(block_nums_), block_dims(block_dims_), shmem((size_t)shmem_), stream(stream_) {} +}; + +#if defined(GGML_CUDA_USE_PDL) +struct ggml_cuda_pdl_config { + cudaLaunchAttribute attr; + cudaLaunchConfig_t cfg; + + ggml_cuda_pdl_config(const ggml_cuda_kernel_launch_params & params) { + attr.id = cudaLaunchAttributeProgrammaticStreamSerialization; + attr.val.programmaticStreamSerializationAllowed = 1; + + cfg = {}; + cfg.gridDim = params.block_nums; + cfg.blockDim = params.block_dims; + cfg.dynamicSmemBytes = params.shmem; + cfg.stream = params.stream; + cfg.attrs = &attr; + cfg.numAttrs = 1; + } + + // Delete due to &attr + ggml_cuda_pdl_config(const ggml_cuda_pdl_config&) = delete; + ggml_cuda_pdl_config& operator=(const ggml_cuda_pdl_config&) = delete; + ggml_cuda_pdl_config& operator=(ggml_cuda_pdl_config&&) = delete; + +}; + +static bool ggml_cuda_kernel_can_use_pdl(const void * kernel) { + const int device = ggml_cuda_get_device(); + + struct cache_key { + int device; + const void * kernel; + + bool operator==(const cache_key & other) const { return device == other.device && kernel == other.kernel; } + }; + + struct cache_key_hash { + // MurmurHash3 mixing function for better hash distribution (vs. just std::hash which in some implementations simply returns the identity) + static size_t hash_mix(size_t x) { + std::uint64_t y = x; + const std::uint64_t m = 0xe9846af9b1a615d; + + y ^= y >> 32; + y *= m; + y ^= y >> 32; + y *= m; + y ^= y >> 28; + + return static_cast<size_t>(y); + } + + size_t operator()(const cache_key & key) const { + // Use a nonzero seed to avoid mapping all-zero keys to zero + size_t h = 42; + h = hash_mix(h + key.device); + h = hash_mix(h + reinterpret_cast<size_t>(key.kernel)); + return h; + } + }; + + static std::mutex cache_mutex; + static std::unordered_map<cache_key, bool, cache_key_hash> cache; + + const cache_key key = { device, kernel }; + std::lock_guard<std::mutex> lock(cache_mutex); + const auto it = cache.find(key); + if (it != cache.end()) { + return it->second; + } + + cudaFuncAttributes attr = {}; + CUDA_CHECK(cudaFuncGetAttributes(&attr, kernel)); + + // PDL device-side primitives are emitted only for PTX versions >= 90. + // We have to guard on a loaded kernel's PTX version so a kernel forward-JIT'ed + // from pre-Hopper PTX to a Hopper-or-newer GPU does not opt into PDL. + const bool can_use_pdl = attr.ptxVersion >= 90; + cache.emplace(key, can_use_pdl); + return can_use_pdl; +} + +#endif //defined(GGML_CUDA_USE_PDL) + +// PDL and __restrict__ need to be mutually exclusive, see https://github.com/ggml-org/llama.cpp/pull/24030 +# if (defined(GGML_CUDA_USE_PDL) && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= GGML_CUDA_CC_HOPPER) +# define GGML_CUDA_RESTRICT +# else +# define GGML_CUDA_RESTRICT __restrict__ +# endif // defined(GGML_CUDA_USE_PDL) && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= GGML_CUDA_CC_HOPPER + +template<typename Kernel, typename... Args> +static __inline__ void ggml_cuda_kernel_launch(Kernel kernel, const ggml_cuda_kernel_launch_params & launch_params, Args&&... args) { +#if defined(GGML_CUDA_USE_PDL) + + static const bool env_pdl_enabled = []() { + const char * env = getenv("GGML_CUDA_PDL"); + return env == nullptr || std::atoi(env) != 0; + }(); + + if (env_pdl_enabled && ggml_cuda_kernel_can_use_pdl(reinterpret_cast<const void *>(kernel))) { + auto pdl_cfg = ggml_cuda_pdl_config(launch_params); + + CUDA_CHECK(cudaLaunchKernelEx(&pdl_cfg.cfg, kernel, std::forward<Args>(args)... )); + return; + } +#endif //defined(GGML_CUDA_USE_PDL) + + kernel<<<launch_params.block_nums, launch_params.block_dims, launch_params.shmem, launch_params.stream>>>(std::forward<Args>(args)... ); + CUDA_CHECK(cudaGetLastError()); +} + diff --git a/ggml/src/ggml-cuda/concat.cu b/ggml/src/ggml-cuda/concat.cu index e9ffd274b99..8d557092b2b 100644 --- a/ggml/src/ggml-cuda/concat.cu +++ b/ggml/src/ggml-cuda/concat.cu @@ -1,102 +1,88 @@ #include "concat.cuh" -// contiguous kernels -static __global__ void concat_f32_dim0(const float * x, const float * y, float * dst, const int ne0, const int ne00) { - int nidx = threadIdx.x + blockIdx.x * blockDim.x; - if (nidx >= ne0) { - return; - } +#include <stdint.h> - int offset_dst = - nidx + - blockIdx.y * ne0 + - blockIdx.z * ne0 * gridDim.y; - - if (nidx < ne00) { // src0 - int offset_src = - nidx + - blockIdx.y * ne00 + - blockIdx.z * ne00 * gridDim.y; - dst[offset_dst] = x[offset_src]; - } else { - int offset_src = - (nidx - ne00) + - blockIdx.y * (ne0 - ne00) + - blockIdx.z * (ne0 - ne00) * gridDim.y; - dst[offset_dst] = y[offset_src]; - } -} - -static __global__ void concat_f32_dim1(const float * x, const float * y, float * dst, const int ne0, const int ne01) { - int nidx = threadIdx.x + blockIdx.x * blockDim.x; - if (nidx >= ne0) { - return; - } +// contiguous kernels +template <typename T, int dim> +static __global__ void __launch_bounds__(CUDA_CONCAT_BLOCK_SIZE) concat_cont(const T * x, + const T * y, + T * dst, + int64_t ne00, + int64_t ne01, + int64_t ne02, + int64_t ne0, + int64_t ne1, + int64_t ne2) { + static_assert(dim >= 0 && dim <= 2, "dim must be in [0, 2]"); + + const int64_t n = ne0 * ne1 * ne2; + + ggml_cuda_pdl_sync(); + for (int64_t i = (int64_t) blockIdx.x * blockDim.x + threadIdx.x; i < n; i += (int64_t) blockDim.x * gridDim.x) { + if constexpr (dim == 0) { + const int64_t row = i / ne0; + const int64_t i0 = i - row * ne0; + + if (i0 < ne00) { + dst[i] = x[row * ne00 + i0]; + } else { + dst[i] = y[row * (ne0 - ne00) + (i0 - ne00)]; + } + } else if constexpr (dim == 1) { + const int64_t dst_plane = ne0 * ne1; + const int64_t src0_plane = ne0 * ne01; + const int64_t src1_plane = dst_plane - src0_plane; + const int64_t i2 = i / dst_plane; + const int64_t i01 = i - i2 * dst_plane; + + if (i01 < src0_plane) { + dst[i] = x[i2 * src0_plane + i01]; + } else { + dst[i] = y[i2 * src1_plane + (i01 - src0_plane)]; + } + } else { + const int64_t src0_size = ne0 * ne1 * ne02; - int offset_dst = - nidx + - blockIdx.y * ne0 + - blockIdx.z * ne0 * gridDim.y; - - if (blockIdx.y < (unsigned)ne01) { // src0 - int offset_src = - nidx + - blockIdx.y * ne0 + - blockIdx.z * ne0 * ne01; - dst[offset_dst] = x[offset_src]; - } else { - int offset_src = - nidx + - (blockIdx.y - ne01) * ne0 + - blockIdx.z * ne0 * (gridDim.y - ne01); - dst[offset_dst] = y[offset_src]; + if (i < src0_size) { + dst[i] = x[i]; + } else { + dst[i] = y[i - src0_size]; + } + } } } -static __global__ void concat_f32_dim2(const float * x, const float * y, float * dst, const int ne0, const int ne02) { - int nidx = threadIdx.x + blockIdx.x * blockDim.x; - if (nidx >= ne0) { - return; - } - - int offset_dst = - nidx + - blockIdx.y * ne0 + - blockIdx.z * ne0 * gridDim.y; - - if (blockIdx.z < (unsigned)ne02) { // src0 - int offset_src = - nidx + - blockIdx.y * ne0 + - blockIdx.z * ne0 * gridDim.y; - dst[offset_dst] = x[offset_src]; - } else { - int offset_src = - nidx + - blockIdx.y * ne0 + - (blockIdx.z - ne02) * ne0 * gridDim.y; - dst[offset_dst] = y[offset_src]; - } -} +template <typename T> +static void concat_cont_cuda(const T * x, + const T * y, + T * dst, + int64_t ne00, + int64_t ne01, + int64_t ne02, + int64_t ne0, + int64_t ne1, + int64_t ne2, + int dim, + cudaStream_t stream) { + const int64_t n = ne0 * ne1 * ne2; + const int num_blocks = (n + CUDA_CONCAT_BLOCK_SIZE - 1) / CUDA_CONCAT_BLOCK_SIZE; -static void concat_f32_cuda(const float * x, const float * y, float * dst, int ne00, int ne01, int ne02, int ne0, int ne1, int ne2, int dim, cudaStream_t stream) { - int num_blocks = (ne0 + CUDA_CONCAT_BLOCK_SIZE - 1) / CUDA_CONCAT_BLOCK_SIZE; - dim3 gridDim(num_blocks, ne1, ne2); if (dim == 0) { - concat_f32_dim0<<<gridDim, CUDA_CONCAT_BLOCK_SIZE, 0, stream>>>(x, y, dst, ne0, ne00); + const ggml_cuda_kernel_launch_params launch_params = ggml_cuda_kernel_launch_params(num_blocks, CUDA_CONCAT_BLOCK_SIZE, 0, stream); + ggml_cuda_kernel_launch(concat_cont<T, 0>, launch_params, x, y, dst, ne00, ne01, ne02, ne0, ne1, ne2); return; } if (dim == 1) { - concat_f32_dim1<<<gridDim, CUDA_CONCAT_BLOCK_SIZE, 0, stream>>>(x, y, dst, ne0, ne01); + concat_cont<T, 1><<<num_blocks, CUDA_CONCAT_BLOCK_SIZE, 0, stream>>>(x, y, dst, ne00, ne01, ne02, ne0, ne1, ne2); return; } - concat_f32_dim2<<<gridDim, CUDA_CONCAT_BLOCK_SIZE, 0, stream>>>(x, y, dst, ne0, ne02); + concat_cont<T, 2><<<num_blocks, CUDA_CONCAT_BLOCK_SIZE, 0, stream>>>(x, y, dst, ne00, ne01, ne02, ne0, ne1, ne2); } // non-contiguous kernel (slow) -template <int dim> +template <typename T, int dim> static __global__ void __launch_bounds__(CUDA_CONCAT_BLOCK_SIZE) - concat_f32_non_cont( + concat_non_cont( const char * src0, const char * src1, char * dst, @@ -123,61 +109,49 @@ static __global__ void __launch_bounds__(CUDA_CONCAT_BLOCK_SIZE) uint64_t nb0, uint64_t nb1, uint64_t nb2, - uint64_t nb3){ + uint64_t nb3) { static_assert(dim >= 0 && dim <= 3, "dim must be in [0, 3]"); const int64_t i3 = blockIdx.z; const int64_t i2 = blockIdx.y; const int64_t i1 = blockIdx.x; - const float * x; + const T * x; for (int64_t i0 = threadIdx.x; i0 < ne0; i0 += blockDim.x) { if (i0 < ne00 && i1 < ne01 && i2 < ne02 && i3 < ne03) { - x = (const float *)(src0 + (i3 )*nb03 + (i2 )*nb02 + (i1 )*nb01 + (i0 )*nb00); + x = (const T *)(src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); } else { if constexpr (dim == 0) { - x = (const float *) (src1 + i3 * nb13 + i2 * nb12 + i1 * nb11 + (i0 - ne00) * nb10); + x = (const T *)(src1 + i3*nb13 + i2*nb12 + i1*nb11 + (i0 - ne00)*nb10); } else if constexpr (dim == 1) { - x = (const float *) (src1 + i3 * nb13 + i2 * nb12 + (i1 - ne01) * nb11 + i0 * nb10); + x = (const T *)(src1 + i3*nb13 + i2*nb12 + (i1 - ne01)*nb11 + i0*nb10); } else if constexpr (dim == 2) { - x = (const float *) (src1 + i3 * nb13 + (i2 - ne02) * nb12 + i1 * nb11 + i0 * nb10); + x = (const T *)(src1 + i3*nb13 + (i2 - ne02)*nb12 + i1*nb11 + i0*nb10); } else if constexpr (dim == 3) { - x = (const float *) (src1 + (i3 - ne03) * nb13 + i2 * nb12 + i1 * nb11 + i0 * nb10); + x = (const T *)(src1 + (i3 - ne03)*nb13 + i2*nb12 + i1*nb11 + i0*nb10); } } - float * y = (float *)(dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); + T * y = (T *)(dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); *y = *x; } } - -void ggml_cuda_op_concat(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { - const ggml_tensor * src0 = dst->src[0]; - const ggml_tensor * src1 = dst->src[1]; - - cudaStream_t stream = ctx.stream(); - - const int32_t dim = ((int32_t *) dst->op_params)[0]; - - GGML_ASSERT(src0->type == GGML_TYPE_F32); - GGML_ASSERT(src1->type == GGML_TYPE_F32); - GGML_ASSERT(dst->type == GGML_TYPE_F32); - +template <typename T> +static void concat_cuda(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, int dim, cudaStream_t stream) { if (ggml_is_contiguous(src0) && ggml_is_contiguous(src1)) { - const float * src0_d = (const float *)src0->data; - const float * src1_d = (const float *)src1->data; - - float * dst_d = (float *)dst->data; + const T * src0_d = (const T *) src0->data; + const T * src1_d = (const T *) src1->data; + T * dst_d = (T *) dst->data; if (dim != 3) { - for (int i3 = 0; i3 < dst->ne[3]; i3++) { - concat_f32_cuda( - src0_d + i3 * (src0->nb[3] / 4), - src1_d + i3 * (src1->nb[3] / 4), - dst_d + i3 * ( dst->nb[3] / 4), + for (int64_t i3 = 0; i3 < dst->ne[3]; i3++) { + concat_cont_cuda( + src0_d + i3*(src0->nb[3] / sizeof(T)), + src1_d + i3*(src1->nb[3] / sizeof(T)), + dst_d + i3*( dst->nb[3] / sizeof(T)), src0->ne[0], src0->ne[1], src0->ne[2], dst->ne[0], dst->ne[1], dst->ne[2], dim, stream); } @@ -185,13 +159,13 @@ void ggml_cuda_op_concat(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { const size_t size0 = ggml_nbytes(src0); const size_t size1 = ggml_nbytes(src1); - CUDA_CHECK(cudaMemcpyAsync(dst_d, src0_d, size0, cudaMemcpyDeviceToDevice, stream)); - CUDA_CHECK(cudaMemcpyAsync(dst_d + size0/4, src1_d, size1, cudaMemcpyDeviceToDevice, stream)); + CUDA_CHECK(cudaMemcpyAsync((char *) dst->data, src0->data, size0, cudaMemcpyDeviceToDevice, stream)); + CUDA_CHECK(cudaMemcpyAsync((char *) dst->data + size0, src1->data, size1, cudaMemcpyDeviceToDevice, stream)); } } else { dim3 grid_dim(dst->ne[1], dst->ne[2], dst->ne[3]); auto launch_kernel = [&](auto dim) { - concat_f32_non_cont<dim><<<grid_dim, CUDA_CONCAT_BLOCK_SIZE, 0, stream>>>( + concat_non_cont<T, dim><<<grid_dim, CUDA_CONCAT_BLOCK_SIZE, 0, stream>>>( (const char *) src0->data, (const char *) src1->data, (char *) dst->data, src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3], @@ -219,3 +193,35 @@ void ggml_cuda_op_concat(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { } } } + +void ggml_cuda_op_concat(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + const ggml_tensor * src0 = dst->src[0]; + const ggml_tensor * src1 = dst->src[1]; + + cudaStream_t stream = ctx.stream(); + + const int32_t dim = ((int32_t *) dst->op_params)[0]; + + GGML_ASSERT(src0->type == src1->type); + GGML_ASSERT(dst->type == src0->type); + GGML_ASSERT(!ggml_is_quantized(src0->type)); + GGML_ASSERT(ggml_blck_size(src0->type) == 1); + + switch (ggml_type_size(src0->type)) { + case 1: + concat_cuda<uint8_t>(src0, src1, dst, dim, stream); + break; + case 2: + concat_cuda<uint16_t>(src0, src1, dst, dim, stream); + break; + case 4: + concat_cuda<uint32_t>(src0, src1, dst, dim, stream); + break; + case 8: + concat_cuda<uint64_t>(src0, src1, dst, dim, stream); + break; + default: + GGML_ABORT("Unsupported type size: %zu", ggml_type_size(src0->type)); + break; + } +} diff --git a/ggml/src/ggml-cuda/conv2d-transpose.cu b/ggml/src/ggml-cuda/conv2d-transpose.cu index 03224e404d3..6cbd6f879e6 100644 --- a/ggml/src/ggml-cuda/conv2d-transpose.cu +++ b/ggml/src/ggml-cuda/conv2d-transpose.cu @@ -1,12 +1,20 @@ -#include <algorithm> - #include "conv2d-transpose.cuh" -#include "ggml.h" - -__global__ void conv2d_transpose_kernel(const float * __restrict__ input, const half * __restrict__ kernel, - float * __restrict__ output, const int in_w, const int in_h, const int out_w, - const int out_h, const int kernel_w, const int kernel_h, const int stride, - const int c_in, const int c_out, const int batches) { +#include "convert.cuh" + +template <typename kernel_t> +static __global__ void conv2d_transpose_kernel(const float * __restrict__ input, + const kernel_t * __restrict__ kernel, + float * __restrict__ output, + const int in_w, + const int in_h, + const int out_w, + const int out_h, + const int kernel_w, + const int kernel_h, + const int stride, + const int c_in, + const int c_out, + const int batches) { const int global_idx = blockIdx.x * blockDim.x + threadIdx.x; const int total_elements = out_w * out_h * c_out * batches; @@ -26,24 +34,32 @@ __global__ void conv2d_transpose_kernel(const float * __restrict__ input, const for (int c_in_idx = 0; c_in_idx < c_in; c_in_idx++) { for (int kh = 0; kh < kernel_h; ++kh) { int in_y = out_y_idx - kh; - if (in_y < 0 || in_y % stride) continue; + if (in_y < 0 || in_y % stride) { + continue; + } in_y /= stride; - if (in_y >= in_h) continue; + if (in_y >= in_h) { + continue; + } for (int kw = 0; kw < kernel_w; ++kw) { int in_x = out_x_idx - kw; - if (in_x < 0 || in_x % stride) continue; + if (in_x < 0 || in_x % stride) { + continue; + } in_x /= stride; - if (in_x >= in_w) continue; + if (in_x >= in_w) { + continue; + } const int input_idx = (in_w * in_h * c_in) * n_idx + (in_w * in_h) * c_in_idx + (in_w) *in_y + in_x; const int kernel_idx = (kernel_h * kernel_w * c_out) * c_in_idx + (kernel_h * kernel_w) * c_idx + (kernel_w) *kh + kw; - float input_val = input[input_idx]; - half kern_val = kernel[kernel_idx]; + float input_val = input[input_idx]; + kernel_t kern_val = kernel[kernel_idx]; - accumulator += input_val * (float) kern_val; + accumulator += input_val * ggml_cuda_cast<float>(kern_val); } } } @@ -56,11 +72,12 @@ void ggml_cuda_conv_2d_transpose_p0(ggml_backend_cuda_context & ctx, ggml_tensor const ggml_tensor * kernel = dst->src[0]; const ggml_tensor * input = dst->src[1]; - GGML_ASSERT(kernel->type == GGML_TYPE_F16 && input->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32); + GGML_ASSERT(kernel->type == GGML_TYPE_F16 || kernel->type == GGML_TYPE_F32); + GGML_ASSERT(input->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32); const float * input_data = (const float *) input->data; float * output_data = (float *) dst->data; - const half * kernel_data = (const half *) kernel->data; + const void * kernel_data = kernel->data; const int input_w = input->ne[0]; const int input_h = input->ne[1]; @@ -82,10 +99,17 @@ void ggml_cuda_conv_2d_transpose_p0(ggml_backend_cuda_context & ctx, ggml_tensor GGML_ASSERT(ggml_is_contiguous(kernel)); GGML_ASSERT(ggml_is_contiguous(dst)); - const int total = (output_w * output_h * channels_out * batches); + const int total = output_w * output_h * channels_out * batches; const int blocks = (total + CUDA_CONV2D_TRANSPOSE_BLOCK_SIZE - 1) / CUDA_CONV2D_TRANSPOSE_BLOCK_SIZE; - conv2d_transpose_kernel<<<blocks, CUDA_CONV2D_TRANSPOSE_BLOCK_SIZE, 0, st>>>( - input_data, kernel_data, output_data, input_w, input_h, output_w, output_h, kernel_w, kernel_h, stride, - channels_in, channels_out, batches); + if (kernel->type == GGML_TYPE_F16) { + conv2d_transpose_kernel<half><<<blocks, CUDA_CONV2D_TRANSPOSE_BLOCK_SIZE, 0, st>>>( + input_data, (const half *) kernel_data, output_data, input_w, input_h, output_w, output_h, kernel_w, + kernel_h, stride, channels_in, channels_out, batches); + + } else { + conv2d_transpose_kernel<float><<<blocks, CUDA_CONV2D_TRANSPOSE_BLOCK_SIZE, 0, st>>>( + input_data, (const float *) kernel_data, output_data, input_w, input_h, output_w, output_h, kernel_w, + kernel_h, stride, channels_in, channels_out, batches); + } } diff --git a/ggml/src/ggml-cuda/conv2d-transpose.cuh b/ggml/src/ggml-cuda/conv2d-transpose.cuh index c9430b24850..72889c5f0fa 100644 --- a/ggml/src/ggml-cuda/conv2d-transpose.cuh +++ b/ggml/src/ggml-cuda/conv2d-transpose.cuh @@ -1,4 +1,5 @@ #include "common.cuh" #define CUDA_CONV2D_TRANSPOSE_BLOCK_SIZE 256 + void ggml_cuda_conv_2d_transpose_p0(ggml_backend_cuda_context & ctx, ggml_tensor * dst); diff --git a/ggml/src/ggml-cuda/convert.cu b/ggml/src/ggml-cuda/convert.cu index ba3d4eeb880..61630a35a29 100644 --- a/ggml/src/ggml-cuda/convert.cu +++ b/ggml/src/ggml-cuda/convert.cu @@ -7,7 +7,8 @@ template <int qk, int qr, dequantize_kernel_t dequantize_kernel, typename dst_t> static __global__ void dequantize_block(const void * __restrict__ vx, dst_t * __restrict__ y, - const int64_t ne00, const int64_t ne01, const int64_t ne02, + const int64_t ne00, const int64_t ne01, + const int64_t ne0203, const uint3 ne02, const int64_t s01, const int64_t s02, const int64_t s03) { const int64_t i00 = 2 * (int64_t(blockDim.x)*blockIdx.x + threadIdx.x); @@ -15,24 +16,28 @@ static __global__ void dequantize_block(const void * __restrict__ vx, dst_t * __ return; } - const int64_t i01 = blockIdx.y; - const int64_t i02 = blockIdx.z % ne02; - const int64_t i03 = blockIdx.z / ne02; + for (int64_t i01 = blockIdx.y; i01 < ne01; i01 += gridDim.y) { + for (int64_t i0203 = blockIdx.z; i0203 < ne0203; i0203 += gridDim.z) { + const uint2 dm = fast_div_modulo((uint32_t)i0203, ne02); + const int64_t i02 = dm.y; + const int64_t i03 = dm.x; - const int64_t ibx0 = i03*s03 + i02*s02 + i01*s01; + const int64_t ibx0 = i03*s03 + i02*s02 + i01*s01; - const int64_t ib = ibx0 + i00/qk; // block index - const int64_t iqs = (i00%qk)/qr; // quant index - const int64_t iybs = i00 - i00%qk; // y block start index - const int64_t y_offset = qr == 1 ? 1 : qk/2; + const int64_t ib = ibx0 + i00/qk; // block index + const int64_t iqs = (i00%qk)/qr; // quant index + const int64_t iybs = i00 - i00%qk; // y block start index + const int64_t y_offset = qr == 1 ? 1 : qk/2; - // dequantize - float2 v; - dequantize_kernel(vx, ib, iqs, v); + // dequantize + float2 v; + dequantize_kernel(vx, ib, iqs, v); - const int64_t iy0 = ((i03*ne02 + i02)*ne01 + i01)*ne00 + iybs + iqs; - y[iy0 + 0] = ggml_cuda_cast<dst_t>(v.x); - y[iy0 + y_offset] = ggml_cuda_cast<dst_t>(v.y); + const int64_t iy0 = (i0203*ne01 + i01)*ne00 + iybs + iqs; + y[iy0 + 0] = ggml_cuda_cast<dst_t>(v.x); + y[iy0 + y_offset] = ggml_cuda_cast<dst_t>(v.y); + } + } } template <bool need_check> @@ -485,9 +490,11 @@ template <int qk, int qr, dequantize_kernel_t dequantize_kernel, typename dst_t> static void dequantize_block_cuda(const void * vx, dst_t * y, const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03, const int64_t s01, const int64_t s02, const int64_t s03, cudaStream_t stream) { - const dim3 num_blocks((ne00 + 2*CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / (2*CUDA_DEQUANTIZE_BLOCK_SIZE), ne01, ne02*ne03); + const int64_t ne0203 = ne02*ne03; + const uint3 ne02_fdv = init_fastdiv_values(ne02); + const dim3 num_blocks((ne00 + 2*CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / (2*CUDA_DEQUANTIZE_BLOCK_SIZE), (int)std::min(ne01, (int64_t)65535), (int)std::min(ne0203, (int64_t)65535)); dequantize_block<qk, qr, dequantize_kernel><<<num_blocks, CUDA_DEQUANTIZE_BLOCK_SIZE, 0, stream>>> - (vx, y, ne00, ne01, ne02, s01, s02, s03); + (vx, y, ne00, ne01, ne0203, ne02_fdv, s01, s02, s03); } template <int qk, int qr, dequantize_kernel_t dequantize_kernel, typename dst_t> @@ -610,9 +617,49 @@ static void dequantize_row_mxfp4_cuda(const void * vx, dst_t * y, const int64_t dequantize_block_mxfp4<<<nb, 32, 0, stream>>>(vx, y); } +template <typename dst_t> +static __global__ void dequantize_block_nvfp4( + const void * __restrict__ vx, + dst_t * __restrict__ yy, + const int64_t ne) { + const int64_t i = blockIdx.x; + const int tid = threadIdx.x; + + const int64_t base = i * QK_NVFP4; + if (base >= ne) { + return; + } + + const block_nvfp4 * x = (const block_nvfp4 *) vx; + const block_nvfp4 & xb = x[i]; + + const int sub = tid / (QK_NVFP4_SUB / 2); + const int j = tid % (QK_NVFP4_SUB / 2); + + const float d = ggml_cuda_ue4m3_to_fp32(xb.d[sub]); + const uint8_t q = xb.qs[sub * (QK_NVFP4_SUB / 2) + j]; + + const int64_t y0 = base + sub * QK_NVFP4_SUB + j; + const int64_t y1 = y0 + QK_NVFP4_SUB / 2; + + yy[y0] = ggml_cuda_cast<dst_t>(d * kvalues_mxfp4[q & 0x0F]); + yy[y1] = ggml_cuda_cast<dst_t>(d * kvalues_mxfp4[q >> 4]); +} + +template <typename dst_t> +static void dequantize_row_nvfp4_cuda( + const void * vx, + dst_t * y, + const int64_t k, + cudaStream_t stream) { + GGML_ASSERT(k % QK_NVFP4 == 0); + const int nb = k / QK_NVFP4; + dequantize_block_nvfp4<<<nb, 32, 0, stream>>>(vx, y, k); +} template <typename src_t, typename dst_t> static __global__ void convert_unary( - const void * __restrict__ vx, dst_t * __restrict__ y, const int64_t ne00, const int64_t ne01, const int64_t ne02, + const void * __restrict__ vx, dst_t * __restrict__ y, const int64_t ne00, const int64_t ne01, + const int64_t ne0203, const uint3 ne02, const int64_t s01, const int64_t s02, const int64_t s03) { const int64_t i00 = (int64_t)blockDim.x*blockIdx.x + threadIdx.x; @@ -620,24 +667,30 @@ static __global__ void convert_unary( return; } - const int64_t i01 = blockIdx.y; - const int64_t i02 = blockIdx.z % ne02; - const int64_t i03 = blockIdx.z / ne02; - const src_t * x = (const src_t *) vx; - const int64_t ix = i03*s03 + i02*s02 + i01*s01 + i00; - const int64_t iy = ((i03*ne02 + i02)*ne01 + i01)*ne00 + i00; - y[iy] = ggml_cuda_cast<dst_t>(x[ix]); + for (int64_t i01 = blockIdx.y; i01 < ne01; i01 += gridDim.y) { + for (int64_t i0203 = blockIdx.z; i0203 < ne0203; i0203 += gridDim.z) { + const uint2 dm = fast_div_modulo((uint32_t)i0203, ne02); + const int64_t i02 = dm.y; + const int64_t i03 = dm.x; + + const int64_t ix = i03*s03 + i02*s02 + i01*s01 + i00; + const int64_t iy = (i0203*ne01 + i01)*ne00 + i00; + y[iy] = ggml_cuda_cast<dst_t>(x[ix]); + } + } } template <typename src_t, typename dst_t> static void convert_unary_cuda(const void * vx, dst_t * y, const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03, const int64_t s01, const int64_t s02, const int64_t s03, cudaStream_t stream) { - const dim3 num_blocks((ne00 + CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / CUDA_DEQUANTIZE_BLOCK_SIZE, ne01, ne02*ne03); + const int64_t ne0203 = ne02*ne03; + const uint3 ne02_fdv = init_fastdiv_values(ne02); + const dim3 num_blocks((ne00 + CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / CUDA_DEQUANTIZE_BLOCK_SIZE, (int)std::min(ne01, (int64_t)65535), (int)std::min(ne0203, (int64_t)65535)); convert_unary<src_t><<<num_blocks, CUDA_DEQUANTIZE_BLOCK_SIZE, 0, stream>>> - (vx, y, ne00, ne01, ne02, s01, s02, s03); + (vx, y, ne00, ne01, ne0203, ne02_fdv, s01, s02, s03); } template <typename src_t, typename dst_t> @@ -658,6 +711,8 @@ to_bf16_cuda_t ggml_get_to_bf16_cuda(ggml_type type) { to_fp16_cuda_t ggml_get_to_fp16_cuda(ggml_type type) { switch (type) { + case GGML_TYPE_Q1_0: + return dequantize_block_cont_cuda<QK1_0, QR1_0, dequantize_q1_0>; case GGML_TYPE_Q4_0: return dequantize_row_q4_0_cuda; case GGML_TYPE_Q4_1: @@ -701,6 +756,8 @@ to_fp16_cuda_t ggml_get_to_fp16_cuda(ggml_type type) { return dequantize_row_iq3_s_cuda; case GGML_TYPE_MXFP4: return dequantize_row_mxfp4_cuda; + case GGML_TYPE_NVFP4: + return dequantize_row_nvfp4_cuda; case GGML_TYPE_F32: return convert_unary_cont_cuda<float>; case GGML_TYPE_BF16: @@ -712,6 +769,8 @@ to_fp16_cuda_t ggml_get_to_fp16_cuda(ggml_type type) { to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type) { switch (type) { + case GGML_TYPE_Q1_0: + return dequantize_block_cont_cuda<QK1_0, QR1_0, dequantize_q1_0>; case GGML_TYPE_Q4_0: return dequantize_row_q4_0_cuda; case GGML_TYPE_Q4_1: @@ -752,6 +811,8 @@ to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type) { return dequantize_row_iq3_s_cuda; case GGML_TYPE_MXFP4: return dequantize_row_mxfp4_cuda; + case GGML_TYPE_NVFP4: + return dequantize_row_nvfp4_cuda; case GGML_TYPE_F16: return convert_unary_cont_cuda<half>; case GGML_TYPE_BF16: @@ -765,6 +826,8 @@ to_fp16_nc_cuda_t ggml_get_to_fp16_nc_cuda(ggml_type type) { switch (type) { case GGML_TYPE_F32: return convert_unary_cuda<float>; + case GGML_TYPE_Q1_0: + return dequantize_block_cuda<QK1_0, QR1_0, dequantize_q1_0>; case GGML_TYPE_Q4_0: return dequantize_block_cuda<QK4_0, QR4_0, dequantize_q4_0>; case GGML_TYPE_Q4_1: @@ -786,6 +849,8 @@ to_bf16_nc_cuda_t ggml_get_to_bf16_nc_cuda(ggml_type type) { switch (type) { case GGML_TYPE_F32: return convert_unary_cuda<float, nv_bfloat16>; + case GGML_TYPE_Q1_0: + return dequantize_block_cuda<QK1_0, QR1_0, dequantize_q1_0>; case GGML_TYPE_Q4_0: return dequantize_block_cuda<QK4_0, QR4_0, dequantize_q4_0>; case GGML_TYPE_Q4_1: @@ -807,6 +872,8 @@ to_fp32_nc_cuda_t ggml_get_to_fp32_nc_cuda(ggml_type type) { switch (type) { case GGML_TYPE_F16: return convert_unary_cuda<half, float>; + case GGML_TYPE_Q1_0: + return dequantize_block_cuda<QK1_0, QR1_0, dequantize_q1_0>; case GGML_TYPE_Q4_0: return dequantize_block_cuda<QK4_0, QR4_0, dequantize_q4_0>; case GGML_TYPE_Q4_1: diff --git a/ggml/src/ggml-cuda/convert.cuh b/ggml/src/ggml-cuda/convert.cuh index 09f9a33f909..f5d37c7b998 100644 --- a/ggml/src/ggml-cuda/convert.cuh +++ b/ggml/src/ggml-cuda/convert.cuh @@ -41,6 +41,16 @@ template<typename dst_t, typename src_t> return __bfloat162float(x); } else if constexpr(std::is_same_v<src_t, float2> && std::is_same_v<dst_t, half2>) { return __float22half2_rn(x); + } else if constexpr(std::is_same_v<src_t, nv_bfloat162> && std::is_same_v<dst_t, float2>) { +#ifdef GGML_USE_HIP + return make_float2(__bfloat162float(__low2bfloat16(x)), __bfloat162float(__high2bfloat16(x))); +#else +#if __CUDA_ARCH__ >= 800 + return __bfloat1622float2(x); +#else + return make_float2(__bfloat162float(x.x), __bfloat162float(x.y)); +#endif // __CUDA_ARCH__ >= 800 +#endif // GGML_USE_HIP } else if constexpr(std::is_same_v<src_t, float2> && std::is_same_v<dst_t, nv_bfloat162>) { // bypass compile error on cuda 12.0.1 #ifdef GGML_USE_HIP diff --git a/ggml/src/ggml-cuda/cpy.cu b/ggml/src/ggml-cuda/cpy.cu index ee84303ef0e..121472ec228 100644 --- a/ggml/src/ggml-cuda/cpy.cu +++ b/ggml/src/ggml-cuda/cpy.cu @@ -16,6 +16,7 @@ static __global__ void cpy_scalar(const char * cx, char * cdst, const int64_t ne const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t nb00, const int64_t nb01, const int64_t nb02, const int64_t nb03, const int64_t ne10, const int64_t ne11, const int64_t ne12, const int64_t nb10, const int64_t nb11, const int64_t nb12, const int64_t nb13) { + ggml_cuda_pdl_lc(); const int64_t i = (int64_t)blockDim.x*blockIdx.x + threadIdx.x; if (i >= ne) { @@ -36,6 +37,7 @@ static __global__ void cpy_scalar(const char * cx, char * cdst, const int64_t ne const int64_t i10 = i - i13*ne10*ne11*ne12 - i12*ne10*ne11 - i11*ne10; const int64_t dst_offset = i10*nb10 + i11*nb11 + i12*nb12 + i13 * nb13; + ggml_cuda_pdl_sync(); cpy_1(cx + x_offset, cdst + dst_offset); } @@ -56,8 +58,10 @@ static __global__ void cpy_scalar_transpose(const char * cx, char * cdst, const const int tx = blockIdx.y * CUDA_CPY_TILE_DIM_2D + threadIdx.x; // transpose block offset const int ty = blockIdx.x * CUDA_CPY_TILE_DIM_2D + threadIdx.y; - __shared__ float tile[CUDA_CPY_TILE_DIM_2D][CUDA_CPY_TILE_DIM_2D+1]; + __shared__ float tile[2][CUDA_CPY_TILE_DIM_2D][CUDA_CPY_TILE_DIM_2D+1]; + int cur_tile_buf = 0; + ggml_cuda_pdl_sync(); #pragma unroll for (int i = 0; i < CUDA_CPY_BLOCK_NM; ++i) { @@ -70,7 +74,7 @@ static __global__ void cpy_scalar_transpose(const char * cx, char * cdst, const if(x < ne01 && y + j < ne00){ const int row = threadIdx.y+j; const int col = threadIdx.x * sizeof(float)/sizeof(T); - T *tile2 = reinterpret_cast<T*>(tile[row]); + T *tile2 = reinterpret_cast<T*>(tile[cur_tile_buf][row]); tile2[col] = src[imat*n + (y+j)*ne01 + x]; } } @@ -81,10 +85,12 @@ static __global__ void cpy_scalar_transpose(const char * cx, char * cdst, const for (int j = 0; j < CUDA_CPY_TILE_DIM_2D; j += CUDA_CPY_BLOCK_ROWS) { if (ty + j < ne01 && tx < ne00) { const int col = (threadIdx.y+j)*sizeof(float)/sizeof(T); - const T *tile2 = reinterpret_cast<const T*>(tile[threadIdx.x]); + const T *tile2 = reinterpret_cast<const T*>(tile[cur_tile_buf][threadIdx.x]); dst[imat*n + (ty+j)*ne00 + tx] = tile2[col]; } } + + cur_tile_buf = (cur_tile_buf + 1) % 2; } GGML_UNUSED_VARS(ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, @@ -139,6 +145,7 @@ static __global__ void cpy_f32_q(const char * cx, char * cdst, const int64_t ne, const int64_t i10 = i - i13*ne10*ne11*ne12 - i12*ne10*ne11 - i11*ne10; const int64_t dst_offset = (i10/qk)*nb10 + i11*nb11 + i12*nb12 + i13*nb13; + ggml_cuda_pdl_sync(); cpy_blck(cx + x_offset, cdst + dst_offset); } @@ -165,6 +172,7 @@ static __global__ void cpy_q_f32(const char * cx, char * cdst, const int64_t ne, const int64_t i10 = i - i13*ne10*ne11*ne12 - i12*ne10*ne11 - i11*ne10; const int64_t dst_offset = i10*nb10 + i11*nb11 + i12*nb12 + i13*nb13; + ggml_cuda_pdl_sync(); cpy_blck(cx + x_offset, cdst + dst_offset); } @@ -179,6 +187,7 @@ static __global__ void cpy_scalar_contiguous(const char * cx, char * cdst, const const src_t * x = (const src_t *) cx; dst_t * dst = (dst_t *) cdst; + ggml_cuda_pdl_sync(); dst[i] = ggml_cuda_cast<dst_t>(x[i]); } @@ -189,8 +198,8 @@ cudaStream_t stream) { const int64_t num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE; GGML_ASSERT(num_blocks < UINT_MAX); - cpy_scalar_contiguous<src_t, dst_t><<<num_blocks, CUDA_CPY_BLOCK_SIZE, 0, stream>>> - (cx, cdst, ne); + const ggml_cuda_kernel_launch_params launch_params = ggml_cuda_kernel_launch_params((dim3)num_blocks, CUDA_CPY_BLOCK_SIZE, 0, stream); + ggml_cuda_kernel_launch(cpy_scalar_contiguous<src_t, dst_t>, launch_params, cx, cdst, ne); } template<typename src_t, typename dst_t, bool transposed = false> @@ -220,13 +229,15 @@ static void ggml_cpy_scalar_cuda( GGML_ASSERT(grid_z < USHRT_MAX); dim3 dimGrid(grid_x, grid_y, grid_z); dim3 dimBlock(CUDA_CPY_TILE_DIM_2D, CUDA_CPY_BLOCK_ROWS, 1); - cpy_scalar_transpose<dst_t><<<dimGrid, dimBlock, 0, stream>>> - (cx, cdst, ne, ne00n, ne01n, ne02n, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13); + const ggml_cuda_kernel_launch_params launch_params = ggml_cuda_kernel_launch_params(dimGrid, dimBlock, 0, stream); + ggml_cuda_kernel_launch(cpy_scalar_transpose<dst_t>, launch_params, + cx, cdst, ne, ne00n, ne01n, ne02n, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13); } else { const int64_t num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE; GGML_ASSERT(num_blocks < UINT_MAX); - cpy_scalar<cpy_1_scalar<src_t, dst_t>><<<num_blocks, CUDA_CPY_BLOCK_SIZE, 0, stream>>> - (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13); + const ggml_cuda_kernel_launch_params launch_params = ggml_cuda_kernel_launch_params((dim3)num_blocks, CUDA_CPY_BLOCK_SIZE, 0, stream); + ggml_cuda_kernel_launch(cpy_scalar<cpy_1_scalar<src_t, dst_t>>, launch_params, + cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13); } } diff --git a/ggml/src/ggml-cuda/dequantize.cuh b/ggml/src/ggml-cuda/dequantize.cuh index e060fb29fdc..9ae1342fc0e 100644 --- a/ggml/src/ggml-cuda/dequantize.cuh +++ b/ggml/src/ggml-cuda/dequantize.cuh @@ -1,5 +1,27 @@ #include "common.cuh" +static __device__ __forceinline__ void dequantize_q1_0(const void * vx, const int64_t ib, const int iqs, float2 & v){ + const block_q1_0 * x = (const block_q1_0 *) vx; + + const float d = x[ib].d; + + const int bit_index_0 = iqs; + const int bit_index_1 = iqs + 1; + + const int byte_index_0 = bit_index_0 / 8; + const int bit_offset_0 = bit_index_0 % 8; + + const int byte_index_1 = bit_index_1 / 8; + const int bit_offset_1 = bit_index_1 % 8; + + // Extract bits: 1 = +d, 0 = -d (branchless) + const int bit_0 = (x[ib].qs[byte_index_0] >> bit_offset_0) & 1; + const int bit_1 = (x[ib].qs[byte_index_1] >> bit_offset_1) & 1; + + v.x = (2*bit_0 - 1) * d; + v.y = (2*bit_1 - 1) * d; +} + static __device__ __forceinline__ void dequantize_q4_0(const void * vx, const int64_t ib, const int iqs, float2 & v){ const block_q4_0 * x = (const block_q4_0 *) vx; diff --git a/ggml/src/ggml-cuda/fattn-common.cuh b/ggml/src/ggml-cuda/fattn-common.cuh index 31446787287..8dfa51ad1e8 100644 --- a/ggml/src/ggml-cuda/fattn-common.cuh +++ b/ggml/src/ggml-cuda/fattn-common.cuh @@ -44,6 +44,46 @@ typedef void (* fattn_kernel_t)( typedef float (*vec_dot_KQ_t)( const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8 , const void * __restrict__ Q_ds); +struct ggml_cuda_flash_attn_ext_f16_extra_data { + uintptr_t K; + uintptr_t V; + uintptr_t end; +}; + +static inline ggml_cuda_flash_attn_ext_f16_extra_data ggml_cuda_flash_attn_ext_get_f16_extra_data( + const ggml_tensor * dst, const bool need_f16_K, const bool need_f16_V) { + GGML_ASSERT(dst->op == GGML_OP_FLASH_ATTN_EXT); + + const ggml_tensor * K = dst->src[1]; + const ggml_tensor * V = dst->src[2]; + + GGML_ASSERT(K != nullptr); + GGML_ASSERT(V != nullptr); + + const bool V_is_K_view = V->view_src && (V->view_src == K || (V->view_src == K->view_src && V->view_offs == K->view_offs)); + + ggml_cuda_flash_attn_ext_f16_extra_data data = {}; + data.end = (uintptr_t) dst->data + ggml_nbytes(dst); + + if (need_f16_K && K->type != GGML_TYPE_F16) { + data.end = GGML_PAD(data.end, 128); + data.K = data.end; + data.end += ggml_nelements(K)*ggml_type_size(GGML_TYPE_F16); + } + + if (need_f16_V && V->type != GGML_TYPE_F16) { + if (V_is_K_view) { + data.V = data.K; + } else { + data.end = GGML_PAD(data.end, 128); + data.V = data.end; + data.end += ggml_nelements(V)*ggml_type_size(GGML_TYPE_F16); + } + } + + return data; +} + template <int D, int nthreads> static __device__ __forceinline__ float vec_dot_fattn_vec_KQ_f16( const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8 , const void * __restrict__ Q_ds_v) { @@ -59,7 +99,7 @@ static __device__ __forceinline__ float vec_dot_fattn_vec_KQ_f16( #pragma unroll for (int k_KQ_0 = 0; k_KQ_0 < D/2; k_KQ_0 += nthreads*cpy_ne) { - half2 tmp[cpy_ne]; + __align__(16) half2 tmp[cpy_ne]; ggml_cuda_memcpy_1<sizeof(tmp)>(tmp, K_h2 + k_KQ_0 + (threadIdx.x % nthreads)*cpy_ne); #pragma unroll for (int k_KQ_1 = 0; k_KQ_1 < cpy_ne; ++k_KQ_1) { @@ -74,6 +114,37 @@ static __device__ __forceinline__ float vec_dot_fattn_vec_KQ_f16( return sum; } +template <int D, int nthreads> +static __device__ __forceinline__ float vec_dot_fattn_vec_KQ_bf16( + const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8 , const void * __restrict__ Q_ds_v) { + + const nv_bfloat162 * K_bf16 = (const nv_bfloat162 *) K_c; + GGML_UNUSED(Q_q8); + GGML_UNUSED(Q_ds_v); + + constexpr int cpy_nb = ggml_cuda_get_max_cpy_bytes(); + constexpr int cpy_ne = cpy_nb / 4; + + float sum = 0.0f; + +#pragma unroll + for (int k_KQ_0 = 0; k_KQ_0 < D/2; k_KQ_0 += nthreads*cpy_ne) { + __align__(16) nv_bfloat162 tmp[cpy_ne]; + ggml_cuda_memcpy_1<sizeof(tmp)>(tmp, K_bf16 + k_KQ_0 + (threadIdx.x % nthreads)*cpy_ne); +#pragma unroll + for (int k_KQ_1 = 0; k_KQ_1 < cpy_ne; ++k_KQ_1) { +#ifdef V_DOT2_F32_F16_AVAILABLE + // FIXME replace macros in vector FA kernel with templating and use FP32 for BF16 + ggml_cuda_mad(sum, ggml_cuda_cast<float2>(tmp[k_KQ_1]), __half22float2(((const half2 *) Q_v)[k_KQ_0/nthreads + k_KQ_1])); +#else + ggml_cuda_mad(sum, ggml_cuda_cast<float2>(tmp[k_KQ_1]), ((const float2 *) Q_v)[k_KQ_0/nthreads + k_KQ_1]); +#endif // V_DOT2_F32_F16_AVAILABLE + } + } + + return sum; +} + template<int D, int nthreads> static __device__ __forceinline__ float vec_dot_fattn_vec_KQ_q4_0( const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) { @@ -309,7 +380,7 @@ static __device__ __forceinline__ void dequantize_V_f16(const void * __restrict_ ggml_cuda_memcpy_1<ne*sizeof(half)>(dst, (const half *) vx + i0); } else if constexpr (std::is_same_v<T, float>) { static_assert(ne % 2 == 0, "bad ne"); - half2 tmp[ne/2]; + __align__(16) half2 tmp[ne/2]; ggml_cuda_memcpy_1<ne*sizeof(half)>(tmp, (const half *) vx + i0); float2 * dst_f2 = (float2 *) dst; #pragma unroll @@ -321,6 +392,19 @@ static __device__ __forceinline__ void dequantize_V_f16(const void * __restrict_ } } +template <typename T, int ne> +static __device__ __forceinline__ void dequantize_V_bf16(const void * __restrict__ vx, void * __restrict__ dst, const int64_t i0) { + static_assert(std::is_same_v<T, float>, "BF16 V dequantization only supports float output"); + static_assert(ne % 2 == 0, "bad ne"); + __align__(16) nv_bfloat162 tmp[ne/2]; + ggml_cuda_memcpy_1<ne*sizeof(nv_bfloat16)>(tmp, (const nv_bfloat16 *) vx + i0); + float2 * dst_f2 = (float2 *) dst; +#pragma unroll + for (int l = 0; l < ne/2; ++l) { + dst_f2[l] = ggml_cuda_cast<float2>(tmp[l]); + } +} + template <typename T, int ne> static __device__ __forceinline__ void dequantize_V_q4_0(const void * __restrict__ vx, void * __restrict__ dst, const int64_t i0) { const block_q4_0 * x = (const block_q4_0 *) vx; @@ -547,6 +631,8 @@ constexpr __device__ vec_dot_KQ_t get_vec_dot_KQ() { return vec_dot_fattn_vec_KQ_q5_1<D, nthreads>; } else if constexpr (type_K == GGML_TYPE_Q8_0) { return vec_dot_fattn_vec_KQ_q8_0<D, nthreads>; + } else if constexpr (type_K == GGML_TYPE_BF16) { + return vec_dot_fattn_vec_KQ_bf16<D, nthreads>; } else { static_assert(type_K == -1, "bad type"); return nullptr; @@ -567,6 +653,8 @@ constexpr __device__ dequantize_V_t get_dequantize_V() { return dequantize_V_q5_1<T, ne>; } else if constexpr (type_V == GGML_TYPE_Q8_0) { return dequantize_V_q8_0<T, ne>; + } else if constexpr (type_V == GGML_TYPE_BF16) { + return dequantize_V_bf16<float, ne>; } else { static_assert(type_V == -1, "bad type"); return nullptr; @@ -588,6 +676,7 @@ static __global__ void flash_attn_mask_to_KV_max( if (tid < WARP_SIZE) { buf_iw[tid] = 1; } + ggml_cuda_pdl_sync(); __syncthreads(); int KV_max_sj = (ne30 - 1) * FATTN_KQ_STRIDE; @@ -628,9 +717,102 @@ static __global__ void flash_attn_mask_to_KV_max( template<int D, int ncols1, int ncols2> // D == head size __launch_bounds__(D, 1) -static __global__ void flash_attn_stream_k_fixup( - float * __restrict__ dst, const float2 * __restrict__ dst_fixup, const int ne01, const int ne02, const int ne03, const int ne11, - const int nbatch_fa) { +static __global__ void flash_attn_stream_k_fixup_uniform( + float * dst_ptr, + const float2 * dst_fixup_ptr, + const int ne01, const int ne02, + const int ne12, const int nblocks_stream_k, + const int gqa_ratio, + const int blocks_per_tile, + const uint3 fd_iter_j_z_ne12, + const uint3 fd_iter_j_z, + const uint3 fd_iter_j) { + constexpr int ncols = ncols1*ncols2; + ggml_cuda_pdl_lc(); + float * GGML_CUDA_RESTRICT dst = dst_ptr; + const float2 * GGML_CUDA_RESTRICT dst_fixup = dst_fixup_ptr; + + const int tile_idx = blockIdx.x; // One block per output tile. + const int j = blockIdx.y; + const int c = blockIdx.z; + const int jc = j*ncols2 + c; + const int tid = threadIdx.x; + + // nblocks_stream_k is a multiple of ntiles_dst (== gridDim.x), so each tile gets the same number of blocks. + const int b_first = tile_idx * blocks_per_tile; + const int b_last = b_first + blocks_per_tile - 1; + + const float * dst_fixup_data = ((const float *) dst_fixup) + nblocks_stream_k*(2*2*ncols); + + // z_KV == K/V head index, zt_gqa = Q head start index per K/V head, jt = token position start index + const uint2 dm0 = fast_div_modulo(tile_idx, fd_iter_j_z_ne12); + const uint2 dm1 = fast_div_modulo(dm0.y, fd_iter_j_z); + const uint2 dm2 = fast_div_modulo(dm1.y, fd_iter_j); + + const int sequence = dm0.x; + const int z_KV = dm1.x; + const int zt_gqa = dm2.x; + const int jt = dm2.y; + + const int zt_Q = z_KV*gqa_ratio + zt_gqa*ncols2; // Global Q head start index. + + if (jt*ncols1 + j >= ne01 || zt_gqa*ncols2 + c >= gqa_ratio) { + return; + } + + dst += sequence*ne02*ne01*D + jt*ne02*(ncols1*D) + zt_Q*D + (j*ne02 + c)*D + tid; + + ggml_cuda_pdl_sync(); + // Load the partial result that needs a fixup + float dst_val = *dst; + float max_val; + float rowsum; + { + const float2 tmp = dst_fixup[b_last*ncols + jc]; + max_val = tmp.x; + rowsum = tmp.y; + } + + // Combine with all previous blocks in this tile. + for (int bidx = b_last - 1; bidx >= b_first; --bidx) { + const float dst_add = dst_fixup_data[bidx*ncols*D + jc*D + tid]; + + const float2 tmp = dst_fixup[(nblocks_stream_k + bidx)*ncols + jc]; + + const float max_val_new = fmaxf(max_val, tmp.x); + + const float diff_val = max_val - max_val_new; + const float diff_add = tmp.x - max_val_new; + + const float scale_val = diff_val >= SOFTMAX_FTZ_THRESHOLD ? expf(diff_val) : 0.0f; + const float scale_add = diff_add >= SOFTMAX_FTZ_THRESHOLD ? expf(diff_add) : 0.0f; + + dst_val = scale_val*dst_val + scale_add*dst_add; + rowsum = scale_val*rowsum + scale_add*tmp.y; + + max_val = max_val_new; + } + + // Write back final result: + *dst = dst_val / rowsum; +} + +// General fixup kernel for the case where the number of blocks per tile is not uniform across tiles +// (blocks_num.x not a multiple of ntiles_dst) +template <int D, int ncols1, int ncols2> // D == head size +__launch_bounds__(D, 1) +static __global__ void flash_attn_stream_k_fixup_general( + float * dst_ptr, + const float2 * dst_fixup_ptr, + const int ne01, const int ne02, + const int gqa_ratio, + const int total_work, + const uint3 fd_iter_k_j_z_ne12, + const uint3 fd_iter_k_j_z, + const uint3 fd_iter_k_j, + const uint3 fd_iter_k) { + float * GGML_CUDA_RESTRICT dst = dst_ptr; + const float2 * GGML_CUDA_RESTRICT dst_fixup = dst_fixup_ptr; constexpr int ncols = ncols1*ncols2; const int bidx0 = blockIdx.x; @@ -641,33 +823,40 @@ static __global__ void flash_attn_stream_k_fixup( const float * dst_fixup_data = ((const float *) dst_fixup) + gridDim.x*(2*2*ncols); - const int iter_k = (ne11 + (nbatch_fa - 1)) / nbatch_fa; - const int iter_j = (ne01 + (ncols1 - 1)) / ncols1; - - const int kbc0 = int64_t(bidx0 + 0)*(iter_k*iter_j*(ne02/ncols2)*ne03) / gridDim.x; - const int kbc0_stop = int64_t(bidx0 + 1)*(iter_k*iter_j*(ne02/ncols2)*ne03) / gridDim.x; + const int kbc0 = int64_t(bidx0 + 0)*total_work / gridDim.x; + const int kbc0_stop = int64_t(bidx0 + 1)*total_work / gridDim.x; const bool did_not_have_any_data = kbc0 == kbc0_stop; - const bool wrote_beginning_of_tile = kbc0 % iter_k == 0; - const bool did_not_write_last = kbc0/iter_k == kbc0_stop/iter_k && kbc0_stop % iter_k != 0; + const bool wrote_beginning_of_tile = fastmodulo(kbc0, fd_iter_k) == 0; + const bool did_not_write_last = fastdiv(kbc0, fd_iter_k) == fastdiv(kbc0_stop, fd_iter_k) && fastmodulo(kbc0_stop, fd_iter_k) != 0; if (did_not_have_any_data || wrote_beginning_of_tile || did_not_write_last) { return; } - const int sequence = kbc0 / (iter_k*iter_j*(ne02/ncols2)); - const int head = (kbc0 - iter_k*iter_j*(ne02/ncols2)*sequence) / (iter_k*iter_j); - const int jt = (kbc0 - iter_k*iter_j*(ne02/ncols2)*sequence - iter_k*iter_j*head) / iter_k; // j index of current tile. + // z_KV == K/V head index, zt_gqa = Q head start index per K/V head, jt = token position start index + const uint2 dm0 = fast_div_modulo(kbc0, fd_iter_k_j_z_ne12); + const uint2 dm1 = fast_div_modulo(dm0.y, fd_iter_k_j_z); + const uint2 dm2 = fast_div_modulo(dm1.y, fd_iter_k_j); + const uint2 dm3 = fast_div_modulo(dm2.y, fd_iter_k); + + const int sequence = dm0.x; + const int z_KV = dm1.x; + const int zt_gqa = dm2.x; + const int jt = dm3.x; + + const int zt_Q = z_KV*gqa_ratio + zt_gqa*ncols2; // Global Q head start index. - if (jt*ncols1 + j >= ne01) { + if (jt*ncols1 + j >= ne01 || zt_gqa*ncols2 + c >= gqa_ratio) { return; } - dst += sequence*ne02*ne01*D + jt*ne02*(ncols1*D) + head*(ncols2*D) + (j*ne02 + c)*D + tid; + dst += sequence*ne02*ne01*D + jt*ne02*(ncols1*D) + zt_Q*D + (j*ne02 + c)*D + tid; // Load the partial result that needs a fixup: float dst_val = 0.0f; float max_val = 0.0f; float rowsum = 0.0f; + ggml_cuda_pdl_sync(); { dst_val = *dst; @@ -678,10 +867,11 @@ static __global__ void flash_attn_stream_k_fixup( // Iterate over previous blocks and compute the combined results. // All CUDA blocks that get here must have a previous block that needs a fixup. + const int tile_kbc0 = fastdiv(kbc0, fd_iter_k); int bidx = bidx0 - 1; int kbc_stop = kbc0; while(true) { - const int kbc = int64_t(bidx)*(iter_k*iter_j*(ne02/ncols2)*ne03) / gridDim.x; + const int kbc = int64_t(bidx)*total_work / gridDim.x; if (kbc == kbc_stop) { // Did not have any data. bidx--; kbc_stop = kbc; @@ -707,7 +897,7 @@ static __global__ void flash_attn_stream_k_fixup( max_val = max_val_new; // If this block started in a previous tile we are done and don't need to combine additional partial results. - if (kbc % iter_k == 0 || kbc/iter_k < kbc0/iter_k) { + if (fastmodulo(kbc, fd_iter_k) == 0 || fastdiv(kbc, fd_iter_k) < tile_kbc0) { break; } bidx--; @@ -721,10 +911,14 @@ static __global__ void flash_attn_stream_k_fixup( template<int D> // D == head size __launch_bounds__(D, 1) static __global__ void flash_attn_combine_results( - const float * __restrict__ VKQ_parts, - const float2 * __restrict__ VKQ_meta, - float * __restrict__ dst, + const float * VKQ_parts_ptr, + const float2 * VKQ_meta_ptr, + float * dst_ptr, const int parallel_blocks) { + ggml_cuda_pdl_lc(); + const float * GGML_CUDA_RESTRICT VKQ_parts = VKQ_parts_ptr; + const float2 * GGML_CUDA_RESTRICT VKQ_meta = VKQ_meta_ptr; + float * GGML_CUDA_RESTRICT dst = dst_ptr; // Dimension 0: threadIdx.x // Dimension 1: blockIdx.x // Dimension 2: blockIdx.y @@ -748,6 +942,7 @@ static __global__ void flash_attn_combine_results( __builtin_assume(tid < D); extern __shared__ float2 meta[]; + ggml_cuda_pdl_sync(); for (int i = tid; i < 2*parallel_blocks; i += D) { ((float *) meta)[i] = ((const float *)VKQ_meta) [i]; } @@ -778,13 +973,11 @@ void launch_fattn( ) { constexpr int ncols = ncols1 * ncols2; - const bool is_mla = DV == 512; // TODO better parameterization - const ggml_tensor * Q = dst->src[0]; const ggml_tensor * K = dst->src[1]; const ggml_tensor * V = dst->src[2]; - GGML_ASSERT(V || is_mla); + const bool V_is_K_view = V->view_src && (V->view_src == K || (V->view_src == K->view_src && V->view_offs == K->view_offs)); const ggml_tensor * mask = dst->src[3]; const ggml_tensor * sinks = dst->src[4]; @@ -794,9 +987,9 @@ void launch_fattn( GGML_ASSERT(Q->type == GGML_TYPE_F32); GGML_ASSERT(KQV->type == GGML_TYPE_F32); - GGML_ASSERT( Q->nb[0] == ggml_element_size(Q)); - GGML_ASSERT( K->nb[0] == ggml_element_size(K)); - GGML_ASSERT(!V || V->nb[0] == ggml_element_size(V)); + GGML_ASSERT(Q->nb[0] == ggml_element_size(Q)); + GGML_ASSERT(K->nb[0] == ggml_element_size(K)); + GGML_ASSERT(V->nb[0] == ggml_element_size(V)); GGML_ASSERT(!mask || mask->type == GGML_TYPE_F16); @@ -806,8 +999,9 @@ void launch_fattn( const int cc = ggml_cuda_info().devices[id].cc; const int nsm = ggml_cuda_info().devices[id].nsm; - ggml_cuda_pool_alloc<half> K_f16(pool); - ggml_cuda_pool_alloc<half> V_f16(pool); + const ggml_cuda_flash_attn_ext_f16_extra_data f16_extra = + ggml_cuda_flash_attn_ext_get_f16_extra_data(KQV, need_f16_K, need_f16_V); + ggml_cuda_pool_alloc<int> KV_max(pool); ggml_cuda_pool_alloc<float> dst_tmp(pool); ggml_cuda_pool_alloc<float2> dst_tmp_meta(pool); @@ -817,19 +1011,20 @@ void launch_fattn( size_t nb12 = K->nb[2]; size_t nb13 = K->nb[3]; - const char * V_data = V ? (const char *) V->data : nullptr; - size_t nb21 = V ? V->nb[1] : nb11; - size_t nb22 = V ? V->nb[2] : nb12; - size_t nb23 = V ? V->nb[3] : nb13; + const char * V_data = (const char *) V->data; + size_t nb21 = V->nb[1]; + size_t nb22 = V->nb[2]; + size_t nb23 = V->nb[3]; if (need_f16_K && K->type != GGML_TYPE_F16) { const size_t bs = ggml_blck_size(K->type); const size_t ts = ggml_type_size(K->type); - K_f16.alloc(ggml_nelements(K)); + GGML_ASSERT(f16_extra.K != 0); + half * K_f16 = (half *) f16_extra.K; if (ggml_is_contiguously_allocated(K)) { to_fp16_cuda_t to_fp16 = ggml_get_to_fp16_cuda(K->type); - to_fp16(K_data, K_f16.ptr, ggml_nelements(K), main_stream); + to_fp16(K_data, K_f16, ggml_nelements(K), main_stream); nb11 = nb11*bs*sizeof(half)/ts; nb12 = nb12*bs*sizeof(half)/ts; @@ -840,45 +1035,55 @@ void launch_fattn( const int64_t s01 = nb11 / ts; const int64_t s02 = nb12 / ts; const int64_t s03 = nb13 / ts; - to_fp16(K_data, K_f16.ptr, K->ne[0], K->ne[1], K->ne[2], K->ne[3], s01, s02, s03, main_stream); + to_fp16(K_data, K_f16, K->ne[0], K->ne[1], K->ne[2], K->ne[3], s01, s02, s03, main_stream); nb11 = K->ne[0] * sizeof(half); nb12 = K->ne[1] * nb11; nb13 = K->ne[2] * nb12; } - K_data = (char *) K_f16.ptr; + K_data = (char *) K_f16; } - if (V && need_f16_V && V->type != GGML_TYPE_F16) { - const size_t bs = ggml_blck_size(V->type); - const size_t ts = ggml_type_size(V->type); - - V_f16.alloc(ggml_nelements(V)); - if (ggml_is_contiguously_allocated(V)) { - to_fp16_cuda_t to_fp16 = ggml_get_to_fp16_cuda(V->type); - to_fp16(V_data, V_f16.ptr, ggml_nelements(V), main_stream); - V_data = (char *) V_f16.ptr; - - nb21 = nb21*bs*sizeof(half)/ts; - nb22 = nb22*bs*sizeof(half)/ts; - nb23 = nb23*bs*sizeof(half)/ts; + if (need_f16_V && V->type != GGML_TYPE_F16) { + if (V_is_K_view) { + V_data = K_data; + nb21 = nb11; + nb22 = nb12; + nb23 = nb13; } else { - GGML_ASSERT(V->nb[0] == ts); - to_fp16_nc_cuda_t to_fp16 = ggml_get_to_fp16_nc_cuda(V->type); - const int64_t s01 = nb21 / ts; - const int64_t s02 = nb22 / ts; - const int64_t s03 = nb23 / ts; - to_fp16(V_data, V_f16.ptr, V->ne[0], V->ne[1], V->ne[2], V->ne[3], s01, s02, s03, main_stream); - - nb21 = V->ne[0] * sizeof(half); - nb22 = V->ne[1] * nb21; - nb23 = V->ne[2] * nb22; + const size_t bs = ggml_blck_size(V->type); + const size_t ts = ggml_type_size(V->type); + + GGML_ASSERT(f16_extra.V != 0); + half * V_f16 = (half *) f16_extra.V; + if (ggml_is_contiguously_allocated(V)) { + to_fp16_cuda_t to_fp16 = ggml_get_to_fp16_cuda(V->type); + to_fp16(V_data, V_f16, ggml_nelements(V), main_stream); + V_data = (char *) V_f16; + + nb21 = nb21*bs*sizeof(half)/ts; + nb22 = nb22*bs*sizeof(half)/ts; + nb23 = nb23*bs*sizeof(half)/ts; + } else { + GGML_ASSERT(V->nb[0] == ts); + to_fp16_nc_cuda_t to_fp16 = ggml_get_to_fp16_nc_cuda(V->type); + const int64_t s01 = nb21 / ts; + const int64_t s02 = nb22 / ts; + const int64_t s03 = nb23 / ts; + to_fp16(V_data, V_f16, V->ne[0], V->ne[1], V->ne[2], V->ne[3], s01, s02, s03, main_stream); + + nb21 = V->ne[0] * sizeof(half); + nb22 = V->ne[1] * nb21; + nb23 = V->ne[2] * nb22; + } + V_data = (char *) V_f16; } - V_data = (char *) V_f16.ptr; } - const int ntiles_x = ((Q->ne[1] + ncols1 - 1) / ncols1); - const int ntiles_total = ntiles_x * (Q->ne[2] / ncols2) * Q->ne[3]; + const int ntiles_x = ((Q->ne[1] + ncols1 - 1) / ncols1); + const int gqa_ratio = Q->ne[2] / K->ne[2]; + const int ntiles_z_gqa = ((gqa_ratio + ncols2 - 1) / ncols2); + const int ntiles_dst = ntiles_x * ntiles_z_gqa * K->ne[2] * Q->ne[3]; // Optional optimization where the mask is scanned to determine whether part of the calculation can be skipped. // Only worth the overhead if there is at lease one FATTN_KQ_STRIDE x FATTN_KQ_STRIDE square to be skipped or @@ -905,37 +1110,51 @@ void launch_fattn( GGML_ASSERT(max_blocks_per_sm > 0); int parallel_blocks = max_blocks_per_sm; + const int ntiles_KV = (K->ne[1] + nbatch_fa - 1) / nbatch_fa; // Max. number of parallel blocks limited by KV cache length. + dim3 blocks_num; if (stream_k) { // For short contexts it can be faster to have the SMs work on whole tiles because this lets us skip the fixup. const int max_blocks = max_blocks_per_sm*nsm; - const int tiles_nwaves = (ntiles_total + max_blocks - 1) / max_blocks; - const int tiles_efficiency_percent = 100 * ntiles_total / (max_blocks*tiles_nwaves); + const int tiles_nwaves = (ntiles_dst + max_blocks - 1) / max_blocks; + const int tiles_efficiency_percent = 100 * ntiles_dst / (max_blocks*tiles_nwaves); - const int nblocks_stream_k = max_blocks; + const bool use_stream_k = cc >= GGML_CUDA_CC_ADA_LOVELACE || amd_wmma_available(cc) || tiles_efficiency_percent < 75; - const bool use_stream_k = cc >= GGML_CUDA_CC_ADA_LOVELACE || tiles_efficiency_percent < 75; - - blocks_num.x = use_stream_k ? nblocks_stream_k : ntiles_total; + blocks_num.x = ntiles_dst; blocks_num.y = 1; blocks_num.z = 1; - if (ntiles_total % blocks_num.x != 0) { // Fixup is only needed if the SMs work on fractional tiles. + if(use_stream_k) { + const int nblocks_stream_k_raw = std::min(max_blocks, ntiles_KV*ntiles_dst); + // Round down to a multiple of ntiles_dst so that each output tile gets the same number of blocks (avoids fixup). + // Only do this if the occupancy loss from rounding is acceptable. + const int nblocks_stream_k_rounded = (nblocks_stream_k_raw / ntiles_dst) * ntiles_dst; + const int max_efficiency_loss_percent = 5; + const int efficiency_loss_percent = nblocks_stream_k_rounded > 0 + ? 100 * (nblocks_stream_k_raw - nblocks_stream_k_rounded) / nblocks_stream_k_raw + : 100; + const int nblocks_stream_k = efficiency_loss_percent <= max_efficiency_loss_percent + ? nblocks_stream_k_rounded + : nblocks_stream_k_raw; + + blocks_num.x = nblocks_stream_k; + } + + if (ntiles_dst % blocks_num.x != 0) { // Fixup is only needed if the SMs work on fractional tiles. dst_tmp_meta.alloc((size_t(blocks_num.x) * ncols * (2 + DV/2))); } } else { - const int ntiles_KQ = (K->ne[1] + nbatch_fa - 1) / nbatch_fa; // Max. number of parallel blocks limited by tensor size. - // parallel_blocks must not be larger than what the tensor size allows: - parallel_blocks = std::min(parallel_blocks, ntiles_KQ); + parallel_blocks = std::min(parallel_blocks, ntiles_KV); // If ntiles_total % blocks_per_wave != 0 then some efficiency is lost due to tail effects. // Test whether parallel_blocks can be set to a higher value for better efficiency. const int blocks_per_wave = nsm * max_blocks_per_sm; int nwaves_best = 0; int efficiency_percent_best = 0; - for (int parallel_blocks_test = parallel_blocks; parallel_blocks_test <= ntiles_KQ; ++parallel_blocks_test) { - const int nblocks_total = ntiles_total * parallel_blocks_test; + for (int parallel_blocks_test = parallel_blocks; parallel_blocks_test <= ntiles_KV; ++parallel_blocks_test) { + const int nblocks_total = ntiles_dst * parallel_blocks_test; const int nwaves = (nblocks_total + blocks_per_wave - 1) / blocks_per_wave; const int efficiency_percent = 100 * nblocks_total / (nwaves*blocks_per_wave); @@ -953,7 +1172,7 @@ void launch_fattn( blocks_num.x = ntiles_x; blocks_num.y = parallel_blocks; - blocks_num.z = (Q->ne[2]/ncols2)*Q->ne[3]; + blocks_num.z = ntiles_z_gqa*K->ne[2]*Q->ne[3]; if (parallel_blocks > 1) { dst_tmp.alloc(parallel_blocks*ggml_nelements(KQV)); @@ -983,7 +1202,9 @@ void launch_fattn( const uint3 ne01 = init_fastdiv_values(Q->ne[1]); GGML_ASSERT(block_dim.x % warp_size == 0); - fattn_kernel<<<blocks_num, block_dim, nbytes_shared, main_stream>>>( + + ggml_cuda_kernel_launch_params launch_params = ggml_cuda_kernel_launch_params(blocks_num, block_dim, nbytes_shared, main_stream); + ggml_cuda_kernel_launch(fattn_kernel, launch_params, (const char *) Q->data, K_data, V_data, @@ -1001,22 +1222,49 @@ void launch_fattn( CUDA_CHECK(cudaGetLastError()); if (stream_k) { - if (ntiles_total % blocks_num.x != 0) { // Fixup is only needed if the SMs work on fractional tiles. + if ((int)blocks_num.x % ntiles_dst == 0 && (int)blocks_num.x > ntiles_dst) { + // Optimized fixup: nblocks_stream_k is a multiple of ntiles_dst, launch one block per tile. + const int nblocks_sk = (int)blocks_num.x; + const int bpt = nblocks_sk / ntiles_dst; + + const uint3 fd0 = init_fastdiv_values(ntiles_x * ntiles_z_gqa * K->ne[2]); + const uint3 fd1 = init_fastdiv_values(ntiles_x * ntiles_z_gqa); + const uint3 fd2 = init_fastdiv_values(ntiles_x); + + const dim3 block_dim_combine(DV, 1, 1); + const dim3 blocks_num_combine = {(unsigned)ntiles_dst, ncols1, ncols2}; + + const ggml_cuda_kernel_launch_params launch_params = ggml_cuda_kernel_launch_params(blocks_num_combine, block_dim_combine, 0, main_stream); + ggml_cuda_kernel_launch(flash_attn_stream_k_fixup_uniform<DV, ncols1, ncols2>, launch_params, + (float *) KQV->data, dst_tmp_meta.ptr, + Q->ne[1], Q->ne[2], K->ne[2], nblocks_sk, + gqa_ratio, bpt, fd0, fd1, fd2); + } else if (ntiles_dst % blocks_num.x != 0) { + // General fixup for the cases where nblocks_stream_k < ntiles_dst. + const int total_work = ntiles_KV * ntiles_dst; + + const uint3 fd_k_j_z_ne12 = init_fastdiv_values(ntiles_KV * ntiles_x * ntiles_z_gqa * K->ne[2]); + const uint3 fd_k_j_z = init_fastdiv_values(ntiles_KV * ntiles_x * ntiles_z_gqa); + const uint3 fd_k_j = init_fastdiv_values(ntiles_KV * ntiles_x); + const uint3 fd_k = init_fastdiv_values(ntiles_KV); + const dim3 block_dim_combine(DV, 1, 1); const dim3 blocks_num_combine = {blocks_num.x, ncols1, ncols2}; - flash_attn_stream_k_fixup<DV, ncols1, ncols2> - <<<blocks_num_combine, block_dim_combine, 0, main_stream>>> - ((float *) KQV->data, dst_tmp_meta.ptr, Q->ne[1], Q->ne[2], Q->ne[3], K->ne[1], nbatch_fa); + const ggml_cuda_kernel_launch_params launch_params = ggml_cuda_kernel_launch_params(blocks_num_combine, block_dim_combine, 0, main_stream); + ggml_cuda_kernel_launch(flash_attn_stream_k_fixup_general<DV, ncols1, ncols2>, launch_params, + (float *) KQV->data, dst_tmp_meta.ptr, + Q->ne[1], Q->ne[2], gqa_ratio, total_work, + fd_k_j_z_ne12, fd_k_j_z, fd_k_j, fd_k); } } else if (parallel_blocks > 1) { const dim3 block_dim_combine(DV, 1, 1); const dim3 blocks_num_combine(Q->ne[1], Q->ne[2], Q->ne[3]); const size_t nbytes_shared_combine = parallel_blocks*sizeof(float2); - flash_attn_combine_results<DV> - <<<blocks_num_combine, block_dim_combine, nbytes_shared_combine, main_stream>>> - (dst_tmp.ptr, dst_tmp_meta.ptr, (float *) KQV->data, parallel_blocks); + const ggml_cuda_kernel_launch_params launch_params = ggml_cuda_kernel_launch_params(blocks_num_combine, block_dim_combine, nbytes_shared_combine, main_stream); + ggml_cuda_kernel_launch(flash_attn_combine_results<DV>, launch_params, + dst_tmp.ptr, dst_tmp_meta.ptr, (float *) KQV->data, parallel_blocks); } CUDA_CHECK(cudaGetLastError()); } diff --git a/ggml/src/ggml-cuda/fattn-mma-f16.cuh b/ggml/src/ggml-cuda/fattn-mma-f16.cuh index 856291dc3ce..83478a02cb6 100644 --- a/ggml/src/ggml-cuda/fattn-mma-f16.cuh +++ b/ggml/src/ggml-cuda/fattn-mma-f16.cuh @@ -61,11 +61,24 @@ static constexpr __host__ __device__ fattn_mma_config ggml_cuda_fattn_mma_get_co GGML_CUDA_FATTN_MMA_CONFIG_CASE(128, 128, 32, 128, 2, 64, 64, 64, 64, 2, true); GGML_CUDA_FATTN_MMA_CONFIG_CASE(128, 128, 64, 128, 2, 64, 64, 64, 64, 2, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(192, 128, 8, 64, 4, 64, 96, 64, 64, 2, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(192, 128, 16, 64, 4, 32, 96, 64, 64, 2, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(192, 128, 32, 128, 2, 32, 96, 64, 64, 2, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(192, 128, 64, 128, 2, 32, 96, 64, 64, 2, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 8, 64, 4, 64, 128, 128, 128, 2, true); GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 16, 64, 4, 32, 128, 128, 128, 2, true); GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 32, 128, 2, 32, 128, 128, 128, 2, true); GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 64, 128, 2, 32, 128, 128, 128, 2, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(320, 256, 32, 128, 2, 32, 128, 128, 128, 1, false); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(320, 256, 64, 256, 1, 32, 128, 128, 128, 1, false); + + GGML_CUDA_FATTN_MMA_CONFIG_CASE(512, 512, 8, 64, 4, 32, 256, 256, 128, 1, false); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(512, 512, 16, 64, 4, 32, 256, 256, 128, 1, false); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(512, 512, 32, 128, 2, 32, 128, 128, 128, 1, false); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(512, 512, 64, 256, 1, 32, 128, 128, 128, 1, false); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 8, 64, 4, 32, 288, 256, 128, 1, false); GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 16, 64, 4, 32, 288, 256, 128, 1, false); GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 32, 128, 2, 32, 160, 128, 128, 1, false); @@ -80,6 +93,14 @@ static constexpr __host__ __device__ fattn_mma_config ggml_cuda_fattn_mma_get_co GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 32, 128, 2, 64, 128, 128, 64, 2, true); GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 64, 128, 2, 64, 128, 128, 64, 2, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(320, 256, 32, 128, 2, 32, 128, 128, 128, 1, false); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(320, 256, 64, 256, 1, 32, 128, 128, 128, 1, false); + + GGML_CUDA_FATTN_MMA_CONFIG_CASE(512, 512, 8, 64, 4, 32, 96, 64, 128, 1, false); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(512, 512, 16, 64, 4, 32, 96, 64, 128, 1, false); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(512, 512, 32, 128, 2, 32, 128, 128, 128, 1, false); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(512, 512, 64, 256, 1, 32, 128, 128, 128, 1, false); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 8, 64, 4, 32, 96, 64, 128, 1, false); GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 16, 64, 4, 32, 96, 64, 128, 1, false); GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 32, 128, 2, 32, 160, 128, 128, 1, false); @@ -89,6 +110,11 @@ static constexpr __host__ __device__ fattn_mma_config ggml_cuda_fattn_mma_get_co } static constexpr __host__ __device__ fattn_mma_config ggml_cuda_fattn_mma_get_config_volta(const int DKQ, const int DV, const int ncols) { + GGML_CUDA_FATTN_MMA_CONFIG_CASE(512, 512, 8, 64, 4, 32, 256, 256, 64, 1, false); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(512, 512, 16, 64, 4, 32, 256, 256, 64, 1, false); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(512, 512, 32, 128, 2, 32, 128, 128, 64, 1, false); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(512, 512, 64, 256, 1, 32, 128, 128, 64, 1, false); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 8, 64, 4, 32, 288, 256, 64, 1, false); GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 16, 64, 4, 32, 288, 256, 64, 1, false); GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 32, 128, 2, 32, 160, 128, 64, 1, false); @@ -98,6 +124,110 @@ static constexpr __host__ __device__ fattn_mma_config ggml_cuda_fattn_mma_get_co return ggml_cuda_fattn_mma_get_config_ampere(DKQ, DV, ncols); } +static constexpr __host__ __device__ fattn_mma_config ggml_cuda_fattn_mma_get_config_rdna(const int DKQ, const int DV, const int ncols) { + GGML_CUDA_FATTN_MMA_CONFIG_CASE( 64, 64, 8, 128, 2, 64, 32, 32, 32, 1, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE( 64, 64, 16, 128, 2, 64, 32, 32, 32, 1, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE( 64, 64, 32, 128, 2, 64, 32, 32, 32, 1, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE( 64, 64, 64, 128, 2, 64, 32, 32, 32, 1, true); + + GGML_CUDA_FATTN_MMA_CONFIG_CASE( 80, 80, 8, 64, 2, 32, 40, 40, 40, 1, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE( 80, 80, 16, 64, 2, 32, 40, 40, 40, 1, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE( 80, 80, 32, 128, 2, 64, 40, 40, 40, 1, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE( 80, 80, 64, 128, 2, 64, 40, 40, 40, 1, true); + + GGML_CUDA_FATTN_MMA_CONFIG_CASE( 96, 96, 8, 64, 2, 32, 48, 48, 48, 1, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE( 96, 96, 16, 64, 2, 32, 48, 48, 48, 1, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE( 96, 96, 32, 128, 2, 64, 48, 48, 48, 1, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE( 96, 96, 64, 128, 2, 64, 48, 48, 48, 1, true); + + GGML_CUDA_FATTN_MMA_CONFIG_CASE(112, 112, 8, 64, 2, 32, 56, 56, 56, 1, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(112, 112, 16, 64, 2, 32, 56, 56, 56, 1, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(112, 112, 32, 128, 2, 64, 56, 56, 56, 1, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(112, 112, 64, 128, 2, 64, 56, 56, 56, 1, true); + + GGML_CUDA_FATTN_MMA_CONFIG_CASE(128, 128, 8, 64, 2, 32, 64, 64, 64, 1, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(128, 128, 16, 64, 2, 32, 64, 64, 64, 1, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(128, 128, 32, 128, 2, 64, 64, 64, 64, 1, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(128, 128, 64, 128, 2, 64, 64, 64, 64, 1, true); + + GGML_CUDA_FATTN_MMA_CONFIG_CASE(192, 128, 8, 64, 2, 32, 96, 64, 64, 1, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(192, 128, 16, 64, 2, 32, 96, 64, 64, 1, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(192, 128, 32, 128, 2, 64, 96, 64, 64, 1, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(192, 128, 64, 128, 2, 64, 96, 64, 64, 1, true); + + GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 8, 64, 2, 32, 128, 128, 128, 1, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 16, 64, 2, 32, 128, 128, 128, 1, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 32, 128, 2, 64, 128, 128, 64, 1, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 64, 128, 2, 64, 128, 128, 64, 1, true); + + GGML_CUDA_FATTN_MMA_CONFIG_CASE(320, 256, 32, 128, 2, 32, 160, 128, 128, 1, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(320, 256, 64, 128, 2, 32, 160, 128, 128, 1, true); + + GGML_CUDA_FATTN_MMA_CONFIG_CASE(512, 512, 8, 128, 3, 64, 96, 64, 128, 1, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(512, 512, 16, 128, 3, 64, 96, 64, 128, 1, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(512, 512, 32, 128, 2, 32, 128, 128, 128, 1, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(512, 512, 64, 128, 2, 32, 128, 128, 128, 1, true); + + GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 8, 128, 3, 64, 96, 64, 128, 1, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 16, 128, 3, 64, 96, 64, 128, 1, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 32, 128, 2, 32, 160, 128, 128, 1, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 64, 128, 2, 32, 160, 128, 128, 1, true); + + return fattn_mma_config(32, 1, 0, 0, 0, 0, 0, false); +} + +static constexpr __host__ __device__ fattn_mma_config ggml_cuda_fattn_mma_get_config_cdna(const int DKQ, const int DV, const int ncols) { + GGML_CUDA_FATTN_MMA_CONFIG_CASE( 64, 64, 8, 128, 1, 64, 32, 32, 32, 1, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE( 64, 64, 16, 256, 2, 64, 32, 32, 32, 1, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE( 64, 64, 32, 256, 2, 64, 32, 32, 32, 1, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE( 64, 64, 64, 256, 4, 64, 32, 32, 32, 1, true); + + GGML_CUDA_FATTN_MMA_CONFIG_CASE( 80, 80, 8, 256, 2, 64, 40, 40, 40, 1, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE( 80, 80, 16, 256, 2, 64, 40, 40, 40, 1, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE( 80, 80, 32, 256, 2, 64, 40, 40, 40, 1, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE( 80, 80, 64, 256, 2, 64, 40, 40, 40, 1, true); + + GGML_CUDA_FATTN_MMA_CONFIG_CASE( 96, 96, 8, 256, 2, 64, 48, 48, 48, 1, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE( 96, 96, 16, 256, 2, 64, 48, 48, 48, 1, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE( 96, 96, 32, 256, 2, 64, 48, 48, 48, 1, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE( 96, 96, 64, 256, 2, 64, 48, 48, 48, 1, true); + + GGML_CUDA_FATTN_MMA_CONFIG_CASE(112, 112, 8, 256, 2, 64, 56, 56, 56, 1, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(112, 112, 16, 256, 2, 64, 56, 56, 56, 1, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(112, 112, 32, 256, 2, 64, 56, 56, 56, 1, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(112, 112, 64, 256, 2, 64, 56, 56, 56, 1, true); + + GGML_CUDA_FATTN_MMA_CONFIG_CASE(128, 128, 8, 256, 2, 64, 64, 64, 64, 1, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(128, 128, 16, 256, 2, 64, 64, 64, 64, 1, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(128, 128, 32, 256, 2, 64, 64, 64, 64, 1, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(128, 128, 64, 256, 2, 64, 64, 64, 64, 1, true); + + GGML_CUDA_FATTN_MMA_CONFIG_CASE(192, 128, 8, 256, 1, 64, 64, 64, 64, 1, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(192, 128, 16, 256, 1, 64, 64, 64, 64, 1, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(192, 128, 32, 256, 1, 64, 64, 64, 64, 1, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(192, 128, 64, 512, 1, 64, 64, 64, 64, 1, true); + + GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 8, 256, 1, 64, 128, 128, 128, 1, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 16, 256, 1, 64, 128, 128, 128, 1, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 32, 256, 1, 64, 128, 128, 128, 1, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 64, 512, 1, 64, 128, 128, 64, 1, true); + + GGML_CUDA_FATTN_MMA_CONFIG_CASE(320, 256, 32, 256, 1, 64, 160, 128, 128, 1, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(320, 256, 64, 256, 1, 64, 160, 128, 128, 1, true); + + GGML_CUDA_FATTN_MMA_CONFIG_CASE(512, 512, 8, 256, 1, 64, 128, 128, 128, 1, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(512, 512, 16, 256, 1, 64, 128, 128, 128, 1, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(512, 512, 32, 256, 1, 64, 128, 128, 128, 1, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(512, 512, 64, 256, 1, 64, 128, 128, 128, 1, true); + + GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 8, 256, 1, 64, 128, 128, 128, 1, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 16, 256, 1, 64, 128, 128, 128, 1, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 32, 256, 1, 64, 160, 128, 128, 1, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 64, 256, 1, 64, 160, 128, 128, 1, true); + + return fattn_mma_config(32, 1, 0, 0, 0, 0, 0, false); +} + static __host__ fattn_mma_config ggml_cuda_fattn_mma_get_config(const int DKQ, const int DV, const int ncols, const int cc) { if (ampere_mma_available(cc)) { return ggml_cuda_fattn_mma_get_config_ampere(DKQ, DV, ncols); @@ -105,6 +235,12 @@ static __host__ fattn_mma_config ggml_cuda_fattn_mma_get_config(const int DKQ, c if (turing_mma_available(cc)) { return ggml_cuda_fattn_mma_get_config_turing(DKQ, DV, ncols); } + if (amd_mfma_available(cc)) { + return ggml_cuda_fattn_mma_get_config_cdna(DKQ, DV, ncols); + } + if (amd_wmma_available(cc)) { + return ggml_cuda_fattn_mma_get_config_rdna(DKQ, DV, ncols); + } GGML_ASSERT(volta_mma_available(cc)); return ggml_cuda_fattn_mma_get_config_volta(DKQ, DV, ncols); } @@ -114,8 +250,12 @@ static constexpr __device__ fattn_mma_config ggml_cuda_fattn_mma_get_config(cons return ggml_cuda_fattn_mma_get_config_ampere(DKQ, DV, ncols); #elif defined(TURING_MMA_AVAILABLE) return ggml_cuda_fattn_mma_get_config_turing(DKQ, DV, ncols); +#elif defined(AMD_MFMA_AVAILABLE) + return ggml_cuda_fattn_mma_get_config_cdna(DKQ, DV, ncols); #elif defined(VOLTA_MMA_AVAILABLE) return ggml_cuda_fattn_mma_get_config_volta(DKQ, DV, ncols); +#elif defined(AMD_WMMA_AVAILABLE) + return ggml_cuda_fattn_mma_get_config_rdna(DKQ, DV, ncols); #else GGML_UNUSED_VARS(DKQ, DV, ncols); return fattn_mma_config(32, 1, 0, 0, 0, 0, 0, false); @@ -186,6 +326,23 @@ static constexpr __device__ bool ggml_cuda_fattn_mma_get_Q_in_reg(const int DKQ, return ggml_cuda_fattn_mma_get_config(DKQ, DV, ncols).Q_in_reg; } +static constexpr __device__ int get_cols_per_thread() { +#if defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE) + return 1; // AMD has a single column per thread. +#else + return 2; // This is specifically KQ columns, Volta only has a single VKQ column. +#endif // defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE) +} + +static __host__ int get_cols_per_warp(const int cc) { + if (turing_mma_available(cc) || amd_wmma_available(cc) || amd_mfma_available(cc)) { + return 16; + } else { + // Volta + return 32; + } +} + // ------------------------------------------------------------------------------------------------------------------ static __host__ int ggml_cuda_fattn_mma_get_nstages(const int DKQ, const int DV, const int ncols1, const int ncols2, const int cc) { @@ -206,21 +363,23 @@ static constexpr __device__ int ggml_cuda_fattn_mma_get_nstages(const int DKQ, c template<int stride_tile, int nwarps, int nbatch_fa, bool use_cp_async, bool oob_check> static __device__ __forceinline__ void flash_attn_ext_f16_load_tile( const half2 * const __restrict__ KV, half2 * const __restrict__ tile_KV, const int D2, const int stride_KV, const int i_sup) { + constexpr int warp_size = ggml_cuda_get_physical_warp_size(); // K/V data is loaded with decreasing granularity for D for better memory bandwidth. - // The minimum granularity with cp.async is 16 bytes, with synchronous data loading it's 4 bytes. + // The minimum granularity is 16 bytes. + constexpr int h2_per_chunk = 16/sizeof(half2); + const int chunks_per_row = D2 / h2_per_chunk; if constexpr (use_cp_async) { + static_assert(warp_size == 32, "bad warp_size"); static_assert(!oob_check, "OOB check not compatible with cp_async"); constexpr int preload = 64; - constexpr int h2_per_chunk = 16/sizeof(half2); - const int chunks_per_row = D2 / h2_per_chunk; const unsigned int tile_KV_32 = ggml_cuda_cvta_generic_to_shared(tile_KV); auto load = [&] __device__ (auto n) { - const int stride_k = WARP_SIZE >> n; - const int k0_start = stride_k == WARP_SIZE ? 0 : chunks_per_row - chunks_per_row % (2*stride_k); + const int stride_k = warp_size >> n; + const int k0_start = stride_k == warp_size ? 0 : chunks_per_row - chunks_per_row % (2*stride_k); const int k0_stop = chunks_per_row - chunks_per_row % (1*stride_k); - const int stride_i = WARP_SIZE / stride_k; + const int stride_i = warp_size / stride_k; if (k0_start == k0_stop) { return; @@ -228,7 +387,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_load_tile( #pragma unroll for (int i0 = 0; i0 < nbatch_fa; i0 += nwarps*stride_i) { - const int i = i0 + threadIdx.y*stride_i + (stride_k == WARP_SIZE ? 0 : threadIdx.x / stride_k); + const int i = i0 + threadIdx.y*stride_i + (stride_k == warp_size ? 0 : threadIdx.x / stride_k); if (i0 + nwarps*stride_i > nbatch_fa && i >= nbatch_fa) { break; @@ -236,7 +395,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_load_tile( #pragma unroll for (int k0 = k0_start; k0 < k0_stop; k0 += stride_k) { - const int k = k0 + (stride_k == WARP_SIZE ? threadIdx.x : threadIdx.x % stride_k); + const int k = k0 + (stride_k == warp_size ? threadIdx.x : threadIdx.x % stride_k); cp_async_cg_16<preload>(tile_KV_32 + i*(stride_tile*sizeof(half2)) + k*16, KV + i*stride_KV + k*h2_per_chunk); } @@ -250,12 +409,12 @@ static __device__ __forceinline__ void flash_attn_ext_f16_load_tile( // 6: max 1*16= 16 bytes, 8 half ggml_cuda_unroll<6>{}(load); } else { - // TODO use ggml_cuda_memcpy_1 + const half2 zero[4] = {{0.0f, 0.0f}, {0.0f, 0.0f}, {0.0f, 0.0f}, {0.0f, 0.0f}}; auto load = [&] __device__ (const int n) { - const int stride_k = WARP_SIZE >> n; - const int k0_start = stride_k == WARP_SIZE ? 0 : D2 - D2 % (2*stride_k); - const int k0_stop = D2 - D2 % (1*stride_k); - const int stride_i = WARP_SIZE / stride_k; + const int stride_k = 32 >> n; + const int k0_start = stride_k == 32 ? 0 : chunks_per_row - chunks_per_row % (2*stride_k); + const int k0_stop = chunks_per_row - chunks_per_row % (1*stride_k); + const int stride_i = warp_size / stride_k; if (k0_start == k0_stop) { return; @@ -263,7 +422,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_load_tile( #pragma unroll for (int i0 = 0; i0 < nbatch_fa; i0 += nwarps*stride_i) { - const int i = i0 + threadIdx.y*stride_i + (stride_k == WARP_SIZE ? 0 : threadIdx.x / stride_k); + const int i = i0 + threadIdx.y*stride_i + (stride_k == warp_size ? 0 : threadIdx.x / stride_k); if (i0 + nwarps*stride_i > nbatch_fa && i >= nbatch_fa) { break; @@ -271,17 +430,20 @@ static __device__ __forceinline__ void flash_attn_ext_f16_load_tile( #pragma unroll for (int k0 = k0_start; k0 < k0_stop; k0 += stride_k) { - const int k = k0 + (stride_k == WARP_SIZE ? threadIdx.x : threadIdx.x % stride_k); + const int k = k0 + (stride_k == warp_size ? threadIdx.x : threadIdx.x % stride_k); - tile_KV[i*stride_tile + k] = !oob_check || i < i_sup ? KV[i*stride_KV + k] : make_half2(0.0f, 0.0f); + ggml_cuda_memcpy_1<16>(tile_KV + i*stride_tile + k*4, + !oob_check || i < i_sup ? KV + i*stride_KV + k*h2_per_chunk : zero); } } }; - // 1: max 32* 4=128 bytes, 64 half - // 2: max 16* 4= 64 bytes, 32 half - // 3: max 8* 4= 32 bytes, 16 half - // 4: max 4* 4= 16 bytes, 8 half - ggml_cuda_unroll<4>{}(load); + // 1: max 32*16=512 bytes, 256 half + // 2: max 16*16=256 bytes, 128 half + // 3: max 8*16=128 bytes, 64 half + // 4: max 4*16= 64 bytes, 32 half + // 5: max 2*16= 32 bytes, 16 half + // 6: max 1*16= 16 bytes, 8 half + ggml_cuda_unroll<6>{}(load); } } @@ -289,18 +451,19 @@ template<int ncols1, int nwarps, int nbatch_fa, bool use_cp_async, bool oob_chec static __device__ __forceinline__ void flash_attn_ext_f16_load_mask( const half * const __restrict__ mask_h, half * const __restrict__ tile_mask, const int stride_mask, const int i_sup, const int j0, const uint3 ne01) { + constexpr int warp_size = ggml_cuda_get_physical_warp_size(); if constexpr (use_cp_async) { - static_assert(nbatch_fa <= 8*WARP_SIZE && nbatch_fa % 8 == 0, "bad nbatch_fa"); + static_assert(nbatch_fa <= 8*warp_size && nbatch_fa % 8 == 0, "bad nbatch_fa"); static_assert(!oob_check, "OOB check incompatible with cp_async"); constexpr int preload = nbatch_fa >= 32 ? nbatch_fa * sizeof(half) : 64; - constexpr int cols_per_warp = 8*WARP_SIZE/nbatch_fa; + constexpr int cols_per_warp = 8*warp_size/nbatch_fa; constexpr int stride_j = nwarps * cols_per_warp; const unsigned int tile_mask_32 = ggml_cuda_cvta_generic_to_shared(tile_mask); #pragma unroll for (int j1 = 0; j1 < ncols1; j1 += stride_j) { - const int j_sram = j1 + threadIdx.y*cols_per_warp + threadIdx.x / (WARP_SIZE/cols_per_warp); + const int j_sram = j1 + threadIdx.y*cols_per_warp + threadIdx.x / (warp_size/cols_per_warp); const int j_vram = fastmodulo(j0 + j_sram, ne01); if (j1 + stride_j > ncols1 && j_sram >= ncols1) { @@ -309,7 +472,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_load_mask( const int i = 8 * (threadIdx.x % (nbatch_fa/8)); - cp_async_cg_16<preload>(tile_mask_32 + j_sram*(nbatch_fa*sizeof(half) + 16) + i*sizeof(half), mask_h + j_vram*stride_mask + i); + cp_async_cg_16<preload>(tile_mask_32 + j_sram*(nbatch_fa*sizeof(half) + 16) + i*sizeof(half), mask_h + int64_t(j_vram)*stride_mask + i); } } else if constexpr (oob_check) { #pragma unroll @@ -322,27 +485,27 @@ static __device__ __forceinline__ void flash_attn_ext_f16_load_mask( } #pragma unroll - for (int i0 = 0; i0 < nbatch_fa; i0 += WARP_SIZE) { + for (int i0 = 0; i0 < nbatch_fa; i0 += warp_size) { const int i = i0 + threadIdx.x; - tile_mask[j_sram*(nbatch_fa + 8) + i] = i < i_sup ? mask_h[j_vram*stride_mask + i] : half(0.0f); + tile_mask[j_sram*(nbatch_fa + 8) + i] = i < i_sup ? mask_h[int64_t(j_vram)*stride_mask + i] : half(0.0f); } } - } else if constexpr (nbatch_fa < 2*WARP_SIZE) { - constexpr int cols_per_warp = 2*WARP_SIZE/nbatch_fa; + } else if constexpr (nbatch_fa < 2*warp_size) { + constexpr int cols_per_warp = 2*warp_size/nbatch_fa; constexpr int stride_j = nwarps * cols_per_warp; #pragma unroll for (int j1 = 0; j1 < ncols1; j1 += stride_j) { - const int j_sram = j1 + threadIdx.y*cols_per_warp + threadIdx.x / (WARP_SIZE/cols_per_warp); + const int j_sram = j1 + threadIdx.y*cols_per_warp + threadIdx.x / (warp_size/cols_per_warp); const int j_vram = fastmodulo(j0 + j_sram, ne01); if (j1 + stride_j > ncols1 && j_sram >= ncols1) { break; } - const int i = threadIdx.x % (WARP_SIZE/cols_per_warp); + const int i = threadIdx.x % (warp_size/cols_per_warp); - ggml_cuda_memcpy_1<sizeof(half2)>(tile_mask + j_sram*(nbatch_fa + 8) + 2*i, mask_h + j_vram*stride_mask + 2*i); + ggml_cuda_memcpy_1<sizeof(half2)>(tile_mask + j_sram*(nbatch_fa + 8) + 2*i, mask_h + int64_t(j_vram)*stride_mask + 2*i); } } else { #pragma unroll @@ -355,17 +518,17 @@ static __device__ __forceinline__ void flash_attn_ext_f16_load_mask( } #pragma unroll - for (int i0 = 0; i0 < nbatch_fa; i0 += 2*WARP_SIZE) { + for (int i0 = 0; i0 < nbatch_fa; i0 += 2*warp_size) { const int i = i0 + 2*threadIdx.x; - ggml_cuda_memcpy_1<sizeof(half2)>(tile_mask + j_sram*(nbatch_fa + 8) + i, mask_h + j_vram*stride_mask + i); + ggml_cuda_memcpy_1<sizeof(half2)>(tile_mask + j_sram*(nbatch_fa + 8) + i, mask_h + int64_t(j_vram)*stride_mask + i); } } } } template<int DKQ, int DV, int ncols1, int ncols2, int nwarps, - bool use_logit_softcap, bool mla, bool needs_fixup, bool is_fixup, bool last_iter, bool oob_check, + bool use_logit_softcap, bool V_is_K_view, bool needs_fixup, bool is_fixup, bool last_iter, bool oob_check, typename T_A_KQ, typename T_B_KQ, typename T_C_KQ, typename T_A_VKQ, typename T_B_VKQ, typename T_C_VKQ> static __device__ __forceinline__ void flash_attn_ext_f16_iter( const float2 * const __restrict__ Q_f2, @@ -393,33 +556,34 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( const int jt, const int kb0, const int k_VKQ_sup) { -#if defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#if defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE) + constexpr int warp_size = ggml_cuda_get_physical_warp_size(); constexpr int ncols = ncols1 * ncols2; constexpr int cols_per_warp = T_B_KQ::I; - constexpr int cols_per_thread = 2; // This is specifically KQ columns, Volta only has a single VKQ column. - constexpr int np = nwarps * (cols_per_warp/ncols2) / ncols1; // Number of parallel CUDA warps per Q column. + constexpr int cols_per_thread = get_cols_per_thread(); + constexpr int np = cols_per_warp > ncols ? nwarps : nwarps * cols_per_warp/ncols; // Number of parallel CUDA warps per Q column. constexpr int nbatch_fa = ggml_cuda_fattn_mma_get_nbatch_fa(DKQ, DV, ncols); constexpr int nbatch_K2 = ggml_cuda_fattn_mma_get_nbatch_K2(DKQ, DV, ncols); constexpr int nbatch_V2 = ggml_cuda_fattn_mma_get_nbatch_V2(DKQ, DV, ncols); constexpr bool Q_in_reg = ggml_cuda_fattn_mma_get_Q_in_reg (DKQ, DV, ncols); constexpr int nstages = ggml_cuda_fattn_mma_get_nstages (DKQ, DV, ncols1, ncols2); - constexpr int stride_tile_Q = DKQ/2 + 4; constexpr int stride_tile_K = nbatch_K2 + 4; - static_assert(!mla || nbatch_K2 >= nbatch_V2, "bad nbatch_K2, nbatch_V2 for MLA"); - constexpr int stride_tile_V = mla ? stride_tile_K : nbatch_V2 + 4; + constexpr int stride_tile_V = V_is_K_view ? stride_tile_K : nbatch_V2 + 4; const int k_VKQ_0 = kb0 * nbatch_fa; #if defined(TURING_MMA_AVAILABLE) T_C_KQ KQ_C[nbatch_fa/(np*(cols_per_warp == 8 ? T_C_KQ::I : T_C_KQ::J))]; +#elif defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE) + T_C_KQ KQ_C[nbatch_fa/(np*T_C_KQ::J)]; #else // Volta T_C_KQ KQ_C[nbatch_fa/(np*T_C_KQ::J)]; #endif // defined(TURING_MMA_AVAILABLE) if constexpr (nstages > 1) { static_assert(!oob_check, "OOB check incompatible with multi-stage pipeline"); - static_assert(!mla, "multi-stage loading not implemented for MLA"); + static_assert(!V_is_K_view, "K data reuse not implemented multi-stage loading"); static_assert(nbatch_K2 == DKQ/2, "batching not implemented for multi stage loading"); constexpr bool use_cp_async = true; cp_async_wait_all(); @@ -434,12 +598,14 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( } } + // For MLA K and V have the same data. + // Therefore, iterate over K in reverse and later re-use the data if possible. #pragma unroll - for (int k0_start = 0; k0_start < DKQ/2; k0_start += nbatch_K2) { + for (int k0_start = (DKQ/2-1) - (DKQ/2-1) % nbatch_K2; k0_start >= 0; k0_start -= nbatch_K2) { const int k0_stop = k0_start + nbatch_K2 < DKQ/2 ? k0_start + nbatch_K2 : DKQ/2; - const int k0_diff = k0_stop - k0_start; if constexpr (nstages <= 1) { + const int k0_diff = k0_stop - k0_start; constexpr bool use_cp_async = nstages == 1; flash_attn_ext_f16_load_tile<stride_tile_K, nwarps, nbatch_fa, use_cp_async, oob_check> (K_h2 + int64_t(k_VKQ_0)*stride_K + k0_start, tile_K, k0_diff, stride_K, k_VKQ_sup); @@ -461,13 +627,19 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( if constexpr (cols_per_warp == 8) { mma(KQ_C[i_KQ_00/(np*T_A_KQ::I)], K_A, Q_B[k_KQ_0/T_A_KQ::J]); } else { - // Wide version of KQ_C is column-major => swap A and B. + // Wide version of KQ_C is column-major +#if defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE) + // AMD matrix C is column-major. + mma(KQ_C[i_KQ_00/(np*T_A_KQ::I)], K_A, Q_B[k_KQ_0/T_A_KQ::J]); +#else + // swap A and B for CUDA. mma(KQ_C[i_KQ_00/(np*T_A_KQ::I)], Q_B[k_KQ_0/T_A_KQ::J], K_A); +#endif // defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE) } } } } else { - static_assert(cols_per_warp != 8, "cols_per_warp == 8 not implemented"); + constexpr int stride_tile_Q = DKQ/2 + 4; #pragma unroll for (int k_KQ_0 = k0_start; k_KQ_0 < k0_stop; k_KQ_0 += T_A_KQ::J) { load_ldmatrix(Q_B[0], tile_Q + (threadIdx.y / np)*(T_B_KQ::I*stride_tile_Q) + k_KQ_0, stride_tile_Q); @@ -479,8 +651,18 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( T_A_KQ K_A; load_ldmatrix(K_A, tile_K + i_KQ_0*stride_tile_K + (k_KQ_0 - k0_start), stride_tile_K); - // Wide version of KQ_C is column-major => swap A and B. - mma(KQ_C[i_KQ_00/(np*T_A_KQ::I)], Q_B[0], K_A); + if constexpr (cols_per_warp == 8) { + mma(KQ_C[i_KQ_00/(np*T_A_KQ::I)], K_A, Q_B[0]); + } else { + // Wide version of KQ_C is column-major +#if defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE) + // AMD matrix C is column-major. + mma(KQ_C[i_KQ_00/(np*T_A_KQ::I)], K_A, Q_B[0]); +#else + // swap A and B for CUDA. + mma(KQ_C[i_KQ_00/(np*T_A_KQ::I)], Q_B[0], K_A); +#endif // defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE) + } } } } @@ -532,7 +714,13 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( #pragma unroll for (int l = 0; l < T_C_KQ::ne; ++l) { if (!oob_check || k0 + (threadIdx.y % np)*T_C_KQ::I + T_C_KQ::get_i(l) < k_VKQ_sup) { - KQ_max_new[l % 2] = fmaxf(KQ_max_new[l % 2], KQ_C[k0/(np*T_C_KQ::I)].x[l] + FATTN_KQ_MAX_OFFSET); +#if defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE) + constexpr int KQ_idx = 0; +#else + // Turing + Volta: + const int KQ_idx = l % 2; +#endif // defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE) + KQ_max_new[KQ_idx] = fmaxf(KQ_max_new[KQ_idx], KQ_C[k0/(np*T_C_KQ::I)].x[l] + FATTN_KQ_MAX_OFFSET); } } } @@ -542,7 +730,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( for (int col = 0; col < cols_per_thread; ++col) { #pragma unroll for (int offset = 16; offset >= 4; offset >>= 1) { - KQ_max_new[col] = fmaxf(KQ_max_new[col], __shfl_xor_sync(0xFFFFFFFF, KQ_max_new[col], offset, WARP_SIZE)); + KQ_max_new[col] = fmaxf(KQ_max_new[col], __shfl_xor_sync(0xFFFFFFFF, KQ_max_new[col], offset, warp_size)); } } @@ -552,8 +740,14 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( #pragma unroll for (int l = 0; l < T_C_KQ::ne; ++l) { if (!oob_check || k0 + (threadIdx.y % np)*T_C_KQ::I + T_C_KQ::get_i(l) < k_VKQ_sup) { - KQ_C[k0/(np*T_C_KQ::I)].x[l] = expf(KQ_C[k0/(np*T_C_KQ::I)].x[l] - KQ_max_new[l % 2]); - KQ_rowsum_add[l % 2] += KQ_C[k0/(np*T_C_KQ::I)].x[l]; +#if defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE) + constexpr int KQ_idx = 0; +#else + // Turing + Volta: + const int KQ_idx = l % 2; +#endif // defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE) + KQ_C[k0/(np*T_C_KQ::I)].x[l] = expf(KQ_C[k0/(np*T_C_KQ::I)].x[l] - KQ_max_new[KQ_idx]); + KQ_rowsum_add[KQ_idx] += KQ_C[k0/(np*T_C_KQ::I)].x[l]; } else { KQ_C[k0/(np*T_C_KQ::I)].x[l] = 0.0f; } @@ -564,6 +758,18 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( #pragma unroll for (int i00 = 0; i00 < nbatch_fa; i00 += np*T_C_KQ::J) { const int i0 = i00 + (threadIdx.y % np)*T_C_KQ::J; + + // The mask is stored as 16 bit half values, loading them as 32 bit half2 values is preferred in terms of speed. + // However, this is not possible for RDNA3 where 2 consecutive l indices are not consecutive in the mask memory layout. +#ifdef RDNA3 +#pragma unroll + for (int l = 0; l < T_C_KQ::ne; ++l) { + const int i = i0 + T_C_KQ::get_j(l); + const int j = ((threadIdx.y / np)*cols_per_warp + T_C_KQ::get_i(l)) / ncols2; + + KQ_C[i00/(np*T_C_KQ::J)].x[l] += __half2float(tile_mask[j*(nbatch_fa + 8) + i]); + } +#else #pragma unroll for (int l0 = 0; l0 < T_C_KQ::ne; l0 += 2) { const int i = (i0 + T_C_KQ::get_j(l0)) / 2; @@ -573,6 +779,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( KQ_C[i00/(np*T_C_KQ::J)].x[l0 + 0] += slope*tmp.x; KQ_C[i00/(np*T_C_KQ::J)].x[l0 + 1] += slope*tmp.y; } +#endif // RDNA3 } } @@ -584,8 +791,13 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( #pragma unroll for (int l = 0; l < T_C_KQ::ne; ++l) { if (!oob_check || k0 + (threadIdx.y % np)*T_C_KQ::J + T_C_KQ::get_j(l) < k_VKQ_sup) { +#if defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE) + constexpr int KQ_idx = 0; +#else // Turing + Volta: - KQ_max_new[(l/2) % 2] = fmaxf(KQ_max_new[(l/2) % 2], KQ_C[(k0/(np*T_C_KQ::J))].x[l] + FATTN_KQ_MAX_OFFSET); + const int KQ_idx = (l/2) % 2; +#endif // defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE) + KQ_max_new[KQ_idx] = fmaxf(KQ_max_new[KQ_idx], KQ_C[(k0/(np*T_C_KQ::J))].x[l] + FATTN_KQ_MAX_OFFSET); } } } @@ -596,14 +808,22 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( // Values per KQ column are spread across 4 threads: constexpr int offset_first = 2; constexpr int offset_last = 1; -#else +#elif defined(AMD_MFMA_AVAILABLE) + // MFMA: 4 threads per Q column (threadIdx.x % 16 == col, spaced by 16). + constexpr int offset_first = 32; + constexpr int offset_last = 16; +#elif defined(AMD_WMMA_AVAILABLE) + // Values per KQ column are spread across 2 threads: + constexpr int offset_first = 16; + constexpr int offset_last = 16; +#else // Volta // Values per KQ column are spread across 2 threads: constexpr int offset_first = 2; constexpr int offset_last = 2; #endif // defined(TURING_MMA_AVAILABLE) #pragma unroll for (int offset = offset_first; offset >= offset_last; offset >>= 1) { - KQ_max_new[col] = fmaxf(KQ_max_new[col], __shfl_xor_sync(0xFFFFFFFF, KQ_max_new[col], offset, WARP_SIZE)); + KQ_max_new[col] = fmaxf(KQ_max_new[col], __shfl_xor_sync(0xFFFFFFFF, KQ_max_new[col], offset, warp_size)); } } @@ -612,10 +832,15 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( for (int k0 = 0; k0 < nbatch_fa; k0 += np*T_C_KQ::J) { #pragma unroll for (int l = 0; l < T_C_KQ::ne; ++l) { - // Turing + Volta: if (!oob_check || k0 + (threadIdx.y % np)*T_C_KQ::J + T_C_KQ::get_j(l) < k_VKQ_sup) { - KQ_C[(k0/(np*T_C_KQ::J))].x[l] = expf(KQ_C[(k0/(np*T_C_KQ::J))].x[l] - KQ_max_new[(l/2) % 2]); - KQ_rowsum_add[(l/2) % 2] += KQ_C[(k0/(np*T_C_KQ::J))].x[l]; +#if defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE) + constexpr int KQ_idx = 0; +#else + // Turing + Volta: + const int KQ_idx = (l/2) % 2; +#endif // defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE) + KQ_C[(k0/(np*T_C_KQ::J))].x[l] = expf(KQ_C[(k0/(np*T_C_KQ::J))].x[l] - KQ_max_new[KQ_idx]); + KQ_rowsum_add[KQ_idx] += KQ_C[(k0/(np*T_C_KQ::J))].x[l]; } else { KQ_C[(k0/(np*T_C_KQ::J))].x[l] = 0.0f; } @@ -639,7 +864,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( #if defined(TURING_MMA_AVAILABLE) if constexpr (cols_per_warp == 8) { - const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale[0], KQ_max_scale[1]); + const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale[0], KQ_max_scale[cols_per_thread - 1]); #pragma unroll for (int i = 0; i < DV/T_C_VKQ::I; ++i) { #pragma unroll @@ -660,6 +885,26 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( } } } +#elif defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE) + if constexpr (std::is_same_v<decltype(T_C_VKQ::x), half2[T_C_VKQ::ne]>) { + const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale[0], KQ_max_scale[0]); +#pragma unroll + for (int i = 0; i < (DV/2)/T_C_VKQ::J; ++i) { +#pragma unroll + for (int l = 0; l < T_C_VKQ::ne; ++l) { + VKQ_C[i].x[l] *= KQ_max_scale_h2; + } + } + } else { + static_assert(std::is_same_v<decltype(T_C_VKQ::x), float[T_C_VKQ::ne]>, "bad VKQ type"); +#pragma unroll + for (int i = 0; i < DV/T_C_VKQ::J; ++i) { +#pragma unroll + for (int l = 0; l < T_C_VKQ::ne; ++l) { + VKQ_C[i].x[l] *= KQ_max_scale[0]; + } + } + } #else // Volta const half2 KQ_max_scale_h2 = make_half2( KQ_max_scale[(threadIdx.x / 2) % 2], KQ_max_scale[(threadIdx.x / 2) % 2]); @@ -688,6 +933,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( } if constexpr (nstages > 1) { + static_assert(!V_is_K_view, "K data reuse not implemented multi-stage loading"); // Preload K tile for next iteration: constexpr bool use_cp_async = true; cp_async_wait_all(); @@ -703,19 +949,15 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( } - // For MLA K and V have the same data. - // Therefore, iterate over V in reverse and re-use the data if possible. - static_assert(!mla || nstages <= 1, "combination of MLA and multi-stage loading not implemented"); - constexpr int reusable_cutoff = mla ? (DKQ - 1) - (DKQ - 1) % (2*nbatch_K2) - (DKQ - DV) : DV; - // Calculate VKQ tile, need to use logical rather than physical elements for i0 due to transposition of V: #pragma unroll - for (int i0_stop = DV; i0_stop > 0; i0_stop -= 2*nbatch_V2) { - const int i0_start = i0_stop - 2*nbatch_V2 > 0 ? i0_stop - 2*nbatch_V2 : 0; - const int i0_diff = i0_stop - i0_start; + for (int i0_start = 0; i0_start < DV; i0_start += 2*nbatch_V2) { + static_assert(DV % (2*nbatch_V2) == 0, "bad loop size"); + const int i0_stop = i0_start + 2*nbatch_V2; if constexpr (nstages <= 1) { - if (i0_start < reusable_cutoff) { + const int i0_diff = i0_stop - i0_start; + if (!V_is_K_view || i0_stop > 2*nbatch_K2) { constexpr bool use_cp_async = nstages == 1; flash_attn_ext_f16_load_tile<stride_tile_V, nwarps, nbatch_fa, use_cp_async, oob_check> (V_h2 + int64_t(k_VKQ_0)*stride_V + i0_start/2, tile_V, i0_diff/2, stride_V, k_VKQ_sup); @@ -725,12 +967,11 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( __syncthreads(); } } - const half2 * tile_V_i = i0_start < reusable_cutoff ? tile_V : tile_V + (i0_start - reusable_cutoff)/2; + const half2 * tile_V_i = !V_is_K_view || i0_stop > 2*nbatch_K2 ? tile_V : tile_V + i0_start/2; -#if defined(TURING_MMA_AVAILABLE) - constexpr int i0_stride = cols_per_warp == 8 ? T_C_VKQ::I : 2*T_C_VKQ::J; +#if defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE) #pragma unroll - for (int i_VKQ_0 = i0_start; i_VKQ_0 < i0_stop; i_VKQ_0 += i0_stride) { + for (int i_VKQ_0 = i0_start; i_VKQ_0 < i0_stop; i_VKQ_0 += T_A_VKQ::I) { static_assert((nbatch_fa/2) % (np*T_A_VKQ::J) == 0, "bad loop size"); #pragma unroll for (int k00 = 0; k00 < nbatch_fa/2; k00 += np*T_A_VKQ::J) { @@ -739,10 +980,16 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( T_A_VKQ A; // Transposed in SRAM but not in registers, gets transposed on load. load_ldmatrix_trans(A, tile_V_i + 2*k0*stride_tile_V + (i_VKQ_0 - i0_start)/2, stride_tile_V); if constexpr (T_B_KQ::I == 8) { - mma(VKQ_C[i_VKQ_0/i0_stride], A, B[k00/(np*T_A_VKQ::J)]); + mma(VKQ_C[i_VKQ_0/T_A_VKQ::I], A, B[k00/(np*T_A_VKQ::J)]); } else { - // Wide version of VKQ_C is column-major => swap A and B. - mma(VKQ_C[i_VKQ_0/i0_stride], B[k00/(np*T_A_VKQ::J)], A); + // Wide version of VKQ_C is column-major. +#if defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE) + // AMD matrix C is column-major. + mma(VKQ_C[i_VKQ_0/T_A_VKQ::I], A, B[k00/(np*T_A_VKQ::J)]); +#else + // swap A and B for CUDA. + mma(VKQ_C[i_VKQ_0/T_A_VKQ::I], B[k00/(np*T_A_VKQ::J)], A); +#endif // defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE) } } } @@ -761,7 +1008,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( mma(VKQ_C[i_VKQ_0/i0_stride], B[k00/(np*T_A_VKQ::I)], A); } } -#endif // defined(TURING_MMA_AVAILABLE) +#endif // defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE) if constexpr (nstages <= 1) { __syncthreads(); // Only needed if tile_K == tile_V. @@ -774,11 +1021,11 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( tile_Q, tile_K, tile_V, tile_mask, Q_B, VKQ_C, KQ_max, KQ_rowsum, kb0); NO_DEVICE_CODE; -#endif // defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#endif // defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE) } #if defined(TURING_MMA_AVAILABLE) -template<int ncols> struct mma_tile_sizes { +template<int DV, int ncols> struct mma_tile_sizes { using T_A_KQ = tile<16, 8, half2>; // row-major using T_B_KQ = tile<16, 8, half2>; // column-major using T_C_KQ = tile<16, 16, float>; // column-major @@ -786,7 +1033,7 @@ template<int ncols> struct mma_tile_sizes { using T_B_VKQ = tile<16, 8, half2>; // column-major using T_C_VKQ = tile<16, 8, half2>; // column-major }; -template<> struct mma_tile_sizes<8> { +template<int DV> struct mma_tile_sizes<DV, 8> { using T_A_KQ = tile<16, 8, half2>; // row-major using T_B_KQ = tile< 8, 8, half2>; // column-major using T_C_KQ = tile<16, 8, float>; // row-major @@ -794,8 +1041,69 @@ template<> struct mma_tile_sizes<8> { using T_B_VKQ = tile< 8, 8, half2>; // column-major using T_C_VKQ = tile<16, 4, half2>; // row-major }; +#elif defined(AMD_WMMA_AVAILABLE) +#ifdef RDNA3 +template<int DV, int ncols> struct mma_tile_sizes { + using T_A_KQ = tile<16, 8, half2, DATA_LAYOUT_I_MAJOR_MIRRORED>; // row-major + using T_B_KQ = tile<16, 8, half2, DATA_LAYOUT_I_MAJOR_MIRRORED>; // column-major + using T_C_KQ = tile<16, 16, float, DATA_LAYOUT_I_MAJOR>; // column-major + using T_A_VKQ = tile<32, 8, half2, DATA_LAYOUT_I_MAJOR_MIRRORED>; // row-major + using T_B_VKQ = tile<16, 8, half2, DATA_LAYOUT_I_MAJOR_MIRRORED>; // column-major + using T_C_VKQ = tile<16, 16, half2, DATA_LAYOUT_I_MAJOR>; // column-major +}; +template<int ncols> struct mma_tile_sizes<80, ncols> { + using T_A_KQ = tile<16, 8, half2, DATA_LAYOUT_I_MAJOR_MIRRORED>; // row-major + using T_B_KQ = tile<16, 8, half2, DATA_LAYOUT_I_MAJOR_MIRRORED>; // column-major + using T_C_KQ = tile<16, 16, float, DATA_LAYOUT_I_MAJOR>; // column-major + using T_A_VKQ = tile<16, 8, half2, DATA_LAYOUT_I_MAJOR_MIRRORED>; // row-major + using T_B_VKQ = tile<16, 8, half2, DATA_LAYOUT_I_MAJOR_MIRRORED>; // column-major + using T_C_VKQ = tile<16, 16, float, DATA_LAYOUT_I_MAJOR>; // column-major +}; +template<int ncols> struct mma_tile_sizes<112, ncols> { + using T_A_KQ = tile<16, 8, half2, DATA_LAYOUT_I_MAJOR_MIRRORED>; // row-major + using T_B_KQ = tile<16, 8, half2, DATA_LAYOUT_I_MAJOR_MIRRORED>; // column-major + using T_C_KQ = tile<16, 16, float, DATA_LAYOUT_I_MAJOR>; // column-major + using T_A_VKQ = tile<16, 8, half2, DATA_LAYOUT_I_MAJOR_MIRRORED>; // row-major + using T_B_VKQ = tile<16, 8, half2, DATA_LAYOUT_I_MAJOR_MIRRORED>; // column-major + using T_C_VKQ = tile<16, 16, float, DATA_LAYOUT_I_MAJOR>; // column-major +}; +#else +template<int DV, int ncols> struct mma_tile_sizes { + using T_A_KQ = tile<16, 8, half2, DATA_LAYOUT_I_MAJOR>; // row-major + using T_B_KQ = tile<16, 8, half2, DATA_LAYOUT_I_MAJOR>; // column-major + using T_C_KQ = tile<16, 16, float, DATA_LAYOUT_I_MAJOR>; // column-major + using T_A_VKQ = tile<32, 8, half2, DATA_LAYOUT_I_MAJOR>; // row-major + using T_B_VKQ = tile<16, 8, half2, DATA_LAYOUT_I_MAJOR>; // column-major + using T_C_VKQ = tile<16, 16, half2, DATA_LAYOUT_I_MAJOR_SCRAMBLED>; // column-major +}; +template<int ncols> struct mma_tile_sizes<80, ncols> { + using T_A_KQ = tile<16, 8, half2>; // row-major + using T_B_KQ = tile<16, 8, half2>; // column-major + using T_C_KQ = tile<16, 16, float>; // column-major + using T_A_VKQ = tile<16, 8, half2>; // row-major + using T_B_VKQ = tile<16, 8, half2>; // column-major + using T_C_VKQ = tile<16, 8, half2>; // column-major +}; +template<int ncols> struct mma_tile_sizes<112, ncols> { + using T_A_KQ = tile<16, 8, half2>; // row-major + using T_B_KQ = tile<16, 8, half2>; // column-major + using T_C_KQ = tile<16, 16, float>; // column-major + using T_A_VKQ = tile<16, 8, half2>; // row-major + using T_B_VKQ = tile<16, 8, half2>; // column-major + using T_C_VKQ = tile<16, 8, half2>; // column-major +}; +#endif // RDNA3 +#elif defined(AMD_MFMA_AVAILABLE) +template<int DV, int ncols> struct mma_tile_sizes { + using T_A_KQ = tile<16, 8, half2>; // row-major + using T_B_KQ = tile<16, 8, half2>; // column-major + using T_C_KQ = tile<16, 16, float>; // column-major + using T_A_VKQ = tile<16, 8, half2>; // row-major + using T_B_VKQ = tile<16, 8, half2>; // column-major + using T_C_VKQ = tile<16, 8, half2>; // column-major +}; #else // Volta -template<int ncols> struct mma_tile_sizes { +template<int DV, int ncols> struct mma_tile_sizes { using T_A_KQ = tile< 8, 4, half2, DATA_LAYOUT_I_MAJOR_MIRRORED>; // row-major using T_B_KQ = tile<32, 4, half2, DATA_LAYOUT_I_MAJOR>; // column-major using T_C_KQ = tile<32, 8, float, DATA_LAYOUT_I_MAJOR>; // column-major @@ -805,7 +1113,7 @@ template<int ncols> struct mma_tile_sizes { }; #endif // defined(TURING_MMA_AVAILABLE) -template<int DKQ, int DV, int ncols1, int ncols2, int nwarps, bool use_logit_softcap, bool mla, bool needs_fixup, bool is_fixup> +template<int DKQ, int DV, int ncols1, int ncols2, int nwarps, bool use_logit_softcap, bool V_is_K_view, bool needs_fixup, bool is_fixup> static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( const float2 * const __restrict__ Q_f2, const half2 * const __restrict__ K_h2, @@ -819,6 +1127,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( const float logit_softcap, const uint3 ne01, const int ne02, + const int gqa_ratio, const int ne11, const int stride_Q1, const int stride_Q2, @@ -826,22 +1135,24 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( const int stride_V, const int stride_mask, const int jt, + const int zt_gqa, const int kb0_start, const int kb0_stop) { -#if defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#if defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE) //In this kernel Q, K, V are matrices while i, j, k are matrix indices. + constexpr int warp_size = ggml_cuda_get_physical_warp_size(); constexpr int ncols = ncols1 * ncols2; - using T_A_KQ = typename mma_tile_sizes<ncols>::T_A_KQ; - using T_B_KQ = typename mma_tile_sizes<ncols>::T_B_KQ; - using T_C_KQ = typename mma_tile_sizes<ncols>::T_C_KQ; - using T_A_VKQ = typename mma_tile_sizes<ncols>::T_A_VKQ; - using T_B_VKQ = typename mma_tile_sizes<ncols>::T_B_VKQ; - using T_C_VKQ = typename mma_tile_sizes<ncols>::T_C_VKQ; + using T_A_KQ = typename mma_tile_sizes<DV, ncols>::T_A_KQ; + using T_B_KQ = typename mma_tile_sizes<DV, ncols>::T_B_KQ; + using T_C_KQ = typename mma_tile_sizes<DV, ncols>::T_C_KQ; + using T_A_VKQ = typename mma_tile_sizes<DV, ncols>::T_A_VKQ; + using T_B_VKQ = typename mma_tile_sizes<DV, ncols>::T_B_VKQ; + using T_C_VKQ = typename mma_tile_sizes<DV, ncols>::T_C_VKQ; constexpr int cols_per_warp = T_B_KQ::I; - constexpr int cols_per_thread = 2; // This is specifically KQ columns, Volta only has a single VKQ column. - constexpr int np = nwarps * (cols_per_warp/ncols2) / ncols1; // Number of parallel CUDA warps per Q column. + constexpr int cols_per_thread = get_cols_per_thread(); + constexpr int np = cols_per_warp > ncols ? nwarps : nwarps * cols_per_warp/ncols; // Number of parallel CUDA warps per Q column. constexpr int nbatch_fa = ggml_cuda_fattn_mma_get_nbatch_fa (DKQ, DV, ncols); constexpr int nbatch_K2 = ggml_cuda_fattn_mma_get_nbatch_K2 (DKQ, DV, ncols); constexpr int nbatch_V2 = ggml_cuda_fattn_mma_get_nbatch_V2 (DKQ, DV, ncols); @@ -859,8 +1170,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( constexpr int stride_tile_Q = DKQ/2 + 4; constexpr int stride_tile_K = nbatch_K2 + 4; - static_assert(!mla || nbatch_K2 >= nbatch_V2, "bad nbatch_K2, nbatch_V2 for MLA"); - constexpr int stride_tile_V = mla ? stride_tile_K : nbatch_V2 + 4; + constexpr int stride_tile_V = V_is_K_view ? stride_tile_K : nbatch_V2 + 4; constexpr int stride_tile_KV_max = stride_tile_K > stride_tile_V ? stride_tile_K : stride_tile_V; extern __shared__ half2 tile_Q[]; @@ -871,6 +1181,10 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( T_B_KQ Q_B[(Q_in_reg ? DKQ/(2*T_B_KQ::J) : 1)]; #if defined(TURING_MMA_AVAILABLE) T_C_VKQ VKQ_C[cols_per_warp == 8 ? DV/T_C_VKQ::I : DV/(2*T_C_VKQ::J)]; +#elif defined(AMD_WMMA_AVAILABLE) && defined(RDNA3) + T_C_VKQ VKQ_C[DV % 32 != 0 ? DV/T_C_VKQ::J : DV/(2*T_C_VKQ::J)]; +#elif defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE) + T_C_VKQ VKQ_C[ DV/(2*T_C_VKQ::J)]; #else // Volta T_C_VKQ VKQ_C[ DV/(2*T_C_VKQ::J)]; #endif // defined(TURING_MMA_AVAILABLE) @@ -887,10 +1201,10 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( // The loading is done with decreasing granularity for D for better memory bandwidth. const half2 scale_h2 = make_half2(scale, scale); #pragma unroll - for (int stride_k : {WARP_SIZE, WARP_SIZE/2, WARP_SIZE/4}) { - const int k0_start = stride_k == WARP_SIZE ? 0 : DKQ/2 - (DKQ/2) % (2*stride_k); + for (int stride_k : {warp_size, warp_size/2, warp_size/4, warp_size/8}) { + const int k0_start = stride_k == warp_size ? 0 : DKQ/2 - (DKQ/2) % (2*stride_k); const int k0_stop = DKQ/2 - (DKQ/2) % (1*stride_k); - const int stride_jc = WARP_SIZE / stride_k; + const int stride_jc = warp_size / stride_k; if (k0_start == k0_stop) { continue; @@ -898,7 +1212,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( #pragma unroll for (int jc0 = 0; jc0 < ncols; jc0 += nwarps*stride_jc) { - const int jc = jc0 + threadIdx.y*stride_jc + (stride_k == WARP_SIZE ? 0 : threadIdx.x / stride_k); + const int jc = jc0 + threadIdx.y*stride_jc + (stride_k == warp_size ? 0 : threadIdx.x / stride_k); if (jc0 + nwarps*stride_jc > ncols && jc >= ncols) { break; @@ -907,10 +1221,10 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( const int j = jc / ncols2; const int c = jc % ncols2; - if (jt*ncols1 + j < int(ne01.z)) { + if ((ncols1 == 1 || jt*ncols1 + j < int(ne01.z)) && (ncols2 == 1 || zt_gqa*ncols2 + c < gqa_ratio)) { #pragma unroll for (int k0 = k0_start; k0 < k0_stop; k0 += stride_k) { - const int k = k0 + (stride_k == WARP_SIZE ? threadIdx.x : threadIdx.x % stride_k); + const int k = k0 + (stride_k == warp_size ? threadIdx.x : threadIdx.x % stride_k); const float2 tmp = Q_f2[(jt*ncols1 + j)*stride_Q1 + c*stride_Q2 + k]; tile_Q[jc*stride_tile_Q + k] = scale_h2 * make_half2(tmp.x, tmp.y); @@ -918,7 +1232,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( } else { #pragma unroll for (int k0 = k0_start; k0 < k0_stop; k0 += stride_k) { - const int k = k0 + (stride_k == WARP_SIZE ? threadIdx.x : threadIdx.x % stride_k); + const int k = k0 + (stride_k == warp_size ? threadIdx.x : threadIdx.x % stride_k); tile_Q[jc*stride_tile_Q + k] = make_half2(0.0f, 0.0f); } @@ -962,7 +1276,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( constexpr bool last_iter = false; constexpr int k_VKQ_sup = nbatch_fa; flash_attn_ext_f16_iter - <DKQ, DV, ncols1, ncols2, nwarps, use_logit_softcap, mla, needs_fixup, is_fixup, last_iter, oob_check, + <DKQ, DV, ncols1, ncols2, nwarps, use_logit_softcap, V_is_K_view, needs_fixup, is_fixup, last_iter, oob_check, T_A_KQ, T_B_KQ, T_C_KQ, T_A_VKQ, T_B_VKQ, T_C_VKQ> (Q_f2, K_h2, V_h2, mask_h, dstk, dstk_fixup, scale, slope, logit_softcap, ne01, ne02, stride_K, stride_V, stride_mask, tile_Q, tile_K, tile_V, tile_mask, Q_B, VKQ_C, @@ -971,7 +1285,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( constexpr bool last_iter = true; const int k_VKQ_sup = ne11 - kb0*nbatch_fa; flash_attn_ext_f16_iter - <DKQ, DV, ncols1, ncols2, nwarps, use_logit_softcap, mla, needs_fixup, is_fixup, last_iter, oob_check, + <DKQ, DV, ncols1, ncols2, nwarps, use_logit_softcap, V_is_K_view, needs_fixup, is_fixup, last_iter, oob_check, T_A_KQ, T_B_KQ, T_C_KQ, T_A_VKQ, T_B_VKQ, T_C_VKQ> (Q_f2, K_h2, V_h2, mask_h, dstk, dstk_fixup, scale, slope, logit_softcap, ne01, ne02, stride_K, stride_V, stride_mask, tile_Q, tile_K, tile_V, tile_mask, Q_B, VKQ_C, @@ -982,7 +1296,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( constexpr bool last_iter = false; constexpr int k_VKQ_sup = nbatch_fa; flash_attn_ext_f16_iter - <DKQ, DV, ncols1, ncols2, nwarps, use_logit_softcap, mla, needs_fixup, is_fixup, last_iter, oob_check, + <DKQ, DV, ncols1, ncols2, nwarps, use_logit_softcap, V_is_K_view, needs_fixup, is_fixup, last_iter, oob_check, T_A_KQ, T_B_KQ, T_C_KQ, T_A_VKQ, T_B_VKQ, T_C_VKQ> (Q_f2, K_h2, V_h2, mask_h, dstk, dstk_fixup, scale, slope, logit_softcap, ne01, ne02, stride_K, stride_V, stride_mask, tile_Q, tile_K, tile_V, tile_mask, Q_B, VKQ_C, @@ -991,7 +1305,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( constexpr bool last_iter = true; constexpr int k_VKQ_sup = nbatch_fa; flash_attn_ext_f16_iter - <DKQ, DV, ncols1, ncols2, nwarps, use_logit_softcap, mla, needs_fixup, is_fixup, last_iter, oob_check, + <DKQ, DV, ncols1, ncols2, nwarps, use_logit_softcap, V_is_K_view, needs_fixup, is_fixup, last_iter, oob_check, T_A_KQ, T_B_KQ, T_C_KQ, T_A_VKQ, T_B_VKQ, T_C_VKQ> (Q_f2, K_h2, V_h2, mask_h, dstk, dstk_fixup, scale, slope, logit_softcap, ne01, ne02, stride_K, stride_V, stride_mask, tile_Q, tile_K, tile_V, tile_mask, Q_B, VKQ_C, @@ -1010,6 +1324,14 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( // The partial sums are spread across 8/4 threads. constexpr int offset_first = cols_per_warp == 8 ? 16 : 2; constexpr int offset_last = cols_per_warp == 8 ? 4 : 1; +#elif defined(AMD_MFMA_AVAILABLE) + // The partial sums are spread across 4 threads (wavefront64, 16 cols). + constexpr int offset_first = 32; + constexpr int offset_last = 16; +#elif defined(AMD_WMMA_AVAILABLE) + // The partial sums are spread across 2 threads. + constexpr int offset_first = 16; + constexpr int offset_last = 16; #else // Volta // The partial sums are spread across 2 threads. constexpr int offset_first = 2; @@ -1019,19 +1341,19 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( for (int col = 0; col < cols_per_thread; ++col) { #pragma unroll for (int offset = offset_first; offset >= offset_last; offset >>= 1) { - KQ_rowsum[col] += __shfl_xor_sync(0xFFFFFFFF, KQ_rowsum[col], offset, WARP_SIZE); + KQ_rowsum[col] += __shfl_xor_sync(0xFFFFFFFF, KQ_rowsum[col], offset, warp_size); } } } // If attention sinks are used, potentially re-scale if KQ_max is small. - // Also add the sink as a value to KQ_rowsum, this is done after synchonization of KQ_rowsum + // Also add the sink as a value to KQ_rowsum, this is done after synchronization of KQ_rowsum // so it's being done unconditionally for every thread. if (!is_fixup && (np == 1 || threadIdx.y % np == 0) && sinks_f) { float KQ_max_scale[cols_per_thread]; #pragma unroll for (int col = 0; col < cols_per_thread; ++col) { - const int jc = cols_per_warp == 8 ? T_C_KQ::get_j(col) : T_C_KQ::get_i(2*col); + const int jc = (threadIdx.y/np)*cols_per_warp + (cols_per_warp == 8 ? T_C_KQ::get_j(col) : T_C_KQ::get_i(2*col)); const float sink = sinks_f[jc % ncols2]; const float KQ_max_new = fmaxf(KQ_max[col], sink); @@ -1047,7 +1369,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( #if defined(TURING_MMA_AVAILABLE) if constexpr (cols_per_warp == 8) { - const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale[0], KQ_max_scale[1]); + const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale[0], KQ_max_scale[cols_per_thread - 1]); #pragma unroll for (int i = 0; i < DV/T_C_VKQ::I; ++i) { #pragma unroll @@ -1068,6 +1390,26 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( } } } +#elif defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE) + if constexpr (std::is_same_v<decltype(T_C_VKQ::x), half2[T_C_VKQ::ne]>) { + const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale[0], KQ_max_scale[0]); +#pragma unroll + for (int i = 0; i < (DV/2)/T_C_VKQ::J; ++i) { +#pragma unroll + for (int l = 0; l < T_C_VKQ::ne; ++l) { + VKQ_C[i].x[l] *= KQ_max_scale_h2; + } + } + } else { + static_assert(std::is_same_v<decltype(T_C_VKQ::x), float[T_C_VKQ::ne]>, "bad VKQ type"); +#pragma unroll + for (int i = 0; i < DV/T_C_VKQ::J; ++i) { +#pragma unroll + for (int l = 0; l < T_C_VKQ::ne; ++l) { + VKQ_C[i].x[l] *= KQ_max_scale[0]; + } + } + } #else // Volta const int col = (threadIdx.x / 2) % 2; const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale[col], KQ_max_scale[col]); @@ -1119,6 +1461,10 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( const int jc_cwm = threadIdx.y*cols_per_warp + T_C_VKQ::get_i(threadIdx.x % 4); const float2 KQ_cmr = make_float2(KQ_max[threadIdx.x % cols_per_thread], KQ_rowsum[threadIdx.x % cols_per_thread]); const bool thread_should_write = threadIdx.x % 4 < cols_per_thread; +#elif defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE) + const int jc_cwm = threadIdx.y*cols_per_warp + T_C_VKQ::get_i(0); + const float2 KQ_cmr = make_float2(KQ_max[0], KQ_rowsum[0]); + const bool thread_should_write = threadIdx.x / 16 < cols_per_thread; #else // Volta const int jc_cwm = threadIdx.y*cols_per_warp + T_C_KQ::get_i(threadIdx.x & 2); const float2 KQ_cmr = make_float2(KQ_max[(threadIdx.x & 2) / 2], KQ_rowsum[(threadIdx.x & 2) / 2]); @@ -1149,14 +1495,14 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( // Warps with threadIdx.y % np != 0 must NOT return early. // All threads must return simultaneously to avoid race conditions with work on the next tile. - constexpr int nmeta = np*cols_per_warp >= WARP_SIZE ? np*cols_per_warp/WARP_SIZE : 1; + constexpr int nmeta = np*cols_per_warp >= warp_size ? np*cols_per_warp/warp_size : 1; - const int jc_meta = threadIdx.y*cols_per_warp + (np*cols_per_warp < WARP_SIZE ? threadIdx.x % (np*cols_per_warp) : threadIdx.x); + const int jc_meta = threadIdx.y*cols_per_warp + (np*cols_per_warp < warp_size ? threadIdx.x % (np*cols_per_warp) : threadIdx.x); float2 * const meta_ptr = ((float2 *) tile_Q) + jc_meta*(tile_stride/2) + nbatch_combine/2; float2 meta[nmeta]; #pragma unroll for (int imeta = 0; imeta < nmeta; ++imeta) { - meta[imeta] = meta_ptr[imeta * WARP_SIZE * tile_stride/2]; + meta[imeta] = meta_ptr[imeta * warp_size * tile_stride/2]; } float KQ_cmn = meta[0].x; // KQ combine max new, max between all parallel warps. @@ -1166,8 +1512,8 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( } #pragma unroll for (int offset = np*cols_per_warp/2; offset >= cols_per_warp; offset >>= 1) { - if (offset < WARP_SIZE) { - KQ_cmn = fmaxf(KQ_cmn, __shfl_xor_sync(0xFFFFFFFF, KQ_cmn, offset, WARP_SIZE)); + if (offset < warp_size) { + KQ_cmn = fmaxf(KQ_cmn, __shfl_xor_sync(0xFFFFFFFF, KQ_cmn, offset, warp_size)); } } @@ -1184,8 +1530,8 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( } #pragma unroll for (int offset = np*cols_per_warp/2; offset >= cols_per_warp; offset >>= 1) { - if (offset < WARP_SIZE) { - KQ_crs += __shfl_xor_sync(0xFFFFFFFF, KQ_crs, offset, WARP_SIZE); + if (offset < warp_size) { + KQ_crs += __shfl_xor_sync(0xFFFFFFFF, KQ_crs, offset, warp_size); } } @@ -1194,19 +1540,19 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( // Write back combined meta data: #pragma unroll for (int imeta = 0; imeta < nmeta; ++imeta) { - if (np*cols_per_warp >= WARP_SIZE || threadIdx.x < np*cols_per_warp) { + if (np*cols_per_warp >= warp_size || threadIdx.x < np*cols_per_warp) { // Combined KQ max scale + rowsum. - meta_ptr[imeta * WARP_SIZE * tile_stride/2] = make_float2(KQ_cms[imeta], KQ_crs); + meta_ptr[imeta * warp_size * tile_stride/2] = make_float2(KQ_cms[imeta], KQ_crs); } } // Combined KQ max + rowsum. - static_assert(cols_per_warp <= WARP_SIZE); - if (needs_fixup && (cols_per_warp == WARP_SIZE || threadIdx.x < cols_per_warp)) { + static_assert(cols_per_warp <= warp_size); + if (needs_fixup && (cols_per_warp == warp_size || threadIdx.x < cols_per_warp)) { float2 * dstk_fixup_meta = dstk_fixup + blockIdx.x*ncols; dstk_fixup_meta[(threadIdx.y/np)*cols_per_warp + threadIdx.x] = make_float2(KQ_cmn, KQ_crs); } - if (is_fixup && (cols_per_warp == WARP_SIZE || threadIdx.x < cols_per_warp)) { + if (is_fixup && (cols_per_warp == warp_size || threadIdx.x < cols_per_warp)) { float2 * dstk_fixup_meta = dstk_fixup + (gridDim.x + blockIdx.x)*ncols; dstk_fixup_meta[(threadIdx.y/np)*cols_per_warp + threadIdx.x] = make_float2(KQ_cmn, KQ_crs); } @@ -1220,6 +1566,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( #pragma unroll for (int k00 = 0; k00 < DV/2; k00 += nbatch_combine) { if constexpr (cols_per_warp == 8) { + static_assert(std::is_same_v<decltype(T_C_VKQ::x), half2[T_C_VKQ::ne]>, "bad VKQ type"); const int jc_cwd = threadIdx.y*T_B_KQ::I + T_B_KQ::get_i(-1); // jc combine write data #pragma unroll for (int k1 = 0; k1 < nbatch_combine; k1 += T_B_KQ::J) { @@ -1234,14 +1581,45 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( } } else { const int j0 = threadIdx.y*cols_per_warp; + if constexpr (std::is_same_v<decltype(T_C_VKQ::x), half2[T_C_VKQ::ne]>) { + if constexpr (T_C_VKQ::dl == DATA_LAYOUT_I_MAJOR) { #pragma unroll - for (int k1 = 0; k1 < nbatch_combine; k1 += T_C_VKQ::J) { + for (int k1 = 0; k1 < nbatch_combine; k1 += T_C_VKQ::J) { #pragma unroll - for (int l = 0; l < T_C_VKQ::ne; ++l) { - const int j = j0 + T_C_VKQ::get_i(l); - const int k = k1 + T_C_VKQ::get_j(l); + for (int l = 0; l < T_C_VKQ::ne; ++l) { + const int j = j0 + T_C_VKQ::get_i(l); + const int k = k1 + T_C_VKQ::get_j(l); + + tile_Q[j*tile_stride + k] = VKQ_C[(k00 + k1)/T_C_VKQ::J].x[l]; + } + } + } else { + static_assert(T_C_VKQ::dl == DATA_LAYOUT_I_MAJOR_SCRAMBLED, "bad T_C_VKQ data layout"); + using T_C_VKQ_us = tile<T_C_VKQ::I, T_C_VKQ::J, half2, DATA_LAYOUT_I_MAJOR>; // us == unscrambled +#pragma unroll + for (int k1 = 0; k1 < nbatch_combine; k1 += T_C_VKQ::J) { + const T_C_VKQ_us VKQ_C_us = unscramble(VKQ_C[(k00 + k1)/T_C_VKQ::J]); +#pragma unroll + for (int l = 0; l < T_C_VKQ_us::ne; ++l) { + const int j = j0 + T_C_VKQ_us::get_i(l); + const int k = k1 + T_C_VKQ_us::get_j(l); - tile_Q[j*tile_stride + k] = VKQ_C[(k00 + k1)/T_C_VKQ::J].x[l]; + tile_Q[j*tile_stride + k] = VKQ_C_us.x[l]; + } + } + } + } else { + static_assert(std::is_same_v<decltype(T_C_VKQ::x), float[T_C_VKQ::ne]>, "bad VKQ type"); + half * tile_Q_h = (half *) tile_Q; +#pragma unroll + for (int k1 = 0; k1 < nbatch_combine; k1 += T_C_VKQ::J/2) { +#pragma unroll + for (int l = 0; l < T_C_VKQ::ne; ++l) { + const int j = j0 + T_C_VKQ::get_i(l); + const int k = 2*k1 + T_C_VKQ::get_j(l); + + tile_Q_h[j*(2*tile_stride) + k] = VKQ_C[(k00 + k1)/(T_C_VKQ::J/2)].x[l]; + } } } } @@ -1254,10 +1632,10 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( float2 * dstk_fixup_data = dstk_fixup + gridDim.x*(2*ncols) + blockIdx.x*(ncols*(DV/2)); #pragma unroll - for (int stride_k : {WARP_SIZE, WARP_SIZE/2, WARP_SIZE/4}) { - const int k0_start = stride_k == WARP_SIZE ? 0 : nbatch_combine - nbatch_combine % (2*stride_k); + for (int stride_k : {warp_size, warp_size/2, warp_size/4, warp_size/8}) { + const int k0_start = stride_k == warp_size ? 0 : nbatch_combine - nbatch_combine % (2*stride_k); const int k0_stop = nbatch_combine - nbatch_combine % (1*stride_k); - const int stride_jc = WARP_SIZE / stride_k; + const int stride_jc = warp_size / stride_k; if (k0_start == k0_stop) { continue; @@ -1265,7 +1643,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( #pragma unroll for (int jc0_dst = 0; jc0_dst < ncols; jc0_dst += (nwarps/np)*stride_jc) { - const int jc_dst = jc0_dst + (threadIdx.y/np)*stride_jc + (stride_k == WARP_SIZE ? 0 : threadIdx.x / stride_k); + const int jc_dst = jc0_dst + (threadIdx.y/np)*stride_jc + (stride_k == warp_size ? 0 : threadIdx.x / stride_k); if (jc0_dst + (nwarps/np)*stride_jc > ncols && jc_dst >= ncols) { break; @@ -1276,14 +1654,14 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( const int j_dst = jc_dst / ncols2; const int c_dst = jc_dst % ncols2; - if (!is_fixup && jt*ncols1 + j_dst >= int(ne01.z)) { + if (!is_fixup && ((ncols1 > 1 && jt*ncols1 + j_dst >= int(ne01.z)) || (ncols2 > 1 && zt_gqa*ncols2 + c_dst >= gqa_ratio))) { continue; } const float * meta_j = (const float *) tile_Q + jc_tile_K*tile_stride + nbatch_combine; #pragma unroll for (int k0 = k0_start; k0 < k0_stop; k0 += stride_k) { - const int k = k0 + (stride_k == WARP_SIZE ? threadIdx.x : threadIdx.x % stride_k); + const int k = k0 + (stride_k == warp_size ? threadIdx.x : threadIdx.x % stride_k); float2 dstk_val = make_float2(0.0f, 0.0f); #pragma unroll @@ -1315,24 +1693,24 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( } #else GGML_UNUSED_VARS(Q_f2, K_h2, V_h2, mask_h, sinks_f, dstk, dstk_fixup, - scale, slope, logit_softcap, ne01, ne02, + scale, slope, logit_softcap, ne01, ne02, gqa_ratio, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, kb0_start, kb0_stop); NO_DEVICE_CODE; -#endif // defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#endif // defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE) } -template<int DKQ, int DV, int ncols1, int ncols2, bool use_logit_softcap, bool mla> +template<int DKQ, int DV, int ncols1, int ncols2, bool use_logit_softcap, bool V_is_K_view> __launch_bounds__(ggml_cuda_fattn_mma_get_nthreads(DKQ, DV, ncols1*ncols2), ggml_cuda_fattn_mma_get_occupancy(DKQ, DV, ncols1*ncols2)) static __global__ void flash_attn_ext_f16( - const char * __restrict__ Q, - const char * __restrict__ K, - const char * __restrict__ V, - const char * __restrict__ mask, - const char * __restrict__ sinks, - const int * __restrict__ KV_max, - float * __restrict__ dst, - float2 * __restrict__ dst_meta, + const char * Q_ptr, + const char * K_ptr, + const char * V_ptr, + const char * mask_ptr, + const char * sinks_ptr, + const int * KV_max_ptr, + float * dst_ptr, + float2 * dst_meta_ptr, const float scale, const float max_bias, const float m0, @@ -1346,13 +1724,33 @@ static __global__ void flash_attn_ext_f16( const int32_t nb21, const int32_t nb22, const int64_t nb23, const int32_t ne31, const int32_t ne32, const int32_t ne33, const int32_t nb31, const int32_t nb32, const int64_t nb33) { -#if defined(FLASH_ATTN_AVAILABLE) && (defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)) + ggml_cuda_pdl_sync(); // TODO optimize placement +#if defined(FLASH_ATTN_AVAILABLE) && (defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)) + const char * GGML_CUDA_RESTRICT Q = Q_ptr; + const char * GGML_CUDA_RESTRICT K = K_ptr; + const char * GGML_CUDA_RESTRICT V = V_ptr; + const char * GGML_CUDA_RESTRICT mask = mask_ptr; + const char * GGML_CUDA_RESTRICT sinks = sinks_ptr; + const int * GGML_CUDA_RESTRICT KV_max = KV_max_ptr; + float * GGML_CUDA_RESTRICT dst = dst_ptr; + float2 * GGML_CUDA_RESTRICT dst_meta = dst_meta_ptr; // Skip unused kernel variants for faster compilation: - if (use_logit_softcap && !(DKQ == 128 || DKQ == 256)) { + if (use_logit_softcap && !(DKQ == 128 || DKQ == 256 || DKQ == 512)) { NO_DEVICE_CODE; return; } + if (DKQ == 192 && ncols2 != 8 && ncols2 != 16) { + NO_DEVICE_CODE; + return; + } +#ifdef VOLTA_MMA_AVAILABLE + if (ncols1*ncols2 < 32) { + NO_DEVICE_CODE; + return; + } +#endif // VOLTA_MMA_AVAILABLE + #if __CUDA_ARCH__ == GGML_CUDA_CC_TURING if (ncols1*ncols2 > 32) { NO_DEVICE_CODE; @@ -1360,12 +1758,25 @@ static __global__ void flash_attn_ext_f16( } #endif // __CUDA_ARCH__ == GGML_CUDA_CC_TURING - static_assert(!mla || DKQ >= DV, "MLA needs DKQ >= DV"); +#if defined(AMD_WMMA_AVAILABLE) + if (ncols1*ncols2 < 16 || ncols2 == 1 || DKQ > 128) { + NO_DEVICE_CODE; + return; + } +#endif // defined(AMD_WMMA_AVAILABLE) + +#if defined(AMD_MFMA_AVAILABLE) + if (ncols1*ncols2 < 16 || DKQ > 256) { + NO_DEVICE_CODE; + return; + } +#endif // defined(AMD_MFMA_AVAILABLE) + constexpr int warp_size = ggml_cuda_get_physical_warp_size(); constexpr int ncols = ncols1 * ncols2; constexpr int nbatch_fa = ggml_cuda_fattn_mma_get_nbatch_fa(DKQ, DV, ncols); constexpr int nthreads = ggml_cuda_fattn_mma_get_nthreads(DKQ, DV, ncols); - constexpr int nwarps = nthreads / WARP_SIZE; + constexpr int nwarps = nthreads / warp_size; const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix. @@ -1374,14 +1785,15 @@ static __global__ void flash_attn_ext_f16( const int stride_K = nb11 / sizeof(half2); const int stride_mask = nb31 / sizeof(half); - const int stride_V = mla ? stride_K : nb21 / sizeof(half2); + const int stride_V = V_is_K_view ? stride_K : nb21 / sizeof(half2); - const int iter_k = (ne11 + (nbatch_fa - 1)) / nbatch_fa; - const int iter_j = (ne01.z + (ncols1 - 1)) / ncols1; + const int iter_k = (ne11 + (nbatch_fa - 1)) / nbatch_fa; + const int iter_j = (ne01.z + (ncols1 - 1)) / ncols1; + const int iter_z_gqa = (gqa_ratio + (ncols2 - 1)) / ncols2; // kbc == k block continuous, current index in continuous ijk space. - int kbc = int64_t(blockIdx.x + 0)*(iter_k*iter_j*(ne02/ncols2)*ne03) / gridDim.x; - const int kbc_stop = int64_t(blockIdx.x + 1)*(iter_k*iter_j*(ne02/ncols2)*ne03) / gridDim.x; + int kbc = int64_t(blockIdx.x + 0)*(iter_k*iter_j*iter_z_gqa*ne12*ne03) / gridDim.x; + const int kbc_stop = int64_t(blockIdx.x + 1)*(iter_k*iter_j*iter_z_gqa*ne12*ne03) / gridDim.x; // If the seams of 2 CUDA blocks fall within an output tile their results need to be combined. // For this we need to track both the block that starts the tile (needs_fixup) and the block that finishes the tile (is_fixup). @@ -1392,22 +1804,24 @@ static __global__ void flash_attn_ext_f16( int kb0_stop = min(iter_k, kb0_start + kbc_stop - kbc); while (kbc < kbc_stop && kb0_stop == iter_k) { - const int sequence = kbc / (iter_k*iter_j*(ne02/ncols2)); - const int zt = (kbc - iter_k*iter_j*(ne02/ncols2)*sequence) / (iter_k*iter_j); // head in units of ncols2 - const int jt = (kbc - iter_k*iter_j*(ne02/ncols2)*sequence - iter_k*iter_j*zt) / iter_k; // j index of current tile. + // z_KV == K/V head index, zt_gqa = Q head start index per K/V head, jt = token position start index + const int sequence = kbc /(iter_k*iter_j*iter_z_gqa*ne12); + const int z_KV = (kbc - iter_k*iter_j*iter_z_gqa*ne12 * sequence)/(iter_k*iter_j*iter_z_gqa); + const int zt_gqa = (kbc - iter_k*iter_j*iter_z_gqa*ne12 * sequence - iter_k*iter_j*iter_z_gqa * z_KV)/(iter_k*iter_j); + const int jt = (kbc - iter_k*iter_j*iter_z_gqa*ne12 * sequence - iter_k*iter_j*iter_z_gqa * z_KV - iter_k*iter_j * zt_gqa) / iter_k; - const int head0 = zt * ncols2; + const int zt_Q = z_KV*gqa_ratio + zt_gqa*ncols2; // Global Q head start index. - const float2 * Q_f2 = (const float2 *) (Q + nb03*sequence + nb02* head0); - const half2 * K_h2 = (const half2 *) (K + nb13*sequence + nb12*(head0 / gqa_ratio)); + const float2 * Q_f2 = (const float2 *) (Q + nb03*sequence + nb02*zt_Q); + const half2 * K_h2 = (const half2 *) (K + nb13*sequence + nb12*z_KV); const half * mask_h = ncols2 == 1 && !mask ? nullptr : (const half *) (mask + nb33*(sequence % ne33)); - float2 * dstk = ((float2 *) dst) + (sequence*ne01.z*ne02 + head0) * (DV/2); + float2 * dstk = ((float2 *) dst) + (sequence*ne01.z*ne02 + zt_Q) * (DV/2); - const half2 * V_h2 = mla ? K_h2 + (DKQ/2 - DV/2) : (const half2 *) (V + nb23*sequence + nb22*(head0 / gqa_ratio)); - const float * sinks_f = sinks ? (const float *) sinks + head0 : nullptr; + const half2 * V_h2 = V_is_K_view ? K_h2 : (const half2 *) (V + nb23*sequence + nb22*z_KV); + const float * sinks_f = sinks ? (const float *) sinks + zt_Q : nullptr; - const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, head0, n_head_log2, m0, m1) : 1.0f; + const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, zt_Q, n_head_log2, m0, m1) : 1.0f; if (KV_max) { kb0_stop = min(kb0_stop, KV_max[sequence*iter_j + jt] / nbatch_fa); @@ -1415,14 +1829,14 @@ static __global__ void flash_attn_ext_f16( constexpr bool is_fixup = false; // All but (potentially) the last iterations write their data to dst rather than the fixup buffer. if (kb0_start == 0) { constexpr bool needs_fixup = false; // CUDA block is working on an entire tile. - flash_attn_ext_f16_process_tile<DKQ, DV, ncols1, ncols2, nwarps, use_logit_softcap, mla, needs_fixup, is_fixup> + flash_attn_ext_f16_process_tile<DKQ, DV, ncols1, ncols2, nwarps, use_logit_softcap, V_is_K_view, needs_fixup, is_fixup> (Q_f2, K_h2, V_h2, mask_h, sinks_f, dstk, dst_meta, scale, slope, logit_softcap, - ne01, ne02, ne11, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, kb0_start, kb0_stop); + ne01, ne02, gqa_ratio, ne11, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, zt_gqa, kb0_start, kb0_stop); } else { constexpr bool needs_fixup = true; // CUDA block is missing the beginning of a tile. - flash_attn_ext_f16_process_tile<DKQ, DV, ncols1, ncols2, nwarps, use_logit_softcap, mla, needs_fixup, is_fixup> + flash_attn_ext_f16_process_tile<DKQ, DV, ncols1, ncols2, nwarps, use_logit_softcap, V_is_K_view, needs_fixup, is_fixup> (Q_f2, K_h2, V_h2, mask_h, sinks_f, dstk, dst_meta, scale, slope, logit_softcap, - ne01, ne02, ne11, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, kb0_start, kb0_stop); + ne01, ne02, gqa_ratio, ne11, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, zt_gqa, kb0_start, kb0_stop); } kbc += iter_k; @@ -1436,22 +1850,24 @@ static __global__ void flash_attn_ext_f16( return; } - const int sequence = kbc / (iter_k*iter_j*(ne02/ncols2)); - const int zt = (kbc - iter_k*iter_j*(ne02/ncols2)*sequence) / (iter_k*iter_j); // head in units of ncols2 - const int jt = (kbc - iter_k*iter_j*(ne02/ncols2)*sequence - iter_k*iter_j*zt) / iter_k; // j index of current tile. + // z_KV == K/V head index, zt_gqa = Q head start index per K/V head, jt = token position start index. + const int sequence = kbc /(iter_k*iter_j*iter_z_gqa*ne12); + const int z_KV = (kbc - iter_k*iter_j*iter_z_gqa*ne12 * sequence)/(iter_k*iter_j*iter_z_gqa); + const int zt_gqa = (kbc - iter_k*iter_j*iter_z_gqa*ne12 * sequence - iter_k*iter_j*iter_z_gqa * z_KV)/(iter_k*iter_j); + const int jt = (kbc - iter_k*iter_j*iter_z_gqa*ne12 * sequence - iter_k*iter_j*iter_z_gqa * z_KV - iter_k*iter_j * zt_gqa) / iter_k; - const int head0 = zt * ncols2; + const int zt_Q = z_KV*gqa_ratio + zt_gqa*ncols2; // Global Q head start index. - const float2 * Q_f2 = (const float2 *) (Q + nb03*sequence + nb02* head0); - const half2 * K_h2 = (const half2 *) (K + nb13*sequence + nb12*(head0 / gqa_ratio)); + const float2 * Q_f2 = (const float2 *) (Q + nb03*sequence + nb02*zt_Q); + const half2 * K_h2 = (const half2 *) (K + nb13*sequence + nb12*z_KV); const half * mask_h = ncols2 == 1 && !mask ? nullptr : (const half *) (mask + nb33*(sequence % ne33)); - float2 * dstk = ((float2 *) dst) + (sequence*ne01.z*ne02 + head0) * (DV/2); + float2 * dstk = ((float2 *) dst) + (sequence*ne01.z*ne02 + zt_Q) * (DV/2); - const half2 * V_h2 = mla ? K_h2 + (DKQ/2 - DV/2) : (const half2 *) (V + nb23*sequence + nb22*(head0 / gqa_ratio)); - const float * sinks_f = sinks ? (const float *) sinks + head0 : nullptr; + const half2 * V_h2 = V_is_K_view ? K_h2 : (const half2 *) (V + nb23*sequence + nb22*z_KV); + const float * sinks_f = sinks ? (const float *) sinks + zt_Q : nullptr; - const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, head0, n_head_log2, m0, m1) : 1.0f; + const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, zt_Q, n_head_log2, m0, m1) : 1.0f; if (KV_max) { kb0_stop = min(kb0_stop, KV_max[sequence*iter_j + jt] / nbatch_fa); @@ -1459,11 +1875,11 @@ static __global__ void flash_attn_ext_f16( constexpr bool is_fixup = true; // Last index writes its data to fixup buffer to avoid data races with other blocks. constexpr bool needs_fixup = false; - flash_attn_ext_f16_process_tile<DKQ, DV, ncols1, ncols2, nwarps, use_logit_softcap, mla, needs_fixup, is_fixup> + flash_attn_ext_f16_process_tile<DKQ, DV, ncols1, ncols2, nwarps, use_logit_softcap, V_is_K_view, needs_fixup, is_fixup> (Q_f2, K_h2, V_h2, mask_h, sinks_f, dstk, dst_meta, scale, slope, logit_softcap, - ne01, ne02, ne11, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, kb0_start, kb0_stop); + ne01, ne02, gqa_ratio, ne11, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, zt_gqa, kb0_start, kb0_stop); #else - GGML_UNUSED_VARS(Q, K, V, mask, sinks, KV_max, dst, dst_meta, scale, + GGML_UNUSED_VARS(Q_ptr, K_ptr, V_ptr, mask_ptr, sinks_ptr, KV_max_ptr, dst_ptr, dst_meta_ptr, scale, max_bias, m0, m1, n_head_log2, logit_softcap, ne00, ne01, ne02, ne03, nb01, nb02, nb03, @@ -1473,7 +1889,7 @@ static __global__ void flash_attn_ext_f16( ne31, ne32, ne33, nb31, nb32, nb33); NO_DEVICE_CODE; -#endif // defined(FLASH_ATTN_AVAILABLE) && (defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)) +#endif // defined(FLASH_ATTN_AVAILABLE) && (defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)) } template <int DKQ, int DV, int ncols1, int ncols2> @@ -1492,10 +1908,11 @@ void ggml_cuda_flash_attn_ext_mma_f16_case(ggml_backend_cuda_context & ctx, ggml const bool Q_in_reg = ggml_cuda_fattn_mma_get_Q_in_reg (DKQ, DV, ncols, cc); const int nstages = ggml_cuda_fattn_mma_get_nstages (DKQ, DV, ncols1, ncols2, cc); - const int cols_per_warp = std::min(ncols, turing_mma_available(cc) ? 16 : 32); - const int nwarps = nthreads / WARP_SIZE; + const int cols_per_warp = std::min(ncols, get_cols_per_warp(cc)); + const int warp_size_host = ggml_cuda_info().devices[ctx.device].warp_size; + const int nwarps = nthreads / warp_size_host; - constexpr bool mla = DKQ == 576; + constexpr bool V_is_K_view = DKQ == 576; // Guaranteed by the kernel selection logic in fattn.cu const size_t nbytes_shared_KV_1stage = nbatch_fa * std::max(nbatch_K2 + 4, nbatch_V2 + 4) * sizeof(half2); const size_t nbytes_shared_KV_2stage = nbatch_fa * (nbatch_K2 + 4 + nbatch_V2 + 4) * sizeof(half2); @@ -1512,33 +1929,38 @@ void ggml_cuda_flash_attn_ext_mma_f16_case(ggml_backend_cuda_context & ctx, ggml float logit_softcap; memcpy(&logit_softcap, (const float *) KQV->op_params + 2, sizeof(float)); +#if defined(GGML_USE_HIP) + using fattn_kernel_ptr_t = const void*; +#else + using fattn_kernel_ptr_t = fattn_kernel_t; +#endif // defined(GGML_USE_HIP) fattn_kernel_t fattn_kernel; if (logit_softcap == 0.0f) { constexpr bool use_logit_softcap = false; - fattn_kernel = flash_attn_ext_f16<DKQ, DV, ncols1, ncols2, use_logit_softcap, mla>; + fattn_kernel = flash_attn_ext_f16<DKQ, DV, ncols1, ncols2, use_logit_softcap, V_is_K_view>; -#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) +#if !defined(GGML_USE_MUSA) static bool shared_memory_limit_raised[GGML_CUDA_MAX_DEVICES] = {false}; if (!shared_memory_limit_raised[id]) { - CUDA_CHECK(cudaFuncSetAttribute(fattn_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, nbytes_shared_total)); + CUDA_CHECK(cudaFuncSetAttribute(reinterpret_cast<fattn_kernel_ptr_t>(fattn_kernel), cudaFuncAttributeMaxDynamicSharedMemorySize, nbytes_shared_total)); shared_memory_limit_raised[id] = true; } -#endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) +#endif // !defined(GGML_USE_MUSA) } else { constexpr bool use_logit_softcap = true; - fattn_kernel = flash_attn_ext_f16<DKQ, DV, ncols1, ncols2, use_logit_softcap, mla>; + fattn_kernel = flash_attn_ext_f16<DKQ, DV, ncols1, ncols2, use_logit_softcap, V_is_K_view>; -#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) +#if !defined(GGML_USE_MUSA) static bool shared_memory_limit_raised[GGML_CUDA_MAX_DEVICES] = {false}; if (!shared_memory_limit_raised[id]) { - CUDA_CHECK(cudaFuncSetAttribute(fattn_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, nbytes_shared_total)); + CUDA_CHECK(cudaFuncSetAttribute(reinterpret_cast<fattn_kernel_ptr_t>(fattn_kernel), cudaFuncAttributeMaxDynamicSharedMemorySize, nbytes_shared_total)); shared_memory_limit_raised[id] = true; } -#endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) +#endif // !defined(GGML_USE_MUSA) } launch_fattn<DV, ncols1, ncols2> - (ctx, dst, fattn_kernel, nwarps, nbytes_shared_total, nbatch_fa, true, true, true); + (ctx, dst, fattn_kernel, nwarps, nbytes_shared_total, nbatch_fa, true, true, true, warp_size_host); } @@ -1581,7 +2003,27 @@ DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(112, 112, 64) DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(128, 128, 64) DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(256, 256, 64) +extern DECL_FATTN_MMA_F16_CASE(512, 512, 2, 4); +extern DECL_FATTN_MMA_F16_CASE(512, 512, 4, 4); +extern DECL_FATTN_MMA_F16_CASE(512, 512, 8, 4); +extern DECL_FATTN_MMA_F16_CASE(512, 512, 16, 4); +extern DECL_FATTN_MMA_F16_CASE(512, 512, 1, 8); +extern DECL_FATTN_MMA_F16_CASE(512, 512, 2, 8); +extern DECL_FATTN_MMA_F16_CASE(512, 512, 4, 8); +extern DECL_FATTN_MMA_F16_CASE(512, 512, 8, 8); + // The number of viable configurations for Deepseek is very limited: extern DECL_FATTN_MMA_F16_CASE(576, 512, 1, 16); extern DECL_FATTN_MMA_F16_CASE(576, 512, 2, 16); extern DECL_FATTN_MMA_F16_CASE(576, 512, 4, 16); + +// Mistral Small 4 (DKQ=320, DV=256), GQA=32-only build: +extern DECL_FATTN_MMA_F16_CASE(320, 256, 1, 32); +extern DECL_FATTN_MMA_F16_CASE(320, 256, 2, 32); + +// For GLM 4.7 Flash +extern DECL_FATTN_MMA_F16_CASE(576, 512, 4, 4); +extern DECL_FATTN_MMA_F16_CASE(576, 512, 8, 4); +extern DECL_FATTN_MMA_F16_CASE(576, 512, 16, 4); +extern DECL_FATTN_MMA_F16_CASE(576, 512, 1, 32); +extern DECL_FATTN_MMA_F16_CASE(576, 512, 2, 32); diff --git a/ggml/src/ggml-cuda/fattn-tile.cu b/ggml/src/ggml-cuda/fattn-tile.cu index 3fcb09b7a2b..c8281497d14 100644 --- a/ggml/src/ggml-cuda/fattn-tile.cu +++ b/ggml/src/ggml-cuda/fattn-tile.cu @@ -34,10 +34,22 @@ void ggml_cuda_flash_attn_ext_tile(ggml_backend_cuda_context & ctx, ggml_tensor GGML_ASSERT(V->ne[0] == K->ne[0]); ggml_cuda_flash_attn_ext_tile_case<128, 128>(ctx, dst); } break; + case 192: { + GGML_ASSERT(V->ne[0] == 128); + ggml_cuda_flash_attn_ext_tile_case<192, 128>(ctx, dst); + } break; case 256: { GGML_ASSERT(V->ne[0] == K->ne[0]); ggml_cuda_flash_attn_ext_tile_case<256, 256>(ctx, dst); } break; + case 320: { + GGML_ASSERT(V->ne[0] == 256); + ggml_cuda_flash_attn_ext_tile_case<320, 256>(ctx, dst); + } break; + case 512: { + GGML_ASSERT(V->ne[0] == K->ne[0]); + ggml_cuda_flash_attn_ext_tile_case<512, 512>(ctx, dst); + } break; case 576: { GGML_ASSERT(V->ne[0] == 512); ggml_cuda_flash_attn_ext_tile_case<576, 512>(ctx, dst); diff --git a/ggml/src/ggml-cuda/fattn-tile.cuh b/ggml/src/ggml-cuda/fattn-tile.cuh index 7c4d6fe67fe..0a099810e14 100644 --- a/ggml/src/ggml-cuda/fattn-tile.cuh +++ b/ggml/src/ggml-cuda/fattn-tile.cuh @@ -62,12 +62,26 @@ static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_tile_get_config_nv GGML_CUDA_FATTN_TILE_CONFIG_CASE(128, 128, 16, 256, 2, 64, 64) GGML_CUDA_FATTN_TILE_CONFIG_CASE(128, 128, 32, 256, 2, 64, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(192, 128, 2, 64, 2, 64, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(192, 128, 4, 128, 2, 64, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(192, 128, 8, 256, 2, 64, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(192, 128, 16, 256, 2, 64, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(192, 128, 32, 256, 2, 64, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 2, 64, 2, 64, 64) GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 4, 128, 2, 64, 64) GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 8, 256, 2, 64, 64) GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 16, 256, 2, 64, 64) GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 32, 256, 2, 64, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(320, 256, 16, 256, 2, 64, 64) + + GGML_CUDA_FATTN_TILE_CONFIG_CASE(512, 512, 4, 128, 2, 64, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(512, 512, 8, 256, 2, 64, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(512, 512, 16, 256, 2, 64, 64) + + GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 4, 128, 2, 64, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 8, 256, 2, 64, 64) GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 16, 256, 2, 64, 64) return 0; @@ -116,12 +130,26 @@ static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_tile_get_config_nv GGML_CUDA_FATTN_TILE_CONFIG_CASE(128, 128, 16, 128, 3, 32, 128) GGML_CUDA_FATTN_TILE_CONFIG_CASE(128, 128, 32, 256, 2, 64, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(192, 128, 2, 128, 3, 64, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(192, 128, 4, 128, 3, 32, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(192, 128, 8, 256, 2, 32, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(192, 128, 16, 256, 2, 32, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(192, 128, 32, 256, 2, 32, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 2, 128, 3, 64, 64) GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 4, 128, 3, 32, 64) GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 8, 256, 2, 32, 256) GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 16, 256, 2, 32, 128) GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 32, 256, 2, 32, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(320, 256, 16, 256, 2, 32, 64) + + GGML_CUDA_FATTN_TILE_CONFIG_CASE(512, 512, 4, 128, 2, 32, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(512, 512, 8, 256, 2, 32, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(512, 512, 16, 256, 2, 32, 64) + + GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 4, 128, 2, 32, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 8, 256, 2, 32, 64) GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 16, 256, 2, 32, 64) return 0; @@ -177,12 +205,27 @@ static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_tile_get_config_am GGML_CUDA_FATTN_TILE_CONFIG_CASE(128, 128, 32, 256, 2, 64, 64) GGML_CUDA_FATTN_TILE_CONFIG_CASE(128, 128, 64, 256, 2, 64, 32) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(192, 128, 2, 256, 2, 128, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(192, 128, 4, 256, 2, 64, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(192, 128, 8, 256, 2, 64, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(192, 128, 16, 256, 2, 32, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(192, 128, 32, 256, 2, 32, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 2, 256, 2, 128, 64) GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 4, 256, 2, 64, 128) GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 8, 256, 2, 64, 128) GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 16, 256, 2, 32, 128) GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 32, 256, 2, 32, 128) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(320, 256, 32, 512, 1, 128, 64) + + GGML_CUDA_FATTN_TILE_CONFIG_CASE(512, 512, 4, 128, 2, 64, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(512, 512, 8, 256, 2, 64, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(512, 512, 16, 256, 2, 64, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(512, 512, 32, 512, 1, 128, 64) + + GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 4, 128, 2, 64, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 8, 256, 2, 64, 64) GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 16, 256, 2, 64, 64) GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 32, 512, 1, 128, 64) @@ -239,12 +282,27 @@ static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_tile_get_config_am GGML_CUDA_FATTN_TILE_CONFIG_CASE(128, 128, 32, 256, 3, 128, 64) GGML_CUDA_FATTN_TILE_CONFIG_CASE(128, 128, 64, 256, 3, 64, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(192, 128, 2, 64, 8, 32, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(192, 128, 4, 128, 6, 32, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(192, 128, 8, 128, 6, 32, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(192, 128, 16, 256, 5, 32, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(192, 128, 32, 256, 3, 64, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 2, 64, 8, 32, 64) GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 4, 128, 6, 32, 256) GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 8, 128, 6, 32, 256) GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 16, 256, 5, 32, 256) GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 32, 256, 3, 64, 128) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(320, 256, 32, 256, 2, 128, 64) + + GGML_CUDA_FATTN_TILE_CONFIG_CASE(512, 512, 4, 128, 2, 64, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(512, 512, 8, 256, 2, 64, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(512, 512, 16, 256, 4, 64, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(512, 512, 32, 256, 2, 128, 64) + + GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 4, 128, 2, 64, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 8, 256, 2, 64, 64) GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 16, 256, 4, 64, 64) GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 32, 256, 2, 128, 64) @@ -343,7 +401,7 @@ static __device__ __forceinline__ void flash_attn_tile_load_tile( for (int j0 = j0_start; j0 < j0_stop; j0 += stride_j) { const int j = j0*cpy_ne + (stride_j == warp_size ? threadIdx.x : threadIdx.x % stride_j)*cpy_ne; - const half2 zero[cpy_ne] = {{0.0f, 0.0f}}; + const __align__(16) half2 zero[cpy_ne] = {{0.0f, 0.0f}}; ggml_cuda_memcpy_1<cpy_nb>( tile_KV + i*(J/2 + J_padding) + j, !oob_check || i < i_sup ? KV + i*stride_KV + j : zero); @@ -394,11 +452,11 @@ static __device__ __forceinline__ void flash_attn_tile_load_tile( const int j = j0*(cpy_ne/2) + (stride_j == warp_size ? threadIdx.x : threadIdx.x % stride_j)*(cpy_ne/2); const half2 zero[cpy_ne/2] = {{0.0f, 0.0f}}; - half2 tmp_h2[cpy_ne/2]; + __align__(16) half2 tmp_h2[cpy_ne/2]; ggml_cuda_memcpy_1<sizeof(tmp_h2)>( tmp_h2, !oob_check || i < i_sup ? KV + i*stride_KV + j : zero); - float2 tmp_f2[cpy_ne/2]; + __align__(16) float2 tmp_f2[cpy_ne/2]; #pragma unroll for (int l = 0; l < cpy_ne/2; ++l) { tmp_f2[l] = __half22float2(tmp_h2[l]); @@ -445,14 +503,14 @@ static __device__ __forceinline__ void flash_attn_tile_iter_KQ( static_assert((nbatch_K/2) % cpy_ne == 0, "bad nbatch_K"); #pragma unroll for (int k_KQ_1 = 0; k_KQ_1 < nbatch_K/2; k_KQ_1 += cpy_ne) { - half2 K_k[nbatch_fa/(np*warp_size)][cpy_ne]; - half2 Q_k[cpw][cpy_ne]; + __align__(16) half2 K_k[nbatch_fa/(np*warp_size)][cpy_ne]; + __align__(16) half2 Q_k[cpw][cpy_ne]; #else static_assert(nbatch_K % cpy_ne == 0, "bad nbatch_K"); #pragma unroll for (int k_KQ_1 = 0; k_KQ_1 < nbatch_K; k_KQ_1 += cpy_ne) { - float K_k[nbatch_fa/(np*warp_size)][cpy_ne]; - float Q_k[cpw][cpy_ne]; + __align__(16) float K_k[nbatch_fa/(np*warp_size)][cpy_ne]; + __align__(16) float Q_k[cpw][cpy_ne]; #endif // FAST_FP16_AVAILABLE #pragma unroll @@ -602,9 +660,9 @@ static __device__ __forceinline__ void flash_attn_tile_iter( #pragma unroll for (int jc0 = 0; jc0 < cpw; jc0 += KQ_cs) { #ifdef FAST_FP16_AVAILABLE - half tmp[nbatch_fa/(np*warp_size)][KQ_cs]; + __align__(16) half tmp[nbatch_fa/(np*warp_size)][KQ_cs]; #else - float tmp[nbatch_fa/(np*warp_size)][KQ_cs]; + __align__(16) float tmp[nbatch_fa/(np*warp_size)][KQ_cs]; #endif // FAST_FP16_AVAILABLE #pragma unroll @@ -664,8 +722,8 @@ static __device__ __forceinline__ void flash_attn_tile_iter( #ifdef FAST_FP16_AVAILABLE #pragma unroll for (int k1 = 0; k1 < nbatch_V; k1 += np) { - half2 V_k[(DVp/2)/warp_size]; - half2 KQ_k[cpw]; + __align__(16) half2 V_k[(DVp/2)/warp_size]; + __align__(16) half2 KQ_k[cpw]; constexpr int cpy_ne_D = cpy_ne/2 < (DVp/2)/warp_size ? cpy_ne/2 : (DVp/2)/warp_size; #pragma unroll @@ -676,7 +734,7 @@ static __device__ __forceinline__ void flash_attn_tile_iter( for (int jc_VKQ_0 = 0; jc_VKQ_0 < cpw; jc_VKQ_0 += KQ_cs) { const int jc_KQ = jc_VKQ_0/KQ_cs + (threadIdx.y / np)*(cpw/KQ_cs); - half tmp[KQ_cs]; + __align__(16) half tmp[KQ_cs]; ggml_cuda_memcpy_1<KQ_cs*sizeof(half)>( &tmp, KQ + jc_KQ*(nbatch_fa*KQ_cs) + (k0 + k1 + threadIdx.y % np)*KQ_cs); #pragma unroll @@ -696,8 +754,8 @@ static __device__ __forceinline__ void flash_attn_tile_iter( #else #pragma unroll for (int k1 = 0; k1 < nbatch_V; k1 += np) { - float2 V_k[(DVp/2)/warp_size]; - float KQ_k[cpw]; + __align__(16) float2 V_k[(DVp/2)/warp_size]; + __align__(16) float KQ_k[cpw]; constexpr int cpy_ne_D = cpy_ne < DVp/warp_size ? cpy_ne : DVp/warp_size; #pragma unroll @@ -730,14 +788,14 @@ static __device__ __forceinline__ void flash_attn_tile_iter( template<int DKQ, int DV, int ncols1, int ncols2, bool use_logit_softcap> // D == head size __launch_bounds__(ggml_cuda_fattn_tile_get_nthreads(DKQ, DV, ncols1*ncols2), ggml_cuda_fattn_tile_get_occupancy(DKQ, DV, ncols1*ncols2)) static __global__ void flash_attn_tile( - const char * __restrict__ Q, - const char * __restrict__ K, - const char * __restrict__ V, - const char * __restrict__ mask, - const char * __restrict__ sinks, - const int * __restrict__ KV_max, - float * __restrict__ dst, - float2 * __restrict__ dst_meta, + const char * Q_ptr, + const char * K_ptr, + const char * V_ptr, + const char * mask_ptr, + const char * sinks_ptr, + const int * KV_max_ptr, + float * dst_ptr, + float2 * dst_meta_ptr, const float scale, const float max_bias, const float m0, @@ -752,6 +810,14 @@ static __global__ void flash_attn_tile( const int32_t ne31, const int32_t ne32, const int32_t ne33, const int32_t nb31, const int32_t nb32, const int64_t nb33) { #ifdef FLASH_ATTN_AVAILABLE + const char * GGML_CUDA_RESTRICT Q = Q_ptr; + const char * GGML_CUDA_RESTRICT K = K_ptr; + const char * GGML_CUDA_RESTRICT V = V_ptr; + const char * GGML_CUDA_RESTRICT mask = mask_ptr; + const char * GGML_CUDA_RESTRICT sinks = sinks_ptr; + const int * GGML_CUDA_RESTRICT KV_max = KV_max_ptr; + float * GGML_CUDA_RESTRICT dst = dst_ptr; + float2 * GGML_CUDA_RESTRICT dst_meta = dst_meta_ptr; // Skip unused kernel variants for faster compilation: @@ -759,7 +825,7 @@ static __global__ void flash_attn_tile( #ifdef GGML_USE_WMMA_FATTN (ncols2 != 1 && DV != 40 && DV != 72 && DV != 512) || #endif // GGML_USE_WMMA_FATTN - (use_logit_softcap && !(DV == 128 || DV == 256)) + (use_logit_softcap && !(DV == 128 || DV == 256 || DV == 512)) ) { GGML_UNUSED_VARS(Q, K, V, mask, sinks, KV_max, dst, dst_meta, scale, max_bias, m0, m1, n_head_log2, logit_softcap, @@ -821,12 +887,12 @@ static __global__ void flash_attn_tile( __shared__ half2 Q_tmp[ncols * DKQ/2]; __shared__ half2 KV_tmp[nbatch_fa * (nbatch_K/2 + cpy_ne) + DVp-DV]; __shared__ half KQ[ncols * nbatch_fa]; - half2 VKQ[cpw * ((DVp/2)/warp_size)] = {{0.0f, 0.0f}}; + __align__(16) half2 VKQ[cpw * ((DVp/2)/warp_size)] = {{0.0f, 0.0f}}; #else __shared__ float Q_tmp[ncols * DKQ]; __shared__ float KV_tmp[nbatch_fa * (nbatch_K + cpy_ne) + DVp-DV]; __shared__ float KQ[ncols * nbatch_fa]; - float2 VKQ[cpw * ((DVp/2)/warp_size)] = {{0.0f, 0.0f}}; + __align__(16) float2 VKQ[cpw * ((DVp/2)/warp_size)] = {{0.0f, 0.0f}}; #endif // FAST_FP16_AVAILABLE float KQ_max[cpw]; @@ -836,6 +902,8 @@ static __global__ void flash_attn_tile( } float KQ_sum[cpw] = {0.0f}; + ggml_cuda_pdl_sync(); + // Load Q data, convert to FP16 if fast: #pragma unroll for (int jc0 = 0; jc0 < cpw; ++jc0) { @@ -849,7 +917,7 @@ static __global__ void flash_attn_tile( #pragma unroll for (int i0 = 0; i0 < DKQp; i0 += np*warp_size*cpy_ne_D) { if (i0 + np*warp_size*cpy_ne_D <= DKQ || i0 + (threadIdx.y % np)*(warp_size*cpy_ne_D) + threadIdx.x*cpy_ne_D < DKQ) { - float tmp_f[cpy_ne_D] = {0.0f}; + __align__(16) float tmp_f[cpy_ne_D] = {0.0f}; ggml_cuda_memcpy_1<sizeof(tmp_f)> (tmp_f, &Q_f[c*(nb02/sizeof(float)) + fastmodulo(col_Q_0 + j, ne01)*(nb01/sizeof(float)) + i0 + (threadIdx.y % np)*(warp_size*cpy_ne_D) + threadIdx.x*cpy_ne_D]); @@ -860,7 +928,7 @@ static __global__ void flash_attn_tile( } #ifdef FAST_FP16_AVAILABLE - half2 tmp_h2[cpy_ne_D/2]; + __align__(16) half2 tmp_h2[cpy_ne_D/2]; #pragma unroll for (int i1 = 0; i1 < cpy_ne_D; i1 += 2) { tmp_h2[i1/2] = make_half2(tmp_f[i1 + 0], tmp_f[i1 + 1]); @@ -959,7 +1027,7 @@ static __global__ void flash_attn_tile( constexpr int cpy_ne_D = cpy_ne < (DVp/2)/warp_size ? cpy_ne : (DVp/2)/warp_size; #pragma unroll for (int i0 = 0; i0 < DVp/2; i0 += warp_size*cpy_ne_D) { - half2 tmp[cpy_ne_D]; + __align__(16) half2 tmp[cpy_ne_D]; ggml_cuda_memcpy_1<cpy_ne_D*4>(tmp, &VKQ_combine[(threadIdx.y + ip)*(DVp/2) + i0 + threadIdx.x*cpy_ne_D]); #pragma unroll for (int i1 = 0; i1 < cpy_ne_D; ++i1) { @@ -970,7 +1038,7 @@ static __global__ void flash_attn_tile( constexpr int cpy_ne_D = cpy_ne < DVp/warp_size ? cpy_ne : DVp/warp_size; #pragma unroll for (int i0 = 0; i0 < DVp; i0 += warp_size*cpy_ne_D) { - float tmp[cpy_ne_D]; + __align__(16) float tmp[cpy_ne_D]; ggml_cuda_memcpy_1<cpy_ne_D*4>(tmp, &VKQ_combine[(threadIdx.y + ip)*DVp + i0 + threadIdx.x*cpy_ne_D]); #pragma unroll for (int i1 = 0; i1 < cpy_ne_D; ++i1) { @@ -1033,7 +1101,7 @@ static __global__ void flash_attn_tile( constexpr int cpy_ne_D = cpy_ne/2 < (DVp/2)/warp_size ? cpy_ne/2 : (DVp/2)/warp_size; #pragma unroll for (int i0 = 0; i0 < DVp/2; i0 += warp_size*cpy_ne_D) { - float2 tmp[cpy_ne_D]; + __align__(16) float2 tmp[cpy_ne_D]; #pragma unroll for (int i1 = 0; i1 < cpy_ne_D; ++i1) { tmp[i1] = __half22float2(VKQ[jc0*((DVp/2)/warp_size) + i0/warp_size + i1]); @@ -1066,7 +1134,7 @@ static __global__ void flash_attn_tile( } } #else - GGML_UNUSED_VARS(Q, K, V, mask, sinks, KV_max, dst, dst_meta, scale, + GGML_UNUSED_VARS(Q_ptr, K_ptr, V_ptr, mask_ptr, sinks_ptr, KV_max_ptr, dst_ptr, dst_meta_ptr, scale, max_bias, m0, m1, n_head_log2, logit_softcap, ne00, ne01, ne02, ne03, nb01, nb02, nb03, @@ -1090,7 +1158,7 @@ static void launch_fattn_tile_switch_ncols1(ggml_backend_cuda_context & ctx, ggm constexpr size_t nbytes_shared = 0; #ifdef GGML_USE_HIP - if constexpr (DV <= 128) { + if constexpr (DKQ <= 128) { if (Q->ne[1] > 32/ncols2) { constexpr int cols_per_block = 64; const int nwarps = ggml_cuda_fattn_tile_get_nthreads (DKQ, DV, cols_per_block, cc) / warp_size; @@ -1104,7 +1172,7 @@ static void launch_fattn_tile_switch_ncols1(ggml_backend_cuda_context & ctx, ggm #endif // GGML_USE_HIP #ifndef GGML_USE_HIP - if constexpr (DV <= 256) + if constexpr (DKQ <= 256) #endif // GGML_USE_HIP { if (Q->ne[1] > 16/ncols2) { @@ -1118,14 +1186,16 @@ static void launch_fattn_tile_switch_ncols1(ggml_backend_cuda_context & ctx, ggm } } - if (Q->ne[1] > 8/ncols2) { - constexpr int cols_per_block = 16; - const int nwarps = ggml_cuda_fattn_tile_get_nthreads (DKQ, DV, cols_per_block, cc) / warp_size; - const int nbatch_fa = ggml_cuda_fattn_tile_get_nbatch_fa(DKQ, DV, cols_per_block, cc); - fattn_kernel_t fattn_kernel = flash_attn_tile<DKQ, DV, cols_per_block/ncols2, ncols2, use_logit_softcap>; - launch_fattn<DV, cols_per_block/ncols2, ncols2> - (ctx, dst, fattn_kernel, nwarps, nbytes_shared, nbatch_fa, true, true, false, warp_size); - return; + if constexpr (ncols2 <= 16) { + if (Q->ne[1] > 8/ncols2) { + constexpr int cols_per_block = 16; + const int nwarps = ggml_cuda_fattn_tile_get_nthreads (DKQ, DV, cols_per_block, cc) / warp_size; + const int nbatch_fa = ggml_cuda_fattn_tile_get_nbatch_fa(DKQ, DV, cols_per_block, cc); + fattn_kernel_t fattn_kernel = flash_attn_tile<DKQ, DV, cols_per_block/ncols2, ncols2, use_logit_softcap>; + launch_fattn<DV, cols_per_block/ncols2, ncols2> + (ctx, dst, fattn_kernel, nwarps, nbytes_shared, nbatch_fa, true, true, false, warp_size); + return; + } } if constexpr (ncols2 <= 8) { @@ -1178,18 +1248,56 @@ static void launch_fattn_tile_switch_ncols2(ggml_backend_cuda_context & ctx, ggm GGML_ASSERT(Q->ne[2] % K->ne[2] == 0); const int gqa_ratio = Q->ne[2] / K->ne[2]; + // On NVIDIA (Pascal and older) the GQA optimizations seem to be detrimental in some cases. + // However, for DKQ == 576, DV == 512 only the kernel variant with GQA optimizations is implemented. const bool nvidia = GGML_CUDA_CC_IS_NVIDIA(ggml_cuda_info().devices[ggml_cuda_get_device()].cc); - const int gqa_limit = nvidia && gqa_ratio <= 4 ? 16 : INT_MAX; + const int gqa_limit = nvidia && gqa_ratio <= 4 && DV <= 256 ? 16 : INT_MAX; const bool use_gqa_opt = mask && max_bias == 0.0f && Q->ne[1] <= gqa_limit && K->ne[1] % FATTN_KQ_STRIDE == 0; - if constexpr (DV == 512) { + if constexpr (DKQ == 320) { + // This branch is only used for Mistral Small 4 which has a GQA ratio of 32. + // On AMD, simply use that GQA ratio with 32 columns / block since we always have enough SRAM. + // On NVIDIA however, the tile kernel is only used for GPUs that can't use the mma kernel (Pascal and older). + // Therefore, use a GQA ratio of 16 with 16 columns / block to stay below 48 kiB of SRAM / block. +#ifdef GGML_USE_HIP + if (use_gqa_opt && gqa_ratio % 32 == 0) { + launch_fattn_tile_switch_ncols1<DKQ, DV, 32, use_logit_softcap>(ctx, dst); + return; + } +#else + if (use_gqa_opt && gqa_ratio % 16 == 0) { + launch_fattn_tile_switch_ncols1<DKQ, DV, 16, use_logit_softcap>(ctx, dst); + return; + } +#endif // GGML_USE_HIP + GGML_ABORT("flash-attn tile (320/256): expected GQA ratio multiple of 32"); + } + + if constexpr (DKQ == 576) { if (use_gqa_opt && gqa_ratio % 16 == 0) { launch_fattn_tile_switch_ncols1<DKQ, DV, 16, use_logit_softcap>(ctx, dst); return; } + if (use_gqa_opt && gqa_ratio % 4 == 0) { + launch_fattn_tile_switch_ncols1<DKQ, DV, 4, use_logit_softcap>(ctx, dst); + return; + } } - if constexpr (DV <= 256) { + if constexpr (DKQ == 192) { + // MiMo-V2.5 / V2.5-Pro / V2-Flash: gqa_ratio is 8 (SWA) or 16 (full attn) + if (use_gqa_opt && gqa_ratio % 16 == 0) { + launch_fattn_tile_switch_ncols1<DKQ, DV, 16, use_logit_softcap>(ctx, dst); + return; + } + if (use_gqa_opt && gqa_ratio % 8 == 0) { + launch_fattn_tile_switch_ncols1<DKQ, DV, 8, use_logit_softcap>(ctx, dst); + return; + } + GGML_ABORT("flash-attn tile (192/128): expected GQA ratio multiple of 8"); + } + + if constexpr (DKQ <= 512 && DKQ != 320 && DKQ != 192) { if (use_gqa_opt && gqa_ratio % 8 == 0) { launch_fattn_tile_switch_ncols1<DKQ, DV, 8, use_logit_softcap>(ctx, dst); return; @@ -1200,13 +1308,15 @@ static void launch_fattn_tile_switch_ncols2(ggml_backend_cuda_context & ctx, ggm return; } - if (use_gqa_opt && gqa_ratio % 2 == 0) { - launch_fattn_tile_switch_ncols1<DKQ, DV, 2, use_logit_softcap>(ctx, dst); + if constexpr (DV <= 256) { + if (use_gqa_opt && gqa_ratio % 2 == 0) { + launch_fattn_tile_switch_ncols1<DKQ, DV, 2, use_logit_softcap>(ctx, dst); + return; + } + + launch_fattn_tile_switch_ncols1<DKQ, DV, 1, use_logit_softcap>(ctx, dst); return; } - - launch_fattn_tile_switch_ncols1<DKQ, DV, 1, use_logit_softcap>(ctx, dst); - return; } GGML_ABORT("fatal error"); } @@ -1240,5 +1350,8 @@ extern DECL_FATTN_TILE_CASE( 80, 80); extern DECL_FATTN_TILE_CASE( 96, 96); extern DECL_FATTN_TILE_CASE(112, 112); extern DECL_FATTN_TILE_CASE(128, 128); +extern DECL_FATTN_TILE_CASE(192, 128); extern DECL_FATTN_TILE_CASE(256, 256); +extern DECL_FATTN_TILE_CASE(320, 256); +extern DECL_FATTN_TILE_CASE(512, 512); extern DECL_FATTN_TILE_CASE(576, 512); diff --git a/ggml/src/ggml-cuda/fattn-vec.cuh b/ggml/src/ggml-cuda/fattn-vec.cuh index 4d167b95a07..69dd9368624 100644 --- a/ggml/src/ggml-cuda/fattn-vec.cuh +++ b/ggml/src/ggml-cuda/fattn-vec.cuh @@ -10,7 +10,7 @@ static constexpr __device__ int ggml_cuda_fattn_vec_get_nthreads_device() { return 128; } -// Currenlty llvm with the amdgcn target dose not support unrolling loops +// Currently llvm with the amdgcn target does not support unrolling loops // that contain a break that can not be resolved at compile time. #ifdef __clang__ #pragma clang diagnostic push @@ -19,14 +19,14 @@ static constexpr __device__ int ggml_cuda_fattn_vec_get_nthreads_device() { template<int D, int ncols, ggml_type type_K, ggml_type type_V, bool use_logit_softcap> // D == head size __launch_bounds__(ggml_cuda_fattn_vec_get_nthreads_device(), 1) static __global__ void flash_attn_ext_vec( - const char * __restrict__ Q, - const char * __restrict__ K, - const char * __restrict__ V, - const char * __restrict__ mask, - const char * __restrict__ sinks, - const int * __restrict__ KV_max, - float * __restrict__ dst, - float2 * __restrict__ dst_meta, + const char * Q_ptr, + const char * K_ptr, + const char * V_ptr, + const char * mask_ptr, + const char * sinks_ptr, + const int * KV_max_ptr, + float * dst_ptr, + float2 * dst_meta_ptr, const float scale, const float max_bias, const float m0, @@ -40,7 +40,16 @@ static __global__ void flash_attn_ext_vec( const int32_t nb21, const int32_t nb22, const int64_t nb23, const int32_t ne31, const int32_t ne32, const int32_t ne33, const int32_t nb31, const int32_t nb32, const int64_t nb33) { + ggml_cuda_pdl_lc(); #ifdef FLASH_ATTN_AVAILABLE + const char * GGML_CUDA_RESTRICT Q = Q_ptr; + const char * GGML_CUDA_RESTRICT K = K_ptr; + const char * GGML_CUDA_RESTRICT V = V_ptr; + const char * GGML_CUDA_RESTRICT mask = mask_ptr; + const char * GGML_CUDA_RESTRICT sinks = sinks_ptr; + const int * GGML_CUDA_RESTRICT KV_max = KV_max_ptr; + float * GGML_CUDA_RESTRICT dst = dst_ptr; + float2 * GGML_CUDA_RESTRICT dst_meta = dst_meta_ptr; // Skip unused kernel variants for faster compilation: if (use_logit_softcap && !(D == 128 || D == 256)) { @@ -75,17 +84,17 @@ static __global__ void flash_attn_ext_vec( #endif // GGML_USE_HIP constexpr int nthreads = ggml_cuda_fattn_vec_get_nthreads_device(); - constexpr int nthreads_KQ = type_K == GGML_TYPE_F16 ? 128 / cpy_nb : nthreads_KQ_q; - constexpr int nthreads_V = type_V == GGML_TYPE_F16 ? 128 / cpy_nb : nthreads_V_q; + constexpr int nthreads_KQ = (type_K == GGML_TYPE_F16 || type_K == GGML_TYPE_BF16) ? 128 / cpy_nb : nthreads_KQ_q; + constexpr int nthreads_V = (type_V == GGML_TYPE_F16 || type_V == GGML_TYPE_BF16) ? 128 / cpy_nb : nthreads_V_q; static_assert(WARP_SIZE % nthreads_KQ == 0, "bad nthreads_K"); static_assert(WARP_SIZE % nthreads_V == 0, "bad nthreads_V"); - constexpr int V_rows_per_thread = type_V == GGML_TYPE_F16 ? 2*cpy_ne : 4; + constexpr int V_rows_per_thread = (type_V == GGML_TYPE_F16 || type_V == GGML_TYPE_BF16) ? 2*cpy_ne : 4; constexpr int V_cols_per_iter = WARP_SIZE / nthreads_V; constexpr vec_dot_KQ_t vec_dot_KQ = get_vec_dot_KQ<type_K, D, nthreads_KQ>(); - constexpr bool Q_q8_1 = type_K != GGML_TYPE_F16; + constexpr bool Q_q8_1 = type_K != GGML_TYPE_F16 && type_K != GGML_TYPE_BF16; #ifdef V_DOT2_F32_F16_AVAILABLE constexpr dequantize_V_t dequantize_V = get_dequantize_V<type_V, half, V_rows_per_thread>(); #else @@ -132,10 +141,12 @@ static __global__ void flash_attn_ext_vec( #ifdef V_DOT2_F32_F16_AVAILABLE half2 Q_reg[ncols][(D/2)/nthreads_KQ]; // Will be initialized completely. #else - float2 Q_reg[ncols][(D/2)/nthreads_KQ] = {{{0.0f, 0.0f}}}; // May be only partially initialized. + __align__(16) float2 Q_reg[ncols][(D/2)/nthreads_KQ] = {{{0.0f, 0.0f}}}; // May be only partially initialized. #endif // V_DOT2_F32_F16_AVAILABLE int Q_i32[ncols][1 > D/(sizeof(int)*nthreads_KQ) ? 1 : D/(sizeof(int)*nthreads_KQ)]; float2 Q_ds[ncols][1 > D/(sizeof(int)*nthreads_KQ) ? 1 : D/(sizeof(int)*nthreads_KQ)]; + + ggml_cuda_pdl_sync(); if constexpr (Q_q8_1) { #pragma unroll for (int j0 = 0; j0 < ncols; j0 += nwarps) { @@ -200,7 +211,7 @@ static __global__ void flash_attn_ext_vec( for (int i0 = 0; i0 < D/2; i0 += nthreads_KQ*cpy_ne) { const int i = i0 + (nthreads_KQ == WARP_SIZE ? threadIdx.x : threadIdx.x % nthreads_KQ)*cpy_ne; - float2 tmp[cpy_ne] = {{0.0f, 0.0f}}; + __align__(16) float2 tmp[cpy_ne] = {{0.0f, 0.0f}}; if (ncols == 1 || ic0 + j < int(ne01.z)) { ggml_cuda_memcpy_1<cpy_nb>(tmp, &Q_j[i]); ggml_cuda_memcpy_1<cpy_nb>(tmp + cpy_ne/2, &Q_j[i + cpy_ne/2]); @@ -323,8 +334,18 @@ static __global__ void flash_attn_ext_vec( #pragma unroll for (int i_VKQ_0 = 0; i_VKQ_0 < D/2; i_VKQ_0 += nthreads_V*V_rows_per_thread/2) { half2 tmp[V_rows_per_thread/2]; - dequantize_V(V + k*nb21, tmp, - 2*i_VKQ_0 + (nthreads_V == WARP_SIZE ? threadIdx.x : threadIdx.x % nthreads_V)*V_rows_per_thread); + if constexpr (type_V == GGML_TYPE_BF16) { + float2 tmp_f[V_rows_per_thread/2]; + dequantize_V(V + k*nb21, tmp_f, + 2*i_VKQ_0 + (nthreads_V == WARP_SIZE ? threadIdx.x : threadIdx.x % nthreads_V)*V_rows_per_thread); +#pragma unroll + for (int i_VKQ_1 = 0; i_VKQ_1 < V_rows_per_thread/2; ++i_VKQ_1) { + tmp[i_VKQ_1] = __float22half2_rn(tmp_f[i_VKQ_1]); + } + } else { + dequantize_V(V + k*nb21, tmp, + 2*i_VKQ_0 + (nthreads_V == WARP_SIZE ? threadIdx.x : threadIdx.x % nthreads_V)*V_rows_per_thread); + } #pragma unroll for (int i_VKQ_1 = 0; i_VKQ_1 < V_rows_per_thread/2; ++i_VKQ_1) { #pragma unroll @@ -493,7 +514,7 @@ static __global__ void flash_attn_ext_vec( dst_meta[((sequence*int(ne01.z) + ic0 + tid)*ne02 + head)*gridDim.y + blockIdx.y] = make_float2(KQ_max[tid], KQ_sum[tid]); } #else - GGML_UNUSED_VARS(Q, K, V, mask, sinks, KV_max, dst, dst_meta, scale, + GGML_UNUSED_VARS(Q_ptr, K_ptr, V_ptr, mask_ptr, sinks_ptr, KV_max_ptr, dst_ptr, dst_meta_ptr, scale, max_bias, m0, m1, n_head_log2, logit_softcap, ne00, ne01, ne02, ne03, nb01, nb02, nb03, @@ -563,6 +584,7 @@ void ggml_cuda_flash_attn_ext_vec_case(ggml_backend_cuda_context & ctx, ggml_ten extern DECL_FATTN_VEC_CASE(D, type_K, GGML_TYPE_Q5_0); \ extern DECL_FATTN_VEC_CASE(D, type_K, GGML_TYPE_Q5_1); \ extern DECL_FATTN_VEC_CASE(D, type_K, GGML_TYPE_Q8_0); \ + extern DECL_FATTN_VEC_CASE(D, type_K, GGML_TYPE_BF16); \ EXTERN_DECL_FATTN_VEC_CASES( 64, GGML_TYPE_F16) EXTERN_DECL_FATTN_VEC_CASES( 64, GGML_TYPE_Q4_0) @@ -570,6 +592,7 @@ EXTERN_DECL_FATTN_VEC_CASES( 64, GGML_TYPE_Q4_1) EXTERN_DECL_FATTN_VEC_CASES( 64, GGML_TYPE_Q5_0) EXTERN_DECL_FATTN_VEC_CASES( 64, GGML_TYPE_Q5_1) EXTERN_DECL_FATTN_VEC_CASES( 64, GGML_TYPE_Q8_0) +EXTERN_DECL_FATTN_VEC_CASES( 64, GGML_TYPE_BF16) EXTERN_DECL_FATTN_VEC_CASES(128, GGML_TYPE_F16) EXTERN_DECL_FATTN_VEC_CASES(128, GGML_TYPE_Q4_0) @@ -577,6 +600,7 @@ EXTERN_DECL_FATTN_VEC_CASES(128, GGML_TYPE_Q4_1) EXTERN_DECL_FATTN_VEC_CASES(128, GGML_TYPE_Q5_0) EXTERN_DECL_FATTN_VEC_CASES(128, GGML_TYPE_Q5_1) EXTERN_DECL_FATTN_VEC_CASES(128, GGML_TYPE_Q8_0) +EXTERN_DECL_FATTN_VEC_CASES(128, GGML_TYPE_BF16) EXTERN_DECL_FATTN_VEC_CASES(256, GGML_TYPE_F16) EXTERN_DECL_FATTN_VEC_CASES(256, GGML_TYPE_Q4_0) @@ -584,3 +608,4 @@ EXTERN_DECL_FATTN_VEC_CASES(256, GGML_TYPE_Q4_1) EXTERN_DECL_FATTN_VEC_CASES(256, GGML_TYPE_Q5_0) EXTERN_DECL_FATTN_VEC_CASES(256, GGML_TYPE_Q5_1) EXTERN_DECL_FATTN_VEC_CASES(256, GGML_TYPE_Q8_0) +EXTERN_DECL_FATTN_VEC_CASES(256, GGML_TYPE_BF16) diff --git a/ggml/src/ggml-cuda/fattn-wmma-f16.cu b/ggml/src/ggml-cuda/fattn-wmma-f16.cu index 8694fd06c7b..6850716fc0d 100644 --- a/ggml/src/ggml-cuda/fattn-wmma-f16.cu +++ b/ggml/src/ggml-cuda/fattn-wmma-f16.cu @@ -24,14 +24,14 @@ namespace wmma = rocwmma; template<int D, int ncols, int nwarps, int VKQ_stride, typename KQ_acc_t, bool use_logit_softcap> __launch_bounds__(nwarps*ggml_cuda_get_physical_warp_size(), 1) static __global__ void flash_attn_ext_f16( - const char * __restrict__ Q, - const char * __restrict__ K, - const char * __restrict__ V, - const char * __restrict__ mask, - const char * __restrict__ sinks, - const int * __restrict__ KV_max, - float * __restrict__ dst, - float2 * __restrict__ dst_meta, + const char * Q_ptr, + const char * K_ptr, + const char * V_ptr, + const char * mask_ptr, + const char * sinks_ptr, + const int * KV_max_ptr, + float * dst_ptr, + float2 * dst_meta_ptr, const float scale, const float max_bias, const float m0, @@ -46,6 +46,14 @@ static __global__ void flash_attn_ext_f16( const int32_t ne31, const int32_t ne32, const int32_t ne33, const int32_t nb31, const int32_t nb32, const int64_t nb33) { #if defined(FLASH_ATTN_AVAILABLE) && (defined(GGML_HIP_ROCWMMA_FATTN) && defined(GGML_USE_WMMA_FATTN)) + const char * GGML_CUDA_RESTRICT Q = Q_ptr; + const char * GGML_CUDA_RESTRICT K = K_ptr; + const char * GGML_CUDA_RESTRICT V = V_ptr; + const char * GGML_CUDA_RESTRICT mask = mask_ptr; + const char * GGML_CUDA_RESTRICT sinks = sinks_ptr; + const int * GGML_CUDA_RESTRICT KV_max = KV_max_ptr; + float * GGML_CUDA_RESTRICT dst = dst_ptr; + float2 * GGML_CUDA_RESTRICT dst_meta = dst_meta_ptr; // Skip unused kernel variants for faster compilation: if (use_logit_softcap && !(D == 128 || D == 256)) { NO_DEVICE_CODE; @@ -63,11 +71,19 @@ static __global__ void flash_attn_ext_f16( constexpr int frag_m = ncols == 8 ? 32 : 16; constexpr int frag_n = ncols == 8 ? 8 : 16; static_assert(D % frag_m == 0, "If ncols == 8 then D % frag_m must be 0."); +#if defined(GGML_USE_HIP) && HIP_VERSION >= 60500000 + typedef wmma::fragment<wmma::matrix_a, frag_m, frag_n, 16, _Float16, wmma::row_major> frag_a_K; + typedef wmma::fragment<wmma::matrix_a, frag_m, frag_n, 16, _Float16, wmma::col_major> frag_a_V; + typedef wmma::fragment<wmma::matrix_b, frag_m, frag_n, 16, _Float16, wmma::col_major> frag_b; + typedef wmma::fragment<wmma::accumulator, frag_m, frag_n, 16, KQ_acc_t> frag_c_KQ; + typedef wmma::fragment<wmma::accumulator, frag_m, frag_n, 16, _Float16> frag_c_VKQ; +#else typedef wmma::fragment<wmma::matrix_a, frag_m, frag_n, 16, half, wmma::row_major> frag_a_K; typedef wmma::fragment<wmma::matrix_a, frag_m, frag_n, 16, half, wmma::col_major> frag_a_V; typedef wmma::fragment<wmma::matrix_b, frag_m, frag_n, 16, half, wmma::col_major> frag_b; typedef wmma::fragment<wmma::accumulator, frag_m, frag_n, 16, KQ_acc_t> frag_c_KQ; typedef wmma::fragment<wmma::accumulator, frag_m, frag_n, 16, half> frag_c_VKQ; +#endif constexpr int KQ_stride_tc = nwarps*frag_m; // Number of KQ rows calculated in parallel. constexpr int VKQ_ratio = KQ_stride_tc/VKQ_stride; // Number of parallel VKQ accumulators needed to keep all warps busy. @@ -78,6 +94,7 @@ static __global__ void flash_attn_ext_f16( constexpr int kqs_padded = FATTN_KQ_STRIDE + 8; constexpr int kqar = sizeof(KQ_acc_t)/sizeof(half); + ggml_cuda_pdl_sync(); const int sequence = blockIdx.z / ne02; const int head = blockIdx.z - sequence*ne02; const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix. @@ -126,6 +143,19 @@ static __global__ void flash_attn_ext_f16( __shared__ half VKQ[ncols*D_padded]; // Accumulator for final VKQ slice. half2 * VKQ2 = (half2 *) VKQ; + +#if defined(GGML_USE_HIP) && HIP_VERSION >= 60500000 + const _Float16 * K_h_f16 = reinterpret_cast<const _Float16 *>(K_h); + const _Float16 * V_h_f16 = reinterpret_cast<const _Float16 *>(V_h); + _Float16 * KQ_f16 = reinterpret_cast<_Float16 *>(KQ); + _Float16 * VKQ_f16 = reinterpret_cast<_Float16 *>(VKQ); +#else + const half * K_h_f16 = K_h; + const half * V_h_f16 = V_h; + half * KQ_f16 = KQ; + half * VKQ_f16 = VKQ; +#endif + #pragma unroll for (int j0 = 0; j0 < ncols; j0 += nwarps) { const int j = j0 + threadIdx.y; @@ -160,7 +190,7 @@ static __global__ void flash_attn_ext_f16( for (int i0 = 0; i0 < D; i0 += 16) { #pragma unroll for (int j0 = 0; j0 < ncols; j0 += frag_n) { - wmma::load_matrix_sync(Q_b[i0/16][j0/frag_n], KQ + j0*D_padded + i0, D_padded); + wmma::load_matrix_sync(Q_b[i0/16][j0/frag_n], KQ_f16 + j0*D_padded + i0, D_padded); } } @@ -180,7 +210,7 @@ static __global__ void flash_attn_ext_f16( #pragma unroll for (int k_KQ_0 = 0; k_KQ_0 < D; k_KQ_0 += 16) { frag_a_K K_a; - wmma::load_matrix_sync(K_a, K_h + int64_t(k_VKQ_0 + i_KQ_0 + frag_m*threadIdx.y)*stride_KV + k_KQ_0, stride_KV); + wmma::load_matrix_sync(K_a, K_h_f16 + int64_t(k_VKQ_0 + i_KQ_0 + frag_m*threadIdx.y)*stride_KV + k_KQ_0, stride_KV); #pragma unroll for (int j = 0; j < ncols/frag_n; ++j) { wmma::mma_sync(KQ_c[j], K_a, Q_b[k_KQ_0/16][j], KQ_c[j]); @@ -310,7 +340,7 @@ static __global__ void flash_attn_ext_f16( const int k = k0 + (threadIdx.y % VKQ_ratio)*16; wmma::load_matrix_sync( KQ_b[k0/(VKQ_ratio*16)][j0/frag_n], - KQ + j0*(kqar*kqs_padded) + k, + KQ_f16 + j0*(kqar*kqs_padded) + k, kqar*kqs_padded); } } @@ -328,7 +358,7 @@ static __global__ void flash_attn_ext_f16( const int k = k0 + (threadIdx.y % VKQ_ratio)*16; frag_a_V v_a; - wmma::load_matrix_sync(v_a, V_h + int64_t(k_VKQ_0 + k)*stride_KV + i_VKQ_0 + frag_m*(threadIdx.y/VKQ_ratio), stride_KV); + wmma::load_matrix_sync(v_a, V_h_f16 + int64_t(k_VKQ_0 + k)*stride_KV + i_VKQ_0 + frag_m*(threadIdx.y/VKQ_ratio), stride_KV); #pragma unroll for (int j = 0; j < ncols/frag_n; ++j) { wmma::mma_sync(VKQ_c[i_VKQ_0/VKQ_stride][j], v_a, KQ_b[k0/(VKQ_ratio*16)][j], VKQ_c[i_VKQ_0/VKQ_stride][j]); @@ -344,7 +374,7 @@ static __global__ void flash_attn_ext_f16( #pragma unroll for (int j0 = 0; j0 < ncols; j0 += frag_n) { wmma::store_matrix_sync( - KQ + offset_k + j0*D_padded + i_KQ_0 + frag_m*(threadIdx.y/VKQ_ratio), + KQ_f16 + offset_k + j0*D_padded + i_KQ_0 + frag_m*(threadIdx.y/VKQ_ratio), VKQ_c[i_KQ_0/VKQ_stride][j0/frag_n], D_padded, wmma::mem_col_major); } @@ -472,7 +502,7 @@ static __global__ void flash_attn_ext_f16( dst_meta[j_dst_unrolled] = dst_meta_val; } #else - GGML_UNUSED_VARS(Q, K, V, mask, sinks, KV_max, dst, dst_meta, scale, + GGML_UNUSED_VARS(Q_ptr, K_ptr, V_ptr, mask_ptr, sinks_ptr, KV_max_ptr, dst_ptr, dst_meta_ptr, scale, max_bias, m0, m1, n_head_log2, logit_softcap, ne00, ne01, ne02, ne03, nb01, nb02, nb03, diff --git a/ggml/src/ggml-cuda/fattn-wmma-f16.cuh b/ggml/src/ggml-cuda/fattn-wmma-f16.cuh index cd3bfd4051a..aaf711a618c 100644 --- a/ggml/src/ggml-cuda/fattn-wmma-f16.cuh +++ b/ggml/src/ggml-cuda/fattn-wmma-f16.cuh @@ -18,7 +18,7 @@ #if defined(RDNA4) && ROCWMMA_VERSION_MAJOR > 1 #define GGML_USE_WMMA_FATTN #elif defined(RDNA4) -#warning "rocwmma fattn is not suported on RDNA4 on rocwmma < v2.0.0, expect degraded performance" +#warning "rocwmma fattn is not supported on RDNA4 on rocwmma < v2.0.0, expect degraded performance" #endif // defined(RDNA4) && ROCWMMA_VERSION_MAJOR > 1 #endif // defined(GGML_HIP_ROCWMMA_FATTN) diff --git a/ggml/src/ggml-cuda/fattn.cu b/ggml/src/ggml-cuda/fattn.cu index 0155406665c..d6c501b1d7e 100644 --- a/ggml/src/ggml-cuda/fattn.cu +++ b/ggml/src/ggml-cuda/fattn.cu @@ -18,12 +18,15 @@ static void ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1(ggml_backend_cuda_con } } - if (turing_mma_available(cc) && Q->ne[1] <= 16/ncols2) { - ggml_cuda_flash_attn_ext_mma_f16_case<DKQ, DV, 16/ncols2, ncols2>(ctx, dst); - return; + if constexpr (ncols2 <= 16) { + if (Q->ne[1] <= 16/ncols2) { + ggml_cuda_flash_attn_ext_mma_f16_case<DKQ, DV, 16/ncols2, ncols2>(ctx, dst); + return; + } } - if (ggml_cuda_highest_compiled_arch(cc) == GGML_CUDA_CC_TURING || Q->ne[1] <= 32/ncols2) { + if (Q->ne[1] <= 32/ncols2 || (GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) == GGML_CUDA_CC_TURING) || + (GGML_CUDA_CC_IS_AMD(cc) && DKQ > 256)) { ggml_cuda_flash_attn_ext_mma_f16_case<DKQ, DV, 32/ncols2, ncols2>(ctx, dst); return; } @@ -33,6 +36,7 @@ static void ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1(ggml_backend_cuda_con template <int DKQ, int DV> static void ggml_cuda_flash_attn_ext_mma_f16_switch_ncols2(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc; const ggml_tensor * KQV = dst; const ggml_tensor * Q = dst->src[0]; const ggml_tensor * K = dst->src[1]; @@ -46,7 +50,7 @@ static void ggml_cuda_flash_attn_ext_mma_f16_switch_ncols2(ggml_backend_cuda_con // are put into the template specialization without GQA optimizations. bool use_gqa_opt = mask && max_bias == 0.0f && K->ne[1] % FATTN_KQ_STRIDE == 0; for (const ggml_tensor * t : {Q, K, V, mask}) { - if (t == nullptr) { + if (t == nullptr || ggml_is_quantized(t->type)) { continue; } for (size_t i = 1; i < GGML_MAX_DIMS; ++i) { @@ -60,25 +64,55 @@ static void ggml_cuda_flash_attn_ext_mma_f16_switch_ncols2(ggml_backend_cuda_con GGML_ASSERT(Q->ne[2] % K->ne[2] == 0); const int gqa_ratio = Q->ne[2] / K->ne[2]; - if (use_gqa_opt && gqa_ratio % 8 == 0) { + // On Volta the GQA optimizations aren't as impactful vs. minimizing wasted compute: + if (cc == GGML_CUDA_CC_VOLTA) { + if (use_gqa_opt && gqa_ratio % 8 == 0) { + ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<DKQ, DV, 8>(ctx, dst); + return; + } + + if (use_gqa_opt && gqa_ratio % 4 == 0) { + ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<DKQ, DV, 4>(ctx, dst); + return; + } + + if constexpr (DKQ <= 256) { + if (use_gqa_opt && gqa_ratio % 2 == 0) { + ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<DKQ, DV, 2>(ctx, dst); + return; + } + + ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<DKQ, DV, 1>(ctx, dst); + return; + } else { + GGML_ABORT("fatal error"); + } + } + + if (use_gqa_opt && gqa_ratio > 4) { ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<DKQ, DV, 8>(ctx, dst); return; } - if (use_gqa_opt && gqa_ratio % 4 == 0) { + if (use_gqa_opt && gqa_ratio > 2) { ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<DKQ, DV, 4>(ctx, dst); return; } - if (use_gqa_opt && gqa_ratio % 2 == 0) { - ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<DKQ, DV, 2>(ctx, dst); - return; - } + if constexpr (DKQ <= 256) { + if (use_gqa_opt && gqa_ratio > 1) { + ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<DKQ, DV, 2>(ctx, dst); + return; + } - ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<DKQ, DV, 1>(ctx, dst); + ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<DKQ, DV, 1>(ctx, dst); + } else { + GGML_ABORT("fatal error"); + } } static void ggml_cuda_flash_attn_ext_mma_f16(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc; const ggml_tensor * KQV = dst; const ggml_tensor * Q = dst->src[0]; const ggml_tensor * K = dst->src[1]; @@ -106,10 +140,46 @@ static void ggml_cuda_flash_attn_ext_mma_f16(ggml_backend_cuda_context & ctx, gg GGML_ASSERT(V->ne[0] == 128); ggml_cuda_flash_attn_ext_mma_f16_switch_ncols2<128, 128>(ctx, dst); break; + case 192: { + // MiMo-V2.5 / V2.5-Pro / V2-Flash: gqa_ratio is 8 (SWA) or 16 (full attn) + GGML_ASSERT(V->ne[0] == 128); + float max_bias = 0.0f; + memcpy(&max_bias, (const float *) KQV->op_params + 1, sizeof(float)); + const bool use_gqa_opt = mask && max_bias == 0.0f; + GGML_ASSERT(use_gqa_opt); + GGML_ASSERT(Q->ne[2] % K->ne[2] == 0); + const int gqa_ratio = Q->ne[2] / K->ne[2]; + if (gqa_ratio % 16 == 0) { + ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<192, 128, 16>(ctx, dst); + } else { + GGML_ASSERT(gqa_ratio % 8 == 0); + ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<192, 128, 8>(ctx, dst); + } + } break; case 256: GGML_ASSERT(V->ne[0] == 256); ggml_cuda_flash_attn_ext_mma_f16_switch_ncols2<256, 256>(ctx, dst); break; + case 320: + // For Mistral Small 4, go straight to the ncols1 switch (ncols2=32-only build). + GGML_ASSERT(V->ne[0] == 256); + { + float max_bias = 0.0f; + memcpy(&max_bias, (const float *) KQV->op_params + 1, sizeof(float)); + + const bool use_gqa_opt = mask && max_bias == 0.0f; + GGML_ASSERT(use_gqa_opt); + GGML_ASSERT(Q->ne[2] % K->ne[2] == 0); + const int gqa_ratio = Q->ne[2] / K->ne[2]; + GGML_ASSERT(gqa_ratio % 32 == 0); + + ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<320, 256, 32>(ctx, dst); + } + break; + case 512: + GGML_ASSERT(V->ne[0] == 512); + ggml_cuda_flash_attn_ext_mma_f16_switch_ncols2<512, 512>(ctx, dst); + break; case 576: { // For Deepseek, go straight to the ncols1 switch to avoid compiling unnecessary kernels. GGML_ASSERT(V->ne[0] == 512); @@ -121,8 +191,50 @@ static void ggml_cuda_flash_attn_ext_mma_f16(ggml_backend_cuda_context & ctx, gg GGML_ASSERT(Q->ne[2] % K->ne[2] == 0); const int gqa_ratio = Q->ne[2] / K->ne[2]; - GGML_ASSERT(gqa_ratio % 16 == 0); - ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<576, 512, 16>(ctx, dst); + if (gqa_ratio == 20) { // GLM 4.7 Flash + if (cc >= GGML_CUDA_CC_DGX_SPARK) { + if (Q->ne[1] <= 8) { + ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<576, 512, 16>(ctx, dst); + break; + } + ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<576, 512, 4>(ctx, dst); + break; + } + if (cc >= GGML_CUDA_CC_BLACKWELL) { + if (Q->ne[1] <= 4 && K->ne[1] >= 65536) { + ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<576, 512, 16>(ctx, dst); + break; + } + ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<576, 512, 4>(ctx, dst); + break; + } + if (cc >= GGML_CUDA_CC_ADA_LOVELACE) { + if (Q->ne[1] <= 4) { + ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<576, 512, 16>(ctx, dst); + break; + } + ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<576, 512, 4>(ctx, dst); + break; + } + if (cc >= GGML_CUDA_CC_TURING) { + if (Q->ne[1] <= 4) { + if (K->ne[1] <= 16384) { + ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<576, 512, 16>(ctx, dst); + break; + } + ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<576, 512, 32>(ctx, dst); + break; + } + ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<576, 512, 4>(ctx, dst); + break; + } + // Volta: + ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<576, 512, 4>(ctx, dst); + } else if (gqa_ratio % 16 == 0) { + ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<576, 512, 16>(ctx, dst); + } else { + ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<576, 512, 4>(ctx, dst); + } } break; default: GGML_ABORT("fatal error"); @@ -157,6 +269,7 @@ static void ggml_cuda_flash_attn_ext_vec(ggml_backend_cuda_context & ctx, ggml_t FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_0, GGML_TYPE_F16) FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_1, GGML_TYPE_F16) FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q8_0, GGML_TYPE_F16) + FATTN_VEC_CASES_ALL_D(GGML_TYPE_BF16, GGML_TYPE_F16) FATTN_VEC_CASES_ALL_D(GGML_TYPE_F16, GGML_TYPE_Q4_0) FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_0, GGML_TYPE_Q4_0) @@ -164,6 +277,7 @@ static void ggml_cuda_flash_attn_ext_vec(ggml_backend_cuda_context & ctx, ggml_t FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_0, GGML_TYPE_Q4_0) FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_1, GGML_TYPE_Q4_0) FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q8_0, GGML_TYPE_Q4_0) + FATTN_VEC_CASES_ALL_D(GGML_TYPE_BF16, GGML_TYPE_Q4_0) FATTN_VEC_CASES_ALL_D(GGML_TYPE_F16, GGML_TYPE_Q4_1) FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_0, GGML_TYPE_Q4_1) @@ -171,6 +285,7 @@ static void ggml_cuda_flash_attn_ext_vec(ggml_backend_cuda_context & ctx, ggml_t FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_0, GGML_TYPE_Q4_1) FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_1, GGML_TYPE_Q4_1) FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q8_0, GGML_TYPE_Q4_1) + FATTN_VEC_CASES_ALL_D(GGML_TYPE_BF16, GGML_TYPE_Q4_1) FATTN_VEC_CASES_ALL_D(GGML_TYPE_F16, GGML_TYPE_Q5_0) FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_0, GGML_TYPE_Q5_0) @@ -178,6 +293,7 @@ static void ggml_cuda_flash_attn_ext_vec(ggml_backend_cuda_context & ctx, ggml_t FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_0, GGML_TYPE_Q5_0) FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_1, GGML_TYPE_Q5_0) FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q8_0, GGML_TYPE_Q5_0) + FATTN_VEC_CASES_ALL_D(GGML_TYPE_BF16, GGML_TYPE_Q5_0) FATTN_VEC_CASES_ALL_D(GGML_TYPE_F16, GGML_TYPE_Q5_1) FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_0, GGML_TYPE_Q5_1) @@ -185,6 +301,7 @@ static void ggml_cuda_flash_attn_ext_vec(ggml_backend_cuda_context & ctx, ggml_t FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_0, GGML_TYPE_Q5_1) FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_1, GGML_TYPE_Q5_1) FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q8_0, GGML_TYPE_Q5_1) + FATTN_VEC_CASES_ALL_D(GGML_TYPE_BF16, GGML_TYPE_Q5_1) FATTN_VEC_CASES_ALL_D(GGML_TYPE_F16, GGML_TYPE_Q8_0) FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_0, GGML_TYPE_Q8_0) @@ -192,10 +309,20 @@ static void ggml_cuda_flash_attn_ext_vec(ggml_backend_cuda_context & ctx, ggml_t FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_0, GGML_TYPE_Q8_0) FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_1, GGML_TYPE_Q8_0) FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q8_0, GGML_TYPE_Q8_0) + FATTN_VEC_CASES_ALL_D(GGML_TYPE_BF16, GGML_TYPE_Q8_0) + + FATTN_VEC_CASES_ALL_D(GGML_TYPE_F16, GGML_TYPE_BF16) + FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_0, GGML_TYPE_BF16) + FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_1, GGML_TYPE_BF16) + FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_0, GGML_TYPE_BF16) + FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_1, GGML_TYPE_BF16) + FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q8_0, GGML_TYPE_BF16) + FATTN_VEC_CASES_ALL_D(GGML_TYPE_BF16, GGML_TYPE_BF16) #else FATTN_VEC_CASES_ALL_D(GGML_TYPE_F16, GGML_TYPE_F16) FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_0, GGML_TYPE_Q4_0) FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q8_0, GGML_TYPE_Q8_0) + FATTN_VEC_CASES_ALL_D(GGML_TYPE_BF16, GGML_TYPE_BF16) #endif // GGML_CUDA_FA_ALL_QUANTS GGML_ABORT("fatal error"); @@ -230,7 +357,18 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const // The effective batch size for the kernel can be increased by gqa_ratio. // The kernel versions without this optimization are also used for ALiBi, if there is no mask, or if the KV cache is not padded, - const bool gqa_opt_applies = gqa_ratio % 2 == 0 && mask && max_bias == 0.0f && K->ne[1] % FATTN_KQ_STRIDE == 0; + bool gqa_opt_applies = gqa_ratio >= 2 && mask && max_bias == 0.0f && K->ne[1] % FATTN_KQ_STRIDE == 0; + for (const ggml_tensor * t : {Q, K, V, mask}) { + if (t == nullptr || ggml_is_quantized(t->type)) { + continue; + } + for (size_t i = 1; i < GGML_MAX_DIMS; ++i) { + if (t->nb[i] % 16 != 0) { + gqa_opt_applies = false; + break; + } + } + } const int cc = ggml_cuda_info().devices[device].cc; @@ -247,11 +385,35 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const return BEST_FATTN_KERNEL_NONE; } break; + case 192: + if (V->ne[0] != 128 || !gqa_opt_applies) { + return BEST_FATTN_KERNEL_NONE; + } + if (gqa_ratio % 8 != 0) { + return BEST_FATTN_KERNEL_NONE; + } + break; + case 320: + if (V->ne[0] != 256 || !gqa_opt_applies) { + return BEST_FATTN_KERNEL_NONE; + } + if (gqa_ratio % 32 != 0) { + return BEST_FATTN_KERNEL_NONE; + } + break; + case 512: + if (V->ne[0] != K->ne[0]) { + return BEST_FATTN_KERNEL_NONE; + } + if (!gqa_opt_applies) { + return BEST_FATTN_KERNEL_NONE; + } + break; case 576: if (V->ne[0] != 512) { return BEST_FATTN_KERNEL_NONE; } - if (!gqa_opt_applies || gqa_ratio % 16 != 0) { + if (!gqa_opt_applies) { return BEST_FATTN_KERNEL_NONE; } break; @@ -277,6 +439,7 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const #endif // GGML_CUDA_FA_ALL_QUANTS case GGML_TYPE_Q4_0: case GGML_TYPE_Q8_0: + case GGML_TYPE_BF16: break; default: return BEST_FATTN_KERNEL_NONE; @@ -287,7 +450,8 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const } // For small batch sizes the vector kernel may be preferable over the kernels optimized for large batch sizes: - const bool can_use_vector_kernel = Q->ne[0] <= 256 && Q->ne[0] % 64 == 0 && K->ne[1] % FATTN_KQ_STRIDE == 0; + // 192 satisfies % 64 == 0 but has no vec instance (DKQ != DV); force it onto the MMA path. + const bool can_use_vector_kernel = Q->ne[0] <= 256 && Q->ne[0] % 64 == 0 && Q->ne[0] != 192 && K->ne[1] % FATTN_KQ_STRIDE == 0; // If Turing tensor cores are available, use them: if (turing_mma_available(cc) && Q->ne[0] != 40 && Q->ne[0] != 72) { @@ -314,12 +478,13 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const return BEST_FATTN_KERNEL_MMA_F16; } + const int ncols2_max = Q->ne[0] == 320 ? 32 : ((Q->ne[0] == 576 || Q->ne[0] == 192) ? 16 : 8); + int gqa_ratio_eff = 1; + while (gqa_ratio % (2*gqa_ratio_eff) == 0 && gqa_ratio_eff < ncols2_max) { + gqa_ratio_eff *= 2; + } + if (volta_mma_available(cc) && Q->ne[0] != 40 && Q->ne[0] != 72) { - int gqa_ratio_eff = 1; - const int ncols2_max = Q->ne[0] == 576 ? 16 : 8; - while (gqa_ratio % (2*gqa_ratio_eff) == 0 && gqa_ratio_eff < ncols2_max) { - gqa_ratio_eff *= 2; - } if (can_use_vector_kernel && Q->ne[1] * gqa_ratio_eff <= 2) { return BEST_FATTN_KERNEL_VEC; } @@ -330,13 +495,31 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const } // Use the WMMA kernel if possible: - if (ggml_cuda_should_use_wmma_fattn(cc) && K->ne[1] % FATTN_KQ_STRIDE == 0 && Q->ne[0] != 40 && Q->ne[0] != 72 && Q->ne[0] != 576) { + if (ggml_cuda_should_use_wmma_fattn(cc) && K->ne[1] % FATTN_KQ_STRIDE == 0 && Q->ne[0] != 40 && Q->ne[0] != 72 && Q->ne[0] != 192 && Q->ne[0] != 512 && Q->ne[0] != 576) { if (can_use_vector_kernel && Q->ne[1] <= 2) { return BEST_FATTN_KERNEL_VEC; } return BEST_FATTN_KERNEL_WMMA_F16; } + // AMD MFMA needs a certain minimum batch size to outscale the tile kernel for large head sizes. + if ((amd_mfma_available(cc) && Q->ne[0] <= 256) && Q->ne[0] != 40 && Q->ne[0] != 72) { + if ((Q->ne[0] <= 64 && Q->ne[1] * gqa_ratio_eff > 8)) { + return BEST_FATTN_KERNEL_MMA_F16; + } + if ((Q->ne[0] <= 128 && Q->ne[1] * gqa_ratio_eff > 16)) { + return BEST_FATTN_KERNEL_MMA_F16; + } + if ((Q->ne[0] <= 256 && Q->ne[1] * gqa_ratio_eff > 64)) { + return BEST_FATTN_KERNEL_MMA_F16; + } + } + + // AMD WMMA is always faster than the tile kernel if the full tile width of 16 can be utilized. + if ((amd_wmma_available(cc) && gqa_opt_applies && Q->ne[0] <= 128) && Q->ne[0] != 40 && Q->ne[0] != 72 && Q->ne[1] * gqa_ratio_eff > 8) { + return BEST_FATTN_KERNEL_MMA_F16; + } + // If there are no tensor cores available, use the generic tile kernel: if (can_use_vector_kernel) { if (!ggml_is_quantized(K->type) && !ggml_is_quantized(V->type)) { @@ -354,6 +537,41 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const return BEST_FATTN_KERNEL_TILE; } +size_t ggml_cuda_flash_attn_ext_get_alloc_size(int device, const ggml_tensor * dst) { + GGML_ASSERT(dst->op == GGML_OP_FLASH_ATTN_EXT); + + const ggml_tensor * K = dst->src[1]; + const ggml_tensor * V = dst->src[2]; + + GGML_ASSERT(K != nullptr); + GGML_ASSERT(V != nullptr); + + const best_fattn_kernel kernel = ggml_cuda_get_best_fattn_kernel(device, dst); + + bool need_f16_K = false; + bool need_f16_V = false; + + switch (kernel) { + case BEST_FATTN_KERNEL_TILE: + case BEST_FATTN_KERNEL_WMMA_F16: + case BEST_FATTN_KERNEL_MMA_F16: + need_f16_K = true; + need_f16_V = true; + break; + case BEST_FATTN_KERNEL_VEC: + need_f16_K = K->type == GGML_TYPE_F32; + need_f16_V = V->type == GGML_TYPE_F32; + break; + case BEST_FATTN_KERNEL_NONE: + break; + } + + const ggml_cuda_flash_attn_ext_f16_extra_data f16_extra = + ggml_cuda_flash_attn_ext_get_f16_extra_data(dst, need_f16_K, need_f16_V); + + return f16_extra.end - (uintptr_t) dst->data; +} + void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { ggml_cuda_set_device(ctx.device); switch (ggml_cuda_get_best_fattn_kernel(ggml_cuda_get_device(), dst)) { diff --git a/ggml/src/ggml-cuda/fattn.cuh b/ggml/src/ggml-cuda/fattn.cuh index 78705d59951..f9a7e15fbd6 100644 --- a/ggml/src/ggml-cuda/fattn.cuh +++ b/ggml/src/ggml-cuda/fattn.cuh @@ -3,3 +3,5 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst); bool ggml_cuda_flash_attn_ext_supported(int device, const ggml_tensor * dst); + +size_t ggml_cuda_flash_attn_ext_get_alloc_size(int device, const ggml_tensor * dst); diff --git a/ggml/src/ggml-cuda/fwht.cu b/ggml/src/ggml-cuda/fwht.cu new file mode 100644 index 00000000000..184dc254c72 --- /dev/null +++ b/ggml/src/ggml-cuda/fwht.cu @@ -0,0 +1,101 @@ +#include "common.cuh" +#include "fwht.cuh" + +template <int N> +__launch_bounds__(4*ggml_cuda_get_physical_warp_size(), 1) +__global__ void fwht_cuda(const float * src, float * dst, const int64_t n_rows, const float scale) { + constexpr int warp_size = ggml_cuda_get_physical_warp_size(); + + const int64_t r = (int64_t) blockIdx.x * blockDim.y + threadIdx.y; + + if (r >= n_rows) { + return; + } + + src += r * N; + dst += r * N; + + static constexpr int el_w = N / warp_size; + float reg[el_w]; + const int lane = threadIdx.x; + + ggml_cuda_pdl_sync(); +#pragma unroll + for (int i = 0; i < el_w; ++i) { + reg[i] = src[i * warp_size + lane] * scale; + } + +#pragma unroll + for (int h = 1; h < warp_size; h *= 2) { +#pragma unroll + for (int j = 0; j < el_w; j++) { + const float val = reg[j]; + const float val2 = __shfl_xor_sync(0xFFFFFFFF, val, h, warp_size); + + reg[j] = (lane & h) == 0 ? val + val2 : val2 - val; + } + } + +#pragma unroll + for (int h = warp_size; h < N; h *= 2) { + const int step = h / warp_size; +#pragma unroll + for (int j = 0; j < el_w; j += 2 * step) { +#pragma unroll + for (int k = 0; k < step; k++) { + const float x = reg[j + k]; + const float y = reg[j + k + step]; + + reg[j + k] = x + y; + reg[j + k + step] = x - y; + } + } + } + +#pragma unroll + for (int i = 0; i < el_w; ++i) { + dst[i * warp_size + lane] = reg[i]; + } +} + +bool ggml_cuda_op_fwht(ggml_backend_cuda_context & ctx, const ggml_tensor * src, ggml_tensor * dst) { + GGML_ASSERT(ggml_are_same_shape(src, dst)); + if (!ggml_is_contiguous(src) || !ggml_is_contiguous(dst)) { + return false; + } + const int n = src->ne[0]; + const int64_t rows = ggml_nrows(src); + + const float * src_d = (const float *) src->data; + float * dst_d = (float *) dst->data; + + const int warp_size = ggml_cuda_info().devices[ggml_cuda_get_device()].warp_size; + const int rows_per_block = 4; + + const int64_t num_blocks = (rows + rows_per_block - 1) / rows_per_block; + + cudaStream_t stream = ctx.stream(); + dim3 grid_dims(num_blocks, 1, 1); + dim3 block_dims(warp_size, rows_per_block, 1); + const ggml_cuda_kernel_launch_params launch_params = + ggml_cuda_kernel_launch_params(grid_dims, block_dims, 0, stream); + + const float scale = 1 / sqrtf(n); + + switch (n) { + case 64: + ggml_cuda_kernel_launch(fwht_cuda<64>, launch_params, src_d, dst_d, rows, scale); + return true; + case 128: + ggml_cuda_kernel_launch(fwht_cuda<128>, launch_params, src_d, dst_d, rows, scale); + return true; + case 256: + ggml_cuda_kernel_launch(fwht_cuda<256>, launch_params, src_d, dst_d, rows, scale); + return true; + case 512: + ggml_cuda_kernel_launch(fwht_cuda<512>, launch_params, src_d, dst_d, rows, scale); + return true; + default: + return false; + } +} diff --git a/ggml/src/ggml-cuda/fwht.cuh b/ggml/src/ggml-cuda/fwht.cuh new file mode 100644 index 00000000000..cf3df94cafa --- /dev/null +++ b/ggml/src/ggml-cuda/fwht.cuh @@ -0,0 +1,4 @@ +#include "common.cuh" + +// Returns whether the Fast Walsh-Hadamard transform could be used. +bool ggml_cuda_op_fwht(ggml_backend_cuda_context & ctx, const ggml_tensor * src, ggml_tensor * dst); diff --git a/ggml/src/ggml-cuda/gated_delta_net.cu b/ggml/src/ggml-cuda/gated_delta_net.cu new file mode 100644 index 00000000000..a547360eb06 --- /dev/null +++ b/ggml/src/ggml-cuda/gated_delta_net.cu @@ -0,0 +1,312 @@ +#include "gated_delta_net.cuh" +#include "ggml-cuda/common.cuh" + +template <int S_v, bool KDA, bool keep_rs_t> +__global__ void __launch_bounds__((ggml_cuda_get_physical_warp_size() < S_v ? ggml_cuda_get_physical_warp_size() : S_v) * 4, 2) +gated_delta_net_cuda(const float * q, + const float * k, + const float * v, + const float * g, + const float * beta, + const float * curr_state, + float * dst, + int64_t H, + int64_t n_tokens, + int64_t n_seqs, + int64_t sq1, + int64_t sq2, + int64_t sq3, + int64_t sv1, + int64_t sv2, + int64_t sv3, + int64_t sb1, + int64_t sb2, + int64_t sb3, + const uint3 neqk1_magic, + const uint3 rq3_magic, + float scale, + int K) { + const uint32_t h_idx = blockIdx.x; + const uint32_t sequence = blockIdx.y; + // each warp owns one column, using warp-level primitives to reduce across rows + const int lane = threadIdx.x; + const int col = blockIdx.z * blockDim.y + threadIdx.y; + + const uint32_t iq1 = fastmodulo(h_idx, neqk1_magic); + const uint32_t iq3 = fastdiv(sequence, rq3_magic); + + const int64_t attn_score_elems = S_v * H * n_tokens * n_seqs; + float * attn_data = dst; + float * state = dst + attn_score_elems; + + // input state holds s0 only: [S_v, S_v, H, n_seqs] — seq stride is D = H * S_v * S_v. + // output state layout (per-slot D * n_seqs) — same per-(seq,head) offset as before. + const int64_t state_in_offset = sequence * H * S_v * S_v + h_idx * S_v * S_v; + const int64_t state_out_offset = (sequence * H + h_idx) * S_v * S_v; + state += state_out_offset; + curr_state += state_in_offset + col * S_v; + attn_data += (sequence * n_tokens * H + h_idx) * S_v; + + constexpr int warp_size = ggml_cuda_get_physical_warp_size() < S_v ? ggml_cuda_get_physical_warp_size() : S_v; + static_assert(S_v % warp_size == 0, "S_v must be a multiple of warp_size"); + constexpr int rows_per_lane = (S_v + warp_size - 1) / warp_size; + float s_shard[rows_per_lane]; + // state is stored transposed: M[col][i] = S[i][col], row col is contiguous + + ggml_cuda_pdl_sync(); +#pragma unroll + for (int r = 0; r < rows_per_lane; r++) { + const int i = r * warp_size + lane; + s_shard[r] = curr_state[i]; + } + + for (int t = 0; t < n_tokens; t++) { + const float * q_t = q + iq3 * sq3 + t * sq2 + iq1 * sq1; + const float * k_t = k + iq3 * sq3 + t * sq2 + iq1 * sq1; + const float * v_t = v + sequence * sv3 + t * sv2 + h_idx * sv1; + + const int64_t gb_offset = sequence * sb3 + t * sb2 + h_idx * sb1; + const float * beta_t = beta + gb_offset; + const float * g_t = g + gb_offset * (KDA ? S_v : 1); + + const float beta_val = *beta_t; + + // Cache k and q in registers + float k_reg[rows_per_lane]; + float q_reg[rows_per_lane]; +#pragma unroll + for (int r = 0; r < rows_per_lane; r++) { + const int i = r * warp_size + lane; + k_reg[r] = k_t[i]; + q_reg[r] = q_t[i]; + } + + if constexpr (!KDA) { + const float g_val = expf(*g_t); + + // kv[col] = (S^T @ k)[col] = sum_i S[i][col] * k[i] + float kv_shard = 0.0f; +#pragma unroll + for (int r = 0; r < rows_per_lane; r++) { + kv_shard += s_shard[r] * k_reg[r]; + } + float kv_col = warp_reduce_sum<warp_size>(kv_shard); + + // delta[col] = (v[col] - g * kv[col]) * beta + float delta_col = (v_t[col] - g_val * kv_col) * beta_val; + + // fused: S[i][col] = g * S[i][col] + k[i] * delta[col] + // attn[col] = (S^T @ q)[col] = sum_i S[i][col] * q[i] + float attn_partial = 0.0f; +#pragma unroll + for (int r = 0; r < rows_per_lane; r++) { + s_shard[r] = g_val * s_shard[r] + k_reg[r] * delta_col; + attn_partial += s_shard[r] * q_reg[r]; + } + + float attn_col = warp_reduce_sum<warp_size>(attn_partial); + + if (lane == 0) { + attn_data[col] = attn_col * scale; + } + } else { + // kv[col] = sum_i g[i] * S[i][col] * k[i] + float kv_shard = 0.0f; +#pragma unroll + for (int r = 0; r < rows_per_lane; r++) { + const int i = r * warp_size + lane; + kv_shard += expf(g_t[i]) * s_shard[r] * k_reg[r]; + } + + float kv_col = warp_reduce_sum<warp_size>(kv_shard); + + // delta[col] = (v[col] - kv[col]) * beta + float delta_col = (v_t[col] - kv_col) * beta_val; + + // fused: S[i][col] = g[i] * S[i][col] + k[i] * delta[col] + // attn[col] = (S^T @ q)[col] = sum_i S[i][col] * q[i] + float attn_partial = 0.0f; +#pragma unroll + for (int r = 0; r < rows_per_lane; r++) { + const int i = r * warp_size + lane; + s_shard[r] = expf(g_t[i]) * s_shard[r] + k_reg[r] * delta_col; + attn_partial += s_shard[r] * q_reg[r]; + } + + float attn_col = warp_reduce_sum<warp_size>(attn_partial); + + if (lane == 0) { + attn_data[col] = attn_col * scale; + } + } + + attn_data += S_v * H; + + if constexpr (keep_rs_t) { + // snapshot slot mapping: slot 0 = most recent state, slot s = s tokens back. + // When n_tokens < K only slots 0..n_tokens-1 are written; older slots are caller-owned. + const int64_t state_size_per_token = S_v * S_v * H * n_seqs; // per-slot stride in output + const int target_slot = (int) n_tokens - 1 - t; + if (target_slot >= 0 && target_slot < K) { + float * curr_state = (dst + attn_score_elems) + target_slot * state_size_per_token + state_out_offset; +#pragma unroll + for (int r = 0; r < rows_per_lane; r++) { + const int i = r * warp_size + lane; + curr_state[col * S_v + i] = s_shard[r]; + } + } + } + } + + if constexpr (!keep_rs_t) { +#pragma unroll + for (int r = 0; r < rows_per_lane; r++) { + const int i = r * warp_size + lane; + state[col * S_v + i] = s_shard[r]; + } + } +} + +template <bool KDA, bool keep_rs_t> +static void launch_gated_delta_net( + const float * q_d, const float * k_d, const float * v_d, + const float * g_d, const float * b_d, const float * s_d, + float * dst_d, + int64_t S_v, int64_t H, int64_t n_tokens, int64_t n_seqs, + int64_t sq1, int64_t sq2, int64_t sq3, + int64_t sv1, int64_t sv2, int64_t sv3, + int64_t sb1, int64_t sb2, int64_t sb3, + int64_t neqk1, int64_t rq3, + float scale, int K, cudaStream_t stream) { + //TODO: Add chunked kernel for even faster pre-fill + const int warp_size = ggml_cuda_info().devices[ggml_cuda_get_device()].warp_size; + const int num_warps = 4; + dim3 grid_dims(H, n_seqs, (S_v + num_warps - 1) / num_warps); + dim3 block_dims(warp_size <= S_v ? warp_size : S_v, num_warps, 1); + + const uint3 neqk1_magic = init_fastdiv_values(neqk1); + const uint3 rq3_magic = init_fastdiv_values(rq3); + + int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc; + + const ggml_cuda_kernel_launch_params launch_params = ggml_cuda_kernel_launch_params(grid_dims, block_dims, 0, stream); + switch (S_v) { + case 16: + ggml_cuda_kernel_launch(gated_delta_net_cuda<16, KDA, keep_rs_t>, launch_params, + q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H, + n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3, + sb1, sb2, sb3, neqk1_magic, rq3_magic, scale, K); + break; + case 32: + ggml_cuda_kernel_launch(gated_delta_net_cuda<32, KDA, keep_rs_t>, launch_params, + q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H, + n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3, + sb1, sb2, sb3, neqk1_magic, rq3_magic, scale, K); + break; + case 64: { + ggml_cuda_kernel_launch(gated_delta_net_cuda<64, KDA, keep_rs_t>, launch_params, + q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H, + n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3, + sb1, sb2, sb3, neqk1_magic, rq3_magic, scale, K); + break; + } + case 128: { + ggml_cuda_kernel_launch(gated_delta_net_cuda<128, KDA, keep_rs_t>, launch_params, + q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H, + n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3, + sb1, sb2, sb3, neqk1_magic, rq3_magic, scale, K); + break; + } + default: + GGML_ABORT("fatal error"); + break; + } +} + +void ggml_cuda_op_gated_delta_net(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + ggml_tensor * src_q = dst->src[0]; + ggml_tensor * src_k = dst->src[1]; + ggml_tensor * src_v = dst->src[2]; + ggml_tensor * src_g = dst->src[3]; + ggml_tensor * src_beta = dst->src[4]; + ggml_tensor * src_state = dst->src[5]; + + GGML_TENSOR_LOCALS(int64_t, neq, src_q, ne); + GGML_TENSOR_LOCALS(size_t , nbq, src_q, nb); + GGML_TENSOR_LOCALS(int64_t, nek, src_k, ne); + GGML_TENSOR_LOCALS(size_t , nbk, src_k, nb); + GGML_TENSOR_LOCALS(int64_t, nev, src_v, ne); + GGML_TENSOR_LOCALS(size_t, nbv, src_v, nb); + GGML_TENSOR_LOCALS(size_t, nbb, src_beta, nb); + + const int64_t S_v = nev0; + const int64_t H = nev1; + const int64_t n_tokens = nev2; + const int64_t n_seqs = nev3; + + const bool kda = (src_g->ne[0] == S_v); + + GGML_ASSERT(neq1 == nek1); + const int64_t neqk1 = neq1; + + const int64_t rq3 = nev3 / neq3; + + const float * q_d = (const float *) src_q->data; + const float * k_d = (const float *) src_k->data; + const float * v_d = (const float *) src_v->data; + const float * g_d = (const float *) src_g->data; + const float * b_d = (const float *) src_beta->data; + + const float * s_d = (const float *) src_state->data; + float * dst_d = (float *) dst->data; + + GGML_ASSERT(ggml_is_contiguous_rows(src_q)); + GGML_ASSERT(ggml_is_contiguous_rows(src_k)); + GGML_ASSERT(ggml_is_contiguous_rows(src_v)); + GGML_ASSERT(ggml_are_same_stride(src_q, src_k)); + GGML_ASSERT(src_g->ne[0] == 1 || kda); + GGML_ASSERT(ggml_is_contiguous(src_g)); + GGML_ASSERT(ggml_is_contiguous(src_beta)); + GGML_ASSERT(ggml_is_contiguous(src_state)); + + // strides in floats (beta strides used for both g and beta offset computation) + const int64_t sq1 = nbq1 / sizeof(float); + const int64_t sq2 = nbq2 / sizeof(float); + const int64_t sq3 = nbq3 / sizeof(float); + const int64_t sv1 = nbv1 / sizeof(float); + const int64_t sv2 = nbv2 / sizeof(float); + const int64_t sv3 = nbv3 / sizeof(float); + const int64_t sb1 = nbb1 / sizeof(float); + const int64_t sb2 = nbb2 / sizeof(float); + const int64_t sb3 = nbb3 / sizeof(float); + + const float scale = 1.0f / sqrtf((float) S_v); + + cudaStream_t stream = ctx.stream(); + + // K (snapshot slot count) is an op param; state holds s0 only [S_v, S_v, H, n_seqs]. + const int K = ggml_get_op_params_i32(dst, 0); + const bool keep_rs = K > 1; + + if (kda) { + if (keep_rs) { + launch_gated_delta_net<true, true>(q_d, k_d, v_d, g_d, b_d, s_d, dst_d, + S_v, H, n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3, + sb1, sb2, sb3, neqk1, rq3, scale, K, stream); + } else { + launch_gated_delta_net<true, false>(q_d, k_d, v_d, g_d, b_d, s_d, dst_d, + S_v, H, n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3, + sb1, sb2, sb3, neqk1, rq3, scale, K, stream); + } + } else { + if (keep_rs) { + launch_gated_delta_net<false, true>(q_d, k_d, v_d, g_d, b_d, s_d, dst_d, + S_v, H, n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3, + sb1, sb2, sb3, neqk1, rq3, scale, K, stream); + } else { + launch_gated_delta_net<false, false>(q_d, k_d, v_d, g_d, b_d, s_d, dst_d, + S_v, H, n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3, + sb1, sb2, sb3, neqk1, rq3, scale, K, stream); + } + } +} diff --git a/ggml/src/ggml-cuda/gated_delta_net.cuh b/ggml/src/ggml-cuda/gated_delta_net.cuh new file mode 100644 index 00000000000..7375e81c0c3 --- /dev/null +++ b/ggml/src/ggml-cuda/gated_delta_net.cuh @@ -0,0 +1,4 @@ +#include "common.cuh" +#include "ggml.h" + +void ggml_cuda_op_gated_delta_net(ggml_backend_cuda_context & ctx, ggml_tensor * dst); diff --git a/ggml/src/ggml-cuda/getrows.cu b/ggml/src/ggml-cuda/getrows.cu index 2fab33243dd..eb157b8baf2 100644 --- a/ggml/src/ggml-cuda/getrows.cu +++ b/ggml/src/ggml-cuda/getrows.cu @@ -6,17 +6,19 @@ template<int qk, int qr, dequantize_kernel_t dequantize_kernel, typename dst_t> static __global__ void k_get_rows( const void * __restrict__ src0, const int32_t * __restrict__ src1, dst_t * __restrict__ dst, const int64_t ne00, /*const int64_t ne01, const int64_t ne02, const int64_t ne03,*/ - /*const int64_t ne10,*/ const int64_t ne11, const int64_t ne12, /*const int64_t ne13,*/ + /*const int64_t ne10,*/ const int64_t ne11, const uint3 ne12_fdv, /*const int64_t ne13,*/ /*const size_t s0,*/ const size_t s1, const size_t s2, const size_t s3, /*const size_t nb00,*/ const size_t nb01, const size_t nb02, const size_t nb03, const size_t s10, const size_t s11, const size_t s12/*, const size_t s13*/) { - for (int64_t z = blockIdx.z; z < ne11*ne12; z += gridDim.z) { + ggml_cuda_pdl_sync(); + for (int64_t z = blockIdx.z; z < ne11*(int64_t)ne12_fdv.z; z += gridDim.z) { for (int64_t i00 = 2*(blockIdx.y*blockDim.x + threadIdx.x); i00 < ne00; i00 += gridDim.y*blockDim.x) { // The x and y dimensions of the grid are swapped because the maximum allowed grid size for x is higher. const int i10 = blockIdx.x; - const int i11 = z / ne12; // TODO fastdiv - const int i12 = z % ne12; + const uint2 dm = fast_div_modulo((uint32_t)z, ne12_fdv); + const int i11 = dm.x; + const int i12 = dm.y; const int i01 = src1[i10*s10 + i11*s11 + i12*s12]; @@ -40,19 +42,25 @@ static __global__ void k_get_rows( template<typename src0_t, typename dst_t> static __global__ void k_get_rows_float( - const src0_t * __restrict__ src0, const int32_t * __restrict__ src1, dst_t * __restrict__ dst, + const src0_t * src0_ptr, const int32_t * src1_ptr, dst_t * dst_ptr, const int64_t ne00, /*const int64_t ne01, const int64_t ne02, const int64_t ne03,*/ - /*const int64_t ne10,*/ const int64_t ne11, const int64_t ne12, /*const int64_t ne13,*/ + /*const int64_t ne10,*/ const int64_t ne11, const uint3 ne12_fdv, /*const int64_t ne13,*/ /*const size_t s0,*/ const size_t s1, const size_t s2, const size_t s3, /*const size_t nb00,*/ const size_t nb01, const size_t nb02, const size_t nb03, const size_t s10, const size_t s11, const size_t s12/*, const size_t s13*/) { - for (int64_t z = blockIdx.z; z < ne11*ne12; z += gridDim.z) { + ggml_cuda_pdl_lc(); + const src0_t * GGML_CUDA_RESTRICT src0 = src0_ptr; + const int32_t * GGML_CUDA_RESTRICT src1 = src1_ptr; + dst_t * GGML_CUDA_RESTRICT dst = dst_ptr; + ggml_cuda_pdl_sync(); + for (int64_t z = blockIdx.z; z < ne11*(int64_t)ne12_fdv.z; z += gridDim.z) { for (int64_t i00 = blockIdx.y*blockDim.x + threadIdx.x; i00 < ne00; i00 += gridDim.y*blockDim.x) { // The x and y dimensions of the grid are swapped because the maximum allowed grid size for x is higher. const int i10 = blockIdx.x; - const int i11 = z / ne12; // TODO fastdiv - const int i12 = z % ne12; + const uint2 dm = fast_div_modulo((uint32_t)z, ne12_fdv); + const int i11 = dm.x; + const int i12 = dm.y; if (i00 >= ne00) { return; @@ -81,6 +89,7 @@ static __global__ void k_get_rows_back_float( float sum = 0.0f; + ggml_cuda_pdl_sync(); for (int64_t i = 0; i < nrows_grad; ++i) { if (rows[i] != dst_row) { continue; @@ -115,10 +124,14 @@ static void get_rows_cuda_q( GGML_ASSERT(ne00 % 2 == 0); + GGML_ASSERT(ne12 > 0); + GGML_ASSERT(ne11 <= std::numeric_limits<uint32_t>::max() / ne12); + const uint3 ne12_fdv = init_fastdiv_values(ne12); + k_get_rows<qk, qr, dq><<<block_nums, block_dims, 0, stream>>>( src0_d, src1_d, dst_d, ne00, /*ne01, ne02, ne03,*/ - /*ne10,*/ ne11, ne12, /*ne13,*/ + /*ne10,*/ ne11, ne12_fdv, /*ne13,*/ /* s0,*/ s1, s2, s3, /* nb00,*/ nb01, nb02, nb03, s10, s11, s12/*, s13*/); @@ -146,10 +159,15 @@ static void get_rows_cuda_float( const size_t s12 = nb12 / sizeof(int32_t); // const size_t s13 = nb13 / sizeof(int32_t); - k_get_rows_float<<<block_nums, block_dims, 0, stream>>>( + GGML_ASSERT(ne12 > 0); + GGML_ASSERT(ne11 <= std::numeric_limits<uint32_t>::max() / ne12); + const uint3 ne12_fdv = init_fastdiv_values(ne12); + + const ggml_cuda_kernel_launch_params launch_params = ggml_cuda_kernel_launch_params{block_nums, block_dims, 0, stream}; + ggml_cuda_kernel_launch(k_get_rows_float<src0_t, dst_t>, launch_params, src0_d, src1_d, dst_d, ne00, /*ne01, ne02, ne03,*/ - /*ne10,*/ ne11, ne12, /*ne13,*/ + /*ne10,*/ ne11, ne12_fdv, /*ne13,*/ /* s0,*/ s1, s2, s3, /* nb00,*/ nb01, nb02, nb03, s10, s11, s12/*, s13*/); @@ -179,6 +197,10 @@ static void ggml_cuda_get_rows_switch_src0_type( get_rows_cuda_float((const nv_bfloat16 *) src0_d, src1_d, dst_d, ne00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb1, nb2, nb3, stream); break; + case GGML_TYPE_Q1_0: + get_rows_cuda_q<QK1_0, QR1_0, dequantize_q1_0>(src0_d, src1_d, dst_d, + ne00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb1, nb2, nb3, stream); + break; case GGML_TYPE_Q4_0: get_rows_cuda_q<QK4_0, QR4_0, dequantize_q4_0>(src0_d, src1_d, dst_d, ne00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb1, nb2, nb3, stream); diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index c3ee2ea0667..61041bdc16b 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -2,6 +2,7 @@ #include "ggml-impl.h" #include "ggml-backend-impl.h" +#include "ggml-cuda/allreduce.cuh" #include "ggml-cuda/common.cuh" #include "ggml-cuda/acc.cuh" #include "ggml-cuda/add-id.cuh" @@ -23,6 +24,7 @@ #include "ggml-cuda/diagmask.cuh" #include "ggml-cuda/diag.cuh" #include "ggml-cuda/fattn.cuh" +#include "ggml-cuda/fwht.cuh" #include "ggml-cuda/getrows.cuh" #include "ggml-cuda/im2col.cuh" #include "ggml-cuda/mmf.cuh" @@ -39,6 +41,7 @@ #include "ggml-cuda/rope.cuh" #include "ggml-cuda/roll.cuh" #include "ggml-cuda/scale.cuh" +#include "ggml-cuda/snake.cuh" #include "ggml-cuda/softcap.cuh" #include "ggml-cuda/softmax.cuh" #include "ggml-cuda/ssm-conv.cuh" @@ -53,6 +56,7 @@ #include "ggml-cuda/upscale.cuh" #include "ggml-cuda/wkv.cuh" #include "ggml-cuda/gla.cuh" +#include "ggml-cuda/gated_delta_net.cuh" #include "ggml-cuda/set.cuh" #include "ggml-cuda/set-rows.cuh" #include "ggml-cuda/pad_reflect_1d.cuh" @@ -70,20 +74,23 @@ #include <condition_variable> #include <cstddef> #include <cstdint> -#include <float.h> +#include <cfloat> #include <initializer_list> #include <limits> #include <map> #include <memory> #include <mutex> -#include <stdarg.h> -#include <stdio.h> -#include <stdlib.h> +#include <cstdarg> +#include <cstdio> +#include <cstdlib> #include <string> #include <vector> static_assert(sizeof(half) == sizeof(ggml_fp16_t), "wrong fp16 size"); +#define GGML_LOG_WARN_ONCE(str) \ + { static std::once_flag warn_flag; std::call_once(warn_flag, []() { GGML_LOG_WARN(str); }); } + [[noreturn]] void ggml_cuda_error(const char * stmt, const char * func, const char * file, int line, const char * msg) { int id = -1; // in case cudaGetDevice fails @@ -122,7 +129,10 @@ static cudaError_t ggml_cuda_device_malloc(void ** ptr, size_t size, int device) err = cudaMallocManaged(ptr, size); #if defined(GGML_USE_HIP) if (err == hipSuccess) { - CUDA_CHECK(cudaMemAdvise(*ptr, size, hipMemAdviseSetCoarseGrain, device)); + // hipMemAdviseSetCoarseGrain is an optional performance hint; + // ignore errors (e.g. hipErrorInvalidValue on some APU/iGPU configs). + (void)cudaMemAdvise(*ptr, size, hipMemAdviseSetCoarseGrain, device); + (void)hipGetLastError(); // clear any error } // fall back to cudaMalloc if not supported (e.g. on Windows) @@ -203,7 +213,14 @@ static ggml_cuda_device_info ggml_cuda_init() { GGML_ASSERT(info.device_count <= GGML_CUDA_MAX_DEVICES); int64_t total_vram = 0; - GGML_LOG_INFO("%s: found %d " GGML_CUDA_NAME " devices:\n", __func__, info.device_count); + for (int id = 0; id < info.device_count; ++id) { + cudaDeviceProp prop; + CUDA_CHECK(cudaGetDeviceProperties(&prop, id)); + total_vram += prop.totalGlobalMem; + } + GGML_LOG_INFO("%s: found %d " GGML_CUDA_NAME " devices (Total VRAM: %zu MiB):\n", + __func__, info.device_count, (size_t)(total_vram / (1024 * 1024))); + total_vram = 0; std::vector<std::pair<int, std::string>> turing_devices_without_mma; for (int id = 0; id < info.device_count; ++id) { @@ -241,6 +258,7 @@ static ggml_cuda_device_info ggml_cuda_init() { #else info.devices[id].supports_cooperative_launch = false; #endif // !(GGML_USE_MUSA) + #if defined(GGML_USE_HIP) info.devices[id].smpbo = prop.sharedMemPerBlock; @@ -255,22 +273,25 @@ static ggml_cuda_device_info ggml_cuda_init() { info.devices[id].cc += prop.minor * 0x10; } } - GGML_LOG_INFO(" Device %d: %s, %s (0x%x), VMM: %s, Wave Size: %d\n", + GGML_LOG_INFO(" Device %d: %s, %s (0x%x), VMM: %s, Wave Size: %d, VRAM: %zu MiB\n", id, prop.name, prop.gcnArchName, info.devices[id].cc & 0xffff, - device_vmm ? "yes" : "no", prop.warpSize); + device_vmm ? "yes" : "no", prop.warpSize, + (size_t)(prop.totalGlobalMem / (1024 * 1024))); #elif defined(GGML_USE_MUSA) // FIXME: Ensure compatibility with varying warp sizes across different MUSA archs. info.devices[id].warp_size = 32; info.devices[id].smpbo = prop.sharedMemPerBlockOptin; info.devices[id].cc = GGML_CUDA_CC_OFFSET_MTHREADS + prop.major * 0x100; info.devices[id].cc += prop.minor * 0x10; - GGML_LOG_INFO(" Device %d: %s, compute capability %d.%d, VMM: %s\n", - id, prop.name, prop.major, prop.minor, device_vmm ? "yes" : "no"); + GGML_LOG_INFO(" Device %d: %s, compute capability %d.%d, VMM: %s, VRAM: %zu MiB\n", + id, prop.name, prop.major, prop.minor, device_vmm ? "yes" : "no", + (size_t)(prop.totalGlobalMem / (1024 * 1024))); #else info.devices[id].smpbo = prop.sharedMemPerBlockOptin; info.devices[id].cc = 100*prop.major + 10*prop.minor; - GGML_LOG_INFO(" Device %d: %s, compute capability %d.%d, VMM: %s\n", - id, prop.name, prop.major, prop.minor, device_vmm ? "yes" : "no"); + GGML_LOG_INFO(" Device %d: %s, compute capability %d.%d, VMM: %s, VRAM: %zu MiB\n", + id, prop.name, prop.major, prop.minor, device_vmm ? "yes" : "no", + (size_t)(prop.totalGlobalMem / (1024 * 1024))); std::string device_name(prop.name); if (device_name == "NVIDIA GeForce MX450") { turing_devices_without_mma.push_back({ id, device_name }); @@ -285,6 +306,7 @@ static ggml_cuda_device_info ggml_cuda_init() { // TODO: Check for future drivers the default scheduling strategy and // remove this call again when cudaDeviceScheduleSpin is default. if (prop.major == 12 && prop.minor == 1) { + CUDA_CHECK(cudaSetDevice(id)); CUDA_CHECK(cudaSetDeviceFlags(cudaDeviceScheduleSpin)); } @@ -308,6 +330,22 @@ static ggml_cuda_device_info ggml_cuda_init() { // configure logging to stdout // CUBLAS_CHECK(cublasLoggerConfigure(1, 1, 0, nullptr)); + if (getenv("GGML_CUDA_P2P") != nullptr) { + for (int id = 0; id < info.device_count; ++id) { + ggml_cuda_set_device(id); + for (int id_other = 0; id_other < info.device_count; ++id_other) { + if (id == id_other) { + continue; + } + int can_access_peer; + CUDA_CHECK(cudaDeviceCanAccessPeer(&can_access_peer, id, id_other)); + if (can_access_peer) { + CUDA_CHECK(cudaDeviceEnablePeerAccess(id_other, 0)); + } + } + } + } + return info; } @@ -336,15 +374,21 @@ struct ggml_cuda_pool_leg : public ggml_cuda_pool { } ~ggml_cuda_pool_leg() { + clear_pool(); + GGML_ASSERT(pool_size == 0); + } + + void clear_pool() { ggml_cuda_set_device(device); for (int i = 0; i < MAX_BUFFERS; ++i) { ggml_cuda_buffer & b = buffer_pool[i]; if (b.ptr != nullptr) { CUDA_CHECK(cudaFree(b.ptr)); pool_size -= b.size; + b.ptr = nullptr; + b.size = 0; } } - GGML_ASSERT(pool_size == 0); } void * alloc(size_t size, size_t * actual_size) override { @@ -389,7 +433,20 @@ struct ggml_cuda_pool_leg : public ggml_cuda_pool { size_t look_ahead_size = (size_t) (1.05 * size); look_ahead_size = 256 * ((look_ahead_size + 255)/256); ggml_cuda_set_device(device); - CUDA_CHECK(ggml_cuda_device_malloc(&ptr, look_ahead_size, device)); + cudaError_t err = ggml_cuda_device_malloc(&ptr, look_ahead_size, device); + if (err == cudaErrorMemoryAllocation) { + (void)cudaGetLastError(); + const size_t cached_bytes = pool_size; + GGML_LOG_DEBUG(GGML_CUDA_NAME " pool[%d]: alloc of %.2f MiB failed, flushing %.2f MiB of cached buffers and retrying\n", + device, look_ahead_size/1024.0/1024.0, cached_bytes/1024.0/1024.0); + CUDA_CHECK(cudaDeviceSynchronize()); + clear_pool(); + err = ggml_cuda_device_malloc(&ptr, look_ahead_size, device); + if (err == cudaSuccess) { + GGML_LOG_DEBUG(GGML_CUDA_NAME " pool[%d]: retry succeeded\n", device); + } + } + CUDA_CHECK(err); *actual_size = look_ahead_size; pool_size += look_ahead_size; #ifdef DEBUG_CUDA_MALLOC @@ -565,6 +622,18 @@ ggml_backend_cuda_context::~ggml_backend_cuda_context() { // cuda buffer +struct ggml_backend_cuda_device_context { + int device; + std::string name; + std::string description; + std::string pci_bus_id; + int op_offload_min_batch_size; +#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) + std::mutex device_mutex; + int active_count = 0; +#endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) +}; + struct ggml_backend_cuda_buffer_context { int device; void * dev_ptr = nullptr; @@ -582,6 +651,13 @@ struct ggml_backend_cuda_buffer_context { static void ggml_backend_cuda_buffer_free_buffer(ggml_backend_buffer_t buffer) { ggml_backend_cuda_buffer_context * ctx = (ggml_backend_cuda_buffer_context *)buffer->context; + +#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) + ggml_backend_cuda_device_context * dev_ctx = (ggml_backend_cuda_device_context *) buffer->buft->device->context; + std::lock_guard<std::mutex> lock(dev_ctx->device_mutex); + dev_ctx->active_count--; +#endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) + delete ctx; } @@ -616,26 +692,46 @@ static enum ggml_status ggml_backend_cuda_buffer_init_tensor(ggml_backend_buffer } static void ggml_backend_cuda_buffer_memset_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor, uint8_t value, size_t offset, size_t size) { - ggml_backend_cuda_buffer_context * ctx = (ggml_backend_cuda_buffer_context *)buffer->context; + ggml_backend_cuda_buffer_context * ctx = (ggml_backend_cuda_buffer_context *) buffer->context; ggml_cuda_set_device(ctx->device); - CUDA_CHECK(cudaMemsetAsync((char *)tensor->data + offset, value, size, cudaStreamPerThread)); + CUDA_CHECK(cudaMemsetAsync((char *) tensor->data + offset, value, size, cudaStreamPerThread)); CUDA_CHECK(cudaStreamSynchronize(cudaStreamPerThread)); } static void ggml_backend_cuda_buffer_set_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor, const void * data, size_t offset, size_t size) { - ggml_backend_cuda_buffer_context * ctx = (ggml_backend_cuda_buffer_context *)buffer->context; + ggml_backend_cuda_buffer_context * ctx = (ggml_backend_cuda_buffer_context *) buffer->context; ggml_cuda_set_device(ctx->device); - CUDA_CHECK(cudaMemcpyAsync((char *)tensor->data + offset, data, size, cudaMemcpyHostToDevice, cudaStreamPerThread)); + CUDA_CHECK(cudaMemcpyAsync((char *) tensor->data + offset, data, size, cudaMemcpyHostToDevice, cudaStreamPerThread)); CUDA_CHECK(cudaStreamSynchronize(cudaStreamPerThread)); } static void ggml_backend_cuda_buffer_get_tensor(ggml_backend_buffer_t buffer, const ggml_tensor * tensor, void * data, size_t offset, size_t size) { + ggml_backend_cuda_buffer_context * ctx = (ggml_backend_cuda_buffer_context *) buffer->context; + + ggml_cuda_set_device(ctx->device); + CUDA_CHECK(cudaMemcpyAsync(data, (const char *) tensor->data + offset, size, cudaMemcpyDeviceToHost, cudaStreamPerThread)); + CUDA_CHECK(cudaStreamSynchronize(cudaStreamPerThread)); +} + +static void ggml_backend_cuda_buffer_set_tensor_2d(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, const void * data, + size_t offset, size_t size, size_t n_copies, size_t stride_tensor, size_t stride_data) { + ggml_backend_cuda_buffer_context * ctx = (ggml_backend_cuda_buffer_context *) buffer->context; + + ggml_cuda_set_device(ctx->device); + CUDA_CHECK(cudaMemcpy2DAsync( + (char *) tensor->data + offset, stride_tensor, data, stride_data, size, n_copies, cudaMemcpyHostToDevice, cudaStreamPerThread)); + CUDA_CHECK(cudaStreamSynchronize(cudaStreamPerThread)); +} + +static void ggml_backend_cuda_buffer_get_tensor_2d(ggml_backend_buffer_t buffer, const struct ggml_tensor * tensor, void * data, + size_t offset, size_t size, size_t n_copies, size_t stride_tensor, size_t stride_data) { ggml_backend_cuda_buffer_context * ctx = (ggml_backend_cuda_buffer_context *)buffer->context; ggml_cuda_set_device(ctx->device); - CUDA_CHECK(cudaMemcpyAsync(data, (const char *)tensor->data + offset, size, cudaMemcpyDeviceToHost, cudaStreamPerThread)); + CUDA_CHECK(cudaMemcpy2DAsync( + data, stride_data, (const char *) tensor->data + offset, stride_tensor, size, n_copies, cudaMemcpyDeviceToHost, cudaStreamPerThread)); CUDA_CHECK(cudaStreamSynchronize(cudaStreamPerThread)); } @@ -675,6 +771,8 @@ static const ggml_backend_buffer_i ggml_backend_cuda_buffer_interface = { /* .memset_tensor = */ ggml_backend_cuda_buffer_memset_tensor, /* .set_tensor = */ ggml_backend_cuda_buffer_set_tensor, /* .get_tensor = */ ggml_backend_cuda_buffer_get_tensor, + /* .set_tensor_2d = */ ggml_backend_cuda_buffer_set_tensor_2d, + /* .get_tensor_2d = */ ggml_backend_cuda_buffer_get_tensor_2d, /* .cpy_tensor = */ ggml_backend_cuda_buffer_cpy_tensor, /* .clear = */ ggml_backend_cuda_buffer_clear, /* .reset = */ NULL, @@ -712,6 +810,12 @@ static ggml_backend_buffer_t ggml_backend_cuda_buffer_type_alloc_buffer(ggml_bac ggml_backend_cuda_buffer_context * ctx = new ggml_backend_cuda_buffer_context(buft_ctx->device, dev_ptr); +#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) + ggml_backend_cuda_device_context * dev_ctx = (ggml_backend_cuda_device_context *) buft->device->context; + std::lock_guard<std::mutex> lock(dev_ctx->device_mutex); + dev_ctx->active_count++; +#endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) + return ggml_backend_buffer_init(buft, ggml_backend_cuda_buffer_interface, ctx, size); } @@ -722,7 +826,11 @@ static size_t ggml_backend_cuda_buffer_type_get_alignment(ggml_backend_buffer_ty } static size_t ggml_backend_cuda_buffer_type_get_alloc_size(ggml_backend_buffer_type_t buft, const ggml_tensor * tensor) { - size_t size = ggml_nbytes(tensor); + ggml_backend_cuda_buffer_type_context * buft_ctx = (ggml_backend_cuda_buffer_type_context *) buft->context; + + size_t size = tensor->op == GGML_OP_FLASH_ATTN_EXT + ? ggml_cuda_flash_attn_ext_get_alloc_size(buft_ctx->device, tensor) + : ggml_nbytes(tensor); int64_t ne0 = tensor->ne[0]; if (ggml_is_quantized(tensor->type)) { @@ -733,8 +841,6 @@ static size_t ggml_backend_cuda_buffer_type_get_alloc_size(ggml_backend_buffer_t } return size; - - GGML_UNUSED(buft); } static const ggml_backend_buffer_type_i ggml_backend_cuda_buffer_type_interface = { @@ -987,6 +1093,8 @@ static const ggml_backend_buffer_i ggml_backend_cuda_split_buffer_interface = { /* .memset_tensor = */ NULL, /* .set_tensor = */ ggml_backend_cuda_split_buffer_set_tensor, /* .get_tensor = */ ggml_backend_cuda_split_buffer_get_tensor, + /* .set_tensor_2d = */ NULL, + /* .get_tensor_2d = */ NULL, /* .cpy_tensor = */ NULL, /* .clear = */ ggml_backend_cuda_split_buffer_clear, /* .reset = */ NULL, @@ -1063,6 +1171,295 @@ static const ggml_backend_buffer_type_i ggml_backend_cuda_split_buffer_type_inte /* .is_host = */ ggml_backend_cuda_split_buffer_type_is_host, }; +// Communication context for multi-GPU AllReduce during tensor parallelism. +// +// Created once per meta backend instance. Resources for the selected mode +// (NCCL communicators or the internal AllReduce pipeline) are initialised +// eagerly during comm_init so any init failure surfaces at startup rather +// than mid-run. +struct ggml_backend_cuda_comm_context { + using try_allreduce_fn = bool(*)(ggml_backend_cuda_comm_context *, struct ggml_tensor **); + + std::vector<ggml_backend_t> backends; + std::vector<int> dev_ids; + + // Set by the init chain (comm_init_{nccl, internal, none}) to one of + // try_allreduce_{nccl, internal, butterfly}. nccl needs `comms`, + // internal needs `ar_pipeline`, butterfly needs nothing. Per-call + // failures return false; the meta backend's generic implementation then + // handles that call. + try_allreduce_fn try_allreduce = nullptr; + + ggml_cuda_ar_pipeline * ar_pipeline = nullptr; + +#ifdef GGML_USE_NCCL + std::vector<ncclComm_t> comms; +#endif // GGML_USE_NCCL + + ~ggml_backend_cuda_comm_context() { +#ifdef GGML_USE_NCCL + for (ncclComm_t comm : comms) { + NCCL_CHECK(ncclCommDestroy(comm)); + } +#endif // GGML_USE_NCCL + ggml_cuda_ar_pipeline_free(ar_pipeline); + } +}; + +#ifdef GGML_USE_NCCL +// AllReduce via NCCL. Reduces as FP32 for small tensors and BF16 for large +// tensors (bandwidth-bound), then converts back to FP32. +static bool ggml_backend_cuda_comm_allreduce_nccl( + ggml_backend_cuda_comm_context * comm_ctx, struct ggml_tensor ** tensors) { + const int64_t ne = ggml_nelements(tensors[0]); + // FIXME the input of llm_graph_context::build_in_out_ids can produce a tensor with 0 elements if n_outputs == 0 + // This then causes a crash in this function + if (ne == 0) { + return true; + } + + const size_t n_backends = comm_ctx->backends.size(); + + for (size_t i = 0; i < n_backends; ++i) { + GGML_ASSERT(tensors[i] != nullptr); + GGML_ASSERT(ggml_nelements(tensors[i]) == ne); + GGML_ASSERT(ggml_is_contiguously_allocated(tensors[i])); + } + + // For small tensors, simply reduce them as FP32. + // The following heuristic for how "small" a tensor should be is based on RTX 4090s connected via 16x PCIe 4.0. + if ((n_backends <= 2 && ne < 32768) || (n_backends == 3 && ne < 131072) || (n_backends >= 4 && ne < 262144)) { + for (size_t i = 0; i < n_backends; ++i) { + if ((tensors[i]->flags & GGML_TENSOR_FLAG_COMPUTE) == 0) { + ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *) comm_ctx->backends[i]->context; + ggml_cuda_set_device(cuda_ctx->device); + CUDA_CHECK(cudaMemsetAsync(tensors[i]->data, 0, ggml_nbytes(tensors[i]), cuda_ctx->stream())); + } + } + NCCL_CHECK(ncclGroupStart()); + for (size_t i = 0; i < n_backends; ++i) { + ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *) comm_ctx->backends[i]->context; + NCCL_CHECK(ncclAllReduce(tensors[i]->data, tensors[i]->data, ne, ncclFloat, ncclSum, comm_ctx->comms[i], cuda_ctx->stream())); + } + NCCL_CHECK(ncclGroupEnd()); + return true; + } + + // For large tensors it's faster to compress them to BF16 for the reduction: + to_bf16_cuda_t to_bf16 = ggml_get_to_bf16_cuda(GGML_TYPE_F32); + to_fp32_cuda_t to_fp32 = ggml_get_to_fp32_cuda(GGML_TYPE_BF16); + + ggml_cuda_pool_alloc<nv_bfloat16> tmp[GGML_CUDA_MAX_DEVICES]; + for (size_t i = 0; i < n_backends; ++i) { + ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *) comm_ctx->backends[i]->context; + tmp[i].pool = &cuda_ctx->pool(); + tmp[i].alloc(ne); + + ggml_cuda_set_device(cuda_ctx->device); + if (tensors[i]->flags & GGML_TENSOR_FLAG_COMPUTE) { + to_bf16(tensors[i]->data, tmp[i].get(), ne, cuda_ctx->stream()); + } else { + CUDA_CHECK(cudaMemsetAsync(tmp[i].get(), 0, ne * sizeof(nv_bfloat16), cuda_ctx->stream())); + } + CUDA_CHECK(cudaGetLastError()); + } + + NCCL_CHECK(ncclGroupStart()); + for (size_t i = 0; i < n_backends; ++i) { + ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *) comm_ctx->backends[i]->context; + NCCL_CHECK(ncclAllReduce(tmp[i].get(), tmp[i].get(), ne, ncclBfloat16, ncclSum, comm_ctx->comms[i], cuda_ctx->stream())); + } + NCCL_CHECK(ncclGroupEnd()); + + for (size_t i = 0; i < n_backends; ++i) { + ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *) comm_ctx->backends[i]->context; + + ggml_cuda_set_device(cuda_ctx->device); + to_fp32(tmp[i].get(), (float *) tensors[i]->data, ne, cuda_ctx->stream()); + CUDA_CHECK(cudaGetLastError()); + } + + return true; +} +#endif // GGML_USE_NCCL + +// Run the internal AR pipeline. Returns false on unsupported / failed input +// -- the caller decides whether to abort (env-forced) or fall back silently. +static bool ggml_backend_cuda_comm_allreduce_internal( + ggml_backend_cuda_comm_context * comm_ctx, struct ggml_tensor ** tensors) { + GGML_ASSERT(comm_ctx->ar_pipeline != nullptr); + + const size_t n_backends = comm_ctx->backends.size(); + GGML_ASSERT(n_backends == 2); + GGML_ASSERT(tensors[0] != nullptr); + + const int64_t ne = ggml_nelements(tensors[0]); + const ggml_type type = tensors[0]->type; + + if (type != GGML_TYPE_F32 && type != GGML_TYPE_F16 && type != GGML_TYPE_BF16) { + GGML_LOG_DEBUG("%s: internal unsupported: type=%d\n", __func__, (int) type); + return false; + } + + if (ne == 0) { + return true; + } + + for (size_t i = 0; i < n_backends; ++i) { + if (tensors[i] == nullptr) { + GGML_LOG_ERROR("%s: internal failed: tensor[%zu] is null\n", __func__, i); + return false; + } + if (ggml_nelements(tensors[i]) != ne || tensors[i]->type != type) { + GGML_LOG_ERROR("%s: internal failed: tensor[%zu] ne=%" PRId64 " type=%d expected ne=%" PRId64 " type=%d\n", + __func__, i, ggml_nelements(tensors[i]), (int) tensors[i]->type, ne, (int) type); + return false; + } + if (!ggml_is_contiguously_allocated(tensors[i])) { + GGML_LOG_DEBUG("%s: internal unsupported: tensor[%zu] is not contiguously allocated: ne=%" PRId64 " nbytes=%zu packed=%zu type=%d\n", + __func__, i, ne, ggml_nbytes(tensors[i]), + (size_t) ne * ggml_type_size(type) / ggml_blck_size(type), (int) type); + return false; + } + if (((uintptr_t) tensors[i]->data & 0xF) != 0) { + GGML_LOG_DEBUG("%s: internal unsupported: tensor[%zu] data pointer is not 16-byte aligned: %p type=%d ne=%" PRId64 "\n", + __func__, i, tensors[i]->data, (int) type, ne); + return false; + } + GGML_ASSERT((ggml_nbytes(tensors[i]) & 0xF) == 0); + } + + return ggml_cuda_ar_allreduce(comm_ctx->ar_pipeline, comm_ctx->backends.data(), tensors); +} + +// --------------------------------------------------------------------------- +// Per-call dispatch -- three variants, one per backend. Each is set as +// comm_ctx->try_allreduce by the matching init step. Per-call failure +// returns false; the meta backend's generic implementation handles that call. +// --------------------------------------------------------------------------- + +#ifdef GGML_USE_NCCL +static bool ggml_backend_cuda_comm_try_allreduce_nccl( + ggml_backend_cuda_comm_context * comm_ctx, struct ggml_tensor ** tensors) { + return ggml_backend_cuda_comm_allreduce_nccl(comm_ctx, tensors); +} +#endif // GGML_USE_NCCL + +static bool ggml_backend_cuda_comm_try_allreduce_internal( + ggml_backend_cuda_comm_context * comm_ctx, struct ggml_tensor ** tensors) { + return ggml_backend_cuda_comm_allreduce_internal(comm_ctx, tensors); +} + +static bool ggml_backend_cuda_comm_try_allreduce_butterfly( + ggml_backend_cuda_comm_context *, struct ggml_tensor **) { + return false; +} + +static void ggml_backend_cuda_comm_free(void * comm_ctx_v) { + if (comm_ctx_v == nullptr) { + return; + } + delete static_cast<ggml_backend_cuda_comm_context *>(comm_ctx_v); +} + +// --------------------------------------------------------------------------- +// Init -- chained nccl -> internal -> none. Each step tries to bring up its +// resource; on failure it warns and recurses into the next step. +// --------------------------------------------------------------------------- +static void ggml_backend_cuda_comm_init_none(ggml_backend_cuda_comm_context * ret) { + ret->try_allreduce = ggml_backend_cuda_comm_try_allreduce_butterfly; +} + +static void ggml_backend_cuda_comm_init_internal(ggml_backend_cuda_comm_context * ret) { + ret->ar_pipeline = ggml_cuda_ar_pipeline_init(ret->dev_ids.data(), ret->dev_ids.size()); + if (ret->ar_pipeline) { + ret->try_allreduce = ggml_backend_cuda_comm_try_allreduce_internal; + return; + } + + // Clear sticky CUDA error from the failed init. + (void) cudaGetLastError(); + GGML_LOG_WARN("internal AllReduce init failed (n_devices != 2?); " + "falling back to meta-backend butterfly\n"); + ggml_backend_cuda_comm_init_none(ret); +} + +static void ggml_backend_cuda_comm_init_nccl(ggml_backend_cuda_comm_context * ret) { +#ifdef GGML_USE_NCCL + const size_t n = ret->dev_ids.size(); + ret->comms.resize(n); + ncclResult_t rc = ncclCommInitAll(ret->comms.data(), (int) n, ret->dev_ids.data()); + if (rc == ncclSuccess) { + ret->try_allreduce = ggml_backend_cuda_comm_try_allreduce_nccl; + return; + } + + ret->comms.clear(); + GGML_LOG_WARN("NCCL init failed (%s); falling back to internal AllReduce\n", + ncclGetErrorString(rc)); +#else // GGML_USE_NCCL +#ifndef GGML_USE_HIP + GGML_LOG_WARN("NCCL not compiled in; falling back to internal AllReduce. " + "Recompile with -DGGML_CUDA_NCCL=ON for best multi-GPU performance.\n"); +#endif // !GGML_USE_HIP +#endif // GGML_USE_NCCL + + ggml_backend_cuda_comm_init_internal(ret); +} + +// Top-level init. Picks one of the three init paths based on +// GGML_CUDA_ALLREDUCE (or the platform default) and lets the chain handle +// any fallback. Unrecognised env values warn and fall through to the +// platform default. +static void * ggml_backend_cuda_comm_init(ggml_backend_t * backends, size_t n_backends) { + for (size_t i = 0; i < n_backends; i++) { + if (!ggml_backend_is_cuda(backends[i])) { + return nullptr; + } + } + + auto * ret = new ggml_backend_cuda_comm_context; + ret->backends.assign(backends, backends + n_backends); + ret->dev_ids.reserve(n_backends); + for (size_t i = 0; i < n_backends; i++) { + ret->dev_ids.push_back(static_cast<ggml_backend_cuda_context *>(backends[i]->context)->device); + } + + const char * env = getenv("GGML_CUDA_ALLREDUCE"); + if (!env) { + // Platform default: Linux uses NCCL, otherwise (generally Windows) internal +#if defined(__linux__) + ggml_backend_cuda_comm_init_nccl(ret); +#else + ggml_backend_cuda_comm_init_internal(ret); +#endif // defined(__linux__) + } else { + std::string env_str(env); + if (env_str == "nccl") { + ggml_backend_cuda_comm_init_nccl(ret); + } else if (env_str == "internal") { + ggml_backend_cuda_comm_init_internal(ret); + } else if (env_str == "none") { + ggml_backend_cuda_comm_init_none(ret); + } else { + GGML_LOG_WARN("unknown GGML_CUDA_ALLREDUCE value: %s\n", env); + ggml_backend_cuda_comm_init_none(ret); + } + } + + return ret; +} + +// Top-level dispatch -- calls the function pointer chosen by comm_init. +// Returns false to let the meta-backend's butterfly run. +static bool ggml_backend_cuda_comm_allreduce_tensor(void * comm_ctx_v, struct ggml_tensor ** tensors) { + if (comm_ctx_v == nullptr) { + return false; + } + auto * comm_ctx = static_cast<ggml_backend_cuda_comm_context *>(comm_ctx_v); + return comm_ctx->try_allreduce(comm_ctx, tensors); +} + ggml_backend_buffer_type_t ggml_backend_cuda_split_buffer_type(int main_device, const float * tensor_split) { static std::mutex mutex; std::lock_guard<std::mutex> lock(mutex); @@ -1118,6 +1515,12 @@ static bool ggml_backend_buft_is_cuda_host(ggml_backend_buffer_type_t buft) { } static void ggml_backend_cuda_host_buffer_free_buffer(ggml_backend_buffer_t buffer) { +#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) + ggml_backend_cuda_device_context * dev_ctx = (ggml_backend_cuda_device_context *) buffer->buft->device->context; + std::lock_guard<std::mutex> lock(dev_ctx->device_mutex); + dev_ctx->active_count--; +#endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) + CUDA_CHECK(cudaFreeHost(buffer->context)); } @@ -1126,6 +1529,8 @@ static void * ggml_cuda_host_malloc(size_t size) { return nullptr; } + ggml_cuda_set_device(0); // cudaMallocHost can create the implicit CUDA device context, make sure that this is consistently done on device 0. + void * ptr = nullptr; cudaError_t err = cudaMallocHost((void **) &ptr, size); if (err != cudaSuccess) { @@ -1151,6 +1556,12 @@ static ggml_backend_buffer_t ggml_backend_cuda_host_buffer_type_alloc_buffer(ggm buffer->buft = buft; buffer->iface.free_buffer = ggml_backend_cuda_host_buffer_free_buffer; +#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) + ggml_backend_cuda_device_context * dev_ctx = (ggml_backend_cuda_device_context *) buft->device->context; + std::lock_guard<std::mutex> lock(dev_ctx->device_mutex); + dev_ctx->active_count++; +#endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) + return buffer; } @@ -1224,6 +1635,34 @@ static cudaError_t ggml_cuda_cpy_tensor_2d( } } +struct cublas_force_compute_type { + bool fp32 = false; + bool fp16 = false; +}; + +static const cublas_force_compute_type & ggml_cuda_cublas_get_force_compute_type() { + static const cublas_force_compute_type compute_type = [] { + cublas_force_compute_type result; + + const bool ggml_cuda_force_cublas_compute_32f_env = getenv("GGML_CUDA_FORCE_CUBLAS_COMPUTE_32F") != nullptr; + const bool ggml_cuda_force_cublas_compute_16f_env = getenv("GGML_CUDA_FORCE_CUBLAS_COMPUTE_16F") != nullptr; + + GGML_ASSERT(ggml_cuda_force_cublas_compute_16f_env == false || ggml_cuda_force_cublas_compute_32f_env == false); + + if (ggml_cuda_force_cublas_compute_32f_env) { + GGML_LOG_INFO("Detected GGML_CUDA_FORCE_CUBLAS_COMPUTE_32F\n"); + result.fp32 = true; + } else if (ggml_cuda_force_cublas_compute_16f_env) { + GGML_LOG_INFO("Detected GGML_CUDA_FORCE_CUBLAS_COMPUTE_16F\n"); + result.fp16 = true; + } + + return result; + }(); + + return compute_type; +} + static void ggml_cuda_op_mul_mat_cublas( ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const char * src0_dd_i, const float * src1_ddf_i, @@ -1252,7 +1691,12 @@ static void ggml_cuda_op_mul_mat_cublas( const bool supports_bf16 = GGML_CUDA_CC_IS_NVIDIA(cc) || GGML_CUDA_CC_IS_AMD(cc) || (GGML_CUDA_CC_IS_MTHREADS(cc) && cc >= GGML_CUDA_CC_QY2); - const bool use_fp16 = (src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type)) && ggml_is_contiguous(src0) && row_diff == src0->ne[1] && dst->op_params[0] == GGML_PREC_DEFAULT; + const bool use_fp16 = + src0->type != GGML_TYPE_NVFP4 && + (src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type)) && + ggml_is_contiguous(src0) && + row_diff == src0->ne[1] && + dst->op_params[0] == GGML_PREC_DEFAULT; if (supports_bf16 && src0->type == GGML_TYPE_BF16 && ggml_is_contiguous(src0) && row_diff == src0->ne[1]) { ggml_cuda_pool_alloc<nv_bfloat16> src1_as_bf16(ctx.pool(id)); @@ -1306,7 +1750,13 @@ static void ggml_cuda_op_mul_mat_cublas( CUBLAS_CHECK(cublasSetStream(ctx.cublas_handle(id), stream)); - if (GGML_CUDA_CC_IS_CDNA(cc) || GGML_CUDA_CC_IS_RDNA4(cc)) { + const auto & force_compute_type = ggml_cuda_cublas_get_force_compute_type(); + + if (!force_compute_type.fp16 && (GGML_CUDA_CC_IS_CDNA(cc) + || GGML_CUDA_CC_IS_RDNA4(cc) + || cc == GGML_CUDA_CC_VOLTA + || force_compute_type.fp32)) + { const float alpha = 1.0f; const float beta = 0.0f; CUBLAS_CHECK( @@ -1370,64 +1820,6 @@ static void ggml_cuda_op_mul_mat_cublas( GGML_UNUSED_VARS(dst, src1_ddq_i, src1_padded_row_size); } -static void ggml_cuda_set_peer_access(const int n_tokens, int main_device) { - static bool peer_access_enabled = false; - - const bool enable_peer_access = n_tokens <= GGML_CUDA_PEER_MAX_BATCH_SIZE; - - if (peer_access_enabled == enable_peer_access) { - return; - } - -#ifdef NDEBUG - for (int id = 0; id < ggml_backend_cuda_get_device_count(); ++id) { - ggml_cuda_set_device(id); - CUDA_CHECK(cudaDeviceSynchronize()); - } - - for (int id = 0; id < ggml_backend_cuda_get_device_count(); ++id) { - ggml_cuda_set_device(id); - - for (int id_other = 0; id_other < ggml_backend_cuda_get_device_count(); ++id_other) { - if (id == id_other) { - continue; - } - if (id != main_device && id_other != main_device) { - continue; - } - - int can_access_peer; - CUDA_CHECK(cudaDeviceCanAccessPeer(&can_access_peer, id, id_other)); - if (can_access_peer) { - if (enable_peer_access) { - cudaError_t err = cudaDeviceEnablePeerAccess(id_other, 0); - if (err != cudaErrorPeerAccessAlreadyEnabled) { - CUDA_CHECK(err); - } else { - // reset the error - (void)cudaGetLastError(); - } - } else { - cudaError_t err = cudaDeviceDisablePeerAccess(id_other); - if (err != cudaErrorPeerAccessNotEnabled) { - CUDA_CHECK(err); - } else { - // reset the error - (void)cudaGetLastError(); - } - } - } - } - } - - ggml_cuda_set_device(main_device); -#endif // NDEBUG - - peer_access_enabled = enable_peer_access; - - GGML_UNUSED(main_device); -} - static cudaError_t ggml_cuda_Memcpy2DPeerAsync( void * dst, int dstDevice, size_t dpitch, void * src, int srcDevice, size_t spitch, size_t width, size_t height, cudaStream_t stream) { @@ -1905,10 +2297,23 @@ static void ggml_cuda_mul_mat_batched_cublas_impl(ggml_backend_cuda_context & ct cudaDataType_t cu_data_type_b = traits::data_type; const void * alpha = traits::get_alpha(); const void * beta = traits::get_beta(); - const float alpha_f32 = 1.0f; - const float beta_f32 = 0.0f; - if (dst->op_params[0] == GGML_PREC_DEFAULT) { + const auto & force_compute_type = ggml_cuda_cublas_get_force_compute_type(); + + int id = ggml_cuda_get_device(); + const int cc = ggml_cuda_info().devices[id].cc; + static constexpr bool is_src0_type_f16 = src0_type == GGML_TYPE_F16; + + // bf16 and fp32 are already being computed in fp32 (ensure it using static_assert), + // so checking necessity of forced fp32 only for fp16 src0_type + static_assert(is_src0_type_f16 || traits::compute_type == CUBLAS_COMPUTE_32F); + + const bool need_compute_32f = is_src0_type_f16 && !force_compute_type.fp16 && (GGML_CUDA_CC_IS_CDNA(cc) + || GGML_CUDA_CC_IS_RDNA4(cc) + || cc == GGML_CUDA_CC_VOLTA + || force_compute_type.fp32); + + if (dst->op_params[0] == GGML_PREC_DEFAULT && !need_compute_32f) { if constexpr (src0_type == GGML_TYPE_F32) { dst_t = (char *) dst_ddf; // Direct F32 output } else { @@ -1918,18 +2323,10 @@ static void ggml_cuda_mul_mat_batched_cublas_impl(ggml_backend_cuda_context & ct } } else { dst_t = (char *) dst_ddf; - cu_compute_type = CUBLAS_COMPUTE_32F; - cu_data_type = CUDA_R_32F; - alpha = &alpha_f32; - beta = &beta_f32; - } - - int id = ggml_cuda_get_device(); - const int cc = ggml_cuda_info().devices[id].cc; - if (GGML_CUDA_CC_IS_CDNA(cc) || GGML_CUDA_CC_IS_RDNA4(cc)) { - cu_compute_type = CUBLAS_COMPUTE_32F; - alpha = &alpha_f32; - beta = &beta_f32; + cu_compute_type = batched_mul_mat_traits<GGML_TYPE_F32>::compute_type; + cu_data_type = batched_mul_mat_traits<GGML_TYPE_F32>::data_type; + alpha = batched_mul_mat_traits<GGML_TYPE_F32>::get_alpha(); + beta = batched_mul_mat_traits<GGML_TYPE_F32>::get_beta(); } GGML_ASSERT(ne12 % ne02 == 0); @@ -2214,6 +2611,7 @@ static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor use_mul_mat_q = use_mul_mat_q && ggml_cuda_should_use_mmq(src0->type, cc, src1->ne[1], /*n_experts=*/0); use_mul_mat_f = use_mul_mat_f && ggml_cuda_should_use_mmf(src0->type, cc, warp_size, src0->ne, src0->nb, src1->ne[1], /*mul_mat_id=*/false); use_mul_mat_vec_f = use_mul_mat_vec_f && ggml_cuda_should_use_mmvf(src0->type, cc, src0->ne, src0->nb, src1->ne[1]); + use_mul_mat_vec_q = use_mul_mat_vec_q && ggml_cuda_should_use_mmvq(src0->type, cc, src1->ne[1]); any_gpus_with_slow_fp16 = any_gpus_with_slow_fp16 || !fast_fp16_hardware_available(cc); } } else { @@ -2222,6 +2620,7 @@ static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor use_mul_mat_q = use_mul_mat_q && ggml_cuda_should_use_mmq(src0->type, cc, src1->ne[1], /*n_experts=*/0); use_mul_mat_f = use_mul_mat_f && ggml_cuda_should_use_mmf(src0->type, cc, warp_size, src0->ne, src0->nb, src1->ne[1], /*mul_mat_id=*/false); use_mul_mat_vec_f = use_mul_mat_vec_f && ggml_cuda_should_use_mmvf(src0->type, cc, src0->ne, src0->nb, src1->ne[1]); + use_mul_mat_vec_q = use_mul_mat_vec_q && ggml_cuda_should_use_mmvq(src0->type, cc, src1->ne[1]); any_gpus_with_slow_fp16 = any_gpus_with_slow_fp16 || !fast_fp16_hardware_available(cc); } @@ -2239,6 +2638,11 @@ static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor bool use_batched_cublas_bf16 = src0->type == GGML_TYPE_BF16 && bf16_mma_hardware_available(cc); bool use_batched_cublas_f32 = src0->type == GGML_TYPE_F32; + const int32_t hint = ggml_get_op_params_i32(dst, 1); + if (hint == GGML_HINT_SRC0_IS_HADAMARD && !split && ggml_cuda_op_fwht(ctx, src1, dst)) { + return; + } + if (!split && use_mul_mat_vec_f) { // the custom F16 vector kernel can be used over batched cuBLAS GEMM // but this is only faster for GPUs without tensor cores or with a thin src0 matrix (particularly KQV in attention) @@ -2277,14 +2681,22 @@ static void ggml_cuda_mul_mat_id(ggml_backend_cuda_context & ctx, ggml_tensor * const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc; + // [TAG_MUL_MAT_ID_CUDA_GRAPHS] if (src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { - if (ne2 == 1) { + static_assert(MMVQ_MAX_BATCH_SIZE == MMVF_MAX_BATCH_SIZE); + if (ne2 <= MMVQ_MAX_BATCH_SIZE) { if (ggml_is_quantized(src0->type)) { - ggml_cuda_mul_mat_vec_q(ctx, src0, src1, ids, dst); + const int mmvq_mmid_max = get_mmvq_mmid_max_batch(src0->type, cc); + if (ne2 <= mmvq_mmid_max) { + ggml_cuda_mul_mat_vec_q(ctx, src0, src1, ids, dst); + return; + } } else { - ggml_cuda_mul_mat_vec_f(ctx, src0, src1, ids, dst); + if (GGML_CUDA_CC_IS_AMD(cc)) { + ggml_cuda_mul_mat_vec_f(ctx, src0, src1, ids, dst); + return; + } } - return; } if (ggml_cuda_should_use_mmq(src0->type, cc, ne12, /*n_experts=*/ne02)) { @@ -2298,6 +2710,8 @@ static void ggml_cuda_mul_mat_id(ggml_backend_cuda_context & ctx, ggml_tensor * } } + // note: this path should not be reached when recording CUDA graphs, because it requires stream synchronization + // TODO: add asserts to verify this. should work with CUDA, HIP, etc. cudaStream_t stream = ctx.stream(); GGML_ASSERT(nb12 % nb11 == 0); @@ -2413,11 +2827,6 @@ static void ggml_cuda_mul_mat_id(ggml_backend_cuda_context & ctx, ggml_tensor * } static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct ggml_tensor * dst) { - // why is this here instead of mul_mat? - if (dst->src[0] != nullptr && ggml_backend_buft_is_cuda_split(dst->src[0]->buffer->buft)) { - ggml_cuda_set_peer_access(dst->src[1]->ne[1], ctx.device); - } - switch (dst->op) { case GGML_OP_ARGMAX: ggml_cuda_argmax(ctx, dst); @@ -2723,6 +3132,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg case GGML_OP_GATED_LINEAR_ATTN: ggml_cuda_op_gated_linear_attn(ctx, dst); break; + case GGML_OP_GATED_DELTA_NET: + ggml_cuda_op_gated_delta_net(ctx, dst); + break; case GGML_OP_RWKV_WKV7: ggml_cuda_op_rwkv_wkv7(ctx, dst); break; @@ -2767,26 +3179,54 @@ static const char * ggml_backend_cuda_get_name(ggml_backend_t backend) { static void ggml_backend_cuda_free(ggml_backend_t backend) { ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *)backend->context; +#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) + ggml_backend_cuda_device_context * dev_ctx = (ggml_backend_cuda_device_context *) backend->device->context; + std::lock_guard<std::mutex> lock(dev_ctx->device_mutex); + dev_ctx->active_count--; +#endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) + delete cuda_ctx; delete backend; } static void ggml_backend_cuda_set_tensor_async(ggml_backend_t backend, ggml_tensor * tensor, const void * data, size_t offset, size_t size) { - ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *)backend->context; + ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *) backend->context; ggml_backend_buffer_t buf = tensor->view_src ? tensor->view_src->buffer : tensor->buffer; GGML_ASSERT(buf->buft == ggml_backend_cuda_buffer_type(cuda_ctx->device) && "unsupported buffer type"); - CUDA_CHECK(cudaMemcpyAsync((char *)tensor->data + offset, data, size, cudaMemcpyHostToDevice, cuda_ctx->stream())); + CUDA_CHECK(cudaMemcpyAsync((char *) tensor->data + offset, data, size, cudaMemcpyHostToDevice, cuda_ctx->stream())); } static void ggml_backend_cuda_get_tensor_async(ggml_backend_t backend, const ggml_tensor * tensor, void * data, size_t offset, size_t size) { - ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *)backend->context; + ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *) backend->context; + ggml_backend_buffer_t buf = tensor->view_src ? tensor->view_src->buffer : tensor->buffer; + + GGML_ASSERT(buf->buft == ggml_backend_cuda_buffer_type(cuda_ctx->device) && "unsupported buffer type"); + + CUDA_CHECK(cudaMemcpyAsync(data, (const char *) tensor->data + offset, size, cudaMemcpyDeviceToHost, cuda_ctx->stream())); +} + +static void ggml_backend_cuda_set_tensor_2d_async(ggml_backend_t backend, struct ggml_tensor * tensor, const void * data, + size_t offset, size_t size, size_t n_copies, size_t stride_tensor, size_t stride_data) { + ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *) backend->context; + ggml_backend_buffer_t buf = tensor->view_src ? tensor->view_src->buffer : tensor->buffer; + + GGML_ASSERT(buf->buft == ggml_backend_cuda_buffer_type(cuda_ctx->device) && "unsupported buffer type"); + + CUDA_CHECK(cudaMemcpy2DAsync( + (char *) tensor->data + offset, stride_tensor, data, stride_data, size, n_copies, cudaMemcpyHostToDevice, cuda_ctx->stream())); +} + +static void ggml_backend_cuda_get_tensor_2d_async(ggml_backend_t backend, const struct ggml_tensor * tensor, void * data, + size_t offset, size_t size, size_t n_copies, size_t stride_tensor, size_t stride_data) { + ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *) backend->context; ggml_backend_buffer_t buf = tensor->view_src ? tensor->view_src->buffer : tensor->buffer; GGML_ASSERT(buf->buft == ggml_backend_cuda_buffer_type(cuda_ctx->device) && "unsupported buffer type"); - CUDA_CHECK(cudaMemcpyAsync(data, (const char *)tensor->data + offset, size, cudaMemcpyDeviceToHost, cuda_ctx->stream())); + CUDA_CHECK(cudaMemcpy2DAsync( + data, stride_data, (const char *) tensor->data + offset, stride_tensor, size, n_copies, cudaMemcpyDeviceToHost, cuda_ctx->stream())); } static bool ggml_backend_cuda_cpy_tensor_async(ggml_backend_t backend_src, ggml_backend_t backend_dst, const ggml_tensor * src, ggml_tensor * dst) { @@ -2797,21 +3237,21 @@ static bool ggml_backend_cuda_cpy_tensor_async(ggml_backend_t backend_src, ggml_ return false; } - if (!ggml_backend_buffer_is_cuda(src->buffer) || !ggml_backend_buffer_is_cuda(dst->buffer)) { + if (!ggml_backend_buffer_is_cuda(buf_src) || !ggml_backend_buffer_is_cuda(buf_dst)) { return false; } // device -> device copy - ggml_backend_cuda_context * cuda_ctx_src = (ggml_backend_cuda_context *)backend_src->context; - ggml_backend_cuda_context * cuda_ctx_dst = (ggml_backend_cuda_context *)backend_dst->context; + ggml_backend_cuda_context * cuda_ctx_src = (ggml_backend_cuda_context *) backend_src->context; + ggml_backend_cuda_context * cuda_ctx_dst = (ggml_backend_cuda_context *) backend_dst->context; - ggml_backend_cuda_buffer_context * buf_ctx_src = (ggml_backend_cuda_buffer_context *)buf_src->context; - ggml_backend_cuda_buffer_context * buf_ctx_dst = (ggml_backend_cuda_buffer_context *)buf_dst->context; + ggml_backend_cuda_buffer_context * buf_ctx_src = (ggml_backend_cuda_buffer_context *) buf_src->context; + ggml_backend_cuda_buffer_context * buf_ctx_dst = (ggml_backend_cuda_buffer_context *) buf_dst->context; if (cuda_ctx_src->device != buf_ctx_src->device || cuda_ctx_dst->device != buf_ctx_dst->device) { #ifndef NDEBUG GGML_LOG_DEBUG("%s: backend and buffer devices do not match\n", __func__); -#endif +#endif // NDEBUG return false; } @@ -2824,7 +3264,7 @@ static bool ggml_backend_cuda_cpy_tensor_async(ggml_backend_t backend_src, ggml_ return false; #else CUDA_CHECK(cudaMemcpyPeerAsync(dst->data, cuda_ctx_dst->device, src->data, cuda_ctx_src->device, ggml_nbytes(dst), cuda_ctx_src->stream())); -#endif +#endif // GGML_CUDA_NO_PEER_COPY } // record event on src stream after the copy @@ -2858,14 +3298,6 @@ static bool ggml_cuda_graph_check_compability(ggml_cgraph * cgraph) { bool use_cuda_graph = true; // Loop over nodes in GGML graph to obtain info needed for CUDA graph - const std::string gemma3n_per_layer_proj_src0_name = "inp_per_layer_selected"; - const std::string gemma3n_per_layer_proj_src1_name = "per_layer_proj"; - const std::string ffn_moe_gate_bias_prefix = "ffn_moe_gate_biased"; - const std::string ffn_moe_up_bias_prefix = "ffn_moe_up_biased"; - const std::string ffn_moe_down_bias_prefix = "ffn_moe_down_biased"; - const std::string nemotron_h_block_out_prefix = "nemotron_h_block_out"; - const std::string mamba2_y_add_d_prefix = "mamba2_y_add_d"; - for (int i = 0; i < cgraph->n_nodes; i++) { ggml_tensor * node = cgraph->nodes[i]; @@ -2880,31 +3312,19 @@ static bool ggml_cuda_graph_check_compability(ggml_cgraph * cgraph) { #endif } - if (node->op == GGML_OP_MUL_MAT_ID && node->ne[2] != 1) { - use_cuda_graph = false; // This node type is not supported by CUDA graph capture + // [TAG_MUL_MAT_ID_CUDA_GRAPHS] + if (node->op == GGML_OP_MUL_MAT_ID) { + const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc; + const int mmvq_mmid_max = get_mmvq_mmid_max_batch(node->src[0]->type, cc); + if (!ggml_is_quantized(node->src[0]->type) || node->ne[2] > mmvq_mmid_max) { + // under these conditions, the mul_mat_id operation will need to synchronize the stream, so we cannot use CUDA graphs + // TODO: figure out a way to enable for larger batch sizes, without hurting performance + // ref: https://github.com/ggml-org/llama.cpp/pull/18958 + use_cuda_graph = false; #ifndef NDEBUG - GGML_LOG_DEBUG("%s: disabling CUDA graphs due to unsupported node type\n", __func__); -#endif - } - - if (node->op == GGML_OP_ADD && - node->src[1] && node->src[1]->ne[1] > 1 && - (node->src[0] ? node->src[0]->name != gemma3n_per_layer_proj_src0_name : true) && - (node->src[1] ? node->src[1]->name != gemma3n_per_layer_proj_src1_name : true) && - strncmp(node->name, ffn_moe_gate_bias_prefix.c_str(), ffn_moe_gate_bias_prefix.size()) != 0 && - strncmp(node->name, ffn_moe_up_bias_prefix.c_str(), ffn_moe_up_bias_prefix.size()) != 0 && - strncmp(node->name, ffn_moe_down_bias_prefix.c_str(), ffn_moe_down_bias_prefix.size()) != 0 && - strncmp(node->name, nemotron_h_block_out_prefix.c_str(), nemotron_h_block_out_prefix.size()) != 0 && - strncmp(node->name, mamba2_y_add_d_prefix.c_str(), mamba2_y_add_d_prefix.size()) != 0) { - // disable CUDA graphs for batch size > 1 for now while excluding the matrix-matrix addition as part of Gemma3n's `project_per_layer_input` operation - // by means of matching node names. See - // https://github.com/ggml-org/llama.cpp/blob/f9a31eea06a859e34cecb88b4d020c7f03d86cc4/src/llama-model.cpp#L10199-L10241 and - // https://github.com/huggingface/transformers/blob/bda75b4011239d065de84aa3e744b67ebfa7b245/src/transformers/models/gemma3n/modeling_gemma3n.py#L1773, - // Generally, changes in batch size or context size can cause changes to the grid size of some kernels. - use_cuda_graph = false; -#ifndef NDEBUG - GGML_LOG_DEBUG("%s: disabling CUDA graphs due to batch size > 1 [%s] [%ld %ld %ld %ld]\n", __func__, node->name, node->ne[0], node->ne[1], node->ne[2], node->ne[3]); + GGML_LOG_DEBUG("%s: disabling CUDA graphs due to unsupported node type\n", __func__); #endif + } } if (!use_cuda_graph) { @@ -2915,105 +3335,62 @@ static bool ggml_cuda_graph_check_compability(ggml_cgraph * cgraph) { return use_cuda_graph; } -static void ggml_cuda_graph_node_set_properties(ggml_cuda_graph_node_properties * props, ggml_tensor * node) { - props->node_address = node->data; - props->node_op = node->op; - for (int i = 0; i < GGML_MAX_DIMS; i++) { - props->ne[i] = node->ne[i]; - props->nb[i] = node->nb[i]; - } - for (int i = 0; i < GGML_MAX_SRC; i++) { - props->src_address[i] = node->src[i] ? node->src[i]->data : nullptr; - } - memcpy(props->op_params, node->op_params, GGML_MAX_OP_PARAMS); +static const void * ggml_cuda_graph_get_key(ggml_cgraph * cgraph) { + return cgraph->nodes[0]; } -static bool ggml_cuda_graph_node_properties_match(ggml_tensor * node, ggml_cuda_graph_node_properties * props) { - if (node->data != props->node_address && - node->op != GGML_OP_VIEW) { - return false; - } - - if (node->op != props->node_op) { - return false; - } +static bool ggml_cuda_graph_update_required(ggml_backend_cuda_context * cuda_ctx, ggml_cgraph * cgraph) { + bool res = false; - for (int i = 0; i < GGML_MAX_DIMS; i++) { - if (node->ne[i] != props->ne[i]) { - return false; - } - if (node->nb[i] != props->nb[i]) { - return false; - } - } + const void * graph_key = ggml_cuda_graph_get_key(cgraph); + ggml_cuda_graph * graph = cuda_ctx->cuda_graph(graph_key); - for (int i = 0; i < GGML_MAX_SRC; i++) { - if (node->src[i] && - node->src[i]->data != props->src_address[i] && - node->op != GGML_OP_VIEW - ) { - return false; - } - } - - if ((node->op == GGML_OP_SCALE || node->op == GGML_OP_GLU) && - memcmp(props->op_params, node->op_params, GGML_MAX_OP_PARAMS) != 0) { + if (cgraph->uid != 0 && + cgraph->uid == graph->uid) { + GGML_LOG_DEBUG("CUDA Graph id %zu reused\n", cgraph->uid); + GGML_ASSERT((int)graph->node_props.size() == cgraph->n_nodes); return false; } - return true; -} - -static bool ggml_cuda_graph_update_required(ggml_backend_cuda_context * cuda_ctx, ggml_cgraph * cgraph) { - - bool res = false; - - if (cuda_ctx->cuda_graph->instance == nullptr) { - res = true; - } + graph->uid = cgraph->uid; // Check if the graph size has changed - if (cuda_ctx->cuda_graph->props.size() != (size_t)cgraph->n_nodes + cgraph->n_leafs) { + if ((int)graph->node_props.size() != cgraph->n_nodes) { res = true; - cuda_ctx->cuda_graph->props.resize(cgraph->n_nodes + cgraph->n_leafs); + graph->node_props.resize(cgraph->n_nodes); } - // Loop over nodes in GGML graph to determine if CUDA graph update is required - // and store properties to allow this comparison for the next token for (int i = 0; i < cgraph->n_nodes; i++) { - bool props_match = true; - if (!res) { - props_match = ggml_cuda_graph_node_properties_match(cgraph->nodes[i], &cuda_ctx->cuda_graph->props[i]); - } - if (!props_match) { - res = true; + ggml_cuda_graph::node_properties prop = {}; + memcpy(&prop.node, cgraph->nodes[i], sizeof(ggml_tensor)); + + for (int j = 0; j < GGML_MAX_SRC; ++j) { + if (cgraph->nodes[i]->src[j]) { + prop.node_src_data_ptrs[j] = cgraph->nodes[i]->src[j]->data; + memcpy(prop.node_src_ne[j], cgraph->nodes[i]->src[j]->ne, sizeof(prop.node_src_ne[j])); + memcpy(prop.node_src_nb[j], cgraph->nodes[i]->src[j]->nb, sizeof(prop.node_src_nb[j])); + } } - ggml_cuda_graph_node_set_properties(&cuda_ctx->cuda_graph->props[i], cgraph->nodes[i]); - } - for (int i = 0; i < cgraph->n_leafs; i++) { - bool props_match= true; - if (!res) { - props_match = ggml_cuda_graph_node_properties_match(cgraph->leafs[i], &cuda_ctx->cuda_graph->props[cgraph->n_nodes + i]); - } - if (!props_match) { + if (res || memcmp(&graph->node_props[i], &prop, sizeof(prop)) != 0) { + graph->node_props[i] = prop; res = true; } - ggml_cuda_graph_node_set_properties(&cuda_ctx->cuda_graph->props[cgraph->n_nodes + i], cgraph->leafs[i]); } return res; } -static void ggml_cuda_graph_update_executable(ggml_backend_cuda_context * cuda_ctx) { +static void ggml_cuda_graph_update_executable(ggml_backend_cuda_context * cuda_ctx, const void * graph_key) { + ggml_cuda_graph * graph = cuda_ctx->cuda_graph(graph_key); #if CUDART_VERSION >= 12000 cudaGraphExecUpdateResultInfo result_info; - cudaError_t stat = cudaGraphExecUpdate(cuda_ctx->cuda_graph->instance, cuda_ctx->cuda_graph->graph, &result_info); + cudaError_t stat = cudaGraphExecUpdate(graph->instance, graph->graph, &result_info); #else cudaGraphNode_t errorNode; cudaGraphExecUpdateResult result_info; - cudaError_t stat = cudaGraphExecUpdate(cuda_ctx->cuda_graph->instance, cuda_ctx->cuda_graph->graph, &errorNode, &result_info); + cudaError_t stat = cudaGraphExecUpdate(graph->instance, graph->graph, &errorNode, &result_info); #endif // CUDART_VERSION >= 12000 if (stat == cudaErrorGraphExecUpdateFailure) { @@ -3024,14 +3401,14 @@ static void ggml_cuda_graph_update_executable(ggml_backend_cuda_context * cuda_c // The pre-existing graph exec cannot be updated due to violated constraints // so instead clear error and re-instantiate (void)cudaGetLastError(); - CUDA_CHECK(cudaGraphExecDestroy(cuda_ctx->cuda_graph->instance)); - cuda_ctx->cuda_graph->instance = nullptr; - CUDA_CHECK(cudaGraphInstantiate(&cuda_ctx->cuda_graph->instance, cuda_ctx->cuda_graph->graph, NULL, NULL, 0)); + CUDA_CHECK(cudaGraphExecDestroy(graph->instance)); + graph->instance = nullptr; + CUDA_CHECK(cudaGraphInstantiate(&graph->instance, graph->graph, NULL, NULL, 0)); } else { GGML_ASSERT(stat == cudaSuccess); } } -#endif +#endif // USE_CUDA_GRAPH static bool ggml_cuda_should_fuse_rope_set_rows(const ggml_tensor * rope, const ggml_tensor * view, @@ -3067,63 +3444,231 @@ static bool ggml_cuda_should_fuse_rope_set_rows(const ggml_tensor * rope, return true; } -static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx, std::initializer_list<enum ggml_op> ops, std::initializer_list<enum ggml_unary_op> unary_ops) { -#ifndef NDEBUG - const size_t num_unary = std::count(ops.begin(), ops.end(), GGML_OP_UNARY); - GGML_ASSERT(unary_ops.size() == num_unary); -#endif +static bool ggml_cuda_topk_moe_fusion(const struct ggml_cgraph * cgraph, int node_idx, ggml_cuda_topk_moe_args & args) { + args.sigmoid = false; + args.softmax = false; + args.delayed_softmax = false; + args.prob_bias = false; + args.norm = false; - //TODO: remove special case once ggml_can_fuse can handle empty nodes - std::initializer_list<enum ggml_op> topk_moe_ops = - ggml_cuda_topk_moe_ops(/*with_norm*/ false, /*delayed_softmax=*/false); - std::initializer_list<enum ggml_op> topk_moe_ops_with_norm = - ggml_cuda_topk_moe_ops(/*with_norm=*/true, /*delayed_softmax=*/false); - std::initializer_list<enum ggml_op> topk_moe_ops_delayed_softmax = - ggml_cuda_topk_moe_ops(/*with_norm=*/false, /*delayed_softmax=*/true); + const int n_nodes = cgraph->n_nodes; + ggml_tensor ** nodes = cgraph->nodes; - const auto is_equal = [](const std::initializer_list<enum ggml_op> & list1, - const std::initializer_list<enum ggml_op> & list2) { - return std::equal(list1.begin(), list1.end(), list2.begin(), list2.end()); - }; + if (nodes[node_idx]->op == GGML_OP_SOFT_MAX) { + args.softmax = true; + } - if (is_equal(topk_moe_ops_with_norm, ops) && - ggml_can_fuse_subgraph(cgraph, node_idx, ops, { node_idx + 3, node_idx + 9 })) { - ggml_tensor * softmax = cgraph->nodes[node_idx]; - ggml_tensor * weights = cgraph->nodes[node_idx + 9]; - ggml_tensor * get_rows = cgraph->nodes[node_idx + 4]; - ggml_tensor * argsort = cgraph->nodes[node_idx + 2]; - int n_expert = cgraph->nodes[node_idx]->src[0]->ne[0]; + if (nodes[node_idx]->op == GGML_OP_UNARY) { + if (ggml_get_unary_op(nodes[node_idx]) != GGML_UNARY_OP_SIGMOID) { + return false; + } + args.sigmoid = true; + } - if (ggml_cuda_should_use_topk_moe(softmax, weights, get_rows, argsort, nullptr, n_expert)) { - return true; + if (nodes[node_idx]->op == GGML_OP_ARGSORT) { + args.delayed_softmax = true; + } + + node_idx++; + + if (args.sigmoid || args.softmax) { + // SOFTMAX -> RESHAPE + if (node_idx >= n_nodes || nodes[node_idx]->op != GGML_OP_RESHAPE || + nodes[node_idx]->src[0] != nodes[node_idx - 1]) { + return false; + } + ggml_tensor * probs_reshaped = nodes[node_idx]; + node_idx++; + + if (node_idx >= n_nodes) { + return false; + } + + // src of bias add is the unreshaped probs (-2 instead of -1) + if (nodes[node_idx]->op == GGML_OP_ADD && nodes[node_idx]->src[0] == nodes[node_idx - 2]) { + args.prob_bias = true; + node_idx++; + } + // RESHAPE/ADD -> ARGSORT + if (node_idx >= n_nodes || nodes[node_idx]->op != GGML_OP_ARGSORT) { + return false; + } + + if (args.prob_bias && nodes[node_idx]->src[0] != nodes[node_idx - 1]) { + return false; + } else if (!args.prob_bias && nodes[node_idx]->src[0] != nodes[node_idx - 2]) { + return false; + } + + node_idx++; + + // ARGSORT-> VIEW + if (node_idx >= n_nodes || nodes[node_idx]->op != GGML_OP_VIEW || + nodes[node_idx]->src[0] != nodes[node_idx - 1]) { + return false; + } + node_idx++; + + if (node_idx >= n_nodes || nodes[node_idx]->op != GGML_OP_GET_ROWS) { + return false; + } + + // GET_ROWS + if (nodes[node_idx]->src[0] != probs_reshaped || nodes[node_idx]->src[1] != nodes[node_idx - 1]) { + return false; + } + node_idx++; + } else if (args.delayed_softmax) { + if (node_idx - 2 < 0) { + return false; + } + ggml_tensor * probs_reshaped = nodes[node_idx - 2]; + + // VIEW->ARGSORT + if (node_idx >= n_nodes || nodes[node_idx]->op != GGML_OP_VIEW || + nodes[node_idx]->src[0] != nodes[node_idx - 1]) { + return false; + } + node_idx++; + + // GET_ROWS + if (node_idx >= n_nodes || nodes[node_idx]->src[1] != nodes[node_idx - 1] || + nodes[node_idx]->src[0] != probs_reshaped) { + return false; + } + node_idx++; + + static const std::vector<ggml_op> remaining_ops = { GGML_OP_RESHAPE, GGML_OP_SOFT_MAX, GGML_OP_RESHAPE }; + + for (const ggml_op op : remaining_ops) { + if (node_idx >= n_nodes || nodes[node_idx]->op != op || nodes[node_idx]->src[0] != nodes[node_idx - 1]) { + return false; + } + node_idx++; } } - if (is_equal(topk_moe_ops, ops) && ggml_can_fuse_subgraph(cgraph, node_idx, ops, { node_idx + 3, node_idx + 4 })) { - ggml_tensor * softmax = cgraph->nodes[node_idx]; - ggml_tensor * weights = cgraph->nodes[node_idx + 4]; - ggml_tensor * get_rows = cgraph->nodes[node_idx + 4]; - ggml_tensor * argsort = cgraph->nodes[node_idx + 2]; - int n_expert = cgraph->nodes[node_idx]->src[0]->ne[0]; + // At this point we can check for norm + scale. Everything is now at least valid till the norm + if (node_idx >= n_nodes) { + return true; + } + + if (nodes[node_idx]->op == GGML_OP_RESHAPE) { + //check RESHAPE->SUM_ROWS->CLAMP->DIV->RESHAPE + static const std::vector<ggml_op> norm_ops = { GGML_OP_RESHAPE, GGML_OP_SUM_ROWS, GGML_OP_CLAMP }; + + args.norm = true; + for (const ggml_op op : norm_ops) { + if (nodes[node_idx]->op == op && nodes[node_idx]->src[0] == nodes[node_idx - 1]) { + node_idx++; + } else { + args.norm = false; + return true; + } + } + + // DIV <- CLAMP, RESHAPE + if (nodes[node_idx]->op != GGML_OP_DIV || nodes[node_idx]->src[1] != nodes[node_idx - 1] || + nodes[node_idx]->src[0] != nodes[node_idx - 3]) { + args.norm = false; + return true; + } + node_idx++; - if (ggml_cuda_should_use_topk_moe(softmax, weights, get_rows, argsort, nullptr, n_expert)) { + if (nodes[node_idx]->op != GGML_OP_RESHAPE || nodes[node_idx]->src[0] != nodes[node_idx - 1]) { + args.norm = false; return true; } + + node_idx++; } - if (is_equal(topk_moe_ops_delayed_softmax, ops) && - ggml_can_fuse_subgraph(cgraph, node_idx, ops, { node_idx + 1, node_idx + 5 })) { - ggml_tensor * softmax = cgraph->nodes[node_idx + 4]; - ggml_tensor * weights = cgraph->nodes[node_idx + 5]; - ggml_tensor * get_rows = cgraph->nodes[node_idx + 2]; - ggml_tensor * argsort = cgraph->nodes[node_idx + 0]; - int n_expert = cgraph->nodes[node_idx]->src[0]->ne[0]; + if (nodes[node_idx]->op == GGML_OP_SCALE && nodes[node_idx]->src[0] == nodes[node_idx - 1]) { + args.scale = true; + } - if (ggml_cuda_should_use_topk_moe(softmax, weights, get_rows, argsort, nullptr, n_expert)) { + return true; +} + +// returns whether the write (out) nodes overwrite the read nodes in operation +static bool ggml_cuda_check_fusion_memory_ranges(const ggml_cgraph * cgraph, + const int node_idx, + const int node_count, + const int * out_nodes, + const int out_count, + const bool is_topk_moe = false) { + auto nodes_overlap = [&](const ggml_tensor * a, const ggml_tensor * b) { + const int64_t a_start = (int64_t) a->data; + const int64_t a_end = a_start + ggml_backend_buft_get_alloc_size(a->buffer->buft, a); + + const int64_t b_start = (int64_t) b->data; + const int64_t b_end = b_start + ggml_backend_buft_get_alloc_size(b->buffer->buft, b); + + if ((b_start <= a_start && a_start < b_end) || (a_start <= b_start && b_start < a_end)) { return true; } + + return false; + }; + + bool is_ok = true; + // exception for topk-moe, as each row is read entirely before writing + if (ggml_nrows(cgraph->nodes[node_idx]) == 1 && is_topk_moe) { + return true; + } + + for (int i = 0; i < out_count; ++i) { + const ggml_tensor * dst = cgraph->nodes[out_nodes[i]]; + + for (int j = node_idx; j < node_idx + node_count; ++j) { + // Loop over all srcs of all nodes in the fusion. If the src overlaps + // the destination and the src is not an intermediate node that's being + // elided, then disable fusion. + + for (int src_idx = 0; src_idx < GGML_MAX_SRC; ++src_idx) { + const ggml_tensor * src = cgraph->nodes[j]->src[src_idx]; + + if (!src || src->op == GGML_OP_NONE) { + continue; + } + + if (nodes_overlap(dst, src)) { + bool found = false; + + for (int k = node_idx; k < j; ++k) { + if (cgraph->nodes[k] == src) { + found = true; + break; + } + } + + if (!found) { + is_ok = false; + break; + } + } + } + } } + return is_ok; +} + + +static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, + int node_idx, + std::initializer_list<enum ggml_op> ops, + std::initializer_list<enum ggml_unary_op> unary_ops) { +#ifndef NDEBUG + const size_t num_unary = std::count(ops.begin(), ops.end(), GGML_OP_UNARY); + GGML_ASSERT(unary_ops.size() == num_unary); +#endif + + const auto is_equal = [](const std::initializer_list<enum ggml_op> & list1, + const std::initializer_list<enum ggml_op> & list2) { + return std::equal(list1.begin(), list1.end(), list2.begin(), list2.end()); + }; + std::initializer_list<enum ggml_op> mul_mat_bias_glu_ops = { GGML_OP_MUL_MAT, GGML_OP_ADD, GGML_OP_MUL_MAT, GGML_OP_ADD, GGML_OP_GLU }; std::initializer_list<enum ggml_op> mul_mat_id_bias_glu_ops = { GGML_OP_MUL_MAT_ID, GGML_OP_ADD_ID, GGML_OP_MUL_MAT_ID, GGML_OP_ADD_ID, GGML_OP_GLU }; @@ -3139,7 +3684,8 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx, const ggml_tensor * glu = cgraph->nodes[node_idx + 4]; if (ggml_cuda_should_fuse_mul_mat(ffn_up, ffn_gate, glu, ffn_up_bias, ffn_gate_bias)) { - return true; + int out_nodes[] = { node_idx + 4 }; + return ggml_cuda_check_fusion_memory_ranges(cgraph, node_idx, (int)ops.size(), out_nodes, 1); } } @@ -3150,7 +3696,8 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx, const ggml_tensor * glu = cgraph->nodes[node_idx + 2]; if (ggml_cuda_should_fuse_mul_mat(ffn_up, ffn_gate, glu)) { - return true; + int out_nodes[] = { node_idx + 2 }; + return ggml_cuda_check_fusion_memory_ranges(cgraph, node_idx, (int)ops.size(), out_nodes, 1); } } @@ -3200,7 +3747,7 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx, return false; } - //rms_norm kernel assumes contigous rows + //rms_norm kernel assumes contiguous rows if (!ggml_is_contiguous_rows(mul->src[0]) || !ggml_is_contiguous_rows(mul->src[1])) { return false; } @@ -3212,6 +3759,98 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx, return true; } + if (ops.size() == 2 && ops.begin()[0] == GGML_OP_SSM_CONV && ops.begin()[1] == GGML_OP_UNARY + && unary_ops.size() == 1 && unary_ops.begin()[0] == GGML_UNARY_OP_SILU) { + const ggml_tensor * ssm_conv = cgraph->nodes[node_idx]; + const ggml_tensor * silu = cgraph->nodes[node_idx+1]; + if (ggml_get_unary_op(silu) != unary_ops.begin()[0]) { + return false; + } + + if (ssm_conv->type != GGML_TYPE_F32 || silu->type != GGML_TYPE_F32) { + return false; + } + + return true; + } + + if (ops.size() == 3 && ops.begin()[0] == GGML_OP_SSM_CONV && ops.begin()[1] == GGML_OP_ADD + && ops.begin()[2] == GGML_OP_UNARY && unary_ops.size() == 1 && unary_ops.begin()[0] == GGML_UNARY_OP_SILU) { + const ggml_tensor * ssm_conv = cgraph->nodes[node_idx]; + const ggml_tensor * add = cgraph->nodes[node_idx+1]; + const ggml_tensor * silu = cgraph->nodes[node_idx+2]; + if (ggml_get_unary_op(silu) != unary_ops.begin()[0]) { + return false; + } + + if (ssm_conv->type != GGML_TYPE_F32 || add->type != GGML_TYPE_F32 || silu->type != GGML_TYPE_F32) { + return false; + } + + // ADD must consume ssm_conv's output and broadcast a 1-D channel-wise bias. + const ggml_tensor * bias = (add->src[0] == ssm_conv) ? add->src[1] : add->src[0]; + if (bias->type != GGML_TYPE_F32 || !ggml_is_contiguous(bias)) { + return false; + } + if (ggml_nelements(bias) != ssm_conv->ne[0] || bias->ne[0] != ssm_conv->ne[0]) { + return false; + } + + return true; + } + + if (ops.size() == 2 && ops.begin()[0] == GGML_OP_UNARY && ops.begin()[1] == GGML_OP_MUL + && unary_ops.size() == 1 && (unary_ops.begin()[0] == GGML_UNARY_OP_SILU || unary_ops.begin()[0] == GGML_UNARY_OP_SIGMOID || unary_ops.begin()[0] == GGML_UNARY_OP_SOFTPLUS)) { + const ggml_tensor * unary = cgraph->nodes[node_idx]; + const ggml_tensor * mul = cgraph->nodes[node_idx+1]; + + if (ggml_get_unary_op(unary) != unary_ops.begin()[0]) { + return false; + } + + if (unary->type != GGML_TYPE_F32 && unary->type != GGML_TYPE_F16) { + return false; + } + + if (unary->type != mul->type) { + return false; + } + + const ggml_tensor * other = (mul->src[0] == unary) ? mul->src[1] : mul->src[0]; + if (other->type != unary->type) { + return false; + } + if (!ggml_is_contiguous_1(other) || !ggml_is_contiguous_1(unary->src[0]) || !ggml_are_same_shape(other, unary)) { + return false; + } + + return true; + } + + if (ops.size() == 2 && ops.begin()[0] == GGML_OP_UNARY && ops.begin()[1] == GGML_OP_SQR + && unary_ops.size() == 1 && unary_ops.begin()[0] == GGML_UNARY_OP_RELU) { + const ggml_tensor * unary = cgraph->nodes[node_idx]; + const ggml_tensor * sqr = cgraph->nodes[node_idx+1]; + + if (ggml_get_unary_op(unary) != GGML_UNARY_OP_RELU) { + return false; + } + + if (unary->type != GGML_TYPE_F32 && unary->type != GGML_TYPE_F16) { + return false; + } + + if (unary->type != sqr->type) { + return false; + } + + if (!ggml_is_contiguous(unary->src[0])) { + return false; + } + + return true; + } + if (ops.size() == 3 && ops.begin()[0] == GGML_OP_SCALE && ops.begin()[1] == GGML_OP_UNARY && ops.begin()[2] == GGML_OP_SCALE && unary_ops.size() == 1 && unary_ops.begin()[0] == GGML_UNARY_OP_TANH) { const ggml_tensor *scale = cgraph->nodes[node_idx]; @@ -3236,7 +3875,407 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx, return false; } -static void ggml_cuda_graph_evaluate_and_capture(ggml_backend_cuda_context * cuda_ctx, ggml_cgraph * cgraph, const bool use_cuda_graph, const bool cuda_graph_update_required) { +// try and fuse nodes and return the number of nodes to skip +static int ggml_cuda_try_fuse(ggml_backend_cuda_context * cuda_ctx, ggml_cgraph * cgraph, int i) { + + static bool disable_fusion = getenv("GGML_CUDA_DISABLE_FUSION") != nullptr && std::atoi(getenv("GGML_CUDA_DISABLE_FUSION")); + if (disable_fusion) { + return 0; + } + + ggml_tensor * node = cgraph->nodes[i]; + + //topk-moe + if (cgraph->nodes[i]->op == GGML_OP_UNARY || cgraph->nodes[i]->op == GGML_OP_SOFT_MAX || + cgraph->nodes[i]->op == GGML_OP_ARGSORT) { + ggml_cuda_topk_moe_args args; + const bool can_fuse = ggml_cuda_topk_moe_fusion(cgraph, i, args); + std::vector<ggml_op> ops; + + if (can_fuse) { + const ggml_tensor * logits = node->src[0]; + ggml_tensor * weights = nullptr; + ggml_tensor * ids = nullptr; + const ggml_tensor * bias = nullptr; + const ggml_tensor * clamp = nullptr; + const ggml_tensor * scale = nullptr; + + if (!args.delayed_softmax) { + ggml_op gating_op = args.sigmoid ? GGML_OP_UNARY : GGML_OP_SOFT_MAX; + int out_nodes[2]; // nodes which can't be elided + + if (args.prob_bias) { + bias = cgraph->nodes[i + 2]->src[1]; + ops.insert(ops.end(), { gating_op, GGML_OP_RESHAPE, GGML_OP_ADD, GGML_OP_ARGSORT, GGML_OP_VIEW, + GGML_OP_GET_ROWS }); + out_nodes[0] = i + 4; + ids = cgraph->nodes[i + 4]; + } else { + ops.insert(ops.end(), + { gating_op, GGML_OP_RESHAPE, GGML_OP_ARGSORT, GGML_OP_VIEW, GGML_OP_GET_ROWS }); + out_nodes[0] = i + 3; + ids = cgraph->nodes[i + 3]; + } + + if (args.norm) { + ops.insert(ops.end(), + { GGML_OP_RESHAPE, GGML_OP_SUM_ROWS, GGML_OP_CLAMP, GGML_OP_DIV, GGML_OP_RESHAPE }); + clamp = cgraph->nodes[i + ops.size() - 3]; + } + if (args.scale) { + ops.insert(ops.end(), { GGML_OP_SCALE }); + scale = cgraph->nodes[i + ops.size() - 1]; + } + + weights = cgraph->nodes[i + ops.size() - 1]; + out_nodes[1] = i + ops.size() - 1; + + if (ggml_can_fuse_subgraph(cgraph, i, ops.size(), ops.data(), out_nodes, 2) && + ggml_cuda_should_use_topk_moe(node, logits, weights, ids) && + ggml_cuda_check_fusion_memory_ranges(cgraph, i, ops.size(), out_nodes, 2, /*is_topk_moe=*/true)) { + ggml_cuda_op_topk_moe(*cuda_ctx, logits, weights, ids, clamp, scale, bias, args); + return ops.size() - 1; + } + } else if (!args.norm && !args.prob_bias) { + //special case gpt-oss, no norm, no bias. + ops.insert(ops.end(), { GGML_OP_ARGSORT, GGML_OP_VIEW, GGML_OP_GET_ROWS, GGML_OP_RESHAPE, + GGML_OP_SOFT_MAX, GGML_OP_RESHAPE }); + weights = cgraph->nodes[i + 5]; + ids = cgraph->nodes[i + 1]; + const ggml_tensor * softmax = cgraph->nodes[i + 4]; + + int out_nodes[2] = { i + 1, i + 5 }; + if (ggml_can_fuse_subgraph(cgraph, i, ops.size(), ops.data(), out_nodes, 2) && + ggml_cuda_should_use_topk_moe(softmax, logits, weights, ids) && + ggml_cuda_check_fusion_memory_ranges(cgraph, i, ops.size(), out_nodes, 2, /*is_topk_moe=*/true)) { + ggml_cuda_op_topk_moe(*cuda_ctx, logits, weights, ids, clamp, scale, bias, args); + return ops.size() - 1; + } + } + } + } + + //RoPE + view + set-rows + if (ggml_cuda_can_fuse(cgraph, i, { GGML_OP_ROPE, GGML_OP_VIEW, GGML_OP_SET_ROWS }, {})) { + ggml_tensor * rope = cgraph->nodes[i]; + ggml_tensor * set_rows = cgraph->nodes[i + 2]; + + ggml_cuda_op_rope_fused(*cuda_ctx, rope, set_rows); + return 2; + } + + // Snake activation: y = x + sin(a*x)^2 * inv_b + // Naive 5-op decomposition emitted by frontends: mul -> sin -> sqr -> mul -> add + if (ggml_can_fuse_subgraph(cgraph, i, + { GGML_OP_MUL, GGML_OP_SIN, GGML_OP_SQR, GGML_OP_MUL, GGML_OP_ADD }, + { i + 4 })) { + const ggml_tensor * mul0 = cgraph->nodes[i]; + const ggml_tensor * sqr = cgraph->nodes[i + 2]; + const ggml_tensor * mul1 = cgraph->nodes[i + 3]; + ggml_tensor * add = cgraph->nodes[i + 4]; + + // x carries the full activation shape, a is the broadcast operand + const ggml_tensor * x = ggml_are_same_shape(mul0, mul0->src[0]) ? mul0->src[0] : mul0->src[1]; + const ggml_tensor * a = (x == mul0->src[0]) ? mul0->src[1] : mul0->src[0]; + + // mul1 reads sqr and inv_b in either operand order + const ggml_tensor * inv_b = (mul1->src[0] == sqr) ? mul1->src[1] : mul1->src[0]; + + // closure check: the trailing add must read the same x as the leading mul + const ggml_tensor * x_in_add = (add->src[0] == mul1) ? add->src[1] : add->src[0]; + + // Kernel iterates over total = T * C, so x and add must be 2D and + // a / inv_b must collapse to [1, C, 1, 1]. Higher dims are not handled. + const bool dim_ok = (x->ne[2] == 1 && x->ne[3] == 1) && + (add->ne[2] == 1 && add->ne[3] == 1) && + (a->ne[2] == 1 && a->ne[3] == 1); + const bool shape_ok = ggml_are_same_shape(a, inv_b) && a->ne[0] == 1 && a->ne[1] == x->ne[1]; + + // x must be in the supported whitelist and every operand / intermediate + // result must share x's type, since launch_snake casts a / inv_b as + // float and templates the kernel on a single T. Mixed precision chains + // fall back to the naive path. + const ggml_tensor * sin1 = cgraph->nodes[i + 1]; + const bool types_ok = (x->type == GGML_TYPE_F32 || x->type == GGML_TYPE_F16 || x->type == GGML_TYPE_BF16) && + (a->type == x->type) && (inv_b->type == x->type) && + (mul0->type == x->type) && (sin1->type == x->type) && + (sqr->type == x->type) && (mul1->type == x->type) && + (add->type == x->type); + + if (types_ok && shape_ok && dim_ok && x_in_add == x) { + ggml_cuda_op_snake_fused(*cuda_ctx, x, a, inv_b, add); + return 4; + } + } + + // multi-(add or mul) + if (node->op == GGML_OP_ADD || node->op == GGML_OP_MUL) { + int n_fuse = 0; + ggml_op ops[8]; + std::fill(ops, ops + 8, node->op); + + for (; n_fuse <= 6; ++n_fuse) { + if (!ggml_can_fuse(cgraph, i + n_fuse, ops + n_fuse, 2)) { + break; + } + if (cgraph->nodes[i + n_fuse] != cgraph->nodes[i + n_fuse + 1]->src[0]) { + break; + } + if (!ggml_are_same_layout(cgraph->nodes[i + n_fuse]->src[1], cgraph->nodes[i + n_fuse + 1]->src[1])) { + break; + } + } + + n_fuse++; + + if (n_fuse > 1) { + ggml_tensor fused_node; + memcpy(&fused_node, node, sizeof(ggml_tensor)); + for (int j = 0; j < n_fuse - 1; ++j) { + fused_node.src[j + 2] = cgraph->nodes[i + j + 1]->src[1]; + } + fused_node.data = cgraph->nodes[i + n_fuse - 1]->data; + if (node->op == GGML_OP_ADD) { + ggml_cuda_op_fused_add(*cuda_ctx, &fused_node, n_fuse); + } else { + ggml_cuda_op_fused_mul(*cuda_ctx, &fused_node, n_fuse); + } + return n_fuse - 1; + } + } + + bool fused_mul_mat_vec = false; + int fused_node_count = 0; + + // gate + glu + up + for (ggml_op op : { GGML_OP_MUL_MAT, GGML_OP_MUL_MAT_ID }) { + const ggml_op bias_op = op == GGML_OP_MUL_MAT ? GGML_OP_ADD : GGML_OP_ADD_ID; + + if (ggml_cuda_can_fuse(cgraph, i, { op, bias_op, op, bias_op, GGML_OP_GLU }, {})) { + ggml_tensor * glu = cgraph->nodes[i + 4]; + ggml_tensor * gate_bias_n = glu->src[0]; + ggml_tensor * up_bias_n = glu->src[1]; + + //we don't assume the order for {gate, up}. Instead infer it from the bias tensor + ggml_tensor * gate_n = nullptr; + ggml_tensor * up_n = nullptr; + + if (gate_bias_n->src[0] == cgraph->nodes[i] || gate_bias_n->src[1] == cgraph->nodes[i]) { + gate_n = cgraph->nodes[i]; + up_n = cgraph->nodes[i + 2]; + } else if (gate_bias_n->src[0] == cgraph->nodes[i + 2] || gate_bias_n->src[1] == cgraph->nodes[i + 2]) { + gate_n = cgraph->nodes[i + 2]; + up_n = cgraph->nodes[i]; + } else { + continue; + } + + auto get_bias_tensor = [](const ggml_tensor * bias_node, const ggml_tensor * mul_node, ggml_op op_bias) { + if (op_bias == GGML_OP_ADD) { + if (bias_node->src[0] == mul_node) { + return bias_node->src[1]; + } + if (bias_node->src[1] == mul_node) { + return bias_node->src[0]; + } + return (ggml_tensor *) nullptr; + } + GGML_ASSERT(op_bias == GGML_OP_ADD_ID); + GGML_ASSERT(bias_node->src[0] == mul_node); + return bias_node->src[1]; + }; + + ggml_tensor * up_bias_tensor = get_bias_tensor(up_bias_n, up_n, bias_op); + ggml_tensor * gate_bias_tensor = get_bias_tensor(gate_bias_n, gate_n, bias_op); + + if (!up_bias_tensor || !gate_bias_tensor) { + continue; + } + + // we don't support repeating adds + if (bias_op == GGML_OP_ADD && (!ggml_are_same_shape(gate_bias_n->src[0], gate_bias_n->src[1]) || + !ggml_are_same_shape(up_bias_n->src[0], up_bias_n->src[1]))) { + continue; + } + + const ggml_tensor * src0 = up_n->src[0]; + const ggml_tensor * src1 = up_n->src[1]; + const ggml_tensor * ids = up_n->src[2]; + + if (ggml_cuda_should_fuse_mul_mat_vec_f(up_n)) { + ggml_cuda_mm_fusion_args_host fusion_data{}; + fusion_data.gate = gate_n->src[0]; + fusion_data.x_bias = up_bias_tensor; + fusion_data.gate_bias = gate_bias_tensor; + fusion_data.glu_op = ggml_get_glu_op(glu); + + ggml_cuda_mul_mat_vec_f(*cuda_ctx, src0, src1, ids, glu, &fusion_data); + fused_mul_mat_vec = true; + fused_node_count = 5; + break; + } + + if (ggml_cuda_should_fuse_mul_mat_vec_q(up_n)) { + ggml_cuda_mm_fusion_args_host fusion_data{}; + fusion_data.gate = gate_n->src[0]; + fusion_data.x_bias = up_bias_tensor; + fusion_data.gate_bias = gate_bias_tensor; + fusion_data.glu_op = ggml_get_glu_op(glu); + + ggml_cuda_mul_mat_vec_q(*cuda_ctx, src0, src1, ids, glu, &fusion_data); + fused_mul_mat_vec = true; + fused_node_count = 5; + break; + } + } else if (ggml_cuda_can_fuse(cgraph, i, { op, op, GGML_OP_GLU }, {})) { + ggml_tensor * glu = cgraph->nodes[i + 2]; + ggml_tensor * gate = glu->src[0]; + ggml_tensor * up = glu->src[1]; + + bool ok = (gate == cgraph->nodes[i] && up == cgraph->nodes[i + 1]) || + (gate == cgraph->nodes[i + 1] && up == cgraph->nodes[i]); + + if (!ok) { + continue; + } + + const ggml_tensor * src0 = up->src[0]; + const ggml_tensor * src1 = up->src[1]; + const ggml_tensor * ids = up->src[2]; + + if (ggml_cuda_should_fuse_mul_mat_vec_f(up)) { + ggml_cuda_mm_fusion_args_host fusion_data{}; + fusion_data.gate = gate->src[0]; + fusion_data.glu_op = ggml_get_glu_op(glu); + + ggml_cuda_mul_mat_vec_f(*cuda_ctx, src0, src1, ids, glu, &fusion_data); + fused_mul_mat_vec = true; + fused_node_count = 3; + break; + } + + if (ggml_cuda_should_fuse_mul_mat_vec_q(up)) { + ggml_cuda_mm_fusion_args_host fusion_data{}; + fusion_data.gate = gate->src[0]; + fusion_data.glu_op = ggml_get_glu_op(glu); + + ggml_cuda_mul_mat_vec_q(*cuda_ctx, src0, src1, ids, glu, &fusion_data); + fused_mul_mat_vec = true; + fused_node_count = 3; + break; + } + } + } + + if (fused_mul_mat_vec) { + return fused_node_count - 1; + } + + fused_mul_mat_vec = false; + fused_node_count = 0; + + // gate + add + glu + up + add + for (ggml_op op : { GGML_OP_MUL_MAT, GGML_OP_MUL_MAT_ID }) { + const ggml_op bias_op = op == GGML_OP_MUL_MAT ? GGML_OP_ADD : GGML_OP_ADD_ID; + + if (!ggml_can_fuse(cgraph, i, { op, bias_op })) { + continue; + } + + ggml_tensor * mm_node = cgraph->nodes[i]; + ggml_tensor * bias_node = cgraph->nodes[i + 1]; + + ggml_tensor * bias_tensor = nullptr; + if (bias_op == GGML_OP_ADD) { + if (bias_node->src[0] == mm_node) { + bias_tensor = bias_node->src[1]; + } else if (bias_node->src[1] == mm_node) { + bias_tensor = bias_node->src[0]; + } else { + continue; + } + } else { + if (bias_node->src[0] != mm_node) { + continue; + } + bias_tensor = bias_node->src[1]; + } + + const ggml_tensor * src0 = mm_node->src[0]; + const ggml_tensor * src1 = mm_node->src[1]; + const ggml_tensor * ids = mm_node->src[2]; + + if (bias_op == GGML_OP_ADD_ID && bias_node->src[2] != ids) { + continue; + } + + if (bias_op == GGML_OP_ADD && !ggml_are_same_shape(bias_node->src[0], bias_node->src[1])) { + continue; + } + + ggml_cuda_mm_fusion_args_host fusion_data{}; + fusion_data.x_bias = bias_tensor; + + if (ggml_cuda_should_fuse_mul_mat_vec_f(mm_node)) { + ggml_cuda_mul_mat_vec_f(*cuda_ctx, src0, src1, ids, bias_node, &fusion_data); + fused_mul_mat_vec = true; + fused_node_count = 2; + break; + } + + if (ggml_cuda_should_fuse_mul_mat_vec_q(mm_node)) { + ggml_cuda_mul_mat_vec_q(*cuda_ctx, src0, src1, ids, bias_node, &fusion_data); + fused_mul_mat_vec = true; + fused_node_count = 2; + break; + } + } + + if (fused_mul_mat_vec) { + return fused_node_count - 1; + } + + if (ggml_cuda_can_fuse(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL, GGML_OP_ADD }, {})) { + ggml_cuda_op_rms_norm_fused_add(*cuda_ctx, node, cgraph->nodes[i + 1], cgraph->nodes[i + 2]); + return 2; + } + + if (ggml_cuda_can_fuse(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL }, {})) { + ggml_cuda_op_rms_norm_fused(*cuda_ctx, node, cgraph->nodes[i + 1]); + return 1; + } + + if (ggml_cuda_can_fuse(cgraph, i, { GGML_OP_SSM_CONV, GGML_OP_ADD, GGML_OP_UNARY }, { GGML_UNARY_OP_SILU })) { + ggml_cuda_op_ssm_conv(*cuda_ctx, node, cgraph->nodes[i + 1], cgraph->nodes[i + 2]); + return 2; + } + + if (ggml_cuda_can_fuse(cgraph, i, { GGML_OP_SSM_CONV, GGML_OP_UNARY }, { GGML_UNARY_OP_SILU })) { + ggml_cuda_op_ssm_conv(*cuda_ctx, node, /*bias_add_node=*/ nullptr, cgraph->nodes[i + 1]); + return 1; + } + + if (ggml_cuda_can_fuse(cgraph, i, { GGML_OP_UNARY, GGML_OP_MUL }, { GGML_UNARY_OP_SILU }) || + ggml_cuda_can_fuse(cgraph, i, { GGML_OP_UNARY, GGML_OP_MUL }, { GGML_UNARY_OP_SIGMOID }) || + ggml_cuda_can_fuse(cgraph, i, { GGML_OP_UNARY, GGML_OP_MUL }, { GGML_UNARY_OP_SOFTPLUS })) { + ggml_cuda_op_unary_mul(*cuda_ctx, node, cgraph->nodes[i + 1]); + return 1; + } + + if (ggml_cuda_can_fuse(cgraph, i, { GGML_OP_UNARY, GGML_OP_SQR }, { GGML_UNARY_OP_RELU })) { + ggml_cuda_op_relu_sqr(*cuda_ctx, node, cgraph->nodes[i + 1]); + return 1; + } + + if (ggml_cuda_can_fuse(cgraph, i, { GGML_OP_SCALE, GGML_OP_UNARY, GGML_OP_SCALE }, { GGML_UNARY_OP_TANH })) { + ggml_cuda_op_softcap(*cuda_ctx, cgraph->nodes[i + 2], node); + return 2; + } + + return 0; +} + +static void ggml_cuda_graph_evaluate_and_capture(ggml_backend_cuda_context * cuda_ctx, ggml_cgraph * cgraph, const bool use_cuda_graph, const bool cuda_graph_update_required, const void * graph_key) { bool graph_evaluated_or_captured = false; // flag used to determine whether it is an integrated_gpu @@ -3378,288 +4417,15 @@ static void ggml_cuda_graph_evaluate_and_capture(ggml_backend_cuda_context * cud continue; } + if ((node->flags & GGML_TENSOR_FLAG_COMPUTE) == 0) { + continue; + } - // start of fusion operations - static bool disable_fusion = (getenv("GGML_CUDA_DISABLE_FUSION") != nullptr); - if (!disable_fusion) { - - if (ggml_cuda_can_fuse(cgraph, i, ggml_cuda_topk_moe_ops(/*with norm*/ true), {})) { - ggml_tensor * weights = cgraph->nodes[i + 9]; - ggml_tensor * selected_experts = cgraph->nodes[i + 3]; - ggml_tensor * clamp = cgraph->nodes[i + 7]; - ggml_cuda_op_topk_moe(*cuda_ctx, node->src[0], weights, selected_experts, /*with norm*/ true, - /*delayed softmax*/ false, clamp); - i += 9; - continue; - } - - if (ggml_cuda_can_fuse(cgraph, i, ggml_cuda_topk_moe_ops(/*with norm*/ false), {})) { - ggml_tensor * weights = cgraph->nodes[i + 4]; - ggml_tensor * selected_experts = cgraph->nodes[i + 3]; - ggml_cuda_op_topk_moe(*cuda_ctx, node->src[0], weights, selected_experts, /*with norm*/ false, - /*delayed softmax*/ false); - i += 4; - continue; - } - - if (ggml_cuda_can_fuse(cgraph, i, - ggml_cuda_topk_moe_ops(/*with norm*/ false, /*delayed softmax*/ true), {})) { - ggml_tensor * weights = cgraph->nodes[i + 5]; - ggml_tensor * ids = cgraph->nodes[i + 1]; - - ggml_cuda_op_topk_moe(*cuda_ctx, node->src[0], weights, ids, /*with norm*/ false, - /*delayed_softmax*/ true); - i += 5; - continue; - } - - if (ggml_cuda_can_fuse(cgraph, i, { GGML_OP_ROPE, GGML_OP_VIEW, GGML_OP_SET_ROWS }, {})) { - ggml_tensor * rope = cgraph->nodes[i]; - ggml_tensor * set_rows = cgraph->nodes[i + 2]; - - ggml_cuda_op_rope_fused(*cuda_ctx, rope, set_rows); - i += 2; - continue; - } - - if (node->op == GGML_OP_ADD) { - int n_fuse = 0; - ggml_op ops[8]; - std::fill(ops, ops + 8, GGML_OP_ADD); - - for (; n_fuse <= 6; ++n_fuse){ - if (!ggml_can_fuse(cgraph, i + n_fuse, ops + n_fuse, 2)) { - break; - } - if (cgraph->nodes[i + n_fuse] != cgraph->nodes[i + n_fuse + 1]->src[0]) { - break; - } - if (!ggml_are_same_layout(cgraph->nodes[i + n_fuse]->src[1], cgraph->nodes[i + n_fuse + 1]->src[1])) { - break; - } - } - - n_fuse++; - - if (n_fuse > 1) { - for (int j = 0; j < n_fuse - 1; ++j) { - node->src[j + 2] = cgraph->nodes[i + j + 1]->src[1]; - } - cgraph->nodes[i + n_fuse - 1]->data = node->data; - ggml_cuda_op_fused_add(*cuda_ctx, node, n_fuse); - i += n_fuse - 1; - - continue; - } - } - - bool fused_mul_mat_vec = false; - int fused_node_count = 0; - - for (ggml_op op : { GGML_OP_MUL_MAT, GGML_OP_MUL_MAT_ID }) { - const ggml_op bias_op = op == GGML_OP_MUL_MAT ? GGML_OP_ADD : GGML_OP_ADD_ID; - - if (ggml_cuda_can_fuse(cgraph, i, { op, bias_op, op, bias_op, GGML_OP_GLU }, {})) { - ggml_tensor * glu = cgraph->nodes[i + 4]; - ggml_tensor * gate_bias_n = glu->src[0]; - ggml_tensor * up_bias_n = glu->src[1]; - - //we don't assume the order for {gate, up}. Instead infer it from the bias tensor - ggml_tensor * gate_n = nullptr; - ggml_tensor * up_n = nullptr; - - if (gate_bias_n->src[0] == cgraph->nodes[i] || gate_bias_n->src[1] == cgraph->nodes[i]) { - gate_n = cgraph->nodes[i]; - up_n = cgraph->nodes[i + 2]; - } else if (gate_bias_n->src[0] == cgraph->nodes[i + 2] || gate_bias_n->src[1] == cgraph->nodes[i + 2]) { - gate_n = cgraph->nodes[i + 2]; - up_n = cgraph->nodes[i]; - } else { - continue; - } - - auto get_bias_tensor = [](const ggml_tensor * bias_node, const ggml_tensor * mul_node, ggml_op op_bias) { - if (op_bias == GGML_OP_ADD) { - if (bias_node->src[0] == mul_node) { - return bias_node->src[1]; - } - if (bias_node->src[1] == mul_node) { - return bias_node->src[0]; - } - return (ggml_tensor *) nullptr; - } - GGML_ASSERT(op_bias == GGML_OP_ADD_ID); - GGML_ASSERT(bias_node->src[0] == mul_node); - return bias_node->src[1]; - }; - - ggml_tensor * up_bias_tensor = get_bias_tensor(up_bias_n, up_n, bias_op); - ggml_tensor * gate_bias_tensor = get_bias_tensor(gate_bias_n, gate_n, bias_op); - - if (!up_bias_tensor || !gate_bias_tensor) { - continue; - } - - // we don't support repeating adds - if (bias_op == GGML_OP_ADD && - (!ggml_are_same_shape(gate_bias_n->src[0], gate_bias_n->src[1]) || - !ggml_are_same_shape(up_bias_n->src[0], up_bias_n->src[1]))) { - continue; - } - - const ggml_tensor * src0 = up_n->src[0]; - const ggml_tensor * src1 = up_n->src[1]; - const ggml_tensor * ids = up_n->src[2]; - - if (ggml_cuda_should_fuse_mul_mat_vec_f(up_n)) { - ggml_cuda_mm_fusion_args_host fusion_data{}; - fusion_data.gate = gate_n->src[0]; - fusion_data.x_bias = up_bias_tensor; - fusion_data.gate_bias = gate_bias_tensor; - fusion_data.glu_op = ggml_get_glu_op(glu); - - ggml_cuda_mul_mat_vec_f(*cuda_ctx, src0, src1, ids, glu, &fusion_data); - fused_mul_mat_vec = true; - fused_node_count = 5; - break; - } - - if (ggml_cuda_should_fuse_mul_mat_vec_q(up_n)) { - ggml_cuda_mm_fusion_args_host fusion_data{}; - fusion_data.gate = gate_n->src[0]; - fusion_data.x_bias = up_bias_tensor; - fusion_data.gate_bias = gate_bias_tensor; - fusion_data.glu_op = ggml_get_glu_op(glu); - - ggml_cuda_mul_mat_vec_q(*cuda_ctx, src0, src1, ids, glu, &fusion_data); - fused_mul_mat_vec = true; - fused_node_count = 5; - break; - } - } else if (ggml_cuda_can_fuse(cgraph, i, { op, op, GGML_OP_GLU }, {})) { - ggml_tensor * glu = cgraph->nodes[i + 2]; - ggml_tensor * gate = glu->src[0]; - ggml_tensor * up = glu->src[1]; - - bool ok = (gate == cgraph->nodes[i] && up == cgraph->nodes[i + 1]) - || (gate == cgraph->nodes[i + 1] && up == cgraph->nodes[i]); - - if (!ok) continue; - - const ggml_tensor * src0 = up->src[0]; - const ggml_tensor * src1 = up->src[1]; - const ggml_tensor * ids = up->src[2]; - - if (ggml_cuda_should_fuse_mul_mat_vec_f(up)) { - ggml_cuda_mm_fusion_args_host fusion_data{}; - fusion_data.gate = gate->src[0]; - fusion_data.glu_op = ggml_get_glu_op(glu); - - ggml_cuda_mul_mat_vec_f(*cuda_ctx, src0, src1, ids, glu, &fusion_data); - fused_mul_mat_vec = true; - fused_node_count = 3; - break; - } - - if (ggml_cuda_should_fuse_mul_mat_vec_q(up)) { - ggml_cuda_mm_fusion_args_host fusion_data{}; - fusion_data.gate = gate->src[0]; - fusion_data.glu_op = ggml_get_glu_op(glu); - - ggml_cuda_mul_mat_vec_q(*cuda_ctx, src0, src1, ids, glu, &fusion_data); - fused_mul_mat_vec = true; - fused_node_count = 3; - break; - } - } - } - - if (fused_mul_mat_vec) { - i += fused_node_count - 1; - continue; - } - - fused_mul_mat_vec = false; - fused_node_count = 0; - - for (ggml_op op : { GGML_OP_MUL_MAT, GGML_OP_MUL_MAT_ID }) { - const ggml_op bias_op = op == GGML_OP_MUL_MAT ? GGML_OP_ADD : GGML_OP_ADD_ID; - - if (!ggml_can_fuse(cgraph, i, { op, bias_op })) { - continue; - } - - ggml_tensor * mm_node = cgraph->nodes[i]; - ggml_tensor * bias_node = cgraph->nodes[i + 1]; - - ggml_tensor * bias_tensor = nullptr; - if (bias_op == GGML_OP_ADD) { - if (bias_node->src[0] == mm_node) { - bias_tensor = bias_node->src[1]; - } else if (bias_node->src[1] == mm_node) { - bias_tensor = bias_node->src[0]; - } else { - continue; - } - } else { - if (bias_node->src[0] != mm_node) { - continue; - } - bias_tensor = bias_node->src[1]; - } - - const ggml_tensor * src0 = mm_node->src[0]; - const ggml_tensor * src1 = mm_node->src[1]; - const ggml_tensor * ids = mm_node->src[2]; - - if (bias_op == GGML_OP_ADD_ID && bias_node->src[2] != ids) { - continue; - } - - if (bias_op == GGML_OP_ADD && !ggml_are_same_shape(bias_node->src[0], bias_node->src[1])) { - continue; - } - - ggml_cuda_mm_fusion_args_host fusion_data{}; - fusion_data.x_bias = bias_tensor; - - if (ggml_cuda_should_fuse_mul_mat_vec_f(mm_node)) { - ggml_cuda_mul_mat_vec_f(*cuda_ctx, src0, src1, ids, bias_node, &fusion_data); - fused_mul_mat_vec = true; - fused_node_count = 2; - break; - } - - if (ggml_cuda_should_fuse_mul_mat_vec_q(mm_node)) { - ggml_cuda_mul_mat_vec_q(*cuda_ctx, src0, src1, ids, bias_node, &fusion_data); - fused_mul_mat_vec = true; - fused_node_count = 2; - break; - } - } - - if (fused_mul_mat_vec) { - i += fused_node_count - 1; - continue; - } - - if (ggml_cuda_can_fuse(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL, GGML_OP_ADD}, {})) { - ggml_cuda_op_rms_norm_fused_add(*cuda_ctx, node, cgraph->nodes[i+1], cgraph->nodes[i+2]); - i += 2; - continue; - } - - if (ggml_cuda_can_fuse(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL}, {})) { - ggml_cuda_op_rms_norm_fused(*cuda_ctx, node, cgraph->nodes[i+1]); - i++; - continue; - } + int nodes_to_skip = ggml_cuda_try_fuse(cuda_ctx, cgraph, i); - if (ggml_cuda_can_fuse(cgraph, i, { GGML_OP_SCALE, GGML_OP_UNARY, GGML_OP_SCALE }, { GGML_UNARY_OP_TANH })) { - i += 2; - ggml_cuda_op_softcap(*cuda_ctx, cgraph->nodes[i], node); - continue; - } + if (nodes_to_skip != 0) { + i += nodes_to_skip; + continue; } #ifndef NDEBUG assert(node->buffer->buft == ggml_backend_cuda_buffer_type(cuda_ctx->device)); @@ -3687,13 +4453,14 @@ static void ggml_cuda_graph_evaluate_and_capture(ggml_backend_cuda_context * cud } #ifdef USE_CUDA_GRAPH + ggml_cuda_graph * graph = cuda_ctx->cuda_graph(graph_key); if (use_cuda_graph && cuda_graph_update_required) { // End CUDA graph capture - if (cuda_ctx->cuda_graph->graph != nullptr) { - CUDA_CHECK(cudaGraphDestroy(cuda_ctx->cuda_graph->graph)); - cuda_ctx->cuda_graph->graph = nullptr; + if (graph->graph != nullptr) { + CUDA_CHECK(cudaGraphDestroy(graph->graph)); + graph->graph = nullptr; } - CUDA_CHECK(cudaStreamEndCapture(cuda_ctx->stream(), &cuda_ctx->cuda_graph->graph)); + CUDA_CHECK(cudaStreamEndCapture(cuda_ctx->stream(), &graph->graph)); graph_evaluated_or_captured = true; // CUDA graph has been captured std::lock_guard<std::mutex> lock(ggml_cuda_lock); @@ -3706,41 +4473,38 @@ static void ggml_cuda_graph_evaluate_and_capture(ggml_backend_cuda_context * cud } if (use_cuda_graph) { - if (cuda_ctx->cuda_graph->instance == nullptr) { // Create executable graph from captured graph. - CUDA_CHECK(cudaGraphInstantiate(&cuda_ctx->cuda_graph->instance, cuda_ctx->cuda_graph->graph, NULL, NULL, 0)); + ggml_cuda_graph * graph = cuda_ctx->cuda_graph(graph_key); + if (graph->instance == nullptr) { // Create executable graph from captured graph. + CUDA_CHECK(cudaGraphInstantiate(&graph->instance, graph->graph, NULL, NULL, 0)); } if (cuda_graph_update_required) { // Update graph executable - ggml_cuda_graph_update_executable(cuda_ctx); + ggml_cuda_graph_update_executable(cuda_ctx, graph_key); } // Launch graph - CUDA_CHECK(cudaGraphLaunch(cuda_ctx->cuda_graph->instance, cuda_ctx->stream())); + CUDA_CHECK(cudaGraphLaunch(graph->instance, cuda_ctx->stream())); #else + GGML_UNUSED(graph_key); graph_evaluated_or_captured = true; #endif // USE_CUDA_GRAPH } } -static bool ggml_cuda_graph_set_enabled(ggml_backend_cuda_context * cuda_ctx) { - #ifdef USE_CUDA_GRAPH +static bool ggml_cuda_graph_set_enabled(ggml_backend_cuda_context * cuda_ctx, const void * graph_key) { + ggml_cuda_graph * graph = cuda_ctx->cuda_graph(graph_key); - if (cuda_ctx->cuda_graph == nullptr) { - cuda_ctx->cuda_graph.reset(new ggml_cuda_graph()); - } - - if (cuda_ctx->cuda_graph->graph == nullptr) { + if (graph->graph == nullptr) { if (ggml_cuda_info().devices[cuda_ctx->device].cc < GGML_CUDA_CC_AMPERE) { - cuda_ctx->cuda_graph->disable_due_to_gpu_arch = true; - GGML_LOG_DEBUG("%s: disabling CUDA graphs due to GPU architecture\n", __func__); + if (!graph->disable_due_to_gpu_arch) { + GGML_LOG_DEBUG("%s: disabling CUDA graphs due to GPU architecture\n", __func__); + } + graph->disable_due_to_gpu_arch = true; } } - return cuda_ctx->cuda_graph->is_enabled(); -#else - GGML_UNUSED(cuda_ctx); - return false; -#endif // USE_CUDA_GRAPH + return graph->is_enabled(); } +#endif // USE_CUDA_GRAPH static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) { ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *) backend->context; @@ -3749,15 +4513,40 @@ static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t backend, bool use_cuda_graph = false; bool cuda_graph_update_required = false; + const void * graph_key = nullptr; #ifdef USE_CUDA_GRAPH - use_cuda_graph = ggml_cuda_graph_set_enabled(cuda_ctx); - - if (cuda_ctx->cuda_graph->is_enabled()) { - cuda_graph_update_required = ggml_cuda_graph_update_required(cuda_ctx, cgraph); - use_cuda_graph = ggml_cuda_graph_check_compability(cgraph); - - cuda_ctx->cuda_graph->record_update(use_cuda_graph, cuda_graph_update_required); + graph_key = ggml_cuda_graph_get_key(cgraph); + + ggml_cuda_graph_set_enabled(cuda_ctx, graph_key); + + ggml_cuda_graph * graph = cuda_ctx->cuda_graph(graph_key); + if (graph->is_enabled()) { + const bool graph_compatible = ggml_cuda_graph_check_compability(cgraph); + if (graph_compatible) { + const bool properties_changed = ggml_cuda_graph_update_required(cuda_ctx, cgraph); + + if (!graph->warmup_complete) { + // Warmup: need at least 2 calls with no property change on the 2nd call + if (!properties_changed) { + graph->warmup_complete = true; + GGML_LOG_DEBUG("%s: CUDA graph warmup complete\n", __func__); + use_cuda_graph = true; + cuda_graph_update_required = true; + } + // else: properties changed or first call - execute directly (use_cuda_graph stays false) + } else { + // Post-warmup: normal CUDA graph operation + if (properties_changed) { + // Properties changed - reset warmup, execute directly until stable again + graph->warmup_complete = false; + GGML_LOG_DEBUG("%s: CUDA graph warmup reset\n", __func__); + } else { + use_cuda_graph = true; + cuda_graph_update_required = graph->instance == nullptr; + } + } + } } #endif // USE_CUDA_GRAPH @@ -3771,7 +4560,7 @@ static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t backend, CUDA_CHECK(cudaStreamBeginCapture(cuda_ctx->stream(), cudaStreamCaptureModeRelaxed)); } - ggml_cuda_graph_evaluate_and_capture(cuda_ctx, cgraph, use_cuda_graph, cuda_graph_update_required); + ggml_cuda_graph_evaluate_and_capture(cuda_ctx, cgraph, use_cuda_graph, cuda_graph_update_required, graph_key); return GGML_STATUS_SUCCESS; } @@ -3804,7 +4593,14 @@ static void ggml_backend_cuda_event_wait(ggml_backend_t backend, ggml_backend_ev static void ggml_backend_cuda_graph_optimize(ggml_backend_t backend, ggml_cgraph * cgraph) { ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *) backend->context; - const bool use_cuda_graph = ggml_cuda_graph_set_enabled(cuda_ctx); +#ifdef USE_CUDA_GRAPH + const void * graph_key = ggml_cuda_graph_get_key(cgraph); + const bool use_cuda_graph = ggml_cuda_graph_set_enabled(cuda_ctx, graph_key); +#else + const bool use_cuda_graph = false; + GGML_UNUSED(cuda_ctx); + GGML_UNUSED(cgraph); +#endif static bool enable_graph_optimization = [] { const char * env = getenv("GGML_CUDA_GRAPH_OPT"); @@ -4043,6 +4839,8 @@ static const ggml_backend_i ggml_backend_cuda_interface = { /* .free = */ ggml_backend_cuda_free, /* .set_tensor_async = */ ggml_backend_cuda_set_tensor_async, /* .get_tensor_async = */ ggml_backend_cuda_get_tensor_async, + /* .set_tensor_2d_async = */ ggml_backend_cuda_set_tensor_2d_async, + /* .get_tensor_2d_async = */ ggml_backend_cuda_get_tensor_2d_async, /* .cpy_tensor_async = */ ggml_backend_cuda_cpy_tensor_async, /* .synchronize = */ ggml_backend_cuda_synchronize, /* .graph_plan_create = */ NULL, @@ -4118,14 +4916,6 @@ void ggml_backend_cuda_unregister_host_buffer(void * buffer) { // backend device -struct ggml_backend_cuda_device_context { - int device; - std::string name; - std::string description; - std::string pci_bus_id; - int op_offload_min_batch_size; -}; - static const char * ggml_backend_cuda_device_get_name(ggml_backend_dev_t dev) { ggml_backend_cuda_device_context * ctx = (ggml_backend_cuda_device_context *)dev->context; return ctx->name.c_str(); @@ -4214,6 +5004,11 @@ static bool ggml_backend_cuda_get_available_uma_memory(long * available_memory_k static void ggml_backend_cuda_device_get_memory(ggml_backend_dev_t dev, size_t * free, size_t * total) { ggml_backend_cuda_device_context * ctx = (ggml_backend_cuda_device_context *)dev->context; + +#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) + std::lock_guard<std::mutex> lock(ctx->device_mutex); +#endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) + ggml_cuda_set_device(ctx->device); CUDA_CHECK(cudaMemGetInfo(free, total)); @@ -4240,11 +5035,24 @@ static void ggml_backend_cuda_device_get_memory(ggml_backend_dev_t dev, size_t * } #endif // defined(__linux__) +#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) + // If no backends or buffers are active, the cudaMemGetInfo call above lazily created a CUDA + // context that permanently consumes VRAM. Reset the device to free it. + if (ctx->active_count == 0) { + CUDA_CHECK(cudaDeviceReset()); + } +#endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) } static enum ggml_backend_dev_type ggml_backend_cuda_device_get_type(ggml_backend_dev_t dev) { - GGML_UNUSED(dev); - return GGML_BACKEND_DEVICE_TYPE_GPU; + ggml_backend_cuda_device_context * ctx = (ggml_backend_cuda_device_context *) dev->context; + + cudaDeviceProp prop; + CUDA_CHECK(cudaGetDeviceProperties(&prop, ctx->device)); + + return prop.integrated + ? GGML_BACKEND_DEVICE_TYPE_IGPU + : GGML_BACKEND_DEVICE_TYPE_GPU; } static void ggml_backend_cuda_device_get_props(ggml_backend_dev_t dev, ggml_backend_dev_props * props) { @@ -4335,6 +5143,8 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g case GGML_UNARY_OP_CEIL: case GGML_UNARY_OP_ROUND: case GGML_UNARY_OP_TRUNC: + // TODO: should become: + //return ggml_is_contiguous_rows(op->src[0]); return ggml_is_contiguous(op->src[0]); default: return false; @@ -4391,12 +5201,14 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g switch (a->type) { case GGML_TYPE_F32: case GGML_TYPE_F16: + case GGML_TYPE_Q1_0: case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_1: case GGML_TYPE_Q5_0: case GGML_TYPE_Q5_1: case GGML_TYPE_Q8_0: case GGML_TYPE_MXFP4: + case GGML_TYPE_NVFP4: case GGML_TYPE_Q2_K: case GGML_TYPE_Q3_K: case GGML_TYPE_Q4_K: @@ -4427,6 +5239,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g case GGML_TYPE_F32: case GGML_TYPE_BF16: case GGML_TYPE_I32: + case GGML_TYPE_Q1_0: case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_1: case GGML_TYPE_Q5_0: @@ -4532,7 +5345,15 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g case GGML_OP_CONCAT: { ggml_type src0_type = op->src[0]->type; - return src0_type != GGML_TYPE_I32 && src0_type != GGML_TYPE_I16; + ggml_type src1_type = op->src[1]->type; + return src0_type == src1_type && + src0_type == op->type && + !ggml_is_quantized(src0_type) && + ggml_blck_size(src0_type) == 1 && + (ggml_type_size(src0_type) == 1 || + ggml_type_size(src0_type) == 2 || + ggml_type_size(src0_type) == 4 || + ggml_type_size(src0_type) == 8); } break; case GGML_OP_CONV_TRANSPOSE_1D: { @@ -4551,19 +5372,15 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g case GGML_OP_L2_NORM: return true; case GGML_OP_RMS_NORM_BACK: - return ggml_is_contiguous(op->src[0]) && op->ne[0] % WARP_SIZE == 0; + return ggml_is_contiguous(op->src[0]); break; case GGML_OP_NONE: case GGML_OP_RESHAPE: case GGML_OP_VIEW: case GGML_OP_PERMUTE: case GGML_OP_TRANSPOSE: - case GGML_OP_ADD: case GGML_OP_ADD_ID: case GGML_OP_ADD1: - case GGML_OP_SUB: - case GGML_OP_MUL: - case GGML_OP_DIV: case GGML_OP_SCALE: case GGML_OP_SQR: case GGML_OP_SQRT: @@ -4572,6 +5389,13 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g case GGML_OP_CLAMP: case GGML_OP_LOG: return true; + case GGML_OP_ADD: + case GGML_OP_SUB: + case GGML_OP_MUL: + case GGML_OP_DIV: + return (op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16) && + (op->src[1]->type == GGML_TYPE_F32 || op->src[1]->type == GGML_TYPE_F16) && + (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16); case GGML_OP_SSM_SCAN: { if (op->src[3]->ne[0] == 1) { // Mamba2 @@ -4613,8 +5437,11 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g case GGML_OP_CONV_2D_DW: case GGML_OP_CONV_TRANSPOSE_2D: case GGML_OP_POOL_2D: - case GGML_OP_ACC: return true; + case GGML_OP_ACC: + // TODO: extend support like so: + //return ggml_is_contiguous_rows(op->src[0]) && ggml_is_contiguous_rows(op->src[1]); + return ggml_is_contiguous(op->src[0]) && ggml_is_contiguous(op->src[1]); case GGML_OP_SUM: return ggml_is_contiguous_rows(op->src[0]); case GGML_OP_TOP_K: @@ -4627,8 +5454,9 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g case GGML_OP_SUM_ROWS: case GGML_OP_MEAN: case GGML_OP_GROUP_NORM: - case GGML_OP_PAD: return ggml_is_contiguous(op->src[0]); + case GGML_OP_PAD: + return true; case GGML_OP_UPSCALE: case GGML_OP_PAD_REFLECT_1D: case GGML_OP_ARANGE: @@ -4638,6 +5466,13 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g case GGML_OP_GATED_LINEAR_ATTN: case GGML_OP_RWKV_WKV7: return true; + case GGML_OP_GATED_DELTA_NET: + //TODO: enable once MUSA compiler is solved https://github.com/ggml-org/llama.cpp/pull/19504#issuecomment-4018634327 +#ifdef GGML_USE_MUSA + return false; +#else + return true; +#endif // GGML_USE_MUSA case GGML_OP_FLASH_ATTN_EXT: return ggml_cuda_flash_attn_ext_supported(dev_ctx->device, op); case GGML_OP_CROSS_ENTROPY_LOSS: @@ -4816,6 +5651,15 @@ static ggml_backend_feature * ggml_backend_cuda_get_features(ggml_backend_reg_t static void * ggml_backend_cuda_reg_get_proc_address(ggml_backend_reg_t reg, const char * name) { GGML_UNUSED(reg); + if (strcmp(name, "ggml_backend_comm_init") == 0) { + return (void *)ggml_backend_cuda_comm_init; + } + if (strcmp(name, "ggml_backend_comm_free") == 0) { + return (void *)ggml_backend_cuda_comm_free; + } + if (strcmp(name, "ggml_backend_comm_allreduce_tensor") == 0) { + return (void *)ggml_backend_cuda_comm_allreduce_tensor; + } if (strcmp(name, "ggml_backend_split_buffer_type") == 0) { return (void *)ggml_backend_cuda_split_buffer_type; } @@ -4859,9 +5703,12 @@ ggml_backend_reg_t ggml_backend_cuda_reg() { CUDA_CHECK(cudaGetDeviceProperties(&prop, i)); dev_ctx->description = prop.name; - char pci_bus_id[16] = {}; - snprintf(pci_bus_id, sizeof(pci_bus_id), "%04x:%02x:%02x.0", prop.pciDomainID, prop.pciBusID, prop.pciDeviceID); + char pci_bus_id[32] = {}; + CUDA_CHECK(cudaDeviceGetPCIBusId(pci_bus_id, sizeof(pci_bus_id), i)); dev_ctx->pci_bus_id = pci_bus_id; + for (char & c : dev_ctx->pci_bus_id) { + c = std::tolower(c); + } dev_ctx->op_offload_min_batch_size = min_batch_size; ggml_backend_dev_t dev = new ggml_backend_device { @@ -4897,13 +5744,21 @@ ggml_backend_t ggml_backend_cuda_init(int device) { return nullptr; } + ggml_backend_dev_t dev = ggml_backend_reg_dev_get(ggml_backend_cuda_reg(), device); + ggml_backend_t cuda_backend = new ggml_backend { /* .guid = */ ggml_backend_cuda_guid(), /* .iface = */ ggml_backend_cuda_interface, - /* .device = */ ggml_backend_reg_dev_get(ggml_backend_cuda_reg(), device), + /* .device = */ dev, /* .context = */ ctx, }; +#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) + ggml_backend_cuda_device_context * dev_ctx = (ggml_backend_cuda_device_context *) dev->context; + std::lock_guard<std::mutex> lock(dev_ctx->device_mutex); + dev_ctx->active_count++; +#endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) + return cuda_backend; } diff --git a/ggml/src/ggml-cuda/im2col.cu b/ggml/src/ggml-cuda/im2col.cu index 56dc0545742..28c79ab462e 100644 --- a/ggml/src/ggml-cuda/im2col.cu +++ b/ggml/src/ggml-cuda/im2col.cu @@ -1,5 +1,6 @@ #include "im2col.cuh" +#define MAX_GRIDDIM_Y 65535 #define MAX_GRIDDIM_Z 65535 template <typename T> @@ -18,22 +19,23 @@ static __global__ void im2col_kernel( const int64_t ikh = rem / KW; const int64_t ikw = rem - ikh * KW; - const int64_t iow = blockIdx.y; - for (int64_t iz = blockIdx.z; iz < N_OH; iz+=MAX_GRIDDIM_Z) { - const int64_t in = iz / OH; - const int64_t ioh = iz - in * OH; + for (int64_t iow = blockIdx.y; iow < OW; iow += MAX_GRIDDIM_Y) { + for (int64_t iz = blockIdx.z; iz < N_OH; iz += MAX_GRIDDIM_Z) { + const int64_t in = iz / OH; + const int64_t ioh = iz - in * OH; - const int64_t iiw = iow * s0 + ikw * d0 - p0; - const int64_t iih = ioh * s1 + ikh * d1 - p1; + const int64_t iiw = iow * s0 + ikw * d0 - p0; + const int64_t iih = ioh * s1 + ikh * d1 - p1; - const int64_t offset_dst = - ((in * OH + ioh) * OW + iow) * IC_KH_KW + iic * KH_KW + ikh * KW + ikw; + const int64_t offset_dst = + ((in * OH + ioh) * OW + iow) * IC_KH_KW + iic * KH_KW + ikh * KW + ikw; - if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW) { - dst[offset_dst] = 0.0f; - } else { - const int64_t offset_src = iic * IC_IH_IW + in * IH_IW; - dst[offset_dst] = x[offset_src + iih * IW + iiw]; + if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW) { + dst[offset_dst] = 0.0f; + } else { + const int64_t offset_src = iic * IC_IH_IW + in * IH_IW; + dst[offset_dst] = x[offset_src + iih * IW + iiw]; + } } } @@ -51,7 +53,7 @@ static void im2col_cuda(const float * x, T* dst, const int64_t num_blocks = (IC_KH_KW + CUDA_IM2COL_BLOCK_SIZE - 1) / CUDA_IM2COL_BLOCK_SIZE; const int64_t N_OH = N * OH; const int64_t KH_KW = KW*KH; - dim3 block_nums(num_blocks, OW, MIN(N_OH, MAX_GRIDDIM_Z)); + dim3 block_nums(num_blocks, MIN(OW, MAX_GRIDDIM_Y), MIN(N_OH, MAX_GRIDDIM_Z)); im2col_kernel<<<block_nums, MIN(IC_KH_KW, CUDA_IM2COL_BLOCK_SIZE) , 0, stream>>>(x, dst, IC, IW, IH, OH, OW, KW, KH, IC_IH_IW, IH_IW, N_OH, KH_KW, IC_KH_KW, s0, s1, p0, p1, d0, d1); @@ -136,23 +138,24 @@ static __global__ void im2col_3d_kernel( const int64_t ikh = (i - iic * KD_KH_KW - ikd * KH_KW) / KW; const int64_t ikw = i % KW; - const int64_t iow = blockIdx.y; - for (int64_t iz = blockIdx.z; iz < N_OD_OH; iz+=MAX_GRIDDIM_Z) { - const int64_t in = iz / OD_OH; - const int64_t iod = (iz - in*OD_OH) / OH; - const int64_t ioh = iz % OH; + for (int64_t iow = blockIdx.y; iow < OW; iow += MAX_GRIDDIM_Y) { + for (int64_t iz = blockIdx.z; iz < N_OD_OH; iz += MAX_GRIDDIM_Z) { + const int64_t in = iz / OD_OH; + const int64_t iod = (iz - in*OD_OH) / OH; + const int64_t ioh = iz % OH; - const int64_t iiw = iow * s0 + ikw * d0 - p0; - const int64_t iih = ioh * s1 + ikh * d1 - p1; - const int64_t iid = iod * s2 + ikd * d2 - p2; + const int64_t iiw = iow * s0 + ikw * d0 - p0; + const int64_t iih = ioh * s1 + ikh * d1 - p1; + const int64_t iid = iod * s2 + ikd * d2 - p2; - const int64_t offset_dst = in*OD_OH_OW_IC_KD_KH_KW + iod*OH_OW_IC_KD_KH_KW + ioh*OW_IC_KD_KH_KW + iow*IC_KD_KH_KW + iic*KD_KH_KW + ikd * KH_KW + ikh*KW + ikw; + const int64_t offset_dst = in*OD_OH_OW_IC_KD_KH_KW + iod*OH_OW_IC_KD_KH_KW + ioh*OW_IC_KD_KH_KW + iow*IC_KD_KH_KW + iic*KD_KH_KW + ikd * KH_KW + ikh*KW + ikw; - if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW || iid < 0 || iid >= ID) { - dst[offset_dst] = 0.0f; - } else { - const int64_t offset_src = ((in * IC + iic) * stride_q) + (iid * stride_z) + (iih * stride_y) + (iiw * stride_x); - dst[offset_dst] = src[offset_src]; + if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW || iid < 0 || iid >= ID) { + dst[offset_dst] = 0.0f; + } else { + const int64_t offset_src = ((in * IC + iic) * stride_q) + (iid * stride_z) + (iih * stride_y) + (iiw * stride_x); + dst[offset_dst] = src[offset_src]; + } } } } @@ -178,7 +181,7 @@ static void im2col_3d_cuda(const float * src, T* dst, const int64_t OH_OW_IC_KD_KH_KW = OH*OW*IC*KD*KH*KW; const int64_t OW_IC_KD_KH_KW = OW*IC*KD*KH*KW; const int64_t num_blocks = (IC_KD_KH_KW + CUDA_IM2COL_BLOCK_SIZE - 1) / CUDA_IM2COL_BLOCK_SIZE; - dim3 block_nums(num_blocks, OW, MIN(N_OD_OH, MAX_GRIDDIM_Z)); + dim3 block_nums(num_blocks, MIN(OW, MAX_GRIDDIM_Y), MIN(N_OD_OH, MAX_GRIDDIM_Z)); im2col_3d_kernel<<<block_nums, MIN(IC_KD_KH_KW, CUDA_IM2COL_BLOCK_SIZE) , 0, stream>>>(src, dst, N, IC, ID, IH, IW, OC, KD, KH, KW, OD, OH, OW, OH_OW, KD_KH_KW, ID_IH_IW, KH_KW, IH_IW, IC_ID_IH_IW, IC_KD_KH_KW, OW_KD_KH_KW, OD_OH_OW_IC_KD_KH_KW, diff --git a/ggml/src/ggml-cuda/mean.cu b/ggml/src/ggml-cuda/mean.cu index 60542fc19dd..a8f6046e46d 100644 --- a/ggml/src/ggml-cuda/mean.cu +++ b/ggml/src/ggml-cuda/mean.cu @@ -31,14 +31,15 @@ void ggml_cuda_op_mean(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { #endif // USE_CUDA_GRAPH if ((nrows == 1) && #ifdef USE_CUDA_GRAPH - // CUDA_GRAPHS_DISABLED - ((ncols > 65536) && - ((ctx.cuda_graph->instance == nullptr) && (iscapturing == cudaStreamCaptureStatusNone) || - ctx.cuda_graph->is_enabled())) || - // CUDA_GRAPHS ENABLED - ((ncols > 32768) && - !((ctx.cuda_graph->instance == nullptr) && (iscapturing == cudaStreamCaptureStatusNone) || - ctx.cuda_graph->is_enabled()))) { + // Determine if CUDA graphs are effectively disabled for this context + // (no graph instance exists and we're not capturing, OR graphs are explicitly enabled) + (((ncols > 65536) && + (((!ctx.any_cuda_graph_has_instance()) && (iscapturing == cudaStreamCaptureStatusNone)) || + ctx.any_cuda_graph_enabled())) || + // CUDA graphs are enabled - use lower threshold + ((ncols > 32768) && + !(((!ctx.any_cuda_graph_has_instance()) && (iscapturing == cudaStreamCaptureStatusNone)) || + ctx.any_cuda_graph_enabled())))) { #else (ncols > 65536)) { #endif // USE_CUDA_GRAPH @@ -66,9 +67,11 @@ void ggml_cuda_op_mean(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { // See discussion in: https://github.com/ggml-org/llama.cpp/pull/15132 if ((nrows / nsm) < 2) { const dim3 block_dims(512, 1, 1); - reduce_rows_f32</*norm=*/true><<<block_nums, block_dims, 0, stream>>>(src0_d, dst_d, ncols); + const ggml_cuda_kernel_launch_params launch_params = ggml_cuda_kernel_launch_params(block_nums, block_dims, 0, stream); + ggml_cuda_kernel_launch(reduce_rows_f32</*norm=*/true>, launch_params, src0_d, dst_d, ncols); } else { const dim3 block_dims(ncols < 1024 ? 32 : 128, 1, 1); - reduce_rows_f32</*norm=*/true><<<block_nums, block_dims, 0, stream>>>(src0_d, dst_d, ncols); + const ggml_cuda_kernel_launch_params launch_params = ggml_cuda_kernel_launch_params(block_nums, block_dims, 0, stream); + ggml_cuda_kernel_launch(reduce_rows_f32</*norm=*/true>, launch_params, src0_d, dst_d, ncols); } } diff --git a/ggml/src/ggml-cuda/mma.cuh b/ggml/src/ggml-cuda/mma.cuh index df9eed71172..8d7c69dc3e8 100644 --- a/ggml/src/ggml-cuda/mma.cuh +++ b/ggml/src/ggml-cuda/mma.cuh @@ -80,23 +80,19 @@ namespace ggml_cuda_mma { DATA_LAYOUT_J_MAJOR = 10, // Matrix C for CDNA and RDNA4, int and float matrix C for RDNA3. DATA_LAYOUT_I_MAJOR_MIRRORED = 20, // Volta, matrix A&B for RDNA3. DATA_LAYOUT_J_MAJOR_MIRRORED = 30, + DATA_LAYOUT_I_MAJOR_SCRAMBLED = 40, // Scrambled matrix C for faster transposition (RDNA4/CDNA), convert to float to unscramble. }; // Implemented mma combinations are: // - (I_MAJOR, I_MAJOR) -> I_MAJOR // - (I_MAJOR, I_MAJOR_MIRRORED) -> I_MAJOR // - (I_MAJOR, J_MAJOR_MIRRORED) -> I_MAJOR - static constexpr bool is_i_major(const data_layout dl) { - return dl == DATA_LAYOUT_I_MAJOR || - dl == DATA_LAYOUT_I_MAJOR_MIRRORED; - } - static constexpr __device__ data_layout get_input_data_layout() { -#if defined(RDNA3) || __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA +#if defined(RDNA3) || defined(VOLTA_MMA_AVAILABLE) return DATA_LAYOUT_I_MAJOR_MIRRORED; #else return DATA_LAYOUT_I_MAJOR; -#endif // defined(RDNA3) || __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA +#endif // defined(RDNA3) || defined(VOLTA_MMA_AVAILABLE) } template <int I_, int J_, typename T, data_layout ds_=DATA_LAYOUT_I_MAJOR> @@ -113,7 +109,6 @@ namespace ggml_cuda_mma { T x[ne] = {0}; static constexpr __device__ bool supported() { - if (I == 64 && J == 2) return true; if (I == 16 && J == 8) return true; if (I == 32 && J == 4) return true; if (I == 16 && J == 16) return true; @@ -122,7 +117,7 @@ namespace ggml_cuda_mma { } static __device__ __forceinline__ int get_i(const int l) { - if constexpr (I == 64 && J == 2) { // Special tile size to load <16, 4> as <16, 8> + if constexpr (I == 16 && J == 4) { return threadIdx.x % 16; } else if constexpr (I == 16 && J == 8) { return threadIdx.x % 16; @@ -139,8 +134,8 @@ namespace ggml_cuda_mma { } static __device__ __forceinline__ int get_j(const int l) { - if constexpr (I == 64 && J == 2) { // Special tile size to load <16, 4> as <16, 8> - return (2 * ((threadIdx.x / 16) % 2) + l); + if constexpr (I == 16 && J == 4) { + return threadIdx.x / 16; } else if constexpr (I == 16 && J == 8) { return 2 * (threadIdx.x / 16) + l; } else if constexpr (I == 32 && J == 4) { @@ -154,7 +149,7 @@ namespace ggml_cuda_mma { return -1; } } -#elif __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA +#elif defined(VOLTA_MMA_AVAILABLE) static constexpr int ne = I * J / 32; T x[ne] = {0}; @@ -206,10 +201,16 @@ namespace ggml_cuda_mma { static __device__ __forceinline__ int get_j(const int l) { if constexpr (I == 16 && J == 16) { - // matrix C #if defined(RDNA3) - return 2 * l + (threadIdx.x / 16); + if constexpr (std::is_same_v<T, float> || std::is_same_v<T, int>) { + // matrix C + return 2 * l + (threadIdx.x / 16); + } else { + // matrix A&B + return l; + } #else + // matrix C is the transposed matrix A&B on RDNA4 return ne * (threadIdx.x / 16) + l; #endif // defined(RDNA3) } else if constexpr (I == 16 && J == 8) { @@ -277,7 +278,7 @@ namespace ggml_cuda_mma { static constexpr int J = J_; static constexpr data_layout dl = DATA_LAYOUT_I_MAJOR; -#if __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA +#if defined(VOLTA_MMA_AVAILABLE) static constexpr int ne = I * J / WARP_SIZE; half2 x[ne] = {{0.0f, 0.0f}}; @@ -312,13 +313,19 @@ namespace ggml_cuda_mma { half2 x[ne] = {{0.0f, 0.0f}}; static constexpr __device__ bool supported() { - if (I == 16 && J == 8) return true; + if (I == 16 && J == 8) return true; + if (I == 16 && J == 16) return true; + if (I == 32 && J == 8) return true; return false; } static __device__ __forceinline__ int get_i(const int l) { if constexpr (I == 16 && J == 8) { return threadIdx.x % 16; + } else if constexpr (I == 16 && J == 16) { + return threadIdx.x % 16; + } else if constexpr (I == 32 && J == 8) { + return (threadIdx.x % 16) * 2 + l / (ne/2); } else { NO_DEVICE_CODE; return -1; @@ -327,7 +334,51 @@ namespace ggml_cuda_mma { static __device__ __forceinline__ int get_j(const int l) { if constexpr (I == 16 && J == 8) { - return 4 * (threadIdx.x / 16) + l; + return (threadIdx.x / 16) * ne + l; + } else if constexpr (I == 16 && J == 16) { +#ifdef RDNA3 + return l*2 + (threadIdx.x / 16); +#else + return (threadIdx.x / 16) * ne + l; +#endif // RDNA3 + } else if constexpr (I == 32 && J == 8) { + return (threadIdx.x / 16) * (ne/2) + l % (ne/2); + } else { + NO_DEVICE_CODE; + return -1; + } + } +#elif defined(AMD_MFMA_AVAILABLE) + static constexpr int ne = I * J / 64; + half2 x[ne] = {{0.0f, 0.0f}}; + + static constexpr __device__ bool supported() { + if (I == 16 && J == 8) return true; + if (I == 16 && J == 16) return true; + if (I == 32 && J == 8) return true; + return false; + } + + static __device__ __forceinline__ int get_i(const int l) { + if constexpr (I == 16 && J == 8) { + return threadIdx.x % 16; + } else if constexpr (I == 16 && J == 16) { + return threadIdx.x % 16; + } else if constexpr (I == 32 && J == 8) { + return (threadIdx.x % 16) * 2 + l / (ne/2); + } else { + NO_DEVICE_CODE; + return -1; + } + } + + static __device__ __forceinline__ int get_j(const int l) { + if constexpr (I == 16 && J == 8) { + return (threadIdx.x / 16) * ne + l; + } else if constexpr (I == 16 && J == 16) { + return (threadIdx.x / 16) * ne + l; + } else if constexpr (I == 32 && J == 8) { + return (threadIdx.x / 16) * (ne/2) + l % (ne/2); } else { NO_DEVICE_CODE; return -1; @@ -375,7 +426,7 @@ namespace ggml_cuda_mma { return -1; } } -#endif // __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA +#endif // defined(VOLTA_MMA_AVAILABLE) }; template <int I_, int J_> @@ -385,7 +436,22 @@ namespace ggml_cuda_mma { static constexpr data_layout dl = DATA_LAYOUT_I_MAJOR; #if defined(AMD_WMMA_AVAILABLE) - static constexpr int ne = I * J / 32; + static constexpr int ne = tile<I_, J_, half2, DATA_LAYOUT_I_MAJOR>::ne; + nv_bfloat162 x[ne] = {{0.0f, 0.0f}}; + + static constexpr __device__ bool supported() { + return tile<I_, J_, half2, DATA_LAYOUT_I_MAJOR>::supported(); + } + + static __device__ __forceinline__ int get_i(const int l) { + return tile<I_, J_, half2, DATA_LAYOUT_I_MAJOR>::get_i(l); + } + + static __device__ __forceinline__ int get_j(const int l) { + return tile<I_, J_, half2, DATA_LAYOUT_I_MAJOR>::get_j(l); + } +#elif defined(AMD_MFMA_AVAILABLE) + static constexpr int ne = tile<I_, J_, half2, DATA_LAYOUT_I_MAJOR>::ne; nv_bfloat162 x[ne] = {{0.0f, 0.0f}}; static constexpr __device__ bool supported() { @@ -475,12 +541,15 @@ namespace ggml_cuda_mma { if (I == 16 && J == 16) return true; if (I == 16 && J == 8) return true; if (I == 16 && J == 4) return true; + if (I == 32 && J == 8) return true; return false; } - static __device__ __forceinline__ int get_i(const int /*l*/) { - if constexpr (supported()) { + static __device__ __forceinline__ int get_i(const int l) { + if constexpr (I == 16) { return threadIdx.x % 16; + } else if constexpr (I == 32) { + return (threadIdx.x % 16) * 2 + l / (ne/2); } else { NO_DEVICE_CODE; return -1; @@ -488,8 +557,10 @@ namespace ggml_cuda_mma { } static __device__ __forceinline__ int get_j(const int l) { - if constexpr (supported()) { + if constexpr (I == 16) { return l; + } else if constexpr (I == 32) { + return l % (ne/2); } else { NO_DEVICE_CODE; return -1; @@ -603,6 +674,40 @@ namespace ggml_cuda_mma { } }; + template <int I_, int J_> + struct tile<I_, J_, half2, DATA_LAYOUT_I_MAJOR_SCRAMBLED> { + static constexpr int I = I_; + static constexpr int J = J_; + static constexpr data_layout dl = DATA_LAYOUT_I_MAJOR_SCRAMBLED; + + static constexpr int ne = I * J / ggml_cuda_get_physical_warp_size(); + half2 x[ne] = {{0.0f, 0.0f}}; + + static constexpr __device__ bool supported() { + if (I == 16 && J == 16) return true; + return false; + } + + static __device__ __forceinline__ int get_i(const int l) { + return tile<I_, J_, half2, DATA_LAYOUT_I_MAJOR>::get_i(l); + } + }; + + static __device__ __forceinline__ tile<16, 16, half2, DATA_LAYOUT_I_MAJOR> unscramble(const tile<16, 16, half2, DATA_LAYOUT_I_MAJOR_SCRAMBLED> & t) { +#if defined(AMD_MFMA_AVAILABLE) || (defined(AMD_WMMA_AVAILABLE) && defined(RDNA4)) + tile<16, 16, half2, DATA_LAYOUT_I_MAJOR> ret; +#pragma unroll + for (int l0 = 0; l0 < t.ne/2; ++l0) { + ret.x[2*l0 + 0] = __lows2half2(t.x[l0], t.x[l0 + t.ne/2]); + ret.x[2*l0 + 1] = __highs2half2(t.x[l0], t.x[l0 + t.ne/2]); + } + return ret; +#else + NO_DEVICE_CODE; + GGML_UNUSED(t); +#endif // defined(AMD_MFMA_AVAILABLE) || (defined(AMD_WMMA_AVAILABLE) && defined(RDNA4)) + } + #if defined(TURING_MMA_AVAILABLE) template <int I, int J> static __device__ __forceinline__ tile<I, J/2, half2> get_half2(const tile<I, J, float> & tile_float) { @@ -621,6 +726,36 @@ namespace ggml_cuda_mma { return ret; } +#elif defined(AMD_WMMA_AVAILABLE) && defined(RDNA3) + static __device__ __forceinline__ tile<16, 8, half2, DATA_LAYOUT_I_MAJOR_MIRRORED> get_half2( + const tile<16, 16, float, DATA_LAYOUT_I_MAJOR> & tile_float) { + tile<16, 8, half2, DATA_LAYOUT_I_MAJOR_MIRRORED> ret; +#pragma unroll + for (int l = 0; l < tile_float.ne; ++l) { + float tmp[2]; + int i = threadIdx.x / 16; + tmp[i] = tile_float.x[l]; + i ^= 1; + tmp[i] = __shfl_xor_sync(0xFFFFFFFF, tile_float.x[l], 16, WARP_SIZE); + ret.x[l] = make_half2(tmp[0], tmp[1]); + } + return ret; + } +#elif defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE) + template <int I, int J> + static __device__ __forceinline__ tile<I, J/2, half2> get_half2(const tile<I, J, float> & tile_float) { + tile<I, J/2, half2> ret; +#pragma unroll + for (int l0 = 0; l0 < tile_float.ne; l0 += 2) { + ret.x[l0/2] = make_half2(tile_float.x[l0 + 0], tile_float.x[l0 + 1]); + } + return ret; + } + + static __device__ __forceinline__ tile<8, 8, half2> get_transposed(const tile<16, 4, half2> & t) { + NO_DEVICE_CODE; + return tile<8, 8, half2>{}; + } #else // Volta template <int I, int J> static __device__ __forceinline__ tile<I, J/2, half2> get_half2(const tile<I, J, float> & tile_float) { @@ -641,42 +776,10 @@ namespace ggml_cuda_mma { template <int I, int J, typename T, data_layout dl> static __device__ __forceinline__ void load_generic(tile<I, J, T, dl> & t, const T * __restrict__ xs0, const int stride) { -#if defined(AMD_MFMA_AVAILABLE) - if constexpr (I == 64 && J == 2) { // Special tile size to load <16, 4> as <16, 8> -#pragma unroll - for (int l = 0; l < t.ne; ++l) { - t.x[l] = xs0[t.get_i(l)*stride + t.get_j(l)]; - } - } else { - ggml_cuda_memcpy_1<sizeof(t.x)>(t.x, xs0 + t.get_i(0) * stride + t.get_j(0)); - } -#elif defined(AMD_WMMA_AVAILABLE) - // All wmma layout has contiguous data when i-major. - if constexpr (is_i_major(dl)) { - // the data must be aligned to 16 bytes when bigger than ggml_cuda_get_max_cpy_bytes() - constexpr int aligned_copy_bytes = ggml_cuda_get_max_cpy_bytes(); - if constexpr (sizeof(t.x) > aligned_copy_bytes) { - static_assert(sizeof(t.x) % aligned_copy_bytes == 0, "bad type size"); - constexpr int aligned_copy_count = sizeof(t.x)/aligned_copy_bytes; -#pragma unroll - for (int i = 0; i < aligned_copy_count; ++i) { - ggml_cuda_memcpy_1<aligned_copy_bytes>(t.x + t.ne/aligned_copy_count*i, xs0 + t.get_i(0) * stride + t.get_j(t.ne/aligned_copy_count*i)); - } - } else { - ggml_cuda_memcpy_1<sizeof(t.x)>(t.x, xs0 + t.get_i(0) * stride + t.get_j(0)); - } - } else { -#pragma unroll - for (int l = 0; l < t.ne; ++l) { - t.x[l] = xs0[t.get_i(l)*stride + t.get_j(l)]; - } - } -#else #pragma unroll for (int l = 0; l < t.ne; ++l) { t.x[l] = xs0[t.get_i(l)*stride + t.get_j(l)]; } -#endif // defined(AMD_MFMA_AVAILABLE) } template <typename T> @@ -689,26 +792,37 @@ namespace ggml_cuda_mma { : "=r"(xi[0]), "=r"(xi[1]) : "l"(xs)); #else - load_generic(t, xs0, stride); + GGML_UNUSED_VARS(t, xs0, stride); + NO_DEVICE_CODE; #endif // TURING_MMA_AVAILABLE } - template <typename T> + template <typename T, data_layout dl> static __device__ __forceinline__ void load_ldmatrix( - tile<16, 4, T> & t, const T * __restrict__ xs0, const int stride) { + tile<16, 4, T, dl> & t, const T * __restrict__ xs0, const int stride) { #ifdef TURING_MMA_AVAILABLE int * xi = (int *) t.x; const int * xs = (const int *) xs0 + (threadIdx.x % t.I) * stride; asm volatile("ldmatrix.sync.aligned.m8n8.x2.b16 {%0, %1}, [%2];" : "=r"(xi[0]), "=r"(xi[1]) : "l"(xs)); +#elif defined(AMD_WMMA_AVAILABLE) +#ifdef RDNA3 + static_assert(dl == DATA_LAYOUT_I_MAJOR_MIRRORED, "bad data layout"); + static_assert(sizeof(t.x) == 16, "bad ne"); + ggml_cuda_memcpy_1<8>(t.x + 0, xs0 + t.get_i(0)*stride + 0); + ggml_cuda_memcpy_1<8>(t.x + 2, xs0 + t.get_i(0)*stride + 2); +#else + static_assert(dl == DATA_LAYOUT_I_MAJOR, "bad data layout"); + static_assert(sizeof(t.x) == 8, "bad ne"); + ggml_cuda_memcpy_1<8>(t.x, xs0 + t.get_i(0)*stride + t.get_j(0)); +#endif // RDNA3 +#elif defined(AMD_MFMA_AVAILABLE) + static_assert(sizeof(t.x) == 4, "bad ne"); + ggml_cuda_memcpy_1<4>(t.x, xs0 + t.get_i(0)*stride + t.get_j(0)); #else -#if __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA GGML_UNUSED_VARS(t, xs0, stride); NO_DEVICE_CODE; -#else - load_generic(t, xs0, stride); -#endif // __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA #endif // TURING_MMA_AVAILABLE } @@ -721,19 +835,26 @@ namespace ggml_cuda_mma { asm volatile("ldmatrix.sync.aligned.m8n8.x4.b16 {%0, %1, %2, %3}, [%4];" : "=r"(xi[0]), "=r"(xi[1]), "=r"(xi[2]), "=r"(xi[3]) : "l"(xs)); -#else -#if __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA -#if 1 - // TODO: more generic handling - static_assert(sizeof(T) == 4, "bad type size"); +#elif defined(VOLTA_MMA_AVAILABLE) ggml_cuda_memcpy_1<4*sizeof(T)>(t.x + 0, xs0 + t.get_i(0)*stride + 0); ggml_cuda_memcpy_1<4*sizeof(T)>(t.x + 4, xs0 + t.get_i(4)*stride + 4); +#elif defined(AMD_WMMA_AVAILABLE) +#ifdef RDNA3 + static_assert(dl == DATA_LAYOUT_I_MAJOR_MIRRORED, "bad data layout"); + static_assert(sizeof(t.x) == 32, "bad ne"); + ggml_cuda_memcpy_1<16>(t.x + 0, xs0 + t.get_i(0)*stride + 0); + ggml_cuda_memcpy_1<16>(t.x + 4, xs0 + t.get_i(0)*stride + 4); #else - load_generic(t, xs0, stride); -#endif // 1 + static_assert(dl == DATA_LAYOUT_I_MAJOR, "bad data layout"); + static_assert(sizeof(t.x) == 16, "bad ne"); + ggml_cuda_memcpy_1<16>(t.x, xs0 + t.get_i(0)*stride + t.get_j(0)); +#endif // RDNA3 +#elif defined(AMD_MFMA_AVAILABLE) + static_assert(sizeof(t.x) == 8, "bad ne"); + ggml_cuda_memcpy_1<8>(t.x, xs0 + t.get_i(0)*stride + t.get_j(0)); #else - load_generic(t, xs0, stride); -#endif // __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA + GGML_UNUSED_VARS(t, xs0, stride); + NO_DEVICE_CODE; #endif // TURING_MMA_AVAILABLE } @@ -752,23 +873,44 @@ namespace ggml_cuda_mma { static __device__ __forceinline__ void load_ldmatrix( tile<32, 4, half2> & t, const half2 * __restrict__ xs0, const int stride) { -#if __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA +#if defined(VOLTA_MMA_AVAILABLE) ggml_cuda_memcpy_1<4*sizeof(half2)>(t.x, xs0 + t.get_i(0)*stride); #else GGML_UNUSED_VARS(t, xs0, stride); NO_DEVICE_CODE; -#endif // __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA +#endif // defined(VOLTA_MMA_AVAILABLE) } - template <typename T> + template <int I, typename T, data_layout dl> static __device__ __forceinline__ void load_ldmatrix_trans( - tile<16, 8, T> & t, const T * __restrict__ xs0, const int stride) { + tile<I, 8, T, dl> & t, const T * __restrict__ xs0, const int stride) { #ifdef TURING_MMA_AVAILABLE - int * xi = (int * ) t.x; + static_assert(I == 16, "bad tile width"); + static_assert(dl == DATA_LAYOUT_I_MAJOR, "bad data layout"); + int * xi = (int *) t.x; const int * xs = (const int *) xs0 + (threadIdx.x % t.I) * stride + (threadIdx.x / t.I) * (t.J / 2); asm volatile("ldmatrix.sync.aligned.m8n8.x4.trans.b16 {%0, %1, %2, %3}, [%4];" : "=r"(xi[0]), "=r"(xi[2]), "=r"(xi[1]), "=r"(xi[3]) : "l"(xs)); +#elif defined(AMD_MFMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) + static_assert(dl == DATA_LAYOUT_I_MAJOR || dl == DATA_LAYOUT_I_MAJOR_MIRRORED, "bad data layout"); + if constexpr (I == 32) { +#pragma unroll + for (int l0 = 0; l0 < t.ne/2; ++l0) { + const half2 tmp0 = xs0[(2*t.get_j(l0) + 0)*stride + t.get_i(l0)/2]; + const half2 tmp1 = xs0[(2*t.get_j(l0) + 1)*stride + t.get_i(l0)/2]; + + t.x[l0] = __lows2half2(tmp0, tmp1); + t.x[l0 + t.ne/2] = __highs2half2(tmp0, tmp1); + } + } else { + half * xh = (half *) t.x; +#pragma unroll + for (int l = 0; l < t.ne; ++l) { + xh[2*l + 0] = ((const half *) xs0)[(2*t.get_j(l) + 0)*(2*stride) + t.get_i(l)]; + xh[2*l + 1] = ((const half *) xs0)[(2*t.get_j(l) + 1)*(2*stride) + t.get_i(l)]; + } + } #else GGML_UNUSED_VARS(t, xs0, stride); NO_DEVICE_CODE; @@ -878,12 +1020,65 @@ namespace ggml_cuda_mma { : "+r"(Dxi[2]), "+r"(Dxi[3]) : "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[3])); #endif // __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE +#elif defined(AMD_WMMA_AVAILABLE) +#if defined(RDNA4) + using halfx8_t = __attribute__((ext_vector_type(8))) _Float16; + halfx8_t& acc_frag = reinterpret_cast<halfx8_t&>(D.x[0]); + const halfx8_t& a_frag = reinterpret_cast<const halfx8_t&>(A.x[0]); + const halfx8_t& b_frag = reinterpret_cast<const halfx8_t&>(B.x[0]); + acc_frag = __builtin_amdgcn_wmma_f16_16x16x16_f16_w32_gfx12(a_frag, b_frag, acc_frag); +#else + GGML_UNUSED_VARS(D, A, B); + NO_DEVICE_CODE; +#endif // defined(RDNA4) +#elif defined(AMD_MFMA_AVAILABLE) + // MFMA: FP16 input, FP32 accumulate, convert back to half2. + using halfx4_t = __attribute__((ext_vector_type(4))) _Float16; + using floatx4_t = __attribute__((ext_vector_type(4))) float; + + // Convert existing half2 accumulator to float for MFMA: + floatx4_t acc_f32; + { + const halfx4_t acc_h = reinterpret_cast<const halfx4_t&>(D.x[0]); +#pragma unroll + for (int i = 0; i < 4; ++i) { + acc_f32[i] = (float)acc_h[i]; + } + } + + const halfx4_t& a_frag = reinterpret_cast<const halfx4_t&>(A.x[0]); + const halfx4_t& b_frag = reinterpret_cast<const halfx4_t&>(B.x[0]); + acc_f32 = __builtin_amdgcn_mfma_f32_16x16x16f16(a_frag, b_frag, acc_f32, 0, 0, 0); + + // Convert back to half2: + { + halfx4_t result_h; +#pragma unroll + for (int i = 0; i < 4; ++i) { + result_h[i] = (_Float16)acc_f32[i]; + } + reinterpret_cast<halfx4_t&>(D.x[0]) = result_h; + } #else GGML_UNUSED_VARS(D, A, B); NO_DEVICE_CODE; #endif // TURING_MMA_AVAILABLE } + static __device__ __forceinline__ void mma( + tile<16, 16, half2, DATA_LAYOUT_I_MAJOR_SCRAMBLED> & D, const tile<32, 8, half2, DATA_LAYOUT_I_MAJOR> & A, + const tile<16, 8, half2, DATA_LAYOUT_I_MAJOR> & B) { +#if defined(AMD_MFMA_AVAILABLE) || (defined(AMD_WMMA_AVAILABLE) && defined(RDNA4)) + tile<16, 8, half2> * D16 = (tile<16, 8, half2> *) &D; + const tile<16, 8, half2> * A16 = (const tile<16, 8, half2> *) &A; + mma(D16[0], A16[0], B); + mma(D16[1], A16[1], B); +#else + GGML_UNUSED_VARS(D, A, B); + NO_DEVICE_CODE; +#endif // defined(AMD_MFMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) && defined(RDNA4) + } + template <data_layout dl_ab, data_layout dl_d> static __device__ __forceinline__ void mma( tile<16, 8, float, dl_d> & D, const tile<16, 8, float, dl_ab> & A, const tile<8, 8, float, dl_ab> & B) { @@ -900,25 +1095,62 @@ namespace ggml_cuda_mma { #endif // AMPERE_MMA_AVAILABLE } - static __device__ __forceinline__ void mma_block_scaled(tile<16, 8, float> & D, - const tile<16, 8, int> & A, - const tile<8, 8, int> & B, - uint32_t a_scale, - uint32_t b_scale) { + template <data_layout dl_ab, data_layout dl_d> + static __device__ __forceinline__ void mma( + tile<16, 16, float, dl_d> & D, const tile<16, 8, float, dl_ab> & A, const tile<16, 8, float, dl_ab> & B) { +#ifdef AMD_MFMA_AVAILABLE + using floatx4_t = __attribute__((ext_vector_type(4))) float; + floatx4_t& acc_frag = reinterpret_cast<floatx4_t&>(D.x[0]); +#if defined(CDNA3) + using floatx2_t = __attribute__((ext_vector_type(2))) float; + const floatx2_t& a_frag = reinterpret_cast<const floatx2_t&>(A.x[0]); + const floatx2_t& b_frag = reinterpret_cast<const floatx2_t&>(B.x[0]); + acc_frag = __builtin_amdgcn_mfma_f32_16x16x8_xf32(a_frag, b_frag, acc_frag, 0, 0, 0); +#elif defined(CDNA4) || defined(CDNA2) || defined(CDNA1) + // CDNA4 (gfx950) does not support xf32 MFMA, use f32 path like CDNA2/CDNA1 +#pragma unroll + for (int i = 0; i < 2; ++i) { + acc_frag = __builtin_amdgcn_mfma_f32_16x16x4f32(A.x[i], B.x[i], acc_frag, 0, 0, 0); + } +#else + GGML_UNUSED_VARS(D, A, B); + NO_DEVICE_CODE; +#endif // defined(CDNA3) +#else + GGML_UNUSED_VARS(D, A, B); + NO_DEVICE_CODE; +#endif // AMD_MFMA_AVAILABLE + } + + template <ggml_type type> + static __device__ __forceinline__ void mma_block_scaled_fp4(tile<16, 8, float> & D, + const tile<16, 8, int> & A, + const tile<8, 8, int> & B, + uint32_t a_scale, + uint32_t b_scale) { #ifdef BLACKWELL_MMA_AVAILABLE const int * Axi = (const int *) A.x; const int * Bxi = (const int *) B.x; float * Dxi = (float *) D.x; - asm volatile( - "mma.sync.aligned.kind::mxf4.block_scale.scale_vec::2X.m16n8k64.row.col.f32.e2m1.e2m1.f32.ue8m0 " - "{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%0, %1, %2, %3}, " - "%10, {0, 0}, %11, {0, 0};" - : "+f"(Dxi[0]), "+f"(Dxi[1]), "+f"(Dxi[2]), "+f"(Dxi[3]) - : "r"(Axi[0]), "r"(Axi[1]), "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[0]), "r"(Bxi[1]), "r"(a_scale), "r"(b_scale)); + if constexpr (type == GGML_TYPE_MXFP4) { + asm volatile( + "mma.sync.aligned.kind::mxf4.block_scale.scale_vec::2X.m16n8k64.row.col.f32.e2m1.e2m1.f32.ue8m0 " + "{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%0, %1, %2, %3}, " + "%10, {0, 0}, %11, {0, 0};" + : "+f"(Dxi[0]), "+f"(Dxi[1]), "+f"(Dxi[2]), "+f"(Dxi[3]) + : "r"(Axi[0]), "r"(Axi[1]), "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[0]), "r"(Bxi[1]), "r"(a_scale), "r"(b_scale)); + } else { + asm volatile( + "mma.sync.aligned.kind::mxf4nvf4.block_scale.scale_vec::4X.m16n8k64.row.col.f32.e2m1.e2m1.f32.ue4m3 " + "{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%0, %1, %2, %3}, " + "%10, {0, 0}, %11, {0, 0};" + : "+f"(Dxi[0]), "+f"(Dxi[1]), "+f"(Dxi[2]), "+f"(Dxi[3]) + : "r"(Axi[0]), "r"(Axi[1]), "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[0]), "r"(Bxi[1]), "r"(a_scale), "r"(b_scale)); + } #else GGML_UNUSED_VARS(D, A, B, a_scale, b_scale); -#endif // BLACKWELL_MMA_AVAILABLE +#endif // BLACKWELL_MMA_AVAILABLE } static __device__ __forceinline__ void mma( @@ -1009,6 +1241,13 @@ namespace ggml_cuda_mma { GGML_UNUSED_VARS(D, A, B); NO_DEVICE_CODE; #endif // RDNA4 +#elif defined(AMD_MFMA_AVAILABLE) + using halfx4_t = __attribute__((ext_vector_type(4))) _Float16; + using floatx4_t = __attribute__((ext_vector_type(4))) float; + floatx4_t& acc_frag = reinterpret_cast<floatx4_t&>(D.x[0]); + const halfx4_t& a_frag = reinterpret_cast<const halfx4_t&>(A.x[0]); + const halfx4_t& b_frag = reinterpret_cast<const halfx4_t&>(B.x[0]); + acc_frag = __builtin_amdgcn_mfma_f32_16x16x16f16(a_frag, b_frag, acc_frag, 0, 0, 0); #else GGML_UNUSED_VARS(D, A, B); NO_DEVICE_CODE; @@ -1036,11 +1275,31 @@ namespace ggml_cuda_mma { #else GGML_UNUSED_VARS(D, A, B); NO_DEVICE_CODE; -#endif // RDNA4 +#endif // defined(RDNA4) +#elif defined(AMD_MFMA_AVAILABLE) + using floatx4_t = __attribute__((ext_vector_type(4))) float; + floatx4_t& acc_frag = reinterpret_cast<floatx4_t&>(D.x[0]); +#if defined(CDNA4) || defined(CDNA3) || defined(CDNA2) + using bf16x4_t = __attribute__((ext_vector_type(4))) __bf16; + const bf16x4_t& a_frag = reinterpret_cast<const bf16x4_t&>(A.x[0]); + const bf16x4_t& b_frag = reinterpret_cast<const bf16x4_t&>(B.x[0]); + acc_frag = __builtin_amdgcn_mfma_f32_16x16x16bf16_1k(a_frag, b_frag, acc_frag, 0, 0, 0); +#elif defined(CDNA1) +#pragma unroll + for (int i = 0; i < 2; ++i) { + using bf16x2_t = __attribute__((ext_vector_type(2))) __bf16; + const bf16x2_t& a_frag = reinterpret_cast<const bf16x2_t&>(A.x[i]); + const bf16x2_t& b_frag = reinterpret_cast<const bf16x2_t&>(B.x[i]); + acc_frag = __builtin_amdgcn_mfma_f32_16x16x8bf16(a_frag, b_frag, acc_frag, 0, 0, 0); + } #else GGML_UNUSED_VARS(D, A, B); NO_DEVICE_CODE; -#endif // AMPERE_MMA_AVAILABLE +#endif // defined(CDNA3) || defined(CDNA2) +#else + GGML_UNUSED_VARS(D, A, B); + NO_DEVICE_CODE; +#endif // defined(AMD_WMMA_AVAILABLE) } template <data_layout dl_d, data_layout dl_ab> @@ -1049,74 +1308,28 @@ namespace ggml_cuda_mma { #if defined(AMD_MFMA_AVAILABLE) using int32x4_t = __attribute__((__vector_size__(4 * sizeof(int)))) int; int32x4_t * acc = (int32x4_t *) D.x; -#if defined(CDNA3) - acc[0] = __builtin_amdgcn_mfma_i32_16x16x32_i8(((int64_t *) A.x)[0], - ((int64_t *) B.x)[0], - acc[0], - 0, 0, 0); -#elif defined(CDNA2) || defined(CDNA) - acc[0] = __builtin_amdgcn_mfma_i32_16x16x16i8(A.x[0], - B.x[0], - acc[0], - 0, 0, 0); - acc[0] = __builtin_amdgcn_mfma_i32_16x16x16i8(A.x[1], - B.x[1], - acc[0], - 0, 0, 0); -#endif // defined(CDNA3) - +#if defined(CDNA4) || defined(CDNA3) + acc[0] = __builtin_amdgcn_mfma_i32_16x16x32_i8(((int64_t *) A.x)[0], ((int64_t *) B.x)[0], acc[0], 0, 0, 0); +#elif defined(CDNA2) || defined(CDNA1) + acc[0] = __builtin_amdgcn_mfma_i32_16x16x16i8(A.x[0], B.x[0], acc[0], 0, 0, 0); + acc[0] = __builtin_amdgcn_mfma_i32_16x16x16i8(A.x[1], B.x[1], acc[0], 0, 0, 0); +#endif // defined(CDNA4) || defined(CDNA3) #elif defined(AMD_WMMA_AVAILABLE) - using int32x8_t = __attribute__((__vector_size__(8 * sizeof(int)))) int; int32x8_t * acc = (int32x8_t *) D.x; - #if defined(RDNA4) using int32x2_t = __attribute__((__vector_size__(2 * sizeof(int)))) int; int32x2_t * a_vec = (int32x2_t *) A.x; int32x2_t * b_vec = (int32x2_t *) B.x; - - acc[0] = __builtin_amdgcn_wmma_i32_16x16x16_iu8_w32_gfx12( - true, - a_vec[0], - true, - b_vec[0], - acc[0], - true - ); - - acc[0] = __builtin_amdgcn_wmma_i32_16x16x16_iu8_w32_gfx12( - true, - a_vec[1], - true, - b_vec[1], - acc[0], - true - ); - + acc[0] = __builtin_amdgcn_wmma_i32_16x16x16_iu8_w32_gfx12(true, a_vec[0], true, b_vec[0], acc[0], true); + acc[0] = __builtin_amdgcn_wmma_i32_16x16x16_iu8_w32_gfx12(true, a_vec[1], true, b_vec[1], acc[0], true); #elif defined(RDNA3) using int32x4_t = __attribute__((__vector_size__(4 * sizeof(int)))) int; int32x4_t * a_vec = (int32x4_t *) A.x; int32x4_t * b_vec = (int32x4_t *) B.x; - - acc[0] = __builtin_amdgcn_wmma_i32_16x16x16_iu8_w32( - true, - a_vec[0], - true, - b_vec[0], - acc[0], - true - ); - - acc[0] = __builtin_amdgcn_wmma_i32_16x16x16_iu8_w32( - true, - a_vec[1], - true, - b_vec[1], - acc[0], - true - ); + acc[0] = __builtin_amdgcn_wmma_i32_16x16x16_iu8_w32(true, a_vec[0], true, b_vec[0], acc[0], true); + acc[0] = __builtin_amdgcn_wmma_i32_16x16x16_iu8_w32(true, a_vec[1], true, b_vec[1], acc[0], true); #endif // RDNA4 - #else GGML_UNUSED_VARS(D, A, B); NO_DEVICE_CODE; @@ -1128,21 +1341,12 @@ namespace ggml_cuda_mma { #if defined(AMD_MFMA_AVAILABLE) using int32x16_t = __attribute__((__vector_size__(16 * sizeof(int)))) int; int32x16_t * acc = (int32x16_t *) D.x; -#if defined(CDNA3) - acc[0] = __builtin_amdgcn_mfma_i32_32x32x16_i8(((int64_t *) A.x)[0], - ((int64_t *) B.x)[0], - acc[0], - 0, 0, 0); -#elif defined(CDNA2) || defined(CDNA) - acc[0] = __builtin_amdgcn_mfma_i32_32x32x8i8(A.x[0], - B.x[0], - acc[0], - 0, 0, 0); - acc[0] = __builtin_amdgcn_mfma_i32_32x32x8i8(A.x[1], - B.x[1], - acc[0], - 0, 0, 0); -#endif // defined(CDNA3) +#if defined(CDNA4) || defined(CDNA3) + acc[0] = __builtin_amdgcn_mfma_i32_32x32x16_i8(((int64_t *) A.x)[0], ((int64_t *) B.x)[0], acc[0], 0, 0, 0); +#elif defined(CDNA2) || defined(CDNA1) + acc[0] = __builtin_amdgcn_mfma_i32_32x32x8i8(A.x[0], B.x[0], acc[0], 0, 0, 0); + acc[0] = __builtin_amdgcn_mfma_i32_32x32x8i8(A.x[1], B.x[1], acc[0], 0, 0, 0); +#endif // defined(CDNA4) || defined(CDNA3) #else GGML_UNUSED_VARS(D, A, B); @@ -1161,7 +1365,7 @@ namespace ggml_cuda_mma { static __device__ __forceinline__ void mma( tile<32, 8, float> & D, const tile<32, 4, half2> & A, const tile<8, 4, half2, DATA_LAYOUT_I_MAJOR_MIRRORED> & B) { -#if __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA +#if defined(VOLTA_MMA_AVAILABLE) const int * Axi = (const int *) A.x; const int * Bxi = (const int *) B.x; int * Dxi = (int *) D.x; @@ -1176,12 +1380,12 @@ namespace ggml_cuda_mma { #else GGML_UNUSED_VARS(D, A, B); NO_DEVICE_CODE; -#endif // __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA +#endif // defined(VOLTA_MMA_AVAILABLE) } static __device__ __forceinline__ void mma( tile<32, 4, half2> & D, const tile<32, 4, half2> & A, const tile<8, 4, half2, DATA_LAYOUT_J_MAJOR_MIRRORED> & B) { -#if __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA +#if defined(VOLTA_MMA_AVAILABLE) const int * Axi = (const int *) A.x; const int * Bxi = (const int *) B.x; int * Dxi = (int *) D.x; @@ -1196,41 +1400,51 @@ namespace ggml_cuda_mma { #else GGML_UNUSED_VARS(D, A, B); NO_DEVICE_CODE; -#endif // __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA +#endif // defined(VOLTA_MMA_AVAILABLE) + } + + static __device__ __forceinline__ void mma( + tile<16, 16, half2, DATA_LAYOUT_I_MAJOR> & D, const tile<32, 8, half2, DATA_LAYOUT_I_MAJOR_MIRRORED> & A, + const tile<16, 8, half2, DATA_LAYOUT_I_MAJOR_MIRRORED> & B) { +#if defined(AMD_WMMA_AVAILABLE) && defined(RDNA3) + using halfx16_t = __attribute__((ext_vector_type(16))) _Float16; + halfx16_t * xD = (halfx16_t *) D.x; + const halfx16_t * xA = (const halfx16_t *) A.x; + const halfx16_t * xB = (const halfx16_t *) B.x; + xD[0] = __builtin_amdgcn_wmma_f16_16x16x16_f16_w32(xA[0], xB[0], xD[0], /*opsel =*/ 0); + xD[0] = __builtin_amdgcn_wmma_f16_16x16x16_f16_w32(xA[1], xB[0], xD[0], /*opsel =*/ 1); +#else + GGML_UNUSED_VARS(D, A, B); + NO_DEVICE_CODE; +#endif // TURING_MMA_AVAILABLE } template <data_layout dl_d, data_layout dl_ab> static __device__ __forceinline__ void mma( tile<16, 16, int, dl_d> & D, const tile<16, 4, int, dl_ab> & A, const tile<16, 4, int, dl_ab> & B) { -#if defined(AMD_WMMA_AVAILABLE) +#if defined(AMD_MFMA_AVAILABLE) + using int32x4_t = __attribute__((__vector_size__(4 * sizeof(int)))) int; + int32x4_t * acc = (int32x4_t *) D.x; +#if defined(CDNA4) || defined(CDNA3) + const int64_t xA = uint32_t(A.x[0]); + const int64_t xB = uint32_t(B.x[0]); + acc[0] = __builtin_amdgcn_mfma_i32_16x16x32_i8(xA, xB, acc[0], 0, 0, 0); +#elif defined(CDNA2) || defined(CDNA1) + acc[0] = __builtin_amdgcn_mfma_i32_16x16x16i8(A.x[0], B.x[0], acc[0], 0, 0, 0); +#endif // defined(CDNA4) || defined(CDNA3) +#elif defined(AMD_WMMA_AVAILABLE) using int32x8_t = __attribute__((__vector_size__(8 * sizeof(int)))) int; int32x8_t * acc = (int32x8_t *) D.x; #if defined(RDNA4) using int32x2_t = __attribute__((__vector_size__(2 * sizeof(int)))) int; int32x2_t * a_vec = (int32x2_t *) A.x; int32x2_t * b_vec = (int32x2_t *) B.x; - - acc[0] = __builtin_amdgcn_wmma_i32_16x16x16_iu8_w32_gfx12( - true, - a_vec[0], - true, - b_vec[0], - acc[0], - false - ); + acc[0] = __builtin_amdgcn_wmma_i32_16x16x16_iu8_w32_gfx12(true, a_vec[0], true, b_vec[0], acc[0], false); #elif defined(RDNA3) using int32x4_t = __attribute__((__vector_size__(4 * sizeof(int)))) int; int32x4_t * a_vec = (int32x4_t *) A.x; int32x4_t * b_vec = (int32x4_t *) B.x; - - acc[0] = __builtin_amdgcn_wmma_i32_16x16x16_iu8_w32( - true, - a_vec[0], - true, - b_vec[0], - acc[0], - false - ); + acc[0] = __builtin_amdgcn_wmma_i32_16x16x16_iu8_w32(true, a_vec[0], true, b_vec[0], acc[0], false); #endif // RDNA4 #else GGML_UNUSED(D); diff --git a/ggml/src/ggml-cuda/mmf.cu b/ggml/src/ggml-cuda/mmf.cu index 6643f243b12..aad4c34aa66 100644 --- a/ggml/src/ggml-cuda/mmf.cu +++ b/ggml/src/ggml-cuda/mmf.cu @@ -2,6 +2,13 @@ #include "mmf.cuh" #include "mmid.cuh" +static __forceinline__ int mmf_get_rows_per_block(const int cc) { + if (GGML_CUDA_CC_IS_CDNA(cc)) { + return MMF_ROWS_PER_BLOCK_CDNA; + } else { + return MMF_ROWS_PER_BLOCK; + } +} void ggml_cuda_mul_mat_f(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst) { GGML_ASSERT( src1->type == GGML_TYPE_F32); @@ -89,28 +96,32 @@ void ggml_cuda_mul_mat_f(ggml_backend_cuda_context & ctx, const ggml_tensor * sr ids_info_ptr = &ids_info; } + const int device = ggml_cuda_get_device(); + const int cc = ggml_cuda_info().devices[device].cc; + const int rows_per_block = mmf_get_rows_per_block(cc); + switch (src0->type) { case GGML_TYPE_F32: { const float * src0_d = (const float *) src0->data; constexpr int vals_per_T = 1; - mul_mat_f_switch_cols_per_block( - src0_d, src1_d, ids_d, dst_d, ne00/vals_per_T, ne01, ncols_dst, s01/vals_per_T, stride_col_y/vals_per_T, stride_col_dst, + mul_mat_f_switch_rows_per_block<float>( + rows_per_block, src0_d, src1_d, ids_d, dst_d, ne00/vals_per_T, ne01, ncols_dst, s01/vals_per_T, stride_col_y/vals_per_T, stride_col_dst, ids_s0, ids_s1, ne02, nchannels_y, nchannels_dst, s02/vals_per_T, stride_channel_y, stride_channel_dst, ne03, ne3, s03/vals_per_T, s13, s3, ctx.stream(), ids_info_ptr); } break; case GGML_TYPE_F16: { const half2 * src0_d = (const half2 *) src0->data; constexpr int vals_per_T = 2; - mul_mat_f_switch_cols_per_block( - src0_d, src1_d, ids_d, dst_d, ne00/vals_per_T, ne01, ncols_dst, s01/vals_per_T, stride_col_y/vals_per_T, stride_col_dst, + mul_mat_f_switch_rows_per_block<half2>( + rows_per_block, src0_d, src1_d, ids_d, dst_d, ne00/vals_per_T, ne01, ncols_dst, s01/vals_per_T, stride_col_y/vals_per_T, stride_col_dst, ids_s0, ids_s1, ne02, nchannels_y, nchannels_dst, s02/vals_per_T, stride_channel_y, stride_channel_dst, ne03, ne3, s03/vals_per_T, s13, s3, ctx.stream(), ids_info_ptr); } break; case GGML_TYPE_BF16: { const nv_bfloat162 * src0_d = (const nv_bfloat162 *) src0->data; constexpr int vals_per_T = 2; - mul_mat_f_switch_cols_per_block( - src0_d, src1_d, ids_d, dst_d, ne00/vals_per_T, ne01, ncols_dst, s01/vals_per_T, stride_col_y/vals_per_T, stride_col_dst, + mul_mat_f_switch_rows_per_block<nv_bfloat162>( + rows_per_block, src0_d, src1_d, ids_d, dst_d, ne00/vals_per_T, ne01, ncols_dst, s01/vals_per_T, stride_col_y/vals_per_T, stride_col_dst, ids_s0, ids_s1, ne02, nchannels_y, nchannels_dst, s02/vals_per_T, stride_channel_y, stride_channel_dst, ne03, ne3, s03/vals_per_T, s13, s3, ctx.stream(), ids_info_ptr); } break; @@ -140,7 +151,11 @@ bool ggml_cuda_should_use_mmf(enum ggml_type type, int cc, int warp_size, const return false; } } - if (src0_ne[1] % MMF_ROWS_PER_BLOCK != 0) { + if (src0_ne[1] % mmf_get_rows_per_block(cc) != 0) { + return false; + } + + if (GGML_CUDA_CC_IS_CDNA3(cc) && type == GGML_TYPE_BF16) { return false; } @@ -153,6 +168,11 @@ bool ggml_cuda_should_use_mmf(enum ggml_type type, int cc, int warp_size, const } else { if (GGML_CUDA_CC_IS_RDNA3_0(cc) && src1_ncols > 8) { return false; + } else if (GGML_CUDA_CC_IS_CDNA2(cc) && (type == GGML_TYPE_F16 || type == GGML_TYPE_BF16)) { + //TODO: truse CDNA2 as CDNA1, tune the perf when CDNA2 is available. + return false; + } else if (GGML_CUDA_CC_IS_CDNA1(cc) && (type == GGML_TYPE_F16 || type == GGML_TYPE_BF16)) { + return false; } else if (src1_ncols > 16) { return false; } @@ -160,11 +180,11 @@ bool ggml_cuda_should_use_mmf(enum ggml_type type, int cc, int warp_size, const switch (type) { case GGML_TYPE_F32: - return ampere_mma_available(cc); + return ampere_mma_available(cc) || amd_mfma_available(cc); case GGML_TYPE_F16: - return volta_mma_available(cc) || turing_mma_available(cc) || amd_wmma_available(cc); + return volta_mma_available(cc) || turing_mma_available(cc) || amd_wmma_available(cc) || amd_mfma_available(cc); case GGML_TYPE_BF16: - return ampere_mma_available(cc) || amd_wmma_available(cc); + return ampere_mma_available(cc) || amd_wmma_available(cc) || amd_mfma_available(cc); default: return false; } diff --git a/ggml/src/ggml-cuda/mmf.cuh b/ggml/src/ggml-cuda/mmf.cuh index e36730948ff..d55cc1ec7b5 100644 --- a/ggml/src/ggml-cuda/mmf.cuh +++ b/ggml/src/ggml-cuda/mmf.cuh @@ -7,6 +7,31 @@ using namespace ggml_cuda_mma; #define MMF_ROWS_PER_BLOCK 32 +#define MMF_ROWS_PER_BLOCK_CDNA 64 + +static __forceinline__ int64_t mmf_get_max_block_size(int cc) { + if (GGML_CUDA_CC_IS_CDNA(cc)) { + return 512; + } else { + return 256; + } +} + +static __forceinline__ int mmf_get_padding(int cc) { + if (GGML_CUDA_CC_IS_CDNA(cc)) { + return 2; + } else { + return 4; + } +} + +static constexpr __device__ int mmf_get_padding() { +#if defined(AMD_MFMA_AVAILABLE) + return 2; +#else + return 4; +#endif // defined(AMD_MFMA_AVAILABLE) +} struct mmf_ids_data { const int32_t * ids_src_compact = nullptr; @@ -29,23 +54,25 @@ static __global__ void mul_mat_f( const int channel_ratio, const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst, const int sample_ratio, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst) { // TODO: handle this in a consistent and simpler way after AMD MFMA support has been added -#if (!defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)) || defined(AMD_WMMA_AVAILABLE) +#if defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE) #if defined(AMD_WMMA_AVAILABLE) - // Special case for tf32, just dummy mma layout as wmma doesn't support it. - constexpr bool is_tf32 = std::is_same_v<T, float>; - constexpr int tile_B_I = is_tf32 ? 8 : 16; - constexpr int tile_C_J = is_tf32 ? 8 : 16; - constexpr data_layout ab_layout = is_tf32 ? DATA_LAYOUT_I_MAJOR : get_input_data_layout(); - typedef tile<16, 8, T, ab_layout> tile_A; - typedef tile<tile_B_I, 8, T, ab_layout> tile_B; - typedef tile<16, tile_C_J, float, DATA_LAYOUT_J_MAJOR> tile_C; + if constexpr (!(std::is_same_v<T, half2> || std::is_same_v<T, nv_bfloat162>) || rows_per_block != MMF_ROWS_PER_BLOCK) {NO_DEVICE_CODE;} else { + typedef tile<16, 8, T, get_input_data_layout()> tile_A; + typedef tile<16, 8, T, get_input_data_layout()> tile_B; + typedef tile<16, 16, float, DATA_LAYOUT_J_MAJOR> tile_C; +#elif defined(AMD_MFMA_AVAILABLE) + if constexpr (rows_per_block != MMF_ROWS_PER_BLOCK_CDNA) {NO_DEVICE_CODE;} else { + typedef tile<16, 8, T, DATA_LAYOUT_I_MAJOR> tile_A; + typedef tile<16, 8, T, DATA_LAYOUT_I_MAJOR> tile_B; + typedef tile<16, 16, float, DATA_LAYOUT_J_MAJOR> tile_C; #else #ifdef VOLTA_MMA_AVAILABLE - if constexpr (!std::is_same_v<T, half2>) {NO_DEVICE_CODE;} else { + if constexpr (!std::is_same_v<T, half2> || rows_per_block != MMF_ROWS_PER_BLOCK) {NO_DEVICE_CODE;} else { typedef tile<32, 4, T, DATA_LAYOUT_I_MAJOR> tile_A; typedef tile< 8, 4, T, DATA_LAYOUT_I_MAJOR_MIRRORED> tile_B; typedef tile<32, 8, float, DATA_LAYOUT_I_MAJOR> tile_C; #else + if constexpr (rows_per_block != MMF_ROWS_PER_BLOCK) {NO_DEVICE_CODE;} else { typedef tile<16, 8, T> tile_A; typedef tile<8, 8, T> tile_B; typedef tile<16, 8, float> tile_C; @@ -57,14 +84,14 @@ static __global__ void mul_mat_f( } constexpr int warp_size = ggml_cuda_get_physical_warp_size(); - constexpr int tile_k_padded = warp_size + 4; + constexpr int tile_k_padded = warp_size + mmf_get_padding(); constexpr int ntA = rows_per_block / tile_A::I; constexpr int ntB = (cols_per_block + tile_B::I - 1) / tile_B::I; const int row0 = blockIdx.x * rows_per_block; int expert_idx = 0; - int col_base = 0; + [[maybe_unused]] int col_base = 0; const int channel_dst = has_ids ? 0 : blockIdx.y; @@ -95,12 +122,12 @@ static __global__ void mul_mat_f( ids += col_offset * stride_row_id; } - const float2 * y2 = (const float2 *) y; + [[maybe_unused]] const float2 * y2 = (const float2 *) y; extern __shared__ char data_mmv[]; char * shmem_base = data_mmv; - int * slot_map = (int *) shmem_base; + [[maybe_unused]] int * slot_map = (int *) shmem_base; char * compute_base = has_ids ? (shmem_base + GGML_PAD(cols_per_block, 16) * sizeof(int)) : shmem_base; tile_C C[ntA][ntB]; @@ -198,7 +225,7 @@ static __global__ void mul_mat_f( } float * buf_iw = (float *) compute_base; - constexpr int kiw = nwarps*rows_per_block + 4; + constexpr int kiw = nwarps*rows_per_block + mmf_get_padding(); if (nwarps > 1) { __syncthreads(); @@ -228,27 +255,34 @@ static __global__ void mul_mat_f( return; } - float sum = 0.0f; - static_assert(rows_per_block == warp_size, "need loop/check"); + float sum[rows_per_block/warp_size] = {0.0f}; + static_assert((rows_per_block % warp_size) == 0, "rows_per_block must be a multiple of warp_size."); #pragma unroll for (int i0 = 0; i0 < nwarps*rows_per_block; i0 += rows_per_block) { - const int i = i0 + threadIdx.x; +#pragma unroll + for (int i1 = 0; i1 < sizeof(sum)/sizeof(sum[0]); ++i1) { + const int i = i0 + i1*warp_size + threadIdx.x; - sum += buf_iw[j*kiw + i]; + sum[i1] += buf_iw[j*kiw + i]; + } } if constexpr (!has_ids) { - dst[j*stride_col_dst + row0 + threadIdx.x] = sum; +#pragma unroll + for (int i0 = 0; i0 < sizeof(sum)/sizeof(sum[0]); ++i0) { + dst[j*stride_col_dst + row0 + i0*warp_size + threadIdx.x] = sum[i0]; + } } else { const int slot = (j < cols_per_block) ? slot_map[j] : -1; if (slot >= 0 && (col_base + j) < ncols_dst_total) { - dst[slot*stride_channel_dst + j*stride_col_dst + row0 + threadIdx.x] = sum; +#pragma unroll + for (int i0 = 0; i0 < sizeof(sum)/sizeof(sum[0]); ++i0) { + dst[slot*stride_channel_dst + j*stride_col_dst + row0 + i0*warp_size + threadIdx.x] = sum[i0]; + } } } } -#ifdef VOLTA_MMA_AVAILABLE } -#endif //VOLTA_MMA_AVAILABLE #else GGML_UNUSED_VARS(x, y, ids, dst, ncols, ncols_dst_total, nchannels_dst, stride_row, stride_col_y, stride_col_dst, @@ -256,7 +290,7 @@ static __global__ void mul_mat_f( channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); NO_DEVICE_CODE; -#endif // (!defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)) || defined(AMD_WMMA_AVAILABLE) +#endif // defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE) } //This kernel is for larger batch sizes of mul_mat_id @@ -271,23 +305,25 @@ static __global__ void mul_mat_f_ids( const int sample_ratio, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst, const uint3 sis1_fd, const uint3 nch_fd) { // TODO: handle this in a consistent and simpler way after AMD MFMA support has been added -#if (!defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)) || defined(AMD_WMMA_AVAILABLE) +#if defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE) #if defined(AMD_WMMA_AVAILABLE) - // Special case for tf32, just dummy mma layout as wmma doesn't support it. - constexpr bool is_tf32 = std::is_same_v<T, float>; - constexpr int tile_B_I = is_tf32 ? 8 : 16; - constexpr int tile_C_J = is_tf32 ? 8 : 16; - constexpr data_layout ab_layout = is_tf32 ? DATA_LAYOUT_I_MAJOR : get_input_data_layout(); - typedef tile<16, 8, T, ab_layout> tile_A; - typedef tile<tile_B_I, 8, T, ab_layout> tile_B; - typedef tile<16, tile_C_J, float, DATA_LAYOUT_J_MAJOR> tile_C; + if constexpr (!(std::is_same_v<T, half2> || std::is_same_v<T, nv_bfloat162>) || rows_per_block != MMF_ROWS_PER_BLOCK) {NO_DEVICE_CODE;} else { + typedef tile<16, 8, T, get_input_data_layout()> tile_A; + typedef tile<16, 8, T, get_input_data_layout()> tile_B; + typedef tile<16, 16, float, DATA_LAYOUT_J_MAJOR> tile_C; +#elif defined(AMD_MFMA_AVAILABLE) + if constexpr (rows_per_block != MMF_ROWS_PER_BLOCK_CDNA) {NO_DEVICE_CODE;} else { + typedef tile<16, 8, T, DATA_LAYOUT_I_MAJOR> tile_A; + typedef tile<16, 8, T, DATA_LAYOUT_I_MAJOR> tile_B; + typedef tile<16, 16, float, DATA_LAYOUT_J_MAJOR> tile_C; #else #ifdef VOLTA_MMA_AVAILABLE - if constexpr (!std::is_same_v<T, half2>) {NO_DEVICE_CODE;} else { + if constexpr (!std::is_same_v<T, half2> || rows_per_block != MMF_ROWS_PER_BLOCK) {NO_DEVICE_CODE;} else { typedef tile<32, 4, T, DATA_LAYOUT_I_MAJOR> tile_A; typedef tile< 8, 4, T, DATA_LAYOUT_I_MAJOR_MIRRORED> tile_B; typedef tile<32, 8, float, DATA_LAYOUT_I_MAJOR> tile_C; #else + if constexpr (rows_per_block != MMF_ROWS_PER_BLOCK) {NO_DEVICE_CODE;} else { typedef tile<16, 8, T> tile_A; typedef tile<8, 8, T> tile_B; typedef tile<16, 8, float> tile_C; @@ -300,7 +336,7 @@ static __global__ void mul_mat_f_ids( constexpr int warp_size = ggml_cuda_get_physical_warp_size(); - constexpr int tile_k_padded = warp_size + 4; + constexpr int tile_k_padded = warp_size + mmf_get_padding(); constexpr int ntA = rows_per_block / tile_A::I; constexpr int ntB = (cols_per_block + tile_B::I - 1) / tile_B::I; @@ -467,7 +503,7 @@ static __global__ void mul_mat_f_ids( } float * buf_iw = (float *) compute_base; - constexpr int kiw = nwarps*rows_per_block + 4; + constexpr int kiw = nwarps*rows_per_block + mmf_get_padding(); if (nwarps > 1) { __syncthreads(); @@ -497,13 +533,16 @@ static __global__ void mul_mat_f_ids( return; } - float sum = 0.0f; - static_assert(rows_per_block == warp_size, "need loop/check"); + float sum[rows_per_block/warp_size] = {0.0f}; + static_assert((rows_per_block % warp_size) == 0, "rows_per_block must be a multiple of warp_size."); #pragma unroll for (int i0 = 0; i0 < nwarps*rows_per_block; i0 += rows_per_block) { - const int i = i0 + threadIdx.x; +#pragma unroll + for (int i1 = 0; i1 < sizeof(sum)/sizeof(sum[0]); ++i1) { + const int i = i0 + i1*warp_size + threadIdx.x; - sum += buf_iw[j*kiw + i]; + sum[i1] += buf_iw[j * kiw + i]; + } } const int global_j = col_base + j; @@ -513,23 +552,24 @@ static __global__ void mul_mat_f_ids( const int token = (int) qrm.x; if (token < ncols_dst_total) { const int slot = (int) qrm.y; - dst[slot*stride_channel_dst + token*stride_col_dst + row0 + threadIdx.x] = sum; +#pragma unroll + for (int i0 = 0; i0 < sizeof(sum)/sizeof(sum[0]); ++i0) { + dst[slot * stride_channel_dst + token * stride_col_dst + row0 + i0*warp_size + threadIdx.x] = sum[i0]; + } } } } -#ifdef VOLTA_MMA_AVAILABLE } -#endif // VOLTA_MMA_AVAILABLE #else GGML_UNUSED_VARS(x, y, ids_src_compact, ids_dst_compact, expert_bounds, dst, ncols, ncols_dst_total, nchannels_dst, stride_row, stride_col_y, stride_col_dst, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, sis1_fd, nch_fd); NO_DEVICE_CODE; -#endif // (!defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)) || defined(AMD_WMMA_AVAILABLE) +#endif // defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE) } -template<typename T, int cols_per_block, int nwarps> +template<typename T, int rows_per_block, int cols_per_block, int nwarps> static inline void mul_mat_f_switch_ids( const T * x, const float * y, const int32_t * ids, float * dst, const int64_t ncols_x, const int64_t ncols_dst, const int64_t nchannels_dst, @@ -553,7 +593,7 @@ static inline void mul_mat_f_switch_ids( const uint3 sis1_fd = ids_data->sis1 > 0 ? init_fastdiv_values((uint32_t) ids_data->sis1) : make_uint3(0, 0, 1); const uint3 nch_fd = init_fastdiv_values((uint32_t) nchannels_dst); - mul_mat_f_ids<T, MMF_ROWS_PER_BLOCK, cols_per_block, nwarps><<<block_nums_ids, block_dims, nbytes_shared_total, stream>>> + mul_mat_f_ids<T, rows_per_block, cols_per_block, nwarps><<<block_nums_ids, block_dims, nbytes_shared_total, stream>>> (x, y, ids_data->ids_src_compact, ids_data->ids_dst_compact, ids_data->expert_bounds_dev, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, @@ -564,19 +604,19 @@ static inline void mul_mat_f_switch_ids( dim3 block_nums_ids = block_nums; block_nums_ids.y *= col_tiles; - mul_mat_f<T, MMF_ROWS_PER_BLOCK, cols_per_block, nwarps, true><<<block_nums_ids, block_dims, nbytes_shared_total, stream>>> + mul_mat_f<T, rows_per_block, cols_per_block, nwarps, true><<<block_nums_ids, block_dims, nbytes_shared_total, stream>>> (x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst, stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); } else { - mul_mat_f<T, MMF_ROWS_PER_BLOCK, cols_per_block, nwarps, false><<<block_nums, block_dims, nbytes_shared_total, stream>>> + mul_mat_f<T, rows_per_block, cols_per_block, nwarps, false><<<block_nums, block_dims, nbytes_shared_total, stream>>> (x, y, ids, dst, ncols_x, cols_per_block, nchannels_dst, stride_row, stride_col_y, stride_col_dst, stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); } } -template <typename T, int cols_per_block> +template <typename T, int rows_per_block, int cols_per_block> void mul_mat_f_cuda( const T * x, const float * y, const int32_t * ids, float * dst, const int64_t ncols_x, const int64_t nrows_x, const int64_t ncols_dst, @@ -605,7 +645,7 @@ void mul_mat_f_cuda( int64_t nwarps_best = 1; int64_t niter_best = (ncols_x + warp_size*2 - 1) / (warp_size*2); - int64_t max_block_size = 256; + int64_t max_block_size = mmf_get_max_block_size(cc); for (int64_t nwarps = 2; nwarps <= max_block_size/warp_size; nwarps++) { const int64_t niter = (ncols_x + nwarps*warp_size*2 - 1) / (nwarps*warp_size*2); if (niter < niter_best) { @@ -614,10 +654,9 @@ void mul_mat_f_cuda( } } - constexpr int rows_per_block = MMF_ROWS_PER_BLOCK; - const int nbytes_shared_iter = nwarps_best * (volta_mma_available(cc) ? tile_A_32::I : tile_A_16::I) * (warp_size + 4) * 4; - const int nbytes_cols_per_block_pad = amd_wmma_available(cc) ? tile_B_16::I : tile_B_8::I; - const int nbytes_shared_combine = GGML_PAD(cols_per_block, nbytes_cols_per_block_pad) * (nwarps_best*rows_per_block + 4) * 4; + const int nbytes_shared_iter = nwarps_best * (volta_mma_available(cc) ? tile_A_32::I : tile_A_16::I) * (warp_size + mmf_get_padding(cc)) * 4; + const int nbytes_cols_per_block_pad = (amd_wmma_available(cc) || amd_mfma_available(cc)) ? tile_B_16::I : tile_B_8::I; + const int nbytes_shared_combine = GGML_PAD(cols_per_block, nbytes_cols_per_block_pad) * (nwarps_best*rows_per_block + mmf_get_padding(cc)) * 4; const int nbytes_shared = std::max(nbytes_shared_iter, nbytes_shared_combine); const int nbytes_slotmap = ids ? GGML_PAD(cols_per_block, 16) * sizeof(int) : 0; const int nbytes_shared_total = nbytes_shared + nbytes_slotmap; @@ -628,56 +667,56 @@ void mul_mat_f_cuda( switch (nwarps_best) { case 1: { - mul_mat_f_switch_ids<T, cols_per_block, 1>( + mul_mat_f_switch_ids<T, rows_per_block, cols_per_block, 1>( x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst, stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream, ids_data); } break; case 2: { - mul_mat_f_switch_ids<T, cols_per_block, 2>( + mul_mat_f_switch_ids<T, rows_per_block, cols_per_block, 2>( x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst, stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream, ids_data); } break; case 3: { - mul_mat_f_switch_ids<T, cols_per_block, 3>( + mul_mat_f_switch_ids<T, rows_per_block, cols_per_block, 3>( x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst, stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream, ids_data); } break; case 4: { - mul_mat_f_switch_ids<T, cols_per_block, 4>( + mul_mat_f_switch_ids<T, rows_per_block, cols_per_block, 4>( x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst, stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream, ids_data); } break; case 5: { - mul_mat_f_switch_ids<T, cols_per_block, 5>( + mul_mat_f_switch_ids<T, rows_per_block, cols_per_block, 5>( x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst, stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream, ids_data); } break; case 6: { - mul_mat_f_switch_ids<T, cols_per_block, 6>( + mul_mat_f_switch_ids<T, rows_per_block, cols_per_block, 6>( x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst, stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream, ids_data); } break; case 7: { - mul_mat_f_switch_ids<T, cols_per_block, 7>( + mul_mat_f_switch_ids<T, rows_per_block, cols_per_block, 7>( x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst, stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream, ids_data); } break; case 8: { - mul_mat_f_switch_ids<T, cols_per_block, 8>( + mul_mat_f_switch_ids<T, rows_per_block, cols_per_block, 8>( x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst, stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream, @@ -691,7 +730,7 @@ void mul_mat_f_cuda( GGML_UNUSED_VARS(nchannels_y); } -template <typename T> +template <typename T, int rows_per_block> static void mul_mat_f_switch_cols_per_block( const T * x, const float * y, const int32_t * ids, float * dst, const int64_t ncols_x, const int64_t nrows_x, const int64_t ncols_dst, @@ -708,82 +747,82 @@ static void mul_mat_f_switch_cols_per_block( switch (ncols_case) { case 1: { - mul_mat_f_cuda<T, 1>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst, + mul_mat_f_cuda<T, rows_per_block, 1>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst, stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data); } break; case 2: { - mul_mat_f_cuda<T, 2>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst, + mul_mat_f_cuda<T, rows_per_block, 2>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst, stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data); } break; case 3: { - mul_mat_f_cuda<T, 3>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst, + mul_mat_f_cuda<T, rows_per_block, 3>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst, stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data); } break; case 4: { - mul_mat_f_cuda<T, 4>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst, + mul_mat_f_cuda<T, rows_per_block, 4>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst, stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data); } break; case 5: { - mul_mat_f_cuda<T, 5>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst, + mul_mat_f_cuda<T, rows_per_block, 5>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst, stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data); } break; case 6: { - mul_mat_f_cuda<T, 6>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst, + mul_mat_f_cuda<T, rows_per_block, 6>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst, stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data); } break; case 7: { - mul_mat_f_cuda<T, 7>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst, + mul_mat_f_cuda<T, rows_per_block, 7>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst, stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data); } break; case 8: { - mul_mat_f_cuda<T, 8>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst, + mul_mat_f_cuda<T, rows_per_block, 8>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst, stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data); } break; case 9: { - mul_mat_f_cuda<T, 9>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst, + mul_mat_f_cuda<T, rows_per_block, 9>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst, stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data); } break; case 10: { - mul_mat_f_cuda<T, 10>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst, + mul_mat_f_cuda<T, rows_per_block, 10>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst, stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data); } break; case 11: { - mul_mat_f_cuda<T, 11>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst, + mul_mat_f_cuda<T, rows_per_block, 11>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst, stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data); } break; case 12: { - mul_mat_f_cuda<T, 12>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst, + mul_mat_f_cuda<T, rows_per_block, 12>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst, stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data); } break; case 13: { - mul_mat_f_cuda<T, 13>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst, + mul_mat_f_cuda<T, rows_per_block, 13>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst, stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data); } break; case 14: { - mul_mat_f_cuda<T, 14>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst, + mul_mat_f_cuda<T, rows_per_block, 14>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst, stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data); } break; case 15: { - mul_mat_f_cuda<T, 15>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst, + mul_mat_f_cuda<T, rows_per_block, 15>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst, stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data); } break; case 16: { - mul_mat_f_cuda<T, 16>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst, + mul_mat_f_cuda<T, rows_per_block, 16>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst, stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data); } break; @@ -793,8 +832,36 @@ static void mul_mat_f_switch_cols_per_block( } } -#define DECL_MMF_CASE_HELPER(T, ncols_dst) \ - template void mul_mat_f_cuda<T, ncols_dst>( \ +template <typename T> +static void mul_mat_f_switch_rows_per_block( + const int rows_per_block, const T * x, const float * y, const int32_t * ids, float * dst, + const int64_t ncols_x, const int64_t nrows_x, const int64_t ncols_dst, + const int64_t stride_row, const int64_t stride_col_y, const int64_t stride_col_dst, + const int64_t stride_col_id, const int stride_row_id, + const int64_t nchannels_x, const int64_t nchannels_y, const int64_t nchannels_dst, + const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t nsamples_x, + const int64_t nsamples_dst, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst, + cudaStream_t stream, const mmf_ids_data * ids_data) { + switch (rows_per_block) { + case MMF_ROWS_PER_BLOCK: { + mul_mat_f_switch_cols_per_block<T, MMF_ROWS_PER_BLOCK>( + x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst, + stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data); + } break; + case MMF_ROWS_PER_BLOCK_CDNA: { + mul_mat_f_switch_cols_per_block<T, MMF_ROWS_PER_BLOCK_CDNA>( + x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst, + stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data); + } break; + default: + GGML_ABORT("unsupported rows_per_block: %i", rows_per_block); + } +} + +#define DECL_MMF_CASE_HELPER(T, nrows_dst, ncols_dst) \ + template void mul_mat_f_cuda<T, nrows_dst, ncols_dst>( \ const T * x, const float * y, const int32_t * ids, float * dst, \ const int64_t ncols_x, const int64_t nrows_x, int64_t ncols_dst_total, const int64_t stride_row, const int64_t stride_col_y, const int64_t stride_col_dst, \ const int64_t stride_col_id, const int64_t stride_row_id, \ @@ -803,16 +870,22 @@ static void mul_mat_f_switch_cols_per_block( const int64_t nsamples_dst, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst, \ cudaStream_t stream, const mmf_ids_data * ids_data); -#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) +#if !defined(GGML_USE_MUSA) #define DECL_MMF_CASE_EXTERN(ncols_dst) \ - extern DECL_MMF_CASE_HELPER(float, ncols_dst) \ - extern DECL_MMF_CASE_HELPER(half2, ncols_dst) \ - extern DECL_MMF_CASE_HELPER(nv_bfloat162, ncols_dst) + extern DECL_MMF_CASE_HELPER(float, MMF_ROWS_PER_BLOCK, ncols_dst) \ + extern DECL_MMF_CASE_HELPER(half2, MMF_ROWS_PER_BLOCK, ncols_dst) \ + extern DECL_MMF_CASE_HELPER(nv_bfloat162, MMF_ROWS_PER_BLOCK, ncols_dst) \ + extern DECL_MMF_CASE_HELPER(float, MMF_ROWS_PER_BLOCK_CDNA, ncols_dst) \ + extern DECL_MMF_CASE_HELPER(half2, MMF_ROWS_PER_BLOCK_CDNA, ncols_dst) \ + extern DECL_MMF_CASE_HELPER(nv_bfloat162, MMF_ROWS_PER_BLOCK_CDNA, ncols_dst) #define DECL_MMF_CASE(ncols_dst) \ - DECL_MMF_CASE_HELPER(float, ncols_dst) \ - DECL_MMF_CASE_HELPER(half2, ncols_dst) \ - DECL_MMF_CASE_HELPER(nv_bfloat162, ncols_dst) + DECL_MMF_CASE_HELPER(float, MMF_ROWS_PER_BLOCK, ncols_dst) \ + DECL_MMF_CASE_HELPER(half2, MMF_ROWS_PER_BLOCK, ncols_dst) \ + DECL_MMF_CASE_HELPER(nv_bfloat162, MMF_ROWS_PER_BLOCK, ncols_dst) \ + DECL_MMF_CASE_HELPER(float, MMF_ROWS_PER_BLOCK_CDNA, ncols_dst) \ + DECL_MMF_CASE_HELPER(half2, MMF_ROWS_PER_BLOCK_CDNA, ncols_dst) \ + DECL_MMF_CASE_HELPER(nv_bfloat162, MMF_ROWS_PER_BLOCK_CDNA, ncols_dst) DECL_MMF_CASE_EXTERN(1); DECL_MMF_CASE_EXTERN(2); diff --git a/ggml/src/ggml-cuda/mmq.cu b/ggml/src/ggml-cuda/mmq.cu index 9a69f41d159..e1add5e0331 100644 --- a/ggml/src/ggml-cuda/mmq.cu +++ b/ggml/src/ggml-cuda/mmq.cu @@ -5,6 +5,9 @@ static void ggml_cuda_mul_mat_q_switch_type(ggml_backend_cuda_context & ctx, const mmq_args & args, cudaStream_t stream) { switch (args.type_x) { + case GGML_TYPE_Q1_0: + mul_mat_q_case<GGML_TYPE_Q1_0>(ctx, args, stream); + break; case GGML_TYPE_Q4_0: mul_mat_q_case<GGML_TYPE_Q4_0>(ctx, args, stream); break; @@ -23,6 +26,9 @@ static void ggml_cuda_mul_mat_q_switch_type(ggml_backend_cuda_context & ctx, con case GGML_TYPE_MXFP4: mul_mat_q_case<GGML_TYPE_MXFP4>(ctx, args, stream); break; + case GGML_TYPE_NVFP4: + mul_mat_q_case<GGML_TYPE_NVFP4>(ctx, args, stream); + break; case GGML_TYPE_Q2_K: mul_mat_q_case<GGML_TYPE_Q2_K>(ctx, args, stream); break; @@ -116,7 +122,7 @@ void ggml_cuda_mul_mat_q( || GGML_CUDA_CC_IS_CDNA(cc); // TODO: tighter pool buffer size vs q8 path - const bool use_native_mxfp4 = blackwell_mma_available(cc) && src0->type == GGML_TYPE_MXFP4; + const bool use_native_fp4 = blackwell_mma_available(cc) && (src0->type == GGML_TYPE_MXFP4 || src0->type == GGML_TYPE_NVFP4); if (!ids) { const size_t nbytes_src1_q8_1 = ne13*ne12 * ne11*ne10_padded * sizeof(block_q8_1)/QK8_1 + @@ -127,9 +133,9 @@ void ggml_cuda_mul_mat_q( const int64_t s11 = src1->nb[1] / ts_src1; const int64_t s12 = src1->nb[2] / ts_src1; const int64_t s13 = src1->nb[3] / ts_src1; - if (use_native_mxfp4) { + if (use_native_fp4) { static_assert(sizeof(block_fp4_mmq) == 4 * sizeof(block_q8_1)); - quantize_mmq_mxfp4_cuda(src1_d, nullptr, src1_q8_1.get(), src0->type, ne10, s11, s12, s13, ne10_padded, + quantize_mmq_fp4_cuda(src1_d, nullptr, src1_q8_1.get(), src0->type, ne10, s11, s12, s13, ne10_padded, ne11, ne12, ne13, stream); } else { @@ -140,10 +146,8 @@ void ggml_cuda_mul_mat_q( } // Stride depends on quantization format - const int64_t s12 = use_native_mxfp4 ? - ne11 * ne10_padded * sizeof(block_fp4_mmq) / - (8 * QK_MXFP4 * sizeof(int)) // block_fp4_mmq holds 256 values (8 blocks of 32) - : + const int64_t s12 = use_native_fp4 ? + ne11 * ne10_padded * sizeof(block_fp4_mmq) / (QK_K * sizeof(int)) : // block_fp4_mmq holds 256 values ne11 * ne10_padded * sizeof(block_q8_1) / (QK8_1 * sizeof(int)); const int64_t s13 = ne12*s12; @@ -192,8 +196,8 @@ void ggml_cuda_mul_mat_q( const int64_t s12 = src1->nb[2] / ts_src1; const int64_t s13 = src1->nb[3] / ts_src1; - if (use_native_mxfp4) { - quantize_mmq_mxfp4_cuda(src1_d, ids_src1.get(), src1_q8_1.get(), src0->type, ne10, s11, s12, s13, + if (use_native_fp4) { + quantize_mmq_fp4_cuda(src1_d, ids_src1.get(), src1_q8_1.get(), src0->type, ne10, s11, s12, s13, ne10_padded, ne11_flat, ne12_flat, ne13_flat, stream); } else { quantize_mmq_q8_1_cuda(src1_d, ids_src1.get(), src1_q8_1.get(), src0->type, ne10, s11, s12, s13, @@ -202,8 +206,9 @@ void ggml_cuda_mul_mat_q( CUDA_CHECK(cudaGetLastError()); } - const int64_t s12 = use_native_mxfp4 ? ne11 * ne10_padded * sizeof(block_fp4_mmq) / (8 * QK_MXFP4 * sizeof(int)) : - ne11 * ne10_padded * sizeof(block_q8_1) / (QK8_1 * sizeof(int)); + static_assert(QK_K == 8 * QK_MXFP4, "QK_K needs to be 8 * QK_MXFP4"); + const int64_t s12 = use_native_fp4 ? ne11 * ne10_padded * sizeof(block_fp4_mmq) / (QK_K * sizeof(int)) : + ne11 * ne10_padded * sizeof(block_q8_1) / (QK8_1 * sizeof(int)); const int64_t s13 = ne12*s12; // Note that ne02 is used instead of ne12 because the number of y channels determines the z dimension of the CUDA grid. @@ -267,12 +272,14 @@ bool ggml_cuda_should_use_mmq(enum ggml_type type, int cc, int64_t ne11, int64_t bool mmq_supported; switch (type) { + case GGML_TYPE_Q1_0: case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_1: case GGML_TYPE_Q5_0: case GGML_TYPE_Q5_1: case GGML_TYPE_Q8_0: case GGML_TYPE_MXFP4: + case GGML_TYPE_NVFP4: case GGML_TYPE_Q2_K: case GGML_TYPE_Q3_K: case GGML_TYPE_Q4_K: @@ -362,5 +369,4 @@ bool ggml_cuda_should_use_mmq(enum ggml_type type, int cc, int64_t ne11, int64_t } return (!GGML_CUDA_CC_IS_CDNA(cc)) || ne11 < MMQ_DP4A_MAX_BATCH_SIZE; - } diff --git a/ggml/src/ggml-cuda/mmq.cuh b/ggml/src/ggml-cuda/mmq.cuh index a382e6a6979..edf546d8f1e 100644 --- a/ggml/src/ggml-cuda/mmq.cuh +++ b/ggml/src/ggml-cuda/mmq.cuh @@ -10,9 +10,9 @@ using namespace ggml_cuda_mma; #define MMQ_DP4A_MAX_BATCH_SIZE 64 // Max. batch size to use for dp4a MMQ kernels when FP16 tensor cores are available. -#define MMQ_ITER_K 256 -#define MMQ_ITER_K_MXFP4_FP4 512 -#define MMQ_NWARPS 8 +#define MMQ_ITER_K 256 +#define MMQ_ITER_K_FP4 512 +#define MMQ_NWARPS 8 typedef void (*load_tiles_mmq_t)(const char * __restrict__ x, int * x_tile, const int kbx0, const int i_max, const int stride); typedef void (*vec_dot_mmq_t)(const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00); @@ -46,9 +46,12 @@ struct block_q8_1_mmq { int8_t qs[4*QK8_1]; // 128 values quantized to 8 bit each }; +// this struct is used for fp4 data types (currently only used for Blackwell) +// mxfp4 has block size 32, each int32 of d4 contains 2 e8m0 scales in the lower 16 bits +// nvfp4 has block size 16, each int32 of d4 contains 4 ue4m3 scales struct block_fp4_mmq { - uint32_t d4[4]; // 8 E8M0 scales (1 per 32 values), 2 packed per uint32: d4[0]={s0,s1}, d4[1]={s2,s3}, etc. - int8_t qs[4 * 32]; // 256 FP4 values packed as 4-bit pairs (2 per byte), 8 blocks of 32 values + uint32_t d4[4]; + int8_t qs[4 * 32]; // 256 FP4 values packed as 4-bit pairs (2 per byte) }; static_assert(sizeof(block_q8_1_mmq) == 4*QK8_1 + 4*sizeof(half2), "Unexpected block_q8_1_mmq size"); @@ -57,6 +60,8 @@ static_assert(sizeof(block_fp4_mmq) == sizeof(block_q8_1_mmq), "Unexpected b static mmq_q8_1_ds_layout mmq_get_q8_1_ds_layout(const ggml_type type_x) { switch (type_x) { + case GGML_TYPE_Q1_0: + return MMQ_Q8_1_DS_LAYOUT_D4; case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_1: return MMQ_Q8_1_DS_LAYOUT_DS4; @@ -68,6 +73,8 @@ static mmq_q8_1_ds_layout mmq_get_q8_1_ds_layout(const ggml_type type_x) { return MMQ_Q8_1_DS_LAYOUT_D4; case GGML_TYPE_MXFP4: return MMQ_Q8_1_DS_LAYOUT_D4; + case GGML_TYPE_NVFP4: + return MMQ_Q8_1_DS_LAYOUT_D4; case GGML_TYPE_Q2_K: return MMQ_Q8_1_DS_LAYOUT_D2S6; case GGML_TYPE_Q3_K: @@ -100,7 +107,7 @@ struct tile_x_sizes { }; static int get_mmq_x_max_host(const int cc) { - return (amd_mfma_available(cc) || turing_mma_available(cc) || amd_wmma_available(cc)) ? 128 : + return (turing_mma_available(cc) || amd_wmma_available(cc)) ? 128 : GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA ? #ifdef GGML_CUDA_FORCE_MMQ 128 : 64; @@ -110,9 +117,9 @@ static int get_mmq_x_max_host(const int cc) { } static constexpr __device__ int get_mmq_x_max_device() { -#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) +#if defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) return 128; -#else // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) +#else // defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) #if defined(GGML_USE_HIP) return 64; @@ -139,10 +146,11 @@ static int get_mmq_y_host(const int cc) { static constexpr __device__ int get_iter_k([[maybe_unused]] const ggml_type type) { #if defined(BLACKWELL_MMA_AVAILABLE) - return type == GGML_TYPE_MXFP4 ? MMQ_ITER_K_MXFP4_FP4 : MMQ_ITER_K; -#else - return MMQ_ITER_K; +if (type == GGML_TYPE_NVFP4 || type == GGML_TYPE_MXFP4) { + return MMQ_ITER_K_FP4; +} #endif // defined(BLACKWELL_MMA_AVAILABLE) + return MMQ_ITER_K; } static constexpr __device__ int get_mmq_y_device() { @@ -183,12 +191,14 @@ static constexpr __device__ int get_mmq_y_device() { static constexpr __host__ __device__ tile_x_sizes mmq_get_dp4a_tile_x_sizes(ggml_type type, int mmq_y) { switch (type) { + case GGML_TYPE_Q1_0: return MMQ_DP4A_TXS_Q8_0; case GGML_TYPE_Q4_0: return MMQ_DP4A_TXS_Q4_0; case GGML_TYPE_Q4_1: return MMQ_DP4A_TXS_Q4_1; case GGML_TYPE_Q5_0: return MMQ_DP4A_TXS_Q8_0; case GGML_TYPE_Q5_1: return MMQ_DP4A_TXS_Q8_1; case GGML_TYPE_Q8_0: return MMQ_DP4A_TXS_Q8_0; case GGML_TYPE_MXFP4: return MMQ_DP4A_TXS_Q8_1; + case GGML_TYPE_NVFP4: return MMQ_DP4A_TXS_Q8_0_16; case GGML_TYPE_Q2_K: return MMQ_DP4A_TXS_Q2_K; case GGML_TYPE_Q3_K: return MMQ_DP4A_TXS_Q3_K; case GGML_TYPE_Q4_K: return MMQ_DP4A_TXS_Q4_K; @@ -206,12 +216,13 @@ static constexpr __host__ __device__ tile_x_sizes mmq_get_dp4a_tile_x_sizes(ggml } } -#define MMQ_MMA_TILE_X_K_Q8_0 (2*MMQ_TILE_NE_K + 2*MMQ_TILE_NE_K/QI8_0 + 4) -#define MMQ_MMA_TILE_X_K_FP4 (2*MMQ_TILE_NE_K + 8 + 4) -#define MMQ_MMA_TILE_X_K_Q8_1 (2*MMQ_TILE_NE_K + 2*MMQ_TILE_NE_K/QI8_0 + 4) -#define MMQ_MMA_TILE_X_K_Q2_K (2*MMQ_TILE_NE_K + MMQ_TILE_NE_K + 4) -#define MMQ_MMA_TILE_X_K_Q3_K (2*MMQ_TILE_NE_K + MMQ_TILE_NE_K/2 + 4) -#define MMQ_MMA_TILE_X_K_Q6_K (2*MMQ_TILE_NE_K + MMQ_TILE_NE_K/QI6_K + MMQ_TILE_NE_K/8 + 7) +#define MMQ_MMA_TILE_X_K_Q8_0 (2*MMQ_TILE_NE_K + 2*MMQ_TILE_NE_K/QI8_0 + 4) +#define MMQ_MMA_TILE_X_K_FP4 (2*MMQ_TILE_NE_K + 8 + 4) // MXFP4 and NVFP4 Blackwell +#define MMQ_MMA_TILE_X_K_NVFP4 (2*MMQ_TILE_NE_K + MMQ_TILE_NE_K/2 + 4) // NVFP4 Generic +#define MMQ_MMA_TILE_X_K_Q8_1 (2*MMQ_TILE_NE_K + 2*MMQ_TILE_NE_K/QI8_0 + 4) +#define MMQ_MMA_TILE_X_K_Q2_K (2*MMQ_TILE_NE_K + MMQ_TILE_NE_K + 4) +#define MMQ_MMA_TILE_X_K_Q3_K (2*MMQ_TILE_NE_K + MMQ_TILE_NE_K/2 + 4) +#define MMQ_MMA_TILE_X_K_Q6_K (2*MMQ_TILE_NE_K + MMQ_TILE_NE_K/QI6_K + MMQ_TILE_NE_K/8 + 7) static_assert(MMQ_MMA_TILE_X_K_Q8_0 % 8 == 4, "Wrong padding."); static_assert(MMQ_MMA_TILE_X_K_Q8_1 % 8 == 4, "Wrong padding."); @@ -220,9 +231,12 @@ static_assert(MMQ_MMA_TILE_X_K_Q3_K % 8 == 4, "Wrong padding."); static_assert(MMQ_MMA_TILE_X_K_Q6_K % 8 == 4, "Wrong padding."); static_assert(MMQ_MMA_TILE_X_K_FP4 % 8 == 4, "Wrong padding."); static_assert(MMQ_MMA_TILE_X_K_FP4 == MMQ_MMA_TILE_X_K_Q8_1, "Wrong tile size for MXFP4"); +static_assert(MMQ_MMA_TILE_X_K_NVFP4 % 8 == 4, "Wrong padding."); + static constexpr __host__ __device__ int mmq_get_mma_tile_x_k(ggml_type type) { switch (type) { + case GGML_TYPE_Q1_0: return MMQ_MMA_TILE_X_K_Q8_0; case GGML_TYPE_Q4_0: return MMQ_MMA_TILE_X_K_Q8_0; case GGML_TYPE_Q4_1: return MMQ_MMA_TILE_X_K_Q8_1; case GGML_TYPE_Q5_0: return MMQ_MMA_TILE_X_K_Q8_0; @@ -230,6 +244,11 @@ static constexpr __host__ __device__ int mmq_get_mma_tile_x_k(ggml_type type) { case GGML_TYPE_Q8_0: return MMQ_MMA_TILE_X_K_Q8_0; // tile sizes are the same for Q8_1 and FP4 for blackwell case GGML_TYPE_MXFP4: return MMQ_MMA_TILE_X_K_Q8_1; +#if defined(BLACKWELL_MMA_AVAILABLE) + case GGML_TYPE_NVFP4: return MMQ_MMA_TILE_X_K_FP4; +#else + case GGML_TYPE_NVFP4: return MMQ_MMA_TILE_X_K_NVFP4; +#endif // defined(BLACKWELL_MMA_AVAILABLE) case GGML_TYPE_Q2_K: return MMQ_MMA_TILE_X_K_Q2_K; case GGML_TYPE_Q3_K: return MMQ_MMA_TILE_X_K_Q3_K; case GGML_TYPE_Q4_K: return MMQ_MMA_TILE_X_K_Q8_1; @@ -295,6 +314,87 @@ static constexpr __device__ int mmq_get_nwarps_device() { // ------------------------------------------------------------ +template <int mmq_y, bool need_check> static __device__ __forceinline__ void load_tiles_q1_0( + const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) { + constexpr int nwarps = mmq_get_nwarps_device(); + constexpr int warp_size = ggml_cuda_get_physical_warp_size(); + +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) + int * x_qs = (int *) x_tile; + float * x_df = (float *) (x_qs + 2*MMQ_TILE_NE_K); +#else + constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q8_0, mmq_y); + int * x_qs = (int *) x_tile; + float * x_df = (float *) (x_qs + txs.qs); +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) + + constexpr int blocks_per_iter = MMQ_ITER_K / QK1_0; + constexpr int threads_per_row = blocks_per_iter * QI1_0; + constexpr int nrows = warp_size / threads_per_row; + constexpr int scale_entries_per_block = QK1_0 / QK8_1; + constexpr int scale_entries_per_row = blocks_per_iter * scale_entries_per_block; + + const int txi = threadIdx.x % threads_per_row; + const int kbx = txi / QI1_0; + const int kqsx = txi % QI1_0; + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nrows*nwarps) { + int i = i0 + threadIdx.y*nrows + threadIdx.x/threads_per_row; + + if (need_check) { + i = min(i, i_max); + } + + const block_q1_0 * bxi = (const block_q1_0 *) x + kbx0 + i*stride + kbx; + const int qs_offset = 4*kqsx; + const int qs0 = bxi->qs[qs_offset + 0] | (bxi->qs[qs_offset + 1] << 8) | + (bxi->qs[qs_offset + 2] << 16) | (bxi->qs[qs_offset + 3] << 24); + + int unpacked_bytes[8]; +#pragma unroll + for (int j = 0; j < 8; ++j) { + const int shift = j * 4; + const int bits4 = (qs0 >> shift) & 0x0F; + const int b0 = (bits4 & 0x01) ? 1 : -1; + const int b1 = (bits4 & 0x02) ? 1 : -1; + const int b2 = (bits4 & 0x04) ? 1 : -1; + const int b3 = (bits4 & 0x08) ? 1 : -1; + unpacked_bytes[j] = (b0 & 0xFF) | ((b1 & 0xFF) << 8) | ((b2 & 0xFF) << 16) | ((b3 & 0xFF) << 24); + } + + const int dst_offset = kbx*(scale_entries_per_block*QI8_0) + kqsx*QI8_0; +#pragma unroll + for (int j = 0; j < 8; ++j) { +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) + x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + dst_offset + j] = unpacked_bytes[j]; +#else + x_qs[i*(2*MMQ_TILE_NE_K + 1) + dst_offset + j] = unpacked_bytes[j]; +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) + } + } + + const int ksx = threadIdx.x % scale_entries_per_row; + const int scale_block = ksx / scale_entries_per_block; + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nwarps) { + int i = i0 + threadIdx.y; + + if (need_check) { + i = min(i, i_max); + } + + const block_q1_0 * bxi = (const block_q1_0 *) x + kbx0 + i*stride + scale_block; + +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) + x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + ksx] = bxi->d; +#else + x_df[i*(2*MMQ_TILE_NE_K/QI8_0) + i/(QI8_0/2) + ksx] = bxi->d; +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) + } +} + template <int mmq_y, bool need_check> static __device__ __forceinline__ void load_tiles_q4_0( const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) { constexpr int nwarps = mmq_get_nwarps_device(); @@ -379,17 +479,25 @@ static __device__ __forceinline__ void vec_dot_q4_0_q8_1_dp4a( #pragma unroll for (int i0 = 0; i0 < mmq_y; i0 += warp_size) { const int i = i0 + threadIdx.x; - const int kyqs = QI8_1 * ((k01/2) / (QI8_1/2)) + (k01/2) % (QI8_1/2); int u[2*VDR_Q4_0_Q8_1_MMQ]; -#pragma unroll - for (int l = 0; l < VDR_Q4_0_Q8_1_MMQ; ++l) { - u[2*l+0] = y_qs[j*MMQ_TILE_Y_K + kyqs + l]; - u[2*l+1] = y_qs[j*MMQ_TILE_Y_K + kyqs + (l + QI4_0)]; + constexpr int max_cpy = ggml_cuda_get_max_cpy_bytes(); + constexpr int mcpy_int = max_cpy / sizeof(int); + static_assert(VDR_Q4_0_Q8_1_MMQ == 4, "bad VDR_Q4_0_Q8_1_MMQ"); + + int tmp0[4], tmp1[4]; + + #pragma unroll + for (int l0 = 0; l0 < 4 / mcpy_int; ++l0) { + ggml_cuda_memcpy_1<max_cpy>(tmp0 + l0 * mcpy_int, &y_qs[j*MMQ_TILE_Y_K + kyqs + l0 * mcpy_int] ); + ggml_cuda_memcpy_1<max_cpy>(tmp1 + l0 * mcpy_int, &y_qs[j*MMQ_TILE_Y_K + kyqs + QI4_0 + l0 * mcpy_int]); } + u[0]=tmp0[0]; u[2]=tmp0[1]; u[4]=tmp0[2]; u[6]=tmp0[3]; + u[1]=tmp1[0]; u[3]=tmp1[1]; u[5]=tmp1[2]; u[7]=tmp1[3]; + sum[j0/nwarps*mmq_y/warp_size + i0/warp_size] += vec_dot_q4_0_q8_1_impl<VDR_Q4_0_Q8_1_MMQ> (&x_qs[i*(MMQ_TILE_NE_K + 1) + k0/QR4_0], u, x_df[i*(MMQ_TILE_NE_K/QI4_0) + i/QI4_0 + k0/(QR4_0*QI4_0)], y_ds[j*MMQ_TILE_Y_K + k01/QI8_1]); @@ -482,17 +590,25 @@ static __device__ __forceinline__ void vec_dot_q4_1_q8_1_dp4a( #pragma unroll for (int i0 = 0; i0 < mmq_y; i0 += warp_size) { const int i = i0 + threadIdx.x; - const int kyqs = QI8_1 * ((k01/2) / (QI8_1/2)) + (k01/2) % (QI8_1/2); int u[2*VDR_Q4_1_Q8_1_MMQ]; -#pragma unroll - for (int l = 0; l < VDR_Q4_1_Q8_1_MMQ; ++l) { - u[2*l+0] = y_qs[j*MMQ_TILE_Y_K + kyqs + l]; - u[2*l+1] = y_qs[j*MMQ_TILE_Y_K + kyqs + (l + QI4_1)]; + constexpr int max_cpy = ggml_cuda_get_max_cpy_bytes(); + constexpr int mcpy_int = max_cpy / sizeof(int); + static_assert(VDR_Q4_0_Q8_1_MMQ == 4, "bad VDR_Q4_0_Q8_1_MMQ"); + + int tmp0[4], tmp1[4]; + + #pragma unroll + for (int l0 = 0; l0 < 4 / mcpy_int; ++l0) { + ggml_cuda_memcpy_1<max_cpy>(tmp0 + l0 * mcpy_int, &y_qs[j*MMQ_TILE_Y_K + kyqs + l0 * mcpy_int] ); + ggml_cuda_memcpy_1<max_cpy>(tmp1 + l0 * mcpy_int, &y_qs[j*MMQ_TILE_Y_K + kyqs + QI4_1 + l0 * mcpy_int]); } + u[0]=tmp0[0]; u[2]=tmp0[1]; u[4]=tmp0[2]; u[6]=tmp0[3]; + u[1]=tmp1[0]; u[3]=tmp1[1]; u[5]=tmp1[2]; u[7]=tmp1[3]; + sum[j0/nwarps*mmq_y/warp_size + i0/warp_size] += vec_dot_q4_1_q8_1_impl<VDR_Q4_1_Q8_1_MMQ> (&x_qs[i*(MMQ_TILE_NE_K + 1) + k0/QR4_1], u, x_dm[i*(MMQ_TILE_NE_K/QI4_1) + i/QI4_1 + k0/(QR4_1*QI4_1)], y_ds[j*MMQ_TILE_Y_K + k01/QI8_1]); @@ -826,6 +942,187 @@ static __device__ __forceinline__ void load_tiles_mxfp4_fp4(const char * __restr } } +#ifdef BLACKWELL_MMA_AVAILABLE +template <int mmq_y, bool need_check> +static __device__ __forceinline__ void load_tiles_nvfp4_nvfp4(const char * __restrict__ x, + int * __restrict__ x_tile, + const int kbx0, + const int i_max, + const int stride) { + constexpr int nwarps = mmq_get_nwarps_device(); + constexpr int warp_size = ggml_cuda_get_physical_warp_size(); + constexpr int iter_k = get_iter_k(GGML_TYPE_NVFP4); + constexpr int threads_per_row = iter_k / QK_NVFP4; // each thread processes 1 block + constexpr int rows_per_warp = warp_size / threads_per_row; + + uint32_t * x_u32 = (uint32_t *) x_tile; + + const int txi = threadIdx.x; + const int kbx = txi % threads_per_row; + const int row_in_warp = txi / threads_per_row; + + const block_nvfp4 * bxi_base = (const block_nvfp4 *) x + kbx0 + kbx; + uint32_t * x_u32_scale = x_u32 + 64 + kbx; + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += rows_per_warp * nwarps) { + int i = i0 + threadIdx.y * rows_per_warp + row_in_warp; + + if constexpr (need_check) { + i = min(i, i_max); + } + + const block_nvfp4 * bxi = bxi_base + i * stride; + const int row_base = i * MMQ_MMA_TILE_X_K_FP4; + const int q_base = row_base + 8 * kbx; + + const uint32_t * src_qs = reinterpret_cast<const uint32_t *>(bxi->qs); + +#pragma unroll + for (int sub = 0; sub < QK_NVFP4 / QK_NVFP4_SUB; ++sub) { + x_u32[q_base + 2 * sub + 0] = src_qs[2 * sub + 0]; + x_u32[q_base + 2 * sub + 1] = src_qs[2 * sub + 1]; + } + + x_u32_scale[row_base] = get_int_b4(bxi->d, 0); + } +} + +// Shared MMA kernel for MXFP4 and NVFP4 on Blackwell. +// Both quantizations encode values as e2m1 (FP4) and produce one uint32 scale per +// m16n8k64 MMA call; only the PTX kind (scale_vec::2X ue8m0 vs scale_vec::4X ue4m3) +// and the per-type stride constant differ. +template <int mmq_x, int mmq_y, ggml_type type> +static __device__ __forceinline__ void vec_dot_fp4_fp4_mma(const int * __restrict__ x, + const int * __restrict__ y, + float * __restrict__ sum, + const int k00) { + static_assert(type == GGML_TYPE_MXFP4 || type == GGML_TYPE_NVFP4, + "vec_dot_fp4_fp4_mma: type must be MXFP4 or NVFP4"); + + typedef tile<16, 8, int> tile_A; + typedef tile<8, 8, int> tile_B; + typedef tile<16, 8, float> tile_C; + + constexpr int stride = MMQ_MMA_TILE_X_K_FP4; + constexpr int granularity = mmq_get_granularity_device(mmq_x); + constexpr int rows_per_warp = 2 * granularity; + constexpr int ntx = rows_per_warp / tile_C::I; + constexpr int nfrags = MMQ_TILE_NE_K / tile_A::J; + + y += (threadIdx.y % ntx) * (tile_C::J * MMQ_TILE_Y_K); + + const int * x_qs = (const int *) x; + const uint32_t * x_sc = (const uint32_t *) (x_qs + 2 * MMQ_TILE_NE_K); + const int * y_qs = (const int *) y + 4; + const uint32_t * y_sc = (const uint32_t *) y; + + // 2 threads per quad supply the packed scale register to the block_scale MMA, + // see https://docs.nvidia.com/cuda/parallel-thread-execution/#warp-level-block-scaling + const int tidx_A = threadIdx.x / 4 + (threadIdx.x % 2) * 8; + const int tidx_B = threadIdx.x / 4; + const int i0 = (threadIdx.y / ntx) * rows_per_warp; + + tile_A A[ntx][nfrags]; + uint32_t scaleA[ntx][nfrags]; + +#pragma unroll + for (int n = 0; n < ntx; ++n) { +#pragma unroll + for (int frag = 0; frag < nfrags; ++frag) { + const int k0 = k00 + frag * tile_A::J; + load_ldmatrix(A[n][frag], x_qs + (i0 + n * tile_A::I) * stride + k0, stride); + scaleA[n][frag] = x_sc[(i0 + n * tile_A::I + tidx_A) * stride + k0 / tile_A::J]; + } + } + +#pragma unroll + for (int j0 = 0; j0 < mmq_x; j0 += ntx * tile_C::J) { + tile_B B[nfrags]; + uint32_t scaleB[nfrags]; + +#pragma unroll + for (int frag = 0; frag < nfrags; ++frag) { + const int k0 = frag * tile_B::J; + load_generic(B[frag], y_qs + j0 * MMQ_TILE_Y_K + k0, MMQ_TILE_Y_K); + scaleB[frag] = y_sc[(j0 + tidx_B) * MMQ_TILE_Y_K + frag]; + } + +#pragma unroll + for (int n = 0; n < ntx; ++n) { +#pragma unroll + for (int frag = 0; frag < nfrags; ++frag) { + tile_C C = {}; + mma_block_scaled_fp4<type>(C, A[n][frag], B[frag], scaleA[n][frag], scaleB[frag]); +#pragma unroll + for (int l = 0; l < tile_C::ne; ++l) { + sum[(j0 / tile_C::J + n) * tile_C::ne + l] += C.x[l]; + } + } + } + } +} +#endif // BLACKWELL_MMA_AVAILABLE + + +template <int mmq_y, bool need_check> +static __device__ __forceinline__ void load_tiles_nvfp4(const char * __restrict__ x, + int * __restrict__ x_tile, + const int kb0, + const int i_max, + const int stride) { + constexpr int nwarps = mmq_get_nwarps_device(); + constexpr int warp_size = ggml_cuda_get_physical_warp_size(); + +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) + int * x_qs = (int *) x_tile; + float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2); +#else + constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_NVFP4, mmq_y); + int * x_qs = (int *) x_tile; + float * x_df = (float *) (x_qs + txs.qs); +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) + + constexpr int threads_per_row = MMQ_ITER_K / QK_NVFP4; + constexpr int rows_per_warp = warp_size / threads_per_row; + const int kbx = threadIdx.x % threads_per_row; + const int row_in_warp = threadIdx.x / threads_per_row; + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += rows_per_warp * nwarps) { + int i = i0 + threadIdx.y * rows_per_warp + row_in_warp; + + if constexpr (need_check) { + i = min(i, i_max); + } + + const block_nvfp4 * bxi = (const block_nvfp4 *) x + kb0 + i * stride + kbx; + const uint32_t * __restrict__ src_qs = reinterpret_cast<const uint32_t *>(bxi->qs); + const int kqs = 16 * kbx; + const int ksc = 4 * kbx; + +#pragma unroll + for (int sub = 0; sub < QK_NVFP4 / QK_NVFP4_SUB; ++sub) { + const int2 q0 = get_int_from_table_16(src_qs[2 * sub + 0], kvalues_mxfp4); + const int2 q1 = get_int_from_table_16(src_qs[2 * sub + 1], kvalues_mxfp4); + +#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) + x_qs[i * MMQ_MMA_TILE_X_K_NVFP4 + kqs + 4 * sub + 0] = q0.x; + x_qs[i * MMQ_MMA_TILE_X_K_NVFP4 + kqs + 4 * sub + 1] = q1.x; + x_qs[i * MMQ_MMA_TILE_X_K_NVFP4 + kqs + 4 * sub + 2] = q0.y; + x_qs[i * MMQ_MMA_TILE_X_K_NVFP4 + kqs + 4 * sub + 3] = q1.y; + x_df[i * MMQ_MMA_TILE_X_K_NVFP4 + ksc + sub] = ggml_cuda_ue4m3_to_fp32(bxi->d[sub]); +#else + x_qs[i * (2 * MMQ_TILE_NE_K + 1) + kqs + 4 * sub + 0] = q0.x; + x_qs[i * (2 * MMQ_TILE_NE_K + 1) + kqs + 4 * sub + 1] = q1.x; + x_qs[i * (2 * MMQ_TILE_NE_K + 1) + kqs + 4 * sub + 2] = q0.y; + x_qs[i * (2 * MMQ_TILE_NE_K + 1) + kqs + 4 * sub + 3] = q1.y; + x_df[i * (2 * MMQ_TILE_NE_K * 2 / QI_NVFP4) + i / (QK_NVFP4_SUB / QI_NVFP4) + ksc + sub] = ggml_cuda_ue4m3_to_fp32(bxi->d[sub]); +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) + } + } +} + template <int mmq_x, int mmq_y> static __device__ __forceinline__ void vec_dot_q8_0_q8_1_dp4a( const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) { @@ -887,13 +1184,13 @@ static __device__ __forceinline__ void vec_dot_q8_0_q8_1_mma( tile_A A[ntx]; #pragma unroll for (int n = 0; n < ntx; ++n) { - load_generic(A[n], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q8_0 + k0, MMQ_MMA_TILE_X_K_Q8_0); + load_ldmatrix(A[n], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q8_0 + k0, MMQ_MMA_TILE_X_K_Q8_0); } #pragma unroll for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) { tile_B B; - load_generic(B, y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K); + load_ldmatrix(B, y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K); float dB; const int j = j0 + tile_C::get_j(0); @@ -996,77 +1293,6 @@ static __device__ __forceinline__ void vec_dot_q8_0_q8_1_mma( #endif // defined(AMD_MFMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) } -template <int mmq_x, int mmq_y> -static __device__ __forceinline__ void vec_dot_mxfp4_mxfp4_mma(const int * __restrict__ x, - const int * __restrict__ y, - float * __restrict__ sum, - const int k00) { - typedef tile<16, 8, int> tile_A; - typedef tile<8, 8, int> tile_B; - typedef tile<16, 8, float> tile_C; // Output is float for native scaled MMA - - constexpr int granularity = mmq_get_granularity_device(mmq_x); - constexpr int rows_per_warp = 2 * granularity; - constexpr int ntx = rows_per_warp / tile_C::I; // Number of x minitiles per warp. - - y += (threadIdx.y % ntx) * (tile_C::J * MMQ_TILE_Y_FP4_K); - - // Match layout from load_tiles_mxfp4_fp4 - const int * x_qs = (const int *) x; - const uint32_t * x_sc = (const uint32_t *) (x_qs + 2 * MMQ_TILE_NE_K); - const int * y_qs = (const int *) y + 4; - const uint32_t * y_sc = (const uint32_t *) y; - - // tile_A has a length of 64 logical values vs. 32 values in block_mxfp4 - tile_A A[ntx][MMQ_TILE_NE_K / (2 * QI_MXFP4)]; - uint32_t scaleA[ntx][MMQ_TILE_NE_K / (2 * QI_MXFP4)]; - - // Block scale - // Each thread has to point to a 4 byte scale value - // https://docs.nvidia.com/cuda/parallel-thread-execution/#warp-level-block-scaling - - const int i0 = (threadIdx.y / ntx) * rows_per_warp; - -#pragma unroll - for (int n = 0; n < ntx; ++n) { -#pragma unroll - for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += 2 * QI_MXFP4) { - const int k0 = k00 + k01; - - load_ldmatrix(A[n][k01 / (2 * QI_MXFP4)], x_qs + (i0 + n * tile_A::I) * MMQ_MMA_TILE_X_K_FP4 + k0, - MMQ_MMA_TILE_X_K_FP4); - - // based on block-scaling document, 2 threads in each quad need to supply to the scale value - const int tidx = threadIdx.x / 4 + (threadIdx.x % 2) * 8; - scaleA[n][k01 / (2 * QI_MXFP4)] = - *(x_sc + (i0 + n * tile_A::I + tidx) * MMQ_MMA_TILE_X_K_FP4 + k0 / (2 * QI_MXFP4)); - } - } - -#pragma unroll - for (int j0 = 0; j0 < mmq_x; j0 += ntx * tile_C::J) { -#pragma unroll - for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += 2 * QI_MXFP4) { - tile_B B; - uint32_t scaleB; // 2xN scales - - load_generic(B, y_qs + j0 * MMQ_TILE_Y_FP4_K + k01, MMQ_TILE_Y_FP4_K); - - scaleB = y_sc[(j0 + threadIdx.x / 4) * MMQ_TILE_Y_FP4_K + k01 / (2 * QI_MXFP4)]; - -#pragma unroll - for (int n = 0; n < ntx; ++n) { - tile_C C; - - mma_block_scaled(C, A[n][k01 / (2 * QI_MXFP4)], B, scaleA[n][k01 / (2 * QI_MXFP4)], scaleB); -#pragma unroll - for (int l = 0; l < tile_C::ne; ++l) { - sum[(j0 / tile_C::J + n) * tile_C::ne + l] += C.x[l]; - } - } - } - } -} template <int mmq_x, int mmq_y> static __device__ __forceinline__ void vec_dot_q8_1_q8_1_dp4a( @@ -1128,13 +1354,13 @@ static __device__ __forceinline__ void vec_dot_q8_1_q8_1_mma( tile_A A[ntx]; #pragma unroll for (int n = 0; n < ntx; ++n) { - load_generic(A[n], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q8_1 + k0, MMQ_MMA_TILE_X_K_Q8_1); + load_ldmatrix(A[n], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q8_1 + k0, MMQ_MMA_TILE_X_K_Q8_1); } #pragma unroll for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) { tile_B B; - load_generic(B, y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K); + load_ldmatrix(B, y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K); const int j = j0 + tile_C::get_j(0); const float2 dsB = __half22float2(y_dm[j*MMQ_TILE_Y_K + k01/QI8_1]); @@ -1229,7 +1455,7 @@ static __device__ __forceinline__ void vec_dot_q8_1_q8_1_mma( #endif // defined(AMD_MFMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) } -// Used for Q3_K, IQ2_S, and IQ2_XS +// Used for NVFP4, Q3_K, IQ2_S, and IQ2_XS template <int mmq_x, int mmq_y> static __device__ __forceinline__ void vec_dot_q8_0_16_q8_1_dp4a( const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) { @@ -1268,57 +1494,7 @@ static __device__ __forceinline__ void vec_dot_q8_0_16_q8_1_dp4a( template <int mmq_x, int mmq_y> static __device__ __forceinline__ void vec_dot_q8_0_16_q8_1_mma( const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) { -#if defined(AMD_MFMA_AVAILABLE) - constexpr data_layout input_layout = get_input_data_layout(); - typedef tile<16, 8, int, input_layout> tile_A; - typedef tile<16, 8, int, input_layout> tile_B; - typedef tile<16, 16, int, DATA_LAYOUT_J_MAJOR> tile_C; - typedef tile<64, 2, int, input_layout> tile_load; - - constexpr int granularity = mmq_get_granularity_device(mmq_x); - constexpr int rows_per_warp = granularity; - constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp. - - y += (threadIdx.y % ntx) * (tile_C::J*MMQ_TILE_Y_K); - - const int * x_qs = (const int *) x; - const float * x_df = (const float *) x_qs + MMQ_TILE_NE_K*2; - const int * y_qs = (const int *) y + 4; - const float * y_df = (const float *) y; - - const int i0 = (threadIdx.y / ntx) * rows_per_warp; - - for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += 4) { - const int k0 = k00 + k01; - - tile_A A[ntx]; -#pragma unroll - for (int n = 0; n < ntx; ++n) { - load_generic(((tile_load *) A)[n], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q3_K + k0, MMQ_MMA_TILE_X_K_Q3_K); - } - -#pragma unroll - for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) { - tile_B B[1]; - load_generic(((tile_load *) B)[0], y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K); - - const int j = j0 + tile_C::get_j(0); - const float dB = y_df[j*MMQ_TILE_Y_K + k01/QI8_1] / 2; - -#pragma unroll - for (int n = 0; n < ntx; ++n) { - tile_C C; - mma(C, A[n], B[0]); - -#pragma unroll - for (int l = 0; l < tile_C::ne; ++l) { - const int i = i0 + n*tile_C::I + tile_C::get_i(l); - sum[(j0/tile_C::J + n)*tile_C::ne + l] += C.x[l] * x_df[i*MMQ_MMA_TILE_X_K_Q3_K + k0/4] * dB; - } - } - } - } -#elif defined(AMD_WMMA_AVAILABLE) //wmma instructions can handle 16x4 tiles, does not require loading 64x2 tiles +#if defined(AMD_MFMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) constexpr data_layout input_layout = get_input_data_layout(); typedef tile<16, 4, int, input_layout> tile_A; typedef tile<16, 4, int, input_layout> tile_B; @@ -1343,13 +1519,13 @@ static __device__ __forceinline__ void vec_dot_q8_0_16_q8_1_mma( tile_A A[ntx]; #pragma unroll for (int n = 0; n < ntx; ++n) { - load_generic(A[n], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q3_K + k0, MMQ_MMA_TILE_X_K_Q3_K); + load_ldmatrix(A[n], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q3_K + k0, MMQ_MMA_TILE_X_K_Q3_K); } #pragma unroll for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) { tile_B B; - load_generic(B, y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K); + load_ldmatrix(B, y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K); const int j = j0 + tile_C::get_j(0); const float dB = y_df[j*MMQ_TILE_Y_K + k01/QI8_1]; @@ -1575,74 +1751,7 @@ static __device__ __forceinline__ void vec_dot_q2_K_q8_1_dp4a( template <int mmq_x, int mmq_y> static __device__ __forceinline__ void vec_dot_q2_K_q8_1_mma( const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) { -#if defined(AMD_MFMA_AVAILABLE) - constexpr data_layout input_layout = get_input_data_layout(); - typedef tile<16, 8, int, input_layout> tile_A; - typedef tile<16, 8, int, input_layout> tile_B; - typedef tile<16, 16, int, DATA_LAYOUT_J_MAJOR> tile_C; - typedef tile<64, 2, int, input_layout> tile_load; - - constexpr int granularity = mmq_get_granularity_device(mmq_x); - constexpr int rows_per_warp = granularity; - constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp. - - y += (threadIdx.y % ntx) * (tile_C::J*MMQ_TILE_Y_K); - - const int * x_qs = (const int *) x; - const half2 * x_dm = (const half2 *) x_qs + MMQ_TILE_NE_K*2; - const int * y_qs = (const int *) y + 4; - const half2 * y_ds = (const half2 *) y; - - const int i0 = (threadIdx.y / ntx) * rows_per_warp; - - for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += 4) { - const int k0 = k00 + k01; - - tile_A A[ntx]; -#pragma unroll - for (int n = 0; n < ntx; ++n) { - load_generic(((tile_load *) A)[n], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q2_K + k0, MMQ_MMA_TILE_X_K_Q2_K); - } - -#pragma unroll - for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) { - tile_B B[1]; - load_generic(((tile_load *) B)[0], y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K); - - const int j = j0 + tile_C::get_j(0); - const float dB = (k01 < MMQ_TILE_NE_K/2) ? __half22float2(y_ds[j*MMQ_TILE_Y_K]).x/2 : __half22float2(y_ds[j*MMQ_TILE_Y_K]).y/2; - const float sB = (k01 >= MMQ_TILE_NE_K * 3/4) ? 0 - : (((k01/4)%2) ? __half22float2(y_ds[j*MMQ_TILE_Y_K + (1 + k01/QI8_1)]).y - : __half22float2(y_ds[j*MMQ_TILE_Y_K + (1 + k01/QI8_1)]).x); - - tile_C Cm; - if (k01 >= MMQ_TILE_NE_K * 3/4) { - tile_A A1; - A1.x[0] = 0x01010101; - A1.x[1] = 0x01010101; - mma(Cm, A1, B[0]); - } - -#pragma unroll - for (int n = 0; n < ntx; ++n) { - tile_C Cd; - mma(Cd, A[n], B[0]); - -#pragma unroll - for (int l = 0; l < tile_C::ne; ++l) { - const int i = i0 + n*tile_C::I + tile_C::get_i(l); - const float2 dm = __half22float2(x_dm[i*MMQ_MMA_TILE_X_K_Q2_K + k0/4]); - float tmp = Cd.x[l]*dm.x; - if (k01 >= MMQ_TILE_NE_K * 3/4) { - tmp -= Cm.x[l]*dm.y; - } - sum[(j0/tile_C::J + n)*tile_C::ne + l] += tmp*dB; - sum[(j0/tile_C::J + n)*tile_C::ne + l] -= dm.y*sB; - } - } - } - } -#elif defined(AMD_WMMA_AVAILABLE) //wmma instructions can handle 16x4 tiles, does not require loading 64x2 tiles +#if defined(AMD_MFMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) constexpr data_layout input_layout = get_input_data_layout(); typedef tile<16, 4, int, input_layout> tile_A; typedef tile<16, 4, int, input_layout> tile_B; @@ -1667,13 +1776,13 @@ static __device__ __forceinline__ void vec_dot_q2_K_q8_1_mma( tile_A A[ntx]; #pragma unroll for (int n = 0; n < ntx; ++n) { - load_generic(A[n], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q2_K + k0, MMQ_MMA_TILE_X_K_Q2_K); + load_ldmatrix(A[n], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q2_K + k0, MMQ_MMA_TILE_X_K_Q2_K); } #pragma unroll for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) { tile_B B; - load_generic(B, y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K); + load_ldmatrix(B, y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K); const int j = j0 + tile_C::get_j(0); const float dB = (k01 < MMQ_TILE_NE_K/2) ? __half22float2(y_ds[j*MMQ_TILE_Y_K]).x : __half22float2(y_ds[j*MMQ_TILE_Y_K]).y; @@ -2406,59 +2515,7 @@ static __device__ __forceinline__ void vec_dot_q6_K_q8_1_dp4a( template <int mmq_x, int mmq_y> static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma( const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) { -#if defined(AMD_MFMA_AVAILABLE) - constexpr data_layout input_layout = get_input_data_layout(); - typedef tile<16, 8, int, input_layout> tile_A; - typedef tile<16, 8, int, input_layout> tile_B; - typedef tile<16, 16, int, DATA_LAYOUT_J_MAJOR> tile_C; - typedef tile<64, 2, int, input_layout> tile_load; - - constexpr int granularity = mmq_get_granularity_device(mmq_x); - constexpr int rows_per_warp = granularity; - constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp. - - y += (threadIdx.y % ntx) * (tile_C::J*MMQ_TILE_Y_K); - - const int * x_qs = (const int *) x; - const float * x_df = (const float *) x_qs + MMQ_TILE_NE_K*2; - const int * x_sc = (const int *) x_df + MMQ_TILE_NE_K/QI6_K; - const int * y_qs = (const int *) y + 4; - const float * y_df = (const float *) y; - - const int i0 = (threadIdx.y / ntx) * rows_per_warp; - - for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += 4) { - const int k0 = k00 + k01; - - tile_A A[ntx]; -#pragma unroll - for (int n = 0; n < ntx; ++n) { - load_generic(((tile_load *) A)[n], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q6_K + k0, MMQ_MMA_TILE_X_K_Q6_K); - } - -#pragma unroll - for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) { - tile_B B[1]; - load_generic(((tile_load *) B)[0], y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K); - - const int j = j0 + tile_C::get_j(0); - const float dB = y_df[j*MMQ_TILE_Y_K + k01/QI8_1] / 2; - -#pragma unroll - for (int n = 0; n < ntx; ++n) { - tile_C C; - mma(C, A[n], B[0]); - -#pragma unroll - for (int l = 0; l < tile_C::ne; ++l) { - const int i = i0 + n*tile_C::I + tile_C::get_i(l); - const int8_t * sc = (const int8_t *) (x_sc + i*MMQ_MMA_TILE_X_K_Q6_K + k00/16); - sum[(j0/tile_C::J + n)*tile_C::ne + l] += C.x[l] * sc[k01/4] * x_df[i*MMQ_MMA_TILE_X_K_Q6_K] * dB; - } - } - } - } -#elif defined(AMD_WMMA_AVAILABLE) //wmma instructions can handle 16x4 tiles, does not require loading 64x2 tiles +#if defined(AMD_MFMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) constexpr data_layout input_layout = get_input_data_layout(); typedef tile<16, 4, int, input_layout> tile_A; typedef tile<16, 4, int, input_layout> tile_B; @@ -2484,13 +2541,13 @@ static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma( tile_A A[ntx]; #pragma unroll for (int n = 0; n < ntx; ++n) { - load_generic(A[n], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q6_K + k0, MMQ_MMA_TILE_X_K_Q6_K); + load_ldmatrix(A[n], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q6_K + k0, MMQ_MMA_TILE_X_K_Q6_K); } #pragma unroll for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) { tile_B B; - load_generic(B, y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K); + load_ldmatrix(B, y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K); const int j = j0 + tile_C::get_j(0); const float dB = y_df[j*MMQ_TILE_Y_K + k01/QI8_1]; @@ -2715,14 +2772,14 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa #pragma unroll for (int l = 0; l < QR2_XXS; ++l) { - const int * grid_pos = (const int *) (iq2xxs_grid + aux8[l]); - const int signs_packed = ksigns_iq2xs[(aux32 >> (7*l)) & 0x7F]; + const uint2 grid_pos = ((const uint2*)iq2xxs_grid)[aux8[l]]; + const uint32_t signs = unpack_ksigns(aux32 >> (7 * l)); - const int signs0 = __vcmpne4(((signs_packed & 0x03) << 7) | ((signs_packed & 0x0C) << 21), 0x00000000); - const int grid0 = __vsub4(grid_pos[0] ^ signs0, signs0); + const int signs0 = __vcmpne4(signs & 0x08040201, 0); + const int grid0 = __vsub4(grid_pos.x ^ signs0, signs0); - const int signs1 = __vcmpne4(((signs_packed & 0x30) << 3) | ((signs_packed & 0xC0) << 17), 0x00000000); - const int grid1 = __vsub4(grid_pos[1] ^ signs1, signs1); + const int signs1 = __vcmpne4(signs & 0x80402010, 0); + const int grid1 = __vsub4(grid_pos.y ^ signs1, signs1); #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + (2*l + 0)] = grid0; @@ -2733,12 +2790,12 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) } - const int ls = aux32 >> 28; + const int ls = aux32 >> 27 | 1; // (scale * 2 + 1) const float d = bxi->d; #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) - x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kqsx] = (ls*d + d/2)/4; + x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kqsx] = d * ls / 8; // (d * scale + d / 2) / 4 #else - x_df[i*(MMQ_TILE_NE_K/4) + i/4 + kqsx] = (ls*d + d/2)/4; + x_df[i*(MMQ_TILE_NE_K/4) + i/4 + kqsx] = d * ls / 8; // (d * scale + d / 2) / 4 #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) } } @@ -2776,11 +2833,14 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa #pragma unroll for (int l = 0; l < QR2_XS; ++l) { - const uint32_t * grid_pos = (const uint32_t *)(iq2xs_grid + (q2[l] & 0x000001FF)); - const uint32_t * signs = (const uint32_t *)(ksigns64 + (q2[l] >> 9)); + const uint2 grid_pos = ((const uint2*)iq2xs_grid)[q2[l] & 0x1FF]; + const uint32_t signs = unpack_ksigns(q2[l] >> 9); - const int grid_l = __vsub4(grid_pos[0] ^ signs[0], signs[0]); - const int grid_h = __vsub4(grid_pos[1] ^ signs[1], signs[1]); + const int signs0 = __vcmpne4(signs & 0x08040201, 0); + const int grid_l = __vsub4(grid_pos.x ^ signs0, signs0); + + const int signs1 = __vcmpne4(signs & 0x80402010, 0); + const int grid_h = __vsub4(grid_pos.y ^ signs1, signs1); #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + (2*l + 0)] = grid_l; @@ -2904,11 +2964,13 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa #pragma unroll for (int l = 0; l < QR3_XXS; ++l) { const int2 grid_pos = make_int2(iq3xxs_grid[q3[2*l+0]], iq3xxs_grid[q3[2*l+1]]); + const uint32_t signs = unpack_ksigns(aux32 >> (7*l)); - const int * signs = (const int *)(ksigns64 + ((aux32 >> (7*l)) & 0x7F)); + const int signs0 = __vcmpne4(signs & 0x08040201, 0); + const int grid_l = __vsub4(grid_pos.x ^ signs0, signs0); - const int grid_l = __vsub4(grid_pos.x ^ signs[0], signs[0]); - const int grid_h = __vsub4(grid_pos.y ^ signs[1], signs[1]); + const int signs1 = __vcmpne4(signs & 0x80402010, 0); + const int grid_h = __vsub4(grid_pos.y ^ signs1, signs1); #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + (2*l + 0)] = grid_l; @@ -3203,6 +3265,14 @@ static __device__ __forceinline__ void mmq_write_back_mma( template <int mmq_x, int mmq_y, bool need_check, ggml_type type> struct mmq_type_traits; +template <int mmq_x, int mmq_y, bool need_check> +struct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_Q1_0> { + static constexpr int vdr = VDR_Q1_0_Q8_1_MMQ; + static constexpr load_tiles_mmq_t load_tiles = load_tiles_q1_0<mmq_y, need_check>; + static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma<mmq_x, mmq_y, MMQ_Q8_1_DS_LAYOUT_D4>; + static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a<mmq_x, mmq_y>; +}; + template <int mmq_x, int mmq_y, bool need_check> struct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_Q4_0> { static constexpr int vdr = VDR_Q4_0_Q8_1_MMQ; @@ -3248,7 +3318,7 @@ struct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_MXFP4> { static constexpr int vdr = VDR_MXFP4_Q8_1_MMQ; #ifdef BLACKWELL_MMA_AVAILABLE static constexpr load_tiles_mmq_t load_tiles = load_tiles_mxfp4_fp4<mmq_y, need_check>; - static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_mxfp4_mxfp4_mma<mmq_x, mmq_y>; + static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_fp4_fp4_mma<mmq_x, mmq_y, GGML_TYPE_MXFP4>; #else static constexpr load_tiles_mmq_t load_tiles = load_tiles_mxfp4<mmq_y, need_check>; static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma<mmq_x, mmq_y, MMQ_Q8_1_DS_LAYOUT_D4>; @@ -3256,6 +3326,19 @@ struct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_MXFP4> { static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a<mmq_x, mmq_y>; }; +template <int mmq_x, int mmq_y, bool need_check> +struct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_NVFP4> { + static constexpr int vdr = VDR_NVFP4_Q8_1_MMQ; +#ifdef BLACKWELL_MMA_AVAILABLE + static constexpr load_tiles_mmq_t load_tiles = load_tiles_nvfp4_nvfp4<mmq_y, need_check>; + static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_fp4_fp4_mma<mmq_x, mmq_y, GGML_TYPE_NVFP4>; +#else + static constexpr load_tiles_mmq_t load_tiles = load_tiles_nvfp4<mmq_y, need_check>; + static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_16_q8_1_mma<mmq_x, mmq_y>; +#endif // BLACKWELL_MMA_AVAILABLE + static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_16_q8_1_dp4a<mmq_x, mmq_y>; +}; + template <int mmq_x, int mmq_y, bool need_check> struct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_Q2_K> { static constexpr int vdr = VDR_Q2_K_Q8_1_MMQ; @@ -3387,7 +3470,7 @@ static __device__ __forceinline__ void mul_mat_q_process_tile( #if defined(BLACKWELL_MMA_AVAILABLE) // FP4 tile stores 8 blocks - constexpr int ne_block = (type == GGML_TYPE_MXFP4) ? 8 * QK_MXFP4 : 4 * QK8_1; + constexpr int ne_block = (type == GGML_TYPE_MXFP4 || type == GGML_TYPE_NVFP4) ? QK_K : 4 * QK8_1; #else constexpr int ne_block = 4 * QK8_1; #endif // defined(BLACKWELL_MMA_AVAILABLE) @@ -3459,10 +3542,10 @@ template <ggml_type type, int mmq_x, bool need_check> static __global__ void mul_mat_q( const char * __restrict__ x, const int * __restrict__ y, const int32_t * __restrict__ ids_dst, const int32_t * __restrict__ expert_bounds, float * __restrict__ dst, float * __restrict__ tmp_fixup, - const int ncols_x, const int nrows_x, const int ncols_dst, const int stride_row_x, const int ncols_y, const int stride_col_dst, - const int channel_ratio, const int nchannels_y, const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst, - const int sample_ratio, const int nsamples_y, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst, - const int ncols_max) { + const uint3 blocks_per_ne00, const int nrows_x, const int ncols_dst, const int stride_row_x, const int ncols_y, const int stride_col_dst, + const uint3 channel_ratio, const uint3 nchannels_y, const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst, + const uint3 sample_ratio, const uint3 nsamples_y, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst, + const uint3 ntx) { // Skip unused template specializations for faster compilation: if (mmq_x > get_mmq_x_max_device() || mmq_x % mmq_get_granularity_device(mmq_x) != 0) { @@ -3476,8 +3559,7 @@ static __global__ void mul_mat_q( constexpr int qk = ggml_cuda_type_traits<type>::qk; constexpr int mmq_y = get_mmq_y_device(); - const int ntx = (ncols_max + mmq_x - 1) / mmq_x; // Number of tiles x - const int nty = (nrows_x + mmq_y - 1) / mmq_y; // Number of tiles y + const uint32_t nty = (nrows_x + mmq_y - 1) / mmq_y; // Number of tiles y // Initialize the ids for writing back data with just the index. // For regular matrix multiplications this is never changed. @@ -3498,8 +3580,9 @@ static __global__ void mul_mat_q( // On non-CDNA AMD or old CUDA the performance with stream-k was worse, use conventional tiling instead: #if (defined(GGML_USE_HIP) && !defined(CDNA)) || __CUDA_ARCH__ < GGML_CUDA_CC_VOLTA { - const int wt = blockIdx.z / nchannels_y; - const int zt = blockIdx.z - wt*nchannels_y; + const uint2 tmp2 = fast_div_modulo(blockIdx.z, nchannels_y); + const int wt = tmp2.x; + const int zt = tmp2.y; const int jt = blockIdx.y; const int it = blockIdx.x; @@ -3542,40 +3625,40 @@ static __global__ void mul_mat_q( const int tile_x_max_i = nrows_x - it*mmq_y - 1; const int tile_y_max_j = col_diff - jt*mmq_x - 1; - const int offset_x = (wt/sample_ratio)*stride_sample_x + (zt/channel_ratio)*stride_channel_x + it*mmq_y*stride_row_x; + const int offset_x = fastdiv(wt, sample_ratio)*stride_sample_x + fastdiv(zt, channel_ratio)*stride_channel_x + it*mmq_y*stride_row_x; constexpr bool fixup = false; mul_mat_q_process_tile<type, mmq_x, need_check, fixup> (x, offset_x, y + offset_y, ids_dst_shared, dst + offset_dst, tmp_fixup, stride_row_x, ncols_y, stride_col_dst, - tile_x_max_i, tile_y_max_j, 0, ncols_x/qk); + tile_x_max_i, tile_y_max_j, 0, blocks_per_ne00.z); return; } -#endif // (defined(GGML_USE_HIP) && !defined(CDNA3)) || __CUDA_ARCH__ < GGML_CUDA_CC_VOLTA +#endif // (defined(GGML_USE_HIP) && !defined(CDNA4) && !defined(CDNA3)) || __CUDA_ARCH__ < GGML_CUDA_CC_VOLTA - constexpr int ITER_K = get_iter_k(type); - - const int64_t blocks_per_ne00 = ncols_x / qk; - constexpr int blocks_per_iter = ITER_K / qk; + constexpr int ITER_K = get_iter_k(type); + constexpr int blocks_per_iter = ITER_K / qk; // kbc == k block continuous, current index in continuous ijk space. - int64_t kbc = (int64_t) blockIdx.x *nsamples_y*nchannels_y*ntx*nty*blocks_per_ne00 / gridDim.x; - int64_t kbc_stop = (int64_t)(blockIdx.x + 1)*nsamples_y*nchannels_y*ntx*nty*blocks_per_ne00 / gridDim.x; + int kbc = int64_t(blockIdx.x) *(nsamples_y.z*nchannels_y.z*ntx.z*nty*blocks_per_ne00.z) / gridDim.x; + int kbc_stop = int64_t(blockIdx.x + 1)*(nsamples_y.z*nchannels_y.z*ntx.z*nty*blocks_per_ne00.z) / gridDim.x; - kbc -= (kbc % blocks_per_ne00) % blocks_per_iter; - kbc_stop -= (kbc_stop % blocks_per_ne00) % blocks_per_iter; + kbc -= fastmodulo(kbc, blocks_per_ne00) % blocks_per_iter; + kbc_stop -= fastmodulo(kbc_stop, blocks_per_ne00) % blocks_per_iter; // kb0 == k index when doing the matrix multiplication for an output tile. - int kb0_start = kbc % blocks_per_ne00; - int kb0_stop = min(blocks_per_ne00, kb0_start + kbc_stop - kbc); - while (kbc < kbc_stop && kb0_stop == blocks_per_ne00) { - int tmp = kbc; - const int it = tmp / (nsamples_y*nchannels_y*ntx*blocks_per_ne00); - tmp -= it * (nsamples_y*nchannels_y*ntx*blocks_per_ne00); - const int wt = tmp / (nchannels_y*ntx*blocks_per_ne00); - tmp -= wt * (nchannels_y*ntx*blocks_per_ne00); - const int zt = tmp / (ntx*blocks_per_ne00); - tmp -= zt * (ntx*blocks_per_ne00); - const int jt = tmp / blocks_per_ne00; + int kb0_start = fastmodulo(kbc, blocks_per_ne00); + int kb0_stop = min(blocks_per_ne00.z, uint32_t(kb0_start + kbc_stop - kbc)); + while (kbc < kbc_stop && kb0_stop == int(blocks_per_ne00.z)) { + int tmp = fastdiv(kbc, blocks_per_ne00); + uint2 tmp2 = fast_div_modulo(tmp, ntx); + const int jt = tmp2.y; + tmp = tmp2.x; + tmp2 = fast_div_modulo(tmp, nchannels_y); + const int zt = tmp2.y; + tmp = tmp2.x; + tmp2 = fast_div_modulo(tmp, nsamples_y); + const int wt = tmp2.y; + const int it = tmp2.x; // Defaults for regular matrix multiplication: int col_low = 0; @@ -3593,11 +3676,11 @@ static __global__ void mul_mat_q( offset_dst = 0; if (jt*mmq_x >= col_diff) { - kbc += blocks_per_ne00; - kbc -= kbc % blocks_per_ne00; + kbc += blocks_per_ne00.z; + kbc -= fastmodulo(kbc, blocks_per_ne00); kb0_start = 0; - kb0_stop = min(blocks_per_ne00, kbc_stop - kbc); + kb0_stop = min(blocks_per_ne00.z, uint32_t(kbc_stop - kbc)); continue; } @@ -3622,32 +3705,34 @@ static __global__ void mul_mat_q( const int tile_x_max_i = nrows_x - it*mmq_y - 1; const int tile_y_max_j = col_diff - jt*mmq_x - 1; - const int offset_x = (wt/sample_ratio)*stride_sample_x + (zt/channel_ratio)*stride_channel_x + it*mmq_y*stride_row_x; + const int offset_x = fastdiv(wt, sample_ratio)*stride_sample_x + fastdiv(zt, channel_ratio)*stride_channel_x + it*mmq_y*stride_row_x; constexpr bool fixup = false; // All but (potentially) the last iterations write their data to dst rather than the fixup buffer. mul_mat_q_process_tile<type, mmq_x, need_check, fixup> (x, offset_x, y + offset_y, ids_dst_shared, dst + offset_dst, tmp_fixup, stride_row_x, ncols_y, stride_col_dst, tile_x_max_i, tile_y_max_j, kb0_start, kb0_stop); - kbc += blocks_per_ne00; - kbc -= kbc % blocks_per_ne00; + kbc += blocks_per_ne00.z; + kbc -= fastmodulo(kbc, blocks_per_ne00); kb0_start = 0; - kb0_stop = min(blocks_per_ne00, kbc_stop - kbc); + kb0_stop = min(blocks_per_ne00.z, uint32_t(kbc_stop - kbc)); } if (kbc >= kbc_stop) { return; } - int tmp = kbc; - const int it = tmp / (nsamples_y*nchannels_y*ntx*blocks_per_ne00); - tmp -= it * (nsamples_y*nchannels_y*ntx*blocks_per_ne00); - const int wt = tmp / (nchannels_y*ntx*blocks_per_ne00); - tmp -= wt * (nchannels_y*ntx*blocks_per_ne00); - const int zt = tmp / (ntx*blocks_per_ne00); - tmp -= zt * (ntx*blocks_per_ne00); - const int jt = tmp / blocks_per_ne00; + int tmp = fastdiv(kbc, blocks_per_ne00); + uint2 tmp2 = fast_div_modulo(tmp, ntx); + const int jt = tmp2.y; + tmp = tmp2.x; + tmp2 = fast_div_modulo(tmp, nchannels_y); + const int zt = tmp2.y; + tmp = tmp2.x; + tmp2 = fast_div_modulo(tmp, nsamples_y); + const int wt = tmp2.y; + const int it = tmp2.x; // Defaults for regular matrix multiplication: int col_low = 0; @@ -3689,7 +3774,7 @@ static __global__ void mul_mat_q( const int tile_x_max_i = nrows_x - it*mmq_y - 1; const int tile_y_max_j = col_diff - jt*mmq_x - 1; - const int offset_x = (wt/sample_ratio)*stride_sample_x + (zt/channel_ratio)*stride_channel_x + it*mmq_y*stride_row_x; + const int offset_x = fastdiv(wt, sample_ratio)*stride_sample_x + fastdiv(zt, channel_ratio)*stride_channel_x + it*mmq_y*stride_row_x; constexpr bool fixup = true; // Last index writes its data to fixup buffer to avoid data races with other blocks. mul_mat_q_process_tile<type, mmq_x, need_check, fixup> @@ -3697,40 +3782,38 @@ static __global__ void mul_mat_q( tile_x_max_i, tile_y_max_j, kb0_start, kb0_stop); } - template <ggml_type type, int mmq_x, bool need_check> +__launch_bounds__(ggml_cuda_get_physical_warp_size()*mmq_get_nwarps_device()/2, 1) static __global__ void mul_mat_q_stream_k_fixup( - const int32_t * ids_dst, const int32_t * expert_bounds, float * __restrict__ dst, const float * __restrict__ tmp_last_tile, - const int ncols_x, const int nrows_x, const int ncols_dst, const int stride_col_dst, - const int nchannels_y, const int stride_channel_dst, const int nsamples_y, const int stride_sample_dst, - const int ncols_max) { - constexpr int mmq_y = get_mmq_y_device(); - constexpr int qk = ggml_cuda_type_traits<type>::qk; - constexpr int ITER_K = get_iter_k(type); - - constexpr int blocks_per_iter = ITER_K / qk; - const int64_t blocks_per_ne00 = ncols_x / qk; + const int32_t * __restrict__ ids_dst, const int32_t * __restrict__ expert_bounds, float * __restrict__ dst, + float * __restrict__ tmp_last_tile, const uint3 blocks_per_ne00, const int nrows_x, const int ncols_dst, + const int stride_col_dst, const uint3 nchannels_y, const int stride_channel_dst, const uint3 nsamples_y, + const int stride_sample_dst, const uint3 ntx) { + constexpr int mmq_y = get_mmq_y_device(); + constexpr int qk = ggml_cuda_type_traits<type>::qk; + constexpr int ITER_K = get_iter_k(type); + constexpr int blocks_per_iter = ITER_K / qk; - constexpr int nwarps = mmq_get_nwarps_device(); + constexpr int nwarps = mmq_get_nwarps_device()/2; constexpr int warp_size = ggml_cuda_get_physical_warp_size(); - float sum[mmq_x*mmq_y / (nwarps*warp_size)] = {0.0f}; + float sum[mmq_x / nwarps] = {0.0f}; + const int i = blockIdx.y*warp_size + threadIdx.x; - const int ntx = (ncols_max + mmq_x - 1) / mmq_x; - const int nty = (nrows_x + mmq_y - 1) / mmq_y; + const int nty = (nrows_x + mmq_y - 1) / mmq_y; const int bidx0 = blockIdx.x; // kbc == k block continuous, current index in continuous ijk space. - int64_t kbc0 = (int64_t) bidx0 *nsamples_y*nchannels_y*ntx*nty*blocks_per_ne00 / gridDim.x; - int64_t kbc0_stop = (int64_t)(bidx0 + 1)*nsamples_y*nchannels_y*ntx*nty*blocks_per_ne00 / gridDim.x; + int kbc0 = int64_t(blockIdx.x) *(nsamples_y.z*nchannels_y.z*ntx.z*nty*blocks_per_ne00.z) / gridDim.x; + int kbc0_stop = int64_t(blockIdx.x + 1)*(nsamples_y.z*nchannels_y.z*ntx.z*nty*blocks_per_ne00.z) / gridDim.x; - kbc0 -= (kbc0 % blocks_per_ne00) % blocks_per_iter; - kbc0_stop -= (kbc0_stop % blocks_per_ne00) % blocks_per_iter; + kbc0 -= fastmodulo(kbc0, blocks_per_ne00) % blocks_per_iter; + kbc0_stop -= fastmodulo(kbc0_stop, blocks_per_ne00) % blocks_per_iter; const bool did_not_have_any_data = kbc0 == kbc0_stop; - const bool wrote_beginning_of_tile = kbc0 % blocks_per_ne00 == 0; - const bool did_not_write_last = kbc0/blocks_per_ne00 == kbc0_stop/blocks_per_ne00 && kbc0_stop % blocks_per_ne00 != 0; + const bool wrote_beginning_of_tile = fastmodulo(kbc0, blocks_per_ne00) == 0; + const bool did_not_write_last = fastdiv(kbc0, blocks_per_ne00) == fastdiv(kbc0_stop, blocks_per_ne00) && fastmodulo(kbc0_stop, blocks_per_ne00) != 0; if (did_not_have_any_data || wrote_beginning_of_tile || did_not_write_last) { return; } @@ -3739,11 +3822,11 @@ static __global__ void mul_mat_q_stream_k_fixup( // Iterate over previous blocks and sum up partial sums written to fixup buffer. // All CUDA blocks that get here must have a previous block that needs a fixup. - int64_t bidx = bidx0 - 1; - int64_t kbc_stop = kbc0; + int bidx = bidx0 - 1; + int kbc_stop = kbc0; while(true) { - int64_t kbc = bidx*nsamples_y*nchannels_y*ntx*nty*blocks_per_ne00 / gridDim.x; - kbc -= (kbc % blocks_per_ne00) % blocks_per_iter; + int kbc = int64_t(bidx)*(nsamples_y.z*nchannels_y.z*ntx.z*nty*blocks_per_ne00.z) / gridDim.x; + kbc -= fastmodulo(kbc, blocks_per_ne00) % blocks_per_iter; if (kbc == kbc_stop) { // Did not have any data. bidx--; @@ -3753,20 +3836,16 @@ static __global__ void mul_mat_q_stream_k_fixup( any_fixup = true; + #pragma unroll for (int j0 = 0; j0 < mmq_x; j0 += nwarps) { const int j = j0 + threadIdx.y; -#pragma unroll - for (int i0 = 0; i0 < mmq_y; i0 += warp_size) { - const int i = i0 + threadIdx.x; - - sum[(j0/nwarps) * (mmq_y/warp_size) + i0/warp_size] += tmp_last_tile[bidx*(mmq_x*mmq_y) + j*mmq_y + i]; - } + sum[j0/nwarps] += tmp_last_tile[bidx*(mmq_x*mmq_y) + j*mmq_y + i]; } // If this block started in a previous tile we are done and don't need to combine additional partial results. - if (kbc % blocks_per_ne00 == 0 || kbc/blocks_per_ne00 < kbc0/blocks_per_ne00) { + if (fastmodulo(kbc, blocks_per_ne00) == 0 || fastdiv(kbc, blocks_per_ne00) < fastdiv(kbc0, blocks_per_ne00)) { break; } bidx--; @@ -3777,14 +3856,16 @@ static __global__ void mul_mat_q_stream_k_fixup( return; } - int tmp = kbc0; - const int it = tmp / (nsamples_y*nchannels_y*ntx*blocks_per_ne00); - tmp -= it * (nsamples_y*nchannels_y*ntx*blocks_per_ne00); - const int wt = tmp / (nchannels_y*ntx*blocks_per_ne00); - tmp -= wt * (nchannels_y*ntx*blocks_per_ne00); - const int zt = tmp / (ntx*blocks_per_ne00); - tmp -= zt * (ntx*blocks_per_ne00); - const int jt = tmp / blocks_per_ne00; + int tmp = fastdiv(kbc0, blocks_per_ne00); + uint2 tmp2 = fast_div_modulo(tmp, ntx); + const int jt = tmp2.y; + tmp = tmp2.x; + tmp2 = fast_div_modulo(tmp, nchannels_y); + const int zt = tmp2.y; + tmp = tmp2.x; + tmp2 = fast_div_modulo(tmp, nsamples_y); + const int wt = tmp2.y; + const int it = tmp2.x; if (!ids_dst) { const int offset_dst = wt*stride_sample_dst + zt*stride_channel_dst + jt*mmq_x*stride_col_dst + it*mmq_y; @@ -3792,6 +3873,9 @@ static __global__ void mul_mat_q_stream_k_fixup( const int i_max = nrows_x - it*mmq_y - 1; const int j_max = ncols_dst - jt*mmq_x - 1; + if (need_check && i > i_max) { + return; + } #pragma unroll for (int j0 = 0; j0 < mmq_x; j0 += nwarps) { @@ -3801,16 +3885,7 @@ static __global__ void mul_mat_q_stream_k_fixup( return; } -#pragma unroll - for (int i0 = 0; i0 < mmq_y; i0 += warp_size) { - const int i = i0 + threadIdx.x; - - if (need_check && i > i_max) { - continue; - } - - dst[j*stride_col_dst + i] += sum[(j0/nwarps) * (mmq_y/warp_size) + i0/warp_size]; - } + dst[j*stride_col_dst + i] += sum[j0/nwarps]; } return; } @@ -3830,6 +3905,9 @@ static __global__ void mul_mat_q_stream_k_fixup( const int i_max = nrows_x - it*mmq_y - 1; const int j_max = col_diff - jt*mmq_x - 1; + if (need_check && i > i_max) { + return; + } #pragma unroll for (int j0 = 0; j0 < mmq_x; j0 += nwarps) { @@ -3839,16 +3917,7 @@ static __global__ void mul_mat_q_stream_k_fixup( return; } -#pragma unroll - for (int i0 = 0; i0 < mmq_y; i0 += warp_size) { - const int i = i0 + threadIdx.x; - - if (need_check && i > i_max) { - continue; - } - - dst[ids_dst_shared[j]*stride_col_dst + i] += sum[(j0/nwarps) * (mmq_y/warp_size) + i0/warp_size]; - } + dst[ids_dst_shared[j]*stride_col_dst + i] += sum[j0/nwarps]; } } @@ -3896,29 +3965,44 @@ static void launch_mul_mat_q(ggml_backend_cuda_context & ctx, const mmq_args & a const int channel_ratio = args.nchannels_y / args.nchannels_x; const int sample_ratio = args.nsamples_y / args.nsamples_x; + const uint3 blocks_per_ne00_fd = init_fastdiv_values(args.ncols_x / ggml_cuda_type_traits<type>::qk); + const uint3 ntx_fd = init_fastdiv_values(ntx); + const uint3 nchannels_y_fd = init_fastdiv_values(args.nchannels_y); + const uint3 nsamples_y_fd = init_fastdiv_values(args.nsamples_y); + const uint3 channel_ratio_fd = init_fastdiv_values(channel_ratio); + const uint3 sample_ratio_fd = init_fastdiv_values(sample_ratio); + if (!args.use_stream_k) { if (args.nrows_x % mmq_y == 0) { constexpr bool need_check = false; mul_mat_q<type, mmq_x, need_check><<<block_nums_xy_tiling, block_dims, nbytes_shared, stream>>> (args.x, args.y, args.ids_dst, args.expert_bounds, args.dst, nullptr, - args.ncols_x, args.nrows_x, args.ncols_dst, args.stride_row_x, args.ncols_y, args.nrows_dst, - channel_ratio, args.nchannels_y, args.stride_channel_x, args.stride_channel_y, args.stride_channel_dst, - sample_ratio, args.nsamples_y, args.stride_sample_x, args.stride_sample_y, args.stride_sample_dst, - args.ncols_max); + blocks_per_ne00_fd, args.nrows_x, args.ncols_dst, args.stride_row_x, args.ncols_y, args.nrows_dst, + channel_ratio_fd, nchannels_y_fd, args.stride_channel_x, args.stride_channel_y, args.stride_channel_dst, + sample_ratio_fd, nsamples_y_fd, args.stride_sample_x, args.stride_sample_y, args.stride_sample_dst, + ntx_fd); } else { constexpr bool need_check = true; mul_mat_q<type, mmq_x, need_check><<<block_nums_xy_tiling, block_dims, nbytes_shared, stream>>> (args.x, args.y, args.ids_dst, args.expert_bounds, args.dst, nullptr, - args.ncols_x, args.nrows_x, args.ncols_dst, args.stride_row_x, args.ncols_y, args.nrows_dst, - channel_ratio, args.nchannels_y, args.stride_channel_x, args.stride_channel_y, args.stride_channel_dst, - sample_ratio, args.nsamples_y, args.stride_sample_x, args.stride_sample_y, args.stride_sample_dst, - args.ncols_max); + blocks_per_ne00_fd, args.nrows_x, args.ncols_dst, args.stride_row_x, args.ncols_y, args.nrows_dst, + channel_ratio_fd, nchannels_y_fd, args.stride_channel_x, args.stride_channel_y, args.stride_channel_dst, + sample_ratio_fd, nsamples_y_fd, args.stride_sample_x, args.stride_sample_y, args.stride_sample_dst, + ntx_fd); } return; } - const dim3 block_nums_stream_k(nsm, 1, 1); - const bool fixup_needed = ntx*nty*ntzw % nsm != 0; + // For the stream-k kernel it is possible to run it with tiling by setting the number of CUDA blocks equal to the number of tiles. + // This is worthwhile if the efficiency of tiling is high and skipping the fixup kernel is more important. + const int ntiles_dst = ntx * nty * ntzw; + const int tiles_nwaves = (ntiles_dst + nsm - 1) / nsm; + const int tiles_efficiency_percent = 100 * ntiles_dst / (nsm*tiles_nwaves); + const dim3 block_nums_stream_k(GGML_CUDA_CC_IS_NVIDIA(cc) && tiles_efficiency_percent >= 90 ? ntiles_dst : nsm, 1, 1); + + GGML_ASSERT(ntiles_dst * blocks_per_ne00_fd.z < (1 << 30)); // Assert that variable kbc will not overflow. + + const bool fixup_needed = ntiles_dst % block_nums_stream_k.x != 0; ggml_cuda_pool & pool = ctx.pool(id); ggml_cuda_pool_alloc<float> tmp_fixup(pool); @@ -3926,40 +4010,45 @@ static void launch_mul_mat_q(ggml_backend_cuda_context & ctx, const mmq_args & a tmp_fixup.alloc(block_nums_stream_k.x * mmq_x*mmq_y); } + const dim3 block_nums_fixup(block_nums_stream_k.x, mmq_y/warp_size, 1); + const dim3 block_dims_fixup(block_dims.x, block_dims.y/2, block_dims.z); + if (args.nrows_x % mmq_y == 0) { constexpr bool need_check = false; mul_mat_q<type, mmq_x, need_check><<<block_nums_stream_k, block_dims, nbytes_shared, stream>>> (args.x, args.y, args.ids_dst, args.expert_bounds, args.dst, tmp_fixup.ptr, - args.ncols_x, args.nrows_x, args.ncols_dst, args.stride_row_x, args.ncols_y, args.nrows_dst, - channel_ratio, args.nchannels_y, args.stride_channel_x, args.stride_channel_y, args.stride_channel_dst, - sample_ratio, args.nsamples_y, args.stride_sample_x, args.stride_sample_y, args.stride_sample_dst, - args.ncols_max); + blocks_per_ne00_fd, args.nrows_x, args.ncols_dst, args.stride_row_x, args.ncols_y, args.nrows_dst, + channel_ratio_fd, nchannels_y_fd, args.stride_channel_x, args.stride_channel_y, args.stride_channel_dst, + sample_ratio_fd, nsamples_y_fd, args.stride_sample_x, args.stride_sample_y, args.stride_sample_dst, + ntx_fd); if (!fixup_needed) { return; } - mul_mat_q_stream_k_fixup<type, mmq_x, need_check><<<block_nums_stream_k, block_dims, 0, stream>>> - (args.ids_dst, args.expert_bounds, args.dst, tmp_fixup.ptr, args.ncols_x, args.nrows_x, args.ncols_dst, - args.nrows_dst, args.nchannels_y, args.stride_channel_dst, args.nsamples_y, args.stride_sample_dst, - args.ncols_max); + CUDA_CHECK(cudaGetLastError()); + mul_mat_q_stream_k_fixup<type, mmq_x, need_check><<<block_nums_fixup, block_dims_fixup, 0, stream>>> + (args.ids_dst, args.expert_bounds, args.dst, tmp_fixup.ptr, blocks_per_ne00_fd, args.nrows_x, args.ncols_dst, + args.nrows_dst, nchannels_y_fd, args.stride_channel_dst, nsamples_y_fd, args.stride_sample_dst, + ntx_fd); } else { constexpr bool need_check = true; mul_mat_q<type, mmq_x, need_check><<<block_nums_stream_k, block_dims, nbytes_shared, stream>>> (args.x, args.y, args.ids_dst, args.expert_bounds, args.dst, tmp_fixup.ptr, - args.ncols_x, args.nrows_x, args.ncols_dst, args.stride_row_x, args.ncols_y, args.nrows_dst, - channel_ratio, args.nchannels_y, args.stride_channel_x, args.stride_channel_y, args.stride_channel_dst, - sample_ratio, args.nsamples_y, args.stride_sample_x, args.stride_sample_y, args.stride_sample_dst, - args.ncols_max); + blocks_per_ne00_fd, args.nrows_x, args.ncols_dst, args.stride_row_x, args.ncols_y, args.nrows_dst, + channel_ratio_fd, nchannels_y_fd, args.stride_channel_x, args.stride_channel_y, args.stride_channel_dst, + sample_ratio_fd, nsamples_y_fd, args.stride_sample_x, args.stride_sample_y, args.stride_sample_dst, + ntx_fd); if (!fixup_needed) { return; } - mul_mat_q_stream_k_fixup<type, mmq_x, need_check><<<block_nums_stream_k, block_dims, 0, stream>>> - (args.ids_dst, args.expert_bounds, args.dst, tmp_fixup.ptr, args.ncols_x, args.nrows_x, args.ncols_dst, - args.nrows_dst, args.nchannels_y, args.stride_channel_dst, args.nsamples_y, args.stride_sample_dst, - args.ncols_max); + CUDA_CHECK(cudaGetLastError()); + mul_mat_q_stream_k_fixup<type, mmq_x, need_check><<<block_nums_fixup, block_dims_fixup, 0, stream>>> + (args.ids_dst, args.expert_bounds, args.dst, tmp_fixup.ptr, blocks_per_ne00_fd, args.nrows_x, args.ncols_dst, + args.nrows_dst, nchannels_y_fd, args.stride_channel_dst, nsamples_y_fd, args.stride_sample_dst, + ntx_fd); } } @@ -4057,6 +4146,7 @@ extern DECL_MMQ_CASE(GGML_TYPE_Q5_0); extern DECL_MMQ_CASE(GGML_TYPE_Q5_1); extern DECL_MMQ_CASE(GGML_TYPE_Q8_0); extern DECL_MMQ_CASE(GGML_TYPE_MXFP4); +extern DECL_MMQ_CASE(GGML_TYPE_NVFP4); extern DECL_MMQ_CASE(GGML_TYPE_Q2_K); extern DECL_MMQ_CASE(GGML_TYPE_Q3_K); extern DECL_MMQ_CASE(GGML_TYPE_Q4_K); @@ -4083,3 +4173,4 @@ void ggml_cuda_op_mul_mat_q( const int64_t src1_padded_row_size, cudaStream_t stream); bool ggml_cuda_should_use_mmq(enum ggml_type type, int cc, int64_t ne11, int64_t n_experts); + diff --git a/ggml/src/ggml-cuda/mmvf.cu b/ggml/src/ggml-cuda/mmvf.cu index 32948e4d7a1..d7dbc8b9928 100644 --- a/ggml/src/ggml-cuda/mmvf.cu +++ b/ggml/src/ggml-cuda/mmvf.cu @@ -4,26 +4,53 @@ #include "mmvf.cuh" #include "convert.cuh" -template <typename T, typename type_acc, int ncols_dst, int block_size, bool has_fusion = false> +template <typename T, typename type_acc, int ncols_dst, int block_size, bool has_fusion = false, bool is_multi_token_id = false> static __global__ void mul_mat_vec_f( - const T * __restrict__ x, const float * __restrict__ y, const int32_t * __restrict__ ids, const ggml_cuda_mm_fusion_args_device fusion, float * __restrict__ dst, - const int ncols2, const int nchannels_y, const int stride_row, const int stride_col_y2, const int stride_col_dst, + const T * x_ptr, const float * y_ptr, const int32_t * ids_ptr, const ggml_cuda_mm_fusion_args_device fusion, float * dst_ptr, + const int ncols2, const uint3 nchannels_y, const int stride_row, const int stride_col_y2, const int stride_col_dst, const uint3 channel_ratio, const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst, - const uint3 sample_ratio, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst) { + const uint3 sample_ratio, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst, + const int ids_stride) { + const T * GGML_CUDA_RESTRICT x = x_ptr; + const float * GGML_CUDA_RESTRICT y = y_ptr; + const int32_t * GGML_CUDA_RESTRICT ids = ids_ptr; + float * GGML_CUDA_RESTRICT dst = dst_ptr; const int row = blockIdx.x; + // for MUL_MAT_ID - blockIdx.y = n_expert_used, blockIdx.z = ncols_dst (tokens) const int channel_dst = blockIdx.y; - const int channel_x = ids ? ids[channel_dst] : fastdiv((uint32_t) channel_dst, channel_ratio); - const int channel_y = ids ? channel_dst % nchannels_y : channel_dst; - const int sample_dst = blockIdx.z; + const int tid = threadIdx.x; + + int token_idx; + int channel_x; + int channel_y; + int sample_dst; + + ggml_cuda_pdl_sync(); + if constexpr (is_multi_token_id) { + // Multi-token MUL_MAT_ID path, adding these in the normal path causes a perf regression for n_tokens=1 case + token_idx = blockIdx.z; + channel_x = ids[channel_dst + token_idx * ids_stride]; + channel_y = fastmodulo(channel_dst, nchannels_y); + sample_dst = 0; + } else { + token_idx = ids ? blockIdx.z : 0; + channel_x = ids ? ids[blockIdx.y + token_idx * ids_stride] : fastdiv((uint32_t) channel_dst, channel_ratio); + channel_y = ids ? fastmodulo(blockIdx.y, nchannels_y) : channel_dst; + sample_dst = ids ? 0 : blockIdx.z; + } + const int sample_x = fastdiv((uint32_t) sample_dst, sample_ratio); const int sample_y = sample_dst; - const int tid = threadIdx.x; constexpr int warp_size = ggml_cuda_get_physical_warp_size(); x += int64_t(sample_x) *stride_sample_x + channel_x *stride_channel_x + row*stride_row; y += int64_t(sample_y) *stride_sample_y + channel_y *stride_channel_y; dst += int64_t(sample_dst)*stride_sample_dst + channel_dst*stride_channel_dst; + if constexpr (is_multi_token_id) { + y += token_idx*stride_col_y2*2; + dst += token_idx*stride_col_dst; + } bool use_gate = false; bool use_bias = false; @@ -56,6 +83,7 @@ static __global__ void mul_mat_vec_f( if (use_gate) { gate_x += int64_t(sample_x) *stride_sample_x + channel_x *stride_channel_x + row*stride_row; } + if constexpr (has_fusion) { const int channel_bias = ids ? channel_x : channel_dst; if (use_bias) { @@ -70,7 +98,7 @@ static __global__ void mul_mat_vec_f( extern __shared__ char data_mmv[]; float * buf_iw = (float *) data_mmv; - float * buf_iw_gate = nullptr; + [[maybe_unused]] float * buf_iw_gate = nullptr; if constexpr (has_fusion) { buf_iw_gate = (float *) (data_mmv + warp_size*sizeof(float)); } @@ -98,7 +126,7 @@ static __global__ void mul_mat_vec_f( if constexpr (std::is_same_v<T, float>) { const float2 * x2 = (const float2 *) x; - const float2 * gate_x2 = nullptr; + [[maybe_unused]] const float2 * gate_x2 = nullptr; if constexpr (has_fusion) { if (use_gate) { gate_x2 = (const float2 *) gate_x; @@ -130,7 +158,7 @@ static __global__ void mul_mat_vec_f( } } else if constexpr (std::is_same_v<T, half>) { const half2 * x2 = (const half2 *) x; - const half2 * gate_x2 = nullptr; + [[maybe_unused]] const half2 * gate_x2 = nullptr; if constexpr (has_fusion) { if (use_gate) { gate_x2 = (const half2 *) gate_x; @@ -241,7 +269,7 @@ static __global__ void mul_mat_vec_f( } #else const nv_bfloat162 * x2 = (const nv_bfloat162 *) x; - const nv_bfloat162 * gate_x2 = nullptr; + [[maybe_unused]] const nv_bfloat162 * gate_x2 = nullptr; if constexpr (has_fusion) { if (use_gate) { gate_x2 = (const nv_bfloat162 *) gate_x; @@ -249,7 +277,7 @@ static __global__ void mul_mat_vec_f( } for (int col2 = tid; col2 < ncols2; col2 += block_size) { const nv_bfloat162 tmpx = x2[col2]; - nv_bfloat162 tmpx_gate; + [[maybe_unused]] nv_bfloat162 tmpx_gate; if constexpr (has_fusion) { if (use_gate) { tmpx_gate = gate_x2[col2]; @@ -274,6 +302,7 @@ static __global__ void mul_mat_vec_f( static_assert(std::is_same_v<T, void>, "unsupported type"); } + ggml_cuda_pdl_lc(); #pragma unroll for (int j = 0; j < ncols_dst; ++j) { sumf[j] = warp_reduce_sum<warp_size>(sumf[j]); @@ -349,36 +378,38 @@ static __global__ void mul_mat_vec_f( } } -template<typename T, typename type_acc, int ncols_dst, int block_size> +template<typename T, typename type_acc, int ncols_dst, int block_size, bool is_multi_token_id = false> static void mul_mat_vec_f_switch_fusion( const T * x, const float * y, const int32_t * ids, const ggml_cuda_mm_fusion_args_device fusion, float * dst, - const int64_t ncols, const int64_t nrows, + const int64_t ncols, const uint3 nchannels_y, const int64_t stride_row, const int64_t stride_col_y, const int64_t stride_col_dst, const uint3 channel_ratio, const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst, const uint3 sample_ratio, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst, - const dim3 & block_dims, const dim3 & block_nums, const int nbytes_shared, const cudaStream_t stream) { + const dim3 & block_dims, const dim3 & block_nums, const int nbytes_shared, const int ids_stride, const cudaStream_t stream) { + + const ggml_cuda_kernel_launch_params launch_params = {block_nums, block_dims, nbytes_shared, stream}; const bool has_fusion = fusion.gate != nullptr || fusion.x_bias != nullptr || fusion.gate_bias != nullptr; if constexpr (ncols_dst == 1) { if (has_fusion) { - mul_mat_vec_f<T, type_acc, ncols_dst, block_size, true><<<block_nums, block_dims, nbytes_shared, stream>>> - (x, y, ids, fusion, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst, + ggml_cuda_kernel_launch(mul_mat_vec_f<T, type_acc, ncols_dst, block_size, true, is_multi_token_id>, launch_params, + x, y, ids, fusion, dst, ncols, nchannels_y, stride_row, stride_col_y, stride_col_dst, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, - sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); + sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride); return; } } GGML_ASSERT(!has_fusion && "fusion only supported for ncols_dst=1"); - mul_mat_vec_f<T, type_acc, ncols_dst, block_size><<<block_nums, block_dims, nbytes_shared, stream>>> - (x, y, ids, fusion, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst, + ggml_cuda_kernel_launch(mul_mat_vec_f<T, type_acc, ncols_dst, block_size, false, is_multi_token_id>, launch_params, + x, y, ids, fusion, dst, ncols, nchannels_y, stride_row, stride_col_y, stride_col_dst, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, - sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); + sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride); } -template <typename T, typename type_acc, int ncols_dst> +template <typename T, typename type_acc, int ncols_dst, bool is_multi_token_id = false> void launch_mul_mat_vec_f_cuda( const T * x, const float * y, const int32_t * ids, const ggml_cuda_mm_fusion_args_device fusion, float * dst, const int64_t ncols, const int64_t nrows, @@ -386,12 +417,13 @@ void launch_mul_mat_vec_f_cuda( const int64_t nchannels_x, const int64_t nchannels_y, const int64_t nchannels_dst, const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t nsamples_x, const int64_t nsamples_dst, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst, - cudaStream_t stream) { + const int64_t nsamples_or_ntokens, const int64_t ids_stride, cudaStream_t stream) { GGML_ASSERT(ncols % 2 == 0); GGML_ASSERT(stride_row % 2 == 0); GGML_ASSERT(stride_col_y % 2 == 0); GGML_ASSERT(ids || nchannels_dst % nchannels_x == 0); GGML_ASSERT( nsamples_dst % nsamples_x == 0); + const uint3 nchannels_y_fd = ids ? init_fastdiv_values(nchannels_y) : make_uint3(0, 0, 0); const uint3 channel_ratio_fd = ids ? make_uint3(0, 0, 0) : init_fastdiv_values(nchannels_dst / nchannels_x); const uint3 sample_ratio_fd = init_fastdiv_values(nsamples_dst / nsamples_x); @@ -415,56 +447,56 @@ void launch_mul_mat_vec_f_cuda( const bool has_fusion = fusion.gate != nullptr || fusion.x_bias != nullptr || fusion.gate_bias != nullptr; const int nbytes_shared = warp_size*sizeof(float) + (has_fusion ? warp_size*sizeof(float) : 0); - const dim3 block_nums(nrows, nchannels_dst, nsamples_dst); + const dim3 block_nums(nrows, nchannels_dst, nsamples_or_ntokens); const dim3 block_dims(block_size_best, 1, 1); switch (block_size_best) { case 32: { - mul_mat_vec_f_switch_fusion<T, type_acc, ncols_dst, 32> - (x, y, ids, fusion, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst, + mul_mat_vec_f_switch_fusion<T, type_acc, ncols_dst, 32, is_multi_token_id> + (x, y, ids, fusion, dst, ncols/2, nchannels_y_fd, stride_row, stride_col_y/2, stride_col_dst, channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst, - sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, block_dims, block_nums, nbytes_shared, stream); + sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, block_dims, block_nums, nbytes_shared, ids_stride, stream); } break; case 64: { - mul_mat_vec_f_switch_fusion<T, type_acc, ncols_dst, 64> - (x, y, ids, fusion, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst, + mul_mat_vec_f_switch_fusion<T, type_acc, ncols_dst, 64, is_multi_token_id> + (x, y, ids, fusion, dst, ncols/2, nchannels_y_fd, stride_row, stride_col_y/2, stride_col_dst, channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst, - sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, block_dims, block_nums, nbytes_shared, stream); + sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, block_dims, block_nums, nbytes_shared, ids_stride, stream); } break; case 96: { - mul_mat_vec_f_switch_fusion<T, type_acc, ncols_dst, 96> - (x, y, ids, fusion, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst, + mul_mat_vec_f_switch_fusion<T, type_acc, ncols_dst, 96, is_multi_token_id> + (x, y, ids, fusion, dst, ncols/2, nchannels_y_fd, stride_row, stride_col_y/2, stride_col_dst, channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst, - sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, block_dims, block_nums, nbytes_shared, stream); + sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, block_dims, block_nums, nbytes_shared, ids_stride, stream); } break; case 128: { - mul_mat_vec_f_switch_fusion<T, type_acc, ncols_dst, 128> - (x, y, ids, fusion, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst, + mul_mat_vec_f_switch_fusion<T, type_acc, ncols_dst, 128, is_multi_token_id> + (x, y, ids, fusion, dst, ncols/2, nchannels_y_fd, stride_row, stride_col_y/2, stride_col_dst, channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst, - sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, block_dims, block_nums, nbytes_shared, stream); + sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, block_dims, block_nums, nbytes_shared, ids_stride, stream); } break; case 160: { - mul_mat_vec_f_switch_fusion<T, type_acc, ncols_dst, 160> - (x, y, ids, fusion, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst, + mul_mat_vec_f_switch_fusion<T, type_acc, ncols_dst, 160, is_multi_token_id> + (x, y, ids, fusion, dst, ncols/2, nchannels_y_fd, stride_row, stride_col_y/2, stride_col_dst, channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst, - sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, block_dims, block_nums, nbytes_shared, stream); + sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, block_dims, block_nums, nbytes_shared, ids_stride, stream); } break; case 192: { - mul_mat_vec_f_switch_fusion<T, type_acc, ncols_dst, 192> - (x, y, ids, fusion, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst, + mul_mat_vec_f_switch_fusion<T, type_acc, ncols_dst, 192, is_multi_token_id> + (x, y, ids, fusion, dst, ncols/2, nchannels_y_fd, stride_row, stride_col_y/2, stride_col_dst, channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst, - sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, block_dims, block_nums, nbytes_shared, stream); + sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, block_dims, block_nums, nbytes_shared, ids_stride, stream); } break; case 224: { - mul_mat_vec_f_switch_fusion<T, type_acc, ncols_dst, 224> - (x, y, ids, fusion, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst, + mul_mat_vec_f_switch_fusion<T, type_acc, ncols_dst, 224, is_multi_token_id> + (x, y, ids, fusion, dst, ncols/2, nchannels_y_fd, stride_row, stride_col_y/2, stride_col_dst, channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst, - sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, block_dims, block_nums, nbytes_shared, stream); + sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, block_dims, block_nums, nbytes_shared, ids_stride, stream); } break; case 256: { - mul_mat_vec_f_switch_fusion<T, type_acc, ncols_dst, 256> - (x, y, ids, fusion, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst, + mul_mat_vec_f_switch_fusion<T, type_acc, ncols_dst, 256, is_multi_token_id> + (x, y, ids, fusion, dst, ncols/2, nchannels_y_fd, stride_row, stride_col_y/2, stride_col_dst, channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst, - sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, block_dims, block_nums, nbytes_shared, stream); + sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, block_dims, block_nums, nbytes_shared, ids_stride, stream); } break; default: { GGML_ABORT("fatal error"); @@ -480,55 +512,88 @@ static void mul_mat_vec_f_cuda_switch_ncols_dst( const int64_t nchannels_x, const int64_t nchannels_y, const int64_t nchannels_dst, const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t nsamples_x, const int64_t nsamples_dst, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst, - cudaStream_t stream) { + const int64_t ids_stride, cudaStream_t stream) { + + const bool has_ids = ids != nullptr; + + if (has_ids && ncols_dst > 1) { + // Multi-token MUL_MAT_ID path only - single-token goes through regular path below + constexpr int c_ncols_dst = 1; + launch_mul_mat_vec_f_cuda<T, type_acc, c_ncols_dst, true> + (x, y, ids, fusion, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst, + nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, + stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, + ncols_dst, ids_stride, stream); + return; + } + + if (has_ids) { + // Single-token MUL_MAT_ID path + constexpr int c_ncols_dst = 1; + launch_mul_mat_vec_f_cuda<T, type_acc, c_ncols_dst> + (x, y, ids, fusion, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst, + nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, + stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, + ncols_dst, ids_stride, stream); + return; + } + switch (ncols_dst) { case 1: launch_mul_mat_vec_f_cuda<T, type_acc, 1> (x, y, ids, fusion, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, - stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); + stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, + nsamples_dst, ids_stride, stream); break; case 2: launch_mul_mat_vec_f_cuda<T, type_acc, 2> (x, y, ids, fusion, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, - stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); + stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, + nsamples_dst, ids_stride, stream); break; case 3: launch_mul_mat_vec_f_cuda<T, type_acc, 3> (x, y, ids, fusion, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, - stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); + stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, + nsamples_dst, ids_stride, stream); break; case 4: launch_mul_mat_vec_f_cuda<T, type_acc, 4> (x, y, ids, fusion, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, - stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); + stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, + nsamples_dst, ids_stride, stream); break; case 5: launch_mul_mat_vec_f_cuda<T, type_acc, 5> (x, y, ids, fusion, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, - stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); + stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, + nsamples_dst, ids_stride, stream); break; case 6: launch_mul_mat_vec_f_cuda<T, type_acc, 6> (x, y, ids, fusion, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, - stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); + stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, + nsamples_dst, ids_stride, stream); break; case 7: launch_mul_mat_vec_f_cuda<T, type_acc, 7> (x, y, ids, fusion, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, - stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); + stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, + nsamples_dst, ids_stride, stream); break; case 8: launch_mul_mat_vec_f_cuda<T, type_acc, 8> (x, y, ids, fusion, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, - stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); + stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, + nsamples_dst, ids_stride, stream); break; default: GGML_ABORT("fatal error"); @@ -544,21 +609,21 @@ static void mul_mat_vec_f_cuda( const int64_t nchannels_x, const int64_t nchannels_y, const int64_t nchannels_dst, const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t nsamples_x, const int64_t nsamples_dst, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst, - enum ggml_prec prec, cudaStream_t stream) { + const int64_t ids_stride, enum ggml_prec prec, cudaStream_t stream) { if constexpr(std::is_same_v<T, half>) { if (prec == GGML_PREC_DEFAULT) { mul_mat_vec_f_cuda_switch_ncols_dst<T, half> (x, y, ids, fusion, dst, ncols, nrows, ncols_dst, stride_row, stride_col_y, stride_col_dst, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, - stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); + stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream); return; } } mul_mat_vec_f_cuda_switch_ncols_dst<T, float> (x, y, ids, fusion, dst, ncols, nrows, ncols_dst, stride_row, stride_col_y, stride_col_dst, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, - stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); + stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream); } void ggml_cuda_mul_mat_vec_f(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst, @@ -573,7 +638,7 @@ void ggml_cuda_mul_mat_vec_f(ggml_backend_cuda_context & ctx, const ggml_tensor const size_t ts_src1 = ggml_type_size(src1->type); const size_t ts_dst = ggml_type_size(dst->type); - GGML_ASSERT(!ids || ne12 == 1); // Implementation is only correct for batch size 1. + GGML_ASSERT(!ids || ne12 <= MMVF_MAX_BATCH_SIZE); GGML_ASSERT(ne13 == ne3); GGML_ASSERT( nb00 == ts_src0); @@ -626,29 +691,31 @@ void ggml_cuda_mul_mat_vec_f(ggml_backend_cuda_context & ctx, const ggml_tensor const int64_t ncols_dst = ids ? ne2 : ne1; const int64_t nchannels_y = ids ? ne11 : ne12; const int64_t nchannels_dst = ids ? ne1 : ne2; + const int64_t stride_col_dst = ids ? s2 : s1; + const int64_t stride_col_y = ids ? s12 : s11; const int64_t stride_channel_dst = ids ? s1 : s2; const int64_t stride_channel_y = ids ? s11 : s12; - GGML_ASSERT(!ids || ncols_dst == 1); + const int64_t ids_stride = ids ? ids->nb[1] / ggml_type_size(ids->type) : 0; switch (src0->type) { case GGML_TYPE_F32: { const float * src0_d = (const float *) src0->data; - mul_mat_vec_f_cuda(src0_d, src1_d, ids_d, fusion_local, dst_d, ne00, ne01, ncols_dst, s01, s11, s1, + mul_mat_vec_f_cuda(src0_d, src1_d, ids_d, fusion_local, dst_d, ne00, ne01, ncols_dst, s01, stride_col_y, stride_col_dst, ne02, nchannels_y, nchannels_dst, s02, stride_channel_y, stride_channel_dst, - ne03, ne3, s03, s13, s3, prec, ctx.stream()); + ne03, ne3, s03, s13, s3, ids_stride, prec, ctx.stream()); } break; case GGML_TYPE_F16: { const half * src0_d = (const half *) src0->data; - mul_mat_vec_f_cuda(src0_d, src1_d, ids_d, fusion_local, dst_d, ne00, ne01, ncols_dst, s01, s11, s1, + mul_mat_vec_f_cuda(src0_d, src1_d, ids_d, fusion_local, dst_d, ne00, ne01, ncols_dst, s01, stride_col_y, stride_col_dst, ne02, nchannels_y, nchannels_dst, s02, stride_channel_y, stride_channel_dst, - ne03, ne3, s03, s13, s3, prec, ctx.stream()); + ne03, ne3, s03, s13, s3, ids_stride, prec, ctx.stream()); } break; case GGML_TYPE_BF16: { const nv_bfloat16 * src0_d = (const nv_bfloat16 *) src0->data; - mul_mat_vec_f_cuda(src0_d, src1_d, ids_d, fusion_local, dst_d, ne00, ne01, ncols_dst, s01, s11, s1, + mul_mat_vec_f_cuda(src0_d, src1_d, ids_d, fusion_local, dst_d, ne00, ne01, ncols_dst, s01, stride_col_y, stride_col_dst, ne02, nchannels_y, nchannels_dst, s02, stride_channel_y, stride_channel_dst, - ne03, ne3, s03, s13, s3, prec, ctx.stream()); + ne03, ne3, s03, s13, s3, ids_stride, prec, ctx.stream()); } break; default: GGML_ABORT("unsupported type: %s", ggml_type_name(src0->type)); @@ -695,19 +762,19 @@ void ggml_cuda_op_mul_mat_vec_f( const float * src0_d = (const float *) src0_dd_i; mul_mat_vec_f_cuda(src0_d, src1_ddf_i, nullptr, empty, dst_dd_i, ne00, row_diff, src1_ncols, stride_row, stride_col_y, stride_col_dst, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, - nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, prec, stream); + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, 0, prec, stream); } break; case GGML_TYPE_F16: { const half * src0_d = (const half *) src0_dd_i; mul_mat_vec_f_cuda(src0_d, src1_ddf_i, nullptr, empty, dst_dd_i, ne00, row_diff, src1_ncols, stride_row, stride_col_y, stride_col_dst, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, - nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, prec, stream); + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, 0, prec, stream); } break; case GGML_TYPE_BF16: { const nv_bfloat16 * src0_d = (const nv_bfloat16 *) src0_dd_i; mul_mat_vec_f_cuda(src0_d, src1_ddf_i, nullptr, empty, dst_dd_i, ne00, row_diff, src1_ncols, stride_row, stride_col_y, stride_col_dst, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, - nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, prec, stream); + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, 0, prec, stream); } break; default: GGML_ABORT("unsupported type: %s", ggml_type_name(src0->type)); diff --git a/ggml/src/ggml-cuda/mmvf.cuh b/ggml/src/ggml-cuda/mmvf.cuh index a09fbdc7202..a50f7c02180 100644 --- a/ggml/src/ggml-cuda/mmvf.cuh +++ b/ggml/src/ggml-cuda/mmvf.cuh @@ -1,5 +1,7 @@ #include "common.cuh" +#define MMVF_MAX_BATCH_SIZE 8 // Max. batch size for which to use MMVF kernels. + void ggml_cuda_mul_mat_vec_f(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst, const ggml_cuda_mm_fusion_args_host * fusion = nullptr); diff --git a/ggml/src/ggml-cuda/mmvq.cu b/ggml/src/ggml-cuda/mmvq.cu index d671551c171..fe44a58da91 100644 --- a/ggml/src/ggml-cuda/mmvq.cu +++ b/ggml/src/ggml-cuda/mmvq.cu @@ -9,12 +9,14 @@ typedef float (*vec_dot_q_cuda_t)(const void * __restrict__ vbq, const block_q8_ static constexpr __device__ vec_dot_q_cuda_t get_vec_dot_q_cuda(ggml_type type) { switch (type) { + case GGML_TYPE_Q1_0: return vec_dot_q1_0_q8_1; case GGML_TYPE_Q4_0: return vec_dot_q4_0_q8_1; case GGML_TYPE_Q4_1: return vec_dot_q4_1_q8_1; case GGML_TYPE_Q5_0: return vec_dot_q5_0_q8_1; case GGML_TYPE_Q5_1: return vec_dot_q5_1_q8_1; case GGML_TYPE_Q8_0: return vec_dot_q8_0_q8_1; case GGML_TYPE_MXFP4: return vec_dot_mxfp4_q8_1; + case GGML_TYPE_NVFP4: return vec_dot_nvfp4_q8_1; case GGML_TYPE_Q2_K: return vec_dot_q2_K_q8_1; case GGML_TYPE_Q3_K: return vec_dot_q3_K_q8_1; case GGML_TYPE_Q4_K: return vec_dot_q4_K_q8_1; @@ -33,14 +35,16 @@ static constexpr __device__ vec_dot_q_cuda_t get_vec_dot_q_cuda(ggml_type type) } } -static constexpr __device__ int get_vdr_mmvq(ggml_type type) { +static constexpr __host__ __device__ int get_vdr_mmvq(ggml_type type) { switch (type) { + case GGML_TYPE_Q1_0: return VDR_Q1_0_Q8_1_MMVQ; case GGML_TYPE_Q4_0: return VDR_Q4_0_Q8_1_MMVQ; case GGML_TYPE_Q4_1: return VDR_Q4_1_Q8_1_MMVQ; case GGML_TYPE_Q5_0: return VDR_Q5_0_Q8_1_MMVQ; case GGML_TYPE_Q5_1: return VDR_Q5_1_Q8_1_MMVQ; case GGML_TYPE_Q8_0: return VDR_Q8_0_Q8_1_MMVQ; case GGML_TYPE_MXFP4: return VDR_MXFP4_Q8_1_MMVQ; + case GGML_TYPE_NVFP4: return VDR_NVFP4_Q8_1_MMVQ; case GGML_TYPE_Q2_K: return VDR_Q2_K_Q8_1_MMVQ; case GGML_TYPE_Q3_K: return VDR_Q3_K_Q8_1_MMVQ; case GGML_TYPE_Q4_K: return VDR_Q4_K_Q8_1_MMVQ; @@ -59,31 +63,290 @@ static constexpr __device__ int get_vdr_mmvq(ggml_type type) { enum mmvq_parameter_table_id { MMVQ_PARAMETERS_GENERIC = 0, + MMVQ_PARAMETERS_TURING, MMVQ_PARAMETERS_GCN, - MMVQ_PARAMETERS_RDNA2 + MMVQ_PARAMETERS_RDNA2, + MMVQ_PARAMETERS_RDNA3_0, + MMVQ_PARAMETERS_RDNA4 }; static constexpr __device__ mmvq_parameter_table_id get_device_table_id() { -#if defined(RDNA2) || defined(RDNA3) || defined(RDNA4) +#if defined(RDNA4) + return MMVQ_PARAMETERS_RDNA4; +#elif defined(RDNA3_0) + return MMVQ_PARAMETERS_RDNA3_0; +#elif defined(RDNA2) || defined(RDNA3_5) return MMVQ_PARAMETERS_RDNA2; #elif defined(GCN) || defined(CDNA) return MMVQ_PARAMETERS_GCN; +#elif defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= GGML_CUDA_CC_TURING && __CUDA_ARCH__ < GGML_CUDA_CC_AMPERE + return MMVQ_PARAMETERS_TURING; #else return MMVQ_PARAMETERS_GENERIC; #endif } static __host__ mmvq_parameter_table_id get_device_table_id(int cc) { - if (GGML_CUDA_CC_IS_RDNA2(cc) || GGML_CUDA_CC_IS_RDNA3(cc) || GGML_CUDA_CC_IS_RDNA4(cc)) { + if (GGML_CUDA_CC_IS_RDNA4(cc)) { + return MMVQ_PARAMETERS_RDNA4; + } + if (GGML_CUDA_CC_IS_RDNA3_0(cc)) { + return MMVQ_PARAMETERS_RDNA3_0; + } + if (GGML_CUDA_CC_IS_RDNA2(cc) || GGML_CUDA_CC_IS_RDNA3_5(cc)) { return MMVQ_PARAMETERS_RDNA2; } if (GGML_CUDA_CC_IS_GCN(cc) || GGML_CUDA_CC_IS_CDNA(cc)) { return MMVQ_PARAMETERS_GCN; } + if (GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_TURING && ggml_cuda_highest_compiled_arch(cc) < GGML_CUDA_CC_AMPERE) { + return MMVQ_PARAMETERS_TURING; + } return MMVQ_PARAMETERS_GENERIC; } -static constexpr __host__ __device__ int calc_nwarps(int ncols_dst, mmvq_parameter_table_id table_id) { +// Per-architecture maximum batch size for which MMVQ should be used for MUL_MAT_ID. +// Returns a value <= MMVQ_MAX_BATCH_SIZE. Default is MMVQ_MAX_BATCH_SIZE. +// Check https://github.com/ggml-org/llama.cpp/pull/20905#issuecomment-4145835627 for details + +static constexpr __host__ __device__ int get_mmvq_mmid_max_batch_pascal_older(ggml_type type) { + switch (type) { + case GGML_TYPE_IQ1_S: return 6; + case GGML_TYPE_IQ1_M: return 6; + case GGML_TYPE_IQ2_S: return 4; + case GGML_TYPE_IQ2_XS: return 5; + case GGML_TYPE_IQ2_XXS: return 5; + case GGML_TYPE_IQ3_S: return 4; + case GGML_TYPE_IQ3_XXS: return 4; + case GGML_TYPE_IQ4_NL: return 6; + case GGML_TYPE_IQ4_XS: return 5; + case GGML_TYPE_MXFP4: return 4; + case GGML_TYPE_NVFP4: return 4; + case GGML_TYPE_Q2_K: return 4; + case GGML_TYPE_Q3_K: return 4; + case GGML_TYPE_Q4_0: return 6; + case GGML_TYPE_Q4_1: return 6; + case GGML_TYPE_Q4_K: return 5; + case GGML_TYPE_Q5_0: return 6; + case GGML_TYPE_Q5_1: return 6; + case GGML_TYPE_Q5_K: return 5; + case GGML_TYPE_Q6_K: return 4; + case GGML_TYPE_Q8_0: return 4; + default: return MMVQ_MAX_BATCH_SIZE; + } +} + +static constexpr __host__ __device__ int get_mmvq_mmid_max_batch_turing_plus(ggml_type type) { + switch (type) { + case GGML_TYPE_IQ2_S: return 7; + case GGML_TYPE_IQ3_S: return 6; + case GGML_TYPE_IQ3_XXS: return 7; + case GGML_TYPE_MXFP4: return 7; + case GGML_TYPE_NVFP4: return 8; + case GGML_TYPE_Q2_K: return 7; + case GGML_TYPE_Q3_K: return 5; + default: return MMVQ_MAX_BATCH_SIZE; + } +} + +static constexpr __host__ __device__ int get_mmvq_mmid_max_batch_gcn(ggml_type type) { + switch (type) { + case GGML_TYPE_IQ1_S: return 5; + case GGML_TYPE_IQ1_M: return 5; + case GGML_TYPE_IQ2_S: return 4; + case GGML_TYPE_IQ2_XS: return 4; + case GGML_TYPE_IQ2_XXS: return 4; + case GGML_TYPE_IQ3_S: return 4; + case GGML_TYPE_IQ3_XXS: return 4; + case GGML_TYPE_IQ4_NL: return 6; + case GGML_TYPE_IQ4_XS: return 4; + case GGML_TYPE_Q2_K: return 4; + case GGML_TYPE_Q3_K: return 4; + case GGML_TYPE_Q4_0: return 5; + case GGML_TYPE_Q4_1: return 5; + case GGML_TYPE_Q4_K: return 4; + case GGML_TYPE_Q5_K: return 4; + case GGML_TYPE_Q6_K: return 4; + case GGML_TYPE_Q8_0: return 4; + default: return MMVQ_MAX_BATCH_SIZE; + } +} + +static constexpr __host__ __device__ int get_mmvq_mmid_max_batch_cdna(ggml_type type) { + switch (type) { + case GGML_TYPE_IQ2_S: return 5; + case GGML_TYPE_IQ2_XS: return 5; + case GGML_TYPE_IQ2_XXS: return 5; + case GGML_TYPE_IQ3_S: return 4; + case GGML_TYPE_IQ3_XXS: return 5; + default: return MMVQ_MAX_BATCH_SIZE; + } +} + +static constexpr __host__ __device__ int get_mmvq_mmid_max_batch_rdna1_rdna2(ggml_type type) { + switch (type) { + case GGML_TYPE_IQ2_S: return 4; + case GGML_TYPE_IQ2_XS: return 4; + case GGML_TYPE_IQ2_XXS: return 4; + case GGML_TYPE_IQ3_S: return 4; + case GGML_TYPE_IQ3_XXS: return 4; + case GGML_TYPE_Q2_K: return 7; + case GGML_TYPE_Q3_K: return 4; + case GGML_TYPE_Q4_K: return 5; + case GGML_TYPE_Q5_K: return 6; + case GGML_TYPE_Q6_K: return 5; + default: return MMVQ_MAX_BATCH_SIZE; + } +} + +static constexpr __host__ __device__ int get_mmvq_mmid_max_batch_rdna3(ggml_type type) { + switch (type) { + case GGML_TYPE_IQ1_S: return 6; + case GGML_TYPE_IQ1_M: return 6; + case GGML_TYPE_IQ2_S: return 4; + case GGML_TYPE_IQ2_XS: return 4; + case GGML_TYPE_IQ2_XXS: return 4; + case GGML_TYPE_IQ3_S: return 4; + case GGML_TYPE_IQ3_XXS: return 4; + case GGML_TYPE_IQ4_NL: return 6; + case GGML_TYPE_IQ4_XS: return 6; + case GGML_TYPE_Q4_K: return 4; + case GGML_TYPE_Q5_K: return 4; + case GGML_TYPE_Q6_K: return 4; + default: return MMVQ_MAX_BATCH_SIZE; + } +} + +static constexpr __host__ __device__ int get_mmvq_mmid_max_batch_rdna4(ggml_type type) { + switch (type) { + case GGML_TYPE_IQ1_S: return 7; + case GGML_TYPE_IQ1_M: return 7; + case GGML_TYPE_IQ2_S: return 4; + case GGML_TYPE_IQ2_XS: return 4; + case GGML_TYPE_IQ2_XXS: return 4; + case GGML_TYPE_IQ3_S: return 4; + case GGML_TYPE_IQ3_XXS: return 4; + case GGML_TYPE_IQ4_NL: return 7; + case GGML_TYPE_IQ4_XS: return 5; + case GGML_TYPE_MXFP4: return 5; + case GGML_TYPE_NVFP4: return 5; + case GGML_TYPE_Q3_K: return 4; + case GGML_TYPE_Q4_0: return 7; + case GGML_TYPE_Q4_1: return 7; + case GGML_TYPE_Q4_K: return 4; + case GGML_TYPE_Q5_0: return 7; + case GGML_TYPE_Q5_1: return 7; + case GGML_TYPE_Q5_K: return 5; + case GGML_TYPE_Q6_K: return 5; + case GGML_TYPE_Q8_0: return 7; + default: return MMVQ_MAX_BATCH_SIZE; + } +} + +// Host function: returns the max batch size for the current arch+type at runtime. +int get_mmvq_mmid_max_batch(ggml_type type, int cc) { + // NVIDIA: Volta, Ada Lovelace, and Blackwell always use MMVQ for MUL_MAT_ID. + if (GGML_CUDA_CC_IS_NVIDIA(cc)) { + if (cc == GGML_CUDA_CC_VOLTA || cc >= GGML_CUDA_CC_ADA_LOVELACE) { + return MMVQ_MAX_BATCH_SIZE; + } + if (cc >= GGML_CUDA_CC_TURING) { + return get_mmvq_mmid_max_batch_turing_plus(type); + } + return get_mmvq_mmid_max_batch_pascal_older(type); + } + + // AMD + if (GGML_CUDA_CC_IS_AMD(cc)) { + if (GGML_CUDA_CC_IS_RDNA4(cc)) { + return get_mmvq_mmid_max_batch_rdna4(type); + } + if (GGML_CUDA_CC_IS_RDNA3(cc)) { + return get_mmvq_mmid_max_batch_rdna3(type); + } + if (GGML_CUDA_CC_IS_RDNA1(cc) || GGML_CUDA_CC_IS_RDNA2(cc)) { + return get_mmvq_mmid_max_batch_rdna1_rdna2(type); + } + if (GGML_CUDA_CC_IS_CDNA(cc)) { + return get_mmvq_mmid_max_batch_cdna(type); + } + if (GGML_CUDA_CC_IS_GCN(cc)) { + return get_mmvq_mmid_max_batch_gcn(type); + } + } + return MMVQ_MAX_BATCH_SIZE; +} + +bool ggml_cuda_should_use_mmvq(enum ggml_type type, int cc, int64_t ne11) { + if (GGML_CUDA_CC_IS_CDNA(cc)) { + if (GGML_CUDA_CC_IS_CDNA1(cc)) { + switch (type) { + case GGML_TYPE_Q4_0: + case GGML_TYPE_Q4_1: + return ne11 <= 7; + case GGML_TYPE_Q5_1: + return ne11 <= 7; + case GGML_TYPE_Q8_0: + return ne11 <= 6; + case GGML_TYPE_Q2_K: + return ne11 <= 4; + case GGML_TYPE_Q3_K: + return ne11 <= 3; + case GGML_TYPE_Q4_K: + return ne11 <= 2; + case GGML_TYPE_Q5_K: + return ne11 <= 3; + case GGML_TYPE_Q6_K: + return ne11 <= 4; + case GGML_TYPE_IQ1_S: + return ne11 <= 5; + case GGML_TYPE_IQ2_XXS: + case GGML_TYPE_IQ3_S: + case GGML_TYPE_IQ4_XS: + return ne11 <= 6; + default: + return ne11 <= MMVQ_MAX_BATCH_SIZE; + } + } + switch (type) { // tuned for CDNA2 + case GGML_TYPE_Q2_K: + return ne11 <= 5; + case GGML_TYPE_Q3_K: + case GGML_TYPE_Q4_K: + case GGML_TYPE_Q5_K: + return ne11 <= 3; + case GGML_TYPE_Q6_K: + return ne11 <= 5; + default: + return ne11 <= MMVQ_MAX_BATCH_SIZE; + } + } + return ne11 <= MMVQ_MAX_BATCH_SIZE; +} + +// Device constexpr: returns the max batch size for the current arch+type at compile time. +template <ggml_type type> +static constexpr __device__ int get_mmvq_mmid_max_batch_for_device() { +#if defined(RDNA4) + return get_mmvq_mmid_max_batch_rdna4(type); +#elif defined(RDNA3) + return get_mmvq_mmid_max_batch_rdna3(type); +#elif defined(RDNA2) || defined(RDNA1) + return get_mmvq_mmid_max_batch_rdna1_rdna2(type); +#elif defined(CDNA) + return get_mmvq_mmid_max_batch_cdna(type); +#elif defined(GCN) + return get_mmvq_mmid_max_batch_gcn(type); +#elif defined(__CUDA_ARCH__) && (__CUDA_ARCH__ == GGML_CUDA_CC_VOLTA || __CUDA_ARCH__ >= GGML_CUDA_CC_ADA_LOVELACE) + return MMVQ_MAX_BATCH_SIZE; +#elif defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= GGML_CUDA_CC_TURING + return get_mmvq_mmid_max_batch_turing_plus(type); +#else + return get_mmvq_mmid_max_batch_pascal_older(type); +#endif +} + +static constexpr __host__ __device__ int calc_nwarps(ggml_type type, int ncols_dst, mmvq_parameter_table_id table_id) { if (table_id == MMVQ_PARAMETERS_GENERIC) { switch (ncols_dst) { case 1: @@ -114,14 +377,86 @@ static constexpr __host__ __device__ int calc_nwarps(int ncols_dst, mmvq_paramet return 1; } } + if (table_id == MMVQ_PARAMETERS_RDNA4) { + // nwarps=8 benefits types with simple vec_dot on RDNA4 (ncols_dst=1). + // Types with complex vec_dot (Q3_K, IQ2_*, IQ3_*) regress due to register + // pressure and lookup table contention at higher thread counts. + if (ncols_dst == 1) { + switch (type) { + case GGML_TYPE_Q4_0: + case GGML_TYPE_Q4_1: + case GGML_TYPE_Q5_0: + case GGML_TYPE_Q5_1: + case GGML_TYPE_Q8_0: + case GGML_TYPE_Q2_K: + case GGML_TYPE_Q4_K: + case GGML_TYPE_Q5_K: + case GGML_TYPE_Q6_K: + case GGML_TYPE_IQ4_NL: + case GGML_TYPE_IQ4_XS: + return 8; + default: + return 1; + } + } + return 1; + } + if (table_id == MMVQ_PARAMETERS_RDNA3_0) { + // RDNA3 (W7900): stricter whitelist than RDNA4. + // Q2_K / Q5_K / IQ4_XS regress in full quant sweeps. + if (ncols_dst == 1) { + switch (type) { + case GGML_TYPE_Q4_0: + case GGML_TYPE_Q4_1: + case GGML_TYPE_Q5_0: + case GGML_TYPE_Q5_1: + case GGML_TYPE_Q8_0: + return 8; + case GGML_TYPE_Q6_K: + return 2; + case GGML_TYPE_IQ4_NL: + return 8; + default: + return 1; + } + } + return 1; + } + if (table_id == MMVQ_PARAMETERS_TURING) { + if (ncols_dst == 1) { + switch (type) { + case GGML_TYPE_Q2_K: + case GGML_TYPE_Q3_K: + case GGML_TYPE_Q4_K: + case GGML_TYPE_Q5_K: + case GGML_TYPE_Q6_K: + return 2; + default: + return 4; + } + } + switch (ncols_dst) { + case 2: + case 3: + case 4: + return 4; + case 5: + case 6: + case 7: + case 8: + return 2; + default: + return 1; + } + } return 1; } -static constexpr __host__ __device__ int calc_rows_per_block(int ncols_dst, int table_id) { - if (table_id == MMVQ_PARAMETERS_GENERIC || table_id == MMVQ_PARAMETERS_GCN) { +static constexpr __host__ __device__ int calc_rows_per_block(int ncols_dst, int table_id, bool small_k = false, int nwarps = 1) { + if (table_id == MMVQ_PARAMETERS_GENERIC || table_id == MMVQ_PARAMETERS_GCN || table_id == MMVQ_PARAMETERS_TURING) { switch (ncols_dst) { case 1: - return 1; + return small_k ? nwarps : 1; case 2: case 3: case 4: @@ -137,22 +472,26 @@ static constexpr __host__ __device__ int calc_rows_per_block(int ncols_dst, int return 1; } -// tell the compiler to use as many registers as it wants, see nwarps definition below -template <ggml_type type, int ncols_dst, bool has_fusion> -__launch_bounds__(calc_nwarps(ncols_dst, get_device_table_id())*ggml_cuda_get_physical_warp_size(), 1) +template <ggml_type type, int ncols_dst, bool has_fusion, bool small_k = false> +__launch_bounds__(calc_nwarps(type, ncols_dst, get_device_table_id())*ggml_cuda_get_physical_warp_size(), 1) static __global__ void mul_mat_vec_q( - const void * __restrict__ vx, const void * __restrict__ vy, const int32_t * __restrict__ ids, const ggml_cuda_mm_fusion_args_device fusion, float * __restrict__ dst, + const void * vx_ptr, const void * vy_ptr, const int32_t * ids_ptr, const ggml_cuda_mm_fusion_args_device fusion, float * dst_ptr, const uint32_t ncols_x, const uint3 nchannels_y, const uint32_t stride_row_x, const uint32_t stride_col_y, const uint32_t stride_col_dst, const uint3 channel_ratio, const uint32_t stride_channel_x, const uint32_t stride_channel_y, const uint32_t stride_channel_dst, const uint3 sample_ratio, - const uint32_t stride_sample_x, const uint32_t stride_sample_y, const uint32_t stride_sample_dst) { + const uint32_t stride_sample_x, const uint32_t stride_sample_y, const uint32_t stride_sample_dst, + const uint32_t ids_stride) { + const void * GGML_CUDA_RESTRICT vx = vx_ptr; + const void * GGML_CUDA_RESTRICT vy = vy_ptr; + const int32_t * GGML_CUDA_RESTRICT ids = ids_ptr; + float * GGML_CUDA_RESTRICT dst = dst_ptr; constexpr int qk = ggml_cuda_type_traits<type>::qk; constexpr int qi = ggml_cuda_type_traits<type>::qi; constexpr int vdr = get_vdr_mmvq(type); constexpr mmvq_parameter_table_id table_id = get_device_table_id(); - constexpr int nwarps = calc_nwarps(ncols_dst, table_id); - constexpr int rows_per_cuda_block = calc_rows_per_block(ncols_dst, table_id); + constexpr int nwarps = calc_nwarps(type, ncols_dst, table_id); + constexpr int rows_per_cuda_block = calc_rows_per_block(ncols_dst, table_id, small_k, nwarps); constexpr int warp_size = ggml_cuda_get_physical_warp_size(); constexpr vec_dot_q_cuda_t vec_dot_q_cuda = get_vec_dot_q_cuda(type); @@ -162,18 +501,24 @@ static __global__ void mul_mat_vec_q( const int blocks_per_row_x = ncols_x / qk; constexpr int blocks_per_iter = vdr * nwarps*warp_size / qi; - // The MUL_MAT_ID code path with ids != nullptr is only implemented for ncols_dst == 1. const uint32_t channel_dst = blockIdx.y; - const uint32_t channel_x = ncols_dst == 1 && ids ? ids[channel_dst] : fastdiv(channel_dst, channel_ratio); - const uint32_t channel_y = ncols_dst == 1 && ids ? fastmodulo(channel_dst, nchannels_y) : channel_dst; - const uint32_t sample_dst = blockIdx.z; + + uint32_t channel_x; + uint32_t channel_y; + uint32_t sample_dst; + + ggml_cuda_pdl_sync(); + channel_x = ncols_dst == 1 && ids ? ids[channel_dst] : fastdiv(channel_dst, channel_ratio); + channel_y = ncols_dst == 1 && ids ? fastmodulo(channel_dst, nchannels_y) : channel_dst; + sample_dst = blockIdx.z; + const uint32_t sample_x = fastdiv(sample_dst, sample_ratio); const uint32_t sample_y = sample_dst; bool use_gate = false; bool use_bias = false; bool use_gate_bias = false; - const void * vgate = nullptr; + [[maybe_unused]] const void * vgate = nullptr; const float * x_bias = nullptr; const float * gate_bias = nullptr; ggml_glu_op active_glu; @@ -188,11 +533,11 @@ static __global__ void mul_mat_vec_q( active_glu = fusion.glu_op; } - const uint32_t channel_bias = ids ? channel_x : channel_dst; - float x_biases[ncols_dst] = { 0.0f }; - float gate_biases[ncols_dst] = { 0.0f }; + [[maybe_unused]] float x_biases[ncols_dst] = { 0.0f }; + [[maybe_unused]] float gate_biases[ncols_dst] = { 0.0f }; if constexpr (has_fusion) { + const uint32_t channel_bias = ids ? channel_x : channel_dst; if (use_bias) { x_bias = x_bias + sample_dst*stride_sample_dst + channel_bias*stride_channel_dst + row0; // 1. Hide latency by prefetching bias and gate here @@ -247,12 +592,7 @@ static __global__ void mul_mat_vec_q( } __shared__ float tmp_shared[nwarps-1 > 0 ? nwarps-1 : 1][ncols_dst][rows_per_cuda_block][warp_size]; - __shared__ float tmp_shared_gate[(has_fusion && (nwarps-1 > 0)) ? nwarps-1 : 1][ncols_dst][rows_per_cuda_block][warp_size]; - if constexpr (!has_fusion) { - (void) tmp_shared_gate; - } else if (!use_gate) { - (void) tmp_shared_gate; - } + [[maybe_unused]] __shared__ float tmp_shared_gate[(has_fusion && (nwarps-1 > 0)) ? nwarps-1 : 1][ncols_dst][rows_per_cuda_block][warp_size]; if (threadIdx.y > 0) { #pragma unroll @@ -334,41 +674,139 @@ static __global__ void mul_mat_vec_q( } } +// Dedicated MoE multi-token kernel. +// Grid: (ceil(nrows_x / c_rows_per_block), nchannels_dst) +// Block: (warp_size, ncols_dst) - each warp handles one token independently. +// No shared memory reduction needed since each warp works alone. +template <ggml_type type, int c_rows_per_block> +__launch_bounds__(get_mmvq_mmid_max_batch_for_device<type>()*ggml_cuda_get_physical_warp_size(), 1) +static __global__ void mul_mat_vec_q_moe( + const void * vx_ptr, const void * vy_ptr, const int32_t * ids_ptr, + float * dst_ptr, + const uint32_t ncols_x, const uint3 nchannels_y, const uint32_t nrows_x, + const uint32_t stride_row_x, const uint32_t stride_col_y, const uint32_t stride_col_dst, + const uint32_t stride_channel_x, const uint32_t stride_channel_y, const uint32_t stride_channel_dst, + const uint32_t ncols_dst, const uint32_t ids_stride) { + const void * GGML_CUDA_RESTRICT vx = vx_ptr; + const void * GGML_CUDA_RESTRICT vy = vy_ptr; + const int32_t * GGML_CUDA_RESTRICT ids = ids_ptr; + float * GGML_CUDA_RESTRICT dst = dst_ptr; + + constexpr int qk = ggml_cuda_type_traits<type>::qk; + constexpr int qi = ggml_cuda_type_traits<type>::qi; + constexpr int vdr = get_vdr_mmvq(type); + constexpr int warp_size = ggml_cuda_get_physical_warp_size(); + + constexpr vec_dot_q_cuda_t vec_dot_q_cuda = get_vec_dot_q_cuda(type); + + const uint32_t token_idx = threadIdx.y; + const int row0 = c_rows_per_block*blockIdx.x; + const int blocks_per_row_x = ncols_x / qk; + constexpr int blocks_per_iter = vdr * warp_size / qi; + + const uint32_t channel_dst = blockIdx.y; + + if (token_idx >= ncols_dst) { + return; + } + + ggml_cuda_pdl_sync(); + const uint32_t channel_x = ids[channel_dst + token_idx * ids_stride]; + const uint32_t channel_y = fastmodulo(channel_dst, nchannels_y); + + const block_q8_1 * y = ((const block_q8_1 *) vy) + channel_y*stride_channel_y + token_idx*stride_col_y; + const int kbx_offset = channel_x*stride_channel_x + row0*stride_row_x; + + // partial sum for each thread + float tmp[c_rows_per_block] = {0.0f}; + + for (int kbx = threadIdx.x / (qi/vdr); kbx < blocks_per_row_x; kbx += blocks_per_iter) { + const int kby = kbx * (qk/QK8_1); + const int kqs = vdr * (threadIdx.x % (qi/vdr)); + +#pragma unroll + for (int i = 0; i < c_rows_per_block; ++i) { + tmp[i] += vec_dot_q_cuda(vx, &y[kby], kbx_offset + i*stride_row_x + kbx, kqs); + } + } + + ggml_cuda_pdl_lc(); + + // Warp-level reduction only - no shared memory needed +#pragma unroll + for (int i = 0; i < c_rows_per_block; ++i) { + tmp[i] = warp_reduce_sum<warp_size>(tmp[i]); + } + + // Write results + if (threadIdx.x < c_rows_per_block && (c_rows_per_block == 1 || uint32_t(row0 + threadIdx.x) < nrows_x)) { + dst[channel_dst*stride_channel_dst + token_idx*stride_col_dst + row0 + threadIdx.x] = tmp[threadIdx.x]; + } +} + +template<ggml_type type> static std::pair<dim3, dim3> calc_launch_params( - const int ncols_dst, const int nrows_x, const int nchannels_y, const int nsamples_y, - const int warp_size, const mmvq_parameter_table_id table_id) { - const int64_t nblocks = (nrows_x + calc_rows_per_block(ncols_dst, table_id) - 1) / calc_rows_per_block(ncols_dst, table_id); - const dim3 block_nums(nblocks, nchannels_y, nsamples_y); - const dim3 block_dims(warp_size, calc_nwarps(ncols_dst, table_id), 1); + const int ncols_dst, const int nrows_x, const int nchannels_dst, const int nsamples_or_ntokens, + const int warp_size, const mmvq_parameter_table_id table_id, const bool small_k = false) { + const int nwarps = calc_nwarps(type, ncols_dst, table_id); + const int rpb = calc_rows_per_block(ncols_dst, table_id, small_k, nwarps); + const int64_t nblocks = (nrows_x + rpb - 1) / rpb; + const dim3 block_nums(nblocks, nchannels_dst, nsamples_or_ntokens); + const dim3 block_dims(warp_size, nwarps, 1); return {block_nums, block_dims}; } -template<ggml_type type, int c_ncols_dst> +template<ggml_type type, int c_ncols_dst, bool small_k = false> static void mul_mat_vec_q_switch_fusion( const void * vx, const void * vy, const int32_t * ids, const ggml_cuda_mm_fusion_args_device fusion, float * dst, const uint32_t ncols_x, const uint3 nchannels_y, const uint32_t stride_row_x, const uint32_t stride_col_y, const uint32_t stride_col_dst, const uint3 channel_ratio, const uint32_t stride_channel_x, const uint32_t stride_channel_y, const uint32_t stride_channel_dst, const uint3 sample_ratio, const uint32_t stride_sample_x, const uint32_t stride_sample_y, const uint32_t stride_sample_dst, - const dim3 & block_nums, const dim3 & block_dims, const int nbytes_shared, cudaStream_t stream) { + const dim3 & block_nums, const dim3 & block_dims, const int nbytes_shared, + const uint32_t ids_stride, cudaStream_t stream) { const bool has_fusion = fusion.gate != nullptr || fusion.x_bias != nullptr || fusion.gate_bias != nullptr; if constexpr (c_ncols_dst == 1) { if (has_fusion) { - mul_mat_vec_q<type, c_ncols_dst, true><<<block_nums, block_dims, nbytes_shared, stream>>> - (vx, vy, ids, fusion, dst, ncols_x, nchannels_y, stride_row_x, stride_col_y, stride_col_dst, + const ggml_cuda_kernel_launch_params launch_params = ggml_cuda_kernel_launch_params(block_nums, block_dims, nbytes_shared, stream); + ggml_cuda_kernel_launch(mul_mat_vec_q<type, c_ncols_dst, true, small_k>, launch_params, + vx, vy, ids, fusion, dst, ncols_x, nchannels_y, stride_row_x, stride_col_y, stride_col_dst, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, - sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); + sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride); return; } } GGML_ASSERT(!has_fusion && "fusion only supported for ncols_dst=1"); - mul_mat_vec_q<type, c_ncols_dst, false><<<block_nums, block_dims, nbytes_shared, stream>>> - (vx, vy, ids, fusion, dst, ncols_x, nchannels_y, stride_row_x, stride_col_y, stride_col_dst, + const ggml_cuda_kernel_launch_params launch_params = ggml_cuda_kernel_launch_params(block_nums, block_dims, nbytes_shared, stream); + ggml_cuda_kernel_launch(mul_mat_vec_q<type, c_ncols_dst, false, small_k>, launch_params, + vx, vy, ids, fusion, dst, ncols_x, nchannels_y, stride_row_x, stride_col_y, stride_col_dst, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, - sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); + sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride); +} + +template <ggml_type type> +static void mul_mat_vec_q_moe_launch( + const void * vx, const void * vy, const int32_t * ids, float * dst, + const uint32_t ncols_x, const uint3 nchannels_y, const uint32_t nrows_x, + const uint32_t stride_row_x, const uint32_t stride_col_y, const uint32_t stride_col_dst, + const uint32_t stride_channel_x, const uint32_t stride_channel_y, const uint32_t stride_channel_dst, + const uint32_t ncols_dst, const uint32_t ids_stride, + const int warp_size, const int nchannels_dst, cudaStream_t stream) { + + constexpr int rows_per_block = 2; // 2 gives best perf based on tuning + const int64_t nblocks_rows = (nrows_x + rows_per_block - 1) / rows_per_block; + const dim3 block_nums(nblocks_rows, nchannels_dst); + const dim3 block_dims(warp_size, ncols_dst); + const ggml_cuda_kernel_launch_params launch_params = ggml_cuda_kernel_launch_params(block_nums, block_dims, 0, stream); + + ggml_cuda_kernel_launch(mul_mat_vec_q_moe<type, rows_per_block>, launch_params, + vx, vy, ids, dst, ncols_x, nchannels_y, nrows_x, + stride_row_x, stride_col_y, stride_col_dst, + stride_channel_x, stride_channel_y, stride_channel_dst, + ncols_dst, ids_stride); } template <ggml_type type> @@ -379,7 +817,7 @@ static void mul_mat_vec_q_switch_ncols_dst( const int nchannels_x, const int nchannels_y, const int nchannels_dst, const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst, const int nsamples_x, const int nsamples_dst, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst, - cudaStream_t stream) { + const int ids_stride, cudaStream_t stream) { GGML_ASSERT(ncols_x % ggml_blck_size(type) == 0); GGML_ASSERT(ncols_dst <= MMVQ_MAX_BATCH_SIZE); @@ -389,76 +827,144 @@ static void mul_mat_vec_q_switch_ncols_dst( const uint3 sample_ratio_fd = init_fastdiv_values(nsamples_dst / nsamples_x); const int device = ggml_cuda_get_device(); + const int cc = ggml_cuda_info().devices[device].cc; const int warp_size = ggml_cuda_info().devices[device].warp_size; - const mmvq_parameter_table_id table_id = get_device_table_id(ggml_cuda_info().devices[device].cc); + const mmvq_parameter_table_id table_id = get_device_table_id(cc); const bool has_fusion = fusion.gate != nullptr || fusion.x_bias != nullptr || fusion.gate_bias != nullptr; + const bool has_ids = ids != nullptr; + + const auto should_use_small_k = [&](int c_ncols_dst) { + // When K is small, increase rows_per_block to match nwarps so each warp has more work to do + // Trigger when the full thread block covers all K blocks in a single loop iteration and few threads remain idle. + constexpr int qk = ggml_cuda_type_traits<type>::qk; + constexpr int qi = ggml_cuda_type_traits<type>::qi; + constexpr int vdr = get_vdr_mmvq(type); + const int blocks_per_row_x = ncols_x / qk; + const int blocks_per_iter_1warp = vdr * warp_size / qi; + const int nwarps = calc_nwarps(type, c_ncols_dst, table_id); + bool use = nwarps > 1 && blocks_per_row_x < nwarps * blocks_per_iter_1warp; + + constexpr std::array<ggml_type, 2> iq_slow_turing = { + GGML_TYPE_IQ3_XXS, + GGML_TYPE_IQ3_S, + }; + constexpr std::array<ggml_type, 8> iq_slow_other = { + GGML_TYPE_IQ1_S, GGML_TYPE_IQ1_M, GGML_TYPE_IQ2_XXS, GGML_TYPE_IQ2_XS, + GGML_TYPE_IQ2_S, GGML_TYPE_IQ3_XXS, GGML_TYPE_IQ3_S, GGML_TYPE_IQ4_XS, + }; + constexpr std::array<ggml_type, 3> slow_pascal = { + GGML_TYPE_IQ3_S, + GGML_TYPE_Q2_K, + GGML_TYPE_Q3_K, + }; + + const bool is_nvidia_turing_plus = GGML_CUDA_CC_IS_NVIDIA(cc) && cc >= GGML_CUDA_CC_TURING; + const bool is_nvidia_pascal_older = GGML_CUDA_CC_IS_NVIDIA(cc) && cc < GGML_CUDA_CC_VOLTA; + + if (is_nvidia_turing_plus) { + if (ncols_dst == 1 && + std::find(iq_slow_turing.begin(), iq_slow_turing.end(), type) != iq_slow_turing.end()) { + use = false; + } + } else if ((ncols_dst == 1 && std::find(iq_slow_other.begin(), iq_slow_other.end(), type) != iq_slow_other.end()) || + (is_nvidia_pascal_older && std::find(slow_pascal.begin(), slow_pascal.end(), type) != slow_pascal.end()) || + GGML_CUDA_CC_IS_RDNA(cc)) { + use = false; + } + + return use; + }; + + if (has_ids && ncols_dst > 1) { + // Multi-token MUL_MAT_ID path - dedicated MoE kernel + mul_mat_vec_q_moe_launch<type>( + vx, vy, ids, dst, ncols_x, nchannels_y_fd, nrows_x, + stride_row_x, stride_col_y, stride_col_dst, + stride_channel_x, stride_channel_y, stride_channel_dst, + ncols_dst, ids_stride, warp_size, nchannels_dst, stream); + return; + } - GGML_ASSERT(!ids || ncols_dst == 1); switch (ncols_dst) { case 1: { constexpr int c_ncols_dst = 1; - std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id); - mul_mat_vec_q_switch_fusion<type, c_ncols_dst>(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst, - channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst, - sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, - dims.first, dims.second, 0, stream); + + bool use_small_k = should_use_small_k(c_ncols_dst); + + if (use_small_k) { + std::pair<dim3, dim3> dims = calc_launch_params<type>(c_ncols_dst, nrows_x, nchannels_dst, + nsamples_dst, warp_size, table_id, true); + mul_mat_vec_q_switch_fusion<type, c_ncols_dst, true>( + vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst, + channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst, sample_ratio_fd, + stride_sample_x, stride_sample_y, stride_sample_dst, dims.first, dims.second, 0, ids_stride, + stream); + } else { + std::pair<dim3, dim3> dims = calc_launch_params<type>(c_ncols_dst, nrows_x, nchannels_dst, + nsamples_dst, warp_size, table_id); + mul_mat_vec_q_switch_fusion<type, c_ncols_dst>( + vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst, + channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst, sample_ratio_fd, + stride_sample_x, stride_sample_y, stride_sample_dst, dims.first, dims.second, 0, ids_stride, + stream); + } } break; case 2: { constexpr int c_ncols_dst = 2; - std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id); + std::pair<dim3, dim3> dims = calc_launch_params<type>(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id); mul_mat_vec_q_switch_fusion<type, c_ncols_dst>(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst, channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst, sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, - dims.first, dims.second, 0, stream); + dims.first, dims.second, 0, ids_stride, stream); } break; case 3: { constexpr int c_ncols_dst = 3; - std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id); + std::pair<dim3, dim3> dims = calc_launch_params<type>(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id); mul_mat_vec_q_switch_fusion<type, c_ncols_dst>(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst, channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst, sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, - dims.first, dims.second, 0, stream); + dims.first, dims.second, 0, ids_stride, stream); } break; case 4: { constexpr int c_ncols_dst = 4; - std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id); + std::pair<dim3, dim3> dims = calc_launch_params<type>(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id); mul_mat_vec_q_switch_fusion<type, c_ncols_dst>(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst, channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst, sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, - dims.first, dims.second, 0, stream); + dims.first, dims.second, 0, ids_stride, stream); } break; case 5: { constexpr int c_ncols_dst = 5; - std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id); + std::pair<dim3, dim3> dims = calc_launch_params<type>(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id); mul_mat_vec_q_switch_fusion<type, c_ncols_dst>(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst, channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst, sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, - dims.first, dims.second, 0, stream); + dims.first, dims.second, 0, ids_stride, stream); } break; case 6: { constexpr int c_ncols_dst = 6; - std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id); + std::pair<dim3, dim3> dims = calc_launch_params<type>(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id); mul_mat_vec_q_switch_fusion<type, c_ncols_dst>(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst, channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst, sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, - dims.first, dims.second, 0, stream); + dims.first, dims.second, 0, ids_stride, stream); } break; case 7: { constexpr int c_ncols_dst = 7; - std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id); + std::pair<dim3, dim3> dims = calc_launch_params<type>(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id); mul_mat_vec_q_switch_fusion<type, c_ncols_dst>(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst, channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst, sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, - dims.first, dims.second, 0, stream); + dims.first, dims.second, 0, ids_stride, stream); } break; case 8: { constexpr int c_ncols_dst = 8; - std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id); + std::pair<dim3, dim3> dims = calc_launch_params<type>(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id); mul_mat_vec_q_switch_fusion<type, c_ncols_dst>(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst, channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst, sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, - dims.first, dims.second, 0, stream); + dims.first, dims.second, 0, ids_stride, stream); } break; default: GGML_ABORT("fatal error"); @@ -474,127 +980,139 @@ static void mul_mat_vec_q_switch_type( const int nchannels_x, const int nchannels_y, const int nchannels_dst, const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst, const int nsamples_x, const int nsamples_dst, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst, - cudaStream_t stream) { + const int ids_stride, cudaStream_t stream) { switch (type_x) { + case GGML_TYPE_Q1_0: + mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_Q1_0> + (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst, + nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream); + break; case GGML_TYPE_Q4_0: mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_Q4_0> (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, - nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream); break; case GGML_TYPE_Q4_1: mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_Q4_1> (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, - nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream); break; case GGML_TYPE_Q5_0: mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_Q5_0> (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, - nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream); break; case GGML_TYPE_Q5_1: mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_Q5_1> (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, - nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream); break; case GGML_TYPE_Q8_0: mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_Q8_0> (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, - nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream); break; case GGML_TYPE_MXFP4: mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_MXFP4> (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, - nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream); + break; + case GGML_TYPE_NVFP4: + mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_NVFP4> + (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst, + nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream); break; case GGML_TYPE_Q2_K: mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_Q2_K> (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, - nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream); break; case GGML_TYPE_Q3_K: mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_Q3_K> (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, - nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream); break; case GGML_TYPE_Q4_K: mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_Q4_K> (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, - nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream); break; case GGML_TYPE_Q5_K: mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_Q5_K> (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, - nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream); break; case GGML_TYPE_Q6_K: mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_Q6_K> (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, - nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream); break; case GGML_TYPE_IQ2_XXS: mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_IQ2_XXS> (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, - nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream); break; case GGML_TYPE_IQ2_XS: mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_IQ2_XS> (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, - nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream); break; case GGML_TYPE_IQ2_S: mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_IQ2_S> (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, - nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream); break; case GGML_TYPE_IQ3_XXS: mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_IQ3_XXS> (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, - nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream); break; case GGML_TYPE_IQ1_S: mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_IQ1_S> (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, - nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream); break; case GGML_TYPE_IQ1_M: mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_IQ1_M> (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, - nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream); break; case GGML_TYPE_IQ4_NL: mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_IQ4_NL> (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, - nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream); break; case GGML_TYPE_IQ4_XS: mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_IQ4_XS> (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, - nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream); break; case GGML_TYPE_IQ3_S: mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_IQ3_S> (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, - nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream); break; default: GGML_ABORT("fatal error"); @@ -622,7 +1140,7 @@ void ggml_cuda_mul_mat_vec_q( GGML_ASSERT( nb0 == ts_dst); GGML_ASSERT(!ids || ids->nb[0] == ggml_type_size(ids->type)); - GGML_ASSERT(!ids || ne12 == 1); // Implementation is only correct for batch size 1. + GGML_ASSERT(!ids || ne12 <= MMVQ_MAX_BATCH_SIZE); const float * src1_d = (const float *) src1->data; const int32_t * ids_d = ids ? (const int32_t *) ids->data : nullptr; @@ -693,11 +1211,13 @@ void ggml_cuda_mul_mat_vec_q( const int64_t stride_channel_dst = ids ? s1 : s2; const int64_t stride_channel_y = ids ? s11 : s12; + const int64_t ids_stride = ids ? ids->nb[1] / ggml_type_size(ids->type) : 0; + mul_mat_vec_q_switch_type( src0->data, src0->type, src1_q8_1.get(), ids_d, fusion_local, dst_d, ne00, ne01, ncols_dst, s01, stride_col_y, stride_col_dst, ne02, nchannels_y, nchannels_dst, s02, stride_channel_y, stride_channel_dst, - ne03, ne3, s03, s13, s3, stream); + ne03, ne3, s03, s13, s3, ids_stride, stream); } void ggml_cuda_op_mul_mat_vec_q( @@ -726,7 +1246,7 @@ void ggml_cuda_op_mul_mat_vec_q( ggml_cuda_mm_fusion_args_device fusion_local{}; mul_mat_vec_q_switch_type( src0_dd_i, src0->type, src1_ddq_i, nullptr, fusion_local, dst_dd_i, ne00, row_diff, src1_ncols, stride_row_x, stride_col_y, nrows_dst, - 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, stream); + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, stream); GGML_UNUSED_VARS(src1, dst, src1_ddf_i, src1_ncols, src1_padded_row_size); } diff --git a/ggml/src/ggml-cuda/mmvq.cuh b/ggml/src/ggml-cuda/mmvq.cuh index 4bb10cfaec2..5605bf7a4e6 100644 --- a/ggml/src/ggml-cuda/mmvq.cuh +++ b/ggml/src/ggml-cuda/mmvq.cuh @@ -2,6 +2,12 @@ #define MMVQ_MAX_BATCH_SIZE 8 // Max. batch size for which to use MMVQ kernels. +bool ggml_cuda_should_use_mmvq(enum ggml_type type, int cc, int64_t ne11); + +// Returns the maximum batch size for which MMVQ should be used for MUL_MAT_ID, +// based on the quantization type and GPU architecture (compute capability). +int get_mmvq_mmid_max_batch(ggml_type type, int cc); + void ggml_cuda_mul_mat_vec_q(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst, const ggml_cuda_mm_fusion_args_host * fusion = nullptr); diff --git a/ggml/src/ggml-cuda/norm.cu b/ggml/src/ggml-cuda/norm.cu index 4f153c5718e..09d9f3a7d62 100644 --- a/ggml/src/ggml-cuda/norm.cu +++ b/ggml/src/ggml-cuda/norm.cu @@ -18,6 +18,7 @@ static __global__ void norm_f32( float2 mean_var = make_float2(0.0f, 0.0f); + ggml_cuda_pdl_sync(); for (int col = tid; col < ncols; col += block_size) { const float xi = x[col]; mean_var.x += xi; @@ -25,19 +26,8 @@ static __global__ void norm_f32( } // sum up partial sums - mean_var = warp_reduce_sum(mean_var); - if constexpr (block_size > WARP_SIZE) { - static_assert(block_size == 1024, "unexpected block_size"); - __shared__ float2 s_sum[32]; - const int warp_id = threadIdx.x / WARP_SIZE; - const int lane_id = threadIdx.x % WARP_SIZE; - if (lane_id == 0) { - s_sum[warp_id] = mean_var; - } - __syncthreads(); - mean_var = s_sum[lane_id]; - mean_var = warp_reduce_sum(mean_var); - } + extern __shared__ float2 s_sum2[]; + mean_var = block_reduce<block_reduce_method::SUM, block_size>(mean_var, s_sum2); const float mean = mean_var.x / ncols; const float var = mean_var.y / ncols - mean * mean; @@ -57,23 +47,13 @@ static __global__ void group_norm_f32(const float * x, float * dst, const int gr float tmp = 0.0f; // partial sum for thread in warp + ggml_cuda_pdl_sync(); for (int j = start; j < end; j += block_size) { tmp += x[j]; } - tmp = warp_reduce_sum(tmp); - if constexpr (block_size > WARP_SIZE) { - static_assert(block_size == 1024, "unexpected block_size"); - __shared__ float s_sum[32]; - const int warp_id = threadIdx.x / WARP_SIZE; - const int lane_id = threadIdx.x % WARP_SIZE; - if (lane_id == 0) { - s_sum[warp_id] = tmp; - } - __syncthreads(); - tmp = s_sum[lane_id]; - tmp = warp_reduce_sum(tmp); - } + extern __shared__ float s_sum[]; + tmp = block_reduce<block_reduce_method::SUM, block_size>(tmp, s_sum); const float mean = tmp / group_size; tmp = 0.0f; @@ -84,18 +64,7 @@ static __global__ void group_norm_f32(const float * x, float * dst, const int gr tmp += xi * xi; } - tmp = warp_reduce_sum(tmp); - if (block_size > WARP_SIZE) { - __shared__ float s_sum[32]; - const int warp_id = threadIdx.x / WARP_SIZE; - const int lane_id = threadIdx.x % WARP_SIZE; - if (lane_id == 0) { - s_sum[warp_id] = tmp; - } - __syncthreads(); - tmp = s_sum[lane_id]; - tmp = warp_reduce_sum(tmp); - } + tmp = block_reduce<block_reduce_method::SUM, block_size>(tmp, s_sum); const float variance = tmp / group_size; const float scale = rsqrtf(variance + eps); @@ -128,6 +97,7 @@ static __global__ void rms_norm_f32(const float * x, const uint3 add_nrows_packed = make_uint3(0, 0, 0), const uint3 add_nchannels_packed = make_uint3(0, 0, 0), const uint3 add_nsamples_packed = make_uint3(0, 0, 0)) { + ggml_cuda_pdl_lc(); const int nrows = gridDim.x; const int nchannels = gridDim.y; @@ -157,28 +127,15 @@ static __global__ void rms_norm_f32(const float * x, float tmp = 0.0f; // partial sum for thread in warp + ggml_cuda_pdl_sync(); for (int col = tid; col < ncols; col += block_size) { const float xi = x[col]; tmp += xi * xi; } // sum up partial sums - tmp = warp_reduce_sum(tmp); - if constexpr (block_size > WARP_SIZE) { - static_assert((block_size <= 1024) && (block_size % 32 == 0), "unexpected block_size"); - __shared__ float s_sum[32]; - const int warp_id = tid / WARP_SIZE; - const int lane_id = tid % WARP_SIZE; - if (lane_id == 0) { - s_sum[warp_id] = tmp; - } - __syncthreads(); - tmp = 0.0f; - if (lane_id < (block_size / WARP_SIZE)) { - tmp = s_sum[lane_id]; - } - tmp = warp_reduce_sum(tmp); - } + extern __shared__ float s_sum[]; + tmp = block_reduce<block_reduce_method::SUM, block_size>(tmp, s_sum); const float mean = tmp / ncols; const float scale = rsqrtf(mean + eps); @@ -210,6 +167,7 @@ static __global__ void rms_norm_back_f32( float sum_xx = 0.0f; // sum for squares of x, equivalent to forward pass float sum_xg = 0.0f; // sum for x * gradient, needed because RMS norm mixes inputs + ggml_cuda_pdl_sync(); for (int col = tid; col < ncols; col += block_size) { const float xfi = xf[col]; sum_xx += xfi * xfi; @@ -300,25 +258,16 @@ static __global__ void l2_norm_f32( float tmp = 0.0f; // partial sum for thread in warp + ggml_cuda_pdl_sync(); for (int col = tid; col < ncols; col += block_size) { const float xi = x[col]; tmp += xi * xi; } // sum up partial sums - tmp = warp_reduce_sum(tmp); - if constexpr (block_size > WARP_SIZE) { - static_assert(block_size == 1024, "unexpected block_size"); - __shared__ float s_sum[32]; - const int warp_id = threadIdx.x / WARP_SIZE; - const int lane_id = threadIdx.x % WARP_SIZE; - if (lane_id == 0) { - s_sum[warp_id] = tmp; - } - __syncthreads(); - tmp = s_sum[lane_id]; - tmp = warp_reduce_sum(tmp); - } + extern __shared__ float s_sum[]; + tmp = block_reduce<block_reduce_method::SUM, block_size>(tmp, s_sum); + ggml_cuda_pdl_lc(); // from https://pytorch.org/docs/stable/generated/torch.nn.functional.normalize.html const float scale = rsqrtf(fmaxf(tmp, eps * eps)); @@ -337,7 +286,7 @@ static void norm_f32_cuda( norm_f32<WARP_SIZE><<<blocks_num, block_dims, 0, stream>>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps); } else { const dim3 block_dims(1024, 1, 1); - norm_f32<1024><<<blocks_num, block_dims, 0, stream>>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps); + norm_f32<1024><<<blocks_num, block_dims, block_dims.x > WARP_SIZE ? 32 * sizeof(float2): 0, stream>>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps); } } @@ -348,7 +297,7 @@ static void group_norm_f32_cuda( group_norm_f32<WARP_SIZE><<<num_groups, block_dims, 0, stream>>>(x, dst, group_size, ne_elements, eps); } else { const dim3 block_dims(1024, 1, 1); - group_norm_f32<1024><<<num_groups, block_dims, 0, stream>>>(x, dst, group_size, ne_elements, eps); + group_norm_f32<1024><<<num_groups, block_dims, block_dims.x > WARP_SIZE ? 32 * sizeof(float): 0, stream>>>(x, dst, group_size, ne_elements, eps); } } @@ -358,10 +307,19 @@ static void rms_norm_f32_cuda( const dim3 blocks_num(nrows, nchannels, nsamples); if (ncols < 1024) { const dim3 block_dims(256, 1, 1); - rms_norm_f32<256, false><<<blocks_num, block_dims, 0, stream>>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps); + const ggml_cuda_kernel_launch_params launch_params = {blocks_num, block_dims, block_dims.x > WARP_SIZE ? 32 * sizeof(float): 0, stream}; + ggml_cuda_kernel_launch(rms_norm_f32<256, false>, launch_params, + x, dst, ncols, stride_row, stride_channel, stride_sample, eps, + // underlying cudaLaunchKernelEx does not support default params + nullptr, 0, 0, 0, make_uint3(0, 0, 0), make_uint3(0, 0, 0), make_uint3(0, 0, 0), make_uint3(0, 0, 0), + nullptr, 0, 0, 0, make_uint3(0, 0, 0), make_uint3(0, 0, 0), make_uint3(0, 0, 0), make_uint3(0, 0, 0)); } else { const dim3 block_dims(1024, 1, 1); - rms_norm_f32<1024, false><<<blocks_num, block_dims, 0, stream>>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps); + const ggml_cuda_kernel_launch_params launch_params = ggml_cuda_kernel_launch_params{blocks_num, block_dims, block_dims.x > WARP_SIZE ? 32 * sizeof(float): 0, stream}; + ggml_cuda_kernel_launch(rms_norm_f32<1024, false>, launch_params, x, dst, ncols, stride_row, stride_channel, stride_sample, eps, + // underlying cudaLaunchKernelEx does not support default params + nullptr, 0, 0, 0, make_uint3(0, 0, 0), make_uint3(0, 0, 0), make_uint3(0, 0, 0), make_uint3(0, 0, 0), + nullptr, 0, 0, 0, make_uint3(0, 0, 0), make_uint3(0, 0, 0), make_uint3(0, 0, 0), make_uint3(0, 0, 0)); } } @@ -404,14 +362,20 @@ static void rms_norm_mul_f32_cuda(const float * x, const uint3 mul_nsamples_packed = init_fastdiv_values(mul_nsamples); if (ncols < 1024) { const dim3 block_dims(256, 1, 1); - rms_norm_f32<256, true><<<blocks_num, block_dims, 0, stream>>>( + const ggml_cuda_kernel_launch_params launch_params = ggml_cuda_kernel_launch_params{blocks_num, block_dims, block_dims.x > WARP_SIZE ? 32 * sizeof(float): 0, stream}; + ggml_cuda_kernel_launch(rms_norm_f32<256, true>, launch_params, x, dst, ncols, stride_row, stride_channel, stride_sample, eps, mul, mul_stride_row, mul_stride_channel, - mul_stride_sample, mul_ncols_packed, mul_nrows_packed, mul_nchannels_packed, mul_nsamples_packed); + mul_stride_sample, mul_ncols_packed, mul_nrows_packed, mul_nchannels_packed, mul_nsamples_packed, + // underlying cudaLaunchKernelEx does not support default params + nullptr, 0, 0, 0, make_uint3(0, 0, 0), make_uint3(0, 0, 0), make_uint3(0, 0, 0), make_uint3(0, 0, 0)); } else { const dim3 block_dims(1024, 1, 1); - rms_norm_f32<1024, true><<<blocks_num, block_dims, 0, stream>>>( + const ggml_cuda_kernel_launch_params launch_params = ggml_cuda_kernel_launch_params{blocks_num, block_dims, block_dims.x > WARP_SIZE ? 32 * sizeof(float): 0, stream}; + ggml_cuda_kernel_launch(rms_norm_f32<1024, true>, launch_params, x, dst, ncols, stride_row, stride_channel, stride_sample, eps, mul, mul_stride_row, mul_stride_channel, - mul_stride_sample, mul_ncols_packed, mul_nrows_packed, mul_nchannels_packed, mul_nsamples_packed); + mul_stride_sample, mul_ncols_packed, mul_nrows_packed, mul_nchannels_packed, mul_nsamples_packed, + // underlying cudaLaunchKernelEx does not support default params + nullptr, 0, 0, 0, make_uint3(0, 0, 0), make_uint3(0, 0, 0), make_uint3(0, 0, 0), make_uint3(0, 0, 0)); } } else { const uint3 mul_ncols_packed = init_fastdiv_values(mul_ncols); @@ -425,14 +389,16 @@ static void rms_norm_mul_f32_cuda(const float * x, const uint3 add_nsamples_packed = init_fastdiv_values(add_nsamples); if (ncols < 1024) { const dim3 block_dims(256, 1, 1); - rms_norm_f32<256, true, true><<<blocks_num, block_dims, 0, stream>>>( + const ggml_cuda_kernel_launch_params launch_params = ggml_cuda_kernel_launch_params{blocks_num, block_dims,block_dims.x > WARP_SIZE ? 32 * sizeof(float): 0, stream}; + ggml_cuda_kernel_launch(rms_norm_f32<256, true, true>, launch_params, x, dst, ncols, stride_row, stride_channel, stride_sample, eps, mul, mul_stride_row, mul_stride_channel, mul_stride_sample, mul_ncols_packed, mul_nrows_packed, mul_nchannels_packed, mul_nsamples_packed, add, add_stride_row, add_stride_channel, add_stride_sample, add_ncols_packed, add_nrows_packed, add_nchannels_packed, add_nsamples_packed); } else { const dim3 block_dims(1024, 1, 1); - rms_norm_f32<1024, true, true><<<blocks_num, block_dims, 0, stream>>>( + const ggml_cuda_kernel_launch_params launch_params = ggml_cuda_kernel_launch_params{blocks_num, block_dims, block_dims.x > WARP_SIZE ? 32 * sizeof(float): 0, stream}; + ggml_cuda_kernel_launch(rms_norm_f32<1024, true, true>, launch_params, x, dst, ncols, stride_row, stride_channel, stride_sample, eps, mul, mul_stride_row, mul_stride_channel, mul_stride_sample, mul_ncols_packed, mul_nrows_packed, mul_nchannels_packed, mul_nsamples_packed, add, add_stride_row, add_stride_channel, add_stride_sample, add_ncols_packed, add_nrows_packed, @@ -457,10 +423,12 @@ static void l2_norm_f32_cuda( const dim3 blocks_num(nrows, nchannels, nsamples); if (ncols < 1024) { const dim3 block_dims(WARP_SIZE, 1, 1); - l2_norm_f32<WARP_SIZE><<<blocks_num, block_dims, 0, stream>>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps); + const ggml_cuda_kernel_launch_params launch_params = ggml_cuda_kernel_launch_params{blocks_num, block_dims, 0, stream}; + ggml_cuda_kernel_launch(l2_norm_f32<WARP_SIZE>, launch_params, x, dst, ncols, stride_row, stride_channel, stride_sample, eps); } else { const dim3 block_dims(1024, 1, 1); - l2_norm_f32<1024><<<blocks_num, block_dims, 0, stream>>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps); + const ggml_cuda_kernel_launch_params launch_params = ggml_cuda_kernel_launch_params{blocks_num, block_dims, block_dims.x > WARP_SIZE ? 32 * sizeof(float): 0, stream}; + ggml_cuda_kernel_launch(l2_norm_f32<1024>, launch_params, x, dst, ncols, stride_row, stride_channel, stride_sample, eps); } } diff --git a/ggml/src/ggml-cuda/out-prod.cu b/ggml/src/ggml-cuda/out-prod.cu index c9b2b699c6a..499903d09b1 100644 --- a/ggml/src/ggml-cuda/out-prod.cu +++ b/ggml/src/ggml-cuda/out-prod.cu @@ -54,15 +54,31 @@ void ggml_cuda_out_prod(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { const int64_t dps2 = ne2 / ne02; const int64_t dps3 = ne3 / ne03; - // TODO batched matrix multiplication - for (int64_t i3 = 0; i3 < ne3; ++i3) { - for (int64_t i2 = 0; i2 < ne2; ++i2) { + if (dps2 == 1 && ne2 > 1) { + // src0 has uniform stride s02 along dim 2; batch the inner loop with a strided GEMM + GGML_ASSERT(ne2 <= std::numeric_limits<int>::max()); + const int batch_count = (int) ne2; + for (int64_t i3 = 0; i3 < ne3; ++i3) { CUBLAS_CHECK( - cublasSgemm(handle, CUBLAS_OP_N, src1_cublas_op, + cublasSgemmStridedBatched(handle, CUBLAS_OP_N, src1_cublas_op, ne0, ne1, ne01, - &alpha, src0_d + (i3/dps3)*s03 + (i2/dps2)*s02, lda, - src1_d + i3 *s13 + i2 *s12, ldb, - &beta, dst_d + i3 *s3 + i2 *s2, ldc)); + &alpha, src0_d + (i3/dps3)*s03, lda, s02, + src1_d + i3 *s13, ldb, s12, + &beta, dst_d + i3 *s3, ldc, s2, + batch_count)); + } + } else { + // Fallback: ne2 == 1 (no batching benefit) or dps2 > 1 (src0 broadcast along dim 2 + // with non-uniform stride; would need cublasSgemmBatched with pointer arrays). + for (int64_t i3 = 0; i3 < ne3; ++i3) { + for (int64_t i2 = 0; i2 < ne2; ++i2) { + CUBLAS_CHECK( + cublasSgemm(handle, CUBLAS_OP_N, src1_cublas_op, + ne0, ne1, ne01, + &alpha, src0_d + (i3/dps3)*s03 + (i2/dps2)*s02, lda, + src1_d + i3 *s13 + i2 *s12, ldb, + &beta, dst_d + i3 *s3 + i2 *s2, ldc)); + } } } } diff --git a/ggml/src/ggml-cuda/pad.cu b/ggml/src/ggml-cuda/pad.cu index 660c192e48a..31cd00f7781 100644 --- a/ggml/src/ggml-cuda/pad.cu +++ b/ggml/src/ggml-cuda/pad.cu @@ -7,7 +7,7 @@ __device__ __forceinline__ int64_t wrap_around(int64_t coord, int64_t size) { return (coord + size) % size; } -static __global__ void pad_f32(const float * src, float * dst, +static __global__ void pad_f32(const float * src, size_t s00, size_t s01, size_t s02, size_t s03, float * dst, const int lp0, const int rp0, const int lp1, const int rp1, const int lp2, const int rp2, const int lp3, const int rp3, const int ne0, const int ne1, const int ne2, const int ne3, @@ -34,11 +34,8 @@ static __global__ void pad_f32(const float * src, float * dst, const int64_t i01 = i1 - lp1; const int64_t i02 = i2 - lp2; const int64_t i03 = i3 - lp3; - const int64_t ne02 = ne2 - lp2 - rp2; - const int64_t ne01 = ne1 - lp1 - rp1; - const int64_t ne00 = ne0 - lp0 - rp0; - const int64_t src_idx = i03 * (ne00 * ne01 * ne02) + i02 * (ne00 * ne01) + i01 * ne00 + i00; + const int64_t src_idx = i03 * s03 + i02 * s02 + i01 * s01 + i00 * s00; dst[dst_idx] = src[src_idx]; } else { @@ -57,21 +54,21 @@ static __global__ void pad_f32(const float * src, float * dst, const int64_t i02 = wrap_around(i2 - lp2, ne02); const int64_t i03 = wrap_around(i3 - lp3, ne03); - const int64_t src_idx = i03 * (ne00 * ne01 * ne02) + i02 * (ne00 * ne01) + i01 * ne00 + i00; + const int64_t src_idx = i03 * s03 + i02 * s02 + i01 * s01 + i00 * s00; dst[dst_idx] = src[src_idx]; } } -static void pad_f32_cuda(const float * src, float * dst, +static void pad_f32_cuda(const float * src, size_t s00, size_t s01, size_t s02, size_t s03, float * dst, const int lp0, const int rp0, const int lp1, const int rp1, const int lp2, const int rp2, const int lp3, const int rp3, const int ne0, const int ne1, const int ne2, const int ne3, const bool circular, cudaStream_t stream) { int num_blocks = (ne0 + CUDA_PAD_BLOCK_SIZE - 1) / CUDA_PAD_BLOCK_SIZE; dim3 gridDim(num_blocks, ne1, ne2 * ne3); - pad_f32<<<gridDim, CUDA_PAD_BLOCK_SIZE, 0, stream>>>(src, dst, + pad_f32<<<gridDim, CUDA_PAD_BLOCK_SIZE, 0, stream>>>(src, s00, s01, s02, s03, dst, lp0, rp0, lp1, rp1, lp2, rp2, lp3, rp3, ne0, ne1, ne2, ne3, circular); } @@ -82,9 +79,10 @@ void ggml_cuda_op_pad(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { float * dst_d = (float *) dst->data; cudaStream_t stream = ctx.stream(); + GGML_TENSOR_UNARY_OP_LOCALS; + GGML_ASSERT(src0->type == GGML_TYPE_F32); GGML_ASSERT(dst->type == GGML_TYPE_F32); - GGML_ASSERT(ggml_is_contiguous(src0)); const int32_t lp0 = ((const int32_t *) (dst->op_params))[0]; const int32_t rp0 = ((const int32_t *) (dst->op_params))[1]; @@ -96,7 +94,12 @@ void ggml_cuda_op_pad(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { const int32_t rp3 = ((const int32_t *) (dst->op_params))[7]; const int32_t circular = ((const int32_t *) (dst->op_params))[8]; - pad_f32_cuda(src0_d, dst_d, + const size_t s00 = nb00 / ggml_type_size(src0->type); + const size_t s01 = nb01 / ggml_type_size(src0->type); + const size_t s02 = nb02 / ggml_type_size(src0->type); + const size_t s03 = nb03 / ggml_type_size(src0->type); + + pad_f32_cuda(src0_d, s00, s01, s02, s03, dst_d, lp0, rp0, lp1, rp1, lp2, rp2, lp3, rp3, dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], (bool) circular, stream); diff --git a/ggml/src/ggml-cuda/quantize.cu b/ggml/src/ggml-cuda/quantize.cu index a8c68e44b16..39a500a1704 100644 --- a/ggml/src/ggml-cuda/quantize.cu +++ b/ggml/src/ggml-cuda/quantize.cu @@ -3,9 +3,12 @@ __launch_bounds__(CUDA_QUANTIZE_BLOCK_SIZE, 1) static __global__ void quantize_q8_1( - const float * __restrict__ x, void * __restrict__ vy, + const float * x_ptr, void * vy_ptr, const int64_t ne00, const int64_t s01, const int64_t s02, const int64_t s03, const int64_t ne0, const uint32_t ne1, const uint3 ne2) { + ggml_cuda_pdl_lc(); + const float * GGML_CUDA_RESTRICT x = x_ptr; + void * GGML_CUDA_RESTRICT vy = vy_ptr; const int64_t i0 = (int64_t)blockDim.x*blockIdx.x + threadIdx.x; if (i0 >= ne0) { @@ -28,6 +31,7 @@ static __global__ void quantize_q8_1( const int64_t ib = i_cont / QK8_1; // block index const int64_t iqs = i_cont % QK8_1; // quant index + ggml_cuda_pdl_sync(); const float xi = i0 < ne00 ? x[i03*s03 + i02*s02 + i01*s01 + i00] : 0.0f; float amax = fabsf(xi); float sum = xi; @@ -70,6 +74,102 @@ __device__ __forceinline__ uint8_t compute_e8m0_scale(float amax) { return static_cast<uint8_t>(biased); } + +static __global__ void quantize_mmq_nvfp4( + const float * __restrict__ x, const int32_t * __restrict__ ids, void * __restrict__ vy, + const int64_t ne00, const int64_t s01, const int64_t s02, const int64_t s03, + const int64_t ne0, const int64_t ne1, const int64_t ne2) { +#if defined(BLACKWELL_MMA_AVAILABLE) + + const int64_t i0_base = ((int64_t) blockDim.x * blockIdx.y + threadIdx.x) * QK_NVFP4_SUB; + if (i0_base >= ne0) { + return; + } + + const int64_t i1 = blockIdx.x; + const int64_t i2 = blockIdx.z % ne2; + const int64_t i3 = blockIdx.z / ne2; + const int64_t i01 = ids ? ids[i1] : i1; + const int64_t k_block = i0_base / QK_K; + const int64_t blocks_per_col = (ne0 + QK_K - 1) / QK_K; + if (k_block >= blocks_per_col) { + return; + } + + const int64_t ib = blockIdx.z * ((int64_t) blocks_per_col * ne1) + k_block * ne1 + blockIdx.x; + block_fp4_mmq * y = (block_fp4_mmq *) vy; + block_fp4_mmq * yb = y + ib; + + const int sub = (i0_base % QK_K) / QK_NVFP4_SUB; + + float vals_raw[QK_NVFP4_SUB]; + float amax_raw = 0.0f; + const int64_t base_idx = i3 * s03 + i2 * s02 + i01 * s01; +#pragma unroll + for (int k = 0; k < QK_NVFP4_SUB; k++) { + const int64_t i00 = i0_base + k; + if (i00 < ne00) { + const float v = x[base_idx + i00]; + vals_raw[k] = v; + amax_raw = fmaxf(amax_raw, fabsf(v)); + } else { + vals_raw[k] = 0.0f; + } + } + + static constexpr int test_offsets[5] = { 0, -1, 1, -2, 2}; + const int first_fp8_code = (int) ggml_cuda_fp32_to_ue4m3(amax_raw / 6.0f); + + float best_err = FLT_MAX; + uint8_t fp8_code = 0; + float subblock_scale = 0.0f; + +#pragma unroll // Check +/- 2 to find best code to reduce NVFP4 activation loss. Negligible overhead on Blackwell. + for (int i = 0; i < 5; i++) { + const int test_code = first_fp8_code + test_offsets[i]; + if (test_code < 0 || test_code > 0x7e) { + continue; + } + const uint8_t code = (uint8_t) test_code; + const float test_scale = ggml_cuda_ue4m3_to_fp32(code); + const float test_inv_scale = test_scale > 0.0f ? 0.5f / test_scale : 0.0f; + float cur_err = 0.0f; +#pragma unroll + for (int k = 0; k < QK_NVFP4_SUB; ++k) { + const float v = vals_raw[k]; + const uint8_t q = ggml_cuda_float_to_fp4_e2m1(v, test_inv_scale); + const float err_diff = fabsf(v) - fabsf(kvalues_mxfp4[q & 0x7]) * test_scale; + cur_err = fmaf(err_diff, err_diff, cur_err); + } + + if (cur_err < best_err) { + best_err = cur_err; + fp8_code = test_code; + subblock_scale = test_scale; + } + } + + const float inv_scale = subblock_scale > 0.0f ? 0.5f / subblock_scale : 0.0f; + uint32_t q0 = 0; + uint32_t q1 = 0; +#pragma unroll // this is faster than the previous __nv_fp4x4_e2m1 + for (int k = 0; k < QK_NVFP4_SUB / 4; ++k) { + q0 |= (uint32_t) ggml_cuda_float_to_fp4_e2m1(vals_raw[k + 0], inv_scale) << (8 * k); + q0 |= (uint32_t) ggml_cuda_float_to_fp4_e2m1(vals_raw[k + 8], inv_scale) << (8 * k + 4); + q1 |= (uint32_t) ggml_cuda_float_to_fp4_e2m1(vals_raw[k + 4], inv_scale) << (8 * k); + q1 |= (uint32_t) ggml_cuda_float_to_fp4_e2m1(vals_raw[k + 12], inv_scale) << (8 * k + 4); + } + + uint32_t * yqs = reinterpret_cast<uint32_t *>(yb->qs); + yqs[2 * sub + 0] = q0; + yqs[2 * sub + 1] = q1; + reinterpret_cast<uint8_t *>(yb->d4)[sub] = fp8_code; +#else + NO_DEVICE_CODE; // This is for Blackwell NVFP4 activations only. +#endif // defined(BLACKWELL_MMA_AVAILABLE) + +} + // quantize values in the format mxfp4 is stored which is interleaved nibbles // i.e. a block a0-a31 is represented as a0a16,a1a17 ...a15a31 static __global__ void quantize_mmq_mxfp4(const float * __restrict__ x, @@ -100,6 +200,7 @@ static __global__ void quantize_mmq_mxfp4(const float * __restrict__ x, const int64_t i2 = blockIdx.z % ne2; const int64_t i3 = blockIdx.z / ne2; + ggml_cuda_pdl_sync(); const int64_t i01 = ids ? ids[i1] : i1; const int64_t i02 = i2; const int64_t i03 = i3; @@ -192,6 +293,7 @@ static __global__ void quantize_mmq_q8_1( const int64_t i3 = blockIdx.z / ne2; const int64_t i00 = i0; + ggml_cuda_pdl_sync(); const int64_t i01 = ids ? ids[i1] : i1; const int64_t i02 = i2; const int64_t i03 = i3; @@ -235,7 +337,7 @@ static __global__ void quantize_mmq_q8_1( q.z = roundf(xi.z*d_inv); q.w = roundf(xi.w*d_inv); - // Write back 4 int8 values as a single 32 bit value for better memroy bandwidth: + // Write back 4 int8 values as a single 32 bit value for better memory bandwidth: char4 * yqs4 = (char4 *) y[ib].qs; yqs4[iqs/4] = q; @@ -282,7 +384,8 @@ void quantize_row_q8_1_cuda( const int64_t block_num_x = (ne0 + CUDA_QUANTIZE_BLOCK_SIZE - 1) / CUDA_QUANTIZE_BLOCK_SIZE; const dim3 num_blocks(block_num_x, ne1, ne2*ne3); const dim3 block_size(CUDA_QUANTIZE_BLOCK_SIZE, 1, 1); - quantize_q8_1<<<num_blocks, block_size, 0, stream>>>(x, vy, ne00, s01, s02, s03, ne0, ne1, ne2_fastdiv); + const ggml_cuda_kernel_launch_params launch_params = ggml_cuda_kernel_launch_params(num_blocks, block_size, 0, stream); + ggml_cuda_kernel_launch(quantize_q8_1, launch_params, x, vy, ne00, s01, s02, s03, ne0, ne1, ne2_fastdiv); GGML_UNUSED(type_src0); } @@ -316,28 +419,32 @@ void quantize_mmq_q8_1_cuda( } } -void quantize_mmq_mxfp4_cuda(const float * x, - const int32_t * ids, - void * vy, - [[maybe_unused]] const ggml_type type_src0, - const int64_t ne00, - const int64_t s01, - const int64_t s02, - const int64_t s03, - const int64_t ne0, - const int64_t ne1, - const int64_t ne2, - const int64_t ne3, - cudaStream_t stream) { - GGML_ASSERT(ne0 % (2 * QK_MXFP4) == 0); - - constexpr int nwarps = 8; - constexpr int vals_per_warp = 2 * QK_MXFP4; - constexpr int vals_per_block = nwarps * vals_per_warp; - - const int64_t block_num_y = (ne0 + vals_per_block - 1) / vals_per_block; - const dim3 num_blocks(ne1, block_num_y, ne2 * ne3); - const dim3 block_size(WARP_SIZE, nwarps, 1); - - quantize_mmq_mxfp4<<<num_blocks, block_size, 0, stream>>>(x, ids, vy, ne00, s01, s02, s03, ne0, ne1, ne2); +void quantize_mmq_fp4_cuda( + const float * x, const int32_t * ids, void * vy, const ggml_type type_src0, + const int64_t ne00, const int64_t s01, const int64_t s02, const int64_t s03, + const int64_t ne0, const int64_t ne1, const int64_t ne2, const int64_t ne3, cudaStream_t stream) { + GGML_ASSERT(type_src0 == GGML_TYPE_MXFP4 || type_src0 == GGML_TYPE_NVFP4); + GGML_ASSERT(ne0 > 0); + + if (type_src0 == GGML_TYPE_NVFP4) { + GGML_ASSERT(ne00 % QK_NVFP4 == 0); + constexpr int nvfp4_block_size = 128; + const int64_t block_num_y = (ne0 + QK_NVFP4_SUB * nvfp4_block_size - 1) / (QK_NVFP4_SUB * nvfp4_block_size); + const dim3 block_size(nvfp4_block_size, 1, 1); + const dim3 num_blocks(ne1, block_num_y, ne2 * ne3); + quantize_mmq_nvfp4<<<num_blocks, block_size, 0, stream>>>( + x, ids, vy, ne00, s01, s02, s03, ne0, ne1, ne2); + } else { + GGML_ASSERT(ne0 % (2 * QK_MXFP4) == 0); + + constexpr int nwarps = 8; + constexpr int vals_per_warp = 2 * QK_MXFP4; + constexpr int vals_per_block = nwarps * vals_per_warp; + + const int64_t block_num_y = (ne0 + vals_per_block - 1) / vals_per_block; + const dim3 num_blocks(ne1, block_num_y, ne2 * ne3); + const dim3 block_size(WARP_SIZE, nwarps, 1); + + quantize_mmq_mxfp4<<<num_blocks, block_size, 0, stream>>>(x, ids, vy, ne00, s01, s02, s03, ne0, ne1, ne2); + } } diff --git a/ggml/src/ggml-cuda/quantize.cuh b/ggml/src/ggml-cuda/quantize.cuh index 6a91df63578..768a3ae6de6 100644 --- a/ggml/src/ggml-cuda/quantize.cuh +++ b/ggml/src/ggml-cuda/quantize.cuh @@ -26,7 +26,7 @@ void quantize_mmq_q8_1_cuda( ggml_type type_src0, int64_t ne00, int64_t s01, int64_t s02, int64_t s03, int64_t ne0, int64_t ne1, int64_t ne2, int64_t ne3, cudaStream_t stream); -void quantize_mmq_mxfp4_cuda(const float * x, +void quantize_mmq_fp4_cuda(const float * x, const int32_t * ids, void * vy, ggml_type type_src0, diff --git a/ggml/src/ggml-cuda/reduce_rows.cuh b/ggml/src/ggml-cuda/reduce_rows.cuh index 6bcae9e52fb..968c47aa20a 100644 --- a/ggml/src/ggml-cuda/reduce_rows.cuh +++ b/ggml/src/ggml-cuda/reduce_rows.cuh @@ -2,7 +2,9 @@ // Row reduction kernel template - compute sum (norm=false) or mean (norm=true) template <bool norm> -static __global__ void reduce_rows_f32(const float * __restrict__ x, float * __restrict__ dst, const int ncols) { +static __global__ void reduce_rows_f32(const float * x_ptr, float * dst_ptr, const int ncols) { + const float * GGML_CUDA_RESTRICT x = x_ptr; + float * GGML_CUDA_RESTRICT dst = dst_ptr; const int row = blockIdx.x; const int col = threadIdx.x; @@ -10,6 +12,8 @@ static __global__ void reduce_rows_f32(const float * __restrict__ x, float * __r const int num_unroll = 8; float temp[num_unroll]; float sum_temp[num_unroll] = { 0.0f }; + + ggml_cuda_pdl_sync(); for (int i = col; i < ncols;) { for (int j = 0; j < num_unroll; ++j) { if (i < ncols) { @@ -28,22 +32,8 @@ static __global__ void reduce_rows_f32(const float * __restrict__ x, float * __r } // sum up partial sums - sum = warp_reduce_sum(sum); - if (blockDim.x > WARP_SIZE) { - assert((blockDim.x <= 1024) && (blockDim.x % WARP_SIZE) == 0); - __shared__ float s_sum[32]; - const int warp_id = threadIdx.x / WARP_SIZE; - const int lane_id = threadIdx.x % WARP_SIZE; - if (lane_id == 0) { - s_sum[warp_id] = sum; - } - __syncthreads(); - sum = 0.0f; - if (lane_id < (static_cast<int>(blockDim.x) / WARP_SIZE)) { - sum = s_sum[lane_id]; - } - sum = warp_reduce_sum(sum); - } + __shared__ float shared_vals[32]; + sum = block_reduce<block_reduce_method::SUM>(sum, shared_vals); if (col != 0) { return; diff --git a/ggml/src/ggml-cuda/rope.cu b/ggml/src/ggml-cuda/rope.cu index 88ed79111a1..e20a5cb6bed 100644 --- a/ggml/src/ggml-cuda/rope.cu +++ b/ggml/src/ggml-cuda/rope.cu @@ -43,10 +43,15 @@ static __device__ void rope_yarn( template <bool forward, bool has_ff, typename T, typename D> static __global__ void rope_norm(const T * x, D * dst, - const int ne0, - const int ne1, + const int ne00, + const int ne01, + const int ne02, + const int s01, + const int s02, + const int s03, const int s1, const int s2, + const int s3, const int n_dims, const int32_t * pos, const float freq_scale, @@ -59,23 +64,23 @@ static __global__ void rope_norm(const T * x, const int set_rows_stride) { const int i0 = 2*(blockDim.y*blockIdx.y + threadIdx.y); - if (i0 >= ne0) { + if (i0 >= ne00) { return; } const int row_dst = blockDim.x*blockIdx.x + threadIdx.x; - const int row_x = row_dst % ne1; - const int channel_x = row_dst / ne1; - - int idst = row_dst * ne0 + i0; - const int ix = channel_x*s2 + row_x*s1 + i0; + const uint32_t i3 = row_dst / (ne01 * ne02); + const uint32_t i2 = (row_dst - i3 * ne01 * ne02) / ne01; + const uint32_t i1 = row_dst - i3 * ne01 * ne02 - i2 * ne01; + int idst = i0 + i1 * s1 + i2 * s2 + i3 * s3; + const int ix = i0 + i1 * s01 + i2 * s02 + i3 * s03; // Fusion optimization: ROPE + VIEW + SET_ROWS. // The rope output is viewed as a 1D tensor and offset based on a row index in row_indices. if (set_rows_stride != 0) { - idst = row_x * ne0 + i0; - idst += row_indices[channel_x] * set_rows_stride; + idst = i1 * s1 + i0; + idst += row_indices[i2] * set_rows_stride; } const auto & store_coaelsced = [&](float x0, float x1) { @@ -92,7 +97,7 @@ static __global__ void rope_norm(const T * x, return; } - const float theta_base = pos[channel_x]*powf(theta_scale, i0/2.0f); + const float theta_base = pos[i2]*powf(theta_scale, i0/2.0f); const float freq_factor = has_ff ? freq_factors[i0/2] : 1.0f; @@ -110,10 +115,15 @@ static __global__ void rope_norm(const T * x, template <bool forward, bool has_ff, typename T, typename D> static __global__ void rope_neox(const T * x, D * dst, - const int ne0, - const int ne1, + const int ne00, + const int ne01, + const int ne02, + const int s01, + const int s02, + const int s03, const int s1, const int s2, + const int s3, const int n_dims, const int32_t * pos, const float freq_scale, @@ -124,25 +134,28 @@ static __global__ void rope_neox(const T * x, const float * freq_factors, const int64_t * row_indices, const int set_rows_stride) { + ggml_cuda_pdl_lc(); const int i0 = 2*(blockDim.y*blockIdx.y + threadIdx.y); - if (i0 >= ne0) { + if (i0 >= ne00) { return; } const int row_dst = blockDim.x*blockIdx.x + threadIdx.x; - const int row_x = row_dst % ne1; - const int channel_x = row_dst / ne1; + const uint32_t i3 = row_dst / (ne01 * ne02); + const uint32_t i2 = (row_dst - i3 * ne01 * ne02) / ne01; + const uint32_t i1 = row_dst - i3 * ne01 * ne02 - i2 * ne01; - int idst = row_dst * ne0 + i0 / 2; - const int ix = channel_x*s2 + row_x*s1 + i0/2; + int idst = i0 / 2 + i1 * s1 + i2 * s2 + i3 * s3; + const int ix = i0 / 2 + i1 * s01 + i2 * s02 + i3 * s03; + ggml_cuda_pdl_sync(); // Fusion optimization: ROPE + VIEW + SET_ROWS. // The rope output is viewed as a 1D tensor and offset based on a row index in row_indices. if (set_rows_stride != 0) { - idst = row_x * ne0 + i0 / 2; - idst += row_indices[channel_x] * set_rows_stride; + idst = i1 * s1 + i0 / 2; + idst += row_indices[i2] * set_rows_stride; } if (i0 >= n_dims) { @@ -152,7 +165,7 @@ static __global__ void rope_neox(const T * x, return; } - const float theta_base = pos[channel_x]*powf(theta_scale, i0/2.0f); + const float theta_base = pos[i2]*powf(theta_scale, i0/2.0f); const float freq_factor = has_ff ? freq_factors[i0/2] : 1.0f; @@ -168,25 +181,44 @@ static __global__ void rope_neox(const T * x, dst[idst + n_dims / 2] = ggml_cuda_cast<D>(x0 * sin_theta + x1 * cos_theta); } -template<bool forward, bool has_ff, typename T> -static __global__ void rope_multi( - const T * x, T * dst, const int ne0, const int ne1, const int ne2, const int s1, const int s2, - const int n_dims, const int32_t * pos, const float freq_scale, const float ext_factor, const float attn_factor, - const rope_corr_dims corr_dims, const float theta_scale, const float * freq_factors, const mrope_sections sections, const bool is_imrope) { - const int i0 = 2*(blockDim.y*blockIdx.y + threadIdx.y); - - if (i0 >= ne0) { +template <bool forward, bool has_ff, typename T> +static __global__ void rope_multi(const T * x, + T * dst, + const int ne00, + const int ne01, + const int ne02, + const int s01, + const int s02, + const int s03, + const int s1, + const int s2, + const int s3, + const int n_dims, + const int32_t * pos, + const float freq_scale, + const float ext_factor, + const float attn_factor, + const rope_corr_dims corr_dims, + const float theta_scale, + const float * freq_factors, + const mrope_sections sections, + const bool is_imrope) { + const int i0 = 2 * (blockDim.y * blockIdx.y + threadIdx.y); + + if (i0 >= ne00) { return; } const int row_dst = blockDim.x*blockIdx.x + threadIdx.x; - const int row_x = row_dst % ne1; - const int channel_x = row_dst / ne1; + const uint32_t i3 = row_dst / (ne01 * ne02); + const uint32_t i2 = (row_dst - i3 * ne01 * ne02) / ne01; + const uint32_t i1 = row_dst - i3 * ne01 * ne02 - i2 * ne01; - const int idst = row_dst*ne0 + i0/2; - const int ix = channel_x*s2 + row_x*s1 + i0/2; + int idst = i0 / 2 + i1 * s1 + i2 * s2 + i3 * s3; + const int ix = i0 / 2 + i1 * s01 + i2 * s02 + i3 * s03; + ggml_cuda_pdl_sync(); if (i0 >= n_dims) { dst[idst + i0/2 + 0] = x[ix + i0/2 + 0]; dst[idst + i0/2 + 1] = x[ix + i0/2 + 1]; @@ -200,27 +232,24 @@ static __global__ void rope_multi( float theta_base = 0.0; if (is_imrope) { - if (sector % 3 == 1 && sector < 3 * sections.v[1]) { // h - theta_base = pos[channel_x + ne2 * 1]*powf(theta_scale, i0/2.0f); - } else if (sector % 3 == 2 && sector < 3 * sections.v[2]) { // w - theta_base = pos[channel_x + ne2 * 2]*powf(theta_scale, i0/2.0f); - } else if (sector % 3 == 0 && sector < 3 * sections.v[0]) { // t - theta_base = pos[channel_x]*powf(theta_scale, i0/2.0f); + if (sector % 3 == 1 && sector < 3 * sections.v[1]) { // h + theta_base = pos[i2 + ne02 * 1] * powf(theta_scale, i0 / 2.0f); + } else if (sector % 3 == 2 && sector < 3 * sections.v[2]) { // w + theta_base = pos[i2 + ne02 * 2] * powf(theta_scale, i0 / 2.0f); + } else if (sector % 3 == 0 && sector < 3 * sections.v[0]) { // t + theta_base = pos[i2] * powf(theta_scale, i0 / 2.0f); } else { - theta_base = pos[channel_x + ne2 * 3]*powf(theta_scale, i0/2.0f); + theta_base = pos[i2 + ne02 * 3] * powf(theta_scale, i0 / 2.0f); } } else { if (sector < sections.v[0]) { - theta_base = pos[channel_x]*powf(theta_scale, i0/2.0f); - } - else if (sector >= sections.v[0] && sector < sec_w) { - theta_base = pos[channel_x + ne2 * 1]*powf(theta_scale, i0/2.0f); - } - else if (sector >= sec_w && sector < sec_w + sections.v[2]) { - theta_base = pos[channel_x + ne2 * 2]*powf(theta_scale, i0/2.0f); - } - else if (sector >= sec_w + sections.v[2]) { - theta_base = pos[channel_x + ne2 * 3]*powf(theta_scale, i0/2.0f); + theta_base = pos[i2] * powf(theta_scale, i0 / 2.0f); + } else if (sector >= sections.v[0] && sector < sec_w) { + theta_base = pos[i2 + ne02 * 1] * powf(theta_scale, i0 / 2.0f); + } else if (sector >= sec_w && sector < sec_w + sections.v[2]) { + theta_base = pos[i2 + ne02 * 2] * powf(theta_scale, i0 / 2.0f); + } else if (sector >= sec_w + sections.v[2]) { + theta_base = pos[i2 + ne02 * 3] * powf(theta_scale, i0 / 2.0f); } } @@ -238,37 +267,54 @@ static __global__ void rope_multi( dst[idst + n_dims/2] = x0*sin_theta + x1*cos_theta; } -template<bool forward, bool has_ff, typename T> -static __global__ void rope_vision( - const T * x, T * dst, const int ne0, const int ne1, const int ne2, const int s1, const int s2, const int n_dims, - const int32_t * pos, const float freq_scale, const float ext_factor, const float attn_factor, const rope_corr_dims corr_dims, - const float theta_scale, const float * freq_factors, const mrope_sections sections) { +template <bool forward, bool has_ff, typename T> +static __global__ void rope_vision(const T * x, + T * dst, + const int ne00, + const int ne01, + const int ne02, + const int s01, + const int s02, + const int s03, + const int s1, + const int s2, + const int s3, + const int n_dims, + const int32_t * pos, + const float freq_scale, + const float ext_factor, + const float attn_factor, + const rope_corr_dims corr_dims, + const float theta_scale, + const float * freq_factors, + const mrope_sections sections) { const int i0 = 2*(blockDim.y*blockIdx.y + threadIdx.y); - if (i0 >= ne0) { + if (i0 >= ne00) { return; } const int row_dst = blockDim.x*blockIdx.x + threadIdx.x; - const int row_x = row_dst % ne1; - const int channel_x = row_dst / ne1; + const uint32_t i3 = row_dst / (ne01 * ne02); + const uint32_t i2 = (row_dst - i3 * ne01 * ne02) / ne01; + const uint32_t i1 = row_dst - i3 * ne01 * ne02 - i2 * ne01; - const int idst = row_dst*ne0 + i0/2; - const int ix = channel_x*s2 + row_x*s1 + i0/2; + int idst = i0 / 2 + i1 * s1 + i2 * s2 + i3 * s3; + const int ix = i0 / 2 + i1 * s01 + i2 * s02 + i3 * s03; + ggml_cuda_pdl_sync(); const int sect_dims = sections.v[0] + sections.v[1]; - const int sec_w = sections.v[1] + sections.v[0]; - const int sector = (i0 / 2) % sect_dims; + const int sec_w = sections.v[1] + sections.v[0]; + const int sector = (i0 / 2) % sect_dims; float theta_base = 0.0; if (sector < sections.v[0]) { const int p = sector; - theta_base = pos[channel_x]*powf(theta_scale, p); - } - else if (sector >= sections.v[0] && sector < sec_w) { + theta_base = pos[i2] * powf(theta_scale, p); + } else if (sector >= sections.v[0] && sector < sec_w) { const int p = sector - sections.v[0]; - theta_base = pos[channel_x + ne2]*powf(theta_scale, p); + theta_base = pos[i2 + ne02] * powf(theta_scale, p); } const float freq_factor = has_ff ? freq_factors[i0/2] : 1.0f; @@ -288,10 +334,15 @@ static __global__ void rope_vision( template <bool forward, typename T, typename D> static void rope_norm_cuda(const T * x, D * dst, - const int ne0, - const int ne1, + const int ne00, + const int ne01, + const int ne02, + const int s01, + const int s02, + const int s03, const int s1, const int s2, + const int s3, const int n_dims, const int nr, const int32_t * pos, @@ -304,31 +355,36 @@ static void rope_norm_cuda(const T * x, const int64_t * row_indices, const int set_rows_stride, cudaStream_t stream) { - GGML_ASSERT(ne0 % 2 == 0); + GGML_ASSERT(ne00 % 2 == 0); const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1); - const int n_blocks_x = (ne0 + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE); + const int n_blocks_x = (ne00 + 2 * CUDA_ROPE_BLOCK_SIZE - 1) / (2 * CUDA_ROPE_BLOCK_SIZE); const dim3 block_nums(nr, n_blocks_x, 1); - const float theta_scale = powf(freq_base, -2.0f/n_dims); + const float theta_scale = powf(freq_base, -2.0f / n_dims); if (freq_factors == nullptr) { rope_norm<forward, false><<<block_nums, block_dims, 0, stream>>>( - x, dst, ne0, ne1, s1, s2, n_dims, pos, freq_scale, ext_factor, attn_factor, corr_dims, theta_scale, - freq_factors, row_indices, set_rows_stride); + x, dst, ne00, ne01, ne02, s01, s02, s03, s1, s2, s3, n_dims, pos, freq_scale, ext_factor, + attn_factor, corr_dims, theta_scale, freq_factors, row_indices, set_rows_stride); } else { rope_norm<forward, true><<<block_nums, block_dims, 0, stream>>>( - x, dst, ne0, ne1, s1, s2, n_dims, pos, freq_scale, ext_factor, attn_factor, corr_dims, theta_scale, - freq_factors, row_indices, set_rows_stride); + x, dst, ne00, ne01, ne02, s01, s02, s03, s1, s2, s3, n_dims, pos, freq_scale, ext_factor, + attn_factor, corr_dims, theta_scale, freq_factors, row_indices, set_rows_stride); } } template <bool forward, typename T, typename D> static void rope_neox_cuda(const T * x, D * dst, - const int ne0, - const int ne1, + const int ne00, + const int ne01, + const int ne02, + const int s01, + const int s02, + const int s03, const int s1, const int s2, + const int s3, const int n_dims, const int nr, const int32_t * pos, @@ -341,55 +397,95 @@ static void rope_neox_cuda(const T * x, const int64_t * row_indices, const int set_rows_stride, cudaStream_t stream) { - GGML_ASSERT(ne0 % 2 == 0); + GGML_ASSERT(ne00 % 2 == 0); const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1); - const int n_blocks_x = (ne0 + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE); + const int n_blocks_x = (ne00 + 2 * CUDA_ROPE_BLOCK_SIZE - 1) / (2 * CUDA_ROPE_BLOCK_SIZE); const dim3 block_nums(nr, n_blocks_x, 1); - const float theta_scale = powf(freq_base, -2.0f/n_dims); + const float theta_scale = powf(freq_base, -2.0f / n_dims); + const ggml_cuda_kernel_launch_params launch_params = {block_nums, block_dims, 0, stream}; if (freq_factors == nullptr) { - rope_neox<forward, false><<<block_nums, block_dims, 0, stream>>>( - x, dst, ne0, ne1, s1, s2, n_dims, pos, freq_scale, ext_factor, attn_factor, corr_dims, theta_scale, - freq_factors, row_indices, set_rows_stride); + ggml_cuda_kernel_launch(rope_neox<forward, false, T, D>, launch_params, + x, dst, ne00, ne01, ne02, s01, s02, s03, s1, s2, s3, n_dims, pos, freq_scale, ext_factor, + attn_factor, corr_dims, theta_scale, freq_factors, row_indices, set_rows_stride); } else { - rope_neox<forward, true><<<block_nums, block_dims, 0, stream>>>( - x, dst, ne0, ne1, s1, s2, n_dims, pos, freq_scale, ext_factor, attn_factor, corr_dims, theta_scale, - freq_factors, row_indices, set_rows_stride); + ggml_cuda_kernel_launch(rope_neox<forward, true, T, D>, launch_params, + x, dst, ne00, ne01, ne02, s01, s02, s03, s1, s2, s3, n_dims, pos, freq_scale, ext_factor, + attn_factor, corr_dims, theta_scale, freq_factors, row_indices, set_rows_stride); } } -template<bool forward, typename T> -static void rope_multi_cuda( - const T * x, T * dst, const int ne0, const int ne1, const int ne2, const int s1, const int s2, const int n_dims, const int nr, - const int32_t * pos, const float freq_scale, const float freq_base, const float ext_factor, const float attn_factor, - const rope_corr_dims corr_dims, const float * freq_factors, const mrope_sections sections, const bool is_imrope, cudaStream_t stream) { - GGML_ASSERT(ne0 % 2 == 0); +template <bool forward, typename T> +static void rope_multi_cuda(const T * x, + T * dst, + const int ne00, + const int ne01, + const int ne02, + const int s01, + const int s02, + const int s03, + const int s1, + const int s2, + const int s3, + const int n_dims, + const int nr, + const int32_t * pos, + const float freq_scale, + const float freq_base, + const float ext_factor, + const float attn_factor, + const rope_corr_dims corr_dims, + const float * freq_factors, + const mrope_sections sections, + const bool is_imrope, + cudaStream_t stream) { + GGML_ASSERT(ne00 % 2 == 0); const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1); - const int n_blocks_x = (ne0 + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE); + const int n_blocks_x = (ne00 + 2 * CUDA_ROPE_BLOCK_SIZE - 1) / (2 * CUDA_ROPE_BLOCK_SIZE); const dim3 block_nums(nr, n_blocks_x, 1); - const float theta_scale = powf(freq_base, -2.0f/n_dims); + const float theta_scale = powf(freq_base, -2.0f / n_dims); if (freq_factors == nullptr) { - rope_multi<forward, false, T><<<block_nums, block_dims, 0, stream>>>( - x, dst, ne0, ne1, ne2, s1, s2, n_dims, pos, freq_scale, ext_factor, + const ggml_cuda_kernel_launch_params launch_params = ggml_cuda_kernel_launch_params(block_nums, block_dims, 0, stream); + ggml_cuda_kernel_launch(rope_multi<forward, false, T>, launch_params, + x, dst, ne00, ne01, ne02, s01, s02, s03, s1, s2, s3, n_dims, pos, freq_scale, ext_factor, attn_factor, corr_dims, theta_scale, freq_factors, sections, is_imrope); } else { - rope_multi<forward, true, T><<<block_nums, block_dims, 0, stream>>>( - x, dst, ne0, ne1, ne2, s1, s2, n_dims, pos, freq_scale, ext_factor, + const ggml_cuda_kernel_launch_params launch_params = ggml_cuda_kernel_launch_params(block_nums, block_dims, 0, stream); + ggml_cuda_kernel_launch(rope_multi<forward, true, T>, launch_params, + x, dst, ne00, ne01, ne02, s01, s02, s03, s1, s2, s3, n_dims, pos, freq_scale, ext_factor, attn_factor, corr_dims, theta_scale, freq_factors, sections, is_imrope); } } -template<bool forward, typename T> -static void rope_vision_cuda( - const T * x, T * dst, const int ne0, const int ne1, const int ne2, const int s1, const int s2, const int n_dims, const int nr, - const int32_t * pos, const float freq_scale, const float freq_base, const float ext_factor, const float attn_factor, - const rope_corr_dims corr_dims, const float * freq_factors, const mrope_sections sections, cudaStream_t stream) { - GGML_ASSERT(ne0 % 2 == 0); +template <bool forward, typename T> +static void rope_vision_cuda(const T * x, + T * dst, + const int ne00, + const int ne01, + const int ne02, + const int s01, + const int s02, + const int s03, + const int s1, + const int s2, + const int s3, + const int n_dims, + const int nr, + const int32_t * pos, + const float freq_scale, + const float freq_base, + const float ext_factor, + const float attn_factor, + const rope_corr_dims corr_dims, + const float * freq_factors, + const mrope_sections sections, + cudaStream_t stream) { + GGML_ASSERT(ne00 % 2 == 0); const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1); - const int n_blocks_x = (ne0 + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE); + const int n_blocks_x = (ne00 + 2 * CUDA_ROPE_BLOCK_SIZE - 1) / (2 * CUDA_ROPE_BLOCK_SIZE); const dim3 block_nums(nr, n_blocks_x, 1); // break down (head_dim, heads, seq) into (CUDA_ROPE_BLOCK_SIZE, x, heads * seq) // where x ~= ceil(head_dim / CUDA_ROPE_BLOCK_SIZE); @@ -398,11 +494,11 @@ static void rope_vision_cuda( if (freq_factors == nullptr) { rope_vision<forward, false, T><<<block_nums, block_dims, 0, stream>>>( - x, dst, ne0, ne1, ne2, s1, s2, n_dims, pos, freq_scale, ext_factor, + x, dst, ne00, ne01, ne02, s01, s02, s03, s1, s2, s3, n_dims, pos, freq_scale, ext_factor, attn_factor, corr_dims, theta_scale, freq_factors, sections); } else { rope_vision<forward, true, T><<<block_nums, block_dims, 0, stream>>>( - x, dst, ne0, ne1, ne2, s1, s2, n_dims, pos, freq_scale, ext_factor, + x, dst, ne00, ne01, ne02, s01, s02, s03, s1, s2, s3, n_dims, pos, freq_scale, ext_factor, attn_factor, corr_dims, theta_scale, freq_factors, sections); } } @@ -445,6 +541,11 @@ void ggml_cuda_op_rope_impl(ggml_backend_cuda_context & ctx, const size_t s01 = src0->nb[1] / ggml_type_size(src0->type); const size_t s02 = src0->nb[2] / ggml_type_size(src0->type); + const size_t s03 = src0->nb[3] / ggml_type_size(src0->type); + + const size_t s1 = dst->nb[1] / ggml_type_size(dst->type); + const size_t s2 = dst->nb[2] / ggml_type_size(dst->type); + const size_t s3 = dst->nb[3] / ggml_type_size(dst->type); //const int n_past = ((int32_t *) dst->op_params)[0]; const int n_dims = ((int32_t *) dst->op_params)[1]; @@ -495,57 +596,63 @@ void ggml_cuda_op_rope_impl(ggml_backend_cuda_context & ctx, // compute if (is_neox) { if (src0->type == GGML_TYPE_F32 && dst_type == GGML_TYPE_F32) { - rope_neox_cuda<forward, float, float>((const float *) src0_d, (float *) dst_d, ne00, ne01, s01, s02, n_dims, - nr, pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims, - freq_factors, row_indices, set_rows_stride, stream); + rope_neox_cuda<forward, float, float>((const float *) src0_d, (float *) dst_d, ne00, ne01, ne02, s01, s02, + s03, s1, s2, s3, n_dims, nr, pos, freq_scale, freq_base, + ext_factor, attn_factor, corr_dims, freq_factors, row_indices, + set_rows_stride, stream); } else if (src0->type == GGML_TYPE_F32 && dst_type == GGML_TYPE_F16) { - rope_neox_cuda<forward, float, half>((const float *) src0_d, (half *) dst_d, ne00, ne01, s01, s02, n_dims, - nr, pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims, - freq_factors, row_indices, set_rows_stride, stream); + rope_neox_cuda<forward, float, half>((const float *) src0_d, (half *) dst_d, ne00, ne01, ne02, s01, s02, + s03, s1, s2, s3, n_dims, nr, pos, freq_scale, freq_base, + ext_factor, attn_factor, corr_dims, freq_factors, row_indices, + set_rows_stride, stream); } else if (src0->type == GGML_TYPE_F16 && dst_type == GGML_TYPE_F16) { - rope_neox_cuda<forward, half, half>((const half *) src0_d, (half *) dst_d, ne00, ne01, s01, s02, n_dims, nr, - pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims, - freq_factors, row_indices, set_rows_stride, stream); + rope_neox_cuda<forward, half, half>((const half *) src0_d, (half *) dst_d, ne00, ne01, ne02, s01, s02, + s03, s1, s2, s3, n_dims, nr, pos, freq_scale, freq_base, + ext_factor, attn_factor, corr_dims, freq_factors, row_indices, + set_rows_stride, stream); } else { GGML_ABORT("fatal error"); } } else if (is_mrope && !is_vision) { if (src0->type == GGML_TYPE_F32) { - rope_multi_cuda<forward>( - (const float *) src0_d, (float *) dst_d, ne00, ne01, ne02, s01, s02, n_dims, nr, pos, freq_scale, - freq_base, ext_factor, attn_factor, corr_dims, freq_factors, sections, is_imrope, stream); + rope_multi_cuda<forward>((const float *) src0_d, (float *) dst_d, ne00, ne01, ne02, s01, s02, s03, s1, + s2, s3, n_dims, nr, pos, freq_scale, freq_base, ext_factor, attn_factor, + corr_dims, freq_factors, sections, is_imrope, stream); } else if (src0->type == GGML_TYPE_F16) { - rope_multi_cuda<forward>( - (const half *) src0_d, (half *) dst_d, ne00, ne01, ne02, s01, s02, n_dims, nr, pos, freq_scale, - freq_base, ext_factor, attn_factor, corr_dims, freq_factors, sections, is_imrope, stream); + rope_multi_cuda<forward>((const half *) src0_d, (half *) dst_d, ne00, ne01, ne02, s01, s02, s03, s1, + s2, s3, n_dims, nr, pos, freq_scale, freq_base, ext_factor, attn_factor, + corr_dims, freq_factors, sections, is_imrope, stream); } else { GGML_ABORT("fatal error"); } } else if (is_vision) { if (src0->type == GGML_TYPE_F32) { - rope_vision_cuda<forward>( - (const float *) src0_d, (float *) dst_d, ne00, ne01, ne02, s01, s02, n_dims, nr, pos, freq_scale, - freq_base, ext_factor, attn_factor, corr_dims, freq_factors, sections, stream); + rope_vision_cuda<forward>((const float *) src0_d, (float *) dst_d, ne00, ne01, ne02, s01, s02, s03, s1, + s2, s3, n_dims, nr, pos, freq_scale, freq_base, ext_factor, attn_factor, + corr_dims, freq_factors, sections, stream); } else if (src0->type == GGML_TYPE_F16) { - rope_vision_cuda<forward>( - (const half *) src0_d, (half *) dst_d, ne00, ne01, ne02, s01, s02, n_dims, nr, pos, freq_scale, - freq_base, ext_factor, attn_factor, corr_dims, freq_factors, sections, stream); + rope_vision_cuda<forward>((const half *) src0_d, (half *) dst_d, ne00, ne01, ne02, s01, s02, s03, s1, + s2, s3, n_dims, nr, pos, freq_scale, freq_base, ext_factor, attn_factor, + corr_dims, freq_factors, sections, stream); } else { GGML_ABORT("fatal error"); } } else { if (src0->type == GGML_TYPE_F32 && dst_type == GGML_TYPE_F32) { - rope_norm_cuda<forward, float, float>((const float *) src0_d, (float *) dst_d, ne00, ne01, s01, s02, n_dims, - nr, pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims, - freq_factors, row_indices, set_rows_stride, stream); + rope_norm_cuda<forward, float, float>((const float *) src0_d, (float *) dst_d, ne00, ne01, ne02, s01, s02, + s03, s1, s2, s3, n_dims, nr, pos, freq_scale, freq_base, + ext_factor, attn_factor, corr_dims, freq_factors, row_indices, + set_rows_stride, stream); } else if (src0->type == GGML_TYPE_F32 && dst_type == GGML_TYPE_F16) { - rope_norm_cuda<forward, float, half>((const float *) src0_d, (half *) dst_d, ne00, ne01, s01, s02, n_dims, - nr, pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims, - freq_factors, row_indices, set_rows_stride, stream); + rope_norm_cuda<forward, float, half>((const float *) src0_d, (half *) dst_d, ne00, ne01, ne02, s01, s02, + s03, s1, s2, s3, n_dims, nr, pos, freq_scale, freq_base, + ext_factor, attn_factor, corr_dims, freq_factors, row_indices, + set_rows_stride, stream); } else if (src0->type == GGML_TYPE_F16 && dst_type == GGML_TYPE_F16) { - rope_norm_cuda<forward, half, half>((const half *) src0_d, (half *) dst_d, ne00, ne01, s01, s02, n_dims, nr, - pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims, - freq_factors, row_indices, set_rows_stride, stream); + rope_norm_cuda<forward, half, half>((const half *) src0_d, (half *) dst_d, ne00, ne01, ne02, s01, s02, + s03, s1, s2, s3, n_dims, nr, pos, freq_scale, freq_base, + ext_factor, attn_factor, corr_dims, freq_factors, row_indices, + set_rows_stride, stream); } else { GGML_ABORT("fatal error"); } diff --git a/ggml/src/ggml-cuda/scale.cu b/ggml/src/ggml-cuda/scale.cu index 0ddeff6a175..7b2e59a4383 100644 --- a/ggml/src/ggml-cuda/scale.cu +++ b/ggml/src/ggml-cuda/scale.cu @@ -3,9 +3,11 @@ #define MAX_GRIDDIM_X 0x7FFFFFFF static __global__ void scale_f32(const float * x, float * dst, const float scale, const float bias, const int64_t nelements) { + ggml_cuda_pdl_lc(); int64_t tid = (int64_t)blockIdx.x * (int64_t)blockDim.x + (int64_t)threadIdx.x; int64_t stride = (int64_t)blockDim.x * (int64_t)gridDim.x; + ggml_cuda_pdl_sync(); for (int64_t i = tid; i < nelements; i += stride) { dst[i] = scale * x[i] + bias; } @@ -13,7 +15,8 @@ static __global__ void scale_f32(const float * x, float * dst, const float scale static void scale_f32_cuda(const float * x, float * dst, const float scale, const float bias, const int64_t nelements, cudaStream_t stream) { const int64_t num_blocks = (nelements + CUDA_SCALE_BLOCK_SIZE - 1) / CUDA_SCALE_BLOCK_SIZE; - scale_f32<<<MIN(MAX_GRIDDIM_X, num_blocks), CUDA_SCALE_BLOCK_SIZE, 0, stream>>>(x, dst, scale, bias, nelements); + const ggml_cuda_kernel_launch_params launch_params = ggml_cuda_kernel_launch_params(MIN(MAX_GRIDDIM_X, num_blocks), CUDA_SCALE_BLOCK_SIZE, 0, stream); + ggml_cuda_kernel_launch(scale_f32, launch_params, x, dst, scale, bias, nelements); } void ggml_cuda_op_scale(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { diff --git a/ggml/src/ggml-cuda/set-rows.cu b/ggml/src/ggml-cuda/set-rows.cu index 631de7e8fa5..3b4f004c946 100644 --- a/ggml/src/ggml-cuda/set-rows.cu +++ b/ggml/src/ggml-cuda/set-rows.cu @@ -53,6 +53,7 @@ static __global__ void k_set_rows_quant(const float * __restrict__ src0, const int64_t i11 = fastmodulo((uint32_t) i02, ne11_fd); const int64_t i10 = i01; + ggml_cuda_pdl_sync(); const int64_t dst_row = *(src1 + i10*s10 + i11*s11 + i12*s12); const float * src0_row = src0 + i01*s01 + i02*s02 + i03*s03; @@ -110,9 +111,9 @@ static void set_rows_cuda_quant( } template <typename src_t, typename idx_t, typename dst_t> -static __global__ void k_set_rows(const src_t * __restrict__ src0, - const idx_t * __restrict__ src1, - dst_t * __restrict__ dst, +static __global__ void k_set_rows(const src_t * src0_ptr, + const idx_t * src1_ptr, + dst_t * dst_ptr, const int64_t ne_total, const int64_t ne10, const int64_t ne11, @@ -132,6 +133,9 @@ static __global__ void k_set_rows(const src_t * __restrict__ src0, const uint3 ne02, const uint3 ne11_fd, const uint3 ne12_fd) { + const src_t * GGML_CUDA_RESTRICT src0 = src0_ptr; + const idx_t * GGML_CUDA_RESTRICT src1 = src1_ptr; + dst_t * GGML_CUDA_RESTRICT dst = dst_ptr; const int64_t i = int64_t(blockDim.x) * blockIdx.x + threadIdx.x; if (i >= ne_total) { @@ -157,7 +161,9 @@ static __global__ void k_set_rows(const src_t * __restrict__ src0, const int64_t i11 = fastmodulo((uint32_t) i02, ne11_fd); const int64_t i10 = i01; + ggml_cuda_pdl_sync(); const int64_t dst_row = *(src1 + i10*s10 + i11*s11 + i12*s12); + ggml_cuda_pdl_lc(); const src_t * src0_row = src0 + i01*s01 + i02*s02 + i03*s03; dst_t * dst_row_ptr = dst + dst_row*s1 + i02*s2 + i03*s3; @@ -203,9 +209,11 @@ static void set_rows_cuda( const uint3 ne11_fd = init_fastdiv_values((uint32_t) ne11); const uint3 ne12_fd = init_fastdiv_values((uint32_t) ne12); - k_set_rows<<<grid_size, block_size, 0, stream>>>(src0_d, src1_d, dst_d, ne_total, ne10, ne11, ne12, ne13, s01, - s02, s03, s10, s11, s12, s1, s2, s3, ne00_fd, ne01_fd, ne02_fd, - ne11_fd, ne12_fd); + const ggml_cuda_kernel_launch_params launch_params = ggml_cuda_kernel_launch_params(grid_size, block_size, 0, stream); + ggml_cuda_kernel_launch(k_set_rows<src_t, idx_t, dst_t>, launch_params, + src0_d, src1_d, dst_d, ne_total, ne10, ne11, ne12, ne13, s01, + s02, s03, s10, s11, s12, s1, s2, s3, ne00_fd, ne01_fd, ne02_fd, + ne11_fd, ne12_fd); } } diff --git a/ggml/src/ggml-cuda/snake.cu b/ggml/src/ggml-cuda/snake.cu new file mode 100644 index 00000000000..384638c1f47 --- /dev/null +++ b/ggml/src/ggml-cuda/snake.cu @@ -0,0 +1,72 @@ +#include "snake.cuh" +#include "convert.cuh" + +// Fused Snake activation: y = x + sin^2(a * x) * inv_b +// x: [T, C] (T contiguous), a: [1, C], inv_b: [1, C] +// Supports F32, F16, BF16 data with F32 compute. + +template <typename T> +static __global__ void snake_kernel( + const T * __restrict__ x, + const float * __restrict__ a, + const float * __restrict__ inv_b, + T * __restrict__ dst, + const int total, + const uint3 T_len_fastdiv) { + const int idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx >= total) return; + + const int c = (int) fastdiv((uint32_t) idx, T_len_fastdiv); + + const float xi = ggml_cuda_cast<float>(x[idx]); + const float s = sinf(a[c] * xi); + dst[idx] = ggml_cuda_cast<T>(xi + s * s * inv_b[c]); +} + +// Internal launcher with explicit x/a/inv_b/dst tensors. +// Shared by the public op (reads dst->src) and the fusion path (explicit args). +static void launch_snake(ggml_backend_cuda_context & ctx, + const ggml_tensor * x, + const ggml_tensor * a, + const ggml_tensor * inv_b, + ggml_tensor * dst) { + const float * a_d = (const float *)a->data; + const float * inv_b_d = (const float *)inv_b->data; + + const int T = (int)x->ne[0]; + const int C = (int)x->ne[1]; + const int total = T * C; + const uint3 T_len_fastdiv = init_fastdiv_values((uint64_t) T); + + const int block_size = 256; + const int grid_size = (total + block_size - 1) / block_size; + + cudaStream_t stream = ctx.stream(); + + switch (x->type) { + case GGML_TYPE_F32: { + snake_kernel<<<grid_size, block_size, 0, stream>>>( + (const float *)x->data, a_d, inv_b_d, (float *)dst->data, total, T_len_fastdiv); + } break; + case GGML_TYPE_F16: { + snake_kernel<<<grid_size, block_size, 0, stream>>>( + (const half *)x->data, a_d, inv_b_d, (half *)dst->data, total, T_len_fastdiv); + } break; + case GGML_TYPE_BF16: { + snake_kernel<<<grid_size, block_size, 0, stream>>>( + (const nv_bfloat16 *)x->data, a_d, inv_b_d, (nv_bfloat16 *)dst->data, total, T_len_fastdiv); + } break; + default: + GGML_ABORT("snake: unsupported type"); + } +} + +// Fusion entry: caller supplies x/a/inv_b explicitly from the matched +// mul -> sin -> sqr -> mul -> add pattern. The dst is the trailing add output. +void ggml_cuda_op_snake_fused(ggml_backend_cuda_context & ctx, + const ggml_tensor * x, + const ggml_tensor * a, + const ggml_tensor * inv_b, + ggml_tensor * dst) { + launch_snake(ctx, x, a, inv_b, dst); +} diff --git a/ggml/src/ggml-cuda/snake.cuh b/ggml/src/ggml-cuda/snake.cuh new file mode 100644 index 00000000000..7f6f1cb3b41 --- /dev/null +++ b/ggml/src/ggml-cuda/snake.cuh @@ -0,0 +1,8 @@ +#include "common.cuh" + +// Fusion entry point. Caller supplies x/a/inv_b explicitly. +void ggml_cuda_op_snake_fused(ggml_backend_cuda_context & ctx, + const ggml_tensor * x, + const ggml_tensor * a, + const ggml_tensor * inv_b, + ggml_tensor * dst); diff --git a/ggml/src/ggml-cuda/softcap.cu b/ggml/src/ggml-cuda/softcap.cu index 40dfe45d65c..9f0fa1051cf 100644 --- a/ggml/src/ggml-cuda/softcap.cu +++ b/ggml/src/ggml-cuda/softcap.cu @@ -1,18 +1,21 @@ #include "softcap.cuh" static __global__ void softcap_f32(const float * x, float * dst, const float scale, const float softcap, const int k) { + ggml_cuda_pdl_lc(); const int i = blockDim.x*blockIdx.x + threadIdx.x; if (i >= k) { return; } + ggml_cuda_pdl_sync(); dst[i] = tanhf(scale * x[i]) * softcap; } static void softcap_f32_cuda(const float * x, float * dst, const float scale, const float softcap, const int k, cudaStream_t stream) { const int num_blocks = (k + CUDA_SOFTCAP_BLOCK_SIZE - 1) / CUDA_SOFTCAP_BLOCK_SIZE; - softcap_f32<<<num_blocks, CUDA_SOFTCAP_BLOCK_SIZE, 0, stream>>>(x, dst, scale, softcap, k); + const ggml_cuda_kernel_launch_params launch_params = ggml_cuda_kernel_launch_params(num_blocks, CUDA_SOFTCAP_BLOCK_SIZE, 0, stream); + ggml_cuda_kernel_launch(softcap_f32, launch_params, x, dst, scale, softcap, k); } // fused GGML_OP_SCALE + GGML_UNARY_OP_TANH + GGML_OP_SCALE diff --git a/ggml/src/ggml-cuda/softmax.cu b/ggml/src/ggml-cuda/softmax.cu index 1ae84ebf630..285c0e9543a 100644 --- a/ggml/src/ggml-cuda/softmax.cu +++ b/ggml/src/ggml-cuda/softmax.cu @@ -46,7 +46,7 @@ struct soft_max_params { }; // When ncols_template == 0 the bounds for the loops in this function are not known and can't be unrolled. -// As we want to keep pragma unroll for all other cases we supress the clang transformation warning here. +// As we want to keep pragma unroll for all other cases we suppress the clang transformation warning here. #ifdef __clang__ #pragma clang diagnostic push #pragma clang diagnostic ignored "-Wpass-failed" @@ -75,9 +75,6 @@ static __global__ void soft_max_f32( const int block_size = block_size_template == 0 ? blockDim.x : block_size_template; - const int warp_id = threadIdx.x / WARP_SIZE; - const int lane_id = threadIdx.x % WARP_SIZE; - const float slope = get_alibi_slope(p.max_bias, i02, p.n_head_log2, p.m0, p.m1); extern __shared__ float data_soft_max_f32[]; @@ -102,21 +99,7 @@ static __global__ void soft_max_f32( } // find the max value in the block - max_val = warp_reduce_max(max_val); - if (block_size > WARP_SIZE) { - if (warp_id == 0) { - buf_iw[lane_id] = -INFINITY; - } - __syncthreads(); - - if (lane_id == 0) { - buf_iw[warp_id] = max_val; - } - __syncthreads(); - - max_val = buf_iw[lane_id]; - max_val = warp_reduce_max(max_val); - } + max_val = block_reduce<block_reduce_method::MAX, block_size_template>(max_val, buf_iw); float tmp = 0.0f; // partial sum @@ -134,22 +117,7 @@ static __global__ void soft_max_f32( } // find the sum of exps in the block - tmp = warp_reduce_sum(tmp); - if (block_size > WARP_SIZE) { - __syncthreads(); - if (warp_id == 0) { - buf_iw[lane_id] = 0.0f; - } - __syncthreads(); - - if (lane_id == 0) { - buf_iw[warp_id] = tmp; - } - __syncthreads(); - - tmp = buf_iw[lane_id]; - tmp = warp_reduce_sum(tmp); - } + tmp = block_reduce<block_reduce_method::SUM, block_size_template>(tmp, buf_iw); if (sinks) { tmp += expf(sinks[i02] - max_val); @@ -169,50 +137,6 @@ static __global__ void soft_max_f32( } } - -// TODO: This is a common pattern used across kernels that could be moved to common.cuh + templated -static __device__ float two_stage_warp_reduce_max(float val) { - val = warp_reduce_max(val); - if (blockDim.x > WARP_SIZE) { - assert((blockDim.x <= 1024) && (blockDim.x % WARP_SIZE) == 0); - __shared__ float local_vals[32]; - const int warp_id = threadIdx.x / WARP_SIZE; - const int lane_id = threadIdx.x % WARP_SIZE; - if (lane_id == 0) { - local_vals[warp_id] = val; - } - __syncthreads(); - val = -INFINITY; - if (lane_id < (static_cast<int>(blockDim.x) / WARP_SIZE)) { - val = local_vals[lane_id]; - } - return warp_reduce_max(val); - } else { - return val; - } -} - -static __device__ float two_stage_warp_reduce_sum(float val) { - val = warp_reduce_sum(val); - if (blockDim.x > WARP_SIZE) { - assert((blockDim.x <= 1024) && (blockDim.x % WARP_SIZE) == 0); - __shared__ float local_vals[32]; - const int warp_id = threadIdx.x / WARP_SIZE; - const int lane_id = threadIdx.x % WARP_SIZE; - if (lane_id == 0) { - local_vals[warp_id] = val; - } - __syncthreads(); - val = 0.0f; - if (lane_id < (static_cast<int>(blockDim.x) / WARP_SIZE)) { - val = local_vals[lane_id]; - } - return warp_reduce_sum(val); - } else { - return val; - } -} - // TODO: Template to allow keeping ncols in registers if they fit static __device__ void soft_max_f32_parallelize_cols_single_row(const float * __restrict__ x, float * __restrict__ dst, @@ -230,6 +154,7 @@ static __device__ void soft_max_f32_parallelize_cols_single_row(const float * __ float local_vals[n_elem_per_thread] = { -INFINITY, -INFINITY, -INFINITY, -INFINITY }; float local_max = -INFINITY; const int step_size = gridDim.x * blockDim.x; + __shared__ float shared_vals[32]; // Compute thread-local max for (int col = col_start; col < p.ncols;) { @@ -246,7 +171,7 @@ static __device__ void soft_max_f32_parallelize_cols_single_row(const float * __ } // Compute CTA-level max - local_max = two_stage_warp_reduce_max(local_max); + local_max = block_reduce<block_reduce_method::MAX>(local_max, shared_vals); // Store CTA-level max to GMEM if (tid == 0) { @@ -261,7 +186,7 @@ static __device__ void soft_max_f32_parallelize_cols_single_row(const float * __ } else { local_max = -INFINITY; } - local_max = two_stage_warp_reduce_max(local_max); + local_max = block_reduce<block_reduce_method::MAX>(local_max, shared_vals); // Compute softmax dividends, accumulate divisor float tmp_expf = 0.0f; @@ -284,7 +209,7 @@ static __device__ void soft_max_f32_parallelize_cols_single_row(const float * __ } // Reduce divisor within CTA - tmp_expf = two_stage_warp_reduce_sum(tmp_expf); + tmp_expf = block_reduce<block_reduce_method::SUM>(tmp_expf, shared_vals); // Store CTA-level sum to GMEM if (tid == 0) { @@ -298,7 +223,7 @@ static __device__ void soft_max_f32_parallelize_cols_single_row(const float * __ } else { tmp_expf = 0.0f; } - tmp_expf = two_stage_warp_reduce_sum(tmp_expf); + tmp_expf = block_reduce<block_reduce_method::SUM>(tmp_expf, shared_vals); // Divide dividend by global sum + store data for (int col = col_start; col < p.ncols;) { diff --git a/ggml/src/ggml-cuda/solve_tri.cu b/ggml/src/ggml-cuda/solve_tri.cu index 177ffc268f1..07ca33f513b 100644 --- a/ggml/src/ggml-cuda/solve_tri.cu +++ b/ggml/src/ggml-cuda/solve_tri.cu @@ -83,7 +83,7 @@ static void solve_tri_f32_cublas(ggml_backend_cuda_context & ctx, // ====================== // When ncols_template == 0 the bounds for the loops in this function are not // known and can't be unrolled. As we want to keep pragma unroll for all other -// cases we supress the clang transformation warning here. +// cases we suppress the clang transformation warning here. #ifdef __clang__ # pragma clang diagnostic push # pragma clang diagnostic ignored "-Wpass-failed" diff --git a/ggml/src/ggml-cuda/ssm-conv.cu b/ggml/src/ggml-cuda/ssm-conv.cu index 6d5ea704c65..1463169cf78 100644 --- a/ggml/src/ggml-cuda/ssm-conv.cu +++ b/ggml/src/ggml-cuda/ssm-conv.cu @@ -1,10 +1,18 @@ +#include "common.cuh" #include "ssm-conv.cuh" +#include "unary.cuh" -template <size_t split_d_inner, size_t d_conv> -static __global__ void ssm_conv_f32(const float * __restrict__ src0, const float * __restrict__ src1, +template <bool apply_silu, size_t split_d_inner, size_t d_conv> +static __global__ void ssm_conv_f32(const float * src0_ptr, const float * src1_ptr, + const float * bias_ptr, const int src0_nb0, const int src0_nb1, const int src0_nb2, const int src1_nb1, - float * __restrict__ dst, const int dst_nb0, const int dst_nb1, const int dst_nb2, + float * dst_ptr, const int dst_nb0, const int dst_nb1, const int dst_nb2, const int64_t n_t) { + ggml_cuda_pdl_lc(); + const float * GGML_CUDA_RESTRICT src0 = src0_ptr; + const float * GGML_CUDA_RESTRICT src1 = src1_ptr; + const float * GGML_CUDA_RESTRICT bias = bias_ptr; + float * GGML_CUDA_RESTRICT dst = dst_ptr; GGML_UNUSED(src0_nb0); const int tid = threadIdx.x; const int bidx = blockIdx.x; @@ -21,11 +29,14 @@ static __global__ void ssm_conv_f32(const float * __restrict__ src0, const float float x[d_conv] = { 0.0f }; float w[d_conv] = { 0.0f }; + ggml_cuda_pdl_sync(); #pragma unroll for (size_t j = 0; j < d_conv; j++) { w[j] = w_block[tid * stride_w + j]; } + float b = bias != nullptr ? bias[bidy * split_d_inner + tid] : 0.0f; + for (int64_t i = 0; i < n_t; i++) { float sumf = 0.0f; @@ -41,12 +52,14 @@ static __global__ void ssm_conv_f32(const float * __restrict__ src0, const float for (size_t j = 0; j < d_conv; j++) { sumf += x[(i + j) % d_conv] * w[j]; } - y_block[i * stride_y + tid] = sumf; + sumf += b; + y_block[i * stride_y + tid] = apply_silu ? ggml_cuda_op_silu_single(sumf) : sumf; } } -template <size_t split_d_inner, size_t d_conv, int64_t split_n_t> +template <bool apply_silu, size_t split_d_inner, size_t d_conv, int64_t split_n_t> static __global__ void ssm_conv_long_token_f32(const float * __restrict__ src0, const float * __restrict__ src1, + const float * __restrict__ bias, const int src0_nb0, const int src0_nb1, const int src0_nb2, const int src1_nb1, float * __restrict__ dst, const int dst_nb0, const int dst_nb1, const int dst_nb2, const int64_t n_t) { @@ -65,37 +78,53 @@ static __global__ void ssm_conv_long_token_f32(const float * __restrict__ src0, const int stride_w = src1_nb1 / sizeof(float); const int stride_y = dst_nb1 / sizeof(float); - float x[d_conv] = { 0.0f }; - float w[d_conv] = { 0.0f }; + const int64_t local_n_t = min(split_n_t, n_t - bidz * split_n_t); + const int n_cols = d_conv - 1 + split_n_t; + + extern __shared__ float smem[]; + + constexpr int load_cols = d_conv - 1 + split_n_t; + constexpr int total_elems = split_d_inner * load_cols; + int row = tid / load_cols; + int col = tid % load_cols; +#pragma unroll + for (int idx = 0; idx < total_elems; idx += split_d_inner) { + if (row < (int)split_d_inner) { + smem[row * n_cols + col] = x_block[row * stride_x + col]; + } + + col += split_d_inner; + row += col / load_cols; + col = col % load_cols; + if (idx >= total_elems - tid - split_d_inner) { + break; + } + } + __syncthreads(); + // Load weights into registers (done once, small) + float w[d_conv] = { 0.0f }; #pragma unroll for (size_t j = 0; j < d_conv; j++) { w[j] = w_block[tid * stride_w + j]; } -#pragma unroll - for (int64_t i = 0; i < split_n_t; i++) { - if (bidz * split_n_t + i < n_t) { - float sumf = 0.0f; - - if (i == 0) { - for (size_t j = 0; j < d_conv; j++) { - x[j] = x_block[tid * stride_x + j]; - } - } else { - x[(i - 1) % d_conv] = x_block[tid * stride_x + i + d_conv - 1]; - } + float b = bias != nullptr ? bias[bidy * split_d_inner + tid] : 0.0f; + // Compute from shared memory + for (int64_t i = 0; i < local_n_t; i++) { + float sumf = 0.0f; #pragma unroll - for (size_t j = 0; j < d_conv; j++) { - sumf += x[(i + j) % d_conv] * w[j]; - } - y_block[i * stride_y + tid] = sumf; + for (size_t j = 0; j < d_conv; j++) { + sumf += smem[tid * n_cols + i + j] * w[j]; } + sumf += b; + y_block[i * stride_y + tid] = apply_silu ? ggml_cuda_op_silu_single(sumf) : sumf; } } -static void ssm_conv_f32_cuda(const float * src0, const float * src1, const int src0_nb0, const int src0_nb1, +template <bool apply_silu> +static void ssm_conv_f32_cuda(const float * src0, const float * src1, const float * bias, const int src0_nb0, const int src0_nb1, const int src0_nb2, const int src1_nb1, float * dst, const int dst_nb0, const int dst_nb1, const int dst_nb2, const int64_t nc, const int64_t nr, const int64_t n_t, const int64_t n_s, cudaStream_t stream) { @@ -106,45 +135,72 @@ static void ssm_conv_f32_cuda(const float * src0, const float * src1, const int constexpr int kNC = decltype(NC)::value; if (n_t <= 32) { const dim3 blocks(n_s, (nr + threads - 1) / threads, 1); - ssm_conv_f32<threads, kNC><<<blocks, threads, 0, stream>>>(src0, src1, src0_nb0, src0_nb1, src0_nb2, src1_nb1, - dst, dst_nb0, dst_nb1, dst_nb2, n_t); + const ggml_cuda_kernel_launch_params launch_params = ggml_cuda_kernel_launch_params(blocks, threads, 0, stream); + ggml_cuda_kernel_launch(ssm_conv_f32<apply_silu, threads, kNC>, launch_params, src0, src1, bias, src0_nb0, src0_nb1, + src0_nb2, src1_nb1, dst, dst_nb0, dst_nb1, dst_nb2, n_t); } else { const int64_t split_n_t = 32; dim3 blocks(n_s, (nr + threads - 1) / threads, (n_t + split_n_t - 1) / split_n_t); - ssm_conv_long_token_f32<threads, kNC, split_n_t><<<blocks, threads, 0, stream>>>( - src0, src1, src0_nb0, src0_nb1, src0_nb2, src1_nb1, dst, dst_nb0, dst_nb1, dst_nb2, n_t); + const size_t smem_size = threads * (kNC - 1 + split_n_t) * sizeof(float); + ssm_conv_long_token_f32<apply_silu, threads, kNC, split_n_t><<<blocks, threads, smem_size, stream>>>( + src0, src1, bias, src0_nb0, src0_nb1, src0_nb2, src1_nb1, dst, dst_nb0, dst_nb1, dst_nb2, n_t); } }; switch (nc) { - case 3: launch_kernel(std::integral_constant<int, 3>{}); break; - case 4: launch_kernel(std::integral_constant<int, 4>{}); break; - case 9: launch_kernel(std::integral_constant<int, 9>{}); break; - default: GGML_ABORT("Only support kernel sizes 3, 4, 9 right now."); + case 3: launch_kernel(std::integral_constant<int, 3 >{}); break; + case 4: launch_kernel(std::integral_constant<int, 4 >{}); break; + case 5: launch_kernel(std::integral_constant<int, 5 >{}); break; + case 9: launch_kernel(std::integral_constant<int, 9 >{}); break; + case 15: launch_kernel(std::integral_constant<int, 15>{}); break; + default: GGML_ABORT("Only support kernel sizes 3, 4, 5, 9, 15 right now."); } } -void ggml_cuda_op_ssm_conv(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { +void ggml_cuda_op_ssm_conv(ggml_backend_cuda_context & ctx, ggml_tensor * dst, ggml_tensor * bias_add_node, ggml_tensor * silu_dst) { const struct ggml_tensor * src0 = dst->src[0]; // conv_x const struct ggml_tensor * src1 = dst->src[1]; // conv1d.weight + const bool fuse_bias = bias_add_node != nullptr; + const bool fuse_silu = silu_dst != nullptr; + + // bias always comes with silu. + GGML_ASSERT(!fuse_bias || fuse_silu); + + // The bias (when fused) is the non-conv operand of the ADD node. + const struct ggml_tensor * bias = fuse_bias ? (bias_add_node->src[0] == dst ? bias_add_node->src[1] : bias_add_node->src[0]) : nullptr; + + // When fusing, write to silu_dst (the node downstream references). + const struct ggml_tensor * out = fuse_silu ? silu_dst : dst; const int64_t nc = src1->ne[0]; // d_conv const int64_t nr = src0->ne[1]; // d_inner - const int64_t n_t = dst->ne[1]; // tokens per sequence - const int64_t n_s = dst->ne[2]; // number of sequences in the batch + const int64_t n_t = out->ne[1]; // tokens per sequence + const int64_t n_s = out->ne[2]; // number of sequences in the batch - GGML_ASSERT(dst->ne[0] == nr); + GGML_ASSERT(out->ne[0] == nr); GGML_ASSERT(src0->nb[0] == sizeof(float)); GGML_ASSERT(src1->nb[0] == sizeof(float)); GGML_ASSERT(src0->nb[1] == src0->ne[0] * sizeof(float)); const float * src0_d = (const float *) src0->data; const float * src1_d = (const float *) src1->data; - float * dst_d = (float *) dst->data; + const float * bias_d = fuse_bias ? (const float *) bias->data : nullptr; + float * dst_d = (float *) out->data; cudaStream_t stream = ctx.stream(); GGML_ASSERT(src0->type == GGML_TYPE_F32); - GGML_ASSERT(dst->type == GGML_TYPE_F32); - ssm_conv_f32_cuda(src0_d, src1_d, src0->nb[0], src0->nb[1], src0->nb[2], src1->nb[1], dst_d, dst->nb[0], dst->nb[1], - dst->nb[2], nc, nr, n_t, n_s, stream); + GGML_ASSERT(out->type == GGML_TYPE_F32); + if (fuse_bias) { + GGML_ASSERT(bias->type == GGML_TYPE_F32); + GGML_ASSERT(ggml_is_contiguous(bias)); + GGML_ASSERT(ggml_nelements(bias) == nr); + } + + if (fuse_silu) { + ssm_conv_f32_cuda<true>(src0_d, src1_d, bias_d, src0->nb[0], src0->nb[1], src0->nb[2], src1->nb[1], dst_d, out->nb[0], out->nb[1], + out->nb[2], nc, nr, n_t, n_s, stream); + } else { + ssm_conv_f32_cuda<false>(src0_d, src1_d, bias_d, src0->nb[0], src0->nb[1], src0->nb[2], src1->nb[1], dst_d, out->nb[0], out->nb[1], + out->nb[2], nc, nr, n_t, n_s, stream); + } } diff --git a/ggml/src/ggml-cuda/ssm-conv.cuh b/ggml/src/ggml-cuda/ssm-conv.cuh index 8e6c1f00bfa..8514ca84920 100644 --- a/ggml/src/ggml-cuda/ssm-conv.cuh +++ b/ggml/src/ggml-cuda/ssm-conv.cuh @@ -1,3 +1,3 @@ #include "common.cuh" -void ggml_cuda_op_ssm_conv(ggml_backend_cuda_context & ctx, ggml_tensor * dst); +void ggml_cuda_op_ssm_conv(ggml_backend_cuda_context & ctx, ggml_tensor * dst, ggml_tensor * bias_add_node = nullptr, ggml_tensor * silu_dst = nullptr); diff --git a/ggml/src/ggml-cuda/ssm-scan.cu b/ggml/src/ggml-cuda/ssm-scan.cu index c1d4e2bc8df..3022249c77d 100644 --- a/ggml/src/ggml-cuda/ssm-scan.cu +++ b/ggml/src/ggml-cuda/ssm-scan.cu @@ -17,15 +17,24 @@ using namespace cub; #endif // __clang__ template <size_t splitD, size_t N, size_t L_template> __global__ void __launch_bounds__(splitD, 1) - ssm_scan_f32(const float *__restrict__ src0, const float *__restrict__ src1, const float *__restrict__ src2, - const float *__restrict__ src3, const float *__restrict__ src4, const float *__restrict__ src5, - const int32_t * __restrict__ src6, float * __restrict__ dst, + ssm_scan_f32(const float * src0_ptr, const float * src1_ptr, const float * src2_ptr, + const float * src3_ptr, const float * src4_ptr, const float * src5_ptr, + const int32_t * src6_ptr, float * dst_ptr, const int src0_nb2, const int src0_nb3, const int src1_nb2, const int src1_nb3, const int src2_nb1, const int src2_nb2, const int src3_nb1, const int src4_nb2, const int src4_nb3, const int src5_nb2, const int src5_nb3, const int64_t s_off, const int64_t d_inner, const int64_t L_param) { + const float * GGML_CUDA_RESTRICT src0 = src0_ptr; + const float * GGML_CUDA_RESTRICT src1 = src1_ptr; + const float * GGML_CUDA_RESTRICT src2 = src2_ptr; + const float * GGML_CUDA_RESTRICT src3 = src3_ptr; + const float * GGML_CUDA_RESTRICT src4 = src4_ptr; + const float * GGML_CUDA_RESTRICT src5 = src5_ptr; + const int32_t * GGML_CUDA_RESTRICT src6 = src6_ptr; + float * GGML_CUDA_RESTRICT dst = dst_ptr; const size_t L = L_template == 0 ? L_param : L_template; + ggml_cuda_pdl_sync(); const float *s0_block = (const float *)((const char *)src0 + src6[blockIdx.x] * src0_nb3 + blockIdx.y * splitD * src0_nb2); const float *x_block = (const float *)((const char *)src1 + (blockIdx.x * src1_nb3) + blockIdx.y * splitD * sizeof(float)); const float *dt_block = (const float *)((const char *)src2 + (blockIdx.x * src2_nb2) + blockIdx.y * splitD * sizeof(float)); @@ -58,6 +67,7 @@ __global__ void __launch_bounds__(splitD, 1) __shared__ CubTempStorage cub_temp_storage; BlockLoad(cub_temp_storage.load_temp).Load(A_block, regA); + __syncthreads(); BlockLoad(cub_temp_storage.load_temp).Load(s0_block, regs0); #else const int stride_s0 = src0_nb2 / sizeof(float); @@ -96,6 +106,7 @@ __global__ void __launch_bounds__(splitD, 1) regs0[n] = state; } y_block[i * stride_y + threadIdx.x] = sumf; + __syncthreads(); } #ifdef USE_CUB @@ -117,13 +128,21 @@ __global__ void __launch_bounds__(splitD, 1) template <int c_factor, int d_state> __global__ void __launch_bounds__(d_state, 1) ssm_scan_f32_group( - const float * __restrict__ src0, const float * __restrict__ src1, const float * __restrict__ src2, - const float * __restrict__ src3, const float * __restrict__ src4, const float * __restrict__ src5, - const int32_t * __restrict__ src6, float * __restrict__ dst, + const float * src0_ptr, const float * src1_ptr, const float * src2_ptr, + const float * src3_ptr, const float * src4_ptr, const float * src5_ptr, + const int32_t * src6_ptr, float * dst_ptr, const int src0_nb2, const int src0_nb3, const int src1_nb2, const int src1_nb3, const int src2_nb1, const int src2_nb2, const int src3_nb1, const int src4_nb2, const int src4_nb3, const int src5_nb2, const int src5_nb3, const int64_t s_off, const int64_t n_head, const int64_t d_head, const int64_t n_group, const int64_t n_tok) { + const float * GGML_CUDA_RESTRICT src0 = src0_ptr; + const float * GGML_CUDA_RESTRICT src1 = src1_ptr; + const float * GGML_CUDA_RESTRICT src2 = src2_ptr; + const float * GGML_CUDA_RESTRICT src3 = src3_ptr; + const float * GGML_CUDA_RESTRICT src4 = src4_ptr; + const float * GGML_CUDA_RESTRICT src5 = src5_ptr; + const int32_t * GGML_CUDA_RESTRICT src6 = src6_ptr; + float * GGML_CUDA_RESTRICT dst = dst_ptr; const int warp = threadIdx.x / WARP_SIZE; const int lane = threadIdx.x % WARP_SIZE; @@ -135,6 +154,7 @@ __global__ void __launch_bounds__(d_state, 1) const int group_off = (head_idx / (n_head / n_group)) * d_state * sizeof(float); + ggml_cuda_pdl_sync(); // TODO: refactor strides to be in elements/floats instead of bytes to be cleaner and consistent with the rest of the codebase const float * s0_warp = (const float *) ((const char *) src0 + src6[seq_idx] * src0_nb3 + head_idx * src0_nb2 + head_off * d_state); const float * x_warp = (const float *) ((const char *) src1 + (seq_idx * src1_nb3) + (warp_idx * sizeof(float))); @@ -206,7 +226,8 @@ static void ssm_scan_f32_cuda(const float * src0, const float * src1, const floa constexpr int num_warps = threads/WARP_SIZE; const dim3 blocks((n_head * head_dim + (num_warps - 1)) / num_warps, n_seq, 1); - ssm_scan_f32_group<128/WARP_SIZE, 128><<<blocks, threads, 0, stream>>>( + const ggml_cuda_kernel_launch_params launch_params = ggml_cuda_kernel_launch_params(blocks, threads, 0, stream); + ggml_cuda_kernel_launch(ssm_scan_f32_group<128/WARP_SIZE, 128>, launch_params, src0, src1, src2, src3, src4, src5, src6, dst, src0_nb2, src0_nb3, src1_nb2, src1_nb3, src2_nb1, src2_nb2, src3_nb1, src4_nb2, src4_nb3, src5_nb2, src5_nb3, s_off, n_head, head_dim, n_group, n_tok); @@ -215,7 +236,8 @@ static void ssm_scan_f32_cuda(const float * src0, const float * src1, const floa constexpr int num_warps = threads/WARP_SIZE; const dim3 blocks((n_head * head_dim + (num_warps - 1)) / num_warps, n_seq, 1); - ssm_scan_f32_group<256/WARP_SIZE, 256><<<blocks, threads, 0, stream>>>( + const ggml_cuda_kernel_launch_params launch_params = ggml_cuda_kernel_launch_params(blocks, threads, 0, stream); + ggml_cuda_kernel_launch(ssm_scan_f32_group<256/WARP_SIZE, 256>, launch_params, src0, src1, src2, src3, src4, src5, src6, dst, src0_nb2, src0_nb3, src1_nb2, src1_nb3, src2_nb1, src2_nb2, src3_nb1, src4_nb2, src4_nb3, src5_nb2, src5_nb3, s_off, n_head, head_dim, n_group, n_tok); @@ -229,60 +251,60 @@ static void ssm_scan_f32_cuda(const float * src0, const float * src1, const floa GGML_ASSERT(head_dim == 1); GGML_ASSERT(n_group == 1); const dim3 blocks(n_seq, (n_head + threads - 1) / threads, 1); - const int smem_size = (threads * (d_state + 1) * 2) * sizeof(float); if (d_state == 16) { + const ggml_cuda_kernel_launch_params launch_params = ggml_cuda_kernel_launch_params(blocks, threads, 0, stream); switch (n_tok) { case 1: - ssm_scan_f32<threads, 16, 1><<<blocks, threads, smem_size, stream>>>( + ggml_cuda_kernel_launch(ssm_scan_f32<threads, 16, 1>, launch_params, src0, src1, src2, src3, src4, src5, src6, dst, src0_nb2, src0_nb3, src1_nb2, src1_nb3, src2_nb1, src2_nb2, src3_nb1, src4_nb2, src4_nb3, src5_nb2, src5_nb3, s_off, n_head, n_tok); break; case 2: - ssm_scan_f32<threads, 16, 2><<<blocks, threads, smem_size, stream>>>( + ggml_cuda_kernel_launch(ssm_scan_f32<threads, 16, 2>, launch_params, src0, src1, src2, src3, src4, src5, src6, dst, src0_nb2, src0_nb3, src1_nb2, src1_nb3, src2_nb1, src2_nb2, src3_nb1, src4_nb2, src4_nb3, src5_nb2, src5_nb3, s_off, n_head, n_tok); break; case 3: - ssm_scan_f32<threads, 16, 3><<<blocks, threads, smem_size, stream>>>( + ggml_cuda_kernel_launch(ssm_scan_f32<threads, 16, 3>, launch_params, src0, src1, src2, src3, src4, src5, src6, dst, src0_nb2, src0_nb3, src1_nb2, src1_nb3, src2_nb1, src2_nb2, src3_nb1, src4_nb2, src4_nb3, src5_nb2, src5_nb3, s_off, n_head, n_tok); break; case 4: - ssm_scan_f32<threads, 16, 4><<<blocks, threads, smem_size, stream>>>( + ggml_cuda_kernel_launch(ssm_scan_f32<threads, 16, 4>, launch_params, src0, src1, src2, src3, src4, src5, src6, dst, src0_nb2, src0_nb3, src1_nb2, src1_nb3, src2_nb1, src2_nb2, src3_nb1, src4_nb2, src4_nb3, src5_nb2, src5_nb3, s_off, n_head, n_tok); break; case 5: - ssm_scan_f32<threads, 16, 5><<<blocks, threads, smem_size, stream>>>( + ggml_cuda_kernel_launch(ssm_scan_f32<threads, 16, 5>, launch_params, src0, src1, src2, src3, src4, src5, src6, dst, src0_nb2, src0_nb3, src1_nb2, src1_nb3, src2_nb1, src2_nb2, src3_nb1, src4_nb2, src4_nb3, src5_nb2, src5_nb3, s_off, n_head, n_tok); break; case 6: - ssm_scan_f32<threads, 16, 6><<<blocks, threads, smem_size, stream>>>( + ggml_cuda_kernel_launch(ssm_scan_f32<threads, 16, 6>, launch_params, src0, src1, src2, src3, src4, src5, src6, dst, src0_nb2, src0_nb3, src1_nb2, src1_nb3, src2_nb1, src2_nb2, src3_nb1, src4_nb2, src4_nb3, src5_nb2, src5_nb3, s_off, n_head, n_tok); break; case 7: - ssm_scan_f32<threads, 16, 7><<<blocks, threads, smem_size, stream>>>( + ggml_cuda_kernel_launch(ssm_scan_f32<threads, 16, 7>, launch_params, src0, src1, src2, src3, src4, src5, src6, dst, src0_nb2, src0_nb3, src1_nb2, src1_nb3, src2_nb1, src2_nb2, src3_nb1, src4_nb2, src4_nb3, src5_nb2, src5_nb3, s_off, n_head, n_tok); break; case 8: - ssm_scan_f32<threads, 16, 8><<<blocks, threads, smem_size, stream>>>( + ggml_cuda_kernel_launch(ssm_scan_f32<threads, 16, 8>, launch_params, src0, src1, src2, src3, src4, src5, src6, dst, src0_nb2, src0_nb3, src1_nb2, src1_nb3, src2_nb1, src2_nb2, src3_nb1, src4_nb2, src4_nb3, src5_nb2, src5_nb3, s_off, n_head, n_tok); break; default: - ssm_scan_f32<threads, 16, 0><<<blocks, threads, smem_size, stream>>>( + ggml_cuda_kernel_launch(ssm_scan_f32<threads, 16, 0>, launch_params, src0, src1, src2, src3, src4, src5, src6, dst, src0_nb2, src0_nb3, src1_nb2, src1_nb3, src2_nb1, src2_nb2, src3_nb1, src4_nb2, src4_nb3, src5_nb2, src5_nb3, s_off, n_head, n_tok); diff --git a/ggml/src/ggml-cuda/sumrows.cu b/ggml/src/ggml-cuda/sumrows.cu index 4025771aadb..0003658ca95 100644 --- a/ggml/src/ggml-cuda/sumrows.cu +++ b/ggml/src/ggml-cuda/sumrows.cu @@ -7,10 +7,12 @@ void sum_rows_f32_cuda(const float * x, float * dst, const int ncols, const int const dim3 block_nums(nrows, 1, 1); if ((nrows / nsm) < 2) { const dim3 block_dims(512, 1, 1); - reduce_rows_f32</*norm=*/false><<<block_nums, block_dims, 0, stream>>>(x, dst, ncols); + const ggml_cuda_kernel_launch_params launch_params = ggml_cuda_kernel_launch_params(block_nums, block_dims, 0, stream); + ggml_cuda_kernel_launch(reduce_rows_f32</*norm=*/false>, launch_params, x, dst, ncols); } else { const dim3 block_dims(ncols < 1024 ? 32 : 128, 1, 1); - reduce_rows_f32</*norm=*/false><<<block_nums, block_dims, 0, stream>>>(x, dst, ncols); + const ggml_cuda_kernel_launch_params launch_params = ggml_cuda_kernel_launch_params(block_nums, block_dims, 0, stream); + ggml_cuda_kernel_launch(reduce_rows_f32</*norm=*/false>, launch_params, x, dst, ncols); } } @@ -34,10 +36,12 @@ void ggml_cuda_op_sum_rows(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { if ((nrows / nsm) < 2) { // Increase num threads to 512 for small nrows to better hide the latency const dim3 block_dims(512, 1, 1); - reduce_rows_f32</*norm=*/false><<<block_nums, block_dims, 0, stream>>>(src0_d, dst_d, ncols); + const ggml_cuda_kernel_launch_params launch_params = ggml_cuda_kernel_launch_params(block_nums, block_dims, 0, stream); + ggml_cuda_kernel_launch(reduce_rows_f32</*norm=*/false>, launch_params, src0_d, dst_d, ncols); } else { // Enough active SMs to hide latency, use smaller blocks to allow better scheduling const dim3 block_dims(ncols < 1024 ? 32 : 128, 1, 1); - reduce_rows_f32</*norm=*/false><<<block_nums, block_dims, 0, stream>>>(src0_d, dst_d, ncols); + const ggml_cuda_kernel_launch_params launch_params = ggml_cuda_kernel_launch_params(block_nums, block_dims, 0, stream); + ggml_cuda_kernel_launch(reduce_rows_f32</*norm=*/false>, launch_params, src0_d, dst_d, ncols); } } diff --git a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_16.cu b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_16.cu index fb26abeb0da..b2661b93162 100644 --- a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_16.cu +++ b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_16.cu @@ -2,4 +2,5 @@ #include "../fattn-mma-f16.cuh" +DECL_FATTN_MMA_F16_CASE(192, 128, 1, 16); DECL_FATTN_MMA_F16_CASE(576, 512, 1, 16); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_32.cu b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_32.cu new file mode 100644 index 00000000000..8fc3b17976e --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_32.cu @@ -0,0 +1,6 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-mma-f16.cuh" + +DECL_FATTN_MMA_F16_CASE(320, 256, 1, 32); +DECL_FATTN_MMA_F16_CASE(576, 512, 1, 32); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_8.cu b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_8.cu index dc16829021f..6ae77bec895 100644 --- a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_8.cu +++ b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_8.cu @@ -7,4 +7,6 @@ DECL_FATTN_MMA_F16_CASE(80, 80, 1, 8); DECL_FATTN_MMA_F16_CASE(96, 96, 1, 8); DECL_FATTN_MMA_F16_CASE(112, 112, 1, 8); DECL_FATTN_MMA_F16_CASE(128, 128, 1, 8); +DECL_FATTN_MMA_F16_CASE(192, 128, 1, 8); DECL_FATTN_MMA_F16_CASE(256, 256, 1, 8); +DECL_FATTN_MMA_F16_CASE(512, 512, 1, 8); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_4.cu b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_4.cu index 2074e954a32..d2415bfa957 100644 --- a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_4.cu +++ b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_4.cu @@ -8,3 +8,5 @@ DECL_FATTN_MMA_F16_CASE(96, 96, 16, 4); DECL_FATTN_MMA_F16_CASE(112, 112, 16, 4); DECL_FATTN_MMA_F16_CASE(128, 128, 16, 4); DECL_FATTN_MMA_F16_CASE(256, 256, 16, 4); +DECL_FATTN_MMA_F16_CASE(512, 512, 16, 4); +DECL_FATTN_MMA_F16_CASE(576, 512, 16, 4); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_16.cu b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_16.cu index f011a208cd2..fd41e71b142 100644 --- a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_16.cu +++ b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_16.cu @@ -2,4 +2,5 @@ #include "../fattn-mma-f16.cuh" +DECL_FATTN_MMA_F16_CASE(192, 128, 2, 16); DECL_FATTN_MMA_F16_CASE(576, 512, 2, 16); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_32.cu b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_32.cu new file mode 100644 index 00000000000..abd2b21ce04 --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_32.cu @@ -0,0 +1,6 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-mma-f16.cuh" + +DECL_FATTN_MMA_F16_CASE(320, 256, 2, 32); +DECL_FATTN_MMA_F16_CASE(576, 512, 2, 32); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_4.cu b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_4.cu index 24c64cf000f..8eec1d74e29 100644 --- a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_4.cu +++ b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_4.cu @@ -8,3 +8,5 @@ DECL_FATTN_MMA_F16_CASE(96, 96, 2, 4); DECL_FATTN_MMA_F16_CASE(112, 112, 2, 4); DECL_FATTN_MMA_F16_CASE(128, 128, 2, 4); DECL_FATTN_MMA_F16_CASE(256, 256, 2, 4); +DECL_FATTN_MMA_F16_CASE(512, 512, 2, 4); +DECL_FATTN_MMA_F16_CASE(576, 512, 2, 4); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_8.cu b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_8.cu index 163b1d939e4..9f4bef11a44 100644 --- a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_8.cu +++ b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_8.cu @@ -7,4 +7,6 @@ DECL_FATTN_MMA_F16_CASE(80, 80, 2, 8); DECL_FATTN_MMA_F16_CASE(96, 96, 2, 8); DECL_FATTN_MMA_F16_CASE(112, 112, 2, 8); DECL_FATTN_MMA_F16_CASE(128, 128, 2, 8); +DECL_FATTN_MMA_F16_CASE(192, 128, 2, 8); DECL_FATTN_MMA_F16_CASE(256, 256, 2, 8); +DECL_FATTN_MMA_F16_CASE(512, 512, 2, 8); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_16.cu b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_16.cu index f5fd0e2369c..cc41fa52f13 100644 --- a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_16.cu +++ b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_16.cu @@ -2,4 +2,5 @@ #include "../fattn-mma-f16.cuh" +DECL_FATTN_MMA_F16_CASE(192, 128, 4, 16); DECL_FATTN_MMA_F16_CASE(576, 512, 4, 16); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_4.cu b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_4.cu index 1ada657f194..3475dfea08a 100644 --- a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_4.cu +++ b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_4.cu @@ -8,3 +8,5 @@ DECL_FATTN_MMA_F16_CASE(96, 96, 4, 4); DECL_FATTN_MMA_F16_CASE(112, 112, 4, 4); DECL_FATTN_MMA_F16_CASE(128, 128, 4, 4); DECL_FATTN_MMA_F16_CASE(256, 256, 4, 4); +DECL_FATTN_MMA_F16_CASE(512, 512, 4, 4); +DECL_FATTN_MMA_F16_CASE(576, 512, 4, 4); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_8.cu b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_8.cu index bad296b4141..859bea5c525 100644 --- a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_8.cu +++ b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_8.cu @@ -7,4 +7,6 @@ DECL_FATTN_MMA_F16_CASE(80, 80, 4, 8); DECL_FATTN_MMA_F16_CASE(96, 96, 4, 8); DECL_FATTN_MMA_F16_CASE(112, 112, 4, 8); DECL_FATTN_MMA_F16_CASE(128, 128, 4, 8); +DECL_FATTN_MMA_F16_CASE(192, 128, 4, 8); DECL_FATTN_MMA_F16_CASE(256, 256, 4, 8); +DECL_FATTN_MMA_F16_CASE(512, 512, 4, 8); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_4.cu b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_4.cu index 86d4ffae27c..684cd25ce0d 100644 --- a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_4.cu +++ b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_4.cu @@ -8,3 +8,5 @@ DECL_FATTN_MMA_F16_CASE(96, 96, 8, 4); DECL_FATTN_MMA_F16_CASE(112, 112, 8, 4); DECL_FATTN_MMA_F16_CASE(128, 128, 8, 4); DECL_FATTN_MMA_F16_CASE(256, 256, 8, 4); +DECL_FATTN_MMA_F16_CASE(512, 512, 8, 4); +DECL_FATTN_MMA_F16_CASE(576, 512, 8, 4); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_8.cu b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_8.cu index 680a13ca6de..c975ce6b9b7 100644 --- a/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_8.cu +++ b/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_8.cu @@ -7,4 +7,6 @@ DECL_FATTN_MMA_F16_CASE(80, 80, 8, 8); DECL_FATTN_MMA_F16_CASE(96, 96, 8, 8); DECL_FATTN_MMA_F16_CASE(112, 112, 8, 8); DECL_FATTN_MMA_F16_CASE(128, 128, 8, 8); +DECL_FATTN_MMA_F16_CASE(192, 128, 8, 8); DECL_FATTN_MMA_F16_CASE(256, 256, 8, 8); +DECL_FATTN_MMA_F16_CASE(512, 512, 8, 8); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq192-dv128.cu b/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq192-dv128.cu new file mode 100644 index 00000000000..b571cca0df2 --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq192-dv128.cu @@ -0,0 +1,5 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-tile.cuh" + +DECL_FATTN_TILE_CASE(192, 128); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq320-dv256.cu b/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq320-dv256.cu new file mode 100644 index 00000000000..c91f508079d --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq320-dv256.cu @@ -0,0 +1,5 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-tile.cuh" + +DECL_FATTN_TILE_CASE(320, 256); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq512-dv512.cu b/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq512-dv512.cu new file mode 100644 index 00000000000..7c61d8d2ecd --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq512-dv512.cu @@ -0,0 +1,5 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-tile.cuh" + +DECL_FATTN_TILE_CASE(512, 512); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-bf16.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-bf16.cu new file mode 100644 index 00000000000..3a2fa99b05b --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-bf16.cu @@ -0,0 +1,7 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-vec.cuh" + +DECL_FATTN_VEC_CASE( 64, GGML_TYPE_BF16, GGML_TYPE_BF16); +DECL_FATTN_VEC_CASE(128, GGML_TYPE_BF16, GGML_TYPE_BF16); +DECL_FATTN_VEC_CASE(256, GGML_TYPE_BF16, GGML_TYPE_BF16); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-f16.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-f16.cu new file mode 100644 index 00000000000..60f0f6f7952 --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-f16.cu @@ -0,0 +1,7 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-vec.cuh" + +DECL_FATTN_VEC_CASE( 64, GGML_TYPE_BF16, GGML_TYPE_F16); +DECL_FATTN_VEC_CASE(128, GGML_TYPE_BF16, GGML_TYPE_F16); +DECL_FATTN_VEC_CASE(256, GGML_TYPE_BF16, GGML_TYPE_F16); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q4_0.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q4_0.cu new file mode 100644 index 00000000000..489e05f08c3 --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q4_0.cu @@ -0,0 +1,7 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-vec.cuh" + +DECL_FATTN_VEC_CASE( 64, GGML_TYPE_BF16, GGML_TYPE_Q4_0); +DECL_FATTN_VEC_CASE(128, GGML_TYPE_BF16, GGML_TYPE_Q4_0); +DECL_FATTN_VEC_CASE(256, GGML_TYPE_BF16, GGML_TYPE_Q4_0); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q4_1.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q4_1.cu new file mode 100644 index 00000000000..6fa3c26d309 --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q4_1.cu @@ -0,0 +1,7 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-vec.cuh" + +DECL_FATTN_VEC_CASE( 64, GGML_TYPE_BF16, GGML_TYPE_Q4_1); +DECL_FATTN_VEC_CASE(128, GGML_TYPE_BF16, GGML_TYPE_Q4_1); +DECL_FATTN_VEC_CASE(256, GGML_TYPE_BF16, GGML_TYPE_Q4_1); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q5_0.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q5_0.cu new file mode 100644 index 00000000000..421027fb29d --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q5_0.cu @@ -0,0 +1,7 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-vec.cuh" + +DECL_FATTN_VEC_CASE( 64, GGML_TYPE_BF16, GGML_TYPE_Q5_0); +DECL_FATTN_VEC_CASE(128, GGML_TYPE_BF16, GGML_TYPE_Q5_0); +DECL_FATTN_VEC_CASE(256, GGML_TYPE_BF16, GGML_TYPE_Q5_0); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q5_1.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q5_1.cu new file mode 100644 index 00000000000..abbc9434802 --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q5_1.cu @@ -0,0 +1,7 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-vec.cuh" + +DECL_FATTN_VEC_CASE( 64, GGML_TYPE_BF16, GGML_TYPE_Q5_1); +DECL_FATTN_VEC_CASE(128, GGML_TYPE_BF16, GGML_TYPE_Q5_1); +DECL_FATTN_VEC_CASE(256, GGML_TYPE_BF16, GGML_TYPE_Q5_1); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q8_0.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q8_0.cu new file mode 100644 index 00000000000..d641f859d81 --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q8_0.cu @@ -0,0 +1,7 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-vec.cuh" + +DECL_FATTN_VEC_CASE( 64, GGML_TYPE_BF16, GGML_TYPE_Q8_0); +DECL_FATTN_VEC_CASE(128, GGML_TYPE_BF16, GGML_TYPE_Q8_0); +DECL_FATTN_VEC_CASE(256, GGML_TYPE_BF16, GGML_TYPE_Q8_0); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-bf16.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-bf16.cu new file mode 100644 index 00000000000..d1071dc2438 --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-bf16.cu @@ -0,0 +1,7 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-vec.cuh" + +DECL_FATTN_VEC_CASE( 64, GGML_TYPE_F16, GGML_TYPE_BF16); +DECL_FATTN_VEC_CASE(128, GGML_TYPE_F16, GGML_TYPE_BF16); +DECL_FATTN_VEC_CASE(256, GGML_TYPE_F16, GGML_TYPE_BF16); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-bf16.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-bf16.cu new file mode 100644 index 00000000000..8afda314238 --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-bf16.cu @@ -0,0 +1,7 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-vec.cuh" + +DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q4_0, GGML_TYPE_BF16); +DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_BF16); +DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q4_0, GGML_TYPE_BF16); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-bf16.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-bf16.cu new file mode 100644 index 00000000000..506864ac18d --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-bf16.cu @@ -0,0 +1,7 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-vec.cuh" + +DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q4_1, GGML_TYPE_BF16); +DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_BF16); +DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q4_1, GGML_TYPE_BF16); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-bf16.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-bf16.cu new file mode 100644 index 00000000000..0bbda8371e6 --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-bf16.cu @@ -0,0 +1,7 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-vec.cuh" + +DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q5_0, GGML_TYPE_BF16); +DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_BF16); +DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q5_0, GGML_TYPE_BF16); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-bf16.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-bf16.cu new file mode 100644 index 00000000000..79be24daf9e --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-bf16.cu @@ -0,0 +1,7 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-vec.cuh" + +DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q5_1, GGML_TYPE_BF16); +DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_BF16); +DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q5_1, GGML_TYPE_BF16); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-bf16.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-bf16.cu new file mode 100644 index 00000000000..45636e5e70c --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-bf16.cu @@ -0,0 +1,7 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-vec.cuh" + +DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q8_0, GGML_TYPE_BF16); +DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_BF16); +DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q8_0, GGML_TYPE_BF16); diff --git a/ggml/src/ggml-cuda/template-instances/generate_cu_files.py b/ggml/src/ggml-cuda/template-instances/generate_cu_files.py index a5602da02bb..af05a9eff71 100755 --- a/ggml/src/ggml-cuda/template-instances/generate_cu_files.py +++ b/ggml/src/ggml-cuda/template-instances/generate_cu_files.py @@ -3,9 +3,12 @@ from glob import glob import os -HEAD_SIZES_KQ = [40, 64, 72, 80, 96, 112, 128, 256, 576] +HEAD_SIZES_KQ = [40, 64, 72, 80, 96, 112, 128, 192, 256, 320, 512, 576] -TYPES_KV = ["GGML_TYPE_F16", "GGML_TYPE_Q4_0", "GGML_TYPE_Q4_1", "GGML_TYPE_Q5_0", "GGML_TYPE_Q5_1", "GGML_TYPE_Q8_0"] +# DKQ -> DV override for asymmetric head dims. +HEAD_SIZES_V_OVERRIDE = {576: 512, 320: 256, 192: 128} + +TYPES_KV = ["GGML_TYPE_F16", "GGML_TYPE_Q4_0", "GGML_TYPE_Q4_1", "GGML_TYPE_Q5_0", "GGML_TYPE_Q5_1", "GGML_TYPE_Q8_0", "GGML_TYPE_BF16"] SOURCE_FATTN_TILE = """// This file has been autogenerated by generate_cu_files.py, do not edit manually. @@ -32,10 +35,11 @@ SOURCE_FATTN_MMA_CASE = "DECL_FATTN_MMA_F16_CASE({head_size_kq}, {head_size_v}, {ncols1}, {ncols2});\n" TYPES_MMQ = [ + "GGML_TYPE_Q1_0", "GGML_TYPE_Q4_0", "GGML_TYPE_Q4_1", "GGML_TYPE_Q5_0", "GGML_TYPE_Q5_1", "GGML_TYPE_Q8_0", "GGML_TYPE_Q2_K", "GGML_TYPE_Q3_K", "GGML_TYPE_Q4_K", "GGML_TYPE_Q5_K", "GGML_TYPE_Q6_K", "GGML_TYPE_IQ2_XXS", "GGML_TYPE_IQ2_XS", "GGML_TYPE_IQ2_S", "GGML_TYPE_IQ3_XXS", "GGML_TYPE_IQ3_S", - "GGML_TYPE_IQ1_S", "GGML_TYPE_IQ4_NL", "GGML_TYPE_IQ4_XS", "GGML_TYPE_MXFP4" + "GGML_TYPE_IQ1_S", "GGML_TYPE_IQ4_NL", "GGML_TYPE_IQ4_XS", "GGML_TYPE_MXFP4", "GGML_TYPE_NVFP4" ] SOURCE_MMQ = """// This file has been autogenerated by generate_cu_files.py, do not edit manually. @@ -61,7 +65,7 @@ def get_short_name(long_quant_name): os.remove(filename) for head_size_kq in HEAD_SIZES_KQ: - head_size_v = head_size_kq if head_size_kq != 576 else 512 + head_size_v = HEAD_SIZES_V_OVERRIDE.get(head_size_kq, head_size_kq) with open(f"fattn-tile-instance-dkq{head_size_kq}-dv{head_size_v}.cu", "w") as f: f.write(SOURCE_FATTN_TILE.format(head_size_kq=head_size_kq, head_size_v=head_size_v)) @@ -71,7 +75,7 @@ def get_short_name(long_quant_name): f.write(SOURCE_FATTN_VEC.format(type_k=type_k, type_v=type_v)) for ncols in [8, 16, 32, 64]: - for ncols2 in [1, 2, 4, 8, 16]: + for ncols2 in [1, 2, 4, 8, 16, 32]: if ncols2 > ncols: continue ncols1 = ncols // ncols2 @@ -83,11 +87,18 @@ def get_short_name(long_quant_name): continue if head_size_kq == 72: continue - if head_size_kq != 576 and ncols2 == 16: + # Skip compilation of unused ncols2 values for niche head sizes: + if head_size_kq == 192 and ncols2 not in (8, 16): # MiMo-V2.5 + continue + if head_size_kq == 320 and ncols2 != 32: # Mistral Small 4 + continue + if head_size_kq == 512 and ncols2 not in (4, 8): # Gemma 4 + continue + if head_size_kq == 576 and ncols2 not in (4, 16, 32): # Deepseek, GLM 4.7 Flash continue - if head_size_kq == 576 and ncols2 != 16: + if head_size_kq not in (192, 320, 576) and ncols2 in (16, 32): continue - head_size_v = head_size_kq if head_size_kq != 576 else 512 + head_size_v = HEAD_SIZES_V_OVERRIDE.get(head_size_kq, head_size_kq) f.write(SOURCE_FATTN_MMA_CASE.format(ncols1=ncols1, ncols2=ncols2, head_size_kq=head_size_kq, head_size_v=head_size_v)) for type in TYPES_MMQ: diff --git a/ggml/src/ggml-cuda/template-instances/mmq-instance-nvfp4.cu b/ggml/src/ggml-cuda/template-instances/mmq-instance-nvfp4.cu new file mode 100644 index 00000000000..2cb140d35a3 --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/mmq-instance-nvfp4.cu @@ -0,0 +1,5 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../mmq.cuh" + +DECL_MMQ_CASE(GGML_TYPE_NVFP4); diff --git a/ggml/src/ggml-cuda/template-instances/mmq-instance-q1_0.cu b/ggml/src/ggml-cuda/template-instances/mmq-instance-q1_0.cu new file mode 100644 index 00000000000..f0686b0d0d8 --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/mmq-instance-q1_0.cu @@ -0,0 +1,5 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../mmq.cuh" + +DECL_MMQ_CASE(GGML_TYPE_Q1_0); diff --git a/ggml/src/ggml-cuda/top-k.cu b/ggml/src/ggml-cuda/top-k.cu index 318ac38691e..db1d39e2dc7 100644 --- a/ggml/src/ggml-cuda/top-k.cu +++ b/ggml/src/ggml-cuda/top-k.cu @@ -4,8 +4,8 @@ #ifdef GGML_CUDA_USE_CUB # include <cub/cub.cuh> # if (CCCL_MAJOR_VERSION >= 3 && CCCL_MINOR_VERSION >= 2) -# include <cuda/iterator> # define CUB_TOP_K_AVAILABLE +# include <cuda/iterator> using namespace cub; # endif // CCCL_MAJOR_VERSION >= 3 && CCCL_MINOR_VERSION >= 2 #endif // GGML_CUDA_USE_CUB @@ -26,14 +26,14 @@ static void top_k_cub(ggml_cuda_pool & pool, auto indexes_in = cuda::make_counting_iterator(0); size_t temp_storage_bytes = 0; - DeviceTopK::MaxPairs(nullptr, temp_storage_bytes, src, cuda::discard_iterator(), indexes_in, dst, ncols, k, - env); + CUDA_CHECK(DeviceTopK::MaxPairs(nullptr, temp_storage_bytes, src, cuda::discard_iterator(), indexes_in, dst, ncols, k, + env)); ggml_cuda_pool_alloc<uint8_t> temp_storage_alloc(pool, temp_storage_bytes); void * d_temp_storage = temp_storage_alloc.get(); - DeviceTopK::MaxPairs(d_temp_storage, temp_storage_bytes, src, cuda::discard_iterator(), indexes_in, dst, - ncols, k, env); + CUDA_CHECK(DeviceTopK::MaxPairs(d_temp_storage, temp_storage_bytes, src, cuda::discard_iterator(), indexes_in, dst, + ncols, k, env)); } #elif defined(GGML_CUDA_USE_CUB) // CUB_TOP_K_AVAILABLE diff --git a/ggml/src/ggml-cuda/topk-moe.cu b/ggml/src/ggml-cuda/topk-moe.cu index 48e569efa0d..c4253bfa43b 100644 --- a/ggml/src/ggml-cuda/topk-moe.cu +++ b/ggml/src/ggml-cuda/topk-moe.cu @@ -5,6 +5,13 @@ #include <cmath> #include <initializer_list> +// Kernel config struct - passed by value to CUDA kernel +struct topk_moe_config { + bool use_sigmoid; + bool with_norm; + bool delayed_softmax; +}; + // Warp-local softmax used for both the pre-top-k logits and the post-top-k delayed path. template <int experts_per_thread, bool use_limit> __device__ void softmax_warp_inplace(float (&vals)[experts_per_thread], const int limit, const int lane) { @@ -50,6 +57,16 @@ __device__ void softmax_warp_inplace(float (&vals)[experts_per_thread], const in } } +template <int experts_per_thread, bool use_limit> +__device__ void sigmoid_warp_inplace(float (&vals)[experts_per_thread], const int limit, const int lane) { +#pragma unroll + for (int i = 0; i < experts_per_thread; i++) { + const int idx = lane + i * WARP_SIZE; + const bool active = !use_limit || (idx < limit); + vals[i] = active ? 1.f / (1.f + expf(-vals[i])) : -INFINITY; + } +} + /* This kernel does the following: 1. optionally softmax over the logits per token [n_experts, n_tokens] @@ -59,13 +76,16 @@ __device__ void softmax_warp_inplace(float (&vals)[experts_per_thread], const in It is intended as fusion of softmax->top-k->get_rows pipeline for MoE models */ -template <int n_experts, bool with_norm, bool delayed_softmax = false> -__launch_bounds__(4 * WARP_SIZE, 1) __global__ void topk_moe_cuda(const float * logits, - float * weights, - int32_t * ids, - const int n_rows, - const int n_expert_used, - const float clamp_val) { +template <int n_experts, bool has_bias> +__launch_bounds__(4 * WARP_SIZE, 1) __global__ void topk_moe_cuda(const float * logits, + float * weights, + int32_t * ids, + float * bias, + const int n_rows, + const int n_expert_used, + const float clamp_val, + const float scale_val, + const topk_moe_config config) { const int row = blockIdx.x * blockDim.y + threadIdx.y; if (row >= n_rows) { return; @@ -79,14 +99,54 @@ __launch_bounds__(4 * WARP_SIZE, 1) __global__ void topk_moe_cuda(const float * float wt[experts_per_thread]; + // Initialize all slots to -INFINITY +#pragma unroll + for (int i = 0; i < experts_per_thread; i++) { + wt[i] = -INFINITY; + } + + ggml_cuda_pdl_sync(); #pragma unroll for (int i = 0; i < n_experts; i += WARP_SIZE) { const int expert = i + threadIdx.x; wt[i / WARP_SIZE] = (n_experts % WARP_SIZE == 0 || expert < n_experts) ? logits[expert] : -INFINITY; } - if constexpr (!delayed_softmax) { - softmax_warp_inplace<experts_per_thread, false>(wt, n_experts, threadIdx.x); + if (!config.delayed_softmax) { + if (config.use_sigmoid) { + sigmoid_warp_inplace<experts_per_thread, false>(wt, n_experts, threadIdx.x); + } else { + softmax_warp_inplace<experts_per_thread, false>(wt, n_experts, threadIdx.x); + } + } + + // Sanitize NaN to -FLT_MAX so the iterative argmax produces unique expert IDs. + // NaN comparisons always return false, which would cause the same expert to be + // selected repeatedly. -FLT_MAX compares normally and is still excluded by the + // -INFINITY sentinel used after each selection round. + // More relevant for the cuBLAS path. See https://github.com/ggml-org/llama.cpp/issues/19659 +#pragma unroll + for (int i = 0; i < experts_per_thread; i++) { + if (__isnanf(wt[i])) { + wt[i] = -FLT_MAX; + } + } + + // selection_wt is only needed when bias is present (selection uses wt + bias) + // when no bias, we use wt directly for both selection and weight values + [[maybe_unused]] float selection_wt[has_bias ? experts_per_thread : 1]; + + if constexpr (has_bias) { +#pragma unroll + for (int i = 0; i < experts_per_thread; i++) { + selection_wt[i] = -INFINITY; + } +#pragma unroll + for (int i = 0; i < n_experts; i += WARP_SIZE) { + const int expert = i + threadIdx.x; + selection_wt[i / WARP_SIZE] = + (n_experts % WARP_SIZE == 0 || expert < n_experts) ? wt[i / WARP_SIZE] + bias[expert] : -INFINITY; + } } //at this point, each thread holds either a portion of the softmax distribution @@ -102,26 +162,61 @@ __launch_bounds__(4 * WARP_SIZE, 1) __global__ void topk_moe_cuda(const float * output_weights[i] = 0.f; } + ggml_cuda_pdl_lc(); for (int k = 0; k < n_expert_used; k++) { float max_val = wt[0]; int max_expert = threadIdx.x; + if constexpr (has_bias) { + float max_val_s = selection_wt[0]; + #pragma unroll - for (int i = 1; i < experts_per_thread; i++) { - const int expert = threadIdx.x + i * WARP_SIZE; - if ((n_experts % WARP_SIZE == 0 || expert < n_experts) && wt[i] > max_val) { - max_val = wt[i]; - max_expert = expert; + for (int i = 1; i < experts_per_thread; i++) { + const int expert = threadIdx.x + i * WARP_SIZE; + if ((n_experts % WARP_SIZE == 0 || expert < n_experts) && selection_wt[i] > max_val_s) { + max_val = wt[i]; + max_val_s = selection_wt[i]; + max_expert = expert; + } } - } #pragma unroll - for (int mask = WARP_SIZE / 2; mask > 0; mask /= 2) { - const float val = __shfl_xor_sync(0xFFFFFFFF, max_val, mask, WARP_SIZE); - const int expert = __shfl_xor_sync(0xFFFFFFFF, max_expert, mask, WARP_SIZE); - if (val > max_val || (val == max_val && expert < max_expert)) { - max_val = val; - max_expert = expert; + for (int mask = WARP_SIZE / 2; mask > 0; mask /= 2) { + const float val = __shfl_xor_sync(0xFFFFFFFF, max_val, mask, WARP_SIZE); + const float val_s = __shfl_xor_sync(0xFFFFFFFF, max_val_s, mask, WARP_SIZE); + const int expert = __shfl_xor_sync(0xFFFFFFFF, max_expert, mask, WARP_SIZE); + if (val_s > max_val_s || (val_s == max_val_s && expert < max_expert)) { + max_val = val; + max_val_s = val_s; + max_expert = expert; + } + } + + if ((max_expert & (WARP_SIZE - 1)) == threadIdx.x) { + selection_wt[max_expert / WARP_SIZE] = -INFINITY; + } + } else { +#pragma unroll + for (int i = 1; i < experts_per_thread; i++) { + const int expert = threadIdx.x + i * WARP_SIZE; + if ((n_experts % WARP_SIZE == 0 || expert < n_experts) && wt[i] > max_val) { + max_val = wt[i]; + max_expert = expert; + } + } + +#pragma unroll + for (int mask = WARP_SIZE / 2; mask > 0; mask /= 2) { + const float val = __shfl_xor_sync(0xFFFFFFFF, max_val, mask, WARP_SIZE); + const int expert = __shfl_xor_sync(0xFFFFFFFF, max_expert, mask, WARP_SIZE); + if (val > max_val || (val == max_val && expert < max_expert)) { + max_val = val; + max_expert = expert; + } + } + + if ((max_expert & (WARP_SIZE - 1)) == threadIdx.x) { + wt[max_expert / WARP_SIZE] = -INFINITY; } } @@ -130,16 +225,14 @@ __launch_bounds__(4 * WARP_SIZE, 1) __global__ void topk_moe_cuda(const float * } if ((max_expert & (WARP_SIZE - 1)) == threadIdx.x) { - wt[max_expert / WARP_SIZE] = -INFINITY; - ids[k] = max_expert; - if constexpr (with_norm) { + if (config.with_norm) { wt_sum += max_val; } } } - if constexpr (with_norm) { + if (config.with_norm) { wt_sum = warp_reduce_sum(wt_sum); wt_sum = max(wt_sum, clamp_val); const float inv_sum = 1.0f / wt_sum; @@ -149,7 +242,7 @@ __launch_bounds__(4 * WARP_SIZE, 1) __global__ void topk_moe_cuda(const float * } } - if constexpr (delayed_softmax) { + if (config.delayed_softmax) { softmax_warp_inplace<experts_per_thread, true>(output_weights, n_expert_used, threadIdx.x); } @@ -157,70 +250,75 @@ __launch_bounds__(4 * WARP_SIZE, 1) __global__ void topk_moe_cuda(const float * for (int i = 0; i < experts_per_thread; i++) { const int idx = i * WARP_SIZE + threadIdx.x; if (idx < n_expert_used) { - weights[idx] = output_weights[i]; + weights[idx] = output_weights[i] * scale_val; } } - - if (!with_norm) { - GGML_UNUSED(clamp_val); - } } -template <bool with_norm, bool delayed_softmax = false> +template<bool has_bias> static void launch_topk_moe_cuda(ggml_backend_cuda_context & ctx, const float * logits, float * weights, int32_t * ids, + float * bias, const int n_rows, const int n_expert, const int n_expert_used, - const float clamp_val) { - static_assert(!(with_norm && delayed_softmax), "delayed softmax is not supported with weight normalization"); + const float clamp_val, + const float scale_val, + const topk_moe_config config) { + GGML_ASSERT(!(config.with_norm && config.delayed_softmax) && + "delayed softmax is not supported with weight normalization"); const int rows_per_block = 4; dim3 grid_dims((n_rows + rows_per_block - 1) / rows_per_block, 1, 1); dim3 block_dims(WARP_SIZE, rows_per_block, 1); cudaStream_t stream = ctx.stream(); + const ggml_cuda_kernel_launch_params launch_params = ggml_cuda_kernel_launch_params(grid_dims, block_dims, 0, stream); switch (n_expert) { case 1: - topk_moe_cuda<1, with_norm, delayed_softmax> - <<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used, clamp_val); + ggml_cuda_kernel_launch(topk_moe_cuda<1, has_bias>, launch_params, + logits, weights, ids, bias, n_rows, n_expert_used, clamp_val, scale_val, config); break; case 2: - topk_moe_cuda<2, with_norm, delayed_softmax> - <<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used, clamp_val); + ggml_cuda_kernel_launch(topk_moe_cuda<2, has_bias>, launch_params, + logits, weights, ids, bias, n_rows, n_expert_used, clamp_val, scale_val, config); break; case 4: - topk_moe_cuda<4, with_norm, delayed_softmax> - <<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used, clamp_val); + ggml_cuda_kernel_launch(topk_moe_cuda<4, has_bias>, launch_params, + logits, weights, ids, bias, n_rows, n_expert_used, clamp_val, scale_val, config); break; case 8: - topk_moe_cuda<8, with_norm, delayed_softmax> - <<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used, clamp_val); + ggml_cuda_kernel_launch(topk_moe_cuda<8, has_bias>, launch_params, + logits, weights, ids, bias, n_rows, n_expert_used, clamp_val, scale_val, config); break; case 16: - topk_moe_cuda<16, with_norm, delayed_softmax> - <<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used, clamp_val); + ggml_cuda_kernel_launch(topk_moe_cuda<16, has_bias>, launch_params, + logits, weights, ids, bias, n_rows, n_expert_used, clamp_val, scale_val, config); break; case 32: - topk_moe_cuda<32, with_norm, delayed_softmax> - <<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used, clamp_val); + ggml_cuda_kernel_launch(topk_moe_cuda<32, has_bias>, launch_params, + logits, weights, ids, bias, n_rows, n_expert_used, clamp_val, scale_val, config); break; case 64: - topk_moe_cuda<64, with_norm, delayed_softmax> - <<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used, clamp_val); + ggml_cuda_kernel_launch(topk_moe_cuda<64, has_bias>, launch_params, + logits, weights, ids, bias, n_rows, n_expert_used, clamp_val, scale_val, config); break; case 128: - topk_moe_cuda<128, with_norm, delayed_softmax> - <<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used, clamp_val); + ggml_cuda_kernel_launch(topk_moe_cuda<128, has_bias>, launch_params, + logits, weights, ids, bias, n_rows, n_expert_used, clamp_val, scale_val, config); break; case 256: - topk_moe_cuda<256, with_norm, delayed_softmax> - <<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used, clamp_val); + ggml_cuda_kernel_launch(topk_moe_cuda<256, has_bias>, launch_params, + logits, weights, ids, bias, n_rows, n_expert_used, clamp_val, scale_val, config); break; case 512: - topk_moe_cuda<512, with_norm, delayed_softmax> - <<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used, clamp_val); + ggml_cuda_kernel_launch(topk_moe_cuda<512, has_bias>, launch_params, + logits, weights, ids, bias, n_rows, n_expert_used, clamp_val, scale_val, config); + break; + case 576: + ggml_cuda_kernel_launch(topk_moe_cuda<576, has_bias>, launch_params, + logits, weights, ids, bias, n_rows, n_expert_used, clamp_val, scale_val, config); break; default: GGML_ASSERT(false && "fatal error"); @@ -228,13 +326,14 @@ static void launch_topk_moe_cuda(ggml_backend_cuda_context & ctx, } } -void ggml_cuda_op_topk_moe(ggml_backend_cuda_context & ctx, - const ggml_tensor * logits, - ggml_tensor * weights, - ggml_tensor * ids, - const bool with_norm, - const bool delayed_softmax, - ggml_tensor * clamp) { +void ggml_cuda_op_topk_moe(ggml_backend_cuda_context & ctx, + const ggml_tensor * logits, + ggml_tensor * weights, + ggml_tensor * ids, + const ggml_tensor * clamp, + const ggml_tensor * scale, + const ggml_tensor * bias, + const ggml_cuda_topk_moe_args & args) { GGML_ASSERT(logits->type == GGML_TYPE_F32); GGML_ASSERT(weights->type == GGML_TYPE_F32); GGML_ASSERT(ids->type == GGML_TYPE_I32); @@ -245,107 +344,75 @@ void ggml_cuda_op_topk_moe(ggml_backend_cuda_context & ctx, const float * logits_d = (const float *) logits->data; float * weights_d = (float *) weights->data; int32_t * ids_d = (int32_t *) ids->data; + float * bias_d = bias ? (float *) bias->data : nullptr; + + float scale_val = scale ? ggml_get_op_params_f32(scale, 0) : 1.0f; GGML_ASSERT(ids->nb[1] / ggml_type_size(ids->type) == (size_t) n_experts); const int n_expert_used = weights->ne[1]; + const bool with_norm = clamp != nullptr; + float clamp_val = -INFINITY; - if (with_norm) { - if (clamp) { - clamp_val = ggml_get_op_params_f32(clamp, 0); - } - launch_topk_moe_cuda<true>(ctx, logits_d, weights_d, ids_d, n_rows, n_experts, n_expert_used, clamp_val); + if (clamp) { + clamp_val = ggml_get_op_params_f32(clamp, 0); + } + + topk_moe_config config; + config.use_sigmoid = args.sigmoid; + config.with_norm = with_norm; + config.delayed_softmax = args.delayed_softmax; + + if (bias) { + launch_topk_moe_cuda<true>(ctx, logits_d, weights_d, ids_d, bias_d, n_rows, n_experts, n_expert_used, clamp_val, + scale_val, config); } else { - GGML_ASSERT(clamp == nullptr); - if (delayed_softmax) { - launch_topk_moe_cuda<false, true>(ctx, logits_d, weights_d, ids_d, n_rows, n_experts, n_expert_used, - clamp_val); - } else { - launch_topk_moe_cuda<false, false>(ctx, logits_d, weights_d, ids_d, n_rows, n_experts, n_expert_used, - clamp_val); - } + launch_topk_moe_cuda<false>(ctx, logits_d, weights_d, ids_d, bias_d, n_rows, n_experts, n_expert_used, clamp_val, + scale_val, config); } } -bool ggml_cuda_should_use_topk_moe(const ggml_tensor * softmax, +bool ggml_cuda_should_use_topk_moe(const ggml_tensor * gating_op, const ggml_tensor * weights, - const ggml_tensor * get_rows, - const ggml_tensor * argsort, - const ggml_tensor * clamp, - int n_expert) { - ggml_tensor * probs = get_rows->src[0]; - if (probs->op != GGML_OP_RESHAPE) { + const ggml_tensor * logits, + const ggml_tensor * ids) { + const int n_expert = ids->nb[1] / ids->nb[0]; + if (((n_expert & (n_expert - 1)) != 0 || n_expert > 512) && n_expert != 576) { return false; } - probs = probs->src[0]; - ggml_tensor * selection_probs = argsort->src[0]; - if (probs != selection_probs) { + if (!ggml_is_contiguous(weights) || !ggml_is_contiguous(logits)) { return false; } - float scale = 1.0f; - float max_bias = 0.0f; - - memcpy(&scale, (const float *) softmax->op_params + 0, sizeof(float)); - memcpy(&max_bias, (const float *) softmax->op_params + 1, sizeof(float)); - - if (!ggml_is_contiguous(softmax->src[0]) || !ggml_is_contiguous(weights)) { - return false; - } + if (gating_op->op == GGML_OP_SOFT_MAX) { + const ggml_tensor * softmax = gating_op; + float scale = 1.0f; + float max_bias = 0.0f; - if (scale != 1.0f || max_bias != 0.0f) { - return false; - } + memcpy(&scale, (const float *) softmax->op_params + 0, sizeof(float)); + memcpy(&max_bias, (const float *) softmax->op_params + 1, sizeof(float)); - // don't fuse when masks or sinks are present - if (softmax->src[1] || softmax->src[2]) { - return false; - } + if (!ggml_is_contiguous(softmax->src[0])) { + return false; + } - // n_expert must be a power of 2 - if ((n_expert & (n_expert - 1)) != 0 || n_expert > 512) { - return false; - } + if (scale != 1.0f || max_bias != 0.0f) { + return false; + } - if (clamp) { - if (clamp->op != GGML_OP_CLAMP) { + // don't fuse when masks or sinks are present + if (softmax->src[1] || softmax->src[2]) { return false; } - float max_val = ggml_get_op_params_f32(clamp, 1); + } else if (gating_op->op == GGML_OP_UNARY) { + ggml_unary_op op = ggml_get_unary_op(gating_op); - if (max_val != INFINITY) { + if (op != GGML_UNARY_OP_SIGMOID) { return false; } } - return true; } - -std::initializer_list<enum ggml_op> ggml_cuda_topk_moe_ops(bool norm, bool delayed_softmax) { - static std::initializer_list<enum ggml_op> norm_ops = { GGML_OP_SOFT_MAX, GGML_OP_RESHAPE, GGML_OP_ARGSORT, - GGML_OP_VIEW, GGML_OP_GET_ROWS, GGML_OP_RESHAPE, - GGML_OP_SUM_ROWS, GGML_OP_CLAMP, GGML_OP_DIV, - GGML_OP_RESHAPE }; - - static std::initializer_list<enum ggml_op> no_norm_ops = { GGML_OP_SOFT_MAX, GGML_OP_RESHAPE, GGML_OP_ARGSORT, - GGML_OP_VIEW, GGML_OP_GET_ROWS }; - - static std::initializer_list<enum ggml_op> delayed_softmax_ops = { GGML_OP_ARGSORT, GGML_OP_VIEW, - GGML_OP_GET_ROWS, GGML_OP_RESHAPE, - GGML_OP_SOFT_MAX, GGML_OP_RESHAPE }; - - GGML_ASSERT(!norm || !delayed_softmax); - - if (delayed_softmax) { - return delayed_softmax_ops; - } - - if (norm) { - return norm_ops; - } - - return no_norm_ops; -} diff --git a/ggml/src/ggml-cuda/topk-moe.cuh b/ggml/src/ggml-cuda/topk-moe.cuh index 6b6c13c5870..243dc2f1c41 100644 --- a/ggml/src/ggml-cuda/topk-moe.cuh +++ b/ggml/src/ggml-cuda/topk-moe.cuh @@ -3,19 +3,25 @@ #include <initializer_list> -void ggml_cuda_op_topk_moe(ggml_backend_cuda_context & ctx, - const ggml_tensor * logits, - ggml_tensor * weights, - ggml_tensor * ids, - const bool with_norm, - const bool delayed_softmax = false, - ggml_tensor * weight_clamp = nullptr); +struct ggml_cuda_topk_moe_args { + bool sigmoid{}; + bool softmax{}; + bool delayed_softmax{}; + bool prob_bias{}; + bool norm{}; + bool scale{}; +}; -bool ggml_cuda_should_use_topk_moe(const ggml_tensor * softmax, - const ggml_tensor * weights, - const ggml_tensor * get_rows, - const ggml_tensor * argsort, - const ggml_tensor * clamp, - int n_expert); +void ggml_cuda_op_topk_moe(ggml_backend_cuda_context & ctx, + const ggml_tensor * logits, + ggml_tensor * weights, + ggml_tensor * ids, + const ggml_tensor * clamp, + const ggml_tensor * scale, + const ggml_tensor * bias, + const ggml_cuda_topk_moe_args & args); -std::initializer_list<enum ggml_op> ggml_cuda_topk_moe_ops(bool with_norm, bool delayed_softmax = false); +bool ggml_cuda_should_use_topk_moe(const ggml_tensor * gating_op, + const ggml_tensor * weights, + const ggml_tensor * logits, + const ggml_tensor * ids); diff --git a/ggml/src/ggml-cuda/unary.cu b/ggml/src/ggml-cuda/unary.cu index d4866067a4f..4cb805fa601 100644 --- a/ggml/src/ggml-cuda/unary.cu +++ b/ggml/src/ggml-cuda/unary.cu @@ -65,6 +65,11 @@ static __device__ __forceinline__ float op_sqr(float x) { return x * x; } +static __device__ __forceinline__ float op_relu_sqr(float x) { + const float r = fmaxf(x, 0.0f); + return r * r; +} + static __device__ __forceinline__ float op_sqrt(float x) { return sqrtf(x); } @@ -111,19 +116,22 @@ static __device__ __forceinline__ float op_trunc(float x) { template <float (*op)(float), typename T> static __global__ void unary_op_kernel(const T * x, T * dst, const int k) { + ggml_cuda_pdl_lc(); const int i = blockDim.x*blockIdx.x + threadIdx.x; if (i >= k) { return; } + ggml_cuda_pdl_sync(); dst[i] = (T)op((float)x[i]); } template <float (*op)(float), typename T> static void unary_cuda(const T * x, T * dst, const int k, cudaStream_t stream) { const int num_blocks = (k + CUDA_NEG_BLOCK_SIZE - 1) / CUDA_NEG_BLOCK_SIZE; - unary_op_kernel<op><<<num_blocks, CUDA_NEG_BLOCK_SIZE, 0, stream>>>(x, dst, k); + const ggml_cuda_kernel_launch_params launch_params = ggml_cuda_kernel_launch_params((dim3)num_blocks, CUDA_NEG_BLOCK_SIZE, 0, stream); + ggml_cuda_kernel_launch(unary_op_kernel<op, T>, launch_params, x, dst, k); } template <float (*op)(float)> @@ -253,6 +261,7 @@ void ggml_cuda_op_softplus(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { template <float (*op)(float), typename T> static __global__ void unary_gated_op_kernel(const T * x, const T * g, T * dst, const int64_t k, const int64_t n, const int64_t o0, const int64_t o1) { + ggml_cuda_pdl_lc(); const int64_t i = int64_t(blockDim.x)*blockIdx.x + threadIdx.x; if (i >= k) { @@ -263,13 +272,15 @@ static __global__ void unary_gated_op_kernel(const T * x, const T * g, T * dst, const int64_t j0 = (i / n) * o0 + (i % n); const int64_t j1 = o0 == o1 ? j0 : (i / n) * o1 + (i % n); + ggml_cuda_pdl_sync(); dst[i] = (T)(op((float)x[j0]) * (float)g[j1]); } template <float (*op)(float), typename T> static void unary_gated_cuda(const T * x, const T * g, T * dst, const int64_t k, const int64_t n, const int64_t o0, const int64_t o1, cudaStream_t stream) { const int64_t num_blocks = (k + CUDA_GLU_BLOCK_SIZE - 1) / CUDA_GLU_BLOCK_SIZE; - unary_gated_op_kernel<op><<<num_blocks, CUDA_GLU_BLOCK_SIZE, 0, stream>>>(x, g, dst, k, n, o0, o1); + const ggml_cuda_kernel_launch_params launch_params = ggml_cuda_kernel_launch_params((dim3)num_blocks, CUDA_GLU_BLOCK_SIZE, 0, stream); + ggml_cuda_kernel_launch(unary_gated_op_kernel<op, T>, launch_params, x, g, dst, k, n, o0, o1); } template <float (*op)(float)> @@ -560,3 +571,76 @@ void ggml_cuda_op_leaky_relu(ggml_backend_cuda_context & ctx, ggml_tensor * dst) leaky_relu_cuda((const float *)src0_d, (float *)dst_d, ggml_nelements(src0), negative_slope, stream); } } + +/* fused unary + mul */ + +template <float (*op)(float)> +static void ggml_cuda_op_unary_mul_impl(ggml_backend_cuda_context & ctx, ggml_tensor * unary_node, ggml_tensor * mul_node) { + // unary_node: UNARY op applied to unary_node->src[0] + // mul_node: MUL(a, b) where one of a/b is unary_node + // Output goes to mul_node->data + + const ggml_tensor * unary_src = unary_node->src[0]; // input to the unary op + const ggml_tensor * other_src = (mul_node->src[0] == unary_node) ? mul_node->src[1] : mul_node->src[0]; + + GGML_ASSERT(ggml_is_contiguous_1(unary_src)); + GGML_ASSERT(unary_src->nb[0] == ggml_element_size(unary_src)); + GGML_ASSERT(ggml_is_contiguous_1(other_src)); + GGML_ASSERT(other_src->nb[0] == ggml_element_size(other_src)); + GGML_ASSERT(ggml_are_same_shape(unary_src, other_src)); + + GGML_ASSERT(unary_src->type == GGML_TYPE_F32 || unary_src->type == GGML_TYPE_F16); + GGML_ASSERT(unary_src->type == other_src->type); + GGML_ASSERT(unary_src->type == mul_node->type); + + cudaStream_t stream = ctx.stream(); + + const int64_t k = ggml_nelements(mul_node); + const int64_t nc = unary_src->ne[0]; + const int64_t unary_stride = unary_src->nb[1]; + const int64_t other_stride = other_src->nb[1]; + + if (unary_src->type == GGML_TYPE_F16) { + unary_gated_cuda<op>((const half *) unary_src->data, (const half *) other_src->data, + (half *) mul_node->data, k, nc, + unary_stride / sizeof(half), other_stride / sizeof(half), stream); + } else { + unary_gated_cuda<op>((const float *) unary_src->data, (const float *) other_src->data, + (float *) mul_node->data, k, nc, + unary_stride / sizeof(float), other_stride / sizeof(float), stream); + } +} + +void ggml_cuda_op_unary_mul(ggml_backend_cuda_context & ctx, ggml_tensor * unary_node, ggml_tensor * mul_node) { + switch (ggml_get_unary_op(unary_node)) { + case GGML_UNARY_OP_SILU: + ggml_cuda_op_unary_mul_impl<op_silu>(ctx, unary_node, mul_node); + break; + case GGML_UNARY_OP_SIGMOID: + ggml_cuda_op_unary_mul_impl<op_sigmoid>(ctx, unary_node, mul_node); + break; + case GGML_UNARY_OP_SOFTPLUS: + ggml_cuda_op_unary_mul_impl<op_softplus>(ctx, unary_node, mul_node); + break; + default: + GGML_ABORT("Unsupported unary op for fused unary+mul"); + } +} + +/* fused relu + sqr */ + +void ggml_cuda_op_relu_sqr(ggml_backend_cuda_context & ctx, ggml_tensor * relu_node, ggml_tensor * sqr_node) { + const ggml_tensor * src = relu_node->src[0]; + cudaStream_t stream = ctx.stream(); + + GGML_ASSERT(ggml_is_contiguous(src)); + GGML_ASSERT(src->type == GGML_TYPE_F32 || src->type == GGML_TYPE_F16); + GGML_ASSERT(src->type == sqr_node->type); + + const int k = ggml_nelements(src); + if (src->type == GGML_TYPE_F16) { + unary_cuda<op_relu_sqr>((const half *)src->data, (half *)sqr_node->data, k, stream); + } else { + unary_cuda<op_relu_sqr>((const float *)src->data, (float *)sqr_node->data, k, stream); + } +} diff --git a/ggml/src/ggml-cuda/unary.cuh b/ggml/src/ggml-cuda/unary.cuh index 609046e5694..81ed873ecc3 100644 --- a/ggml/src/ggml-cuda/unary.cuh +++ b/ggml/src/ggml-cuda/unary.cuh @@ -89,6 +89,10 @@ void ggml_cuda_op_geglu_quick(ggml_backend_cuda_context & ctx, ggml_tensor * dst void ggml_cuda_op_xielu(ggml_backend_cuda_context & ctx, ggml_tensor * dst); +void ggml_cuda_op_unary_mul(ggml_backend_cuda_context & ctx, ggml_tensor * unary_node, ggml_tensor * mul_node); + +void ggml_cuda_op_relu_sqr(ggml_backend_cuda_context & ctx, ggml_tensor * relu_node, ggml_tensor * sqr_node); + __device__ __forceinline__ float ggml_cuda_op_silu_single(float x) { return x / (1.0f + expf(-x)); } diff --git a/ggml/src/ggml-cuda/vecdotq.cuh b/ggml/src/ggml-cuda/vecdotq.cuh index 6baab1176ff..d1741cc8d7b 100644 --- a/ggml/src/ggml-cuda/vecdotq.cuh +++ b/ggml/src/ggml-cuda/vecdotq.cuh @@ -94,9 +94,21 @@ static __device__ __forceinline__ int2 get_int_from_table_16(const int & q4, con #endif } +static __device__ __forceinline__ uint32_t unpack_ksigns(const uint8_t v) { + // v is a 7 bit int, with the 8th sign being encodable as popcnt + // with xor we can "correct" the bit instead of having to mask + const uint32_t p = __popc(v) & 1; + const uint32_t s = v ^ p << 7; + // broadcast over uint to allow for 0x08040201 / 0x80402010 as selectors + return s * 0x01010101; +} + // VDR = vec dot ratio, how many contiguous integers each thread processes when the vec dot kernel is called // MMVQ = mul_mat_vec_q, MMQ = mul_mat_q +#define VDR_Q1_0_Q8_1_MMVQ 1 // Process one 32-element chunk at a time for parallelism +#define VDR_Q1_0_Q8_1_MMQ 4 // Q1_0 has 128 bits (4 ints) per block + #define VDR_Q4_0_Q8_1_MMVQ 2 #define VDR_Q4_0_Q8_1_MMQ 4 @@ -313,6 +325,38 @@ static __device__ __forceinline__ float vec_dot_mxfp4_q8_1( return d * sumi; } +#define VDR_NVFP4_Q8_1_MMVQ 4 +#define VDR_NVFP4_Q8_1_MMQ 8 + +static __device__ __forceinline__ float vec_dot_nvfp4_q8_1( + const void * __restrict__ vbq, + const block_q8_1 * __restrict__ bq8_1, + const int32_t & kbx, + const int32_t & iqs) { + + const block_nvfp4 * bq4 = (const block_nvfp4 *) vbq + kbx; + float sum = 0.0f; +#pragma unroll + for (int i = 0; i < VDR_NVFP4_Q8_1_MMVQ/2; i++) { + const int32_t iqs0 = iqs + 2*i; + const int32_t iqs1 = iqs0 + 1; + const int32_t is = iqs0 >> 1; + const int2 v0 = get_int_from_table_16(get_int_b4(bq4->qs, iqs0), kvalues_mxfp4); + const int2 v1 = get_int_from_table_16(get_int_b4(bq4->qs, iqs1), kvalues_mxfp4); + const block_q8_1 * bq8 = bq8_1 + (is >> 1); + const int32_t i8 = ((is & 1) << 2); + + int sumi = ggml_cuda_dp4a(v0.x, get_int_b4(bq8->qs, i8 + 0), 0); + sumi = ggml_cuda_dp4a(v0.y, get_int_b4(bq8->qs, i8 + 2), sumi); + sumi = ggml_cuda_dp4a(v1.x, get_int_b4(bq8->qs, i8 + 1), sumi); + sumi = ggml_cuda_dp4a(v1.y, get_int_b4(bq8->qs, i8 + 3), sumi); + + const float d = ggml_cuda_ue4m3_to_fp32(bq4->d[is]) * __low2float(bq8->ds); + sum += d * float(sumi); + } + + return sum; +} #define VDR_Q2_K_Q8_1_MMVQ 1 #define VDR_Q2_K_Q8_1_MMQ 4 @@ -628,6 +672,51 @@ static __device__ __forceinline__ float vec_dot_q6_K_q8_1_impl_mmq( return d6 * sumf_d; } +static __device__ __forceinline__ float vec_dot_q1_0_q8_1( + const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) { + + const block_q1_0 * bq1_0 = (const block_q1_0 *) vbq + kbx; + + // Q1_0: 128 elements with ONE scale + // Q8_1: 32 elements per block with individual scales + // iqs selects which of the 4 chunks of 32 elements to process (0-3) + + const float d1 = bq1_0->d; + + // Process only the chunk specified by iqs + const block_q8_1 * bq8_1_chunk = bq8_1 + iqs; + + // Load 32 bits (4 bytes) for this chunk from Q1_0 + const int offset = iqs * 4; + const int v = bq1_0->qs[offset + 0] | (bq1_0->qs[offset + 1] << 8) | + (bq1_0->qs[offset + 2] << 16) | (bq1_0->qs[offset + 3] << 24); + + // Unpack 32 bits into 32 signed values (-1 or +1) + int vi_bytes[8]; +#pragma unroll + for (int j = 0; j < 8; ++j) { + const int shift = j * 4; + const int bits4 = (v >> shift) & 0x0F; + const int b0 = (bits4 & 0x01) ? 1 : -1; + const int b1 = (bits4 & 0x02) ? 1 : -1; + const int b2 = (bits4 & 0x04) ? 1 : -1; + const int b3 = (bits4 & 0x08) ? 1 : -1; + vi_bytes[j] = (b0 & 0xFF) | ((b1 & 0xFF) << 8) | ((b2 & 0xFF) << 16) | ((b3 & 0xFF) << 24); + } + + // Compute dot product for this 32-element chunk + int sumi = 0; +#pragma unroll + for (int j = 0; j < 8; ++j) { + const int u = get_int_b4(bq8_1_chunk->qs, j); + sumi = ggml_cuda_dp4a(vi_bytes[j], u, sumi); + } + + // Apply Q1_0's single scale and this chunk's Q8_1 scale + const float d8 = __low2float(bq8_1_chunk->ds); + return d1 * d8 * sumi; +} + static __device__ __forceinline__ float vec_dot_q4_0_q8_1( const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) { @@ -905,22 +994,22 @@ static __device__ __forceinline__ float vec_dot_iq2_xxs_q8_1( int sumi = 0; #pragma unroll for (int k0 = 0; k0 < 8; k0 += 2) { - const int * grid_pos = (const int *) (iq2xxs_grid + aux8[k0/2]); - const int signs_packed = ksigns_iq2xs[(aux32 >> (7*k0/2)) & 0x7F]; + const uint2 grid_pos = ((const uint2*)iq2xxs_grid)[aux8[k0/2]]; + const uint32_t signs = unpack_ksigns(aux32 >> (7 * k0 / 2)); - const int signs0 = __vcmpne4(((signs_packed & 0x03) << 7) | ((signs_packed & 0x0C) << 21), 0x00000000); - const int grid0 = __vsub4(grid_pos[0] ^ signs0, signs0); + const int signs0 = __vcmpne4(signs & 0x08040201, 0); + const int grid0 = __vsub4(grid_pos.x ^ signs0, signs0); const int u0 = get_int_b4(bq8_1[iqs/2].qs, k0 + 0); sumi = ggml_cuda_dp4a(grid0, u0, sumi); - const int signs1 = __vcmpne4(((signs_packed & 0x30) << 3) | ((signs_packed & 0xC0) << 17), 0x00000000); - const int grid1 = __vsub4(grid_pos[1] ^ signs1, signs1); + const int signs1 = __vcmpne4(signs & 0x80402010, 0); + const int grid1 = __vsub4(grid_pos.y ^ signs1, signs1); const int u1 = get_int_b4(bq8_1[iqs/2].qs, k0 + 1); sumi = ggml_cuda_dp4a(grid1, u1, sumi); } - const int ls = aux32 >> 28; - sumi = (ls*sumi + sumi/2)/4; + const int ls = aux32 >> 27 | 1; // (scale * 2 + 1) + sumi = sumi * ls / 8; // (sumi * scale + sumi / 2) / 4 const float d = __half2float(bq2->d) * __low2float(bq8_1[iqs/2].ds); return d * sumi; } @@ -942,13 +1031,15 @@ static __device__ __forceinline__ float vec_dot_iq2_xs_q8_1( int sumi1 = 0; #pragma unroll for (int l0 = 0; l0 < 8; l0 += 2) { - const uint32_t * grid_pos = (const uint32_t *)(iq2xs_grid + (q2[l0/2] & 0x000001FF)); - const uint32_t * signs = (const uint32_t *)(ksigns64 + (q2[l0/2] >> 9)); - - const int grid_l = __vsub4(grid_pos[0] ^ signs[0], signs[0]); - const int grid_h = __vsub4(grid_pos[1] ^ signs[1], signs[1]); + const uint2 grid_pos = ((const uint2*)iq2xs_grid)[q2[l0/2] & 0x1FF]; + const uint32_t signs = unpack_ksigns(q2[l0/2] >> 9); + const int signs0 = __vcmpne4(signs & 0x08040201, 0); + const int grid_l = __vsub4(grid_pos.x ^ signs0, signs0); const int u0 = get_int_b4(bq8_1[iqs/2].qs, l0 + 0); + + const int signs1 = __vcmpne4(signs & 0x80402010, 0); + const int grid_h = __vsub4(grid_pos.y ^ signs1, signs1); const int u1 = get_int_b4(bq8_1[iqs/2].qs, l0 + 1); if (l0 < 4) { @@ -1028,13 +1119,16 @@ static __device__ __forceinline__ float vec_dot_iq3_xxs_q8_1( #pragma unroll for (int l0 = 0; l0 < 8; l0 += 2) { const int2 grid_pos = make_int2(iq3xxs_grid[q3[l0 + 0]], iq3xxs_grid[q3[l0 + 1]]); + const uint32_t signs = unpack_ksigns(aux32 >> (7*l0/2)); - const int * signs = (const int *)(ksigns64 + ((aux32 >> (7*l0/2)) & 0x7F)); - - const int grid_l = __vsub4(grid_pos.x ^ signs[0], signs[0]); - const int grid_h = __vsub4(grid_pos.y ^ signs[1], signs[1]); + const int signs0 = __vcmpne4(signs & 0x08040201, 0); + const int grid_l = __vsub4(grid_pos.x ^ signs0, signs0); const int u0 = get_int_b4(bq8_1[iqs/2].qs, l0 + 0); + + const int signs1 = __vcmpne4(signs & 0x80402010, 0); + const int grid_h = __vsub4(grid_pos.y ^ signs1, signs1); + const int u1 = get_int_b4(bq8_1[iqs/2].qs, l0 + 1); sumi = ggml_cuda_dp4a(grid_l, u0, sumi); diff --git a/ggml/src/ggml-cuda/vendors/cuda.h b/ggml/src/ggml-cuda/vendors/cuda.h index ba032cfab4b..323c9801934 100644 --- a/ggml/src/ggml-cuda/vendors/cuda.h +++ b/ggml/src/ggml-cuda/vendors/cuda.h @@ -6,9 +6,14 @@ #include <cuda_bf16.h> #include <cuda_fp16.h> -#if CUDART_VERSION >= 12050 +#ifdef GGML_USE_NCCL +#include <nccl.h> +#endif // GGML_USE_NCCL + +#if CUDART_VERSION >= 11080 #include <cuda_fp8.h> -#endif // CUDART_VERSION >= 12050 +#define FP8_AVAILABLE +#endif // CUDART_VERSION >= 11080 #if CUDART_VERSION >= 12080 #include <cuda_fp4.h> diff --git a/ggml/src/ggml-cuda/vendors/hip.h b/ggml/src/ggml-cuda/vendors/hip.h index 016b04e5a0c..a6115cd80dc 100644 --- a/ggml/src/ggml-cuda/vendors/hip.h +++ b/ggml/src/ggml-cuda/vendors/hip.h @@ -10,6 +10,11 @@ #include <rocwmma/rocwmma-version.hpp> #endif // defined(GGML_HIP_ROCWMMA_FATTN) +#ifdef GGML_USE_NCCL +#include <rccl/rccl.h> +#endif // GGML_USE_NCCL + + #define CUBLAS_GEMM_DEFAULT HIPBLAS_GEMM_DEFAULT #define CUBLAS_GEMM_DEFAULT_TENSOR_OP HIPBLAS_GEMM_DEFAULT #define CUBLAS_OP_N HIPBLAS_OP_N @@ -43,6 +48,7 @@ #define cublasSetMathMode(handle, mode) CUBLAS_STATUS_SUCCESS #define cublasSetStream hipblasSetStream #define cublasSgemm hipblasSgemm +#define cublasSgemmStridedBatched hipblasSgemmStridedBatched #define cublasStatus_t hipblasStatus_t #define cublasOperation_t hipblasOperation_t #define cudaDevAttrCooperativeLaunch hipDeviceAttributeCooperativeLaunch @@ -50,9 +56,11 @@ #define cudaDeviceDisablePeerAccess hipDeviceDisablePeerAccess #define cudaDeviceEnablePeerAccess hipDeviceEnablePeerAccess #define cudaDeviceGetAttribute hipDeviceGetAttribute +#define cudaDeviceGetPCIBusId hipDeviceGetPCIBusId #define cudaDeviceProp hipDeviceProp_t #define cudaDeviceSynchronize hipDeviceSynchronize #define cudaError_t hipError_t +#define cudaErrorMemoryAllocation hipErrorOutOfMemory #define cudaErrorPeerAccessAlreadyEnabled hipErrorPeerAccessAlreadyEnabled #define cudaErrorPeerAccessNotEnabled hipErrorPeerAccessNotEnabled #define cudaEventCreateWithFlags hipEventCreateWithFlags @@ -138,6 +146,8 @@ #define cudaStream_t hipStream_t #define cudaSuccess hipSuccess #define cudaOccupancyMaxActiveBlocksPerMultiprocessor hipOccupancyMaxActiveBlocksPerMultiprocessor +#define cudaFuncSetAttribute hipFuncSetAttribute +#define cudaFuncAttributeMaxDynamicSharedMemorySize hipFuncAttributeMaxDynamicSharedMemorySize #define __trap() do { abort(); __builtin_unreachable(); } while(0) #define CUBLAS_STATUS_SUCCESS HIPBLAS_STATUS_SUCCESS #define CUBLAS_STATUS_NOT_INITIALIZED HIPBLAS_STATUS_NOT_INITIALIZED @@ -181,6 +191,10 @@ #define GCN #endif // defined(GCN5) || defined(GCN4) +#if defined(__gfx950__) +#define CDNA4 +#endif // defined(__gfx950__) + #if defined(__gfx942__) #define CDNA3 #endif // defined(__gfx942__) @@ -193,9 +207,9 @@ #define CDNA1 #endif // defined(__gfx908__) -#if defined(CDNA3) || defined(CDNA2) || defined(CDNA1) +#if defined(CDNA4) || defined(CDNA3) || defined(CDNA2) || defined(CDNA1) #define CDNA // For the entire family -#endif // defined(CDNA3) || defined(CDNA2) || defined(CDNA1) +#endif // defined(CDNA4) || defined(CDNA3) || defined(CDNA2) || defined(CDNA1) #if defined(__GFX12__) #define RDNA4 @@ -205,6 +219,14 @@ #define RDNA3 #endif // defined(__GFX11__) +#if defined(__gfx1150__) || defined(__gfx1151__) || defined(__gfx1152__) || defined(__gfx1153__) +#define RDNA3_5 +#endif // defined(__gfx1150__) || defined(__gfx1151__) || defined(__gfx1152__) || defined(__gfx1153__) + +#if defined(RDNA3) && !defined(RDNA3_5) +#define RDNA3_0 +#endif // defined(RDNA3) && !defined(RDNA3_5) + #if defined(__gfx1030__) || defined(__gfx1031__) || defined(__gfx1032__) || defined(__gfx1033__) || \ defined(__gfx1034__) || defined(__gfx1035__) || defined(__gfx1036__) || defined(__gfx1037__) #define RDNA2 @@ -225,6 +247,12 @@ typedef __hip_bfloat16 nv_bfloat16; typedef __hip_bfloat162 nv_bfloat162; +#if HIP_VERSION >= 60200000 +#include <hip/hip_fp8.h> +typedef __hip_fp8_e4m3 __nv_fp8_e4m3; +#define FP8_AVAILABLE +#endif // HIP_VERSION >= 60200000 + typedef int8_t int8x4_t __attribute__((ext_vector_type(4))); typedef uint8_t uint8x4_t __attribute__((ext_vector_type(4))); static __device__ __forceinline__ int __vsubss4(const int a, const int b) { diff --git a/ggml/src/ggml-cuda/vendors/musa.h b/ggml/src/ggml-cuda/vendors/musa.h index 1abb8acfd4b..99e8fa3703e 100644 --- a/ggml/src/ggml-cuda/vendors/musa.h +++ b/ggml/src/ggml-cuda/vendors/musa.h @@ -32,6 +32,7 @@ #define cublasSetMathMode mublasSetMathMode #define cublasSetStream mublasSetStream #define cublasSgemm mublasSgemm +#define cublasSgemmStridedBatched mublasSgemmStridedBatched #define cublasStatus_t mublasStatus_t #define cublasOperation_t mublasOperation_t #define cublasGetStatusString mublasGetStatusString @@ -39,9 +40,11 @@ #define cudaDeviceCanAccessPeer musaDeviceCanAccessPeer #define cudaDeviceDisablePeerAccess musaDeviceDisablePeerAccess #define cudaDeviceEnablePeerAccess musaDeviceEnablePeerAccess +#define cudaDeviceGetPCIBusId musaDeviceGetPCIBusId #define cudaDeviceProp musaDeviceProp #define cudaDeviceSynchronize musaDeviceSynchronize #define cudaError_t musaError_t +#define cudaErrorMemoryAllocation musaErrorMemoryAllocation #define cudaErrorPeerAccessAlreadyEnabled musaErrorPeerAccessAlreadyEnabled #define cudaErrorPeerAccessNotEnabled musaErrorPeerAccessNotEnabled #define cudaEventCreateWithFlags musaEventCreateWithFlags diff --git a/ggml/src/ggml-hexagon/CMakeLists.txt b/ggml/src/ggml-hexagon/CMakeLists.txt index d58e2878237..b82bae0c103 100644 --- a/ggml/src/ggml-hexagon/CMakeLists.txt +++ b/ggml/src/ggml-hexagon/CMakeLists.txt @@ -1,7 +1,30 @@ +file(TO_CMAKE_PATH "${HEXAGON_SDK_ROOT}" HEXAGON_SDK_ROOT) +file(TO_CMAKE_PATH "${HEXAGON_TOOLS_ROOT}" HEXAGON_TOOLS_ROOT) + +if (NOT IS_DIRECTORY "${HEXAGON_SDK_ROOT}") + message(FATAL_ERROR "Make sure HEXAGON_SDK_ROOT point to the correct Hexagon SDK installation.") +endif() + +if (NOT IS_DIRECTORY "${HEXAGON_TOOLS_ROOT}") + message("Try to read HEXAGON_TOOLS_ROOT from hexagon_sdk.json") + file(READ "${HEXAGON_SDK_ROOT}/hexagon_sdk.json" HEXAGON_SDK_CONFIG_PATH) + string(JSON HEXAGON_TOOLS_PATH GET ${HEXAGON_SDK_CONFIG_PATH} "root" "tools" "info" 0 "path") + message("Found HEXAGON_TOOLS_PATH: ${HEXAGON_TOOLS_PATH}") + set(HEXAGON_TOOLS_ROOT "${HEXAGON_SDK_ROOT}/${HEXAGON_TOOLS_PATH}") + file(TO_CMAKE_PATH "${HEXAGON_TOOLS_ROOT}" HEXAGON_TOOLS_ROOT) + if (NOT IS_DIRECTORY "${HEXAGON_TOOLS_ROOT}") + message(FATAL_ERROR "Make sure HEXAGON_TOOLS_ROOT point to the correct Hexagon SDK installation.") + endif() +endif() + +message(STATUS "hexagon: using ${HEXAGON_SDK_ROOT} and ${HEXAGON_TOOLS_ROOT} for building libggml-htp skels") + include(${HEXAGON_SDK_ROOT}/build/cmake/hexagon_fun.cmake) include(ExternalProject) -option(GGML_HEXAGON_HTP_DEBUG "ggml-hexagon: enable HTP debug output" OFF) +option(GGML_HEXAGON_HTP_DEBUG "ggml-hexagon: enable HTP debug output" OFF) +option(GGML_HEXAGON_FA_EXP2_HF "ggml-hexagon: use FP16 exp2 polynomial in FA softmax instead of F32 exp round-trip" OFF) +set(GGML_HEXAGON_HTP_CERT "$ENV{HEXAGON_HTP_CERT}" CACHE PATH "ggml-hexagon: enable HTP library signing using certificate") set(GGML_HEXAGON_FP32_QUANTIZE_GROUP_SIZE 128 CACHE STRING "ggml-hexagon: quantize group size (32, 64, or 128)") add_library(htp_iface OBJECT @@ -25,56 +48,71 @@ else() target_link_options(htp_iface PUBLIC -ldl) endif() -link_custom_library(htp_iface cdsprpc) -link_custom_library(htp_iface rpcmem) - set(TARGET_NAME ggml-hexagon) ggml_add_backend_library(${TARGET_NAME} - ggml-hexagon.cpp htp-utils.c htp-utils.h ../../include/ggml-hexagon.h) + ggml-hexagon.cpp + htp-drv.cpp + htp-drv.h + libdl.h + ../../include/ggml-hexagon.h) target_link_libraries(${TARGET_NAME} PRIVATE htp_iface) target_include_directories(${TARGET_NAME} PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/htp ${CMAKE_CURRENT_BINARY_DIR}) -# Build HTP bits -set(HTP_CMAKE_ARGS - -DCMAKE_TOOLCHAIN_FILE=${CMAKE_CURRENT_SOURCE_DIR}/htp/cmake-toolchain.cmake - -DCMAKE_BUILD_TYPE=Release - -DCMAKE_INSTALL_LIBDIR=${CMAKE_CURRENT_BINARY_DIR} - -DHEXAGON_SDK_ROOT=$ENV{HEXAGON_SDK_ROOT} - -DHEXAGON_TOOLS_ROOT=$ENV{HEXAGON_TOOLS_ROOT} - -DHEXAGON_HTP_DEBUG=${GGML_HEXAGON_HTP_DEBUG} - -DGGML_HEXAGON_FP32_QUANTIZE_GROUP_SIZE=${GGML_HEXAGON_FP32_QUANTIZE_GROUP_SIZE}) - -ExternalProject_Add(htp-v68 - SOURCE_DIR ${CMAKE_CURRENT_SOURCE_DIR}/htp BUILD_ALWAYS ON - CMAKE_ARGS ${HTP_CMAKE_ARGS} -DDSP_VERSION=v68 -DPREBUILT_LIB_DIR="toolv19_v68") - -ExternalProject_Add(htp-v69 - SOURCE_DIR ${CMAKE_CURRENT_SOURCE_DIR}/htp BUILD_ALWAYS ON - CMAKE_ARGS ${HTP_CMAKE_ARGS} -DDSP_VERSION=v69 -DPREBUILT_LIB_DIR="toolv19_v69") - -ExternalProject_Add(htp-v73 - SOURCE_DIR ${CMAKE_CURRENT_SOURCE_DIR}/htp BUILD_ALWAYS ON - CMAKE_ARGS ${HTP_CMAKE_ARGS} -DDSP_VERSION=v73 -DPREBUILT_LIB_DIR="toolv19_v73") - -ExternalProject_Add(htp-v75 - SOURCE_DIR ${CMAKE_CURRENT_SOURCE_DIR}/htp BUILD_ALWAYS ON - CMAKE_ARGS ${HTP_CMAKE_ARGS} -DDSP_VERSION=v75 -DPREBUILT_LIB_DIR="toolv19_v75") - -ExternalProject_Add(htp-v79 - SOURCE_DIR ${CMAKE_CURRENT_SOURCE_DIR}/htp BUILD_ALWAYS ON - CMAKE_ARGS ${HTP_CMAKE_ARGS} -DDSP_VERSION=v79 -DPREBUILT_LIB_DIR="toolv19_v79") - -ExternalProject_Add(htp-v81 - SOURCE_DIR ${CMAKE_CURRENT_SOURCE_DIR}/htp BUILD_ALWAYS ON - CMAKE_ARGS ${HTP_CMAKE_ARGS} -DDSP_VERSION=v81 -DPREBUILT_LIB_DIR="toolv19_v81") +# Build HTP skels +set(HTP_SKELS) +function(build_htp_skel V) + ExternalProject_Add(htp-${V} + SOURCE_DIR ${CMAKE_CURRENT_SOURCE_DIR}/htp BUILD_ALWAYS ON + BUILD_BYPRODUCTS ${CMAKE_CURRENT_BINARY_DIR}/libggml-htp-${V}.so + CMAKE_ARGS + -DCMAKE_BUILD_TYPE=Release + -DCMAKE_TOOLCHAIN_FILE=${CMAKE_CURRENT_SOURCE_DIR}/htp/cmake-toolchain.cmake + -DCMAKE_INSTALL_LIBDIR=${CMAKE_CURRENT_BINARY_DIR} + -DHEXAGON_SDK_ROOT=${HEXAGON_SDK_ROOT} + -DHEXAGON_TOOLS_ROOT=${HEXAGON_TOOLS_ROOT} + -DHEXAGON_HTP_DEBUG=${GGML_HEXAGON_HTP_DEBUG} + -DGGML_HEXAGON_FP32_QUANTIZE_GROUP_SIZE=${GGML_HEXAGON_FP32_QUANTIZE_GROUP_SIZE} + -DDSP_VERSION=${V} + -DPREBUILT_LIB_DIR="toolv19_${V}") + list(APPEND HTP_SKELS ${CMAKE_CURRENT_BINARY_DIR}/libggml-htp-${V}.so) + set(HTP_SKELS ${HTP_SKELS} PARENT_SCOPE) +endfunction() + +build_htp_skel(v68) +build_htp_skel(v69) +build_htp_skel(v73) +build_htp_skel(v75) +build_htp_skel(v79) +build_htp_skel(v81) # Install Hexagon skels required at runtime -install(FILES - ${CMAKE_CURRENT_BINARY_DIR}/libggml-htp-v68.so - ${CMAKE_CURRENT_BINARY_DIR}/libggml-htp-v69.so - ${CMAKE_CURRENT_BINARY_DIR}/libggml-htp-v73.so - ${CMAKE_CURRENT_BINARY_DIR}/libggml-htp-v75.so - ${CMAKE_CURRENT_BINARY_DIR}/libggml-htp-v79.so - ${CMAKE_CURRENT_BINARY_DIR}/libggml-htp-v81.so - TYPE LIB) +install(FILES ${HTP_SKELS} TYPE LIB) + +if (CMAKE_SYSTEM_NAME MATCHES Windows AND GGML_HEXAGON_HTP_CERT) + file(TO_CMAKE_PATH "$ENV{WINDOWS_SDK_BIN}/arm64" WINSDK_BIN0_ARM64) + file(TO_CMAKE_PATH "$ENV{WINDOWS_SDK_BIN}/x86" WINSDK_BIN0_X86) + file(TO_CMAKE_PATH "$ENV{WindowsSdkVerBinPath}/arm64" WINSDK_BIN1_ARM64) + file(TO_CMAKE_PATH "$ENV{WindowsSdkVerBinPath}/x86" WINSDK_BIN1_X86) + + set(WINSDK_PATHS ${WINSDK_BIN0_ARM64} ${WINSDK_BIN0_X86} ${WINSDK_BIN1_ARM64} ${WINSDK_BIN1_X86}) + + find_program(INF2CAT NAMES inf2cat.exe PATHS ${WINSDK_PATHS} REQUIRED) + find_program(SIGNTOOL NAMES signtool.exe PATHS ${WINSDK_PATHS} REQUIRED) + + message(STATUS "hexagon: using ${GGML_HEXAGON_HTP_CERT} to sign libggml-htp skels") + + set(LIBGGML_HTP_CAT ${CMAKE_CURRENT_BINARY_DIR}/libggml-htp.cat) + add_custom_target(libggml-htp-cat + BYPRODUCTS ${LIBGGML_HTP_CAT} + DEPENDS libggml-htp.inf ${HTP_SKELS} + COMMAND ${CMAKE_COMMAND} -E copy ${CMAKE_CURRENT_SOURCE_DIR}/libggml-htp.inf ${CMAKE_CURRENT_BINARY_DIR} + COMMAND ${INF2CAT} /driver:${CMAKE_CURRENT_BINARY_DIR} /os:10_25H2_ARM64 + COMMAND ${SIGNTOOL} sign /fd sha256 /f ${GGML_HEXAGON_HTP_CERT} ${LIBGGML_HTP_CAT} + COMMENT "generating and signing libggml-htp.cat file" + VERBATIM + ) + + add_dependencies(${TARGET_NAME} libggml-htp-cat) + install(FILES ${LIBGGML_HTP_CAT} TYPE LIB) +endif() diff --git a/ggml/src/ggml-hexagon/ggml-hexagon.cpp b/ggml/src/ggml-hexagon/ggml-hexagon.cpp index 365a24b4965..49bd7e4331a 100644 --- a/ggml/src/ggml-hexagon/ggml-hexagon.cpp +++ b/ggml/src/ggml-hexagon/ggml-hexagon.cpp @@ -7,16 +7,20 @@ #include <atomic> #include <chrono> -#include <cstddef> #include <mutex> +#include <thread> +#include <cstddef> #include <stdexcept> #include <string> +#include <sstream> +#include <iomanip> +#include <unordered_set> +#include <unordered_map> +#include <regex> +#include <queue> #ifdef _WIN32 # include <sal.h> -# ifndef _WINDOWS -# define _WINDOWS -# endif #else # include <semaphore.h> # include <unistd.h> @@ -25,8 +29,6 @@ #pragma clang diagnostic ignored "-Wnested-anon-types" #pragma clang diagnostic ignored "-Wgnu-anonymous-struct" -#include "htp-utils.h" - #include <AEEStdErr.h> #include <dspqueue.h> #include <rpcmem.h> @@ -37,22 +39,38 @@ #include "ggml-hexagon.h" #include "ggml-impl.h" #include "ggml-quants.h" -#include "op-desc.h" -#include "htp-msg.h" +#include "htp-opnode.h" +#include "htp-ops.h" #include "htp_iface.h" - -static size_t opt_ndev = 1; -static size_t opt_nhvx = 0; // use all -static int opt_arch = 0; // autodetect -static int opt_etm = 0; -static int opt_verbose = 0; -static int opt_profile = 0; -static int opt_hostbuf = 1; -static int opt_experimental = 0; +#include "htp-drv.h" + +using intvec = std::vector<int>; +using uintvec = std::vector<unsigned int>; +using u32vec = std::vector<uint32_t>; + +static int opt_arch = 0; // autodetect +static size_t opt_ndev = 1; +static size_t opt_nhvx = 0; // use all +static int opt_use_hmx = 1; // when set, enable HMX; when 0, use HVX only +static size_t opt_vmem = HTP_OP_MAX_VMEM_DEFAULT; // max available va space for buffer mappings +static size_t opt_mbuf = 1ul * 1024 * 1024 * 1024; // max buffer size +static int opt_etm = 0; +static int opt_verbose = 0; +static int opt_profile = 0; // profiling mode (0-disabled, 1-basic, 2-pmu) +static int opt_hostbuf = 1; // hostbuf ON by default + +// Default PMU events, if profiling with PMU (mode=2) is enabled +// See https://docs.qualcomm.com/doc/80-N2040-60/topic/pmu-events.html +// https://docs.qualcomm.com/doc/80-N2040-61/topic/hvx-pmu-events.html +static u32vec opt_pmu_evt { 0x3, 0x111, 0x100, 0x105, 0x240, 0x256, 0x7D, 0x8C }; // Enable all stages by default -static int opt_opmask = HTP_OPMASK_QUEUE | HTP_OPMASK_QUANTIZE | HTP_OPMASK_COMPUTE; -static int opt_opsync = 0; // synchronous ops +static int opt_opstage = HTP_OPSTAGE_QUEUE | HTP_OPSTAGE_COMPUTE; +static int opt_opbatch = 1024; // max number of ops in a batch +static int opt_opqueue = 16; // max number of pending batches +static int opt_oppoll = 0; // polling for batch completions + +static std::regex* opt_opfilter = NULL; // regex of ops to not claim #define HEX_VERBOSE(...) \ if (opt_verbose) GGML_LOG_DEBUG(__VA_ARGS__) @@ -84,47 +102,45 @@ static const char * status_to_str(uint32_t status) { // ** debug helpers -static void ggml_hexagon_dump_op_exec(const std::string &sess_name, const ggml_tensor * op, const uint32_t req_flags) { +static void ggml_hexagon_dump_op_exec(const std::string &sess_name, const htp_opnode & node, const uint32_t req_flags) { if (!opt_verbose) return; - op_desc desc(op); + htp_opformat fmt(node); GGML_LOG_DEBUG("ggml-hex: %s execute-op %s: %s : %s : %s : %s : %s : flags 0x%x\n", sess_name.c_str(), - ggml_op_name(op->op), desc.names, desc.dims, desc.types, desc.strides, desc.buffs, req_flags); + node.op_name().c_str(), fmt.names, fmt.dims, fmt.types, fmt.strides, fmt.buffs, req_flags); } static void ggml_hexagon_dump_op_supp(const std::string &sess_name, const struct ggml_tensor * op, bool supp) { if (!opt_verbose) return; - op_desc desc(op); - GGML_LOG_DEBUG("ggml-hex: %s supports-op %s : %s : %s : %s : %s : %s : %s\n", sess_name.c_str(), - ggml_op_name(op->op), desc.names, desc.dims, desc.types, desc.strides, desc.buffs, supp ? "yes" : "no"); + htp_opformat fmt(htp_opformat(htp_opnode{const_cast<ggml_tensor*>(op), {}, HTP_OP_INVALID})); + GGML_LOG_DEBUG("ggml-hex: %s supports-op %s: %s : %s : %s : %s : %s : %s\n", sess_name.c_str(), + ggml_op_desc(op), fmt.names, fmt.dims, fmt.types, fmt.strides, fmt.buffs, supp ? "yes" : "no"); } -static void ggml_hexagon_dump_op_prof(const std::string &sess_name, const ggml_tensor * op, - uint32_t op_usec, uint32_t op_cycles, uint32_t op_pkts, uint64_t call_usec) { +static void ggml_hexagon_dump_op_prof(const std::string &sess_name, const htp_opnode & node, + uint32_t op_usec, uint32_t op_cycles, const uint32_t pmu[]) { if (!opt_profile) return; - op_desc desc(op); - GGML_LOG_DEBUG("ggml-hex: %s profile-op %s: %s : %s : %s : %s : %s : op-usec %u op-cycles %u op-pkts %u (%f) call-usec %llu\n", sess_name.c_str(), - ggml_op_name(op->op), desc.names, desc.dims, desc.types, desc.strides, desc.buffs, - op_usec, op_cycles, op_pkts, (float) op_cycles / op_pkts, (unsigned long long) call_usec); + char pmu_str[256] = ""; + if (opt_profile > 1) { + static_assert(HTP_PROF_PMU_NCNT == 8, "current implementation assumes 8 PMU counters"); + sprintf(pmu_str, " pmu [%u,%u,%u,%u,%u,%u,%u,%u]", + pmu[0], pmu[1], pmu[2], pmu[3], pmu[4], pmu[5], pmu[6], pmu[7]); + } + + htp_opformat fmt(node); + GGML_LOG_DEBUG("ggml-hex: %s profile-op %s: %s : %s : %s : %s : usec %u cycles %u%s\n", sess_name.c_str(), + node.op_name().c_str(), fmt.names, fmt.dims, fmt.types, fmt.strides, op_usec, op_cycles, pmu_str); } // ** backend sessions -struct ggml_hexagon_session { - ggml_hexagon_session(int dev_id, ggml_backend_dev_t dev) noexcept(false); - ~ggml_hexagon_session() noexcept(true); - - void allocate(int dev_id) noexcept(false); - void release() noexcept(true); - - void enqueue(struct htp_general_req &req, struct dspqueue_buffer *bufs, uint32_t n_bufs, bool sync = false); - void flush(); - - ggml_backend_buffer_type buffer_type = {}; - ggml_backend_buffer_type repack_buffer_type = {}; +struct ggml_hexagon_opbatch; +struct ggml_hexagon_opqueue; +struct htp_opnode; +struct ggml_hexagon_session { std::string name; remote_handle64 handle; dspqueue_t queue; @@ -136,87 +152,28 @@ struct ggml_hexagon_session { bool valid_handle; bool valid_queue; bool valid_iface; - std::atomic<int> op_pending; - uint32_t prof_usecs; - uint32_t prof_cycles; - uint32_t prof_pkts; -}; - -void ggml_hexagon_session::enqueue(struct htp_general_req &req, struct dspqueue_buffer *bufs, uint32_t n_bufs, bool sync) { - // Bump pending flag (cleared in the session::flush once we get the responce) - this->op_pending++; // atomic inc - - int err = dspqueue_write(this->queue, - 0, // flags - the framework will autoset this - n_bufs, // number of buffers - bufs, // buffer references - sizeof(req), - (const uint8_t *) &req, // Message - 1000000 // Timeout - ); - - if (err != 0) { - GGML_ABORT("ggml-hex: %s dspqueue_write failed: 0x%08x\n", this->name.c_str(), (unsigned) err); - } - - if (sync) { - flush(); - } -} - -// Flush HTP response queue i.e wait for all outstanding requests to complete -void ggml_hexagon_session::flush() { - dspqueue_t q = this->queue; - - // Repeatedly read packets from the queue until it's empty. We don't - // necessarily get a separate callback for each packet, and new packets - // may arrive while we're processing the previous one. - - while (this->op_pending) { - struct htp_general_rsp rsp; - uint32_t rsp_size; - uint32_t flags; - - struct dspqueue_buffer bufs[HTP_MAX_PACKET_BUFFERS]; - uint32_t n_bufs; - // Read response packet from queue - int err = dspqueue_read(q, &flags, - HTP_MAX_PACKET_BUFFERS, // Maximum number of buffer references - &n_bufs, // Number of buffer references - bufs, // Buffer references - sizeof(rsp), // Max message length - &rsp_size, // Message length - (uint8_t *) &rsp, - 1000000); // Timeout + std::atomic<int> op_pending; + ggml_hexagon_opbatch* op_batch; + ggml_hexagon_opqueue* op_queue; - if (err == AEE_EEXPIRED) { - // TODO: might need to bail out if the HTP is stuck on something - continue; - } + ggml_backend_buffer_type buffer_type = {}; + ggml_backend_buffer_type repack_buffer_type = {}; - if (err != 0) { - GGML_ABORT("ggml-hex: dspqueue_read failed: 0x%08x\n", (unsigned) err); - } + ggml_hexagon_session(int dev_id, ggml_backend_dev_t dev) noexcept(false); + ~ggml_hexagon_session() noexcept(true); - // Basic sanity checks - if (rsp_size != sizeof(rsp)) { - GGML_ABORT("ggml-hex: dspcall : bad response (size)\n"); - } + const char* c_name() const { return name.c_str(); } - if (rsp.status != HTP_STATUS_OK) { - GGML_LOG_ERROR("ggml-hex: dspcall : dsp-rsp: %s\n", status_to_str(rsp.status)); - // TODO: handle errors - } + void allocate(int dev_id) noexcept(false); + void release() noexcept(true); - // TODO: update profiling implementation, currently only works for opt_opsync mode - this->prof_usecs = rsp.prof_usecs; - this->prof_cycles = rsp.prof_cycles; - this->prof_pkts = rsp.prof_pkts; + void enqueue_op(const htp_opnode & node); + void flush(bool all = true); - this->op_pending--; // atomic dec - } -} + void flush_pending(bool all = false); + void flush_batch(); +}; // ** backend buffers @@ -230,88 +187,94 @@ struct ggml_backend_hexagon_buffer_type_context { std::string name; }; -struct ggml_backend_hexagon_buffer_context { - bool mmap_to(ggml_hexagon_session * s) { - HEX_VERBOSE("ggml-hex: %s mmaping buffer: base %p domain-id %d session-id %d size %zu fd %d repack %d\n", - s->name.c_str(), (void *) this->base, s->domain_id, s->session_id, this->size, this->fd, - (int) this->repack); +struct ggml_hexagon_shared_buffer { + ggml_hexagon_session * sess; + uint8_t * base; + size_t size; + int fd; + bool mapped; + bool pinned; + + void mmap() { + fastrpc_map_flags flags = this->pinned ? FASTRPC_MAP_FD : FASTRPC_MAP_FD_DELAYED; - int err = fastrpc_mmap(s->domain_id, this->fd, (void *) this->base, 0, this->size, FASTRPC_MAP_FD); + int err = fastrpc_mmap(sess->domain_id, this->fd, (void *) this->base, 0, this->size, flags); if (err != 0) { - GGML_LOG_ERROR("ggml-hex: buffer mapping failed : domain_id %d size %zu fd %d error 0x%08x\n", - s->domain_id, this->size, this->fd, (unsigned) err); - return false; + GGML_LOG_ERROR("ggml-hex: %s buffer mapping failed : domain_id %d size %zu fd %d error 0x%08x\n", sess->c_name(), + sess->domain_id, this->size, this->fd, (unsigned) err); + throw std::runtime_error("ggml-hex: fastrpc_mmap failed (see log for details)"); } - return true; - } + HEX_VERBOSE("ggml-hex: %s mapped buffer: base %p size %zu fd %d pinned %u\n", + sess->c_name(), (void *) this->base, this->size, this->fd, pinned); - bool mmap() { - if (this->mapped) { - return true; - } - if (!mmap_to(this->sess)) { - return false; - } this->mapped = true; - return true; } - void munmap() { - if (!this->mapped) { - return; + void unmap() { + if (!this->mapped) return; + + if (!this->pinned) { + // HTP might still hold a reference, tell it drop it + htp_iface_munmap(sess->handle, this->fd); } - fastrpc_munmap(this->sess->domain_id, this->fd, this->base, this->size); + fastrpc_munmap(sess->domain_id, this->fd, (void *) this->base, this->size); + + HEX_VERBOSE("ggml-hex: %s unmapped buffer: base %p size %zu fd %d\n", sess->c_name(), + (void *) this->base, size, this->fd); + this->mapped = false; + this->fd = -1; } - ggml_backend_hexagon_buffer_context(ggml_hexagon_session * sess, size_t size, bool repack) { - size += 4 * 1024; // extra page for padding - - if (rpcmem_alloc2) { - this->base = (uint8_t *) rpcmem_alloc2(RPCMEM_HEAP_ID_SYSTEM, RPCMEM_DEFAULT_FLAGS | RPCMEM_HEAP_NOREG, size); - } else { - GGML_LOG_INFO("ggml-hex: %s rpcmem_alloc2 not found, falling back to rpcmem_alloc\n", sess->name.c_str()); - this->base = (uint8_t *) rpcmem_alloc(RPCMEM_HEAP_ID_SYSTEM, RPCMEM_DEFAULT_FLAGS | RPCMEM_HEAP_NOREG, size); - } + void alloc(size_t size) { + if (this->base) return; + this->base = (uint8_t *) rpcmem_alloc2(RPCMEM_HEAP_ID_SYSTEM, RPCMEM_DEFAULT_FLAGS, size); if (!this->base) { - GGML_LOG_ERROR("ggml-hex: %s failed to allocate buffer : size %zu\n", sess->name.c_str(), size); + GGML_LOG_ERROR("ggml-hex: %s failed to allocate buffer : size %zu\n", sess->c_name(), size); throw std::runtime_error("ggml-hex: rpcmem_alloc failed (see log for details)"); } this->fd = rpcmem_to_fd(this->base); if (this->fd < 0) { - GGML_LOG_ERROR("ggml-hex: %s failed to get FD for buffer %p\n", sess->name.c_str(), (void *) this->base); - rpcmem_free(this->base); - this->base = NULL; + GGML_LOG_ERROR("ggml-hex: %s failed to get FD for buffer %p\n", sess->c_name(), (void *) this->base); throw std::runtime_error("ggml-hex: rpcmem_to_fd failed (see log for details)"); } + this->size = size; + + HEX_VERBOSE("ggml-hex: %s allocated buffer: base %p size %zu fd %d pinned %d\n", sess->c_name(), + (void *) this->base, this->size, this->fd, (int) pinned); + mmap(); + } + + void free() { + if (!this->base) return; + + unmap(); + rpcmem_free(this->base); + + HEX_VERBOSE("ggml-hex: %s freed buffer: base %p size %zu fd %d\n", sess->c_name(), + (void *) this->base, size, this->fd); - HEX_VERBOSE("ggml-hex: %s allocated buffer: base %p size %zu fd %d repack %d\n", sess->name.c_str(), - (void *) this->base, size, this->fd, (int) repack); + this->base = NULL; + } + ggml_hexagon_shared_buffer(ggml_hexagon_session * sess, size_t size, bool pinned = false) { this->sess = sess; - this->size = size; + this->size = 0; + this->base = nullptr; + this->fd = -1; this->mapped = false; - this->repack = repack; - } + this->pinned = pinned; - ~ggml_backend_hexagon_buffer_context() { - munmap(); - if (this->base) { - rpcmem_free(this->base); - this->base = NULL; - } + alloc(size); } - ggml_hexagon_session * sess; // primary session - uint8_t * base; - size_t size; - int fd; - bool mapped; // mmap is done - bool repack; // repacked buffer + ~ggml_hexagon_shared_buffer() { + free(); + } }; static ggml_hexagon_session * ggml_backend_hexagon_buffer_get_sess(ggml_backend_buffer_t buffer) { @@ -319,30 +282,26 @@ static ggml_hexagon_session * ggml_backend_hexagon_buffer_get_sess(ggml_backend_ } static void ggml_backend_hexagon_buffer_free_buffer(ggml_backend_buffer_t buffer) { - auto ctx = static_cast<ggml_backend_hexagon_buffer_context *>(buffer->context); - delete ctx; + auto sbuf = static_cast<ggml_hexagon_shared_buffer *>(buffer->context); + delete sbuf; } static void * ggml_backend_hexagon_buffer_get_base(ggml_backend_buffer_t buffer) { - auto ctx = static_cast<ggml_backend_hexagon_buffer_context *>(buffer->context); - return ctx->base; + auto sbuf = static_cast<ggml_hexagon_shared_buffer *>(buffer->context); + return sbuf->base; } static enum ggml_status ggml_backend_hexagon_buffer_init_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor) { - auto ctx = static_cast<ggml_backend_hexagon_buffer_context *>(buffer->context); - auto sess = ctx->sess; + auto sbuf = static_cast<ggml_hexagon_shared_buffer *>(buffer->context); + auto sess = sbuf->sess; - HEX_VERBOSE("ggml-hex: %s init-tensor %s : base %p data %p nbytes %zu usage %d repack %d\n", sess->name.c_str(), - tensor->name, (void *) ctx->base, tensor->data, ggml_nbytes(tensor), (int) buffer->usage, - (int) ctx->repack); + HEX_VERBOSE("ggml-hex: %s init-tensor %s : base %p data %p nbytes %zu usage %d\n", sess->c_name(), + tensor->name, (void *) sbuf->base, tensor->data, ggml_nbytes(tensor), (int) buffer->usage); if (tensor->view_src != NULL && tensor->view_offs == 0) { - ; // nothing to do for the view - } else { - if (!ctx->mapped) { - ctx->mmap(); - } + return GGML_STATUS_SUCCESS; // nothing to do for the view } + return GGML_STATUS_SUCCESS; } @@ -412,6 +371,7 @@ static void pack_q4_0_quants(block_q4_0 * x, const uint8_t * qs, unsigned int bi static void repack_row_q4x4x2(uint8_t * y, const block_q4_0 * x, int64_t k) { static const int qk = QK_Q4_0x4x2; const int nb = (k + qk - 1) / qk; // number of blocks (padded) + const int nloe = k % qk; // leftovers const int dblk_size = 8 * 2; // 8x __fp16 const int qblk_size = qk / 2; // int4 @@ -445,15 +405,17 @@ static void repack_row_q4x4x2(uint8_t * y, const block_q4_0 * x, int64_t k) { unpack_q4_0_quants(qs, &x[i * 8 + 6], 6); unpack_q4_0_quants(qs, &x[i * 8 + 7], 7); + bool partial = (nloe && i == nb-1); + uint8_t * q = y_q + (i * qblk_size); for (int j = 0; j < qk / 2; j++) { - q[j] = (qs[j + 128] << 4) | qs[j]; + q[j] = partial ? (qs[j*2+1] << 4) | qs[j*2+0] : (qs[j+128] << 4) | qs[j+000]; } } // Repack the scales // Note: Do not combine with the loop above. For tensor sizes not multiple of 256 (QK_Q4_0x4x2) - // the last block is truncated and overriden by the scales. + // the last block is truncated and overridden by the scales. for (int i = 0; i < nb; i++) { // Repack the scales ggml_half * d = (ggml_half *) (y_d + i * dblk_size); @@ -467,7 +429,7 @@ static void repack_row_q4x4x2(uint8_t * y, const block_q4_0 * x, int64_t k) { d[7] = x[i * 8 + 7].d; } - if (opt_verbose > 1) { + if (opt_verbose > 2) { for (int i = 0; i < nb; i++) { dump_packed_block_q4x4x2(y, i, k); } @@ -477,6 +439,7 @@ static void repack_row_q4x4x2(uint8_t * y, const block_q4_0 * x, int64_t k) { static void unpack_row_q4x4x2(block_q4_0 * x, const uint8_t * y, int64_t k) { static const int qk = QK_Q4_0x4x2; const int nb = (k + qk - 1) / qk; // number of blocks (padded) + const int nloe = k % qk; // leftovers const int dblk_size = 8 * 2; // 8x __fp16 const int qblk_size = qk / 2; // int4 @@ -485,7 +448,7 @@ static void unpack_row_q4x4x2(block_q4_0 * x, const uint8_t * y, int64_t k) { const uint8_t * y_q = y + 0; // quants first const uint8_t * y_d = y + qrow_size; // then scales - if (opt_verbose > 1) { + if (opt_verbose > 2) { for (int i = 0; i < nb; i++) { dump_packed_block_q4x4x2(y, i, k); } @@ -495,10 +458,17 @@ static void unpack_row_q4x4x2(block_q4_0 * x, const uint8_t * y, int64_t k) { for (int i = 0; i < nb; i++) { uint8_t qs[QK_Q4_0x4x2]; // unpacked quants + bool partial = (nloe && i == nb-1); + const uint8_t * q = y_q + (i * qblk_size); for (int j = 0; j < qk / 2; j++) { - qs[j] = q[j] & 0xf; - qs[j + 128] = q[j] >> 4; + if (partial) { + qs[j*2+0] = q[j] & 0xf; + qs[j*2+1] = q[j] >> 4; + } else { + qs[j+000] = q[j] & 0xf; + qs[j+128] = q[j] >> 4; + } } pack_q4_0_quants(&x[i * 8 + 0], qs, 0); @@ -513,7 +483,7 @@ static void unpack_row_q4x4x2(block_q4_0 * x, const uint8_t * y, int64_t k) { // Repack the scales // Note: Do not combine with the loop above. For tensor sizes not multiple of 256 (QK_Q4_0x4x2) - // the last block is truncated and overriden by the scales. + // the last block is truncated and overridden by the scales. for (int i = 0; i < nb; i++) { // Unpack the scales const ggml_half * d = (const ggml_half *) (y_d + i * dblk_size); @@ -562,7 +532,7 @@ static void init_row_q4x4x2(block_q4_0 * x, int64_t k) { // Init the scales // Note: Do not combine with the loop above. For tensor sizes not multiple of 256 (QK_Q4_0x4x2) - // the last block is truncated and overriden by the scales. + // the last block is truncated and overridden by the scales. for (int i = 0; i < nb; i++) { // Unpack the scales x[i * 8 + 0].d = 0; @@ -582,7 +552,7 @@ static void repack_q4_0_q4x4x2(ggml_tensor * t, const void * data, size_t size) size_t row_size = ggml_row_size(t->type, t->ne[0]); size_t row_size_pd = ggml_row_size(t->type, hex_round_up(t->ne[0], QK_Q4_0x4x2)); // extra elements for the pad - size_t row_size_rp = row_size * 2; // extra space for tmp pad (if any) + size_t row_size_rp = row_size_pd; // scratch must hold one full padded tile (qblk_size/2 quants + scales) // Ensure we don't try to read more data than is available in the source buffer 'data' // or write more than the tensor can hold. @@ -643,7 +613,7 @@ static void repack_q4x4x2_q4_0(void * data, const ggml_tensor * t, size_t size) size_t row_size = ggml_row_size(t->type, t->ne[0]); size_t row_size_pd = ggml_row_size(t->type, hex_round_up(t->ne[0], QK_Q4_0x4x2)); // extra elements for the pad - size_t row_size_rp = row_size * 2; // extra space for tmp pad (if any) + size_t row_size_rp = row_size_pd; // scratch must hold one full padded tile (qblk_size/2 quants + scales) // Ensure we don't try to copy more data than the tensor actually contains. const size_t total_tensor_size = (size_t)nrows * row_size; @@ -692,6 +662,239 @@ static void repack_q4x4x2_q4_0(void * data, const ggml_tensor * t, size_t size) ggml_aligned_free(buf_rp, row_size_rp); } +static void unpack_q4_1_quants(uint8_t * qs, const block_q4_1 * x, unsigned int bi) { + static const int qk = QK4_1; + + for (unsigned int i = 0; i < qk / 2; ++i) { + const int x0 = (x->qs[i] & 0x0F); + const int x1 = (x->qs[i] >> 4); + qs[bi * qk + i + 0] = x0; + qs[bi * qk + i + qk / 2] = x1; + } +} + +static void pack_q4_1_quants(block_q4_1 * x, const uint8_t * qs, unsigned int bi) { + static const int qk = QK4_1; + + for (unsigned int i = 0; i < qk / 2; ++i) { + const uint8_t x0 = qs[bi * qk + i + 0]; + const uint8_t x1 = qs[bi * qk + i + qk / 2]; + x->qs[i] = x0 | (x1 << 4); + } +} + +static void repack_row_q4_1x4x2(uint8_t * y, const block_q4_1 * x, int64_t k) { + static const int qk = QK_Q4_0x4x2; + const int nb = (k + qk - 1) / qk; // number of blocks (padded) + const int nloe = k % qk; // leftovers + + const int dblk_size = 8 * 4; // 8x (d, m) __fp16 = 32 bytes + const int qblk_size = qk / 2; // int4 = 128 bytes + const int qrow_size = k / 2; // int4 (not padded to blocks) + + uint8_t * y_q = y + 0; // quants first + uint8_t * y_d = y + qrow_size; // then scales/offsets + + // Repack the quants + for (int i = 0; i < nb; i++) { + uint8_t qs[QK_Q4_0x4x2]; // unpacked quants + unpack_q4_1_quants(qs, &x[i * 8 + 0], 0); + unpack_q4_1_quants(qs, &x[i * 8 + 1], 1); + unpack_q4_1_quants(qs, &x[i * 8 + 2], 2); + unpack_q4_1_quants(qs, &x[i * 8 + 3], 3); + unpack_q4_1_quants(qs, &x[i * 8 + 4], 4); + unpack_q4_1_quants(qs, &x[i * 8 + 5], 5); + unpack_q4_1_quants(qs, &x[i * 8 + 6], 6); + unpack_q4_1_quants(qs, &x[i * 8 + 7], 7); + + bool partial = (nloe && i == nb-1); + + uint8_t * q = y_q + (i * qblk_size); + for (int j = 0; j < qk / 2; j++) { + q[j] = partial ? (qs[j*2+1] << 4) | qs[j*2+0] : (qs[j+128] << 4) | qs[j+000]; + } + } + + // Repack the scales and offsets + for (int i = 0; i < nb; i++) { + ggml_half * d_m = (ggml_half *) (y_d + i * dblk_size); + for (int j = 0; j < 8; j++) { + d_m[j * 2 + 0] = x[i * 8 + j].d; + d_m[j * 2 + 1] = x[i * 8 + j].m; + } + } +} + +static void unpack_row_q4_1x4x2(block_q4_1 * x, const uint8_t * y, int64_t k) { + static const int qk = QK_Q4_0x4x2; + const int nb = (k + qk - 1) / qk; // number of blocks (padded) + const int nloe = k % qk; // leftovers + + const int dblk_size = 8 * 4; // 8x (d, m) __fp16 = 32 bytes + const int qblk_size = qk / 2; // int4 = 128 bytes + const int qrow_size = k / 2; // int4 (not padded to blocks) + + const uint8_t * y_q = y + 0; // quants first + const uint8_t * y_d = y + qrow_size; // then scales/offsets + + // Unpack the quants + for (int i = 0; i < nb; i++) { + uint8_t qs[QK_Q4_0x4x2]; + bool partial = (nloe && i == nb-1); + + const uint8_t * q = y_q + (i * qblk_size); + for (int j = 0; j < qk / 2; j++) { + if (partial) { + qs[j*2+0] = q[j] & 0x0F; + qs[j*2+1] = q[j] >> 4; + } else { + qs[j+000] = q[j] & 0x0F; + qs[j+128] = q[j] >> 4; + } + } + + pack_q4_1_quants(&x[i * 8 + 0], qs, 0); + pack_q4_1_quants(&x[i * 8 + 1], qs, 1); + pack_q4_1_quants(&x[i * 8 + 2], qs, 2); + pack_q4_1_quants(&x[i * 8 + 3], qs, 3); + pack_q4_1_quants(&x[i * 8 + 4], qs, 4); + pack_q4_1_quants(&x[i * 8 + 5], qs, 5); + pack_q4_1_quants(&x[i * 8 + 6], qs, 6); + pack_q4_1_quants(&x[i * 8 + 7], qs, 7); + } + + // Unpack the scales and offsets + for (int i = 0; i < nb; i++) { + const ggml_half * d_m = (const ggml_half *) (y_d + i * dblk_size); + for (int j = 0; j < 8; j++) { + x[i * 8 + j].d = d_m[j * 2 + 0]; + x[i * 8 + j].m = d_m[j * 2 + 1]; + } + } +} + +static void init_row_q4_1x4x2(block_q4_1 * x, int64_t k) { + static const int qk = QK_Q4_0x4x2; + const int nb = (k + qk - 1) / qk; // number of blocks (padded) + + uint8_t qs[QK_Q4_0x4x2]; // unpacked quants + memset(qs, 0, sizeof(qs)); + + for (int i = 0; i < nb; i++) { + pack_q4_1_quants(&x[i * 8 + 0], qs, 0); + pack_q4_1_quants(&x[i * 8 + 1], qs, 1); + pack_q4_1_quants(&x[i * 8 + 2], qs, 2); + pack_q4_1_quants(&x[i * 8 + 3], qs, 3); + pack_q4_1_quants(&x[i * 8 + 4], qs, 4); + pack_q4_1_quants(&x[i * 8 + 5], qs, 5); + pack_q4_1_quants(&x[i * 8 + 6], qs, 6); + pack_q4_1_quants(&x[i * 8 + 7], qs, 7); + } + + for (int i = 0; i < nb; i++) { + for (int j = 0; j < 8; j++) { + x[i * 8 + j].d = 0; + x[i * 8 + j].m = 0; + } + } +} + +static void repack_q4_1_q4x4x2(ggml_tensor * t, const void * data, size_t size) { + int64_t nrows = ggml_nrows(t); + + size_t row_size = ggml_row_size(t->type, t->ne[0]); + size_t row_size_pd = ggml_row_size(t->type, hex_round_up(t->ne[0], QK_Q4_0x4x2)); + size_t row_size_rp = row_size_pd; // scratch must hold one full padded tile (qblk_size/2 quants + scales) + + const size_t total_tensor_size = (size_t)nrows * row_size; + const size_t n_bytes_to_copy = size < total_tensor_size ? size : total_tensor_size; + + const int64_t n_full_rows = n_bytes_to_copy / row_size; + const size_t n_rem_bytes = n_bytes_to_copy % row_size; + + void * buf_pd = ggml_aligned_malloc(row_size_pd); + GGML_ASSERT(buf_pd != NULL); + + void * buf_rp = ggml_aligned_malloc(row_size_rp); + GGML_ASSERT(buf_rp != NULL); + + HEX_VERBOSE("ggml-hex: repack-q4_1-q4x4x2 %s : data %p size %zu dims %ldx%ld row-size %zu\n", t->name, data, size, + t->ne[0], nrows, row_size); + + init_row_q4_1x4x2((block_q4_1 *) buf_pd, t->ne[0]); + + for (int64_t i = 0; i < n_full_rows; i++) { + const uint8_t * src = (const uint8_t *) data + (i * row_size); + uint8_t * dst = (uint8_t *) t->data + (i * row_size); + + memcpy(buf_pd, src, row_size); + repack_row_q4_1x4x2((uint8_t *) buf_rp, (const block_q4_1 *) buf_pd, t->ne[0]); + memcpy(dst, buf_rp, row_size); + } + + if (n_rem_bytes > 0) { + const int64_t i = n_full_rows; + const uint8_t * src = (const uint8_t *) data + (i * row_size); + uint8_t * dst = (uint8_t *) t->data + (i * row_size); + + init_row_q4_1x4x2((block_q4_1 *) buf_pd, t->ne[0]); + memcpy(buf_pd, src, n_rem_bytes); + repack_row_q4_1x4x2((uint8_t *) buf_rp, (const block_q4_1 *) buf_pd, t->ne[0]); + memcpy(dst, buf_rp, n_rem_bytes); + } + + ggml_aligned_free(buf_pd, row_size_pd); + ggml_aligned_free(buf_rp, row_size_rp); +} + +static void repack_q4x4x2_q4_1(void * data, const ggml_tensor * t, size_t size) { + int64_t nrows = ggml_nrows(t); + + size_t row_size = ggml_row_size(t->type, t->ne[0]); + size_t row_size_pd = ggml_row_size(t->type, hex_round_up(t->ne[0], QK_Q4_0x4x2)); + size_t row_size_rp = row_size_pd; // scratch must hold one full padded tile (qblk_size/2 quants + scales) + + const size_t total_tensor_size = (size_t)nrows * row_size; + const size_t n_bytes_to_copy = size < total_tensor_size ? size : total_tensor_size; + + const int64_t n_full_rows = n_bytes_to_copy / row_size; + const size_t n_rem_bytes = n_bytes_to_copy % row_size; + + void * buf_pd = ggml_aligned_malloc(row_size_pd); + GGML_ASSERT(buf_pd != NULL); + + void * buf_rp = ggml_aligned_malloc(row_size_rp); + GGML_ASSERT(buf_rp != NULL); + + HEX_VERBOSE("ggml-hex: repack-q4x4x2-q4_1 %s : data %p size %zu dims %ldx%ld row-size %zu\n", t->name, data, size, + t->ne[0], nrows, row_size); + + memset(buf_rp, 0, row_size_rp); // clear-out padded buffer to make sure the tail is all zeros + + for (int64_t i = 0; i < n_full_rows; i++) { + const uint8_t * src = (const uint8_t *) t->data + (i * row_size); + uint8_t * dst = (uint8_t *) data + (i * row_size); + + memcpy(buf_rp, src, row_size); + unpack_row_q4_1x4x2((block_q4_1 *) buf_pd, (const uint8_t *) buf_rp, t->ne[0]); + memcpy(dst, buf_pd, row_size); + } + + if (n_rem_bytes > 0) { + const int64_t i = n_full_rows; + const uint8_t * src = (const uint8_t *) t->data + (i * row_size); + uint8_t * dst = (uint8_t *) data + (i * row_size); + + // We still need to read and unpack the entire source row because quantization is block-based. + memcpy(buf_rp, src, row_size); + unpack_row_q4_1x4x2((block_q4_1 *) buf_pd, (const uint8_t *) buf_rp, t->ne[0]); + memcpy(dst, buf_pd, n_rem_bytes); + } + + ggml_aligned_free(buf_pd, row_size_pd); + ggml_aligned_free(buf_rp, row_size_rp); +} + // ======== Q8x4x2 ==================== static void dump_block_q8_0(const block_q8_0 * b, int i) { HEX_VERBOSE("ggml-hex: repack q8_0 %d: %d %d %d %d ... %d %d %d %d : %.6f\n", i, b->qs[0], b->qs[1], b->qs[2], @@ -780,7 +983,7 @@ static void repack_row_q8x4x2(uint8_t * y, const block_q8_0 * x, int64_t k) { // Repack the scales // Note: Do not combine with the loop above. For tensor sizes not multiple of 256 (QK_Q4_0x4x2) - // the last block is truncated and overriden by the scales. + // the last block is truncated and overridden by the scales. for (int i = 0; i < nb; i++) { // Repack the scales ggml_half * d = (ggml_half *) (y_d + i * dblk_size); @@ -794,7 +997,7 @@ static void repack_row_q8x4x2(uint8_t * y, const block_q8_0 * x, int64_t k) { d[7] = x[i * 8 + 7].d; } - if (opt_verbose > 1) { + if (opt_verbose > 2) { for (int i = 0; i < nb; i++) { dump_packed_block_q8x4x2(y, i, k); } @@ -812,7 +1015,7 @@ static void unpack_row_q8x4x2(block_q8_0 * x, const uint8_t * y, int64_t k) { const uint8_t * y_q = y + 0; // quants first const uint8_t * y_d = y + qrow_size; // then scales - if (opt_verbose > 1) { + if (opt_verbose > 2) { for (int i = 0; i < nb; i++) { dump_packed_block_q8x4x2(y, i, k); } @@ -839,7 +1042,7 @@ static void unpack_row_q8x4x2(block_q8_0 * x, const uint8_t * y, int64_t k) { // Repack the scales // Note: Do not combine with the loop above. For tensor sizes not multiple of 256 (QK_Q4_0x4x2) - // the last block is truncated and overriden by the scales. + // the last block is truncated and overridden by the scales. for (int i = 0; i < nb; i++) { // Unpack the scales const ggml_half * d = (const ggml_half *) (y_d + i * dblk_size); @@ -888,7 +1091,7 @@ static void init_row_q8x4x2(block_q8_0 * x, int64_t k) { // Init the scales // Note: Do not combine with the loop above. For tensor sizes not multiple of 256 (QK_Q8_0x4x2) - // the last block is truncated and overriden by the scales. + // the last block is truncated and overridden by the scales. for (int i = 0; i < nb; i++) { // Unpack the scales x[i * 8 + 0].d = 0; @@ -908,7 +1111,7 @@ static void repack_q8_0_q8x4x2(ggml_tensor * t, const void * data, size_t size) size_t row_size = ggml_row_size(t->type, t->ne[0]); size_t row_size_pd = ggml_row_size(t->type, hex_round_up(t->ne[0], QK_Q8_0x4x2)); // extra elements for the pad - size_t row_size_rp = row_size * 2; // extra space for tmp pad (if any) + size_t row_size_rp = row_size_pd; // scratch must hold one full padded tile (qblk_size quants + scales) // Ensure we don't try to read more data than is available in the source buffer 'data' // or write more than the tensor can hold. @@ -969,7 +1172,7 @@ static void repack_q8x4x2_q8_0(void * data, const ggml_tensor * t, size_t size) size_t row_size = ggml_row_size(t->type, t->ne[0]); size_t row_size_pd = ggml_row_size(t->type, hex_round_up(t->ne[0], QK_Q8_0x4x2)); // extra elements for the pad - size_t row_size_rp = row_size * 2; // extra space for tmp pad (if any) + size_t row_size_rp = row_size_pd; // scratch must hold one full padded tile (qblk_size quants + scales) // Ensure we don't try to copy more data than the tensor actually contains. const size_t total_tensor_size = (size_t)nrows * row_size; @@ -1088,6 +1291,7 @@ static void pack_mxfp4_quants(block_mxfp4 * x, const uint8_t * qs, unsigned int static void repack_row_mxfp4x4x2(uint8_t * y, const block_mxfp4 * x, int64_t k) { static const int qk = QK_MXFP4x4x2; const int nb = (k + qk - 1) / qk; // number of blocks (padded) + const int nloe = k % qk; // leftovers const int eblk_size = 8 * 1; // 8x E8M0 const int qblk_size = qk / 2; // int4 @@ -1122,15 +1326,17 @@ static void repack_row_mxfp4x4x2(uint8_t * y, const block_mxfp4 * x, int64_t k) unpack_mxfp4_quants(qs, &x[i * 8 + 6], 6); unpack_mxfp4_quants(qs, &x[i * 8 + 7], 7); + bool partial = (nloe && i == nb-1); + uint8_t * q = y_q + (i * qblk_size); for (int j = 0; j < qk / 2; j++) { - q[j] = (qs[j + 128] << 4) | qs[j]; + q[j] = partial ? (qs[j*2+1] << 4) | qs[j*2+0] : (qs[j+128] << 4) | qs[j+000]; } } // Repack the scales // Note: Do not combine with the loop above. For tensor sizes not multiple of 256 (QK_MXFP4x4x2) - // the last block is truncated and overriden by the scales. + // the last block is truncated and overridden by the scales. for (int i = 0; i < nb; i++) { // Repack the scales uint8_t * e = (uint8_t *) (y_e + i * eblk_size); @@ -1144,7 +1350,7 @@ static void repack_row_mxfp4x4x2(uint8_t * y, const block_mxfp4 * x, int64_t k) e[7] = x[i * 8 + 7].e; } - if (opt_verbose > 1) { + if (opt_verbose > 2) { for (int i = 0; i < nb; i++) { dump_packed_block_mxfp4x4x2(y, i, k); } @@ -1154,6 +1360,7 @@ static void repack_row_mxfp4x4x2(uint8_t * y, const block_mxfp4 * x, int64_t k) static void unpack_row_mxfp4x4x2(block_mxfp4 * x, const uint8_t * y, int64_t k) { static const int qk = QK_MXFP4x4x2; const int nb = (k + qk - 1) / qk; // number of blocks (padded) + const int nloe = k % qk; // leftovers const int eblk_size = 8 * 1; // 8x E8M0 const int qblk_size = qk / 2; // int4 @@ -1162,7 +1369,7 @@ static void unpack_row_mxfp4x4x2(block_mxfp4 * x, const uint8_t * y, int64_t k) const uint8_t * y_q = y + 0; // quants first const uint8_t * y_e = y + qrow_size; // then scales - if (opt_verbose > 1) { + if (opt_verbose > 2) { for (int i = 0; i < nb; i++) { dump_packed_block_mxfp4x4x2(y, i, k); } @@ -1172,10 +1379,17 @@ static void unpack_row_mxfp4x4x2(block_mxfp4 * x, const uint8_t * y, int64_t k) for (int i = 0; i < nb; i++) { uint8_t qs[QK_MXFP4x4x2]; // unpacked quants + bool partial = (nloe && i == nb-1); + const uint8_t * q = y_q + (i * qblk_size); for (int j = 0; j < qk / 2; j++) { - qs[j] = q[j] & 0xf; - qs[j + 128] = q[j] >> 4; + if (partial) { + qs[j*2+0] = q[j] & 0xf; + qs[j*2+1] = q[j] >> 4; + } else { + qs[j+000] = q[j] & 0xf; + qs[j+128] = q[j] >> 4; + } } pack_mxfp4_quants(&x[i * 8 + 0], qs, 0); @@ -1190,7 +1404,7 @@ static void unpack_row_mxfp4x4x2(block_mxfp4 * x, const uint8_t * y, int64_t k) // Repack the scales // Note: Do not combine with the loop above. For tensor sizes not multiple of 256 (QK_MXFP4_0x4x2) - // the last block is truncated and overriden by the scales. + // the last block is truncated and overridden by the scales. for (int i = 0; i < nb; i++) { // Unpack the scales const uint8_t * e = (const uint8_t *) (y_e + i * eblk_size); @@ -1239,7 +1453,7 @@ static void init_row_mxfp4x4x2(block_mxfp4 * x, int64_t k) { // Init the scales // Note: Do not combine with the loop above. For tensor sizes not multiple of 256 (QK_MXFP4x4x2) - // the last block is truncated and overriden by the scales. + // the last block is truncated and overridden by the scales. for (int i = 0; i < nb; i++) { // Unpack the scales x[i * 8 + 0].e = 0; @@ -1259,7 +1473,7 @@ static void repack_mxfp4_mxfp4x4x2(ggml_tensor * t, const void * data, size_t si size_t row_size = ggml_row_size(t->type, t->ne[0]); size_t row_size_pd = ggml_row_size(t->type, hex_round_up(t->ne[0], QK_MXFP4x4x2)); // extra elements for the pad - size_t row_size_rp = row_size * 2; // extra space for tmp pad (if any) + size_t row_size_rp = row_size_pd; // scratch must hold one full padded tile (qblk_size/2 quants + scales) // Ensure we don't try to read more data than is available in the source buffer 'data' // or write more than the tensor can hold. @@ -1320,7 +1534,7 @@ static void repack_mxfp4x4x2_mxfp4(void * data, const ggml_tensor * t, size_t si size_t row_size = ggml_row_size(t->type, t->ne[0]); size_t row_size_pd = ggml_row_size(t->type, hex_round_up(t->ne[0], QK_MXFP4x4x2)); // extra elements for the pad - size_t row_size_rp = row_size * 2; // extra space for tmp pad (if any) + size_t row_size_rp = row_size_pd; // scratch must hold one full padded tile (qblk_size/2 quants + scales) // Ensure we don't try to copy more data than the tensor actually contains. const size_t total_tensor_size = (size_t)nrows * row_size; @@ -1374,11 +1588,10 @@ static void ggml_backend_hexagon_buffer_set_tensor(ggml_backend_buffer_t buffer, const void * data, size_t offset, size_t size) { - auto ctx = (ggml_backend_hexagon_buffer_context *) buffer->context; - auto sess = ctx->sess; + auto sbuf = (ggml_hexagon_shared_buffer *) buffer->context; + auto sess = sbuf->sess; - HEX_VERBOSE("ggml-hex: %s set-tensor %s : data %p offset %zu size %zu\n", sess->name.c_str(), tensor->name, data, - offset, size); + HEX_VERBOSE("ggml-hex: %s set-tensor %s : data %p offset %zu size %zu\n", sess->c_name(), tensor->name, data, offset, size); switch (tensor->type) { case GGML_TYPE_Q4_0: @@ -1387,10 +1600,23 @@ static void ggml_backend_hexagon_buffer_set_tensor(ggml_backend_buffer_t buffer, repack_q4_0_q4x4x2(tensor, data, size); break; - case GGML_TYPE_Q8_0: + case GGML_TYPE_Q4_1: GGML_ASSERT(offset == 0); GGML_ASSERT(offset + size <= ggml_nbytes(tensor)); - repack_q8_0_q8x4x2(tensor, data, size); + repack_q4_1_q4x4x2(tensor, data, size); + break; + + case GGML_TYPE_Q8_0: + GGML_ASSERT(offset == 0); + GGML_ASSERT(offset + size <= ggml_nbytes(tensor)); + repack_q8_0_q8x4x2(tensor, data, size); + break; + + case GGML_TYPE_IQ4_NL: + GGML_ASSERT(offset == 0); + GGML_ASSERT(offset + size <= ggml_nbytes(tensor)); + // IQ4_NL has identical block layout to Q4_0 (ggml_half d + uint8_t qs[16]) + repack_q4_0_q4x4x2(tensor, data, size); break; case GGML_TYPE_MXFP4: @@ -1410,11 +1636,10 @@ static void ggml_backend_hexagon_buffer_get_tensor(ggml_backend_buffer_t buffer, void * data, size_t offset, size_t size) { - auto ctx = (ggml_backend_hexagon_buffer_context *) buffer->context; - auto sess = ctx->sess; + auto sbuf = (ggml_hexagon_shared_buffer *) buffer->context; + auto sess = sbuf->sess; - HEX_VERBOSE("ggml-hex: %s get-tensor %s : data %p offset %zu size %zu\n", sess->name.c_str(), tensor->name, data, - offset, size); + HEX_VERBOSE("ggml-hex: %s get-tensor %s : data %p offset %zu size %zu\n", sess->c_name(), tensor->name, data, offset, size); switch (tensor->type) { case GGML_TYPE_Q4_0: @@ -1423,12 +1648,24 @@ static void ggml_backend_hexagon_buffer_get_tensor(ggml_backend_buffer_t buffer, repack_q4x4x2_q4_0(data, tensor, size); break; + case GGML_TYPE_Q4_1: + GGML_ASSERT(offset == 0); + GGML_ASSERT(offset + size <= ggml_nbytes(tensor)); + repack_q4x4x2_q4_1(data, tensor, size); + break; + case GGML_TYPE_Q8_0: GGML_ASSERT(offset == 0); GGML_ASSERT(offset + size <= ggml_nbytes(tensor)); repack_q8x4x2_q8_0(data, tensor, size); break; + case GGML_TYPE_IQ4_NL: + GGML_ASSERT(offset == 0); + GGML_ASSERT(offset + size <= ggml_nbytes(tensor)); + repack_q4x4x2_q4_0(data, tensor, size); + break; + case GGML_TYPE_MXFP4: GGML_ASSERT(offset == 0); GGML_ASSERT(offset + size <= ggml_nbytes(tensor)); @@ -1452,10 +1689,10 @@ static bool ggml_backend_hexagon_buffer_cpy_tensor(ggml_backend_buffer_t bu } static void ggml_backend_hexagon_buffer_clear(ggml_backend_buffer_t buffer, uint8_t value) { - auto ctx = (ggml_backend_hexagon_buffer_context *) buffer->context; - auto sess = ctx->sess; - HEX_VERBOSE("ggml-hex: %s clear-buff base %p size %zu\n", sess->name.c_str(), (void *) ctx->base, ctx->size); - memset(ctx->base, value, ctx->size); + auto sbuf = (ggml_hexagon_shared_buffer *) buffer->context; + auto sess = sbuf->sess; + HEX_VERBOSE("ggml-hex: %s clear-buff base %p size %zu\n", sess->c_name(), (void *) sbuf->base, sbuf->size); + memset(sbuf->base, value, sbuf->size); } static ggml_backend_buffer_i ggml_backend_hexagon_buffer_interface = { @@ -1465,6 +1702,8 @@ static ggml_backend_buffer_i ggml_backend_hexagon_buffer_interface = { /* .memset_tensor = */ NULL, /* .set_tensor = */ ggml_backend_hexagon_buffer_set_tensor, /* .get_tensor = */ ggml_backend_hexagon_buffer_get_tensor, + /* .set_tensor_2d = */ NULL, + /* .get_tensor_2d = */ NULL, /* .cpy_tensor = */ ggml_backend_hexagon_buffer_cpy_tensor, /* .clear = */ ggml_backend_hexagon_buffer_clear, /* .reset = */ NULL, @@ -1480,10 +1719,11 @@ static ggml_backend_buffer_t ggml_backend_hexagon_buffer_type_alloc_buffer( ggml_backend_buffer_type_t buffer_type, size_t size) { auto sess = static_cast<ggml_backend_hexagon_buffer_type_context *>(buffer_type->context)->sess; try { - ggml_backend_hexagon_buffer_context * ctx = new ggml_backend_hexagon_buffer_context(sess, size, false /*repack*/); - return ggml_backend_buffer_init(buffer_type, ggml_backend_hexagon_buffer_interface, ctx, size); + size += 4 * 1024; // guard page + ggml_hexagon_shared_buffer * sbuf = new ggml_hexagon_shared_buffer(sess, size); + return ggml_backend_buffer_init(buffer_type, ggml_backend_hexagon_buffer_interface, sbuf, size); } catch (const std::exception & exc) { - GGML_LOG_ERROR("ggml-hex: %s failed to allocate buffer context: %s\n", sess->name.c_str(), exc.what()); + GGML_LOG_ERROR("ggml-hex: %s failed to allocate buffer context (host): %s\n", sess->c_name(), exc.what()); return nullptr; } } @@ -1492,10 +1732,11 @@ static ggml_backend_buffer_t ggml_backend_hexagon_repack_buffer_type_alloc_buffe ggml_backend_buffer_type_t buffer_type, size_t size) { auto sess = static_cast<ggml_backend_hexagon_buffer_type_context *>(buffer_type->context)->sess; try { - ggml_backend_hexagon_buffer_context * ctx = new ggml_backend_hexagon_buffer_context(sess, size, true /*repack*/); - return ggml_backend_buffer_init(buffer_type, ggml_backend_hexagon_buffer_interface, ctx, size); + size += 4 * 1024; // guard page + ggml_hexagon_shared_buffer * sbuf = new ggml_hexagon_shared_buffer(sess, size); + return ggml_backend_buffer_init(buffer_type, ggml_backend_hexagon_buffer_interface, sbuf, size); } catch (const std::exception & exc) { - GGML_LOG_ERROR("ggml-hex: %s failed to allocate buffer context: %s\n", sess->name.c_str(), exc.what()); + GGML_LOG_ERROR("ggml-hex: %s failed to allocate buffer context (repack): %s\n", sess->c_name(), exc.what()); return nullptr; } } @@ -1510,7 +1751,7 @@ static size_t ggml_backend_hexagon_buffer_type_get_alloc_size(ggml_backend_buffe } static size_t ggml_backend_hexagon_buffer_type_get_max_size(ggml_backend_buffer_type_t buffer_type) { - return 1 * 1024 * 1024 * 1024; // 1GB per buffer + return opt_mbuf; // typically 1GB per buffer GGML_UNUSED(buffer_type); } @@ -1542,6 +1783,448 @@ static ggml_backend_buffer_type_i ggml_backend_hexagon_repack_buffer_type_interf /* .is_host = */ ggml_backend_hexagon_repack_buffer_type_is_host, }; +struct ggml_hexagon_opbatch { + ggml_hexagon_session* sess; + + std::vector<htp_opnode> ops; // htp_opnode of ops + + std::vector<htp_buf_desc> h_bufs; // htp buffer descriptors + std::vector<htp_tensor> h_tens; // htp tensor descriptors + std::vector<htp_op_desc> h_ops; // htp op descriptors + + std::unordered_map<int, int> b_map; // buffer fd to index + std::unordered_map<const ggml_tensor*, int> t_map; // tensor ptr to index + std::unordered_multimap<void*, int> d_map; // tensor data to index + + unsigned int n_bufs; // num buffers in the batch + unsigned int n_tens; // num tensors ... + unsigned int n_ops; // num ops ... + size_t b_vmem; // sum of all buffer sizes + + unsigned int n_bufs_max; + unsigned int n_tens_max; + unsigned int n_ops_max; + size_t b_vmem_max; + + void reset() { + n_bufs = 0; + n_tens = 0; + n_ops = 0; + b_vmem = 0; + + b_map.clear(); + t_map.clear(); + d_map.clear(); + } + + ggml_hexagon_opbatch(ggml_hexagon_session *sess, size_t batch_size, size_t max_vmem) { + this->sess = sess; + + n_bufs_max = HTP_OP_MAX_BUFS; + n_ops_max = batch_size; + n_tens_max = n_ops_max + n_ops_max * HTP_OP_MAX_INPUTS; + + b_vmem_max = max_vmem; + + ops.resize(n_ops_max); + + h_bufs.resize(n_bufs_max); + h_tens.resize(n_tens_max); + h_ops.resize(n_ops_max); + + b_map.reserve(n_bufs_max); + t_map.reserve(n_tens_max); + d_map.reserve(n_tens_max); + + GGML_LOG_INFO("ggml-hex: %s op batching: n-bufs %u n-tensors %u n-ops %u vmem %zu\n", + sess->c_name(), n_bufs_max, n_tens_max, n_ops_max, b_vmem_max); + + reset(); + } + + bool empty() const { return n_ops == 0; } + + // add buffer and return its index + int add_buffer(ggml_hexagon_shared_buffer * sbuf) { + // Lookup by fd + auto it = b_map.find(sbuf->fd); + if (it != b_map.end()) { return it->second; } + + // Add new buffer to the batch + int bi = n_bufs++; + GGML_ASSERT(n_bufs < HTP_OP_MAX_BUFS); + + b_map.insert({sbuf->fd, bi}); + + htp_buf_desc &b = h_bufs[bi]; + b.base = (uint64_t) sbuf->base; + b.fd = sbuf->fd; + b.size = sbuf->size; + + b_vmem += b.size; + + HEX_VERBOSE("ggml-hex: add-buffer #%u : fd %d base %p size %zu : vmem %zu\n", bi, b.fd, (void*) sbuf->base, (size_t) b.size, b_vmem); + + return bi; + } + + bool same_shape(const htp_tensor * h, const ggml_tensor * t) const { + return (h->ne[0] == t->ne[0]) && (h->ne[1] == t->ne[1]) && (h->ne[2] == t->ne[2]) && (h->ne[3] == t->ne[3]) && + (h->nb[0] == t->nb[0]) && (h->nb[1] == t->nb[1]) && (h->nb[2] == t->nb[2]) && (h->nb[3] == t->nb[3]); + } + + // add tensor and return its index + int add_tensor(const ggml_tensor * t) { + auto sbuf = static_cast<ggml_hexagon_shared_buffer *>(t->buffer->context); + + // First lookup by tensor data + auto range = d_map.equal_range(t->data); + for (auto it = range.first; it != range.second; ++it) { + htp_tensor * h = &h_tens[it->second]; + if (same_shape(h, t)) { return it->second; } + } + + // Lookup by tensor ptr + auto it = t_map.find(t); + if (it != t_map.end()) { return it->second; } + + // Add new tensor to the batch + int ti = n_tens++; + GGML_ASSERT(n_tens <= n_tens_max); + + t_map.insert({t, ti}); + d_map.insert({t->data, ti}); + + uint64_t t_offset = (uint8_t *) t->data - sbuf->base; + size_t t_size = ggml_nbytes(t); + + htp_tensor &h = h_tens[ti]; + h.bi = add_buffer(sbuf); + h.data = t_offset; + h.size = t_size; + h.type = t->type; + h.ne[0] = t->ne[0]; h.ne[1] = t->ne[1]; h.ne[2] = t->ne[2]; h.ne[3] = t->ne[3]; + h.nb[0] = t->nb[0]; h.nb[1] = t->nb[1]; h.nb[2] = t->nb[2]; h.nb[3] = t->nb[3]; + + h.flags = 0; + if (ggml_backend_buffer_get_usage(t->buffer) == GGML_BACKEND_BUFFER_USAGE_COMPUTE) { + h.flags |= HTP_TENSOR_COMPUTE; + } + + HEX_VERBOSE("ggml-hex: add-tensor #%u %s : bi %d data %p offset %zu size %zu flags 0x%x : %zu:%zu:%zu:%zu\n", + ti, t->name, h.bi, (void*) t->data, (size_t) t_offset, t_size, h.flags, + (size_t) t->ne[0], (size_t) t->ne[1], (size_t) t->ne[2], (size_t) t->ne[3]); + + return ti; + } + + bool fit_op(const htp_opnode & node) const { + if (n_ops >= n_ops_max ) return false; + + // check how much extras we will need + size_t extra_bufs = 0; + size_t extra_vmem = 0; + size_t extra_tens = 0; + + auto fit_tensor = [&](const ggml_tensor *t) { + if (!t) return; + if (!t_map.count(t)) { + extra_tens++; + + auto sbuf = static_cast<ggml_hexagon_shared_buffer *>(t->buffer->context); + if (!b_map.count(sbuf->fd)) { + extra_vmem += sbuf->size; + extra_bufs += 1; + } + } + }; + + for (const auto * src : node.get_inputs()) { + fit_tensor(src); + } + fit_tensor(node.dst()); + + if ((extra_bufs + n_bufs) > n_bufs_max) return false; + if ((extra_tens + n_tens) > n_tens_max) return false; + if ((extra_vmem + b_vmem) > b_vmem_max) return false; + + return true; + } + + // assumes that fit_op() was called first and returned true + void add_op(const htp_opnode & node) { + // Add new op + + unsigned int n = n_ops++; + GGML_ASSERT(n_ops <= n_ops_max); + + ops[n] = node; + + htp_op_desc &o = h_ops[n]; + memcpy(&o.params, &node.node->op_params, sizeof(node.node->op_params)); + o.opcode = node.opcode; + o.flags = 0; + + if (!(opt_opstage & HTP_OPSTAGE_COMPUTE)) { + o.flags |= HTP_OPFLAGS_SKIP_COMPUTE; + } + + ggml_hexagon_dump_op_exec(sess->c_name(), node, o.flags); + + auto inputs = node.get_inputs(); + for (unsigned int i=0; i < HTP_OP_MAX_INPUTS; i++) { + o.src[i] = (i < inputs.size() && inputs[i]) ? add_tensor(inputs[i]) : 0xffff; + } + o.dst = add_tensor(node.dst()); + } +}; + +struct ggml_hexagon_opqueue { + // Shared buffer for storing batches + ggml_hexagon_shared_buffer *shm_buf; + size_t shm_blk_size; + + using opvec = std::vector<htp_opnode>; + + std::queue<unsigned int> done; // completed batch ids + std::vector<opvec> op_cache; // per batch op cache + std::vector<uint64_t> start_usec; // per batch start time + + ggml_hexagon_opqueue(ggml_hexagon_session *sess, size_t batch_size, size_t depth) { + size_t n_bufs = HTP_OP_MAX_BUFS; + size_t n_ops = batch_size; + size_t n_tensors = n_ops + n_ops * HTP_OP_MAX_INPUTS; + + shm_blk_size = sizeof(htp_buf_desc) * n_bufs + + sizeof(htp_tensor) * n_tensors + + sizeof(htp_op_desc) * n_ops + + sizeof(htp_prof_desc) * n_ops; + + shm_buf = new ggml_hexagon_shared_buffer(sess, shm_blk_size * depth, true /* pinned */); + + op_cache.resize(depth); + start_usec.resize(depth, 0); + + // init done queue + for (unsigned int i = 0; i < depth; i++) { done.push(i); } + + if (opt_verbose) { + GGML_LOG_INFO("ggml-hex: %s allocated op-queue : batch-size %zu depth %zu shm-size %zu shm-block-size %zu\n", + sess->c_name(), batch_size, depth, shm_buf->size, shm_blk_size); + } + } + + ~ggml_hexagon_opqueue() { + delete shm_buf; + } + + // push new batch + bool push(htp_opbatch_req& req, dspqueue_buffer& dbuf, ggml_hexagon_opbatch* op_batch) { + static_assert(sizeof(htp_opbatch_req) % 8 == 0, "sizeof(htp_opbatch_req) must be multiple of 8"); + static_assert(sizeof(htp_opbatch_rsp) % 8 == 0, "sizeof(htp_opbatch_rsp) must be multiple of 8"); + static_assert(sizeof(htp_buf_desc) % 8 == 0, "sizeof(htp_buf_desc) must be multiple of 8"); + static_assert(sizeof(htp_tensor) % 8 == 0, "sizeof(htp_tensor) must be multiple of 8"); + static_assert(sizeof(htp_op_desc) % 8 == 0, "sizeof(htp_op_desc) must be multiple of 8"); + static_assert(sizeof(htp_prof_desc) % 8 == 0, "sizeof(htp_prof_desc) must be multiple of 8"); + + if (done.empty()) { return false; } + + req.id = done.front(); done.pop(); // batch id + req.n_bufs = op_batch->n_bufs; + req.n_tensors = op_batch->n_tens; + req.n_ops = op_batch->n_ops; + + op_cache[req.id] = op_batch->ops; + start_usec[req.id] = ggml_time_us(); + + const size_t b_size = sizeof(htp_buf_desc) * req.n_bufs; + const size_t t_size = sizeof(htp_tensor) * req.n_tensors; + const size_t o_size = sizeof(htp_op_desc) * req.n_ops; + const size_t p_size = sizeof(htp_prof_desc) * req.n_ops; + + dbuf.ptr = shm_buf->base + (req.id * shm_blk_size); + dbuf.fd = shm_buf->fd; + dbuf.flags = DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT; + dbuf.offset = (uint8_t*) dbuf.ptr - (uint8_t*) shm_buf->base; + dbuf.size = b_size + t_size + o_size + p_size; + + GGML_ASSERT(dbuf.size <= shm_blk_size); + + uint8_t * m_ptr = (uint8_t*) dbuf.ptr; + uint8_t * b_ptr = m_ptr; m_ptr += b_size; + uint8_t * t_ptr = m_ptr; m_ptr += t_size; + uint8_t * o_ptr = m_ptr; + + memcpy(b_ptr, (void *) op_batch->h_bufs.data(), b_size); + memcpy(t_ptr, (void *) op_batch->h_tens.data(), t_size); + memcpy(o_ptr, (void *) op_batch->h_ops.data(), o_size); + + HEX_VERBOSE("ggml-hex: %s op-queue push batch #%u : n-bufs %u n-tensors %u n-ops %u vmem %zu : b-size %zu t-size %zu o-size %zu m-size %zu\n", + shm_buf->sess->c_name(), req.id, req.n_bufs, req.n_tensors, req.n_ops, op_batch->b_vmem, + b_size, t_size, o_size, (size_t) dbuf.size); + + op_batch->reset(); + + if (opt_verbose > 1) { + htp_buf_desc *b = (htp_buf_desc*) b_ptr; + for (unsigned int i=0; i < req.n_bufs; i++) { + GGML_LOG_DEBUG("ggml-hex: %s htp-buf #%u : fd %d base %p size %zu\n", shm_buf->sess->c_name(), i, + b[i].fd, (void *) b[i].base, (size_t) b[i].size); + } + htp_tensor *t = (htp_tensor*) t_ptr; + for (unsigned int i=0; i < req.n_tensors; i++) { + GGML_LOG_DEBUG("ggml-hex: %s htp-tensor #%u : bi %u offset %u size %u : %zu:%zu:%zu:%zu\n", + shm_buf->sess->c_name(), i, t[i].bi, t[i].data, t[i].size, + (size_t) t[i].ne[0], (size_t) t[i].ne[1], (size_t) t[i].ne[2], (size_t) t[i].ne[3]); + } + } + + return true; + } + + void pop(htp_opbatch_rsp rsp, dspqueue_buffer dbuf) { + GGML_ASSERT(rsp.id < op_cache.size()); + + done.push(rsp.id); + + const size_t b_size = sizeof(htp_buf_desc) * rsp.n_bufs; + const size_t t_size = sizeof(htp_tensor) * rsp.n_tensors; + const size_t o_size = sizeof(htp_op_desc) * rsp.n_ops; + const size_t p_size = sizeof(htp_prof_desc) * rsp.n_ops; + + const size_t m_size = b_size + t_size + o_size + p_size; + GGML_ASSERT(m_size <= shm_blk_size); + + HEX_VERBOSE("ggml-hex: %s op-queue pop batch #%u : n-bufs %u n-tensors %u n-ops %u : m-size %zu b-size %zu t-size %zu o-size %zu\n", + shm_buf->sess->c_name(), rsp.id, rsp.n_bufs, rsp.n_tensors, rsp.n_ops, + (size_t) dbuf.size, b_size, t_size, o_size); + + uint8_t * m_ptr = (uint8_t*) dbuf.ptr; + uint8_t * p_ptr = m_ptr + (b_size + t_size + o_size); + + if (opt_profile && rsp.n_ops > 0) { + auto & ops = op_cache[rsp.id]; + + uint64_t batch_usec = ggml_time_us() - start_usec[rsp.id]; + uint32_t htp_usec = 0; + + GGML_ASSERT(rsp.n_ops <= ops.size()); + + const htp_prof_desc * pd = (const htp_prof_desc *) p_ptr; + for (uint32_t i = 0; i < rsp.n_ops; i++) { + htp_usec += pd[i].usecs; + ggml_hexagon_dump_op_prof(shm_buf->sess->name, ops[i], pd[i].usecs, pd[i].cycles, pd[i].pmu); + } + + GGML_LOG_DEBUG("ggml-hex: %s profile-batch n-ops %u batch-dur-usec %lld htp-ops-usec %u\n", + shm_buf->sess->c_name(), rsp.n_ops, (long long) batch_usec, htp_usec); + } + } +}; + +// Flush HTP response queue i.e wait for all outstanding requests to complete +void ggml_hexagon_session::flush_pending(bool all) { + while (this->op_pending) { + struct htp_opbatch_rsp rsp; + uint32_t rsp_size; + uint32_t flags; + + struct dspqueue_buffer dbuf; + uint32_t n_dbufs; + + // Read response packet from queue + const uint32_t timeo = opt_oppoll ? 0 : DSPQUEUE_TIMEOUT; + int err = dspqueue_read(this->queue, &flags, 1, &n_dbufs, &dbuf, sizeof(rsp), &rsp_size, (uint8_t *) &rsp, timeo); + if (err == AEE_EEXPIRED) { + continue; + } + + if (err != 0) { + GGML_ABORT("ggml-hex: dspqueue_read failed: 0x%08x\n", (unsigned) err); + } + + // Basic sanity checks + if (rsp_size != sizeof(rsp) || n_dbufs != 1) { + GGML_ABORT("ggml-hex: %s dspcall : bad response : size %u dspbufs %u\n", this->c_name(), rsp_size, n_dbufs); + } + + if (rsp.status != HTP_STATUS_OK) { + GGML_LOG_ERROR("ggml-hex: %s dspcall : dsp-rsp: %s\n", this->c_name(), status_to_str(rsp.status)); + // TODO: handle errors + } + + op_queue->pop(rsp, dbuf); + + this->op_pending--; // atomic dec + + if (!all) break; + } +} + +void ggml_hexagon_session::flush_batch() { + if (op_batch->empty()) { return; } + + htp_opbatch_req req {}; + dspqueue_buffer dbuf{}; + + if (!op_queue->push(req, dbuf, op_batch)) { + flush_pending(false); + op_queue->push(req, dbuf, op_batch); + } + + // Bump pending flag (cleared in the session::flush once we get the response) + this->op_pending++; // atomic inc + + HEX_VERBOSE("ggml-hex: %s queue-opbatch: %p size %u\n", this->c_name(), dbuf.ptr, dbuf.size); + + int err = dspqueue_write(this->queue, 0, 1, &dbuf, sizeof(req), (const uint8_t*) &req, DSPQUEUE_TIMEOUT); + if (err != 0) { + GGML_ABORT("ggml-hex: %s dspqueue_write failed: 0x%08x\n", this->c_name(), (unsigned) err); + } +} + +void ggml_hexagon_session::enqueue_op(const htp_opnode & node) { + if (!op_batch->fit_op(node)) { + flush_batch(); + } + op_batch->add_op(node); +} + +// Flush HTP response queue i.e wait for all outstanding requests to complete +void ggml_hexagon_session::flush(bool all) { + flush_batch(); + flush_pending(all); +} + +static size_t ggml_hexagon_measure_max_vmem(ggml_hexagon_session *sess) { + // Allocate a bunch pinned buffers till failure. + // This is kind of expensive but handy for figuring out exactly how much we can mmap on a specific device. + // Typically we're going to allocate all/most of these buffers anyway for the model weights. + + std::vector<ggml_hexagon_shared_buffer *> sbufs; + + const size_t MiB = 1024 * 1024; + const size_t GiB = MiB * 1024; + + size_t vmem = 0; + size_t step = 256u * MiB; + + try { + sbufs.push_back(new ggml_hexagon_shared_buffer(sess, GiB, true)); vmem += GiB; + sbufs.push_back(new ggml_hexagon_shared_buffer(sess, GiB, true)); vmem += GiB; + sbufs.push_back(new ggml_hexagon_shared_buffer(sess, GiB, true)); vmem += GiB; + + while (1) { + sbufs.push_back(new ggml_hexagon_shared_buffer(sess, step, true)); + vmem += step; + } + } catch (...) { } + + for (auto b : sbufs) { delete b; } + + return vmem - step; // backoff to account for overhead from internal mappings +} + void ggml_hexagon_session::allocate(int dev_id) noexcept(false) { this->valid_session = false; this->valid_handle = false; @@ -1554,11 +2237,8 @@ void ggml_hexagon_session::allocate(int dev_id) noexcept(false) { this->name = std::string("HTP") + std::to_string(dev_id); this->op_pending = 0; - this->prof_usecs = 0; - this->prof_cycles = 0; - this->prof_pkts = 0; - GGML_LOG_INFO("ggml-hex: allocating new session: %s\n", this->name.c_str()); + GGML_LOG_DEBUG("ggml-hex: %s allocating new session\n", this->name.c_str()); domain * my_domain = get_domain(this->domain_id); if (my_domain == NULL) { @@ -1634,9 +2314,6 @@ void ggml_hexagon_session::allocate(int dev_id) noexcept(false) { this->valid_handle = true; - GGML_LOG_INFO("ggml-hex: new session: %s : session-id %d domain-id %d uri %s handle 0x%lx\n", this->name.c_str(), - this->session_id, this->domain_id, session_uri, (unsigned long) this->handle); - // Enable FastRPC QoS mode { struct remote_rpc_control_latency l; @@ -1648,11 +2325,17 @@ void ggml_hexagon_session::allocate(int dev_id) noexcept(false) { } } + GGML_LOG_INFO("ggml-hex: %s new session : session-id %d domain-id %d uri %s handle 0x%lx\n", this->c_name(), + this->session_id, this->domain_id, session_uri, (unsigned long) this->handle); + + const size_t req_q_size = (sizeof(htp_opbatch_req) * opt_opqueue * 2) + 1024; + const size_t rsp_q_size = (sizeof(htp_opbatch_rsp) * opt_opqueue * 2) + 1024; + // Now let's setup the DSP queue err = dspqueue_create(this->domain_id, 0, // Flags - 128 * 1024, // Request queue size (in bytes) - 64 * 1024, // Response queue size (in bytes) + req_q_size, // Request queue size (in bytes) + rsp_q_size, // Response queue size (in bytes) nullptr, // Read packet callback (we handle reads explicitly) nullptr, // Error callback (we handle errors during reads) (void *) this, // Callback context @@ -1672,18 +2355,36 @@ void ggml_hexagon_session::allocate(int dev_id) noexcept(false) { } if (opt_etm) { - err = htp_iface_enable_etm(this->handle); + err = htp_iface_etm(this->handle, 1); if (err != 0) { GGML_LOG_ERROR("ggml-hex: failed to enable ETM tracing: 0x%08x\n", (unsigned) err); } } - // Start the DSP-side service. We need to pass the queue ID to the - // DSP in a FastRPC call; the DSP side will import the queue and start - // listening for packets in a callback. - err = htp_iface_start(this->handle, dev_id, this->queue_id, opt_nhvx); + if (opt_profile) { + htp_iface_pmu_conf pmu_conf{}; + std::copy(opt_pmu_evt.begin(), opt_pmu_evt.end(), pmu_conf.events); + + err = htp_iface_profiler(this->handle, opt_profile, &pmu_conf); + if (err != 0) { + GGML_LOG_ERROR("ggml-hex: failed to enable profiling: 0x%08x\n", (unsigned) err); + } + } + + // Allocate buffers and state for op batching + this->op_queue = new ggml_hexagon_opqueue(this, opt_opbatch, opt_opqueue); + + if (!opt_vmem) { + opt_vmem = ggml_hexagon_measure_max_vmem(this); + GGML_LOG_INFO("ggml-hex: %s measured max vmem %zu\n", this->c_name(), opt_vmem); + } + + this->op_batch = new ggml_hexagon_opbatch(this, opt_opbatch, opt_vmem); + + // Start dspqueue/opbatch processing + err = htp_iface_start(this->handle, dev_id, this->queue_id, opt_nhvx, opt_use_hmx, opt_vmem); if (err != 0) { - GGML_LOG_ERROR("ggml-hex: failed to start session: 0x%08x\n", (unsigned) err); + GGML_LOG_ERROR("ggml-hex: %s failed to start session: 0x%08x\n", this->c_name(), (unsigned) err); throw std::runtime_error("ggml-hex: iface start failed (see log for details)"); } this->valid_iface = true; @@ -1694,21 +2395,32 @@ void ggml_hexagon_session::release() noexcept(true) { int err; - // Stop the DSP-side service and close the queue if (this->valid_iface) { + // Stop dspqueue/opbatch processing err = htp_iface_stop(this->handle); if (err != 0) { GGML_ABORT("ggml-hex: htp_iface_stop failed: 0x%08x\n", (unsigned) err); } } + delete this->op_batch; + delete this->op_queue; + if (opt_etm) { - err = htp_iface_disable_etm(this->handle); + err = htp_iface_etm(this->handle, 0); if (err != 0) { GGML_LOG_ERROR("ggml-hex: warn : failed to disable ETM tracing: 0x%08x\n", (unsigned) err); } } + if (opt_profile) { + htp_iface_pmu_conf pmu_conf{}; + err = htp_iface_profiler(this->handle, 0, &pmu_conf); + if (err != 0) { + GGML_LOG_ERROR("ggml-hex: warn : failed to disable profiling: 0x%08x\n", (unsigned) err); + } + } + if (this->valid_queue) { err = dspqueue_close(queue); if (err != 0) { @@ -1725,6 +2437,9 @@ ggml_hexagon_session::ggml_hexagon_session(int dev_id, ggml_backend_dev_t dev) n buffer_type.device = dev; repack_buffer_type.device = dev; + op_batch = nullptr; + op_queue = nullptr; + try { allocate(dev_id); @@ -1753,24 +2468,10 @@ static bool ggml_backend_buffer_is_hexagon(const struct ggml_backend_buffer * b) } static inline bool ggml_backend_buffer_is_hexagon_repack(const struct ggml_backend_buffer * b) { - return b->buft->iface.alloc_buffer == ggml_backend_hexagon_repack_buffer_type_alloc_buffer; -} - -static bool hex_supported_dims2(const struct ggml_tensor * x, const struct ggml_tensor * y) { - if (x->ne[0] != y->ne[0]) { - return false; - } - if (x->ne[1] != y->ne[1]) { - return false; - } - if (x->ne[2] != y->ne[2]) { - return false; + if (!opt_hostbuf) { + return ggml_backend_buffer_is_hexagon(b); } - if (x->ne[3] != y->ne[3]) { - return false; - } - - return true; + return b->buft->iface.alloc_buffer == ggml_backend_hexagon_repack_buffer_type_alloc_buffer; } static bool ggml_hexagon_supported_flash_attn_ext(const struct ggml_hexagon_session * sess, const struct ggml_tensor * op) { @@ -1801,44 +2502,64 @@ static bool ggml_hexagon_supported_flash_attn_ext(const struct ggml_hexagon_sess return false; } - return opt_experimental; -} + if (dst->ne[3] != 1) { + return false; + } -static bool hex_supported_src0_type(ggml_type t) { - return t == GGML_TYPE_F32; + return true; } -static bool hex_supported_src1_type(ggml_type t) { - return t == GGML_TYPE_F32; -} +static bool ggml_hexagon_supported_gated_delta_net(const struct ggml_hexagon_session * sess, const struct ggml_tensor * op) { + const struct ggml_tensor * q = op->src[0]; + const struct ggml_tensor * k = op->src[1]; + const struct ggml_tensor * v = op->src[2]; + const struct ggml_tensor * g = op->src[3]; + const struct ggml_tensor * beta = op->src[4]; + const struct ggml_tensor * state = op->src[5]; + const struct ggml_tensor * dst = op; -static bool hex_supported_src2_type(ggml_type t) { - return t == GGML_TYPE_F32; -} + if (!q || !k || !v || !g || !beta || !state) { + return false; + } -static bool hex_supported_src1_type2(ggml_type t) { - return t == GGML_TYPE_F16; -} + if (q->type != GGML_TYPE_F32 || k->type != GGML_TYPE_F32 || v->type != GGML_TYPE_F32 || + g->type != GGML_TYPE_F32 || beta->type != GGML_TYPE_F32 || state->type != GGML_TYPE_F32 || + dst->type != GGML_TYPE_F32) { + return false; + } -static bool hex_supported_src1_type3(ggml_type t) { - return t == GGML_TYPE_I32; -} + if (!ggml_is_contiguous_rows(q) || !ggml_is_contiguous_rows(k) || !ggml_is_contiguous_rows(v) || + !ggml_is_contiguous(g) || !ggml_is_contiguous(beta) || !ggml_is_contiguous(state) || + !ggml_is_contiguous(dst)) { + return false; + } -static bool hex_supported_dst_type(ggml_type t) { - return t == GGML_TYPE_F32; -} + const int64_t S_v = v->ne[0]; + const int64_t H = v->ne[1]; + const int64_t n_tokens = v->ne[2]; + const int64_t n_seqs = v->ne[3]; + const int64_t K = ggml_get_op_params_i32(op, 0); -static bool hex_supported_dims(const struct ggml_tensor * x, const struct ggml_tensor * y) { - // TODO: support broadcast for ne[2 and 3] - if (x->ne[0] != y->ne[0]) { + if (S_v <= 0 || S_v > 128 || H <= 0 || n_tokens <= 0 || n_seqs <= 0) { + return false; + } + if (q->ne[0] != S_v || k->ne[0] != S_v || q->ne[1] <= 0 || k->ne[1] <= 0 || + q->ne[2] != n_tokens || k->ne[2] != n_tokens || q->ne[3] <= 0 || k->ne[3] <= 0 || + (n_seqs % q->ne[3]) != 0 || (n_seqs % k->ne[3]) != 0) { return false; } - if (x->ne[2] != y->ne[2]) { + if ((g->ne[0] != 1 && g->ne[0] != S_v) || beta->ne[0] != 1) { return false; } - if (x->ne[3] != y->ne[3]) { + // state holds s0 only [S_v, S_v, H, n_seqs]; K is op param 0. + if (ggml_nelements(state) != S_v * S_v * H * n_seqs) { return false; } + if (dst->ne[0] != S_v * H || dst->ne[1] != n_tokens * n_seqs + S_v * n_seqs * K) { + return false; + } + + GGML_UNUSED(sess); return true; } @@ -1856,18 +2577,20 @@ static bool ggml_hexagon_supported_mul_mat(const struct ggml_hexagon_session * s switch (src0->type) { case GGML_TYPE_Q4_0: + case GGML_TYPE_Q4_1: case GGML_TYPE_Q8_0: + case GGML_TYPE_IQ4_NL: case GGML_TYPE_MXFP4: if (src0->ne[0] % 32) { return false; } - if (src0->ne[1] > 16 * 1024) { + if (ggml_nrows(src0) > 16 * 1024) { return false; // typically the lm-head which would be too large for VTCM } - if ((src1->ne[2] != 1 || src1->ne[3] != 1)) { - return false; + if (ggml_nrows(src1) > 1024 || src1->ne[2] != 1 || src1->ne[3] != 1) { + return false; // no huge batches or broadcasting (for now) } // src0 (weights) must be repacked @@ -1881,6 +2604,30 @@ static bool ggml_hexagon_supported_mul_mat(const struct ggml_hexagon_session * s GGML_LOG_DEBUG("ggml_hexagon_supported_mul_mat: permuted F16 src0 not supported\n"); return false; } + if (src1->ne[2] < src0->ne[2] || src1->ne[3] < src0->ne[3]) { + GGML_LOG_DEBUG("ggml_hexagon_supported_mul_mat: src1 broadcasting not supported\n"); + return false; + } + if (ggml_nrows(src1) > 1024) { + return false; // no huge batches (for now) + } + break; + + case GGML_TYPE_F32: + if (src1->type != GGML_TYPE_F32) { + return false; + } + if (src0->nb[1] < src0->nb[0]) { + GGML_LOG_DEBUG("ggml_hexagon_supported_mul_mat: permuted F32 src0 not supported\n"); + return false; + } + if (src1->ne[2] < src0->ne[2] || src1->ne[3] < src0->ne[3]) { + GGML_LOG_DEBUG("ggml_hexagon_supported_mul_mat: src1 broadcasting not supported\n"); + return false; + } + if (ggml_nrows(src1) > 1024) { + return false; // no huge batches (for now) + } break; default: @@ -1902,7 +2649,9 @@ static bool ggml_hexagon_supported_mul_mat_id(const struct ggml_hexagon_session switch (src0->type) { case GGML_TYPE_Q4_0: + case GGML_TYPE_Q4_1: case GGML_TYPE_Q8_0: + case GGML_TYPE_IQ4_NL: case GGML_TYPE_MXFP4: if ((src0->ne[0] % 32)) { return false; @@ -1926,24 +2675,30 @@ static bool ggml_hexagon_supported_binary(const struct ggml_hexagon_session * se const struct ggml_tensor * src1 = op->src[1]; const struct ggml_tensor * dst = op; - if (!hex_supported_src0_type(src0->type)) { - return false; - } - if (!hex_supported_src1_type(src1->type)) { - return false; + if (src0->type == GGML_TYPE_F32) { + if (src1->type != GGML_TYPE_F32) { + return false; + } + if (dst->type != GGML_TYPE_F32) { + return false; + } } - if (!hex_supported_dst_type(dst->type)) { - return false; + else if (src0->type == GGML_TYPE_F16) { + if (src1->type != GGML_TYPE_F16) { + return false; + } + if (dst->type != GGML_TYPE_F16) { + return false; + } } - if (!hex_supported_dims2(src0, dst)) { + else { return false; } - if (!ggml_can_repeat(src1, src0)) { + + if (!ggml_are_same_shape(src0, dst)) { return false; } - - // TODO: add support for non-contigiuos tensors - if (!ggml_is_contiguous(src0) || !ggml_is_contiguous(src1) || !ggml_is_contiguous(dst)) { + if (!ggml_can_repeat(src1, src0) || ggml_is_permuted(src1)) { return false; } @@ -1955,16 +2710,16 @@ static bool ggml_hexagon_supported_add_id(const struct ggml_hexagon_session * se const struct ggml_tensor * src1 = op->src[1]; const struct ggml_tensor * dst = op; - if (!hex_supported_src0_type(src0->type)) { + if (src0->type != GGML_TYPE_F32) { return false; } - if (!hex_supported_src1_type(src1->type)) { + if (src1->type != GGML_TYPE_F32) { return false; } - if (!hex_supported_dst_type(dst->type)) { + if (dst->type != GGML_TYPE_F32) { return false; } - if (!hex_supported_dims2(src0, dst)) { + if (!ggml_are_same_shape(src0, dst)) { return false; } @@ -1980,13 +2735,32 @@ static bool ggml_hexagon_supported_unary(const struct ggml_hexagon_session * ses const struct ggml_tensor * src0 = op->src[0]; const struct ggml_tensor * dst = op; - if (!hex_supported_src0_type(src0->type)) { + if (src0->type != GGML_TYPE_F32) { + return false; + } + if (dst->type != GGML_TYPE_F32) { + return false; + } + if (!ggml_are_same_shape(src0, dst)) { + return false; + } + + // dst must be contiguous; src0 may be non-contiguous + if (!ggml_is_contiguous(dst)) { return false; } - if (!hex_supported_dst_type(dst->type)) { + + return true; +} + +static bool ggml_hexagon_supported_sum_rows(const struct ggml_hexagon_session * sess, const struct ggml_tensor * op) { + const struct ggml_tensor * src0 = op->src[0]; + const struct ggml_tensor * dst = op; + + if (src0->type != GGML_TYPE_F32) { return false; } - if (!hex_supported_dims2(src0, dst)) { + if (dst->type != GGML_TYPE_F32) { return false; } @@ -2004,10 +2778,10 @@ static bool ggml_hexagon_supported_activations(const struct ggml_hexagon_session const struct ggml_tensor * src1 = op->src[1]; const struct ggml_tensor * dst = op; - if (!hex_supported_src0_type(src0->type)) { + if (src0->type != GGML_TYPE_F32) { return false; } - if (!hex_supported_dst_type(dst->type)) { + if (dst->type != GGML_TYPE_F32) { return false; } @@ -2016,10 +2790,10 @@ static bool ggml_hexagon_supported_activations(const struct ggml_hexagon_session } if (src1) { - if (!hex_supported_src1_type(src1->type)) { + if (src1->type != GGML_TYPE_F32) { return false; } - if (!hex_supported_dims2(src0, src1)) { + if (!ggml_are_same_shape(src0, src1)) { return false; } if (!ggml_is_contiguous(src1)) { @@ -2040,15 +2814,15 @@ static bool ggml_hexagon_supported_softmax(const struct ggml_hexagon_session * s return false; // FIXME: add support for sinks } - if (!hex_supported_src0_type(src0->type)) { + if (src0->type != GGML_TYPE_F32) { return false; } - if (!hex_supported_dst_type(dst->type)) { + if (dst->type != GGML_TYPE_F32) { return false; } if (src1) { - if (!hex_supported_src1_type(src1->type) && !hex_supported_src1_type2(src1->type)) { + if (src1->type != GGML_TYPE_F32 && src1->type != GGML_TYPE_F16) { return false; } if (src0->ne[0] != src1->ne[0]) { @@ -2075,6 +2849,23 @@ static bool ggml_hexagon_supported_softmax(const struct ggml_hexagon_session * s } } + // Reject non-HVX-aligned sizes when ne[0] > HVX_F32_LANES + // The HVX softmax implementation has issues with tail handling for larger non-aligned sizes + // Small sizes (ne[0] <= 32) work correctly with tail-only processing + const int64_t ne0 = src0->ne[0]; + if (ne0 > 32 && (ne0 & (32 - 1)) != 0) { + return false; + } + + // HVX vector size constraints for softmax + #define SOFTMAX_MAX_ROW_SIZE 131072 // 128K elements max for numerical precision + + // Reject very large row sizes to avoid numerical precision issues + // Softmax accumulation over many elements can lead to precision loss + if (ne0 > SOFTMAX_MAX_ROW_SIZE) { + return false; + } + return true; } @@ -2118,12 +2909,32 @@ static bool ggml_hexagon_supported_get_rows(const struct ggml_hexagon_session * return true; } +static bool ggml_hexagon_supported_argsort(const struct ggml_hexagon_session * sess, const struct ggml_tensor * op) { + const struct ggml_tensor * src0 = op->src[0]; // values + const struct ggml_tensor * dst = op; // indices + + if (src0->type != GGML_TYPE_F32) { + return false; + } + + if (dst->type != GGML_TYPE_I32) { + return false; + } + + if (src0->ne[0] > (16*1024)) { + // reject tensors with huge rows for now + return false; + } + + return true; +} + static bool ggml_hexagon_supported_rope(const struct ggml_hexagon_session * sess, const struct ggml_tensor * op) { const int32_t * op_params = &op->op_params[0]; int mode = op_params[2]; - if ((mode & GGML_ROPE_TYPE_MROPE) || (mode & GGML_ROPE_TYPE_VISION)) { + if (mode == GGML_ROPE_TYPE_VISION) { return false; } if (mode & 1) { @@ -2135,17 +2946,17 @@ static bool ggml_hexagon_supported_rope(const struct ggml_hexagon_session * sess const struct ggml_tensor * src2 = op->src[2]; const struct ggml_tensor * dst = op; - if (!hex_supported_src0_type(src0->type)) { + if (src0->type != GGML_TYPE_F32) { return false; // FIXME: add support for GGML_TYPE_F16 for src0 } - if (!hex_supported_dst_type(dst->type)) { + if (dst->type != GGML_TYPE_F32) { return false; } - if (!hex_supported_src1_type3(src1->type)) { + if (src1->type != GGML_TYPE_I32) { return false; } if (src2) { - if (!hex_supported_src2_type(src2->type)) { + if (src2->type != GGML_TYPE_F32) { return false; } int n_dims = op_params[1]; @@ -2168,277 +2979,147 @@ static bool ggml_hexagon_supported_rope(const struct ggml_hexagon_session * sess return true; } -enum dspqbuf_type { - DSPQBUF_TYPE_DSP_WRITE_CPU_READ = 0, - DSPQBUF_TYPE_CPU_WRITE_DSP_READ, - DSPQBUF_TYPE_CONSTANT, -}; - -static void dspqbuf_dump(dspqueue_buffer * d, const struct ggml_tensor * t, dspqbuf_type type) { - if (opt_verbose < 2) return; - - auto buf = static_cast<ggml_backend_hexagon_buffer_context *>(t->buffer->context); - auto sess = buf->sess; - - GGML_LOG_DEBUG("ggml-hex: %s dspqbuf : %s base-addr %p base-size %zu data %p offset %u size %u\n", sess->name.c_str(), - t->name, (void *) buf->base, buf->size, (void *) d->ptr, (unsigned int) d->offset, - (unsigned int) d->size); -} - -// Init hexagon tensor from GGML tensor and Hexagon buffer -static void htp_req_tensor_init(htp_tensor * h, const ggml_tensor * t) { - h->data = 0; // updated by the receiver - h->type = t->type; - h->ne[0] = t->ne[0]; - h->ne[1] = t->ne[1]; - h->ne[2] = t->ne[2]; - h->ne[3] = t->ne[3]; - h->nb[0] = t->nb[0]; - h->nb[1] = t->nb[1]; - h->nb[2] = t->nb[2]; - h->nb[3] = t->nb[3]; -} +static bool ggml_hexagon_supported_ssm_conv(const struct ggml_hexagon_session * sess, const struct ggml_tensor * op) { + const struct ggml_tensor * src0 = op->src[0]; + const struct ggml_tensor * src1 = op->src[1]; + const struct ggml_tensor * dst = op; -static size_t htp_req_buff_init(htp_tensor *h, dspqueue_buffer * d, const ggml_tensor * t, dspqbuf_type type) { - if (!t) { - return 0; + // Only support FP32 for now + if (src0->type != GGML_TYPE_F32 || src1->type != GGML_TYPE_F32 || dst->type != GGML_TYPE_F32) { + return false; } - auto buf = static_cast<ggml_backend_hexagon_buffer_context *>(t->buffer->context); + // Check IO tensor shapes and dims + if (src0->ne[3] != 1 || src1->ne[2] != 1 || src1->ne[3] != 1 || dst->ne[3] != 1) { + return false; // src0 should be effectively 3D + } - memset(d, 0, sizeof(*d)); - d->fd = buf->fd; - d->ptr = t->data; - d->offset = (uint8_t *) t->data - buf->base; - d->size = ggml_nbytes(t); + const int d_conv = src1->ne[0]; + const int d_inner = src0->ne[1]; + const int n_t = dst->ne[1]; + const int n_s = dst->ne[2]; - if (!d->size) { - // Some requests contain srcs where ggml_nbytes() returns 0 but the rest of the op is non-empty - d->size = 64; + if (src0->ne[0] != d_conv - 1 + n_t || src0->ne[1] != d_inner || src0->ne[2] != n_s) { + return false; } - - switch (type) { - case DSPQBUF_TYPE_DSP_WRITE_CPU_READ: - // Flush CPU - d->flags = DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER; - break; - case DSPQBUF_TYPE_CPU_WRITE_DSP_READ: - // Flush CPU, Invalidate DSP - d->flags = DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT; - break; - default: - // Constant buffer, no cache maintenance - d->flags = 0; - break; + if (src1->ne[0] != d_conv || src1->ne[1] != d_inner) { + return false; } - - htp_req_tensor_init(h, t); - - dspqbuf_dump(d, t, type); - - return 1; -} - -typedef size_t (*htp_req_init_func_t)(htp_general_req * req, dspqueue_buffer * bufs, const ggml_tensor * op); - -template <htp_req_init_func_t _init_req_func> -static inline void ggml_hexagon_dispatch_op(ggml_hexagon_session *sess, const struct ggml_tensor * op, uint32_t flags) { - uint64_t t = ggml_time_us(); - - // Construct HTP request - htp_general_req req; - memset(&req, 0, sizeof(req)); - - req.flags = flags; - if (!(opt_opmask & HTP_OPMASK_QUANTIZE)) { - req.flags |= HTP_OPFLAGS_SKIP_QUANTIZE; + if (dst->ne[0] != d_inner || dst->ne[1] != n_t || dst->ne[2] != n_s) { + return false; } - if (!(opt_opmask & HTP_OPMASK_COMPUTE)) { - req.flags |= HTP_OPFLAGS_SKIP_COMPUTE; + if (src0->nb[0] != sizeof(float) || src1->nb[0] != sizeof(float) || dst->nb[0] != sizeof(float)) { + return false; } - - ggml_hexagon_dump_op_exec(sess->name, op, req.flags); - - if ((opt_opmask & HTP_OPMASK_QUEUE)) { - dspqueue_buffer bufs[HTP_MAX_PACKET_BUFFERS]; - size_t n_bufs = _init_req_func(&req, bufs, op); - sess->enqueue(req, bufs, n_bufs, opt_opsync); + if (src0->nb[1] != src0->ne[0] * sizeof(float) || src1->nb[1] != src1->ne[0] * sizeof(float)) { + return false; } - t = ggml_time_us() - t; - - ggml_hexagon_dump_op_prof(sess->name, op, sess->prof_usecs, sess->prof_cycles, sess->prof_pkts, t); + return true; } -template <bool _is_src0_constant> -static inline size_t init_binary_req(htp_general_req * req, dspqueue_buffer * bufs, const ggml_tensor * t) { - switch (t->op) { - case GGML_OP_MUL_MAT: - req->op = HTP_OP_MUL_MAT; - break; - case GGML_OP_MUL: - req->op = HTP_OP_MUL; - break; - case GGML_OP_ADD: - req->op = HTP_OP_ADD; - break; - case GGML_OP_SUB: - req->op = HTP_OP_SUB; - break; - default: - GGML_ABORT("ggml-hex: binary : unsupported op: %d\n", t->op); - break; - } - - // src0: Weights (mulmat) or First Operand (binary op). - // If constant (e.g. weights), no cache management is needed. - // src1: Input Activations (mulmat) or Second Operand (binary op). +static bool ggml_hexagon_supported_pad(const struct ggml_hexagon_session * sess, const struct ggml_tensor * op) { + const struct ggml_tensor * src0 = op->src[0]; + const struct ggml_tensor * dst = op; - size_t n_bufs = 0; - n_bufs += htp_req_buff_init(&req->src0, &bufs[n_bufs], t->src[0], _is_src0_constant ? DSPQBUF_TYPE_CONSTANT : DSPQBUF_TYPE_CPU_WRITE_DSP_READ); - n_bufs += htp_req_buff_init(&req->src1, &bufs[n_bufs], t->src[1], DSPQBUF_TYPE_CPU_WRITE_DSP_READ); - n_bufs += htp_req_buff_init(&req->dst, &bufs[n_bufs], t, DSPQBUF_TYPE_DSP_WRITE_CPU_READ); + if (src0->type != GGML_TYPE_F32 || dst->type != GGML_TYPE_F32) { + return false; + } - return n_bufs; + GGML_UNUSED(sess); + return true; } -static inline size_t init_get_rows_req(htp_general_req * req, dspqueue_buffer * bufs, const ggml_tensor * t) { - req->op = HTP_OP_GET_ROWS; - - size_t n_bufs = 0; - n_bufs += htp_req_buff_init(&req->src0, &bufs[n_bufs], t->src[0], DSPQBUF_TYPE_CPU_WRITE_DSP_READ); - n_bufs += htp_req_buff_init(&req->src1, &bufs[n_bufs], t->src[1], DSPQBUF_TYPE_CPU_WRITE_DSP_READ); - n_bufs += htp_req_buff_init(&req->dst, &bufs[n_bufs], t, DSPQBUF_TYPE_DSP_WRITE_CPU_READ); - - return n_bufs; -} +static bool ggml_hexagon_supported_cumsum(const struct ggml_hexagon_session * sess, const struct ggml_tensor * op) { + const struct ggml_tensor * src0 = op->src[0]; + const struct ggml_tensor * dst = op; -template <bool _is_src0_constant> -static inline size_t init_binary_id_req(htp_general_req * req, dspqueue_buffer * bufs, const ggml_tensor * t) { - switch (t->op) { - case GGML_OP_MUL_MAT_ID: - req->op = HTP_OP_MUL_MAT_ID; - break; - case GGML_OP_ADD_ID: - req->op = HTP_OP_ADD_ID; - break; - default: - GGML_ABORT("ggml-hex: unsupported op: %d\n", t->op); + if (src0->type != GGML_TYPE_F32 || dst->type != GGML_TYPE_F32) { + return false; } - // src0: Weights (mulmat) or Input Activations (other op). - // If constant, no cache management is needed. - // src1: Input Activations (mulmat) or Second Operand (binary op). - // src2: Expert IDs (mulmat) or Activated Experts (other op). - - size_t n_bufs = 0; - n_bufs += htp_req_buff_init(&req->src0, &bufs[n_bufs], t->src[0], _is_src0_constant ? DSPQBUF_TYPE_CONSTANT : DSPQBUF_TYPE_CPU_WRITE_DSP_READ); - n_bufs += htp_req_buff_init(&req->src1, &bufs[n_bufs], t->src[1], DSPQBUF_TYPE_CPU_WRITE_DSP_READ); - n_bufs += htp_req_buff_init(&req->src2, &bufs[n_bufs], t->src[2], DSPQBUF_TYPE_CPU_WRITE_DSP_READ); - n_bufs += htp_req_buff_init(&req->dst, &bufs[n_bufs], t, DSPQBUF_TYPE_DSP_WRITE_CPU_READ); + if (!ggml_is_contiguous(src0) || !ggml_is_contiguous(dst)) { + return false; + } - return n_bufs; + GGML_UNUSED(sess); + return true; } -static inline size_t init_set_rows_req(htp_general_req * req, dspqueue_buffer * bufs, const ggml_tensor * t) { - req->op = HTP_OP_SET_ROWS; - - size_t n_bufs = 0; - n_bufs += htp_req_buff_init(&req->src0, &bufs[n_bufs], t->src[0], DSPQBUF_TYPE_CPU_WRITE_DSP_READ); - n_bufs += htp_req_buff_init(&req->src1, &bufs[n_bufs], t->src[1], DSPQBUF_TYPE_CPU_WRITE_DSP_READ); - n_bufs += htp_req_buff_init(&req->dst, &bufs[n_bufs], t, DSPQBUF_TYPE_DSP_WRITE_CPU_READ); +static bool ggml_hexagon_supported_diag(const struct ggml_hexagon_session * sess, const struct ggml_tensor * op) { + const struct ggml_tensor * src0 = op->src[0]; + const struct ggml_tensor * dst = op; - return n_bufs; -} + // diag only supports F32 currently + if (src0->type != GGML_TYPE_F32 || dst->type != GGML_TYPE_F32) { + return false; + } -static inline size_t init_unary_req(htp_general_req * req, dspqueue_buffer * bufs, const ggml_tensor * t) { - memcpy(&req->op_params, &t->op_params, sizeof(t->op_params)); + // Input must have ne[1] == 1 (vector input) + if (src0->ne[1] != 1) { + return false; + } - bool supported = false; + // Output must be square in first two dimensions + if (dst->ne[0] != dst->ne[1] || dst->ne[0] != src0->ne[0]) { + return false; + } - switch (t->op) { - case GGML_OP_RMS_NORM: - req->op = HTP_OP_RMS_NORM; - supported = true; - break; + GGML_UNUSED(sess); + return true; +} - case GGML_OP_SCALE: - req->op = HTP_OP_SCALE; - supported = true; - break; +static bool ggml_hexagon_supported_solve_tri(const struct ggml_hexagon_session * sess, const struct ggml_tensor * op) { + const struct ggml_tensor * src0 = op->src[0]; // A + const struct ggml_tensor * src1 = op->src[1]; // B + const struct ggml_tensor * dst = op; // X - case GGML_OP_UNARY: - if (ggml_get_unary_op(t) == GGML_UNARY_OP_SILU) { - req->op = HTP_OP_UNARY_SILU; - supported = true; - } else if (ggml_get_unary_op(t) == GGML_UNARY_OP_GELU) { - req->op = HTP_OP_UNARY_GELU; - supported = true; - } - break; + if (!src0 || !src1) { + return false; + } - case GGML_OP_GLU: - if (ggml_get_glu_op(t) == GGML_GLU_OP_SWIGLU) { - req->op = HTP_OP_GLU_SWIGLU; - supported = true; - } else if (ggml_get_glu_op(t) == GGML_GLU_OP_SWIGLU_OAI) { - req->op = HTP_OP_GLU_SWIGLU_OAI; - supported = true; - } - break; + if (src0->type != GGML_TYPE_F32 || src1->type != GGML_TYPE_F32 || dst->type != GGML_TYPE_F32) { + return false; + } - case GGML_OP_SOFT_MAX: - req->op = HTP_OP_SOFTMAX; - supported = true; - break; + if (src0->ne[0] != src0->ne[1]) { + return false; + } - default: - break; + if (src0->ne[1] != src1->ne[1]) { + return false; } - if (!supported) { - GGML_ABORT("ggml-hex: unary : unsupported op: %d\n", t->op); + if (src0->ne[2] != src1->ne[2] || src0->ne[3] != src1->ne[3]) { + return false; } - size_t n_bufs = 0; - n_bufs += htp_req_buff_init(&req->src0, &bufs[n_bufs], t->src[0], DSPQBUF_TYPE_CPU_WRITE_DSP_READ); - n_bufs += htp_req_buff_init(&req->src1, &bufs[n_bufs], t->src[1], DSPQBUF_TYPE_CPU_WRITE_DSP_READ); - n_bufs += htp_req_buff_init(&req->dst, &bufs[n_bufs], t, DSPQBUF_TYPE_DSP_WRITE_CPU_READ); + if (dst->ne[0] != src1->ne[0] || dst->ne[1] != src1->ne[1] || dst->ne[2] != src1->ne[2] || dst->ne[3] != src1->ne[3]) { + return false; + } - return n_bufs; + GGML_UNUSED(sess); + return true; } -static inline size_t init_rope_req(htp_general_req * req, dspqueue_buffer * bufs, const ggml_tensor * t) { - memcpy(&req->op_params, &t->op_params, sizeof(t->op_params)); - req->op = HTP_OP_ROPE; +static bool ggml_hexagon_supported_tri(const struct ggml_hexagon_session * sess, const struct ggml_tensor * op) { - size_t n_bufs = 0; - n_bufs += htp_req_buff_init(&req->src0, &bufs[n_bufs], t->src[0], DSPQBUF_TYPE_CPU_WRITE_DSP_READ); - n_bufs += htp_req_buff_init(&req->src1, &bufs[n_bufs], t->src[1], DSPQBUF_TYPE_CPU_WRITE_DSP_READ); - n_bufs += htp_req_buff_init(&req->src2, &bufs[n_bufs], t->src[2], DSPQBUF_TYPE_CPU_WRITE_DSP_READ); - n_bufs += htp_req_buff_init(&req->dst, &bufs[n_bufs], t, DSPQBUF_TYPE_DSP_WRITE_CPU_READ); - - return n_bufs; -} + const struct ggml_tensor * src0 = op->src[0]; + const struct ggml_tensor * dst = op; -static inline size_t init_flash_attn_ext_req(htp_general_req * req, dspqueue_buffer * bufs, const ggml_tensor * t) { - memcpy(&req->op_params, &t->op_params, sizeof(t->op_params)); - req->op = HTP_OP_FLASH_ATTN_EXT; + if (src0->type != GGML_TYPE_F32) { return false; } + if (dst->type != GGML_TYPE_F32) { return false; } + if (!ggml_are_same_shape(src0, dst)) { return false; } + if (!ggml_is_contiguous(src0) || !ggml_is_contiguous(dst)) { return false; } - size_t n_bufs = 0; - n_bufs += htp_req_buff_init(&req->src0, &bufs[n_bufs], t->src[0], DSPQBUF_TYPE_CPU_WRITE_DSP_READ); - n_bufs += htp_req_buff_init(&req->src1, &bufs[n_bufs], t->src[1], DSPQBUF_TYPE_CPU_WRITE_DSP_READ); - n_bufs += htp_req_buff_init(&req->src2, &bufs[n_bufs], t->src[2], DSPQBUF_TYPE_CPU_WRITE_DSP_READ); - n_bufs += htp_req_buff_init(&req->src3, &bufs[n_bufs], t->src[3], DSPQBUF_TYPE_CPU_WRITE_DSP_READ); - n_bufs += htp_req_buff_init(&req->src4, &bufs[n_bufs], t->src[4], DSPQBUF_TYPE_CPU_WRITE_DSP_READ); - n_bufs += htp_req_buff_init(&req->dst, &bufs[n_bufs], t, DSPQBUF_TYPE_DSP_WRITE_CPU_READ); + return true; - return n_bufs; + GGML_UNUSED(sess); } static const char * ggml_backend_hexagon_name(ggml_backend_t backend) { auto sess = static_cast<ggml_hexagon_session *>(backend->context); - return sess->name.c_str(); + return sess->c_name(); } static void ggml_backend_hexagon_free(ggml_backend_t backend) { @@ -2447,118 +3128,118 @@ static void ggml_backend_hexagon_free(ggml_backend_t backend) { delete backend; } -static inline bool op_reuse_src1(const ggml_tensor * op1, const ggml_tensor * op0) { - return (op0 && op0->src[1] == op1->src[1] && ggml_is_quantized(op0->src[0]->type) && ggml_is_quantized(op1->src[1]->type)); -} +static htp_op_code op_remap_to_htp(const ggml_tensor * t) { + switch (t->op) { + case GGML_OP_FLASH_ATTN_EXT: return HTP_OP_FLASH_ATTN_EXT; + case GGML_OP_MUL_MAT: return HTP_OP_MUL_MAT; + case GGML_OP_MUL_MAT_ID: return HTP_OP_MUL_MAT_ID; + case GGML_OP_MUL: return HTP_OP_MUL; + case GGML_OP_ADD: return HTP_OP_ADD; + case GGML_OP_ADD_ID: return HTP_OP_ADD_ID; + case GGML_OP_SUB: return HTP_OP_SUB; + case GGML_OP_DIV: return HTP_OP_DIV; + case GGML_OP_CPY: return HTP_OP_CPY; + case GGML_OP_CONT: return HTP_OP_CPY; + case GGML_OP_GET_ROWS: return HTP_OP_GET_ROWS; + case GGML_OP_SET_ROWS: return HTP_OP_SET_ROWS; + case GGML_OP_SUM_ROWS: return HTP_OP_SUM_ROWS; + case GGML_OP_ARGSORT: return HTP_OP_ARGSORT; + case GGML_OP_NORM: return HTP_OP_NORM; + case GGML_OP_L2_NORM: return HTP_OP_L2_NORM; + case GGML_OP_RMS_NORM: return HTP_OP_RMS_NORM; + case GGML_OP_CONCAT: return HTP_OP_CONCAT; + case GGML_OP_SCALE: return HTP_OP_SCALE; + case GGML_OP_SQR: return HTP_OP_SQR; + case GGML_OP_SQRT: return HTP_OP_SQRT; + case GGML_OP_SOFT_MAX: return HTP_OP_SOFTMAX; + case GGML_OP_SSM_CONV: return HTP_OP_SSM_CONV; + case GGML_OP_GATED_DELTA_NET: return HTP_OP_GATED_DELTA_NET; + case GGML_OP_ROPE: return HTP_OP_ROPE; + case GGML_OP_REPEAT: return HTP_OP_REPEAT; + case GGML_OP_CUMSUM: return HTP_OP_CUMSUM; + case GGML_OP_FILL: return HTP_OP_FILL; + case GGML_OP_DIAG: return HTP_OP_DIAG; + case GGML_OP_SOLVE_TRI: return HTP_OP_SOLVE_TRI; + case GGML_OP_TRI: return HTP_OP_TRI; + case GGML_OP_PAD: return HTP_OP_PAD; -static inline bool is_compute_op(ggml_tensor *node) -{ - return !(ggml_op_is_empty(node->op) || ggml_is_empty(node)); -} + case GGML_OP_UNARY: + switch (ggml_get_unary_op(t)) { + case GGML_UNARY_OP_SILU: return HTP_OP_UNARY_SILU; + case GGML_UNARY_OP_GELU: return HTP_OP_UNARY_GELU; + case GGML_UNARY_OP_GELU_QUICK: return HTP_OP_UNARY_GELU; + case GGML_UNARY_OP_SIGMOID: return HTP_OP_UNARY_SIGMOID; + case GGML_UNARY_OP_NEG: return HTP_OP_UNARY_NEG; + case GGML_UNARY_OP_EXP: return HTP_OP_UNARY_EXP; + case GGML_UNARY_OP_SOFTPLUS: return HTP_OP_UNARY_SOFTPLUS; + case GGML_UNARY_OP_TANH: return HTP_OP_UNARY_TANH; + default: + break; + } + break; -// scan the graph and figure out last compute op index -static inline int last_compute_op(ggml_cgraph * graph) { - int last = 0; - for (int i = 0; i < graph->n_nodes; ++i) { - if (is_compute_op(graph->nodes[i])) { - last = i; - } + case GGML_OP_GLU: + switch (ggml_get_glu_op(t)) { + case GGML_GLU_OP_SWIGLU: return HTP_OP_GLU_SWIGLU; + case GGML_GLU_OP_SWIGLU_OAI: return HTP_OP_GLU_SWIGLU_OAI; + case GGML_GLU_OP_GEGLU: return HTP_OP_GLU_GEGLU; + default: break; + } + break; + + default: + GGML_ABORT("\nggml-hex: graph-compute %s is not supported\n", ggml_op_desc(t)); } + return HTP_OP_INVALID; +} - return last; +static inline bool op_is_compute(ggml_tensor *node) +{ + return !ggml_op_is_empty(node->op) && !ggml_is_empty(node) && (node->flags & GGML_TENSOR_FLAG_COMPUTE); } static ggml_status ggml_backend_hexagon_graph_compute(ggml_backend_t backend, ggml_cgraph * graph) { auto sess = static_cast<ggml_hexagon_session *>(backend->context); - HEX_VERBOSE("ggml-hex: %s graph-compute n_nodes %d\n", sess->name.c_str(), graph->n_nodes); - - const int last = last_compute_op(graph); + HEX_VERBOSE("ggml-hex: %s graph-compute n_nodes %d\n", sess->c_name(), graph->n_nodes); - const struct ggml_tensor * prev_quant_op = nullptr; // prev executed op with quantizer + std::vector<htp_opnode> nodes; + nodes.reserve(graph->n_nodes); + // Fusion for (int i = 0; i < graph->n_nodes; ++i) { - ggml_tensor * node = graph->nodes[i]; - - if (!is_compute_op(node)) { + ggml_tensor * n = graph->nodes[i]; + if (!op_is_compute(n)) { continue; } - uint32_t flags = 0; + ggml_tensor * next_node = (i + 1 < graph->n_nodes) ? graph->nodes[i + 1] : nullptr; - // skip quantizer if src1 is reused - if (op_reuse_src1(node, prev_quant_op)) { - flags |= HTP_OPFLAGS_SKIP_QUANTIZE; - } + htp_opnode node = { + /*.node =*/ n, + /*.fused =*/ {}, + /*.opcode =*/ HTP_OP_INVALID + }; - // ask for early notification for the last Op - if (i == last) { - flags |= HTP_OPFLAGS_EARLY_WAKEUP; + if (n->op == GGML_OP_RMS_NORM && next_node) { + if (next_node->op == GGML_OP_MUL && op_is_compute(next_node) && ggml_can_fuse(graph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL })) { + node.add_fused(next_node); + node.opcode = HTP_OP_RMS_NORM_MUL; + i++; // skip the fused MUL node + } } - switch (node->op) { - case GGML_OP_MUL_MAT: - if (ggml_is_quantized(node->src[0]->type)) { - ggml_hexagon_dispatch_op<init_binary_req<true>>(sess, node, flags); - } else { - ggml_hexagon_dispatch_op<init_binary_req<false>>(sess, node, flags); - } - prev_quant_op = node; - break; - case GGML_OP_MUL_MAT_ID: - if (ggml_is_quantized(node->src[0]->type)) { - ggml_hexagon_dispatch_op<init_binary_id_req<true>>(sess, node, flags); - } else { - ggml_hexagon_dispatch_op<init_binary_id_req<false>>(sess, node, flags); - } - prev_quant_op = node; - break; - case GGML_OP_MUL: - case GGML_OP_ADD: - case GGML_OP_SUB: - ggml_hexagon_dispatch_op<init_binary_req<false>>(sess, node, flags); - break; - case GGML_OP_ADD_ID: - ggml_hexagon_dispatch_op<init_binary_id_req<false>>(sess, node, flags); - break; - case GGML_OP_RMS_NORM: - case GGML_OP_SCALE: - ggml_hexagon_dispatch_op<init_unary_req>(sess, node, flags); - break; - case GGML_OP_UNARY: - if ((ggml_get_unary_op(node) == GGML_UNARY_OP_SILU) || - (ggml_get_unary_op(node) == GGML_UNARY_OP_GELU)) { - ggml_hexagon_dispatch_op<init_unary_req>(sess, node, flags); - } - break; - case GGML_OP_GLU: - if ((ggml_get_glu_op(node) == GGML_GLU_OP_SWIGLU) || - (ggml_get_glu_op(node) == GGML_GLU_OP_SWIGLU_OAI)) { - ggml_hexagon_dispatch_op<init_unary_req>(sess, node, flags); - } - break; - case GGML_OP_SOFT_MAX: - ggml_hexagon_dispatch_op<init_unary_req>(sess, node, flags); - break; - - case GGML_OP_ROPE: - ggml_hexagon_dispatch_op<init_rope_req>(sess, node, flags); - break; - - case GGML_OP_FLASH_ATTN_EXT: - ggml_hexagon_dispatch_op<init_flash_attn_ext_req>(sess, node, flags); - break; - - case GGML_OP_SET_ROWS: - ggml_hexagon_dispatch_op<init_set_rows_req>(sess, node, flags); - break; + if (node.opcode == HTP_OP_INVALID) { + node.opcode = op_remap_to_htp(n); + } - case GGML_OP_GET_ROWS: - ggml_hexagon_dispatch_op<init_get_rows_req>(sess, node, flags); - break; + nodes.push_back(std::move(node)); + } - default: - GGML_ABORT("\nggml-hex: graph-compute %s is not supported\n", ggml_op_desc(node)); + // Queue and execute + if (opt_opstage & HTP_OPSTAGE_QUEUE) { + for (const auto & node : nodes) { + sess->enqueue_op(node); } } @@ -2571,57 +3252,13 @@ static ggml_status ggml_backend_hexagon_graph_compute(ggml_backend_t backend, gg static void ggml_backend_hexagon_synchronize(ggml_backend_t backend) { auto sess = static_cast<ggml_hexagon_session *>(backend->context); - HEX_VERBOSE("ggml-hex: %s synchronize\n", sess->name.c_str()); + HEX_VERBOSE("ggml-hex: %s synchronize\n", sess->c_name()); // Wait until all pending ops complete sess->flush(); } -struct node_info { - ggml_tensor * node; - - std::vector<ggml_tensor *> fused; - - ggml_op op() const { - return node->op; - } - - const ggml_tensor * dst() const { - return fused.empty() ? node : fused.back(); - } - - const ggml_tensor * src0() const { - return node->src[0]; - } - - const ggml_tensor * src1() const { - return node->src[1]; - } - - bool is_empty() const { - return ggml_op_is_empty(node->op); - } - - void add_fused(ggml_tensor * t) { - fused.push_back(t); - } - - bool stackable() const { - switch (this->op()) { - case GGML_OP_MUL_MAT: - case GGML_OP_MUL_MAT_ID: - return ggml_is_quantized(this->src0()->type); - default: - return false; - } - } - - bool same_input(const node_info& n) const { - return n.src1() == this->src1(); - } -}; - -static std::vector<int> ggml_hexagon_graph_optimize_reorder(const std::vector<node_info> & nodes) { +static std::vector<int> ggml_hexagon_graph_optimize_reorder(const std::vector<htp_opnode> & nodes) { const int n = nodes.size(); std::vector<int> res; @@ -2632,7 +3269,7 @@ static std::vector<int> ggml_hexagon_graph_optimize_reorder(const std::vector<no // The main goal here is to stack the MUL_MAT ops with the same src1 input. // This allows use to reuse dynamically quantized src1 in VTCM. - // TODO: the current version might do incorrect reodering in cases where quantized src0 + // TODO: the current version might do incorrect reordering in cases where quantized src0 // input is an output of another Op. for (int i0 = 0; i0 < n; i0++) { @@ -2649,7 +3286,7 @@ static std::vector<int> ggml_hexagon_graph_optimize_reorder(const std::vector<no } // that many nodes forward to search for stackable nodes that can reuse VTCM - constexpr int N_FORWARD = 8; + constexpr int N_FORWARD = 16; for (int i1 = i0 + 1; i1 < i0 + N_FORWARD && i1 < n; i1++) { if (used[i1]) { @@ -2675,14 +3312,14 @@ static void ggml_backend_hexagon_graph_optimize(ggml_backend_t backend, ggml_cgr enum ggml_op ops[MAX_FUSE]; - std::vector<node_info> nodes; + std::vector<htp_opnode> nodes; nodes.reserve(gf->n_nodes); // fuse nodes: // we don't want to make reorders that break fusing, so we first pack all fusable tensors // and perform the reorder over the fused nodes. after the reorder is done, we unfuse for (int i = 0; i < n; i++) { - node_info node = { + htp_opnode node = { /*.node =*/gf->nodes[i], /*.fused =*/{}, }; @@ -2749,6 +3386,8 @@ static struct ggml_backend_i hexagon_backend_i = { /* .free = */ ggml_backend_hexagon_free, /* .set_tensor_async = */ NULL, /* .get_tensor_async = */ NULL, + /* .set_tensor_2d_async = */ NULL, + /* .get_tensor_2d_async = */ NULL, /* .cpy_tensor_async = */ NULL, /* .synchronize = */ ggml_backend_hexagon_synchronize, /* .graph_plan_create = */ NULL, @@ -2788,7 +3427,7 @@ static ggml_backend_t ggml_backend_hexagon_device_init(ggml_backend_dev_t dev, c static const char * ggml_backend_hexagon_device_get_name(ggml_backend_dev_t dev) { auto sess = static_cast<ggml_hexagon_session *>(dev->context); - return sess->name.c_str(); + return sess->c_name(); GGML_UNUSED(dev); } @@ -2799,8 +3438,7 @@ static const char * ggml_backend_hexagon_device_get_description(ggml_backend_dev } static void ggml_backend_hexagon_device_get_memory(ggml_backend_dev_t dev, size_t * free, size_t * total) { - // ~2GB per session for now - *free = 2ULL * 1024 * 1024 * 1024; + *free = 0; *total = *free; GGML_UNUSED(dev); @@ -2858,9 +3496,98 @@ static bool ggml_hexagon_supported_buffers(ggml_hexagon_session *sess, const str return true; } +static bool ggml_hexagon_supported_cpy(const struct ggml_hexagon_session * sess, const struct ggml_tensor * op) { + const struct ggml_tensor * src0 = op->src[0]; + const struct ggml_tensor * dst = op; + + // for now we can do f32 -> f16 and f16 -> f32 (without reshaping) + if (src0->type != GGML_TYPE_F32 && src0->type != GGML_TYPE_F16) return false; + if ( dst->type != GGML_TYPE_F32 && dst->type != GGML_TYPE_F16) return false; + + const bool sametype = (src0->type == dst->type); + const bool transposed = ggml_is_transposed(src0) || ggml_is_transposed(dst); + const bool sameshape = !transposed && ggml_are_same_shape(src0, dst); + + // can handle any shape and any same-type (pretty slow if reshaping is required) + if (sametype) return true; + + // cannot handle re-shaping and type conversion at the same time + if (!sameshape) return false; + + return true; +} + +static bool ggml_hexagon_supported_cont(const struct ggml_hexagon_session * sess, const struct ggml_tensor * op) { + GGML_UNUSED(sess); + const struct ggml_tensor * src0 = op->src[0]; + + // CONT is same-type only, supports f32 and f16 + if (src0->type != GGML_TYPE_F32 && src0->type != GGML_TYPE_F16) return false; + + return true; +} + +static bool ggml_hexagon_supported_repeat(const struct ggml_hexagon_session * sess, const struct ggml_tensor * op) { + GGML_UNUSED(sess); + const struct ggml_tensor * src0 = op->src[0]; + const struct ggml_tensor * dst = op; + + // Support f32 and f16 + if (src0->type != GGML_TYPE_F32 && src0->type != GGML_TYPE_F16) return false; + + // src and dst must be the same type + if (src0->type != dst->type) return false; + + // dst dims must be multiples of src dims + if (dst->ne[0] % src0->ne[0] != 0) return false; + if (dst->ne[1] % src0->ne[1] != 0) return false; + if (dst->ne[2] % src0->ne[2] != 0) return false; + if (dst->ne[3] % src0->ne[3] != 0) return false; + + // require contiguous tensors (no transposition) + if (ggml_is_transposed(src0) || ggml_is_transposed(dst)) return false; + + return true; +} + +static bool ggml_hexagon_supported_concat(const struct ggml_hexagon_session * sess, const struct ggml_tensor * op) { + int dim = ((const int32_t *) op->op_params)[0]; + if (dim < 0 || dim >= GGML_MAX_DIMS) { + return false; + } + + for (int i = 0; i < GGML_MAX_SRC; ++i) { + const struct ggml_tensor * src = op->src[i]; + if (!src) { + continue; + } + if (src->type != GGML_TYPE_F32 && src->type != GGML_TYPE_I32 && src->type != GGML_TYPE_F16) { + return false; + } + } + + return true; +} + +static bool ggml_hexagon_supported_fill(const struct ggml_hexagon_session * sess, const struct ggml_tensor * op) { + const struct ggml_tensor * dst = op; + + if (dst->type != GGML_TYPE_F32 && dst->type != GGML_TYPE_F16) { + return false; + } + + GGML_UNUSED(sess); + return true; +} + static bool ggml_backend_hexagon_device_supports_op(ggml_backend_dev_t dev, const struct ggml_tensor * op) { auto sess = static_cast<ggml_hexagon_session *>(dev->context); + // reject ops that match the filter + if (opt_opfilter && std::regex_match(ggml_op_desc(op), *opt_opfilter)) { + return false; + } + // all srcs & dsts must be mapped to the same session if (!ggml_hexagon_supported_buffers(sess, op)) { ggml_hexagon_dump_op_supp(sess->name, op, false); @@ -2877,6 +3604,13 @@ static bool ggml_backend_hexagon_device_supports_op(ggml_backend_dev_t dev, cons supp = true; break; + case GGML_OP_MUL: + case GGML_OP_ADD: + case GGML_OP_SUB: + case GGML_OP_DIV: + supp = ggml_hexagon_supported_binary(sess, op); + break; + case GGML_OP_MUL_MAT: supp = ggml_hexagon_supported_mul_mat(sess, op); break; @@ -2885,41 +3619,61 @@ static bool ggml_backend_hexagon_device_supports_op(ggml_backend_dev_t dev, cons supp = ggml_hexagon_supported_mul_mat_id(sess, op); break; - case GGML_OP_MUL: - case GGML_OP_ADD: - case GGML_OP_SUB: - supp = ggml_hexagon_supported_binary(sess, op); - break; - case GGML_OP_ADD_ID: supp = ggml_hexagon_supported_add_id(sess, op); break; + case GGML_OP_NORM: + case GGML_OP_L2_NORM: case GGML_OP_RMS_NORM: case GGML_OP_SCALE: supp = ggml_hexagon_supported_unary(sess, op); break; + case GGML_OP_SQR: + case GGML_OP_SQRT: + supp = ggml_hexagon_supported_unary(sess, op); + break; + + case GGML_OP_SUM_ROWS: + supp = ggml_hexagon_supported_sum_rows(sess, op); + break; + case GGML_OP_SOFT_MAX: supp = ggml_hexagon_supported_softmax(sess, op); break; case GGML_OP_UNARY: - { - const auto unary_op = ggml_get_unary_op(op); - if (unary_op == GGML_UNARY_OP_SILU || unary_op == GGML_UNARY_OP_GELU) { + switch (ggml_get_unary_op(op)) { + case GGML_UNARY_OP_NEG: + case GGML_UNARY_OP_EXP: + case GGML_UNARY_OP_SIGMOID: + case GGML_UNARY_OP_SOFTPLUS: + case GGML_UNARY_OP_TANH: + supp = ggml_hexagon_supported_unary(sess, op); + break; + case GGML_UNARY_OP_SILU: + case GGML_UNARY_OP_GELU: + case GGML_UNARY_OP_GELU_QUICK: supp = ggml_hexagon_supported_activations(sess, op); - } - break; + break; + default: + break; } + break; + case GGML_OP_GLU: - { - const auto glu_op = ggml_get_glu_op(op); - if ((glu_op == GGML_GLU_OP_SWIGLU) || (glu_op == GGML_GLU_OP_SWIGLU_OAI)) { + switch (ggml_get_glu_op(op)) { + case GGML_GLU_OP_SWIGLU: + case GGML_GLU_OP_SWIGLU_OAI: + case GGML_GLU_OP_GEGLU: supp = ggml_hexagon_supported_activations(sess, op); - } - break; + break; + default: + break; } + break; + case GGML_OP_ROPE: supp = ggml_hexagon_supported_rope(sess, op); break; @@ -2936,6 +3690,58 @@ static bool ggml_backend_hexagon_device_supports_op(ggml_backend_dev_t dev, cons supp = ggml_hexagon_supported_get_rows(sess, op); break; + case GGML_OP_CPY: + supp = ggml_hexagon_supported_cpy(sess, op); + break; + + case GGML_OP_CONT: + supp = ggml_hexagon_supported_cont(sess, op); + break; + + case GGML_OP_REPEAT: + supp = ggml_hexagon_supported_repeat(sess, op); + break; + + case GGML_OP_ARGSORT: + supp = ggml_hexagon_supported_argsort(sess, op); + break; + + case GGML_OP_SSM_CONV: + supp = ggml_hexagon_supported_ssm_conv(sess, op); + break; + + case GGML_OP_GATED_DELTA_NET: + supp = ggml_hexagon_supported_gated_delta_net(sess, op); + break; + + case GGML_OP_CUMSUM: + supp = ggml_hexagon_supported_cumsum(sess, op); + break; + + case GGML_OP_CONCAT: + supp = ggml_hexagon_supported_concat(sess, op); + break; + + case GGML_OP_FILL: + supp = ggml_hexagon_supported_fill(sess, op); + break; + + case GGML_OP_DIAG: + supp = ggml_hexagon_supported_diag(sess, op); + break; + + case GGML_OP_SOLVE_TRI: + supp = ggml_hexagon_supported_solve_tri(sess, op); + break; + + case GGML_OP_TRI: + supp = ggml_hexagon_supported_tri(sess, op); + break; + + case GGML_OP_PAD: + supp = ggml_hexagon_supported_pad(sess, op); + break; + default: break; } @@ -3002,19 +3808,6 @@ struct ggml_hexagon_registry { ggml_hexagon_registry::ggml_hexagon_registry(ggml_backend_reg_t reg) { GGML_LOG_INFO("ggml-hex: Hexagon backend (experimental) : allocating new registry : ndev %zu\n", opt_ndev); - if (!opt_arch) { - int err = get_hex_arch_ver(CDSP_DOMAIN_ID, &opt_arch); - if (err != 0) { - GGML_LOG_ERROR("ggml-hex: failed to query HTP version (err %d) defaulting to v73\n", err); - opt_arch = 73; - } - } - - if (opt_arch < 75) { - opt_ndev = 1; - GGML_LOG_WARN("ggml-hex: forcing ndev to 1 for SoCs archs lower than v75.\n"); - } - GGML_LOG_INFO("ggml-hex: Hexagon Arch version v%d\n", opt_arch); // Create devices / sessions @@ -3061,7 +3854,7 @@ static ggml_backend_dev_t ggml_backend_hexagon_reg_get_device(ggml_backend_reg_t } static void * ggml_backend_hexagon_get_proc_address(ggml_backend_reg_t reg, const char * name) { - if (strcmp(name, "ggml_backend_dev_get_extra_bufts") == 0) { + if (strcmp(name, "ggml_backend_dev_get_extra_bufts") == 0 && opt_hostbuf) { ggml_backend_dev_get_extra_bufts_t fct = ggml_backend_hexagon_device_get_extra_buffers_type; return (void *) fct; } @@ -3069,56 +3862,117 @@ static void * ggml_backend_hexagon_get_proc_address(ggml_backend_reg_t reg, cons return NULL; } +template<typename T> std::vector<T> str_to_vec(const char* str) { + std::stringstream ss(str); + std::vector<T> v; + std::string t; + + while (std::getline(ss, t, ',')) { + v.push_back(std::stoul(t, nullptr, 0)); + } + + return v; +} + +template<typename T, int BASE=10> std::string vec_to_str(std::vector<T> v) { + std::stringstream ss; + ss << std::setbase(BASE) << std::showbase; + for (auto i : v) { ss << i << ','; } + auto str = ss.str(); str.pop_back(); // drop last comma + return str; +} + static void ggml_hexagon_init(ggml_backend_reg * reg) { // Basic sanity checks to make sure definitions match static_assert((unsigned int) HTP_TYPE_Q4_0 == (unsigned int) GGML_TYPE_Q4_0, "please update hexagon_type to match ggml_type"); + static_assert((unsigned int) HTP_TYPE_Q4_1 == (unsigned int) GGML_TYPE_Q4_1, + "please update hexagon_type to match ggml_type"); static_assert((unsigned int) HTP_TYPE_Q8_0 == (unsigned int) GGML_TYPE_Q8_0, "please update hexagon_type to match ggml_type"); static_assert((unsigned int) HTP_TYPE_MXFP4 == (unsigned int) GGML_TYPE_MXFP4, "please update hexagon_type to match ggml_type"); + static_assert((unsigned int) HTP_TYPE_IQ4_NL == (unsigned int) GGML_TYPE_IQ4_NL, + "please update hexagon_type to match ggml_type"); + + const char * str_verbose = getenv("GGML_HEXAGON_VERBOSE"); + const char * str_hostbuf = getenv("GGML_HEXAGON_HOSTBUF"); + const char * str_opstage = getenv("GGML_HEXAGON_OPSTAGE"); + const char * str_opbatch = getenv("GGML_HEXAGON_OPBATCH"); + const char * str_opqueue = getenv("GGML_HEXAGON_OPQUEUE"); + const char * str_oppoll = getenv("GGML_HEXAGON_OPPOLL"); + const char * str_opfilter = getenv("GGML_HEXAGON_OPFILTER"); + const char * str_profile = getenv("GGML_HEXAGON_PROFILE"); + const char * str_etm = getenv("GGML_HEXAGON_ETM"); + const char * str_nhvx = getenv("GGML_HEXAGON_NHVX"); + const char * str_use_hmx = getenv("GGML_HEXAGON_USE_HMX"); + const char * str_ndev = getenv("GGML_HEXAGON_NDEV"); + const char * str_arch = getenv("GGML_HEXAGON_ARCH"); + const char * str_vmem = getenv("GGML_HEXAGON_VMEM"); + const char * str_mbuf = getenv("GGML_HEXAGON_MBUF"); + + // Init Arch first since it affects other defaults + if (!str_arch) { + int err = get_hex_arch_ver(CDSP_DOMAIN_ID, &opt_arch); + if (err != 0) { + GGML_LOG_ERROR("ggml-hex: failed to query HTP version (err %d) defaulting to v73\n", err); + opt_arch = 73; + } + } else { + if (str_arch[0] == 'v' || str_arch[0] == 'V') { + str_arch++; + } + opt_arch = strtoul(str_arch, NULL, 0); + } - const char * str_verbose = getenv("GGML_HEXAGON_VERBOSE"); - const char * str_hostbuf = getenv("GGML_HEXAGON_HOSTBUF"); + size_t MiB = 1024 * 1024; - opt_verbose = str_verbose ? atoi(str_verbose) : 0; - opt_profile = getenv("GGML_HEXAGON_PROFILE") != nullptr; - opt_etm = getenv("GGML_HEXAGON_ETM") != nullptr; - opt_experimental = getenv("GGML_HEXAGON_EXPERIMENTAL") != nullptr; + // Update vmem default + opt_vmem = opt_arch >= 75 ? HTP_OP_MAX_VMEM_DEFAULT : 3000 * MiB; - const char * str_opmask = getenv("GGML_HEXAGON_OPMASK"); - if (str_opmask != nullptr) { - opt_opmask = strtoul(str_opmask, NULL, 0); - } - opt_opsync = getenv("GGML_HEXAGON_OPSYNC") != nullptr; + auto RE_ICASE = std::regex_constants::icase; - const char * str_ndev = getenv("GGML_HEXAGON_NDEV"); - if (str_ndev) { - opt_ndev = strtoul(str_ndev, NULL, 0); - if (opt_ndev > GGML_HEXAGON_MAX_SESSIONS) { - opt_ndev = GGML_HEXAGON_MAX_SESSIONS; - } - } + opt_opfilter = str_opfilter ? new std::regex(str_opfilter, RE_ICASE) : NULL; + opt_verbose = str_verbose ? atoi(str_verbose) : 0; + opt_hostbuf = str_hostbuf ? atoi(str_hostbuf) : opt_hostbuf; + opt_opstage = str_opstage ? strtoul(str_opstage, NULL, 0) : opt_opstage; + opt_opbatch = str_opbatch ? strtoul(str_opbatch, NULL, 0) : opt_opbatch; + opt_opqueue = str_opqueue ? strtoul(str_opqueue, NULL, 0) : opt_opqueue; + opt_oppoll = str_oppoll ? strtoul(str_oppoll, NULL, 0) : opt_oppoll; + opt_profile = str_profile ? atoi(str_profile) : 0; + opt_etm = str_etm ? atoi(str_etm) : 0; + opt_nhvx = str_nhvx ? strtoul(str_nhvx, NULL, 0) : opt_nhvx; + opt_use_hmx = str_use_hmx ? atoi(str_use_hmx) : opt_use_hmx; + opt_ndev = str_ndev ? strtoul(str_ndev, NULL, 0) : opt_ndev; + opt_hostbuf = str_hostbuf ? atoi(str_hostbuf) : opt_hostbuf; + opt_mbuf = str_mbuf ? strtoul(str_mbuf, NULL, 0) * MiB : opt_mbuf; + opt_vmem = str_vmem ? strtoul(str_vmem, NULL, 0) * MiB : opt_vmem; - const char * str_nhvx = getenv("GGML_HEXAGON_NHVX"); - if (str_nhvx) { - opt_nhvx = strtoul(str_nhvx, NULL, 0); + if (opt_ndev > GGML_HEXAGON_MAX_SESSIONS) { + opt_ndev = GGML_HEXAGON_MAX_SESSIONS; } - const char * str_arch = getenv("GGML_HEXAGON_ARCH"); - if (str_arch) { - if (str_arch[0] == 'v') { - str_arch++; - } - opt_arch = strtoul(str_arch, NULL, 0); +#if defined(__ANDROID__) + if (opt_arch < 75) { + opt_ndev = 1; + GGML_LOG_WARN("ggml-hex: forcing ndev to 1 for SoCs archs lower than v75.\n"); } +#endif - opt_hostbuf = str_hostbuf ? atoi(str_hostbuf) : 1; + if (str_profile) { + opt_pmu_evt = [&]() -> std::vector<uint32_t> { + auto v = str_to_vec<uint32_t>(str_profile); + switch (v.size()) { + case 1: opt_profile = v[0]; return opt_pmu_evt; // mode with default pmu events + case 8: opt_profile = 2; return v; // mode with custom pmu events + default: opt_profile = 0; return {}; // garbage input + }}(); + if (opt_profile == 1) opt_pmu_evt = {}; + GGML_LOG_INFO("ggml-hex: Profiling mode %u : pmu-evt [ %s ]\n", opt_profile, + vec_to_str<uint32_t, 16>(opt_pmu_evt).c_str()); + } reg->context = new ggml_hexagon_registry(reg); - - HEX_VERBOSE("ggml-hex: size-of-general-req %zu size-of-general-rsp %zu\n", sizeof(struct htp_general_req), - sizeof(struct htp_general_rsp)); } static const struct ggml_backend_reg_i ggml_backend_hexagon_reg_i = { @@ -3139,6 +3993,11 @@ ggml_backend_reg_t ggml_backend_hexagon_reg(void) { static std::mutex mutex; std::lock_guard<std::mutex> lock(mutex); if (!initialized) { + auto nErr = htpdrv_init(); + if (nErr != AEE_SUCCESS) { + return NULL; + } + ggml_hexagon_init(®); } diff --git a/ggml/src/ggml-hexagon/htp-drv.cpp b/ggml/src/ggml-hexagon/htp-drv.cpp new file mode 100644 index 00000000000..4c376b5fc91 --- /dev/null +++ b/ggml/src/ggml-hexagon/htp-drv.cpp @@ -0,0 +1,418 @@ +// sample drv interface + +#pragma clang diagnostic ignored "-Wgnu-anonymous-struct" +#pragma clang diagnostic ignored "-Wmissing-prototypes" +#pragma clang diagnostic ignored "-Wsign-compare" + +#include <filesystem> +#include <set> +#include <sstream> +#include <string> +#ifdef _WIN32 +# define WIN32_LEAN_AND_MEAN +# ifndef NOMINMAX +# define NOMINMAX +# endif +# include <windows.h> +# include <winevt.h> +#else +# include <dlfcn.h> +# include <unistd.h> +#endif +#include "ggml-impl.h" +#include "htp-drv.h" +#include "libdl.h" + +#include <domain.h> + +// +// Driver API types +// + +typedef void * (*rpcmem_alloc_pfn_t)(int heapid, uint32_t flags, int size); +typedef void * (*rpcmem_alloc2_pfn_t)(int heapid, uint32_t flags, size_t size); +typedef void (*rpcmem_free_pfn_t)(void * po); +typedef int (*rpcmem_to_fd_pfn_t)(void * po); + +typedef AEEResult (*dspqueue_create_pfn_t)(int domain, + uint32_t flags, + uint32_t req_queue_size, + uint32_t resp_queue_size, + dspqueue_callback_t packet_callback, + dspqueue_callback_t error_callback, + void * callback_context, + dspqueue_t * queue); +typedef AEEResult (*dspqueue_close_pfn_t)(dspqueue_t queue); +typedef AEEResult (*dspqueue_export_pfn_t)(dspqueue_t queue, uint64_t *queue_id); +typedef AEEResult (*dspqueue_write_pfn_t)(dspqueue_t queue, uint32_t flags, + uint32_t num_buffers, + struct dspqueue_buffer *buffers, + uint32_t message_length, + const uint8_t *message, + uint32_t timeout_us); +typedef AEEResult (*dspqueue_read_pfn_t)(dspqueue_t queue, uint32_t *flags, + uint32_t max_buffers, uint32_t *num_buffers, + struct dspqueue_buffer *buffers, + uint32_t max_message_length, + uint32_t *message_length, uint8_t *message, + uint32_t timeout_us); + +typedef int (*fastrpc_mmap_pfn_t)(int domain, int fd, void *addr, int offset, size_t length, enum fastrpc_map_flags flags); +typedef int (*fastrpc_munmap_pfn_t)(int domain, int fd, void *addr, size_t length); + +typedef int (*remote_handle64_open_pfn_t)(const char* name, remote_handle64 *ph); +typedef int (*remote_handle64_invoke_pfn_t)(remote_handle64 h, uint32_t dwScalars, remote_arg *pra); +typedef int (*remote_handle64_close_pfn_t)(remote_handle h); +typedef int (*remote_handle_control_pfn_t)(uint32_t req, void* data, uint32_t datalen); +typedef int (*remote_handle64_control_pfn_t)(remote_handle64 h, uint32_t req, void* data, uint32_t datalen); +typedef int (*remote_session_control_pfn_t)(uint32_t req, void *data, uint32_t datalen); + +// +// Driver API pfns +// + +rpcmem_alloc_pfn_t rpcmem_alloc_pfn = nullptr; +rpcmem_alloc2_pfn_t rpcmem_alloc2_pfn = nullptr; +rpcmem_free_pfn_t rpcmem_free_pfn = nullptr; +rpcmem_to_fd_pfn_t rpcmem_to_fd_pfn = nullptr; + +fastrpc_mmap_pfn_t fastrpc_mmap_pfn = nullptr; +fastrpc_munmap_pfn_t fastrpc_munmap_pfn = nullptr; + +dspqueue_create_pfn_t dspqueue_create_pfn = nullptr; +dspqueue_close_pfn_t dspqueue_close_pfn = nullptr; +dspqueue_export_pfn_t dspqueue_export_pfn = nullptr; +dspqueue_write_pfn_t dspqueue_write_pfn = nullptr; +dspqueue_read_pfn_t dspqueue_read_pfn = nullptr; + +remote_handle64_open_pfn_t remote_handle64_open_pfn = nullptr; +remote_handle64_invoke_pfn_t remote_handle64_invoke_pfn = nullptr; +remote_handle64_close_pfn_t remote_handle64_close_pfn = nullptr; +remote_handle_control_pfn_t remote_handle_control_pfn = nullptr; +remote_handle64_control_pfn_t remote_handle64_control_pfn = nullptr; +remote_session_control_pfn_t remote_session_control_pfn = nullptr; + +// +// Driver API +// + +void * rpcmem_alloc(int heapid, uint32_t flags, int size) { + return rpcmem_alloc_pfn(heapid, flags, size); +} + +void * rpcmem_alloc2(int heapid, uint32_t flags, size_t size) { + if (rpcmem_alloc2_pfn) { + return rpcmem_alloc2_pfn(heapid, flags, size); + } else { + GGML_LOG_INFO("ggml-hex: rpcmem_alloc2 not found, falling back to rpcmem_alloc\n"); + return rpcmem_alloc_pfn(heapid, flags, size); + } +} + +void rpcmem_free(void * po) { + return rpcmem_free_pfn(po); +} + +int rpcmem_to_fd(void * po) { + return rpcmem_to_fd_pfn(po); +} + +HTPDRV_API int fastrpc_mmap(int domain, int fd, void * addr, int offset, size_t length, enum fastrpc_map_flags flags) { + return fastrpc_mmap_pfn(domain, fd, addr, offset, length, flags); +} + +HTPDRV_API int fastrpc_munmap(int domain, int fd, void * addr, size_t length) { + return fastrpc_munmap_pfn(domain, fd, addr, length); +} + +AEEResult dspqueue_create(int domain, + uint32_t flags, + uint32_t req_queue_size, + uint32_t resp_queue_size, + dspqueue_callback_t packet_callback, + dspqueue_callback_t error_callback, + void * callback_context, + dspqueue_t * queue) { + return dspqueue_create_pfn(domain, flags, req_queue_size, resp_queue_size, packet_callback, error_callback, + callback_context, queue); +} + +AEEResult dspqueue_close(dspqueue_t queue) { + return dspqueue_close_pfn(queue); +} + +AEEResult dspqueue_export(dspqueue_t queue, uint64_t * queue_id) { + return dspqueue_export_pfn(queue, queue_id); +} + +AEEResult dspqueue_write(dspqueue_t queue, + uint32_t flags, + uint32_t num_buffers, + struct dspqueue_buffer * buffers, + uint32_t message_length, + const uint8_t * message, + uint32_t timeout_us) { + return dspqueue_write_pfn(queue, flags, num_buffers, buffers, message_length, message, timeout_us); +} + +AEEResult dspqueue_read(dspqueue_t queue, + uint32_t * flags, + uint32_t max_buffers, + uint32_t * num_buffers, + struct dspqueue_buffer * buffers, + uint32_t max_message_length, + uint32_t * message_length, + uint8_t * message, + uint32_t timeout_us) { + return dspqueue_read_pfn(queue, flags, max_buffers, num_buffers, buffers, max_message_length, message_length, + message, timeout_us); +} + +HTPDRV_API int remote_handle64_open(const char * name, remote_handle64 * ph) { + return remote_handle64_open_pfn(name, ph); +} + +HTPDRV_API int remote_handle64_invoke(remote_handle64 h, uint32_t dwScalars, remote_arg * pra) { + return remote_handle64_invoke_pfn(h, dwScalars, pra); +} + +HTPDRV_API int remote_handle64_close(remote_handle64 h) { + return remote_handle64_close_pfn(h); +} + +HTPDRV_API int remote_handle_control(uint32_t req, void * data, uint32_t datalen) { + return remote_handle_control_pfn(req, data, datalen); +} + +HTPDRV_API int remote_handle64_control(remote_handle64 h, uint32_t req, void * data, uint32_t datalen) { + return remote_handle64_control_pfn(h, req, data, datalen); +} + +HTPDRV_API int remote_session_control(uint32_t req, void * data, uint32_t datalen) { + return remote_session_control_pfn(req, data, datalen); +} + +#ifdef _WIN32 + +static std::string wstr_to_str(std::wstring_view wstr) { + std::string result; + if (wstr.empty()) { + return result; + } + auto bytes_needed = WideCharToMultiByte(CP_UTF8, WC_ERR_INVALID_CHARS, + wstr.data(), (int) wstr.size(), + nullptr, 0, nullptr, nullptr); + if (bytes_needed == 0) { + GGML_LOG_ERROR("ggml-hex: WideCharToMultiByte failed. Error %lu\n", GetLastError()); + throw std::runtime_error("Invalid wstring input"); + } + + result.resize(bytes_needed, '\0'); + int bytes_written = WideCharToMultiByte(CP_UTF8, WC_ERR_INVALID_CHARS, + wstr.data(), (int) wstr.size(), + result.data(), bytes_needed, + nullptr, nullptr); + if (bytes_written == 0) { + GGML_LOG_ERROR("ggml-hex: WideCharToMultiByte failed. Error %lu\n", GetLastError()); + throw std::runtime_error("Wstring conversion failed"); + } + return result; +} + +static std::string get_driver_path() { + std::wstring serviceName = L"qcnspmcdm"; + std::string result; + + // Get a handle to the SCM database. + SC_HANDLE schSCManager = OpenSCManagerW(NULL, NULL, STANDARD_RIGHTS_READ); + if (nullptr == schSCManager) { + GGML_LOG_ERROR("ggml-hex: Failed to open SCManager. Error: %lu\n", GetLastError()); + return result; + } + + // Get a handle to the service. + SC_HANDLE schService = OpenServiceW(schSCManager, // SCM database + serviceName.c_str(), // name of service + SERVICE_QUERY_CONFIG); // need query config access + + if (nullptr == schService) { + GGML_LOG_ERROR("ggml-hex: Failed to open qcnspmcdm service. Error: %lu\n", GetLastError()); + CloseServiceHandle(schSCManager); + return result; + } + + // Store the size of buffer used as an output. + DWORD bufferSize; + if (!QueryServiceConfigW(schService, NULL, 0, &bufferSize) && + (GetLastError() != ERROR_INSUFFICIENT_BUFFER)) { + GGML_LOG_ERROR("ggml-hex: Failed to query service config. Error: %lu\n", GetLastError()); + CloseServiceHandle(schService); + CloseServiceHandle(schSCManager); + return result; + } + // Get the configuration of the service. + LPQUERY_SERVICE_CONFIGW serviceConfig = + static_cast<LPQUERY_SERVICE_CONFIGW>(LocalAlloc(LMEM_FIXED, bufferSize)); + if (!QueryServiceConfigW(schService, serviceConfig, bufferSize, &bufferSize)) { + fprintf(stderr, "ggml-hex: Failed to query service config. Error: %lu\n", GetLastError()); + LocalFree(serviceConfig); + CloseServiceHandle(schService); + CloseServiceHandle(schSCManager); + return result; + } + + // Read the driver file path get its parent directory + std::wstring driverPath = std::wstring(serviceConfig->lpBinaryPathName); + driverPath = driverPath.substr(0, driverPath.find_last_of(L"\\")); + + // Clean up resources + LocalFree(serviceConfig); + CloseServiceHandle(schService); + CloseServiceHandle(schSCManager); + + // Driver path would contain invalid path string, like: + // \SystemRoot\System32\DriverStore\FileRepository\qcadsprpc8280.inf_arm64_c2b9460c9a072f37 + // "\SystemRoot" should be replace with a correct one (e.g. C:\Windows) + const std::wstring systemRootPlaceholder = L"\\SystemRoot"; + if (0 != driverPath.compare(0, systemRootPlaceholder.length(), systemRootPlaceholder)) { + GGML_LOG_ERROR("ggml-hex: String pattern not found in driver path.\n"); + return result; + } + + // Replace \SystemRoot with an absolute path from system ENV windir + const std::wstring systemRootEnv = L"windir"; + + // Query the number of wide characters this variable requires + DWORD numWords = GetEnvironmentVariableW(systemRootEnv.c_str(), NULL, 0); + if (numWords == 0) { + GGML_LOG_ERROR("ggml-hex: Failed get systemRoot environment variable\n"); + return result; + } + + // Query the actual system root name from environment variable + std::vector<wchar_t> systemRoot(numWords + 1); + numWords = GetEnvironmentVariableW(systemRootEnv.c_str(), systemRoot.data(), numWords + 1); + if (numWords == 0) { + GGML_LOG_ERROR("ggml-hex: Failed to read windir environment variable\n"); + return result; + } + driverPath.replace(0, systemRootPlaceholder.length(), std::wstring(systemRoot.data())); + + return wstr_to_str(driverPath); +} + +#endif + +using dl_handle_ptr = std::unique_ptr<dl_handle, dl_handle_deleter>; + +int htpdrv_init() { + static dl_handle_ptr lib_cdsp_rpc_handle = nullptr; + static bool initialized = false; +#ifdef _WIN32 + std::string drv_path = get_driver_path() + "\\" + "libcdsprpc.dll"; +#else + std::string drv_path = "libcdsprpc.so"; +#endif + if (initialized) { + GGML_LOG_INFO("ggml-hex: Driver already loaded\n"); + return AEE_SUCCESS; + } + GGML_LOG_INFO("ggml-hex: Loading driver %s\n", drv_path.c_str()); + + fs::path path{ drv_path.c_str() }; + dl_handle_ptr handle { dl_load_library(path) }; + if (!handle) { + GGML_LOG_ERROR("ggml-hex: failed to load %s: %s\n", path.u8string().c_str(), dl_error()); + return AEE_EUNABLETOLOAD; + } + +#define dlsym(drv, type, pfn, symbol, ignore) \ + do { \ + pfn = (type) dl_get_sym(drv, #symbol); \ + if (!ignore && nullptr == pfn) { \ + GGML_LOG_ERROR("ggml-hex: failed to dlsym %s\n", #symbol); \ + return AEE_EUNABLETOLOAD; \ + } \ + } while (0) + + dlsym(handle.get(), rpcmem_alloc_pfn_t, rpcmem_alloc_pfn, rpcmem_alloc, false); + dlsym(handle.get(), rpcmem_alloc2_pfn_t, rpcmem_alloc2_pfn, rpcmem_alloc2, true); + dlsym(handle.get(), rpcmem_free_pfn_t, rpcmem_free_pfn, rpcmem_free, false); + dlsym(handle.get(), rpcmem_to_fd_pfn_t, rpcmem_to_fd_pfn, rpcmem_to_fd, false); + dlsym(handle.get(), fastrpc_mmap_pfn_t, fastrpc_mmap_pfn, fastrpc_mmap, false); + dlsym(handle.get(), fastrpc_munmap_pfn_t, fastrpc_munmap_pfn, fastrpc_munmap, false); + dlsym(handle.get(), dspqueue_create_pfn_t, dspqueue_create_pfn, dspqueue_create, false); + dlsym(handle.get(), dspqueue_close_pfn_t, dspqueue_close_pfn, dspqueue_close, false); + dlsym(handle.get(), dspqueue_export_pfn_t, dspqueue_export_pfn, dspqueue_export, false); + dlsym(handle.get(), dspqueue_write_pfn_t, dspqueue_write_pfn, dspqueue_write, false); + dlsym(handle.get(), dspqueue_read_pfn_t, dspqueue_read_pfn, dspqueue_read, false); + dlsym(handle.get(), remote_handle64_open_pfn_t, remote_handle64_open_pfn, remote_handle64_open, false); + dlsym(handle.get(), remote_handle64_invoke_pfn_t, remote_handle64_invoke_pfn, remote_handle64_invoke, false); + dlsym(handle.get(), remote_handle_control_pfn_t, remote_handle_control_pfn, remote_handle_control, false); + dlsym(handle.get(), remote_handle64_control_pfn_t, remote_handle64_control_pfn, remote_handle64_control, false); + dlsym(handle.get(), remote_session_control_pfn_t, remote_session_control_pfn, remote_session_control, false); + dlsym(handle.get(), remote_handle64_close_pfn_t, remote_handle64_close_pfn, remote_handle64_close, false); + + lib_cdsp_rpc_handle = std::move(handle); + initialized = true; + + return AEE_SUCCESS; +} + +domain * get_domain(int domain_id) { + int i = 0; + int size = sizeof(supported_domains) / sizeof(domain); + + for (i = 0; i < size; i++) { + if (supported_domains[i].id == domain_id) { + return &supported_domains[i]; + } + } + + return NULL; +} + +int get_hex_arch_ver(int domain, int * arch) { + if (!remote_handle_control_pfn) { + GGML_LOG_ERROR("ggml-hex: remote_handle_control is not supported on this device\n"); + return AEE_EUNSUPPORTEDAPI; + } + + struct remote_dsp_capability arch_ver; + arch_ver.domain = (uint32_t) domain; + arch_ver.attribute_ID = ARCH_VER; + arch_ver.capability = (uint32_t) 0; + + int err = remote_handle_control(DSPRPC_GET_DSP_INFO, &arch_ver, sizeof(arch_ver)); + if ((err & 0xff) == (AEE_EUNSUPPORTEDAPI & 0xff)) { + GGML_LOG_ERROR("ggml-hex: FastRPC capability API is not supported on this device\n"); + return AEE_EUNSUPPORTEDAPI; + } + + if (err != AEE_SUCCESS) { + GGML_LOG_ERROR("ggml-hex: FastRPC capability query failed (err %d)\n", err); + return err; + } + + switch (arch_ver.capability & 0xff) { + case 0x68: + *arch = 68; + return 0; + case 0x69: + *arch = 69; + return 0; + case 0x73: + *arch = 73; + return 0; + case 0x75: + *arch = 75; + return 0; + case 0x79: + *arch = 79; + return 0; + case 0x81: + *arch = 81; + return 0; + } + return -1; +} diff --git a/ggml/src/ggml-hexagon/htp-drv.h b/ggml/src/ggml-hexagon/htp-drv.h new file mode 100644 index 00000000000..6eba7ba17d8 --- /dev/null +++ b/ggml/src/ggml-hexagon/htp-drv.h @@ -0,0 +1,121 @@ +#pragma once + +#ifdef __cplusplus +extern "C" { +#endif + +#ifdef _WIN32 +# pragma clang diagnostic ignored "-Wignored-attributes" +#endif + +#include <AEEStdErr.h> +#include <rpcmem.h> +#include <remote.h> +#include <dspqueue.h> + +#if defined(_WIN32) && !defined(__MINGW32__) +# ifdef GGML_BACKEND_BUILD +# define HTPDRV_API __declspec(dllexport) extern +# else +# define HTPDRV_API __declspec(dllimport) extern +# endif +#else +# define HTPDRV_API __attribute__ ((visibility ("default"))) extern +#endif + +/* Offset to differentiate HLOS and Hexagon error codes. + Stores the value of AEE_EOFFSET for Hexagon. */ +#ifndef DSP_OFFSET +# define DSP_OFFSET 0x80000400 +#endif + +/* Errno for connection reset by peer. */ +#ifndef ECONNRESET +# ifdef __hexagon__ +# define ECONNRESET 104 +# endif +#endif + +/* Abstraction of different OS specific sleep APIs. + SLEEP accepts input in seconds. */ +#ifndef SLEEP +# ifdef __hexagon__ +# define SLEEP(x) \ + { /* Do nothing for simulator. */ \ + } +# else +# ifdef _WIN32 +# define SLEEP(x) Sleep(1000 * x) /* Sleep accepts input in milliseconds. */ +# else +# define SLEEP(x) sleep(x) /* sleep accepts input in seconds. */ +# endif +# endif +#endif + +/* Include windows specific header files. */ +#ifdef _WIN32 +# include <windows.h> +# include <sysinfoapi.h> +# define _CRT_SECURE_NO_WARNINGS 1 +# define _WINSOCK_DEPRECATED_NO_WARNINGS 1 +#endif + +/* Includes and defines for all HLOS except windows */ +#if !defined(__hexagon__) && !defined(_WIN32) +# include "unistd.h" + +# include <sys/time.h> +#endif + +/* Includes and defines for Hexagon and all HLOS except Windows. */ +#if !defined(_WIN32) +/* Weak reference to remote symbol for compilation. */ +# pragma weak remote_session_control +# pragma weak remote_handle_control +# pragma weak remote_handle64_control +# pragma weak fastrpc_mmap +# pragma weak fastrpc_munmap +# pragma weak rpcmem_alloc2 +#endif + +#if !defined(_WIN32) +# pragma weak remote_system_request +#endif + +#ifdef _WIN32 +# define DSPQUEUE_TIMEOUT DSPQUEUE_TIMEOUT_NONE +#else +# define DSPQUEUE_TIMEOUT 1000000 +#endif + +/** + * htpdrv_init API: driver interface entry point + * + * @return Return AEE error codes as defined in Hexagon SDK. + */ +HTPDRV_API int htpdrv_init(void); + +/** + * get_domain API: get domain struct from domain value. + * + * @param[in] domain value of a domain + * @return Returns domain struct of the domain if it is supported or else + * returns NULL. + * + */ +HTPDRV_API domain * get_domain(int domain_id); + +/** + * get_hex_arch_ver API: query the Hexagon processor architecture version information + * + * @param[in] domain_id value of a domain + * @param[out] Arch version (73, 75, ...) + * @return 0 if query is successful. + * non-zero if error, return value points to the error. + * + */ +HTPDRV_API int get_hex_arch_ver(int domain, int * arch); + +#ifdef __cplusplus +} +#endif diff --git a/ggml/src/ggml-hexagon/htp-opnode.h b/ggml/src/ggml-hexagon/htp-opnode.h new file mode 100644 index 00000000000..52c727c6206 --- /dev/null +++ b/ggml/src/ggml-hexagon/htp-opnode.h @@ -0,0 +1,272 @@ +#ifndef HTP_OPNODE_H +#define HTP_OPNODE_H + +#define GGML_COMMON_IMPL_CPP +#include "ggml-backend-impl.h" +#include "ggml-common.h" + +#include <string> +#include <vector> +#include <stdio.h> +#include "htp-ops.h" + +struct htp_opnode { + ggml_tensor * node = nullptr; + + std::vector<ggml_tensor *> fused; + + htp_op_code opcode = HTP_OP_INVALID; + + ggml_op op() const { + return node->op; + } + + const ggml_tensor * dst() const { + return fused.empty() ? node : fused.back(); + } + + const ggml_tensor * src0() const { + return node->src[0]; + } + + const ggml_tensor * src1() const { + return node->src[1]; + } + + bool is_empty() const { + return ggml_op_is_empty(node->op); + } + + void add_fused(ggml_tensor * t) { + fused.push_back(t); + } + + bool stackable() const { + switch (this->op()) { + case GGML_OP_MUL_MAT: + case GGML_OP_MUL_MAT_ID: + return ggml_is_quantized(this->src0()->type); + default: + return false; + } + } + + bool same_input(const htp_opnode& n) const { + return n.src1() == this->src1(); + } + + std::vector<const ggml_tensor *> get_inputs() const { + if (fused.empty()) { + int last_non_null = -1; + for (int i = 0; i < GGML_MAX_SRC; i++) { + if (node->src[i]) { + last_non_null = i; + } + } + std::vector<const ggml_tensor *> inputs(last_non_null + 1, nullptr); + for (int i = 0; i <= last_non_null; i++) { + inputs[i] = node->src[i]; + } + return inputs; + } + + std::vector<const ggml_tensor *> inputs(GGML_MAX_SRC, nullptr); + std::vector<const ggml_tensor *> outputs; + outputs.push_back(node); + for (const auto * f : fused) { + outputs.push_back(f); + } + + auto contains = [&](const std::vector<const ggml_tensor *> & vec, const ggml_tensor * t) { + for (const auto * x : vec) { + if (x == t) return true; + } + return false; + }; + + int count = 0; + auto add_input = [&](const ggml_tensor * t) { + if (t && !contains(outputs, t) && !contains(inputs, t)) { + if (count < (int)inputs.size()) { + inputs[count++] = t; + } else { + inputs.push_back(t); + } + } + }; + + for (int i = 0; i < GGML_MAX_SRC; i++) { + if (node->src[i]) { + add_input(node->src[i]); + } + } + for (const auto * f : fused) { + for (int i = 0; i < GGML_MAX_SRC; i++) { + if (f->src[i]) { + add_input(f->src[i]); + } + } + } + + inputs.resize(count); + return inputs; + } + + std::string op_name() const { + if (fused.empty()) { + return ggml_op_desc(node); + } + std::string name = ggml_op_desc(node); + for (const auto * f : fused) { + name += "+"; + name += ggml_op_desc(f); + } + return name; + } +}; + +struct htp_opformat { + char strides[64 * GGML_MAX_SRC]; + char dims[64 * GGML_MAX_SRC]; + char types[16 * GGML_MAX_SRC]; + char buffs[64 * GGML_MAX_SRC]; + char names[64 * GGML_MAX_SRC]; + + int format_tensor_dims(char * str, const struct ggml_tensor * t) { + if (!t) { + return sprintf(str, "NONE"); + } + if (t->ne[2] == 1 && t->ne[3] == 1) { + return sprintf(str, "%d:%d", (int) t->ne[0], (int) t->ne[1]); + } else { + return sprintf(str, "%d:%d:%d:%d", (int) t->ne[0], (int) t->ne[1], (int) t->ne[2], (int) t->ne[3]); + } + } + + void format_op_dims(char * str, const htp_opnode & node) { + char * p = str; + auto inputs = node.get_inputs(); + + if (!inputs.empty()) { + p += format_tensor_dims(p, inputs[0]); + + for (size_t i = 1; i < inputs.size(); i++) { + p += sprintf(p, " x "); + p += format_tensor_dims(p, inputs[i]); + } + + p += sprintf(p, " -> "); + } + + char self[64]; + format_tensor_dims(self, node.dst()); + p += sprintf(p, "%s", self); + } + + int format_tensor_strides(char * str, const struct ggml_tensor * t) { + if (!t) { + return sprintf(str, "NONE"); + } + const char * c = ggml_is_contiguous(t) ? "" : "!"; + + if (t->ne[2] == 1 && t->ne[3] == 1) { + return sprintf(str, "%zu:%zu%s", (size_t) t->nb[0], (size_t) t->nb[1], c); + } else { + return sprintf(str, "%zu:%zu:%zu:%zu%s", (size_t) t->nb[0], (size_t) t->nb[1], (size_t) t->nb[2], (size_t) t->nb[3], c); + } + } + + void format_op_strides(char * str, const htp_opnode & node) { + char * p = str; + auto inputs = node.get_inputs(); + + if (!inputs.empty()) { + p += format_tensor_strides(p, inputs[0]); + + for (size_t i = 1; i < inputs.size(); i++) { + p += sprintf(p, " x "); + p += format_tensor_strides(p, inputs[i]); + } + + p += sprintf(p, " -> "); + } + + char self[64]; + format_tensor_strides(self, node.dst()); + p += sprintf(p, "%s", self); + } + + void format_op_types(char * str, const htp_opnode & node) { + char * p = str; + auto inputs = node.get_inputs(); + + if (!inputs.empty()) { + p += sprintf(p, "%s", inputs[0] ? ggml_type_name(inputs[0]->type) : "NONE"); + + for (size_t i = 1; i < inputs.size(); i++) { + p += sprintf(p, " x "); + p += sprintf(p, "%s", inputs[i] ? ggml_type_name(inputs[i]->type) : "NONE"); + } + + p += sprintf(p, " -> "); + } + + p += sprintf(p, "%s", ggml_type_name(node.dst()->type)); + } + + const char * tensor_buff_name(const struct ggml_tensor * t) { + if (t && t->buffer) { + return ggml_backend_buffer_name(t->buffer); + } + return "NONE"; + } + + void format_op_buffs(char * str, const htp_opnode & node) { + char * p = str; + auto inputs = node.get_inputs(); + + if (!inputs.empty()) { + p += sprintf(p, "%s", tensor_buff_name(inputs[0])); + + for (size_t i = 1; i < inputs.size(); i++) { + p += sprintf(p, " x "); + p += sprintf(p, "%s", tensor_buff_name(inputs[i])); + } + + p += sprintf(p, " -> "); + } + + p += sprintf(p, "%s", tensor_buff_name(node.dst())); + } + + void format_op_names(char * str, const htp_opnode & node) { + char * p = str; + auto inputs = node.get_inputs(); + + if (!inputs.empty()) { + p += sprintf(p, "%s", inputs[0] ? inputs[0]->name : "NONE"); + + for (size_t i = 1; i < inputs.size(); i++) { + p += sprintf(p, " x "); + p += sprintf(p, "%s", inputs[i] ? inputs[i]->name : "NONE"); + } + + p += sprintf(p, " -> "); + } + + p += sprintf(p, "%s", node.dst()->name); + } + + void format(const htp_opnode & node) { + format_op_dims(dims, node); + format_op_strides(strides, node); + format_op_types(types, node); + format_op_buffs(buffs, node); + format_op_names(names, node); + } + + htp_opformat() {} + htp_opformat(const htp_opnode & node) { format(node); } +}; + +#endif // HTP_OPNODE_H diff --git a/ggml/src/ggml-hexagon/htp-utils.c b/ggml/src/ggml-hexagon/htp-utils.c deleted file mode 100644 index 3f335bf71c0..00000000000 --- a/ggml/src/ggml-hexagon/htp-utils.c +++ /dev/null @@ -1,454 +0,0 @@ - -#pragma clang diagnostic ignored "-Wgnu-anonymous-struct" -#pragma clang diagnostic ignored "-Wmissing-prototypes" -#pragma clang diagnostic ignored "-Wsign-compare" - -#define GGML_COMMON_IMPL_C -#include "ggml-backend-impl.h" -#include "ggml-common.h" -#include "ggml-hexagon.h" -#include "ggml-impl.h" - -#include "htp-utils.h" - -#include <domain.h> -#include <remote.h> -#include <stdbool.h> -#include <stdint.h> -#include <stdio.h> -#include <stdlib.h> -#include <string.h> - -domain * get_domain(int domain_id) { - int i = 0; - int size = sizeof(supported_domains) / sizeof(domain); - - for (i = 0; i < size; i++) { - if (supported_domains[i].id == domain_id) { - return &supported_domains[i]; - } - } - - return NULL; -} - -bool is_valid_domain_id(int domain_id, int compute_only) { - int i = 0; - int size = sizeof(supported_domains) / sizeof(domain); - - if (compute_only) { - return is_CDSP(domain_id); - } - - for (i = 0; i < size; i++) { - if (supported_domains[i].id == domain_id) { - return true; - } - } - - return false; -} - -int get_domains_info(char * domain_type, int * num_domains, fastrpc_domain ** domains_info) { - int nErr = AEE_SUCCESS; - int ss_info = 0; - if (domain_type != NULL) { - if (strcmp(domain_type, "LPASS") == 0) { - ss_info = FASTRPC_LPASS; - } else if (strcmp(domain_type, "HPASS") == 0) { - ss_info = FASTRPC_HPASS; - } else { - ss_info = FASTRPC_NSP; - } - } - system_req_payload req = { 0 }; - req.id = FASTRPC_GET_DOMAINS; - req.sys.domains = NULL; - fastrpc_domain * domain = NULL; - if (ss_info != 0) { - req.sys.flags = DOMAINS_LIST_FLAGS_SET_TYPE(req.sys.flags, ss_info); - } else { - req.sys.flags = 0; - } -#ifdef _WIN32 - nErr = AEE_EUNSUPPORTED; - goto bail; -#endif - if (remote_system_request) { - nErr = remote_system_request(&req); - if (nErr != AEE_SUCCESS) { - GGML_LOG_ERROR("Failure in remote_system_request call: %d.\n", nErr); - goto bail; - } - // Allocate memory for domain-info array - req.sys.max_domains = req.sys.num_domains; - if ((req.sys.domains = calloc(req.sys.num_domains, sizeof(fastrpc_domain))) == NULL) { - nErr = AEE_ENOMEMORY; - GGML_LOG_ERROR("Unable to allocate memory for req.sys.domains"); - goto bail; - } - - nErr = remote_system_request(&req); - if (nErr != AEE_SUCCESS) { - GGML_LOG_ERROR("Failure in remote_system_request call: %d.\n", nErr); - goto bail; - } - - for (int i = 0; i < req.sys.num_domains; i++) { - // Verify that only requested type domains were returned - domain = &req.sys.domains[i]; - if (domain->type != ss_info && domain_type != NULL) { - nErr = -1; - GGML_LOG_ERROR("Incorrect data received from remote_system_request.\n"); - goto bail; - } - } - *domains_info = req.sys.domains; - *num_domains = req.sys.num_domains; - } else { - nErr = AEE_EUNSUPPORTED; - goto bail; - } -bail: - if (nErr && !req.sys.domains) { - free(req.sys.domains); - } - return nErr; -} - -int get_effective_domain_id(char * domain_name, int session_id, int * effec_domain_id) { - int err = 0; - remote_rpc_effective_domain_id_t sess = { 0 }; - - sess.domain_name = domain_name; - sess.domain_name_len = strlen(domain_name); - sess.session_id = session_id; - - err = remote_session_control(FASTRPC_GET_EFFECTIVE_DOMAIN_ID, &sess, sizeof(sess)); - if (err) { - GGML_LOG_ERROR("Error 0x%x: failed to get effective domain id for %s, session id %d\n", err, sess.domain_name, - session_id); - return err; - } - - *effec_domain_id = sess.effective_domain_id; - return err; -} - -int get_dsp_support(int * domain) { - int nErr = AEE_SUCCESS; - *domain = CDSP_DOMAIN_ID; // DSP domain default value is CDSP_DOMAIN_ID - - if (remote_handle_control) { - struct remote_dsp_capability dsp_capability_domain = { CDSP_DOMAIN_ID, DOMAIN_SUPPORT, 0 }; - nErr = remote_handle_control(DSPRPC_GET_DSP_INFO, &dsp_capability_domain, sizeof(struct remote_dsp_capability)); - if ((nErr & 0xFF) == (AEE_EUNSUPPORTEDAPI & 0xFF)) { - GGML_LOG_ERROR("\nFastRPC Capability API is not supported on this device\n"); - goto bail; - } - - if (dsp_capability_domain.capability == 0) { - dsp_capability_domain.domain = ADSP_DOMAIN_ID; // Check for ADSP support. - dsp_capability_domain.attribute_ID = DOMAIN_SUPPORT; - dsp_capability_domain.capability = 0; - nErr = remote_handle_control(DSPRPC_GET_DSP_INFO, &dsp_capability_domain, - sizeof(struct remote_dsp_capability)); - if (dsp_capability_domain.capability) { - *domain = ADSP_DOMAIN_ID; // For targets like Agatti (not having cDSP), domain is ADSP_DOMAIN_ID - } - } - - if (nErr != AEE_SUCCESS) { - GGML_LOG_ERROR("\nget_dsp_support failed with Error 0x%x\n", nErr); - goto bail; - } - } else { - nErr = AEE_EUNSUPPORTEDAPI; - GGML_LOG_ERROR("remote_dsp_capability interface is not supported on this device\n"); - } - -bail: - return nErr; -} - -int get_vtcm_info(int domain, uint32_t * capability, uint32_t attr) { - int nErr = AEE_SUCCESS; - *capability = 0; - - if (attr == VTCM_PAGE || attr == VTCM_COUNT) { - } else { - nErr = AEE_EBADPARM; - GGML_LOG_ERROR("Unsupported attr. Only VTCM_PAGE and VTCM_COUNT supported\n"); - goto bail; - } - if (remote_handle_control) { - if (domain == ADSP_DOMAIN_ID || domain == CDSP_DOMAIN_ID) { - /* - * Query the DSP for VTCM information - * Since the ADSP does not have a dedicated VTCM, we expect the output to be 0 - */ - struct remote_dsp_capability dsp_capability_vtcm_dsp; - dsp_capability_vtcm_dsp.domain = (uint32_t) domain; - dsp_capability_vtcm_dsp.attribute_ID = attr; - dsp_capability_vtcm_dsp.capability = (uint32_t) 0; - nErr = remote_handle_control(DSPRPC_GET_DSP_INFO, &dsp_capability_vtcm_dsp, - sizeof(struct remote_dsp_capability)); - if ((nErr & 0xFF) == (AEE_EUNSUPPORTEDAPI & 0xFF)) { - GGML_LOG_ERROR("\nFastRPC Capability API is not supported on this device\n"); - GGML_LOG_ERROR("Running the usecase without checking the capability\n"); - nErr = AEE_SUCCESS; - goto bail; - } else if (nErr == AEE_SUCCESS) { - *capability = dsp_capability_vtcm_dsp.capability; - } else { - GGML_LOG_ERROR("\nget_vtcm_info failed with Error 0x%x\n", nErr); - goto bail; - } - } else { - nErr = AEE_EUNSUPPORTED; - GGML_LOG_ERROR("Unsupported domain %d\n", domain); - goto bail; - } - } else { - nErr = AEE_EUNSUPPORTEDAPI; - GGML_LOG_ERROR("remote_dsp_capability interface is not supported on this device\n"); - } - -bail: - return nErr; -} - -bool is_unsignedpd_supported(int domain_id) { - int nErr = AEE_SUCCESS; - if (remote_handle_control) { - struct remote_dsp_capability dsp_capability_domain = { domain_id, UNSIGNED_PD_SUPPORT, 0 }; - nErr = remote_handle_control(DSPRPC_GET_DSP_INFO, &dsp_capability_domain, sizeof(struct remote_dsp_capability)); - if ((nErr & 0xFF) == (AEE_EUNSUPPORTEDAPI & 0xFF)) { - GGML_LOG_ERROR("\nFastRPC Capability API is not supported on this device. Falling back to signed pd.\n"); - return false; - } - if (nErr) { - GGML_LOG_ERROR("\nERROR 0x%x: FastRPC Capability API failed. Falling back to signed pd.", nErr); - return false; - } - if (dsp_capability_domain.capability == 1) { - return true; - } - } else { - nErr = AEE_EUNSUPPORTEDAPI; - GGML_LOG_ERROR("remote_dsp_capability interface is not supported on this device. Falling back to signed pd.\n"); - return false; - } - return false; -} - -bool get_unsignedpd_support(void) { - return is_unsignedpd_supported(CDSP_DOMAIN_ID); -} - -bool is_async_fastrpc_supported(int domain) { - int nErr = AEE_SUCCESS; - if (remote_handle_control) { - if (domain == CDSP_DOMAIN_ID) { - /* - * Query the DSP for ASYNC_FASTRPC_SUPPORT information - * Async fastrpc is supported only on CDSP - */ - struct remote_dsp_capability dsp_capability_async_support; - dsp_capability_async_support.domain = (uint32_t) domain; - dsp_capability_async_support.attribute_ID = ASYNC_FASTRPC_SUPPORT; - dsp_capability_async_support.capability = (uint32_t) 0; - nErr = remote_handle_control(DSPRPC_GET_DSP_INFO, &dsp_capability_async_support, - sizeof(struct remote_dsp_capability)); - if ((nErr & 0xFF) == (AEE_EUNSUPPORTEDAPI & 0xFF)) { - GGML_LOG_ERROR("\nFastRPC Capability API is not supported on this device\n"); - GGML_LOG_ERROR("Running the usecase without checking the capability\n"); - nErr = AEE_SUCCESS; - goto bail; - } else if (dsp_capability_async_support.capability == 1) { - return true; - } - if (nErr != AEE_SUCCESS) { - GGML_LOG_ERROR("\nis_async_fastrpc_supported failed with Error 0x%x\n", nErr); - goto bail; - } - } else { - nErr = AEE_EUNSUPPORTED; - GGML_LOG_ERROR("Async fastrpc is not supported on domain %d\n", domain); - goto bail; - } - } else { - nErr = AEE_EUNSUPPORTEDAPI; - GGML_LOG_ERROR("remote_dsp_capability interface is not supported on this device\n"); - } - -bail: - return false; -} - -bool is_status_notification_supported(int domain) { - int nErr = AEE_SUCCESS; - - if (remote_handle_control) { - /* - * Query the DSP for STATUS_NOTIFICATION_SUPPORT information - * DSP User PD status notification Support - */ - struct remote_dsp_capability dsp_capability_status_notification_support; - dsp_capability_status_notification_support.domain = (uint32_t) domain; - dsp_capability_status_notification_support.attribute_ID = STATUS_NOTIFICATION_SUPPORT; - dsp_capability_status_notification_support.capability = (uint32_t) 0; - nErr = remote_handle_control(DSPRPC_GET_DSP_INFO, &dsp_capability_status_notification_support, - sizeof(struct remote_dsp_capability)); - if ((nErr & 0xFF) == (AEE_EUNSUPPORTEDAPI & 0xFF)) { - GGML_LOG_ERROR("\nFastRPC Capability API is not supported on this device\n"); - GGML_LOG_ERROR("Running the usecase without checking the capability\n"); - nErr = AEE_SUCCESS; - goto bail; - } else if (dsp_capability_status_notification_support.capability == 1) { - return true; - } - if (nErr != AEE_SUCCESS) { - GGML_LOG_ERROR("\nis_status_notification_supported failed with Error 0x%x\n", nErr); - goto bail; - } - } else { - nErr = AEE_EUNSUPPORTEDAPI; - GGML_LOG_ERROR("remote_dsp_capability interface is not supported on this device\n"); - } - -bail: - return false; -} - -int get_hmx_support_info(int domain, uint32_t * capability, uint32_t attr) { - int nErr = AEE_SUCCESS; - *capability = 0; - - if (attr != HMX_SUPPORT_SPATIAL && attr != HMX_SUPPORT_DEPTH) { - nErr = AEE_EBADPARM; - GGML_LOG_ERROR("Unsupported attr. Only HMX_SUPPORT_SPATIAL and HMX_SUPPORT_DEPTH supported\n"); - goto bail; - } - if (remote_handle_control) { - if (domain == CDSP_DOMAIN_ID) { - /* - * Query the DSP for HMX SUPPORT information - * HMX is supported on CDSP only - */ - struct remote_dsp_capability dsp_capability_hmx_dsp; - dsp_capability_hmx_dsp.domain = (uint32_t) domain; - dsp_capability_hmx_dsp.attribute_ID = attr; - dsp_capability_hmx_dsp.capability = (uint32_t) 0; - nErr = remote_handle_control(DSPRPC_GET_DSP_INFO, &dsp_capability_hmx_dsp, - sizeof(struct remote_dsp_capability)); - if ((nErr & 0xFF) == (AEE_EUNSUPPORTEDAPI & 0xFF)) { - GGML_LOG_ERROR("\nFastRPC Capability API is not supported on this device\n"); - GGML_LOG_ERROR("Running the usecase without checking the capability\n"); - nErr = AEE_SUCCESS; - goto bail; - } else if (nErr == AEE_SUCCESS) { - *capability = dsp_capability_hmx_dsp.capability; - } else { - GGML_LOG_ERROR("\nget_hmx_support_info failed with Error 0x%x\n", nErr); - goto bail; - } - } else { - nErr = AEE_EUNSUPPORTED; - GGML_LOG_ERROR("HMX support is not there for domain %d\n", domain); - goto bail; - } - } else { - nErr = AEE_EUNSUPPORTEDAPI; - GGML_LOG_ERROR("remote_dsp_capability interface is not supported on this device\n"); - } - -bail: - return nErr; -} - -int get_hex_arch_ver(int domain, int * arch) { - if (!remote_handle_control) { - GGML_LOG_ERROR("ggml-hex: remote_handle_control is not supported on this device\n"); - return AEE_EUNSUPPORTEDAPI; - } - - struct remote_dsp_capability arch_ver; - arch_ver.domain = (uint32_t) domain; - arch_ver.attribute_ID = ARCH_VER; - arch_ver.capability = (uint32_t) 0; - - int err = remote_handle_control(DSPRPC_GET_DSP_INFO, &arch_ver, sizeof(arch_ver)); - if ((err & 0xff) == (AEE_EUNSUPPORTEDAPI & 0xff)) { - GGML_LOG_ERROR("ggml-hex: FastRPC capability API is not supported on this device\n"); - return AEE_EUNSUPPORTEDAPI; - } - - if (err != AEE_SUCCESS) { - GGML_LOG_ERROR("ggml-hex: FastRPC capability query failed (err %d)\n", err); - return err; - } - - switch (arch_ver.capability & 0xff) { - case 0x68: - *arch = 68; - return 0; - case 0x69: - *arch = 69; - return 0; - case 0x73: - *arch = 73; - return 0; - case 0x75: - *arch = 75; - return 0; - case 0x79: - *arch = 79; - return 0; - case 0x81: - *arch = 81; - return 0; - } - return -1; -} - -int get_hvx_support_info(int domain, uint32_t * capability, uint32_t attr) { - int nErr = AEE_SUCCESS; - *capability = 0; - - if (remote_handle_control) { - if (domain == CDSP_DOMAIN_ID) { - /* - * Query the DSP for HVX SUPPORT information - * HVX is supported on CDSP only - */ - struct remote_dsp_capability dsp_capability_hvx_dsp; - dsp_capability_hvx_dsp.domain = (uint32_t) domain; - dsp_capability_hvx_dsp.attribute_ID = attr; - dsp_capability_hvx_dsp.capability = (uint32_t) 0; - nErr = remote_handle_control(DSPRPC_GET_DSP_INFO, &dsp_capability_hvx_dsp, - sizeof(struct remote_dsp_capability)); - if ((nErr & 0xFF) == (AEE_EUNSUPPORTEDAPI & 0xFF)) { - GGML_LOG_ERROR("\nFastRPC Capability API is not supported on this device\n"); - GGML_LOG_ERROR("Running the usecase without checking the capability\n"); - nErr = AEE_SUCCESS; - goto bail; - } else if (nErr == AEE_SUCCESS) { - *capability = dsp_capability_hvx_dsp.capability; - } else { - GGML_LOG_ERROR("\nget_hvx_support_info failed with Error 0x%x\n", nErr); - goto bail; - } - } else { - nErr = AEE_EUNSUPPORTED; - GGML_LOG_ERROR("HVX support is not available on domain %d\n", domain); - goto bail; - } - } else { - nErr = AEE_EUNSUPPORTEDAPI; - GGML_LOG_ERROR("remote_dsp_capability interface is not supported on this device\n"); - } - -bail: - return nErr; -} diff --git a/ggml/src/ggml-hexagon/htp-utils.h b/ggml/src/ggml-hexagon/htp-utils.h deleted file mode 100644 index 7bbae3a0b73..00000000000 --- a/ggml/src/ggml-hexagon/htp-utils.h +++ /dev/null @@ -1,221 +0,0 @@ -#ifndef HTP_UTILS_H -#define HTP_UTILS_H - -#ifdef __cplusplus -extern "C" { -#endif - -#include <AEEStdErr.h> -#include <inttypes.h> -#include <remote.h> -#include <rpcmem.h> -#include <stdbool.h> - -/* Offset to differentiate HLOS and Hexagon error codes. - Stores the value of AEE_EOFFSET for Hexagon. */ -#ifndef DSP_OFFSET -# define DSP_OFFSET 0x80000400 -#endif - -/* Errno for connection reset by peer. */ -#ifndef ECONNRESET -# ifdef __hexagon__ -# define ECONNRESET 104 -# endif -#endif - -/* Abstraction of different OS specific sleep APIs. - SLEEP accepts input in seconds. */ -#ifndef SLEEP -# ifdef __hexagon__ -# define SLEEP(x) \ - { /* Do nothing for simulator. */ \ - } -# else -# ifdef _WINDOWS -# define SLEEP(x) Sleep(1000 * x) /* Sleep accepts input in milliseconds. */ -# else -# define SLEEP(x) sleep(x) /* sleep accepts input in seconds. */ -# endif -# endif -#endif - -/* Include windows specific header files. */ -#ifdef _WINDOWS -# include <sysinfoapi.h> -# include <windows.h> -# define _CRT_SECURE_NO_WARNINGS 1 -# define _WINSOCK_DEPRECATED_NO_WARNINGS 1 -/* Including this file for custom implementation of getopt function. */ -# include "getopt_custom.h" -#endif - -/* Includes and defines for all HLOS except windows */ -#if !defined(__hexagon__) && !defined(_WINDOWS) -# include "unistd.h" - -# include <sys/time.h> -#endif - -/* Includes and defines for Hexagon and all HLOS except Windows. */ -#if !defined(_WINDOWS) -/* Weak reference to remote symbol for compilation. */ -# pragma weak remote_session_control -# pragma weak remote_handle_control -# pragma weak remote_handle64_control -# pragma weak fastrpc_mmap -# pragma weak fastrpc_munmap -# pragma weak rpcmem_alloc2 -#endif - -#if !defined(_WINDOWS) -# pragma weak remote_system_request -#endif -/** - * Wrapper for FastRPC Capability API: query DSP support. - * - * @param[out] domain pointer to supported domain. - * @return 0 if query is successful. - * non-zero if error, return value points to the error. - */ -int get_dsp_support(int * domain); - -/** - * Wrapper for FastRPC Capability API: query VTCM information. - * - * @param[in] domain value of domain in the queried. - * @param[out] capability capability value of the attribute queried. - * @param[in] attr value of the attribute to the queried. - * @return 0 if query is successful. - * non-zero if error, return value points to the error. - */ -int get_vtcm_info(int domain, uint32_t * capability, uint32_t attr); - -/** - * Wrapper for FastRPC Capability API: query unsigned pd support on CDSP domain. - * - * @return true if unsigned pd is supported. - * false if unsigned pd is not supported, capability query failed. - */ - -bool get_unsignedpd_support(void); - -/** - * Wrapper for FastRPC Capability API: query unsigned pd support. - * - * @param[in] domain value of domain in the queried. - * @return true if unsigned pd is supported. - * false if unsigned pd is not supported, capability query failed. - */ - -bool is_unsignedpd_supported(int domain_id); - -/** - * is_valid_domain_id API: query a domain id is valid. - * - * @param[in] domain value of domain in the queried. - * @param[in] compute_only value of domain is only compared with CDSP domains supported by the target when enabled. - * @return true if value of domain is valid. - * false if value of domain is not valid. - */ - -bool is_valid_domain_id(int domain_id, int compute_only); - -/** - * get_domain API: get domain struct from domain value. - * - * @param[in] domain value of a domain - * @return Returns domain struct of the domain if it is supported or else - * returns NULL. - * - */ - -domain * get_domain(int domain_id); - -/** - * get_domains_info API: get information for all the domains available on the device - * - * @param[in] domain_type pointer to domain type - * @param[in] num_domains pointer to number of domains - * @param[in] domains_info pointer to save discovered domains information. - * @return 0 if query is successful. - * non-zero if error, return value points to the error. - * - * It is user's responsibility to free the memory used to store the domains info whose address is present in domains_info before closing the application. - * - */ - -int get_domains_info(char * domain_type, int * num_domains, fastrpc_domain ** domains_info); - -/** - * get_effective_domain_id API: get effective domain id for given session id - * - * @param[in] domain_name pointer to domain name - * @param[in] session_id - * @param[in] effec_domain_id pointer to save obtained effective domain id. - * @return 0 if query is successful. - * non-zero if error, return value points to the error. - * - */ - -int get_effective_domain_id(char * domain_name, int session_id, int * effec_domain_id); - -/** - * is_async_fastrpc_supported API: query a domain id has async fastrpc supported or not - * - * @param[in] domain_id value of a domain - * @return Returns true or false stating support of Async FastRPC - * - */ - -bool is_async_fastrpc_supported(int domain_id); - -/** - * is_status_notification_supported API: query the DSP for STATUS_NOTIFICATION_SUPPORT information - * - * @param[in] domain_id value of a domain - * @return Returns true or false stating status notification support information - * - */ -bool is_status_notification_supported(int domain_id); - -/** - * get_hmx_support_info API: query the DSP for HMX SUPPORT information - * - * @param[in] domain_id value of a domain - * @param[out] capability capability value of the attribute queried. - * @param[in] attr value of the attribute to the queried. - * @return 0 if query is successful. - * non-zero if error, return value points to the error. - * - */ -int get_hmx_support_info(int domain, uint32_t * capability, uint32_t attr); - -/** - * get_hex_arch_ver API: query the Hexagon processor architecture version information - * - * @param[in] domain_id value of a domain - * @param[out] Arch version (73, 75, ...) - * @return 0 if query is successful. - * non-zero if error, return value points to the error. - * - */ -int get_hex_arch_ver(int domain, int * arch); - -/** - * get_hvx_support_info API: query the DSP for HVX SUPPORT information - * - * @param[in] domain_id value of a domain - * @param[out] capability capability value of the attribute queried. - * @param[in] attr value of the attribute to the queried. - * @return 0 if query is successful. - * non-zero if error, return value points to the error. - * - */ -int get_hvx_support_info(int domain, uint32_t * capability, uint32_t attr); - -#ifdef __cplusplus -} -#endif - -#endif //DSP_CAPABILITIES_UTILS_H diff --git a/ggml/src/ggml-hexagon/htp/CMakeLists.txt b/ggml/src/ggml-hexagon/htp/CMakeLists.txt index 6a34a215fa4..f4b44fe1a65 100644 --- a/ggml/src/ggml-hexagon/htp/CMakeLists.txt +++ b/ggml/src/ggml-hexagon/htp/CMakeLists.txt @@ -6,6 +6,7 @@ include(${HEXAGON_SDK_ROOT}/build/cmake/hexagon_fun.cmake) include_directories( ${HEXAGON_SDK_ROOT}/incs ${HEXAGON_SDK_ROOT}/incs/stddef + ${CMAKE_CURRENT_SOURCE_DIR}/../../../include ${CMAKE_CURRENT_SOURCE_DIR}/../.. ${CMAKE_CURRENT_SOURCE_DIR}/.. ${CMAKE_CURRENT_SOURCE_DIR} @@ -17,28 +18,67 @@ add_library(${HTP_LIB} SHARED main.c htp_iface_skel.c worker-pool.c - htp-dma.c - hvx-sigmoid.c - hvx-inverse.c - hvx-exp.c - hvx-utils.c + hex-dma.c +) + +target_compile_definitions(${HTP_LIB} PRIVATE + $<IF:$<BOOL:${HEXAGON_HTP_DEBUG}>,HTP_DEBUG=1,NDEBUG=1> + $<IF:$<BOOL:${HEXAGON_HTP_DEBUG}>,FARF_HIGH=1,> + FP32_QUANTIZE_GROUP_SIZE=${GGML_HEXAGON_FP32_QUANTIZE_GROUP_SIZE}) + +if (GGML_HEXAGON_FA_EXP2_HF) + message(STATUS "ggml-htp: HMX_FA_USE_EXP2_HF=1 (use FP16 exp2 polynomial in FA softmax)") + target_compile_definitions(${HTP_LIB} PRIVATE HMX_FA_USE_EXP2_HF=1) +endif() + +# HMX acceleration: available on v73+ architectures +set(HTP_HMX_VERSIONS v73 v75 v79 v81) +list(FIND HTP_HMX_VERSIONS ${DSP_VERSION} _hmx_idx) + +if (_hmx_idx GREATER_EQUAL 0) + target_sources(${HTP_LIB} PRIVATE + hmx-matmul-ops.c + hmx-flash-attn-ops.c + hmx-queue.c + ) + + # -mhmx enables HMX instruction set (needed by files that include hmx-utils.h) + set_source_files_properties( + hmx-flash-attn-ops.c + hmx-matmul-ops.c + hmx-queue.c + PROPERTIES COMPILE_OPTIONS "-mhmx" + ) + + target_compile_definitions(${HTP_LIB} PRIVATE HTP_HAS_HMX=1) +endif() + +build_idl(htp_iface.idl ${HTP_LIB}) + +target_sources(${HTP_LIB} PRIVATE matmul-ops.c binary-ops.c unary-ops.c + sum-rows-ops.c softmax-ops.c act-ops.c rope-ops.c flash-attn-ops.c set-rows-ops.c get-rows-ops.c + cpy-ops.c + repeat-ops.c + argsort-ops.c + ssm-conv.c + cumsum-ops.c + fill-ops.c + concat-ops.c + diag-ops.c + solve-tri-ops.c + gated-delta-net-ops.c + pad-ops.c ) -target_compile_definitions(${HTP_LIB} PRIVATE - $<IF:$<BOOL:${HEXAGON_HTP_DEBUG}>,HTP_DEBUG=1,NDEBUG=1> - FP32_QUANTIZE_GROUP_SIZE=${GGML_HEXAGON_FP32_QUANTIZE_GROUP_SIZE}) - -build_idl(htp_iface.idl ${HTP_LIB}) - set_target_properties(${HTP_LIB} PROPERTIES EXPORT_COMPILE_COMMANDS ON) install(TARGETS ${HTP_LIB}) diff --git a/ggml/src/ggml-hexagon/htp/act-ops.c b/ggml/src/ggml-hexagon/htp/act-ops.c index 88bd2ddc435..6416d2dfbc3 100644 --- a/ggml/src/ggml-hexagon/htp/act-ops.c +++ b/ggml/src/ggml-hexagon/htp/act-ops.c @@ -2,101 +2,92 @@ #pragma clang diagnostic ignored "-Wunused-function" #pragma clang diagnostic ignored "-Wunused-but-set-variable" -#ifdef HTP_DEBUG -# define FARF_HIGH 1 -#endif #include <HAP_farf.h> -#include <HAP_mem.h> #include <HAP_perf.h> -#include <HAP_ps.h> -#include <hexagon_protos.h> -#include <hexagon_types.h> + #include <math.h> -#include <qurt_thread.h> #include <string.h> +#include "hex-dma.h" +#include "hvx-utils.h" + #define GGML_COMMON_DECL_C #include "ggml-common.h" #include "htp-ctx.h" -#include "htp-dma.h" -#include "htp-msg.h" #include "htp-ops.h" -#include "hvx-utils.h" -#include "ops-utils.h" - -#define htp_act_preamble3 \ - const uint32_t ne00 = src0->ne[0]; \ - const uint32_t ne01 = src0->ne[1]; \ - const uint32_t ne02 = src0->ne[2]; \ - const uint32_t ne03 = src0->ne[3]; \ - \ - const uint32_t ne10 = src1->ne[0]; \ - const uint32_t ne11 = src1->ne[1]; \ - const uint32_t ne12 = src1->ne[2]; \ - const uint32_t ne13 = src1->ne[3]; \ - \ - const uint32_t ne0 = dst->ne[0]; \ - const uint32_t ne1 = dst->ne[1]; \ - const uint32_t ne2 = dst->ne[2]; \ - const uint32_t ne3 = dst->ne[3]; \ - \ - const uint32_t nb00 = src0->nb[0]; \ - const uint32_t nb01 = src0->nb[1]; \ - const uint32_t nb02 = src0->nb[2]; \ - const uint32_t nb03 = src0->nb[3]; \ - \ - const uint32_t nb10 = src1->nb[0]; \ - const uint32_t nb11 = src1->nb[1]; \ - const uint32_t nb12 = src1->nb[2]; \ - const uint32_t nb13 = src1->nb[3]; \ - \ - const uint32_t nb0 = dst->nb[0]; \ - const uint32_t nb1 = dst->nb[1]; \ - const uint32_t nb2 = dst->nb[2]; \ - const uint32_t nb3 = dst->nb[3]; +#include "htp-ops.h" -#define htp_act_preamble2 \ - const uint32_t ne00 = src0->ne[0]; \ - const uint32_t ne01 = src0->ne[1]; \ - const uint32_t ne02 = src0->ne[2]; \ - const uint32_t ne03 = src0->ne[3]; \ - \ - const uint32_t ne0 = dst->ne[0]; \ - const uint32_t ne1 = dst->ne[1]; \ - const uint32_t ne2 = dst->ne[2]; \ - const uint32_t ne3 = dst->ne[3]; \ - \ - const uint32_t nb00 = src0->nb[0]; \ - const uint32_t nb01 = src0->nb[1]; \ - const uint32_t nb02 = src0->nb[2]; \ - const uint32_t nb03 = src0->nb[3]; \ - \ - const uint32_t nb0 = dst->nb[0]; \ - const uint32_t nb1 = dst->nb[1]; \ - const uint32_t nb2 = dst->nb[2]; \ +#define htp_act_preamble \ + const struct htp_tensor * src0 = actx->octx->src[0]; \ + const struct htp_tensor * src1 = actx->octx->src[1]; \ + const struct htp_tensor * dst = actx->octx->dst; \ + \ + const uint32_t ne00 = src0->ne[0]; \ + const uint32_t ne01 = src0->ne[1]; \ + const uint32_t ne02 = src0->ne[2]; \ + const uint32_t ne03 = src0->ne[3]; \ + \ + const uint32_t nb00 = src0->nb[0]; \ + const uint32_t nb01 = src0->nb[1]; \ + const uint32_t nb02 = src0->nb[2]; \ + const uint32_t nb03 = src0->nb[3]; \ + \ + const uint32_t ne10 = src1 ? src1->ne[0] : 0; \ + const uint32_t ne11 = src1 ? src1->ne[1] : 0; \ + const uint32_t ne12 = src1 ? src1->ne[2] : 0; \ + const uint32_t ne13 = src1 ? src1->ne[3] : 0; \ + \ + const uint32_t nb10 = src1 ? src1->nb[0] : 0; \ + const uint32_t nb11 = src1 ? src1->nb[1] : 0; \ + const uint32_t nb12 = src1 ? src1->nb[2] : 0; \ + const uint32_t nb13 = src1 ? src1->nb[3] : 0; \ + \ + const uint32_t ne0 = dst->ne[0]; \ + const uint32_t ne1 = dst->ne[1]; \ + const uint32_t ne2 = dst->ne[2]; \ + const uint32_t ne3 = dst->ne[3]; \ + \ + const uint32_t nb0 = dst->nb[0]; \ + const uint32_t nb1 = dst->nb[1]; \ + const uint32_t nb2 = dst->nb[2]; \ const uint32_t nb3 = dst->nb[3]; -static void glu_swiglu_fp32_per_thread(const struct htp_tensor * src0, - const struct htp_tensor * src1, - struct htp_tensor * dst, - const int32_t * op_params, - struct htp_spad * src0_spad, - struct htp_spad * src1_spad, - struct htp_spad * dst_spad, - uint32_t nth, - uint32_t ith, - uint32_t src0_nrows_per_thread, - dma_queue * dma_queue) { - htp_act_preamble3; +struct htp_act_context { + struct htp_ops_context * octx; + + // Precomputed values + const uint8_t * data_src0; + const uint8_t * data_src1; + uint8_t * data_dst; + + size_t src0_row_size; + size_t src1_row_size; + size_t dst_row_size; - size_t src0_row_size = nb01; - size_t src1_row_size = nb11; - size_t dst_row_size = nb1; + size_t src0_row_size_aligned; + size_t src1_row_size_aligned; + size_t dst_row_size_aligned; + size_t src0_spad_half_size; + size_t src1_spad_half_size; + size_t dst_spad_half_size; + uint32_t block; + uint32_t src0_nrows; + uint32_t src0_nrows_per_thread; + int nc; +}; - const uint32_t src0_nrows = ne01 * ne02 * ne03; // src0 rows +static void glu_swiglu_f32_per_thread(unsigned int nth, unsigned int ith, void * data) { + struct htp_act_context * actx = (struct htp_act_context *) data; + htp_act_preamble; + size_t src0_row_size = actx->src0_row_size; + size_t src1_row_size = actx->src1_row_size; + size_t dst_row_size = actx->dst_row_size; + + const uint32_t src0_nrows = actx->src0_nrows; + const uint32_t src0_nrows_per_thread = actx->src0_nrows_per_thread; const uint32_t src0_start_row = src0_nrows_per_thread * ith; const uint32_t src0_end_row = MIN(src0_start_row + src0_nrows_per_thread, src0_nrows); @@ -108,43 +99,34 @@ static void glu_swiglu_fp32_per_thread(const struct htp_tensor * src0, uint64_t t1, t2; t1 = HAP_perf_get_qtimer_count(); - const uint8_t * restrict data_src0 = (const uint8_t *) src0->data; - const uint8_t * restrict data_src1 = (const uint8_t *) src1->data; - uint8_t * restrict data_dst = (uint8_t *) dst->data; - - const bool src1_valid = src1->ne[0]; - const int nc = (src1_valid) ? ne00 : ne00 / 2; - if (!src1_valid) { - const int32_t swapped = op_params[1]; - data_src1 = data_src0; - src1_row_size = src0_row_size; + const uint8_t * restrict data_src0 = actx->data_src0; + const uint8_t * restrict data_src1 = actx->data_src1; + uint8_t * restrict data_dst = actx->data_dst; - const size_t nc_in_bytes = nc * SIZEOF_FP32; - data_src0 += swapped ? nc_in_bytes : 0; - data_src1 += swapped ? 0 : nc_in_bytes; - } + const int nc = actx->nc; - const size_t src0_row_size_aligned = htp_round_up(src0_row_size, VLEN); - const size_t src1_row_size_aligned = htp_round_up(src1_row_size, VLEN); - const size_t dst_row_size_aligned = htp_round_up(dst_row_size, VLEN); + const size_t src0_row_size_aligned = actx->src0_row_size_aligned; + const size_t src1_row_size_aligned = actx->src1_row_size_aligned; + const size_t dst_row_size_aligned = actx->dst_row_size_aligned; - uint8_t * restrict src0_spad_data = src0_spad->data + (ith * src0_spad->size_per_thread); - uint8_t * restrict src1_spad_data = src1_spad->data + (ith * src1_spad->size_per_thread); - uint8_t * restrict dst_spad_data = dst_spad->data + (ith * dst_spad->size_per_thread); + uint8_t * restrict src0_spad_data = actx->octx->src0_spad.data + (ith * actx->octx->src0_spad.size_per_thread); + uint8_t * restrict src1_spad_data = actx->octx->src1_spad.data + (ith * actx->octx->src1_spad.size_per_thread); + uint8_t * restrict dst_spad_data = actx->octx->dst_spad.data + (ith * actx->octx->dst_spad.size_per_thread); - // While given src0_spad->size_per_thread, divide it to two ping-pong buffer for src0 - size_t src0_spad_half_size = src0_spad->size_per_thread / 2; - size_t src1_spad_half_size = src1_spad->size_per_thread / 2; - size_t dst_spad_half_size = dst_spad->size_per_thread / 2; + size_t src0_spad_half_size = actx->src0_spad_half_size; + size_t src1_spad_half_size = actx->src1_spad_half_size; + size_t dst_spad_half_size = actx->dst_spad_half_size; - const int BLOCK = src0_spad_half_size / src0_row_size_aligned; // How many rows can we process in one block + const int BLOCK = actx->block; if (BLOCK == 0) { FARF(ERROR, "swiglu-f32 : current VTCM reservation %zu is too small for even 1 row per thread, needed at least %zu\n", - src0_spad->size_per_thread, src0_row_size_aligned); + actx->octx->src0_spad.size_per_thread, src0_row_size_aligned); return; } + dma_queue * dma_queue = actx->octx->ctx->dma[ith]; + // See discussion: https://github.com/ggml-org/llama.cpp/pull/18151#issuecomment-3678235379 for (uint32_t ir = src0_start_row, spad_idx = 0; ir < src0_end_row && spad_idx < 2; ir += BLOCK, spad_idx++) { const uint32_t block_size = MIN(BLOCK, src0_end_row - ir); @@ -175,9 +157,9 @@ static void glu_swiglu_fp32_per_thread(const struct htp_tensor * src0, float * dst_spad_ptr = dst_spad + ib * (dst_row_size_aligned / sizeof(float)); //swiglu(x) = x1 * sigmoid(x0) - hvx_fast_sigmoid_f32((const uint8_t *) src0_spad_ptr, (uint8_t *) dst_spad_ptr, nc); - hvx_mul_mul_f32_opt((const uint8_t *) src0_spad_ptr, (const uint8_t *) dst_spad_ptr, - (const uint8_t *) src1_spad_ptr, (uint8_t *) dst_spad_ptr, nc); + hvx_sigmoid_f32_aa((uint8_t *) dst_spad_ptr, (const uint8_t *) src0_spad_ptr, nc); + hvx_mul_mul_f32_aa((uint8_t *) dst_spad_ptr, (const uint8_t *) src0_spad_ptr, (const uint8_t *) dst_spad_ptr, + (const uint8_t *) src1_spad_ptr, nc); } dma_queue_push_vtcm_to_ddr(dma_queue, dma_make_ptr(data_dst + (ir * dst_row_size), dst_spad), dst_row_size, @@ -203,27 +185,19 @@ static void glu_swiglu_fp32_per_thread(const struct htp_tensor * src0, (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1)); } -static void glu_swiglu_oai_fp32_per_thread(const struct htp_tensor * src0, - const struct htp_tensor * src1, - struct htp_tensor * dst, - const int32_t * op_params, - struct htp_spad * src0_spad, - struct htp_spad * src1_spad, - struct htp_spad * dst_spad, - uint32_t nth, - uint32_t ith, - uint32_t src0_nrows_per_thread, - dma_queue * dma_queue) { - htp_act_preamble3; +static void glu_swiglu_oai_f32_per_thread(unsigned int nth, unsigned int ith, void * data) { + struct htp_act_context * actx = (struct htp_act_context *) data; + htp_act_preamble; uint64_t t1, t2; t1 = HAP_perf_get_qtimer_count(); - size_t src0_row_size = nb01; - size_t src1_row_size = nb11; - size_t dst_row_size = nb1; + size_t src0_row_size = actx->src0_row_size; + size_t src1_row_size = actx->src1_row_size; + size_t dst_row_size = actx->dst_row_size; - const uint32_t src0_nrows = ne01 * ne02 * ne03; // src0 rows + const uint32_t src0_nrows = actx->src0_nrows; + const uint32_t src0_nrows_per_thread = actx->src0_nrows_per_thread; const uint32_t src0_start_row = src0_nrows_per_thread * ith; const uint32_t src0_end_row = MIN(src0_start_row + src0_nrows_per_thread, src0_nrows); @@ -233,45 +207,36 @@ static void glu_swiglu_oai_fp32_per_thread(const struct htp_tensor * src0, return; } - const uint8_t * restrict data_src0 = (const uint8_t *) src0->data; - const uint8_t * restrict data_src1 = (const uint8_t *) src1->data; - uint8_t * restrict data_dst = (uint8_t *) dst->data; + const uint8_t * restrict data_src0 = actx->data_src0; + const uint8_t * restrict data_src1 = actx->data_src1; + uint8_t * restrict data_dst = actx->data_dst; - const bool src1_valid = src1->ne[0]; - const int nc = (src1_valid) ? ne00 : ne00 / 2; - if (!src1_valid) { - const int32_t swapped = op_params[1]; - data_src1 = data_src0; - src1_row_size = src0_row_size; + const int nc = actx->nc; - const size_t nc_in_bytes = nc * SIZEOF_FP32; - data_src0 += swapped ? nc_in_bytes : 0; - data_src1 += swapped ? 0 : nc_in_bytes; - } - - const size_t src0_row_size_aligned = htp_round_up(src0_row_size, VLEN); - const size_t src1_row_size_aligned = htp_round_up(src1_row_size, VLEN); - const size_t dst_row_size_aligned = htp_round_up(dst_row_size, VLEN); + const size_t src0_row_size_aligned = actx->src0_row_size_aligned; + const size_t src1_row_size_aligned = actx->src1_row_size_aligned; + const size_t dst_row_size_aligned = actx->dst_row_size_aligned; - uint8_t * restrict src0_spad_data = src0_spad->data + (ith * src0_spad->size_per_thread); - uint8_t * restrict src1_spad_data = src1_spad->data + (ith * src1_spad->size_per_thread); - uint8_t * restrict dst_spad_data = dst_spad->data + (ith * dst_spad->size_per_thread); + uint8_t * restrict src0_spad_data = actx->octx->src0_spad.data + (ith * actx->octx->src0_spad.size_per_thread); + uint8_t * restrict src1_spad_data = actx->octx->src1_spad.data + (ith * actx->octx->src1_spad.size_per_thread); + uint8_t * restrict dst_spad_data = actx->octx->dst_spad.data + (ith * actx->octx->dst_spad.size_per_thread); - // While given src0_spad->size_per_thread, divide it to two ping-pong buffer for src0 - size_t src0_spad_half_size = src0_spad->size_per_thread / 2; - size_t src1_spad_half_size = src1_spad->size_per_thread / 2; - size_t dst_spad_half_size = dst_spad->size_per_thread / 2; + size_t src0_spad_half_size = actx->src0_spad_half_size; + size_t src1_spad_half_size = actx->src1_spad_half_size; + size_t dst_spad_half_size = actx->dst_spad_half_size; - const int BLOCK = src0_spad_half_size / src0_row_size_aligned; // How many rows can we process in one block + const int BLOCK = actx->block; if (BLOCK == 0) { FARF(ERROR, "swiglu-oai-f32 : current VTCM reservation %zu is too small for even 1 row per thread, needed at least " "%zu\n", - src0_spad->size_per_thread, src0_row_size_aligned); + actx->octx->src0_spad.size_per_thread, src0_row_size_aligned); return; } - const float alpha = ((const float *) (op_params))[2]; - const float limit = ((const float *) (op_params))[3]; + const float alpha = ((const float *) (actx->octx->op_params))[2]; + const float limit = ((const float *) (actx->octx->op_params))[3]; + + dma_queue * dma_queue = actx->octx->ctx->dma[ith]; // See discussion: https://github.com/ggml-org/llama.cpp/pull/18151#issuecomment-3678235379 for (uint32_t ir = src0_start_row, spad_idx = 0; ir < src0_end_row && spad_idx < 2; ir += BLOCK, spad_idx++) { @@ -304,18 +269,18 @@ static void glu_swiglu_oai_fp32_per_thread(const struct htp_tensor * src0, float * dst_spad_ptr = dst_spad + ib * (dst_row_size_aligned / sizeof(float)); // x (src0_spad_data) = std::min(src0_p[k], limit); - hvx_min_scalar_f32((const uint8_t *) src0_spad_ptr, limit, (uint8_t *) src0_spad_ptr, nc); + hvx_min_scalar_f32((uint8_t *) src0_spad_ptr, (const uint8_t *) src0_spad_ptr, limit, nc); // y1 (src1_spad_data) = std::clamp(src1_p[k], -limit, limit); - hvx_clamp_scalar_f32((const uint8_t *) src1_spad_ptr, -limit, limit, (uint8_t *) src1_spad_ptr, nc); + hvx_clamp_scalar_f32((uint8_t *) src1_spad_ptr, (const uint8_t *) src1_spad_ptr, -limit, limit, nc); // y (src1_spad_data) = y1 + 1.f - hvx_add_scalar_f32((const uint8_t *) src1_spad_ptr, 1.0, (uint8_t *) src1_spad_ptr, nc); + hvx_add_scalar_f32((uint8_t *) src1_spad_ptr, (const uint8_t *) src1_spad_ptr, 1.0, nc); // x1 (dst_spad_data) = alpha * (x) - hvx_mul_scalar_f32((const uint8_t *) src0_spad_ptr, alpha, (uint8_t *) dst_spad_ptr, nc); + hvx_mul_scalar_f32((uint8_t *) dst_spad_ptr, (const uint8_t *) src0_spad_ptr, alpha, nc); // x2 (dst_spad_data) = sigmoid(x1) = 1/(1+exp(-x1)) - hvx_fast_sigmoid_f32((const uint8_t *) dst_spad_ptr, (uint8_t *) dst_spad_ptr, nc); + hvx_sigmoid_f32_aa((uint8_t *) dst_spad_ptr, (const uint8_t *) dst_spad_ptr, nc); // out = x * sigmoid(alpha * x) * (y + 1.f) - hvx_mul_mul_f32_opt((const uint8_t *) src0_spad_ptr, (const uint8_t *) dst_spad_ptr, - (const uint8_t *) src1_spad_ptr, (uint8_t *) dst_spad_ptr, nc); + hvx_mul_mul_f32_aa((uint8_t *) dst_spad_ptr, (const uint8_t *) src0_spad_ptr, (const uint8_t *) dst_spad_ptr, + (const uint8_t *) src1_spad_ptr, nc); } dma_queue_push_vtcm_to_ddr(dma_queue, dma_make_ptr(data_dst + (ir * dst_row_size), dst_spad), dst_row_size, @@ -342,26 +307,20 @@ static void glu_swiglu_oai_fp32_per_thread(const struct htp_tensor * src0, } -static void unary_gelu_fp32_per_thread(const struct htp_tensor * src0, - struct htp_tensor * dst, - const int32_t * op_params, - struct htp_spad * src0_spad, - struct htp_spad * dst_spad, - uint32_t nth, - uint32_t ith, - uint32_t src0_nrows_per_thread, - dma_queue * dma_queue) { - htp_act_preamble2; +static void unary_gelu_f32_per_thread(unsigned int nth, unsigned int ith, void * data) { + struct htp_act_context * actx = (struct htp_act_context *) data; + htp_act_preamble; uint64_t t1, t2; t1 = HAP_perf_get_qtimer_count(); - const size_t src0_row_size = nb01; - const size_t dst_row_size = nb1; - const size_t src0_row_size_aligned = htp_round_up(src0_row_size, VLEN); - const size_t dst_row_size_aligned = htp_round_up(dst_row_size, VLEN); + const size_t src0_row_size = actx->src0_row_size; + const size_t dst_row_size = actx->dst_row_size; + const size_t src0_row_size_aligned = actx->src0_row_size_aligned; + const size_t dst_row_size_aligned = actx->dst_row_size_aligned; - const uint32_t src0_nrows = ne01 * ne02 * ne03; + const uint32_t src0_nrows = actx->src0_nrows; + const uint32_t src0_nrows_per_thread = actx->src0_nrows_per_thread; const uint32_t src0_start_row = src0_nrows_per_thread * ith; const uint32_t src0_end_row = MIN(src0_start_row + src0_nrows_per_thread, src0_nrows); @@ -371,25 +330,29 @@ static void unary_gelu_fp32_per_thread(const struct htp_tensor * src0, return; } - const uint8_t * data_src0 = (const uint8_t *) src0->data; - uint8_t * data_dst = (uint8_t *) dst->data; + const uint8_t * data_src0 = actx->data_src0; + uint8_t * data_dst = actx->data_dst; + + // nc/ne0 matches. + const int ne0_val = actx->nc; // == dst->ne[0] - uint8_t * src0_spad_data = src0_spad->data + (ith * src0_spad->size_per_thread); - uint8_t * dst_spad_data = dst_spad->data + (ith * dst_spad->size_per_thread); + uint8_t * src0_spad_data = actx->octx->src0_spad.data + (ith * actx->octx->src0_spad.size_per_thread); + uint8_t * dst_spad_data = actx->octx->dst_spad.data + (ith * actx->octx->dst_spad.size_per_thread); - // While given src0_spad->size_per_thread, divide it to two ping-pong buffer for src0 - size_t src0_spad_half_size = src0_spad->size_per_thread / 2; - size_t dst_spad_half_size = dst_spad->size_per_thread / 2; + size_t src0_spad_half_size = actx->src0_spad_half_size; + size_t dst_spad_half_size = actx->dst_spad_half_size; // In gelu = x*sigmoid(x*1.702) - const int BLOCK = src0_spad_half_size / src0_row_size_aligned; // How many rows can we process in one block + const int BLOCK = actx->block; if (BLOCK == 0) { FARF(ERROR, "gelu-f32 : current VTCM reservation %zu is too small for even 1 row per thread, needed at least %zu\n", - src0_spad->size_per_thread, src0_row_size_aligned); + actx->octx->src0_spad.size_per_thread, src0_row_size_aligned); return; } + dma_queue * dma_queue = actx->octx->ctx->dma[ith]; + // See discussion: https://github.com/ggml-org/llama.cpp/pull/18151#issuecomment-3678235379 for (uint32_t ir = src0_start_row, spad_idx = 0; ir < src0_end_row && spad_idx < 2; ir += BLOCK, spad_idx++) { const uint32_t block_size = MIN(BLOCK, src0_end_row - ir); @@ -415,9 +378,9 @@ static void unary_gelu_fp32_per_thread(const struct htp_tensor * src0, float* dst_spad_ptr = dst_spad + ib * (dst_row_size_aligned / sizeof(float)); // gelu = x * sigmoid(1.702 * x) // current implementation - hvx_mul_scalar_f32((const uint8_t *) src0_spad_ptr, (float) 1.702, (uint8_t *) dst_spad_ptr, ne0); - hvx_fast_sigmoid_f32((const uint8_t *) dst_spad_ptr, (uint8_t *) dst_spad_ptr, ne0); - hvx_mul_f32_opt((const uint8_t *) src0_spad_ptr, (uint8_t *) dst_spad_ptr, (uint8_t *) dst_spad_ptr, ne0); + hvx_mul_scalar_f32((uint8_t *) dst_spad_ptr, (const uint8_t *) src0_spad_ptr, (float) 1.702, ne0_val); + hvx_sigmoid_f32_aa((uint8_t *) dst_spad_ptr, (const uint8_t *) dst_spad_ptr, ne0_val); + hvx_mul_f32_aaa((uint8_t *) dst_spad_ptr, (const uint8_t *) src0_spad_ptr, (const uint8_t *) dst_spad_ptr, ne0_val); } dma_queue_push_vtcm_to_ddr(dma_queue, @@ -442,34 +405,21 @@ static void unary_gelu_fp32_per_thread(const struct htp_tensor * src0, ne03, src0_start_row, src0_end_row, ne0, ne1, ne2, ne3, (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1)); } -static void unary_gelu_fp32(unsigned int n, unsigned int i, void * data) { - struct htp_ops_context * octx = (struct htp_ops_context *) data; - unary_gelu_fp32_per_thread(&octx->src0, &octx->dst, octx->op_params, &octx->src0_spad, &octx->dst_spad, n, i, - octx->src0_nrows_per_thread, octx->ctx->dma[i]); -} - - -static void unary_silu_fp32_per_thread(const struct htp_tensor * src0, - struct htp_tensor * dst, - const int32_t * op_params, - struct htp_spad * src0_spad, - struct htp_spad * dst_spad, - uint32_t nth, - uint32_t ith, - uint32_t src0_nrows_per_thread, - dma_queue * dma_queue) { - htp_act_preamble2; +static void unary_silu_f32_per_thread(unsigned int nth, unsigned int ith, void * data) { + struct htp_act_context * actx = (struct htp_act_context *) data; + htp_act_preamble; uint64_t t1, t2; t1 = HAP_perf_get_qtimer_count(); - const size_t src0_row_size = nb01; - const size_t dst_row_size = nb1; - const size_t src0_row_size_aligned = htp_round_up(src0_row_size, VLEN); - const size_t dst_row_size_aligned = htp_round_up(dst_row_size, VLEN); + const size_t src0_row_size = actx->src0_row_size; + const size_t dst_row_size = actx->dst_row_size; + const size_t src0_row_size_aligned = actx->src0_row_size_aligned; + const size_t dst_row_size_aligned = actx->dst_row_size_aligned; - const uint32_t src0_nrows = ne01 * ne02 * ne03; + const uint32_t src0_nrows = actx->src0_nrows; + const uint32_t src0_nrows_per_thread = actx->src0_nrows_per_thread; const uint32_t src0_start_row = src0_nrows_per_thread * ith; const uint32_t src0_end_row = MIN(src0_start_row + src0_nrows_per_thread, src0_nrows); @@ -479,24 +429,27 @@ static void unary_silu_fp32_per_thread(const struct htp_tensor * src0, return; } - const uint8_t * data_src0 = (const uint8_t *) src0->data; - uint8_t * data_dst = (uint8_t *) dst->data; + const uint8_t * data_src0 = actx->data_src0; + uint8_t * data_dst = actx->data_dst; - uint8_t * src0_spad_data = src0_spad->data + (ith * src0_spad->size_per_thread); - uint8_t * dst_spad_data = dst_spad->data + (ith * dst_spad->size_per_thread); + const int ne0_val = actx->nc; // == dst->ne[0] - // While given src0_spad->size_per_thread, divide it to two ping-pong buffer for src0 - size_t src0_spad_half_size = src0_spad->size_per_thread / 2; - size_t dst_spad_half_size = dst_spad->size_per_thread / 2; + uint8_t * src0_spad_data = actx->octx->src0_spad.data + (ith * actx->octx->src0_spad.size_per_thread); + uint8_t * dst_spad_data = actx->octx->dst_spad.data + (ith * actx->octx->dst_spad.size_per_thread); - const int BLOCK = src0_spad_half_size / src0_row_size_aligned; // How many rows can we process in one block + size_t src0_spad_half_size = actx->src0_spad_half_size; + size_t dst_spad_half_size = actx->dst_spad_half_size; + + const int BLOCK = actx->block; if (BLOCK == 0) { FARF(ERROR, "silu-f32 : current VTCM reservation %zu is too small for even 1 row per thread, needed at least %zu\n", - src0_spad->size_per_thread, src0_row_size_aligned); + actx->octx->src0_spad.size_per_thread, src0_row_size_aligned); return; } + dma_queue * dma_queue = actx->octx->ctx->dma[ith]; + // See discussion: https://github.com/ggml-org/llama.cpp/pull/18151#issuecomment-3678235379 for (uint32_t ir = src0_start_row, spad_idx = 0; ir < src0_end_row && spad_idx < 2; ir += BLOCK, spad_idx++) { const uint32_t block_size = MIN(BLOCK, src0_end_row - ir); @@ -522,8 +475,8 @@ static void unary_silu_fp32_per_thread(const struct htp_tensor * src0, float* dst_spad_ptr = dst_spad + ib * (dst_row_size_aligned / sizeof(float)); // silu = x * sigmoid(x) - hvx_fast_sigmoid_f32((const uint8_t *) src0_spad_ptr, (uint8_t *) dst_spad_ptr, ne0); - hvx_mul_f32_opt((const uint8_t *) src0_spad_ptr, (uint8_t *) dst_spad_ptr, (uint8_t *) dst_spad_ptr, ne0); + hvx_sigmoid_f32_aa((uint8_t *) dst_spad_ptr, (const uint8_t *) src0_spad_ptr, ne0_val); + hvx_mul_f32_aaa((uint8_t *) dst_spad_ptr, (const uint8_t *) src0_spad_ptr, (const uint8_t *) dst_spad_ptr, ne0_val); } dma_queue_push_vtcm_to_ddr(dma_queue, @@ -548,30 +501,130 @@ static void unary_silu_fp32_per_thread(const struct htp_tensor * src0, ne03, src0_start_row, src0_end_row, ne0, ne1, ne2, ne3, (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1)); } -static void unary_silu_fp32(unsigned int n, unsigned int i, void * data) { - struct htp_ops_context * octx = (struct htp_ops_context *) data; - unary_silu_fp32_per_thread(&octx->src0, &octx->dst, octx->op_params, &octx->src0_spad, &octx->dst_spad, n, i, - octx->src0_nrows_per_thread, octx->ctx->dma[i]); -} +static const float GELU_COEF_A = 0.044715f; +static const float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f; -static void glu_swiglu_fp32(unsigned int n, unsigned int i, void * data) { - struct htp_ops_context * octx = (struct htp_ops_context *) data; - glu_swiglu_fp32_per_thread(&octx->src0, &octx->src1, &octx->dst, octx->op_params, &octx->src0_spad, - &octx->src1_spad, &octx->dst_spad, n, i, octx->src0_nrows_per_thread, octx->ctx->dma[i]); -} +static void glu_geglu_f32_per_thread(unsigned int nth, unsigned int ith, void * data) { + struct htp_act_context * actx = (struct htp_act_context *) data; + htp_act_preamble; -static void glu_swiglu_oai_fp32(unsigned int n, unsigned int i, void * data) { - struct htp_ops_context * octx = (struct htp_ops_context *) data; - glu_swiglu_oai_fp32_per_thread(&octx->src0, &octx->src1, &octx->dst, octx->op_params, &octx->src0_spad, - &octx->src1_spad, &octx->dst_spad, n, i, octx->src0_nrows_per_thread, octx->ctx->dma[i]); -} + size_t src0_row_size = actx->src0_row_size; + size_t src1_row_size = actx->src1_row_size; + size_t dst_row_size = actx->dst_row_size; -static int execute_op_activations_fp32(struct htp_ops_context * octx) { - int err = HTP_STATUS_OK; + uint64_t t1, t2; + t1 = HAP_perf_get_qtimer_count(); + + const uint32_t src0_nrows = actx->src0_nrows; + const uint32_t src0_nrows_per_thread = actx->src0_nrows_per_thread; + + const uint32_t src0_start_row = src0_nrows_per_thread * ith; + const uint32_t src0_end_row = MIN(src0_start_row + src0_nrows_per_thread, src0_nrows); + + // no work for this thread + if (src0_start_row >= src0_end_row) { + return; + } + + const uint8_t * restrict data_src0 = actx->data_src0; + const uint8_t * restrict data_src1 = actx->data_src1; + uint8_t * restrict data_dst = actx->data_dst; + + const int nc = actx->nc; + + const size_t src0_row_size_aligned = actx->src0_row_size_aligned; + const size_t src1_row_size_aligned = actx->src1_row_size_aligned; + const size_t dst_row_size_aligned = actx->dst_row_size_aligned; + + uint8_t * restrict src0_spad_data = actx->octx->src0_spad.data + (ith * actx->octx->src0_spad.size_per_thread); + uint8_t * restrict src1_spad_data = actx->octx->src1_spad.data + (ith * actx->octx->src1_spad.size_per_thread); + uint8_t * restrict dst_spad_data = actx->octx->dst_spad.data + (ith * actx->octx->dst_spad.size_per_thread); + + size_t src0_spad_half_size = actx->src0_spad_half_size; + size_t src1_spad_half_size = actx->src1_spad_half_size; + size_t dst_spad_half_size = actx->dst_spad_half_size; + + const int BLOCK = actx->block; + if (BLOCK == 0) { + FARF(ERROR, + "geglu-f32 : current VTCM reservation %zu is too small for even 1 row per thread, needed at least %zu\n", + actx->octx->src0_spad.size_per_thread, src0_row_size_aligned); + return; + } + + dma_queue * dma_queue = actx->octx->ctx->dma[ith]; + + // See discussion: https://github.com/ggml-org/llama.cpp/pull/18151#issuecomment-3678235379 + for (uint32_t ir = src0_start_row, spad_idx = 0; ir < src0_end_row && spad_idx < 2; ir += BLOCK, spad_idx++) { + const uint32_t block_size = MIN(BLOCK, src0_end_row - ir); + + // Dummy DMA transation for sequencing (interleaving dst,src,dst,...) + dma_queue_push_vtcm_to_ddr(dma_queue, + dma_make_ptr(data_dst, dst_spad_data + (spad_idx * dst_spad_half_size)), + dst_row_size, dst_row_size_aligned, 0); + + dma_queue_push_ddr_to_vtcm(dma_queue, + dma_make_ptr(src0_spad_data + (spad_idx * src0_spad_half_size), data_src0 + (ir * src0_row_size)), + src0_row_size_aligned, src0_row_size, block_size); + dma_queue_push_ddr_to_vtcm(dma_queue, + dma_make_ptr(src1_spad_data + (spad_idx * src1_spad_half_size), data_src1 + (ir * src1_row_size)), + src1_row_size_aligned, src1_row_size, block_size); + } + + for (uint32_t ir = src0_start_row; ir < src0_end_row; ir += BLOCK) { + const uint32_t block_size = MIN(BLOCK, src0_end_row - ir); + + float * dst_spad = (float *) dma_queue_pop(dma_queue).src; + float * src0_spad = (float *) dma_queue_pop(dma_queue).dst; + float * src1_spad = (float *) dma_queue_pop(dma_queue).dst; + + for (uint32_t ib = 0; ib < block_size; ib++) { + const uint8_t * src0_spad_ptr = (const uint8_t *)(src0_spad + ib * (src0_row_size_aligned / sizeof(float))); + const uint8_t * src1_spad_ptr = (const uint8_t *)(src1_spad + ib * (src1_row_size_aligned / sizeof(float))); + uint8_t * dst_spad_ptr = (uint8_t *)(dst_spad + ib * (dst_row_size_aligned / sizeof(float))); + + // geglu tanh implementation + // geglu(x, g) = gelu(x) * g + // gelu(x) = 0.5f*x*(1.0f + tanhf(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x))) + hvx_mul_f32_aaa(dst_spad_ptr, src0_spad_ptr, src0_spad_ptr, nc); // res = x*x + hvx_mul_scalar_f32_aa(dst_spad_ptr, (const uint8_t *)dst_spad_ptr, GELU_COEF_A, nc); // res = res * GELU_COEF_A + hvx_add_scalar_f32_aa(dst_spad_ptr, (const uint8_t *)dst_spad_ptr, 1.0f, nc); // res = res + 1.0f + hvx_mul_f32_aaa(dst_spad_ptr, src0_spad_ptr, (const uint8_t *)dst_spad_ptr, nc); // res = res * x + hvx_mul_scalar_f32_aa(dst_spad_ptr, (const uint8_t*)dst_spad_ptr, SQRT_2_OVER_PI, nc); // res = result * SQRT_2_OVER_PI + hvx_tanh_f32_aa((uint8_t *) dst_spad_ptr, (const uint8_t *) dst_spad_ptr, nc); // res = tanh(res) + hvx_add_scalar_f32_aa(dst_spad_ptr, (const uint8_t*)dst_spad_ptr, 1.0f, nc); // res = res + 1.0f + hvx_mul_f32_aaa(dst_spad_ptr, src0_spad_ptr, (const uint8_t *)dst_spad_ptr, nc); // res = res * x + hvx_mul_scalar_f32_aa(dst_spad_ptr, (const uint8_t *)dst_spad_ptr, 0.5f, nc); // res = res + 0.5f + hvx_mul_f32_aaa(dst_spad_ptr, (const uint8_t *)dst_spad_ptr, src1_spad_ptr, nc); // res = res * g + } + + dma_queue_push_vtcm_to_ddr(dma_queue, dma_make_ptr(data_dst + (ir * dst_row_size), dst_spad), dst_row_size, + dst_row_size_aligned, block_size); + + // prefetch N+2 loop iteration if any + const uint32_t pref_block = (ir + BLOCK * 2); + if (pref_block < src0_end_row) { + const uint32_t pref_block_size = MIN(BLOCK, src0_end_row - pref_block); + dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(src0_spad, data_src0 + (pref_block * src0_row_size)), + src0_row_size_aligned, src0_row_size, pref_block_size); + dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(src1_spad, data_src1 + (pref_block * src1_row_size)), + src1_row_size_aligned, src1_row_size, pref_block_size); + } + } - const struct htp_tensor * src0 = &octx->src0; - const struct htp_tensor * src1 = &octx->src1; - struct htp_tensor * dst = &octx->dst; + dma_queue_flush(dma_queue); + + t2 = HAP_perf_get_qtimer_count(); + + FARF(HIGH, "geglu-f32 %d/%d: %ux%ux%ux%u (%u:%u) x %ux%ux%ux%u -> %ux%ux%ux%u usec %u\n", ith, nth, + ne00, ne01, ne02, ne03, src0_start_row, src0_end_row, ne10, ne11, ne12, ne13, ne0, ne1, ne2, ne3, + (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1)); +} + +static int execute_op_activations_f32(struct htp_ops_context * octx) { + const struct htp_tensor * src0 = octx->src[0]; + const struct htp_tensor * src1 = octx->src[1]; + const struct htp_tensor * dst = octx->dst; if (((src0->ne[0] * SIZEOF_FP32) != src0->nb[1]) || ((dst->ne[0] * SIZEOF_FP32) != dst->nb[1])) { FARF(ERROR, "Non-contiguous tensors are not supported at this time \n"); @@ -583,51 +636,51 @@ static int execute_op_activations_fp32(struct htp_ops_context * octx) { switch (octx->op) { case HTP_OP_UNARY_SILU: - act_op_func = unary_silu_fp32; + act_op_func = (worker_callback_t)unary_silu_f32_per_thread; op_type = "silu-f32"; break; case HTP_OP_GLU_SWIGLU: - act_op_func = glu_swiglu_fp32; + act_op_func = (worker_callback_t)glu_swiglu_f32_per_thread; op_type = "swiglu-f32"; break; case HTP_OP_GLU_SWIGLU_OAI: - act_op_func = glu_swiglu_oai_fp32; + act_op_func = (worker_callback_t)glu_swiglu_oai_f32_per_thread; op_type = "swiglu-oai-f32"; break; case HTP_OP_UNARY_GELU: - act_op_func = unary_gelu_fp32; + act_op_func = (worker_callback_t)unary_gelu_f32_per_thread; op_type = "gelu-f32"; break; + + case HTP_OP_GLU_GEGLU: + act_op_func = (worker_callback_t)glu_geglu_f32_per_thread; + op_type = "geglu-f32"; + break; default: FARF(ERROR, "Unsupported activations Op %u\n", octx->op); return HTP_STATUS_NO_SUPPORT; } - const uint32_t n_threads = octx->n_threads; const uint32_t src0_nrows = src0->ne[1] * src0->ne[2] * src0->ne[3]; + const uint32_t n_threads = MIN(octx->n_threads, src0_nrows); size_t src0_row_size = src0->nb[1]; - size_t src1_row_size = src1->nb[1]; // zero bytes if src1 is not used + size_t src1_row_size = src1 ? src1->nb[1] : src0->nb[1]; size_t dst_row_size = dst->nb[1]; - const bool src1_valid = src1->ne[0]; - if (!src1_valid) { - src1_row_size = src0_row_size; - } + const size_t src0_row_size_aligned = hex_round_up(src0_row_size, VLEN); + const size_t src1_row_size_aligned = hex_round_up(src1_row_size, VLEN); + const size_t dst_row_size_aligned = hex_round_up(dst_row_size, VLEN); - const size_t src0_row_size_aligned = htp_round_up(src0_row_size, VLEN); - const size_t src1_row_size_aligned = htp_round_up(src1_row_size, VLEN); - const size_t dst_row_size_aligned = htp_round_up(dst_row_size, VLEN); // VTCM scratchpads for all tensors // N rows per thread, padded to HVX vector size - size_t spad_size_per_row = (src0_row_size_aligned + src1_row_size_aligned) + dst_row_size_aligned; size_t vtcm_row_per_thread = (octx->ctx->vtcm_size)/ (n_threads* spad_size_per_row); // Make sure the reserved vtcm size is sufficient - if(vtcm_row_per_thread ==0){ + if (vtcm_row_per_thread == 0) { FARF(ERROR, "act-%s : current VTCM reservation %zu is too small for even 1 row per thread, needed at least %zu\n", op_type, octx->ctx->vtcm_size, spad_size_per_row * n_threads); return HTP_STATUS_VTCM_TOO_SMALL; @@ -645,7 +698,11 @@ static int execute_op_activations_fp32(struct htp_ops_context * octx) { octx->src1_spad.data = octx->src0_spad.data + octx->src0_spad.size; octx->dst_spad.data = octx->src1_spad.data + octx->src1_spad.size; - if (src1->ne[0]) { + octx->src0_spad.src = NULL; + octx->src1_spad.src = NULL; + octx->dst_spad.src = NULL; + + if (src1) { FARF(HIGH, "%s: %ux%ux%ux%u x %ux%ux%ux%u -> %ux%ux%ux%u : src0-spad-size %u src1-spad-size %u dst-spad-size %u\n", op_type, src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], src1->ne[0], src1->ne[1], src1->ne[2], src1->ne[3], dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], octx->src0_spad.size, octx->src1_spad.size, @@ -656,21 +713,64 @@ static int execute_op_activations_fp32(struct htp_ops_context * octx) { octx->src0_spad.size, octx->src1_spad.size, octx->dst_spad.size); } - if (!(octx->flags & HTP_OPFLAGS_SKIP_COMPUTE)) { - uint32_t n_jobs = MIN(n_threads, src0_nrows); - octx->src0_nrows_per_thread = (src0_nrows + n_jobs - 1) / n_jobs; - worker_pool_run_func(octx->ctx->worker_pool, act_op_func, octx, n_jobs); + if ((octx->flags & HTP_OPFLAGS_SKIP_COMPUTE)) { + return HTP_STATUS_OK; } - return err; + // Prepare context + struct htp_act_context actx; + actx.octx = octx; + + actx.src0_nrows_per_thread = (src0_nrows + n_threads - 1) / n_threads; + + actx.src0_row_size = src0_row_size; + actx.src1_row_size = src1_row_size; + actx.dst_row_size = dst_row_size; + + actx.src0_row_size_aligned = src0_row_size_aligned; + actx.src1_row_size_aligned = src1_row_size_aligned; + actx.dst_row_size_aligned = dst_row_size_aligned; + + actx.src0_spad_half_size = octx->src0_spad.size_per_thread / 2; + actx.src1_spad_half_size = octx->src1_spad.size_per_thread / 2; + actx.dst_spad_half_size = octx->dst_spad.size_per_thread / 2; + + actx.block = actx.src0_spad_half_size / actx.src0_row_size_aligned; + actx.src0_nrows = src0_nrows; + + actx.nc = dst->ne[0]; + + // Pointers and GLU logic + const uint8_t * data_src0 = (const uint8_t *) src0->data; + const uint8_t * data_src1 = src1 ? (const uint8_t *) src1->data : NULL; + + if (!src1 && (octx->op == HTP_OP_GLU_SWIGLU || octx->op == HTP_OP_GLU_SWIGLU_OAI || octx->op == HTP_OP_GLU_GEGLU)) { + const int32_t swapped = octx->op_params[1]; + data_src1 = data_src0; + actx.src1_row_size = actx.src0_row_size; + + size_t nc_in_bytes = actx.nc * SIZEOF_FP32; + if (swapped) { + data_src0 += nc_in_bytes; + } else { + data_src1 += nc_in_bytes; + } + } + + actx.data_src0 = data_src0; + actx.data_src1 = data_src1; + actx.data_dst = (uint8_t *) dst->data; + + worker_pool_run_func(octx->ctx->worker_pool, act_op_func, &actx, n_threads); + return HTP_STATUS_OK; } int op_activations(struct htp_ops_context * octx) { int err = HTP_STATUS_OK; - switch (octx->src0.type) { + switch (octx->src[0]->type) { case HTP_TYPE_F32: - err = execute_op_activations_fp32(octx); + err = execute_op_activations_f32(octx); break; default: diff --git a/ggml/src/ggml-hexagon/htp/argsort-ops.c b/ggml/src/ggml-hexagon/htp/argsort-ops.c new file mode 100644 index 00000000000..73af38a35ab --- /dev/null +++ b/ggml/src/ggml-hexagon/htp/argsort-ops.c @@ -0,0 +1,294 @@ +#include <string.h> +#include <stdlib.h> +#include <math.h> +#include <HAP_farf.h> +#include <HAP_perf.h> + +#define GGML_COMMON_DECL_C +#include "ggml-common.h" +#include "ggml.h" + +#include "hvx-utils.h" +#include "hex-dma.h" + +#include "htp-ctx.h" +#include "htp-ops.h" +#include "htp-ops.h" + +#ifndef MIN +#define MIN(a, b) ((a) < (b) ? (a) : (b)) +#endif + +struct htp_argsort_context { + struct htp_ops_context * octx; + uint32_t nrows_per_thread; +}; + +static inline bool all_greater_f32(HVX_Vector x, HVX_Vector y) +{ + const HVX_Vector one = Q6_V_vsplat_R(1); + const HVX_Vector zero = Q6_V_vzero(); + + HVX_VectorPred pred = Q6_Q_vcmp_gt_VsfVsf(x, y); + HVX_Vector matches = Q6_V_vmux_QVV(pred, one, zero); + HVX_Vector sum = hvx_vec_reduce_sum_i32(matches); + return hvx_vec_get_i32(sum) == 32; +} + +// Sorts values and mirrors swaps to indices. +static void quicksort_values_indices_asc(float * values, int32_t * indices, int left, int right) { + if (left >= right) return; + + int pivot_idx = (left + right) / 2; + float pivot = values[pivot_idx]; + int i = left; + int j = right; + + HVX_Vector pivot_vec = hvx_vec_splat_f32(pivot); + while (i <= j) { + // Vectorized scan for i + while (i <= j) { + // Check if we have at least one full vector + if (i + 32 <= j) { + HVX_Vector vals_vec = *(HVX_UVector *)(values + i); + if (all_greater_f32(pivot_vec, vals_vec)) { + // If all elements are < pivot, we can skip this whole block + i += 32; + continue; + } + } + + // Scalar fallback / cleanup + if (values[i] < pivot) { + i++; + } else { + break; + } + } + + // Vectorized scan for j + while (i <= j) { + if (j - 32 >= i) { + // Load 32 elements ending at j. + // Since we want `values[j] > pivot`, let's load from j-31 to j. + HVX_Vector vals_vec = *(HVX_UVector *)(values + j - 31); + if (all_greater_f32(vals_vec, pivot_vec)) { + j -= 32; + continue; + } + } + + if (values[j] > pivot) { + j--; + } else { + break; + } + } + + if (i <= j) { + float tmp_val = values[i]; + values[i] = values[j]; + values[j] = tmp_val; + + int32_t tmp_idx = indices[i]; + indices[i] = indices[j]; + indices[j] = tmp_idx; + i++; + j--; + } + } + + if (left < j) quicksort_values_indices_asc(values, indices, left, j); + if (i < right) quicksort_values_indices_asc(values, indices, i, right); +} + +static void quicksort_values_indices_desc(float * values, int32_t * indices, int left, int right) { + if (left >= right) return; + + int pivot_idx = (left + right) / 2; + float pivot = values[pivot_idx]; + int i = left; + int j = right; + + HVX_Vector pivot_vec = hvx_vec_splat_f32(pivot); + + while (i <= j) { + // Vectorized scan for i (values[i] > pivot) + while (i <= j) { + if (i + 32 <= j) { + HVX_Vector vals_vec = *(HVX_UVector *)(values + i); + if (all_greater_f32(vals_vec, pivot_vec)) { + i += 32; + continue; + } + } + + if (values[i] > pivot) { + i++; + } else { + break; + } + } + + // Vectorized scan for j (values[j] < pivot) + while (i <= j) { + if (j - 32 >= i) { + HVX_Vector vals_vec = *(HVX_UVector *)(values + j - 31); + if (all_greater_f32(pivot_vec, vals_vec)) { + j -= 32; + continue; + } + } + + if (values[j] < pivot) { + j--; + } else { + break; + } + } + + if (i <= j) { + float tmp_val = values[i]; + values[i] = values[j]; + values[j] = tmp_val; + + int32_t tmp_idx = indices[i]; + indices[i] = indices[j]; + indices[j] = tmp_idx; + i++; + j--; + } + } + + if (left < j) quicksort_values_indices_desc(values, indices, left, j); + if (i < right) quicksort_values_indices_desc(values, indices, i, right); +} + +// LUT for ramp initialization of argsort output (first 32 members) +int32_t argosrt_ramp_lut[32] __attribute__((aligned(VLEN))) = { + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, + 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31 +}; + +static void htp_argsort_f32(unsigned int n, unsigned int i, void * data) { + struct htp_argsort_context * actx = (struct htp_argsort_context *)data; + struct htp_ops_context * octx = actx->octx; + + // Unpack context + const struct htp_tensor * src0 = octx->src[0]; + const struct htp_tensor * dst = octx->dst; + + // Scratchpad memory + uint8_t * spad = octx->src0_spad.data + octx->src0_spad.size_per_thread * i; + + // Dimensions + uint32_t ne00 = src0->ne[0]; + uint32_t ne01 = src0->ne[1]; + uint32_t ne02 = src0->ne[2]; + uint32_t ne03 = src0->ne[3]; + + uint32_t nb01 = src0->nb[1]; + //uint32_t nb02 = src0->nb[2]; + //uint32_t nb03 = src0->nb[3]; + + uint32_t nb1 = dst->nb[1]; + //uint32_t nb2 = dst->nb[2]; + //uint32_t nb3 = dst->nb[3]; + + // Sort order + enum ggml_sort_order order = (enum ggml_sort_order) octx->op_params[0]; + + // Rows to process + uint32_t total_rows = ne01 * ne02 * ne03; + uint32_t rows_per_thread = actx->nrows_per_thread; + uint32_t start_row = rows_per_thread * i; + uint32_t end_row = MIN(start_row + rows_per_thread, total_rows); + + // Scratchpad layout: + // We need space for one row of float data (values) and one row of int32 indices. + // values: ne00 * sizeof(float) + // indices: ne00 * sizeof(int32_t) + // Padded to 128 bytes. + + size_t values_size = hex_round_up(ne00 * sizeof(float), 128); + size_t num_vec_ind_values = hmx_ceil_div(ne00, VLEN/(sizeof(int32_t))); + float * values_buf = (float *) spad; + int32_t * indices_buf = (int32_t *) (spad + values_size); + HVX_Vector * indices_buf_vec = (HVX_Vector *) (spad + values_size); + const HVX_Vector ind_init_vec = *(HVX_Vector *)argosrt_ramp_lut; + const HVX_Vector ind_diff_vec = Q6_V_vsplat_R(32); + + for (uint32_t r = start_row; r < end_row; r++) { + uint32_t src_offset = r * nb01; + uint32_t dst_offset = r * nb1; + + uint8_t * src_ptr = (uint8_t *) src0->data + src_offset; + uint8_t * dst_ptr = (uint8_t *) dst->data + dst_offset; + + hex_l2fetch(src_ptr, ne00 * sizeof(float), ne00 * sizeof(float), 1); + hvx_copy_f32_au((uint8_t*)values_buf, src_ptr, ne00); + + // Initialize indices - Start with values 0..31, add 32 for additional vec iterations + HVX_Vector curr_ind_vec = ind_init_vec; + for (uint32_t j_vec = 0; j_vec < num_vec_ind_values; j_vec++) { + indices_buf_vec[j_vec] = curr_ind_vec; + curr_ind_vec = Q6_Vw_vadd_VwVw(curr_ind_vec, ind_diff_vec); + } + + // Sort values and mirror swaps to indices + if (order == GGML_SORT_ORDER_ASC) { + quicksort_values_indices_asc(values_buf, indices_buf, 0, ne00 - 1); + } else { + quicksort_values_indices_desc(values_buf, indices_buf, 0, ne00 - 1); + } + + // Copy indices back to DDR + hvx_copy_f32_ua(dst_ptr, (const uint8_t *) indices_buf, ne00); + } +} + +int op_argsort(struct htp_ops_context * octx) { + // Check supported types + if (octx->src[0]->type != HTP_TYPE_F32) { + return HTP_STATUS_NO_SUPPORT; + } + + const uint32_t total_rows = octx->src[0]->ne[1] * octx->src[0]->ne[2] * octx->src[0]->ne[3]; + const uint32_t n_threads = MIN(total_rows, octx->n_threads); + + // Allocate scratchpad + // We need 1 row of float + 1 row of int32 per thread. + uint32_t ne00 = octx->src[0]->ne[0]; + size_t values_size = hex_round_up(ne00 * sizeof(float), 128); + size_t indices_size = hex_round_up(ne00 * sizeof(int32_t), 128); + size_t spad_per_thread = values_size + indices_size; + + // Make sure we round up to 256 for alignment requirements + spad_per_thread = hex_round_up(spad_per_thread, 256); + + size_t total_spad_size = spad_per_thread * n_threads; + + if (octx->ctx->vtcm_size < total_spad_size) { + FARF(ERROR, "argsort: VTCM size too small. Needed %zu, have %zu", total_spad_size, octx->ctx->vtcm_size); + return HTP_STATUS_VTCM_TOO_SMALL; + } + + octx->src0_spad.data = octx->ctx->vtcm_base; + octx->src0_spad.size = total_spad_size; + octx->src0_spad.size_per_thread = spad_per_thread; + octx->src0_spad.src = NULL; + + FARF(HIGH, "argsort: %ux%ux%ux%u -> %ux%ux%ux%u (0x%x, 0x%x)", + octx->src[0]->ne[0], octx->src[0]->ne[1], octx->src[0]->ne[2], octx->src[0]->ne[3], + octx->dst->ne[0], octx->dst->ne[1], octx->dst->ne[2], octx->dst->ne[3], + octx->src[0]->data, octx->dst->data); + + struct htp_argsort_context actx; + actx.octx = octx; + actx.nrows_per_thread = (total_rows + n_threads - 1) / n_threads; + + // Run jobs + worker_pool_run_func(octx->ctx->worker_pool, htp_argsort_f32, &actx, n_threads); + + return HTP_STATUS_OK; +} diff --git a/ggml/src/ggml-hexagon/htp/binary-ops.c b/ggml/src/ggml-hexagon/htp/binary-ops.c index 8ed7f67d9c8..52013ad0fec 100644 --- a/ggml/src/ggml-hexagon/htp/binary-ops.c +++ b/ggml/src/ggml-hexagon/htp/binary-ops.c @@ -2,42 +2,51 @@ #pragma clang diagnostic ignored "-Wunused-function" #pragma clang diagnostic ignored "-Wunused-but-set-variable" -#ifdef HTP_DEBUG -# define FARF_HIGH 1 -#endif - #include <HAP_farf.h> -#include <HAP_mem.h> #include <HAP_perf.h> -#include <HAP_ps.h> -#include <hexagon_protos.h> -#include <hexagon_types.h> + #include <math.h> -#include <qurt_thread.h> #include <string.h> +#include "hex-dma.h" +#include "hvx-utils.h" + #define GGML_COMMON_DECL_C #include "ggml-common.h" #include "htp-ctx.h" -#include "htp-dma.h" -#include "htp-msg.h" #include "htp-ops.h" -#include "hvx-utils.h" -#include "ops-utils.h" +#include "htp-ops.h" + +#ifndef MIN +#define MIN(a, b) ((a) < (b) ? (a) : (b)) +#endif + +// Context for binary operations +struct htp_binary_context { + struct htp_ops_context * octx; -typedef void (*hvx_elemwise_f32_func)(const uint8_t * src0, - const uint8_t * src1, - uint8_t * data_dst, - const int num_elems); + struct fastdiv_values src0_dim1_div; // ne01 + struct fastdiv_values src0_dim2_div; // ne02 + struct fastdiv_values src0_dim12_div;// ne03 -static hvx_elemwise_f32_func func_table_HVX[] = { hvx_mul_f32, hvx_add_f32, hvx_sub_f32 }; -static hvx_elemwise_f32_func func_table_HVX_opt[] = { hvx_mul_f32_opt, hvx_add_f32_opt, hvx_sub_f32_opt }; + struct fastdiv_values src1_dim1_div; // ne11 + struct fastdiv_values src1_dim2_div; // ne12 + struct fastdiv_values src1_dim3_div; // ne13 -#define htp_binary_preamble \ - const struct htp_tensor * src0 = &octx->src0; \ - const struct htp_tensor * src1 = &octx->src1; \ - const struct htp_tensor * src2 = &octx->src2; \ - struct htp_tensor * dst = &octx->dst; \ + uint32_t block_max; + uint32_t nrows_per_thread; + size_t src0_row_size_aligned; + size_t src1_row_size_aligned; + size_t dst_row_size_aligned; + + bool split_at_ne01; + bool split_at_ne02; +}; + +#define htp_binary_preamble \ + const struct htp_tensor * src0 = octx->src[0]; \ + const struct htp_tensor * src1 = octx->src[1]; \ + const struct htp_tensor * dst = octx->dst; \ \ const uint32_t ne00 = src0->ne[0]; \ const uint32_t ne01 = src0->ne[1]; \ @@ -49,312 +58,815 @@ static hvx_elemwise_f32_func func_table_HVX_opt[] = { hvx_mul_f32_opt, hvx_add_f const uint32_t ne12 = src1->ne[2]; \ const uint32_t ne13 = src1->ne[3]; \ \ - const uint32_t ne0 = dst->ne[0]; \ - const uint32_t ne1 = dst->ne[1]; \ - const uint32_t ne2 = dst->ne[2]; \ - const uint32_t ne3 = dst->ne[3]; \ - \ - const uint32_t nb00 = src0->nb[0]; \ const uint32_t nb01 = src0->nb[1]; \ const uint32_t nb02 = src0->nb[2]; \ const uint32_t nb03 = src0->nb[3]; \ \ - const uint32_t nb10 = src1->nb[0]; \ const uint32_t nb11 = src1->nb[1]; \ const uint32_t nb12 = src1->nb[2]; \ const uint32_t nb13 = src1->nb[3]; \ \ - const uint32_t nb0 = dst->nb[0]; \ const uint32_t nb1 = dst->nb[1]; \ const uint32_t nb2 = dst->nb[2]; \ - const uint32_t nb3 = dst->nb[3]; \ - \ - const uint32_t src0_nrows_per_thread = octx->src0_nrows_per_thread; + const uint32_t nb3 = dst->nb[3]; -static void binary_job_f32_per_thread(struct htp_ops_context * octx, - uint8_t * spad_data, - uint32_t nth, - uint32_t ith, - enum htp_op op) { - htp_binary_preamble; +static inline uint32_t calc_block_size(struct htp_binary_context * bctx, uint32_t ir, uint32_t end_row, uint32_t ne01, uint32_t ne02) { + uint32_t i03, i02, i01, rem; + i03 = fastdiv(ir, &bctx->src0_dim12_div); + rem = ir - i03 * (ne02 * ne01); + i02 = fastdiv(rem, &bctx->src0_dim1_div); + i01 = rem - i02 * ne01; - const size_t src0_row_size = nb01; - const size_t src1_row_size = nb11; - const size_t dst_row_size = nb1; + uint32_t rows_left = end_row - ir; + uint32_t block_limit = rows_left; - const uint32_t src0_nrows = ne01 * ne02 * ne03; // src0 rows - const uint32_t src1_nrows = ne11 * ne12 * ne13; // src1 rows + if (bctx->split_at_ne01) { + block_limit = MIN(block_limit, ne01 - i01); + } + if (bctx->split_at_ne02) { + uint32_t rows_in_plane = (ne02 * ne01) - rem; + block_limit = MIN(block_limit, rows_in_plane); + } - const uint32_t src0_start_row = src0_nrows_per_thread * ith; - const uint32_t src0_end_row = MIN(src0_start_row + src0_nrows_per_thread, src0_nrows); + return MIN(bctx->block_max, block_limit); +} - // no work for this thread - if (src0_start_row >= src0_end_row) { - return; +// Macro for scalar op switch +#define COMPUTE_SCALAR_OP(DST, SRC, VAL, TYPE, N) \ + if(TYPE == HTP_TYPE_F32) { \ + switch (octx->op) { \ + case HTP_OP_ADD: hvx_add_scalar_f32_aa(DST, SRC, *(float *)VAL, N); break; \ + case HTP_OP_SUB: hvx_sub_scalar_f32_aa(DST, SRC, *(float *)VAL, N); break; \ + case HTP_OP_MUL: hvx_mul_scalar_f32_aa(DST, SRC, *(float *)VAL, N); break; \ + case HTP_OP_DIV: hvx_mul_scalar_f32_aa(DST, SRC, 1.0f / (*(float *)VAL), N); break; \ + default: break; \ + } \ + } \ + else { \ + switch (octx->op) { \ + case HTP_OP_ADD: hvx_add_scalar_f16_aa(DST, SRC, *(_Float16 *)VAL, N); break; \ + case HTP_OP_SUB: hvx_sub_scalar_f16_aa(DST, SRC, *(_Float16 *)VAL, N); break; \ + case HTP_OP_MUL: hvx_mul_scalar_f16_aa(DST, SRC, *(_Float16 *)VAL, N); break; \ + case HTP_OP_DIV: hvx_div_scalar_f16_aa(DST, SRC, *(_Float16 *)VAL, N); break; \ + default: break; \ + } \ } - uint64_t t1, t2; - t1 = HAP_perf_get_qtimer_count(); - - int is_aligned = 1; - int opt_path = 0; - if ((0 == htp_is_aligned((void *) src0->data, VLEN)) || (0 == htp_is_aligned((void *) src1->data, VLEN)) || - (0 == htp_is_aligned((void *) dst->data, VLEN))) { - FARF(HIGH, "binary-f32: unaligned addresses in elementwise op, possibly slower execution\n"); - is_aligned = 0; +// Macro for vector op switch (All Aligned) +#define COMPUTE_VECTOR_OP_AAA(DST, SRC0, SRC1, TYPE, N) \ + if(TYPE == HTP_TYPE_F32) { \ + switch (octx->op) { \ + case HTP_OP_ADD: hvx_add_f32_aaa(DST, SRC0, SRC1, N); break; \ + case HTP_OP_SUB: hvx_sub_f32_aaa(DST, SRC0, SRC1, N); break; \ + case HTP_OP_MUL: hvx_mul_f32_aaa(DST, SRC0, SRC1, N); break; \ + case HTP_OP_DIV: hvx_div_f32_aaa(DST, SRC0, SRC1, N); break; \ + default: break; \ + } \ + } \ + else { \ + switch (octx->op) { \ + case HTP_OP_ADD: hvx_add_f16_aaa(DST, SRC0, SRC1, N); break; \ + case HTP_OP_SUB: hvx_sub_f16_aaa(DST, SRC0, SRC1, N); break; \ + case HTP_OP_MUL: hvx_mul_f16_aaa(DST, SRC0, SRC1, N); break; \ + case HTP_OP_DIV: hvx_div_f16_aaa(DST, SRC0, SRC1, N); break; \ + default: break; \ + } \ } - if ((1 == is_aligned) && !(nb01 & (VLEN - 1))) { - opt_path = 1; + +// Macro for vector op switch (Dst Aligned, Src0 Aligned, Src1 Unaligned) +#define COMPUTE_VECTOR_OP_AAU(DST, SRC0, SRC1, TYPE, N) \ + if(TYPE == HTP_TYPE_F32) { \ + switch (octx->op) { \ + case HTP_OP_ADD: hvx_add_f32_aau(DST, SRC0, SRC1, N); break; \ + case HTP_OP_SUB: hvx_sub_f32_aau(DST, SRC0, SRC1, N); break; \ + case HTP_OP_MUL: hvx_mul_f32_aau(DST, SRC0, SRC1, N); break; \ + case HTP_OP_DIV: hvx_div_f32_aau(DST, SRC0, SRC1, N); break; \ + default: break; \ + } \ + } \ + else { \ + switch (octx->op) { \ + case HTP_OP_ADD: hvx_add_f16_aau(DST, SRC0, SRC1, N); break; \ + case HTP_OP_SUB: hvx_sub_f16_aau(DST, SRC0, SRC1, N); break; \ + case HTP_OP_MUL: hvx_mul_f16_aau(DST, SRC0, SRC1, N); break; \ + case HTP_OP_DIV: hvx_div_f16_aau(DST, SRC0, SRC1, N); break; \ + default: break; \ + } \ } - hvx_elemwise_f32_func func_HVX = (1 == opt_path) ? func_table_HVX_opt[op] : func_table_HVX[op]; +// Macro for vector op switch (All Unaligned - generic loop used in element repeat) +#define COMPUTE_VECTOR_OP_UUU(DST, SRC0, SRC1, TYPE, N) \ + if(TYPE == HTP_TYPE_F32) { \ + switch (octx->op) { \ + case HTP_OP_ADD: hvx_add_f32_uuu(DST, SRC0, SRC1, N); break; \ + case HTP_OP_SUB: hvx_sub_f32_uuu(DST, SRC0, SRC1, N); break; \ + case HTP_OP_MUL: hvx_mul_f32_uuu(DST, SRC0, SRC1, N); break; \ + case HTP_OP_DIV: hvx_div_f32_uuu(DST, SRC0, SRC1, N); break; \ + default: break; \ + } \ + } \ + else { \ + switch (octx->op) { \ + case HTP_OP_ADD: hvx_add_f16_uuu(DST, SRC0, SRC1, N); break; \ + case HTP_OP_SUB: hvx_sub_f16_uuu(DST, SRC0, SRC1, N); break; \ + case HTP_OP_MUL: hvx_mul_f16_uuu(DST, SRC0, SRC1, N); break; \ + case HTP_OP_DIV: hvx_div_f16_uuu(DST, SRC0, SRC1, N); break; \ + default: break; \ + } \ + } - uint8_t * restrict spad_data_th = spad_data + (ith * src0_row_size); +// 1. Scalar src1 (ne10 == 1) +static void binary_job_scalar(unsigned int nth, unsigned int ith, void * data) { + struct htp_binary_context * bctx = (struct htp_binary_context *) data; + struct htp_ops_context * octx = bctx->octx; + htp_binary_preamble; - const uint8_t * restrict src0_ptr = (const uint8_t *) src0->data + (src0_start_row * src0_row_size); - uint8_t * restrict dst_ptr = (uint8_t *) dst->data + (src0_start_row * dst_row_size); + const uint32_t src0_type = octx->src[0]->type; + const uint32_t row_size_bytes = (src0_type == HTP_TYPE_F32) ? ne00 * sizeof(float) : ne00 * sizeof(_Float16); + const uint32_t total_rows = ne01 * ne02 * ne03; + const uint32_t start_row = bctx->nrows_per_thread * ith; + const uint32_t end_row = MIN(start_row + bctx->nrows_per_thread, total_rows); + if (start_row >= end_row) return; + + FARF(HIGH, "binary-scalar: %d/%d (%u:%u) row-size %u (%u)", ith, nth, start_row, end_row, nb01, bctx->dst_row_size_aligned); + + uint8_t * src0_spad_base = octx->src0_spad.data + (ith * octx->src0_spad.size_per_thread); + uint8_t * dst_spad_base = octx->dst_spad.data + (ith * octx->dst_spad.size_per_thread); + size_t src0_spad_half = octx->src0_spad.size_per_thread / 2; + size_t dst_spad_half = octx->dst_spad.size_per_thread / 2; + + dma_queue * q = octx->ctx->dma[ith]; + uint32_t ir_prefetch = start_row; + int spad_idx = 0; + + // Preamble + for (int k = 0; k < 2 && ir_prefetch < end_row; k++) { + uint32_t current_block_size = calc_block_size(bctx, ir_prefetch, end_row, ne01, ne02); + uint32_t i03, i02, i01, rem; + i03 = fastdiv(ir_prefetch, &bctx->src0_dim12_div); + rem = ir_prefetch - i03 * (ne02 * ne01); + i02 = fastdiv(rem, &bctx->src0_dim1_div); + i01 = rem - i02 * ne01; + + uint8_t * src0_curr = (uint8_t *)src0->data + i03 * nb03 + i02 * nb02 + i01 * nb01; + uint8_t * dst_curr = (uint8_t *)dst->data + i03 * nb3 + i02 * nb2 + i01 * nb1; + + uint8_t * s0_spad = src0_spad_base + spad_idx * src0_spad_half; + uint8_t * d_spad = dst_spad_base + spad_idx * dst_spad_half; + + dma_queue_push(q, dma_make_ptr(dst_curr, d_spad), nb1, bctx->dst_row_size_aligned, row_size_bytes, 0); + dma_queue_push(q, dma_make_ptr(s0_spad, src0_curr), bctx->src0_row_size_aligned, nb01, row_size_bytes, current_block_size); + ir_prefetch += current_block_size; + spad_idx ^= 1; + } - const uint8_t * restrict data_src1 = (const uint8_t *) src1->data; + // Main loop + for (uint32_t ir = start_row; ir < end_row; ) { + uint32_t current_block_size = calc_block_size(bctx, ir, end_row, ne01, ne02); - const uint32_t ne02_ne01 = ne02 * ne01; + uint8_t * d_spad = (uint8_t *) dma_queue_pop(q).src; + uint8_t * s0_spad = (uint8_t *) dma_queue_pop(q).dst; - for (uint32_t ir = src0_start_row; ir < src0_end_row; ir++) { - const uint32_t i03 = fastdiv(ir, &octx->src0_div21); - const uint32_t i02 = fastdiv(ir - i03 * ne02_ne01, &octx->src0_div1); - const uint32_t i01 = (ir - i03 * ne02_ne01 - i02 * ne01); + uint32_t i03, i02, i01, rem; + i03 = fastdiv(ir, &bctx->src0_dim12_div); + rem = ir - i03 * (ne02 * ne01); + i02 = fastdiv(rem, &bctx->src0_dim1_div); + i01 = rem - i02 * ne01; - const uint32_t i13 = fastmodulo(i03, ne13, &octx->src1_div3); - const uint32_t i12 = fastmodulo(i02, ne12, &octx->src1_div2); - const uint32_t i11 = fastmodulo(i01, ne11, &octx->src1_div1); + // src1 indices (broadcast/repeat) + uint32_t i13 = fastmodulo(i03, ne13, &bctx->src1_dim3_div); + uint32_t i12 = fastmodulo(i02, ne12, &bctx->src1_dim2_div); + uint32_t i11 = fastmodulo(i01, ne11, &bctx->src1_dim1_div); - const uint8_t * restrict src1_ptr = data_src1 + i13 * nb13 + i12 * nb12 + i11 * src1_row_size; + uint8_t * src1_ptr = (uint8_t *)src1->data + i13 * nb13 + i12 * nb12 + i11 * nb11; + uint32_t s1_stride = (ne11 == 1) ? 0 : nb11; - if (ir + 1 < src0_end_row) { - htp_l2fetch(src0_ptr + ne00, 1, src0_row_size, src0_row_size); - if (src1_row_size == src0_row_size) { - htp_l2fetch(src1_ptr, 1, src1_row_size, src1_row_size); - } + for (uint32_t r = 0; r < current_block_size; r++) { + uint8_t * r_src0 = s0_spad + r * bctx->src0_row_size_aligned; + uint8_t * r_dst = d_spad + r * bctx->dst_row_size_aligned; + COMPUTE_SCALAR_OP(r_dst, r_src0, src1_ptr, src0_type, ne00); + src1_ptr += s1_stride; } - const uint32_t nr0 = ne00 / ne10; - if (nr0 > 1) { - if ((1 == is_aligned) && (nr0 == ne00)) { - hvx_bcast_fp32_a(spad_data_th, *(float *) src1_ptr, nr0); - } else { - for (uint32_t r = 0; r < nr0; r++) { - memcpy(spad_data_th + r * nb11, (const uint8_t *) src1_ptr, nb11); - } - } - func_HVX((const uint8_t *) src0_ptr, (const uint8_t *) spad_data_th, (uint8_t *) dst_ptr, ne00); - } else { - func_HVX((const uint8_t *) src0_ptr, (const uint8_t *) src1_ptr, (uint8_t *) dst_ptr, ne00); + uint8_t * dst_curr = (uint8_t *)dst->data + i03 * nb3 + i02 * nb2 + i01 * nb1; + dma_queue_push(q, dma_make_ptr(dst_curr, d_spad), nb1, bctx->dst_row_size_aligned, row_size_bytes, current_block_size); + + if (ir_prefetch < end_row) { + uint32_t next_block_size = calc_block_size(bctx, ir_prefetch, end_row, ne01, ne02); + uint32_t p03, p02, p01, prem; + p03 = fastdiv(ir_prefetch, &bctx->src0_dim12_div); + prem = ir_prefetch - p03 * (ne02 * ne01); + p02 = fastdiv(prem, &bctx->src0_dim1_div); + p01 = prem - p02 * ne01; + uint8_t * s0_next = (uint8_t *)src0->data + p03 * nb03 + p02 * nb02 + p01 * nb01; + + dma_queue_push(q, dma_make_ptr(s0_spad, s0_next), bctx->src0_row_size_aligned, nb01, row_size_bytes, next_block_size); + ir_prefetch += next_block_size; } + ir += current_block_size; + } + dma_queue_flush(q); +} - src0_ptr += src0_row_size; - dst_ptr += dst_row_size; +// 2. Vector Same Shape (ne1x == ne0x) or Simple Broadcast +static void binary_job_vector_same_shape(unsigned int nth, unsigned int ith, void * data) { + struct htp_binary_context * bctx = (struct htp_binary_context *) data; + struct htp_ops_context * octx = bctx->octx; + htp_binary_preamble; + + const uint32_t src0_type = octx->src[0]->type; + const uint32_t row_size_bytes = (src0_type == HTP_TYPE_F32) ? ne00 * sizeof(float) : ne00 * sizeof(_Float16); + const uint32_t total_rows = ne01 * ne02 * ne03; + const uint32_t start_row = bctx->nrows_per_thread * ith; + const uint32_t end_row = MIN(start_row + bctx->nrows_per_thread, total_rows); + if (start_row >= end_row) return; + + FARF(HIGH, "binary-same-shape: %d/%d (%u:%u) row-size %u (%u)", ith, nth, start_row, end_row, nb01, bctx->dst_row_size_aligned); + + uint8_t * src0_spad_base = octx->src0_spad.data + (ith * octx->src0_spad.size_per_thread); + uint8_t * src1_spad_base = octx->src1_spad.data + (ith * octx->src1_spad.size_per_thread); + uint8_t * dst_spad_base = octx->dst_spad.data + (ith * octx->dst_spad.size_per_thread); + + size_t src0_spad_half = octx->src0_spad.size_per_thread / 2; + size_t src1_spad_half = octx->src1_spad.size_per_thread / 2; + size_t dst_spad_half = octx->dst_spad.size_per_thread / 2; + + dma_queue * q = octx->ctx->dma[ith]; + uint32_t ir_prefetch = start_row; + int spad_idx = 0; + + for (int k = 0; k < 2 && ir_prefetch < end_row; k++) { + uint32_t current_block_size = calc_block_size(bctx, ir_prefetch, end_row, ne01, ne02); + uint32_t i03, i02, i01, rem; + i03 = fastdiv(ir_prefetch, &bctx->src0_dim12_div); + rem = ir_prefetch - i03 * (ne02 * ne01); + i02 = fastdiv(rem, &bctx->src0_dim1_div); + i01 = rem - i02 * ne01; + + uint32_t i13 = (ne13 == 1) ? 0 : i03; + uint32_t i12 = (ne12 == 1) ? 0 : i02; + uint32_t i11 = (ne11 == 1) ? 0 : i01; + + uint8_t * src0_curr = (uint8_t *)src0->data + i03 * nb03 + i02 * nb02 + i01 * nb01; + uint8_t * src1_curr = (uint8_t *)src1->data + i13 * nb13 + i12 * nb12 + i11 * nb11; + uint8_t * dst_curr = (uint8_t *)dst->data + i03 * nb3 + i02 * nb2 + i01 * nb1; + + uint8_t * s0_spad = src0_spad_base + spad_idx * src0_spad_half; + uint8_t * s1_spad = src1_spad_base + spad_idx * src1_spad_half; + uint8_t * d_spad = dst_spad_base + spad_idx * dst_spad_half; + + dma_queue_push(q, dma_make_ptr(dst_curr, d_spad), nb1, bctx->dst_row_size_aligned, row_size_bytes, 0); + dma_queue_push(q, dma_make_ptr(s0_spad, src0_curr), bctx->src0_row_size_aligned, nb01, row_size_bytes, current_block_size); + dma_queue_push(q, dma_make_ptr(s1_spad, src1_curr), bctx->src1_row_size_aligned, nb11, row_size_bytes, current_block_size); + ir_prefetch += current_block_size; + spad_idx ^= 1; } - t2 = HAP_perf_get_qtimer_count(); + for (uint32_t ir = start_row; ir < end_row; ) { + uint32_t current_block_size = calc_block_size(bctx, ir, end_row, ne01, ne02); + uint8_t * d_spad = (uint8_t *) dma_queue_pop(q).src; + uint8_t * s0_spad = (uint8_t *) dma_queue_pop(q).dst; + uint8_t * s1_spad = (uint8_t *) dma_queue_pop(q).dst; + + for (uint32_t r = 0; r < current_block_size; r++) { + uint8_t * r_src0 = s0_spad + r * bctx->src0_row_size_aligned; + uint8_t * r_src1 = s1_spad + r * bctx->src1_row_size_aligned; + uint8_t * r_dst = d_spad + r * bctx->dst_row_size_aligned; + COMPUTE_VECTOR_OP_AAA(r_dst, r_src0, r_src1, src0_type, ne00); + } - FARF(HIGH, "binary-f32 %d/%d/%d: %ux%ux%ux%u (%u:%u) x %ux%ux%ux%u -> %ux%ux%ux%u usec %u\n", ith, nth, opt_path, - ne00, ne01, ne02, ne03, src0_start_row, src0_end_row, ne10, ne11, ne12, ne13, ne0, ne1, ne2, ne3, - (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1)); + uint32_t i03, i02, i01, rem; + i03 = fastdiv(ir, &bctx->src0_dim12_div); + rem = ir - i03 * (ne02 * ne01); + i02 = fastdiv(rem, &bctx->src0_dim1_div); + i01 = rem - i02 * ne01; + uint8_t * dst_curr = (uint8_t *)dst->data + i03 * nb3 + i02 * nb2 + i01 * nb1; + dma_queue_push(q, dma_make_ptr(dst_curr, d_spad), nb1, bctx->dst_row_size_aligned, row_size_bytes, current_block_size); + + if (ir_prefetch < end_row) { + uint32_t next_block_size = calc_block_size(bctx, ir_prefetch, end_row, ne01, ne02); + uint32_t p03, p02, p01, prem; + p03 = fastdiv(ir_prefetch, &bctx->src0_dim12_div); + prem = ir_prefetch - p03 * (ne02 * ne01); + p02 = fastdiv(prem, &bctx->src0_dim1_div); + p01 = prem - p02 * ne01; + + uint32_t p13 = (ne13 == 1) ? 0 : p03; + uint32_t p12 = (ne12 == 1) ? 0 : p02; + uint32_t p11 = (ne11 == 1) ? 0 : p01; + + uint8_t * s0_next = (uint8_t *)src0->data + p03 * nb03 + p02 * nb02 + p01 * nb01; + uint8_t * s1_next = (uint8_t *)src1->data + p13 * nb13 + p12 * nb12 + p11 * nb11; + + dma_queue_push(q, dma_make_ptr(s0_spad, s0_next), bctx->src0_row_size_aligned, nb01, row_size_bytes, next_block_size); + dma_queue_push(q, dma_make_ptr(s1_spad, s1_next), bctx->src1_row_size_aligned, nb11, row_size_bytes, next_block_size); + + ir_prefetch += next_block_size; + } + ir += current_block_size; + } + dma_queue_flush(q); } -static void binary_add_id_job_f32_per_thread(struct htp_ops_context * octx, - uint8_t * spad_data, - uint32_t nth, - uint32_t ith, - hvx_elemwise_f32_func func_HVX) { +// 3. Row Broadcast (ne11 == 1, ne12 == 1, single row src1) +static void binary_job_vector_row_broadcast(unsigned int nth, unsigned int ith, void * data) { + struct htp_binary_context * bctx = (struct htp_binary_context *) data; + struct htp_ops_context * octx = bctx->octx; htp_binary_preamble; - const size_t src0_row_size = nb01; - const size_t src1_row_size = nb11; - const size_t dst_row_size = nb1; + const uint32_t src0_type = octx->src[0]->type; + const uint32_t row_size_bytes = (src0_type == HTP_TYPE_F32) ? ne00 * sizeof(float) : ne00 * sizeof(_Float16); + const uint32_t total_rows = ne01 * ne02 * ne03; + const uint32_t start_row = bctx->nrows_per_thread * ith; + const uint32_t end_row = MIN(start_row + bctx->nrows_per_thread, total_rows); + if (start_row >= end_row) return; - const uint32_t src0_nrows = ne01 * ne02 * ne03; // src0 rows + FARF(HIGH, "binary-row-bcast: %d/%d (%u:%u) row-size %u (%u)", ith, nth, start_row, end_row, nb01, bctx->dst_row_size_aligned); - const uint32_t src0_start_row = src0_nrows_per_thread * ith; - const uint32_t src0_end_row = MIN(src0_start_row + src0_nrows_per_thread, src0_nrows); + uint8_t * src0_spad_base = octx->src0_spad.data + (ith * octx->src0_spad.size_per_thread); + uint8_t * src1_spad_base = octx->src1_spad.data + (ith * octx->src1_spad.size_per_thread); + uint8_t * dst_spad_base = octx->dst_spad.data + (ith * octx->dst_spad.size_per_thread); - // no work for this thread - if (src0_start_row >= src0_end_row) { - return; - } + size_t src0_spad_half = octx->src0_spad.size_per_thread / 2; + size_t dst_spad_half = octx->dst_spad.size_per_thread / 2; + + dma_queue * q = octx->ctx->dma[ith]; + uint32_t ir_prefetch = start_row; + int spad_idx = 0; + + void * s1_ptr = (void *) src1_spad_base; + + for (int k = 0; k < 2 && ir_prefetch < end_row; k++) { + uint32_t current_block_size = calc_block_size(bctx, ir_prefetch, end_row, ne01, ne02); + uint32_t i03 = fastdiv(ir_prefetch, &bctx->src0_dim12_div); + uint32_t rem = ir_prefetch - i03 * (ne02 * ne01); + uint32_t i02 = fastdiv(rem, &bctx->src0_dim1_div); + uint32_t i01 = rem - i02 * ne01; + + uint8_t * src0_curr = (uint8_t *)src0->data + i03 * nb03 + i02 * nb02 + i01 * nb01; + uint8_t * dst_curr = (uint8_t *)dst->data + i03 * nb3 + i02 * nb2 + i01 * nb1; - uint64_t t1, t2; - t1 = HAP_perf_get_qtimer_count(); + uint8_t * s0_spad = src0_spad_base + spad_idx * src0_spad_half; + uint8_t * d_spad = dst_spad_base + spad_idx * dst_spad_half; - if ((0 == htp_is_aligned((void *) src0->data, VLEN)) || (0 == htp_is_aligned((void *) src1->data, VLEN)) || - (0 == htp_is_aligned((void *) dst->data, VLEN))) { - FARF(HIGH, "add-id-f32: unaligned addresses, possibly slower execution\n"); + dma_queue_push(q, dma_make_ptr(dst_curr, d_spad), nb1, bctx->dst_row_size_aligned, row_size_bytes, 0); + dma_queue_push(q, dma_make_ptr(s0_spad, src0_curr), bctx->src0_row_size_aligned, nb01, row_size_bytes, current_block_size); + ir_prefetch += current_block_size; + spad_idx ^= 1; } - const uint8_t * restrict data_src0 = (const uint8_t *) src0->data; - const uint8_t * restrict data_src1 = (const uint8_t *) src1->data; - uint8_t * restrict data_dst = (uint8_t *) dst->data; - - const uint32_t ne02_ne01 = ne02 * ne01; - for (uint32_t ir = src0_start_row; ir < src0_end_row; ir++) { - // src0 indices - const uint32_t i03 = fastdiv(ir, &octx->src0_div21); - const uint32_t i02 = fastdiv(ir - i03 * ne02_ne01, &octx->src0_div1); - const uint32_t i01 = (ir - i03 * ne02_ne01 - i02 * ne01); - - // src1 indices - const int i11 = *(int32_t *) ((char *) src2->data + i01 * src2->nb[0] + i02 * src2->nb[1]); - assert(i11 >= 0 && i11 < ne11); - - float * restrict dst_ptr = (float *) (data_dst + i03 * nb3 + i02 * nb2 + i01 * nb1); - const float * restrict src0_ptr = (const float *) (data_src0 + i03 * nb03 + i02 * nb02 + i01 * nb01); - const float * restrict src1_ptr = (const float *) (data_src1 + 0 + 0 + i11 * nb11); - - if (ir + 1 < src0_end_row) { - htp_l2fetch(src0_ptr + ne00, 1, src0_row_size, src0_row_size); - if (src1_row_size == src0_row_size) { - htp_l2fetch(src1_ptr + ne10, 1, src1_row_size, src1_row_size); - } + for (uint32_t ir = start_row; ir < end_row; ) { + uint32_t current_block_size = calc_block_size(bctx, ir, end_row, ne01, ne02); + uint8_t * d_spad = (uint8_t *) dma_queue_pop(q).src; + uint8_t * s0_spad = (uint8_t *) dma_queue_pop(q).dst; + + for (uint32_t r = 0; r < current_block_size; r++) { + uint8_t * r_src0 = s0_spad + r * bctx->src0_row_size_aligned; + uint8_t * r_src1 = (uint8_t *)s1_ptr; // Constant + uint8_t * r_dst = d_spad + r * bctx->dst_row_size_aligned; + COMPUTE_VECTOR_OP_AAA(r_dst, r_src0, r_src1, src0_type, ne00); } - const uint32_t nr0 = ne00 / ne10; - if (nr0 > 1) { - for (uint32_t r = 0; r < nr0; r++) { - memcpy(spad_data + r * nb10, (const uint8_t *) src1_ptr, nb10); - } - func_HVX((const uint8_t *) src0_ptr, (const uint8_t *) spad_data, (uint8_t *) dst_ptr, ne00); - } else { - func_HVX((const uint8_t *) src0_ptr, (const uint8_t *) src1_ptr, (uint8_t *) dst_ptr, ne00); + uint32_t i03 = fastdiv(ir, &bctx->src0_dim12_div); + uint32_t rem = ir - i03 * (ne02 * ne01); + uint32_t i02 = fastdiv(rem, &bctx->src0_dim1_div); + uint32_t i01 = rem - i02 * ne01; + uint8_t * dst_curr = (uint8_t *)dst->data + i03 * nb3 + i02 * nb2 + i01 * nb1; + dma_queue_push(q, dma_make_ptr(dst_curr, d_spad), nb1, bctx->dst_row_size_aligned, row_size_bytes, current_block_size); + + if (ir_prefetch < end_row) { + uint32_t next_block_size = calc_block_size(bctx, ir_prefetch, end_row, ne01, ne02); + uint32_t p03 = fastdiv(ir_prefetch, &bctx->src0_dim12_div); + uint32_t prem = ir_prefetch - p03 * (ne02 * ne01); + uint32_t p02 = fastdiv(prem, &bctx->src0_dim1_div); + uint32_t p01 = prem - p02 * ne01; + uint8_t * s0_next = (uint8_t *)src0->data + p03 * nb03 + p02 * nb02 + p01 * nb01; + dma_queue_push(q, dma_make_ptr(s0_spad, s0_next), bctx->src0_row_size_aligned, nb01, row_size_bytes, next_block_size); + ir_prefetch += next_block_size; } + ir += current_block_size; + } + dma_queue_flush(q); +} + +// 4. Vector Complex (ne10 == ne00, complex broadcast) +static void binary_job_vector_complex(unsigned int nth, unsigned int ith, void * data) { + struct htp_binary_context * bctx = (struct htp_binary_context *) data; + struct htp_ops_context * octx = bctx->octx; + htp_binary_preamble; + + const uint32_t src0_type = octx->src[0]->type; + const uint32_t row_size_bytes = (src0_type == HTP_TYPE_F32) ? ne00 * sizeof(float) : ne00 * sizeof(_Float16); + const uint32_t total_rows = ne01 * ne02 * ne03; + const uint32_t start_row = bctx->nrows_per_thread * ith; + const uint32_t end_row = MIN(start_row + bctx->nrows_per_thread, total_rows); + if (start_row >= end_row) return; + + FARF(HIGH, "binary-complex: %d/%d (%u:%u) row-size %u (%u)", ith, nth, start_row, end_row, nb01, bctx->dst_row_size_aligned); + + uint8_t * src0_spad_base = octx->src0_spad.data + (ith * octx->src0_spad.size_per_thread); + uint8_t * dst_spad_base = octx->dst_spad.data + (ith * octx->dst_spad.size_per_thread); + size_t src0_spad_half = octx->src0_spad.size_per_thread / 2; + size_t dst_spad_half = octx->dst_spad.size_per_thread / 2; + + dma_queue * q = octx->ctx->dma[ith]; + uint32_t ir_prefetch = start_row; + int spad_idx = 0; + + for (int k = 0; k < 2 && ir_prefetch < end_row; k++) { + uint32_t current_block_size = calc_block_size(bctx, ir_prefetch, end_row, ne01, ne02); + uint32_t i03 = fastdiv(ir_prefetch, &bctx->src0_dim12_div); + uint32_t rem = ir_prefetch - i03 * (ne02 * ne01); + uint32_t i02 = fastdiv(rem, &bctx->src0_dim1_div); + uint32_t i01 = rem - i02 * ne01; + + uint8_t * src0_curr = (uint8_t *)src0->data + i03 * nb03 + i02 * nb02 + i01 * nb01; + uint8_t * dst_curr = (uint8_t *)dst->data + i03 * nb3 + i02 * nb2 + i01 * nb1; + + uint8_t * s0_spad = src0_spad_base + spad_idx * src0_spad_half; + uint8_t * d_spad = dst_spad_base + spad_idx * dst_spad_half; + + dma_queue_push(q, dma_make_ptr(dst_curr, d_spad), nb1, bctx->dst_row_size_aligned, row_size_bytes, 0); + dma_queue_push(q, dma_make_ptr(s0_spad, src0_curr), bctx->src0_row_size_aligned, nb01, row_size_bytes, current_block_size); + ir_prefetch += current_block_size; + spad_idx ^= 1; } - t2 = HAP_perf_get_qtimer_count(); + for (uint32_t ir = start_row; ir < end_row; ) { + uint32_t current_block_size = calc_block_size(bctx, ir, end_row, ne01, ne02); + uint8_t * d_spad = (uint8_t *) dma_queue_pop(q).src; + uint8_t * s0_spad = (uint8_t *) dma_queue_pop(q).dst; + + uint32_t i03 = fastdiv(ir, &bctx->src0_dim12_div); + uint32_t rem = ir - i03 * (ne02 * ne01); + uint32_t i02 = fastdiv(rem, &bctx->src0_dim1_div); + uint32_t i01 = rem - i02 * ne01; + + for (uint32_t r = 0; r < current_block_size; r++) { + uint32_t r_i01 = i01 + r; + uint32_t i13 = fastmodulo(i03, ne13, &bctx->src1_dim3_div); + uint32_t i12 = fastmodulo(i02, ne12, &bctx->src1_dim2_div); + uint32_t i11 = fastmodulo(r_i01, ne11, &bctx->src1_dim1_div); - FARF(HIGH, "add-id-f32 %d/%d: %ux%ux%ux%u (%u:%u) x %ux%ux%ux%u (%ux%ux%ux%u) -> %ux%ux%ux%u usec %u\n", ith, nth, - src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], src0_start_row, src0_end_row, src1->ne[0], src1->ne[1], - src1->ne[2], src1->ne[3], src2->ne[0], src2->ne[1], src2->ne[2], src2->ne[3], dst->ne[0], dst->ne[1], - dst->ne[2], dst->ne[3], (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1)); + uint8_t * r_src0 = s0_spad + r * bctx->src0_row_size_aligned; + uint8_t * r_src1 = (uint8_t *)src1->data + i13 * nb13 + i12 * nb12 + i11 * nb11; + uint8_t * r_dst = d_spad + r * bctx->dst_row_size_aligned; + + // Read src1 from DDR (unaligned) + COMPUTE_VECTOR_OP_AAU(r_dst, r_src0, r_src1, src0_type, ne00); + } + + uint8_t * dst_curr = (uint8_t *)dst->data + i03 * nb3 + i02 * nb2 + i01 * nb1; + dma_queue_push(q, dma_make_ptr(dst_curr, d_spad), nb1, bctx->dst_row_size_aligned, row_size_bytes, current_block_size); + + if (ir_prefetch < end_row) { + uint32_t next_block_size = calc_block_size(bctx, ir_prefetch, end_row, ne01, ne02); + uint32_t p03 = fastdiv(ir_prefetch, &bctx->src0_dim12_div); + uint32_t prem = ir_prefetch - p03 * (ne02 * ne01); + uint32_t p02 = fastdiv(prem, &bctx->src0_dim1_div); + uint32_t p01 = prem - p02 * ne01; + uint8_t * s0_next = (uint8_t *)src0->data + p03 * nb03 + p02 * nb02 + p01 * nb01; + dma_queue_push(q, dma_make_ptr(s0_spad, s0_next), bctx->src0_row_size_aligned, nb01, row_size_bytes, next_block_size); + ir_prefetch += next_block_size; + } + ir += current_block_size; + } + dma_queue_flush(q); } -static void binary_job_dispatcher_f32(unsigned int n, unsigned int i, void * data) { - struct htp_ops_context * octx = (struct htp_ops_context *) data; +// 5. Element Repeat (ne10 != ne00) +static void binary_job_element_repeat(unsigned int nth, unsigned int ith, void * data) { + struct htp_binary_context * bctx = (struct htp_binary_context *) data; + struct htp_ops_context * octx = bctx->octx; + htp_binary_preamble; - switch (octx->op) { - case HTP_OP_MUL: - case HTP_OP_ADD: - case HTP_OP_SUB: - binary_job_f32_per_thread(octx, octx->src1_spad.data, n, i, octx->op); - break; + const uint32_t src0_type = octx->src[0]->type; + const uint32_t elem_size_bytes = (src0_type == HTP_TYPE_F32) ? sizeof(float) : sizeof(_Float16); + const uint32_t row_size_bytes = ne00 * elem_size_bytes;; + const uint32_t total_rows = ne01 * ne02 * ne03; + const uint32_t start_row = bctx->nrows_per_thread * ith; + const uint32_t end_row = MIN(start_row + bctx->nrows_per_thread, total_rows); + if (start_row >= end_row) return; + + uint8_t * src0_spad_base = octx->src0_spad.data + (ith * octx->src0_spad.size_per_thread); + uint8_t * dst_spad_base = octx->dst_spad.data + (ith * octx->dst_spad.size_per_thread); + size_t src0_spad_half = octx->src0_spad.size_per_thread / 2; + size_t dst_spad_half = octx->dst_spad.size_per_thread / 2; + + FARF(HIGH, "binary-repeat: %d/%d (%u:%u) row-size %u (%u)", ith, nth, start_row, end_row, nb01, bctx->dst_row_size_aligned); + + dma_queue * q = octx->ctx->dma[ith]; + uint32_t ir_prefetch = start_row; + int spad_idx = 0; + + for (int k = 0; k < 2 && ir_prefetch < end_row; k++) { + uint32_t current_block_size = calc_block_size(bctx, ir_prefetch, end_row, ne01, ne02); + uint32_t i03 = fastdiv(ir_prefetch, &bctx->src0_dim12_div); + uint32_t rem = ir_prefetch - i03 * (ne02 * ne01); + uint32_t i02 = fastdiv(rem, &bctx->src0_dim1_div); + uint32_t i01 = rem - i02 * ne01; + + uint8_t * src0_curr = (uint8_t *)src0->data + i03 * nb03 + i02 * nb02 + i01 * nb01; + uint8_t * dst_curr = (uint8_t *)dst->data + i03 * nb3 + i02 * nb2 + i01 * nb1; + + uint8_t * s0_spad = src0_spad_base + spad_idx * src0_spad_half; + uint8_t * d_spad = dst_spad_base + spad_idx * dst_spad_half; + + dma_queue_push(q, dma_make_ptr(dst_curr, d_spad), nb1, bctx->dst_row_size_aligned, row_size_bytes, 0); + dma_queue_push(q, dma_make_ptr(s0_spad, src0_curr), bctx->src0_row_size_aligned, nb01, row_size_bytes, current_block_size); + ir_prefetch += current_block_size; + spad_idx ^= 1; + } - case HTP_OP_ADD_ID: - binary_add_id_job_f32_per_thread(octx, octx->src0_spad.data, n, i, hvx_add_f32); - break; + for (uint32_t ir = start_row; ir < end_row; ) { + uint32_t current_block_size = calc_block_size(bctx, ir, end_row, ne01, ne02); + uint8_t * d_spad = (uint8_t *) dma_queue_pop(q).src; + uint8_t * s0_spad = (uint8_t *) dma_queue_pop(q).dst; + + uint32_t i03 = fastdiv(ir, &bctx->src0_dim12_div); + uint32_t rem = ir - i03 * (ne02 * ne01); + uint32_t i02 = fastdiv(rem, &bctx->src0_dim1_div); + uint32_t i01 = rem - i02 * ne01; + + for (uint32_t r = 0; r < current_block_size; r++) { + uint32_t r_i01 = i01 + r; + uint32_t i13 = fastmodulo(i03, ne13, &bctx->src1_dim3_div); + uint32_t i12 = fastmodulo(i02, ne12, &bctx->src1_dim2_div); + uint32_t i11 = fastmodulo(r_i01, ne11, &bctx->src1_dim1_div); + + uint8_t * r_src0 = s0_spad + r * bctx->src0_row_size_aligned; + uint8_t * r_src1_row = (uint8_t *)src1->data + i13 * nb13 + i12 * nb12 + i11 * nb11; + uint8_t * r_dst = d_spad + r * bctx->dst_row_size_aligned; + + // Repeat src1 row + for (uint32_t c = 0; c < ne00; c += ne10) { + uint32_t len = MIN(ne10, ne00 - c); + // Use UUU for speed and simplicity + COMPUTE_VECTOR_OP_UUU(r_dst + c * elem_size_bytes, r_src0 + c * elem_size_bytes, r_src1_row, src0_type, len); + } + } - default: - FARF(ERROR, "Unknown Binary Op %u", octx->op); - break; + uint8_t * dst_curr = (uint8_t *)dst->data + i03 * nb3 + i02 * nb2 + i01 * nb1; + dma_queue_push(q, dma_make_ptr(dst_curr, d_spad), nb1, bctx->dst_row_size_aligned, row_size_bytes, current_block_size); + + if (ir_prefetch < end_row) { + uint32_t next_block_size = calc_block_size(bctx, ir_prefetch, end_row, ne01, ne02); + uint32_t p03 = fastdiv(ir_prefetch, &bctx->src0_dim12_div); + uint32_t prem = ir_prefetch - p03 * (ne02 * ne01); + uint32_t p02 = fastdiv(prem, &bctx->src0_dim1_div); + uint32_t p01 = prem - p02 * ne01; + uint8_t * s0_next = (uint8_t *)src0->data + p03 * nb03 + p02 * nb02 + p01 * nb01; + dma_queue_push(q, dma_make_ptr(s0_spad, s0_next), bctx->src0_row_size_aligned, nb01, row_size_bytes, next_block_size); + ir_prefetch += next_block_size; + } + ir += current_block_size; } + dma_queue_flush(q); } -static int execute_op_binary_f32(struct htp_ops_context * octx) { - int err = HTP_STATUS_OK; +// 6. ADD_ID (src1 gathered via src2 indices) +static void binary_job_add_id(unsigned int nth, unsigned int ith, void * data) { + struct htp_binary_context * bctx = (struct htp_binary_context *) data; + struct htp_ops_context * octx = bctx->octx; + + const struct htp_tensor * src0 = octx->src[0]; + const struct htp_tensor * src1 = octx->src[1]; + const struct htp_tensor * src2 = octx->src[2]; + const struct htp_tensor * dst = octx->dst; + + const uint32_t ne00 = src0->ne[0]; + const uint32_t ne01 = src0->ne[1]; + const uint32_t ne02 = src0->ne[2]; + const uint32_t ne03 = src0->ne[3]; + const uint32_t ne11 = src1->ne[1]; // for bounds check + + const uint32_t nb01 = src0->nb[1]; + const uint32_t nb02 = src0->nb[2]; + const uint32_t nb03 = src0->nb[3]; + const uint32_t nb11 = src1->nb[1]; // src1 row stride + + const uint32_t nb1 = dst->nb[1]; + const uint32_t nb2 = dst->nb[2]; + const uint32_t nb3 = dst->nb[3]; + + const uint32_t total_rows = ne01 * ne02 * ne03; + const uint32_t start_row = bctx->nrows_per_thread * ith; + const uint32_t end_row = MIN(start_row + bctx->nrows_per_thread, total_rows); + if (start_row >= end_row) return; + + uint8_t * src0_spad_base = octx->src0_spad.data + (ith * octx->src0_spad.size_per_thread); + uint8_t * dst_spad_base = octx->dst_spad.data + (ith * octx->dst_spad.size_per_thread); + size_t src0_spad_half = octx->src0_spad.size_per_thread / 2; + size_t dst_spad_half = octx->dst_spad.size_per_thread / 2; + + dma_queue * q = octx->ctx->dma[ith]; + uint32_t ir_prefetch = start_row; + int spad_idx = 0; + + for (int k = 0; k < 2 && ir_prefetch < end_row; k++) { + uint32_t current_block_size = calc_block_size(bctx, ir_prefetch, end_row, ne01, ne02); + uint32_t i03 = fastdiv(ir_prefetch, &bctx->src0_dim12_div); + uint32_t rem = ir_prefetch - i03 * (ne02 * ne01); + uint32_t i02 = fastdiv(rem, &bctx->src0_dim1_div); + uint32_t i01 = rem - i02 * ne01; + + uint8_t * src0_curr = (uint8_t *)src0->data + i03 * nb03 + i02 * nb02 + i01 * nb01; + uint8_t * dst_curr = (uint8_t *)dst->data + i03 * nb3 + i02 * nb2 + i01 * nb1; + + uint8_t * s0_spad = src0_spad_base + spad_idx * src0_spad_half; + uint8_t * d_spad = dst_spad_base + spad_idx * dst_spad_half; + + dma_queue_push(q, dma_make_ptr(dst_curr, d_spad), nb1, bctx->dst_row_size_aligned, ne00 * sizeof(float), 0); + dma_queue_push(q, dma_make_ptr(s0_spad, src0_curr), bctx->src0_row_size_aligned, nb01, ne00 * sizeof(float), current_block_size); + ir_prefetch += current_block_size; + spad_idx ^= 1; + } - const struct htp_tensor * src0 = &octx->src0; - const struct htp_tensor * src1 = &octx->src1; - struct htp_tensor * dst = &octx->dst; + for (uint32_t ir = start_row; ir < end_row; ) { + uint32_t current_block_size = calc_block_size(bctx, ir, end_row, ne01, ne02); + uint8_t * d_spad = (uint8_t *) dma_queue_pop(q).src; + uint8_t * s0_spad = (uint8_t *) dma_queue_pop(q).dst; - worker_callback_t binary_op_func; - const char * op_type = NULL; + uint32_t i03 = fastdiv(ir, &bctx->src0_dim12_div); + uint32_t rem = ir - i03 * (ne02 * ne01); + uint32_t i02 = fastdiv(rem, &bctx->src0_dim1_div); + uint32_t i01 = rem - i02 * ne01; - switch (octx->op) { - case HTP_OP_MUL: - binary_op_func = binary_job_dispatcher_f32; - op_type = "mul-f32"; - break; + for (uint32_t r = 0; r < current_block_size; r++) { + uint32_t r_i01 = i01 + r; // linear within block since we split at ne01 - case HTP_OP_ADD: - binary_op_func = binary_job_dispatcher_f32; - op_type = "add-f32"; - break; + const int32_t idx = *(int32_t *)((char *)src2->data + r_i01 * src2->nb[0] + i02 * src2->nb[1]); - case HTP_OP_SUB: - binary_op_func = binary_job_dispatcher_f32; - op_type = "sub-f32"; - break; + uint8_t * r_src1 = (uint8_t *)src1->data + idx * nb11; + uint8_t * r_src0 = s0_spad + r * bctx->src0_row_size_aligned; + uint8_t * r_dst = d_spad + r * bctx->dst_row_size_aligned; - case HTP_OP_ADD_ID: - binary_op_func = binary_job_dispatcher_f32; - op_type = "add-id-f32"; - break; + hvx_add_f32_aau(r_dst, r_src0, r_src1, ne00); + } - default: - FARF(ERROR, "Unsupported binary-Op %u\n", octx->op); - return HTP_STATUS_NO_SUPPORT; + uint8_t * dst_curr = (uint8_t *)dst->data + i03 * nb3 + i02 * nb2 + i01 * nb1; + dma_queue_push(q, dma_make_ptr(dst_curr, d_spad), nb1, bctx->dst_row_size_aligned, ne00 * sizeof(float), current_block_size); + + if (ir_prefetch < end_row) { + uint32_t next_block_size = calc_block_size(bctx, ir_prefetch, end_row, ne01, ne02); + uint32_t p03 = fastdiv(ir_prefetch, &bctx->src0_dim12_div); + uint32_t prem = ir_prefetch - p03 * (ne02 * ne01); + uint32_t p02 = fastdiv(prem, &bctx->src0_dim1_div); + uint32_t p01 = prem - p02 * ne01; + uint8_t * s0_next = (uint8_t *)src0->data + p03 * nb03 + p02 * nb02 + p01 * nb01; + dma_queue_push(q, dma_make_ptr(s0_spad, s0_next), bctx->src0_row_size_aligned, nb01, ne00 * sizeof(float), next_block_size); + ir_prefetch += next_block_size; + } + ir += current_block_size; } + dma_queue_flush(q); +} + +static int execute_op_binary(struct htp_ops_context * octx) { + const struct htp_tensor * src0 = octx->src[0]; + const struct htp_tensor * src1 = octx->src[1]; + const struct htp_tensor * dst = octx->dst; - const int n_threads = octx->n_threads; const uint32_t src0_nrows = src0->ne[1] * src0->ne[2] * src0->ne[3]; + const uint32_t n_threads = MIN(octx->n_threads, src0_nrows); + + // Use packed row sizes for VTCM allocation + const uint32_t src0_type = octx->src[0]->type; + const size_t elem_size = (src0_type == HTP_TYPE_F32) ? sizeof(float) : sizeof(_Float16); + const size_t src0_row_size = src0->ne[0] * elem_size; + const size_t src1_row_size = src1->ne[0] * elem_size; + const size_t dst_row_size = dst->ne[0] * elem_size; + + size_t src0_row_size_aligned = hex_round_up(src0_row_size, VLEN); + size_t src1_row_size_aligned = hex_round_up(src1_row_size, VLEN); + size_t dst_row_size_aligned = hex_round_up(dst_row_size, VLEN); + + bool is_add_id = (octx->op == HTP_OP_ADD_ID); + bool is_scalar = !is_add_id && (src1->ne[0] == 1); + + bool is_transposed = (src0->nb[1] < src0_row_size || src1->nb[1] < src1_row_size || dst->nb[1] < dst_row_size); + + bool is_same_shape = !is_add_id && !is_scalar && !is_transposed && + (src1->ne[0] == src0->ne[0] && src0->ne[0] % VLEN == 0) && + (src1->ne[1] == src0->ne[1] || src1->ne[1] == 1) && + (src1->ne[2] == src0->ne[2] || src1->ne[2] == 1) && + (src1->ne[3] == src0->ne[3] || src1->ne[3] == 1); + + bool is_row_bcast = is_same_shape && (src1->ne[1] == 1 && src1->ne[2] == 1 && src1->ne[3] == 1); + bool is_complex = !is_add_id && !is_scalar && !is_same_shape && (src1->ne[0] == src0->ne[0]); + bool is_repeat = !is_add_id && !is_scalar && !is_same_shape && (src1->ne[0] != src0->ne[0]); + + size_t spad_row_total; + if (is_same_shape) { + spad_row_total = 2 * (src0_row_size_aligned + src1_row_size_aligned + dst_row_size_aligned); + } else { + spad_row_total = 2 * (src0_row_size_aligned + dst_row_size_aligned); + } - const size_t src0_row_size = src0->nb[1]; - const size_t src1_row_size = src1->nb[1]; - const size_t dst_row_size = dst->nb[1]; + size_t rows_per_buffer = octx->ctx->vtcm_size / (n_threads * spad_row_total); - // VTCM scratchpads for all tensors - octx->dst_spad.size = htp_round_up(dst_row_size, 128) * n_threads; - octx->src0_spad.size = htp_round_up(src0_row_size, 128) * n_threads; - octx->src1_spad.size = htp_round_up(src1_row_size, 128) * n_threads; + // Adjust for static src1 in row_bcast case + if (is_row_bcast) { + size_t needed_static = src1_row_size_aligned; + if (octx->ctx->vtcm_size < needed_static) return HTP_STATUS_VTCM_TOO_SMALL; + size_t avail = octx->ctx->vtcm_size - needed_static; + rows_per_buffer = avail / (n_threads * spad_row_total); + } - size_t spad_size = octx->src0_spad.size + octx->src1_spad.size + octx->dst_spad.size; + if (rows_per_buffer < 1) { + FARF(ERROR, "binary: VTCM too small\n"); + return HTP_STATUS_VTCM_TOO_SMALL; + } - FARF(HIGH, - "%s: (%ux%ux%ux%u) * (%ux%ux%ux%u) -> (%ux%ux%ux%u) : src0-spad-size %u src1-spad-size %u dst-spad-size %u\n", - op_type, src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], src1->ne[0], src1->ne[1], src1->ne[2], - src1->ne[3], dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], octx->src0_spad.size, octx->src1_spad.size, - octx->dst_spad.size); + octx->src0_spad.size_per_thread = rows_per_buffer * 2 * src0_row_size_aligned; + octx->dst_spad.size_per_thread = rows_per_buffer * 2 * dst_row_size_aligned; - // Make sure the reserved vtcm size is sufficient - if (octx->ctx->vtcm_size < spad_size) { - FARF(ERROR, "binary-%s : current VTCM reservation %zu is too small, needed %zu\n", op_type, - octx->ctx->vtcm_size, spad_size); - return HTP_STATUS_VTCM_TOO_SMALL; + if (is_add_id || is_scalar || is_complex || is_repeat || is_row_bcast) { + octx->src1_spad.size_per_thread = 0; + } else { + octx->src1_spad.size_per_thread = rows_per_buffer * 2 * src1_row_size_aligned; } - octx->src0_spad.data = octx->ctx->vtcm_base; - octx->src1_spad.data = octx->src0_spad.data + octx->src0_spad.size; - octx->dst_spad.data = octx->src1_spad.data + octx->src1_spad.size; + octx->dst_spad.size = n_threads * octx->dst_spad.size_per_thread; + octx->src0_spad.size = n_threads * octx->src0_spad.size_per_thread; + if (is_row_bcast) { + octx->src1_spad.size = src1_row_size_aligned; + } else { + octx->src1_spad.size = n_threads * octx->src1_spad.size_per_thread; + } - if (!(octx->flags & HTP_OPFLAGS_SKIP_COMPUTE)) { - uint32_t n_jobs = MIN(n_threads, src0_nrows); + if (octx->ctx->vtcm_size < (octx->src0_spad.size + octx->src1_spad.size + octx->dst_spad.size)) { + return HTP_STATUS_VTCM_TOO_SMALL; + } - octx->src0_nrows_per_thread = (src0_nrows + n_jobs - 1) / n_jobs; + octx->src0_spad.data = octx->ctx->vtcm_base; octx->src0_spad.src = NULL; + octx->src1_spad.data = octx->src0_spad.data + octx->src0_spad.size; octx->src1_spad.src = NULL; + octx->dst_spad.data = octx->src1_spad.data + octx->src1_spad.size; octx->dst_spad.src = NULL; - octx->src0_div21 = init_fastdiv_values(src0->ne[2] * src0->ne[1]); - octx->src0_div3 = init_fastdiv_values(src0->ne[3]); - octx->src0_div2 = init_fastdiv_values(src0->ne[2]); - octx->src0_div1 = init_fastdiv_values(src0->ne[1]); + if ((octx->flags & HTP_OPFLAGS_SKIP_COMPUTE)) { + return HTP_STATUS_OK; + } - octx->src1_div21 = init_fastdiv_values(src1->ne[2] * src1->ne[1]); - octx->src1_div3 = init_fastdiv_values(src1->ne[3]); - octx->src1_div2 = init_fastdiv_values(src1->ne[2]); - octx->src1_div1 = init_fastdiv_values(src1->ne[1]); + dma_queue * q = octx->ctx->dma[0]; + if (is_row_bcast) { + dma_queue_push(q, dma_make_ptr(octx->src1_spad.data, (const void *) src1->data), src1_row_size_aligned, 0, src1->ne[0] * elem_size, 1); + } - worker_pool_run_func(octx->ctx->worker_pool, binary_op_func, octx, n_jobs); + struct htp_binary_context bctx; + bctx.octx = octx; + bctx.nrows_per_thread = (src0_nrows + n_threads - 1) / n_threads; + bctx.block_max = rows_per_buffer; + bctx.src0_row_size_aligned = src0_row_size_aligned; + bctx.src1_row_size_aligned = src1_row_size_aligned; + bctx.dst_row_size_aligned = dst_row_size_aligned; + + bctx.src0_dim1_div = init_fastdiv_values(src0->ne[1]); + bctx.src0_dim2_div = init_fastdiv_values(src0->ne[2]); + bctx.src0_dim12_div = init_fastdiv_values(src0->ne[1] * src0->ne[2]); + + bctx.src1_dim1_div = init_fastdiv_values(src1->ne[1]); + bctx.src1_dim2_div = init_fastdiv_values(src1->ne[2]); + bctx.src1_dim3_div = init_fastdiv_values(src1->ne[3]); + + bool src0_contig_dim1 = (src0->nb[2] == src0->ne[1] * src0->nb[1]); + bool dst_contig_dim1 = (dst->nb[2] == src0->ne[1] * dst->nb[1]); + + bool src0_contig_dim2 = (src0->nb[3] == src0->ne[2] * src0->nb[2]); + bool dst_contig_dim2 = (dst->nb[3] == src0->ne[2] * dst->nb[2]); + + bctx.split_at_ne01 = (src0->ne[2] > 1) && ((src1->ne[1] > 1) || (src1->ne[2] > 1) || !src0_contig_dim1 || !dst_contig_dim1); + bctx.split_at_ne02 = (src0->ne[3] > 1) && ((src1->ne[2] > 1) || (src1->ne[3] > 1) || !src0_contig_dim2 || !dst_contig_dim2); + + worker_callback_t worker_func; + if (is_add_id) worker_func = binary_job_add_id; + else if (is_scalar) worker_func = binary_job_scalar; + else if (is_row_bcast) worker_func = binary_job_vector_row_broadcast; + else if (is_same_shape) worker_func = binary_job_vector_same_shape; + else if (is_complex) worker_func = binary_job_vector_complex; + else worker_func = binary_job_element_repeat; + + if (is_row_bcast) { + dma_queue_pop(q); } - return err; + worker_pool_run_func(octx->ctx->worker_pool, worker_func, &bctx, n_threads); + + return HTP_STATUS_OK; } int op_binary(struct htp_ops_context * octx) { - int err = HTP_STATUS_OK; - switch (octx->src0.type) { - case HTP_TYPE_F32: - err = execute_op_binary_f32(octx); - break; + // Does not support permutations of src1 + const struct htp_tensor * src1 = octx->src[1]; + if (src1->nb[1] < src1->nb[0]) { + return HTP_STATUS_NO_SUPPORT; + } - default: - err = HTP_STATUS_NO_SUPPORT; - break; + const uint32_t src0_type = octx->src[0]->type; + if ((src0_type == HTP_TYPE_F32) || (src0_type == HTP_TYPE_F16)) { + return execute_op_binary(octx); } - return err; + return HTP_STATUS_NO_SUPPORT; } + diff --git a/ggml/src/ggml-hexagon/htp/cmake-toolchain.cmake b/ggml/src/ggml-hexagon/htp/cmake-toolchain.cmake index 7fa236e328f..ed5c198468c 100644 --- a/ggml/src/ggml-hexagon/htp/cmake-toolchain.cmake +++ b/ggml/src/ggml-hexagon/htp/cmake-toolchain.cmake @@ -138,15 +138,15 @@ set(CMAKE_SHARED_LIBRARY_SONAME_C_FLAG "-Wl,-soname,") set(CMAKE_SHARED_LIBRARY_SONAME_CXX_FLAG "-Wl,-soname,") #Compiler Options -set(COMMON_FLAGS "-mcpu=hexagon${V_ARCH} -m${V_ARCH} -mhvx=${V_ARCH} -fvectorize -Wall -Werror -fno-zero-initialized-in-bss -G0 -fdata-sections -fpic ${XQF_ARGS}") +set(COMMON_FLAGS "-mcpu=hexagon${V_ARCH} -m${V_ARCH} -mhvx=${V_ARCH} -fvectorize -flto -Wall -Werror -fno-zero-initialized-in-bss -G0 -fdata-sections -fpic ${XQF_ARGS}") set(CMAKE_CXX_FLAGS_DEBUG "${COMMON_FLAGS} -O0 -D_DEBUG -g") -set(CMAKE_CXX_FLAGS_RELWITHDEBINFO "${COMMON_FLAGS} -O3 -g") -set(CMAKE_CXX_FLAGS_RELEASE "${COMMON_FLAGS} -O3") +set(CMAKE_CXX_FLAGS_RELWITHDEBINFO "${COMMON_FLAGS} -O2 -g") +set(CMAKE_CXX_FLAGS_RELEASE "${COMMON_FLAGS} -O2") set(CMAKE_C_FLAGS_DEBUG "${COMMON_FLAGS} -O0 -D_DEBUG -g") -set(CMAKE_C_FLAGS_RELWITHDEBINFO "${COMMON_FLAGS} -O3 -g") -set(CMAKE_C_FLAGS_RELEASE "${COMMON_FLAGS} -O3") +set(CMAKE_C_FLAGS_RELWITHDEBINFO "${COMMON_FLAGS} -O2 -g") +set(CMAKE_C_FLAGS_RELEASE "${COMMON_FLAGS} -O2") set(CMAKE_ASM_FLAGS_DEBUG "${COMMON_FLAGS} ${CMAKE_CXX_FLAGS_DEBUG}") set(CMAKE_ASM_FLAGS_RELEASE "${COMMON_FLAGS} ${CMAKE_CXX_FLAGS_RELEASE}") diff --git a/ggml/src/ggml-hexagon/htp/concat-ops.c b/ggml/src/ggml-hexagon/htp/concat-ops.c new file mode 100644 index 00000000000..f2a381313c5 --- /dev/null +++ b/ggml/src/ggml-hexagon/htp/concat-ops.c @@ -0,0 +1,277 @@ +#include "htp-ctx.h" +#include "htp-ops.h" +#include "hexagon_types.h" +#include "hexagon_protos.h" +#include "hvx_hexagon_protos.h" +#include "hex-dma.h" +#include "vtcm-utils.h" +#include "hvx-utils.h" +#include "hex-fastdiv.h" +#include <string.h> + +struct htp_concat_context { + struct htp_ops_context * octx; + uint32_t dim; + uint32_t nrows_per_thread; + struct fastdiv_values div_ne0; + struct fastdiv_values div_ne1; + struct fastdiv_values div_ne2; +}; + +static void concat_2d_f32_transposed(unsigned int nth, unsigned int ith, void * data) { + struct htp_concat_context * cctx = (struct htp_concat_context *) data; + struct htp_ops_context * octx = cctx->octx; + + const struct htp_tensor * src0 = octx->src[0]; + const struct htp_tensor * src1 = octx->src[1]; + const struct htp_tensor * dst = octx->dst; + + const uint32_t src0_ne0 = src0->ne[0]; + const uint32_t src1_ne0 = src1->ne[0]; + const uint32_t ne1 = dst->ne[1]; + + const uint32_t start_i = ith * cctx->nrows_per_thread; + const uint32_t end_i = (start_i + cctx->nrows_per_thread < ne1) ? (start_i + cctx->nrows_per_thread) : ne1; + if (start_i >= end_i) return; + + dma_queue * q = octx->ctx->dma[ith]; + + uint8_t * spad0_base = octx->src0_spad.data + ith * octx->src0_spad.size_per_thread; + uint8_t * spad1_base = octx->src1_spad.data + ith * octx->src1_spad.size_per_thread; + + const uint32_t block_i = 32; + const uint32_t spad1_stride = block_i * sizeof(float); + + int32_t offsets[32] __attribute__((aligned(128))); + for(int k=0; k<32; k++) { + offsets[k] = k * spad1_stride; + } + HVX_Vector vv = *(HVX_Vector*)offsets; + const uint32_t src1_ne0_padded = hex_round_up(src1_ne0, 32); + const uint32_t spad0_row_bytes = hex_round_up((src0_ne0 + src1_ne0_padded) * sizeof(float), VLEN); + uint32_t mu = src1_ne0_padded * spad1_stride; + + for (uint32_t i = start_i; i < end_i; i += block_i) { + uint32_t current_block_i = (end_i - i < block_i) ? (end_i - i) : block_i; + + uint32_t src1_width_bytes = current_block_i * sizeof(float); + uint8_t * src1_ptr = (uint8_t *)src1->data + i * src1->nb[1]; + dma_queue_push(q, dma_make_ptr(spad1_base, src1_ptr), spad1_stride, src1->nb[0], src1_width_bytes, src1_ne0); + + uint32_t src0_row_bytes = src0_ne0 * sizeof(float); + uint8_t * src0_ptr = (uint8_t *)src0->data + i * src0->nb[1]; + dma_queue_push(q, dma_make_ptr(spad0_base, src0_ptr), spad0_row_bytes, src0->nb[1], src0_row_bytes, current_block_i); + + dma_queue_pop(q); // src1 + + HVX_Vector * vtcm_tmp = (HVX_Vector *)(spad1_base + src1_ne0_padded * spad1_stride); + + for (uint32_t j = 0; j < src1_ne0_padded; j += 32) { + #pragma unroll(4) + for (uint32_t ii = 0; ii < current_block_i; ii++) { + size_t rt = (size_t)(spad1_base + j * spad1_stride + ii * sizeof(float)); + Q6_vgather_ARMVw(&vtcm_tmp[ii], rt, mu, vv); + uint8_t * dst_ptr = spad0_base + ii * spad0_row_bytes + (src0_ne0 + j) * sizeof(float); + hvx_vmemu(dst_ptr) = vtcm_tmp[ii]; + } + } + + dma_queue_pop(q); // src0 + + uint8_t * dst_ptr = (uint8_t *)dst->data + i * dst->nb[1]; + dma_queue_push(q, dma_make_ptr(dst_ptr, spad0_base), dst->nb[1], spad0_row_bytes, (src0_ne0 + src1_ne0) * sizeof(float), current_block_i); + + dma_queue_pop(q); + } +} + +static void concat_2d_f16_transposed(unsigned int nth, unsigned int ith, void * data) { + struct htp_concat_context * cctx = (struct htp_concat_context *) data; + struct htp_ops_context * octx = cctx->octx; + + const struct htp_tensor * src0 = octx->src[0]; + const struct htp_tensor * src1 = octx->src[1]; + const struct htp_tensor * dst = octx->dst; + + const uint32_t src0_ne0 = src0->ne[0]; + const uint32_t src1_ne0 = src1->ne[0]; + const uint32_t ne1 = dst->ne[1]; + + const uint32_t start_i = ith * cctx->nrows_per_thread; + const uint32_t end_i = (start_i + cctx->nrows_per_thread < ne1) ? (start_i + cctx->nrows_per_thread) : ne1; + if (start_i >= end_i) return; + + dma_queue * q = octx->ctx->dma[ith]; + + uint8_t * spad0_base = octx->src0_spad.data + ith * octx->src0_spad.size_per_thread; + uint8_t * spad1_base = octx->src1_spad.data + ith * octx->src1_spad.size_per_thread; + + const uint32_t block_i = 64; + const uint32_t spad1_stride = block_i * sizeof(__fp16); + + int16_t offsets[64] __attribute__((aligned(128))); + for(int k=0; k<64; k++) { + offsets[k] = k * spad1_stride; + } + HVX_Vector vv = *(HVX_Vector*)offsets; + const uint32_t src1_ne0_padded = hex_round_up(src1_ne0, 64); + const uint32_t spad0_row_bytes = hex_round_up((src0_ne0 + src1_ne0_padded) * sizeof(__fp16), VLEN); + uint32_t mu = src1_ne0_padded * spad1_stride; + + for (uint32_t i = start_i; i < end_i; i += block_i) { + uint32_t current_block_i = (end_i - i < block_i) ? (end_i - i) : block_i; + + uint32_t src1_width_bytes = current_block_i * sizeof(__fp16); + uint8_t * src1_ptr = (uint8_t *)src1->data + i * src1->nb[1]; + dma_queue_push(q, dma_make_ptr(spad1_base, src1_ptr), spad1_stride, src1->nb[0], src1_width_bytes, src1_ne0); + + uint32_t src0_row_bytes = src0_ne0 * sizeof(__fp16); + uint8_t * src0_ptr = (uint8_t *)src0->data + i * src0->nb[1]; + dma_queue_push(q, dma_make_ptr(spad0_base, src0_ptr), spad0_row_bytes, src0->nb[1], src0_row_bytes, current_block_i); + + dma_queue_pop(q); // src1 + + HVX_Vector * vtcm_tmp = (HVX_Vector *)(spad1_base + src1_ne0_padded * spad1_stride); + + for (uint32_t j = 0; j < src1_ne0_padded; j += 64) { + #pragma unroll(4) + for (uint32_t ii = 0; ii < current_block_i; ii++) { + size_t rt = (size_t)(spad1_base + j * spad1_stride + ii * sizeof(__fp16)); + Q6_vgather_ARMVh(&vtcm_tmp[ii], rt, mu, vv); + uint8_t * dst_ptr = spad0_base + ii * spad0_row_bytes + (src0_ne0 + j) * sizeof(__fp16); + hvx_vmemu(dst_ptr) = vtcm_tmp[ii]; + } + } + + dma_queue_pop(q); // src0 + + uint8_t * dst_ptr = (uint8_t *)dst->data + i * dst->nb[1]; + dma_queue_push(q, dma_make_ptr(dst_ptr, spad0_base), dst->nb[1], spad0_row_bytes, (src0_ne0 + src1_ne0) * sizeof(__fp16), current_block_i); + + dma_queue_pop(q); + } +} + +static void concat_generic(unsigned int nth, unsigned int ith, void * data) { + struct htp_concat_context * cctx = (struct htp_concat_context *) data; + struct htp_ops_context * octx = cctx->octx; + + const struct htp_tensor * src0 = octx->src[0]; + const struct htp_tensor * src1 = octx->src[1]; + const struct htp_tensor * dst = octx->dst; + + const int dim = cctx->dim; + const uint32_t type_size = (dst->type == HTP_TYPE_F32 || dst->type == HTP_TYPE_I32) ? 4 : 2; + + const uint32_t ne[4] = {dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3]}; + const uint32_t total_elements = ne[0] * ne[1] * ne[2] * ne[3]; + const uint32_t chunk_size = (total_elements + nth - 1) / nth; + + const uint32_t start_idx = MIN(ith * chunk_size, total_elements); + const uint32_t end_idx = MIN(start_idx + chunk_size, total_elements); + + // Naive scalar element-wise copy + for (uint32_t idx = start_idx; idx < end_idx; idx++) { + uint32_t idx_div_ne0 = fastdiv(idx, &cctx->div_ne0); + uint32_t i0 = idx - idx_div_ne0 * ne[0]; + + uint32_t idx_div_ne01 = fastdiv(idx_div_ne0, &cctx->div_ne1); + uint32_t i1 = idx_div_ne0 - idx_div_ne01 * ne[1]; + + uint32_t idx_div_ne012 = fastdiv(idx_div_ne01, &cctx->div_ne2); + uint32_t i2 = idx_div_ne01 - idx_div_ne012 * ne[2]; + uint32_t i3 = idx_div_ne012; + + uint8_t * dst_ptr = (uint8_t *)dst->data + i3 * dst->nb[3] + i2 * dst->nb[2] + i1 * dst->nb[1] + i0 * dst->nb[0]; + + uint32_t idx_dim = 0; + if (dim == 0) idx_dim = i0; + else if (dim == 1) idx_dim = i1; + else if (dim == 2) idx_dim = i2; + else if (dim == 3) idx_dim = i3; + + const struct htp_tensor * src = (idx_dim < src0->ne[dim]) ? src0 : src1; + + uint32_t s0 = i0; + uint32_t s1 = i1; + uint32_t s2 = i2; + uint32_t s3 = i3; + + if (dim == 0 && src == src1) s0 -= src0->ne[0]; + if (dim == 1 && src == src1) s1 -= src0->ne[1]; + if (dim == 2 && src == src1) s2 -= src0->ne[2]; + if (dim == 3 && src == src1) s3 -= src0->ne[3]; + + uint8_t * src_ptr = (uint8_t *)src->data + s3 * src->nb[3] + s2 * src->nb[2] + s1 * src->nb[1] + s0 * src->nb[0]; + + if (type_size == 4) { + *(float*)dst_ptr = *(float*)src_ptr; + } else { + *(__fp16*)dst_ptr = *(__fp16*)src_ptr; + } + } +} + +int op_concat(struct htp_ops_context * octx) { + const struct htp_tensor * src0 = octx->src[0]; + const struct htp_tensor * src1 = octx->src[1]; + const struct htp_tensor * dst = octx->dst; + + int dim = octx->op_params[0]; + + bool is_2d = dst->ne[2] == 1 && dst->ne[3] == 1; + + const uint32_t type_size = (dst->type == HTP_TYPE_F32 || dst->type == HTP_TYPE_I32) ? 4 : 2; + bool is_src1_transposed = (src1->nb[0] > src1->nb[1]); + bool is_src0_transposed = (src0->nb[0] > src0->nb[1]); + + uint32_t n_threads = octx->n_threads; + struct htp_concat_context cctx; + cctx.octx = octx; + cctx.dim = dim; + cctx.div_ne0 = init_fastdiv_values(dst->ne[0]); + cctx.div_ne1 = init_fastdiv_values(dst->ne[1]); + cctx.div_ne2 = init_fastdiv_values(dst->ne[2]); + + void (*worker_func)(unsigned int, unsigned int, void *) = concat_generic; + + if (dim == 0 && is_2d && is_src1_transposed && !is_src0_transposed) { + n_threads = MIN(dst->ne[1], n_threads); + if (n_threads < 1) { + n_threads = 1; + } + uint32_t block_i = (type_size == 4) ? 32 : 64; + + cctx.nrows_per_thread = hmx_ceil_div(dst->ne[1], n_threads); + + // Allocate VTCM + uint32_t spad1_stride = block_i * type_size; + + uint32_t src1_ne0_padded = hex_round_up(src1->ne[0], block_i); + uint32_t spad0_row_bytes = hex_round_up((src0->ne[0] + src1_ne0_padded) * type_size, VLEN); + + octx->src0_spad.size_per_thread = block_i * spad0_row_bytes; + octx->src1_spad.size_per_thread = src1_ne0_padded * spad1_stride + block_i * VLEN; + + octx->src0_spad.size = n_threads * octx->src0_spad.size_per_thread; + octx->src1_spad.size = n_threads * octx->src1_spad.size_per_thread; + + if (octx->src0_spad.size + octx->src1_spad.size > octx->ctx->vtcm_size) { + return HTP_STATUS_VTCM_TOO_SMALL; + } + + octx->src0_spad.data = octx->ctx->vtcm_base; + octx->src1_spad.data = octx->src0_spad.data + octx->src0_spad.size; + octx->src0_spad.src = NULL; + octx->src1_spad.src = NULL; + + if (type_size == 4) { + worker_func = concat_2d_f32_transposed; + } else { + worker_func = concat_2d_f16_transposed; + } + } + + worker_pool_run_func(octx->ctx->worker_pool, worker_func, &cctx, n_threads); + return HTP_STATUS_OK; +} diff --git a/ggml/src/ggml-hexagon/htp/cpy-ops.c b/ggml/src/ggml-hexagon/htp/cpy-ops.c new file mode 100644 index 00000000000..ae507effa51 --- /dev/null +++ b/ggml/src/ggml-hexagon/htp/cpy-ops.c @@ -0,0 +1,295 @@ +#pragma clang diagnostic ignored "-Wunused-variable" +#pragma clang diagnostic ignored "-Wunused-function" +#pragma clang diagnostic ignored "-Wunused-but-set-variable" + +#include <HAP_farf.h> +#include <HAP_perf.h> + +#include <math.h> +#include <string.h> + +#define GGML_COMMON_DECL_C +#include "ggml-common.h" +#include "htp-ctx.h" +#include "htp-ops.h" +#include "htp-ops.h" +#include "hvx-utils.h" + +struct htp_copy_context { + struct htp_ops_context * octx; + + uint32_t src0_type_size; + uint32_t src0_block_size; + + uint32_t dst_type_size; + uint32_t dst_block_size; + + uint32_t src0_blocks_per_row; + uint32_t dst_blocks_per_row; + + uint32_t src0_nrows_per_thread; +}; + +#define cpy_preamble \ + const struct htp_tensor *src0 = octx->src[0]; \ + const struct htp_tensor *dst = octx->dst; \ + \ + const uint32_t ne00 = src0->ne[0]; \ + const uint32_t ne01 = src0->ne[1]; \ + const uint32_t ne02 = src0->ne[2]; \ + const uint32_t ne03 = src0->ne[3]; \ + \ + const uint32_t nb00 = src0->nb[0]; \ + const uint32_t nb01 = src0->nb[1]; \ + const uint32_t nb02 = src0->nb[2]; \ + const uint32_t nb03 = src0->nb[3]; \ + \ + const uint32_t ne0 = dst->ne[0]; \ + const uint32_t ne1 = dst->ne[1]; \ + const uint32_t ne2 = dst->ne[2]; \ + const uint32_t ne3 = dst->ne[3]; \ + \ + const uint32_t nb0 = dst->nb[0]; \ + const uint32_t nb1 = dst->nb[1]; \ + const uint32_t nb2 = dst->nb[2]; \ + const uint32_t nb3 = dst->nb[3]; \ + \ + const uint32_t nr = ne01; + +#define DEFINE_CPY_SAMESHAPE(NAME, ELEM_TYPE, ELEM_SIZE) \ +static void cpy_thread_##NAME##_sameshape(unsigned int nth, unsigned int ith, void * data) { \ + struct htp_copy_context * ct = (struct htp_copy_context *) data; \ + struct htp_ops_context * octx = ct->octx; \ + cpy_preamble; \ + const uint32_t dr = ct->src0_nrows_per_thread; \ + const uint32_t ir0 = dr * ith; \ + const uint32_t ir1 = (ir0 + dr) < nr ? (ir0 + dr) : nr; \ + if (ir0 >= nr) return; \ + for (uint32_t i03 = 0; i03 < ne03; i03++) { \ + for (uint32_t i02 = 0; i02 < ne02; i02++) { \ + _Pragma("unroll(4)") \ + for (uint32_t i01 = ir0; i01 < ir1; i01++) { \ + uint8_t* dst_ptr = (uint8_t*) dst->data + i01*nb1 + i02*nb2 + i03*nb3; \ + uint8_t* src0_ptr = (uint8_t*) src0->data + i01*nb01 + i02*nb02 + i03*nb03; \ + hex_l2fetch(src0_ptr, ne00 * ELEM_SIZE, nb01, 2); \ + hvx_copy_uu(dst_ptr, src0_ptr, ne00, ELEM_SIZE); \ + } \ + } \ + } \ +} + +DEFINE_CPY_SAMESHAPE(f32, float, 4) +DEFINE_CPY_SAMESHAPE(f16, __fp16, 2) + +#define DEFINE_CPY_RESHAPE(NAME, ELEM_TYPE, ELEM_SIZE) \ +static void cpy_thread_##NAME##_reshape(unsigned int nth, unsigned int ith, void * data) { \ + struct htp_copy_context * ct = (struct htp_copy_context *) data; \ + struct htp_ops_context * octx = ct->octx; \ + cpy_preamble; \ + const uint32_t dr = ct->src0_nrows_per_thread; \ + const uint32_t ir0 = dr * ith; \ + const uint32_t ir1 = (ir0 + dr) < nr ? (ir0 + dr) : nr; \ + if (ir0 >= nr) return; \ + const bool src0_contig = (nb00 == ELEM_SIZE) && \ + (nb01 == ne00 * nb00) && \ + (nb02 == ne01 * nb01) && \ + (nb03 == ne02 * nb02); \ + const bool dst_contig = (nb0 == ELEM_SIZE) && \ + (nb1 == ne0 * nb0) && \ + (nb2 == ne1 * nb1) && \ + (nb3 == ne2 * nb2); \ + if (src0_contig && dst_contig) { \ + for (int64_t i03 = 0; i03 < ne03; i03++) { \ + for (int64_t i02 = 0; i02 < ne02; i02++) { \ + uint8_t * src_ptr = (uint8_t *) src0->data + i03*nb03 + i02*nb02 + ir0*nb01; \ + uint32_t flat = ((i03*ne02 + i02)*ne01 + ir0) * ne00; \ + uint8_t * dst_ptr = (uint8_t *) dst->data + flat * ELEM_SIZE; \ + hvx_copy_uu(dst_ptr, src_ptr, (ir1 - ir0) * ne00, ELEM_SIZE); \ + } \ + } \ + return; \ + } \ + const bool reshape_flat_fast = (ne03 == 1 && ne2 == 1 && ne3 == 1) && \ + (ne0 == ne00 * ne01) && (ne1 == ne02) && \ + (nb00 == ELEM_SIZE) && (nb0 == ELEM_SIZE); \ + if (reshape_flat_fast) { \ + for (uint32_t i02 = 0; i02 < ne02; i02++) { \ + for (uint32_t i01 = ir0; i01 < ir1; i01++) { \ + uint8_t * src0_ptr = (uint8_t *) src0->data + i01 * nb01 + i02 * nb02; \ + uint8_t * dst_ptr = (uint8_t *) dst->data + i01 * ne00 * ELEM_SIZE + i02 * nb1; \ + hvx_copy_uu(dst_ptr, src0_ptr, ne00, ELEM_SIZE); \ + } \ + } \ + return; \ + } \ + int64_t k10 = 0; \ + int64_t i11 = 0; \ + int64_t i12 = 0; \ + int64_t i13 = 0; \ + const int64_t nk00 = ct->src0_blocks_per_row; \ + const int64_t nk0 = ct->dst_blocks_per_row; \ + for (int64_t i03 = 0; i03 < ne03; i03++) { \ + for (int64_t i02 = 0; i02 < ne02; i02++) { \ + k10 += nk00 * ir0; \ + while (k10 >= nk0) { \ + k10 -= nk0; \ + if (++i11 == ne1) { \ + i11 = 0; \ + if (++i12 == ne2) { \ + i12 = 0; \ + if (++i13 == ne3) { \ + i13 = 0; \ + } \ + } \ + } \ + } \ + for (int64_t i01 = ir0; i01 < ir1; i01++) { \ + for (int64_t k00 = 0; k00 < nk00; k00++) { \ + const char * src0_ptr = ((char *) src0->data + k00*nb00 + i01*nb01 + i02*nb02 + i03*nb03); \ + char * dst_ptr = ((char *) dst->data + k10*nb0 + i11*nb1 + i12*nb2 + i13*nb3); \ + memcpy(dst_ptr, src0_ptr, ELEM_SIZE); \ + if (++k10 == nk0) { \ + k10 = 0; \ + if (++i11 == ne1) { \ + i11 = 0; \ + if (++i12 == ne2) { \ + i12 = 0; \ + if (++i13 == ne3) { \ + i13 = 0; \ + } \ + } \ + } \ + } \ + } \ + } \ + k10 += nk00 * (ne01 - ir1); \ + while (k10 >= nk0) { \ + k10 -= nk0; \ + if (++i11 == ne1) { \ + i11 = 0; \ + if (++i12 == ne2) { \ + i12 = 0; \ + if (++i13 == ne3) { \ + i13 = 0; \ + } \ + } \ + } \ + } \ + } \ + } \ +} + +DEFINE_CPY_RESHAPE(f32, float, 4) +DEFINE_CPY_RESHAPE(f16, __fp16, 2) + +static void cpy_thread_f16_f32_sameshape(unsigned int nth, unsigned int ith, void * data) { + struct htp_copy_context * ct = (struct htp_copy_context *) data; + struct htp_ops_context * octx = ct->octx; + cpy_preamble; + + // parallelize by src0 rows + const uint32_t dr = ct->src0_nrows_per_thread; + const uint32_t ir0 = dr * ith; + const uint32_t ir1 = (ir0 + dr) < nr ? (ir0 + dr) : nr; + if (ir0 >= nr) return; + + // copy by rows + for (uint32_t i03 = 0; i03 < ne03; i03++) { + for (uint32_t i02 = 0; i02 < ne02; i02++) { + #pragma unroll(2) + for (uint32_t i01 = ir0; i01 < ir1; i01++) { + uint8_t* dst_ptr = (uint8_t*) dst->data + i01*nb1 + i02*nb2 + i03*nb3; + uint8_t* src0_ptr = (uint8_t*) src0->data + i01*nb01 + i02*nb02 + i03*nb03; + hex_l2fetch(src0_ptr, ne00 * sizeof(float), nb01, 2); + hvx_copy_f16_f32_uu(dst_ptr, src0_ptr, ne00); + } + } + } +} + +static void cpy_thread_f32_f16_sameshape(unsigned int nth, unsigned int ith, void * data) { + struct htp_copy_context * ct = (struct htp_copy_context *) data; + struct htp_ops_context * octx = ct->octx; + cpy_preamble; + + // parallelize by src0 rows + const uint32_t dr = ct->src0_nrows_per_thread; + const uint32_t ir0 = dr * ith; + const uint32_t ir1 = (ir0 + dr) < nr ? (ir0 + dr) : nr; + if (ir0 >= nr) return; + + // copy by rows + for (uint32_t i03 = 0; i03 < ne03; i03++) { + for (uint32_t i02 = 0; i02 < ne02; i02++) { + #pragma unroll(2) + for (uint32_t i01 = ir0; i01 < ir1; i01++) { + uint8_t* dst_ptr = (uint8_t*) dst->data + i01*nb1 + i02*nb2 + i03*nb3; + uint8_t* src0_ptr = (uint8_t*) src0->data + i01*nb01 + i02*nb02 + i03*nb03; + hex_l2fetch(src0_ptr, ne00 * sizeof(__fp16), nb01, 2); + hvx_copy_f32_f16_uu(dst_ptr, src0_ptr, ne00); + } + } + } +} + +int op_cpy(struct htp_ops_context * octx) { + cpy_preamble; + + const uint32_t n_threads = MIN(nr, octx->n_threads); + + struct htp_copy_context ct; + ct.octx = octx; + + switch (src0->type) { + case HTP_TYPE_F32: ct.src0_type_size = 4; ct.src0_block_size = 1; ct.src0_blocks_per_row = ne00 / 1; break; + case HTP_TYPE_F16: ct.src0_type_size = 2; ct.src0_block_size = 1; ct.src0_blocks_per_row = ne00 / 1; break; + default: + return HTP_STATUS_NO_SUPPORT; + } + + switch (dst->type) { + case HTP_TYPE_F32: ct.dst_type_size = 4; ct.dst_block_size = 1; ct.dst_blocks_per_row = ne0 / 1; break; + case HTP_TYPE_F16: ct.dst_type_size = 2; ct.dst_block_size = 1; ct.dst_blocks_per_row = ne0 / 1; break; + default: + return HTP_STATUS_NO_SUPPORT; + } + + if (octx->flags & HTP_OPFLAGS_SKIP_COMPUTE) { + return HTP_STATUS_OK; + } + + const bool sametype = (src0->type == dst->type); + const bool transposed = (nb00 > nb01) || (nb0 > nb1); + const bool sameshape = !transposed && (ne00 == ne0 && ne01 == ne1 && ne02 == ne2 && ne03 == ne3); + + ct.src0_nrows_per_thread = (nr + n_threads - 1) / n_threads; + + worker_callback_t copy_fun; + + if (sametype && sameshape) { + if (src0->type == HTP_TYPE_F32) { + copy_fun = cpy_thread_f32_sameshape; + } else { + copy_fun = cpy_thread_f16_sameshape; + } + } else if (sameshape) { + /**/ if (dst->type == HTP_TYPE_F16 && src0->type == HTP_TYPE_F32) + copy_fun = cpy_thread_f16_f32_sameshape; + else if (dst->type == HTP_TYPE_F32 && src0->type == HTP_TYPE_F16) + copy_fun = cpy_thread_f32_f16_sameshape; + else + return HTP_STATUS_NO_SUPPORT; + } else if (sametype) { + if (src0->type == HTP_TYPE_F32) { + copy_fun = cpy_thread_f32_reshape; + } else { + copy_fun = cpy_thread_f16_reshape; + } + } else { + return HTP_STATUS_NO_SUPPORT; + } + + worker_pool_run_func(octx->ctx->worker_pool, copy_fun, &ct, n_threads); + + return HTP_STATUS_OK; +} diff --git a/ggml/src/ggml-hexagon/htp/cumsum-ops.c b/ggml/src/ggml-hexagon/htp/cumsum-ops.c new file mode 100644 index 00000000000..2ced1971236 --- /dev/null +++ b/ggml/src/ggml-hexagon/htp/cumsum-ops.c @@ -0,0 +1,270 @@ +#pragma clang diagnostic ignored "-Wunused-variable" +#pragma clang diagnostic ignored "-Wunused-function" +#pragma clang diagnostic ignored "-Wunused-but-set-variable" + +#include <HAP_farf.h> +#include <HAP_perf.h> + +#define GGML_COMMON_DECL_C +#include "ggml-common.h" +#include "htp-ctx.h" +#include "htp-ops.h" +#include "hvx-types.h" +#include "hvx-utils.h" +#include "hex-dma.h" + +#define htp_cumsum_tensors_preamble \ + const struct htp_tensor * restrict src0 = octx->src[0]; \ + const struct htp_tensor * restrict dst = octx->dst; \ + \ + const uint32_t ne00 = src0->ne[0]; \ + const uint32_t ne01 = src0->ne[1]; \ + const uint32_t ne02 = src0->ne[2]; \ + const uint32_t ne03 = src0->ne[3]; \ + \ + const uint32_t ne0 = dst->ne[0]; \ + const uint32_t ne1 = dst->ne[1]; \ + const uint32_t ne2 = dst->ne[2]; \ + const uint32_t ne3 = dst->ne[3]; \ + \ + const uint32_t nb00 = src0->nb[0]; \ + const uint32_t nb01 = src0->nb[1]; \ + const uint32_t nb02 = src0->nb[2]; \ + const uint32_t nb03 = src0->nb[3]; \ + \ + const uint32_t nb0 = dst->nb[0]; \ + const uint32_t nb1 = dst->nb[1]; \ + const uint32_t nb2 = dst->nb[2]; \ + const uint32_t nb3 = dst->nb[3]; + +struct htp_cumsum_context { + struct htp_ops_context * octx; + size_t src_row_size; + size_t dst_row_size; + size_t src_row_size_aligned; + size_t dst_row_size_aligned; + uint32_t rows_per_thread; + uint32_t total_rows; +}; + +#define htp_cumsum_preamble \ + struct htp_cumsum_context * cctx = (struct htp_cumsum_context *) data; \ + struct htp_ops_context * octx = cctx->octx; \ + htp_cumsum_tensors_preamble; \ + dma_queue * dma_queue = octx->ctx->dma[ith]; + +// --------------------------------------------------------------------------- +// HVX prefix scan helpers +// --------------------------------------------------------------------------- + +#if __HVX_ARCH__ > 75 +static inline HVX_Vector hvx_cumsum_vadd(HVX_Vector a, HVX_Vector b) { + return Q6_Vsf_vadd_VsfVsf(a, b); +} +#else +static inline HVX_Vector hvx_cumsum_vadd(HVX_Vector a, HVX_Vector b) { + return Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_VsfVsf(a, b)); +} +#endif // __HVX_ARCH__ > 75 + +static inline HVX_Vector hvx_prefix_scan_f32(HVX_Vector v, HVX_Vector carry_in) { + const HVX_Vector zero = Q6_V_vsplat_R(0); + + v = hvx_cumsum_vadd(v, Q6_V_vlalign_VVR(v, zero, 4)); + v = hvx_cumsum_vadd(v, Q6_V_vlalign_VVR(v, zero, 8)); + v = hvx_cumsum_vadd(v, Q6_V_vlalign_VVR(v, zero, 16)); + v = hvx_cumsum_vadd(v, Q6_V_vlalign_VVR(v, zero, 32)); + v = hvx_cumsum_vadd(v, Q6_V_vlalign_VVR(v, zero, 64)); + v = hvx_cumsum_vadd(v, carry_in); + + return v; +} + +static inline HVX_Vector hvx_splat_last_f32(HVX_Vector v) { + return hvx_vec_repl4(Q6_V_vror_VR(v, 124)); +} + +static inline void hvx_cumsum_row_f32(const float * restrict src, float * restrict dst, uint32_t n) { + const uint32_t nvec = n / VLEN_FP32; + const uint32_t nloe = n % VLEN_FP32; + + HVX_Vector carry = Q6_V_vsplat_R(0); + + for (uint32_t i = 0; i < nvec; i++) { + HVX_Vector v = *((const HVX_UVector *) (src + i * VLEN_FP32)); + v = hvx_prefix_scan_f32(v, carry); + hvx_vec_store_u(dst + i * VLEN_FP32, VLEN, v); + carry = hvx_splat_last_f32(v); + } + + if (nloe) { + float acc = hvx_vec_get_f32(carry); + const float * src_tail = src + nvec * VLEN_FP32; + float * dst_tail = dst + nvec * VLEN_FP32; + for (uint32_t i = 0; i < nloe; i++) { + acc += src_tail[i]; + dst_tail[i] = acc; + } + } +} + +// --------------------------------------------------------------------------- +// Per thread worker: Double-buffered DMA +// --------------------------------------------------------------------------- + +static void cumsum_thread_f32_dma(unsigned int nth, unsigned int ith, void * data) { + htp_cumsum_preamble; + + uint64_t t1, t2; + t1 = HAP_perf_get_qtimer_count(); + + const uint32_t ir0 = cctx->rows_per_thread * ith; + const uint32_t ir1 = MIN(ir0 + cctx->rows_per_thread, cctx->total_rows); + + if (ir0 >= ir1) { + return; + } + + const size_t src_row_size = cctx->src_row_size; + const size_t dst_row_size = cctx->dst_row_size; + const size_t src_row_size_aligned = cctx->src_row_size_aligned; + const size_t dst_row_size_aligned = cctx->dst_row_size_aligned; + + const uint8_t * src_data = (const uint8_t *) src0->data; + uint8_t * dst_data = (uint8_t *) dst->data; + + uint8_t * src_spad = octx->src0_spad.data + (ith * src_row_size_aligned * 2); + uint8_t * dst_spad = octx->dst_spad.data + (ith * dst_row_size_aligned * 2); + + for (uint32_t ir = ir0, spad_idx = 0; ir < ir1 && spad_idx < 2; ir++, spad_idx++) { + // Dummy dst writeback to establish queue ordering + dma_queue_push_vtcm_to_ddr(dma_queue, + dma_make_ptr(dst_data, dst_spad + (spad_idx * dst_row_size_aligned)), + dst_row_size, dst_row_size_aligned, 0); + + dma_queue_push_ddr_to_vtcm(dma_queue, + dma_make_ptr(src_spad + (spad_idx * src_row_size_aligned), + src_data + (ir * src_row_size)), + src_row_size_aligned, src_row_size, 1); + } + + for (uint32_t ir = ir0; ir < ir1; ir++) { + float * dst_spad_row = (float *) dma_queue_pop(dma_queue).src; + float * src_spad_row = (float *) dma_queue_pop(dma_queue).dst; + + hvx_cumsum_row_f32(src_spad_row, dst_spad_row, ne00); + + dma_queue_push_vtcm_to_ddr(dma_queue, + dma_make_ptr(dst_data + (ir * dst_row_size), (uint8_t *) dst_spad_row), + dst_row_size, dst_row_size_aligned, 1); + + const uint32_t next_row = ir + 2; + if (next_row < ir1) { + dma_queue_push_ddr_to_vtcm(dma_queue, + dma_make_ptr((uint8_t *) src_spad_row, src_data + (next_row * src_row_size)), + src_row_size_aligned, src_row_size, 1); + } + } + + dma_queue_flush(dma_queue); + t2 = HAP_perf_get_qtimer_count(); + + FARF(HIGH, "cumsum-f32-dma %d/%d: %ux%ux%ux%u (%u:%u) -> %ux%ux%ux%u usec %u\n", + ith, nth, src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], ir0, ir1, + dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], + (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1)); +} + +// --------------------------------------------------------------------------- +// Per thread worker: Direct HVX (no DMA) +// --------------------------------------------------------------------------- + +static void cumsum_thread_f32(unsigned int nth, unsigned int ith, void * data) { + htp_cumsum_preamble; + + uint64_t t1, t2; + t1 = HAP_perf_get_qtimer_count(); + + const uint8_t * src_data = (const uint8_t *) src0->data; + uint8_t * dst_data = (uint8_t *) dst->data; + + const uint32_t ir0 = cctx->rows_per_thread * ith; + const uint32_t ir1 = MIN(ir0 + cctx->rows_per_thread, cctx->total_rows); + + for (uint32_t ir = ir0; ir < ir1; ir++) { + const float * restrict src_row = (const float *) (src_data + ir * cctx->src_row_size); + float * restrict dst_row = (float *) (dst_data + ir * cctx->dst_row_size); + hvx_cumsum_row_f32(src_row, dst_row, ne00); + } + + t2 = HAP_perf_get_qtimer_count(); + + FARF(HIGH, "cumsum-f32 %d/%d: %ux%ux%ux%u (%u:%u) -> %ux%ux%ux%u usec %u\n", + ith, nth, src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], ir0, ir1, + dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], + (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1)); +} + +int op_cumsum_f32(struct htp_ops_context * octx) { + const struct htp_tensor * src0 = octx->src[0]; + const struct htp_tensor * dst = octx->dst; + + if (octx->flags & HTP_OPFLAGS_SKIP_COMPUTE) { + return HTP_STATUS_OK; + } + + const uint32_t total_rows = src0->ne[1] * src0->ne[2] * src0->ne[3]; + const uint32_t n_threads = MIN(octx->n_threads, total_rows); + + const size_t src_row_size = src0->nb[1]; + const size_t dst_row_size = dst->nb[1]; + const size_t src_row_size_aligned = hex_round_up(src_row_size, VLEN); + const size_t dst_row_size_aligned = hex_round_up(dst_row_size, VLEN); + + // 2 ping-pong buffers per thread for src and dst + const size_t spad_per_thread = 2 * (src_row_size_aligned + dst_row_size_aligned); + + octx->src0_spad.size_per_thread = src_row_size_aligned * 2; + octx->dst_spad.size_per_thread = dst_row_size_aligned * 2; + + octx->src0_spad.size = n_threads * octx->src0_spad.size_per_thread; + octx->dst_spad.size = n_threads * octx->dst_spad.size_per_thread; + + octx->src0_spad.data = octx->ctx->vtcm_base; octx->src0_spad.src = NULL; + octx->dst_spad.data = octx->src0_spad.data + octx->src0_spad.size; octx->dst_spad.src = NULL; + + struct htp_cumsum_context cctx = { + .octx = octx, + .src_row_size = src_row_size, + .dst_row_size = dst_row_size, + .src_row_size_aligned = src_row_size_aligned, + .dst_row_size_aligned = dst_row_size_aligned, + .rows_per_thread = (total_rows + n_threads - 1) / n_threads, + .total_rows = total_rows, + }; + + if (octx->ctx->vtcm_size < spad_per_thread * n_threads) { + worker_pool_run_func(octx->ctx->worker_pool, cumsum_thread_f32, &cctx, n_threads); + } else { + worker_pool_run_func(octx->ctx->worker_pool, cumsum_thread_f32_dma, &cctx, n_threads); + } + + return HTP_STATUS_OK; +} + +int op_cumsum(struct htp_ops_context * octx) { + const struct htp_tensor * dst = octx->dst; + + int err = HTP_STATUS_OK; + + switch (dst->type) { + case HTP_TYPE_F32: + err = op_cumsum_f32(octx); + break; + default: + err = HTP_STATUS_NO_SUPPORT; + break; + } + + return err; +} diff --git a/ggml/src/ggml-hexagon/htp/diag-ops.c b/ggml/src/ggml-hexagon/htp/diag-ops.c new file mode 100644 index 00000000000..9b3194d9084 --- /dev/null +++ b/ggml/src/ggml-hexagon/htp/diag-ops.c @@ -0,0 +1,216 @@ +#pragma clang diagnostic ignored "-Wunused-but-set-variable" + +#include <HAP_farf.h> +#include <HAP_perf.h> + +#define GGML_COMMON_DECL_C +#include "ggml-common.h" +#include "htp-ctx.h" +#include "htp-ops.h" +#include "hvx-types.h" +#include "hex-utils.h" +#include "hvx-copy.h" +#include "hex-dma.h" + +#define htp_diag_tensors_preamble \ + const struct htp_tensor * restrict src0 = octx->src[0]; \ + const struct htp_tensor * restrict dst = octx->dst; \ + \ + const uint32_t ne02 = src0->ne[2]; \ + \ + const uint32_t ne0 = dst->ne[0]; \ + const uint32_t ne1 = dst->ne[1]; \ + \ + const uint32_t nb02 = src0->nb[2]; \ + const uint32_t nb03 = src0->nb[3]; \ + \ + const uint32_t nb1 = dst->nb[1]; \ + const uint32_t nb2 = dst->nb[2]; \ + const uint32_t nb3 = dst->nb[3]; + +struct htp_diag_context { + struct htp_ops_context * octx; + size_t src_batch_size; + size_t dst_row_size; + size_t src_batch_size_aligned; + size_t dst_row_size_aligned; + uint32_t batches_per_thread; + uint32_t total_batches; +}; + +#define htp_diag_preamble \ + struct htp_diag_context * dctx = (struct htp_diag_context *) data; \ + struct htp_ops_context * octx = dctx->octx; \ + htp_diag_tensors_preamble; + +static inline void hvx_diag_row_f32(const float * restrict src, float * restrict dst, + uint32_t row_idx, uint32_t n) { + hvx_splat_f32_a((uint8_t *) dst, 0.0f, n); + dst[row_idx] = src[row_idx]; +} + +// --------------------------------------------------------------------------- +// Per thread worker: DMA src fetch, compute in VTCM, DMA dst writeback +// --------------------------------------------------------------------------- + +static void diag_thread_f32_dma(unsigned int nth, unsigned int ith, void * data) { + htp_diag_preamble; + dma_queue * dma_queue = octx->ctx->dma[ith]; + + uint64_t t1, t2; + t1 = HAP_perf_get_qtimer_count(); + + const uint32_t ib0 = dctx->batches_per_thread * ith; + const uint32_t ib1 = MIN(ib0 + dctx->batches_per_thread, dctx->total_batches); + + if (ib0 >= ib1) { + return; + } + + const size_t src_batch_size = dctx->src_batch_size; + const size_t dst_row_size = dctx->dst_row_size; + const size_t src_batch_size_aligned = dctx->src_batch_size_aligned; + const size_t dst_row_size_aligned = dctx->dst_row_size_aligned; + + const uint8_t * src_data = (const uint8_t *) src0->data; + uint8_t * dst_data = (uint8_t *) dst->data; + + // 1 src buffer + 1 dst row buffer per thread in VTCM + uint8_t * src_spad = octx->src0_spad.data + (ith * src_batch_size_aligned); + uint8_t * dst_spad = octx->dst_spad.data + (ith * dst_row_size_aligned); + + for (uint32_t ib = ib0; ib < ib1; ib++) { + const uint32_t i3 = ib / ne02; + const uint32_t i2 = ib % ne02; + + const uint8_t * src_batch = src_data + i3 * nb03 + i2 * nb02; + + // Fetch source vector into VTCM + dma_queue_push_ddr_to_vtcm(dma_queue, + dma_make_ptr(src_spad, src_batch), + src_batch_size_aligned, src_batch_size, 1); + dma_queue_flush(dma_queue); + + const float * src_spad_f32 = (const float *) src_spad; + float * dst_spad_f32 = (float *) dst_spad; + + for (uint32_t i1 = 0; i1 < ne1; i1++) { + // Compute row in VTCM + hvx_diag_row_f32(src_spad_f32, dst_spad_f32, i1, ne0); + + // Write completed row back to DDR + uint8_t * dst_row = dst_data + i3 * nb3 + i2 * nb2 + i1 * nb1; + dma_queue_push_vtcm_to_ddr(dma_queue, + dma_make_ptr(dst_row, dst_spad), + dst_row_size, dst_row_size_aligned, 1); + dma_queue_flush(dma_queue); + } + } + + t2 = HAP_perf_get_qtimer_count(); + + FARF(HIGH, "diag-f32-dma %d/%d: %ux%ux%ux%u (%u:%u) -> %ux%ux%ux%u usec %u\n", + ith, nth, src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], ib0, ib1, + dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], + (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1)); +} + +// --------------------------------------------------------------------------- +// Per thread worker: Direct HVX (no DMA) +// --------------------------------------------------------------------------- + +static void diag_thread_f32(unsigned int nth, unsigned int ith, void * data) { + htp_diag_preamble; + + uint64_t t1, t2; + t1 = HAP_perf_get_qtimer_count(); + + const uint8_t * src_data = (const uint8_t *) src0->data; + uint8_t * dst_data = (uint8_t *) dst->data; + + const uint32_t ib0 = dctx->batches_per_thread * ith; + const uint32_t ib1 = MIN(ib0 + dctx->batches_per_thread, dctx->total_batches); + + for (uint32_t ib = ib0; ib < ib1; ib++) { + const uint32_t i3 = ib / ne02; + const uint32_t i2 = ib % ne02; + + const float * restrict src_batch = (const float *)(src_data + i3 * nb03 + i2 * nb02); + + for (uint32_t i1 = 0; i1 < ne1; i1++) { + float * restrict dst_row = (float *)(dst_data + i3 * nb3 + i2 * nb2 + i1 * nb1); + hvx_diag_row_f32(src_batch, dst_row, i1, ne0); + } + } + + t2 = HAP_perf_get_qtimer_count(); + + FARF(HIGH, "diag-f32 %d/%d: %ux%ux%ux%u (%u:%u) -> %ux%ux%ux%u usec %u\n", + ith, nth, src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], ib0, ib1, + dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], + (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1)); +} + +int op_diag_f32(struct htp_ops_context * octx) { + const struct htp_tensor * src0 = octx->src[0]; + const struct htp_tensor * dst = octx->dst; + + if (octx->flags & HTP_OPFLAGS_SKIP_COMPUTE) { + return HTP_STATUS_OK; + } + + const uint32_t total_batches = src0->ne[2] * src0->ne[3]; + const uint32_t n_threads = MIN(octx->n_threads, total_batches); + + const size_t src_batch_size = src0->ne[0] * sizeof(float); + const size_t dst_row_size = dst->ne[0] * sizeof(float); + const size_t src_batch_size_aligned = hex_round_up(src_batch_size, VLEN); + const size_t dst_row_size_aligned = hex_round_up(dst_row_size, VLEN); + + // 1 src buffer + 1 dst row buffer per thread + const size_t spad_per_thread = src_batch_size_aligned + dst_row_size_aligned; + + octx->src0_spad.size_per_thread = src_batch_size_aligned; + octx->dst_spad.size_per_thread = dst_row_size_aligned; + + octx->src0_spad.size = n_threads * octx->src0_spad.size_per_thread; + octx->dst_spad.size = n_threads * octx->dst_spad.size_per_thread; + + octx->src0_spad.data = octx->ctx->vtcm_base; octx->src0_spad.src = NULL; + octx->dst_spad.data = octx->src0_spad.data + octx->src0_spad.size; octx->dst_spad.src = NULL; + + struct htp_diag_context dctx = { + .octx = octx, + .src_batch_size = src_batch_size, + .dst_row_size = dst_row_size, + .src_batch_size_aligned = src_batch_size_aligned, + .dst_row_size_aligned = dst_row_size_aligned, + .batches_per_thread = (total_batches + n_threads - 1) / n_threads, + .total_batches = total_batches, + }; + + if (octx->ctx->vtcm_size < spad_per_thread * n_threads) { + worker_pool_run_func(octx->ctx->worker_pool, diag_thread_f32, &dctx, n_threads); + } else { + worker_pool_run_func(octx->ctx->worker_pool, diag_thread_f32_dma, &dctx, n_threads); + } + + return HTP_STATUS_OK; +} + +int op_diag(struct htp_ops_context * octx) { + const struct htp_tensor * dst = octx->dst; + + int err = HTP_STATUS_OK; + + switch (dst->type) { + case HTP_TYPE_F32: + err = op_diag_f32(octx); + break; + default: + err = HTP_STATUS_NO_SUPPORT; + break; + } + + return err; +} diff --git a/ggml/src/ggml-hexagon/htp/fill-ops.c b/ggml/src/ggml-hexagon/htp/fill-ops.c new file mode 100644 index 00000000000..3ccfbe74ee4 --- /dev/null +++ b/ggml/src/ggml-hexagon/htp/fill-ops.c @@ -0,0 +1,123 @@ +#pragma clang diagnostic ignored "-Wunused-variable" +#pragma clang diagnostic ignored "-Wunused-function" +#pragma clang diagnostic ignored "-Wunused-but-set-variable" + +#include <HAP_farf.h> +#include <HAP_perf.h> + +#include <string.h> + +#include "hvx-copy.h" +#include "hvx-utils.h" + +#define GGML_COMMON_DECL_C +#include "ggml-common.h" +#include "htp-ctx.h" +#include "htp-ops.h" + +// ggml op_params layout for FILL: +// op_params[0] (as float) - the scalar fill value + +#define fill_preamble \ + const struct htp_tensor * dst = octx->dst; \ + \ + const uint32_t ne0 = dst->ne[0]; \ + const uint32_t ne1 = dst->ne[1]; \ + const uint32_t ne2 = dst->ne[2]; \ + const uint32_t ne3 = dst->ne[3]; \ + \ + const uint32_t nb1 = dst->nb[1]; \ + const uint32_t nb2 = dst->nb[2]; \ + const uint32_t nb3 = dst->nb[3]; \ + \ + const uint32_t nr = ne1 * ne2 * ne3; + +struct htp_fill_context { + struct htp_ops_context * octx; + uint32_t nrows_per_thread; + uint32_t total_rows; // ne1 * ne2 * ne3 + bool opt_path; + HVX_Vector splat_vec; + uint32_t elem_size; +}; + +static void fill_thread(unsigned int nth, unsigned int ith, void * data) { + const struct htp_fill_context * fctx = (const struct htp_fill_context *) data; + struct htp_ops_context * octx = fctx->octx; + fill_preamble; + + // Parallelise over the flat row index spanning ne1*ne2*ne3 + const uint32_t ir0 = fctx->nrows_per_thread * ith; + const uint32_t ir1 = MIN(ir0 + fctx->nrows_per_thread, fctx->total_rows); + + uint64_t t1 = HAP_perf_get_qtimer_count(); + + if (fctx->opt_path) { + // Opt path: tensor is fully contiguous, treat as flat array + const uint32_t elem_start = ir0 * ne0; + const uint32_t elem_end = ir1 * ne0; + uint8_t * dst_ptr = (uint8_t *) dst->data + elem_start * fctx->elem_size; + hvx_splat_u(dst_ptr, fctx->splat_vec, elem_end - elem_start, fctx->elem_size); + } else { + // Non-contiguous path: must respect strides + for (uint32_t ir = ir0; ir < ir1; ++ir) { + const uint32_t i1 = ir % ne1; + const uint32_t i2 = (ir / ne1) % ne2; + const uint32_t i3 = ir / (ne1 * ne2); + uint8_t * dst_ptr = (uint8_t *) dst->data + i1*nb1 + i2*nb2 + i3*nb3; + hvx_splat_u(dst_ptr, fctx->splat_vec, ne0, fctx->elem_size); + } + } + + uint64_t t2 = HAP_perf_get_qtimer_count(); + FARF(HIGH, "fill %u/%u: rows %u:%u usec %u\n", + ith, nth, ir0, ir1, (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1)); +} + +int op_fill(struct htp_ops_context * octx) { + fill_preamble; + + if (dst->type != HTP_TYPE_F32 && dst->type != HTP_TYPE_F16) { + return HTP_STATUS_NO_SUPPORT; + } + + if (octx->flags & HTP_OPFLAGS_SKIP_COMPUTE) { + return HTP_STATUS_OK; + } + + // nr = ne1*ne2*ne3 (flat row count across all outer dims); parallelise over it. + const uint32_t n_threads = MIN(nr, octx->n_threads); + + // Optimize if fully contiguous: skip stride arithmetic, treat as flat array + const bool opt_path = (nb2 == nb1 * ne1) && (nb3 == nb2 * ne2); + + FARF(HIGH, "fill: (%ux%ux%ux%u) type=%u opt=%d\n", + dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], dst->type, (int) opt_path); + + float val_f32 = 0.f; + memcpy(&val_f32, &octx->op_params[0], sizeof(float)); + + struct htp_fill_context fctx = { + .octx = octx, + .nrows_per_thread = (nr + n_threads - 1) / n_threads, + .total_rows = nr, + .opt_path = opt_path, + }; + + switch (dst->type) { + case HTP_TYPE_F32: + fctx.splat_vec = hvx_vec_splat_f32(val_f32); + fctx.elem_size = sizeof(float); + break; + case HTP_TYPE_F16: + fctx.splat_vec = hvx_vec_splat_f16((_Float16) val_f32); + fctx.elem_size = sizeof(_Float16); + break; + default: + return HTP_STATUS_NO_SUPPORT; + } + + worker_pool_run_func(octx->ctx->worker_pool, fill_thread, &fctx, n_threads); + + return HTP_STATUS_OK; +} diff --git a/ggml/src/ggml-hexagon/htp/flash-attn-ops.c b/ggml/src/ggml-hexagon/htp/flash-attn-ops.c index 04a7b843ce5..e996214691a 100644 --- a/ggml/src/ggml-hexagon/htp/flash-attn-ops.c +++ b/ggml/src/ggml-hexagon/htp/flash-attn-ops.c @@ -2,166 +2,299 @@ #pragma clang diagnostic ignored "-Wunused-function" #pragma clang diagnostic ignored "-Wunused-but-set-variable" -#ifdef HTP_DEBUG -# define FARF_HIGH 1 -#endif +#include <assert.h> #include <HAP_farf.h> -#include <HAP_mem.h> #include <HAP_perf.h> -#include <hexagon_protos.h> -#include <hexagon_types.h> #include <math.h> #include <string.h> +#include "hex-dma.h" +#include "hvx-utils.h" +#include "hvx-dump.h" +#include "hvx-flash-attn.h" + #define GGML_COMMON_DECL_C #include "ggml-common.h" #include "htp-ctx.h" -#include "htp-dma.h" -#include "htp-msg.h" #include "htp-ops.h" -#include "hvx-utils.h" -#include "ops-utils.h" +#include "htp-ops.h" +#include "hmx-ops.h" + +// Must be multiple of 32 +#define FLASH_ATTN_BLOCK_SIZE (32 * 2) + +#if __HVX_ARCH__ < 79 +#define HVX_OP_ADD_F32(a, b) Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_VsfVsf(a, b)) +#define HVX_OP_SUB_F32(a, b) Q6_Vsf_equals_Vqf32(Q6_Vqf32_vsub_VsfVsf(a, b)) +#define HVX_OP_MUL_F32(a, b) Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(a, b)) +#else +#define HVX_OP_ADD_F32(a, b) Q6_Vsf_vadd_VsfVsf(a, b) +#define HVX_OP_SUB_F32(a, b) Q6_Vsf_vsub_VsfVsf(a, b) +#define HVX_OP_MUL_F32(a, b) Q6_Vsf_vmpy_VsfVsf(a, b) +#endif -// Dot product of FP32 and FP16 vectors, accumulating to float -static inline void hvx_dot_f32_f16_aa(float * restrict r, const void * restrict y, const void * restrict x, unsigned int n, float s) { - const HVX_Vector * restrict vy = (const HVX_Vector * restrict) y; // fp32 +// This is a bit of a hack because the compiler is strugling to properly inline +// the default hvx_vec_f32_to_f16 with output into the local array. +static __attribute__((noinline)) void hvx_vec_f32_to_f16_a(void *ptr, HVX_Vector v0, HVX_Vector v1) +{ + *(HVX_Vector *) ptr = hvx_vec_f32_to_f16(v0, v1); +} + +// Dot product of two F16 vectors, accumulating to float +static inline void hvx_dot_f16_f16_aa(float * restrict r, const void * restrict x, const void * restrict y, unsigned int n, float s) { const HVX_Vector * restrict vx = (const HVX_Vector * restrict) x; // fp16 + const HVX_Vector * restrict vy = (const HVX_Vector * restrict) y; // fp16 uint32_t nvec = n / VLEN_FP16; // num full fp16 hvx vectors uint32_t nloe = n % VLEN_FP16; // leftover elements - const HVX_Vector zero = Q6_V_vsplat_R(0); - HVX_Vector rsum = Q6_V_vsplat_R(0); + HVX_VectorPair rsum_p = Q6_W_vcombine_VV(Q6_V_vsplat_R(0), Q6_V_vsplat_R(0)); uint32_t i = 0; #pragma unroll(4) for (i = 0; i < nvec; i++) { - // Load y (fp32) and convert into fp16 - HVX_Vector y0_qf = Q6_Vqf32_vsub_VsfVsf(vy[i*2+0], zero); // 32 elements - HVX_Vector y1_qf = Q6_Vqf32_vsub_VsfVsf(vy[i*2+1], zero); // 32 elements - HVX_Vector y_hf = Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(y1_qf, y0_qf))); - - // Load x (fp16) - HVX_Vector x_hf = vx[i]; + rsum_p = hvx_vec_mpyacc_f32_f16(rsum_p, vx[i], vy[i]); + } - HVX_VectorPair xy_qf = Q6_Wqf32_vmpy_VhfVhf(x_hf, y_hf); + if (nloe) { + HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 2); + HVX_Vector y_hf = Q6_V_vand_QV(bmask, vy[i]); + HVX_Vector x_hf = Q6_V_vand_QV(bmask, vx[i]); - rsum = Q6_Vqf32_vadd_Vqf32Vqf32(rsum, Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy_qf), Q6_V_hi_W(xy_qf))); + rsum_p = hvx_vec_mpyacc_f32_f16(rsum_p, x_hf, y_hf); } - if (nloe) { - // Load y (fp32) and convert into fp16 - HVX_Vector y0_qf = Q6_Vqf32_vsub_VsfVsf(vy[i*2+0], zero); // 32 elements - HVX_Vector y1_qf = Q6_Vqf32_vsub_VsfVsf(vy[i*2+1], zero); // 32 elements - HVX_Vector y_hf = Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(y1_qf, y0_qf))); + HVX_Vector rsum = HVX_OP_ADD_F32(Q6_V_lo_W(rsum_p), Q6_V_hi_W(rsum_p)); + rsum = HVX_OP_MUL_F32(hvx_vec_splat_f32(s), hvx_vec_reduce_sum_f32(rsum)); + hvx_vec_store_u(r, 4, rsum); +} - // Load x (fp16) - HVX_Vector x_hf = vx[i]; +static inline HVX_Vector hvx_dot_f16_f16_aa_rx4(const void * restrict y, + const uint8_t * restrict x, + const size_t stride_x, + const size_t nvec, + const size_t nloe) { + const HVX_Vector * restrict vx0 = (const HVX_Vector * restrict) x; // fp16 + const HVX_Vector * restrict vx1 = (const HVX_Vector * restrict) (x + stride_x); // fp16 + const HVX_Vector * restrict vx2 = (const HVX_Vector * restrict) (x + stride_x * 2); // fp16 + const HVX_Vector * restrict vx3 = (const HVX_Vector * restrict) (x + stride_x * 3); // fp16 + const HVX_Vector * restrict vy = (const HVX_Vector * restrict) y; // fp16 + + HVX_VectorPair rsum0_p = Q6_W_vcombine_VV(Q6_V_vsplat_R(0), Q6_V_vsplat_R(0)); + HVX_VectorPair rsum1_p = Q6_W_vcombine_VV(Q6_V_vsplat_R(0), Q6_V_vsplat_R(0)); + HVX_VectorPair rsum2_p = Q6_W_vcombine_VV(Q6_V_vsplat_R(0), Q6_V_vsplat_R(0)); + HVX_VectorPair rsum3_p = Q6_W_vcombine_VV(Q6_V_vsplat_R(0), Q6_V_vsplat_R(0)); - // Zero-out unused elements - // Note that we need to clear both x and y because they may contain NANs - HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 2); - x_hf = Q6_V_vand_QV(bmask, x_hf); - y_hf = Q6_V_vand_QV(bmask, y_hf); + uint32_t i = 0; - HVX_VectorPair xy_qf = Q6_Wqf32_vmpy_VhfVhf(x_hf, y_hf); + for (i = 0; i < nvec; i++) { + HVX_Vector y_hf = vy[i]; + HVX_Vector x0_hf = vx0[i]; + HVX_Vector x1_hf = vx1[i]; + HVX_Vector x2_hf = vx2[i]; + HVX_Vector x3_hf = vx3[i]; + + rsum0_p = hvx_vec_mpyacc_f32_f16(rsum0_p, x0_hf, y_hf); + rsum1_p = hvx_vec_mpyacc_f32_f16(rsum1_p, x1_hf, y_hf); + rsum2_p = hvx_vec_mpyacc_f32_f16(rsum2_p, x2_hf, y_hf); + rsum3_p = hvx_vec_mpyacc_f32_f16(rsum3_p, x3_hf, y_hf); + } - rsum = Q6_Vqf32_vadd_Vqf32Vqf32(rsum, Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy_qf), Q6_V_hi_W(xy_qf))); + if (nloe) { + // Load x (fp16) and zero-out unused elements + HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 2); + HVX_Vector y_hf = Q6_V_vand_QV(bmask, vy[i]); + HVX_Vector x0_hf = Q6_V_vand_QV(bmask, vx0[i]); + HVX_Vector x1_hf = Q6_V_vand_QV(bmask, vx1[i]); + HVX_Vector x2_hf = Q6_V_vand_QV(bmask, vx2[i]); + HVX_Vector x3_hf = Q6_V_vand_QV(bmask, vx3[i]); + + rsum0_p = hvx_vec_mpyacc_f32_f16(rsum0_p, x0_hf, y_hf); + rsum1_p = hvx_vec_mpyacc_f32_f16(rsum1_p, x1_hf, y_hf); + rsum2_p = hvx_vec_mpyacc_f32_f16(rsum2_p, x2_hf, y_hf); + rsum3_p = hvx_vec_mpyacc_f32_f16(rsum3_p, x3_hf, y_hf); } - rsum = Q6_Vqf32_vmpy_VsfVsf(Q6_Vsf_equals_Vqf32(rsum), hvx_vec_splat_fp32(s)); - rsum = Q6_Vsf_equals_Vqf32(hvx_vec_qf32_reduce_sum(rsum)); + HVX_Vector rsum0 = HVX_OP_ADD_F32(Q6_V_lo_W(rsum0_p), Q6_V_hi_W(rsum0_p)); + HVX_Vector rsum1 = HVX_OP_ADD_F32(Q6_V_lo_W(rsum1_p), Q6_V_hi_W(rsum1_p)); + HVX_Vector rsum2 = HVX_OP_ADD_F32(Q6_V_lo_W(rsum2_p), Q6_V_hi_W(rsum2_p)); + HVX_Vector rsum3 = HVX_OP_ADD_F32(Q6_V_lo_W(rsum3_p), Q6_V_hi_W(rsum3_p)); - hvx_vec_store_u(r, 4, rsum); + HVX_Vector_x4 rsum0123 = { .v = { rsum0, rsum1, rsum2, rsum3 } }; + return hvx_vec_reduce_sum_f32x4(rsum0123); } -// Dot product of two F16 vectors, accumulating to float -static inline void hvx_dot_f16_f16_aa(float * restrict r, const void * restrict x, const void * restrict y, unsigned int n, float s) { - const HVX_Vector * restrict vx = (const HVX_Vector * restrict) x; // fp16 - const HVX_Vector * restrict vy = (const HVX_Vector * restrict) y; // fp16 +static inline HVX_Vector hvx_dot_f16_f16_aa_rx32(const void * restrict y, + const uint8_t * restrict x, + const size_t stride_x, + const size_t n, + float s) { + + const size_t nvec = n / VLEN_FP16; // num full fp16 hvx vectors + const size_t nloe = n % VLEN_FP16; // leftover elements + + HVX_Vector sums = Q6_V_vzero(); + const size_t stride_x_4 = stride_x * 4; + for (uint32_t j = 0; j < VLEN_FP32; j += 4) { + HVX_Vector sums_x4 = hvx_dot_f16_f16_aa_rx4(y, x, stride_x, nvec, nloe); + HVX_VectorPred pred = Q6_Q_vsetq_R(j * SIZEOF_FP32); + sums = Q6_V_vmux_QVV(pred, sums, sums_x4); + x += stride_x_4; + } + + return HVX_OP_MUL_F32(hvx_vec_splat_f32(s), sums); +} + +// MAD: y (F32) += x (F16) * s (F16) +static inline void hvx_mad_f32_f16_aa(float * restrict y, const void * restrict x, const __fp16 * restrict s, int n) { + const HVX_Vector * restrict vx0 = (const HVX_Vector *) x; + + HVX_VectorPair * restrict vy_p = (HVX_VectorPair *) y; + HVX_Vector * restrict vy = (HVX_Vector *) y; uint32_t nvec = n / VLEN_FP16; // num full fp16 hvx vectors uint32_t nloe = n % VLEN_FP16; // leftover elements - const HVX_Vector zero = Q6_V_vsplat_R(0); - HVX_Vector rsum = Q6_V_vsplat_R(0); + HVX_Vector S0 = hvx_vec_splat_f16(*s); uint32_t i = 0; - #pragma unroll(4) - for (i = 0; i < nvec; i++) { - HVX_Vector y_hf = vy[i]; - HVX_Vector x_hf = vx[i]; - - HVX_VectorPair xy_qf = Q6_Wqf32_vmpy_VhfVhf(x_hf, y_hf); - - rsum = Q6_Vqf32_vadd_Vqf32Vqf32(rsum, Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy_qf), Q6_V_hi_W(xy_qf))); + #pragma unroll(2) + for (i = 0; i < nvec; ++i) { + vy_p[i] = hvx_vec_mpyacc_f32_f16(vy_p[i], Q6_Vh_vshuff_Vh(vx0[i]), S0); } if (nloe) { - HVX_Vector y_hf = vy[i]; + HVX_VectorPair xy_p = vy_p[i]; + xy_p = hvx_vec_mpyacc_f32_f16(xy_p, Q6_Vh_vshuff_Vh(vx0[i]), S0); - // Load x (fp16) and zero-out unused elements - HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 2); - HVX_Vector x_hf = Q6_V_vand_QV(bmask, vx[i]); + HVX_Vector xy = Q6_V_lo_W(xy_p); + i = 2 * i; // index for vy - HVX_VectorPair xy_qf = Q6_Wqf32_vmpy_VhfVhf(x_hf, y_hf); + if (nloe >= VLEN_FP32) { + vy[i] = xy; + nloe -= VLEN_FP32; ++i; xy = Q6_V_hi_W(xy_p); + } - rsum = Q6_Vqf32_vadd_Vqf32Vqf32(rsum, Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy_qf), Q6_V_hi_W(xy_qf))); + if (nloe) { + hvx_vec_store_a(&vy[i], nloe * 4, xy); + } } - - rsum = Q6_Vqf32_vmpy_VsfVsf(Q6_Vsf_equals_Vqf32(rsum), hvx_vec_splat_fp32(s)); - rsum = Q6_Vsf_equals_Vqf32(hvx_vec_qf32_reduce_sum(rsum)); - hvx_vec_store_u(r, 4, rsum); } -// MAD: y (F32) += x (F16) * v (float) -static inline void hvx_mad_f32_f16_aa(float * restrict y, const void * restrict x, int n, float s) { - const HVX_Vector * restrict ptr_x = (const HVX_Vector *) x; - HVX_Vector * restrict ptr_y = (HVX_Vector *) y; +// MAD: y (F32) += x0 (F16) * s0 (F16) + x1 (F16) * s1 (F16) +static inline void hvx_mad_f32_f16_aa_rx2(float * restrict y, const void * restrict x0, const void * restrict x1, + const __fp16 * restrict s0, const __fp16 * restrict s1, int n) { + const HVX_Vector * restrict vx0 = (const HVX_Vector *) x0; + const HVX_Vector * restrict vx1 = (const HVX_Vector *) x1; - uint32_t nvec = n / VLEN_FP16; // num full fp16 hvx vectors - uint32_t nloe = n % VLEN_FP16; // leftover elements + HVX_VectorPair * restrict vy_p = (HVX_VectorPair *) y; + HVX_Vector * restrict vy = (HVX_Vector *) y; - HVX_Vector S = hvx_vec_splat_fp16(s); + uint32_t nvec = n / VLEN_FP16; // num full fp16 hvx vectors + uint32_t nloe = n % VLEN_FP16; // leftover elements + + HVX_Vector S0 = hvx_vec_splat_f16(*s0); + HVX_Vector S1 = hvx_vec_splat_f16(*s1); uint32_t i = 0; - #pragma unroll(4) + + #pragma unroll(2) for (i = 0; i < nvec; ++i) { - // Multiply x * s -> pair of F32 vectors - HVX_VectorPair xs_p = Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(ptr_x[i]), S); - ptr_y[i*2] = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(Q6_V_lo_W(xs_p), ptr_y[i*2])); - ptr_y[i*2+1] = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(Q6_V_hi_W(xs_p), ptr_y[i*2+1])); + vy_p[i] = hvx_vec_mpyacc_f32_f16(vy_p[i], Q6_Vh_vshuff_Vh(vx0[i]), S0); + vy_p[i] = hvx_vec_mpyacc_f32_f16(vy_p[i], Q6_Vh_vshuff_Vh(vx1[i]), S1); } if (nloe) { - HVX_VectorPair xs_p = Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(ptr_x[i]), S); + HVX_VectorPair xy_p = vy_p[i]; + xy_p = hvx_vec_mpyacc_f32_f16(xy_p, Q6_Vh_vshuff_Vh(vx0[i]), S0); + xy_p = hvx_vec_mpyacc_f32_f16(xy_p, Q6_Vh_vshuff_Vh(vx1[i]), S1); - HVX_Vector xs = Q6_V_lo_W(xs_p); - i = 2 * i; // index for ptr_y + HVX_Vector xy = Q6_V_lo_W(xy_p); + i = 2 * i; // index for vy - if (nloe >= 32) { - ptr_y[i] = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(xs, ptr_y[i])); - nloe -= 32; ++i; xs = Q6_V_hi_W(xs_p); + if (nloe >= VLEN_FP32) { + vy[i] = xy; + nloe -= VLEN_FP32; ++i; xy = Q6_V_hi_W(xy_p); } if (nloe) { - HVX_Vector xy = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(xs, ptr_y[i])); - hvx_vec_store_u(&ptr_y[i], nloe * 4, xy); + hvx_vec_store_a(&vy[i], nloe * 4, xy); } } } -#define FLASH_ATTN_BLOCK_SIZE 128 +struct htp_fa_context { + const struct htp_ops_context * octx; + + struct fastdiv_values src0_div21; + struct fastdiv_values src0_div1; + + struct fastdiv_values broadcast_rk2; + struct fastdiv_values broadcast_rk3; + struct fastdiv_values broadcast_rv2; + struct fastdiv_values broadcast_rv3; + + struct fastdiv_values src3_div2; + struct fastdiv_values src3_div3; + + float scale; + float max_bias; + float logit_softcap; + + uint32_t n_head_log2; + float m0; + float m1; + float slopes[512]; + + uint32_t n_blocks; + + size_t size_q_row_padded; + size_t size_k_row_padded; + size_t size_v_row_padded; + + size_t size_k_block; + size_t size_v_block; + size_t size_m_block; + + uint32_t qrows; + uint32_t qrows_per_thread; + + bool is_q_fp32; + + uint64_t t_start; +}; + +static inline void hvx_scale_vec_f32_aa(uint8_t * restrict dst, const uint8_t * restrict src, const int n, HVX_Vector vs) { + assert((size_t) dst % 128 == 0); + assert((size_t) src % 128 == 0); -static void flash_attn_ext_f16_thread(struct htp_ops_context * octx, int ith, int nth) { - const struct htp_tensor * q = &octx->src0; - const struct htp_tensor * k = &octx->src1; - const struct htp_tensor * v = &octx->src2; - const struct htp_tensor * mask = (octx->src3.data) ? &octx->src3 : NULL; - const struct htp_tensor * sinks = (octx->src4.data) ? &octx->src4 : NULL; - struct htp_tensor * dst = &octx->dst; + const HVX_Vector * restrict vsrc = (const HVX_Vector * restrict) src; + HVX_Vector * restrict vdst = (HVX_Vector * restrict) dst; + + const uint32_t nvec = n / VLEN_FP32; + const uint32_t nloe = n % VLEN_FP32; + + uint32_t i = 0; + #pragma unroll(4) + for (; i < nvec; ++i) { + vdst[i] = HVX_OP_MUL_F32(vsrc[i], vs); + } + if (nloe) { + hvx_vec_store_a(&vdst[i], nloe * sizeof(float), HVX_OP_MUL_F32(vsrc[i], vs)); + } +} + +static void flash_attn_ext_f16_thread(unsigned int nth, unsigned int ith, void * data) { + struct htp_fa_context * factx = (struct htp_fa_context *) data; + const struct htp_ops_context * octx = factx->octx; + const struct htp_tensor * q = octx->src[0]; + const struct htp_tensor * k = octx->src[1]; + const struct htp_tensor * v = octx->src[2]; + const struct htp_tensor * mask = octx->src[3]; + const struct htp_tensor * sinks = octx->src[4]; + const struct htp_tensor * dst = octx->dst; const uint32_t neq0 = q->ne[0]; const uint32_t neq1 = q->ne[1]; @@ -198,22 +331,9 @@ static void flash_attn_ext_f16_thread(struct htp_ops_context * octx, int ith, in const uint32_t nb2 = dst->nb[2]; const uint32_t nb3 = dst->nb[3]; - float scale = 1.0f; - float max_bias = 0.0f; - float logit_softcap = 0.0f; - - memcpy(&scale, (float *) octx->op_params + 0, sizeof(float)); - memcpy(&max_bias, (float *) octx->op_params + 1, sizeof(float)); - memcpy(&logit_softcap, (float *) octx->op_params + 2, sizeof(float)); - - if (logit_softcap != 0) { - scale /= logit_softcap; - } - // total rows in q - const uint32_t nr = neq1*neq2*neq3; - - const uint32_t dr = (nr + nth - 1) / nth; + const uint32_t nr = factx->qrows; + const uint32_t dr = factx->qrows_per_thread; const uint32_t ir0 = dr * ith; const uint32_t ir1 = MIN(ir0 + dr, nr); @@ -225,18 +345,8 @@ static void flash_attn_ext_f16_thread(struct htp_ops_context * octx, int ith, in const uint32_t DV = nev0; const size_t size_q_row = DK * ((q->type == HTP_TYPE_F32) ? 4 : 2); - const size_t size_q_row_padded = htp_round_up(size_q_row, 128); - const size_t size_k_row = DK * sizeof(__fp16); const size_t size_v_row = DV * sizeof(__fp16); - const size_t size_m_row = FLASH_ATTN_BLOCK_SIZE * sizeof(__fp16); // Treat block as one row for mask - - const size_t size_k_row_padded = htp_round_up(size_k_row, 128); - const size_t size_v_row_padded = htp_round_up(size_v_row, 128); - - const size_t size_k_block = size_k_row_padded * FLASH_ATTN_BLOCK_SIZE; - const size_t size_v_block = size_v_row_padded * FLASH_ATTN_BLOCK_SIZE; - const size_t size_m_block = htp_round_up(FLASH_ATTN_BLOCK_SIZE * sizeof(__fp16), 128); // Scratchpad buffers for Q, K, V, Mask, and VKQ32 accumulator uint8_t * spad_q = octx->src0_spad.data + octx->src0_spad.size_per_thread * ith; @@ -245,72 +355,81 @@ static void flash_attn_ext_f16_thread(struct htp_ops_context * octx, int ith, in uint8_t * spad_m = octx->src3_spad.data + octx->src3_spad.size_per_thread * ith; uint8_t * spad_a = octx->dst_spad.data + octx->dst_spad.size_per_thread * ith; - const uint32_t n_head = neq2; - const uint32_t n_head_log2 = 1u << (uint32_t) floor(log2(n_head)); - const float m0 = powf(2.0f, -(max_bias ) / n_head_log2); - const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2); + const HVX_Vector logit_cap = hvx_vec_splat_f32(factx->logit_softcap); + + dma_cache m_cache; + dma_cache_init(&m_cache, spad_m, factx->size_m_block, DMA_CACHE_MAX_SIZE); for (uint32_t ir = ir0; ir < ir1; ++ir) { - const uint32_t iq3 = fastdiv(ir, &octx->src0_div21); - const uint32_t iq2 = fastdiv(ir - iq3*neq2*neq1, &octx->src0_div1); + const uint32_t iq3 = fastdiv(ir, &factx->src0_div21); + const uint32_t iq2 = fastdiv(ir - iq3*neq2*neq1, &factx->src0_div1); const uint32_t iq1 = (ir - iq3*neq2*neq1 - iq2 * neq1); - const uint32_t ik3 = fastdiv(iq3, &octx->broadcast_rk3); - const uint32_t ik2 = fastdiv(iq2, &octx->broadcast_rk2); + const uint32_t ik3 = fastdiv(iq3, &factx->broadcast_rk3); + const uint32_t ik2 = fastdiv(iq2, &factx->broadcast_rk2); - const uint32_t iv3 = fastdiv(iq3, &octx->broadcast_rv3); - const uint32_t iv2 = fastdiv(iq2, &octx->broadcast_rv2); + const uint32_t iv3 = fastdiv(iq3, &factx->broadcast_rv3); + const uint32_t iv2 = fastdiv(iq2, &factx->broadcast_rv2); // Fetch Q row const uint8_t * q_row_ptr = (const uint8_t *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3); - dma_queue_push(dma, dma_make_ptr(spad_q, q_row_ptr), size_q_row_padded, nbq1, size_q_row, 1); - - const uint32_t h = iq2; // head index - const float slope = (max_bias > 0.0f) ? (h < n_head_log2 ? powf(m0, h + 1) : powf(m1, 2*(h - n_head_log2) + 1)) : 1.0f; + dma_queue_push(dma, dma_make_ptr(spad_q, q_row_ptr), factx->size_q_row_padded, nbq1, size_q_row, 1); - float S = 0.0f; // sum - float M = -INFINITY; // maximum KQ value - - // Clear accumulator - float * VKQ32 = (float *) spad_a; - memset(VKQ32, 0, DV * sizeof(float)); + // FARF(HIGH, "fa %u: prefetch Q: ir %u iq1 %u iq2 %u iq3 %u q_row_ptr %p size %u : usec %u", ith, ir, iq1, iq2, iq3, q_row_ptr, size_q_row, + // (unsigned)HAP_perf_qtimer_count_to_us(HAP_perf_get_qtimer_count() - factx->t_start)); const __fp16 * mp_base = NULL; if (mask) { - const uint32_t im2 = fastmodulo(iq2, mask->ne[2], &octx->src3_div2); - const uint32_t im3 = fastmodulo(iq3, mask->ne[3], &octx->src3_div3); + const uint32_t im2 = fastmodulo(iq2, mask->ne[2], &factx->src3_div2); + const uint32_t im3 = fastmodulo(iq3, mask->ne[3], &factx->src3_div3); mp_base = (const __fp16 *) ((const uint8_t *) mask->data + iq1*mask->nb[1] + im2*mask->nb[2] + im3*mask->nb[3]); } - const uint32_t n_blocks = (nek1 + FLASH_ATTN_BLOCK_SIZE - 1) / FLASH_ATTN_BLOCK_SIZE; - // Prefetch first two blocks - for (uint32_t ib = 0; ib < MIN(n_blocks, 2); ++ib) { + for (uint32_t ib = 0; ib < MIN(factx->n_blocks, 2); ++ib) { const uint32_t ic_start = ib * FLASH_ATTN_BLOCK_SIZE; const uint32_t current_block_size = MIN(FLASH_ATTN_BLOCK_SIZE, nek1 - ic_start); // K const uint8_t * k_src = (const uint8_t *) k->data + (ic_start*nbk1 + ik2*nbk2 + ik3*nbk3); - uint8_t * k_dst = spad_k + (ib % 2) * size_k_block; - dma_queue_push(dma, dma_make_ptr(k_dst, k_src), size_k_row_padded, nbk1, size_k_row, current_block_size); + uint8_t * k_dst = spad_k + (ib % 2) * factx->size_k_block; + dma_queue_push(dma, dma_make_ptr(k_dst, k_src), factx->size_k_row_padded, nbk1, size_k_row, current_block_size); // V const uint8_t * v_src = (const uint8_t *) v->data + (ic_start*nbv1 + iv2*nbv2 + iv3*nbv3); - uint8_t * v_dst = spad_v + (ib % 2) * size_v_block; - dma_queue_push(dma, dma_make_ptr(v_dst, v_src), size_v_row_padded, nbv1, size_v_row, current_block_size); + uint8_t * v_dst = spad_v + (ib % 2) * factx->size_v_block; + dma_queue_push(dma, dma_make_ptr(v_dst, v_src), factx->size_v_row_padded, nbv1, size_v_row, current_block_size); // Mask if (mask) { const uint8_t * m_src = (const uint8_t *) (mp_base + ic_start); - uint8_t * m_dst = spad_m + (ib % 2) * size_m_block; // Mask is 1D contiguous for this row - dma_queue_push(dma, dma_make_ptr(m_dst, m_src), current_block_size * 2, current_block_size * 2, current_block_size * 2, 1); + dma_cache_push(dma, &m_cache, m_src, current_block_size * 2, current_block_size * 2, current_block_size * 2, 1); } + + // FARF(HIGH, "fa %u: prefetch KVM: ir %u ib %u iq1 %u iq2 %u iq3 %u : size_k_row %u size_v_row %u bs %u: usec %u", + // ith, ir, ib, iq1, iq2, iq3, + // size_k_row, size_v_row, current_block_size, + // (unsigned)HAP_perf_qtimer_count_to_us(HAP_perf_get_qtimer_count() - factx->t_start)); } - const uint8_t * q_ptr_vtcm = dma_queue_pop(dma).dst; + const uint32_t h = iq2; // head index + const float slope = factx->slopes[h]; - for (uint32_t ib = 0; ib < n_blocks; ++ib) { + HVX_Vector S_vec = hvx_vec_splat_f32(0.0f); + HVX_Vector M_vec = hvx_vec_splat_f32(-INFINITY); + + // Clear accumulator + hvx_splat_f32_a(spad_a, 0, DV); + float * VKQ32 = (float *) (spad_a + 0); + + uint8_t * q_ptr_vtcm = dma_queue_pop(dma).dst; + if (factx->is_q_fp32) { + hvx_copy_f16_f32_aa(q_ptr_vtcm, q_ptr_vtcm, DK); // inplace convert f32 to f16 + } + + const HVX_Vector slope_vec = hvx_vec_splat_f16(slope); + for (uint32_t ib = 0; ib < factx->n_blocks; ++ib) { const uint32_t ic_start = ib * FLASH_ATTN_BLOCK_SIZE; const uint32_t current_block_size = MIN(FLASH_ATTN_BLOCK_SIZE, nek1 - ic_start); @@ -319,156 +438,162 @@ static void flash_attn_ext_f16_thread(struct htp_ops_context * octx, int ith, in uint8_t * v_base = dma_queue_pop(dma).dst; // V __fp16 * m_base = mask ? dma_queue_pop(dma).dst : NULL; // M + // FARF(HIGH, "fa %u: process: ir %u ib %u : iq1 %u iq2 %u iq3 %u q_ptr_vtcm %p : usec %u", + // ith, ir, ib, iq1, iq2, iq3, q_ptr_vtcm, + // (unsigned)HAP_perf_qtimer_count_to_us(HAP_perf_get_qtimer_count() - factx->t_start)); + // Inner loop processing the block from VTCM uint32_t ic = 0; - // Process in blocks of 32 (VLEN_FP32) - for (; ic + VLEN_FP32 <= current_block_size; ic += VLEN_FP32) { + // Process in sub-blocks of 32 (VLEN_FP32) + HVX_Vector sb_scores[FLASH_ATTN_BLOCK_SIZE / VLEN_FP32]; + HVX_Vector v_max = hvx_vec_splat_f32(-INFINITY); + for (uint32_t iv = 0; ic < current_block_size; ic += VLEN_FP32, ++iv) { // 1. Compute scores - float __attribute__((aligned(VLEN))) scores_arr[VLEN_FP32]; - for (int j = 0; j < VLEN_FP32; ++j) { - const uint32_t cur_ic = ic + j; - const uint8_t * k_ptr = k_base + cur_ic * size_k_row_padded; - if (q->type == HTP_TYPE_F32) { - hvx_dot_f32_f16_aa(&scores_arr[j], q_ptr_vtcm, k_ptr, DK, scale); - } else { - hvx_dot_f16_f16_aa(&scores_arr[j], q_ptr_vtcm, k_ptr, DK, scale); - } - } - - HVX_Vector scores = *(HVX_Vector *) scores_arr; + HVX_Vector scores = hvx_dot_f16_f16_aa_rx32(q_ptr_vtcm, k_base + ic * factx->size_k_row_padded, factx->size_k_row_padded, DK, factx->scale); // 2. Softcap - if (logit_softcap != 0.0f) { - scores = hvx_vec_tanh_fp32(scores); - scores = Q6_Vqf32_vmpy_VsfVsf(scores, hvx_vec_splat_fp32(logit_softcap)); - scores = Q6_Vsf_equals_Vqf32(scores); + if (factx->logit_softcap != 0.0f) { + scores = hvx_vec_tanh_f32(scores); + scores = HVX_OP_MUL_F32(scores, logit_cap); } // 3. Mask if (mask) { const __fp16 * mp = m_base + ic; - HVX_Vector m_vals_fp16 = *(const HVX_UVector *) mp; - - HVX_Vector one_fp16 = Q6_Vh_vsplat_R(0x3c00); - HVX_VectorPair m_vals_fp32_pair = Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(m_vals_fp16), one_fp16); - - HVX_Vector m_vals_fp32 = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(m_vals_fp32_pair)); - - HVX_Vector slope_vec = hvx_vec_splat_fp32(slope); - HVX_Vector add_val = Q6_Vqf32_vmpy_VsfVsf(m_vals_fp32, slope_vec); - scores = Q6_Vqf32_vadd_VsfVsf(scores, Q6_Vsf_equals_Vqf32(add_val)); - scores = Q6_Vsf_equals_Vqf32(scores); + HVX_Vector m_vals_f16 = *(const HVX_UVector *) mp; + + // Multiplying -INFINITY (0xFC00) by a slope in VhfVhf instructions can incorrectly produce NaN on v79. + // Clamp -INFINITY to the max negative fp16 finite value (-65504.0f). + HVX_Vector vinf = Q6_Vh_vsplat_R(0xFC00); + HVX_Vector vmin = Q6_Vh_vsplat_R(0xFBFF); + HVX_VectorPred is_inf = Q6_Q_vcmp_eq_VhVh(m_vals_f16, vinf); + m_vals_f16 = Q6_V_vmux_QVV(is_inf, vmin, m_vals_f16); + + #if __HVX_ARCH__ >= 79 + HVX_VectorPair m_vals_f32_pair = Q6_Wsf_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(m_vals_f16), slope_vec); + HVX_Vector add_val = Q6_V_lo_W(m_vals_f32_pair); + scores = Q6_Vsf_vadd_VsfVsf(add_val, scores); + #else + HVX_VectorPair m_vals_f32_pair = Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(m_vals_f16), slope_vec); + HVX_Vector add_val = Q6_V_lo_W(m_vals_f32_pair); + scores = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(add_val, scores)); + #endif } - // 4. Online Softmax Update - HVX_Vector v_max = hvx_vec_reduce_max_fp32(scores); - float m_block = hvx_vec_get_fp32(v_max); - - float M_old = M; - float M_new = (m_block > M) ? m_block : M; - M = M_new; - - float ms = expf(M_old - M_new); - - hvx_scale_f32_aa((uint8_t *) VKQ32, (const uint8_t *) VKQ32, DV, ms); - S = S * ms; - - HVX_Vector M_new_vec = hvx_vec_splat_fp32(M_new); - HVX_Vector scores_shifted = Q6_Vqf32_vsub_VsfVsf(scores, M_new_vec); - HVX_Vector P = hvx_vec_exp_fp32(Q6_Vsf_equals_Vqf32(scores_shifted)); - - HVX_Vector p_sum_vec = hvx_vec_fp32_reduce_sum(P); - float p_sum = hvx_vec_get_fp32(p_sum_vec); - S += p_sum; - - // 5. Accumulate V - float __attribute__((aligned(VLEN))) p_arr[VLEN_FP32]; - *(HVX_Vector*)p_arr = P; - - for (int j = 0; j < VLEN_FP32; ++j) { - const uint32_t cur_ic = ic + j; - const uint8_t * v_ptr = v_base + cur_ic * size_v_row_padded; - hvx_mad_f32_f16_aa(VKQ32, v_ptr, DV, p_arr[j]); + // Mask out invalid lanes for leftover handling + uint32_t valid_lanes = current_block_size - ic; + if (valid_lanes < VLEN_FP32) { + HVX_VectorPred valid_pred = Q6_Q_vsetq_R(valid_lanes * 4); // 4 bytes per fp32 lane + scores = Q6_V_vmux_QVV(valid_pred, scores, hvx_vec_splat_f32(-INFINITY)); } - } - // Leftover - for (; ic < current_block_size; ++ic) { - float s_val; - const uint8_t * k_ptr = k_base + ic * size_k_row_padded; - - if (q->type == HTP_TYPE_F32) { - hvx_dot_f32_f16_aa(&s_val, q_ptr_vtcm, k_ptr, DK, scale); - } else { - hvx_dot_f16_f16_aa(&s_val, q_ptr_vtcm, k_ptr, DK, scale); - } - - if (logit_softcap != 0.0f) { - s_val = logit_softcap * tanhf(s_val); - } - - if (mask) { - const float m_val = m_base[ic]; - s_val += slope * m_val; - } - - const float Mold = M; - float ms = 1.0f; - float vs = 1.0f; + sb_scores[iv] = scores; + v_max = hvx_vec_reduce_max2_f32(scores, v_max); // All lanes have block max + } - if (s_val > M) { - M = s_val; - ms = expf(Mold - M); - hvx_scale_f32_aa((uint8_t *) VKQ32, (const uint8_t *) VKQ32, DV, ms); - } else { - vs = expf(s_val - M); + { + // 4. Online Softmax Update + HVX_Vector M_new_vec = Q6_Vsf_vmax_VsfVsf(v_max, M_vec); + HVX_Vector diff_vec = HVX_OP_SUB_F32(M_vec, M_new_vec); + HVX_Vector ms_vec = hvx_vec_exp_f32(diff_vec); + M_vec = M_new_vec; + + hvx_scale_vec_f32_aa((uint8_t *) VKQ32, (const uint8_t *) VKQ32, DV, ms_vec); + + HVX_Vector p_sum_vec = hvx_vec_splat_f32(0.0f); + for (uint32_t ic2 = 0, iv = 0; ic2 < current_block_size; ic2 += VLEN_FP32, ++iv) { + HVX_Vector scores = sb_scores[iv]; + HVX_Vector scores_shifted = HVX_OP_SUB_F32(scores, M_vec); + HVX_Vector P = hvx_vec_exp_f32(scores_shifted); + + p_sum_vec = HVX_OP_ADD_F32(p_sum_vec, P); + + // 5. Accumulate V + __fp16 __attribute__((aligned(VLEN))) p_arr[VLEN_FP16]; + hvx_vec_f32_to_f16_a(p_arr, P, hvx_vec_splat_f32(0)); + + float __attribute__((aligned(128))) P_arr[VLEN_FP32]; + hvx_vec_store_a(P_arr, 128, P); + + for (uint32_t j = 0; j < VLEN_FP32; j += 2) { + const uint32_t cur_ic = ic2 + j; + if (cur_ic >= current_block_size) { + break; + } + + if (cur_ic + 1 == current_block_size) { + // Odd leftover, process single row + if (P_arr[j] != 0.0f) { + const uint8_t * v_ptr = v_base + cur_ic * factx->size_v_row_padded; + hvx_mad_f32_f16_aa(VKQ32, v_ptr, (p_arr + j), DV); + } + break; + } + + // Avoid NaN * 0.0 = NaN for uninitialized V cache rows. + // Check the f32 values to safely avoid strict aliasing violations. + if (P_arr[j] == 0.0f && P_arr[j + 1] == 0.0f) { + continue; + } + + const uint8_t * v_ptr = v_base + cur_ic * factx->size_v_row_padded; + hvx_mad_f32_f16_aa_rx2(VKQ32, v_ptr, v_ptr + factx->size_v_row_padded, (p_arr + j), (p_arr + j + 1), DV); + } } - const uint8_t * v_ptr = v_base + ic * size_v_row_padded; - - hvx_mad_f32_f16_aa(VKQ32, v_ptr, DV, vs); - - S = S * ms + vs; + p_sum_vec = hvx_vec_reduce_sum_f32(p_sum_vec); + S_vec = HVX_OP_ADD_F32(HVX_OP_MUL_F32(S_vec, ms_vec), p_sum_vec); } // Issue DMA for next+1 block (if exists) - if (ib + 2 < n_blocks) { + if (ib + 2 < factx->n_blocks) { const uint32_t next_ib = ib + 2; const uint32_t next_ic_start = next_ib * FLASH_ATTN_BLOCK_SIZE; const uint32_t next_block_size = MIN(FLASH_ATTN_BLOCK_SIZE, nek1 - next_ic_start); // K const uint8_t * k_src = (const uint8_t *) k->data + (next_ic_start*nbk1 + ik2*nbk2 + ik3*nbk3); - dma_queue_push(dma, dma_make_ptr(k_base, k_src), size_k_row_padded, nbk1, size_k_row, next_block_size); + dma_queue_push(dma, dma_make_ptr(k_base, k_src), factx->size_k_row_padded, nbk1, size_k_row, next_block_size); // V const uint8_t * v_src = (const uint8_t *) v->data + (next_ic_start*nbv1 + iv2*nbv2 + iv3*nbv3); - dma_queue_push(dma, dma_make_ptr(v_base, v_src), size_v_row_padded, nbv1, size_v_row, next_block_size); + dma_queue_push(dma, dma_make_ptr(v_base, v_src), factx->size_v_row_padded, nbv1, size_v_row, next_block_size); // Mask if (mask) { const uint8_t * m_src = (const uint8_t *) (mp_base + next_ic_start); - dma_queue_push(dma, dma_make_ptr(m_base, m_src), next_block_size * 2, next_block_size * 2, next_block_size * 2, 1); + dma_cache_push(dma, &m_cache, m_src, next_block_size * 2, next_block_size * 2, next_block_size * 2, 1); } + + // FARF(HIGH, "fa %u: prefetch KVM: ir %u ib %u : iq1 %u iq2 %u iq3 %u : size_k_row %u size_v_row %u bs %u: usec %u", + // ith, ir, next_ib, iq1, iq2, iq3, + // size_k_row, size_v_row, next_block_size, + // (unsigned)HAP_perf_qtimer_count_to_us(HAP_perf_get_qtimer_count() - factx->t_start)); } } // sinks + float M = hvx_vec_get_f32(M_vec); + float S = hvx_vec_get_f32(S_vec); + if (sinks) { const float s = ((float *)((char *) sinks->data))[h]; - float ms = 1.0f; float vs = 1.0f; if (s > M) { - ms = expf(M - s); - hvx_scale_f32_aa((uint8_t *) VKQ32, (const uint8_t *) VKQ32, DV, ms); + HVX_Vector diff_vec = hvx_vec_splat_f32(M - s); + HVX_Vector ms_vec = hvx_vec_exp_f32(diff_vec); + hvx_scale_vec_f32_aa((uint8_t *) VKQ32, (const uint8_t *) VKQ32, DV, ms_vec); + + float ms = hvx_vec_get_f32(ms_vec); + S = S * ms + vs; } else { - vs = expf(s - M); + HVX_Vector diff_vec = hvx_vec_splat_f32(s - M); + vs = hvx_vec_get_f32(hvx_vec_exp_f32(diff_vec)); + S += vs; } - - S = S * ms + vs; } const float S_inv = S == 0.0f ? 0.0f : 1.0f/S; @@ -480,64 +605,114 @@ static void flash_attn_ext_f16_thread(struct htp_ops_context * octx, int ith, in const int i2 = iq2; const int i3 = iq3; - // dst is permuted - uint8_t * dst_ptr = (uint8_t *) dst->data + (i3*ne2*ne1 + i2 + i1*ne1) * nb1; + // dst is permuted: [DV, n_heads, n_tokens, n_seq] + // head stride is nb[1], token stride is nb[2], batch stride is nb[3] + uint8_t * dst_ptr = (uint8_t *) dst->data + i2 * dst->nb[1] + i1 * dst->nb[2] + i3 * dst->nb[3]; if (dst->type == HTP_TYPE_F32) { - hvx_copy_fp32_ua(dst_ptr, (uint8_t *) VKQ32, DV); + hvx_copy_f32_ua(dst_ptr, (uint8_t *) VKQ32, DV); } else if (dst->type == HTP_TYPE_F16) { - hvx_copy_fp16_fp32_ua(dst_ptr, (uint8_t *) VKQ32, DV); + hvx_copy_f16_f32_ua(dst_ptr, (uint8_t *) VKQ32, DV); } } } -static void htp_flash_attn_ext_job(unsigned int n, unsigned int i, void * data) { - struct htp_ops_context * octx = data; - flash_attn_ext_f16_thread(octx, i, n); -} - int op_flash_attn_ext(struct htp_ops_context * octx) { - const struct htp_tensor * q = &octx->src0; - const struct htp_tensor * k = &octx->src1; - const struct htp_tensor * v = &octx->src2; - const struct htp_tensor * mask = (octx->src3.type != HTP_TYPE_COUNT) ? &octx->src3 : NULL; - struct htp_tensor * dst = &octx->dst; + const struct htp_tensor * q = octx->src[0]; + const struct htp_tensor * k = octx->src[1]; + const struct htp_tensor * v = octx->src[2]; + const struct htp_tensor * mask = octx->src[3]; + const struct htp_tensor * dst = octx->dst; // Check support - if ((q->type != HTP_TYPE_F16 && q->type != HTP_TYPE_F32) || - k->type != HTP_TYPE_F16 || - v->type != HTP_TYPE_F16) { + if ((q->type != HTP_TYPE_F16 && q->type != HTP_TYPE_F32) || k->type != HTP_TYPE_F16 || v->type != HTP_TYPE_F16) { return HTP_STATUS_NO_SUPPORT; } - octx->src0_div21 = init_fastdiv_values(q->ne[2] * q->ne[1]); - octx->src0_div1 = init_fastdiv_values(q->ne[1]); +#ifdef HTP_HAS_HMX + // HMX path: head_dim multiple of 64, F16 KV, and no sinks + if (k->type == HTP_TYPE_F16 && v->type == HTP_TYPE_F16 && k->ne[0] % 64 == 0 && v->ne[0] % 64 == 0 && octx->src[4] == NULL) { + int ret = hmx_flash_attn_ext(octx); + if (ret == HTP_STATUS_OK) { + return ret; + } + // VTCM too small or other failure -> fall through to HVX path + } +#endif + + struct htp_fa_context factx; + factx.octx = octx; + + factx.t_start = HAP_perf_get_qtimer_count(); - octx->broadcast_rk2 = init_fastdiv_values(q->ne[2]/k->ne[2]); - octx->broadcast_rk3 = init_fastdiv_values(q->ne[3]/k->ne[3]); - octx->broadcast_rv2 = init_fastdiv_values(q->ne[2]/v->ne[2]); - octx->broadcast_rv3 = init_fastdiv_values(q->ne[3]/v->ne[3]); + factx.src0_div21 = init_fastdiv_values(q->ne[2] * q->ne[1]); + factx.src0_div1 = init_fastdiv_values(q->ne[1]); + + factx.broadcast_rk2 = init_fastdiv_values(q->ne[2]/k->ne[2]); + factx.broadcast_rk3 = init_fastdiv_values(q->ne[3]/k->ne[3]); + factx.broadcast_rv2 = init_fastdiv_values(q->ne[2]/v->ne[2]); + factx.broadcast_rv3 = init_fastdiv_values(q->ne[3]/v->ne[3]); if (mask) { - octx->src3_div2 = init_fastdiv_values(mask->ne[2]); - octx->src3_div3 = init_fastdiv_values(mask->ne[3]); + factx.src3_div2 = init_fastdiv_values(mask->ne[2]); + factx.src3_div3 = init_fastdiv_values(mask->ne[3]); + } + + factx.is_q_fp32 = (q->type == HTP_TYPE_F32); + factx.size_q_row_padded = hex_round_up(q->ne[0] * (factx.is_q_fp32 ? 4 : 2), 128); + factx.size_k_row_padded = hex_round_up(k->ne[0] * sizeof(__fp16), 128); + factx.size_v_row_padded = hex_round_up(v->ne[0] * sizeof(__fp16), 128); + + size_t size_q_block = factx.size_q_row_padded * 1; // single row for now + factx.size_k_block = factx.size_k_row_padded * FLASH_ATTN_BLOCK_SIZE; + factx.size_v_block = factx.size_v_row_padded * FLASH_ATTN_BLOCK_SIZE; + factx.size_m_block = hex_round_up(FLASH_ATTN_BLOCK_SIZE * sizeof(__fp16), 128); + + factx.n_blocks = (k->ne[1] + FLASH_ATTN_BLOCK_SIZE - 1) / FLASH_ATTN_BLOCK_SIZE; + + float scale = 1.0f; + float max_bias = 0.0f; + float logit_softcap = 0.0f; + + memcpy(&scale, (float *) octx->op_params + 0, sizeof(float)); + memcpy(&max_bias, (float *) octx->op_params + 1, sizeof(float)); + memcpy(&logit_softcap, (float *) octx->op_params + 2, sizeof(float)); + + if (logit_softcap != 0.0f) { + scale /= logit_softcap; } - size_t size_q_row_padded = htp_round_up(q->ne[0] * (q->type == HTP_TYPE_F32 ? 4 : 2), 128); - size_t size_k_row_padded = htp_round_up(k->ne[0] * sizeof(__fp16), 128); - size_t size_v_row_padded = htp_round_up(v->ne[0] * sizeof(__fp16), 128); + factx.scale = scale; + factx.max_bias = max_bias; + factx.logit_softcap = logit_softcap; + + uint32_t n_head = q->ne[2]; + factx.n_head_log2 = 1u << (uint32_t) floor(log2(n_head)); + factx.m0 = powf(2.0f, -(max_bias ) / factx.n_head_log2); + factx.m1 = powf(2.0f, -(max_bias / 2.0f) / factx.n_head_log2); + + if (n_head > 512) { + return HTP_STATUS_NO_SUPPORT; + } + for (uint32_t h = 0; h < n_head; ++h) { + factx.slopes[h] = (max_bias > 0.0f) ? alibi_slope(h, factx.n_head_log2, factx.m0, factx.m1) : 1.0f; + } + + // total rows in q + const uint32_t neq0 = q->ne[0]; + const uint32_t neq1 = q->ne[1]; + const uint32_t neq2 = q->ne[2]; + const uint32_t neq3 = q->ne[3]; - size_t size_q_block = size_q_row_padded * 1; // single row for now - size_t size_k_block = size_k_row_padded * FLASH_ATTN_BLOCK_SIZE; - size_t size_v_block = size_v_row_padded * FLASH_ATTN_BLOCK_SIZE; - size_t size_m_block = htp_round_up(FLASH_ATTN_BLOCK_SIZE * sizeof(__fp16), 128); + factx.qrows = neq1*neq2*neq3; + factx.qrows_per_thread = (factx.qrows + octx->n_threads - 1) / octx->n_threads; - size_t size_vkq_acc = htp_round_up(v->ne[0] * sizeof(float), 128); // VKQ32 + size_t size_vkq_acc = hex_round_up(v->ne[0] * sizeof(float), 128); // VKQ32 octx->src0_spad.size_per_thread = size_q_block * 1; - octx->src1_spad.size_per_thread = size_k_block * 2; - octx->src2_spad.size_per_thread = size_v_block * 2; - octx->src3_spad.size_per_thread = mask ? size_m_block * 2 : 0; + octx->src1_spad.size_per_thread = factx.size_k_block * 2; + octx->src2_spad.size_per_thread = factx.size_v_block * 2; + octx->src3_spad.size_per_thread = mask ? factx.size_m_block * DMA_CACHE_MAX_SIZE : 0; octx->dst_spad.size_per_thread = size_vkq_acc; octx->src0_spad.size = octx->src0_spad.size_per_thread * octx->n_threads; @@ -552,14 +727,14 @@ int op_flash_attn_ext(struct htp_ops_context * octx) { return HTP_STATUS_VTCM_TOO_SMALL; } - octx->src0_spad.data = octx->ctx->vtcm_base; - octx->src1_spad.data = octx->src0_spad.data + octx->src0_spad.size; - octx->src2_spad.data = octx->src1_spad.data + octx->src1_spad.size; - octx->src3_spad.data = octx->src2_spad.data + octx->src2_spad.size; - octx->dst_spad.data = octx->src3_spad.data + octx->src3_spad.size; + octx->src0_spad.data = octx->ctx->vtcm_base; octx->src0_spad.src = NULL; + octx->src1_spad.data = octx->src0_spad.data + octx->src0_spad.size; octx->src1_spad.src = NULL; + octx->src2_spad.data = octx->src1_spad.data + octx->src1_spad.size; octx->src2_spad.src = NULL; + octx->src3_spad.data = octx->src2_spad.data + octx->src2_spad.size; octx->src3_spad.src = NULL; + octx->dst_spad.data = octx->src3_spad.data + octx->src3_spad.size; octx->dst_spad.src = NULL; if (!(octx->flags & HTP_OPFLAGS_SKIP_COMPUTE)) { - worker_pool_run_func(octx->ctx->worker_pool, htp_flash_attn_ext_job, octx, octx->n_threads); + worker_pool_run_func(octx->ctx->worker_pool, flash_attn_ext_f16_thread, &factx, octx->n_threads); } return HTP_STATUS_OK; diff --git a/ggml/src/ggml-hexagon/htp/gated-delta-net-ops.c b/ggml/src/ggml-hexagon/htp/gated-delta-net-ops.c new file mode 100644 index 00000000000..35518e6111c --- /dev/null +++ b/ggml/src/ggml-hexagon/htp/gated-delta-net-ops.c @@ -0,0 +1,1148 @@ +#include <math.h> +#include <stdint.h> +#include <string.h> + +#include "hvx-utils.h" +#include "hex-fastdiv.h" + +#define GGML_COMMON_DECL_C +#include "ggml-common.h" +#include "htp-ctx.h" + +#ifndef MIN +#define MIN(a, b) ((a) < (b) ? (a) : (b)) +#endif + +#define HTP_GDN_MAX_SV 128 + + +struct htp_gdn_context { + struct htp_ops_context * octx; + uint32_t rows_per_thread; + size_t state_bytes; + uint8_t * vtcm_base; + size_t vtcm_per_thread; +}; + +static inline HVX_Vector gdn_mul_dot_f32(float * restrict dst, const float * restrict mul, const float * restrict dot, uint32_t n) { + HVX_Vector acc = Q6_V_vzero(); + + const uint32_t epv = 128 / sizeof(float); + const uint32_t nvec = n / epv; + const uint32_t nloe = n % epv; + for (uint32_t i = 0; i < nvec; ++i) { + HVX_Vector vd = hvx_vmemu(dst + i * epv); + HVX_Vector vm = hvx_vmem(mul + i * epv); + HVX_Vector vdot = hvx_vmem(dot + i * epv); + HVX_Vector out = hvx_vec_mul_f32_f32(vd, vm); + hvx_vmemu(dst + i * epv) = out; + acc = hvx_vec_add_f32_f32(acc, hvx_vec_mul_f32_f32(out, vdot)); + } + + if (nloe) { + const uint32_t off = nvec * epv; + HVX_Vector vd = hvx_vmemu(dst + off); + HVX_Vector vm = hvx_vmem(mul + off); + HVX_Vector vdot = hvx_vmem(dot + off); + HVX_Vector out = hvx_vec_mul_f32_f32(vd, vm); + hvx_vec_store_u(dst + off, nloe * sizeof(float), out); + HVX_VectorPred mask = Q6_Q_vsetq2_R(nloe * sizeof(float)); + HVX_Vector prod = hvx_vec_mul_f32_f32(out, vdot); + acc = hvx_vec_add_f32_f32(acc, Q6_V_vmux_QVV(mask, prod, Q6_V_vzero())); + } + + return hvx_vec_reduce_sum_f32(acc); +} + +static inline HVX_Vector gdn_mul_scalar_dot_f32(float * restrict dst, float mul, const float * restrict dot, uint32_t n) { + HVX_Vector acc = Q6_V_vzero(); + const HVX_Vector vmul = hvx_vec_splat_f32(mul); + + const uint32_t epv = 128 / sizeof(float); + const uint32_t nvec = n / epv; + const uint32_t nloe = n % epv; + for (uint32_t i = 0; i < nvec; ++i) { + HVX_Vector vd = hvx_vmemu(dst + i * epv); + HVX_Vector vdot = hvx_vmem(dot + i * epv); + HVX_Vector out = hvx_vec_mul_f32_f32(vd, vmul); + hvx_vmemu(dst + i * epv) = out; + acc = hvx_vec_add_f32_f32(acc, hvx_vec_mul_f32_f32(out, vdot)); + } + + if (nloe) { + const uint32_t off = nvec * epv; + HVX_Vector vd = hvx_vmemu(dst + off); + HVX_Vector vdot = hvx_vmem(dot + off); + HVX_Vector out = hvx_vec_mul_f32_f32(vd, vmul); + hvx_vec_store_u(dst + off, nloe * sizeof(float), out); + HVX_VectorPred mask = Q6_Q_vsetq2_R(nloe * sizeof(float)); + HVX_Vector prod = hvx_vec_mul_f32_f32(out, vdot); + acc = hvx_vec_add_f32_f32(acc, Q6_V_vmux_QVV(mask, prod, Q6_V_vzero())); + } + + return hvx_vec_reduce_sum_f32(acc); +} + +static inline HVX_Vector gdn_add_scaled_dot_f32(float * restrict dst, const float * restrict src, + HVX_Vector vscale, const float * restrict dot, uint32_t n) { + HVX_Vector acc = Q6_V_vzero(); + + const uint32_t epv = 128 / sizeof(float); + const uint32_t nvec = n / epv; + const uint32_t nloe = n % epv; + for (uint32_t i = 0; i < nvec; ++i) { + HVX_Vector vd = hvx_vmemu(dst + i * epv); + HVX_Vector vs = hvx_vmem(src + i * epv); + HVX_Vector vdot = hvx_vmem(dot + i * epv); + HVX_Vector out = hvx_vec_add_f32_f32(vd, hvx_vec_mul_f32_f32(vs, vscale)); + hvx_vmemu(dst + i * epv) = out; + acc = hvx_vec_add_f32_f32(acc, hvx_vec_mul_f32_f32(out, vdot)); + } + + if (nloe) { + const uint32_t off = nvec * epv; + HVX_Vector vd = hvx_vmemu(dst + off); + HVX_Vector vs = hvx_vmem(src + off); + HVX_Vector vdot = hvx_vmem(dot + off); + HVX_Vector out = hvx_vec_add_f32_f32(vd, hvx_vec_mul_f32_f32(vs, vscale)); + hvx_vec_store_u(dst + off, nloe * sizeof(float), out); + HVX_VectorPred mask = Q6_Q_vsetq2_R(nloe * sizeof(float)); + HVX_Vector prod = hvx_vec_mul_f32_f32(out, vdot); + acc = hvx_vec_add_f32_f32(acc, Q6_V_vmux_QVV(mask, prod, Q6_V_vzero())); + } + + return hvx_vec_reduce_sum_f32(acc); +} + +static inline void gdn_mul_dot4_f32(float * restrict dst0, float * restrict dst1, + float * restrict dst2, float * restrict dst3, const float * restrict mul, + const float * restrict dot, uint32_t n, float * restrict sums) { + HVX_Vector acc0 = Q6_V_vzero(); + HVX_Vector acc1 = Q6_V_vzero(); + HVX_Vector acc2 = Q6_V_vzero(); + HVX_Vector acc3 = Q6_V_vzero(); + + const uint32_t epv = 128 / sizeof(float); + const uint32_t nvec = n / epv; + const uint32_t nloe = n % epv; + for (uint32_t i = 0; i < nvec; ++i) { + HVX_Vector vm = hvx_vmem(mul + i * epv); + HVX_Vector vdot = hvx_vmem(dot + i * epv); + + HVX_Vector out0 = hvx_vec_mul_f32_f32(hvx_vmemu(dst0 + i * epv), vm); + HVX_Vector out1 = hvx_vec_mul_f32_f32(hvx_vmemu(dst1 + i * epv), vm); + HVX_Vector out2 = hvx_vec_mul_f32_f32(hvx_vmemu(dst2 + i * epv), vm); + HVX_Vector out3 = hvx_vec_mul_f32_f32(hvx_vmemu(dst3 + i * epv), vm); + + hvx_vmemu(dst0 + i * epv) = out0; + hvx_vmemu(dst1 + i * epv) = out1; + hvx_vmemu(dst2 + i * epv) = out2; + hvx_vmemu(dst3 + i * epv) = out3; + + acc0 = hvx_vec_add_f32_f32(acc0, hvx_vec_mul_f32_f32(out0, vdot)); + acc1 = hvx_vec_add_f32_f32(acc1, hvx_vec_mul_f32_f32(out1, vdot)); + acc2 = hvx_vec_add_f32_f32(acc2, hvx_vec_mul_f32_f32(out2, vdot)); + acc3 = hvx_vec_add_f32_f32(acc3, hvx_vec_mul_f32_f32(out3, vdot)); + } + + if (nloe) { + const uint32_t off = nvec * epv; + HVX_Vector vm = hvx_vmem(mul + off); + HVX_Vector vdot = hvx_vmem(dot + off); + HVX_VectorPred mask = Q6_Q_vsetq2_R(nloe * sizeof(float)); + HVX_Vector zero = Q6_V_vzero(); + + HVX_Vector out0 = hvx_vec_mul_f32_f32(hvx_vmemu(dst0 + off), vm); + HVX_Vector out1 = hvx_vec_mul_f32_f32(hvx_vmemu(dst1 + off), vm); + HVX_Vector out2 = hvx_vec_mul_f32_f32(hvx_vmemu(dst2 + off), vm); + HVX_Vector out3 = hvx_vec_mul_f32_f32(hvx_vmemu(dst3 + off), vm); + + hvx_vec_store_u(dst0 + off, nloe * sizeof(float), out0); + hvx_vec_store_u(dst1 + off, nloe * sizeof(float), out1); + hvx_vec_store_u(dst2 + off, nloe * sizeof(float), out2); + hvx_vec_store_u(dst3 + off, nloe * sizeof(float), out3); + + acc0 = hvx_vec_add_f32_f32(acc0, Q6_V_vmux_QVV(mask, hvx_vec_mul_f32_f32(out0, vdot), zero)); + acc1 = hvx_vec_add_f32_f32(acc1, Q6_V_vmux_QVV(mask, hvx_vec_mul_f32_f32(out1, vdot), zero)); + acc2 = hvx_vec_add_f32_f32(acc2, Q6_V_vmux_QVV(mask, hvx_vec_mul_f32_f32(out2, vdot), zero)); + acc3 = hvx_vec_add_f32_f32(acc3, Q6_V_vmux_QVV(mask, hvx_vec_mul_f32_f32(out3, vdot), zero)); + } + + HVX_Vector_x4 acc = { .v = { acc0, acc1, acc2, acc3 } }; + hvx_vec_store_u(sums, 4 * sizeof(float), hvx_vec_reduce_sum_f32x4(acc)); +} + +static inline void gdn_mul_scalar_dot4_f32(float * restrict dst0, float * restrict dst1, + float * restrict dst2, float * restrict dst3, float mul, + const float * restrict dot, uint32_t n, float * restrict sums) { + HVX_Vector acc0 = Q6_V_vzero(); + HVX_Vector acc1 = Q6_V_vzero(); + HVX_Vector acc2 = Q6_V_vzero(); + HVX_Vector acc3 = Q6_V_vzero(); + const HVX_Vector vmul = hvx_vec_splat_f32(mul); + + const uint32_t epv = 128 / sizeof(float); + const uint32_t nvec = n / epv; + const uint32_t nloe = n % epv; + for (uint32_t i = 0; i < nvec; ++i) { + HVX_Vector vdot = hvx_vmem(dot + i * epv); + + HVX_Vector out0 = hvx_vec_mul_f32_f32(hvx_vmemu(dst0 + i * epv), vmul); + HVX_Vector out1 = hvx_vec_mul_f32_f32(hvx_vmemu(dst1 + i * epv), vmul); + HVX_Vector out2 = hvx_vec_mul_f32_f32(hvx_vmemu(dst2 + i * epv), vmul); + HVX_Vector out3 = hvx_vec_mul_f32_f32(hvx_vmemu(dst3 + i * epv), vmul); + + hvx_vmemu(dst0 + i * epv) = out0; + hvx_vmemu(dst1 + i * epv) = out1; + hvx_vmemu(dst2 + i * epv) = out2; + hvx_vmemu(dst3 + i * epv) = out3; + + acc0 = hvx_vec_add_f32_f32(acc0, hvx_vec_mul_f32_f32(out0, vdot)); + acc1 = hvx_vec_add_f32_f32(acc1, hvx_vec_mul_f32_f32(out1, vdot)); + acc2 = hvx_vec_add_f32_f32(acc2, hvx_vec_mul_f32_f32(out2, vdot)); + acc3 = hvx_vec_add_f32_f32(acc3, hvx_vec_mul_f32_f32(out3, vdot)); + } + + if (nloe) { + const uint32_t off = nvec * epv; + HVX_Vector vdot = hvx_vmem(dot + off); + HVX_VectorPred mask = Q6_Q_vsetq2_R(nloe * sizeof(float)); + HVX_Vector zero = Q6_V_vzero(); + + HVX_Vector out0 = hvx_vec_mul_f32_f32(hvx_vmemu(dst0 + off), vmul); + HVX_Vector out1 = hvx_vec_mul_f32_f32(hvx_vmemu(dst1 + off), vmul); + HVX_Vector out2 = hvx_vec_mul_f32_f32(hvx_vmemu(dst2 + off), vmul); + HVX_Vector out3 = hvx_vec_mul_f32_f32(hvx_vmemu(dst3 + off), vmul); + + hvx_vec_store_u(dst0 + off, nloe * sizeof(float), out0); + hvx_vec_store_u(dst1 + off, nloe * sizeof(float), out1); + hvx_vec_store_u(dst2 + off, nloe * sizeof(float), out2); + hvx_vec_store_u(dst3 + off, nloe * sizeof(float), out3); + + acc0 = hvx_vec_add_f32_f32(acc0, Q6_V_vmux_QVV(mask, hvx_vec_mul_f32_f32(out0, vdot), zero)); + acc1 = hvx_vec_add_f32_f32(acc1, Q6_V_vmux_QVV(mask, hvx_vec_mul_f32_f32(out1, vdot), zero)); + acc2 = hvx_vec_add_f32_f32(acc2, Q6_V_vmux_QVV(mask, hvx_vec_mul_f32_f32(out2, vdot), zero)); + acc3 = hvx_vec_add_f32_f32(acc3, Q6_V_vmux_QVV(mask, hvx_vec_mul_f32_f32(out3, vdot), zero)); + } + + HVX_Vector_x4 acc = { .v = { acc0, acc1, acc2, acc3 } }; + hvx_vec_store_u(sums, 4 * sizeof(float), hvx_vec_reduce_sum_f32x4(acc)); +} + +static inline void gdn_add_scaled_dot4_f32(float * restrict dst0, float * restrict dst1, + float * restrict dst2, float * restrict dst3, const float * restrict src, + const float * restrict scale, const float * restrict dot, uint32_t n, + float * restrict sums) { + HVX_Vector acc0 = Q6_V_vzero(); + HVX_Vector acc1 = Q6_V_vzero(); + HVX_Vector acc2 = Q6_V_vzero(); + HVX_Vector acc3 = Q6_V_vzero(); + const HVX_Vector scale0 = hvx_vec_splat_f32(scale[0]); + const HVX_Vector scale1 = hvx_vec_splat_f32(scale[1]); + const HVX_Vector scale2 = hvx_vec_splat_f32(scale[2]); + const HVX_Vector scale3 = hvx_vec_splat_f32(scale[3]); + + const uint32_t epv = 128 / sizeof(float); + const uint32_t nvec = n / epv; + const uint32_t nloe = n % epv; + for (uint32_t i = 0; i < nvec; ++i) { + HVX_Vector vs = hvx_vmem(src + i * epv); + HVX_Vector vdot = hvx_vmem(dot + i * epv); + + HVX_Vector out0 = hvx_vec_add_f32_f32(hvx_vmemu(dst0 + i * epv), hvx_vec_mul_f32_f32(vs, scale0)); + HVX_Vector out1 = hvx_vec_add_f32_f32(hvx_vmemu(dst1 + i * epv), hvx_vec_mul_f32_f32(vs, scale1)); + HVX_Vector out2 = hvx_vec_add_f32_f32(hvx_vmemu(dst2 + i * epv), hvx_vec_mul_f32_f32(vs, scale2)); + HVX_Vector out3 = hvx_vec_add_f32_f32(hvx_vmemu(dst3 + i * epv), hvx_vec_mul_f32_f32(vs, scale3)); + + hvx_vmemu(dst0 + i * epv) = out0; + hvx_vmemu(dst1 + i * epv) = out1; + hvx_vmemu(dst2 + i * epv) = out2; + hvx_vmemu(dst3 + i * epv) = out3; + + acc0 = hvx_vec_add_f32_f32(acc0, hvx_vec_mul_f32_f32(out0, vdot)); + acc1 = hvx_vec_add_f32_f32(acc1, hvx_vec_mul_f32_f32(out1, vdot)); + acc2 = hvx_vec_add_f32_f32(acc2, hvx_vec_mul_f32_f32(out2, vdot)); + acc3 = hvx_vec_add_f32_f32(acc3, hvx_vec_mul_f32_f32(out3, vdot)); + } + + if (nloe) { + const uint32_t off = nvec * epv; + HVX_Vector vs = hvx_vmem(src + off); + HVX_Vector vdot = hvx_vmem(dot + off); + HVX_VectorPred mask = Q6_Q_vsetq2_R(nloe * sizeof(float)); + HVX_Vector zero = Q6_V_vzero(); + + HVX_Vector out0 = hvx_vec_add_f32_f32(hvx_vmemu(dst0 + off), hvx_vec_mul_f32_f32(vs, scale0)); + HVX_Vector out1 = hvx_vec_add_f32_f32(hvx_vmemu(dst1 + off), hvx_vec_mul_f32_f32(vs, scale1)); + HVX_Vector out2 = hvx_vec_add_f32_f32(hvx_vmemu(dst2 + off), hvx_vec_mul_f32_f32(vs, scale2)); + HVX_Vector out3 = hvx_vec_add_f32_f32(hvx_vmemu(dst3 + off), hvx_vec_mul_f32_f32(vs, scale3)); + + hvx_vec_store_u(dst0 + off, nloe * sizeof(float), out0); + hvx_vec_store_u(dst1 + off, nloe * sizeof(float), out1); + hvx_vec_store_u(dst2 + off, nloe * sizeof(float), out2); + hvx_vec_store_u(dst3 + off, nloe * sizeof(float), out3); + + acc0 = hvx_vec_add_f32_f32(acc0, Q6_V_vmux_QVV(mask, hvx_vec_mul_f32_f32(out0, vdot), zero)); + acc1 = hvx_vec_add_f32_f32(acc1, Q6_V_vmux_QVV(mask, hvx_vec_mul_f32_f32(out1, vdot), zero)); + acc2 = hvx_vec_add_f32_f32(acc2, Q6_V_vmux_QVV(mask, hvx_vec_mul_f32_f32(out2, vdot), zero)); + acc3 = hvx_vec_add_f32_f32(acc3, Q6_V_vmux_QVV(mask, hvx_vec_mul_f32_f32(out3, vdot), zero)); + } + + HVX_Vector_x4 acc = { .v = { acc0, acc1, acc2, acc3 } }; + hvx_vec_store_u(sums, 4 * sizeof(float), hvx_vec_reduce_sum_f32x4(acc)); +} + +static inline void gdn_mul_dot8_f32(float * restrict dst0, float * restrict dst1, + float * restrict dst2, float * restrict dst3, float * restrict dst4, + float * restrict dst5, float * restrict dst6, float * restrict dst7, + const float * restrict mul, const float * restrict dot, uint32_t n, + float * restrict sums) { + HVX_Vector acc0 = Q6_V_vzero(); + HVX_Vector acc1 = Q6_V_vzero(); + HVX_Vector acc2 = Q6_V_vzero(); + HVX_Vector acc3 = Q6_V_vzero(); + HVX_Vector acc4 = Q6_V_vzero(); + HVX_Vector acc5 = Q6_V_vzero(); + HVX_Vector acc6 = Q6_V_vzero(); + HVX_Vector acc7 = Q6_V_vzero(); + + const uint32_t epv = 128 / sizeof(float); + const uint32_t nvec = n / epv; + const uint32_t nloe = n % epv; + for (uint32_t i = 0; i < nvec; ++i) { + HVX_Vector vm = hvx_vmem(mul + i * epv); + HVX_Vector vdot = hvx_vmem(dot + i * epv); + + HVX_Vector out0 = hvx_vec_mul_f32_f32(hvx_vmemu(dst0 + i * epv), vm); + HVX_Vector out1 = hvx_vec_mul_f32_f32(hvx_vmemu(dst1 + i * epv), vm); + HVX_Vector out2 = hvx_vec_mul_f32_f32(hvx_vmemu(dst2 + i * epv), vm); + HVX_Vector out3 = hvx_vec_mul_f32_f32(hvx_vmemu(dst3 + i * epv), vm); + HVX_Vector out4 = hvx_vec_mul_f32_f32(hvx_vmemu(dst4 + i * epv), vm); + HVX_Vector out5 = hvx_vec_mul_f32_f32(hvx_vmemu(dst5 + i * epv), vm); + HVX_Vector out6 = hvx_vec_mul_f32_f32(hvx_vmemu(dst6 + i * epv), vm); + HVX_Vector out7 = hvx_vec_mul_f32_f32(hvx_vmemu(dst7 + i * epv), vm); + + hvx_vmemu(dst0 + i * epv) = out0; + hvx_vmemu(dst1 + i * epv) = out1; + hvx_vmemu(dst2 + i * epv) = out2; + hvx_vmemu(dst3 + i * epv) = out3; + hvx_vmemu(dst4 + i * epv) = out4; + hvx_vmemu(dst5 + i * epv) = out5; + hvx_vmemu(dst6 + i * epv) = out6; + hvx_vmemu(dst7 + i * epv) = out7; + + acc0 = hvx_vec_add_f32_f32(acc0, hvx_vec_mul_f32_f32(out0, vdot)); + acc1 = hvx_vec_add_f32_f32(acc1, hvx_vec_mul_f32_f32(out1, vdot)); + acc2 = hvx_vec_add_f32_f32(acc2, hvx_vec_mul_f32_f32(out2, vdot)); + acc3 = hvx_vec_add_f32_f32(acc3, hvx_vec_mul_f32_f32(out3, vdot)); + acc4 = hvx_vec_add_f32_f32(acc4, hvx_vec_mul_f32_f32(out4, vdot)); + acc5 = hvx_vec_add_f32_f32(acc5, hvx_vec_mul_f32_f32(out5, vdot)); + acc6 = hvx_vec_add_f32_f32(acc6, hvx_vec_mul_f32_f32(out6, vdot)); + acc7 = hvx_vec_add_f32_f32(acc7, hvx_vec_mul_f32_f32(out7, vdot)); + } + + if (nloe) { + const uint32_t off = nvec * epv; + HVX_Vector vm = hvx_vmem(mul + off); + HVX_Vector vdot = hvx_vmem(dot + off); + HVX_VectorPred mask = Q6_Q_vsetq2_R(nloe * sizeof(float)); + HVX_Vector zero = Q6_V_vzero(); + + HVX_Vector out0 = hvx_vec_mul_f32_f32(hvx_vmemu(dst0 + off), vm); + HVX_Vector out1 = hvx_vec_mul_f32_f32(hvx_vmemu(dst1 + off), vm); + HVX_Vector out2 = hvx_vec_mul_f32_f32(hvx_vmemu(dst2 + off), vm); + HVX_Vector out3 = hvx_vec_mul_f32_f32(hvx_vmemu(dst3 + off), vm); + HVX_Vector out4 = hvx_vec_mul_f32_f32(hvx_vmemu(dst4 + off), vm); + HVX_Vector out5 = hvx_vec_mul_f32_f32(hvx_vmemu(dst5 + off), vm); + HVX_Vector out6 = hvx_vec_mul_f32_f32(hvx_vmemu(dst6 + off), vm); + HVX_Vector out7 = hvx_vec_mul_f32_f32(hvx_vmemu(dst7 + off), vm); + + hvx_vec_store_u(dst0 + off, nloe * sizeof(float), out0); + hvx_vec_store_u(dst1 + off, nloe * sizeof(float), out1); + hvx_vec_store_u(dst2 + off, nloe * sizeof(float), out2); + hvx_vec_store_u(dst3 + off, nloe * sizeof(float), out3); + hvx_vec_store_u(dst4 + off, nloe * sizeof(float), out4); + hvx_vec_store_u(dst5 + off, nloe * sizeof(float), out5); + hvx_vec_store_u(dst6 + off, nloe * sizeof(float), out6); + hvx_vec_store_u(dst7 + off, nloe * sizeof(float), out7); + + acc0 = hvx_vec_add_f32_f32(acc0, Q6_V_vmux_QVV(mask, hvx_vec_mul_f32_f32(out0, vdot), zero)); + acc1 = hvx_vec_add_f32_f32(acc1, Q6_V_vmux_QVV(mask, hvx_vec_mul_f32_f32(out1, vdot), zero)); + acc2 = hvx_vec_add_f32_f32(acc2, Q6_V_vmux_QVV(mask, hvx_vec_mul_f32_f32(out2, vdot), zero)); + acc3 = hvx_vec_add_f32_f32(acc3, Q6_V_vmux_QVV(mask, hvx_vec_mul_f32_f32(out3, vdot), zero)); + acc4 = hvx_vec_add_f32_f32(acc4, Q6_V_vmux_QVV(mask, hvx_vec_mul_f32_f32(out4, vdot), zero)); + acc5 = hvx_vec_add_f32_f32(acc5, Q6_V_vmux_QVV(mask, hvx_vec_mul_f32_f32(out5, vdot), zero)); + acc6 = hvx_vec_add_f32_f32(acc6, Q6_V_vmux_QVV(mask, hvx_vec_mul_f32_f32(out6, vdot), zero)); + acc7 = hvx_vec_add_f32_f32(acc7, Q6_V_vmux_QVV(mask, hvx_vec_mul_f32_f32(out7, vdot), zero)); + } + + HVX_Vector_x4 accA = { .v = { acc0, acc1, acc2, acc3 } }; + HVX_Vector_x4 accB = { .v = { acc4, acc5, acc6, acc7 } }; + hvx_vec_store_u(sums + 0, 4 * sizeof(float), hvx_vec_reduce_sum_f32x4(accA)); + hvx_vec_store_u(sums + 4, 4 * sizeof(float), hvx_vec_reduce_sum_f32x4(accB)); +} + +static inline void gdn_mul_scalar_dot8_f32(float * restrict dst0, float * restrict dst1, + float * restrict dst2, float * restrict dst3, float * restrict dst4, + float * restrict dst5, float * restrict dst6, float * restrict dst7, + float mul, const float * restrict dot, uint32_t n, float * restrict sums) { + HVX_Vector acc0 = Q6_V_vzero(); + HVX_Vector acc1 = Q6_V_vzero(); + HVX_Vector acc2 = Q6_V_vzero(); + HVX_Vector acc3 = Q6_V_vzero(); + HVX_Vector acc4 = Q6_V_vzero(); + HVX_Vector acc5 = Q6_V_vzero(); + HVX_Vector acc6 = Q6_V_vzero(); + HVX_Vector acc7 = Q6_V_vzero(); + const HVX_Vector vmul = hvx_vec_splat_f32(mul); + + const uint32_t epv = 128 / sizeof(float); + const uint32_t nvec = n / epv; + const uint32_t nloe = n % epv; + for (uint32_t i = 0; i < nvec; ++i) { + HVX_Vector vdot = hvx_vmem(dot + i * epv); + + HVX_Vector out0 = hvx_vec_mul_f32_f32(hvx_vmemu(dst0 + i * epv), vmul); + HVX_Vector out1 = hvx_vec_mul_f32_f32(hvx_vmemu(dst1 + i * epv), vmul); + HVX_Vector out2 = hvx_vec_mul_f32_f32(hvx_vmemu(dst2 + i * epv), vmul); + HVX_Vector out3 = hvx_vec_mul_f32_f32(hvx_vmemu(dst3 + i * epv), vmul); + HVX_Vector out4 = hvx_vec_mul_f32_f32(hvx_vmemu(dst4 + i * epv), vmul); + HVX_Vector out5 = hvx_vec_mul_f32_f32(hvx_vmemu(dst5 + i * epv), vmul); + HVX_Vector out6 = hvx_vec_mul_f32_f32(hvx_vmemu(dst6 + i * epv), vmul); + HVX_Vector out7 = hvx_vec_mul_f32_f32(hvx_vmemu(dst7 + i * epv), vmul); + + hvx_vmemu(dst0 + i * epv) = out0; + hvx_vmemu(dst1 + i * epv) = out1; + hvx_vmemu(dst2 + i * epv) = out2; + hvx_vmemu(dst3 + i * epv) = out3; + hvx_vmemu(dst4 + i * epv) = out4; + hvx_vmemu(dst5 + i * epv) = out5; + hvx_vmemu(dst6 + i * epv) = out6; + hvx_vmemu(dst7 + i * epv) = out7; + + acc0 = hvx_vec_add_f32_f32(acc0, hvx_vec_mul_f32_f32(out0, vdot)); + acc1 = hvx_vec_add_f32_f32(acc1, hvx_vec_mul_f32_f32(out1, vdot)); + acc2 = hvx_vec_add_f32_f32(acc2, hvx_vec_mul_f32_f32(out2, vdot)); + acc3 = hvx_vec_add_f32_f32(acc3, hvx_vec_mul_f32_f32(out3, vdot)); + acc4 = hvx_vec_add_f32_f32(acc4, hvx_vec_mul_f32_f32(out4, vdot)); + acc5 = hvx_vec_add_f32_f32(acc5, hvx_vec_mul_f32_f32(out5, vdot)); + acc6 = hvx_vec_add_f32_f32(acc6, hvx_vec_mul_f32_f32(out6, vdot)); + acc7 = hvx_vec_add_f32_f32(acc7, hvx_vec_mul_f32_f32(out7, vdot)); + } + + if (nloe) { + const uint32_t off = nvec * epv; + HVX_Vector vdot = hvx_vmem(dot + off); + HVX_VectorPred mask = Q6_Q_vsetq2_R(nloe * sizeof(float)); + HVX_Vector zero = Q6_V_vzero(); + + HVX_Vector out0 = hvx_vec_mul_f32_f32(hvx_vmemu(dst0 + off), vmul); + HVX_Vector out1 = hvx_vec_mul_f32_f32(hvx_vmemu(dst1 + off), vmul); + HVX_Vector out2 = hvx_vec_mul_f32_f32(hvx_vmemu(dst2 + off), vmul); + HVX_Vector out3 = hvx_vec_mul_f32_f32(hvx_vmemu(dst3 + off), vmul); + HVX_Vector out4 = hvx_vec_mul_f32_f32(hvx_vmemu(dst4 + off), vmul); + HVX_Vector out5 = hvx_vec_mul_f32_f32(hvx_vmemu(dst5 + off), vmul); + HVX_Vector out6 = hvx_vec_mul_f32_f32(hvx_vmemu(dst6 + off), vmul); + HVX_Vector out7 = hvx_vec_mul_f32_f32(hvx_vmemu(dst7 + off), vmul); + + hvx_vec_store_u(dst0 + off, nloe * sizeof(float), out0); + hvx_vec_store_u(dst1 + off, nloe * sizeof(float), out1); + hvx_vec_store_u(dst2 + off, nloe * sizeof(float), out2); + hvx_vec_store_u(dst3 + off, nloe * sizeof(float), out3); + hvx_vec_store_u(dst4 + off, nloe * sizeof(float), out4); + hvx_vec_store_u(dst5 + off, nloe * sizeof(float), out5); + hvx_vec_store_u(dst6 + off, nloe * sizeof(float), out6); + hvx_vec_store_u(dst7 + off, nloe * sizeof(float), out7); + + acc0 = hvx_vec_add_f32_f32(acc0, Q6_V_vmux_QVV(mask, hvx_vec_mul_f32_f32(out0, vdot), zero)); + acc1 = hvx_vec_add_f32_f32(acc1, Q6_V_vmux_QVV(mask, hvx_vec_mul_f32_f32(out1, vdot), zero)); + acc2 = hvx_vec_add_f32_f32(acc2, Q6_V_vmux_QVV(mask, hvx_vec_mul_f32_f32(out2, vdot), zero)); + acc3 = hvx_vec_add_f32_f32(acc3, Q6_V_vmux_QVV(mask, hvx_vec_mul_f32_f32(out3, vdot), zero)); + acc4 = hvx_vec_add_f32_f32(acc4, Q6_V_vmux_QVV(mask, hvx_vec_mul_f32_f32(out4, vdot), zero)); + acc5 = hvx_vec_add_f32_f32(acc5, Q6_V_vmux_QVV(mask, hvx_vec_mul_f32_f32(out5, vdot), zero)); + acc6 = hvx_vec_add_f32_f32(acc6, Q6_V_vmux_QVV(mask, hvx_vec_mul_f32_f32(out6, vdot), zero)); + acc7 = hvx_vec_add_f32_f32(acc7, Q6_V_vmux_QVV(mask, hvx_vec_mul_f32_f32(out7, vdot), zero)); + } + + HVX_Vector_x4 accA = { .v = { acc0, acc1, acc2, acc3 } }; + HVX_Vector_x4 accB = { .v = { acc4, acc5, acc6, acc7 } }; + hvx_vec_store_u(sums + 0, 4 * sizeof(float), hvx_vec_reduce_sum_f32x4(accA)); + hvx_vec_store_u(sums + 4, 4 * sizeof(float), hvx_vec_reduce_sum_f32x4(accB)); +} + +static inline void gdn_add_scaled_dot8_f32(float * restrict dst0, float * restrict dst1, + float * restrict dst2, float * restrict dst3, float * restrict dst4, + float * restrict dst5, float * restrict dst6, float * restrict dst7, + const float * restrict src, const float * restrict scale, + const float * restrict dot, uint32_t n, float * restrict sums) { + HVX_Vector acc0 = Q6_V_vzero(); + HVX_Vector acc1 = Q6_V_vzero(); + HVX_Vector acc2 = Q6_V_vzero(); + HVX_Vector acc3 = Q6_V_vzero(); + HVX_Vector acc4 = Q6_V_vzero(); + HVX_Vector acc5 = Q6_V_vzero(); + HVX_Vector acc6 = Q6_V_vzero(); + HVX_Vector acc7 = Q6_V_vzero(); + const HVX_Vector scale0 = hvx_vec_splat_f32(scale[0]); + const HVX_Vector scale1 = hvx_vec_splat_f32(scale[1]); + const HVX_Vector scale2 = hvx_vec_splat_f32(scale[2]); + const HVX_Vector scale3 = hvx_vec_splat_f32(scale[3]); + const HVX_Vector scale4 = hvx_vec_splat_f32(scale[4]); + const HVX_Vector scale5 = hvx_vec_splat_f32(scale[5]); + const HVX_Vector scale6 = hvx_vec_splat_f32(scale[6]); + const HVX_Vector scale7 = hvx_vec_splat_f32(scale[7]); + + const uint32_t epv = 128 / sizeof(float); + const uint32_t nvec = n / epv; + const uint32_t nloe = n % epv; + for (uint32_t i = 0; i < nvec; ++i) { + HVX_Vector vs = hvx_vmem(src + i * epv); + HVX_Vector vdot = hvx_vmem(dot + i * epv); + + HVX_Vector out0 = hvx_vec_add_f32_f32(hvx_vmemu(dst0 + i * epv), hvx_vec_mul_f32_f32(vs, scale0)); + HVX_Vector out1 = hvx_vec_add_f32_f32(hvx_vmemu(dst1 + i * epv), hvx_vec_mul_f32_f32(vs, scale1)); + HVX_Vector out2 = hvx_vec_add_f32_f32(hvx_vmemu(dst2 + i * epv), hvx_vec_mul_f32_f32(vs, scale2)); + HVX_Vector out3 = hvx_vec_add_f32_f32(hvx_vmemu(dst3 + i * epv), hvx_vec_mul_f32_f32(vs, scale3)); + HVX_Vector out4 = hvx_vec_add_f32_f32(hvx_vmemu(dst4 + i * epv), hvx_vec_mul_f32_f32(vs, scale4)); + HVX_Vector out5 = hvx_vec_add_f32_f32(hvx_vmemu(dst5 + i * epv), hvx_vec_mul_f32_f32(vs, scale5)); + HVX_Vector out6 = hvx_vec_add_f32_f32(hvx_vmemu(dst6 + i * epv), hvx_vec_mul_f32_f32(vs, scale6)); + HVX_Vector out7 = hvx_vec_add_f32_f32(hvx_vmemu(dst7 + i * epv), hvx_vec_mul_f32_f32(vs, scale7)); + + hvx_vmemu(dst0 + i * epv) = out0; + hvx_vmemu(dst1 + i * epv) = out1; + hvx_vmemu(dst2 + i * epv) = out2; + hvx_vmemu(dst3 + i * epv) = out3; + hvx_vmemu(dst4 + i * epv) = out4; + hvx_vmemu(dst5 + i * epv) = out5; + hvx_vmemu(dst6 + i * epv) = out6; + hvx_vmemu(dst7 + i * epv) = out7; + + acc0 = hvx_vec_add_f32_f32(acc0, hvx_vec_mul_f32_f32(out0, vdot)); + acc1 = hvx_vec_add_f32_f32(acc1, hvx_vec_mul_f32_f32(out1, vdot)); + acc2 = hvx_vec_add_f32_f32(acc2, hvx_vec_mul_f32_f32(out2, vdot)); + acc3 = hvx_vec_add_f32_f32(acc3, hvx_vec_mul_f32_f32(out3, vdot)); + acc4 = hvx_vec_add_f32_f32(acc4, hvx_vec_mul_f32_f32(out4, vdot)); + acc5 = hvx_vec_add_f32_f32(acc5, hvx_vec_mul_f32_f32(out5, vdot)); + acc6 = hvx_vec_add_f32_f32(acc6, hvx_vec_mul_f32_f32(out6, vdot)); + acc7 = hvx_vec_add_f32_f32(acc7, hvx_vec_mul_f32_f32(out7, vdot)); + } + + if (nloe) { + const uint32_t off = nvec * epv; + HVX_Vector vs = hvx_vmem(src + off); + HVX_Vector vdot = hvx_vmem(dot + off); + HVX_VectorPred mask = Q6_Q_vsetq2_R(nloe * sizeof(float)); + HVX_Vector zero = Q6_V_vzero(); + + HVX_Vector out0 = hvx_vec_add_f32_f32(hvx_vmemu(dst0 + off), hvx_vec_mul_f32_f32(vs, scale0)); + HVX_Vector out1 = hvx_vec_add_f32_f32(hvx_vmemu(dst1 + off), hvx_vec_mul_f32_f32(vs, scale1)); + HVX_Vector out2 = hvx_vec_add_f32_f32(hvx_vmemu(dst2 + off), hvx_vec_mul_f32_f32(vs, scale2)); + HVX_Vector out3 = hvx_vec_add_f32_f32(hvx_vmemu(dst3 + off), hvx_vec_mul_f32_f32(vs, scale3)); + HVX_Vector out4 = hvx_vec_add_f32_f32(hvx_vmemu(dst4 + off), hvx_vec_mul_f32_f32(vs, scale4)); + HVX_Vector out5 = hvx_vec_add_f32_f32(hvx_vmemu(dst5 + off), hvx_vec_mul_f32_f32(vs, scale5)); + HVX_Vector out6 = hvx_vec_add_f32_f32(hvx_vmemu(dst6 + off), hvx_vec_mul_f32_f32(vs, scale6)); + HVX_Vector out7 = hvx_vec_add_f32_f32(hvx_vmemu(dst7 + off), hvx_vec_mul_f32_f32(vs, scale7)); + + hvx_vec_store_u(dst0 + off, nloe * sizeof(float), out0); + hvx_vec_store_u(dst1 + off, nloe * sizeof(float), out1); + hvx_vec_store_u(dst2 + off, nloe * sizeof(float), out2); + hvx_vec_store_u(dst3 + off, nloe * sizeof(float), out3); + hvx_vec_store_u(dst4 + off, nloe * sizeof(float), out4); + hvx_vec_store_u(dst5 + off, nloe * sizeof(float), out5); + hvx_vec_store_u(dst6 + off, nloe * sizeof(float), out6); + hvx_vec_store_u(dst7 + off, nloe * sizeof(float), out7); + + acc0 = hvx_vec_add_f32_f32(acc0, Q6_V_vmux_QVV(mask, hvx_vec_mul_f32_f32(out0, vdot), zero)); + acc1 = hvx_vec_add_f32_f32(acc1, Q6_V_vmux_QVV(mask, hvx_vec_mul_f32_f32(out1, vdot), zero)); + acc2 = hvx_vec_add_f32_f32(acc2, Q6_V_vmux_QVV(mask, hvx_vec_mul_f32_f32(out2, vdot), zero)); + acc3 = hvx_vec_add_f32_f32(acc3, Q6_V_vmux_QVV(mask, hvx_vec_mul_f32_f32(out3, vdot), zero)); + acc4 = hvx_vec_add_f32_f32(acc4, Q6_V_vmux_QVV(mask, hvx_vec_mul_f32_f32(out4, vdot), zero)); + acc5 = hvx_vec_add_f32_f32(acc5, Q6_V_vmux_QVV(mask, hvx_vec_mul_f32_f32(out5, vdot), zero)); + acc6 = hvx_vec_add_f32_f32(acc6, Q6_V_vmux_QVV(mask, hvx_vec_mul_f32_f32(out6, vdot), zero)); + acc7 = hvx_vec_add_f32_f32(acc7, Q6_V_vmux_QVV(mask, hvx_vec_mul_f32_f32(out7, vdot), zero)); + } + + HVX_Vector_x4 accA = { .v = { acc0, acc1, acc2, acc3 } }; + HVX_Vector_x4 accB = { .v = { acc4, acc5, acc6, acc7 } }; + hvx_vec_store_u(sums + 0, 4 * sizeof(float), hvx_vec_reduce_sum_f32x4(accA)); + hvx_vec_store_u(sums + 4, 4 * sizeof(float), hvx_vec_reduce_sum_f32x4(accB)); +} + +static void gated_delta_net_f32_pp_thread(unsigned int nth, unsigned int ith, void * data) { + struct htp_gdn_context * gctx = (struct htp_gdn_context *) data; + struct htp_ops_context * octx = gctx->octx; + + const struct htp_tensor * q = octx->src[0]; + const struct htp_tensor * k = octx->src[1]; + const struct htp_tensor * v = octx->src[2]; + const struct htp_tensor * g = octx->src[3]; + const struct htp_tensor * beta = octx->src[4]; + const struct htp_tensor * state = octx->src[5]; + const struct htp_tensor * dst = octx->dst; + + const uint32_t S_v = v->ne[0]; + const uint32_t H = v->ne[1]; + const uint32_t n_tokens = v->ne[2]; + const uint32_t n_seqs = v->ne[3]; + const uint32_t K = octx->op_params[0]; + + const uint32_t total_rows = H * n_seqs; + if (ith >= total_rows) { + return; + } + + const uint32_t rq3 = n_seqs / q->ne[3]; + const uint32_t rk3 = n_seqs / k->ne[3]; + const float scale = 1.0f / sqrtf((float) S_v); + + float * dst_base = (float *) (uintptr_t) dst->data; + float * state_out_base = dst_base + (uint64_t) S_v * H * n_tokens * n_seqs; + const float * state_in_base = (const float *) (uintptr_t) state->data; + + const bool kda = (g->ne[0] == S_v); + float local_gate[HTP_GDN_MAX_SV] __attribute__((aligned(128))); + float local_q[HTP_GDN_MAX_SV] __attribute__((aligned(128))); + float local_k[HTP_GDN_MAX_SV] __attribute__((aligned(128))); + float local_sums[32] __attribute__((aligned(128))); + + dma_queue * dma = octx->ctx->dma[ith]; + size_t state_aligned = (size_t) S_v * S_v * sizeof(float); + state_aligned = (state_aligned + 127) & ~(size_t)127; + float * s_work[2]; + s_work[0] = (float *) (gctx->vtcm_base + gctx->vtcm_per_thread * ith); + s_work[1] = s_work[0] + state_aligned / sizeof(float); + + struct fastdiv_values fd_H = init_fastdiv_values(H); + struct fastdiv_values fd_q1 = init_fastdiv_values(q->ne[1]); + struct fastdiv_values fd_k1 = init_fastdiv_values(k->ne[1]); + struct fastdiv_values fd_rq3 = init_fastdiv_values(rq3); + struct fastdiv_values fd_rk3 = init_fastdiv_values(rk3); + + const uint64_t state_seq_stride = state->nb[3] / sizeof(float); + const uint64_t state_size_per_snap = (uint64_t) S_v * S_v * H * n_seqs; + + uint32_t ir_prefetch = ith; + int spad_idx = 0; + + // Prefetch preamble (up to 2 steps) + for (int k = 0; k < 2 && ir_prefetch < total_rows; k++) { + const uint32_t piv1 = fastmodulo(ir_prefetch, H, &fd_H); + const uint32_t piv3 = fastdiv(ir_prefetch, &fd_H); + const float * ps_in = state_in_base + (uint64_t) piv3 * state_seq_stride + (uint64_t) piv1 * S_v * S_v; + // final state lands in snapshot slot 0 (most-recent-first ordering) + float * ps_out = state_out_base + ((uint64_t) piv3 * H + piv1) * S_v * S_v; + + // Push dummy write-back + dma_queue_push(dma, dma_make_ptr(ps_out, s_work[spad_idx]), + S_v * sizeof(float), S_v * sizeof(float), + S_v * sizeof(float), 0); + + // Push fetch + dma_queue_push(dma, dma_make_ptr(s_work[spad_idx], ps_in), + S_v * sizeof(float), S_v * sizeof(float), + S_v * sizeof(float), S_v); + + ir_prefetch += nth; + spad_idx ^= 1; + } + + int curr_spad_idx = 0; + for (uint32_t ir = ith; ir < total_rows; ir += nth) { + dma_queue_pop(dma); + dma_queue_pop(dma); + + float * s_work_curr = s_work[curr_spad_idx]; + + const uint32_t iv1 = fastmodulo(ir, H, &fd_H); + const uint32_t iv3 = fastdiv(ir, &fd_H); + + const uint32_t iq1 = fastmodulo(iv1, q->ne[1], &fd_q1); + const uint32_t ik1 = fastmodulo(iv1, k->ne[1], &fd_k1); + const uint32_t iq3 = fastdiv(iv3, &fd_rq3); + const uint32_t ik3 = fastdiv(iv3, &fd_rk3); + + // final state lands in snapshot slot 0 (most-recent-first ordering) + float * s_out = state_out_base + ((uint64_t) iv3 * H + iv1) * S_v * S_v; + + float * attn_data = dst_base + ((uint64_t) iv3 * n_tokens * H + iv1) * S_v; + + for (uint32_t t = 0; t < n_tokens; ++t) { + const float * q_t = (const float *) ((const uint8_t *) (uintptr_t) q->data + + (uint64_t) iq3 * q->nb[3] + (uint64_t) t * q->nb[2] + (uint64_t) iq1 * q->nb[1]); + const float * k_t = (const float *) ((const uint8_t *) (uintptr_t) k->data + + (uint64_t) ik3 * k->nb[3] + (uint64_t) t * k->nb[2] + (uint64_t) ik1 * k->nb[1]); + const float * v_t = (const float *) ((const uint8_t *) (uintptr_t) v->data + + (uint64_t) iv3 * v->nb[3] + (uint64_t) t * v->nb[2] + (uint64_t) iv1 * v->nb[1]); + const float * g_t = (const float *) ((const uint8_t *) (uintptr_t) g->data + + (uint64_t) iv3 * g->nb[3] + (uint64_t) t * g->nb[2] + (uint64_t) iv1 * g->nb[1]); + const float beta_val = *(const float *) ((const uint8_t *) (uintptr_t) beta->data + + (uint64_t) iv3 * beta->nb[3] + (uint64_t) t * beta->nb[2] + (uint64_t) iv1 * beta->nb[1]); + + hvx_copy_f32_au((uint8_t *) local_q, (const uint8_t *) q_t, S_v); + hvx_copy_f32_au((uint8_t *) local_k, (const uint8_t *) k_t, S_v); + + if (kda) { + hvx_exp_f32((uint8_t *) local_gate, (const uint8_t *) g_t, S_v, false); + + uint32_t j = 0; + for (; j + 8 <= S_v; j += 8) { + float * row0 = s_work_curr + (uint64_t) (j + 0) * S_v; + float * row1 = s_work_curr + (uint64_t) (j + 1) * S_v; + float * row2 = s_work_curr + (uint64_t) (j + 2) * S_v; + float * row3 = s_work_curr + (uint64_t) (j + 3) * S_v; + float * row4 = s_work_curr + (uint64_t) (j + 4) * S_v; + float * row5 = s_work_curr + (uint64_t) (j + 5) * S_v; + float * row6 = s_work_curr + (uint64_t) (j + 6) * S_v; + float * row7 = s_work_curr + (uint64_t) (j + 7) * S_v; + gdn_mul_dot8_f32(row0, row1, row2, row3, row4, row5, row6, row7, + local_gate, local_k, S_v, local_sums); + + float local_delta_b[32] __attribute__((aligned(128))); + HVX_Vector vv_t = hvx_vmemu(v_t + j); + HVX_Vector v_local_sums = hvx_vmem(local_sums); + HVX_Vector diff = hvx_vec_sub_f32_f32(vv_t, v_local_sums); + hvx_vmem(local_delta_b) = hvx_vec_mul_f32_f32(diff, hvx_vec_splat_f32(beta_val)); + + gdn_add_scaled_dot8_f32(row0, row1, row2, row3, row4, row5, row6, row7, + local_k, local_delta_b, local_q, S_v, local_sums); + + HVX_Vector res_attn = hvx_vec_mul_f32_f32(hvx_vmem(local_sums), hvx_vec_splat_f32(scale)); + hvx_vec_store_u(attn_data + j, 8 * sizeof(float), res_attn); + } + for (; j + 4 <= S_v; j += 4) { + float * row0 = s_work_curr + (uint64_t) (j + 0) * S_v; + float * row1 = s_work_curr + (uint64_t) (j + 1) * S_v; + float * row2 = s_work_curr + (uint64_t) (j + 2) * S_v; + float * row3 = s_work_curr + (uint64_t) (j + 3) * S_v; + gdn_mul_dot4_f32(row0, row1, row2, row3, local_gate, local_k, S_v, local_sums); + + float local_delta_b[32] __attribute__((aligned(128))); + HVX_Vector vv_t = hvx_vmemu(v_t + j); + HVX_Vector v_local_sums = hvx_vmem(local_sums); + HVX_Vector diff = hvx_vec_sub_f32_f32(vv_t, v_local_sums); + hvx_vmem(local_delta_b) = hvx_vec_mul_f32_f32(diff, hvx_vec_splat_f32(beta_val)); + + gdn_add_scaled_dot4_f32(row0, row1, row2, row3, local_k, local_delta_b, local_q, S_v, local_sums); + + HVX_Vector res_attn = hvx_vec_mul_f32_f32(hvx_vmem(local_sums), hvx_vec_splat_f32(scale)); + hvx_vec_store_u(attn_data + j, 4 * sizeof(float), res_attn); + } + HVX_Vector vscale_splat = hvx_vec_splat_f32(scale); + for (; j < S_v; ++j) { + float * row = s_work_curr + (uint64_t) j * S_v; + HVX_Vector vsum = gdn_mul_dot_f32(row, local_gate, local_k, S_v); + HVX_Vector vv_t = hvx_vec_splat_f32(v_t[j]); + HVX_Vector vdj = hvx_vec_mul_f32_f32(hvx_vec_sub_f32_f32(vv_t, vsum), hvx_vec_splat_f32(beta_val)); + HVX_Vector vres = gdn_add_scaled_dot_f32(row, local_k, vdj, local_q, S_v); + attn_data[j] = hvx_vec_get_f32(hvx_vec_mul_f32_f32(vres, vscale_splat)); + } + } else { + const float gate = expf(g_t[0]); + uint32_t j = 0; + for (; j + 8 <= S_v; j += 8) { + float * row0 = s_work_curr + (uint64_t) (j + 0) * S_v; + float * row1 = s_work_curr + (uint64_t) (j + 1) * S_v; + float * row2 = s_work_curr + (uint64_t) (j + 2) * S_v; + float * row3 = s_work_curr + (uint64_t) (j + 3) * S_v; + float * row4 = s_work_curr + (uint64_t) (j + 4) * S_v; + float * row5 = s_work_curr + (uint64_t) (j + 5) * S_v; + float * row6 = s_work_curr + (uint64_t) (j + 6) * S_v; + float * row7 = s_work_curr + (uint64_t) (j + 7) * S_v; + gdn_mul_scalar_dot8_f32(row0, row1, row2, row3, row4, row5, row6, row7, + gate, local_k, S_v, local_sums); + + float local_delta_b[32] __attribute__((aligned(128))); + HVX_Vector vv_t = hvx_vmemu(v_t + j); + HVX_Vector v_local_sums = hvx_vmem(local_sums); + HVX_Vector diff = hvx_vec_sub_f32_f32(vv_t, v_local_sums); + hvx_vmem(local_delta_b) = hvx_vec_mul_f32_f32(diff, hvx_vec_splat_f32(beta_val)); + + gdn_add_scaled_dot8_f32(row0, row1, row2, row3, row4, row5, row6, row7, + local_k, local_delta_b, local_q, S_v, local_sums); + + HVX_Vector res_attn = hvx_vec_mul_f32_f32(hvx_vmem(local_sums), hvx_vec_splat_f32(scale)); + hvx_vec_store_u(attn_data + j, 8 * sizeof(float), res_attn); + } + for (; j + 4 <= S_v; j += 4) { + float * row0 = s_work_curr + (uint64_t) (j + 0) * S_v; + float * row1 = s_work_curr + (uint64_t) (j + 1) * S_v; + float * row2 = s_work_curr + (uint64_t) (j + 2) * S_v; + float * row3 = s_work_curr + (uint64_t) (j + 3) * S_v; + gdn_mul_scalar_dot4_f32(row0, row1, row2, row3, gate, local_k, S_v, local_sums); + + float local_delta_b[32] __attribute__((aligned(128))); + HVX_Vector vv_t = hvx_vmemu(v_t + j); + HVX_Vector v_local_sums = hvx_vmem(local_sums); + HVX_Vector diff = hvx_vec_sub_f32_f32(vv_t, v_local_sums); + hvx_vmem(local_delta_b) = hvx_vec_mul_f32_f32(diff, hvx_vec_splat_f32(beta_val)); + + gdn_add_scaled_dot4_f32(row0, row1, row2, row3, local_k, local_delta_b, local_q, S_v, local_sums); + + HVX_Vector res_attn = hvx_vec_mul_f32_f32(hvx_vmem(local_sums), hvx_vec_splat_f32(scale)); + hvx_vec_store_u(attn_data + j, 4 * sizeof(float), res_attn); + } + HVX_Vector vscale_splat = hvx_vec_splat_f32(scale); + for (; j < S_v; ++j) { + float * row = s_work_curr + (uint64_t) j * S_v; + HVX_Vector vsum = gdn_mul_scalar_dot_f32(row, gate, local_k, S_v); + HVX_Vector vv_t = hvx_vec_splat_f32(v_t[j]); + HVX_Vector vdj = hvx_vec_mul_f32_f32(hvx_vec_sub_f32_f32(vv_t, vsum), hvx_vec_splat_f32(beta_val)); + HVX_Vector vres = gdn_add_scaled_dot_f32(row, local_k, vdj, local_q, S_v); + attn_data[j] = hvx_vec_get_f32(hvx_vec_mul_f32_f32(vres, vscale_splat)); + } + } + + if (K > 1) { + // snapshot slot mapping: slot 0 = most recent state, slot s = s tokens back. + const int64_t target_slot = (int64_t) n_tokens - 1 - (int64_t) t; + if (target_slot >= 0 && target_slot < (int64_t) K) { + float * curr_state_o = state_out_base + (uint64_t) target_slot * state_size_per_snap + ((uint64_t) iv3 * H + iv1) * S_v * S_v; + if (curr_state_o != s_out) { + hvx_copy_f32_uu((uint8_t *) curr_state_o, (const uint8_t *) s_work_curr, S_v * S_v); + } + } + } + + attn_data += (uint64_t) S_v * H; + } + + // Push real write-back + dma_queue_push(dma, dma_make_ptr(s_out, s_work_curr), + S_v * sizeof(float), S_v * sizeof(float), + S_v * sizeof(float), S_v); + + // Prefetch next block (if any) + if (ir_prefetch < total_rows) { + const uint32_t piv1 = fastmodulo(ir_prefetch, H, &fd_H); + const uint32_t piv3 = fastdiv(ir_prefetch, &fd_H); + const float * ps_in = state_in_base + (uint64_t) piv3 * state_seq_stride + (uint64_t) piv1 * S_v * S_v; + + dma_queue_push(dma, dma_make_ptr(s_work[spad_idx], ps_in), + S_v * sizeof(float), S_v * sizeof(float), + S_v * sizeof(float), S_v); + + ir_prefetch += nth; + spad_idx ^= 1; + } + + curr_spad_idx ^= 1; + } + dma_queue_flush(dma); +} + + +static void gated_delta_net_f32_tg_thread(unsigned int nth, unsigned int ith, void * data) { + struct htp_gdn_context * gctx = (struct htp_gdn_context *) data; + struct htp_ops_context * octx = gctx->octx; + + const struct htp_tensor * q = octx->src[0]; + const struct htp_tensor * k = octx->src[1]; + const struct htp_tensor * v = octx->src[2]; + const struct htp_tensor * g = octx->src[3]; + const struct htp_tensor * beta = octx->src[4]; + const struct htp_tensor * state = octx->src[5]; + const struct htp_tensor * dst = octx->dst; + + const uint32_t S_v = v->ne[0]; + const uint32_t H = v->ne[1]; + const uint32_t n_seqs = v->ne[3]; + + const uint32_t total_rows = H * n_seqs; + if (ith >= total_rows) { + return; + } + + const uint32_t rq3 = n_seqs / q->ne[3]; + const uint32_t rk3 = n_seqs / k->ne[3]; + const float scale = 1.0f / sqrtf((float) S_v); + + float * dst_base = (float *) (uintptr_t) dst->data; + float * state_out_base = dst_base + (uint64_t) S_v * H * n_seqs; + const float * state_in_base = (const float *) (uintptr_t) state->data; + + const bool kda = (g->ne[0] == S_v); + float local_gate[HTP_GDN_MAX_SV] __attribute__((aligned(128))); + float local_q[HTP_GDN_MAX_SV] __attribute__((aligned(128))); + float local_k[HTP_GDN_MAX_SV] __attribute__((aligned(128))); + float local_sums[32] __attribute__((aligned(128))); + + dma_queue * dma = octx->ctx->dma[ith]; + size_t state_aligned = (size_t) S_v * S_v * sizeof(float); + state_aligned = (state_aligned + 127) & ~(size_t)127; + float * s_work[2]; + s_work[0] = (float *) (gctx->vtcm_base + gctx->vtcm_per_thread * ith); + s_work[1] = s_work[0] + state_aligned / sizeof(float); + + struct fastdiv_values fd_H = init_fastdiv_values(H); + struct fastdiv_values fd_q1 = init_fastdiv_values(q->ne[1]); + struct fastdiv_values fd_k1 = init_fastdiv_values(k->ne[1]); + struct fastdiv_values fd_rq3 = init_fastdiv_values(rq3); + struct fastdiv_values fd_rk3 = init_fastdiv_values(rk3); + + const uint64_t state_seq_stride = state->nb[3] / sizeof(float); + + uint32_t ir_prefetch = ith; + int spad_idx = 0; + + // Prefetch preamble (up to 2 steps) + for (int k = 0; k < 2 && ir_prefetch < total_rows; k++) { + const uint32_t piv1 = fastmodulo(ir_prefetch, H, &fd_H); + const uint32_t piv3 = fastdiv(ir_prefetch, &fd_H); + const float * ps_in = state_in_base + (uint64_t) piv3 * state_seq_stride + (uint64_t) piv1 * S_v * S_v; + // final state lands in snapshot slot 0 (most-recent-first ordering) + float * ps_out = state_out_base + ((uint64_t) piv3 * H + piv1) * S_v * S_v; + + // Push dummy write-back + dma_queue_push(dma, dma_make_ptr(ps_out, s_work[spad_idx]), + S_v * sizeof(float), S_v * sizeof(float), + S_v * sizeof(float), 0); + + // Push fetch + dma_queue_push(dma, dma_make_ptr(s_work[spad_idx], ps_in), + S_v * sizeof(float), S_v * sizeof(float), + S_v * sizeof(float), S_v); + + ir_prefetch += nth; + spad_idx ^= 1; + } + + int curr_spad_idx = 0; + for (uint32_t ir = ith; ir < total_rows; ir += nth) { + dma_queue_pop(dma); + dma_queue_pop(dma); + + float * s_work_curr = s_work[curr_spad_idx]; + + const uint32_t iv1 = fastmodulo(ir, H, &fd_H); + const uint32_t iv3 = fastdiv(ir, &fd_H); + + const uint32_t iq1 = fastmodulo(iv1, q->ne[1], &fd_q1); + const uint32_t ik1 = fastmodulo(iv1, k->ne[1], &fd_k1); + const uint32_t iq3 = fastdiv(iv3, &fd_rq3); + const uint32_t ik3 = fastdiv(iv3, &fd_rk3); + + // final state lands in snapshot slot 0 (most-recent-first ordering) + float * s_out = state_out_base + ((uint64_t) iv3 * H + iv1) * S_v * S_v; + + float * attn_data = dst_base + ((uint64_t) iv3 * H + iv1) * S_v; + + const float * q_t = (const float *) ((const uint8_t *) (uintptr_t) q->data + + (uint64_t) iq3 * q->nb[3] + (uint64_t) iq1 * q->nb[1]); + const float * k_t = (const float *) ((const uint8_t *) (uintptr_t) k->data + + (uint64_t) ik3 * k->nb[3] + (uint64_t) ik1 * k->nb[1]); + const float * v_t = (const float *) ((const uint8_t *) (uintptr_t) v->data + + (uint64_t) iv3 * v->nb[3] + (uint64_t) iv1 * v->nb[1]); + const float * g_t = (const float *) ((const uint8_t *) (uintptr_t) g->data + + (uint64_t) iv3 * g->nb[3] + (uint64_t) iv1 * g->nb[1]); + const float beta_val = *(const float *) ((const uint8_t *) (uintptr_t) beta->data + + (uint64_t) iv3 * beta->nb[3] + (uint64_t) iv1 * beta->nb[1]); + + hvx_copy_f32_au((uint8_t *) local_q, (const uint8_t *) q_t, S_v); + hvx_copy_f32_au((uint8_t *) local_k, (const uint8_t *) k_t, S_v); + + if (kda) { + hvx_exp_f32((uint8_t *) local_gate, (const uint8_t *) g_t, S_v, false); + + uint32_t j = 0; + for (; j + 8 <= S_v; j += 8) { + float * row0 = s_work_curr + (uint64_t) (j + 0) * S_v; + float * row1 = s_work_curr + (uint64_t) (j + 1) * S_v; + float * row2 = s_work_curr + (uint64_t) (j + 2) * S_v; + float * row3 = s_work_curr + (uint64_t) (j + 3) * S_v; + float * row4 = s_work_curr + (uint64_t) (j + 4) * S_v; + float * row5 = s_work_curr + (uint64_t) (j + 5) * S_v; + float * row6 = s_work_curr + (uint64_t) (j + 6) * S_v; + float * row7 = s_work_curr + (uint64_t) (j + 7) * S_v; + gdn_mul_dot8_f32(row0, row1, row2, row3, row4, row5, row6, row7, + local_gate, local_k, S_v, local_sums); + + float local_delta_b[32] __attribute__((aligned(128))); + HVX_Vector vv_t = hvx_vmemu(v_t + j); + HVX_Vector v_local_sums = hvx_vmem(local_sums); + HVX_Vector diff = hvx_vec_sub_f32_f32(vv_t, v_local_sums); + hvx_vmem(local_delta_b) = hvx_vec_mul_f32_f32(diff, hvx_vec_splat_f32(beta_val)); + + gdn_add_scaled_dot8_f32(row0, row1, row2, row3, row4, row5, row6, row7, + local_k, local_delta_b, local_q, S_v, local_sums); + + HVX_Vector res_attn = hvx_vec_mul_f32_f32(hvx_vmem(local_sums), hvx_vec_splat_f32(scale)); + hvx_vec_store_u(attn_data + j, 8 * sizeof(float), res_attn); + } + for (; j + 4 <= S_v; j += 4) { + float * row0 = s_work_curr + (uint64_t) (j + 0) * S_v; + float * row1 = s_work_curr + (uint64_t) (j + 1) * S_v; + float * row2 = s_work_curr + (uint64_t) (j + 2) * S_v; + float * row3 = s_work_curr + (uint64_t) (j + 3) * S_v; + gdn_mul_dot4_f32(row0, row1, row2, row3, local_gate, local_k, S_v, local_sums); + + float local_delta_b[32] __attribute__((aligned(128))); + HVX_Vector vv_t = hvx_vmemu(v_t + j); + HVX_Vector v_local_sums = hvx_vmem(local_sums); + HVX_Vector diff = hvx_vec_sub_f32_f32(vv_t, v_local_sums); + hvx_vmem(local_delta_b) = hvx_vec_mul_f32_f32(diff, hvx_vec_splat_f32(beta_val)); + + gdn_add_scaled_dot4_f32(row0, row1, row2, row3, local_k, local_delta_b, local_q, S_v, local_sums); + + HVX_Vector res_attn = hvx_vec_mul_f32_f32(hvx_vmem(local_sums), hvx_vec_splat_f32(scale)); + hvx_vec_store_u(attn_data + j, 4 * sizeof(float), res_attn); + } + HVX_Vector vscale_splat = hvx_vec_splat_f32(scale); + for (; j < S_v; ++j) { + float * row = s_work_curr + (uint64_t) j * S_v; + HVX_Vector vsum = gdn_mul_dot_f32(row, local_gate, local_k, S_v); + HVX_Vector vv_t = hvx_vec_splat_f32(v_t[j]); + HVX_Vector vdj = hvx_vec_mul_f32_f32(hvx_vec_sub_f32_f32(vv_t, vsum), hvx_vec_splat_f32(beta_val)); + HVX_Vector vres = gdn_add_scaled_dot_f32(row, local_k, vdj, local_q, S_v); + attn_data[j] = hvx_vec_get_f32(hvx_vec_mul_f32_f32(vres, vscale_splat)); + } + } else { + const float gate = expf(g_t[0]); + uint32_t j = 0; + for (; j + 8 <= S_v; j += 8) { + float * row0 = s_work_curr + (uint64_t) (j + 0) * S_v; + float * row1 = s_work_curr + (uint64_t) (j + 1) * S_v; + float * row2 = s_work_curr + (uint64_t) (j + 2) * S_v; + float * row3 = s_work_curr + (uint64_t) (j + 3) * S_v; + float * row4 = s_work_curr + (uint64_t) (j + 4) * S_v; + float * row5 = s_work_curr + (uint64_t) (j + 5) * S_v; + float * row6 = s_work_curr + (uint64_t) (j + 6) * S_v; + float * row7 = s_work_curr + (uint64_t) (j + 7) * S_v; + gdn_mul_scalar_dot8_f32(row0, row1, row2, row3, row4, row5, row6, row7, + gate, local_k, S_v, local_sums); + + float local_delta_b[32] __attribute__((aligned(128))); + HVX_Vector vv_t = hvx_vmemu(v_t + j); + HVX_Vector v_local_sums = hvx_vmem(local_sums); + HVX_Vector diff = hvx_vec_sub_f32_f32(vv_t, v_local_sums); + hvx_vmem(local_delta_b) = hvx_vec_mul_f32_f32(diff, hvx_vec_splat_f32(beta_val)); + + gdn_add_scaled_dot8_f32(row0, row1, row2, row3, row4, row5, row6, row7, + local_k, local_delta_b, local_q, S_v, local_sums); + + HVX_Vector res_attn = hvx_vec_mul_f32_f32(hvx_vmem(local_sums), hvx_vec_splat_f32(scale)); + hvx_vec_store_u(attn_data + j, 8 * sizeof(float), res_attn); + } + for (; j + 4 <= S_v; j += 4) { + float * row0 = s_work_curr + (uint64_t) (j + 0) * S_v; + float * row1 = s_work_curr + (uint64_t) (j + 1) * S_v; + float * row2 = s_work_curr + (uint64_t) (j + 2) * S_v; + float * row3 = s_work_curr + (uint64_t) (j + 3) * S_v; + gdn_mul_scalar_dot4_f32(row0, row1, row2, row3, gate, local_k, S_v, local_sums); + + float local_delta_b[32] __attribute__((aligned(128))); + HVX_Vector vv_t = hvx_vmemu(v_t + j); + HVX_Vector v_local_sums = hvx_vmem(local_sums); + HVX_Vector diff = hvx_vec_sub_f32_f32(vv_t, v_local_sums); + hvx_vmem(local_delta_b) = hvx_vec_mul_f32_f32(diff, hvx_vec_splat_f32(beta_val)); + + gdn_add_scaled_dot4_f32(row0, row1, row2, row3, local_k, local_delta_b, local_q, S_v, local_sums); + + HVX_Vector res_attn = hvx_vec_mul_f32_f32(hvx_vmem(local_sums), hvx_vec_splat_f32(scale)); + hvx_vec_store_u(attn_data + j, 4 * sizeof(float), res_attn); + } + HVX_Vector vscale_splat = hvx_vec_splat_f32(scale); + for (; j < S_v; ++j) { + float * row = s_work_curr + (uint64_t) j * S_v; + HVX_Vector vsum = gdn_mul_scalar_dot_f32(row, gate, local_k, S_v); + HVX_Vector vv_t = hvx_vec_splat_f32(v_t[j]); + HVX_Vector vdj = hvx_vec_mul_f32_f32(hvx_vec_sub_f32_f32(vv_t, vsum), hvx_vec_splat_f32(beta_val)); + HVX_Vector vres = gdn_add_scaled_dot_f32(row, local_k, vdj, local_q, S_v); + attn_data[j] = hvx_vec_get_f32(hvx_vec_mul_f32_f32(vres, vscale_splat)); + } + } + + // Push real write-back + dma_queue_push(dma, dma_make_ptr(s_out, s_work_curr), + S_v * sizeof(float), S_v * sizeof(float), + S_v * sizeof(float), S_v); + + // Prefetch next block (if any) + if (ir_prefetch < total_rows) { + const uint32_t piv1 = fastmodulo(ir_prefetch, H, &fd_H); + const uint32_t piv3 = fastdiv(ir_prefetch, &fd_H); + const float * ps_in = state_in_base + (uint64_t) piv3 * state_seq_stride + (uint64_t) piv1 * S_v * S_v; + + dma_queue_push(dma, dma_make_ptr(s_work[spad_idx], ps_in), + S_v * sizeof(float), S_v * sizeof(float), + S_v * sizeof(float), S_v); + + ir_prefetch += nth; + spad_idx ^= 1; + } + + curr_spad_idx ^= 1; + } + dma_queue_flush(dma); +} + + +int op_gated_delta_net(struct htp_ops_context * octx) { + const struct htp_tensor * q = octx->src[0]; + const struct htp_tensor * k = octx->src[1]; + const struct htp_tensor * v = octx->src[2]; + const struct htp_tensor * g = octx->src[3]; + const struct htp_tensor * beta = octx->src[4]; + const struct htp_tensor * state = octx->src[5]; + const struct htp_tensor * dst = octx->dst; + + if (!q || !k || !v || !g || !beta || !state || !dst) { + return HTP_STATUS_INVAL_PARAMS; + } + + if (q->type != HTP_TYPE_F32 || k->type != HTP_TYPE_F32 || v->type != HTP_TYPE_F32 || + g->type != HTP_TYPE_F32 || beta->type != HTP_TYPE_F32 || state->type != HTP_TYPE_F32 || + dst->type != HTP_TYPE_F32) { + return HTP_STATUS_NO_SUPPORT; + } + + const uint32_t S_v = v->ne[0]; + const uint32_t H = v->ne[1]; + const uint32_t n_tokens = v->ne[2]; + const uint32_t n_seqs = v->ne[3]; + const uint32_t K = octx->op_params[0]; + + if (S_v == 0 || S_v > HTP_GDN_MAX_SV || H == 0 || n_tokens == 0 || n_seqs == 0) { + return HTP_STATUS_NO_SUPPORT; + } + if ((g->ne[0] != 1 && g->ne[0] != S_v) || beta->ne[0] != 1) { + return HTP_STATUS_NO_SUPPORT; + } + if (q->ne[0] != S_v || k->ne[0] != S_v || q->ne[1] == 0 || k->ne[1] == 0 || + q->ne[2] != n_tokens || k->ne[2] != n_tokens || q->ne[3] == 0 || k->ne[3] == 0 || + (n_seqs % q->ne[3]) != 0 || (n_seqs % k->ne[3]) != 0) { + return HTP_STATUS_NO_SUPPORT; + } + // state holds s0 only: [S_v, S_v, H, n_seqs] + if (state->ne[0] != S_v || state->ne[1] != S_v || state->ne[2] != H || state->ne[3] != n_seqs) { + return HTP_STATUS_NO_SUPPORT; + } + if (dst->ne[0] != S_v * H || dst->ne[1] != n_tokens * n_seqs + S_v * n_seqs * K) { + return HTP_STATUS_NO_SUPPORT; + } + + if (octx->flags & HTP_OPFLAGS_SKIP_COMPUTE) { + return HTP_STATUS_OK; + } + + struct htp_gdn_context gctx; + gctx.octx = octx; + gctx.rows_per_thread = (H * n_seqs + octx->n_threads - 1) / octx->n_threads; + gctx.state_bytes = (size_t) S_v * S_v * sizeof(float); + + size_t state_aligned = (size_t) S_v * S_v * sizeof(float); + state_aligned = (state_aligned + 127) & ~(size_t)127; + + assert(octx->ctx->vtcm_base != NULL); + assert(octx->ctx->vtcm_size >= 2 * state_aligned * octx->n_threads); + + gctx.vtcm_base = octx->ctx->vtcm_base; + gctx.vtcm_per_thread = 2 * state_aligned; + + if (n_tokens == 1) { + worker_pool_run_func(octx->ctx->worker_pool, gated_delta_net_f32_tg_thread, &gctx, octx->n_threads); + } else { + worker_pool_run_func(octx->ctx->worker_pool, gated_delta_net_f32_pp_thread, &gctx, octx->n_threads); + } + + return HTP_STATUS_OK; +} diff --git a/ggml/src/ggml-hexagon/htp/get-rows-ops.c b/ggml/src/ggml-hexagon/htp/get-rows-ops.c index 54321421eb5..bf7063e9880 100644 --- a/ggml/src/ggml-hexagon/htp/get-rows-ops.c +++ b/ggml/src/ggml-hexagon/htp/get-rows-ops.c @@ -2,98 +2,164 @@ #pragma clang diagnostic ignored "-Wunused-function" #pragma clang diagnostic ignored "-Wunused-but-set-variable" -#ifdef HTP_DEBUG -# define FARF_HIGH 1 -#endif #include <HAP_farf.h> -#include <HAP_mem.h> #include <HAP_perf.h> -#include <hexagon_protos.h> -#include <hexagon_types.h> + #include <math.h> #include <string.h> #define GGML_COMMON_DECL_C #include "ggml-common.h" #include "htp-ctx.h" -#include "htp-msg.h" +#include "htp-ops.h" #include "htp-ops.h" #include "hvx-utils.h" -#include "ops-utils.h" + +struct get_rows_context { + struct htp_ops_context * octx; + uint32_t tasks_per_thread; + uint32_t total_tasks; + uint32_t chunks_per_row; + uint32_t chunk_size; + struct fastdiv_values get_rows_div_ne10; + struct fastdiv_values get_rows_div_ne10_ne11; + struct fastdiv_values get_rows_div_chunks_per_row; +}; #define get_rows_preamble \ - const uint32_t ne00 = octx->src0.ne[0]; \ - const uint32_t ne01 = octx->src0.ne[1]; \ - const uint32_t ne02 = octx->src0.ne[2]; \ - const uint32_t ne03 = octx->src0.ne[3]; \ - \ - const uint32_t ne10 = octx->src1.ne[0]; \ - const uint32_t ne11 = octx->src1.ne[1]; \ - const uint32_t ne12 = octx->src1.ne[2]; \ - \ - const uint32_t nb01 = octx->src0.nb[1]; \ - const uint32_t nb02 = octx->src0.nb[2]; \ - const uint32_t nb03 = octx->src0.nb[3]; \ - \ - const uint32_t nb10 = octx->src1.nb[0]; \ - const uint32_t nb11 = octx->src1.nb[1]; \ - const uint32_t nb12 = octx->src1.nb[2]; \ - \ - const uint32_t nb1 = octx->dst.nb[1]; \ - const uint32_t nb2 = octx->dst.nb[2]; \ - const uint32_t nb3 = octx->dst.nb[3]; \ - \ + const uint32_t ne00 = octx->src[0]->ne[0]; \ + const uint32_t ne01 = octx->src[0]->ne[1]; \ + const uint32_t ne02 = octx->src[0]->ne[2]; \ + const uint32_t ne03 = octx->src[0]->ne[3]; \ + \ + const uint32_t ne10 = octx->src[1]->ne[0]; \ + const uint32_t ne11 = octx->src[1]->ne[1]; \ + const uint32_t ne12 = octx->src[1]->ne[2]; \ + const uint32_t ne13 = octx->src[1]->ne[3]; \ + \ + const uint32_t ne0 = octx->dst->ne[0]; \ + const uint32_t ne1 = octx->dst->ne[1]; \ + const uint32_t ne2 = octx->dst->ne[2]; \ + const uint32_t ne3 = octx->dst->ne[3]; \ + \ + const uint32_t nb01 = octx->src[0]->nb[1]; \ + const uint32_t nb02 = octx->src[0]->nb[2]; \ + const uint32_t nb03 = octx->src[0]->nb[3]; \ + \ + const uint32_t nb10 = octx->src[1]->nb[0]; \ + const uint32_t nb11 = octx->src[1]->nb[1]; \ + const uint32_t nb12 = octx->src[1]->nb[2]; \ + \ + const uint32_t nb1 = octx->dst->nb[1]; \ + const uint32_t nb2 = octx->dst->nb[2]; \ + const uint32_t nb3 = octx->dst->nb[3]; \ + \ const uint32_t nr = ne10 * ne11 * ne12; -static int get_rows_thread_f32_f32(struct htp_ops_context * octx, const int nth, const int ith) { +static void get_rows_thread_f32_f32_dma(unsigned int nth, unsigned int ith, void *data) { + struct get_rows_context * grctx = (struct get_rows_context *)data; + struct htp_ops_context * octx = grctx->octx; get_rows_preamble; - // parallelize by src1 elements (which correspond to dst rows) - const uint32_t dr = octx->src1_nrows_per_thread; + uint64_t qt = HAP_perf_get_qtimer_count(); + + const uint32_t dr = grctx->tasks_per_thread; const uint32_t ir0 = dr * ith; - const uint32_t ir1 = (ir0 + dr < nr) ? (ir0 + dr) : nr; + if (ir0 >= grctx->total_tasks) { + return; + } + const uint32_t ir1 = MIN(ir0 + dr, grctx->total_tasks); - const bool is_i32 = (octx->src1.type == HTP_TYPE_I32); + const bool is_i32 = (octx->src[1]->type == HTP_TYPE_I32); + dma_queue * dma_queue = octx->ctx->dma[ith]; for (uint32_t i = ir0; i < ir1; ++i) { - const uint32_t i12 = fastdiv(i, &octx->get_rows_div_ne10_ne11); + const uint32_t i12 = fastdiv(i, &grctx->get_rows_div_ne10_ne11); const uint32_t rem = i - i12 * ne11 * ne10; - const uint32_t i11 = fastdiv(rem, &octx->get_rows_div_ne10); + const uint32_t i11 = fastdiv(rem, &grctx->get_rows_div_ne10); const uint32_t i10 = rem - i11 * ne10; - const uintptr_t src1_addr = octx->src1.data + i10*nb10 + i11*nb11 + i12*nb12; - + const uintptr_t src1_addr = octx->src[1]->data + i10*nb10 + i11*nb11 + i12*nb12; uint32_t i01 = is_i32 ? *(int32_t *)src1_addr : *(int64_t *)src1_addr; if (i01 >= ne01) { - // invalid index, skip for now to avoid crash continue; } - const uintptr_t src0_ptr = octx->src0.data + i01*nb01 + i11*nb02 + i12*nb03; - const uintptr_t dst_ptr = octx->dst.data + i10*nb1 + i11*nb2 + i12*nb3; - hvx_copy_fp32_uu((uint8_t *)dst_ptr, (const uint8_t *)src0_ptr, ne00); + const uintptr_t src0_ptr = octx->src[0]->data + i01*nb01 + i11*nb02 + i12*nb03; + const uintptr_t dst_ptr = octx->dst->data + i10*nb1 + i11*nb2 + i12*nb3; + + while (!dma_queue_push(dma_queue, dma_make_ptr((void *)dst_ptr, (const void *)src0_ptr), nb1, nb01, ne00 * sizeof(float), 1)) { + dma_queue_pop(dma_queue); + } } + dma_queue_flush(dma_queue); - return HTP_STATUS_OK; + qt = HAP_perf_qtimer_count_to_us(HAP_perf_get_qtimer_count() - qt); + FARF(HIGH, "get-rows-f32-f32-dma %d/%d: %ux%ux%ux%u (%u:%u) x %ux%ux%ux%u -> %ux%ux%ux%u usec %u\n", ith, nth, + ne00, ne01, ne02, ne03, ir0, ir1, ne10, ne11, ne12, ne13, ne0, ne1, ne2, ne3, (unsigned) qt); } -static void get_rows_work_f32_f32(unsigned int n, unsigned int i, void *data) { - get_rows_thread_f32_f32((struct htp_ops_context *) data, n, i); +static void get_rows_thread_f32_f32_hvx(unsigned int nth, unsigned int ith, void *data) { + struct get_rows_context * grctx = (struct get_rows_context *)data; + struct htp_ops_context * octx = grctx->octx; + get_rows_preamble; + + uint64_t qt = HAP_perf_get_qtimer_count(); + + const uint32_t dr = grctx->tasks_per_thread; + const uint32_t ir0 = dr * ith; + if (ir0 >= grctx->total_tasks) { + return; + } + const uint32_t ir1 = MIN(ir0 + dr, grctx->total_tasks); + + const bool is_i32 = (octx->src[1]->type == HTP_TYPE_I32); + + const uint32_t chunks_per_row = grctx->chunks_per_row; + const uint32_t chunk_size = grctx->chunk_size; + for (uint32_t i = ir0; i < ir1; ++i) { + const uint32_t row_idx = fastdiv(i, &grctx->get_rows_div_chunks_per_row); + const uint32_t chunk_idx = i - row_idx * chunks_per_row; + + const uint32_t i12 = fastdiv(row_idx, &grctx->get_rows_div_ne10_ne11); + const uint32_t rem = row_idx - i12 * ne11 * ne10; + const uint32_t i11 = fastdiv(rem, &grctx->get_rows_div_ne10); + const uint32_t i10 = rem - i11 * ne10; + + const uintptr_t src1_addr = octx->src[1]->data + i10*nb10 + i11*nb11 + i12*nb12; + uint32_t i01 = is_i32 ? *(int32_t *)src1_addr : *(int64_t *)src1_addr; + + if (i01 >= ne01) { + continue; + } + + const uint32_t offset = chunk_idx * chunk_size; + if (offset < ne00) { + const uint32_t copy_size = MIN(chunk_size, ne00 - offset); + const uintptr_t src0_ptr = octx->src[0]->data + i01*nb01 + i11*nb02 + i12*nb03 + offset * sizeof(float); + const uintptr_t dst_ptr = octx->dst->data + i10*nb1 + i11*nb2 + i12*nb3 + offset * sizeof(float); + hvx_copy_f32_uu((uint8_t *)dst_ptr, (const uint8_t *)src0_ptr, copy_size); + } + } + + qt = HAP_perf_qtimer_count_to_us(HAP_perf_get_qtimer_count() - qt); + FARF(HIGH, "get-rows-f32-f32-hvx %d/%d: %ux%ux%ux%u (%u:%u) x %ux%ux%ux%u -> %ux%ux%ux%u usec %u\n", ith, nth, + ne00, ne01, ne02, ne03, ir0, ir1, ne10, ne11, ne12, ne13, ne0, ne1, ne2, ne3, (unsigned) qt); } int op_get_rows(struct htp_ops_context * octx) { get_rows_preamble; - if (octx->src0.type != HTP_TYPE_F32) { + if (octx->src[0]->type != HTP_TYPE_F32) { return HTP_STATUS_NO_SUPPORT; } - if (octx->dst.type != HTP_TYPE_F32) { + if (octx->dst->type != HTP_TYPE_F32) { return HTP_STATUS_NO_SUPPORT; } - if (octx->src1.type != HTP_TYPE_I32 && octx->src1.type != HTP_TYPE_I64) { + if (octx->src[1]->type != HTP_TYPE_I32 && octx->src[1]->type != HTP_TYPE_I64) { return HTP_STATUS_NO_SUPPORT; } @@ -101,12 +167,52 @@ int op_get_rows(struct htp_ops_context * octx) { return HTP_STATUS_OK; } - octx->get_rows_div_ne10 = init_fastdiv_values(octx->src1.ne[0]); - octx->get_rows_div_ne10_ne11 = init_fastdiv_values(octx->src1.ne[0] * octx->src1.ne[1]); + const uint32_t nb00 = octx->src[0]->nb[0]; + const uint32_t nb0 = octx->dst->nb[0]; + + const bool can_use_dma = (nb00 == sizeof(float)) && (nb0 == sizeof(float)); + const bool use_dma = can_use_dma && (ne00 >= 2048); + + struct get_rows_context grctx; + grctx.octx = octx; + grctx.get_rows_div_ne10 = init_fastdiv_values(octx->src[1]->ne[0]); + grctx.get_rows_div_ne10_ne11 = init_fastdiv_values(octx->src[1]->ne[0] * octx->src[1]->ne[1]); + + if (use_dma) { + grctx.chunks_per_row = 1; + grctx.chunk_size = ne00; + grctx.total_tasks = nr; + grctx.get_rows_div_chunks_per_row = init_fastdiv_values(1); + + const uint32_t n_threads = MIN(nr, octx->n_threads); + grctx.tasks_per_thread = (nr + n_threads - 1) / n_threads; + + worker_pool_run_func(octx->ctx->worker_pool, get_rows_thread_f32_f32_dma, &grctx, n_threads); + } else { + uint32_t chunks_per_row = 1; + uint32_t chunk_size = ne00; + uint32_t total_tasks = nr; + + if (nr < octx->n_threads) { + const uint32_t min_chunk_size = 1024; + uint32_t max_chunks = ne00 / min_chunk_size; + if (max_chunks == 0) { + max_chunks = 1; + } + chunks_per_row = MIN((octx->n_threads + nr - 1) / nr, max_chunks); + chunk_size = (ne00 + chunks_per_row - 1) / chunks_per_row; + total_tasks = nr * chunks_per_row; + } + + grctx.chunks_per_row = chunks_per_row; + grctx.chunk_size = chunk_size; + grctx.total_tasks = total_tasks; + grctx.get_rows_div_chunks_per_row = init_fastdiv_values(chunks_per_row); - const uint32_t n_jobs = MIN(nr, octx->n_threads); - octx->src1_nrows_per_thread = (nr + n_jobs - 1) / n_jobs; + const uint32_t n_threads = MIN(total_tasks, octx->n_threads); + grctx.tasks_per_thread = (total_tasks + n_threads - 1) / n_threads; - worker_pool_run_func(octx->ctx->worker_pool, get_rows_work_f32_f32, octx, n_jobs); + worker_pool_run_func(octx->ctx->worker_pool, get_rows_thread_f32_f32_hvx, &grctx, n_threads); + } return HTP_STATUS_OK; } diff --git a/ggml/src/ggml-hexagon/htp/htp-dma.c b/ggml/src/ggml-hexagon/htp/hex-dma.c similarity index 85% rename from ggml/src/ggml-hexagon/htp/htp-dma.c rename to ggml/src/ggml-hexagon/htp/hex-dma.c index 880c4542a0e..b66e2d2603c 100644 --- a/ggml/src/ggml-hexagon/htp/htp-dma.c +++ b/ggml/src/ggml-hexagon/htp/hex-dma.c @@ -1,4 +1,4 @@ -#include "htp-dma.h" +#include "hex-dma.h" #include <stdbool.h> #include <stdlib.h> @@ -31,8 +31,8 @@ dma_queue * dma_queue_create(size_t capacity) { q->capacity = capacity; q->idx_mask = capacity - 1; - q->desc = (hexagon_udma_descriptor_type1_t *) memalign(64, capacity * sizeof(hexagon_udma_descriptor_type1_t)); - memset(q->desc, 0, capacity * sizeof(hexagon_udma_descriptor_type1_t)); + q->desc = (dma_descriptor_2d *) memalign(64, capacity * sizeof(dma_descriptor_2d)); + memset(q->desc, 0, capacity * sizeof(dma_descriptor_2d)); q->dptr = (dma_ptr *) memalign(4, capacity * sizeof(dma_ptr)); memset(q->dptr, 0, capacity * sizeof(dma_ptr)); diff --git a/ggml/src/ggml-hexagon/htp/hex-dma.h b/ggml/src/ggml-hexagon/htp/hex-dma.h new file mode 100644 index 00000000000..7685473f463 --- /dev/null +++ b/ggml/src/ggml-hexagon/htp/hex-dma.h @@ -0,0 +1,372 @@ +#ifndef HTP_DMA_H +#define HTP_DMA_H + +#include <HAP_farf.h> +#include <hexagon_types.h> +#include <stdbool.h> +#include <stdint.h> + +#ifdef __cplusplus +extern "C" { +#endif + +// Define the HW descriptor structs here since the ones in HexSDK are a bit out of date +typedef struct dma_descriptor_1d_s { + void * next; + uint32_t size:24; + uint32_t desc_size:2; + uint32_t dst_comp:1; + uint32_t src_comp:1; + uint32_t dst_bypass:1; + uint32_t src_bypass:1; + uint32_t order:1; + uint32_t done:1; + void * src; + void * dst; +} dma_descriptor_1d; + +#if __HVX_ARCH__ < 75 + +typedef struct dma_descriptor_2d_s { + void * next; + uint32_t reserved0:24; + uint32_t desc_size:2; + uint32_t dst_comp:1; + uint32_t src_comp:1; + uint32_t dst_bypass:1; + uint32_t src_bypass:1; + uint32_t order:1; + uint32_t done:1; + void * src; + void * dst; + uint32_t desc_type:8; + uint32_t reserved1:24; + uint32_t row_size:16; + uint32_t nrows:16; + uint32_t src_stride:16; + uint32_t dst_stride:16; + uint32_t src_offset:16; + uint32_t dst_offset:16; +} dma_descriptor_2d; + +#else + +typedef struct dma_descriptor_2d_s { + void * next; + uint32_t dst_stride:24; + uint32_t desc_size:2; + uint32_t dst_comp:1; + uint32_t src_comp:1; + uint32_t dst_bypass:1; + uint32_t src_bypass:1; + uint32_t order:1; + uint32_t done:1; + void * src; + void * dst; + uint32_t desc_type:8; + uint32_t reserved0:24; + uint32_t row_size:24; + uint32_t nrows_lo:8; + uint32_t nrows_hi:8; + uint32_t src_stride:24; + uint32_t offset:24; + uint32_t reserved1:8; +} dma_descriptor_2d; + +#endif + +typedef struct { + void *dst; + const void *src; +} dma_ptr; + +typedef struct { + dma_descriptor_2d * desc; // descriptor pointers + dma_descriptor_2d * tail; // tail pointer + dma_ptr * dptr; // dst/src pointers + uint32_t push_idx; + uint32_t pop_idx; + uint32_t capacity; + uint32_t idx_mask; +} dma_queue; + +dma_queue * dma_queue_create(size_t capacity); +void dma_queue_delete(dma_queue * q); +void dma_queue_flush(dma_queue * q); + +// TODO: technically we don't need these and could use Q6_dmstart/wait/etc instead +// but those do not seem to always compiler properly. +static inline void dmstart(void * next) { + asm volatile(" release(%0):at" : : "r"(next)); + asm volatile(" dmstart(%0)" : : "r"(next)); +} + +static inline void dmlink(void * cur, void * next) { + asm volatile(" release(%0):at" : : "r"(next)); + asm volatile(" dmlink(%0, %1)" : : "r"(cur), "r"(next)); +} + +static inline unsigned int dmpoll(void) { + unsigned int ret = 0; + asm volatile(" %0 = dmpoll" : "=r"(ret) : : "memory"); + return ret; +} + +static inline unsigned int dmwait(void) { + unsigned int ret = 0; + asm volatile(" %0 = dmwait" : "=r"(ret) : : "memory"); + return ret; +} + +static inline dma_ptr dma_make_ptr(void *dst, const void *src) +{ + dma_ptr p = { dst, src }; + return p; +} + +#if __HVX_ARCH__ < 73 +static const uint32_t dma_src_l2_bypass_on = 1; +static const uint32_t dma_dst_l2_bypass_on = 0; +#else +static const uint32_t dma_src_l2_bypass_on = 1; +static const uint32_t dma_dst_l2_bypass_on = 1; +#endif + +static inline bool dma_queue_push_single_1d(dma_queue * q, dma_ptr dptr, size_t size) { + if (((q->push_idx + 1) & q->idx_mask) == q->pop_idx) { + FARF(HIGH, "dma-push: queue full\n"); + return false; + } + + dma_descriptor_1d * desc = (dma_descriptor_1d *) &q->desc[q->push_idx]; + desc->next = NULL; + desc->desc_size = 0; // 1D mode + desc->src_bypass = dma_src_l2_bypass_on; + desc->dst_bypass = dma_dst_l2_bypass_on; + desc->order = 0; + desc->done = 0; + desc->src = (void *) dptr.src; + desc->dst = (void *) dptr.dst; + desc->size = size; + + q->dptr[q->push_idx] = dptr; + + if (size) { + dmlink(q->tail, desc); + q->tail = (dma_descriptor_2d *) desc; + } else { + desc->done = 1; + } + + // FARF(ERROR, "dma-push: i %u row-size %u nrows %d dst %p src %p\n", q->push_idx, row_size, nrows, dptr.dst, dptr.src); + q->push_idx = (q->push_idx + 1) & q->idx_mask; + return true; +} + +static inline bool dma_queue_push_single_2d(dma_queue * q, dma_ptr dptr, size_t dst_stride, size_t src_stride, size_t row_size, size_t nrows) { + if (((q->push_idx + 1) & q->idx_mask) == q->pop_idx) { + FARF(HIGH, "dma-push: queue full\n"); + return false; + } + + dma_descriptor_2d * desc = &q->desc[q->push_idx]; + + desc->next = NULL; + desc->reserved0 = 0; + desc->reserved1 = 0; + desc->desc_size = 1; // 2d mode + desc->src_bypass = dma_src_l2_bypass_on; + desc->dst_bypass = dma_dst_l2_bypass_on; + desc->src_comp = 0; + desc->dst_comp = 0; + desc->order = 0; + desc->done = 0; + desc->src_stride = src_stride; + desc->dst_stride = dst_stride; + desc->src = (void *) dptr.src; + desc->dst = (void *) dptr.dst; + desc->row_size = row_size; + +#if __HVX_ARCH__ < 75 + desc->desc_type = 0; // 2d (16-bit) mode + desc->nrows = nrows; + desc->src_offset = 0; + desc->dst_offset = 0; +#else + desc->desc_type = 9; // 2d (24-bit) mode + desc->nrows_lo = (nrows & 0xff); + desc->nrows_hi = (nrows >> 8); + desc->offset = 0; +#endif + + q->dptr[q->push_idx] = dptr; + + if (nrows) { + dmlink(q->tail, desc); + q->tail = desc; + } else { + desc->done = 1; + } + + // FARF(ERROR, "dma-push: i %u row-size %u nrows %d dst %p src %p\n", q->push_idx, row_size, nrows, dptr.dst, dptr.src); + q->push_idx = (q->push_idx + 1) & q->idx_mask; + return true; +} + +static inline dma_ptr dma_queue_pop(dma_queue * q) { + dma_ptr dptr = { NULL }; + + if (q->push_idx == q->pop_idx) { + return dptr; + } + + dma_descriptor_2d * desc = &q->desc[q->pop_idx]; + + // Wait for desc to complete + while (!desc->done) { + // FARF(ERROR, "dma-pop: waiting for DMA : %u\n", q->pop_idx); + dmpoll(); + } + + dptr = q->dptr[q->pop_idx]; + + // FARF(ERROR, "dma-pop: i %u dst %p src %p\n", q->pop_idx, dptr.dst, dptr.src); + q->pop_idx = (q->pop_idx + 1) & q->idx_mask; + return dptr; +} + +static inline dma_ptr dma_queue_pop_nowait(dma_queue * q) { + dma_ptr dptr = { NULL }; + + if (q->push_idx == q->pop_idx) { + return dptr; + } + + dptr = q->dptr[q->pop_idx]; + + // FARF(ERROR, "dma-pop-nowait: i %u dst %p src %p\n", q->pop_idx, dptr.dst, dptr.src); + q->pop_idx = (q->pop_idx + 1) & q->idx_mask; + return dptr; +} + +static inline bool dma_queue_empty(dma_queue * q) { + return q->push_idx == q->pop_idx; +} + +static inline uint32_t dma_queue_depth(dma_queue * q) { + return (q->push_idx - q->pop_idx) & q->idx_mask; +} + +static inline uint32_t dma_queue_capacity(dma_queue * q) { + return q->capacity; +} + +#if __HVX_ARCH__ < 75 + +// Overflow-safe DMA push: all 2d descriptor fields (row_size, nrows, src_stride, dst_stride) are 16-bit, max 65535. +// This version transparently handles values that exceed the 16-bit limit and submits chained DMA transtions. + +#define DMA_MAX_FIELD_VAL 65535u + +static inline bool dma_queue_push(dma_queue *q, dma_ptr dptr, size_t dst_stride, size_t src_stride, size_t row_size, size_t nrows) { + // Fast path: everything fits in 16 bits + if (nrows == 0 || __builtin_expect( + row_size <= DMA_MAX_FIELD_VAL && + nrows <= DMA_MAX_FIELD_VAL && + src_stride <= DMA_MAX_FIELD_VAL && + dst_stride <= DMA_MAX_FIELD_VAL, 1)) { + return dma_queue_push_single_2d(q, dptr, dst_stride, src_stride, row_size, nrows); + } + + // Contiguous block + // Use 1d DMA mode which supports sizes up to 24-bits (16MB) + if (nrows == 1 || (row_size == src_stride && row_size == dst_stride)) { + size_t total = row_size * nrows; + return dma_queue_push_single_1d(q, dptr, total); + } + + // Stride overflow — fall back to row-by-row. + { + const uint8_t *src = (const uint8_t *) dptr.src; + uint8_t *dst = (uint8_t *) dptr.dst; + for (size_t r = 0; r < nrows; ++r) { + dma_ptr p = dma_make_ptr(dst + r * dst_stride, src + r * src_stride); + if (!dma_queue_push_single_1d(q, p, row_size)) + return false; + if (r + 1 < nrows) + dma_queue_pop(q); + } + return true; + } +} + +#else // HVX_ARCH >= 75 + +static inline bool dma_queue_push(dma_queue *q, dma_ptr dptr, size_t dst_stride, size_t src_stride, size_t row_size, size_t nrows) { + // On v75 and up we always use 2d 24-bit mode + return dma_queue_push_single_2d(q, dptr, dst_stride, src_stride, row_size, nrows); +} + +#endif + +static inline bool dma_queue_push_ddr_to_vtcm(dma_queue * q, dma_ptr dptr, size_t dst_row_size, size_t src_row_size, size_t nrows) { + return dma_queue_push(q, dptr, dst_row_size, src_row_size, src_row_size, nrows); +} + +static inline bool dma_queue_push_vtcm_to_ddr(dma_queue * q, dma_ptr dptr, size_t dst_row_size, size_t src_row_size, size_t nrows) { + return dma_queue_push(q, dptr, dst_row_size, src_row_size, dst_row_size, nrows); +} + +#define DMA_CACHE_MAX_SIZE 64U + +typedef struct { + uint8_t *base; + uint32_t line_size; + uint32_t capacity; + uint32_t src[DMA_CACHE_MAX_SIZE]; + uint16_t age[DMA_CACHE_MAX_SIZE]; +} dma_cache; + +static inline void dma_cache_init(dma_cache *c, uint8_t *base, uint32_t line_size, uint32_t capacity) +{ + c->capacity = (capacity > DMA_CACHE_MAX_SIZE) ? DMA_CACHE_MAX_SIZE : capacity; + c->base = base; + c->line_size = line_size; + + for (unsigned i=0; i < c->capacity; i++) { + c->src[i] = 0; + c->age[i] = 0; + } +} + +static inline bool dma_cache_push(dma_queue *q, dma_cache *c, const uint8_t * src, uint32_t dst_stride, uint32_t src_stride, uint32_t row_size, uint32_t nrows) +{ + uint32_t o_idx = 0; + uint16_t o_age = 0; + uint8_t * dst = 0; + + for (unsigned i=0; i < c->capacity; i++) { + if (c->src[i] == (uint32_t) src) { + c->age[i] = 0; + dst = c->base + (i * c->line_size); nrows = 0; // dummy dma + // FARF(ERROR, "dma-cache: found %p", src); + } else { + c->age[i]++; + if (c->age[i] > o_age) { o_age = c->age[i]; o_idx = i; } + } + } + if (!dst) { + // FARF(ERROR, "dma-cache: replacing #%u : age %u %p -> %p", o_idx, c->age[o_idx], (void *) c->src[o_idx], src); + c->age[o_idx] = 0; + c->src[o_idx] = (uint32_t) src; + dst = c->base + o_idx * c->line_size; // normal nrows dma + } + + return dma_queue_push(q, dma_make_ptr(dst, src), dst_stride, src_stride, row_size, nrows); +} + +#ifdef __cplusplus +} // extern "C" +#endif + +#endif /* HTP_DMA_H */ diff --git a/ggml/src/ggml-hexagon/htp/hex-dump.h b/ggml/src/ggml-hexagon/htp/hex-dump.h new file mode 100644 index 00000000000..19d173c2232 --- /dev/null +++ b/ggml/src/ggml-hexagon/htp/hex-dump.h @@ -0,0 +1,86 @@ +#ifndef HEX_DUMP_H +#define HEX_DUMP_H + +#include <HAP_farf.h> + +static inline void hex_dump_int8_line(char * pref, const int8_t * x, int n) { + char str[1024], *p = str, *p_end = str + sizeof(str); + p += snprintf(p, p_end - p, "%s: ", pref); + for (int i = 0; i < n && p < p_end; i++) { + p += snprintf(p, p_end - p, "%d, ", x[i]); + } + FARF(HIGH, "%s\n", str); +} + +static inline void hex_dump_uint8_line(char * pref, const uint8_t * x, uint32_t n) { + char str[1024], *p = str, *p_end = str + sizeof(str); + p += snprintf(p, p_end - p, "%s: ", pref); + for (int i = 0; i < n && p < p_end; i++) { + p += snprintf(p, p_end - p, "%d, ", x[i]); + } + FARF(HIGH, "%s\n", str); +} + +static inline void hex_dump_uint32_line(char * pref, const uint32_t * x, uint32_t n) { + char str[1024], *p = str, *p_end = str + sizeof(str); + p += snprintf(p, p_end - p, "%s: ", pref); + for (int i = 0; i < n; i++) { + p += snprintf(p, p_end - p, "%u, ", (unsigned int) x[i]); + } + FARF(HIGH, "%s\n", str); +} + +static inline void hex_dump_int32_line(char * pref, const int32_t * x, uint32_t n) { + char str[1024], *p = str, *p_end = str + sizeof(str); + p += snprintf(p, p_end - p, "%s: ", pref); + for (int i = 0; i < n; i++) { + p += snprintf(p, p_end - p, "%d, ", (int) x[i]); + } + FARF(HIGH, "%s\n", str); +} + +static inline void hex_dump_f16_line(char * pref, const __fp16 * x, uint32_t n) { + char str[1024], *p = str, *p_end = str + sizeof(str); + p += snprintf(p, p_end - p, "%s: ", pref); + for (int i = 0; i < n; i++) { + p += snprintf(p, p_end - p, "%.6f, ", (float) x[i]); + } + FARF(HIGH, "%s\n", str); +} + +static inline void hex_dump_f32_line(char * pref, const float * x, uint32_t n) { + char str[1024], *p = str, *p_end = str + sizeof(str); + p += snprintf(p, p_end - p, "%s: ", pref); + for (int i = 0; i < n; i++) { + p += snprintf(p, p_end - p, "%.6f, ", x[i]); + } + FARF(HIGH, "%s\n", str); +} + +static inline void hex_dump_f32(char * pref, const float * x, uint32_t n) { + uint32_t n0 = n / 16; + uint32_t n1 = n % 16; + + uint32_t i = 0; + for (; i < n0; i++) { + hex_dump_f32_line(pref, x + (16 * i), 16); + } + if (n1) { + hex_dump_f32_line(pref, x + (16 * i), n1); + } +} + +static inline void hex_dump_f16(char * pref, const __fp16 * x, uint32_t n) { + uint32_t n0 = n / 16; + uint32_t n1 = n % 16; + + uint32_t i = 0; + for (; i < n0; i++) { + hex_dump_f16_line(pref, x + (16 * i), 16); + } + if (n1) { + hex_dump_f16_line(pref, x + (16 * i), n1); + } +} + +#endif /* HEX_DUMP_H */ diff --git a/ggml/src/ggml-hexagon/htp/hex-fastdiv.h b/ggml/src/ggml-hexagon/htp/hex-fastdiv.h new file mode 100644 index 00000000000..b7b5867593f --- /dev/null +++ b/ggml/src/ggml-hexagon/htp/hex-fastdiv.h @@ -0,0 +1,37 @@ +#ifndef HEX_FASTDIV_H +#define HEX_FASTDIV_H + +// See https://gmplib.org/~tege/divcnst-pldi94.pdf figure 4.1. +// Precompute mp (m' in the paper) and L such that division +// can be computed using a multiply (high 32b of 64b result) +// and a shift: +// +// n/d = (mulhi(n, mp) + n) >> L; +struct fastdiv_values { + uint32_t mp; + uint32_t l; +}; + +static inline struct fastdiv_values init_fastdiv_values(uint32_t d) { + struct fastdiv_values result = { 0, 0 }; + // compute L = ceil(log2(d)); + while (result.l < 32 && ((uint32_t) 1 << result.l) < d) { + ++(result.l); + } + + result.mp = (uint32_t) (((uint64_t) 1 << 32) * (((uint64_t) 1 << result.l) - d) / d + 1); + return result; +} + +static inline uint32_t fastdiv(uint32_t n, const struct fastdiv_values * vals) { + // Compute high 32 bits of n * mp + const uint32_t hi = (uint32_t) (((uint64_t) n * vals->mp) >> 32); // mulhi(n, mp) + // add n, apply bit shift + return (hi + n) >> vals->l; +} + +static inline uint32_t fastmodulo(uint32_t n, uint32_t d, const struct fastdiv_values * vals) { + return n - fastdiv(n, vals) * d; +} + +#endif /* HEX_FASTDIV_H */ diff --git a/ggml/src/ggml-hexagon/htp/hex-utils.h b/ggml/src/ggml-hexagon/htp/hex-utils.h new file mode 100644 index 00000000000..6239ceff4b4 --- /dev/null +++ b/ggml/src/ggml-hexagon/htp/hex-utils.h @@ -0,0 +1,137 @@ +#ifndef HEX_UTILS_H +#define HEX_UTILS_H + +#include <stdbool.h> +#include <stdint.h> +#include <qurt_memory.h> +#include <qurt.h> + +#include "hexagon_types.h" +#include "hexagon_protos.h" + +#include "hex-fastdiv.h" +#include "hex-dump.h" + +#ifndef MAX +#define MAX(a, b) ((a) > (b) ? (a) : (b)) +#endif + +#ifndef MIN +#define MIN(a, b) ((a) < (b) ? (a) : (b)) +#endif + +static inline uint64_t hex_get_cycles() { + uint64_t cycles = 0; + asm volatile(" %0 = c15:14\n" : "=r"(cycles)); + return cycles; +} + +static inline uint64_t hex_get_pktcnt() { + uint64_t pktcnt; + asm volatile(" %0 = c19:18\n" : "=r"(pktcnt)); + return pktcnt; +} + +static inline uint32_t hex_ceil_pow2(uint32_t x) { + if (x <= 1) { return 1; } + int p = 2; + x--; + while (x >>= 1) { p <<= 1; } + return p; +} + +static inline size_t hmx_ceil_div(size_t num, size_t den) { + return (num + den - 1) / den; +} + +static inline int32_t hex_is_aligned(const void * addr, uint32_t align) { + return ((size_t) addr & (align - 1)) == 0; +} + +static inline size_t hex_align_up(size_t v, size_t align) { + return hmx_ceil_div(v, align) * align; +} + +static inline size_t hex_align_down(size_t v, size_t align) { + return (v / align) * align; +} + +static inline int32_t hex_is_one_chunk(void * addr, uint32_t n, uint32_t chunk_size) { + uint32_t left_off = (size_t) addr & (chunk_size - 1); + uint32_t right_off = left_off + n; + return right_off <= chunk_size; +} + +static inline uint32_t hex_round_up(uint32_t n, uint32_t m) { + return m * ((n + m - 1) / m); +} + +static inline size_t hex_smin(size_t a, size_t b) { + return a < b ? a : b; +} + +static inline size_t hex_smax(size_t a, size_t b) { + return a > b ? a : b; +} + +static inline void hex_swap_ptr(void ** p1, void ** p2) { + void * t = *p1; + *p1 = *p2; + *p2 = t; +} + +static inline void hex_l2fetch(const void * p, uint32_t width, uint32_t stride, uint32_t height) { + const uint64_t control = Q6_P_combine_RR(stride, Q6_R_combine_RlRl(width, height)); + Q6_l2fetch_AP((void *) p, control); +} + +#define HEX_L2_LINE_SIZE 64 +#define HEX_L2_FLUSH_SIZE (128 * 1024) + +static inline void hex_l2flush(void * addr, size_t size) { + if (size > HEX_L2_FLUSH_SIZE) { + qurt_mem_cache_clean((qurt_addr_t) 0, 0, QURT_MEM_CACHE_FLUSH_INVALIDATE_ALL, QURT_MEM_DCACHE); + } else { + const uint32_t s = (uint32_t) addr; + const uint32_t e = s + size; + for (uint32_t i = s; i < e; i += HEX_L2_LINE_SIZE * 4) { + Q6_dccleaninva_A((void *) i + HEX_L2_LINE_SIZE * 0); + Q6_dccleaninva_A((void *) i + HEX_L2_LINE_SIZE * 1); + Q6_dccleaninva_A((void *) i + HEX_L2_LINE_SIZE * 2); + Q6_dccleaninva_A((void *) i + HEX_L2_LINE_SIZE * 3); + } + } +} + +static inline void hex_pause() { + asm volatile(" pause(#255)\n"); +} + +#ifndef HEX_NUM_PMU_COUNTERS +#define HEX_NUM_PMU_COUNTERS 8 +#endif + +static inline void hex_get_pmu(uint32_t counters[]) { +#if __HVX_ARCH__ >= 79 + asm volatile("%0 = upmucnt0" : "=r"(counters[0])); + asm volatile("%0 = upmucnt1" : "=r"(counters[1])); + asm volatile("%0 = upmucnt2" : "=r"(counters[2])); + asm volatile("%0 = upmucnt3" : "=r"(counters[3])); + asm volatile("%0 = upmucnt4" : "=r"(counters[4])); + asm volatile("%0 = upmucnt5" : "=r"(counters[5])); + asm volatile("%0 = upmucnt6" : "=r"(counters[6])); + asm volatile("%0 = upmucnt7" : "=r"(counters[7])); +#else + counters[0] = qurt_pmu_get(QURT_PMUCNT0); + counters[1] = qurt_pmu_get(QURT_PMUCNT1); + counters[2] = qurt_pmu_get(QURT_PMUCNT2); + counters[3] = qurt_pmu_get(QURT_PMUCNT3); + counters[4] = qurt_pmu_get(QURT_PMUCNT4); + counters[5] = qurt_pmu_get(QURT_PMUCNT5); + counters[6] = qurt_pmu_get(QURT_PMUCNT6); + counters[7] = qurt_pmu_get(QURT_PMUCNT7); + // qurt_pmu_get_pmucnt(counters); +#endif +} + +#endif /* HEX_UTILS_H */ diff --git a/ggml/src/ggml-hexagon/htp/hmx-flash-attn-ops.c b/ggml/src/ggml-hexagon/htp/hmx-flash-attn-ops.c new file mode 100644 index 00000000000..2796564fb75 --- /dev/null +++ b/ggml/src/ggml-hexagon/htp/hmx-flash-attn-ops.c @@ -0,0 +1,1878 @@ +// HMX-accelerated Flash Attention for prefill (neq1 >= 32). +// Ported from htp-ops-lib/src/dsp/ops/flash_attn.c, adapted to the htp/ codebase. + +#pragma clang diagnostic ignored "-Wunused-variable" +#pragma clang diagnostic ignored "-Wunused-function" +#pragma clang diagnostic ignored "-Wunused-but-set-variable" + +#include <assert.h> +#include <HAP_compute_res.h> +#include <HAP_farf.h> +#include <math.h> +#include <stdbool.h> +#include <stddef.h> +#include <stdint.h> +#include <string.h> + +#define GGML_COMMON_DECL_C +#include "ggml-common.h" +#include "hex-dma.h" +#include "hex-fastdiv.h" +#include "hmx-profile.h" +#include "hmx-queue.h" +#include "hmx-utils.h" +#include "htp-ctx.h" +#include "htp-ops.h" +#include "hvx-dump.h" +#include "hvx-copy.h" +#include "hvx-reduce.h" +#include "hvx-utils.h" +#include "hvx-flash-attn.h" +#include "vtcm-utils.h" +#include "worker-pool.h" + +// ============================================================================ +// Constants +// ============================================================================ + +// Tile constants from hmx-utils.h +// HMX_FP16_TILE_N_ROWS = 32 +// HMX_FP16_TILE_N_COLS = 32 +// HMX_FP16_TILE_N_ELMS = 1024 +// HMX_FP16_TILE_SIZE = 2048 + +// ============================================================================ +// Dynamic block size computation (GQA-aware) +// ============================================================================ + +// Exact VTCM usage for a given (gqa_factor, DK, DV, Br, Bc) configuration. +// g_br = hex_align_up(gqa_factor * Br, 32) replaces Br for all Q/O/S/P/D dimensions. +// Layout: Q + O_ping + O_pong + K_dma*2 + V_dma*2 + K_tile + V_tile + S + P + D + vectors + scales +// Mask is DMA'd into a VTCM buffer (Br rows per KV block) to avoid DDR reads in softmax. +static size_t hmx_fa_compute_vtcm_usage(size_t gqa_factor, size_t DK, size_t DV, size_t Br, size_t Bc, size_t n_threads, bool use_pipeline) { + const size_t g_br = hex_align_up(gqa_factor * Br, HMX_FP16_TILE_N_ROWS); + const size_t q_tile_size = hex_align_up(g_br * DK * sizeof(__fp16), 4096); // Q: [g_br, DK] + const size_t o_tile_size = hex_align_up(g_br * DV * sizeof(__fp16), 4096); // O: [g_br, DV] x2 ping-pong + const size_t k_dma_size = hex_align_up(Bc * hex_round_up(DK * sizeof(__fp16), 128), 4096); // K DMA: [Bc, DK] x2 double-buf + const size_t v_dma_size = hex_align_up(Bc * hex_round_up(DV * sizeof(__fp16), 128), 4096); // V DMA: [Bc, DV] x2 double-buf + const size_t k_tile_size = hex_align_up(Bc * DK * sizeof(__fp16), 4096); // K tiles: [Bc, DK] interleaved + const size_t v_tile_size = hex_align_up(Bc * DV * sizeof(__fp16), 4096); // V tiles: [Bc, DV] interleaved + const size_t s_tile_size = hex_align_up(g_br * Bc * sizeof(__fp16), 4096); // S/P:[g_br, Bc] + const size_t d_tile_size = hex_align_up(g_br * g_br * sizeof(__fp16), 4096); // D: [g_br, g_br] + const size_t col_vec_size = hex_align_up(g_br * sizeof(__fp16), 256); // m, l, etc. + const size_t row_vec_size = hex_align_up(Bc * sizeof(__fp16), 256); + const size_t m_line_size = hex_align_up(Bc * sizeof(__fp16), 128); + const size_t m_buf_size = hex_align_up(Br * m_line_size, 4096); + const size_t slopes_size = hex_align_up(g_br * sizeof(__fp16), 128); + + return q_tile_size * 1 // Q tiles + + o_tile_size * 2 // O ping-pong + + k_dma_size * 2 // K DMA x2 + + v_dma_size * 2 // V DMA x2 + + k_tile_size * 1 // K tiles + + v_tile_size * (use_pipeline ? 2 : 1) // V tiles (double-buffered if pipelining) + + s_tile_size * 2 // S + P + + d_tile_size * 1 // D (diagonal matrix) + + col_vec_size * 4 // m_vec, l_vec, s_rowmax, p_rowsum + + row_vec_size * 2 * n_threads // per-thread softmax row scratch + + m_buf_size * 1 // mask VTCM buffer [Br rows] + + slopes_size // Slopes + + 256 * 2; // HMX scales (id + qk) +} + +// ============================================================================ +// FP16 exp2 polynomial (ported from htp-ops-lib/include/dsp/hvx_math.h) +// ============================================================================ +// 5th-order Horner polynomial for exp2(x) in qf16/hf16 domain. Input must be +// ≤ 0 (safe softmax invariant — overflow handling omitted). ~18 ALU ops per +// 64 fp16 lanes, fully parallel across HVX threads (no scatter/gather engine). +// Replaces the F32 round-trip (qf16→f32→exp→f32→f16, ~44 ops for 2×32 lanes). +static inline HVX_Vector hvx_exp2_hf(HVX_Vector x_v) { + const HVX_Vector zero_v = Q6_V_vzero(); + const HVX_Vector half_hf_v = Q6_Vh_vsplat_R(0x3800); // fp16 0.5 + + // k = round_toward_neg_inf(x); f = (float)k; frac = x - f + HVX_Vector x_minus_half = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vsub_VhfVhf(x_v, half_hf_v)); + HVX_Vector k_v = Q6_Vh_equals_Vhf(x_minus_half); // truncate to int16 + HVX_Vector f_v = Q6_Vhf_equals_Vh(k_v); // back to fp16 + + HVX_Vector x_qf16 = Q6_Vqf16_vsub_VhfVhf(x_v, f_v); // fractional part in qf16 + + // Horner: y = ((((E5*x + E4)*x + E3)*x + E2)*x + E1)*x + E0 + HVX_Vector y = Q6_Vqf16_vmpy_Vqf16Vqf16(Q6_Vh_vsplat_R(0x5082), x_qf16); // E5*x + y = Q6_Vqf16_vadd_Vqf16Vhf(y, Q6_Vh_vsplat_R(0x157d)); // + E4 + y = Q6_Vqf16_vmpy_Vqf16Vqf16(y, x_qf16); + y = Q6_Vqf16_vadd_Vqf16Vhf(y, Q6_Vh_vsplat_R(0x20ed)); // + E3 + y = Q6_Vqf16_vmpy_Vqf16Vqf16(y, x_qf16); + y = Q6_Vqf16_vadd_Vqf16Vhf(y, Q6_Vh_vsplat_R(0x2b1b)); // + E2 + y = Q6_Vqf16_vmpy_Vqf16Vqf16(y, x_qf16); + y = Q6_Vqf16_vadd_Vqf16Vhf(y, Q6_Vh_vsplat_R(0x33b0)); // + E1 + y = Q6_Vqf16_vmpy_Vqf16Vqf16(y, x_qf16); + y = Q6_Vqf16_vadd_Vqf16Vhf(y, Q6_Vh_vsplat_R(0x398c)); // + E0 + y = Q6_Vqf16_vmpy_Vqf16Vqf16(y, x_qf16); // y = y * x + y = Q6_Vqf16_vadd_Vqf16Vhf(y, Q6_Vh_vsplat_R(0x3c00)); // + 1.0 + + // Combine polynomial (mantissa) with integer part (exponent): result = y * 2^k + y = Q6_Vhf_equals_Vqf16(y); + HVX_Vector y_exp = Q6_Vuh_vlsr_VuhR(Q6_Vh_vasl_VhR(y, 1), 11); + y_exp = Q6_Vh_vadd_VhVh(k_v, y_exp); + HVX_VectorPred q_underflow = Q6_Q_vcmp_gt_VhVh(zero_v, y_exp); + y = Q6_Vh_vaslacc_VhVhR(y, k_v, 10); + return Q6_V_vmux_QVV(q_underflow, zero_v, y); +} + +#define FA_MIN_KV_BLOCKS 3 + +// Cost-based (Br, Bc) search for flash attention with pipeline constraint. +// +// VTCM model (same as before): +// overhead + g_br * per_gbr + g_br² * per_gbr2 + Bc * per_bc + g_br * Bc * per_gbr_bc +// +// Cost model (minimization objective): +// Q * (c_q_fixed + K * c_iter_fixed), where Q = ceil(qo/Br), K = ceil(kv/Bc) +static int hmx_fa_find_chunk_size(size_t * Br_out, + size_t * Bc_out, + size_t gqa_factor, + size_t DK, + size_t DV, + size_t qo_len, + size_t kv_len, + size_t vtcm_budget, + size_t n_threads) { + const size_t T = HMX_FP16_TILE_N_ROWS; // 32 + const size_t br_unit = hmx_ceil_div(T, gqa_factor); + // Bc must be a multiple of 64 so that n_tiles_per_bc is even. The softmax + // P-tile write uses a dual-tile pattern (vshuff + two stores 16 slots apart) + // that would race across r0 blocks if the last dual-tile is half-occupied. + // See .cursor/todos/hmx-flash-attn-bc-search-space.md for the perf trade-off. + const size_t bc_unit = HMX_FP16_TILE_N_COLS * 2; // 64 + const size_t fp16 = sizeof(__fp16); + const bool can_pipeline = (kv_len >= FA_MIN_KV_BLOCKS * bc_unit && n_threads >= 2); + + // Approximate per-unit VTCM costs (without per-buffer alignment padding). + const size_t per_gbr = (DK + 2 * DV) * fp16 + 4 * fp16; // Q + O×2 + 4 col vectors + const size_t per_gbr2 = fp16; // D diagonal matrix + const size_t per_bc = + 3 * DK * fp16 + (can_pipeline ? 4 : 3) * DV * fp16 + 2 * n_threads * fp16; // K/V DMA x2 + tiles + row bufs + const size_t per_gbr_bc = 2 * fp16; // S + P + + const size_t overhead = 256 * 2 + 13 * 4096; + + if (vtcm_budget <= overhead) { + return -1; + } + const size_t usable = vtcm_budget - overhead; + + // Br_max: largest Br aligned to br_unit that does not exceed qo_len. + const size_t Br_max = qo_len >= br_unit ? hex_align_down(qo_len, br_unit) : br_unit; + + // Pipeline constraint: cap Bc so n_kv_blocks >= FA_MIN_KV_BLOCKS. + // Only relax when kv_len is too short to form enough blocks. + const size_t Bc_limit = can_pipeline ? hex_align_down(kv_len / FA_MIN_KV_BLOCKS, bc_unit) : + (kv_len >= bc_unit ? hex_align_down(kv_len, bc_unit) : bc_unit); + // Cost coefficients calibrated from profiling + const size_t c_q_fixed = 1400; // per-Q-block: q_load + epilogue o_update + o_norm + o_store + const size_t c_iter_fixed = 200; // per-KV-iter: HMX queue push/pop + DMA pop + barriers + + size_t best_cost = SIZE_MAX, best_mn = 0; + size_t best_Br = 0, best_Bc = 0; + + for (size_t Br = Br_max; Br >= br_unit; Br -= br_unit) { + const size_t g_br = hex_align_up(gqa_factor * Br, T); + + // g_br-dependent VTCM cost: g_br * per_gbr + g_br² * per_gbr2 + const size_t gbr_cost = g_br * per_gbr + g_br * g_br * per_gbr2; + if (gbr_cost >= usable) { + if (Br == br_unit) { + break; + } + continue; + } + + // Analytically solve for max Bc: + // remain >= Bc * (per_bc + g_br * per_gbr_bc + Br * fp16_mask) + // The Br * fp16 term accounts for the VTCM mask buffer [Br × Bc]. + const size_t remain = usable - gbr_cost; + const size_t bc_denom = per_bc + g_br * per_gbr_bc + Br * fp16; + size_t Bc = hex_smin(hex_align_down(remain / bc_denom, bc_unit), Bc_limit); + if (Bc < bc_unit) { + if (Br == br_unit) { + break; + } + continue; + } + + // Exact VTCM verification (alignment padding may push over budget) + while (Bc >= bc_unit && hmx_fa_compute_vtcm_usage(gqa_factor, DK, DV, Br, Bc, n_threads, can_pipeline) > vtcm_budget) { + Bc -= bc_unit; + } + if (Bc < bc_unit) { + if (Br == br_unit) { + break; + } + continue; + } + + const size_t q_blocks = (qo_len + Br - 1) / Br; + const size_t kv_blocks = (kv_len + Bc - 1) / Bc; + const size_t cost = q_blocks * (c_q_fixed + kv_blocks * c_iter_fixed); + const size_t mn = Br * Bc; + + if (cost < best_cost || (cost == best_cost && mn > best_mn)) { + best_cost = cost; + best_mn = mn; + best_Br = Br; + best_Bc = Bc; + } + + if (Br == br_unit) { + break; + } + } + + if (best_Br == 0) { + return -1; + } + + *Br_out = best_Br; + *Bc_out = best_Bc; + return 0; +} + +// ============================================================================ +// Tile interleave / extract helpers +// ============================================================================ + +// transpose scatter offsets moved to hmx-utils.h as hmx_transpose_scatter_offsets + +// Scatter offsets for diagonal tile: entry[2i] = i*136, entry[2i+1] = i*136+6 +// 136 = 4 * 32 + 8 = byte offset to diagonal in a 32x32 fp16 interleaved tile +static const int16_t d_tile_scatter_offsets[64] __attribute__((aligned(128))) = { + 0 * 136, 0 * 136 + 6, + 1 * 136, 1 * 136 + 6, + 2 * 136, 2 * 136 + 6, + 3 * 136, 3 * 136 + 6, + 4 * 136, 4 * 136 + 6, + 5 * 136, 5 * 136 + 6, + 6 * 136, 6 * 136 + 6, + 7 * 136, 7 * 136 + 6, + 8 * 136, 8 * 136 + 6, + 9 * 136, 9 * 136 + 6, + 10 * 136, 10 * 136 + 6, + 11 * 136, 11 * 136 + 6, + 12 * 136, 12 * 136 + 6, + 13 * 136, 13 * 136 + 6, + 14 * 136, 14 * 136 + 6, + 15 * 136, 15 * 136 + 6, + 0, 0, + 0, 0, + 0, 0, + 0, 0, + 0, 0, + 0, 0, + 0, 0, + 0, 0, + 0, 0, + 0, 0, + 0, 0, + 0, 0, + 0, 0, + 0, 0, + 0, 0, + 0, 0, +}; + +// hmx_interleave_rows_to_tiles and hmx_interleave_cols_to_tiles are in hmx-utils.h + +// ============================================================================ +// HMX Flash Attention context (GQA-merged) +// ============================================================================ + +struct hmx_fa_context { + const struct htp_ops_context * octx; + bool use_pipeline; // true when n_kv_blocks >= FA_MIN_KV_BLOCKS && n_threads >= 2 + uint32_t n_threads; + + // Op parameters + float scale; + float max_bias; + float logit_softcap; + uint32_t n_head_log2; + float m0, m1; + + // Dimensions + uint32_t DK, DV; + uint32_t n_kv; // kv_len + uint32_t n_kv_heads; // number of KV heads + uint32_t n_heads; // number of Q heads + uint32_t G; // GQA factor = n_heads / n_kv_heads + struct fastdiv_values div_G; + uint32_t n_kv_blocks; + uint32_t neq1; // Q token count + + // Types + bool is_q_fp32; + bool is_dst_fp32; + + // Dynamic block sizes + uint32_t Br; // Q tokens per block (before GQA expansion) + uint32_t Bc; + uint32_t g_br; // hex_align_up(G * Br, 32) - actual tile row dim + + // VTCM buffers (allocated by vtcm_seq_alloc) + __fp16 * vtcm_q_tiles; // Q tile format [g_br, D] + __fp16 * vtcm_o_tiles[2]; // O ping-pong [g_br, D] + __fp16 * vtcm_k_fp16[2]; // K DMA double-buffer [Bc, D] + __fp16 * vtcm_v_fp16[2]; // V DMA double-buffer [Bc, D] + __fp16 * vtcm_k_tiles; // K tiles (transposed) + __fp16 * vtcm_v_tiles[2]; // V tiles (column-major, double-buffered) + __fp16 * vtcm_s_tiles; // S = QK^T [g_br, Bc] + __fp16 * vtcm_p_tiles; // P = softmax(S) [g_br, Bc] + __fp16 * vtcm_d_tiles; // Diagonal rescale [g_br, g_br] + HVX_Vector * vtcm_m_vec; // Row max [g_br] + HVX_Vector * vtcm_l_vec; // Row sum [g_br] + HVX_Vector * vtcm_s_rowmax; // Softmax intermediate [g_br] + HVX_Vector * vtcm_p_rowsum; // Softmax intermediate [g_br] + HVX_Vector * vtcm_row_bufs; // Per-thread softmax row scratch [n_threads][2][Bc/64] + uint8_t * vtcm_hmx_scales_id; // HMX output scales (identity) + uint8_t * vtcm_hmx_scales_qk; // HMX output scales (qk_scale) + __fp16 * vtcm_mask_buf; // VTCM mask buffer [Br × m_line], DMA'd per KV block + __fp16 * vtcm_slopes; // ALiBi slopes [g_br] + size_t row_buf_stride; // HVX vectors per row buffer (Bc/64) + size_t mask_buf_row_stride; // elements (__fp16) per row in mask buffer + bool mask_broadcast; // true when mask->ne[2] == 1 (head-independent, single 2D DMA) +}; + +// ============================================================================ +// Multi-thread K interleave phase +// ============================================================================ + +typedef struct { + struct hmx_fa_context * factx; + int kv_rows; + size_t src_stride; + size_t buf_idx; +} fa_k_int_args_t; + +static void fa_k_interleave_thread(unsigned int n, unsigned int i, void * data) { + fa_k_int_args_t * args = (fa_k_int_args_t *) data; + struct hmx_fa_context * factx = args->factx; + + const int total_rows = args->kv_rows; + const int rows_per_t = hex_align_up(hmx_ceil_div(total_rows, n), 2); // ensure even (row pairs) + const int start = i * rows_per_t; + const int end = hex_smin(start + rows_per_t, total_rows); + + if (start >= total_rows) { + return; + } + + hmx_interleave_rows_to_tiles(factx->vtcm_k_tiles, factx->vtcm_k_fp16[args->buf_idx], total_rows, (int) factx->DK, + (int) args->src_stride, start, end); +} + +static void fa_phase_k_interleave(struct hmx_fa_context * factx, int kv_rows, size_t src_stride, size_t buf_idx) { + worker_pool_context_t wp = factx->octx->ctx->worker_pool; + fa_k_int_args_t args = { factx, kv_rows, src_stride, buf_idx }; + if (factx->n_threads > 1 && kv_rows >= (int) (factx->n_threads * 2)) { + worker_pool_run_func(wp, fa_k_interleave_thread, &args, factx->n_threads); + } else { + fa_k_interleave_thread(1, 0, &args); + } +} + +// ============================================================================ +// Multi-thread V interleave phase +// ============================================================================ + +typedef struct { + struct hmx_fa_context * factx; + int kv_rows; + size_t src_stride; + size_t buf_idx; + size_t n_col_tiles; +} fa_v_int_args_t; + +static void fa_v_interleave_thread(unsigned int n, unsigned int i, void * data) { + fa_v_int_args_t * args = (fa_v_int_args_t *) data; + struct hmx_fa_context * factx = args->factx; + + const int total_rows = args->kv_rows; + const int rows_per_t = hex_align_up(hmx_ceil_div(total_rows, n), 2); + const int start = i * rows_per_t; + const int end = hex_smin(start + rows_per_t, total_rows); + + if (start >= total_rows) { + return; + } + + __fp16 * v_tiles_dest = factx->use_pipeline ? factx->vtcm_v_tiles[args->buf_idx] : factx->vtcm_v_tiles[0]; + + hmx_interleave_cols_to_tiles(v_tiles_dest, factx->vtcm_v_fp16[args->buf_idx], total_rows, (int) factx->DV, + (int) args->src_stride, (int) args->n_col_tiles, start, end); +} + +static void fa_phase_v_interleave(struct hmx_fa_context * factx, + int kv_rows, + size_t src_stride, + size_t buf_idx, + size_t n_col_tiles) { + worker_pool_context_t wp = factx->octx->ctx->worker_pool; + fa_v_int_args_t args = { factx, kv_rows, src_stride, buf_idx, n_col_tiles }; + if (factx->n_threads > 1 && kv_rows >= (int) (factx->n_threads * 2)) { + worker_pool_run_func(wp, fa_v_interleave_thread, &args, factx->n_threads); + } else { + fa_v_interleave_thread(1, 0, &args); + } +} + +// ============================================================================ +// Multi-thread Q load phase: read Q[G × neq1, DK] from DDR, convert F32→F16 +// (or deal F16 pairs), and write interleaved into vtcm_q_tiles. +// Each thread owns a disjoint range of row pairs; writes target distinct tile +// slots (r0 selects tile row, r1 selects intra-tile slot), so there is no +// write conflict. Padding fill (when n_rows_g < g_br) is done single-threaded +// by the caller before dispatching. +// ============================================================================ + +typedef struct { + struct hmx_fa_context * factx; + const struct htp_tensor * q; + uint32_t q_start; + uint32_t kv_head; + uint32_t ib3; + size_t n_rows_g; +} fa_q_load_args_t; + +static void fa_q_load_thread(unsigned int n, unsigned int i, void * data) { + fa_q_load_args_t * args = (fa_q_load_args_t *) data; + struct hmx_fa_context * factx = args->factx; + + const size_t n_rows_g = args->n_rows_g; + const size_t G = factx->G; + const size_t DK = factx->DK; + + // Partition row pairs across threads. Keep each thread's start even so r/r+1 + // are always in the same thread's range. + const size_t rows_per_t = hex_align_up(hmx_ceil_div(n_rows_g, n), 2); + const size_t start = (size_t) i * rows_per_t; + const size_t end = hex_smin(start + rows_per_t, n_rows_g); + + if (start >= n_rows_g) { + return; + } + + const struct htp_tensor * q = args->q; + const uint32_t q_start = args->q_start; + const uint32_t kv_head = args->kv_head; + const uint32_t ib3 = args->ib3; + + for (size_t r = start; r < end; r += 2) { + const bool next_row_valid = (r + 1) < n_rows_g; + + const size_t q_idx0 = fastdiv(r + 0, &factx->div_G); + const size_t h_idx0 = fastmodulo(r + 0, G, &factx->div_G); + const size_t q_idx1 = fastdiv(r + 1, &factx->div_G); + const size_t h_idx1 = fastmodulo(r + 1, G, &factx->div_G); + + const uint8_t * q_ptr0 = (const uint8_t *) q->data + (q_start + q_idx0) * q->nb[1] + + (kv_head * G + h_idx0) * q->nb[2] + ib3 * q->nb[3]; + const uint8_t * q_ptr1 = next_row_valid ? ((const uint8_t *) q->data + (q_start + q_idx1) * q->nb[1] + + (kv_head * G + h_idx1) * q->nb[2] + ib3 * q->nb[3]) : + NULL; + + size_t r0 = r / HMX_FP16_TILE_N_ROWS; + size_t r1 = r % HMX_FP16_TILE_N_ROWS; + __fp16 * out_base = factx->vtcm_q_tiles + r0 * HMX_FP16_TILE_N_ROWS * DK; + + if (factx->is_q_fp32) { + const HVX_Vector * pv_in0 = (const HVX_Vector *) q_ptr0; + const HVX_Vector * pv_in1 = q_ptr1 ? (const HVX_Vector *) q_ptr1 : NULL; + + for (uint32_t d = 0; d < DK / 32; ++d) { + HVX_Vector v0 = pv_in0[d]; + HVX_Vector v1 = pv_in1 ? pv_in1[d] : Q6_V_vzero(); + HVX_Vector v_hf = hvx_vec_f32_to_f16_shuff(v0, v1); + + HVX_Vector * out_tile = (HVX_Vector *) (out_base + d * HMX_FP16_TILE_N_ELMS); + out_tile[r1 / 2] = v_hf; + } + } else { + const HVX_Vector * pv_in0 = (const HVX_Vector *) q_ptr0; + const HVX_Vector * pv_in1 = q_ptr1 ? (const HVX_Vector *) q_ptr1 : NULL; + + for (uint32_t d = 0; d < DK / 64; ++d) { + HVX_Vector v0 = pv_in0[d]; + HVX_Vector v1 = pv_in1 ? pv_in1[d] : Q6_V_vzero(); + HVX_VectorPair vp = Q6_W_vshuff_VVR(v1, v0, -2); + + __fp16 * out_dual_tile = out_base + d * HMX_FP16_TILE_N_ELMS * 2; + HVX_Vector * pv_out0 = ((HVX_Vector *) out_dual_tile) + r1 / 2; + HVX_Vector * pv_out1 = pv_out0 + 16; + + *pv_out0 = Q6_V_lo_W(vp); + *pv_out1 = Q6_V_hi_W(vp); + } + } + } +} + +static void fa_phase_q_load(struct hmx_fa_context * factx, + const struct htp_tensor * q, + uint32_t q_start, + uint32_t kv_head, + uint32_t ib3, + size_t n_rows_g) { + worker_pool_context_t wp = factx->octx->ctx->worker_pool; + fa_q_load_args_t args = { factx, q, q_start, kv_head, ib3, n_rows_g }; + // Require >= 2 row pairs per thread so partitioning is worthwhile. + if (factx->n_threads > 1 && n_rows_g >= (size_t) (factx->n_threads * 2)) { + worker_pool_run_func(wp, fa_q_load_thread, &args, factx->n_threads); + } else { + fa_q_load_thread(1, 0, &args); + } +} + +// ============================================================================ +// Multi-thread O store phase: read O tiles from VTCM, convert F16->F32 (or +// deal F16 pairs), and write to strided DDR dst tensor. Each thread owns a +// disjoint row range; writes target distinct dst rows (different q_idx/h_idx +// pairs produced by r/G and r%G), so there is no write conflict. +// ============================================================================ + +typedef struct { + struct hmx_fa_context * factx; + const struct htp_tensor * dst; + const __fp16 * o_tile_src; + uint32_t q_start; + uint32_t kv_head; + uint32_t ib3; + size_t n_rows_g; +} fa_o_store_args_t; + +static void fa_o_store_thread(unsigned int n, unsigned int i, void * data) { + fa_o_store_args_t * args = (fa_o_store_args_t *) data; + struct hmx_fa_context * factx = args->factx; + + const size_t n_rows_g = args->n_rows_g; + const size_t G = factx->G; + const size_t DV = factx->DV; + + const size_t rows_per_t = hmx_ceil_div(n_rows_g, n); + const size_t start = (size_t) i * rows_per_t; + const size_t end = hex_smin(start + rows_per_t, n_rows_g); + + if (start >= n_rows_g) { + return; + } + + const struct htp_tensor * dst = args->dst; + const __fp16 * o_tile_src = args->o_tile_src; + const uint32_t q_start = args->q_start; + const uint32_t kv_head = args->kv_head; + const uint32_t ib3 = args->ib3; + + for (size_t r = start; r < end; ++r) { + const size_t q_idx = fastdiv(r, &factx->div_G); + const size_t h_idx = fastmodulo(r, G, &factx->div_G); + + // FIX(dst-indexing): ggml_flash_attn_ext() creates dst as permute(0,2,1,3) -> + // [DV, n_heads, n_tokens, n_seq], so head stride is nb[1] and token stride is nb[2]. + uint8_t * dst_row = (uint8_t *) dst->data + (kv_head * G + h_idx) * dst->nb[1] + + (q_start + q_idx) * dst->nb[2] + ib3 * dst->nb[3]; + + size_t r0 = r / HMX_FP16_TILE_N_ROWS; + size_t r1 = r % HMX_FP16_TILE_N_ROWS; + const __fp16 * tile_row_base = o_tile_src + r0 * HMX_FP16_TILE_N_ROWS * DV; + + if (factx->is_dst_fp32) { + float * out = (float *) dst_row; + for (uint32_t d = 0; d < DV / 32; ++d) { + const HVX_Vector * in_tile = (const HVX_Vector *) (tile_row_base + d * HMX_FP16_TILE_N_ELMS); + HVX_VectorPair vp = hvx_vec_f16_to_f32_shuff(in_tile[r1 / 2]); + if (r1 % 2 == 0) { + *(HVX_UVector *) (out + d * 32) = Q6_V_lo_W(vp); + } else { + *(HVX_UVector *) (out + d * 32) = Q6_V_hi_W(vp); + } + } + } else { + __fp16 * out = (__fp16 *) dst_row; + for (uint32_t d = 0; d < DV / 64; ++d) { + const __fp16 * in_dual_tile = tile_row_base + d * HMX_FP16_TILE_N_ELMS * 2; + const HVX_Vector * pv_in0 = ((const HVX_Vector *) in_dual_tile) + r1 / 2; + const HVX_Vector * pv_in1 = pv_in0 + 16; + HVX_VectorPair vp = Q6_W_vdeal_VVR(*pv_in1, *pv_in0, -2); + if (r1 % 2 == 0) { + *(HVX_UVector *) (out + d * 64) = Q6_V_lo_W(vp); + } else { + *(HVX_UVector *) (out + d * 64) = Q6_V_hi_W(vp); + } + } + } + } +} + +static void fa_phase_o_store(struct hmx_fa_context * factx, + const struct htp_tensor * dst, + const __fp16 * o_tile_src, + uint32_t q_start, + uint32_t kv_head, + uint32_t ib3, + size_t n_rows_g) { + worker_pool_context_t wp = factx->octx->ctx->worker_pool; + fa_o_store_args_t args = { factx, dst, o_tile_src, q_start, kv_head, ib3, n_rows_g }; + if (factx->n_threads > 1 && n_rows_g >= (size_t) (factx->n_threads * 2)) { + worker_pool_run_func(wp, fa_o_store_thread, &args, factx->n_threads); + } else { + fa_o_store_thread(1, 0, &args); + } +} + +// ============================================================================ +// Multi-thread softmax phase + serial m/l update + build_D +// ============================================================================ + +typedef struct { + struct hmx_fa_context * factx; + size_t kv_rows; + size_t n_rows_g; + size_t n_col_tiles; + size_t n_tiles_per_bc; + size_t n_row_tiles; + size_t n_row_tiles_g_br; + uint32_t Bc; + uint32_t G; + uint32_t kv_head; + uint32_t kv_start; + uint32_t q_start; + uint32_t ib3; + bool has_alibi; // true when max_bias != 0 (need slope * mask + add) + + // ALiBi per-head slopes (indexed by GQA-merged row: slope[r] for r in [0, n_rows_g)) + // slope[r] = 1.0 when max_bias == 0 (no ALiBi) + // Pointer into hmx_fa_context.vtcm_slopes (sized to g_br) + __fp16 * slopes; + + // Mask info (preloaded before softmax) + const struct htp_tensor * mask; + const __fp16 * mask_vtcm; // VTCM mask buffer base (NULL = DDR fallback) + size_t mask_vtcm_row_stride; // elements (__fp16) per row in VTCM mask buffer +} fa_softmax_args_t; + +static void fa_softmax_thread(unsigned int n, unsigned int i, void * data) { + fa_softmax_args_t * args = (fa_softmax_args_t *) data; + struct hmx_fa_context * factx = args->factx; + + const size_t n_rows_g = args->n_rows_g; + const size_t kv_rows = args->kv_rows; + const size_t Bc = args->Bc; + const size_t G = args->G; + const size_t n_tiles_per_bc = args->n_tiles_per_bc; + const size_t n_row_vec_cnt = hmx_ceil_div(n_rows_g, 64); + + // Partition r_vec_idx across threads + const size_t vecs_per_t = hmx_ceil_div(n_row_vec_cnt, n); + const size_t vec_start = i * vecs_per_t; + const size_t vec_end = hex_smin(vec_start + vecs_per_t, n_row_vec_cnt); + + if (vec_start >= n_row_vec_cnt) { + return; + } + + // Per-thread row scratch: thread i uses bufs at offset i * 2 * stride + const size_t row_buf_stride = factx->row_buf_stride; + HVX_Vector * my_row_buf0 = factx->vtcm_row_bufs + i * 2 * row_buf_stride; + HVX_Vector * my_row_buf1 = my_row_buf0 + row_buf_stride; + + const HVX_Vector v_neg_inf = Q6_Vh_vsplat_R(0xfbff); + + // Per-row accumulators: each fp16 lane in a 64-lane vector holds one row's scalar. + // CONTRACT: lane bits must be IEEE fp16 (hf), never qf16 — qf16 uses a different + // bit layout, so a later hf-domain read would silently produce wrong values. + // Convert first via Q6_Vhf_equals_Vqf16(). For reference: vtcm_m_vec/vtcm_s_rowmax + // are hf; vtcm_l_vec is qf16 — don't mix them up. + + for (size_t r_vec_idx = vec_start; r_vec_idx < vec_end; ++r_vec_idx) { + HVX_Vector rowmax_acc_v = v_neg_inf; + HVX_Vector rowsum_acc_v = Q6_V_vzero(); + HVX_Vector m_prev_v = factx->vtcm_m_vec[r_vec_idx]; + + for (int r_vec_off = 0; r_vec_off < 64; r_vec_off += 2) { + int r = r_vec_idx * 64 + r_vec_off; + if (r >= (int) hex_align_up(n_rows_g, 2)) { + break; + } + + int r0 = r / HMX_FP16_TILE_N_ROWS; + int r1 = r % HMX_FP16_TILE_N_ROWS; + + const __fp16 * s_ld_base = factx->vtcm_s_tiles + r0 * HMX_FP16_TILE_N_ROWS * Bc; + __fp16 * p_st_base = factx->vtcm_p_tiles + r0 * HMX_FP16_TILE_N_ROWS * Bc; + + // Decode 2 rows from S tiles into per-thread row buffers + HVX_Vector * pv_row_buf0 = my_row_buf0; + HVX_Vector * pv_row_buf1 = my_row_buf1; + for (size_t c = 0; c < kv_rows; c += 64) { + const __fp16 * in_dual_tile = s_ld_base + (c / 64) * HMX_FP16_TILE_N_ELMS * 2; + const HVX_Vector * pv_s_in0 = ((const HVX_Vector *) in_dual_tile) + r1 / 2; + const HVX_Vector * pv_s_in1 = pv_s_in0 + 16; + + HVX_VectorPair vp_s_dual_row = Q6_W_vdeal_VVR(*pv_s_in1, *pv_s_in0, -2); + *pv_row_buf0++ = Q6_V_lo_W(vp_s_dual_row); + *pv_row_buf1++ = Q6_V_hi_W(vp_s_dual_row); + } + + // Apply softcap if enabled (in F32 precision) + if (factx->logit_softcap != 0.0f) { + // When EXP2_HF is on, fold log2(e) into v_cap so the output lands in + // log2(e)-scaled space for the downstream exp2. log2(e) is kept OUT + // of qk_scale in this configuration (see scale setup) so tanh sees + // the physical QK/(√d·c) argument. + float cap = factx->logit_softcap; +#ifdef HMX_FA_USE_EXP2_HF + cap *= 1.44269504f; // log2(e) +#endif + const HVX_Vector v_cap = hvx_vec_splat_f32(cap); + for (size_t c = 0; c < kv_rows; c += 64) { + size_t ci = c / 64; + + HVX_VectorPair r0_f32 = hvx_vec_f16_to_f32(my_row_buf0[ci]); + HVX_Vector t0_lo = hvx_vec_tanh_f32(Q6_V_lo_W(r0_f32)); + HVX_Vector t0_hi = hvx_vec_tanh_f32(Q6_V_hi_W(r0_f32)); + t0_lo = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(t0_lo, v_cap)); + t0_hi = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(t0_hi, v_cap)); + my_row_buf0[ci] = hvx_vec_f32_to_f16(t0_lo, t0_hi); + + HVX_VectorPair r1_f32 = hvx_vec_f16_to_f32(my_row_buf1[ci]); + HVX_Vector t1_lo = hvx_vec_tanh_f32(Q6_V_lo_W(r1_f32)); + HVX_Vector t1_hi = hvx_vec_tanh_f32(Q6_V_hi_W(r1_f32)); + t1_lo = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(t1_lo, v_cap)); + t1_hi = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(t1_hi, v_cap)); + my_row_buf1[ci] = hvx_vec_f32_to_f16(t1_lo, t1_hi); + } + } + + // Apply mask & compute rowmax(S) + // + // Optimizations over baseline: + // A. No-ALiBi fast path: when max_bias==0 (slope≡1.0), skip the + // slope multiplication — still add mask (additive bias) but + // avoid the mul_f16_f16. Saves 2 ops/dual-row vs ALiBi path. + // B. GQA mask row dedup: G consecutive Q rows share one mask row + // (qi = r / G). Reuse mask vector when qi is unchanged between + // row0 and row1 (saves ~75% of VTCM loads for G=4). + + // ALiBi slopes — only needed when has_alibi (scheme A) + HVX_Vector v_slope0, v_slope1; + if (args->has_alibi) { + HVX_Vector v_s = hvx_vmemu(args->slopes + r); + v_slope0 = hvx_vec_repl_f16(v_s); + v_slope1 = (r + 1 < (int) n_rows_g) ? hvx_vec_repl_f16(Q6_V_vror_VR(v_s, 2)) : Q6_V_vzero(); + } + + const HVX_Vector v_threshold = Q6_Vh_vsplat_R(0xcc00); // fp16 -16.0 (hoisted outside for-c) + + HVX_Vector v_s_rowmax0 = v_neg_inf; + HVX_Vector v_s_rowmax1 = v_neg_inf; + for (size_t c = 0; c < kv_rows; c += 64) { + size_t ci = c / 64; + const size_t ne = hex_smin(kv_rows - c, 64); + HVX_VectorPred q_tail_keep = Q6_Q_vsetq2_R(ne * sizeof(__fp16)); + + if (args->mask) { + HVX_Vector v_mask0, v_mask1; + + if (args->mask_vtcm) { + // Read mask from VTCM buffer (DMA'd per KV block). + // GQA dedup (scheme B): skip load when qi unchanged. + const size_t qi0 = fastdiv(r + 0, &factx->div_G); + v_mask0 = *(const HVX_UVector *) (args->mask_vtcm + qi0 * args->mask_vtcm_row_stride + c); + v_mask1 = v_neg_inf; + if (r + 1 < (int) n_rows_g) { + const size_t qi1 = fastdiv(r + 1, &factx->div_G); + if (qi1 == qi0) { + v_mask1 = v_mask0; // scheme B: reuse — same mask row + } else { + v_mask1 = *(const HVX_UVector *) (args->mask_vtcm + qi1 * args->mask_vtcm_row_stride + c); + } + } + } else { + // Fallback: read mask directly from DDR (when mask->ne[2] > 1). + const struct htp_tensor * mask = args->mask; + const size_t q_idx0 = args->q_start + fastdiv(r + 0, &factx->div_G); + const size_t h_idx0 = args->kv_head * G + fastmodulo(r + 0, G, &factx->div_G); + const uint32_t im2_0 = h_idx0 % mask->ne[2]; + const uint32_t im3_0 = args->ib3 % mask->ne[3]; + + const __fp16 * m0_ptr = (const __fp16 *) ((const uint8_t *) mask->data + q_idx0 * mask->nb[1] + + im2_0 * mask->nb[2] + im3_0 * mask->nb[3]) + args->kv_start + c; + v_mask0 = *(const HVX_UVector *) m0_ptr; + v_mask1 = v_neg_inf; + + if (r + 1 < (int) n_rows_g) { + const size_t q_idx1 = args->q_start + fastdiv(r + 1, &factx->div_G); + if (q_idx1 == q_idx0) { + // scheme B: same mask row in DDR path + v_mask1 = v_mask0; + } else { + const size_t h_idx1 = args->kv_head * G + fastmodulo(r + 1, G, &factx->div_G); + const uint32_t im2_1 = h_idx1 % mask->ne[2]; + const uint32_t im3_1 = args->ib3 % mask->ne[3]; + const __fp16 * m1_ptr = (const __fp16 *) ((const uint8_t *) mask->data + q_idx1 * mask->nb[1] + + im2_1 * mask->nb[2] + im3_1 * mask->nb[3]) + args->kv_start + c; + v_mask1 = *(const HVX_UVector *) m1_ptr; + } + } + } + + // Threshold: mask values below -16.0 are treated as -inf (causal mask). + HVX_VectorPred q_keep0 = Q6_Q_and_QQ(Q6_Q_vcmp_gt_VhfVhf(v_mask0, v_threshold), q_tail_keep); + HVX_VectorPred q_keep1 = Q6_Q_and_QQ(Q6_Q_vcmp_gt_VhfVhf(v_mask1, v_threshold), q_tail_keep); + + if (args->has_alibi) { + // ALiBi path: S += slope * mask (full mul + add) + HVX_Vector v_sm0 = hvx_vec_mul_f16_f16(v_mask0, v_slope0); + HVX_Vector v_sm1 = hvx_vec_mul_f16_f16(v_mask1, v_slope1); + my_row_buf0[ci] = Q6_V_vmux_QVV(q_keep0, hvx_vec_add_f16_f16(my_row_buf0[ci], v_sm0), v_neg_inf); + my_row_buf1[ci] = Q6_V_vmux_QVV(q_keep1, hvx_vec_add_f16_f16(my_row_buf1[ci], v_sm1), v_neg_inf); + } else { + // No-ALiBi fast path (scheme A): slope≡1.0, skip the mul + // but still add mask (additive positional bias). vmux + // clamps mask < -16 to -inf as a numerical safeguard. + my_row_buf0[ci] = Q6_V_vmux_QVV(q_keep0, hvx_vec_add_f16_f16(my_row_buf0[ci], v_mask0), v_neg_inf); + my_row_buf1[ci] = Q6_V_vmux_QVV(q_keep1, hvx_vec_add_f16_f16(my_row_buf1[ci], v_mask1), v_neg_inf); + } + } else { + if (ne < 64) { + my_row_buf0[ci] = Q6_V_vmux_QVV(q_tail_keep, my_row_buf0[ci], v_neg_inf); + my_row_buf1[ci] = Q6_V_vmux_QVV(q_tail_keep, my_row_buf1[ci], v_neg_inf); + } + } + + v_s_rowmax0 = Q6_Vhf_vmax_VhfVhf(v_s_rowmax0, my_row_buf0[ci]); + v_s_rowmax1 = Q6_Vhf_vmax_VhfVhf(v_s_rowmax1, my_row_buf1[ci]); + } + + v_s_rowmax0 = hvx_vec_reduce_max_f16(v_s_rowmax0); + v_s_rowmax1 = hvx_vec_reduce_max_f16(v_s_rowmax1); + + // Splat m_prev[r], m_prev[r+1] from the per-row accumulator. + // vror brings the target lane to lane 0, then vdelta replicates it + // across all lanes — stays in the vector domain (no store/reload). + HVX_Vector v_m_prev0 = hvx_vec_repl_f16(Q6_V_vror_VR(m_prev_v, r_vec_off * 2)); + HVX_Vector v_m_prev1 = hvx_vec_repl_f16(Q6_V_vror_VR(m_prev_v, (r_vec_off + 1) * 2)); + + // HVX max — both operands are splats, so result is splat of m_new. + HVX_Vector v_dup_m0 = Q6_Vhf_vmax_VhfVhf(v_m_prev0, v_s_rowmax0); + HVX_Vector v_dup_m1 = Q6_Vhf_vmax_VhfVhf(v_m_prev1, v_s_rowmax1); + + // Insert row r, r+1 rowmax into rowmax_acc_v via 2-byte-wide vmux. + // Byte ranges: lane0 = [r_vec_off*2 .. r_vec_off*2+1], lane1 shifted by 2. + // vsetq2 handles the n=128 corner case when r_vec_off reaches 62. + { + HVX_VectorPred p_start = Q6_Q_vsetq_R(r_vec_off * 2); + HVX_VectorPred p_mid = Q6_Q_vsetq_R((r_vec_off + 1) * 2); + HVX_VectorPred p_end = Q6_Q_vsetq2_R((r_vec_off + 2) * 2); + HVX_VectorPred p_lane0 = Q6_Q_and_QQn(p_mid, p_start); + HVX_VectorPred p_lane1 = Q6_Q_and_QQn(p_end, p_mid); + rowmax_acc_v = Q6_V_vmux_QVV(p_lane0, v_dup_m0, rowmax_acc_v); + rowmax_acc_v = Q6_V_vmux_QVV(p_lane1, v_dup_m1, rowmax_acc_v); + } + + // Compute P = exp(S - m_new), using HVX exp + const HVX_Vector v_zero = Q6_V_vzero(); + HVX_Vector v_p_rowsum0 = v_zero; + HVX_Vector v_p_rowsum1 = v_zero; + +#ifdef HMX_FA_USE_EXP2_HF + // FP16 exp2 polynomial path (matches htp-ops-lib flash_attn.c): + // P = exp2(S - m_new) + for (size_t c = 0; c < kv_rows; c += 64) { + size_t ci = c / 64; + HVX_Vector v_s_minus_m0 = Q6_Vqf16_vsub_VhfVhf(my_row_buf0[ci], v_dup_m0); + HVX_Vector v_s_minus_m1 = Q6_Vqf16_vsub_VhfVhf(my_row_buf1[ci], v_dup_m1); + + HVX_Vector v_p_row0_hf = hvx_exp2_hf(Q6_Vhf_equals_Vqf16(v_s_minus_m0)); + HVX_Vector v_p_row1_hf = hvx_exp2_hf(Q6_Vhf_equals_Vqf16(v_s_minus_m1)); +#else + // F32 exp path: qf16 → f32 → exp → f32 → f16. Higher precision, + for (size_t c = 0; c < kv_rows; c += 64) { + size_t ci = c / 64; + HVX_Vector v_s_minus_m0 = Q6_Vqf16_vsub_VhfVhf(my_row_buf0[ci], v_dup_m0); + HVX_Vector v_s_minus_m1 = Q6_Vqf16_vsub_VhfVhf(my_row_buf1[ci], v_dup_m1); + + HVX_VectorPair vp0 = hvx_vec_f16_to_f32_shuff(Q6_Vhf_equals_Vqf16(v_s_minus_m0)); + HVX_Vector p0_lo = hvx_vec_exp_f32(Q6_V_lo_W(vp0)); + HVX_Vector p0_hi = hvx_vec_exp_f32(Q6_V_hi_W(vp0)); + HVX_Vector v_p_row0_hf = hvx_vec_f32_to_f16_shuff(p0_lo, p0_hi); + + HVX_VectorPair vp1 = hvx_vec_f16_to_f32_shuff(Q6_Vhf_equals_Vqf16(v_s_minus_m1)); + HVX_Vector p1_lo = hvx_vec_exp_f32(Q6_V_lo_W(vp1)); + HVX_Vector p1_hi = hvx_vec_exp_f32(Q6_V_hi_W(vp1)); + HVX_Vector v_p_row1_hf = hvx_vec_f32_to_f16_shuff(p1_lo, p1_hi); +#endif + // Write P to tile format. Dual-tile pattern assumes Bc is a + // multiple of 64 (enforced by bc_unit=64 in hmx_fa_find_chunk_size), + // so both tile halves are always in the current r0 block. + __fp16 * out_dual_tile = p_st_base + (c / 64) * HMX_FP16_TILE_N_ELMS * 2; + HVX_Vector * pv_p_out0 = ((HVX_Vector *) out_dual_tile) + r1 / 2; + HVX_Vector * pv_p_out1 = pv_p_out0 + 16; + + HVX_VectorPair vp_p_dual = Q6_W_vshuff_VVR(v_p_row1_hf, v_p_row0_hf, -2); + *pv_p_out0 = Q6_V_lo_W(vp_p_dual); + *pv_p_out1 = Q6_V_hi_W(vp_p_dual); + + HVX_VectorPair vp_p0 = hvx_vec_f16_to_f32_shuff(v_p_row0_hf); + HVX_VectorPair vp_p1 = hvx_vec_f16_to_f32_shuff(v_p_row1_hf); + + v_p_rowsum0 = Q6_Vqf32_vadd_Vqf32Vqf32(v_p_rowsum0, Q6_Vqf32_vadd_VsfVsf(Q6_V_lo_W(vp_p0), Q6_V_hi_W(vp_p0))); + v_p_rowsum1 = Q6_Vqf32_vadd_Vqf32Vqf32(v_p_rowsum1, Q6_Vqf32_vadd_VsfVsf(Q6_V_lo_W(vp_p1), Q6_V_hi_W(vp_p1))); + } + + HVX_Vector rowsum0_sf = hvx_vec_reduce_sum_f32(Q6_Vsf_equals_Vqf32(v_p_rowsum0)); + HVX_Vector rowsum1_sf = hvx_vec_reduce_sum_f32(Q6_Vsf_equals_Vqf32(v_p_rowsum1)); + { + // Both inputs are f32 splats, so the f32->f16 output is an fp16 splat. + HVX_Vector rv0_v = hvx_vec_f32_to_f16(rowsum0_sf, rowsum0_sf); + HVX_Vector rv1_v = hvx_vec_f32_to_f16(rowsum1_sf, rowsum1_sf); + + HVX_VectorPred p_start = Q6_Q_vsetq_R(r_vec_off * 2); + HVX_VectorPred p_mid = Q6_Q_vsetq_R((r_vec_off + 1) * 2); + HVX_VectorPred p_end = Q6_Q_vsetq2_R((r_vec_off + 2) * 2); + HVX_VectorPred p_lane0 = Q6_Q_and_QQn(p_mid, p_start); + HVX_VectorPred p_lane1 = Q6_Q_and_QQn(p_end, p_mid); + rowsum_acc_v = Q6_V_vmux_QVV(p_lane0, rv0_v, rowsum_acc_v); + rowsum_acc_v = Q6_V_vmux_QVV(p_lane1, rv1_v, rowsum_acc_v); + } + } + + factx->vtcm_s_rowmax[r_vec_idx] = rowmax_acc_v; + factx->vtcm_p_rowsum[r_vec_idx] = rowsum_acc_v; + } +} + +// Serial m/l update + build_D. Must run after softmax barrier (s_rowmax written by all threads). +// +// noinline: function boundary acts as a hard compiler barrier so the (size_t)addr scatter +// intrinsics inside cannot be hoisted past the call site. Mirrors the structural protection +// matmul gets for free via worker_pool function-pointer dispatch. Without this, the compiler +// can reorder the scatter past the subsequent hmx_queue_push and the HMX-queue worker thread +// reads stale VTCM (PPL → ~vocab-size). +static __attribute__((noinline)) void fa_ml_update_and_build_d(struct hmx_fa_context * factx, + size_t n_rows_g, + size_t n_row_tiles, + size_t n_row_tiles_g_br) { + // Reuse s_rowmax buffer for exp(m_diff) — safe because softmax is fully complete + HVX_Vector * const mvec_exp_m_diff = factx->vtcm_s_rowmax; + + const size_t n_row_vec_cnt = hmx_ceil_div(n_rows_g, 64); + for (size_t i = 0; i < n_row_vec_cnt; ++i) { + HVX_Vector v_m_prev = factx->vtcm_m_vec[i]; + HVX_Vector v_m_curr = Q6_Vhf_vmax_VhfVhf(v_m_prev, factx->vtcm_s_rowmax[i]); + HVX_Vector v_m_diff = Q6_Vqf16_vsub_VhfVhf(v_m_prev, v_m_curr); + +#ifdef HMX_FA_USE_EXP2_HF + // Base-2 path: must match P = exp2(S - m_new) in fa_softmax_thread. + HVX_Vector v_exp_m_diff = hvx_exp2_hf(Q6_Vhf_equals_Vqf16(v_m_diff)); +#else + HVX_VectorPair vp_diff = hvx_vec_f16_to_f32_shuff(Q6_Vhf_equals_Vqf16(v_m_diff)); + HVX_Vector exp_lo = hvx_vec_exp_f32(Q6_V_lo_W(vp_diff)); + HVX_Vector exp_hi = hvx_vec_exp_f32(Q6_V_hi_W(vp_diff)); + HVX_Vector v_exp_m_diff = hvx_vec_f32_to_f16_shuff(exp_lo, exp_hi); +#endif + + HVX_Vector v_l_curr = Q6_Vqf16_vmpy_Vqf16Vhf(factx->vtcm_l_vec[i], v_exp_m_diff); + v_l_curr = Q6_Vqf16_vadd_Vqf16Vhf(v_l_curr, factx->vtcm_p_rowsum[i]); + + factx->vtcm_m_vec[i] = v_m_curr; + factx->vtcm_l_vec[i] = v_l_curr; + mvec_exp_m_diff[i] = v_exp_m_diff; + } + + // Build diagonal tile D = diag(exp(m_diff)) + const HVX_Vector v_offsets = *(const HVX_Vector *) d_tile_scatter_offsets; + const HVX_VectorPred q_32_mask = Q6_Q_vsetq_R(32 * sizeof(__fp16)); + for (size_t i = 0; i < n_row_tiles; ++i) { + const HVX_Vector v_content = Q6_V_vror_VR(mvec_exp_m_diff[i / 2], (i % 2) * 64); + __fp16 * out_base = factx->vtcm_d_tiles + i * (n_row_tiles_g_br + 1) * HMX_FP16_TILE_N_ELMS; + Q6_vscatter_QRMVhV(q_32_mask, (size_t) out_base, HMX_FP16_TILE_SIZE - 1, v_offsets, v_content); + // Compiler barrier — Q6_vscatter takes (size_t)addr; without this the + // compiler may not recognize the volatile read below as aliasing and + // could reorder it before the scatter, defeating the HW drain. + __asm__ __volatile__("" ::: "memory"); + // Per-tile drain: scatter regions are disjoint (stride > tile size), + // so a single drain at tile 0 does NOT retire later tiles' entries. + (void) *(volatile HVX_Vector *) out_base; + } +} + +// Build D = diag(1/l) tile for the final O = D @ O normalization. +// +// noinline: same rationale as fa_ml_update_and_build_d — keeps Q6_vscatter from +// being hoisted past the subsequent hmx_queue_push at the o_norm call site. +static __attribute__((noinline)) void fa_build_d_diag_inv_l(struct hmx_fa_context * factx, + size_t n_row_tiles, + size_t n_row_tiles_g_br) { + const HVX_Vector v_offsets = *(const HVX_Vector *) d_tile_scatter_offsets; + const HVX_VectorPred q_32_mask = Q6_Q_vsetq_R(32 * sizeof(__fp16)); + const HVX_Vector one = hvx_vec_splat_f32(1.0f); + + HVX_Vector v_content = Q6_V_vzero(); + for (size_t i = 0; i < n_row_tiles; ++i) { + if ((i % 2) == 0) { + HVX_Vector v_l_hf = Q6_Vhf_equals_Vqf16(factx->vtcm_l_vec[i / 2]); + HVX_VectorPair vp_l = hvx_vec_f16_to_f32_shuff(v_l_hf); + HVX_Vector inv_lo = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(one, hvx_vec_inverse_f32(Q6_V_lo_W(vp_l)))); + HVX_Vector inv_hi = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(one, hvx_vec_inverse_f32(Q6_V_hi_W(vp_l)))); + v_content = hvx_vec_f32_to_f16_shuff(inv_lo, inv_hi); + } else { + v_content = Q6_V_vror_VR(v_content, 64); + } + + __fp16 * out_base = factx->vtcm_d_tiles + i * (n_row_tiles_g_br + 1) * HMX_FP16_TILE_N_ELMS; + Q6_vscatter_QRMVhV(q_32_mask, (size_t) out_base, HMX_FP16_TILE_SIZE - 1, v_offsets, v_content); + // Compiler barrier — see fa_ml_update_and_build_d for rationale. + __asm__ __volatile__("" ::: "memory"); + (void) *(volatile HVX_Vector *) out_base; + } +} + +// Combined: multi-thread softmax -> barrier -> serial m/l update + build_D +static void fa_phase_softmax_and_build_d(struct hmx_fa_context * factx, + fa_softmax_args_t * sargs, + size_t n_row_tiles, + size_t n_row_tiles_g_br) { + worker_pool_context_t wp = factx->octx->ctx->worker_pool; + const size_t n_row_vec_cnt = hmx_ceil_div(sargs->n_rows_g, 64); + + if (factx->n_threads > 1 && n_row_vec_cnt >= 2) { + uint32_t n_use = (uint32_t) hex_smin((size_t) factx->n_threads, n_row_vec_cnt); + worker_pool_run_func(wp, fa_softmax_thread, sargs, n_use); + } else { + fa_softmax_thread(1, 0, sargs); + } + // barrier implicit in worker_pool_run_func return + + fa_ml_update_and_build_d(factx, sargs->n_rows_g, n_row_tiles, n_row_tiles_g_br); +} + +// ============================================================================ +// HMX job structs and worker functions +// ============================================================================ + +typedef struct { + const __fp16 * q_tiles; + const __fp16 * k_tiles; + __fp16 * s_tiles; + size_t n_row_tiles; + size_t n_col_tiles; + size_t n_dot_tiles; // DK / 32 + size_t n_tiles_per_bc; + uint8_t * hmx_scales; +} hmx_fa_qk_job_t; + +static void hmx_fa_qk_dot_worker(void * data) { + hmx_fa_qk_job_t * job = (hmx_fa_qk_job_t *) data; + const size_t n_row_tiles = job->n_row_tiles; + const size_t n_col_tiles = job->n_col_tiles; + const size_t n_dot_tiles = job->n_dot_tiles; + const size_t n_tiles_per_bc = job->n_tiles_per_bc; + const __fp16 * restrict q_tiles = job->q_tiles; + const __fp16 * restrict k_tiles = job->k_tiles; + __fp16 * restrict s_tiles = job->s_tiles; + __builtin_assume(n_row_tiles > 0); + __builtin_assume(n_col_tiles > 0); + __builtin_assume(n_dot_tiles > 0); + + Q6_bias_mxmem2_A((void *) job->hmx_scales); + for (size_t r = 0; r < n_row_tiles; ++r) { + for (size_t c = 0; c < n_col_tiles; ++c) { + const __fp16 * row_tiles = q_tiles + r * HMX_FP16_TILE_N_ROWS * n_dot_tiles * HMX_FP16_TILE_N_COLS; + const __fp16 * col_tiles = k_tiles + c * HMX_FP16_TILE_N_COLS * n_dot_tiles * HMX_FP16_TILE_N_COLS; + __fp16 * out_tile = s_tiles + (r * n_tiles_per_bc + c) * HMX_FP16_TILE_N_ELMS; + + for (size_t k = 0; k < n_dot_tiles; ++k) { + Q6_activation_hf_mxmem_RR((unsigned int) row_tiles, 2047); + Q6_weight_hf_mxmem_RR((unsigned int) col_tiles, 2047); + row_tiles += HMX_FP16_TILE_N_ELMS; + col_tiles += HMX_FP16_TILE_N_ELMS; + } + Q6_mxmem_AR_after_hf(out_tile, 0); + } + } +} + +typedef struct { + __fp16 * o_curr; + const __fp16 * o_prev; + const __fp16 * p_tiles; + const __fp16 * v_tiles; + const __fp16 * d_tiles; + uint8_t * hmx_scales; + size_t n_row_tiles; + size_t n_col_tiles; + size_t n_row_tiles_g_br; + size_t n_tiles_per_bc; + size_t DV; +} hmx_fa_o_update_job_t; + +static void hmx_fa_o_update_worker(void * data) { + hmx_fa_o_update_job_t * job = (hmx_fa_o_update_job_t *) data; + const size_t n_row_tiles = job->n_row_tiles; + const size_t n_col_tiles = job->n_col_tiles; + const size_t n_row_tiles_g_br = job->n_row_tiles_g_br; + const size_t n_tiles_per_bc = job->n_tiles_per_bc; + const size_t DV_tiles = job->DV / 32; + const __fp16 * restrict d_tiles = job->d_tiles; + const __fp16 * restrict p_tiles = job->p_tiles; + const __fp16 * restrict v_tiles = job->v_tiles; + const __fp16 * restrict o_prev = job->o_prev; + __fp16 * restrict o_curr = job->o_curr; + __builtin_assume(n_row_tiles > 0); + __builtin_assume(n_col_tiles > 0); + __builtin_assume(DV_tiles > 0); + + Q6_bias_mxmem2_A((void *) job->hmx_scales); + for (size_t r = 0; r < n_row_tiles; ++r) { + for (size_t c = 0; c < DV_tiles; ++c) { + // D[r,r] @ O_prev[r,c] — only the diagonal tile + const __fp16 * d_diag = d_tiles + r * (n_row_tiles_g_br + 1) * HMX_FP16_TILE_N_ELMS; + const __fp16 * o_rc = o_prev + (c * n_row_tiles_g_br + r) * HMX_FP16_TILE_N_ELMS; + Q6_activation_hf_mxmem_RR((unsigned int) d_diag, 2047); + Q6_weight_hf_mxmem_RR((unsigned int) o_rc, 2047); + + // P @ V (accumulate on same accumulator) + const __fp16 * p_tile_in = p_tiles + (r * n_tiles_per_bc) * HMX_FP16_TILE_N_ELMS; + const __fp16 * v_tile_in = v_tiles + (c * n_tiles_per_bc) * HMX_FP16_TILE_N_ELMS; + for (size_t k = 0; k < n_col_tiles; ++k) { + Q6_activation_hf_mxmem_RR((unsigned int) p_tile_in, 2047); + Q6_weight_hf_mxmem_RR((unsigned int) v_tile_in, 2047); + p_tile_in += HMX_FP16_TILE_N_ELMS; + v_tile_in += HMX_FP16_TILE_N_ELMS; + } + + __fp16 * o_tile_out = o_curr + (c * n_row_tiles_g_br + r) * HMX_FP16_TILE_N_ELMS; + Q6_mxmem_AR_after_hf(o_tile_out, 0); + } + } +} + +typedef struct { + __fp16 * o_curr; // output (row-major tile layout) + const __fp16 * o_prev; // input (column-major tile layout) + const __fp16 * d_tiles; // diag(1/l) tiles + uint8_t * hmx_scales; + size_t n_row_tiles; + size_t n_row_tiles_g_br; + size_t DV; +} hmx_fa_o_norm_job_t; + +static void hmx_fa_o_norm_worker(void * data) { + hmx_fa_o_norm_job_t * job = (hmx_fa_o_norm_job_t *) data; + const size_t n_row_tiles = job->n_row_tiles; + const size_t n_row_tiles_g_br = job->n_row_tiles_g_br; + const size_t DV_tiles = job->DV / 32; + const __fp16 * restrict d_tiles = job->d_tiles; + const __fp16 * restrict o_prev = job->o_prev; + __fp16 * restrict o_curr = job->o_curr; + __builtin_assume(n_row_tiles > 0); + __builtin_assume(DV_tiles > 0); + + Q6_bias_mxmem2_A((void *) job->hmx_scales); + for (size_t r = 0; r < n_row_tiles; ++r) { + for (size_t c = 0; c < DV_tiles; ++c) { + const __fp16 * d_diag = d_tiles + r * (n_row_tiles_g_br + 1) * HMX_FP16_TILE_N_ELMS; + const __fp16 * o_rc = o_prev + (c * n_row_tiles_g_br + r) * HMX_FP16_TILE_N_ELMS; + __fp16 * o_out = o_curr + (r * DV_tiles + c) * HMX_FP16_TILE_N_ELMS; + + Q6_activation_hf_mxmem_RR((unsigned int) d_diag, 2047); + Q6_weight_hf_mxmem_RR((unsigned int) o_rc, 2047); + Q6_mxmem_AR_after_hf(o_out, 0); + } + } +} + +// Populate per-GQA-row ALiBi slopes for a given KV head. +// Row r in the GQA-merged block maps to Q head h = kv_head * G + r % G. +// slope(h) = m0^(h+1) when h < n_head_log2, else m1^(2*(h-n_head_log2)+1). +// When max_bias == 0, all slopes are 1.0 (no ALiBi). +static __attribute__((noinline)) void fa_compute_slopes( + const struct hmx_fa_context * factx, + uint32_t kv_head, + size_t n_rows_g) { + __fp16 * slopes = factx->vtcm_slopes; + if (factx->max_bias == 0.0f) { + hvx_splat_f16_a(slopes, 1.0f, n_rows_g); + return; + } + + const uint32_t G = factx->G; + const uint32_t n_head_log2 = factx->n_head_log2; + const float m0 = factx->m0; + const float m1 = factx->m1; + + __fp16 temp_slopes[512] __attribute__((aligned(128))); + if (G <= 32) { + // Fast path: Compute G unique slope values in vector registers + HVX_Vector v_val = hvx_alibi_slopes(kv_head, G, n_head_log2, m0, m1); + + __fp16 temp_slopes_aligned[64] __attribute__((aligned(128))); + hvx_vmem(temp_slopes_aligned) = hvx_vec_f32_to_f16(v_val, Q6_V_vzero()); + + for (uint32_t i = 0; i < G; ++i) { + temp_slopes[i] = temp_slopes_aligned[i]; + } + } else { + // Fallback path: G > 32 (rare configurations) + for (uint32_t i = 0; i < G; ++i) { + temp_slopes[i] = (__fp16)alibi_slope(kv_head * G + i, n_head_log2, m0, m1); + } + } + + // Allocate stack buffer to avoid scalar writes to VTCM (which generates L2 misses) + __fp16 local_slopes[n_rows_g] __attribute__((aligned(128))); + for (size_t r = 0; r < n_rows_g; ++r) { + local_slopes[r] = temp_slopes[fastmodulo(r, G, &factx->div_G)]; + } + + // Copy to VTCM slopes using HVX block copy (both are aligned to 128 bytes) + hvx_copy_f16_aa((uint8_t *)slopes, (const uint8_t *)local_slopes, n_rows_g); +} + +// ============================================================================ +// Core HMX flash attention algorithm (GQA-merged) +// ============================================================================ + +int hmx_flash_attn_ext(struct htp_ops_context * octx) { + const struct htp_tensor * q = octx->src[0]; + const struct htp_tensor * k = octx->src[1]; + const struct htp_tensor * v = octx->src[2]; + const struct htp_tensor * mask = (octx->src[3] && octx->src[3]->data) ? octx->src[3] : NULL; + const struct htp_tensor * dst = octx->dst; + + struct htp_context * const ctx = octx->ctx; + + if (!ctx->hmx_enabled) { + return HTP_STATUS_NO_SUPPORT; + } + + // Dimensions + const uint32_t neq0 = q->ne[0]; // head_dim (DK) + const uint32_t neq1 = q->ne[1]; // n_tokens + const uint32_t neq2 = q->ne[2]; // n_heads + const uint32_t neq3 = q->ne[3]; // n_seqs + + const uint32_t nek0 = k->ne[0]; // head_dim + const uint32_t nek1 = k->ne[1]; // kv_len + + const uint32_t nev0 = v->ne[0]; // head_dim (DV) + + const uint32_t DK = neq0; + const uint32_t DV = nev0; + + // HMX requires head_dim to be multiple of 32 + if (DK % 32 != 0 || DV % 32 != 0) { + return HTP_STATUS_NO_SUPPORT; + } + + // GQA factor + const uint32_t n_kv_heads = k->ne[2]; + const uint32_t G = neq2 / n_kv_heads; + + // Thread count for multi-thread HVX phases + const uint32_t n_threads_init = octx->n_threads; + + // Compute dynamic block sizes (GQA-aware, accounting for per-thread row bufs) + size_t Br, Bc; + const size_t vtcm_budget = ctx->vtcm_size; + if (hmx_fa_find_chunk_size(&Br, &Bc, G, DK, DV, neq1, nek1, vtcm_budget, n_threads_init) != 0) { + return HTP_STATUS_VTCM_TOO_SMALL; + } + + const size_t g_br = hex_align_up(G * Br, HMX_FP16_TILE_N_ROWS); + + const uint32_t n_kv_blocks = (nek1 + Bc - 1) / Bc; + const bool use_pipeline = (n_kv_blocks >= FA_MIN_KV_BLOCKS && n_threads_init >= 2); + + // Bypass thread pool dispatch for small prompts/non-pipelined prefill by setting n_threads = 1 + const uint32_t n_threads = use_pipeline ? n_threads_init : 1; + + FARF(HIGH, "hmx-fa: neq1=%u nek1=%u DK=%u DV=%u G=%u Br=%zu Bc=%zu g_br=%zu n_kv_blocks=%u pipeline=%d vtcm=%zu", + neq1, nek1, DK, DV, G, Br, Bc, g_br, n_kv_blocks, use_pipeline, vtcm_budget); + + // ======== Build context ======== + struct hmx_fa_context factx; + memset(&factx, 0, sizeof(factx)); + factx.octx = octx; + factx.n_threads = n_threads; + factx.DK = DK; + factx.DV = DV; + factx.n_kv = nek1; + factx.n_kv_heads = n_kv_heads; + factx.n_heads = neq2; + factx.G = G; + factx.div_G = init_fastdiv_values(G); + factx.neq1 = neq1; + factx.Br = (uint32_t) Br; + factx.Bc = (uint32_t) Bc; + factx.g_br = (uint32_t) g_br; + factx.n_kv_blocks = n_kv_blocks; + factx.is_q_fp32 = (q->type == HTP_TYPE_F32); + factx.is_dst_fp32 = (dst->type == HTP_TYPE_F32); + factx.use_pipeline = use_pipeline; + factx.mask_broadcast = (mask != NULL && mask->ne[2] == 1); + + // Extract op parameters (mutable during softcap adjustment, then stored as const in factx) + float scale = 1.0f, max_bias = 0.0f, logit_softcap = 0.0f; + memcpy(&scale, (float *) octx->op_params + 0, sizeof(float)); + memcpy(&max_bias, (float *) octx->op_params + 1, sizeof(float)); + memcpy(&logit_softcap, (float *) octx->op_params + 2, sizeof(float)); + + if (logit_softcap != 0.0f) { + scale /= logit_softcap; + } + +#ifdef HMX_FA_USE_EXP2_HF + // Pre-bake log2(e) into qk_scale so HMX-produced S tiles are in log2(e)-scaled + // space. Then exp2(S - m) in the softmax equals base-e exp((S - m) / log2(e)), + // preserving ggml's base-e softmax semantics. Matches htp-ops-lib flash_attn.c. + // + // When softcap is active we cannot pre-bake log2(e) here — it would land inside + // the tanh argument and shift the softcap knee from x≈c to x≈c/log2(e), giving + // numerically wrong softcapped values. Instead fold log2(e) into the post-tanh + // multiplier (see softcap block: v_cap absorbs log2(e)). + if (logit_softcap == 0.0f) { + scale *= 1.44269504f; // log2(e) + } +#endif + + factx.scale = scale; + factx.max_bias = max_bias; + factx.logit_softcap = logit_softcap; + + factx.n_head_log2 = 1u << (uint32_t) floor(log2(neq2)); + factx.m0 = powf(2.0f, -(max_bias) / factx.n_head_log2); + factx.m1 = powf(2.0f, -(max_bias / 2.0f) / factx.n_head_log2); + + // ======== VTCM allocation (GQA-aware) ======== + const size_t size_k_row = DK * sizeof(__fp16); + const size_t size_v_row = DV * sizeof(__fp16); + const size_t size_k_row_padded = hex_round_up(size_k_row, 128); + const size_t size_v_row_padded = hex_round_up(size_v_row, 128); + + const size_t q_tile_bytes = hex_align_up(g_br * DK * sizeof(__fp16), 4096); + const size_t o_tile_bytes = hex_align_up(g_br * DV * sizeof(__fp16), 4096); + const size_t k_dma_bytes = hex_align_up(Bc * size_k_row_padded, 4096); + const size_t v_dma_bytes = hex_align_up(Bc * size_v_row_padded, 4096); + const size_t k_tile_bytes = hex_align_up(Bc * DK * sizeof(__fp16), 4096); + const size_t v_tile_bytes = hex_align_up(Bc * DV * sizeof(__fp16), 4096); + const size_t s_tile_bytes = hex_align_up(g_br * Bc * sizeof(__fp16), 4096); + const size_t d_tile_bytes = hex_align_up(g_br * g_br * sizeof(__fp16), 4096); + const size_t col_vec_bytes = hex_align_up(g_br * sizeof(__fp16), 256); + const size_t row_vec_bytes = hex_align_up(Bc * sizeof(__fp16), 256); + const size_t m_line_bytes = hex_align_up(Bc * sizeof(__fp16), 128); + const size_t m_buf_bytes = hex_align_up(Br * m_line_bytes, 4096); + const size_t slopes_bytes = hex_align_up(g_br * sizeof(__fp16), 128); + + uint8_t * vtcm_cur = ctx->vtcm_base; + + factx.vtcm_q_tiles = (__fp16 *) vtcm_seq_alloc(&vtcm_cur, q_tile_bytes); + factx.vtcm_o_tiles[0] = (__fp16 *) vtcm_seq_alloc(&vtcm_cur, o_tile_bytes); + factx.vtcm_o_tiles[1] = (__fp16 *) vtcm_seq_alloc(&vtcm_cur, o_tile_bytes); + factx.vtcm_k_fp16[0] = (__fp16 *) vtcm_seq_alloc(&vtcm_cur, k_dma_bytes); + factx.vtcm_k_fp16[1] = (__fp16 *) vtcm_seq_alloc(&vtcm_cur, k_dma_bytes); + factx.vtcm_v_fp16[0] = (__fp16 *) vtcm_seq_alloc(&vtcm_cur, v_dma_bytes); + factx.vtcm_v_fp16[1] = (__fp16 *) vtcm_seq_alloc(&vtcm_cur, v_dma_bytes); + factx.vtcm_k_tiles = (__fp16 *) vtcm_seq_alloc(&vtcm_cur, k_tile_bytes); + factx.vtcm_v_tiles[0] = (__fp16 *) vtcm_seq_alloc(&vtcm_cur, v_tile_bytes); + if (use_pipeline) { + factx.vtcm_v_tiles[1] = (__fp16 *) vtcm_seq_alloc(&vtcm_cur, v_tile_bytes); + } else { + factx.vtcm_v_tiles[1] = NULL; + } + factx.vtcm_s_tiles = (__fp16 *) vtcm_seq_alloc(&vtcm_cur, s_tile_bytes); + factx.vtcm_p_tiles = (__fp16 *) vtcm_seq_alloc(&vtcm_cur, s_tile_bytes); + factx.vtcm_d_tiles = (__fp16 *) vtcm_seq_alloc(&vtcm_cur, d_tile_bytes); + factx.vtcm_m_vec = (HVX_Vector *) vtcm_seq_alloc(&vtcm_cur, col_vec_bytes); + factx.vtcm_l_vec = (HVX_Vector *) vtcm_seq_alloc(&vtcm_cur, col_vec_bytes); + factx.vtcm_s_rowmax = (HVX_Vector *) vtcm_seq_alloc(&vtcm_cur, col_vec_bytes); + factx.vtcm_p_rowsum = (HVX_Vector *) vtcm_seq_alloc(&vtcm_cur, col_vec_bytes); + factx.vtcm_row_bufs = (HVX_Vector *) vtcm_seq_alloc(&vtcm_cur, row_vec_bytes * 2 * n_threads); + factx.row_buf_stride = row_vec_bytes / sizeof(HVX_Vector); + factx.vtcm_hmx_scales_id = vtcm_seq_alloc(&vtcm_cur, 256); + factx.vtcm_hmx_scales_qk = vtcm_seq_alloc(&vtcm_cur, 256); + factx.vtcm_mask_buf = (__fp16 *) vtcm_seq_alloc(&vtcm_cur, m_buf_bytes); + factx.mask_buf_row_stride = m_line_bytes / sizeof(__fp16); + factx.vtcm_slopes = (__fp16 *) vtcm_seq_alloc(&vtcm_cur, slopes_bytes); + + if ((size_t) (vtcm_cur - ctx->vtcm_base) > ctx->vtcm_size) { + return HTP_STATUS_VTCM_TOO_SMALL; + } + + // ======== Initialize HMX output scales ======== + // Identity scale (1.0) for O updates and normalization + hmx_init_column_scales(factx.vtcm_hmx_scales_id, Q6_V_vsplat_R(0x3c00)); // 1.0 + + // QK scale embedded in HMX output + hmx_init_column_scales(factx.vtcm_hmx_scales_qk, hvx_vec_splat_f16(factx.scale)); + + // ======== Skip compute if profiling ======== + if (octx->flags & HTP_OPFLAGS_SKIP_COMPUTE) { + return HTP_STATUS_OK; + } + + // Profiling timers + TIMER_DEFINE(total); + TIMER_DEFINE(q_load); + TIMER_DEFINE(kv_dma); + TIMER_DEFINE(k_interleave); + TIMER_DEFINE(v_interleave); + TIMER_DEFINE(qk_dot); + TIMER_DEFINE(softmax); + TIMER_DEFINE(o_update); + TIMER_DEFINE(o_norm); + TIMER_DEFINE(o_store); + + TIMER_START(total); + + // ======== DMA setup ======== + dma_queue * const dma = ctx->dma[0]; + + // Padded row sizes for DMA (defined in outer scope) + + const size_t n_row_tiles_g_br = g_br / HMX_FP16_TILE_N_ROWS; + const size_t n_tiles_per_bc = Bc / HMX_FP16_TILE_N_COLS; + + // Q/O element size for Q load and O store + const size_t qo_element_size = factx.is_q_fp32 ? sizeof(float) : sizeof(__fp16); + + // ======== HMX lock strategy ======== + // Pipeline: queue thread auto-acquires HMX lock on first push; released by suspend. + // Fallback: main thread holds the lock (original behavior). + if (!factx.use_pipeline) { + HAP_compute_res_hmx_lock(ctx->vtcm_rctx); + } + + // ======== Reusable job descriptors for pipeline ======== + hmx_fa_qk_job_t qk_job; + hmx_fa_o_update_job_t ou_job; + hmx_fa_o_norm_job_t on_job; + + // ======== Main loop: per batch, per KV head, per Q block ======== + for (uint32_t ib3 = 0; ib3 < neq3; ++ib3) { + for (uint32_t kv_head = 0; kv_head < n_kv_heads; ++kv_head) { + const uint32_t ik2 = kv_head; + const uint32_t ik3 = ib3 / (neq3 / k->ne[3]); + const uint32_t iv2 = kv_head; + const uint32_t iv3 = ib3 / (neq3 / v->ne[3]); + + for (uint32_t q_start = 0; q_start < neq1; q_start += Br) { + const uint32_t n_q_rows = hex_smin(Br, neq1 - q_start); + const size_t n_rows_g = n_q_rows * G; + const size_t g_br_actual = hex_align_up(n_rows_g, HMX_FP16_TILE_N_ROWS); + const size_t n_row_tiles = g_br_actual / HMX_FP16_TILE_N_ROWS; + + // ---- Load Q block [g_br, D] -> tiles, interleaving G heads ---- + TIMER_START(q_load); + if (n_rows_g < g_br) { + hvx_splat_u8_a(factx.vtcm_q_tiles, 0, q_tile_bytes); + } + fa_phase_q_load(&factx, q, q_start, kv_head, ib3, n_rows_g); + TIMER_STOP(q_load); + + // ---- Initialize per-block state ---- + hvx_splat_u8_a(factx.vtcm_l_vec, 0, col_vec_bytes); + hvx_splat_u8_a(factx.vtcm_d_tiles, 0, d_tile_bytes); + hvx_splat_u16_a(factx.vtcm_m_vec, 0xfbff, col_vec_bytes/2); + + __fp16 * o_tile_prev = factx.vtcm_o_tiles[0]; + __fp16 * o_tile_curr = factx.vtcm_o_tiles[1]; + hvx_splat_u8_a(o_tile_prev, 0, o_tile_bytes); + + // ---- KV block loop with DMA double-buffering ---- + size_t buf_idx = 0; + + fa_compute_slopes(&factx, kv_head, n_rows_g); + + // Prefetch first KV block + if (factx.n_kv_blocks > 0) { + const uint32_t kv_rows0 = hex_smin(Bc, nek1); + + const uint8_t * k_src = (const uint8_t *) k->data + ik2 * k->nb[2] + ik3 * k->nb[3]; + dma_queue_push(dma, dma_make_ptr(factx.vtcm_k_fp16[0], k_src), size_k_row_padded, k->nb[1], + size_k_row, kv_rows0); + + const uint8_t * v_src = (const uint8_t *) v->data + iv2 * v->nb[2] + iv3 * v->nb[3]; + dma_queue_push(dma, dma_make_ptr(factx.vtcm_v_fp16[0], v_src), size_v_row_padded, v->nb[1], + size_v_row, kv_rows0); + } + + // Mask DMA: single 2D transfer of n_q_rows unique mask rows into VTCM buffer. + // Only when mask is head-broadcast (ne[2]==1); otherwise softmax reads DDR directly. + #define MASK_DMA_PUSH(kv_start_val, kv_rows_val, has_mask_dma_var) \ + do { \ + has_mask_dma_var = false; \ + if (mask && factx.mask_broadcast) { \ + const uint32_t _im3 = ib3 % mask->ne[3]; \ + const uint8_t * _ms = (const uint8_t *) mask->data + q_start * mask->nb[1] + _im3 * mask->nb[3] + \ + (kv_start_val) * sizeof(__fp16); \ + dma_queue_push(dma, dma_make_ptr(factx.vtcm_mask_buf, _ms), m_line_bytes, mask->nb[1], \ + (kv_rows_val) * sizeof(__fp16), n_q_rows); \ + has_mask_dma_var = true; \ + } \ + } while (0) + + #define MASK_DMA_POP(has_mask_dma_var) \ + do { \ + if (has_mask_dma_var) { \ + dma_queue_pop(dma); \ + } \ + } while (0) + + #define DMA_PREFETCH_KV(blk_val) \ + do { \ + if ((blk_val) < factx.n_kv_blocks) { \ + const uint32_t _ns = (blk_val) * Bc; \ + const uint32_t _nr = hex_smin(Bc, nek1 - _ns); \ + size_t _nb = 1 - buf_idx; \ + const uint8_t * _ks = (const uint8_t *) k->data + _ns * k->nb[1] + ik2 * k->nb[2] + ik3 * k->nb[3]; \ + dma_queue_push(dma, dma_make_ptr(factx.vtcm_k_fp16[_nb], _ks), size_k_row_padded, k->nb[1], size_k_row, _nr); \ + const uint8_t * _vs = (const uint8_t *) v->data + _ns * v->nb[1] + iv2 * v->nb[2] + iv3 * v->nb[3]; \ + dma_queue_push(dma, dma_make_ptr(factx.vtcm_v_fp16[_nb], _vs), size_v_row_padded, v->nb[1], size_v_row, _nr); \ + } \ + } while (0) + + const size_t k_src_stride = size_k_row_padded / sizeof(__fp16); + const size_t v_src_stride = size_v_row_padded / sizeof(__fp16); + + if (factx.use_pipeline) { + // ================================================================== + // Pipeline path: HVX phases ‖ HMX queue worker + // ================================================================== + struct hmx_queue * hmx_q = ctx->hmx_queue; + + for (uint32_t kv_blk = 0; kv_blk < factx.n_kv_blocks; ++kv_blk) { + const uint32_t kv_start = kv_blk * Bc; + const uint32_t kv_rows = hex_smin(Bc, nek1 - kv_start); + const size_t n_col_tiles = hmx_ceil_div(kv_rows, HMX_FP16_TILE_N_COLS); + + // Wait for current KV DMA + TIMER_START(kv_dma); + dma_queue_pop(dma); // K + dma_queue_pop(dma); // V + TIMER_STOP(kv_dma); + + // Push mask DMA for this block (single 2D DMA when broadcast) + bool has_mask_dma = false; + MASK_DMA_PUSH(kv_start, kv_rows, has_mask_dma); + + // ---- Phase 1: K_int(blk) ‖ O_update(blk-1) ---- + if (kv_blk > 0) { + // Submit O_update for previous block (HMX worker) + ou_job.o_curr = o_tile_curr; + ou_job.o_prev = o_tile_prev; + ou_job.p_tiles = factx.vtcm_p_tiles; + ou_job.v_tiles = factx.vtcm_v_tiles[1 - buf_idx]; + ou_job.d_tiles = factx.vtcm_d_tiles; + ou_job.hmx_scales = factx.vtcm_hmx_scales_id; + ou_job.n_row_tiles = n_row_tiles; + ou_job.n_col_tiles = hmx_ceil_div(hex_smin(Bc, nek1 - (kv_blk - 1) * Bc), HMX_FP16_TILE_N_COLS); + ou_job.n_row_tiles_g_br = n_row_tiles_g_br; + ou_job.n_tiles_per_bc = n_tiles_per_bc; + ou_job.DV = DV; + hmx_queue_push(hmx_q, hmx_queue_make_desc(hmx_fa_o_update_worker, &ou_job)); + } + + TIMER_START(k_interleave); + fa_phase_k_interleave(&factx, kv_rows, k_src_stride, buf_idx); + TIMER_STOP(k_interleave); + + // ---- Phase 2: qk_dot(blk) on HMX ‖ V_int(blk) + DMA prefetch on HVX ---- + qk_job.q_tiles = factx.vtcm_q_tiles; + qk_job.k_tiles = factx.vtcm_k_tiles; + qk_job.s_tiles = factx.vtcm_s_tiles; + qk_job.n_row_tiles = n_row_tiles; + qk_job.n_col_tiles = n_col_tiles; + qk_job.n_dot_tiles = DK / 32; + qk_job.n_tiles_per_bc = n_tiles_per_bc; + qk_job.hmx_scales = factx.vtcm_hmx_scales_qk; + TIMER_START(qk_dot); + hmx_queue_push(hmx_q, hmx_queue_make_desc(hmx_fa_qk_dot_worker, &qk_job)); + + // DMA push next block (non-blocking, before worker_pool) + DMA_PREFETCH_KV(kv_blk + 1); + + TIMER_START(v_interleave); + fa_phase_v_interleave(&factx, kv_rows, v_src_stride, buf_idx, n_tiles_per_bc); + TIMER_STOP(v_interleave); + + // Pop and swap previous block's output update (deferred HMX pop) + if (kv_blk > 0) { + hmx_queue_pop(hmx_q); + hex_swap_ptr((void **) &o_tile_curr, (void **) &o_tile_prev); + } + + // Pop current block's dot product job + hmx_queue_pop(hmx_q); + TIMER_STOP(qk_dot); + + // ---- Phase 3: softmax(blk) + build_D(blk) | HMX idle ---- + // Pop mask DMA before softmax (ensures VTCM buffer is ready) + MASK_DMA_POP(has_mask_dma); + + fa_softmax_args_t sargs; + memset(&sargs, 0, sizeof(sargs)); + sargs.factx = &factx; + sargs.kv_rows = kv_rows; + sargs.n_rows_g = n_rows_g; + sargs.n_col_tiles = n_col_tiles; + sargs.n_tiles_per_bc = n_tiles_per_bc; + sargs.n_row_tiles = n_row_tiles; + sargs.n_row_tiles_g_br = n_row_tiles_g_br; + sargs.Bc = Bc; + sargs.G = G; + sargs.kv_head = kv_head; + sargs.kv_start = kv_start; + sargs.q_start = q_start; + sargs.ib3 = ib3; + sargs.has_alibi = (factx.max_bias != 0.0f); + sargs.mask = mask; + sargs.mask_vtcm = has_mask_dma ? (const __fp16 *) factx.vtcm_mask_buf : NULL; + sargs.mask_vtcm_row_stride = factx.mask_buf_row_stride; + sargs.slopes = factx.vtcm_slopes; + + TIMER_START(softmax); + fa_phase_softmax_and_build_d(&factx, &sargs, n_row_tiles, n_row_tiles_g_br); + TIMER_STOP(softmax); + + buf_idx = 1 - buf_idx; + } // end KV block loop (pipeline) + + // Epilogue: O_update for last block + if (factx.n_kv_blocks > 0) { + const uint32_t last_blk = factx.n_kv_blocks - 1; + const size_t last_cols = hmx_ceil_div(hex_smin(Bc, nek1 - last_blk * Bc), HMX_FP16_TILE_N_COLS); + ou_job.o_curr = o_tile_curr; + ou_job.o_prev = o_tile_prev; + ou_job.p_tiles = factx.vtcm_p_tiles; + ou_job.v_tiles = factx.vtcm_v_tiles[1 - buf_idx]; + ou_job.d_tiles = factx.vtcm_d_tiles; + ou_job.hmx_scales = factx.vtcm_hmx_scales_id; + ou_job.n_row_tiles = n_row_tiles; + ou_job.n_col_tiles = last_cols; + ou_job.n_row_tiles_g_br = n_row_tiles_g_br; + ou_job.n_tiles_per_bc = n_tiles_per_bc; + ou_job.DV = DV; + + TIMER_START(o_update); + hmx_queue_push(hmx_q, hmx_queue_make_desc(hmx_fa_o_update_worker, &ou_job)); + hmx_queue_pop(hmx_q); + TIMER_STOP(o_update); + + hex_swap_ptr((void **) &o_tile_curr, (void **) &o_tile_prev); + } + + } else { + // ================================================================== + // Fallback path: sequential with multi-thread HVX phases + // Main thread holds HMX lock, runs HMX inline. + // ================================================================== + + for (uint32_t kv_blk = 0; kv_blk < factx.n_kv_blocks; ++kv_blk) { + const uint32_t kv_start = kv_blk * Bc; + const uint32_t kv_rows = hex_smin(Bc, nek1 - kv_start); + const size_t n_col_tiles = hmx_ceil_div(kv_rows, HMX_FP16_TILE_N_COLS); + + TIMER_START(kv_dma); + dma_queue_pop(dma); // K + dma_queue_pop(dma); // V + TIMER_STOP(kv_dma); + + bool has_mask_dma = false; + MASK_DMA_PUSH(kv_start, kv_rows, has_mask_dma); + DMA_PREFETCH_KV(kv_blk + 1); + + // K interleave (multi-thread HVX) + TIMER_START(k_interleave); + fa_phase_k_interleave(&factx, kv_rows, k_src_stride, buf_idx); + TIMER_STOP(k_interleave); + + // QK dot (inline HMX on main thread) + TIMER_START(qk_dot); + { + const size_t n_dot_tiles = (size_t) (DK / 32); + const __fp16 * restrict q_base = factx.vtcm_q_tiles; + const __fp16 * restrict k_base = factx.vtcm_k_tiles; + __fp16 * restrict s_base = factx.vtcm_s_tiles; + __builtin_assume(n_row_tiles > 0); + __builtin_assume(n_col_tiles > 0); + __builtin_assume(n_dot_tiles > 0); + + Q6_bias_mxmem2_A((void *) factx.vtcm_hmx_scales_qk); + for (size_t r = 0; r < n_row_tiles; ++r) { + for (size_t c = 0; c < n_col_tiles; ++c) { + const __fp16 * row_tiles = q_base + r * HMX_FP16_TILE_N_ROWS * DK; + const __fp16 * col_tiles = k_base + c * HMX_FP16_TILE_N_COLS * DK; + __fp16 * out_tile = s_base + (r * n_tiles_per_bc + c) * HMX_FP16_TILE_N_ELMS; + for (size_t k = 0; k < n_dot_tiles; ++k) { + Q6_activation_hf_mxmem_RR((unsigned int) row_tiles, 2047); + Q6_weight_hf_mxmem_RR((unsigned int) col_tiles, 2047); + row_tiles += HMX_FP16_TILE_N_ELMS; + col_tiles += HMX_FP16_TILE_N_ELMS; + } + Q6_mxmem_AR_after_hf(out_tile, 0); + } + } + } + TIMER_STOP(qk_dot); + + // Pop mask DMA + MASK_DMA_POP(has_mask_dma); + + // Softmax + build_D (multi-thread HVX + serial m/l update) + fa_softmax_args_t sargs; + memset(&sargs, 0, sizeof(sargs)); + sargs.factx = &factx; + sargs.kv_rows = kv_rows; + sargs.n_rows_g = n_rows_g; + sargs.n_col_tiles = n_col_tiles; + sargs.n_tiles_per_bc = n_tiles_per_bc; + sargs.n_row_tiles = n_row_tiles; + sargs.n_row_tiles_g_br = n_row_tiles_g_br; + sargs.Bc = Bc; + sargs.G = G; + sargs.kv_head = kv_head; + sargs.kv_start = kv_start; + sargs.q_start = q_start; + sargs.ib3 = ib3; + sargs.has_alibi = (factx.max_bias != 0.0f); + sargs.mask = mask; + sargs.mask_vtcm = has_mask_dma ? (const __fp16 *) factx.vtcm_mask_buf : NULL; + sargs.mask_vtcm_row_stride = factx.mask_buf_row_stride; + sargs.slopes = factx.vtcm_slopes; + + TIMER_START(softmax); + fa_phase_softmax_and_build_d(&factx, &sargs, n_row_tiles, n_row_tiles_g_br); + TIMER_STOP(softmax); + + // V interleave (multi-thread HVX) + TIMER_START(v_interleave); + // FIX(v-stride): use n_tiles_per_bc (block-invariant) as V tile layout + // stride to match o_update's v_tile access. Using per-block n_col_tiles + // misplaces DV_tile 1..3 in the last partial KV block. + fa_phase_v_interleave(&factx, kv_rows, v_src_stride, buf_idx, n_tiles_per_bc); + TIMER_STOP(v_interleave); + + // O update (inline HMX on main thread) + TIMER_START(o_update); + { + const size_t DV_tiles = (size_t) (DV / 32); + const __fp16 * restrict d_base = factx.vtcm_d_tiles; + const __fp16 * restrict p_base = factx.vtcm_p_tiles; + const __fp16 * restrict v_base = factx.vtcm_v_tiles[0]; + const __fp16 * restrict op_base = o_tile_prev; + __fp16 * restrict oc_base = o_tile_curr; + __builtin_assume(n_row_tiles > 0); + __builtin_assume(n_col_tiles > 0); + __builtin_assume(DV_tiles > 0); + + Q6_bias_mxmem2_A((void *) factx.vtcm_hmx_scales_id); + for (size_t r = 0; r < n_row_tiles; ++r) { + for (size_t c = 0; c < DV_tiles; ++c) { + const __fp16 * d_diag = d_base + r * (n_row_tiles_g_br + 1) * HMX_FP16_TILE_N_ELMS; + const __fp16 * o_rc = op_base + (c * n_row_tiles_g_br + r) * HMX_FP16_TILE_N_ELMS; + Q6_activation_hf_mxmem_RR((unsigned int) d_diag, 2047); + Q6_weight_hf_mxmem_RR((unsigned int) o_rc, 2047); + + const __fp16 * p_tile_in = p_base + (r * n_tiles_per_bc) * HMX_FP16_TILE_N_ELMS; + const __fp16 * v_tile_in = v_base + (c * n_tiles_per_bc) * HMX_FP16_TILE_N_ELMS; + for (size_t k = 0; k < n_col_tiles; ++k) { + Q6_activation_hf_mxmem_RR((unsigned int) p_tile_in, 2047); + Q6_weight_hf_mxmem_RR((unsigned int) v_tile_in, 2047); + p_tile_in += HMX_FP16_TILE_N_ELMS; + v_tile_in += HMX_FP16_TILE_N_ELMS; + } + + __fp16 * o_tile_out = oc_base + (c * n_row_tiles_g_br + r) * HMX_FP16_TILE_N_ELMS; + Q6_mxmem_AR_after_hf(o_tile_out, 0); + } + } + hex_swap_ptr((void **) &o_tile_curr, (void **) &o_tile_prev); + } + TIMER_STOP(o_update); + + buf_idx = 1 - buf_idx; + } // end KV block loop (fallback) + } + + // ---- Final normalization: O = diag(1/l) @ O ---- + TIMER_START(o_norm); + { + fa_build_d_diag_inv_l(&factx, n_row_tiles, n_row_tiles_g_br); + + // HMX: O_final = diag(1/l) @ O_prev + if (factx.use_pipeline) { + on_job.o_curr = o_tile_curr; + on_job.o_prev = o_tile_prev; + on_job.d_tiles = factx.vtcm_d_tiles; + on_job.hmx_scales = factx.vtcm_hmx_scales_id; + on_job.n_row_tiles = n_row_tiles; + on_job.n_row_tiles_g_br = n_row_tiles_g_br; + on_job.DV = DV; + hmx_queue_push(ctx->hmx_queue, hmx_queue_make_desc(hmx_fa_o_norm_worker, &on_job)); + hmx_queue_pop(ctx->hmx_queue); + } else { + const size_t DV_tiles = (size_t) (DV / 32); + const __fp16 * restrict d_base = factx.vtcm_d_tiles; + const __fp16 * restrict op_base = o_tile_prev; + __fp16 * restrict oc_base = o_tile_curr; + __builtin_assume(n_row_tiles > 0); + __builtin_assume(DV_tiles > 0); + + Q6_bias_mxmem2_A((void *) factx.vtcm_hmx_scales_id); + for (size_t r = 0; r < n_row_tiles; ++r) { + for (size_t c = 0; c < DV_tiles; ++c) { + const __fp16 * d_diag = d_base + r * (n_row_tiles_g_br + 1) * HMX_FP16_TILE_N_ELMS; + const __fp16 * o_rc = op_base + (c * n_row_tiles_g_br + r) * HMX_FP16_TILE_N_ELMS; + __fp16 * o_out = oc_base + (r * DV_tiles + c) * HMX_FP16_TILE_N_ELMS; + + Q6_activation_hf_mxmem_RR((unsigned int) d_diag, 2047); + Q6_weight_hf_mxmem_RR((unsigned int) o_rc, 2047); + Q6_mxmem_AR_after_hf(o_out, 0); + } + } + } + } + TIMER_STOP(o_norm); + + // ---- Store O block ---- + TIMER_START(o_store); + fa_phase_o_store(&factx, dst, o_tile_curr, q_start, kv_head, ib3, n_rows_g); + TIMER_STOP(o_store); + +#undef MASK_DMA_PUSH +#undef MASK_DMA_POP +#undef DMA_PREFETCH_KV + + } // end Q block loop + } // end KV head loop + } // end batch loop + + if (factx.use_pipeline) { + hmx_queue_suspend(ctx->hmx_queue); + } else { + HAP_compute_res_hmx_unlock(ctx->vtcm_rctx); + } + + TIMER_STOP(total); + +#if defined(ENABLE_PROFILE_TIMERS) + FARF(HIGH, "hmx-fa: %lld us, q_load=%lld kv_dma=%lld k_interleave=%lld v_interleave=%lld", TIMER_US(total), + TIMER_US(q_load), TIMER_US(kv_dma), TIMER_US(k_interleave), TIMER_US(v_interleave)); + FARF(HIGH, " qk_dot=%lld softmax=%lld o_update=%lld o_norm=%lld o_store=%lld", TIMER_US(qk_dot), TIMER_US(softmax), + TIMER_US(o_update), TIMER_US(o_norm), TIMER_US(o_store)); +#endif + + return HTP_STATUS_OK; +} diff --git a/ggml/src/ggml-hexagon/htp/hmx-matmul-ops.c b/ggml/src/ggml-hexagon/htp/hmx-matmul-ops.c new file mode 100644 index 00000000000..dab605210cf --- /dev/null +++ b/ggml/src/ggml-hexagon/htp/hmx-matmul-ops.c @@ -0,0 +1,2066 @@ +#pragma clang diagnostic ignored "-Wgnu-zero-variadic-macro-arguments" +#pragma clang diagnostic ignored "-Wunused-function" +#pragma clang diagnostic ignored "-Wunused-variable" +#pragma clang diagnostic ignored "-Wunused-but-set-variable" + +#include <assert.h> +#include <stdbool.h> +#include <stddef.h> +#include <stdint.h> +#include <string.h> + +#include <HAP_farf.h> +#include <HAP_compute_res.h> + +#define GGML_COMMON_DECL_C +#include "ggml-common.h" + +#include "hex-dma.h" +#include "hex-fastdiv.h" +#include "worker-pool.h" + +#include "hvx-utils.h" +#include "hvx-dump.h" +#include "htp-ctx.h" +#include "htp-ops.h" + +#include "hmx-ops.h" +#include "hmx-utils.h" +#include "hmx-queue.h" +#include "hmx-profile.h" + +#include "vtcm-utils.h" + +static const __fp16 q4_0_to_fp16_lut[64] __attribute__((aligned(VLEN))) = { + -8, 0, -7, 0, -6, 0, -5, 0, -4, 0, -3, 0, -2, 0, -1, 0, 0, 0, 1, 0, 2, 0, 3, 0, 4, 0, 5, 0, 6, 0, 7, 0, +}; + +static const __fp16 q4_1_to_fp16_lut[64] __attribute__((aligned(VLEN))) = { + 0, 0, 1, 0, 2, 0, 3, 0, 4, 0, 5, 0, 6, 0, 7, 0, 8, 0, 9, 0, 10, 0, 11, 0, 12, 0, 13, 0, 14, 0, 15, 0, +}; + +// MXFP4 dequantization LUT: maps 4-bit index to fp16 mantissa value +// kvalues: 0, 0.5, 1, 1.5, 2, 3, 4, 6, 0, -0.5, -1, -1.5, -2, -3, -4, -6 +static const __fp16 mxfp4_to_fp16_lut[64] __attribute__((aligned(VLEN))) = { + 0, 0, 0.5, 0, 1, 0, 1.5, 0, 2, 0, 3, 0, 4, 0, 6, 0, 0, 0, -0.5, 0, -1, 0, -1.5, 0, -2, 0, -3, 0, -4, 0, -6, 0, +}; + +static const __fp16 iq4_nl_to_fp16_lut[64] __attribute__((aligned(VLEN))) = { + -127, 0, -104, 0, -83, 0, -65, 0, -49, 0, -35, 0, -22, 0, -10, 0, + 1, 0, 13, 0, 25, 0, 38, 0, 53, 0, 69, 0, 89, 0, 113, 0, +}; + +// Scales per x4x2 logical block: 8 × sizeof(__fp16) = 16 bytes +#define HMX_X4X2_SCALES_PER_BLK 8 +#define HMX_X4X2_DBLK_SIZE 16 // 8 * 2 bytes (fp16 scales for Q4_0/Q8_0/IQ4_NL) +#define HMX_X4X2_MXFP4_EBLK_SIZE 8 // 8 * 1 byte (E8M0 scales for MXFP4) + +// Compute the byte stride of one row in x4x2 format. +// Numerically equals ggml_row_size(type, k) when k is 256-aligned, because +// x4x2 packing has the same density as block_q4_0 / block_q8_0. +// Layout per row: [quants: nb*128 (Q4) or nb*256 (Q8)][scales: nb*16 bytes] +// Total per row = nb * (128+16) = 144*nb (Q4) or nb * (256+16) = 272*nb (Q8). +// Callers must ensure k is a multiple of 256 (enforced by proc_hmx_matmul_req). +static inline size_t get_x4x2_row_stride(int weight_type, int k) { + int nb = (k + QK_Q4_0x4x2 - 1) / QK_Q4_0x4x2; + switch (weight_type) { + case HTP_TYPE_Q4_0: + case HTP_TYPE_IQ4_NL: + return (size_t) nb * (QK_Q4_0x4x2 / 2 + HMX_X4X2_DBLK_SIZE); // 144 * nb + case HTP_TYPE_Q4_1: + return (size_t) nb * (QK_Q4_0x4x2 / 2 + 32); // 160 * nb + case HTP_TYPE_Q8_0: + return (size_t) nb * (QK_Q8_0x4x2 + HMX_X4X2_DBLK_SIZE); // 272 * nb + case HTP_TYPE_MXFP4: + return (size_t) nb * (QK_MXFP4x4x2 / 2 + HMX_X4X2_MXFP4_EBLK_SIZE); // 136 * nb + case HTP_TYPE_F16: + return (size_t) k * sizeof(__fp16); + case HTP_TYPE_F32: + return (size_t) k * sizeof(float); + default: + return 0; + } +} + +// --- Overflow-safe arithmetic for VTCM budget calculation --- + +static inline bool hmx_mul_overflow(size_t a, size_t b, size_t *out) { + if (a != 0 && b > SIZE_MAX / a) return true; + *out = a * b; + return false; +} + +static inline bool hmx_add_overflow(size_t a, size_t b, size_t *out) { + if (a > SIZE_MAX - b) return true; + *out = a + b; + return false; +} + +// Search for optimal (mc, nc) chunk sizes within VTCM budget. +// +// VTCM model: nc * per_n_cost + mc * per_m_cost + mc * nc * per_mn_cost + overhead +// +// Minimize ceil(m/mc) * m_block_cost + ceil(n/nc) * n_block_cost. +// All matmul paths repeat weight processing per M-block and activation loading +// per N-block, so discrete block counts drive total overhead. +// Tie-break: when cost is equal, prefer larger mc * nc. +// +// Caller-provided coefficients: +// m_block_cost: penalty per extra M-block (weight redundancy, scales with n). +// n_block_cost: penalty per extra N-block (activation redundancy, scales with m). +// +// Algorithm: nc sweeps from n_max down by 32, analytically solving for mc_max. +// Returns 0 on success, -1 if VTCM is insufficient. +static int hmx_compute_chunks(size_t vtcm_total, + size_t overhead, + size_t per_n_cost, + size_t per_m_cost, + size_t per_mn_cost, + int m, + int n, + size_t m_block_cost, + size_t n_block_cost, + size_t * m_chunk_out, + size_t * n_chunk_out, + size_t * total_out) { + if (m <= 0 || n <= 0) return -1; + if (vtcm_total <= overhead) return -1; + if (per_n_cost == 0 || per_m_cost == 0 || per_mn_cost == 0) return -1; + + const size_t usable = vtcm_total - overhead; + + size_t best_cost = SIZE_MAX; + size_t best_mn = 0; + size_t best_m = 0, best_n = 0; + + const size_t n_max = hex_align_down((size_t)n, HMX_FP16_TILE_N_COLS); + for (size_t nc = n_max; nc >= HMX_FP16_TILE_N_COLS; nc -= HMX_FP16_TILE_N_COLS) { + size_t n_fixed = 0, ncmn = 0, mc_denom = 0; + if (hmx_mul_overflow(nc, per_n_cost, &n_fixed)) continue; + if (n_fixed >= usable) goto next_nc; + + if (hmx_mul_overflow(nc, per_mn_cost, &ncmn)) goto next_nc; + if (hmx_add_overflow(per_m_cost, ncmn, &mc_denom) || mc_denom == 0) goto next_nc; + + { + size_t remain = usable - n_fixed; + size_t mc = remain / mc_denom; + mc = hex_align_down(mc, HMX_FP16_TILE_N_ROWS); + mc = hex_smin(mc, (size_t)m); + + if (mc == 0) { + goto next_nc; + } + + size_t mblocks = ((size_t) m + mc - 1) / mc; + size_t nblocks = ((size_t) n + nc - 1) / nc; + size_t cost = mblocks * m_block_cost + nblocks * n_block_cost; + size_t mn = mc * nc; + if (cost < best_cost || (cost == best_cost && mn > best_mn)) { + best_cost = cost; + best_mn = mn; + best_m = mc; + best_n = nc; + } + } + +next_nc: + if (nc == HMX_FP16_TILE_N_COLS) break; // avoid size_t underflow + } + + if (best_m == 0 || best_n == 0) return -1; + + // Compute exact total (with overflow checks) + size_t t0 = 0, t1 = 0, t2 = 0, mn = 0, total = 0; + if (hmx_mul_overflow(best_n, per_n_cost, &t0)) return -1; + if (hmx_mul_overflow(best_m, per_m_cost, &t1)) return -1; + if (hmx_mul_overflow(best_m, best_n, &mn)) return -1; + if (hmx_mul_overflow(mn, per_mn_cost, &t2)) return -1; + if (hmx_add_overflow(t0, t1, &total)) return -1; + if (hmx_add_overflow(total, t2, &total)) return -1; + if (hmx_add_overflow(total, overhead, &total)) return -1; + + *m_chunk_out = best_m; + *n_chunk_out = best_n; + *total_out = total; + return 0; +} + +// --- x4x2 format dequantizers --- + +// Dequantize one x4x2 Q4_0 group (32 elements from 32 packed bytes) -> 32 FP16 in first 64 bytes. +// In x4x2, sub-blocks 0..3 use lower nibbles, sub-blocks 4..7 use upper nibbles +// of the same 32 packed bytes. +static inline HVX_Vector dequantize_x4x2_q4_0_group_hvx(const uint8_t *packed_32, bool upper_nibbles, const __fp16 *scale, const HVX_Vector vlut_cvt) { + (void)vlut_cvt; + HVX_Vector vq = hvx_vmemu(packed_32); + const HVX_Vector mask_h4 = Q6_Vb_vsplat_R(0x0F); + const HVX_Vector i8 = Q6_Vb_vsplat_R(8); + HVX_Vector v_scales = hvx_vec_repl_f16(hvx_vmemu(scale)); + + HVX_Vector v_quants = Q6_Vub_vlsr_VubR(vq, 4 * upper_nibbles); + v_quants = Q6_V_vand_VV(v_quants, mask_h4); + + HVX_Vector v_int8 = Q6_Vb_vsub_VbVb(v_quants, i8); + HVX_Vector v0 = Q6_V_lo_W(Q6_Wh_vunpack_Vb(v_int8)); + HVX_Vector v_hf = Q6_Vhf_equals_Vh(v0); + + return Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(v_hf, v_scales)); +} + +// Batch-dequantize 4 contiguous x4x2 Q4_0 groups (4x32 = 128 packed bytes) using +// full HVX vector width. +// Output: vector_x2 each hold 32 FP16 values in the first 64 bytes. +static inline HVX_Vector_x2 dequantize_x4x2_q4_0_x4groups_hvx( + const uint8_t *packed_128, bool upper_nibbles, + const __fp16 *scales_4, const HVX_Vector vlut_cvt) { + (void)vlut_cvt; + HVX_Vector vq = hvx_vmemu(packed_128); + const HVX_Vector mask_h4 = Q6_Vb_vsplat_R(0x0F); + const HVX_Vector i8 = Q6_Vb_vsplat_R(8); + HVX_Vector v_quants = Q6_Vub_vlsr_VubR(vq, 4 * upper_nibbles); + v_quants = Q6_V_vand_VV(v_quants, mask_h4); + + HVX_Vector v_int8 = Q6_Vb_vsub_VbVb(v_quants, i8); + + HVX_VectorPair vp_int16 = Q6_Wh_vunpack_Vb(v_int8); + HVX_Vector v_lo = Q6_V_lo_W(vp_int16); + HVX_Vector v_hi = Q6_V_hi_W(vp_int16); + + v_lo = Q6_Vhf_equals_Vh(v_lo); + v_hi = Q6_Vhf_equals_Vh(v_hi); + + HVX_Vector vscale = hvx_vmemu(scales_4); + HVX_Vector v_sc01 = hvx_vec_repl_2x_f16(vscale); + HVX_Vector v_sc23 = hvx_vec_repl_2x_f16(Q6_V_vror_VR(vscale, 4)); + + v_lo = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(v_lo, v_sc01)); + v_hi = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(v_hi, v_sc23)); + + HVX_Vector_x2 r = { v_lo, v_hi }; + return r; +} + +static inline HVX_Vector dequantize_x4x2_q4_1_group_hvx(const uint8_t *packed_32, bool upper_nibbles, const __fp16 *scale_offset, const HVX_Vector vlut_cvt) { + (void)vlut_cvt; + HVX_Vector vq = hvx_vmemu(packed_32); + const HVX_Vector mask_h4 = Q6_Vb_vsplat_R(0x0F); + HVX_Vector v_dm = hvx_vmemu(scale_offset); + HVX_Vector v_scales = hvx_vec_repl_f16(v_dm); + HVX_Vector v_offsets = hvx_vec_repl_f16(Q6_V_vror_VR(v_dm, 2)); + + HVX_Vector v_quants = Q6_Vub_vlsr_VubR(vq, 4 * upper_nibbles); + v_quants = Q6_V_vand_VV(v_quants, mask_h4); + + HVX_Vector v0 = Q6_V_lo_W(Q6_Wh_vunpack_Vb(v_quants)); + HVX_Vector v_hf = Q6_Vhf_equals_Vh(v0); + + return Q6_Vhf_equals_Vqf16(Q6_Vqf16_vadd_Vqf16Vhf(Q6_Vqf16_vmpy_VhfVhf(v_hf, v_scales), v_offsets)); +} + +static inline HVX_Vector_x2 dequantize_x4x2_q4_1_x4groups_hvx( + const uint8_t *packed_128, bool upper_nibbles, + const __fp16 *scales_offsets_4, const HVX_Vector vlut_cvt) { + (void)vlut_cvt; + HVX_Vector vq = hvx_vmemu(packed_128); + const HVX_Vector mask_h4 = Q6_Vb_vsplat_R(0x0F); + HVX_Vector v_quants = Q6_Vub_vlsr_VubR(vq, 4 * upper_nibbles); + v_quants = Q6_V_vand_VV(v_quants, mask_h4); + + HVX_VectorPair vp_int16 = Q6_Wh_vunpack_Vb(v_quants); + HVX_Vector v_lo = Q6_V_lo_W(vp_int16); + HVX_Vector v_hi = Q6_V_hi_W(vp_int16); + + v_lo = Q6_Vhf_equals_Vh(v_lo); + v_hi = Q6_Vhf_equals_Vh(v_hi); + + HVX_Vector vscale_offset = hvx_vmemu(scales_offsets_4); + HVX_VectorPair dm_deal = Q6_W_vdeal_VVR(vscale_offset, vscale_offset, -2); + HVX_Vector vd = Q6_V_lo_W(dm_deal); + HVX_Vector vm = Q6_V_hi_W(dm_deal); + + HVX_Vector v_sc01 = hvx_vec_repl_2x_f16(vd); + HVX_Vector v_sc23 = hvx_vec_repl_2x_f16(Q6_V_vror_VR(vd, 4)); + + HVX_Vector v_os01 = hvx_vec_repl_2x_f16(vm); + HVX_Vector v_os23 = hvx_vec_repl_2x_f16(Q6_V_vror_VR(vm, 4)); + + v_lo = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vadd_Vqf16Vhf(Q6_Vqf16_vmpy_VhfVhf(v_lo, v_sc01), v_os01)); + v_hi = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vadd_Vqf16Vhf(Q6_Vqf16_vmpy_VhfVhf(v_hi, v_sc23), v_os23)); + + HVX_Vector_x2 r = { v_lo, v_hi }; + return r; +} + +// LUT-based dequantizers for non-linear IQ4_NL format. +static inline HVX_Vector dequantize_x4x2_iq4_nl_group_hvx(const uint8_t *packed_32, bool upper_nibbles, const __fp16 *scale, const HVX_Vector vlut_cvt) { + HVX_Vector vq = hvx_vmemu(packed_32); + const HVX_Vector mask_h4 = Q6_Vb_vsplat_R(0x0F); + HVX_Vector v_scales = hvx_vec_repl_f16(hvx_vmemu(scale)); + HVX_Vector v_quants = Q6_Vub_vlsr_VubR(vq, 4 * upper_nibbles); + v_quants = Q6_V_vand_VV(v_quants, mask_h4); + v_quants = Q6_Vb_vshuff_Vb(v_quants); + HVX_VectorPair vp = Q6_Wh_vlut16_VbVhR(v_quants, vlut_cvt, 0); + HVX_Vector v_hf = Q6_V_lo_W(vp); + + return Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(v_hf, v_scales)); +} + +static inline HVX_Vector_x2 dequantize_x4x2_iq4_nl_x4groups_hvx( + const uint8_t *packed_128, bool upper_nibbles, + const __fp16 *scales_4, const HVX_Vector vlut_cvt) { + HVX_Vector vq = hvx_vmemu(packed_128); + const HVX_Vector mask_h4 = Q6_Vb_vsplat_R(0x0F); + HVX_Vector v_quants = Q6_Vub_vlsr_VubR(vq, 4 * upper_nibbles); + v_quants = Q6_V_vand_VV(v_quants, mask_h4); + + v_quants = Q6_Vb_vshuff_Vb(v_quants); + + HVX_VectorPair vp = Q6_Wh_vlut16_VbVhR(v_quants, vlut_cvt, 0); + HVX_Vector v_lo = Q6_V_lo_W(vp); + HVX_Vector v_hi = Q6_V_hi_W(vp); + + HVX_Vector vscale = hvx_vmemu(scales_4); + HVX_Vector v_sc01 = hvx_vec_repl_2x_f16(vscale); + HVX_Vector v_sc23 = hvx_vec_repl_2x_f16(Q6_V_vror_VR(vscale, 4)); + + v_lo = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(v_lo, v_sc01)); + v_hi = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(v_hi, v_sc23)); + + HVX_Vector_x2 r = { v_lo, v_hi }; + return r; +} + +// Dequantize one x4x2 Q8_0 group (32 int8 quants) -> 32 FP16 in first 64 bytes. +static inline HVX_Vector dequantize_x4x2_q8_0_group_hvx(const int8_t *quants_32, const __fp16 *scale) { + HVX_Vector vq = hvx_vmemu(quants_32); + HVX_Vector v_scales = hvx_vec_repl_f16(hvx_vmemu(scale)); + HVX_Vector v0 = Q6_V_lo_W(Q6_Wh_vunpack_Vb(vq)); + HVX_Vector v_hf = Q6_Vhf_equals_Vh(v0); + return Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(v_hf, v_scales)); +} + +// --- MXFP4 E8M0 scale conversion and dequantization --- +// +// HVX batch-convert 8 E8M0 bytes (one x4x2 block's scales) to __fp16[8] on stack. +// Scalar loads from the stack array execute on the scalar pipeline, in parallel +// with HVX vlut16/vmpy/vscatter — freeing HVX slots in the hot loop. +// Arithmetic: fp16_bits = clamp(e - 112, 0, 30) << 10 +// e=0..112 -> 0 (underflow), e=113..142 -> valid fp16, e>=143 -> clamped to 2^15. + +typedef struct { + __fp16 v[8] __attribute__((aligned(16))); +} mxfp4_scales_t; + +static inline mxfp4_scales_t mxfp4_convert_scales(const uint8_t * e8m0_8) { + mxfp4_scales_t s; + HVX_Vector v = hvx_vmemu(e8m0_8); + HVX_Vector vh = Q6_V_lo_W(Q6_Wuh_vunpack_Vub(v)); + vh = Q6_Vh_vsub_VhVh(vh, Q6_Vh_vsplat_R(112)); + vh = Q6_Vh_vmax_VhVh(vh, Q6_V_vzero()); + vh = Q6_Vh_vmin_VhVh(vh, Q6_Vh_vsplat_R(30)); + vh = Q6_Vh_vasl_VhR(vh, 10); + hvx_vec_store_u(s.v, 16, vh); + return s; +} + +static inline HVX_Vector mxfp4_extract_splat(mxfp4_scales_t scales, int idx) { + return hvx_vec_splat_f16(scales.v[idx]); +} + +// Dequantize one x4x2 MXFP4 group (32 elements from 32 packed bytes) -> 32 FP16. +static inline HVX_Vector dequantize_x4x2_mxfp4_group_hvx(const uint8_t * packed_32, + bool upper_nibbles, + int sub_blk, + const HVX_Vector vlut_cvt, + mxfp4_scales_t scales) { + HVX_Vector vq = hvx_vmemu(packed_32); + const HVX_Vector mask_h4 = Q6_Vb_vsplat_R(0x0F); + HVX_Vector v_quants = upper_nibbles ? Q6_Vub_vlsr_VubR(vq, 4) : vq; + v_quants = Q6_V_vand_VV(v_quants, mask_h4); + + HVX_Vector v_sc = mxfp4_extract_splat(scales, sub_blk); + + v_quants = Q6_Vb_vshuff_Vb(v_quants); + HVX_VectorPair vp = Q6_Wh_vlut16_VbVhR(v_quants, vlut_cvt, 0); + HVX_Vector v_hf = Q6_V_lo_W(vp); + + return Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(v_hf, v_sc)); +} + +// Batch-dequantize 4 contiguous x4x2 MXFP4 groups (4x32 = 128 packed bytes). +static inline HVX_Vector_x4 dequantize_x4x2_mxfp4_x4groups_hvx(const uint8_t * packed_128, + bool upper_nibbles, + int sub_blk_base, + const HVX_Vector vlut_cvt, + mxfp4_scales_t scales) { + HVX_Vector vq = hvx_vmemu(packed_128); + const HVX_Vector mask_h4 = Q6_Vb_vsplat_R(0x0F); + HVX_Vector v_quants = upper_nibbles ? Q6_Vub_vlsr_VubR(vq, 4) : vq; + v_quants = Q6_V_vand_VV(v_quants, mask_h4); + + v_quants = Q6_Vb_vshuff_Vb(v_quants); + + HVX_VectorPair vp = Q6_Wh_vlut16_VbVhR(v_quants, vlut_cvt, 0); + HVX_Vector v_lo = Q6_V_lo_W(vp); + HVX_Vector v_hi = Q6_V_hi_W(vp); + + HVX_VectorPred q64 = Q6_Q_vsetq_R(64); + HVX_Vector v_sc01 = Q6_V_vmux_QVV(q64, mxfp4_extract_splat(scales, sub_blk_base + 0), + mxfp4_extract_splat(scales, sub_blk_base + 1)); + HVX_Vector v_sc23 = Q6_V_vmux_QVV(q64, mxfp4_extract_splat(scales, sub_blk_base + 2), + mxfp4_extract_splat(scales, sub_blk_base + 3)); + + v_lo = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(v_lo, v_sc01)); + v_hi = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(v_hi, v_sc23)); + + HVX_Vector_x4 r = { v_lo, Q6_V_vror_VR(v_lo, 64), v_hi, Q6_V_vror_VR(v_hi, 64) }; + return r; +} + +typedef struct { + __fp16 *dst; + const uint8_t *src; + int n_cols; + int k_block; + size_t row_stride; + int weight_type; + int n_tot_tiles; + int n_tiles_per_task; + int n_tasks; + int n_k_tiles; + struct fastdiv_values n_k_tiles_div; +} x4x2_dequantize_state_t; + +// Dequantize a tile range from x4x2 weight data (already in VTCM) to tile-major FP16. +// Input: vtcm_src has n_cols rows of x4x2 data, each row_stride bytes. +// Output: vtcm_dst in tile-major FP16 layout. + +#define DEFINE_DEQUANTIZE_Q4_TASK(suffix, lut_name, helper_prefix, dblk_size, scale_step) \ +static void dequantize_x4x2_weight_to_fp16_tiles_task_##suffix( \ + const x4x2_dequantize_state_t *state, \ + int start_tile, int end_tile) { \ + \ + const int n_k_tiles = state->n_k_tiles; \ + const int qrow_size = (unsigned)state->k_block / 2; \ + const struct fastdiv_values n_k_tiles_div = state->n_k_tiles_div; \ + const HVX_Vector vlut_cvt = hvx_vmem(lut_name); \ + \ + const HVX_Vector v_scat_base = hvx_vmem(hmx_transpose_scatter_offsets); \ + const HVX_Vector v_scat_step = Q6_V_vsplat_R(4); \ + const HVX_VectorPred q_mask64 = Q6_Q_vsetq_R(64); \ + \ + unsigned ct = fastdiv((unsigned)start_tile, &n_k_tiles_div); \ + unsigned kt = fastmodulo((unsigned)start_tile, n_k_tiles, &n_k_tiles_div); \ + \ + for (unsigned t = start_tile; t < (unsigned)end_tile; ) { \ + if (kt >= (unsigned)n_k_tiles) { kt = 0; ct++; } \ + \ + if ((kt % 4 == 0) && (t + 4 <= (unsigned)end_tile) && (fastdiv(t + 3, &n_k_tiles_div) == ct)) { \ + unsigned blk_idx = ((kt * 32) / QK_Q4_0x4x2); \ + unsigned sub_blk_base = ((kt * 32) % QK_Q4_0x4x2) / 32; \ + bool upper = (sub_blk_base >= 4); \ + unsigned packed_off = blk_idx * (QK_Q4_0x4x2 / 2); \ + unsigned scale_off = qrow_size + blk_idx * (dblk_size) + sub_blk_base * (scale_step); \ + \ + __fp16 *tile_bases[4]; \ + for (unsigned g = 0; g < 4; g++) { \ + tile_bases[g] = state->dst + (t + g) * HMX_FP16_TILE_N_ELMS; \ + } \ + \ + HVX_Vector v_off = v_scat_base; \ + unsigned row_offset = ct * HMX_FP16_TILE_N_COLS * state->row_stride; \ + \ + for (int r = 0; r < HMX_FP16_TILE_N_ROWS; r += 2) { \ + const uint8_t *r0 = state->src + row_offset; row_offset += state->row_stride; \ + const uint8_t *r1 = state->src + row_offset; row_offset += state->row_stride; \ + \ + HVX_Vector_x2 dv0 = dequantize_x4x2_##helper_prefix##_x4groups_hvx( \ + r0 + packed_off, upper, (const __fp16 *)(r0 + scale_off), vlut_cvt); \ + Q6_vscatter_RMVwV((size_t)tile_bases[0], 2 * HMX_FP16_TILE_SIZE - 1, v_off, dv0.v[0]); \ + Q6_vscatter_RMVwV((size_t)tile_bases[2], 2 * HMX_FP16_TILE_SIZE - 1, v_off, dv0.v[1]); \ + v_off = Q6_Vw_vadd_VwVw(v_off, v_scat_step); \ + \ + HVX_Vector_x2 dv1 = dequantize_x4x2_##helper_prefix##_x4groups_hvx( \ + r1 + packed_off, upper, (const __fp16 *)(r1 + scale_off), vlut_cvt); \ + Q6_vscatter_RMVwV((size_t)tile_bases[0], 2 * HMX_FP16_TILE_SIZE - 1, v_off, dv1.v[0]); \ + Q6_vscatter_RMVwV((size_t)tile_bases[2], 2 * HMX_FP16_TILE_SIZE - 1, v_off, dv1.v[1]); \ + v_off = Q6_Vw_vadd_VwVw(v_off, v_scat_step); \ + } \ + \ + for (int g = 0; g < 4; g++) { (void) *(volatile HVX_Vector *)(tile_bases[g]); } \ + t += 4; kt += 4; \ + continue; \ + } \ + \ + __fp16 *tile_base = state->dst + t * HMX_FP16_TILE_N_ELMS; \ + { \ + unsigned blk_idx = (kt * 32) / QK_Q4_0x4x2; \ + unsigned sub_blk = ((kt * 32) % QK_Q4_0x4x2) / 32; \ + bool upper = (sub_blk >= 4); \ + unsigned byte_off = blk_idx * (QK_Q4_0x4x2 / 2) + (upper ? (sub_blk - 4) : sub_blk) * 32; \ + unsigned scale_off = qrow_size + blk_idx * (dblk_size) + sub_blk * (scale_step); \ + \ + HVX_Vector v_off = v_scat_base; \ + unsigned row_offset = ct * HMX_FP16_TILE_N_COLS * state->row_stride; \ + unsigned row1 = ct * HMX_FP16_TILE_N_COLS + 1; \ + \ + for (int r = 0; r < HMX_FP16_TILE_N_ROWS; r += 2, row1 += 2) { \ + const uint8_t *r0 = state->src + row_offset; row_offset += state->row_stride; \ + const uint8_t *r1 = state->src + row_offset; row_offset += state->row_stride; \ + \ + HVX_Vector v0 = dequantize_x4x2_##helper_prefix##_group_hvx( \ + r0 + byte_off, upper, (const __fp16 *)(r0 + scale_off), vlut_cvt); \ + HVX_Vector v1 = (row1 < (unsigned)state->n_cols) \ + ? dequantize_x4x2_##helper_prefix##_group_hvx( \ + r1 + byte_off, upper, (const __fp16 *)(r1 + scale_off), vlut_cvt) \ + : Q6_V_vzero(); \ + \ + Q6_vscatter_QRMVwV(q_mask64, (size_t)tile_base, HMX_FP16_TILE_SIZE - 1, v_off, v0); \ + v_off = Q6_Vw_vadd_VwVw(v_off, v_scat_step); \ + Q6_vscatter_QRMVwV(q_mask64, (size_t)tile_base, HMX_FP16_TILE_SIZE - 1, v_off, v1); \ + v_off = Q6_Vw_vadd_VwVw(v_off, v_scat_step); \ + } \ + (void) *(volatile HVX_Vector *)(tile_base); \ + } \ + ++t; ++kt; \ + } \ + \ + if (start_tile < end_tile) { \ + (void) *(volatile HVX_Vector *)(state->dst + (end_tile - 1) * HMX_FP16_TILE_N_ELMS); \ + } \ +} \ + \ +static void dequantize_x4x2_worker_loop_##suffix(unsigned int n, unsigned int i, void *data) { \ + x4x2_dequantize_state_t *state = (x4x2_dequantize_state_t *)data; \ + for (unsigned int task_id = i; task_id < (unsigned int)state->n_tasks; task_id += n) { \ + int start = task_id * state->n_tiles_per_task; \ + int end = hex_smin(start + state->n_tiles_per_task, state->n_tot_tiles); \ + dequantize_x4x2_weight_to_fp16_tiles_task_##suffix(state, start, end); \ + } \ +} + +DEFINE_DEQUANTIZE_Q4_TASK(q4_0, q4_0_to_fp16_lut, q4_0, HMX_X4X2_DBLK_SIZE, (int)sizeof(__fp16)) +DEFINE_DEQUANTIZE_Q4_TASK(q4_1, q4_1_to_fp16_lut, q4_1, 32, 4) +DEFINE_DEQUANTIZE_Q4_TASK(iq4_nl, iq4_nl_to_fp16_lut, iq4_nl, HMX_X4X2_DBLK_SIZE, (int)sizeof(__fp16)) + +static void dequantize_x4x2_weight_to_fp16_tiles_task_mxfp4( + const x4x2_dequantize_state_t *state, + int start_tile, int end_tile) { + + const int n_k_tiles = state->n_k_tiles; + const int qrow_size = (unsigned)state->k_block / 2; + const struct fastdiv_values n_k_tiles_div = state->n_k_tiles_div; + const HVX_Vector vlut_cvt = hvx_vmem(mxfp4_to_fp16_lut); + + const HVX_Vector v_scat_base = hvx_vmem(hmx_transpose_scatter_offsets); + const HVX_Vector v_scat_step = Q6_V_vsplat_R(4); + const HVX_VectorPred q_mask64 = Q6_Q_vsetq_R(64); + + unsigned ct = fastdiv((unsigned)start_tile, &n_k_tiles_div); + unsigned kt = fastmodulo((unsigned)start_tile, n_k_tiles, &n_k_tiles_div); + + for (unsigned t = start_tile; t < (unsigned)end_tile; ) { + if (kt >= (unsigned)n_k_tiles) { kt = 0; ct++; } + + // Batch-4 fast path for MXFP4 + if ((kt % 4 == 0) && (t + 4 <= (unsigned)end_tile) && (fastdiv(t + 3, &n_k_tiles_div) == ct)) { + int blk_idx = (kt * 32) / QK_MXFP4x4x2; + int sub_blk_base = ((kt * 32) % QK_MXFP4x4x2) / 32; + bool upper = (sub_blk_base >= 4); + int packed_off = blk_idx * (QK_MXFP4x4x2 / 2); + int e8m0_blk_off = qrow_size + blk_idx * HMX_X4X2_MXFP4_EBLK_SIZE; + + __fp16 * tile_bases[4]; + for (int g = 0; g < 4; g++) { + tile_bases[g] = state->dst + (t + g) * HMX_FP16_TILE_N_ELMS; + } + + HVX_Vector v_off = v_scat_base; + for (int r = 0; r < HMX_FP16_TILE_N_ROWS; r += 2) { + int row0 = ct * HMX_FP16_TILE_N_COLS + r; + int row1 = row0 + 1; + const uint8_t * r0 = state->src + row0 * state->row_stride; + const uint8_t * r1 = state->src + row1 * state->row_stride; + + mxfp4_scales_t r0_e8 = mxfp4_convert_scales(r0 + e8m0_blk_off); + + HVX_Vector_x4 dv0, dv1; + dv0 = dequantize_x4x2_mxfp4_x4groups_hvx(r0 + packed_off, upper, sub_blk_base, vlut_cvt, r0_e8); + if (row1 < state->n_cols) { + mxfp4_scales_t r1_e8 = mxfp4_convert_scales(r1 + e8m0_blk_off); + dv1 = dequantize_x4x2_mxfp4_x4groups_hvx(r1 + packed_off, upper, sub_blk_base, vlut_cvt, r1_e8); + } else { + dv1.v[0] = dv1.v[1] = dv1.v[2] = dv1.v[3] = Q6_V_vzero(); + } + + for (int g = 0; g < 4; g++) { + Q6_vscatter_QRMVwV(q_mask64, (size_t) tile_bases[g], HMX_FP16_TILE_SIZE - 1, v_off, dv0.v[g]); + } + v_off = Q6_Vw_vadd_VwVw(v_off, v_scat_step); + for (int g = 0; g < 4; g++) { + Q6_vscatter_QRMVwV(q_mask64, (size_t) tile_bases[g], HMX_FP16_TILE_SIZE - 1, v_off, dv1.v[g]); + } + v_off = Q6_Vw_vadd_VwVw(v_off, v_scat_step); + } + + for (int g = 0; g < 4; g++) { + (void) *(volatile HVX_Vector *) (tile_bases[g]); + } + + t += 4; kt += 4; + continue; + } + + // Single-tile fallback + __fp16 *tile_base = state->dst + t * HMX_FP16_TILE_N_ELMS; + { + int blk_idx = (kt * 32) / QK_MXFP4x4x2; + int sub_blk = ((kt * 32) % QK_MXFP4x4x2) / 32; + bool upper = (sub_blk >= 4); + int byte_off = blk_idx * (QK_MXFP4x4x2 / 2) + (upper ? (sub_blk - 4) : sub_blk) * 32; + int e8m0_blk_off = qrow_size + blk_idx * HMX_X4X2_MXFP4_EBLK_SIZE; + + HVX_Vector v_off = v_scat_base; + for (int r = 0; r < HMX_FP16_TILE_N_ROWS; r += 2) { + int row0 = ct * HMX_FP16_TILE_N_COLS + r; + int row1 = row0 + 1; + + const uint8_t * r0 = state->src + row0 * state->row_stride; + const uint8_t * r1 = state->src + row1 * state->row_stride; + + mxfp4_scales_t r0_e8 = mxfp4_convert_scales(r0 + e8m0_blk_off); + + HVX_Vector v0 = dequantize_x4x2_mxfp4_group_hvx(r0 + byte_off, upper, sub_blk, vlut_cvt, r0_e8); + HVX_Vector v1; + if (row1 < state->n_cols) { + mxfp4_scales_t r1_e8 = mxfp4_convert_scales(r1 + e8m0_blk_off); + v1 = dequantize_x4x2_mxfp4_group_hvx(r1 + byte_off, upper, sub_blk, vlut_cvt, r1_e8); + } else { + v1 = Q6_V_vzero(); + } + + Q6_vscatter_QRMVwV(q_mask64, (size_t) tile_base, HMX_FP16_TILE_SIZE - 1, v_off, v0); + v_off = Q6_Vw_vadd_VwVw(v_off, v_scat_step); + Q6_vscatter_QRMVwV(q_mask64, (size_t) tile_base, HMX_FP16_TILE_SIZE - 1, v_off, v1); + v_off = Q6_Vw_vadd_VwVw(v_off, v_scat_step); + } + (void) *(volatile HVX_Vector *) (tile_base); + } + ++t; ++kt; + } + + if (start_tile < end_tile) { + (void) *(volatile HVX_Vector *)(state->dst + (end_tile - 1) * HMX_FP16_TILE_N_ELMS); + } +} + +static void dequantize_x4x2_worker_loop_mxfp4(unsigned int n, unsigned int i, void *data) { + x4x2_dequantize_state_t *state = (x4x2_dequantize_state_t *)data; + for (unsigned int task_id = i; task_id < (unsigned int)state->n_tasks; task_id += n) { + int start = task_id * state->n_tiles_per_task; + int end = hex_smin(start + state->n_tiles_per_task, state->n_tot_tiles); + dequantize_x4x2_weight_to_fp16_tiles_task_mxfp4(state, start, end); + } +} + +static void dequantize_x4x2_weight_to_fp16_tiles_task_q8_0( + const x4x2_dequantize_state_t *state, + int start_tile, int end_tile) { + + const int n_k_tiles = state->n_k_tiles; + const int qrow_size = state->k_block; + const struct fastdiv_values n_k_tiles_div = state->n_k_tiles_div; + + const HVX_Vector v_scat_base = hvx_vmem(hmx_transpose_scatter_offsets); + const HVX_Vector v_scat_step = Q6_V_vsplat_R(4); + const HVX_VectorPred q_mask64 = Q6_Q_vsetq_R(64); + + unsigned ct = fastdiv((unsigned)start_tile, &n_k_tiles_div); + unsigned kt = fastmodulo((unsigned)start_tile, n_k_tiles, &n_k_tiles_div); + + for (unsigned t = start_tile; t < (unsigned)end_tile; ) { + if (kt >= (unsigned)n_k_tiles) { kt = 0; ct++; } + + __fp16 *tile_base = state->dst + t * HMX_FP16_TILE_N_ELMS; + { + int blk_idx = (kt * 32) / QK_Q8_0x4x2; + int sub_blk = ((kt * 32) % QK_Q8_0x4x2) / 32; + int byte_off = blk_idx * QK_Q8_0x4x2 + sub_blk * 32; + int scale_off = qrow_size + blk_idx * HMX_X4X2_DBLK_SIZE + sub_blk * (int)sizeof(__fp16); + + HVX_Vector v_off = v_scat_base; + for (int r = 0; r < HMX_FP16_TILE_N_ROWS; r += 2) { + int row0 = ct * HMX_FP16_TILE_N_COLS + r; + int row1 = row0 + 1; + + const uint8_t *r0 = state->src + row0 * state->row_stride; + const uint8_t *r1 = state->src + row1 * state->row_stride; + + HVX_Vector v0 = dequantize_x4x2_q8_0_group_hvx((const int8_t *)(r0 + byte_off), (const __fp16 *)(r0 + scale_off)); + HVX_Vector v1 = (row1 < state->n_cols) ? dequantize_x4x2_q8_0_group_hvx((const int8_t *)(r1 + byte_off), (const __fp16 *)(r1 + scale_off)) : Q6_V_vzero(); + + Q6_vscatter_QRMVwV(q_mask64, (size_t)tile_base, HMX_FP16_TILE_SIZE - 1, v_off, v0); + v_off = Q6_Vw_vadd_VwVw(v_off, v_scat_step); + Q6_vscatter_QRMVwV(q_mask64, (size_t)tile_base, HMX_FP16_TILE_SIZE - 1, v_off, v1); + v_off = Q6_Vw_vadd_VwVw(v_off, v_scat_step); + } + (void) *(volatile HVX_Vector *)(tile_base); + } + ++t; ++kt; + } + + if (start_tile < end_tile) { + (void) *(volatile HVX_Vector *)(state->dst + (end_tile - 1) * HMX_FP16_TILE_N_ELMS); + } +} + +static void dequantize_x4x2_worker_loop_q8_0(unsigned int n, unsigned int i, void *data) { + x4x2_dequantize_state_t *state = (x4x2_dequantize_state_t *)data; + for (unsigned int task_id = i; task_id < (unsigned int)state->n_tasks; task_id += n) { + int start = task_id * state->n_tiles_per_task; + int end = hex_smin(start + state->n_tiles_per_task, state->n_tot_tiles); + dequantize_x4x2_weight_to_fp16_tiles_task_q8_0(state, start, end); + } +} + +static void convert_f16_weight_to_fp16_tiles_task( + const x4x2_dequantize_state_t *state, + int start_tile, int end_tile) { + + const int n_k_tiles = state->n_k_tiles; + const struct fastdiv_values n_k_tiles_div = state->n_k_tiles_div; + + const HVX_Vector v_scat_base = hvx_vmem(hmx_transpose_scatter_offsets); + const HVX_Vector v_scat_step = Q6_V_vsplat_R(4); + const HVX_VectorPred q_mask64 = Q6_Q_vsetq_R(64); + + unsigned ct = fastdiv((unsigned)start_tile, &n_k_tiles_div); + unsigned kt = fastmodulo((unsigned)start_tile, n_k_tiles, &n_k_tiles_div); + + for (unsigned t = start_tile; t < (unsigned)end_tile; ) { + if (kt >= (unsigned)n_k_tiles) { kt = 0; ct++; } + + __fp16 *tile_base = state->dst + t * HMX_FP16_TILE_N_ELMS; + { + int byte_off = kt * 32 * sizeof(__fp16); + + HVX_Vector v_off = v_scat_base; + for (int r = 0; r < HMX_FP16_TILE_N_ROWS; r += 2) { + int row0 = ct * HMX_FP16_TILE_N_COLS + r; + int row1 = row0 + 1; + + const uint8_t *r0 = state->src + row0 * state->row_stride; + const uint8_t *r1 = state->src + row1 * state->row_stride; + + HVX_Vector v0 = hvx_vmemu((const __fp16 *)(r0 + byte_off)); + HVX_Vector v1 = (row1 < state->n_cols) ? hvx_vmemu((const __fp16 *)(r1 + byte_off)) : Q6_V_vzero(); + + Q6_vscatter_QRMVwV(q_mask64, (size_t)tile_base, HMX_FP16_TILE_SIZE - 1, v_off, v0); + v_off = Q6_Vw_vadd_VwVw(v_off, v_scat_step); + Q6_vscatter_QRMVwV(q_mask64, (size_t)tile_base, HMX_FP16_TILE_SIZE - 1, v_off, v1); + v_off = Q6_Vw_vadd_VwVw(v_off, v_scat_step); + } + (void) *(volatile HVX_Vector *)(tile_base); + } + ++t; ++kt; + } + + if (start_tile < end_tile) { + (void) *(volatile HVX_Vector *)(state->dst + (end_tile - 1) * HMX_FP16_TILE_N_ELMS); + } +} + +static void convert_f16_worker_loop(unsigned int n, unsigned int i, void *data) { + x4x2_dequantize_state_t *state = (x4x2_dequantize_state_t *)data; + for (unsigned int task_id = i; task_id < (unsigned int)state->n_tasks; task_id += n) { + int start = task_id * state->n_tiles_per_task; + int end = hex_smin(start + state->n_tiles_per_task, state->n_tot_tiles); + convert_f16_weight_to_fp16_tiles_task(state, start, end); + } +} + +static void quantize_f32_weight_to_fp16_tiles_task( + const x4x2_dequantize_state_t *state, + int start_tile, int end_tile) { + + const int n_k_tiles = state->n_k_tiles; + const struct fastdiv_values n_k_tiles_div = state->n_k_tiles_div; + + const HVX_Vector v_scat_base = hvx_vmem(hmx_transpose_scatter_offsets); + const HVX_Vector v_scat_step = Q6_V_vsplat_R(4); + const HVX_VectorPred q_mask64 = Q6_Q_vsetq_R(64); + + unsigned ct = fastdiv((unsigned)start_tile, &n_k_tiles_div); + unsigned kt = fastmodulo((unsigned)start_tile, n_k_tiles, &n_k_tiles_div); + + for (unsigned t = start_tile; t < (unsigned)end_tile; ) { + if (kt >= (unsigned)n_k_tiles) { kt = 0; ct++; } + + __fp16 *tile_base = state->dst + t * HMX_FP16_TILE_N_ELMS; + { + int byte_off = kt * 32 * sizeof(float); + + HVX_Vector v_off = v_scat_base; + for (int r = 0; r < HMX_FP16_TILE_N_ROWS; r += 2) { + int row0 = ct * HMX_FP16_TILE_N_COLS + r; + int row1 = row0 + 1; + + const uint8_t *r0 = state->src + row0 * state->row_stride; + const uint8_t *r1 = state->src + row1 * state->row_stride; + + HVX_Vector v0_f32 = hvx_vmemu((const float *)(r0 + byte_off)); + HVX_Vector v1_f32 = (row1 < state->n_cols) ? hvx_vmemu((const float *)(r1 + byte_off)) : Q6_V_vzero(); + + HVX_Vector v_out = hvx_vec_f32_to_f16(v0_f32, v1_f32); + + Q6_vscatter_QRMVwV(q_mask64, (size_t)tile_base, HMX_FP16_TILE_SIZE - 1, v_off, v_out); + v_off = Q6_Vw_vadd_VwVw(v_off, v_scat_step); + + HVX_Vector v_out_hi = Q6_V_vror_VR(v_out, 64); + Q6_vscatter_QRMVwV(q_mask64, (size_t)tile_base, HMX_FP16_TILE_SIZE - 1, v_off, v_out_hi); + v_off = Q6_Vw_vadd_VwVw(v_off, v_scat_step); + } + (void) *(volatile HVX_Vector *)(tile_base); + } + ++t; ++kt; + } + + if (start_tile < end_tile) { + (void) *(volatile HVX_Vector *)(state->dst + (end_tile - 1) * HMX_FP16_TILE_N_ELMS); + } +} + +static void quantize_f32_worker_loop(unsigned int n, unsigned int i, void *data) { + x4x2_dequantize_state_t *state = (x4x2_dequantize_state_t *)data; + for (unsigned int task_id = i; task_id < (unsigned int)state->n_tasks; task_id += n) { + int start = task_id * state->n_tiles_per_task; + int end = hex_smin(start + state->n_tiles_per_task, state->n_tot_tiles); + quantize_f32_weight_to_fp16_tiles_task(state, start, end); + } +} + + +static void dequantize_x4x2_weight_chunk_to_fp16_tiles( + struct htp_context *ctx, __fp16 *vtcm_dst, + const void *vtcm_src, int n_cols, int k_block, + size_t row_stride, int weight_type, + int n_k_tiles, struct fastdiv_values n_k_tiles_div, + worker_callback_t dequant_worker_fn, int n_threads) { + + assert(n_cols % HMX_FP16_TILE_N_COLS == 0); + assert(k_block % HMX_FP16_TILE_N_COLS == 0); + + size_t n_col_tiles = n_cols / HMX_FP16_TILE_N_COLS; + size_t n_tot_tiles = n_col_tiles * n_k_tiles; + + size_t n_tiles_per_task = (n_threads == 1) ? n_tot_tiles : hmx_ceil_div(n_tot_tiles, n_threads); + + x4x2_dequantize_state_t state; + state.n_tasks = (n_tot_tiles + n_tiles_per_task - 1) / n_tiles_per_task; + state.n_tot_tiles = n_tot_tiles; + state.n_tiles_per_task = n_tiles_per_task; + state.dst = vtcm_dst; + state.src = (const uint8_t *)vtcm_src; + state.n_cols = n_cols; + state.k_block = k_block; + state.row_stride = row_stride; + state.weight_type = weight_type; + state.n_k_tiles = n_k_tiles; + state.n_k_tiles_div = n_k_tiles_div; + + if (state.n_tasks == 1 || n_threads == 1) { + dequant_worker_fn(1, 0, &state); + } else { + worker_pool_run_func(ctx->worker_pool, dequant_worker_fn, &state, n_threads); + } +} + +// --- End x4x2 dequantizers --- + +#pragma clang diagnostic ignored "-Wbackend-plugin" // spurios warning for hmx intrinsics + +// requires external HMX lock +static void core_dot_chunk_fp16(__fp16 *restrict output, const __fp16 *restrict activation, const __fp16 *restrict weight, const __fp16 *restrict scales, + int n_row_tiles, int n_col_tiles, int n_dot_tiles) { + __builtin_assume(n_row_tiles > 0); + __builtin_assume(n_col_tiles > 0); + __builtin_assume(n_dot_tiles > 0); + + Q6_bias_mxmem2_A((void *)scales); + for (int r = 0; r < n_row_tiles; ++r) { + for (size_t c = 0; c < n_col_tiles; ++c) { + Q6_mxclracc_hf(); + + const __fp16 *row_tiles = activation + r * n_dot_tiles * HMX_FP16_TILE_N_ELMS; + const __fp16 *col_tiles = weight + c * n_dot_tiles * HMX_FP16_TILE_N_ELMS; + + for (int k = 0, k_block; k < n_dot_tiles; k += k_block) { + k_block = hex_smin(n_dot_tiles - k, 32); + const uint32_t range = 2048u * (uint32_t)k_block - 1; + Q6_activation_hf_mxmem_RR_deep((unsigned int)row_tiles, range); + Q6_weight_hf_mxmem_RR((unsigned int)col_tiles, range); + row_tiles += k_block * HMX_FP16_TILE_N_ELMS; + col_tiles += k_block * HMX_FP16_TILE_N_ELMS; + } + + __fp16 *out_tile = output + (r * n_col_tiles + c) * HMX_FP16_TILE_N_ELMS; + Q6_mxmem_AR_after_hf(out_tile, 0); + } + } +} + +// --- Async HMX matmul job (for pipeline overlap) --- + +typedef struct { + __fp16 * output; + const __fp16 * activation; + const __fp16 * weight; + const __fp16 * scales; + uint32_t n_row_tiles; + uint32_t n_col_tiles; + uint32_t n_dot_tiles; +} hmx_matmul_job_t; + +static void hmx_matmul_worker_fn(void * data) { + hmx_matmul_job_t * job = (hmx_matmul_job_t *) data; + FARF(HIGH, "hmx-mm-job: n_row_tiles %u n_col_tiles %u n_dot_tiles %u", job->n_row_tiles, job->n_col_tiles, job->n_dot_tiles); + core_dot_chunk_fp16(job->output, job->activation, job->weight, job->scales, job->n_row_tiles, job->n_col_tiles, job->n_dot_tiles); +} + +static inline void hmx_matmul_job_init(hmx_matmul_job_t * job, + __fp16 * output, + const __fp16 * activation, + const __fp16 * weight, + const __fp16 * scales, + int n_row_tiles, + int n_col_tiles, + int n_dot_tiles) { + job->output = output; + job->activation = activation; + job->weight = weight; + job->scales = scales; + job->n_row_tiles = n_row_tiles; + job->n_col_tiles = n_col_tiles; + job->n_dot_tiles = n_dot_tiles; +} + +// output : fp16 -> f32p + +static void transfer_output_chunk_fp16_to_fp32(float *restrict dst, const __fp16 *restrict vtcm_src, int n_rows, int n_cols, int n) { + assert(n_cols % HMX_FP16_TILE_N_COLS == 0); + const size_t tile_row_stride = (n_cols / HMX_FP16_TILE_N_COLS) * HMX_FP16_TILE_N_ELMS; + + const HVX_Vector one = hvx_vec_splat_f16(1.0); + + for (size_t r = 0; r < n_rows; r += 2) { + const size_t r0 = r / HMX_FP16_TILE_N_ROWS; + const size_t r1 = (r % HMX_FP16_TILE_N_ROWS) / 2; // index of the row pair within the tile + const __fp16 *row_base = vtcm_src + r0 * tile_row_stride; + float *output_row_base = dst + r * n; // global memory row base for row r (and r+1) + + #pragma unroll(4) + for (size_t c = 0; c < n_cols; c += HMX_FP16_TILE_N_COLS) { + const size_t c0 = c / HMX_FP16_TILE_N_COLS; + const __fp16 *tile = row_base + c0 * HMX_FP16_TILE_N_ELMS; + HVX_Vector v = ((const HVX_Vector *) tile)[r1]; + HVX_VectorPair vp = Q6_Wqf32_vmpy_VhfVhf(v, one); + + volatile HVX_Vector *pv_out0 = (volatile HVX_Vector *) (output_row_base + c + 0); + volatile HVX_Vector *pv_out1 = (volatile HVX_Vector *) (output_row_base + c + n); // next row in global memory + + *pv_out0 = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(vp)); + if (r + 1 < n_rows) { + *pv_out1 = Q6_Vsf_equals_Vqf32(Q6_V_hi_W(vp)); + } + } + } +} + +typedef struct { + const __fp16 *vtcm_src; + float *dst; + int n_tasks; + int n_tot_chunks; + int n_chunks_per_task; + int n_cols; + int n; // DDR row stride (total output columns) +} output_transfer_task_state_t; + +static void transfer_output_chunk_worker_fn(unsigned int n, unsigned int i, void *data) { + output_transfer_task_state_t *st = (output_transfer_task_state_t *) data; + + for (unsigned int task_id = i; task_id < (unsigned int)st->n_tasks; task_id += n) { + int chunk_idx = task_id * st->n_chunks_per_task; + size_t chunk_size = hex_smin(st->n_tot_chunks - chunk_idx, st->n_chunks_per_task); + + float *dst = st->dst + chunk_idx * st->n; + const __fp16 *vtcm_src = st->vtcm_src + chunk_idx * st->n_cols; + transfer_output_chunk_fp16_to_fp32(dst, vtcm_src, chunk_size, st->n_cols, st->n); + } +} + +static void transfer_output_chunk_threaded(struct htp_context *ctx, float *dst, const __fp16 *vtcm_src, + int n_rows, int n_cols, int n, int n_threads) { + assert(n_cols % HMX_FP16_TILE_N_COLS == 0); + + size_t n_tot_chunks = n_rows; + size_t n_chunks_per_task = (n_threads == 1) ? n_tot_chunks : HMX_FP16_TILE_N_ROWS; // must be multiple of HMX_FP16_TILE_N_ROWS (32) + + output_transfer_task_state_t state; + state.n_tasks = (n_tot_chunks + n_chunks_per_task - 1) / n_chunks_per_task; + state.n_tot_chunks = n_tot_chunks; + state.n_chunks_per_task = n_chunks_per_task; + state.dst = dst; + state.vtcm_src = vtcm_src; + state.n_cols = n_cols; + state.n = n; + + if (state.n_tasks == 1 || n_threads == 1) { + transfer_output_chunk_worker_fn(1, 0, &state); + } else { + worker_pool_run_func(ctx->worker_pool, transfer_output_chunk_worker_fn, &state, n_threads); + } +} + +// activations : fp32 -> fp16 + +static void transfer_activation_chunk_fp32_to_fp16(__fp16 *restrict vtcm_dst, const float *restrict src, int n_rows, int k_block, int k_stride) { + const int n_rows_padded = hex_align_up(n_rows, HMX_FP16_TILE_N_ROWS); + const int n_rows_tiled = (n_rows / HMX_FP16_TILE_N_ROWS) * HMX_FP16_TILE_N_ROWS; + + int r = 0; + + #pragma unroll(2) + for (r = 0; r < n_rows_tiled; r += 2) { + int r0 = r / HMX_FP16_TILE_N_ROWS; // tile row index + int r1 = r % HMX_FP16_TILE_N_ROWS; // intra-tile row idx + + const HVX_Vector *pv_in0 = (const HVX_Vector *) (src + (r + 0) * k_stride); + const HVX_Vector *pv_in1 = (const HVX_Vector *) (src + (r + 1) * k_stride); + for (int c = 0; c < k_block; c += 32) { + HVX_Vector v0 = *pv_in0++; + HVX_Vector v1 = *pv_in1++; + + HVX_Vector v_out = hvx_vec_f32_to_f16_shuff(v0, v1); + + // compute output position + int c0 = c / HMX_FP16_TILE_N_COLS; // tile column index + int tile_idx = r0 * (k_block / HMX_FP16_TILE_N_COLS) + c0; + + HVX_Vector *tile = (HVX_Vector *) (vtcm_dst + tile_idx * HMX_FP16_TILE_N_ELMS); + tile[r1 / 2] = v_out; + } + } + + for (; r < n_rows_padded; r += 2) { + int r0 = r / HMX_FP16_TILE_N_ROWS; // tile row index + int r1 = r % HMX_FP16_TILE_N_ROWS; // intra-tile row idx + + const bool row0_valid = r < n_rows; + const bool row1_valid = (r + 1) < n_rows; + + const HVX_Vector *pv_in0 = row0_valid ? (const HVX_Vector *) (src + (r + 0) * k_stride) : NULL; + const HVX_Vector *pv_in1 = row1_valid ? (const HVX_Vector *) (src + (r + 1) * k_stride) : NULL; + for (int c = 0; c < k_block; c += 32) { + HVX_Vector v0 = row0_valid ? *pv_in0++ : Q6_V_vzero(); + HVX_Vector v1 = row1_valid ? *pv_in1++ : Q6_V_vzero(); + + HVX_Vector v_out = hvx_vec_f32_to_f16_shuff(v0, v1); + + // compute output position + int c0 = c / HMX_FP16_TILE_N_COLS; // tile column index + int tile_idx = r0 * (k_block / HMX_FP16_TILE_N_COLS) + c0; + + HVX_Vector *tile = (HVX_Vector *) (vtcm_dst + tile_idx * HMX_FP16_TILE_N_ELMS); + tile[r1 / 2] = v_out; + } + } +} + +typedef struct { + __fp16 *dst; + const float *src; + int n_tasks; + int n_tot_chunks; + int n_chunks_per_task; + int k_block; + int k_stride; +} activation_transfer_task_state_t; + +static void transfer_activation_chunk_worker_fn(unsigned int n, unsigned int i, void *data) { + activation_transfer_task_state_t *st = (activation_transfer_task_state_t *) data; + + for (unsigned int task_id = i; task_id < (unsigned int)st->n_tasks; task_id += n) { + // one chunk: one row + int chunk_idx = task_id * st->n_chunks_per_task; + size_t chunk_size = hex_smin(st->n_tot_chunks - chunk_idx, st->n_chunks_per_task); + + __fp16 *dst = st->dst + chunk_idx * st->k_block; + const float *src = st->src + chunk_idx * st->k_stride; + transfer_activation_chunk_fp32_to_fp16(dst, src, chunk_size, st->k_block, st->k_stride); + } +} + +static void transfer_activation_chunk_threaded(struct htp_context *ctx, __fp16 *dst, const float *src, int n_rows, int k_block, int k_stride, int n_threads) { + assert(k_block % HMX_FP16_TILE_N_COLS == 0 && k_stride % HMX_FP16_TILE_N_COLS == 0); + assert(VLEN == 32 * sizeof(float)); + + size_t n_tot_chunks = n_rows; + size_t n_chunks_per_task = (n_threads == 1) ? n_tot_chunks : 32; // must be multiple of 32 to ensure correct destination address + + activation_transfer_task_state_t state; + state.n_tasks = (n_tot_chunks + n_chunks_per_task - 1) / n_chunks_per_task; + state.n_tot_chunks = n_tot_chunks; + state.n_chunks_per_task = n_chunks_per_task; + state.dst = dst; + state.src = src; + state.k_block = k_block; + state.k_stride = k_stride; + + if (state.n_tasks == 1 || n_threads == 1) { + transfer_activation_chunk_worker_fn(1, 0, &state); + } else { + worker_pool_run_func(ctx->worker_pool, transfer_activation_chunk_worker_fn, &state, n_threads); + } +} + +// C += AB +static void core_mma_chunk_fp16(__fp16 *restrict c, const __fp16 *restrict a, const __fp16 *restrict b, + const __fp16 *restrict col_scales, const __fp16 *restrict eye_tile, + int n_row_tiles, int n_col_tiles, int n_dot_tiles, bool zero_init) { + __builtin_assume(n_row_tiles > 0); + __builtin_assume(n_col_tiles > 0); + __builtin_assume(n_dot_tiles > 0); + + Q6_bias_mxmem2_A((void *)col_scales); + + const size_t dot_tile_stride = n_dot_tiles * HMX_FP16_TILE_N_ELMS; + for (size_t i = 0; i < n_row_tiles; ++i) { + const __fp16 *row_base = a + i * dot_tile_stride; + __fp16 *res_base = c + i * n_col_tiles * HMX_FP16_TILE_N_ELMS; + for (size_t j = 0; j < n_col_tiles; ++j) { + Q6_mxclracc_hf(); + + const __fp16 *col_tiles = b + j * dot_tile_stride; + const __fp16 *row_tiles = row_base; + __fp16 *accum_tile = res_base + j * HMX_FP16_TILE_N_ELMS; + if (!zero_init) { + Q6_activation_hf_mxmem_RR((unsigned int)accum_tile, 2047); + Q6_weight_hf_mxmem_RR((unsigned int)eye_tile, 2047); + } + + for (int k = 0, k_block; k < n_dot_tiles; k += k_block) { + k_block = hex_smin(n_dot_tiles - k, 32); + const uint32_t range = 2048u * (uint32_t)k_block - 1; + Q6_activation_hf_mxmem_RR_deep((unsigned int)row_tiles, range); + Q6_weight_hf_mxmem_RR((unsigned int)col_tiles, range); + row_tiles += k_block * HMX_FP16_TILE_N_ELMS; + col_tiles += k_block * HMX_FP16_TILE_N_ELMS; + } + + Q6_mxmem_AR_after_hf(accum_tile, 0); + } + } +} + +int hmx_matmul_2d_f32(struct htp_context *ctx, float *restrict dst, const float *restrict activation, + const uint8_t *restrict permuted_weight, int m, int k, int n, + int act_stride, int weight_stride, int weight_type) { + if (k % 32 != 0 || n % 32 != 0) { return -1; } + + if (!hex_is_aligned(dst, VLEN) || !hex_is_aligned(activation, VLEN) || !hex_is_aligned(permuted_weight, VLEN)) { + return -1; + } + + size_t row_stride = get_x4x2_row_stride(weight_type, k); + if (row_stride == 0) { + return -1; + } + + worker_callback_t dequant_worker_fn = NULL; + switch (weight_type) { + case HTP_TYPE_Q4_0: dequant_worker_fn = dequantize_x4x2_worker_loop_q4_0; break; + case HTP_TYPE_IQ4_NL: dequant_worker_fn = dequantize_x4x2_worker_loop_iq4_nl; break; + case HTP_TYPE_Q4_1: dequant_worker_fn = dequantize_x4x2_worker_loop_q4_1; break; + case HTP_TYPE_MXFP4: dequant_worker_fn = dequantize_x4x2_worker_loop_mxfp4; break; + case HTP_TYPE_Q8_0: dequant_worker_fn = dequantize_x4x2_worker_loop_q8_0; break; + case HTP_TYPE_F16: dequant_worker_fn = convert_f16_worker_loop; break; + case HTP_TYPE_F32: dequant_worker_fn = quantize_f32_worker_loop; break; + default: + return -1; + } + + const int n_k_tiles = k / HMX_FP16_TILE_N_COLS; + const struct fastdiv_values n_k_tiles_div = init_fastdiv_values(n_k_tiles); + + // --- Dynamic Mode Configuration --- + const bool use_pipeline = (m > 32); + const int num_threads = (m <= 32) ? 1 : ctx->n_threads; + + // --- Dynamic VTCM layout --- + const size_t vec_dot_size = k * sizeof(__fp16); + const size_t vtcm_budget = ctx->vtcm_size; + size_t vtcm_used = 0; + + // Pipeline = 4-stage DMA→dequant→HMX→store with HMX worker overlap. + const size_t size_per_n = row_stride + (use_pipeline ? 2 * vec_dot_size : vec_dot_size); // Q + S0 + S1 (dequant bufs) + const size_t size_per_mn = (use_pipeline ? 2 : 1) * sizeof(__fp16); // O x 2 (output double buffer) + + size_t m_chunk_n_rows = 0, n_chunk_n_cols = 0; + if (hmx_compute_chunks(vtcm_budget, /*overhead=*/256, size_per_n, /*per_m=*/vec_dot_size, size_per_mn, + hex_align_up(m, HMX_FP16_TILE_N_ROWS), n, + /*m_block_cost=*/(size_t) n * 3, + /*n_block_cost=*/(size_t) m * 2, &m_chunk_n_rows, &n_chunk_n_cols, &vtcm_used)) { + FARF(HIGH, "hmx-mm-2d: VTCM too small : m %d k %d n %d budget %zu", m, k, n, vtcm_budget); + return -1; + } + + const size_t weight_area_size = hex_align_up(n_chunk_n_cols * row_stride, HMX_FP16_TILE_SIZE); + const size_t act_area_size = hex_align_up(m_chunk_n_rows * vec_dot_size, HMX_FP16_TILE_SIZE); + const size_t output_area_size = hex_align_up(m_chunk_n_rows * n_chunk_n_cols * sizeof(__fp16), HMX_FP16_TILE_SIZE); + + size_t scratch0_size, scratch1_size, scratch2_size; + scratch0_size = hex_align_up(n_chunk_n_cols * vec_dot_size, HMX_FP16_TILE_SIZE); // dequant buf 0 + scratch1_size = use_pipeline ? scratch0_size : 0; // dequant buf 1 + scratch2_size = use_pipeline ? output_area_size : 0; // output buf 1 + + uint8_t *vtcm_ptr = (uint8_t *) ctx->vtcm_base; + __fp16 *vtcm_weight = (__fp16 *) vtcm_seq_alloc(&vtcm_ptr, weight_area_size); + __fp16 *vtcm_activation = (__fp16 *) vtcm_seq_alloc(&vtcm_ptr, act_area_size); + __fp16 *vtcm_output = (__fp16 *) vtcm_seq_alloc(&vtcm_ptr, output_area_size); + void *vtcm_scratch0 = vtcm_seq_alloc(&vtcm_ptr, scratch0_size); + void *vtcm_scratch1 = scratch1_size ? vtcm_seq_alloc(&vtcm_ptr, scratch1_size) : NULL; + void *vtcm_scratch2 = scratch2_size ? vtcm_seq_alloc(&vtcm_ptr, scratch2_size) : NULL; + __fp16 *vtcm_scales = (__fp16 *) vtcm_seq_alloc(&vtcm_ptr, 256); + + vtcm_used = vtcm_ptr - (uint8_t *) ctx->vtcm_base; + if (vtcm_used > vtcm_budget) { + FARF(ERROR, "hmx-mm-2d: VTCM overflow: used %zu budget %zu", vtcm_used, vtcm_budget); + return -1; + } + + hmx_init_column_scales(vtcm_scales, Q6_V_vsplat_R(0x3c00)); // scale: 1.0, bias: 0.0 in FP16 + + FARF(HIGH, "hmx-mm-2d: standard : m %d k %d n %d wtype %d mc %zu nc %zu vtcm %zu/%zu", + m, k, n, weight_type, m_chunk_n_rows, n_chunk_n_cols, vtcm_used, vtcm_budget); + + TIMER_DEFINE(activation_load); + TIMER_DEFINE(weight_load); + TIMER_DEFINE(hmx_core); + TIMER_DEFINE(output_store); + + TIMER_DEFINE(total); + TIMER_START(total); + + int n_chunk_cnt = hmx_ceil_div(n, n_chunk_n_cols); + + if (use_pipeline) { + // --- Asynchronous Pipelined Loop (Current implementation) --- + hmx_matmul_job_t job_slots[2]; // persistent double-buffered job descriptors + + for (size_t mr = 0; mr < m; mr += m_chunk_n_rows) { + const size_t n_rows = hex_smin(m - mr, m_chunk_n_rows); + + void *vtcm_qweight = vtcm_weight; + void *vtcm_weight_bufs[2] = { vtcm_scratch0, vtcm_scratch1 }; + void *vtcm_output_bufs[2] = { vtcm_output, vtcm_scratch2 }; + + // prologue: A0 + const size_t n_cols_A0 = hex_smin(n - 0 * n_chunk_n_cols, n_chunk_n_cols); + { + const uint8_t *qweight_chunk_A0 = permuted_weight; + dma_queue_push(ctx->dma[0], dma_make_ptr(vtcm_qweight, qweight_chunk_A0), row_stride, weight_stride, row_stride, n_cols_A0); + } + + { + const float *activation_chunk = activation + mr * act_stride; + transfer_activation_chunk_threaded(ctx, vtcm_activation, activation_chunk, n_rows, k, act_stride, num_threads); + } + + // prologue: B0, A1, submit C0 (async), B1 (overlaps C0) + { + // B0: wait for DMA, dequant weight chunk 0 + dma_queue_pop(ctx->dma[0]); + dequantize_x4x2_weight_chunk_to_fp16_tiles(ctx, vtcm_weight_bufs[0], vtcm_qweight, n_cols_A0, k, row_stride, weight_type, n_k_tiles, n_k_tiles_div, dequant_worker_fn, num_threads); + + // A1: issue DMA for weight chunk 1 + const size_t n_cols_A1 = hex_smin(n - 1 * n_chunk_n_cols, n_chunk_n_cols); + if (1 < n_chunk_cnt) { + const uint8_t *qweight_chunk_A1 = permuted_weight + n_chunk_n_cols * weight_stride; + dma_queue_push(ctx->dma[0], dma_make_ptr(vtcm_qweight, qweight_chunk_A1), row_stride, weight_stride, row_stride, n_cols_A1); + } + + // submit C0 (non-blocking — HMX worker executes in parallel) + hmx_matmul_job_init(&job_slots[0], (__fp16 *) vtcm_output_bufs[0], (__fp16 *) vtcm_activation, + (__fp16 *) vtcm_weight_bufs[0], vtcm_scales, + hmx_ceil_div(n_rows, HMX_FP16_TILE_N_ROWS), + hmx_ceil_div(n_cols_A0, HMX_FP16_TILE_N_COLS), k / HMX_FP16_TILE_N_ROWS); + hmx_queue_push(ctx->hmx_queue, hmx_queue_make_desc(hmx_matmul_worker_fn, &job_slots[0])); + + // B1: DMA pop + dequant (runs in parallel with C0 on HMX worker) + if (1 < n_chunk_cnt) { + dma_queue_pop(ctx->dma[0]); + dequantize_x4x2_weight_chunk_to_fp16_tiles(ctx, vtcm_weight_bufs[1], vtcm_qweight, n_cols_A1, k, row_stride, weight_type, n_k_tiles, n_k_tiles_div, dequant_worker_fn, num_threads); + } + } + + // main loop: wait C_i → submit C_{i+1} → D_i + B_{i+2} (parallel with C_{i+1}) + for (int i = 0; i < n_chunk_cnt; ++i) { + const size_t nc = i * n_chunk_n_cols; + const size_t nc_p1 = nc + 1 * n_chunk_n_cols; + const size_t nc_p2 = nc + 2 * n_chunk_n_cols; + + const size_t n_cols = hex_smin(n - nc, n_chunk_n_cols); + const size_t n_cols_p1 = hex_smin(n - nc_p1, n_chunk_n_cols); + const size_t n_cols_p2 = hex_smin(n - nc_p2, n_chunk_n_cols); + + // issue A_{i+2}: DMA push (non-blocking) + if (i + 2 < n_chunk_cnt) { + const uint8_t *qweight_chunk_p2 = permuted_weight + nc_p2 * weight_stride; + dma_queue_push(ctx->dma[0], dma_make_ptr(vtcm_qweight, qweight_chunk_p2), row_stride, weight_stride, row_stride, n_cols_p2); + } + + // wait C_i: block until prologue/previous C completes + hmx_queue_pop(ctx->hmx_queue); + + // submit C_{i+1} (non-blocking, overlaps with D_i + B_{i+2} below) + if (i + 1 < n_chunk_cnt) { + hmx_matmul_job_init(&job_slots[(i + 1) % 2], (__fp16 *) vtcm_output_bufs[(i + 1) % 2], + (__fp16 *) vtcm_activation, (__fp16 *) vtcm_weight_bufs[(i + 1) % 2], + vtcm_scales, hmx_ceil_div(n_rows, HMX_FP16_TILE_N_ROWS), + hmx_ceil_div(n_cols_p1, HMX_FP16_TILE_N_COLS), k / HMX_FP16_TILE_N_ROWS); + hmx_queue_push(ctx->hmx_queue, hmx_queue_make_desc(hmx_matmul_worker_fn, &job_slots[(i + 1) % 2])); + } + + // D_i: store output (multi-thread HVX, parallel with C_{i+1}) + float *output_chunk = dst + (mr * n + nc); + transfer_output_chunk_threaded(ctx, output_chunk, vtcm_output_bufs[i % 2], n_rows, n_cols, n, num_threads); + + // B_{i+2}: DMA pop + dequant (multi-thread HVX, parallel with C_{i+1}) + if (i + 2 < n_chunk_cnt) { + dma_queue_pop(ctx->dma[0]); + dequantize_x4x2_weight_chunk_to_fp16_tiles(ctx, vtcm_weight_bufs[(i + 2) % 2], vtcm_qweight, n_cols_p2, k, row_stride, weight_type, n_k_tiles, n_k_tiles_div, dequant_worker_fn, num_threads); + } + } + } + hmx_queue_suspend(ctx->hmx_queue); + } else { + // --- Synchronous Loop (Optimized for small/non-pipelined cases) --- + HAP_compute_res_hmx_lock(ctx->vtcm_rctx); + + for (size_t mr = 0; mr < m; mr += m_chunk_n_rows) { + const size_t n_rows = hex_smin(m - mr, m_chunk_n_rows); + const size_t n_row_tiles = hmx_ceil_div(n_rows, HMX_FP16_TILE_N_ROWS); + + // Load Activation + const float *activation_chunk = activation + mr * act_stride; + transfer_activation_chunk_threaded(ctx, vtcm_activation, activation_chunk, n_rows, k, act_stride, num_threads); + + for (size_t nc = 0; nc < n; nc += n_chunk_n_cols) { + const size_t n_cols = hex_smin(n - nc, n_chunk_n_cols); + const size_t n_col_tiles = hmx_ceil_div(n_cols, HMX_FP16_TILE_N_COLS); + + // A: DMA Load Weight + const uint8_t *qweight_chunk = permuted_weight + nc * weight_stride; + dma_queue_push(ctx->dma[0], dma_make_ptr(vtcm_weight, qweight_chunk), row_stride, weight_stride, row_stride, n_cols); + dma_queue_pop(ctx->dma[0]); + + // B: Dequantize / Convert Weight + dequantize_x4x2_weight_chunk_to_fp16_tiles(ctx, vtcm_scratch0, vtcm_weight, n_cols, k, row_stride, weight_type, n_k_tiles, n_k_tiles_div, dequant_worker_fn, num_threads); + + // C: HMX Compute (Synchronous) + core_dot_chunk_fp16(vtcm_output, vtcm_activation, vtcm_scratch0, vtcm_scales, n_row_tiles, n_col_tiles, k / HMX_FP16_TILE_N_ROWS); + + // D: Output Store + float *output_chunk = dst + (mr * n + nc); + transfer_output_chunk_threaded(ctx, output_chunk, vtcm_output, n_rows, n_cols, n, num_threads); + } + } + HAP_compute_res_hmx_unlock(ctx->vtcm_rctx); + } + + TIMER_STOP(total); + +#if defined(ENABLE_PROFILE_TIMERS) + FARF(HIGH, "hex-mm-2d: %lld us : m %d k %d n %d", TIMER_US(total), m, k, n); + if (!use_pipeline) { + FARF(HIGH, " activation_load: %lld us, weight_load: %lld us, hmx_core: %lld us, output_store: %lld us", + TIMER_US(activation_load), TIMER_US(weight_load), TIMER_US(hmx_core), TIMER_US(output_store)); + size_t weight_size = (size_t)n * row_stride; + float bandwidth = 1e-3f * weight_size / (float)TIMER_US(weight_load); + FARF(HIGH, " weight load bandwidth: %.2f GB/s", bandwidth); + } +#endif + + return 0; +} + +// + +static inline int hmx_matmul_batch_r2(const hmx_matmul_f16_f32_batched_params_t *params) { + return params->ne02 > 0 ? params->ne12 / params->ne02 : 1; +} + +static inline int hmx_matmul_batch_r3(const hmx_matmul_f16_f32_batched_params_t *params) { + return params->ne03 > 0 ? params->ne13 / params->ne03 : 1; +} + +static inline const __fp16 *hmx_matmul_weight_batch_ptr(const hmx_matmul_f16_f32_batched_params_t *params, + int dst_b2, int dst_b3) { + const int r2 = hmx_matmul_batch_r2(params); + const int r3 = hmx_matmul_batch_r3(params); + return (const __fp16 *) ((const uint8_t *) params->permuted_weight + + (size_t) (dst_b2 / r2) * params->src0_nb2 + + (size_t) (dst_b3 / r3) * params->src0_nb3); +} + +static inline const float *hmx_matmul_activation_batch_ptr(const hmx_matmul_f16_f32_batched_params_t *params, + int dst_b2, int dst_b3) { + return (const float *) ((const uint8_t *) params->activation + + (size_t) dst_b2 * params->src1_nb2 + + (size_t) dst_b3 * params->src1_nb3); +} + +static inline float *hmx_matmul_dst_batch_ptr(const hmx_matmul_f16_f32_batched_params_t *params, + int dst_b2, int dst_b3) { + return (float *) ((uint8_t *) params->dst + + (size_t) dst_b2 * params->dst_nb2 + + (size_t) dst_b3 * params->dst_nb3); +} + +static int hmx_matmul_f16_f32_batched_legacy(struct htp_context *ctx, + const hmx_matmul_f16_f32_batched_params_t *params) { + int ret = 0; + for (int b3 = 0; b3 < params->ne13 && ret == 0; ++b3) { + for (int b2 = 0; b2 < params->ne12 && ret == 0; ++b2) { + ret = hmx_matmul_f16_f32(ctx, hmx_matmul_dst_batch_ptr(params, b2, b3), + hmx_matmul_activation_batch_ptr(params, b2, b3), + hmx_matmul_weight_batch_ptr(params, b2, b3), + params->m, params->k, params->n, + params->act_stride, params->weight_stride); + } + } + return ret; +} + +int hmx_matmul_f16_f32_batched(struct htp_context *ctx, const hmx_matmul_f16_f32_batched_params_t *params) { + if (!ctx || !params || !params->dst || !params->activation || !params->permuted_weight) { return -1; } + if (!params->m || !params->k || !params->n) { return -1; } + if (params->act_stride < params->k || params->weight_stride < params->k || params->dst_stride < params->n) { return -1; } + if (params->ne02 <= 0 || params->ne03 <= 0 || params->ne12 <= 0 || params->ne13 <= 0) { return -1; } + if (params->ne12 % params->ne02 != 0 || params->ne13 % params->ne03 != 0) { return -1; } + if (params->k % 32 != 0 || params->n % 32 != 0) { return -1; } + + if (!hex_is_aligned(params->dst, VLEN) || + !hex_is_aligned(params->activation, VLEN) || + !hex_is_aligned(params->permuted_weight, VLEN)) { + return -1; + } + + const int group_size = hmx_matmul_batch_r2(params); + + if (group_size <= 1) { + FARF(HIGH, "%s: no dim2 GQA reuse (group=%d), using legacy batched loop", __func__, group_size); + return hmx_matmul_f16_f32_batched_legacy(ctx, params); + } + + // Grouped path: reuse interleaved weight across all q_heads sharing a + // kv_head. Each q_head gets its own activation buffer in VTCM (so + // activation is loaded once per m_chunk and reused across all n_chunks), + // and each q_head is computed individually to avoid tile-major packing + // issues. m_chunk_n_rows is always a multiple of 32 (from + // hmx_compute_chunks), so per-head tile arrays don't overlap. + const size_t vtcm_budget = ctx->vtcm_size; + const size_t vec_dot_size = params->k * sizeof(__fp16); + + // When the activation has a large stride (e.g. permuted Q tensor with + // act_stride >> k), HVX vector loads from strided DDR thrash L2 cache. + // Allocate an F32 scratch buffer in VTCM and use 2D DMA to gather + // strided rows into a contiguous block before the F32->F16 conversion. + const bool use_dma_activation = (params->act_stride > params->k); + const size_t f32_scratch_per_m = use_dma_activation ? (size_t) params->k * sizeof(float) : 0; + + size_t m_chunk_n_rows = 0, n_chunk_n_cols = 0, vtcm_used = 0; + // FP16 weight: interleave and activation load have similar per-element cost. + if (hmx_compute_chunks(vtcm_budget, /*overhead=*/256, + /*per_n=*/3 * vec_dot_size, + /*per_m=*/group_size * vec_dot_size + f32_scratch_per_m, + /*per_mn=*/sizeof(__fp16), + hex_align_up(params->m, HMX_FP16_TILE_N_ROWS), params->n, + /*m_block_cost=*/(size_t) params->n, + /*n_block_cost=*/(size_t) params->m, &m_chunk_n_rows, &n_chunk_n_cols, &vtcm_used) != 0) { + FARF(HIGH, "%s: grouped path does not fit VTCM, falling back to legacy batched loop", __func__); + return hmx_matmul_f16_f32_batched_legacy(ctx, params); + } + + const size_t act_head_stride = m_chunk_n_rows * (size_t) params->k; // fp16 elements between heads + const size_t weight_area_size = hex_align_up(n_chunk_n_cols * vec_dot_size, HMX_FP16_TILE_SIZE); + const size_t activation_area_size = hex_align_up(group_size * m_chunk_n_rows * vec_dot_size, HMX_FP16_TILE_SIZE); + const size_t output_area_size = hex_align_up(m_chunk_n_rows * n_chunk_n_cols * sizeof(__fp16), HMX_FP16_TILE_SIZE); + const size_t scratch_area_size = hex_align_up(n_chunk_n_cols * vec_dot_size, HMX_FP16_TILE_SIZE); + const size_t f32_scratch_size = use_dma_activation + ? hex_align_up(m_chunk_n_rows * (size_t) params->k * sizeof(float), HMX_FP16_TILE_SIZE) : 0; + + uint8_t *vtcm_ptr = (uint8_t *) ctx->vtcm_base; + __fp16 *vtcm_weight = (__fp16 *) vtcm_seq_alloc(&vtcm_ptr, weight_area_size); + __fp16 *vtcm_activation = (__fp16 *) vtcm_seq_alloc(&vtcm_ptr, activation_area_size); + __fp16 *vtcm_output = (__fp16 *) vtcm_seq_alloc(&vtcm_ptr, output_area_size); + void *vtcm_scratch0 = vtcm_seq_alloc(&vtcm_ptr, scratch_area_size); + void *vtcm_scratch1 = vtcm_seq_alloc(&vtcm_ptr, scratch_area_size); + __fp16 *vtcm_scales = (__fp16 *) vtcm_seq_alloc(&vtcm_ptr, 256); + float *vtcm_f32_act = use_dma_activation ? (float *) vtcm_seq_alloc(&vtcm_ptr, f32_scratch_size) : NULL; + + if ((size_t) (vtcm_ptr - (uint8_t *) ctx->vtcm_base) > vtcm_budget) { + FARF(HIGH, "%s: grouped layout overflowed VTCM, falling back to legacy batched loop", __func__); + return hmx_matmul_f16_f32_batched_legacy(ctx, params); + } + + hmx_init_column_scales(vtcm_scales, Q6_V_vsplat_R(0x3c00)); // scale: 1.0, bias: 0.0 in FP16 + + FARF(HIGH, "%s: grouped path m=%d k=%d n=%d group=%d streams=%d mc=%zu nc=%zu vtcm=%zu/%zu", + __func__, params->m, params->k, params->n, group_size, params->ne13, + m_chunk_n_rows, n_chunk_n_cols, + (size_t) (vtcm_ptr - (uint8_t *) ctx->vtcm_base), vtcm_budget); + + TIMER_DEFINE(activation_load); + TIMER_DEFINE(weight_load); + TIMER_DEFINE(hmx_core); + TIMER_DEFINE(output_store); + TIMER_DEFINE(total); + + TIMER_START(total); + + const size_t fp16_row_bytes = (size_t) params->k * sizeof(__fp16); + const size_t weight_row_bytes = (size_t) params->weight_stride * sizeof(__fp16); + + HAP_compute_res_hmx_lock(ctx->vtcm_rctx); + + for (int b3 = 0; b3 < params->ne13; ++b3) { + for (int b2_base = 0; b2_base < params->ne12; b2_base += group_size) { + const __fp16 *weight_group = hmx_matmul_weight_batch_ptr(params, b2_base, b3); + + for (size_t mr = 0; mr < (size_t) params->m; mr += m_chunk_n_rows) { + const size_t n_rows = hex_smin((size_t) params->m - mr, m_chunk_n_rows); + const size_t n_row_tiles = hmx_ceil_div((int) n_rows, HMX_FP16_TILE_N_ROWS); + + // Pre-load activations for all heads in the group (once per m_chunk). + // When the source is strided (permuted Q), use 2D DMA to gather + // contiguous rows into a VTCM scratch buffer first, then HVX + // converts from the contiguous VTCM buffer. This avoids L2 cache + // thrashing from HVX loads at large strides. + TIMER_START(activation_load); + for (int g = 0; g < group_size; ++g) { + const float *activation_chunk = hmx_matmul_activation_batch_ptr(params, b2_base + g, b3) + mr * params->act_stride; + __fp16 *vtcm_act_g = vtcm_activation + (size_t) g * act_head_stride; + if (use_dma_activation) { + const size_t row_bytes = (size_t) params->k * sizeof(float); + const size_t stride_bytes = (size_t) params->act_stride * sizeof(float); + dma_queue_push(ctx->dma[0], + dma_make_ptr(vtcm_f32_act, activation_chunk), + row_bytes, stride_bytes, row_bytes, n_rows); + dma_queue_pop(ctx->dma[0]); + transfer_activation_chunk_threaded(ctx, vtcm_act_g, + vtcm_f32_act, (int) n_rows, + params->k, params->k, ctx->n_threads); + } else { + transfer_activation_chunk_threaded(ctx, vtcm_act_g, + activation_chunk, (int) n_rows, + params->k, params->act_stride, ctx->n_threads); + } + } + TIMER_STOP(activation_load); + + void *buf_curr = vtcm_scratch0; + void *buf_next = vtcm_scratch1; + + { + const size_t n_cols_first = hex_smin((size_t) params->n, n_chunk_n_cols); + dma_queue_push(ctx->dma[0], dma_make_ptr(buf_curr, weight_group), + fp16_row_bytes, weight_row_bytes, fp16_row_bytes, n_cols_first); + } + + for (size_t nc = 0; nc < (size_t) params->n; nc += n_chunk_n_cols) { + const size_t n_cols = hex_smin((size_t) params->n - nc, n_chunk_n_cols); + const size_t n_col_tiles = hmx_ceil_div((int) n_cols, HMX_FP16_TILE_N_COLS); + + TIMER_START(weight_load); + { + dma_queue_pop(ctx->dma[0]); + + const size_t nc_next = nc + n_chunk_n_cols; + if (nc_next < (size_t) params->n) { + const size_t n_cols_next = hex_smin((size_t) params->n - nc_next, n_chunk_n_cols); + const __fp16 *next_weight_chunk = weight_group + nc_next * params->weight_stride; + + dma_queue_push(ctx->dma[0], dma_make_ptr(buf_next, next_weight_chunk), + fp16_row_bytes, weight_row_bytes, fp16_row_bytes, n_cols_next); + } + + hmx_interleave_rows_to_tiles(vtcm_weight, (const __fp16 *) buf_curr, n_cols, params->k, params->k, + 0, n_cols); + hex_swap_ptr(&buf_curr, &buf_next); + } + TIMER_STOP(weight_load); + + // Reuse the interleaved weight for every q_head in this GQA group + for (int g = 0; g < group_size; ++g) { + TIMER_START(hmx_core); + { + const __fp16 * vtcm_act_g = vtcm_activation + (size_t) g * act_head_stride; + core_dot_chunk_fp16(vtcm_output, vtcm_act_g, vtcm_weight, vtcm_scales, n_row_tiles, n_col_tiles, + params->k / 32); + } + TIMER_STOP(hmx_core); + + TIMER_START(output_store); + { + float *output = hmx_matmul_dst_batch_ptr(params, b2_base + g, b3) + mr * params->dst_stride + nc; + transfer_output_chunk_threaded(ctx, output, vtcm_output, (int) n_rows, (int) n_cols, params->dst_stride, ctx->n_threads); + } + TIMER_STOP(output_store); + } + } + } + } + } + + HAP_compute_res_hmx_unlock(ctx->vtcm_rctx); + + TIMER_STOP(total); + +#if defined(ENABLE_PROFILE_TIMERS) + FARF(HIGH, "%s: %lld us, m=%d k=%d n=%d group=%d", __func__, TIMER_US(total), + params->m, params->k, params->n, group_size); + FARF(HIGH, " activation_load: %lld us, weight_load: %lld us, hmx_core: %lld us, output_store: %lld us", + TIMER_US(activation_load), TIMER_US(weight_load), TIMER_US(hmx_core), TIMER_US(output_store)); +#endif + + return 0; +} + +int hmx_matmul_f16_f32(struct htp_context *ctx, float *restrict dst, const float *restrict activation, + const __fp16 *restrict permuted_weight, int m, int k, int n, + int act_stride, int weight_stride) { + if (!dst || !activation || !permuted_weight || !m || !n || !k) { return -1; } + return hmx_matmul_2d_f32(ctx, dst, activation, (const uint8_t *)permuted_weight, m, k, n, + act_stride, weight_stride * (int)sizeof(__fp16), HTP_TYPE_F16); +} + +struct mmid_row_mapping { + uint32_t i1; + uint32_t i2; +}; + +typedef struct { + __fp16 *dst; + const float *src; + int n_tasks; + int n_tot_chunks; + int n_chunks_per_task; + int k_block; + const struct mmid_row_mapping *matrix_rows; + int cur_a; + int mapping_stride; + int ne11; + struct fastdiv_values ne11_div; + size_t nb11; + size_t nb12; + int start_row; + int cne1; +} activation_transfer_gathered_task_state_t; + +typedef struct { + const __fp16 *vtcm_src; + float *dst; + int n_tasks; + int n_tot_chunks; + int n_chunks_per_task; + int n_cols; + const struct mmid_row_mapping *matrix_rows; + int cur_a; + int mapping_stride; + size_t dst_nb1; + size_t dst_nb2; + int start_row; + int cne1; +} output_transfer_scattered_task_state_t; + +static void transfer_activation_chunk_fp32_to_fp16_gathered( + __fp16 *restrict vtcm_dst, + const float *restrict src, + int start_row, + int n_rows, + int k_block, + const struct mmid_row_mapping *matrix_rows, + int cur_a, + int mapping_stride, + int ne11, + const struct fastdiv_values * ne11_div, + size_t nb11, + size_t nb12, + int cne1) { + const int n_rows_padded = hex_align_up(n_rows, HMX_FP16_TILE_N_ROWS); + const int n_rows_tiled = (n_rows / HMX_FP16_TILE_N_ROWS) * HMX_FP16_TILE_N_ROWS; + + int r = 0; + + #pragma unroll(2) + for (r = 0; r < n_rows_tiled; r += 2) { + int r0 = r / HMX_FP16_TILE_N_ROWS; // tile row index + int r1 = r % HMX_FP16_TILE_N_ROWS; // intra-tile row idx + + int r_idx0 = start_row + r + 0; + int r_idx1 = start_row + r + 1; + + struct mmid_row_mapping mapping0 = matrix_rows[cur_a * mapping_stride + r_idx0]; + struct mmid_row_mapping mapping1 = matrix_rows[cur_a * mapping_stride + r_idx1]; + + int i11_0 = fastmodulo(mapping0.i1, ne11, ne11_div); + int i11_1 = fastmodulo(mapping1.i1, ne11, ne11_div); + + const float *row0_ptr = (const float *) ((const uint8_t *) src + i11_0 * nb11 + mapping0.i2 * nb12); + const float *row1_ptr = (const float *) ((const uint8_t *) src + i11_1 * nb11 + mapping1.i2 * nb12); + + const HVX_Vector *pv_in0 = (const HVX_Vector *) row0_ptr; + const HVX_Vector *pv_in1 = (const HVX_Vector *) row1_ptr; + + for (int c = 0; c < k_block; c += 32) { + HVX_Vector v0 = *pv_in0++; + HVX_Vector v1 = *pv_in1++; + + HVX_Vector v_out = hvx_vec_f32_to_f16_shuff(v0, v1); + + int c0 = c / HMX_FP16_TILE_N_COLS; // tile column index + int tile_idx = r0 * (k_block / HMX_FP16_TILE_N_COLS) + c0; + + HVX_Vector *tile = (HVX_Vector *) (vtcm_dst + tile_idx * HMX_FP16_TILE_N_ELMS); + tile[r1 / 2] = v_out; + } + } + + for (; r < n_rows_padded; r += 2) { + int r0 = r / HMX_FP16_TILE_N_ROWS; // tile row index + int r1 = r % HMX_FP16_TILE_N_ROWS; // intra-tile row idx + + const bool row0_valid = (start_row + r + 0) < cne1; + const bool row1_valid = (start_row + r + 1) < cne1; + + const float *row0_ptr = NULL; + const float *row1_ptr = NULL; + + if (row0_valid) { + struct mmid_row_mapping mapping0 = matrix_rows[cur_a * mapping_stride + (start_row + r + 0)]; + int i11_0 = fastmodulo(mapping0.i1, ne11, ne11_div); + row0_ptr = (const float *) ((const uint8_t *) src + i11_0 * nb11 + mapping0.i2 * nb12); + } + if (row1_valid) { + struct mmid_row_mapping mapping1 = matrix_rows[cur_a * mapping_stride + (start_row + r + 1)]; + int i11_1 = fastmodulo(mapping1.i1, ne11, ne11_div); + row1_ptr = (const float *) ((const uint8_t *) src + i11_1 * nb11 + mapping1.i2 * nb12); + } + + const HVX_Vector *pv_in0 = (const HVX_Vector *) row0_ptr; + const HVX_Vector *pv_in1 = (const HVX_Vector *) row1_ptr; + + for (int c = 0; c < k_block; c += 32) { + HVX_Vector v0 = row0_valid ? *pv_in0++ : Q6_V_vzero(); + HVX_Vector v1 = row1_valid ? *pv_in1++ : Q6_V_vzero(); + + HVX_Vector v_out = hvx_vec_f32_to_f16_shuff(v0, v1); + + int c0 = c / HMX_FP16_TILE_N_COLS; // tile column index + int tile_idx = r0 * (k_block / HMX_FP16_TILE_N_COLS) + c0; + + HVX_Vector *tile = (HVX_Vector *) (vtcm_dst + tile_idx * HMX_FP16_TILE_N_ELMS); + tile[r1 / 2] = v_out; + } + } +} + +static void transfer_activation_chunk_gathered_worker_fn(unsigned int n, unsigned int i, void *data) { + activation_transfer_gathered_task_state_t *st = data; + int chunk_idx = i; + int chunk_size = st->n_chunks_per_task; + int start_row = st->start_row + chunk_idx * chunk_size; + int n_rows = hex_smin(st->cne1 - start_row, chunk_size); + if (n_rows > 0) { + __fp16 *dst = st->dst + (size_t)(start_row - st->start_row) * st->k_block; + transfer_activation_chunk_fp32_to_fp16_gathered( + dst, st->src, start_row, n_rows, st->k_block, + st->matrix_rows, st->cur_a, st->mapping_stride, + st->ne11, &st->ne11_div, st->nb11, st->nb12, st->cne1); + } +} + +static void transfer_activation_chunk_gathered_threaded( + struct htp_context *ctx, + __fp16 *dst, + const float *src, + int start_row, + int n_rows, + int k_block, + const struct mmid_row_mapping *matrix_rows, + int cur_a, + int mapping_stride, + int ne11, + size_t nb11, + size_t nb12, + int cne1, + int n_threads) { + if (n_rows <= 0) return; + int chunks_per_thread = hmx_ceil_div(n_rows, n_threads); + chunks_per_thread = hex_align_up(chunks_per_thread, HMX_FP16_TILE_N_ROWS); + + int actual_threads = hmx_ceil_div(n_rows, chunks_per_thread); + + activation_transfer_gathered_task_state_t state = { + .dst = dst, + .src = src, + .n_tasks = actual_threads, + .n_tot_chunks = n_rows, + .n_chunks_per_task = chunks_per_thread, + .k_block = k_block, + .matrix_rows = matrix_rows, + .cur_a = cur_a, + .mapping_stride = mapping_stride, + .ne11 = ne11, + .ne11_div = init_fastdiv_values(ne11), + .nb11 = nb11, + .nb12 = nb12, + .start_row = start_row, + .cne1 = cne1, + }; + + if (actual_threads <= 1) { + transfer_activation_chunk_gathered_worker_fn(1, 0, &state); + } else { + worker_pool_run_func(ctx->worker_pool, transfer_activation_chunk_gathered_worker_fn, &state, actual_threads); + } +} + +static void transfer_output_chunk_fp16_to_fp32_scattered( + float *restrict dst, + const __fp16 *restrict vtcm_src, + int start_row, + int n_rows, + int n_cols, + const struct mmid_row_mapping *matrix_rows, + int cur_a, + int mapping_stride, + size_t dst_nb1, + size_t dst_nb2, + int cne1) { + assert(n_cols % HMX_FP16_TILE_N_COLS == 0); + const size_t tile_row_stride = (n_cols / HMX_FP16_TILE_N_COLS) * HMX_FP16_TILE_N_ELMS; + + const HVX_Vector one = hvx_vec_splat_f16(1.0); + + for (size_t r = 0; r < n_rows; r += 2) { + const size_t r0 = r / HMX_FP16_TILE_N_ROWS; + const size_t r1 = (r % HMX_FP16_TILE_N_ROWS) / 2; // index of the row pair within the tile + const __fp16 *row_base = vtcm_src + r0 * tile_row_stride; + + int r_idx0 = start_row + (int)r + 0; + int r_idx1 = start_row + (int)r + 1; + + if (r_idx0 >= cne1) break; + + struct mmid_row_mapping mapping0 = matrix_rows[cur_a * mapping_stride + r_idx0]; + float *output_row0 = (float *) ((uint8_t *) dst + mapping0.i1 * dst_nb1 + mapping0.i2 * dst_nb2); + + float *output_row1 = NULL; + if (r_idx1 < cne1) { + struct mmid_row_mapping mapping1 = matrix_rows[cur_a * mapping_stride + r_idx1]; + output_row1 = (float *) ((uint8_t *) dst + mapping1.i1 * dst_nb1 + mapping1.i2 * dst_nb2); + } + + #pragma unroll(4) + for (size_t c = 0; c < (size_t)n_cols; c += HMX_FP16_TILE_N_COLS) { + const size_t c0 = c / HMX_FP16_TILE_N_COLS; + const __fp16 *tile = row_base + c0 * HMX_FP16_TILE_N_ELMS; + HVX_Vector v = ((const HVX_Vector *) tile)[r1]; + HVX_VectorPair vp = Q6_Wqf32_vmpy_VhfVhf(v, one); + + volatile HVX_Vector *pv_out0 = (volatile HVX_Vector *) (output_row0 + c); + volatile HVX_Vector *pv_out1 = output_row1 ? (volatile HVX_Vector *) (output_row1 + c) : NULL; + + *pv_out0 = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(vp)); + if (pv_out1) { + *pv_out1 = Q6_Vsf_equals_Vqf32(Q6_V_hi_W(vp)); + } + } + } +} + +static void transfer_output_chunk_scattered_worker_fn(unsigned int n, unsigned int i, void *data) { + output_transfer_scattered_task_state_t *st = data; + int chunk_idx = i; + int chunk_size = st->n_chunks_per_task; + int start_row = st->start_row + chunk_idx * chunk_size; + int n_rows = hex_smin(st->cne1 - start_row, chunk_size); + if (n_rows > 0) { + const __fp16 *src = st->vtcm_src + (size_t)(start_row - st->start_row) * st->n_cols; + transfer_output_chunk_fp16_to_fp32_scattered( + st->dst, src, start_row, n_rows, st->n_cols, + st->matrix_rows, st->cur_a, st->mapping_stride, + st->dst_nb1, st->dst_nb2, st->cne1); + } +} + +static void transfer_output_chunk_scattered_threaded( + struct htp_context *ctx, + float *dst, + const __fp16 *vtcm_src, + int start_row, + int n_rows, + int n_cols, + const struct mmid_row_mapping *matrix_rows, + int cur_a, + int mapping_stride, + size_t dst_nb1, + size_t dst_nb2, + int cne1, + int n_threads) { + if (n_rows <= 0) return; + int chunks_per_thread = hmx_ceil_div(n_rows, n_threads); + chunks_per_thread = hex_align_up(chunks_per_thread, HMX_FP16_TILE_N_ROWS); + + int actual_threads = hmx_ceil_div(n_rows, chunks_per_thread); + + output_transfer_scattered_task_state_t state = { + .vtcm_src = vtcm_src, + .dst = dst, + .n_tasks = actual_threads, + .n_tot_chunks = n_rows, + .n_chunks_per_task = chunks_per_thread, + .n_cols = n_cols, + .matrix_rows = matrix_rows, + .cur_a = cur_a, + .mapping_stride = mapping_stride, + .dst_nb1 = dst_nb1, + .dst_nb2 = dst_nb2, + .start_row = start_row, + .cne1 = cne1, + }; + + if (actual_threads <= 1) { + transfer_output_chunk_scattered_worker_fn(1, 0, &state); + } else { + worker_pool_run_func(ctx->worker_pool, transfer_output_chunk_scattered_worker_fn, &state, actual_threads); + } +} + +int hmx_matmul_id_2d_f32(struct htp_context *ctx, + float *restrict dst, + const float *activation, + const uint8_t *permuted_weight, + int m, int k, int n, + int ne11, + size_t act_nb1, size_t act_nb2, + size_t dst_nb1, size_t dst_nb2, + int weight_stride, + int weight_type, + const struct mmid_row_mapping *matrix_rows, + int cur_a, + int mapping_stride) { + const int cne1 = m; + const int m_padded = hex_align_up(m, 32); + + if (k % 32 != 0 || n % 32 != 0) { return -1; } + + if (!hex_is_aligned(dst, VLEN) || !hex_is_aligned(activation, VLEN) || !hex_is_aligned(permuted_weight, VLEN)) { + return -1; + } + + size_t row_stride = get_x4x2_row_stride(weight_type, k); + if (row_stride == 0) { + return -1; + } + + worker_callback_t dequant_worker_fn = NULL; + switch (weight_type) { + case HTP_TYPE_Q4_0: dequant_worker_fn = dequantize_x4x2_worker_loop_q4_0; break; + case HTP_TYPE_IQ4_NL: dequant_worker_fn = dequantize_x4x2_worker_loop_iq4_nl; break; + case HTP_TYPE_Q4_1: dequant_worker_fn = dequantize_x4x2_worker_loop_q4_1; break; + case HTP_TYPE_MXFP4: dequant_worker_fn = dequantize_x4x2_worker_loop_mxfp4; break; + case HTP_TYPE_Q8_0: dequant_worker_fn = dequantize_x4x2_worker_loop_q8_0; break; + case HTP_TYPE_F16: dequant_worker_fn = convert_f16_worker_loop; break; + case HTP_TYPE_F32: dequant_worker_fn = quantize_f32_worker_loop; break; + default: + return -1; + } + + const int n_k_tiles = k / HMX_FP16_TILE_N_COLS; + const struct fastdiv_values n_k_tiles_div = init_fastdiv_values(n_k_tiles); + + const int num_threads = ctx->n_threads; + + const size_t vec_dot_size = k * sizeof(__fp16); + const size_t vtcm_budget = ctx->vtcm_size; + size_t vtcm_used = 0; + + const size_t size_per_n = row_stride + vec_dot_size; + const size_t size_per_mn = sizeof(__fp16); + + size_t m_chunk_n_rows = 0, n_chunk_n_cols = 0; + if (hmx_compute_chunks(vtcm_budget, /*overhead=*/256, size_per_n, /*per_m=*/vec_dot_size, size_per_mn, + m_padded, n, + /*m_block_cost=*/(size_t) n * 3, + /*n_block_cost=*/(size_t) m_padded * 2, &m_chunk_n_rows, &n_chunk_n_cols, &vtcm_used)) { + FARF(HIGH, "hmx-mm-id-2d: VTCM too small : m %d k %d n %d budget %zu", m_padded, k, n, vtcm_budget); + return -1; + } + + const size_t weight_area_size = hex_align_up(n_chunk_n_cols * row_stride, HMX_FP16_TILE_SIZE); + const size_t act_area_size = hex_align_up(m_chunk_n_rows * vec_dot_size, HMX_FP16_TILE_SIZE); + const size_t output_area_size = hex_align_up(m_chunk_n_rows * n_chunk_n_cols * sizeof(__fp16), HMX_FP16_TILE_SIZE); + + size_t scratch0_size = hex_align_up(n_chunk_n_cols * vec_dot_size, HMX_FP16_TILE_SIZE); + + uint8_t *vtcm_ptr = (uint8_t *) ctx->vtcm_base; + __fp16 *vtcm_weight = (__fp16 *) vtcm_seq_alloc(&vtcm_ptr, weight_area_size); + __fp16 *vtcm_activation = (__fp16 *) vtcm_seq_alloc(&vtcm_ptr, act_area_size); + __fp16 *vtcm_output = (__fp16 *) vtcm_seq_alloc(&vtcm_ptr, output_area_size); + void *vtcm_scratch0 = vtcm_seq_alloc(&vtcm_ptr, scratch0_size); + __fp16 *vtcm_scales = (__fp16 *) vtcm_seq_alloc(&vtcm_ptr, 256); + + vtcm_used = vtcm_ptr - (uint8_t *) ctx->vtcm_base; + if (vtcm_used > vtcm_budget) { + FARF(ERROR, "hmx-mm-id-2d: VTCM overflow: used %zu budget %zu", vtcm_used, vtcm_budget); + return -1; + } + + hmx_init_column_scales(vtcm_scales, Q6_V_vsplat_R(0x3c00)); + + HAP_compute_res_hmx_lock(ctx->vtcm_rctx); + + for (size_t mr = 0; mr < (size_t) m_padded; mr += m_chunk_n_rows) { + const size_t n_rows = hex_smin(m_padded - mr, m_chunk_n_rows); + const size_t n_row_tiles = hmx_ceil_div(n_rows, HMX_FP16_TILE_N_ROWS); + + transfer_activation_chunk_gathered_threaded( + ctx, vtcm_activation, activation, (int) mr, (int) n_rows, k, + matrix_rows, cur_a, mapping_stride, ne11, act_nb1, act_nb2, cne1, num_threads); + + for (size_t nc = 0; nc < (size_t) n; nc += n_chunk_n_cols) { + const size_t n_cols = hex_smin((size_t) n - nc, n_chunk_n_cols); + const size_t n_col_tiles = hmx_ceil_div(n_cols, HMX_FP16_TILE_N_COLS); + + const uint8_t *qweight_chunk = permuted_weight + nc * weight_stride; + dma_queue_push(ctx->dma[0], dma_make_ptr(vtcm_weight, qweight_chunk), row_stride, weight_stride, row_stride, n_cols); + dma_queue_pop(ctx->dma[0]); + + dequantize_x4x2_weight_chunk_to_fp16_tiles(ctx, vtcm_scratch0, vtcm_weight, n_cols, k, row_stride, weight_type, n_k_tiles, n_k_tiles_div, dequant_worker_fn, num_threads); + + core_dot_chunk_fp16(vtcm_output, vtcm_activation, vtcm_scratch0, vtcm_scales, n_row_tiles, n_col_tiles, k / HMX_FP16_TILE_N_ROWS); + + transfer_output_chunk_scattered_threaded( + ctx, dst, vtcm_output, (int) mr, (int) n_rows, (int) n_cols, + matrix_rows, cur_a, mapping_stride, dst_nb1, dst_nb2, cne1, num_threads); + } + } + + HAP_compute_res_hmx_unlock(ctx->vtcm_rctx); + return 0; +} diff --git a/ggml/src/ggml-hexagon/htp/hmx-ops.c b/ggml/src/ggml-hexagon/htp/hmx-ops.c new file mode 100644 index 00000000000..114d8c14811 --- /dev/null +++ b/ggml/src/ggml-hexagon/htp/hmx-ops.c @@ -0,0 +1,6 @@ +// HMX operations compiled as a single translation unit. +// This allows interprocedural optimizations within HMX ops without requiring global HTP LTO. + +#include "hmx-queue.c" +#include "hmx-matmul-ops.c" +#include "hmx-flash-attn-ops.c" diff --git a/ggml/src/ggml-hexagon/htp/hmx-ops.h b/ggml/src/ggml-hexagon/htp/hmx-ops.h new file mode 100644 index 00000000000..a67842f3ffc --- /dev/null +++ b/ggml/src/ggml-hexagon/htp/hmx-ops.h @@ -0,0 +1,88 @@ +// HMX operation entry-point declarations. +// Ported from htp-ops-lib/include/dsp/ops.h (renamed, benchmark kernels removed). (https://github.com/haozixu/htp-ops-lib) + +#ifndef HMX_OPS_H +#define HMX_OPS_H + +#include <stddef.h> +#include <stdint.h> + +#include "htp-ops.h" + +#ifdef __cplusplus +extern "C" { +#endif + +typedef struct { + float *dst; + const float *activation; + const __fp16 *permuted_weight; + int m; + int k; + int n; + int act_stride; + int weight_stride; + int dst_stride; + int ne02; + int ne03; + int ne12; + int ne13; + size_t src0_nb2; + size_t src0_nb3; + size_t src1_nb2; + size_t src1_nb3; + size_t dst_nb2; + size_t dst_nb3; +} hmx_matmul_f16_f32_batched_params_t; + +// HMX matrix multiplication — tile-permuted FP16 weights, FP32 activation/output +// act_stride: activation row stride in elements (= k for contiguous, or +// nb[1]/sizeof(float) for permuted tensors like attention Q). +// weight_stride: weight row stride in elements (= k for compact weights, or +// nb[1]/sizeof(__fp16) for permuted KV-cache views used by QK). +int hmx_matmul_f16_f32(struct htp_context *ctx, + float *restrict dst, + const float *activation, + const __fp16 *permuted_weight, + int m, int k, int n, + int act_stride, + int weight_stride); + +// Batched F16 wrapper over hmx_mat_mul_f16_f32. +// Batch semantics match ggml_mul_mat(): src0 broadcasts to src1 in dims 2/3. +int hmx_matmul_f16_f32_batched(struct htp_context *ctx, const hmx_matmul_f16_f32_batched_params_t *params); + +// HMX matrix multiplication — all supported weight types (F16/F32/Q4_0/Q4_1/Q8_0/IQ4_NL/MXFP4) +int hmx_matmul_2d_f32(struct htp_context *ctx, + float *restrict dst, + const float *activation, + const uint8_t *permuted_weight, + int m, int k, int n, + int act_stride, + int weight_stride, + int weight_type); + +struct mmid_row_mapping; + +int hmx_matmul_id_2d_f32(struct htp_context *ctx, + float *restrict dst, + const float *activation, + const uint8_t *permuted_weight, + int m, int k, int n, + int ne11, + size_t act_nb1, size_t act_nb2, + size_t dst_nb1, size_t dst_nb2, + int weight_stride, + int weight_type, + const struct mmid_row_mapping *matrix_rows, + int cur_a, + int mapping_stride); + +// HMX flash attention +int hmx_flash_attn_ext(struct htp_ops_context * octx); + +#ifdef __cplusplus +} +#endif + +#endif // HMX_OPS_H diff --git a/ggml/src/ggml-hexagon/htp/hmx-profile.h b/ggml/src/ggml-hexagon/htp/hmx-profile.h new file mode 100644 index 00000000000..01eece720c5 --- /dev/null +++ b/ggml/src/ggml-hexagon/htp/hmx-profile.h @@ -0,0 +1,34 @@ +// Conditional fine-grained profiling macros for HMX operations. +// +// Define ENABLE_PROFILE_TIMERS (via compiler flag or before including this +// header) to instrument sub-operation latencies with HAP qtimer. When the +// macro is not defined the TIMER_* helpers expand to nothing so there is zero +// overhead. +// +// Usage: +// TIMER_DEFINE(my_phase); // declare accumulator variable +// TIMER_START(my_phase); // snapshot start time +// ... work ... +// TIMER_STOP(my_phase); // accumulate elapsed ticks +// FARF(ALWAYS, "my_phase: %lld us", TIMER_US(my_phase)); + +#ifndef HMX_PROFILE_H +#define HMX_PROFILE_H + +#include <HAP_perf.h> + +// #define ENABLE_PROFILE_TIMERS + +#if defined(ENABLE_PROFILE_TIMERS) +# define TIMER_DEFINE(name) int64_t name##_ticks = 0 +# define TIMER_START(name) int64_t name##_t0 = HAP_perf_get_qtimer_count() +# define TIMER_STOP(name) name##_ticks += HAP_perf_get_qtimer_count() - name##_t0 +# define TIMER_US(name) HAP_perf_qtimer_count_to_us(name##_ticks) +#else +# define TIMER_DEFINE(name) +# define TIMER_START(name) +# define TIMER_STOP(name) +# define TIMER_US(name) 0LL +#endif + +#endif // HMX_PROFILE_H diff --git a/ggml/src/ggml-hexagon/htp/hmx-queue.c b/ggml/src/ggml-hexagon/htp/hmx-queue.c new file mode 100644 index 00000000000..5b1d83a0cbf --- /dev/null +++ b/ggml/src/ggml-hexagon/htp/hmx-queue.c @@ -0,0 +1,158 @@ +#pragma clang diagnostic ignored "-Wunused-function" + +#include <stdbool.h> +#include <stdlib.h> +#include <string.h> + +#include <qurt_thread.h> +#include <qurt_futex.h> + +#include <HAP_compute_res.h> + +#include "hmx-queue.h" + +#define QURT_LOWEST_PRIO (254) + +static inline void hmx_lock(struct hmx_queue *q) +{ + if (!q->hmx_locked) { + HAP_compute_res_hmx_lock(q->hap_rctx); + q->hmx_locked = true; + } +} + +static inline void hmx_unlock(struct hmx_queue *q) +{ + if (q->hmx_locked) { + HAP_compute_res_hmx_unlock(q->hap_rctx); + q->hmx_locked = false; + } +} + +static inline void hmx_queue_process(struct hmx_queue *q, bool* killed) { + unsigned int ir = atomic_load(&q->idx_read); + + while (ir != atomic_load(&q->idx_write)) { + struct hmx_queue_desc *d = &q->desc[ir]; + if (!d->done) { + FARF(HIGH, "hmx-queue-process: ir %u func %p data %p", ir, d->func, d->data); + + enum hmx_queue_signal sig = (enum hmx_queue_signal) (unsigned int) d->func; + switch (sig) { + case HMX_QUEUE_NOOP: /* noop */; break; + case HMX_QUEUE_KILL: *killed = true; break; + case HMX_QUEUE_SUSPEND: hmx_unlock(q); break; + default: + hmx_lock(q); + d->func(d->data); + break; + } + + atomic_fetch_add(&d->done, 1); + } + + ir = (ir + 1) & q->idx_mask; + atomic_store(&q->idx_read, ir); + } +} + +static void hmx_queue_thread(void * arg) { + struct hmx_queue * q = (struct hmx_queue *) arg; + + FARF(HIGH, "hmx-queue-thread: started"); + + bool killed = false; + + unsigned int poll_cnt = HMX_QUEUE_POLL_COUNT; + unsigned int prev_seqn = 0; + while (!killed) { + unsigned int seqn = atomic_load(&q->seqn); + if (seqn == prev_seqn) { + if (--poll_cnt) { hex_pause(); continue; } + FARF(HIGH, "hmx-queue-thread: sleeping"); + qurt_futex_wait(&q->seqn, prev_seqn); + continue; + } + prev_seqn = seqn; + poll_cnt = HMX_QUEUE_POLL_COUNT; + + FARF(HIGH, "hmx-queue-thread: new work"); + + hmx_queue_process(q, &killed); + } + + FARF(HIGH, "hmx-queue-thread: stopped"); +} + +struct hmx_queue * hmx_queue_create(size_t capacity, uint32_t hap_rctx) { + capacity = hex_ceil_pow2(capacity); + + struct hmx_queue * q = (struct hmx_queue *) memalign(32, sizeof(struct hmx_queue)); + if (q == NULL) { + FARF(ERROR, "%s: failed to allocate DMA queue\n", __FUNCTION__); + return NULL; + } + memset(q, 0, sizeof(struct hmx_queue)); + q->capacity = capacity; + q->idx_mask = capacity - 1; + q->hap_rctx = hap_rctx; + + q->desc = (struct hmx_queue_desc *) memalign(64, capacity * sizeof(struct hmx_queue_desc)); + if (!q->desc) { + FARF(ERROR, "hmx-queue: failed to allocate HMX queue descriptors\n"); + return NULL; + } + memset(q->desc, 0, capacity * sizeof(struct hmx_queue_desc)); + + const size_t stack_size = HMX_QUEUE_THREAD_STACK_SIZE; + q->stack = (unsigned char *) memalign(64, stack_size); + if (!q->stack) { + FARF(ERROR, "hmx-queue: thread stack allocation failed (%zu bytes)", stack_size); + return NULL; + } + memset(q->stack, 0, stack_size); + + // Match caller thread priority (same pattern as worker-pool.c). + int prio = qurt_thread_get_priority(qurt_thread_get_id()); + if (prio < 1) { + prio = 1; + } + if (prio > QURT_LOWEST_PRIO) { + prio = QURT_LOWEST_PRIO; + } + + qurt_thread_attr_t attr; + qurt_thread_attr_init(&attr); + qurt_thread_attr_set_stack_addr(&attr, q->stack); + qurt_thread_attr_set_stack_size(&attr, stack_size); + qurt_thread_attr_set_priority(&attr, prio); + qurt_thread_attr_set_name(&attr, "hmx-queue"); + + int err = qurt_thread_create(&q->thread, &attr, hmx_queue_thread, q); + if (err) { + FARF(ERROR, "hmx-worker: thread create failed (%d)", err); + return NULL; + } + + FARF(HIGH, "hmx-queue: capacity %u\n", capacity); + + return q; +} + +void hmx_queue_delete(struct hmx_queue * q) { + if (!q) { + return; + } + + // Tell the worker to exit. + hmx_queue_flush(q); + hmx_queue_signal(q, HMX_QUEUE_KILL); + hmx_queue_flush(q); + + int status; + qurt_thread_join(q->thread, &status); + + free(q->desc); + free(q->stack); + free(q); +} diff --git a/ggml/src/ggml-hexagon/htp/hmx-queue.h b/ggml/src/ggml-hexagon/htp/hmx-queue.h new file mode 100644 index 00000000000..0d48c280f52 --- /dev/null +++ b/ggml/src/ggml-hexagon/htp/hmx-queue.h @@ -0,0 +1,134 @@ +#ifndef HMX_QUEUE_H +#define HMX_QUEUE_H + +#include <stdbool.h> +#include <stdint.h> +#include <stdatomic.h> + +#include <hexagon_types.h> +#include <qurt_thread.h> +#include <qurt_futex.h> +#include <HAP_farf.h> + +#include "hex-utils.h" + +#ifdef __cplusplus +extern "C" { +#endif + +#define HMX_QUEUE_THREAD_STACK_SIZE (16 * 1024) +#define HMX_QUEUE_POLL_COUNT 2000 + +typedef void (*hmx_queue_func)(void *); + +// Dummy funcs used as signals +enum hmx_queue_signal { + HMX_QUEUE_NOOP = 0, // aka NULL + HMX_QUEUE_SUSPEND, + HMX_QUEUE_KILL +}; + +struct hmx_queue_desc { + hmx_queue_func func; + void * data; + atomic_uint done; +}; + +struct hmx_queue { + struct hmx_queue_desc * desc; + atomic_uint idx_write; // updated by producer (push) + atomic_uint idx_read; // updated by consumer (process) + unsigned int idx_pop; // updated by producer (pop) + uint32_t idx_mask; + uint32_t capacity; + + atomic_uint seqn; // incremented for all pushes, used with futex + qurt_thread_t thread; + void * stack; + uint32_t hap_rctx; + bool hmx_locked; +}; + +struct hmx_queue * hmx_queue_create(size_t capacity, uint32_t hap_rctx); +void hmx_queue_delete(struct hmx_queue * q); + +static inline struct hmx_queue_desc hmx_queue_make_desc(hmx_queue_func func, void * data) { + struct hmx_queue_desc d = { func, data }; + return d; +} + +static inline bool hmx_queue_push(struct hmx_queue * q, struct hmx_queue_desc d) { + unsigned int ir = atomic_load(&q->idx_read); + unsigned int iw = q->idx_write; + + if (((iw + 1) & q->idx_mask) == ir) { + FARF(HIGH, "hmx-queue-push: queue is full\n"); + return false; + } + + atomic_store(&d.done, 0); + + FARF(HIGH, "hmx-queue-push: iw %u func %p data %p\n", iw, d.func, d.data); + + q->desc[iw] = d; + atomic_store(&q->idx_write, (iw + 1) & q->idx_mask); + // wake up our thread + atomic_fetch_add(&q->seqn, 1); + qurt_futex_wake(&q->seqn, 1); + + return true; +} + +static inline bool hmx_queue_signal(struct hmx_queue *q, enum hmx_queue_signal sig) { + return hmx_queue_push(q, hmx_queue_make_desc((hmx_queue_func) sig, NULL)); +} + +static inline bool hmx_queue_empty(struct hmx_queue * q) { + return q->idx_pop == q->idx_write; +} + +static inline uint32_t hmx_queue_depth(struct hmx_queue * q) { + return (q->idx_read - q->idx_read) & q->idx_mask; +} + +static inline uint32_t hmx_queue_capacity(struct hmx_queue * q) { + return q->capacity; +} + +static inline struct hmx_queue_desc hmx_queue_pop(struct hmx_queue * q) { + unsigned int ip = q->idx_pop; + unsigned int iw = q->idx_write; + + struct hmx_queue_desc rd = { NULL, NULL }; + if (ip == iw) { + return rd; + } + + // Wait for desc to complete + struct hmx_queue_desc * d = &q->desc[ip]; + while (!atomic_load(&d->done)) { + FARF(HIGH, "hmx-queue-pop: waiting for HMX queue : %u\n", ip); + hex_pause(); + } + + rd = *d; + q->idx_pop = (ip + 1) & q->idx_mask; + + FARF(HIGH, "hmx-queue-pop: ip %u func %p data %p\n", ip, rd.func, rd.data); + return rd; +} + +static inline void hmx_queue_flush(struct hmx_queue * q) { + while (hmx_queue_pop(q).func != NULL) ; +} + +static inline void hmx_queue_suspend(struct hmx_queue *q) { + hmx_queue_signal(q, HMX_QUEUE_SUSPEND); + hmx_queue_flush(q); +} + +#ifdef __cplusplus +} // extern "C" +#endif + +#endif /* HMX_QUEUE_H */ diff --git a/ggml/src/ggml-hexagon/htp/hmx-utils.h b/ggml/src/ggml-hexagon/htp/hmx-utils.h new file mode 100644 index 00000000000..f448ee3372a --- /dev/null +++ b/ggml/src/ggml-hexagon/htp/hmx-utils.h @@ -0,0 +1,200 @@ +// HMX tile-level inline helpers (FP16 32x32 tile operations). +// Ported from htp-ops-lib/include/dsp/hmx_utils.h. (https://github.com/haozixu/htp-ops-lib) + +#ifndef HMX_UTILS_H +#define HMX_UTILS_H + +#include "hvx-base.h" + +#include <assert.h> +#include <hexagon_types.h> +#include <stddef.h> + +#define HMX_FP16_TILE_N_ROWS 32 +#define HMX_FP16_TILE_N_COLS 32 +#define HMX_FP16_TILE_N_ELMS 1024 +#define HMX_FP16_TILE_SIZE 2048 + +// Initialise aligned 256-byte area with scale vector + zero padding. +static inline void hmx_init_column_scales(void *out_scales, HVX_Vector v_scale) { + volatile HVX_Vector *pv = (HVX_Vector *) out_scales; + pv[0] = v_scale; + pv[1] = Q6_V_vzero(); +} + +// --- Shared scatter offsets and interleave helper --- + +// vscatter offsets for fused dequant+transpose: write K-values directly to [K][N] tile. +// word[i] = i*128 maps K-row-pair i to byte offset i*128. +// Column offset (n*4) is added at runtime. Entries 0..15 cover one tile (region 2047); +// entries 16..31 cover the next adjacent tile (region 4095) — pick region size at the +// call site to scatter into one tile (masked) or two contiguous tiles (unmasked). +static const int32_t hmx_transpose_scatter_offsets[32] __attribute__((aligned(VLEN))) = { + 0 * 128, 1 * 128, 2 * 128, 3 * 128, 4 * 128, 5 * 128, 6 * 128, 7 * 128, 8 * 128, 9 * 128, 10 * 128, + 11 * 128, 12 * 128, 13 * 128, 14 * 128, 15 * 128, 16 * 128, 17 * 128, 18 * 128, 19 * 128, 20 * 128, 21 * 128, + 22 * 128, 23 * 128, 24 * 128, 25 * 128, 26 * 128, 27 * 128, 28 * 128, 29 * 128, 30 * 128, 31 * 128, +}; + +// Scatter row-major FP16 data (in VTCM scratch) into transposed [K][N] tiles. +// vtcm_src: [n_cols][src_stride] row-major fp16 (only first k elements per row are used) +// vtcm_dst: [n_col_tiles][n_k_tiles][HMX_FP16_TILE_N_ELMS] tile-major interleaved fp16 +// Processes rows [start_row, end_row) for multi-thread slicing. +// Full range: start_row=0, end_row=n_cols. +static inline void hmx_interleave_rows_to_tiles(__fp16 * restrict vtcm_dst, + const __fp16 * restrict vtcm_src, + int n_cols, + int k, + int src_stride, + int start_row, + int end_row) { + assert(k % HMX_FP16_TILE_N_COLS == 0); + + const int n_k_tiles = k / HMX_FP16_TILE_N_COLS; + const HVX_Vector v_scat_base = hvx_vmem(hmx_transpose_scatter_offsets); + const HVX_Vector v_scat_step = Q6_V_vsplat_R(4); + const HVX_VectorPred q_mask64 = Q6_Q_vsetq_R(64); + // Each hvx_vmemu load brings 64 fp16 = 128 bytes covering 2 adjacent K-tiles. + // When n_k_tiles is even, scatter into 2 K-tiles per call (region 4095, no mask) + // using the upper half of hmx_transpose_scatter_offsets. Tail one K-tile (when + // n_k_tiles is odd) falls back to single-tile masked scatter. + const bool pair_scatter = (n_k_tiles & 1) == 0; + const size_t pair_region = (size_t) (2 * HMX_FP16_TILE_SIZE - 1); + const size_t single_region = (size_t) (HMX_FP16_TILE_SIZE - 1); + __builtin_assume(k > 0); + __builtin_assume(end_row > start_row); + + if (pair_scatter) { + // Step c by 64 fp16 (two K-tiles per scatter), advance dst by 2 tiles per iter. + const int c_step = 2 * HMX_FP16_TILE_N_COLS; + const size_t c_byte_step = (size_t) c_step * sizeof(__fp16); + const size_t dst_step = 2 * (size_t) HMX_FP16_TILE_N_ELMS; + const int n_c_iters = k / c_step; + + for (int r = start_row; r < end_row; r += 2) { + const int ct = r / HMX_FP16_TILE_N_ROWS; + const int local_r = r % HMX_FP16_TILE_N_ROWS; + const bool next_row_valid = (r + 1) < end_row && (r + 1) < n_cols; + const HVX_Vector v_off0 = Q6_Vw_vadd_VwVw(v_scat_base, Q6_V_vsplat_R(local_r * 4)); + const HVX_Vector v_off1 = Q6_Vw_vadd_VwVw(v_off0, v_scat_step); + + __fp16 * tile_base = vtcm_dst + (size_t) ct * n_k_tiles * HMX_FP16_TILE_N_ELMS; + const uint8_t * p0 = (const uint8_t *) (vtcm_src + r * src_stride); + const uint8_t * p1 = next_row_valid ? (const uint8_t *) (vtcm_src + (r + 1) * src_stride) : NULL; + + assert(hex_is_aligned(p0, 128)); + assert(hex_is_aligned(p1, 128)); + assert(c_byte_step % 128 == 0); + + if (p1) { + for (int i = 0; i < n_c_iters; ++i) { + HVX_Vector v0 = hvx_vmem(p0); p0 += c_byte_step; + HVX_Vector v1 = hvx_vmem(p1); p1 += c_byte_step; + Q6_vscatter_RMVwV((size_t) tile_base, pair_region, v_off0, v0); + Q6_vscatter_RMVwV((size_t) tile_base, pair_region, v_off1, v1); + tile_base += dst_step; + } + } else { + const HVX_Vector vzero = Q6_V_vzero(); + for (int i = 0; i < n_c_iters; ++i) { + HVX_Vector v0 = hvx_vmem(p0); p0 += c_byte_step; + Q6_vscatter_RMVwV((size_t) tile_base, pair_region, v_off0, v0); + Q6_vscatter_RMVwV((size_t) tile_base, pair_region, v_off1, vzero); + tile_base += dst_step; + } + } + } + } else { + // Fallback: scatter one K-tile per call (region 2047, masked). + const int c_step = HMX_FP16_TILE_N_COLS; + const size_t c_byte_step = (size_t) c_step * sizeof(__fp16); + const size_t dst_step = (size_t) HMX_FP16_TILE_N_ELMS; + const int n_c_iters = k / c_step; + + for (int r = start_row; r < end_row; r += 2) { + const int ct = r / HMX_FP16_TILE_N_ROWS; + const int local_r = r % HMX_FP16_TILE_N_ROWS; + const bool next_row_valid = (r + 1) < end_row && (r + 1) < n_cols; + const HVX_Vector v_off0 = Q6_Vw_vadd_VwVw(v_scat_base, Q6_V_vsplat_R(local_r * 4)); + const HVX_Vector v_off1 = Q6_Vw_vadd_VwVw(v_off0, v_scat_step); + + __fp16 * tile_base = vtcm_dst + (size_t) ct * n_k_tiles * HMX_FP16_TILE_N_ELMS; + const uint8_t * p0 = (const uint8_t *) (vtcm_src + r * src_stride); + const uint8_t * p1 = next_row_valid ? (const uint8_t *) (vtcm_src + (r + 1) * src_stride) : NULL; + + if (p1) { + for (int i = 0; i < n_c_iters; ++i) { + HVX_Vector v0 = hvx_vmemu(p0); p0 += c_byte_step; + HVX_Vector v1 = hvx_vmemu(p1); p1 += c_byte_step; + Q6_vscatter_QRMVwV(q_mask64, (size_t) tile_base, single_region, v_off0, v0); + Q6_vscatter_QRMVwV(q_mask64, (size_t) tile_base, single_region, v_off1, v1); + tile_base += dst_step; + } + } else { + const HVX_Vector vzero = Q6_V_vzero(); + for (int i = 0; i < n_c_iters; ++i) { + HVX_Vector v0 = hvx_vmemu(p0); p0 += c_byte_step; + Q6_vscatter_QRMVwV(q_mask64, (size_t) tile_base, single_region, v_off0, v0); + Q6_vscatter_QRMVwV(q_mask64, (size_t) tile_base, single_region, v_off1, vzero); + tile_base += dst_step; + } + } + } + } +} + +// Interleave row-major FP16 data into column-major tile format. +// Input: [n_rows, head_dim] row-major. Output: tile[dim_tile][row_tile]. +// Processes rows [start_row, end_row) for multi-thread slicing. +// Full range: start_row=0, end_row=n_rows. +static inline void hmx_interleave_cols_to_tiles(__fp16 * restrict tiles_out, + const __fp16 * restrict src, + int n_rows, + int head_dim, + int src_stride, + int n_row_tiles, + int start_row, + int end_row) { + __builtin_assume(head_dim > 0); + const size_t tile_stride_elms = (size_t) n_row_tiles * HMX_FP16_TILE_N_ELMS; + + for (int r = start_row; r < end_row; r += 2) { + const bool next_row_valid = (r + 1) < end_row && (r + 1) < n_rows; + + const HVX_Vector * pv_in0 = (const HVX_Vector *) (src + r * src_stride); + const HVX_Vector * pv_in1 = next_row_valid ? (const HVX_Vector *) (src + (r + 1) * src_stride) : NULL; + + // Row-pair invariants hoisted out of the c loop. + const int r0 = r / HMX_FP16_TILE_N_ROWS; + const int r1_half = (r % HMX_FP16_TILE_N_ROWS) / 2; + + // tb0 starts at tile (c0=0, r0); tb1 at the adjacent dim-tile (c0=1, r0). + // Each c step (+= 64) advances both by 2 dim-tiles worth of fp16. + __fp16 * tb0 = tiles_out + (size_t) r0 * HMX_FP16_TILE_N_ELMS; + __fp16 * tb1 = tb0 + tile_stride_elms; + const size_t tb_step = 2 * tile_stride_elms; + + if (pv_in1) { + for (int c = 0; c < head_dim; c += 64) { + HVX_Vector v0 = *pv_in0++; + HVX_Vector v1 = *pv_in1++; + HVX_VectorPair vp = Q6_W_vshuff_VVR(v1, v0, -2); + ((HVX_Vector *) tb0)[r1_half] = Q6_V_lo_W(vp); + ((HVX_Vector *) tb1)[r1_half] = Q6_V_hi_W(vp); + tb0 += tb_step; + tb1 += tb_step; + } + } else { + const HVX_Vector vzero = Q6_V_vzero(); + for (int c = 0; c < head_dim; c += 64) { + HVX_Vector v0 = *pv_in0++; + HVX_VectorPair vp = Q6_W_vshuff_VVR(vzero, v0, -2); + ((HVX_Vector *) tb0)[r1_half] = Q6_V_lo_W(vp); + ((HVX_Vector *) tb1)[r1_half] = Q6_V_hi_W(vp); + tb0 += tb_step; + tb1 += tb_step; + } + } + } +} + +#endif // HMX_UTILS_H diff --git a/ggml/src/ggml-hexagon/htp/htp-ctx.h b/ggml/src/ggml-hexagon/htp/htp-ctx.h index 4bd0ea7a36a..0f1676f077a 100644 --- a/ggml/src/ggml-hexagon/htp/htp-ctx.h +++ b/ggml/src/ggml-hexagon/htp/htp-ctx.h @@ -1,35 +1,118 @@ #ifndef HTP_CTX_H #define HTP_CTX_H -#include "htp-dma.h" +#include "hex-dma.h" +#include "hmx-queue.h" +#include "htp-ops.h" #include "worker-pool.h" #include <assert.h> #include <dspqueue.h> #include <stdatomic.h> #include <stdint.h> +#include <stdbool.h> #define HTP_MAX_NTHREADS 10 +#define HTP_MAX_MMAPS 16 + +// Memory mapping +struct htp_mmap { + uint64_t size; + uint64_t base; + uint32_t fd; + uint32_t reserved; +}; + +// Scratchpad state +struct htp_spad { + const struct htp_tensor * src; // original src of the data (for reuse) + uint8_t * data; // pointer to an area in vtcm + uint32_t stride; // stride used inside this spad + uint32_t size; // total size + uint32_t size_per_thread; // size per thread +}; + +struct htp_context; + +// Context while processing an Op +// TODO: fold this into the main context +struct htp_ops_context { + struct htp_context * ctx; + + enum htp_op_code op; // FIXME: rename to opcode + int32_t op_params[HTP_OP_MAX_PARAMS]; + + const struct htp_tensor * src[HTP_OP_MAX_INPUTS]; + const struct htp_tensor * dst; + + // TODO convert these to an array + struct htp_spad src0_spad; + struct htp_spad src1_spad; + struct htp_spad src2_spad; + struct htp_spad src3_spad; + struct htp_spad dst_spad; + + uint32_t n_threads; + uint32_t flags; +}; // Main context for htp DSP backend struct htp_context { - dspqueue_t queue; - dma_queue * dma[HTP_MAX_NTHREADS]; - worker_pool_context_t worker_pool; - uint32_t n_threads; + dspqueue_t queue; + dma_queue * dma[HTP_MAX_NTHREADS]; + struct htp_mmap mmap[HTP_MAX_MMAPS]; + worker_pool_context_t worker_pool; + uint32_t n_threads; + + int thread_id; + int thread_prio; - int thread_id; - int thread_prio; + bool hmx_enabled; + bool etm; + uint32_t profiler; - uint8_t * vtcm_base; - size_t vtcm_size; - uint32_t vtcm_rctx; + uint8_t * vtcm_base; + size_t vtcm_size; + uint32_t vtcm_rctx; + atomic_bool vtcm_valid; + atomic_bool vtcm_needs_release; - atomic_bool vtcm_valid; - atomic_bool vtcm_inuse; - atomic_bool vtcm_needs_release; + uint64_t max_vmem; - uint32_t opmask; + // Persistent DDR scratchpad for MUL_MAT_ID mappings + void * ddr_spad_base; + size_t ddr_spad_size; + + struct htp_ops_context octx; + +#ifdef HTP_HAS_HMX + struct hmx_queue * hmx_queue; // Async HMX queue for pipeline overlap +#endif }; +int op_matmul(struct htp_ops_context * octx); +int op_matmul_id(struct htp_ops_context * octx); +int op_binary(struct htp_ops_context * octx); +int op_unary(struct htp_ops_context * octx); +int op_sum_rows(struct htp_ops_context * octx); +int op_activations(struct htp_ops_context * octx); +int op_softmax(struct htp_ops_context * octx); +int op_add_id(struct htp_ops_context * octx); +int op_rope(struct htp_ops_context * octx); +int op_flash_attn_ext(struct htp_ops_context * octx); +int op_set_rows(struct htp_ops_context * octx); +int op_get_rows(struct htp_ops_context * octx); +int op_cpy(struct htp_ops_context * octx); +int op_repeat(struct htp_ops_context * octx); +int op_argsort(struct htp_ops_context * octx); +int op_ssm_conv(struct htp_ops_context * octx); +int op_cumsum(struct htp_ops_context * octx); +int op_fill(struct htp_ops_context * octx); +int op_concat(struct htp_ops_context * octx); +int op_diag(struct htp_ops_context * octx); +int op_solve_tri(struct htp_ops_context * octx); +int op_gated_delta_net(struct htp_ops_context * octx); +int op_tri(struct htp_ops_context * octx); +int op_pad(struct htp_ops_context * octx); + #endif /* HTP_CTX_H */ diff --git a/ggml/src/ggml-hexagon/htp/htp-dma.h b/ggml/src/ggml-hexagon/htp/htp-dma.h deleted file mode 100644 index 32fd06e7d46..00000000000 --- a/ggml/src/ggml-hexagon/htp/htp-dma.h +++ /dev/null @@ -1,157 +0,0 @@ -#ifndef HTP_DMA_H -#define HTP_DMA_H - -#include <HAP_farf.h> -#include <hexagon_protos.h> -#include <hexagon_types.h> -#include <stdbool.h> -#include <stdint.h> - -#ifdef __cplusplus -extern "C" { -#endif - -typedef struct { - void *dst; - const void *src; -} dma_ptr; - -typedef struct { - hexagon_udma_descriptor_type1_t * desc; // descriptor pointers - hexagon_udma_descriptor_type1_t * tail; // tail pointer - dma_ptr * dptr; // dst/src pointers - uint32_t push_idx; - uint32_t pop_idx; - uint32_t capacity; - uint32_t idx_mask; -} dma_queue; - -dma_queue * dma_queue_create(size_t capacity); -void dma_queue_delete(dma_queue * q); -void dma_queue_flush(dma_queue * q); - -// TODO: technically we don't need these and could use Q6_dmstart/wait/etc instead -// but those do not seem to always compiler properly. -static inline void dmstart(void * next) { - asm volatile(" release(%0):at" : : "r"(next)); - asm volatile(" dmstart(%0)" : : "r"(next)); -} - -static inline void dmlink(void * cur, void * next) { - asm volatile(" release(%0):at" : : "r"(next)); - asm volatile(" dmlink(%0, %1)" : : "r"(cur), "r"(next)); -} - -static inline unsigned int dmpoll(void) { - unsigned int ret = 0; - asm volatile(" %0 = dmpoll" : "=r"(ret) : : "memory"); - return ret; -} - -static inline unsigned int dmwait(void) { - unsigned int ret = 0; - asm volatile(" %0 = dmwait" : "=r"(ret) : : "memory"); - return ret; -} - -static inline dma_ptr dma_make_ptr(void *dst, const void *src) -{ - dma_ptr p = { dst, src }; - return p; -} - -static inline bool dma_queue_push(dma_queue * q, - dma_ptr dptr, - size_t dst_row_size, - size_t src_row_size, - size_t width, // width in bytes. number of bytes to transfer per row - size_t nrows) { - if (((q->push_idx + 1) & q->idx_mask) == q->pop_idx) { - FARF(ERROR, "dma-push: queue full\n"); - return false; - } - - hexagon_udma_descriptor_type1_t * desc = &q->desc[q->push_idx]; - - desc->next = NULL; - desc->length = 0; - desc->desctype = HEXAGON_UDMA_DESC_DESCTYPE_TYPE1; - desc->dstbypass = 1; - desc->srcbypass = 1; -#if __HVX_ARCH__ >= 73 - desc->dstbypass = 1; - desc->srcbypass = 1; -#else - desc->dstbypass = 0; - desc->srcbypass = 1; -#endif - desc->order = 0; - desc->dstate = HEXAGON_UDMA_DESC_DSTATE_INCOMPLETE; - desc->src = (void *) dptr.src; - desc->dst = (void *) dptr.dst; - desc->allocation = 0; - desc->padding = 0; - desc->roiwidth = width; - desc->roiheight = nrows; - desc->srcstride = src_row_size; - desc->dststride = dst_row_size; - desc->srcwidthoffset = 0; - desc->dstwidthoffset = 0; - - q->dptr[q->push_idx] = dptr; - - dmlink(q->tail, desc); - q->tail = desc; - - // FARF(ERROR, "dma-push: i %u len %u dst %p src %p\n", q->push_idx, len, dst, src); - q->push_idx = (q->push_idx + 1) & q->idx_mask; - return true; -} - -static inline bool dma_queue_push_ddr_to_vtcm(dma_queue * q, - dma_ptr dptr, - size_t dst_row_size, - size_t src_row_size, - size_t nrows) { - return dma_queue_push(q, dptr, dst_row_size, src_row_size, src_row_size, nrows); -} - - -static inline bool dma_queue_push_vtcm_to_ddr(dma_queue * q, - dma_ptr dptr, - size_t dst_row_size, - size_t src_row_size, - size_t nrows) { - return dma_queue_push(q, dptr, dst_row_size, src_row_size, dst_row_size, nrows); -} - -static inline dma_ptr dma_queue_pop(dma_queue * q) { - dma_ptr dptr = { NULL }; - - if (q->push_idx == q->pop_idx) { - return dptr; - } - - hexagon_udma_descriptor_type1_t * desc = &q->desc[q->pop_idx]; - - // Wait for desc to complete - while (1) { - dmpoll(); - if (desc->dstate == HEXAGON_UDMA_DESC_DSTATE_COMPLETE) { - break; - } - // FARF(ERROR, "dma-pop: waiting for DMA : %u\n", q->pop_idx); - } - - dptr = q->dptr[q->pop_idx]; - - // FARF(ERROR, "dma-pop: i %u dst %p\n", q->pop_idx, dst); - q->pop_idx = (q->pop_idx + 1) & q->idx_mask; - return dptr; -} - -#ifdef __cplusplus -} // extern "C" -#endif - -#endif /* HTP_DMA_H */ diff --git a/ggml/src/ggml-hexagon/htp/htp-msg.h b/ggml/src/ggml-hexagon/htp/htp-msg.h deleted file mode 100644 index 846d0617843..00000000000 --- a/ggml/src/ggml-hexagon/htp/htp-msg.h +++ /dev/null @@ -1,165 +0,0 @@ -#ifndef HTP_MSG_H -#define HTP_MSG_H - -#include <assert.h> - -// ggml-common.h must be included prio to this header - -// Mask to enable various stages of the Ops. -// Used for debugging and profiling. -enum { - HTP_OPMASK_QUEUE = (1 << 0), // Enable Queueing (ie calls into the DSP) - HTP_OPMASK_QUANTIZE = (1 << 1), // Enable Quantize - HTP_OPMASK_COMPUTE = (1 << 2), // Enable Compute -}; - -// Op flags -enum { - HTP_OPFLAGS_SKIP_QUANTIZE = (1 << 0), // Skip dynamic quantization (reuse quantized tensors) - HTP_OPFLAGS_SKIP_COMPUTE = (1 << 1), // Skip actual computation (used for profiling) - HTP_OPFLAGS_EARLY_WAKEUP = (1 << 2) // Send early wakeup notification -}; - -enum htp_status { - HTP_STATUS_OK = 1, - HTP_STATUS_INTERNAL_ERR = 2, - HTP_STATUS_NO_SUPPORT = 3, - HTP_STATUS_INVAL_PARAMS = 4, - HTP_STATUS_VTCM_TOO_SMALL = 5, -}; - -// The values must match the ggml_type. -// Duplicated here because we can't include full ggml.h in the htp build. -// We have some static_asserts in the cpp code to ensure things are in sync. -enum htp_data_type { - HTP_TYPE_F32 = 0, - HTP_TYPE_F16 = 1, - HTP_TYPE_Q4_0 = 2, - HTP_TYPE_Q8_0 = 8, - HTP_TYPE_I32 = 26, - HTP_TYPE_I64 = 27, - HTP_TYPE_MXFP4 = 39, - HTP_TYPE_COUNT -}; - -// These values are manually translated over to HTP -// !!!! DO NOT ALTER THE ORDER OF THE FIRST FOUR ENUMS !!!! -enum htp_op { - HTP_OP_MUL = 0, - HTP_OP_ADD = 1, - HTP_OP_SUB = 2, - HTP_OP_DIV = 3, - HTP_OP_MUL_MAT = 4, - HTP_OP_MUL_MAT_ID = 5, - HTP_OP_RMS_NORM = 6, - HTP_OP_UNARY_SILU = 7, - HTP_OP_UNARY_GELU = 8, - HTP_OP_GLU_SWIGLU = 9, - HTP_OP_GLU_SWIGLU_OAI = 10, - HTP_OP_SOFTMAX = 11, - HTP_OP_ADD_ID = 12, - HTP_OP_ROPE = 13, - HTP_OP_FLASH_ATTN_EXT = 14, - HTP_OP_SET_ROWS = 15, - HTP_OP_SCALE = 16, - HTP_OP_GET_ROWS = 17, - INVALID -}; - -static inline size_t htp_type_block_size(uint32_t t) { - switch (t) { - case HTP_TYPE_F32: - return 1; - case HTP_TYPE_F16: - return 1; - case HTP_TYPE_Q4_0: - return QK4_0; - case HTP_TYPE_Q8_0: - return QK8_0; - case HTP_TYPE_MXFP4: - return QK_MXFP4; - default: - assert(0 && "unsupported HTP data type"); - } - return 0; -} - -static inline size_t htp_type_nbytes(uint32_t t) { - switch (t) { - case HTP_TYPE_F32: - return 4; - case HTP_TYPE_F16: - return 2; - case HTP_TYPE_Q4_0: - return sizeof(block_q4_0); - case HTP_TYPE_Q8_0: - return sizeof(block_q8_0); - case HTP_TYPE_MXFP4: - return sizeof(block_mxfp4); - default: - assert(0 && "unsupported HTP data type"); - } - return 0; -} - -static const char * htp_type_name(uint32_t t) { - switch (t) { - case HTP_TYPE_F32: - return "fp32"; - case HTP_TYPE_F16: - return "fp16"; - case HTP_TYPE_Q4_0: - return "q4_0"; - case HTP_TYPE_Q8_0: - return "q8_0"; - case HTP_TYPE_MXFP4: - return "mxfp4"; - } - return 0; -} - -// Internal types -#define QK_Q4_0x4x2 256 // 4x Q4_0 blocks packed with next 4x Q4_0 blocks (size in bytes 128) -#define QK_Q8_0x4x2 256 // 4x Q8_0 blocks concat with next 4x Q8_0 blocks -#define QK_MXFP4x4x2 256 // 4x MXFP4 blocks concat with next 4x MXFP4 blocks - -#define HTP_MAX_DIMS 4 - -struct htp_tensor { - uint32_t data; // Buffer offset in the messages, and data pointer on the NSP - uint32_t type; // Data type - uint32_t ne[HTP_MAX_DIMS]; // Number of elements - uint32_t nb[HTP_MAX_DIMS]; // Stride in bytes (see ggml.h ggml_tensor) -}; - -#define HTP_MAX_OP_PARAMS 64 - -struct htp_general_req { - uint32_t op; // GGML/HTP Op - int32_t op_params[HTP_MAX_OP_PARAMS / sizeof(int32_t)]; - // Params for the op, e.g. epsilon of RMS norm - uint32_t flags; // Request flags - - struct htp_tensor src0; // Input0 tensor - struct htp_tensor src1; // Input1 tensor - struct htp_tensor src2; // Input2 tensor - struct htp_tensor src3; // Input3 tensor - struct htp_tensor src4; // Input4 tensor - struct htp_tensor dst; // Output tensor - - // should be multiple of 64 bytes (cacheline) -}; - -struct htp_general_rsp { - uint32_t op; // GGML/HTP Op - uint32_t status; // HTP_STATUS_... - uint32_t prof_usecs; // Number of usec per request - uint32_t prof_cycles; // Number of cycles per request - uint32_t prof_pkts; // Number of instruction packets per request - uint8_t unused[44]; // Pad to 64 bytes -}; - -#define HTP_MAX_MESSAGE_SIZE sizeof(struct htp_general_req) -#define HTP_MAX_PACKET_BUFFERS 8 - -#endif /* HTP_MSG_H */ diff --git a/ggml/src/ggml-hexagon/htp/htp-ops.h b/ggml/src/ggml-hexagon/htp/htp-ops.h index 7c828ae6362..fa85bf4ca0c 100644 --- a/ggml/src/ggml-hexagon/htp/htp-ops.h +++ b/ggml/src/ggml-hexagon/htp/htp-ops.h @@ -1,92 +1,188 @@ #ifndef HTP_OPS_H #define HTP_OPS_H -#include "htp-ctx.h" -#include "htp-msg.h" -#include "worker-pool.h" -#include "ops-utils.h" - #include <assert.h> -#include <stdint.h> -// ggml-common.h must be included prior to this header +// ggml-common.h must be included prio to this header + +enum htp_status { + HTP_STATUS_OK = 1, + HTP_STATUS_INTERNAL_ERR = 2, + HTP_STATUS_NO_SUPPORT = 3, + HTP_STATUS_INVAL_PARAMS = 4, + HTP_STATUS_VTCM_TOO_SMALL = 5, +}; + +// First set of values must match the ggml_type. +// Duplicated here because we can't include full ggml.h in the htp build. +// We have some static_asserts in the cpp code to ensure things are in sync. +enum htp_data_type { + HTP_TYPE_F32 = 0, + HTP_TYPE_F16 = 1, + HTP_TYPE_Q4_0 = 2, + HTP_TYPE_Q4_1 = 3, + HTP_TYPE_Q8_0 = 8, + HTP_TYPE_IQ4_NL = 20, + HTP_TYPE_I32 = 26, + HTP_TYPE_I64 = 27, + HTP_TYPE_MXFP4 = 39, + + // types used internally for repack, dyn.quant, etc + HTP_TYPE_Q4_0x4x2 = 200, + HTP_TYPE_Q4_1x4x2, + HTP_TYPE_Q8_0x4x2, + HTP_TYPE_MXFP4x4x2, + + HTP_TYPE_INVALID +}; + +// Constats for internal types +#define QK_Q4_0x4x2 256 // 4x Q4_0 blocks packed with next 4x Q4_0 blocks (size in bytes 128) +#define QK_Q8_0x4x2 256 // 4x Q8_0 blocks concat with next 4x Q8_0 blocks +#define QK_MXFP4x4x2 256 // 4x MXFP4 blocks concat with next 4x MXFP4 blocks -struct htp_spad { - uint8_t * data; - size_t stride; - size_t size; - size_t size_per_thread; + +// Mask to enable various stages of the Ops. +// Used for debugging and profiling. +enum htp_op_stage { + HTP_OPSTAGE_QUEUE = (1 << 0), // Enable Queueing (ie calls into NPU) + HTP_OPSTAGE_COMPUTE = (1 << 1), // Enable Compute }; -struct htp_ops_context { - struct htp_context * ctx; +// Do not reorder first 4 (used as an index) +enum htp_op_code { + HTP_OP_MUL = 0, + HTP_OP_ADD = 1, + HTP_OP_SUB = 2, + HTP_OP_DIV = 3, + HTP_OP_MUL_MAT, + HTP_OP_MUL_MAT_ID, + HTP_OP_RMS_NORM, + HTP_OP_RMS_NORM_MUL, + HTP_OP_UNARY_SILU, + HTP_OP_UNARY_GELU, + HTP_OP_UNARY_SIGMOID, + HTP_OP_UNARY_EXP, + HTP_OP_UNARY_NEG, + HTP_OP_UNARY_SOFTPLUS, + HTP_OP_UNARY_TANH, + HTP_OP_GLU_SWIGLU, + HTP_OP_GLU_SWIGLU_OAI, + HTP_OP_GLU_GEGLU, + HTP_OP_SOFTMAX, + HTP_OP_ADD_ID, + HTP_OP_ROPE, + HTP_OP_FLASH_ATTN_EXT, + HTP_OP_SET_ROWS, + HTP_OP_GET_ROWS, + HTP_OP_SCALE, + HTP_OP_CPY, + HTP_OP_ARGSORT, + HTP_OP_SQR, + HTP_OP_SQRT, + HTP_OP_SUM_ROWS, + HTP_OP_SSM_CONV, + HTP_OP_REPEAT, + HTP_OP_CUMSUM, + HTP_OP_FILL, + HTP_OP_DIAG, + HTP_OP_SOLVE_TRI, + HTP_OP_L2_NORM, + HTP_OP_GATED_DELTA_NET, + HTP_OP_TRI, + HTP_OP_PAD, + HTP_OP_NORM, + HTP_OP_CONCAT, + + HTP_OP_INVALID +}; - enum htp_op op; - int32_t op_params[HTP_MAX_OP_PARAMS / sizeof(int32_t)]; +#define HTP_OP_MAX_DIMS 4 // aka GGML_MAX_DIMS +#define HTP_OP_MAX_INPUTS 6 // aka GGML_MAX_SRCS +#define HTP_OP_MAX_PARAMS 16 // aka GGML_MAX_OP_PARAMS - struct htp_tensor src0; - struct htp_tensor src1; - struct htp_tensor src2; - struct htp_tensor src3; - struct htp_tensor src4; - struct htp_tensor dst; +#define HTP_OP_MAX_BUFS 16 +#define HTP_OP_MAX_REQS 256 +#define HTP_OP_MAX_TENSORS (HTP_OP_MAX_REQS * HTP_OP_MAX_INPUTS + HTP_OP_MAX_REQS) - struct htp_spad src0_spad; - struct htp_spad src1_spad; - struct htp_spad src2_spad; - struct htp_spad src3_spad; - struct htp_spad dst_spad; +#define HTP_OP_MAX_VMEM_DEFAULT (3355443200u) - worker_pool_context_t * wpool; // worker pool - uint32_t n_threads; // num threads +#define HTP_MMAP_MAX_VMEM (2147483648u) - uint32_t src0_nrows_per_thread; - uint32_t src1_nrows_per_thread; +enum htp_tensor_flags { + HTP_TENSOR_COMPUTE = (1U << 0), // Tensor buffer temporal compute data (not weights) + HTP_TENSOR_FLUSHED = (1U << 1) // Tensor buffer has been flushed (set by the NPU) +}; - struct fastdiv_values src0_div1; // fastdiv values for ne1 - struct fastdiv_values src0_div2; // fastdiv values for ne2 - struct fastdiv_values src0_div3; // fastdiv values for ne3 - struct fastdiv_values src0_div21; // fastdiv values for ne2 * ne1 +// Tensor descriptor +struct htp_tensor { + uint32_t data; // Buffer offset in the messages, and data pointer on the NPU + uint32_t size; // Data size in bytes + uint32_t flags; // Buffer / tensor flags + uint16_t type; // Data type + uint16_t bi; // Buffer index + uint32_t ne[HTP_OP_MAX_DIMS]; // Number of elements + uint32_t nb[HTP_OP_MAX_DIMS]; // Stride in bytes (see ggml.h ggml_tensor) +}; - struct fastdiv_values src1_div1; // fastdiv values for ne1 - struct fastdiv_values src1_div2; // fastdiv values for ne2 - struct fastdiv_values src1_div3; // fastdiv values for ne3 - struct fastdiv_values src1_div21; // fastdiv values for ne2 * ne1 +// Buffer descriptor +struct htp_buf_desc { + uint64_t base; // base address + uint64_t size; // total size + uint32_t flags; // buffer flags (unused) + uint32_t fd; // file descriptor +}; - struct fastdiv_values src3_div1; // fastdiv values for ne1 - struct fastdiv_values src3_div2; // fastdiv values for ne2 - struct fastdiv_values src3_div3; // fastdiv values for ne3 - struct fastdiv_values src3_div21; // fastdiv values for ne2 * ne1 +enum htp_op_flags { + HTP_OPFLAGS_SKIP_COMPUTE = (1U << 0), // Skip actual computation (used for profiling) +}; - struct fastdiv_values broadcast_rk2; - struct fastdiv_values broadcast_rk3; - struct fastdiv_values broadcast_rv2; - struct fastdiv_values broadcast_rv3; +// Op descriptor +struct htp_op_desc { + uint32_t opcode; // GGML/HTP Op + uint32_t flags; // Op flags + int32_t params[HTP_OP_MAX_PARAMS]; // Params for the op, e.g. epsilon of RMS norm + uint16_t src[HTP_OP_MAX_INPUTS]; // Input tensors indices + uint16_t dst; // Output tensor index +}; - struct fastdiv_values mm_div_ne12_ne1; // fastdiv values for ne12 * ne1 - struct fastdiv_values mm_div_ne1; // fastdiv values for ne1 - struct fastdiv_values mm_div_r2; // fastdiv values for ne12 / ne02 - struct fastdiv_values mm_div_r3; // fastdiv values for ne13 / ne03 +enum htp_profiler_mode { + HTP_PROF_DISABLED = 0, + HTP_PROF_BASIC = 1, + HTP_PROF_PMU = 2, +}; - struct fastdiv_values set_rows_div_ne12; // fastdiv values for ne12 - struct fastdiv_values set_rows_div_ne11; // fastdiv values for ne11 +#define HTP_PROF_PMU_NCNT 8 - struct fastdiv_values get_rows_div_ne10; // fastdiv values for ne10 - struct fastdiv_values get_rows_div_ne10_ne11; // fastdiv values for ne10 * ne11 +// Profile descriptor +struct htp_prof_desc { + uint32_t opcode; // GGML/HTP Op + uint32_t usecs; // Number of usec + uint32_t cycles; // Number of cycles + uint32_t pad; // Unused + uint32_t pmu[HTP_PROF_PMU_NCNT]; // PMU counters +}; - uint32_t flags; +struct htp_opbatch_req { + uint32_t id; // Batch id + uint32_t n_bufs; // Number of buffers + uint32_t n_tensors; // Number of tensors + uint32_t n_ops; // Number of ops + uint32_t flags; // unused + uint32_t pad; // unused + // struct htp_buf_desc bufs[]; -- dspqueue buf 0 + // struct htp_tensor tensors[]; -- dspqueue buf 0 + // struct htp_op_desc ops[]; -- dspqueue buf 0 }; -int op_matmul(struct htp_ops_context * octx); -int op_matmul_id(struct htp_ops_context * octx); -int op_binary(struct htp_ops_context * octx); -int op_unary(struct htp_ops_context * octx); -int op_activations(struct htp_ops_context * octx); -int op_softmax(struct htp_ops_context * octx); -int op_add_id(struct htp_ops_context * octx); -int op_rope(struct htp_ops_context * octx); -int op_flash_attn_ext(struct htp_ops_context * octx); -int op_set_rows(struct htp_ops_context * octx); -int op_get_rows(struct htp_ops_context * octx); +struct htp_opbatch_rsp { + uint32_t id; // Batch id + uint32_t status; // HTP_STATUS_... + uint32_t n_bufs; // Number of buffers + uint32_t n_tensors; // Number of tensors + uint32_t n_ops; // Number of op profile descriptors + uint32_t pad; // unused + // struct htp_prof_desc profs[]; -- dspqueue buf 0 +}; #endif /* HTP_OPS_H */ diff --git a/ggml/src/ggml-hexagon/htp/htp_iface.idl b/ggml/src/ggml-hexagon/htp/htp_iface.idl index 9ebd937e46d..d696a5fba0c 100644 --- a/ggml/src/ggml-hexagon/htp/htp_iface.idl +++ b/ggml/src/ggml-hexagon/htp/htp_iface.idl @@ -6,11 +6,17 @@ #include "AEEStdDef.idl" #include "remote.idl" +struct htp_iface_pmu_conf { + uint32 events[8]; +}; + interface htp_iface : remote_handle64 { - AEEResult start(in uint32 sess_id, in uint64 dsp_queue_id, in uint32 n_hvx); + AEEResult start(in uint32 sess_id, in uint64 dsp_queue_id, in uint32 n_hvx, in uint32 use_hmx, in uint64 max_vmem); AEEResult stop(); - AEEResult enable_etm(); - AEEResult disable_etm(); + AEEResult mmap(in uint32 fd, in uint32 size); + AEEResult munmap(in uint32 fd); + AEEResult profiler(in uint32 mode, in htp_iface_pmu_conf pmu); + AEEResult etm(in uint32 enable); }; #endif /* HTP_IDL */ diff --git a/ggml/src/ggml-hexagon/htp/hvx-arith.h b/ggml/src/ggml-hexagon/htp/hvx-arith.h new file mode 100644 index 00000000000..82e3416970b --- /dev/null +++ b/ggml/src/ggml-hexagon/htp/hvx-arith.h @@ -0,0 +1,443 @@ +#ifndef HVX_ARITH_H +#define HVX_ARITH_H + +#include <assert.h> +#include <stddef.h> +#include <stdint.h> +#include <math.h> + +#include "hvx-base.h" +#include "hex-utils.h" + +// +// Binary operations (add, mul, sub) +// + +#define UNUSED(x) (void)(x) + +#define hvx_arith_loop_body(dst_type, src0_type, src1_type, elem_size, vec_store, vec_op) \ + do { \ + dst_type * restrict vdst = (dst_type *) dst; \ + src0_type * restrict vsrc0 = (src0_type *) src0; \ + src1_type * restrict vsrc1 = (src1_type *) src1; \ + \ + const uint32_t epv = 128 / (elem_size); \ + const uint32_t nvec = n / epv; \ + const uint32_t nloe = n % epv; \ + \ + uint32_t i = 0; \ + \ + _Pragma("unroll(4)") \ + for (; i < nvec; i++) { \ + vdst[i] = vec_op(vsrc0[i], vsrc1[i]); \ + } \ + if (nloe) { \ + HVX_Vector v = vec_op(vsrc0[i], vsrc1[i]); \ + vec_store((void *) &vdst[i], nloe * (elem_size), v); \ + } \ + } while(0) + +#if __HVX_ARCH__ < 79 + +#define HVX_OP_ADD_F32(a, b) Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_VsfVsf(a, b)) +#define HVX_OP_SUB_F32(a, b) Q6_Vsf_equals_Vqf32(Q6_Vqf32_vsub_VsfVsf(a, b)) +#define HVX_OP_MUL_F32(a, b) Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(a, b)) + +#else + +#define HVX_OP_ADD_F32(a, b) Q6_Vsf_vadd_VsfVsf(a, b) +#define HVX_OP_SUB_F32(a, b) Q6_Vsf_vsub_VsfVsf(a, b) +#define HVX_OP_MUL_F32(a, b) Q6_Vsf_vmpy_VsfVsf(a, b) + +#endif + +#define HVX_OP_ADD_F16(a, b) hvx_vec_add_f16_f16(a, b) +#define HVX_OP_SUB_F16(a, b) hvx_vec_sub_f16_f16(a, b) +#define HVX_OP_MUL_F16(a, b) hvx_vec_mul_f16_f16(a, b) + +// Generic macro to define alignment permutations for an op +#define DEFINE_HVX_BINARY_OP_VARIANTS(OP_NAME, OP_MACRO, ELEM_TYPE) \ +static inline void OP_NAME##_aaa(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) { \ + assert((uintptr_t) dst % 128 == 0); \ + assert((uintptr_t) src0 % 128 == 0); \ + assert((uintptr_t) src1 % 128 == 0); \ + hvx_arith_loop_body(HVX_Vector, HVX_Vector, HVX_Vector, sizeof(ELEM_TYPE), hvx_vec_store_a, OP_MACRO); \ +} \ +static inline void OP_NAME##_aau(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) { \ + assert((uintptr_t) dst % 128 == 0); \ + assert((uintptr_t) src0 % 128 == 0); \ + hvx_arith_loop_body(HVX_Vector, HVX_Vector, HVX_UVector, sizeof(ELEM_TYPE), hvx_vec_store_a, OP_MACRO); \ +} \ +static inline void OP_NAME##_aua(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) { \ + assert((uintptr_t) dst % 128 == 0); \ + assert((uintptr_t) src1 % 128 == 0); \ + hvx_arith_loop_body(HVX_Vector, HVX_UVector, HVX_Vector, sizeof(ELEM_TYPE), hvx_vec_store_a, OP_MACRO); \ +} \ +static inline void OP_NAME##_auu(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) { \ + assert((uintptr_t) dst % 128 == 0); \ + hvx_arith_loop_body(HVX_Vector, HVX_UVector, HVX_UVector, sizeof(ELEM_TYPE), hvx_vec_store_a, OP_MACRO); \ +} \ +static inline void OP_NAME##_uaa(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) { \ + assert((uintptr_t) src0 % 128 == 0); \ + assert((uintptr_t) src1 % 128 == 0); \ + hvx_arith_loop_body(HVX_UVector, HVX_Vector, HVX_Vector, sizeof(ELEM_TYPE), hvx_vec_store_u, OP_MACRO); \ +} \ +static inline void OP_NAME##_uau(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) { \ + assert((uintptr_t) src0 % 128 == 0); \ + hvx_arith_loop_body(HVX_UVector, HVX_Vector, HVX_UVector, sizeof(ELEM_TYPE), hvx_vec_store_u, OP_MACRO); \ +} \ +static inline void OP_NAME##_uua(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) { \ + assert((uintptr_t) src1 % 128 == 0); \ + hvx_arith_loop_body(HVX_UVector, HVX_UVector, HVX_Vector, sizeof(ELEM_TYPE), hvx_vec_store_u, OP_MACRO); \ +} \ +static inline void OP_NAME##_uuu(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) { \ + hvx_arith_loop_body(HVX_UVector, HVX_UVector, HVX_UVector, sizeof(ELEM_TYPE), hvx_vec_store_u, OP_MACRO); \ +} \ + +DEFINE_HVX_BINARY_OP_VARIANTS(hvx_add_f32, HVX_OP_ADD_F32, float) +DEFINE_HVX_BINARY_OP_VARIANTS(hvx_sub_f32, HVX_OP_SUB_F32, float) +DEFINE_HVX_BINARY_OP_VARIANTS(hvx_mul_f32, HVX_OP_MUL_F32, float) + +DEFINE_HVX_BINARY_OP_VARIANTS(hvx_add_f16, HVX_OP_ADD_F16, _Float16) +DEFINE_HVX_BINARY_OP_VARIANTS(hvx_sub_f16, HVX_OP_SUB_F16, _Float16) +DEFINE_HVX_BINARY_OP_VARIANTS(hvx_mul_f16, HVX_OP_MUL_F16, _Float16) + +// Dispatcher logic +#define HVX_BINARY_DISPATCHER(OP_NAME) \ +static inline void OP_NAME(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, const uint32_t num_elems) { \ + if (hex_is_aligned((void *) dst, 128)) { \ + if (hex_is_aligned((void *) src0, 128)) { \ + if (hex_is_aligned((void *) src1, 128)) OP_NAME##_aaa(dst, src0, src1, num_elems); \ + else OP_NAME##_aau(dst, src0, src1, num_elems); \ + } else { \ + if (hex_is_aligned((void *) src1, 128)) OP_NAME##_aua(dst, src0, src1, num_elems); \ + else OP_NAME##_auu(dst, src0, src1, num_elems); \ + } \ + } else { \ + if (hex_is_aligned((void *) src0, 128)) { \ + if (hex_is_aligned((void *) src1, 128)) OP_NAME##_uaa(dst, src0, src1, num_elems); \ + else OP_NAME##_uau(dst, src0, src1, num_elems); \ + } else { \ + if (hex_is_aligned((void *) src1, 128)) OP_NAME##_uua(dst, src0, src1, num_elems); \ + else OP_NAME##_uuu(dst, src0, src1, num_elems); \ + } \ + } \ +} + +HVX_BINARY_DISPATCHER(hvx_add_f32) +HVX_BINARY_DISPATCHER(hvx_sub_f32) +HVX_BINARY_DISPATCHER(hvx_mul_f32) + +HVX_BINARY_DISPATCHER(hvx_add_f16) +HVX_BINARY_DISPATCHER(hvx_sub_f16) +HVX_BINARY_DISPATCHER(hvx_mul_f16) + +// Mul-Mul Optimized +static inline void hvx_mul_mul_f32_aa(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, const uint8_t * restrict src2, const uint32_t num_elems) { + assert((unsigned long) dst % 128 == 0); + assert((unsigned long) src0 % 128 == 0); + assert((unsigned long) src1 % 128 == 0); + assert((unsigned long) src2 % 128 == 0); + + HVX_Vector * restrict vdst = (HVX_Vector *) dst; + HVX_Vector * restrict vsrc0 = (HVX_Vector *) src0; + HVX_Vector * restrict vsrc1 = (HVX_Vector *) src1; + HVX_Vector * restrict vsrc2 = (HVX_Vector *) src2; + + const uint32_t elem_size = sizeof(float); + const uint32_t epv = 128 / elem_size; + const uint32_t nvec = num_elems / epv; + const uint32_t nloe = num_elems % epv; + + uint32_t i = 0; + + _Pragma("unroll(4)") + for (; i < nvec; i++) { + HVX_Vector v1 = HVX_OP_MUL_F32(vsrc0[i], vsrc1[i]); + vdst[i] = HVX_OP_MUL(v1, vsrc2[i]); + } + + if (nloe) { + HVX_Vector v1 = HVX_OP_MUL_F32(vsrc0[i], vsrc1[i]); + HVX_Vector v2 = HVX_OP_MUL_F32(v1, vsrc2[i]); + hvx_vec_store_a((void *) &vdst[i], nloe * elem_size, v2); + } +} + +// Scalar Operations + +#define hvx_scalar_loop_body(dst_type, src_type, elem_size, vec_store, scalar_op_macro) \ + do { \ + dst_type * restrict vdst = (dst_type *) dst; \ + src_type * restrict vsrc = (src_type *) src; \ + \ + const uint32_t epv = 128 / (elem_size); \ + const uint32_t nvec = n / epv; \ + const uint32_t nloe = n % epv; \ + \ + uint32_t i = 0; \ + \ + _Pragma("unroll(4)") \ + for (; i < nvec; i++) { \ + HVX_Vector v = vsrc[i]; \ + vdst[i] = scalar_op_macro(v); \ + } \ + if (nloe) { \ + HVX_Vector v = vsrc[i]; \ + v = scalar_op_macro(v); \ + vec_store((void *) &vdst[i], nloe * (elem_size), v); \ + } \ + } while(0) + +#define HVX_OP_ADD_SCALAR_F32(v) \ + ({ \ + const HVX_VectorPred pred_inf = Q6_Q_vcmp_eq_VwVw(inf, v); \ + HVX_Vector out = HVX_OP_ADD_F32(v, val_vec); \ + Q6_V_vmux_QVV(pred_inf, inf, out); \ + }) + +#define HVX_OP_MUL_SCALAR_F32(v) HVX_OP_MUL_F32(v, val_vec) +#define HVX_OP_SUB_SCALAR_F32(v) HVX_OP_SUB_F32(v, val_vec) + +#define HVX_OP_ADD_SCALAR_F16(v) \ + ({ \ + const HVX_VectorPred pred_inf = Q6_Q_vcmp_eq_VhVh(inf, v); \ + HVX_Vector out = HVX_OP_ADD_F16(v, val_vec); \ + Q6_V_vmux_QVV(pred_inf, inf, out); \ + }) + +#define HVX_OP_MUL_SCALAR_F16(v) HVX_OP_MUL_F16(v, val_vec) +#define HVX_OP_SUB_SCALAR_F16(v) HVX_OP_SUB_F16(v, val_vec) + +// Scalar Variants + +// Generic macro to define alignment permutations for an op +#define DEFINE_HVX_BINARY_SCALAR_OP_VARIANTS(OP_NAME, OP_MACRO, SPLAT_MACRO, ELEM_TYPE) \ +static inline void OP_NAME##_aa(uint8_t * restrict dst, const uint8_t * restrict src, const ELEM_TYPE val, uint32_t n) { \ + const HVX_Vector val_vec = SPLAT_MACRO(val); \ + const HVX_Vector inf = SPLAT_MACRO((ELEM_TYPE)INFINITY); UNUSED(inf); \ + assert((uintptr_t) dst % 128 == 0); \ + assert((uintptr_t) src % 128 == 0); \ + hvx_scalar_loop_body(HVX_Vector, HVX_Vector, sizeof(ELEM_TYPE), hvx_vec_store_a, OP_MACRO); \ +} \ +static inline void OP_NAME##_au(uint8_t * restrict dst, const uint8_t * restrict src, const ELEM_TYPE val, uint32_t n) { \ + const HVX_Vector val_vec = SPLAT_MACRO(val); \ + const HVX_Vector inf = SPLAT_MACRO((ELEM_TYPE)INFINITY); UNUSED(inf); \ + assert((uintptr_t) dst % 128 == 0); \ + hvx_scalar_loop_body(HVX_Vector, HVX_UVector, sizeof(ELEM_TYPE), hvx_vec_store_a, OP_MACRO); \ +} \ +static inline void OP_NAME##_ua(uint8_t * restrict dst, const uint8_t * restrict src, const ELEM_TYPE val, uint32_t n) { \ + const HVX_Vector val_vec = SPLAT_MACRO(val); \ + const HVX_Vector inf = SPLAT_MACRO((ELEM_TYPE)INFINITY); UNUSED(inf); \ + assert((uintptr_t) src % 128 == 0); \ + hvx_scalar_loop_body(HVX_UVector, HVX_Vector, sizeof(ELEM_TYPE), hvx_vec_store_u, OP_MACRO); \ +} \ +static inline void OP_NAME##_uu(uint8_t * restrict dst, const uint8_t * restrict src, const ELEM_TYPE val, uint32_t n) { \ + const HVX_Vector val_vec = SPLAT_MACRO(val); \ + const HVX_Vector inf = SPLAT_MACRO((ELEM_TYPE)INFINITY); UNUSED(inf); \ + hvx_scalar_loop_body(HVX_UVector, HVX_UVector, sizeof(ELEM_TYPE), hvx_vec_store_u, OP_MACRO); \ +} \ + +DEFINE_HVX_BINARY_SCALAR_OP_VARIANTS(hvx_add_scalar_f32, HVX_OP_ADD_SCALAR_F32, hvx_vec_splat_f32, float) +DEFINE_HVX_BINARY_SCALAR_OP_VARIANTS(hvx_sub_scalar_f32, HVX_OP_SUB_SCALAR_F32, hvx_vec_splat_f32, float) +DEFINE_HVX_BINARY_SCALAR_OP_VARIANTS(hvx_mul_scalar_f32, HVX_OP_MUL_SCALAR_F32, hvx_vec_splat_f32, float) + +DEFINE_HVX_BINARY_SCALAR_OP_VARIANTS(hvx_add_scalar_f16, HVX_OP_ADD_SCALAR_F16, hvx_vec_splat_f16, _Float16) +DEFINE_HVX_BINARY_SCALAR_OP_VARIANTS(hvx_sub_scalar_f16, HVX_OP_SUB_SCALAR_F16, hvx_vec_splat_f16, _Float16) +DEFINE_HVX_BINARY_SCALAR_OP_VARIANTS(hvx_mul_scalar_f16, HVX_OP_MUL_SCALAR_F16, hvx_vec_splat_f16, _Float16) + +// Dispatcher logic +#define HVX_BINARY_SCALAR_DISPATCHER(OP_NAME, ELEM_TYPE) \ +static inline void OP_NAME(uint8_t * restrict dst, const uint8_t * restrict src, const ELEM_TYPE val, const uint32_t num_elems) { \ + if (hex_is_aligned((void *) dst, 128) && hex_is_aligned((void *) src, 128)) { \ + OP_NAME##_aa(dst, src, val, num_elems); \ + } else if (hex_is_aligned((void *) dst, 128)) { \ + OP_NAME##_au(dst, src, val, num_elems); \ + } else if (hex_is_aligned((void *) src, 128)) { \ + OP_NAME##_ua(dst, src, val, num_elems); \ + } else { \ + OP_NAME##_uu(dst, src, val, num_elems); \ + } \ +} + +HVX_BINARY_SCALAR_DISPATCHER(hvx_add_scalar_f32, float) +HVX_BINARY_SCALAR_DISPATCHER(hvx_sub_scalar_f32, float) +HVX_BINARY_SCALAR_DISPATCHER(hvx_mul_scalar_f32, float) + +HVX_BINARY_SCALAR_DISPATCHER(hvx_add_scalar_f16, _Float16) +HVX_BINARY_SCALAR_DISPATCHER(hvx_sub_scalar_f16, _Float16) +HVX_BINARY_SCALAR_DISPATCHER(hvx_mul_scalar_f16, _Float16) + +// MIN Scalar variants + +#define HVX_OP_MIN_SCALAR(v) Q6_Vsf_vmin_VsfVsf(val_vec, v) + +static inline void hvx_min_scalar_f32_aa(uint8_t * restrict dst, const uint8_t * restrict src, const float val, uint32_t n) { + const HVX_Vector val_vec = hvx_vec_splat_f32(val); + assert((unsigned long) dst % 128 == 0); + assert((unsigned long) src % 128 == 0); + hvx_scalar_loop_body(HVX_Vector, HVX_Vector, sizeof(float), hvx_vec_store_a, HVX_OP_MIN_SCALAR); +} + +static inline void hvx_min_scalar_f32_au(uint8_t * restrict dst, const uint8_t * restrict src, const float val, uint32_t n) { + const HVX_Vector val_vec = hvx_vec_splat_f32(val); + assert((unsigned long) dst % 128 == 0); + hvx_scalar_loop_body(HVX_Vector, HVX_UVector, sizeof(float), hvx_vec_store_a, HVX_OP_MIN_SCALAR); +} + +static inline void hvx_min_scalar_f32_ua(uint8_t * restrict dst, const uint8_t * restrict src, const float val, uint32_t n) { + const HVX_Vector val_vec = hvx_vec_splat_f32(val); + assert((unsigned long) src % 128 == 0); + hvx_scalar_loop_body(HVX_UVector, HVX_Vector, sizeof(float), hvx_vec_store_u, HVX_OP_MIN_SCALAR); +} + +static inline void hvx_min_scalar_f32_uu(uint8_t * restrict dst, const uint8_t * restrict src, const float val, uint32_t n) { + const HVX_Vector val_vec = hvx_vec_splat_f32(val); + hvx_scalar_loop_body(HVX_UVector, HVX_UVector, sizeof(float), hvx_vec_store_u, HVX_OP_MIN_SCALAR); +} + +static inline void hvx_min_scalar_f32(uint8_t * restrict dst, const uint8_t * restrict src, const float val, const int num_elems) { + if (hex_is_aligned((void *) dst, 128) && hex_is_aligned((void *) src, 128)) { + hvx_min_scalar_f32_aa(dst, src, val, num_elems); + } else if (hex_is_aligned((void *) dst, 128)) { + hvx_min_scalar_f32_au(dst, src, val, num_elems); + } else if (hex_is_aligned((void *) src, 128)) { + hvx_min_scalar_f32_ua(dst, src, val, num_elems); + } else { + hvx_min_scalar_f32_uu(dst, src, val, num_elems); + } +} + +// CLAMP Scalar variants + +#define HVX_OP_CLAMP_SCALAR(v) \ + ({ \ + HVX_VectorPred pred_cap_right = Q6_Q_vcmp_gt_VsfVsf(v, max_vec); \ + HVX_VectorPred pred_cap_left = Q6_Q_vcmp_gt_VsfVsf(min_vec, v); \ + HVX_Vector tmp = Q6_V_vmux_QVV(pred_cap_right, max_vec, v); \ + Q6_V_vmux_QVV(pred_cap_left, min_vec, tmp); \ + }) + +static inline void hvx_clamp_scalar_f32_aa(uint8_t * restrict dst, const uint8_t * restrict src, const float min, const float max, uint32_t n) { + const HVX_Vector min_vec = hvx_vec_splat_f32(min); + const HVX_Vector max_vec = hvx_vec_splat_f32(max); + assert((unsigned long) dst % 128 == 0); + assert((unsigned long) src % 128 == 0); + hvx_scalar_loop_body(HVX_Vector, HVX_Vector, sizeof(float), hvx_vec_store_a, HVX_OP_CLAMP_SCALAR); +} + +static inline void hvx_clamp_scalar_f32_au(uint8_t * restrict dst, const uint8_t * restrict src, const float min, const float max, uint32_t n) { + const HVX_Vector min_vec = hvx_vec_splat_f32(min); + const HVX_Vector max_vec = hvx_vec_splat_f32(max); + assert((unsigned long) dst % 128 == 0); + hvx_scalar_loop_body(HVX_Vector, HVX_UVector, sizeof(float), hvx_vec_store_a, HVX_OP_CLAMP_SCALAR); +} + +static inline void hvx_clamp_scalar_f32_ua(uint8_t * restrict dst, const uint8_t * restrict src, const float min, const float max, uint32_t n) { + const HVX_Vector min_vec = hvx_vec_splat_f32(min); + const HVX_Vector max_vec = hvx_vec_splat_f32(max); + assert((unsigned long) src % 128 == 0); + hvx_scalar_loop_body(HVX_UVector, HVX_Vector, sizeof(float), hvx_vec_store_u, HVX_OP_CLAMP_SCALAR); +} + +static inline void hvx_clamp_scalar_f32_uu(uint8_t * restrict dst, const uint8_t * restrict src, const float min, const float max, uint32_t n) { + const HVX_Vector min_vec = hvx_vec_splat_f32(min); + const HVX_Vector max_vec = hvx_vec_splat_f32(max); + hvx_scalar_loop_body(HVX_UVector, HVX_UVector, sizeof(float), hvx_vec_store_u, HVX_OP_CLAMP_SCALAR); +} + +static inline void hvx_clamp_scalar_f32(uint8_t * restrict dst, const uint8_t * restrict src, const float min, const float max, const int num_elems) { + if (hex_is_aligned((void *) dst, 128) && hex_is_aligned((void *) src, 128)) { + hvx_clamp_scalar_f32_aa(dst, src, min, max, num_elems); + } else if (hex_is_aligned((void *) dst, 128)) { + hvx_clamp_scalar_f32_au(dst, src, min, max, num_elems); + } else if (hex_is_aligned((void *) src, 128)) { + hvx_clamp_scalar_f32_ua(dst, src, min, max, num_elems); + } else { + hvx_clamp_scalar_f32_uu(dst, src, min, max, num_elems); + } +} + +// +// Square +// + +#define hvx_sqr_f32_loop_body(dst_type, src_type, vec_store) \ + do { \ + dst_type * restrict vdst = (dst_type *) dst; \ + src_type * restrict vsrc = (src_type *) src; \ + \ + const uint32_t elem_size = sizeof(float); \ + const uint32_t epv = 128 / elem_size; \ + const uint32_t nvec = n / epv; \ + const uint32_t nloe = n % epv; \ + \ + uint32_t i = 0; \ + \ + _Pragma("unroll(4)") \ + for (; i < nvec; i++) { \ + vdst[i] = HVX_OP_MUL_F32(vsrc[i], vsrc[i]); \ + } \ + if (nloe) { \ + HVX_Vector v = HVX_OP_MUL_F32(vsrc[i], vsrc[i]); \ + vec_store((void *) &vdst[i], nloe * elem_size, v); \ + } \ + } while(0) + +static inline void hvx_sqr_f32_aa(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) { + assert((unsigned long) dst % 128 == 0); + assert((unsigned long) src % 128 == 0); + hvx_sqr_f32_loop_body(HVX_Vector, HVX_Vector, hvx_vec_store_a); +} + +static inline void hvx_sqr_f32_au(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) { + assert((unsigned long) dst % 128 == 0); + hvx_sqr_f32_loop_body(HVX_Vector, HVX_UVector, hvx_vec_store_a); +} + +static inline void hvx_sqr_f32_ua(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) { + assert((unsigned long) src % 128 == 0); + hvx_sqr_f32_loop_body(HVX_UVector, HVX_Vector, hvx_vec_store_u); +} + +static inline void hvx_sqr_f32_uu(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) { + hvx_sqr_f32_loop_body(HVX_UVector, HVX_UVector, hvx_vec_store_u); +} + +static inline void hvx_sqr_f32(uint8_t * restrict dst, const uint8_t * restrict src, const uint32_t num_elems) { + if (hex_is_aligned((void *) dst, 128)) { + if (hex_is_aligned((void *) src, 128)) { + hvx_sqr_f32_aa(dst, src, num_elems); + } else { + hvx_sqr_f32_au(dst, src, num_elems); + } + } else { + if (hex_is_aligned((void *) src, 128)) { + hvx_sqr_f32_ua(dst, src, num_elems); + } else { + hvx_sqr_f32_uu(dst, src, num_elems); + } + } +} + +#undef HVX_OP_ADD_F32 +#undef HVX_OP_SUB_F32 +#undef HVX_OP_MUL_F32 +#undef HVX_OP_ADD_F16 +#undef HVX_OP_SUB_F16 +#undef HVX_OP_MUL_F16 +#undef hvx_arith_loop_body +#undef HVX_OP_ADD_SCALAR_F32 +#undef HVX_OP_SUB_SCALAR_F32 +#undef HVX_OP_MUL_SCALAR_F32 +#undef HVX_OP_ADD_SCALAR_F16 +#undef HVX_OP_SUB_SCALAR_F16 +#undef HVX_OP_MUL_SCALAR_F16 +#undef hvx_scalar_loop_body +#undef HVX_OP_MIN_SCALAR +#undef HVX_OP_CLAMP_SCALAR +#undef DEFINE_HVX_BINARY_OP_VARIANTS +#undef HVX_BINARY_DISPATCHER +#undef UNUSED + +#endif // HVX_ARITH_H diff --git a/ggml/src/ggml-hexagon/htp/hvx-base.h b/ggml/src/ggml-hexagon/htp/hvx-base.h new file mode 100644 index 00000000000..f6cb02951d0 --- /dev/null +++ b/ggml/src/ggml-hexagon/htp/hvx-base.h @@ -0,0 +1,308 @@ +#ifndef HVX_BASE_H +#define HVX_BASE_H + +#include <stdbool.h> +#include <stdint.h> +#include <math.h> +#include <assert.h> + +#include "hex-utils.h" +#include "hvx-types.h" + +#define hvx_vmem(A) *((HVX_Vector *)(A)) +#define hvx_vmemu(A) *((HVX_UVector *)(A)) + +static inline void hvx_vec_store_u(void * restrict dst, uint32_t n, HVX_Vector v) { + // Rotate as needed. + v = Q6_V_vlalign_VVR(v, v, (size_t) dst); + + uint32_t left_off = (size_t) dst & 127; + uint32_t right_off = left_off + n; + + HVX_VectorPred ql_not = Q6_Q_vsetq_R((size_t) dst); + HVX_VectorPred qr = Q6_Q_vsetq2_R(right_off); + + if (right_off > 128) { + Q6_vmem_QRIV(qr, (HVX_Vector *) dst + 1, v); + // all 1's + qr = Q6_Q_vcmp_eq_VbVb(v, v); + } + + ql_not = Q6_Q_or_QQn(ql_not, qr); + Q6_vmem_QnRIV(ql_not, (HVX_Vector *) dst, v); +} + +static inline void hvx_vec_store_a(void * restrict dst, uint32_t n, HVX_Vector v) { + assert((unsigned long) dst % 128 == 0); + HVX_VectorPred m = Q6_Q_or_QQn(Q6_Q_vsetq_R((unsigned long) dst), Q6_Q_vsetq2_R(n)); + Q6_vmem_QnRIV(m, (HVX_Vector *) dst, v); +} + +static inline HVX_Vector hvx_vec_splat_f32(float v) { + union { float f; uint32_t i; } u = { .f = v }; + return Q6_V_vsplat_R(u.i); +} + +static inline HVX_Vector hvx_vec_splat_f16(_Float16 v) { + union { __fp16 f; uint16_t i; } u = { .f = v }; + return Q6_Vh_vsplat_R(u.i); +} + +static inline HVX_Vector hvx_vec_repl4(HVX_Vector v) { + // vdelta control to replicate first 4 bytes across all elements + static const uint8_t __attribute__((aligned(128))) repl[128] = { + 0x00, 0x00, 0x00, 0x00, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, + 0x10, 0x10, 0x10, 0x10, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, + 0x20, 0x20, 0x20, 0x20, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, + 0x10, 0x10, 0x10, 0x10, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, + 0x40, 0x40, 0x40, 0x40, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, + 0x10, 0x10, 0x10, 0x10, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, + 0x20, 0x20, 0x20, 0x20, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, + 0x10, 0x10, 0x10, 0x10, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, + }; + + HVX_Vector ctrl = *(HVX_Vector *) repl; + return Q6_V_vdelta_VV(v, ctrl); +} + +static inline float hvx_vec_get_f32(HVX_Vector v) { + float __attribute__((aligned(128))) x; + hvx_vec_store_a(&x, 4, v); + return x; +} + +static inline int32_t hvx_vec_get_i32(HVX_Vector v) { + int32_t __attribute__((aligned(128))) x; + hvx_vec_store_a(&x, 4, v); + return x; +} + +static inline _Float16 hvx_vec_get_f16(HVX_Vector v) { + _Float16 __attribute__((aligned(128))) x; + hvx_vec_store_a(&x, 2, v); + return x; +} + +static inline HVX_Vector hvx_vec_abs_f16(HVX_Vector v) { + // abs by clearing the fp16 sign bit + HVX_Vector mask = Q6_Vh_vsplat_R(0x7fff); + return Q6_V_vand_VV(v, mask); +} + +static inline HVX_Vector hvx_vec_neg_f16(HVX_Vector v) { + // neg by setting the fp16 sign bit + HVX_Vector mask = Q6_Vh_vsplat_R(0x8000); + return Q6_V_vxor_VV(v, mask); +} + +static inline HVX_Vector hvx_vec_abs_f32(HVX_Vector v) { + // abs by clearing the fp32 sign bit + HVX_Vector mask = Q6_V_vsplat_R(0x7fffffff); + return Q6_V_vand_VV(v, mask); +} + +static inline HVX_Vector hvx_vec_neg_f32(HVX_Vector v) { +#if __HVX_ARCH__ > 75 + return Q6_Vsf_vfneg_Vsf(v); +#else + // neg by setting the fp32 sign bit + HVX_Vector mask = Q6_V_vsplat_R(0x80000000); + return Q6_V_vxor_VV(v, mask); +#endif // __HVX_ARCH__ > 75 +} + +static inline HVX_VectorPred hvx_vec_is_nan_f16(HVX_Vector v) { + const HVX_Vector vnan_exp = Q6_Vh_vsplat_R(0x7C00); + const HVX_Vector vnan_frac = Q6_Vh_vsplat_R(0x7FFF); + + // get pred of which are NaN, i.e., exponent bits all 1s and fraction bits non 0s + HVX_VectorPred p_exp = Q6_Q_vcmp_eq_VhVh(Q6_V_vand_VV(v, vnan_exp), vnan_exp); + HVX_VectorPred p_frac = Q6_Q_not_Q(Q6_Q_vcmp_eq_VhVh(Q6_V_vand_VV(v, vnan_frac), vnan_exp)); + return Q6_Q_and_QQ(p_exp, p_frac); +} + +static inline HVX_Vector hvx_vec_f32_to_f16_shuff(HVX_Vector v0, HVX_Vector v1) { +#if __HVX_ARCH__ >= 81 + HVX_Vector q0 = Q6_Vqf32_equals_Vsf(v0); + HVX_Vector q1 = Q6_Vqf32_equals_Vsf(v1); +#else + const HVX_Vector zero = Q6_V_vzero(); + HVX_Vector q0 = Q6_Vqf32_vadd_VsfVsf(v0, zero); + HVX_Vector q1 = Q6_Vqf32_vadd_VsfVsf(v1, zero); +#endif + return Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(q1, q0)); +} + +static inline HVX_Vector hvx_vec_f32_to_f16(HVX_Vector v0, HVX_Vector v1) { + HVX_Vector v = Q6_Vh_vdeal_Vh(hvx_vec_f32_to_f16_shuff(v0, v1)); + +#if __HVX_ARCH__ < 79 + // replace NaNs with -INF, older arches produce NaNs for (-INF + 0.0) + const HVX_Vector neg_inf = hvx_vec_splat_f16(-INFINITY); + HVX_VectorPred nan = hvx_vec_is_nan_f16(v); + v = Q6_V_vmux_QVV(nan, neg_inf, v); +#endif + + return v; +} + +#if __HVX_ARCH__ >= 79 +static inline HVX_VectorPair hvx_vec_f16_to_f32_shuff(HVX_Vector v) { + const HVX_Vector one = hvx_vec_splat_f16(1.0); + HVX_VectorPair p = Q6_Wsf_vmpy_VhfVhf(v, one); + return Q6_W_vcombine_VV(Q6_V_hi_W(p), Q6_V_lo_W(p)); +} +static inline HVX_VectorPair hvx_vec_f16_to_f32(HVX_Vector v) { + const HVX_Vector one = hvx_vec_splat_f16(1.0); + HVX_VectorPair p = Q6_Wsf_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(v), one); + return Q6_W_vcombine_VV(Q6_V_hi_W(p), Q6_V_lo_W(p)); +} +#else +static inline HVX_VectorPair hvx_vec_f16_to_f32_shuff(HVX_Vector v) { + const HVX_Vector one = hvx_vec_splat_f16(1.0); + HVX_VectorPair p = Q6_Wqf32_vmpy_VhfVhf(v, one); + return Q6_W_vcombine_VV(Q6_Vsf_equals_Vqf32(Q6_V_hi_W(p)), Q6_Vsf_equals_Vqf32(Q6_V_lo_W(p))); +} +static inline HVX_VectorPair hvx_vec_f16_to_f32(HVX_Vector v) { + const HVX_Vector one = hvx_vec_splat_f16(1.0); + HVX_VectorPair p = Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(v), one); + return Q6_W_vcombine_VV(Q6_Vsf_equals_Vqf32(Q6_V_hi_W(p)), Q6_Vsf_equals_Vqf32(Q6_V_lo_W(p))); +} +#endif + +/* Q6_Vsf_equals_Vw is only available on v73+.*/ +#if __HVX_ARCH__ < 73 +static inline HVX_Vector hvx_vec_i32_to_qf32(HVX_Vector const in) +{ + HVX_Vector const vzero = Q6_V_vzero(); + HVX_VectorPred is_zero = Q6_Q_vcmp_eq_VwVw(in, vzero); + HVX_Vector lshift = Q6_Vw_vnormamt_Vw(in); + HVX_Vector normalized = Q6_Vw_vasl_VwVw(in, lshift); + HVX_Vector vexp = Q6_Vw_vsub_VwVw(Q6_V_vsplat_R(0x7f + 30), lshift); + HVX_Vector mant = Q6_V_vand_VV(Q6_V_vsplat_R(0xFFFFFF00), normalized); + HVX_Vector ret = Q6_V_vmux_QVV(is_zero, vzero, Q6_Vw_vadd_VwVw(mant, vexp)); + return ret; +} + +static inline HVX_Vector Q6_Vsf_equals_Vw(HVX_Vector const in) +{ + return Q6_Vsf_equals_Vqf32(hvx_vec_i32_to_qf32(in)); +} +#endif + +static inline HVX_Vector hvx_vec_i16_from_hf_rnd_sat(HVX_Vector vin) { + // This looks complicated. + // Ideally should just be Q6_Vh_equals_Vhf(vin) + // but that instruction does not do proper rounding. + + // convert to qf32, multiplying by 1.0 in the process. + HVX_VectorPair v32 = Q6_Wqf32_vmpy_VhfVhf(vin, Q6_Vh_vsplat_R(0x3C00)); + + // 'in-range' values are +/32752. + // add 192K to it, convert to sf + HVX_Vector v192K = Q6_V_vsplat_R(0x48400000); + HVX_Vector vsf_0 = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(Q6_V_lo_W(v32), v192K)); + HVX_Vector vsf_1 = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(Q6_V_hi_W(v32), v192K)); + + // for in-range cases, result is {163858... 229360} so the exponent is always 144. + // if we extract bits 21..0 as a signed quantity, and round 6 bits off, that will be the answer. + // Start by <<10 to get the final 'sign' bit in bit 15... + vsf_0 = Q6_Vw_vasl_VwR(vsf_0, 10); + vsf_1 = Q6_Vw_vasl_VwR(vsf_1, 10); + + // now round down to 16 + return Q6_Vh_vround_VwVw_sat(vsf_1, vsf_0); +} + +#if __HVX_ARCH__ < 79 + +static inline HVX_VectorPair hvx_vec_mpyacc_f32_f16(HVX_VectorPair acc, HVX_Vector x, HVX_Vector y) +{ + HVX_VectorPair m = Q6_Wqf32_vmpy_VhfVhf(x, y); + HVX_Vector a0 = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(Q6_V_lo_W(m), Q6_V_lo_W(acc))); + HVX_Vector a1 = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(Q6_V_hi_W(m), Q6_V_hi_W(acc))); + return Q6_W_vcombine_VV(a1, a0); +} + +#else + +static inline HVX_VectorPair hvx_vec_mpyacc_f32_f16(HVX_VectorPair acc, HVX_Vector x, HVX_Vector y) +{ + return Q6_Wsf_vmpyacc_WsfVhfVhf(acc, x, y); +} + +#endif + +#if __HVX_ARCH__ < 79 + +static inline HVX_Vector hvx_vec_add_f16_f16(HVX_Vector a, HVX_Vector b) +{ + const HVX_Vector negone = Q6_Vh_vsplat_R(0xBC00); // -1.0 in IEEE FP16 + const HVX_Vector one = Q6_Vh_vsplat_R(0x3C00); // 1.0 in IEEE FP16 + HVX_VectorPair a_p = Q6_Wqf32_vmpy_VhfVhf(a, one); + HVX_VectorPair b_p = Q6_Wqf32_vmpy_VhfVhf(b, negone); + HVX_Vector a0 = Q6_Vqf32_vsub_Vqf32Vqf32(Q6_V_lo_W(a_p), Q6_V_lo_W(b_p)); + HVX_Vector a1 = Q6_Vqf32_vsub_Vqf32Vqf32(Q6_V_hi_W(a_p), Q6_V_hi_W(b_p)); + return Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(a1, a0)); +} + +static inline HVX_Vector hvx_vec_sub_f16_f16(HVX_Vector a, HVX_Vector b) +{ + const HVX_Vector negone = Q6_Vh_vsplat_R(0xBC00); // -1.0 in IEEE FP16 + const HVX_Vector one = Q6_Vh_vsplat_R(0x3C00); // 1.0 in IEEE FP16 + HVX_VectorPair a_p = Q6_Wqf32_vmpy_VhfVhf(a, one); + HVX_VectorPair b_p = Q6_Wqf32_vmpy_VhfVhf(b, negone); + HVX_Vector a0 = Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(a_p), Q6_V_lo_W(b_p)); + HVX_Vector a1 = Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_hi_W(a_p), Q6_V_hi_W(b_p)); + return Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(a1, a0)); +} + +static inline HVX_Vector hvx_vec_mul_f16_f16(HVX_Vector a, HVX_Vector b) +{ + return Q6_Vhf_equals_Wqf32(Q6_Wqf32_vmpy_VhfVhf(a, b)); +} + +static inline HVX_Vector hvx_vec_add_f32_f32(HVX_Vector a, HVX_Vector b) { + return Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_VsfVsf(a, b)); +} + +static inline HVX_Vector hvx_vec_sub_f32_f32(HVX_Vector a, HVX_Vector b) { + return Q6_Vsf_equals_Vqf32(Q6_Vqf32_vsub_VsfVsf(a, b)); +} + +static inline HVX_Vector hvx_vec_mul_f32_f32(HVX_Vector a, HVX_Vector b) { + return Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(a, b)); +} + +#else + +static inline HVX_Vector hvx_vec_add_f16_f16(HVX_Vector a, HVX_Vector b) +{ + return Q6_Vhf_vadd_VhfVhf(a, b); +} + +static inline HVX_Vector hvx_vec_sub_f16_f16(HVX_Vector a, HVX_Vector b) +{ + return Q6_Vhf_vsub_VhfVhf(a, b); +} + +static inline HVX_Vector hvx_vec_mul_f16_f16(HVX_Vector a, HVX_Vector b) +{ + return Q6_Vhf_vmpy_VhfVhf(a, b); +} + +static inline HVX_Vector hvx_vec_add_f32_f32(HVX_Vector a, HVX_Vector b) { + return Q6_Vsf_vadd_VsfVsf(a, b); +} + +static inline HVX_Vector hvx_vec_sub_f32_f32(HVX_Vector a, HVX_Vector b) { + return Q6_Vsf_vsub_VsfVsf(a, b); +} + +static inline HVX_Vector hvx_vec_mul_f32_f32(HVX_Vector a, HVX_Vector b) { + return Q6_Vsf_vmpy_VsfVsf(a, b); +} + +#endif // __HVX_ARCH__ < 79 + +#endif /* HVX_BASE_H */ diff --git a/ggml/src/ggml-hexagon/htp/hvx-copy.h b/ggml/src/ggml-hexagon/htp/hvx-copy.h new file mode 100644 index 00000000000..a3e33c3b3af --- /dev/null +++ b/ggml/src/ggml-hexagon/htp/hvx-copy.h @@ -0,0 +1,262 @@ +#ifndef HVX_COPY_H +#define HVX_COPY_H + +#include <assert.h> +#include <stddef.h> +#include <stdint.h> + +#include "hvx-base.h" + +#define hvx_splat_pragma(x) _Pragma(#x) +#define hvx_splat_loop_body(dst_type, vec_store, unroll_cnt) \ + do { \ + dst_type * restrict vdst = (dst_type *) dst; \ + \ + uint32_t nvec = n / (128 / elem_size); \ + uint32_t nloe = n % (128 / elem_size); \ + \ + uint32_t i = 0; \ + \ + hvx_splat_pragma(unroll(unroll_cnt)) \ + for (; i < nvec; i++) { \ + vdst[i] = src; \ + } \ + if (nloe) { \ + vec_store((void *) &vdst[i], nloe * elem_size, src); \ + } \ + } while(0) + +static inline void hvx_splat_a(void * restrict dst, HVX_Vector src, uint32_t n, uint32_t elem_size) { + assert((unsigned long) dst % 128 == 0); + hvx_splat_loop_body(HVX_Vector, hvx_vec_store_a, 4); +} + +static inline void hvx_splat_u(void * restrict dst, HVX_Vector src, uint32_t n, uint32_t elem_size) { + hvx_splat_loop_body(HVX_UVector, hvx_vec_store_u, 4); +} + +static inline void hvx_splat_f32_a(void * restrict dst, float v, uint32_t n) { + hvx_splat_a(dst, hvx_vec_splat_f32(v), n, sizeof(float)); +} + +static inline void hvx_splat_f32_u(void * restrict dst, float v, uint32_t n) { + hvx_splat_u(dst, hvx_vec_splat_f32(v), n, sizeof(float)); +} + +static inline void hvx_splat_f16_a(void * restrict dst, _Float16 v, uint32_t n) { + hvx_splat_u(dst, hvx_vec_splat_f16(v), n, sizeof(__fp16)); +} + +static inline void hvx_splat_f16_u(void * restrict dst, _Float16 v, uint32_t n) { + hvx_splat_u(dst, hvx_vec_splat_f16(v), n, sizeof(__fp16)); +} + +static inline void hvx_splat_u16_a(void * restrict dst, uint16_t v, uint32_t n) { + hvx_splat_a(dst, Q6_Vh_vsplat_R(v), n, sizeof(uint16_t)); +} + +static inline void hvx_splat_u16_u(void * restrict dst, uint16_t v, uint32_t n) { + hvx_splat_u(dst, Q6_Vh_vsplat_R(v), n, sizeof(uint16_t)); +} + +static inline void hvx_splat_u8_a(void * restrict dst, uint8_t v, uint32_t n) { + hvx_splat_a(dst, Q6_Vb_vsplat_R(v), n, 1); +} + +static inline void hvx_splat_u8_u(void * restrict dst, uint8_t v, uint32_t n) { + hvx_splat_u(dst, Q6_Vb_vsplat_R(v), n, 1); +} + +#define hvx_copy_loop_body(dst_type, src_type, vec_store) \ + do { \ + dst_type * restrict vdst = (dst_type *) dst; \ + src_type * restrict vsrc = (src_type *) src; \ + \ + const uint32_t epv = 128 / elem_size; \ + const uint32_t nvec = n / epv; \ + const uint32_t nloe = n % epv; \ + \ + uint32_t i = 0; \ + \ + _Pragma("unroll(4)") \ + for (; i < nvec; i++) { vdst[i] = vsrc[i]; } \ + if (nloe) { \ + vec_store((void *) &vdst[i], nloe * elem_size, vsrc[i]); \ + } \ + } while(0) + +// Generic copy routines +static inline void hvx_copy_aa(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n, uint32_t elem_size) { + assert((unsigned long) dst % 128 == 0); + assert((unsigned long) src % 128 == 0); + hvx_copy_loop_body(HVX_Vector, HVX_Vector, hvx_vec_store_a); +} + +static inline void hvx_copy_au(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n, uint32_t elem_size) { + assert((unsigned long) dst % 128 == 0); + hvx_copy_loop_body(HVX_Vector, HVX_UVector, hvx_vec_store_a); +} + +static inline void hvx_copy_ua(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n, uint32_t elem_size) { + assert((unsigned long) src % 128 == 0); + hvx_copy_loop_body(HVX_UVector, HVX_Vector, hvx_vec_store_u); +} + +static inline void hvx_copy_uu(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n, uint32_t elem_size) { + hvx_copy_loop_body(HVX_UVector, HVX_UVector, hvx_vec_store_u); +} + +// copy n fp16 elements : source and destination are aligned to HVX Vector (128) +static inline void hvx_copy_f16_aa(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) { + hvx_copy_aa(dst, src, n, sizeof(__fp16)); +} + +// copy n fp16 elements : source is aligned, destination is potentially unaligned +static inline void hvx_copy_f16_au(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) { + hvx_copy_au(dst, src, n, sizeof(__fp16)); +} + +// copy n fp16 elements : source is aligned, destination is potentially unaligned +static inline void hvx_copy_f16_ua(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) { + hvx_copy_ua(dst, src, n, sizeof(__fp16)); +} + +// copy n fp16 elements : source is aligned, destination is potentially unaligned +static inline void hvx_copy_f16_uu(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) { + hvx_copy_uu(dst, src, n, sizeof(__fp16)); +} + +// copy n fp32 elements : source and destination are aligned to HVX Vector (128) +static inline void hvx_copy_f32_aa(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) { + hvx_copy_aa(dst, src, n, sizeof(float)); +} + +// copy n fp32 elements : source is aligned, destination is unaligned +static inline void hvx_copy_f32_ua(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) { + hvx_copy_ua(dst, src, n, sizeof(float)); +} + +// copy n fp32 elements : source is unaligned, destination is aligned +static inline void hvx_copy_f32_au(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) { + hvx_copy_au(dst, src, n, sizeof(float)); +} + +// copy n fp32 elements : source is unaligned, destination unaligned +static inline void hvx_copy_f32_uu(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) { + hvx_copy_uu(dst, src, n, sizeof(float)); +} + +//// fp32 -> fp16 + +#define hvx_copy_f16_f32_loop_body(dst_type, src_type, vec_store) \ + do { \ + dst_type * restrict vdst = (dst_type *) dst; \ + src_type * restrict vsrc = (src_type *) src; \ + \ + const uint32_t elem_size = sizeof(__fp16); \ + const uint32_t epv = 128 / elem_size; \ + const uint32_t nvec = n / epv; \ + const uint32_t nloe = n % epv; \ + \ + uint32_t i = 0; \ + \ + _Pragma("unroll(4)") \ + for (; i < nvec; i++) { \ + vdst[i] = hvx_vec_f32_to_f16(vsrc[i*2+0], vsrc[i*2+1]); \ + } \ + if (nloe) { \ + HVX_Vector v = hvx_vec_f32_to_f16(vsrc[i*2+0], vsrc[i*2+1]); \ + vec_store((void *) &vdst[i], nloe * elem_size, v); \ + } \ + } while(0) + +// copy/convert n fp32 elements into n fp16 elements : source is aligned, destination is aligned +static inline void hvx_copy_f16_f32_aa(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) { + assert((unsigned long) dst % 128 == 0); + assert((unsigned long) src % 128 == 0); + hvx_copy_f16_f32_loop_body(HVX_Vector, HVX_Vector, hvx_vec_store_a); +} + +// copy/convert n fp32 elements into n fp16 elements : source is unaligned, destination is aligned +static inline void hvx_copy_f16_f32_au(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) { + assert((unsigned long) dst % 128 == 0); + hvx_copy_f16_f32_loop_body(HVX_Vector, HVX_UVector, hvx_vec_store_a); +} + +// copy/convert n fp32 elements into n fp16 elements : source is aligned, destination is unaligned +static inline void hvx_copy_f16_f32_ua(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) { + assert((unsigned long) src % 128 == 0); + hvx_copy_f16_f32_loop_body(HVX_UVector, HVX_Vector, hvx_vec_store_u); +} + +// copy/convert n fp32 elements into n fp16 elements : source is unaligned, destination is unaligned +static inline void hvx_copy_f16_f32_uu(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) { + hvx_copy_f16_f32_loop_body(HVX_UVector, HVX_UVector, hvx_vec_store_u); +} + +//// fp16 -> fp32 + +#define hvx_copy_f32_f16_loop_body(dst_type, src_type, vec_store) \ + do { \ + dst_type * restrict vdst = (dst_type *) dst; \ + src_type * restrict vsrc = (src_type *) src; \ + \ + const HVX_Vector one = hvx_vec_splat_f16(1.0); \ + \ + const uint32_t elem_size = sizeof(__fp16); \ + const uint32_t epv = 128 / elem_size; \ + const uint32_t nvec = n / epv; \ + uint32_t nloe = n % epv; \ + \ + uint32_t i = 0; \ + \ + _Pragma("unroll(4)") \ + for (i = 0; i < nvec; ++i) { \ + HVX_VectorPair p = Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(vsrc[i]), one); \ + vdst[i*2] = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(p)); \ + vdst[i*2+1] = Q6_Vsf_equals_Vqf32(Q6_V_hi_W(p)); \ + } \ + \ + if (nloe) { \ + HVX_VectorPair p = Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(vsrc[i]), one); \ + \ + HVX_Vector vd = Q6_V_lo_W(p); \ + i = 2 * i; \ + \ + if (nloe >= 32) { \ + vdst[i] = Q6_Vsf_equals_Vqf32(vd); \ + nloe -= 32; ++i; vd = Q6_V_hi_W(p); \ + } \ + \ + if (nloe) { \ + vd = Q6_Vsf_equals_Vqf32(vd); \ + hvx_vec_store_u(&vdst[i], nloe * sizeof(float), vd); \ + } \ + } \ + } while(0) + +// copy/convert n fp16 elements into n fp32 elements : source is aligned, destination is aligned +static inline void hvx_copy_f32_f16_aa(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) { + assert((unsigned long) dst % 128 == 0); + assert((unsigned long) src % 128 == 0); + hvx_copy_f32_f16_loop_body(HVX_Vector, HVX_Vector, hvx_vec_store_a); +} + +// copy/convert n fp16 elements into n fp32 elements : source is unaligned, destination is aligned +static inline void hvx_copy_f32_f16_au(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) { + assert((unsigned long) dst % 128 == 0); + hvx_copy_f32_f16_loop_body(HVX_Vector, HVX_UVector, hvx_vec_store_a); +} + +// copy/convert n fp16 elements into n fp32 elements : source is aligned, destination is unaligned +static inline void hvx_copy_f32_f16_ua(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) { + assert((unsigned long) src % 128 == 0); + hvx_copy_f32_f16_loop_body(HVX_UVector, HVX_Vector, hvx_vec_store_u); +} + +// copy/convert n fp16 elements into n fp32 elements : source is unaligned, destination is unaligned +static inline void hvx_copy_f32_f16_uu(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) { + hvx_copy_f32_f16_loop_body(HVX_UVector, HVX_UVector, hvx_vec_store_u); +} + +#endif // HVX_COPY_H diff --git a/ggml/src/ggml-hexagon/htp/hvx-div.h b/ggml/src/ggml-hexagon/htp/hvx-div.h new file mode 100644 index 00000000000..53ee304e749 --- /dev/null +++ b/ggml/src/ggml-hexagon/htp/hvx-div.h @@ -0,0 +1,291 @@ +#ifndef HVX_DIV_H +#define HVX_DIV_H + +#include <HAP_farf.h> + +#include <math.h> +#include <string.h> +#include <assert.h> +#include <stddef.h> +#include <stdint.h> + +#include "hvx-base.h" +#include "hex-utils.h" +#include "hvx-inverse.h" +#include "hvx-arith.h" + +#if __HVX_ARCH__ < 79 +#define HVX_OP_MUL_F32(a, b) Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(a, b)) +#define HVX_OP_MUL_F16(a, b) Q6_Vhf_equals_Wqf32(Q6_Wqf32_vmpy_VhfVhf(a, b)) +#else +#define HVX_OP_MUL_F32(a, b) Q6_Vsf_vmpy_VsfVsf(a, b) +#define HVX_OP_MUL_F16(a, b) Q6_Vhf_vmpy_VhfVhf(a, b) +#endif + +// Compute div by scaler in f32. Requires first by expanding fp32 to fp16 and converting the result back to fp32. +static inline HVX_Vector hvx_div_mul_f16_const_using_f32(HVX_Vector vec1_hf, HVX_Vector vec2_sf_const, HVX_Vector vec_hf_one_1_0) { +#if __HVX_ARCH__ < 79 + HVX_VectorPair src_to_f32 = Q6_Wqf32_vmpy_VhfVhf(vec1_hf, vec_hf_one_1_0); + HVX_Vector src_to_f32_0 = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(src_to_f32)); + HVX_Vector src_to_f32_1 = Q6_Vsf_equals_Vqf32(Q6_V_hi_W(src_to_f32)); +#else + HVX_VectorPair src_to_f32 = Q6_Wsf_vmpy_VhfVhf(vec1_hf, vec_hf_one_1_0); + HVX_Vector src_to_f32_0 = Q6_V_lo_W(src_to_f32); + HVX_Vector src_to_f32_1 = Q6_V_hi_W(src_to_f32); +#endif + + HVX_Vector div_f32_0 = HVX_OP_MUL_F32(src_to_f32_0, vec2_sf_const); + HVX_Vector div_f32_1 = HVX_OP_MUL_F32(src_to_f32_1, vec2_sf_const); + +#if __HVX_ARCH__ < 79 + HVX_Vector res = hvx_vec_f32_to_f16(div_f32_0, div_f32_1); +#else + HVX_Vector res = Q6_Vhf_vcvt_VsfVsf(div_f32_0, div_f32_1); +#endif + return res; +} + +// Variant for <v79: Use pre-computed f16 reciprocal constant +static inline HVX_Vector hvx_div_mul_f16_const_using_f16(HVX_Vector vec1_hf, HVX_Vector const_inv_hf) { + // Multiply by pre-computed f16 reciprocal constant + return HVX_OP_MUL_F16(vec1_hf, const_inv_hf); +} + +#define hvx_div_scaler_f16_loop_body(dst_type, src_type, vec_store) \ + do { \ + dst_type * restrict vdst = (dst_type *) dst; \ + src_type * restrict vsrc = (src_type *) src; \ + \ + HVX_Vector hf_one = Q6_Vh_vsplat_R(0x3C00); \ + \ + const uint32_t nvec = n / VLEN_FP16; \ + const uint32_t nloe = n % VLEN_FP16; \ + \ + uint32_t i = 0; \ + \ + _Pragma("unroll(4)") \ + for (; i < nvec; i++) { \ + HVX_Vector res; \ + if (__HVX_ARCH__ < 79) { \ + res = hvx_div_mul_f16_const_using_f16(vsrc[i], val_vec_f16); \ + } else { \ + res = hvx_div_mul_f16_const_using_f32(vsrc[i], val_vec_f32, hf_one); \ + } \ + vdst[i] = res; \ + } \ + if (nloe) { \ + HVX_Vector res; \ + if (__HVX_ARCH__ < 79) { \ + res = hvx_div_mul_f16_const_using_f16(vsrc[i], val_vec_f16); \ + } else { \ + res = hvx_div_mul_f16_const_using_f32(vsrc[i], val_vec_f32, hf_one); \ + } \ + vec_store((void *) &vdst[i], nloe * SIZEOF_FP16, res); \ + } \ + } while(0) + +static inline void hvx_div_scalar_f16_aa(uint8_t * restrict dst, const uint8_t * restrict src, const _Float16 val, uint32_t n) { + const HVX_Vector val_vec_f32 = hvx_vec_splat_f32(1.0f/((float)val)); + const HVX_Vector val_vec_f16 = hvx_vec_splat_f16(1.0f / val); + assert((uintptr_t) dst % 128 == 0); + assert((uintptr_t) src % 128 == 0); + hvx_div_scaler_f16_loop_body(HVX_Vector, HVX_Vector, hvx_vec_store_a); +} +static inline void hvx_div_scalar_f16_au(uint8_t * restrict dst, const uint8_t * restrict src, const _Float16 val, uint32_t n) { + const HVX_Vector val_vec_f32 = hvx_vec_splat_f32(1.0f/((float)val)); + const HVX_Vector val_vec_f16 = hvx_vec_splat_f16(1.0f / val); + assert((uintptr_t) dst % 128 == 0); + hvx_div_scaler_f16_loop_body(HVX_Vector, HVX_UVector, hvx_vec_store_a); +} +static inline void hvx_div_scalar_f16_ua(uint8_t * restrict dst, const uint8_t * restrict src, const _Float16 val, uint32_t n) { + const HVX_Vector val_vec_f32 = hvx_vec_splat_f32(1.0f/((float)val)); + const HVX_Vector val_vec_f16 = hvx_vec_splat_f16(1.0f / val); + assert((uintptr_t) src % 128 == 0); + hvx_div_scaler_f16_loop_body(HVX_UVector, HVX_Vector, hvx_vec_store_u); +} +static inline void hvx_div_scalar_f16_uu(uint8_t * restrict dst, const uint8_t * restrict src, const _Float16 val, uint32_t n) { + const HVX_Vector val_vec_f32 = hvx_vec_splat_f32(1.0f/((float)val)); + const HVX_Vector val_vec_f16 = hvx_vec_splat_f16(1.0f / val); + hvx_div_scaler_f16_loop_body(HVX_UVector, HVX_UVector, hvx_vec_store_u); +} + +// Compute div by using hvx_vec_inverse_f32_guard. Requires first by exapnding fp32 to fp16 and convert the result back to fp32. +static inline HVX_Vector hvx_vec_div_f16_using_f32(HVX_Vector vec1, HVX_Vector vec2, HVX_Vector f32_nan_inf_mask, HVX_Vector vec_hf_one_1_0) { +#if __HVX_ARCH__ < 79 + // Convert first input to fp32 + HVX_VectorPair vec1_to_f32 = Q6_Wqf32_vmpy_VhfVhf(vec1, vec_hf_one_1_0); // *1.0 + HVX_Vector vec1_to_f32_0 = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(vec1_to_f32)); + HVX_Vector vec1_to_f32_1 = Q6_Vsf_equals_Vqf32(Q6_V_hi_W(vec1_to_f32)); + + // Convert second input to fp32 + HVX_VectorPair vec2_to_f32 = Q6_Wqf32_vmpy_VhfVhf(vec2, vec_hf_one_1_0); // *1.0 + HVX_Vector vec2_to_f32_0 = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(vec2_to_f32)); + HVX_Vector vec2_to_f32_1 = Q6_Vsf_equals_Vqf32(Q6_V_hi_W(vec2_to_f32)); +#else + // Convert first input to fp32 + HVX_VectorPair vec1_to_f32 = Q6_Wsf_vmpy_VhfVhf(vec1, vec_hf_one_1_0); // *1.0 + HVX_Vector vec1_to_f32_0 = Q6_V_lo_W(vec1_to_f32); + HVX_Vector vec1_to_f32_1 = Q6_V_hi_W(vec1_to_f32); + + // Convert second input to fp32 + HVX_VectorPair vec2_to_f32 = Q6_Wsf_vmpy_VhfVhf(vec2, vec_hf_one_1_0); // *1.0 + HVX_Vector vec2_to_f32_0 = Q6_V_lo_W(vec2_to_f32); + HVX_Vector vec2_to_f32_1 = Q6_V_hi_W(vec2_to_f32); +#endif + + // Inverse second input in fp32 + HVX_Vector vec2_inv_f32_0 = hvx_vec_inverse_f32_guard(vec2_to_f32_0, f32_nan_inf_mask); + HVX_Vector vec2_inv_f32_1 = hvx_vec_inverse_f32_guard(vec2_to_f32_1, f32_nan_inf_mask); + + // Multiply first input by inverse of second, in fp32 + HVX_Vector div_f32_0 = HVX_OP_MUL_F32(vec1_to_f32_0, vec2_inv_f32_0); + HVX_Vector div_f32_1 = HVX_OP_MUL_F32(vec1_to_f32_1, vec2_inv_f32_1); + + // Convert back to fp16 +#if __HVX_ARCH__ < 79 + HVX_Vector recip = hvx_vec_f32_to_f16(div_f32_0, div_f32_1); +#else + HVX_Vector recip = Q6_Vhf_vcvt_VsfVsf(div_f32_0, div_f32_1); +#endif + + return recip; +} + +// Hybrid approach: f16 reciprocal for <v79, f32 precision for >=v79 +static inline HVX_Vector hvx_vec_hybrid_div_f16(HVX_Vector vec1, HVX_Vector vec2, HVX_Vector f32_nan_inf_mask, HVX_Vector f16_nan_inf_mask, HVX_Vector vec_hf_one_1_0) { +#if __HVX_ARCH__ < 79 + // For older architectures, use f16 reciprocal to avoid NaN/-inf issues + HVX_Vector vec2_inv = hvx_vec_inverse_f16_guard(vec2, f16_nan_inf_mask); + return HVX_OP_MUL_F16(vec1, vec2_inv); +#else + return hvx_vec_div_f16_using_f32(vec1, vec2, f32_nan_inf_mask, vec_hf_one_1_0); +#endif +} + +#define hvx_div_f16_loop_body(dst_type, src0_type, src1_type, vec_store) \ + do { \ + dst_type * restrict vdst = (dst_type *) dst; \ + src0_type * restrict vsrc0 = (src0_type *) src0; \ + src1_type * restrict vsrc1 = (src1_type *) src1; \ + \ + const HVX_Vector f32_nan_inf_mask = Q6_V_vsplat_R(0x7f800000); \ + const HVX_Vector f16_nan_inf_mask = Q6_Vh_vsplat_R(0x7c00); \ + const HVX_Vector hf_one = Q6_Vh_vsplat_R(0x3C00); \ + \ + const uint32_t nvec = n / VLEN_FP16; \ + const uint32_t nloe = n % VLEN_FP16; \ + \ + uint32_t i = 0; \ + \ + _Pragma("unroll(4)") \ + for (; i < nvec; i++) { \ + HVX_Vector res = hvx_vec_hybrid_div_f16(vsrc0[i], vsrc1[i], \ + f32_nan_inf_mask, f16_nan_inf_mask, \ + hf_one); \ + vdst[i] = res; \ + } \ + if (nloe) { \ + HVX_Vector res = hvx_vec_hybrid_div_f16(vsrc0[i], vsrc1[i], \ + f32_nan_inf_mask, f16_nan_inf_mask, \ + hf_one); \ + vec_store((void *) &vdst[i], nloe * SIZEOF_FP16, res); \ + } \ + } while(0) + +#define hvx_div_f32_loop_body(dst_type, src0_type, src1_type, vec_store) \ + do { \ + dst_type * restrict vdst = (dst_type *) dst; \ + src0_type * restrict vsrc0 = (src0_type *) src0; \ + src1_type * restrict vsrc1 = (src1_type *) src1; \ + \ + const HVX_Vector nan_inf_mask = Q6_V_vsplat_R(0x7f800000); \ + \ + const uint32_t nvec = n / VLEN_FP32; \ + const uint32_t nloe = n % VLEN_FP32; \ + \ + uint32_t i = 0; \ + \ + _Pragma("unroll(4)") \ + for (; i < nvec; i++) { \ + HVX_Vector inv_src1 = hvx_vec_inverse_f32_guard(vsrc1[i], nan_inf_mask); \ + HVX_Vector res = HVX_OP_MUL_F32(vsrc0[i], inv_src1); \ + vdst[i] = res; \ + } \ + if (nloe) { \ + HVX_Vector inv_src1 = hvx_vec_inverse_f32_guard(vsrc1[i], nan_inf_mask); \ + HVX_Vector res = HVX_OP_MUL_F32(vsrc0[i], inv_src1); \ + vec_store((void *) &vdst[i], nloe * SIZEOF_FP32, res); \ + } \ + } while(0) + +// Generic macro to define alignment permutations for an op +#define DEFINE_HVX_DIV_OP_VARIANTS(OP_NAME, OP_LOOP_BODY) \ +static inline void OP_NAME##_aaa(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) { \ + assert((uintptr_t) dst % 128 == 0); \ + assert((uintptr_t) src0 % 128 == 0); \ + assert((uintptr_t) src1 % 128 == 0); \ + OP_LOOP_BODY(HVX_Vector, HVX_Vector, HVX_Vector, hvx_vec_store_a); \ +} \ +static inline void OP_NAME##_aau(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) { \ + assert((uintptr_t) dst % 128 == 0); \ + assert((uintptr_t) src0 % 128 == 0); \ + OP_LOOP_BODY(HVX_Vector, HVX_Vector, HVX_UVector, hvx_vec_store_a); \ +} \ +static inline void OP_NAME##_aua(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) { \ + assert((uintptr_t) dst % 128 == 0); \ + assert((uintptr_t) src1 % 128 == 0); \ + OP_LOOP_BODY(HVX_Vector, HVX_UVector, HVX_Vector, hvx_vec_store_a); \ +} \ +static inline void OP_NAME##_auu(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) { \ + assert((uintptr_t) dst % 128 == 0); \ + OP_LOOP_BODY(HVX_Vector, HVX_UVector, HVX_UVector, hvx_vec_store_a); \ +} \ +static inline void OP_NAME##_uaa(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) { \ + assert((uintptr_t) src0 % 128 == 0); \ + assert((uintptr_t) src1 % 128 == 0); \ + OP_LOOP_BODY(HVX_UVector, HVX_Vector, HVX_Vector, hvx_vec_store_u); \ +} \ +static inline void OP_NAME##_uau(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) { \ + assert((uintptr_t) src0 % 128 == 0); \ + OP_LOOP_BODY(HVX_UVector, HVX_Vector, HVX_UVector, hvx_vec_store_u); \ +} \ +static inline void OP_NAME##_uua(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) { \ + assert((uintptr_t) src1 % 128 == 0); \ + OP_LOOP_BODY(HVX_UVector, HVX_UVector, HVX_Vector, hvx_vec_store_u); \ +} \ +static inline void OP_NAME##_uuu(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) { \ + OP_LOOP_BODY(HVX_UVector, HVX_UVector, HVX_UVector, hvx_vec_store_u); \ +} \ + +// Dispatcher logic +#define HVX_DIV_DISPATCHER(OP_NAME) \ +static inline void OP_NAME(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, const uint32_t num_elems) { \ + if (hex_is_aligned((void *) dst, 128)) { \ + if (hex_is_aligned((void *) src0, 128)) { \ + if (hex_is_aligned((void *) src1, 128)) OP_NAME##_aaa(dst, src0, src1, num_elems); \ + else OP_NAME##_aau(dst, src0, src1, num_elems); \ + } else { \ + if (hex_is_aligned((void *) src1, 128)) OP_NAME##_aua(dst, src0, src1, num_elems); \ + else OP_NAME##_auu(dst, src0, src1, num_elems); \ + } \ + } else { \ + if (hex_is_aligned((void *) src0, 128)) { \ + if (hex_is_aligned((void *) src1, 128)) OP_NAME##_uaa(dst, src0, src1, num_elems); \ + else OP_NAME##_uau(dst, src0, src1, num_elems); \ + } else { \ + if (hex_is_aligned((void *) src1, 128)) OP_NAME##_uua(dst, src0, src1, num_elems); \ + else OP_NAME##_uuu(dst, src0, src1, num_elems); \ + } \ + } \ +} + +DEFINE_HVX_DIV_OP_VARIANTS(hvx_div_f32, hvx_div_f32_loop_body) +DEFINE_HVX_DIV_OP_VARIANTS(hvx_div_f16, hvx_div_f16_loop_body) + +HVX_DIV_DISPATCHER(hvx_div_f32) +HVX_DIV_DISPATCHER(hvx_div_f16) + +#undef HVX_OP_MUL_F32 +#undef HVX_OP_MUL_F16 + +#endif // HVX_DIV_H diff --git a/ggml/src/ggml-hexagon/htp/hvx-dump.h b/ggml/src/ggml-hexagon/htp/hvx-dump.h new file mode 100644 index 00000000000..85201fc3453 --- /dev/null +++ b/ggml/src/ggml-hexagon/htp/hvx-dump.h @@ -0,0 +1,129 @@ +#ifndef HVX_DUMP_H +#define HVX_DUMP_H + +#include <HAP_farf.h> + +#include <stdbool.h> +#include <stdint.h> + +#include "hex-utils.h" +#include "hvx-types.h" + +static void hvx_vec_dump_f16_n(char * pref, HVX_Vector v, uint32_t n) { + HVX_VectorAlias u = { .v = v }; + + const uint32_t n0 = n / 16; + const uint32_t n1 = n % 16; + int i = 0; + for (; i < n0; i++) { + hex_dump_f16_line(pref, u.fp16 + (16 * i), 16); + } + if (n1) { + hex_dump_f16_line(pref, u.fp16 + (16 * i), n1); + } +} + +static void hvx_vec_dump_f16(char * pref, HVX_Vector v) { + hvx_vec_dump_f16_n(pref, v, 64); +} + +static void hvx_vec_dump_f32_n(char * pref, HVX_Vector v, uint32_t n) { + HVX_VectorAlias u = { .v = v }; + + const uint32_t n0 = n / 16; + const uint32_t n1 = n % 16; + int i = 0; + for (; i < n0; i++) { + hex_dump_f32_line(pref, u.fp32 + (16 * i), 16); + } + if (n1) { + hex_dump_f32_line(pref, u.fp32 + (16 * i), n1); + } +} + +static void hvx_vec_dump_f32_hmt(char * pref, HVX_Vector v) { + union { + HVX_Vector v; + float d[32]; + } u = { .v = v }; + + FARF(HIGH, "%s: %.6f %.6f %.6f %.6f ... %.6f %.6f %.6f %.6f ... %.6f %.6f %.6f %.6f\n", pref, u.d[0], u.d[1], + u.d[2], u.d[3], u.d[12], u.d[13], u.d[14], u.d[15], u.d[28], u.d[29], u.d[30], u.d[31]); +} + +static void hvx_vec_dump_f32(char * pref, HVX_Vector v) { + hvx_vec_dump_f32_n(pref, v, 32); +} + +static void hvx_vec_dump_int32(char * pref, HVX_Vector v) { + union { + HVX_Vector v; + int32_t d[32]; + } u = { .v = v }; + + for (int i = 0; i < 32 / 16; i++) { + hex_dump_int32_line(pref, u.d + (16 * i), 16); + } +} + +static void hvx_vec_dump_int32_hmt(char * pref, HVX_Vector v) { + union { + HVX_Vector v; + int32_t d[32]; + } u = { .v = v }; + + FARF(HIGH, "%s: %d %d %d %d ... %d %d %d %d ... %d %d %d %d\n", pref, u.d[0], u.d[1], u.d[2], u.d[3], u.d[12], + u.d[13], u.d[14], u.d[15], u.d[28], u.d[29], u.d[30], u.d[31]); +} + +static void hvx_vec_dump_int8_hmt(char * pref, HVX_Vector v) { + union { + HVX_Vector v; + int8_t d[128]; + } u = { .v = v }; + + FARF(HIGH, "%s: %d %d %d %d ... %d %d %d %d ... %d %d %d %d\n", pref, u.d[0], u.d[1], u.d[2], u.d[3], u.d[60], + u.d[61], u.d[62], u.d[63], u.d[124], u.d[125], u.d[126], u.d[127]); +} + +static void hvx_vec_dump_int8(char * pref, HVX_Vector v) { + union { + HVX_Vector v; + int8_t d[128]; + } u = { .v = v }; + + for (int i = 0; i < 128 / 16; i++) { + hex_dump_int8_line(pref, u.d + (16 * i), 16); + } +} + +static void hvx_vec_dump_uint8(char * pref, HVX_Vector v) { + union { + HVX_Vector v; + uint8_t d[128]; + } u = { .v = v }; + + for (int i = 0; i < 128 / 16; i++) { + hex_dump_uint8_line(pref, u.d + (16 * i), 16); + } +} + +static bool hvx_vec_eq(HVX_Vector v0, HVX_Vector v1, size_t n) { + typedef union { + HVX_Vector v; + int8_t d[128]; + } U; + + U u0 = { .v = v0 }; + U u1 = { .v = v1 }; + + for (int i = 0; i < n; i++) { + if (u0.d[i] != u1.d[i]) { + return false; + } + } + + return true; +} + +#endif /* HVX_DUMP_H */ diff --git a/ggml/src/ggml-hexagon/htp/hvx-exp.c b/ggml/src/ggml-hexagon/htp/hvx-exp.c deleted file mode 100644 index 21bf46a542f..00000000000 --- a/ggml/src/ggml-hexagon/htp/hvx-exp.c +++ /dev/null @@ -1,94 +0,0 @@ -#pragma clang diagnostic ignored "-Wunused-variable" -#pragma clang diagnostic ignored "-Wunused-function" -#pragma clang diagnostic ignored "-Wunused-but-set-variable" - -#include <hexagon_protos.h> -#include <hexagon_types.h> -#include <math.h> -#include <string.h> - -#define GGML_COMMON_DECL_C -#include "ggml-common.h" -#include "htp-ctx.h" -#include "htp-dma.h" -#include "htp-msg.h" -#include "htp-ops.h" -#include "hvx-utils.h" -#include "ops-utils.h" - -static inline HVX_Vector hvx_vec_exp_fp32_guard(HVX_Vector in_vec, HVX_Vector max_exp, HVX_Vector inf) { - const HVX_VectorPred pred0 = Q6_Q_vcmp_gt_VsfVsf(in_vec, max_exp); - - HVX_Vector out = hvx_vec_exp_fp32(in_vec); - - return Q6_V_vmux_QVV(pred0, inf, out); -} - -void hvx_exp_f32(const uint8_t * restrict src, uint8_t * restrict dst, const int num_elems, bool negate) { - int left_over = num_elems & (VLEN_FP32 - 1); - int num_elems_whole = num_elems - left_over; - - int unaligned_addr = 0; - int unaligned_loop = 0; - if ((0 == htp_is_aligned((void *) src, VLEN)) || (0 == htp_is_aligned((void *) dst, VLEN))) { - FARF(HIGH, "hvx_exp_f32: unaligned address in hvx op, possibly slower execution\n"); - unaligned_addr = 1; - } - // assert((0 == unaligned_addr) || (0 == num_elems_whole)); - if ((1 == unaligned_addr) && (num_elems_whole != 0)) { - unaligned_loop = 1; - FARF(HIGH, "hvx_exp_f32: unaligned loop in hvx op, possibly slower execution\n"); - } - - HVX_Vector vec_out = Q6_V_vzero(); - - static const float kInf = INFINITY; - static const float kMaxExp = 88.02f; // log(INF) - - const HVX_Vector max_exp = hvx_vec_splat_fp32(kMaxExp); - const HVX_Vector inf = hvx_vec_splat_fp32(kInf); - - if (0 == unaligned_loop) { - HVX_Vector * p_vec_in1 = (HVX_Vector *) src; - HVX_Vector * p_vec_out = (HVX_Vector *) dst; - - #pragma unroll(4) - for (int i = 0; i < num_elems_whole; i += VLEN_FP32) { - if (true == negate) { - HVX_Vector neg_vec_in = hvx_vec_neg_fp32(*p_vec_in1++); - *p_vec_out++ = hvx_vec_exp_fp32_guard(neg_vec_in, max_exp, inf); - } else { - *p_vec_out++ = hvx_vec_exp_fp32_guard(*p_vec_in1++, max_exp, inf); - } - } - } else { - #pragma unroll(4) - for (int i = 0; i < num_elems_whole; i += VLEN_FP32) { - HVX_Vector in = *(HVX_UVector *) (src + i * SIZEOF_FP32); - - if (true == negate) { - HVX_Vector neg_vec_in = hvx_vec_neg_fp32(in); - *(HVX_UVector *) (dst + i * SIZEOF_FP32) = hvx_vec_exp_fp32_guard(neg_vec_in, max_exp, inf); - } else { - *(HVX_UVector *) (dst + i * SIZEOF_FP32) = hvx_vec_exp_fp32_guard(in, max_exp, inf); - } - } - } - - if (left_over > 0) { - const float * srcf = (float *) src + num_elems_whole; - float * dstf = (float *) dst + num_elems_whole; - - HVX_Vector in = *(HVX_UVector *) srcf; - - if (true == negate) { - HVX_Vector neg_vec_in = hvx_vec_neg_fp32(in); - - vec_out = hvx_vec_exp_fp32_guard(neg_vec_in, max_exp, inf); - } else { - vec_out = hvx_vec_exp_fp32_guard(in, max_exp, inf); - } - - hvx_vec_store_u((void *) dstf, left_over * SIZEOF_FP32, vec_out); - } -} diff --git a/ggml/src/ggml-hexagon/htp/hvx-exp.h b/ggml/src/ggml-hexagon/htp/hvx-exp.h new file mode 100644 index 00000000000..e71ec4909a6 --- /dev/null +++ b/ggml/src/ggml-hexagon/htp/hvx-exp.h @@ -0,0 +1,216 @@ +#ifndef HVX_EXP_H +#define HVX_EXP_H + +#include <stdbool.h> +#include <stdint.h> +#include <math.h> + +#include "hvx-base.h" +#include "hvx-floor.h" + +#define EXP_COEFF_5 (0x39506967) // 0.000198757 = 1/(7!) +#define EXP_COEFF_4 (0x3AB743CE) // 0.0013982 = 1/(6!) +#define EXP_COEFF_3 (0x3C088908) // 0.00833345 = 1/(5!) +#define EXP_COEFF_2 (0x3D2AA9C1) // 0.416658 = 1/(4!) +#define EXP_COEFF_1 (0x3E2AAAAA) // 0.16666667 = 1/(3!) +#define EXP_COEFF_0 (0x3F000000) // 0.5 = 1/(2!) +#define EXP_LOGN2 (0x3F317218) // ln(2) = 0.6931471805 +#define EXP_LOG2E (0x3FB8AA3B) // log2(e) = 1/ln(2) = 1.4426950408 +#define EXP_ONE (0x3f800000) // 1.0 +#define EXP_RANGE_R (0x42B17218) // ln(FLT_MAX) approx = 88.7228 +#define EXP_RANGE_L (0xC2B00000) // -88.0 (approx log(FLT_MIN)) + +static inline HVX_Vector hvx_vec_exp_f32(HVX_Vector in_vec) { + HVX_Vector z_qf32_v; + HVX_Vector x_v; + HVX_Vector x_qf32_v; + HVX_Vector y_v; + HVX_Vector k_v; + HVX_Vector f_v; + HVX_Vector epsilon_v; + HVX_Vector log2e = Q6_V_vsplat_R(EXP_LOG2E); + HVX_Vector logn2 = Q6_V_vsplat_R(EXP_LOGN2); + HVX_Vector E_const; + HVX_Vector zero_v = Q6_V_vzero(); + + // exp(x) is approximated as follows: + // f = floor(x/ln(2)) = floor(x*log2(e)) + // epsilon = x - f*ln(2) + // exp(x) = exp(epsilon+f*ln(2)) + // = exp(epsilon)*exp(f*ln(2)) + // = exp(epsilon)*2^f + // + // Since epsilon is close to zero, it can be approximated with its Taylor series: + // exp(x) ~= 1+x+x^2/2!+x^3/3!+...+x^n/n!+... + // Preserving the first eight elements, we get: + // exp(x) ~= 1+x+e0*x^2+e1*x^3+e2*x^4+e3*x^5+e4*x^6+e5*x^7 + // = 1+x+(E0+(E1+(E2+(E3+(E4+E5*x)*x)*x)*x)*x)*x^2 + + HVX_Vector temp_v = in_vec; + + // Clamp inputs to (-88.0, 88.0) to avoid overflow/underflow + HVX_VectorPred pred_cap_right = Q6_Q_vcmp_gt_VsfVsf(in_vec, Q6_V_vsplat_R(EXP_RANGE_R)); + HVX_VectorPred pred_cap_left = Q6_Q_vcmp_gt_VsfVsf(Q6_V_vsplat_R(EXP_RANGE_L), in_vec); + + in_vec = Q6_V_vmux_QVV(pred_cap_right, Q6_V_vsplat_R(EXP_RANGE_R), temp_v); + in_vec = Q6_V_vmux_QVV(pred_cap_left, Q6_V_vsplat_R(EXP_RANGE_L), in_vec); + + epsilon_v = Q6_Vqf32_vmpy_VsfVsf(log2e, in_vec); + epsilon_v = Q6_Vsf_equals_Vqf32(epsilon_v); + + // f_v is the floating point result and k_v is the integer result + f_v = hvx_vec_floor_f32(epsilon_v); + k_v = hvx_vec_truncate_f32(f_v); + + x_qf32_v = Q6_Vqf32_vadd_VsfVsf(in_vec, zero_v); + + // x = x - f_v * logn2; + epsilon_v = Q6_Vqf32_vmpy_VsfVsf(f_v, logn2); + x_qf32_v = Q6_Vqf32_vsub_Vqf32Vqf32(x_qf32_v, epsilon_v); + // normalize before every QFloat's vmpy + x_qf32_v = Q6_Vqf32_vadd_Vqf32Vsf(x_qf32_v, zero_v); + + x_v = Q6_Vsf_equals_Vqf32(x_qf32_v); + + // z = x * x; + z_qf32_v = Q6_Vqf32_vmpy_Vqf32Vqf32(x_qf32_v, x_qf32_v); + z_qf32_v = Q6_Vqf32_vadd_Vqf32Vsf(z_qf32_v, zero_v); + + // y = E4 + E5 * x; + E_const = Q6_V_vsplat_R(EXP_COEFF_5); + y_v = Q6_Vqf32_vmpy_VsfVsf(E_const, x_v); + E_const = Q6_V_vsplat_R(EXP_COEFF_4); + y_v = Q6_Vqf32_vadd_Vqf32Vsf(y_v, E_const); + y_v = Q6_Vqf32_vadd_Vqf32Vsf(y_v, zero_v); + + // y = E3 + y * x; + E_const = Q6_V_vsplat_R(EXP_COEFF_3); + y_v = Q6_Vqf32_vmpy_Vqf32Vqf32(y_v, x_qf32_v); + y_v = Q6_Vqf32_vadd_Vqf32Vsf(y_v, E_const); + y_v = Q6_Vqf32_vadd_Vqf32Vsf(y_v, zero_v); + + // y = E2 + y * x; + E_const = Q6_V_vsplat_R(EXP_COEFF_2); + y_v = Q6_Vqf32_vmpy_Vqf32Vqf32(y_v, x_qf32_v); + y_v = Q6_Vqf32_vadd_Vqf32Vsf(y_v, E_const); + y_v = Q6_Vqf32_vadd_Vqf32Vsf(y_v, zero_v); + + // y = E1 + y * x; + E_const = Q6_V_vsplat_R(EXP_COEFF_1); + y_v = Q6_Vqf32_vmpy_Vqf32Vqf32(y_v, x_qf32_v); + y_v = Q6_Vqf32_vadd_Vqf32Vsf(y_v, E_const); + y_v = Q6_Vqf32_vadd_Vqf32Vsf(y_v, zero_v); + + // y = E0 + y * x; + E_const = Q6_V_vsplat_R(EXP_COEFF_0); + y_v = Q6_Vqf32_vmpy_Vqf32Vqf32(y_v, x_qf32_v); + y_v = Q6_Vqf32_vadd_Vqf32Vsf(y_v, E_const); + y_v = Q6_Vqf32_vadd_Vqf32Vsf(y_v, zero_v); + + // y = x + y * z; + y_v = Q6_Vqf32_vmpy_Vqf32Vqf32(y_v, z_qf32_v); + y_v = Q6_Vqf32_vadd_Vqf32Vqf32(y_v, x_qf32_v); + y_v = Q6_Vqf32_vadd_Vqf32Vsf(y_v, zero_v); + + // y = y + 1.0; + y_v = Q6_Vqf32_vadd_Vqf32Vsf(y_v, Q6_V_vsplat_R(EXP_ONE)); + + // insert exponents + // y = ldexpf(y, k); + // y_v += k_v; // qf32 + // modify exponent + + y_v = Q6_Vsf_equals_Vqf32(y_v); + + // add k_v to the exponent of y_v + HVX_Vector y_v_exponent = Q6_Vw_vasl_VwR(y_v, 1); + + y_v_exponent = Q6_Vuw_vlsr_VuwR(y_v_exponent, IEEE_VSF_MANTLEN + 1); + y_v_exponent = Q6_Vw_vadd_VwVw(k_v, y_v_exponent); + + // exponent cannot be negative; if overflow is detected, result is set to zero + HVX_VectorPred qy_v_negative_exponent = Q6_Q_vcmp_gt_VwVw(zero_v, y_v_exponent); + + y_v = Q6_Vw_vaslacc_VwVwR(y_v, k_v, IEEE_VSF_MANTLEN); + + y_v = Q6_V_vmux_QVV(qy_v_negative_exponent, zero_v, y_v); + + return y_v; +} + +static inline HVX_Vector hvx_vec_exp_f32_guard(HVX_Vector in_vec, HVX_Vector max_exp, HVX_Vector inf) { + const HVX_VectorPred pred0 = Q6_Q_vcmp_gt_VsfVsf(in_vec, max_exp); + + HVX_Vector out = hvx_vec_exp_f32(in_vec); + + return Q6_V_vmux_QVV(pred0, inf, out); +} + +static inline void hvx_exp_f32(uint8_t * restrict dst, const uint8_t * restrict src, const int num_elems, bool negate) { + int left_over = num_elems & (VLEN_FP32 - 1); + int num_elems_whole = num_elems - left_over; + + int unaligned_addr = 0; + int unaligned_loop = 0; + if ((0 == hex_is_aligned((void *) src, VLEN)) || (0 == hex_is_aligned((void *) dst, VLEN))) { + unaligned_addr = 1; + } + // assert((0 == unaligned_addr) || (0 == num_elems_whole)); + if ((1 == unaligned_addr) && (num_elems_whole != 0)) { + unaligned_loop = 1; + } + + HVX_Vector vec_out = Q6_V_vzero(); + + static const float kInf = INFINITY; + static const float kMaxExp = 88.7228f; + + const HVX_Vector max_exp = hvx_vec_splat_f32(kMaxExp); + const HVX_Vector inf = hvx_vec_splat_f32(kInf); + + if (0 == unaligned_loop) { + HVX_Vector * p_vec_in1 = (HVX_Vector *) src; + HVX_Vector * p_vec_out = (HVX_Vector *) dst; + + #pragma unroll(4) + for (int i = 0; i < num_elems_whole; i += VLEN_FP32) { + if (true == negate) { + HVX_Vector neg_vec_in = hvx_vec_neg_f32(*p_vec_in1++); + *p_vec_out++ = hvx_vec_exp_f32_guard(neg_vec_in, max_exp, inf); + } else { + *p_vec_out++ = hvx_vec_exp_f32_guard(*p_vec_in1++, max_exp, inf); + } + } + } else { + #pragma unroll(4) + for (int i = 0; i < num_elems_whole; i += VLEN_FP32) { + HVX_Vector in = *(HVX_UVector *) (src + i * SIZEOF_FP32); + + if (true == negate) { + HVX_Vector neg_vec_in = hvx_vec_neg_f32(in); + *(HVX_UVector *) (dst + i * SIZEOF_FP32) = hvx_vec_exp_f32_guard(neg_vec_in, max_exp, inf); + } else { + *(HVX_UVector *) (dst + i * SIZEOF_FP32) = hvx_vec_exp_f32_guard(in, max_exp, inf); + } + } + } + + if (left_over > 0) { + const float * srcf = (float *) src + num_elems_whole; + float * dstf = (float *) dst + num_elems_whole; + + HVX_Vector in = *(HVX_UVector *) srcf; + + if (true == negate) { + HVX_Vector neg_vec_in = hvx_vec_neg_f32(in); + + vec_out = hvx_vec_exp_f32_guard(neg_vec_in, max_exp, inf); + } else { + vec_out = hvx_vec_exp_f32_guard(in, max_exp, inf); + } + + hvx_vec_store_u((void *) dstf, left_over * SIZEOF_FP32, vec_out); + } +} + +#endif /* HVX_EXP_H */ diff --git a/ggml/src/ggml-hexagon/htp/hvx-flash-attn.h b/ggml/src/ggml-hexagon/htp/hvx-flash-attn.h new file mode 100644 index 00000000000..f1f2e49e455 --- /dev/null +++ b/ggml/src/ggml-hexagon/htp/hvx-flash-attn.h @@ -0,0 +1,47 @@ +#ifndef HVX_FLASH_ATTN_H +#define HVX_FLASH_ATTN_H + +#include <math.h> +#include "hvx-utils.h" + +// Scalar helper to compute a single ALiBi slope. +static inline float alibi_slope(uint32_t h, uint32_t n_head_log2, float m0, float m1) { + return (h < n_head_log2) ? powf(m0, h + 1) : powf(m1, 2 * (h - n_head_log2) + 1); +} + +// Vectorized helper to compute 32 ALiBi slopes starting from (kv_head * G). +static inline HVX_Vector hvx_alibi_slopes( + uint32_t kv_head, + uint32_t G, + uint32_t n_head_log2, + float m0, + float m1 +) { + static const float ramp_32[32] __attribute__((aligned(128))) = { + 0.0f, 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, + 8.0f, 9.0f, 10.0f, 11.0f, 12.0f, 13.0f, 14.0f, 15.0f, + 16.0f, 17.0f, 18.0f, 19.0f, 20.0f, 21.0f, 22.0f, 23.0f, + 24.0f, 25.0f, 26.0f, 27.0f, 28.0f, 29.0f, 30.0f, 31.0f + }; + HVX_Vector v_ramp = hvx_vmem(ramp_32); + HVX_Vector v_h_base = hvx_vec_splat_f32((float)(kv_head * G)); + HVX_Vector v_h = hvx_vec_add_f32_f32(v_h_base, v_ramp); + + // Compute exponent_m0: h + 1 + HVX_Vector v_exp_m0 = hvx_vec_add_f32_f32(v_h, hvx_vec_splat_f32(1.0f)); + + // Compute exponent_m1: 2 * (h - n_head_log2) + 1 + HVX_Vector v_n_head_log2 = hvx_vec_splat_f32((float)n_head_log2); + HVX_Vector v_h_minus = hvx_vec_sub_f32_f32(v_h, v_n_head_log2); + HVX_Vector v_exp_m1 = hvx_vec_add_f32_f32(hvx_vec_mul_f32_f32(hvx_vec_splat_f32(2.0f), v_h_minus), hvx_vec_splat_f32(1.0f)); + + // Compute powers + HVX_Vector v_pow_m0 = hvx_vec_pow_const_base_f32(m0, v_exp_m0); + HVX_Vector v_pow_m1 = hvx_vec_pow_const_base_f32(m1, v_exp_m1); + + // Select based on h < n_head_log2 + HVX_VectorPred p_cond = Q6_Q_vcmp_gt_VsfVsf(v_n_head_log2, v_h); // v_n_head_log2 > v_h <=> h < n_head_log2 + return Q6_V_vmux_QVV(p_cond, v_pow_m0, v_pow_m1); +} + +#endif /* HVX_FLASH_ATTN_H */ diff --git a/ggml/src/ggml-hexagon/htp/hvx-floor.h b/ggml/src/ggml-hexagon/htp/hvx-floor.h new file mode 100644 index 00000000000..6a1bfde5675 --- /dev/null +++ b/ggml/src/ggml-hexagon/htp/hvx-floor.h @@ -0,0 +1,100 @@ +#ifndef HVX_FLOOR_H +#define HVX_FLOOR_H + +#include <stdbool.h> +#include <stdint.h> + +#include "hvx-base.h" + +#define IEEE_VSF_EXPLEN (8) +#define IEEE_VSF_EXPBIAS (127) +#define IEEE_VSF_EXPMASK (0xFF) +#define IEEE_VSF_MANTLEN (23) +#define IEEE_VSF_MANTMASK (0x7FFFFF) +#define IEEE_VSF_MIMPMASK (0x800000) + +static inline HVX_Vector hvx_vec_truncate_f32(HVX_Vector in_vec) { + HVX_Vector mask_mant_v = Q6_V_vsplat_R(IEEE_VSF_MANTMASK); + HVX_Vector mask_impl_v = Q6_V_vsplat_R(IEEE_VSF_MIMPMASK); + HVX_Vector const_zero_v = Q6_V_vzero(); + + HVX_VectorPred q_negative = Q6_Q_vcmp_gt_VwVw(const_zero_v, in_vec); + + HVX_Vector expval_v = in_vec >> IEEE_VSF_MANTLEN; + expval_v &= IEEE_VSF_EXPMASK; + expval_v -= IEEE_VSF_EXPBIAS; + + // negative exp == fractional value + HVX_VectorPred q_negexp = Q6_Q_vcmp_gt_VwVw(const_zero_v, expval_v); + + HVX_Vector rshift_v = IEEE_VSF_MANTLEN - expval_v; // fractional bits - exp shift + + HVX_Vector mant_v = in_vec & mask_mant_v; // obtain mantissa + HVX_Vector vout = Q6_Vw_vadd_VwVw(mant_v, mask_impl_v); // add implicit 1.0 + + vout = Q6_Vw_vasr_VwVw(vout, rshift_v); // shift to obtain truncated integer + vout = Q6_V_vmux_QVV(q_negexp, const_zero_v, vout); // expval<0 -> 0 + + HVX_Vector neg_vout = -vout; + + vout = Q6_V_vmux_QVV(q_negative, neg_vout, vout); // handle negatives + + return (vout); +} + +static inline HVX_Vector hvx_vec_floor_f32(HVX_Vector in_vec) { + HVX_Vector mask_mant_v = Q6_V_vsplat_R(IEEE_VSF_MANTMASK); + HVX_Vector mask_impl_v = Q6_V_vsplat_R(IEEE_VSF_MIMPMASK); + HVX_Vector const_mnlen_v = Q6_V_vsplat_R(IEEE_VSF_MANTLEN); + HVX_Vector const_zero_v = Q6_V_vzero(); + HVX_Vector const_negone_v = Q6_V_vsplat_R(0xbf800000); // -1 IEEE vsf + + HVX_VectorPred q_negative = Q6_Q_vcmp_gt_VwVw(const_zero_v, in_vec); + + HVX_Vector expval_v = in_vec >> IEEE_VSF_MANTLEN; + expval_v &= IEEE_VSF_EXPMASK; + expval_v -= IEEE_VSF_EXPBIAS; + + HVX_VectorPred q_negexp = Q6_Q_vcmp_gt_VwVw(const_zero_v, expval_v); + HVX_VectorPred q_expltmn = Q6_Q_vcmp_gt_VwVw(const_mnlen_v, expval_v); + HVX_VectorPred q_negexp_pos = Q6_Q_vcmp_gtand_QVwVw(q_negexp, in_vec, const_zero_v); + HVX_VectorPred q_negexp_neg = Q6_Q_vcmp_gtand_QVwVw(q_negexp, const_zero_v, in_vec); + + // if expval < 0 (q_negexp) // <0, floor is 0 + // if vin > 0 + // floor = 0 + // if vin < 0 + // floor = -1 + // if expval < mant_len (q_expltmn) // >0, but fraction may exist + // get sign (q_negative) + // mask >> expval // fraction bits to mask off + // vout = ~(mask) // apply mask to remove fraction + // if (qneg) // negative floor is one less (more, sign bit for neg) + // vout += ((impl_mask) >> expval) + // if (mask && vin) + // vout = vin + // else // already an integer + // ; // no change + + // compute floor + mask_mant_v >>= expval_v; + HVX_Vector neg_addin_v = mask_impl_v >> expval_v; + HVX_Vector vout_neg_addin = Q6_Vw_vadd_VwVw(in_vec, neg_addin_v); + HVX_Vector vout = Q6_V_vmux_QVV(q_negative, vout_neg_addin, in_vec); + + HVX_Vector mask_chk_v = Q6_V_vand_VV(in_vec, mask_mant_v); // chk if bits set + HVX_VectorPred q_integral = Q6_Q_vcmp_eq_VwVw(const_zero_v, mask_chk_v); + + HVX_Vector not_mask_v = Q6_V_vnot_V(mask_mant_v); // frac bits to clear + HVX_Vector vfrfloor_v = Q6_V_vand_VV(vout, not_mask_v); // clear frac bits + + vout = in_vec; + vout = Q6_V_vmux_QVV(q_expltmn, vfrfloor_v, vout); // expval<mant + vout = Q6_V_vmux_QVV(q_integral, in_vec, vout); // integral values + vout = Q6_V_vmux_QVV(q_negexp_pos, const_zero_v, vout); // expval<0 x>0 -> 0 + vout = Q6_V_vmux_QVV(q_negexp_neg, const_negone_v, vout); // expval<0 x<0 -> -1 + + return vout; +} + +#endif /* HVX_FLOOR_H */ diff --git a/ggml/src/ggml-hexagon/htp/hvx-inverse.c b/ggml/src/ggml-hexagon/htp/hvx-inverse.c deleted file mode 100644 index 4d70634fcd4..00000000000 --- a/ggml/src/ggml-hexagon/htp/hvx-inverse.c +++ /dev/null @@ -1,72 +0,0 @@ -#pragma clang diagnostic ignored "-Wunused-variable" -#pragma clang diagnostic ignored "-Wunused-function" -#pragma clang diagnostic ignored "-Wunused-but-set-variable" - -#include <hexagon_protos.h> -#include <hexagon_types.h> -#include <math.h> -#include <string.h> - -#define GGML_COMMON_DECL_C -#include "ggml-common.h" -#include "htp-ctx.h" -#include "htp-dma.h" -#include "htp-msg.h" -#include "htp-ops.h" -#include "hvx-utils.h" -#include "ops-utils.h" - -static inline HVX_Vector hvx_vec_inverse_fp32_guard(HVX_Vector v_sf, HVX_Vector nan_inf_mask) { - HVX_Vector out = hvx_vec_inverse_fp32(v_sf); - - HVX_Vector masked_out = Q6_V_vand_VV(out, nan_inf_mask); - const HVX_VectorPred pred = Q6_Q_vcmp_eq_VwVw(nan_inf_mask, masked_out); - - return Q6_V_vmux_QVV(pred, Q6_V_vzero(), out); -} - -void hvx_inverse_f32(const uint8_t * restrict src, uint8_t * restrict dst, const int num_elems) { - int left_over = num_elems & (VLEN_FP32 - 1); - int num_elems_whole = num_elems - left_over; - - int unaligned_addr = 0; - int unaligned_loop = 0; - if ((0 == htp_is_aligned((void *) src, VLEN)) || (0 == htp_is_aligned((void *) dst, VLEN))) { - FARF(HIGH, "hvx_inverse_f32: unaligned address in hvx op, possibly slower execution\n"); - unaligned_addr = 1; - } - // assert((0 == unaligned_addr) || (0 == num_elems_whole)); - if ((1 == unaligned_addr) && (num_elems_whole != 0)) { - unaligned_loop = 1; - FARF(HIGH, "hvx_inverse_f32: unaligned loop in hvx op, possibly slower execution\n"); - } - - static const uint32_t kNanInfMask = 0x7f800000; - const HVX_Vector nan_inf_mask = Q6_V_vsplat_R(kNanInfMask); - - if (0 == unaligned_loop) { - HVX_Vector * p_vec_in = (HVX_Vector *) src; - HVX_Vector * p_vec_out = (HVX_Vector *) dst; - - #pragma unroll(4) - for (int i = 0; i < num_elems_whole; i += VLEN_FP32) { - *p_vec_out++ = hvx_vec_inverse_fp32_guard(*p_vec_in++, nan_inf_mask); - } - } else { - #pragma unroll(4) - for (int i = 0; i < num_elems_whole; i += VLEN_FP32) { - HVX_Vector in = *(HVX_UVector *) (src + i * SIZEOF_FP32); - *(HVX_UVector *) (dst + i * SIZEOF_FP32) = hvx_vec_inverse_fp32_guard(in, nan_inf_mask); - } - } - - if (left_over > 0) { - const float * srcf = (float *) src + num_elems_whole; - float * dstf = (float *) dst + num_elems_whole; - - HVX_Vector in = *(HVX_UVector *) srcf; - HVX_Vector out = hvx_vec_inverse_fp32_guard(in, nan_inf_mask); - - hvx_vec_store_u((void *) dstf, left_over * SIZEOF_FP32, out); - } -} diff --git a/ggml/src/ggml-hexagon/htp/hvx-inverse.h b/ggml/src/ggml-hexagon/htp/hvx-inverse.h new file mode 100644 index 00000000000..f2054f45bac --- /dev/null +++ b/ggml/src/ggml-hexagon/htp/hvx-inverse.h @@ -0,0 +1,210 @@ +#ifndef HVX_INVERSE_H +#define HVX_INVERSE_H + +#include <HAP_farf.h> + +#include <math.h> +#include <string.h> +#include <assert.h> +#include <stddef.h> +#include <stdint.h> + +#include "hvx-base.h" + +// ==================================================== +// FUNCTION: 1/(x+1) y(0) = 1, y(0.5) = 0.6667, y(1) = 0.5 +// Order:3; continuity: True; Ends forced: True +// Mode: unsigned; Result fractional bits: 14 +// Peak Error: 1.1295e-04 Rms Error: 2.8410e-05 Mean Error: 1.1370e-05 +// 32769 -32706 31252 -10589 +// 32590 -30635 22793 -4493 +// 32066 -27505 16481 -2348 +// 31205 -24054 11849 -1306 + +static inline HVX_Vector hvx_vec_recip_xp1_O3_unsigned(HVX_Vector vx) { + // input is 0..0xffff representing 0.0 .. 1.0 + HVX_Vector p; + p = Q6_Vh_vlut4_VuhPh(vx, 0xFAE6F6D4EE73D6A3ull); + p = Q6_Vh_vmpa_VhVhVuhPuh_sat(p, vx, 0x2E49406159097A14ull); + p = Q6_Vh_vmps_VhVhVuhPuh_sat(p, vx, 0x5DF66B7177AB7FC2ull); + p = Q6_Vh_vmpa_VhVhVuhPuh_sat(p, vx, 0x79E57D427F4E8001ull); + return p; // signed result, 14 fractional bits +} + +// Find reciprocal of fp16. +// (1) first, convert to fp32, multiplying by 1.0; this is done to +// handle denormals. Ignoring sign and zero, result should be at +// least 5.9604645e-08 (32-bit code 0x33800000) and at most 131008 (0x47ffe000) +// (exponent in range [103,143]) +// (2) extract the mantissa into 16-bit unsigned; find reciprocal using a fitted poly +// (3) put this, along with '253-exp' (exp from (1)) together to make an qf32 +// (4) convert that to fp16 +// (5) put sign back in. Also, if the original value (w/o sign) was <0x81, replace +// the result with the max value. +static inline HVX_Vector hvx_vec_inverse_f16(HVX_Vector vals) { + HVX_Vector em_mask = Q6_Vh_vsplat_R(0x7FFF); + HVX_Vector avals = Q6_V_vand_VV(vals, em_mask); + HVX_VectorPred is_neg = Q6_Q_vcmp_gt_VhVh(avals, vals); + // is too small to 1/x ? for 'standard' fp16, this would be 0x101 + HVX_VectorPred is_small = Q6_Q_vcmp_gt_VhVh(Q6_Vh_vsplat_R(0x101), avals); + + HVX_VectorPair to_qf32 = Q6_Wqf32_vmpy_VhfVhf(avals, Q6_Vh_vsplat_R(0x3C00)); // *1.0 + HVX_Vector to_f32_0 = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(to_qf32)); + HVX_Vector to_f32_1 = Q6_Vsf_equals_Vqf32(Q6_V_hi_W(to_qf32)); + + // bits 22..13 contain the mantissa now (w/o hidden bit); move to bit 14..5 of a 16-bit vector + HVX_Vector mant_u16 = Q6_Vh_vshuffo_VhVh(Q6_Vw_vasl_VwR(to_f32_1, 9), Q6_Vw_vasl_VwR(to_f32_0, 9)); + // likewise extract the upper 16 from each, containing the exponents in range 103..142 + HVX_Vector exp_u16 = Q6_Vh_vshuffo_VhVh(to_f32_1, to_f32_0); + //Get exponent in IEEE 32-bit representation + exp_u16 = Q6_Vuh_vlsr_VuhR(exp_u16, 7); + + // so, mant_u16 contains an unbiased mantissa in upper 10 bits of each u16 lane + // We can consider it to be x-1.0, with 16 fractional bits, where 'x' is in range [1.0,2.0) + // Use poly to transform to 1/x, with 14 fractional bits + // + HVX_Vector rm = hvx_vec_recip_xp1_O3_unsigned(mant_u16); + + HVX_Vector vcl0 = Q6_Vuh_vcl0_Vuh(rm); //count leading zeros + + // Get mantissa for 16-bit representation + HVX_Vector mant_recip = Q6_V_vand_VV(Q6_Vh_vasr_VhR(Q6_Vh_vasl_VhVh(rm, vcl0), 5), Q6_Vh_vsplat_R(0x03FF)); + + //Compute Reciprocal Exponent + HVX_Vector exp_recip = + Q6_Vh_vsub_VhVh(Q6_Vh_vsub_VhVh(Q6_Vh_vsplat_R(254), exp_u16), Q6_Vh_vsub_VhVh(vcl0, Q6_Vh_vsplat_R(1))); + //Convert it for 16-bit representation + exp_recip = Q6_Vh_vadd_VhVh_sat(Q6_Vh_vsub_VhVh(exp_recip, Q6_Vh_vsplat_R(127)), Q6_Vh_vsplat_R(15)); + exp_recip = Q6_Vh_vasl_VhR(exp_recip, 10); + + //Merge exponent and mantissa for reciprocal + HVX_Vector recip = Q6_V_vor_VV(exp_recip, mant_recip); + // map 'small' inputs to standard largest value 0x7bff + recip = Q6_V_vmux_QVV(is_small, Q6_Vh_vsplat_R(0x7bff), recip); + // add sign back + recip = Q6_V_vandor_VQR(recip, is_neg, 0x80008000); + return recip; +} + +static inline HVX_Vector hvx_vec_inverse_f32(HVX_Vector v_sf) { + HVX_Vector inv_aprox_sf = Q6_V_vsplat_R(0x7EEEEBB3); + HVX_Vector two_sf = hvx_vec_splat_f32(2.0); + + // First approximation + HVX_Vector i_sf = Q6_Vw_vsub_VwVw(inv_aprox_sf, v_sf); + + HVX_Vector r_qf; + + // Refine + r_qf = Q6_Vqf32_vmpy_VsfVsf( + i_sf, Q6_Vsf_equals_Vqf32(Q6_Vqf32_vsub_VsfVsf(two_sf, Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(i_sf, v_sf))))); + r_qf = Q6_Vqf32_vmpy_Vqf32Vqf32( + r_qf, Q6_Vqf32_vsub_VsfVsf(two_sf, Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(Q6_Vsf_equals_Vqf32(r_qf), v_sf)))); + r_qf = Q6_Vqf32_vmpy_Vqf32Vqf32( + r_qf, Q6_Vqf32_vsub_VsfVsf(two_sf, Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(Q6_Vsf_equals_Vqf32(r_qf), v_sf)))); + + return Q6_Vsf_equals_Vqf32(r_qf); +} + +static inline HVX_Vector hvx_vec_inverse_f32_guard(HVX_Vector v_sf, HVX_Vector nan_inf_mask) { + HVX_Vector out = hvx_vec_inverse_f32(v_sf); + + HVX_Vector masked_out = Q6_V_vand_VV(out, nan_inf_mask); + const HVX_VectorPred pred = Q6_Q_vcmp_eq_VwVw(nan_inf_mask, masked_out); + + return Q6_V_vmux_QVV(pred, Q6_V_vzero(), out); +} + +#define hvx_inverse_f32_loop_body(dst_type, src_type, vec_store) \ + do { \ + dst_type * restrict vdst = (dst_type *) dst; \ + src_type * restrict vsrc = (src_type *) src; \ + \ + const HVX_Vector nan_inf_mask = Q6_V_vsplat_R(0x7f800000); \ + \ + const uint32_t nvec = n / VLEN_FP32; \ + const uint32_t nloe = n % VLEN_FP32; \ + \ + uint32_t i = 0; \ + \ + _Pragma("unroll(4)") \ + for (; i < nvec; i++) { \ + vdst[i] = hvx_vec_inverse_f32_guard(vsrc[i], nan_inf_mask); \ + } \ + if (nloe) { \ + HVX_Vector v = hvx_vec_inverse_f32_guard(vsrc[i], nan_inf_mask); \ + vec_store((void *) &vdst[i], nloe * SIZEOF_FP32, v); \ + } \ + } while(0) + +static inline HVX_Vector hvx_vec_inverse_f16_guard(HVX_Vector v_sf, HVX_Vector nan_inf_mask) { + HVX_Vector out = hvx_vec_inverse_f16(v_sf); + + HVX_Vector masked_out = Q6_V_vand_VV(out, nan_inf_mask); + const HVX_VectorPred pred = Q6_Q_vcmp_eq_VhVh(nan_inf_mask, masked_out); + + return Q6_V_vmux_QVV(pred, Q6_V_vzero(), out); +} + +#define hvx_inverse_f16_loop_body(dst_type, src_type, vec_store) \ + do { \ + dst_type * restrict vdst = (dst_type *) dst; \ + src_type * restrict vsrc = (src_type *) src; \ + \ + const HVX_Vector nan_inf_mask = Q6_Vh_vsplat_R(0x7c00); \ + \ + const uint32_t nvec = n / VLEN_FP16; \ + const uint32_t nloe = n % VLEN_FP16; \ + \ + uint32_t i = 0; \ + \ + _Pragma("unroll(4)") \ + for (; i < nvec; i++) { \ + vdst[i] = hvx_vec_inverse_f16_guard(vsrc[i], nan_inf_mask); \ + } \ + if (nloe) { \ + HVX_Vector v = hvx_vec_inverse_f16_guard(vsrc[i], nan_inf_mask); \ + vec_store((void *) &vdst[i], nloe * SIZEOF_FP16, v); \ + } \ + } while(0) + +// Generic macro to define alignment permutations for an op +#define DEFINE_HVX_INV_OP_VARIANTS(OP_NAME, OP_LOOP_BODY) \ +static inline void OP_NAME##_aa(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) { \ + assert((uintptr_t) dst % 128 == 0); \ + assert((uintptr_t) src % 128 == 0); \ + OP_LOOP_BODY(HVX_Vector, HVX_Vector, hvx_vec_store_a); \ +} \ +static inline void OP_NAME##_au(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) { \ + assert((uintptr_t) dst % 128 == 0); \ + OP_LOOP_BODY(HVX_Vector, HVX_UVector, hvx_vec_store_a); \ +} \ +static inline void OP_NAME##_ua(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) { \ + assert((uintptr_t) src % 128 == 0); \ + OP_LOOP_BODY(HVX_UVector, HVX_Vector, hvx_vec_store_u); \ +} \ +static inline void OP_NAME##_uu(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) { \ + OP_LOOP_BODY(HVX_UVector, HVX_UVector, hvx_vec_store_u); \ +} \ + +// Dispatcher logic +#define HVX_INV_DISPATCHER(OP_NAME) \ +static inline void OP_NAME(uint8_t * restrict dst, const uint8_t * restrict src, const uint32_t num_elems) { \ + if (hex_is_aligned((void *) dst, 128) && hex_is_aligned((void *) src, 128)) { \ + OP_NAME##_aa(dst, src, num_elems); \ + } else if (hex_is_aligned((void *) dst, 128)) { \ + OP_NAME##_au(dst, src, num_elems); \ + } else if (hex_is_aligned((void *) src, 128)) { \ + OP_NAME##_ua(dst, src, num_elems); \ + } else { \ + OP_NAME##_uu(dst, src, num_elems); \ + } \ +} + +DEFINE_HVX_INV_OP_VARIANTS(hvx_inverse_f32, hvx_inverse_f32_loop_body) +DEFINE_HVX_INV_OP_VARIANTS(hvx_inverse_f16, hvx_inverse_f16_loop_body) + +HVX_INV_DISPATCHER(hvx_inverse_f32) +HVX_INV_DISPATCHER(hvx_inverse_f16) + +#endif // HVX_INVERSE_H diff --git a/ggml/src/ggml-hexagon/htp/hvx-log.h b/ggml/src/ggml-hexagon/htp/hvx-log.h new file mode 100644 index 00000000000..7013dae785a --- /dev/null +++ b/ggml/src/ggml-hexagon/htp/hvx-log.h @@ -0,0 +1,65 @@ +#ifndef HVX_LOG_H +#define HVX_LOG_H + +#include "hvx-base.h" + +// Approximates ln(x) element-wise for float vectors. +// x must contain positive float elements. +// Uses Abramowitz & Stegun polynomial approximation 4.1.44 for ln(1+y) over [0, 1]. +static inline HVX_Vector hvx_vec_log_f32(HVX_Vector x) { + // x = m * 2^e, where m in [1, 2) + HVX_Vector biased_e = Q6_Vuw_vlsr_VuwR(x, 23); + HVX_Vector e_int = Q6_Vw_vsub_VwVw(biased_e, Q6_V_vsplat_R(127)); + HVX_Vector e_float = Q6_Vsf_equals_Vw(e_int); + + // Extract mantissa and set exponent to 127 (which represents float value in [1.0, 2.0)) + HVX_Vector mant_mask = Q6_V_vsplat_R(0x007FFFFF); + HVX_Vector exp_127 = Q6_V_vsplat_R(0x3F800000); + HVX_Vector m = Q6_V_vor_VV(Q6_V_vand_VV(x, mant_mask), exp_127); + + // y = m - 1.0f, y in [0, 1) + HVX_Vector y = hvx_vec_sub_f32_f32(m, hvx_vec_splat_f32(1.0f)); + + // Abramowitz & Stegun 4.1.44 polynomial approximation of ln(1+y) + HVX_Vector c; + HVX_Vector res; + + c = hvx_vec_splat_f32(-0.0064535442f); + res = hvx_vec_mul_f32_f32(y, c); + + c = hvx_vec_splat_f32(0.0360884937f); + res = hvx_vec_add_f32_f32(res, c); + res = hvx_vec_mul_f32_f32(y, res); + + c = hvx_vec_splat_f32(-0.0953293897f); + res = hvx_vec_add_f32_f32(res, c); + res = hvx_vec_mul_f32_f32(y, res); + + c = hvx_vec_splat_f32(0.1676540711f); + res = hvx_vec_add_f32_f32(res, c); + res = hvx_vec_mul_f32_f32(y, res); + + c = hvx_vec_splat_f32(-0.2407338084f); + res = hvx_vec_add_f32_f32(res, c); + res = hvx_vec_mul_f32_f32(y, res); + + c = hvx_vec_splat_f32(0.3317990258f); + res = hvx_vec_add_f32_f32(res, c); + res = hvx_vec_mul_f32_f32(y, res); + + c = hvx_vec_splat_f32(-0.4998741238f); + res = hvx_vec_add_f32_f32(res, c); + res = hvx_vec_mul_f32_f32(y, res); + + c = hvx_vec_splat_f32(0.9999964239f); + res = hvx_vec_add_f32_f32(res, c); + res = hvx_vec_mul_f32_f32(y, res); + + // ln(x) = e * ln(2) + ln(1+y) + HVX_Vector ln2 = hvx_vec_splat_f32(0.69314718056f); + HVX_Vector term_e = hvx_vec_mul_f32_f32(e_float, ln2); + + return hvx_vec_add_f32_f32(term_e, res); +} + +#endif /* HVX_LOG_H */ diff --git a/ggml/src/ggml-hexagon/htp/hvx-pow.h b/ggml/src/ggml-hexagon/htp/hvx-pow.h new file mode 100644 index 00000000000..48fe0e8eade --- /dev/null +++ b/ggml/src/ggml-hexagon/htp/hvx-pow.h @@ -0,0 +1,42 @@ +#ifndef HVX_POW_H +#define HVX_POW_H + +#include <math.h> +#include "hvx-base.h" +#include "hvx-exp.h" +#include "hvx-log.h" + +// Approximates base^exponent element-wise for float vectors. +// base must be a positive constant. exponent is an HVX f32 vector. +// Uses base^x = exp(x * ln(base)). +static inline HVX_Vector hvx_vec_pow_const_base_f32(float base, HVX_Vector exponent) { + float ln_base = logf(base); + HVX_Vector ln_base_v = hvx_vec_splat_f32(ln_base); + HVX_Vector x = hvx_vec_mul_f32_f32(exponent, ln_base_v); + + static const float kInf = INFINITY; + static const float kMaxExp = 88.7228f; + + const HVX_Vector max_exp = hvx_vec_splat_f32(kMaxExp); + const HVX_Vector inf = hvx_vec_splat_f32(kInf); + + return hvx_vec_exp_f32_guard(x, max_exp, inf); +} + +// Approximates base^exponent element-wise for float vectors. +// base and exponent are HVX f32 vectors. base elements must be positive. +// Uses base^exponent = exp(exponent * ln(base)). +static inline HVX_Vector hvx_vec_pow_f32(HVX_Vector base, HVX_Vector exponent) { + HVX_Vector ln_base = hvx_vec_log_f32(base); + HVX_Vector x = hvx_vec_mul_f32_f32(exponent, ln_base); + + static const float kInf = INFINITY; + static const float kMaxExp = 88.7228f; + + const HVX_Vector max_exp = hvx_vec_splat_f32(kMaxExp); + const HVX_Vector inf = hvx_vec_splat_f32(kInf); + + return hvx_vec_exp_f32_guard(x, max_exp, inf); +} + +#endif /* HVX_POW_H */ diff --git a/ggml/src/ggml-hexagon/htp/hvx-reduce.h b/ggml/src/ggml-hexagon/htp/hvx-reduce.h new file mode 100644 index 00000000000..3c0073ef6d8 --- /dev/null +++ b/ggml/src/ggml-hexagon/htp/hvx-reduce.h @@ -0,0 +1,296 @@ +#ifndef HVX_REDUCE_H +#define HVX_REDUCE_H + +#include <math.h> +#include <stdbool.h> +#include <stdint.h> +#include <assert.h> + +#include "hex-utils.h" +#include "hvx-base.h" +#include "hvx-types.h" + +static inline HVX_Vector hvx_vec_reduce_sum_n_i32(HVX_Vector in, unsigned int n) { + unsigned int total = n * 4; // total vec nbytes + unsigned int width = 4; // int32 + + HVX_Vector sum = in, sum_t; + while (width < total) { + sum_t = Q6_V_vror_VR(sum, width); // rotate right + sum = Q6_Vw_vadd_VwVw(sum_t, sum); // elementwise sum + width = width << 1; + } + return sum; +} + +static inline HVX_Vector hvx_vec_reduce_sum_i32(HVX_Vector in) { + return hvx_vec_reduce_sum_n_i32(in, 32); +} + +static inline HVX_Vector hvx_vec_reduce_sum_n_qf32(HVX_Vector in, unsigned int n) { + unsigned int total = n * 4; // total vec nbytes + unsigned int width = 4; // fp32 nbytes + + HVX_Vector sum = in, sum_t; + while (width < total) { + sum_t = Q6_V_vror_VR(Q6_Vsf_equals_Vqf32(sum), width); // rotate right + sum = Q6_Vqf32_vadd_Vqf32Vsf(sum, sum_t); // elementwise sum + width = width << 1; + } + return sum; +} + +static inline HVX_Vector hvx_vec_reduce_sum_qf32(HVX_Vector in) { + return hvx_vec_reduce_sum_n_qf32(in, 32); +} + +#if __HVX_ARCH__ > 75 + +static inline HVX_Vector hvx_vec_reduce_sum_f32x4(HVX_Vector_x4 in) { + HVX_VectorPair sum_p01 = Q6_W_vshuff_VVR(in.v[1], in.v[0], 4); + HVX_VectorPair sum_p23 = Q6_W_vshuff_VVR(in.v[3], in.v[2], 4); + HVX_Vector sum_sf01 = Q6_Vsf_vadd_VsfVsf(Q6_V_lo_W(sum_p01), Q6_V_hi_W(sum_p01)); + HVX_Vector sum_sf23 = Q6_Vsf_vadd_VsfVsf(Q6_V_lo_W(sum_p23), Q6_V_hi_W(sum_p23)); + + HVX_VectorPair sum_p0123 = Q6_W_vshuff_VVR(sum_sf23, sum_sf01, 8); + HVX_Vector sum_sf = Q6_Vsf_vadd_VsfVsf(Q6_V_lo_W(sum_p0123), Q6_V_hi_W(sum_p0123)); + + sum_sf = Q6_Vsf_vadd_VsfVsf(sum_sf, Q6_V_vror_VR(sum_sf, VLEN / 2)); + sum_sf = Q6_Vsf_vadd_VsfVsf(sum_sf, Q6_V_vror_VR(sum_sf, VLEN / 4)); + sum_sf = Q6_Vsf_vadd_VsfVsf(sum_sf, Q6_V_vror_VR(sum_sf, VLEN / 8)); + return sum_sf; +} + +static inline HVX_Vector hvx_vec_reduce_sum_f32x2(HVX_Vector in0, HVX_Vector in1) { + HVX_VectorPair sump = Q6_W_vshuff_VVR(in1, in0, 4); + HVX_Vector sum_sf = Q6_Vsf_vadd_VsfVsf(Q6_V_lo_W(sump), Q6_V_hi_W(sump)); + + sum_sf = Q6_Vsf_vadd_VsfVsf(sum_sf, Q6_V_vror_VR(sum_sf, VLEN / 2)); + sum_sf = Q6_Vsf_vadd_VsfVsf(sum_sf, Q6_V_vror_VR(sum_sf, VLEN / 4)); + sum_sf = Q6_Vsf_vadd_VsfVsf(sum_sf, Q6_V_vror_VR(sum_sf, VLEN / 8)); + sum_sf = Q6_Vsf_vadd_VsfVsf(sum_sf, Q6_V_vror_VR(sum_sf, VLEN / 16)); + return sum_sf; +} + +static inline HVX_Vector hvx_vec_reduce_sum_n_f32(HVX_Vector in, unsigned int n) { + unsigned int total = n * 4; // total vec nbytes + unsigned int width = 4; // fp32 nbytes + + HVX_Vector sum = in, sum_t; + while (width < total) { + sum_t = Q6_V_vror_VR(sum, width); // rotate right + sum = Q6_Vsf_vadd_VsfVsf(sum, sum_t); // elementwise sum + width = width << 1; + } + return sum; +} + +#else + +static inline HVX_Vector hvx_vec_reduce_sum_f32x4(HVX_Vector_x4 in) { + HVX_VectorPair sum_p01 = Q6_W_vshuff_VVR(in.v[1], in.v[0], 4); + HVX_VectorPair sum_p23 = Q6_W_vshuff_VVR(in.v[3], in.v[2], 4); + HVX_Vector sum_qf01 = Q6_Vqf32_vadd_VsfVsf(Q6_V_lo_W(sum_p01), Q6_V_hi_W(sum_p01)); + HVX_Vector sum_qf23 = Q6_Vqf32_vadd_VsfVsf(Q6_V_lo_W(sum_p23), Q6_V_hi_W(sum_p23)); + + HVX_VectorPair sum_p0123 = Q6_W_vshuff_VVR(Q6_Vsf_equals_Vqf32(sum_qf23), Q6_Vsf_equals_Vqf32(sum_qf01), 8); + HVX_Vector sum_qf = Q6_Vqf32_vadd_VsfVsf(Q6_V_lo_W(sum_p0123), Q6_V_hi_W(sum_p0123)); + + sum_qf = Q6_Vqf32_vadd_Vqf32Vsf(sum_qf, Q6_V_vror_VR(Q6_Vsf_equals_Vqf32(sum_qf), VLEN / 2)); + sum_qf = Q6_Vqf32_vadd_Vqf32Vsf(sum_qf, Q6_V_vror_VR(Q6_Vsf_equals_Vqf32(sum_qf), VLEN / 4)); + sum_qf = Q6_Vqf32_vadd_Vqf32Vsf(sum_qf, Q6_V_vror_VR(Q6_Vsf_equals_Vqf32(sum_qf), VLEN / 8)); + return Q6_Vsf_equals_Vqf32(sum_qf); +} + +static inline HVX_Vector hvx_vec_reduce_sum_f32x2(HVX_Vector in0, HVX_Vector in1) { + HVX_VectorPair sump = Q6_W_vshuff_VVR(in1, in0, 4); + HVX_Vector sum_qf = Q6_Vqf32_vadd_VsfVsf(Q6_V_lo_W(sump), Q6_V_hi_W(sump)); + + sum_qf = Q6_Vqf32_vadd_Vqf32Vsf(sum_qf, Q6_V_vror_VR(Q6_Vsf_equals_Vqf32(sum_qf), VLEN / 2)); + sum_qf = Q6_Vqf32_vadd_Vqf32Vsf(sum_qf, Q6_V_vror_VR(Q6_Vsf_equals_Vqf32(sum_qf), VLEN / 4)); + sum_qf = Q6_Vqf32_vadd_Vqf32Vsf(sum_qf, Q6_V_vror_VR(Q6_Vsf_equals_Vqf32(sum_qf), VLEN / 8)); + sum_qf = Q6_Vqf32_vadd_Vqf32Vsf(sum_qf, Q6_V_vror_VR(Q6_Vsf_equals_Vqf32(sum_qf), VLEN / 16)); + return Q6_Vsf_equals_Vqf32(sum_qf); +} + +static inline HVX_Vector hvx_vec_reduce_sum_n_f32(HVX_Vector in, unsigned int n) { + unsigned int total = n * 4; // total vec nbytes + unsigned int width = 4; // fp32 nbytes + + HVX_Vector sum = in, sum_t; + while (width < total) { + sum_t = Q6_V_vror_VR(sum, width); // rotate right + sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_VsfVsf(sum, sum_t)); // elementwise sum + width = width << 1; + } + return sum; +} + +#endif + +static inline HVX_Vector hvx_vec_reduce_sum_f32(HVX_Vector in) { + return hvx_vec_reduce_sum_n_f32(in, 32); +} + +static inline HVX_Vector hvx_vec_reduce_max_f16(HVX_Vector in) { + unsigned total = 128; // total vec nbytes + unsigned width = 2; // fp16 nbytes + + HVX_Vector _max = in, _max_t; + while (width < total) { + _max_t = Q6_V_vror_VR(_max, width); // rotate right + _max = Q6_Vhf_vmax_VhfVhf(_max_t, _max); // elementwise max + width = width << 1; + } + + return _max; +} + +static inline HVX_Vector hvx_vec_reduce_max2_f16(HVX_Vector in, HVX_Vector _max) { + unsigned total = 128; // total vec nbytes + unsigned width = 2; // fp32 nbytes + + HVX_Vector _max_t; + + _max = Q6_Vhf_vmax_VhfVhf(in, _max); + while (width < total) { + _max_t = Q6_V_vror_VR(_max, width); // rotate right + _max = Q6_Vhf_vmax_VhfVhf(_max_t, _max); // elementwise max + width = width << 1; + } + + return _max; +} + +static inline HVX_Vector hvx_vec_reduce_max_f32(HVX_Vector in) { + unsigned total = 128; // total vec nbytes + unsigned width = 4; // fp32 nbytes + + HVX_Vector _max = in, _max_t; + while (width < total) { + _max_t = Q6_V_vror_VR(_max, width); // rotate right + _max = Q6_Vsf_vmax_VsfVsf(_max_t, _max); // elementwise max + width = width << 1; + } + + return _max; +} + +static inline HVX_Vector hvx_vec_reduce_max2_f32(HVX_Vector in, HVX_Vector _max) { + unsigned total = 128; // total vec nbytes + unsigned width = 4; // fp32 nbytes + + HVX_Vector _max_t; + + _max = Q6_Vsf_vmax_VsfVsf(in, _max); + while (width < total) { + _max_t = Q6_V_vror_VR(_max, width); // rotate right + _max = Q6_Vsf_vmax_VsfVsf(_max_t, _max); // elementwise max + width = width << 1; + } + + return _max; +} + +#define hvx_reduce_loop_body(src_type, init_vec, pad_vec, vec_op, reduce_op, scalar_reduce) \ + do { \ + src_type * restrict vsrc = (src_type *) src; \ + HVX_Vector acc = init_vec; \ + \ + const uint32_t elem_size = sizeof(float); \ + const uint32_t epv = 128 / elem_size; \ + const uint32_t nvec = num_elems / epv; \ + const uint32_t nloe = num_elems % epv; \ + \ + uint32_t i = 0; \ + _Pragma("unroll(4)") \ + for (; i < nvec; i++) { \ + acc = vec_op(acc, vsrc[i]); \ + } \ + if (nloe) { \ + const float * srcf = (const float *) src + i * epv; \ + HVX_Vector in = *(HVX_UVector *) srcf; \ + HVX_Vector temp = Q6_V_valign_VVR(in, pad_vec, nloe * elem_size); \ + acc = vec_op(acc, temp); \ + } \ + HVX_Vector v = reduce_op(acc); \ + return scalar_reduce(v); \ + } while(0) + +#define HVX_REDUCE_MAX_OP(acc, val) Q6_Vsf_vmax_VsfVsf(acc, val) +#define HVX_REDUCE_SUM_OP(acc, val) Q6_Vqf32_vadd_VsfVsf(Q6_Vsf_equals_Vqf32(acc), val) +#define HVX_SUM_SQ_OP(acc, val) Q6_Vqf32_vadd_Vqf32Vqf32(acc, Q6_Vqf32_vmpy_VsfVsf(val, val)) +#define HVX_REDUCE_MAX_SCALAR(v) hvx_vec_get_f32(v) +#define HVX_REDUCE_SUM_SCALAR(v) hvx_vec_get_f32(Q6_Vsf_equals_Vqf32(v)) + +// Max variants + +static inline float hvx_reduce_max_f32_a(const uint8_t * restrict src, const int num_elems) { + HVX_Vector init_vec = hvx_vec_splat_f32(((const float *) src)[0]); + assert((unsigned long) src % 128 == 0); + hvx_reduce_loop_body(HVX_Vector, init_vec, init_vec, HVX_REDUCE_MAX_OP, hvx_vec_reduce_max_f32, HVX_REDUCE_MAX_SCALAR); +} + +static inline float hvx_reduce_max_f32_u(const uint8_t * restrict src, const int num_elems) { + HVX_Vector init_vec = hvx_vec_splat_f32(((const float *) src)[0]); + hvx_reduce_loop_body(HVX_UVector, init_vec, init_vec, HVX_REDUCE_MAX_OP, hvx_vec_reduce_max_f32, HVX_REDUCE_MAX_SCALAR); +} + +static inline float hvx_reduce_max_f32(const uint8_t * restrict src, const int num_elems) { + if (hex_is_aligned((void *) src, 128)) { + return hvx_reduce_max_f32_a(src, num_elems); + } else { + return hvx_reduce_max_f32_u(src, num_elems); + } +} + +// Sum variants + +static inline float hvx_reduce_sum_f32_a(const uint8_t * restrict src, const int num_elems) { + HVX_Vector init_vec = Q6_V_vsplat_R(0); + assert((unsigned long) src % 128 == 0); + hvx_reduce_loop_body(HVX_Vector, init_vec, init_vec, HVX_REDUCE_SUM_OP, hvx_vec_reduce_sum_qf32, HVX_REDUCE_SUM_SCALAR); +} + +static inline float hvx_reduce_sum_f32_u(const uint8_t * restrict src, const int num_elems) { + HVX_Vector init_vec = Q6_V_vsplat_R(0); + hvx_reduce_loop_body(HVX_UVector, init_vec, init_vec, HVX_REDUCE_SUM_OP, hvx_vec_reduce_sum_qf32, HVX_REDUCE_SUM_SCALAR); +} + +static inline float hvx_reduce_sum_f32(const uint8_t * restrict src, const int num_elems) { + if (hex_is_aligned((void *) src, 128)) { + return hvx_reduce_sum_f32_a(src, num_elems); + } else { + return hvx_reduce_sum_f32_u(src, num_elems); + } +} + +// Sum of squares variants + +static inline float hvx_sum_of_squares_f32_a(const uint8_t * restrict src, const int num_elems) { + HVX_Vector init_vec = Q6_V_vsplat_R(0); + assert((uintptr_t) src % 128 == 0); + hvx_reduce_loop_body(HVX_Vector, init_vec, init_vec, HVX_SUM_SQ_OP, hvx_vec_reduce_sum_qf32, HVX_REDUCE_SUM_SCALAR); +} + +static inline float hvx_sum_of_squares_f32_u(const uint8_t * restrict src, const int num_elems) { + HVX_Vector init_vec = Q6_V_vsplat_R(0); + hvx_reduce_loop_body(HVX_UVector, init_vec, init_vec, HVX_SUM_SQ_OP, hvx_vec_reduce_sum_qf32, HVX_REDUCE_SUM_SCALAR); +} + +static inline float hvx_sum_of_squares_f32(const uint8_t * restrict src, const int num_elems) { + if (hex_is_aligned((void *) src, 128)) { + return hvx_sum_of_squares_f32_a(src, num_elems); + } else { + return hvx_sum_of_squares_f32_u(src, num_elems); + } +} + +#undef hvx_reduce_loop_body +#undef HVX_REDUCE_MAX_OP +#undef HVX_REDUCE_SUM_OP +#undef HVX_REDUCE_MAX_SCALAR +#undef HVX_REDUCE_SUM_SCALAR +#undef HVX_SUM_SQ_OP + +#endif /* HVX_REDUCE_H */ diff --git a/ggml/src/ggml-hexagon/htp/hvx-repl.h b/ggml/src/ggml-hexagon/htp/hvx-repl.h new file mode 100644 index 00000000000..fdc7e6c7d2f --- /dev/null +++ b/ggml/src/ggml-hexagon/htp/hvx-repl.h @@ -0,0 +1,74 @@ +#ifndef HVX_REPL_H +#define HVX_REPL_H + +#include <assert.h> +#include <stddef.h> +#include <stdint.h> + +#include "hvx-base.h" + +static inline HVX_Vector hvx_vec_repl(HVX_Vector v, const uint8_t * ctrl) { + return Q6_V_vdelta_VV(v, hvx_vmem(ctrl)); +} + +static inline HVX_Vector hvx_vec_repl_u32(HVX_Vector v) { + // vdelta control to replicate first 4 bytes across all lanes + static const uint8_t __attribute__((aligned(128))) repl[128] = { + 0x00, 0x00, 0x00, 0x00, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, + 0x10, 0x10, 0x10, 0x10, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, + 0x20, 0x20, 0x20, 0x20, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, + 0x10, 0x10, 0x10, 0x10, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, + 0x40, 0x40, 0x40, 0x40, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, + 0x10, 0x10, 0x10, 0x10, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, + 0x20, 0x20, 0x20, 0x20, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, + 0x10, 0x10, 0x10, 0x10, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, + }; + return hvx_vec_repl(v, repl); +} + +static inline HVX_Vector hvx_vec_repl_f32(HVX_Vector v) { + // vdelta control to replicate first 4 bytes across all lanes + static const uint8_t __attribute__((aligned(128))) repl[128] = { + 0x00, 0x00, 0x00, 0x00, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, + 0x10, 0x10, 0x10, 0x10, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, + 0x20, 0x20, 0x20, 0x20, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, + 0x10, 0x10, 0x10, 0x10, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, + 0x40, 0x40, 0x40, 0x40, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, + 0x10, 0x10, 0x10, 0x10, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, + 0x20, 0x20, 0x20, 0x20, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, + 0x10, 0x10, 0x10, 0x10, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, + }; + return hvx_vec_repl(v, repl); +} + +static inline HVX_Vector hvx_vec_repl_f16(HVX_Vector v) { + // vdelta control to replicate first two bytes across all lanes + static const uint8_t __attribute__((aligned(128))) repl[128] = { + 0x00, 0x00, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, + 0x10, 0x10, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, + 0x20, 0x20, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, + 0x10, 0x10, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, + 0x40, 0x40, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, + 0x10, 0x10, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, + 0x20, 0x20, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, + 0x10, 0x10, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, + }; + return hvx_vec_repl(v, repl); +} + +static inline HVX_Vector hvx_vec_repl_2x_f16(HVX_Vector v) { + // vdelta control to splat a pair of f16s: first half = f16[0], second half = f16[1] + static const uint8_t __attribute__((aligned(128))) repl[128] = { + 0x00, 0x00, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, + 0x10, 0x10, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, + 0x20, 0x20, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, + 0x10, 0x10, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, + 0x02, 0x02, 0x40, 0x40, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, + 0x02, 0x02, 0x10, 0x10, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, + 0x02, 0x02, 0x20, 0x20, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, + 0x02, 0x02, 0x10, 0x10, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, + }; + return hvx_vec_repl(v, repl); +} + +#endif // HVX_REPL_H diff --git a/ggml/src/ggml-hexagon/htp/hvx-scale.h b/ggml/src/ggml-hexagon/htp/hvx-scale.h new file mode 100644 index 00000000000..c65c98639dc --- /dev/null +++ b/ggml/src/ggml-hexagon/htp/hvx-scale.h @@ -0,0 +1,133 @@ +#ifndef HVX_SCALE_H +#define HVX_SCALE_H + +#include <assert.h> +#include <stddef.h> +#include <stdint.h> + +#include "hvx-base.h" + +#define hvx_scale_f32_loop_body(dst_type, src_type, vec_store) \ + do { \ + dst_type * restrict vdst = (dst_type *) dst; \ + src_type * restrict vsrc = (src_type *) src; \ + \ + HVX_Vector vs = hvx_vec_splat_f32(scale); \ + \ + const uint32_t elem_size = sizeof(float); \ + const uint32_t epv = 128 / elem_size; \ + const uint32_t nvec = n / epv; \ + const uint32_t nloe = n % epv; \ + \ + uint32_t i = 0; \ + \ + _Pragma("unroll(4)") \ + for (; i < nvec; ++i) { \ + HVX_Vector v = Q6_Vqf32_vmpy_VsfVsf(vsrc[i], vs); \ + vdst[i] = Q6_Vsf_equals_Vqf32(v); \ + } \ + if (nloe) { \ + HVX_Vector v = Q6_Vqf32_vmpy_VsfVsf(vsrc[i], vs); \ + vec_store((void *) &vdst[i], nloe * elem_size, Q6_Vsf_equals_Vqf32(v)); \ + } \ + } while(0) + +static inline void hvx_scale_f32_aa(uint8_t * restrict dst, const uint8_t * restrict src, const int n, const float scale) { + assert((size_t) dst % 128 == 0); + assert((size_t) src % 128 == 0); + hvx_scale_f32_loop_body(HVX_Vector, HVX_Vector, hvx_vec_store_a); +} + +static inline void hvx_scale_f32_au(uint8_t * restrict dst, const uint8_t * restrict src, const int n, const float scale) { + assert((size_t) dst % 128 == 0); + hvx_scale_f32_loop_body(HVX_Vector, HVX_UVector, hvx_vec_store_a); +} + +static inline void hvx_scale_f32_ua(uint8_t * restrict dst, const uint8_t * restrict src, const int n, const float scale) { + assert((size_t) src % 128 == 0); + hvx_scale_f32_loop_body(HVX_UVector, HVX_Vector, hvx_vec_store_u); +} + +static inline void hvx_scale_f32_uu(uint8_t * restrict dst, const uint8_t * restrict src, const int n, const float scale) { + hvx_scale_f32_loop_body(HVX_UVector, HVX_UVector, hvx_vec_store_u); +} + +static inline void hvx_scale_f32(uint8_t * restrict dst, const uint8_t * restrict src, const int n, const float scale) { + if (((size_t) dst & 127) == 0) { + if (((size_t) src & 127) == 0) { + hvx_scale_f32_aa(dst, src, n, scale); + } else { + hvx_scale_f32_au(dst, src, n, scale); + } + } else { + if (((size_t) src & 127) == 0) { + hvx_scale_f32_ua(dst, src, n, scale); + } else { + hvx_scale_f32_uu(dst, src, n, scale); + } + } +} + +#define hvx_scale_offset_f32_loop_body(dst_type, src_type, vec_store) \ + do { \ + dst_type * restrict vdst = (dst_type *) dst; \ + src_type * restrict vsrc = (src_type *) src; \ + \ + HVX_Vector vs = hvx_vec_splat_f32(scale); \ + HVX_Vector vo = hvx_vec_splat_f32(offset); \ + \ + const uint32_t elem_size = sizeof(float); \ + const uint32_t epv = 128 / elem_size; \ + const uint32_t nvec = n / epv; \ + const uint32_t nloe = n % epv; \ + \ + uint32_t i = 0; \ + \ + _Pragma("unroll(4)") \ + for (; i < nvec; ++i) { \ + HVX_Vector v = Q6_Vqf32_vadd_Vqf32Vsf(Q6_Vqf32_vmpy_VsfVsf(vsrc[i], vs), vo); \ + vdst[i] = Q6_Vsf_equals_Vqf32(v); \ + } \ + if (nloe) { \ + HVX_Vector v = Q6_Vqf32_vadd_Vqf32Vsf(Q6_Vqf32_vmpy_VsfVsf(vsrc[i], vs), vo); \ + vec_store((void *) &vdst[i], nloe * elem_size, Q6_Vsf_equals_Vqf32(v)); \ + } \ + } while(0) + +static inline void hvx_scale_offset_f32_aa(uint8_t * restrict dst, const uint8_t * restrict src, const int n, const float scale, const float offset) { + assert((size_t) dst % 128 == 0); + assert((size_t) src % 128 == 0); + hvx_scale_offset_f32_loop_body(HVX_Vector, HVX_Vector, hvx_vec_store_a); +} + +static inline void hvx_scale_offset_f32_au(uint8_t * restrict dst, const uint8_t * restrict src, const int n, const float scale, const float offset) { + assert((size_t) dst % 128 == 0); + hvx_scale_offset_f32_loop_body(HVX_Vector, HVX_UVector, hvx_vec_store_a); +} + +static inline void hvx_scale_offset_f32_ua(uint8_t * restrict dst, const uint8_t * restrict src, const int n, const float scale, const float offset) { + assert((size_t) src % 128 == 0); + hvx_scale_offset_f32_loop_body(HVX_UVector, HVX_Vector, hvx_vec_store_u); +} + +static inline void hvx_scale_offset_f32_uu(uint8_t * restrict dst, const uint8_t * restrict src, const int n, const float scale, const float offset) { + hvx_scale_offset_f32_loop_body(HVX_UVector, HVX_UVector, hvx_vec_store_u); +} + +static inline void hvx_scale_offset_f32(uint8_t * restrict dst, const uint8_t * restrict src, const int n, const float scale, const float offset) { + if (((size_t) dst & 127) == 0) { + if (((size_t) src & 127) == 0) { + hvx_scale_offset_f32_aa(dst, src, n, scale, offset); + } else { + hvx_scale_offset_f32_au(dst, src, n, scale, offset); + } + } else { + if (((size_t) src & 127) == 0) { + hvx_scale_offset_f32_ua(dst, src, n, scale, offset); + } else { + hvx_scale_offset_f32_uu(dst, src, n, scale, offset); + } + } +} + +#endif // HVX_SCALE_H diff --git a/ggml/src/ggml-hexagon/htp/hvx-sigmoid.c b/ggml/src/ggml-hexagon/htp/hvx-sigmoid.c deleted file mode 100644 index 15ac64697c7..00000000000 --- a/ggml/src/ggml-hexagon/htp/hvx-sigmoid.c +++ /dev/null @@ -1,49 +0,0 @@ -#pragma clang diagnostic ignored "-Wunused-variable" -#pragma clang diagnostic ignored "-Wunused-function" -#pragma clang diagnostic ignored "-Wunused-but-set-variable" - -#include <hexagon_protos.h> -#include <hexagon_types.h> -#include <math.h> -#include <string.h> - -#define GGML_COMMON_DECL_C -#include "ggml-common.h" -#include "htp-ctx.h" -#include "htp-dma.h" -#include "htp-msg.h" -#include "htp-ops.h" -#include "hvx-utils.h" -#include "ops-utils.h" - -#if 0 -// Reference algo used in hvx-utils -static void fast_sigmoid_f32(const float* restrict src, float* restrict dst, const int num_elems) -{ - const float c1 = 0.03138777; - const float c2 = 0.276281267; - const float c_log2f = 1.442695022; - - int32_t store_ints[32]; - float store_floats[3][32]; - - for (int i = 0; i < num_elems; i++) - { - float v = src0[i]; - - v *= c_log2f*0.5; - int intPart = (int)v; - float x = (v - intPart); - float xx = x * x; - float v1 = c_log2f + c2 * xx; - float v2 = x + xx * c1 * x; - float v3 = (v2 + v1); - *((int*)&v3) += intPart << 24; - float v4 = v2 - v1; - float v5 = v3 - v4; - float res = v3 / v5; - - dst[i] = res; - } -} -#endif diff --git a/ggml/src/ggml-hexagon/htp/hvx-sigmoid.h b/ggml/src/ggml-hexagon/htp/hvx-sigmoid.h new file mode 100644 index 00000000000..37f3e7b6fae --- /dev/null +++ b/ggml/src/ggml-hexagon/htp/hvx-sigmoid.h @@ -0,0 +1,142 @@ +#ifndef HVX_SIGMOID_H +#define HVX_SIGMOID_H + +#include "hvx-base.h" +#include "hvx-inverse.h" + +#define FAST_SIGMOID_LOG2F (0x3fb8aa3b) // 1.442695022 +#define FAST_SIGMOID_C1 (0x3d009076) // 0.03138777 +#define FAST_SIGMOID_C2 (0x3e8d74bd) // 0.276281267 +#define FAST_SIGMOID_C3 (0x3f000000) // 0.5 + +static inline HVX_Vector hvx_vec_fast_sigmoid_f32(HVX_Vector v) { + v = Q6_Vqf32_vmpy_VsfVsf(v, Q6_V_vsplat_R(FAST_SIGMOID_LOG2F)); + v = Q6_Vqf32_vmpy_VsfVsf(Q6_Vsf_equals_Vqf32(v), Q6_V_vsplat_R(FAST_SIGMOID_C3)); + + HVX_Vector in_int = hvx_vec_truncate_f32(Q6_Vsf_equals_Vqf32(v)); + HVX_Vector x = Q6_Vqf32_vsub_Vqf32Vsf(v, Q6_Vsf_equals_Vw(in_int)); + HVX_Vector xx = Q6_Vqf32_vmpy_Vqf32Vqf32(x, x); + + HVX_Vector v1 = Q6_Vqf32_vmpy_VsfVsf(Q6_Vsf_equals_Vqf32(xx), Q6_V_vsplat_R(FAST_SIGMOID_C2)); + v1 = Q6_Vqf32_vadd_Vqf32Vsf(v1, Q6_V_vsplat_R(FAST_SIGMOID_LOG2F)); + + HVX_Vector v2 = Q6_Vqf32_vmpy_VsfVsf(Q6_Vsf_equals_Vqf32(x), Q6_V_vsplat_R(FAST_SIGMOID_C1)); + v2 = Q6_Vqf32_vmpy_Vqf32Vqf32(v2, xx); + v2 = Q6_Vqf32_vadd_Vqf32Vqf32(v2, x); + + HVX_Vector v3 = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vqf32(v2, v1)); + HVX_Vector v3_exponent = Q6_Vw_vasl_VwR(v3, 1); + v3_exponent = Q6_Vuw_vlsr_VuwR(v3_exponent, 24); + v3_exponent = Q6_Vw_vadd_VwVw(in_int, v3_exponent); + v3 = Q6_Vw_vaslacc_VwVwR(v3, in_int, 24); + + HVX_Vector v4 = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vsub_Vqf32Vqf32(v2, v1)); + HVX_Vector v5 = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vsub_VsfVsf(v3, v4)); + + HVX_Vector res = hvx_vec_inverse_f32(v5); + res = Q6_Vqf32_vmpy_VsfVsf(v3, res); + + return Q6_Vsf_equals_Vqf32(res); +} + +static inline HVX_Vector hvx_vec_fast_sigmoid_f32_guard(HVX_Vector v, + HVX_Vector one, + HVX_Vector max_exp, + HVX_Vector min_exp) { + const HVX_VectorPred pred_max = Q6_Q_vcmp_gt_VsfVsf(max_exp, v); + const HVX_VectorPred pred_min = Q6_Q_vcmp_gt_VsfVsf(v, min_exp); + + HVX_Vector out = hvx_vec_fast_sigmoid_f32(v); + out = Q6_V_vmux_QVV(pred_max, out, one); + return Q6_V_vmux_QVV(pred_min, out, Q6_V_vzero()); +} + +static inline HVX_Vector hvx_vec_tanh_f32(HVX_Vector x) { + // tanh(x) = 2 * sigmoid(2x) - 1 + HVX_Vector two = hvx_vec_splat_f32(2.0f); + HVX_Vector one = hvx_vec_splat_f32(1.0f); + HVX_Vector x2 = Q6_Vqf32_vmpy_VsfVsf(x, two); + + HVX_Vector max_exp = hvx_vec_splat_f32(87.f); + HVX_Vector min_exp = hvx_vec_splat_f32(-87.f); + + HVX_Vector sig2x = hvx_vec_fast_sigmoid_f32_guard(Q6_Vsf_equals_Vqf32(x2), one, max_exp, min_exp); + + HVX_Vector res = Q6_Vqf32_vmpy_VsfVsf(sig2x, two); + res = Q6_Vqf32_vsub_Vqf32Vsf(res, one); + return Q6_Vsf_equals_Vqf32(res); +} + +#define hvx_sigmoid_loop_body(dst_type, src_type, vec_store) \ + do { \ + dst_type * restrict vdst = (dst_type *) dst; \ + src_type * restrict vsrc = (src_type *) src; \ + \ + const HVX_Vector one = hvx_vec_splat_f32(1.f); \ + const HVX_Vector max_exp = hvx_vec_splat_f32(87.f); \ + const HVX_Vector min_exp = hvx_vec_splat_f32(-87.f); \ + \ + const uint32_t epv = 128 / sizeof(float); \ + const uint32_t nvec = n / epv; \ + const uint32_t nloe = n % epv; \ + \ + uint32_t i = 0; \ + \ + _Pragma("unroll(4)") \ + for (; i < nvec; i++) { \ + vdst[i] = hvx_vec_fast_sigmoid_f32_guard(vsrc[i], one, max_exp, min_exp); \ + } \ + if (nloe) { \ + HVX_Vector tmp = hvx_vec_fast_sigmoid_f32_guard(vsrc[i], one, max_exp, min_exp); \ + vec_store((void *) &vdst[i], nloe * sizeof(float), tmp); \ + } \ + } while(0) + +#define hvx_tanh_loop_body(dst_type, src_type, vec_store) \ + do { \ + dst_type * restrict vdst = (dst_type *) dst; \ + src_type * restrict vsrc = (src_type *) src; \ + \ + const uint32_t epv = 128 / sizeof(float); \ + const uint32_t nvec = n / epv; \ + const uint32_t nloe = n % epv; \ + \ + uint32_t i = 0; \ + \ + _Pragma("unroll(4)") \ + for (; i < nvec; i++) { \ + vdst[i] = hvx_vec_tanh_f32(vsrc[i]); \ + } \ + if (nloe) { \ + HVX_Vector tmp = hvx_vec_tanh_f32(vsrc[i]); \ + vec_store((void *) &vdst[i], nloe * sizeof(float), tmp); \ + } \ + } while(0) + +static inline void hvx_sigmoid_f32_aa(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) { + assert((unsigned long) dst % 128 == 0); + assert((unsigned long) src % 128 == 0); + hvx_sigmoid_loop_body(HVX_Vector, HVX_Vector, hvx_vec_store_a); +} + +static inline void hvx_sigmoid_f32_au(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) { + assert((unsigned long) dst % 128 == 0); + hvx_sigmoid_loop_body(HVX_Vector, HVX_UVector, hvx_vec_store_a); +} + +static inline void hvx_sigmoid_f32_ua(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) { + assert((unsigned long) src % 128 == 0); + hvx_sigmoid_loop_body(HVX_UVector, HVX_Vector, hvx_vec_store_u); +} + +static inline void hvx_sigmoid_f32_uu(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) { + hvx_sigmoid_loop_body(HVX_UVector, HVX_UVector, hvx_vec_store_u); +} + +static inline void hvx_tanh_f32_aa(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) { + assert((unsigned long) dst % 128 == 0); + assert((unsigned long) src % 128 == 0); + hvx_tanh_loop_body(HVX_Vector, HVX_Vector, hvx_vec_store_a); +} + +#endif /* HVX_SIGMOID_H */ diff --git a/ggml/src/ggml-hexagon/htp/hvx-sin-cos.h b/ggml/src/ggml-hexagon/htp/hvx-sin-cos.h new file mode 100644 index 00000000000..c5b9a5d47c1 --- /dev/null +++ b/ggml/src/ggml-hexagon/htp/hvx-sin-cos.h @@ -0,0 +1,90 @@ +#ifndef HVX_SIN_COS_H +#define HVX_SIN_COS_H + +#include "hvx-base.h" +#include "hvx-floor.h" + +static inline HVX_Vector hvx_vec_cos_f32(HVX_Vector x) { + HVX_Vector const_inv_pi = hvx_vec_splat_f32(0.3183098861837907f); + HVX_Vector const_half = hvx_vec_splat_f32(0.5f); + HVX_Vector const_pi = hvx_vec_splat_f32(3.141592653589793f); + HVX_Vector const_one = hvx_vec_splat_f32(1.0f); + HVX_Vector const_neg_one = hvx_vec_splat_f32(-1.0f); + + // n = floor(x * (1/pi) + 0.5) + HVX_Vector n_float = hvx_vec_floor_f32(hvx_vec_add_f32_f32(hvx_vec_mul_f32_f32(x, const_inv_pi), const_half)); + + // y = x - n * pi + HVX_Vector y = hvx_vec_sub_f32_f32(x, hvx_vec_mul_f32_f32(n_float, const_pi)); + + // Sign determination: if n is odd, sign is -1.0f, else 1.0f + // half_n = n * 0.5f + HVX_Vector half_n = hvx_vec_mul_f32_f32(n_float, const_half); + // floor_half_n = floor(half_n) + HVX_Vector floor_half_n = hvx_vec_floor_f32(half_n); + // is_odd = half_n > floor_half_n + HVX_VectorPred is_odd = Q6_Q_vcmp_gt_VsfVsf(half_n, floor_half_n); + // sign = vmux(is_odd, -1.0f, 1.0f) + HVX_Vector sign = Q6_V_vmux_QVV(is_odd, const_neg_one, const_one); + + // z = y^2 + HVX_Vector z = hvx_vec_mul_f32_f32(y, y); + + // Chebyshev approximation for cos(y) + HVX_Vector c4 = hvx_vec_splat_f32(2.3557242013849433e-05f); + HVX_Vector c3 = hvx_vec_splat_f32(-0.0013871428263450528f); + HVX_Vector c2 = hvx_vec_splat_f32(0.041665895266688284f); + HVX_Vector c1 = hvx_vec_splat_f32(-0.4999999360426369f); + HVX_Vector c0 = hvx_vec_splat_f32(0.9999999999071725f); + + HVX_Vector cos_y = hvx_vec_add_f32_f32(c3, hvx_vec_mul_f32_f32(z, c4)); + cos_y = hvx_vec_add_f32_f32(c2, hvx_vec_mul_f32_f32(z, cos_y)); + cos_y = hvx_vec_add_f32_f32(c1, hvx_vec_mul_f32_f32(z, cos_y)); + cos_y = hvx_vec_add_f32_f32(c0, hvx_vec_mul_f32_f32(z, cos_y)); + + return hvx_vec_mul_f32_f32(cos_y, sign); +} + +static inline HVX_Vector hvx_vec_sin_f32(HVX_Vector x) { + HVX_Vector const_inv_pi = hvx_vec_splat_f32(0.3183098861837907f); + HVX_Vector const_half = hvx_vec_splat_f32(0.5f); + HVX_Vector const_pi = hvx_vec_splat_f32(3.141592653589793f); + HVX_Vector const_one = hvx_vec_splat_f32(1.0f); + HVX_Vector const_neg_one = hvx_vec_splat_f32(-1.0f); + + // n = floor(x * (1/pi) + 0.5) + HVX_Vector n_float = hvx_vec_floor_f32(hvx_vec_add_f32_f32(hvx_vec_mul_f32_f32(x, const_inv_pi), const_half)); + + // y = x - n * pi + HVX_Vector y = hvx_vec_sub_f32_f32(x, hvx_vec_mul_f32_f32(n_float, const_pi)); + + // Sign determination: if n is odd, sign is -1.0f, else 1.0f + // half_n = n * 0.5f + HVX_Vector half_n = hvx_vec_mul_f32_f32(n_float, const_half); + // floor_half_n = floor(half_n) + HVX_Vector floor_half_n = hvx_vec_floor_f32(half_n); + // is_odd = half_n > floor_half_n + HVX_VectorPred is_odd = Q6_Q_vcmp_gt_VsfVsf(half_n, floor_half_n); + // sign = vmux(is_odd, -1.0f, 1.0f) + HVX_Vector sign = Q6_V_vmux_QVV(is_odd, const_neg_one, const_one); + + // z = y^2 + HVX_Vector z = hvx_vec_mul_f32_f32(y, y); + + // Chebyshev approximation for sin(y) + HVX_Vector s4 = hvx_vec_splat_f32(2.642186986152672e-06f); + HVX_Vector s3 = hvx_vec_splat_f32(-0.00019825318964070864f); + HVX_Vector s2 = hvx_vec_splat_f32(0.00833326283319605f); + HVX_Vector s1 = hvx_vec_splat_f32(-0.16666666082087775f); + HVX_Vector s0 = hvx_vec_splat_f32(0.999999999915155f); + + HVX_Vector sin_y = hvx_vec_add_f32_f32(s3, hvx_vec_mul_f32_f32(z, s4)); + sin_y = hvx_vec_add_f32_f32(s2, hvx_vec_mul_f32_f32(z, sin_y)); + sin_y = hvx_vec_add_f32_f32(s1, hvx_vec_mul_f32_f32(z, sin_y)); + sin_y = hvx_vec_add_f32_f32(s0, hvx_vec_mul_f32_f32(z, sin_y)); + sin_y = hvx_vec_mul_f32_f32(y, sin_y); + + return hvx_vec_mul_f32_f32(sin_y, sign); +} + +#endif /* HVX_SIN_COS_H */ diff --git a/ggml/src/ggml-hexagon/htp/hvx-sqrt.h b/ggml/src/ggml-hexagon/htp/hvx-sqrt.h new file mode 100644 index 00000000000..e31a1006d21 --- /dev/null +++ b/ggml/src/ggml-hexagon/htp/hvx-sqrt.h @@ -0,0 +1,126 @@ +#ifndef HVX_SQRT_H +#define HVX_SQRT_H + +#include <stdbool.h> +#include <stdint.h> + +#include "hex-utils.h" + +#include "hvx-base.h" + +#define RSQRT_CONST 0x5f3759df // Constant for fast inverse square root calculation +#define RSQRT_ONE_HALF 0x3f000000 // 0.5 +#define RSQRT_THREE_HALVES 0x3fc00000 // 1.5 + +#if __HVX_ARCH__ < 79 +#define HVX_OP_MUL(a, b) Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(a, b)) +#else +#define HVX_OP_MUL(a, b) Q6_Vsf_vmpy_VsfVsf(a, b) +#endif + +static inline HVX_Vector hvx_vec_rsqrt_f32(HVX_Vector in_vec) { + //Algorithm : + // x2 = input*0.5 + // y = * (long *) &input + // y = 0x5f3759df - (y>>1) + // y = y*(threehalfs - x2*y*y) + + HVX_Vector rsqrtconst = Q6_V_vsplat_R(RSQRT_CONST); + HVX_Vector onehalf = Q6_V_vsplat_R(RSQRT_ONE_HALF); + HVX_Vector threehalfs = Q6_V_vsplat_R(RSQRT_THREE_HALVES); + + HVX_Vector x2, y, ypower2, temp; + + x2 = Q6_Vqf32_vmpy_VsfVsf(in_vec, onehalf); + x2 = Q6_Vqf32_vadd_Vqf32Vsf(x2, Q6_V_vzero()); + + y = Q6_Vw_vasr_VwR(in_vec, 1); + y = Q6_Vw_vsub_VwVw(rsqrtconst, y); + + // 1st iteration + ypower2 = Q6_Vqf32_vmpy_VsfVsf(y, y); + ypower2 = Q6_Vqf32_vadd_Vqf32Vsf(ypower2, Q6_V_vzero()); + temp = Q6_Vqf32_vmpy_Vqf32Vqf32(x2, ypower2); + temp = Q6_Vqf32_vsub_VsfVsf(threehalfs, Q6_Vsf_equals_Vqf32(temp)); + temp = Q6_Vqf32_vmpy_VsfVsf(y, Q6_Vsf_equals_Vqf32(temp)); + + // 2nd iteration + y = Q6_Vqf32_vadd_Vqf32Vsf(temp, Q6_V_vzero()); + ypower2 = Q6_Vqf32_vmpy_Vqf32Vqf32(y, y); + ypower2 = Q6_Vqf32_vadd_Vqf32Vsf(ypower2, Q6_V_vzero()); + temp = Q6_Vqf32_vmpy_Vqf32Vqf32(x2, ypower2); + temp = Q6_Vqf32_vsub_VsfVsf(threehalfs, Q6_Vsf_equals_Vqf32(temp)); + temp = Q6_Vqf32_vmpy_Vqf32Vqf32(y, temp); + + // 3rd iteration + y = Q6_Vqf32_vadd_Vqf32Vsf(temp, Q6_V_vzero()); + ypower2 = Q6_Vqf32_vmpy_Vqf32Vqf32(y, y); + ypower2 = Q6_Vqf32_vadd_Vqf32Vsf(ypower2, Q6_V_vzero()); + temp = Q6_Vqf32_vmpy_Vqf32Vqf32(x2, ypower2); + temp = Q6_Vqf32_vsub_VsfVsf(threehalfs, Q6_Vsf_equals_Vqf32(temp)); + temp = Q6_Vqf32_vmpy_Vqf32Vqf32(y, temp); + + return Q6_Vsf_equals_Vqf32(temp); +} + +// Compute sqrt(x) as x*inv_sqrt(x) +#define hvx_sqrt_f32_loop_body(dst_type, src_type, vec_store) \ + do { \ + dst_type * restrict vdst = (dst_type *) dst; \ + src_type * restrict vsrc = (src_type *) src; \ + \ + const uint32_t nvec = n / VLEN_FP32; \ + const uint32_t nloe = n % VLEN_FP32; \ + \ + uint32_t i = 0; \ + \ + _Pragma("unroll(4)") \ + for (; i < nvec; i++) { \ + HVX_Vector inv_sqrt = hvx_vec_rsqrt_f32(vsrc[i]); \ + HVX_Vector sqrt_res = HVX_OP_MUL(inv_sqrt, vsrc[i]); \ + vdst[i] = sqrt_res; \ + } \ + if (nloe) { \ + HVX_Vector inv_sqrt = hvx_vec_rsqrt_f32(vsrc[i]); \ + HVX_Vector sqrt_res = HVX_OP_MUL(inv_sqrt, vsrc[i]); \ + vec_store((void *) &vdst[i], nloe * SIZEOF_FP32, sqrt_res); \ + } \ + } while(0) + +static inline void hvx_sqrt_f32_aa(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) { + assert((unsigned long) dst % 128 == 0); + assert((unsigned long) src % 128 == 0); + hvx_sqrt_f32_loop_body(HVX_Vector, HVX_Vector, hvx_vec_store_a); +} + +static inline void hvx_sqrt_f32_au(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) { + assert((unsigned long) dst % 128 == 0); + hvx_sqrt_f32_loop_body(HVX_Vector, HVX_UVector, hvx_vec_store_a); +} + +static inline void hvx_sqrt_f32_ua(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) { + assert((unsigned long) src % 128 == 0); + hvx_sqrt_f32_loop_body(HVX_UVector, HVX_Vector, hvx_vec_store_u); +} + +static inline void hvx_sqrt_f32_uu(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) { + hvx_sqrt_f32_loop_body(HVX_UVector, HVX_UVector, hvx_vec_store_u); +} + +static inline void hvx_sqrt_f32(uint8_t * restrict dst, const uint8_t * restrict src, const int num_elems) { + if ((unsigned long) dst % 128 == 0) { + if ((unsigned long) src % 128 == 0) { + hvx_sqrt_f32_aa(dst, src, num_elems); + } else { + hvx_sqrt_f32_au(dst, src, num_elems); + } + } else { + if ((unsigned long) src % 128 == 0) { + hvx_sqrt_f32_ua(dst, src, num_elems); + } else { + hvx_sqrt_f32_uu(dst, src, num_elems); + } + } +} + +#endif /* HVX_SQRT_H */ diff --git a/ggml/src/ggml-hexagon/htp/hvx-types.h b/ggml/src/ggml-hexagon/htp/hvx-types.h new file mode 100644 index 00000000000..d495a59fbea --- /dev/null +++ b/ggml/src/ggml-hexagon/htp/hvx-types.h @@ -0,0 +1,36 @@ +#ifndef HVX_TYPES_H +#define HVX_TYPES_H + +#include <stdbool.h> +#include <stdint.h> + +#include <hexagon_types.h> + +#define SIZEOF_FP32 (4) +#define SIZEOF_FP16 (2) +#define VLEN (128) +#define VLEN_FP32 (VLEN / SIZEOF_FP32) +#define VLEN_FP16 (VLEN / SIZEOF_FP16) + +typedef union { + HVX_Vector v; + uint8_t b[VLEN]; + uint16_t h[VLEN_FP16]; + uint32_t w[VLEN_FP32]; + __fp16 fp16[VLEN_FP16]; + float fp32[VLEN_FP32]; +} __attribute__((aligned(VLEN), packed)) HVX_VectorAlias; + +typedef struct { + HVX_Vector v[2]; +} HVX_Vector_x2; + +typedef struct { + HVX_Vector v[4]; +} HVX_Vector_x4; + +typedef struct { + HVX_Vector v[8]; +} HVX_Vector_x8; + +#endif /* HVX_TYPES_H */ diff --git a/ggml/src/ggml-hexagon/htp/hvx-utils.c b/ggml/src/ggml-hexagon/htp/hvx-utils.c deleted file mode 100644 index 29d73b8622b..00000000000 --- a/ggml/src/ggml-hexagon/htp/hvx-utils.c +++ /dev/null @@ -1,1020 +0,0 @@ -#pragma clang diagnostic ignored "-Wunused-variable" -#pragma clang diagnostic ignored "-Wunused-function" -#pragma clang diagnostic ignored "-Wunused-but-set-variable" - -#ifdef HTP_DEBUG -# define FARF_HIGH 1 -#endif - -#include <HAP_farf.h> -#include <HAP_mem.h> -#include <HAP_perf.h> -#include <HAP_ps.h> -#include <hexagon_protos.h> -#include <hexagon_types.h> -#include <math.h> -#include <string.h> - -#define GGML_COMMON_DECL_C -#include "ggml-common.h" -#include "hvx-utils.h" - -#define htp_binary_ops_preamble \ - int step_of_4 = num_elems >> 7; \ - int step_of_2 = (num_elems - step_of_4 * VLEN_FP32 * 4) >> 6; \ - int step_of_1 = (num_elems - step_of_4 * VLEN_FP32 * 4 - step_of_2 * VLEN_FP32 * 2) >> 5; \ - int remaining = num_elems - step_of_4 * VLEN_FP32 * 4 - step_of_2 * VLEN_FP32 * 2 - step_of_1 * VLEN_FP32; \ - \ - const uint8_t * restrict src0_curr = src0; \ - const uint8_t * restrict src1_curr = src1; \ - uint8_t * restrict dst_curr = dst; - -void hvx_mul_f32(const uint8_t * restrict src0, - const uint8_t * restrict src1, - uint8_t * restrict dst, - const int num_elems) { - int left_over = num_elems & (VLEN_FP32 - 1); - int num_elems_whole = num_elems - left_over; - - int unaligned_addr = 0; - int unaligned_loop = 0; - if ((0 == htp_is_aligned((void *) src0, VLEN)) || (0 == htp_is_aligned((void *) src1, VLEN)) || - (0 == htp_is_aligned((void *) dst, VLEN))) { - FARF(HIGH, "hvx_mul_f32: unaligned address in hvx op, possibly slower execution\n"); - unaligned_addr = 1; - } - - if ((1 == unaligned_addr) && (num_elems_whole != 0)) { - unaligned_loop = 1; - FARF(HIGH, "hvx_mul_f32: unaligned loop in hvx op, possibly slower execution\n"); - } - - - bool handled_leftover = false; - if (0 == unaligned_loop) { - HVX_Vector * restrict vec_in1 = (HVX_Vector *) src0; - HVX_Vector * restrict vec_in2 = (HVX_Vector *) src1; - HVX_Vector * restrict vec_out = (HVX_Vector *) dst; - - #pragma unroll(4) - for (int i = 0; i < num_elems_whole; i += VLEN_FP32) { - HVX_Vector v = Q6_Vqf32_vmpy_VsfVsf(*vec_in1++, *vec_in2++); - *vec_out++ = Q6_Vsf_equals_Vqf32(v); - } - } else { - int step_of_1 = num_elems_whole >> 5; // divby 32, because 32 float = 128 bytes per HVX vector - int leftover_size = left_over * sizeof(float); - - - HVX_Vector * restrict vec_in1 = (HVX_Vector *) src0; - HVX_Vector * restrict vec_in2 = (HVX_Vector *) src1; - HVX_UVector * restrict vec_out = (HVX_UVector *) dst; - - HVX_Vector slinep; - HVX_Vector slinec; - HVX_Vector sline; - HVX_Vector sline2p; - HVX_Vector sline2c; - HVX_Vector sline2; - - slinep = *vec_in1++; - sline2p = *vec_in2++; - #pragma unroll(4) - for (int i = step_of_1 - 1; i > 0; i--) { - slinec = *vec_in1++; - sline2c = *vec_in2++; - sline = Q6_V_valign_VVR(slinec, slinep, (size_t) src0); - sline2 = Q6_V_valign_VVR(sline2c, sline2p, (size_t) src1); - - *((HVX_UVector *) (vec_out++)) = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(sline, sline2)); - slinep = slinec; - sline2p = sline2c; - } - if (step_of_1 > 1) { - slinec = htp_is_aligned(vec_in1, VLEN) && left_over == 0 ? slinep : *vec_in1++; - sline2c = htp_is_aligned(vec_in2, VLEN) && left_over == 0 ? sline2p : *vec_in2++; - - sline = Q6_V_valign_VVR(slinec, slinep, (size_t) src0); - sline2 = Q6_V_valign_VVR(sline2c, sline2p, (size_t) src1); - *((HVX_UVector *) (vec_out++)) = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(sline, sline2)); - slinep = slinec; - sline2p = sline2c; - } - if (left_over > 0) { - slinec = (is_in_one_chunk(vec_in1, leftover_size, VLEN) ? slinep : *vec_in1++); - - sline = Q6_V_valign_VVR(slinec, slinep, (size_t) src0); - sline2c = (is_in_one_chunk(vec_in2, leftover_size, VLEN) ? sline2p : *vec_in2++); - sline2 = Q6_V_valign_VVR(sline2c, sline2p, (size_t) src1); - - HVX_Vector out = Q6_Vqf32_vmpy_VsfVsf(sline, sline2); - hvx_vec_store_u(vec_out, leftover_size, Q6_Vsf_equals_Vqf32(out)); - handled_leftover = true; - } - } - - - if (left_over > 0 && !handled_leftover) { - const float * src0f = (const float *) src0 + num_elems_whole; - const float * src1f = (const float *) src1 + num_elems_whole; - float * dstf = (float *) dst + num_elems_whole; - - HVX_Vector in1 = *(HVX_UVector *) src0f; - HVX_Vector in2 = *(HVX_UVector *) src1f; - - HVX_Vector out = Q6_Vqf32_vmpy_VsfVsf(in1, in2); - hvx_vec_store_u((void *) dstf, left_over * SIZEOF_FP32, Q6_Vsf_equals_Vqf32(out)); - } -} - -void hvx_mul_f32_opt(const uint8_t * restrict src0, - const uint8_t * restrict src1, - uint8_t * restrict dst, - const int num_elems) { - htp_binary_ops_preamble; - - for (int i = 0; i < step_of_4; i++) { - HVX_Vector v1a = *(HVX_Vector *) src0_curr; - - HVX_Vector v1b = *(HVX_Vector *) src1_curr; - - HVX_Vector v2a = *(HVX_Vector *) (src0_curr + VLEN); - - HVX_Vector v1 = Q6_Vqf32_vmpy_VsfVsf(v1a, v1b); - - HVX_Vector v2b = *(HVX_Vector *) (src1_curr + VLEN); - - HVX_Vector v3a = *(HVX_Vector *) (src0_curr + 2 * VLEN); - - HVX_Vector v2 = Q6_Vqf32_vmpy_VsfVsf(v2a, v2b); - - *(HVX_Vector *) dst_curr = Q6_Vsf_equals_Vqf32(v1); - - HVX_Vector v3b = *(HVX_Vector *) (src1_curr + 2 * VLEN); - - HVX_Vector v4a = *(HVX_Vector *) (src0_curr + 3 * VLEN); - - src0_curr += 4 * VLEN; - - HVX_Vector v3 = Q6_Vqf32_vmpy_VsfVsf(v3a, v3b); - - *(HVX_Vector *) (dst_curr + VLEN) = Q6_Vsf_equals_Vqf32(v2); - - HVX_Vector v4b = *(HVX_Vector *) (src1_curr + 3 * VLEN); - - *(HVX_Vector *) (dst_curr + 2 * VLEN) = Q6_Vsf_equals_Vqf32(v3); - - HVX_Vector v4 = Q6_Vqf32_vmpy_VsfVsf(v4a, v4b); - - src1_curr += 4 * VLEN; - - *(HVX_Vector *) (dst_curr + 3 * VLEN) = Q6_Vsf_equals_Vqf32(v4); - - dst_curr += 4 * VLEN; - } - - for (int i = 0; i < step_of_2; i++) { - HVX_Vector v1a = *(HVX_Vector *) src0_curr; - - HVX_Vector v1b = *(HVX_Vector *) src1_curr; - - HVX_Vector v2a = *(HVX_Vector *) (src0_curr + VLEN); - - HVX_Vector v1 = Q6_Vqf32_vmpy_VsfVsf(v1a, v1b); - - HVX_Vector v2b = *(HVX_Vector *) (src1_curr + VLEN); - - *(HVX_Vector *) dst_curr = Q6_Vsf_equals_Vqf32(v1); - - src0_curr += 2 * VLEN; - - HVX_Vector v2 = Q6_Vqf32_vmpy_VsfVsf(v2a, v2b); - - src1_curr += 2 * VLEN; - - *(HVX_Vector *) (dst_curr + VLEN) = Q6_Vsf_equals_Vqf32(v2); - - dst_curr += 2 * VLEN; - } - - for (int i = 0; i < step_of_1; i++) { - HVX_Vector va = *(HVX_Vector *) src0_curr; - - src0_curr += VLEN; - - HVX_Vector vb = *(HVX_Vector *) src1_curr; - - src1_curr += VLEN; - - HVX_Vector v = Q6_Vqf32_vmpy_VsfVsf(va, vb); - - *(HVX_Vector *) dst_curr = Q6_Vsf_equals_Vqf32(v); - - dst_curr += VLEN; - } - - if (remaining > 0) { - HVX_Vector v = Q6_Vqf32_vmpy_VsfVsf(*(HVX_Vector *) src0_curr, *(HVX_Vector *) src1_curr); - hvx_vec_store_u((void *) dst_curr, remaining * SIZEOF_FP32, Q6_Vsf_equals_Vqf32(v)); - } -} - -void hvx_mul_mul_f32_opt(const uint8_t * restrict src0, - const uint8_t * restrict src1, - const uint8_t * restrict src2, - uint8_t * restrict dst, - const int num_elems) { - const uint8_t * restrict src0_curr = src0; - const uint8_t * restrict src1_curr = src1; - const uint8_t * restrict src2_curr = src2; - uint8_t * restrict dst_curr = dst; - - int step_of_2 = num_elems >> 6; - int step_of_1 = (num_elems - step_of_2 * VLEN_FP32 * 2) >> 5; - int remaining = num_elems - step_of_2 * VLEN_FP32 * 2 - step_of_1 * VLEN_FP32; - - for (int i = 0; i < step_of_2; i++) { - HVX_Vector v1a = *(HVX_Vector *) src0_curr; - HVX_Vector v1b = *(HVX_Vector *) src1_curr; - HVX_Vector v1c = *(HVX_Vector *) src2_curr; - - HVX_Vector v2a = *(HVX_Vector *) (src0_curr + VLEN); - - HVX_Vector v1_ = Q6_Vqf32_vmpy_VsfVsf(v1a, v1b); - HVX_Vector v1 = Q6_Vqf32_vmpy_VsfVsf(Q6_Vsf_equals_Vqf32(v1_), v1c); - - HVX_Vector v2b = *(HVX_Vector *) (src1_curr + VLEN); - - *(HVX_Vector *) dst_curr = Q6_Vsf_equals_Vqf32(v1); - - HVX_Vector v2c = *(HVX_Vector *) (src2_curr + VLEN); - - src0_curr += 2 * VLEN; - - HVX_Vector v2_ = Q6_Vqf32_vmpy_VsfVsf(v2a, v2b); - HVX_Vector v2 = Q6_Vqf32_vmpy_VsfVsf(Q6_Vsf_equals_Vqf32(v2_), v2c); - - src1_curr += 2 * VLEN; - src2_curr += 2 * VLEN; - - *(HVX_Vector *) (dst_curr + VLEN) = Q6_Vsf_equals_Vqf32(v2); - - dst_curr += 2 * VLEN; - } - for (int i = 0; i < step_of_1; i++) { - HVX_Vector va = *(HVX_Vector *) src0_curr; - src0_curr += VLEN; - - HVX_Vector vb = *(HVX_Vector *) src1_curr; - src1_curr += VLEN; - - HVX_Vector vc = *(HVX_Vector *) src2_curr; - src2_curr += VLEN; - - HVX_Vector v1 = Q6_Vqf32_vmpy_VsfVsf(va, vb); - HVX_Vector v2 = Q6_Vqf32_vmpy_VsfVsf(Q6_Vsf_equals_Vqf32(v1), vc); - - *(HVX_Vector *) dst_curr = Q6_Vsf_equals_Vqf32(v2); - dst_curr += VLEN; - } - if (remaining > 0) { - HVX_Vector v1 = Q6_Vqf32_vmpy_VsfVsf(*(HVX_Vector *) src0_curr, *(HVX_Vector *) src1_curr); - HVX_Vector v2 = Q6_Vqf32_vmpy_VsfVsf(Q6_Vsf_equals_Vqf32(v1), *(HVX_Vector *) src2_curr); - hvx_vec_store_u((void *) dst_curr, remaining * SIZEOF_FP32, Q6_Vsf_equals_Vqf32(v2)); - } -} - -void hvx_add_f32(const uint8_t * restrict src0, - const uint8_t * restrict src1, - uint8_t * restrict dst, - const int num_elems) { - int left_over = num_elems & (VLEN_FP32 - 1); - int num_elems_whole = num_elems - left_over; - - int unaligned_addr = 0; - int unaligned_loop = 0; - if ((0 == htp_is_aligned((void *) src0, VLEN)) || (0 == htp_is_aligned((void *) src1, VLEN)) || - (0 == htp_is_aligned((void *) dst, VLEN))) { - FARF(HIGH, "hvx_add_f32: unaligned address in hvx op, possibly slower execution\n"); - unaligned_addr = 1; - } - - if ((1 == unaligned_addr) && (num_elems_whole != 0)) { - unaligned_loop = 1; - FARF(HIGH, "hvx_add_f32: unaligned loop in hvx op, possibly slower execution\n"); - } - - if (0 == unaligned_loop) { - HVX_Vector * restrict vec_in1 = (HVX_Vector *) src0; - HVX_Vector * restrict vec_in2 = (HVX_Vector *) src1; - HVX_Vector * restrict vec_out = (HVX_Vector *) dst; - - #pragma unroll(4) - for (int i = 0; i < num_elems_whole; i += VLEN_FP32) { - HVX_Vector v = Q6_Vqf32_vadd_VsfVsf(*vec_in1++, *vec_in2++); - *vec_out++ = Q6_Vsf_equals_Vqf32(v); - } - } else { - #pragma unroll(4) - for (int i = 0; i < num_elems_whole; i += VLEN_FP32) { - HVX_Vector in1 = *(HVX_UVector *) (src0 + i * SIZEOF_FP32); - HVX_Vector in2 = *(HVX_UVector *) (src1 + i * SIZEOF_FP32); - - HVX_Vector out = Q6_Vqf32_vadd_VsfVsf(in1, in2); - - *(HVX_UVector *) (dst + i * SIZEOF_FP32) = Q6_Vsf_equals_Vqf32(out); - } - } - - if (left_over > 0) { - const float * src0f = (const float *) src0 + num_elems_whole; - const float * src1f = (const float *) src1 + num_elems_whole; - float * dstf = (float *) dst + num_elems_whole; - - HVX_Vector in1 = *(HVX_UVector *) src0f; - HVX_Vector in2 = *(HVX_UVector *) src1f; - - HVX_Vector out = Q6_Vqf32_vadd_VsfVsf(in1, in2); - hvx_vec_store_u((void *) dstf, left_over * SIZEOF_FP32, Q6_Vsf_equals_Vqf32(out)); - } -} - -void hvx_add_f32_opt(const uint8_t * restrict src0, - const uint8_t * restrict src1, - uint8_t * restrict dst, - const int num_elems) { - htp_binary_ops_preamble; - - for (int i = 0; i < step_of_4; i++) { - HVX_Vector v1a = *(HVX_Vector *) src0_curr; - - HVX_Vector v1b = *(HVX_Vector *) src1_curr; - - HVX_Vector v2a = *(HVX_Vector *) (src0_curr + VLEN); - - HVX_Vector v1 = Q6_Vqf32_vadd_VsfVsf(v1a, v1b); - - HVX_Vector v2b = *(HVX_Vector *) (src1_curr + VLEN); - - HVX_Vector v3a = *(HVX_Vector *) (src0_curr + 2 * VLEN); - - HVX_Vector v2 = Q6_Vqf32_vadd_VsfVsf(v2a, v2b); - - *(HVX_Vector *) dst_curr = Q6_Vsf_equals_Vqf32(v1); - - HVX_Vector v3b = *(HVX_Vector *) (src1_curr + 2 * VLEN); - - HVX_Vector v4a = *(HVX_Vector *) (src0_curr + 3 * VLEN); - - src0_curr += 4 * VLEN; - - HVX_Vector v3 = Q6_Vqf32_vadd_VsfVsf(v3a, v3b); - - *(HVX_Vector *) (dst_curr + VLEN) = Q6_Vsf_equals_Vqf32(v2); - - HVX_Vector v4b = *(HVX_Vector *) (src1_curr + 3 * VLEN); - - *(HVX_Vector *) (dst_curr + 2 * VLEN) = Q6_Vsf_equals_Vqf32(v3); - - HVX_Vector v4 = Q6_Vqf32_vadd_VsfVsf(v4a, v4b); - - src1_curr += 4 * VLEN; - - *(HVX_Vector *) (dst_curr + 3 * VLEN) = Q6_Vsf_equals_Vqf32(v4); - - dst_curr += 4 * VLEN; - } - for (int i = 0; i < step_of_2; i++) { - HVX_Vector v1a = *(HVX_Vector *) src0_curr; - - HVX_Vector v1b = *(HVX_Vector *) src1_curr; - - HVX_Vector v2a = *(HVX_Vector *) (src0_curr + VLEN); - - HVX_Vector v1 = Q6_Vqf32_vadd_VsfVsf(v1a, v1b); - - HVX_Vector v2b = *(HVX_Vector *) (src1_curr + VLEN); - - *(HVX_Vector *) dst_curr = Q6_Vsf_equals_Vqf32(v1); - - src0_curr += 2 * VLEN; - - HVX_Vector v2 = Q6_Vqf32_vadd_VsfVsf(v2a, v2b); - - src1_curr += 2 * VLEN; - - *(HVX_Vector *) (dst_curr + VLEN) = Q6_Vsf_equals_Vqf32(v2); - - dst_curr += 2 * VLEN; - } - for (int i = 0; i < step_of_1; i++) { - HVX_Vector va = *(HVX_Vector *) src0_curr; - - src0_curr += VLEN; - - HVX_Vector vb = *(HVX_Vector *) src1_curr; - - src1_curr += VLEN; - - HVX_Vector v = Q6_Vqf32_vadd_VsfVsf(va, vb); - - *(HVX_Vector *) dst_curr = Q6_Vsf_equals_Vqf32(v); - - dst_curr += VLEN; - } - if (remaining > 0) { - HVX_Vector v = Q6_Vqf32_vadd_VsfVsf(*(HVX_Vector *) src0_curr, *(HVX_Vector *) src1_curr); - hvx_vec_store_u((void *) dst_curr, remaining * SIZEOF_FP32, Q6_Vsf_equals_Vqf32(v)); - } -} - -void hvx_add_scalar_f32(const uint8_t * restrict src, const float val, uint8_t * restrict dst, const int num_elems) { - size_t left_over = num_elems & (VLEN_FP32 - 1); - size_t num_elems_whole = num_elems - left_over; - - int unaligned_addr = 0; - int unaligned_loop = 0; - if ((0 == htp_is_aligned((void *) src, VLEN)) || (0 == htp_is_aligned((void *) dst, VLEN))) { - FARF(HIGH, "hvx_add_scalar_f32: unaligned address in hvx op, possibly slower execution\n"); - unaligned_addr = 1; - } - - if ((1 == unaligned_addr) && (num_elems_whole != 0)) { - unaligned_loop = 1; - FARF(HIGH, "hvx_add_scalar_f32: unaligned loop in hvx op, possibly slower execution\n"); - } - - static const float kInf = INFINITY; - const HVX_Vector inf = hvx_vec_splat_fp32(kInf); - HVX_Vector val_vec = hvx_vec_splat_fp32(val); - - if (0 == unaligned_loop) { - HVX_Vector * restrict vec_in1 = (HVX_Vector *) src; - HVX_Vector * restrict vec_out = (HVX_Vector *) dst; - - #pragma unroll(4) - for (int i = 0; i < num_elems_whole; i += VLEN_FP32) { - HVX_Vector in = *vec_in1++; - const HVX_VectorPred pred_inf = Q6_Q_vcmp_eq_VwVw(inf, in); - HVX_Vector v = Q6_Vqf32_vadd_VsfVsf(in, val_vec); - v = Q6_Vsf_equals_Vqf32(v); - v = Q6_V_vmux_QVV(pred_inf, inf, v); - *vec_out++ = v; - } - } else { - #pragma unroll(4) - for (int i = 0; i < num_elems_whole; i += VLEN_FP32) { - HVX_Vector in = *(HVX_UVector *) (src + i * SIZEOF_FP32); - - const HVX_VectorPred pred_inf = Q6_Q_vcmp_eq_VwVw(inf, in); - HVX_Vector out = Q6_Vqf32_vadd_VsfVsf(in, val_vec); - out = Q6_Vsf_equals_Vqf32(out); - out = Q6_V_vmux_QVV(pred_inf, inf, out); - - *(HVX_UVector *) (dst + i * SIZEOF_FP32) = out; - } - } - - if (left_over > 0) { - const float * srcf = (const float *) src + num_elems_whole; - float * dstf = (float *) dst + num_elems_whole; - - HVX_Vector in = *(HVX_UVector *) srcf; - - const HVX_VectorPred pred_inf = Q6_Q_vcmp_eq_VwVw(inf, in); - HVX_Vector out = Q6_Vqf32_vadd_VsfVsf(in, val_vec); - out = Q6_Vsf_equals_Vqf32(out); - out = Q6_V_vmux_QVV(pred_inf, inf, out); - - hvx_vec_store_u((void *) dstf, left_over * SIZEOF_FP32, out); - } -} - -void hvx_mul_scalar_f32(const uint8_t * restrict src, const float val, uint8_t * restrict dst, const int num_elems) { - size_t left_over = num_elems & (VLEN_FP32 - 1); - size_t num_elems_whole = num_elems - left_over; - - int unaligned_addr = 0; - int unaligned_loop = 0; - if ((0 == htp_is_aligned((void *) src, VLEN)) || (0 == htp_is_aligned((void *) dst, VLEN))) { - FARF(HIGH, "hvx_mul_scalar_f32: unaligned address in hvx op, possibly slower execution\n"); - unaligned_addr = 1; - } - - if ((1 == unaligned_addr) && (num_elems_whole != 0)) { - unaligned_loop = 1; - FARF(HIGH, "hvx_mul_scalar_f32: unaligned loop in hvx op, possibly slower execution\n"); - } - - HVX_Vector val_vec = hvx_vec_splat_fp32(val); - bool handled_leftover = false; - if (0 == unaligned_loop) { - HVX_Vector * restrict vec_in1 = (HVX_Vector *) src; - HVX_Vector * restrict vec_out = (HVX_Vector *) dst; - - #pragma unroll(4) - for (int i = 0; i < num_elems_whole; i += VLEN_FP32) { - HVX_Vector v = Q6_Vqf32_vmpy_VsfVsf(*vec_in1++, val_vec); - *vec_out++ = Q6_Vsf_equals_Vqf32(v); - } - } else { - int step_of_1 = num_elems >> 5; // divby 32, because 32 float = 128 bytes per HVX vector - int leftover_size = left_over * sizeof(float); - - HVX_Vector * input_v_ptr = (HVX_Vector *) src; - HVX_UVector * output_v_ptr = (HVX_UVector *) dst; - - HVX_Vector slinep; - HVX_Vector slinec; - HVX_Vector sline; - - slinep = *input_v_ptr++; - - #pragma unroll(4) - for (int i = step_of_1 - 1; i > 0; i--) { - slinec = *input_v_ptr++; - sline = Q6_V_valign_VVR(slinec, slinep, (size_t) src); - *((HVX_UVector *) (output_v_ptr++)) = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(sline, val_vec)); - /* Prepare slinep for next iteration */ - slinep = slinec; - } - - if (step_of_1 > 0) { - slinec = htp_is_aligned(input_v_ptr, VLEN) && left_over == 0 ? slinep : *input_v_ptr++; - sline = Q6_V_valign_VVR(slinec, slinep, (size_t) src); - *((HVX_UVector *) (output_v_ptr++)) = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(sline, val_vec)); - - slinep = slinec; - } - - if (leftover_size > 0) { - slinec = (is_in_one_chunk(input_v_ptr, leftover_size, VLEN) ? slinep : *input_v_ptr++); - - sline = Q6_V_valign_VVR(slinec, slinep, (size_t) src); - - HVX_Vector sout = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(sline, val_vec)); - hvx_vec_store_u(output_v_ptr, leftover_size, sout); - handled_leftover = true; - } - } - - if (left_over > 0 && !handled_leftover) { - const float * srcf = (const float *) src + num_elems_whole; - float * dstf = (float *) dst + num_elems_whole; - - HVX_Vector in = *(HVX_UVector *) srcf; - - HVX_Vector out = Q6_Vqf32_vmpy_VsfVsf(in, val_vec); - hvx_vec_store_u((void *) dstf, left_over * SIZEOF_FP32, Q6_Vsf_equals_Vqf32(out)); - } -} - -void hvx_sub_f32(const uint8_t * restrict src0, - const uint8_t * restrict src1, - uint8_t * restrict dst, - const int num_elems) { - size_t left_over = num_elems & (VLEN_FP32 - 1); - size_t num_elems_whole = num_elems - left_over; - - int unaligned_addr = 0; - int unaligned_loop = 0; - if ((0 == htp_is_aligned((void *) src0, VLEN)) || (0 == htp_is_aligned((void *) src1, VLEN)) || - (0 == htp_is_aligned((void *) dst, VLEN))) { - FARF(HIGH, "hvx_sub_f32: unaligned address in hvx op, possibly slower execution\n"); - unaligned_addr = 1; - } - - if ((1 == unaligned_addr) && (num_elems_whole != 0)) { - unaligned_loop = 1; - FARF(HIGH, "hvx_sub_f32: unaligned loop in hvx op, possibly slower execution\n"); - } - - if (0 == unaligned_loop) { - HVX_Vector * restrict vec_in1 = (HVX_Vector *) src0; - HVX_Vector * restrict vec_in2 = (HVX_Vector *) src1; - HVX_Vector * restrict vec_out = (HVX_Vector *) dst; - - #pragma unroll(4) - for (int i = 0; i < num_elems_whole; i += VLEN_FP32) { - HVX_Vector v = Q6_Vqf32_vsub_VsfVsf(*vec_in1++, *vec_in2++); - *vec_out++ = Q6_Vsf_equals_Vqf32(v); - } - } else { - #pragma unroll(4) - for (int i = 0; i < num_elems_whole; i += VLEN_FP32) { - HVX_Vector in1 = *(HVX_UVector *) (src0 + i * SIZEOF_FP32); - HVX_Vector in2 = *(HVX_UVector *) (src1 + i * SIZEOF_FP32); - - HVX_Vector out = Q6_Vqf32_vsub_VsfVsf(in1, in2); - - *(HVX_UVector *) (dst + i * SIZEOF_FP32) = Q6_Vsf_equals_Vqf32(out); - } - } - - if (left_over > 0) { - const float * src0f = (const float *) src0 + num_elems_whole; - const float * src1f = (const float *) src1 + num_elems_whole; - float * dstf = (float *) dst + num_elems_whole; - - HVX_Vector in1 = *(HVX_UVector *) src0f; - HVX_Vector in2 = *(HVX_UVector *) src1f; - - HVX_Vector out = Q6_Vqf32_vsub_VsfVsf(in1, in2); - hvx_vec_store_u((void *) dstf, left_over * SIZEOF_FP32, Q6_Vsf_equals_Vqf32(out)); - } -} - -void hvx_sub_f32_opt(const uint8_t * restrict src0, - const uint8_t * restrict src1, - uint8_t * restrict dst, - const int num_elems) { - htp_binary_ops_preamble; - - for (int i = 0; i < step_of_4; i++) { - HVX_Vector v1a = *(HVX_Vector *) src0_curr; - - HVX_Vector v1b = *(HVX_Vector *) src1_curr; - - HVX_Vector v2a = *(HVX_Vector *) (src0_curr + VLEN); - - HVX_Vector v1 = Q6_Vqf32_vsub_VsfVsf(v1a, v1b); - - HVX_Vector v2b = *(HVX_Vector *) (src1_curr + VLEN); - - HVX_Vector v3a = *(HVX_Vector *) (src0_curr + 2 * VLEN); - - HVX_Vector v2 = Q6_Vqf32_vsub_VsfVsf(v2a, v2b); - - *(HVX_Vector *) dst_curr = Q6_Vsf_equals_Vqf32(v1); - - HVX_Vector v3b = *(HVX_Vector *) (src1_curr + 2 * VLEN); - - HVX_Vector v4a = *(HVX_Vector *) (src0_curr + 3 * VLEN); - - src0_curr += 4 * VLEN; - - HVX_Vector v3 = Q6_Vqf32_vsub_VsfVsf(v3a, v3b); - - *(HVX_Vector *) (dst_curr + VLEN) = Q6_Vsf_equals_Vqf32(v2); - - HVX_Vector v4b = *(HVX_Vector *) (src1_curr + 3 * VLEN); - - *(HVX_Vector *) (dst_curr + 2 * VLEN) = Q6_Vsf_equals_Vqf32(v3); - - HVX_Vector v4 = Q6_Vqf32_vsub_VsfVsf(v4a, v4b); - - src1_curr += 4 * VLEN; - - *(HVX_Vector *) (dst_curr + 3 * VLEN) = Q6_Vsf_equals_Vqf32(v4); - - dst_curr += 4 * VLEN; - } - for (int i = 0; i < step_of_2; i++) { - HVX_Vector v1a = *(HVX_Vector *) src0_curr; - - HVX_Vector v1b = *(HVX_Vector *) src1_curr; - - HVX_Vector v2a = *(HVX_Vector *) (src0_curr + VLEN); - - HVX_Vector v1 = Q6_Vqf32_vsub_VsfVsf(v1a, v1b); - - HVX_Vector v2b = *(HVX_Vector *) (src1_curr + VLEN); - - *(HVX_Vector *) dst_curr = Q6_Vsf_equals_Vqf32(v1); - - src0_curr += 2 * VLEN; - - HVX_Vector v2 = Q6_Vqf32_vsub_VsfVsf(v2a, v2b); - - src1_curr += 2 * VLEN; - - *(HVX_Vector *) (dst_curr + VLEN) = Q6_Vsf_equals_Vqf32(v2); - - dst_curr += 2 * VLEN; - } - for (int i = 0; i < step_of_1; i++) { - HVX_Vector va = *(HVX_Vector *) src0_curr; - - src0_curr += VLEN; - - HVX_Vector vb = *(HVX_Vector *) src1_curr; - - src1_curr += VLEN; - - HVX_Vector v = Q6_Vqf32_vsub_VsfVsf(va, vb); - - *(HVX_Vector *) dst_curr = Q6_Vsf_equals_Vqf32(v); - - dst_curr += VLEN; - } - if (remaining > 0) { - HVX_Vector v = Q6_Vqf32_vsub_VsfVsf(*(HVX_Vector *) src0_curr, *(HVX_Vector *) src1_curr); - hvx_vec_store_u((void *) dst_curr, remaining * SIZEOF_FP32, Q6_Vsf_equals_Vqf32(v)); - } -} - -void hvx_sub_scalar_f32(const uint8_t * restrict src, const float val, uint8_t * restrict dst, const int num_elems) { - size_t left_over = num_elems & (VLEN_FP32 - 1); - size_t num_elems_whole = num_elems - left_over; - - int unaligned_addr = 0; - int unaligned_loop = 0; - if ((0 == htp_is_aligned((void *) src, VLEN)) || (0 == htp_is_aligned((void *) dst, VLEN))) { - FARF(HIGH, "hvx_sub_scalar_f32: unaligned address in hvx op, possibly slower execution\n"); - unaligned_addr = 1; - } - - if ((1 == unaligned_addr) && (num_elems_whole != 0)) { - unaligned_loop = 1; - FARF(HIGH, "hvx_sub_scalar_f32: unaligned loop in hvx op, possibly slower execution\n"); - } - - HVX_Vector val_vec = hvx_vec_splat_fp32(val); - - if (0 == unaligned_loop) { - HVX_Vector * restrict vec_in1 = (HVX_Vector *) src; - HVX_Vector * restrict vec_out = (HVX_Vector *) dst; - - #pragma unroll(4) - for (int i = 0; i < num_elems_whole; i += VLEN_FP32) { - HVX_Vector v = Q6_Vqf32_vsub_VsfVsf(*vec_in1++, val_vec); - *vec_out++ = Q6_Vsf_equals_Vqf32(v); - } - } else { - #pragma unroll(4) - for (int i = 0; i < num_elems_whole; i += VLEN_FP32) { - HVX_Vector in = *(HVX_UVector *) (src + i * SIZEOF_FP32); - - HVX_Vector out = Q6_Vqf32_vsub_VsfVsf(in, val_vec); - - *(HVX_UVector *) (dst + i * SIZEOF_FP32) = Q6_Vsf_equals_Vqf32(out); - } - } - - if (left_over > 0) { - const float * srcf = (const float *) src + num_elems_whole; - float * dstf = (float *) dst + num_elems_whole; - - HVX_Vector in = *(HVX_UVector *) srcf; - - HVX_Vector out = Q6_Vqf32_vsub_VsfVsf(in, val_vec); - hvx_vec_store_u((void *) dstf, left_over * SIZEOF_FP32, Q6_Vsf_equals_Vqf32(out)); - } -} - -float hvx_sum_of_squares_f32(const uint8_t * restrict src, const int num_elems) { - int left_over = num_elems & (VLEN_FP32 - 1); - int num_elems_whole = num_elems - left_over; - - if (0 == htp_is_aligned((void *) src, VLEN)) { - FARF(HIGH, "hvx_sum_of_squares_f32: unaligned address in hvx op, possibly slower execution\n"); - } - - assert((1 == htp_is_aligned((void *) src, VLEN)) || (0 == num_elems_whole)); - - HVX_Vector * restrict vec_in1 = (HVX_Vector *) src; - - HVX_Vector sum_vec_acc = Q6_V_vsplat_R(0x00000000); - HVX_Vector zero_vec = Q6_V_vsplat_R(0x00000000); - - #pragma unroll(4) - for (int i = 0; i < num_elems_whole; i += VLEN_FP32) { - HVX_Vector v = Q6_Vqf32_vmpy_VsfVsf(*vec_in1, *vec_in1); - sum_vec_acc = Q6_Vqf32_vadd_Vqf32Vqf32(sum_vec_acc, v); - vec_in1++; - } - - if (left_over > 0) { - const float * srcf = (const float *) src + num_elems_whole; - - HVX_Vector vec_left = *(HVX_UVector *) srcf; - - HVX_Vector vec_left_sq = Q6_Vqf32_vmpy_VsfVsf(vec_left, vec_left); - HVX_Vector vec_tmp = Q6_V_valign_VVR(vec_left_sq, zero_vec, left_over * SIZEOF_FP32); - - sum_vec_acc = Q6_Vqf32_vadd_Vqf32Vqf32(sum_vec_acc, vec_tmp); - } - - HVX_Vector v = hvx_vec_qf32_reduce_sum(sum_vec_acc); - return hvx_vec_get_fp32(Q6_Vsf_equals_Vqf32(v)); -} - -float hvx_self_sum_f32(const uint8_t * restrict src, const int num_elems) { - int left_over = num_elems & (VLEN_FP32 - 1); - int num_elems_whole = num_elems - left_over; - - int unaligned_addr = 0; - int unaligned_loop = 0; - if (0 == htp_is_aligned((void *) src, VLEN)) { - FARF(HIGH, "hvx_self_sum_f32: unaligned address in hvx op, possibly slower execution\n"); - unaligned_addr = 1; - } - - if ((1 == unaligned_addr) && (num_elems_whole != 0)) { - unaligned_loop = 1; - FARF(HIGH, "hvx_self_sum_f32: unaligned loop in hvx op, possibly slower execution\n"); - } - - HVX_Vector sum_vec = Q6_V_vsplat_R(0x00000000); - HVX_Vector zero_vec = Q6_V_vsplat_R(0x00000000); - - if (0 == unaligned_loop) { - HVX_Vector * vec_in = (HVX_Vector *) src; - - #pragma unroll(4) - for (int i = 0; i < num_elems_whole; i += VLEN_FP32) { - // sum_vec = Q6_Vqf32_vadd_Vqf32Vsf(sum_vec, *vec_in++); - sum_vec = Q6_Vqf32_vadd_VsfVsf(Q6_Vsf_equals_Vqf32(sum_vec), *vec_in++); - } - } else { - #pragma unroll(4) - for (int i = 0; i < num_elems_whole; i += VLEN_FP32) { - HVX_Vector in = *(HVX_UVector *) (src + i * SIZEOF_FP32); - - sum_vec = Q6_Vqf32_vadd_VsfVsf(Q6_Vsf_equals_Vqf32(sum_vec), in); - } - } - - if (left_over > 0) { - const float * srcf = (const float *) src + num_elems_whole; - - HVX_Vector vec_left = *(HVX_UVector *) srcf; - HVX_Vector vec_tmp = Q6_V_valign_VVR(vec_left, zero_vec, left_over * SIZEOF_FP32); - // sum_vec = Q6_Vqf32_vadd_Vqf32Vsf(sum_vec, vec_tmp); - sum_vec = Q6_Vqf32_vadd_VsfVsf(Q6_Vsf_equals_Vqf32(sum_vec), vec_tmp); - } - - HVX_Vector v = hvx_vec_qf32_reduce_sum(sum_vec); - return hvx_vec_get_fp32(Q6_Vsf_equals_Vqf32(v)); -} - -float hvx_self_max_f32(const uint8_t * restrict src, const int num_elems) { - int left_over = num_elems & (VLEN_FP32 - 1); - int num_elems_whole = num_elems - left_over; - - int unaligned_addr = 0; - int unaligned_loop = 0; - if (0 == htp_is_aligned((void *) src, VLEN)) { - FARF(HIGH, "hvx_self_max_f32: unaligned address in hvx op, possibly slower execution\n"); - unaligned_addr = 1; - } - - if ((1 == unaligned_addr) && (num_elems_whole != 0)) { - unaligned_loop = 1; - FARF(HIGH, "hvx_self_max_f32: unaligned loop in hvx op, possibly slower execution\n"); - } - - HVX_Vector vec_max = hvx_vec_splat_fp32(((const float *) src)[0]); - HVX_Vector vec_first = hvx_vec_splat_fp32(((const float *) src)[0]); - - if (0 == unaligned_loop) { - HVX_Vector * restrict vec_in = (HVX_Vector *) src; - - #pragma unroll(4) - for (int i = 0; i < num_elems_whole; i += VLEN_FP32) { - vec_max = Q6_Vsf_vmax_VsfVsf(vec_max, *vec_in++); - } - } else { - #pragma unroll(4) - for (int i = 0; i < num_elems_whole; i += VLEN_FP32) { - HVX_Vector in = *(HVX_UVector *) (src + i * SIZEOF_FP32); - - vec_max = Q6_Vsf_vmax_VsfVsf(vec_max, in); - } - } - - if (left_over > 0) { - const float * srcf = (const float *) src + num_elems_whole; - - HVX_Vector in = *(HVX_UVector *) srcf; - - HVX_Vector temp = Q6_V_valign_VVR(in, vec_first, left_over * SIZEOF_FP32); - vec_max = Q6_Vsf_vmax_VsfVsf(vec_max, temp); - } - - HVX_Vector v = hvx_vec_reduce_max_fp32(vec_max); - return hvx_vec_get_fp32(v); -} - -void hvx_min_scalar_f32(const uint8_t * restrict src, const float val, uint8_t * restrict dst, const int num_elems) { - size_t left_over = num_elems & (VLEN_FP32 - 1); - size_t num_elems_whole = num_elems - left_over; - int unalign_address = 0; - if ((0 == htp_is_aligned((void *) src, VLEN)) || (0 == htp_is_aligned((void *) dst, VLEN))) { - FARF(HIGH, "hvx_min_scalar_f32: unaligned address in hvx op, possibly slower execution\n"); - unalign_address = 1; - } - - const float * src_f = (const float *) src; - - HVX_Vector vec_min = hvx_vec_splat_fp32(val); - - if(unalign_address == 0){ - HVX_Vector * restrict vec_in = (HVX_Vector *) src; - HVX_Vector * restrict vec_out = (HVX_Vector *) dst; - - #pragma unroll(4) - for (int i = 0; i < num_elems_whole; i += VLEN_FP32) { - HVX_Vector min_clamp = Q6_Vsf_vmin_VsfVsf(vec_min, *vec_in++); - *vec_out++ = (min_clamp); - } - }else{ - HVX_UVector * restrict vec_in = (HVX_Vector *) src; - HVX_UVector * restrict vec_out = (HVX_Vector *) dst; - - #pragma unroll(4) - for (int i = 0; i < num_elems_whole; i += VLEN_FP32) { - HVX_Vector min_clamp = Q6_Vsf_vmin_VsfVsf(vec_min, *vec_in++); - *vec_out++ = (min_clamp); - } - } - - if (left_over > 0 ) { - const float * srcf = (const float *) src + num_elems_whole; - float * dstf = (float *) dst + num_elems_whole; - - HVX_UVector in = *(HVX_UVector *) srcf; - - HVX_UVector min_clamp = Q6_Vsf_vmin_VsfVsf(vec_min, in); - - hvx_vec_store_u((void *) dstf, left_over * SIZEOF_FP32, (min_clamp)); - } -} - -void hvx_clamp_scalar_f32(const uint8_t * restrict src, - const float limit_left, - const float limit_right, - uint8_t * restrict dst, - const int num_elems) { - size_t left_over = num_elems & (VLEN_FP32 - 1); - size_t num_elems_whole = num_elems - left_over; - - int unalign_address = 0; - if ((0 == htp_is_aligned((void *) src, VLEN)) || (0 == htp_is_aligned((void *) dst, VLEN))) { - FARF(HIGH, "hvx_clamp_scalar_f32: unaligned address in hvx op, possibly slower execution\n"); - unalign_address = 1; - } - - HVX_Vector range_left = hvx_vec_splat_fp32(limit_left); - HVX_Vector range_right = hvx_vec_splat_fp32(limit_right); - - if(unalign_address == 0){ - HVX_Vector * restrict vec_in = (HVX_Vector *) src; - HVX_Vector * restrict vec_out = (HVX_Vector *) dst; - - - - #pragma unroll(4) - for (int i = 0; i < num_elems_whole; i += VLEN_FP32) { - HVX_Vector in_vec = *vec_in++; - HVX_Vector temp_v = in_vec; - - HVX_VectorPred pred_cap_right = Q6_Q_vcmp_gt_VsfVsf(in_vec, range_right); - HVX_VectorPred pred_cap_left = Q6_Q_vcmp_gt_VsfVsf(range_left, in_vec); - - in_vec = Q6_V_vmux_QVV(pred_cap_right, range_right, temp_v); - in_vec = Q6_V_vmux_QVV(pred_cap_left, range_left, in_vec); - - *vec_out++ = in_vec; - } - - }else{ - - HVX_UVector * restrict vec_in = (HVX_UVector *) src; - HVX_UVector * restrict vec_out = (HVX_UVector *) dst; - - #pragma unroll(4) - for (int i = 0; i < num_elems_whole; i += VLEN_FP32) { - HVX_Vector in_vec = *vec_in++; - HVX_Vector temp_v = in_vec; - - HVX_VectorPred pred_cap_right = Q6_Q_vcmp_gt_VsfVsf(in_vec, range_right); - HVX_VectorPred pred_cap_left = Q6_Q_vcmp_gt_VsfVsf(range_left, in_vec); - - in_vec = Q6_V_vmux_QVV(pred_cap_right, range_right, temp_v); - in_vec = Q6_V_vmux_QVV(pred_cap_left, range_left, in_vec); - - *vec_out++ = in_vec; - } - - } - - if (left_over > 0) { - const float * srcf = (const float *) src + num_elems_whole; - float * dstf = (float *) dst + num_elems_whole; - - HVX_Vector in_vec = *(HVX_UVector *) srcf; - - HVX_Vector temp_v = in_vec; - - HVX_VectorPred pred_cap_right = Q6_Q_vcmp_gt_VsfVsf(in_vec, range_right); - HVX_VectorPred pred_cap_left = Q6_Q_vcmp_gt_VsfVsf(range_left, in_vec); - - in_vec = Q6_V_vmux_QVV(pred_cap_right, range_right, temp_v); - in_vec = Q6_V_vmux_QVV(pred_cap_left, range_left, in_vec); - - hvx_vec_store_u((void *) dstf, left_over * SIZEOF_FP32, in_vec); - } -} - - diff --git a/ggml/src/ggml-hexagon/htp/hvx-utils.h b/ggml/src/ggml-hexagon/htp/hvx-utils.h index 22876e6dbaa..23373f73ae2 100644 --- a/ggml/src/ggml-hexagon/htp/hvx-utils.h +++ b/ggml/src/ggml-hexagon/htp/hvx-utils.h @@ -1,1353 +1,23 @@ #ifndef HVX_UTILS_H #define HVX_UTILS_H -#include "ops-utils.h" - -#include <stdbool.h> -#include <stdint.h> - -#define SIZEOF_FP32 (4) -#define SIZEOF_FP16 (2) -#define VLEN (128) -#define VLEN_FP32 (VLEN / SIZEOF_FP32) -#define VLEN_FP16 (VLEN / SIZEOF_FP16) - -typedef union { - HVX_Vector v; - uint8_t b[VLEN]; - uint16_t h[VLEN_FP16]; - uint32_t w[VLEN_FP32]; - __fp16 fp16[VLEN_FP16]; - float fp32[VLEN_FP32]; -} __attribute__((aligned(VLEN), packed)) HVX_VectorAlias; - -/* Q6_Vsf_equals_Vw is only available on v73+.*/ -#if __HVX_ARCH__ < 73 -static inline HVX_Vector int32_to_qfloat(HVX_Vector const in) -{ - HVX_Vector const vzero = Q6_V_vzero(); - HVX_VectorPred is_zero = Q6_Q_vcmp_eq_VwVw(in, vzero); - HVX_Vector lshift = Q6_Vw_vnormamt_Vw(in); - HVX_Vector normalized = Q6_Vw_vasl_VwVw(in, lshift); - HVX_Vector vexp = Q6_Vw_vsub_VwVw(Q6_V_vsplat_R(0x7f + 30), lshift); - HVX_Vector mant = Q6_V_vand_VV(Q6_V_vsplat_R(0xFFFFFF00), normalized); - HVX_Vector ret = Q6_V_vmux_QVV(is_zero, vzero, Q6_Vw_vadd_VwVw(mant, vexp)); - return ret; -} - -static inline HVX_Vector Q6_Vsf_equals_Vw(HVX_Vector const in) -{ - return Q6_Vsf_equals_Vqf32(int32_to_qfloat(in)); -} -#endif - -static inline HVX_Vector hvx_vec_splat_fp32(float v) { - union { - float f; - uint32_t i; - } fp32 = { .f = v }; - - return Q6_V_vsplat_R(fp32.i); -} - -static inline HVX_Vector hvx_vec_splat_fp16(float v) { - union { - __fp16 f; - uint16_t i; - } fp16 = { .f = v }; - - return Q6_Vh_vsplat_R(fp16.i); -} - -static inline void hvx_vec_store_u(void * addr, uint32_t n, HVX_Vector v) { - // Rotate as needed. - v = Q6_V_vlalign_VVR(v, v, (size_t) addr); - - uint32_t left_off = (size_t) addr & 127; - uint32_t right_off = left_off + n; - - HVX_VectorPred ql_not = Q6_Q_vsetq_R((size_t) addr); - HVX_VectorPred qr = Q6_Q_vsetq2_R(right_off); - - if (right_off > 128) { - Q6_vmem_QRIV(qr, (HVX_Vector *) addr + 1, v); - // all 1's - qr = Q6_Q_vcmp_eq_VbVb(v, v); - } - - ql_not = Q6_Q_or_QQn(ql_not, qr); - Q6_vmem_QnRIV(ql_not, (HVX_Vector *) addr, v); -} - -static inline void hvx_vec_store_a(void * ptr, size_t n, HVX_Vector v) { - assert((unsigned long) ptr % 128 == 0); - - HVX_VectorPred ql_not = Q6_Q_vsetq_R((size_t) ptr); - HVX_VectorPred qr = Q6_Q_vsetq2_R(n); - ql_not = Q6_Q_or_QQn(ql_not, qr); - Q6_vmem_QnRIV(ql_not, (HVX_Vector *) ptr, v); -} - -static inline HVX_Vector hvx_vec_repl4(HVX_Vector v) { - // vdelta control to replicate first 4 bytes across all elements - static const uint8_t __attribute__((aligned(128))) repl[128] = { - 0x00, 0x00, 0x00, 0x00, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, - 0x10, 0x10, 0x10, 0x10, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, - 0x20, 0x20, 0x20, 0x20, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, - 0x10, 0x10, 0x10, 0x10, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, - 0x40, 0x40, 0x40, 0x40, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, - 0x10, 0x10, 0x10, 0x10, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, - 0x20, 0x20, 0x20, 0x20, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, - 0x10, 0x10, 0x10, 0x10, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, - }; - - HVX_Vector ctrl = *(HVX_Vector *) repl; - return Q6_V_vdelta_VV(v, ctrl); -} - -// copy n fp16 elements : source and destination are aligned to HVX Vector (128) -static inline void hvx_copy_fp16_aa(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) { - HVX_Vector * restrict vdst = (HVX_Vector *) dst; - HVX_Vector * restrict vsrc = (HVX_Vector *) src; - - assert((unsigned long) dst % 128 == 0); - assert((unsigned long) src % 128 == 0); - - uint32_t nvec = n / 64; - uint32_t nloe = n % 64; - - uint32_t i = 0; - - #pragma unroll(4) - for (; i < nvec; i++) { - HVX_Vector v = vsrc[i]; - vdst[i] = v; - } - - if (nloe) { - HVX_Vector v = vsrc[i]; - hvx_vec_store_u((void *) &vdst[i], nloe * sizeof(__fp16), v); - } -} - -// copy n fp16 elements : source is aligned, destination is potentially unaligned -static inline void hvx_copy_fp16_ua(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) { - HVX_UVector * restrict vdst = (HVX_UVector *) dst; - HVX_Vector * restrict vsrc = (HVX_Vector *) src; - - assert((unsigned long) src % 128 == 0); - - uint32_t nvec = n / 64; - uint32_t nloe = n % 64; - - uint32_t i = 0; - - #pragma unroll(4) - for (; i < nvec; i++) { - HVX_Vector v = vsrc[i]; - vdst[i] = v; - } - - if (nloe) { - HVX_Vector v = vsrc[i]; - hvx_vec_store_u((void *) &vdst[i], nloe * sizeof(__fp16), v); - } -} - -// copy n fp16 elements : source is aligned, destination is potentially unaligned -static inline void hvx_copy_fp16_au(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) { - HVX_Vector * restrict vdst = (HVX_Vector *) dst; - HVX_UVector * restrict vsrc = (HVX_UVector *) src; - - assert((unsigned long) dst % 128 == 0); - - uint32_t nvec = n / 64; - uint32_t nloe = n % 64; - - uint32_t i = 0; - - #pragma unroll(4) - for (; i < nvec; i++) { - HVX_Vector v = vsrc[i]; - vdst[i] = v; - } - - if (nloe) { - HVX_Vector v = vsrc[i]; - hvx_vec_store_u((void *) &vdst[i], nloe * sizeof(__fp16), v); - } -} - -// copy n fp32 elements : source and destination are aligned to HVX Vector (128) -static inline void hvx_copy_fp32_aa(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) { - HVX_Vector * restrict vdst = (HVX_Vector *) dst; - HVX_Vector * restrict vsrc = (HVX_Vector *) src; - - assert((unsigned long) dst % 128 == 0); - assert((unsigned long) src % 128 == 0); - - uint32_t nvec = n / 32; - uint32_t nloe = n % 32; - - uint32_t i = 0; - - #pragma unroll(4) - for (; i < nvec; i++) { - HVX_Vector v = vsrc[i]; - vdst[i] = v; - } - - if (nloe) { - HVX_Vector v = vsrc[i]; - hvx_vec_store_u((void *) &vdst[i], nloe * sizeof(float), v); - } -} - -// copy n fp32 elements : source is aligned, destination is unaligned -static inline void hvx_copy_fp32_ua(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) { - HVX_UVector * restrict vdst = (HVX_UVector *) dst; - HVX_Vector * restrict vsrc = (HVX_Vector *) src; - - assert((unsigned long) src % 128 == 0); - - uint32_t nvec = n / 32; - uint32_t nloe = n % 32; - - uint32_t i = 0; - - #pragma unroll(4) - for (; i < nvec; i++) { - HVX_Vector v = vsrc[i]; - vdst[i] = v; - } - - if (nloe) { - HVX_Vector v = vsrc[i]; - hvx_vec_store_u((void *) &vdst[i], nloe * sizeof(float), v); - } -} - -// copy n fp32 elements : source is unaligned, destination is aligned -static inline void hvx_copy_fp32_au(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) { - HVX_Vector * restrict vdst = (HVX_Vector *) dst; - HVX_UVector * restrict vsrc = (HVX_UVector *) src; - - assert((unsigned long) dst % 128 == 0); - - uint32_t nvec = n / 32; - uint32_t nloe = n % 32; - - uint32_t i = 0; - - #pragma unroll(4) - for (; i < nvec; i++) { - HVX_Vector v = vsrc[i]; - vdst[i] = v; - } - - if (nloe) { - HVX_Vector v = vsrc[i]; - hvx_vec_store_u((void *) &vdst[i], nloe * sizeof(float), v); - } -} - -// copy n fp32 elements : source is unaligned, destination unaligned -static inline void hvx_copy_fp32_uu(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) { - HVX_UVector * restrict vdst = (HVX_UVector *) dst; - HVX_UVector * restrict vsrc = (HVX_UVector *) src; - - assert((unsigned long) dst % 128 == 0); - - uint32_t nvec = n / 32; - uint32_t nloe = n % 32; - - uint32_t i = 0; - - #pragma unroll(4) - for (; i < nvec; i++) { - HVX_Vector v = vsrc[i]; - vdst[i] = v; - } - - if (nloe) { - HVX_Vector v = vsrc[i]; - hvx_vec_store_u((void *) &vdst[i], nloe * sizeof(float), v); - } -} - -// copy/convert n fp32 elements into n fp16 elements : source is unaligned, destination is unaligned -static inline void hvx_copy_fp16_fp32_uu(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) { - HVX_UVector * restrict vdst = (HVX_UVector *) dst; // fp16 - HVX_UVector * restrict vsrc = (HVX_UVector *) src; // fp32 - - const HVX_Vector zero = Q6_V_vsplat_R(0); - - uint32_t nvec = n / 64; - uint32_t nloe = n % 64; - - uint32_t i = 0; - - #pragma unroll(4) - for (; i < nvec; i++) { - // Load y (fp32) and convert into fp16 - HVX_Vector s0_qf = Q6_Vqf32_vsub_VsfVsf(vsrc[i*2+0], zero); // 32 elements - HVX_Vector s1_qf = Q6_Vqf32_vsub_VsfVsf(vsrc[i*2+1], zero); // 32 elements - HVX_Vector s_hf = Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(s1_qf, s0_qf)); - vdst[i] = Q6_Vh_vdeal_Vh(s_hf); - } - - if (nloe) { - // Load y (fp32) and convert into fp16 - HVX_Vector s0_qf = Q6_Vqf32_vsub_VsfVsf(vsrc[i*2+0], zero); // 32 elements - HVX_Vector s1_qf = Q6_Vqf32_vsub_VsfVsf(vsrc[i*2+1], zero); // 32 elements - HVX_Vector s_hf = Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(s1_qf, s0_qf)); - hvx_vec_store_u((void *) &vdst[i], nloe * sizeof(__fp16), Q6_Vh_vdeal_Vh(s_hf)); - } -} - -// copy/convert n fp32 elements into n fp16 elements : source is aligned, destination is unaligned -static inline void hvx_copy_fp16_fp32_ua(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) { - HVX_UVector * restrict vdst = (HVX_UVector *) dst; // fp16 - HVX_Vector * restrict vsrc = (HVX_Vector *) src; // fp32 - - const HVX_Vector zero = Q6_V_vsplat_R(0); - - uint32_t nvec = n / 64; - uint32_t nloe = n % 64; - - uint32_t i = 0; - - #pragma unroll(4) - for (; i < nvec; i++) { - // Load y (fp32) and convert into fp16 - HVX_Vector s0_qf = Q6_Vqf32_vsub_VsfVsf(vsrc[i*2+0], zero); // 32 elements - HVX_Vector s1_qf = Q6_Vqf32_vsub_VsfVsf(vsrc[i*2+1], zero); // 32 elements - HVX_Vector s_hf = Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(s1_qf, s0_qf)); - vdst[i] = Q6_Vh_vdeal_Vh(s_hf); - } - - if (nloe) { - // Load y (fp32) and convert into fp16 - HVX_Vector s0_qf = Q6_Vqf32_vsub_VsfVsf(vsrc[i*2+0], zero); // 32 elements - HVX_Vector s1_qf = Q6_Vqf32_vsub_VsfVsf(vsrc[i*2+1], zero); // 32 elements - HVX_Vector s_hf = Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(s1_qf, s0_qf)); - hvx_vec_store_u((void *) &vdst[i], nloe * sizeof(__fp16), Q6_Vh_vdeal_Vh(s_hf)); - } -} - -// copy/convert n fp32 elements into n fp16 elements : source is unaligned, destination is aligned -static inline void hvx_copy_fp16_fp32_au(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) { - HVX_Vector * restrict vdst = (HVX_Vector *) dst; // fp16 - HVX_UVector * restrict vsrc = (HVX_UVector *) src; // fp32 - - const HVX_Vector zero = Q6_V_vsplat_R(0); - - uint32_t nvec = n / 64; - uint32_t nloe = n % 64; - - uint32_t i = 0; - - #pragma unroll(4) - for (; i < nvec; i++) { - // Load y (fp32) and convert into fp16 - HVX_Vector s0_qf = Q6_Vqf32_vsub_VsfVsf(vsrc[i*2+0], zero); // 32 elements - HVX_Vector s1_qf = Q6_Vqf32_vsub_VsfVsf(vsrc[i*2+1], zero); // 32 elements - HVX_Vector s_hf = Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(s1_qf, s0_qf)); - vdst[i] = Q6_Vh_vdeal_Vh(s_hf); - } - - if (nloe) { - // Load y (fp32) and convert into fp16 - HVX_Vector s0_qf = Q6_Vqf32_vsub_VsfVsf(vsrc[i*2+0], zero); // 32 elements - HVX_Vector s1_qf = Q6_Vqf32_vsub_VsfVsf(vsrc[i*2+1], zero); // 32 elements - HVX_Vector s_hf = Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(s1_qf, s0_qf)); - hvx_vec_store_u((void *) &vdst[i], nloe * sizeof(__fp16), Q6_Vh_vdeal_Vh(s_hf)); - } -} - -// bcast 1 fp32 element from source to n fp32 elements in destination : destination is aligned -static inline void hvx_bcast_fp32_a(uint8_t * restrict dst, float elem, uint32_t n) { - HVX_Vector * restrict vdst = (HVX_Vector *) dst; - - HVX_Vector velem = hvx_vec_splat_fp32(elem); - - assert((unsigned long) dst % 128 == 0); - - uint32_t nvec = n / 32; - uint32_t nloe = n % 32; - - uint32_t i = 0; - - #pragma unroll(4) - for (; i < nvec; i++) { - vdst[i] = velem; - } - - if (nloe) { - hvx_vec_store_u((void *) &vdst[i], nloe * sizeof(float), velem); - } -} - - -/* Return whether 'n' elements from vector are in the one chunk of 'chunk_size'. */ -static __attribute__((always_inline)) int32_t is_in_one_chunk(void * addr, uint32_t n, uint32_t chunk_size) { - uint32_t left_off = (size_t) addr & (chunk_size - 1); - uint32_t right_off = left_off + n; - return right_off <= chunk_size; -} - -static void hvx_vec_dump_fp16_n(char * pref, HVX_Vector v, uint32_t n) { - HVX_VectorAlias u = { .v = v }; - - const uint32_t n0 = n / 16; - const uint32_t n1 = n % 16; - int i = 0; - for (; i < n0; i++) { - htp_dump_fp16_line(pref, u.fp16 + (16 * i), 16); - } - if (n1) { - htp_dump_fp16_line(pref, u.fp16 + (16 * i), n1); - } -} - -static void hvx_vec_dump_fp16(char * pref, HVX_Vector v) { - hvx_vec_dump_fp16_n(pref, v, 64); -} - -static void hvx_vec_dump_fp32_n(char * pref, HVX_Vector v, uint32_t n) { - union { - HVX_Vector v; - float d[32]; - } u = { .v = v }; - - const uint32_t n0 = n / 16; - const uint32_t n1 = n % 16; - int i = 0; - for (; i < n0; i++) { - htp_dump_fp32_line(pref, u.d + (16 * i), 16); - } - if (n1) { - htp_dump_fp32_line(pref, u.d + (16 * i), n1); - } -} - -static void hvx_vec_dump_fp32_hmt(char * pref, HVX_Vector v) { - union { - HVX_Vector v; - float d[32]; - } u = { .v = v }; - - FARF(HIGH, "%s: %.6f %.6f %.6f %.6f ... %.6f %.6f %.6f %.6f ... %.6f %.6f %.6f %.6f\n", pref, u.d[0], u.d[1], - u.d[2], u.d[3], u.d[12], u.d[13], u.d[14], u.d[15], u.d[28], u.d[29], u.d[30], u.d[31]); -} - -static void hvx_vec_dump_fp32(char * pref, HVX_Vector v) { - hvx_vec_dump_fp32_n(pref, v, 32); -} - -static void hvx_vec_dump_int32(char * pref, HVX_Vector v) { - union { - HVX_Vector v; - int32_t d[32]; - } u = { .v = v }; - - for (int i = 0; i < 32 / 16; i++) { - htp_dump_int32_line(pref, u.d + (16 * i), 16); - } -} - -static void hvx_vec_dump_int32_hmt(char * pref, HVX_Vector v) { - union { - HVX_Vector v; - int32_t d[32]; - } u = { .v = v }; - - FARF(HIGH, "%s: %d %d %d %d ... %d %d %d %d ... %d %d %d %d\n", pref, u.d[0], u.d[1], u.d[2], u.d[3], u.d[12], - u.d[13], u.d[14], u.d[15], u.d[28], u.d[29], u.d[30], u.d[31]); -} - -static void hvx_vec_dump_int8_hmt(char * pref, HVX_Vector v) { - union { - HVX_Vector v; - int8_t d[128]; - } u = { .v = v }; - - FARF(HIGH, "%s: %d %d %d %d ... %d %d %d %d ... %d %d %d %d\n", pref, u.d[0], u.d[1], u.d[2], u.d[3], u.d[60], - u.d[61], u.d[62], u.d[63], u.d[124], u.d[125], u.d[126], u.d[127]); -} - -static void hvx_vec_dump_int8(char * pref, HVX_Vector v) { - union { - HVX_Vector v; - int8_t d[128]; - } u = { .v = v }; - - for (int i = 0; i < 128 / 16; i++) { - htp_dump_int8_line(pref, u.d + (16 * i), 16); - } -} - -static void hvx_vec_dump_uint8(char * pref, HVX_Vector v) { - union { - HVX_Vector v; - uint8_t d[128]; - } u = { .v = v }; - - for (int i = 0; i < 128 / 16; i++) { - htp_dump_uint8_line(pref, u.d + (16 * i), 16); - } -} - -static bool hvx_vec_eq(HVX_Vector v0, HVX_Vector v1, size_t n) { - typedef union { - HVX_Vector v; - int8_t d[128]; - } U; - - U u0 = { .v = v0 }; - U u1 = { .v = v1 }; - - for (int i = 0; i < n; i++) { - if (u0.d[i] != u1.d[i]) { - return false; - } - } - - return true; -} - -static inline float hvx_vec_get_fp32(HVX_Vector v) { - float __attribute__((aligned(128))) x; - hvx_vec_store_a(&x, 4, v); - return x; -} - -static inline HVX_Vector hvx_vec_int32_reduce_sum_n(HVX_Vector in, unsigned int n) { - unsigned int total = n * 4; // total vec nbytes - unsigned int width = 4; // int32 - - HVX_Vector sum = in, sum_t; - while (width < total) { - sum_t = Q6_V_vror_VR(sum, width); // rotate right - sum = Q6_Vw_vadd_VwVw(sum_t, sum); // elementwise sum - width = width << 1; - } - return sum; -} - -static inline HVX_Vector hvx_vec_int32_reduce_sum(HVX_Vector in) { - return hvx_vec_int32_reduce_sum_n(in, 32); -} - -static inline HVX_Vector hvx_vec_qf32_reduce_sum_n(HVX_Vector in, unsigned int n) { - unsigned int total = n * 4; // total vec nbytes - unsigned int width = 4; // fp32 nbytes - - HVX_Vector sum = in, sum_t; - while (width < total) { - sum_t = Q6_V_vror_VR(Q6_Vsf_equals_Vqf32(sum), width); // rotate right - sum = Q6_Vqf32_vadd_Vqf32Vsf(sum, sum_t); // elementwise sum - width = width << 1; - } - return sum; -} - -static inline HVX_Vector hvx_vec_qf32_reduce_sum(HVX_Vector in) { - return hvx_vec_qf32_reduce_sum_n(in, 32); -} - -static inline HVX_Vector hvx_vec_fp32_reduce_sum_n(HVX_Vector in, unsigned int n) { - unsigned int total = n * 4; // total vec nbytes - unsigned int width = 4; // fp32 nbytes - - HVX_Vector sum = in, sum_t; - while (width < total) { - sum_t = Q6_V_vror_VR(sum, width); // rotate right - sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_VsfVsf(sum, sum_t)); // elementwise sum - width = width << 1; - } - return sum; -} - -static inline HVX_Vector hvx_vec_fp32_reduce_sum(HVX_Vector in) { - return hvx_vec_fp32_reduce_sum_n(in, 32); -} - -static inline HVX_Vector hvx_vec_reduce_max_fp16(HVX_Vector in) { - unsigned total = 128; // total vec nbytes - unsigned width = 2; // fp16 nbytes - - HVX_Vector _max = in, _max_t; - while (width < total) { - _max_t = Q6_V_vror_VR(_max, width); // rotate right - _max = Q6_Vhf_vmax_VhfVhf(_max_t, _max); // elementwise max - width = width << 1; - } - - return _max; -} - -static inline HVX_Vector hvx_vec_reduce_max2_fp16(HVX_Vector in, HVX_Vector _max) { - unsigned total = 128; // total vec nbytes - unsigned width = 2; // fp32 nbytes - - HVX_Vector _max_t; - - _max = Q6_Vhf_vmax_VhfVhf(in, _max); - while (width < total) { - _max_t = Q6_V_vror_VR(_max, width); // rotate right - _max = Q6_Vhf_vmax_VhfVhf(_max_t, _max); // elementwise max - width = width << 1; - } - - return _max; -} - -static inline HVX_Vector hvx_vec_reduce_max_fp32(HVX_Vector in) { - unsigned total = 128; // total vec nbytes - unsigned width = 4; // fp32 nbytes - - HVX_Vector _max = in, _max_t; - while (width < total) { - _max_t = Q6_V_vror_VR(_max, width); // rotate right - _max = Q6_Vsf_vmax_VsfVsf(_max_t, _max); // elementwise max - width = width << 1; - } - - return _max; -} - -static inline HVX_Vector hvx_vec_reduce_max2_fp32(HVX_Vector in, HVX_Vector _max) { - unsigned total = 128; // total vec nbytes - unsigned width = 4; // fp32 nbytes - - HVX_Vector _max_t; - - _max = Q6_Vsf_vmax_VsfVsf(in, _max); - while (width < total) { - _max_t = Q6_V_vror_VR(_max, width); // rotate right - _max = Q6_Vsf_vmax_VsfVsf(_max_t, _max); // elementwise max - width = width << 1; - } - - return _max; -} - -static inline HVX_Vector hvx_vec_abs_fp16(HVX_Vector v) { - // abs by clearing the fp16 sign bit - HVX_Vector mask = Q6_Vh_vsplat_R(0x7fff); - return Q6_V_vand_VV(v, mask); -} - -static inline HVX_Vector hvx_vec_neg_fp16(HVX_Vector v) { - // neg by setting the fp16 sign bit - HVX_Vector mask = Q6_Vh_vsplat_R(0x8000); - return Q6_V_vxor_VV(v, mask); -} - -static inline HVX_Vector hvx_vec_abs_fp32(HVX_Vector v) { - // abs by clearing the fp32 sign bit - HVX_Vector mask = Q6_V_vsplat_R(0x7fffffff); - return Q6_V_vand_VV(v, mask); -} - -static inline HVX_Vector hvx_vec_neg_fp32(HVX_Vector v) { -#if __HVX_ARCH__ > 75 - return Q6_Vsf_vfneg_Vsf(v); -#else - // neg by setting the fp32 sign bit - HVX_Vector mask = Q6_V_vsplat_R(0x80000000); - return Q6_V_vxor_VV(v, mask); -#endif // __HVX_ARCH__ > 75 -} - -// ==================================================== -// FUNCTION: 1/(x+1) y(0) = 1, y(0.5) = 0.6667, y(1) = 0.5 -// Order:3; continuity: True; Ends forced: True -// Mode: unsigned; Result fractional bits: 14 -// Peak Error: 1.1295e-04 Rms Error: 2.8410e-05 Mean Error: 1.1370e-05 -// 32769 -32706 31252 -10589 -// 32590 -30635 22793 -4493 -// 32066 -27505 16481 -2348 -// 31205 -24054 11849 -1306 - -static inline HVX_Vector hvx_vec_recip_xp1_O3_unsigned(HVX_Vector vx) { - // input is 0..0xffff representing 0.0 .. 1.0 - HVX_Vector p; - p = Q6_Vh_vlut4_VuhPh(vx, 0xFAE6F6D4EE73D6A3ull); - p = Q6_Vh_vmpa_VhVhVuhPuh_sat(p, vx, 0x2E49406159097A14ull); - p = Q6_Vh_vmps_VhVhVuhPuh_sat(p, vx, 0x5DF66B7177AB7FC2ull); - p = Q6_Vh_vmpa_VhVhVuhPuh_sat(p, vx, 0x79E57D427F4E8001ull); - return p; // signed result, 14 fractional bits -} - -// Find reciprocal of fp16. -// (1) first, convert to fp32, multiplying by 1.0; this is done to -// handle denormals. Ignoring sign and zero, result should be at -// least 5.9604645e-08 (32-bit code 0x33800000) and at most 131008 (0x47ffe000) -// (exponent in range [103,143]) -// (2) extract the mantissa into 16-bit unsigned; find reciprocal using a fitted poly -// (3) put this, along with '253-exp' (exp from (1)) together to make an qf32 -// (4) convert that to fp16 -// (5) put sign back in. Also, if the original value (w/o sign) was <0x81, replace -// the result with the max value. -static inline HVX_Vector hvx_vec_inverse_fp16(HVX_Vector vals) { - HVX_Vector em_mask = Q6_Vh_vsplat_R(0x7FFF); - HVX_Vector avals = Q6_V_vand_VV(vals, em_mask); - HVX_VectorPred is_neg = Q6_Q_vcmp_gt_VhVh(avals, vals); - // is too small to 1/x ? for 'standard' fp16, this would be 0x101 - HVX_VectorPred is_small = Q6_Q_vcmp_gt_VhVh(Q6_Vh_vsplat_R(0x101), avals); - - HVX_VectorPair to_qf32 = Q6_Wqf32_vmpy_VhfVhf(avals, Q6_Vh_vsplat_R(0x3C00)); // *1.0 - HVX_Vector to_f32_0 = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(to_qf32)); - HVX_Vector to_f32_1 = Q6_Vsf_equals_Vqf32(Q6_V_hi_W(to_qf32)); - - // bits 22..13 contain the mantissa now (w/o hidden bit); move to bit 14..5 of a 16-bit vector - HVX_Vector mant_u16 = Q6_Vh_vshuffo_VhVh(Q6_Vw_vasl_VwR(to_f32_1, 9), Q6_Vw_vasl_VwR(to_f32_0, 9)); - // likewise extract the upper 16 from each, containing the exponents in range 103..142 - HVX_Vector exp_u16 = Q6_Vh_vshuffo_VhVh(to_f32_1, to_f32_0); - //Get exponent in IEEE 32-bit representation - exp_u16 = Q6_Vuh_vlsr_VuhR(exp_u16, 7); - - // so, mant_u16 contains an unbiased mantissa in upper 10 bits of each u16 lane - // We can consider it to be x-1.0, with 16 fractional bits, where 'x' is in range [1.0,2.0) - // Use poly to transform to 1/x, with 14 fractional bits - // - HVX_Vector rm = hvx_vec_recip_xp1_O3_unsigned(mant_u16); - - HVX_Vector vcl0 = Q6_Vuh_vcl0_Vuh(rm); //count leading zeros - - // Get mantissa for 16-bit represenation - HVX_Vector mant_recip = Q6_V_vand_VV(Q6_Vh_vasr_VhR(Q6_Vh_vasl_VhVh(rm, vcl0), 5), Q6_Vh_vsplat_R(0x03FF)); - - //Compute Reciprocal Exponent - HVX_Vector exp_recip = - Q6_Vh_vsub_VhVh(Q6_Vh_vsub_VhVh(Q6_Vh_vsplat_R(254), exp_u16), Q6_Vh_vsub_VhVh(vcl0, Q6_Vh_vsplat_R(1))); - //Convert it for 16-bit representation - exp_recip = Q6_Vh_vadd_VhVh_sat(Q6_Vh_vsub_VhVh(exp_recip, Q6_Vh_vsplat_R(127)), Q6_Vh_vsplat_R(15)); - exp_recip = Q6_Vh_vasl_VhR(exp_recip, 10); - - //Merge exponent and mantissa for reciprocal - HVX_Vector recip = Q6_V_vor_VV(exp_recip, mant_recip); - // map 'small' inputs to standard largest value 0x7bff - recip = Q6_V_vmux_QVV(is_small, Q6_Vh_vsplat_R(0x7bff), recip); - // add sign back - recip = Q6_V_vandor_VQR(recip, is_neg, 0x80008000); - return recip; -} - -#define IEEE_VSF_EXPLEN (8) -#define IEEE_VSF_EXPBIAS (127) -#define IEEE_VSF_EXPMASK (0xFF) -#define IEEE_VSF_MANTLEN (23) -#define IEEE_VSF_MANTMASK (0x7FFFFF) -#define IEEE_VSF_MIMPMASK (0x800000) - -static inline HVX_Vector hvx_vec_truncate_fp32(HVX_Vector in_vec) { - HVX_Vector mask_mant_v = Q6_V_vsplat_R(IEEE_VSF_MANTMASK); - HVX_Vector mask_impl_v = Q6_V_vsplat_R(IEEE_VSF_MIMPMASK); - HVX_Vector const_zero_v = Q6_V_vzero(); - - HVX_VectorPred q_negative = Q6_Q_vcmp_gt_VwVw(const_zero_v, in_vec); - - HVX_Vector expval_v = in_vec >> IEEE_VSF_MANTLEN; - expval_v &= IEEE_VSF_EXPMASK; - expval_v -= IEEE_VSF_EXPBIAS; - - // negative exp == fractional value - HVX_VectorPred q_negexp = Q6_Q_vcmp_gt_VwVw(const_zero_v, expval_v); - - HVX_Vector rshift_v = IEEE_VSF_MANTLEN - expval_v; // fractional bits - exp shift - - HVX_Vector mant_v = in_vec & mask_mant_v; // obtain mantissa - HVX_Vector vout = Q6_Vw_vadd_VwVw(mant_v, mask_impl_v); // add implicit 1.0 - - vout = Q6_Vw_vasr_VwVw(vout, rshift_v); // shift to obtain truncated integer - vout = Q6_V_vmux_QVV(q_negexp, const_zero_v, vout); // expval<0 -> 0 - - HVX_Vector neg_vout = -vout; - - vout = Q6_V_vmux_QVV(q_negative, neg_vout, vout); // handle negatives - - return (vout); -} - -static inline HVX_Vector hvx_vec_floor_fp32(HVX_Vector in_vec) { - HVX_Vector mask_mant_v = Q6_V_vsplat_R(IEEE_VSF_MANTMASK); - HVX_Vector mask_impl_v = Q6_V_vsplat_R(IEEE_VSF_MIMPMASK); - HVX_Vector const_mnlen_v = Q6_V_vsplat_R(IEEE_VSF_MANTLEN); - HVX_Vector const_zero_v = Q6_V_vzero(); - HVX_Vector const_negone_v = Q6_V_vsplat_R(0xbf800000); // -1 IEEE vsf - - HVX_VectorPred q_negative = Q6_Q_vcmp_gt_VwVw(const_zero_v, in_vec); - - HVX_Vector expval_v = in_vec >> IEEE_VSF_MANTLEN; - expval_v &= IEEE_VSF_EXPMASK; - expval_v -= IEEE_VSF_EXPBIAS; - - HVX_VectorPred q_negexp = Q6_Q_vcmp_gt_VwVw(const_zero_v, expval_v); - HVX_VectorPred q_expltmn = Q6_Q_vcmp_gt_VwVw(const_mnlen_v, expval_v); - HVX_VectorPred q_negexp_pos = Q6_Q_vcmp_gtand_QVwVw(q_negexp, in_vec, const_zero_v); - HVX_VectorPred q_negexp_neg = Q6_Q_vcmp_gtand_QVwVw(q_negexp, const_zero_v, in_vec); - - // if expval < 0 (q_negexp) // <0, floor is 0 - // if vin > 0 - // floor = 0 - // if vin < 0 - // floor = -1 - // if expval < mant_len (q_expltmn) // >0, but fraction may exist - // get sign (q_negative) - // mask >> expval // fraction bits to mask off - // vout = ~(mask) // apply mask to remove fraction - // if (qneg) // negative floor is one less (more, sign bit for neg) - // vout += ((impl_mask) >> expval) - // if (mask && vin) - // vout = vin - // else // already an integer - // ; // no change - - // compute floor - mask_mant_v >>= expval_v; - HVX_Vector neg_addin_v = mask_impl_v >> expval_v; - HVX_Vector vout_neg_addin = Q6_Vw_vadd_VwVw(in_vec, neg_addin_v); - HVX_Vector vout = Q6_V_vmux_QVV(q_negative, vout_neg_addin, in_vec); - - HVX_Vector mask_chk_v = Q6_V_vand_VV(in_vec, mask_mant_v); // chk if bits set - HVX_VectorPred q_integral = Q6_Q_vcmp_eq_VwVw(const_zero_v, mask_chk_v); - - HVX_Vector not_mask_v = Q6_V_vnot_V(mask_mant_v); // frac bits to clear - HVX_Vector vfrfloor_v = Q6_V_vand_VV(vout, not_mask_v); // clear frac bits - - vout = in_vec; - vout = Q6_V_vmux_QVV(q_expltmn, vfrfloor_v, vout); // expval<mant - vout = Q6_V_vmux_QVV(q_integral, in_vec, vout); // integral values - vout = Q6_V_vmux_QVV(q_negexp_pos, const_zero_v, vout); // expval<0 x>0 -> 0 - vout = Q6_V_vmux_QVV(q_negexp_neg, const_negone_v, vout); // expval<0 x<0 -> -1 - - return vout; -} - -static inline HVX_Vector hvx_vec_i16_from_hf_rnd_sat(HVX_Vector vin) { - // This looks complicated. - // Ideally should just be Q6_Vh_equals_Vhf(vin) - // but that instruction does not do proper rounding. - - // convert to qf32, multiplying by 1.0 in the process. - HVX_VectorPair v32 = Q6_Wqf32_vmpy_VhfVhf(vin, Q6_Vh_vsplat_R(0x3C00)); - - // 'in-range' values are +/32752. - // add 192K to it, convert to sf - HVX_Vector v192K = Q6_V_vsplat_R(0x48400000); - HVX_Vector vsf_0 = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(Q6_V_lo_W(v32), v192K)); - HVX_Vector vsf_1 = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(Q6_V_hi_W(v32), v192K)); - - // for in-range cases, result is {163858... 229360} so the exponent is always 144. - // if we extract bits 21..0 as a signed quantity, and round 6 bits off, that will be the answer. - // Start by <<10 to get the final 'sign' bit in bit 15... - vsf_0 = Q6_Vw_vasl_VwR(vsf_0, 10); - vsf_1 = Q6_Vw_vasl_VwR(vsf_1, 10); - - // now round down to 16 - return Q6_Vh_vround_VwVw_sat(vsf_1, vsf_0); -} - -static inline HVX_Vector hvx_vec_inverse_fp32(HVX_Vector v_sf) { - HVX_Vector inv_aprox_sf = Q6_V_vsplat_R(0x7EEEEBB3); - HVX_Vector two_sf = hvx_vec_splat_fp32(2.0); - - // First approximation - HVX_Vector i_sf = Q6_Vw_vsub_VwVw(inv_aprox_sf, v_sf); - - HVX_Vector r_qf; - - // Refine - r_qf = Q6_Vqf32_vmpy_VsfVsf( - i_sf, Q6_Vsf_equals_Vqf32(Q6_Vqf32_vsub_VsfVsf(two_sf, Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(i_sf, v_sf))))); - r_qf = Q6_Vqf32_vmpy_Vqf32Vqf32( - r_qf, Q6_Vqf32_vsub_VsfVsf(two_sf, Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(Q6_Vsf_equals_Vqf32(r_qf), v_sf)))); - r_qf = Q6_Vqf32_vmpy_Vqf32Vqf32( - r_qf, Q6_Vqf32_vsub_VsfVsf(two_sf, Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(Q6_Vsf_equals_Vqf32(r_qf), v_sf)))); - - return Q6_Vsf_equals_Vqf32(r_qf); -} - -#define FAST_SIGMOID_LOG2F (0x3fb8aa3b) // 1.442695022 -#define FAST_SIGMOID_C1 (0x3d009076) // 0.03138777 -#define FAST_SIGMOID_C2 (0x3e8d74bd) // 0.276281267 -#define FAST_SIGMOID_C3 (0x3f000000) // 0.5 - -static inline HVX_Vector hvx_vec_fast_sigmoid_fp32(HVX_Vector v) { - v = Q6_Vqf32_vmpy_VsfVsf(v, Q6_V_vsplat_R(FAST_SIGMOID_LOG2F)); - v = Q6_Vqf32_vmpy_VsfVsf(Q6_Vsf_equals_Vqf32(v), Q6_V_vsplat_R(FAST_SIGMOID_C3)); - - HVX_Vector in_int = hvx_vec_truncate_fp32(Q6_Vsf_equals_Vqf32(v)); - HVX_Vector x = Q6_Vqf32_vsub_Vqf32Vsf(v, Q6_Vsf_equals_Vw(in_int)); - HVX_Vector xx = Q6_Vqf32_vmpy_Vqf32Vqf32(x, x); - - HVX_Vector v1 = Q6_Vqf32_vmpy_VsfVsf(Q6_Vsf_equals_Vqf32(xx), Q6_V_vsplat_R(FAST_SIGMOID_C2)); - v1 = Q6_Vqf32_vadd_Vqf32Vsf(v1, Q6_V_vsplat_R(FAST_SIGMOID_LOG2F)); - - HVX_Vector v2 = Q6_Vqf32_vmpy_VsfVsf(Q6_Vsf_equals_Vqf32(x), Q6_V_vsplat_R(FAST_SIGMOID_C1)); - v2 = Q6_Vqf32_vmpy_Vqf32Vqf32(v2, xx); - v2 = Q6_Vqf32_vadd_Vqf32Vqf32(v2, x); - - HVX_Vector v3 = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vqf32(v2, v1)); - HVX_Vector v3_exponent = Q6_Vw_vasl_VwR(v3, 1); - v3_exponent = Q6_Vuw_vlsr_VuwR(v3_exponent, 24); - v3_exponent = Q6_Vw_vadd_VwVw(in_int, v3_exponent); - v3 = Q6_Vw_vaslacc_VwVwR(v3, in_int, 24); - - HVX_Vector v4 = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vsub_Vqf32Vqf32(v2, v1)); - HVX_Vector v5 = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vsub_VsfVsf(v3, v4)); - - HVX_Vector res = hvx_vec_inverse_fp32(v5); - res = Q6_Vqf32_vmpy_VsfVsf(v3, res); - - return Q6_Vsf_equals_Vqf32(res); -} - -#define EXP_COEFF_5 (0x39506967) // 0.000198757 = 1/(7!) -#define EXP_COEFF_4 (0x3AB743CE) // 0.0013982 = 1/(6!) -#define EXP_COEFF_3 (0x3C088908) // 0.00833345 = 1/(5!) -#define EXP_COEFF_2 (0x3D2AA9C1) // 0.416658 = 1/(4!) -#define EXP_COEFF_1 (0x3E2AAAAA) // 0.16666667 = 1/(3!) -#define EXP_COEFF_0 (0x3F000000) // 0.5 = 1/(2!) -#define EXP_LOGN2 (0x3F317218) // ln(2) = 0.6931471805 -#define EXP_LOG2E (0x3FB8AA3B) // log2(e) = 1/ln(2) = 1.4426950408 -#define EXP_ONE (0x3f800000) // 1.0 -#define EXP_RANGE_R (0x41a00000) // 20.0 -#define EXP_RANGE_L (0xc1a00000) // -20.0 - -static inline HVX_Vector hvx_vec_exp_fp32(HVX_Vector in_vec) { - HVX_Vector z_qf32_v; - HVX_Vector x_v; - HVX_Vector x_qf32_v; - HVX_Vector y_v; - HVX_Vector k_v; - HVX_Vector f_v; - HVX_Vector epsilon_v; - HVX_Vector log2e = Q6_V_vsplat_R(EXP_LOG2E); - HVX_Vector logn2 = Q6_V_vsplat_R(EXP_LOGN2); - HVX_Vector E_const; - HVX_Vector zero_v = Q6_V_vzero(); - - // exp(x) is approximated as follows: - // f = floor(x/ln(2)) = floor(x*log2(e)) - // epsilon = x - f*ln(2) - // exp(x) = exp(epsilon+f*ln(2)) - // = exp(epsilon)*exp(f*ln(2)) - // = exp(epsilon)*2^f - // - // Since epsilon is close to zero, it can be approximated with its Taylor series: - // exp(x) ~= 1+x+x^2/2!+x^3/3!+...+x^n/n!+... - // Preserving the first eight elements, we get: - // exp(x) ~= 1+x+e0*x^2+e1*x^3+e2*x^4+e3*x^5+e4*x^6+e5*x^7 - // = 1+x+(E0+(E1+(E2+(E3+(E4+E5*x)*x)*x)*x)*x)*x^2 - - HVX_Vector temp_v = in_vec; - - // Clamp inputs to (-20.0, 20.0) - HVX_VectorPred pred_cap_right = Q6_Q_vcmp_gt_VsfVsf(in_vec, Q6_V_vsplat_R(EXP_RANGE_R)); - HVX_VectorPred pred_cap_left = Q6_Q_vcmp_gt_VsfVsf(Q6_V_vsplat_R(EXP_RANGE_L), in_vec); - - in_vec = Q6_V_vmux_QVV(pred_cap_right, Q6_V_vsplat_R(EXP_RANGE_R), temp_v); - in_vec = Q6_V_vmux_QVV(pred_cap_left, Q6_V_vsplat_R(EXP_RANGE_L), temp_v); - - epsilon_v = Q6_Vqf32_vmpy_VsfVsf(log2e, in_vec); - epsilon_v = Q6_Vsf_equals_Vqf32(epsilon_v); - - // f_v is the floating point result and k_v is the integer result - f_v = hvx_vec_floor_fp32(epsilon_v); - k_v = hvx_vec_truncate_fp32(f_v); - - x_qf32_v = Q6_Vqf32_vadd_VsfVsf(in_vec, zero_v); - - // x = x - f_v * logn2; - epsilon_v = Q6_Vqf32_vmpy_VsfVsf(f_v, logn2); - x_qf32_v = Q6_Vqf32_vsub_Vqf32Vqf32(x_qf32_v, epsilon_v); - // normalize before every QFloat's vmpy - x_qf32_v = Q6_Vqf32_vadd_Vqf32Vsf(x_qf32_v, zero_v); - - // z = x * x; - z_qf32_v = Q6_Vqf32_vmpy_Vqf32Vqf32(x_qf32_v, x_qf32_v); - z_qf32_v = Q6_Vqf32_vadd_Vqf32Vsf(z_qf32_v, zero_v); - - x_v = Q6_Vsf_equals_Vqf32(x_qf32_v); - - // y = E4 + E5 * x; - E_const = Q6_V_vsplat_R(EXP_COEFF_5); - y_v = Q6_Vqf32_vmpy_VsfVsf(E_const, x_v); - E_const = Q6_V_vsplat_R(EXP_COEFF_4); - y_v = Q6_Vqf32_vadd_Vqf32Vsf(y_v, E_const); - y_v = Q6_Vqf32_vadd_Vqf32Vsf(y_v, zero_v); - - // y = E3 + y * x; - E_const = Q6_V_vsplat_R(EXP_COEFF_3); - y_v = Q6_Vqf32_vmpy_Vqf32Vqf32(y_v, x_qf32_v); - y_v = Q6_Vqf32_vadd_Vqf32Vsf(y_v, E_const); - y_v = Q6_Vqf32_vadd_Vqf32Vsf(y_v, zero_v); - - // y = E2 + y * x; - E_const = Q6_V_vsplat_R(EXP_COEFF_2); - y_v = Q6_Vqf32_vmpy_Vqf32Vqf32(y_v, x_qf32_v); - y_v = Q6_Vqf32_vadd_Vqf32Vsf(y_v, E_const); - y_v = Q6_Vqf32_vadd_Vqf32Vsf(y_v, zero_v); - - // y = E1 + y * x; - E_const = Q6_V_vsplat_R(EXP_COEFF_1); - y_v = Q6_Vqf32_vmpy_Vqf32Vqf32(y_v, x_qf32_v); - y_v = Q6_Vqf32_vadd_Vqf32Vsf(y_v, E_const); - y_v = Q6_Vqf32_vadd_Vqf32Vsf(y_v, zero_v); - - // y = E0 + y * x; - E_const = Q6_V_vsplat_R(EXP_COEFF_0); - y_v = Q6_Vqf32_vmpy_Vqf32Vqf32(y_v, x_qf32_v); - y_v = Q6_Vqf32_vadd_Vqf32Vsf(y_v, E_const); - y_v = Q6_Vqf32_vadd_Vqf32Vsf(y_v, zero_v); - - // y = x + y * z; - y_v = Q6_Vqf32_vmpy_Vqf32Vqf32(y_v, z_qf32_v); - y_v = Q6_Vqf32_vadd_Vqf32Vqf32(y_v, x_qf32_v); - y_v = Q6_Vqf32_vadd_Vqf32Vsf(y_v, zero_v); - - // y = y + 1.0; - y_v = Q6_Vqf32_vadd_Vqf32Vsf(y_v, Q6_V_vsplat_R(EXP_ONE)); - - // insert exponents - // y = ldexpf(y, k); - // y_v += k_v; // qf32 - // modify exponent - - y_v = Q6_Vsf_equals_Vqf32(y_v); - - // add k_v to the exponent of y_v - HVX_Vector y_v_exponent = Q6_Vw_vasl_VwR(y_v, 1); - - y_v_exponent = Q6_Vuw_vlsr_VuwR(y_v_exponent, IEEE_VSF_MANTLEN + 1); - y_v_exponent = Q6_Vw_vadd_VwVw(k_v, y_v_exponent); - - // exponent cannot be negative; if overflow is detected, result is set to zero - HVX_VectorPred qy_v_negative_exponent = Q6_Q_vcmp_gt_VwVw(zero_v, y_v_exponent); - - y_v = Q6_Vw_vaslacc_VwVwR(y_v, k_v, IEEE_VSF_MANTLEN); - - y_v = Q6_V_vmux_QVV(qy_v_negative_exponent, zero_v, y_v); - - return y_v; -} - -#define RSQRT_CONST 0x5f3759df // Constant for fast inverse square root calculation -#define RSQRT_ONE_HALF 0x3f000000 // 0.5 -#define RSQRT_THREE_HALVES 0x3fc00000 // 1.5 - -static inline HVX_Vector hvx_vec_rsqrt_fp32(HVX_Vector in_vec) { - //Algorithm : - // x2 = input*0.5 - // y = * (long *) &input - // y = 0x5f3759df - (y>>2) - // y = y*(threehalfs - x2*y*y) - - HVX_Vector rsqrtconst = Q6_V_vsplat_R(RSQRT_CONST); - HVX_Vector onehalf = Q6_V_vsplat_R(RSQRT_ONE_HALF); - HVX_Vector threehalfs = Q6_V_vsplat_R(RSQRT_THREE_HALVES); - - HVX_Vector x2, y, ypower2, temp; - - x2 = Q6_Vqf32_vmpy_VsfVsf(in_vec, onehalf); - x2 = Q6_Vqf32_vadd_Vqf32Vsf(x2, Q6_V_vzero()); - - y = Q6_Vw_vasr_VwR(in_vec, 1); - y = Q6_Vw_vsub_VwVw(rsqrtconst, y); - - // 1st iteration - ypower2 = Q6_Vqf32_vmpy_VsfVsf(y, y); - ypower2 = Q6_Vqf32_vadd_Vqf32Vsf(ypower2, Q6_V_vzero()); - temp = Q6_Vqf32_vmpy_Vqf32Vqf32(x2, ypower2); - temp = Q6_Vqf32_vsub_VsfVsf(threehalfs, Q6_Vsf_equals_Vqf32(temp)); - temp = Q6_Vqf32_vmpy_VsfVsf(y, Q6_Vsf_equals_Vqf32(temp)); - - // 2nd iteration - y = Q6_Vqf32_vadd_Vqf32Vsf(temp, Q6_V_vzero()); - ypower2 = Q6_Vqf32_vmpy_Vqf32Vqf32(y, y); - ypower2 = Q6_Vqf32_vadd_Vqf32Vsf(ypower2, Q6_V_vzero()); - temp = Q6_Vqf32_vmpy_Vqf32Vqf32(x2, ypower2); - temp = Q6_Vqf32_vsub_VsfVsf(threehalfs, Q6_Vsf_equals_Vqf32(temp)); - temp = Q6_Vqf32_vmpy_Vqf32Vqf32(y, temp); - - // 3rd iteration - y = Q6_Vqf32_vadd_Vqf32Vsf(temp, Q6_V_vzero()); - ypower2 = Q6_Vqf32_vmpy_Vqf32Vqf32(y, y); - ypower2 = Q6_Vqf32_vadd_Vqf32Vsf(ypower2, Q6_V_vzero()); - temp = Q6_Vqf32_vmpy_Vqf32Vqf32(x2, ypower2); - temp = Q6_Vqf32_vsub_VsfVsf(threehalfs, Q6_Vsf_equals_Vqf32(temp)); - temp = Q6_Vqf32_vmpy_Vqf32Vqf32(y, temp); - - return Q6_Vsf_equals_Vqf32(temp); -} - -static inline HVX_Vector hvx_vec_fast_sigmoid_fp32_guard(HVX_Vector v, - HVX_Vector one, - HVX_Vector max_exp, - HVX_Vector min_exp) { - const HVX_VectorPred pred_max = Q6_Q_vcmp_gt_VsfVsf(max_exp, v); - const HVX_VectorPred pred_min = Q6_Q_vcmp_gt_VsfVsf(v, min_exp); - - HVX_Vector out = hvx_vec_fast_sigmoid_fp32(v); - out = Q6_V_vmux_QVV(pred_max, out, one); - return Q6_V_vmux_QVV(pred_min, out, Q6_V_vzero()); -} - -static inline HVX_Vector hvx_vec_tanh_fp32(HVX_Vector x) { - // tanh(x) = 2 * sigmoid(2x) - 1 - HVX_Vector two = hvx_vec_splat_fp32(2.0f); - HVX_Vector one = hvx_vec_splat_fp32(1.0f); - HVX_Vector x2 = Q6_Vqf32_vmpy_VsfVsf(x, two); - - static const float kMinExp = -87.f; // 0 - static const float kMaxExp = 87.f; // 1 - HVX_Vector max_exp = hvx_vec_splat_fp32(kMaxExp); - HVX_Vector min_exp = hvx_vec_splat_fp32(kMinExp); - - HVX_Vector sig2x = hvx_vec_fast_sigmoid_fp32_guard(Q6_Vsf_equals_Vqf32(x2), one, max_exp, min_exp); - - HVX_Vector res = Q6_Vqf32_vmpy_VsfVsf(sig2x, two); - res = Q6_Vqf32_vsub_Vqf32Vsf(res, one); - return Q6_Vsf_equals_Vqf32(res); -} - -static inline void hvx_fast_sigmoid_f32(const uint8_t * restrict src, uint8_t * restrict dst, const int num_elems) { - int step_of_1 = num_elems >> 5; - int remaining = num_elems - step_of_1 * VLEN_FP32; - - const HVX_Vector * restrict v_src = (HVX_Vector *) src; - HVX_Vector * restrict v_dst = (HVX_Vector *) dst; - - static const float kMinExp = -87.f; // 0 - static const float kMaxExp = 87.f; // 1 - - const HVX_Vector one = hvx_vec_splat_fp32(1.f); - const HVX_Vector max_exp = hvx_vec_splat_fp32(kMaxExp); - const HVX_Vector min_exp = hvx_vec_splat_fp32(kMinExp); - - #pragma unroll(4) - for (int i = 0; i < step_of_1; i++) { - v_dst[i] = hvx_vec_fast_sigmoid_fp32_guard(v_src[i], one, max_exp, min_exp); - } - - if (remaining > 0) { - const float * srcf = ((const float *) src) + step_of_1* VLEN_FP32; - float * dstf = (float *) dst + step_of_1*VLEN_FP32; - - HVX_Vector in = *(HVX_UVector *) srcf; - HVX_Vector out = hvx_vec_fast_sigmoid_fp32_guard(in, one, max_exp, min_exp); - hvx_vec_store_u((void *) dstf, remaining * SIZEOF_FP32, out); - } -} - -static inline void hvx_sigmoid_f32(const uint8_t * restrict src, uint8_t * restrict dst, const int num_elems){ - int step_of_1 = num_elems >> 5; // divby 32, because 32 float = 128 bytes per HVX vector - int leftover = num_elems - (step_of_1 * VLEN_FP32); - - int32_t leftover_size = leftover * sizeof(float); - - static const float kMinExp = -87.f; // 0 - static const float kMaxExp = 87.f; // 1 - - const HVX_Vector one = hvx_vec_splat_fp32(1.f); - const HVX_Vector max_exp = hvx_vec_splat_fp32(kMaxExp); - const HVX_Vector min_exp = hvx_vec_splat_fp32(kMinExp); - - const float *input = (float *)src; - float *output = (float *)dst; - - HVX_Vector * input_v_ptr = (HVX_Vector *) input; - HVX_UVector * output_v_ptr = (HVX_UVector *) output; - - HVX_Vector slinep; - HVX_Vector slinec; - HVX_Vector sline; - - slinep = *input_v_ptr++; - #pragma unroll(4) - for (int i = step_of_1 - 1; i > 0; i--) { - slinec = *input_v_ptr++; - sline = Q6_V_valign_VVR(slinec, slinep, (size_t) input); - *((HVX_UVector *) (output_v_ptr++)) = hvx_vec_fast_sigmoid_fp32_guard(sline, one, max_exp, min_exp); - /* Prepare slinep for next iteration */ - slinep = slinec; - } - - if (step_of_1 > 0) { - slinec = htp_is_aligned(input_v_ptr, 128) && leftover == 0 ? slinep : *input_v_ptr++; - sline = Q6_V_valign_VVR(slinec, slinep, (size_t) input); - *((HVX_UVector *) (output_v_ptr++)) = hvx_vec_fast_sigmoid_fp32_guard(sline, one, max_exp, min_exp); - ; - - slinep = slinec; - } - if (leftover > 0) { - slinec = (is_in_one_chunk(input_v_ptr, leftover_size, 128) ? slinep : *input_v_ptr++); - - sline = Q6_V_valign_VVR(slinec, slinep, (size_t) input); - - HVX_Vector sout = hvx_vec_fast_sigmoid_fp32_guard(sline, one, max_exp, min_exp); - hvx_vec_store_u(output_v_ptr, leftover_size, sout); - } -} - -static inline void hvx_scale_f32_aa(uint8_t * restrict dst, const uint8_t * restrict src, const int n, const float scale) { - int nvec = n / VLEN_FP32; - int nloe = n % VLEN_FP32; - - HVX_Vector vs = hvx_vec_splat_fp32(scale); - - HVX_Vector * vsrc = (HVX_Vector *) src; - HVX_Vector * vdst = (HVX_Vector *) dst; - - uint32_t i = 0; - - #pragma unroll(4) - for (i = 0; i < nvec; ++i) { - HVX_Vector v = Q6_Vqf32_vmpy_VsfVsf(vsrc[i], vs); - vdst[i] = Q6_Vsf_equals_Vqf32(v); - } - - if (nloe) { - HVX_Vector v = Q6_Vqf32_vmpy_VsfVsf(vsrc[i], vs); - hvx_vec_store_u((void *) &vdst[i], nloe * 4, Q6_Vsf_equals_Vqf32(v)); - } -} - -static inline void hvx_scale_f32_uu(uint8_t * restrict dst, const uint8_t * restrict src, const int n, const float scale) { - int nvec = n / VLEN_FP32; - int nloe = n % VLEN_FP32; - - HVX_Vector vs = hvx_vec_splat_fp32(scale); - - HVX_UVector * vsrc = (HVX_UVector *) src; - HVX_UVector * vdst = (HVX_UVector *) dst; - - uint32_t i = 0; - - #pragma unroll(4) - for (i = 0; i < nvec; ++i) { - HVX_Vector v = Q6_Vqf32_vmpy_VsfVsf(vsrc[i], vs); - vdst[i] = Q6_Vsf_equals_Vqf32(v); - } - - if (nloe) { - HVX_Vector v = Q6_Vqf32_vmpy_VsfVsf(vsrc[i], vs); - hvx_vec_store_u((void *) &vdst[i], nloe * 4, Q6_Vsf_equals_Vqf32(v)); - } -} - -static inline void hvx_scale_f32(uint8_t * restrict dst, const uint8_t * restrict src, const int n, const float scale) { - if (htp_is_aligned((void *) src, VLEN) && htp_is_aligned((void *) dst, VLEN)) { - hvx_scale_f32_aa(dst, src, n, scale); - } else { - hvx_scale_f32_uu(dst, src, n, scale); - } -} - -static inline void hvx_scale_offset_f32_aa(uint8_t * restrict dst, const uint8_t * restrict src, const int n, const float scale, const float offset) { - int nvec = n / VLEN_FP32; - int nloe = n % VLEN_FP32; - - HVX_Vector vs = hvx_vec_splat_fp32(scale); - HVX_Vector vo = hvx_vec_splat_fp32(offset); - - HVX_Vector * vsrc = (HVX_Vector *) src; - HVX_Vector * vdst = (HVX_Vector *) dst; - - uint32_t i = 0; - - #pragma unroll(4) - for (i = 0; i < nvec; ++i) { - HVX_Vector v = Q6_Vqf32_vadd_Vqf32Vsf(Q6_Vqf32_vmpy_VsfVsf(vsrc[i], vs), vo); - vdst[i] = Q6_Vsf_equals_Vqf32(v); - } - - if (nloe) { - HVX_Vector v = Q6_Vqf32_vadd_Vqf32Vsf(Q6_Vqf32_vmpy_VsfVsf(vsrc[i], vs), vo); - hvx_vec_store_u((void *) &vdst[i], nloe * 4, Q6_Vsf_equals_Vqf32(v)); - } -} - -static inline void hvx_scale_offset_f32_uu(uint8_t * restrict dst, const uint8_t * restrict src, const int n, const float scale, const float offset) { - int nvec = n / VLEN_FP32; - int nloe = n % VLEN_FP32; - - HVX_Vector vs = hvx_vec_splat_fp32(scale); - HVX_Vector vo = hvx_vec_splat_fp32(offset); - - HVX_UVector * vsrc = (HVX_UVector *) src; - HVX_UVector * vdst = (HVX_UVector *) dst; - - uint32_t i = 0; - - #pragma unroll(4) - for (i = 0; i < nvec; ++i) { - HVX_Vector v = Q6_Vqf32_vadd_Vqf32Vsf(Q6_Vqf32_vmpy_VsfVsf(vsrc[i], vs), vo); - vdst[i] = Q6_Vsf_equals_Vqf32(v); - } - - if (nloe) { - HVX_Vector v = Q6_Vqf32_vadd_Vqf32Vsf(Q6_Vqf32_vmpy_VsfVsf(vsrc[i], vs), vo); - hvx_vec_store_u((void *) &vdst[i], nloe * 4, Q6_Vsf_equals_Vqf32(v)); - } -} - -static inline void hvx_scale_offset_f32(uint8_t * restrict dst, const uint8_t * restrict src, const int n, const float scale, const float offset) { - if (htp_is_aligned((void *) src, VLEN) && htp_is_aligned((void *) dst, VLEN)) { - hvx_scale_offset_f32_aa(dst, src, n, scale, offset); - } else { - hvx_scale_offset_f32_uu(dst, src, n, scale, offset); - } -} - -float hvx_sum_of_squares_f32(const uint8_t * restrict src, const int num_elems); -void hvx_mul_f32(const uint8_t * restrict src0, - const uint8_t * restrict src1, - uint8_t * restrict dst, - const int num_elems); -void hvx_mul_f32_opt(const uint8_t * restrict src0, - const uint8_t * restrict src1, - uint8_t * restrict dst, - const int num_elems); -void hvx_mul_mul_f32_opt(const uint8_t * restrict src0, - const uint8_t * restrict src1, - const uint8_t * restrict src2, - uint8_t * restrict dst, - const int num_elems); -void hvx_mul_scalar_f32(const uint8_t * restrict src, const float val, uint8_t * restrict dst, const int num_elems); -void hvx_add_f32(const uint8_t * restrict src0, - const uint8_t * restrict src1, - uint8_t * restrict dst, - const int num_elems); -void hvx_add_f32_opt(const uint8_t * restrict src0, - const uint8_t * restrict src1, - uint8_t * restrict dst, - const int num_elems); -void hvx_add_scalar_f32(const uint8_t * restrict src, const float val, uint8_t * restrict dst, const int num_elems); -void hvx_sub_f32(const uint8_t * restrict src0, - const uint8_t * restrict src1, - uint8_t * restrict dst, - const int num_elems); -void hvx_sub_f32_opt(const uint8_t * restrict src0, - const uint8_t * restrict src1, - uint8_t * restrict dst, - const int num_elems); -void hvx_sub_scalar_f32(const uint8_t * restrict src, const float val, uint8_t * restrict dst, const int num_elems); -void hvx_inverse_f32(const uint8_t * restrict src, uint8_t * restrict dst, const int num_elems); -void hvx_sigmoid_f32(const uint8_t * restrict src, uint8_t * restrict dst, const int num_elems); -void hvx_exp_f32(const uint8_t * restrict src, uint8_t * restrict dst, const int num_elems, bool negate); -float hvx_self_max_f32(const uint8_t * restrict src, const int num_elems); -float hvx_self_sum_f32(const uint8_t * restrict src, const int num_elems); -void hvx_min_scalar_f32(const uint8_t * restrict src, const float val, uint8_t * restrict dst, const int num_elems); -void hvx_clamp_scalar_f32(const uint8_t * restrict src, - const float limit_left, - const float limit_right, - uint8_t * restrict dst, - const int num_elems); +#include "hex-utils.h" + +#include "hvx-types.h" +#include "hvx-copy.h" +#include "hvx-repl.h" +#include "hvx-scale.h" +#include "hvx-exp.h" +#include "hvx-inverse.h" +#include "hvx-reduce.h" +#include "hvx-sigmoid.h" +#include "hvx-sqrt.h" +#include "hvx-arith.h" +#include "hvx-div.h" +#include "hvx-floor.h" +#include "hvx-sin-cos.h" +#include "hvx-base.h" +#include "hvx-pow.h" +#include "hvx-log.h" #endif /* HVX_UTILS_H */ diff --git a/ggml/src/ggml-hexagon/htp/main.c b/ggml/src/ggml-hexagon/htp/main.c index 24b3e90e4b6..3715227d2c7 100644 --- a/ggml/src/ggml-hexagon/htp/main.c +++ b/ggml/src/ggml-hexagon/htp/main.c @@ -1,31 +1,34 @@ #pragma clang diagnostic ignored "-Wgnu-zero-variadic-macro-arguments" #pragma clang diagnostic ignored "-Wunused-function" +#pragma clang diagnostic ignored "-Wunused-variable" +#pragma clang diagnostic ignored "-Wunused-but-set-variable" -#define FARF_ERROR 1 -#define FARF_HIGH 1 -#define FARF_MEDIUM 0 -#define FARF_LOW 0 +#include <HAP_farf.h> +#include <HAP_perf.h> #include <AEEStdErr.h> #include <dspqueue.h> #include <HAP_compute_res.h> #include <HAP_etm_config.h> -#include <HAP_farf.h> #include <HAP_mem.h> -#include <HAP_perf.h> #include <HAP_power.h> #include <HAP_ps.h> +#include <HAP_dcvs.h> #include <qurt.h> #include <qurt_thread.h> +#include <qurt_memory.h> #include <remote.h> #include <string.h> +#include "hex-utils.h" +#include "hex-dma.h" +#include "hmx-queue.h" + #define GGML_COMMON_DECL_C #include "ggml-common.h" #include "htp-ctx.h" -#include "htp-dma.h" -#include "htp-msg.h" #include "htp-ops.h" -#include "ops-utils.h" +#include "htp-ops.h" +#include "htp_iface.h" #include "worker-pool.h" AEEResult htp_iface_open(const char * uri, remote_handle64 * handle) { @@ -37,7 +40,7 @@ AEEResult htp_iface_open(const char * uri, remote_handle64 * handle) { return AEE_ENOMEMORY; } - // Use the context structure as a handle + // Use the context structure as the handle *handle = (remote_handle64) ctx; // Enable FARF logs @@ -61,8 +64,7 @@ AEEResult htp_iface_open(const char * uri, remote_handle64 * handle) { request.type = HAP_power_set_DCVS_v3; request.dcvs_v3.set_dcvs_enable = TRUE; - request.dcvs_v3.dcvs_enable = TRUE; - request.dcvs_v3.dcvs_option = HAP_DCVS_V2_PERFORMANCE_MODE; + request.dcvs_v3.dcvs_enable = FALSE; request.dcvs_v3.set_bus_params = TRUE; request.dcvs_v3.bus_params.min_corner = HAP_DCVS_VCORNER_MAX; request.dcvs_v3.bus_params.max_corner = HAP_DCVS_VCORNER_MAX; @@ -73,6 +75,10 @@ AEEResult htp_iface_open(const char * uri, remote_handle64 * handle) { request.dcvs_v3.core_params.target_corner = HAP_DCVS_VCORNER_MAX; request.dcvs_v3.set_sleep_disable = TRUE; request.dcvs_v3.sleep_disable = TRUE; + +#if (__HEXAGON_ARCH__ >= 79) + HAP_set_dcvs_v3_protected_bus_corners(&request, 1); +#endif if ((err = HAP_power_set((void *) ctx, &request)) != 0) { return err; } @@ -85,6 +91,27 @@ AEEResult htp_iface_open(const char * uri, remote_handle64 * handle) { } } +#if __HVX_ARCH__ >= 75 + { + // Power on HMX and set HMX clock + HAP_power_request_t request; + memset(&request, 0, sizeof(HAP_power_request_t)); + request.type = HAP_power_set_HMX_v2; + request.hmx_v2.set_power = TRUE; + request.hmx_v2.power_up = TRUE; + request.hmx_v2.set_clock = TRUE; + request.hmx_v2.target_corner = HAP_DCVS_EXP_VCORNER_MAX; + request.hmx_v2.min_corner = HAP_DCVS_EXP_VCORNER_MAX; + request.hmx_v2.max_corner = HAP_DCVS_EXP_VCORNER_MAX; + request.hmx_v2.perf_mode = HAP_CLK_PERF_HIGH; + FARF(ALWAYS, "Setting HMX clock\n"); + err = HAP_power_set((void *) ctx, &request); + if (err != AEE_SUCCESS) { + FARF(ERROR, "ggml-hex: error setting HMX clock."); + return err; + } + } +#else { // Power on HMX HAP_power_request_t request; @@ -92,12 +119,61 @@ AEEResult htp_iface_open(const char * uri, remote_handle64 * handle) { request.type = HAP_power_set_HMX; request.hmx.power_up = TRUE; FARF(ALWAYS, "Powering HMX on\n"); - err = HAP_power_set((void *) &ctx, &request); + err = HAP_power_set((void *) ctx, &request); if (err != AEE_SUCCESS) { - FARF(ERROR, "Error powering on HMX."); + FARF(ERROR, "ggml-hex: error powering on HMX."); return err; } } +#endif + + return AEE_SUCCESS; +} + +AEEResult htp_iface_etm(remote_handle64 handle, uint32_t enable) { + int err = enable ? HAP_user_etm_enable() : HAP_user_etm_disable(); + if (err) { + if (err == AEE_EVERSIONNOTSUPPORT) { + FARF(ERROR, "API HAP_user_etm_enable/disable is not supported\n"); + } else { + FARF(ERROR, "Error executing HAP_user_etm_enable/disable with error code : 0x%x\n", err); + } + } + return err; +} + +AEEResult htp_iface_profiler(remote_handle64 handle, uint32_t mode, const htp_iface_pmu_conf* pmu_conf) { + struct htp_context * ctx = (struct htp_context *) handle; + if (!ctx) { + return AEE_EBADPARM; + } + + if (mode == HTP_PROF_PMU) { + const uint32_t* events = pmu_conf->events; + + // Pack 4 event IDs (low 8 bits) into each 32-bit config register + uint32_t evtcfg = 0, evtcfg1 = 0, cfg = 0, i = 0; + for (; i < HEX_NUM_PMU_COUNTERS/2; i++) { + evtcfg |= ((events[i + 0] & 0xFF) << (i * 8)); + evtcfg1 |= ((events[i + 4] & 0xFF) << (i * 8)); + } + + // For events >255 pack high 2 bits of all 8 event IDs into cfg register + // 2 bits per counter: bits [1:0] for counter 0, [3:2] for counter 1, etc. + for (i = 0; i < HEX_NUM_PMU_COUNTERS; i++) { + cfg |= (((events[i] >> 8) & 3) << (i * 2)); + } + + FARF(ALWAYS, "Configuring PMU registers: evtcfg = 0x%x, evtcfg1 = 0x%x, pmucfg = 0x%x", evtcfg, evtcfg1, cfg); + + // Configure PMU registers + qurt_pmu_set(QURT_PMUCFG, cfg); + qurt_pmu_set(QURT_PMUEVTCFG, evtcfg); + qurt_pmu_set(QURT_PMUEVTCFG1, evtcfg1); + qurt_pmu_enable(1); + } + + ctx->profiler = mode; return AEE_SUCCESS; } @@ -114,91 +190,128 @@ AEEResult htp_iface_close(remote_handle64 handle) { return AEE_EITEMBUSY; } + // release the mmaps (if any) + for (uint32_t i=0; i<HTP_MAX_MMAPS; i++) { + if (ctx->mmap[i].size) { +#if __HVX_ARCH__ > 73 + HAP_munmap2((void *) ctx->mmap[i].base, ctx->mmap[i].size); +#else + HAP_munmap((void *) ctx->mmap[i].base, ctx->mmap[i].size); +#endif + ctx->mmap[i].size = 0; + ctx->mmap[i].base = NULL; + ctx->mmap[i].fd = -1; + } + } + + if (ctx->profiler) { + qurt_pmu_enable(1); + } + + if (ctx->etm) { + HAP_user_etm_disable(); + } + free(ctx); return AEE_SUCCESS; } -AEEResult htp_iface_enable_etm(remote_handle64 handle) { - int err = HAP_user_etm_enable(); - if (err) { - if (err == AEE_EVERSIONNOTSUPPORT) { - FARF(ERROR, "API HAP_user_etm_enable is not supported\n"); - } else { - FARF(ERROR, "Error executing HAP_user_etm_enable with error code : 0x%x\n", err); +AEEResult htp_iface_mmap(remote_handle64 handle, uint32_t fd, uint32_t size) { + struct htp_context * ctx = (struct htp_context *) handle; + if (!ctx) { + return AEE_EBADPARM; + } + + // See if we already have this mapping + for (uint32_t i=0; i<HTP_MAX_MMAPS; i++) { + struct htp_mmap *m = &ctx->mmap[i]; + if (m->fd == fd) { + return AEE_SUCCESS; } } - return err; + + // Add new mapping + for (uint32_t i=0; i<HTP_MAX_MMAPS; i++) { + struct htp_mmap *m = &ctx->mmap[i]; + if (!m->size) { + FARF(HIGH, "mmap : fd %u size %u", fd, size); +#if __HVX_ARCH__ > 73 + void *va = HAP_mmap2(NULL, size, HAP_PROT_READ | HAP_PROT_WRITE, 0, fd, 0); +#else + if (size > HTP_MMAP_MAX_VMEM) { // HAP_mmap has a size limit of 2GB + FARF(ERROR, "mmap failed : size %u exceeds 2GB limit for HAP_mmap", (uint32_t) size); + abort(); // can't do much else at this point + } + + void *va = HAP_mmap(NULL, size, HAP_PROT_READ | HAP_PROT_WRITE, 0, fd, 0); +#endif + if (va == (void*)-1) { + FARF(ERROR, "mmap failed : va %p fd %u size %u", va, fd, (uint32_t) size); + return AEE_EFAILED; + } + + m->base = (uint64_t) va; + m->fd = fd; + m->size = size; + + return AEE_SUCCESS; + } + } + + return AEE_ENOMEMORY; } -AEEResult htp_iface_disable_etm(remote_handle64 handle) { - int err = HAP_user_etm_disable(); - if (err) { - if (err == AEE_EVERSIONNOTSUPPORT) { - FARF(ERROR, "API HAP_user_etm_disable is not supported\n"); - } else { - FARF(ERROR, "Error executing HAP_user_etm_disable with error code : 0x%x\n", err); +AEEResult htp_iface_munmap(remote_handle64 handle, uint32 fd) { + struct htp_context * ctx = (struct htp_context *) handle; + if (!ctx) { + return AEE_EBADPARM; + } + + for (uint32_t i=0; i<HTP_MAX_MMAPS; i++) { + struct htp_mmap *m = &ctx->mmap[i]; + if (fd < 0 || m->fd == fd) { + FARF(HIGH, "unmmap : base %p fd %u size %u", (void*) m->base, m->fd, (uint32_t) m->size); +#if __HVX_ARCH__ > 73 + HAP_munmap2((void *) m->base, m->size); +#else + HAP_munmap((void *) m->base, m->size); +#endif + m->size = 0; + m->base = NULL; + m->fd = -1; } } - return err; + + return AEE_SUCCESS; } -static int vtcm_acquire(struct htp_context * ctx) { - int err; +static void vtcm_acquire(struct htp_context * ctx) { if (!ctx->vtcm_valid) { - // Temporarily bump thread priority to make sure it's higher than other sessions. - // This way the resource manager will notify the other thread to release VTCM. - // Note that we need to reaquire VTCM at normal priority for this to work next time. - qurt_thread_set_priority(qurt_thread_get_id(), ctx->thread_prio - 10); - err = HAP_compute_res_acquire_cached(ctx->vtcm_rctx, 1000000); + int err = HAP_compute_res_acquire_cached(ctx->vtcm_rctx, 1000000u); if (err != 0) { - FARF(ERROR, "Failed to acquire VTCM: 0x%08x", (unsigned)err); + FARF(ERROR, "ggml-hex: failed to acquire VTCM: 0x%08x", (unsigned)err); abort(); } - HAP_compute_res_release_cached(ctx->vtcm_rctx); - qurt_thread_set_priority(qurt_thread_get_id(), ctx->thread_prio); - err = HAP_compute_res_acquire_cached(ctx->vtcm_rctx, 1000000); - if (err != 0) { - FARF(ERROR, "Failed to acquire VTCM: 0x%08x", (unsigned)err); - abort(); - } + ctx->vtcm_needs_release = false; ctx->vtcm_valid = true; - } - ctx->vtcm_inuse = true; - return 0; + // Drop the priority to make sure we get the release callback from other GGML-HTP and QNN-HTP sessions + HAP_compute_res_update_priority(ctx->vtcm_rctx, ctx->thread_prio + 10); + } } -static int vtcm_release(struct htp_context * ctx) { - ctx->vtcm_inuse = false; - - if (ctx->vtcm_valid && ctx->vtcm_needs_release) { +static void vtcm_release(struct htp_context * ctx) { + if (ctx->vtcm_valid) { ctx->vtcm_valid = false; ctx->vtcm_needs_release = false; HAP_compute_res_release_cached(ctx->vtcm_rctx); } - - return 0; } static int vtcm_release_callback(unsigned int rctx, void * state) { struct htp_context * ctx = (struct htp_context *) state; - - if (!ctx || ctx->vtcm_rctx != rctx) { - return AEE_EBADPARM; - } - - // If VTCM is not inuse (not processing Ops) release it right here - // otherwise we'll release it once we're done with the current Op. - - if (ctx->vtcm_inuse) { - ctx->vtcm_needs_release = false; - return 0; - } - - ctx->vtcm_valid = false; - HAP_compute_res_release_cached(ctx->vtcm_rctx); - + ctx->vtcm_needs_release = true; return 0; } @@ -210,7 +323,7 @@ static int vtcm_alloc(struct htp_context * ctx) { HAP_compute_res_attr_init(&attr); HAP_compute_res_attr_set_serialize(&attr, 0); HAP_compute_res_attr_set_cache_mode(&attr, 1); - HAP_compute_res_attr_set_vtcm_param_v2(&attr, vtcm_size, 0, vtcm_size); + HAP_compute_res_attr_set_vtcm_param_v2(&attr, vtcm_size, vtcm_size, vtcm_size); // single page HAP_compute_res_attr_set_release_callback(&attr, vtcm_release_callback, (void *) ctx); HAP_compute_res_attr_set_hmx_param(&attr, 1); @@ -232,7 +345,6 @@ static int vtcm_alloc(struct htp_context * ctx) { ctx->vtcm_size = vtcm_size; ctx->vtcm_rctx = rctx; ctx->vtcm_valid = false; - ctx->vtcm_inuse = false; ctx->vtcm_needs_release = false; return 0; @@ -249,7 +361,7 @@ static void vtcm_free(struct htp_context * ctx) { static void htp_packet_callback(dspqueue_t queue, int error, void * context); static void htp_error_callback(dspqueue_t queue, int error, void * context); -AEEResult htp_iface_start(remote_handle64 handle, uint32 sess_id, uint64 dsp_queue_id, uint32 n_hvx) { +AEEResult htp_iface_start(remote_handle64 handle, uint32 sess_id, uint64 dsp_queue_id, uint32 n_hvx, uint32 use_hmx, uint64_t max_vmem) { struct htp_context * ctx = (struct htp_context *) handle; if (!ctx) { @@ -267,12 +379,12 @@ AEEResult htp_iface_start(remote_handle64 handle, uint32 sess_id, uint64 dsp_que htp_error_callback, // Error callback; no errors expected on the DSP (void *) ctx, // Callback context &ctx->queue); - if (err) { FARF(ERROR, "Queue import failed with 0x%08x", (unsigned) err); return err; } + ctx->max_vmem = max_vmem; ctx->thread_id = qurt_thread_get_id(); ctx->thread_prio = qurt_thread_get_priority(ctx->thread_id); @@ -283,6 +395,19 @@ AEEResult htp_iface_start(remote_handle64 handle, uint32 sess_id, uint64 dsp_que return AEE_ENOMEMORY; } +#ifdef HTP_HAS_HMX + ctx->hmx_enabled = use_hmx; + ctx->hmx_queue = NULL; + if (use_hmx) { + ctx->hmx_queue = hmx_queue_create(16, ctx->vtcm_rctx); + if (!ctx->hmx_queue) { + FARF(ERROR, "hmx-queue-create failed"); + ctx->hmx_enabled = false; + } + } + FARF(HIGH, "HMX %s (use_hmx=%d)", ctx->hmx_enabled ? "enabled" : "disabled", use_hmx); +#endif + qurt_sysenv_max_hthreads_t hw_threads; qurt_sysenv_get_max_hw_threads(&hw_threads); uint32_t hw_nhvx = (qurt_hvx_get_units() >> 8) & 0xFF; @@ -299,14 +424,21 @@ AEEResult htp_iface_start(remote_handle64 handle, uint32 sess_id, uint64 dsp_que ctx->n_threads = n_hvx; for (int i = 0; i < ctx->n_threads; i++) { - // see discussion https://github.com/ggml-org/llama.cpp/pull/18151#discussion_r2632388541 - ctx->dma[i] = dma_queue_create(64); + ctx->dma[i] = dma_queue_create(256); // queue depth } + ctx->ddr_spad_size = 512 * 1024; // 512 KB + ctx->ddr_spad_base = memalign(128, ctx->ddr_spad_size); + // init worker pool err = worker_pool_init(&ctx->worker_pool, n_hvx); if (err != AEE_SUCCESS) { FARF(ERROR, "Unable to create worker pool"); + if (ctx->ddr_spad_base) { + free(ctx->ddr_spad_base); + ctx->ddr_spad_base = NULL; + ctx->ddr_spad_size = 0; + } return err; } @@ -344,8 +476,22 @@ AEEResult htp_iface_stop(remote_handle64 handle) { dma_queue_delete(ctx->dma[i]); } +#ifdef HTP_HAS_HMX + if (ctx->hmx_queue) { + hmx_queue_delete(ctx->hmx_queue); + ctx->hmx_queue = NULL; + } + ctx->hmx_enabled = false; +#endif + vtcm_free(ctx); + if (ctx->ddr_spad_base) { + free(ctx->ddr_spad_base); + ctx->ddr_spad_base = NULL; + ctx->ddr_spad_size = 0; + } + return AEE_SUCCESS; } @@ -357,645 +503,411 @@ static void htp_error_callback(dspqueue_t queue, int error, void * context) { struct profile_data { uint64_t usecs; uint64_t cycles; - uint64_t pkts; + uint32_t pmu_counters[HEX_NUM_PMU_COUNTERS]; }; -static inline void profile_start(struct profile_data * d) { - d->usecs = HAP_perf_get_qtimer_count(); - d->cycles = htp_get_cycles(); - d->pkts = htp_get_pktcnt(); +static inline void profile_start(uint32_t mode, struct profile_data * d) { + switch (mode) { + case HTP_PROF_PMU: + hex_get_pmu(d->pmu_counters); + // fallthrough + case HTP_PROF_BASIC: + d->usecs = HAP_perf_get_qtimer_count(); + d->cycles = hex_get_cycles(); + break; + default: + break; + } } -static inline void profile_stop(struct profile_data * d) { - d->usecs = HAP_perf_qtimer_count_to_us(HAP_perf_get_qtimer_count() - d->usecs); - d->cycles = htp_get_cycles() - d->cycles; - d->pkts = htp_get_pktcnt() - d->pkts; +static inline void profile_stop(uint32_t mode, struct profile_data * d) { + uint32_t pmu_counters[HEX_NUM_PMU_COUNTERS]; + switch (mode) { + case HTP_PROF_PMU: + hex_get_pmu(pmu_counters); + for (int i = 0; i < HEX_NUM_PMU_COUNTERS; i++) { + d->pmu_counters[i] = pmu_counters[i] - d->pmu_counters[i]; + } + // fallthrough + case HTP_PROF_BASIC: + d->usecs = HAP_perf_qtimer_count_to_us(HAP_perf_get_qtimer_count() - d->usecs); + d->cycles = hex_get_cycles() - d->cycles; + break; + default: + break; + } } -static int send_htp_rsp(struct htp_context * c, - uint32_t op, - uint32_t status, - struct dspqueue_buffer * bufs, - size_t n_bufs, - struct profile_data * prof) { - // Prep response struct - struct htp_general_rsp rsp; - rsp.op = op; - rsp.status = status; - rsp.prof_usecs = prof->usecs; - rsp.prof_cycles = prof->cycles; - rsp.prof_pkts = prof->pkts; - - int err = dspqueue_write(c->queue, - 0, // Flags - n_bufs, - bufs, // Buffer references - sizeof(rsp), - (const uint8_t *) &rsp, // Message - DSPQUEUE_TIMEOUT_NONE); +static int execute_op(struct htp_ops_context * octx) { + switch (octx->op) { + case HTP_OP_MUL_MAT: + return op_matmul(octx); - if (err != 0) { - FARF(ERROR, "dspqueue_write failed: 0x%08x", (unsigned) err); - } + case HTP_OP_MUL_MAT_ID: + return op_matmul_id(octx); - return err; -} + case HTP_OP_MUL: + case HTP_OP_ADD: + case HTP_OP_SUB: + case HTP_OP_DIV: + case HTP_OP_ADD_ID: + return op_binary(octx); -static void proc_matmul_req(struct htp_context * ctx, - struct htp_general_req * req, - struct dspqueue_buffer * bufs, - size_t n_bufs) { - struct dspqueue_buffer rsp_bufs[1]; - - // We had written to the output buffer, we'd also need to flush it - rsp_bufs[0].fd = bufs[2].fd; - rsp_bufs[0].ptr = bufs[2].ptr; - rsp_bufs[0].size = bufs[2].size; - rsp_bufs[0].offset = bufs[2].offset; - rsp_bufs[0].flags = (DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | // Flush HTP - DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT); // Invalidate CPU - - // Setup Op context - struct htp_ops_context octx = { 0 }; - octx.ctx = ctx; - octx.src0 = req->src0; - octx.src1 = req->src1; - octx.dst = req->dst; - octx.flags = req->flags; - octx.op = req->op; - - // Update data pointers - octx.src0.data = (uint32_t) bufs[0].ptr; - octx.src1.data = (uint32_t) bufs[1].ptr; - octx.dst.data = (uint32_t) bufs[2].ptr; - octx.n_threads = ctx->n_threads; - - struct profile_data prof; - profile_start(&prof); - - uint32_t rsp_status = HTP_STATUS_INTERNAL_ERR; - if (vtcm_acquire(ctx) == AEE_SUCCESS) { - rsp_status = op_matmul(&octx); - vtcm_release(ctx); - } - - profile_stop(&prof); - send_htp_rsp(ctx, req->op, rsp_status, rsp_bufs, 1, &prof); -} + case HTP_OP_NORM: + case HTP_OP_RMS_NORM: + case HTP_OP_RMS_NORM_MUL: + case HTP_OP_SCALE: + case HTP_OP_SQR: + case HTP_OP_SQRT: + case HTP_OP_UNARY_SOFTPLUS: + case HTP_OP_UNARY_SIGMOID: + case HTP_OP_UNARY_NEG: + case HTP_OP_UNARY_EXP: + case HTP_OP_UNARY_TANH: + case HTP_OP_L2_NORM: + return op_unary(octx); -static void proc_get_rows_req(struct htp_context * ctx, struct htp_general_req * req, struct dspqueue_buffer * bufs) { - struct dspqueue_buffer rsp_bufs[1]; - - // We had written to the output buffer, we'd also need to flush it - rsp_bufs[0].fd = bufs[2].fd; - rsp_bufs[0].ptr = bufs[2].ptr; - rsp_bufs[0].offset = bufs[2].offset; - rsp_bufs[0].size = bufs[2].size; - rsp_bufs[0].flags = (DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | // Flush HTP - DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT); // Invalidate CPU - - // Setup Op context - struct htp_ops_context octx = { 0 }; - octx.ctx = ctx; - octx.src0 = req->src0; - octx.src1 = req->src1; - octx.dst = req->dst; - octx.flags = req->flags; - octx.op = req->op; - - // Update data pointers - octx.src0.data = (uint32_t) bufs[0].ptr; - octx.src1.data = (uint32_t) bufs[1].ptr; - octx.dst.data = (uint32_t) bufs[2].ptr; - octx.n_threads = ctx->n_threads; - - struct profile_data prof; - profile_start(&prof); - - uint32_t rsp_status = HTP_STATUS_INTERNAL_ERR; - if (vtcm_acquire(ctx) == AEE_SUCCESS) { - rsp_status = op_get_rows(&octx); - vtcm_release(ctx); - } - - profile_stop(&prof); - send_htp_rsp(ctx, req->op, rsp_status, rsp_bufs, 1, &prof); -} + case HTP_OP_UNARY_SILU: + case HTP_OP_UNARY_GELU: + case HTP_OP_GLU_SWIGLU: + case HTP_OP_GLU_SWIGLU_OAI: + case HTP_OP_GLU_GEGLU: + return op_activations(octx); -static void proc_matmul_id_req(struct htp_context * ctx, - struct htp_general_req * req, - struct dspqueue_buffer * bufs, - size_t n_bufs) { - struct dspqueue_buffer rsp_bufs[1]; - - // We had written to the output buffer, we'd also need to flush it - rsp_bufs[0].fd = bufs[3].fd; - rsp_bufs[0].ptr = bufs[3].ptr; - rsp_bufs[0].size = bufs[3].size; - rsp_bufs[0].offset = bufs[3].offset; - rsp_bufs[0].flags = (DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | // Flush HTP - DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT); // Invalidate CPU - - // Setup Op context - struct htp_ops_context octx = { 0 }; - octx.ctx = ctx; - octx.src0 = req->src0; - octx.src1 = req->src1; - octx.src2 = req->src2; - octx.dst = req->dst; - octx.flags = req->flags; - octx.op = req->op; - - // Update data pointers - octx.src0.data = (uint32_t) bufs[0].ptr; - octx.src1.data = (uint32_t) bufs[1].ptr; - octx.src2.data = (uint32_t) bufs[2].ptr; - octx.dst.data = (uint32_t) bufs[3].ptr; - octx.n_threads = ctx->n_threads; - - struct profile_data prof; - profile_start(&prof); - - uint32_t rsp_status = HTP_STATUS_INTERNAL_ERR; - if (vtcm_acquire(ctx) == AEE_SUCCESS) { - rsp_status = op_matmul_id(&octx); - vtcm_release(ctx); - } - - profile_stop(&prof); - send_htp_rsp(ctx, req->op, rsp_status, rsp_bufs, 1, &prof); -} + case HTP_OP_SOFTMAX: + return op_softmax(octx); -static void proc_binary_req(struct htp_context * ctx, struct htp_general_req * req, struct dspqueue_buffer * bufs) { - struct dspqueue_buffer rsp_bufs[1]; - - // We had written to the output buffer, we'd also need to flush it - rsp_bufs[0].fd = bufs[2].fd; - rsp_bufs[0].ptr = bufs[2].ptr; - rsp_bufs[0].offset = bufs[2].offset; - rsp_bufs[0].size = bufs[2].size; - rsp_bufs[0].flags = (DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | // Flush HTP - DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT); // Invalidate CPU - - // Setup Op context - struct htp_ops_context octx = { 0 }; - octx.ctx = ctx; - octx.src0 = req->src0; - octx.src1 = req->src1; - octx.dst = req->dst; - octx.flags = req->flags; - octx.op = req->op; - - // Update data pointers - octx.src0.data = (uint32_t) bufs[0].ptr; - octx.src1.data = (uint32_t) bufs[1].ptr; - octx.dst.data = (uint32_t) bufs[2].ptr; - octx.n_threads = ctx->n_threads; - - struct profile_data prof; - profile_start(&prof); - - uint32_t rsp_status = HTP_STATUS_INTERNAL_ERR; - if (vtcm_acquire(ctx) == AEE_SUCCESS) { - rsp_status = op_binary(&octx); - vtcm_release(ctx); - } - - profile_stop(&prof); - send_htp_rsp(ctx, req->op, rsp_status, rsp_bufs, 1, &prof); -} + case HTP_OP_ROPE: + return op_rope(octx); -static void proc_add_id_req(struct htp_context * ctx, struct htp_general_req * req, struct dspqueue_buffer * bufs) { - struct dspqueue_buffer rsp_bufs[1]; - - // We had written to the output buffer, we'd also need to flush it - rsp_bufs[0].fd = bufs[3].fd; - rsp_bufs[0].ptr = bufs[3].ptr; - rsp_bufs[0].offset = bufs[3].offset; - rsp_bufs[0].size = bufs[3].size; - rsp_bufs[0].flags = (DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | // Flush HTP - DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT); // Invalidate CPU - - // Setup Op context - struct htp_ops_context octx = { 0 }; - octx.ctx = ctx; - octx.src0 = req->src0; - octx.src1 = req->src1; - octx.src2 = req->src2; - octx.dst = req->dst; - octx.flags = req->flags; - octx.op = req->op; - - // Update data pointers - octx.src0.data = (uint32_t) bufs[0].ptr; - octx.src1.data = (uint32_t) bufs[1].ptr; - octx.src2.data = (uint32_t) bufs[2].ptr; - octx.dst.data = (uint32_t) bufs[3].ptr; - octx.n_threads = ctx->n_threads; - - struct profile_data prof; - profile_start(&prof); - - uint32_t rsp_status = HTP_STATUS_INTERNAL_ERR; - if (vtcm_acquire(ctx) == AEE_SUCCESS) { - rsp_status = op_binary(&octx); - vtcm_release(ctx); - } - - profile_stop(&prof); - send_htp_rsp(ctx, req->op, rsp_status, rsp_bufs, 1, &prof); -} + case HTP_OP_FLASH_ATTN_EXT: + return op_flash_attn_ext(octx); + + case HTP_OP_SET_ROWS: + return op_set_rows(octx); + + case HTP_OP_GET_ROWS: + return op_get_rows(octx); + + case HTP_OP_SUM_ROWS: + return op_sum_rows(octx); + + case HTP_OP_CPY: + return op_cpy(octx); + + case HTP_OP_REPEAT: + return op_repeat(octx); -static void proc_unary_req(struct htp_context * ctx, struct htp_general_req * req, struct dspqueue_buffer * bufs) { - struct dspqueue_buffer rsp_bufs[HTP_MAX_PACKET_BUFFERS]; + case HTP_OP_ARGSORT: + return op_argsort(octx); - // We had written to the output buffer, we'd also need to flush it - rsp_bufs[0].fd = bufs[1].fd; - rsp_bufs[0].ptr = bufs[1].ptr; - rsp_bufs[0].offset = bufs[1].offset; - rsp_bufs[0].size = bufs[1].size; - rsp_bufs[0].flags = (DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | // Flush HTP - DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT); // Invalidate CPU + case HTP_OP_SSM_CONV: + return op_ssm_conv(octx); - // Setup Op context - struct htp_ops_context octx = { 0 }; - octx.ctx = ctx; - octx.src0 = req->src0; - octx.dst = req->dst; - octx.flags = req->flags; - octx.op = req->op; + case HTP_OP_CUMSUM: + return op_cumsum(octx); - memcpy(octx.op_params, req->op_params, sizeof(octx.op_params)); + case HTP_OP_FILL: + return op_fill(octx); - // Update data pointers - octx.src0.data = (uint32_t) bufs[0].ptr; - octx.dst.data = (uint32_t) bufs[1].ptr; - octx.n_threads = ctx->n_threads; + case HTP_OP_DIAG: + return op_diag(octx); - struct profile_data prof; - profile_start(&prof); + case HTP_OP_SOLVE_TRI: + return op_solve_tri(octx); - uint32_t rsp_status = HTP_STATUS_INTERNAL_ERR; - if (vtcm_acquire(ctx) == AEE_SUCCESS) { - rsp_status = op_unary(&octx); - vtcm_release(ctx); + case HTP_OP_PAD: + return op_pad(octx); + + case HTP_OP_CONCAT: + return op_concat(octx); + + case HTP_OP_GATED_DELTA_NET: + return op_gated_delta_net(octx); + + case HTP_OP_TRI: + return op_tri(octx); + + case HTP_OP_INVALID: + break; + + // No default to catch missing cases } - profile_stop(&prof); - send_htp_rsp(ctx, req->op, rsp_status, rsp_bufs, 1, &prof); + FARF(ERROR, "Unknown Op %u", octx->op); + return -1; } -static void proc_activations_req(struct htp_context * ctx, - struct htp_general_req * req, - struct dspqueue_buffer * bufs, - uint32_t n_bufs) { - struct dspqueue_buffer rsp_bufs[HTP_MAX_PACKET_BUFFERS]; - - int write_idx = (n_bufs == 3) ? 2 : 1; - - // We had written to the output buffer, we'd also need to flush it - rsp_bufs[0].fd = bufs[write_idx].fd; - rsp_bufs[0].ptr = bufs[write_idx].ptr; - rsp_bufs[0].offset = bufs[write_idx].offset; - rsp_bufs[0].size = bufs[write_idx].size; - rsp_bufs[0].flags = (DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | // Flush HTP - DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT); // Invalidate CPU - - // Setup Op context - struct htp_ops_context octx = { 0 }; - octx.ctx = ctx; - octx.src0 = req->src0; - if (3 == n_bufs) { - octx.src1 = req->src1; - } - octx.dst = req->dst; - octx.flags = req->flags; - octx.op = req->op; - - memcpy(octx.op_params, req->op_params, sizeof(octx.op_params)); - - // Update data pointers - octx.src0.data = (uint32_t) bufs[0].ptr; - if (3 == n_bufs) { - octx.src1.data = (uint32_t) bufs[1].ptr; - octx.dst.data = (uint32_t) bufs[2].ptr; - } else { - octx.dst.data = (uint32_t) bufs[1].ptr; - } - octx.n_threads = ctx->n_threads; - - struct profile_data prof; - profile_start(&prof); - - uint32_t rsp_status = HTP_STATUS_INTERNAL_ERR; - if (vtcm_acquire(ctx) == AEE_SUCCESS) { - if (octx.op == HTP_OP_SOFTMAX) { - rsp_status = op_softmax(&octx); - } else { - rsp_status = op_activations(&octx); +static inline bool reuse_buf(struct htp_context *ctx, uint32_t *m_reuse, struct htp_buf_desc *b) { + b->base = NULL; + + for (uint32_t i=0; i<HTP_MAX_MMAPS; i++) { + struct htp_mmap *m = ctx->mmap + i; + if (m->size && m->fd == b->fd) { + b->base = m->base; + *m_reuse |= (1 << i); + return true; } - vtcm_release(ctx); } - profile_stop(&prof); - send_htp_rsp(ctx, req->op, rsp_status, rsp_bufs, 1, &prof); + return false; } -static void proc_rope_req(struct htp_context * ctx, - struct htp_general_req * req, - struct dspqueue_buffer * bufs, - uint32_t n_bufs) { - struct dspqueue_buffer rsp_bufs[HTP_MAX_PACKET_BUFFERS]; - - int write_idx = n_bufs - 1; - - // We had written to the output buffer, we'd also need to flush it - rsp_bufs[0].fd = bufs[write_idx].fd; - rsp_bufs[0].ptr = bufs[write_idx].ptr; - rsp_bufs[0].offset = bufs[write_idx].offset; - rsp_bufs[0].size = bufs[write_idx].size; - rsp_bufs[0].flags = (DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | // Flush HTP - DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT); // Invalidate CPU - - // Setup Op context - struct htp_ops_context octx = { 0 }; - octx.ctx = ctx; - octx.src0 = req->src0; - octx.src1 = req->src1; - if (4 == n_bufs) { - octx.src2 = req->src2; - } - octx.dst = req->dst; - octx.flags = req->flags; - octx.op = req->op; - - memcpy(octx.op_params, req->op_params, sizeof(octx.op_params)); - - // Update data pointers - octx.src0.data = (uint32_t) bufs[0].ptr; - octx.src1.data = (uint32_t) bufs[1].ptr; - if (4 == n_bufs) { - octx.src2.data = (uint32_t) bufs[2].ptr; - octx.dst.data = (uint32_t) bufs[3].ptr; - } else { - octx.dst.data = (uint32_t) bufs[2].ptr; - } - octx.n_threads = ctx->n_threads; - - struct profile_data prof; - profile_start(&prof); - - uint32_t rsp_status = HTP_STATUS_INTERNAL_ERR; - if (vtcm_acquire(ctx) == AEE_SUCCESS) { - rsp_status = op_rope(&octx); - vtcm_release(ctx); - } - - profile_stop(&prof); - send_htp_rsp(ctx, req->op, rsp_status, rsp_bufs, 1, &prof); +static inline void drop_mmap(struct htp_context *ctx, struct htp_mmap *m) { + if (m->size) { + FARF(HIGH, "unmap : fd %u base %p size %u", m->fd, (void*) m->base, (uint32_t) m->size); +#if __HVX_ARCH__ > 73 + HAP_munmap2((void *) m->base, m->size); +#else + HAP_munmap((void *) m->base, m->size); +#endif + m->size = 0; + m->base = 0; + m->fd = -1; + } } -static void proc_set_rows_req(struct htp_context * ctx, struct htp_general_req * req, struct dspqueue_buffer * bufs) { - struct dspqueue_buffer rsp_bufs[1]; - - // We had written to the output buffer, we'd also need to flush it - rsp_bufs[0].fd = bufs[2].fd; - rsp_bufs[0].ptr = bufs[2].ptr; - rsp_bufs[0].offset = bufs[2].offset; - rsp_bufs[0].size = bufs[2].size; - rsp_bufs[0].flags = (DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | // Flush HTP - DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT); // Invalidate CPU - - // Setup Op context - struct htp_ops_context octx = { 0 }; - octx.ctx = ctx; - octx.src0 = req->src0; - octx.src1 = req->src1; - octx.dst = req->dst; - octx.flags = req->flags; - octx.op = req->op; - - // Update data pointers - octx.src0.data = (uint32_t) bufs[0].ptr; - octx.src1.data = (uint32_t) bufs[1].ptr; - octx.dst.data = (uint32_t) bufs[2].ptr; - octx.n_threads = ctx->n_threads; - - struct profile_data prof; - profile_start(&prof); - - uint32_t rsp_status = HTP_STATUS_INTERNAL_ERR; - if (vtcm_acquire(ctx) == AEE_SUCCESS) { - rsp_status = op_set_rows(&octx); - vtcm_release(ctx); - } - - profile_stop(&prof); - send_htp_rsp(ctx, req->op, rsp_status, rsp_bufs, 1, &prof); +static inline void mmap_buf(struct htp_context *ctx, struct htp_buf_desc *b) { + if (b->base) return; // already mapped + + // find unused mapping + for (uint32_t i=0; i < HTP_MAX_MMAPS; i++) { + struct htp_mmap *m = &ctx->mmap[i]; + if (!m->size) { +#if __HVX_ARCH__ > 73 + void *va = HAP_mmap2(NULL, b->size, HAP_PROT_READ | HAP_PROT_WRITE, 0, b->fd, 0); +#else + if (b->size > HTP_MMAP_MAX_VMEM) { // HAP_mmap has a size limit of 2GB + FARF(ERROR, "mmap failed : size %u exceeds 2GB limit for HAP_mmap", (uint32_t) b->size); + abort(); // can't do much else at this point + } + + void *va = HAP_mmap(NULL, b->size, HAP_PROT_READ | HAP_PROT_WRITE, 0, b->fd, 0); +#endif + if (va == (void*)-1) { + FARF(ERROR, "mmap failed : va %p fd %u size %u", va, b->fd, (uint32_t) b->size); + abort(); // can't do much else at this point + } + + m->base = b->base = (uint64_t) va; + m->fd = b->fd; + m->size = b->size; + + FARF(HIGH, "mmap : fd %u base %p size %u", m->fd, (void*) m->base, (uint32_t) m->size); + return; + } + } } -static void proc_flash_attn_ext_req(struct htp_context * ctx, - struct htp_general_req * req, - struct dspqueue_buffer * bufs, - uint32_t n_bufs) { - // Setup Op context - struct htp_ops_context octx; - memset(&octx, 0, sizeof(octx)); +static void prep_op_bufs(struct htp_context *ctx, struct htp_buf_desc *bufs, uint32_t n_bufs) { + uint32_t m_reuse = 0; // mmap reuse mask (index from ctx->mmap array) + uint32_t b_reuse = 0; // buf reuse count - octx.ctx = ctx; - octx.n_threads = ctx->n_threads; + uint64_t m_vmem = 0; // mapped vmem + uint64_t e_vmem = 0; // extra vmem - octx.src0 = req->src0; - octx.src1 = req->src1; - octx.src2 = req->src2; - octx.src3 = req->src3; - octx.src4 = req->src4; - octx.dst = req->dst; - octx.flags = req->flags; - octx.op = req->op; + // See what we can reuse + for (uint32_t i=0; i < n_bufs; i++) { + struct htp_buf_desc *b = bufs + i; + if (reuse_buf(ctx, &m_reuse, b)) { b_reuse++; } else { e_vmem += b->size; } + FARF(HIGH, "prep-buf #%u : pass0 fd %u base %p size %u flags 0x%x", i, b->fd, (void*) b->base, (uint32_t) b->size, b->flags); + } - memcpy(octx.op_params, req->op_params, sizeof(octx.op_params)); + if (b_reuse == n_bufs) return; // all bufs reuse existing mappings - // Update data pointers - octx.src0.data = (uint32_t) bufs[0].ptr; - octx.src1.data = (uint32_t) bufs[1].ptr; - octx.src2.data = (uint32_t) bufs[2].ptr; + // See how much vmem we have mmaped right now + for (uint32_t i=0; i<HTP_MAX_MMAPS; i++) { m_vmem += ctx->mmap[i].size; } - int last_buf = 3; + FARF(HIGH, "prep-bufs : pass1 mmap-vmem %zu extra-vmem %zu max-vmem %zu : n-bufs %u b-reuse %u", + (size_t) m_vmem, (size_t) e_vmem, (size_t) ctx->max_vmem, n_bufs, b_reuse); - if (octx.src3.ne[0]) { - octx.src3.data = (uint32_t) bufs[last_buf++].ptr; // mask is valid + if ((m_vmem + e_vmem) > ctx->max_vmem) { + // Drop unused mappings + for (uint32_t i=0; i < HTP_MAX_MMAPS; i++) { + bool used = m_reuse & (1<<i); + if (!used) { drop_mmap(ctx, ctx->mmap + i); } + } } - if (octx.src4.ne[0]) { - octx.src4.data = (uint32_t) bufs[last_buf++].ptr; // sinks is valid + // Create missing mappings + for (uint32_t i=0; i < n_bufs; i++) { + struct htp_buf_desc *b = bufs + i; + mmap_buf(ctx, b); + FARF(HIGH, "prep-buf #%u : pass1 fd %u base %p size %u flags 0x%x", i, b->fd, (void*) b->base, (uint32_t) b->size, b->flags); } +} + +static void prep_tensor(struct htp_context *ctx, struct htp_buf_desc *bufs, uint32_t idx, struct htp_tensor *t) { + uint32_t offset = t->data; + uint32_t size = t->size; + uint32_t bi = t->bi; - octx.dst.data = (uint32_t) bufs[last_buf].ptr; + t->data = bufs[bi].base + offset; // update data to the actual pointer - struct profile_data prof; - profile_start(&prof); + FARF(HIGH, "prep-tensor #%u: bi %u offset %u size %u data %p : %u:%u:%u:%u", idx, t->bi, offset, t->size, (void*) t->data, + t->ne[0], t->ne[1], t->ne[3], t->ne[3]); +} - uint32_t rsp_status = HTP_STATUS_INTERNAL_ERR; - if (vtcm_acquire(ctx) == AEE_SUCCESS) { - rsp_status = op_flash_attn_ext(&octx); - vtcm_release(ctx); +static void prep_tensors(struct htp_context *ctx, struct htp_buf_desc *bufs, struct htp_tensor *tens, uint32_t n_tens) { + for (uint32_t i=0; i < n_tens; i++) { + prep_tensor(ctx, bufs, i, tens + i); } +} + +static void proc_op_req(struct htp_ops_context * octx, struct htp_tensor *tens, uint32_t idx, struct htp_op_desc * op) { + memcpy(octx->op_params, op->params, sizeof(octx->op_params)); + octx->flags = op->flags; + octx->op = op->opcode; + + FARF(HIGH, "proc-op #%u: opcode %u flags 0x%x", idx, octx->op, octx->flags); - profile_stop(&prof); + // Prep input tensors + for (uint32_t i=0; i<HTP_OP_MAX_INPUTS; i++) { + struct htp_tensor *src = op->src[i] == 0xffff ? NULL : tens + op->src[i]; - struct dspqueue_buffer rsp_buf = bufs[last_buf]; - rsp_buf.flags = (DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | // Flush HTP - DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT); // Invalidate CPU + octx->src[i] = src; + if (!src) continue; + + if (!(src->flags & HTP_TENSOR_FLUSHED) && (src->flags & HTP_TENSOR_COMPUTE)) { + // flush compute buffers on input + hex_l2flush((void *) src->data, src->size); + } - send_htp_rsp(ctx, req->op, rsp_status, &bufs[last_buf], 1, &prof); + FARF(HIGH, "prep-src #%u: data %p size %u : %u:%u:%u:%u", op->src[i], (void*) src->data, src->size, + src->ne[0], src->ne[1], src->ne[3], src->ne[3]); + } + + // Prep output tensor + struct htp_tensor *dst = tens + op->dst; + + octx->dst = dst; + + FARF(HIGH, "prep-dst #%u: data %p size %u : %u:%u:%u:%u", op->dst, (void*) dst->data, dst->size, + dst->ne[0], dst->ne[1], dst->ne[3], dst->ne[3]); + + (void) execute_op(octx); + + // flush buffers on output + hex_l2flush((void *) dst->data, dst->size); + dst->flags |= HTP_TENSOR_FLUSHED; + + FARF(HIGH, "post-dst #%u: data %p size %u : %u:%u:%u:%u", op->dst, (void*) dst->data, dst->size, + dst->ne[0], dst->ne[1], dst->ne[3], dst->ne[3]); } +#define DSPQUEUE_POLL_TIMEOUT_USEC 100 +#define DSPQUEUE_POLL_COUNT 100 + static void htp_packet_callback(dspqueue_t queue, int error, void * context) { struct htp_context * ctx = (struct htp_context *) context; - // Repeatedly read packets from the queue until it's empty. We don't - // necessarily get a separate callback for each packet, and new packets - // may arrive while we're processing the previous one. This ensures we - // keep the DSP busy as much as possible and avoid waiting for the CPU. + int err; + + uint32_t poll_count = DSPQUEUE_POLL_COUNT; - while (1) { - struct htp_general_req req; - uint32_t req_size; + vtcm_acquire(ctx); - struct dspqueue_buffer bufs[HTP_MAX_PACKET_BUFFERS]; - uint32_t n_bufs; - uint32_t flags; + while (!ctx->vtcm_needs_release) { + struct htp_opbatch_req req; + uint32_t r_size = sizeof(req); - // Read packet from queue - int err = dspqueue_read_noblock(queue, &flags, - HTP_MAX_PACKET_BUFFERS, // Maximum number of buffer references - &n_bufs, // Number of buffer references - bufs, // Buffer references - sizeof(req), // Max message length - &req_size, // Message length - (uint8_t *) &req); // Message + struct dspqueue_buffer dbuf; + uint32_t n_dbufs = 1; + uint32_t flags = 0; + err = dspqueue_read_noblock(queue, &flags, n_dbufs, &n_dbufs, &dbuf, r_size, &r_size, (uint8_t *) &req); if (err == AEE_EWOULDBLOCK) { - // Consumed all packets available for now - return; + if (--poll_count) { + qurt_sleep(DSPQUEUE_POLL_TIMEOUT_USEC); + continue; + } + break; } if (err != 0) { FARF(ERROR, "dspqueue_read_noblock failed: 0x%08x", (unsigned) err); - return; + break; } - if (req_size != sizeof(req)) { - FARF(ERROR, "Invalid request size"); + if (r_size < sizeof(req) || n_dbufs != 1) { + FARF(ERROR, "invalid request : size %u n-dbufs %u", r_size, n_dbufs); continue; } - if (req.flags & HTP_OPFLAGS_EARLY_WAKEUP) { - // Host wants early notification - dspqueue_write_early_wakeup_noblock(ctx->queue, 10, 0); + // Reset poll count for valid requests + poll_count = DSPQUEUE_POLL_COUNT; + + const uint32_t n_bufs = req.n_bufs; + const uint32_t n_tens = req.n_tensors; + const uint32_t n_ops = req.n_ops; + + const uint32_t b_size = sizeof(struct htp_buf_desc) * n_bufs; + const uint32_t t_size = sizeof(struct htp_tensor) * n_tens; + const uint32_t o_size = sizeof(struct htp_op_desc) * n_ops; + const uint32_t p_size = sizeof(struct htp_prof_desc) * n_ops; + + if (dbuf.size < b_size + t_size + o_size + p_size) { + FARF(ERROR, "invalid opbatch memory block size %u", dbuf.size); + break; } - // Process packet based on its message type - switch (req.op) { - case HTP_OP_MUL_MAT: - if (n_bufs != 3) { - FARF(ERROR, "Bad matmul-req buffer list"); - continue; - } - proc_matmul_req(ctx, &req, bufs, n_bufs); - break; + FARF(HIGH, "processing opbatch #%u: n-bufs %u n-tensors %u n-ops %u : m-size %u b-size %u t-size %u o-size %u", req.id, + n_bufs, n_tens, n_ops, dbuf.size, b_size, t_size, o_size); - case HTP_OP_MUL_MAT_ID: - if (n_bufs != 4) { - FARF(ERROR, "Bad matmul-id-req buffer list"); - continue; - } - proc_matmul_id_req(ctx, &req, bufs, n_bufs); - break; - - case HTP_OP_MUL: - case HTP_OP_ADD: - case HTP_OP_SUB: - if (n_bufs != 3) { - FARF(ERROR, "Bad binary-req buffer list"); - continue; - } - proc_binary_req(ctx, &req, bufs); - break; - - case HTP_OP_RMS_NORM: - case HTP_OP_SCALE: - if (n_bufs != 2) { - FARF(ERROR, "Bad unary-req buffer list"); - continue; - } + // Setup descriptor pointers + uint8_t * m_ptr = dbuf.ptr; + struct htp_buf_desc* bufs = (struct htp_buf_desc*) m_ptr; m_ptr += b_size; + struct htp_tensor* tens = (struct htp_tensor*) m_ptr; m_ptr += t_size; + struct htp_op_desc* ops = (struct htp_op_desc*) m_ptr; m_ptr += o_size; + struct htp_prof_desc* pds = (struct htp_prof_desc*) m_ptr; - proc_unary_req(ctx, &req, bufs); - break; + prep_op_bufs(ctx, bufs, n_bufs); + prep_tensors(ctx, bufs, tens, n_tens); - case HTP_OP_UNARY_SILU: - case HTP_OP_UNARY_GELU: - if (n_bufs != 2) { - FARF(ERROR, "Bad act-req buffer list"); - continue; - } - proc_activations_req(ctx, &req, bufs, n_bufs); - break; - - case HTP_OP_GLU_SWIGLU: - case HTP_OP_GLU_SWIGLU_OAI: - case HTP_OP_SOFTMAX: - if ((n_bufs != 2) && (n_bufs != 3)) { - FARF(ERROR, "Bad act-req buffer list"); - continue; - } - proc_activations_req(ctx, &req, bufs, n_bufs); - break; + struct htp_ops_context *octx = &ctx->octx; + memset(octx, 0, sizeof(*octx)); + octx->n_threads = ctx->n_threads; + octx->ctx = ctx; - case HTP_OP_ADD_ID: - if (n_bufs != 4) { - FARF(ERROR, "Bad add-id-req buffer list"); - continue; - } - proc_add_id_req(ctx, &req, bufs); - break; + for (uint32_t i=0; i < n_ops; i++) { + struct profile_data prof; - case HTP_OP_ROPE: - if ((n_bufs != 3) && (n_bufs != 4)) { - FARF(ERROR, "Bad rope-req buffer list"); - continue; - } - proc_rope_req(ctx, &req, bufs, n_bufs); - break; + if (i == (n_ops-1)) { + // wake up the host before starting the last op + dspqueue_write_early_wakeup_noblock(queue, 0, 0); + } - case HTP_OP_FLASH_ATTN_EXT: - if (!(n_bufs >= 4 && n_bufs <= 6)) { - FARF(ERROR, "Bad flash-attn-ext-req buffer list"); - continue; - } - proc_flash_attn_ext_req(ctx, &req, bufs, n_bufs); - break; + profile_start(ctx->profiler, &prof); - case HTP_OP_SET_ROWS: - if (n_bufs != 3) { - FARF(ERROR, "Bad set-rows-req buffer list"); - continue; - } - proc_set_rows_req(ctx, &req, bufs); - break; + proc_op_req(octx, tens, i, &ops[i]); + + profile_stop(ctx->profiler, &prof); - case HTP_OP_GET_ROWS: - if (n_bufs != 3) { - FARF(ERROR, "Bad get-rows-req buffer list"); - continue; + if (ctx->profiler) { + pds[i].opcode = ops[i].opcode; + pds[i].usecs = prof.usecs; + pds[i].cycles = prof.cycles; + for (int j = 0; j < HEX_NUM_PMU_COUNTERS; j++) { + pds[i].pmu[j] = prof.pmu_counters[j]; } - proc_get_rows_req(ctx, &req, bufs); - break; + } + } + + struct htp_opbatch_rsp rsp; + rsp.id = req.id; + rsp.status = HTP_STATUS_OK; + rsp.n_bufs = n_bufs; + rsp.n_tensors = n_tens; + rsp.n_ops = n_ops; + + dbuf.flags = DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT; - default: - FARF(ERROR, "Unknown Op %u", req.op); - break; + err = dspqueue_write(queue, 0, 1, &dbuf, sizeof(rsp), (const uint8_t *) &rsp, DSPQUEUE_TIMEOUT_NONE); + if (err != 0) { + FARF(ERROR, "dspqueue_write failed: 0x%08x", (unsigned) err); + break; } } + + vtcm_release(ctx); } diff --git a/ggml/src/ggml-hexagon/htp/matmul-ops.c b/ggml/src/ggml-hexagon/htp/matmul-ops.c index 9bb39db9fcb..5121c6f9bad 100644 --- a/ggml/src/ggml-hexagon/htp/matmul-ops.c +++ b/ggml/src/ggml-hexagon/htp/matmul-ops.c @@ -3,105 +3,61 @@ #pragma clang diagnostic ignored "-Wunused-variable" #pragma clang diagnostic ignored "-Wunused-but-set-variable" -#ifdef HTP_DEBUG -# define FARF_HIGH 1 -#endif - #include <HAP_farf.h> -#include <HAP_mem.h> #include <HAP_perf.h> -#include <HAP_ps.h> -#include <hexagon_protos.h> -#include <hexagon_types.h> + #include <math.h> -#include <qurt_thread.h> #include <string.h> +#include "hex-dma.h" +#include "hvx-utils.h" +#include "hvx-dump.h" + #define GGML_COMMON_DECL_C #include "ggml-common.h" #include "htp-ctx.h" -#include "htp-dma.h" -#include "htp-msg.h" #include "htp-ops.h" -#include "hvx-utils.h" -#include "ops-utils.h" +#include "htp-ops.h" +#include "hmx-ops.h" #define MM_SPAD_SRC0_NROWS 16 #define MM_SPAD_SRC1_NROWS 16 #define MM_SPAD_DST_NROWS 2 -struct htp_matmul_type { +struct htp_matmul_context { const char * type; - void (*vec_dot)(const int n, float * restrict s, const void * restrict vx, const void * restrict vy); - void (*vec_dot_rx2)(const int n, float * restrict s, const void * restrict vx, uint32_t vx_row_size, const void * restrict vy); -}; - -typedef struct { - HVX_Vector v[2]; -} HVX_Vector_x2; - -typedef struct { - HVX_Vector v[4]; -} HVX_Vector_x4; - -typedef struct { - HVX_Vector v[8]; -} HVX_Vector_x8; - -// vdelta control to replicate first 4x fp32 values across lanes -static const uint8_t __attribute__((aligned(128))) repl_4x_fp32[128] = { - 0x00, 0x00, 0x00, 0x00, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x10, 0x10, 0x10, - 0x10, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x04, 0x04, 0x04, 0x04, 0x20, 0x20, - 0x20, 0x20, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x10, 0x10, 0x10, 0x10, 0x04, - 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x40, 0x40, 0x40, 0x40, - 0x44, 0x44, 0x44, 0x44, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x10, 0x10, 0x10, 0x10, 0x04, 0x04, 0x04, - 0x04, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x20, 0x20, 0x20, 0x20, 0x04, 0x04, - 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x10, 0x10, 0x10, 0x10, -}; - -// vdelta control to replicate and interleave first 8x fp32 values across lanes -static const uint8_t __attribute__((aligned(128))) repl_interleave_8x_fp32[128] = { - 0x00, 0x00, 0x00, 0x00, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x00, 0x00, 0x00, - 0x00, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x04, 0x04, 0x04, 0x04, 0x20, 0x20, - 0x20, 0x20, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x20, 0x20, 0x20, 0x20, 0x04, - 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x40, 0x40, 0x40, 0x40, - 0x44, 0x44, 0x44, 0x44, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x40, 0x40, 0x40, 0x40, 0x44, 0x44, 0x44, - 0x44, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x20, 0x20, 0x20, 0x20, 0x04, 0x04, - 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x20, 0x20, 0x20, 0x20, -}; - -// vdelta control to replicate first fp32 value across all elements -static const uint8_t __attribute__((aligned(128))) repl_1x_fp32[128] = { - 0x00, 0x00, 0x00, 0x00, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x10, 0x10, 0x10, - 0x10, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x20, 0x20, 0x20, 0x20, 0x04, 0x04, - 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x10, 0x10, 0x10, 0x10, 0x04, 0x04, 0x04, 0x04, 0x08, - 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x40, 0x40, 0x40, 0x40, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, - 0x04, 0x04, 0x04, 0x04, 0x10, 0x10, 0x10, 0x10, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, - 0x04, 0x20, 0x20, 0x20, 0x20, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x10, 0x10, - 0x10, 0x10, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, -}; - -// vdelta control to replicate first fp16 value across all elements -static const uint8_t __attribute__((aligned(128))) repl_1x_fp16[128] = { - 0x00, 0x00, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x10, 0x10, 0x02, - 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x20, 0x20, 0x02, 0x02, 0x04, 0x04, - 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x10, 0x10, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, - 0x08, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x40, 0x40, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, - 0x04, 0x04, 0x02, 0x02, 0x10, 0x10, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, 0x02, - 0x02, 0x20, 0x20, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x10, 0x10, - 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, -}; - -// vdelta control to replicate first fp16 value across all elements -static const uint8_t __attribute__((aligned(128))) repl_2x_fp16[128] = { - 0x00, 0x00, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, - 0x10, 0x10, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, - 0x20, 0x20, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, - 0x10, 0x10, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, - 0x00, 0x00, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, - 0x10, 0x10, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, - 0x20, 0x20, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, - 0x10, 0x10, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, + struct htp_ops_context * octx; + + void (*vec_dot_1x1)(const int n, float * restrict s0, + const void * restrict vx0, + const void * restrict vy0); + + void (*vec_dot_2x1)(const int n, float * restrict s0, + const void * restrict vx0, const void * restrict vx1, + const void * restrict vy0); + + void (*vec_dot_2x2)(const int n, float * restrict s0, float * restrict s1, + const void * restrict vx0, const void * restrict vx1, + const void * restrict vy0, const void * restrict vy1); + + void (*vec_dot_4x1)(const int n, float * restrict s0, + const void * restrict vx0, const void * restrict vx1, + const void * restrict vx2, const void * restrict vx3, + const void * restrict vy0); + + // Precomputed values + uint32_t src0_nrows_per_thread; + uint32_t src1_nrows_per_thread; + + struct fastdiv_values mm_div_ne12_ne1; + struct fastdiv_values mm_div_ne1; + struct fastdiv_values mm_div_r2; + struct fastdiv_values mm_div_r3; + + // Fields for scattered mapping & HMX support in MUL_MAT_ID + const uint32_t * matrix_row_counts; + const struct mmid_row_mapping * matrix_rows; + bool hmx_eligible; }; // vdelta control to expand first 32 e8m0 values into 32 uint32 elements @@ -115,6 +71,16 @@ static const uint8_t __attribute__((aligned(128))) expand_x32_e8m0[128] = { 0x00, 0x00, 0x09, 0x08, 0x00, 0x00, 0x22, 0x20, 0x24, 0x20, 0x21, 0x22, 0x20, 0x20, }; +// IQ4_NL dequantization LUT: maps 4-bit index (0-15) to int8 kvalue +// kvalues: -127, -104, -83, -65, -49, -35, -22, -10, 1, 13, 25, 38, 53, 69, 89, 113 +static const uint8_t __attribute__((aligned(VLEN))) kvalues_iq4nl_lut[] = { + 0x81, 0, 0x98, 0, 0xAD, 0, 0xBF, 0, 0xCF, 0, 0xDD, 0, 0xEA, 0, 0xF6, 0, 0x01, 0, 0x0D, 0, 0x19, 0, 0x26, 0, + 0x35, 0, 0x45, 0, 0x59, 0, 0x71, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, +}; + static const uint8_t __attribute__((aligned(VLEN))) kvalues_mxfp4_lut[] = { 0, 0, 1, 0, 2, 0, 3, 0, 4, 0, 6, 0, 8, 0, 12, 0, 0, 0, 0xff, 0, 0xfe, 0, 0xfd, 0, 0xfc, 0, 0xfa, 0, 0xf8, 0, 0xf4, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, @@ -123,16 +89,90 @@ static const uint8_t __attribute__((aligned(VLEN))) kvalues_mxfp4_lut[] = { 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, }; +static inline HVX_Vector_x8 hvx_vec_load_iq4nlx4x8_full(const uint8_t * restrict ptr) { + const HVX_Vector * restrict vptr = (const HVX_Vector *) ptr; + + HVX_Vector v0_1 = vptr[0]; // first 256 elements (128 bytes) + HVX_Vector v2_3 = vptr[1]; // ... + HVX_Vector v4_5 = vptr[2]; // ... + HVX_Vector v6_7 = vptr[3]; // ... + + const HVX_Vector mask_h4 = Q6_Vb_vsplat_R(0x0F); + const HVX_Vector lut = *(const HVX_Vector *) kvalues_iq4nl_lut; + + HVX_Vector v0 = Q6_V_vand_VV(v0_1, mask_h4); // & 0x0F + HVX_Vector v1 = Q6_Vub_vlsr_VubR(v0_1, 4); // >> 4 + HVX_Vector v2 = Q6_V_vand_VV(v2_3, mask_h4); // & 0x0F + HVX_Vector v3 = Q6_Vub_vlsr_VubR(v2_3, 4); // >> 4 + HVX_Vector v4 = Q6_V_vand_VV(v4_5, mask_h4); // & 0x0F + HVX_Vector v5 = Q6_Vub_vlsr_VubR(v4_5, 4); // >> 4 + HVX_Vector v6 = Q6_V_vand_VV(v6_7, mask_h4); // & 0x0F + HVX_Vector v7 = Q6_Vub_vlsr_VubR(v6_7, 4); // >> 4 + + v0 = Q6_Vb_vlut32_VbVbI(v0, lut, 0); + v1 = Q6_Vb_vlut32_VbVbI(v1, lut, 0); + v2 = Q6_Vb_vlut32_VbVbI(v2, lut, 0); + v3 = Q6_Vb_vlut32_VbVbI(v3, lut, 0); + v4 = Q6_Vb_vlut32_VbVbI(v4, lut, 0); + v5 = Q6_Vb_vlut32_VbVbI(v5, lut, 0); + v6 = Q6_Vb_vlut32_VbVbI(v6, lut, 0); + v7 = Q6_Vb_vlut32_VbVbI(v7, lut, 0); + + HVX_Vector_x8 r = { v0, v1, v2, v3, v4, v5, v6, v7 }; + return r; +} + +static inline HVX_Vector_x8 hvx_vec_load_iq4nlx4x8_partial(const uint8_t * restrict ptr, uint32_t n) { + const HVX_Vector * restrict vptr = (const HVX_Vector *) ptr; + + const uint32_t qk = QK_Q4_0x4x2; // 256 + const uint32_t nb = n / qk; + const uint32_t nloe = n % qk; + + const HVX_Vector mask_h4 = Q6_Vb_vsplat_R(0x0F); + const HVX_Vector lut = *(const HVX_Vector *) kvalues_iq4nl_lut; + + HVX_Vector_x8 r; + uint32_t i = 0; + + #pragma unroll(2) + for (i = 0; i < nb; i++) { + HVX_Vector v = vptr[i]; // 256 elements (128 bytes) + HVX_Vector v0 = Q6_V_vand_VV(v, mask_h4); // & 0x0F : first 128 elements + HVX_Vector v1 = Q6_Vub_vlsr_VubR(v, 4); // >> 4 : second 128 elements + r.v[i * 2 + 0] = Q6_Vb_vlut32_VbVbI(v0, lut, 0); + r.v[i * 2 + 1] = Q6_Vb_vlut32_VbVbI(v1, lut, 0); + } + + if (nloe) { + HVX_Vector v = vptr[i]; // 256 elements (128 bytes) + HVX_Vector v0 = Q6_V_vand_VV(v, mask_h4); // & 0x0F : even 128 elements + HVX_Vector v1 = Q6_Vub_vlsr_VubR(v, 4); // >> 4 : odd 128 elements + HVX_VectorPair v0_1_p = Q6_W_vshuff_VVR(v1, v0, -1); // zip even:odd:... + r.v[i * 2 + 0] = Q6_Vb_vlut32_VbVbI(Q6_V_lo_W(v0_1_p), lut, 0); + r.v[i * 2 + 1] = Q6_Vb_vlut32_VbVbI(Q6_V_hi_W(v0_1_p), lut, 0); + } + + return r; +} + // q4x4x2 and q8x4x2 are the flat q4/8_0 formats where all quants are stored first followed by all scales static inline size_t q8x4x2_row_size(uint32_t ne) { // ensures perfect alignment of quants and full row const uint32_t qk = QK_Q8_0x4x2; const uint32_t nb = (ne + qk - 1) / qk; - return htp_round_up(ne + nb * 8 * sizeof(__fp16), 128); + return hex_round_up(ne + nb * 8 * sizeof(__fp16), 128); +} + +static inline size_t q8_1x4x2_row_size(uint32_t ne) { + // ensures perfect alignment of quants and full row + const uint32_t qk = QK_Q8_0x4x2; + const uint32_t nb = (ne + qk - 1) / qk; + return hex_round_up(ne + nb * 8 * 2 * sizeof(__fp16), 128); } -static inline HVX_Vector_x8 hvx_vec_load_q4x4x8(const uint8_t * restrict ptr) { +static inline HVX_Vector_x8 hvx_vec_load_q4x4x8_full(const uint8_t * restrict ptr) { const HVX_Vector * restrict vptr = (const HVX_Vector *) ptr; HVX_Vector v0_1 = vptr[0]; // first 256 elements (128 bytes) @@ -141,10 +181,11 @@ static inline HVX_Vector_x8 hvx_vec_load_q4x4x8(const uint8_t * restrict ptr) { HVX_Vector v6_7 = vptr[3]; // ... const HVX_Vector mask_h4 = Q6_Vb_vsplat_R(0x0F); + const HVX_Vector i8 = Q6_Vb_vsplat_R(8); - HVX_Vector v0 = Q6_V_vand_VV(v0_1, mask_h4); // & 0x0F - HVX_Vector v1 = Q6_Vub_vlsr_VubR(v0_1, 4); // >> 4 - HVX_Vector v2 = Q6_V_vand_VV(v2_3, mask_h4); // & 0x0F + HVX_Vector v0 = Q6_V_vand_VV(v0_1, mask_h4); // & 0x0F : first 128 elements + HVX_Vector v1 = Q6_Vub_vlsr_VubR(v0_1, 4); // >> 4 : second 128 elements + HVX_Vector v2 = Q6_V_vand_VV(v2_3, mask_h4); // & 0x0F ... HVX_Vector v3 = Q6_Vub_vlsr_VubR(v2_3, 4); // >> 4 HVX_Vector v4 = Q6_V_vand_VV(v4_5, mask_h4); // & 0x0F HVX_Vector v5 = Q6_Vub_vlsr_VubR(v4_5, 4); // >> 4 @@ -152,21 +193,110 @@ static inline HVX_Vector_x8 hvx_vec_load_q4x4x8(const uint8_t * restrict ptr) { HVX_Vector v7 = Q6_Vub_vlsr_VubR(v6_7, 4); // >> 4 // Convert uint4 to int4 (i.e. x - 8) - const HVX_Vector i8 = Q6_Vb_vsplat_R(8); - v0 = Q6_Vb_vsub_VbVb(v0, i8); - v1 = Q6_Vb_vsub_VbVb(v1, i8); - v2 = Q6_Vb_vsub_VbVb(v2, i8); - v3 = Q6_Vb_vsub_VbVb(v3, i8); - v4 = Q6_Vb_vsub_VbVb(v4, i8); - v5 = Q6_Vb_vsub_VbVb(v5, i8); - v6 = Q6_Vb_vsub_VbVb(v6, i8); - v7 = Q6_Vb_vsub_VbVb(v7, i8); + v0 = Q6_Vb_vsub_VbVb(v0, i8); + v1 = Q6_Vb_vsub_VbVb(v1, i8); + v2 = Q6_Vb_vsub_VbVb(v2, i8); + v3 = Q6_Vb_vsub_VbVb(v3, i8); + v4 = Q6_Vb_vsub_VbVb(v4, i8); + v5 = Q6_Vb_vsub_VbVb(v5, i8); + v6 = Q6_Vb_vsub_VbVb(v6, i8); + v7 = Q6_Vb_vsub_VbVb(v7, i8); + + HVX_Vector_x8 r = { v0, v1, v2, v3, v4, v5, v6, v7 }; + return r; +} + +static HVX_Vector_x8 hvx_vec_load_q4x4x8_partial(const uint8_t * restrict ptr, uint32_t n) { + const HVX_Vector * restrict vptr = (const HVX_Vector *) ptr; + + const uint32_t qk = QK_Q4_0x4x2; // 256 + const uint32_t nb = n / qk; + const uint32_t nloe = n % qk; + + const HVX_Vector mask_h4 = Q6_Vb_vsplat_R(0x0F); + const HVX_Vector i8 = Q6_Vb_vsplat_R(8); + + HVX_Vector_x8 r; + uint32_t i = 0; + + #pragma unroll(2) + for (i=0; i < nb; i++) { + HVX_Vector v = vptr[i]; // 256 elements (128 bytes) + HVX_Vector v0 = Q6_V_vand_VV(v, mask_h4); // & 0x0F : first 128 elements + HVX_Vector v1 = Q6_Vub_vlsr_VubR(v, 4); // >> 4 : second 128 elements + r.v[i*2+0] = Q6_Vb_vsub_VbVb(v0, i8); + r.v[i*2+1] = Q6_Vb_vsub_VbVb(v1, i8); + } + + if (nloe) { + HVX_Vector v = vptr[i]; // 256 elements (128 bytes) + HVX_Vector v0 = Q6_V_vand_VV(v, mask_h4); // & 0x0F : even 128 elements + HVX_Vector v1 = Q6_Vub_vlsr_VubR(v, 4); // >> 4 : odd 128 elements + HVX_VectorPair v0_1_p = Q6_W_vshuff_VVR(v1, v0, -1); // zip even:odd:... + r.v[i*2+0] = Q6_Vb_vsub_VbVb(Q6_V_lo_W(v0_1_p), i8); + r.v[i*2+1] = Q6_Vb_vsub_VbVb(Q6_V_hi_W(v0_1_p), i8); + } + + return r; +} + +static inline HVX_Vector_x8 hvx_vec_load_q4_1x4x8_full(const uint8_t * restrict ptr) { + const HVX_Vector * restrict vptr = (const HVX_Vector *) ptr; + + HVX_Vector v0_1 = vptr[0]; // first 256 elements (128 bytes) + HVX_Vector v2_3 = vptr[1]; // ... + HVX_Vector v4_5 = vptr[2]; // ... + HVX_Vector v6_7 = vptr[3]; // ... + + const HVX_Vector mask_h4 = Q6_Vb_vsplat_R(0x0F); + + HVX_Vector v0 = Q6_V_vand_VV(v0_1, mask_h4); // & 0x0F : first 128 elements + HVX_Vector v1 = Q6_Vub_vlsr_VubR(v0_1, 4); // >> 4 : second 128 elements + HVX_Vector v2 = Q6_V_vand_VV(v2_3, mask_h4); // & 0x0F ... + HVX_Vector v3 = Q6_Vub_vlsr_VubR(v2_3, 4); // >> 4 + HVX_Vector v4 = Q6_V_vand_VV(v4_5, mask_h4); // & 0x0F + HVX_Vector v5 = Q6_Vub_vlsr_VubR(v4_5, 4); // >> 4 + HVX_Vector v6 = Q6_V_vand_VV(v6_7, mask_h4); // & 0x0F + HVX_Vector v7 = Q6_Vub_vlsr_VubR(v6_7, 4); // >> 4 HVX_Vector_x8 r = { v0, v1, v2, v3, v4, v5, v6, v7 }; return r; } -static inline HVX_Vector_x8 hvx_vec_load_mxfp4x4x8(const uint8_t * restrict ptr) { +static HVX_Vector_x8 hvx_vec_load_q4_1x4x8_partial(const uint8_t * restrict ptr, uint32_t n) { + const HVX_Vector * restrict vptr = (const HVX_Vector *) ptr; + + const uint32_t qk = QK_Q4_0x4x2; // 256 + const uint32_t nb = n / qk; + const uint32_t nloe = n % qk; + + const HVX_Vector mask_h4 = Q6_Vb_vsplat_R(0x0F); + + HVX_Vector_x8 r; + uint32_t i = 0; + + #pragma unroll(2) + for (i=0; i < nb; i++) { + HVX_Vector v = vptr[i]; // 256 elements (128 bytes) + HVX_Vector v0 = Q6_V_vand_VV(v, mask_h4); // & 0x0F : first 128 elements + HVX_Vector v1 = Q6_Vub_vlsr_VubR(v, 4); // >> 4 : second 128 elements + r.v[i*2+0] = v0; + r.v[i*2+1] = v1; + } + + if (nloe) { + HVX_Vector v = vptr[i]; // 256 elements (128 bytes) + HVX_Vector v0 = Q6_V_vand_VV(v, mask_h4); // & 0x0F : even 128 elements + HVX_Vector v1 = Q6_Vub_vlsr_VubR(v, 4); // >> 4 : odd 128 elements + HVX_VectorPair v0_1_p = Q6_W_vshuff_VVR(v1, v0, -1); // zip even:odd:... + r.v[i*2+0] = Q6_V_lo_W(v0_1_p); + r.v[i*2+1] = Q6_V_hi_W(v0_1_p); + } + + return r; +} + +static inline HVX_Vector_x8 hvx_vec_load_mxfp4x4x8_full(const uint8_t * restrict ptr) { const HVX_Vector * restrict vptr = (const HVX_Vector *) ptr; HVX_Vector v0_1 = vptr[0]; // first 256 elements (128 bytes) @@ -175,6 +305,7 @@ static inline HVX_Vector_x8 hvx_vec_load_mxfp4x4x8(const uint8_t * restrict ptr) HVX_Vector v6_7 = vptr[3]; // ... const HVX_Vector mask_h4 = Q6_Vb_vsplat_R(0x0F); + const HVX_Vector lut = *(const HVX_Vector *) kvalues_mxfp4_lut; HVX_Vector v0 = Q6_V_vand_VV(v0_1, mask_h4); // & 0x0F HVX_Vector v1 = Q6_Vub_vlsr_VubR(v0_1, 4); // >> 4 @@ -185,21 +316,54 @@ static inline HVX_Vector_x8 hvx_vec_load_mxfp4x4x8(const uint8_t * restrict ptr) HVX_Vector v6 = Q6_V_vand_VV(v6_7, mask_h4); // & 0x0F HVX_Vector v7 = Q6_Vub_vlsr_VubR(v6_7, 4); // >> 4 - HVX_Vector lut = *(const HVX_Vector *) kvalues_mxfp4_lut; - v0 = Q6_Vb_vlut32_VbVbI(v0, lut, 0); - v1 = Q6_Vb_vlut32_VbVbI(v1, lut, 0); - v2 = Q6_Vb_vlut32_VbVbI(v2, lut, 0); - v3 = Q6_Vb_vlut32_VbVbI(v3, lut, 0); - v4 = Q6_Vb_vlut32_VbVbI(v4, lut, 0); - v5 = Q6_Vb_vlut32_VbVbI(v5, lut, 0); - v6 = Q6_Vb_vlut32_VbVbI(v6, lut, 0); - v7 = Q6_Vb_vlut32_VbVbI(v7, lut, 0); + v0 = Q6_Vb_vlut32_VbVbI(v0, lut, 0); + v1 = Q6_Vb_vlut32_VbVbI(v1, lut, 0); + v2 = Q6_Vb_vlut32_VbVbI(v2, lut, 0); + v3 = Q6_Vb_vlut32_VbVbI(v3, lut, 0); + v4 = Q6_Vb_vlut32_VbVbI(v4, lut, 0); + v5 = Q6_Vb_vlut32_VbVbI(v5, lut, 0); + v6 = Q6_Vb_vlut32_VbVbI(v6, lut, 0); + v7 = Q6_Vb_vlut32_VbVbI(v7, lut, 0); HVX_Vector_x8 r = { v0, v1, v2, v3, v4, v5, v6, v7 }; return r; } -static inline HVX_Vector_x8 hvx_vec_load_q8x4x8(const uint8_t * restrict ptr) { +static inline HVX_Vector_x8 hvx_vec_load_mxfp4x4x8_partial(const uint8_t * restrict ptr, uint32_t n) { + const HVX_Vector * restrict vptr = (const HVX_Vector *) ptr; + + const uint32_t qk = QK_Q4_0x4x2; // 256 + const uint32_t nb = n / qk; + const uint32_t nloe = n % qk; + + const HVX_Vector mask_h4 = Q6_Vb_vsplat_R(0x0F); + const HVX_Vector lut = *(const HVX_Vector *) kvalues_mxfp4_lut; + + HVX_Vector_x8 r; + uint32_t i = 0; + + #pragma unroll(2) + for (i=0; i < nb; i++) { + HVX_Vector v = vptr[i]; // 256 elements (128 bytes) + HVX_Vector v0 = Q6_V_vand_VV(v, mask_h4); // & 0x0F : first 128 elements + HVX_Vector v1 = Q6_Vub_vlsr_VubR(v, 4); // >> 4 : second 128 elements + r.v[i*2+0] = Q6_Vb_vlut32_VbVbI(v0, lut, 0); + r.v[i*2+1] = Q6_Vb_vlut32_VbVbI(v1, lut, 0); + } + + if (nloe) { + HVX_Vector v = vptr[i]; // 256 elements (128 bytes) + HVX_Vector v0 = Q6_V_vand_VV(v, mask_h4); // & 0x0F : even 128 elements + HVX_Vector v1 = Q6_Vub_vlsr_VubR(v, 4); // >> 4 : odd 128 elements + HVX_VectorPair v0_1_p = Q6_W_vshuff_VVR(v1, v0, -1); // zip even:odd:... + r.v[i*2+0] = Q6_Vb_vlut32_VbVbI(Q6_V_lo_W(v0_1_p), lut, 0); + r.v[i*2+1] = Q6_Vb_vlut32_VbVbI(Q6_V_hi_W(v0_1_p), lut, 0); + } + + return r; +} + +static inline HVX_Vector_x8 hvx_vec_load_q8x4x8_full(const uint8_t * restrict ptr) { const HVX_Vector * restrict vptr = (const HVX_Vector *) ptr; HVX_Vector v0 = vptr[0]; // first 128 vals @@ -215,44 +379,8 @@ static inline HVX_Vector_x8 hvx_vec_load_q8x4x8(const uint8_t * restrict ptr) { return r; } -static inline HVX_Vector_x4 hvx_vec_load_x4_f16(const uint8_t * restrict ptr) { - const HVX_Vector * restrict vptr = (const HVX_Vector *) ptr; - - HVX_Vector v0 = vptr[0]; // first 64 vals - HVX_Vector v1 = vptr[1]; // second 64 vals - HVX_Vector v2 = vptr[2]; // third 64 vals - HVX_Vector v3 = vptr[3]; // forth 64 vals - - HVX_Vector_x4 r = { v0, v1, v2, v3 }; - return r; -} - -static inline HVX_Vector_x4 hvx_vec_load_x4_f32_as_f16(const uint8_t * restrict ptr) { - const HVX_VectorPair * restrict vptr = (const HVX_VectorPair *) ptr; - - HVX_VectorPair v0 = vptr[0]; // first 64 vals - HVX_VectorPair v1 = vptr[1]; // second 64 vals - HVX_VectorPair v2 = vptr[2]; // third 64 vals - HVX_VectorPair v3 = vptr[3]; // forth 64 vals - - HVX_Vector vq0_lo = Q6_Vqf32_vsub_VsfVsf(Q6_V_lo_W(v0), Q6_V_vzero()); - HVX_Vector vq0_hi = Q6_Vqf32_vsub_VsfVsf(Q6_V_hi_W(v0), Q6_V_vzero()); - HVX_Vector vq1_lo = Q6_Vqf32_vsub_VsfVsf(Q6_V_lo_W(v1), Q6_V_vzero()); - HVX_Vector vq1_hi = Q6_Vqf32_vsub_VsfVsf(Q6_V_hi_W(v1), Q6_V_vzero()); - HVX_Vector vq2_lo = Q6_Vqf32_vsub_VsfVsf(Q6_V_lo_W(v2), Q6_V_vzero()); - HVX_Vector vq2_hi = Q6_Vqf32_vsub_VsfVsf(Q6_V_hi_W(v2), Q6_V_vzero()); - HVX_Vector vq3_lo = Q6_Vqf32_vsub_VsfVsf(Q6_V_lo_W(v3), Q6_V_vzero()); - HVX_Vector vq3_hi = Q6_Vqf32_vsub_VsfVsf(Q6_V_hi_W(v3), Q6_V_vzero()); - - HVX_Vector vh0 = Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(vq0_hi, vq0_lo)); - HVX_Vector vh1 = Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(vq1_hi, vq1_lo)); - HVX_Vector vh2 = Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(vq2_hi, vq2_lo)); - HVX_Vector vh3 = Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(vq3_hi, vq3_lo)); - - // vcombine does a shuffle, use vdeal to undo - - HVX_Vector_x4 r = { Q6_Vh_vdeal_Vh(vh0), Q6_Vh_vdeal_Vh(vh1), Q6_Vh_vdeal_Vh(vh2), Q6_Vh_vdeal_Vh(vh3) }; - return r; +static inline HVX_Vector_x8 hvx_vec_load_q8x4x8_partial(const uint8_t * restrict ptr, uint32_t nloe) { + return hvx_vec_load_q8x4x8_full(ptr); } // Reduce multiply 1024 x 1024 int8 elements (32x q4/8 blocks in 8x HVX vectors). @@ -262,14 +390,14 @@ static inline HVX_Vector_x4 hvx_vec_load_x4_f32_as_f16(const uint8_t * restrict // if() checks are optimized out at compile time -- make sure to pass N as a constexpr. static inline HVX_Vector hvx_vec_rmpy_x8_n(HVX_Vector_x8 x, HVX_Vector_x8 y, unsigned int n) { - HVX_Vector r0 = Q6_V_vsplat_R(0); - HVX_Vector r1 = Q6_V_vsplat_R(0); - HVX_Vector r2 = Q6_V_vsplat_R(0); - HVX_Vector r3 = Q6_V_vsplat_R(0); - HVX_Vector r4 = Q6_V_vsplat_R(0); - HVX_Vector r5 = Q6_V_vsplat_R(0); - HVX_Vector r6 = Q6_V_vsplat_R(0); - HVX_Vector r7 = Q6_V_vsplat_R(0); + HVX_Vector r0 = Q6_V_vzero(); + HVX_Vector r1 = Q6_V_vzero(); + HVX_Vector r2 = Q6_V_vzero(); + HVX_Vector r3 = Q6_V_vzero(); + HVX_Vector r4 = Q6_V_vzero(); + HVX_Vector r5 = Q6_V_vzero(); + HVX_Vector r6 = Q6_V_vzero(); + HVX_Vector r7 = Q6_V_vzero(); HVX_VectorPair p3; HVX_VectorPair p2; @@ -308,370 +436,1917 @@ static inline HVX_Vector hvx_vec_rmpy_x8_n(HVX_Vector_x8 x, HVX_Vector_x8 y, uns } static inline HVX_Vector hvx_vec_rmpy_x8_full(HVX_Vector_x8 x, HVX_Vector_x8 y) { - return hvx_vec_rmpy_x8_n(x, y, 1024); + HVX_Vector r0 = Q6_Vw_vrmpy_VbVb(x.v[0], y.v[0]); + HVX_Vector r1 = Q6_Vw_vrmpy_VbVb(x.v[1], y.v[1]); + HVX_Vector r2 = Q6_Vw_vrmpy_VbVb(x.v[2], y.v[2]); + HVX_Vector r3 = Q6_Vw_vrmpy_VbVb(x.v[3], y.v[3]); + HVX_Vector r4 = Q6_Vw_vrmpy_VbVb(x.v[4], y.v[4]); + HVX_Vector r5 = Q6_Vw_vrmpy_VbVb(x.v[5], y.v[5]); + HVX_Vector r6 = Q6_Vw_vrmpy_VbVb(x.v[6], y.v[6]); + HVX_Vector r7 = Q6_Vw_vrmpy_VbVb(x.v[7], y.v[7]); + + HVX_VectorPair p0 = Q6_W_vdeal_VVR(r1, r0, -4); + HVX_VectorPair p1 = Q6_W_vdeal_VVR(r3, r2, -4); + HVX_VectorPair p2 = Q6_W_vdeal_VVR(r5, r4, -4); + HVX_VectorPair p3 = Q6_W_vdeal_VVR(r7, r6, -4); + + r0 = Q6_Vw_vadd_VwVw(Q6_V_lo_W(p0), Q6_V_hi_W(p0)); + r1 = Q6_Vw_vadd_VwVw(Q6_V_lo_W(p1), Q6_V_hi_W(p1)); + r2 = Q6_Vw_vadd_VwVw(Q6_V_lo_W(p2), Q6_V_hi_W(p2)); + r3 = Q6_Vw_vadd_VwVw(Q6_V_lo_W(p3), Q6_V_hi_W(p3)); + + p0 = Q6_W_vdeal_VVR(r1, r0, -4); + p1 = Q6_W_vdeal_VVR(r3, r2, -4); + + r0 = Q6_Vw_vadd_VwVw(Q6_V_lo_W(p0), Q6_V_hi_W(p0)); + r1 = Q6_Vw_vadd_VwVw(Q6_V_lo_W(p1), Q6_V_hi_W(p1)); + + p0 = Q6_W_vdeal_VVR(r1, r0, -4); + r0 = Q6_Vw_vadd_VwVw(Q6_V_lo_W(p0), Q6_V_hi_W(p0)); + + return r0; } -// Handle most common cases of tensors not multiple of 1024. -static inline HVX_Vector hvx_vec_rmpy_x8_nloe(HVX_Vector_x8 x, HVX_Vector_x8 y, unsigned int n) { - if (n <= 256) { return hvx_vec_rmpy_x8_n(x, y, 256); }; - if (n <= 512) { return hvx_vec_rmpy_x8_n(x, y, 512); }; - if (n <= 768) { return hvx_vec_rmpy_x8_n(x, y, 768); }; - return hvx_vec_rmpy_x8_n(x, y, 1024); +static inline HVX_Vector hvx_vec_rmpy_x8_partial(HVX_Vector_x8 x, HVX_Vector_x8 y, unsigned int n) { + if (n >= 512) + return hvx_vec_rmpy_x8_full(x, y); + + return hvx_vec_rmpy_x8_partial(x, y, 512); } -static void vec_dot_q4x4x2_q8x4x2(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) { +static void vec_dot_q4_1x4x2_q8x4x2_1x1(const int n, float * restrict s0, const void * restrict vx0, const void * restrict vy0) { assert(n % 32 == 0); // min sub-block size - assert((unsigned long) vx % 128 == 0); - assert((unsigned long) vy % 128 == 0); + assert((unsigned long) vx0 % 128 == 0); + assert((unsigned long) vy0 % 128 == 0); const uint32_t qk = QK_Q4_0x4x2 * 4; - const uint32_t x_dblk_size = 8 * 4 * 2; // 32x __fp16 - const uint32_t x_qblk_size = qk / 2; // int4 - const uint32_t x_qrow_size = n / 2; // int4 (not padded) - - const uint32_t y_dblk_size = 8 * 4 * 2; // 32x __fp16 - const uint32_t y_qblk_size = qk; // int8 - const uint32_t y_qrow_size = n; // int8 (not padded) + const uint32_t x_dblk_size = 8 * 4 * 2 * 2; // 32x (d, m) __fp16 = 128 bytes + const uint32_t x_qblk_size = qk / 2; // int4 + const uint32_t x_qrow_size = n / 2; // int4 (not padded) - const uint8_t * restrict r0_x_q = ((const uint8_t *) vx + 0); // quants first - const uint8_t * restrict r0_x_d = ((const uint8_t *) vx + x_qrow_size); // then scales + const uint32_t y_dblk_size = 8 * 4 * 4; // 32x (d, s) __fp16 = 128 bytes + const uint32_t y_qblk_size = qk; // int8 + const uint32_t y_qrow_size = n; // int8 (not padded) - const uint8_t * restrict y_q = ((const uint8_t *) vy + 0); // quants first - const uint8_t * restrict y_d = ((const uint8_t *) vy + y_qrow_size); // then scales + const uint8_t * restrict r0_x_q = ((const uint8_t *) vx0 + 0); // quants first + const uint8_t * restrict r0_x_d = ((const uint8_t *) vx0 + x_qrow_size); // then scales/offsets - // Row sum (qf32) - HVX_Vector r0_sum = Q6_V_vsplat_R(0); + const uint8_t * restrict y_q = ((const uint8_t *) vy0 + 0); // quants first + const uint8_t * restrict y_d = ((const uint8_t *) vy0 + y_qrow_size); // then scales/sums - // Multiply and accumulate into int32. - // Compute combined scale (fp32). - // Apply scale to acc and accumulate into the row sum (qf32). + // Row sum (sf) + HVX_Vector r0_sum = Q6_V_vzero(); const uint32_t nb = n / qk; // num full blocks const uint32_t nloe = n % qk; // num leftover elemements uint32_t i = 0; for (; i < nb; i++) { - HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8(y_q + i * y_qblk_size); - HVX_Vector_x8 r0_q = hvx_vec_load_q4x4x8(r0_x_q + i * x_qblk_size); + HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8_full(y_q + i * y_qblk_size); + HVX_Vector_x8 r0_q = hvx_vec_load_q4_1x4x8_full(r0_x_q + i * x_qblk_size); HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy_q)); - HVX_Vector vy_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_d + i * y_dblk_size)); - HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size)); + HVX_Vector ds = *(const HVX_UVector *) (y_d + i * y_dblk_size); + HVX_VectorPair ds_deal = Q6_W_vdeal_VVR(ds, ds, -2); + HVX_Vector vy_d = Q6_Vh_vshuff_Vh(Q6_V_lo_W(ds_deal)); + HVX_Vector vy_s = Q6_Vh_vshuff_Vh(Q6_V_hi_W(ds_deal)); + + HVX_Vector dm = *(const HVX_UVector *) (r0_x_d + i * x_dblk_size); + HVX_VectorPair dm_deal = Q6_W_vdeal_VVR(dm, dm, -2); + HVX_Vector r0_d = Q6_Vh_vshuff_Vh(Q6_V_lo_W(dm_deal)); + HVX_Vector r0_m = Q6_Vh_vshuff_Vh(Q6_V_hi_W(dm_deal)); HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy_d))); + HVX_Vector r0_ms = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_m, vy_s))); HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd); + HVX_Vector r0_fa_total = Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_ms); - r0_sum = Q6_Vqf32_vadd_Vqf32Vqf32(r0_sum, r0_fa); + r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa_total, r0_sum)); } - // Process leftovers, we still load full 4x4x2 block but zero out unused scales/blocks + // Process leftovers if (nloe) { - HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8(y_q + i * y_qblk_size); - HVX_Vector_x8 r0_q = hvx_vec_load_q4x4x8(r0_x_q + i * x_qblk_size); + HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8_partial(y_q + i * y_qblk_size, nloe); + HVX_Vector_x8 r0_q = hvx_vec_load_q4_1x4x8_partial(r0_x_q + i * x_qblk_size, nloe); - HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_nloe(r0_q, vy_q, nloe)); + HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r0_q, vy_q, nloe)); - HVX_Vector vy_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_d + i * y_dblk_size)); - HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size)); + HVX_Vector ds = *(const HVX_UVector *) (y_d + i * y_dblk_size); + HVX_VectorPair ds_deal = Q6_W_vdeal_VVR(ds, ds, -2); + HVX_Vector vy_d = Q6_Vh_vshuff_Vh(Q6_V_lo_W(ds_deal)); + HVX_Vector vy_s = Q6_Vh_vshuff_Vh(Q6_V_hi_W(ds_deal)); + + HVX_Vector dm = *(const HVX_UVector *) (r0_x_d + i * x_dblk_size); + HVX_VectorPair dm_deal = Q6_W_vdeal_VVR(dm, dm, -2); + HVX_Vector r0_d = Q6_Vh_vshuff_Vh(Q6_V_lo_W(dm_deal)); + HVX_Vector r0_m = Q6_Vh_vshuff_Vh(Q6_V_hi_W(dm_deal)); HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy_d))); + HVX_Vector r0_ms = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_m, vy_s))); - // Zero out unused scales + // Zero out unused elements HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe / 8); r0_dd = Q6_V_vand_QV(bmask, r0_dd); + r0_ms = Q6_V_vand_QV(bmask, r0_ms); + r0_ia = Q6_V_vand_QV(bmask, r0_ia); HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd); + HVX_Vector r0_fa_total = Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_ms); - r0_sum = Q6_Vqf32_vadd_Vqf32Vqf32(r0_sum, r0_fa); + r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa_total, r0_sum)); } - // Reduce and convert into fp32 - r0_sum = hvx_vec_fp32_reduce_sum(Q6_Vsf_equals_Vqf32(r0_sum)); - - hvx_vec_store_u(&s[0], 4, r0_sum); + r0_sum = hvx_vec_reduce_sum_f32(r0_sum); + hvx_vec_store_u(s0, 4, r0_sum); } -static void vec_dot_q4x4x2_q8x4x2_rx2(const int n, - float * restrict s, - const void * restrict vx, - uint32_t vx_row_size, - const void * restrict vy) { +static void vec_dot_q4_1x4x2_q8x4x2_2x1(const int n, float * restrict s0, + const void * restrict vx0, const void * restrict vx1, + const void * restrict vy0) { assert(n % 32 == 0); // min sub-block size - assert((unsigned long) vx % 128 == 0); - assert((unsigned long) vy % 128 == 0); + assert((unsigned long) vx0 % 128 == 0); + assert((unsigned long) vx1 % 128 == 0); + assert((unsigned long) vy0 % 128 == 0); const uint32_t qk = QK_Q4_0x4x2 * 4; - const uint32_t x_dblk_size = 8 * 4 * 2; // 32x __fp16 - const uint32_t x_qblk_size = qk / 2; // int4 - const uint32_t x_qrow_size = n / 2; // int4 (not padded) - - const uint32_t y_dblk_size = 8 * 4 * 2; // 32x __fp16 - const uint32_t y_qblk_size = qk; // int8 - const uint32_t y_qrow_size = n; // int8 (not padded) + const uint32_t x_dblk_size = 8 * 4 * 2 * 2; // 32x (d, m) __fp16 = 128 bytes + const uint32_t x_qblk_size = qk / 2; // int4 + const uint32_t x_qrow_size = n / 2; // int4 (not padded) - const uint8_t * restrict r0_x_q = ((const uint8_t *) (vx + (0 * vx_row_size)) + 0); // quants first - const uint8_t * restrict r0_x_d = ((const uint8_t *) (vx + (0 * vx_row_size)) + x_qrow_size); // then scales + const uint32_t y_dblk_size = 8 * 4 * 4; // 32x (d, s) __fp16 = 128 bytes + const uint32_t y_qblk_size = qk; // int8 + const uint32_t y_qrow_size = n; // int8 (not padded) - const uint8_t * restrict r1_x_q = ((const uint8_t *) (vx + (1 * vx_row_size)) + 0); // quants first - const uint8_t * restrict r1_x_d = ((const uint8_t *) (vx + (1 * vx_row_size)) + x_qrow_size); // then scales + const uint8_t * restrict r0_x_q = ((const uint8_t *) vx0) + 0; // quants first + const uint8_t * restrict r0_x_d = ((const uint8_t *) vx0) + x_qrow_size; // then scales + const uint8_t * restrict r1_x_q = ((const uint8_t *) vx1) + 0; // quants first + const uint8_t * restrict r1_x_d = ((const uint8_t *) vx1) + x_qrow_size; // then scales - const uint8_t * restrict y_q = ((const uint8_t *) vy + 0); // quants first - const uint8_t * restrict y_d = ((const uint8_t *) vy + y_qrow_size); // then scales + const uint8_t * restrict y_q = ((const uint8_t *) vy0 + 0); // quants first + const uint8_t * restrict y_d = ((const uint8_t *) vy0 + y_qrow_size); // then scales/sums - // Row sum (qf32) - HVX_Vector r0_sum = Q6_V_vsplat_R(0); - HVX_Vector r1_sum = Q6_V_vsplat_R(0); - - // Multiply and accumulate into int32. - // Compute combined scale (fp32). - // Apply scale to acc and accumulate into the row sum (qf32). + // Row sum (sf) + HVX_Vector r0_sum = Q6_V_vzero(); + HVX_Vector r1_sum = Q6_V_vzero(); const uint32_t nb = n / qk; // num full blocks const uint32_t nloe = n % qk; // num leftover elemements uint32_t i = 0; for (; i < nb; i++) { - HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8(y_q + i * y_qblk_size); - HVX_Vector_x8 r0_q = hvx_vec_load_q4x4x8(r0_x_q + i * x_qblk_size); - HVX_Vector_x8 r1_q = hvx_vec_load_q4x4x8(r1_x_q + i * x_qblk_size); + HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8_full(y_q + i * y_qblk_size); + HVX_Vector_x8 r0_q = hvx_vec_load_q4_1x4x8_full(r0_x_q + i * x_qblk_size); + HVX_Vector_x8 r1_q = hvx_vec_load_q4_1x4x8_full(r1_x_q + i * x_qblk_size); HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy_q)); HVX_Vector r1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r1_q, vy_q)); - HVX_Vector vy_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_d + i * y_dblk_size)); - HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size)); - HVX_Vector r1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r1_x_d + i * x_dblk_size)); + HVX_Vector ds = *(const HVX_UVector *) (y_d + i * y_dblk_size); + HVX_VectorPair ds_deal = Q6_W_vdeal_VVR(ds, ds, -2); + HVX_Vector vy_d = Q6_Vh_vshuff_Vh(Q6_V_lo_W(ds_deal)); + HVX_Vector vy_s = Q6_Vh_vshuff_Vh(Q6_V_hi_W(ds_deal)); + + HVX_Vector r0_dm = *(const HVX_UVector *) (r0_x_d + i * x_dblk_size); + HVX_VectorPair r0_dm_deal = Q6_W_vdeal_VVR(r0_dm, r0_dm, -2); + HVX_Vector r0_d = Q6_Vh_vshuff_Vh(Q6_V_lo_W(r0_dm_deal)); + HVX_Vector r0_m = Q6_Vh_vshuff_Vh(Q6_V_hi_W(r0_dm_deal)); + + HVX_Vector r1_dm = *(const HVX_UVector *) (r1_x_d + i * x_dblk_size); + HVX_VectorPair r1_dm_deal = Q6_W_vdeal_VVR(r1_dm, r1_dm, -2); + HVX_Vector r1_d = Q6_Vh_vshuff_Vh(Q6_V_lo_W(r1_dm_deal)); + HVX_Vector r1_m = Q6_Vh_vshuff_Vh(Q6_V_hi_W(r1_dm_deal)); HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy_d))); + HVX_Vector r0_ms = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_m, vy_s))); + HVX_Vector r1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy_d))); + HVX_Vector r1_ms = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_m, vy_s))); HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd); + HVX_Vector r0_fa_total = Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_ms); + HVX_Vector r1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_ia, r1_dd); + HVX_Vector r1_fa_total = Q6_Vqf32_vadd_Vqf32Vsf(r1_fa, r1_ms); - r0_sum = Q6_Vqf32_vadd_Vqf32Vqf32(r0_sum, r0_fa); - r1_sum = Q6_Vqf32_vadd_Vqf32Vqf32(r1_sum, r1_fa); + r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa_total, r0_sum)); + r1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_fa_total, r1_sum)); } - // Process leftovers, we still load full 4x4x2 block but zero out unused scales/blocks + // Process leftovers if (nloe) { - HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8(y_q + i * y_qblk_size); - HVX_Vector_x8 r0_q = hvx_vec_load_q4x4x8(r0_x_q + i * x_qblk_size); - HVX_Vector_x8 r1_q = hvx_vec_load_q4x4x8(r1_x_q + i * x_qblk_size); + HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8_partial(y_q + i * y_qblk_size, nloe); + HVX_Vector_x8 r0_q = hvx_vec_load_q4_1x4x8_partial(r0_x_q + i * x_qblk_size, nloe); + HVX_Vector_x8 r1_q = hvx_vec_load_q4_1x4x8_partial(r1_x_q + i * x_qblk_size, nloe); - HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_nloe(r0_q, vy_q, nloe)); - HVX_Vector r1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_nloe(r1_q, vy_q, nloe)); + HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r0_q, vy_q, nloe)); + HVX_Vector r1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r1_q, vy_q, nloe)); - HVX_Vector vy_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_d + i * y_dblk_size)); - HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size)); - HVX_Vector r1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r1_x_d + i * x_dblk_size)); + HVX_Vector ds = *(const HVX_UVector *) (y_d + i * y_dblk_size); + HVX_VectorPair ds_deal = Q6_W_vdeal_VVR(ds, ds, -2); + HVX_Vector vy_d = Q6_Vh_vshuff_Vh(Q6_V_lo_W(ds_deal)); + HVX_Vector vy_s = Q6_Vh_vshuff_Vh(Q6_V_hi_W(ds_deal)); + + HVX_Vector r0_dm = *(const HVX_UVector *) (r0_x_d + i * x_dblk_size); + HVX_VectorPair r0_dm_deal = Q6_W_vdeal_VVR(r0_dm, r0_dm, -2); + HVX_Vector r0_d = Q6_Vh_vshuff_Vh(Q6_V_lo_W(r0_dm_deal)); + HVX_Vector r0_m = Q6_Vh_vshuff_Vh(Q6_V_hi_W(r0_dm_deal)); + + HVX_Vector r1_dm = *(const HVX_UVector *) (r1_x_d + i * x_dblk_size); + HVX_VectorPair r1_dm_deal = Q6_W_vdeal_VVR(r1_dm, r1_dm, -2); + HVX_Vector r1_d = Q6_Vh_vshuff_Vh(Q6_V_lo_W(r1_dm_deal)); + HVX_Vector r1_m = Q6_Vh_vshuff_Vh(Q6_V_hi_W(r1_dm_deal)); HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy_d))); + HVX_Vector r0_ms = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_m, vy_s))); + HVX_Vector r1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy_d))); + HVX_Vector r1_ms = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_m, vy_s))); - // Zero out unused scales + // Zero out unused elements HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe / 8); r0_dd = Q6_V_vand_QV(bmask, r0_dd); + r0_ms = Q6_V_vand_QV(bmask, r0_ms); r1_dd = Q6_V_vand_QV(bmask, r1_dd); + r1_ms = Q6_V_vand_QV(bmask, r1_ms); + r0_ia = Q6_V_vand_QV(bmask, r0_ia); + r1_ia = Q6_V_vand_QV(bmask, r1_ia); HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd); + HVX_Vector r0_fa_total = Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_ms); + HVX_Vector r1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_ia, r1_dd); + HVX_Vector r1_fa_total = Q6_Vqf32_vadd_Vqf32Vsf(r1_fa, r1_ms); - r0_sum = Q6_Vqf32_vadd_Vqf32Vqf32(r0_sum, r0_fa); - r1_sum = Q6_Vqf32_vadd_Vqf32Vqf32(r1_sum, r1_fa); + r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa_total, r0_sum)); + r1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_fa_total, r1_sum)); } - // Convert into fp32 and reduce - r0_sum = hvx_vec_fp32_reduce_sum(Q6_Vsf_equals_Vqf32(r0_sum)); - r1_sum = hvx_vec_fp32_reduce_sum(Q6_Vsf_equals_Vqf32(r1_sum)); - HVX_VectorPair p0 = Q6_W_vshuff_VVR(r1_sum, r0_sum, 4); - - hvx_vec_store_u(&s[0], 8, Q6_V_lo_W(p0)); + HVX_Vector rsum = hvx_vec_reduce_sum_f32x2(r0_sum, r1_sum); + hvx_vec_store_u(s0, 8, rsum); } -static void vec_dot_q8x4x2_q8x4x2(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) { +static void vec_dot_q4_1x4x2_q8x4x2_4x1(const int n, float * restrict s0, + const void * restrict vx0, const void * restrict vx1, + const void * restrict vx2, const void * restrict vx3, + const void * restrict vy0) { assert(n % 32 == 0); // min sub-block size - assert((unsigned long) vx % 128 == 0); - assert((unsigned long) vy % 128 == 0); + assert((unsigned long) vx0 % 128 == 0); + assert((unsigned long) vx1 % 128 == 0); + assert((unsigned long) vx2 % 128 == 0); + assert((unsigned long) vx3 % 128 == 0); + assert((unsigned long) vy0 % 128 == 0); const uint32_t qk = QK_Q4_0x4x2 * 4; - const uint32_t x_dblk_size = 8 * 4 * 2; // 32x __fp16 - const uint32_t x_qblk_size = qk; // int8 - const uint32_t x_qrow_size = n; // int8 (not padded) - - const uint32_t y_dblk_size = 8 * 4 * 2; // 32x __fp16 - const uint32_t y_qblk_size = qk; // int8 - const uint32_t y_qrow_size = n; // int8 (not padded) + const uint32_t x_dblk_size = 8 * 4 * 2 * 2; // 32x (d, m) __fp16 = 128 bytes + const uint32_t x_qblk_size = qk / 2; // int4 + const uint32_t x_qrow_size = n / 2; // int4 (not padded) - const uint8_t * restrict r0_x_q = ((const uint8_t *) vx + 0); // quants first - const uint8_t * restrict r0_x_d = ((const uint8_t *) vx + x_qrow_size); // then scales + const uint32_t y_dblk_size = 8 * 4 * 4; // 32x (d, s) __fp16 = 128 bytes + const uint32_t y_qblk_size = qk; // int8 + const uint32_t y_qrow_size = n; // int8 (not padded) - const uint8_t * restrict y_q = ((const uint8_t *) vy + 0); // quants first - const uint8_t * restrict y_d = ((const uint8_t *) vy + y_qrow_size); // then scales + const uint8_t * restrict r0_x_q = ((const uint8_t *) vx0) + 0; // quants first + const uint8_t * restrict r0_x_d = ((const uint8_t *) vx0) + x_qrow_size; // then scales + const uint8_t * restrict r1_x_q = ((const uint8_t *) vx1) + 0; // quants first + const uint8_t * restrict r1_x_d = ((const uint8_t *) vx1) + x_qrow_size; // then scales + const uint8_t * restrict r2_x_q = ((const uint8_t *) vx2) + 0; // quants first + const uint8_t * restrict r2_x_d = ((const uint8_t *) vx2) + x_qrow_size; // then scales + const uint8_t * restrict r3_x_q = ((const uint8_t *) vx3) + 0; // quants first + const uint8_t * restrict r3_x_d = ((const uint8_t *) vx3) + x_qrow_size; // then scales - // Row sum (qf32) - HVX_Vector r0_sum = Q6_V_vsplat_R(0); + const uint8_t * restrict y_q = ((const uint8_t *) vy0 + 0); // quants first + const uint8_t * restrict y_d = ((const uint8_t *) vy0 + y_qrow_size); // then scales/sums - // Multiply and accumulate into int32. - // Compute combined scale (fp32). - // Apply scale to acc and accumulate into the row sum (qf32). + // Row sum (sf) + HVX_Vector r0_sum = Q6_V_vzero(); + HVX_Vector r1_sum = Q6_V_vzero(); + HVX_Vector r2_sum = Q6_V_vzero(); + HVX_Vector r3_sum = Q6_V_vzero(); const uint32_t nb = n / qk; // num full blocks - int32_t nloe = n % qk; // num leftover elemements (must be signed) + const uint32_t nloe = n % qk; // num leftover elements uint32_t i = 0; for (; i < nb; i++) { - HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8(y_q + i * y_qblk_size); - HVX_Vector_x8 r0_q = hvx_vec_load_q8x4x8(r0_x_q + i * x_qblk_size); + HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8_full(y_q + i * y_qblk_size); + HVX_Vector_x8 r0_q = hvx_vec_load_q4_1x4x8_full(r0_x_q + i * x_qblk_size); + HVX_Vector_x8 r1_q = hvx_vec_load_q4_1x4x8_full(r1_x_q + i * x_qblk_size); + HVX_Vector_x8 r2_q = hvx_vec_load_q4_1x4x8_full(r2_x_q + i * x_qblk_size); + HVX_Vector_x8 r3_q = hvx_vec_load_q4_1x4x8_full(r3_x_q + i * x_qblk_size); HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy_q)); - - HVX_Vector vy_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_d + i * y_dblk_size)); - HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size)); + HVX_Vector r1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r1_q, vy_q)); + HVX_Vector r2_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r2_q, vy_q)); + HVX_Vector r3_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r3_q, vy_q)); + + HVX_Vector ds = *(const HVX_UVector *) (y_d + i * y_dblk_size); + HVX_VectorPair ds_deal = Q6_W_vdeal_VVR(ds, ds, -2); + HVX_Vector vy_d = Q6_Vh_vshuff_Vh(Q6_V_lo_W(ds_deal)); + HVX_Vector vy_s = Q6_Vh_vshuff_Vh(Q6_V_hi_W(ds_deal)); + + HVX_Vector r0_dm = *(const HVX_UVector *) (r0_x_d + i * x_dblk_size); + HVX_VectorPair r0_dm_deal = Q6_W_vdeal_VVR(r0_dm, r0_dm, -2); + HVX_Vector r0_d = Q6_Vh_vshuff_Vh(Q6_V_lo_W(r0_dm_deal)); + HVX_Vector r0_m = Q6_Vh_vshuff_Vh(Q6_V_hi_W(r0_dm_deal)); + + HVX_Vector r1_dm = *(const HVX_UVector *) (r1_x_d + i * x_dblk_size); + HVX_VectorPair r1_dm_deal = Q6_W_vdeal_VVR(r1_dm, r1_dm, -2); + HVX_Vector r1_d = Q6_Vh_vshuff_Vh(Q6_V_lo_W(r1_dm_deal)); + HVX_Vector r1_m = Q6_Vh_vshuff_Vh(Q6_V_hi_W(r1_dm_deal)); + + HVX_Vector r2_dm = *(const HVX_UVector *) (r2_x_d + i * x_dblk_size); + HVX_VectorPair r2_dm_deal = Q6_W_vdeal_VVR(r2_dm, r2_dm, -2); + HVX_Vector r2_d = Q6_Vh_vshuff_Vh(Q6_V_lo_W(r2_dm_deal)); + HVX_Vector r2_m = Q6_Vh_vshuff_Vh(Q6_V_hi_W(r2_dm_deal)); + + HVX_Vector r3_dm = *(const HVX_UVector *) (r3_x_d + i * x_dblk_size); + HVX_VectorPair r3_dm_deal = Q6_W_vdeal_VVR(r3_dm, r3_dm, -2); + HVX_Vector r3_d = Q6_Vh_vshuff_Vh(Q6_V_lo_W(r3_dm_deal)); + HVX_Vector r3_m = Q6_Vh_vshuff_Vh(Q6_V_hi_W(r3_dm_deal)); HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy_d))); + HVX_Vector r0_ms = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_m, vy_s))); + + HVX_Vector r1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy_d))); + HVX_Vector r1_ms = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_m, vy_s))); + + HVX_Vector r2_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r2_d, vy_d))); + HVX_Vector r2_ms = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r2_m, vy_s))); + + HVX_Vector r3_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r3_d, vy_d))); + HVX_Vector r3_ms = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r3_m, vy_s))); HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd); + HVX_Vector r0_fa_total = Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_ms); + + HVX_Vector r1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_ia, r1_dd); + HVX_Vector r1_fa_total = Q6_Vqf32_vadd_Vqf32Vsf(r1_fa, r1_ms); + + HVX_Vector r2_fa = Q6_Vqf32_vmpy_VsfVsf(r2_ia, r2_dd); + HVX_Vector r2_fa_total = Q6_Vqf32_vadd_Vqf32Vsf(r2_fa, r2_ms); - r0_sum = Q6_Vqf32_vadd_Vqf32Vqf32(r0_sum, r0_fa); + HVX_Vector r3_fa = Q6_Vqf32_vmpy_VsfVsf(r3_ia, r3_dd); + HVX_Vector r3_fa_total = Q6_Vqf32_vadd_Vqf32Vsf(r3_fa, r3_ms); + + r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa_total, r0_sum)); + r1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_fa_total, r1_sum)); + r2_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r2_fa_total, r2_sum)); + r3_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r3_fa_total, r3_sum)); } - // Process leftovers, we still load full 4x4x2 block but zero out unused scales/blocks if (nloe) { - HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8(y_q + i * y_qblk_size); - HVX_Vector_x8 r0_q = hvx_vec_load_q8x4x8(r0_x_q + i * x_qblk_size); + HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8_partial(y_q + i * y_qblk_size, nloe); + HVX_Vector_x8 r0_q = hvx_vec_load_q4_1x4x8_partial(r0_x_q + i * x_qblk_size, nloe); + HVX_Vector_x8 r1_q = hvx_vec_load_q4_1x4x8_partial(r1_x_q + i * x_qblk_size, nloe); + HVX_Vector_x8 r2_q = hvx_vec_load_q4_1x4x8_partial(r2_x_q + i * x_qblk_size, nloe); + HVX_Vector_x8 r3_q = hvx_vec_load_q4_1x4x8_partial(r3_x_q + i * x_qblk_size, nloe); + + HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r0_q, vy_q, nloe)); + HVX_Vector r1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r1_q, vy_q, nloe)); + HVX_Vector r2_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r2_q, vy_q, nloe)); + HVX_Vector r3_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r3_q, vy_q, nloe)); + + HVX_Vector ds = *(const HVX_UVector *) (y_d + i * y_dblk_size); + HVX_VectorPair ds_deal = Q6_W_vdeal_VVR(ds, ds, -2); + HVX_Vector vy_d = Q6_Vh_vshuff_Vh(Q6_V_lo_W(ds_deal)); + HVX_Vector vy_s = Q6_Vh_vshuff_Vh(Q6_V_hi_W(ds_deal)); + + HVX_Vector r0_dm = *(const HVX_UVector *) (r0_x_d + i * x_dblk_size); + HVX_VectorPair r0_dm_deal = Q6_W_vdeal_VVR(r0_dm, r0_dm, -2); + HVX_Vector r0_d = Q6_Vh_vshuff_Vh(Q6_V_lo_W(r0_dm_deal)); + HVX_Vector r0_m = Q6_Vh_vshuff_Vh(Q6_V_hi_W(r0_dm_deal)); + + HVX_Vector r1_dm = *(const HVX_UVector *) (r1_x_d + i * x_dblk_size); + HVX_VectorPair r1_dm_deal = Q6_W_vdeal_VVR(r1_dm, r1_dm, -2); + HVX_Vector r1_d = Q6_Vh_vshuff_Vh(Q6_V_lo_W(r1_dm_deal)); + HVX_Vector r1_m = Q6_Vh_vshuff_Vh(Q6_V_hi_W(r1_dm_deal)); + + HVX_Vector r2_dm = *(const HVX_UVector *) (r2_x_d + i * x_dblk_size); + HVX_VectorPair r2_dm_deal = Q6_W_vdeal_VVR(r2_dm, r2_dm, -2); + HVX_Vector r2_d = Q6_Vh_vshuff_Vh(Q6_V_lo_W(r2_dm_deal)); + HVX_Vector r2_m = Q6_Vh_vshuff_Vh(Q6_V_hi_W(r2_dm_deal)); + + HVX_Vector r3_dm = *(const HVX_UVector *) (r3_x_d + i * x_dblk_size); + HVX_VectorPair r3_dm_deal = Q6_W_vdeal_VVR(r3_dm, r3_dm, -2); + HVX_Vector r3_d = Q6_Vh_vshuff_Vh(Q6_V_lo_W(r3_dm_deal)); + HVX_Vector r3_m = Q6_Vh_vshuff_Vh(Q6_V_hi_W(r3_dm_deal)); + + HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy_d))); + HVX_Vector r0_ms = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_m, vy_s))); - HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_nloe(r0_q, vy_q, nloe)); + HVX_Vector r1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy_d))); + HVX_Vector r1_ms = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_m, vy_s))); - HVX_Vector vy_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_d + i * y_dblk_size)); - HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size)); + HVX_Vector r2_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r2_d, vy_d))); + HVX_Vector r2_ms = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r2_m, vy_s))); - HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy_d))); + HVX_Vector r3_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r3_d, vy_d))); + HVX_Vector r3_ms = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r3_m, vy_s))); - // Zero out unused scales HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe / 8); r0_dd = Q6_V_vand_QV(bmask, r0_dd); + r0_ms = Q6_V_vand_QV(bmask, r0_ms); + r1_dd = Q6_V_vand_QV(bmask, r1_dd); + r1_ms = Q6_V_vand_QV(bmask, r1_ms); + r2_dd = Q6_V_vand_QV(bmask, r2_dd); + r2_ms = Q6_V_vand_QV(bmask, r2_ms); + r3_dd = Q6_V_vand_QV(bmask, r3_dd); + r3_ms = Q6_V_vand_QV(bmask, r3_ms); + r0_ia = Q6_V_vand_QV(bmask, r0_ia); + r1_ia = Q6_V_vand_QV(bmask, r1_ia); + r2_ia = Q6_V_vand_QV(bmask, r2_ia); + r3_ia = Q6_V_vand_QV(bmask, r3_ia); HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd); + HVX_Vector r0_fa_total = Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_ms); - r0_sum = Q6_Vqf32_vadd_Vqf32Vqf32(r0_sum, r0_fa); - } + HVX_Vector r1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_ia, r1_dd); + HVX_Vector r1_fa_total = Q6_Vqf32_vadd_Vqf32Vsf(r1_fa, r1_ms); - // Reduce and convert into fp32 - r0_sum = hvx_vec_fp32_reduce_sum(Q6_Vsf_equals_Vqf32(r0_sum)); + HVX_Vector r2_fa = Q6_Vqf32_vmpy_VsfVsf(r2_ia, r2_dd); + HVX_Vector r2_fa_total = Q6_Vqf32_vadd_Vqf32Vsf(r2_fa, r2_ms); - hvx_vec_store_u(&s[0], 4, r0_sum); -} + HVX_Vector r3_fa = Q6_Vqf32_vmpy_VsfVsf(r3_ia, r3_dd); + HVX_Vector r3_fa_total = Q6_Vqf32_vadd_Vqf32Vsf(r3_fa, r3_ms); -static void vec_dot_q8x4x2_q8x4x2_rx2(const int n, - float * restrict s, - const void * restrict vx, - uint32_t vx_row_size, - const void * restrict vy) { - assert(n % 32 == 0); // min sub-block size - assert((unsigned long) vx % 128 == 0); - assert((unsigned long) vy % 128 == 0); + r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa_total, r0_sum)); + r1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_fa_total, r1_sum)); + r2_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r2_fa_total, r2_sum)); + r3_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r3_fa_total, r3_sum)); + } - const uint32_t qk = QK_Q4_0x4x2 * 4; + HVX_Vector_x4 rsum_in = { .v = { r0_sum, r1_sum, r2_sum, r3_sum } }; + HVX_Vector rsum = hvx_vec_reduce_sum_f32x4(rsum_in); + hvx_vec_store_u(s0, 16, rsum); +} - const uint32_t x_dblk_size = 8 * 4 * 2; // 32x __fp16 - const uint32_t x_qblk_size = qk; // int8 - const uint32_t x_qrow_size = n; // int8 (not padded) - const uint32_t y_dblk_size = 8 * 4 * 2; // 32x __fp16 - const uint32_t y_qblk_size = qk; // int8 - const uint32_t y_qrow_size = n; // int8 (not padded) +static void vec_dot_q4_1x4x2_q8x4x2_2x2(const int n, float * restrict s0, float * restrict s1, + const void * restrict vx0, const void * restrict vx1, + const void * restrict vy0, const void * restrict vy1) { + assert(n % 32 == 0); + assert((unsigned long) vx0 % 128 == 0); + assert((unsigned long) vx1 % 128 == 0); + assert((unsigned long) vy0 % 128 == 0); + assert((unsigned long) vy1 % 128 == 0); - const uint8_t * restrict r0_x_q = ((const uint8_t *) (vx + (0 * vx_row_size)) + 0); // quants first - const uint8_t * restrict r0_x_d = ((const uint8_t *) (vx + (0 * vx_row_size)) + x_qrow_size); // then scales + const uint32_t qk = QK_Q4_0x4x2 * 4; - const uint8_t * restrict r1_x_q = ((const uint8_t *) (vx + (1 * vx_row_size)) + 0); // quants first - const uint8_t * restrict r1_x_d = ((const uint8_t *) (vx + (1 * vx_row_size)) + x_qrow_size); // then scales + const uint32_t x_dblk_size = 8 * 4 * 2 * 2; // 32x (d, m) __fp16 = 128 bytes + const uint32_t x_qblk_size = qk / 2; // int4 + const uint32_t x_qrow_size = n / 2; // int4 (not padded) - const uint8_t * restrict y_q = ((const uint8_t *) vy + 0); // quants first - const uint8_t * restrict y_d = ((const uint8_t *) vy + y_qrow_size); // then scales + const uint32_t y_dblk_size = 8 * 4 * 4; // 32x (d, s) __fp16 = 128 bytes + const uint32_t y_qblk_size = qk; // int8 + const uint32_t y_qrow_size = n; // int8 (not padded) - // Row sum (qf32) - HVX_Vector r0_sum = Q6_V_vsplat_R(0); - HVX_Vector r1_sum = Q6_V_vsplat_R(0); + const uint8_t * restrict r0_x_q = ((const uint8_t *) vx0) + 0; // quants first + const uint8_t * restrict r0_x_d = ((const uint8_t *) vx0) + x_qrow_size; // then scales + const uint8_t * restrict r1_x_q = ((const uint8_t *) vx1) + 0; // quants first + const uint8_t * restrict r1_x_d = ((const uint8_t *) vx1) + x_qrow_size; // then scales - // Multiply and accumulate into int32. - // Compute combined scale (fp32). - // Apply scale to acc and accumulate into the row sum (qf32). + const uint8_t * restrict y0_q = ((const uint8_t *) vy0) + 0; // quants first + const uint8_t * restrict y0_d = ((const uint8_t *) vy0) + y_qrow_size; // then scales/sums + const uint8_t * restrict y1_q = ((const uint8_t *) vy1) + 0; // quants first + const uint8_t * restrict y1_d = ((const uint8_t *) vy1) + y_qrow_size; // then scales/sums + + // Row sums (sf) - 4 accumulators for 2×2 tile + HVX_Vector r0_c0_sum = Q6_V_vzero(); + HVX_Vector r0_c1_sum = Q6_V_vzero(); + HVX_Vector r1_c0_sum = Q6_V_vzero(); + HVX_Vector r1_c1_sum = Q6_V_vzero(); const uint32_t nb = n / qk; // num full blocks - int32_t nloe = n % qk; // num leftover elemements (must be signed) + const uint32_t nloe = n % qk; // num leftover elements uint32_t i = 0; for (; i < nb; i++) { - HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8(y_q + i * y_qblk_size); - HVX_Vector_x8 r0_q = hvx_vec_load_q8x4x8(r0_x_q + i * x_qblk_size); - HVX_Vector_x8 r1_q = hvx_vec_load_q8x4x8(r1_x_q + i * x_qblk_size); + // Load src1 columns + HVX_Vector_x8 vy0_q = hvx_vec_load_q8x4x8_full(y0_q + i * y_qblk_size); + HVX_Vector_x8 vy1_q = hvx_vec_load_q8x4x8_full(y1_q + i * y_qblk_size); + + // Load src0 rows + HVX_Vector_x8 r0_q = hvx_vec_load_q4_1x4x8_full(r0_x_q + i * x_qblk_size); + HVX_Vector_x8 r1_q = hvx_vec_load_q4_1x4x8_full(r1_x_q + i * x_qblk_size); + + // Compute 4 dot products: r0×c0, r0×c1, r1×c0, r1×c1 + HVX_Vector r0_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy0_q)); + HVX_Vector r0_c1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy1_q)); + HVX_Vector r1_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r1_q, vy0_q)); + HVX_Vector r1_c1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r1_q, vy1_q)); + + // Load scales + HVX_Vector ds0 = *(const HVX_UVector *) (y0_d + i * y_dblk_size); + HVX_VectorPair ds0_deal = Q6_W_vdeal_VVR(ds0, ds0, -2); + HVX_Vector vy0_d = Q6_Vh_vshuff_Vh(Q6_V_lo_W(ds0_deal)); + HVX_Vector vy0_s = Q6_Vh_vshuff_Vh(Q6_V_hi_W(ds0_deal)); + + HVX_Vector ds1 = *(const HVX_UVector *) (y1_d + i * y_dblk_size); + HVX_VectorPair ds1_deal = Q6_W_vdeal_VVR(ds1, ds1, -2); + HVX_Vector vy1_d = Q6_Vh_vshuff_Vh(Q6_V_lo_W(ds1_deal)); + HVX_Vector vy1_s = Q6_Vh_vshuff_Vh(Q6_V_hi_W(ds1_deal)); + + HVX_Vector r0_dm = *(const HVX_UVector *) (r0_x_d + i * x_dblk_size); + HVX_VectorPair r0_dm_deal = Q6_W_vdeal_VVR(r0_dm, r0_dm, -2); + HVX_Vector r0_d = Q6_Vh_vshuff_Vh(Q6_V_lo_W(r0_dm_deal)); + HVX_Vector r0_m = Q6_Vh_vshuff_Vh(Q6_V_hi_W(r0_dm_deal)); + + HVX_Vector r1_dm = *(const HVX_UVector *) (r1_x_d + i * x_dblk_size); + HVX_VectorPair r1_dm_deal = Q6_W_vdeal_VVR(r1_dm, r1_dm, -2); + HVX_Vector r1_d = Q6_Vh_vshuff_Vh(Q6_V_lo_W(r1_dm_deal)); + HVX_Vector r1_m = Q6_Vh_vshuff_Vh(Q6_V_hi_W(r1_dm_deal)); + + // Compute combined scales + HVX_Vector r0_c0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy0_d))); + HVX_Vector r0_c0_ms = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_m, vy0_s))); + + HVX_Vector r0_c1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy1_d))); + HVX_Vector r0_c1_ms = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_m, vy1_s))); + + HVX_Vector r1_c0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy0_d))); + HVX_Vector r1_c0_ms = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_m, vy0_s))); + + HVX_Vector r1_c1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy1_d))); + HVX_Vector r1_c1_ms = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_m, vy1_s))); + + // Apply scales and accumulate + HVX_Vector r0_c0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_c0_ia, r0_c0_dd); + HVX_Vector r0_c1_fa = Q6_Vqf32_vmpy_VsfVsf(r0_c1_ia, r0_c1_dd); + HVX_Vector r1_c0_fa = Q6_Vqf32_vmpy_VsfVsf(r1_c0_ia, r1_c0_dd); + HVX_Vector r1_c1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_c1_ia, r1_c1_dd); + + HVX_Vector r0_c0_fa_total = Q6_Vqf32_vadd_Vqf32Vsf(r0_c0_fa, r0_c0_ms); + HVX_Vector r0_c1_fa_total = Q6_Vqf32_vadd_Vqf32Vsf(r0_c1_fa, r0_c1_ms); + HVX_Vector r1_c0_fa_total = Q6_Vqf32_vadd_Vqf32Vsf(r1_c0_fa, r1_c0_ms); + HVX_Vector r1_c1_fa_total = Q6_Vqf32_vadd_Vqf32Vsf(r1_c1_fa, r1_c1_ms); + + r0_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c0_fa_total, r0_c0_sum)); + r0_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c1_fa_total, r0_c1_sum)); + r1_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c0_fa_total, r1_c0_sum)); + r1_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c1_fa_total, r1_c1_sum)); + } - HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy_q)); - HVX_Vector r1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r1_q, vy_q)); + // Process leftovers + if (nloe) { + HVX_Vector_x8 vy0_q = hvx_vec_load_q8x4x8_partial(y0_q + i * y_qblk_size, nloe); + HVX_Vector_x8 vy1_q = hvx_vec_load_q8x4x8_partial(y1_q + i * y_qblk_size, nloe); + HVX_Vector_x8 r0_q = hvx_vec_load_q4_1x4x8_partial(r0_x_q + i * x_qblk_size, nloe); + HVX_Vector_x8 r1_q = hvx_vec_load_q4_1x4x8_partial(r1_x_q + i * x_qblk_size, nloe); - HVX_Vector vy_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_d + i * y_dblk_size)); - HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size)); - HVX_Vector r1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r1_x_d + i * x_dblk_size)); + HVX_Vector r0_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r0_q, vy0_q, nloe)); + HVX_Vector r0_c1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r0_q, vy1_q, nloe)); + HVX_Vector r1_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r1_q, vy0_q, nloe)); + HVX_Vector r1_c1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r1_q, vy1_q, nloe)); - HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy_d))); - HVX_Vector r1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy_d))); + HVX_Vector ds0 = *(const HVX_UVector *) (y0_d + i * y_dblk_size); + HVX_VectorPair ds0_deal = Q6_W_vdeal_VVR(ds0, ds0, -2); + HVX_Vector vy0_d = Q6_Vh_vshuff_Vh(Q6_V_lo_W(ds0_deal)); + HVX_Vector vy0_s = Q6_Vh_vshuff_Vh(Q6_V_hi_W(ds0_deal)); - HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd); - HVX_Vector r1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_ia, r1_dd); + HVX_Vector ds1 = *(const HVX_UVector *) (y1_d + i * y_dblk_size); + HVX_VectorPair ds1_deal = Q6_W_vdeal_VVR(ds1, ds1, -2); + HVX_Vector vy1_d = Q6_Vh_vshuff_Vh(Q6_V_lo_W(ds1_deal)); + HVX_Vector vy1_s = Q6_Vh_vshuff_Vh(Q6_V_hi_W(ds1_deal)); - r0_sum = Q6_Vqf32_vadd_Vqf32Vqf32(r0_sum, r0_fa); - r1_sum = Q6_Vqf32_vadd_Vqf32Vqf32(r1_sum, r1_fa); - } + HVX_Vector r0_dm = *(const HVX_UVector *) (r0_x_d + i * x_dblk_size); + HVX_VectorPair r0_dm_deal = Q6_W_vdeal_VVR(r0_dm, r0_dm, -2); + HVX_Vector r0_d = Q6_Vh_vshuff_Vh(Q6_V_lo_W(r0_dm_deal)); + HVX_Vector r0_m = Q6_Vh_vshuff_Vh(Q6_V_hi_W(r0_dm_deal)); - // Process leftovers, we still load full 4x4x2 block but zero out unused scales/blocks - if (nloe) { - HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8(y_q + i * y_qblk_size); - HVX_Vector_x8 r0_q = hvx_vec_load_q8x4x8(r0_x_q + i * x_qblk_size); - HVX_Vector_x8 r1_q = hvx_vec_load_q8x4x8(r1_x_q + i * x_qblk_size); + HVX_Vector r1_dm = *(const HVX_UVector *) (r1_x_d + i * x_dblk_size); + HVX_VectorPair r1_dm_deal = Q6_W_vdeal_VVR(r1_dm, r1_dm, -2); + HVX_Vector r1_d = Q6_Vh_vshuff_Vh(Q6_V_lo_W(r1_dm_deal)); + HVX_Vector r1_m = Q6_Vh_vshuff_Vh(Q6_V_hi_W(r1_dm_deal)); - HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_nloe(r0_q, vy_q, nloe)); - HVX_Vector r1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_nloe(r1_q, vy_q, nloe)); + HVX_Vector r0_c0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy0_d))); + HVX_Vector r0_c0_ms = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_m, vy0_s))); - HVX_Vector vy_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_d + i * y_dblk_size)); - HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size)); - HVX_Vector r1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r1_x_d + i * x_dblk_size)); + HVX_Vector r0_c1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy1_d))); + HVX_Vector r0_c1_ms = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_m, vy1_s))); - HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy_d))); - HVX_Vector r1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy_d))); + HVX_Vector r1_c0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy0_d))); + HVX_Vector r1_c0_ms = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_m, vy0_s))); + + HVX_Vector r1_c1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy1_d))); + HVX_Vector r1_c1_ms = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_m, vy1_s))); + + // Zero out unused elements + HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe / 8); + r0_c0_dd = Q6_V_vand_QV(bmask, r0_c0_dd); + r0_c0_ms = Q6_V_vand_QV(bmask, r0_c0_ms); + r0_c1_dd = Q6_V_vand_QV(bmask, r0_c1_dd); + r0_c1_ms = Q6_V_vand_QV(bmask, r0_c1_ms); + r1_c0_dd = Q6_V_vand_QV(bmask, r1_c0_dd); + r1_c0_ms = Q6_V_vand_QV(bmask, r1_c0_ms); + r1_c1_dd = Q6_V_vand_QV(bmask, r1_c1_dd); + r1_c1_ms = Q6_V_vand_QV(bmask, r1_c1_ms); + + r0_c0_ia = Q6_V_vand_QV(bmask, r0_c0_ia); + r0_c1_ia = Q6_V_vand_QV(bmask, r0_c1_ia); + r1_c0_ia = Q6_V_vand_QV(bmask, r1_c0_ia); + r1_c1_ia = Q6_V_vand_QV(bmask, r1_c1_ia); + + HVX_Vector r0_c0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_c0_ia, r0_c0_dd); + HVX_Vector r0_c1_fa = Q6_Vqf32_vmpy_VsfVsf(r0_c1_ia, r0_c1_dd); + HVX_Vector r1_c0_fa = Q6_Vqf32_vmpy_VsfVsf(r1_c0_ia, r1_c0_dd); + HVX_Vector r1_c1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_c1_ia, r1_c1_dd); + + HVX_Vector r0_c0_fa_total = Q6_Vqf32_vadd_Vqf32Vsf(r0_c0_fa, r0_c0_ms); + HVX_Vector r0_c1_fa_total = Q6_Vqf32_vadd_Vqf32Vsf(r0_c1_fa, r0_c1_ms); + HVX_Vector r1_c0_fa_total = Q6_Vqf32_vadd_Vqf32Vsf(r1_c0_fa, r1_c0_ms); + HVX_Vector r1_c1_fa_total = Q6_Vqf32_vadd_Vqf32Vsf(r1_c1_fa, r1_c1_ms); + + r0_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c0_fa_total, r0_c0_sum)); + r0_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c1_fa_total, r0_c1_sum)); + r1_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c0_fa_total, r1_c0_sum)); + r1_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c1_fa_total, r1_c1_sum)); + } + + // Reduce and store results + HVX_Vector r0_r1_c0_sum = hvx_vec_reduce_sum_f32x2(r0_c0_sum, r1_c0_sum); + HVX_Vector r0_r1_c1_sum = hvx_vec_reduce_sum_f32x2(r0_c1_sum, r1_c1_sum); + + hvx_vec_store_u(s0, 8, r0_r1_c0_sum); // row0,col0 row1,col0 + hvx_vec_store_u(s1, 8, r0_r1_c1_sum); // row0,col1 row1,col1 +} + +static void vec_dot_q4x4x2_q8x4x2_1x1(const int n, float * restrict s0, const void * restrict vx0, const void * restrict vy0) { + assert(n % 32 == 0); // min sub-block size + assert((unsigned long) vx0 % 128 == 0); + assert((unsigned long) vy0 % 128 == 0); + + const uint32_t qk = QK_Q4_0x4x2 * 4; + + const uint32_t x_dblk_size = 8 * 4 * 2; // 32x __fp16 + const uint32_t x_qblk_size = qk / 2; // int4 + const uint32_t x_qrow_size = n / 2; // int4 (not padded) + + const uint32_t y_dblk_size = 8 * 4 * 2; // 32x __fp16 + const uint32_t y_qblk_size = qk; // int8 + const uint32_t y_qrow_size = n; // int8 (not padded) + + const uint8_t * restrict r0_x_q = ((const uint8_t *) vx0 + 0); // quants first + const uint8_t * restrict r0_x_d = ((const uint8_t *) vx0 + x_qrow_size); // then scales + + const uint8_t * restrict y_q = ((const uint8_t *) vy0 + 0); // quants first + const uint8_t * restrict y_d = ((const uint8_t *) vy0 + y_qrow_size); // then scales + + // Row sum (sf) + HVX_Vector r0_sum = Q6_V_vzero(); + + // Multiply and accumulate into int32. + // Compute combined scale (fp32). + // Apply scale to acc and accumulate into the row sum (qf32). + + const uint32_t nb = n / qk; // num full blocks + const uint32_t nloe = n % qk; // num leftover elemements + + uint32_t i = 0; + for (; i < nb; i++) { + HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8_full(y_q + i * y_qblk_size); + HVX_Vector_x8 r0_q = hvx_vec_load_q4x4x8_full(r0_x_q + i * x_qblk_size); + + HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy_q)); + + HVX_Vector vy_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_d + i * y_dblk_size)); + HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size)); + + HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy_d))); + + HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd); + + r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum)); + } + + // Process leftovers + if (nloe) { + HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8_partial(y_q + i * y_qblk_size, nloe); + HVX_Vector_x8 r0_q = hvx_vec_load_q4x4x8_partial(r0_x_q + i * x_qblk_size, nloe); + + HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r0_q, vy_q, nloe)); + + HVX_Vector vy_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_d + i * y_dblk_size)); + HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size)); + + HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy_d))); + + // Zero out unused elements + HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe / 8); + r0_dd = Q6_V_vand_QV(bmask, r0_dd); + r0_ia = Q6_V_vand_QV(bmask, r0_ia); + + HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd); + + r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum)); + } + + r0_sum = hvx_vec_reduce_sum_f32(r0_sum); + + hvx_vec_store_u(s0, 4, r0_sum); +} + +static void vec_dot_q4x4x2_q8x4x2_2x1(const int n, float * restrict s0, + const void * restrict vx0, const void * restrict vx1, + const void * restrict vy0) { + assert(n % 32 == 0); // min sub-block size + assert((unsigned long) vx0 % 128 == 0); + assert((unsigned long) vx1 % 128 == 0); + assert((unsigned long) vy0 % 128 == 0); + + const uint32_t qk = QK_Q4_0x4x2 * 4; + + const uint32_t x_dblk_size = 8 * 4 * 2; // 32x __fp16 + const uint32_t x_qblk_size = qk / 2; // int4 + const uint32_t x_qrow_size = n / 2; // int4 (not padded) + + const uint32_t y_dblk_size = 8 * 4 * 2; // 32x __fp16 + const uint32_t y_qblk_size = qk; // int8 + const uint32_t y_qrow_size = n; // int8 (not padded) + + const uint8_t * restrict r0_x_q = ((const uint8_t *) vx0) + 0; // quants first + const uint8_t * restrict r0_x_d = ((const uint8_t *) vx0) + x_qrow_size; // then scales + const uint8_t * restrict r1_x_q = ((const uint8_t *) vx1) + 0; // quants first + const uint8_t * restrict r1_x_d = ((const uint8_t *) vx1) + x_qrow_size; // then scales + + const uint8_t * restrict y_q = ((const uint8_t *) vy0 + 0); // quants first + const uint8_t * restrict y_d = ((const uint8_t *) vy0 + y_qrow_size); // then scales + + // Row sum (sf) + HVX_Vector r0_sum = Q6_V_vzero(); + HVX_Vector r1_sum = Q6_V_vzero(); + + // Multiply and accumulate into int32. + // Compute combined scale (fp32). + // Apply scale to acc and accumulate into the row sum (qf32). + + const uint32_t nb = n / qk; // num full blocks + const uint32_t nloe = n % qk; // num leftover elemements + + uint32_t i = 0; + for (; i < nb; i++) { + HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8_full(y_q + i * y_qblk_size); + HVX_Vector_x8 r0_q = hvx_vec_load_q4x4x8_full(r0_x_q + i * x_qblk_size); + HVX_Vector_x8 r1_q = hvx_vec_load_q4x4x8_full(r1_x_q + i * x_qblk_size); + + HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy_q)); + HVX_Vector r1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r1_q, vy_q)); + + HVX_Vector vy_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_d + i * y_dblk_size)); + HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size)); + HVX_Vector r1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r1_x_d + i * x_dblk_size)); + + HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy_d))); + HVX_Vector r1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy_d))); + + HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd); + HVX_Vector r1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_ia, r1_dd); + + r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum)); + r1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_fa, r1_sum)); + } + + // Process leftovers + if (nloe) { + HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8_partial(y_q + i * y_qblk_size, nloe); + HVX_Vector_x8 r0_q = hvx_vec_load_q4x4x8_partial(r0_x_q + i * x_qblk_size, nloe); + HVX_Vector_x8 r1_q = hvx_vec_load_q4x4x8_partial(r1_x_q + i * x_qblk_size, nloe); + + HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r0_q, vy_q, nloe)); + HVX_Vector r1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r1_q, vy_q, nloe)); + + HVX_Vector vy_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_d + i * y_dblk_size)); + HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size)); + HVX_Vector r1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r1_x_d + i * x_dblk_size)); + + HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy_d))); + HVX_Vector r1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy_d))); + + // Zero out unused elements + HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe / 8); + r0_dd = Q6_V_vand_QV(bmask, r0_dd); + r1_dd = Q6_V_vand_QV(bmask, r1_dd); + r0_ia = Q6_V_vand_QV(bmask, r0_ia); + r1_ia = Q6_V_vand_QV(bmask, r1_ia); + + HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd); + HVX_Vector r1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_ia, r1_dd); + + r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum)); + r1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_fa, r1_sum)); + } + + HVX_Vector rsum = hvx_vec_reduce_sum_f32x2(r0_sum, r1_sum); + hvx_vec_store_u(s0, 8, rsum); +} + +static void vec_dot_q4x4x2_q8x4x2_4x1(const int n, float * restrict s0, + const void * restrict vx0, const void * restrict vx1, + const void * restrict vx2, const void * restrict vx3, + const void * restrict vy0) { + assert(n % 32 == 0); // min sub-block size + assert((unsigned long) vx0 % 128 == 0); + assert((unsigned long) vx1 % 128 == 0); + assert((unsigned long) vx2 % 128 == 0); + assert((unsigned long) vx3 % 128 == 0); + assert((unsigned long) vy0 % 128 == 0); + + const uint32_t qk = QK_Q4_0x4x2 * 4; + + const uint32_t x_dblk_size = 8 * 4 * 2; // 32x __fp16 + const uint32_t x_qblk_size = qk / 2; // int4 + const uint32_t x_qrow_size = n / 2; // int4 (not padded) + + const uint32_t y_dblk_size = 8 * 4 * 2; // 32x __fp16 + const uint32_t y_qblk_size = qk; // int8 + const uint32_t y_qrow_size = n; // int8 (not padded) + + const uint8_t * restrict r0_x_q = ((const uint8_t *) vx0) + 0; + const uint8_t * restrict r0_x_d = ((const uint8_t *) vx0) + x_qrow_size; + const uint8_t * restrict r1_x_q = ((const uint8_t *) vx1) + 0; + const uint8_t * restrict r1_x_d = ((const uint8_t *) vx1) + x_qrow_size; + const uint8_t * restrict r2_x_q = ((const uint8_t *) vx2) + 0; + const uint8_t * restrict r2_x_d = ((const uint8_t *) vx2) + x_qrow_size; + const uint8_t * restrict r3_x_q = ((const uint8_t *) vx3) + 0; + const uint8_t * restrict r3_x_d = ((const uint8_t *) vx3) + x_qrow_size; + + const uint8_t * restrict y_q = ((const uint8_t *) vy0 + 0); + const uint8_t * restrict y_d = ((const uint8_t *) vy0 + y_qrow_size); + + // Row sum (sf) + HVX_Vector r0_sum = Q6_V_vzero(); + HVX_Vector r1_sum = Q6_V_vzero(); + HVX_Vector r2_sum = Q6_V_vzero(); + HVX_Vector r3_sum = Q6_V_vzero(); + + const uint32_t nb = n / qk; // num full blocks + const uint32_t nloe = n % qk; // num leftover elements + + uint32_t i = 0; + for (; i < nb; i++) { + HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8_full(y_q + i * y_qblk_size); + HVX_Vector_x8 r0_q = hvx_vec_load_q4x4x8_full(r0_x_q + i * x_qblk_size); + HVX_Vector_x8 r1_q = hvx_vec_load_q4x4x8_full(r1_x_q + i * x_qblk_size); + HVX_Vector_x8 r2_q = hvx_vec_load_q4x4x8_full(r2_x_q + i * x_qblk_size); + HVX_Vector_x8 r3_q = hvx_vec_load_q4x4x8_full(r3_x_q + i * x_qblk_size); + + HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy_q)); + HVX_Vector r1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r1_q, vy_q)); + HVX_Vector r2_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r2_q, vy_q)); + HVX_Vector r3_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r3_q, vy_q)); + + HVX_Vector vy_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_d + i * y_dblk_size)); + HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size)); + HVX_Vector r1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r1_x_d + i * x_dblk_size)); + HVX_Vector r2_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r2_x_d + i * x_dblk_size)); + HVX_Vector r3_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r3_x_d + i * x_dblk_size)); + + HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy_d))); + HVX_Vector r1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy_d))); + HVX_Vector r2_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r2_d, vy_d))); + HVX_Vector r3_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r3_d, vy_d))); + + HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd); + HVX_Vector r1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_ia, r1_dd); + HVX_Vector r2_fa = Q6_Vqf32_vmpy_VsfVsf(r2_ia, r2_dd); + HVX_Vector r3_fa = Q6_Vqf32_vmpy_VsfVsf(r3_ia, r3_dd); + + r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum)); + r1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_fa, r1_sum)); + r2_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r2_fa, r2_sum)); + r3_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r3_fa, r3_sum)); + } + + if (nloe) { + HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8_partial(y_q + i * y_qblk_size, nloe); + HVX_Vector_x8 r0_q = hvx_vec_load_q4x4x8_partial(r0_x_q + i * x_qblk_size, nloe); + HVX_Vector_x8 r1_q = hvx_vec_load_q4x4x8_partial(r1_x_q + i * x_qblk_size, nloe); + HVX_Vector_x8 r2_q = hvx_vec_load_q4x4x8_partial(r2_x_q + i * x_qblk_size, nloe); + HVX_Vector_x8 r3_q = hvx_vec_load_q4x4x8_partial(r3_x_q + i * x_qblk_size, nloe); + + HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r0_q, vy_q, nloe)); + HVX_Vector r1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r1_q, vy_q, nloe)); + HVX_Vector r2_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r2_q, vy_q, nloe)); + HVX_Vector r3_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r3_q, vy_q, nloe)); + + HVX_Vector vy_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_d + i * y_dblk_size)); + HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size)); + HVX_Vector r1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r1_x_d + i * x_dblk_size)); + HVX_Vector r2_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r2_x_d + i * x_dblk_size)); + HVX_Vector r3_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r3_x_d + i * x_dblk_size)); + + HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy_d))); + HVX_Vector r1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy_d))); + HVX_Vector r2_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r2_d, vy_d))); + HVX_Vector r3_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r3_d, vy_d))); + + HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe / 8); + r0_dd = Q6_V_vand_QV(bmask, r0_dd); + r1_dd = Q6_V_vand_QV(bmask, r1_dd); + r2_dd = Q6_V_vand_QV(bmask, r2_dd); + r3_dd = Q6_V_vand_QV(bmask, r3_dd); + r0_ia = Q6_V_vand_QV(bmask, r0_ia); + r1_ia = Q6_V_vand_QV(bmask, r1_ia); + r2_ia = Q6_V_vand_QV(bmask, r2_ia); + r3_ia = Q6_V_vand_QV(bmask, r3_ia); + + HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd); + HVX_Vector r1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_ia, r1_dd); + HVX_Vector r2_fa = Q6_Vqf32_vmpy_VsfVsf(r2_ia, r2_dd); + HVX_Vector r3_fa = Q6_Vqf32_vmpy_VsfVsf(r3_ia, r3_dd); + + r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum)); + r1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_fa, r1_sum)); + r2_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r2_fa, r2_sum)); + r3_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r3_fa, r3_sum)); + } + + HVX_Vector_x4 rsum_in = { .v = { r0_sum, r1_sum, r2_sum, r3_sum } }; + HVX_Vector rsum = hvx_vec_reduce_sum_f32x4(rsum_in); + hvx_vec_store_u(s0, 16, rsum); +} + + +static void vec_dot_q4x4x2_q8x4x2_2x2(const int n, float * restrict s0, float * restrict s1, + const void * restrict vx0, const void * restrict vx1, + const void * restrict vy0, const void * restrict vy1) { + assert(n % 32 == 0); + assert((unsigned long) vx0 % 128 == 0); + assert((unsigned long) vx1 % 128 == 0); + assert((unsigned long) vy0 % 128 == 0); + assert((unsigned long) vy1 % 128 == 0); + + const uint32_t qk = QK_Q4_0x4x2 * 4; + + const uint32_t x_dblk_size = 8 * 4 * 2; // 32x __fp16 + const uint32_t x_qblk_size = qk / 2; // int4 + const uint32_t x_qrow_size = n / 2; // int4 (not padded) + + const uint32_t y_dblk_size = 8 * 4 * 2; // 32x __fp16 + const uint32_t y_qblk_size = qk; // int8 + const uint32_t y_qrow_size = n; // int8 (not padded) + + const uint8_t * restrict r0_x_q = ((const uint8_t *) vx0) + 0; // quants first + const uint8_t * restrict r0_x_d = ((const uint8_t *) vx0) + x_qrow_size; // then scales + const uint8_t * restrict r1_x_q = ((const uint8_t *) vx1) + 0; // quants first + const uint8_t * restrict r1_x_d = ((const uint8_t *) vx1) + x_qrow_size; // then scales + + const uint8_t * restrict y0_q = ((const uint8_t *) vy0) + 0; // quants first + const uint8_t * restrict y0_d = ((const uint8_t *) vy0) + y_qrow_size; // then scales + const uint8_t * restrict y1_q = ((const uint8_t *) vy1) + 0; // quants first + const uint8_t * restrict y1_d = ((const uint8_t *) vy1) + y_qrow_size; // then scales + + // Row sums (sf) - 4 accumulators for 2×2 tile + HVX_Vector r0_c0_sum = Q6_V_vzero(); + HVX_Vector r0_c1_sum = Q6_V_vzero(); + HVX_Vector r1_c0_sum = Q6_V_vzero(); + HVX_Vector r1_c1_sum = Q6_V_vzero(); + + const uint32_t nb = n / qk; // num full blocks + const uint32_t nloe = n % qk; // num leftover elements + + uint32_t i = 0; + for (; i < nb; i++) { + // Load src1 columns (reused across both src0 rows) + HVX_Vector_x8 vy0_q = hvx_vec_load_q8x4x8_full(y0_q + i * y_qblk_size); + HVX_Vector_x8 vy1_q = hvx_vec_load_q8x4x8_full(y1_q + i * y_qblk_size); + + // Load src0 rows (reused across both src1 columns) + HVX_Vector_x8 r0_q = hvx_vec_load_q4x4x8_full(r0_x_q + i * x_qblk_size); + HVX_Vector_x8 r1_q = hvx_vec_load_q4x4x8_full(r1_x_q + i * x_qblk_size); + + // Compute 4 dot products: r0×c0, r0×c1, r1×c0, r1×c1 + HVX_Vector r0_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy0_q)); + HVX_Vector r0_c1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy1_q)); + HVX_Vector r1_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r1_q, vy0_q)); + HVX_Vector r1_c1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r1_q, vy1_q)); + + // Load scales + HVX_Vector vy0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y0_d + i * y_dblk_size)); + HVX_Vector vy1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y1_d + i * y_dblk_size)); + HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size)); + HVX_Vector r1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r1_x_d + i * x_dblk_size)); + + // Compute combined scales + HVX_Vector r0_c0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy0_d))); + HVX_Vector r0_c1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy1_d))); + HVX_Vector r1_c0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy0_d))); + HVX_Vector r1_c1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy1_d))); + + // Apply scales and accumulate + HVX_Vector r0_c0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_c0_ia, r0_c0_dd); + HVX_Vector r0_c1_fa = Q6_Vqf32_vmpy_VsfVsf(r0_c1_ia, r0_c1_dd); + HVX_Vector r1_c0_fa = Q6_Vqf32_vmpy_VsfVsf(r1_c0_ia, r1_c0_dd); + HVX_Vector r1_c1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_c1_ia, r1_c1_dd); + + r0_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c0_fa, r0_c0_sum)); + r0_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c1_fa, r0_c1_sum)); + r1_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c0_fa, r1_c0_sum)); + r1_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c1_fa, r1_c1_sum)); + } + + // Process leftovers + if (nloe) { + HVX_Vector_x8 vy0_q = hvx_vec_load_q8x4x8_partial(y0_q + i * y_qblk_size, nloe); + HVX_Vector_x8 vy1_q = hvx_vec_load_q8x4x8_partial(y1_q + i * y_qblk_size, nloe); + HVX_Vector_x8 r0_q = hvx_vec_load_q4x4x8_partial(r0_x_q + i * x_qblk_size, nloe); + HVX_Vector_x8 r1_q = hvx_vec_load_q4x4x8_partial(r1_x_q + i * x_qblk_size, nloe); + + HVX_Vector r0_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r0_q, vy0_q, nloe)); + HVX_Vector r0_c1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r0_q, vy1_q, nloe)); + HVX_Vector r1_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r1_q, vy0_q, nloe)); + HVX_Vector r1_c1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r1_q, vy1_q, nloe)); + + HVX_Vector vy0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y0_d + i * y_dblk_size)); + HVX_Vector vy1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y1_d + i * y_dblk_size)); + HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size)); + HVX_Vector r1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r1_x_d + i * x_dblk_size)); + + HVX_Vector r0_c0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy0_d))); + HVX_Vector r0_c1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy1_d))); + HVX_Vector r1_c0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy0_d))); + HVX_Vector r1_c1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy1_d))); // Zero out unused scales + HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe / 8); + r0_c0_dd = Q6_V_vand_QV(bmask, r0_c0_dd); + r0_c1_dd = Q6_V_vand_QV(bmask, r0_c1_dd); + r1_c0_dd = Q6_V_vand_QV(bmask, r1_c0_dd); + r1_c1_dd = Q6_V_vand_QV(bmask, r1_c1_dd); + r0_c0_ia = Q6_V_vand_QV(bmask, r0_c0_ia); + r0_c1_ia = Q6_V_vand_QV(bmask, r0_c1_ia); + r1_c0_ia = Q6_V_vand_QV(bmask, r1_c0_ia); + r1_c1_ia = Q6_V_vand_QV(bmask, r1_c1_ia); + + HVX_Vector r0_c0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_c0_ia, r0_c0_dd); + HVX_Vector r0_c1_fa = Q6_Vqf32_vmpy_VsfVsf(r0_c1_ia, r0_c1_dd); + HVX_Vector r1_c0_fa = Q6_Vqf32_vmpy_VsfVsf(r1_c0_ia, r1_c0_dd); + HVX_Vector r1_c1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_c1_ia, r1_c1_dd); + + r0_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c0_fa, r0_c0_sum)); + r0_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c1_fa, r0_c1_sum)); + r1_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c0_fa, r1_c0_sum)); + r1_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c1_fa, r1_c1_sum)); + } + + // Reduce and store results + HVX_Vector r0_r1_c0_sum = hvx_vec_reduce_sum_f32x2(r0_c0_sum, r1_c0_sum); + HVX_Vector r0_r1_c1_sum = hvx_vec_reduce_sum_f32x2(r0_c1_sum, r1_c1_sum); + + hvx_vec_store_u(s0, 8, r0_r1_c0_sum); // row0,col0 row1,col0 + hvx_vec_store_u(s1, 8, r0_r1_c1_sum); // row0,col1 row1,col1 +} + +static void vec_dot_q8x4x2_q8x4x2_1x1(const int n, float * restrict s0, const void * restrict vx0, const void * restrict vy0) { + assert(n % 32 == 0); // min sub-block size + assert((unsigned long) vx0 % 128 == 0); + assert((unsigned long) vy0 % 128 == 0); + + const uint32_t qk = QK_Q4_0x4x2 * 4; + + const uint32_t x_dblk_size = 8 * 4 * 2; // 32x __fp16 + const uint32_t x_qblk_size = qk; // int8 + const uint32_t x_qrow_size = n; // int8 (not padded) + + const uint32_t y_dblk_size = 8 * 4 * 2; // 32x __fp16 + const uint32_t y_qblk_size = qk; // int8 + const uint32_t y_qrow_size = n; // int8 (not padded) + + const uint8_t * restrict r0_x_q = ((const uint8_t *) vx0 + 0); // quants first + const uint8_t * restrict r0_x_d = ((const uint8_t *) vx0 + x_qrow_size); // then scales + + const uint8_t * restrict y_q = ((const uint8_t *) vy0 + 0); // quants first + const uint8_t * restrict y_d = ((const uint8_t *) vy0 + y_qrow_size); // then scales + + // Row sum (sf) + HVX_Vector r0_sum = Q6_V_vzero(); + + // Multiply and accumulate into int32. + // Compute combined scale (fp32). + // Apply scale to acc and accumulate into the row sum (qf32). + + const uint32_t nb = n / qk; // num full blocks + int32_t nloe = n % qk; // num leftover elemements (must be signed) + + uint32_t i = 0; + for (; i < nb; i++) { + HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8_full(y_q + i * y_qblk_size); + HVX_Vector_x8 r0_q = hvx_vec_load_q8x4x8_full(r0_x_q + i * x_qblk_size); + + HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy_q)); + + HVX_Vector vy_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_d + i * y_dblk_size)); + HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size)); + + HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy_d))); + + HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd); + + r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum)); + } + + // Process leftovers + if (nloe) { + HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8_partial(y_q + i * y_qblk_size, nloe); + HVX_Vector_x8 r0_q = hvx_vec_load_q8x4x8_partial(r0_x_q + i * x_qblk_size, nloe); + + HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r0_q, vy_q, nloe)); + + HVX_Vector vy_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_d + i * y_dblk_size)); + HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size)); + + HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy_d))); + + // Zero out unused elements + HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe / 8); + r0_dd = Q6_V_vand_QV(bmask, r0_dd); + r0_ia = Q6_V_vand_QV(bmask, r0_ia); + + HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd); + + r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum)); + } + + r0_sum = hvx_vec_reduce_sum_f32(r0_sum); + + hvx_vec_store_u(s0, 4, r0_sum); +} + +static void vec_dot_q8x4x2_q8x4x2_2x1(const int n, float * restrict s0, + const void * restrict vx0, const void * restrict vx1, + const void * restrict vy0) { + assert(n % 32 == 0); // min sub-block size + assert((unsigned long) vx0 % 128 == 0); + assert((unsigned long) vx1 % 128 == 0); + assert((unsigned long) vy0 % 128 == 0); + + const uint32_t qk = QK_Q4_0x4x2 * 4; + + const uint32_t x_dblk_size = 8 * 4 * 2; // 32x __fp16 + const uint32_t x_qblk_size = qk; // int8 + const uint32_t x_qrow_size = n; // int8 (not padded) + + const uint32_t y_dblk_size = 8 * 4 * 2; // 32x __fp16 + const uint32_t y_qblk_size = qk; // int8 + const uint32_t y_qrow_size = n; // int8 (not padded) + + const uint8_t * restrict r0_x_q = ((const uint8_t *) vx0) + 0; // quants first + const uint8_t * restrict r0_x_d = ((const uint8_t *) vx0) + x_qrow_size; // then scales + const uint8_t * restrict r1_x_q = ((const uint8_t *) vx1) + 0; // quants first + const uint8_t * restrict r1_x_d = ((const uint8_t *) vx1) + x_qrow_size; // then scales + + const uint8_t * restrict y_q = ((const uint8_t *) vy0 + 0); // quants first + const uint8_t * restrict y_d = ((const uint8_t *) vy0 + y_qrow_size); // then scales + + // Row sum (qf32) + HVX_Vector r0_sum = Q6_V_vzero(); + HVX_Vector r1_sum = Q6_V_vzero(); + + // Multiply and accumulate into int32. + // Compute combined scale (fp32). + // Apply scale to acc and accumulate into the row sum (qf32). + + const uint32_t nb = n / qk; // num full blocks + int32_t nloe = n % qk; // num leftover elemements (must be signed) + + uint32_t i = 0; + for (; i < nb; i++) { + HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8_full(y_q + i * y_qblk_size); + HVX_Vector_x8 r0_q = hvx_vec_load_q8x4x8_full(r0_x_q + i * x_qblk_size); + HVX_Vector_x8 r1_q = hvx_vec_load_q8x4x8_full(r1_x_q + i * x_qblk_size); + + HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy_q)); + HVX_Vector r1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r1_q, vy_q)); + + HVX_Vector vy_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_d + i * y_dblk_size)); + HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size)); + HVX_Vector r1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r1_x_d + i * x_dblk_size)); + + HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy_d))); + HVX_Vector r1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy_d))); + + HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd); + HVX_Vector r1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_ia, r1_dd); + + r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum)); + r1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_fa, r1_sum)); + } + + // Process leftovers + if (nloe) { + HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8_partial(y_q + i * y_qblk_size, nloe); + HVX_Vector_x8 r0_q = hvx_vec_load_q8x4x8_partial(r0_x_q + i * x_qblk_size, nloe); + HVX_Vector_x8 r1_q = hvx_vec_load_q8x4x8_partial(r1_x_q + i * x_qblk_size, nloe); + + HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r0_q, vy_q, nloe)); + HVX_Vector r1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r1_q, vy_q, nloe)); + + HVX_Vector vy_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_d + i * y_dblk_size)); + HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size)); + HVX_Vector r1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r1_x_d + i * x_dblk_size)); + + HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy_d))); + HVX_Vector r1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy_d))); + + // Zero out unused elements + HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe / 8); + r0_dd = Q6_V_vand_QV(bmask, r0_dd); + r1_dd = Q6_V_vand_QV(bmask, r1_dd); + r0_ia = Q6_V_vand_QV(bmask, r0_ia); + r1_ia = Q6_V_vand_QV(bmask, r1_ia); + + HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd); + HVX_Vector r1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_ia, r1_dd); + + r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum)); + r1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_fa, r1_sum)); + } + + HVX_Vector rsum = hvx_vec_reduce_sum_f32x2(r0_sum, r1_sum); + hvx_vec_store_u(s0, 8, rsum); +} + +static void vec_dot_q8x4x2_q8x4x2_4x1(const int n, float * restrict s0, + const void * restrict vx0, const void * restrict vx1, + const void * restrict vx2, const void * restrict vx3, + const void * restrict vy0) { + assert(n % 32 == 0); // min sub-block size + assert((unsigned long) vx0 % 128 == 0); + assert((unsigned long) vx1 % 128 == 0); + assert((unsigned long) vx2 % 128 == 0); + assert((unsigned long) vx3 % 128 == 0); + assert((unsigned long) vy0 % 128 == 0); + + const uint32_t qk = QK_Q4_0x4x2 * 4; + + const uint32_t x_dblk_size = 8 * 4 * 2; // 32x __fp16 + const uint32_t x_qblk_size = qk; // int8 + const uint32_t x_qrow_size = n; // int8 (not padded) + + const uint32_t y_dblk_size = 8 * 4 * 2; // 32x __fp16 + const uint32_t y_qblk_size = qk; // int8 + const uint32_t y_qrow_size = n; // int8 (not padded) + + const uint8_t * restrict r0_x_q = ((const uint8_t *) vx0) + 0; // quants first + const uint8_t * restrict r0_x_d = ((const uint8_t *) vx0) + x_qrow_size; // then scales + const uint8_t * restrict r1_x_q = ((const uint8_t *) vx1) + 0; // quants first + const uint8_t * restrict r1_x_d = ((const uint8_t *) vx1) + x_qrow_size; // then scales + const uint8_t * restrict r2_x_q = ((const uint8_t *) vx2) + 0; // quants first + const uint8_t * restrict r2_x_d = ((const uint8_t *) vx2) + x_qrow_size; // then scales + const uint8_t * restrict r3_x_q = ((const uint8_t *) vx3) + 0; // quants first + const uint8_t * restrict r3_x_d = ((const uint8_t *) vx3) + x_qrow_size; // then scales + + const uint8_t * restrict y_q = ((const uint8_t *) vy0 + 0); // quants first + const uint8_t * restrict y_d = ((const uint8_t *) vy0 + y_qrow_size); // then scales + + // Row sum (qf32) + HVX_Vector r0_sum = Q6_V_vzero(); + HVX_Vector r1_sum = Q6_V_vzero(); + HVX_Vector r2_sum = Q6_V_vzero(); + HVX_Vector r3_sum = Q6_V_vzero(); + + const uint32_t nb = n / qk; // num full blocks + int32_t nloe = n % qk; // num leftover elemements (must be signed) + + uint32_t i = 0; + for (; i < nb; i++) { + HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8_full(y_q + i * y_qblk_size); + HVX_Vector_x8 r0_q = hvx_vec_load_q8x4x8_full(r0_x_q + i * x_qblk_size); + HVX_Vector_x8 r1_q = hvx_vec_load_q8x4x8_full(r1_x_q + i * x_qblk_size); + HVX_Vector_x8 r2_q = hvx_vec_load_q8x4x8_full(r2_x_q + i * x_qblk_size); + HVX_Vector_x8 r3_q = hvx_vec_load_q8x4x8_full(r3_x_q + i * x_qblk_size); + + HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy_q)); + HVX_Vector r1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r1_q, vy_q)); + HVX_Vector r2_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r2_q, vy_q)); + HVX_Vector r3_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r3_q, vy_q)); + + HVX_Vector vy_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_d + i * y_dblk_size)); + HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size)); + HVX_Vector r1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r1_x_d + i * x_dblk_size)); + HVX_Vector r2_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r2_x_d + i * x_dblk_size)); + HVX_Vector r3_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r3_x_d + i * x_dblk_size)); + + HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy_d))); + HVX_Vector r1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy_d))); + HVX_Vector r2_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r2_d, vy_d))); + HVX_Vector r3_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r3_d, vy_d))); + + HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd); + HVX_Vector r1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_ia, r1_dd); + HVX_Vector r2_fa = Q6_Vqf32_vmpy_VsfVsf(r2_ia, r2_dd); + HVX_Vector r3_fa = Q6_Vqf32_vmpy_VsfVsf(r3_ia, r3_dd); + + r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum)); + r1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_fa, r1_sum)); + r2_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r2_fa, r2_sum)); + r3_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r3_fa, r3_sum)); + } + + if (nloe) { + HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8_partial(y_q + i * y_qblk_size, nloe); + HVX_Vector_x8 r0_q = hvx_vec_load_q8x4x8_partial(r0_x_q + i * x_qblk_size, nloe); + HVX_Vector_x8 r1_q = hvx_vec_load_q8x4x8_partial(r1_x_q + i * x_qblk_size, nloe); + HVX_Vector_x8 r2_q = hvx_vec_load_q8x4x8_partial(r2_x_q + i * x_qblk_size, nloe); + HVX_Vector_x8 r3_q = hvx_vec_load_q8x4x8_partial(r3_x_q + i * x_qblk_size, nloe); + + HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r0_q, vy_q, nloe)); + HVX_Vector r1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r1_q, vy_q, nloe)); + HVX_Vector r2_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r2_q, vy_q, nloe)); + HVX_Vector r3_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r3_q, vy_q, nloe)); + + HVX_Vector vy_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_d + i * y_dblk_size)); + HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size)); + HVX_Vector r1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r1_x_d + i * x_dblk_size)); + HVX_Vector r2_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r2_x_d + i * x_dblk_size)); + HVX_Vector r3_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r3_x_d + i * x_dblk_size)); + + HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy_d))); + HVX_Vector r1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy_d))); + HVX_Vector r2_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r2_d, vy_d))); + HVX_Vector r3_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r3_d, vy_d))); + + HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe / 8); + r0_dd = Q6_V_vand_QV(bmask, r0_dd); + r1_dd = Q6_V_vand_QV(bmask, r1_dd); + r2_dd = Q6_V_vand_QV(bmask, r2_dd); + r3_dd = Q6_V_vand_QV(bmask, r3_dd); + r0_ia = Q6_V_vand_QV(bmask, r0_ia); + r1_ia = Q6_V_vand_QV(bmask, r1_ia); + r2_ia = Q6_V_vand_QV(bmask, r2_ia); + r3_ia = Q6_V_vand_QV(bmask, r3_ia); + + HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd); + HVX_Vector r1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_ia, r1_dd); + HVX_Vector r2_fa = Q6_Vqf32_vmpy_VsfVsf(r2_ia, r2_dd); + HVX_Vector r3_fa = Q6_Vqf32_vmpy_VsfVsf(r3_ia, r3_dd); + + r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum)); + r1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_fa, r1_sum)); + r2_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r2_fa, r2_sum)); + r3_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r3_fa, r3_sum)); + } + + HVX_Vector_x4 rsum_in = { .v = { r0_sum, r1_sum, r2_sum, r3_sum } }; + HVX_Vector rsum = hvx_vec_reduce_sum_f32x4(rsum_in); + hvx_vec_store_u(s0, 16, rsum); +} + + +static void vec_dot_q8x4x2_q8x4x2_2x2(const int n, float * restrict s0, float * restrict s1, + const void * restrict vx0, const void * restrict vx1, + const void * restrict vy0, const void * restrict vy1) { + assert(n % 32 == 0); + assert((unsigned long) vx0 % 128 == 0); + assert((unsigned long) vx1 % 128 == 0); + assert((unsigned long) vy0 % 128 == 0); + assert((unsigned long) vy1 % 128 == 0); + + const uint32_t qk = QK_Q8_0x4x2 * 4; + + const uint32_t x_dblk_size = 8 * 4 * 2; // 32x __fp16 + const uint32_t x_qblk_size = qk; // int8 + const uint32_t x_qrow_size = n; // int8 (not padded) + + const uint32_t y_dblk_size = 8 * 4 * 2; // 32x __fp16 + const uint32_t y_qblk_size = qk; // int8 + const uint32_t y_qrow_size = n; // int8 (not padded) + + const uint8_t * restrict r0_x_q = ((const uint8_t *) vx0) + 0; // quants first + const uint8_t * restrict r0_x_d = ((const uint8_t *) vx0) + x_qrow_size; // then scales + const uint8_t * restrict r1_x_q = ((const uint8_t *) vx1) + 0; // quants first + const uint8_t * restrict r1_x_d = ((const uint8_t *) vx1) + x_qrow_size; // then scales + + const uint8_t * restrict y0_q = ((const uint8_t *) vy0) + 0; // quants first + const uint8_t * restrict y0_d = ((const uint8_t *) vy0) + y_qrow_size; // then scales + const uint8_t * restrict y1_q = ((const uint8_t *) vy1) + 0; // quants first + const uint8_t * restrict y1_d = ((const uint8_t *) vy1) + y_qrow_size; // then scales + + // Row sums (sf) - 4 accumulators for 2×2 tile + HVX_Vector r0_c0_sum = Q6_V_vzero(); + HVX_Vector r0_c1_sum = Q6_V_vzero(); + HVX_Vector r1_c0_sum = Q6_V_vzero(); + HVX_Vector r1_c1_sum = Q6_V_vzero(); + + const uint32_t nb = n / qk; // num full blocks + const uint32_t nloe = n % qk; // num leftover elements + + uint32_t i = 0; + for (; i < nb; i++) { + // Load src1 columns (reused across both src0 rows) + HVX_Vector_x8 vy0_q = hvx_vec_load_q8x4x8_full(y0_q + i * y_qblk_size); + HVX_Vector_x8 vy1_q = hvx_vec_load_q8x4x8_full(y1_q + i * y_qblk_size); + + // Load src0 rows (reused across both src1 columns) + HVX_Vector_x8 r0_q = hvx_vec_load_q8x4x8_full(r0_x_q + i * x_qblk_size); + HVX_Vector_x8 r1_q = hvx_vec_load_q8x4x8_full(r1_x_q + i * x_qblk_size); + + // Compute 4 dot products: r0×c0, r0×c1, r1×c0, r1×c1 + HVX_Vector r0_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy0_q)); + HVX_Vector r0_c1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy1_q)); + HVX_Vector r1_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r1_q, vy0_q)); + HVX_Vector r1_c1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r1_q, vy1_q)); + + // Load scales + HVX_Vector vy0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y0_d + i * y_dblk_size)); + HVX_Vector vy1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y1_d + i * y_dblk_size)); + HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size)); + HVX_Vector r1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r1_x_d + i * x_dblk_size)); + + // Compute combined scales + HVX_Vector r0_c0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy0_d))); + HVX_Vector r0_c1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy1_d))); + HVX_Vector r1_c0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy0_d))); + HVX_Vector r1_c1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy1_d))); + + // Apply scales and accumulate + HVX_Vector r0_c0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_c0_ia, r0_c0_dd); + HVX_Vector r0_c1_fa = Q6_Vqf32_vmpy_VsfVsf(r0_c1_ia, r0_c1_dd); + HVX_Vector r1_c0_fa = Q6_Vqf32_vmpy_VsfVsf(r1_c0_ia, r1_c0_dd); + HVX_Vector r1_c1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_c1_ia, r1_c1_dd); + + r0_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c0_fa, r0_c0_sum)); + r0_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c1_fa, r0_c1_sum)); + r1_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c0_fa, r1_c0_sum)); + r1_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c1_fa, r1_c1_sum)); + } + + // Process leftovers + if (nloe) { + HVX_Vector_x8 vy0_q = hvx_vec_load_q8x4x8_partial(y0_q + i * y_qblk_size, nloe); + HVX_Vector_x8 vy1_q = hvx_vec_load_q8x4x8_partial(y1_q + i * y_qblk_size, nloe); + HVX_Vector_x8 r0_q = hvx_vec_load_q8x4x8_partial(r0_x_q + i * x_qblk_size, nloe); + HVX_Vector_x8 r1_q = hvx_vec_load_q8x4x8_partial(r1_x_q + i * x_qblk_size, nloe); + + HVX_Vector r0_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r0_q, vy0_q, nloe)); + HVX_Vector r0_c1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r0_q, vy1_q, nloe)); + HVX_Vector r1_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r1_q, vy0_q, nloe)); + HVX_Vector r1_c1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r1_q, vy1_q, nloe)); + + HVX_Vector vy0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y0_d + i * y_dblk_size)); + HVX_Vector vy1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y1_d + i * y_dblk_size)); + HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size)); + HVX_Vector r1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r1_x_d + i * x_dblk_size)); + + HVX_Vector r0_c0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy0_d))); + HVX_Vector r0_c1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy1_d))); + HVX_Vector r1_c0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy0_d))); + HVX_Vector r1_c1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy1_d))); + + // Zero out unused elements + HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe / 8); + r0_c0_dd = Q6_V_vand_QV(bmask, r0_c0_dd); + r0_c1_dd = Q6_V_vand_QV(bmask, r0_c1_dd); + r1_c0_dd = Q6_V_vand_QV(bmask, r1_c0_dd); + r1_c1_dd = Q6_V_vand_QV(bmask, r1_c1_dd); + r0_c0_ia = Q6_V_vand_QV(bmask, r0_c0_ia); + r0_c1_ia = Q6_V_vand_QV(bmask, r0_c1_ia); + r1_c0_ia = Q6_V_vand_QV(bmask, r1_c0_ia); + r1_c1_ia = Q6_V_vand_QV(bmask, r1_c1_ia); + + HVX_Vector r0_c0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_c0_ia, r0_c0_dd); + HVX_Vector r0_c1_fa = Q6_Vqf32_vmpy_VsfVsf(r0_c1_ia, r0_c1_dd); + HVX_Vector r1_c0_fa = Q6_Vqf32_vmpy_VsfVsf(r1_c0_ia, r1_c0_dd); + HVX_Vector r1_c1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_c1_ia, r1_c1_dd); + + r0_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c0_fa, r0_c0_sum)); + r0_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c1_fa, r0_c1_sum)); + r1_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c0_fa, r1_c0_sum)); + r1_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c1_fa, r1_c1_sum)); + } + + // Reduce and store results + HVX_Vector r0_r1_c0_sum = hvx_vec_reduce_sum_f32x2(r0_c0_sum, r1_c0_sum); + HVX_Vector r0_r1_c1_sum = hvx_vec_reduce_sum_f32x2(r0_c1_sum, r1_c1_sum); + + hvx_vec_store_u(&s0[0], 8, r0_r1_c0_sum); // row0,col0 row1,col0 + hvx_vec_store_u(&s1[0], 8, r0_r1_c1_sum); // row0,col1 row1,col1 +} + +// ======== IQ4_NL x Q8_0 vec_dot kernels ======== +// Same structure as Q4_0 vec_dot but uses IQ4_NL LUT-based load (4-bit index -> int8 kvalue). +// Scale format is identical to Q4_0 (fp16 scales). + +static void vec_dot_iq4nlx4x2_q8x4x2_1x1(const int n, + float * restrict s0, + const void * restrict vx0, + const void * restrict vy0) { + assert(n % 32 == 0); + assert((unsigned long) vx0 % 128 == 0); + assert((unsigned long) vy0 % 128 == 0); + + const uint32_t qk = QK_Q4_0x4x2 * 4; + + const uint32_t x_dblk_size = 8 * 4 * 2; // 32x __fp16 + const uint32_t x_qblk_size = qk / 2; // int4 + const uint32_t x_qrow_size = n / 2; // int4 (not padded) + + const uint32_t y_dblk_size = 8 * 4 * 2; // 32x __fp16 + const uint32_t y_qblk_size = qk; // int8 + const uint32_t y_qrow_size = n; // int8 (not padded) + + const uint8_t * restrict r0_x_q = ((const uint8_t *) vx0 + 0); // quants first + const uint8_t * restrict r0_x_d = ((const uint8_t *) vx0 + x_qrow_size); // then scales + + const uint8_t * restrict y_q = ((const uint8_t *) vy0 + 0); // quants first + const uint8_t * restrict y_d = ((const uint8_t *) vy0 + y_qrow_size); // then scales + + HVX_Vector r0_sum = Q6_V_vzero(); + + const uint32_t nb = n / qk; + const uint32_t nloe = n % qk; + + uint32_t i = 0; + for (; i < nb; i++) { + HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8_full(y_q + i * y_qblk_size); + HVX_Vector_x8 r0_q = hvx_vec_load_iq4nlx4x8_full(r0_x_q + i * x_qblk_size); + + HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy_q)); + + HVX_Vector vy_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_d + i * y_dblk_size)); + HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size)); + + HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy_d))); + + HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd); + + r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum)); + } + + if (nloe) { + HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8_partial(y_q + i * y_qblk_size, nloe); + HVX_Vector_x8 r0_q = hvx_vec_load_iq4nlx4x8_partial(r0_x_q + i * x_qblk_size, nloe); + + HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r0_q, vy_q, nloe)); + + HVX_Vector vy_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_d + i * y_dblk_size)); + HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size)); + + HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy_d))); + + HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe / 8); + r0_dd = Q6_V_vand_QV(bmask, r0_dd); + r0_ia = Q6_V_vand_QV(bmask, r0_ia); + + HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd); + + r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum)); + } + + r0_sum = hvx_vec_reduce_sum_f32(r0_sum); + + hvx_vec_store_u(s0, 4, r0_sum); +} + +static void vec_dot_iq4nlx4x2_q8x4x2_2x1(const int n, + float * restrict s0, + const void * restrict vx0, + const void * restrict vx1, + const void * restrict vy0) { + assert(n % 32 == 0); + assert((unsigned long) vx0 % 128 == 0); + assert((unsigned long) vx1 % 128 == 0); + assert((unsigned long) vy0 % 128 == 0); + + const uint32_t qk = QK_Q4_0x4x2 * 4; + + const uint32_t x_dblk_size = 8 * 4 * 2; // 32x __fp16 + const uint32_t x_qblk_size = qk / 2; // int4 + const uint32_t x_qrow_size = n / 2; // int4 (not padded) + + const uint32_t y_dblk_size = 8 * 4 * 2; // 32x __fp16 + const uint32_t y_qblk_size = qk; // int8 + const uint32_t y_qrow_size = n; // int8 (not padded) + + const uint8_t * restrict r0_x_q = ((const uint8_t *) vx0) + 0; // quants first + const uint8_t * restrict r0_x_d = ((const uint8_t *) vx0) + x_qrow_size; // then scales + const uint8_t * restrict r1_x_q = ((const uint8_t *) vx1) + 0; // quants first + const uint8_t * restrict r1_x_d = ((const uint8_t *) vx1) + x_qrow_size; // then scales + + const uint8_t * restrict y_q = ((const uint8_t *) vy0 + 0); // quants first + const uint8_t * restrict y_d = ((const uint8_t *) vy0 + y_qrow_size); // then scales + + HVX_Vector r0_sum = Q6_V_vzero(); + HVX_Vector r1_sum = Q6_V_vzero(); + + const uint32_t nb = n / qk; + const uint32_t nloe = n % qk; + + uint32_t i = 0; + for (; i < nb; i++) { + HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8_full(y_q + i * y_qblk_size); + HVX_Vector_x8 r0_q = hvx_vec_load_iq4nlx4x8_full(r0_x_q + i * x_qblk_size); + HVX_Vector_x8 r1_q = hvx_vec_load_iq4nlx4x8_full(r1_x_q + i * x_qblk_size); + + HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy_q)); + HVX_Vector r1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r1_q, vy_q)); + + HVX_Vector vy_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_d + i * y_dblk_size)); + HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size)); + HVX_Vector r1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r1_x_d + i * x_dblk_size)); + + HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy_d))); + HVX_Vector r1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy_d))); + + HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd); + HVX_Vector r1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_ia, r1_dd); + + r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum)); + r1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_fa, r1_sum)); + } + + if (nloe) { + HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8_partial(y_q + i * y_qblk_size, nloe); + HVX_Vector_x8 r0_q = hvx_vec_load_iq4nlx4x8_partial(r0_x_q + i * x_qblk_size, nloe); + HVX_Vector_x8 r1_q = hvx_vec_load_iq4nlx4x8_partial(r1_x_q + i * x_qblk_size, nloe); + + HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r0_q, vy_q, nloe)); + HVX_Vector r1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r1_q, vy_q, nloe)); + + HVX_Vector vy_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_d + i * y_dblk_size)); + HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size)); + HVX_Vector r1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r1_x_d + i * x_dblk_size)); + + HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy_d))); + HVX_Vector r1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy_d))); + + HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe / 8); + r0_dd = Q6_V_vand_QV(bmask, r0_dd); + r1_dd = Q6_V_vand_QV(bmask, r1_dd); + r0_ia = Q6_V_vand_QV(bmask, r0_ia); + r1_ia = Q6_V_vand_QV(bmask, r1_ia); + + HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd); + HVX_Vector r1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_ia, r1_dd); + + r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum)); + r1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_fa, r1_sum)); + } + + HVX_Vector rsum = hvx_vec_reduce_sum_f32x2(r0_sum, r1_sum); + hvx_vec_store_u(s0, 8, rsum); +} + +static void vec_dot_iq4nlx4x2_q8x4x2_4x1(const int n, + float * restrict s0, + const void * restrict vx0, + const void * restrict vx1, + const void * restrict vx2, + const void * restrict vx3, + const void * restrict vy0) { + assert(n % 32 == 0); + assert((unsigned long) vx0 % 128 == 0); + assert((unsigned long) vx1 % 128 == 0); + assert((unsigned long) vx2 % 128 == 0); + assert((unsigned long) vx3 % 128 == 0); + assert((unsigned long) vy0 % 128 == 0); + + const uint32_t qk = QK_Q4_0x4x2 * 4; + + const uint32_t x_dblk_size = 8 * 4 * 2; // 32x __fp16 + const uint32_t x_qblk_size = qk / 2; // int4 + const uint32_t x_qrow_size = n / 2; // int4 (not padded) + + const uint32_t y_dblk_size = 8 * 4 * 2; // 32x __fp16 + const uint32_t y_qblk_size = qk; // int8 + const uint32_t y_qrow_size = n; // int8 (not padded) + + const uint8_t * restrict r0_x_q = ((const uint8_t *) vx0) + 0; // quants first + const uint8_t * restrict r0_x_d = ((const uint8_t *) vx0) + x_qrow_size; // then scales + const uint8_t * restrict r1_x_q = ((const uint8_t *) vx1) + 0; // quants first + const uint8_t * restrict r1_x_d = ((const uint8_t *) vx1) + x_qrow_size; // then scales + const uint8_t * restrict r2_x_q = ((const uint8_t *) vx2) + 0; // quants first + const uint8_t * restrict r2_x_d = ((const uint8_t *) vx2) + x_qrow_size; // then scales + const uint8_t * restrict r3_x_q = ((const uint8_t *) vx3) + 0; // quants first + const uint8_t * restrict r3_x_d = ((const uint8_t *) vx3) + x_qrow_size; // then scales + + const uint8_t * restrict y_q = ((const uint8_t *) vy0 + 0); // quants first + const uint8_t * restrict y_d = ((const uint8_t *) vy0 + y_qrow_size); // then scales + + HVX_Vector r0_sum = Q6_V_vzero(); + HVX_Vector r1_sum = Q6_V_vzero(); + HVX_Vector r2_sum = Q6_V_vzero(); + HVX_Vector r3_sum = Q6_V_vzero(); + + const uint32_t nb = n / qk; + const uint32_t nloe = n % qk; + + uint32_t i = 0; + for (; i < nb; i++) { + HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8_full(y_q + i * y_qblk_size); + HVX_Vector_x8 r0_q = hvx_vec_load_iq4nlx4x8_full(r0_x_q + i * x_qblk_size); + HVX_Vector_x8 r1_q = hvx_vec_load_iq4nlx4x8_full(r1_x_q + i * x_qblk_size); + HVX_Vector_x8 r2_q = hvx_vec_load_iq4nlx4x8_full(r2_x_q + i * x_qblk_size); + HVX_Vector_x8 r3_q = hvx_vec_load_iq4nlx4x8_full(r3_x_q + i * x_qblk_size); + + HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy_q)); + HVX_Vector r1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r1_q, vy_q)); + HVX_Vector r2_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r2_q, vy_q)); + HVX_Vector r3_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r3_q, vy_q)); + + HVX_Vector vy_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_d + i * y_dblk_size)); + HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size)); + HVX_Vector r1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r1_x_d + i * x_dblk_size)); + HVX_Vector r2_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r2_x_d + i * x_dblk_size)); + HVX_Vector r3_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r3_x_d + i * x_dblk_size)); + + HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy_d))); + HVX_Vector r1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy_d))); + HVX_Vector r2_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r2_d, vy_d))); + HVX_Vector r3_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r3_d, vy_d))); + + HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd); + HVX_Vector r1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_ia, r1_dd); + HVX_Vector r2_fa = Q6_Vqf32_vmpy_VsfVsf(r2_ia, r2_dd); + HVX_Vector r3_fa = Q6_Vqf32_vmpy_VsfVsf(r3_ia, r3_dd); + + r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum)); + r1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_fa, r1_sum)); + r2_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r2_fa, r2_sum)); + r3_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r3_fa, r3_sum)); + } + + if (nloe) { + HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8_partial(y_q + i * y_qblk_size, nloe); + HVX_Vector_x8 r0_q = hvx_vec_load_iq4nlx4x8_partial(r0_x_q + i * x_qblk_size, nloe); + HVX_Vector_x8 r1_q = hvx_vec_load_iq4nlx4x8_partial(r1_x_q + i * x_qblk_size, nloe); + HVX_Vector_x8 r2_q = hvx_vec_load_iq4nlx4x8_partial(r2_x_q + i * x_qblk_size, nloe); + HVX_Vector_x8 r3_q = hvx_vec_load_iq4nlx4x8_partial(r3_x_q + i * x_qblk_size, nloe); + + HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r0_q, vy_q, nloe)); + HVX_Vector r1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r1_q, vy_q, nloe)); + HVX_Vector r2_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r2_q, vy_q, nloe)); + HVX_Vector r3_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r3_q, vy_q, nloe)); + + HVX_Vector vy_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_d + i * y_dblk_size)); + HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size)); + HVX_Vector r1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r1_x_d + i * x_dblk_size)); + HVX_Vector r2_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r2_x_d + i * x_dblk_size)); + HVX_Vector r3_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r3_x_d + i * x_dblk_size)); + + HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy_d))); + HVX_Vector r1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy_d))); + HVX_Vector r2_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r2_d, vy_d))); + HVX_Vector r3_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r3_d, vy_d))); + HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe / 8); r0_dd = Q6_V_vand_QV(bmask, r0_dd); r1_dd = Q6_V_vand_QV(bmask, r1_dd); + r2_dd = Q6_V_vand_QV(bmask, r2_dd); + r3_dd = Q6_V_vand_QV(bmask, r3_dd); + r0_ia = Q6_V_vand_QV(bmask, r0_ia); + r1_ia = Q6_V_vand_QV(bmask, r1_ia); + r2_ia = Q6_V_vand_QV(bmask, r2_ia); + r3_ia = Q6_V_vand_QV(bmask, r3_ia); HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd); HVX_Vector r1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_ia, r1_dd); + HVX_Vector r2_fa = Q6_Vqf32_vmpy_VsfVsf(r2_ia, r2_dd); + HVX_Vector r3_fa = Q6_Vqf32_vmpy_VsfVsf(r3_ia, r3_dd); - r0_sum = Q6_Vqf32_vadd_Vqf32Vqf32(r0_sum, r0_fa); - r1_sum = Q6_Vqf32_vadd_Vqf32Vqf32(r1_sum, r1_fa); + r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum)); + r1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_fa, r1_sum)); + r2_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r2_fa, r2_sum)); + r3_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r3_fa, r3_sum)); } - // Convert into fp32 and reduce - r0_sum = hvx_vec_fp32_reduce_sum(Q6_Vsf_equals_Vqf32(r0_sum)); - r1_sum = hvx_vec_fp32_reduce_sum(Q6_Vsf_equals_Vqf32(r1_sum)); - HVX_VectorPair p0 = Q6_W_vshuff_VVR(r1_sum, r0_sum, 4); + HVX_Vector_x4 rsum_in = { .v = { r0_sum, r1_sum, r2_sum, r3_sum } }; + HVX_Vector rsum = hvx_vec_reduce_sum_f32x4(rsum_in); + hvx_vec_store_u(s0, 16, rsum); +} + + +static void vec_dot_iq4nlx4x2_q8x4x2_2x2(const int n, + float * restrict s0, + float * restrict s1, + const void * restrict vx0, + const void * restrict vx1, + const void * restrict vy0, + const void * restrict vy1) { + assert(n % 32 == 0); + assert((unsigned long) vx0 % 128 == 0); + assert((unsigned long) vx1 % 128 == 0); + assert((unsigned long) vy0 % 128 == 0); + assert((unsigned long) vy1 % 128 == 0); + + const uint32_t qk = QK_Q4_0x4x2 * 4; + + const uint32_t x_dblk_size = 8 * 4 * 2; // 32x __fp16 + const uint32_t x_qblk_size = qk / 2; // int4 + const uint32_t x_qrow_size = n / 2; // int4 (not padded) + + const uint32_t y_dblk_size = 8 * 4 * 2; // 32x __fp16 + const uint32_t y_qblk_size = qk; // int8 + const uint32_t y_qrow_size = n; // int8 (not padded) - hvx_vec_store_u(&s[0], 8, Q6_V_lo_W(p0)); + const uint8_t * restrict r0_x_q = ((const uint8_t *) vx0) + 0; + const uint8_t * restrict r0_x_d = ((const uint8_t *) vx0) + x_qrow_size; + const uint8_t * restrict r1_x_q = ((const uint8_t *) vx1) + 0; + const uint8_t * restrict r1_x_d = ((const uint8_t *) vx1) + x_qrow_size; + + const uint8_t * restrict y0_q = ((const uint8_t *) vy0) + 0; + const uint8_t * restrict y0_d = ((const uint8_t *) vy0) + y_qrow_size; + const uint8_t * restrict y1_q = ((const uint8_t *) vy1) + 0; + const uint8_t * restrict y1_d = ((const uint8_t *) vy1) + y_qrow_size; + + HVX_Vector r0_c0_sum = Q6_V_vzero(); + HVX_Vector r0_c1_sum = Q6_V_vzero(); + HVX_Vector r1_c0_sum = Q6_V_vzero(); + HVX_Vector r1_c1_sum = Q6_V_vzero(); + + const uint32_t nb = n / qk; + const uint32_t nloe = n % qk; + + uint32_t i = 0; + for (; i < nb; i++) { + HVX_Vector_x8 vy0_q = hvx_vec_load_q8x4x8_full(y0_q + i * y_qblk_size); + HVX_Vector_x8 vy1_q = hvx_vec_load_q8x4x8_full(y1_q + i * y_qblk_size); + HVX_Vector_x8 r0_q = hvx_vec_load_iq4nlx4x8_full(r0_x_q + i * x_qblk_size); + HVX_Vector_x8 r1_q = hvx_vec_load_iq4nlx4x8_full(r1_x_q + i * x_qblk_size); + + HVX_Vector r0_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy0_q)); + HVX_Vector r0_c1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy1_q)); + HVX_Vector r1_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r1_q, vy0_q)); + HVX_Vector r1_c1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r1_q, vy1_q)); + + HVX_Vector vy0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y0_d + i * y_dblk_size)); + HVX_Vector vy1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y1_d + i * y_dblk_size)); + HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size)); + HVX_Vector r1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r1_x_d + i * x_dblk_size)); + + HVX_Vector r0_c0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy0_d))); + HVX_Vector r0_c1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy1_d))); + HVX_Vector r1_c0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy0_d))); + HVX_Vector r1_c1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy1_d))); + + HVX_Vector r0_c0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_c0_ia, r0_c0_dd); + HVX_Vector r0_c1_fa = Q6_Vqf32_vmpy_VsfVsf(r0_c1_ia, r0_c1_dd); + HVX_Vector r1_c0_fa = Q6_Vqf32_vmpy_VsfVsf(r1_c0_ia, r1_c0_dd); + HVX_Vector r1_c1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_c1_ia, r1_c1_dd); + + r0_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c0_fa, r0_c0_sum)); + r0_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c1_fa, r0_c1_sum)); + r1_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c0_fa, r1_c0_sum)); + r1_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c1_fa, r1_c1_sum)); + } + + if (nloe) { + HVX_Vector_x8 vy0_q = hvx_vec_load_q8x4x8_partial(y0_q + i * y_qblk_size, nloe); + HVX_Vector_x8 vy1_q = hvx_vec_load_q8x4x8_partial(y1_q + i * y_qblk_size, nloe); + HVX_Vector_x8 r0_q = hvx_vec_load_iq4nlx4x8_partial(r0_x_q + i * x_qblk_size, nloe); + HVX_Vector_x8 r1_q = hvx_vec_load_iq4nlx4x8_partial(r1_x_q + i * x_qblk_size, nloe); + + HVX_Vector r0_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r0_q, vy0_q, nloe)); + HVX_Vector r0_c1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r0_q, vy1_q, nloe)); + HVX_Vector r1_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r1_q, vy0_q, nloe)); + HVX_Vector r1_c1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r1_q, vy1_q, nloe)); + + HVX_Vector vy0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y0_d + i * y_dblk_size)); + HVX_Vector vy1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y1_d + i * y_dblk_size)); + HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size)); + HVX_Vector r1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r1_x_d + i * x_dblk_size)); + + HVX_Vector r0_c0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy0_d))); + HVX_Vector r0_c1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy1_d))); + HVX_Vector r1_c0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy0_d))); + HVX_Vector r1_c1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy1_d))); + + HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe / 8); + r0_c0_dd = Q6_V_vand_QV(bmask, r0_c0_dd); + r0_c1_dd = Q6_V_vand_QV(bmask, r0_c1_dd); + r1_c0_dd = Q6_V_vand_QV(bmask, r1_c0_dd); + r1_c1_dd = Q6_V_vand_QV(bmask, r1_c1_dd); + r0_c0_ia = Q6_V_vand_QV(bmask, r0_c0_ia); + r0_c1_ia = Q6_V_vand_QV(bmask, r0_c1_ia); + r1_c0_ia = Q6_V_vand_QV(bmask, r1_c0_ia); + r1_c1_ia = Q6_V_vand_QV(bmask, r1_c1_ia); + + HVX_Vector r0_c0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_c0_ia, r0_c0_dd); + HVX_Vector r0_c1_fa = Q6_Vqf32_vmpy_VsfVsf(r0_c1_ia, r0_c1_dd); + HVX_Vector r1_c0_fa = Q6_Vqf32_vmpy_VsfVsf(r1_c0_ia, r1_c0_dd); + HVX_Vector r1_c1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_c1_ia, r1_c1_dd); + + r0_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c0_fa, r0_c0_sum)); + r0_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c1_fa, r0_c1_sum)); + r1_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c0_fa, r1_c0_sum)); + r1_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c1_fa, r1_c1_sum)); + } + + HVX_Vector r0_r1_c0_sum = hvx_vec_reduce_sum_f32x2(r0_c0_sum, r1_c0_sum); + HVX_Vector r0_r1_c1_sum = hvx_vec_reduce_sum_f32x2(r0_c1_sum, r1_c1_sum); + + hvx_vec_store_u(&s0[0], 8, r0_r1_c0_sum); + hvx_vec_store_u(&s1[0], 8, r0_r1_c1_sum); } -static void vec_dot_mxfp4x4x2_q8x4x2(const int n, - float * restrict s, - const void * restrict vx, - const void * restrict vy) { +static void vec_dot_mxfp4x4x2_q8x4x2_1x1(const int n, float * restrict s0, const void * restrict vx0, const void * restrict vy0) { assert(n % 32 == 0); // min sub-block size - assert((unsigned long) vx % 128 == 0); - assert((unsigned long) vy % 128 == 0); + assert((unsigned long) vx0 % 128 == 0); + assert((unsigned long) vy0 % 128 == 0); const uint32_t qk = QK_MXFP4x4x2 * 4; @@ -683,14 +2358,14 @@ static void vec_dot_mxfp4x4x2_q8x4x2(const int n, const uint32_t y_qblk_size = qk; // int8 const uint32_t y_qrow_size = n; // int8 (not padded) - const uint8_t * restrict r0_x_q = ((const uint8_t *) vx + 0); // quants first - const uint8_t * restrict r0_x_d = ((const uint8_t *) vx + x_qrow_size); // then scales + const uint8_t * restrict r0_x_q = ((const uint8_t *) vx0 + 0); // quants first + const uint8_t * restrict r0_x_d = ((const uint8_t *) vx0 + x_qrow_size); // then scales - const uint8_t * restrict y_q = ((const uint8_t *) vy + 0); // quants first - const uint8_t * restrict y_d = ((const uint8_t *) vy + y_qrow_size); // then scales + const uint8_t * restrict y_q = ((const uint8_t *) vy0 + 0); // quants first + const uint8_t * restrict y_d = ((const uint8_t *) vy0 + y_qrow_size); // then scales - // Row sum (qf32) - HVX_Vector r0_sum = Q6_V_vsplat_R(0); + // Row sum (sf) + HVX_Vector r0_sum = Q6_V_vzero(); // Multiply and accumulate into int32. // Compute combined scale (fp32). @@ -701,8 +2376,8 @@ static void vec_dot_mxfp4x4x2_q8x4x2(const int n, uint32_t i = 0; for (; i < nb; i++) { - HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8(y_q + i * y_qblk_size); - HVX_Vector_x8 r0_q = hvx_vec_load_mxfp4x4x8(r0_x_q + i * x_qblk_size); + HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8_full( y_q + i * y_qblk_size); + HVX_Vector_x8 r0_q = hvx_vec_load_mxfp4x4x8_full(r0_x_q + i * x_qblk_size); HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy_q)); @@ -728,17 +2403,17 @@ static void vec_dot_mxfp4x4x2_q8x4x2(const int n, HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd); - r0_sum = Q6_Vqf32_vadd_Vqf32Vqf32(r0_sum, r0_fa); + r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum)); } // Process leftovers if (nloe) { - HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8(y_q + i * y_qblk_size); - HVX_Vector_x8 r0_q = hvx_vec_load_mxfp4x4x8(r0_x_q + i * x_qblk_size); + HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8_partial( y_q + i * y_qblk_size, nloe); + HVX_Vector_x8 r0_q = hvx_vec_load_mxfp4x4x8_partial(r0_x_q + i * x_qblk_size, nloe); - HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy_q)); + HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r0_q, vy_q, nloe)); - HVX_Vector vy_d = *(const HVX_UVector *) (y_d + i * y_dblk_size); + HVX_Vector vy_d = *(const HVX_UVector *) (y_d + i * y_dblk_size); HVX_Vector r0_d = *(const HVX_UVector *) (r0_x_d + i * x_dblk_size); // Convert vy_d from fp16 to fp32 while applying 0.5 scaling which is used for e8m0 halving @@ -761,62 +2436,60 @@ static void vec_dot_mxfp4x4x2_q8x4x2(const int n, // Zero-out unused scales HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe / 8); r0_dd = Q6_V_vand_QV(bmask, r0_dd); + r0_ia = Q6_V_vand_QV(bmask, r0_ia); HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd); - r0_sum = Q6_Vqf32_vadd_Vqf32Vqf32(r0_sum, r0_fa); + r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum)); } - // Reduce and convert into fp32 - r0_sum = hvx_vec_fp32_reduce_sum(Q6_Vsf_equals_Vqf32(r0_sum)); + r0_sum = hvx_vec_reduce_sum_f32(r0_sum); - hvx_vec_store_u(&s[0], 4, r0_sum); + hvx_vec_store_u(s0, 4, r0_sum); } -static void vec_dot_mxfp4x4x2_q8x4x2_rx2(const int n, - float * restrict s, - const void * restrict vx, - uint32_t vx_row_size, - const void * restrict vy) { +static void vec_dot_mxfp4x4x2_q8x4x2_2x1(const int n, float * restrict s0, + const void * restrict vx0, const void * restrict vx1, + const void * restrict vy0) { assert(n % 32 == 0); // min sub-block size - assert((unsigned long) vx % 128 == 0); - assert((unsigned long) vy % 128 == 0); + assert((unsigned long) vx0 % 128 == 0); + assert((unsigned long) vx1 % 128 == 0); + assert((unsigned long) vy0 % 128 == 0); const uint32_t qk = QK_MXFP4x4x2 * 4; - const uint32_t x_dblk_size = 8 * 4 * 1; // 32x e8m0 - const uint32_t x_qblk_size = qk / 2; // fp4 - const uint32_t x_qrow_size = n / 2; // fp4 (not padded) - - const uint32_t y_dblk_size = 8 * 4 * 2; // 32x __fp16 - const uint32_t y_qblk_size = qk; // int8 - const uint32_t y_qrow_size = n; // int8 (not padded) + const uint32_t x_dblk_size = 8 * 4 * 1; // 32x e8m0 + const uint32_t x_qblk_size = qk / 2; // fp4 + const uint32_t x_qrow_size = n / 2; // fp4 (not padded) - const uint8_t * restrict r0_x_q = ((const uint8_t *) (vx + (0 * vx_row_size)) + 0); // quants first - const uint8_t * restrict r0_x_d = ((const uint8_t *) (vx + (0 * vx_row_size)) + x_qrow_size); // then scales + const uint32_t y_dblk_size = 8 * 4 * 2; // 32x __fp16 + const uint32_t y_qblk_size = qk; // int8 + const uint32_t y_qrow_size = n; // int8 (not padded) - const uint8_t * restrict r1_x_q = ((const uint8_t *) (vx + (1 * vx_row_size)) + 0); // quants first - const uint8_t * restrict r1_x_d = ((const uint8_t *) (vx + (1 * vx_row_size)) + x_qrow_size); // then scales + const uint8_t * restrict r0_x_q = ((const uint8_t *) vx0) + 0; // quants first + const uint8_t * restrict r0_x_d = ((const uint8_t *) vx0) + x_qrow_size; // then scales + const uint8_t * restrict r1_x_q = ((const uint8_t *) vx1) + 0; // quants first + const uint8_t * restrict r1_x_d = ((const uint8_t *) vx1) + x_qrow_size; // then scales - const uint8_t * restrict y_q = ((const uint8_t *) vy + 0); // quants first - const uint8_t * restrict y_d = ((const uint8_t *) vy + y_qrow_size); // then scales + const uint8_t * restrict y_q = ((const uint8_t *) vy0) + 0; // quants first + const uint8_t * restrict y_d = ((const uint8_t *) vy0) + y_qrow_size; // then scales - // Row sum (qf32) - HVX_Vector r0_sum = Q6_V_vsplat_R(0); - HVX_Vector r1_sum = Q6_V_vsplat_R(0); + // Row sum (sf) + HVX_Vector r0_sum = Q6_V_vzero(); + HVX_Vector r1_sum = Q6_V_vzero(); // Multiply and accumulate into int32. // Compute combined scale (fp32). - // Apply scale to acc and accumulate into the row sum (qf32). + // Apply scale to acc and accumulate into the row sum (f32). const uint32_t nb = n / qk; // num full blocks int32_t nloe = n % qk; // num leftover elemements (must be signed) uint32_t i = 0; for (; i < nb; i++) { - HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8(y_q + i * y_qblk_size); - HVX_Vector_x8 r0_q = hvx_vec_load_mxfp4x4x8(r0_x_q + i * x_qblk_size); - HVX_Vector_x8 r1_q = hvx_vec_load_mxfp4x4x8(r1_x_q + i * x_qblk_size); + HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8_full( y_q + i * y_qblk_size); + HVX_Vector_x8 r0_q = hvx_vec_load_mxfp4x4x8_full(r0_x_q + i * x_qblk_size); + HVX_Vector_x8 r1_q = hvx_vec_load_mxfp4x4x8_full(r1_x_q + i * x_qblk_size); HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy_q)); HVX_Vector r1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r1_q, vy_q)); @@ -843,155 +2516,710 @@ static void vec_dot_mxfp4x4x2_q8x4x2_rx2(const int n, r1_d = Q6_V_vand_VV(r1_d, e8m0_mask); r1_d = Q6_Vw_vasl_VwR(r1_d, 23); - HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(r0_d, vy_d)); - HVX_Vector r1_dd = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(r1_d, vy_d)); + HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(r0_d, vy_d)); + HVX_Vector r1_dd = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(r1_d, vy_d)); + + HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd); + HVX_Vector r1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_ia, r1_dd); + + r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum)); + r1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_fa, r1_sum)); + } + + // Process leftovers + if (nloe) { + HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8_partial( y_q + i * y_qblk_size, nloe); + HVX_Vector_x8 r0_q = hvx_vec_load_mxfp4x4x8_partial(r0_x_q + i * x_qblk_size, nloe); + HVX_Vector_x8 r1_q = hvx_vec_load_mxfp4x4x8_partial(r1_x_q + i * x_qblk_size, nloe); + + HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy_q)); + HVX_Vector r1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r1_q, vy_q)); + + HVX_Vector vy_d = *(const HVX_UVector *) (y_d + i * y_dblk_size); + HVX_Vector r0_d = *(const HVX_UVector *) (r0_x_d + i * x_dblk_size); + HVX_Vector r1_d = *(const HVX_UVector *) (r1_x_d + i * x_dblk_size); + + // Convert vy_d from fp16 to fp32 while applying 0.5 scaling which is used for e8m0 halving + HVX_Vector half = Q6_Vh_vsplat_R(0x3800); // 0.5 in fp16 + vy_d = Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(vy_d), half)); + vy_d = Q6_Vsf_equals_Vqf32(vy_d); + + // Convert rX_d scales from e8m0 to fp32 + // Expand and zero-pad 32x uint8 e8m0 values to uint32s : 0 0 0 0, 0 0 0 1, 0 0 0 2, ... + // Left shift with zero fill to create FP32 + // FIXME: might need to handle zero as a special case (see ggml-cpu code) + HVX_Vector expand = *(const HVX_Vector *) expand_x32_e8m0; + HVX_Vector e8m0_mask = Q6_V_vsplat_R(0x000000ff); + r0_d = Q6_V_vdelta_VV(r0_d, expand); + r0_d = Q6_V_vand_VV(r0_d, e8m0_mask); + r0_d = Q6_Vw_vasl_VwR(r0_d, 23); + r1_d = Q6_V_vdelta_VV(r1_d, expand); + r1_d = Q6_V_vand_VV(r1_d, e8m0_mask); + r1_d = Q6_Vw_vasl_VwR(r1_d, 23); + + HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(r0_d, vy_d)); + HVX_Vector r1_dd = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(r1_d, vy_d)); + + // Zero-out unused values + HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe / 8); + r0_dd = Q6_V_vand_QV(bmask, r0_dd); + r1_dd = Q6_V_vand_QV(bmask, r1_dd); + r0_ia = Q6_V_vand_QV(bmask, r0_ia); + r1_ia = Q6_V_vand_QV(bmask, r1_ia); + + HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd); + HVX_Vector r1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_ia, r1_dd); + + r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum)); + r1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_fa, r1_sum)); + } + + HVX_Vector rsum = hvx_vec_reduce_sum_f32x2(r0_sum, r1_sum); + hvx_vec_store_u(s0, 8, rsum); +} + +static void vec_dot_mxfp4x4x2_q8x4x2_4x1(const int n, float * restrict s0, + const void * restrict vx0, const void * restrict vx1, + const void * restrict vx2, const void * restrict vx3, + const void * restrict vy0) { + assert(n % 32 == 0); // min sub-block size + assert((unsigned long) vx0 % 128 == 0); + assert((unsigned long) vx1 % 128 == 0); + assert((unsigned long) vx2 % 128 == 0); + assert((unsigned long) vx3 % 128 == 0); + assert((unsigned long) vy0 % 128 == 0); + + const uint32_t qk = QK_MXFP4x4x2 * 4; + + const uint32_t x_dblk_size = 8 * 4 * 1; // 32x e8m0 + const uint32_t x_qblk_size = qk / 2; // fp4 + const uint32_t x_qrow_size = n / 2; // fp4 (not padded) + + const uint32_t y_dblk_size = 8 * 4 * 2; // 32x __fp16 + const uint32_t y_qblk_size = qk; // int8 + const uint32_t y_qrow_size = n; // int8 (not padded) + + const uint8_t * restrict r0_x_q = ((const uint8_t *) vx0) + 0; // quants first + const uint8_t * restrict r0_x_d = ((const uint8_t *) vx0) + x_qrow_size; // then scales + const uint8_t * restrict r1_x_q = ((const uint8_t *) vx1) + 0; // quants first + const uint8_t * restrict r1_x_d = ((const uint8_t *) vx1) + x_qrow_size; // then scales + const uint8_t * restrict r2_x_q = ((const uint8_t *) vx2) + 0; // quants first + const uint8_t * restrict r2_x_d = ((const uint8_t *) vx2) + x_qrow_size; // then scales + const uint8_t * restrict r3_x_q = ((const uint8_t *) vx3) + 0; // quants first + const uint8_t * restrict r3_x_d = ((const uint8_t *) vx3) + x_qrow_size; // then scales + + const uint8_t * restrict y_q = ((const uint8_t *) vy0) + 0; // quants first + const uint8_t * restrict y_d = ((const uint8_t *) vy0 + y_qrow_size); // then scales + + // Row sum (sf) + HVX_Vector r0_sum = Q6_V_vzero(); + HVX_Vector r1_sum = Q6_V_vzero(); + HVX_Vector r2_sum = Q6_V_vzero(); + HVX_Vector r3_sum = Q6_V_vzero(); + + const uint32_t nb = n / qk; // num full blocks + int32_t nloe = n % qk; // num leftover elemements (must be signed) + + uint32_t i = 0; + for (; i < nb; i++) { + HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8_full( y_q + i * y_qblk_size); + HVX_Vector_x8 r0_q = hvx_vec_load_mxfp4x4x8_full(r0_x_q + i * x_qblk_size); + HVX_Vector_x8 r1_q = hvx_vec_load_mxfp4x4x8_full(r1_x_q + i * x_qblk_size); + HVX_Vector_x8 r2_q = hvx_vec_load_mxfp4x4x8_full(r2_x_q + i * x_qblk_size); + HVX_Vector_x8 r3_q = hvx_vec_load_mxfp4x4x8_full(r3_x_q + i * x_qblk_size); + + HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy_q)); + HVX_Vector r1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r1_q, vy_q)); + HVX_Vector r2_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r2_q, vy_q)); + HVX_Vector r3_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r3_q, vy_q)); + + HVX_Vector vy_d = *(const HVX_UVector *) (y_d + i * y_dblk_size); + HVX_Vector r0_d = *(const HVX_UVector *) (r0_x_d + i * x_dblk_size); + HVX_Vector r1_d = *(const HVX_UVector *) (r1_x_d + i * x_dblk_size); + HVX_Vector r2_d = *(const HVX_UVector *) (r2_x_d + i * x_dblk_size); + HVX_Vector r3_d = *(const HVX_UVector *) (r3_x_d + i * x_dblk_size); + + // Convert vy_d from fp16 to fp32 while applying 0.5 scaling which is used for e8m0 halving + HVX_Vector half = Q6_Vh_vsplat_R(0x3800); // 0.5 in fp16 + vy_d = Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(vy_d), half)); + vy_d = Q6_Vsf_equals_Vqf32(vy_d); + + // Convert rX_d scales from e8m0 to fp32 + HVX_Vector expand = *(const HVX_Vector *) expand_x32_e8m0; + HVX_Vector e8m0_mask = Q6_V_vsplat_R(0x000000ff); + r0_d = Q6_V_vdelta_VV(r0_d, expand); + r0_d = Q6_V_vand_VV(r0_d, e8m0_mask); + r0_d = Q6_Vw_vasl_VwR(r0_d, 23); + r1_d = Q6_V_vdelta_VV(r1_d, expand); + r1_d = Q6_V_vand_VV(r1_d, e8m0_mask); + r1_d = Q6_Vw_vasl_VwR(r1_d, 23); + r2_d = Q6_V_vdelta_VV(r2_d, expand); + r2_d = Q6_V_vand_VV(r2_d, e8m0_mask); + r2_d = Q6_Vw_vasl_VwR(r2_d, 23); + r3_d = Q6_V_vdelta_VV(r3_d, expand); + r3_d = Q6_V_vand_VV(r3_d, e8m0_mask); + r3_d = Q6_Vw_vasl_VwR(r3_d, 23); + + HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(r0_d, vy_d)); + HVX_Vector r1_dd = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(r1_d, vy_d)); + HVX_Vector r2_dd = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(r2_d, vy_d)); + HVX_Vector r3_dd = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(r3_d, vy_d)); + + HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd); + HVX_Vector r1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_ia, r1_dd); + HVX_Vector r2_fa = Q6_Vqf32_vmpy_VsfVsf(r2_ia, r2_dd); + HVX_Vector r3_fa = Q6_Vqf32_vmpy_VsfVsf(r3_ia, r3_dd); + + r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum)); + r1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_fa, r1_sum)); + r2_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r2_fa, r2_sum)); + r3_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r3_fa, r3_sum)); + } + + if (nloe) { + HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8_partial( y_q + i * y_qblk_size, nloe); + HVX_Vector_x8 r0_q = hvx_vec_load_mxfp4x4x8_partial(r0_x_q + i * x_qblk_size, nloe); + HVX_Vector_x8 r1_q = hvx_vec_load_mxfp4x4x8_partial(r1_x_q + i * x_qblk_size, nloe); + HVX_Vector_x8 r2_q = hvx_vec_load_mxfp4x4x8_partial(r2_x_q + i * x_qblk_size, nloe); + HVX_Vector_x8 r3_q = hvx_vec_load_mxfp4x4x8_partial(r3_x_q + i * x_qblk_size, nloe); + + HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy_q)); + HVX_Vector r1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r1_q, vy_q)); + HVX_Vector r2_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r2_q, vy_q)); + HVX_Vector r3_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r3_q, vy_q)); + + HVX_Vector vy_d = *(const HVX_UVector *) (y_d + i * y_dblk_size); + HVX_Vector r0_d = *(const HVX_UVector *) (r0_x_d + i * x_dblk_size); + HVX_Vector r1_d = *(const HVX_UVector *) (r1_x_d + i * x_dblk_size); + HVX_Vector r2_d = *(const HVX_UVector *) (r2_x_d + i * x_dblk_size); + HVX_Vector r3_d = *(const HVX_UVector *) (r3_x_d + i * x_dblk_size); + + // Convert vy_d from fp16 to fp32 while applying 0.5 scaling which is used for e8m0 halving + HVX_Vector half = Q6_Vh_vsplat_R(0x3800); // 0.5 in fp16 + vy_d = Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(vy_d), half)); + vy_d = Q6_Vsf_equals_Vqf32(vy_d); + + // Convert rX_d scales from e8m0 to fp32 + HVX_Vector expand = *(const HVX_Vector *) expand_x32_e8m0; + HVX_Vector e8m0_mask = Q6_V_vsplat_R(0x000000ff); + r0_d = Q6_V_vdelta_VV(r0_d, expand); + r0_d = Q6_V_vand_VV(r0_d, e8m0_mask); + r0_d = Q6_Vw_vasl_VwR(r0_d, 23); + r1_d = Q6_V_vdelta_VV(r1_d, expand); + r1_d = Q6_V_vand_VV(r1_d, e8m0_mask); + r1_d = Q6_Vw_vasl_VwR(r1_d, 23); + r2_d = Q6_V_vdelta_VV(r2_d, expand); + r2_d = Q6_V_vand_VV(r2_d, e8m0_mask); + r2_d = Q6_Vw_vasl_VwR(r2_d, 23); + r3_d = Q6_V_vdelta_VV(r3_d, expand); + r3_d = Q6_V_vand_VV(r3_d, e8m0_mask); + r3_d = Q6_Vw_vasl_VwR(r3_d, 23); + + HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(r0_d, vy_d)); + HVX_Vector r1_dd = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(r1_d, vy_d)); + HVX_Vector r2_dd = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(r2_d, vy_d)); + HVX_Vector r3_dd = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(r3_d, vy_d)); + + // Zero-out unused values + HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe / 8); + r0_dd = Q6_V_vand_QV(bmask, r0_dd); + r1_dd = Q6_V_vand_QV(bmask, r1_dd); + r2_dd = Q6_V_vand_QV(bmask, r2_dd); + r3_dd = Q6_V_vand_QV(bmask, r3_dd); + r0_ia = Q6_V_vand_QV(bmask, r0_ia); + r1_ia = Q6_V_vand_QV(bmask, r1_ia); + r2_ia = Q6_V_vand_QV(bmask, r2_ia); + r3_ia = Q6_V_vand_QV(bmask, r3_ia); + + HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd); + HVX_Vector r1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_ia, r1_dd); + HVX_Vector r2_fa = Q6_Vqf32_vmpy_VsfVsf(r2_ia, r2_dd); + HVX_Vector r3_fa = Q6_Vqf32_vmpy_VsfVsf(r3_ia, r3_dd); + + r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum)); + r1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_fa, r1_sum)); + r2_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r2_fa, r2_sum)); + r3_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r3_fa, r3_sum)); + } + + HVX_Vector_x4 rsum_in = { .v = { r0_sum, r1_sum, r2_sum, r3_sum } }; + HVX_Vector rsum = hvx_vec_reduce_sum_f32x4(rsum_in); + hvx_vec_store_u(s0, 16, rsum); +} + + +static void vec_dot_mxfp4x4x2_q8x4x2_2x2(const int n, float * restrict s0, float * restrict s1, + const void * restrict vx0, const void * restrict vx1, + const void * restrict vy0, const void * restrict vy1) { + assert(n % 32 == 0); + assert((unsigned long) vx0 % 128 == 0); + assert((unsigned long) vx1 % 128 == 0); + assert((unsigned long) vy0 % 128 == 0); + assert((unsigned long) vy1 % 128 == 0); + + const uint32_t qk = QK_MXFP4x4x2 * 4; + + const uint32_t x_dblk_size = 8 * 4 * 1; // 32x e8m0 + const uint32_t x_qblk_size = qk / 2; // fp4 + const uint32_t x_qrow_size = n / 2; // fp4 (not padded) + + const uint32_t y_dblk_size = 8 * 4 * 2; // 32x __fp16 + const uint32_t y_qblk_size = qk; // int8 + const uint32_t y_qrow_size = n; // int8 (not padded) + + const uint8_t * restrict r0_x_q = ((const uint8_t *) vx0) + 0; // quants first + const uint8_t * restrict r0_x_d = ((const uint8_t *) vx0) + x_qrow_size; // then scales + const uint8_t * restrict r1_x_q = ((const uint8_t *) vx1) + 0; // quants first + const uint8_t * restrict r1_x_d = ((const uint8_t *) vx1) + x_qrow_size; // then scales + + const uint8_t * restrict y0_q = ((const uint8_t *) vy0) + 0; // quants first + const uint8_t * restrict y0_d = ((const uint8_t *) vy0) + y_qrow_size; // then scales + const uint8_t * restrict y1_q = ((const uint8_t *) vy1) + 0; // quants first + const uint8_t * restrict y1_d = ((const uint8_t *) vy1) + y_qrow_size; // then scales + + // Row sums (sf) - 4 accumulators for 2×2 tile + HVX_Vector r0_c0_sum = Q6_V_vzero(); + HVX_Vector r0_c1_sum = Q6_V_vzero(); + HVX_Vector r1_c0_sum = Q6_V_vzero(); + HVX_Vector r1_c1_sum = Q6_V_vzero(); + + const uint32_t nb = n / qk; // num full blocks + const uint32_t nloe = n % qk; // num leftover elements + + uint32_t i = 0; + for (; i < nb; i++) { + // Load src1 columns (reused across both src0 rows) + HVX_Vector_x8 vy0_q = hvx_vec_load_q8x4x8_full(y0_q + i * y_qblk_size); + HVX_Vector_x8 vy1_q = hvx_vec_load_q8x4x8_full(y1_q + i * y_qblk_size); + + // Load src0 rows (reused across both src1 columns) + HVX_Vector_x8 r0_q = hvx_vec_load_mxfp4x4x8_full(r0_x_q + i * x_qblk_size); + HVX_Vector_x8 r1_q = hvx_vec_load_mxfp4x4x8_full(r1_x_q + i * x_qblk_size); + + // Compute 4 dot products: r0×c0, r0×c1, r1×c0, r1×c1 + HVX_Vector r0_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy0_q)); + HVX_Vector r0_c1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy1_q)); + HVX_Vector r1_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r1_q, vy0_q)); + HVX_Vector r1_c1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r1_q, vy1_q)); + + // Load scales + HVX_Vector vy0_d = *(const HVX_UVector *) (y0_d + i * y_dblk_size); + HVX_Vector vy1_d = *(const HVX_UVector *) (y1_d + i * y_dblk_size); + HVX_Vector r0_d = *(const HVX_UVector *) (r0_x_d + i * x_dblk_size); + HVX_Vector r1_d = *(const HVX_UVector *) (r1_x_d + i * x_dblk_size); + + // Convert vy_d from fp16 to fp32 while applying 0.5 scaling which is used for e8m0 halving + HVX_Vector half = Q6_Vh_vsplat_R(0x3800); // 0.5 in fp16 + vy0_d = Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(vy0_d), half)); + vy0_d = Q6_Vsf_equals_Vqf32(vy0_d); + vy1_d = Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(vy1_d), half)); + vy1_d = Q6_Vsf_equals_Vqf32(vy1_d); + + // Convert rX_d scales from e8m0 to fp32 + // Expand and zero-pad 32x uint8 e8m0 values to uint32s : 0 0 0 0, 0 0 0 1, 0 0 0 2, ... + // Left shift with zero fill to create FP32 + // FIXME: might need to handle zero as a special case (see ggml-cpu code) + HVX_Vector expand = *(const HVX_Vector *) expand_x32_e8m0; + HVX_Vector e8m0_mask = Q6_V_vsplat_R(0x000000ff); + r0_d = Q6_V_vdelta_VV(r0_d, expand); + r0_d = Q6_V_vand_VV(r0_d, e8m0_mask); + r0_d = Q6_Vw_vasl_VwR(r0_d, 23); + r1_d = Q6_V_vdelta_VV(r1_d, expand); + r1_d = Q6_V_vand_VV(r1_d, e8m0_mask); + r1_d = Q6_Vw_vasl_VwR(r1_d, 23); + + // Compute combined scales + HVX_Vector r0_c0_dd = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(r0_d, vy0_d)); + HVX_Vector r0_c1_dd = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(r0_d, vy1_d)); + HVX_Vector r1_c0_dd = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(r1_d, vy0_d)); + HVX_Vector r1_c1_dd = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(r1_d, vy1_d)); + + // Apply scales and accumulate + HVX_Vector r0_c0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_c0_ia, r0_c0_dd); + HVX_Vector r0_c1_fa = Q6_Vqf32_vmpy_VsfVsf(r0_c1_ia, r0_c1_dd); + HVX_Vector r1_c0_fa = Q6_Vqf32_vmpy_VsfVsf(r1_c0_ia, r1_c0_dd); + HVX_Vector r1_c1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_c1_ia, r1_c1_dd); + + r0_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c0_fa, r0_c0_sum)); + r0_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c1_fa, r0_c1_sum)); + r1_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c0_fa, r1_c0_sum)); + r1_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c1_fa, r1_c1_sum)); + } + + // Process leftovers + if (nloe) { + HVX_Vector_x8 vy0_q = hvx_vec_load_q8x4x8_partial( y0_q + i * y_qblk_size, nloe); + HVX_Vector_x8 vy1_q = hvx_vec_load_q8x4x8_partial( y1_q + i * y_qblk_size, nloe); + HVX_Vector_x8 r0_q = hvx_vec_load_mxfp4x4x8_partial(r0_x_q + i * x_qblk_size, nloe); + HVX_Vector_x8 r1_q = hvx_vec_load_mxfp4x4x8_partial(r1_x_q + i * x_qblk_size, nloe); + + HVX_Vector r0_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r0_q, vy0_q, nloe)); + HVX_Vector r0_c1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r0_q, vy1_q, nloe)); + HVX_Vector r1_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r1_q, vy0_q, nloe)); + HVX_Vector r1_c1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r1_q, vy1_q, nloe)); + + HVX_Vector vy0_d = *(const HVX_UVector *) (y0_d + i * y_dblk_size); + HVX_Vector vy1_d = *(const HVX_UVector *) (y1_d + i * y_dblk_size); + HVX_Vector r0_d = *(const HVX_UVector *) (r0_x_d + i * x_dblk_size); + HVX_Vector r1_d = *(const HVX_UVector *) (r1_x_d + i * x_dblk_size); + + // Convert vy_d from fp16 to fp32 while applying 0.5 scaling which is used for e8m0 halving + HVX_Vector half = Q6_Vh_vsplat_R(0x3800); // 0.5 in fp16 + vy0_d = Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(vy0_d), half)); + vy0_d = Q6_Vsf_equals_Vqf32(vy0_d); + vy1_d = Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(vy1_d), half)); + vy1_d = Q6_Vsf_equals_Vqf32(vy1_d); + + // Convert rX_d scales from e8m0 to fp32 + // Expand and zero-pad 32x uint8 e8m0 values to uint32s : 0 0 0 0, 0 0 0 1, 0 0 0 2, ... + // Left shift with zero fill to create FP32 + // FIXME: might need to handle zero as a special case (see ggml-cpu code) + HVX_Vector expand = *(const HVX_Vector *) expand_x32_e8m0; + HVX_Vector e8m0_mask = Q6_V_vsplat_R(0x000000ff); + r0_d = Q6_V_vdelta_VV(r0_d, expand); + r0_d = Q6_V_vand_VV(r0_d, e8m0_mask); + r0_d = Q6_Vw_vasl_VwR(r0_d, 23); + r1_d = Q6_V_vdelta_VV(r1_d, expand); + r1_d = Q6_V_vand_VV(r1_d, e8m0_mask); + r1_d = Q6_Vw_vasl_VwR(r1_d, 23); + + HVX_Vector r0_c0_dd = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(r0_d, vy0_d)); + HVX_Vector r0_c1_dd = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(r0_d, vy1_d)); + HVX_Vector r1_c0_dd = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(r1_d, vy0_d)); + HVX_Vector r1_c1_dd = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(r1_d, vy1_d)); + + // Zero out unused scales + HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe / 8); + r0_c0_dd = Q6_V_vand_QV(bmask, r0_c0_dd); + r0_c1_dd = Q6_V_vand_QV(bmask, r0_c1_dd); + r1_c0_dd = Q6_V_vand_QV(bmask, r1_c0_dd); + r1_c1_dd = Q6_V_vand_QV(bmask, r1_c1_dd); + r0_c0_ia = Q6_V_vand_QV(bmask, r0_c0_ia); + r0_c1_ia = Q6_V_vand_QV(bmask, r0_c1_ia); + r1_c0_ia = Q6_V_vand_QV(bmask, r1_c0_ia); + r1_c1_ia = Q6_V_vand_QV(bmask, r1_c1_ia); + + HVX_Vector r0_c0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_c0_ia, r0_c0_dd); + HVX_Vector r0_c1_fa = Q6_Vqf32_vmpy_VsfVsf(r0_c1_ia, r0_c1_dd); + HVX_Vector r1_c0_fa = Q6_Vqf32_vmpy_VsfVsf(r1_c0_ia, r1_c0_dd); + HVX_Vector r1_c1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_c1_ia, r1_c1_dd); + + r0_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c0_fa, r0_c0_sum)); + r0_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c1_fa, r0_c1_sum)); + r1_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c0_fa, r1_c0_sum)); + r1_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c1_fa, r1_c1_sum)); + } + + // Reduce and store results + HVX_Vector r0_r1_c0_sum = hvx_vec_reduce_sum_f32x2(r0_c0_sum, r1_c0_sum); + HVX_Vector r0_r1_c1_sum = hvx_vec_reduce_sum_f32x2(r0_c1_sum, r1_c1_sum); + + hvx_vec_store_u(&s0[0], 8, r0_r1_c0_sum); // row0,col0 row1,col0 + hvx_vec_store_u(&s1[0], 8, r0_r1_c1_sum); // row0,col1 row1,col1 +} + +#if __HVX_ARCH__ < 79 +#define HVX_OP_ADD_F32(a, b) Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_VsfVsf(a, b)) +#define HVX_OP_MUL_F32(a, b) Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(a, b)) +#else +#define HVX_OP_ADD_F32(a, b) Q6_Vsf_vadd_VsfVsf(a, b) +#define HVX_OP_MUL_F32(a, b) Q6_Vsf_vmpy_VsfVsf(a, b) +#endif + +static void vec_dot_f32_f32_aa_1x1(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) { + const HVX_Vector * restrict x = (const HVX_Vector *) vx; + const HVX_Vector * restrict y = (const HVX_Vector *) vy; + + uint32_t nvec = n / VLEN_FP32; // num full fp32 hvx vectors + uint32_t nloe = n % VLEN_FP32; // leftover elements + + HVX_Vector rsum = Q6_V_vzero(); + + uint32_t i = 0; + + #pragma unroll(4) + for (i = 0; i < nvec; i++) { + HVX_Vector prod = HVX_OP_MUL_F32(x[i], y[i]); + rsum = HVX_OP_ADD_F32(rsum, prod); + } + + if (nloe) { + HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 4); + HVX_Vector x_sf = Q6_V_vand_QV(bmask, x[i]); + HVX_Vector y_sf = Q6_V_vand_QV(bmask, y[i]); + HVX_Vector prod = HVX_OP_MUL_F32(x_sf, y_sf); + rsum = HVX_OP_ADD_F32(rsum, prod); + } + + *s = hvx_vec_get_f32(hvx_vec_reduce_sum_f32(rsum)); +} + +static void vec_dot_f32_f32_aa_2x1(const int n, float * restrict s0, + const void * restrict vx0, const void * restrict vx1, + const void * restrict vy0) { + const HVX_Vector * restrict x0 = (const HVX_Vector *) vx0; + const HVX_Vector * restrict x1 = (const HVX_Vector *) vx1; + const HVX_Vector * restrict y = (const HVX_Vector *) vy0; + + uint32_t nvec = n / VLEN_FP32; + uint32_t nloe = n % VLEN_FP32; + + HVX_Vector rsum0 = Q6_V_vzero(); + HVX_Vector rsum1 = Q6_V_vzero(); + + uint32_t i = 0; + + #pragma unroll(2) + for (i = 0; i < nvec; i++) { + HVX_Vector y_sf = y[i]; + HVX_Vector prod0 = HVX_OP_MUL_F32(x0[i], y_sf); + HVX_Vector prod1 = HVX_OP_MUL_F32(x1[i], y_sf); + rsum0 = HVX_OP_ADD_F32(rsum0, prod0); + rsum1 = HVX_OP_ADD_F32(rsum1, prod1); + } + + if (nloe) { + HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 4); + HVX_Vector y_sf = Q6_V_vand_QV(bmask, y[i]); + HVX_Vector x0_sf = Q6_V_vand_QV(bmask, x0[i]); + HVX_Vector x1_sf = Q6_V_vand_QV(bmask, x1[i]); + HVX_Vector prod0 = HVX_OP_MUL_F32(x0_sf, y_sf); + HVX_Vector prod1 = HVX_OP_MUL_F32(x1_sf, y_sf); + rsum0 = HVX_OP_ADD_F32(rsum0, prod0); + rsum1 = HVX_OP_ADD_F32(rsum1, prod1); + } + + HVX_Vector rsum = hvx_vec_reduce_sum_f32x2(rsum0, rsum1); + HVX_VectorAlias va; + va.v = rsum; + s0[0] = va.fp32[0]; + s0[1] = va.fp32[1]; +} - HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd); - HVX_Vector r1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_ia, r1_dd); +static void vec_dot_f32_f32_aa_2x2(const int n, float * restrict s0, float * restrict s1, + const void * restrict vx0, const void * restrict vx1, + const void * restrict vy0, const void * restrict vy1) { + const HVX_Vector * restrict x0 = (const HVX_Vector *) vx0; + const HVX_Vector * restrict x1 = (const HVX_Vector *) vx1; + const HVX_Vector * restrict y0 = (const HVX_Vector *) vy0; + const HVX_Vector * restrict y1 = (const HVX_Vector *) vy1; + + uint32_t nvec = n / VLEN_FP32; + uint32_t nloe = n % VLEN_FP32; - r0_sum = Q6_Vqf32_vadd_Vqf32Vqf32(r0_sum, r0_fa); - r1_sum = Q6_Vqf32_vadd_Vqf32Vqf32(r1_sum, r1_fa); + HVX_Vector r0_c0_sum = Q6_V_vzero(); + HVX_Vector r0_c1_sum = Q6_V_vzero(); + HVX_Vector r1_c0_sum = Q6_V_vzero(); + HVX_Vector r1_c1_sum = Q6_V_vzero(); + + uint32_t i = 0; + + #pragma unroll(2) + for (i = 0; i < nvec; i++) { + HVX_Vector r0_sf = x0[i]; + HVX_Vector r1_sf = x1[i]; + HVX_Vector c0_sf = y0[i]; + HVX_Vector c1_sf = y1[i]; + + r0_c0_sum = HVX_OP_ADD_F32(r0_c0_sum, HVX_OP_MUL_F32(r0_sf, c0_sf)); + r0_c1_sum = HVX_OP_ADD_F32(r0_c1_sum, HVX_OP_MUL_F32(r0_sf, c1_sf)); + r1_c0_sum = HVX_OP_ADD_F32(r1_c0_sum, HVX_OP_MUL_F32(r1_sf, c0_sf)); + r1_c1_sum = HVX_OP_ADD_F32(r1_c1_sum, HVX_OP_MUL_F32(r1_sf, c1_sf)); } - // Process leftovers if (nloe) { - HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8(y_q + i * y_qblk_size); - HVX_Vector_x8 r0_q = hvx_vec_load_mxfp4x4x8(r0_x_q + i * x_qblk_size); - HVX_Vector_x8 r1_q = hvx_vec_load_mxfp4x4x8(r1_x_q + i * x_qblk_size); + HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 4); - HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy_q)); - HVX_Vector r1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r1_q, vy_q)); + HVX_Vector r0_sf = Q6_V_vand_QV(bmask, x0[i]); + HVX_Vector r1_sf = Q6_V_vand_QV(bmask, x1[i]); + HVX_Vector c0_sf = Q6_V_vand_QV(bmask, y0[i]); + HVX_Vector c1_sf = Q6_V_vand_QV(bmask, y1[i]); - HVX_Vector vy_d = *(const HVX_UVector *) (y_d + i * y_dblk_size); - HVX_Vector r0_d = *(const HVX_UVector *) (r0_x_d + i * x_dblk_size); - HVX_Vector r1_d = *(const HVX_UVector *) (r1_x_d + i * x_dblk_size); + r0_c0_sum = HVX_OP_ADD_F32(r0_c0_sum, HVX_OP_MUL_F32(r0_sf, c0_sf)); + r0_c1_sum = HVX_OP_ADD_F32(r0_c1_sum, HVX_OP_MUL_F32(r0_sf, c1_sf)); + r1_c0_sum = HVX_OP_ADD_F32(r1_c0_sum, HVX_OP_MUL_F32(r1_sf, c0_sf)); + r1_c1_sum = HVX_OP_ADD_F32(r1_c1_sum, HVX_OP_MUL_F32(r1_sf, c1_sf)); + } - // Convert vy_d from fp16 to fp32 while applying 0.5 scaling which is used for e8m0 halving - HVX_Vector half = Q6_Vh_vsplat_R(0x3800); // 0.5 in fp16 - vy_d = Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(vy_d), half)); - vy_d = Q6_Vsf_equals_Vqf32(vy_d); + // Reduce and store results + HVX_Vector r0_r1_c0_sum = hvx_vec_reduce_sum_f32x2(r0_c0_sum, r1_c0_sum); + HVX_Vector r0_r1_c1_sum = hvx_vec_reduce_sum_f32x2(r0_c1_sum, r1_c1_sum); + + HVX_VectorAlias va0, va1; + va0.v = r0_r1_c0_sum; + va1.v = r0_r1_c1_sum; + s0[0] = va0.fp32[0]; + s0[1] = va0.fp32[1]; + s1[0] = va1.fp32[0]; + s1[1] = va1.fp32[1]; +} - // Convert rX_d scales from e8m0 to fp32 - // Expand and zero-pad 32x uint8 e8m0 values to uint32s : 0 0 0 0, 0 0 0 1, 0 0 0 2, ... - // Left shift with zero fill to create FP32 - // FIXME: might need to handle zero as a special case (see ggml-cpu code) - HVX_Vector expand = *(const HVX_Vector *) expand_x32_e8m0; - HVX_Vector e8m0_mask = Q6_V_vsplat_R(0x000000ff); - r0_d = Q6_V_vdelta_VV(r0_d, expand); - r0_d = Q6_V_vand_VV(r0_d, e8m0_mask); - r0_d = Q6_Vw_vasl_VwR(r0_d, 23); - r1_d = Q6_V_vdelta_VV(r1_d, expand); - r1_d = Q6_V_vand_VV(r1_d, e8m0_mask); - r1_d = Q6_Vw_vasl_VwR(r1_d, 23); +static void vec_dot_f32_f32_uu_1x1(const int n, float * restrict s, const void * restrict x, const void * restrict y) { + const HVX_UVector * restrict vx = (const HVX_UVector * restrict) x; + const HVX_UVector * restrict vy = (const HVX_UVector * restrict) y; - HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(r0_d, vy_d)); - HVX_Vector r1_dd = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(r1_d, vy_d)); + uint32_t nvec = n / VLEN_FP32; // num full fp32 hvx vectors + uint32_t nloe = n % VLEN_FP32; // leftover elements - // Zero-out unused scales - HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe / 8); - r0_dd = Q6_V_vand_QV(bmask, r0_dd); - r1_dd = Q6_V_vand_QV(bmask, r1_dd); + HVX_Vector rsum = Q6_V_vzero(); - HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd); - HVX_Vector r1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_ia, r1_dd); + uint32_t i = 0; - r0_sum = Q6_Vqf32_vadd_Vqf32Vqf32(r0_sum, r0_fa); - r1_sum = Q6_Vqf32_vadd_Vqf32Vqf32(r1_sum, r1_fa); + #pragma unroll(2) + for (i = 0; i < nvec; i++) { + HVX_Vector x_sf = vx[i]; + HVX_Vector y_sf = vy[i]; + + rsum = HVX_OP_ADD_F32(rsum, HVX_OP_MUL_F32(x_sf, y_sf)); } - // Convert into fp32 and reduce - r0_sum = hvx_vec_fp32_reduce_sum(Q6_Vsf_equals_Vqf32(r0_sum)); - r1_sum = hvx_vec_fp32_reduce_sum(Q6_Vsf_equals_Vqf32(r1_sum)); - HVX_VectorPair p0 = Q6_W_vshuff_VVR(r1_sum, r0_sum, 4); + if (nloe) { + HVX_Vector x_sf = vx[i]; + HVX_Vector y_sf = vy[i]; + + HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 4); + x_sf = Q6_V_vand_QV(bmask, x_sf); + y_sf = Q6_V_vand_QV(bmask, y_sf); + + rsum = HVX_OP_ADD_F32(rsum, HVX_OP_MUL_F32(x_sf, y_sf)); + } - hvx_vec_store_u(&s[0], 8, Q6_V_lo_W(p0)); + rsum = hvx_vec_reduce_sum_f32(rsum); + hvx_vec_store_u(&s[0], 4, rsum); } -static void vec_dot_f16_f16_aa(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) { +static void vec_dot_f16_f16_aa_1x1(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) { const HVX_Vector * restrict x = (const HVX_Vector *) vx; const HVX_Vector * restrict y = (const HVX_Vector *) vy; uint32_t nvec = n / VLEN_FP16; // num full fp16 hvx vectors uint32_t nloe = n % VLEN_FP16; // leftover elements - HVX_Vector rsum = Q6_V_vsplat_R(0); + HVX_VectorPair rsum_p = Q6_W_vzero(); uint32_t i = 0; #pragma unroll(4) for (i = 0; i < nvec; i++) { - HVX_VectorPair xy_qf = Q6_Wqf32_vmpy_VhfVhf(x[i], y[i]); - rsum = Q6_Vqf32_vadd_Vqf32Vqf32(rsum, Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy_qf), Q6_V_hi_W(xy_qf))); + rsum_p = hvx_vec_mpyacc_f32_f16(rsum_p, x[i], y[i]); } if (nloe) { HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 2); HVX_Vector x_hf = Q6_V_vand_QV(bmask, x[i]); HVX_Vector y_hf = Q6_V_vand_QV(bmask, y[i]); - - HVX_VectorPair xy_qf = Q6_Wqf32_vmpy_VhfVhf(x_hf, y_hf); - rsum = Q6_Vqf32_vadd_Vqf32Vqf32(rsum, Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy_qf), Q6_V_hi_W(xy_qf))); + rsum_p = hvx_vec_mpyacc_f32_f16(rsum_p, x_hf, y_hf); } - rsum = Q6_Vsf_equals_Vqf32(hvx_vec_qf32_reduce_sum(rsum)); - hvx_vec_store_u(&s[0], 4, rsum); + HVX_Vector rsum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_VsfVsf(Q6_V_lo_W(rsum_p), Q6_V_hi_W(rsum_p))); + hvx_vec_store_u(s, 4, hvx_vec_reduce_sum_f32(rsum)); } -static void vec_dot_f16_f16_aa_rx2(const int n, - float * restrict s, - const void * restrict vx, - uint32_t vx_row_size, - const void * restrict vy) { - const HVX_Vector * restrict x0 = (const HVX_Vector *) vx; - const HVX_Vector * restrict x1 = (const HVX_Vector *) ((const uint8_t *) vx + vx_row_size); - const HVX_Vector * restrict y = (const HVX_Vector *) vy; +static void vec_dot_f16_f16_aa_2x1(const int n, float * restrict s0, + const void * restrict vx0, const void * restrict vx1, + const void * restrict vy0) { + const HVX_Vector * restrict x0 = (const HVX_Vector *) vx0; + const HVX_Vector * restrict x1 = (const HVX_Vector *) vx1; + const HVX_Vector * restrict y = (const HVX_Vector *) vy0; uint32_t nvec = n / VLEN_FP16; uint32_t nloe = n % VLEN_FP16; - HVX_Vector rsum0 = Q6_V_vsplat_R(0); - HVX_Vector rsum1 = Q6_V_vsplat_R(0); + HVX_VectorPair rsum0_p = Q6_W_vzero(); + HVX_VectorPair rsum1_p = Q6_W_vzero(); uint32_t i = 0; #pragma unroll(2) for (i = 0; i < nvec; i++) { HVX_Vector y_hf = y[i]; - HVX_VectorPair xy0_qf = Q6_Wqf32_vmpy_VhfVhf(x0[i], y_hf); - HVX_VectorPair xy1_qf = Q6_Wqf32_vmpy_VhfVhf(x1[i], y_hf); - - rsum0 = Q6_Vqf32_vadd_Vqf32Vqf32(rsum0, Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy0_qf), Q6_V_hi_W(xy0_qf))); - rsum1 = Q6_Vqf32_vadd_Vqf32Vqf32(rsum1, Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy1_qf), Q6_V_hi_W(xy1_qf))); + rsum0_p = hvx_vec_mpyacc_f32_f16(rsum0_p, x0[i], y_hf); + rsum1_p = hvx_vec_mpyacc_f32_f16(rsum1_p, x1[i], y_hf); } if (nloe) { HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 2); + HVX_Vector y_hf = Q6_V_vand_QV(bmask, y[i]); HVX_Vector x0_hf = Q6_V_vand_QV(bmask, x0[i]); HVX_Vector x1_hf = Q6_V_vand_QV(bmask, x1[i]); - HVX_Vector y_hf = Q6_V_vand_QV(bmask, y[i]); + rsum0_p = hvx_vec_mpyacc_f32_f16(rsum0_p, x0_hf, y_hf); + rsum1_p = hvx_vec_mpyacc_f32_f16(rsum1_p, x1_hf, y_hf); + } + + HVX_Vector rsum0 = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_VsfVsf(Q6_V_lo_W(rsum0_p), Q6_V_hi_W(rsum0_p))); + HVX_Vector rsum1 = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_VsfVsf(Q6_V_lo_W(rsum1_p), Q6_V_hi_W(rsum1_p))); + HVX_Vector rsum = hvx_vec_reduce_sum_f32x2(rsum0, rsum1); + hvx_vec_store_u(s0, 8, rsum); +} + +static void vec_dot_f16_f16_aa_2x2(const int n, float * restrict s0, float * restrict s1, + const void * restrict vx0, const void * restrict vx1, + const void * restrict vy0, const void * restrict vy1) { + const HVX_Vector * restrict x0 = (const HVX_Vector *) vx0; + const HVX_Vector * restrict x1 = (const HVX_Vector *) vx1; + const HVX_Vector * restrict y0 = (const HVX_Vector *) vy0; + const HVX_Vector * restrict y1 = (const HVX_Vector *) vy1; + + uint32_t nvec = n / VLEN_FP16; + uint32_t nloe = n % VLEN_FP16; + + // Row sums (sf) - 4 accumulators for 2×2 tile + HVX_VectorPair r0_c0_sum_p = Q6_W_vzero(); + HVX_VectorPair r0_c1_sum_p = Q6_W_vzero(); + HVX_VectorPair r1_c0_sum_p = Q6_W_vzero(); + HVX_VectorPair r1_c1_sum_p = Q6_W_vzero(); + + uint32_t i = 0; + + #pragma unroll(2) + for (i = 0; i < nvec; i++) { + HVX_Vector r0_hf = x0[i]; + HVX_Vector r1_hf = x1[i]; + HVX_Vector c0_hf = y0[i]; + HVX_Vector c1_hf = y1[i]; + + // Compute 4 dot products: r0×c0, r0×c1, r1×c0, r1×c1 + r0_c0_sum_p = hvx_vec_mpyacc_f32_f16(r0_c0_sum_p, r0_hf, c0_hf); + r0_c1_sum_p = hvx_vec_mpyacc_f32_f16(r0_c1_sum_p, r0_hf, c1_hf); + r1_c0_sum_p = hvx_vec_mpyacc_f32_f16(r1_c0_sum_p, r1_hf, c0_hf); + r1_c1_sum_p = hvx_vec_mpyacc_f32_f16(r1_c1_sum_p, r1_hf, c1_hf); + } - HVX_VectorPair xy0_qf = Q6_Wqf32_vmpy_VhfVhf(x0_hf, y_hf); - HVX_VectorPair xy1_qf = Q6_Wqf32_vmpy_VhfVhf(x1_hf, y_hf); + if (nloe) { + HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 2); - rsum0 = Q6_Vqf32_vadd_Vqf32Vqf32(rsum0, Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy0_qf), Q6_V_hi_W(xy0_qf))); - rsum1 = Q6_Vqf32_vadd_Vqf32Vqf32(rsum1, Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy1_qf), Q6_V_hi_W(xy1_qf))); + HVX_Vector r0_hf = Q6_V_vand_QV(bmask, x0[i]); + HVX_Vector r1_hf = Q6_V_vand_QV(bmask, x1[i]); + HVX_Vector c0_hf = Q6_V_vand_QV(bmask, y0[i]); + HVX_Vector c1_hf = Q6_V_vand_QV(bmask, y1[i]); + + r0_c0_sum_p = hvx_vec_mpyacc_f32_f16(r0_c0_sum_p, r0_hf, c0_hf); + r0_c1_sum_p = hvx_vec_mpyacc_f32_f16(r0_c1_sum_p, r0_hf, c1_hf); + r1_c0_sum_p = hvx_vec_mpyacc_f32_f16(r1_c0_sum_p, r1_hf, c0_hf); + r1_c1_sum_p = hvx_vec_mpyacc_f32_f16(r1_c1_sum_p, r1_hf, c1_hf); } - rsum0 = Q6_Vsf_equals_Vqf32(hvx_vec_qf32_reduce_sum(rsum0)); - rsum1 = Q6_Vsf_equals_Vqf32(hvx_vec_qf32_reduce_sum(rsum1)); - HVX_VectorPair p0 = Q6_W_vshuff_VVR(rsum1, rsum0, 4); + HVX_Vector r0_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_VsfVsf(Q6_V_lo_W(r0_c0_sum_p), Q6_V_hi_W(r0_c0_sum_p))); + HVX_Vector r0_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_VsfVsf(Q6_V_lo_W(r0_c1_sum_p), Q6_V_hi_W(r0_c1_sum_p))); + HVX_Vector r1_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_VsfVsf(Q6_V_lo_W(r1_c0_sum_p), Q6_V_hi_W(r1_c0_sum_p))); + HVX_Vector r1_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_VsfVsf(Q6_V_lo_W(r1_c1_sum_p), Q6_V_hi_W(r1_c1_sum_p))); + + // Reduce and store results + HVX_Vector r0_r1_c0_sum = hvx_vec_reduce_sum_f32x2(r0_c0_sum, r1_c0_sum); + HVX_Vector r0_r1_c1_sum = hvx_vec_reduce_sum_f32x2(r0_c1_sum, r1_c1_sum); - hvx_vec_store_u(&s[0], 8, Q6_V_lo_W(p0)); + hvx_vec_store_u(&s0[0], 8, r0_r1_c0_sum); // row0,col0 row1,col0 + hvx_vec_store_u(&s1[0], 8, r0_r1_c1_sum); // row0,col1 row1,col1 } -static void vec_dot_f16_f16_uu(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) { +static void vec_dot_f16_f16_uu_1x1(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) { const HVX_UVector * restrict x = (const HVX_UVector *) vx; const HVX_UVector * restrict y = (const HVX_UVector *) vy; uint32_t nvec = n / VLEN_FP16; // num full fp16 hvx vectors uint32_t nloe = n % VLEN_FP16; // leftover elements - HVX_Vector rsum = Q6_V_vsplat_R(0); + HVX_Vector rsum = Q6_V_vzero(); uint32_t i = 0; @@ -1010,20 +3238,20 @@ static void vec_dot_f16_f16_uu(const int n, float * restrict s, const void * res rsum = Q6_Vqf32_vadd_Vqf32Vqf32(rsum, Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy_qf), Q6_V_hi_W(xy_qf))); } - rsum = Q6_Vsf_equals_Vqf32(hvx_vec_qf32_reduce_sum(rsum)); + rsum = hvx_vec_reduce_sum_f32(Q6_Vsf_equals_Vqf32(rsum)); hvx_vec_store_u(&s[0], 4, rsum); } -static void vec_dot_f16_f32_uu(const int n, float * restrict s, const void * restrict x, const void * restrict y) { +static void vec_dot_f16_f32_uu_1x1(const int n, float * restrict s, const void * restrict x, const void * restrict y) { const HVX_UVector * restrict vx = (const HVX_UVector * restrict) x; const HVX_UVector * restrict vy = (const HVX_UVector * restrict) y; uint32_t nvec = n / VLEN_FP16; // num full fp16 hvx vectors uint32_t nloe = n % VLEN_FP16; // leftover elements - const HVX_Vector zero = Q6_V_vsplat_R(0); + const HVX_Vector zero = Q6_V_vzero(); - HVX_Vector rsum = Q6_V_vsplat_R(0); + HVX_Vector rsum = Q6_V_vzero(); uint32_t i = 0; @@ -1062,15 +3290,16 @@ static void vec_dot_f16_f32_uu(const int n, float * restrict s, const void * res rsum = Q6_Vqf32_vadd_Vqf32Vqf32(rsum, Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy_qf), Q6_V_hi_W(xy_qf))); } - rsum = Q6_Vsf_equals_Vqf32(hvx_vec_qf32_reduce_sum(rsum)); + // Convert into fp32 and reduce + rsum = hvx_vec_reduce_sum_f32(Q6_Vsf_equals_Vqf32(rsum)); hvx_vec_store_u(&s[0], 4, rsum); } -#define htp_matmul_tensors_preamble \ - struct htp_tensor * restrict src0 = &octx->src0; \ - struct htp_tensor * restrict src1 = &octx->src1; \ - struct htp_tensor * restrict src2 = &octx->src2; \ - struct htp_tensor * restrict dst = &octx->dst; \ +#define htp_matmul_tensors_preamble \ + const struct htp_tensor * restrict src0 = octx->src[0]; \ + const struct htp_tensor * restrict src1 = octx->src[1]; \ + const struct htp_tensor * restrict src2 = octx->src[2]; \ + const struct htp_tensor * restrict dst = octx->dst; \ struct htp_spad * restrict src0_spad = &octx->src0_spad; \ struct htp_spad * restrict src1_spad = &octx->src1_spad; \ struct htp_spad * restrict dst_spad = &octx->dst_spad; \ @@ -1110,14 +3339,16 @@ static void vec_dot_f16_f32_uu(const int n, float * restrict s, const void * res const uint32_t nb2 = dst->nb[2]; \ const uint32_t nb3 = dst->nb[3]; -#define htp_matmul_preamble \ - htp_matmul_tensors_preamble; \ - dma_queue *dma_queue = octx->ctx->dma[ith]; \ - uint32_t src0_nrows_per_thread = octx->src0_nrows_per_thread; +#define htp_matmul_preamble \ + struct htp_matmul_context * mmctx = data; \ + struct htp_ops_context * octx = mmctx->octx; \ + htp_matmul_tensors_preamble; \ + dma_queue *dma_queue = octx->ctx->dma[ith]; \ + uint32_t src0_nrows_per_thread = mmctx->src0_nrows_per_thread; // *** matmul with support for 4d tensors and full broadcasting -static void matmul_4d(struct htp_matmul_type * mt, struct htp_ops_context * octx, uint32_t nth, uint32_t ith) { +static void matmul_4d(unsigned int nth, unsigned int ith, void * data) { htp_matmul_preamble; uint64_t t1, t2; @@ -1163,13 +3394,13 @@ static void matmul_4d(struct htp_matmul_type * mt, struct htp_ops_context * octx for (uint32_t iir1 = ir1_start; iir1 < ir1_end; iir1 += blck_1) { for (uint32_t iir0 = ir0_start; iir0 < ir0_end; iir0 += blck_0) { for (uint32_t ir1 = iir1; ir1 < MIN(iir1 + blck_1, ir1_end); ir1++) { - const uint32_t i13 = fastdiv(ir1, &octx->mm_div_ne12_ne1); - const uint32_t i12 = fastdiv(ir1 - i13 * ne12 * ne1, &octx->mm_div_ne1); + const uint32_t i13 = fastdiv(ir1, &mmctx->mm_div_ne12_ne1); + const uint32_t i12 = fastdiv(ir1 - i13 * ne12 * ne1, &mmctx->mm_div_ne1); const uint32_t i11 = (ir1 - i13 * ne12 * ne1 - i12 * ne1); // broadcast src0 into src1 - const uint32_t i03 = fastdiv(i13, &octx->mm_div_r3); - const uint32_t i02 = fastdiv(i12, &octx->mm_div_r2); + const uint32_t i03 = fastdiv(i13, &mmctx->mm_div_r3); + const uint32_t i02 = fastdiv(i12, &mmctx->mm_div_r2); const uint32_t i1 = i11; const uint32_t i2 = i12; @@ -1182,7 +3413,7 @@ static void matmul_4d(struct htp_matmul_type * mt, struct htp_ops_context * octx const uint32_t ir0_block_end = MIN(iir0 + blck_0, ir0_end); for (uint32_t ir0 = iir0; ir0 < ir0_block_end; ir0++) { const uint8_t * restrict src0_row = src0_base + ir0 * nb01; - mt->vec_dot(ne00, &dst_col[ir0], src0_row, src1_col); + mmctx->vec_dot_1x1(ne00, &dst_col[ir0], src0_row, src1_col); } } } @@ -1197,7 +3428,7 @@ static void matmul_4d(struct htp_matmul_type * mt, struct htp_ops_context * octx } // src1 tensor is already in VTCM spad -static void matmul_2d(struct htp_matmul_type * mt, struct htp_ops_context * octx, uint32_t nth, uint32_t ith) { +static void matmul_2d(unsigned int nth, unsigned int ith, void * data) { htp_matmul_preamble; const uint32_t src0_nrows = ne01 * ne02 * ne03; // src0 rows @@ -1222,7 +3453,7 @@ static void matmul_2d(struct htp_matmul_type * mt, struct htp_ops_context * octx // Per-thread VTCM scratchpads for all tensors // Note that the entire src1 tensor is already in VTCM // For other tensors we allocate N rows per thread, padded to HVX vector size - uint8_t * restrict spad_dst = dst_spad->data + dst_spad->size_per_thread * ith; + uint8_t * restrict spad_dst = dst_spad->data + dst_spad->size_per_thread * ith; uint8_t * restrict spad_src0 = src0_spad->data + src0_spad->size_per_thread * ith; uint8_t * restrict src1_data = src1_spad->data; @@ -1246,11 +3477,21 @@ static void matmul_2d(struct htp_matmul_type * mt, struct htp_ops_context * octx for (uint32_t ir0 = src0_start_row; ir0 < src0_end_row_x2; ir0 += 2) { const uint8_t * ss0 = dma_queue_pop(dma_queue).dst; - #pragma unroll(2) - for (uint32_t ir1 = 0; ir1 < src1_nrows; ++ir1) { + // Process src1 columns in pairs (2×2 tiling) + uint32_t ir1 = 0; + for (; ir1 + 1 < src1_nrows; ir1 += 2) { + const uint8_t * restrict src1_col0 = (const uint8_t *) (src1_data + (ir1+0) * src1_stride); + const uint8_t * restrict src1_col1 = (const uint8_t *) (src1_data + (ir1+1) * src1_stride); + float * restrict dst_row0 = (float *) (dst->data + ((ir1+0) * dst_row_size)); + float * restrict dst_row1 = (float *) (dst->data + ((ir1+1) * dst_row_size)); + mmctx->vec_dot_2x2(ne00, &dst_row0[ir0], &dst_row1[ir0], ss0, ss0 + src0_stride, src1_col0, src1_col1); + } + + // Handle remaining src1 rows (fallback to 2×1) + for (; ir1 < src1_nrows; ++ir1) { const uint8_t * restrict src1_col = (const uint8_t *) (src1_data + ir1 * src1_stride); float * restrict dst_row = (float *) (dst->data + (ir1 * dst_row_size)); - mt->vec_dot_rx2(ne00, &dst_row[ir0], ss0, src0_stride, src1_col); + mmctx->vec_dot_2x1(ne00, &dst_row[ir0], ss0, ss0 + src0_stride, src1_col); } // Prefetch next (n + spad_nrows) row @@ -1265,7 +3506,7 @@ static void matmul_2d(struct htp_matmul_type * mt, struct htp_ops_context * octx // Process the last row (if any) if (src0_end_row != src0_end_row_x2) { uint32_t ir0 = src0_end_row_x2; - const int is0 = (ir0 - src0_start_row); + const int is0 = (ir0 - src0_start_row) % MM_SPAD_SRC0_NROWS; dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src0 + is0 * src0_stride, src0_row + ir0 * src0_row_size), src0_stride, src0_row_size, 1); const uint8_t * ss0 = dma_queue_pop(dma_queue).dst; @@ -1274,27 +3515,26 @@ static void matmul_2d(struct htp_matmul_type * mt, struct htp_ops_context * octx for (uint32_t ir1 = 0; ir1 < src1_nrows; ++ir1) { const uint8_t * restrict src1_col = (const uint8_t *) (src1_data + ir1 * src1_stride); float * restrict dst_row = (float *) (dst->data + (ir1 * dst_row_size)); - mt->vec_dot(ne00, &dst_row[ir0], ss0, src1_col); + mmctx->vec_dot_1x1(ne00, &dst_row[ir0], ss0, src1_col); } } t2 = HAP_perf_get_qtimer_count(); - FARF(HIGH, "matmul-%s %d/%d: %ux%ux%ux%u (%u:%u) * %ux%ux%ux%u -> %ux%ux%ux%u usec %u\n", mt->type, ith, nth, + FARF(HIGH, "matmul-%s %d/%d: %ux%ux%ux%u (%u:%u) * %ux%ux%ux%u -> %ux%ux%ux%u usec %u\n", mmctx->type, ith, nth, src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], src0_start_row, src0_end_row, src1->ne[0], src1->ne[1], src1->ne[2], src1->ne[3], dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1)); } // q8x4x2 src1 tensor is already in VTCM spad -static void matvec_2d(struct htp_matmul_type * mt, struct htp_ops_context * octx, uint32_t nth, uint32_t ith) { +static void matvec_2d(unsigned int nth, unsigned int ith, void * data) { htp_matmul_preamble; const uint32_t src0_nrows = ne01; const uint32_t src0_start_row = src0_nrows_per_thread * ith; const uint32_t src0_end_row = MIN(src0_start_row + src0_nrows_per_thread, src0_nrows); - const uint32_t src0_end_row_x2 = src0_start_row + ((src0_end_row - src0_start_row) & ~1U); // no work for this thread if (src0_start_row >= src0_end_row) { @@ -1324,46 +3564,96 @@ static void matvec_2d(struct htp_matmul_type * mt, struct htp_ops_context * octx const uint8_t * restrict src1_col = (const uint8_t *) src1_data; float * restrict dst_col = (float *) dst->data; - // Prefill spad with 2x src0 rows - #pragma unroll(2) - for (uint32_t ir0 = src0_start_row; ir0 < src0_end_row_x2; ir0 += 2) { - const uint32_t is0 = (ir0 - src0_start_row); - if (is0 >= MM_SPAD_SRC0_NROWS) { - break; + if (mmctx->vec_dot_4x1 != NULL) { + const uint32_t src0_end_row_x4 = src0_start_row + ((src0_end_row - src0_start_row) & ~3U); + + // Prefill spad with 4x src0 rows + #pragma unroll(4) + for (uint32_t ir0 = src0_start_row; ir0 < src0_end_row_x4; ir0 += 4) { + const uint32_t is0 = (ir0 - src0_start_row); + if (is0 >= MM_SPAD_SRC0_NROWS) { + break; + } + dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src0 + is0 * src0_stride, src0_row + ir0 * src0_row_size), + src0_stride, src0_row_size, 4); } - dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src0 + is0 * src0_stride, src0_row + ir0 * src0_row_size), - src0_stride, src0_row_size, 2); - } - // Process src0 rows - for (uint32_t ir0 = src0_start_row; ir0 < src0_end_row_x2; ir0 += 2) { - const uint8_t * ss0 = dma_queue_pop(dma_queue).dst; - mt->vec_dot_rx2(ne00, &tmp[ir0 - src0_start_row], ss0, src0_stride, src1_col); + // Process src0 rows + for (uint32_t ir0 = src0_start_row; ir0 < src0_end_row_x4; ir0 += 4) { + const uint8_t * ss0 = dma_queue_pop(dma_queue).dst; + mmctx->vec_dot_4x1(ne00, &tmp[ir0 - src0_start_row], ss0, ss0 + src0_stride, ss0 + 2 * src0_stride, ss0 + 3 * src0_stride, src1_col); - // Prefetch next (n + spad_nrows) row - const uint32_t pr0 = (ir0 + MM_SPAD_SRC0_NROWS); - const uint32_t is0 = (pr0 - src0_start_row) % MM_SPAD_SRC0_NROWS; - if (pr0 < src0_end_row_x2) { - dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src0 + is0 * src0_stride, src0_row + pr0 * src0_row_size), + // Prefetch next (n + spad_nrows) row + const uint32_t pr0 = (ir0 + MM_SPAD_SRC0_NROWS); + const uint32_t is0 = (pr0 - src0_start_row) % MM_SPAD_SRC0_NROWS; + if (pr0 < src0_end_row_x4) { + dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src0 + is0 * src0_stride, src0_row + pr0 * src0_row_size), + src0_stride, src0_row_size, 4); + } + } + + // Process leftovers + uint32_t ir0 = src0_end_row_x4; + if (ir0 + 2 <= src0_end_row) { + const uint32_t is0 = (ir0 - src0_start_row) % MM_SPAD_SRC0_NROWS; + dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src0 + is0 * src0_stride, src0_row + ir0 * src0_row_size), src0_stride, src0_row_size, 2); + const uint8_t * ss0 = dma_queue_pop(dma_queue).dst; + mmctx->vec_dot_2x1(ne00, &tmp[ir0 - src0_start_row], ss0, ss0 + src0_stride, src1_col); + ir0 += 2; } - } + if (ir0 < src0_end_row) { + const uint32_t is0 = (ir0 - src0_start_row) % MM_SPAD_SRC0_NROWS; + dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src0 + is0 * src0_stride, src0_row + ir0 * src0_row_size), + src0_stride, src0_row_size, 1); + const uint8_t * ss0 = dma_queue_pop(dma_queue).dst; + mmctx->vec_dot_1x1(ne00, &tmp[ir0 - src0_start_row], ss0, src1_col); + ir0 += 1; + } + } else { + const uint32_t src0_end_row_x2 = src0_start_row + ((src0_end_row - src0_start_row) & ~1U); - // Process the last row (if any) - if (src0_end_row != src0_end_row_x2) { - const uint32_t ir0 = src0_end_row_x2; - const uint32_t is0 = (ir0 - src0_start_row); - dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src0 + is0 * src0_stride, src0_row + ir0 * src0_row_size), - src0_stride, src0_row_size, 1); - const uint8_t * ss0 = dma_queue_pop(dma_queue).dst; - mt->vec_dot(ne00, &tmp[ir0 - src0_start_row], ss0, src1_col); + // Prefill spad with 2x src0 rows + #pragma unroll(2) + for (uint32_t ir0 = src0_start_row; ir0 < src0_end_row_x2; ir0 += 2) { + const uint32_t is0 = (ir0 - src0_start_row); + if (is0 >= MM_SPAD_SRC0_NROWS) { + break; + } + dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src0 + is0 * src0_stride, src0_row + ir0 * src0_row_size), + src0_stride, src0_row_size, 2); + } + + // Process src0 rows + for (uint32_t ir0 = src0_start_row; ir0 < src0_end_row_x2; ir0 += 2) { + const uint8_t * ss0 = dma_queue_pop(dma_queue).dst; + mmctx->vec_dot_2x1(ne00, &tmp[ir0 - src0_start_row], ss0, ss0 + src0_stride, src1_col); + + // Prefetch next (n + spad_nrows) row + const uint32_t pr0 = (ir0 + MM_SPAD_SRC0_NROWS); + const uint32_t is0 = (pr0 - src0_start_row) % MM_SPAD_SRC0_NROWS; + if (pr0 < src0_end_row_x2) { + dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src0 + is0 * src0_stride, src0_row + pr0 * src0_row_size), + src0_stride, src0_row_size, 2); + } + } + + // Process the last row (if any) + if (src0_end_row != src0_end_row_x2) { + const uint32_t ir0 = src0_end_row_x2; + const uint32_t is0 = (ir0 - src0_start_row) % MM_SPAD_SRC0_NROWS; + dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src0 + is0 * src0_stride, src0_row + ir0 * src0_row_size), + src0_stride, src0_row_size, 1); + const uint8_t * ss0 = dma_queue_pop(dma_queue).dst; + mmctx->vec_dot_1x1(ne00, &tmp[ir0 - src0_start_row], ss0, src1_col); + } } - hvx_copy_fp32_ua((uint8_t *) &dst_col[src0_start_row], (uint8_t *) tmp, src0_end_row - src0_start_row); + hvx_copy_f32_ua((uint8_t *) &dst_col[src0_start_row], (uint8_t *) tmp, src0_end_row - src0_start_row); t2 = HAP_perf_get_qtimer_count(); - FARF(HIGH, "matvec-%s %u/%u: %ux%ux%ux%u (%u:%u) * %ux%ux%ux%u -> %ux%ux%ux%u usec %u\n", mt->type, ith, nth, + FARF(HIGH, "matvec-%s %u/%u: %ux%ux%ux%u (%u:%u) * %ux%ux%ux%u -> %ux%ux%ux%u usec %u\n", mmctx->type, ith, nth, src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], src0_start_row, src0_end_row, src1->ne[0], src1->ne[1], src1->ne[2], src1->ne[3], dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1)); @@ -1377,11 +3667,11 @@ struct mmid_row_mapping { }; // src1 tensor is already in VTCM spad -static void matmul_id(struct htp_matmul_type * mt, struct htp_ops_context * octx, uint32_t nth, uint32_t ith) { +static void matmul_id(unsigned int nth, unsigned int ith, void * data) { htp_matmul_preamble; - struct htp_tensor * restrict ids = &octx->src2; - struct htp_spad * restrict src2_spad = &octx->src2_spad; + const struct htp_tensor * restrict ids = octx->src[2]; + struct htp_spad * restrict src2_spad = &octx->src2_spad; uint64_t t1, t2; t1 = HAP_perf_get_qtimer_count(); @@ -1401,17 +3691,14 @@ static void matmul_id(struct htp_matmul_type * mt, struct htp_ops_context * octx const uint32_t n_ids = ids->ne[0]; // n_expert_used const uint32_t n_as = ne02; // n_expert - const size_t matrix_row_counts_size = n_as * sizeof(uint32_t); - const size_t matrix_row_map_size = n_as * ids->ne[0] * ids->ne[1] * sizeof(struct mmid_row_mapping); - - const uint32_t * matrix_row_counts = (const uint32_t *) src2_spad->data + 0; - const struct mmid_row_mapping * matrix_rows = (const void *) src2_spad->data + matrix_row_counts_size; + const uint32_t * matrix_row_counts = mmctx->matrix_row_counts; + const struct mmid_row_mapping * matrix_rows = mmctx->matrix_rows; const size_t dst_row_size = nb1; const size_t src0_row_size = nb01; const size_t src1_row_size = q8x4x2_row_size(ne10); - const size_t src0_row_size_padded = htp_round_up(src0_row_size, 128); + const size_t src0_row_size_padded = hex_round_up(src0_row_size, 128); // Per-thread VTCM scratchpads for all tensors // Note that the entire src1 tensor is already in VTCM @@ -1427,6 +3714,10 @@ static void matmul_id(struct htp_matmul_type * mt, struct htp_ops_context * octx continue; } + if (mmctx->hmx_eligible) { + continue; + } + const uint8_t * src0_row = (const uint8_t *) src0->data + (0 + cur_a * nb02 + 0); // Prefill spad with src0 rows @@ -1450,11 +3741,10 @@ static void matmul_id(struct htp_matmul_type * mt, struct htp_ops_context * octx const int rm2 = row_mapping.i2; // token idx const uint32_t ir1 = src1_nrows == 1 ? 0 : rm1; // src1 row idx - const uint8_t * restrict src1_col = - (const uint8_t *) (src1_data + (ir1 + rm2 * ne11 + 0) * src1_row_size); + const uint8_t * restrict src1_col = (const uint8_t *) (src1_data + (ir1 + rm2 * ne11 + 0) * src1_row_size); float * dst_row = (float *) (dst->data + (rm1 * nb1 + rm2 * nb2 + 0)); - mt->vec_dot_rx2(ne00, &dst_row[ir0], ss0, src0_row_size_padded, src1_col); + mmctx->vec_dot_2x1(ne00, &dst_row[ir0], ss0, ss0 + src0_row_size_padded, src1_col); } // Prefetch next (n + spad_nrows) row @@ -1469,7 +3759,7 @@ static void matmul_id(struct htp_matmul_type * mt, struct htp_ops_context * octx // Process the last row (if any) if (src0_end_row != src0_end_row_x2) { uint32_t ir0 = src0_end_row_x2; - const uint32_t is0 = (ir0 - src0_start_row); + const uint32_t is0 = (ir0 - src0_start_row) % MM_SPAD_SRC0_NROWS; dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src0 + is0 * src0_row_size_padded, src0_row + ir0 * src0_row_size), src0_row_size_padded, src0_row_size, 1); const uint8_t * ss0 = dma_queue_pop(dma_queue).dst; @@ -1480,29 +3770,28 @@ static void matmul_id(struct htp_matmul_type * mt, struct htp_ops_context * octx const int rm2 = row_mapping.i2; // token idx const uint32_t ir1 = src1_nrows == 1 ? 0 : rm1; // src1 row idx - const uint8_t * restrict src1_col = - (const uint8_t *) (src1_data + (ir1 + rm2 * ne11 + 0) * src1_row_size); + const uint8_t * restrict src1_col = (const uint8_t *) (src1_data + (ir1 + rm2 * ne11 + 0) * src1_row_size); float * dst_row = (float *) (dst->data + (rm1 * nb1 + rm2 * nb2 + 0)); - mt->vec_dot(ne00, &dst_row[ir0], ss0, src1_col); + mmctx->vec_dot_1x1(ne00, &dst_row[ir0], ss0, src1_col); } } } t2 = HAP_perf_get_qtimer_count(); - FARF(HIGH, "matmul-id-%s %d/%d: %ux%ux%ux%u (%u:%u) * %ux%ux%ux%u (%ux%ux%ux%u) -> %ux%ux%ux%u usec %u\n", mt->type, + FARF(HIGH, "matmul-id-%s %d/%d: %ux%ux%ux%u (%u:%u) * %ux%ux%ux%u (%ux%ux%ux%u) -> %ux%ux%ux%u usec %u\n", mmctx->type, ith, nth, src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], src0_start_row, src0_end_row, src1->ne[0], src1->ne[1], src1->ne[2], src1->ne[3], ids->ne[0], ids->ne[1], ids->ne[2], ids->ne[3], dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1)); } // src1 tensor is already in VTCM spad -static void matvec_id(struct htp_matmul_type * mt, struct htp_ops_context * octx, uint32_t nth, uint32_t ith) { +static void matvec_id(unsigned int nth, unsigned int ith, void * data) { htp_matmul_preamble; - struct htp_tensor * restrict ids = &octx->src2; - struct htp_spad * restrict src2_spad = &octx->src2_spad; + const struct htp_tensor * restrict ids = octx->src[2]; + struct htp_spad * restrict src2_spad = &octx->src2_spad; uint64_t t1, t2; t1 = HAP_perf_get_qtimer_count(); @@ -1524,7 +3813,7 @@ static void matvec_id(struct htp_matmul_type * mt, struct htp_ops_context * octx const size_t src0_row_size = nb01; const size_t src1_row_size = q8x4x2_row_size(ne10); - const size_t src0_row_size_padded = htp_round_up(src0_row_size, 128); + const size_t src0_row_size_padded = hex_round_up(src0_row_size, 128); const uint32_t n_aids = src2->ne[0]; // num activated experts const uint32_t n_ids = ne02; // num experts @@ -1558,7 +3847,7 @@ static void matvec_id(struct htp_matmul_type * mt, struct htp_ops_context * octx // Process src0 rows for (uint32_t ir0 = src0_start_row; ir0 < src0_end_row_x2; ir0 += 2) { const uint8_t * ss0 = dma_queue_pop(dma_queue).dst; - mt->vec_dot_rx2(ne00, &dst_row[ir0], ss0, src0_row_size_padded, src1_col); + mmctx->vec_dot_2x1(ne00, &dst_row[ir0], ss0, ss0 + src0_row_size_padded, src1_col); // Prefetch next (n + spad_nrows) row const int pr0 = (ir0 + MM_SPAD_SRC0_NROWS); @@ -1572,17 +3861,17 @@ static void matvec_id(struct htp_matmul_type * mt, struct htp_ops_context * octx // Process the last row (if any) if (src0_end_row != src0_end_row_x2) { uint32_t ir0 = src0_end_row_x2; - const uint32_t is0 = (ir0 - src0_start_row); + const uint32_t is0 = (ir0 - src0_start_row) % MM_SPAD_SRC0_NROWS; dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src0 + is0 * src0_row_size_padded, src0_row + ir0 * src0_row_size), src0_row_size_padded, src0_row_size, 1); const uint8_t * ss0 = dma_queue_pop(dma_queue).dst; - mt->vec_dot(ne00, &dst_row[ir0], ss0, src1_col); + mmctx->vec_dot_1x1(ne00, &dst_row[ir0], ss0, src1_col); } } t2 = HAP_perf_get_qtimer_count(); - FARF(HIGH, "matvec-id-%s %d/%d: %ux%ux%ux%u (%u:%u) * %ux%ux%ux%u (%ux%ux%ux%u) -> %ux%ux%ux%u usec %u\n", mt->type, + FARF(HIGH, "matvec-id-%s %d/%d: %ux%ux%ux%u (%u:%u) * %ux%ux%ux%u (%ux%ux%ux%u) -> %ux%ux%ux%u usec %u\n", mmctx->type, ith, nth, src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], src0_start_row, src0_end_row, src1->ne[0], src1->ne[1], src1->ne[2], src1->ne[3], src2->ne[0], src2->ne[1], src2->ne[2], src2->ne[3], dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1)); @@ -1590,18 +3879,19 @@ static void matvec_id(struct htp_matmul_type * mt, struct htp_ops_context * octx // *** dynamic quant -static inline void quantize_block_fp32_q8x1(float * restrict x, uint8_t * restrict y_q, uint8_t * restrict y_d) { +static inline void quantize_block_f32_q8_1x1(float * restrict x, uint8_t * restrict y_q, uint8_t * restrict y_d) { assert((unsigned long) x % 128 == 0); assert((unsigned long) y_q % 128 == 0); HVX_Vector * vx = (HVX_Vector *) x; - HVX_Vector zero = Q6_V_vsplat_R(0); + HVX_Vector zero = Q6_V_vzero(); // Use reduce max fp32 to find max(abs(e)) first - HVX_Vector vmax0_sf = hvx_vec_reduce_max_fp32(hvx_vec_abs_fp32(vx[0])); - HVX_Vector vmax1_sf = hvx_vec_reduce_max_fp32(hvx_vec_abs_fp32(vx[1])); - HVX_Vector vmax2_sf = hvx_vec_reduce_max_fp32(hvx_vec_abs_fp32(vx[2])); - HVX_Vector vmax3_sf = hvx_vec_reduce_max_fp32(hvx_vec_abs_fp32(vx[3])); + HVX_Vector vmax0_sf = hvx_vec_reduce_max_f32(hvx_vec_abs_f32(vx[0])); + HVX_Vector vmax1_sf = hvx_vec_reduce_max_f32(hvx_vec_abs_f32(vx[1])); + HVX_Vector vmax2_sf = hvx_vec_reduce_max_f32(hvx_vec_abs_f32(vx[2])); + HVX_Vector vmax3_sf = hvx_vec_reduce_max_f32(hvx_vec_abs_f32(vx[3])); + // Load and convert into QF32 HVX_Vector vx0_qf = Q6_Vqf32_vsub_VsfVsf(vx[0], zero); // 32 elements HVX_Vector vx1_qf = Q6_Vqf32_vsub_VsfVsf(vx[1], zero); // 32 elements @@ -1622,10 +3912,92 @@ static inline void quantize_block_fp32_q8x1(float * restrict x, uint8_t * restri HVX_Vector vx01_hf = Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(vx1_qf, vx0_qf))); HVX_Vector vx23_hf = Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(vx3_qf, vx2_qf))); - // Replicate first fp16 scale across all lanes - HVX_Vector ctrl = *(const HVX_Vector *) repl_2x_fp16; - vmax01_hf = Q6_V_vdelta_VV(vmax01_hf, ctrl); - vmax23_hf = Q6_V_vdelta_VV(vmax23_hf, ctrl); + HVX_Vector vd01_qf16 = Q6_Vqf16_vmpy_VhfVhf(vmax01_hf, Q6_Vh_vsplat_R(0x2008)); // 1.0 / 127.0 + HVX_Vector vd23_qf16 = Q6_Vqf16_vmpy_VhfVhf(vmax23_hf, Q6_Vh_vsplat_R(0x2008)); // 1.0 / 127.0 + HVX_Vector vd01_hf = Q6_Vhf_equals_Vqf16(vd01_qf16); + HVX_Vector vd23_hf = Q6_Vhf_equals_Vqf16(vd23_qf16); + + // Divide input by the scale + HVX_Vector vd01_inv_hf = hvx_vec_inverse_f16(vd01_hf); + HVX_Vector vd23_inv_hf = hvx_vec_inverse_f16(vd23_hf); + vx01_hf = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(vx01_hf, vd01_inv_hf)); + vx23_hf = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(vx23_hf, vd23_inv_hf)); + + // Convert to int8 + HVX_Vector vx01_i16 = hvx_vec_i16_from_hf_rnd_sat(vx01_hf); + HVX_Vector vx23_i16 = hvx_vec_i16_from_hf_rnd_sat(vx23_hf); + HVX_Vector vx_i8 = Q6_Vb_vpack_VhVh_sat(vx23_i16, vx01_i16); + + *(HVX_Vector *) y_q = vx_i8; + + // --- Sum calculation --- + const HVX_Vector ones = Q6_Vb_vsplat_R(1); + HVX_Vector v_sums = Q6_Vw_vrmpy_VbVb(vx_i8, ones); // sum every 4 consecutive elements + // Sum 8 elements: + v_sums = Q6_Vw_vadd_VwVw(v_sums, Q6_V_vror_VR(v_sums, 4)); + v_sums = Q6_Vw_vadd_VwVw(v_sums, Q6_V_vror_VR(v_sums, 8)); + v_sums = Q6_Vw_vadd_VwVw(v_sums, Q6_V_vror_VR(v_sums, 16)); + + // Copy to stack to extract sums and vmaxes + float vmax0[32] __attribute__((aligned(128))); + float vmax1[32] __attribute__((aligned(128))); + float vmax2[32] __attribute__((aligned(128))); + float vmax3[32] __attribute__((aligned(128))); + int32_t sums[32] __attribute__((aligned(128))); + + hvx_vec_store_u(vmax0, 128, vmax0_sf); + hvx_vec_store_u(vmax1, 128, vmax1_sf); + hvx_vec_store_u(vmax2, 128, vmax2_sf); + hvx_vec_store_u(vmax3, 128, vmax3_sf); + hvx_vec_store_u(sums, 128, v_sums); + + float d0 = vmax0[0] / 127.0f; + float d1 = vmax1[0] / 127.0f; + float d2 = vmax2[0] / 127.0f; + float d3 = vmax3[0] / 127.0f; + + __fp16 * y_d_half = (__fp16 *) y_d; + y_d_half[0] = d0; + y_d_half[1] = (float) sums[0] * d0; + y_d_half[2] = d1; + y_d_half[3] = (float) sums[8] * d1; + y_d_half[4] = d2; + y_d_half[5] = (float) sums[16] * d2; + y_d_half[6] = d3; + y_d_half[7] = (float) sums[24] * d3; +} + +static inline void quantize_block_f32_q8x1(float * restrict x, uint8_t * restrict y_q, uint8_t * restrict y_d) { + assert((unsigned long) x % 128 == 0); + assert((unsigned long) y_q % 128 == 0); + + HVX_Vector * vx = (HVX_Vector *) x; + HVX_Vector zero = Q6_V_vzero(); + + // Use reduce max fp32 to find max(abs(e)) first + HVX_Vector vmax0_sf = hvx_vec_reduce_max_f32(hvx_vec_abs_f32(vx[0])); + HVX_Vector vmax1_sf = hvx_vec_reduce_max_f32(hvx_vec_abs_f32(vx[1])); + HVX_Vector vmax2_sf = hvx_vec_reduce_max_f32(hvx_vec_abs_f32(vx[2])); + HVX_Vector vmax3_sf = hvx_vec_reduce_max_f32(hvx_vec_abs_f32(vx[3])); + // Load and convert into QF32 + HVX_Vector vx0_qf = Q6_Vqf32_vsub_VsfVsf(vx[0], zero); // 32 elements + HVX_Vector vx1_qf = Q6_Vqf32_vsub_VsfVsf(vx[1], zero); // 32 elements + HVX_Vector vx2_qf = Q6_Vqf32_vsub_VsfVsf(vx[2], zero); // 32 elements + HVX_Vector vx3_qf = Q6_Vqf32_vsub_VsfVsf(vx[3], zero); // 32 elements + + // Convert to QF32 + HVX_Vector vmax0_qf = Q6_Vqf32_vsub_VsfVsf(vmax0_sf, zero); // replicated over all lanes + HVX_Vector vmax1_qf = Q6_Vqf32_vsub_VsfVsf(vmax1_sf, zero); // replicated over all lanes + HVX_Vector vmax2_qf = Q6_Vqf32_vsub_VsfVsf(vmax2_sf, zero); // replicated over all lanes + HVX_Vector vmax3_qf = Q6_Vqf32_vsub_VsfVsf(vmax3_sf, zero); // replicated over all lanes + + // Combine and convert to fp16 + HVX_Vector vmax01_hf = Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(vmax1_qf, vmax0_qf))); + HVX_Vector vmax23_hf = Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(vmax3_qf, vmax2_qf))); + + // Convert into fp16 + HVX_Vector vx01_hf = Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(vx1_qf, vx0_qf))); + HVX_Vector vx23_hf = Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(vx3_qf, vx2_qf))); HVX_Vector vd01_qf16 = Q6_Vqf16_vmpy_VhfVhf(vmax01_hf, Q6_Vh_vsplat_R(0x2008)); // 1.0 / 127.0 HVX_Vector vd23_qf16 = Q6_Vqf16_vmpy_VhfVhf(vmax23_hf, Q6_Vh_vsplat_R(0x2008)); // 1.0 / 127.0 @@ -1641,8 +4013,8 @@ static inline void quantize_block_fp32_q8x1(float * restrict x, uint8_t * restri hvx_vec_store_u(y_d + 6, 2, rotated_vd_hf); // Divide input by the scale - HVX_Vector vd01_inv_hf = hvx_vec_inverse_fp16(vd01_hf); - HVX_Vector vd23_inv_hf = hvx_vec_inverse_fp16(vd23_hf); + HVX_Vector vd01_inv_hf = hvx_vec_inverse_f16(vd01_hf); + HVX_Vector vd23_inv_hf = hvx_vec_inverse_f16(vd23_hf); vx01_hf = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(vx01_hf, vd01_inv_hf)); vx23_hf = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(vx23_hf, vd23_inv_hf)); @@ -1654,14 +4026,14 @@ static inline void quantize_block_fp32_q8x1(float * restrict x, uint8_t * restri *(HVX_Vector *) y_q = vx_i8; } -static inline void quantize_block_fp32_q8x2(float * restrict x, uint8_t * restrict y_q, uint8_t * restrict y_d) { +static inline void quantize_block_f32_q8x2(float * restrict x, uint8_t * restrict y_q, uint8_t * restrict y_d) { assert((unsigned long) x % 128 == 0); assert((unsigned long) y_q % 128 == 0); HVX_Vector * vx = (HVX_Vector *) x; // Load and convert into QF32 - HVX_Vector zero = Q6_V_vsplat_R(0); + HVX_Vector zero = Q6_V_vzero(); HVX_Vector vx0_qf = Q6_Vqf32_vsub_VsfVsf(vx[0], zero); // 32 elements HVX_Vector vx1_qf = Q6_Vqf32_vsub_VsfVsf(vx[1], zero); // 32 elements HVX_Vector vx2_qf = Q6_Vqf32_vsub_VsfVsf(vx[2], zero); // 32 elements @@ -1672,13 +4044,8 @@ static inline void quantize_block_fp32_q8x2(float * restrict x, uint8_t * restri HVX_Vector vx23_hf = Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(vx3_qf, vx2_qf))); // Compute max and scale - HVX_Vector vmax01_hf = hvx_vec_reduce_max_fp16(hvx_vec_abs_fp16(vx01_hf)); - HVX_Vector vmax23_hf = hvx_vec_reduce_max_fp16(hvx_vec_abs_fp16(vx23_hf)); - - // Replicate first fp16 scale across all lanes - HVX_Vector ctrl = *(const HVX_Vector *) repl_1x_fp16; - vmax01_hf = Q6_V_vdelta_VV(vmax01_hf, ctrl); - vmax23_hf = Q6_V_vdelta_VV(vmax23_hf, ctrl); + HVX_Vector vmax01_hf = hvx_vec_reduce_max_f16(hvx_vec_abs_f16(vx01_hf)); // replicated over all lanes + HVX_Vector vmax23_hf = hvx_vec_reduce_max_f16(hvx_vec_abs_f16(vx23_hf)); // replicated over all lanes HVX_Vector vd01_qf16 = Q6_Vqf16_vmpy_VhfVhf(vmax01_hf, Q6_Vh_vsplat_R(0x2008)); // 1.0 / 127.0 HVX_Vector vd23_qf16 = Q6_Vqf16_vmpy_VhfVhf(vmax23_hf, Q6_Vh_vsplat_R(0x2008)); // 1.0 / 127.0 @@ -1689,8 +4056,8 @@ static inline void quantize_block_fp32_q8x2(float * restrict x, uint8_t * restri hvx_vec_store_u(y_d + 4, 4, vd23_hf); // Divide input by the scale - HVX_Vector vd01_inv_hf = hvx_vec_inverse_fp16(vd01_hf); - HVX_Vector vd23_inv_hf = hvx_vec_inverse_fp16(vd23_hf); + HVX_Vector vd01_inv_hf = hvx_vec_inverse_f16(vd01_hf); + HVX_Vector vd23_inv_hf = hvx_vec_inverse_f16(vd23_hf); vx01_hf = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(vx01_hf, vd01_inv_hf)); vx23_hf = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(vx23_hf, vd23_inv_hf)); @@ -1702,14 +4069,14 @@ static inline void quantize_block_fp32_q8x2(float * restrict x, uint8_t * restri *(HVX_Vector *) y_q = vx_i8; } -static inline void quantize_block_fp32_q8x4(float * restrict x, uint8_t * restrict y_q, uint8_t * restrict y_d) { +static inline void quantize_block_f32_q8x4(float * restrict x, uint8_t * restrict y_q, uint8_t * restrict y_d) { assert((unsigned long) x % 128 == 0); assert((unsigned long) y_q % 128 == 0); HVX_Vector * vx = (HVX_Vector *) x; // Load and convert into QF32 - HVX_Vector zero = Q6_V_vsplat_R(0); + HVX_Vector zero = Q6_V_vzero(); HVX_Vector vx0_qf = Q6_Vqf32_vsub_VsfVsf(vx[0], zero); // 32 elements HVX_Vector vx1_qf = Q6_Vqf32_vsub_VsfVsf(vx[1], zero); // 32 elements HVX_Vector vx2_qf = Q6_Vqf32_vsub_VsfVsf(vx[2], zero); // 32 elements @@ -1719,74 +4086,144 @@ static inline void quantize_block_fp32_q8x4(float * restrict x, uint8_t * restri HVX_Vector vx01_hf = Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(vx1_qf, vx0_qf))); HVX_Vector vx23_hf = Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(vx3_qf, vx2_qf))); - // Compute max and scale - HVX_Vector vmax_hf = hvx_vec_reduce_max_fp16(hvx_vec_abs_fp16(vx01_hf)); - vmax_hf = hvx_vec_reduce_max2_fp16(hvx_vec_abs_fp16(vx23_hf), vmax_hf); + // Compute max and scale + HVX_Vector vmax_hf = hvx_vec_reduce_max_f16(hvx_vec_abs_f16(vx01_hf)); + vmax_hf = hvx_vec_reduce_max2_f16(hvx_vec_abs_f16(vx23_hf), vmax_hf); // replicated over all lanes + + HVX_Vector vd_qf16 = Q6_Vqf16_vmpy_VhfVhf(vmax_hf, Q6_Vh_vsplat_R(0x2008)); // 1.0 / 127.0 + HVX_Vector vd_hf = Q6_Vhf_equals_Vqf16(vd_qf16); + + *(HVX_UVector *) y_d = vd_hf; + + // Divide input by the scale + HVX_Vector vd_inv_hf = hvx_vec_inverse_f16(vd_hf); + vx01_hf = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(vx01_hf, vd_inv_hf)); + vx23_hf = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(vx23_hf, vd_inv_hf)); + + // Convert to int8 + HVX_Vector vx01_i16 = hvx_vec_i16_from_hf_rnd_sat(vx01_hf); + HVX_Vector vx23_i16 = hvx_vec_i16_from_hf_rnd_sat(vx23_hf); + HVX_Vector vx_i8 = Q6_Vb_vpack_VhVh_sat(vx23_i16, vx01_i16); + + *(HVX_Vector *) y_q = vx_i8; +} + +// Overrides input x +static void quantize_row_f32_q8x4x2(float * restrict x, uint8_t * restrict y, uint32_t k) { + assert(k % 32 == 0); + const uint32_t qk = QK_Q8_0x4x2; + const uint32_t nb = (k + qk - 1) / qk; + + const uint32_t qrow_size = k; // int8 + + const uint32_t dblk_size = 8 * 2; // 8x __fp16 + const uint32_t qblk_size = QK_Q8_0x4x2; // int8 + + uint8_t * restrict y_q = (y + 0); // quants first + uint8_t * restrict y_d = (y + qrow_size); // then scales + + // Temp scales override input since we're working off of the aligned temp buffer in VTCM + uint8_t * restrict t_d = (uint8_t *) x; + + for (uint32_t i = 0; i < nb; i++) { +#if FP32_QUANTIZE_GROUP_SIZE == 32 + quantize_block_f32_q8x1(x + (i*2 + 0) * qk/2, y_q + (i*2 + 0) * qblk_size/2, t_d + (i*2 + 0) * dblk_size/2); + quantize_block_f32_q8x1(x + (i*2 + 1) * qk/2, y_q + (i*2 + 1) * qblk_size/2, t_d + (i*2 + 1) * dblk_size/2); +#elif FP32_QUANTIZE_GROUP_SIZE == 64 + quantize_block_f32_q8x2(x + (i*2 + 0) * qk/2, y_q + (i*2 + 0) * qblk_size/2, t_d + (i*2 + 0) * dblk_size/2); + quantize_block_f32_q8x2(x + (i*2 + 1) * qk/2, y_q + (i*2 + 1) * qblk_size/2, t_d + (i*2 + 1) * dblk_size/2); +#elif FP32_QUANTIZE_GROUP_SIZE == 128 + quantize_block_f32_q8x4(x + (i*2 + 0) * qk/2, y_q + (i*2 + 0) * qblk_size/2, t_d + (i*2 + 0) * dblk_size/2); + quantize_block_f32_q8x4(x + (i*2 + 1) * qk/2, y_q + (i*2 + 1) * qblk_size/2, t_d + (i*2 + 1) * dblk_size/2); +#else +#error "FP32_QUANTIZE_GROUP_SIZE must be 32, 64, or 128" +#endif + } + + // now copy the scales into final location + hvx_copy_f16_ua(y_d, t_d, nb * 8); +} + +static void quantize_f32_q8x4x2(unsigned int nth, unsigned int ith, void * data) { + struct htp_matmul_context * mmctx = data; + struct htp_ops_context * octx = mmctx->octx; + + const struct htp_tensor * src = octx->src[1]; + uint8_t * restrict dst = octx->src1_spad.data; + struct htp_spad * spad = &octx->src0_spad; + uint32_t nrows_per_thread = mmctx->src1_nrows_per_thread; + + uint64_t t1 = HAP_perf_get_qtimer_count(); + + const uint32_t ne0 = src->ne[0]; + const uint32_t ne1 = src->ne[1]; + const uint32_t ne2 = src->ne[2]; + const uint32_t ne3 = src->ne[3]; + + const uint32_t nrows = ne1 * ne2 * ne3; // total n_rows + + const uint32_t ir_first = nrows_per_thread * ith; // first row + const uint32_t ir_last = MIN(ir_first + nrows_per_thread, nrows); // last row + + const size_t src_row_size = src->nb[1]; + const size_t dst_row_size = q8x4x2_row_size(ne0); - // Replicate first fp16 scale across all lanes - HVX_Vector ctrl = *(const HVX_Vector *) repl_1x_fp16; - vmax_hf = Q6_V_vdelta_VV(vmax_hf, ctrl); + uint8_t * restrict src_data = (uint8_t *) src->data + (src_row_size * ir_first); + uint8_t * restrict dst_data = (uint8_t *) dst + (dst_row_size * ir_first); + uint8_t * restrict tmp_data = (uint8_t *) spad->data + (spad->size_per_thread * ith); - HVX_Vector vd_qf16 = Q6_Vqf16_vmpy_VhfVhf(vmax_hf, Q6_Vh_vsplat_R(0x2008)); // 1.0 / 127.0 - HVX_Vector vd_hf = Q6_Vhf_equals_Vqf16(vd_qf16); + const size_t src_row_size_padded = hex_round_up(src_row_size, QK_Q8_0x4x2 * sizeof(float)); + memset(tmp_data, 0, src_row_size_padded); // zero-out temp row data for padding - *(HVX_UVector *) y_d = vd_hf; + for (uint32_t i = ir_first; i < ir_last; ++i) { + hex_l2fetch(src_data, src_row_size, src_row_size, 2); + hvx_copy_f32_aa(tmp_data, src_data, ne0); - // Divide input by the scale - HVX_Vector vd_inv_hf = hvx_vec_inverse_fp16(vd_hf); - vx01_hf = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(vx01_hf, vd_inv_hf)); - vx23_hf = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(vx23_hf, vd_inv_hf)); + // FARF(HIGH, "quantize-q8x4-row: %u\n", i); + quantize_row_f32_q8x4x2((float *) tmp_data, dst_data, ne0); + dst_data += dst_row_size; + src_data += src_row_size; + } - // Convert to int8 - HVX_Vector vx01_i16 = hvx_vec_i16_from_hf_rnd_sat(vx01_hf); - HVX_Vector vx23_i16 = hvx_vec_i16_from_hf_rnd_sat(vx23_hf); - HVX_Vector vx_i8 = Q6_Vb_vpack_VhVh_sat(vx23_i16, vx01_i16); + uint64_t t2 = HAP_perf_get_qtimer_count(); - *(HVX_Vector *) y_q = vx_i8; + FARF(HIGH, "quantize-f32-q8x4: %u/%u : n-rows %u (%u:%u) row-size %u -> %u usec %u\n", ith, nth, nrows, ir_first, + ir_last, src_row_size, dst_row_size, (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1)); } -// Overrides input x -static void quantize_row_fp32_q8x4x2(float * restrict x, uint8_t * restrict y, uint32_t k) { +static void quantize_row_f32_q8_1x4x2(float * restrict x, uint8_t * restrict y, uint32_t k) { assert(k % 32 == 0); const uint32_t qk = QK_Q8_0x4x2; const uint32_t nb = (k + qk - 1) / qk; const uint32_t qrow_size = k; // int8 - const uint32_t dblk_size = 8 * 2; // 8x __fp16 + const uint32_t dblk_size = 8 * 4; // 8x (d, s) __fp16 = 32 bytes const uint32_t qblk_size = QK_Q8_0x4x2; // int8 uint8_t * restrict y_q = (y + 0); // quants first - uint8_t * restrict y_d = (y + qrow_size); // then scales + uint8_t * restrict y_d = (y + qrow_size); // then scales/sums // Temp scales override input since we're working off of the aligned temp buffer in VTCM uint8_t * restrict t_d = (uint8_t *) x; for (uint32_t i = 0; i < nb; i++) { -#if FP32_QUANTIZE_GROUP_SIZE == 32 - quantize_block_fp32_q8x1(x + (i*2 + 0) * qk/2, y_q + (i*2 + 0) * qblk_size/2, t_d + (i*2 + 0) * dblk_size/2); - quantize_block_fp32_q8x1(x + (i*2 + 1) * qk/2, y_q + (i*2 + 1) * qblk_size/2, t_d + (i*2 + 1) * dblk_size/2); -#elif FP32_QUANTIZE_GROUP_SIZE == 64 - quantize_block_fp32_q8x2(x + (i*2 + 0) * qk/2, y_q + (i*2 + 0) * qblk_size/2, t_d + (i*2 + 0) * dblk_size/2); - quantize_block_fp32_q8x2(x + (i*2 + 1) * qk/2, y_q + (i*2 + 1) * qblk_size/2, t_d + (i*2 + 1) * dblk_size/2); -#elif FP32_QUANTIZE_GROUP_SIZE == 128 - quantize_block_fp32_q8x4(x + (i*2 + 0) * qk/2, y_q + (i*2 + 0) * qblk_size/2, t_d + (i*2 + 0) * dblk_size/2); - quantize_block_fp32_q8x4(x + (i*2 + 1) * qk/2, y_q + (i*2 + 1) * qblk_size/2, t_d + (i*2 + 1) * dblk_size/2); -#else -#error "FP32_QUANTIZE_GROUP_SIZE must be 32, 64, or 128" -#endif + quantize_block_f32_q8_1x1(x + (i*2 + 0) * qk/2, y_q + (i*2 + 0) * qblk_size/2, t_d + (i*2 + 0) * dblk_size/2); + quantize_block_f32_q8_1x1(x + (i*2 + 1) * qk/2, y_q + (i*2 + 1) * qblk_size/2, t_d + (i*2 + 1) * dblk_size/2); } - // now copy the scales into final location - hvx_copy_fp16_ua(y_d, t_d, nb * 8); + // now copy the scales/sums into final location + hvx_copy_f16_ua(y_d, t_d, nb * 16); } -static void quantize_fp32_q8x4x2(const struct htp_tensor * src, - uint8_t * restrict dst, - struct htp_spad * spad, - uint32_t nth, - uint32_t ith, - uint32_t nrows_per_thread) { +static void quantize_f32_q8_1x4x2(unsigned int nth, unsigned int ith, void * data) { + struct htp_matmul_context * mmctx = data; + struct htp_ops_context * octx = mmctx->octx; + + const struct htp_tensor * src = octx->src[1]; + uint8_t * restrict dst = octx->src1_spad.data; + struct htp_spad * spad = &octx->src0_spad; + uint32_t nrows_per_thread = mmctx->src1_nrows_per_thread; uint64_t t1 = HAP_perf_get_qtimer_count(); @@ -1801,33 +4238,38 @@ static void quantize_fp32_q8x4x2(const struct htp_tensor * src, const uint32_t ir_last = MIN(ir_first + nrows_per_thread, nrows); // last row const size_t src_row_size = src->nb[1]; - const size_t dst_row_size = q8x4x2_row_size(ne0); + const size_t dst_row_size = q8_1x4x2_row_size(ne0); uint8_t * restrict src_data = (uint8_t *) src->data + (src_row_size * ir_first); uint8_t * restrict dst_data = (uint8_t *) dst + (dst_row_size * ir_first); uint8_t * restrict tmp_data = (uint8_t *) spad->data + (spad->size_per_thread * ith); - const size_t src_row_size_padded = htp_round_up(src_row_size, QK_Q8_0x4x2 * sizeof(float)); + const size_t src_row_size_padded = hex_round_up(src_row_size, QK_Q8_0x4x2 * sizeof(float)); memset(tmp_data, 0, src_row_size_padded); // zero-out temp row data for padding for (uint32_t i = ir_first; i < ir_last; ++i) { - htp_l2fetch(src_data, 2, src_row_size, src_row_size); - hvx_copy_fp32_aa(tmp_data, src_data, ne0); + hex_l2fetch(src_data, src_row_size, src_row_size, 2); + hvx_copy_f32_aa(tmp_data, src_data, ne0); - // FARF(HIGH, "quantize-q8x4-row: %u\n", i); - quantize_row_fp32_q8x4x2((float *) tmp_data, dst_data, ne0); + quantize_row_f32_q8_1x4x2((float *) tmp_data, dst_data, ne0); dst_data += dst_row_size; src_data += src_row_size; } uint64_t t2 = HAP_perf_get_qtimer_count(); - FARF(HIGH, "quantize-fp32-q8x4: %u/%u : n-rows %u (%u:%u) row-size %u -> %u usec %u\n", ith, nth, nrows, ir_first, + FARF(HIGH, "quantize-f32-q8_1x4: %u/%u : n-rows %u (%u:%u) row-size %u -> %u usec %u\n", ith, nth, nrows, ir_first, ir_last, src_row_size, dst_row_size, (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1)); } -static void quantize_fp32_fp16(const struct htp_tensor * src, uint8_t * restrict dst, uint32_t nth, uint32_t ith, - uint32_t nrows_per_thread, uint32_t dst_stride) { +static void quantize_f32_f32(unsigned int nth, unsigned int ith, void * data) { + struct htp_matmul_context * mmctx = data; + struct htp_ops_context * octx = mmctx->octx; + + const struct htp_tensor * src = octx->src[1]; + uint8_t * restrict dst = octx->src1_spad.data; + uint32_t nrows_per_thread = mmctx->src1_nrows_per_thread; + uint32_t dst_stride = octx->src1_spad.stride; uint64_t t1 = HAP_perf_get_qtimer_count(); @@ -1848,8 +4290,8 @@ static void quantize_fp32_fp16(const struct htp_tensor * src, uint8_t * restrict uint8_t * restrict dst_data = (uint8_t *) dst + (dst_stride * ir_first); for (uint32_t i = ir_first; i < ir_last; ++i) { - htp_l2fetch(src_data, 2, src_row_size, src_stride); - hvx_copy_fp16_fp32_au(dst_data, src_data, ne0); + hex_l2fetch(src_data, src_row_size, src_stride, 2); + hvx_copy_f32_au(dst_data, src_data, ne0); dst_data += dst_stride; src_data += src_stride; @@ -1857,13 +4299,18 @@ static void quantize_fp32_fp16(const struct htp_tensor * src, uint8_t * restrict uint64_t t2 = HAP_perf_get_qtimer_count(); - FARF(HIGH, "quantize-fp32-fp16: %u/%u : n-rows %u (%u:%u) row-size %u (%u) -> %u usec %u\n", ith, nth, nrows, ir_first, + FARF(HIGH, "quantize-f32-f32: %u/%u : n-rows %u (%u:%u) row-size %u (%u) -> %u usec %u\n", ith, nth, nrows, ir_first, ir_last, src_row_size, src_stride, dst_stride, (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1)); } -// TODO just a plain copy that should be done via the DMA during the Op setup -static void quantize_fp16_fp16(const struct htp_tensor * src, uint8_t * restrict dst, uint32_t nth, uint32_t ith, - uint32_t nrows_per_thread, uint32_t dst_stride) { +static void quantize_f32_f16(unsigned int nth, unsigned int ith, void * data) { + struct htp_matmul_context * mmctx = data; + struct htp_ops_context * octx = mmctx->octx; + + const struct htp_tensor * src = octx->src[1]; + uint8_t * restrict dst = octx->src1_spad.data; + uint32_t nrows_per_thread = mmctx->src1_nrows_per_thread; + uint32_t dst_stride = octx->src1_spad.stride; uint64_t t1 = HAP_perf_get_qtimer_count(); @@ -1884,8 +4331,8 @@ static void quantize_fp16_fp16(const struct htp_tensor * src, uint8_t * restrict uint8_t * restrict dst_data = (uint8_t *) dst + (dst_stride * ir_first); for (uint32_t i = ir_first; i < ir_last; ++i) { - htp_l2fetch(src_data, 2, src_row_size, src_stride); - hvx_copy_fp16_au(dst_data, src_data, ne0); + hex_l2fetch(src_data, src_row_size, src_stride, 2); + hvx_copy_f16_f32_au(dst_data, src_data, ne0); dst_data += dst_stride; src_data += src_stride; @@ -1893,450 +4340,465 @@ static void quantize_fp16_fp16(const struct htp_tensor * src, uint8_t * restrict uint64_t t2 = HAP_perf_get_qtimer_count(); - FARF(HIGH, "quantize-fp16-fp16: %u/%u : n-rows %u (%u:%u) row-size %u (%u) -> %u usec %u\n", ith, nth, nrows, ir_first, + FARF(HIGH, "quantize-f32-f16: %u/%u : n-rows %u (%u:%u) row-size %u (%u) -> %u usec %u\n", ith, nth, nrows, ir_first, ir_last, src_row_size, src_stride, dst_stride, (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1)); } -static void htp_quantize_fp32_q8x4x2(unsigned int n, unsigned int i, void * data) { - struct htp_ops_context * octx = data; - quantize_fp32_q8x4x2(&octx->src1, octx->src1_spad.data, &octx->src0_spad, n, i, octx->src1_nrows_per_thread); -} - -static void htp_quantize_fp32_fp16(unsigned int n, unsigned int i, void * data) { - struct htp_ops_context * octx = data; - quantize_fp32_fp16(&octx->src1, octx->src1_spad.data, n, i, octx->src1_nrows_per_thread, octx->src1_spad.stride); -} - -static void htp_quantize_fp16_fp16(unsigned int n, unsigned int i, void * data) { - struct htp_ops_context * octx = data; - quantize_fp16_fp16(&octx->src1, octx->src1_spad.data, n, i, octx->src1_nrows_per_thread, octx->src1_spad.stride); -} - -// ** matmul/matvec callbacks for worker_pool - -static void htp_matvec_2d_q4x4x2_q8x4x2(unsigned int n, unsigned int i, void * data) { - struct htp_ops_context * octx = data; - - struct htp_matmul_type mt; - mt.type = "q4x4x2-q8x4x2"; - mt.vec_dot = vec_dot_q4x4x2_q8x4x2; - mt.vec_dot_rx2 = vec_dot_q4x4x2_q8x4x2_rx2; - - matvec_2d(&mt, octx, n, i); -} - -static void htp_matmul_2d_q4x4x2_q8x4x2(unsigned int n, unsigned int i, void * data) { - struct htp_ops_context * octx = data; - - struct htp_matmul_type mt; - mt.type = "q4x4x2-q8x4x2"; - mt.vec_dot = vec_dot_q4x4x2_q8x4x2; - mt.vec_dot_rx2 = vec_dot_q4x4x2_q8x4x2_rx2; - - matmul_2d(&mt, octx, n, i); -} - -static void htp_matvec_2d_q8x4x2_q8x4x2(unsigned int n, unsigned int i, void * data) { - struct htp_ops_context * octx = data; - - struct htp_matmul_type mt; - mt.type = "q8x4x2-q8x4x2"; - mt.vec_dot = vec_dot_q8x4x2_q8x4x2; - mt.vec_dot_rx2 = vec_dot_q8x4x2_q8x4x2_rx2; - - matvec_2d(&mt, octx, n, i); -} - -static void htp_matmul_2d_q8x4x2_q8x4x2(unsigned int n, unsigned int i, void * data) { - struct htp_ops_context * octx = data; - - struct htp_matmul_type mt; - mt.type = "q8x4x2-q8x4x2"; - mt.vec_dot = vec_dot_q8x4x2_q8x4x2; - mt.vec_dot_rx2 = vec_dot_q8x4x2_q8x4x2_rx2; - - matmul_2d(&mt, octx, n, i); -} - -static void htp_matvec_2d_mxfp4x4x2_q8x4x2(unsigned int n, unsigned int i, void * data) { - struct htp_ops_context * octx = data; - - struct htp_matmul_type mt; - mt.type = "mxfp4x4x2-q8x4x2"; - mt.vec_dot = vec_dot_mxfp4x4x2_q8x4x2; - mt.vec_dot_rx2 = vec_dot_mxfp4x4x2_q8x4x2_rx2; - - matvec_2d(&mt, octx, n, i); -} - -static void htp_matmul_2d_mxfp4x4x2_q8x4x2(unsigned int n, unsigned int i, void * data) { - struct htp_ops_context * octx = data; - - struct htp_matmul_type mt; - mt.type = "mxfp4x4x2-q8x4x2"; - mt.vec_dot = vec_dot_mxfp4x4x2_q8x4x2; - mt.vec_dot_rx2 = vec_dot_mxfp4x4x2_q8x4x2_rx2; - - matmul_2d(&mt, octx, n, i); -} - -static void htp_matvec_2d_f16_f16(unsigned int n, unsigned int i, void * data) { - struct htp_ops_context * octx = data; - - struct htp_matmul_type mt; - mt.type = "f16-f16"; - mt.vec_dot = vec_dot_f16_f16_aa; - mt.vec_dot_rx2 = vec_dot_f16_f16_aa_rx2; - - matvec_2d(&mt, octx, n, i); -} - -static void htp_matmul_2d_f16_f16(unsigned int n, unsigned int i, void * data) { - struct htp_ops_context * octx = data; - - struct htp_matmul_type mt; - mt.type = "f16-f16"; - mt.vec_dot = vec_dot_f16_f16_aa; - mt.vec_dot_rx2 = vec_dot_f16_f16_aa_rx2; - - matmul_2d(&mt, octx, n, i); -} - -static void htp_matmul_4d_f16_f32(unsigned int n, unsigned int i, void * data) { - struct htp_ops_context * octx = data; - - struct htp_matmul_type mt; - mt.type = "f16-f32"; - mt.vec_dot = vec_dot_f16_f32_uu; - - matmul_4d(&mt, octx, n, i); -} - -static void htp_matmul_4d_f16_f16(unsigned int n, unsigned int i, void * data) { - struct htp_ops_context * octx = data; - - struct htp_matmul_type mt; - mt.type = "f16-f16"; - mt.vec_dot = vec_dot_f16_f16_uu; +// TODO just a plain copy that should be done via the DMA during the Op setup +static void quantize_f16_f16(unsigned int nth, unsigned int ith, void * data) { + struct htp_matmul_context * mmctx = data; + struct htp_ops_context * octx = mmctx->octx; - matmul_4d(&mt, octx, n, i); -} + const struct htp_tensor * src = octx->src[1]; + uint8_t * restrict dst = octx->src1_spad.data; + uint32_t nrows_per_thread = mmctx->src1_nrows_per_thread; + uint32_t dst_stride = octx->src1_spad.stride; -// ** matmul-id callbacks for worker_pool + uint64_t t1 = HAP_perf_get_qtimer_count(); -static void htp_matvec_id_q4x4x2_q8x4x2(unsigned int n, unsigned int i, void * data) { - struct htp_ops_context * octx = data; + const uint32_t ne0 = src->ne[0]; + const uint32_t ne1 = src->ne[1]; + const uint32_t ne2 = src->ne[2]; + const uint32_t ne3 = src->ne[3]; - struct htp_matmul_type mt; - mt.type = "q4x4x2-q8x4x2"; - mt.vec_dot = vec_dot_q4x4x2_q8x4x2; - mt.vec_dot_rx2 = vec_dot_q4x4x2_q8x4x2_rx2; + const uint32_t nrows = ne1 * ne2 * ne3; // total n_rows - matvec_id(&mt, octx, n, i); -} + const uint32_t ir_first = nrows_per_thread * ith; // first row + const uint32_t ir_last = MIN(ir_first + nrows_per_thread, nrows); // last row -static void htp_matmul_id_q4x4x2_q8x4x2(unsigned int n, unsigned int i, void * data) { - struct htp_ops_context * octx = data; + const size_t src_row_size = ne0 * sizeof(float); + const size_t src_stride = src->nb[1]; - struct htp_matmul_type mt; - mt.type = "q4x4x2-q8x4x2"; - mt.vec_dot = vec_dot_q4x4x2_q8x4x2; - mt.vec_dot_rx2 = vec_dot_q4x4x2_q8x4x2_rx2; + uint8_t * restrict src_data = (uint8_t *) src->data + (src_stride * ir_first); + uint8_t * restrict dst_data = (uint8_t *) dst + (dst_stride * ir_first); - matmul_id(&mt, octx, n, i); -} + for (uint32_t i = ir_first; i < ir_last; ++i) { + hex_l2fetch(src_data, src_row_size, src_stride, 2); + hvx_copy_f16_au(dst_data, src_data, ne0); -static void htp_matvec_id_q8x4x2_q8x4x2(unsigned int n, unsigned int i, void * data) { - struct htp_ops_context * octx = data; + dst_data += dst_stride; + src_data += src_stride; + } - struct htp_matmul_type mt; - mt.type = "q8x4x2-q8x4x2"; - mt.vec_dot = vec_dot_q8x4x2_q8x4x2; - mt.vec_dot_rx2 = vec_dot_q8x4x2_q8x4x2_rx2; + uint64_t t2 = HAP_perf_get_qtimer_count(); - matvec_id(&mt, octx, n, i); + FARF(HIGH, "quantize-f16-f16: %u/%u : n-rows %u (%u:%u) row-size %u (%u) -> %u usec %u\n", ith, nth, nrows, ir_first, + ir_last, src_row_size, src_stride, dst_stride, (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1)); } -static void htp_matmul_id_q8x4x2_q8x4x2(unsigned int n, unsigned int i, void * data) { - struct htp_ops_context * octx = data; - - struct htp_matmul_type mt; - mt.type = "q8x4x2-q8x4x2"; - mt.vec_dot = vec_dot_q8x4x2_q8x4x2; - mt.vec_dot_rx2 = vec_dot_q8x4x2_q8x4x2_rx2; - matmul_id(&mt, octx, n, i); +static inline bool htp_is_permuted(const struct htp_tensor * t) { + return t->nb[0] > t->nb[1] || t->nb[1] > t->nb[2] || t->nb[2] > t->nb[3]; } -static void htp_matvec_id_mxfp4x4x2_q8x4x2(unsigned int n, unsigned int i, void * data) { - struct htp_ops_context * octx = data; - - struct htp_matmul_type mt; - mt.type = "mxfp4x4x2-q8x4x2"; - mt.vec_dot = vec_dot_mxfp4x4x2_q8x4x2; - mt.vec_dot_rx2 = vec_dot_mxfp4x4x2_q8x4x2_rx2; - - matvec_id(&mt, octx, n, i); +static int htp_mminit_vec_dot(struct htp_matmul_context * mmctx, enum htp_data_type type) { + switch (type) { + case HTP_TYPE_Q4_0: + mmctx->type = "q4x4x2-f32"; + mmctx->vec_dot_1x1 = vec_dot_q4x4x2_q8x4x2_1x1; + mmctx->vec_dot_2x1 = vec_dot_q4x4x2_q8x4x2_2x1; + mmctx->vec_dot_2x2 = vec_dot_q4x4x2_q8x4x2_2x2; + mmctx->vec_dot_4x1 = vec_dot_q4x4x2_q8x4x2_4x1; + return 0; + case HTP_TYPE_Q4_1: + mmctx->type = "q4_1x4x2-f32"; + mmctx->vec_dot_1x1 = vec_dot_q4_1x4x2_q8x4x2_1x1; + mmctx->vec_dot_2x1 = vec_dot_q4_1x4x2_q8x4x2_2x1; + mmctx->vec_dot_2x2 = vec_dot_q4_1x4x2_q8x4x2_2x2; + mmctx->vec_dot_4x1 = vec_dot_q4_1x4x2_q8x4x2_4x1; + return 0; + case HTP_TYPE_Q8_0: + mmctx->type = "q8x4x2-f32"; + mmctx->vec_dot_1x1 = vec_dot_q8x4x2_q8x4x2_1x1; + mmctx->vec_dot_2x1 = vec_dot_q8x4x2_q8x4x2_2x1; + mmctx->vec_dot_2x2 = vec_dot_q8x4x2_q8x4x2_2x2; + mmctx->vec_dot_4x1 = vec_dot_q8x4x2_q8x4x2_4x1; + return 0; + case HTP_TYPE_IQ4_NL: + mmctx->type = "iq4nlx4x2-f32"; + mmctx->vec_dot_1x1 = vec_dot_iq4nlx4x2_q8x4x2_1x1; + mmctx->vec_dot_2x1 = vec_dot_iq4nlx4x2_q8x4x2_2x1; + mmctx->vec_dot_2x2 = vec_dot_iq4nlx4x2_q8x4x2_2x2; + mmctx->vec_dot_4x1 = vec_dot_iq4nlx4x2_q8x4x2_4x1; + return 0; + case HTP_TYPE_MXFP4: + mmctx->type = "mxfp4x4x2-f32"; + mmctx->vec_dot_1x1 = vec_dot_mxfp4x4x2_q8x4x2_1x1; + mmctx->vec_dot_2x1 = vec_dot_mxfp4x4x2_q8x4x2_2x1; + mmctx->vec_dot_2x2 = vec_dot_mxfp4x4x2_q8x4x2_2x2; + mmctx->vec_dot_4x1 = vec_dot_mxfp4x4x2_q8x4x2_4x1; + return 0; + default: + return -1; + } } -static void htp_matmul_id_mxfp4x4x2_q8x4x2(unsigned int n, unsigned int i, void * data) { - struct htp_ops_context * octx = data; - - struct htp_matmul_type mt; - mt.type = "mxfp4x4x2-q8x4x2"; - mt.vec_dot = vec_dot_mxfp4x4x2_q8x4x2; - mt.vec_dot_rx2 = vec_dot_mxfp4x4x2_q8x4x2_rx2; - - matmul_id(&mt, octx, n, i); -} +static void htp_mminit_spad(struct htp_ops_context * octx, + size_t dst_row_size, + size_t src0_row_size_padded, + size_t src1_row_size, + uint32_t src1_nrows, + size_t src2_spad_size_per_thread) { + octx->dst_spad.size_per_thread = hex_round_up(MM_SPAD_DST_NROWS * dst_row_size, 256); + octx->src0_spad.size_per_thread = hex_round_up(MM_SPAD_SRC0_NROWS * src0_row_size_padded, 256); + octx->src1_spad.size_per_thread = hex_round_up(src1_row_size * src1_nrows, 256); + + if (src2_spad_size_per_thread > 0) { + octx->src2_spad.size_per_thread = src2_spad_size_per_thread; + octx->src2_spad.size = octx->src2_spad.size_per_thread; + } -// ** main matmul entry point + // src0 spad is also used in dynamic quantizer to store padded src1 rows + size_t src1_row_size_padded = hex_round_up(src1_row_size, QK_Q8_0x4x2 * sizeof(float)); + if (octx->src0_spad.size_per_thread < src1_row_size_padded) { + octx->src0_spad.size_per_thread = src1_row_size_padded; + } -static inline bool htp_is_permuted(const struct htp_tensor * t) { - return t->nb[0] > t->nb[1] || t->nb[1] > t->nb[2] || t->nb[2] > t->nb[3]; + octx->src1_spad.size = octx->src1_spad.size_per_thread; + octx->src0_spad.size = octx->src0_spad.size_per_thread * octx->n_threads; + octx->dst_spad.size = octx->dst_spad.size_per_thread * octx->n_threads; } -int op_matmul(struct htp_ops_context * octx) { +static int op_matmul_hvx(struct htp_ops_context * octx) { htp_matmul_tensors_preamble; - const char * op_type; + struct htp_matmul_context mmctx_struct = {0}; + struct htp_matmul_context * mmctx = &mmctx_struct; + mmctx->octx = octx; const uint32_t src0_nrows = ne01 * ne02 * ne03; const uint32_t src1_nrows = ne11 * ne12 * ne13; + // Compute src0_nrows_per_thread + mmctx->src0_nrows_per_thread = (src0_nrows + octx->n_threads - 1) / octx->n_threads; + mmctx->src0_nrows_per_thread += (mmctx->src0_nrows_per_thread & 1); // round up to even + const size_t src0_row_size = nb01; const size_t dst_row_size = nb1; size_t src1_row_size = nb11; - const size_t src0_row_size_padded = htp_round_up(src0_row_size, 128); + const size_t src0_row_size_padded = hex_round_up(src0_row_size, 128); size_t src1_row_size_padded; worker_callback_t quant_job_func; - worker_callback_t matmul_job_func; + worker_callback_t matmul_job_func = src1_nrows > 1 ? matmul_2d : matvec_2d; - bool need_quant = !(octx->flags & HTP_OPFLAGS_SKIP_QUANTIZE); + bool need_quant = true; - switch (src0->type) { - case HTP_TYPE_Q4_0: - op_type = "q4x4x2-fp32"; - quant_job_func = htp_quantize_fp32_q8x4x2; - if (src1_nrows > 1) { - matmul_job_func = htp_matmul_2d_q4x4x2_q8x4x2; - } else { - matmul_job_func = htp_matvec_2d_q4x4x2_q8x4x2; - } + if (src0->type == HTP_TYPE_F16) { + // Try optimized f16-f16 path first (src1 in VTCM) + const size_t f16_src1_row_size = hex_round_up(ne10 * 2, 128); + const size_t f16_src1_spad_size = hex_round_up(f16_src1_row_size * src1_nrows, 256); + const size_t f16_src0_spad_size = hex_round_up(MM_SPAD_SRC0_NROWS * src0_row_size_padded, 256) * octx->n_threads; + const size_t f16_dst_spad_size = hex_round_up(MM_SPAD_DST_NROWS * dst_row_size, 256) * octx->n_threads; - src1_row_size = q8x4x2_row_size(ne10); // row size post quantization + const size_t f16_total_size = f16_src1_spad_size + f16_src0_spad_size + f16_dst_spad_size; - // Entire src1 tensor is placed into the VTCM - // For other tensors we allocate N rows per thread, padded to HVX vector size + // Default matmul implementation does not support multi-batch src0 (N-vs-N broadcasting). + // It only supports 1-vs-N broadcasting (src0 is 2D) or standard 2D matmul. + const bool is_batched = (ne02 > 1) || (ne03 > 1); + const bool is_permuted = htp_is_permuted(octx->src[0]) || htp_is_permuted(octx->src[1]); - octx->dst_spad.size_per_thread = htp_round_up(MM_SPAD_DST_NROWS * dst_row_size, 256); - octx->src0_spad.size_per_thread = htp_round_up(MM_SPAD_SRC0_NROWS * src0_row_size_padded, 256); - octx->src1_spad.size_per_thread = htp_round_up(src1_row_size * src1_nrows, 256); + if (!is_batched && !is_permuted && f16_total_size <= octx->ctx->vtcm_size) { + // Optimized path + quant_job_func = (src1->type == HTP_TYPE_F32) ? quantize_f32_f16 : quantize_f16_f16; + mmctx->type = "f16-f16"; + mmctx->vec_dot_1x1 = vec_dot_f16_f16_aa_1x1; + mmctx->vec_dot_2x1 = vec_dot_f16_f16_aa_2x1; + mmctx->vec_dot_2x2 = vec_dot_f16_f16_aa_2x2; - // src0 spad is also used in dynamic quantizer to store padded src1 rows - src1_row_size_padded = htp_round_up(src1_row_size, QK_Q8_0x4x2 * sizeof(float)); - if (octx->src0_spad.size_per_thread < src1_row_size_padded) { - octx->src0_spad.size_per_thread = src1_row_size_padded; - } + src1_row_size = f16_src1_row_size; // row size post quantization + + octx->dst_spad.size_per_thread = hex_round_up(MM_SPAD_DST_NROWS * dst_row_size, 256); + octx->src0_spad.size_per_thread = hex_round_up(MM_SPAD_SRC0_NROWS * src0_row_size_padded, 256); + octx->src1_spad.size_per_thread = hex_round_up(src1_row_size * src1_nrows, 256); octx->src1_spad.size = octx->src1_spad.size_per_thread; octx->src0_spad.size = octx->src0_spad.size_per_thread * octx->n_threads; octx->dst_spad.size = octx->dst_spad.size_per_thread * octx->n_threads; - break; - - case HTP_TYPE_Q8_0: - op_type = "q8x4x2-fp32"; - quant_job_func = htp_quantize_fp32_q8x4x2; - if (src1_nrows > 1) { - matmul_job_func = htp_matmul_2d_q8x4x2_q8x4x2; + } else { + // Fallback to f16/f32 (DDR) if src1 doesn't fit in VTCM or broadcasting is required + quant_job_func = NULL; + if (src1->type == HTP_TYPE_F32) { + mmctx->type = "f16-f32"; + mmctx->vec_dot_1x1 = vec_dot_f16_f32_uu_1x1; + matmul_job_func = matmul_4d; } else { - matmul_job_func = htp_matvec_2d_q8x4x2_q8x4x2; + mmctx->type = "f16-f16"; + mmctx->vec_dot_1x1 = vec_dot_f16_f16_uu_1x1; + matmul_job_func = matmul_4d; } - src1_row_size = q8x4x2_row_size(ne10); // row size post quantization - - // Entire src1 tensor is placed into the VTCM - // For other tensors we allocate N rows per thread, padded to HVX vector size - - octx->dst_spad.size_per_thread = htp_round_up(MM_SPAD_DST_NROWS * dst_row_size, 256); - octx->src0_spad.size_per_thread = htp_round_up(MM_SPAD_SRC0_NROWS * src0_row_size_padded, 256); - octx->src1_spad.size_per_thread = htp_round_up(src1_row_size * src1_nrows, 256); + src1_row_size = nb11; // original row size in DDR - // src0 spad is also used in dynamic quantizer to store padded src1 rows - src1_row_size_padded = htp_round_up(src1_row_size, QK_Q8_0x4x2 * sizeof(float)); - if (octx->src0_spad.size_per_thread < src1_row_size_padded) { - octx->src0_spad.size_per_thread = src1_row_size_padded; - } + octx->dst_spad.size_per_thread = hex_round_up(MM_SPAD_DST_NROWS * dst_row_size, 256); + octx->src0_spad.size_per_thread = hex_round_up(MM_SPAD_SRC0_NROWS * src0_row_size, 256); + octx->src1_spad.size_per_thread = hex_round_up(MM_SPAD_SRC1_NROWS * src1_row_size, 256); - octx->src1_spad.size = octx->src1_spad.size_per_thread; octx->src0_spad.size = octx->src0_spad.size_per_thread * octx->n_threads; + octx->src1_spad.size = octx->src1_spad.size_per_thread * octx->n_threads; octx->dst_spad.size = octx->dst_spad.size_per_thread * octx->n_threads; - break; - case HTP_TYPE_MXFP4: - op_type = "mxfp4x4x2-f32"; - quant_job_func = htp_quantize_fp32_q8x4x2; - if (src1_nrows > 1) { - matmul_job_func = htp_matmul_2d_mxfp4x4x2_q8x4x2; - } else { - matmul_job_func = htp_matvec_2d_mxfp4x4x2_q8x4x2; - } + // Init fastdiv for matmul_4d (supports broadcasting) + mmctx->mm_div_ne12_ne1 = init_fastdiv_values(src1->ne[2] * dst->ne[1]); + mmctx->mm_div_ne1 = init_fastdiv_values(dst->ne[1]); + mmctx->mm_div_r2 = init_fastdiv_values(src1->ne[2] / src0->ne[2]); + mmctx->mm_div_r3 = init_fastdiv_values(src1->ne[3] / src0->ne[3]); - src1_row_size = q8x4x2_row_size(ne10); // row size post quantization + need_quant = false; + } + } else if (src0->type == HTP_TYPE_F32) { + // Try optimized f32-f32 path first (src1 in VTCM) + const size_t f32_src1_row_size = hex_round_up(ne10 * 4, 128); + const size_t f32_src1_spad_size = hex_round_up(f32_src1_row_size * src1_nrows, 256); + const size_t f32_src0_spad_size = hex_round_up(MM_SPAD_SRC0_NROWS * src0_row_size_padded, 256) * octx->n_threads; + const size_t f32_dst_spad_size = hex_round_up(MM_SPAD_DST_NROWS * dst_row_size, 256) * octx->n_threads; - // Entire src1 tensor is placed into the VTCM - // For other tensors we allocate N rows per thread, padded to HVX vector size + const size_t f32_total_size = f32_src1_spad_size + f32_src0_spad_size + f32_dst_spad_size; - octx->dst_spad.size_per_thread = htp_round_up(MM_SPAD_DST_NROWS * dst_row_size, 256); - octx->src0_spad.size_per_thread = htp_round_up(MM_SPAD_SRC0_NROWS * src0_row_size_padded, 256); - octx->src1_spad.size_per_thread = htp_round_up(src1_row_size * src1_nrows, 256); + const bool is_batched = (ne02 > 1) || (ne03 > 1); + const bool is_permuted = htp_is_permuted(octx->src[0]) || htp_is_permuted(octx->src[1]); - // src0 spad is also used in dynamic quantizer to store padded src1 rows - src1_row_size_padded = htp_round_up(src1_row_size, QK_Q8_0x4x2 * sizeof(float)); - if (octx->src0_spad.size_per_thread < src1_row_size_padded) { - octx->src0_spad.size_per_thread = src1_row_size_padded; - } + if (!is_batched && !is_permuted && f32_total_size <= octx->ctx->vtcm_size) { + // Optimized path + quant_job_func = quantize_f32_f32; + mmctx->type = "f32-f32"; + mmctx->vec_dot_1x1 = vec_dot_f32_f32_aa_1x1; + mmctx->vec_dot_2x1 = vec_dot_f32_f32_aa_2x1; + mmctx->vec_dot_2x2 = vec_dot_f32_f32_aa_2x2; + + src1_row_size = f32_src1_row_size; + + octx->dst_spad.size_per_thread = hex_round_up(MM_SPAD_DST_NROWS * dst_row_size, 256); + octx->src0_spad.size_per_thread = hex_round_up(MM_SPAD_SRC0_NROWS * src0_row_size_padded, 256); + octx->src1_spad.size_per_thread = hex_round_up(src1_row_size * src1_nrows, 256); octx->src1_spad.size = octx->src1_spad.size_per_thread; octx->src0_spad.size = octx->src0_spad.size_per_thread * octx->n_threads; octx->dst_spad.size = octx->dst_spad.size_per_thread * octx->n_threads; - break; + } else { + // Fallback to DDR / broadcasting + quant_job_func = NULL; + mmctx->type = "f32-f32"; + mmctx->vec_dot_1x1 = vec_dot_f32_f32_uu_1x1; + matmul_job_func = matmul_4d; - case HTP_TYPE_F16: - { - // Try optimized f16-f16 path first (src1 in VTCM) - const size_t f16_src1_row_size = htp_round_up(ne10 * 2, 128); - const size_t f16_src1_spad_size = htp_round_up(f16_src1_row_size * src1_nrows, 256); - const size_t f16_src0_spad_size = htp_round_up(MM_SPAD_SRC0_NROWS * src0_row_size_padded, 256) * octx->n_threads; - const size_t f16_dst_spad_size = htp_round_up(MM_SPAD_DST_NROWS * dst_row_size, 256) * octx->n_threads; - - const size_t f16_total_size = f16_src1_spad_size + f16_src0_spad_size + f16_dst_spad_size; - - // Default matmul implementation does not support multi-batch src0 (N-vs-N broadcasting). - // It only supports 1-vs-N broadcasting (src0 is 2D) or standard 2D matmul. - const bool is_batched = (ne02 > 1) || (ne03 > 1); - const bool is_permuted = htp_is_permuted(&octx->src0) || htp_is_permuted(&octx->src1); - - if (!is_batched && !is_permuted && f16_total_size <= octx->ctx->vtcm_size) { - // Optimized path - op_type = "f16-f16"; - quant_job_func = (src1->type == HTP_TYPE_F32) ? htp_quantize_fp32_fp16 : htp_quantize_fp16_fp16; - if (src1_nrows > 1) { - matmul_job_func = htp_matmul_2d_f16_f16; - } else { - matmul_job_func = htp_matvec_2d_f16_f16; - } - - src1_row_size = f16_src1_row_size; // row size post quantization - - octx->dst_spad.size_per_thread = htp_round_up(MM_SPAD_DST_NROWS * dst_row_size, 256); - octx->src0_spad.size_per_thread = htp_round_up(MM_SPAD_SRC0_NROWS * src0_row_size_padded, 256); - octx->src1_spad.size_per_thread = htp_round_up(src1_row_size * src1_nrows, 256); - - octx->src1_spad.size = octx->src1_spad.size_per_thread; - octx->src0_spad.size = octx->src0_spad.size_per_thread * octx->n_threads; - octx->dst_spad.size = octx->dst_spad.size_per_thread * octx->n_threads; - } else { - // Fallback to f16/f32 (DDR) if src1 doesn't fit in VTCM or broadcasting is required - quant_job_func = NULL; - if (src1->type == HTP_TYPE_F32) { - op_type = "f16-f32"; - matmul_job_func = htp_matmul_4d_f16_f32; - } else { - op_type = "f16-f16"; - matmul_job_func = htp_matmul_4d_f16_f16; - } - - src1_row_size = nb11; // original row size in DDR - - octx->dst_spad.size_per_thread = htp_round_up(MM_SPAD_DST_NROWS * dst_row_size, 256); - octx->src0_spad.size_per_thread = htp_round_up(MM_SPAD_SRC0_NROWS * src0_row_size, 256); - octx->src1_spad.size_per_thread = htp_round_up(MM_SPAD_SRC1_NROWS * src1_row_size, 256); - - octx->src0_spad.size = octx->src0_spad.size_per_thread * octx->n_threads; - octx->src1_spad.size = octx->src1_spad.size_per_thread * octx->n_threads; - octx->dst_spad.size = octx->dst_spad.size_per_thread * octx->n_threads; - - // Init fastdiv for matmul_4d (supports broadcasting) - octx->mm_div_ne12_ne1 = init_fastdiv_values(src1->ne[2] * dst->ne[1]); - octx->mm_div_ne1 = init_fastdiv_values(dst->ne[1]); - octx->mm_div_r2 = init_fastdiv_values(src1->ne[2] / src0->ne[2]); - octx->mm_div_r3 = init_fastdiv_values(src1->ne[3] / src0->ne[3]); - - need_quant = false; - } - } - break; + src1_row_size = nb11; - default: + octx->dst_spad.size_per_thread = hex_round_up(MM_SPAD_DST_NROWS * dst_row_size, 256); + octx->src0_spad.size_per_thread = hex_round_up(MM_SPAD_SRC0_NROWS * src0_row_size, 256); + octx->src1_spad.size_per_thread = hex_round_up(MM_SPAD_SRC1_NROWS * src1_row_size, 256); + + octx->src0_spad.size = octx->src0_spad.size_per_thread * octx->n_threads; + octx->src1_spad.size = octx->src1_spad.size_per_thread * octx->n_threads; + octx->dst_spad.size = octx->dst_spad.size_per_thread * octx->n_threads; + + // Init fastdiv for matmul_4d (supports broadcasting) + mmctx->mm_div_ne12_ne1 = init_fastdiv_values(src1->ne[2] * dst->ne[1]); + mmctx->mm_div_ne1 = init_fastdiv_values(dst->ne[1]); + mmctx->mm_div_r2 = init_fastdiv_values(src1->ne[2] / src0->ne[2]); + mmctx->mm_div_r3 = init_fastdiv_values(src1->ne[3] / src0->ne[3]); + + need_quant = false; + } + } else { + if (htp_mminit_vec_dot(mmctx, src0->type) != 0) { return HTP_STATUS_NO_SUPPORT; + } + + if (src0->type == HTP_TYPE_Q4_1) { + quant_job_func = quantize_f32_q8_1x4x2; + src1_row_size = q8_1x4x2_row_size(ne10); + } else { + quant_job_func = quantize_f32_q8x4x2; + src1_row_size = q8x4x2_row_size(ne10); + } + htp_mminit_spad(octx, dst_row_size, src0_row_size_padded, src1_row_size, src1_nrows, 0); } // VTCM scratchpads for all tensors size_t spad_size = octx->src1_spad.size + octx->src0_spad.size + octx->dst_spad.size; - FARF(HIGH, "matmul-%s : src0-spad-size %u src1-spad-size %u dst-spad-size %u (%zu)\n", op_type, + FARF(HIGH, "matmul-%s : src0-spad-size %u src1-spad-size %u dst-spad-size %u (%zu)\n", mmctx->type, octx->src0_spad.size, octx->src1_spad.size, octx->dst_spad.size, spad_size); - FARF(HIGH, "matmul-%s : %ux%ux%ux%u * %ux%ux%ux%u-> %ux%ux%ux%u (0x%p, 0x%p, 0x%p)\n", op_type, src0->ne[0], + FARF(HIGH, "matmul-%s : %ux%ux%ux%u * %ux%ux%ux%u-> %ux%ux%ux%u (0x%p, 0x%p, 0x%p)\n", mmctx->type, src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], src1->ne[0], src1->ne[1], src1->ne[2], src1->ne[3], dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], src0->data, src1->data, dst->data); // Make sure the reserved vtcm size is sufficient if (octx->ctx->vtcm_size < spad_size) { - FARF(ERROR, "matmul-%s : current VTCM reservation %zu is too small, needed %zu\n", op_type, + FARF(ERROR, "matmul-%s : current VTCM reservation %zu is too small, needed %zu\n", mmctx->type, octx->ctx->vtcm_size, spad_size); return HTP_STATUS_VTCM_TOO_SMALL; } - octx->src0_spad.data = octx->ctx->vtcm_base; - octx->src1_spad.data = octx->src0_spad.data + octx->src0_spad.size; - octx->dst_spad.data = octx->src1_spad.data + octx->src1_spad.size; + // Place src1 spad first. We use it for dyn.quant and may reuse between ops + octx->src1_spad.data = octx->ctx->vtcm_base; + octx->src0_spad.data = octx->src1_spad.data + octx->src1_spad.size; + octx->dst_spad.data = octx->src0_spad.data + octx->src0_spad.size; - octx->src0_nrows_per_thread = (src0_nrows + octx->n_threads - 1) / octx->n_threads; - octx->src0_nrows_per_thread += (octx->src0_nrows_per_thread & 1); // round up to even + octx->src1_spad.src = (src1 == octx->src1_spad.src) ? src1 : NULL; + octx->src0_spad.src = NULL; + octx->dst_spad.src = NULL; octx->src0_spad.stride = src0_row_size_padded; octx->src1_spad.stride = src1_row_size; - if (need_quant) { - // Run quant jobs - const uint32_t n_quant_jobs = MIN(src1_nrows, octx->n_threads); - octx->src1_nrows_per_thread = (src1_nrows + n_quant_jobs - 1) / n_quant_jobs; - worker_pool_run_func(octx->ctx->worker_pool, quant_job_func, octx, n_quant_jobs); - } + if (octx->flags & HTP_OPFLAGS_SKIP_COMPUTE) + return HTP_STATUS_OK; - if (!(octx->flags & HTP_OPFLAGS_SKIP_COMPUTE)) { - // Run matmul jobs - const uint32_t n_matmul_jobs = octx->n_threads; - worker_pool_run_func(octx->ctx->worker_pool, matmul_job_func, octx, n_matmul_jobs); + if (need_quant && !octx->src1_spad.src) { + const uint32_t n_quant_jobs = MIN(src1_nrows, octx->n_threads); + mmctx->src1_nrows_per_thread = (src1_nrows + n_quant_jobs - 1) / n_quant_jobs; + worker_pool_run_func(octx->ctx->worker_pool, quant_job_func, mmctx, n_quant_jobs); + octx->src1_spad.src = src1; } + const uint32_t n_matmul_jobs = octx->n_threads; + worker_pool_run_func(octx->ctx->worker_pool, matmul_job_func, mmctx, n_matmul_jobs); + return HTP_STATUS_OK; } -// ** main matmul-id entry point +int op_matmul(struct htp_ops_context * octx) { + htp_matmul_tensors_preamble; + +#ifndef HTP_HAS_HMX + return op_matmul_hvx(octx); +#else + if (!octx->ctx->hmx_enabled) { + return op_matmul_hvx(octx); + } + + // HMX weight tile requires N to be 32-aligned. + if (src0->ne[1] % 32 != 0) { + return op_matmul_hvx(octx); + } + + // HMX supports F16, F32, Q4_0, Q8_0, IQ4_NL, MXFP4 weights. + // Other types fall back to HVX. + uint32_t wtype = src0->type; + if (wtype != HTP_TYPE_F16 && wtype != HTP_TYPE_F32 && wtype != HTP_TYPE_Q4_0 && wtype != HTP_TYPE_Q4_1 && wtype != HTP_TYPE_Q8_0 && wtype != HTP_TYPE_IQ4_NL && wtype != HTP_TYPE_MXFP4) { + return op_matmul_hvx(octx); + } + + // Quantised HMX path requires K aligned to 256 (x4x2 super-block). + // F16 and F32 HMX paths require K aligned to 32 (tile width). + if (wtype != HTP_TYPE_F16 && wtype != HTP_TYPE_F32 && src0->ne[0] % 256 != 0) { + return op_matmul_hvx(octx); + } + + if ((wtype == HTP_TYPE_F16 || wtype == HTP_TYPE_F32) && src0->ne[0] % 32 != 0) { + return op_matmul_hvx(octx); + } + + const bool is_batched = (src0->ne[2] * src0->ne[3] > 1 || src1->ne[2] * src1->ne[3] > 1); + + // Quantised HMX kernels only handle flat 2D matmul (host already rejects + // batched quantised, but guard here too). F16 batched matmul is handled + // by the dedicated wrapper in hmx-matmul-ops.c. + if (is_batched && src0->type != HTP_TYPE_F16) { + return op_matmul_hvx(octx); + } + + // HMX assumes contiguous row-major layout. Fall back for permuted + // tensors where strides are non-monotonic (e.g. transposed KV cache). + if (src0->nb[0] > src0->nb[1] || src1->nb[0] > src1->nb[1]) { + return op_matmul_hvx(octx); + } + + // M alignment: Use HMX when M >= 32, the last partial tile (m_total % 32 rows) + // is handled by HMX itself; when M < 32 fall back to HVX. + const int m_total = (int) src1->ne[1]; + const int m_hmx = m_total & ~31; // 0 when M < 32 + if (m_hmx == 0) { + return op_matmul_hvx(octx); + } + + // Always re-quantize src1 since HMX kernel overwrites vtcm/spad, + // so any previously cached quantized data is invalid. + octx->src1_spad.src = NULL; + + int k = (int) src0->ne[0]; // inner dimension + int n = (int) src0->ne[1]; // weight columns + + int ret = -1; + + // Row strides in elements. For compact tensors these equal k; for + // permuted attention views they can be larger, so pass the real stride. + const int act_stride = (int)(src1->nb[1] / sizeof(float)); + const int wgt_stride = (int)(src0->nb[1] / sizeof(__fp16)); + + if (octx->flags & HTP_OPFLAGS_SKIP_COMPUTE) { + return HTP_STATUS_OK; + } + + if (is_batched) { + if (src0->type == HTP_TYPE_F16) { + hmx_matmul_f16_f32_batched_params_t batch_params = { + .dst = (float *) dst->data, + .activation = (float *) src1->data, + .permuted_weight = (const __fp16 *) src0->data, + .m = m_total, + .k = k, + .n = n, + .act_stride = act_stride, + .weight_stride = wgt_stride, + .dst_stride = (int) (dst->nb[1] / sizeof(float)), + .ne02 = ne02, + .ne03 = ne03, + .ne12 = ne12, + .ne13 = ne13, + .src0_nb2 = src0->nb[2], + .src0_nb3 = src0->nb[3], + .src1_nb2 = src1->nb[2], + .src1_nb3 = src1->nb[3], + .dst_nb2 = dst->nb[2], + .dst_nb3 = dst->nb[3], + }; + ret = hmx_matmul_f16_f32_batched(octx->ctx, &batch_params); + } else { + return op_matmul_hvx(octx); + } + } else { + ret = hmx_matmul_2d_f32(octx->ctx, (float*) dst->data, (float*) src1->data, (const uint8_t *) src0->data, + m_total, k, n, act_stride, (int) src0->nb[1], (int) src0->type); + } + + if (ret != 0) { + FARF(HIGH, "HMX matmul failed (ret=%d), falling back to HVX", ret); + return op_matmul(octx); + } + + return 0; +#endif // HTP_HAS_HMX +} int op_matmul_id(struct htp_ops_context * octx) { htp_matmul_tensors_preamble; - struct htp_tensor * restrict ids = &octx->src2; + struct htp_matmul_context mmctx_struct = {0}; + struct htp_matmul_context * mmctx = &mmctx_struct; + mmctx->octx = octx; - const char * op_type; - - worker_callback_t quant_job_func; - worker_callback_t matmul_id_job_func; + const struct htp_tensor * restrict ids = octx->src[2]; const size_t src0_row_size = nb01; const size_t dst_row_size = nb1; - const size_t src0_row_size_padded = htp_round_up(src0_row_size, 128); + const size_t src0_row_size_padded = hex_round_up(src0_row_size, 128); const uint32_t src0_nrows = ne01; // per expert const uint32_t src1_nrows = ne11 * ne12 * ne13; + worker_callback_t quant_job_func; + worker_callback_t matmul_id_job_func = src1_nrows > 1 ? matmul_id : matvec_id; + + // Compute src0_nrows_per_thread + mmctx->src0_nrows_per_thread = (src0_nrows + octx->n_threads - 1) / octx->n_threads; + mmctx->src0_nrows_per_thread += (mmctx->src0_nrows_per_thread & 1); // round up to even + size_t src1_row_size; size_t src1_row_size_padded; @@ -2346,158 +4808,151 @@ int op_matmul_id(struct htp_ops_context * octx) { size_t matrix_row_counts_size = n_as * sizeof(uint32_t); size_t matrix_row_map_size = n_as * ids->ne[0] * ids->ne[1] * sizeof(struct mmid_row_mapping); + const size_t total_map_size = matrix_row_counts_size + matrix_row_map_size; + + void * mapping_buf = NULL; + bool must_free_mapping = false; + + if (octx->ctx->ddr_spad_base && total_map_size <= octx->ctx->ddr_spad_size) { + mapping_buf = octx->ctx->ddr_spad_base; + } else { + mapping_buf = memalign(128, total_map_size); + if (mapping_buf) { + must_free_mapping = true; + } else { + return HTP_STATUS_INTERNAL_ERR; + } + } - switch (src0->type) { - case HTP_TYPE_Q4_0: - op_type = "q4x2x2-f32"; - quant_job_func = htp_quantize_fp32_q8x4x2; - src1_row_size = q8x4x2_row_size(ne10); // row size post quantization - if (src1_nrows > 1) { - matmul_id_job_func = htp_matmul_id_q4x4x2_q8x4x2; - } else { - matmul_id_job_func = htp_matvec_id_q4x4x2_q8x4x2; - } - - // Entire src1 tensor is placed into the VTCM - // For other tensors we allocate N rows per thread, padded to HVX vector size - octx->dst_spad.size_per_thread = htp_round_up(MM_SPAD_DST_NROWS * dst_row_size, 256); - octx->src0_spad.size_per_thread = htp_round_up(MM_SPAD_SRC0_NROWS * src0_row_size_padded, 256); - octx->src1_spad.size_per_thread = htp_round_up(src1_row_size * src1_nrows, 256); - octx->src2_spad.size_per_thread = htp_round_up(matrix_row_counts_size + matrix_row_map_size, 256); - - // src0 spad is also used in dynamic quantizer to store padded src1 rows - src1_row_size_padded = htp_round_up(src1_row_size, QK_Q8_0x4x2 * sizeof(float)); - if (octx->src0_spad.size_per_thread < src1_row_size_padded) { - octx->src0_spad.size_per_thread = src1_row_size_padded; - } - - octx->src2_spad.size = octx->src2_spad.size_per_thread; - octx->src1_spad.size = octx->src1_spad.size_per_thread; - octx->src0_spad.size = octx->src0_spad.size_per_thread * octx->n_threads; - octx->dst_spad.size = octx->dst_spad.size_per_thread * octx->n_threads; - break; - - case HTP_TYPE_Q8_0: - op_type = "q8x2x2-f32"; - quant_job_func = htp_quantize_fp32_q8x4x2; - src1_row_size = q8x4x2_row_size(ne10); // row size post quantization - if (src1_nrows > 1) { - matmul_id_job_func = htp_matmul_id_q8x4x2_q8x4x2; - } else { - matmul_id_job_func = htp_matvec_id_q8x4x2_q8x4x2; - } - - // Entire src1 tensor is placed into the VTCM - // For other tensors we allocate N rows per thread, padded to HVX vector size - octx->dst_spad.size_per_thread = htp_round_up(MM_SPAD_DST_NROWS * dst_row_size, 256); - octx->src0_spad.size_per_thread = htp_round_up(MM_SPAD_SRC0_NROWS * src0_row_size_padded, 256); - octx->src1_spad.size_per_thread = htp_round_up(src1_row_size * src1_nrows, 256); - octx->src2_spad.size_per_thread = htp_round_up(matrix_row_counts_size + matrix_row_map_size, 256); - - // src0 spad is also used in dynamic quantizer to store padded src1 rows - src1_row_size_padded = htp_round_up(src1_row_size, QK_Q8_0x4x2 * sizeof(float)); - if (octx->src0_spad.size_per_thread < src1_row_size_padded) { - octx->src0_spad.size_per_thread = src1_row_size_padded; - } - - octx->src2_spad.size = octx->src2_spad.size_per_thread; - octx->src1_spad.size = octx->src1_spad.size_per_thread; - octx->src0_spad.size = octx->src0_spad.size_per_thread * octx->n_threads; - octx->dst_spad.size = octx->dst_spad.size_per_thread * octx->n_threads; - break; - - case HTP_TYPE_MXFP4: - op_type = "mxfp4x2x2-f32"; - quant_job_func = htp_quantize_fp32_q8x4x2; - src1_row_size = q8x4x2_row_size(ne10); // row size post quantization - if (src1_nrows > 1) { - matmul_id_job_func = htp_matmul_id_mxfp4x4x2_q8x4x2; - } else { - matmul_id_job_func = htp_matvec_id_mxfp4x4x2_q8x4x2; - } + uint32_t * matrix_row_counts = (uint32_t *) mapping_buf; + struct mmid_row_mapping * matrix_rows = (struct mmid_row_mapping *) ((uint8_t *) mapping_buf + matrix_row_counts_size); - // Entire src1 tensor is placed into the VTCM - // For other tensors we allocate N rows per thread, padded to HVX vector size - octx->dst_spad.size_per_thread = htp_round_up(MM_SPAD_DST_NROWS * dst_row_size, 256); - octx->src0_spad.size_per_thread = htp_round_up(MM_SPAD_SRC0_NROWS * src0_row_size_padded, 256); - octx->src1_spad.size_per_thread = htp_round_up(src1_row_size * src1_nrows, 256); - octx->src2_spad.size_per_thread = htp_round_up(matrix_row_counts_size + matrix_row_map_size, 256); - - // src0 spad is also used in dynamic quantizer to store padded src1 rows - src1_row_size_padded = htp_round_up(src1_row_size, QK_Q8_0x4x2 * sizeof(float)); - if (octx->src0_spad.size_per_thread < src1_row_size_padded) { - octx->src0_spad.size_per_thread = src1_row_size_padded; - } + mmctx->matrix_row_counts = matrix_row_counts; + mmctx->matrix_rows = matrix_rows; - octx->src2_spad.size = octx->src2_spad.size_per_thread; - octx->src1_spad.size = octx->src1_spad.size_per_thread; - octx->src0_spad.size = octx->src0_spad.size_per_thread * octx->n_threads; - octx->dst_spad.size = octx->dst_spad.size_per_thread * octx->n_threads; - break; + if (htp_mminit_vec_dot(mmctx, src0->type) != 0) { + if (must_free_mapping) free(mapping_buf); + return HTP_STATUS_NO_SUPPORT; + } - default: - return HTP_STATUS_NO_SUPPORT; + if (src0->type == HTP_TYPE_Q4_1) { + quant_job_func = quantize_f32_q8_1x4x2; + src1_row_size = q8_1x4x2_row_size(ne10); + } else { + quant_job_func = quantize_f32_q8x4x2; + src1_row_size = q8x4x2_row_size(ne10); } + const size_t src2_spad_size_per_thread = 0; // We moved the mapping to DDR! + htp_mminit_spad(octx, dst_row_size, src0_row_size_padded, src1_row_size, src1_nrows, src2_spad_size_per_thread); + size_t spad_size = octx->src2_spad.size + octx->src1_spad.size + octx->src0_spad.size + octx->dst_spad.size; - FARF(HIGH, "matmul-id-%s : src0-spad-size %u src1-spad-size %u src2-spad-size %u dst-spad-size %u (%zu)\n", op_type, + FARF(HIGH, "matmul-id-%s : src0-spad-size %u src1-spad-size %u src2-spad-size %u dst-spad-size %u (%zu)\n", mmctx->type, octx->src0_spad.size, octx->src1_spad.size, octx->src2_spad.size, octx->dst_spad.size, spad_size); - FARF(HIGH, "matmul-id-%s : %ux%ux%ux%u * %ux%ux%ux%u (%ux%ux%ux%u) -> %ux%ux%ux%u (0x%p, 0x%p, 0x%p)\n", op_type, + FARF(HIGH, "matmul-id-%s : %ux%ux%ux%u * %ux%ux%ux%u (%ux%ux%ux%u) -> %ux%ux%ux%u (0x%p, 0x%p, 0x%p)\n", mmctx->type, src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], src1->ne[0], src1->ne[1], src1->ne[2], src1->ne[3], ids->ne[0], ids->ne[1], ids->ne[2], ids->ne[3], dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], src0->data, src1->data, dst->data); // Make sure the reserved vtcm size is sufficient if (octx->ctx->vtcm_size < spad_size) { - FARF(ERROR, "matmul-id-%s : current VTCM reservation %zu is too small, needed %zu\n", op_type, - octx->ctx->vtcm_size, spad_size); + FARF(ERROR, "matmul-id-%s : current VTCM reservation %zu is too small, needed %zu\n", mmctx->type, octx->ctx->vtcm_size, spad_size); + if (must_free_mapping) free(mapping_buf); return HTP_STATUS_VTCM_TOO_SMALL; } - octx->src0_spad.data = octx->ctx->vtcm_base; - octx->src1_spad.data = octx->src0_spad.data + octx->src0_spad.size; - octx->src2_spad.data = octx->src1_spad.data + octx->src1_spad.size; + // Place src1 spad first. We use it for dyn.quant and may reuse in subseq ops. + octx->src1_spad.data = octx->ctx->vtcm_base; + octx->src0_spad.data = octx->src1_spad.data + octx->src1_spad.size; + octx->src2_spad.data = octx->src0_spad.data + octx->src0_spad.size; octx->dst_spad.data = octx->src2_spad.data + octx->src2_spad.size; - octx->src0_nrows_per_thread = (src0_nrows + octx->n_threads - 1) / octx->n_threads; - octx->src0_nrows_per_thread += (octx->src0_nrows_per_thread & 1); // round up to even + octx->src1_spad.src = (src1 == octx->src1_spad.src) ? src1 : NULL; + octx->src0_spad.src = NULL; + octx->src2_spad.src = NULL; + octx->dst_spad.src = NULL; + + octx->src0_spad.stride = src0_row_size_padded; + octx->src1_spad.stride = src1_row_size; if (src1_nrows > 1) { // initialize matrix_row_counts and map - uint32_t * matrix_row_counts = (uint32_t *) octx->src2_spad.data + 0; - struct mmid_row_mapping * matrix_rows = (void *) octx->src2_spad.data + matrix_row_counts_size; - memset(matrix_row_counts, 0, n_as * sizeof(uint32_t)); // group rows by src0 matrix for (uint32_t iid1 = 0; iid1 < ids->ne[1]; ++iid1) { // token idx for (uint32_t id = 0; id < n_ids; ++id) { // expert idx - const uint32_t i02 = - *(const uint32_t *) ((const uint8_t *) ids->data + iid1 * ids->nb[1] + id * ids->nb[0]); + const uint32_t i02 = *(const uint32_t *) ((const uint8_t *) ids->data + iid1 * ids->nb[1] + id * ids->nb[0]); assert(i02 >= 0 && i02 < n_as); - MMID_MATRIX_ROW(i02, matrix_row_counts[i02]) = (struct mmid_row_mapping) { id, iid1 }; + matrix_rows[i02 * n_ids * ids->ne[1] + matrix_row_counts[i02]] = (struct mmid_row_mapping) { id, iid1 }; matrix_row_counts[i02] += 1; } } } - // Setup worker pool callbacks - if (!(octx->flags & HTP_OPFLAGS_SKIP_QUANTIZE)) { - // Run quant jobs - const uint32_t n_quant_jobs = MIN(src1_nrows, octx->n_threads); - octx->src1_nrows_per_thread = (src1_nrows + n_quant_jobs - 1) / n_quant_jobs; - worker_pool_run_func(octx->ctx->worker_pool, quant_job_func, octx, n_quant_jobs); + if (octx->flags & HTP_OPFLAGS_SKIP_COMPUTE) { + if (must_free_mapping) free(mapping_buf); + return HTP_STATUS_OK; + } + + bool hmx_eligible = false; +#ifdef HTP_HAS_HMX + if (octx->ctx->hmx_enabled && src1_nrows > 1) { + uint32_t wtype = src0->type; + if (ne01 % 32 == 0 && + (wtype == HTP_TYPE_F16 || wtype == HTP_TYPE_F32 || wtype == HTP_TYPE_Q4_0 || wtype == HTP_TYPE_Q4_1 || wtype == HTP_TYPE_Q8_0 || wtype == HTP_TYPE_IQ4_NL || wtype == HTP_TYPE_MXFP4)) { + if ((wtype == HTP_TYPE_F16 || wtype == HTP_TYPE_F32) && ne00 % 32 == 0) { + hmx_eligible = true; + } else if (wtype != HTP_TYPE_F16 && wtype != HTP_TYPE_F32 && ne00 % 256 == 0) { + hmx_eligible = true; + } + } + } +#endif + + mmctx->hmx_eligible = hmx_eligible; + + if (hmx_eligible) { + for (uint32_t cur_a = 0; cur_a < n_as; ++cur_a) { + const int32_t cne1 = matrix_row_counts[cur_a]; + if (cne1 == 0) continue; + + int ret = hmx_matmul_id_2d_f32(octx->ctx, (float*) dst->data, (float*) src1->data, + (const uint8_t *) src0->data + cur_a * nb02, + cne1, ne00, ne01, + ne11, + nb11, nb12, + nb1, nb2, + (int) src0->nb[1], (int) src0->type, + matrix_rows, cur_a, n_ids * ids->ne[1]); + if (ret != 0) { + FARF(ERROR, "HMX matmul failed for expert %u, error %d\n", cur_a, ret); + if (must_free_mapping) free(mapping_buf); + return HTP_STATUS_NO_SUPPORT; + } + } + + // HMX has overwritten VTCM, so force dynamic quantization cache to clear + octx->src1_spad.src = NULL; + + if (must_free_mapping) free(mapping_buf); + return HTP_STATUS_OK; } - if (!(octx->flags & HTP_OPFLAGS_SKIP_COMPUTE)) { - // Run matmul-id jobs - const uint32_t n_matmul_jobs = octx->n_threads; - worker_pool_run_func(octx->ctx->worker_pool, matmul_id_job_func, octx, n_matmul_jobs); + if (octx->src1_spad.src != src1) { + const uint32_t n_quant_jobs = MIN(src1_nrows, octx->n_threads); + mmctx->src1_nrows_per_thread = (src1_nrows + n_quant_jobs - 1) / n_quant_jobs; + worker_pool_run_func(octx->ctx->worker_pool, quant_job_func, mmctx, n_quant_jobs); + octx->src1_spad.src = src1; } + const uint32_t n_matmul_jobs = octx->n_threads; + worker_pool_run_func(octx->ctx->worker_pool, matmul_id_job_func, mmctx, n_matmul_jobs); + + if (must_free_mapping) free(mapping_buf); return HTP_STATUS_OK; } diff --git a/ggml/src/ggml-hexagon/htp/ops-utils.h b/ggml/src/ggml-hexagon/htp/ops-utils.h deleted file mode 100644 index af9c3305f61..00000000000 --- a/ggml/src/ggml-hexagon/htp/ops-utils.h +++ /dev/null @@ -1,149 +0,0 @@ -#ifndef OPS_UTILS_H -#define OPS_UTILS_H - -#include "htp-msg.h" - -#ifndef MAX -# define MAX(a, b) ((a) > (b) ? (a) : (b)) -#endif - -#ifndef MIN -# define MIN(a, b) ((a) < (b) ? (a) : (b)) -#endif - -static inline uint64_t htp_get_cycles() { - uint64_t cycles = 0; - asm volatile(" %0 = c15:14\n" : "=r"(cycles)); - return cycles; -} - -static inline uint64_t htp_get_pktcnt() { - uint64_t pktcnt; - asm volatile(" %0 = c19:18\n" : "=r"(pktcnt)); - return pktcnt; -} - -static inline int32_t htp_is_aligned(void * addr, uint32_t align) { - return ((size_t) addr & (align - 1)) == 0; -} - -static inline uint32_t htp_round_up(uint32_t n, uint32_t m) { - return m * ((n + m - 1) / m); -} - -// See https://gmplib.org/~tege/divcnst-pldi94.pdf figure 4.1. -// Precompute mp (m' in the paper) and L such that division -// can be computed using a multiply (high 32b of 64b result) -// and a shift: -// -// n/d = (mulhi(n, mp) + n) >> L; -struct fastdiv_values { - uint32_t mp; - uint32_t l; -}; - -static inline struct fastdiv_values init_fastdiv_values(uint32_t d) { - struct fastdiv_values result = { 0, 0 }; - // compute L = ceil(log2(d)); - while (result.l < 32 && ((uint32_t) 1 << result.l) < d) { - ++(result.l); - } - - result.mp = (uint32_t) (((uint64_t) 1 << 32) * (((uint64_t) 1 << result.l) - d) / d + 1); - return result; -} - -static inline uint32_t fastdiv(uint32_t n, const struct fastdiv_values * vals) { - // Compute high 32 bits of n * mp - const uint32_t hi = (uint32_t) (((uint64_t) n * vals->mp) >> 32); // mulhi(n, mp) - // add n, apply bit shift - return (hi + n) >> vals->l; -} - -static inline uint32_t fastmodulo(uint32_t n, uint32_t d, const struct fastdiv_values * vals) { - return n - fastdiv(n, vals) * d; -} - -static inline void htp_l2fetch(const void * p, uint32_t height, uint32_t width, uint32_t stride) { - const uint64_t control = Q6_P_combine_RR(stride, Q6_R_combine_RlRl(width, height)); - asm volatile(" l2fetch(%0,%1) " : : "r"(p), "r"(control)); -} - -static inline int32_t htp_is_one_chunk(void * addr, uint32_t n, uint32_t chunk_size) { - uint32_t left_off = (size_t) addr & (chunk_size - 1); - uint32_t right_off = left_off + n; - return right_off <= chunk_size; -} - -static inline void htp_dump_int8_line(char * pref, const int8_t * x, int n) { - char str[1024], *p = str, *p_end = str + sizeof(str); - p += snprintf(p, p_end - p, "%s: ", pref); - for (int i = 0; i < n && p < p_end; i++) { - p += snprintf(p, p_end - p, "%d, ", x[i]); - } - FARF(HIGH, "%s\n", str); -} - -static inline void htp_dump_uint8_line(char * pref, const uint8_t * x, uint32_t n) { - char str[1024], *p = str, *p_end = str + sizeof(str); - p += snprintf(p, p_end - p, "%s: ", pref); - for (int i = 0; i < n && p < p_end; i++) { - p += snprintf(p, p_end - p, "%d, ", x[i]); - } - FARF(HIGH, "%s\n", str); -} - -static inline void htp_dump_int32_line(char * pref, const int32_t * x, uint32_t n) { - char str[1024], *p = str, *p_end = str + sizeof(str); - p += snprintf(p, p_end - p, "%s: ", pref); - for (int i = 0; i < n; i++) { - p += snprintf(p, p_end - p, "%d, ", (int) x[i]); - } - FARF(HIGH, "%s\n", str); -} - -static inline void htp_dump_fp16_line(char * pref, const __fp16 * x, uint32_t n) { - char str[1024], *p = str, *p_end = str + sizeof(str); - p += snprintf(p, p_end - p, "%s: ", pref); - for (int i = 0; i < n; i++) { - p += snprintf(p, p_end - p, "%.6f, ", (float) x[i]); - } - FARF(HIGH, "%s\n", str); -} - -static inline void htp_dump_fp32_line(char * pref, const float * x, uint32_t n) { - char str[1024], *p = str, *p_end = str + sizeof(str); - p += snprintf(p, p_end - p, "%s: ", pref); - for (int i = 0; i < n; i++) { - p += snprintf(p, p_end - p, "%.6f, ", x[i]); - } - FARF(HIGH, "%s\n", str); -} - -static inline void htp_dump_f32(char * pref, const float * x, uint32_t n) { - uint32_t n0 = n / 16; - uint32_t n1 = n % 16; - - uint32_t i = 0; - for (; i < n0; i++) { - htp_dump_fp32_line(pref, x + (16 * i), 16); - } - if (n1) { - htp_dump_fp32_line(pref, x + (16 * i), n1); - } -} - -static inline void htp_dump_f16(char * pref, const __fp16 * x, uint32_t n) { - uint32_t n0 = n / 16; - uint32_t n1 = n % 16; - - uint32_t i = 0; - for (; i < n0; i++) { - htp_dump_fp16_line(pref, x + (16 * i), 16); - } - if (n1) { - htp_dump_fp16_line(pref, x + (16 * i), n1); - } -} - -#endif /* OPS_UTILS_H */ diff --git a/ggml/src/ggml-hexagon/htp/pad-ops.c b/ggml/src/ggml-hexagon/htp/pad-ops.c new file mode 100644 index 00000000000..aaa72b31590 --- /dev/null +++ b/ggml/src/ggml-hexagon/htp/pad-ops.c @@ -0,0 +1,547 @@ +#pragma clang diagnostic ignored "-Wunused-variable" +#pragma clang diagnostic ignored "-Wunused-function" +#pragma clang diagnostic ignored "-Wunused-but-set-variable" + +#include <HAP_farf.h> +#include <HAP_perf.h> + +#include <string.h> + +#include "hex-dma.h" +#include "hvx-utils.h" + +#define GGML_COMMON_DECL_C +#include "ggml-common.h" +#include "htp-ctx.h" +#include "htp-ops.h" + +/* Circular wrap: maps any integer x into [0, n) */ +static inline uint32_t wrap_around(int32_t x, uint32_t n) { + return (uint32_t)(((x % (int32_t)n) + (int32_t)n) % (int32_t)n); +} + +/* Decompose a flat dst row index into (i1, i2, i3) */ +static inline void pad_decompose_row(uint32_t ir, uint32_t ne1, uint32_t ne2, + uint32_t *i1, uint32_t *i2, uint32_t *i3) { + *i1 = ir % ne1; + *i2 = (ir / ne1) % ne2; + *i3 = ir / (ne1 * ne2); +} + +/* Return non-zero if row (i1,i2,i3) falls in the non-padded interior */ +static inline int pad_is_interior(uint32_t i1, uint32_t i2, uint32_t i3, + int32_t lp1, int32_t rp1, uint32_t ne1, + int32_t lp2, int32_t rp2, uint32_t ne2, + int32_t lp3, int32_t rp3, uint32_t ne3) { + return ((int32_t)i1 >= lp1 && (int32_t)i1 < (int32_t)ne1 - rp1) && + ((int32_t)i2 >= lp2 && (int32_t)i2 < (int32_t)ne2 - rp2) && + ((int32_t)i3 >= lp3 && (int32_t)i3 < (int32_t)ne3 - rp3); +} + +/* Compute the DDR src row pointer for a zero-pad interior row */ +static inline const uint8_t * pad_src_row_ptr(const struct htp_tensor * src, + uint32_t i1, uint32_t i2, uint32_t i3, + int32_t lp1, int32_t lp2, int32_t lp3) { + return (const uint8_t *) src->data + + (i1 - (uint32_t)lp1) * src->nb[1] + + (i2 - (uint32_t)lp2) * src->nb[2] + + (i3 - (uint32_t)lp3) * src->nb[3]; +} + +/* Compute the DDR src row pointer for a circular row (wrap-around indexing) */ +static inline const uint8_t * pad_circ_src_row_ptr(const struct htp_tensor * src, + uint32_t i1, uint32_t i2, uint32_t i3, + int32_t lp1, int32_t lp2, int32_t lp3) { + return (const uint8_t *) src->data + + wrap_around((int32_t)i1 - lp1, src->ne[1]) * src->nb[1] + + wrap_around((int32_t)i2 - lp2, src->ne[2]) * src->nb[2] + + wrap_around((int32_t)i3 - lp3, src->ne[3]) * src->nb[3]; +} + +struct htp_pad_context { + struct htp_ops_context * octx; + + int32_t lp0, rp0; + int32_t lp1, rp1; + int32_t lp2, rp2; + int32_t lp3, rp3; + + uint32_t nrows_per_thread; + uint32_t total_dst_rows; + + size_t type_size; + + // Row sizes for DMA kernel (populated when VTCM is available) + size_t src_row_size; + size_t src_row_size_aligned; + size_t dst_row_size; + size_t dst_row_size_aligned; +}; + +#define htp_pad_preamble \ + const struct htp_tensor * src = octx->src[0]; \ + const struct htp_tensor * dst = octx->dst; \ + \ + const uint32_t ne00 = src->ne[0]; \ + const uint32_t nb00 = src->nb[0]; \ + \ + const uint32_t ne0 = dst->ne[0]; \ + const uint32_t ne1 = dst->ne[1]; \ + const uint32_t ne2 = dst->ne[2]; \ + const uint32_t ne3 = dst->ne[3]; \ + \ + const uint32_t nb1 = dst->nb[1]; \ + const uint32_t nb2 = dst->nb[2]; \ + const uint32_t nb3 = dst->nb[3]; \ + \ + const int32_t lp0 = pctx->lp0, rp0 = pctx->rp0; \ + const int32_t lp1 = pctx->lp1, rp1 = pctx->rp1; \ + const int32_t lp2 = pctx->lp2, rp2 = pctx->rp2; \ + const int32_t lp3 = pctx->lp3, rp3 = pctx->rp3; \ + \ + const size_t type_size = pctx->type_size; \ + \ + const uint32_t row_start = pctx->nrows_per_thread * ith; \ + const uint32_t row_end = MIN(row_start + pctx->nrows_per_thread, pctx->total_dst_rows); + + +#define htp_pad_dma_preamble \ + const size_t src_row_size = pctx->src_row_size; \ + const size_t src_row_size_aligned = pctx->src_row_size_aligned; \ + const size_t dst_row_size = pctx->dst_row_size; \ + const size_t dst_row_size_aligned = pctx->dst_row_size_aligned; \ + \ + uint8_t * src_spad_base = octx->src0_spad.data + ith * octx->src0_spad.size_per_thread; \ + uint8_t * dst_spad_base = octx->dst_spad.data + ith * octx->dst_spad.size_per_thread; \ + \ + dma_queue * dma = octx->ctx->dma[ith]; + +// --------------------------------------------------------------------------- +// HVX vectorized PAD kernel +// --------------------------------------------------------------------------- + +static void pad_job_per_thread_hvx(unsigned int nth, unsigned int ith, void * data) { + const struct htp_pad_context * pctx = (const struct htp_pad_context *) data; + struct htp_ops_context * octx = pctx->octx; + htp_pad_preamble; + + uint64_t t1, t2; + t1 = HAP_perf_get_qtimer_count(); + + for (uint32_t dst_row = row_start; dst_row < row_end; dst_row++) { + uint32_t i1, i2, i3; + pad_decompose_row(dst_row, ne1, ne2, &i1, &i2, &i3); + + uint8_t * dst_ptr = (uint8_t *) dst->data + i1 * nb1 + i2 * nb2 + i3 * nb3; + + const int interior = pad_is_interior(i1, i2, i3, + lp1, rp1, ne1, + lp2, rp2, ne2, + lp3, rp3, ne3); + + if (!interior) { + hvx_splat_f32_u(dst_ptr, 0.0f, ne0); + } else { + const uint8_t * src_ptr = pad_src_row_ptr(src, i1, i2, i3, lp1, lp2, lp3); + + if (lp0 > 0) { + hvx_splat_f32_u(dst_ptr, 0.0f, (uint32_t)lp0); + } + + uint8_t * dst_row_start = dst_ptr + (size_t)lp0 * type_size; + if (nb00 == type_size) { + hvx_copy_f32_uu(dst_row_start, src_ptr, ne00); + } else { + for (uint32_t i = 0; i < ne00; i++) { + memcpy(dst_row_start + i * type_size, + src_ptr + (size_t)i * nb00, + type_size); + } + } + + if (rp0 > 0) { + hvx_splat_f32_u(dst_ptr + ((size_t)lp0 + ne00) * type_size, 0.0f, (uint32_t)rp0); + } + } + } + + t2 = HAP_perf_get_qtimer_count(); + + FARF(HIGH, "pad-hvx %d/%d: (%ux%ux%ux%u) -> (%ux%ux%ux%u) rows %u:%u usec %u\n", + ith, nth, + src->ne[0], src->ne[1], src->ne[2], src->ne[3], + dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], + row_start, row_end, + (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1)); +} + +// --------------------------------------------------------------------------- +// HVX + DMA PAD kernel — aligned, double-buffered +// --------------------------------------------------------------------------- + +static void pad_job_per_thread_hvx_dma(unsigned int nth, unsigned int ith, void * data) { + const struct htp_pad_context * pctx = (const struct htp_pad_context *) data; + struct htp_ops_context * octx = pctx->octx; + htp_pad_preamble; + htp_pad_dma_preamble; + + uint64_t t1, t2; + t1 = HAP_perf_get_qtimer_count(); + + // ----------------------------------------------------------------------- + // Priming phase: push 2 pairs of (dummy_dst_DMA, src_DMA) to seed the + // double-buffer pipeline before the main loop begins. + // ----------------------------------------------------------------------- + for (uint32_t ir = row_start, spad_idx = 0; ir < row_end && spad_idx < 2; ir++, spad_idx++) { + uint8_t * src_spad_cur = src_spad_base + spad_idx * src_row_size_aligned; + uint8_t * dst_spad_cur = dst_spad_base + spad_idx * dst_row_size_aligned; + + dma_queue_push_vtcm_to_ddr(dma, + dma_make_ptr((uint8_t *)dst->data, dst_spad_cur), + dst_row_size, dst_row_size_aligned, 0); + + uint32_t i1, i2, i3; + pad_decompose_row(ir, ne1, ne2, &i1, &i2, &i3); + const int interior = pad_is_interior(i1, i2, i3, + lp1, rp1, ne1, + lp2, rp2, ne2, + lp3, rp3, ne3); + + const uint8_t * src_ptr = interior + ? pad_src_row_ptr(src, i1, i2, i3, lp1, lp2, lp3) : NULL; + + // Interior row: real DMA (1 row) from DDR to VTCM. + // Border row: null DMA (nrows=0) + dma_queue_push_ddr_to_vtcm(dma, + dma_make_ptr(src_spad_cur, + src_ptr ? src_ptr : (const uint8_t *)src_spad_cur), + src_row_size_aligned, src_row_size, src_ptr ? 1 : 0); + } + + // ----------------------------------------------------------------------- + // Main loop: pop completed DMAs, compute in VTCM with aligned HVX ops, + // push dst DMA and prefetch src for the next+1 row. + // ----------------------------------------------------------------------- + for (uint32_t ir = row_start; ir < row_end; ir++) { + uint8_t * dst_spad_cur = (uint8_t *) dma_queue_pop(dma).src; + uint8_t * src_spad_cur = (uint8_t *) dma_queue_pop(dma).dst; + + uint32_t i1, i2, i3; + pad_decompose_row(ir, ne1, ne2, &i1, &i2, &i3); + + uint8_t * dst_ptr = (uint8_t *) dst->data + i1 * nb1 + i2 * nb2 + i3 * nb3; + + const int interior = pad_is_interior(i1, i2, i3, + lp1, rp1, ne1, + lp2, rp2, ne2, + lp3, rp3, ne3); + + if (!interior) { + hvx_splat_f32_a(dst_spad_cur, 0.0f, ne0); + } else { + hvx_splat_f32_a(dst_spad_cur, 0.0f, ne0); + + uint8_t * dst_interior = dst_spad_cur + (size_t)lp0 * type_size; + + if ((uintptr_t)dst_interior % VLEN == 0) { + hvx_copy_f32_aa(dst_interior, src_spad_cur, ne00); + } else { + hvx_copy_f32_ua(dst_interior, src_spad_cur, ne00); + } + } + + dma_queue_push_vtcm_to_ddr(dma, + dma_make_ptr(dst_ptr, dst_spad_cur), + dst_row_size, dst_row_size_aligned, 1); + + const uint32_t next_row = ir + 2; + if (next_row < row_end) { + uint32_t ni1, ni2, ni3; + pad_decompose_row(next_row, ne1, ne2, &ni1, &ni2, &ni3); + const int next_interior = pad_is_interior(ni1, ni2, ni3, + lp1, rp1, ne1, + lp2, rp2, ne2, + lp3, rp3, ne3); + const uint8_t * next_src_ptr = next_interior + ? pad_src_row_ptr(src, ni1, ni2, ni3, lp1, lp2, lp3) : NULL; + + dma_queue_push_ddr_to_vtcm(dma, + dma_make_ptr(src_spad_cur, + next_src_ptr ? next_src_ptr : (const uint8_t *)src_spad_cur), + src_row_size_aligned, src_row_size, next_src_ptr ? 1 : 0); + } + } + + dma_queue_flush(dma); + + t2 = HAP_perf_get_qtimer_count(); + + FARF(HIGH, "pad-hvx-dma %d/%d: (%ux%ux%ux%u) -> (%ux%ux%ux%u) rows %u:%u usec %u\n", + ith, nth, + src->ne[0], src->ne[1], src->ne[2], src->ne[3], + dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], + row_start, row_end, + (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1)); +} + +// --------------------------------------------------------------------------- +// HVX circular PAD kernel +// --------------------------------------------------------------------------- + +static void pad_job_per_thread_hvx_circular(unsigned int nth, unsigned int ith, void * data) { + const struct htp_pad_context * pctx = (const struct htp_pad_context *) data; + struct htp_ops_context * octx = pctx->octx; + htp_pad_preamble; + + uint64_t t1, t2; + t1 = HAP_perf_get_qtimer_count(); + + for (uint32_t dst_row = row_start; dst_row < row_end; dst_row++) { + uint32_t i1, i2, i3; + pad_decompose_row(dst_row, ne1, ne2, &i1, &i2, &i3); + + uint8_t * dst_ptr = (uint8_t *) dst->data + i1 * nb1 + i2 * nb2 + i3 * nb3; + const uint8_t * src_row = pad_circ_src_row_ptr(src, i1, i2, i3, lp1, lp2, lp3); + + if (nb00 == type_size) { + + if (lp0 > 0) { + if ((uint32_t)lp0 < 32) { + memcpy(dst_ptr, + src_row + (size_t)(ne00 - (uint32_t)lp0) * type_size, + (size_t)lp0 * type_size); + } else { + hvx_copy_f32_uu(dst_ptr, + src_row + (size_t)(ne00 - (uint32_t)lp0) * type_size, + (uint32_t)lp0); + } + } + hvx_copy_f32_uu(dst_ptr + (size_t)lp0 * type_size, src_row, ne00); + if (rp0 > 0) { + if ((uint32_t)rp0 < 32) { + memcpy(dst_ptr + ((size_t)lp0 + ne00) * type_size, + src_row, + (size_t)rp0 * type_size); + } else { + hvx_copy_f32_uu(dst_ptr + ((size_t)lp0 + ne00) * type_size, + src_row, + (uint32_t)rp0); + } + } + } else { + for (uint32_t i = 0; i < (uint32_t)lp0; i++) { + *(float *)(dst_ptr + i * type_size) = + *(const float *)(src_row + (size_t)(ne00 - (uint32_t)lp0 + i) * nb00); + } + for (uint32_t i = 0; i < ne00; i++) { + *(float *)(dst_ptr + ((size_t)lp0 + i) * type_size) = + *(const float *)(src_row + (size_t)i * nb00); + } + for (uint32_t i = 0; i < (uint32_t)rp0; i++) { + *(float *)(dst_ptr + ((size_t)lp0 + ne00 + i) * type_size) = + *(const float *)(src_row + (size_t)i * nb00); + } + } + } + + t2 = HAP_perf_get_qtimer_count(); + + FARF(HIGH, "pad-hvx-circ %d/%d: (%ux%ux%ux%u) -> (%ux%ux%ux%u) rows %u:%u usec %u\n", + ith, nth, + src->ne[0], src->ne[1], src->ne[2], src->ne[3], + dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], + row_start, row_end, + (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1)); +} + +// --------------------------------------------------------------------------- +// HVX + DMA circular PAD kernel — aligned, double-buffered +// --------------------------------------------------------------------------- + +static void pad_job_per_thread_hvx_circular_dma(unsigned int nth, unsigned int ith, void * data) { + const struct htp_pad_context * pctx = (const struct htp_pad_context *) data; + struct htp_ops_context * octx = pctx->octx; + htp_pad_preamble; + htp_pad_dma_preamble; + + uint64_t t1, t2; + t1 = HAP_perf_get_qtimer_count(); + + // ----------------------------------------------------------------------- + // Priming phase: push 2 pairs of (dummy_dst_DMA, src_DMA) to seed the + // double-buffer pipeline. Every row is a real src DMA (no null DMAs). + // ----------------------------------------------------------------------- + for (uint32_t ir = row_start, spad_idx = 0; ir < row_end && spad_idx < 2; ir++, spad_idx++) { + uint8_t * src_spad_cur = src_spad_base + spad_idx * src_row_size_aligned; + uint8_t * dst_spad_cur = dst_spad_base + spad_idx * dst_row_size_aligned; + + dma_queue_push_vtcm_to_ddr(dma, + dma_make_ptr((uint8_t *)dst->data, dst_spad_cur), + dst_row_size, dst_row_size_aligned, 0); + + uint32_t pi1, pi2, pi3; + pad_decompose_row(ir, ne1, ne2, &pi1, &pi2, &pi3); + dma_queue_push_ddr_to_vtcm(dma, + dma_make_ptr(src_spad_cur, pad_circ_src_row_ptr(src, pi1, pi2, pi3, lp1, lp2, lp3)), + src_row_size_aligned, src_row_size, 1); + } + + // ----------------------------------------------------------------------- + // Main loop: pop completed DMAs, assemble circular row in VTCM with + // aligned HVX ops, push dst DMA and prefetch src for the next+1 row. + // ----------------------------------------------------------------------- + for (uint32_t ir = row_start; ir < row_end; ir++) { + uint8_t * dst_spad_cur = (uint8_t *) dma_queue_pop(dma).src; + uint8_t * src_spad_cur = (uint8_t *) dma_queue_pop(dma).dst; + + uint32_t i1, i2, i3; + pad_decompose_row(ir, ne1, ne2, &i1, &i2, &i3); + uint8_t * dst_ptr = (uint8_t *) dst->data + i1 * nb1 + i2 * nb2 + i3 * nb3; + + + if (lp0 > 0) { + uint8_t * dst_left = dst_spad_cur; + const uint8_t * src_left = src_spad_cur + (size_t)(ne00 - (uint32_t)lp0) * type_size; + if ((uint32_t)lp0 < 32) { + memcpy(dst_left, src_left, (size_t)lp0 * type_size); + } else { + hvx_copy_f32_uu(dst_left, src_left, (uint32_t)lp0); + } + } + + { + uint8_t * dst_mid = dst_spad_cur + (size_t)lp0 * type_size; + if ((uintptr_t)dst_mid % VLEN == 0) { + hvx_copy_f32_aa(dst_mid, src_spad_cur, ne00); + } else { + hvx_copy_f32_ua(dst_mid, src_spad_cur, ne00); + } + } + + if (rp0 > 0) { + uint8_t * dst_right = dst_spad_cur + ((size_t)lp0 + ne00) * type_size; + if ((uint32_t)rp0 < 32) { + memcpy(dst_right, src_spad_cur, (size_t)rp0 * type_size); + } else { + if ((uintptr_t)dst_right % VLEN == 0) { + hvx_copy_f32_aa(dst_right, src_spad_cur, (uint32_t)rp0); + } else { + hvx_copy_f32_ua(dst_right, src_spad_cur, (uint32_t)rp0); + } + } + } + + dma_queue_push_vtcm_to_ddr(dma, + dma_make_ptr(dst_ptr, dst_spad_cur), + dst_row_size, dst_row_size_aligned, 1); + + const uint32_t next_row = ir + 2; + if (next_row < row_end) { + uint32_t nri1, nri2, nri3; + pad_decompose_row(next_row, ne1, ne2, &nri1, &nri2, &nri3); + dma_queue_push_ddr_to_vtcm(dma, + dma_make_ptr(src_spad_cur, + pad_circ_src_row_ptr(src, nri1, nri2, nri3, lp1, lp2, lp3)), + src_row_size_aligned, src_row_size, 1); + } + } + + dma_queue_flush(dma); + + t2 = HAP_perf_get_qtimer_count(); + + FARF(HIGH, "pad-hvx-circ-dma %d/%d: (%ux%ux%ux%u) -> (%ux%ux%ux%u) rows %u:%u usec %u\n", + ith, nth, + src->ne[0], src->ne[1], src->ne[2], src->ne[3], + dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], + row_start, row_end, + (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1)); +} + +int op_pad(struct htp_ops_context * octx) { + const struct htp_tensor * src0 = octx->src[0]; + const struct htp_tensor * dst = octx->dst; + + // Only F32 supported + size_t type_size; + switch (src0->type) { + case HTP_TYPE_F32: type_size = 4; break; + default: + FARF(ERROR, "pad-hvx: unsupported type %u\n", src0->type); + return HTP_STATUS_NO_SUPPORT; + } + + if (octx->flags & HTP_OPFLAGS_SKIP_COMPUTE) { + return HTP_STATUS_OK; + } + + const int32_t lp0 = octx->op_params[0]; + const int32_t rp0 = octx->op_params[1]; + const int32_t lp1 = octx->op_params[2]; + const int32_t rp1 = octx->op_params[3]; + const int32_t lp2 = octx->op_params[4]; + const int32_t rp2 = octx->op_params[5]; + const int32_t lp3 = octx->op_params[6]; + const int32_t rp3 = octx->op_params[7]; + const int32_t circular = octx->op_params[8]; + + const uint32_t ne0 = dst->ne[0]; + const uint32_t ne00 = src0->ne[0]; + + const uint32_t total_dst_rows = dst->ne[1] * dst->ne[2] * dst->ne[3]; + const uint32_t n_threads = MIN(octx->n_threads, total_dst_rows > 0 ? total_dst_rows : 1); + + const size_t src_row_size = (size_t)ne00 * type_size; + const size_t dst_row_size = (size_t)ne0 * type_size; + const size_t src_row_size_aligned = hex_round_up(src_row_size, VLEN); + const size_t dst_row_size_aligned = hex_round_up(dst_row_size, VLEN); + + // Total VTCM needed: 2 buffers (ping+pong) for src and dst, per thread + const size_t vtcm_needed = (size_t)n_threads * 2 * (src_row_size_aligned + dst_row_size_aligned); + + const int use_dma = (src0->nb[0] == (uint32_t)type_size) && + (ne00 >= 512) && + (octx->ctx->vtcm_base != NULL) && + (octx->ctx->vtcm_size >= vtcm_needed); + + if (use_dma) { + octx->src0_spad.size_per_thread = 2 * src_row_size_aligned; + octx->dst_spad.size_per_thread = 2 * dst_row_size_aligned; + octx->src0_spad.size = n_threads * octx->src0_spad.size_per_thread; + octx->dst_spad.size = n_threads * octx->dst_spad.size_per_thread; + octx->src0_spad.data = octx->ctx->vtcm_base; + octx->dst_spad.data = octx->src0_spad.data + octx->src0_spad.size; + octx->src0_spad.src = NULL; + octx->dst_spad.src = NULL; + } + + struct htp_pad_context pctx = { + .octx = octx, + .lp0 = lp0, .rp0 = rp0, + .lp1 = lp1, .rp1 = rp1, + .lp2 = lp2, .rp2 = rp2, + .lp3 = lp3, .rp3 = rp3, + .nrows_per_thread = (total_dst_rows + n_threads - 1) / n_threads, + .total_dst_rows = total_dst_rows, + .type_size = type_size, + .src_row_size = src_row_size, + .src_row_size_aligned = src_row_size_aligned, + .dst_row_size = dst_row_size, + .dst_row_size_aligned = dst_row_size_aligned, + }; + + FARF(HIGH, "pad-hvx%s%s: (%ux%ux%ux%u) -> (%ux%ux%ux%u) pads=(%d,%d,%d,%d,%d,%d,%d,%d)\n", + circular ? "-circ" : "", + use_dma ? "-dma" : "", + src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], + dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], + lp0, rp0, lp1, rp1, lp2, rp2, lp3, rp3); + + if (circular && use_dma) { worker_pool_run_func(octx->ctx->worker_pool, pad_job_per_thread_hvx_circular_dma, &pctx, n_threads); } + else if (circular) { worker_pool_run_func(octx->ctx->worker_pool, pad_job_per_thread_hvx_circular, &pctx, n_threads); } + else if (use_dma) { worker_pool_run_func(octx->ctx->worker_pool, pad_job_per_thread_hvx_dma, &pctx, n_threads); } + else { worker_pool_run_func(octx->ctx->worker_pool, pad_job_per_thread_hvx, &pctx, n_threads); } + + return HTP_STATUS_OK; +} + diff --git a/ggml/src/ggml-hexagon/htp/repeat-ops.c b/ggml/src/ggml-hexagon/htp/repeat-ops.c new file mode 100644 index 00000000000..a6f2f0ed5f3 --- /dev/null +++ b/ggml/src/ggml-hexagon/htp/repeat-ops.c @@ -0,0 +1,148 @@ +#pragma clang diagnostic ignored "-Wunused-variable" +#pragma clang diagnostic ignored "-Wunused-function" +#pragma clang diagnostic ignored "-Wunused-but-set-variable" + +#include <HAP_farf.h> +#include <HAP_perf.h> + +#include <string.h> + +#include "hvx-utils.h" + +#define GGML_COMMON_DECL_C +#include "ggml-common.h" +#include "htp-ctx.h" +#include "htp-ops.h" +#include "htp-ops.h" + +struct htp_repeat_context { + struct htp_ops_context * octx; + + uint32_t nr0; + uint32_t nr1; + uint32_t nr2; + uint32_t nr3; + + uint32_t nrows_per_thread; + uint32_t total_dst_rows; // ne1 * ne2 * ne3 + + size_t type_size; +}; + +static void repeat_job_per_thread(unsigned int nth, unsigned int ith, void * data) { + const struct htp_repeat_context * rctx = (const struct htp_repeat_context *) data; + struct htp_ops_context * octx = rctx->octx; + const struct htp_tensor * src = octx->src[0]; + const struct htp_tensor * dst = octx->dst; + + const uint32_t ne00 = src->ne[0]; + const uint32_t ne01 = src->ne[1]; + const uint32_t ne02 = src->ne[2]; + const uint32_t ne03 = src->ne[3]; + + const uint32_t nb00 = src->nb[0]; + const uint32_t nb01 = src->nb[1]; + const uint32_t nb02 = src->nb[2]; + const uint32_t nb03 = src->nb[3]; + + const uint32_t ne0 = dst->ne[0]; + const uint32_t ne1 = dst->ne[1]; + const uint32_t ne2 = dst->ne[2]; + const uint32_t ne3 = dst->ne[3]; + + const uint32_t nb0 = dst->nb[0]; + const uint32_t nb1 = dst->nb[1]; + const uint32_t nb2 = dst->nb[2]; + const uint32_t nb3 = dst->nb[3]; + + const uint32_t nr0 = rctx->nr0; + const uint32_t nr1 = rctx->nr1; + const uint32_t nr2 = rctx->nr2; + const uint32_t nr3 = rctx->nr3; + + const size_t row_bytes = ne00 * rctx->type_size; + + const uint32_t row_start = rctx->nrows_per_thread * ith; + const uint32_t row_end = MIN(row_start + rctx->nrows_per_thread, rctx->total_dst_rows); + + uint64_t t1, t2; + t1 = HAP_perf_get_qtimer_count(); + + for (uint32_t dst_row = row_start; dst_row < row_end; dst_row++) { + // Decompose flat dst row index into (i1, i2, i3) + const uint32_t i1 = dst_row % ne1; + const uint32_t i2 = (dst_row / ne1) % ne2; + const uint32_t i3 = dst_row / (ne1 * ne2); + + // Map to source indices (tiling) + const uint32_t k1 = i1 % ne01; + const uint32_t k2 = i2 % ne02; + const uint32_t k3 = i3 % ne03; + + const uint8_t * src_row = (const uint8_t *) src->data + k1 * nb01 + k2 * nb02 + k3 * nb03; + uint8_t * dst_base = (uint8_t *) dst->data + i1 * nb1 + i2 * nb2 + i3 * nb3; + + // Tile along dimension 0 + for (uint32_t i0 = 0; i0 < nr0; i0++) { + uint8_t * dst_ptr = dst_base + i0 * ne00 * nb0; + memcpy(dst_ptr, src_row, row_bytes); + } + } + + t2 = HAP_perf_get_qtimer_count(); + + FARF(HIGH, "repeat %d/%d: (%ux%ux%ux%u) -> (%ux%ux%ux%u) rows %u:%u usec %u\n", + ith, nth, src->ne[0], src->ne[1], src->ne[2], src->ne[3], + dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], + row_start, row_end, (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1)); +} + +int op_repeat(struct htp_ops_context * octx) { + const struct htp_tensor * src0 = octx->src[0]; + const struct htp_tensor * dst = octx->dst; + + // Validate that dst dims are multiples of src dims + if (dst->ne[0] % src0->ne[0] != 0 || + dst->ne[1] % src0->ne[1] != 0 || + dst->ne[2] % src0->ne[2] != 0 || + dst->ne[3] % src0->ne[3] != 0) { + FARF(ERROR, "repeat: dst dims must be multiples of src dims\n"); + return HTP_STATUS_INVAL_PARAMS; + } + + size_t type_size; + switch (src0->type) { + case HTP_TYPE_F32: type_size = 4; break; + case HTP_TYPE_F16: type_size = 2; break; + default: + FARF(ERROR, "repeat: unsupported type %u\n", src0->type); + return HTP_STATUS_NO_SUPPORT; + } + + const uint32_t total_dst_rows = dst->ne[1] * dst->ne[2] * dst->ne[3]; + const uint32_t n_threads = MIN(octx->n_threads, total_dst_rows); + + if (octx->flags & HTP_OPFLAGS_SKIP_COMPUTE) { + return HTP_STATUS_OK; + } + + struct htp_repeat_context rctx = { + .octx = octx, + .nr0 = dst->ne[0] / src0->ne[0], + .nr1 = dst->ne[1] / src0->ne[1], + .nr2 = dst->ne[2] / src0->ne[2], + .nr3 = dst->ne[3] / src0->ne[3], + .nrows_per_thread = (total_dst_rows + n_threads - 1) / n_threads, + .total_dst_rows = total_dst_rows, + .type_size = type_size, + }; + + FARF(HIGH, "repeat: (%ux%ux%ux%u) -> (%ux%ux%ux%u) nr=(%u,%u,%u,%u)\n", + src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], + dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], + rctx.nr0, rctx.nr1, rctx.nr2, rctx.nr3); + + worker_pool_run_func(octx->ctx->worker_pool, repeat_job_per_thread, &rctx, n_threads); + + return HTP_STATUS_OK; +} diff --git a/ggml/src/ggml-hexagon/htp/rope-ops.c b/ggml/src/ggml-hexagon/htp/rope-ops.c index a4399704fcb..c839044b84f 100644 --- a/ggml/src/ggml-hexagon/htp/rope-ops.c +++ b/ggml/src/ggml-hexagon/htp/rope-ops.c @@ -2,31 +2,31 @@ #pragma clang diagnostic ignored "-Wunused-function" #pragma clang diagnostic ignored "-Wunused-but-set-variable" -#ifdef HTP_DEBUG -# define FARF_HIGH 1 -#endif #include <HAP_farf.h> -#include <HAP_mem.h> #include <HAP_perf.h> -#include <HAP_ps.h> -#include <hexagon_protos.h> -#include <hexagon_types.h> + #include <math.h> -#include <qurt_thread.h> #include <string.h> +#include <stdlib.h> + +#include "hex-dma.h" +#include "hvx-utils.h" +#include "hex-fastdiv.h" #define GGML_COMMON_DECL_C #include "ggml-common.h" #include "htp-ctx.h" -#include "htp-dma.h" -#include "htp-msg.h" #include "htp-ops.h" -#include "hvx-utils.h" -#include "ops-utils.h" +#include "htp-ops.h" -// Redefined the types GGML_ROPE_TYPE_NORMAL & GGML_ROPE_TYPE_NEOX as we cant include ggml.h +// Redefined the rope type constants as we can't include ggml.h #define HTP_ROPE_TYPE_NORMAL 0 #define HTP_ROPE_TYPE_NEOX 2 +#define HTP_ROPE_TYPE_MROPE 8 +#define HTP_ROPE_TYPE_IMROPE 40 + +#define HTP_ROPE_SPAD_NROWS 16 +#define HTP_ROPE_SPAD_BLOCK (HTP_ROPE_SPAD_NROWS/2) #define htp_rope_preamble \ const uint32_t ne00 = src0->ne[0]; \ @@ -49,7 +49,7 @@ const uint32_t nb2 = dst->nb[2]; \ const uint32_t nb3 = dst->nb[3]; -struct rope_th_ctx { +struct htp_rope_context { int32_t n_dims; int32_t mode; int32_t n_ctx_orig; @@ -64,7 +64,22 @@ struct rope_th_ctx { float theta_scale; float corr_dims[2]; + uint32_t src0_nrows_per_thread; + size_t spad_stride; + struct htp_ops_context * octx; + + size_t src0_row_size; + size_t dst_row_size; + size_t src0_row_size_aligned; + size_t dst_row_size_aligned; + size_t theta_cache_offset; + uint32_t src0_nrows; + + struct fastdiv_values div_ne2_ne1; + struct fastdiv_values div_ne1; + + uint64_t t_start; }; static float rope_yarn_ramp(const float low, const float high, const int i0) { @@ -73,7 +88,30 @@ static float rope_yarn_ramp(const float low, const float high, const int i0) { return (1 - MIN(1, MAX(0, y))); } -static void rope_cache_init(const float theta_base, +// Compute one (cos, sin) pair into cache[i0], cache[i0+1] applying YaRN scaling. +static inline void rope_yarn_one(float theta, float freq_scale, float * corr_dims, + uint32_t i0, float ext_factor, float mscale, + float * cache) { + float theta_extrap = theta; + + // Get n-d rotational scaling corrected for extrapolation + float theta_interp = freq_scale * theta_extrap; + float theta_final = theta_interp; + float mscale_final = mscale; + + if (ext_factor != 0.0f) { + float ramp_mix = rope_yarn_ramp(corr_dims[0], corr_dims[1], i0) * ext_factor; + theta_final = theta_interp * (1 - ramp_mix) + theta_extrap * ramp_mix; + + // Get n-d magnitude scaling corrected for interpolation + mscale_final *= 1.0f + 0.1f * logf(1.0f / freq_scale); + } + + cache[i0 + 0] = cosf(theta_final) * mscale_final; + cache[i0 + 1] = sinf(theta_final) * mscale_final; +} + +static __attribute__((noinline)) void rope_cache_init(const float theta_base, const float freq_scale, const float * freq_factors, float * corr_dims, @@ -83,30 +121,137 @@ static void rope_cache_init(const float theta_base, float * cache, const float theta_scale) { // ref: https://github.com/jquesnelle/yarn/blob/master/scaled_rope/LlamaYaRNScaledRotaryEmbedding.py - float theta = theta_base; +#if __HVX_ARCH__ >= 79 + const bool is_v79_or_newer = true; +#else + const bool is_v79_or_newer = false; +#endif - for (uint32_t i0 = 0; i0 < ne0; i0 += 2) { - const float ff = freq_factors ? freq_factors[i0 / 2] : 1.0f; + if (is_v79_or_newer && ext_factor == 0.0f) { + // Fast path: fully vectorized + // We process 32 pairs (64 elements) per iteration. + const uint32_t n_blocks = ne0 / 64; + + // Initialize theta scale powers: [1.0f, theta_scale, theta_scale^2, ..., theta_scale^31] + float __attribute__((aligned(128))) theta_powers[32]; + theta_powers[0] = 1.0f; + for (int j = 1; j < 32; j++) { + theta_powers[j] = theta_powers[j - 1] * theta_scale; + } + HVX_Vector v_theta_powers = hvx_vmem(theta_powers); + + HVX_Vector v_freq_scale = hvx_vec_splat_f32(freq_scale); + HVX_Vector v_mscale = hvx_vec_splat_f32(mscale); + + // Base theta starts at theta_base + float theta_block = theta_base; + // The scale factor for the next block is theta_scale^32 + float theta_scale_32 = 1.0f; + for (int j = 0; j < 32; j++) { + theta_scale_32 *= theta_scale; + } - float theta_extrap = theta / ff; + for (uint32_t b = 0; b < n_blocks; b++) { + uint32_t i0 = b * 64; + HVX_Vector v_theta_base = hvx_vec_splat_f32(theta_block); + HVX_Vector v_theta = hvx_vec_mul_f32_f32(v_theta_base, v_theta_powers); - // Get n-d rotational scaling corrected for extrapolation - float theta_interp = freq_scale * theta_extrap; - float theta_final = theta_interp; - float mscale_final = mscale; + if (freq_factors) { + // Load 32 elements of freq_factors + HVX_Vector v_ff = hvx_vmemu(freq_factors + i0 / 2); + HVX_Vector v_inv_ff = hvx_vec_inverse_f32(v_ff); + v_theta = hvx_vec_mul_f32_f32(v_theta, v_inv_ff); + } + + HVX_Vector v_theta_final = hvx_vec_mul_f32_f32(v_theta, v_freq_scale); + + HVX_Vector vcos = hvx_vec_cos_f32(v_theta_final); + HVX_Vector vsin = hvx_vec_sin_f32(v_theta_final); + + vcos = hvx_vec_mul_f32_f32(vcos, v_mscale); + vsin = hvx_vec_mul_f32_f32(vsin, v_mscale); - if (ext_factor != 0.0f) { - float ramp_mix = rope_yarn_ramp(corr_dims[0], corr_dims[1], i0) * ext_factor; - theta_final = theta_interp * (1 - ramp_mix) + theta_extrap * ramp_mix; + HVX_VectorPair vstore = Q6_W_vshuff_VVR(vsin, vcos, -4); - // Get n-d magnitude scaling corrected for interpolation - mscale_final *= 1.0f + 0.1f * logf(1.0f / freq_scale); + if (((uintptr_t)cache) % 128 == 0) { + hvx_vmem(cache + i0 + 0) = Q6_V_lo_W(vstore); + hvx_vmem(cache + i0 + 32) = Q6_V_hi_W(vstore); + } else { + hvx_vec_store_u(cache + i0 + 0, 32 * sizeof(float), Q6_V_lo_W(vstore)); + hvx_vec_store_u(cache + i0 + 32, 32 * sizeof(float), Q6_V_hi_W(vstore)); + } + + theta_block *= theta_scale_32; } - cache[i0 + 0] = cosf(theta_final) * mscale_final; - cache[i0 + 1] = sinf(theta_final) * mscale_final; + // Leftovers + float theta = theta_block; + for (uint32_t i0 = n_blocks * 64; i0 < ne0; i0 += 2) { + const float ff = freq_factors ? freq_factors[i0 / 2] : 1.0f; + rope_yarn_one(theta / ff, freq_scale, corr_dims, i0, ext_factor, mscale, cache); + theta *= theta_scale; + } + } else { + // Fallback to original scalar loop + float theta = theta_base; + for (uint32_t i0 = 0; i0 < ne0; i0 += 2) { + const float ff = freq_factors ? freq_factors[i0 / 2] : 1.0f; + rope_yarn_one(theta / ff, freq_scale, corr_dims, i0, ext_factor, mscale, cache); + theta *= theta_scale; + } + } +} + +// pos_t/h/w/e: the four position ids for this sequence step (t=time, h=height, w=width, e=extra). +// sections[4]: number of head dims assigned to each position component. +static __attribute__((noinline)) void mrope_cache_init(const float pos_t, + const float pos_h, + const float pos_w, + const float pos_e, + const int32_t sections[4], + const bool is_imrope, + const float freq_scale, + const float * freq_factors, + float * corr_dims, + const uint32_t ne0, + const float ext_factor, + const float mscale, + float * cache, + const float theta_scale) { + const int sect_dims = sections[0] + sections[1] + sections[2] + sections[3]; + const int sec_w = sections[0] + sections[1]; + const int sec_e = sec_w + sections[2]; + + float theta_t = pos_t; + float theta_h = pos_h; + float theta_w = pos_w; + float theta_e = pos_e; + + for (uint32_t i0 = 0; i0 < ne0; i0 += 2) { + const float ff = freq_factors ? freq_factors[i0 / 2] : 1.0f; + const int sector = (i0 / 2) % sect_dims; + + float theta; + if (is_imrope) { + // Interleaved: sector mod 3 selects component + if (sector % 3 == 0 && sector < 3 * sections[0]) { theta = theta_t; } + else if (sector % 3 == 1 && sector < 3 * sections[1]) { theta = theta_h; } + else if (sector % 3 == 2 && sector < 3 * sections[2]) { theta = theta_w; } + else { theta = theta_e; } + } else { + // Contiguous sections + if (sector < sections[0]) { theta = theta_t; } + else if (sector < sec_w) { theta = theta_h; } + else if (sector < sec_e) { theta = theta_w; } + else { theta = theta_e; } + } + + rope_yarn_one(theta / ff, freq_scale, corr_dims, i0, ext_factor, mscale, cache); - theta *= theta_scale; + theta_t *= theta_scale; + theta_h *= theta_scale; + theta_w *= theta_scale; + theta_e *= theta_scale; } } @@ -124,66 +269,40 @@ static void rope_corr_dims(int n_dims, dims[1] = MIN(n_dims - 1, end); } -static void init_rope_ctx(struct rope_th_ctx * rope_ctx, struct htp_ops_context * octx) { - memset(rope_ctx, 0, sizeof(struct rope_th_ctx)); - - const int32_t * op_params = &octx->op_params[0]; - - rope_ctx->n_dims = ((const int32_t *) op_params)[1]; - rope_ctx->mode = ((const int32_t *) op_params)[2]; - rope_ctx->n_ctx_orig = ((const int32_t *) op_params)[4]; - - memcpy(&rope_ctx->freq_base, (int32_t *) op_params + 5, sizeof(float)); - memcpy(&rope_ctx->freq_scale, (int32_t *) op_params + 6, sizeof(float)); - memcpy(&rope_ctx->ext_factor, (int32_t *) op_params + 7, sizeof(float)); - memcpy(&rope_ctx->attn_factor, (int32_t *) op_params + 8, sizeof(float)); - memcpy(&rope_ctx->beta_fast, (int32_t *) op_params + 9, sizeof(float)); - memcpy(&rope_ctx->beta_slow, (int32_t *) op_params + 10, sizeof(float)); - memcpy(&rope_ctx->sections, (int32_t *) op_params + 11, sizeof(int) * 4); - - rope_ctx->theta_scale = powf(rope_ctx->freq_base, -2.0f / rope_ctx->n_dims); - - rope_corr_dims(rope_ctx->n_dims, rope_ctx->n_ctx_orig, rope_ctx->freq_base, rope_ctx->beta_fast, - rope_ctx->beta_slow, rope_ctx->corr_dims); - - rope_ctx->octx = octx; - FARF(HIGH, "rope-f32 n_dims:%d, ext_factor:%.6f, theta_scale:%.6f, attn_factor:%.6f\n", rope_ctx->n_dims, - rope_ctx->ext_factor, rope_ctx->theta_scale, rope_ctx->attn_factor); -} +static inline void hvx_rope_neox_f32_aa(float * restrict dst, const float * restrict src0, uint32_t ne, const float * restrict theta_cache) { + const uint32_t he = ne / 2; + const uint32_t nvec = he / 32; + const uint32_t nloe = he % 32; -static void hvx_calc_rope_neox_f32(const float * restrict src0, - float * restrict dst, - const int num_elems, - const float * restrict theta_cache) { - // for (int i = 0; i < num_elems; i += 2) { - //const float cos_theta = theta_cache[i + 0]; - //const float sin_theta = theta_cache[i + 1]; + for (uint32_t i = 0; i < nvec; i++) { + HVX_Vector v0 = ((const HVX_Vector *) src0)[i]; + HVX_Vector v1 = hvx_vmemu(src0 + he + i * 32); - //const float x0 = src[0]; - //const float x1 = src[num_elems/2]; + HVX_Vector v2 = ((const HVX_Vector *) theta_cache)[i * 2 + 0]; + HVX_Vector v3 = ((const HVX_Vector *) theta_cache)[i * 2 + 1]; - //dst[0] = x0*cos_theta - x1*sin_theta; - //dst[num_elems/2] = x0*sin_theta + x1*cos_theta; + HVX_VectorPair vcos_sin = Q6_W_vdeal_VVR(v3, v2, -4); - //src += 1; - //dst += 1; - // } + HVX_Vector vx0_c = Q6_Vqf32_vmpy_VsfVsf(v0, Q6_V_lo_W(vcos_sin)); + HVX_Vector vx0_s = Q6_Vqf32_vmpy_VsfVsf(v0, Q6_V_hi_W(vcos_sin)); + HVX_Vector vx1_c = Q6_Vqf32_vmpy_VsfVsf(v1, Q6_V_lo_W(vcos_sin)); + HVX_Vector vx1_s = Q6_Vqf32_vmpy_VsfVsf(v1, Q6_V_hi_W(vcos_sin)); - const uint8_t * restrict src0_curr = (const uint8_t *) src0; - const uint8_t * restrict theta_curr = (const uint8_t *) theta_cache; - uint8_t * restrict dst_curr = (uint8_t *) dst; + HVX_Vector v4 = Q6_Vqf32_vsub_Vqf32Vqf32(vx0_c, vx1_s); + HVX_Vector v5 = Q6_Vqf32_vadd_Vqf32Vqf32(vx0_s, vx1_c); - int step_of_1 = num_elems >> 6; // 6 because we process two vectors at once - int half_size = (sizeof(float) * (num_elems / 2)); + ((HVX_Vector *) dst)[i] = Q6_Vsf_equals_Vqf32(v4); + hvx_vmemu(dst + he + i * 32) = Q6_Vsf_equals_Vqf32(v5); + } - for (int i = 0; i < step_of_1; i++) { - HVX_Vector v0 = *(HVX_Vector *) src0_curr; - HVX_Vector v1 = *(HVX_Vector *) (src0_curr + half_size); + if (nloe > 0) { + HVX_Vector v0 = hvx_vmemu(src0 + nvec * 32); + HVX_Vector v1 = hvx_vmemu(src0 + he + nvec * 32); - HVX_Vector v2 = *(HVX_Vector *) theta_curr; - HVX_Vector v3 = *(HVX_Vector *) (theta_curr + VLEN); + HVX_Vector v2 = ((const HVX_Vector *) theta_cache)[nvec * 2 + 0]; + HVX_Vector v3 = ((const HVX_Vector *) theta_cache)[nvec * 2 + 1]; - HVX_VectorPair vcos_sin = Q6_W_vdeal_VVR(v3, v2, -4); // vcos_sin[0] = cos_theta, vcos_sin[1] = sin_theta + HVX_VectorPair vcos_sin = Q6_W_vdeal_VVR(v3, v2, -4); HVX_Vector vx0_c = Q6_Vqf32_vmpy_VsfVsf(v0, Q6_V_lo_W(vcos_sin)); HVX_Vector vx0_s = Q6_Vqf32_vmpy_VsfVsf(v0, Q6_V_hi_W(vcos_sin)); @@ -193,48 +312,24 @@ static void hvx_calc_rope_neox_f32(const float * restrict src0, HVX_Vector v4 = Q6_Vqf32_vsub_Vqf32Vqf32(vx0_c, vx1_s); HVX_Vector v5 = Q6_Vqf32_vadd_Vqf32Vqf32(vx0_s, vx1_c); - *(HVX_Vector *) dst_curr = Q6_Vsf_equals_Vqf32(v4); - *(HVX_Vector *) (dst_curr + half_size) = Q6_Vsf_equals_Vqf32(v5); - - src0_curr += VLEN; - theta_curr += 2 * VLEN; - dst_curr += VLEN; + hvx_vec_store_u(dst + nvec * 32, nloe * sizeof(float), Q6_Vsf_equals_Vqf32(v4)); + hvx_vec_store_u(dst + he + nvec * 32, nloe * sizeof(float), Q6_Vsf_equals_Vqf32(v5)); } } -static void hvx_calc_rope_f32(const float * restrict src0, - float * restrict dst, - const int num_elems, - const float * restrict theta_cache) { - // for (int i = 0; i < num_elems; i += 2) { - //const float cos_theta = theta_cache[i + 0]; - //const float sin_theta = theta_cache[i + 1]; +static inline void hvx_rope_f32_aa(float * restrict dst, const float * restrict src0, uint32_t ne, const float * restrict theta_cache) { + const uint32_t nvec = ne / 64; + const uint32_t nloe = ne % 64; - //const float x0 = src[0]; - //const float x1 = src[1]; + for (uint32_t i = 0; i < nvec; i++) { + HVX_Vector v0 = ((const HVX_Vector *) src0)[i * 2 + 0]; + HVX_Vector v1 = ((const HVX_Vector *) src0)[i * 2 + 1]; - //dst[0] = x0*cos_theta - x1*sin_theta; - //dst[1] = x0*sin_theta + x1*cos_theta; + HVX_Vector v2 = ((const HVX_Vector *) theta_cache)[i * 2 + 0]; + HVX_Vector v3 = ((const HVX_Vector *) theta_cache)[i * 2 + 1]; - //src += 2; - //dst += 2; - // } - - const uint8_t * restrict src0_curr = (const uint8_t *) src0; - const uint8_t * restrict theta_curr = (const uint8_t *) theta_cache; - uint8_t * restrict dst_curr = (uint8_t *) dst; - - int step_of_1 = num_elems >> 6; // 6 because we process two vectors at once - - for (int i = 0; i < step_of_1; i++) { - HVX_Vector v0 = *(HVX_Vector *) src0_curr; - HVX_Vector v1 = *(HVX_Vector *) (src0_curr + VLEN); - - HVX_Vector v2 = *(HVX_Vector *) theta_curr; - HVX_Vector v3 = *(HVX_Vector *) (theta_curr + VLEN); - - HVX_VectorPair vx0_x1 = Q6_W_vdeal_VVR(v1, v0, -4); // vx0_x1[0] = x0, vx0_x1[1] = x1 - HVX_VectorPair vcos_sin = Q6_W_vdeal_VVR(v3, v2, -4); // vcos_sin[0] = cos_theta, vcos_sin[1] = sin_theta + HVX_VectorPair vx0_x1 = Q6_W_vdeal_VVR(v1, v0, -4); + HVX_VectorPair vcos_sin = Q6_W_vdeal_VVR(v3, v2, -4); HVX_Vector vx0_c = Q6_Vqf32_vmpy_VsfVsf(Q6_V_lo_W(vx0_x1), Q6_V_lo_W(vcos_sin)); HVX_Vector vx0_s = Q6_Vqf32_vmpy_VsfVsf(Q6_V_lo_W(vx0_x1), Q6_V_hi_W(vcos_sin)); @@ -246,116 +341,100 @@ static void hvx_calc_rope_f32(const float * restrict src0, HVX_VectorPair vstore = Q6_W_vshuff_VVR(Q6_Vsf_equals_Vqf32(v5), Q6_Vsf_equals_Vqf32(v4), -4); - *(HVX_Vector *) dst_curr = Q6_V_lo_W(vstore); - *(HVX_Vector *) (dst_curr + VLEN) = Q6_V_hi_W(vstore); - - src0_curr += 2 * VLEN; - theta_curr += 2 * VLEN; - dst_curr += 2 * VLEN; + ((HVX_Vector *) dst)[i * 2 + 0] = Q6_V_lo_W(vstore); + ((HVX_Vector *) dst)[i * 2 + 1] = Q6_V_hi_W(vstore); } -} -static void rope_hex_f32(struct rope_th_ctx * rope_ctx, - const uint32_t ir0, - const uint32_t ir1, - int nth, - int ith, - const int opt_path) { - struct htp_ops_context * octx = rope_ctx->octx; + if (nloe > 0) { + if (nloe <= 32) { + HVX_Vector v0 = hvx_vmemu(src0 + nvec * 64); + HVX_Vector v2 = hvx_vmemu(theta_cache + nvec * 64); - const struct htp_tensor * src0 = &octx->src0; - const struct htp_tensor * src1 = &octx->src1; - const struct htp_tensor * src2 = &octx->src2; - struct htp_tensor * dst = &octx->dst; + HVX_VectorPair vx0_x1 = Q6_W_vdeal_VVR(Q6_V_vzero(), v0, -4); + HVX_VectorPair vcos_sin = Q6_W_vdeal_VVR(Q6_V_vzero(), v2, -4); - const int32_t mode = rope_ctx->mode; - const bool is_neox = mode & HTP_ROPE_TYPE_NEOX; + HVX_Vector vx0_c = Q6_Vqf32_vmpy_VsfVsf(Q6_V_lo_W(vx0_x1), Q6_V_lo_W(vcos_sin)); + HVX_Vector vx0_s = Q6_Vqf32_vmpy_VsfVsf(Q6_V_lo_W(vx0_x1), Q6_V_hi_W(vcos_sin)); + HVX_Vector vx1_c = Q6_Vqf32_vmpy_VsfVsf(Q6_V_hi_W(vx0_x1), Q6_V_lo_W(vcos_sin)); + HVX_Vector vx1_s = Q6_Vqf32_vmpy_VsfVsf(Q6_V_hi_W(vx0_x1), Q6_V_hi_W(vcos_sin)); - htp_rope_preamble; + HVX_Vector v4 = Q6_Vqf32_vsub_Vqf32Vqf32(vx0_c, vx1_s); + HVX_Vector v5 = Q6_Vqf32_vadd_Vqf32Vqf32(vx0_s, vx1_c); - const int32_t * pos = (const int32_t *) src1->data; - - float * wp0 = (float *) (octx->src0_spad.data + (ith * nb01)); - - const float * freq_factors = NULL; - if (src2 != NULL) { - freq_factors = (const float *) src2->data; - } + HVX_VectorPair vstore = Q6_W_vshuff_VVR(Q6_Vsf_equals_Vqf32(v5), Q6_Vsf_equals_Vqf32(v4), -4); - const uint32_t i1_end = MIN(ir1, ne1); - const int32_t half_dims = rope_ctx->n_dims / 2; - const size_t remain_bytes = (ne0 - rope_ctx->n_dims) * sizeof(float); - for (uint32_t i3 = 0; i3 < ne3; i3++) { // batch - for (uint32_t i2 = 0; i2 < ne2; i2++) { // seq-len - const int32_t p = pos[i2]; + hvx_vec_store_u(dst + nvec * 64, nloe * sizeof(float), Q6_V_lo_W(vstore)); + } else { + HVX_Vector v0 = hvx_vmemu(src0 + nvec * 64); + HVX_Vector v1 = hvx_vmemu(src0 + nvec * 64 + 32); - rope_cache_init(p, rope_ctx->freq_scale, freq_factors, rope_ctx->corr_dims, ne0, rope_ctx->ext_factor, - rope_ctx->attn_factor, wp0, rope_ctx->theta_scale); + HVX_Vector v2 = hvx_vmemu(theta_cache + nvec * 64); + HVX_Vector v3 = hvx_vmemu(theta_cache + nvec * 64 + 32); - for (uint32_t i1 = ir0; i1 < i1_end; i1++) { // attn-heads - const float * src = (float *) ((char *) src0->data + i3 * nb03 + i2 * nb02 + i1 * nb01); - float * dst_data = (float *) ((char *) dst->data + i3 * nb3 + i2 * nb2 + i1 * nb1); + HVX_VectorPair vx0_x1 = Q6_W_vdeal_VVR(v1, v0, -4); + HVX_VectorPair vcos_sin = Q6_W_vdeal_VVR(v3, v2, -4); - const float * src_loc = src; - float * dst_data_loc = dst_data; + HVX_Vector vx0_c = Q6_Vqf32_vmpy_VsfVsf(Q6_V_lo_W(vx0_x1), Q6_V_lo_W(vcos_sin)); + HVX_Vector vx0_s = Q6_Vqf32_vmpy_VsfVsf(Q6_V_lo_W(vx0_x1), Q6_V_hi_W(vcos_sin)); + HVX_Vector vx1_c = Q6_Vqf32_vmpy_VsfVsf(Q6_V_hi_W(vx0_x1), Q6_V_lo_W(vcos_sin)); + HVX_Vector vx1_s = Q6_Vqf32_vmpy_VsfVsf(Q6_V_hi_W(vx0_x1), Q6_V_hi_W(vcos_sin)); - if (1 == opt_path) { - if (is_neox) { - hvx_calc_rope_neox_f32(src_loc, dst_data_loc, rope_ctx->n_dims, wp0); - } else { - hvx_calc_rope_f32(src_loc, dst_data_loc, rope_ctx->n_dims, wp0); - } + HVX_Vector v4 = Q6_Vqf32_vsub_Vqf32Vqf32(vx0_c, vx1_s); + HVX_Vector v5 = Q6_Vqf32_vadd_Vqf32Vqf32(vx0_s, vx1_c); - src_loc += rope_ctx->n_dims; - dst_data_loc += rope_ctx->n_dims; - } else { - for (uint32_t i0 = 0; i0 < rope_ctx->n_dims; i0 += 2) { - const float cos_theta = wp0[i0 + 0]; - const float sin_theta = wp0[i0 + 1]; + HVX_VectorPair vstore = Q6_W_vshuff_VVR(Q6_Vsf_equals_Vqf32(v5), Q6_Vsf_equals_Vqf32(v4), -4); - if (is_neox) { - const float x0 = src_loc[0]; - const float x1 = src_loc[half_dims]; + ((HVX_Vector *) dst)[nvec * 2 + 0] = Q6_V_lo_W(vstore); + hvx_vec_store_u(dst + nvec * 64 + 32, (nloe - 32) * sizeof(float), Q6_V_hi_W(vstore)); + } + } +} - dst_data_loc[0] = x0 * cos_theta - x1 * sin_theta; - dst_data_loc[half_dims] = x0 * sin_theta + x1 * cos_theta; +static void inline rope_basic_f32(struct htp_rope_context * rctx, uint8_t * restrict dst, uint8_t * restrict src, + uint32_t nr, uint32_t ne0, const float * restrict theta_cache) { + #pragma unroll(4) + for (uint32_t i = 0; i < nr; i++) { + float * d = (float *) (dst + i * rctx->dst_row_size_aligned); + float * s = (float *) (src + i * rctx->src0_row_size_aligned); - src_loc += 1; - dst_data_loc += 1; - } else { - const float x0 = src_loc[0]; - const float x1 = src_loc[1]; + hvx_rope_f32_aa(d, s, rctx->n_dims, theta_cache); - dst_data_loc[0] = x0 * cos_theta - x1 * sin_theta; - dst_data_loc[1] = x0 * sin_theta + x1 * cos_theta; + // fill the remain channels with data from src tensor + if (rctx->n_dims < ne0) { + hvx_copy_f32_uu((uint8_t *)(d + rctx->n_dims), (uint8_t *)(s + rctx->n_dims), ne0 - rctx->n_dims); + } + } +} - src_loc += 2; - dst_data_loc += 2; - } - } +static void inline rope_neox_f32(struct htp_rope_context * rctx, uint8_t * restrict dst, uint8_t * restrict src, + uint32_t nr, uint32_t ne0, const float * restrict theta_cache) { + #pragma unroll(4) + for (uint32_t i = 0; i < nr; i++) { + float * d = (float *) (dst + i * rctx->dst_row_size_aligned); + float * s = (float *) (src + i * rctx->src0_row_size_aligned); - src_loc += (is_neox ? half_dims : 0); - dst_data_loc += (is_neox ? half_dims : 0); - } + hvx_rope_neox_f32_aa(d, s, rctx->n_dims, theta_cache); - // TODO: use simd to speed up the remaining elements copy - memcpy(dst_data_loc, src_loc, remain_bytes); - } + // fill the remain channels with data from src tensor + if (rctx->n_dims < ne0) { + hvx_copy_f32_uu((uint8_t *)(d + rctx->n_dims), (uint8_t *)(s + rctx->n_dims), ne0 - rctx->n_dims); } } } -static void rope_job_f32_per_thread(struct rope_th_ctx * rope_ctx, int nth, int ith) { - struct htp_ops_context * octx = rope_ctx->octx; +static void rope_job_f32(unsigned int nth, unsigned int ith, void * data) { + struct htp_rope_context * rctx = (struct htp_rope_context *) data; + struct htp_ops_context * octx = rctx->octx; - const struct htp_tensor * src0 = &octx->src0; - const struct htp_tensor * src1 = &octx->src1; - struct htp_tensor * dst = &octx->dst; + const struct htp_tensor * src0 = octx->src[0]; + const struct htp_tensor * src1 = octx->src[1]; + const struct htp_tensor * src2 = octx->src[2]; + const struct htp_tensor * dst = octx->dst; htp_rope_preamble; - const uint32_t src0_nrows = ne01 * ne02 * ne03; // src0 rows - const uint32_t src0_nrows_per_thread = octx->src0_nrows_per_thread; + const uint32_t src0_nrows = rctx->src0_nrows; + const uint32_t src0_nrows_per_thread = rctx->src0_nrows_per_thread; const uint32_t src0_start_row = src0_nrows_per_thread * ith; const uint32_t src0_end_row = MIN(src0_start_row + src0_nrows_per_thread, src0_nrows); @@ -365,53 +444,149 @@ static void rope_job_f32_per_thread(struct rope_th_ctx * rope_ctx, int nth, int return; } - uint64_t t1, t2; - t1 = HAP_perf_get_qtimer_count(); + uint64_t tt = HAP_perf_get_qtimer_count(); - int is_aligned = 1; - int opt_path = 0; - if ((0 == htp_is_aligned((void *) src0->data, VLEN)) || (0 == htp_is_aligned((void *) src1->data, VLEN)) || - (0 == htp_is_aligned((void *) dst->data, VLEN))) { - FARF(HIGH, "rope-f32: unaligned addresses in rope op, possibly slower execution\n"); - is_aligned = 0; - } - if ((1 == is_aligned) && !(nb01 & (VLEN - 1))) { - opt_path = 1; - } + const int32_t mode = rctx->mode; + // MROPE and IMROPE use NEOX-style pairing for the rotation + const bool is_neox = (mode & HTP_ROPE_TYPE_NEOX) || (mode & HTP_ROPE_TYPE_MROPE); - rope_hex_f32(rope_ctx, src0_start_row, src0_end_row, nth, ith, opt_path); + // VTCM setup + uint8_t * src0_spad_base = octx->src0_spad.data + (ith * octx->src0_spad.size_per_thread); + float * theta_cache = (float *) (src0_spad_base); + src0_spad_base = src0_spad_base + rctx->theta_cache_offset; + uint8_t * dst_spad_base = octx->dst_spad.data + (ith * octx->dst_spad.size_per_thread); - t2 = HAP_perf_get_qtimer_count(); + dma_queue * dma_queue = octx->ctx->dma[ith]; + const int32_t * pos = (const int32_t *) src1->data; + const float * freq_factors = src2 ? (const float *) src2->data : NULL; - FARF(HIGH, "rope-f32: %d/%d/%d: (%u:%u) usec %u\n", ith, nth, opt_path, src0_start_row, src0_end_row, - (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1)); -} + const uint32_t i3_start = fastdiv(src0_start_row, &rctx->div_ne2_ne1); + const uint32_t rem = fastmodulo(src0_start_row, ne2 * ne1, &rctx->div_ne2_ne1); + const uint32_t i2_start = fastdiv(rem, &rctx->div_ne1); + const uint32_t i1_start = fastmodulo(rem, ne1, &rctx->div_ne1); + + uint32_t ir = src0_start_row; + uint32_t prev_i2 = (uint32_t) -1; + + for (uint32_t i3 = i3_start; i3 < ne3; i3++) { // batch + const uint32_t i2_init = (i3 == i3_start) ? i2_start : 0; + for (uint32_t i2 = i2_init; i2 < ne2; i2++) { // seq-len + const uint32_t i1_init = (i3 == i3_start && i2 == i2_start) ? i1_start : 0; + for (uint32_t i1 = i1_init; i1 < ne1; ) { // attn-heads + if (ir >= src0_end_row) goto done; + + // Rows in this block + const uint32_t nrows = MIN(src0_end_row - ir, ne1 - i1); + + // Depth before prefetch + uint32_t dma_depth = dma_queue_depth(dma_queue); + + // FARF(HIGH, "rope-block %u: ir %u n-rows %u dma-depth %u : usec %u", ith, ir, nrows, dma_depth, + // (unsigned) HAP_perf_qtimer_count_to_us(HAP_perf_get_qtimer_count() - rctx->t_start)); + + // Prefetch loop + for (uint32_t pnr = 0, pr = 0; pr < nrows && pr < HTP_ROPE_SPAD_NROWS; pr += pnr) { + pnr = MIN(nrows - pr, HTP_ROPE_SPAD_BLOCK); + + uint32_t pi1 = i1 + pr; + uint32_t pir = ir + pr; + + // Dummy DMA transaction for sequencing (interleaving dst,src,dst,...) + dma_queue_push_vtcm_to_ddr(dma_queue, dma_make_ptr((void *) dst->data, dst_spad_base + pr * rctx->dst_row_size_aligned), 0, 0, 0); + + const uint8_t * src_addr = (const uint8_t *) src0->data + i3 * nb03 + i2 * nb02 + pi1 * nb01; + uint8_t * src_spad = src0_spad_base + pr * rctx->src0_row_size_aligned; + dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(src_spad, src_addr), + rctx->src0_row_size_aligned, rctx->src0_row_size, pnr); + + // FARF(HIGH, "rope-prefetch %u: pr %u i1 %u i2 %u i3 %u src-spad %p src-addr %p pnr %u", ith, pir, pi1, i2, i3, src_spad, src_addr, pnr); + } + + // Update theta cache + if (i2 != prev_i2) { + prev_i2 = i2; + + const bool is_mrope = (rctx->mode & HTP_ROPE_TYPE_MROPE) != 0; + if (is_mrope) { + // src1 holds four position arrays stacked along ne0: + // pos[i2], pos[i2+ne2], pos[i2+ne2*2], pos[i2+ne2*3] + const bool is_imrope = (rctx->mode == HTP_ROPE_TYPE_IMROPE); + mrope_cache_init( + (float) pos[i2], + (float) pos[i2 + ne2], + (float) pos[i2 + ne2 * 2], + (float) pos[i2 + ne2 * 3], + rctx->sections, is_imrope, + rctx->freq_scale, freq_factors, rctx->corr_dims, + ne0, rctx->ext_factor, rctx->attn_factor, + theta_cache, rctx->theta_scale); + } else { + rope_cache_init(pos[i2], rctx->freq_scale, freq_factors, rctx->corr_dims, + ne0, rctx->ext_factor, rctx->attn_factor, + theta_cache, rctx->theta_scale); + } + } + + // Skip output DMA transactions from prev block (if any) + // No need to wait for those here since we're explicitly waiting for the latest prefecthes below. + for (uint32_t d=0; d < dma_depth; d++) { dma_queue_pop_nowait(dma_queue); } + + // Compute loop + for (uint32_t cnr = 0, cr = 0; cr < nrows; cr += cnr, ir += cnr, i1 += cnr) { + // Number of rows to compute + cnr = MIN(nrows - cr, HTP_ROPE_SPAD_BLOCK); + + uint8_t * dst_spad = (uint8_t *) dma_queue_pop(dma_queue).src; + uint8_t * src_spad = (uint8_t *) dma_queue_pop(dma_queue).dst; + + // FARF(HIGH, "rope-compute %u: ir %u i1 %u i2 %u i3 %u src-spad %p cnr %u : usec %u", ith, ir, i1, i2, i3, src_spad, cnr, + // (unsigned) HAP_perf_qtimer_count_to_us(HAP_perf_get_qtimer_count() - rctx->t_start)); + + if (is_neox) { + rope_neox_f32(rctx, dst_spad, src_spad, cnr, ne0, theta_cache); + } else { + rope_basic_f32(rctx, dst_spad, src_spad, cnr, ne0, theta_cache); + } + + uint8_t * dst_addr = (uint8_t *) dst->data + i3 * nb3 + i2 * nb2 + i1 * nb1; + dma_queue_push_vtcm_to_ddr(dma_queue, dma_make_ptr(dst_addr, dst_spad), rctx->dst_row_size, rctx->dst_row_size_aligned, cnr); + + // Prefetch more rows (if any) + if ((cr + HTP_ROPE_SPAD_NROWS) < nrows) { + uint32_t pnr = MIN(nrows - (cr + HTP_ROPE_SPAD_NROWS), HTP_ROPE_SPAD_BLOCK); + uint32_t pi1 = i1 + HTP_ROPE_SPAD_NROWS; + uint32_t pir = ir + HTP_ROPE_SPAD_NROWS; + + const uint8_t * src_addr = (const uint8_t *) src0->data + i3 * nb03 + i2 * nb02 + pi1 * nb01; + dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(src_spad, src_addr), + rctx->src0_row_size_aligned, rctx->src0_row_size, pnr); + + // FARF(HIGH, "rope-prefetch %u: pr %u i1 %u i2 %u i3 %u src-spad %p src-addr %p pnr %u", ith, pir, pi1, i2, i3, src_spad, src_addr, pnr); + } + } + } + } + } -static void rope_job_dispatcher_f32(unsigned int n, unsigned int i, void * data) { - struct rope_th_ctx * rope_ctx = (struct rope_th_ctx *) data; +done: + dma_queue_flush(dma_queue); + tt = HAP_perf_get_qtimer_count() - tt; - rope_job_f32_per_thread(rope_ctx, n, i); + FARF(HIGH, "rope-f32: %d/%d: (%u:%u) usec %u\n", ith, nth, src0_start_row, src0_end_row, (unsigned) HAP_perf_qtimer_count_to_us(tt)); } static int execute_op_rope_f32(struct htp_ops_context * octx) { int err = HTP_STATUS_OK; - const struct htp_tensor * src0 = &octx->src0; - const struct htp_tensor * src1 = &octx->src1; - const struct htp_tensor * src2 = &octx->src2; - struct htp_tensor * dst = &octx->dst; + const struct htp_tensor * src0 = octx->src[0]; + const struct htp_tensor * src1 = octx->src[1]; + const struct htp_tensor * src2 = octx->src[2]; + const struct htp_tensor * dst = octx->dst; - worker_callback_t op_func; - const char * op_type = NULL; - - struct rope_th_ctx rope_ctx; + const char * op_type = "rope-f32"; switch (octx->op) { case HTP_OP_ROPE: - op_func = rope_job_dispatcher_f32; - op_type = "rope-f32"; - - init_rope_ctx(&rope_ctx, octx); break; default: @@ -419,52 +594,83 @@ static int execute_op_rope_f32(struct htp_ops_context * octx) { return HTP_STATUS_NO_SUPPORT; } - const uint32_t n_threads = octx->n_threads; + const uint32_t ne0 = dst->ne[0]; + const uint32_t src0_nrows = src0->ne[1] * src0->ne[2] * src0->ne[3]; + const uint32_t n_threads = MIN(octx->n_threads, src0_nrows); const size_t src0_row_size = src0->nb[1]; - const size_t src1_row_size = src0_row_size; const size_t dst_row_size = dst->nb[1]; - // VTCM scratchpads for all tensors - // N rows per thread, padded to HVX vector size - octx->dst_spad.size = htp_round_up(dst_row_size, 128) * n_threads; - octx->src0_spad.size = htp_round_up(src0_row_size, 128) * n_threads; - octx->src1_spad.size = htp_round_up(src1_row_size, 128) * n_threads; - - size_t spad_size = octx->src0_spad.size + octx->src1_spad.size + octx->dst_spad.size; - - if (src2->ne[0]) { - FARF(HIGH, - "%s: %ux%ux%ux%u (x %ux%ux%ux%u x %ux%ux%ux%u) -> %ux%ux%ux%u : src0-spad-size %u src1-spad-size %u " - "dst-spad-size %u\n", - op_type, src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], src1->ne[0], src1->ne[1], src1->ne[2], - src1->ne[3], src2->ne[0], src2->ne[1], src2->ne[2], src2->ne[3], dst->ne[0], dst->ne[1], dst->ne[2], - dst->ne[3], octx->src0_spad.size, octx->src1_spad.size, octx->dst_spad.size); - } else { - FARF(HIGH, - "%s: %ux%ux%ux%u (%ux%ux%ux%u) -> %ux%ux%ux%u : src0-spad-size %u src1-spad-size %u dst-spad-size %u\n", - op_type, src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], src1->ne[0], src1->ne[1], src1->ne[2], - src1->ne[3], dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], octx->src0_spad.size, octx->src1_spad.size, - octx->dst_spad.size); - } + // Aligned row sizes for VTCM + const size_t src0_row_size_aligned = hex_round_up(src0_row_size, VLEN); + const size_t dst_row_size_aligned = hex_round_up(dst_row_size, VLEN); + const size_t theta_cache_size_aligned = hex_round_up(src0->ne[0] * sizeof(float), 256); - // Make sure the reserved vtcm size is sufficient - if (octx->ctx->vtcm_size < spad_size) { - FARF(ERROR, "%s : current VTCM reservation %zu is too small, needed %zu\n", op_type, octx->ctx->vtcm_size, - spad_size); + // Calculate spad sizes per thread + size_t src0_spad_per_thread = theta_cache_size_aligned + HTP_ROPE_SPAD_NROWS * src0_row_size_aligned; + size_t dst_spad_per_thread = HTP_ROPE_SPAD_NROWS * dst_row_size_aligned; + size_t spad_per_thread = src0_spad_per_thread + dst_spad_per_thread; + + // Check if we fit in VTCM + size_t total_vtcm_needed = spad_per_thread * n_threads; + if (octx->ctx->vtcm_size < total_vtcm_needed) { + FARF(ERROR, "%s : current VTCM reservation %zu is too small, needed %zu\n", op_type, octx->ctx->vtcm_size, total_vtcm_needed); return HTP_STATUS_VTCM_TOO_SMALL; } - octx->src0_spad.data = octx->ctx->vtcm_base; - octx->src1_spad.data = octx->src0_spad.data + octx->src0_spad.size; - octx->dst_spad.data = octx->src1_spad.data + octx->src1_spad.size; + octx->src0_spad.size_per_thread = src0_spad_per_thread; + octx->dst_spad.size_per_thread = dst_spad_per_thread; + octx->src0_spad.size = n_threads * src0_spad_per_thread; + octx->dst_spad.size = n_threads * dst_spad_per_thread; + octx->src1_spad.size = 0; + + octx->src0_spad.data = octx->ctx->vtcm_base; octx->src0_spad.src = NULL; + octx->src1_spad.data = NULL; octx->src1_spad.src = NULL; + octx->dst_spad.data = octx->src0_spad.data + octx->src0_spad.size; octx->dst_spad.src = NULL; + + struct htp_rope_context rctx; + memset(&rctx, 0, sizeof(struct htp_rope_context)); + + rctx.t_start = HAP_perf_get_qtimer_count(); + + rctx.octx = octx; + + const int32_t * op_params = &octx->op_params[0]; + rctx.n_dims = ((const int32_t *) op_params)[1]; + rctx.mode = ((const int32_t *) op_params)[2]; + rctx.n_ctx_orig = ((const int32_t *) op_params)[4]; + + memcpy(&rctx.freq_base, (int32_t *) op_params + 5, sizeof(float)); + memcpy(&rctx.freq_scale, (int32_t *) op_params + 6, sizeof(float)); + memcpy(&rctx.ext_factor, (int32_t *) op_params + 7, sizeof(float)); + memcpy(&rctx.attn_factor, (int32_t *) op_params + 8, sizeof(float)); + memcpy(&rctx.beta_fast, (int32_t *) op_params + 9, sizeof(float)); + memcpy(&rctx.beta_slow, (int32_t *) op_params + 10, sizeof(float)); + memcpy(&rctx.sections, (int32_t *) op_params + 11, sizeof(int) * 4); + + rctx.theta_scale = powf(rctx.freq_base, -2.0f / rctx.n_dims); + + rope_corr_dims(rctx.n_dims, rctx.n_ctx_orig, rctx.freq_base, rctx.beta_fast, rctx.beta_slow, rctx.corr_dims); + + rctx.src0_row_size = src0_row_size; + rctx.dst_row_size = dst_row_size; + rctx.src0_row_size_aligned = src0_row_size_aligned; + rctx.dst_row_size_aligned = dst_row_size_aligned; + rctx.theta_cache_offset = theta_cache_size_aligned; + + rctx.src0_nrows = src0_nrows; + rctx.src0_nrows_per_thread = (src0_nrows + n_threads - 1) / n_threads; + + if (src0_nrows > 0) { + rctx.div_ne2_ne1 = init_fastdiv_values(dst->ne[2] * dst->ne[1]); + rctx.div_ne1 = init_fastdiv_values(dst->ne[1]); + } - uint32_t src0_nrows = src0->ne[1] * src0->ne[2] * src0->ne[3]; + FARF(HIGH, "rope-f32 n-rows %u n-dims %d ne0 %u ext-factor %.6f theta-scale %.6f attn-factor %.6f\n", rctx.src0_nrows, rctx.n_dims, ne0, + rctx.ext_factor, rctx.theta_scale, rctx.attn_factor); if (!(octx->flags & HTP_OPFLAGS_SKIP_COMPUTE)) { - uint32_t n_jobs = MIN(n_threads, src0_nrows); - octx->src0_nrows_per_thread = (src0_nrows + n_jobs - 1) / n_jobs; - worker_pool_run_func(octx->ctx->worker_pool, op_func, &rope_ctx, n_jobs); + worker_pool_run_func(octx->ctx->worker_pool, rope_job_f32, &rctx, n_threads); } return err; @@ -473,7 +679,7 @@ static int execute_op_rope_f32(struct htp_ops_context * octx) { int op_rope(struct htp_ops_context * octx) { int err = HTP_STATUS_OK; - switch (octx->src0.type) { + switch (octx->src[0]->type) { case HTP_TYPE_F32: err = execute_op_rope_f32(octx); break; diff --git a/ggml/src/ggml-hexagon/htp/set-rows-ops.c b/ggml/src/ggml-hexagon/htp/set-rows-ops.c index bdd64fcc8f7..58c54967db0 100644 --- a/ggml/src/ggml-hexagon/htp/set-rows-ops.c +++ b/ggml/src/ggml-hexagon/htp/set-rows-ops.c @@ -2,69 +2,84 @@ #pragma clang diagnostic ignored "-Wunused-function" #pragma clang diagnostic ignored "-Wunused-but-set-variable" -#ifdef HTP_DEBUG -# define FARF_HIGH 1 -#endif #include <HAP_farf.h> -#include <HAP_mem.h> #include <HAP_perf.h> -#include <hexagon_protos.h> -#include <hexagon_types.h> + #include <math.h> #include <string.h> +#include "hex-dma.h" +#include "hvx-utils.h" + #define GGML_COMMON_DECL_C #include "ggml-common.h" #include "htp-ctx.h" -#include "htp-msg.h" #include "htp-ops.h" -#include "hvx-utils.h" -#include "ops-utils.h" - -#define set_rows_preamble \ - const uint32_t ne00 = octx->src0.ne[0]; \ - const uint32_t ne01 = octx->src0.ne[1]; \ - const uint32_t ne02 = octx->src0.ne[2]; \ - const uint32_t ne03 = octx->src0.ne[3]; \ - \ - const uint32_t ne10 = octx->src1.ne[0]; \ - const uint32_t ne11 = octx->src1.ne[1]; \ - const uint32_t ne12 = octx->src1.ne[2]; \ - \ - const uint32_t nb01 = octx->src0.nb[1]; \ - const uint32_t nb02 = octx->src0.nb[2]; \ - const uint32_t nb03 = octx->src0.nb[3]; \ - \ - const uint32_t nb10 = octx->src1.nb[0]; \ - const uint32_t nb11 = octx->src1.nb[1]; \ - const uint32_t nb12 = octx->src1.nb[2]; \ - \ - const uint32_t nb1 = octx->dst.nb[1]; \ - const uint32_t nb2 = octx->dst.nb[2]; \ - const uint32_t nb3 = octx->dst.nb[3]; \ - \ - const uint32_t ne1 = octx->dst.ne[1]; \ - \ +#include "htp-ops.h" + +#define set_rows_preamble \ + const uint32_t ne00 = octx->src[0]->ne[0]; \ + const uint32_t ne01 = octx->src[0]->ne[1]; \ + const uint32_t ne02 = octx->src[0]->ne[2]; \ + const uint32_t ne03 = octx->src[0]->ne[3]; \ + \ + const uint32_t ne10 = octx->src[1]->ne[0]; \ + const uint32_t ne11 = octx->src[1]->ne[1]; \ + const uint32_t ne12 = octx->src[1]->ne[2]; \ + const uint32_t ne13 = octx->src[1]->ne[3]; \ + \ + const uint32_t nb01 = octx->src[0]->nb[1]; \ + const uint32_t nb02 = octx->src[0]->nb[2]; \ + const uint32_t nb03 = octx->src[0]->nb[3]; \ + \ + const uint32_t nb10 = octx->src[1]->nb[0]; \ + const uint32_t nb11 = octx->src[1]->nb[1]; \ + const uint32_t nb12 = octx->src[1]->nb[2]; \ + \ + const uint32_t nb1 = octx->dst->nb[1]; \ + const uint32_t nb2 = octx->dst->nb[2]; \ + const uint32_t nb3 = octx->dst->nb[3]; \ + \ + const uint32_t ne0 = octx->dst->ne[0]; \ + const uint32_t ne1 = octx->dst->ne[1]; \ + const uint32_t ne2 = octx->dst->ne[2]; \ + const uint32_t ne3 = octx->dst->ne[3]; \ + \ const uint32_t nr = ne01; -static int set_rows_thread_f32_f32(struct htp_ops_context * octx, const int nth, const int ith) { +struct htp_set_rows_context { + struct htp_ops_context * octx; + struct fastdiv_values div_ne12; + struct fastdiv_values div_ne11; + uint32_t src0_nrows_per_thread; +}; + +static void set_rows_thread_f32_f32(unsigned int nth, unsigned int ith, void *data) { + struct htp_set_rows_context * srctx = (struct htp_set_rows_context *)data; + struct htp_ops_context * octx = srctx->octx; + set_rows_preamble; + uint64_t qt = HAP_perf_get_qtimer_count(); + // parallelize by rows of src0 - const uint32_t dr = octx->src0_nrows_per_thread; + const uint32_t dr = srctx->src0_nrows_per_thread; const uint32_t ir0 = dr * ith; + if (ir0 >= nr) { + return; + } const uint32_t ir1 = (ir0 + dr < nr) ? (ir0 + dr) : nr; - const bool is_i32 = (octx->src1.type == HTP_TYPE_I32); + const bool is_i32 = (octx->src[1]->type == HTP_TYPE_I32); for (uint32_t i03 = 0; i03 < ne03; ++i03) { for (uint32_t i02 = 0; i02 < ne02; ++i02) { for (uint32_t i = ir0; i < ir1; ++i) { - const uint32_t i12 = fastmodulo(i03, ne12, &octx->set_rows_div_ne12); - const uint32_t i11 = fastmodulo(i02, ne11, &octx->set_rows_div_ne11); + const uint32_t i12 = fastmodulo(i03, ne12, &srctx->div_ne12); + const uint32_t i11 = fastmodulo(i02, ne11, &srctx->div_ne11); const uint32_t i10 = i; - const uintptr_t src1_addr = octx->src1.data + i10*nb10 + i11*nb11 + i12*nb12; + const uintptr_t src1_addr = octx->src[1]->data + i10*nb10 + i11*nb11 + i12*nb12; uint32_t i1 = is_i32 ? *(int32_t *)src1_addr : *(int64_t *)src1_addr; if (i1 >= ne1) { @@ -72,36 +87,46 @@ static int set_rows_thread_f32_f32(struct htp_ops_context * octx, const int nth, continue; } - const uintptr_t src0_ptr = octx->src0.data + i*nb01 + i02*nb02 + i03*nb03; - const uintptr_t dst_ptr = octx->dst.data + i1*nb1 + i02*nb2 + i03*nb3; + const uintptr_t src0_ptr = octx->src[0]->data + i*nb01 + i02*nb02 + i03*nb03; + const uintptr_t dst_ptr = octx->dst->data + i1*nb1 + i02*nb2 + i03*nb3; // copy row - hvx_copy_fp32_uu((uint8_t *)dst_ptr, (const uint8_t *)src0_ptr, ne00); + hvx_copy_f32_uu((uint8_t *)dst_ptr, (const uint8_t *)src0_ptr, ne00); } } } - return HTP_STATUS_OK; + qt = HAP_perf_qtimer_count_to_us(HAP_perf_get_qtimer_count() - qt); + FARF(HIGH, "set-rows-f32-f32 %d/%d: %ux%ux%ux%u (%u:%u) x %ux%ux%ux%u -> %ux%ux%ux%u usec %u\n", ith, nth, + ne00, ne01, ne02, ne03, ir0, ir1, ne10, ne11, ne12, ne13, ne0, ne1, ne2, ne3, (unsigned) qt); } -static int set_rows_thread_f16_f32(struct htp_ops_context * octx, const int nth, const int ith) { +static void set_rows_thread_f16_f32(unsigned int nth, unsigned int ith, void *data) { + struct htp_set_rows_context * srctx = (struct htp_set_rows_context *)data; + struct htp_ops_context * octx = srctx->octx; + set_rows_preamble; + uint64_t qt = HAP_perf_get_qtimer_count(); + // parallelize by rows of src0 - const uint32_t dr = octx->src0_nrows_per_thread; + const uint32_t dr = srctx->src0_nrows_per_thread; const uint32_t ir0 = dr * ith; + if (ir0 >= nr) { + return; + } const uint32_t ir1 = (ir0 + dr < nr) ? (ir0 + dr) : nr; - const bool is_i32 = (octx->src1.type == HTP_TYPE_I32); + const bool is_i32 = (octx->src[1]->type == HTP_TYPE_I32); for (uint32_t i03 = 0; i03 < ne03; ++i03) { for (uint32_t i02 = 0; i02 < ne02; ++i02) { for (uint32_t i = ir0; i < ir1; ++i) { - const uint32_t i12 = fastmodulo(i03, ne12, &octx->set_rows_div_ne12); - const uint32_t i11 = fastmodulo(i02, ne11, &octx->set_rows_div_ne11); + const uint32_t i12 = fastmodulo(i03, ne12, &srctx->div_ne12); + const uint32_t i11 = fastmodulo(i02, ne11, &srctx->div_ne11); const uint32_t i10 = i; - const uintptr_t src1_addr = octx->src1.data + i10*nb10 + i11*nb11 + i12*nb12; + const uintptr_t src1_addr = octx->src[1]->data + i10*nb10 + i11*nb11 + i12*nb12; uint32_t i1 = is_i32 ? *(int32_t *)src1_addr : *(int64_t *)src1_addr; if (i1 >= ne1) { @@ -109,37 +134,33 @@ static int set_rows_thread_f16_f32(struct htp_ops_context * octx, const int nth, continue; } - const uint8_t* src0_ptr = (const uint8_t *) octx->src0.data + i*nb01 + i02*nb02 + i03*nb03; - uint8_t* dst_ptr = (uint8_t *) octx->dst.data + i1*nb1 + i02*nb2 + i03*nb3; + const uint8_t* src0_ptr = (const uint8_t *) octx->src[0]->data + i*nb01 + i02*nb02 + i03*nb03; + uint8_t* dst_ptr = (uint8_t *) octx->dst->data + i1*nb1 + i02*nb2 + i03*nb3; - hvx_copy_fp16_fp32_uu(dst_ptr, src0_ptr, ne00); + hvx_copy_f16_f32_uu(dst_ptr, src0_ptr, ne00); } } } - return HTP_STATUS_OK; -} - -static void set_rows_work_f16_f32(unsigned int n, unsigned int i, void *data) { - set_rows_thread_f16_f32((struct htp_ops_context *) data, n, i); -} - -static void set_rows_work_f32_f32(unsigned int n, unsigned int i, void *data) { - set_rows_thread_f32_f32((struct htp_ops_context *) data, n, i); + qt = HAP_perf_qtimer_count_to_us(HAP_perf_get_qtimer_count() - qt); + FARF(HIGH, "set-rows-f16-f32 %d/%d: %ux%ux%ux%u (%u:%u) x %ux%ux%ux%u -> %ux%ux%ux%u usec %u\n", ith, nth, + ne00, ne01, ne02, ne03, ir0, ir1, ne10, ne11, ne12, ne13, ne0, ne1, ne2, ne3, (unsigned) qt); } int op_set_rows(struct htp_ops_context * octx) { set_rows_preamble; - if (octx->src0.type != HTP_TYPE_F32) { + const uint32_t n_threads = MIN(nr, octx->n_threads); + + if (octx->src[0]->type != HTP_TYPE_F32) { return HTP_STATUS_NO_SUPPORT; } - if (octx->dst.type != HTP_TYPE_F32 && octx->dst.type != HTP_TYPE_F16) { + if (octx->dst->type != HTP_TYPE_F32 && octx->dst->type != HTP_TYPE_F16) { return HTP_STATUS_NO_SUPPORT; } - if (octx->src1.type != HTP_TYPE_I32 && octx->src1.type != HTP_TYPE_I64) { + if (octx->src[1]->type != HTP_TYPE_I32 && octx->src[1]->type != HTP_TYPE_I64) { return HTP_STATUS_NO_SUPPORT; } @@ -147,18 +168,19 @@ int op_set_rows(struct htp_ops_context * octx) { return HTP_STATUS_OK; } - octx->set_rows_div_ne12 = init_fastdiv_values(ne12); - octx->set_rows_div_ne11 = init_fastdiv_values(ne11); + struct htp_set_rows_context srctx; + srctx.octx = octx; + srctx.div_ne12 = init_fastdiv_values(ne12); + srctx.div_ne11 = init_fastdiv_values(ne11); - const uint32_t n_jobs = MIN(nr, octx->n_threads); - octx->src0_nrows_per_thread = (nr + n_jobs - 1) / n_jobs; + srctx.src0_nrows_per_thread = (nr + n_threads - 1) / n_threads; - switch(octx->dst.type) { + switch(octx->dst->type) { case HTP_TYPE_F32: - worker_pool_run_func(octx->ctx->worker_pool, set_rows_work_f32_f32, octx, n_jobs); + worker_pool_run_func(octx->ctx->worker_pool, set_rows_thread_f32_f32, &srctx, n_threads); break; case HTP_TYPE_F16: - worker_pool_run_func(octx->ctx->worker_pool, set_rows_work_f16_f32, octx, n_jobs); + worker_pool_run_func(octx->ctx->worker_pool, set_rows_thread_f16_f32, &srctx, n_threads); break; default: return HTP_STATUS_NO_SUPPORT; diff --git a/ggml/src/ggml-hexagon/htp/softmax-ops.c b/ggml/src/ggml-hexagon/htp/softmax-ops.c index 80d249a22c6..d78bcc0eb24 100644 --- a/ggml/src/ggml-hexagon/htp/softmax-ops.c +++ b/ggml/src/ggml-hexagon/htp/softmax-ops.c @@ -2,92 +2,127 @@ #pragma clang diagnostic ignored "-Wunused-function" #pragma clang diagnostic ignored "-Wunused-but-set-variable" -#ifdef HTP_DEBUG -# define FARF_HIGH 1 -#endif #include <HAP_farf.h> -#include <HAP_mem.h> #include <HAP_perf.h> -#include <HAP_ps.h> -#include <hexagon_protos.h> -#include <hexagon_types.h> + #include <math.h> -#include <qurt_thread.h> #include <string.h> +#include "hex-dma.h" +#include "hvx-utils.h" +#include "hex-fastdiv.h" + #define GGML_COMMON_DECL_C #include "ggml-common.h" #include "htp-ctx.h" -#include "htp-dma.h" -#include "htp-msg.h" #include "htp-ops.h" -#include "hvx-utils.h" -#include "ops-utils.h" - -#define htp_softmax_preamble3 \ - const uint32_t ne00 = src0->ne[0]; \ - const uint32_t ne01 = src0->ne[1]; \ - const uint32_t ne02 = src0->ne[2]; \ - const uint32_t ne03 = src0->ne[3]; \ - \ - const uint32_t nb00 = src0->nb[0]; \ - const uint32_t nb01 = src0->nb[1]; \ - const uint32_t nb02 = src0->nb[2]; \ - const uint32_t nb03 = src0->nb[3]; \ - \ - const uint32_t ne10 = (src1->ne[0]) ? src1->ne[0] : 1; \ - const uint32_t ne11 = (src1->ne[0]) ? src1->ne[1] : 1; \ - const uint32_t ne12 = (src1->ne[0]) ? src1->ne[2] : 1; \ - const uint32_t ne13 = (src1->ne[0]) ? src1->ne[3] : 1; \ - \ - const uint32_t nb10 = (src1->ne[0]) ? src1->nb[0] : 1; \ - const uint32_t nb11 = (src1->ne[0]) ? src1->nb[1] : 1; \ - const uint32_t nb12 = (src1->ne[0]) ? src1->nb[2] : 1; \ - const uint32_t nb13 = (src1->ne[0]) ? src1->nb[3] : 1; \ - \ - const uint32_t ne0 = dst->ne[0]; \ - const uint32_t ne1 = dst->ne[1]; \ - const uint32_t ne2 = dst->ne[2]; \ - const uint32_t ne3 = dst->ne[3]; \ - \ - const uint32_t nb0 = dst->nb[0]; \ - const uint32_t nb1 = dst->nb[1]; \ - const uint32_t nb2 = dst->nb[2]; \ +#include "htp-ops.h" + +#define htp_softmax_preamble3 \ + const uint32_t ne00 = src0->ne[0]; \ + const uint32_t ne01 = src0->ne[1]; \ + const uint32_t ne02 = src0->ne[2]; \ + const uint32_t ne03 = src0->ne[3]; \ + \ + const uint32_t nb00 = src0->nb[0]; \ + const uint32_t nb01 = src0->nb[1]; \ + const uint32_t nb02 = src0->nb[2]; \ + const uint32_t nb03 = src0->nb[3]; \ + \ + const uint32_t ne10 = src1 ? src1->ne[0] : 1; \ + const uint32_t ne11 = src1 ? src1->ne[1] : 1; \ + const uint32_t ne12 = src1 ? src1->ne[2] : 1; \ + const uint32_t ne13 = src1 ? src1->ne[3] : 1; \ + \ + const uint32_t nb10 = src1 ? src1->nb[0] : 1; \ + const uint32_t nb11 = src1 ? src1->nb[1] : 1; \ + const uint32_t nb12 = src1 ? src1->nb[2] : 1; \ + const uint32_t nb13 = src1 ? src1->nb[3] : 1; \ + \ + const uint32_t ne0 = dst->ne[0]; \ + const uint32_t ne1 = dst->ne[1]; \ + const uint32_t ne2 = dst->ne[2]; \ + const uint32_t ne3 = dst->ne[3]; \ + \ + const uint32_t nb0 = dst->nb[0]; \ + const uint32_t nb1 = dst->nb[1]; \ + const uint32_t nb2 = dst->nb[2]; \ const uint32_t nb3 = dst->nb[3]; -struct softmax_th_ctx { +struct htp_softmax_context { + struct htp_ops_context * octx; + bool use_f16; bool use_src1; + uint32_t n_head; uint32_t n_head_log2; - float scale; - float max_bias; - float m0; - float m1; + float scale; + float max_bias; + float m0; + float m1; - struct htp_ops_context * octx; + struct fastdiv_values fastdiv_ne01; + struct fastdiv_values fastdiv_ne02; + struct fastdiv_values fastdiv_ne12; // For mask broadcasting + struct fastdiv_values fastdiv_ne13; // For mask broadcasting + + uint32_t src0_nrows_per_thread; }; -static void init_softmax_ctx(struct softmax_th_ctx * softmax_ctx, struct htp_ops_context * octx) { - const struct htp_tensor * src0 = &octx->src0; - const struct htp_tensor * src1 = &octx->src1; +static void apply_mask(float * restrict wp0, + const float * restrict mp_f32, + const __fp16 * restrict mp_f16, + uint32_t ne00, + float slope, + bool use_f16) { + if (!mp_f32) { + return; + } + if (use_f16) { + for (uint32_t i = 0; i < ne00; ++i) { + wp0[i] += slope * (float) mp_f16[i]; + } + } else { + for (uint32_t i = 0; i < ne00; ++i) { + wp0[i] += slope * mp_f32[i]; + } + } +} + +static void init_softmax_ctx(struct htp_softmax_context * smctx, struct htp_ops_context * octx) { + const struct htp_tensor * src0 = octx->src[0]; + const struct htp_tensor * src1 = octx->src[1]; + + memset(smctx, 0, sizeof(struct htp_softmax_context)); + + memcpy(&smctx->scale, (float *) octx->op_params, sizeof(float)); + memcpy(&smctx->max_bias, (float *) octx->op_params + 1, sizeof(float)); + + smctx->n_head = src0->ne[2]; + smctx->n_head_log2 = 1u << (uint32_t) floor(log2(smctx->n_head)); + + smctx->m0 = powf(2.0f, -(smctx->max_bias) / smctx->n_head_log2); + smctx->m1 = powf(2.0f, -(smctx->max_bias / 2.0f) / smctx->n_head_log2); - memset(softmax_ctx, 0, sizeof(struct softmax_th_ctx)); + smctx->use_src1 = (src1 != 0); + smctx->use_f16 = (src1 != 0) && (src1->type == HTP_TYPE_F16); - memcpy(&softmax_ctx->scale, (float *) octx->op_params, sizeof(float)); - memcpy(&softmax_ctx->max_bias, (float *) octx->op_params + 1, sizeof(float)); + smctx->octx = octx; - softmax_ctx->n_head = src0->ne[2]; - softmax_ctx->n_head_log2 = 1u << (uint32_t) floor(log2(softmax_ctx->n_head)); + // Initialize fastdiv values + const uint32_t ne01 = src0->ne[1]; + const uint32_t ne02 = src0->ne[2]; - softmax_ctx->m0 = powf(2.0f, -(softmax_ctx->max_bias) / softmax_ctx->n_head_log2); - softmax_ctx->m1 = powf(2.0f, -(softmax_ctx->max_bias / 2.0f) / softmax_ctx->n_head_log2); + if (ne01 > 0) smctx->fastdiv_ne01 = init_fastdiv_values(ne01); + if (ne02 > 0) smctx->fastdiv_ne02 = init_fastdiv_values(ne02); - softmax_ctx->use_src1 = (src1->ne[0] != 0); - softmax_ctx->use_f16 = (src1->ne[0] != 0) && (src1->type == HTP_TYPE_F16); + const uint32_t ne12 = src1 ? src1->ne[2] : 1; + const uint32_t ne13 = src1 ? src1->ne[3] : 1; - softmax_ctx->octx = octx; + if (ne12 > 0) smctx->fastdiv_ne12 = init_fastdiv_values(ne12); + if (ne13 > 0) smctx->fastdiv_ne13 = init_fastdiv_values(ne13); } static void hvx_fast_softmax_prep_f32(const uint8_t * restrict src, @@ -100,8 +135,8 @@ static void hvx_fast_softmax_prep_f32(const uint8_t * restrict src, uint8_t * restrict dst_curr = dst; const uint8_t * restrict mask_curr = mask; - HVX_Vector scale_vec = hvx_vec_splat_fp32(scale); - HVX_Vector slope_vec = hvx_vec_splat_fp32(slope); + HVX_Vector scale_vec = hvx_vec_splat_f32(scale); + HVX_Vector slope_vec = hvx_vec_splat_f32(slope); int step_of_1 = num_elems >> 5; @@ -125,18 +160,15 @@ static void hvx_fast_softmax_prep_f32(const uint8_t * restrict src, } } -static void hvx_fast_softmax_f32(const uint8_t * restrict src, - uint8_t * restrict dst, - uint8_t * restrict pad, - const int num_elems) { +static void hvx_fast_softmax_f32(const uint8_t * restrict src, uint8_t * restrict dst, uint8_t * restrict pad, const int num_elems) { const HVX_Vector * restrict v_src = (HVX_Vector *) src; HVX_Vector * restrict v_pad = (HVX_Vector *) pad; HVX_Vector * restrict v_dst = (HVX_Vector *) dst; HVX_Vector sum_vec = Q6_V_vsplat_R(0x00000000); - HVX_Vector max_vec = hvx_vec_splat_fp32(((const float *) src)[0]); + HVX_Vector max_vec = hvx_vec_splat_f32(((const float *) src)[0]); HVX_Vector zero_v = Q6_V_vzero(); - HVX_Vector one_v = hvx_vec_splat_fp32(1.0); + HVX_Vector one_v = hvx_vec_splat_f32(1.0); int step_of_1 = num_elems >> 5; @@ -146,26 +178,24 @@ static void hvx_fast_softmax_f32(const uint8_t * restrict src, max_vec = Q6_Vsf_vmax_VsfVsf(max_vec, v1); } - HVX_Vector v = hvx_vec_reduce_max_fp32(max_vec); - max_vec = hvx_vec_repl4(v); + max_vec = hvx_vec_reduce_max_f32(max_vec); // replicated over all lanes #pragma unroll(4) for (int i = 0; i < step_of_1; i++) { HVX_Vector v1 = v_src[i]; HVX_Vector v2 = Q6_Vqf32_vsub_VsfVsf(v1, max_vec); - HVX_Vector v3 = hvx_vec_exp_fp32(Q6_Vsf_equals_Vqf32(v2)); + HVX_Vector v3 = hvx_vec_exp_f32(Q6_Vsf_equals_Vqf32(v2)); sum_vec = Q6_Vqf32_vadd_VsfVsf(Q6_Vsf_equals_Vqf32(sum_vec), v3); v_pad[i] = v3; } - v = hvx_vec_qf32_reduce_sum(sum_vec); - sum_vec = hvx_vec_repl4(Q6_Vsf_equals_Vqf32(v)); + sum_vec = hvx_vec_reduce_sum_f32(Q6_Vsf_equals_Vqf32(sum_vec)); // replicated over all lanes HVX_VectorPred pos_sum = Q6_Q_vcmp_gt_VwVw(sum_vec, zero_v); - HVX_Vector v4 = hvx_vec_inverse_fp32(sum_vec); + HVX_Vector v4 = hvx_vec_inverse_f32(sum_vec); HVX_Vector scale_vec = Q6_V_vmux_QVV(pos_sum, v4, one_v); #pragma unroll(4) @@ -176,106 +206,25 @@ static void hvx_fast_softmax_f32(const uint8_t * restrict src, } } -static float hvx_softmax_f32(const uint8_t * restrict src, - uint8_t * restrict dst, - uint8_t * restrict spad, - const int num_elems, - const float max) { - hvx_sub_scalar_f32(src, max, spad, num_elems); +static float hvx_softmax_f32(const uint8_t * restrict src, uint8_t * restrict dst, uint8_t * restrict spad, const int num_elems, const float max) { + hvx_sub_scalar_f32(spad, src, max, num_elems); - hvx_exp_f32(spad, dst, num_elems, false); - - float sum = hvx_self_sum_f32(dst, num_elems); - - return sum; + hvx_exp_f32(dst, spad, num_elems, false); + return hvx_reduce_sum_f32(dst, num_elems); } -static void softmax_htp_f32(int nth, int ith, struct softmax_th_ctx * softmax_ctx, int opt_path) { - struct htp_ops_context * octx = softmax_ctx->octx; +static void softmax_job_f32(unsigned int nth, unsigned int ith, void * data) { + struct htp_softmax_context * smctx = (struct htp_softmax_context *) data; + struct htp_ops_context * octx = smctx->octx; - const struct htp_tensor * src0 = &octx->src0; - const struct htp_tensor * src1 = &octx->src1; - const struct htp_tensor * dst = &octx->dst; - - htp_softmax_preamble3; - - uint8_t * src0_spad_data = octx->src0_spad.data + (ith * nb01); - uint8_t * src1_spad_data = octx->src1_spad.data + (ith * nb01); - uint8_t * dst_spad_data = octx->dst_spad.data + (ith * nb1); - - float * wp0 = (float *) src0_spad_data; - float * wp1 = (float *) src1_spad_data; - float * wp2 = (float *) dst_spad_data; - - for (uint32_t i03 = 0; i03 < ne03; i03++) { - for (uint32_t i02 = 0; i02 < ne02; i02++) { - for (uint32_t i01 = ith; i01 < ne01; i01 += nth) { - const uint32_t i11 = i01; - const uint32_t i12 = i02 % ne12; - const uint32_t i13 = i03 % ne13; - - // ALiBi - const uint32_t h = i02; // head - - const float slope = (softmax_ctx->max_bias > 0.0f) ? - h < softmax_ctx->n_head_log2 ? - powf(softmax_ctx->m0, h + 1) : - powf(softmax_ctx->m1, 2 * (h - softmax_ctx->n_head_log2) + 1) : - 1.0f; - - float * sp = (float *) ((char *) octx->src0.data + i01 * nb01 + i02 * nb02 + i03 * nb03); - float * dp = (float *) ((char *) octx->dst.data + i01 * nb1 + i02 * nb2 + i03 * nb3); - - // broadcast the mask across rows - __fp16 * mp_f16 = (softmax_ctx->use_src1) ? - (__fp16 *) ((char *) octx->src1.data + i11 * nb11 + i12 * nb12 + i13 * nb13) : - NULL; - float * mp_f32 = (softmax_ctx->use_src1) ? - (float *) ((char *) octx->src1.data + i11 * nb11 + i12 * nb12 + i13 * nb13) : - NULL; - - if ((1 == opt_path) && (mp_f32) && !(softmax_ctx->use_f16)) { - hvx_fast_softmax_prep_f32((const uint8_t *) sp, (uint8_t *) wp0, ne00, softmax_ctx->scale, - (const uint8_t *) mp_f32, slope); - } else { - hvx_scale_f32((uint8_t *) wp0, (const uint8_t *) sp, ne00, softmax_ctx->scale); - if (mp_f32) { - if (softmax_ctx->use_f16) { - for (int i = 0; i < ne00; ++i) { - wp0[i] += slope * (float) mp_f16[i]; - } - } else { - for (int i = 0; i < ne00; ++i) { - wp0[i] += slope * mp_f32[i]; - } - } - } - } - - if (1 == opt_path) { - hvx_fast_softmax_f32((const uint8_t *) wp0, (uint8_t *) dp, (uint8_t *) wp1, ne00); - } else { - float max = hvx_self_max_f32((const uint8_t *) wp0, ne00); - float sum = hvx_softmax_f32((const uint8_t *) wp0, (uint8_t *) wp2, (uint8_t *) wp1, ne00, max); - sum = sum > 0.0 ? (1.0 / sum) : 1; - hvx_scale_f32((uint8_t *) dp, (const uint8_t *) wp2, ne00, sum); - } - } - } - } -} - -static void softmax_job_f32_per_thread(struct softmax_th_ctx * softmax_ctx, int nth, int ith) { - struct htp_ops_context * octx = softmax_ctx->octx; - - const struct htp_tensor * src0 = &octx->src0; - const struct htp_tensor * src1 = &octx->src1; - struct htp_tensor * dst = &octx->dst; + const struct htp_tensor * src0 = octx->src[0]; + const struct htp_tensor * src1 = octx->src[1]; + const struct htp_tensor * dst = octx->dst; htp_softmax_preamble3; const uint32_t src0_nrows = ne01 * ne02 * ne03; // src0 rows - const uint32_t src0_nrows_per_thread = octx->src0_nrows_per_thread; + const uint32_t src0_nrows_per_thread = smctx->src0_nrows_per_thread; const uint32_t src0_start_row = src0_nrows_per_thread * ith; const uint32_t src0_end_row = MIN(src0_start_row + src0_nrows_per_thread, src0_nrows); @@ -285,75 +234,136 @@ static void softmax_job_f32_per_thread(struct softmax_th_ctx * softmax_ctx, int return; } - uint64_t t1, t2; - t1 = HAP_perf_get_qtimer_count(); + uint64_t qt = HAP_perf_get_qtimer_count(); int is_aligned = 1; int opt_path = 0; - if (!htp_is_aligned((void *) src0->data, VLEN) || !htp_is_aligned((void *) dst->data, VLEN)) { + + if (!hex_is_aligned((void *) src0->data, VLEN) || !hex_is_aligned((void *) dst->data, VLEN)) { is_aligned = 0; FARF(HIGH, "softmax-f32: unaligned addresses in elementwise op, possibly slower execution\n"); } + + // Only use the fast path when aligned AND row size is multiple of VLEN (128 bytes) + // The fast path (hvx_fast_softmax_f32) doesn't handle tail elements + // The non-opt path uses hvx_softmax_f32 which properly handles all sizes via its helper functions if ((1 == is_aligned) && !(nb01 & (VLEN - 1))) { opt_path = 1; } - softmax_htp_f32(nth, ith, softmax_ctx, opt_path); + uint8_t * src0_spad_data = octx->src0_spad.data + (ith * octx->src0_spad.size_per_thread); + uint8_t * src1_spad_data = octx->src1_spad.data + (ith * octx->src1_spad.size_per_thread); + uint8_t * dst_spad_data = octx->dst_spad.data + (ith * octx->dst_spad.size_per_thread); - t2 = HAP_perf_get_qtimer_count(); + float * wp0 = (float *) src0_spad_data; + float * wp1 = (float *) src1_spad_data; + float * wp2 = (float *) dst_spad_data; - FARF(HIGH, "softmax-f32 %d/%d/%d/%d: %ux%ux%ux%u (%u:%u) x %ux%ux%ux%u -> %ux%ux%ux%u usec %u\n", ith, nth, - softmax_ctx->use_f16, opt_path, ne00, ne01, ne02, ne03, src0_start_row, src0_end_row, ne10, ne11, ne12, ne13, - ne0, ne1, ne2, ne3, (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1)); -} + uint32_t prev_i2 = (uint32_t)-1; + float slope = 1.0f; + + for (uint32_t r = src0_start_row; r < src0_end_row; ++r) { + uint32_t i1 = fastmodulo(r, ne01, &smctx->fastdiv_ne01); + uint32_t r_div_ne01 = fastdiv(r, &smctx->fastdiv_ne01); + uint32_t i2 = fastmodulo(r_div_ne01, ne02, &smctx->fastdiv_ne02); + uint32_t i3 = fastdiv(r_div_ne01, &smctx->fastdiv_ne02); + + // Map to original logic indices + // i01 = i1 + // i02 = i2 + // i03 = i3 + + const uint32_t i11 = i1; + // const uint32_t i12 = i2 % ne12; + // const uint32_t i13 = i3 % ne13; + + uint32_t i12, i13; + if (ne12 == ne02) { + i12 = i2; + } else { + i12 = fastmodulo(i2, ne12, &smctx->fastdiv_ne12); + } -static void softmax_job_dispatcher_f32(unsigned int n, unsigned int i, void * p_data) { - struct softmax_th_ctx * p_softmax_ctx = (struct softmax_th_ctx *) p_data; - softmax_job_f32_per_thread(p_softmax_ctx, n, i); + if (ne13 == ne03) { + i13 = i3; + } else { + i13 = fastmodulo(i3, ne13, &smctx->fastdiv_ne13); + } + + // ALiBi + if (i2 != prev_i2) { + const uint32_t h = i2; // head + slope = (smctx->max_bias > 0.0f) ? h < smctx->n_head_log2 ? powf(smctx->m0, h + 1) : powf(smctx->m1, 2 * (h - smctx->n_head_log2) + 1) : 1.0f; + prev_i2 = i2; + } + + float * sp = (float *) ((char *) src0->data + i1 * nb01 + i2 * nb02 + i3 * nb03); + float * dp = (float *) ((char *) dst->data + i1 * nb1 + i2 * nb2 + i3 * nb3); + + // broadcast the mask across rows + __fp16 * mp_f16 = (smctx->use_src1) ? (__fp16 *) ((char *) src1->data + i11 * nb11 + i12 * nb12 + i13 * nb13) : NULL; + float * mp_f32 = (smctx->use_src1) ? (float *) ((char *) src1->data + i11 * nb11 + i12 * nb12 + i13 * nb13) : NULL; + + if ((1 == opt_path) && (mp_f32) && !(smctx->use_f16)) { + hvx_fast_softmax_prep_f32((const uint8_t *) sp, (uint8_t *) wp0, ne00, smctx->scale, (const uint8_t *) mp_f32, slope); + hvx_fast_softmax_f32((const uint8_t *) wp0, (uint8_t *) dp, (uint8_t *) wp1, ne00); + } else if (1 == opt_path) { + hvx_scale_f32((uint8_t *) wp0, (const uint8_t *) sp, ne00, smctx->scale); + apply_mask(wp0, mp_f32, mp_f16, ne00, slope, smctx->use_f16); + hvx_fast_softmax_f32((const uint8_t *) wp0, (uint8_t *) dp, (uint8_t *) wp1, ne00); + } else { + // Non-optimized path: uses HVX helper functions that properly handle all tensor sizes + // including non-multiples of 32 (the HVX vector lane count for f32) + hvx_scale_f32((uint8_t *) wp0, (const uint8_t *) sp, ne00, smctx->scale); + apply_mask(wp0, mp_f32, mp_f16, ne00, slope, smctx->use_f16); + float max = hvx_reduce_max_f32((const uint8_t *) wp0, ne00); + float sum = hvx_softmax_f32((const uint8_t *) wp0, (uint8_t *) wp2, (uint8_t *) wp1, ne00, max); + sum = sum > 0.0 ? (1.0 / sum) : 1; + hvx_scale_f32((uint8_t *) dp, (const uint8_t *) wp2, ne00, sum); + } + } + + qt = HAP_perf_qtimer_count_to_us(HAP_perf_get_qtimer_count() - qt); + FARF(HIGH, "softmax-f32 %d/%d: %ux%ux%ux%u (%u:%u) x %ux%ux%ux%u -> %ux%ux%ux%u : opt %u f16 %u usec %u\n", ith, nth, + ne00, ne01, ne02, ne03, src0_start_row, src0_end_row, ne10, ne11, ne12, ne13, + ne0, ne1, ne2, ne3, opt_path, smctx->use_f16, (unsigned) qt); } static int execute_op_softmax_f32(struct htp_ops_context * octx) { int err = HTP_STATUS_OK; - const struct htp_tensor * src0 = &octx->src0; - const struct htp_tensor * src1 = &octx->src1; - struct htp_tensor * dst = &octx->dst; + const struct htp_tensor * src0 = octx->src[0]; + const struct htp_tensor * src1 = octx->src[1]; + const struct htp_tensor * dst = octx->dst; - worker_callback_t op_func; - const char * op_type = NULL; + struct htp_softmax_context smctx; + const char * op_type = "softmax-f32"; - struct softmax_th_ctx softmax_ctx; + init_softmax_ctx(&smctx, octx); - switch (octx->op) { - case HTP_OP_SOFTMAX: - op_func = softmax_job_dispatcher_f32; - op_type = "softmax-f32"; + const uint32_t src0_nrows = src0->ne[1] * src0->ne[2] * src0->ne[3]; + const uint32_t n_threads = MIN(octx->n_threads, src0_nrows); - init_softmax_ctx(&softmax_ctx, octx); - break; - - default: - FARF(ERROR, "Unsupported Op %u\n", octx->op); - return HTP_STATUS_NO_SUPPORT; - } - - const uint32_t n_threads = octx->n_threads; + smctx.src0_nrows_per_thread = (src0_nrows + n_threads - 1) / n_threads; const size_t src0_row_size = src0->nb[1]; const size_t src1_row_size = src0_row_size; const size_t dst_row_size = dst->nb[1]; // VTCM scratchpads for all tensors - // N rows per thread, padded to HVX vector size - octx->dst_spad.size = htp_round_up(dst_row_size, 128) * n_threads; - octx->src0_spad.size = htp_round_up(src0_row_size, 128) * n_threads; - octx->src1_spad.size = htp_round_up(src1_row_size, 128) * n_threads; + // 4 rows per thread, padded to HVX vector size + octx->src0_spad.size_per_thread = hex_round_up(4 * src0_row_size, 128); + octx->src1_spad.size_per_thread = hex_round_up(4 * src1_row_size, 128); + octx->dst_spad.size_per_thread = hex_round_up(4 * dst_row_size, 128); + + octx->src0_spad.size = octx->src0_spad.size_per_thread * n_threads; + octx->src1_spad.size = octx->src1_spad.size_per_thread * n_threads; + octx->dst_spad.size = octx->dst_spad.size_per_thread * n_threads; size_t spad_size = octx->src0_spad.size + octx->src1_spad.size + octx->dst_spad.size; - if (src1->ne[0]) { - FARF(HIGH, - "%s: %ux%ux%ux%u x %ux%ux%ux%u -> %ux%ux%ux%u : src0-spad-size %u src1-spad-size %u dst-spad-size %u\n", + if (src1) { + FARF(HIGH, "%s: %ux%ux%ux%u x %ux%ux%ux%u -> %ux%ux%ux%u : src0-spad-size %u src1-spad-size %u dst-spad-size %u\n", op_type, src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], src1->ne[0], src1->ne[1], src1->ne[2], src1->ne[3], dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], octx->src0_spad.size, octx->src1_spad.size, octx->dst_spad.size); @@ -365,22 +375,17 @@ static int execute_op_softmax_f32(struct htp_ops_context * octx) { // Make sure the reserved vtcm size is sufficient if (octx->ctx->vtcm_size < spad_size) { - FARF(ERROR, "%s : current VTCM reservation %zu is too small, needed %zu\n", op_type, octx->ctx->vtcm_size, - spad_size); + FARF(ERROR, "%s : current VTCM reservation %zu is too small, needed %zu\n", op_type, octx->ctx->vtcm_size, spad_size); return HTP_STATUS_VTCM_TOO_SMALL; } - octx->src0_spad.data = octx->ctx->vtcm_base; - octx->src1_spad.data = octx->src0_spad.data + octx->src0_spad.size; - octx->dst_spad.data = octx->src1_spad.data + octx->src1_spad.size; + octx->src0_spad.data = octx->ctx->vtcm_base; octx->src0_spad.src = NULL; + octx->src1_spad.data = octx->src0_spad.data + octx->src0_spad.size; octx->src1_spad.src = NULL; + octx->dst_spad.data = octx->src1_spad.data + octx->src1_spad.size; octx->dst_spad.src = NULL; - uint32_t src0_nrows = src0->ne[1] * src0->ne[2] * src0->ne[3]; + if (octx->flags & HTP_OPFLAGS_SKIP_COMPUTE) return err; - if (!(octx->flags & HTP_OPFLAGS_SKIP_COMPUTE)) { - uint32_t n_jobs = MIN(n_threads, src0_nrows); - octx->src0_nrows_per_thread = (src0_nrows + n_jobs - 1) / n_jobs; - worker_pool_run_func(octx->ctx->worker_pool, op_func, &softmax_ctx, n_jobs); - } + worker_pool_run_func(octx->ctx->worker_pool, softmax_job_f32, &smctx, n_threads); return err; } @@ -388,7 +393,7 @@ static int execute_op_softmax_f32(struct htp_ops_context * octx) { int op_softmax(struct htp_ops_context * octx) { int err = HTP_STATUS_OK; - switch (octx->src0.type) { + switch (octx->src[0]->type) { case HTP_TYPE_F32: err = execute_op_softmax_f32(octx); break; diff --git a/ggml/src/ggml-hexagon/htp/solve-tri-ops.c b/ggml/src/ggml-hexagon/htp/solve-tri-ops.c new file mode 100644 index 00000000000..ae8e1a50495 --- /dev/null +++ b/ggml/src/ggml-hexagon/htp/solve-tri-ops.c @@ -0,0 +1,267 @@ +#pragma clang diagnostic ignored "-Wunused-but-set-variable" + +#include <HAP_farf.h> +#include <HAP_perf.h> +#include <string.h> + +#define GGML_COMMON_DECL_C +#include "ggml-common.h" +#include "htp-ctx.h" +#include "htp-ops.h" +#include "hvx-types.h" +#include "hvx-utils.h" + +struct htp_solve_tri_context { + struct htp_ops_context * octx; + uint32_t jobs_per_thread; + uint32_t total_jobs; + uint32_t k_chunks; + uint32_t col_block; +}; + +static inline void solve_tri_row_scalar(const float * A_row, + const float * B_row, + float * X, + uint32_t row, + uint32_t k, + uint32_t col0, + uint32_t coln, + float inv_diag) { + for (uint32_t col = col0; col < col0 + coln; ++col) { + float sum = 0.0f; + for (uint32_t t = 0; t < row; ++t) { + sum += A_row[t] * X[t * k + col]; + } + X[row * k + col] = (B_row[col] - sum) * inv_diag; + } +} + +static inline HVX_Vector hvx_load_partial_f32(const float * src, uint32_t n) { + HVX_Vector v = *((const HVX_UVector *) src); + HVX_VectorPred mask = Q6_Q_vsetq2_R(n * sizeof(float)); + return Q6_V_vmux_QVV(mask, v, Q6_V_vzero()); +} + +static inline void solve_tri_row_hvx(const float * A_row, + const float * B_row, + float * X, + uint32_t row, + uint32_t k, + uint32_t col0, + uint32_t coln, + float inv_diag) { + const bool full = (coln == VLEN_FP32); + + HVX_Vector sum_v = Q6_V_vzero(); + for (uint32_t t = 0; t < row; ++t) { + const float a = A_row[t]; + const float * x_row_col = X + t * k + col0; + + HVX_Vector x_v = full ? *((const HVX_UVector *) x_row_col) : hvx_load_partial_f32(x_row_col, coln); + HVX_Vector a_v = hvx_vec_splat_f32(a); + sum_v = hvx_vec_add_f32_f32(sum_v, hvx_vec_mul_f32_f32(x_v, a_v)); + } + + const float * b_row_col = B_row + col0; + float * x_out_col = X + row * k + col0; + + HVX_Vector b_v = full ? *((const HVX_UVector *) b_row_col) : hvx_load_partial_f32(b_row_col, coln); + HVX_Vector inv_diag_v = hvx_vec_splat_f32(inv_diag); + + HVX_Vector out_v = hvx_vec_mul_f32_f32(hvx_vec_sub_f32_f32(b_v, sum_v), inv_diag_v); + hvx_vec_store_u((void *) x_out_col, coln * sizeof(float), out_v); +} + +// Batch-level thread: each job is one full batch. +static void solve_tri_batch_thread_f32(unsigned int nth, unsigned int ith, void * data) { + struct htp_solve_tri_context * sctx = (struct htp_solve_tri_context *) data; + struct htp_ops_context * octx = sctx->octx; + + const struct htp_tensor * src0 = octx->src[0]; // A + const struct htp_tensor * src1 = octx->src[1]; // B + const struct htp_tensor * dst = octx->dst; // X + + const uint32_t n = src0->ne[0]; + const uint32_t k = src1->ne[0]; + + const uint32_t ne02 = src0->ne[2]; + + const uint32_t col_block = VLEN_FP32; + const uint32_t k_full = (k / col_block) * col_block; + + const uint32_t start_batch = sctx->jobs_per_thread * ith; + const uint32_t end_batch = MIN(start_batch + sctx->jobs_per_thread, sctx->total_jobs); + + uint64_t t1, t2; + t1 = HAP_perf_get_qtimer_count(); + + for (uint32_t batch = start_batch; batch < end_batch; ++batch) { + const uint32_t i03 = batch / ne02; + const uint32_t i02 = batch - i03 * ne02; + + const float * A_batch = + (const float *) ((const uint8_t *) (uintptr_t) src0->data + i02 * src0->nb[2] + i03 * src0->nb[3]); + const float * B_batch = + (const float *) ((const uint8_t *) (uintptr_t) src1->data + i02 * src1->nb[2] + i03 * src1->nb[3]); + float * X_batch = (float *) ((uint8_t *) (uintptr_t) dst->data + i02 * dst->nb[2] + i03 * dst->nb[3]); + + for (uint32_t row = 0; row < n; ++row) { + const float diag = A_batch[row * n + row]; + const float inv_diag = 1.0f / diag; + const float * A_row = A_batch + row * n; + const float * B_row = B_batch + row * k; + + uint32_t col0 = 0; + for (; col0 < k_full; col0 += col_block) { + solve_tri_row_hvx(A_row, B_row, X_batch, row, k, col0, col_block, inv_diag); + } + + if (col0 < k) { + const uint32_t coln = k - col0; + if (coln >= 8) { + solve_tri_row_hvx(A_row, B_row, X_batch, row, k, col0, coln, inv_diag); + } else { + solve_tri_row_scalar(A_row, B_row, X_batch, row, k, col0, coln, inv_diag); + } + } + } + } + + t2 = HAP_perf_get_qtimer_count(); + + FARF(HIGH, "solve-tri-batch %d/%d: A=(%ux%u) B=(%ux%u) batch %u:%u usec %u\n", + ith, nth, n, n, k, n, start_batch, end_batch, + (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1)); +} + +// Chunk-level thread: each job is one (batch, col_chunk) pair. +static void solve_tri_chunk_thread_f32(unsigned int nth, unsigned int ith, void * data) { + struct htp_solve_tri_context * sctx = (struct htp_solve_tri_context *) data; + struct htp_ops_context * octx = sctx->octx; + + const struct htp_tensor * src0 = octx->src[0]; // A + const struct htp_tensor * src1 = octx->src[1]; // B + const struct htp_tensor * dst = octx->dst; // X + + const uint32_t n = src0->ne[0]; + const uint32_t k = src1->ne[0]; + + const uint32_t ne02 = src0->ne[2]; + + const uint32_t start_job = sctx->jobs_per_thread * ith; + const uint32_t end_job = MIN(start_job + sctx->jobs_per_thread, sctx->total_jobs); + + uint64_t t1, t2; + t1 = HAP_perf_get_qtimer_count(); + + for (uint32_t job = start_job; job < end_job; ++job) { + const uint32_t batch = job / sctx->k_chunks; + const uint32_t chunk = job - batch * sctx->k_chunks; + + const uint32_t i03 = batch / ne02; + const uint32_t i02 = batch - i03 * ne02; + + const uint32_t col0 = chunk * sctx->col_block; + const uint32_t coln = MIN(sctx->col_block, k - col0); + + const float * A_batch = + (const float *) ((const uint8_t *) (uintptr_t) src0->data + i02 * src0->nb[2] + i03 * src0->nb[3]); + const float * B_batch = + (const float *) ((const uint8_t *) (uintptr_t) src1->data + i02 * src1->nb[2] + i03 * src1->nb[3]); + float * X_batch = (float *) ((uint8_t *) (uintptr_t) dst->data + i02 * dst->nb[2] + i03 * dst->nb[3]); + + const bool use_hvx = (coln >= 8); + + for (uint32_t row = 0; row < n; ++row) { + const float diag = A_batch[row * n + row]; + const float inv_diag = 1.0f / diag; + + const float * A_row = A_batch + row * n; + const float * B_row = B_batch + row * k; + + if (use_hvx) { + solve_tri_row_hvx(A_row, B_row, X_batch, row, k, col0, coln, inv_diag); + } else { + solve_tri_row_scalar(A_row, B_row, X_batch, row, k, col0, coln, inv_diag); + } + } + } + + t2 = HAP_perf_get_qtimer_count(); + + FARF(HIGH, "solve-tri-chunk %d/%d: A=(%ux%u) B=(%ux%u) job %u:%u usec %u\n", + ith, nth, n, n, k, n, start_job, end_job, + (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1)); +} + +int op_solve_tri(struct htp_ops_context * octx) { + const struct htp_tensor * src0 = octx->src[0]; // A + const struct htp_tensor * src1 = octx->src[1]; // B + const struct htp_tensor * dst = octx->dst; // X + + if (src0->type != HTP_TYPE_F32 || src1->type != HTP_TYPE_F32 || dst->type != HTP_TYPE_F32) { + return HTP_STATUS_NO_SUPPORT; + } + + // left=true, lower=true, uni=false only + if (src0->ne[0] != src0->ne[1]) { + return HTP_STATUS_INVAL_PARAMS; + } + if (src0->ne[1] != src1->ne[1]) { + return HTP_STATUS_INVAL_PARAMS; + } + if (src0->ne[2] != src1->ne[2] || src0->ne[3] != src1->ne[3]) { + return HTP_STATUS_INVAL_PARAMS; + } + if (dst->ne[0] != src1->ne[0] || dst->ne[1] != src1->ne[1] || dst->ne[2] != src1->ne[2] || + dst->ne[3] != src1->ne[3]) { + return HTP_STATUS_INVAL_PARAMS; + } + + if (octx->flags & HTP_OPFLAGS_SKIP_COMPUTE) { + return HTP_STATUS_OK; + } + + const uint32_t k = src1->ne[0]; + + const uint32_t col_block = VLEN_FP32; + const uint32_t k_chunks = (k + col_block - 1) / col_block; + const uint32_t total_batches = src0->ne[2] * src0->ne[3]; + const bool batched = total_batches >= (uint32_t) octx->n_threads; + + FARF(HIGH, "solve-tri: (%ux%ux%ux%u) x (%ux%ux%ux%u) -> (%ux%ux%ux%u) : batched %d\n", + src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], + src1->ne[0], src1->ne[1], src1->ne[2], src1->ne[3], + dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], batched); + + if (batched) { + // Batch-level parallelism + const uint32_t n_threads = MIN((uint32_t) octx->n_threads, total_batches); + + struct htp_solve_tri_context sctx = { + .octx = octx, + .jobs_per_thread = (total_batches + n_threads - 1) / n_threads, + .total_jobs = total_batches, + .k_chunks = k_chunks, + .col_block = col_block, + }; + + worker_pool_run_func(octx->ctx->worker_pool, solve_tri_batch_thread_f32, &sctx, n_threads); + } else { + // Chunk-level parallelism + const uint32_t total_jobs = total_batches * k_chunks; + const uint32_t n_threads = MIN((uint32_t) octx->n_threads, MAX(total_jobs, 1)); + + struct htp_solve_tri_context sctx = { + .octx = octx, + .jobs_per_thread = (total_jobs + n_threads - 1) / n_threads, + .total_jobs = total_jobs, + .k_chunks = k_chunks, + .col_block = col_block, + }; + + worker_pool_run_func(octx->ctx->worker_pool, solve_tri_chunk_thread_f32, &sctx, n_threads); + } + + return HTP_STATUS_OK; +} diff --git a/ggml/src/ggml-hexagon/htp/ssm-conv.c b/ggml/src/ggml-hexagon/htp/ssm-conv.c new file mode 100644 index 00000000000..d574da2e2bc --- /dev/null +++ b/ggml/src/ggml-hexagon/htp/ssm-conv.c @@ -0,0 +1,432 @@ +#pragma clang diagnostic ignored "-Wunused-variable" +#pragma clang diagnostic ignored "-Wunused-function" +#pragma clang diagnostic ignored "-Wunused-but-set-variable" + +#include <HAP_farf.h> +#include <HAP_mem.h> +#include <HAP_perf.h> +#include <HAP_ps.h> +#include <hexagon_protos.h> +#include <hexagon_types.h> +#include <math.h> +#include <qurt_thread.h> +#include <string.h> + +#define GGML_COMMON_DECL_C +#include "ggml-common.h" +#include "htp-ctx.h" +#include "hex-dma.h" +#include "htp-ops.h" +#include "htp-ops.h" +#include "hvx-utils.h" + +#define htp_ssm_conv_tensors_preamble \ + const struct htp_tensor * restrict src0 = octx->src[0]; \ + const struct htp_tensor * restrict src1 = octx->src[1]; \ + const struct htp_tensor * restrict dst = octx->dst; \ + struct htp_spad * restrict src0_spad = &octx->src0_spad; \ + struct htp_spad * restrict src1_spad = &octx->src1_spad; \ + struct htp_spad * restrict dst_spad = &octx->dst_spad; \ + \ + const uint32_t ne00 = src0->ne[0]; \ + const uint32_t ne01 = src0->ne[1]; \ + const uint32_t ne02 = src0->ne[2]; \ + const uint32_t ne03 = src0->ne[3]; \ + \ + const uint32_t ne10 = src1->ne[0]; \ + const uint32_t ne11 = src1->ne[1]; \ + const uint32_t ne12 = src1->ne[2]; \ + const uint32_t ne13 = src1->ne[3]; \ + \ + const uint32_t ne0 = dst->ne[0]; \ + const uint32_t ne1 = dst->ne[1]; \ + const uint32_t ne2 = dst->ne[2]; \ + const uint32_t ne3 = dst->ne[3]; \ + \ + const uint32_t nb00 = src0->nb[0]; \ + const uint32_t nb01 = src0->nb[1]; \ + const uint32_t nb02 = src0->nb[2]; \ + const uint32_t nb03 = src0->nb[3]; \ + \ + const uint32_t nb10 = src1->nb[0]; \ + const uint32_t nb11 = src1->nb[1]; \ + const uint32_t nb12 = src1->nb[2]; \ + const uint32_t nb13 = src1->nb[3]; \ + \ + const uint32_t nb0 = dst->nb[0]; \ + const uint32_t nb1 = dst->nb[1]; \ + const uint32_t nb2 = dst->nb[2]; \ + const uint32_t nb3 = dst->nb[3]; + +struct htp_ssm_conv_context { + struct htp_ops_context * octx; + uint32_t nrows_per_thread; + uint32_t d_inner_tile; + uint64_t t_start; +}; + +#define htp_ssm_conv_preamble \ + struct htp_ssm_conv_context * scctx = (struct htp_ssm_conv_context *) data; \ + struct htp_ops_context * octx = scctx->octx; \ + htp_ssm_conv_tensors_preamble; \ + dma_queue * dma_queue = octx->ctx->dma[ith]; + +// Scalar FP32 SSM_CONV implementation +static void ssm_conv_thread_f32_f32(unsigned int nth, unsigned int ith, void *data) { + htp_ssm_conv_preamble; + + uint64_t t1, t2; + t1 = HAP_perf_get_qtimer_count(); + + const uint32_t d_conv = src1->ne[0]; + const uint32_t d_inner = src0->ne[1]; + const uint32_t n_t = dst->ne[1]; + const uint32_t n_s = dst->ne[2]; + + const uint32_t src0_stride_inner = src0->nb[1] / sizeof(float); // stride for inner dimension + const uint32_t src0_stride_seq = src0->nb[2] / sizeof(float); // stride for sequence dimension + const uint32_t src1_stride_inner = src1->nb[1] / sizeof(float); // stride for inner dimension + const uint32_t dst_stride_token = dst->nb[1] / sizeof(float); // stride for token dimension + const uint32_t dst_stride_seq = dst->nb[2] / sizeof(float); // stride for sequence dimension + + const float * src0_data = (const float *) src0->data; + const float * src1_data = (const float *) src1->data; + float * dst_data = (float *) dst->data; + + // Calculate row range for this thread + const uint32_t d_inner_per_thread = scctx->nrows_per_thread; + const uint32_t d_inner_start = d_inner_per_thread * ith; + const uint32_t d_inner_end = MIN(d_inner_start + d_inner_per_thread, d_inner); + + // No work for this thread + if (d_inner_start >= d_inner_end) { + return; + } + + for (uint32_t i3 = 0; i3 < n_s; ++i3) { + for (uint32_t i2 = 0; i2 < n_t; ++i2) { + for (uint32_t i1 = d_inner_start; i1 < d_inner_end; ++i1) { + float sumf = 0.0f; + + for (uint32_t i0 = 0; i0 < d_conv; ++i0) { + const uint32_t src0_idx = (i2 + i0) + i1 * src0_stride_inner + i3 * src0_stride_seq; + const uint32_t src1_idx = i0 + i1 * src1_stride_inner; + + sumf += src0_data[src0_idx] * src1_data[src1_idx]; + } + + const uint32_t dst_idx = i1 + i2 * dst_stride_token + i3 * dst_stride_seq; + dst_data[dst_idx] = sumf; + } + } + } + + t2 = HAP_perf_get_qtimer_count(); + + FARF(HIGH, "ssm-conv-f32 %d/%d: %ux%ux%ux%u (%u:%u) * %ux%ux%ux%u -> %ux%ux%ux%u usec %u\n", + ith, nth, src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], d_inner_start, d_inner_end, + src1->ne[0], src1->ne[1], src1->ne[2], src1->ne[3], dst->ne[0], dst->ne[1], + dst->ne[2], dst->ne[3], (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1)); +} + + +// In-register 32x32 fp32 transpose using std 5-stage HVX vshuff butterfly. +static inline void hvx_transpose_32x32_f32(HVX_Vector m[32]) { + HVX_Vector tmp[32]; + + // Stage 0 (R = -4): pair (2i, 2i+1) for i = 0..15. m -> tmp. + for (int i = 0; i < 16; ++i) { + HVX_VectorPair p = Q6_W_vshuff_VVR(m[2*i + 1], m[2*i], -4); + tmp[2*i + 0] = Q6_V_lo_W(p); + tmp[2*i + 1] = Q6_V_hi_W(p); + } + + // Stage 1 (R = -8): per block of 4, pair (b+0, b+2) and (b+1, b+3). tmp -> m. + for (int b = 0; b < 32; b += 4) { + HVX_VectorPair p0 = Q6_W_vshuff_VVR(tmp[b + 2], tmp[b + 0], -8); + HVX_VectorPair p1 = Q6_W_vshuff_VVR(tmp[b + 3], tmp[b + 1], -8); + m[b + 0] = Q6_V_lo_W(p0); m[b + 1] = Q6_V_hi_W(p0); + m[b + 2] = Q6_V_lo_W(p1); m[b + 3] = Q6_V_hi_W(p1); + } + + // Stage 2 (R = -16): per block of 8, pair (b+i, b+i+4) for i = 0..3. m -> tmp. + for (int b = 0; b < 32; b += 8) { + for (int i = 0; i < 4; ++i) { + HVX_VectorPair p = Q6_W_vshuff_VVR(m[b + i + 4], m[b + i], -16); + tmp[b + 2*i + 0] = Q6_V_lo_W(p); + tmp[b + 2*i + 1] = Q6_V_hi_W(p); + } + } + + // Stage 3 (R = -32): per block of 16, pair (b+i, b+i+8) for i = 0..7. tmp -> m. + for (int b = 0; b < 32; b += 16) { + for (int i = 0; i < 8; ++i) { + HVX_VectorPair p = Q6_W_vshuff_VVR(tmp[b + i + 8], tmp[b + i], -32); + m[b + 2*i + 0] = Q6_V_lo_W(p); + m[b + 2*i + 1] = Q6_V_hi_W(p); + } + } + + // Stage 4 (R = -64): pair (i, i+16) for i = 0..15. m -> tmp -> m. + for (int i = 0; i < 16; ++i) { + HVX_VectorPair p = Q6_W_vshuff_VVR(m[i + 16], m[i], -64); + tmp[2 * i + 0] = Q6_V_lo_W(p); + tmp[2 * i + 1] = Q6_V_hi_W(p); + } + + for (int i = 0; i < 32; ++i) { + m[i] = tmp[i]; + } +} + +// HVX FP32 SSM_CONV implementation - channel-vectorized HVX kernel with src0/src1 +// transposed into VTCM. +// +// VTCM layouts (per thread): +// src1_T : {d_inner_per_thread, d_conv} — staged once per launch (small). +// src0_T : {d_inner_tile, ncs} — staged per d_inner-tile. +// +// d_inner_tile is chosen so that per-thread VTCM stays under the budget. +// Each thread iterates ceil(d_inner_per_thread d_inner_tile) tiles serially. +#define HTP_SSM_CONV_VTCM_BUDGET (1u << 20) // 1 MiB per thread + +// Scalar transpose: src1 {d_conv, d_inner} (DDR) -> {d_inner_per_thread, d_conv} (VTCM) +static inline void transpose_src1(const float * src1_data, + uint32_t src1_stride_inner, + uint32_t i1_off, + uint32_t d_inner_per_thread, + uint32_t d_conv, + float * src1_T) { + for (uint32_t i = 0; i < d_inner_per_thread; ++i) { + const float * src_row = src1_data + (i1_off + i) * src1_stride_inner; + for (uint32_t j = 0; j < d_conv; ++j) { + src1_T[j * d_inner_per_thread + i] = src_row[j]; + } + } +} + +// HVX 32x32 src0 transpose: src0 {ncs, d_inner} (DDR) -> src0_T {d_inner_tile, ncs} (VTCM) +static inline void transpose_src0_block(const float * src0_block, + uint32_t ncs, + uint32_t cb_n, + uint32_t d_inner_tile, + float * src0_T_block_dst, + uint32_t cb /* dst column offset */) { + const uint32_t T_TILE = VLEN_FP32; + + HVX_Vector __attribute__((aligned(VLEN))) sub[32]; + + for (uint32_t t0 = 0; t0 < ncs; t0 += T_TILE) { + const uint32_t t_n = MIN(T_TILE, ncs - t0); + + // Load 32 rows (channels) of T_TILE samples; pad missing channels with zeros. + for (uint32_t r = 0; r < cb_n; ++r) { + const float * src_row = src0_block + r * ncs + t0; + if (t_n == T_TILE) { + sub[r] = *(const HVX_UVector *) src_row; + } else { + HVX_Vector v = hvx_vec_splat_f32(0.0f); + hvx_vec_store_u(&v, t_n * sizeof(float), hvx_vec_splat_f32(0.0f)); + + float __attribute__((aligned(VLEN))) tmp[VLEN_FP32] = { 0 }; + for (uint32_t k = 0; k < t_n; ++k) tmp[k] = src_row[k]; + v = *(const HVX_Vector *) tmp; + sub[r] = v; + } + } + for (uint32_t r = cb_n; r < T_TILE; ++r) { + sub[r] = hvx_vec_splat_f32(0.0f); + } + + hvx_transpose_32x32_f32(sub); + + // Store transposed sub-tile to src0_T at offsets (t0 + j) * d_inner_tile + cb. + // Only write the valid t_n rows of the transposed result. + for (uint32_t r = 0; r < t_n; ++r) { + float * dst = src0_T_block_dst + (t0 + r) * d_inner_tile + cb; + if (cb_n == T_TILE) { + *(HVX_UVector *) dst = sub[r]; + } else { + hvx_vec_store_u(dst, cb_n * sizeof(float), sub[r]); + } + } + } +} + +static void ssm_conv_thread_f32_f32_hvx(unsigned int nth, unsigned int ith, void *data) { + htp_ssm_conv_preamble; + + uint64_t t1, t2; + t1 = HAP_perf_get_qtimer_count(); + + const uint32_t d_conv = src1->ne[0]; + const uint32_t d_inner = src0->ne[1]; + const uint32_t n_t = dst->ne[1]; + const uint32_t n_s = dst->ne[2]; + const uint32_t ncs = src0->ne[0]; + + const uint32_t src0_stride_inner = src0->nb[1] / sizeof(float); + const uint32_t src0_stride_seq = src0->nb[2] / sizeof(float); + const uint32_t src1_stride_inner = src1->nb[1] / sizeof(float); + const uint32_t dst_stride_token = dst->nb[1] / sizeof(float); + const uint32_t dst_stride_seq = dst->nb[2] / sizeof(float); + + const uint32_t dr = scctx->nrows_per_thread; + const uint32_t ir0 = dr * ith; + const uint32_t ir1 = MIN(ir0 + dr, d_inner); + + if (ir0 >= ir1) { + return; + } + + const uint32_t d_inner_per_thread = ir1 - ir0; + const uint32_t d_inner_tile = scctx->d_inner_tile; + + const float * src0_data = (const float *) src0->data; + const float * src1_data = (const float *) src1->data; + float * dst_data = (float *) dst->data; + + // Per-thread VTCM regions. + float * src0_T = (float *)(octx->src0_spad.data + ith * octx->src0_spad.size_per_thread); + float * src1_T = (float *)(octx->src1_spad.data + ith * octx->src1_spad.size_per_thread); + + // Stage src1 weights once into VTCM in {d_inner_per_thread, d_conv} layout. + transpose_src1(src1_data, src1_stride_inner, ir0, d_inner_per_thread, d_conv, src1_T); + + const uint32_t C_TILE = VLEN_FP32; + + for (uint32_t i3 = 0; i3 < n_s; ++i3) { + for (uint32_t tile_off = 0; tile_off < d_inner_per_thread; tile_off += d_inner_tile) { + const uint32_t tile_n = MIN(d_inner_tile, d_inner_per_thread - tile_off); + + // Place src0 chunk into VTCM in {d_inner_tile, ncs} layout. + const float * src0_block = src0_data + i3 * src0_stride_seq + (ir0 + tile_off) * src0_stride_inner; + + for (uint32_t cb = 0; cb < tile_n; cb += C_TILE) { + const uint32_t cb_n = MIN(C_TILE, tile_n - cb); + transpose_src0_block(src0_block + cb * src0_stride_inner, ncs, cb_n, d_inner_tile, src0_T, cb); + } + + for (uint32_t t = 0; t < n_t; ++t) { + for (uint32_t cb = 0; cb < tile_n; cb += C_TILE) { + const uint32_t cb_n = MIN(C_TILE, tile_n - cb); + + HVX_Vector acc = hvx_vec_splat_f32(0.0f); + for (uint32_t j = 0; j < d_conv; ++j) { + HVX_Vector x = *(const HVX_Vector *) (src0_T + (t + j) * d_inner_tile + cb); + HVX_Vector w = *(const HVX_Vector *) (src1_T + j * d_inner_per_thread + tile_off + cb); + acc = Q6_Vqf32_vadd_Vqf32Vqf32(acc, Q6_Vqf32_vmpy_VsfVsf(x, w)); + } + HVX_Vector res = Q6_Vsf_equals_Vqf32(acc); + + float * dst_ptr = dst_data + i3 * dst_stride_seq + t * dst_stride_token + (ir0 + tile_off + cb); + if (cb_n == C_TILE) { + *(HVX_UVector *) dst_ptr = res; + } else { + hvx_vec_store_u(dst_ptr, cb_n * sizeof(float), res); + } + } + } + } + } + + t2 = HAP_perf_get_qtimer_count(); + + FARF(HIGH, "ssm-conv-f32-hvx %d/%d: %ux%ux%ux%u (%u:%u) tile=%u * %ux%ux%ux%u -> %ux%ux%ux%u usec %u\n", + ith, nth, src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], ir0, ir1, d_inner_tile, + src1->ne[0], src1->ne[1], src1->ne[2], src1->ne[3], dst->ne[0], dst->ne[1], + dst->ne[2], dst->ne[3], (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1)); +} + +int op_ssm_conv_f32(struct htp_ops_context * octx) { + htp_ssm_conv_tensors_preamble; + + if (src0->type != HTP_TYPE_F32 || src1->type != HTP_TYPE_F32 || dst->type != HTP_TYPE_F32) { + FARF(ERROR, "ssm_conv: only (F32 x F32 -> F32) OPs supported"); + return HTP_STATUS_NO_SUPPORT; + } + + struct htp_ssm_conv_context scctx = { 0 }; + scctx.octx = octx; + + const uint32_t d_conv = src1->ne[0]; + const uint32_t d_inner = src0->ne[1]; + const uint32_t n_t = dst->ne[1]; // tokens per sequence + const uint32_t n_s = dst->ne[2]; // number of sequences in the batch + + const uint32_t n_threads = MIN(octx->n_threads, d_inner); + + if (!(octx->flags & HTP_OPFLAGS_SKIP_COMPUTE)) { + uint32_t use_hvx = 0; + if (d_inner >= VLEN_FP32 && n_t >= VLEN_FP32) { + use_hvx = 1; + } + + scctx.nrows_per_thread = (d_inner + n_threads - 1) / n_threads; + scctx.nrows_per_thread += (scctx.nrows_per_thread & 1); + + const uint32_t d_inner_per_thread = scctx.nrows_per_thread; + const uint32_t ncs = src0->ne[0]; + + const uint32_t src1_T_size = hex_round_up(d_conv * d_inner_per_thread * sizeof(float), 256); + const uint32_t src0_T_max = HTP_SSM_CONV_VTCM_BUDGET > src1_T_size ? HTP_SSM_CONV_VTCM_BUDGET - src1_T_size : 0; + + uint32_t d_inner_tile = (src0_T_max / sizeof(float)) / ncs; + d_inner_tile -= (d_inner_tile % VLEN_FP32); + if (d_inner_tile == 0) { + FARF(HIGH, "ssm_conv-f32: inner tile rounds to 0 (ncs=%u), falling back to scalar\n", ncs); + use_hvx = 0; + } else { + scctx.d_inner_tile = d_inner_tile; + + octx->src0_spad.size_per_thread = hex_round_up(d_inner_tile * ncs * sizeof(float), 256); + octx->src1_spad.size_per_thread = src1_T_size; + octx->dst_spad.size_per_thread = 0; + + octx->src0_spad.size = octx->src0_spad.size_per_thread * n_threads; + octx->src1_spad.size = octx->src1_spad.size_per_thread * n_threads; + octx->dst_spad.size = 0; + + octx->src0_spad.data = octx->ctx->vtcm_base; + octx->src1_spad.data = octx->src0_spad.data + octx->src0_spad.size; + octx->src0_spad.src = NULL; + octx->src1_spad.src = NULL; + + const size_t total_spad = octx->src0_spad.size + octx->src1_spad.size; + if (total_spad > octx->ctx->vtcm_size) { + FARF(HIGH, "ssm_conv-f32: scratchpad %zu exceeds VTCM %zu, falling back to scalar\n", + total_spad, octx->ctx->vtcm_size); + use_hvx = 0; + } + } + + FARF(HIGH, "ssm-conv-f32: (%ux%ux%ux%u) x (%ux%ux%ux%u) -> (%ux%ux%ux%u) : use_hvx %d\n", src0->ne[0], + src0->ne[1], src0->ne[2], src0->ne[3], src1->ne[0], src1->ne[1], src1->ne[2], src1->ne[3], dst->ne[0], + dst->ne[1], dst->ne[2], dst->ne[3], use_hvx); + + if (use_hvx) { + worker_pool_run_func(octx->ctx->worker_pool, ssm_conv_thread_f32_f32_hvx, &scctx, n_threads); + } else { + worker_pool_run_func(octx->ctx->worker_pool, ssm_conv_thread_f32_f32, &scctx, n_threads); + } + } + + return HTP_STATUS_OK; +} + +int op_ssm_conv(struct htp_ops_context * octx) { + const struct htp_tensor * dst = octx->dst; + + int err = HTP_STATUS_OK; + + switch (dst->type) { + case HTP_TYPE_F32: + err = op_ssm_conv_f32(octx); + break; + default: + err = HTP_STATUS_NO_SUPPORT; + break; + } + + return err; +} diff --git a/ggml/src/ggml-hexagon/htp/sum-rows-ops.c b/ggml/src/ggml-hexagon/htp/sum-rows-ops.c new file mode 100644 index 00000000000..874c41ab2ac --- /dev/null +++ b/ggml/src/ggml-hexagon/htp/sum-rows-ops.c @@ -0,0 +1,128 @@ +#pragma clang diagnostic ignored "-Wunused-variable" +#pragma clang diagnostic ignored "-Wunused-function" +#pragma clang diagnostic ignored "-Wunused-but-set-variable" + +#include <HAP_farf.h> +#include <HAP_perf.h> + +#include <string.h> +#include <math.h> + +#include "hex-dma.h" +#include "hvx-utils.h" + +#define GGML_COMMON_DECL_C +#include "ggml-common.h" +#include "htp-ctx.h" +#include "htp-ops.h" +#include "htp-ops.h" + +#define sum_rows_preamble \ + const struct htp_tensor *src0 = octx->src[0]; \ + const struct htp_tensor *dst = octx->dst; \ + \ + const uint32_t ne00 = src0->ne[0]; \ + const uint32_t ne01 = src0->ne[1]; \ + const uint32_t ne02 = src0->ne[2]; \ + const uint32_t ne03 = src0->ne[3]; \ + \ + const uint32_t nb00 = src0->nb[0]; \ + const uint32_t nb01 = src0->nb[1]; \ + const uint32_t nb02 = src0->nb[2]; \ + const uint32_t nb03 = src0->nb[3]; \ + \ + const uint32_t ne0 = dst->ne[0]; \ + const uint32_t ne1 = dst->ne[1]; \ + const uint32_t ne2 = dst->ne[2]; \ + const uint32_t ne3 = dst->ne[3]; \ + \ + const uint32_t nb0 = dst->nb[0]; \ + const uint32_t nb1 = dst->nb[1]; \ + const uint32_t nb2 = dst->nb[2]; \ + const uint32_t nb3 = dst->nb[3]; \ + +struct sum_rows_context { + const uint8_t * src_data; + uint8_t * dst_data; + uint32_t ne00; + size_t src_stride; + size_t dst_stride; + uint32_t rows_per_thread; + uint32_t total_rows; + bool opt_path; +}; + +static void sum_rows_thread_f32(unsigned int nth, unsigned int ith, void *data) { + const struct sum_rows_context * smctx = (const struct sum_rows_context *) data; + + const uint32_t rows_per_thread = smctx->rows_per_thread; + const uint32_t total_rows = smctx->total_rows; + + const uint32_t start_row = rows_per_thread * ith; + const uint32_t end_row = MIN(start_row + rows_per_thread, total_rows); + + if (start_row >= end_row) { + return; + } + + const size_t src_stride = smctx->src_stride; + const size_t dst_stride = smctx->dst_stride; + const uint32_t ne00 = smctx->ne00; + const bool opt_path = smctx->opt_path; + + const float * restrict src_th = (const float *) (smctx->src_data + (start_row * src_stride)); + float * restrict dst_th = (float *) (smctx->dst_data + (start_row * dst_stride)); + + // Calculate actual number of rows for this thread + const uint32_t n_rows = end_row - start_row; + + for (uint32_t ir = 0; ir < n_rows; ir++) { + const float * restrict src_local = src_th + (ir * (src_stride / sizeof(float))); + + if (ir + 1 < n_rows) { + hex_l2fetch(src_local + (src_stride / sizeof(float)), src_stride, src_stride, 1); + } + + if (opt_path) { + dst_th[ir] = hvx_reduce_sum_f32_a((const uint8_t *) src_local, ne00); + } else { + dst_th[ir] = hvx_reduce_sum_f32((const uint8_t *) src_local, ne00); + } + } +} + +int op_sum_rows(struct htp_ops_context * octx) { + sum_rows_preamble; + + if (octx->src[0]->type != HTP_TYPE_F32) { + return HTP_STATUS_NO_SUPPORT; + } + + if (octx->flags & HTP_OPFLAGS_SKIP_COMPUTE) { + return HTP_STATUS_OK; + } + + const uint32_t src0_nrows = ne01 * ne02 * ne03; + const uint32_t n_threads = MIN(octx->n_threads, src0_nrows); + const uint32_t rows_per_thread = (src0_nrows + n_threads - 1) / n_threads; + + bool opt_path = false; + if ((0 == hex_is_aligned((void *) src0->data, VLEN)) && !(nb01 & (VLEN - 1))) { + opt_path = true; + } + + struct sum_rows_context smctx = { + .src_data = (const uint8_t *) src0->data, + .dst_data = (uint8_t *) dst->data, + .ne00 = ne00, + .src_stride = nb01, + .dst_stride = nb1, + .rows_per_thread = rows_per_thread, + .total_rows = src0_nrows, + .opt_path = opt_path, + }; + + worker_pool_run_func(octx->ctx->worker_pool, sum_rows_thread_f32, &smctx, n_threads); + + return HTP_STATUS_OK; +} diff --git a/ggml/src/ggml-hexagon/htp/unary-ops.c b/ggml/src/ggml-hexagon/htp/unary-ops.c index 8ed1e5b6619..71fab2cdbcb 100644 --- a/ggml/src/ggml-hexagon/htp/unary-ops.c +++ b/ggml/src/ggml-hexagon/htp/unary-ops.c @@ -2,28 +2,82 @@ #pragma clang diagnostic ignored "-Wunused-function" #pragma clang diagnostic ignored "-Wunused-but-set-variable" -#ifdef HTP_DEBUG -# define FARF_HIGH 1 -#endif - #include <HAP_farf.h> -#include <HAP_mem.h> #include <HAP_perf.h> -#include <HAP_ps.h> -#include <hexagon_protos.h> -#include <hexagon_types.h> + #include <math.h> -#include <qurt_thread.h> #include <string.h> +#include "hex-dma.h" +#include "hvx-exp.h" +#include "hvx-sigmoid.h" +#include "hvx-utils.h" + #define GGML_COMMON_DECL_C #include "ggml-common.h" #include "htp-ctx.h" -#include "htp-dma.h" -#include "htp-msg.h" #include "htp-ops.h" -#include "hvx-utils.h" -#include "ops-utils.h" + +struct htp_unary_context { + struct htp_ops_context * octx; + + // Precomputed values + const uint8_t * data_src0; + const uint8_t * data_src1; // weight/scale tensor for RMS_NORM_MUL + uint8_t * data_dst; + + size_t src0_data_row_size; // actual data bytes per row + size_t src1_data_row_size; + size_t dst_data_row_size; // actual data bytes per row + + size_t src0_row_size_aligned; + size_t src1_row_size_aligned; + size_t dst_row_size_aligned; + + size_t src0_spad_half_size; + size_t src1_spad_half_size; + size_t dst_spad_half_size; + + uint32_t block; + uint32_t src0_nrows; + uint32_t src0_nrows_per_thread; + uint32_t nc; + bool broadcast_weight; +}; + +// Convert flat row index to DDR byte offset using the tensor's actual strides. +// ir = i1 + ne1*(i2 + ne2*i3) => offset = i1*nb1 + i2*nb2 + i3*nb3 +static inline size_t unary_row_offset(uint32_t ir, + uint32_t ne1, uint32_t ne2, + size_t nb1, size_t nb2, size_t nb3) { + const uint32_t i1 = ir % ne1; + const uint32_t i2 = (ir / ne1) % ne2; + const uint32_t i3 = ir / (ne1 * ne2); + return i1 * nb1 + i2 * nb2 + i3 * nb3; +} +// Safe DMA block size from row `ir`: clamp to the tighter dim-1 slice +// boundary of src and dst so the nb1 stride stays valid for all rows. +static inline uint32_t unary_block_size(uint32_t ir, + uint32_t end_row, + uint32_t block, + bool src_contig, + bool dst_contig, + uint32_t src_ne1, + uint32_t dst_ne1) { + uint32_t limit = MIN(block, end_row - ir); + + if (!src_contig) { + const uint32_t src_slice_end = (ir / src_ne1 + 1) * src_ne1; + limit = MIN(limit, src_slice_end - ir); + } + + if (!dst_contig) { + const uint32_t dst_slice_end = (ir / dst_ne1 + 1) * dst_ne1; + limit = MIN(limit, dst_slice_end - ir); + } + + return limit; +} #define htp_unary_preamble \ const uint32_t ne00 = src->ne[0]; \ @@ -51,110 +105,578 @@ static void hvx_fast_rms_norm_f32(const uint8_t * restrict src, uint8_t * restrict pad, const int num_elems, float epsilon) { + (void)pad; + const HVX_Vector * restrict v_src = (HVX_Vector *) src; HVX_Vector * restrict v_dst = (HVX_Vector *) dst; - HVX_Vector sum_v = Q6_V_vsplat_R(0x00000000); - HVX_Vector epsilon_v = hvx_vec_splat_fp32(epsilon); + const int nvec = num_elems / VLEN_FP32; // number of full vectors + const int nloe = num_elems % VLEN_FP32; // leftover elements + + // Compute sum of squares for full vectors + HVX_Vector sum_v = Q6_V_vsplat_R(0x00000000); + HVX_Vector epsilon_v = hvx_vec_splat_f32(epsilon); + + #pragma unroll(4) + for (int i = 0; i < nvec; i++) { + HVX_Vector v1 = v_src[i]; + HVX_Vector v2 = Q6_Vqf32_vmpy_VsfVsf(v1, v1); + sum_v = Q6_Vqf32_vadd_Vqf32Vqf32(sum_v, v2); + } + + // Handle tail elements using vectorized ops with masking + if (nloe > 0) { + HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 4); + HVX_Vector v1 = Q6_V_vand_QV(bmask, v_src[nvec]); + HVX_Vector v2 = Q6_Vqf32_vmpy_VsfVsf(v1, v1); + sum_v = Q6_Vqf32_vadd_Vqf32Vqf32(sum_v, v2); + } + + // Reduce HVX sum + sum_v = hvx_vec_reduce_sum_f32(Q6_Vsf_equals_Vqf32(sum_v)); + + HVX_Vector t_v = hvx_vec_splat_f32((float) num_elems); + HVX_Vector denom_v = hvx_vec_inverse_f32(t_v); + HVX_Vector mean_v = Q6_Vqf32_vmpy_VsfVsf(sum_v, denom_v); + HVX_Vector mean_epsilon_v = Q6_Vqf32_vadd_Vqf32Vsf(mean_v, epsilon_v); + + // Scale full vectors + HVX_Vector scale_v = hvx_vec_rsqrt_f32(Q6_Vsf_equals_Vqf32(mean_epsilon_v)); - int step_of_1 = num_elems >> 5; #pragma unroll(4) - for (int i = 0; i < step_of_1; i++) { + for (int i = 0; i < nvec; i++) { + HVX_Vector v1 = v_src[i]; + HVX_Vector v2 = Q6_Vqf32_vmpy_VsfVsf(v1, scale_v); + v_dst[i] = Q6_Vsf_equals_Vqf32(v2); + } + + // Handle tail elements using vectorized ops with masking + if (nloe > 0) { + + HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 4); + HVX_Vector v1 = Q6_V_vand_QV(bmask, v_src[nvec]); + HVX_Vector v2 = Q6_Vqf32_vmpy_VsfVsf(v1, scale_v); + HVX_Vector result = Q6_Vsf_equals_Vqf32(v2); + + // Store with masking to avoid overwriting memory beyond the tensor + hvx_vec_store_a(&v_dst[nvec], nloe * 4, result); + } +} + +static void hvx_fast_rms_norm_mul_f32(const uint8_t * restrict src, + const uint8_t * restrict weight, + uint8_t * restrict dst, + const int num_elems, + float epsilon) { + const HVX_Vector * restrict v_src = (const HVX_Vector *) src; + const HVX_Vector * restrict v_weight = (const HVX_Vector *) weight; + HVX_Vector * restrict v_dst = (HVX_Vector *) dst; + + const int nvec = num_elems / VLEN_FP32; // number of full vectors + const int nloe = num_elems % VLEN_FP32; // leftover elements + + // Compute sum of squares for full vectors + HVX_Vector sum_v = Q6_V_vsplat_R(0x00000000); + HVX_Vector epsilon_v = hvx_vec_splat_f32(epsilon); + + #pragma unroll(4) + for (int i = 0; i < nvec; i++) { HVX_Vector v1 = v_src[i]; HVX_Vector v2 = Q6_Vqf32_vmpy_VsfVsf(v1, v1); - sum_v = Q6_Vqf32_vadd_Vqf32Vqf32(sum_v, v2); + sum_v = Q6_Vqf32_vadd_Vqf32Vqf32(sum_v, v2); } - HVX_Vector reduced_sum = hvx_vec_qf32_reduce_sum(sum_v); - sum_v = hvx_vec_repl4(Q6_Vsf_equals_Vqf32(reduced_sum)); + // Handle tail elements using vectorized ops with masking + if (nloe > 0) { + HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 4); + HVX_Vector v1 = Q6_V_vand_QV(bmask, v_src[nvec]); + HVX_Vector v2 = Q6_Vqf32_vmpy_VsfVsf(v1, v1); + sum_v = Q6_Vqf32_vadd_Vqf32Vqf32(sum_v, v2); + } - HVX_Vector t_v = hvx_vec_splat_fp32((float) num_elems); - HVX_Vector denom_v = hvx_vec_inverse_fp32(t_v); + // Reduce HVX sum + sum_v = hvx_vec_reduce_sum_f32(Q6_Vsf_equals_Vqf32(sum_v)); + + HVX_Vector t_v = hvx_vec_splat_f32((float) num_elems); + HVX_Vector denom_v = hvx_vec_inverse_f32(t_v); HVX_Vector mean_v = Q6_Vqf32_vmpy_VsfVsf(sum_v, denom_v); HVX_Vector mean_epsilon_v = Q6_Vqf32_vadd_Vqf32Vsf(mean_v, epsilon_v); - HVX_Vector scale_v = hvx_vec_rsqrt_fp32(Q6_Vsf_equals_Vqf32(mean_epsilon_v)); + // Scale and multiply + HVX_Vector scale_v = hvx_vec_rsqrt_f32(Q6_Vsf_equals_Vqf32(mean_epsilon_v)); #pragma unroll(4) - for (int i = 0; i < step_of_1; i++) { + for (int i = 0; i < nvec; i++) { HVX_Vector v1 = v_src[i]; HVX_Vector v2 = Q6_Vqf32_vmpy_VsfVsf(v1, scale_v); - v_dst[i] = Q6_Vsf_equals_Vqf32(v2); + HVX_Vector v3 = Q6_Vsf_equals_Vqf32(v2); + HVX_Vector result = Q6_Vqf32_vmpy_VsfVsf(v3, v_weight[i]); + v_dst[i] = Q6_Vsf_equals_Vqf32(result); + } + + // Handle tail elements using vectorized ops with masking + if (nloe > 0) { + HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 4); + HVX_Vector v1 = Q6_V_vand_QV(bmask, v_src[nvec]); + HVX_Vector v2 = Q6_Vqf32_vmpy_VsfVsf(v1, scale_v); + HVX_Vector v3 = Q6_Vsf_equals_Vqf32(v2); + HVX_Vector result = Q6_Vqf32_vmpy_VsfVsf(v3, v_weight[nvec]); + HVX_Vector res_v = Q6_Vsf_equals_Vqf32(result); + + // Store with masking to avoid overwriting memory beyond the tensor + hvx_vec_store_a(&v_dst[nvec], nloe * 4, res_v); } } -static void scale_htp_f32(const float * restrict src, - float * restrict dst, - uint8_t * restrict spad, - const uint32_t num_rows, - const uint32_t row_elems, - const size_t row_size, - int32_t * op_params, - int opt_path) { +static void hvx_fast_norm_f32(const uint8_t * restrict src, + uint8_t * restrict dst, + uint8_t * restrict pad, + const int num_elems, + float epsilon) { + (void)pad; + + const HVX_Vector * restrict v_src = (HVX_Vector *) src; + HVX_Vector * restrict v_dst = (HVX_Vector *) dst; + + const int nvec = num_elems / VLEN_FP32; // number of full vectors + const int nloe = num_elems % VLEN_FP32; // leftover elements + + // Compute sum of squares and sum of values for full vectors + HVX_Vector sum_sq_v = Q6_V_vsplat_R(0x00000000); + HVX_Vector sum_x_v = Q6_V_vsplat_R(0x00000000); + HVX_Vector epsilon_v = hvx_vec_splat_f32(epsilon); + + #pragma unroll(4) + for (int i = 0; i < nvec; i++) { + HVX_Vector v1 = v_src[i]; + HVX_Vector v2 = Q6_Vqf32_vmpy_VsfVsf(v1, v1); + sum_sq_v = Q6_Vqf32_vadd_Vqf32Vqf32(sum_sq_v, v2); + sum_x_v = Q6_Vqf32_vadd_Vqf32Vqf32(sum_x_v, Q6_Vqf32_vadd_VsfVsf(v1, Q6_V_vzero())); + } + + // Handle tail elements using vectorized ops with masking + if (nloe > 0) { + HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 4); + HVX_Vector v1 = Q6_V_vand_QV(bmask, v_src[nvec]); + HVX_Vector v2 = Q6_Vqf32_vmpy_VsfVsf(v1, v1); + sum_sq_v = Q6_Vqf32_vadd_Vqf32Vqf32(sum_sq_v, v2); + sum_x_v = Q6_Vqf32_vadd_Vqf32Vqf32(sum_x_v, Q6_Vqf32_vadd_VsfVsf(v1, Q6_V_vzero())); + } + + // Reduce HVX sums + sum_sq_v = hvx_vec_reduce_sum_f32(Q6_Vsf_equals_Vqf32(sum_sq_v)); + sum_x_v = hvx_vec_reduce_sum_f32(Q6_Vsf_equals_Vqf32(sum_x_v)); + + HVX_Vector t_v = hvx_vec_splat_f32((float) num_elems); + HVX_Vector denom_v = hvx_vec_inverse_f32(t_v); + HVX_Vector mean_sq_v = Q6_Vqf32_vmpy_VsfVsf(sum_sq_v, denom_v); + HVX_Vector mean_x_v = Q6_Vqf32_vmpy_VsfVsf(sum_x_v, denom_v); + HVX_Vector mean_x_sq_v = Q6_Vqf32_vmpy_VsfVsf(Q6_Vsf_equals_Vqf32(mean_x_v), Q6_Vsf_equals_Vqf32(mean_x_v)); + HVX_Vector var_v = Q6_Vqf32_vsub_Vqf32Vqf32(mean_sq_v, mean_x_sq_v); + HVX_Vector var_epsilon_v = Q6_Vqf32_vadd_Vqf32Vsf(var_v, epsilon_v); + + // scale = rsqrt(variance + epsilon), mean_x broadcast for subtraction + HVX_Vector scale_v = hvx_vec_rsqrt_f32(Q6_Vsf_equals_Vqf32(var_epsilon_v)); + HVX_Vector mean_x_b = hvx_vec_repl_f32(Q6_Vsf_equals_Vqf32(mean_x_v)); + + #pragma unroll(4) + for (int i = 0; i < nvec; i++) { + HVX_Vector v1 = v_src[i]; + HVX_Vector v2 = Q6_Vqf32_vsub_VsfVsf(v1, mean_x_b); + HVX_Vector v3 = Q6_Vqf32_vmpy_VsfVsf(Q6_Vsf_equals_Vqf32(v2), scale_v); + v_dst[i] = Q6_Vsf_equals_Vqf32(v3); + } + + // Handle tail elements using vectorized ops with masking + if (nloe > 0) { + + HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 4); + HVX_Vector v1 = Q6_V_vand_QV(bmask, v_src[nvec]); + HVX_Vector v2 = Q6_Vqf32_vsub_VsfVsf(v1, mean_x_b); + HVX_Vector v3 = Q6_Vqf32_vmpy_VsfVsf(Q6_Vsf_equals_Vqf32(v2), scale_v); + HVX_Vector result = Q6_Vsf_equals_Vqf32(v3); + + // Store with masking to avoid overwriting memory beyond the tensor + hvx_vec_store_a(&v_dst[nvec], nloe * 4, result); + } +} + +static void scale_f32(const float * restrict src, + float * restrict dst, + uint8_t * restrict spad, + const uint32_t num_rows, + const uint32_t row_elems, + const size_t row_size, + int32_t * op_params) { float scale = 0.f; float bias = 0.f; memcpy(&scale, &op_params[0], sizeof(float)); memcpy(&bias, &op_params[1], sizeof(float)); for (uint32_t ir = 0; ir < num_rows; ir++) { - const float * restrict src_local = src + (ir * row_elems); - float * restrict dst_local = dst + (ir * row_elems); + const uint8_t * restrict src_local = (const uint8_t *)src + (ir * row_size); + uint8_t * restrict dst_local = (uint8_t *)dst + (ir * row_size); - if (ir + 1 < num_rows) { - htp_l2fetch(src_local + row_elems, 1, row_size, row_size); - } + hvx_scale_offset_f32_aa((uint8_t *) dst_local, (const uint8_t *) src_local, row_elems, scale, bias); + } +} + +static void rms_norm_f32(const float * restrict src, + float * restrict dst, + uint8_t * restrict spad, + const uint32_t num_rows, + const uint32_t row_elems, + const size_t row_size, + int32_t * op_params) { + float epsilon = 0.f; + memcpy(&epsilon, op_params, sizeof(float)); - hvx_scale_offset_f32((uint8_t *) dst_local, (const uint8_t *) src_local, row_elems, scale, bias); + for (uint32_t ir = 0; ir < num_rows; ir++) { + const uint8_t * restrict src_local = (const uint8_t *)src + (ir * row_size); + uint8_t * restrict dst_local = (uint8_t *)dst + (ir * row_size); + + hvx_fast_rms_norm_f32((const uint8_t *) src_local, (uint8_t *) dst_local, spad, row_elems, epsilon); } } -static void rms_norm_htp_f32(const float * restrict src, +static void rms_norm_mul_f32(const float * restrict src, + const float * restrict weight, float * restrict dst, - uint8_t * restrict spad, const uint32_t num_rows, const uint32_t row_elems, const size_t row_size, + const size_t weight_row_size, int32_t * op_params, - int opt_path) { + bool broadcast_weight) { float epsilon = 0.f; memcpy(&epsilon, op_params, sizeof(float)); for (uint32_t ir = 0; ir < num_rows; ir++) { - const float * restrict src_local = src + (ir * row_elems); - float * restrict dst_local = dst + (ir * row_elems); + const uint8_t * restrict src_local = (const uint8_t *)src + (ir * row_size); + const uint8_t * restrict w_local = (const uint8_t *)weight + (broadcast_weight ? 0 : ir * weight_row_size); + uint8_t * restrict dst_local = (uint8_t *)dst + (ir * row_size); + + hvx_fast_rms_norm_mul_f32(src_local, w_local, dst_local, row_elems, epsilon); + } +} + +static void norm_f32(const float * restrict src, + float * restrict dst, + uint8_t * restrict spad, + const uint32_t num_rows, + const uint32_t row_elems, + const size_t row_size, + int32_t * op_params) { + float epsilon = 0.f; + memcpy(&epsilon, op_params, sizeof(float)); + + for (uint32_t ir = 0; ir < num_rows; ir++) { + const uint8_t * restrict src_local = (const uint8_t *)src + (ir * row_size); + uint8_t * restrict dst_local = (uint8_t *)dst + (ir * row_size); + + hvx_fast_norm_f32((const uint8_t *) src_local, (uint8_t *) dst_local, spad, row_elems, epsilon); + } +} + +static void sqr_f32(const float * restrict src, + float * restrict dst, + uint8_t * restrict spad, + const uint32_t num_rows, + const uint32_t row_elems, + const size_t row_size, + int32_t * op_params) { + + for (uint32_t ir = 0; ir < num_rows; ir++) { + const uint8_t * restrict src_local = (const uint8_t *)src + (ir * row_size); + uint8_t * restrict dst_local = (uint8_t *)dst + (ir * row_size); + + hvx_sqr_f32_aa((uint8_t *) dst_local, (const uint8_t *) src_local, row_elems); + } +} + +static void sqrt_f32(const float * restrict src, + float * restrict dst, + uint8_t * restrict spad, + const uint32_t num_rows, + const uint32_t row_elems, + const size_t row_size, + int32_t * op_params) { + + for (uint32_t ir = 0; ir < num_rows; ir++) { + const uint8_t * restrict src_local = (const uint8_t *)src + (ir * row_size); + uint8_t * restrict dst_local = (uint8_t *)dst + (ir * row_size); + + hvx_sqrt_f32_aa((uint8_t *) dst_local, (const uint8_t *) src_local, row_elems); + } +} + +static void neg_f32(const float * restrict src, + float * restrict dst, + uint8_t * restrict spad, + const uint32_t num_rows, + const uint32_t row_elems, + const size_t row_size, + int32_t * op_params) { + + for (uint32_t ir = 0; ir < num_rows; ir++) { + const uint8_t * restrict src_local = (const uint8_t *)src + (ir * row_size); + uint8_t * restrict dst_local = (uint8_t *)dst + (ir * row_size); + + hvx_scale_f32_aa(dst_local, src_local, row_elems, -1.0f); + } +} + +static void exp_f32(const float * restrict src, + float * restrict dst, + uint8_t * restrict spad, + const uint32_t num_rows, + const uint32_t row_elems, + const size_t row_size, + int32_t * op_params) { + + for (uint32_t ir = 0; ir < num_rows; ir++) { + const uint8_t * restrict src_local = (const uint8_t *)src + (ir * row_size); + uint8_t * restrict dst_local = (uint8_t *)dst + (ir * row_size); + + hvx_exp_f32(dst_local, src_local, row_elems, false); + } +} + +static void sigmoid_f32(const float * restrict src, + float * restrict dst, + uint8_t * restrict spad, + const uint32_t num_rows, + const uint32_t row_elems, + const size_t row_size, + int32_t * op_params) { + + for (uint32_t ir = 0; ir < num_rows; ir++) { + const uint8_t * restrict src_local = (const uint8_t *)src + (ir * row_size); + uint8_t * restrict dst_local = (uint8_t *)dst + (ir * row_size); - if (ir + 1 < num_rows) { - htp_l2fetch(src_local + row_elems, 1, row_size, row_size); + hvx_sigmoid_f32_aa(dst_local, src_local, row_elems); + } +} + +static void tri_f32(const float * restrict src, + float * restrict dst, + uint8_t * restrict spad, + const uint32_t num_rows, + const uint32_t row_elems, + const size_t row_size, + int32_t * op_params, + const uint32_t ir, + const struct htp_unary_context * uctx) { + + const int32_t ttype = op_params[0]; + const HVX_Vector zero = hvx_vec_splat_f32(0.0f); + const uint32_t nvec = row_elems / VLEN_FP32; + const uint32_t nloe = row_elems % VLEN_FP32; + + const uint32_t ne01 = uctx->octx->src[0]->ne[1]; + + for (uint32_t b = 0; b < num_rows; b++) { + const uint32_t abs_row = ir + b; + const uint32_t i01 = abs_row % ne01; + + const HVX_Vector * restrict v_src = (const HVX_Vector *) ((const uint8_t *) src + b * row_size); + HVX_Vector * restrict v_dst = (HVX_Vector *) ((uint8_t *) dst + b * row_size); + + uint32_t boundary; + int keep_left; + switch (ttype) { + case 0: boundary = i01; keep_left = 0; break; // keep col >= row + case 1: boundary = i01 + 1; keep_left = 0; break; // keep col > row + case 2: boundary = i01 + 1; keep_left = 1; break; // keep col <= row + case 3: boundary = i01; keep_left = 1; break; // keep col < row + default: boundary = 0; keep_left = 0; break; } + if (boundary > row_elems) boundary = row_elems; - if (1 == opt_path) { - hvx_fast_rms_norm_f32((const uint8_t *) src_local, (uint8_t *) dst_local, spad, row_elems, epsilon); - } else { - float sum = hvx_sum_of_squares_f32((const uint8_t *) src_local, row_elems); + // Full HVX vectors — each starts at a 128-byte aligned offset + for (uint32_t i = 0; i < nvec; i++) { + const uint32_t vec_start = i * VLEN_FP32; + const uint32_t vec_end = vec_start + VLEN_FP32; + if (keep_left) { + if (vec_end <= boundary) { + v_dst[i] = v_src[i]; + } else if (vec_start >= boundary) { + v_dst[i] = zero; + } else { + HVX_VectorPred mask = Q6_Q_vsetq_R((boundary - vec_start) * sizeof(float)); + v_dst[i] = Q6_V_vmux_QVV(mask, v_src[i], zero); + } + } else { + if (vec_end <= boundary) { + v_dst[i] = zero; + } else if (vec_start >= boundary) { + v_dst[i] = v_src[i]; + } else { + HVX_VectorPred mask = Q6_Q_vsetq_R((boundary - vec_start) * sizeof(float)); + v_dst[i] = Q6_V_vmux_QVV(mask, zero, v_src[i]); + } + } + } + + // Tail elements (row_elems not a multiple of VLEN_FP32) + if (nloe > 0) { + const uint32_t vec_start = nvec * VLEN_FP32; + const uint32_t vec_end = vec_start + nloe; + HVX_Vector tail_val; + if (keep_left) { + if (vec_end <= boundary) { + tail_val = v_src[nvec]; + } else if (vec_start >= boundary) { + tail_val = zero; + } else { + HVX_VectorPred mask = Q6_Q_vsetq_R((boundary - vec_start) * sizeof(float)); + tail_val = Q6_V_vmux_QVV(mask, v_src[nvec], zero); + } + } else { + if (vec_end <= boundary) { + tail_val = zero; + } else if (vec_start >= boundary) { + tail_val = v_src[nvec]; + } else { + HVX_VectorPred mask = Q6_Q_vsetq_R((boundary - vec_start) * sizeof(float)); + tail_val = Q6_V_vmux_QVV(mask, zero, v_src[nvec]); + } + } + hvx_vec_store_a(&v_dst[nvec], nloe * sizeof(float), tail_val); + } + } +} - const float mean = sum / row_elems; - const float scale = 1.0f / sqrtf(mean + epsilon); +static void softplus_f32(const float * restrict src, + float * restrict dst, + uint8_t * restrict spad, + const uint32_t num_rows, + const uint32_t row_elems, + const size_t row_size, + int32_t * op_params) { + // softplus(x) = log(1 + exp(x)) + // Match CPU reference: ggml_compute_softplus_f32() in ggml-impl.h + for (uint32_t ir = 0; ir < num_rows; ir++) { + const float * restrict src_f = (const float *)((const uint8_t *)src + (ir * row_size)); + float * restrict dst_f = (float *)((uint8_t *)dst + (ir * row_size)); - hvx_scale_f32((uint8_t *) dst_local, (const uint8_t *) src_local, row_elems, scale); + for (uint32_t i = 0; i < row_elems; i++) { + float x = src_f[i]; + // For x > 20: softplus(x) ≈ x (avoids exp overflow) + dst_f[i] = (x > 20.0f) ? x : logf(1.0f + expf(x)); } } } -static void unary_job_f32_per_thread(const struct htp_tensor * src, - struct htp_tensor * dst, - uint8_t * spad, - int htp_op, - int32_t * op_params, - uint32_t nth, - uint32_t ith, - uint32_t src0_nrows_per_thread) { +// --- L2_NORM HVX kernel --- +// Computes y[i] = x[i] / fmax(sqrt(sum(x[j]^2)), epsilon) for each row. +// scale = 1/fmax(sqrt(sum), epsilon) is computed entirely in HVX registers +// using rsqrt + inverse to avoid scalar extraction. +static void hvx_fast_l2_norm_f32(const uint8_t * restrict src, + uint8_t * restrict dst, + uint8_t * restrict pad, + const int num_elems, + float epsilon) { + (void)pad; + + const HVX_Vector * restrict v_src = (HVX_Vector *) src; + HVX_Vector * restrict v_dst = (HVX_Vector *) dst; + + HVX_Vector sum_v = hvx_vec_splat_f32(0.0f); + + const int nvec = num_elems / VLEN_FP32; + const int nloe = num_elems % VLEN_FP32; + + #pragma unroll(4) + for (int i = 0; i < nvec; i++) { + HVX_Vector v1 = v_src[i]; + HVX_Vector sq = Q6_Vqf32_vmpy_VsfVsf(v1, v1); + sum_v = Q6_Vqf32_vadd_Vqf32Vqf32(sum_v, sq); + } + + // Include tail elements in the sum-of-squares using a predicate mask + if (nloe > 0) { + HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 4); + HVX_Vector v1 = Q6_V_vand_QV(bmask, v_src[nvec]); + HVX_Vector sq = Q6_Vqf32_vmpy_VsfVsf(v1, v1); + sum_v = Q6_Vqf32_vadd_Vqf32Vqf32(sum_v, sq); + } + + // Compute scale = 1/fmax(sqrt(sum), epsilon) entirely in HVX registers. + // hvx_vec_rsqrt_f32 + hvx_vec_inverse_f32 avoids scalar extraction. + HVX_Vector sum_sf = hvx_vec_reduce_sum_f32(Q6_Vsf_equals_Vqf32(sum_v)); + HVX_Vector rsqrt_v = hvx_vec_rsqrt_f32(sum_sf); // 1/sqrt(sum) + HVX_Vector sqrt_v = hvx_vec_inverse_f32(rsqrt_v); // sqrt(sum) + HVX_Vector epsilon_v = hvx_vec_splat_f32(epsilon); + HVX_Vector denom_v = Q6_Vsf_vmax_VsfVsf(sqrt_v, epsilon_v); // fmax(sqrt(sum), epsilon) + HVX_Vector scale_v = hvx_vec_inverse_f32(denom_v); // 1/fmax(sqrt(sum), epsilon) + + #pragma unroll(4) + for (int i = 0; i < nvec; i++) { + HVX_Vector v1 = v_src[i]; + v_dst[i] = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(v1, scale_v)); + } + + if (nloe > 0) { + HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 4); + HVX_Vector v1 = Q6_V_vand_QV(bmask, v_src[nvec]); + HVX_Vector result = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(v1, scale_v)); + hvx_vec_store_a(&v_dst[nvec], nloe * 4, result); + } +} + +static void l2_norm_f32(const float * restrict src, + float * restrict dst, + uint8_t * restrict spad, + const uint32_t num_rows, + const uint32_t row_elems, + const size_t row_size, + int32_t * op_params) { + float epsilon = 0.f; + memcpy(&epsilon, op_params, sizeof(float)); + + for (uint32_t ir = 0; ir < num_rows; ir++) { + const float * restrict src_f = (const float *)((const uint8_t *)src + (ir * row_size)); + float * restrict dst_f = (float *)((uint8_t *)dst + (ir * row_size)); + + hvx_fast_l2_norm_f32((const uint8_t *)src_f, (uint8_t *)dst_f, spad, row_elems, epsilon); + } +} + +static void tanh_f32(const float * restrict src, + float * restrict dst, + uint8_t * restrict spad, + const uint32_t num_rows, + const uint32_t row_elems, + const size_t row_size, + int32_t * op_params) { + for (uint32_t ir = 0; ir < num_rows; ir++) { + const uint8_t * restrict src_local = (const uint8_t *)src + (ir * row_size); + uint8_t * restrict dst_local = (uint8_t *)dst + (ir * row_size); + + hvx_tanh_f32_aa(dst_local, src_local, row_elems); + } +} + +static void unary_job_f32_per_thread(unsigned int nth, unsigned int ith, void * data) { + const struct htp_unary_context * uctx = (const struct htp_unary_context *) data; + struct htp_ops_context * octx = uctx->octx; + const struct htp_tensor * src = octx->src[0]; + const struct htp_tensor * dst = octx->dst; + htp_unary_preamble; - const size_t src0_row_size = nb01; - const size_t dst_row_size = nb1; + int htp_op = octx->op; + int32_t * op_params = octx->op_params; + uint32_t src0_nrows_per_thread = uctx->src0_nrows_per_thread; - const uint32_t src0_nrows = ne01 * ne02 * ne03; // src0 rows + const size_t src0_data_row_size = uctx->src0_data_row_size; + const size_t dst_data_row_size = uctx->dst_data_row_size; + const size_t src0_row_size_aligned = uctx->src0_row_size_aligned; + const size_t dst_row_size_aligned = uctx->dst_row_size_aligned; + + const uint32_t src0_nrows = uctx->src0_nrows; const uint32_t src0_start_row = src0_nrows_per_thread * ith; const uint32_t src0_end_row = MIN(src0_start_row + src0_nrows_per_thread, src0_nrows); @@ -166,66 +688,212 @@ static void unary_job_f32_per_thread(const struct htp_tensor * src, uint64_t t1, t2; t1 = HAP_perf_get_qtimer_count(); - int is_aligned = 1; - int opt_path = 0; - if ((0 == htp_is_aligned((void *) src->data, VLEN)) || (0 == htp_is_aligned((void *) dst->data, VLEN))) { - is_aligned = 0; - FARF(HIGH, "unary-f32: unaligned addresses in unary op, possibly slower execution\n"); + const uint8_t * restrict data_src = uctx->data_src0; + const uint8_t * restrict data_src1 = uctx->data_src1; + uint8_t * restrict data_dst = uctx->data_dst; + + const struct htp_tensor * src1 = (htp_op == HTP_OP_RMS_NORM_MUL) ? octx->src[1] : NULL; + const uint32_t nb11 = src1 ? src1->nb[1] : 0; + const uint32_t nb12 = src1 ? src1->nb[2] : 0; + const uint32_t nb13 = src1 ? src1->nb[3] : 0; + + uint8_t * src0_spad_data = octx->src0_spad.data + (ith * octx->src0_spad.size_per_thread); + uint8_t * src1_spad_data = octx->src1_spad.data + (ith * octx->src1_spad.size_per_thread); + uint8_t * dst_spad_data = octx->dst_spad.data + (ith * octx->dst_spad.size_per_thread); + + size_t src0_spad_half_size = uctx->src0_spad_half_size; + size_t src1_spad_half_size = uctx->src1_spad_half_size; + size_t dst_spad_half_size = uctx->dst_spad_half_size; + + // Non-contiguous tensors have gaps at dim-2/3 boundaries that a single-stride + // 2D DMA descriptor cannot span. Clamp BLOCK to ne1 (one dim-1 slice) so every + // transfer stays within a nb1-uniform region. Skipped for contiguous tensors. + const bool src0_contig = (nb02 == (size_t)ne01 * nb01) && + (nb03 == (size_t)ne02 * nb02); + const bool dst_contig = (nb2 == (size_t)ne1 * nb1) && + (nb3 == (size_t)ne2 * nb2); + const uint32_t src0_max_block = src0_contig ? uctx->block : MIN((uint32_t)uctx->block, ne01); + const uint32_t dst_max_block = dst_contig ? uctx->block : MIN((uint32_t)uctx->block, ne1); + const uint32_t BLOCK = MIN(src0_max_block, dst_max_block); + if (BLOCK == 0) { + FARF(ERROR, "unary-f32 : current VTCM reservation %zu is too small for even 1 row per thread, needed at least %zu\n", + octx->src0_spad.size_per_thread, src0_row_size_aligned); + return; } - if ((1 == is_aligned) && !(nb01 & (VLEN - 1))) { - opt_path = 1; + + dma_queue * dma_queue = octx->ctx->dma[ith]; + + // If weight is broadcasted, load it once per thread at the beginning of execution + if (htp_op == HTP_OP_RMS_NORM_MUL && uctx->broadcast_weight) { + dma_queue_push(dma_queue, dma_make_ptr(src1_spad_data, data_src1), uctx->src1_row_size_aligned, 0, uctx->src1_data_row_size, 1); + dma_queue_flush(dma_queue); } - const uint8_t * restrict data_src = (const uint8_t *) src->data; - uint8_t * restrict data_dst = (uint8_t *) dst->data; + for (uint32_t ir = src0_start_row, spad_idx = 0; ir < src0_end_row && spad_idx < 2; spad_idx++) { + const uint32_t block_size = unary_block_size(ir, src0_end_row, BLOCK, src0_contig, dst_contig, ne01, ne1); - const float * restrict src_th = (float *) (data_src + (src0_start_row * src0_row_size)); - float * restrict dst_th = (float *) (data_dst + (src0_start_row * dst_row_size)); - uint8_t * restrict spad_th = (uint8_t *) spad + (ith * nb01); + // Dummy DMA transation for sequencing (interleaving dst,src,dst,...) + dma_queue_push(dma_queue, + dma_make_ptr(data_dst, dst_spad_data + (spad_idx * dst_spad_half_size)), + nb1, dst_row_size_aligned, dst_data_row_size, 0); - switch (htp_op) { - case HTP_OP_RMS_NORM: - rms_norm_htp_f32(src_th, dst_th, spad_th, src0_end_row - src0_start_row, ne0, nb1, op_params, opt_path); - break; - case HTP_OP_SCALE: - scale_htp_f32(src_th, dst_th, spad_th, src0_end_row - src0_start_row, ne0, nb1, op_params, opt_path); - break; + const size_t src0_off = unary_row_offset(ir, ne01, ne02, nb01, nb02, nb03); + dma_queue_push(dma_queue, + dma_make_ptr(src0_spad_data + (spad_idx * src0_spad_half_size), data_src + src0_off), + src0_row_size_aligned, nb01, src0_data_row_size, block_size); - default: - break; + if (htp_op == HTP_OP_RMS_NORM_MUL && !uctx->broadcast_weight) { + const size_t src1_off = unary_row_offset(ir, ne01, ne02, nb11, nb12, nb13); + dma_queue_push(dma_queue, + dma_make_ptr(src1_spad_data + (spad_idx * src1_spad_half_size), data_src1 + src1_off), + uctx->src1_row_size_aligned, nb11, uctx->src1_data_row_size, block_size); + } + + ir += block_size; } + for (uint32_t ir = src0_start_row; ir < src0_end_row; ) { + const uint32_t block_size = unary_block_size(ir, src0_end_row, BLOCK, src0_contig, dst_contig, ne01, ne1); + + float * dst_spad = (float *) dma_queue_pop(dma_queue).src; + float * src0_spad = (float *) dma_queue_pop(dma_queue).dst; + float * src1_spad = NULL; + if (htp_op == HTP_OP_RMS_NORM_MUL && !uctx->broadcast_weight) { + src1_spad = (float *) dma_queue_pop(dma_queue).dst; + } + + // Process block in VTCM + switch (htp_op) { + case HTP_OP_NORM: + norm_f32(src0_spad, dst_spad, NULL, block_size, ne0, src0_row_size_aligned, op_params); + break; + case HTP_OP_RMS_NORM: + rms_norm_f32(src0_spad, dst_spad, NULL, block_size, ne0, src0_row_size_aligned, op_params); + break; + case HTP_OP_RMS_NORM_MUL: + { + const float * w_ptr = uctx->broadcast_weight ? (const float *) src1_spad_data : src1_spad; + rms_norm_mul_f32(src0_spad, w_ptr, dst_spad, block_size, ne0, src0_row_size_aligned, uctx->src1_row_size_aligned, op_params, uctx->broadcast_weight); + } + break; + case HTP_OP_SCALE: + scale_f32(src0_spad, dst_spad, NULL, block_size, ne0, src0_row_size_aligned, op_params); + break; + case HTP_OP_SQR: + sqr_f32(src0_spad, dst_spad, NULL, block_size, ne0, src0_row_size_aligned, op_params); + break; + case HTP_OP_SQRT: + sqrt_f32(src0_spad, dst_spad, NULL, block_size, ne0, src0_row_size_aligned, op_params); + break; + case HTP_OP_UNARY_NEG: + neg_f32(src0_spad, dst_spad, NULL, block_size, ne0, src0_row_size_aligned, op_params); + break; + case HTP_OP_UNARY_EXP: + exp_f32(src0_spad, dst_spad, NULL, block_size, ne0, src0_row_size_aligned, op_params); + break; + case HTP_OP_UNARY_SIGMOID: + sigmoid_f32(src0_spad, dst_spad, NULL, block_size, ne0, src0_row_size_aligned, op_params); + break; + case HTP_OP_UNARY_SOFTPLUS: + softplus_f32(src0_spad, dst_spad, NULL, block_size, ne0, src0_row_size_aligned, op_params); + break; + case HTP_OP_UNARY_TANH: + tanh_f32(src0_spad, dst_spad, NULL, block_size, ne0, src0_row_size_aligned, op_params); + break; + case HTP_OP_L2_NORM: + l2_norm_f32(src0_spad, dst_spad, NULL, block_size, ne0, src0_row_size_aligned, op_params); + break; + case HTP_OP_TRI: + tri_f32(src0_spad, dst_spad, NULL, block_size, ne00, src0_row_size_aligned, op_params, ir, uctx); + break; + default: + break; + } + + const size_t dst_off = unary_row_offset(ir, ne1, ne2, nb1, nb2, nb3); + dma_queue_push(dma_queue, + dma_make_ptr(data_dst + dst_off, dst_spad), + nb1, dst_row_size_aligned, dst_data_row_size, block_size); + + // prefetch N+2 loop iteration if any + const uint32_t next_ir = ir + block_size; + if (next_ir < src0_end_row) { + const uint32_t next_block_size = unary_block_size(next_ir, src0_end_row, BLOCK, src0_contig, dst_contig, ne01, ne1); + const uint32_t pref_ir = next_ir + next_block_size; + if (pref_ir < src0_end_row) { + const uint32_t pref_block_size = unary_block_size(pref_ir, src0_end_row, BLOCK, src0_contig, dst_contig, ne01, ne1); + const size_t src0_pref_off = unary_row_offset(pref_ir, ne01, ne02, nb01, nb02, nb03); + dma_queue_push(dma_queue, + dma_make_ptr(src0_spad, data_src + src0_pref_off), + src0_row_size_aligned, nb01, src0_data_row_size, pref_block_size); + + if (htp_op == HTP_OP_RMS_NORM_MUL && !uctx->broadcast_weight) { + const size_t src1_pref_off = unary_row_offset(pref_ir, ne01, ne02, nb11, nb12, nb13); + dma_queue_push(dma_queue, + dma_make_ptr(src1_spad, data_src1 + src1_pref_off), + uctx->src1_row_size_aligned, nb11, uctx->src1_data_row_size, pref_block_size); + } + } + } + ir += block_size; + } + + dma_queue_flush(dma_queue); + t2 = HAP_perf_get_qtimer_count(); - FARF(HIGH, "unary-f32 %d/%d/%d: %ux%ux%ux%u (%u:%u) -> %ux%ux%ux%u usec %u\n", ith, nth, opt_path, src->ne[0], + FARF(HIGH, "unary-f32 %d/%d: %ux%ux%ux%u (%u:%u) -> %ux%ux%ux%u usec %u\n", ith, nth, src->ne[0], src->ne[1], src->ne[2], src->ne[3], src0_start_row, src0_end_row, dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1)); } -static void unary_job_dispatcher_f32(unsigned int n, unsigned int i, void * data) { - struct htp_ops_context * octx = (struct htp_ops_context *) data; - - unary_job_f32_per_thread(&octx->src0, &octx->dst, octx->src0_spad.data, octx->op, octx->op_params, n, i, - octx->src0_nrows_per_thread); -} - static int execute_op_unary_f32(struct htp_ops_context * octx) { int err = HTP_STATUS_OK; - const struct htp_tensor * src0 = &octx->src0; - struct htp_tensor * dst = &octx->dst; + const struct htp_tensor * src0 = octx->src[0]; + const struct htp_tensor * dst = octx->dst; - worker_callback_t unary_op_func; - const char * op_type = NULL; + const char * op_type = NULL; switch (octx->op) { + case HTP_OP_NORM: + op_type = "norm-f32"; + break; case HTP_OP_RMS_NORM: - unary_op_func = unary_job_dispatcher_f32; - op_type = "rmsnorm-f32"; + op_type = "rmsnorm-f32"; + break; + case HTP_OP_RMS_NORM_MUL: + op_type = "rmsnorm-mul-f32"; break; case HTP_OP_SCALE: - unary_op_func = unary_job_dispatcher_f32; - op_type = "scale-f32"; + op_type = "scale-f32"; + break; + case HTP_OP_SQR: + op_type = "sqr-f32"; + break; + case HTP_OP_SQRT: + op_type = "sqrt-f32"; + break; + case HTP_OP_UNARY_NEG: + op_type = "neg-f32"; + break; + case HTP_OP_UNARY_EXP: + op_type = "exp-f32"; + break; + case HTP_OP_UNARY_SIGMOID: + op_type = "sigmoid-f32"; + break; + case HTP_OP_UNARY_SOFTPLUS: + op_type = "softplus-f32"; + break; + case HTP_OP_UNARY_TANH: + op_type = "tanh-f32"; + break; + case HTP_OP_L2_NORM: + op_type = "l2norm-f32"; + break; + case HTP_OP_TRI: + op_type = "tri-f32"; break; default: @@ -233,38 +901,139 @@ static int execute_op_unary_f32(struct htp_ops_context * octx) { return HTP_STATUS_NO_SUPPORT; } - const int n_threads = octx->n_threads; const uint32_t src0_nrows = src0->ne[1] * src0->ne[2] * src0->ne[3]; + const uint32_t n_threads = MIN(octx->n_threads, src0_nrows); - const size_t src0_row_size = src0->nb[1]; - const size_t dst_row_size = dst->nb[1]; + const size_t src0_data_row_size = src0->ne[0] * sizeof(float); + const size_t dst_data_row_size = dst->ne[0] * sizeof(float); + + const size_t src0_row_size_aligned = hex_round_up(src0_data_row_size, VLEN); + const size_t dst_row_size_aligned = hex_round_up(dst_data_row_size, VLEN); + + size_t src1_data_row_size = 0; + size_t src1_row_size_aligned = 0; + bool broadcast_weight = false; + const struct htp_tensor * src1 = NULL; + + if (octx->op == HTP_OP_RMS_NORM_MUL) { + src1 = octx->src[1]; + src1_data_row_size = src1->ne[0] * sizeof(float); + src1_row_size_aligned = hex_round_up(src1_data_row_size, VLEN); + broadcast_weight = (src1->ne[1] * src1->ne[2] * src1->ne[3] == 1); + } // VTCM scratchpads for all tensors - octx->dst_spad.size = htp_round_up(dst_row_size, 128) * n_threads; - octx->src0_spad.size = htp_round_up(src0_row_size, 128) * n_threads; + // N rows per thread, padded to HVX vector size + // Double buffering requires 2x size per buffer - size_t spad_size = octx->src0_spad.size + octx->dst_spad.size; + size_t spad_size_per_row = 0; + size_t vtcm_row_per_thread = 0; - FARF(HIGH, "%s: (%ux%ux%ux%u) -> (%ux%ux%ux%u) : src0-spad-size %u src1-spad-size %u dst-spad-size %u\n", op_type, - src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], - octx->src0_spad.size, octx->src1_spad.size, octx->dst_spad.size); + if (octx->op == HTP_OP_RMS_NORM_MUL) { + if (broadcast_weight) { + size_t available_vtcm = octx->ctx->vtcm_size; + size_t src1_spad_total = n_threads * src1_row_size_aligned; + if (available_vtcm > src1_spad_total) { + available_vtcm -= src1_spad_total; + } else { + available_vtcm = 0; + } + spad_size_per_row = 2 * (src0_row_size_aligned + dst_row_size_aligned); + vtcm_row_per_thread = available_vtcm / (n_threads * spad_size_per_row); + } else { + spad_size_per_row = 2 * (src0_row_size_aligned + dst_row_size_aligned + src1_row_size_aligned); + vtcm_row_per_thread = (octx->ctx->vtcm_size) / (n_threads * spad_size_per_row); + } + } else { + spad_size_per_row = 2 * (src0_row_size_aligned + dst_row_size_aligned); + vtcm_row_per_thread = (octx->ctx->vtcm_size)/ (n_threads * spad_size_per_row); + } // Make sure the reserved vtcm size is sufficient - if (octx->ctx->vtcm_size < spad_size) { + if (vtcm_row_per_thread == 0) { FARF(ERROR, "unary-%s : current VTCM reservation %zu is too small, needed %zu\n", op_type, octx->ctx->vtcm_size, - spad_size); + spad_size_per_row * n_threads); return HTP_STATUS_VTCM_TOO_SMALL; } + octx->src0_spad.size_per_thread = src0_row_size_aligned * vtcm_row_per_thread * 2; + octx->dst_spad.size_per_thread = dst_row_size_aligned * vtcm_row_per_thread * 2; + + octx->src0_spad.size = n_threads * octx->src0_spad.size_per_thread; + octx->dst_spad.size = n_threads * octx->dst_spad.size_per_thread; + + if (octx->op == HTP_OP_RMS_NORM_MUL) { + if (broadcast_weight) { + octx->src1_spad.size_per_thread = src1_row_size_aligned; + } else { + octx->src1_spad.size_per_thread = src1_row_size_aligned * vtcm_row_per_thread * 2; + } + octx->src1_spad.size = n_threads * octx->src1_spad.size_per_thread; + } else { + octx->src1_spad.size = 0; + octx->src1_spad.size_per_thread = 0; + } + octx->src0_spad.data = octx->ctx->vtcm_base; - octx->dst_spad.data = octx->src0_spad.data + octx->src0_spad.size; + if (octx->op == HTP_OP_RMS_NORM_MUL) { + octx->src1_spad.data = octx->src0_spad.data + octx->src0_spad.size; + octx->dst_spad.data = octx->src1_spad.data + octx->src1_spad.size; + } else { + octx->dst_spad.data = octx->src0_spad.data + octx->src0_spad.size; + } + + octx->src0_spad.src = NULL; + octx->src1_spad.src = NULL; + octx->dst_spad.src = NULL; + + FARF(HIGH, "%s: (%ux%ux%ux%u) -> (%ux%ux%ux%u) : src0-spad-size %u src1-spad-size %u dst-spad-size %u\n", op_type, + src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], + octx->src0_spad.size, octx->src1_spad.size, octx->dst_spad.size); if (!(octx->flags & HTP_OPFLAGS_SKIP_COMPUTE)) { - uint32_t n_jobs = MIN(n_threads, src0_nrows); + struct htp_unary_context uctx = { + .octx = octx, + .src0_nrows_per_thread = (src0_nrows + n_threads - 1) / n_threads, + .src0_nrows = src0_nrows, + + .data_src0 = (const uint8_t *)src0->data, + .data_src1 = (octx->op == HTP_OP_RMS_NORM_MUL) ? (const uint8_t *)src1->data : NULL, + .data_dst = (uint8_t *)dst->data, + + .src0_data_row_size = src0_data_row_size, + .src1_data_row_size = src1_data_row_size, + .dst_data_row_size = dst_data_row_size, - octx->src0_nrows_per_thread = (src0_nrows + n_jobs - 1) / n_jobs; + .src0_row_size_aligned = src0_row_size_aligned, + .src1_row_size_aligned = src1_row_size_aligned, + .dst_row_size_aligned = dst_row_size_aligned, - worker_pool_run_func(octx->ctx->worker_pool, unary_op_func, octx, n_jobs); + .src0_spad_half_size = octx->src0_spad.size_per_thread / 2, + .src1_spad_half_size = (octx->op == HTP_OP_RMS_NORM_MUL) ? (octx->src1_spad.size_per_thread / (broadcast_weight ? 1 : 2)) : 0, + .dst_spad_half_size = octx->dst_spad.size_per_thread / 2, + + .block = (octx->src0_spad.size_per_thread / 2) / src0_row_size_aligned, + .nc = src0->ne[0], + .broadcast_weight = broadcast_weight, + }; + + worker_pool_run_func(octx->ctx->worker_pool, unary_job_f32_per_thread, &uctx, n_threads); + } + + return err; +} + +int op_tri(struct htp_ops_context * octx) { + int err = HTP_STATUS_OK; + + switch (octx->src[0]->type) { + case HTP_TYPE_F32: + err = execute_op_unary_f32(octx); + break; + + default: + err = HTP_STATUS_NO_SUPPORT; + break; } return err; @@ -273,7 +1042,7 @@ static int execute_op_unary_f32(struct htp_ops_context * octx) { int op_unary(struct htp_ops_context * octx) { int err = HTP_STATUS_OK; - switch (octx->src0.type) { + switch (octx->src[0]->type) { case HTP_TYPE_F32: err = execute_op_unary_f32(octx); break; diff --git a/ggml/src/ggml-hexagon/htp/vtcm-utils.h b/ggml/src/ggml-hexagon/htp/vtcm-utils.h new file mode 100644 index 00000000000..b129fb74e31 --- /dev/null +++ b/ggml/src/ggml-hexagon/htp/vtcm-utils.h @@ -0,0 +1,16 @@ +#ifndef VTCM_UTILS_H +#define VTCM_UTILS_H + +#include "hex-utils.h" + +#include <assert.h> +#include <stdint.h> +#include <hexagon_types.h> + +static inline uint8_t *vtcm_seq_alloc(uint8_t **vtcm_ptr, size_t size) { + uint8_t *p = *vtcm_ptr; + *vtcm_ptr += size; + return p; +} + +#endif // VTCM_UTILS_H diff --git a/ggml/src/ggml-hexagon/htp/worker-pool.c b/ggml/src/ggml-hexagon/htp/worker-pool.c index cd38c2126c7..172e28908eb 100644 --- a/ggml/src/ggml-hexagon/htp/worker-pool.c +++ b/ggml/src/ggml-hexagon/htp/worker-pool.c @@ -7,10 +7,6 @@ #include <stdlib.h> #include <string.h> -#ifdef HTP_DEBUG -# define FARF_HIGH 1 -#endif - #include "HAP_farf.h" #define WORKER_THREAD_STACK_SZ (2 * 16384) @@ -60,7 +56,7 @@ static void worker_pool_main(void * context) { unsigned int n = atomic_load(&pool->n_jobs); unsigned int i = atomic_fetch_add(&pool->next_job, 1); if (i >= n) { - // Spurios wakeup + // Spurious wakeup continue; } diff --git a/ggml/src/ggml-hexagon/libdl.h b/ggml/src/ggml-hexagon/libdl.h new file mode 100644 index 00000000000..8ca5016f039 --- /dev/null +++ b/ggml/src/ggml-hexagon/libdl.h @@ -0,0 +1,79 @@ +#pragma once + +#ifdef _WIN32 +# define WIN32_LEAN_AND_MEAN +# ifndef NOMINMAX +# define NOMINMAX +# endif +# include <windows.h> +# include <winevt.h> +#else +# include <dlfcn.h> +# include <unistd.h> +#endif +#include <filesystem> + +namespace fs = std::filesystem; + +#ifdef _WIN32 + +using dl_handle = std::remove_pointer_t<HMODULE>; + +struct dl_handle_deleter { + void operator()(HMODULE handle) { + FreeLibrary(handle); + } +}; + +static inline dl_handle * dl_load_library(const fs::path & path) { + // suppress error dialogs for missing DLLs + DWORD old_mode = SetErrorMode(SEM_FAILCRITICALERRORS); + SetErrorMode(old_mode | SEM_FAILCRITICALERRORS); + + HMODULE handle = LoadLibraryW(path.wstring().c_str()); + + SetErrorMode(old_mode); + + return handle; +} + +static inline void * dl_get_sym(dl_handle * handle, const char * name) { + DWORD old_mode = SetErrorMode(SEM_FAILCRITICALERRORS); + SetErrorMode(old_mode | SEM_FAILCRITICALERRORS); + + void * p = (void *) GetProcAddress(handle, name); + + SetErrorMode(old_mode); + + return p; +} + +static inline const char * dl_error() { + return ""; +} + +#else + +using dl_handle = void; + +struct dl_handle_deleter { + void operator()(void * handle) { + dlclose(handle); + } +}; + +static inline dl_handle * dl_load_library(const fs::path & path) { + dl_handle * handle = dlopen(path.string().c_str(), RTLD_NOW | RTLD_LOCAL); + return handle; +} + +static inline void * dl_get_sym(dl_handle * handle, const char * name) { + return dlsym(handle, name); +} + +static inline const char * dl_error() { + const char *rslt = dlerror(); + return rslt != nullptr ? rslt : ""; +} + +#endif diff --git a/ggml/src/ggml-hexagon/libggml-htp.inf b/ggml/src/ggml-hexagon/libggml-htp.inf new file mode 100644 index 00000000000..39cefcdda38 --- /dev/null +++ b/ggml/src/ggml-hexagon/libggml-htp.inf @@ -0,0 +1,40 @@ +[Version] +Signature = "$WINDOWS NT$" +Class = ComputeAccelerator +ClassGuid = {F01A9D53-3FF6-48D2-9F97-C8A7004BE10C} +Provider = %GGML% +DriverVer = 01/01/2026,1.0.0.0 +CatalogFile = libggml-htp.cat +PnpLockDown = 1 + +[DestinationDirs] +Drivers_Dir = 13 + +[SourceDisksNames] +1 = %DiskId% + +[SourceDisksFiles] +libggml-htp-v68.so = 1 +libggml-htp-v69.so = 1 +libggml-htp-v73.so = 1 +libggml-htp-v75.so = 1 +libggml-htp-v79.so = 1 +libggml-htp-v81.so = 1 + +[ControlFlags] +ExcludeFromSelect = * + +[DefaultInstall.NTarm64] +CopyFiles=Drivers_Dir + +[Drivers_Dir] +libggml-htp-v68.so,,,0x10 ;COPYFLG_NO_OVERWRITE +libggml-htp-v69.so,,,0x10 ;COPYFLG_NO_OVERWRITE +libggml-htp-v73.so,,,0x10 ;COPYFLG_NO_OVERWRITE +libggml-htp-v75.so,,,0x10 ;COPYFLG_NO_OVERWRITE +libggml-htp-v79.so,,,0x10 ;COPYFLG_NO_OVERWRITE +libggml-htp-v81.so,,,0x10 ;COPYFLG_NO_OVERWRITE + +[Strings] +GGML = 'GGML' +DiskId = 'GGML HTP library' diff --git a/ggml/src/ggml-hexagon/op-desc.h b/ggml/src/ggml-hexagon/op-desc.h deleted file mode 100644 index a1e8ddd8b97..00000000000 --- a/ggml/src/ggml-hexagon/op-desc.h +++ /dev/null @@ -1,153 +0,0 @@ -#ifndef OP_DESC_H -#define OP_DESC_H - -#define GGML_COMMON_IMPL_CPP -#include "ggml-backend-impl.h" -#include "ggml-common.h" - -#include <string> -#include <stdio.h> - -struct op_desc { - char strides[64 * GGML_MAX_SRC]; - char dims[64 * GGML_MAX_SRC]; - char types[16 * GGML_MAX_SRC]; - char buffs[64 * GGML_MAX_SRC]; - char names[64 * GGML_MAX_SRC]; - - int format_tensor_dims(char * str, const struct ggml_tensor * t) { - if (t->ne[2] == 1 && t->ne[3] == 1) { - return sprintf(str, "%d:%d", (int) t->ne[0], (int) t->ne[1]); - } else { - return sprintf(str, "%d:%d:%d:%d", (int) t->ne[0], (int) t->ne[1], (int) t->ne[2], (int) t->ne[3]); - } - } - - void format_op_dims(char * str, const struct ggml_tensor * t) { - char * p = str; - - // append src0 and src1 (if any) - if (t->src[0]) { - p += format_tensor_dims(p, t->src[0]); - - for (int i = 1; i < GGML_MAX_SRC && t->src[i]; i++) { - p += sprintf(p, " x "); - p += format_tensor_dims(p, t->src[i]); - } - - p += sprintf(p, " -> "); - } - - // format self dims separately for better visual alignment - char self[64]; - format_tensor_dims(self, t); - - p += sprintf(p, "%s", self); - } - - int format_tensor_strides(char * str, const struct ggml_tensor * t) { - const char * c = ggml_is_contiguous(t) ? "" : "!"; - - if (t->ne[2] == 1 && t->ne[3] == 1) { - return sprintf(str, "%zu:%zu%s", (size_t) t->nb[0], (size_t) t->nb[1], c); - } else { - return sprintf(str, "%zu:%zu:%zu:%zu%s", (size_t) t->nb[0], (size_t) t->nb[1], (size_t) t->nb[2], (size_t) t->nb[3], c); - } - } - - void format_op_strides(char * str, const struct ggml_tensor * t) { - char * p = str; - - // append src0 and src1 (if any) - if (t->src[0]) { - p += format_tensor_strides(p, t->src[0]); - - for (int i = 1; i < GGML_MAX_SRC && t->src[i]; i++) { - p += sprintf(p, " x "); - p += format_tensor_strides(p, t->src[i]); - } - - p += sprintf(p, " -> "); - } - - // format self dims separately for better visual alignment - char self[64]; - format_tensor_strides(self, t); - - p += sprintf(p, "%s", self); - } - - void format_op_types(char * str, const struct ggml_tensor * t) { - char * p = str; - - // append src0 and src1 (if any) - if (t->src[0]) { - p += sprintf(p, "%s", ggml_type_name(t->src[0]->type)); - - for (int i = 1; i < GGML_MAX_SRC && t->src[i]; i++) { - p += sprintf(p, " x "); - p += sprintf(p, "%s", ggml_type_name(t->src[i]->type)); - } - - p += sprintf(p, " -> "); - } - - p += sprintf(p, "%s", ggml_type_name(t->type)); - } - - const char * tensor_buff_name(const struct ggml_tensor * t) { - if (t->buffer) { - return ggml_backend_buffer_name(t->buffer); - } - return "NONE"; - } - - void format_op_buffs(char * str, const struct ggml_tensor * t) { - char * p = str; - - // append src0 and src1 (if any) - if (t->src[0]) { - p += sprintf(p, "%s", tensor_buff_name(t->src[0])); - - for (int i = 1; i < GGML_MAX_SRC && t->src[i]; i++) { - p += sprintf(p, " x "); - p += sprintf(p, "%s", tensor_buff_name(t->src[i])); - } - - p += sprintf(p, " -> "); - } - - p += sprintf(p, "%s", tensor_buff_name(t)); - } - - void format_op_names(char * str, const struct ggml_tensor * t) { - char * p = str; - - // append src0 and src1 (if any) - if (t->src[0]) { - p += sprintf(p, "%s", t->src[0]->name); - - for (int i = 1; i < GGML_MAX_SRC && t->src[i]; i++) { - p += sprintf(p, " x "); - p += sprintf(p, "%s", t->src[i]->name); - } - - p += sprintf(p, " -> "); - } - - p += sprintf(p, "%s", t->name); - } - - void format(const ggml_tensor * op) { - format_op_dims(dims, op); - format_op_strides(strides, op); - format_op_types(types, op); - format_op_buffs(buffs, op); - format_op_names(names, op); - } - - op_desc() {} - op_desc(const ggml_tensor * op) { format(op); } -}; - -#endif // OP_DESC_H diff --git a/ggml/src/ggml-hip/CMakeLists.txt b/ggml/src/ggml-hip/CMakeLists.txt index 23b6889919f..a7d4e0ea2b5 100644 --- a/ggml/src/ggml-hip/CMakeLists.txt +++ b/ggml/src/ggml-hip/CMakeLists.txt @@ -11,6 +11,10 @@ endif() list(APPEND CMAKE_PREFIX_PATH ${ROCM_PATH}) list(APPEND CMAKE_PREFIX_PATH "${ROCM_PATH}/lib64/cmake") +if (NOT DEFINED CMAKE_HIP_FLAGS_DEBUG) + set(CMAKE_HIP_FLAGS_DEBUG "-g -O2") +endif() + # CMake on Windows doesn't support the HIP language yet if (WIN32) set(CXX_IS_HIPCC TRUE) @@ -43,15 +47,16 @@ find_package(hip REQUIRED) find_package(hipblas REQUIRED) find_package(rocblas REQUIRED) +if (GGML_HIP_RCCL) + find_package(rccl REQUIRED) +endif() + if (${hip_VERSION} VERSION_LESS 6.1) message(FATAL_ERROR "At least ROCM/HIP V6.1 is required") endif() message(STATUS "HIP and hipBLAS found") -# Workaround old compilers -set(CMAKE_HIP_FLAGS "${CMAKE_HIP_FLAGS} --gpu-max-threads-per-block=1024") - file(GLOB GGML_HEADERS_ROCM "../ggml-cuda/*.cuh") list(APPEND GGML_HEADERS_ROCM "../../include/ggml-cuda.h") @@ -62,18 +67,19 @@ file(GLOB SRCS "../ggml-cuda/template-instances/fattn-mma*.cu") list(APPEND GGML_SOURCES_ROCM ${SRCS}) file(GLOB SRCS "../ggml-cuda/template-instances/mmq*.cu") list(APPEND GGML_SOURCES_ROCM ${SRCS}) +file(GLOB SRCS "../ggml-cuda/template-instances/mmf*.cu") +list(APPEND GGML_SOURCES_ROCM ${SRCS}) if (GGML_CUDA_FA_ALL_QUANTS) file(GLOB SRCS "../ggml-cuda/template-instances/fattn-vec*.cu") list(APPEND GGML_SOURCES_ROCM ${SRCS}) add_compile_definitions(GGML_CUDA_FA_ALL_QUANTS) else() - file(GLOB SRCS "../ggml-cuda/template-instances/fattn-vec*q4_0-q4_0.cu") - list(APPEND GGML_SOURCES_ROCM ${SRCS}) - file(GLOB SRCS "../ggml-cuda/template-instances/fattn-vec*q8_0-q8_0.cu") - list(APPEND GGML_SOURCES_ROCM ${SRCS}) - file(GLOB SRCS "../ggml-cuda/template-instances/fattn-vec*f16-f16.cu") - list(APPEND GGML_SOURCES_ROCM ${SRCS}) + list(APPEND GGML_SOURCES_ROCM + ../ggml-cuda/template-instances/fattn-vec-instance-f16-f16.cu + ../ggml-cuda/template-instances/fattn-vec-instance-q4_0-q4_0.cu + ../ggml-cuda/template-instances/fattn-vec-instance-q8_0-q8_0.cu + ../ggml-cuda/template-instances/fattn-vec-instance-bf16-bf16.cu) endif() ggml_add_backend_library(ggml-hip @@ -116,6 +122,10 @@ if (NOT GGML_HIP_MMQ_MFMA) add_compile_definitions(GGML_HIP_NO_MMQ_MFMA) endif() +if (GGML_HIP_RCCL) + add_compile_definitions(GGML_USE_NCCL) # RCCL has the same interface as NCCL. +endif() + if (GGML_HIP_EXPORT_METRICS) set(CMAKE_HIP_FLAGS "${CMAKE_HIP_FLAGS} -Rpass-analysis=kernel-resource-usage --save-temps") endif() @@ -126,6 +136,11 @@ endif() if (CXX_IS_HIPCC) set_source_files_properties(${GGML_SOURCES_ROCM} PROPERTIES LANGUAGE CXX) + if (WIN32 AND CMAKE_BUILD_TYPE STREQUAL "Debug") + # CMake on Windows doesn't support the HIP language yet. + # Therefore we workaround debug build's failure on HIP backend this way. + set_source_files_properties(${GGML_SOURCES_ROCM} PROPERTIES COMPILE_FLAGS "-O2 -g") + endif() target_link_libraries(ggml-hip PRIVATE hip::device) else() set_source_files_properties(${GGML_SOURCES_ROCM} PROPERTIES LANGUAGE HIP) @@ -135,4 +150,8 @@ if (GGML_STATIC) message(FATAL_ERROR "Static linking not supported for HIP/ROCm") endif() +if (GGML_HIP_RCCL) + target_link_libraries(ggml-hip PRIVATE ggml-base roc::rccl) +endif() + target_link_libraries(ggml-hip PRIVATE ggml-base hip::host roc::rocblas roc::hipblas) diff --git a/ggml/src/ggml-impl.h b/ggml/src/ggml-impl.h index 80e0fd2ff8b..62b76abbcec 100644 --- a/ggml/src/ggml-impl.h +++ b/ggml/src/ggml-impl.h @@ -30,6 +30,8 @@ extern "C" { void ggml_print_backtrace(void); +uint64_t ggml_graph_next_uid(void); + #ifndef MIN # define MIN(a, b) ((a) < (b) ? (a) : (b)) #endif @@ -98,6 +100,10 @@ static bool ggml_op_is_empty(enum ggml_op op) { } } +static inline bool ggml_impl_is_view(const struct ggml_tensor * t) { + return t->view_src != NULL; +} + static inline float ggml_compute_softplus_f32(float input) { return (input > 20.0f) ? input : logf(1 + expf(input)); } @@ -334,6 +340,10 @@ struct ggml_cgraph { struct ggml_hash_set visited_hash_set; enum ggml_cgraph_eval_order order; + + // an optional identifier that can be utilized to recognize same graphs if two non-zero values match + // a value of 0 means it is not set and should be ignored + uint64_t uid; }; // returns a slice of cgraph with nodes [i0, i1) @@ -487,6 +497,61 @@ static inline float ggml_e8m0_to_fp32_half(uint8_t x) { #define GGML_E8M0_TO_FP32(x) ggml_e8m0_to_fp32(x) #define GGML_E8M0_TO_FP32_HALF(x) ggml_e8m0_to_fp32_half(x) +// UE4M3: unsigned, 4 exp bits (bias=7), 3 mantissa bits +// Returns value * 0.5 to match kvalues_mxfp4 convention (kvalues = 2 * E2M1_float) +static inline float ggml_ue4m3_to_fp32(uint8_t x) { + if (x == 0 || x == 0x7F) { + return 0.0f; + } + int exp = (x >> 3) & 0xF; + int man = x & 0x7; + float raw; + if (exp == 0) { + raw = ldexpf((float) man, -9); + } else { + raw = ldexpf(1.0f + (float) man / 8.0f, exp - 7); + } + return raw * 0.5f; +} + +static inline uint8_t ggml_fp32_to_ue4m3(float x) { + if (!(x > 0.0f)) { + return 0; + } + if (x > 448.0f) { + x = 448.0f; + } + uint32_t bits; + memcpy(&bits, &x, 4); + int fp32_exp = ((bits >> 23) & 0xFF) - 127; + int fp32_man = (bits >> 20) & 0x7; + int ue4m3_exp = fp32_exp + 7; + if (ue4m3_exp <= 0) { + // subnormal: value = man * 2^-9, man = round(x * 2^9) + int man = (int) (x * 512.0f + 0.5f); + if (man > 7) { + man = 7; + } + if (man < 1) { + return 0; + } + return (uint8_t) man; + } + if (ue4m3_exp >= 15) { + return 0x7E; + } + int round_bit = (bits >> 19) & 1; + int ue4m3_man = fp32_man + round_bit; + if (ue4m3_man > 7) { + ue4m3_man = 0; + ue4m3_exp++; + if (ue4m3_exp >= 15) { + return 0x7E; + } + } + return (uint8_t) ((ue4m3_exp << 3) | ue4m3_man); +} + /** * Converts brain16 to float32. * @@ -611,6 +676,9 @@ static inline bool ggml_can_fuse_ext(const struct ggml_cgraph * cgraph, const in if (node->op != ops[i]) { return false; } + if ((node->flags & GGML_TENSOR_FLAG_COMPUTE) == 0) { + return false; + } if (i < num_ops - 1 && !ggml_node_has_n_uses(cgraph, node_idxs[i], 1)) { return false; } @@ -711,6 +779,5 @@ inline bool ggml_check_edges(const struct ggml_cgraph * cgraph, // expose GGUF internals for test code GGML_API size_t gguf_type_size(enum gguf_type type); -GGML_API struct gguf_context * gguf_init_from_file_impl(FILE * file, struct gguf_init_params params); GGML_API void gguf_write_to_buf(const struct gguf_context * ctx, std::vector<int8_t> & buf, bool only_meta); #endif // __cplusplus diff --git a/ggml/src/ggml-metal/CMakeLists.txt b/ggml/src/ggml-metal/CMakeLists.txt index 63418fe1430..42054d841aa 100644 --- a/ggml/src/ggml-metal/CMakeLists.txt +++ b/ggml/src/ggml-metal/CMakeLists.txt @@ -23,11 +23,6 @@ if (GGML_METAL_NDEBUG) add_compile_definitions(GGML_METAL_NDEBUG) endif() -# copy metal files to bin directory -configure_file(../ggml-common.h ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/ggml-common.h COPYONLY) -configure_file(ggml-metal.metal ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/ggml-metal.metal COPYONLY) -configure_file(ggml-metal-impl.h ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/ggml-metal-impl.h COPYONLY) - set(METALLIB_COMMON "${CMAKE_CURRENT_SOURCE_DIR}/../ggml-common.h") if (GGML_METAL_EMBED_LIBRARY) enable_language(ASM) @@ -37,12 +32,12 @@ if (GGML_METAL_EMBED_LIBRARY) set(METALLIB_SOURCE "${CMAKE_CURRENT_SOURCE_DIR}/ggml-metal.metal") set(METALLIB_IMPL "${CMAKE_CURRENT_SOURCE_DIR}/ggml-metal-impl.h") - file(MAKE_DIRECTORY "${CMAKE_BINARY_DIR}/autogenerated") + file(MAKE_DIRECTORY "${CMAKE_CURRENT_BINARY_DIR}/autogenerated") # merge ggml-common.h and ggml-metal.metal into a single file - set(METALLIB_EMBED_ASM "${CMAKE_BINARY_DIR}/autogenerated/ggml-metal-embed.s") - set(METALLIB_SOURCE_EMBED "${CMAKE_BINARY_DIR}/autogenerated/ggml-metal-embed.metal") - set(METALLIB_SOURCE_EMBED_TMP "${CMAKE_BINARY_DIR}/autogenerated/ggml-metal-embed.metal.tmp") + set(METALLIB_EMBED_ASM "${CMAKE_CURRENT_BINARY_DIR}/autogenerated/ggml-metal-embed.s") + set(METALLIB_SOURCE_EMBED "${CMAKE_CURRENT_BINARY_DIR}/autogenerated/ggml-metal-embed.metal") + set(METALLIB_SOURCE_EMBED_TMP "${CMAKE_CURRENT_BINARY_DIR}/autogenerated/ggml-metal-embed.metal.tmp") add_custom_command( OUTPUT "${METALLIB_EMBED_ASM}" @@ -62,6 +57,11 @@ if (GGML_METAL_EMBED_LIBRARY) target_sources(ggml-metal PRIVATE "${METALLIB_EMBED_ASM}") else() + # copy metal files to bin directory + configure_file(../ggml-common.h ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/ggml-common.h COPYONLY) + configure_file(ggml-metal.metal ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/ggml-metal.metal COPYONLY) + configure_file(ggml-metal-impl.h ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/ggml-metal-impl.h COPYONLY) + if (GGML_METAL_SHADER_DEBUG) # custom command to do the following: # xcrun -sdk macosx metal -fno-fast-math -c ggml-metal.metal -o ggml-metal.air @@ -71,7 +71,7 @@ else() # disabling fast math is needed in order to pass tests/test-backend-ops # note: adding -fno-inline fixes the tests when using MTL_SHADER_VALIDATION=1 # note: unfortunately, we have to call it default.metallib instead of ggml.metallib - # ref: https://github.com/ggerganov/whisper.cpp/issues/1720 + # ref: https://github.com/ggml-org/whisper.cpp/issues/1720 # note: adding -g causes segmentation fault during compile #set(XC_FLAGS -fno-fast-math -fno-inline -g) set(XC_FLAGS -fno-fast-math -fno-inline) diff --git a/ggml/src/ggml-metal/ggml-metal-common.cpp b/ggml/src/ggml-metal/ggml-metal-common.cpp index 95627d38665..2eb9820bff9 100644 --- a/ggml/src/ggml-metal/ggml-metal-common.cpp +++ b/ggml/src/ggml-metal/ggml-metal-common.cpp @@ -264,15 +264,26 @@ static std::vector<int> ggml_metal_graph_optimize_reorder(const std::vector<node case GGML_OP_NORM: case GGML_OP_RMS_NORM: case GGML_OP_GROUP_NORM: + case GGML_OP_L2_NORM: case GGML_OP_SUM_ROWS: + case GGML_OP_SSM_CONV: + case GGML_OP_SSM_SCAN: + case GGML_OP_CLAMP: + case GGML_OP_TRI: + case GGML_OP_DIAG: case GGML_OP_MUL: case GGML_OP_ADD: + case GGML_OP_SUB: case GGML_OP_DIV: case GGML_OP_GLU: case GGML_OP_SCALE: + case GGML_OP_UNARY: case GGML_OP_GET_ROWS: - case GGML_OP_CPY: case GGML_OP_SET_ROWS: + case GGML_OP_SET: + case GGML_OP_CPY: + case GGML_OP_CONT: + case GGML_OP_REPEAT: return true; default: return ggml_op_is_empty(op); @@ -312,7 +323,7 @@ static std::vector<int> ggml_metal_graph_optimize_reorder(const std::vector<node h_add(mrs1, node0); // that many nodes forward to search for a concurrent node - constexpr int N_FORWARD = 8; + constexpr int N_FORWARD = 64; for (int i1 = i0 + 1; i1 < i0 + N_FORWARD && i1 < n; i1++) { if (used[i1]) { diff --git a/ggml/src/ggml-metal/ggml-metal-context.h b/ggml/src/ggml-metal/ggml-metal-context.h index ec2b686b733..abf4b06ed2a 100644 --- a/ggml/src/ggml-metal/ggml-metal-context.h +++ b/ggml/src/ggml-metal/ggml-metal-context.h @@ -15,14 +15,22 @@ typedef struct ggml_metal * ggml_metal_t; ggml_metal_t ggml_metal_init(ggml_metal_device_t dev); void ggml_metal_free(ggml_metal_t ctx); +const char * ggml_metal_get_name(ggml_metal_t ctx); + void ggml_metal_synchronize(ggml_metal_t ctx); void ggml_metal_set_tensor_async(ggml_metal_t ctx, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size); void ggml_metal_get_tensor_async(ggml_metal_t ctx, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size); +bool ggml_metal_cpy_tensor_async(ggml_metal_t ctx_src, ggml_metal_t ctx_dst, const struct ggml_tensor * src, struct ggml_tensor * dst); enum ggml_status ggml_metal_graph_compute (ggml_metal_t ctx, struct ggml_cgraph * gf); void ggml_metal_graph_optimize(ggml_metal_t ctx, struct ggml_cgraph * gf); +void ggml_metal_event_record(ggml_metal_t ctx, ggml_metal_event_t ev); +void ggml_metal_event_wait (ggml_metal_t ctx, ggml_metal_event_t ev); + +ggml_metal_event_t ggml_metal_get_ev_cpy(ggml_metal_t ctx); + void ggml_metal_set_n_cb (ggml_metal_t ctx, int n_cb); void ggml_metal_set_abort_callback (ggml_metal_t ctx, ggml_abort_callback abort_callback, void * user_data); bool ggml_metal_supports_family (ggml_metal_t ctx, int family); diff --git a/ggml/src/ggml-metal/ggml-metal-context.m b/ggml/src/ggml-metal/ggml-metal-context.m index 42a35736eea..32d97cd5d0a 100644 --- a/ggml/src/ggml-metal/ggml-metal-context.m +++ b/ggml/src/ggml-metal/ggml-metal-context.m @@ -24,9 +24,13 @@ }; struct ggml_metal { + char name[128]; + ggml_metal_device_t dev; ggml_metal_library_t lib; + ggml_metal_event_t ev_cpy; // for async copies + dispatch_queue_t d_queue; // additional, inference-time compiled pipelines @@ -43,7 +47,7 @@ uint64_t fuse_cnt[GGML_OP_COUNT]; // capture state - bool capture_next_compute; + int capture_compute; bool capture_started; id<MTLCaptureScope> capture_scope; @@ -71,6 +75,10 @@ // abort ggml_metal_graph_compute if callback returns true ggml_abort_callback abort_callback; void * abort_callback_data; + + // error state - set when a command buffer fails during synchronize + // once set, graph_compute will return GGML_STATUS_FAILED until the backend is recreated + bool has_error; }; ggml_metal_t ggml_metal_init(ggml_metal_device_t dev) { @@ -117,7 +125,11 @@ ggml_metal_t ggml_metal_init(ggml_metal_device_t dev) { } } - //const struct ggml_metal_device_props * props_dev = ggml_metal_device_get_props(dev); + res->ev_cpy = ggml_metal_device_event_init(dev); + + const struct ggml_metal_device_props * props_dev = ggml_metal_device_get_props(dev); + + snprintf(res->name, sizeof(res->name), "%s", props_dev->name); res->d_queue = dispatch_queue_create("ggml-metal", DISPATCH_QUEUE_CONCURRENT); @@ -146,10 +158,19 @@ ggml_metal_t ggml_metal_init(ggml_metal_device_t dev) { GGML_LOG_INFO("%s: use concurrency = %s\n", __func__, res->use_concurrency ? "true" : "false"); GGML_LOG_INFO("%s: use graph optimize = %s\n", __func__, res->use_graph_optimize ? "true" : "false"); - res->capture_next_compute = false; + res->capture_compute = 0; res->capture_started = false; res->capture_scope = nil; + { + const char * val = getenv("GGML_METAL_CAPTURE_COMPUTE"); + if (val) { + res->capture_compute = atoi(val); + } + } + + res->has_error = false; + res->gf = nil; res->encode_async = nil; for (int i = 0; i < GGML_METAL_MAX_COMMAND_BUFFERS; ++i) { @@ -206,9 +227,15 @@ void ggml_metal_free(ggml_metal_t ctx) { dispatch_release(ctx->d_queue); + ggml_metal_device_event_free(ctx->dev, ctx->ev_cpy); + free(ctx); } +const char * ggml_metal_get_name(ggml_metal_t ctx) { + return ctx->name; +} + void ggml_metal_synchronize(ggml_metal_t ctx) { // wait for any backend operations to finish if (ctx->cmd_buf_last) { @@ -232,7 +259,8 @@ void ggml_metal_synchronize(ggml_metal_t ctx) { if (status == MTLCommandBufferStatusError) { GGML_LOG_ERROR("error: %s\n", [[cmd_buf error].localizedDescription UTF8String]); } - GGML_ABORT("fatal error"); + ctx->has_error = true; + return; } } } @@ -248,7 +276,15 @@ void ggml_metal_synchronize(ggml_metal_t ctx) { if (status == MTLCommandBufferStatusError) { GGML_LOG_ERROR("error: %s\n", [[cmd_buf error].localizedDescription UTF8String]); } - GGML_ABORT("fatal error"); + + // release this and all remaining command buffers before returning + for (size_t j = i; j < ctx->cmd_bufs_ext.count; ++j) { + [ctx->cmd_bufs_ext[j] release]; + } + [ctx->cmd_bufs_ext removeAllObjects]; + + ctx->has_error = true; + return; } [cmd_buf release]; @@ -273,8 +309,8 @@ void ggml_metal_set_tensor_async(ggml_metal_t ctx, struct ggml_tensor * tensor, // wrap the source data into a Metal buffer id<MTLDevice> device = ggml_metal_device_get_obj(ctx->dev); id<MTLBuffer> buf_src = [device newBufferWithBytes:data - length:size - options:MTLResourceStorageModeShared]; + length:size + options:MTLResourceStorageModeShared]; GGML_ASSERT(buf_src); @@ -316,9 +352,9 @@ void ggml_metal_get_tensor_async(ggml_metal_t ctx, const struct ggml_tensor * te @autoreleasepool { id<MTLDevice> device = ggml_metal_device_get_obj(ctx->dev); id<MTLBuffer> buf_dst = [device newBufferWithBytesNoCopy:data - length:size - options:MTLResourceStorageModeShared - deallocator:nil]; + length:size + options:MTLResourceStorageModeShared + deallocator:nil]; GGML_ASSERT(buf_dst); @@ -356,9 +392,57 @@ void ggml_metal_get_tensor_async(ggml_metal_t ctx, const struct ggml_tensor * te } } +bool ggml_metal_cpy_tensor_async(ggml_metal_t ctx_src, ggml_metal_t ctx_dst, const struct ggml_tensor * src, struct ggml_tensor * dst) { + @autoreleasepool { + struct ggml_metal_buffer_id bid_src = ggml_metal_get_buffer_id(src); + struct ggml_metal_buffer_id bid_dst = ggml_metal_get_buffer_id(dst); + + if (bid_src.metal == nil || bid_dst.metal == nil) { + return false; + } + + // queue the copy operation into the Metal context + // this will be queued at the end, after any currently ongoing GPU operations + id<MTLCommandQueue> queue = ggml_metal_device_get_queue(ctx_src->dev); + id<MTLCommandBuffer> cmd_buf = [queue commandBuffer]; + id<MTLBlitCommandEncoder> encoder = [cmd_buf blitCommandEncoder]; + + [encoder copyFromBuffer:bid_src.metal + sourceOffset:bid_src.offs + toBuffer:bid_dst.metal + destinationOffset:bid_dst.offs + size:ggml_nbytes(src)]; + + [encoder endEncoding]; + + ggml_metal_event_t ev_cpy = ggml_metal_get_ev_cpy(ctx_src); + ggml_metal_event_encode_signal(ev_cpy, cmd_buf); + + [cmd_buf commit]; + + // do not wait here for completion + //[cmd_buf waitUntilCompleted]; + + // instead, remember a reference to the command buffer and wait for it later if needed + [ctx_src->cmd_bufs_ext addObject:cmd_buf]; + ctx_src->cmd_buf_last = cmd_buf; + + [cmd_buf retain]; + + ggml_metal_event_wait(ctx_dst, ev_cpy); + + return true; + } +} + enum ggml_status ggml_metal_graph_compute(ggml_metal_t ctx, struct ggml_cgraph * gf) { + if (ctx->has_error) { + GGML_LOG_ERROR("%s: backend is in error state from a previous command buffer failure - recreate the backend to recover\n", __func__); + return GGML_STATUS_FAILED; + } + // number of nodes encoded by the main thread (empirically determined) - const int n_main = 64; + const int n_main = MAX(64, 0.1*gf->n_nodes); // number of threads in addition to the main thread const int n_cb = ctx->n_cb; @@ -381,9 +465,13 @@ enum ggml_status ggml_metal_graph_compute(ggml_metal_t ctx, struct ggml_cgraph * ctx->n_nodes_per_cb = (ctx->n_nodes_1 + ctx->n_cb - 1) / ctx->n_cb; - const bool use_capture = ctx->capture_next_compute; + if (ctx->capture_compute >= 0) { + ctx->capture_compute--; + } + + const bool use_capture = ctx->capture_compute == 0; if (use_capture) { - ctx->capture_next_compute = false; + ctx->capture_compute = -1; // make sure all previous computations have finished before starting the capture if (ctx->cmd_buf_last) { @@ -392,6 +480,10 @@ enum ggml_status ggml_metal_graph_compute(ggml_metal_t ctx, struct ggml_cgraph * } if (!ctx->capture_started) { + NSString * path = [NSString stringWithFormat:@"/tmp/perf-metal-%d.gputrace", getpid()]; + + GGML_LOG_WARN("%s: capturing graph in %s\n", __func__, [path UTF8String]); + // create capture scope id<MTLDevice> device = ggml_metal_device_get_obj(ctx->dev); ctx->capture_scope = [[MTLCaptureManager sharedCaptureManager] newCaptureScopeWithDevice:device]; @@ -399,7 +491,7 @@ enum ggml_status ggml_metal_graph_compute(ggml_metal_t ctx, struct ggml_cgraph * MTLCaptureDescriptor * descriptor = [MTLCaptureDescriptor new]; descriptor.captureObject = ctx->capture_scope; descriptor.destination = MTLCaptureDestinationGPUTraceDocument; - descriptor.outputURL = [NSURL fileURLWithPath:[NSString stringWithFormat:@"/tmp/perf-metal.gputrace"]]; + descriptor.outputURL = [NSURL fileURLWithPath:path]; NSError * error = nil; if (![[MTLCaptureManager sharedCaptureManager] startCaptureWithDescriptor:descriptor error:&error]) { @@ -462,7 +554,7 @@ enum ggml_status ggml_metal_graph_compute(ggml_metal_t ctx, struct ggml_cgraph * // enter here only when capturing in order to wait for all computation to finish // otherwise, we leave the graph to compute asynchronously - if (!use_capture && ctx->capture_started) { + if (use_capture && ctx->capture_started) { // wait for completion and check status of each command buffer // needed to detect if the device ran out-of-memory for example (#1881) { @@ -514,6 +606,8 @@ enum ggml_status ggml_metal_graph_compute(ggml_metal_t ctx, struct ggml_cgraph * [ctx->capture_scope endScope]; [[MTLCaptureManager sharedCaptureManager] stopCapture]; + + ctx->capture_started = false; } } @@ -530,6 +624,42 @@ void ggml_metal_graph_optimize(ggml_metal_t ctx, struct ggml_cgraph * gf) { //printf("%s: graph optimize took %.3f ms\n", __func__, (ggml_time_us() - t_start) / 1000.0); } +void ggml_metal_event_record(ggml_metal_t ctx, ggml_metal_event_t ev) { + @autoreleasepool { + id<MTLCommandQueue> queue = ggml_metal_device_get_queue(ctx->dev); + id<MTLCommandBuffer> cmd_buf = [queue commandBuffer]; + + ggml_metal_event_encode_signal(ev, cmd_buf); + + [cmd_buf commit]; + + [ctx->cmd_bufs_ext addObject:cmd_buf]; + ctx->cmd_buf_last = cmd_buf; + + [cmd_buf retain]; + } +} + +void ggml_metal_event_wait(ggml_metal_t ctx, ggml_metal_event_t ev) { + @autoreleasepool { + id<MTLCommandQueue> queue = ggml_metal_device_get_queue(ctx->dev); + id<MTLCommandBuffer> cmd_buf = [queue commandBuffer]; + + ggml_metal_event_encode_wait(ev, cmd_buf); + + [cmd_buf commit]; + + [ctx->cmd_bufs_ext addObject:cmd_buf]; + ctx->cmd_buf_last = cmd_buf; + + [cmd_buf retain]; + } +} + +ggml_metal_event_t ggml_metal_get_ev_cpy(ggml_metal_t ctx) { + return ctx->ev_cpy; +} + void ggml_metal_set_n_cb(ggml_metal_t ctx, int n_cb) { if (ctx->n_cb != n_cb) { ctx->n_cb = MIN(n_cb, GGML_METAL_MAX_COMMAND_BUFFERS); @@ -570,7 +700,7 @@ void ggml_metal_set_n_cb(ggml_metal_t ctx, int n_cb) { idx_end, ctx->use_fusion, ctx->use_concurrency, - ctx->capture_next_compute, + ctx->capture_compute, ctx->debug_graph, ctx->debug_fusion); @@ -605,5 +735,5 @@ bool ggml_metal_supports_family(ggml_metal_t ctx, int family) { } void ggml_metal_capture_next_compute(ggml_metal_t ctx) { - ctx->capture_next_compute = true; + ctx->capture_compute = 1; } diff --git a/ggml/src/ggml-metal/ggml-metal-device.cpp b/ggml/src/ggml-metal/ggml-metal-device.cpp index b0734797f19..4f4f073cb61 100644 --- a/ggml/src/ggml-metal/ggml-metal-device.cpp +++ b/ggml/src/ggml-metal/ggml-metal-device.cpp @@ -17,10 +17,12 @@ struct ggml_metal_device_deleter { typedef std::unique_ptr<ggml_metal_device, ggml_metal_device_deleter> ggml_metal_device_ptr; -ggml_metal_device_t ggml_metal_device_get(void) { - static ggml_metal_device_ptr ctx { ggml_metal_device_init() }; +ggml_metal_device_t ggml_metal_device_get(int device) { + static std::vector<ggml_metal_device_ptr> devs; - return ctx.get(); + devs.emplace_back(ggml_metal_device_init(device)); + + return devs.back().get(); } struct ggml_metal_pipelines { @@ -94,6 +96,31 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_cpy(ggml_metal_l return res; } +ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_pool_1d(ggml_metal_library_t lib, const ggml_tensor * op, ggml_op_pool op_pool) { + GGML_ASSERT(ggml_is_contiguous(op->src[0])); + GGML_ASSERT(op->src[0]->type == GGML_TYPE_F32 && op->src[0]->type == op->type); + + const char * pool_str = "undefined"; + switch (op_pool) { + case GGML_OP_POOL_AVG: pool_str = "avg"; break; + case GGML_OP_POOL_MAX: pool_str = "max"; break; + default: GGML_ASSERT(false && "not implemented"); + }; + + char base[256]; + char name[256]; + + snprintf(base, sizeof(base), "kernel_pool_1d_%s_%s", pool_str, ggml_type_name(op->src[0]->type)); + snprintf(name, sizeof(name), "%s", base); + + ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name); + if (!res.pipeline) { + res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); + } + + return res; +} + ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_pool_2d(ggml_metal_library_t lib, const ggml_tensor * op, ggml_op_pool op_pool) { GGML_ASSERT(ggml_is_contiguous(op->src[0])); GGML_ASSERT(op->src[0]->type == GGML_TYPE_F32 && op->src[0]->type == op->type); @@ -149,6 +176,26 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_set_rows(ggml_me return res; } +ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_diag(ggml_metal_library_t lib, const ggml_tensor * op) { + char base[256]; + char name[256]; + + const int n = op->src[0]->ne[0]; + + snprintf(base, 256, "kernel_diag_%s", ggml_type_name(op->src[0]->type)); + snprintf(name, 256, "%s_n=%d", base, n); + + ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name); + if (!res.pipeline) { + res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); + } + + res.nsg = 1; + res.smem = 0; + + return res; +} + ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_repeat(ggml_metal_library_t lib, ggml_type tsrc) { char base[256]; char name[256]; @@ -165,61 +212,74 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_repeat(ggml_meta } ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_unary(ggml_metal_library_t lib, const ggml_tensor * op) { - GGML_ASSERT(ggml_is_contiguous(op->src[0])); - char base[256]; char name[256]; - const int64_t n = ggml_nelements(op); + int op_num = -1; - const char * op_str = "undefined"; switch (op->op) { - case GGML_OP_SCALE: op_str = "scale"; break; - case GGML_OP_FILL: op_str = "fill"; break; - case GGML_OP_CLAMP: op_str = "clamp"; break; - case GGML_OP_SQR: op_str = "sqr"; break; - case GGML_OP_SQRT: op_str = "sqrt"; break; - case GGML_OP_SIN: op_str = "sin"; break; - case GGML_OP_COS: op_str = "cos"; break; - case GGML_OP_LOG: op_str = "log"; break; - case GGML_OP_LEAKY_RELU: op_str = "leaky_relu"; break; + case GGML_OP_SCALE: op_num = OP_UNARY_NUM_SCALE; break; + case GGML_OP_FILL: op_num = OP_UNARY_NUM_FILL; break; + case GGML_OP_CLAMP: op_num = OP_UNARY_NUM_CLAMP; break; + case GGML_OP_SQR: op_num = OP_UNARY_NUM_SQR; break; + case GGML_OP_SQRT: op_num = OP_UNARY_NUM_SQRT; break; + case GGML_OP_SIN: op_num = OP_UNARY_NUM_SIN; break; + case GGML_OP_COS: op_num = OP_UNARY_NUM_COS; break; + case GGML_OP_LOG: op_num = OP_UNARY_NUM_LOG; break; + case GGML_OP_LEAKY_RELU: op_num = OP_UNARY_NUM_LEAKY_RELU; break; case GGML_OP_UNARY: switch (ggml_get_unary_op(op)) { - case GGML_UNARY_OP_TANH: op_str = "tanh"; break; - case GGML_UNARY_OP_RELU: op_str = "relu"; break; - case GGML_UNARY_OP_SIGMOID: op_str = "sigmoid"; break; - case GGML_UNARY_OP_GELU: op_str = "gelu"; break; - case GGML_UNARY_OP_GELU_ERF: op_str = "gelu_erf"; break; - case GGML_UNARY_OP_GELU_QUICK: op_str = "gelu_quick"; break; - case GGML_UNARY_OP_SILU: op_str = "silu"; break; - case GGML_UNARY_OP_ELU: op_str = "elu"; break; - case GGML_UNARY_OP_NEG: op_str = "neg"; break; - case GGML_UNARY_OP_ABS: op_str = "abs"; break; - case GGML_UNARY_OP_SGN: op_str = "sgn"; break; - case GGML_UNARY_OP_STEP: op_str = "step"; break; - case GGML_UNARY_OP_HARDSWISH: op_str = "hardswish"; break; - case GGML_UNARY_OP_HARDSIGMOID: op_str = "hardsigmoid"; break; - case GGML_UNARY_OP_EXP: op_str = "exp"; break; - case GGML_UNARY_OP_SOFTPLUS: op_str = "softplus"; break; - case GGML_UNARY_OP_EXPM1: op_str = "expm1"; break; + case GGML_UNARY_OP_TANH: op_num = OP_UNARY_NUM_TANH; break; + case GGML_UNARY_OP_RELU: op_num = OP_UNARY_NUM_RELU; break; + case GGML_UNARY_OP_SIGMOID: op_num = OP_UNARY_NUM_SIGMOID; break; + case GGML_UNARY_OP_GELU: op_num = OP_UNARY_NUM_GELU; break; + case GGML_UNARY_OP_GELU_ERF: op_num = OP_UNARY_NUM_GELU_ERF; break; + case GGML_UNARY_OP_GELU_QUICK: op_num = OP_UNARY_NUM_GELU_QUICK; break; + case GGML_UNARY_OP_SILU: op_num = OP_UNARY_NUM_SILU; break; + case GGML_UNARY_OP_ELU: op_num = OP_UNARY_NUM_ELU; break; + case GGML_UNARY_OP_NEG: op_num = OP_UNARY_NUM_NEG; break; + case GGML_UNARY_OP_ABS: op_num = OP_UNARY_NUM_ABS; break; + case GGML_UNARY_OP_SGN: op_num = OP_UNARY_NUM_SGN; break; + case GGML_UNARY_OP_STEP: op_num = OP_UNARY_NUM_STEP; break; + case GGML_UNARY_OP_HARDSWISH: op_num = OP_UNARY_NUM_HARDSWISH; break; + case GGML_UNARY_OP_HARDSIGMOID: op_num = OP_UNARY_NUM_HARDSIGMOID; break; + case GGML_UNARY_OP_EXP: op_num = OP_UNARY_NUM_EXP; break; + case GGML_UNARY_OP_SOFTPLUS: op_num = OP_UNARY_NUM_SOFTPLUS; break; + case GGML_UNARY_OP_EXPM1: op_num = OP_UNARY_NUM_EXPM1; break; + case GGML_UNARY_OP_FLOOR: op_num = OP_UNARY_NUM_FLOOR; break; + case GGML_UNARY_OP_CEIL: op_num = OP_UNARY_NUM_CEIL; break; + case GGML_UNARY_OP_ROUND: op_num = OP_UNARY_NUM_ROUND; break; + case GGML_UNARY_OP_TRUNC: op_num = OP_UNARY_NUM_TRUNC; break; + case GGML_UNARY_OP_XIELU: op_num = OP_UNARY_NUM_XIELU; break; default: GGML_ABORT("fatal error"); } break; default: GGML_ABORT("fatal error"); }; - const char * suffix = ""; - if (n % 4 == 0) { - suffix = "_4"; - } + const char * t0_str = ggml_type_name(op->src[0]->type); + const char * t_str = ggml_type_name(op->type); - snprintf(base, 256, "kernel_%s_%s%s", op_str, ggml_type_name(op->src[0]->type), suffix); - snprintf(name, 256, "%s", base); + const bool is_c4 = op->src[0]->ne[0] % 4 == 0; + const bool is_cnt = ggml_is_contiguous(op->src[0]) && ggml_nelements(op) < 32768; + + snprintf(base, 256, "kernel_unary_%s_%s%s", t0_str, t_str, is_c4 ? "_4" : ""); + snprintf(name, 256, "%s_op=%d_cnt=%d", base, op_num, is_cnt); ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name); if (!res.pipeline) { - res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); + ggml_metal_cv_t cv = ggml_metal_cv_init(); + + ggml_metal_cv_set_int16(cv, op_num, FC_UNARY + 0); + ggml_metal_cv_set_bool (cv, is_cnt, FC_UNARY + 1); + + res = ggml_metal_library_compile_pipeline(lib, base, name, cv); + + ggml_metal_cv_free(cv); } + res.c4 = is_c4; + res.cnt = is_cnt; + return res; } @@ -273,31 +333,46 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_sum(ggml_metal_l } ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_sum_rows(ggml_metal_library_t lib, const ggml_tensor * op) { - GGML_ASSERT(op->src[0]->nb[0] == ggml_type_size(op->src[0]->type)); + GGML_ASSERT(ggml_is_contiguous_rows(op->src[0])); char base[256]; char name[256]; - const char * op_str = "undefined"; + int op_num = -1; + switch (op->op) { - case GGML_OP_SUM_ROWS: - op_str = "sum_rows"; break; - case GGML_OP_MEAN: - op_str = "mean"; break; + case GGML_OP_SUM_ROWS: op_num = OP_SUM_ROWS_NUM_SUM_ROWS; break; + case GGML_OP_MEAN: op_num = OP_SUM_ROWS_NUM_MEAN; break; default: GGML_ABORT("fatal error"); }; - snprintf(base, 256, "kernel_%s_%s", op_str, ggml_type_name(op->src[0]->type)); + const char * t0_str = ggml_type_name(op->src[0]->type); + const char * t_str = ggml_type_name(op->type); - snprintf(name, 256, "%s", base); + const bool is_c4 = op->src[0]->ne[0] % 4 == 0; + + snprintf(base, 256, "kernel_sum_rows_%s_%s%s", t0_str, t_str, is_c4 ? "_4" : ""); + snprintf(name, 256, "%s_op=%d", base, op_num); ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name); if (!res.pipeline) { - res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); + ggml_metal_cv_t cv = ggml_metal_cv_init(); + + ggml_metal_cv_set_int16(cv, op_num, FC_SUM_ROWS + 0); + + res = ggml_metal_library_compile_pipeline(lib, base, name, cv); + + ggml_metal_cv_free(cv); } res.smem = 32*sizeof(float); + if (is_c4) { + res.smem *= 4; + } + + res.c4 = is_c4; + return res; } @@ -507,19 +582,98 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_rwkv(ggml_metal_ return res; } -ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mv_ext(ggml_metal_library_t lib, ggml_type tsrc0, ggml_type tsrc1, int nsg, int nxpsg, int r1ptg) { +ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_gated_delta_net(ggml_metal_library_t lib, const ggml_tensor * op) { + char base[256]; + char name[256]; + + // v is src[2], dimensions: S_v = ne[0], H = ne[1] + const int ne20 = op->src[2]->ne[0]; // S_v + const int ne21 = op->src[2]->ne[1]; // H + const int ne30 = op->src[3]->ne[0]; // G + // state is src[5], 4D [S_v, S_v, H_v, n_seqs] (s0 only); K is op param 0. + const int K = ggml_get_op_params_i32(op, 0); + + const int nsg = op->src[2]->ne[0]/32; + + GGML_ASSERT(op->src[5]->type == GGML_TYPE_F32); + GGML_ASSERT(op->ne[0] == ne20 * ne21); + GGML_ASSERT(ne20 % 32 == 0); + + snprintf(base, 256, "kernel_gated_delta_net_%s_%d", ggml_type_name(op->src[0]->type), nsg); + snprintf(name, 256, "%s_ne20=%d_ne30=%d_K=%d", base, ne20, ne30, K); + + ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name); + if (!res.pipeline) { + ggml_metal_cv_t cv = ggml_metal_cv_init(); + + ggml_metal_cv_set_int16(cv, ne20, FC_GATED_DELTA_NET + 0); + ggml_metal_cv_set_int16(cv, ne30, FC_GATED_DELTA_NET + 1); + ggml_metal_cv_set_int16(cv, K, FC_GATED_DELTA_NET + 2); + + res = ggml_metal_library_compile_pipeline(lib, base, name, cv); + + ggml_metal_cv_free(cv); + } + + res.nsg = nsg; + + return res; +} + +ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_solve_tri(ggml_metal_library_t lib, const ggml_tensor * op) { + char base[256]; + char name[256]; + + const int nsg = 8; + const int n = op->src[1]->ne[1]; + const int k = op->src[1]->ne[0]; + + snprintf(base, 256, "kernel_solve_tri_%s", ggml_type_name(op->src[0]->type)); + snprintf(name, 256, "%s_nsg=%d_n=%d_k=%d", base, nsg, n, k); + + ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name); + if (!res.pipeline) { + ggml_metal_cv_t cv = ggml_metal_cv_init(); + + ggml_metal_cv_set_int16(cv, nsg, FC_SOLVE_TRI + 0); + ggml_metal_cv_set_int16(cv, n, FC_SOLVE_TRI + 1); + ggml_metal_cv_set_int16(cv, k, FC_SOLVE_TRI + 2); + + res = ggml_metal_library_compile_pipeline(lib, base, name, cv); + + ggml_metal_cv_free(cv); + } + + res.nsg = nsg; + res.smem = GGML_PAD(GGML_PAD(n, 32)*nsg*sizeof(float), 16); + + return res; +} + +ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mv_ext(ggml_metal_library_t lib, const ggml_tensor * op, int nsg, int nxpsg, int r1ptg) { char base[256]; char name[256]; + const ggml_type tsrc0 = op->src[0]->type; + const ggml_type tsrc1 = op->src[1]->type; + const int ne12 = op->src[1]->ne[2]; + const int r2 = ne12 / op->src[0]->ne[2]; + const int r3 = op->src[1]->ne[3] / op->src[0]->ne[3]; + + GGML_ASSERT(ne12 <= INT16_MAX && r2 <= INT16_MAX && r3 <= INT16_MAX); + snprintf(base, 256, "kernel_mul_mv_ext_%s_%s_r1_%d", ggml_type_name(tsrc0), ggml_type_name(tsrc1), r1ptg); - snprintf(name, 256, "%s_nsg=%d_nxpsg=%d", base, nsg, nxpsg); + snprintf(name, 256, "%s_nsg=%d_nxpsg=%d_ne12=%d_r2=%d_r3=%d", base, nsg, nxpsg, ne12, r2, r3); ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name); if (!res.pipeline) { ggml_metal_cv_t cv = ggml_metal_cv_init(); - ggml_metal_cv_set_int16(cv, nsg, FC_MUL_MV + 0); - ggml_metal_cv_set_int16(cv, nxpsg, FC_MUL_MV + 1); + ggml_metal_cv_set_int16(cv, nsg, FC_MUL_MV + 0); + ggml_metal_cv_set_int16(cv, nxpsg, FC_MUL_MV + 1); + ggml_metal_cv_set_int16(cv, (int16_t) ne12, FC_MUL_MV + 2); + ggml_metal_cv_set_int16(cv, (int16_t) r2, FC_MUL_MV + 3); + ggml_metal_cv_set_int16(cv, (int16_t) r3, FC_MUL_MV + 4); res = ggml_metal_library_compile_pipeline(lib, base, name, cv); @@ -537,10 +691,25 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mm(ggml_meta const ggml_type tsrc1 = op->src[1]->type; const bool bc_inp = op->src[0]->ne[0] % 32 != 0; - const bool bc_out = op->ne[0] % 64 != 0 || op->ne[1] % 32 != 0; + + constexpr int NRA = SZ_SIMDGROUP * N_MM_BLOCK_Y * N_MM_SIMD_GROUP_Y; + constexpr int NRB = SZ_SIMDGROUP * N_MM_BLOCK_X * N_MM_SIMD_GROUP_X; + + const bool has_tensor = ggml_metal_device_get_props(ggml_metal_library_get_device(lib))->has_tensor; + + const bool bc_out = has_tensor + ? (op->ne[0] % NRA != 0 || op->ne[1] % NRB != 0) + : (op->ne[0] % 64 != 0 || op->ne[1] % 32 != 0); + + GGML_ASSERT(op->src[1]->ne[2] <= INT16_MAX && op->src[1]->ne[3] <= INT16_MAX); + const int16_t ne12 = (int16_t) op->src[1]->ne[2]; + const int16_t ne13 = (int16_t) op->src[1]->ne[3]; + const int16_t r2 = (int16_t) (ne12 / op->src[0]->ne[2]); + const int16_t r3 = (int16_t) (ne13 / op->src[0]->ne[3]); snprintf(base, 256, "kernel_mul_mm_%s_%s", ggml_type_name(tsrc0), ggml_type_name(tsrc1)); - snprintf(name, 256, "%s_bci=%d_bco=%d", base, bc_inp, bc_out); + snprintf(name, 256, "%s_bci=%d_bco=%d_ne12=%d_ne13=%d_r2=%d_r3=%d", + base, bc_inp, bc_out, ne12, ne13, r2, r3); ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name); if (!res.pipeline) { @@ -548,14 +717,30 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mm(ggml_meta ggml_metal_cv_set_bool(cv, bc_inp, FC_MUL_MM + 0); ggml_metal_cv_set_bool(cv, bc_out, FC_MUL_MM + 1); + ggml_metal_cv_set_int16(cv, ne12, FC_MUL_MM + 2); + ggml_metal_cv_set_int16(cv, ne13, FC_MUL_MM + 3); + ggml_metal_cv_set_int16(cv, r2, FC_MUL_MM + 4); + ggml_metal_cv_set_int16(cv, r3, FC_MUL_MM + 5); res = ggml_metal_library_compile_pipeline(lib, base, name, cv); ggml_metal_cv_free(cv); } - // when the output size is not multiple of 64x32, we need extra smem to prevent out-of-bounds writes - res.smem = bc_out ? 8192 : 4096 + 2048; + if (has_tensor) { + res.nr0 = NRA; + res.nr1 = NRB; + + const size_t smem_a = NRA * N_MM_NK_TOTAL * sizeof(ggml_fp16_t); + res.smem = smem_a; + } else { + res.nr0 = 64; + res.nr1 = 32; + + res.smem = bc_out ? 8192 : (4096 + 2048); + } + + res.nsg = N_MM_SIMD_GROUP_X * N_MM_SIMD_GROUP_Y; return res; } @@ -597,6 +782,11 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mv(ggml_meta suffix = ne00 % 4 == 0 ? "_4" : ""; } } break; + case GGML_TYPE_Q1_0: + { + nsg = N_SG_Q1_0; + nr0 = N_R0_Q1_0; + } break; case GGML_TYPE_Q4_0: { nsg = N_SG_Q4_0; @@ -712,14 +902,21 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mv(ggml_meta } }; + GGML_ASSERT(ne12 <= INT16_MAX && ne13 <= INT16_MAX); + const int16_t r2 = (int16_t) (ne12 / ne02); + const int16_t r3 = (int16_t) (ne13 / ne03); + snprintf(base, 256, "kernel_mul_mv_%s_%s%s", ggml_type_name(tsrc0), ggml_type_name(tsrc1), suffix); - snprintf(name, 256, "%s_nsg=%d", base, nsg); + snprintf(name, 256, "%s_nsg=%d_ne12=%d_r2=%d_r3=%d", base, nsg, ne12, r2, r3); ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name); if (!res.pipeline) { ggml_metal_cv_t cv = ggml_metal_cv_init(); - ggml_metal_cv_set_int16(cv, nsg, FC_MUL_MV + 0); + ggml_metal_cv_set_int16(cv, nsg, FC_MUL_MV + 0); + ggml_metal_cv_set_int16(cv, (int16_t) ne12, FC_MUL_MV + 2); + ggml_metal_cv_set_int16(cv, r2, FC_MUL_MV + 3); + ggml_metal_cv_set_int16(cv, r3, FC_MUL_MV + 4); res = ggml_metal_library_compile_pipeline(lib, base, name, cv); @@ -809,6 +1006,11 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mv_id(ggml_m smem = 32*sizeof(float)*nr0; suffix = ne00 % 4 == 0 ? "_4" : ""; } break; + case GGML_TYPE_Q1_0: + { + nsg = N_SG_Q1_0; + nr0 = N_R0_Q1_0; + } break; case GGML_TYPE_Q4_0: { nsg = N_SG_Q4_0; @@ -932,6 +1134,9 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mv_id(ggml_m ggml_metal_cv_t cv = ggml_metal_cv_init(); ggml_metal_cv_set_int16(cv, nsg, FC_MUL_MV + 0); + ggml_metal_cv_set_int16(cv, 1, FC_MUL_MV + 2); + ggml_metal_cv_set_int16(cv, 1, FC_MUL_MV + 3); + ggml_metal_cv_set_int16(cv, 1, FC_MUL_MV + 4); res = ggml_metal_library_compile_pipeline(lib, base, name, cv); @@ -1315,34 +1520,80 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_flash_attn_ext_v GGML_UNUSED(op); } -ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_bin( - ggml_metal_library_t lib, - ggml_op op, - int32_t n_fuse, - bool row) { +ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_bin(ggml_metal_library_t lib, const ggml_tensor * op, int32_t n_fuse) { char base[256]; char name[256]; - const char * op_str = "undefined"; - switch (op) { - case GGML_OP_ADD: op_str = "add"; break; - case GGML_OP_SUB: op_str = "sub"; break; - case GGML_OP_MUL: op_str = "mul"; break; - case GGML_OP_DIV: op_str = "div"; break; + int op_num = -1; + + switch (op->op) { + case GGML_OP_ADD: op_num = 0; break; + case GGML_OP_SUB: op_num = 1; break; + case GGML_OP_MUL: op_num = 2; break; + case GGML_OP_DIV: op_num = 3; break; default: GGML_ABORT("fatal error"); }; - if (row) { - snprintf(base, 256, "kernel_%s_row_c4_fuse_%d", op_str, n_fuse); - } else { - snprintf(base, 256, "kernel_%s_fuse_%d", op_str, n_fuse); + const char * t0_str = ggml_type_name(op->src[0]->type); + const char * t1_str = ggml_type_name(op->src[1]->type); + const char * t_str = ggml_type_name(op->type); + + const bool is_c4 = (op->src[0]->ne[0] % 4 == 0) && (op->src[1]->ne[0] % 4 == 0); + + const bool is_cb = op->src[0]->ne[0] != op->src[1]->ne[0]; + const bool is_rb = ggml_is_contiguous(op->src[0]) && ggml_is_contiguous(op->src[1]) && (ggml_nrows(op->src[1]) == 1) && ggml_nelements(op) < 65536; + + snprintf(base, 256, "kernel_bin_fuse_%s_%s_%s%s", t0_str, t1_str, t_str, is_c4 ? "_4" : ""); + snprintf(name, 256, "%s_op=%d_nf=%d_rb=%d_cb=%d", base, op_num, n_fuse, is_rb, is_cb); + + ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name); + if (!res.pipeline) { + ggml_metal_cv_t cv = ggml_metal_cv_init(); + + ggml_metal_cv_set_int16(cv, op_num, FC_BIN + 0); + ggml_metal_cv_set_int16(cv, n_fuse, FC_BIN + 1); + ggml_metal_cv_set_bool (cv, is_rb, FC_BIN + 2); + ggml_metal_cv_set_bool (cv, is_cb, FC_BIN + 3); + + res = ggml_metal_library_compile_pipeline(lib, base, name, cv); + + ggml_metal_cv_free(cv); } - snprintf(name, 256, "%s", base); + res.c4 = is_c4; + res.cnt = is_rb; + + return res; +} + +ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_bin_one(ggml_metal_library_t lib, ggml_op op) { + char base[256]; + char name[256]; + + int op_num = -1; + + switch (op) { + case GGML_OP_ADD: op_num = 0; break; + case GGML_OP_SUB: op_num = 1; break; + case GGML_OP_MUL: op_num = 2; break; + case GGML_OP_DIV: op_num = 3; break; + default: GGML_ABORT("fatal error"); + }; + + snprintf(base, 256, "kernel_bin_fuse_%s_%s_%s", "f32", "f32", "f32"); + snprintf(name, 256, "%s_op=%d_nf=%d", base, op_num, 1); ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name); if (!res.pipeline) { - res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); + ggml_metal_cv_t cv = ggml_metal_cv_init(); + + ggml_metal_cv_set_int16(cv, op_num, FC_BIN + 0); + ggml_metal_cv_set_int16(cv, 1, FC_BIN + 1); + ggml_metal_cv_set_bool (cv, false, FC_BIN + 2); + + res = ggml_metal_library_compile_pipeline(lib, base, name, cv); + + ggml_metal_cv_free(cv); } return res; @@ -1351,13 +1602,15 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_bin( ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_l2_norm(ggml_metal_library_t lib, const ggml_tensor * op) { assert(op->op == GGML_OP_L2_NORM); - GGML_ASSERT(op->src[0]->ne[0] % 4 == 0); - GGML_ASSERT(ggml_is_contiguous_1(op->src[0])); - char base[256]; char name[256]; - snprintf(base, 256, "kernel_l2_norm_f32"); + const bool is_c4 = op->src[0]->ne[0] % 4 == 0; + + const char * t0_str = ggml_type_name(op->src[0]->type); + const char * t_str = ggml_type_name(op->type); + + snprintf(base, 256, "kernel_l2_norm_%s_%s%s", t0_str, t_str, is_c4 ? "_4" : ""); snprintf(name, 256, "%s", base); ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name); @@ -1365,6 +1618,7 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_l2_norm(ggml_met res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); } + res.c4 = is_c4; res.smem = 32*sizeof(float); return res; @@ -1478,14 +1732,24 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_rope(ggml_metal_ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_im2col(ggml_metal_library_t lib, const ggml_tensor * op) { assert(op->op == GGML_OP_IM2COL); + GGML_TENSOR_LOCALS(int64_t, ne0, op->src[0], ne); + GGML_ASSERT(ggml_is_contiguous(op->src[1])); GGML_ASSERT(op->src[1]->type == GGML_TYPE_F32); GGML_ASSERT(op->type == GGML_TYPE_F16 || op->type == GGML_TYPE_F32); + const bool is_2D = ((const int32_t *)(op->op_params))[6] == 1; + const int64_t KH = is_2D ? ne01 : 1; + const int64_t KW = ne00; + char base[256]; char name[256]; - snprintf(base, 256, "kernel_im2col_%s", ggml_type_name(op->type)); + if (KH*KW <= 1024) { + snprintf(base, 256, "kernel_im2col_%s", ggml_type_name(op->type)); + } else { + snprintf(base, 256, "kernel_im2col_ext_%s", ggml_type_name(op->type)); + } snprintf(name, 256, "%s", base); ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name); @@ -1564,13 +1828,69 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_conv_2d(ggml_met return res; } +ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_conv_3d(ggml_metal_library_t lib, const ggml_tensor * op) { + assert(op->op == GGML_OP_CONV_3D); + + GGML_ASSERT(ggml_is_contiguous(op->src[0])); + GGML_ASSERT(op->src[0]->type == GGML_TYPE_F16 || op->src[0]->type == GGML_TYPE_F32); + GGML_ASSERT(op->src[1]->type == GGML_TYPE_F32); + GGML_ASSERT(op->type == GGML_TYPE_F32); + + char base[256]; + char name[256]; + + snprintf(base, 256, "kernel_conv_3d_%s_%s", ggml_type_name(op->src[0]->type), ggml_type_name(op->src[1]->type)); + snprintf(name, 256, "%s", base); + + ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name); + if (!res.pipeline) { + res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); + } + + return res; +} + ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_upscale(ggml_metal_library_t lib, const ggml_tensor * op) { assert(op->op == GGML_OP_UPSCALE); char base[256]; char name[256]; - snprintf(base, 256, "kernel_upscale_%s", ggml_type_name(op->src[0]->type)); + const int32_t mode_flags = ggml_get_op_params_i32(op, 0); + const ggml_scale_mode mode = (ggml_scale_mode) (mode_flags & 0xFF); + + const bool antialias = (mode_flags & GGML_SCALE_FLAG_ANTIALIAS); + + if (mode == GGML_SCALE_MODE_BILINEAR) { + snprintf(base, 256, "kernel_upscale_bilinear_%s", ggml_type_name(op->src[0]->type)); + } else if (mode == GGML_SCALE_MODE_BICUBIC) { + snprintf(base, 256, "kernel_upscale_bicubic_%s", ggml_type_name(op->src[0]->type)); + } else { + snprintf(base, 256, "kernel_upscale_nearest_%s", ggml_type_name(op->src[0]->type)); + } + snprintf(name, 256, "%s_aa=%d", base, antialias); + + ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name); + if (!res.pipeline) { + ggml_metal_cv_t cv = ggml_metal_cv_init(); + + ggml_metal_cv_set_bool(cv, antialias, FC_UPSCALE + 0); + + res = ggml_metal_library_compile_pipeline(lib, base, name, cv); + + ggml_metal_cv_free(cv); + } + + return res; +} + +ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_roll(ggml_metal_library_t lib, const ggml_tensor * op) { + assert(op->op == GGML_OP_ROLL); + + char base[256]; + char name[256]; + + snprintf(base, 256, "kernel_roll_%s", ggml_type_name(op->src[0]->type)); snprintf(name, 256, "%s", base); ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name); @@ -1587,7 +1907,11 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_pad(ggml_metal_l char base[256]; char name[256]; - snprintf(base, 256, "kernel_pad_%s", ggml_type_name(op->src[0]->type)); + // note: this is slower + //const bool is_c4 = op->src[0]->ne[0] % 4 == 0 && op->ne[0] % 4 == 0; + const bool is_c4 = false; + + snprintf(base, 256, "kernel_pad_%s%s", ggml_type_name(op->src[0]->type), is_c4 ? "_4" : ""); snprintf(name, 256, "%s", base); ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name); @@ -1597,6 +1921,8 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_pad(ggml_metal_l res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); + res.c4 = is_c4; + return res; } diff --git a/ggml/src/ggml-metal/ggml-metal-device.h b/ggml/src/ggml-metal/ggml-metal-device.h index 9c3b0014878..4a3ebb5569d 100644 --- a/ggml/src/ggml-metal/ggml-metal-device.h +++ b/ggml/src/ggml-metal/ggml-metal-device.h @@ -53,6 +53,9 @@ struct ggml_metal_pipeline_with_params { int nr1; size_t smem; + + bool c4; + bool cnt; }; int ggml_metal_pipeline_max_theads_per_threadgroup(struct ggml_metal_pipeline_with_params pipeline); @@ -99,14 +102,18 @@ ggml_metal_library_t ggml_metal_library_init_from_source(ggml_metal_device_t dev void ggml_metal_library_free(ggml_metal_library_t lib); +ggml_metal_device_t ggml_metal_library_get_device(ggml_metal_library_t lib); + struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline (ggml_metal_library_t lib, const char * name); struct ggml_metal_pipeline_with_params ggml_metal_library_compile_pipeline(ggml_metal_library_t lib, const char * base, const char * name, ggml_metal_cv_t cv); struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_base (ggml_metal_library_t lib, enum ggml_op op); struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_cpy (ggml_metal_library_t lib, enum ggml_type tsrc, enum ggml_type tdst); +struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_pool_1d (ggml_metal_library_t lib, const struct ggml_tensor * op, enum ggml_op_pool op_pool); struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_pool_2d (ggml_metal_library_t lib, const struct ggml_tensor * op, enum ggml_op_pool op_pool); struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_get_rows (ggml_metal_library_t lib, enum ggml_type tsrc); struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_set_rows (ggml_metal_library_t lib, enum ggml_type tidx, enum ggml_type tdst); +struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_diag (ggml_metal_library_t lib, const struct ggml_tensor * op); struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_repeat (ggml_metal_library_t lib, enum ggml_type tsrc); struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_unary (ggml_metal_library_t lib, const struct ggml_tensor * op); struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_glu (ggml_metal_library_t lib, const struct ggml_tensor * op); @@ -120,7 +127,9 @@ struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_ssm_conv struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_ssm_conv_batched (ggml_metal_library_t lib, const struct ggml_tensor * op, int ssm_conv_bs); struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_ssm_scan (ggml_metal_library_t lib, const struct ggml_tensor * op); struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_rwkv (ggml_metal_library_t lib, const struct ggml_tensor * op); -struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mv_ext (ggml_metal_library_t lib, enum ggml_type tsrc0, enum ggml_type tsrc1, int nsg, int nxpsg, int r1ptg); +struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_gated_delta_net (ggml_metal_library_t lib, const struct ggml_tensor * op); +struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_solve_tri (ggml_metal_library_t lib, const struct ggml_tensor * op); +struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mv_ext (ggml_metal_library_t lib, const struct ggml_tensor * op, int nsg, int nxpsg, int r1ptg); struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mm (ggml_metal_library_t lib, const struct ggml_tensor * op); struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mv (ggml_metal_library_t lib, const struct ggml_tensor * op); struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mm_id_map0 (ggml_metal_library_t lib, int ne02, int ne20); @@ -131,7 +140,8 @@ struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_argsort struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_argsort_merge (ggml_metal_library_t lib, const struct ggml_tensor * op); struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_top_k (ggml_metal_library_t lib, const struct ggml_tensor * op); struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_top_k_merge (ggml_metal_library_t lib, const struct ggml_tensor * op); -struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_bin (ggml_metal_library_t lib, enum ggml_op op, int32_t n_fuse, bool row); +struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_bin (ggml_metal_library_t lib, const struct ggml_tensor * op, int32_t n_fuse ); +struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_bin_one (ggml_metal_library_t lib, enum ggml_op op); struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_l2_norm (ggml_metal_library_t lib, const struct ggml_tensor * op); struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_group_norm (ggml_metal_library_t lib, const struct ggml_tensor * op); struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_norm (ggml_metal_library_t lib, const struct ggml_tensor * op, int32_t n_fuse); @@ -140,9 +150,11 @@ struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_im2col struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_conv_transpose_1d (ggml_metal_library_t lib, const struct ggml_tensor * op); struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_conv_transpose_2d (ggml_metal_library_t lib, const struct ggml_tensor * op); struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_conv_2d (ggml_metal_library_t lib, const struct ggml_tensor * op); +struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_conv_3d (ggml_metal_library_t lib, const struct ggml_tensor * op); struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_upscale (ggml_metal_library_t lib, const struct ggml_tensor * op); struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_pad (ggml_metal_library_t lib, const struct ggml_tensor * op); struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_pad_reflect_1d (ggml_metal_library_t lib, const struct ggml_tensor * op); +struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_roll (ggml_metal_library_t lib, const struct ggml_tensor * op); struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_arange (ggml_metal_library_t lib, const struct ggml_tensor * op); struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_timestep_embedding(ggml_metal_library_t lib, const struct ggml_tensor * op); struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_opt_step_adamw (ggml_metal_library_t lib, const struct ggml_tensor * op); @@ -203,8 +215,34 @@ void ggml_metal_rsets_free(ggml_metal_rsets_t rsets); // device // +enum ggml_metal_device_id { + GGML_METAL_DEVICE_GENERIC = 0, + + GGML_METAL_DEVICE_M1, + GGML_METAL_DEVICE_M1_PRO, + GGML_METAL_DEVICE_M1_MAX, + GGML_METAL_DEVICE_M1_ULTRA, + GGML_METAL_DEVICE_M2, + GGML_METAL_DEVICE_M2_PRO, + GGML_METAL_DEVICE_M2_MAX, + GGML_METAL_DEVICE_M2_ULTRA, + GGML_METAL_DEVICE_M3, + GGML_METAL_DEVICE_M3_PRO, + GGML_METAL_DEVICE_M3_MAX, + GGML_METAL_DEVICE_M3_ULTRA, + GGML_METAL_DEVICE_M4, + GGML_METAL_DEVICE_M4_PRO, + GGML_METAL_DEVICE_M4_MAX, + GGML_METAL_DEVICE_M5, + GGML_METAL_DEVICE_M5_PRO, + GGML_METAL_DEVICE_M5_MAX, + GGML_METAL_DEVICE_M5_ULTRA, +}; + struct ggml_metal_device_props { + int device; char name[128]; + char desc[128]; size_t max_buffer_size; size_t max_working_set_size; @@ -220,14 +258,20 @@ struct ggml_metal_device_props { bool supports_gpu_family_apple7; + enum ggml_metal_device_id device_id; + int op_offload_min_batch_size; }; -ggml_metal_device_t ggml_metal_device_init(void); +typedef struct ggml_metal_event * ggml_metal_event_t; + +void ggml_metal_event_encode_signal(ggml_metal_event_t ev, ggml_metal_cmd_buf_t cmd_buf); +void ggml_metal_event_encode_wait (ggml_metal_event_t ev, ggml_metal_cmd_buf_t cmd_buf); + +ggml_metal_device_t ggml_metal_device_init(int device); void ggml_metal_device_free(ggml_metal_device_t dev); -// return a singleton that is automatically destroyed when the program exits -ggml_metal_device_t ggml_metal_device_get(void); +ggml_metal_device_t ggml_metal_device_get(int device); void * ggml_metal_device_get_obj (ggml_metal_device_t dev); // id<MTLDevice> void * ggml_metal_device_get_queue(ggml_metal_device_t dev); // id<MTLCommandQueue> @@ -239,6 +283,10 @@ void ggml_metal_device_rsets_rm (ggml_metal_device_t dev, ggml_metal_rset_t rset void ggml_metal_device_rsets_keep_alive(ggml_metal_device_t dev); +ggml_metal_event_t ggml_metal_device_event_init(ggml_metal_device_t dev); +void ggml_metal_device_event_free(ggml_metal_device_t dev, ggml_metal_event_t ev); +void ggml_metal_device_event_synchronize(ggml_metal_device_t dev, ggml_metal_event_t ev); + void ggml_metal_device_get_memory(ggml_metal_device_t dev, size_t * free, size_t * total); bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_tensor * op); @@ -260,6 +308,7 @@ bool ggml_metal_buffer_is_shared(ggml_metal_buffer_t buf); void ggml_metal_buffer_memset_tensor(ggml_metal_buffer_t buf, struct ggml_tensor * tensor, uint8_t value, size_t offset, size_t size); void ggml_metal_buffer_set_tensor (ggml_metal_buffer_t buf, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size); void ggml_metal_buffer_get_tensor (ggml_metal_buffer_t buf, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size); +bool ggml_metal_buffer_cpy_tensor (ggml_metal_buffer_t buf, const struct ggml_tensor * src, struct ggml_tensor * dst); void ggml_metal_buffer_clear (ggml_metal_buffer_t buf, uint8_t value); // finds the Metal buffer that contains the tensor data on the GPU device diff --git a/ggml/src/ggml-metal/ggml-metal-device.m b/ggml/src/ggml-metal/ggml-metal-device.m index ff899a81709..d583bd6efc0 100644 --- a/ggml/src/ggml-metal/ggml-metal-device.m +++ b/ggml/src/ggml-metal/ggml-metal-device.m @@ -1,6 +1,7 @@ #import "ggml-metal-device.h" #import "ggml-impl.h" +#import "ggml-backend-impl.h" #include <Foundation/Foundation.h> @@ -24,9 +25,6 @@ static const NSInteger MTLGPUFamilyMetal3_GGML = 5001; static const NSInteger MTLGPUFamilyMetal4_GGML = 5002; -// virtual address for GPU memory allocations -static atomic_uintptr_t g_addr_device = 0x000000400ULL; - #if !GGML_METAL_EMBED_LIBRARY // Here to assist with NSBundle Path Hack @interface GGMLMetalClass : NSObject @@ -98,8 +96,8 @@ int ggml_metal_pipeline_max_theads_per_threadgroup(struct ggml_metal_pipeline_wi struct ggml_metal_library { id<MTLLibrary> obj; - id<MTLDevice> device; + ggml_metal_device_t dev; ggml_metal_pipelines_t pipelines; // cache of compiled pipelines NSLock * lock; @@ -254,7 +252,7 @@ ggml_metal_library_t ggml_metal_library_init(ggml_metal_device_t dev) { ggml_metal_library_t res = calloc(1, sizeof(struct ggml_metal_library)); res->obj = library; - res->device = device; + res->dev = dev; res->pipelines = ggml_metal_pipelines_init(); res->lock = [NSLock new]; @@ -321,7 +319,7 @@ ggml_metal_library_t ggml_metal_library_init_from_source(ggml_metal_device_t dev } res->obj = library; - res->device = device; + res->dev = dev; res->pipelines = ggml_metal_pipelines_init(); res->lock = [NSLock new]; @@ -344,15 +342,21 @@ void ggml_metal_library_free(ggml_metal_library_t lib) { free(lib); } +ggml_metal_device_t ggml_metal_library_get_device(ggml_metal_library_t lib) { + return lib->dev; +} + struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline(ggml_metal_library_t lib, const char * name) { [lib->lock lock]; struct ggml_metal_pipeline_with_params res = { /*.pipeline =*/ nil, + /*.nsg =*/ 0, /*.nr0 =*/ 0, /*.nr1 =*/ 0, - /*.nsg =*/ 0, /*.smem =*/ 0, + /*.c4 =*/ false, + /*.cnt =*/ false, }; res.pipeline = ggml_metal_pipelines_get(lib->pipelines, name); @@ -365,10 +369,12 @@ struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline(ggml_meta struct ggml_metal_pipeline_with_params ggml_metal_library_compile_pipeline(ggml_metal_library_t lib, const char * base, const char * name, ggml_metal_cv_t cv) { struct ggml_metal_pipeline_with_params res = { /*.pipeline =*/ nil, + /*.nsg =*/ 0, /*.nr0 =*/ 0, /*.nr1 =*/ 0, - /*.nsg =*/ 0, /*.smem =*/ 0, + /*.c4 =*/ false, + /*.cnt =*/ false, }; [lib->lock lock]; @@ -404,7 +410,8 @@ struct ggml_metal_pipeline_with_params ggml_metal_library_compile_pipeline(ggml_ return res; } - id<MTLComputePipelineState> obj = [lib->device newComputePipelineStateWithFunction:mtl_function error:&error]; + id<MTLDevice> device = ggml_metal_device_get_obj(lib->dev); + id<MTLComputePipelineState> obj = [device newComputePipelineStateWithFunction:mtl_function error:&error]; [mtl_function release]; @@ -523,6 +530,9 @@ void ggml_metal_encoder_end_encoding(ggml_metal_encoder_t encoder) { ggml_metal_library_t library; struct ggml_metal_device_props props; + + // virtual address for GPU memory allocations + atomic_uintptr_t addr_virt; }; // @@ -537,6 +547,8 @@ void ggml_metal_encoder_end_encoding(ggml_metal_encoder_t encoder) { // number of seconds since the last graph computation // keep the residency sets wired for that amount of time to avoid being collected by the OS int keep_alive_s; + int loops_per_s; + int time_per_loop_ms; // background heartbeat thread to keep the residency sets alive atomic_bool d_stop; @@ -563,10 +575,13 @@ ggml_metal_rsets_t ggml_metal_rsets_init(void) { res->keep_alive_s = 3*60; } + res->time_per_loop_ms = 5; + res->loops_per_s = 1000/res->time_per_loop_ms; + GGML_LOG_INFO("%s: creating a residency set collection (keep_alive = %d s)\n", __func__, res->keep_alive_s); atomic_store_explicit(&res->d_stop, false, memory_order_relaxed); - atomic_store_explicit(&res->d_loop, 2*res->keep_alive_s, memory_order_relaxed); + atomic_store_explicit(&res->d_loop, res->loops_per_s*res->keep_alive_s, memory_order_relaxed); res->d_group = dispatch_group_create(); @@ -589,8 +604,7 @@ ggml_metal_rsets_t ggml_metal_rsets_init(void) { [res->lock unlock]; } - // half a second - usleep(500 * 1000); + usleep(res->time_per_loop_ms * 1000); } } #endif @@ -618,7 +632,51 @@ void ggml_metal_rsets_free(ggml_metal_rsets_t rsets) { free(rsets); } -ggml_metal_device_t ggml_metal_device_init(void) { +static enum ggml_metal_device_id ggml_metal_device_id_parse(const char * name) { + if (!name) { + return GGML_METAL_DEVICE_GENERIC; + } + + static const char prefix[] = "Apple "; + if (strncmp(name, prefix, sizeof(prefix) - 1) != 0) { + return GGML_METAL_DEVICE_GENERIC; + } + const char * suffix = name + sizeof(prefix) - 1; + + static const struct { + const char * name; + enum ggml_metal_device_id id; + } table[] = { + {"M1", GGML_METAL_DEVICE_M1}, + {"M1 Pro", GGML_METAL_DEVICE_M1_PRO}, + {"M1 Max", GGML_METAL_DEVICE_M1_MAX}, + {"M1 Ultra", GGML_METAL_DEVICE_M1_ULTRA}, + {"M2", GGML_METAL_DEVICE_M2}, + {"M2 Pro", GGML_METAL_DEVICE_M2_PRO}, + {"M2 Max", GGML_METAL_DEVICE_M2_MAX}, + {"M2 Ultra", GGML_METAL_DEVICE_M2_ULTRA}, + {"M3", GGML_METAL_DEVICE_M3}, + {"M3 Pro", GGML_METAL_DEVICE_M3_PRO}, + {"M3 Max", GGML_METAL_DEVICE_M3_MAX}, + {"M3 Ultra", GGML_METAL_DEVICE_M3_ULTRA}, + {"M4", GGML_METAL_DEVICE_M4}, + {"M4 Pro", GGML_METAL_DEVICE_M4_PRO}, + {"M4 Max", GGML_METAL_DEVICE_M4_MAX}, + {"M5", GGML_METAL_DEVICE_M5}, + {"M5 Pro", GGML_METAL_DEVICE_M5_PRO}, + {"M5 Max", GGML_METAL_DEVICE_M5_MAX}, + {"M5 Ultra", GGML_METAL_DEVICE_M5_ULTRA}, + }; + + for (size_t i = 0; i < sizeof(table)/sizeof(table[0]); ++i) { + if (strcmp(suffix, table[i].name) == 0) { + return table[i].id; + } + } + return GGML_METAL_DEVICE_GENERIC; +} + +ggml_metal_device_t ggml_metal_device_init(int device) { ggml_metal_device_t dev = calloc(1, sizeof(struct ggml_metal_device)); assert(dev != NULL); @@ -632,6 +690,9 @@ ggml_metal_device_t ggml_metal_device_init(void) { GGML_LOG_ERROR("%s: error: failed to create command queue\n", __func__); } + dev->addr_virt = 0x000000400ULL; + + dev->props.device = device; dev->props.has_simdgroup_reduction = [dev->mtl_device supportsFamily:MTLGPUFamilyApple7]; dev->props.has_simdgroup_reduction |= [dev->mtl_device supportsFamily:MTLGPUFamilyMetal3_GGML]; @@ -659,7 +720,7 @@ ggml_metal_device_t ggml_metal_device_init(void) { ![[dev->mtl_device name] containsString:@"M6"] && ![[dev->mtl_device name] containsString:@"A19"] && ![[dev->mtl_device name] containsString:@"A20"]) { - GGML_LOG_WARN("%s: tensor API disabled for pre-M5 and pre-A19 devices\n", __func__); + GGML_LOG_INFO("%s: tensor API disabled for pre-M5 and pre-A19 devices\n", __func__); dev->props.has_tensor = false; } @@ -683,7 +744,7 @@ ggml_metal_device_t ggml_metal_device_init(void) { " auto tB = B.slice((int)tgid.x, 0); \n" " \n" " matmul2d< \n" - " matmul2d_descriptor(8, 8, dynamic_extent), \n" + " matmul2d_descriptor(16, 16, dynamic_extent), \n" " execution_simdgroups<4>> mm; \n" " \n" " auto cT = mm.get_destination_cooperative_tensor<decltype(tA), decltype(tB), float>(); \n" @@ -692,7 +753,7 @@ ggml_metal_device_t ggml_metal_device_init(void) { " auto sB = tB.slice(0, 0); \n" " mm.run(sB, sA, cT); \n" " \n" - " auto tC = tensor<device float, dextents<int32_t, 2>, tensor_inline>(C, dextents<int32_t, 2>(4, 4)); \n" + " auto tC = tensor<device float, dextents<int32_t, 2>, tensor_inline>(C, dextents<int32_t, 2>(16, 16)); \n" " \n" " cT.store(tC); \n" "}"; @@ -733,7 +794,7 @@ ggml_metal_device_t ggml_metal_device_init(void) { " auto tB = B.slice((int)tgid.x, 0); \n" " \n" " matmul2d< \n" - " matmul2d_descriptor(8, 8, dynamic_extent), \n" + " matmul2d_descriptor(16, 16, dynamic_extent), \n" " execution_simdgroups<4>> mm; \n" " \n" " auto cT = mm.get_destination_cooperative_tensor<decltype(tA), decltype(tB), float>(); \n" @@ -742,7 +803,7 @@ ggml_metal_device_t ggml_metal_device_init(void) { " auto sB = tB.slice(0, 0); \n" " mm.run(sB, sA, cT); \n" " \n" - " auto tC = tensor<device float, dextents<int32_t, 2>, tensor_inline>(C, dextents<int32_t, 2>(4, 4)); \n" + " auto tC = tensor<device float, dextents<int32_t, 2>, tensor_inline>(C, dextents<int32_t, 2>(16, 16)); \n" " \n" " cT.store(tC); \n" "}"; @@ -782,13 +843,20 @@ ggml_metal_device_t ggml_metal_device_init(void) { dev->props.supports_gpu_family_apple7 = [dev->mtl_device supportsFamily:MTLGPUFamilyApple7]; + dev->props.device_id = ggml_metal_device_id_parse([[dev->mtl_device name] UTF8String]); + dev->props.op_offload_min_batch_size = getenv("GGML_OP_OFFLOAD_MIN_BATCH") ? atoi(getenv("GGML_OP_OFFLOAD_MIN_BATCH")) : 32; dev->props.max_buffer_size = dev->mtl_device.maxBufferLength; - dev->props.max_working_set_size = dev->mtl_device.recommendedMaxWorkingSetSize; dev->props.max_theadgroup_memory_size = dev->mtl_device.maxThreadgroupMemoryLength; + if (@available(macOS 10.12, iOS 16.0, *)) { + dev->props.max_working_set_size = dev->mtl_device.recommendedMaxWorkingSetSize; + } else { + dev->props.max_working_set_size = dev->mtl_device.maxBufferLength; + } - strncpy(dev->props.name, [[dev->mtl_device name] UTF8String], sizeof(dev->props.name) - 1); + snprintf(dev->props.name, sizeof(dev->props.name), "%s%d", "MTL", device); + snprintf(dev->props.desc, sizeof(dev->props.desc), "%s", [[dev->mtl_device name] UTF8String]); dev->library = ggml_metal_library_init(dev); if (!dev->library) { @@ -802,7 +870,7 @@ ggml_metal_device_t ggml_metal_device_init(void) { } // print MTL GPU family: - GGML_LOG_INFO("%s: GPU name: %s\n", __func__, dev->props.name); + GGML_LOG_INFO("%s: GPU name: %s (%s)\n", __func__, dev->props.name, dev->props.desc); // determine max supported GPU family // https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf @@ -915,7 +983,59 @@ void ggml_metal_device_rsets_keep_alive(ggml_metal_device_t dev) { return; } - atomic_store_explicit(&dev->rsets->d_loop, 2*dev->rsets->keep_alive_s, memory_order_relaxed); + atomic_store_explicit(&dev->rsets->d_loop, dev->rsets->loops_per_s*dev->rsets->keep_alive_s, memory_order_relaxed); +} + +struct ggml_metal_event { + void * obj; // id<MTLSharedEvent> + + atomic_int value; +}; + +void ggml_metal_event_encode_signal(ggml_metal_event_t ev, ggml_metal_cmd_buf_t cmd_buf_raw) { + id<MTLSharedEvent> event = (id<MTLSharedEvent>)ev->obj; + + id<MTLCommandBuffer> cmd_buf = (id<MTLCommandBuffer>) cmd_buf_raw; + + [cmd_buf encodeSignalEvent:event value:atomic_fetch_add_explicit(&ev->value, 1, memory_order_relaxed) + 1]; +} + +void ggml_metal_event_encode_wait(ggml_metal_event_t ev, ggml_metal_cmd_buf_t cmd_buf_raw) { + id<MTLSharedEvent> event = (id<MTLSharedEvent>)ev->obj; + + id<MTLCommandBuffer> cmd_buf = (id<MTLCommandBuffer>) cmd_buf_raw; + + [cmd_buf encodeWaitForEvent:event value:atomic_load_explicit(&ev->value, memory_order_relaxed)]; +} + +ggml_metal_event_t ggml_metal_device_event_init(ggml_metal_device_t dev) { + id<MTLSharedEvent> event = [dev->mtl_device newSharedEvent]; + + ggml_metal_event_t ev = calloc(1, sizeof(struct ggml_metal_event)); + + ev->obj = (__bridge void *)event; + ev->value = 0; + + return ev; +} + +void ggml_metal_device_event_free(ggml_metal_device_t dev, ggml_metal_event_t ev) { + id<MTLSharedEvent> event = ev->obj; + [event release]; + + free(ev); + + GGML_UNUSED(dev); +} + +void ggml_metal_device_event_synchronize(ggml_metal_device_t dev, ggml_metal_event_t ev) { + id<MTLSharedEvent> event = ev->obj; + const bool res = [event waitUntilSignaledValue:atomic_load_explicit(&ev->value, memory_order_relaxed) timeoutMS:60000]; + if (!res) { + GGML_ABORT("%s: failed to wait for event\n", __func__); + } + + GGML_UNUSED(dev); } void ggml_metal_device_get_memory(ggml_metal_device_t dev, size_t * free, size_t * total) { @@ -946,6 +1066,15 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te } switch (op->op) { + case GGML_OP_SCALE: + case GGML_OP_FILL: + case GGML_OP_CLAMP: + case GGML_OP_SQR: + case GGML_OP_SQRT: + case GGML_OP_SIN: + case GGML_OP_COS: + case GGML_OP_LOG: + return ggml_is_contiguous_rows(op->src[0]) && (op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16); case GGML_OP_UNARY: switch (ggml_get_unary_op(op)) { case GGML_UNARY_OP_TANH: @@ -965,7 +1094,12 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te case GGML_UNARY_OP_EXP: case GGML_UNARY_OP_SOFTPLUS: case GGML_UNARY_OP_EXPM1: - return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32; + case GGML_UNARY_OP_FLOOR: + case GGML_UNARY_OP_CEIL: + case GGML_UNARY_OP_ROUND: + case GGML_UNARY_OP_TRUNC: + case GGML_UNARY_OP_XIELU: + return ggml_is_contiguous_rows(op->src[0]) && (op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16); default: return false; } @@ -977,7 +1111,7 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te case GGML_GLU_OP_SWIGLU_OAI: case GGML_GLU_OP_GEGLU_ERF: case GGML_GLU_OP_GEGLU_QUICK: - return ggml_is_contiguous_1(op->src[0]) && op->src[0]->type == GGML_TYPE_F32; + return ggml_is_contiguous_1(op->src[0]) && (op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16); default: return false; } @@ -986,18 +1120,25 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te case GGML_OP_VIEW: case GGML_OP_TRANSPOSE: case GGML_OP_PERMUTE: - case GGML_OP_CONCAT: return true; + case GGML_OP_CONCAT: + { + // kernel_concat copies one float-sized value per element. + // Other scalar types need a type-generic copy kernel first. + const enum ggml_type src0_type = op->src[0]->type; + const enum ggml_type src1_type = op->src[1]->type; + return src0_type == src1_type && + src0_type == op->type && + (src0_type == GGML_TYPE_F32 || src0_type == GGML_TYPE_I32); + } case GGML_OP_ADD: case GGML_OP_SUB: case GGML_OP_MUL: case GGML_OP_DIV: case GGML_OP_ADD_ID: - return op->src[0]->type == GGML_TYPE_F32; case GGML_OP_ACC: + return ggml_is_contiguous_rows(op->src[0]) && ggml_is_contiguous_rows(op->src[1]) && op->src[0]->type == GGML_TYPE_F32; case GGML_OP_REPEAT: - case GGML_OP_SCALE: - case GGML_OP_FILL: case GGML_OP_CONV_TRANSPOSE_1D: return true; case GGML_OP_CONV_TRANSPOSE_2D: @@ -1005,14 +1146,11 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te (op->src[0]->type == GGML_TYPE_F16 || op->src[0]->type == GGML_TYPE_F32) && op->src[1]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32; - case GGML_OP_CLAMP: - return op->src[0]->type == GGML_TYPE_F32; - case GGML_OP_SQR: - case GGML_OP_SQRT: - case GGML_OP_SIN: - case GGML_OP_COS: - case GGML_OP_LOG: - return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32; + case GGML_OP_CONV_3D: + return ggml_is_contiguous(op->src[0]) && + ggml_is_contiguous(op->src[1]) && + (op->src[0]->type == GGML_TYPE_F16 || op->src[0]->type == GGML_TYPE_F32) && + op->src[1]->type == GGML_TYPE_F32; case GGML_OP_SUM: return has_simdgroup_reduction && ggml_is_contiguous(op->src[0]); case GGML_OP_TRI: @@ -1022,9 +1160,8 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te case GGML_OP_MEAN: case GGML_OP_SOFT_MAX: case GGML_OP_GROUP_NORM: - return has_simdgroup_reduction && ggml_is_contiguous_rows(op->src[0]); case GGML_OP_L2_NORM: - return has_simdgroup_reduction && (op->ne[0] % 4 == 0 && ggml_is_contiguous_1(op->src[0])); + return has_simdgroup_reduction && ggml_is_contiguous_rows(op->src[0]); case GGML_OP_COUNT_EQUAL: return has_simdgroup_reduction && op->src[0]->type == GGML_TYPE_I32 && @@ -1044,10 +1181,10 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te op->src[1]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32 && (op->src[0]->type == GGML_TYPE_F16 || op->src[0]->type == GGML_TYPE_F32); - case GGML_OP_POOL_1D: - return false; case GGML_OP_UPSCALE: - return op->src[0]->type == GGML_TYPE_F32 && op->op_params[0] == GGML_SCALE_MODE_NEAREST && !(op->op_params[0] & GGML_SCALE_FLAG_ANTIALIAS); + return op->src[0]->type == GGML_TYPE_F32; + case GGML_OP_POOL_1D: + return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32; case GGML_OP_POOL_2D: return op->src[0]->type == GGML_TYPE_F32; case GGML_OP_PAD: @@ -1065,6 +1202,7 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te case GGML_OP_ARGSORT: case GGML_OP_TOP_K: case GGML_OP_ARANGE: + case GGML_OP_ROLL: return true; case GGML_OP_FLASH_ATTN_EXT: // for new head sizes, add checks here @@ -1078,17 +1216,32 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te op->src[0]->ne[0] != 112 && op->src[0]->ne[0] != 128 && op->src[0]->ne[0] != 192 && - op->src[0]->ne[0] != 256) { - return false; - } - if (op->src[0]->ne[0] == 576) { - // DeepSeek sizes - // TODO: disabled for now, until optmized + op->src[0]->ne[0] != 256 && + op->src[0]->ne[0] != 320 && + op->src[0]->ne[0] != 512 && + op->src[0]->ne[0] != 576) { return false; } if (op->src[1]->type != op->src[2]->type) { return false; } + switch (op->src[1]->type) { + case GGML_TYPE_F32: + case GGML_TYPE_F16: + case GGML_TYPE_Q8_0: + case GGML_TYPE_Q4_0: + case GGML_TYPE_Q4_1: + case GGML_TYPE_Q5_0: + case GGML_TYPE_Q5_1: + break; + case GGML_TYPE_BF16: + if (!has_bfloat) { + return false; + } + break; + default: + return false; + } return has_simdgroup_mm; // TODO: over-restricted for vec-kernels case GGML_OP_SSM_CONV: case GGML_OP_SSM_SCAN: @@ -1096,9 +1249,13 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te case GGML_OP_RWKV_WKV6: case GGML_OP_RWKV_WKV7: return true; + case GGML_OP_GATED_DELTA_NET: + return has_simdgroup_reduction && op->src[2]->ne[0] % 32 == 0; + case GGML_OP_SOLVE_TRI: case GGML_OP_MUL_MAT: case GGML_OP_MUL_MAT_ID: - return has_simdgroup_reduction; + return has_simdgroup_reduction && op->src[0]->type != GGML_TYPE_NVFP4; + case GGML_OP_SET: case GGML_OP_CPY: case GGML_OP_DUP: case GGML_OP_CONT: @@ -1110,6 +1267,7 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te case GGML_TYPE_F16: case GGML_TYPE_BF16: case GGML_TYPE_Q8_0: + case GGML_TYPE_Q1_0: case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_1: case GGML_TYPE_Q5_0: @@ -1136,6 +1294,7 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te default: return false; } + case GGML_TYPE_Q1_0: case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_1: case GGML_TYPE_Q5_0: @@ -1155,7 +1314,7 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te }; } case GGML_OP_GET_ROWS: - return true; + return op->src[0]->type != GGML_TYPE_NVFP4; case GGML_OP_SET_ROWS: { if (op->src[0]->type != GGML_TYPE_F32) { @@ -1177,6 +1336,8 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te return false; }; } + case GGML_OP_DIAG: + return true; case GGML_OP_OPT_STEP_ADAMW: case GGML_OP_OPT_STEP_SGD: return has_simdgroup_reduction; @@ -1218,7 +1379,7 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te bool use_residency_sets; // optional MTLResidencySet - // note: cannot use explicity "id<MTLResidencySet>" here because it is not available on certain OSes + // note: cannot use explicitly "id<MTLResidencySet>" here because it is not available on certain OSes id rset; // pointers to global device @@ -1344,8 +1505,8 @@ ggml_metal_buffer_t ggml_metal_buffer_init(ggml_metal_device_t dev, size_t size, res->all_data = ggml_metal_host_malloc(size_aligned); res->is_shared = true; } else { - // use virtual address from g_addr_device counter - res->all_data = (void *) atomic_fetch_add_explicit(&g_addr_device, size_aligned, memory_order_relaxed); + // use virtual address + res->all_data = (void *) atomic_fetch_add_explicit(&dev->addr_virt, size_aligned, memory_order_relaxed); res->is_shared = false; } res->all_size = size_aligned; @@ -1636,6 +1797,47 @@ void ggml_metal_buffer_get_tensor(ggml_metal_buffer_t buf, const struct ggml_ten } } +bool ggml_metal_buffer_cpy_tensor(ggml_metal_buffer_t buf_dst, const struct ggml_tensor * src, struct ggml_tensor * dst) { + ggml_metal_buffer_t buf_src = (ggml_metal_buffer_t)src->buffer->context; + + const size_t size = ggml_nbytes(src); + + // if both buffers are shared, we can use memcpy directly + if (buf_dst->is_shared && buf_src->is_shared) { + memcpy(dst->data, src->data, size); + return true; + } + + // for private buffers, we need to use Metal blit commands + @autoreleasepool { + struct ggml_metal_buffer_id bid_src = ggml_metal_buffer_get_id(buf_src, src); + struct ggml_metal_buffer_id bid_dst = ggml_metal_buffer_get_id(buf_dst, dst); + + if (bid_src.metal == nil || bid_dst.metal == nil) { + return false; + } + + id<MTLCommandBuffer> cmd_buf = [buf_dst->dev->mtl_queue commandBufferWithUnretainedReferences]; + + { + id<MTLBlitCommandEncoder> encoder = [cmd_buf blitCommandEncoder]; + + [encoder copyFromBuffer:bid_src.metal + sourceOffset:bid_src.offs + toBuffer:bid_dst.metal + destinationOffset:bid_dst.offs + size:size]; + + [encoder endEncoding]; + } + + [cmd_buf commit]; + [cmd_buf waitUntilCompleted]; + } + + return true; +} + void ggml_metal_buffer_clear(ggml_metal_buffer_t buf, uint8_t value) { if (buf->is_shared) { memset(buf->all_data, value, buf->all_size); diff --git a/ggml/src/ggml-metal/ggml-metal-impl.h b/ggml/src/ggml-metal/ggml-metal-impl.h index d3b0e732ec4..ff74cafb5b7 100644 --- a/ggml/src/ggml-metal/ggml-metal-impl.h +++ b/ggml/src/ggml-metal/ggml-metal-impl.h @@ -1,6 +1,19 @@ #ifndef GGML_METAL_IMPL #define GGML_METAL_IMPL +// kernel parameters for mat-mat threadgroups +// +// TODO: become function constants + +#define SZ_SIMDGROUP 16 +#define N_MM_NK 2 +#define N_MM_NK_TOTAL (SZ_SIMDGROUP * N_MM_NK) + +#define N_MM_BLOCK_X 4 +#define N_MM_BLOCK_Y 2 +#define N_MM_SIMD_GROUP_X 2 +#define N_MM_SIMD_GROUP_Y 2 + // kernel parameters for mat-vec threadgroups // // N_R0: number of src0 rows to process per simdgroup @@ -8,6 +21,9 @@ // // TODO: for optimal performance, become function of the device and work size +#define N_R0_Q1_0 8 +#define N_SG_Q1_0 2 + #define N_R0_Q4_0 4 #define N_SG_Q4_0 2 @@ -35,7 +51,7 @@ #define N_R0_Q4_K 2 #define N_SG_Q4_K 2 -#define N_R0_Q5_K 2 +#define N_R0_Q5_K 1 #define N_SG_Q5_K 2 #define N_R0_Q6_K 2 @@ -78,15 +94,57 @@ #define FC_MUL_MM 700 #define FC_ROPE 800 #define FC_SSM_CONV 900 -#define FC_COUNT_EQUAL 1000 +#define FC_SOLVE_TRI 1000 +#define FC_COUNT_EQUAL 1100 +#define FC_UNARY 1200 +#define FC_BIN 1300 +#define FC_SUM_ROWS 1400 +#define FC_UPSCALE 1500 +#define FC_GATED_DELTA_NET 1600 // op-specific constants -#define OP_FLASH_ATTN_EXT_NQPTG 8 +#define OP_FLASH_ATTN_EXT_NQPSG 8 #define OP_FLASH_ATTN_EXT_NCPSG 64 -#define OP_FLASH_ATTN_EXT_VEC_NQPTG 1 +#define OP_FLASH_ATTN_EXT_VEC_NQPSG 1 #define OP_FLASH_ATTN_EXT_VEC_NCPSG 32 +#define OP_UNARY_NUM_SCALE 10 +#define OP_UNARY_NUM_FILL 11 +#define OP_UNARY_NUM_CLAMP 12 +#define OP_UNARY_NUM_SQR 13 +#define OP_UNARY_NUM_SQRT 14 +#define OP_UNARY_NUM_SIN 15 +#define OP_UNARY_NUM_COS 16 +#define OP_UNARY_NUM_LOG 17 +#define OP_UNARY_NUM_LEAKY_RELU 18 + +#define OP_UNARY_NUM_TANH 100 +#define OP_UNARY_NUM_RELU 101 +#define OP_UNARY_NUM_SIGMOID 102 +#define OP_UNARY_NUM_GELU 103 +#define OP_UNARY_NUM_GELU_ERF 104 +#define OP_UNARY_NUM_GELU_QUICK 105 +#define OP_UNARY_NUM_SILU 106 +#define OP_UNARY_NUM_ELU 107 +#define OP_UNARY_NUM_NEG 108 +#define OP_UNARY_NUM_ABS 109 +#define OP_UNARY_NUM_SGN 110 +#define OP_UNARY_NUM_STEP 111 +#define OP_UNARY_NUM_HARDSWISH 112 +#define OP_UNARY_NUM_HARDSIGMOID 113 +#define OP_UNARY_NUM_EXP 114 +#define OP_UNARY_NUM_SOFTPLUS 115 +#define OP_UNARY_NUM_EXPM1 116 +#define OP_UNARY_NUM_FLOOR 117 +#define OP_UNARY_NUM_CEIL 118 +#define OP_UNARY_NUM_ROUND 119 +#define OP_UNARY_NUM_TRUNC 120 +#define OP_UNARY_NUM_XIELU 121 + +#define OP_SUM_ROWS_NUM_SUM_ROWS 10 +#define OP_SUM_ROWS_NUM_MEAN 11 + // kernel argument structs // // - element counters (e.g. ne00) typically use int32_t to reduce register usage @@ -122,6 +180,31 @@ typedef struct { int32_t dim; } ggml_metal_kargs_concat; +typedef struct { + int32_t ne00; + int32_t ne01; + int32_t ne02; + int32_t ne03; + uint64_t nb00; + uint64_t nb01; + uint64_t nb02; + uint64_t nb03; + int32_t ne0; + int32_t ne1; + int32_t ne2; + int32_t ne3; + uint64_t nb0; + uint64_t nb1; + uint64_t nb2; + uint64_t nb3; + float slope; + float scale; + float bias; + float val; + float min; + float max; +} ggml_metal_kargs_unary; + typedef struct { int32_t ne00; int32_t ne01; @@ -179,20 +262,6 @@ typedef struct { uint64_t nb3; } ggml_metal_kargs_repeat; -typedef struct { - float scale; - float bias; -} ggml_metal_kargs_scale; - -typedef struct { - float val; -} ggml_metal_kargs_fill; - -typedef struct { - float min; - float max; -} ggml_metal_kargs_clamp; - typedef struct { int64_t nk0; int64_t ne00; @@ -496,8 +565,21 @@ typedef struct { typedef struct { int32_t ne00; - int32_t ne00_4; + int32_t ne01; + int32_t ne02; + int32_t ne03; + uint64_t nb00; uint64_t nb01; + uint64_t nb02; + uint64_t nb03; + int32_t ne0; + int32_t ne1; + int32_t ne2; + int32_t ne3; + uint64_t nb0; + uint64_t nb1; + uint64_t nb2; + uint64_t nb3; float eps; } ggml_metal_kargs_l2_norm; @@ -582,6 +664,42 @@ typedef struct { int32_t KHW; // KH * KW, pre-computed on CPU to save GPU resources } ggml_metal_kargs_im2col; +typedef struct { + int32_t IW; + int32_t IH; + int32_t ID; + int32_t OW; + int32_t OH; + int32_t OD; + int32_t KW; + int32_t KH; + int32_t KD; + int32_t s0; + int32_t s1; + int32_t s2; + int32_t p0; + int32_t p1; + int32_t p2; + int32_t d0; + int32_t d1; + int32_t d2; + int32_t IC; + int32_t N; + int32_t OC; + uint64_t nb00; + uint64_t nb01; + uint64_t nb02; + uint64_t nb03; + uint64_t nb10; + uint64_t nb11; + uint64_t nb12; + uint64_t nb13; + uint64_t nb0; + uint64_t nb1; + uint64_t nb2; + uint64_t nb3; +} ggml_metal_kargs_conv_3d; + typedef struct{ int32_t ne00; uint64_t nb01; @@ -733,6 +851,71 @@ typedef struct { uint64_t nb0; } ggml_metal_kargs_ssm_scan; +typedef struct { + int32_t ne00; + int32_t ne01; + int32_t ne02; + int32_t ne03; + uint64_t nb00; + uint64_t nb01; + uint64_t nb02; + uint64_t nb03; + int32_t ne10; + int32_t ne11; + int32_t ne12; + int32_t ne13; + uint64_t nb10; + uint64_t nb11; + uint64_t nb12; + uint64_t nb13; + int32_t ne20; + int32_t ne21; + int32_t ne22; + int32_t ne23; + uint64_t nb20; + uint64_t nb21; + uint64_t nb22; + uint64_t nb23; + int32_t ns02; + int32_t ns12; + int32_t ns22; + int32_t ne0; + int32_t ne1; + int32_t ne2; + int32_t ne3; + uint64_t nb0; + uint64_t nb1; + uint64_t nb2; + uint64_t nb3; +} ggml_metal_kargs_gated_delta_net; + +typedef struct { + int32_t ne00; + int32_t ne01; + int32_t ne02; + int32_t ne03; + uint64_t nb00; + uint64_t nb01; + uint64_t nb02; + uint64_t nb03; + int32_t ne10; + int32_t ne11; + int32_t ne12; + int32_t ne13; + uint64_t nb10; + uint64_t nb11; + uint64_t nb12; + uint64_t nb13; + int32_t ne0; + int32_t ne1; + int32_t ne2; + int32_t ne3; + uint64_t nb0; + uint64_t nb1; + uint64_t nb2; + uint64_t nb3; +} ggml_metal_kargs_solve_tri; + typedef struct { int32_t ne00t; int32_t ne00; @@ -764,6 +947,25 @@ typedef struct { uint64_t nb3; } ggml_metal_kargs_set_rows; +typedef struct { + int32_t ne00; + int32_t ne01; + int32_t ne02; + int32_t ne03; + uint64_t nb00; + uint64_t nb01; + uint64_t nb02; + uint64_t nb03; + int32_t ne0; + int32_t ne1; + int32_t ne2; + int32_t ne3; + uint64_t nb0; + uint64_t nb1; + uint64_t nb2; + uint64_t nb3; +} ggml_metal_kargs_diag; + typedef struct { int64_t ne00; int64_t ne01; @@ -785,6 +987,7 @@ typedef struct { float sf1; float sf2; float sf3; + float poffs; } ggml_metal_kargs_upscale; typedef struct { @@ -827,16 +1030,35 @@ typedef struct { int32_t p1; } ggml_metal_kargs_pad_reflect_1d; +typedef struct { + int64_t ne00; + int64_t ne01; + int64_t ne02; + int64_t ne03; + uint64_t nb00; + uint64_t nb01; + uint64_t nb02; + uint64_t nb03; + int64_t ne0; + int64_t ne1; + int64_t ne2; + int64_t ne3; + uint64_t nb0; + uint64_t nb1; + uint64_t nb2; + uint64_t nb3; + int32_t s0; + int32_t s1; + int32_t s2; + int32_t s3; +} ggml_metal_kargs_roll; + typedef struct { uint64_t nb1; int dim; int max_period; } ggml_metal_kargs_timestep_embedding; -typedef struct { - float slope; -} ggml_metal_kargs_leaky_relu; - typedef struct { int32_t ne00; int32_t ne01; @@ -928,6 +1150,15 @@ typedef struct { int64_t np; } ggml_metal_kargs_pool_2d; +typedef struct { + int32_t k0; + int32_t s0; + int32_t p0; + int64_t IW; + int64_t OW; + int64_t np; +} ggml_metal_kargs_pool_1d; + typedef struct { int64_t ne00; uint64_t nb01; diff --git a/ggml/src/ggml-metal/ggml-metal-ops.cpp b/ggml/src/ggml-metal/ggml-metal-ops.cpp index a50b12b6f3b..e2ce56e9e28 100644 --- a/ggml/src/ggml-metal/ggml-metal-ops.cpp +++ b/ggml/src/ggml-metal/ggml-metal-ops.cpp @@ -203,6 +203,10 @@ static int ggml_metal_op_encode_impl(ggml_metal_op_t ctx, int idx) { GGML_ABORT("unsupported op"); } + if ((node->flags & GGML_TENSOR_FLAG_COMPUTE) == 0) { + return 1; + } + int n_fuse = 1; // check if the current node can run concurrently with other nodes before it @@ -283,17 +287,9 @@ static int ggml_metal_op_encode_impl(ggml_metal_op_t ctx, int idx) { n_fuse = ggml_metal_op_acc(ctx, idx); } break; case GGML_OP_SCALE: - { - n_fuse = ggml_metal_op_scale(ctx, idx); - } break; case GGML_OP_FILL: - { - n_fuse = ggml_metal_op_fill(ctx, idx); - } break; case GGML_OP_CLAMP: - { - n_fuse = ggml_metal_op_clamp(ctx, idx); - } break; + case GGML_OP_LEAKY_RELU: case GGML_OP_SQR: case GGML_OP_SQRT: case GGML_OP_SIN: @@ -337,6 +333,14 @@ static int ggml_metal_op_encode_impl(ggml_metal_op_t ctx, int idx) { { n_fuse = ggml_metal_op_rwkv(ctx, idx); } break; + case GGML_OP_GATED_DELTA_NET: + { + n_fuse = ggml_metal_op_gated_delta_net(ctx, idx); + } break; + case GGML_OP_SOLVE_TRI: + { + n_fuse = ggml_metal_op_solve_tri(ctx, idx); + } break; case GGML_OP_MUL_MAT: { n_fuse = ggml_metal_op_mul_mat(ctx, idx); @@ -353,6 +357,10 @@ static int ggml_metal_op_encode_impl(ggml_metal_op_t ctx, int idx) { { n_fuse = ggml_metal_op_set_rows(ctx, idx); } break; + case GGML_OP_DIAG: + { + n_fuse = ggml_metal_op_diag(ctx, idx); + } break; case GGML_OP_L2_NORM: { n_fuse = ggml_metal_op_l2_norm(ctx, idx); @@ -386,6 +394,10 @@ static int ggml_metal_op_encode_impl(ggml_metal_op_t ctx, int idx) { { n_fuse = ggml_metal_op_conv_transpose_2d(ctx, idx); } break; + case GGML_OP_CONV_3D: + { + n_fuse = ggml_metal_op_conv_3d(ctx, idx); + } break; case GGML_OP_UPSCALE: { n_fuse = ggml_metal_op_upscale(ctx, idx); @@ -398,6 +410,10 @@ static int ggml_metal_op_encode_impl(ggml_metal_op_t ctx, int idx) { { n_fuse = ggml_metal_op_pad_reflect_1d(ctx, idx); } break; + case GGML_OP_ROLL: + { + n_fuse = ggml_metal_op_roll(ctx, idx); + } break; case GGML_OP_ARANGE: { n_fuse = ggml_metal_op_arange(ctx, idx); @@ -414,10 +430,6 @@ static int ggml_metal_op_encode_impl(ggml_metal_op_t ctx, int idx) { { n_fuse = ggml_metal_op_top_k(ctx, idx); } break; - case GGML_OP_LEAKY_RELU: - { - n_fuse = ggml_metal_op_leaky_relu(ctx, idx); - } break; case GGML_OP_TRI: { n_fuse = ggml_metal_op_tri(ctx, idx); @@ -426,12 +438,20 @@ static int ggml_metal_op_encode_impl(ggml_metal_op_t ctx, int idx) { { n_fuse = ggml_metal_op_flash_attn_ext(ctx, idx); } break; + case GGML_OP_SET: + { + n_fuse = ggml_metal_op_set(ctx, idx); + } break; case GGML_OP_DUP: case GGML_OP_CPY: case GGML_OP_CONT: { n_fuse = ggml_metal_op_cpy(ctx, idx); } break; + case GGML_OP_POOL_1D: + { + n_fuse = ggml_metal_op_pool_1d(ctx, idx); + } break; case GGML_OP_POOL_2D: { n_fuse = ggml_metal_op_pool_2d(ctx, idx); @@ -544,9 +564,20 @@ int ggml_metal_op_concat(ggml_metal_op_t ctx, int idx) { ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), 2); ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 3); - const int nth = std::min(1024, ne0); + int nth = std::min(256, ne0); - ggml_metal_encoder_dispatch_threadgroups(enc, ne1, ne2, ne3, nth, 1, 1); + // when rows are small, we can batch them together in a single threadgroup + int nrptg = 1; + if (nth < 256) { + nrptg = std::min((256 + nth - 1) / nth, ne1); + if (nrptg * nth > 256) { + nrptg = 256 / nth; + } + } + + const int nw0 = (ne1 + nrptg - 1) / nrptg; + + ggml_metal_encoder_dispatch_threadgroups(enc, nw0, ne2, ne3, nth, nrptg, 1); return 1; } @@ -612,8 +643,8 @@ int ggml_metal_op_acc(ggml_metal_op_t ctx, int idx) { GGML_ASSERT(op->src[1]->type == GGML_TYPE_F32); GGML_ASSERT(op->type == GGML_TYPE_F32); - GGML_ASSERT(ggml_is_contiguous(op->src[0])); - GGML_ASSERT(ggml_is_contiguous(op->src[1])); + GGML_ASSERT(ggml_is_contiguous_rows(op->src[0])); + GGML_ASSERT(ggml_is_contiguous_rows(op->src[1])); const size_t pnb1 = ((const int32_t *) op->op_params)[0]; const size_t pnb2 = ((const int32_t *) op->op_params)[1]; @@ -623,7 +654,7 @@ int ggml_metal_op_acc(ggml_metal_op_t ctx, int idx) { const bool inplace = (bool) ((const int32_t *) op->op_params)[4]; if (!inplace) { - // run a separete kernel to cpy src->dst + // run a separate kernel to cpy src->dst // not sure how to avoid this // TODO: make a simpler cpy_bytes kernel @@ -663,10 +694,10 @@ int ggml_metal_op_acc(ggml_metal_op_t ctx, int idx) { } ggml_metal_kargs_bin args = { - /*.ne00 =*/ ne00, - /*.ne01 =*/ ne01, - /*.ne02 =*/ ne02, - /*.ne03 =*/ ne03, + /*.ne00 =*/ ne10, + /*.ne01 =*/ ne11, + /*.ne02 =*/ ne12, + /*.ne03 =*/ ne13, /*.nb00 =*/ nb00, /*.nb01 =*/ pnb1, /*.nb02 =*/ pnb2, @@ -679,10 +710,10 @@ int ggml_metal_op_acc(ggml_metal_op_t ctx, int idx) { /*.nb11 =*/ nb11, /*.nb12 =*/ nb12, /*.nb13 =*/ nb13, - /*.ne0 =*/ ne0, - /*.ne1 =*/ ne1, - /*.ne2 =*/ ne2, - /*.ne3 =*/ ne3, + /*.ne0 =*/ ne10, + /*.ne1 =*/ ne11, + /*.ne2 =*/ ne12, + /*.ne3 =*/ ne13, /*.nb0 =*/ nb0, /*.nb1 =*/ pnb1, /*.nb2 =*/ pnb2, @@ -691,7 +722,7 @@ int ggml_metal_op_acc(ggml_metal_op_t ctx, int idx) { /*.o1 =*/ { 0 }, }; - auto pipeline = ggml_metal_library_get_pipeline_bin(lib, GGML_OP_ADD, 1, false); + auto pipeline = ggml_metal_library_get_pipeline_bin_one(lib, GGML_OP_ADD); ggml_metal_encoder_set_pipeline(enc, pipeline); ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0); @@ -699,53 +730,20 @@ int ggml_metal_op_acc(ggml_metal_op_t ctx, int idx) { ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), 2); ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 3); - const int nth = std::min(ggml_metal_pipeline_max_theads_per_threadgroup(pipeline), ne00); - - ggml_metal_encoder_dispatch_threadgroups(enc, ne11, ne12, ne13, nth, 1, 1); + const int nth_max = MIN(256, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)); - return 1; -} - -int ggml_metal_op_scale(ggml_metal_op_t ctx, int idx) { - ggml_tensor * op = ctx->node(idx); - - ggml_metal_library_t lib = ctx->lib; - ggml_metal_encoder_t enc = ctx->enc; - - GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne); - GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb); - GGML_TENSOR_LOCALS( int32_t, ne, op, ne); - GGML_TENSOR_LOCALS(uint64_t, nb, op, nb); - - float scale; - float bias; - memcpy(&scale, ((const int32_t *) op->op_params) + 0, sizeof(float)); - memcpy(&bias, ((const int32_t *) op->op_params) + 1, sizeof(float)); - - ggml_metal_kargs_scale args = { - /*.scale =*/ scale, - /*.bias =*/ bias, - }; - - int64_t n = ggml_nelements(op); + int nth = 1; - if (n % 4 == 0) { - n /= 4; + while (2*nth < args.ne0 && nth < nth_max) { + nth *= 2; } - auto pipeline = ggml_metal_library_get_pipeline_unary(lib, op); - - ggml_metal_encoder_set_pipeline(enc, pipeline); - ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0); - ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1); - ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 2); - - ggml_metal_encoder_dispatch_threadgroups(enc, n, 1, 1, 1, 1, 1); + ggml_metal_encoder_dispatch_threadgroups(enc, ne11, ne12, ne13, nth, 1, 1); return 1; } -int ggml_metal_op_fill(ggml_metal_op_t ctx, int idx) { +int ggml_metal_op_unary(ggml_metal_op_t ctx, int idx) { ggml_tensor * op = ctx->node(idx); ggml_metal_library_t lib = ctx->lib; @@ -756,94 +754,85 @@ int ggml_metal_op_fill(ggml_metal_op_t ctx, int idx) { GGML_TENSOR_LOCALS( int32_t, ne, op, ne); GGML_TENSOR_LOCALS(uint64_t, nb, op, nb); - const float val = ggml_get_op_params_f32(op, 0); + GGML_ASSERT(ggml_is_contiguous_rows(op->src[0])); - ggml_metal_kargs_fill args = { - /*.val =*/ val - }; + ggml_metal_buffer_id bid_src0 = ggml_metal_get_buffer_id(op->src[0]); + ggml_metal_buffer_id bid_dst = ggml_metal_get_buffer_id(op); - int64_t n = ggml_nelements(op); + ggml_metal_kargs_unary args = { + /*.ne00 =*/ ne00, + /*.ne01 =*/ ne01, + /*.ne02 =*/ ne02, + /*.ne03 =*/ ne03, + /*.nb00 =*/ nb00, + /*.nb01 =*/ nb01, + /*.nb02 =*/ nb02, + /*.nb03 =*/ nb03, + /*.ne0 =*/ ne0, + /*.ne1 =*/ ne1, + /*.ne2 =*/ ne2, + /*.ne3 =*/ ne3, + /*.nb0 =*/ nb0, + /*.nb1 =*/ nb1, + /*.nb2 =*/ nb2, + /*.nb3 =*/ nb3, + /*.slope =*/ 0.0, + /*.scale =*/ 0.0, + /*.bias =*/ 0.0, + /*.val =*/ 0.0, + /*.min =*/ 0.0, + /*.max =*/ 0.0, + }; - if (n % 4 == 0) { - n /= 4; + if (op->op == GGML_OP_LEAKY_RELU) { + args.slope = ggml_get_op_params_f32(op, 0); } - auto pipeline = ggml_metal_library_get_pipeline_unary(lib, op); - - ggml_metal_encoder_set_pipeline(enc, pipeline); - ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0); - ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1); - ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 2); - - ggml_metal_encoder_dispatch_threadgroups(enc, n, 1, 1, 1, 1, 1); - - return 1; -} - -int ggml_metal_op_clamp(ggml_metal_op_t ctx, int idx) { - ggml_tensor * op = ctx->node(idx); - - ggml_metal_library_t lib = ctx->lib; - ggml_metal_encoder_t enc = ctx->enc; - - GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne); - GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb); - GGML_TENSOR_LOCALS( int32_t, ne, op, ne); - GGML_TENSOR_LOCALS(uint64_t, nb, op, nb); - - float min; - float max; - memcpy(&min, ((const int32_t *) op->op_params) + 0, sizeof(float)); - memcpy(&max, ((const int32_t *) op->op_params) + 1, sizeof(float)); + if (op->op == GGML_OP_SCALE) { + args.scale = ggml_get_op_params_f32(op, 0); + args.bias = ggml_get_op_params_f32(op, 1); + } - ggml_metal_kargs_clamp args = { - /*.min =*/ min, - /*.max =*/ max, - }; + if (op->op == GGML_OP_FILL) { + args.val = ggml_get_op_params_f32(op, 0); + } - int64_t n = ggml_nelements(op); + if (op->op == GGML_OP_CLAMP) { + args.min = ggml_get_op_params_f32(op, 0); + args.max = ggml_get_op_params_f32(op, 1); + } - if (n % 4 == 0) { - n /= 4; + if (op->op == GGML_OP_UNARY && ggml_get_unary_op(op) == GGML_UNARY_OP_XIELU) { + args.slope = ggml_get_op_params_f32(op, 1); // alpha_n + args.scale = ggml_get_op_params_f32(op, 2); // alpha_p + args.bias = ggml_get_op_params_f32(op, 3); // beta + args.val = ggml_get_op_params_f32(op, 4); // eps } auto pipeline = ggml_metal_library_get_pipeline_unary(lib, op); + if (pipeline.c4) { + args.ne00 = ne00/4; + args.ne0 = ne0/4; + } + ggml_metal_encoder_set_pipeline(enc, pipeline); ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0); - ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1); - ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 2); - - ggml_metal_encoder_dispatch_threadgroups(enc, n, 1, 1, 1, 1, 1); - - return 1; -} - -int ggml_metal_op_unary(ggml_metal_op_t ctx, int idx) { - ggml_tensor * op = ctx->node(idx); - - ggml_metal_library_t lib = ctx->lib; - ggml_metal_encoder_t enc = ctx->enc; + ggml_metal_encoder_set_buffer (enc, bid_src0, 1); + ggml_metal_encoder_set_buffer (enc, bid_dst, 2); - GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne); - GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb); - GGML_TENSOR_LOCALS( int32_t, ne, op, ne); - GGML_TENSOR_LOCALS(uint64_t, nb, op, nb); + if (pipeline.cnt) { + const int n = pipeline.c4 ? ggml_nelements(op)/4 : ggml_nelements(op); - int64_t n = ggml_nelements(op); + ggml_metal_encoder_dispatch_threadgroups(enc, n, 1, 1, 1, 1, 1); + } else { + const int nth_max = MIN(256, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)); + const int nth = MIN(args.ne00, nth_max); + const int nk0 = (args.ne00 + nth - 1)/nth; - if (n % 4 == 0) { - n /= 4; + ggml_metal_encoder_dispatch_threadgroups(enc, nk0*ne01, ne02, ne03, nth, 1, 1); } - auto pipeline = ggml_metal_library_get_pipeline_unary(lib, op); - - ggml_metal_encoder_set_pipeline(enc, pipeline); - ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 0); - ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 1); - - ggml_metal_encoder_dispatch_threadgroups(enc, n, 1, 1, 1, 1, 1); - return 1; } @@ -953,6 +942,11 @@ int ggml_metal_op_sum_rows(ggml_metal_op_t ctx, int idx) { GGML_TENSOR_LOCALS( int32_t, ne, op, ne); GGML_TENSOR_LOCALS(uint64_t, nb, op, nb); + GGML_ASSERT(ggml_is_contiguous_rows(op->src[0])); + + ggml_metal_buffer_id bid_src0 = ggml_metal_get_buffer_id(op->src[0]); + ggml_metal_buffer_id bid_dst = ggml_metal_get_buffer_id(op); + ggml_metal_kargs_sum_rows args = { /*.ne00 =*/ ne00, /*.ne01 =*/ ne01, @@ -974,21 +968,26 @@ int ggml_metal_op_sum_rows(ggml_metal_op_t ctx, int idx) { auto pipeline = ggml_metal_library_get_pipeline_sum_rows(lib, op); + if (pipeline.c4) { + args.ne00 = ne00/4; + args.ne0 = ne0/4; + } + int nth = 32; // SIMD width - while (nth < ne00 && nth < ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) { + while (nth < args.ne00 && nth < ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) { nth *= 2; } nth = std::min(nth, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)); - nth = std::min(nth, ne00); + nth = std::min(nth, (int) args.ne00); const size_t smem = pipeline.smem; ggml_metal_encoder_set_pipeline(enc, pipeline); ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0); - ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1); - ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 2); + ggml_metal_encoder_set_buffer (enc, bid_src0, 1); + ggml_metal_encoder_set_buffer (enc, bid_dst, 2); ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0); @@ -1247,6 +1246,48 @@ int ggml_metal_op_set_rows(ggml_metal_op_t ctx, int idx) { return 1; } +int ggml_metal_op_diag(ggml_metal_op_t ctx, int idx) { + ggml_tensor * op = ctx->node(idx); + + ggml_metal_library_t lib = ctx->lib; + ggml_metal_encoder_t enc = ctx->enc; + + GGML_TENSOR_LOCALS(int32_t, ne0, op->src[0], ne); + GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb); + GGML_TENSOR_LOCALS(int32_t, ne, op, ne); + GGML_TENSOR_LOCALS(uint64_t, nb, op, nb); + + ggml_metal_kargs_diag args = { + /*.ne00 =*/ne00, + /*.ne01 =*/ne01, + /*.ne02 =*/ne02, + /*.ne03 =*/ne03, + /*.nb00 =*/nb00, + /*.nb01 =*/nb01, + /*.nb02 =*/nb02, + /*.nb03 =*/nb03, + /*.ne0 =*/ne0, + /*.ne1 =*/ne1, + /*.ne2 =*/ne2, + /*.ne3 =*/ne3, + /*.nb0 =*/nb0, + /*.nb1 =*/nb1, + /*.nb2 =*/nb2, + /*.nb3 =*/nb3, + }; + + auto pipeline = ggml_metal_library_get_pipeline_diag(lib, op); + + ggml_metal_encoder_set_pipeline(enc, pipeline); + ggml_metal_encoder_set_bytes(enc, &args, sizeof(args), 0); + ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[0]), 1); + ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op), 2); + + ggml_metal_encoder_dispatch_threadgroups(enc, ne1, ne2, ne3, 32, 1, 1); + + return 1; +} + int ggml_metal_op_soft_max(ggml_metal_op_t ctx, int idx) { ggml_tensor * op = ctx->node(idx); @@ -1524,27 +1565,287 @@ int ggml_metal_op_rwkv(ggml_metal_op_t ctx, int idx) { const int64_t C = op->ne[0]; const int64_t H = op->src[0]->ne[1]; - auto pipeline = ggml_metal_library_get_pipeline_rwkv(lib, op); + auto pipeline = ggml_metal_library_get_pipeline_rwkv(lib, op); + + int ida = 0; + + ggml_metal_encoder_set_pipeline(enc, pipeline); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), ida++); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), ida++); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[2]), ida++); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[3]), ida++); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[4]), ida++); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[5]), ida++); + if (op->op == GGML_OP_RWKV_WKV7) { + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[6]), ida++); + } + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), ida++); + ggml_metal_encoder_set_bytes (enc, (void *) &B, sizeof(B), ida++); + ggml_metal_encoder_set_bytes (enc, (void *) &T, sizeof(T), ida++); + ggml_metal_encoder_set_bytes (enc, (void *) &C, sizeof(C), ida++); + ggml_metal_encoder_set_bytes (enc, (void *) &H, sizeof(H), ida++); + + ggml_metal_encoder_dispatch_threadgroups(enc, B * H, 1, 1, C/H, 1, 1); + + return 1; +} + +int ggml_metal_op_gated_delta_net(ggml_metal_op_t ctx, int idx) { + ggml_tensor * op = ctx->node(idx); + + ggml_metal_library_t lib = ctx->lib; + ggml_metal_encoder_t enc = ctx->enc; + + + GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne); + GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb); + GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne); + GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb); + GGML_TENSOR_LOCALS( int32_t, ne2, op->src[2], ne); + GGML_TENSOR_LOCALS(uint64_t, nb2, op->src[2], nb); + GGML_TENSOR_LOCALS( int32_t, ne, op, ne); + GGML_TENSOR_LOCALS(uint64_t, nb, op, nb); + + auto pipeline = ggml_metal_library_get_pipeline_gated_delta_net(lib, op); + + int ida = 0; + + ggml_metal_kargs_gated_delta_net args = { + /*.ne00 =*/ ne00, + /*.ne01 =*/ ne01, + /*.ne02 =*/ ne02, + /*.ne03 =*/ ne03, + /*.nb00 =*/ nb00, + /*.nb01 =*/ nb01, + /*.nb02 =*/ nb02, + /*.nb03 =*/ nb03, + /*.ne10 =*/ ne10, + /*.ne11 =*/ ne11, + /*.ne12 =*/ ne12, + /*.ne13 =*/ ne13, + /*.nb10 =*/ nb10, + /*.nb11 =*/ nb11, + /*.nb12 =*/ nb12, + /*.nb13 =*/ nb13, + /*.ne20 =*/ ne20, + /*.ne21 =*/ ne21, + /*.ne22 =*/ ne22, + /*.ne23 =*/ ne23, + /*.nb20 =*/ nb20, + /*.nb21 =*/ nb21, + /*.nb22 =*/ nb22, + /*.nb23 =*/ nb23, + /*.ns02 =*/ (int32_t) (nb02/sizeof(float)), + /*.ns12 =*/ (int32_t) (nb12/sizeof(float)), + /*.ns22 =*/ (int32_t) (nb22/sizeof(float)), + /*.ne0 =*/ ne0, + /*.ne1 =*/ ne1, + /*.ne2 =*/ ne2, + /*.ne3 =*/ ne3, + /*.nb0 =*/ nb0, + /*.nb1 =*/ nb1, + /*.nb2 =*/ nb2, + /*.nb3 =*/ nb3, + }; + + ggml_metal_encoder_set_pipeline(enc, pipeline); + ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), ida++); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), ida++); // q + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), ida++); // k + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[2]), ida++); // v + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[3]), ida++); // gate + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[4]), ida++); // beta + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[5]), ida++); // state + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), ida++); // dst + + const int nsg = pipeline.nsg; + + ggml_metal_encoder_dispatch_threadgroups(enc, op->src[2]->ne[0]/nsg, op->src[2]->ne[1], op->src[2]->ne[3], 32, nsg, 1); + + return 1; +} + +int ggml_metal_op_solve_tri(ggml_metal_op_t ctx, int idx) { + ggml_tensor * op = ctx->node(idx); + + ggml_metal_library_t lib = ctx->lib; + ggml_metal_encoder_t enc = ctx->enc; + + GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne); + GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb); + GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne); + GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb); + GGML_TENSOR_LOCALS( int32_t, ne, op, ne); + GGML_TENSOR_LOCALS(uint64_t, nb, op, nb); + + ggml_metal_kargs_solve_tri args = { + /*.ne00 =*/ ne00, + /*.ne01 =*/ ne01, + /*.ne02 =*/ ne02, + /*.ne03 =*/ ne03, + /*.nb00 =*/ nb00, + /*.nb01 =*/ nb01, + /*.nb02 =*/ nb02, + /*.nb03 =*/ nb03, + /*.ne10 =*/ ne10, + /*.ne11 =*/ ne11, + /*.ne12 =*/ ne12, + /*.ne13 =*/ ne13, + /*.nb10 =*/ nb10, + /*.nb11 =*/ nb11, + /*.nb12 =*/ nb12, + /*.nb13 =*/ nb13, + /*.ne0 =*/ ne0, + /*.ne1 =*/ ne1, + /*.ne2 =*/ ne2, + /*.ne3 =*/ ne3, + /*.nb0 =*/ nb0, + /*.nb1 =*/ nb1, + /*.nb2 =*/ nb2, + /*.nb3 =*/ nb3, + }; + + auto pipeline = ggml_metal_library_get_pipeline_solve_tri(lib, op); + + ggml_metal_encoder_set_pipeline(enc, pipeline); + ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), 2); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 3); + + const int nsg = pipeline.nsg; + + ggml_metal_encoder_set_threadgroup_memory_size(enc, pipeline.smem, 0); + + ggml_metal_encoder_dispatch_threadgroups(enc, (ne10 + nsg - 1)/nsg, ne02, ne03, 32, nsg, 1); + + return 1; +} + +int ggml_metal_op_set(ggml_metal_op_t ctx, int idx) { + ggml_tensor * op = ctx->node(idx); + + ggml_metal_library_t lib = ctx->lib; + ggml_metal_encoder_t enc = ctx->enc; + + GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne); + GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb); + GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne); + GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb); + GGML_TENSOR_LOCALS( int32_t, ne, op, ne); + GGML_TENSOR_LOCALS(uint64_t, nb, op, nb); + + ggml_metal_buffer_id bid_src0 = ggml_metal_get_buffer_id(op->src[0]); + ggml_metal_buffer_id bid_src1 = ggml_metal_get_buffer_id(op->src[1]); + ggml_metal_buffer_id bid_dst = ggml_metal_get_buffer_id(op); + + const size_t pnb1 = ((const int32_t *) op->op_params)[0]; + const size_t pnb2 = ((const int32_t *) op->op_params)[1]; + const size_t pnb3 = ((const int32_t *) op->op_params)[2]; + const size_t offs = ((const int32_t *) op->op_params)[3]; + + const bool inplace = (bool) ((const int32_t *) op->op_params)[4]; + + if (!inplace) { + // run a separate kernel to cpy src->dst + // not sure how to avoid this + // TODO: make a simpler cpy_bytes kernel + + //const id<MTLComputePipelineState> pipeline = ctx->pipelines[GGML_METAL_PIPELINE_TYPE_CPY_F32_F32].obj; + auto pipeline = ggml_metal_library_get_pipeline_cpy(lib, op->src[0]->type, op->type); + + ggml_metal_kargs_cpy args = { + /*.nk0 =*/ ne00, + /*.ne00 =*/ ne00, + /*.ne01 =*/ ne01, + /*.ne02 =*/ ne02, + /*.ne03 =*/ ne03, + /*.nb00 =*/ nb00, + /*.nb01 =*/ nb01, + /*.nb02 =*/ nb02, + /*.nb03 =*/ nb03, + /*.ne0 =*/ ne0, + /*.ne1 =*/ ne1, + /*.ne2 =*/ ne2, + /*.ne3 =*/ ne3, + /*.nb0 =*/ nb0, + /*.nb1 =*/ nb1, + /*.nb2 =*/ nb2, + /*.nb3 =*/ nb3, + }; + + ggml_metal_encoder_set_pipeline(enc, pipeline); + ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0); + ggml_metal_encoder_set_buffer (enc, bid_src0, 1); + ggml_metal_encoder_set_buffer (enc, bid_dst, 2); + + const int nth = std::min(ggml_metal_pipeline_max_theads_per_threadgroup(pipeline), ne00); + + ggml_metal_encoder_dispatch_threadgroups(enc, ne01, ne02, ne03, nth, 1, 1); + + ggml_metal_op_concurrency_reset(ctx); + } + + auto pipeline = ggml_metal_library_get_pipeline_cpy(lib, op->src[1]->type, op->type); + + GGML_ASSERT(ne10 % ggml_blck_size(op->src[1]->type) == 0); + + int64_t nk0 = ne10; + if (ggml_is_quantized(op->src[1]->type)) { + nk0 = ne10/16; + } else if (ggml_is_quantized(op->type)) { + nk0 = ne10/ggml_blck_size(op->type); + } + + int nth = std::min<int>(nk0*ne11, 256); + + // when rows are small, we can batch them together in a single threadgroup + int nrptg = 1; + + // TODO: relax this constraint in the future + if (ggml_blck_size(op->src[1]->type) == 1 && ggml_blck_size(op->type) == 1) { + if (nth > nk0) { + nrptg = (nth + nk0 - 1)/nk0; + nth = nk0; + + if (nrptg*nth > 256) { + nrptg--; + } + } + } + + nth = std::min<int>(nth, nk0); + + ggml_metal_kargs_cpy args = { + /*.nk0 =*/ nk0, + /*.ne00 =*/ ne10, + /*.ne01 =*/ ne11, + /*.ne02 =*/ ne12, + /*.ne03 =*/ ne13, + /*.nb00 =*/ nb10, + /*.nb01 =*/ nb11, + /*.nb02 =*/ nb12, + /*.nb03 =*/ nb13, + /*.ne0 =*/ ne10, + /*.ne1 =*/ ne11, + /*.ne2 =*/ ne12, + /*.ne3 =*/ ne13, + /*.nb0 =*/ ggml_element_size(op), + /*.nb1 =*/ pnb1, + /*.nb2 =*/ pnb2, + /*.nb3 =*/ pnb3, + }; + + const int nw0 = nrptg == 1 ? (nk0 + nth - 1)/nth : 1; - int ida = 0; + bid_dst.offs += offs; ggml_metal_encoder_set_pipeline(enc, pipeline); - ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), ida++); - ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), ida++); - ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[2]), ida++); - ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[3]), ida++); - ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[4]), ida++); - ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[5]), ida++); - if (op->op == GGML_OP_RWKV_WKV7) { - ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[6]), ida++); - } - ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), ida++); - ggml_metal_encoder_set_bytes (enc, (void *) &B, sizeof(B), ida++); - ggml_metal_encoder_set_bytes (enc, (void *) &T, sizeof(T), ida++); - ggml_metal_encoder_set_bytes (enc, (void *) &C, sizeof(C), ida++); - ggml_metal_encoder_set_bytes (enc, (void *) &H, sizeof(H), ida++); + ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0); + ggml_metal_encoder_set_buffer (enc, bid_src1, 1); + ggml_metal_encoder_set_buffer (enc, bid_dst, 2); - ggml_metal_encoder_dispatch_threadgroups(enc, B * H, 1, 1, C/H, 1, 1); + ggml_metal_encoder_dispatch_threadgroups(enc, nw0*(ne11 + nrptg - 1)/nrptg, ne12, ne13, nth, nrptg, 1); return 1; } @@ -1571,7 +1872,7 @@ int ggml_metal_op_cpy(ggml_metal_op_t ctx, int idx) { nk0 = ne00/ggml_blck_size(op->type); } - int nth = std::min<int>(nk0, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)); + int nth = std::min<int>(nk0*ne01, 256); // when rows are small, we can batch them together in a single threadgroup int nrptg = 1; @@ -1582,7 +1883,7 @@ int ggml_metal_op_cpy(ggml_metal_op_t ctx, int idx) { nrptg = (nth + nk0 - 1)/nk0; nth = nk0; - if (nrptg*nth > ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) { + if (nrptg*nth > 256) { nrptg--; } } @@ -1622,6 +1923,54 @@ int ggml_metal_op_cpy(ggml_metal_op_t ctx, int idx) { return 1; } +int ggml_metal_op_pool_1d(ggml_metal_op_t ctx, int idx) { + ggml_tensor * op = ctx->node(idx); + + ggml_metal_library_t lib = ctx->lib; + ggml_metal_encoder_t enc = ctx->enc; + + GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne); + GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb); + GGML_TENSOR_LOCALS( int32_t, ne, op, ne); + GGML_TENSOR_LOCALS(uint64_t, nb, op, nb); + + const int32_t * opts = op->op_params; + ggml_op_pool op_pool = (ggml_op_pool) opts[0]; + + const int32_t k0 = opts[1]; + const int32_t s0 = opts[2]; + const int32_t p0 = opts[3]; + + const int64_t IW = op->src[0]->ne[0]; + const int64_t OW = op->ne[0]; + + const int64_t np = ggml_nelements(op); + + ggml_metal_kargs_pool_1d args_pool_1d = { + /* .k0 = */ k0, + /* .s0 = */ s0, + /* .p0 = */ p0, + /* .IW = */ IW, + /* .OW = */ OW, + /* .np = */ np + }; + + auto pipeline = ggml_metal_library_get_pipeline_pool_1d(lib, op, op_pool); + + const int nth = std::min(ggml_metal_pipeline_max_theads_per_threadgroup(pipeline), (int) np); + const int ntg = (np + nth - 1) / nth; + + ggml_metal_encoder_set_pipeline(enc, pipeline); + ggml_metal_encoder_set_bytes (enc, &args_pool_1d, sizeof(args_pool_1d), 0); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 2); + + ggml_metal_encoder_dispatch_threadgroups(enc, ntg, 1, 1, nth, 1, 1); + + return 1; +} + + int ggml_metal_op_pool_2d(ggml_metal_op_t ctx, int idx) { ggml_tensor * op = ctx->node(idx); @@ -1717,6 +2066,8 @@ int ggml_metal_op_mul_mat(ggml_metal_op_t ctx, int idx) { ( op->src[0]->type == GGML_TYPE_F32 || // TODO: helper function op->src[0]->type == GGML_TYPE_F16 || + op->src[0]->type == GGML_TYPE_BF16 || + op->src[0]->type == GGML_TYPE_Q1_0 || op->src[0]->type == GGML_TYPE_Q4_0 || op->src[0]->type == GGML_TYPE_Q4_1 || op->src[0]->type == GGML_TYPE_Q5_0 || @@ -1731,6 +2082,8 @@ int ggml_metal_op_mul_mat(ggml_metal_op_t ctx, int idx) { op->src[0]->type == GGML_TYPE_Q4_K || op->src[0]->type == GGML_TYPE_Q5_K || op->src[0]->type == GGML_TYPE_Q6_K || + op->src[0]->type == GGML_TYPE_Q2_K || + op->src[0]->type == GGML_TYPE_Q3_K || false) && (ne11 >= 4 && ne11 <= 8) ) ) @@ -1759,7 +2112,7 @@ int ggml_metal_op_mul_mat(ggml_metal_op_t ctx, int idx) { const int16_t r0ptg = nypsg*nsg; // num src0 rows per threadgroup int16_t r1ptg = 4; // num src1 rows per threadgroup - // note: not sure how optimal are those across all different hardware. there might be someting cleverer + // note: not sure how optimal are those across all different hardware. there might be something cleverer switch (ne11) { case 2: r1ptg = 2; break; @@ -1776,7 +2129,7 @@ int ggml_metal_op_mul_mat(ggml_metal_op_t ctx, int idx) { GGML_ABORT("unsupported ne11"); }; - auto pipeline = ggml_metal_library_get_pipeline_mul_mv_ext(lib, op->src[0]->type, op->src[1]->type, nsg, nxpsg, r1ptg); + auto pipeline = ggml_metal_library_get_pipeline_mul_mv_ext(lib, op, nsg, nxpsg, r1ptg); ggml_metal_kargs_mul_mv_ext args = { /*.ne00 =*/ ne00, @@ -1851,7 +2204,12 @@ int ggml_metal_op_mul_mat(ggml_metal_op_t ctx, int idx) { const size_t smem = pipeline.smem; ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0); - ggml_metal_encoder_dispatch_threadgroups(enc, ((ne11 + 31)/32), ((ne01 + 63)/64), ne12*ne13, 128, 1, 1); + + const int nr0 = pipeline.nr0; + const int nr1 = pipeline.nr1; + const int nsg = pipeline.nsg; + + ggml_metal_encoder_dispatch_threadgroups(enc, ((ne11 + nr1 - 1) / nr1), ((ne01 + nr0 - 1) / nr0), ne12 * ne13, 32, nsg, 1); } else { auto pipeline = ggml_metal_library_get_pipeline_mul_mv(lib, op); @@ -2239,7 +2597,7 @@ size_t ggml_metal_op_flash_attn_ext_extra_blk(const ggml_tensor * op) { // return res; //} - const int nqptg = is_vec ? OP_FLASH_ATTN_EXT_VEC_NQPTG : OP_FLASH_ATTN_EXT_NQPTG; + const int nqptg = is_vec ? OP_FLASH_ATTN_EXT_VEC_NQPSG : OP_FLASH_ATTN_EXT_NQPSG; const int ncpsg = is_vec ? OP_FLASH_ATTN_EXT_VEC_NCPSG : OP_FLASH_ATTN_EXT_NCPSG; const int64_t ne1 = (ne01 + nqptg - 1)/nqptg; @@ -2355,7 +2713,7 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) { if (!ggml_metal_op_flash_attn_ext_use_vec(op)) { // half8x8 kernel - const int nqptg = OP_FLASH_ATTN_EXT_NQPTG; // queries per threadgroup + const int nqptg = OP_FLASH_ATTN_EXT_NQPSG; // queries per threadgroup const int ncpsg = OP_FLASH_ATTN_EXT_NCPSG; // cache values per simdgroup GGML_ASSERT(nqptg <= 32); @@ -2464,7 +2822,7 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) { // simdgroups per threadgroup (a.k.a. warps) //nsg = ne01 <= nqptg ? MAX(4, MIN(nsgmax, MIN(ne11/ncpsg, (int64_t) pipeline.maxTotalThreadsPerThreadgroup/32))) : 4; - int32_t nsg = 4; + int32_t nsg = ne00 >= 512 ? 8 : 4; const size_t smem = FATTN_SMEM(nsg); @@ -2522,9 +2880,9 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) { #undef FATTN_SMEM } else { // half4x4 kernel - const int nqptg = OP_FLASH_ATTN_EXT_VEC_NQPTG; // queries per threadgroup + const int nqptg = OP_FLASH_ATTN_EXT_VEC_NQPSG; // queries per threadgroup const int ncpsg = OP_FLASH_ATTN_EXT_VEC_NCPSG; // cache values per simdgroup !! sync with kernel template arguments !! - const int nkpsg = 1*ncpsg; + const int nhptg = 1; // heads per threadgroup GGML_ASSERT(nqptg <= 32); GGML_ASSERT(nqptg % 1 == 0); @@ -2576,6 +2934,9 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) { ggml_metal_op_concurrency_reset(ctx); } + // note: for simplicity assume the K is larger or equal than V + GGML_ASSERT(ne10 >= ne20); + // ne00 + 2*ncpsg*(nsg) // for each query, we load it as f16 in shared memory (ne00) // and store the soft_max values and the mask @@ -2583,28 +2944,9 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) { // ne20*(nsg) // each simdgroup has a full f32 head vector in shared mem to accumulate results // -#define FATTN_SMEM(nsg) (GGML_PAD((nqptg*(GGML_PAD(ne00, 128) + 4*ncpsg*(nsg)) + 2*GGML_PAD(ne20, 128)*(nsg))*(sizeof(float)/2), 16)) - - int64_t nsgmax = 2; - while (true) { - const size_t smem = FATTN_SMEM(nsgmax); - // avoid using more than half of the threadgroup memory - can cause slow downs especially for large head sizes - if (smem > props_dev->max_theadgroup_memory_size/2) { - break; - } - nsgmax *= 2; - } - nsgmax /= 2; - - // simdgroups per threadgroup (a.k.a. warps) - //const int64_t nsgt = MAX(2, MIN(nsgmax, MIN((ne11 + nkpsg - 1)/(nkpsg), (int64_t) pipeline.maxTotalThreadsPerThreadgroup/32))); - const int64_t nsgt = MAX(2, MIN(nsgmax, MIN((ne11 + nkpsg - 1)/(nkpsg), (int64_t) 1024/32))); +#define FATTN_SMEM(nsg) (GGML_PAD(((GGML_PAD(ne00, 128) + 4*ncpsg + 2*GGML_PAD(ne20, 128))*(nsg))*(sizeof(float)/2), 16)) int64_t nsg = 1; - while (nsg <= nsgt) { - nsg *= 2; - } - nsg /= 2; // workgroups // each workgroup handles nsg*nkpsg cache values @@ -2617,7 +2959,7 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) { } else { nwg = 32; nsg = 1; - while (2*nwg*nsg*nkpsg < ne11 && nsg < 4) { + while (2*nwg*nsg*ncpsg < ne11 && nsg < 4) { nsg *= 2; } } @@ -2683,7 +3025,7 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) { ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0); - ggml_metal_encoder_dispatch_threadgroups(enc, (ne01 + nqptg - 1)/nqptg, ne02, ne03*nwg, 32, nsg, 1); + ggml_metal_encoder_dispatch_threadgroups(enc, (ne01 + nqptg - 1)/nqptg, (ne02 + nhptg - 1)/nhptg, ne03*nwg, 32, nsg, 1); } else { // sanity checks assert(ggml_metal_op_flash_attn_ext_extra_tmp(op) != 0); @@ -2696,7 +3038,7 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) { ggml_metal_encoder_set_buffer(enc, bid_tmp, 7); ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0); - ggml_metal_encoder_dispatch_threadgroups(enc, (ne01 + nqptg - 1)/nqptg, ne02, ne03*nwg, 32, nsg, 1); + ggml_metal_encoder_dispatch_threadgroups(enc, (ne01 + nqptg - 1)/nqptg, (ne02 + nhptg - 1)/nhptg, ne03*nwg, 32, nsg, 1); // sync the 2 kernels ggml_metal_op_concurrency_reset(ctx); @@ -2748,8 +3090,6 @@ int ggml_metal_op_bin(ggml_metal_op_t ctx, int idx) { GGML_ASSERT(ggml_is_contiguous_rows(op->src[0])); GGML_ASSERT(ggml_is_contiguous_rows(op->src[1])); - bool bcast_row = false; - ggml_metal_buffer_id bid_src0 = ggml_metal_get_buffer_id(op->src[0]); ggml_metal_buffer_id bid_src1 = ggml_metal_get_buffer_id(op->src[1]); ggml_metal_buffer_id bid_dst = ggml_metal_get_buffer_id(op); @@ -2843,18 +3183,7 @@ int ggml_metal_op_bin(ggml_metal_op_t ctx, int idx) { struct ggml_metal_pipeline_with_params pipeline; - if (ggml_nelements(op->src[1]) == ne10 && ggml_is_contiguous(op->src[1]) && ne00 % 4 == 0 && ne10 % 4 == 0) { - GGML_ASSERT(ggml_is_contiguous(op->src[0])); - - // src1 is a row - GGML_ASSERT(ne11 == 1); - - pipeline = ggml_metal_library_get_pipeline_bin(lib, op->op, n_fuse, true); - - bcast_row = true; - } else { - pipeline = ggml_metal_library_get_pipeline_bin(lib, op->op, n_fuse, false); - } + pipeline = ggml_metal_library_get_pipeline_bin(lib, op, n_fuse); if (n_fuse > 1) { bid_dst = ggml_metal_get_buffer_id(ctx->node(idx + n_fuse - 1)); @@ -2868,20 +3197,26 @@ int ggml_metal_op_bin(ggml_metal_op_t ctx, int idx) { } } + if (pipeline.c4) { + args.ne00 = ne00/4; + args.ne10 = ne10/4; + args.ne0 = ne0/4; + } + ggml_metal_encoder_set_pipeline(enc, pipeline); ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0); ggml_metal_encoder_set_buffer (enc, bid_src0, 1); ggml_metal_encoder_set_buffer (enc, bid_src1, 2); ggml_metal_encoder_set_buffer (enc, bid_dst, 3); - if (bcast_row) { - const int64_t n = ggml_nelements(op)/4; - - ggml_metal_encoder_dispatch_threadgroups(enc, n, 1, 1, 1, 1, 1); + if (pipeline.cnt) { + ggml_metal_encoder_dispatch_threadgroups(enc, args.ne0, ggml_nrows(op), 1, 1, 1, 1); } else { - int nth = 32; + const int nth_max = MIN(256, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)); - while (16*nth < ne0 && nth < ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) { + int nth = 1; + + while (2*nth < args.ne0 && nth < nth_max) { nth *= 2; } @@ -2902,39 +3237,59 @@ int ggml_metal_op_l2_norm(ggml_metal_op_t ctx, int idx) { GGML_TENSOR_LOCALS( int32_t, ne, op, ne); GGML_TENSOR_LOCALS(uint64_t, nb, op, nb); + GGML_ASSERT(ggml_is_contiguous_rows(op->src[0])); + + ggml_metal_buffer_id bid_src0 = ggml_metal_get_buffer_id(op->src[0]); + ggml_metal_buffer_id bid_dst = ggml_metal_get_buffer_id(op); + float eps; memcpy(&eps, op->op_params, sizeof(float)); - int nth = 32; // SIMD width - ggml_metal_kargs_l2_norm args = { - /*.ne00 =*/ ne00, - /*.ne00_4 =*/ ne00/4, - /*.nb01 =*/ nb01, - /*.eps =*/ eps, + /*.ne00 =*/ ne00, + /*.ne01 =*/ ne01, + /*.ne02 =*/ ne02, + /*.ne03 =*/ ne03, + /*.nb00 =*/ nb00, + /*.nb01 =*/ nb01, + /*.nb02 =*/ nb02, + /*.nb03 =*/ nb03, + /*.ne0 =*/ ne0, + /*.ne1 =*/ ne1, + /*.ne2 =*/ ne2, + /*.ne3 =*/ ne3, + /*.nb0 =*/ nb0, + /*.nb1 =*/ nb1, + /*.nb2 =*/ nb2, + /*.nb3 =*/ nb3, + /*.eps =*/ eps, }; auto pipeline = ggml_metal_library_get_pipeline_l2_norm(lib, op); - while (nth < ne00/4 && nth < ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) { + if (pipeline.c4) { + args.ne00 = ne00/4; + args.ne0 = ne0/4; + } + + int nth = 32; // SIMD width + + while (nth < ne00 && nth < ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) { nth *= 2; } nth = std::min(nth, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)); - nth = std::min(nth, ne00/4); const size_t smem = pipeline.smem; - const int64_t nrows = ggml_nrows(op->src[0]); - ggml_metal_encoder_set_pipeline(enc, pipeline); ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0); - ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1); - ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 2); + ggml_metal_encoder_set_buffer (enc, bid_src0, 1); + ggml_metal_encoder_set_buffer (enc, bid_dst, 2); ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0); - ggml_metal_encoder_dispatch_threadgroups(enc, nrows, 1, 1, nth, 1, 1); + ggml_metal_encoder_dispatch_threadgroups(enc, ne01, ne02, ne03, nth, 1, 1); return 1; } @@ -3280,16 +3635,26 @@ int ggml_metal_op_im2col(ggml_metal_op_t ctx, int idx) { auto pipeline = ggml_metal_library_get_pipeline_im2col(lib, op); - GGML_ASSERT(KH*KW <= ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)); + if (KH*KW <= ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) { + const uint64_t ntptg0 = std::min(ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)/(KH*KW), N); - const uint64_t ntptg0 = std::min(ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)/(KH*KW), N); + ggml_metal_encoder_set_pipeline(enc, pipeline); + ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), 1); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 2); - ggml_metal_encoder_set_pipeline(enc, pipeline); - ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0); - ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), 1); - ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 2); + ggml_metal_encoder_dispatch_threadgroups(enc, IC, OH, OW, ntptg0, KH, KW); + } else { + const uint64_t n_threads = std::min(ggml_metal_pipeline_max_theads_per_threadgroup(pipeline), N); + const int64_t quotient = N / n_threads + (N % n_threads > 0 ? 1 : 0); + + ggml_metal_encoder_set_pipeline(enc, pipeline); + ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), 1); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 2); - ggml_metal_encoder_dispatch_threadgroups(enc, IC, OH, OW, ntptg0, KH, KW); + ggml_metal_encoder_dispatch_threadgroups(enc, quotient * CHW, OH, OW, n_threads, 1, 1); + } return 1; } @@ -3372,6 +3737,77 @@ int ggml_metal_op_conv_2d(ggml_metal_op_t ctx, int idx) { return 1; } +int ggml_metal_op_conv_3d(ggml_metal_op_t ctx, int idx) { + ggml_tensor * op = ctx->node(idx); + + ggml_metal_library_t lib = ctx->lib; + ggml_metal_encoder_t enc = ctx->enc; + + // 1. Extract standard dimensions and byte strides + GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb); + GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb); + GGML_TENSOR_LOCALS(uint64_t, nb, op, nb); + + // 2. Extract hyperparams from op_params + const int32_t s0 = ((const int32_t *)(op->op_params))[0]; + const int32_t s1 = ((const int32_t *)(op->op_params))[1]; + const int32_t s2 = ((const int32_t *)(op->op_params))[2]; + const int32_t p0 = ((const int32_t *)(op->op_params))[3]; + const int32_t p1 = ((const int32_t *)(op->op_params))[4]; + const int32_t p2 = ((const int32_t *)(op->op_params))[5]; + const int32_t d0 = ((const int32_t *)(op->op_params))[6]; + const int32_t d1 = ((const int32_t *)(op->op_params))[7]; + const int32_t d2 = ((const int32_t *)(op->op_params))[8]; + const int32_t IC = ((const int32_t *)(op->op_params))[9]; + const int32_t N = ((const int32_t *)(op->op_params))[10]; + const int32_t OC = ((const int32_t *)(op->op_params))[11]; + + // 3. Build the parameter struct using the macro-generated variables + ggml_metal_kargs_conv_3d args = { + /*.IW =*/ (int32_t)op->src[1]->ne[0], + /*.IH =*/ (int32_t)op->src[1]->ne[1], + /*.ID =*/ (int32_t)op->src[1]->ne[2], + /*.OW =*/ (int32_t)op->ne[0], + /*.OH =*/ (int32_t)op->ne[1], + /*.OD =*/ (int32_t)op->ne[2], + /*.KW =*/ (int32_t)op->src[0]->ne[0], + /*.KH =*/ (int32_t)op->src[0]->ne[1], + /*.KD =*/ (int32_t)op->src[0]->ne[2], + s0, s1, s2, + p0, p1, p2, + d0, d1, d2, + IC, N, OC, + nb00, nb01, nb02, nb03, // Weight strides + nb10, nb11, nb12, nb13, // Input strides + nb0, nb1, nb2, nb3 // Output strides + }; + + // 4. Fetch the JIT pipeline + auto pipeline = ggml_metal_library_get_pipeline_conv_3d(lib, op); + + // 5. Grid mapping + int nth0 = 32; // Standard SIMD width for Apple Silicon + int nth1 = 1; + int nth2 = 1; + + int64_t spatial_volume = args.OW * args.OH * args.OD; + + int ntg0 = (spatial_volume + nth0 - 1) / nth0; + int ntg1 = args.OC; + int ntg2 = args.N; + + // 6. Bind and Dispatch via the ggml C wrapper + ggml_metal_encoder_set_pipeline(enc, pipeline); + ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), 2); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 3); + + ggml_metal_encoder_dispatch_threadgroups(enc, ntg0, ntg1, ntg2, nth0, nth1, nth2); + + return 1; +} + int ggml_metal_op_conv_transpose_1d(ggml_metal_op_t ctx, int idx) { ggml_tensor * op = ctx->node(idx); @@ -3484,12 +3920,76 @@ int ggml_metal_op_upscale(ggml_metal_op_t ctx, int idx) { GGML_TENSOR_LOCALS( int32_t, ne, op, ne); GGML_TENSOR_LOCALS(uint64_t, nb, op, nb); - const float sf0 = (float)ne0/op->src[0]->ne[0]; - const float sf1 = (float)ne1/op->src[0]->ne[1]; - const float sf2 = (float)ne2/op->src[0]->ne[2]; - const float sf3 = (float)ne3/op->src[0]->ne[3]; + float sf0 = (float)ne0/op->src[0]->ne[0]; + float sf1 = (float)ne1/op->src[0]->ne[1]; + float sf2 = (float)ne2/op->src[0]->ne[2]; + float sf3 = (float)ne3/op->src[0]->ne[3]; + + const int32_t mode_flags = ggml_get_op_params_i32(op, 0); + + float poffs = 0.5f; + + if (mode_flags & GGML_SCALE_FLAG_ALIGN_CORNERS) { + poffs = 0.0f; + sf0 = ne0 > 1 && ne00 > 1 ? (float)(ne0 - 1) / (ne00 - 1) : sf0; + sf1 = ne1 > 1 && ne01 > 1 ? (float)(ne1 - 1) / (ne01 - 1) : sf1; + } ggml_metal_kargs_upscale args = { + /*.ne00 =*/ ne00, + /*.ne01 =*/ ne01, + /*.ne02 =*/ ne02, + /*.ne03 =*/ ne03, + /*.nb00 =*/ nb00, + /*.nb01 =*/ nb01, + /*.nb02 =*/ nb02, + /*.nb03 =*/ nb03, + /*.ne0 =*/ ne0, + /*.ne1 =*/ ne1, + /*.ne2 =*/ ne2, + /*.ne3 =*/ ne3, + /*.nb0 =*/ nb0, + /*.nb1 =*/ nb1, + /*.nb2 =*/ nb2, + /*.nb3 =*/ nb3, + /*.sf0 =*/ sf0, + /*.sf1 =*/ sf1, + /*.sf2 =*/ sf2, + /*.sf3 =*/ sf3, + /*.poffs =*/ poffs, + }; + + auto pipeline = ggml_metal_library_get_pipeline_upscale(lib, op); + + const int nth = std::min(ggml_metal_pipeline_max_theads_per_threadgroup(pipeline), ne0); + + ggml_metal_encoder_set_pipeline(enc, pipeline); + ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 2); + + ggml_metal_encoder_dispatch_threadgroups(enc, ne1, ne2, ne3, nth, 1, 1); + + return 1; +} + +int ggml_metal_op_roll(ggml_metal_op_t ctx, int idx) { + ggml_tensor * op = ctx->node(idx); + + ggml_metal_library_t lib = ctx->lib; + ggml_metal_encoder_t enc = ctx->enc; + + GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne); + GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb); + GGML_TENSOR_LOCALS( int32_t, ne, op, ne); + GGML_TENSOR_LOCALS(uint64_t, nb, op, nb); + + const int32_t s0 = ggml_get_op_params_i32(op, 0); + const int32_t s1 = ggml_get_op_params_i32(op, 1); + const int32_t s2 = ggml_get_op_params_i32(op, 2); + const int32_t s3 = ggml_get_op_params_i32(op, 3); + + ggml_metal_kargs_roll args = { /*.ne00 =*/ ne00, /*.ne01 =*/ ne01, /*.ne02 =*/ ne02, @@ -3498,23 +3998,23 @@ int ggml_metal_op_upscale(ggml_metal_op_t ctx, int idx) { /*.nb01 =*/ nb01, /*.nb02 =*/ nb02, /*.nb03 =*/ nb03, - /*.ne0 =*/ ne0, - /*.ne1 =*/ ne1, - /*.ne2 =*/ ne2, - /*.ne3 =*/ ne3, - /*.nb0 =*/ nb0, - /*.nb1 =*/ nb1, - /*.nb2 =*/ nb2, - /*.nb3 =*/ nb3, - /*.sf0 =*/ sf0, - /*.sf1 =*/ sf1, - /*.sf2 =*/ sf2, - /*.sf3 =*/ sf3 + /*.ne0 =*/ ne0, + /*.ne1 =*/ ne1, + /*.ne2 =*/ ne2, + /*.ne3 =*/ ne3, + /*.nb0 =*/ nb0, + /*.nb1 =*/ nb1, + /*.nb2 =*/ nb2, + /*.nb3 =*/ nb3, + /*.s0 =*/ s0, + /*.s1 =*/ s1, + /*.s2 =*/ s2, + /*.s3 =*/ s3 }; - auto pipeline = ggml_metal_library_get_pipeline_upscale(lib, op); + auto pipeline = ggml_metal_library_get_pipeline_roll(lib, op); - const int nth = std::min(ggml_metal_pipeline_max_theads_per_threadgroup(pipeline), ne0); + const int nth = std::min(1024, ne0); ggml_metal_encoder_set_pipeline(enc, pipeline); ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0); @@ -3558,14 +4058,21 @@ int ggml_metal_op_pad(ggml_metal_op_t ctx, int idx) { auto pipeline = ggml_metal_library_get_pipeline_pad(lib, op); - const int nth = std::min(1024, ne0); + if (pipeline.c4) { + args.ne00 = ne00/4; + args.ne0 = ne0/4; + } + + const int nth_max = MIN(64, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)); + const int nth = MIN(args.ne0, nth_max); + const int nk0 = (args.ne0 + 1024 - 1)/1024; // note: 1024 is hardcoded in the kernel! ggml_metal_encoder_set_pipeline(enc, pipeline); ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0); ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1); ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 2); - ggml_metal_encoder_dispatch_threadgroups(enc, ne1, ne2, ne3, nth, 1, 1); + ggml_metal_encoder_dispatch_threadgroups(enc, nk0*ne1, ne2, ne3, nth, 1, 1); return 1; } @@ -3942,42 +4449,6 @@ int ggml_metal_op_top_k(ggml_metal_op_t ctx, int idx) { return 1; } -int ggml_metal_op_leaky_relu(ggml_metal_op_t ctx, int idx) { - ggml_tensor * op = ctx->node(idx); - - ggml_metal_library_t lib = ctx->lib; - ggml_metal_encoder_t enc = ctx->enc; - - GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne); - GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb); - GGML_TENSOR_LOCALS( int32_t, ne, op, ne); - GGML_TENSOR_LOCALS(uint64_t, nb, op, nb); - - float slope; - memcpy(&slope, op->op_params, sizeof(float)); - - ggml_metal_kargs_leaky_relu args = { - /*.slope =*/ slope - }; - - auto pipeline = ggml_metal_library_get_pipeline_unary(lib, op); - - int64_t n = ggml_nelements(op); - - if (n % 4 == 0) { - n /= 4; - } - - ggml_metal_encoder_set_pipeline(enc, pipeline); - ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0); - ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1); - ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 2); - - ggml_metal_encoder_dispatch_threadgroups(enc, n, 1, 1, 1, 1, 1); - - return 1; -} - int ggml_metal_op_tri(ggml_metal_op_t ctx, int idx) { ggml_tensor * op = ctx->node(idx); diff --git a/ggml/src/ggml-metal/ggml-metal-ops.h b/ggml/src/ggml-metal/ggml-metal-ops.h index c1025d35677..36c61071b4f 100644 --- a/ggml/src/ggml-metal/ggml-metal-ops.h +++ b/ggml/src/ggml-metal/ggml-metal-ops.h @@ -46,9 +46,6 @@ size_t ggml_metal_op_flash_attn_ext_extra_tmp(const struct ggml_tensor * op); int ggml_metal_op_concat (ggml_metal_op_t ctx, int idx); int ggml_metal_op_repeat (ggml_metal_op_t ctx, int idx); int ggml_metal_op_acc (ggml_metal_op_t ctx, int idx); -int ggml_metal_op_scale (ggml_metal_op_t ctx, int idx); -int ggml_metal_op_fill (ggml_metal_op_t ctx, int idx); -int ggml_metal_op_clamp (ggml_metal_op_t ctx, int idx); int ggml_metal_op_unary (ggml_metal_op_t ctx, int idx); int ggml_metal_op_glu (ggml_metal_op_t ctx, int idx); int ggml_metal_op_sum (ggml_metal_op_t ctx, int idx); @@ -56,11 +53,16 @@ int ggml_metal_op_sum_rows (ggml_metal_op_t ctx, int idx); int ggml_metal_op_cumsum (ggml_metal_op_t ctx, int idx); int ggml_metal_op_get_rows (ggml_metal_op_t ctx, int idx); int ggml_metal_op_set_rows (ggml_metal_op_t ctx, int idx); +int ggml_metal_op_diag (ggml_metal_op_t ctx, int idx); int ggml_metal_op_soft_max (ggml_metal_op_t ctx, int idx); int ggml_metal_op_ssm_conv (ggml_metal_op_t ctx, int idx); int ggml_metal_op_ssm_scan (ggml_metal_op_t ctx, int idx); int ggml_metal_op_rwkv (ggml_metal_op_t ctx, int idx); +int ggml_metal_op_gated_delta_net (ggml_metal_op_t ctx, int idx); +int ggml_metal_op_solve_tri (ggml_metal_op_t ctx, int idx); +int ggml_metal_op_set (ggml_metal_op_t ctx, int idx); int ggml_metal_op_cpy (ggml_metal_op_t ctx, int idx); +int ggml_metal_op_pool_1d (ggml_metal_op_t ctx, int idx); int ggml_metal_op_pool_2d (ggml_metal_op_t ctx, int idx); int ggml_metal_op_mul_mat (ggml_metal_op_t ctx, int idx); int ggml_metal_op_mul_mat_id (ggml_metal_op_t ctx, int idx); @@ -73,17 +75,18 @@ int ggml_metal_op_norm (ggml_metal_op_t ctx, int idx); int ggml_metal_op_rope (ggml_metal_op_t ctx, int idx); int ggml_metal_op_im2col (ggml_metal_op_t ctx, int idx); int ggml_metal_op_conv_2d (ggml_metal_op_t ctx, int idx); +int ggml_metal_op_conv_3d (ggml_metal_op_t ctx, int idx); int ggml_metal_op_conv_transpose_1d (ggml_metal_op_t ctx, int idx); int ggml_metal_op_conv_transpose_2d (ggml_metal_op_t ctx, int idx); int ggml_metal_op_upscale (ggml_metal_op_t ctx, int idx); int ggml_metal_op_pad (ggml_metal_op_t ctx, int idx); int ggml_metal_op_pad_reflect_1d (ggml_metal_op_t ctx, int idx); +int ggml_metal_op_roll (ggml_metal_op_t ctx, int idx); int ggml_metal_op_arange (ggml_metal_op_t ctx, int idx); int ggml_metal_op_timestep_embedding(ggml_metal_op_t ctx, int idx); int ggml_metal_op_argmax (ggml_metal_op_t ctx, int idx); int ggml_metal_op_argsort (ggml_metal_op_t ctx, int idx); int ggml_metal_op_top_k (ggml_metal_op_t ctx, int idx); -int ggml_metal_op_leaky_relu (ggml_metal_op_t ctx, int idx); int ggml_metal_op_tri (ggml_metal_op_t ctx, int idx); int ggml_metal_op_opt_step_adamw (ggml_metal_op_t ctx, int idx); int ggml_metal_op_opt_step_sgd (ggml_metal_op_t ctx, int idx); diff --git a/ggml/src/ggml-metal/ggml-metal.cpp b/ggml/src/ggml-metal/ggml-metal.cpp index 56b59f0afdf..a1003b3acff 100644 --- a/ggml/src/ggml-metal/ggml-metal.cpp +++ b/ggml/src/ggml-metal/ggml-metal.cpp @@ -7,11 +7,18 @@ #include "ggml-metal-context.h" #include "ggml-metal-ops.h" -// globals +#include <mutex> +#include <string> -// initialized in ggml_backend_metal_reg -static ggml_backend_reg g_ggml_metal_reg; -static ggml_backend_device g_ggml_metal_device; +#define GGML_METAL_NAME "MTL" +#define GGML_METAL_MAX_DEVICES 16 + +// number of Metal devices +// note: can be overridden with GGML_METAL_DEVICES env to simulate virtual devices +static int g_devices = 1; + +// forward declaration +static bool ggml_backend_buffer_is_metal(ggml_backend_buffer_t buffer); //////////////////////////////////////////////////////////////////////////////// // backend interface @@ -64,11 +71,11 @@ static bool ggml_backend_metal_buffer_shared_cpy_tensor(ggml_backend_buffer_t bu GGML_ASSERT(ggml_metal_buffer_is_shared(ctx)); - GGML_UNUSED(buffer); - GGML_UNUSED(src); - GGML_UNUSED(dst); + if (!ggml_backend_buffer_is_metal(src->buffer)) { + return false; + } - return false; + return ggml_metal_buffer_cpy_tensor(ctx, src, dst); } static void ggml_backend_metal_buffer_shared_clear(ggml_backend_buffer_t buffer, uint8_t value) { @@ -80,15 +87,17 @@ static void ggml_backend_metal_buffer_shared_clear(ggml_backend_buffer_t buffer, } static ggml_backend_buffer_i ggml_backend_metal_buffer_shared_i = { - /* .free_buffer = */ ggml_backend_metal_buffer_shared_free_buffer, - /* .get_base = */ ggml_backend_metal_buffer_shared_get_base, - /* .init_tensor = */ NULL, - /* .memset_tensor = */ ggml_backend_metal_buffer_shared_memset_tensor, - /* .set_tensor = */ ggml_backend_metal_buffer_shared_set_tensor, - /* .get_tensor = */ ggml_backend_metal_buffer_shared_get_tensor, - /* .cpy_tensor = */ ggml_backend_metal_buffer_shared_cpy_tensor, - /* .clear = */ ggml_backend_metal_buffer_shared_clear, - /* .reset = */ NULL, + /* .free_buffer = */ ggml_backend_metal_buffer_shared_free_buffer, + /* .get_base = */ ggml_backend_metal_buffer_shared_get_base, + /* .init_tensor = */ NULL, + /* .memset_tensor = */ ggml_backend_metal_buffer_shared_memset_tensor, + /* .set_tensor = */ ggml_backend_metal_buffer_shared_set_tensor, + /* .get_tensor = */ ggml_backend_metal_buffer_shared_get_tensor, + /* .set_tensor_2d = */ NULL, + /* .get_tensor_2d = */ NULL, + /* .cpy_tensor = */ ggml_backend_metal_buffer_shared_cpy_tensor, + /* .clear = */ ggml_backend_metal_buffer_shared_clear, + /* .reset = */ NULL, }; // private buffer @@ -138,11 +147,11 @@ static bool ggml_backend_metal_buffer_private_cpy_tensor(ggml_backend_buffer_t b GGML_ASSERT(!ggml_metal_buffer_is_shared(ctx)); - GGML_UNUSED(buffer); - GGML_UNUSED(src); - GGML_UNUSED(dst); + if (!ggml_backend_buffer_is_metal(src->buffer)) { + return false; + } - return false; + return ggml_metal_buffer_cpy_tensor(ctx, src, dst); } static void ggml_backend_metal_buffer_private_clear(ggml_backend_buffer_t buffer, uint8_t value) { @@ -154,21 +163,41 @@ static void ggml_backend_metal_buffer_private_clear(ggml_backend_buffer_t buffer } static ggml_backend_buffer_i ggml_backend_metal_buffer_private_i = { - /* .free_buffer = */ ggml_backend_metal_buffer_private_free_buffer, - /* .get_base = */ ggml_backend_metal_buffer_private_get_base, - /* .init_tensor = */ NULL, - /* .memset_tensor = */ ggml_backend_metal_buffer_private_memset_tensor, - /* .set_tensor = */ ggml_backend_metal_buffer_private_set_tensor, - /* .get_tensor = */ ggml_backend_metal_buffer_private_get_tensor, - /* .cpy_tensor = */ ggml_backend_metal_buffer_private_cpy_tensor, - /* .clear = */ ggml_backend_metal_buffer_private_clear, - /* .reset = */ NULL, + /* .free_buffer = */ ggml_backend_metal_buffer_private_free_buffer, + /* .get_base = */ ggml_backend_metal_buffer_private_get_base, + /* .init_tensor = */ NULL, + /* .memset_tensor = */ ggml_backend_metal_buffer_private_memset_tensor, + /* .set_tensor = */ ggml_backend_metal_buffer_private_set_tensor, + /* .get_tensor = */ ggml_backend_metal_buffer_private_get_tensor, + /* .set_tensor_2d = */ NULL, + /* .get_tensor_2d = */ NULL, + /* .cpy_tensor = */ ggml_backend_metal_buffer_private_cpy_tensor, + /* .clear = */ ggml_backend_metal_buffer_private_clear, + /* .reset = */ NULL, }; +static bool ggml_backend_buffer_is_metal(ggml_backend_buffer_t buffer) { + return buffer->iface.free_buffer == ggml_backend_metal_buffer_shared_free_buffer || + buffer->iface.free_buffer == ggml_backend_metal_buffer_private_free_buffer; +} + // // buffer types // +struct ggml_backend_metal_buffer_type { + int device; + std::string name; +}; + +struct ggml_backend_metal_buffer_type_deleter { + void operator()(ggml_backend_metal_buffer_type * ctx) const { + delete ctx; + } +}; + +typedef std::unique_ptr<ggml_backend_metal_buffer_type, ggml_backend_metal_buffer_type_deleter> ggml_backend_metal_buffer_type_ptr; + // common method for allocating shread or private Metal buffers static ggml_backend_buffer_t ggml_backend_metal_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size, bool shared) { ggml_metal_device_t ctx_dev = (ggml_metal_device_t)buft->device->context; @@ -218,9 +247,9 @@ static size_t ggml_backend_metal_buffer_type_get_alloc_size(ggml_backend_buffer_ // default (shared) buffer type static const char * ggml_backend_metal_buffer_type_shared_get_name(ggml_backend_buffer_type_t buft) { - return "Metal"; + ggml_backend_metal_buffer_type * ctx = (ggml_backend_metal_buffer_type *)buft->context; - GGML_UNUSED(buft); + return ctx->name.c_str(); } static ggml_backend_buffer_t ggml_backend_metal_buffer_type_shared_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) { @@ -249,29 +278,54 @@ static bool ggml_backend_metal_buffer_type_shared_is_host(ggml_backend_buffer_ty GGML_UNUSED(buft); } -static ggml_backend_buffer_type_t ggml_backend_metal_buffer_type_shared(void) { - static ggml_backend_buffer_type ggml_backend_buffer_type_metal = { - /* .iface = */ { - /* .get_name = */ ggml_backend_metal_buffer_type_shared_get_name, - /* .alloc_buffer = */ ggml_backend_metal_buffer_type_shared_alloc_buffer, - /* .get_alignment = */ ggml_backend_metal_buffer_type_shared_get_alignment, - /* .get_max_size = */ ggml_backend_metal_buffer_type_shared_get_max_size, - /* .get_alloc_size = */ ggml_backend_metal_buffer_type_shared_get_alloc_size, - /* .is_host = */ ggml_backend_metal_buffer_type_shared_is_host, - }, - /* .device = */ &g_ggml_metal_device, - /* .context = */ NULL, - }; +static ggml_backend_buffer_type_t ggml_backend_metal_buffer_type_shared(int device) { + static std::mutex mutex; + std::lock_guard<std::mutex> lock(mutex); + + static std::vector<ggml_backend_buffer_type> bufts; + static std::vector<ggml_backend_metal_buffer_type_ptr> ctxs; + + static bool initialized = false; + if (!initialized) { + bufts.reserve(g_devices); + ctxs.reserve(g_devices); + + for (int i = 0; i < g_devices; ++i) { + ggml_backend_metal_buffer_type * raw_ctx = + new ggml_backend_metal_buffer_type { + /* .device = */ i, + /* .name = */ GGML_METAL_NAME + std::to_string(i), + }; + ctxs.emplace_back(raw_ctx); + + ggml_backend_buffer_type buft = { + /* .iface = */ { + /* .get_name = */ ggml_backend_metal_buffer_type_shared_get_name, + /* .alloc_buffer = */ ggml_backend_metal_buffer_type_shared_alloc_buffer, + /* .get_alignment = */ ggml_backend_metal_buffer_type_shared_get_alignment, + /* .get_max_size = */ ggml_backend_metal_buffer_type_shared_get_max_size, + /* .get_alloc_size = */ ggml_backend_metal_buffer_type_shared_get_alloc_size, + /* .is_host = */ ggml_backend_metal_buffer_type_shared_is_host, + }, + /* .device = */ ggml_backend_reg_dev_get(ggml_backend_metal_reg(), i), + /* .context = */ raw_ctx, + }; + + bufts.emplace_back(buft); + } + + initialized = true; + } - return &ggml_backend_buffer_type_metal; + return &bufts[device]; } // default (private) buffer type static const char * ggml_backend_metal_buffer_type_private_get_name(ggml_backend_buffer_type_t buft) { - return "Metal_Private"; + ggml_backend_metal_buffer_type * ctx = (ggml_backend_metal_buffer_type *)buft->context; - GGML_UNUSED(buft); + return ctx->name.c_str(); } static ggml_backend_buffer_t ggml_backend_metal_buffer_type_private_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) { @@ -300,29 +354,53 @@ static bool ggml_backend_metal_buffer_type_private_is_host(ggml_backend_buffer_t GGML_UNUSED(buft); } -static ggml_backend_buffer_type_t ggml_backend_metal_buffer_type_private(void) { - static ggml_backend_buffer_type ggml_backend_buffer_type_metal = { - /* .iface = */ { - /* .get_name = */ ggml_backend_metal_buffer_type_private_get_name, - /* .alloc_buffer = */ ggml_backend_metal_buffer_type_private_alloc_buffer, - /* .get_alignment = */ ggml_backend_metal_buffer_type_private_get_alignment, - /* .get_max_size = */ ggml_backend_metal_buffer_type_private_get_max_size, - /* .get_alloc_size = */ ggml_backend_metal_buffer_type_private_get_alloc_size, - /* .is_host = */ ggml_backend_metal_buffer_type_private_is_host, - }, - /* .device = */ &g_ggml_metal_device, - /* .context = */ NULL, - }; +static ggml_backend_buffer_type_t ggml_backend_metal_buffer_type_private(int device) { + static std::mutex mutex; + std::lock_guard<std::mutex> lock(mutex); + + static std::vector<ggml_backend_buffer_type> bufts; + static std::vector<ggml_backend_metal_buffer_type_ptr> ctxs; + + static bool initialized = false; + if (!initialized) { + bufts.reserve(g_devices); + ctxs.reserve(g_devices); + + for (int i = 0; i < g_devices; ++i) { + ggml_backend_metal_buffer_type * raw_ctx = new ggml_backend_metal_buffer_type{ + /* .device = */ i, + /* .name = */ GGML_METAL_NAME + std::to_string(i) + "_Private" + }; + ctxs.emplace_back(raw_ctx); + + ggml_backend_buffer_type buft = { + /* .iface = */ { + /* .get_name = */ ggml_backend_metal_buffer_type_private_get_name, + /* .alloc_buffer = */ ggml_backend_metal_buffer_type_private_alloc_buffer, + /* .get_alignment = */ ggml_backend_metal_buffer_type_private_get_alignment, + /* .get_max_size = */ ggml_backend_metal_buffer_type_private_get_max_size, + /* .get_alloc_size = */ ggml_backend_metal_buffer_type_private_get_alloc_size, + /* .is_host = */ ggml_backend_metal_buffer_type_private_is_host, + }, + /* .device = */ ggml_backend_reg_dev_get(ggml_backend_metal_reg(), i), + /* .context = */ raw_ctx, + }; + + bufts.emplace_back(buft); + } + + initialized = true; + } - return &ggml_backend_buffer_type_metal; + return &bufts[device]; } // mapped buffer type static const char * ggml_backend_metal_buffer_type_mapped_get_name(ggml_backend_buffer_type_t buft) { - return "Metal_Mapped"; + ggml_backend_metal_buffer_type * ctx = (ggml_backend_metal_buffer_type *)buft->context; - GGML_UNUSED(buft); + return ctx->name.c_str(); } static ggml_backend_buffer_t ggml_backend_metal_buffer_type_mapped_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) { @@ -352,31 +430,55 @@ static bool ggml_backend_metal_buffer_type_mapped_is_host(ggml_backend_buffer_ty GGML_UNUSED(buft); } -static ggml_backend_buffer_type_t ggml_backend_metal_buffer_type_mapped(void) { - // note: not obvious, but this buffer type still needs to implement .alloc_buffer: - // https://github.com/ggml-org/llama.cpp/pull/15832#discussion_r2333177099 - static ggml_backend_buffer_type ggml_backend_buffer_type_mapped_metal = { - /* .iface = */ { - /* .get_name = */ ggml_backend_metal_buffer_type_mapped_get_name, - /* .alloc_buffer = */ ggml_backend_metal_buffer_type_mapped_alloc_buffer, - /* .get_alignment = */ ggml_backend_metal_buffer_type_mapped_get_alignment, - /* .get_max_size = */ ggml_backend_metal_buffer_type_mapped_get_max_size, - /* .get_alloc_size = */ ggml_backend_metal_buffer_type_mapped_get_alloc_size, - /* .is_host = */ ggml_backend_metal_buffer_type_mapped_is_host, - }, - /* .device = */ &g_ggml_metal_device, - /* .context = */ NULL, - }; +static ggml_backend_buffer_type_t ggml_backend_metal_buffer_type_mapped(int device) { + static std::mutex mutex; + std::lock_guard<std::mutex> lock(mutex); + + static std::vector<ggml_backend_buffer_type> bufts; + static std::vector<ggml_backend_metal_buffer_type_ptr> ctxs; + + static bool initialized = false; + if (!initialized) { + bufts.reserve(g_devices); + ctxs.reserve(g_devices); + + for (int i = 0; i < g_devices; ++i) { + ggml_backend_metal_buffer_type * raw_ctx = new ggml_backend_metal_buffer_type{ + /* .device = */ i, + /* .name = */ GGML_METAL_NAME + std::to_string(i) + "_Mapped" + }; + ctxs.emplace_back(raw_ctx); + + // note: not obvious, but this buffer type still needs to implement .alloc_buffer: + // https://github.com/ggml-org/llama.cpp/pull/15832#discussion_r2333177099 + ggml_backend_buffer_type buft = { + /* .iface = */ { + /* .get_name = */ ggml_backend_metal_buffer_type_mapped_get_name, + /* .alloc_buffer = */ ggml_backend_metal_buffer_type_mapped_alloc_buffer, + /* .get_alignment = */ ggml_backend_metal_buffer_type_mapped_get_alignment, + /* .get_max_size = */ ggml_backend_metal_buffer_type_mapped_get_max_size, + /* .get_alloc_size = */ ggml_backend_metal_buffer_type_mapped_get_alloc_size, + /* .is_host = */ ggml_backend_metal_buffer_type_mapped_is_host, + }, + /* .device = */ ggml_backend_reg_dev_get(ggml_backend_metal_reg(), i), + /* .context = */ raw_ctx, + }; + + bufts.emplace_back(buft); + } + + initialized = true; + } - return &ggml_backend_buffer_type_mapped_metal; + return &bufts[device]; } // backend static const char * ggml_backend_metal_name(ggml_backend_t backend) { - return "Metal"; + ggml_metal_t ctx = (ggml_metal_t)backend->context; - GGML_UNUSED(backend); + return ggml_metal_get_name(ctx); } static void ggml_backend_metal_free(ggml_backend_t backend) { @@ -409,12 +511,24 @@ static void ggml_backend_metal_get_tensor_async(ggml_backend_t backend, const gg } static bool ggml_backend_metal_cpy_tensor_async(ggml_backend_t backend_src, ggml_backend_t backend_dst, const ggml_tensor * src, ggml_tensor * dst) { - return false; + if (!ggml_backend_is_metal(backend_src) || !ggml_backend_is_metal(backend_dst)) { + return false; + } + + if (!ggml_backend_buffer_is_metal(src->buffer) || !ggml_backend_buffer_is_metal(dst->buffer)) { + return false; + } + + ggml_metal_t ctx_src = (ggml_metal_t)backend_src->context; + ggml_metal_t ctx_dst = (ggml_metal_t)backend_dst->context; - GGML_UNUSED(backend_src); - GGML_UNUSED(backend_dst); - GGML_UNUSED(src); - GGML_UNUSED(dst); + //ggml_backend_buffer_t buf_src = src->view_src ? src->view_src->buffer : src->buffer; + //ggml_backend_buffer_t buf_dst = dst->view_src ? dst->view_src->buffer : dst->buffer; + + //ggml_metal_buffer_t buf_ctx_src = (ggml_metal_buffer_t)buf_src->context; + //ggml_metal_buffer_t buf_ctx_dst = (ggml_metal_buffer_t)buf_dst->context; + + return ggml_metal_cpy_tensor_async(ctx_src, ctx_dst, src, dst); } static enum ggml_status ggml_backend_metal_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) { @@ -423,6 +537,20 @@ static enum ggml_status ggml_backend_metal_graph_compute(ggml_backend_t backend, return ggml_metal_graph_compute(ctx, cgraph); } +static void ggml_backend_metal_event_record(ggml_backend_t backend, ggml_backend_event_t event) { + ggml_metal_t ctx = (ggml_metal_t)backend->context; + ggml_metal_event_t ev = (ggml_metal_event_t)event->context; + + ggml_metal_event_record(ctx, ev); +} + +static void ggml_backend_metal_event_wait(ggml_backend_t backend, ggml_backend_event_t event) { + ggml_metal_t ctx = (ggml_metal_t)backend->context; + ggml_metal_event_t ev = (ggml_metal_event_t)event->context; + + ggml_metal_event_wait(ctx, ev); +} + static void ggml_backend_metal_graph_optimize(ggml_backend_t backend, ggml_cgraph * cgraph) { ggml_metal_t ctx = (ggml_metal_t)backend->context; @@ -435,7 +563,6 @@ static void ggml_backend_metal_set_n_cb(ggml_backend_t backend, int n_cb) { ggml_metal_t ctx = (ggml_metal_t)backend->context; ggml_metal_set_n_cb(ctx, n_cb); - } static ggml_backend_i ggml_backend_metal_i = { @@ -443,6 +570,8 @@ static ggml_backend_i ggml_backend_metal_i = { /* .free = */ ggml_backend_metal_free, /* .set_tensor_async = */ ggml_backend_metal_set_tensor_async, /* .get_tensor_async = */ ggml_backend_metal_get_tensor_async, + /* .set_tensor_2d_async = */ NULL, + /* .get_tensor_2d_async = */ NULL, /* .cpy_tensor_async = */ ggml_backend_metal_cpy_tensor_async, // only needed for multi-GPU setups /* .synchronize = */ ggml_backend_metal_synchronize, /* .graph_plan_create = */ NULL, @@ -450,12 +579,8 @@ static ggml_backend_i ggml_backend_metal_i = { /* .graph_plan_update = */ NULL, /* .graph_plan_compute = */ NULL, /* .graph_compute = */ ggml_backend_metal_graph_compute, - - // the events API is needed only for multi-GPU setups, so likely no need to implement it for Metal - // in any case, these docs seem relevant if we ever decide to implement it: - // https://developer.apple.com/documentation/metal/mtlcommandbuffer#Synchronizing-Passes-with-Events - /* .event_record = */ NULL, - /* .event_wait = */ NULL, + /* .event_record = */ ggml_backend_metal_event_record, + /* .event_wait = */ ggml_backend_metal_event_wait, /* .graph_optimize = */ ggml_backend_metal_graph_optimize, }; @@ -519,15 +644,17 @@ void ggml_backend_metal_capture_next_compute(ggml_backend_t backend) { // backend device static const char * ggml_backend_metal_device_get_name(ggml_backend_dev_t dev) { - return "Metal"; + ggml_metal_device_t ctx_dev = (ggml_metal_device_t)dev->context; - GGML_UNUSED(dev); + const ggml_metal_device_props * props_dev = ggml_metal_device_get_props(ctx_dev); + + return props_dev->name; } static const char * ggml_backend_metal_device_get_description(ggml_backend_dev_t dev) { ggml_metal_device_t ctx_dev = (ggml_metal_device_t)dev->context; - return ggml_metal_device_get_props(ctx_dev)->name; + return ggml_metal_device_get_props(ctx_dev)->desc; } static void ggml_backend_metal_device_get_memory(ggml_backend_dev_t dev, size_t * free, size_t * total) { @@ -550,14 +677,14 @@ static void ggml_backend_metal_device_get_props(ggml_backend_dev_t dev, ggml_bac ggml_backend_metal_device_get_memory(dev, &props->memory_free, &props->memory_total); props->caps = { - /* .async = */ true, - /* .host_buffer = */ false, - /* .buffer_from_host_ptr = */ true, - /* .events = */ false, + /* .async = */ true, + /* .host_buffer = */ false, + /* .buffer_from_host_ptr = */ true, + /* .events = */ true, }; } -static ggml_backend_t ggml_backend_metal_device_init(ggml_backend_dev_t dev, const char * params) { +static ggml_backend_t ggml_backend_metal_device_init_backend(ggml_backend_dev_t dev, const char * params) { ggml_metal_device_t ctx_dev = (ggml_metal_device_t)dev->context; ggml_metal_t ctx = ggml_metal_init(ctx_dev); @@ -587,7 +714,7 @@ static ggml_backend_buffer_type_t ggml_backend_metal_device_get_buffer_type(ggml const ggml_metal_device_props * props_dev = ggml_metal_device_get_props(ctx_dev); - return props_dev->use_shared_buffers ? ggml_backend_metal_buffer_type_shared() : ggml_backend_metal_buffer_type_private(); + return props_dev->use_shared_buffers ? ggml_backend_metal_buffer_type_shared(props_dev->device) : ggml_backend_metal_buffer_type_private(props_dev->device); } static ggml_backend_buffer_t ggml_backend_metal_device_buffer_mapped(ggml_backend_dev_t dev, void * ptr, size_t size, size_t max_tensor_size) { @@ -595,7 +722,9 @@ static ggml_backend_buffer_t ggml_backend_metal_device_buffer_mapped(ggml_backen ggml_metal_buffer_t res = ggml_metal_buffer_map(ctx_dev, ptr, size, max_tensor_size); - return ggml_backend_buffer_init(ggml_backend_metal_buffer_type_mapped(), ggml_backend_metal_buffer_shared_i, res, size); + const ggml_metal_device_props * props_dev = ggml_metal_device_get_props(ctx_dev); + + return ggml_backend_buffer_init(ggml_backend_metal_buffer_type_mapped(props_dev->device), ggml_backend_metal_buffer_shared_i, res, size); } static bool ggml_backend_metal_device_supports_op(ggml_backend_dev_t dev, const ggml_tensor * op) { @@ -606,9 +735,10 @@ static bool ggml_backend_metal_device_supports_op(ggml_backend_dev_t dev, const static bool ggml_backend_metal_device_supports_buft(ggml_backend_dev_t dev, ggml_backend_buffer_type_t buft) { return + buft->device == dev && ( buft->iface.get_name == ggml_backend_metal_buffer_type_shared_get_name || buft->iface.get_name == ggml_backend_metal_buffer_type_private_get_name || - buft->iface.get_name == ggml_backend_metal_buffer_type_mapped_get_name; + buft->iface.get_name == ggml_backend_metal_buffer_type_mapped_get_name); GGML_UNUSED(dev); } @@ -632,45 +762,97 @@ static bool ggml_backend_metal_device_offload_op(ggml_backend_dev_t dev, const g get_op_batch_size(op) >= ggml_metal_device_get_props(ctx_dev)->op_offload_min_batch_size; } +static ggml_backend_event_t ggml_backend_metal_device_event_new(ggml_backend_dev_t dev) { + ggml_metal_device_t ctx_dev = (ggml_metal_device_t)dev->context; + + ggml_metal_event_t event = ggml_metal_device_event_init(ctx_dev); + GGML_ASSERT(event); + + ggml_backend_event_t ev = new ggml_backend_event { + /* .device = */ dev, + /* .context = */ event, + }; + + return ev; +} + +static void ggml_backend_metal_device_event_free(ggml_backend_dev_t dev, ggml_backend_event_t event) { + ggml_metal_device_t ctx_dev = (ggml_metal_device_t)dev->context; + + ggml_metal_event_t ev = (ggml_metal_event_t)event->context; + + ggml_metal_device_event_free(ctx_dev, ev); + + delete event; +} + +static void ggml_backend_metal_device_event_synchronize(ggml_backend_dev_t dev, ggml_backend_event_t event) { + ggml_metal_device_t ctx_dev = (ggml_metal_device_t)dev->context; + + ggml_metal_event_t evt = (ggml_metal_event_t)event->context; + + ggml_metal_device_event_synchronize(ctx_dev, evt); +} + static ggml_backend_device_i ggml_backend_metal_device_i = { /* .get_name = */ ggml_backend_metal_device_get_name, /* .get_description = */ ggml_backend_metal_device_get_description, /* .get_memory = */ ggml_backend_metal_device_get_memory, /* .get_type = */ ggml_backend_metal_device_get_type, /* .get_props = */ ggml_backend_metal_device_get_props, - /* .init_backend = */ ggml_backend_metal_device_init, + /* .init_backend = */ ggml_backend_metal_device_init_backend, /* .get_buffer_type = */ ggml_backend_metal_device_get_buffer_type, /* .get_host_buffer_type = */ NULL, /* .buffer_from_host_ptr = */ ggml_backend_metal_device_buffer_mapped, /* .supports_op = */ ggml_backend_metal_device_supports_op, /* .supports_buft = */ ggml_backend_metal_device_supports_buft, /* .offload_op = */ ggml_backend_metal_device_offload_op, - /* .event_new = */ NULL, - /* .event_free = */ NULL, - /* .event_synchronize = */ NULL, + /* .event_new = */ ggml_backend_metal_device_event_new, + /* .event_free = */ ggml_backend_metal_device_event_free, + /* .event_synchronize = */ ggml_backend_metal_device_event_synchronize, }; // backend registry +struct ggml_backend_metal_reg { + std::vector<ggml_backend_dev_t> devices; +}; + +typedef struct ggml_backend_metal_reg * ggml_backend_metal_reg_t; + +static ggml_backend_metal_reg_t ggml_backend_metal_reg_init(void) { + ggml_backend_metal_reg_t ctx = new struct ggml_backend_metal_reg; + + return ctx; +} + +static void ggml_backend_metal_reg_free(ggml_backend_metal_reg_t ctx) { + delete ctx; +} + +struct ggml_backend_metal_reg_deleter { + void operator()(ggml_backend_metal_reg_t ctx) { + ggml_backend_metal_reg_free(ctx); + } +}; + +typedef std::unique_ptr<struct ggml_backend_metal_reg, ggml_backend_metal_reg_deleter> ggml_backend_metal_reg_ptr; + static const char * ggml_backend_metal_reg_get_name(ggml_backend_reg_t reg) { - return "Metal"; + return GGML_METAL_NAME; GGML_UNUSED(reg); } static size_t ggml_backend_metal_reg_device_count(ggml_backend_reg_t reg) { - return 1; - - GGML_UNUSED(reg); + ggml_backend_metal_reg_t ctx = (ggml_backend_metal_reg_t)reg->context; + return ctx->devices.size(); } static ggml_backend_dev_t ggml_backend_metal_reg_device_get(ggml_backend_reg_t reg, size_t index) { - GGML_ASSERT(index == 0); - - return &g_ggml_metal_device; - - GGML_UNUSED(reg); - GGML_UNUSED(index); + ggml_backend_metal_reg_t ctx = (ggml_backend_metal_reg_t)reg->context; + GGML_ASSERT(index < ctx->devices.size()); + return ctx->devices[index]; } static ggml_backend_feature g_ggml_backend_metal_features[] = { @@ -698,27 +880,71 @@ static void * ggml_backend_metal_get_proc_address(ggml_backend_reg_t reg, const static ggml_backend_reg_i ggml_backend_metal_reg_i = { /* .get_name = */ ggml_backend_metal_reg_get_name, - /* .device_count = */ ggml_backend_metal_reg_device_count, - /* .device_get = */ ggml_backend_metal_reg_device_get, + /* .get_device_count = */ ggml_backend_metal_reg_device_count, + /* .get_device = */ ggml_backend_metal_reg_device_get, /* .get_proc_address = */ ggml_backend_metal_get_proc_address, }; +static ggml_backend_dev_t ggml_backend_metal_device_init(ggml_backend_reg_t reg, int device) { + return new ggml_backend_device { + /* .iface = */ ggml_backend_metal_device_i, + /* .reg = */ reg, + /* .context = */ ggml_metal_device_get(device), + }; +} + +static void ggml_backend_metal_device_free(ggml_backend_dev_t dev) { + delete dev; +} + +struct ggml_backend_device_deleter { + void operator()(ggml_backend_dev_t ctx) { + ggml_backend_metal_device_free(ctx); + } +}; + +typedef std::unique_ptr<ggml_backend_device, ggml_backend_device_deleter> ggml_backend_device_ptr; + ggml_backend_reg_t ggml_backend_metal_reg(void) { + static ggml_backend_reg reg; + static bool initialized = false; + { - g_ggml_metal_reg = { - /* .api_version = */ GGML_BACKEND_API_VERSION, - /* .iface = */ ggml_backend_metal_reg_i, - /* .context = */ NULL, - }; - - g_ggml_metal_device = { - /* .iface = */ ggml_backend_metal_device_i, - /* .reg = */ &g_ggml_metal_reg, - /* .context = */ ggml_metal_device_get(), - }; + static std::mutex mutex; + std::lock_guard<std::mutex> lock(mutex); + + const char * env = getenv("GGML_METAL_DEVICES"); + if (env) { + g_devices = atoi(env); + } + + static std::vector<ggml_backend_device_ptr> devs; + + if (!initialized) { + // workaround macOS limitation (kIOGPUCommandBufferCallbackErrorImpactingInteractivity) until proper fix becomes possible + // ref: https://github.com/ggml-org/llama.cpp/issues/20141#issuecomment-4272947703 + setenv("AGX_RELAX_CDM_CTXSTORE_TIMEOUT", "1", true); + + static ggml_backend_metal_reg_ptr reg_ctx(ggml_backend_metal_reg_init()); + + for (int i = 0; i < g_devices; ++i) { + auto * dev = ggml_backend_metal_device_init(®, i); + devs.emplace_back(dev); + + reg_ctx->devices.push_back(dev); + } + + reg = { + /* .api_version = */ GGML_BACKEND_API_VERSION, + /* .iface = */ ggml_backend_metal_reg_i, + /* .context = */ reg_ctx.get(), + }; + } + + initialized = true; } - return &g_ggml_metal_reg; + return ® } GGML_BACKEND_DL_IMPL(ggml_backend_metal_reg) diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal index 16d17d26af8..0aea68455fb 100644 --- a/ggml/src/ggml-metal/ggml-metal.metal +++ b/ggml/src/ggml-metal/ggml-metal.metal @@ -77,6 +77,14 @@ static inline float dot(float x, float y) { return x*y; } +static inline float sum(float x) { + return x; +} + +static inline float sum(float4 x) { + return x[0] + x[1] + x[2] + x[3]; +} + // NOTE: this is not dequantizing - we are simply fitting the template template <typename type4x4> void dequantize_f32(device const float4x4 * src, short il, thread type4x4 & reg) { @@ -110,6 +118,56 @@ void dequantize_bf16_t4(device const bfloat4 * src, short il, thread type4 & reg } #endif +template <typename type4x4> +void dequantize_q1_0(device const block_q1_0 * xb, short il, thread type4x4 & reg) { + device const uint8_t * qs = xb->qs; + const float d = xb->d; + const float neg_d = -d; + + const int byte_offset = il * 2; // il*16 bits = il*2 bytes + const uint8_t b0 = qs[byte_offset]; + const uint8_t b1 = qs[byte_offset + 1]; + + float4x4 reg_f; + + reg_f[0][0] = select(neg_d, d, bool(b0 & 0x01)); + reg_f[0][1] = select(neg_d, d, bool(b0 & 0x02)); + reg_f[0][2] = select(neg_d, d, bool(b0 & 0x04)); + reg_f[0][3] = select(neg_d, d, bool(b0 & 0x08)); + reg_f[1][0] = select(neg_d, d, bool(b0 & 0x10)); + reg_f[1][1] = select(neg_d, d, bool(b0 & 0x20)); + reg_f[1][2] = select(neg_d, d, bool(b0 & 0x40)); + reg_f[1][3] = select(neg_d, d, bool(b0 & 0x80)); + + reg_f[2][0] = select(neg_d, d, bool(b1 & 0x01)); + reg_f[2][1] = select(neg_d, d, bool(b1 & 0x02)); + reg_f[2][2] = select(neg_d, d, bool(b1 & 0x04)); + reg_f[2][3] = select(neg_d, d, bool(b1 & 0x08)); + reg_f[3][0] = select(neg_d, d, bool(b1 & 0x10)); + reg_f[3][1] = select(neg_d, d, bool(b1 & 0x20)); + reg_f[3][2] = select(neg_d, d, bool(b1 & 0x40)); + reg_f[3][3] = select(neg_d, d, bool(b1 & 0x80)); + + reg = (type4x4) reg_f; +} + +template <typename type4> +void dequantize_q1_0_t4(device const block_q1_0 * xb, short il, thread type4 & reg) { + const float d = xb->d; + const float neg_d = -d; + const int base = il * 4; + const uint8_t byte = xb->qs[base / 8]; + const int s = base % 8; + + float4 reg_f; + reg_f[0] = select(neg_d, d, bool((byte >> (s )) & 1)); + reg_f[1] = select(neg_d, d, bool((byte >> (s + 1)) & 1)); + reg_f[2] = select(neg_d, d, bool((byte >> (s + 2)) & 1)); + reg_f[3] = select(neg_d, d, bool((byte >> (s + 3)) & 1)); + + reg = (type4) reg_f; +} + template <typename type4x4> void dequantize_q4_0(device const block_q4_0 * xb, short il, thread type4x4 & reg) { device const uint16_t * qs = ((device const uint16_t *)xb + 1); @@ -144,6 +202,23 @@ void dequantize_q4_0_t4(device const block_q4_0 * xb, short il, thread type4 & r } } +void quantize_q1_0(device const float * src, device block_q1_0 & dst) { + float sum_abs = 0.0f; + for (int j = 0; j < QK1_0; j++) { + sum_abs += fabs(src[j]); + } + dst.d = sum_abs / QK1_0; + + for (int j = 0; j < QK1_0 / 8; j++) { + dst.qs[j] = 0; + } + for (int j = 0; j < QK1_0; j++) { + if (src[j] >= 0.0f) { + dst.qs[j / 8] |= (1 << (j % 8)); + } + } +} + void quantize_q4_0(device const float * src, device block_q4_0 & dst) { #pragma METAL fp math_mode(safe) float amax = 0.0f; // absolute max @@ -895,753 +970,459 @@ enum ggml_sort_order { GGML_SORT_ORDER_DESC, }; -// general-purpose kernel for addition, subtraction, multiplication and division of two tensors -// pros: works for non-contiguous tensors, supports broadcast across all dims -// cons: not very efficient -template <int F> -kernel void kernel_add_fuse_impl( - constant ggml_metal_kargs_bin & args, - device const char * src0, - device const char * src1, - device char * dst, - uint3 tgpig[[threadgroup_position_in_grid]], - ushort3 tpitg[[thread_position_in_threadgroup]], - ushort3 ntg[[threads_per_threadgroup]]) { - const int i03 = tgpig.z; - const int i02 = tgpig.y; - const int i01 = tgpig.x; +constant float GELU_COEF_A = 0.044715f; +constant float GELU_QUICK_COEF = -1.702f; +constant float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f; +constant float SQRT_2_INV = 0.70710678118654752440084436210484f; - const int i13 = i03%args.ne13; - const int i12 = i02%args.ne12; - const int i11 = i01%args.ne11; +// based on Abramowitz and Stegun formula 7.1.26 or similar Hastings' approximation +// ref: https://www.johndcook.com/blog/python_erf/ +constant float p_erf = 0.3275911f; +constant float a1_erf = 0.254829592f; +constant float a2_erf = -0.284496736f; +constant float a3_erf = 1.421413741f; +constant float a4_erf = -1.453152027f; +constant float a5_erf = 1.061405429f; - device const float * src0_ptr = (device const float *) (src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + args.offs); - device float * dst_ptr = (device float *) (dst + i03*args.nb3 + i02*args.nb2 + i01*args.nb1 + args.offs); +template<typename T> +inline T erf_approx(T x) { + T sign_x = sign(x); + x = fabs(x); + T t = 1.0f / (1.0f + p_erf * x); + T y = 1.0f - (((((a5_erf * t + a4_erf) * t) + a3_erf) * t + a2_erf) * t + a1_erf) * t * exp(-x * x); + return sign_x * y; +} - device const float * src1_ptr[F]; - for (short j = 0; j < F; ++j) { - src1_ptr[j] = (device const float *) (src1 + args.o1[j] + i13*args.nb13 + i12*args.nb12 + i11*args.nb11); - } +template<typename T> T elu_approx(T x); - for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) { - const int i10 = i0%args.ne10; +template<> inline float elu_approx<float>(float x) { + return (x > 0.f) ? x : (exp(x) - 1); +} - float res = src0_ptr[i0]; +template<> inline float4 elu_approx<float4>(float4 x) { + float4 res; -#pragma unroll - for (short j = 0; j < F; ++j) { - res += src1_ptr[j][i10]; - } + res[0] = (x[0] > 0.0f) ? x[0] : (exp(x[0]) - 1.0f); + res[1] = (x[1] > 0.0f) ? x[1] : (exp(x[1]) - 1.0f); + res[2] = (x[2] > 0.0f) ? x[2] : (exp(x[2]) - 1.0f); + res[3] = (x[3] > 0.0f) ? x[3] : (exp(x[3]) - 1.0f); - dst_ptr[i0] = res; - } + return res; } -typedef decltype(kernel_add_fuse_impl<2>) kernel_add_fuse_t; - -template [[host_name("kernel_add_fuse_1")]] kernel kernel_add_fuse_t kernel_add_fuse_impl<1>; -template [[host_name("kernel_add_fuse_2")]] kernel kernel_add_fuse_t kernel_add_fuse_impl<2>; -template [[host_name("kernel_add_fuse_3")]] kernel kernel_add_fuse_t kernel_add_fuse_impl<3>; -template [[host_name("kernel_add_fuse_4")]] kernel kernel_add_fuse_t kernel_add_fuse_impl<4>; -template [[host_name("kernel_add_fuse_5")]] kernel kernel_add_fuse_t kernel_add_fuse_impl<5>; -template [[host_name("kernel_add_fuse_6")]] kernel kernel_add_fuse_t kernel_add_fuse_impl<6>; -template [[host_name("kernel_add_fuse_7")]] kernel kernel_add_fuse_t kernel_add_fuse_impl<7>; -template [[host_name("kernel_add_fuse_8")]] kernel kernel_add_fuse_t kernel_add_fuse_impl<8>; +constant short FC_unary_op [[function_constant(FC_UNARY + 0)]]; +constant bool FC_unary_cnt[[function_constant(FC_UNARY + 1)]]; -kernel void kernel_sub_fuse_1( - constant ggml_metal_kargs_bin & args, +template <typename T0, typename T, typename TC> +kernel void kernel_unary_impl( + constant ggml_metal_kargs_unary & args, device const char * src0, - device const char * src1, device char * dst, uint3 tgpig[[threadgroup_position_in_grid]], ushort3 tpitg[[thread_position_in_threadgroup]], ushort3 ntg[[threads_per_threadgroup]]) { - const int i03 = tgpig.z; - const int i02 = tgpig.y; - const int i01 = tgpig.x; +#define FC_OP FC_unary_op +#define FC_CNT FC_unary_cnt - const int i13 = i03%args.ne13; - const int i12 = i02%args.ne12; - const int i11 = i01%args.ne11; + device const T0 * src0_ptr; + device T * dst_ptr; - device const char * src0_ptr = src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + args.offs; - device const char * src1_ptr = src1 + i13*args.nb13 + i12*args.nb12 + i11*args.nb11 + args.o1[0]; - device char * dst_ptr = dst + i03*args.nb3 + i02*args.nb2 + i01*args.nb1 + args.offs; + int i0; - for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) { - const int i10 = i0%args.ne10; - *((device float *)(dst_ptr + i0*args.nb0)) = *((device float *)(src0_ptr + i0*args.nb00)) - *((device float *)(src1_ptr + i10*args.nb10)); - } -} + if (FC_CNT) { + i0 = tgpig.x; -kernel void kernel_mul_fuse_1( - constant ggml_metal_kargs_bin & args, - device const char * src0, - device const char * src1, - device char * dst, - uint3 tgpig[[threadgroup_position_in_grid]], - ushort3 tpitg[[thread_position_in_threadgroup]], - ushort3 ntg[[threads_per_threadgroup]]) { - const int i03 = tgpig.z; - const int i02 = tgpig.y; - const int i01 = tgpig.x; - - const int i13 = i03%args.ne13; - const int i12 = i02%args.ne12; - const int i11 = i01%args.ne11; - - device const char * src0_ptr = src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + args.offs; - device const char * src1_ptr = src1 + i13*args.nb13 + i12*args.nb12 + i11*args.nb11 + args.o1[0]; - device char * dst_ptr = dst + i03*args.nb3 + i02*args.nb2 + i01*args.nb1 + args.offs; - - if (args.ne10 == 1) { - const float x = *((device float *)(src1_ptr)); - for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) { - *((device float *)(dst_ptr + i0*args.nb0)) = *((device float *)(src0_ptr + i0*args.nb00)) * x; - } + src0_ptr = (device const T0 *) (src0); + dst_ptr = (device T *) (dst); } else { - for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) { - const int i10 = i0%args.ne10; - *((device float *)(dst_ptr + i0*args.nb0)) = *((device float *)(src0_ptr + i0*args.nb00)) * *((device float *)(src1_ptr + i10*args.nb10)); - } - } -} + const int i03 = tgpig.z; + const int i02 = tgpig.y; + const int k0 = tgpig.x/args.ne01; + const int i01 = tgpig.x - k0*args.ne01; -kernel void kernel_div_fuse_1( - constant ggml_metal_kargs_bin & args, - device const char * src0, - device const char * src1, - device char * dst, - uint3 tgpig[[threadgroup_position_in_grid]], - ushort3 tpitg[[thread_position_in_threadgroup]], - ushort3 ntg[[threads_per_threadgroup]]) { - const int i03 = tgpig.z; - const int i02 = tgpig.y; - const int i01 = tgpig.x; + i0 = k0*ntg.x + tpitg.x; - const int i13 = i03%args.ne13; - const int i12 = i02%args.ne12; - const int i11 = i01%args.ne11; + src0_ptr = (device const T0 *) (src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01); + dst_ptr = (device T *) (dst + i03*args.nb3 + i02*args.nb2 + i01*args.nb1 ); + } - device const char * src0_ptr = src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + args.offs; - device const char * src1_ptr = src1 + i13*args.nb13 + i12*args.nb12 + i11*args.nb11 + args.o1[0]; - device char * dst_ptr = dst + i03*args.nb3 + i02*args.nb2 + i01*args.nb1 + args.offs; + { + //threadgroup_barrier(mem_flags::mem_none); - if (args.ne10 == 1) { - const float x = 1.0f / *((device float *)(src1_ptr)); - for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) { - *((device float *)(dst_ptr + i0*args.nb0)) = *((device float *)(src0_ptr + i0*args.nb00)) * x; - } - } else { - for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) { - const int i10 = i0%args.ne10; - *((device float *)(dst_ptr + i0*args.nb0)) = *((device float *)(src0_ptr + i0*args.nb00)) / *((device float *)(src1_ptr + i10*args.nb10)); + if (!FC_CNT) { + if (i0 >= args.ne0) { + return; + } } - } -} -kernel void kernel_add_id( - constant ggml_metal_kargs_add_id & args, - device const char * src0, - device const char * src1, - device const char * src2, - device char * dst, - uint3 tgpig[[threadgroup_position_in_grid]], - ushort3 tpitg[[thread_position_in_threadgroup]], - ushort3 ntg[[threads_per_threadgroup]]) { - const int i1 = tgpig.x; - const int i2 = tgpig.y; + const TC x = (TC) src0_ptr[i0]; - const int i11 = *((device const int32_t *) (src2 + i1*sizeof(int32_t) + i2*args.nb21)); - - const size_t nb1 = args.ne0 * sizeof(float); - const size_t nb2 = args.ne1 * nb1; + if (FC_OP == OP_UNARY_NUM_SCALE) { + dst_ptr[i0] = (T) (args.scale * x + args.bias); + } - device float * dst_row = (device float *)((device char *)dst + i1*nb1 + i2*nb2); - device const float * src0_row = (device const float *)((device char *)src0 + i1*args.nb01 + i2*args.nb02); - device const float * src1_row = (device const float *)((device char *)src1 + i11*args.nb11); + if (FC_OP == OP_UNARY_NUM_FILL) { + dst_ptr[i0] = (T) args.val; + } - for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) { - dst_row[i0] = src0_row[i0] + src1_row[i0]; - } -} + if (FC_OP == OP_UNARY_NUM_CLAMP) { + dst_ptr[i0] = (T) clamp(x, args.min, args.max); + } -template<typename T> -kernel void kernel_repeat( - constant ggml_metal_kargs_repeat & args, - device const char * src0, - device char * dst, - uint3 tgpig[[threadgroup_position_in_grid]], - ushort3 tpitg[[thread_position_in_threadgroup]], - ushort3 ntg[[threads_per_threadgroup]]) { - const int i3 = tgpig.z; - const int i2 = tgpig.y; - const int i1 = tgpig.x; + if (FC_OP == OP_UNARY_NUM_SQR) { + dst_ptr[i0] = (T) (x * x); + } - const int i03 = i3%args.ne03; - const int i02 = i2%args.ne02; - const int i01 = i1%args.ne01; + if (FC_OP == OP_UNARY_NUM_SQRT) { + dst_ptr[i0] = (T) sqrt(x); + } - device const char * src0_ptr = src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01; - device char * dst_ptr = dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1; + if (FC_OP == OP_UNARY_NUM_SIN) { + dst_ptr[i0] = (T) sin(x); + } - for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) { - const int i00 = i0%args.ne00; - *((device T *)(dst_ptr + i0*args.nb0)) = *((device T *)(src0_ptr + i00*args.nb00)); - } -} + if (FC_OP == OP_UNARY_NUM_COS) { + dst_ptr[i0] = (T) cos(x); + } -typedef decltype(kernel_repeat<float>) kernel_repeat_t; + if (FC_OP == OP_UNARY_NUM_LOG) { + dst_ptr[i0] = (T) log(x); + } -template [[host_name("kernel_repeat_f32")]] kernel kernel_repeat_t kernel_repeat<float>; -template [[host_name("kernel_repeat_f16")]] kernel kernel_repeat_t kernel_repeat<half>; -template [[host_name("kernel_repeat_i32")]] kernel kernel_repeat_t kernel_repeat<int>; -template [[host_name("kernel_repeat_i16")]] kernel kernel_repeat_t kernel_repeat<short>; + if (FC_OP == OP_UNARY_NUM_LEAKY_RELU) { + dst_ptr[i0] = (T) (TC(x > 0)*x + TC(x <= 0)*(x * args.slope)); + } -// assumption: src1 is a row -// broadcast src1 into src0 -template <short F> -kernel void kernel_add_row_c4_fuse_impl( - constant ggml_metal_kargs_bin & args, - device const char * src0, - device const char * src1, - device char * dst, - uint tpig[[thread_position_in_grid]]) { - const uint nb = args.ne00/4; - const uint i = tpig % nb; + if (FC_OP == OP_UNARY_NUM_TANH) { + dst_ptr[i0] = (T) precise::tanh(x); + } - device const float4 * src0_row = (device const float4 *) (src0); - device float4 * dst_row = (device float4 *) (dst); + if (FC_OP == OP_UNARY_NUM_RELU) { + dst_ptr[i0] = (T) fmax(0, x); + } - float4 res = src0_row[tpig]; + if (FC_OP == OP_UNARY_NUM_SIGMOID) { + dst_ptr[i0] = (T) (1 / (1 + exp(-x))); + } -#pragma unroll(F) - for (short j = 0; j < F; ++j) { - res += ((device const float4 *) (src1 + args.o1[j]))[i]; - } + if (FC_OP == OP_UNARY_NUM_GELU) { + dst_ptr[i0] = (T) (0.5*x*(1 + precise::tanh(SQRT_2_OVER_PI*x*(1 + GELU_COEF_A*x*x)))); + } - dst_row[tpig] = res; -} + if (FC_OP == OP_UNARY_NUM_GELU_ERF) { + dst_ptr[i0] = (T) (0.5*x*(1 + erf_approx(SQRT_2_INV*x))); + } -typedef decltype(kernel_add_row_c4_fuse_impl<1>) kernel_add_row_c4_fuse_t; + if (FC_OP == OP_UNARY_NUM_GELU_QUICK) { + dst_ptr[i0] = (T) (x * (1/(1 + exp(GELU_QUICK_COEF*x)))); + } -template [[host_name("kernel_add_row_c4_fuse_1")]] kernel kernel_add_row_c4_fuse_t kernel_add_row_c4_fuse_impl<1>; -template [[host_name("kernel_add_row_c4_fuse_2")]] kernel kernel_add_row_c4_fuse_t kernel_add_row_c4_fuse_impl<2>; -template [[host_name("kernel_add_row_c4_fuse_3")]] kernel kernel_add_row_c4_fuse_t kernel_add_row_c4_fuse_impl<3>; -template [[host_name("kernel_add_row_c4_fuse_4")]] kernel kernel_add_row_c4_fuse_t kernel_add_row_c4_fuse_impl<4>; -template [[host_name("kernel_add_row_c4_fuse_5")]] kernel kernel_add_row_c4_fuse_t kernel_add_row_c4_fuse_impl<5>; -template [[host_name("kernel_add_row_c4_fuse_6")]] kernel kernel_add_row_c4_fuse_t kernel_add_row_c4_fuse_impl<6>; -template [[host_name("kernel_add_row_c4_fuse_7")]] kernel kernel_add_row_c4_fuse_t kernel_add_row_c4_fuse_impl<7>; -template [[host_name("kernel_add_row_c4_fuse_8")]] kernel kernel_add_row_c4_fuse_t kernel_add_row_c4_fuse_impl<8>; + if (FC_OP == OP_UNARY_NUM_SILU) { + dst_ptr[i0] = (T) (x / (1 + exp(-x))); + } -template <short F> -kernel void kernel_sub_row_c4_fuse_impl( - constant ggml_metal_kargs_bin & args, - device const char * src0, - device const char * src1, - device char * dst, - uint tpig[[thread_position_in_grid]]) { + if (FC_OP == OP_UNARY_NUM_ELU) { + dst_ptr[i0] = (T) elu_approx(x); + } - const uint nb = args.ne00/4; - const uint i = tpig % nb; + if (FC_OP == OP_UNARY_NUM_NEG) { + dst_ptr[i0] = (T) -x; + } - device const float4 * src0_row = (device const float4 *) (src0); - device float4 * dst_row = (device float4 *) (dst); + if (FC_OP == OP_UNARY_NUM_ABS) { + dst_ptr[i0] = (T) fabs(x); + } - device const float4 * src1_row[F]; - for (short j = 0; j < F; ++j) { - src1_row[j] = (device const float4 *) (src1 + args.o1[j]); - } + if (FC_OP == OP_UNARY_NUM_SGN) { + dst_ptr[i0] = T(x > 0) - T(x < 0); + } - float4 res = src0_row[tpig]; + if (FC_OP == OP_UNARY_NUM_STEP) { + dst_ptr[i0] = T(x > 0); + } -#pragma unroll(F) - for (short j = 0; j < F; ++j) { - res -= src1_row[j][i]; - } + if (FC_OP == OP_UNARY_NUM_HARDSWISH) { + dst_ptr[i0] = (T) (x * fmax(0, fmin(1, x/6 + 0.5))); + } - dst_row[tpig] = res; -} + if (FC_OP == OP_UNARY_NUM_HARDSIGMOID) { + dst_ptr[i0] = (T) fmax(0, fmin(1, x/6 + 0.5)); + } -typedef decltype(kernel_sub_row_c4_fuse_impl<1>) kernel_sub_row_c4_fuse_t; + if (FC_OP == OP_UNARY_NUM_EXP) { + dst_ptr[i0] = (T) exp(x); + } -template [[host_name("kernel_sub_row_c4_fuse_1")]] kernel kernel_sub_row_c4_fuse_t kernel_sub_row_c4_fuse_impl<1>; + if (FC_OP == OP_UNARY_NUM_SOFTPLUS) { + dst_ptr[i0] = (T) select(log(1 + exp(x)), x, x > 20); + } -template <short F> -kernel void kernel_mul_row_c4_fuse_impl( - constant ggml_metal_kargs_bin & args, - device const char * src0, - device const char * src1, - device char * dst, - uint tpig[[thread_position_in_grid]]) { + if (FC_OP == OP_UNARY_NUM_EXPM1) { + // TODO: precise implementation + dst_ptr[i0] = (T) (exp(x) - 1); + } - const uint nb = args.ne00/4; - const uint i = tpig % nb; + if (FC_OP == OP_UNARY_NUM_FLOOR) { + dst_ptr[i0] = (T) floor(x); + } - device const float4 * src0_row = (device const float4 *) (src0); - device float4 * dst_row = (device float4 *) (dst); + if (FC_OP == OP_UNARY_NUM_CEIL) { + dst_ptr[i0] = (T) ceil(x); + } - device const float4 * src1_row[F]; - for (short j = 0; j < F; ++j) { - src1_row[j] = (device const float4 *) (src1 + args.o1[j]); - } + if (FC_OP == OP_UNARY_NUM_ROUND) { + dst_ptr[i0] = (T) round(x); + } - float4 res = src0_row[tpig]; + if (FC_OP == OP_UNARY_NUM_TRUNC) { + dst_ptr[i0] = (T) trunc(x); + } -#pragma unroll(F) - for (short j = 0; j < F; ++j) { - res *= src1_row[j][i]; + if (FC_OP == OP_UNARY_NUM_XIELU) { + const TC xi = x; + const TC gate = TC(xi > TC(0.0f)); + const TC clamped = fmin(xi, TC(args.val)); + const TC y_pos = TC(args.scale) * xi * xi + TC(args.bias) * xi; + const TC y_neg = (exp(clamped) - TC(1.0f) - xi) * TC(args.slope) + TC(args.bias) * xi; + dst_ptr[i0] = (T) (gate * y_pos + (TC(1.0f) - gate) * y_neg); + } } - dst_row[tpig] = res; +#undef FC_OP +#undef FC_CNT } -typedef decltype(kernel_mul_row_c4_fuse_impl<1>) kernel_mul_row_c4_fuse_t; +typedef decltype(kernel_unary_impl<float, float, float>) kernel_unary_t; -template [[host_name("kernel_mul_row_c4_fuse_1")]] kernel kernel_mul_row_c4_fuse_t kernel_mul_row_c4_fuse_impl<1>; +template [[host_name("kernel_unary_f32_f32")]] kernel kernel_unary_t kernel_unary_impl<float, float, float>; +template [[host_name("kernel_unary_f32_f32_4")]] kernel kernel_unary_t kernel_unary_impl<float4, float4, float4>; +template [[host_name("kernel_unary_f16_f16")]] kernel kernel_unary_t kernel_unary_impl<half, half, float>; +template [[host_name("kernel_unary_f16_f16_4")]] kernel kernel_unary_t kernel_unary_impl<half4, half4, float4>; -template <short F> -kernel void kernel_div_row_c4_fuse_impl( +// OP: 0 - add, 1 - sub, 2 - mul, 3 - div +constant short FC_bin_op [[function_constant(FC_BIN + 0)]]; +constant short FC_bin_f [[function_constant(FC_BIN + 1)]]; +constant bool FC_bin_rb [[function_constant(FC_BIN + 2)]]; +constant bool FC_bin_cb [[function_constant(FC_BIN + 3)]]; + +template <typename T0, typename T1, typename T> +kernel void kernel_bin_fuse_impl( constant ggml_metal_kargs_bin & args, device const char * src0, device const char * src1, device char * dst, - uint tpig[[thread_position_in_grid]]) { - - const uint nb = args.ne00/4; - const uint i = tpig % nb; - - device const float4 * src0_row = (device const float4 *) (src0); - device float4 * dst_row = (device float4 *) (dst); - - device const float4 * src1_row[F]; - for (short j = 0; j < F; ++j) { - src1_row[j] = (device const float4 *) (src1 + args.o1[j]); - } - - float4 res = src0_row[tpig]; - -#pragma unroll(F) - for (short j = 0; j < F; ++j) { - res /= src1_row[j][i]; - } - - dst_row[tpig] = res; -} - -typedef decltype(kernel_div_row_c4_fuse_impl<1>) kernel_div_row_c4_fuse_t; - -template [[host_name("kernel_div_row_c4_fuse_1")]] kernel kernel_div_row_c4_fuse_t kernel_div_row_c4_fuse_impl<1>; - -kernel void kernel_scale_f32( - constant ggml_metal_kargs_scale & args, - device const float * src0, - device float * dst, - uint tpig[[thread_position_in_grid]]) { - dst[tpig] = src0[tpig] * args.scale + args.bias; -} - -kernel void kernel_scale_f32_4( - constant ggml_metal_kargs_scale & args, - device const float4 * src0, - device float4 * dst, - uint tpig[[thread_position_in_grid]]) { - dst[tpig] = src0[tpig] * args.scale + args.bias; -} - -kernel void kernel_fill_f32( - constant ggml_metal_kargs_fill & args, - device const float * src0, - device float * dst, - uint tpig[[thread_position_in_grid]]) { - dst[tpig] = args.val; -} - -kernel void kernel_fill_f32_4( - constant ggml_metal_kargs_fill & args, - device const float4 * src0, - device float4 * dst, - uint tpig[[thread_position_in_grid]]) { - dst[tpig] = args.val; -} - -kernel void kernel_clamp_f32( - constant ggml_metal_kargs_clamp & args, - device const float * src0, - device float * dst, - uint tpig[[thread_position_in_grid]]) { - dst[tpig] = clamp(src0[tpig], args.min, args.max); -} - -kernel void kernel_clamp_f32_4( - constant ggml_metal_kargs_clamp & args, - device const float4 * src0, - device float4 * dst, - uint tpig[[thread_position_in_grid]]) { - dst[tpig] = clamp(src0[tpig], args.min, args.max); -} - -kernel void kernel_relu_f32( - device const float * src0, - device float * dst, - uint tpig[[thread_position_in_grid]]) { - dst[tpig] = max(0.0f, src0[tpig]); -} - -kernel void kernel_relu_f32_4( - device const float4 * src0, - device float4 * dst, - uint tpig[[thread_position_in_grid]]) { - dst[tpig] = max(0.0f, src0[tpig]); -} - -kernel void kernel_sigmoid_f32( - device const float * src0, - device float * dst, - uint tpig[[thread_position_in_grid]]) { - dst[tpig] = 1.0f / (1.0f + exp(-src0[tpig])); -} - -kernel void kernel_sigmoid_f32_4( - device const float4 * src0, - device float4 * dst, - uint tpig[[thread_position_in_grid]]) { - dst[tpig] = 1.0f / (1.0f + exp(-src0[tpig])); -} - -kernel void kernel_tanh_f32( - device const float * src0, - device float * dst, - uint tpig[[thread_position_in_grid]]) { - dst[tpig] = precise::tanh(src0[tpig]); -} - -kernel void kernel_tanh_f32_4( - device const float4 * src0, - device float4 * dst, - uint tpig[[thread_position_in_grid]]) { - dst[tpig] = precise::tanh(src0[tpig]); -} - -constant float GELU_COEF_A = 0.044715f; -constant float GELU_QUICK_COEF = -1.702f; -constant float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f; -constant float SQRT_2_INV = 0.70710678118654752440084436210484f; - -kernel void kernel_gelu_f32( - device const float * src0, - device float * dst, - uint tpig[[thread_position_in_grid]]) { - device const float & x = src0[tpig]; - - dst[tpig] = 0.5f*x*(1.0f + precise::tanh(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x))); -} - -kernel void kernel_gelu_f32_4( - device const float4 * src0, - device float4 * dst, - uint tpig[[thread_position_in_grid]]) { - device const float4 & x = src0[tpig]; - - // BEWARE !!! - // Simply using "tanh" instead of "precise::tanh" will sometimes results in NaNs! - // This was observed with Falcon 7B and 40B models - // - dst[tpig] = 0.5f*x*(1.0f + precise::tanh(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x))); -} + uint3 tgpig[[threadgroup_position_in_grid]], + ushort3 tpitg[[thread_position_in_threadgroup]], + ushort3 ntg[[threads_per_threadgroup]]) { +#define FC_OP FC_bin_op +#define FC_F FC_bin_f +#define FC_RB FC_bin_rb +#define FC_CB FC_bin_cb -kernel void kernel_gelu_quick_f32( - device const float * src0, - device float * dst, - uint tpig[[thread_position_in_grid]]) { - device const float & x = src0[tpig]; + if (FC_RB) { + // row broadcast + const uint i0 = tgpig.y*args.ne00 + tgpig.x; + const uint i1 = FC_CB ? tgpig.x%args.ne10 : tgpig.x; - dst[tpig] = x*(1.0f/(1.0f+exp(GELU_QUICK_COEF*x))); -} + device const T0 * src0_row = (device const T0 *) (src0); + device T * dst_row = (device T *) (dst); -kernel void kernel_gelu_quick_f32_4( - device const float4 * src0, - device float4 * dst, - uint tpig[[thread_position_in_grid]]) { - device const float4 & x = src0[tpig]; + if (FC_F == 1) { + device const T1 * src1_row = (device const T1 *) (src1 + args.o1[0]); - dst[tpig] = x*(1.0f/(1.0f+exp(GELU_QUICK_COEF*x))); -} + if (FC_OP == 0) { + dst_row[i0] = src0_row[i0] + src1_row[i1]; + } -// based on Abramowitz and Stegun formula 7.1.26 or similar Hastings' approximation -// ref: https://www.johndcook.com/blog/python_erf/ -constant float p_erf = 0.3275911f; -constant float a1_erf = 0.254829592f; -constant float a2_erf = -0.284496736f; -constant float a3_erf = 1.421413741f; -constant float a4_erf = -1.453152027f; -constant float a5_erf = 1.061405429f; + if (FC_OP == 1) { + dst_row[i0] = src0_row[i0] - src1_row[i1]; + } -template<typename T> -T erf_approx(T x) { - T sign_x = sign(x); - x = fabs(x); - T t = 1.0f / (1.0f + p_erf * x); - T y = 1.0f - (((((a5_erf * t + a4_erf) * t) + a3_erf) * t + a2_erf) * t + a1_erf) * t * exp(-x * x); - return sign_x * y; -} + if (FC_OP == 2) { + dst_row[i0] = src0_row[i0] * src1_row[i1]; + } -kernel void kernel_gelu_erf_f32( - device const float * src0, - device float * dst, - uint tpig[[thread_position_in_grid]]) { - device const float & x = src0[tpig]; + if (FC_OP == 3) { + dst_row[i0] = src0_row[i0] / src1_row[i1]; + } + } else { + T0 res = src0_row[i0]; - dst[tpig] = 0.5f*x*(1.0f+erf_approx<float>(x*SQRT_2_INV)); -} + if (FC_OP == 0) { + FOR_UNROLL (short j = 0; j < FC_F; ++j) { + res += ((device const T1 *) (src1 + args.o1[j]))[i1]; + } + } -kernel void kernel_gelu_erf_f32_4( - device const float4 * src0, - device float4 * dst, - uint tpig[[thread_position_in_grid]]) { - device const float4 & x = src0[tpig]; + if (FC_OP == 1) { + FOR_UNROLL (short j = 0; j < FC_F; ++j) { + res -= ((device const T1 *) (src1 + args.o1[j]))[i1]; + } + } - dst[tpig] = 0.5f*x*(1.0f+erf_approx<float4>(x*SQRT_2_INV)); -} + if (FC_OP == 2) { + FOR_UNROLL (short j = 0; j < FC_F; ++j) { + res *= ((device const T1 *) (src1 + args.o1[j]))[i1]; + } + } -kernel void kernel_silu_f32( - device const float * src0, - device float * dst, - uint tpig[[thread_position_in_grid]]) { - device const float & x = src0[tpig]; - dst[tpig] = x / (1.0f + exp(-x)); -} + if (FC_OP == 3) { + FOR_UNROLL (short j = 0; j < FC_F; ++j) { + res /= ((device const T1 *) (src1 + args.o1[j]))[i1]; + } + } -kernel void kernel_silu_f32_4( - device const float4 * src0, - device float4 * dst, - uint tpig[[thread_position_in_grid]]) { - device const float4 & x = src0[tpig]; - dst[tpig] = x / (1.0f + exp(-x)); -} + dst_row[i0] = res; + } + } else { + const int i03 = tgpig.z; + const int i02 = tgpig.y; + const int i01 = tgpig.x; -kernel void kernel_elu_f32( - device const float * src0, - device float * dst, - uint tpig[[thread_position_in_grid]]) { - const float x = src0[tpig]; - dst[tpig] = (x > 0.0f) ? x : (exp(x) - 1.0f); -} + if (i01 >= args.ne01) { + return; + } -kernel void kernel_elu_f32_4( - device const float4 * src0, - device float4 * dst, - uint tpig[[thread_position_in_grid]]) { - const float4 x = src0[tpig]; - dst[tpig][0] = (x[0] > 0.0f) ? x[0] : (exp(x[0]) - 1.0f); - dst[tpig][1] = (x[1] > 0.0f) ? x[1] : (exp(x[1]) - 1.0f); - dst[tpig][2] = (x[2] > 0.0f) ? x[2] : (exp(x[2]) - 1.0f); - dst[tpig][3] = (x[3] > 0.0f) ? x[3] : (exp(x[3]) - 1.0f); -} + const int i13 = i03%args.ne13; + const int i12 = i02%args.ne12; + const int i11 = i01%args.ne11; -kernel void kernel_sqr_f32( - device const float * src0, - device float * dst, - uint tpig[[thread_position_in_grid]]) { - dst[tpig] = src0[tpig] * src0[tpig]; -} + device const T0 * src0_ptr = (device const T0 *) (src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + args.offs); + device T * dst_ptr = (device T *) (dst + i03*args.nb3 + i02*args.nb2 + i01*args.nb1 + args.offs); -kernel void kernel_sqr_f32_4( - device const float4 * src0, - device float4 * dst, - uint tpig[[thread_position_in_grid]]) { - dst[tpig] = src0[tpig] * src0[tpig]; -} + if (FC_F == 1) { + device const T1 * src1_ptr = (device const T1 *) (src1 + args.o1[0] + i13*args.nb13 + i12*args.nb12 + i11*args.nb11); -kernel void kernel_sqrt_f32( - device const float * src0, - device float * dst, - uint tpig[[thread_position_in_grid]]) { - dst[tpig] = sqrt(src0[tpig]); -} + for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) { + const int i10 = FC_CB ? i0%args.ne10 : i0; -kernel void kernel_sqrt_f32_4( - device const float4 * src0, - device float4 * dst, - uint tpig[[thread_position_in_grid]]) { - dst[tpig] = sqrt(src0[tpig]); -} + if (FC_OP == 0) { + dst_ptr[i0] = src0_ptr[i0] + src1_ptr[i10]; + } -kernel void kernel_sin_f32( - device const float * src0, - device float * dst, - uint tpig[[thread_position_in_grid]]) { - dst[tpig] = sin(src0[tpig]); -} + if (FC_OP == 1) { + dst_ptr[i0] = src0_ptr[i0] - src1_ptr[i10]; + } -kernel void kernel_sin_f32_4( - device const float4 * src0, - device float4 * dst, - uint tpig[[thread_position_in_grid]]) { - dst[tpig] = sin(src0[tpig]); -} + if (FC_OP == 2) { + dst_ptr[i0] = src0_ptr[i0] * src1_ptr[i10]; + } -kernel void kernel_cos_f32( - device const float * src0, - device float * dst, - uint tpig[[thread_position_in_grid]]) { - dst[tpig] = cos(src0[tpig]); -} + if (FC_OP == 3) { + dst_ptr[i0] = src0_ptr[i0] / src1_ptr[i10]; + } + } + } else { + device const T1 * src1_ptr[8]; + FOR_UNROLL (short j = 0; j < FC_F; ++j) { + src1_ptr[j] = (device const T1 *) (src1 + args.o1[j] + i13*args.nb13 + i12*args.nb12 + i11*args.nb11); + } -kernel void kernel_cos_f32_4( - device const float4 * src0, - device float4 * dst, - uint tpig[[thread_position_in_grid]]) { - dst[tpig] = cos(src0[tpig]); -} + for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) { + const int i10 = FC_CB ? i0%args.ne10 : i0; -kernel void kernel_log_f32( - device const float * src0, - device float * dst, - uint tpig[[thread_position_in_grid]]) { - dst[tpig] = log(src0[tpig]); -} + T res = src0_ptr[i0]; -kernel void kernel_log_f32_4( - device const float4 * src0, - device float4 * dst, - uint tpig[[thread_position_in_grid]]) { - dst[tpig] = log(src0[tpig]); -} + if (FC_OP == 0) { + FOR_UNROLL (short j = 0; j < FC_F; ++j) { + res += src1_ptr[j][i10]; + } + } -kernel void kernel_neg_f32( - device const float * src0, - device float * dst, - uint tpig[[thread_position_in_grid]]) { - dst[tpig] = -src0[tpig]; -} + if (FC_OP == 1) { + FOR_UNROLL (short j = 0; j < FC_F; ++j) { + res -= src1_ptr[j][i10]; + } + } -kernel void kernel_neg_f32_4( - device const float4 * src0, - device float4 * dst, - uint tpig[[thread_position_in_grid]]) { - dst[tpig] = -src0[tpig]; -} + if (FC_OP == 2) { + FOR_UNROLL (short j = 0; j < FC_F; ++j) { + res *= src1_ptr[j][i10]; + } + } -kernel void kernel_abs_f32( - device const float * src0, - device float * dst, - uint tpig[[thread_position_in_grid]]) { - dst[tpig] = fabs(src0[tpig]); -} + if (FC_OP == 3) { + FOR_UNROLL (short j = 0; j < FC_F; ++j) { + res /= src1_ptr[j][i10]; + } + } -kernel void kernel_abs_f32_4( - device const float4 * src0, - device float4 * dst, - uint tpig[[thread_position_in_grid]]) { - dst[tpig] = fabs(src0[tpig]); -} + dst_ptr[i0] = res; + } + } + } -kernel void kernel_sgn_f32( - device const float * src0, - device float * dst, - uint tpig[[thread_position_in_grid]]) { - dst[tpig] = sign(src0[tpig]); +#undef FC_OP +#undef FC_F +#undef FC_RB +#undef FC_CB } -kernel void kernel_sgn_f32_4( - device const float4 * src0, - device float4 * dst, - uint tpig[[thread_position_in_grid]]) { - dst[tpig] = sign(src0[tpig]); -} +typedef decltype(kernel_bin_fuse_impl<float, float, float>) kernel_bin_fuse_t; -kernel void kernel_step_f32( - device const float * src0, - device float * dst, - uint tpig[[thread_position_in_grid]]) { - dst[tpig] = step(0.0f, src0[tpig]); -} +template [[host_name("kernel_bin_fuse_f32_f32_f32")]] kernel kernel_bin_fuse_t kernel_bin_fuse_impl<float, float, float>; +template [[host_name("kernel_bin_fuse_f32_f32_f32_4")]] kernel kernel_bin_fuse_t kernel_bin_fuse_impl<float4, float4, float4>; -kernel void kernel_step_f32_4( - device const float4 * src0, - device float4 * dst, - uint tpig[[thread_position_in_grid]]) { - dst[tpig] = step(0.0f, src0[tpig]); -} +kernel void kernel_add_id( + constant ggml_metal_kargs_add_id & args, + device const char * src0, + device const char * src1, + device const char * src2, + device char * dst, + uint3 tgpig[[threadgroup_position_in_grid]], + ushort3 tpitg[[thread_position_in_threadgroup]], + ushort3 ntg[[threads_per_threadgroup]]) { + const int i1 = tgpig.x; + const int i2 = tgpig.y; -kernel void kernel_hardswish_f32( - device const float * src0, - device float * dst, - uint tpig[[thread_position_in_grid]]) { - const float x = src0[tpig]; - dst[tpig] = x * fmin(1.0f, fmax(0.0f, (x + 3.0f) / 6.0f)); -} + const int i11 = *((device const int32_t *) (src2 + i1*sizeof(int32_t) + i2*args.nb21)); -kernel void kernel_hardswish_f32_4( - device const float4 * src0, - device float4 * dst, - uint tpig[[thread_position_in_grid]]) { - const float4 x = src0[tpig]; - dst[tpig] = x * fmin(1.0f, fmax(0.0f, (x + 3.0f) / 6.0f)); -} + const size_t nb1 = args.ne0 * sizeof(float); + const size_t nb2 = args.ne1 * nb1; -kernel void kernel_hardsigmoid_f32( - device const float * src0, - device float * dst, - uint tpig[[thread_position_in_grid]]) { - const float x = src0[tpig]; - dst[tpig] = fmin(1.0f, fmax(0.0f, (x + 3.0f) / 6.0f)); -} + device float * dst_row = (device float *)((device char *)dst + i1*nb1 + i2*nb2); + device const float * src0_row = (device const float *)((device char *)src0 + i1*args.nb01 + i2*args.nb02); + device const float * src1_row = (device const float *)((device char *)src1 + i11*args.nb11); -kernel void kernel_hardsigmoid_f32_4( - device const float4 * src0, - device float4 * dst, - uint tpig[[thread_position_in_grid]]) { - const float4 x = src0[tpig]; - dst[tpig] = fmin(1.0f, fmax(0.0f, (x + 3.0f) / 6.0f)); + for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) { + dst_row[i0] = src0_row[i0] + src1_row[i0]; + } } -kernel void kernel_exp_f32( - device const float * src0, - device float * dst, - uint tpig[[thread_position_in_grid]]) { - dst[tpig] = exp(src0[tpig]); -} +template<typename T> +kernel void kernel_repeat( + constant ggml_metal_kargs_repeat & args, + device const char * src0, + device char * dst, + uint3 tgpig[[threadgroup_position_in_grid]], + ushort3 tpitg[[thread_position_in_threadgroup]], + ushort3 ntg[[threads_per_threadgroup]]) { + const int i3 = tgpig.z; + const int i2 = tgpig.y; + const int i1 = tgpig.x; -kernel void kernel_exp_f32_4( - device const float4 * src0, - device float4 * dst, - uint tpig[[thread_position_in_grid]]) { - dst[tpig] = exp(src0[tpig]); -} + const int i03 = i3%args.ne03; + const int i02 = i2%args.ne02; + const int i01 = i1%args.ne01; -kernel void kernel_softplus_f32( - device const float * src0, - device float * dst, - uint tpig[[thread_position_in_grid]]) { - device const float & x = src0[tpig]; - dst[tpig] = select(log(1.0f + exp(x)), x, x > 20.0f); -} + device const char * src0_ptr = src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01; + device char * dst_ptr = dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1; -kernel void kernel_softplus_f32_4( - device const float4 * src0, - device float4 * dst, - uint tpig[[thread_position_in_grid]]) { - device const float4 & x = src0[tpig]; - dst[tpig] = select(log(1.0f + exp(x)), x, x > 20.0f); + for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) { + const int i00 = i0%args.ne00; + *((device T *)(dst_ptr + i0*args.nb0)) = *((device T *)(src0_ptr + i00*args.nb00)); + } } -kernel void kernel_expm1_f32( - device const float * src0, - device float * dst, - uint tpig[[thread_position_in_grid]]) { - dst[tpig] = exp(src0[tpig]) - 1.0f; -} +typedef decltype(kernel_repeat<float>) kernel_repeat_t; -kernel void kernel_expm1_f32_4( - device const float4 * src0, - device float4 * dst, - uint tpig[[thread_position_in_grid]]) { - dst[tpig] = exp(src0[tpig]) - 1.0f; -} +template [[host_name("kernel_repeat_f32")]] kernel kernel_repeat_t kernel_repeat<float>; +template [[host_name("kernel_repeat_f16")]] kernel kernel_repeat_t kernel_repeat<half>; +template [[host_name("kernel_repeat_i32")]] kernel kernel_repeat_t kernel_repeat<int>; +template [[host_name("kernel_repeat_i16")]] kernel kernel_repeat_t kernel_repeat<short>; -kernel void kernel_reglu_f32( +template<typename T> +kernel void kernel_reglu( constant ggml_metal_kargs_glu & args, device const char * src0, device const char * src1, @@ -1649,19 +1430,25 @@ kernel void kernel_reglu_f32( uint tgpig[[threadgroup_position_in_grid]], uint tpitg[[thread_position_in_threadgroup]], uint ntg[[threads_per_threadgroup]]) { - device const float * src0_row = (device const float *) ((device const char *) src0 + tgpig*args.nb01) + args.i00; - device const float * src1_row = (device const float *) ((device const char *) src1 + tgpig*args.nb11) + args.i10; - device float * dst_row = (device float *) ((device char *) dst + tgpig*args.nb1); + device const T * src0_row = (device const T *) ((device const char *) src0 + tgpig*args.nb01) + args.i00; + device const T * src1_row = (device const T *) ((device const char *) src1 + tgpig*args.nb11) + args.i10; + device T * dst_row = (device T *) ((device char *) dst + tgpig*args.nb1); for (int i0 = tpitg; i0 < args.ne0; i0 += ntg) { const float x0 = src0_row[i0]; const float x1 = src1_row[i0]; - dst_row[i0] = x0*x1*(x0 > 0.0f); + dst_row[i0] = (T)(x0*x1*(x0 > 0.0f)); } } -kernel void kernel_geglu_f32( +typedef decltype(kernel_reglu<float>) kernel_reglu_t; + +template [[host_name("kernel_reglu_f32")]] kernel kernel_reglu_t kernel_reglu<float>; +template [[host_name("kernel_reglu_f16")]] kernel kernel_reglu_t kernel_reglu<half>; + +template<typename T> +kernel void kernel_geglu( constant ggml_metal_kargs_glu & args, device const char * src0, device const char * src1, @@ -1669,9 +1456,9 @@ kernel void kernel_geglu_f32( uint tgpig[[threadgroup_position_in_grid]], uint tpitg[[thread_position_in_threadgroup]], uint ntg[[threads_per_threadgroup]]) { - device const float * src0_row = (device const float *) ((device const char *) src0 + tgpig*args.nb01) + args.i00; - device const float * src1_row = (device const float *) ((device const char *) src1 + tgpig*args.nb11) + args.i10; - device float * dst_row = (device float *) ((device char *) dst + tgpig*args.nb1); + device const T * src0_row = (device const T *) ((device const char *) src0 + tgpig*args.nb01) + args.i00; + device const T * src1_row = (device const T *) ((device const char *) src1 + tgpig*args.nb11) + args.i10; + device T * dst_row = (device T *) ((device char *) dst + tgpig*args.nb1); for (int i0 = tpitg; i0 < args.ne0; i0 += ntg) { const float x0 = src0_row[i0]; @@ -1679,11 +1466,17 @@ kernel void kernel_geglu_f32( const float gelu = 0.5f*x0*(1.0f + precise::tanh(SQRT_2_OVER_PI*x0*(1.0f + GELU_COEF_A*x0*x0))); - dst_row[i0] = gelu*x1; + dst_row[i0] = (T)(gelu*x1); } } -kernel void kernel_swiglu_f32( +typedef decltype(kernel_geglu<float>) kernel_geglu_t; + +template [[host_name("kernel_geglu_f32")]] kernel kernel_geglu_t kernel_geglu<float>; +template [[host_name("kernel_geglu_f16")]] kernel kernel_geglu_t kernel_geglu<half>; + +template<typename T> +kernel void kernel_swiglu( constant ggml_metal_kargs_glu & args, device const char * src0, device const char * src1, @@ -1691,9 +1484,9 @@ kernel void kernel_swiglu_f32( uint tgpig[[threadgroup_position_in_grid]], uint tpitg[[thread_position_in_threadgroup]], uint ntg[[threads_per_threadgroup]]) { - device const float * src0_row = (device const float *) ((device const char *) src0 + tgpig*args.nb01) + args.i00; - device const float * src1_row = (device const float *) ((device const char *) src1 + tgpig*args.nb11) + args.i10; - device float * dst_row = (device float *) ((device char *) dst + tgpig*args.nb1); + device const T * src0_row = (device const T *) ((device const char *) src0 + tgpig*args.nb01) + args.i00; + device const T * src1_row = (device const T *) ((device const char *) src1 + tgpig*args.nb11) + args.i10; + device T * dst_row = (device T *) ((device char *) dst + tgpig*args.nb1); for (int i0 = tpitg; i0 < args.ne0; i0 += ntg) { const float x0 = src0_row[i0]; @@ -1701,11 +1494,17 @@ kernel void kernel_swiglu_f32( const float silu = x0 / (1.0f + exp(-x0)); - dst_row[i0] = silu*x1; + dst_row[i0] = (T)(silu*x1); } } -kernel void kernel_swiglu_oai_f32( +typedef decltype(kernel_swiglu<float>) kernel_swiglu_t; + +template [[host_name("kernel_swiglu_f32")]] kernel kernel_swiglu_t kernel_swiglu<float>; +template [[host_name("kernel_swiglu_f16")]] kernel kernel_swiglu_t kernel_swiglu<half>; + +template<typename T> +kernel void kernel_swiglu_oai( constant ggml_metal_kargs_glu & args, device const char * src0, device const char * src1, @@ -1713,9 +1512,9 @@ kernel void kernel_swiglu_oai_f32( uint tgpig[[threadgroup_position_in_grid]], uint tpitg[[thread_position_in_threadgroup]], uint ntg[[threads_per_threadgroup]]) { - device const float * src0_row = (device const float *) ((device const char *) src0 + tgpig*args.nb01) + args.i00; - device const float * src1_row = (device const float *) ((device const char *) src1 + tgpig*args.nb11) + args.i10; - device float * dst_row = (device float *) ((device char *) dst + tgpig*args.nb1); + device const T * src0_row = (device const T *) ((device const char *) src0 + tgpig*args.nb01) + args.i00; + device const T * src1_row = (device const T *) ((device const char *) src1 + tgpig*args.nb11) + args.i10; + device T * dst_row = (device T *) ((device char *) dst + tgpig*args.nb1); for (int i0 = tpitg; i0 < args.ne0; i0 += ntg) { float x0 = src0_row[i0]; @@ -1727,11 +1526,17 @@ kernel void kernel_swiglu_oai_f32( float out_glu = x0 / (1.0f + exp(-x0 * args.alpha)); out_glu = out_glu * (1.0f + x1); - dst_row[i0] = out_glu; + dst_row[i0] = (T)out_glu; } } -kernel void kernel_geglu_erf_f32( +typedef decltype(kernel_swiglu_oai<float>) kernel_swiglu_oai_t; + +template [[host_name("kernel_swiglu_oai_f32")]] kernel kernel_swiglu_oai_t kernel_swiglu_oai<float>; +template [[host_name("kernel_swiglu_oai_f16")]] kernel kernel_swiglu_oai_t kernel_swiglu_oai<half>; + +template<typename T> +kernel void kernel_geglu_erf( constant ggml_metal_kargs_glu & args, device const char * src0, device const char * src1, @@ -1739,9 +1544,9 @@ kernel void kernel_geglu_erf_f32( uint tgpig[[threadgroup_position_in_grid]], uint tpitg[[thread_position_in_threadgroup]], uint ntg[[threads_per_threadgroup]]) { - device const float * src0_row = (device const float *) ((device const char *) src0 + tgpig*args.nb01) + args.i00; - device const float * src1_row = (device const float *) ((device const char *) src1 + tgpig*args.nb11) + args.i10; - device float * dst_row = (device float *) ((device char *) dst + tgpig*args.nb1); + device const T * src0_row = (device const T *) ((device const char *) src0 + tgpig*args.nb01) + args.i00; + device const T * src1_row = (device const T *) ((device const char *) src1 + tgpig*args.nb11) + args.i10; + device T * dst_row = (device T *) ((device char *) dst + tgpig*args.nb1); for (int i0 = tpitg; i0 < args.ne0; i0 += ntg) { const float x0 = src0_row[i0]; @@ -1749,11 +1554,17 @@ kernel void kernel_geglu_erf_f32( const float gelu_erf = 0.5f*x0*(1.0f+erf_approx<float>(x0*SQRT_2_INV)); - dst_row[i0] = gelu_erf*x1; + dst_row[i0] = (T)(gelu_erf*x1); } } -kernel void kernel_geglu_quick_f32( +typedef decltype(kernel_geglu_erf<float>) kernel_geglu_erf_t; + +template [[host_name("kernel_geglu_erf_f32")]] kernel kernel_geglu_erf_t kernel_geglu_erf<float>; +template [[host_name("kernel_geglu_erf_f16")]] kernel kernel_geglu_erf_t kernel_geglu_erf<half>; + +template<typename T> +kernel void kernel_geglu_quick( constant ggml_metal_kargs_glu & args, device const char * src0, device const char * src1, @@ -1761,9 +1572,9 @@ kernel void kernel_geglu_quick_f32( uint tgpig[[threadgroup_position_in_grid]], uint tpitg[[thread_position_in_threadgroup]], uint ntg[[threads_per_threadgroup]]) { - device const float * src0_row = (device const float *) ((device const char *) src0 + tgpig*args.nb01) + args.i00; - device const float * src1_row = (device const float *) ((device const char *) src1 + tgpig*args.nb11) + args.i10; - device float * dst_row = (device float *) ((device char *) dst + tgpig*args.nb1); + device const T * src0_row = (device const T *) ((device const char *) src0 + tgpig*args.nb01) + args.i00; + device const T * src1_row = (device const T *) ((device const char *) src1 + tgpig*args.nb11) + args.i10; + device T * dst_row = (device T *) ((device char *) dst + tgpig*args.nb1); for (int i0 = tpitg; i0 < args.ne0; i0 += ntg) { const float x0 = src0_row[i0]; @@ -1771,10 +1582,15 @@ kernel void kernel_geglu_quick_f32( const float gelu_quick = x0*(1.0f/(1.0f+exp(GELU_QUICK_COEF*x0))); - dst_row[i0] = gelu_quick*x1; + dst_row[i0] = (T)(gelu_quick*x1); } } +typedef decltype(kernel_geglu_quick<float>) kernel_geglu_quick_t; + +template [[host_name("kernel_geglu_quick_f32")]] kernel kernel_geglu_quick_t kernel_geglu_quick<float>; +template [[host_name("kernel_geglu_quick_f16")]] kernel kernel_geglu_quick_t kernel_geglu_quick<half>; + kernel void kernel_op_sum_f32( constant ggml_metal_kargs_sum & args, device const float * src0, @@ -1824,33 +1640,35 @@ kernel void kernel_op_sum_f32( } } -template <bool norm> -kernel void kernel_sum_rows( +constant short FC_sum_rows_op [[function_constant(FC_SUM_ROWS + 0)]]; + +template <typename T0, typename T> +kernel void kernel_sum_rows_impl( constant ggml_metal_kargs_sum_rows & args, - device const float * src0, - device float * dst, - threadgroup float * shmem_f32 [[threadgroup(0)]], + device const char * src0, + device char * dst, + threadgroup char * shmem [[threadgroup(0)]], uint3 tgpig[[threadgroup_position_in_grid]], ushort3 tpitg[[thread_position_in_threadgroup]], ushort sgitg[[simdgroup_index_in_threadgroup]], ushort tiisg[[thread_index_in_simdgroup]], ushort3 ntg[[threads_per_threadgroup]]) { - int64_t i3 = tgpig.z; - int64_t i2 = tgpig.y; - int64_t i1 = tgpig.x; +#define FC_OP FC_sum_rows_op - if (i3 >= args.ne03 || i2 >= args.ne02 || i1 >= args.ne01) { - return; - } + const int i3 = tgpig.z; + const int i2 = tgpig.y; + const int i1 = tgpig.x; + + threadgroup T0 * shmem_t = (threadgroup T0 *) shmem; if (sgitg == 0) { - shmem_f32[tiisg] = 0.0f; + shmem_t[tiisg] = 0.0f; } - device const float * src_row = (device const float *) ((device const char *) src0 + i1*args.nb01 + i2*args.nb02 + i3*args.nb03); - device float * dst_row = (device float *) ((device char *) dst + i1*args.nb1 + i2*args.nb2 + i3*args.nb3); + device const T0 * src_row = (device const T0 *) (src0 + i1*args.nb01 + i2*args.nb02 + i3*args.nb03); + device T * dst_row = (device T *) (dst + i1*args.nb1 + i2*args.nb2 + i3*args.nb3); - float sumf = 0; + T0 sumf = T0(0.0f); for (int64_t i0 = tpitg.x; i0 < args.ne00; i0 += ntg.x) { sumf += src_row[i0]; @@ -1861,23 +1679,33 @@ kernel void kernel_sum_rows( threadgroup_barrier(mem_flags::mem_threadgroup); if (tiisg == 0) { - shmem_f32[sgitg] = sumf; + shmem_t[sgitg] = sumf; } threadgroup_barrier(mem_flags::mem_threadgroup); - sumf = shmem_f32[tiisg]; + sumf = shmem_t[tiisg]; sumf = simd_sum(sumf); if (tpitg.x == 0) { - dst_row[0] = norm ? sumf / args.ne00 : sumf; + if (FC_OP == OP_SUM_ROWS_NUM_MEAN) { + if (is_same<float4, T0>::value) { + dst_row[0] = sum(sumf) / (4*args.ne00); + } else { + dst_row[0] = sum(sumf) / args.ne00; + } + } else { + dst_row[0] = sum(sumf); + } } + +#undef FC_OP } -typedef decltype(kernel_sum_rows<false>) kernel_sum_rows_t; +typedef decltype(kernel_sum_rows_impl<float, float>) kernel_sum_rows_t; -template [[host_name("kernel_sum_rows_f32")]] kernel kernel_sum_rows_t kernel_sum_rows<false>; -template [[host_name("kernel_mean_f32")]] kernel kernel_sum_rows_t kernel_sum_rows<true>; +template [[host_name("kernel_sum_rows_f32_f32")]] kernel kernel_sum_rows_t kernel_sum_rows_impl<float, float>; +template [[host_name("kernel_sum_rows_f32_f32_4")]] kernel kernel_sum_rows_t kernel_sum_rows_impl<float4, float>; template<typename T> kernel void kernel_cumsum_blk( @@ -2737,6 +2565,329 @@ kernel void kernel_rwkv_wkv7_f32( } } +constant short FC_gated_delta_net_ne20 [[function_constant(FC_GATED_DELTA_NET + 0)]]; +constant short FC_gated_delta_net_ne30 [[function_constant(FC_GATED_DELTA_NET + 1)]]; +constant short FC_gated_delta_net_K [[function_constant(FC_GATED_DELTA_NET + 2)]]; + +#if 1 +template<short NSG> +kernel void kernel_gated_delta_net_impl( + constant ggml_metal_kargs_gated_delta_net & args, + device const char * q, + device const char * k, + device const char * v, + device const char * g, + device const char * b, + device const char * s, + device char * dst, + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tpitg[[thread_position_in_threadgroup]], + uint3 ntg[[threads_per_threadgroup]]) { +#define S_v FC_gated_delta_net_ne20 +#define G FC_gated_delta_net_ne30 +#define K FC_gated_delta_net_K + + const uint tx = tpitg.x; + const uint ty = tpitg.y; + + const uint i23 = tgpig.z; // B (n_seqs) + const uint i21 = tgpig.y; // H (head) + const uint i20 = tgpig.x*NSG + ty; // row within S_v + + const uint i01 = i21 % args.ne01; + const uint i11 = i21 % args.ne11; + + const float scale = 1.0f / sqrt((float)S_v); + + // input state layout [S_v, S_v, H, n_seqs] (s0 only): per-seq stride is H*D. + // state is stored transposed: M[i20][is] = S[is][i20], so row i20 is contiguous + const uint state_in_base = (i23*args.ne21 + i21)*S_v*S_v + i20*S_v; + device const float * s_ptr = (device const float *) (s) + state_in_base; + + float ls[NSG]; + + FOR_UNROLL (short j = 0; j < NSG; j++) { + const short is = tx*NSG + j; + ls[j] = s_ptr[is]; + } + + device float * dst_attn = (device float *) (dst) + (i23*args.ne22*args.ne21 + i21)*S_v + i20; + + device const float * q_ptr = (device const float *) (q + i23*args.nb03 + i01*args.nb01); + device const float * k_ptr = (device const float *) (k + i23*args.nb13 + i11*args.nb11); + device const float * v_ptr = (device const float *) (v + i23*args.nb23 + i21*args.nb21); + + device const float * b_ptr = (device const float *) (b) + (i23*args.ne22*args.ne21 + i21); + device const float * g_ptr = (device const float *) (g) + (i23*args.ne22*args.ne21 + i21)*G; + + // snapshot slot mapping: slot 0 = most recent state, slot s = s tokens back. + // When n_tokens < K, only slots 0..n_tokens-1 are written; older slots are caller-owned. + + // output state base offset: after attention scores + const uint attn_size = args.ne22 * args.ne21 * S_v * args.ne23; + // output state per-slot size: S_v * S_v * H * n_seqs + const uint state_size_per_snap = S_v * S_v * args.ne21 * args.ne23; + // per-(seq,head) offset within a slot + const uint state_out_base = (i23*args.ne21 + i21)*S_v*S_v + i20*S_v; + + for (short t = 0; t < args.ne22; t++) { + float s_k = 0.0f; + + if (G == 1) { + const float g_exp = exp(g_ptr[0]); + + FOR_UNROLL (short j = 0; j < NSG; j++) { + const short is = tx*NSG + j; + ls[j] *= g_exp; + + s_k += ls[j]*k_ptr[is]; + } + } else { + // KDA + FOR_UNROLL (short j = 0; j < NSG; j++) { + const short is = tx*NSG + j; + ls[j] *= exp(g_ptr[is]); + + s_k += ls[j]*k_ptr[is]; + } + } + + s_k = simd_sum(s_k); + + const float d = (v_ptr[i20] - s_k)*b_ptr[0]; + + float y = 0.0f; + + FOR_UNROLL (short j = 0; j < NSG; j++) { + const short is = tx*NSG + j; + ls[j] += k_ptr[is]*d; + + y += ls[j]*q_ptr[is]; + } + + y = simd_sum(y); + + if (tx == 0) { + dst_attn[t*args.ne21*S_v] = y*scale; + } + + q_ptr += args.ns02; + k_ptr += args.ns12; + v_ptr += args.ns22; + + b_ptr += args.ne21; + g_ptr += args.ne21*G; + + if (K > 1) { + const int target_slot = (int)args.ne22 - 1 - (int)t; + if (target_slot >= 0 && target_slot < (int)K) { + device float * dst_state = (device float *) (dst) + attn_size + (uint)target_slot * state_size_per_snap + state_out_base; + FOR_UNROLL (short j = 0; j < NSG; j++) { + const short is = tx*NSG + j; + dst_state[is] = ls[j]; + } + } + } + } + + if (K == 1) { + device float * dst_state = (device float *) (dst) + attn_size + state_out_base; + FOR_UNROLL (short j = 0; j < NSG; j++) { + const short is = tx*NSG + j; + dst_state[is] = ls[j]; + } + } + +#undef S_v +#undef G +#undef K +} + +typedef decltype(kernel_gated_delta_net_impl<4>) kernel_gated_delta_net_t; + +template [[host_name("kernel_gated_delta_net_f32_1")]] kernel kernel_gated_delta_net_t kernel_gated_delta_net_impl<1>; +template [[host_name("kernel_gated_delta_net_f32_2")]] kernel kernel_gated_delta_net_t kernel_gated_delta_net_impl<2>; +template [[host_name("kernel_gated_delta_net_f32_4")]] kernel kernel_gated_delta_net_t kernel_gated_delta_net_impl<4>; + +#else +// a simplified version of the above +// no performance improvement, so keep the above version for now + +template<typename T, short NSG> +kernel void kernel_gated_delta_net_impl( + constant ggml_metal_kargs_gated_delta_net & args, + device const char * q, + device const char * k, + device const char * v, + device const char * g, + device const char * b, + device const char * s, + device char * dst, + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tpitg[[thread_position_in_threadgroup]], + uint3 ntg[[threads_per_threadgroup]]) { +#define S_v FC_gated_delta_net_ne20 +#define G FC_gated_delta_net_ne30 + + const uint tx = tpitg.x; + const uint ty = tpitg.y; + + const uint i23 = tgpig.z; // B + const uint i21 = tgpig.y; // H + const uint i20 = tgpig.x*NSG + ty; + + const uint i01 = i21 % args.ne01; + const uint i11 = i21 % args.ne11; + + const float scale = 1.0f / sqrt((float)S_v); + + device const float * s_ptr = (device const float *) (s) + (i23*args.ne21 + i21)*S_v*S_v + i20; + + float lsf[NSG]; + + FOR_UNROLL (short j = 0; j < NSG; j++) { + const short is = tx*NSG + j; + lsf[j] = s_ptr[is*S_v]; + } + + thread T * ls = (thread T *) (lsf); + + device float * dst_attn = (device float *) (dst) + (i23*args.ne22*args.ne21 + i21)*S_v + i20; + + device const float * q_ptr = (device const float *) (q + i23*args.nb03 + i01*args.nb01); + device const float * k_ptr = (device const float *) (k + i23*args.nb13 + i11*args.nb11); + device const float * v_ptr = (device const float *) (v + i23*args.nb23 + i21*args.nb21); + + device const float * b_ptr = (device const float *) (b) + (i23*args.ne22*args.ne21 + i21); + device const float * g_ptr = (device const float *) (g) + (i23*args.ne22*args.ne21 + i21)*G; + + for (short t = 0; t < args.ne22; t++) { + device const T * qt_ptr = (device const T *) (q_ptr); + device const T * kt_ptr = (device const T *) (k_ptr); + device const T * gt_ptr = (device const T *) (g_ptr); + + if (G == 1) { + *ls *= exp(g_ptr[0]); + } else { + // KDA + *ls *= exp(gt_ptr[tx]); + } + + const float s_k = simd_sum(dot(*ls, kt_ptr[tx])); + + const float d = (v_ptr[i20] - s_k)*b_ptr[0]; + + *ls += kt_ptr[tx]*d; + + const float y = simd_sum(dot(*ls, qt_ptr[tx])); + + if (tx == 0) { + *dst_attn = y*scale; + } + + q_ptr += args.ns02; + k_ptr += args.ns12; + v_ptr += args.ns22; + + b_ptr += args.ne21; + g_ptr += args.ne21*G; + + dst_attn += args.ne21*S_v; + } + + device float * dst_state = (device float *) (dst) + args.ne23*args.ne22*args.ne21*S_v + (i23*args.ne21 + i21)*S_v*S_v + i20; + device T * dstt_state = (device T *) (dst_state); + + FOR_UNROLL (short j = 0; j < NSG; j++) { + const short is = tx*NSG + j; + dst_state[is*S_v] = lsf[j]; + } + +#undef S_v +#undef G +} + +typedef decltype(kernel_gated_delta_net_impl<float4, 4>) kernel_gated_delta_net_t; + +template [[host_name("kernel_gated_delta_net_f32_1")]] kernel kernel_gated_delta_net_t kernel_gated_delta_net_impl<float, 1>; +template [[host_name("kernel_gated_delta_net_f32_2")]] kernel kernel_gated_delta_net_t kernel_gated_delta_net_impl<float2, 2>; +template [[host_name("kernel_gated_delta_net_f32_4")]] kernel kernel_gated_delta_net_t kernel_gated_delta_net_impl<float4, 4>; +#endif + +constant short FC_solve_tri_nsg [[function_constant(FC_SOLVE_TRI + 0)]]; +constant short FC_solve_tri_n [[function_constant(FC_SOLVE_TRI + 1)]]; +constant short FC_solve_tri_k [[function_constant(FC_SOLVE_TRI + 2)]]; + +kernel void kernel_solve_tri_f32( + constant ggml_metal_kargs_solve_tri & args, + device const char * src0, + device const char * src1, + device char * dst, + threadgroup char * shmem [[threadgroup(0)]], + ushort3 tgpig[[threadgroup_position_in_grid]], + ushort sgitg[[simdgroup_index_in_threadgroup]], + ushort tiisg[[thread_index_in_simdgroup]], + ushort3 ntg[[threads_per_threadgroup]]) { + constexpr short NW = N_SIMDWIDTH; + + const short NSG = FC_solve_tri_nsg; + const short N = FC_solve_tri_n; + const short K = FC_solve_tri_k; + const short NP = PAD2(N, NW); + + const int32_t i03 = tgpig.z; + const int32_t i02 = tgpig.y; + const int32_t i01 = tgpig.x*NSG + sgitg; + + threadgroup float * sh0 = (threadgroup float *) shmem; + + device const float * src0_ptr = (device const float *)(src0 + i02 * args.nb02 + i03 * args.nb03) + sgitg*N; + device const float * src1_ptr = (device const float *)(src1 + i02 * args.nb12 + i03 * args.nb13) + i01; + device float * dst_ptr = (device float *)(dst + i02 * args.nb2 + i03 * args.nb3) + i01; + + for (short rr = 0; rr < N; rr += NSG) { + threadgroup_barrier(mem_flags::mem_threadgroup); + + { + threadgroup float * sh0_cur = sh0 + sgitg*NP; + + for (short t = 0; t*NW < N; ++t) { + const short idx = t*NW + tiisg; + sh0_cur[idx] = src0_ptr[idx]; + } + + src0_ptr += NSG*N; + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + if (i01 >= args.ne10) { + continue; + } + + for (short ir = 0; ir < NSG && rr + ir < N; ++ir) { + const short r = rr + ir; + + threadgroup float * sh0_cur = sh0 + ir*NP; + + float sum = 0.0f; + + for (short t = 0; t*NW < r; ++t) { + const short idx = t*NW + tiisg; + sum += sh0_cur[idx] * dst_ptr[idx*K] * (idx < r); + } + + sum = simd_sum(sum); + + if (tiisg == 0) { + const float diag = sh0_cur[r]; + + dst_ptr[r*K] = (src1_ptr[r*K] - sum) / diag; + } + } + } +} + kernel void kernel_argmax_f32( constant ggml_metal_kargs_argmax & args, device const char * src0, @@ -2970,26 +3121,32 @@ template [[host_name("kernel_rms_norm_f32_4")]] kernel kernel_rms_norm_f template [[host_name("kernel_rms_norm_mul_f32_4")]] kernel kernel_rms_norm_fuse_t kernel_rms_norm_fuse_impl<float4, 2>; template [[host_name("kernel_rms_norm_mul_add_f32_4")]] kernel kernel_rms_norm_fuse_t kernel_rms_norm_fuse_impl<float4, 3>; -kernel void kernel_l2_norm_f32( +template <typename T0, typename T> +kernel void kernel_l2_norm_impl( constant ggml_metal_kargs_l2_norm & args, device const char * src0, device char * dst, threadgroup float * shmem_f32 [[threadgroup(0)]], - uint tgpig[[threadgroup_position_in_grid]], - ushort tpitg[[thread_position_in_threadgroup]], - ushort sgitg[[simdgroup_index_in_threadgroup]], - ushort tiisg[[thread_index_in_simdgroup]], - ushort ntg[[threads_per_threadgroup]]) { + uint3 tgpig[[threadgroup_position_in_grid]], + ushort3 tpitg[[thread_position_in_threadgroup]], + ushort sgitg[[simdgroup_index_in_threadgroup]], + ushort tiisg[[thread_index_in_simdgroup]], + ushort3 ntg[[threads_per_threadgroup]]) { + const int i03 = tgpig.z; + const int i02 = tgpig.y; + const int i01 = tgpig.x; + if (sgitg == 0) { shmem_f32[tiisg] = 0.0f; } - device const float4 * x = (device const float4 *) (src0 + tgpig*args.nb01); + device const T0 * x = (device const T0 *) (src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01); + device T * y = (device T *) (dst + i03*args.nb3 + i02*args.nb2 + i01*args.nb1); float sumf = 0.0f; // parallel sum - for (int i00 = tpitg; i00 < args.ne00_4; i00 += ntg) { + for (int i00 = tpitg.x; i00 < args.ne00; i00 += ntg.x) { sumf += dot(x[i00], x[i00]); } sumf = simd_sum(sumf); @@ -3005,14 +3162,18 @@ kernel void kernel_l2_norm_f32( sumf = shmem_f32[tiisg]; sumf = simd_sum(sumf); - const float scale = 1.0f/sqrt(max(sumf, args.eps)); + const float scale = 1.0f/max(sqrt(sumf), args.eps); - device float4 * y = (device float4 *) dst + tgpig*args.ne00_4; - for (int i00 = tpitg; i00 < args.ne00_4; i00 += ntg) { + for (int i00 = tpitg.x; i00 < args.ne00; i00 += ntg.x) { y[i00] = x[i00] * scale; } } +typedef decltype(kernel_l2_norm_impl<float, float>) kernel_l2_norm_t; + +template [[host_name("kernel_l2_norm_f32_f32")]] kernel kernel_l2_norm_t kernel_l2_norm_impl<float, float>; +template [[host_name("kernel_l2_norm_f32_f32_4")]] kernel kernel_l2_norm_t kernel_l2_norm_impl<float4, float4>; + kernel void kernel_group_norm_f32( constant ggml_metal_kargs_group_norm & args, device const float * src0, @@ -3094,6 +3255,35 @@ kernel void kernel_group_norm_f32( } } +// Q1_0 dot product: dot = d * (2 * Σ(yl[i] where bit=1) - sumy) +inline float block_q_n_dot_y(device const block_q1_0 * qb_curr, float sumy, thread float * yl, int il) { + device const uint8_t * qs = qb_curr->qs + il / 8; + const uint8_t b0 = qs[0]; + const uint8_t b1 = qs[1]; + + float acc = 0.0f; + + acc += select(0.0f, yl[ 0], bool(b0 & 0x01)); + acc += select(0.0f, yl[ 1], bool(b0 & 0x02)); + acc += select(0.0f, yl[ 2], bool(b0 & 0x04)); + acc += select(0.0f, yl[ 3], bool(b0 & 0x08)); + acc += select(0.0f, yl[ 4], bool(b0 & 0x10)); + acc += select(0.0f, yl[ 5], bool(b0 & 0x20)); + acc += select(0.0f, yl[ 6], bool(b0 & 0x40)); + acc += select(0.0f, yl[ 7], bool(b0 & 0x80)); + + acc += select(0.0f, yl[ 8], bool(b1 & 0x01)); + acc += select(0.0f, yl[ 9], bool(b1 & 0x02)); + acc += select(0.0f, yl[10], bool(b1 & 0x04)); + acc += select(0.0f, yl[11], bool(b1 & 0x08)); + acc += select(0.0f, yl[12], bool(b1 & 0x10)); + acc += select(0.0f, yl[13], bool(b1 & 0x20)); + acc += select(0.0f, yl[14], bool(b1 & 0x40)); + acc += select(0.0f, yl[15], bool(b1 & 0x80)); + + return qb_curr->d * (2.0f * acc - sumy); +} + // function for calculate inner product between half a q4_0 block and 16 floats (yl), sumy is SUM(yl[i]) // il indicates where the q4 quants begin (0 or QK4_0/4) // we assume that the yl's have been multiplied with the appropriate scale factor @@ -3226,6 +3416,9 @@ static inline void helper_mv_reduce_and_write( constant short FC_mul_mv_nsg [[function_constant(FC_MUL_MV + 0)]]; constant short FC_mul_mv_nxpsg [[function_constant(FC_MUL_MV + 1)]]; +constant short FC_mul_mv_ne12 [[function_constant(FC_MUL_MV + 2)]]; +constant short FC_mul_mv_r2 [[function_constant(FC_MUL_MV + 3)]]; +constant short FC_mul_mv_r3 [[function_constant(FC_MUL_MV + 4)]]; template<typename block_q_type, short NR0, typename args_t> void mul_vec_q_n_f32_impl( @@ -3249,72 +3442,151 @@ void mul_vec_q_n_f32_impl( const int r1 = tgpig.y; const int im = tgpig.z; - const uint i12 = im%args.ne12; - const uint i13 = im/args.ne12; + const uint i12 = im%FC_mul_mv_ne12; + const uint i13 = im/FC_mul_mv_ne12; - //const uint64_t offset0 = r0*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + //const uint64_t offset0 = r0*args.nb01 + (i12/FC_mul_mv_r2)*args.nb02 + (i13/FC_mul_mv_r3)*args.nb03; const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; //device const block_q_type * x = (device const block_q_type *) (src0 + offset0); device const float * y = (device const float *) (src1 + offset1); - // pointers to src0 rows - device const block_q_type * ax[NR0]; - FOR_UNROLL (int row = 0; row < NR0; ++row) { - const uint64_t offset0 = (r0 + row)*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + // pointers to src0 rows + device const block_q_type * ax[NR0]; + FOR_UNROLL (int row = 0; row < NR0; ++row) { + const uint64_t offset0 = (r0 + row)*args.nb01 + (i12/FC_mul_mv_r2)*args.nb02 + (i13/FC_mul_mv_r3)*args.nb03; + + ax[row] = (device const block_q_type *) ((device char *) src0 + offset0); + } + + float sumf[NR0] = {0.f}; + + const short ix = (tiisg/(NW/NQ)); + const short il = (tiisg%(NW/NQ))*8; + + //const int ib0 = sgitg*NQ + ix; + const int ib0 = ix; + + float yl[16]; // src1 vector cache + + //device const float * yb = y + ix*QK4_0 + il; + device const float * yb = y + ib0*QK4_0 + il; + + // each thread in a SIMD group deals with half a block. + //for (int ib = ib0; ib < nb; ib += NSG*NQ) { + for (int ib = ib0; ib < nb; ib += NQ) { + float sumy[2] = { 0.f, 0.f }; + + FOR_UNROLL (short i = 0; i < 8; i += 2) { + sumy[0] += yb[i + 0] + yb[i + 1]; + yl[i + 0] = yb[i + 0]; + yl[i + 1] = yb[i + 1]/256.f; + + sumy[1] += yb[i + 16] + yb[i + 17]; + yl[i + 8] = yb[i + 16]/16.f; + yl[i + 9] = yb[i + 17]/4096.f; + } + + FOR_UNROLL (short row = 0; row < NR0; row++) { + sumf[row] += block_q_n_dot_y(ax[row] + ib, sumy[0] + sumy[1], yl, il); + } + + yb += QK4_0 * 16; + //yb += NSG*NQ*QK4_0; + } + + device float * dst_f32 = (device float *) dst + im*args.ne0*args.ne1 + r1*args.ne0; + + //helper_mv_reduce_and_write<NR0>(dst_f32, sumf, r0, args.ne01, tiisg, sgitg, shmem); + + for (int row = 0; row < NR0; ++row) { + const float tot = simd_sum(sumf[row]); + + if (tiisg == 0 && r0 + row < args.ne01) { + dst_f32[r0 + row] = tot; + } + } +} + +template<int nr0, typename args_t> +void kernel_mul_mv_q1_0_f32_impl( + args_t args, + device const char * src0, + device const char * src1, + device char * dst, + threadgroup char * shmem, + uint3 tgpig, + ushort tiisg, + ushort sgitg) { + const short NSG = FC_mul_mv_nsg; + + const int nb = args.ne00/QK1_0; + + const int r0 = tgpig.x; + const int r1 = tgpig.y; + const int im = tgpig.z; + + const int first_row = (r0 * NSG + sgitg) * nr0; - ax[row] = (device const block_q_type *) ((device char *) src0 + offset0); - } + const uint i12 = im%FC_mul_mv_ne12; + const uint i13 = im/FC_mul_mv_ne12; - float sumf[NR0] = {0.f}; + const uint64_t offset1 = r1*args.nb11 + (i12)*args.nb12 + (i13)*args.nb13; - const short ix = (tiisg/(NW/NQ)); - const short il = (tiisg%(NW/NQ))*8; + device const float * y = (device const float *) (src1 + offset1); - //const int ib0 = sgitg*NQ + ix; - const int ib0 = ix; + device const block_q1_0 * ax[nr0]; + for (int row = 0; row < nr0; ++row) { + const uint64_t offset0 = (first_row + row)*args.nb01 + (i12/FC_mul_mv_r2)*args.nb02 + (i13/FC_mul_mv_r3)*args.nb03; + ax[row] = (device const block_q1_0 *) ((device char *) src0 + offset0); + } - float yl[16]; // src1 vector cache + float yl[16]; + float sumf[nr0] = {0.f}; - //device const float * yb = y + ix*QK4_0 + il; - device const float * yb = y + ib0*QK4_0 + il; + const short ix = (tiisg/8); + const short il = (tiisg%8)*16; - // each thread in a SIMD group deals with half a block. - //for (int ib = ib0; ib < nb; ib += NSG*NQ) { - for (int ib = ib0; ib < nb; ib += NQ) { - float sumy[2] = { 0.f, 0.f }; + device const float * yb = y + ix*QK1_0 + il; - FOR_UNROLL (short i = 0; i < 8; i += 2) { - sumy[0] += yb[i + 0] + yb[i + 1]; - yl[i + 0] = yb[i + 0]; - yl[i + 1] = yb[i + 1]/256.f; + for (int ib = ix; ib < nb; ib += N_SIMDWIDTH/8) { + float sumy = 0.f; - sumy[1] += yb[i + 16] + yb[i + 17]; - yl[i + 8] = yb[i + 16]/16.f; - yl[i + 9] = yb[i + 17]/4096.f; + FOR_UNROLL (short i = 0; i < 16; i++) { + yl[i] = yb[i]; + sumy += yb[i]; } - FOR_UNROLL (short row = 0; row < NR0; row++) { - sumf[row] += block_q_n_dot_y(ax[row] + ib, sumy[0] + sumy[1], yl, il); + FOR_UNROLL (short row = 0; row < nr0; row++) { + sumf[row] += block_q_n_dot_y(ax[row] + ib, sumy, yl, il); } - yb += QK4_0 * 16; - //yb += NSG*NQ*QK4_0; + yb += QK1_0 * (N_SIMDWIDTH/8); } - device float * dst_f32 = (device float *) dst + im*args.ne0*args.ne1 + r1*args.ne0; - - //helper_mv_reduce_and_write<NR0>(dst_f32, sumf, r0, args.ne01, tiisg, sgitg, shmem); + device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0; - for (int row = 0; row < NR0; ++row) { + for (int row = 0; row < nr0; ++row) { const float tot = simd_sum(sumf[row]); - if (tiisg == 0 && r0 + row < args.ne01) { - dst_f32[r0 + row] = tot; + if (tiisg == 0 && first_row + row < args.ne01) { + dst_f32[first_row + row] = tot; } } } +[[host_name("kernel_mul_mv_q1_0_f32")]] +kernel void kernel_mul_mv_q1_0_f32( + constant ggml_metal_kargs_mul_mv & args, + device const char * src0, + device const char * src1, + device char * dst, + uint3 tgpig[[threadgroup_position_in_grid]], + ushort tiisg[[thread_index_in_simdgroup]], + ushort sgitg[[simdgroup_index_in_threadgroup]]) { + kernel_mul_mv_q1_0_f32_impl<N_R0_Q1_0, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg); +} + kernel void kernel_mul_mv_q4_0_f32( constant ggml_metal_kargs_mul_mv & args, device const char * src0, @@ -3384,10 +3656,10 @@ void kernel_mul_mv_q8_0_f32_impl( const int r1 = tgpig.y; const int im = tgpig.z; - const uint i12 = im%args.ne12; - const uint i13 = im/args.ne12; + const uint i12 = im%FC_mul_mv_ne12; + const uint i13 = im/FC_mul_mv_ne12; - //const uint64_t offset0 = r0*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + //const uint64_t offset0 = r0*args.nb01 + (i12/FC_mul_mv_r2)*args.nb02 + (i13/FC_mul_mv_r3)*args.nb03; const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; //device const block_q8_0 * x = (device const block_q8_0 *) (src0 + offset0); @@ -3396,7 +3668,7 @@ void kernel_mul_mv_q8_0_f32_impl( // pointers to src0 rows device const block_q8_0 * ax[NR0]; FOR_UNROLL (short row = 0; row < NR0; ++row) { - const uint64_t offset0 = (r0 + row)*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const uint64_t offset0 = (r0 + row)*args.nb01 + (i12/FC_mul_mv_r2)*args.nb02 + (i13/FC_mul_mv_r3)*args.nb03; ax[row] = (device const block_q8_0 *) ((device char *) src0 + offset0); } @@ -3476,10 +3748,10 @@ void kernel_mul_mv_ext_q4_f32_impl( const int i11 = tgpig.y*r1ptg; const int i1m = tgpig.z; - const int i12 = i1m%args.ne12; - const int i13 = i1m/args.ne12; + const int i12 = i1m%FC_mul_mv_ne12; + const int i13 = i1m/FC_mul_mv_ne12; - const uint64_t offset0 = i01*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const uint64_t offset0 = i01*args.nb01 + (i12/FC_mul_mv_r2)*args.nb02 + (i13/FC_mul_mv_r3)*args.nb03; const uint64_t offset1 = i11*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; device const q_t * xq = (i01 < args.ne01) ? (device const q_t *) (src0 + offset0) + tx/chpb : (device const q_t *) src0; @@ -3579,10 +3851,10 @@ void kernel_mul_mv_ext_q4x4_f32_impl( const int i11 = tgpig.y*r1ptg; const int i1m = tgpig.z; - const int i12 = i1m%args.ne12; - const int i13 = i1m/args.ne12; + const int i12 = i1m%FC_mul_mv_ne12; + const int i13 = i1m/FC_mul_mv_ne12; - const uint64_t offset0 = i01*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const uint64_t offset0 = i01*args.nb01 + (i12/FC_mul_mv_r2)*args.nb02 + (i13/FC_mul_mv_r3)*args.nb03; const uint64_t offset1 = i11*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; device const q_t * xq = (i01 < args.ne01) ? (device const q_t *) (src0 + offset0) + tx/chpb : (device const q_t *) src0; @@ -3700,6 +3972,18 @@ template [[host_name("kernel_mul_mv_ext_f16_f32_r1_3")]] kernel mul_mv_ext_q4 template [[host_name("kernel_mul_mv_ext_f16_f32_r1_4")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<4, half4, 4, dequantize_f16_t4>; template [[host_name("kernel_mul_mv_ext_f16_f32_r1_5")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<5, half4, 4, dequantize_f16_t4>; +#if defined(GGML_METAL_HAS_BF16) +template [[host_name("kernel_mul_mv_ext_bf16_f32_r1_2")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<2, bfloat4, 4, dequantize_bf16_t4>; +template [[host_name("kernel_mul_mv_ext_bf16_f32_r1_3")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<3, bfloat4, 4, dequantize_bf16_t4>; +template [[host_name("kernel_mul_mv_ext_bf16_f32_r1_4")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<4, bfloat4, 4, dequantize_bf16_t4>; +template [[host_name("kernel_mul_mv_ext_bf16_f32_r1_5")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<5, bfloat4, 4, dequantize_bf16_t4>; +#endif + +template [[host_name("kernel_mul_mv_ext_q1_0_f32_r1_2")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<2, block_q1_0, 128, dequantize_q1_0_t4>; +template [[host_name("kernel_mul_mv_ext_q1_0_f32_r1_3")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<3, block_q1_0, 128, dequantize_q1_0_t4>; +template [[host_name("kernel_mul_mv_ext_q1_0_f32_r1_4")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<4, block_q1_0, 128, dequantize_q1_0_t4>; +template [[host_name("kernel_mul_mv_ext_q1_0_f32_r1_5")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<5, block_q1_0, 128, dequantize_q1_0_t4>; + template [[host_name("kernel_mul_mv_ext_q4_0_f32_r1_2")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<2, block_q4_0, 32, dequantize_q4_0_t4>; template [[host_name("kernel_mul_mv_ext_q4_0_f32_r1_3")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<3, block_q4_0, 32, dequantize_q4_0_t4>; template [[host_name("kernel_mul_mv_ext_q4_0_f32_r1_4")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<4, block_q4_0, 32, dequantize_q4_0_t4>; @@ -3750,6 +4034,16 @@ template [[host_name("kernel_mul_mv_ext_q6_K_f32_r1_3")]] kernel mul_mv_ext_q4x4 template [[host_name("kernel_mul_mv_ext_q6_K_f32_r1_4")]] kernel mul_mv_ext_q4x4_f32_t kernel_mul_mv_ext_q4x4_f32_disp<4, block_q6_K, 256, dequantize_q6_K>; template [[host_name("kernel_mul_mv_ext_q6_K_f32_r1_5")]] kernel mul_mv_ext_q4x4_f32_t kernel_mul_mv_ext_q4x4_f32_disp<5, block_q6_K, 256, dequantize_q6_K>; +template [[host_name("kernel_mul_mv_ext_q2_K_f32_r1_2")]] kernel mul_mv_ext_q4x4_f32_t kernel_mul_mv_ext_q4x4_f32_disp<2, block_q2_K, 256, dequantize_q2_K>; +template [[host_name("kernel_mul_mv_ext_q2_K_f32_r1_3")]] kernel mul_mv_ext_q4x4_f32_t kernel_mul_mv_ext_q4x4_f32_disp<3, block_q2_K, 256, dequantize_q2_K>; +template [[host_name("kernel_mul_mv_ext_q2_K_f32_r1_4")]] kernel mul_mv_ext_q4x4_f32_t kernel_mul_mv_ext_q4x4_f32_disp<4, block_q2_K, 256, dequantize_q2_K>; +template [[host_name("kernel_mul_mv_ext_q2_K_f32_r1_5")]] kernel mul_mv_ext_q4x4_f32_t kernel_mul_mv_ext_q4x4_f32_disp<5, block_q2_K, 256, dequantize_q2_K>; + +template [[host_name("kernel_mul_mv_ext_q3_K_f32_r1_2")]] kernel mul_mv_ext_q4x4_f32_t kernel_mul_mv_ext_q4x4_f32_disp<2, block_q3_K, 256, dequantize_q3_K>; +template [[host_name("kernel_mul_mv_ext_q3_K_f32_r1_3")]] kernel mul_mv_ext_q4x4_f32_t kernel_mul_mv_ext_q4x4_f32_disp<3, block_q3_K, 256, dequantize_q3_K>; +template [[host_name("kernel_mul_mv_ext_q3_K_f32_r1_4")]] kernel mul_mv_ext_q4x4_f32_t kernel_mul_mv_ext_q4x4_f32_disp<4, block_q3_K, 256, dequantize_q3_K>; +template [[host_name("kernel_mul_mv_ext_q3_K_f32_r1_5")]] kernel mul_mv_ext_q4x4_f32_t kernel_mul_mv_ext_q4x4_f32_disp<5, block_q3_K, 256, dequantize_q3_K>; + template<typename T0, typename T1, short NR0, typename args_t> void kernel_mul_mv_t_t_impl( args_t args, @@ -3772,10 +4066,10 @@ void kernel_mul_mv_t_t_impl( const int r1 = tgpig.y; const int im = tgpig.z; - const uint i12 = im%args.ne12; - const uint i13 = im/args.ne12; + const uint i12 = im%FC_mul_mv_ne12; + const uint i13 = im/FC_mul_mv_ne12; - //const uint64_t offset0 = r0*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + //const uint64_t offset0 = r0*args.nb01 + (i12/FC_mul_mv_r2)*args.nb02 + (i13/FC_mul_mv_r3)*args.nb03; const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; //device const T0 * x = (device const T0 *) (src0 + offset0); @@ -3784,7 +4078,7 @@ void kernel_mul_mv_t_t_impl( // pointers to src0 rows device const T0 * ax [NR0]; FOR_UNROLL (short row = 0; row < NR0; ++row) { - const uint64_t offset0 = (r0 + row)*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const uint64_t offset0 = (r0 + row)*args.nb01 + (i12/FC_mul_mv_r2)*args.nb02 + (i13/FC_mul_mv_r3)*args.nb03; ax[row] = (device const T0 *) ((device char *) src0 + offset0); } @@ -3894,10 +4188,10 @@ void kernel_mul_mv_t_t_4_impl( const int r1 = tgpig.y; const int im = tgpig.z; - const uint i12 = im%args.ne12; - const uint i13 = im/args.ne12; + const uint i12 = im%FC_mul_mv_ne12; + const uint i13 = im/FC_mul_mv_ne12; - //const uint64_t offset0 = r0*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + //const uint64_t offset0 = r0*args.nb01 + (i12/FC_mul_mv_r2)*args.nb02 + (i13/FC_mul_mv_r3)*args.nb03; const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; device const T1 * y = (device const T1 *) (src1 + offset1); @@ -3907,7 +4201,7 @@ void kernel_mul_mv_t_t_4_impl( device const T0 * ax [NR0]; device const T04 * ax4[NR0]; FOR_UNROLL (short row = 0; row < NR0; ++row) { - const uint64_t offset0 = (r0 + row)*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const uint64_t offset0 = (r0 + row)*args.nb01 + (i12/FC_mul_mv_r2)*args.nb02 + (i13/FC_mul_mv_r3)*args.nb03; ax [row] = (device const T0 *) ((device char *) src0 + offset0); ax4[row] = (device const T04 *) ((device char *) src0 + offset0); @@ -4011,10 +4305,10 @@ void kernel_mul_mv_t_t_short_impl( return; } - const uint i12 = im%args.ne12; - const uint i13 = im/args.ne12; + const uint i12 = im%FC_mul_mv_ne12; + const uint i13 = im/FC_mul_mv_ne12; - const uint64_t offset0 = r0*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const uint64_t offset0 = r0*args.nb01 + (i12/FC_mul_mv_r2)*args.nb02 + (i13/FC_mul_mv_r3)*args.nb03; device const T0 * x = (device const T0 *) (src0 + offset0); @@ -4437,59 +4731,59 @@ kernel void kernel_im2col( template [[host_name("kernel_im2col_f32")]] kernel im2col_t kernel_im2col<float>; template [[host_name("kernel_im2col_f16")]] kernel im2col_t kernel_im2col<half>; -// TODO: obolete -- remove -//typedef void (im2col_ext_t)( -// constant ggml_metal_kargs_im2col & args, -// device const float * x, -// device char * dst, -// uint3 tgpig[[threadgroup_position_in_grid]], -// uint3 tgpg[[threadgroups_per_grid]], -// uint3 tpitg[[thread_position_in_threadgroup]], -// uint3 ntg[[threads_per_threadgroup]]); -// -//template <typename T> -//kernel void kernel_im2col_ext( -// constant ggml_metal_kargs_im2col & args, -// device const float * x, -// device char * dst, -// uint3 tgpig[[threadgroup_position_in_grid]], -// uint3 tgpg[[threadgroups_per_grid]], // tgpg[0] = D x IC x KH x KW, CHW = IC x KH x KW -// uint3 tpitg[[thread_position_in_threadgroup]], -// uint3 ntg[[threads_per_threadgroup]]) { // [M, 1, 1] -// const int64_t KHW = (int64_t)args.KHW; -// -// const int64_t d = tgpig[0] / args.CHW; -// const int64_t chw = tgpig[0] % args.CHW; -// const int64_t tgpig_0 = chw / KHW; // 0 ~ (IC - 1) -// const int64_t HW = tgpig[0] % KHW; -// -// const int64_t tpitg_0 = (d * ntg[0]) + tpitg[0]; -// if (tpitg_0 >= args.N) { -// return; -// } -// -// const int64_t tpitg_1 = HW / args.KW; -// const int64_t tpitg_2 = HW % args.KW; -// -// const int64_t iiw = tgpig[2] * args.s0 + tpitg_2 * args.d0 - args.p0; -// const int64_t iih = tgpig[1] * args.s1 + tpitg_1 * args.d1 - args.p1; -// -// const int64_t offset_dst = -// (tpitg_0 * tgpg[1] * tgpg[2] + tgpig[1] * tgpg[2] + tgpig[2]) * args.CHW + -// (tgpig_0 * KHW + tpitg_1 * args.KW + tpitg_2); -// -// device T * pdst = (device T *) (dst); -// -// if (iih < 0 || iih >= args.IH || iiw < 0 || iiw >= args.IW) { -// pdst[offset_dst] = 0.0f; -// } else { -// const int64_t offset_src = tpitg_0 * args.ofs0 + tgpig_0 * args.ofs1; -// pdst[offset_dst] = x[offset_src + iih * args.IW + iiw]; -// } -//} -// -//template [[host_name("kernel_im2col_ext_f32")]] kernel im2col_ext_t kernel_im2col_ext<float>; -//template [[host_name("kernel_im2col_ext_f16")]] kernel im2col_ext_t kernel_im2col_ext<half>; +// TODO: optimize +typedef void (im2col_ext_t)( + constant ggml_metal_kargs_im2col & args, + device const float * x, + device char * dst, + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tgpg[[threadgroups_per_grid]], + uint3 tpitg[[thread_position_in_threadgroup]], + uint3 ntg[[threads_per_threadgroup]]); + +template <typename T> +kernel void kernel_im2col_ext( + constant ggml_metal_kargs_im2col & args, + device const float * x, + device char * dst, + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tgpg[[threadgroups_per_grid]], // tgpg[0] = D x IC x KH x KW, CHW = IC x KH x KW + uint3 tpitg[[thread_position_in_threadgroup]], + uint3 ntg[[threads_per_threadgroup]]) { // [M, 1, 1] + const int64_t KHW = (int64_t)args.KHW; + + const int64_t d = tgpig[0] / args.CHW; + const int64_t chw = tgpig[0] % args.CHW; + const int64_t tgpig_0 = chw / KHW; // 0 ~ (IC - 1) + const int64_t HW = tgpig[0] % KHW; + + const int64_t tpitg_0 = (d * ntg[0]) + tpitg[0]; + if (tpitg_0 >= args.N) { + return; + } + + const int64_t tpitg_1 = HW / args.KW; + const int64_t tpitg_2 = HW % args.KW; + + const int64_t iiw = tgpig[2] * args.s0 + tpitg_2 * args.d0 - args.p0; + const int64_t iih = tgpig[1] * args.s1 + tpitg_1 * args.d1 - args.p1; + + const int64_t offset_dst = + (tpitg_0 * tgpg[1] * tgpg[2] + tgpig[1] * tgpg[2] + tgpig[2]) * args.CHW + + (tgpig_0 * KHW + tpitg_1 * args.KW + tpitg_2); + + device T * pdst = (device T *) (dst); + + if (iih < 0 || iih >= args.IH || iiw < 0 || iiw >= args.IW) { + pdst[offset_dst] = 0.0f; + } else { + const int64_t offset_src = tpitg_0 * args.ofs0 + tgpig_0 * args.ofs1; + pdst[offset_dst] = x[offset_src + iih * args.IW + iiw]; + } +} + +template [[host_name("kernel_im2col_ext_f32")]] kernel im2col_ext_t kernel_im2col_ext<float>; +template [[host_name("kernel_im2col_ext_f16")]] kernel im2col_ext_t kernel_im2col_ext<half>; template <typename TK> kernel void kernel_conv_2d( @@ -4622,15 +4916,32 @@ kernel void kernel_conv_transpose_1d( uint3 tgpig[[threadgroup_position_in_grid]], uint3 tgpg[[threadgroups_per_grid]]) { - float v = 0.0f; + // For output position j on the time axis, only input positions + // i such that i*s0 <= j < i*s0 + K + // contribute -- i.e. i in [ceil((j - K + 1)/s0), floor(j/s0)] + // intersected with [0, IL-1]. That's at most ceil(K/s0) values + // (typically 2 for stride==K/2 transposed convs). + const int32_t j = tgpig[0]; + const int32_t s0 = args.s0; + const int32_t K = args.K; + const int32_t IL = args.IL; + + int32_t i_min; + { + int32_t a = j - K + 1; + i_min = a <= 0 ? 0 : (a + s0 - 1) / s0; // ceil(a/s0) for a>0 + } + int32_t i_max = j / s0; + if (i_max > IL - 1) i_max = IL - 1; - for (int64_t c = 0; c < args.IC; c++) { - const int32_t kernel_offset = c * tgpg[1] * args.K + args.K * tgpig[1]; - const int32_t input_offset = c * args.IL; + float v = 0.0f; + if (i_min <= i_max) { + for (int64_t c = 0; c < args.IC; c++) { + const int32_t kernel_offset = c * tgpg[1] * K + K * tgpig[1]; + const int32_t input_offset = c * IL; - for (int64_t i = 0; i < args.IL; i++) { - if (tgpig[0] >= i * args.s0 && tgpig[0] < i * args.s0 + args.K) { - v += src0[kernel_offset + tgpig[0] - i * args.s0] * src1[input_offset + i]; + for (int32_t i = i_min; i <= i_max; i++) { + v += float(src0[kernel_offset + j - i * s0]) * src1[input_offset + i]; } } } @@ -4749,7 +5060,9 @@ kernel void kernel_conv_transpose_2d<half>( uint3 tpitg[[thread_position_in_threadgroup]], uint3 ntg[[threads_per_threadgroup]]); -kernel void kernel_upscale_f32( +constant bool FC_upscale_aa [[function_constant(FC_UPSCALE + 0)]]; + +kernel void kernel_upscale_nearest_f32( constant ggml_metal_kargs_upscale & args, device const char * src0, device char * dst, @@ -4775,8 +5088,12 @@ kernel void kernel_upscale_f32( } } -kernel void kernel_pad_f32( - constant ggml_metal_kargs_pad & args, +static inline float bilinear_tri(float x) { + return MAX(0.0f, 1.0f - fabs(x)); +} + +kernel void kernel_upscale_bilinear_f32( + constant ggml_metal_kargs_upscale & args, device const char * src0, device char * dst, uint3 tgpig[[threadgroup_position_in_grid]], @@ -4787,30 +5104,306 @@ kernel void kernel_pad_f32( const int64_t i2 = tgpig.y; const int64_t i1 = tgpig.x; - const int64_t i03 = i3; - const int64_t i02 = i2; - const int64_t i01 = i1; + const int64_t i03 = i3 / args.sf3; + const int64_t i02 = i2 / args.sf2; - device const float * src0_ptr = (device const float *) (src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01); - device float * dst_ptr = (device float *) (dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1); + const float f01 = ((float)i1 + args.poffs) / args.sf1 - args.poffs; + const int64_t i01 = MAX(0, MIN(args.ne01 - 1, (int64_t)floor(f01))); + const int64_t i01p = MAX(0, MIN(args.ne01 - 1, i01 + 1)); + const float fd1 = MAX(0.0f, MIN(1.0f, f01 - (float)i01)); + + src0 += i03*args.nb03 + i02*args.nb02; + + device float * dst_ptr = (device float *)(dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1); + + if (FC_upscale_aa) { + const float support0 = MAX(1.0f, 1.0f / args.sf0); + const float invscale0 = 1.0f / support0; + const float support1 = MAX(1.0f, 1.0f / args.sf1); + const float invscale1 = 1.0f / support1; - if (i1 < args.ne01 && i2 < args.ne02 && i3 < args.ne03) { for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) { - if (i0 < args.ne00) { - dst_ptr[i0] = src0_ptr[i0]; - } else { - dst_ptr[i0] = 0.0f; + const float f00 = ((float)i0 + args.poffs) / args.sf0 - args.poffs; + + int64_t x_min = MAX((int64_t)0, (int64_t)floor(f00 - support0 + args.poffs)); + int64_t x_max = MIN(args.ne00, (int64_t)ceil (f00 + support0 + args.poffs)); + + int64_t y_min = MAX((int64_t)0, (int64_t)floor(f01 - support1 + args.poffs)); + int64_t y_max = MIN(args.ne01, (int64_t)ceil (f01 + support1 + args.poffs)); + + float sum = 0.0f; + float wsum = 0.0f; + + for (int64_t sy = y_min; sy < y_max; ++sy) { + const float wy = MAX(0.0f, 1.0f - fabs((float)sy - f01) * invscale1); + for (int64_t sx = x_min; sx < x_max; ++sx) { + const float wx = MAX(0.0f, 1.0f - fabs((float)sx - f00) * invscale0); + const float w = wx * wy; + device const float * src_ptr = (device const float *)(src0 + sy*args.nb01 + sx*args.nb00); + sum += (*src_ptr) * w; + wsum += w; + } } + + const float v = (wsum > 0.0f) ? (sum / wsum) : 0.0f; + dst_ptr[i0] = v; } + } else { + for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) { + const float f00 = ((float)i0 + args.poffs) / args.sf0 - args.poffs; + const int64_t i00 = MAX(0, MIN(args.ne00 - 1, (int64_t)floor(f00))); + const int64_t i00p = MAX(0, MIN(args.ne00 - 1, i00 + 1)); + const float fd0 = MAX(0.0f, MIN(1.0f, f00 - (float)i00)); - return; + device const float * src00 = (device const float *)(src0 + i01*args.nb01 + i00*args.nb00); + device const float * src10 = (device const float *)(src0 + i01*args.nb01 + i00p*args.nb00); + device const float * src01 = (device const float *)(src0 + i01p*args.nb01 + i00*args.nb00); + device const float * src11 = (device const float *)(src0 + i01p*args.nb01 + i00p*args.nb00); + + const float v = + (*src00) * (1.0f - fd0) * (1.0f - fd1) + + (*src10) * fd0 * (1.0f - fd1) + + (*src01) * (1.0f - fd0) * fd1 + + (*src11) * fd0 * fd1; + + dst_ptr[i0] = v; + } + } +} + +template <typename T> +kernel void kernel_conv_3d( + constant ggml_metal_kargs_conv_3d & args, + device const char * src0, // Weights [IC * OC, KD, KH, KW] + device const char * src1, // Inputs [IC * N, ID, IH, IW] + device char * dst, // Outputs [OC * N, OD, OH, OW] + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tpitg[[thread_position_in_threadgroup]]) { + + // 1. Un-flatten the spatial dimension from Grid X + int64_t spatial_idx = tgpig.x * 32 + tpitg.x; + + if (spatial_idx >= args.OW * args.OH * args.OD) { + return; // Thread falls outside the spatial volume + } + + int64_t od = spatial_idx / (args.OW * args.OH); + int64_t oh = (spatial_idx / args.OW) % args.OH; + int64_t ow = spatial_idx % args.OW; + + // 2. Map Y to Channels, Z to Batch + int64_t oc = tgpig.y; + int64_t batch_idx = tgpig.z; + + // 3. Calculate anchor coordinates in the Input volume + int64_t i_w_base = ow * args.s0 - args.p0; + int64_t i_h_base = oh * args.s1 - args.p1; + int64_t i_d_base = od * args.s2 - args.p2; + + float sum = 0.0f; + + // 4. Gather Loop (Iterate over Input Channels -> Depth -> Height -> Width) + for (int64_t ic = 0; ic < args.IC; ++ic) { + + // ggml packs batch and channel together in the 4th dimension + int64_t src_cn_idx = batch_idx * args.IC + ic; + int64_t w_cn_idx = oc * args.IC + ic; + + for (int64_t kz = 0; kz < args.KD; ++kz) { + int64_t id = i_d_base + kz * args.d2; + if (id < 0 || id >= args.ID) continue; // Boundary check (Padding) + + for (int64_t ky = 0; ky < args.KH; ++ky) { + int64_t ih = i_h_base + ky * args.d1; + if (ih < 0 || ih >= args.IH) continue; + + for (int64_t kx = 0; kx < args.KW; ++kx) { + int64_t iw = i_w_base + kx * args.d0; + if (iw < 0 || iw >= args.IW) continue; + + // Convert multi-dimensional coordinates to flat byte offsets + int64_t w_idx = kx*args.nb00 + ky*args.nb01 + kz*args.nb02 + w_cn_idx*args.nb03; + int64_t i_idx = iw*args.nb10 + ih*args.nb11 + id*args.nb12 + src_cn_idx*args.nb13; + + // Dereference memory and cast weights to f32 if they were f16 + float w_val = (float)*(device const T*)((device const char*)src0 + w_idx); + float i_val = *(device const float*)((device const char*)src1 + i_idx); + + sum += w_val * i_val; + } + } + } + } + + // 5. Write the accumulated value out to RAM + int64_t dst_cn_idx = batch_idx * args.OC + oc; + int64_t d_idx = ow*args.nb0 + oh*args.nb1 + od*args.nb2 + dst_cn_idx*args.nb3; + + *(device float*)(dst + d_idx) = sum; +} + +// Explicit instantiations so the JIT compiler can find them by name +template [[host_name("kernel_conv_3d_f32_f32")]] +kernel void kernel_conv_3d<float>( + constant ggml_metal_kargs_conv_3d & args, + device const char * src0, + device const char * src1, + device char * dst, + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tpitg[[thread_position_in_threadgroup]]); + +// Explicit instantiation for f16 weights +template [[host_name("kernel_conv_3d_f16_f32")]] +kernel void kernel_conv_3d<half>( + constant ggml_metal_kargs_conv_3d & args, + device const char * src0, + device const char * src1, + device char * dst, + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tpitg[[thread_position_in_threadgroup]]); + + +static inline float bicubic_weight1(float x) { + const float a = -0.75f; + return ((a + 2) * x - (a + 3)) * x * x + 1; +} + +static inline float bicubic_weight2(float x) { + const float a = -0.75f; + return ((a * x - 5 * a) * x + 8 * a) * x - 4 * a; +} + +kernel void kernel_upscale_bicubic_f32( + constant ggml_metal_kargs_upscale & args, + device const char * src0, + device char * dst, + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tpitg[[thread_position_in_threadgroup]], + uint3 ntg[[threads_per_threadgroup]]) { + + const int64_t i3 = tgpig.z; + const int64_t i2 = tgpig.y; + const int64_t i1 = tgpig.x; + + const int64_t i03 = i3 / args.sf3; + const int64_t i02 = i2 / args.sf2; + + const float f01 = ((float)i1 + args.poffs) / args.sf1 - args.poffs; + const int64_t i01 = (int64_t)floor(f01); + const float fd1 = f01 - (float)i01; + + const float w_y0 = bicubic_weight2(fd1 + 1.0f); + const float w_y1 = bicubic_weight1(fd1); + const float w_y2 = bicubic_weight1(1.0f - fd1); + const float w_y3 = bicubic_weight2(2.0f - fd1); + + const device const char * src_slice = src0 + i03 * args.nb03 + i02 * args.nb02; + + device float * dst_ptr = (device float *)(dst + i3 * args.nb3 + i2 * args.nb2 + i1 * args.nb1); + + for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) { + const float f00 = ((float)i0 + args.poffs) / args.sf0 - args.poffs; + const int64_t i00 = (int64_t)floor(f00); + const float fd0 = f00 - (float)i00; + + const float w_x0 = bicubic_weight2(fd0 + 1.0f); + const float w_x1 = bicubic_weight1(fd0); + const float w_x2 = bicubic_weight1(1.0f - fd0); + const float w_x3 = bicubic_weight2(2.0f - fd0); + + float sum = 0.0f; + + for (int dy = -1; dy <= 2; ++dy) { + const int64_t iy = MAX(0, MIN(args.ne01 - 1, i01 + dy)); + const float wy = (dy == -1) ? w_y0 : (dy == 0) ? w_y1 : (dy == 1) ? w_y2 : w_y3; + + for (int dx = -1; dx <= 2; ++dx) { + const int64_t ix = MAX(0, MIN(args.ne00 - 1, i00 + dx)); + const float wx = (dx == -1) ? w_x0 : (dx == 0) ? w_x1 : (dx == 1) ? w_x2 : w_x3; + + device const float * src_ptr = (device const float *)(src_slice + iy * args.nb01 + ix * args.nb00); + sum += (*src_ptr) * wx * wy; + } + } + + dst_ptr[i0] = sum; } +} + +kernel void kernel_roll_f32( + constant ggml_metal_kargs_roll & args, + device const char * src0, + device char * dst, + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tpitg[[thread_position_in_threadgroup]], + uint3 ntg[[threads_per_threadgroup]]) { + + const int64_t i3 = tgpig.z; + const int64_t i2 = tgpig.y; + const int64_t i1 = tgpig.x; + + device const float * src0_ptr = (device const float *) src0; + device float * dst_ptr = (device float *) dst; for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) { - dst_ptr[i0] = 0.0f; + // apply shifts and wrap around + int64_t i00 = i0 - args.s0; + int64_t i01 = i1 - args.s1; + int64_t i02 = i2 - args.s2; + int64_t i03 = i3 - args.s3; + + if (i00 < 0) { i00 += args.ne00; } else if (i00 >= args.ne00) { i00 -= args.ne00; } + if (i01 < 0) { i01 += args.ne01; } else if (i01 >= args.ne01) { i01 -= args.ne01; } + if (i02 < 0) { i02 += args.ne02; } else if (i02 >= args.ne02) { i02 -= args.ne02; } + if (i03 < 0) { i03 += args.ne03; } else if (i03 >= args.ne03) { i03 -= args.ne03; } + + int64_t src_idx = i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00 + i00; + int64_t dst_idx = i3 *args.ne2 *args.ne1 *args.ne0 + i2 *args.ne1 *args.ne0 + i1 *args.ne0 + i0; + + dst_ptr[dst_idx] = src0_ptr[src_idx]; + } +} + +template <typename T> +kernel void kernel_pad_impl( + constant ggml_metal_kargs_pad & args, + device const char * src0, + device char * dst, + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tpitg[[thread_position_in_threadgroup]], + uint3 ntg[[threads_per_threadgroup]]) { + const int32_t i3 = tgpig.z; + const int32_t i2 = tgpig.y; + const int32_t k0 = tgpig.x/args.ne1; + const int32_t i1 = tgpig.x - k0*args.ne1; + + const int32_t i03 = i3; + const int32_t i02 = i2; + const int32_t i01 = i1; + + device const T * src0_ptr = (device const T *) (src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01); + device T * dst_ptr = (device T *) (dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1); + + for (int32_t l0 = 0; l0 < 1024; l0 += ntg.x) { + const int32_t i0 = k0*1024 + tpitg.x + l0; + if (i0 >= args.ne0) { + break; + } + + if (i0 < args.ne00 && i1 < args.ne01 && i2 < args.ne02 && i3 < args.ne03) { + dst_ptr[i0] = src0_ptr[i0]; + } else { + dst_ptr[i0] = 0.0f; + } } } +typedef decltype(kernel_pad_impl<float>) kernel_pad_t; + +template [[host_name("kernel_pad_f32")]] kernel kernel_pad_t kernel_pad_impl<float>; +template [[host_name("kernel_pad_f32_4")]] kernel kernel_pad_t kernel_pad_impl<float4>; + +// TODO: this is slow - optimize kernel void kernel_pad_reflect_1d_f32( constant ggml_metal_kargs_pad_reflect_1d & args, device const char * src0, @@ -5114,24 +5707,6 @@ kernel void kernel_argsort_merge_f32_i32( template [[host_name("kernel_argsort_merge_f32_i32_asc")]] kernel argsort_merge_t kernel_argsort_merge_f32_i32<GGML_SORT_ORDER_ASC>; template [[host_name("kernel_argsort_merge_f32_i32_desc")]] kernel argsort_merge_t kernel_argsort_merge_f32_i32<GGML_SORT_ORDER_DESC>; -kernel void kernel_leaky_relu_f32( - constant ggml_metal_kargs_leaky_relu & args, - device const float * src0, - device float * dst, - uint tpig[[thread_position_in_grid]]) { - const float x = src0[tpig]; - dst[tpig] = x > 0.0f ? x : x * args.slope; -} - -kernel void kernel_leaky_relu_f32_4( - constant ggml_metal_kargs_leaky_relu & args, - device const float4 * src0, - device float4 * dst, - uint tpig[[thread_position_in_grid]]) { - const float4 x = src0[tpig]; - dst[tpig] = float4(x > 0.0f)*x + float4(x <= 0.0f)*(x * args.slope); -} - constant bool FC_flash_attn_ext_pad_has_mask [[function_constant(FC_FLASH_ATTN_EXT_PAD + 0)]]; constant int32_t FC_flash_attn_ext_pad_ncpsg [[function_constant(FC_FLASH_ATTN_EXT_PAD + 25)]]; @@ -5208,6 +5783,7 @@ constant int32_t FC_flash_attn_ext_blk_ncpsg [[function_constant(FC_FLASH_ATTN_E // scan the blocks of the mask that are not masked // 0 - masked (i.e. full of -INF, skip) // 1 - not masked (i.e. at least one element of the mask is not -INF) +// 2 - all zero kernel void kernel_flash_attn_ext_blk( constant ggml_metal_kargs_flash_attn_ext_blk & args, device const char * mask, @@ -5229,27 +5805,29 @@ kernel void kernel_flash_attn_ext_blk( device const half * mask_src = (device const half *) (mask + (i1*Q)*args.nb31 + i2*args.nb32 + i3*args.nb33) + i0*C + tiisg; - // fast route - if (res == 0) { - if (simd_max(*mask_src) > -MAXHALF/2) { - res = 1; - } - } - // detailed check of the elements of the block if ((C > NW || Q > 1) && res == 0) { - half m = -MAXHALF; + half mmin = MAXHALF; + half mmax = -MAXHALF; FOR_UNROLL (short j = 0; j < Q; ++j) { FOR_UNROLL (short ii = 0; ii < C/NW; ++ii) { - m = max(m, mask_src[ii*NW]); + mmin = min(mmin, mask_src[ii*NW]); + mmax = max(mmax, mask_src[ii*NW]); } mask_src += args.nb31/2; } - if (simd_max(m) > -MAXHALF/2) { - res = 1; + mmin = simd_min(mmin); + mmax = simd_max(mmax); + + if (mmax > -MAXHALF) { + if (mmin == 0.0 && mmax == 0.0) { + res = 2; + } else { + res = 1; + } } } @@ -5491,9 +6069,13 @@ void kernel_flash_attn_ext_impl( ic = 0; } + char blk_cur = 1; + // read the mask into shared mem if (FC_flash_attn_ext_has_mask) { - if (blk[ic0] == 0) { + blk_cur = blk[ic0]; + + if (blk_cur == 0) { FOR_UNROLL (short jj = 0; jj < NQ; ++jj) { pm2[jj] += NW; } @@ -5501,16 +6083,22 @@ void kernel_flash_attn_ext_impl( continue; } - FOR_UNROLL (short jj = 0; jj < NQ; ++jj) { - const short j = jj*NSG + sgitg; + if (blk_cur == 1) { + FOR_UNROLL (short jj = 0; jj < NQ; ++jj) { + const short j = jj*NSG + sgitg; - if (FC_flash_attn_ext_bc_mask) { - sm2[j*SH + tiisg] = (iq1 + j) < args.ne31 ? pm2[jj][tiisg] : half2(-MAXHALF, -MAXHALF); - } else { - sm2[j*SH + tiisg] = pm2[jj][tiisg]; - } + if (FC_flash_attn_ext_bc_mask) { + sm2[j*SH + tiisg] = (iq1 + j) < args.ne31 ? pm2[jj][tiisg] : half2(-MAXHALF, -MAXHALF); + } else { + sm2[j*SH + tiisg] = pm2[jj][tiisg]; + } - pm2[jj] += NW; + pm2[jj] += NW; + } + } else if (blk_cur == 2) { + FOR_UNROLL (short jj = 0; jj < NQ; ++jj) { + pm2[jj] += NW; + } } #if 0 @@ -5552,9 +6140,7 @@ void kernel_flash_attn_ext_impl( constexpr short NC = (C/8)/NSG; - // note: do not unroll for large heads - #pragma unroll (DK <= 64 ? NC : 1) - for (short cc = 0; cc < NC; ++cc) { + FOR_UNROLL (short cc = 0; cc < NC; ++cc) { qk8x8_t mqk = make_filled_simdgroup_matrix<qk_t, 8>((qk_t) 0.0f); if (DK % 16 != 0) { @@ -5575,7 +6161,9 @@ void kernel_flash_attn_ext_impl( k8x8_t mk[2]; q8x8_t mq[2]; - FOR_UNROLL (short i = 0; i < DK8/2; ++i) { + // note: too much unroll can tank the performance for large heads + #pragma unroll (MIN(DK8/2, 4*NSG)) + for (short i = 0; i < DK8/2; ++i) { simdgroup_barrier(mem_flags::mem_none); simdgroup_load(mq[0], pq + 0*8 + 16*i, DK); @@ -5675,10 +6263,12 @@ void kernel_flash_attn_ext_impl( } // mqk = mqk + slope*mask - if (FC_flash_attn_ext_has_bias) { - s2 += s2_t(sm2[j*SH + tiisg])*slope; - } else { - s2 += s2_t(sm2[j*SH + tiisg]); + if (blk_cur != 2) { + if (FC_flash_attn_ext_has_bias) { + s2 += s2_t(sm2[j*SH + tiisg])*slope; + } else { + s2 += s2_t(sm2[j*SH + tiisg]); + } } M[jj] = simd_max(max(M[jj], max(s2[0], s2[1]))); @@ -5749,7 +6339,9 @@ void kernel_flash_attn_ext_impl( pv += 8*NS20; } } else { - FOR_UNROLL (short cc = 0; cc < (C/8)/2; ++cc) { + constexpr short NC = (C/8)/2; + + FOR_UNROLL (short cc = 0; cc < NC; ++cc) { s8x8_t vs[2]; simdgroup_load(vs[0], ss + 16*cc + 0, SH, 0, false); @@ -5929,7 +6521,7 @@ template< void (*deq_v)(device const vd4x4_t *, short, thread v4x4_t &), short DK, // K head size short DV, // V head size - short Q = OP_FLASH_ATTN_EXT_NQPTG, // queries per threadgroup + short Q = OP_FLASH_ATTN_EXT_NQPSG, // queries per threadgroup short C = OP_FLASH_ATTN_EXT_NCPSG> // cache items per threadgroup kernel void kernel_flash_attn_ext( constant ggml_metal_kargs_flash_attn_ext & args, @@ -5952,6 +6544,7 @@ kernel void kernel_flash_attn_ext( //case 1: kernel_flash_attn_ext_impl<FWD_TMPL, 1>(FWD_ARGS); break; //case 2: kernel_flash_attn_ext_impl<FWD_TMPL, 2>(FWD_ARGS); break; case 4: kernel_flash_attn_ext_impl<FWD_TMPL, 4>(FWD_ARGS); break; + case 8: kernel_flash_attn_ext_impl<FWD_TMPL, 8>(FWD_ARGS); break; } #undef FWD_TMPL #undef FWD_ARGS @@ -6001,6 +6594,8 @@ template [[host_name("kernel_flash_attn_ext_f32_dk128_dv128")]] kernel flash_at template [[host_name("kernel_flash_attn_ext_f32_dk192_dv192")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_F32, float4x4, 1, dequantize_f32, float4x4, 1, dequantize_f32, 192, 192>; template [[host_name("kernel_flash_attn_ext_f32_dk192_dv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_F32, float4x4, 1, dequantize_f32, float4x4, 1, dequantize_f32, 192, 128>; template [[host_name("kernel_flash_attn_ext_f32_dk256_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_F32, float4x4, 1, dequantize_f32, float4x4, 1, dequantize_f32, 256, 256>; +template [[host_name("kernel_flash_attn_ext_f32_dk320_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_F32, float4x4, 1, dequantize_f32, float4x4, 1, dequantize_f32, 320, 256>; +template [[host_name("kernel_flash_attn_ext_f32_dk512_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_F32, float4x4, 1, dequantize_f32, float4x4, 1, dequantize_f32, 512, 512>; template [[host_name("kernel_flash_attn_ext_f32_dk576_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_F32, float4x4, 1, dequantize_f32, float4x4, 1, dequantize_f32, 576, 512>; template [[host_name("kernel_flash_attn_ext_f16_dk32_dv32" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 32, 32>; @@ -6015,6 +6610,8 @@ template [[host_name("kernel_flash_attn_ext_f16_dk128_dv128")]] kernel flash_at template [[host_name("kernel_flash_attn_ext_f16_dk192_dv192")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 192, 192>; template [[host_name("kernel_flash_attn_ext_f16_dk192_dv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 192, 128>; template [[host_name("kernel_flash_attn_ext_f16_dk256_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 256, 256>; +template [[host_name("kernel_flash_attn_ext_f16_dk320_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 320, 256>; +template [[host_name("kernel_flash_attn_ext_f16_dk512_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 512, 512>; template [[host_name("kernel_flash_attn_ext_f16_dk576_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 576, 512>; #if defined(GGML_METAL_HAS_BF16) @@ -6030,6 +6627,8 @@ template [[host_name("kernel_flash_attn_ext_bf16_dk128_dv128")]] kernel flash_at template [[host_name("kernel_flash_attn_ext_bf16_dk192_dv192")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 192, 192>; template [[host_name("kernel_flash_attn_ext_bf16_dk192_dv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 192, 128>; template [[host_name("kernel_flash_attn_ext_bf16_dk256_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 256, 256>; +template [[host_name("kernel_flash_attn_ext_bf16_dk320_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 320, 256>; +template [[host_name("kernel_flash_attn_ext_bf16_dk512_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 512, 512>; template [[host_name("kernel_flash_attn_ext_bf16_dk576_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES_BF, bfloat4x4, 1, dequantize_bf16, bfloat4x4, 1, dequantize_bf16, 576, 512>; #endif @@ -6045,6 +6644,8 @@ template [[host_name("kernel_flash_attn_ext_q4_0_dk128_dv128")]] kernel flash_at template [[host_name("kernel_flash_attn_ext_q4_0_dk192_dv192")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 192, 192>; template [[host_name("kernel_flash_attn_ext_q4_0_dk192_dv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 192, 128>; template [[host_name("kernel_flash_attn_ext_q4_0_dk256_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 256, 256>; +template [[host_name("kernel_flash_attn_ext_q4_0_dk320_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 320, 256>; +template [[host_name("kernel_flash_attn_ext_q4_0_dk512_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 512, 512>; template [[host_name("kernel_flash_attn_ext_q4_0_dk576_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, block_q4_0, 2, dequantize_q4_0, 576, 512>; template [[host_name("kernel_flash_attn_ext_q4_1_dk32_dv32" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 32, 32>; @@ -6059,6 +6660,8 @@ template [[host_name("kernel_flash_attn_ext_q4_1_dk128_dv128")]] kernel flash_at template [[host_name("kernel_flash_attn_ext_q4_1_dk192_dv192")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 192, 192>; template [[host_name("kernel_flash_attn_ext_q4_1_dk192_dv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 192, 128>; template [[host_name("kernel_flash_attn_ext_q4_1_dk256_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 256, 256>; +template [[host_name("kernel_flash_attn_ext_q4_1_dk320_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 320, 256>; +template [[host_name("kernel_flash_attn_ext_q4_1_dk512_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 512, 512>; template [[host_name("kernel_flash_attn_ext_q4_1_dk576_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, block_q4_1, 2, dequantize_q4_1, 576, 512>; template [[host_name("kernel_flash_attn_ext_q5_0_dk32_dv32" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 32, 32>; @@ -6073,6 +6676,8 @@ template [[host_name("kernel_flash_attn_ext_q5_0_dk128_dv128")]] kernel flash_at template [[host_name("kernel_flash_attn_ext_q5_0_dk192_dv192")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 192, 192>; template [[host_name("kernel_flash_attn_ext_q5_0_dk192_dv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 192, 128>; template [[host_name("kernel_flash_attn_ext_q5_0_dk256_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 256, 256>; +template [[host_name("kernel_flash_attn_ext_q5_0_dk320_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 320, 256>; +template [[host_name("kernel_flash_attn_ext_q5_0_dk512_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 512, 512>; template [[host_name("kernel_flash_attn_ext_q5_0_dk576_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, block_q5_0, 2, dequantize_q5_0, 576, 512>; template [[host_name("kernel_flash_attn_ext_q5_1_dk32_dv32" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 32, 32>; @@ -6087,6 +6692,8 @@ template [[host_name("kernel_flash_attn_ext_q5_1_dk128_dv128")]] kernel flash_at template [[host_name("kernel_flash_attn_ext_q5_1_dk192_dv192")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 192, 192>; template [[host_name("kernel_flash_attn_ext_q5_1_dk192_dv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 192, 128>; template [[host_name("kernel_flash_attn_ext_q5_1_dk256_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 256, 256>; +template [[host_name("kernel_flash_attn_ext_q5_1_dk320_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 320, 256>; +template [[host_name("kernel_flash_attn_ext_q5_1_dk512_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 512, 512>; template [[host_name("kernel_flash_attn_ext_q5_1_dk576_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, block_q5_1, 2, dequantize_q5_1, 576, 512>; template [[host_name("kernel_flash_attn_ext_q8_0_dk32_dv32" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 32, 32>; @@ -6101,6 +6708,8 @@ template [[host_name("kernel_flash_attn_ext_q8_0_dk128_dv128")]] kernel flash_at template [[host_name("kernel_flash_attn_ext_q8_0_dk192_dv192")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 192, 192>; template [[host_name("kernel_flash_attn_ext_q8_0_dk192_dv128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 192, 128>; template [[host_name("kernel_flash_attn_ext_q8_0_dk256_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 256, 256>; +template [[host_name("kernel_flash_attn_ext_q8_0_dk320_dv256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 320, 256>; +template [[host_name("kernel_flash_attn_ext_q8_0_dk512_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 512, 512>; template [[host_name("kernel_flash_attn_ext_q8_0_dk576_dv512")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, block_q8_0, 2, dequantize_q8_0, 576, 512>; #undef FA_TYPES @@ -6138,11 +6747,10 @@ template< void (*deq_v_t4)(device const vd4_t *, short, thread v4_t &), short DK, // K head size short DV, // V head size - short NE, // head elements per thread - short Q, // queries per threadgroup - short C, // cache items per threadgroup - short NSG> // number of simd groups -void kernel_flash_attn_ext_vec_impl( + short NE = 4, // head elements per thread + short Q = OP_FLASH_ATTN_EXT_VEC_NQPSG, // queries per threadgroup + short C = OP_FLASH_ATTN_EXT_VEC_NCPSG> // cache items per threadgroup +kernel void kernel_flash_attn_ext_vec( constant ggml_metal_kargs_flash_attn_ext_vec & args, device const char * q, device const char * k, @@ -6159,6 +6767,7 @@ void kernel_flash_attn_ext_vec_impl( static_assert(DV % 32 == 0, "DV must be divisible by 32"); #define NWG (FC_flash_attn_ext_vec_nwg) +#define NSG (FC_flash_attn_ext_vec_nsg) #define NS10 (FC_flash_attn_ext_vec_ns10) #define NS20 (FC_flash_attn_ext_vec_ns20) @@ -6185,14 +6794,14 @@ void kernel_flash_attn_ext_vec_impl( static_assert(DK4 % NL == 0, "DK4 must be divisible by NL"); static_assert(DV4 % NL == 0, "DV4 must be divisible by NL"); - const short T = PK + NSG*SH; // shared memory size per query in (half) + //const short T = PK + NSG*SH; // shared memory size per query in (half) - //threadgroup q_t * sq = (threadgroup q_t *) (shmem_f16 + 0*PK); // holds the query data - threadgroup q4_t * sq4 = (threadgroup q4_t *) (shmem_f16 + 0*PK); // same as above but in q4_t - threadgroup s_t * ss = (threadgroup s_t *) (shmem_f16 + sgitg*SH + Q*PK); // scratch buffer for attention - threadgroup s4_t * ss4 = (threadgroup s4_t *) (shmem_f16 + sgitg*SH + Q*PK); // same as above but in s4_t - threadgroup half * sm = (threadgroup half *) (shmem_f16 + sgitg*SH + 2*C + Q*PK); // scratch buffer for mask - threadgroup o4_t * so4 = (threadgroup o4_t *) (shmem_f16 + 2*sgitg*PV + Q*T); // scratch buffer for the results + //threadgroup q_t * sq = (threadgroup q_t *) (shmem_f16 + 0*PK); // holds the query data + threadgroup q4_t * sq4 = (threadgroup q4_t *) (shmem_f16 + 0*PK); // same as above but in q4_t + threadgroup s_t * ss = (threadgroup s_t *) (shmem_f16 + sgitg*SH + NSG*PK); // scratch buffer for attention + threadgroup s4_t * ss4 = (threadgroup s4_t *) (shmem_f16 + sgitg*SH + NSG*PK); // same as above but in s4_t + threadgroup half * sm = (threadgroup half *) (shmem_f16 + sgitg*SH + 2*C + NSG*PK); // scratch buffer for mask + threadgroup o4_t * so4 = (threadgroup o4_t *) (shmem_f16 + 2*sgitg*PV + NSG*PK + NSG*SH); // scratch buffer for the results // store the result for all queries in shared memory (the O matrix from the paper) so4 += tiisg; @@ -6210,11 +6819,13 @@ void kernel_flash_attn_ext_vec_impl( // load heads from Q to shared memory device const float4 * q4 = (device const float4 *) ((device const char *) q); - for (short i = tiisg; i < PK4; i += NW) { - if (iq1 < args.ne01 && i < DK4) { - sq4[i] = (q4_t) q4[i]; - } else { - sq4[i] = (q4_t) 0.0f; + if (iq1 < args.ne01) { + for (short i = tiisg; i < PK4; i += NW) { + if (i < DK4) { + sq4[i] = (q4_t) q4[i]; + } else { + sq4[i] = (q4_t) 0.0f; + } } } @@ -6292,7 +6903,7 @@ void kernel_flash_attn_ext_vec_impl( } // skip -INF blocks - if (simd_max(sm[tiisg]) == -INFINITY) { + if (simd_max(sm[tiisg]) <= -MAXHALF) { continue; } @@ -6566,57 +7177,11 @@ void kernel_flash_attn_ext_vec_impl( } #undef NWG +#undef NSG #undef NS10 #undef NS20 } -template< - typename q4_t, // query types in shared memory - typename k4_t, // key types in shared memory - typename v4_t, // value types in shared memory - typename qk_t, // Q*K types - typename s_t, // soft-max types - typename s4_t, - typename o4_t, // attention accumulation types - typename kd4_t, // key type in device memory - short nl_k, - void (*deq_k_t4)(device const kd4_t *, short, thread k4_t &), - typename vd4_t, // value type in device memory - short nl_v, - void (*deq_v_t4)(device const vd4_t *, short, thread v4_t &), - short DK, // K head size - short DV, // V head size - short NE = 4, // head elements per thread - short Q = OP_FLASH_ATTN_EXT_VEC_NQPTG, // queries per threadgroup - short C = OP_FLASH_ATTN_EXT_VEC_NCPSG> // cache items per threadgroup -kernel void kernel_flash_attn_ext_vec( - constant ggml_metal_kargs_flash_attn_ext_vec & args, - device const char * q, - device const char * k, - device const char * v, - device const char * mask, - device const char * sinks, - device const char * pad, - device char * dst, - threadgroup half * shmem_f16 [[threadgroup(0)]], - uint3 tgpig[[threadgroup_position_in_grid]], - ushort tiisg[[thread_index_in_simdgroup]], - ushort sgitg[[simdgroup_index_in_threadgroup]]) { -#define FWD_TMPL q4_t, k4_t, v4_t, qk_t, s_t, s4_t, o4_t, kd4_t, nl_k, deq_k_t4, vd4_t, nl_v, deq_v_t4, DK, DV, NE, Q, C -#define FWD_ARGS args, q, k, v, mask, sinks, pad, dst, shmem_f16, tgpig, tiisg, sgitg - switch (FC_flash_attn_ext_vec_nsg) { - // note: disabled cases to reduce library load time - case 1: kernel_flash_attn_ext_vec_impl<FWD_TMPL, 1>(FWD_ARGS); break; - case 2: kernel_flash_attn_ext_vec_impl<FWD_TMPL, 2>(FWD_ARGS); break; - case 4: kernel_flash_attn_ext_vec_impl<FWD_TMPL, 4>(FWD_ARGS); break; - //case 8: kernel_flash_attn_ext_vec_impl<FWD_TMPL, 8>(FWD_ARGS); break; - //case 16: kernel_flash_attn_ext_vec_impl<FWD_TMPL, 16>(FWD_ARGS); break; - //case 32: kernel_flash_attn_ext_vec_impl<FWD_TMPL, 32>(FWD_ARGS); break; - } -#undef FWD_TMPL -#undef FWD_ARGS -} - // note: I think the s_t can be half instead of float, because the Q*K scaling is done before storing to shared mem // in the other (non-vec) kernel, we need s_t to also be float because we scale during the soft_max // @@ -6715,6 +7280,28 @@ template [[host_name("kernel_flash_attn_ext_vec_q5_0_dk256_dv256")]] kernel flas template [[host_name("kernel_flash_attn_ext_vec_q5_1_dk256_dv256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_1, 8, dequantize_q5_1_t4, block_q5_1, 8, dequantize_q5_1_t4, 256, 256, 1>; template [[host_name("kernel_flash_attn_ext_vec_q8_0_dk256_dv256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q8_0, 8, dequantize_q8_0_t4, block_q8_0, 8, dequantize_q8_0_t4, 256, 256, 1>; +template [[host_name("kernel_flash_attn_ext_vec_f32_dk320_dv256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES_F32, float4, 1, dequantize_f32_t4, float4, 1, dequantize_f32_t4, 320, 256, 2>; +template [[host_name("kernel_flash_attn_ext_vec_f16_dk320_dv256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 320, 256, 2>; +#if defined(GGML_METAL_HAS_BF16) +template [[host_name("kernel_flash_attn_ext_vec_bf16_dk320_dv256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, bfloat4, 1, dequantize_bf16_t4, bfloat4, 1, dequantize_bf16_t4, 320, 256, 2>; +#endif +template [[host_name("kernel_flash_attn_ext_vec_q4_0_dk320_dv256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_0, 8, dequantize_q4_0_t4, block_q4_0, 8, dequantize_q4_0_t4, 320, 256, 2>; +template [[host_name("kernel_flash_attn_ext_vec_q4_1_dk320_dv256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_1, 8, dequantize_q4_1_t4, block_q4_1, 8, dequantize_q4_1_t4, 320, 256, 2>; +template [[host_name("kernel_flash_attn_ext_vec_q5_0_dk320_dv256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_0, 8, dequantize_q5_0_t4, block_q5_0, 8, dequantize_q5_0_t4, 320, 256, 2>; +template [[host_name("kernel_flash_attn_ext_vec_q5_1_dk320_dv256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_1, 8, dequantize_q5_1_t4, block_q5_1, 8, dequantize_q5_1_t4, 320, 256, 2>; +template [[host_name("kernel_flash_attn_ext_vec_q8_0_dk320_dv256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q8_0, 8, dequantize_q8_0_t4, block_q8_0, 8, dequantize_q8_0_t4, 320, 256, 2>; + +template [[host_name("kernel_flash_attn_ext_vec_f32_dk512_dv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES_F32, float4, 1, dequantize_f32_t4, float4, 1, dequantize_f32_t4, 512, 512, 1>; +template [[host_name("kernel_flash_attn_ext_vec_f16_dk512_dv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 512, 512, 1>; +#if defined(GGML_METAL_HAS_BF16) +template [[host_name("kernel_flash_attn_ext_vec_bf16_dk512_dv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, bfloat4, 1, dequantize_bf16_t4, bfloat4, 1, dequantize_bf16_t4, 512, 512, 1>; +#endif +template [[host_name("kernel_flash_attn_ext_vec_q4_0_dk512_dv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_0, 8, dequantize_q4_0_t4, block_q4_0, 8, dequantize_q4_0_t4, 512, 512, 1>; +template [[host_name("kernel_flash_attn_ext_vec_q4_1_dk512_dv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_1, 8, dequantize_q4_1_t4, block_q4_1, 8, dequantize_q4_1_t4, 512, 512, 1>; +template [[host_name("kernel_flash_attn_ext_vec_q5_0_dk512_dv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_0, 8, dequantize_q5_0_t4, block_q5_0, 8, dequantize_q5_0_t4, 512, 512, 1>; +template [[host_name("kernel_flash_attn_ext_vec_q5_1_dk512_dv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_1, 8, dequantize_q5_1_t4, block_q5_1, 8, dequantize_q5_1_t4, 512, 512, 1>; +template [[host_name("kernel_flash_attn_ext_vec_q8_0_dk512_dv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q8_0, 8, dequantize_q8_0_t4, block_q8_0, 8, dequantize_q8_0_t4, 512, 512, 1>; + template [[host_name("kernel_flash_attn_ext_vec_f32_dk576_dv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES_F32, float4, 1, dequantize_f32_t4, float4, 1, dequantize_f32_t4, 576, 512, 2>; template [[host_name("kernel_flash_attn_ext_vec_f16_dk576_dv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 576, 512, 2>; #if defined(GGML_METAL_HAS_BF16) @@ -6780,23 +7367,27 @@ kernel void kernel_cpy_t_t( device const char * src0, device char * dst, uint3 tgpig[[threadgroup_position_in_grid]], - ushort tiitg[[thread_index_in_threadgroup]], + ushort3 tpitg[[thread_position_in_threadgroup]], ushort3 ntg[[threads_per_threadgroup]]) { - const int i03 = tgpig[2]; - const int i02 = tgpig[1]; - const int i01 = ntg[1] == 1 ? tgpig[0]%args.ne01 : tgpig[0]*ntg[1] + tiitg/ntg[0]; - const int iw0 = ntg[1] == 1 ? tgpig[0]/args.ne01 : 0; + const int32_t i03 = tgpig[2]; + const int32_t i02 = tgpig[1]; + const int32_t i01 = ntg[1] == 1 ? tgpig[0]%args.ne01 : tgpig[0]*ntg[1] + tpitg.y; + const int32_t iw0 = ntg[1] == 1 ? tgpig[0]/args.ne01 : 0; + + if (i01 >= args.ne01) { + return; + } const int64_t n = i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00; - const int64_t i3 = n/(args.ne2*args.ne1*args.ne0); - const int64_t i2 = (n - i3*args.ne2*args.ne1*args.ne0)/(args.ne1*args.ne0); - const int64_t i1 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0)/args.ne0; - const int64_t i0 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0 - i1*args.ne0); + const int32_t i3 = n/(args.ne2*args.ne1*args.ne0); + const int32_t i2 = (n - i3*args.ne2*args.ne1*args.ne0)/(args.ne1*args.ne0); + const int32_t i1 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0)/args.ne0; + const int32_t i0 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0 - i1*args.ne0); device T1 * dst_data = (device T1 *) (dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0); - for (int64_t i00 = iw0*ntg[0] + tiitg%ntg[0]; i00 < args.ne00; ) { + for (int32_t i00 = iw0*ntg[0] + tpitg.x; i00 < args.ne00;) { device const T0 * src = (device T0 *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + i00*args.nb00); dst_data[i00] = (T1) src[0]; break; @@ -6828,23 +7419,27 @@ kernel void kernel_cpy_f32_q( device const char * src0, device char * dst, uint3 tgpig[[threadgroup_position_in_grid]], - ushort tiitg[[thread_index_in_threadgroup]], + ushort3 tpitg[[thread_position_in_threadgroup]], ushort3 ntg[[threads_per_threadgroup]]) { - const int i03 = tgpig[2]; - const int i02 = tgpig[1]; - const int i01 = ntg[1] == 1 ? tgpig[0]%args.ne01 : tgpig[0]*ntg[1] + tiitg/ntg[0]; - const int iw0 = ntg[1] == 1 ? tgpig[0]/args.ne01 : 0; + const int32_t i03 = tgpig[2]; + const int32_t i02 = tgpig[1]; + const int32_t i01 = ntg[1] == 1 ? tgpig[0]%args.ne01 : tgpig[0]*ntg[1] + tpitg.y; + const int32_t iw0 = ntg[1] == 1 ? tgpig[0]/args.ne01 : 0; + + if (i01 >= args.ne01) { + return; + } const int64_t n = i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00; - const int64_t i3 = n / (args.ne2*args.ne1*args.ne0); - const int64_t i2 = (n - i3*args.ne2*args.ne1*args.ne0) / (args.ne1*args.ne0); - const int64_t i1 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0) / args.ne0; - const int64_t i0 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0 - i1*args.ne0)/QK; + const int32_t i3 = n / (args.ne2*args.ne1*args.ne0); + const int32_t i2 = (n - i3*args.ne2*args.ne1*args.ne0) / (args.ne1*args.ne0); + const int32_t i1 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0) / args.ne0; + const int32_t i0 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0 - i1*args.ne0)/QK; device block_q * dst_data = (device block_q *)(dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0); - for (int64_t i00 = iw0*ntg[0] + tiitg%ntg[0]; i00 < args.nk0; ) { + for (int32_t i00 = iw0*ntg[0] + tpitg.x; i00 < args.nk0;) { device const float * src = (device const float *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + (i00*QK)*args.nb00); quantize_func(src, dst_data[i00]); @@ -6856,6 +7451,7 @@ kernel void kernel_cpy_f32_q( typedef decltype(kernel_cpy_f32_q<QK8_0, block_q8_0, quantize_q8_0>) cpy_f_q_t; template [[host_name("kernel_cpy_f32_q8_0")]] kernel cpy_f_q_t kernel_cpy_f32_q<QK8_0, block_q8_0, quantize_q8_0>; +template [[host_name("kernel_cpy_f32_q1_0")]] kernel cpy_f_q_t kernel_cpy_f32_q<QK1_0, block_q1_0, quantize_q1_0>; template [[host_name("kernel_cpy_f32_q4_0")]] kernel cpy_f_q_t kernel_cpy_f32_q<QK4_0, block_q4_0, quantize_q4_0>; template [[host_name("kernel_cpy_f32_q4_1")]] kernel cpy_f_q_t kernel_cpy_f32_q<QK4_1, block_q4_1, quantize_q4_1>; template [[host_name("kernel_cpy_f32_q5_0")]] kernel cpy_f_q_t kernel_cpy_f32_q<QK5_0, block_q5_0, quantize_q5_0>; @@ -6868,24 +7464,28 @@ kernel void kernel_cpy_q_f32( device const char * src0, device char * dst, uint3 tgpig[[threadgroup_position_in_grid]], - ushort tiitg[[thread_index_in_threadgroup]], + ushort3 tpitg[[thread_position_in_threadgroup]], ushort3 ntg[[threads_per_threadgroup]]) { - const int i03 = tgpig[2]; - const int i02 = tgpig[1]; - const int i01 = ntg[1] == 1 ? tgpig[0]%args.ne01 : tgpig[0]*ntg[1] + tiitg/ntg[0]; - const int iw0 = ntg[1] == 1 ? tgpig[0]/args.ne01 : 0; + const int32_t i03 = tgpig[2]; + const int32_t i02 = tgpig[1]; + const int32_t i01 = ntg[1] == 1 ? tgpig[0]%args.ne01 : tgpig[0]*ntg[1] + tpitg.y; + const int32_t iw0 = ntg[1] == 1 ? tgpig[0]/args.ne01 : 0; + + if (i01 >= args.ne01) { + return; + } const int64_t n = i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00; - const int64_t i3 = n/(args.ne2*args.ne1*args.ne0); - const int64_t i2 = (n - i3*args.ne2*args.ne1*args.ne0)/(args.ne1*args.ne0); - const int64_t i1 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0)/args.ne0; - const int64_t i0 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0 - i1*args.ne0); + const int32_t i3 = n/(args.ne2*args.ne1*args.ne0); + const int32_t i2 = (n - i3*args.ne2*args.ne1*args.ne0)/(args.ne1*args.ne0); + const int32_t i1 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0)/args.ne0; + const int32_t i0 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0 - i1*args.ne0); device const block_q * src_data = (device const block_q *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01); device T4x4 * dst_data = (device T4x4 *)(dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0); - for (int64_t i00 = iw0*ntg[0] + tiitg%ntg[0]; i00 < args.nk0; ) { + for (int32_t i00 = iw0*ntg[0] + tpitg.x; i00 < args.nk0;) { T4x4 temp; dequantize_func(src_data + i00/nl, i00%nl, temp); dst_data[i00] = temp; @@ -6896,12 +7496,14 @@ kernel void kernel_cpy_q_f32( typedef decltype(kernel_cpy_q_f32<float4x4, block_q4_0, 2, dequantize_q4_0>) cpy_q_f_t; +template [[host_name("kernel_cpy_q1_0_f32")]] kernel cpy_q_f_t kernel_cpy_q_f32<float4x4, block_q1_0, 8, dequantize_q1_0>; template [[host_name("kernel_cpy_q4_0_f32")]] kernel cpy_q_f_t kernel_cpy_q_f32<float4x4, block_q4_0, 2, dequantize_q4_0>; template [[host_name("kernel_cpy_q4_1_f32")]] kernel cpy_q_f_t kernel_cpy_q_f32<float4x4, block_q4_1, 2, dequantize_q4_1>; template [[host_name("kernel_cpy_q5_0_f32")]] kernel cpy_q_f_t kernel_cpy_q_f32<float4x4, block_q5_0, 2, dequantize_q5_0>; template [[host_name("kernel_cpy_q5_1_f32")]] kernel cpy_q_f_t kernel_cpy_q_f32<float4x4, block_q5_1, 2, dequantize_q5_1>; template [[host_name("kernel_cpy_q8_0_f32")]] kernel cpy_q_f_t kernel_cpy_q_f32<float4x4, block_q8_0, 2, dequantize_q8_0>; +template [[host_name("kernel_cpy_q1_0_f16")]] kernel cpy_q_f_t kernel_cpy_q_f32<half4x4, block_q1_0, 8, dequantize_q1_0>; template [[host_name("kernel_cpy_q4_0_f16")]] kernel cpy_q_f_t kernel_cpy_q_f32<half4x4, block_q4_0, 2, dequantize_q4_0>; template [[host_name("kernel_cpy_q4_1_f16")]] kernel cpy_q_f_t kernel_cpy_q_f32<half4x4, block_q4_1, 2, dequantize_q4_1>; template [[host_name("kernel_cpy_q5_0_f16")]] kernel cpy_q_f_t kernel_cpy_q_f32<half4x4, block_q5_0, 2, dequantize_q5_0>; @@ -6919,7 +7521,11 @@ kernel void kernel_concat( const int i3 = tgpig.z; const int i2 = tgpig.y; - const int i1 = tgpig.x; + const int i1 = ntg.y == 1 ? tgpig.x : tgpig.x*ntg.y + tpitg.y; + + if (i1 >= args.ne1) { + return; + } int o[4] = {0, 0, 0, 0}; o[args.dim] = args.dim == 0 ? args.ne00 : (args.dim == 1 ? args.ne01 : (args.dim == 2 ? args.ne02 : args.ne03)); @@ -6959,10 +7565,10 @@ void kernel_mul_mv_q2_K_f32_impl( const int first_row = (r0 * NSG + sgitg) * nr0; - const uint i12 = im%args.ne12; - const uint i13 = im/args.ne12; + const uint i12 = im%FC_mul_mv_ne12; + const uint i13 = im/FC_mul_mv_ne12; - const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const uint64_t offset0 = first_row*args.nb01 + (i12/FC_mul_mv_r2)*args.nb02 + (i13/FC_mul_mv_r3)*args.nb03; const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; device const block_q2_K * x = (device const block_q2_K *) (src0 + offset0); @@ -7064,10 +7670,10 @@ void kernel_mul_mv_q3_K_f32_impl( const int first_row = (r0 * NSG + sgitg) * nr0; - const uint i12 = im%args.ne12; - const uint i13 = im/args.ne12; + const uint i12 = im%FC_mul_mv_ne12; + const uint i13 = im/FC_mul_mv_ne12; - const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const uint64_t offset0 = first_row*args.nb01 + (i12/FC_mul_mv_r2)*args.nb02 + (i13/FC_mul_mv_r3)*args.nb03; const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; device const block_q3_K * x = (device const block_q3_K *) (src0 + offset0); @@ -7238,10 +7844,10 @@ void kernel_mul_mv_q4_K_f32_impl( const int first_row = (r0 * NSG + sgitg) * nr0; - const uint i12 = im%args.ne12; - const uint i13 = im/args.ne12; + const uint i12 = im%FC_mul_mv_ne12; + const uint i13 = im/FC_mul_mv_ne12; - const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const uint64_t offset0 = first_row*args.nb01 + (i12/FC_mul_mv_r2)*args.nb02 + (i13/FC_mul_mv_r3)*args.nb03; const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; device const block_q4_K * x = (device const block_q4_K *) (src0 + offset0); @@ -7350,10 +7956,10 @@ void kernel_mul_mv_q5_K_f32_impl( const int first_row = (r0 * NSG + sgitg) * nr0; - const uint i12 = im%args.ne12; - const uint i13 = im/args.ne12; + const uint i12 = im%FC_mul_mv_ne12; + const uint i13 = im/FC_mul_mv_ne12; - const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const uint64_t offset0 = first_row*args.nb01 + (i12/FC_mul_mv_r2)*args.nb02 + (i13/FC_mul_mv_r3)*args.nb03; const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; device const block_q5_K * x = (device const block_q5_K *) (src0 + offset0); @@ -7486,10 +8092,10 @@ void kernel_mul_mv_q6_K_f32_impl( const int first_row = (r0 * NSG + sgitg) * nr0; - const uint i12 = im%args.ne12; - const uint i13 = im/args.ne12; + const uint i12 = im%FC_mul_mv_ne12; + const uint i13 = im/FC_mul_mv_ne12; - const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const uint64_t offset0 = first_row*args.nb01 + (i12/FC_mul_mv_r2)*args.nb02 + (i13/FC_mul_mv_r3)*args.nb03; const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; device const block_q6_K * x = (device const block_q6_K *) (src0 + offset0); @@ -7591,10 +8197,10 @@ void kernel_mul_mv_iq2_xxs_f32_impl( const int first_row = (r0 * NSG + sgitg) * nr0; - const uint i12 = im%args.ne12; - const uint i13 = im/args.ne12; + const uint i12 = im%FC_mul_mv_ne12; + const uint i13 = im/FC_mul_mv_ne12; - const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const uint64_t offset0 = first_row*args.nb01 + (i12/FC_mul_mv_r2)*args.nb02 + (i13/FC_mul_mv_r3)*args.nb03; const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; device const block_iq2_xxs * x = (device const block_iq2_xxs *) (src0 + offset0); @@ -7699,10 +8305,10 @@ void kernel_mul_mv_iq2_xs_f32_impl( const int first_row = (r0 * NSG + sgitg) * nr0; - const uint i12 = im%args.ne12; - const uint i13 = im/args.ne12; + const uint i12 = im%FC_mul_mv_ne12; + const uint i13 = im/FC_mul_mv_ne12; - const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const uint64_t offset0 = first_row*args.nb01 + (i12/FC_mul_mv_r2)*args.nb02 + (i13/FC_mul_mv_r3)*args.nb03; const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; device const block_iq2_xs * x = (device const block_iq2_xs *) (src0 + offset0); @@ -7818,10 +8424,10 @@ void kernel_mul_mv_iq3_xxs_f32_impl( const int first_row = (r0 * NSG + sgitg) * nr0; - const uint i12 = im%args.ne12; - const uint i13 = im/args.ne12; + const uint i12 = im%FC_mul_mv_ne12; + const uint i13 = im/FC_mul_mv_ne12; - const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const uint64_t offset0 = first_row*args.nb01 + (i12/FC_mul_mv_r2)*args.nb02 + (i13/FC_mul_mv_r3)*args.nb03; const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; device const block_iq3_xxs * x = (device const block_iq3_xxs *) (src0 + offset0); @@ -7930,10 +8536,10 @@ void kernel_mul_mv_iq3_s_f32_impl( const int first_row = (r0 * NSG + sgitg) * nr0; - const uint i12 = im%args.ne12; - const uint i13 = im/args.ne12; + const uint i12 = im%FC_mul_mv_ne12; + const uint i13 = im/FC_mul_mv_ne12; - const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const uint64_t offset0 = first_row*args.nb01 + (i12/FC_mul_mv_r2)*args.nb02 + (i13/FC_mul_mv_r3)*args.nb03; const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; device const block_iq3_s * x = (device const block_iq3_s *) (src0 + offset0); @@ -8042,10 +8648,10 @@ void kernel_mul_mv_iq2_s_f32_impl( const int first_row = (r0 * NSG + sgitg) * nr0; - const uint i12 = im%args.ne12; - const uint i13 = im/args.ne12; + const uint i12 = im%FC_mul_mv_ne12; + const uint i13 = im/FC_mul_mv_ne12; - const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const uint64_t offset0 = first_row*args.nb01 + (i12/FC_mul_mv_r2)*args.nb02 + (i13/FC_mul_mv_r3)*args.nb03; const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; device const block_iq2_s * x = (device const block_iq2_s *) (src0 + offset0); @@ -8155,10 +8761,10 @@ void kernel_mul_mv_iq1_s_f32_impl( const int first_row = (r0 * NSG + sgitg) * nr0; - const uint i12 = im%args.ne12; - const uint i13 = im/args.ne12; + const uint i12 = im%FC_mul_mv_ne12; + const uint i13 = im/FC_mul_mv_ne12; - const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const uint64_t offset0 = first_row*args.nb01 + (i12/FC_mul_mv_r2)*args.nb02 + (i13/FC_mul_mv_r3)*args.nb03; const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; device const block_iq1_s * x = (device const block_iq1_s *) (src0 + offset0); @@ -8254,10 +8860,10 @@ void kernel_mul_mv_iq1_m_f32_impl( const int first_row = (r0 * NSG + sgitg) * nr0; - const uint i12 = im%args.ne12; - const uint i13 = im/args.ne12; + const uint i12 = im%FC_mul_mv_ne12; + const uint i13 = im/FC_mul_mv_ne12; - const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const uint64_t offset0 = first_row*args.nb01 + (i12/FC_mul_mv_r2)*args.nb02 + (i13/FC_mul_mv_r3)*args.nb03; const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; device const block_iq1_m * x = (device const block_iq1_m *) (src0 + offset0); @@ -8363,10 +8969,10 @@ void kernel_mul_mv_iq4_nl_f32_impl( const int first_row = (r0 * NSG + sgitg) * NR0; - const uint i12 = im%args.ne12; - const uint i13 = im/args.ne12; + const uint i12 = im%FC_mul_mv_ne12; + const uint i13 = im/FC_mul_mv_ne12; - const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const uint64_t offset0 = first_row*args.nb01 + (i12/FC_mul_mv_r2)*args.nb02 + (i13/FC_mul_mv_r3)*args.nb03; const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; device const block_iq4_nl * x = (device const block_iq4_nl *) (src0 + offset0); @@ -8472,10 +9078,10 @@ void kernel_mul_mv_iq4_xs_f32_impl( const int im = tgpig.z; const int first_row = (r0 * NSG + sgitg) * NR0; - const uint i12 = im%args.ne12; - const uint i13 = im/args.ne12; + const uint i12 = im%FC_mul_mv_ne12; + const uint i13 = im/FC_mul_mv_ne12; - const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const uint64_t offset0 = first_row*args.nb01 + (i12/FC_mul_mv_r2)*args.nb02 + (i13/FC_mul_mv_r3)*args.nb03; const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; device const block_iq4_xs * x = (device const block_iq4_xs *) (src0 + offset0); @@ -8583,10 +9189,10 @@ void kernel_mul_mv_mxfp4_f32_impl( const int first_row = (r0 * NSG + sgitg) * NR0; - const uint i12 = im%args.ne12; - const uint i13 = im/args.ne12; + const uint i12 = im%FC_mul_mv_ne12; + const uint i13 = im/FC_mul_mv_ne12; - const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const uint64_t offset0 = first_row*args.nb01 + (i12/FC_mul_mv_r2)*args.nb02 + (i13/FC_mul_mv_r3)*args.nb03; const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; device const block_mxfp4 * x = (device const block_mxfp4 *) (src0 + offset0); @@ -8779,11 +9385,165 @@ kernel void kernel_set_rows_f( } } +kernel void kernel_diag_f32( + constant ggml_metal_kargs_diag & args, + device const char * src0, + device char * dst, + uint3 tgpig[[threadgroup_position_in_grid]], + ushort tiitg[[thread_index_in_threadgroup]]) { + constexpr short NW = N_SIMDWIDTH; + + const int32_t i3 = tgpig.z; + const int32_t i2 = tgpig.y; + const int32_t i1 = tgpig.x; + + device const float * src0_ptr = (device const float *)(src0 + i2*args.nb02 + i3*args.nb03); + device float * dst_ptr = (device float *)(dst + i1*args.nb01 + i2*args.nb2 + i3*args.nb3); + + for (int i0 = tiitg; i0 < args.ne0; i0 += NW) { + dst_ptr[i0] = i0 == i1 ? src0_ptr[i0] : 0.0f; + } +} + constant bool FC_mul_mm_bc_inp [[function_constant(FC_MUL_MM + 0)]]; constant bool FC_mul_mm_bc_out [[function_constant(FC_MUL_MM + 1)]]; +constant short FC_mul_mm_ne12 [[function_constant(FC_MUL_MM + 2)]]; +constant short FC_mul_mm_ne13 [[function_constant(FC_MUL_MM + 3)]]; +constant short FC_mul_mm_r2 [[function_constant(FC_MUL_MM + 4)]]; +constant short FC_mul_mm_r3 [[function_constant(FC_MUL_MM + 5)]]; // each block_q contains 16*nl weights -template<typename S0, typename S0_4x4, typename S0_8x8, typename S1, typename S1_2x4, typename S1_8x8, typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread S0_4x4 &), typename T0, typename T0_4x4, typename T1, typename T1_2x4> +#ifdef GGML_METAL_HAS_TENSOR +template< + typename SA, typename SA_4x4, typename SA_8x8, + typename SB, typename SB_2x4, typename SB_8x8, + typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread SA_4x4 &), + typename T0, typename T0_4x4, typename T1, typename T1_2x4> +kernel void kernel_mul_mm( + constant ggml_metal_kargs_mul_mm & args, + device const char * srcA, + device const char * srcB, + device char * dst, + threadgroup char * shmem [[threadgroup(0)]], + uint3 tgpig [[threadgroup_position_in_grid]], + ushort tiitg [[thread_index_in_threadgroup]], + ushort sgitg [[simdgroup_index_in_threadgroup]]) { + (void) sgitg; + + // Matrix dimensions: A(M,K) x B(K,N) -> C(M,N) + const int K = args.ne00; + const int M = args.ne0; + const int N = args.ne1; + + // Batch dimension handling + const int im = tgpig.z; + const int i12 = im % FC_mul_mm_ne12; + const int i13 = im / FC_mul_mm_ne12; + + // Batch offsets for srcA and srcB + const uint64_t offset0 = (i12/FC_mul_mm_r2)*args.nb02 + (i13/FC_mul_mm_r3)*args.nb03; + + // Tile dimensions + constexpr int NRB = SZ_SIMDGROUP * N_MM_BLOCK_X * N_MM_SIMD_GROUP_X; + constexpr int NRA = SZ_SIMDGROUP * N_MM_BLOCK_Y * N_MM_SIMD_GROUP_Y; + + // Tile offsets in output matrix + const int ra = tgpig.y * NRA; + const int rb = tgpig.x * NRB; + + // Threadgroup memory for dequantized A tile only + threadgroup SA * sa = (threadgroup SA *)(shmem); + + // Work-item count for A loading + constexpr int A_WORK_ITEMS = NRA * N_MM_NK; + constexpr int NUM_THREADS = N_SIMDWIDTH * N_MM_SIMD_GROUP_X * N_MM_SIMD_GROUP_Y; + + // tA wraps threadgroup memory + auto tA = tensor(sa, dextents<int32_t, 2>(N_MM_NK_TOTAL, NRA)); + + // tB wraps device memory directly + device T1 * ptrB = (device T1 *)(srcB + args.nb12*i12 + args.nb13*i13); + const int strideB = args.nb11 / sizeof(T1); + auto tB = tensor(ptrB, dextents<int32_t, 2>(K, N), array<int, 2>({1, strideB})); + + // Configure matmul operation + mpp::tensor_ops::matmul2d< + mpp::tensor_ops::matmul2d_descriptor( + NRB, NRA, N_MM_NK_TOTAL, false, true, true, + mpp::tensor_ops::matmul2d_descriptor::mode::multiply_accumulate), + execution_simdgroups<N_MM_SIMD_GROUP_X * N_MM_SIMD_GROUP_Y>> mm; + + auto cT = mm.get_destination_cooperative_tensor<decltype(tB), decltype(tA), float>(); + + // Accumulate partial results over K dimension + for (int loop_k = 0; loop_k < K; loop_k += N_MM_NK_TOTAL) { + // === PHASE 1: Dequantization of A into threadgroup memory === + for (int work = tiitg; work < A_WORK_ITEMS; work += NUM_THREADS) { + const int row = work / N_MM_NK; + const int k_chunk = work % N_MM_NK; + const int k_pos = loop_k + k_chunk * 16; + const short k_base = k_chunk * 16; + + // Bounds check: skip device read if row is out of matrix bounds + if (ra + row < M) { + if (is_same<T0_4x4, block_q>::value && FC_mul_mm_bc_inp) { + // Element-wise reads when K is not aligned (nb01 not aligned for half4x4/float4x4). + // MSL spec Table 2.5: half4x4 requires 8-byte alignment. When K is odd, + // nb01 = K*2 is not 8-byte aligned, so odd-row pointers are misaligned. + // Mirrors the legacy kernel's existing guard. + device const T0 * row_ptr = (device const T0 *)(srcA + args.nb01 * (ra + row) + offset0); + + FOR_UNROLL (short i = 0; i < 16; i++) { + sa[row * N_MM_NK_TOTAL + (k_base + i)] = (k_pos + i < K) ? (SA) row_ptr[k_pos + i] : (SA)0; + } + } else { + const int block_idx = k_pos / (16 * nl); + const short il = (k_pos / 16) % nl; + + device const block_q * row_ptr = (device const block_q *)(srcA + args.nb01 * (ra + row) + offset0); + + SA_4x4 temp_a; + dequantize_func(row_ptr + block_idx, il, temp_a); + + FOR_UNROLL (short i = 0; i < 16; i++) { + // Zero-pad A for K positions beyond valid range (handles partial K iterations) + sa[row * N_MM_NK_TOTAL + (k_base + i)] = (k_pos + i < K) ? temp_a[i/4][i%4] : (SA)0; + } + } + } else { + // Zero-pad rows beyond matrix bounds + FOR_UNROLL (short i = 0; i < 16; i++) { + sa[row * N_MM_NK_TOTAL + (k_base + i)] = (SA)0; + } + } + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + // === PHASE 2: Tensor matmul === + auto mA = tA.slice(0, 0); + auto mB = tB.slice(loop_k, rb); + + mm.run(mB, mA, cT); + + threadgroup_barrier(mem_flags::mem_threadgroup); + } + + // Store result tile to output matrix (with batch offset) + // cT.store handles bounds checking via tD's extents (M, N) + device float * dstBatch = (device float *)dst + im * N * M; + + auto tD = tensor(dstBatch, dextents<int32_t, 2>(M, N), array<int, 2>({1, M})); + cT.store(tD.slice(ra, rb)); +} + +#else + +template< + typename S0, typename S0_4x4, typename S0_8x8, + typename S1, typename S1_2x4, typename S1_8x8, + typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread S0_4x4 &), + typename T0, typename T0_4x4, typename T1, typename T1_2x4> kernel void kernel_mul_mm( constant ggml_metal_kargs_mul_mm & args, device const char * src0, @@ -8797,8 +9557,6 @@ kernel void kernel_mul_mm( threadgroup S0 * sa = (threadgroup S0 *)(shmem); threadgroup S1 * sb = (threadgroup S1 *)(shmem + 4096); - threadgroup float * sc = (threadgroup float *)(shmem); - constexpr int NR0 = 64; constexpr int NR1 = 32; @@ -8822,10 +9580,10 @@ kernel void kernel_mul_mm( short il = il0; - const int i12 = im%args.ne12; - const int i13 = im/args.ne12; + const int i12 = im % FC_mul_mm_ne12; + const int i13 = im / FC_mul_mm_ne12; - const uint64_t offset0 = (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const uint64_t offset0 = (i12/FC_mul_mm_r2)*args.nb02 + (i13/FC_mul_mm_r3)*args.nb03; const short offset1 = il0/nl; device const block_q * x = (device const block_q *)(src0 + args.nb01*(r0 + lr0) + offset0) + offset1; @@ -8838,7 +9596,6 @@ kernel void kernel_mul_mm( + args.nb11*(r1 + lr1) + args.nb10*iy); -#ifndef GGML_METAL_HAS_TENSOR S0_8x8 ma[4]; S1_8x8 mb[2]; @@ -8847,19 +9604,8 @@ kernel void kernel_mul_mm( for (short i = 0; i < 8; i++){ mc[i] = make_filled_simdgroup_matrix<float, 8>(0.f); } -#else - auto tA = tensor<threadgroup S0, dextents<int32_t, 2>, tensor_inline>(sa, dextents<int32_t, 2>(NK, NR0)); - auto tB = tensor<threadgroup S1, dextents<int32_t, 2>, tensor_inline>(sb, dextents<int32_t, 2>(NR1, NK )); - - mpp::tensor_ops::matmul2d< - mpp::tensor_ops::matmul2d_descriptor(NR1, NR0, NK, false, true, false, mpp::tensor_ops::matmul2d_descriptor::mode::multiply_accumulate), - execution_simdgroups<4>> mm; - - auto cT = mm.get_destination_cooperative_tensor<decltype(tA), decltype(tB), float>(); -#endif for (int loop_k = 0; loop_k < args.ne00; loop_k += NK) { -#ifndef GGML_METAL_HAS_TENSOR // load data and store to threadgroup memory if (is_same<T0_4x4, block_q>::value && FC_mul_mm_bc_inp) { threadgroup_barrier(mem_flags::mem_threadgroup); @@ -8920,8 +9666,8 @@ kernel void kernel_mul_mm( const short sx = (tiitg%NL1); const short sy = (tiitg/NL1)/8; - const short dx = sx; - const short dy = sy; + //const short dx = sx; + //const short dy = sy; const short ly = (tiitg/NL1)%8; @@ -8929,66 +9675,6 @@ kernel void kernel_mul_mm( *(threadgroup S1_2x4 *)(sb + 64*ib + 8*ly) = (S1_2x4)(*((device T1_2x4 *) y)); } -#else - // load data and store to threadgroup memory - if (is_same<T0_4x4, block_q>::value && FC_mul_mm_bc_inp) { - threadgroup_barrier(mem_flags::mem_threadgroup); - - // no need for dequantization - for (short i = 0; i < 16; i++) { - const short sx = 2*il0 + i/8; - const short sy = (tiitg/NL0)/8; - - const short lx = i%8; - const short ly = (tiitg/NL0)%8; - //const short lx = (tiitg/NL0)%8; - //const short ly = i%8; - - *(sa + NK*(8*sy + ly) + 8*sx + lx) = loop_k + 16*il + i < args.ne00 ? *((device T0 *) x + i) : 0; - } - } else { - S0_4x4 temp_a; - dequantize_func(x, il, temp_a); - - threadgroup_barrier(mem_flags::mem_threadgroup); - - FOR_UNROLL (short i = 0; i < 16; i++) { - const short sx = 2*il0 + i/8; - const short sy = (tiitg/NL0)/8; - - const short lx = i%8; - const short ly = (tiitg/NL0)%8; - //const short lx = (tiitg/NL0)%8; - //const short ly = i%8; - - *(sa + NK*(8*sy + ly) + 8*sx + lx) = temp_a[i/4][i%4]; - } - } - - if (FC_mul_mm_bc_inp) { - for (short i = 0; i < 8; ++i) { - const short sx = (tiitg%NL1); - const short sy = (tiitg/NL1)/8; - - const short lx = i; - const short ly = (tiitg/NL1)%8; - //const short lx = (tiitg/NL1)%8; - //const short ly = i; - - *(sb + NK*(8*sy + ly) + 8*sx + lx) = loop_k + iy + i < args.ne00 ? (S1) *((device T1 *) y + i) : 0; - } - } else { - const short sx = (tiitg%NL1); - const short sy = (tiitg/NL1)/8; - - //const short lx = i; - const short ly = (tiitg/NL1)%8; - //const short lx = (tiitg/NL1)%8; - //const short ly = i; - - *(threadgroup S1_2x4 *)(sb + NK*(8*sy + ly) + 8*sx) = (S1_2x4)(*((device T1_2x4 *) y)); - } -#endif il = (il + 2 < nl) ? il + 2 : il % 2; x = (il < 2) ? x + (2 + nl - 1)/nl : x; @@ -8997,7 +9683,6 @@ kernel void kernel_mul_mm( threadgroup_barrier(mem_flags::mem_threadgroup); -#ifndef GGML_METAL_HAS_TENSOR // load matrices from threadgroup memory and conduct outer products threadgroup const S0 * lsma = (sa + 4*64*(sgitg%2)); threadgroup const S1 * lsmb = (sb + 2*64*(sgitg/2)); @@ -9024,24 +9709,10 @@ kernel void kernel_mul_mm( lsma += 8*64; lsmb += 4*64; } -#else - auto sA = tA.slice(0, 0); - auto sB = tB.slice(0, 0); - - mm.run(sB, sA, cT); -#endif } if (!FC_mul_mm_bc_out || (r0 + NR0 <= args.ne0 && r1 + NR1 <= args.ne1)) { // if no bounds checks on the output are needed, we can directly write to device memory -#ifdef GGML_METAL_HAS_TENSOR - device float * C = (device float *) dst + - r0 + \ - r1 * args.ne0 + im*args.ne1*args.ne0; - - auto tC = tensor<device float, dextents<int32_t, 2>, tensor_inline>(C, dextents<int32_t, 2>(args.ne0, NR1)); - cT.store(tC); -#else device float * C = (device float *) dst + (r0 + 32*(sgitg & 1)) + \ (r1 + 16*(sgitg >> 1)) * args.ne0 + im*args.ne1*args.ne0; @@ -9049,21 +9720,15 @@ kernel void kernel_mul_mm( for (short i = 0; i < 8; i++) { simdgroup_store(mc[i], C + 8*(i%4) + 8*args.ne0*(i/4), args.ne0, 0, false); } -#endif } else { // block is smaller than 64x32, we should avoid writing data outside of the matrix threadgroup_barrier(mem_flags::mem_threadgroup); threadgroup float * temp_str = ((threadgroup float *) shmem) + 32*(sgitg&1) + (16*(sgitg >> 1))*NR0; -#ifdef GGML_METAL_HAS_TENSOR - auto tC = tensor<threadgroup float, dextents<int32_t, 2>, tensor_inline>(sc, dextents<int32_t, 2>(NR0, NR1)); - cT.store(tC); -#else for (short i = 0; i < 8; i++) { simdgroup_store(mc[i], temp_str + 8*(i%4) + 8*NR0*(i/4), NR0, 0, false); } -#endif threadgroup_barrier(mem_flags::mem_threadgroup); @@ -9089,6 +9754,8 @@ kernel void kernel_mul_mm( } } +#endif // GGML_METAL_HAS_TENSOR + template<short ne20> // n_expert_used kernel void kernel_mul_mm_id_map0( constant ggml_metal_kargs_mul_mm_id_map0 & args, @@ -9153,6 +9820,7 @@ template [[host_name("kernel_mul_mm_id_map0_ne20_6" )]] kernel kernel_mul_mm_id_ template [[host_name("kernel_mul_mm_id_map0_ne20_8" )]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<8>; template [[host_name("kernel_mul_mm_id_map0_ne20_10")]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<10>; template [[host_name("kernel_mul_mm_id_map0_ne20_16")]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<16>; +template [[host_name("kernel_mul_mm_id_map0_ne20_22")]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<22>; template<typename S0, typename S0_4x4, typename S0_8x8, typename S1, typename S1_2x4, typename S1_8x8, typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread S0_4x4 &), typename T0, typename T0_4x4, typename T1, typename T1_2x4> kernel void kernel_mul_mm_id( @@ -9170,7 +9838,9 @@ kernel void kernel_mul_mm_id( threadgroup S0 * sa = (threadgroup S0 *)(shmem); threadgroup S1 * sb = (threadgroup S1 *)(shmem + 4096); +#ifdef GGML_METAL_HAS_TENSOR threadgroup float * sc = (threadgroup float *)(shmem); +#endif constexpr int NR0 = 64; constexpr int NR1 = 32; @@ -9261,7 +9931,7 @@ kernel void kernel_mul_mm_id( const short ib = 8*sx + sy; - *(sa + 64*ib + 8*ly + lx) = loop_k + 16*il + i < args.ne00 ? *((device T0 *) x + i) : 0; + *(sa + 64*ib + 8*ly + lx) = loop_k + 16*il + i < args.ne00 ? (S0) *((device T0 *) x + i) : (S0) 0; } } else { S0_4x4 temp_a; @@ -9305,8 +9975,8 @@ kernel void kernel_mul_mm_id( const short sx = (tiitg%NL1); const short sy = (tiitg/NL1)/8; - const short dx = sx; - const short dy = sy; + //const short dx = sx; + //const short dy = sy; const short ly = (tiitg/NL1)%8; @@ -9474,6 +10144,7 @@ template [[host_name("kernel_get_rows_bf16")]] kernel get_rows_f_t kernel_get_ro typedef decltype(kernel_get_rows_q<block_q4_0, 2, dequantize_q4_0>) get_rows_q_t; +template [[host_name("kernel_get_rows_q1_0")]] kernel get_rows_q_t kernel_get_rows_q<block_q1_0, 8, dequantize_q1_0>; template [[host_name("kernel_get_rows_q4_0")]] kernel get_rows_q_t kernel_get_rows_q<block_q4_0, 2, dequantize_q4_0>; template [[host_name("kernel_get_rows_q4_1")]] kernel get_rows_q_t kernel_get_rows_q<block_q4_1, 2, dequantize_q4_1>; template [[host_name("kernel_get_rows_q5_0")]] kernel get_rows_q_t kernel_get_rows_q<block_q5_0, 2, dequantize_q5_0>; @@ -9536,6 +10207,7 @@ template [[host_name("kernel_mul_mm_f16_f32")]] kernel mul_mm_t kernel_mul_m #if defined(GGML_METAL_HAS_BF16) template [[host_name("kernel_mul_mm_bf16_f32")]] kernel mul_mm_t kernel_mul_mm<bfloat, bfloat4x4, simdgroup_bfloat8x8, bfloat, bfloat2x4, simdgroup_bfloat8x8, bfloat4x4, 1, dequantize_bf16, bfloat, bfloat4x4, float, float2x4>; #endif +template [[host_name("kernel_mul_mm_q1_0_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q1_0, 8, dequantize_q1_0, float, float4x4, float, float2x4>; template [[host_name("kernel_mul_mm_q4_0_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q4_0, 2, dequantize_q4_0, float, float4x4, float, float2x4>; template [[host_name("kernel_mul_mm_q4_1_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q4_1, 2, dequantize_q4_1, float, float4x4, float, float2x4>; template [[host_name("kernel_mul_mm_q5_0_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q5_0, 2, dequantize_q5_0, float, float4x4, float, float2x4>; @@ -9559,6 +10231,7 @@ template [[host_name("kernel_mul_mm_iq4_xs_f32")]] kernel mul_mm_t kernel_mul_m template [[host_name("kernel_mul_mm_f32_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, float4x4, 1, dequantize_f32, float, float4x4, half, half2x4>; template [[host_name("kernel_mul_mm_f16_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, half4x4, 1, dequantize_f16, half, half4x4, half, half2x4>; +template [[host_name("kernel_mul_mm_q1_0_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q1_0, 8, dequantize_q1_0, float, float4x4, half, half2x4>; template [[host_name("kernel_mul_mm_q4_0_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q4_0, 2, dequantize_q4_0, float, float4x4, half, half2x4>; template [[host_name("kernel_mul_mm_q4_1_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q4_1, 2, dequantize_q4_1, float, float4x4, half, half2x4>; template [[host_name("kernel_mul_mm_q5_0_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q5_0, 2, dequantize_q5_0, float, float4x4, half, half2x4>; @@ -9591,6 +10264,7 @@ template [[host_name("kernel_mul_mm_id_f16_f32")]] kernel mul_mm_id kernel_m #if defined(GGML_METAL_HAS_BF16) template [[host_name("kernel_mul_mm_id_bf16_f32")]] kernel mul_mm_id kernel_mul_mm_id<bfloat, bfloat4x4, simdgroup_bfloat8x8, bfloat, bfloat2x4, simdgroup_bfloat8x8, bfloat4x4, 1, dequantize_bf16, bfloat, bfloat4x4, float, float2x4>; #endif +template [[host_name("kernel_mul_mm_id_q1_0_f32")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q1_0, 8, dequantize_q1_0, float, float4x4, float, float2x4>; template [[host_name("kernel_mul_mm_id_q4_0_f32")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q4_0, 2, dequantize_q4_0, float, float4x4, float, float2x4>; template [[host_name("kernel_mul_mm_id_q4_1_f32")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q4_1, 2, dequantize_q4_1, float, float4x4, float, float2x4>; template [[host_name("kernel_mul_mm_id_q5_0_f32")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q5_0, 2, dequantize_q5_0, float, float4x4, float, float2x4>; @@ -9614,6 +10288,7 @@ template [[host_name("kernel_mul_mm_id_iq4_xs_f32")]] kernel mul_mm_id kernel_m template [[host_name("kernel_mul_mm_id_f32_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, float4x4, 1, dequantize_f32, float, float4x4, half, half2x4>; template [[host_name("kernel_mul_mm_id_f16_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, half4x4, 1, dequantize_f16, half, half4x4, half, half2x4>; +template [[host_name("kernel_mul_mm_id_q1_0_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q1_0, 8, dequantize_q1_0, float, float4x4, half, half2x4>; template [[host_name("kernel_mul_mm_id_q4_0_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q4_0, 2, dequantize_q4_0, float, float4x4, half, half2x4>; template [[host_name("kernel_mul_mm_id_q4_1_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q4_1, 2, dequantize_q4_1, float, float4x4, half, half2x4>; template [[host_name("kernel_mul_mm_id_q5_0_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q5_0, 2, dequantize_q5_0, float, float4x4, half, half2x4>; @@ -9768,6 +10443,7 @@ template [[host_name("kernel_mul_mv_id_bf16_f32_4")]] kernel kernel_mul_mv_id_4 template [[host_name("kernel_mul_mv_id_q8_0_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q8_0_f32_impl<N_R0_Q8_0>>>; +template [[host_name("kernel_mul_mv_id_q1_0_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q1_0_f32_impl<N_R0_Q1_0>>>; template [[host_name("kernel_mul_mv_id_q4_0_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<mul_vec_q_n_f32_impl<block_q4_0, N_R0_Q4_0>>>; template [[host_name("kernel_mul_mv_id_q4_1_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<mul_vec_q_n_f32_impl<block_q4_1, N_R0_Q4_1>>>; template [[host_name("kernel_mul_mv_id_q5_0_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<mul_vec_q_n_f32_impl<block_q5_0, N_R0_Q5_0>>>; @@ -9869,6 +10545,74 @@ kernel void kernel_pool_2d_avg_f32( o_ptr[cur_oh * args.OW + cur_ow] = res; } + +kernel void kernel_pool_1d_max_f32( + constant ggml_metal_kargs_pool_1d & args, + device const float * src, + device float * dst, + uint gid [[thread_position_in_grid]] +) { + + if (gid >= args.np) { + return; + } + + const int ow = (int)gid % args.OW; + const int row = (int)gid / args.OW; + + const int base = ow * args.s0 - args.p0; + + float acc = -INFINITY; + + const int src_off = row * args.IW; + const int dst_off = row * args.OW; + + for (int ki = 0; ki < args.k0; ++ki) { + int j = base + ki; + if (j < 0 || j >= args.IW){ + continue; + } + float v = src[src_off + j]; + acc = max(acc, v); + } + + dst[dst_off + ow] = acc; +} + +kernel void kernel_pool_1d_avg_f32( + constant ggml_metal_kargs_pool_1d & args, + device const float * src, + device float * dst, + uint gid [[thread_position_in_grid]] +) { + + if (gid >= args.np) { + return; + } + + const int ow = (int)gid % args.OW; + const int row = (int)gid / args.OW; + + const int base = ow * args.s0 - args.p0; + + float acc = 0.0f; + int cnt = 0; + + const int src_off = row * args.IW; + const int dst_off = row * args.OW; + + for (int ki = 0; ki < args.k0; ++ki) { + const int j = base + ki; + if (j < 0 || j >= args.IW) { + continue; + } + acc += src[src_off + j]; + cnt += 1; + } + + dst[dst_off + ow] = (cnt > 0) ? (acc / (float)cnt) : 0.0f; +} + kernel void kernel_opt_step_adamw_f32( constant ggml_metal_kargs_opt_step_adamw & args, device float * x, @@ -9919,7 +10663,7 @@ kernel void kernel_opt_step_sgd_f32( template<typename T> kernel void kernel_memset( - constant ggml_metal_kargs_fill & args, + constant ggml_metal_kargs_memset & args, device T * dst, uint tpig[[thread_position_in_grid]]) { dst[tpig] = args.val; diff --git a/ggml/src/ggml-musa/CMakeLists.txt b/ggml/src/ggml-musa/CMakeLists.txt index d76cb51977f..cc53c812ce5 100644 --- a/ggml/src/ggml-musa/CMakeLists.txt +++ b/ggml/src/ggml-musa/CMakeLists.txt @@ -48,12 +48,11 @@ if (MUSAToolkit_FOUND) list(APPEND GGML_SOURCES_MUSA ${SRCS}) add_compile_definitions(GGML_CUDA_FA_ALL_QUANTS) else() - file(GLOB SRCS "../ggml-cuda/template-instances/fattn-vec*q4_0-q4_0.cu") - list(APPEND GGML_SOURCES_MUSA ${SRCS}) - file(GLOB SRCS "../ggml-cuda/template-instances/fattn-vec*q8_0-q8_0.cu") - list(APPEND GGML_SOURCES_MUSA ${SRCS}) - file(GLOB SRCS "../ggml-cuda/template-instances/fattn-vec*f16-f16.cu") - list(APPEND GGML_SOURCES_MUSA ${SRCS}) + list(APPEND GGML_SOURCES_MUSA + ../ggml-cuda/template-instances/fattn-vec-instance-f16-f16.cu + ../ggml-cuda/template-instances/fattn-vec-instance-q4_0-q4_0.cu + ../ggml-cuda/template-instances/fattn-vec-instance-q8_0-q8_0.cu + ../ggml-cuda/template-instances/fattn-vec-instance-bf16-bf16.cu) endif() set_source_files_properties(${GGML_SOURCES_MUSA} PROPERTIES LANGUAGE CXX) diff --git a/ggml/src/ggml-opencl/CMakeLists.txt b/ggml/src/ggml-opencl/CMakeLists.txt index d8fa53109b7..82ce61d72c6 100644 --- a/ggml/src/ggml-opencl/CMakeLists.txt +++ b/ggml/src/ggml-opencl/CMakeLists.txt @@ -57,22 +57,22 @@ set(GGML_OPENCL_KERNELS add add_id argsort + tri fill clamp cpy cvt diag_mask_inf + diag div gelu - gemv_noshuffle_general - gemv_noshuffle get_rows glu group_norm + solve_tri im2col_f32 im2col_f16 mean - mul_mat_Ab_Bi_8x4 mul_mv_f16_f16 mul_mv_f16_f32_1row mul_mv_f16_f32_l4 @@ -83,9 +83,22 @@ set(GGML_OPENCL_KERNELS mul_mv_q4_0_f32_8x_flat mul_mv_q4_0_f32_1d_8x_flat mul_mv_q4_0_f32_1d_16x_flat - mul_mv_q6_k + mul_mv_q4_1_f32 + mul_mv_q4_1_f32_flat + mul_mv_q4_k_f32 + mul_mv_q4_k_f32_flat + mul_mv_q5_0_f32 + mul_mv_q5_0_f32_flat + mul_mv_q5_1_f32 + mul_mv_q5_1_f32_flat + mul_mv_q5_k_f32 + mul_mv_q5_k_f32_flat + mul_mv_q6_k_f32 + mul_mv_q6_k_f32_flat mul_mv_q8_0_f32 mul_mv_q8_0_f32_flat + mul_mv_iq4_nl_f32 + mul_mv_iq4_nl_f32_flat mul_mv_mxfp4_f32 mul_mv_mxfp4_f32_flat mul_mv_id_q4_0_f32_8x_flat @@ -93,14 +106,61 @@ set(GGML_OPENCL_KERNELS mul_mv_id_q8_0_f32_flat mul_mv_id_mxfp4_f32 mul_mv_id_mxfp4_f32_flat + gemm_moe_q4_0_f32_ns + gemv_moe_q4_0_f32_ns + gemm_moe_q4_1_f32_ns + gemv_moe_q4_1_f32_ns + gemm_moe_q5_0_f32_ns + gemv_moe_q5_0_f32_ns + gemm_moe_q5_1_f32_ns + gemv_moe_q5_1_f32_ns + gemm_moe_q4_k_f32_ns + gemv_moe_q4_k_f32_ns + gemm_moe_q5_k_f32_ns + gemv_moe_q5_k_f32_ns + gemm_moe_q6_k_f32_ns + gemv_moe_q6_k_f32_ns gemm_moe_mxfp4_f32 gemv_moe_mxfp4_f32 + gemm_moe_mxfp4_f32_ns + gemv_moe_mxfp4_f32_ns + moe_reorder_b + moe_sort_by_expert mul_mm_f32_f32_l4_lm mul_mm_f16_f32_l4_lm + mul_mm_q4_0_f32_l4_lm + mul_mm_q4_1_f32_l4_lm + mul_mm_q5_0_f32_l4_lm + mul_mm_q5_1_f32_l4_lm mul_mm_q8_0_f32_l4_lm + mul_mm_iq4_nl_f32_l4_lm + mul_mm_q4_k_f32_l4_lm + mul_mm_q5_k_f32_l4_lm + mul_mm_q6_k_f32_l4_lm + gemv_noshuffle_q4_0_f32 + gemv_noshuffle_q4_0_f32_spec + gemm_noshuffle_q4_0_f32 + gemv_noshuffle_q4_1_f32 + gemm_noshuffle_q4_1_f32 + gemv_noshuffle_q5_0_f32 + gemm_noshuffle_q5_0_f32 + gemv_noshuffle_q5_1_f32 + gemm_noshuffle_q5_1_f32 + gemv_noshuffle_iq4_nl_f32 + gemm_noshuffle_iq4_nl_f32 + gemv_noshuffle_q8_0_f32 + gemm_noshuffle_q8_0_f32 + gemv_noshuffle_q4_k_f32 + gemm_noshuffle_q4_k_f32 + gemv_noshuffle_q6_k_f32 + gemm_noshuffle_q6_k_f32 + gemv_noshuffle_q5_k_f32 + gemm_noshuffle_q5_k_f32 mul + neg norm relu + l2_norm rms_norm rope scale @@ -114,13 +174,16 @@ set(GGML_OPENCL_KERNELS sqr sqrt ssm_conv + gated_delta_net sub sum_rows + cumsum transpose concat tsembd upscale tanh + exp expm1 softplus pad @@ -134,6 +197,10 @@ set(GGML_OPENCL_KERNELS flash_attn_f32 ) +if (GGML_OPENCL_USE_ADRENO_KERNELS) + list(APPEND GGML_OPENCL_KERNELS gemm_xmem_f16_f32_os8) +endif () + foreach (K ${GGML_OPENCL_KERNELS}) ggml_opencl_add_kernel(${K}) endforeach() diff --git a/ggml/src/ggml-opencl/ggml-opencl.cpp b/ggml/src/ggml-opencl/ggml-opencl.cpp index d925f67f065..ca2002424df 100644 --- a/ggml/src/ggml-opencl/ggml-opencl.cpp +++ b/ggml/src/ggml-opencl/ggml-opencl.cpp @@ -28,6 +28,7 @@ #include <memory> #include <charconv> #include <mutex> +#include <regex> #undef MIN #undef MAX @@ -226,7 +227,8 @@ static ADRENO_GPU_GEN get_adreno_gpu_gen(const char *device_name) { return ADRENO_GPU_GEN::A7X; } - if (strstr(device_name, "830")) { + if (strstr(device_name, "830") || + strstr(device_name, "840")) { return ADRENO_GPU_GEN::A8X; } @@ -312,7 +314,7 @@ struct ProfilingInfo { cl_ulong cmd_duration_ns; // The time for the kernel to complete - COMPLETE - END cl_ulong cmd_complete_duration_ns; - // Total time to finish the kernel - COMPELTE - QUEUED + // Total time to finish the kernel - COMPLETE - QUEUED cl_ulong cmd_total_duration_ns; // Global and local work sizes. size_t global_size[3]; @@ -373,6 +375,13 @@ struct ggml_backend_opencl_device_context { ggml_backend_buffer_type buffer_type; cl_context context = nullptr; + + GPU_FAMILY gpu_family = GPU_FAMILY::UNKNOWN; + ADRENO_GPU_GEN adreno_gen = ADRENO_GPU_GEN::ADRENO_UNKNOWN; + + std::regex *opfilter = nullptr; // regex of ops to not claim + std::string opfilter_str = ""; // regex string for opfilter + size_t global_mem_size = 0; }; // backend context @@ -382,22 +391,44 @@ struct ggml_backend_opencl_context { cl_device_id device; std::string device_name; + ggml_cl_version platform_version; + ggml_cl_version opencl_c_version; + + // argsort is loaded in supports_op because its availability depends on how + // many workgroups are allowed, which requires kernel compilation. + bool kernels_loaded_argsort = false; + // flash attn is loaded in supports_op because it contains multiple variants + // and takes time to compile, so we want to only compile it when needed. + bool kernels_loaded_flash_attn = false; + // rest of the kernels are currently always loaded in alloc_buffer. + bool kernels_loaded = false; + std::string driver_version; GPU_FAMILY gpu_family; ADRENO_GPU_GEN adreno_gen; cl_int alignment; + size_t global_mem_size; size_t max_alloc_size; size_t max_workgroup_size; bool fp16_support; bool has_vector_subgroup_broadcast; + bool has_qcom_subgroup_shuffle = false; // cl_qcom_subgroup_shuffle bool disable_fusion; + + bool adreno_has_large_buffer; + bool adreno_use_large_buffer; ggml_cl_compiler_version adreno_cl_compiler_version; int adreno_wave_size; cl_bool non_uniform_workgroups; + size_t image_max_buffer_size; + size_t image2d_max_width; + size_t image2d_max_height; + + cl_device_svm_capabilities svm_caps; cl_context context; cl_command_queue queue; @@ -407,10 +438,27 @@ struct ggml_backend_opencl_context { ggml_cl_buffer prealloc_scales_trans; ggml_cl_buffer prealloc_act_trans; + // prealloc buffers for src0 and src1 + ggml_cl_buffer prealloc_src0; + ggml_cl_buffer prealloc_src1; + +#ifdef GGML_OPENCL_USE_ADRENO_KERNELS + ggml_cl_buffer prealloc_adreno_xmem_const; + bool adreno_xmem_gemm_enabled = false; +#endif + + // prealloc buffers for MoE router table preprocess + bool toggle_reorder = false; + ggml_cl_buffer prealloc_post_router; + ggml_cl_buffer prealloc_emap; + ggml_cl_buffer prealloc_hist; + ggml_cl_buffer prealloc_tile_offset; + ggml_cl_buffer prealloc_total_tiles; + ggml_cl_buffer prealloc_slot_counter; + cl_program program_add; cl_program program_add_id; cl_program program_clamp; - cl_program program_cpy; cl_program program_cvt; cl_program program_diag_mask_inf; cl_program program_gelu; @@ -447,7 +495,6 @@ struct ggml_backend_opencl_context { cl_program program_rms_norm; cl_program program_group_norm; cl_program program_rope; - cl_program program_scale; cl_program program_silu; cl_program program_sigmoid; cl_program program_softmax_f32; @@ -456,11 +503,8 @@ struct ggml_backend_opencl_context { cl_program program_softmax_4_f16; cl_program program_argsort_f32_i32; cl_program program_sum_rows_f32; - cl_program program_repeat; cl_program program_pad; - cl_program program_tanh; cl_program program_upscale; - cl_program program_concat; cl_program program_conv_2d_f16; cl_program program_conv_2d_f32; cl_program program_conv_2d_f16_f32; @@ -479,24 +523,27 @@ struct ggml_backend_opencl_context { cl_kernel kernel_div, kernel_div_row, kernel_div_f16, kernel_div_row_f16; cl_kernel kernel_sub, kernel_sub_row, kernel_sub_f16, kernel_sub_row_f16; cl_kernel kernel_add_id; - cl_kernel kernel_scale; + cl_kernel kernel_scale_f32, kernel_scale_f32_4; cl_kernel kernel_sqr_cont_f32, kernel_sqr_cont_f32_4, kernel_sqr_cont_f16, kernel_sqr_cont_f16_4; cl_kernel kernel_sqrt_cont_f32, kernel_sqrt_cont_f32_4, kernel_sqrt_cont_f16, kernel_sqrt_cont_f16_4; - cl_kernel kernel_mean_f32; + cl_kernel kernel_mean_f32, kernel_mean_f32_4; cl_kernel kernel_silu, kernel_silu_4; cl_kernel kernel_gelu, kernel_gelu_4; cl_kernel kernel_gelu_erf, kernel_gelu_erf_4; cl_kernel kernel_gelu_quick, kernel_gelu_quick_4; cl_kernel kernel_relu; cl_kernel kernel_sigmoid_f32, kernel_sigmoid_f16; + cl_kernel kernel_tri; cl_kernel kernel_fill; cl_kernel kernel_clamp; cl_kernel kernel_geglu, kernel_reglu, kernel_swiglu, kernel_swiglu_oai, kernel_geglu_erf, kernel_geglu_quick, kernel_geglu_f16, kernel_reglu_f16, kernel_swiglu_f16, kernel_geglu_erf_f16, kernel_geglu_quick_f16; cl_kernel kernel_norm, kernel_norm_mul_add; cl_kernel kernel_rms_norm, kernel_rms_norm_mul; + cl_kernel kernel_l2_norm_f32; cl_kernel kernel_group_norm, kernel_group_norm_mul_add; cl_kernel kernel_diag_mask_inf, kernel_diag_mask_inf_8; + cl_kernel kernel_diag_f32; cl_kernel kernel_soft_max, kernel_soft_max_4; cl_kernel kernel_soft_max_f16, kernel_soft_max_4_f16; std::map<std::pair<int, int>, cl_kernel> kernels_flash_attn_f16; @@ -511,61 +558,133 @@ struct ggml_backend_opencl_context { cl_kernel kernel_set_rows_f32_i64, kernel_set_rows_f32_i32, kernel_set_rows_f16_i64, kernel_set_rows_f16_i32; cl_kernel kernel_rope_norm_f32, kernel_rope_norm_f16, kernel_rope_neox_f32, kernel_rope_neox_f16; cl_kernel kernel_rope_multi_f32, kernel_rope_multi_f16, kernel_rope_vision_f32, kernel_rope_vision_f16; - cl_kernel kernel_cpy_f16_f16, kernel_cpy_f16_f32, kernel_cpy_f32_f16, kernel_cpy_f32_f32; + cl_kernel kernel_cpy_f16_f16, kernel_cpy_f16_f32, kernel_cpy_f32_f16, kernel_cpy_f32_f32, kernel_cpy_f32_f32_pack, kernel_cpy_i32_i32; cl_kernel kernel_mul_mat_f32_f32; cl_kernel kernel_mul_mat_f16_f16; cl_kernel kernel_mul_mat_f16_f32_1row; cl_kernel kernel_mul_mat_f16_f32; cl_kernel kernel_mul_mat_f16_f32_l4; cl_kernel kernel_mul_mat_f16_f32_tiled; + cl_kernel kernel_adreno_xmem_pack_src_f32; + cl_kernel kernel_adreno_xmem_prepack_weight_f16; + cl_kernel kernel_gemm_xmem_f16_f32_os8; + cl_kernel kernel_adreno_xmem_store_dst_f32; cl_kernel kernel_mul_mm_f16_f32_kqv; cl_kernel kernel_mul_mm_f16_f32_kq; cl_kernel kernel_mul_mat_q4_0_f32, kernel_mul_mat_q4_0_f32_v; cl_kernel kernel_convert_block_q4_0, kernel_restore_block_q4_0; + cl_kernel kernel_convert_block_q4_0_trans4_ns, kernel_restore_block_q4_0_trans4_ns; + cl_kernel kernel_convert_block_q4_1, kernel_restore_block_q4_1; + cl_kernel kernel_convert_block_q4_1_trans4_ns, kernel_restore_block_q4_1_trans4_ns; + cl_kernel kernel_convert_block_q5_0, kernel_restore_block_q5_0; + cl_kernel kernel_convert_block_q5_0_trans4_ns, kernel_restore_block_q5_0_trans4_ns; + cl_kernel kernel_convert_block_q5_1, kernel_restore_block_q5_1; + cl_kernel kernel_convert_block_q5_1_trans4_ns, kernel_restore_block_q5_1_trans4_ns; + cl_kernel kernel_convert_block_q4_k_trans4_ns, kernel_restore_block_q4_k_trans4_ns; + cl_kernel kernel_convert_block_q5_k_trans4_ns, kernel_restore_block_q5_k_trans4_ns; + cl_kernel kernel_convert_block_q6_k_trans4_ns, kernel_restore_block_q6_k_trans4_ns; cl_kernel kernel_convert_block_mxfp4, kernel_convert_block_mxfp4_trans, kernel_restore_block_mxfp4, kernel_restore_block_mxfp4_trans; - cl_kernel kernel_convert_block_q8_0, kernel_restore_block_q8_0; + cl_kernel kernel_convert_block_mxfp4_trans4_ns, kernel_restore_block_mxfp4_trans4_ns; + cl_kernel kernel_convert_block_q8_0, kernel_restore_block_q8_0, kernel_restore_block_q8_0_trans; + cl_kernel kernel_convert_block_q6_K_noshuffle, kernel_restore_block_q6_K_noshuffle; + cl_kernel kernel_convert_bf16_to_f16, kernel_convert_f16_to_bf16; cl_kernel kernel_mul_mat_q4_0_f32_8x_flat; cl_kernel kernel_convert_block_q4_0_noshuffle; cl_kernel kernel_restore_block_q4_0_noshuffle; + cl_kernel kernel_convert_block_q4_1_noshuffle; + cl_kernel kernel_restore_block_q4_1_noshuffle; + cl_kernel kernel_convert_block_q5_0_noshuffle; + cl_kernel kernel_restore_block_q5_0_noshuffle; + cl_kernel kernel_convert_block_q5_1_noshuffle; + cl_kernel kernel_restore_block_q5_1_noshuffle; + cl_kernel kernel_convert_block_q4_K_noshuffle; + cl_kernel kernel_restore_block_q4_K_noshuffle; + cl_kernel kernel_convert_block_q4_K, kernel_restore_block_q4_K; + cl_kernel kernel_convert_block_q5_K, kernel_restore_block_q5_K; + cl_kernel kernel_convert_block_q5_K_noshuffle; + cl_kernel kernel_restore_block_q5_K_noshuffle; + cl_kernel kernel_convert_block_q6_K, kernel_restore_block_q6_K; + cl_kernel kernel_convert_block_iq4_nl, kernel_restore_block_iq4_nl; + cl_kernel kernel_convert_block_iq4_nl_noshuffle; + cl_kernel kernel_restore_block_iq4_nl_noshuffle; cl_kernel kernel_mul_mat_q4_0_f32_1d_8x_flat, kernel_mul_mat_q4_0_f32_1d_16x_flat; + cl_kernel kernel_mul_mv_q4_1_f32; + cl_kernel kernel_mul_mv_q4_1_f32_flat; + cl_kernel kernel_mul_mv_q5_0_f32; + cl_kernel kernel_mul_mv_q5_0_f32_flat; + cl_kernel kernel_mul_mv_q5_1_f32; + cl_kernel kernel_mul_mv_q5_1_f32_flat; + cl_kernel kernel_mul_mv_q4_K_f32; + cl_kernel kernel_mul_mv_q4_K_f32_flat; + cl_kernel kernel_mul_mv_q5_K_f32; + cl_kernel kernel_mul_mv_q5_K_f32_flat; cl_kernel kernel_mul_mv_q6_K_f32; + cl_kernel kernel_mul_mv_q6_K_f32_flat; cl_kernel kernel_mul_mv_mxfp4_f32, kernel_mul_mv_mxfp4_f32_flat; cl_kernel kernel_mul_mv_q8_0_f32, kernel_mul_mv_q8_0_f32_flat; + cl_kernel kernel_mul_mv_iq4_nl_f32; + cl_kernel kernel_mul_mv_iq4_nl_f32_flat; + cl_kernel kernel_solve_tri_f32; cl_kernel kernel_im2col_f32, kernel_im2col_f16; cl_kernel kernel_argsort_f32_i32; - cl_kernel kernel_sum_rows_f32; - cl_kernel kernel_repeat; + cl_kernel kernel_sum_rows_f32, kernel_sum_rows_f32_4; + cl_kernel kernel_cumsum_blk, kernel_cumsum_add; + cl_kernel kernel_repeat_f32; cl_kernel kernel_pad; - cl_kernel kernel_tanh_f32_nd; - cl_kernel kernel_tanh_f16_nd; - cl_kernel kernel_expm1_f32_nd; - cl_kernel kernel_expm1_f16_nd; - cl_kernel kernel_softplus_f32_nd; - cl_kernel kernel_softplus_f16_nd; + cl_kernel kernel_tanh_f32, kernel_tanh_f32_4, kernel_tanh_f32_nc; + cl_kernel kernel_tanh_f16, kernel_tanh_f16_4, kernel_tanh_f16_nc; + cl_kernel kernel_neg_f32, kernel_neg_f32_4, kernel_neg_f32_nc; + cl_kernel kernel_neg_f16, kernel_neg_f16_4, kernel_neg_f16_nc; + cl_kernel kernel_exp_f32, kernel_exp_f32_4, kernel_exp_f32_nc; + cl_kernel kernel_exp_f16, kernel_exp_f16_4, kernel_exp_f16_nc; + cl_kernel kernel_expm1_f32, kernel_expm1_f32_4, kernel_expm1_f32_nc; + cl_kernel kernel_expm1_f16, kernel_expm1_f16_4, kernel_expm1_f16_nc; + cl_kernel kernel_softplus_f32, kernel_softplus_f32_4, kernel_softplus_f32_nc; + cl_kernel kernel_softplus_f16, kernel_softplus_f16_4, kernel_softplus_f16_nc; cl_kernel kernel_upscale; cl_kernel kernel_upscale_bilinear; - cl_kernel kernel_concat_f32_contiguous; - cl_kernel kernel_concat_f32_non_contiguous; + cl_kernel kernel_concat_f32, kernel_concat_f32_pack; cl_kernel kernel_conv_2d_f16; cl_kernel kernel_conv_2d_f32; cl_kernel kernel_conv_2d_f16_f32; cl_kernel kernel_ssm_conv_f32_f32, kernel_ssm_conv_f32_f32_4; + // [size_idx][kda][tgpp] where size_idx: 0=S_V=16, 1=32, 2=64, 3=128; kda: 0 or 1. + // tgpp 0 = TG variant (COLS_PER_LANE_GROUP=1), tgpp 1 = prefill variant (COLS_PER_LANE_GROUP=4). + cl_kernel kernel_gated_delta_net_f32[4][2][2] = {}; + cl_kernel kernel_timestep_embedding; + cl_kernel kernel_gemv_moe_q4_0_f32_ns, kernel_gemm_moe_q4_0_f32_ns; + cl_kernel kernel_gemv_moe_q4_1_f32_ns, kernel_gemm_moe_q4_1_f32_ns; + cl_kernel kernel_gemv_moe_q5_0_f32_ns, kernel_gemm_moe_q5_0_f32_ns; + cl_kernel kernel_gemv_moe_q5_1_f32_ns, kernel_gemm_moe_q5_1_f32_ns; + cl_kernel kernel_gemv_moe_q4_k_f32_ns, kernel_gemm_moe_q4_k_f32_ns; + cl_kernel kernel_gemv_moe_q5_k_f32_ns, kernel_gemm_moe_q5_k_f32_ns; + cl_kernel kernel_gemv_moe_q6_k_f32_ns, kernel_gemm_moe_q6_k_f32_ns; cl_kernel kernel_gemv_moe_mxfp4_f32, kernel_gemm_moe_mxfp4_f32; + cl_kernel kernel_gemv_moe_mxfp4_f32_ns, kernel_gemm_moe_mxfp4_f32_ns; + cl_kernel kernel_moe_reorder_b; + cl_kernel kernel_moe_histogram, kernel_moe_scan, kernel_moe_fill, kernel_moe_scatter; cl_kernel kernel_mul_mv_id_q4_0_f32_8x_flat; cl_kernel kernel_mul_mv_id_q8_0_f32, kernel_mul_mv_id_q8_0_f32_flat; cl_kernel kernel_mul_mv_id_mxfp4_f32; cl_kernel kernel_mul_mv_id_mxfp4_f32_flat; cl_kernel kernel_mul_mm_f32_f32_l4_lm; cl_kernel kernel_mul_mm_f16_f32_l4_lm; + cl_kernel kernel_mul_mm_q4_0_f32_l4_lm; + cl_kernel kernel_mul_mm_q4_1_f32_l4_lm; + cl_kernel kernel_mul_mm_q5_0_f32_l4_lm; + cl_kernel kernel_mul_mm_q5_1_f32_l4_lm; cl_kernel kernel_mul_mm_q8_0_f32_l4_lm; + cl_kernel kernel_mul_mm_q4_k_f32_l4_lm; + cl_kernel kernel_mul_mm_q5_k_f32_l4_lm; + cl_kernel kernel_mul_mm_q6_k_f32_l4_lm; + cl_kernel kernel_mul_mm_iq4_nl_f32_l4_lm; std::vector<ProfilingInfo> profiling_info; + std::vector<ProfilingInfo> profiling_results; - void write_profiling_info() { - FILE * fperf = fopen("cl_profiling.csv", "w"); - if (!fperf) { - GGML_LOG_ERROR("Failed to open cl_profiling.csv\n"); + void flush_profiling_batch() { + if (profiling_info.empty()) { return; } @@ -589,6 +708,7 @@ struct ggml_backend_opencl_context { CL_CHECK(clGetEventProfilingInfo( info.evt, CL_PROFILING_COMMAND_COMPLETE, sizeof(cl_ulong), &cmd_complete, NULL)); CL_CHECK(clReleaseEvent(info.evt)); + info.evt = nullptr; char kernel_name[512]; CL_CHECK(clGetKernelInfo(info.kernel, CL_KERNEL_FUNCTION_NAME, @@ -606,10 +726,26 @@ struct ggml_backend_opencl_context { info.cmd_complete_duration_ns = cmd_complete - cmd_end; info.cmd_total_duration_ns = cmd_complete - cmd_queued; } + profiling_results.insert(profiling_results.end(), + std::make_move_iterator(profiling_info.begin()), + std::make_move_iterator(profiling_info.end())); + profiling_info.clear(); + } + + void write_profiling_info() { + if (profiling_results.empty()) { + return; + } // Dump a csv + FILE * fperf = fopen("cl_profiling.csv", "w"); + if (!fperf) { + GGML_LOG_ERROR("Failed to open cl_profiling.csv\n"); + return; + } + fprintf(fperf, "op name, kernel name, exec duration (ms), global size, local size, output size\n"); - for (const ProfilingInfo & info : profiling_info) { + for (const ProfilingInfo & info : profiling_results) { fprintf(fperf, "%s,%s,%f,%zux%zux%zu,%zux%zux%zu,%zux%zux%zux%zu\n", info.op_name.c_str(), info.kernel_name.c_str(), info.cmd_duration_ns/1.e6f, @@ -620,14 +756,14 @@ struct ggml_backend_opencl_context { fclose(fperf); // Dump a simple chrome trace - FILE* ftrace = fopen("cl_trace.json", "w"); + FILE * ftrace = fopen("cl_trace.json", "w"); if (!ftrace) { GGML_LOG_ERROR("Failed to open cl_trace.json\n"); return; } fprintf(ftrace, "[\n"); - for (const ProfilingInfo & info : profiling_info) { + for (const ProfilingInfo & info : profiling_results) { fprintf(ftrace, "{\"name\": \"%s\", \"cat\": \"OpenCL\", \"ph\": \"B\", \"ts\": %" PRIu64 ", \"pid\": \"\", \"tid\": \"Host\"},\n", info.kernel_name.c_str(), info.cmd_queued/1000); fprintf(ftrace, "{\"name\": \"%s\", \"cat\": \"OpenCL\", \"ph\": \"E\", \"ts\": %" PRIu64 ", \"pid\": \"\", \"tid\": \"Host\"},\n", @@ -638,6 +774,7 @@ struct ggml_backend_opencl_context { fprintf(ftrace, "{\"name\": \"%s\", \"cat\": \"OpenCL\", \"ph\": \"E\", \"ts\": %" PRIu64 ", \"pid\": \"\", \"tid\": \"Device\"},\n", info.kernel_name.c_str(), info.cmd_end/1000); } + fprintf(ftrace, "]\n"); fclose(ftrace); } @@ -658,6 +795,9 @@ struct ggml_backend_opencl_context { profiling_info.emplace_back(); populateProfilingInfo(profiling_info.back(), evt, kernel, work_dim, global_work_size, local_work_size, tensor); + if (profiling_info.size() >= 2048) { + flush_profiling_batch(); + } #else GGML_UNUSED(tensor); CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, work_dim, NULL, global_work_size, local_work_size, 0, NULL, NULL)); @@ -671,30 +811,44 @@ struct ggml_backend_opencl_context { cl_kernel kernel_transpose_32; cl_kernel kernel_transpose_32_16; cl_kernel kernel_transpose_16; + cl_kernel kernel_transpose_8_buf; cl_kernel kernel_transpose_16_buf; + cl_kernel kernel_transpose_32_buf; cl_kernel kernel_transpose_16_4x1; // Gemm and Gemv related programs, kernels, etc - cl_program program_CL_gemm; - cl_program program_CL_gemv_general; - cl_program program_CL_gemv_4096_1_11008; - cl_program program_CL_gemv_4096_1_4096; - cl_program program_CL_gemv_11008_1_4096; - cl_program program_CL_gemv_32000_1_4096; - cl_kernel CL_mul_mat_Ab_Bi_8x4; - cl_kernel CL_mul_mat_vec_q4_0_f32_1d_4x_flat_general; - cl_kernel CL_mul_mat_vec_q4_0_f32_1d_4x_flat_4096_1_11008; - cl_kernel CL_mul_mat_vec_q4_0_f32_1d_4x_flat_4096_1_4096; - cl_kernel CL_mul_mat_vec_q4_0_f32_1d_4x_flat_11008_1_4096; - cl_kernel CL_mul_mat_vec_q4_0_f32_1d_4x_flat_32000_1_4096; + cl_kernel kernel_gemm_noshuffle_q4_0_f32; + cl_kernel kernel_gemv_noshuffle_q4_0_f32; + cl_kernel kernel_gemv_noshuffle_q4_0_f32_4096_1_11008; + cl_kernel kernel_gemv_noshuffle_q4_0_f32_4096_1_4096; + cl_kernel kernel_gemv_noshuffle_q4_0_f32_11008_1_4096; + cl_kernel kernel_gemv_noshuffle_q4_0_f32_32000_1_4096; + cl_kernel kernel_gemv_noshuffle_q4_1_f32; + cl_kernel kernel_gemm_noshuffle_q4_1_f32; + cl_kernel kernel_gemm_noshuffle_q8_0_f32; + cl_kernel kernel_gemv_noshuffle_q8_0_f32; + cl_kernel kernel_gemv_noshuffle_q4_k_f32; + cl_kernel kernel_gemm_noshuffle_q4_k_f32; + cl_kernel kernel_gemv_noshuffle_q6_K_f32; + cl_kernel kernel_gemm_noshuffle_q6_K_f32; + cl_kernel kernel_gemv_noshuffle_q5_k_f32; + cl_kernel kernel_gemm_noshuffle_q5_k_f32; + cl_kernel kernel_gemv_noshuffle_q5_0_f32; + cl_kernel kernel_gemm_noshuffle_q5_0_f32; + cl_kernel kernel_gemv_noshuffle_q5_1_f32; + cl_kernel kernel_gemm_noshuffle_q5_1_f32; + cl_kernel kernel_gemv_noshuffle_iq4_nl_f32; + cl_kernel kernel_gemm_noshuffle_iq4_nl_f32; #endif // GGML_OPENCL_USE_ADRENO_KERNELS void free() { + clFinish(queue); + ref_count--; if (ref_count == 0) { #ifdef GGML_OPENCL_PROFILING write_profiling_info(); - profiling_info.clear(); + profiling_results.clear(); #endif } } @@ -702,18 +856,21 @@ struct ggml_backend_opencl_context { // All registered devices with a default device in the front. static std::vector<ggml_backend_device> g_ggml_backend_opencl_devices; +// All device contexts associated with the devices above. +// The devices live as long as the process, so do the contexts. +static std::vector<std::unique_ptr<ggml_backend_opencl_device_context>> g_ggml_backend_opencl_dev_ctxs; inline std::string read_file(const std::string &path) { - std::ifstream ifs(path); - if (!ifs) { - return ""; - } - std::string text; - ifs.seekg(0, std::ios::end); - text.resize(ifs.tellg()); - ifs.seekg(0, std::ios::beg); - ifs.read(&text[0], text.size()); - return text; + std::ifstream ifs(path); + if (!ifs) { + return ""; + } + std::string text; + ifs.seekg(0, std::ios::end); + text.resize(ifs.tellg()); + ifs.seekg(0, std::ios::beg); + ifs.read(&text[0], text.size()); + return text; } static cl_program build_program_from_source(cl_context ctx, cl_device_id dev, const char* program_buffer, const std::string &compile_opts) { @@ -745,16 +902,128 @@ static cl_program build_program_from_source(cl_context ctx, cl_device_id dev, co return p; } -static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_version opencl_c_version) { +static void load_cl_kernels_argsort(ggml_backend_opencl_context *backend_ctx) { + // compiler options for general kernels + auto opencl_c_std = + std::string("CL") + std::to_string(backend_ctx->opencl_c_version.major) + "." + std::to_string(backend_ctx->opencl_c_version.minor); + std::string compile_opts = std::string("-cl-std=") + opencl_c_std + + " -cl-mad-enable -cl-unsafe-math-optimizations" + " -cl-finite-math-only -cl-fast-relaxed-math"; + + // argsort + if (!backend_ctx->kernels_loaded_argsort) { + cl_int err; +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "argsort.cl.h" + }; +#else + const std::string kernel_src = read_file("argsort.cl"); +#endif + backend_ctx->program_argsort_f32_i32 = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); + + CL_CHECK((backend_ctx->kernel_argsort_f32_i32 = clCreateKernel(backend_ctx->program_argsort_f32_i32, "kernel_argsort_f32_i32", &err), err)); + backend_ctx->kernels_loaded_argsort = true; + } +} + +static void load_cl_kernels_flash_attn(ggml_backend_opencl_context *backend_ctx) { + // compiler options for general kernels + auto opencl_c_std = + std::string("CL") + std::to_string(backend_ctx->opencl_c_version.major) + "." + std::to_string(backend_ctx->opencl_c_version.minor); + std::string compile_opts = std::string("-cl-std=") + opencl_c_std + + " -cl-mad-enable -cl-unsafe-math-optimizations" + " -cl-finite-math-only -cl-fast-relaxed-math"; + + // flash_attn + if (!backend_ctx->kernels_loaded_flash_attn) { + cl_int err; + + #ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src_f16 { + #include "flash_attn_f16.cl.h" + }; + const std::string kernel_src_f32 { + #include "flash_attn_f32.cl.h" + }; + const std::string kernel_src_f32_f16 { + #include "flash_attn_f32_f16.cl.h" + }; + #else + const std::string kernel_src_f16 = read_file("flash_attn_f16.cl"); + const std::string kernel_src_f32 = read_file("flash_attn_f32.cl"); + const std::string kernel_src_f32_f16 = read_file("flash_attn_f32_f16.cl"); + #endif + + if (!kernel_src_f16.empty() && !kernel_src_f32.empty() && !kernel_src_f32_f16.empty()) { + const struct { int dk; int dv; int bm; int bn; } fa_dims[] = { + { 40, 40, 32, 32}, { 64, 64, 64, 64}, { 80, 80, 64, 32}, { 96, 96, 64, 32}, + {112, 112, 32, 32}, {128, 128, 32, 32}, {192, 128, 16, 16}, + {192, 192, 16, 16}, {256, 256, 16, 16}, + }; + + for (size_t i = 0; i < sizeof(fa_dims)/sizeof(fa_dims[0]); ++i) { + const int dk = fa_dims[i].dk; + const int dv = fa_dims[i].dv; + const int bm = fa_dims[i].bm; + const int bn = fa_dims[i].bn; + std::string OPTS = compile_opts + + " -D DK=" + std::to_string(dk) + + " -D DV=" + std::to_string(dv) + + " -D BLOCK_M=" + std::to_string(bm) + + " -D BLOCK_N=" + std::to_string(bn); + + cl_program prog_f16 = build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src_f16.c_str(), OPTS); + cl_kernel k_f16, k_f16_q1; + CL_CHECK((k_f16 = clCreateKernel(prog_f16, "flash_attn_f16", &err), err)); + CL_CHECK((k_f16_q1 = clCreateKernel(prog_f16, "flash_attn_f16_q1", &err), err)); + backend_ctx->kernels_flash_attn_f16[{dk, dv}] = k_f16; + backend_ctx->kernels_flash_attn_f16_q1[{dk, dv}] = k_f16_q1; + CL_CHECK(clReleaseProgram(prog_f16)); + + cl_program prog_f32 = build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src_f32.c_str(), OPTS); + cl_kernel k_f32, k_f32_q1; + CL_CHECK((k_f32 = clCreateKernel(prog_f32, "flash_attn_f32", &err), err)); + CL_CHECK((k_f32_q1 = clCreateKernel(prog_f32, "flash_attn_f32_q1", &err), err)); + backend_ctx->kernels_flash_attn_f32[{dk, dv}] = k_f32; + backend_ctx->kernels_flash_attn_f32_q1[{dk, dv}] = k_f32_q1; + CL_CHECK(clReleaseProgram(prog_f32)); + + cl_program prog_f32_f16 = build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src_f32_f16.c_str(), OPTS); + cl_kernel k_f32_f16, k_f32_f16_q1; + CL_CHECK((k_f32_f16 = clCreateKernel(prog_f32_f16, "flash_attn_f32_f16", &err), err)); + CL_CHECK((k_f32_f16_q1 = clCreateKernel(prog_f32_f16, "flash_attn_f32_f16_q1", &err), err)); + backend_ctx->kernels_flash_attn_f32_f16[{dk, dv}] = k_f32_f16; + backend_ctx->kernels_flash_attn_f32_f16_q1[{dk, dv}] = k_f32_f16_q1; + CL_CHECK(clReleaseProgram(prog_f32_f16)); + + backend_ctx->kernels_flash_attn_bm[{dk, dv}] = bm; + backend_ctx->kernels_flash_attn_bn[{dk, dv}] = bn; + } + backend_ctx->kernels_loaded_flash_attn = true; + } + } +} + +static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx) { + if (backend_ctx->kernels_loaded) { + return; + } + cl_int err; // compiler options for general kernels auto opencl_c_std = - std::string("CL") + std::to_string(opencl_c_version.major) + "." + std::to_string(opencl_c_version.minor); + std::string("CL") + std::to_string(backend_ctx->opencl_c_version.major) + "." + std::to_string(backend_ctx->opencl_c_version.minor); std::string compile_opts = std::string("-cl-std=") + opencl_c_std + " -cl-mad-enable -cl-unsafe-math-optimizations" " -cl-finite-math-only -cl-fast-relaxed-math"; + if (backend_ctx->adreno_use_large_buffer) { + compile_opts += " -qcom-enable-large-buffer "; + } + GGML_LOG_INFO("ggml_opencl: loading OpenCL kernels"); // add @@ -792,6 +1061,24 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve GGML_LOG_CONT("."); } + // tri + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "tri.cl.h" + }; +#else + const std::string kernel_src = read_file("tri.cl"); +#endif + cl_program prog = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); + + CL_CHECK((backend_ctx->kernel_tri = clCreateKernel(prog, "kernel_tri_f32", &err), err)); + GGML_LOG_CONT("."); + + CL_CHECK(clReleaseProgram(prog)); + } + // fill { #ifdef GGML_OPENCL_EMBED_KERNELS @@ -835,13 +1122,15 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve #else const std::string kernel_src = read_file("cpy.cl"); #endif - backend_ctx->program_cpy = + cl_program prog = build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); - CL_CHECK((backend_ctx->kernel_cpy_f16_f16 = clCreateKernel(backend_ctx->program_cpy, "kernel_cpy_f16_f16", &err), err)); - CL_CHECK((backend_ctx->kernel_cpy_f16_f32 = clCreateKernel(backend_ctx->program_cpy, "kernel_cpy_f16_f32", &err), err)); - CL_CHECK((backend_ctx->kernel_cpy_f32_f16 = clCreateKernel(backend_ctx->program_cpy, "kernel_cpy_f32_f16", &err), err)); - CL_CHECK((backend_ctx->kernel_cpy_f32_f32 = clCreateKernel(backend_ctx->program_cpy, "kernel_cpy_f32_f32", &err), err)); + CL_CHECK((backend_ctx->kernel_cpy_f16_f16 = clCreateKernel(prog, "kernel_cpy_f16_f16", &err), err)); + CL_CHECK((backend_ctx->kernel_cpy_f16_f32 = clCreateKernel(prog, "kernel_cpy_f16_f32", &err), err)); + CL_CHECK((backend_ctx->kernel_cpy_f32_f16 = clCreateKernel(prog, "kernel_cpy_f32_f16", &err), err)); + CL_CHECK((backend_ctx->kernel_cpy_f32_f32 = clCreateKernel(prog, "kernel_cpy_f32_f32", &err), err)); + CL_CHECK((backend_ctx->kernel_cpy_f32_f32_pack = clCreateKernel(prog, "kernel_cpy_f32_f32_pack", &err), err)); + CL_CHECK((backend_ctx->kernel_cpy_i32_i32 = clCreateKernel(prog, "kernel_cpy_i32_i32", &err), err)); GGML_LOG_CONT("."); } @@ -861,12 +1150,59 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve CL_CHECK((backend_ctx->kernel_restore_block_q4_0_noshuffle = clCreateKernel(backend_ctx->program_cvt, "kernel_restore_block_q4_0_noshuffle", &err), err)); CL_CHECK((backend_ctx->kernel_convert_block_q4_0 = clCreateKernel(backend_ctx->program_cvt, "kernel_convert_block_q4_0", &err), err)); CL_CHECK((backend_ctx->kernel_restore_block_q4_0 = clCreateKernel(backend_ctx->program_cvt, "kernel_restore_block_q4_0", &err), err)); + CL_CHECK((backend_ctx->kernel_convert_block_q4_0_trans4_ns = clCreateKernel(backend_ctx->program_cvt, "kernel_convert_block_q4_0_trans4_ns", &err), err)); + CL_CHECK((backend_ctx->kernel_restore_block_q4_0_trans4_ns = clCreateKernel(backend_ctx->program_cvt, "kernel_restore_block_q4_0_trans4_ns", &err), err)); + CL_CHECK((backend_ctx->kernel_convert_block_q4_1_noshuffle = clCreateKernel(backend_ctx->program_cvt, "kernel_convert_block_q4_1_noshuffle", &err), err)); + CL_CHECK((backend_ctx->kernel_restore_block_q4_1_noshuffle = clCreateKernel(backend_ctx->program_cvt, "kernel_restore_block_q4_1_noshuffle", &err), err)); + CL_CHECK((backend_ctx->kernel_convert_block_q4_1 = clCreateKernel(backend_ctx->program_cvt, "kernel_convert_block_q4_1", &err), err)); + CL_CHECK((backend_ctx->kernel_restore_block_q4_1 = clCreateKernel(backend_ctx->program_cvt, "kernel_restore_block_q4_1", &err), err)); + CL_CHECK((backend_ctx->kernel_convert_block_q4_1_trans4_ns = clCreateKernel(backend_ctx->program_cvt, "kernel_convert_block_q4_1_trans4_ns", &err), err)); + CL_CHECK((backend_ctx->kernel_restore_block_q4_1_trans4_ns = clCreateKernel(backend_ctx->program_cvt, "kernel_restore_block_q4_1_trans4_ns", &err), err)); + CL_CHECK((backend_ctx->kernel_convert_block_q5_0 = clCreateKernel(backend_ctx->program_cvt, "kernel_convert_block_q5_0", &err), err)); + CL_CHECK((backend_ctx->kernel_restore_block_q5_0 = clCreateKernel(backend_ctx->program_cvt, "kernel_restore_block_q5_0", &err), err)); + CL_CHECK((backend_ctx->kernel_convert_block_q5_0_noshuffle = clCreateKernel(backend_ctx->program_cvt, "kernel_convert_block_q5_0_noshuffle", &err), err)); + CL_CHECK((backend_ctx->kernel_restore_block_q5_0_noshuffle = clCreateKernel(backend_ctx->program_cvt, "kernel_restore_block_q5_0_noshuffle", &err), err)); + CL_CHECK((backend_ctx->kernel_convert_block_q5_1_noshuffle = clCreateKernel(backend_ctx->program_cvt, "kernel_convert_block_q5_1_noshuffle", &err), err)); + CL_CHECK((backend_ctx->kernel_restore_block_q5_1_noshuffle = clCreateKernel(backend_ctx->program_cvt, "kernel_restore_block_q5_1_noshuffle", &err), err)); + CL_CHECK((backend_ctx->kernel_convert_block_q5_0_trans4_ns = clCreateKernel(backend_ctx->program_cvt, "kernel_convert_block_q5_0_trans4_ns", &err), err)); + CL_CHECK((backend_ctx->kernel_restore_block_q5_0_trans4_ns = clCreateKernel(backend_ctx->program_cvt, "kernel_restore_block_q5_0_trans4_ns", &err), err)); + CL_CHECK((backend_ctx->kernel_convert_block_q5_1 = clCreateKernel(backend_ctx->program_cvt, "kernel_convert_block_q5_1", &err), err)); + CL_CHECK((backend_ctx->kernel_restore_block_q5_1 = clCreateKernel(backend_ctx->program_cvt, "kernel_restore_block_q5_1", &err), err)); + CL_CHECK((backend_ctx->kernel_convert_block_q5_1_trans4_ns = clCreateKernel(backend_ctx->program_cvt, "kernel_convert_block_q5_1_trans4_ns", &err), err)); + CL_CHECK((backend_ctx->kernel_restore_block_q5_1_trans4_ns = clCreateKernel(backend_ctx->program_cvt, "kernel_restore_block_q5_1_trans4_ns", &err), err)); + CL_CHECK((backend_ctx->kernel_convert_block_q4_k_trans4_ns = clCreateKernel(backend_ctx->program_cvt, "kernel_convert_block_q4_k_trans4_ns", &err), err)); + CL_CHECK((backend_ctx->kernel_restore_block_q4_k_trans4_ns = clCreateKernel(backend_ctx->program_cvt, "kernel_restore_block_q4_k_trans4_ns", &err), err)); + CL_CHECK((backend_ctx->kernel_convert_block_q5_k_trans4_ns = clCreateKernel(backend_ctx->program_cvt, "kernel_convert_block_q5_k_trans4_ns", &err), err)); + CL_CHECK((backend_ctx->kernel_restore_block_q5_k_trans4_ns = clCreateKernel(backend_ctx->program_cvt, "kernel_restore_block_q5_k_trans4_ns", &err), err)); + CL_CHECK((backend_ctx->kernel_convert_block_q6_k_trans4_ns = clCreateKernel(backend_ctx->program_cvt, "kernel_convert_block_q6_k_trans4_ns", &err), err)); + CL_CHECK((backend_ctx->kernel_restore_block_q6_k_trans4_ns = clCreateKernel(backend_ctx->program_cvt, "kernel_restore_block_q6_k_trans4_ns", &err), err)); CL_CHECK((backend_ctx->kernel_convert_block_mxfp4 = clCreateKernel(backend_ctx->program_cvt, "kernel_convert_block_mxfp4", &err), err)); CL_CHECK((backend_ctx->kernel_convert_block_mxfp4_trans = clCreateKernel(backend_ctx->program_cvt, "kernel_convert_block_mxfp4_trans", &err), err)); + CL_CHECK((backend_ctx->kernel_convert_block_mxfp4_trans4_ns = clCreateKernel(backend_ctx->program_cvt, "kernel_convert_block_mxfp4_trans4_ns", &err), err)); + CL_CHECK((backend_ctx->kernel_restore_block_mxfp4_trans4_ns = clCreateKernel(backend_ctx->program_cvt, "kernel_restore_block_mxfp4_trans4_ns", &err), err)); CL_CHECK((backend_ctx->kernel_restore_block_mxfp4_trans = clCreateKernel(backend_ctx->program_cvt, "kernel_restore_block_mxfp4_trans", &err), err)); CL_CHECK((backend_ctx->kernel_restore_block_mxfp4 = clCreateKernel(backend_ctx->program_cvt, "kernel_restore_block_mxfp4", &err), err)); CL_CHECK((backend_ctx->kernel_convert_block_q8_0 = clCreateKernel(backend_ctx->program_cvt, "kernel_convert_block_q8_0", &err), err)); CL_CHECK((backend_ctx->kernel_restore_block_q8_0 = clCreateKernel(backend_ctx->program_cvt, "kernel_restore_block_q8_0", &err), err)); + CL_CHECK((backend_ctx->kernel_restore_block_q8_0_trans = clCreateKernel(backend_ctx->program_cvt, "kernel_restore_block_q8_0_trans", &err), err)); + CL_CHECK((backend_ctx->kernel_convert_block_q4_K = clCreateKernel(backend_ctx->program_cvt, "kernel_convert_block_q4_K", &err), err)); + CL_CHECK((backend_ctx->kernel_restore_block_q4_K = clCreateKernel(backend_ctx->program_cvt, "kernel_restore_block_q4_K", &err), err)); + CL_CHECK((backend_ctx->kernel_convert_block_q4_K_noshuffle = clCreateKernel(backend_ctx->program_cvt, "kernel_convert_block_q4_K_noshuffle", &err), err)); + CL_CHECK((backend_ctx->kernel_restore_block_q4_K_noshuffle = clCreateKernel(backend_ctx->program_cvt, "kernel_restore_block_q4_K_noshuffle", &err), err)); + CL_CHECK((backend_ctx->kernel_convert_block_q5_K = clCreateKernel(backend_ctx->program_cvt, "kernel_convert_block_q5_K", &err), err)); + CL_CHECK((backend_ctx->kernel_restore_block_q5_K = clCreateKernel(backend_ctx->program_cvt, "kernel_restore_block_q5_K", &err), err)); + CL_CHECK((backend_ctx->kernel_convert_block_q5_K_noshuffle = clCreateKernel(backend_ctx->program_cvt, "kernel_convert_block_q5_K_noshuffle", &err), err)); + CL_CHECK((backend_ctx->kernel_restore_block_q5_K_noshuffle = clCreateKernel(backend_ctx->program_cvt, "kernel_restore_block_q5_K_noshuffle", &err), err)); + CL_CHECK((backend_ctx->kernel_convert_block_q6_K = clCreateKernel(backend_ctx->program_cvt, "kernel_convert_block_q6_K", &err), err)); + CL_CHECK((backend_ctx->kernel_restore_block_q6_K = clCreateKernel(backend_ctx->program_cvt, "kernel_restore_block_q6_K", &err), err)); + CL_CHECK((backend_ctx->kernel_convert_block_q6_K_noshuffle = clCreateKernel(backend_ctx->program_cvt, "kernel_convert_block_q6_K_noshuffle", &err), err)); + CL_CHECK((backend_ctx->kernel_restore_block_q6_K_noshuffle = clCreateKernel(backend_ctx->program_cvt, "kernel_restore_block_q6_K_noshuffle", &err), err)); + CL_CHECK((backend_ctx->kernel_convert_block_iq4_nl = clCreateKernel(backend_ctx->program_cvt, "kernel_convert_block_iq4_nl", &err), err)); + CL_CHECK((backend_ctx->kernel_restore_block_iq4_nl = clCreateKernel(backend_ctx->program_cvt, "kernel_restore_block_iq4_nl", &err), err)); + CL_CHECK((backend_ctx->kernel_convert_block_iq4_nl_noshuffle = clCreateKernel(backend_ctx->program_cvt, "kernel_convert_block_iq4_nl_noshuffle", &err), err)); + CL_CHECK((backend_ctx->kernel_restore_block_iq4_nl_noshuffle = clCreateKernel(backend_ctx->program_cvt, "kernel_restore_block_iq4_nl_noshuffle", &err), err)); + CL_CHECK((backend_ctx->kernel_convert_bf16_to_f16 = clCreateKernel(backend_ctx->program_cvt, "kernel_convert_bf16_to_f16", &err), err)); + CL_CHECK((backend_ctx->kernel_convert_f16_to_bf16 = clCreateKernel(backend_ctx->program_cvt, "kernel_convert_f16_to_bf16", &err), err)); GGML_LOG_CONT("."); } @@ -887,6 +1223,23 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve GGML_LOG_CONT("."); } + // diag + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "diag.cl.h" + }; +#else + const std::string kernel_src = read_file("diag.cl"); +#endif + cl_program prog = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); + + CL_CHECK((backend_ctx->kernel_diag_f32 = clCreateKernel(prog, "kernel_diag_f32", &err), err)); + CL_CHECK(clReleaseProgram(prog)); + GGML_LOG_CONT("."); + } + // gelu { #ifdef GGML_OPENCL_EMBED_KERNELS @@ -952,6 +1305,23 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve GGML_LOG_CONT("."); } + // solve_tri_f32 + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "solve_tri.cl.h" + }; +#else + const std::string kernel_src = read_file("solve_tri.cl"); +#endif + cl_program prog = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); + + CL_CHECK((backend_ctx->kernel_solve_tri_f32 = clCreateKernel(prog, "kernel_solve_tri_f32", &err), err)); + GGML_LOG_CONT("."); + CL_CHECK(clReleaseProgram(prog)); + } + // im2col_f32 { #ifdef GGML_OPENCL_EMBED_KERNELS @@ -1072,3608 +1442,7741 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve GGML_LOG_CONT("."); } - // mul_mv_q6_k + // mul_mv_q4_1_f32 { #ifdef GGML_OPENCL_EMBED_KERNELS const std::string kernel_src { - #include "mul_mv_q6_k.cl.h" + #include "mul_mv_q4_1_f32.cl.h" }; #else - const std::string kernel_src = read_file("mul_mv_q6_k.cl"); + const std::string kernel_src = read_file("mul_mv_q4_1_f32.cl"); #endif - backend_ctx->program_mul_mv_q6_K = + cl_program prog = build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); - CL_CHECK((backend_ctx->kernel_mul_mv_q6_K_f32 = clCreateKernel(backend_ctx->program_mul_mv_q6_K, "kernel_mul_mv_q6_K_f32", &err), err)); + CL_CHECK((backend_ctx->kernel_mul_mv_q4_1_f32 = clCreateKernel(prog, "kernel_mul_mv_q4_1_f32", &err), err)); + CL_CHECK(clReleaseProgram(prog)); GGML_LOG_CONT("."); } - // mul_mv_q8_0_f32 + // mul_mv_q4_1_f32_flat { #ifdef GGML_OPENCL_EMBED_KERNELS const std::string kernel_src { - #include "mul_mv_q8_0_f32.cl.h" + #include "mul_mv_q4_1_f32_flat.cl.h" }; #else - const std::string kernel_src = read_file("mul_mv_q8_0_f32.cl"); + const std::string kernel_src = read_file("mul_mv_q4_1_f32_flat.cl"); #endif - backend_ctx->program_mul_mv_q8_0_f32 = + cl_program prog = build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); - CL_CHECK((backend_ctx->kernel_mul_mv_q8_0_f32 = clCreateKernel(backend_ctx->program_mul_mv_q8_0_f32, "kernel_mul_mv_q8_0_f32", &err), err)); + CL_CHECK((backend_ctx->kernel_mul_mv_q4_1_f32_flat = clCreateKernel(prog, "kernel_mul_mv_q4_1_f32_flat", &err), err)); + CL_CHECK(clReleaseProgram(prog)); GGML_LOG_CONT("."); } - // mul_mv_q8_0_f32_flat + // mul_mv_q4_k_f32 { #ifdef GGML_OPENCL_EMBED_KERNELS const std::string kernel_src { - #include "mul_mv_q8_0_f32_flat.cl.h" + #include "mul_mv_q4_k_f32.cl.h" }; #else - const std::string kernel_src = read_file("mul_mv_q8_0_f32_flat.cl"); + const std::string kernel_src = read_file("mul_mv_q4_k_f32.cl"); #endif - backend_ctx->program_mul_mv_q8_0_f32_flat = + cl_program prog = build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); - CL_CHECK((backend_ctx->kernel_mul_mv_q8_0_f32_flat = clCreateKernel(backend_ctx->program_mul_mv_q8_0_f32_flat, "kernel_mul_mv_q8_0_f32_flat", &err), err)); + CL_CHECK((backend_ctx->kernel_mul_mv_q4_K_f32 = clCreateKernel(prog, "kernel_mul_mv_q4_K_f32", &err), err)); + CL_CHECK(clReleaseProgram(prog)); GGML_LOG_CONT("."); } - // mul_mv_mxfp4_f32 + // mul_mv_q4_k_f32_flat { #ifdef GGML_OPENCL_EMBED_KERNELS const std::string kernel_src { - #include "mul_mv_mxfp4_f32.cl.h" + #include "mul_mv_q4_k_f32_flat.cl.h" }; #else - const std::string kernel_src = read_file("mul_mv_mxfp4_f32.cl"); + const std::string kernel_src = read_file("mul_mv_q4_k_f32_flat.cl"); #endif - backend_ctx->program_mul_mv_mxfp4_f32 = + cl_program prog = build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); - CL_CHECK((backend_ctx->kernel_mul_mv_mxfp4_f32 = clCreateKernel(backend_ctx->program_mul_mv_mxfp4_f32, "kernel_mul_mv_mxfp4_f32", &err), err)); + CL_CHECK((backend_ctx->kernel_mul_mv_q4_K_f32_flat = clCreateKernel(prog, "kernel_mul_mv_q4_K_f32_flat", &err), err)); + CL_CHECK(clReleaseProgram(prog)); GGML_LOG_CONT("."); } - // mul_mv_mxfp4_f32_flat + // mul_mv_q5_0_f32 { #ifdef GGML_OPENCL_EMBED_KERNELS const std::string kernel_src { - #include "mul_mv_mxfp4_f32_flat.cl.h" + #include "mul_mv_q5_0_f32.cl.h" }; #else - const std::string kernel_src = read_file("mul_mv_mxfp4_f32_flat.cl"); + const std::string kernel_src = read_file("mul_mv_q5_0_f32.cl"); #endif - backend_ctx->program_mul_mv_mxfp4_f32_flat = + cl_program prog = build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); - CL_CHECK((backend_ctx->kernel_mul_mv_mxfp4_f32_flat = clCreateKernel(backend_ctx->program_mul_mv_mxfp4_f32_flat, "kernel_mul_mv_mxfp4_f32_flat", &err), err)); + CL_CHECK((backend_ctx->kernel_mul_mv_q5_0_f32 = clCreateKernel(prog, "kernel_mul_mv_q5_0_f32", &err), err)); + CL_CHECK(clReleaseProgram(prog)); GGML_LOG_CONT("."); } - // mul_mv_f16_f16 + // mul_mv_q5_0_f32_flat { #ifdef GGML_OPENCL_EMBED_KERNELS const std::string kernel_src { - #include "mul_mv_f16_f16.cl.h" + #include "mul_mv_q5_0_f32_flat.cl.h" }; #else - const std::string kernel_src = read_file("mul_mv_f16_f16.cl"); + const std::string kernel_src = read_file("mul_mv_q5_0_f32_flat.cl"); #endif - backend_ctx->program_mul_mv_f16_f16 = + cl_program prog = build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); - CL_CHECK((backend_ctx->kernel_mul_mat_f16_f16 = clCreateKernel(backend_ctx->program_mul_mv_f16_f16, "kernel_mul_mat_f16_f16", &err), err)); + CL_CHECK((backend_ctx->kernel_mul_mv_q5_0_f32_flat = clCreateKernel(prog, "kernel_mul_mv_q5_0_f32_flat", &err), err)); + CL_CHECK(clReleaseProgram(prog)); GGML_LOG_CONT("."); } - // mul_mv_f16_f32_1row + // mul_mv_q5_1_f32 { #ifdef GGML_OPENCL_EMBED_KERNELS const std::string kernel_src { - #include "mul_mv_f16_f32_1row.cl.h" + #include "mul_mv_q5_1_f32.cl.h" }; #else - const std::string kernel_src = read_file("mul_mv_f16_f32_1row.cl"); + const std::string kernel_src = read_file("mul_mv_q5_1_f32.cl"); #endif - backend_ctx->program_mul_mv_f16_f32_1row = + cl_program prog = build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); - CL_CHECK((backend_ctx->kernel_mul_mat_f16_f32_1row = clCreateKernel(backend_ctx->program_mul_mv_f16_f32_1row, "kernel_mul_mat_f16_f32_1row", &err), err)); + CL_CHECK((backend_ctx->kernel_mul_mv_q5_1_f32 = clCreateKernel(prog, "kernel_mul_mv_q5_1_f32", &err), err)); + CL_CHECK(clReleaseProgram(prog)); GGML_LOG_CONT("."); } - // mul_mv_f16_f32_l4 + // mul_mv_q5_1_f32_flat { #ifdef GGML_OPENCL_EMBED_KERNELS const std::string kernel_src { - #include "mul_mv_f16_f32_l4.cl.h" + #include "mul_mv_q5_1_f32_flat.cl.h" }; #else - const std::string kernel_src = read_file("mul_mv_f16_f32_l4.cl"); + const std::string kernel_src = read_file("mul_mv_q5_1_f32_flat.cl"); #endif - backend_ctx->program_mul_mv_f16_f32_l4 = + cl_program prog = build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); - CL_CHECK((backend_ctx->kernel_mul_mat_f16_f32_l4 = clCreateKernel(backend_ctx->program_mul_mv_f16_f32_l4, "kernel_mul_mat_f16_f32_l4", &err), err)); + CL_CHECK((backend_ctx->kernel_mul_mv_q5_1_f32_flat = clCreateKernel(prog, "kernel_mul_mv_q5_1_f32_flat", &err), err)); + CL_CHECK(clReleaseProgram(prog)); GGML_LOG_CONT("."); } - // mul_mv_f16_f32 + // mul_mv_q5_k_f32 { #ifdef GGML_OPENCL_EMBED_KERNELS const std::string kernel_src { - #include "mul_mv_f16_f32.cl.h" + #include "mul_mv_q5_k_f32.cl.h" }; #else - const std::string kernel_src = read_file("mul_mv_f16_f32.cl"); + const std::string kernel_src = read_file("mul_mv_q5_k_f32.cl"); #endif - backend_ctx->program_mul_mv_f16_f32 = + cl_program prog = build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); - CL_CHECK((backend_ctx->kernel_mul_mat_f16_f32 = clCreateKernel(backend_ctx->program_mul_mv_f16_f32, "kernel_mul_mat_f16_f32", &err), err)); + CL_CHECK((backend_ctx->kernel_mul_mv_q5_K_f32 = clCreateKernel(prog, "kernel_mul_mv_q5_K_f32", &err), err)); + CL_CHECK(clReleaseProgram(prog)); GGML_LOG_CONT("."); } - // mul_mv_f32_f32 + // mul_mv_q5_k_f32_flat { #ifdef GGML_OPENCL_EMBED_KERNELS const std::string kernel_src { - #include "mul_mv_f32_f32.cl.h" + #include "mul_mv_q5_k_f32_flat.cl.h" }; #else - const std::string kernel_src = read_file("mul_mv_f32_f32.cl"); + const std::string kernel_src = read_file("mul_mv_q5_k_f32_flat.cl"); #endif - backend_ctx->program_mul_mv_f32_f32 = + cl_program prog = build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); - CL_CHECK((backend_ctx->kernel_mul_mat_f32_f32 = clCreateKernel(backend_ctx->program_mul_mv_f32_f32, "kernel_mul_mat_f32_f32", &err), err)); - GGML_LOG_CONT("."); + CL_CHECK((backend_ctx->kernel_mul_mv_q5_K_f32_flat = clCreateKernel(prog, "kernel_mul_mv_q5_K_f32_flat", &err), err)); + CL_CHECK(clReleaseProgram(prog)); } - // mul_mat_f16_f32_tiled + // mul_mv_q6_k_f32 { #ifdef GGML_OPENCL_EMBED_KERNELS const std::string kernel_src { - #include "mul_mat_f16_f32.cl.h" + #include "mul_mv_q6_k_f32.cl.h" }; #else - const std::string kernel_src = read_file("mul_mat_f16_f32.cl"); + const std::string kernel_src = read_file("mul_mv_q6_k_f32.cl"); #endif - backend_ctx->program_mul_mat_f16_f32_tiled = + backend_ctx->program_mul_mv_q6_K = build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); - CL_CHECK((backend_ctx->kernel_mul_mat_f16_f32_tiled = clCreateKernel(backend_ctx->program_mul_mat_f16_f32_tiled, "mul_mat_f16_f32", &err), err)); + CL_CHECK((backend_ctx->kernel_mul_mv_q6_K_f32 = clCreateKernel(backend_ctx->program_mul_mv_q6_K, "kernel_mul_mv_q6_K_f32", &err), err)); GGML_LOG_CONT("."); } - // mul_mm_f32_f32_l4_lm + // mul_mv_q6_k_f32_flat { #ifdef GGML_OPENCL_EMBED_KERNELS const std::string kernel_src { - #include "mul_mm_f32_f32_l4_lm.cl.h" + #include "mul_mv_q6_k_f32_flat.cl.h" }; #else - const std::string kernel_src = read_file("mul_mm_f32_f32_l4_lm.cl"); + const std::string kernel_src = read_file("mul_mv_q6_k_f32_flat.cl"); #endif - backend_ctx->program_mul_mm_f32_f32_l4_lm = + cl_program prog = build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); - CL_CHECK((backend_ctx->kernel_mul_mm_f32_f32_l4_lm = clCreateKernel(backend_ctx->program_mul_mm_f32_f32_l4_lm, "kernel_mul_mm_f32_f32_l4_lm", &err), err)); + CL_CHECK((backend_ctx->kernel_mul_mv_q6_K_f32_flat = clCreateKernel(prog, "kernel_mul_mv_q6_K_f32_flat", &err), err)); + CL_CHECK(clReleaseProgram(prog)); GGML_LOG_CONT("."); } - // mul_mm_f16_f32_l4_lm + // mul_mv_q8_0_f32 { #ifdef GGML_OPENCL_EMBED_KERNELS const std::string kernel_src { - #include "mul_mm_f16_f32_l4_lm.cl.h" + #include "mul_mv_q8_0_f32.cl.h" }; #else - const std::string kernel_src = read_file("mul_mm_f16_f32_l4_lm.cl"); + const std::string kernel_src = read_file("mul_mv_q8_0_f32.cl"); #endif - backend_ctx->program_mul_mm_f16_f32_l4_lm = + backend_ctx->program_mul_mv_q8_0_f32 = build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); - CL_CHECK((backend_ctx->kernel_mul_mm_f16_f32_l4_lm = clCreateKernel(backend_ctx->program_mul_mm_f16_f32_l4_lm, "kernel_mul_mm_f16_f32_l4_lm", &err), err)); + CL_CHECK((backend_ctx->kernel_mul_mv_q8_0_f32 = clCreateKernel(backend_ctx->program_mul_mv_q8_0_f32, "kernel_mul_mv_q8_0_f32", &err), err)); GGML_LOG_CONT("."); } - // mul_mm_q8_0_f32_l4_lm + // mul_mv_q8_0_f32_flat { #ifdef GGML_OPENCL_EMBED_KERNELS const std::string kernel_src { - #include "mul_mm_q8_0_f32_l4_lm.cl.h" + #include "mul_mv_q8_0_f32_flat.cl.h" }; #else - const std::string kernel_src = read_file("mul_mm_q8_0_f32_l4_lm.cl"); + const std::string kernel_src = read_file("mul_mv_q8_0_f32_flat.cl"); #endif - backend_ctx->program_mul_mm_q8_0_f32_l4_lm = + backend_ctx->program_mul_mv_q8_0_f32_flat = build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); - CL_CHECK((backend_ctx->kernel_mul_mm_q8_0_f32_l4_lm = clCreateKernel(backend_ctx->program_mul_mm_q8_0_f32_l4_lm, "kernel_mul_mm_q8_0_f32_l4_lm", &err), err)); + CL_CHECK((backend_ctx->kernel_mul_mv_q8_0_f32_flat = clCreateKernel(backend_ctx->program_mul_mv_q8_0_f32_flat, "kernel_mul_mv_q8_0_f32_flat", &err), err)); GGML_LOG_CONT("."); } - // mul_mm_f16_f32_kq_kqv + // mul_mv_iq4_nl_f32 { #ifdef GGML_OPENCL_EMBED_KERNELS const std::string kernel_src { - #include "mul_mm_f16_f32_kq_kqv.cl.h" + #include "mul_mv_iq4_nl_f32.cl.h" }; #else - const std::string kernel_src = read_file("mul_mm_f16_f32_kq_kqv.cl"); + const std::string kernel_src = read_file("mul_mv_iq4_nl_f32.cl"); #endif - backend_ctx->program_mul_mm_f16_f32_kqv = - build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts+" -DKQV "); - backend_ctx->program_mul_mm_f16_f32_kq = + cl_program prog = build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); - CL_CHECK((backend_ctx->kernel_mul_mm_f16_f32_kqv = clCreateKernel(backend_ctx->program_mul_mm_f16_f32_kqv, "mul_mm_f16_f32_kqv", &err), err)); - CL_CHECK((backend_ctx->kernel_mul_mm_f16_f32_kq = clCreateKernel(backend_ctx->program_mul_mm_f16_f32_kq, "mul_mm_f16_f32_kq", &err), err)); + CL_CHECK((backend_ctx->kernel_mul_mv_iq4_nl_f32 = clCreateKernel(prog, "kernel_mul_mv_iq4_nl_f32", &err), err)); + CL_CHECK(clReleaseProgram(prog)); GGML_LOG_CONT("."); } - // mul + // mul_mv_iq4_nl_f32_flat { #ifdef GGML_OPENCL_EMBED_KERNELS const std::string kernel_src { - #include "mul.cl.h" + #include "mul_mv_iq4_nl_f32_flat.cl.h" }; #else - const std::string kernel_src = read_file("mul.cl"); + const std::string kernel_src = read_file("mul_mv_iq4_nl_f32_flat.cl"); #endif - backend_ctx->program_mul = + cl_program prog = build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); - CL_CHECK((backend_ctx->kernel_mul = clCreateKernel(backend_ctx->program_mul, "kernel_mul", &err), err)); - CL_CHECK((backend_ctx->kernel_mul_row = clCreateKernel(backend_ctx->program_mul, "kernel_mul_row", &err), err)); - CL_CHECK((backend_ctx->kernel_mul_f16 = clCreateKernel(backend_ctx->program_mul, "kernel_mul_f16", &err), err)); - CL_CHECK((backend_ctx->kernel_mul_row_f16 = clCreateKernel(backend_ctx->program_mul, "kernel_mul_row_f16", &err), err)); + CL_CHECK((backend_ctx->kernel_mul_mv_iq4_nl_f32_flat = clCreateKernel(prog, "kernel_mul_mv_iq4_nl_f32_flat", &err), err)); + CL_CHECK(clReleaseProgram(prog)); GGML_LOG_CONT("."); } - // norm + // mul_mv_mxfp4_f32 { #ifdef GGML_OPENCL_EMBED_KERNELS const std::string kernel_src { - #include "norm.cl.h" + #include "mul_mv_mxfp4_f32.cl.h" }; #else - const std::string kernel_src = read_file("norm.cl"); + const std::string kernel_src = read_file("mul_mv_mxfp4_f32.cl"); #endif - backend_ctx->program_norm = + backend_ctx->program_mul_mv_mxfp4_f32 = build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); - CL_CHECK((backend_ctx->kernel_norm = clCreateKernel(backend_ctx->program_norm, "kernel_norm", &err), err)); - CL_CHECK((backend_ctx->kernel_norm_mul_add = clCreateKernel(backend_ctx->program_norm, "kernel_norm_mul_add", &err), err)); + CL_CHECK((backend_ctx->kernel_mul_mv_mxfp4_f32 = clCreateKernel(backend_ctx->program_mul_mv_mxfp4_f32, "kernel_mul_mv_mxfp4_f32", &err), err)); GGML_LOG_CONT("."); } - // relu + // mul_mv_mxfp4_f32_flat { #ifdef GGML_OPENCL_EMBED_KERNELS const std::string kernel_src { - #include "relu.cl.h" + #include "mul_mv_mxfp4_f32_flat.cl.h" }; #else - const std::string kernel_src = read_file("relu.cl"); + const std::string kernel_src = read_file("mul_mv_mxfp4_f32_flat.cl"); #endif - backend_ctx->program_relu = + backend_ctx->program_mul_mv_mxfp4_f32_flat = build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); - CL_CHECK((backend_ctx->kernel_relu = clCreateKernel(backend_ctx->program_relu, "kernel_relu", &err), err)); + CL_CHECK((backend_ctx->kernel_mul_mv_mxfp4_f32_flat = clCreateKernel(backend_ctx->program_mul_mv_mxfp4_f32_flat, "kernel_mul_mv_mxfp4_f32_flat", &err), err)); GGML_LOG_CONT("."); } - // rms_norm + // mul_mv_f16_f16 { #ifdef GGML_OPENCL_EMBED_KERNELS const std::string kernel_src { - #include "rms_norm.cl.h" + #include "mul_mv_f16_f16.cl.h" }; #else - const std::string kernel_src = read_file("rms_norm.cl"); + const std::string kernel_src = read_file("mul_mv_f16_f16.cl"); #endif - backend_ctx->program_rms_norm = + backend_ctx->program_mul_mv_f16_f16 = build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); - CL_CHECK((backend_ctx->kernel_rms_norm = clCreateKernel(backend_ctx->program_rms_norm, "kernel_rms_norm", &err), err)); - CL_CHECK((backend_ctx->kernel_rms_norm_mul = clCreateKernel(backend_ctx->program_rms_norm, "kernel_rms_norm_mul", &err), err)); + CL_CHECK((backend_ctx->kernel_mul_mat_f16_f16 = clCreateKernel(backend_ctx->program_mul_mv_f16_f16, "kernel_mul_mat_f16_f16", &err), err)); GGML_LOG_CONT("."); } - // rope + // mul_mv_f16_f32_1row { #ifdef GGML_OPENCL_EMBED_KERNELS const std::string kernel_src { - #include "rope.cl.h" + #include "mul_mv_f16_f32_1row.cl.h" }; #else - const std::string kernel_src = read_file("rope.cl"); + const std::string kernel_src = read_file("mul_mv_f16_f32_1row.cl"); #endif - backend_ctx->program_rope = + backend_ctx->program_mul_mv_f16_f32_1row = build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); - CL_CHECK((backend_ctx->kernel_rope_norm_f32 = clCreateKernel(backend_ctx->program_rope, "kernel_rope_norm_f32", &err), err)); - CL_CHECK((backend_ctx->kernel_rope_norm_f16 = clCreateKernel(backend_ctx->program_rope, "kernel_rope_norm_f16", &err), err)); - CL_CHECK((backend_ctx->kernel_rope_neox_f32 = clCreateKernel(backend_ctx->program_rope, "kernel_rope_neox_f32", &err), err)); - CL_CHECK((backend_ctx->kernel_rope_neox_f16 = clCreateKernel(backend_ctx->program_rope, "kernel_rope_neox_f16", &err), err)); - CL_CHECK((backend_ctx->kernel_rope_multi_f32 = clCreateKernel(backend_ctx->program_rope, "kernel_rope_multi_f32", &err), err)); - CL_CHECK((backend_ctx->kernel_rope_multi_f16 = clCreateKernel(backend_ctx->program_rope, "kernel_rope_multi_f16", &err), err)); - CL_CHECK((backend_ctx->kernel_rope_vision_f32 = clCreateKernel(backend_ctx->program_rope, "kernel_rope_vision_f32", &err), err)); - CL_CHECK((backend_ctx->kernel_rope_vision_f16 = clCreateKernel(backend_ctx->program_rope, "kernel_rope_vision_f16", &err), err)); + CL_CHECK((backend_ctx->kernel_mul_mat_f16_f32_1row = clCreateKernel(backend_ctx->program_mul_mv_f16_f32_1row, "kernel_mul_mat_f16_f32_1row", &err), err)); GGML_LOG_CONT("."); } - // scale + // mul_mv_f16_f32_l4 { #ifdef GGML_OPENCL_EMBED_KERNELS const std::string kernel_src { - #include "scale.cl.h" + #include "mul_mv_f16_f32_l4.cl.h" }; #else - const std::string kernel_src = read_file("scale.cl"); + const std::string kernel_src = read_file("mul_mv_f16_f32_l4.cl"); #endif - backend_ctx->program_scale = + backend_ctx->program_mul_mv_f16_f32_l4 = build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); - CL_CHECK((backend_ctx->kernel_scale = clCreateKernel(backend_ctx->program_scale, "kernel_scale", &err), err)); + CL_CHECK((backend_ctx->kernel_mul_mat_f16_f32_l4 = clCreateKernel(backend_ctx->program_mul_mv_f16_f32_l4, "kernel_mul_mat_f16_f32_l4", &err), err)); GGML_LOG_CONT("."); } - // silu + // mul_mv_f16_f32 { #ifdef GGML_OPENCL_EMBED_KERNELS const std::string kernel_src { - #include "silu.cl.h" + #include "mul_mv_f16_f32.cl.h" }; #else - const std::string kernel_src = read_file("silu.cl"); + const std::string kernel_src = read_file("mul_mv_f16_f32.cl"); #endif - backend_ctx->program_silu = + backend_ctx->program_mul_mv_f16_f32 = build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); - CL_CHECK((backend_ctx->kernel_silu = clCreateKernel(backend_ctx->program_silu, "kernel_silu", &err), err)); - CL_CHECK((backend_ctx->kernel_silu_4 = clCreateKernel(backend_ctx->program_silu, "kernel_silu_4", &err), err)); + CL_CHECK((backend_ctx->kernel_mul_mat_f16_f32 = clCreateKernel(backend_ctx->program_mul_mv_f16_f32, "kernel_mul_mat_f16_f32", &err), err)); GGML_LOG_CONT("."); } - // softmax_f32 + // mul_mv_f32_f32 { #ifdef GGML_OPENCL_EMBED_KERNELS const std::string kernel_src { - #include "softmax_f32.cl.h" + #include "mul_mv_f32_f32.cl.h" }; #else - const std::string kernel_src = read_file("softmax_f32.cl"); + const std::string kernel_src = read_file("mul_mv_f32_f32.cl"); #endif - backend_ctx->program_softmax_f32 = + backend_ctx->program_mul_mv_f32_f32 = build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); - CL_CHECK((backend_ctx->kernel_soft_max = clCreateKernel(backend_ctx->program_softmax_f32, "kernel_soft_max", &err), err)); + CL_CHECK((backend_ctx->kernel_mul_mat_f32_f32 = clCreateKernel(backend_ctx->program_mul_mv_f32_f32, "kernel_mul_mat_f32_f32", &err), err)); GGML_LOG_CONT("."); } - // softmax_f16 + // mul_mat_f16_f32_tiled { #ifdef GGML_OPENCL_EMBED_KERNELS const std::string kernel_src { - #include "softmax_f16.cl.h" + #include "mul_mat_f16_f32.cl.h" }; #else - const std::string kernel_src = read_file("softmax_f16.cl"); + const std::string kernel_src = read_file("mul_mat_f16_f32.cl"); #endif - backend_ctx->program_softmax_f16 = + backend_ctx->program_mul_mat_f16_f32_tiled = build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); - CL_CHECK((backend_ctx->kernel_soft_max_f16 = clCreateKernel(backend_ctx->program_softmax_f16, "kernel_soft_max_f16", &err), err)); + CL_CHECK((backend_ctx->kernel_mul_mat_f16_f32_tiled = clCreateKernel(backend_ctx->program_mul_mat_f16_f32_tiled, "mul_mat_f16_f32", &err), err)); GGML_LOG_CONT("."); } - // softmax_4_f32 +#ifdef GGML_OPENCL_USE_ADRENO_KERNELS + // gemm_xmem_f16_f32_os8 { #ifdef GGML_OPENCL_EMBED_KERNELS const std::string kernel_src { - #include "softmax_4_f32.cl.h" + #include "gemm_xmem_f16_f32_os8.cl.h" }; #else - const std::string kernel_src = read_file("softmax_4_f32.cl"); + const std::string kernel_src = read_file("gemm_xmem_f16_f32_os8.cl"); #endif - backend_ctx->program_softmax_4_f32 = + cl_program prog = build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); - CL_CHECK((backend_ctx->kernel_soft_max_4 = clCreateKernel(backend_ctx->program_softmax_4_f32, "kernel_soft_max_4", &err), err)); + CL_CHECK((backend_ctx->kernel_adreno_xmem_pack_src_f32 = + clCreateKernel(prog, "adreno_xmem_pack_src_f32", &err), err)); + CL_CHECK((backend_ctx->kernel_adreno_xmem_prepack_weight_f16 = + clCreateKernel(prog, "adreno_xmem_prepack_weight_f16", &err), err)); + CL_CHECK((backend_ctx->kernel_gemm_xmem_f16_f32_os8 = + clCreateKernel(prog, "kernel_gemm_xmem_f16_f32_os8", &err), err)); + CL_CHECK((backend_ctx->kernel_adreno_xmem_store_dst_f32 = + clCreateKernel(prog, "adreno_xmem_store_dst_f32", &err), err)); + CL_CHECK(clReleaseProgram(prog)); GGML_LOG_CONT("."); } +#endif // GGML_OPENCL_USE_ADRENO_KERNELS - // softmax_4_f16 + // mul_mm_f32_f32_l4_lm { #ifdef GGML_OPENCL_EMBED_KERNELS const std::string kernel_src { - #include "softmax_4_f16.cl.h" + #include "mul_mm_f32_f32_l4_lm.cl.h" }; #else - const std::string kernel_src = read_file("softmax_4_f16.cl"); + const std::string kernel_src = read_file("mul_mm_f32_f32_l4_lm.cl"); #endif - backend_ctx->program_softmax_4_f16 = + backend_ctx->program_mul_mm_f32_f32_l4_lm = build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); - CL_CHECK((backend_ctx->kernel_soft_max_4_f16 = clCreateKernel(backend_ctx->program_softmax_4_f16, "kernel_soft_max_4_f16", &err), err)); + CL_CHECK((backend_ctx->kernel_mul_mm_f32_f32_l4_lm = clCreateKernel(backend_ctx->program_mul_mm_f32_f32_l4_lm, "kernel_mul_mm_f32_f32_l4_lm", &err), err)); GGML_LOG_CONT("."); } - // flash_attn - { - #ifdef GGML_OPENCL_EMBED_KERNELS - const std::string kernel_src_f16 { - #include "flash_attn_f16.cl.h" - }; - const std::string kernel_src_f32 { - #include "flash_attn_f32.cl.h" - }; - const std::string kernel_src_f32_f16 { - #include "flash_attn_f32_f16.cl.h" - }; - #else - const std::string kernel_src_f16 = read_file("flash_attn_f16.cl"); - const std::string kernel_src_f32 = read_file("flash_attn_f32.cl"); - const std::string kernel_src_f32_f16 = read_file("flash_attn_f32_f16.cl"); - #endif - - if (!kernel_src_f16.empty() && !kernel_src_f32.empty() && !kernel_src_f32_f16.empty()) { - const struct { int dk; int dv; int bm; int bn; } fa_dims[] = { - { 40, 40, 32, 32}, { 64, 64, 64, 64}, { 80, 80, 64, 32}, { 96, 96, 64, 32}, - {112, 112, 32, 32}, {128, 128, 32, 32}, {192, 128, 16, 16}, - {192, 192, 16, 16}, {256, 256, 16, 16}, - }; - - for (size_t i = 0; i < sizeof(fa_dims)/sizeof(fa_dims[0]); ++i) { - const int dk = fa_dims[i].dk; - const int dv = fa_dims[i].dv; - const int bm = fa_dims[i].bm; - const int bn = fa_dims[i].bn; - std::string OPTS = compile_opts + - " -D DK=" + std::to_string(dk) + - " -D DV=" + std::to_string(dv) + - " -D BLOCK_M=" + std::to_string(bm) + - " -D BLOCK_N=" + std::to_string(bn); - - cl_program prog_f16 = build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src_f16.c_str(), OPTS); - cl_kernel k_f16, k_f16_q1; - CL_CHECK((k_f16 = clCreateKernel(prog_f16, "flash_attn_f16", &err), err)); - CL_CHECK((k_f16_q1 = clCreateKernel(prog_f16, "flash_attn_f16_q1", &err), err)); - backend_ctx->kernels_flash_attn_f16[{dk, dv}] = k_f16; - backend_ctx->kernels_flash_attn_f16_q1[{dk, dv}] = k_f16_q1; - CL_CHECK(clReleaseProgram(prog_f16)); - - cl_program prog_f32 = build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src_f32.c_str(), OPTS); - cl_kernel k_f32, k_f32_q1; - CL_CHECK((k_f32 = clCreateKernel(prog_f32, "flash_attn_f32", &err), err)); - CL_CHECK((k_f32_q1 = clCreateKernel(prog_f32, "flash_attn_f32_q1", &err), err)); - backend_ctx->kernels_flash_attn_f32[{dk, dv}] = k_f32; - backend_ctx->kernels_flash_attn_f32_q1[{dk, dv}] = k_f32_q1; - CL_CHECK(clReleaseProgram(prog_f32)); - - cl_program prog_f32_f16 = build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src_f32_f16.c_str(), OPTS); - cl_kernel k_f32_f16, k_f32_f16_q1; - CL_CHECK((k_f32_f16 = clCreateKernel(prog_f32_f16, "flash_attn_f32_f16", &err), err)); - CL_CHECK((k_f32_f16_q1 = clCreateKernel(prog_f32_f16, "flash_attn_f32_f16_q1", &err), err)); - backend_ctx->kernels_flash_attn_f32_f16[{dk, dv}] = k_f32_f16; - backend_ctx->kernels_flash_attn_f32_f16_q1[{dk, dv}] = k_f32_f16_q1; - CL_CHECK(clReleaseProgram(prog_f32_f16)); - - backend_ctx->kernels_flash_attn_bm[{dk, dv}] = bm; - backend_ctx->kernels_flash_attn_bn[{dk, dv}] = bn; - } - GGML_LOG_CONT("."); - } - } - - // argsort + // mul_mm_f16_f32_l4_lm { #ifdef GGML_OPENCL_EMBED_KERNELS const std::string kernel_src { - #include "argsort.cl.h" + #include "mul_mm_f16_f32_l4_lm.cl.h" }; #else - const std::string kernel_src = read_file("argsort.cl"); + const std::string kernel_src = read_file("mul_mm_f16_f32_l4_lm.cl"); #endif - backend_ctx->program_argsort_f32_i32 = + backend_ctx->program_mul_mm_f16_f32_l4_lm = build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); - CL_CHECK((backend_ctx->kernel_argsort_f32_i32 = clCreateKernel(backend_ctx->program_argsort_f32_i32, "kernel_argsort_f32_i32", &err), err)); + CL_CHECK((backend_ctx->kernel_mul_mm_f16_f32_l4_lm = clCreateKernel(backend_ctx->program_mul_mm_f16_f32_l4_lm, "kernel_mul_mm_f16_f32_l4_lm", &err), err)); GGML_LOG_CONT("."); } - // div + // mul_mm_q4_0_f32_l4_lm { #ifdef GGML_OPENCL_EMBED_KERNELS const std::string kernel_src { - #include "div.cl.h" + #include "mul_mm_q4_0_f32_l4_lm.cl.h" }; #else - const std::string kernel_src = read_file("div.cl"); + const std::string kernel_src = read_file("mul_mm_q4_0_f32_l4_lm.cl"); #endif - std::string compile_opts = std::string("-cl-std=") + opencl_c_std + - " -cl-mad-enable -cl-finite-math-only "; - - backend_ctx->program_div = + cl_program prog = build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); - CL_CHECK((backend_ctx->kernel_div = clCreateKernel(backend_ctx->program_div, "kernel_div", &err), err)); - CL_CHECK((backend_ctx->kernel_div_row = clCreateKernel(backend_ctx->program_div, "kernel_div_row", &err), err)); - CL_CHECK((backend_ctx->kernel_div_f16 = clCreateKernel(backend_ctx->program_div, "kernel_div_f16", &err), err)); - CL_CHECK((backend_ctx->kernel_div_row_f16 = clCreateKernel(backend_ctx->program_div, "kernel_div_row_f16", &err), err)); + CL_CHECK((backend_ctx->kernel_mul_mm_q4_0_f32_l4_lm = clCreateKernel(prog, "kernel_mul_mm_q4_0_f32_l4_lm", &err), err)); GGML_LOG_CONT("."); } - // sqr + // mul_mm_q4_1_f32_l4_lm { #ifdef GGML_OPENCL_EMBED_KERNELS const std::string kernel_src { - #include "sqr.cl.h" + #include "mul_mm_q4_1_f32_l4_lm.cl.h" }; #else - const std::string kernel_src = read_file("sqr.cl"); + const std::string kernel_src = read_file("mul_mm_q4_1_f32_l4_lm.cl"); #endif cl_program prog = build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); - CL_CHECK((backend_ctx->kernel_sqr_cont_f32 = clCreateKernel(prog, "kernel_sqr_cont_f32", &err), err)); - CL_CHECK((backend_ctx->kernel_sqr_cont_f32_4 = clCreateKernel(prog, "kernel_sqr_cont_f32_4", &err), err)); - CL_CHECK((backend_ctx->kernel_sqr_cont_f16 = clCreateKernel(prog, "kernel_sqr_cont_f16", &err), err)); - CL_CHECK((backend_ctx->kernel_sqr_cont_f16_4 = clCreateKernel(prog, "kernel_sqr_cont_f16_4", &err), err)); - - CL_CHECK(clReleaseProgram(prog)); + CL_CHECK((backend_ctx->kernel_mul_mm_q4_1_f32_l4_lm = clCreateKernel(prog, "kernel_mul_mm_q4_1_f32_l4_lm", &err), err)); GGML_LOG_CONT("."); } - // sqrt + // mul_mm_q5_0_f32_l4_lm { #ifdef GGML_OPENCL_EMBED_KERNELS const std::string kernel_src { - #include "sqrt.cl.h" + #include "mul_mm_q5_0_f32_l4_lm.cl.h" }; #else - const std::string kernel_src = read_file("sqrt.cl"); + const std::string kernel_src = read_file("mul_mm_q5_0_f32_l4_lm.cl"); #endif cl_program prog = build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); - CL_CHECK((backend_ctx->kernel_sqrt_cont_f32 = clCreateKernel(prog, "kernel_sqrt_cont_f32", &err), err)); - CL_CHECK((backend_ctx->kernel_sqrt_cont_f32_4 = clCreateKernel(prog, "kernel_sqrt_cont_f32_4", &err), err)); - CL_CHECK((backend_ctx->kernel_sqrt_cont_f16 = clCreateKernel(prog, "kernel_sqrt_cont_f16", &err), err)); - CL_CHECK((backend_ctx->kernel_sqrt_cont_f16_4 = clCreateKernel(prog, "kernel_sqrt_cont_f16_4", &err), err)); - - CL_CHECK(clReleaseProgram(prog)); + CL_CHECK((backend_ctx->kernel_mul_mm_q5_0_f32_l4_lm = clCreateKernel(prog, "kernel_mul_mm_q5_0_f32_l4_lm", &err), err)); GGML_LOG_CONT("."); } - // mean + // mul_mm_q5_1_f32_l4_lm { #ifdef GGML_OPENCL_EMBED_KERNELS const std::string kernel_src { - #include "mean.cl.h" + #include "mul_mm_q5_1_f32_l4_lm.cl.h" }; #else - const std::string kernel_src = read_file("mean.cl"); + const std::string kernel_src = read_file("mul_mm_q5_1_f32_l4_lm.cl"); #endif cl_program prog = build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); - CL_CHECK((backend_ctx->kernel_mean_f32 = clCreateKernel(prog, "kernel_mean_f32", &err), err)); + CL_CHECK((backend_ctx->kernel_mul_mm_q5_1_f32_l4_lm = clCreateKernel(prog, "kernel_mul_mm_q5_1_f32_l4_lm", &err), err)); + GGML_LOG_CONT("."); + } + + // mul_mm_q8_0_f32_l4_lm + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "mul_mm_q8_0_f32_l4_lm.cl.h" + }; +#else + const std::string kernel_src = read_file("mul_mm_q8_0_f32_l4_lm.cl"); +#endif + backend_ctx->program_mul_mm_q8_0_f32_l4_lm = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); + + CL_CHECK((backend_ctx->kernel_mul_mm_q8_0_f32_l4_lm = clCreateKernel(backend_ctx->program_mul_mm_q8_0_f32_l4_lm, "kernel_mul_mm_q8_0_f32_l4_lm", &err), err)); + GGML_LOG_CONT("."); + } + + // mul_mm_iq4_nl_f32_l4_lm + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "mul_mm_iq4_nl_f32_l4_lm.cl.h" + }; +#else + const std::string kernel_src = read_file("mul_mm_iq4_nl_f32_l4_lm.cl"); +#endif + cl_program prog = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); + CL_CHECK((backend_ctx->kernel_mul_mm_iq4_nl_f32_l4_lm = clCreateKernel(prog, "kernel_mul_mm_iq4_nl_f32_l4_lm", &err), err)); CL_CHECK(clReleaseProgram(prog)); GGML_LOG_CONT("."); } - // sub + // mul_mm_q4_k_f32_l4_lm { #ifdef GGML_OPENCL_EMBED_KERNELS const std::string kernel_src { - #include "sub.cl.h" + #include "mul_mm_q4_k_f32_l4_lm.cl.h" }; #else - const std::string kernel_src = read_file("sub.cl"); + const std::string kernel_src = read_file("mul_mm_q4_k_f32_l4_lm.cl"); #endif - backend_ctx->program_sub = + cl_program prog = build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); - CL_CHECK((backend_ctx->kernel_sub = clCreateKernel(backend_ctx->program_sub, "kernel_sub", &err), err)); - CL_CHECK((backend_ctx->kernel_sub_row = clCreateKernel(backend_ctx->program_sub, "kernel_sub_row", &err), err)); - CL_CHECK((backend_ctx->kernel_sub_f16 = clCreateKernel(backend_ctx->program_sub, "kernel_sub_f16", &err), err)); - CL_CHECK((backend_ctx->kernel_sub_row_f16 = clCreateKernel(backend_ctx->program_sub, "kernel_sub_row_f16", &err), err)); + CL_CHECK((backend_ctx->kernel_mul_mm_q4_k_f32_l4_lm = clCreateKernel(prog, "kernel_mul_mm_q4_k_f32_l4_lm", &err), err)); + CL_CHECK(clReleaseProgram(prog)); GGML_LOG_CONT("."); } - // sum_rows + // mul_mm_q6_k_f32_l4_lm { #ifdef GGML_OPENCL_EMBED_KERNELS const std::string kernel_src { - #include "sum_rows.cl.h" + #include "mul_mm_q6_k_f32_l4_lm.cl.h" }; #else - const std::string kernel_src = read_file("sum_rows.cl"); + const std::string kernel_src = read_file("mul_mm_q6_k_f32_l4_lm.cl"); #endif - backend_ctx->program_sum_rows_f32 = + cl_program prog = build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); - CL_CHECK((backend_ctx->kernel_sum_rows_f32 = clCreateKernel(backend_ctx->program_sum_rows_f32, "kernel_sum_rows_f32", &err), err)); + CL_CHECK((backend_ctx->kernel_mul_mm_q6_k_f32_l4_lm = clCreateKernel(prog, "kernel_mul_mm_q6_k_f32_l4_lm", &err), err)); + CL_CHECK(clReleaseProgram(prog)); GGML_LOG_CONT("."); } - // sigmoid + // mul_mm_q5_k_f32_l4_lm { #ifdef GGML_OPENCL_EMBED_KERNELS const std::string kernel_src { - #include "sigmoid.cl.h" + #include "mul_mm_q5_k_f32_l4_lm.cl.h" }; #else - const std::string kernel_src = read_file("sigmoid.cl"); + const std::string kernel_src = read_file("mul_mm_q5_k_f32_l4_lm.cl"); #endif - backend_ctx->program_sigmoid = + cl_program prog = build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); - CL_CHECK((backend_ctx->kernel_sigmoid_f32 = clCreateKernel(backend_ctx->program_sigmoid, "kernel_sigmoid_f32", &err), err)); - CL_CHECK((backend_ctx->kernel_sigmoid_f16 = clCreateKernel(backend_ctx->program_sigmoid, "kernel_sigmoid_f16", &err), err)); + CL_CHECK((backend_ctx->kernel_mul_mm_q5_k_f32_l4_lm = clCreateKernel(prog, "kernel_mul_mm_q5_k_f32_l4_lm", &err), err)); + CL_CHECK(clReleaseProgram(prog)); GGML_LOG_CONT("."); } - // group_norm + // mul_mm_f16_f32_kq_kqv { #ifdef GGML_OPENCL_EMBED_KERNELS const std::string kernel_src { - #include "group_norm.cl.h" + #include "mul_mm_f16_f32_kq_kqv.cl.h" }; #else - const std::string kernel_src = read_file("group_norm.cl"); + const std::string kernel_src = read_file("mul_mm_f16_f32_kq_kqv.cl"); #endif - backend_ctx->program_group_norm = + backend_ctx->program_mul_mm_f16_f32_kqv = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts+" -DKQV "); + backend_ctx->program_mul_mm_f16_f32_kq = build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); - CL_CHECK((backend_ctx->kernel_group_norm = clCreateKernel(backend_ctx->program_group_norm, "kernel_group_norm", &err), err)); - CL_CHECK((backend_ctx->kernel_group_norm_mul_add = clCreateKernel(backend_ctx->program_group_norm, "kernel_group_norm_mul_add", &err), err)); + CL_CHECK((backend_ctx->kernel_mul_mm_f16_f32_kqv = clCreateKernel(backend_ctx->program_mul_mm_f16_f32_kqv, "mul_mm_f16_f32_kqv", &err), err)); + CL_CHECK((backend_ctx->kernel_mul_mm_f16_f32_kq = clCreateKernel(backend_ctx->program_mul_mm_f16_f32_kq, "mul_mm_f16_f32_kq", &err), err)); GGML_LOG_CONT("."); } - // repeat + // mul { #ifdef GGML_OPENCL_EMBED_KERNELS const std::string kernel_src { - #include "repeat.cl.h" + #include "mul.cl.h" }; #else - const std::string kernel_src = read_file("repeat.cl"); + const std::string kernel_src = read_file("mul.cl"); #endif - if (!kernel_src.empty()) { - backend_ctx->program_repeat = - build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); - CL_CHECK((backend_ctx->kernel_repeat = clCreateKernel(backend_ctx->program_repeat, "kernel_repeat", &err), err)); - GGML_LOG_CONT("."); - } else { - GGML_LOG_WARN("ggml_opencl: repeat kernel source not found or empty. Repeat operations will not be available.\n"); - backend_ctx->program_repeat = nullptr; - backend_ctx->kernel_repeat = nullptr; - } + backend_ctx->program_mul = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); + + CL_CHECK((backend_ctx->kernel_mul = clCreateKernel(backend_ctx->program_mul, "kernel_mul", &err), err)); + CL_CHECK((backend_ctx->kernel_mul_row = clCreateKernel(backend_ctx->program_mul, "kernel_mul_row", &err), err)); + CL_CHECK((backend_ctx->kernel_mul_f16 = clCreateKernel(backend_ctx->program_mul, "kernel_mul_f16", &err), err)); + CL_CHECK((backend_ctx->kernel_mul_row_f16 = clCreateKernel(backend_ctx->program_mul, "kernel_mul_row_f16", &err), err)); + GGML_LOG_CONT("."); } - // pad + // norm { #ifdef GGML_OPENCL_EMBED_KERNELS const std::string kernel_src { - #include "pad.cl.h" + #include "norm.cl.h" }; #else - const std::string kernel_src = read_file("pad.cl"); + const std::string kernel_src = read_file("norm.cl"); #endif - if (!kernel_src.empty()) { - backend_ctx->program_pad = - build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); - CL_CHECK((backend_ctx->kernel_pad = clCreateKernel(backend_ctx->program_pad, "kernel_pad", &err), err)); - GGML_LOG_CONT("."); - } else { - GGML_LOG_WARN("ggml_opencl: pad kernel source not found or empty. Pad operations will not be available.\n"); - backend_ctx->program_pad = nullptr; - backend_ctx->kernel_pad = nullptr; - } + backend_ctx->program_norm = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); + + CL_CHECK((backend_ctx->kernel_norm = clCreateKernel(backend_ctx->program_norm, "kernel_norm", &err), err)); + CL_CHECK((backend_ctx->kernel_norm_mul_add = clCreateKernel(backend_ctx->program_norm, "kernel_norm_mul_add", &err), err)); + GGML_LOG_CONT("."); } - // tanh + // relu { #ifdef GGML_OPENCL_EMBED_KERNELS const std::string kernel_src { - #include "tanh.cl.h" + #include "relu.cl.h" }; #else - const std::string kernel_src = read_file("tanh.cl"); + const std::string kernel_src = read_file("relu.cl"); #endif - if (!kernel_src.empty()) { - backend_ctx->program_tanh = - build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); - CL_CHECK((backend_ctx->kernel_tanh_f32_nd = clCreateKernel(backend_ctx->program_tanh, "kernel_tanh_f32_nd", &err), err)); - CL_CHECK((backend_ctx->kernel_tanh_f16_nd = clCreateKernel(backend_ctx->program_tanh, "kernel_tanh_f16_nd", &err), err)); - GGML_LOG_CONT("."); - } else { - GGML_LOG_WARN("ggml_opencl: tanh kernel source not found or empty. Tanh operation will not be available.\n"); - backend_ctx->program_tanh = nullptr; - backend_ctx->kernel_tanh_f32_nd = nullptr; - backend_ctx->kernel_tanh_f16_nd = nullptr; - } + backend_ctx->program_relu = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); + + CL_CHECK((backend_ctx->kernel_relu = clCreateKernel(backend_ctx->program_relu, "kernel_relu", &err), err)); + GGML_LOG_CONT("."); } - // expm1 + // rms_norm { #ifdef GGML_OPENCL_EMBED_KERNELS const std::string kernel_src { - #include "expm1.cl.h" + #include "rms_norm.cl.h" }; #else - const std::string kernel_src = read_file("expm1.cl"); + const std::string kernel_src = read_file("rms_norm.cl"); #endif - cl_program prog; - if (!kernel_src.empty()) { - prog = - build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); - CL_CHECK((backend_ctx->kernel_expm1_f32_nd = clCreateKernel(prog, "kernel_expm1_f32_nd", &err), err)); - CL_CHECK((backend_ctx->kernel_expm1_f16_nd = clCreateKernel(prog, "kernel_expm1_f16_nd", &err), err)); - GGML_LOG_CONT("."); - } else { - GGML_LOG_WARN("ggml_opencl: expm1 kernel source not found or empty. Expm1 operation will not be available.\n"); - prog = nullptr; - backend_ctx->kernel_expm1_f32_nd = nullptr; - backend_ctx->kernel_expm1_f16_nd = nullptr; - } - CL_CHECK(clReleaseProgram(prog)); + backend_ctx->program_rms_norm = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); + + CL_CHECK((backend_ctx->kernel_rms_norm = clCreateKernel(backend_ctx->program_rms_norm, "kernel_rms_norm", &err), err)); + CL_CHECK((backend_ctx->kernel_rms_norm_mul = clCreateKernel(backend_ctx->program_rms_norm, "kernel_rms_norm_mul", &err), err)); + GGML_LOG_CONT("."); } - // softplus + // l2_norm { #ifdef GGML_OPENCL_EMBED_KERNELS const std::string kernel_src { - #include "softplus.cl.h" + #include "l2_norm.cl.h" }; #else - const std::string kernel_src = read_file("softplus.cl"); + const std::string kernel_src = read_file("l2_norm.cl"); #endif - cl_program prog; - if (!kernel_src.empty()) { - prog = - build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); - CL_CHECK((backend_ctx->kernel_softplus_f32_nd = clCreateKernel(prog, "kernel_softplus_f32_nd", &err), err)); - CL_CHECK((backend_ctx->kernel_softplus_f16_nd = clCreateKernel(prog, "kernel_softplus_f16_nd", &err), err)); - GGML_LOG_CONT("."); - } else { - GGML_LOG_WARN("ggml_opencl: softplus kernel source not found or empty. Softplus operation will not be available.\n"); - prog = nullptr; - backend_ctx->kernel_softplus_f32_nd = nullptr; - backend_ctx->kernel_softplus_f16_nd = nullptr; - } + cl_program prog = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); + + CL_CHECK((backend_ctx->kernel_l2_norm_f32 = clCreateKernel(prog, "kernel_l2_norm_f32", &err), err)); CL_CHECK(clReleaseProgram(prog)); + GGML_LOG_CONT("."); } - // upscale + // rope { #ifdef GGML_OPENCL_EMBED_KERNELS const std::string kernel_src { - #include "upscale.cl.h" + #include "rope.cl.h" }; #else - const std::string kernel_src = read_file("upscale.cl"); + const std::string kernel_src = read_file("rope.cl"); #endif - if (!kernel_src.empty()) { - backend_ctx->program_upscale = - build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); - CL_CHECK((backend_ctx->kernel_upscale = clCreateKernel(backend_ctx->program_upscale, "kernel_upscale", &err), err)); - if (backend_ctx->program_upscale) { - cl_int err_bilinear; - backend_ctx->kernel_upscale_bilinear = clCreateKernel(backend_ctx->program_upscale, "kernel_upscale_bilinear", &err_bilinear); - if (err_bilinear != CL_SUCCESS) { - GGML_LOG_WARN("ggml_opencl: kernel_upscale_bilinear not found in upscale.cl. Bilinear upscale will not be available. Error: %d\n", err_bilinear); - backend_ctx->kernel_upscale_bilinear = nullptr; - } - } else { - backend_ctx->kernel_upscale_bilinear = nullptr; - } - GGML_LOG_CONT("."); - } else { - GGML_LOG_WARN("ggml_opencl: upscale kernel source not found or empty. Upscale operations will not be available.\n"); - backend_ctx->program_upscale = nullptr; - backend_ctx->kernel_upscale = nullptr; - backend_ctx->kernel_upscale_bilinear = nullptr; - } + backend_ctx->program_rope = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); + + CL_CHECK((backend_ctx->kernel_rope_norm_f32 = clCreateKernel(backend_ctx->program_rope, "kernel_rope_norm_f32", &err), err)); + CL_CHECK((backend_ctx->kernel_rope_norm_f16 = clCreateKernel(backend_ctx->program_rope, "kernel_rope_norm_f16", &err), err)); + CL_CHECK((backend_ctx->kernel_rope_neox_f32 = clCreateKernel(backend_ctx->program_rope, "kernel_rope_neox_f32", &err), err)); + CL_CHECK((backend_ctx->kernel_rope_neox_f16 = clCreateKernel(backend_ctx->program_rope, "kernel_rope_neox_f16", &err), err)); + CL_CHECK((backend_ctx->kernel_rope_multi_f32 = clCreateKernel(backend_ctx->program_rope, "kernel_rope_multi_f32", &err), err)); + CL_CHECK((backend_ctx->kernel_rope_multi_f16 = clCreateKernel(backend_ctx->program_rope, "kernel_rope_multi_f16", &err), err)); + CL_CHECK((backend_ctx->kernel_rope_vision_f32 = clCreateKernel(backend_ctx->program_rope, "kernel_rope_vision_f32", &err), err)); + CL_CHECK((backend_ctx->kernel_rope_vision_f16 = clCreateKernel(backend_ctx->program_rope, "kernel_rope_vision_f16", &err), err)); + GGML_LOG_CONT("."); } - // concat + // scale { #ifdef GGML_OPENCL_EMBED_KERNELS const std::string kernel_src { - #include "concat.cl.h" + #include "scale.cl.h" }; #else - - const std::string kernel_src = read_file("concat.cl"); + const std::string kernel_src = read_file("scale.cl"); #endif - if (!kernel_src.empty()) { - backend_ctx->program_concat = - build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); + cl_program prog = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); - CL_CHECK((backend_ctx->kernel_concat_f32_contiguous = clCreateKernel(backend_ctx->program_concat, "kernel_concat_f32_contiguous", &err), err)); - CL_CHECK((backend_ctx->kernel_concat_f32_non_contiguous = clCreateKernel(backend_ctx->program_concat, "kernel_concat_f32_non_contiguous", &err), err)); - GGML_LOG_CONT("."); - } else { - GGML_LOG_WARN("ggml_opencl: concat kernel source not found or empty. Concat operations will not be available.\n"); - backend_ctx->program_concat = nullptr; - backend_ctx->kernel_concat_f32_contiguous = nullptr; - backend_ctx->kernel_concat_f32_non_contiguous = nullptr; - } + CL_CHECK((backend_ctx->kernel_scale_f32 = clCreateKernel(prog, "kernel_scale_f32", &err), err)); + CL_CHECK((backend_ctx->kernel_scale_f32_4 = clCreateKernel(prog, "kernel_scale_f32_4", &err), err)); + CL_CHECK(clReleaseProgram(prog)); + GGML_LOG_CONT("."); } - // timestep_embedding + // silu { #ifdef GGML_OPENCL_EMBED_KERNELS const std::string kernel_src { - #include "tsembd.cl.h" + #include "silu.cl.h" }; #else - - const std::string kernel_src = read_file("tsembd.cl"); + const std::string kernel_src = read_file("silu.cl"); #endif - if (!kernel_src.empty()) { - backend_ctx->program_tsembd = - build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); - CL_CHECK((backend_ctx->kernel_timestep_embedding = clCreateKernel(backend_ctx->program_tsembd, "kernel_timestep_embedding", &err), err)); - GGML_LOG_CONT("."); - } else { - GGML_LOG_WARN("ggml_opencl: timestep_embedding kernel source not found or empty. This op will not be available.\n"); - backend_ctx->program_tsembd = nullptr; - backend_ctx->kernel_timestep_embedding = nullptr; - } + backend_ctx->program_silu = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); + + CL_CHECK((backend_ctx->kernel_silu = clCreateKernel(backend_ctx->program_silu, "kernel_silu", &err), err)); + CL_CHECK((backend_ctx->kernel_silu_4 = clCreateKernel(backend_ctx->program_silu, "kernel_silu_4", &err), err)); + GGML_LOG_CONT("."); } - // set_rows + // softmax_f32 { #ifdef GGML_OPENCL_EMBED_KERNELS const std::string kernel_src { - #include "set_rows.cl.h" + #include "softmax_f32.cl.h" }; #else - const std::string kernel_src = read_file("set_rows.cl"); + const std::string kernel_src = read_file("softmax_f32.cl"); #endif - backend_ctx->program_set_rows = + backend_ctx->program_softmax_f32 = build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); - CL_CHECK((backend_ctx->kernel_set_rows_f32_i64 = clCreateKernel(backend_ctx->program_set_rows, "kernel_set_rows_f32_i64", &err), err)); - CL_CHECK((backend_ctx->kernel_set_rows_f32_i32 = clCreateKernel(backend_ctx->program_set_rows, "kernel_set_rows_f32_i32", &err), err)); - CL_CHECK((backend_ctx->kernel_set_rows_f16_i64 = clCreateKernel(backend_ctx->program_set_rows, "kernel_set_rows_f16_i64", &err), err)); - CL_CHECK((backend_ctx->kernel_set_rows_f16_i32 = clCreateKernel(backend_ctx->program_set_rows, "kernel_set_rows_f16_i32", &err), err)); + CL_CHECK((backend_ctx->kernel_soft_max = clCreateKernel(backend_ctx->program_softmax_f32, "kernel_soft_max", &err), err)); GGML_LOG_CONT("."); } - // conv2d - { - #ifdef GGML_OPENCL_EMBED_KERNELS - const std::string kernel_src { - #include "conv2d.cl.h" - }; - const std::string kernel_src_f16_f32 { - #include "conv2d_f16_f32.cl.h" - }; - #else - const std::string kernel_src = read_file("conv2d.cl"); - const std::string kernel_src_f16_f32 = read_file("conv2d_f16_f32.cl"); - #endif - if (!kernel_src.empty()) { - backend_ctx->program_conv_2d_f16 = - build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), (std::string(compile_opts) + " -DUSE_FP16=1").c_str()); - CL_CHECK((backend_ctx->kernel_conv_2d_f16 = clCreateKernel(backend_ctx->program_conv_2d_f16, "kernel_conv_2d", &err), err)); - GGML_LOG_CONT("."); - backend_ctx->program_conv_2d_f32 = - build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); - CL_CHECK((backend_ctx->kernel_conv_2d_f32 = clCreateKernel(backend_ctx->program_conv_2d_f32, "kernel_conv_2d", &err), err)); - GGML_LOG_CONT("."); - } else { - GGML_LOG_WARN("ggml_opencl: conv2d kernel source not found or empty. This op will not be available.\n"); - backend_ctx->program_conv_2d_f16 = nullptr; - backend_ctx->kernel_conv_2d_f16 = nullptr; - backend_ctx->program_conv_2d_f32 = nullptr; - backend_ctx->kernel_conv_2d_f32 = nullptr; - } - if (!kernel_src_f16_f32.empty()) { - backend_ctx->program_conv_2d_f16_f32 = - build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src_f16_f32.c_str(), compile_opts); - CL_CHECK((backend_ctx->kernel_conv_2d_f16_f32 = clCreateKernel(backend_ctx->program_conv_2d_f16_f32, "kernel_conv_2d", &err), err)); - GGML_LOG_CONT("."); - } else { - GGML_LOG_WARN("ggml_opencl: conv2d_f16_f32 kernel source not found or empty. This op will not be available.\n"); - backend_ctx->program_conv_2d_f16_f32 = nullptr; - backend_ctx->kernel_conv_2d_f16_f32 = nullptr; - } - } - - // ssm_conv + // softmax_f16 { #ifdef GGML_OPENCL_EMBED_KERNELS const std::string kernel_src { - #include "ssm_conv.cl.h" + #include "softmax_f16.cl.h" }; #else - const std::string kernel_src = read_file("ssm_conv.cl"); + const std::string kernel_src = read_file("softmax_f16.cl"); #endif - cl_program prog = + backend_ctx->program_softmax_f16 = build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); - CL_CHECK((backend_ctx->kernel_ssm_conv_f32_f32 = clCreateKernel(prog, "kernel_ssm_conv_f32_f32", &err), err)); - CL_CHECK((backend_ctx->kernel_ssm_conv_f32_f32_4 = clCreateKernel(prog, "kernel_ssm_conv_f32_f32_4", &err), err)); - CL_CHECK(clReleaseProgram(prog)); + CL_CHECK((backend_ctx->kernel_soft_max_f16 = clCreateKernel(backend_ctx->program_softmax_f16, "kernel_soft_max_f16", &err), err)); GGML_LOG_CONT("."); } - // mul_mv_id_q4_0_f32_8x_flat + // softmax_4_f32 { #ifdef GGML_OPENCL_EMBED_KERNELS const std::string kernel_src { - #include "mul_mv_id_q4_0_f32_8x_flat.cl.h" + #include "softmax_4_f32.cl.h" }; #else - const std::string kernel_src = read_file("mul_mv_id_q4_0_f32_8x_flat.cl"); + const std::string kernel_src = read_file("softmax_4_f32.cl"); #endif - backend_ctx->program_mul_mv_id_q4_0_f32_8x_flat = + backend_ctx->program_softmax_4_f32 = build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); - CL_CHECK((backend_ctx->kernel_mul_mv_id_q4_0_f32_8x_flat = clCreateKernel(backend_ctx->program_mul_mv_id_q4_0_f32_8x_flat, "kernel_mul_mv_id_q4_0_f32_8x_flat", &err), err)); + CL_CHECK((backend_ctx->kernel_soft_max_4 = clCreateKernel(backend_ctx->program_softmax_4_f32, "kernel_soft_max_4", &err), err)); GGML_LOG_CONT("."); } - // mul_mv_id_q8_0_f32 + // softmax_4_f16 { #ifdef GGML_OPENCL_EMBED_KERNELS const std::string kernel_src { - #include "mul_mv_id_q8_0_f32.cl.h" + #include "softmax_4_f16.cl.h" }; #else - const std::string kernel_src = read_file("mul_mv_id_q8_0_f32.cl"); + const std::string kernel_src = read_file("softmax_4_f16.cl"); #endif - backend_ctx->program_mul_mv_id_q8_0_f32 = + backend_ctx->program_softmax_4_f16 = build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); - CL_CHECK((backend_ctx->kernel_mul_mv_id_q8_0_f32 = clCreateKernel(backend_ctx->program_mul_mv_id_q8_0_f32, "kernel_mul_mv_id_q8_0_f32", &err), err)); + CL_CHECK((backend_ctx->kernel_soft_max_4_f16 = clCreateKernel(backend_ctx->program_softmax_4_f16, "kernel_soft_max_4_f16", &err), err)); GGML_LOG_CONT("."); } - // mul_mv_id_q8_0_f32_flat + // div { #ifdef GGML_OPENCL_EMBED_KERNELS const std::string kernel_src { - #include "mul_mv_id_q8_0_f32_flat.cl.h" + #include "div.cl.h" }; #else - const std::string kernel_src = read_file("mul_mv_id_q8_0_f32_flat.cl"); + const std::string kernel_src = read_file("div.cl"); #endif - backend_ctx->program_mul_mv_id_q8_0_f32_flat = + std::string compile_opts = std::string("-cl-std=") + opencl_c_std + + " -cl-mad-enable -cl-finite-math-only "; + + backend_ctx->program_div = build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); - CL_CHECK((backend_ctx->kernel_mul_mv_id_q8_0_f32_flat = clCreateKernel(backend_ctx->program_mul_mv_id_q8_0_f32_flat, "kernel_mul_mv_id_q8_0_f32_flat", &err), err)); + CL_CHECK((backend_ctx->kernel_div = clCreateKernel(backend_ctx->program_div, "kernel_div", &err), err)); + CL_CHECK((backend_ctx->kernel_div_row = clCreateKernel(backend_ctx->program_div, "kernel_div_row", &err), err)); + CL_CHECK((backend_ctx->kernel_div_f16 = clCreateKernel(backend_ctx->program_div, "kernel_div_f16", &err), err)); + CL_CHECK((backend_ctx->kernel_div_row_f16 = clCreateKernel(backend_ctx->program_div, "kernel_div_row_f16", &err), err)); GGML_LOG_CONT("."); } - // mul_mv_id_mxfp4_f32 + // sqr { #ifdef GGML_OPENCL_EMBED_KERNELS const std::string kernel_src { - #include "mul_mv_id_mxfp4_f32.cl.h" + #include "sqr.cl.h" }; #else - const std::string kernel_src = read_file("mul_mv_id_mxfp4_f32.cl"); + const std::string kernel_src = read_file("sqr.cl"); #endif - backend_ctx->program_mul_mv_id_mxfp4_f32 = + cl_program prog = build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); - CL_CHECK((backend_ctx->kernel_mul_mv_id_mxfp4_f32 = clCreateKernel(backend_ctx->program_mul_mv_id_mxfp4_f32, "kernel_mul_mv_id_mxfp4_f32", &err), err)); + CL_CHECK((backend_ctx->kernel_sqr_cont_f32 = clCreateKernel(prog, "kernel_sqr_cont_f32", &err), err)); + CL_CHECK((backend_ctx->kernel_sqr_cont_f32_4 = clCreateKernel(prog, "kernel_sqr_cont_f32_4", &err), err)); + CL_CHECK((backend_ctx->kernel_sqr_cont_f16 = clCreateKernel(prog, "kernel_sqr_cont_f16", &err), err)); + CL_CHECK((backend_ctx->kernel_sqr_cont_f16_4 = clCreateKernel(prog, "kernel_sqr_cont_f16_4", &err), err)); + + CL_CHECK(clReleaseProgram(prog)); GGML_LOG_CONT("."); } - // mul_mv_id_mxfp4_f32_flat + // sqrt { #ifdef GGML_OPENCL_EMBED_KERNELS const std::string kernel_src { - #include "mul_mv_id_mxfp4_f32_flat.cl.h" + #include "sqrt.cl.h" }; #else - const std::string kernel_src = read_file("mul_mv_id_mxfp4_f32_flat.cl"); + const std::string kernel_src = read_file("sqrt.cl"); #endif - backend_ctx->program_mul_mv_id_mxfp4_f32_flat = + cl_program prog = build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); - CL_CHECK((backend_ctx->kernel_mul_mv_id_mxfp4_f32_flat = clCreateKernel(backend_ctx->program_mul_mv_id_mxfp4_f32_flat, "kernel_mul_mv_id_mxfp4_f32_flat", &err), err)); + CL_CHECK((backend_ctx->kernel_sqrt_cont_f32 = clCreateKernel(prog, "kernel_sqrt_cont_f32", &err), err)); + CL_CHECK((backend_ctx->kernel_sqrt_cont_f32_4 = clCreateKernel(prog, "kernel_sqrt_cont_f32_4", &err), err)); + CL_CHECK((backend_ctx->kernel_sqrt_cont_f16 = clCreateKernel(prog, "kernel_sqrt_cont_f16", &err), err)); + CL_CHECK((backend_ctx->kernel_sqrt_cont_f16_4 = clCreateKernel(prog, "kernel_sqrt_cont_f16_4", &err), err)); + + CL_CHECK(clReleaseProgram(prog)); GGML_LOG_CONT("."); } - // Adreno kernels -#ifdef GGML_OPENCL_USE_ADRENO_KERNELS - // transpose + // mean { #ifdef GGML_OPENCL_EMBED_KERNELS const std::string kernel_src { - #include "transpose.cl.h" + #include "mean.cl.h" }; #else - const std::string kernel_src = read_file("transpose.cl"); + const std::string kernel_src = read_file("mean.cl"); #endif - backend_ctx->program_transpose = + cl_program prog = build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); - CL_CHECK((backend_ctx->kernel_transpose_32_16 = clCreateKernel(backend_ctx->program_transpose, "kernel_transpose_32_16", &err), err)); - CL_CHECK((backend_ctx->kernel_transpose_32 = clCreateKernel(backend_ctx->program_transpose, "kernel_transpose_32", &err), err)); - CL_CHECK((backend_ctx->kernel_transpose_16 = clCreateKernel(backend_ctx->program_transpose, "kernel_transpose_16", &err), err)); - CL_CHECK((backend_ctx->kernel_transpose_16_buf = clCreateKernel(backend_ctx->program_transpose, "kernel_transpose_16_buf", &err), err)); - CL_CHECK((backend_ctx->kernel_transpose_16_4x1 = clCreateKernel(backend_ctx->program_transpose, "kernel_transpose_16_4x1", &err), err)); + CL_CHECK((backend_ctx->kernel_mean_f32 = clCreateKernel(prog, "kernel_mean_f32", &err), err)); + CL_CHECK((backend_ctx->kernel_mean_f32_4 = clCreateKernel(prog, "kernel_mean_f32_4", &err), err)); + + CL_CHECK(clReleaseProgram(prog)); GGML_LOG_CONT("."); } - // gemv_noshuffle_general + // sub { - std::string CL_gemv_compile_opts = std::string("-cl-std=") + opencl_c_std + - " -cl-mad-enable " - " -DSIMDGROUP_WIDTH=" + - std::to_string(backend_ctx->adreno_wave_size); - if (backend_ctx->has_vector_subgroup_broadcast) { - CL_gemv_compile_opts += " -DVECTOR_SUB_GROUP_BROADCAT "; - } - #ifdef GGML_OPENCL_EMBED_KERNELS - const std::string kernel_src_CL_gemv_general { - #include "gemv_noshuffle_general.cl.h" + const std::string kernel_src { + #include "sub.cl.h" }; #else - const std::string kernel_src_CL_gemv_general = read_file("gemv_noshuffle_general.cl"); + const std::string kernel_src = read_file("sub.cl"); #endif + backend_ctx->program_sub = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); - backend_ctx->program_CL_gemv_general = build_program_from_source( - backend_ctx->context, backend_ctx->device, kernel_src_CL_gemv_general.c_str(), CL_gemv_compile_opts); - - CL_CHECK((backend_ctx->CL_mul_mat_vec_q4_0_f32_1d_4x_flat_general = clCreateKernel(backend_ctx->program_CL_gemv_general, "kernel_gemv_noshuffle", &err), err)); + CL_CHECK((backend_ctx->kernel_sub = clCreateKernel(backend_ctx->program_sub, "kernel_sub", &err), err)); + CL_CHECK((backend_ctx->kernel_sub_row = clCreateKernel(backend_ctx->program_sub, "kernel_sub_row", &err), err)); + CL_CHECK((backend_ctx->kernel_sub_f16 = clCreateKernel(backend_ctx->program_sub, "kernel_sub_f16", &err), err)); + CL_CHECK((backend_ctx->kernel_sub_row_f16 = clCreateKernel(backend_ctx->program_sub, "kernel_sub_row_f16", &err), err)); GGML_LOG_CONT("."); } - // gemv_noshuffle + // sum_rows { - // Gemv 2048, 16384 - std::string CL_gemv_compile_opts = std::string("-cl-std=") + opencl_c_std + - " -cl-mad-enable " - " -DLINE_STRIDE_A=2048 " - " -DBLOCK_STRIDE_A=16384 " - " -DSIMDGROUP_WIDTH=" + - std::to_string(backend_ctx->adreno_wave_size); - if (backend_ctx->has_vector_subgroup_broadcast) { - CL_gemv_compile_opts += " -DVECTOR_SUB_GROUP_BROADCAT "; - } - #ifdef GGML_OPENCL_EMBED_KERNELS - const std::string kernel_src_CL_gemv { - #include "gemv_noshuffle.cl.h" + const std::string kernel_src { + #include "sum_rows.cl.h" }; #else - const std::string kernel_src_CL_gemv = read_file("gemv_noshuffle.cl"); + const std::string kernel_src = read_file("sum_rows.cl"); #endif + backend_ctx->program_sum_rows_f32 = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); - backend_ctx->program_CL_gemv_4096_1_4096 = build_program_from_source( - backend_ctx->context, backend_ctx->device, kernel_src_CL_gemv.c_str(), CL_gemv_compile_opts); - CL_CHECK((backend_ctx->CL_mul_mat_vec_q4_0_f32_1d_4x_flat_4096_1_4096 = clCreateKernel(backend_ctx->program_CL_gemv_4096_1_4096, "kernel_gemv_noshuffle", &err), err)); - GGML_LOG_CONT("."); - - // Gemv 2048, 16384 - CL_gemv_compile_opts = std::string("-cl-std=") + opencl_c_std + - " -cl-mad-enable " - " -DLINE_STRIDE_A=2048 " - " -DBLOCK_STRIDE_A=16384 " - " -DSIMDGROUP_WIDTH=" + - std::to_string(backend_ctx->adreno_wave_size); - if (backend_ctx->has_vector_subgroup_broadcast) { - CL_gemv_compile_opts += " -DVECTOR_SUB_GROUP_BROADCAT "; - } - - backend_ctx->program_CL_gemv_4096_1_11008 = build_program_from_source( - backend_ctx->context, backend_ctx->device, kernel_src_CL_gemv.c_str(), CL_gemv_compile_opts); - CL_CHECK((backend_ctx->CL_mul_mat_vec_q4_0_f32_1d_4x_flat_4096_1_11008 = clCreateKernel(backend_ctx->program_CL_gemv_4096_1_11008, "kernel_gemv_noshuffle", &err), err)); - GGML_LOG_CONT("."); - - // Gemv 5504, 44032 - CL_gemv_compile_opts = std::string("-cl-std=") + opencl_c_std + - " -cl-mad-enable " - " -DLINE_STRIDE_A=5504 " - " -DBLOCK_STRIDE_A=44032 " - " -DSIMDGROUP_WIDTH=" + - std::to_string(backend_ctx->adreno_wave_size); - if (backend_ctx->has_vector_subgroup_broadcast) { - CL_gemv_compile_opts += " -DVECTOR_SUB_GROUP_BROADCAT "; - } - - backend_ctx->program_CL_gemv_11008_1_4096 = build_program_from_source( - backend_ctx->context, backend_ctx->device, kernel_src_CL_gemv.c_str(), CL_gemv_compile_opts); - CL_CHECK((backend_ctx->CL_mul_mat_vec_q4_0_f32_1d_4x_flat_11008_1_4096 = clCreateKernel(backend_ctx->program_CL_gemv_11008_1_4096, "kernel_gemv_noshuffle", &err), err)); + CL_CHECK((backend_ctx->kernel_sum_rows_f32 = clCreateKernel(backend_ctx->program_sum_rows_f32, "kernel_sum_rows_f32", &err), err)); + CL_CHECK((backend_ctx->kernel_sum_rows_f32_4 = clCreateKernel(backend_ctx->program_sum_rows_f32, "kernel_sum_rows_f32_4", &err), err)); GGML_LOG_CONT("."); + } - // Gemv 16000, 128000 - CL_gemv_compile_opts = std::string("-cl-std=") + opencl_c_std + - " -cl-mad-enable " - " -DLINE_STRIDE_A=16000 " - " -DBLOCK_STRIDE_A=128000 " - " -DSIMDGROUP_WIDTH=" + - std::to_string(backend_ctx->adreno_wave_size); - - if (backend_ctx->has_vector_subgroup_broadcast) { - CL_gemv_compile_opts += " -DVECTOR_SUB_GROUP_BROADCAT "; - } + // cumsum + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "cumsum.cl.h" + }; +#else + const std::string kernel_src = read_file("cumsum.cl"); +#endif + cl_program prog; + prog = build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); - backend_ctx->program_CL_gemv_32000_1_4096 = build_program_from_source( - backend_ctx->context, backend_ctx->device, kernel_src_CL_gemv.c_str(), CL_gemv_compile_opts); - CL_CHECK((backend_ctx->CL_mul_mat_vec_q4_0_f32_1d_4x_flat_32000_1_4096 = clCreateKernel(backend_ctx->program_CL_gemv_32000_1_4096, "kernel_gemv_noshuffle", &err), err)); + CL_CHECK((backend_ctx->kernel_cumsum_blk = clCreateKernel(prog, "kernel_cumsum_blk", &err), err)); + CL_CHECK((backend_ctx->kernel_cumsum_add = clCreateKernel(prog, "kernel_cumsum_add", &err), err)); GGML_LOG_CONT("."); + CL_CHECK(clReleaseProgram(prog)); } - // mul_mat_Ab_Bi_8x4 + // sigmoid { #ifdef GGML_OPENCL_EMBED_KERNELS - const std::string kernel_src_CL_gemm { - #include "mul_mat_Ab_Bi_8x4.cl.h" + const std::string kernel_src { + #include "sigmoid.cl.h" }; #else - const std::string kernel_src_CL_gemm = read_file("mul_mat_Ab_Bi_8x4.cl"); + const std::string kernel_src = read_file("sigmoid.cl"); #endif - backend_ctx->program_CL_gemm = build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src_CL_gemm.c_str(), compile_opts); - CL_CHECK((backend_ctx->CL_mul_mat_Ab_Bi_8x4 = clCreateKernel(backend_ctx->program_CL_gemm, "kernel_mul_mat_Ab_Bi_8x4", &err), err)); + backend_ctx->program_sigmoid = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); + + CL_CHECK((backend_ctx->kernel_sigmoid_f32 = clCreateKernel(backend_ctx->program_sigmoid, "kernel_sigmoid_f32", &err), err)); + CL_CHECK((backend_ctx->kernel_sigmoid_f16 = clCreateKernel(backend_ctx->program_sigmoid, "kernel_sigmoid_f16", &err), err)); GGML_LOG_CONT("."); } - std::string CL_moe_compile_opts = std::string("-cl-std=") + opencl_c_std + - " -cl-mad-enable " - " -cl-fast-relaxed-math"; - - // gemv_moe_mxfp4_f32 + // group_norm { #ifdef GGML_OPENCL_EMBED_KERNELS const std::string kernel_src { - #include "gemv_moe_mxfp4_f32.cl.h" + #include "group_norm.cl.h" }; #else - const std::string kernel_src = read_file("gemv_moe_mxfp4_f32.cl"); + const std::string kernel_src = read_file("group_norm.cl"); #endif - backend_ctx->program_gemv_moe_mxfp4_f32 = - build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), CL_moe_compile_opts); + backend_ctx->program_group_norm = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); - CL_CHECK((backend_ctx->kernel_gemv_moe_mxfp4_f32 = clCreateKernel(backend_ctx->program_gemv_moe_mxfp4_f32, "kernel_gemv_moe_mxfp4_f32", &err), err)); + CL_CHECK((backend_ctx->kernel_group_norm = clCreateKernel(backend_ctx->program_group_norm, "kernel_group_norm", &err), err)); + CL_CHECK((backend_ctx->kernel_group_norm_mul_add = clCreateKernel(backend_ctx->program_group_norm, "kernel_group_norm_mul_add", &err), err)); GGML_LOG_CONT("."); } - // gemm_moe_mxfp4_f32 + // repeat { #ifdef GGML_OPENCL_EMBED_KERNELS const std::string kernel_src { - #include "gemm_moe_mxfp4_f32.cl.h" + #include "repeat.cl.h" }; #else - const std::string kernel_src = read_file("gemm_moe_mxfp4_f32.cl"); + const std::string kernel_src = read_file("repeat.cl"); #endif - backend_ctx->program_gemm_moe_mxfp4_f32 = - build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), CL_moe_compile_opts); - - CL_CHECK((backend_ctx->kernel_gemm_moe_mxfp4_f32 = clCreateKernel(backend_ctx->program_gemm_moe_mxfp4_f32, "kernel_gemm_moe_mxfp4_f32", &err), err)); + cl_program prog = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); + CL_CHECK((backend_ctx->kernel_repeat_f32 = clCreateKernel(prog, "kernel_repeat_f32", &err), err)); + CL_CHECK(clReleaseProgram(prog)); GGML_LOG_CONT("."); } -#endif // GGML_OPENCL_USE_ADRENO_KERNELS - GGML_LOG_CONT("\n"); -} - -// XXX static ggml_backend_opencl_context * ggml_cl2_init(ggml_backend_dev_t dev) { -// XXX static bool initialized = false; -// XXX static ggml_backend_opencl_context *backend_ctx = nullptr; - -static ggml_backend_opencl_context * ggml_cl2_init(ggml_backend_dev_t dev); - -namespace /* anonymous */ { -extern struct ggml_backend_device_i ggml_backend_opencl_device_i; -} - -// Look for available and suitable devices. -static std::vector<ggml_backend_device> ggml_opencl_probe_devices(ggml_backend_reg * reg) { - std::vector<ggml_backend_device> found_devices; -#ifdef GGML_OPENCL_PROFILING - GGML_LOG_INFO("ggml_opencl: OpenCL profiling enabled\n"); + // pad + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "pad.cl.h" + }; +#else + const std::string kernel_src = read_file("pad.cl"); #endif + if (!kernel_src.empty()) { + backend_ctx->program_pad = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); + CL_CHECK((backend_ctx->kernel_pad = clCreateKernel(backend_ctx->program_pad, "kernel_pad", &err), err)); + GGML_LOG_CONT("."); + } else { + GGML_LOG_WARN("ggml_opencl: pad kernel source not found or empty. Pad operations will not be available.\n"); + backend_ctx->program_pad = nullptr; + backend_ctx->kernel_pad = nullptr; + } + } - struct cl_device; - struct cl_platform { - cl_platform_id id; - unsigned number; - char name[128]; - char vendor[128]; - struct cl_device * devices; - unsigned n_devices; - struct cl_device * default_device; - }; - - struct cl_device { - struct cl_platform * platform; - cl_device_id id; - unsigned number; - cl_device_type type; - char name[128]; - char version[128]; - }; + // tanh + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "tanh.cl.h" + }; +#else + const std::string kernel_src = read_file("tanh.cl"); +#endif + cl_program prog = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); + CL_CHECK((backend_ctx->kernel_tanh_f32 = clCreateKernel(prog, "kernel_tanh_f32", &err), err)); + CL_CHECK((backend_ctx->kernel_tanh_f32_4 = clCreateKernel(prog, "kernel_tanh_f32_4", &err), err)); + CL_CHECK((backend_ctx->kernel_tanh_f32_nc = clCreateKernel(prog, "kernel_tanh_f32_nc", &err), err)); + CL_CHECK((backend_ctx->kernel_tanh_f16 = clCreateKernel(prog, "kernel_tanh_f16", &err), err)); + CL_CHECK((backend_ctx->kernel_tanh_f16_4 = clCreateKernel(prog, "kernel_tanh_f16_4", &err), err)); + CL_CHECK((backend_ctx->kernel_tanh_f16_nc = clCreateKernel(prog, "kernel_tanh_f16_nc", &err), err)); + CL_CHECK(clReleaseProgram(prog)); + GGML_LOG_CONT("."); + } - enum { NPLAT = 16, NDEV = 16 }; + // neg + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "neg.cl.h" + }; +#else + const std::string kernel_src = read_file("neg.cl"); +#endif + cl_program prog = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); + CL_CHECK((backend_ctx->kernel_neg_f32 = clCreateKernel(prog, "kernel_neg_f32", &err), err)); + CL_CHECK((backend_ctx->kernel_neg_f32_4 = clCreateKernel(prog, "kernel_neg_f32_4", &err), err)); + CL_CHECK((backend_ctx->kernel_neg_f32_nc = clCreateKernel(prog, "kernel_neg_f32_nc", &err), err)); + CL_CHECK((backend_ctx->kernel_neg_f16 = clCreateKernel(prog, "kernel_neg_f16", &err), err)); + CL_CHECK((backend_ctx->kernel_neg_f16_4 = clCreateKernel(prog, "kernel_neg_f16_4", &err), err)); + CL_CHECK((backend_ctx->kernel_neg_f16_nc = clCreateKernel(prog, "kernel_neg_f16_nc", &err), err)); + CL_CHECK(clReleaseProgram(prog)); + GGML_LOG_CONT("."); + } - struct cl_platform platforms[NPLAT]; - unsigned n_platforms = 0; - struct cl_device devices[NDEV]; - unsigned n_devices = 0; - struct cl_device * default_device = NULL; - unsigned default_platform_number = 0; + // exp + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "exp.cl.h" + }; +#else + const std::string kernel_src = read_file("exp.cl"); +#endif + cl_program prog = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); + CL_CHECK((backend_ctx->kernel_exp_f32 = clCreateKernel(prog, "kernel_exp_f32", &err), err)); + CL_CHECK((backend_ctx->kernel_exp_f32_4 = clCreateKernel(prog, "kernel_exp_f32_4", &err), err)); + CL_CHECK((backend_ctx->kernel_exp_f32_nc = clCreateKernel(prog, "kernel_exp_f32_nc", &err), err)); + CL_CHECK((backend_ctx->kernel_exp_f16 = clCreateKernel(prog, "kernel_exp_f16", &err), err)); + CL_CHECK((backend_ctx->kernel_exp_f16_4 = clCreateKernel(prog, "kernel_exp_f16_4", &err), err)); + CL_CHECK((backend_ctx->kernel_exp_f16_nc = clCreateKernel(prog, "kernel_exp_f16_nc", &err), err)); + CL_CHECK(clReleaseProgram(prog)); + GGML_LOG_CONT("."); + } - cl_platform_id platform_ids[NPLAT]; - if (clGetPlatformIDs(NPLAT, platform_ids, &n_platforms) != CL_SUCCESS) { - GGML_LOG_ERROR("ggml_opencl: plaform IDs not available.\n"); - return found_devices; + // expm1 + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "expm1.cl.h" + }; +#else + const std::string kernel_src = read_file("expm1.cl"); +#endif + cl_program prog = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); + CL_CHECK((backend_ctx->kernel_expm1_f32 = clCreateKernel(prog, "kernel_expm1_f32", &err), err)); + CL_CHECK((backend_ctx->kernel_expm1_f32_4 = clCreateKernel(prog, "kernel_expm1_f32_4", &err), err)); + CL_CHECK((backend_ctx->kernel_expm1_f32_nc = clCreateKernel(prog, "kernel_expm1_f32_nc", &err), err)); + CL_CHECK((backend_ctx->kernel_expm1_f16 = clCreateKernel(prog, "kernel_expm1_f16", &err), err)); + CL_CHECK((backend_ctx->kernel_expm1_f16_4 = clCreateKernel(prog, "kernel_expm1_f16_4", &err), err)); + CL_CHECK((backend_ctx->kernel_expm1_f16_nc = clCreateKernel(prog, "kernel_expm1_f16_nc", &err), err)); + CL_CHECK(clReleaseProgram(prog)); + GGML_LOG_CONT("."); } - for (unsigned i = 0; i < n_platforms; i++) { - struct cl_platform * p = &platforms[i]; - p->number = i; - p->id = platform_ids[i]; - CL_CHECK(clGetPlatformInfo(p->id, CL_PLATFORM_NAME, sizeof(p->name), &p->name, NULL)); - CL_CHECK(clGetPlatformInfo(p->id, CL_PLATFORM_VENDOR, sizeof(p->vendor), &p->vendor, NULL)); + // softplus + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "softplus.cl.h" + }; +#else + const std::string kernel_src = read_file("softplus.cl"); +#endif + cl_program prog = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); + CL_CHECK((backend_ctx->kernel_softplus_f32 = clCreateKernel(prog, "kernel_softplus_f32", &err), err)); + CL_CHECK((backend_ctx->kernel_softplus_f32_4 = clCreateKernel(prog, "kernel_softplus_f32_4", &err), err)); + CL_CHECK((backend_ctx->kernel_softplus_f32_nc = clCreateKernel(prog, "kernel_softplus_f32_nc", &err), err)); + CL_CHECK((backend_ctx->kernel_softplus_f16 = clCreateKernel(prog, "kernel_softplus_f16", &err), err)); + CL_CHECK((backend_ctx->kernel_softplus_f16_4 = clCreateKernel(prog, "kernel_softplus_f16_4", &err), err)); + CL_CHECK((backend_ctx->kernel_softplus_f16_nc = clCreateKernel(prog, "kernel_softplus_f16_nc", &err), err)); + CL_CHECK(clReleaseProgram(prog)); + GGML_LOG_CONT("."); + } - cl_device_id device_ids[NDEV]; - cl_int clGetDeviceIDsError = clGetDeviceIDs(p->id, CL_DEVICE_TYPE_ALL, NDEV, device_ids, &p->n_devices); - if (clGetDeviceIDsError == CL_DEVICE_NOT_FOUND) { - p->n_devices = 0; + // upscale + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "upscale.cl.h" + }; +#else + const std::string kernel_src = read_file("upscale.cl"); +#endif + if (!kernel_src.empty()) { + backend_ctx->program_upscale = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); + CL_CHECK((backend_ctx->kernel_upscale = clCreateKernel(backend_ctx->program_upscale, "kernel_upscale", &err), err)); + if (backend_ctx->program_upscale) { + cl_int err_bilinear; + backend_ctx->kernel_upscale_bilinear = clCreateKernel(backend_ctx->program_upscale, "kernel_upscale_bilinear", &err_bilinear); + if (err_bilinear != CL_SUCCESS) { + GGML_LOG_WARN("ggml_opencl: kernel_upscale_bilinear not found in upscale.cl. Bilinear upscale will not be available. Error: %d\n", err_bilinear); + backend_ctx->kernel_upscale_bilinear = nullptr; + } + } else { + backend_ctx->kernel_upscale_bilinear = nullptr; + } + GGML_LOG_CONT("."); } else { - CL_CHECK(clGetDeviceIDsError); + GGML_LOG_WARN("ggml_opencl: upscale kernel source not found or empty. Upscale operations will not be available.\n"); + backend_ctx->program_upscale = nullptr; + backend_ctx->kernel_upscale = nullptr; + backend_ctx->kernel_upscale_bilinear = nullptr; } - p->devices = p->n_devices > 0 ? &devices[n_devices] : NULL; - p->default_device = NULL; + } - for (unsigned j = 0; j < p->n_devices; j++) { - struct cl_device * d = &devices[n_devices]; - d->number = n_devices++; - d->id = device_ids[j]; - d->platform = p; - CL_CHECK(clGetDeviceInfo(d->id, CL_DEVICE_NAME, sizeof(d->name), &d->name, NULL)); - CL_CHECK(clGetDeviceInfo(d->id, CL_DEVICE_TYPE, sizeof(d->type), &d->type, NULL)); - CL_CHECK(clGetDeviceInfo(d->id, CL_DEVICE_VERSION, sizeof(d->version), &d->version, NULL)); + // concat + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "concat.cl.h" + }; +#else + const std::string kernel_src = read_file("concat.cl"); +#endif + cl_program prog = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); + CL_CHECK((backend_ctx->kernel_concat_f32 = clCreateKernel(prog, "kernel_concat_f32", &err), err)); + CL_CHECK((backend_ctx->kernel_concat_f32_pack = clCreateKernel(prog, "kernel_concat_f32_pack", &err), err)); + CL_CHECK(clReleaseProgram(prog)); + GGML_LOG_CONT("."); + } - if (p->default_device == NULL && d->type == CL_DEVICE_TYPE_GPU) { - p->default_device = d; - } - } + // timestep_embedding + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "tsembd.cl.h" + }; +#else - if (default_device == NULL && p->default_device != NULL) { - default_device = p->default_device; - default_platform_number = i; + const std::string kernel_src = read_file("tsembd.cl"); +#endif + if (!kernel_src.empty()) { + backend_ctx->program_tsembd = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); + CL_CHECK((backend_ctx->kernel_timestep_embedding = clCreateKernel(backend_ctx->program_tsembd, "kernel_timestep_embedding", &err), err)); + GGML_LOG_CONT("."); + } else { + GGML_LOG_WARN("ggml_opencl: timestep_embedding kernel source not found or empty. This op will not be available.\n"); + backend_ctx->program_tsembd = nullptr; + backend_ctx->kernel_timestep_embedding = nullptr; } } - if (n_devices == 0) { - GGML_LOG_ERROR("ggml_opencl: could find any OpenCL devices.\n"); - return found_devices; - } - - char * user_platform_string = getenv("GGML_OPENCL_PLATFORM"); - char * user_device_string = getenv("GGML_OPENCL_DEVICE"); - int user_platform_number = -1; - int user_device_number = -1; - cl_device * candidate_devices = nullptr; - unsigned n_candidate_devices = 0; + // set_rows + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "set_rows.cl.h" + }; +#else + const std::string kernel_src = read_file("set_rows.cl"); +#endif + backend_ctx->program_set_rows = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); - unsigned n; - if (user_platform_string != NULL && sscanf(user_platform_string, " %u", &n) == 1 && n < n_platforms) { - user_platform_number = (int)n; - } - if (user_device_string != NULL && sscanf(user_device_string, " %u", &n) == 1 && n < n_devices) { - user_device_number = (int)n; + CL_CHECK((backend_ctx->kernel_set_rows_f32_i64 = clCreateKernel(backend_ctx->program_set_rows, "kernel_set_rows_f32_i64", &err), err)); + CL_CHECK((backend_ctx->kernel_set_rows_f32_i32 = clCreateKernel(backend_ctx->program_set_rows, "kernel_set_rows_f32_i32", &err), err)); + CL_CHECK((backend_ctx->kernel_set_rows_f16_i64 = clCreateKernel(backend_ctx->program_set_rows, "kernel_set_rows_f16_i64", &err), err)); + CL_CHECK((backend_ctx->kernel_set_rows_f16_i32 = clCreateKernel(backend_ctx->program_set_rows, "kernel_set_rows_f16_i32", &err), err)); + GGML_LOG_CONT("."); } - if (user_platform_number != -1 && user_device_number != -1) { - cl_platform* platform = &platforms[user_platform_number]; - if ((unsigned)user_device_number >= platform->n_devices) { - GGML_LOG_ERROR("ggml_opencl: invalid device number %d\n", user_device_number); - exit(1); - } - default_device = &platform->devices[user_device_number]; - candidate_devices = platform->devices; - n_candidate_devices = platform->n_devices; - } else { - // Choose a platform by matching a substring. - if (user_platform_number == -1 && user_platform_string != NULL && user_platform_string[0] != 0) { - for (unsigned i = 0; i < n_platforms; i++) { - struct cl_platform * p = &platforms[i]; - if (strstr(p->name, user_platform_string) != NULL || - strstr(p->vendor, user_platform_string) != NULL) { - user_platform_number = (int)i; - break; + + // conv2d + { + #ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "conv2d.cl.h" + }; + const std::string kernel_src_f16_f32 { + #include "conv2d_f16_f32.cl.h" + }; + #else + const std::string kernel_src = read_file("conv2d.cl"); + const std::string kernel_src_f16_f32 = read_file("conv2d_f16_f32.cl"); + #endif + if (!kernel_src.empty()) { + backend_ctx->program_conv_2d_f16 = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), (std::string(compile_opts) + " -DUSE_FP16=1").c_str()); + CL_CHECK((backend_ctx->kernel_conv_2d_f16 = clCreateKernel(backend_ctx->program_conv_2d_f16, "kernel_conv_2d", &err), err)); + GGML_LOG_CONT("."); + backend_ctx->program_conv_2d_f32 = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); + CL_CHECK((backend_ctx->kernel_conv_2d_f32 = clCreateKernel(backend_ctx->program_conv_2d_f32, "kernel_conv_2d", &err), err)); + GGML_LOG_CONT("."); + } else { + GGML_LOG_WARN("ggml_opencl: conv2d kernel source not found or empty. This op will not be available.\n"); + backend_ctx->program_conv_2d_f16 = nullptr; + backend_ctx->kernel_conv_2d_f16 = nullptr; + backend_ctx->program_conv_2d_f32 = nullptr; + backend_ctx->kernel_conv_2d_f32 = nullptr; } - } - if (user_platform_number == -1) { - GGML_LOG_ERROR("ggml_opencl: no platform matching '%s' was found.\n", user_platform_string); - exit(1); - } - } + if (!kernel_src_f16_f32.empty()) { + backend_ctx->program_conv_2d_f16_f32 = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src_f16_f32.c_str(), compile_opts); + CL_CHECK((backend_ctx->kernel_conv_2d_f16_f32 = clCreateKernel(backend_ctx->program_conv_2d_f16_f32, "kernel_conv_2d", &err), err)); + GGML_LOG_CONT("."); + } else { + GGML_LOG_WARN("ggml_opencl: conv2d_f16_f32 kernel source not found or empty. This op will not be available.\n"); + backend_ctx->program_conv_2d_f16_f32 = nullptr; + backend_ctx->kernel_conv_2d_f16_f32 = nullptr; + } + } - int platform_idx = user_platform_number != -1 ? user_platform_number : default_platform_number; - struct cl_platform * p = &platforms[platform_idx]; - candidate_devices = p->devices; - n_candidate_devices = p->n_devices; - default_device = p->default_device; - if (n_candidate_devices == 0) { - GGML_LOG_ERROR("ggml_opencl: selected platform '%s' does not have any devices.\n", p->name); + // ssm_conv + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "ssm_conv.cl.h" + }; +#else + const std::string kernel_src = read_file("ssm_conv.cl"); +#endif + cl_program prog = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); + + CL_CHECK((backend_ctx->kernel_ssm_conv_f32_f32 = clCreateKernel(prog, "kernel_ssm_conv_f32_f32", &err), err)); + CL_CHECK((backend_ctx->kernel_ssm_conv_f32_f32_4 = clCreateKernel(prog, "kernel_ssm_conv_f32_f32_4", &err), err)); + CL_CHECK(clReleaseProgram(prog)); + GGML_LOG_CONT("."); + } + + // gated_delta_net: one kernel per (S_V, KDA, tgpp) triple. + { + #ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "gated_delta_net.cl.h" + }; + #else + const std::string kernel_src = read_file("gated_delta_net.cl"); + #endif + + const int gdn_sizes[4] = { 16, 32, 64, 128 }; + const int sg_size = backend_ctx->gpu_family == GPU_FAMILY::ADRENO ? 64 : backend_ctx->gpu_family == GPU_FAMILY::INTEL ? 32 : -1; + if (sg_size < 0) { + GGML_LOG_ERROR("Unsupported GPU Family: only Adreno and Intel are supported.\n"); exit(1); } - if (user_device_number == -1 && user_device_string != NULL && user_device_string[0] != 0) { - for (unsigned i = 0; i < n_candidate_devices; i++) { - struct cl_device * d = &candidate_devices[i]; - if (strstr(d->name, user_device_string) != NULL) { - user_device_number = d->number; - break; - } + for (int si = 0; si < 4; si++) { + const int S_V = gdn_sizes[si]; + + // MUST match the dispatcher heuristic in ggml_cl_gated_delta_net exactly. + int lanes_per_column; + if (S_V >= 128) { + lanes_per_column = 8; + } else { + lanes_per_column = std::min(S_V, sg_size); } - if (user_device_number == -1) { - GGML_LOG_ERROR("ggml_opencl: no device matching '%s' was found.\n", user_device_string); - exit(1); + + // Round LANES_PER_COLUMN down until it is: + // * power-of-two + // * divides both S_V and sg_size + while (lanes_per_column > 1 && + (((lanes_per_column & (lanes_per_column - 1)) != 0) || + (S_V % lanes_per_column) != 0 || + (sg_size % lanes_per_column) != 0)) { + lanes_per_column >>= 1; } - } - if (user_device_number != -1) { - candidate_devices = &devices[user_device_number]; - n_candidate_devices = 1; - default_device = &candidate_devices[0]; - } - GGML_ASSERT(n_candidate_devices > 0); + GGML_ASSERT(lanes_per_column >= 1); + GGML_ASSERT(((lanes_per_column & (lanes_per_column - 1)) == 0)); + GGML_ASSERT((S_V % lanes_per_column) == 0); + GGML_ASSERT((sg_size % lanes_per_column) == 0); - if (default_device == NULL) { - default_device = &candidate_devices[0]; + const bool is_partial_reduce = (lanes_per_column != 1) && (lanes_per_column < sg_size); + int use_qcom_shuffle = 0; + if (is_partial_reduce) { + if (backend_ctx->has_qcom_subgroup_shuffle) { + use_qcom_shuffle = 1; + } + } + for (int kda = 0; kda < 2; kda++) { + for (int tgpp = 0; tgpp < 2; tgpp++) { + const int cpl = (tgpp == 0) ? 1 : 4; + const int spw = (tgpp == 0) ? 1 : 1; + + std::string opts = compile_opts; + opts += " -DS_V=" + std::to_string(S_V); + opts += " -DKDA=" + std::to_string(kda); + opts += " -DSUBGROUP_SIZE=" + std::to_string(sg_size); + opts += " -DLANES_PER_COLUMN=" + std::to_string(lanes_per_column); + opts += " -DCOLS_PER_LANE_GROUP=" + std::to_string(cpl); + opts += " -DUSE_QCOM_SUBGROUP_SHUFFLE=" + std::to_string(use_qcom_shuffle); + + // Since spw=1 is found to be optimal, SUBGROUPS_PER_WG > 1 code in + // the kernel is removed. If you want to experiment with spw > 1, + // Please remember to implement code to handle it. + opts += " -DSUBGROUPS_PER_WG=" + std::to_string(spw); + + cl_program prog = build_program_from_source( + backend_ctx->context, backend_ctx->device, kernel_src.c_str(), opts); + + CL_CHECK((backend_ctx->kernel_gated_delta_net_f32[si][kda][tgpp] = + clCreateKernel(prog, "kernel_gated_delta_net", &err), err)); + CL_CHECK(clReleaseProgram(prog)); + } + } } + GGML_LOG_CONT("."); } - GGML_ASSERT(n_candidate_devices != 0 && candidate_devices); + // mul_mv_id_q4_0_f32_8x_flat + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "mul_mv_id_q4_0_f32_8x_flat.cl.h" + }; +#else + const std::string kernel_src = read_file("mul_mv_id_q4_0_f32_8x_flat.cl"); +#endif + backend_ctx->program_mul_mv_id_q4_0_f32_8x_flat = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); - // Put the default device in front. - for (unsigned i = 1; i < n_candidate_devices; i++) { - if (&candidate_devices[i] == default_device) { - std::swap(candidate_devices[0], candidate_devices[i]); - default_device = &candidate_devices[0]; - break; - } + CL_CHECK((backend_ctx->kernel_mul_mv_id_q4_0_f32_8x_flat = clCreateKernel(backend_ctx->program_mul_mv_id_q4_0_f32_8x_flat, "kernel_mul_mv_id_q4_0_f32_8x_flat", &err), err)); + GGML_LOG_CONT("."); } - GGML_LOG_INFO("ggml_opencl: selected platform: '%s'\n", default_device->platform->name); + // mul_mv_id_q8_0_f32 + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "mul_mv_id_q8_0_f32.cl.h" + }; +#else + const std::string kernel_src = read_file("mul_mv_id_q8_0_f32.cl"); +#endif + backend_ctx->program_mul_mv_id_q8_0_f32 = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); - std::vector<cl_device_id> device_ids; - for (auto dev = candidate_devices, dev_end = candidate_devices + n_candidate_devices; dev != dev_end; dev++) { - device_ids.push_back(dev->id); + CL_CHECK((backend_ctx->kernel_mul_mv_id_q8_0_f32 = clCreateKernel(backend_ctx->program_mul_mv_id_q8_0_f32, "kernel_mul_mv_id_q8_0_f32", &err), err)); + GGML_LOG_CONT("."); } - cl_int err; - cl_context shared_context; - cl_context_properties properties[] = { (intptr_t) CL_CONTEXT_PLATFORM, (intptr_t) default_device->platform->id, 0 }; - - CL_CHECK( - (shared_context = clCreateContext(properties, device_ids.size(), device_ids.data(), NULL, NULL, &err), err)); + // mul_mv_id_q8_0_f32_flat + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "mul_mv_id_q8_0_f32_flat.cl.h" + }; +#else + const std::string kernel_src = read_file("mul_mv_id_q8_0_f32_flat.cl"); +#endif + backend_ctx->program_mul_mv_id_q8_0_f32_flat = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); - for (auto dev = candidate_devices, dev_end = candidate_devices + n_candidate_devices; dev != dev_end; dev++) { - GGML_LOG_INFO("\nggml_opencl: device: '%s (%s)'\n", dev->name, dev->version); + CL_CHECK((backend_ctx->kernel_mul_mv_id_q8_0_f32_flat = clCreateKernel(backend_ctx->program_mul_mv_id_q8_0_f32_flat, "kernel_mul_mv_id_q8_0_f32_flat", &err), err)); + GGML_LOG_CONT("."); + } - auto dev_ctx = std::unique_ptr<ggml_backend_opencl_device_context>(new ggml_backend_opencl_device_context{ - /*.platform =*/dev->platform->id, - /*.platform_nane =*/dev->platform->name, - /*.device =*/dev->id, - /*.device_name =*/dev->name, - /*.device_type =*/dev->type, - /*.device_version =*/dev->version, - /*.backend_ctx =*/nullptr, - /*.buffer_type =*/{}, - /*.context =*/shared_context, - }); + // mul_mv_id_mxfp4_f32 + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "mul_mv_id_mxfp4_f32.cl.h" + }; +#else + const std::string kernel_src = read_file("mul_mv_id_mxfp4_f32.cl"); +#endif + backend_ctx->program_mul_mv_id_mxfp4_f32 = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); - found_devices.push_back(ggml_backend_device{ - /* .iface = */ ggml_backend_opencl_device_i, - /* .reg = */ reg, - /* .context = */ dev_ctx.get(), - }); + CL_CHECK((backend_ctx->kernel_mul_mv_id_mxfp4_f32 = clCreateKernel(backend_ctx->program_mul_mv_id_mxfp4_f32, "kernel_mul_mv_id_mxfp4_f32", &err), err)); + GGML_LOG_CONT("."); + } - if (!ggml_cl2_init(&found_devices.back())) { - found_devices.pop_back(); - GGML_LOG_INFO("ggml_opencl: drop unsupported device.\n"); - continue; - } + // mul_mv_id_mxfp4_f32_flat + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "mul_mv_id_mxfp4_f32_flat.cl.h" + }; +#else + const std::string kernel_src = read_file("mul_mv_id_mxfp4_f32_flat.cl"); +#endif + backend_ctx->program_mul_mv_id_mxfp4_f32_flat = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); - dev_ctx.release(); + CL_CHECK((backend_ctx->kernel_mul_mv_id_mxfp4_f32_flat = clCreateKernel(backend_ctx->program_mul_mv_id_mxfp4_f32_flat, "kernel_mul_mv_id_mxfp4_f32_flat", &err), err)); + GGML_LOG_CONT("."); } - if (found_devices.size()) { - auto * dev_ctx = static_cast<ggml_backend_opencl_device_context *>(found_devices.front().context); - GGML_LOG_INFO("ggml_opencl: default device: '%s (%s)'\n", dev_ctx->device_name.c_str(), - dev_ctx->device_version.c_str()); + // Adreno kernels +#ifdef GGML_OPENCL_USE_ADRENO_KERNELS + // transpose + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "transpose.cl.h" + }; +#else + const std::string kernel_src = read_file("transpose.cl"); +#endif + backend_ctx->program_transpose = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); - if (dev_ctx->device_type != CL_DEVICE_TYPE_GPU) { - GGML_LOG_WARN("ggml_opencl: warning, the default device is not a GPU: '%s'.\n", - dev_ctx->device_name.c_str()); - } + CL_CHECK((backend_ctx->kernel_transpose_32_16 = clCreateKernel(backend_ctx->program_transpose, "kernel_transpose_32_16", &err), err)); + CL_CHECK((backend_ctx->kernel_transpose_32 = clCreateKernel(backend_ctx->program_transpose, "kernel_transpose_32", &err), err)); + CL_CHECK((backend_ctx->kernel_transpose_16 = clCreateKernel(backend_ctx->program_transpose, "kernel_transpose_16", &err), err)); + CL_CHECK((backend_ctx->kernel_transpose_8_buf = clCreateKernel(backend_ctx->program_transpose, "kernel_transpose_8_buf", &err), err)); + CL_CHECK((backend_ctx->kernel_transpose_16_buf = clCreateKernel(backend_ctx->program_transpose, "kernel_transpose_16_buf", &err), err)); + CL_CHECK((backend_ctx->kernel_transpose_32_buf = clCreateKernel(backend_ctx->program_transpose, "kernel_transpose_32_buf", &err), err)); + CL_CHECK((backend_ctx->kernel_transpose_16_4x1 = clCreateKernel(backend_ctx->program_transpose, "kernel_transpose_16_4x1", &err), err)); + GGML_LOG_CONT("."); } - return found_devices; -} + // gemv_noshuffle_general + { + std::string CL_gemv_compile_opts = std::string("-cl-std=") + opencl_c_std + + " -cl-mad-enable " + " -DSIMDGROUP_WIDTH=" + + std::to_string(backend_ctx->adreno_wave_size); + if (backend_ctx->has_vector_subgroup_broadcast) { + CL_gemv_compile_opts += " -DVECTOR_SUB_GROUP_BROADCAST "; + } -// Initialize device if it is supported (returns nullptr if it is not). -static ggml_backend_opencl_context * ggml_cl2_init(ggml_backend_dev_t dev) { - GGML_ASSERT(dev); - GGML_ASSERT(dev->context); +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src_CL_gemv_general { + #include "gemv_noshuffle_q4_0_f32.cl.h" + }; +#else + const std::string kernel_src_CL_gemv_general = read_file("gemv_noshuffle_q4_0_f32.cl"); +#endif - ggml_backend_opencl_device_context * dev_ctx = (ggml_backend_opencl_device_context *) dev->context; - GGML_ASSERT(dev_ctx->platform); - GGML_ASSERT(dev_ctx->device); + cl_program prog = build_program_from_source( + backend_ctx->context, backend_ctx->device, kernel_src_CL_gemv_general.c_str(), CL_gemv_compile_opts); - if (dev_ctx->backend_ctx) { - return dev_ctx->backend_ctx; + CL_CHECK((backend_ctx->kernel_gemv_noshuffle_q4_0_f32 = clCreateKernel(prog, "kernel_gemv_noshuffle_q4_0_f32", &err), err)); + CL_CHECK(clReleaseProgram(prog)); + GGML_LOG_CONT("."); } - auto backend_ctx = std::make_unique<ggml_backend_opencl_context>(); - backend_ctx->device = dev_ctx->device; - backend_ctx->gpu_family = GPU_FAMILY::UNKNOWN; - - // ref_count get increased in ggml_backend_opencl_device_init - // This function is also used to retrieve backend context, so we don't want - // to increase ref_count for each call. We only want to increase ref_count - // when the associated device is initialized - backend_ctx->ref_count = 0; - - if (strstr(dev_ctx->device_name.c_str(), "Adreno") || - strstr(dev_ctx->device_name.c_str(), "Qualcomm") || - strstr(dev_ctx->device_version.c_str(), "Adreno")) { - backend_ctx->gpu_family = GPU_FAMILY::ADRENO; - // Usually device version contains the detailed device name - backend_ctx->adreno_gen = get_adreno_gpu_gen(dev_ctx->device_version.c_str()); - if (backend_ctx->adreno_gen == ADRENO_GPU_GEN::ADRENO_UNKNOWN) { - backend_ctx->adreno_gen = get_adreno_gpu_gen(dev_ctx->device_name.c_str()); + // gemv_noshuffle + { + // Gemv 2048, 16384 + std::string CL_gemv_compile_opts = std::string("-cl-std=") + opencl_c_std + + " -cl-mad-enable " + " -DLINE_STRIDE_A=2048 " + " -DBLOCK_STRIDE_A=16384 " + " -DSIMDGROUP_WIDTH=" + + std::to_string(backend_ctx->adreno_wave_size); + if (backend_ctx->has_vector_subgroup_broadcast) { + CL_gemv_compile_opts += " -DVECTOR_SUB_GROUP_BROADCAST "; } - // Use wave size of 64 for all Adreno GPUs. - backend_ctx->adreno_wave_size = 64; - } else if (strstr(dev_ctx->device_name.c_str(), "Intel")) { - backend_ctx->gpu_family = GPU_FAMILY::INTEL; - } else { - GGML_LOG_ERROR("Unsupported GPU: %s\n", dev_ctx->device_name.c_str()); - backend_ctx->gpu_family = GPU_FAMILY::UNKNOWN; - return nullptr; - } - -#ifdef GGML_OPENCL_USE_ADRENO_KERNELS - if (backend_ctx->gpu_family != GPU_FAMILY::ADRENO) { - GGML_LOG_ERROR("ggml_opencl: Adreno-specific kernels should not be enabled for non-Adreno GPUs; " - "run on an Adreno GPU or recompile with CMake option `-DGGML_OPENCL_USE_ADRENO_KERNELS=OFF`\n"); - return nullptr; - } +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src_CL_gemv { + #include "gemv_noshuffle_q4_0_f32_spec.cl.h" + }; +#else + const std::string kernel_src_CL_gemv = read_file("gemv_noshuffle_q4_0_f32_spec.cl"); #endif - // Populate backend device name - backend_ctx->device_name = dev_ctx->device_name; + cl_program prog = build_program_from_source( + backend_ctx->context, backend_ctx->device, kernel_src_CL_gemv.c_str(), CL_gemv_compile_opts); + CL_CHECK((backend_ctx->kernel_gemv_noshuffle_q4_0_f32_4096_1_4096 = clCreateKernel(prog, "kernel_gemv_noshuffle_q4_0_f32", &err), err)); + CL_CHECK(clReleaseProgram(prog)); + GGML_LOG_CONT("."); - // A local ref of cl_device_id for convenience - cl_device_id device = backend_ctx->device; + // Gemv 2048, 16384 + CL_gemv_compile_opts = std::string("-cl-std=") + opencl_c_std + + " -cl-mad-enable " + " -DLINE_STRIDE_A=2048 " + " -DBLOCK_STRIDE_A=16384 " + " -DSIMDGROUP_WIDTH=" + + std::to_string(backend_ctx->adreno_wave_size); + if (backend_ctx->has_vector_subgroup_broadcast) { + CL_gemv_compile_opts += " -DVECTOR_SUB_GROUP_BROADCAST "; + } - ggml_cl_version platform_version = get_opencl_platform_version(dev_ctx->platform); + prog = build_program_from_source( + backend_ctx->context, backend_ctx->device, kernel_src_CL_gemv.c_str(), CL_gemv_compile_opts); + CL_CHECK((backend_ctx->kernel_gemv_noshuffle_q4_0_f32_4096_1_11008 = clCreateKernel(prog, "kernel_gemv_noshuffle_q4_0_f32", &err), err)); + CL_CHECK(clReleaseProgram(prog)); + GGML_LOG_CONT("."); - // Check device OpenCL version, OpenCL 2.0 or above is required - ggml_cl_version opencl_c_version = get_opencl_c_version(platform_version, device); - if (opencl_c_version.major < 2) { - GGML_LOG_ERROR("ggml_opencl: OpenCL 2.0 or above is required\n"); - return nullptr; - } + // Gemv 5504, 44032 + CL_gemv_compile_opts = std::string("-cl-std=") + opencl_c_std + + " -cl-mad-enable " + " -DLINE_STRIDE_A=5504 " + " -DBLOCK_STRIDE_A=44032 " + " -DSIMDGROUP_WIDTH=" + + std::to_string(backend_ctx->adreno_wave_size); + if (backend_ctx->has_vector_subgroup_broadcast) { + CL_gemv_compile_opts += " -DVECTOR_SUB_GROUP_BROADCAST "; + } - // Check driver version - size_t driver_version_str_size; - clGetDeviceInfo(device, CL_DRIVER_VERSION, 0, NULL, &driver_version_str_size); - char *driver_version = (char *)alloca(driver_version_str_size + 1); - clGetDeviceInfo(device, CL_DRIVER_VERSION, driver_version_str_size, driver_version, NULL); - driver_version[driver_version_str_size] = '\0'; - GGML_LOG_INFO("ggml_opencl: OpenCL driver: %s\n", driver_version); - backend_ctx->driver_version = driver_version; + prog = build_program_from_source( + backend_ctx->context, backend_ctx->device, kernel_src_CL_gemv.c_str(), CL_gemv_compile_opts); + CL_CHECK((backend_ctx->kernel_gemv_noshuffle_q4_0_f32_11008_1_4096 = clCreateKernel(prog, "kernel_gemv_noshuffle_q4_0_f32", &err), err)); + CL_CHECK(clReleaseProgram(prog)); + GGML_LOG_CONT("."); - backend_ctx->adreno_cl_compiler_version = get_adreno_cl_compiler_version(driver_version); - backend_ctx->has_vector_subgroup_broadcast = - (backend_ctx->adreno_cl_compiler_version.type == E031 && backend_ctx->adreno_cl_compiler_version.major >= 47) || - (backend_ctx->adreno_cl_compiler_version.type == DX && backend_ctx->adreno_cl_compiler_version.major >= 17); - GGML_LOG_INFO("ggml_opencl: vector subgroup broadcast support: %s\n", - backend_ctx->has_vector_subgroup_broadcast ? "true" : "false"); + // Gemv 16000, 128000 + CL_gemv_compile_opts = std::string("-cl-std=") + opencl_c_std + + " -cl-mad-enable " + " -DLINE_STRIDE_A=16000 " + " -DBLOCK_STRIDE_A=128000 " + " -DSIMDGROUP_WIDTH=" + + std::to_string(backend_ctx->adreno_wave_size); - size_t ext_str_size; - clGetDeviceInfo(device, CL_DEVICE_EXTENSIONS, 0, NULL, &ext_str_size); - char *ext_buffer = (char *)alloca(ext_str_size + 1); - clGetDeviceInfo(device, CL_DEVICE_EXTENSIONS, ext_str_size, ext_buffer, NULL); - ext_buffer[ext_str_size] = '\0'; // ensure it is null terminated - // Check if ext_buffer contains cl_khr_fp16 - backend_ctx->fp16_support = strstr(ext_buffer, "cl_khr_fp16") != NULL; - GGML_LOG_INFO("ggml_opencl: device FP16 support: %s\n", backend_ctx->fp16_support ? "true" : "false"); + if (backend_ctx->has_vector_subgroup_broadcast) { + CL_gemv_compile_opts += " -DVECTOR_SUB_GROUP_BROADCAST "; + } - // fp16 is required - if (!backend_ctx->fp16_support) { - GGML_LOG_ERROR("ggml_opencl: device does not support FP16\n"); - return nullptr; + prog = build_program_from_source( + backend_ctx->context, backend_ctx->device, kernel_src_CL_gemv.c_str(), CL_gemv_compile_opts); + CL_CHECK((backend_ctx->kernel_gemv_noshuffle_q4_0_f32_32000_1_4096 = clCreateKernel(prog, "kernel_gemv_noshuffle_q4_0_f32", &err), err)); + CL_CHECK(clReleaseProgram(prog)); + GGML_LOG_CONT("."); } - // If OpenCL 3.0 is supported, then check for cl_khr_subgroups, which becomes - // optional in OpenCL 3.0 (cl_khr_subgroup is mandatory in OpenCL 2.x) - if (opencl_c_version.major == 3 && strstr(ext_buffer, "cl_khr_subgroups") == NULL && - strstr(ext_buffer, "cl_intel_subgroups") == NULL) { - GGML_LOG_ERROR("ggml_opencl: device does not support subgroups (cl_khr_subgroups or cl_intel_subgroups) " - "(note that subgroups is an optional feature in OpenCL 3.0)\n"); - return nullptr; + // mul_mat_Ab_Bi_8x4 + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src_CL_gemm { + #include "gemm_noshuffle_q4_0_f32.cl.h" + }; +#else + const std::string kernel_src_CL_gemm = read_file("gemm_noshuffle_q4_0_f32.cl"); +#endif + cl_program prog = build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src_CL_gemm.c_str(), compile_opts); + CL_CHECK((backend_ctx->kernel_gemm_noshuffle_q4_0_f32 = clCreateKernel(prog, "kernel_gemm_noshuffle_q4_0_f32", &err), err)); + CL_CHECK(clReleaseProgram(prog)); + GGML_LOG_CONT("."); } - cl_uint base_align_in_bits; - CL_CHECK(clGetDeviceInfo(device, CL_DEVICE_MEM_BASE_ADDR_ALIGN, sizeof(cl_uint), &base_align_in_bits, NULL)); - GGML_ASSERT(base_align_in_bits % 8u == 0); - backend_ctx->alignment = base_align_in_bits / 8u; - GGML_LOG_INFO("ggml_opencl: mem base addr align: %u\n", backend_ctx->alignment); + // gemm_noshuffle_q4_1_f32 + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "gemm_noshuffle_q4_1_f32.cl.h" + }; +#else + const std::string kernel_src = read_file("gemm_noshuffle_q4_1_f32.cl"); +#endif + cl_program prog = build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); + CL_CHECK((backend_ctx->kernel_gemm_noshuffle_q4_1_f32 = clCreateKernel(prog, "kernel_gemm_noshuffle_q4_1_f32", &err), err)); + CL_CHECK(clReleaseProgram(prog)); + GGML_LOG_CONT("."); + } - clGetDeviceInfo(device, CL_DEVICE_MAX_MEM_ALLOC_SIZE, sizeof(size_t), &backend_ctx->max_alloc_size, NULL); - GGML_LOG_INFO("ggml_opencl: max mem alloc size: %zu MB\n", backend_ctx->max_alloc_size/1024/1024); + // gemv_noshuffle_q4_1_f32 + { + std::string CL_gemv_compile_opts = std::string("-cl-std=") + opencl_c_std + + " -cl-mad-enable "; + if (backend_ctx->has_vector_subgroup_broadcast) { + CL_gemv_compile_opts += " -DVECTOR_SUB_GROUP_BROADCAT "; + } + +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "gemv_noshuffle_q4_1_f32.cl.h" + }; +#else + const std::string kernel_src = read_file("gemv_noshuffle_q4_1_f32.cl"); +#endif - clGetDeviceInfo(device, CL_DEVICE_MAX_WORK_GROUP_SIZE, sizeof(size_t), &backend_ctx->max_workgroup_size, NULL); - GGML_LOG_INFO("ggml_opencl: device max workgroup size: %lu\n", backend_ctx->max_workgroup_size); + cl_program prog = build_program_from_source( + backend_ctx->context, backend_ctx->device, kernel_src.c_str(), CL_gemv_compile_opts); - // Check SVM. - cl_device_svm_capabilities svm_caps; - CL_CHECK(clGetDeviceInfo(device, CL_DEVICE_SVM_CAPABILITIES, sizeof(cl_device_svm_capabilities), &svm_caps, 0)); - GGML_LOG_INFO("ggml_opencl: SVM coarse grain buffer support: %s\n", - svm_caps & CL_DEVICE_SVM_COARSE_GRAIN_BUFFER ? "true" : "false"); - GGML_LOG_INFO("ggml_opencl: SVM fine grain buffer support: %s\n", - svm_caps & CL_DEVICE_SVM_FINE_GRAIN_BUFFER ? "true" : "false"); - GGML_LOG_INFO("ggml_opencl: SVM fine grain system support: %s\n", - svm_caps & CL_DEVICE_SVM_FINE_GRAIN_SYSTEM ? "true" : "false"); - GGML_LOG_INFO("ggml_opencl: SVM atomics support: %s\n", - svm_caps & CL_DEVICE_SVM_ATOMICS ? "true" : "false"); + CL_CHECK((backend_ctx->kernel_gemv_noshuffle_q4_1_f32 = clCreateKernel(prog, "kernel_gemv_noshuffle_q4_1_f32", &err), err)); + CL_CHECK(clReleaseProgram(prog)); + GGML_LOG_CONT("."); + } - if (opencl_c_version.major >= 3) { - // Assume it is not available for 3.0, since it is optional in 3.0. - // If compiling against 3.0, then we can query. - backend_ctx->non_uniform_workgroups = false; -#if CL_TARGET_OPENCL_VERSION >= 300 - CL_CHECK(clGetDeviceInfo(device, CL_DEVICE_NON_UNIFORM_WORK_GROUP_SUPPORT, sizeof(cl_bool), - &backend_ctx->non_uniform_workgroups, 0)); + // gemm_noshuffle_q5_0_f32 + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "gemm_noshuffle_q5_0_f32.cl.h" + }; +#else + const std::string kernel_src = read_file("gemm_noshuffle_q5_0_f32.cl"); #endif - } else { - GGML_ASSERT(opencl_c_version.major == 2); - // Non-uniform workgroup sizes is mandatory feature in v2.x. - backend_ctx->non_uniform_workgroups = true; + cl_program prog = build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); + CL_CHECK((backend_ctx->kernel_gemm_noshuffle_q5_0_f32 = clCreateKernel(prog, "kernel_gemm_noshuffle_q5_0_f32", &err), err)); + CL_CHECK(clReleaseProgram(prog)); + GGML_LOG_CONT("."); } - // Print out configurations -#ifdef GGML_OPENCL_SOA_Q - GGML_LOG_INFO("ggml_opencl: flattening quantized weights representation as struct of arrays (GGML_OPENCL_SOA_Q)\n"); -#endif // GGML_OPENCL_SOA_Q + // gemv_noshuffle_q5_0_f32 + { + std::string CL_gemv_compile_opts = std::string("-cl-std=") + opencl_c_std + + " -cl-mad-enable "; + if (backend_ctx->has_vector_subgroup_broadcast) { + CL_gemv_compile_opts += " -DVECTOR_SUB_GROUP_BROADCAST "; + } -#ifdef GGML_OPENCL_USE_ADRENO_KERNELS - GGML_LOG_INFO("ggml_opencl: using kernels optimized for Adreno (GGML_OPENCL_USE_ADRENO_KERNELS)\n"); -#endif // GGML_OPENCL_USE_ADRENO_KERNELS +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "gemv_noshuffle_q5_0_f32.cl.h" + }; +#else + const std::string kernel_src = read_file("gemv_noshuffle_q5_0_f32.cl"); +#endif + cl_program prog = build_program_from_source( + backend_ctx->context, backend_ctx->device, kernel_src.c_str(), CL_gemv_compile_opts); + CL_CHECK((backend_ctx->kernel_gemv_noshuffle_q5_0_f32 = clCreateKernel(prog, "kernel_gemv_noshuffle_q5_0_f32", &err), err)); + CL_CHECK(clReleaseProgram(prog)); + GGML_LOG_CONT("."); + } - cl_int err; + // gemm_noshuffle_q5_1_f32 + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "gemm_noshuffle_q5_1_f32.cl.h" + }; +#else + const std::string kernel_src = read_file("gemm_noshuffle_q5_1_f32.cl"); +#endif + cl_program prog = build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); + CL_CHECK((backend_ctx->kernel_gemm_noshuffle_q5_1_f32 = clCreateKernel(prog, "kernel_gemm_noshuffle_q5_1_f32", &err), err)); + CL_CHECK(clReleaseProgram(prog)); + GGML_LOG_CONT("."); + } - // A local ref of cl_context for convenience - cl_context context = backend_ctx->context = dev_ctx->context; + // gemv_noshuffle_q5_1_f32 + { + std::string CL_gemv_compile_opts = std::string("-cl-std=") + opencl_c_std + + " -cl-mad-enable "; + if (backend_ctx->has_vector_subgroup_broadcast) { + CL_gemv_compile_opts += " -DVECTOR_SUB_GROUP_BROADCAST "; + } - //CL_CHECK((queue = clCreateCommandQueue(context, device, CL_QUEUE_OUT_OF_ORDER_EXEC_MODE_ENABLE, &err), - // (err != CL_INVALID_QUEUE_PROPERTIES && err != CL_INVALID_VALUE ? err : - // (queue = clCreateCommandQueue(context, device, 0, &err), err) - //))); - cl_command_queue_properties command_queue_props = 0; -#ifdef GGML_OPENCL_PROFILING - command_queue_props |= CL_QUEUE_PROFILING_ENABLE; +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "gemv_noshuffle_q5_1_f32.cl.h" + }; +#else + const std::string kernel_src = read_file("gemv_noshuffle_q5_1_f32.cl"); #endif - CL_CHECK((backend_ctx->queue = clCreateCommandQueue(context, device, command_queue_props, &err), err)); - - // Load kernels - load_cl_kernels(backend_ctx.get(), opencl_c_version); + cl_program prog = build_program_from_source( + backend_ctx->context, backend_ctx->device, kernel_src.c_str(), CL_gemv_compile_opts); + CL_CHECK((backend_ctx->kernel_gemv_noshuffle_q5_1_f32 = clCreateKernel(prog, "kernel_gemv_noshuffle_q5_1_f32", &err), err)); + CL_CHECK(clReleaseProgram(prog)); + GGML_LOG_CONT("."); + } -#ifdef GGML_OPENCL_USE_ADRENO_KERNELS - // Allocate intermediate buffers and images - size_t required_A_q_d_bytes = 311164928; - size_t required_A_s_d_bytes = 38895616; - size_t required_B_d_bytes = 45088768; + // gemm_noshuffle_iq4_nl_f32 + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "gemm_noshuffle_iq4_nl_f32.cl.h" + }; +#else + const std::string kernel_src = read_file("gemm_noshuffle_iq4_nl_f32.cl"); +#endif + cl_program prog = build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); + CL_CHECK((backend_ctx->kernel_gemm_noshuffle_iq4_nl_f32 = clCreateKernel(prog, "kernel_gemm_noshuffle_iq4_nl_f32", &err), err)); + CL_CHECK(clReleaseProgram(prog)); + GGML_LOG_CONT("."); + } - // Ensure buffer sizes do not exceed the maximum allocation size - size_t max_A_q_d_bytes = MIN(required_A_q_d_bytes, backend_ctx->max_alloc_size); - size_t max_A_s_d_bytes = MIN(required_A_s_d_bytes, backend_ctx->max_alloc_size); - size_t max_B_d_bytes = MIN(required_B_d_bytes, backend_ctx->max_alloc_size); - if (required_A_q_d_bytes > backend_ctx->max_alloc_size) { - GGML_LOG_WARN("ggml_opencl: A_q_d buffer size reduced from %zu to %zu due to device limitations.\n", - required_A_q_d_bytes, max_A_q_d_bytes); - } - if (required_A_s_d_bytes > backend_ctx->max_alloc_size) { - GGML_LOG_WARN("ggml_opencl: A_s_d buffer size reduced from %zu to %zu due to device limitations.\n", - required_A_s_d_bytes, max_A_s_d_bytes); - } - if (required_B_d_bytes > backend_ctx->max_alloc_size) { - GGML_LOG_WARN("ggml_opencl: B_d buffer size reduced from %zu to %zu due to device limitations.\n", - required_B_d_bytes, max_B_d_bytes); - } + // gemv_noshuffle_iq4_nl_f32 + { + std::string CL_gemv_compile_opts = std::string("-cl-std=") + opencl_c_std + + " -cl-mad-enable "; + if (backend_ctx->has_vector_subgroup_broadcast) { + CL_gemv_compile_opts += " -DVECTOR_SUB_GROUP_BROADCAST "; + } - backend_ctx->prealloc_quant_trans.allocate(context, max_A_q_d_bytes); - backend_ctx->prealloc_scales_trans.allocate(context, max_A_s_d_bytes); - backend_ctx->prealloc_act_trans.allocate(context, max_B_d_bytes); -#endif // GGML_OPENCL_USE_ADRENO_KERNELS +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "gemv_noshuffle_iq4_nl_f32.cl.h" + }; +#else + const std::string kernel_src = read_file("gemv_noshuffle_iq4_nl_f32.cl"); +#endif - backend_ctx->disable_fusion = getenv("GGML_OPENCL_DISABLE_FUSION") != nullptr; + cl_program prog = build_program_from_source( + backend_ctx->context, backend_ctx->device, kernel_src.c_str(), CL_gemv_compile_opts); - dev_ctx->backend_ctx = backend_ctx.release(); - return dev_ctx->backend_ctx; -} + CL_CHECK((backend_ctx->kernel_gemv_noshuffle_iq4_nl_f32 = clCreateKernel(prog, "kernel_gemv_noshuffle_iq4_nl_f32", &err), err)); + CL_CHECK(clReleaseProgram(prog)); + GGML_LOG_CONT("."); + } -static void ggml_cl2_free(ggml_backend_t backend) { - ggml_backend_opencl_context * ctx = (ggml_backend_opencl_context *) backend->context; - ctx->free(); + // mul_mm_q8_0_f32_8x4 + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "gemm_noshuffle_q8_0_f32.cl.h" + }; +#else + const std::string kernel_src = read_file("gemm_noshuffle_q8_0_f32.cl"); +#endif + cl_program prog = build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); + CL_CHECK((backend_ctx->kernel_gemm_noshuffle_q8_0_f32 = clCreateKernel(prog, "kernel_gemm_noshuffle_q8_0_f32", &err), err)); + CL_CHECK(clReleaseProgram(prog)); + GGML_LOG_CONT("."); + } - // The CL context is shared by all backends, release it if all backends have been released - bool should_release_opencl = true; - for (auto device : g_ggml_backend_opencl_devices) { - ggml_backend_opencl_device_context * ctx_dev = (ggml_backend_opencl_device_context *) device.context; - if (ctx_dev->backend_ctx->ref_count > 0) { - should_release_opencl = false; + // gemv_noshuffle_general_q8_0_f32 + { + std::string CL_gemv_compile_opts = std::string("-cl-std=") + opencl_c_std + + " -cl-mad-enable " + " -DSIMDGROUP_WIDTH=" + + std::to_string(backend_ctx->adreno_wave_size); + if (backend_ctx->has_vector_subgroup_broadcast) { + CL_gemv_compile_opts += " -DVECTOR_SUB_GROUP_BROADCAT "; } - } - if (should_release_opencl) { - CL_CHECK(clReleaseContext(ctx->context)); - } -} +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src_CL_gemv_general { + #include "gemv_noshuffle_q8_0_f32.cl.h" + }; +#else + const std::string kernel_src_CL_gemv_general = read_file("gemv_noshuffle_q8_0_f32.cl"); +#endif -//------------------------------------------------------------------------------ -// Tensor extra management -//------------------------------------------------------------------------------ -struct ggml_tensor_extra_cl { - // The buffer object that holds the data. - cl_mem data_device; - // The offset into the buffer object. This is primarily for scratch buffer - // and view operation. - // NB: this offset no longer includes view offset (view_offs). Whenever this - // offset is used, view_offs should be considered. - cl_ulong offset; - // The actual size of the cl_mem object. This is needed when returning the - // block to the pool. - size_t actual_size; + cl_program prog = build_program_from_source( + backend_ctx->context, backend_ctx->device, kernel_src_CL_gemv_general.c_str(), CL_gemv_compile_opts); - void reset() { - data_device = nullptr; - offset = 0; - actual_size = 0; + CL_CHECK((backend_ctx->kernel_gemv_noshuffle_q8_0_f32 = clCreateKernel(prog, "kernel_gemv_noshuffle_q8_0_f32", &err), err)); + CL_CHECK(clReleaseProgram(prog)); + GGML_LOG_CONT("."); } -}; - -// Additional tensor extra structs for quantized tensors. -// These tensors are loaded from files and should not be allocated in scratch -- -// they should always be allocated from the pool. Hence, they do not have an -// `offset`, which indicate their locations in the scratch buffer. -struct ggml_tensor_extra_cl_q4_0 { - // Quantized values. - cl_mem q = nullptr; - // Quantized values in image1d_buffer_t. - cl_mem q_img = nullptr; - // Scales. - cl_mem d = nullptr; - // Scales in image1d_buffer_t. - cl_mem d_img = nullptr; - // Size of quantized values. - size_t size_q = 0; - // Size of scales. - size_t size_d = 0; - ~ggml_tensor_extra_cl_q4_0() { - reset(); + // gemm_noshuffle_q4_k_f32 + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "gemm_noshuffle_q4_k_f32.cl.h" + }; +#else + const std::string kernel_src = read_file("gemm_noshuffle_q4_k_f32.cl"); +#endif + cl_program prog = build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); + CL_CHECK((backend_ctx->kernel_gemm_noshuffle_q4_k_f32 = clCreateKernel(prog, "kernel_gemm_noshuffle_q4_k_f32", &err), err)); + CL_CHECK(clReleaseProgram(prog)); + GGML_LOG_CONT("."); } - void reset() { - // q and d are subbuffers into the bigger buffer allocated in ggml_backend_buffer. - // They must be properly released so that the original buffer can be - // properly released to avoid memory leak. - if (q != nullptr) { - CL_CHECK(clReleaseMemObject(q)); - q = nullptr; - } - if (d != nullptr) { - CL_CHECK(clReleaseMemObject(d)); - d = nullptr; + // gemv_noshuffle_q4_k_f32 + { + std::string CL_gemv_compile_opts = std::string("-cl-std=") + opencl_c_std + + " -cl-mad-enable "; + if (backend_ctx->has_vector_subgroup_broadcast) { + CL_gemv_compile_opts += " -DVECTOR_SUB_GROUP_BROADCAST "; } - // Currently, q_img and d_img are only initialized when SMALL_ALLOC is - // enabled. They point to the images in ggml_backend_opencl_buffer_context. - // So, there is no need to release them here. - // TODO: initialize them for non SMALL_PATH path, or remove them. - q_img = nullptr; - d_img = nullptr; - size_q = 0; - size_d = 0; - } -}; -struct ggml_tensor_extra_cl_mxfp4 { - // Quantized values. - cl_mem q = nullptr; - // Quantized values in image1d_buffer_t. - cl_mem q_img = nullptr; - // Scales in E8M0. - cl_mem e = nullptr; - // Scales in image1d_buffer_t. - cl_mem e_img = nullptr; - // Size of quantized values. - size_t size_q = 0; - // Size of scales. - size_t size_e = 0; +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "gemv_noshuffle_q4_k_f32.cl.h" + }; +#else + const std::string kernel_src = read_file("gemv_noshuffle_q4_k_f32.cl"); +#endif - ~ggml_tensor_extra_cl_mxfp4() { - reset(); - } + cl_program prog = build_program_from_source( + backend_ctx->context, backend_ctx->device, kernel_src.c_str(), CL_gemv_compile_opts); - void reset() { - // q and d are subbuffers into the bigger buffer allocated in ggml_backend_buffer. - // They must be properly released so that the original buffer can be - // properly released to avoid memory leak. - if (q != nullptr) { - CL_CHECK(clReleaseMemObject(q)); - q = nullptr; - } - if (e != nullptr) { - CL_CHECK(clReleaseMemObject(e)); - e = nullptr; - } - if (q != nullptr) { - CL_CHECK(clReleaseMemObject(q_img)); - q = nullptr; - } - // Currently, q_img and d_img are not used. They can be image1d_buffer_t - // that wraps around q and d to utilize image access path. - q_img = nullptr; - e_img = nullptr; - size_q = 0; - size_e = 0; + CL_CHECK((backend_ctx->kernel_gemv_noshuffle_q4_k_f32 = clCreateKernel(prog, "kernel_gemv_noshuffle_q4_k_f32", &err), err)); + CL_CHECK(clReleaseProgram(prog)); + GGML_LOG_CONT("."); } -}; - -struct ggml_tensor_extra_cl_q8_0 { - cl_mem q = nullptr; - cl_mem q_img = nullptr; - cl_mem d = nullptr; - cl_mem d_img = nullptr; + std::string CL_moe_compile_opts = std::string("-cl-std=") + opencl_c_std + + " -cl-mad-enable " + " -cl-fast-relaxed-math"; - size_t size_q = 0; - size_t size_d = 0; + // gemv_moe_q4_1_f32_ns + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "gemv_moe_q4_1_f32_ns.cl.h" + }; +#else + const std::string kernel_src = read_file("gemv_moe_q4_1_f32_ns.cl"); +#endif + cl_program prog = build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), CL_moe_compile_opts); - ~ggml_tensor_extra_cl_q8_0() { - reset(); + CL_CHECK((backend_ctx->kernel_gemv_moe_q4_1_f32_ns = clCreateKernel(prog, "kernel_gemv_moe_q4_1_f32_ns", &err), err)); + CL_CHECK(clReleaseProgram(prog)); + GGML_LOG_CONT("."); } - void reset() { - // q and d are subbuffers into the bigger buffer allocated in ggml_backend_buffer. - // They must be properly released so that the original buffer can be - // properly released to avoid memory leak. - if (q != nullptr) { - CL_CHECK(clReleaseMemObject(q)); - q = nullptr; - } - if (d != nullptr) { - CL_CHECK(clReleaseMemObject(d)); - d = nullptr; - } - // Currently, q_img and d_img are not used. They can be image1d_buffer_t - // that wraps around q and d to utilize image access path. - q_img = nullptr; - d_img = nullptr; - size_q = 0; - size_d = 0; - } -}; + // gemm_moe_q4_1_f32_ns + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "gemm_moe_q4_1_f32_ns.cl.h" + }; +#else + const std::string kernel_src = read_file("gemm_moe_q4_1_f32_ns.cl"); +#endif + cl_program prog = build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), CL_moe_compile_opts); -//------------------------------------------------------------------------------ -// Backend API -//------------------------------------------------------------------------------ + CL_CHECK((backend_ctx->kernel_gemm_moe_q4_1_f32_ns = clCreateKernel(prog, "kernel_gemm_moe_q4_1_f32_ns", &err), err)); + CL_CHECK(clReleaseProgram(prog)); + GGML_LOG_CONT("."); + } -// -// backend -// -static const char * ggml_backend_opencl_name(ggml_backend_t backend) { - return "OpenCL"; + // gemv_moe_mxfp4_f32 + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "gemv_moe_mxfp4_f32.cl.h" + }; +#else + const std::string kernel_src = read_file("gemv_moe_mxfp4_f32.cl"); +#endif + backend_ctx->program_gemv_moe_mxfp4_f32 = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), CL_moe_compile_opts); - UNUSED(backend); -} + CL_CHECK((backend_ctx->kernel_gemv_moe_mxfp4_f32 = clCreateKernel(backend_ctx->program_gemv_moe_mxfp4_f32, "kernel_gemv_moe_mxfp4_f32", &err), err)); + GGML_LOG_CONT("."); + } -static void ggml_backend_opencl_free(ggml_backend_t backend) { - ggml_cl2_free(backend); -} - -static void ggml_backend_opencl_set_tensor_async(ggml_backend_t backend, ggml_tensor * tensor, const void * data, size_t offset, size_t size) { - GGML_UNUSED(backend); - GGML_UNUSED(tensor); - GGML_UNUSED(data); - GGML_UNUSED(offset); - GGML_UNUSED(size); -} + // gemm_moe_mxfp4_f32 + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "gemm_moe_mxfp4_f32.cl.h" + }; +#else + const std::string kernel_src = read_file("gemm_moe_mxfp4_f32.cl"); +#endif + backend_ctx->program_gemm_moe_mxfp4_f32 = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), CL_moe_compile_opts); -static void ggml_backend_opencl_get_tensor_async(ggml_backend_t backend, const ggml_tensor * tensor, void * data, size_t offset, size_t size) { - GGML_UNUSED(backend); - GGML_UNUSED(tensor); - GGML_UNUSED(data); - GGML_UNUSED(offset); - GGML_UNUSED(size); -} + CL_CHECK((backend_ctx->kernel_gemm_moe_mxfp4_f32 = clCreateKernel(backend_ctx->program_gemm_moe_mxfp4_f32, "kernel_gemm_moe_mxfp4_f32", &err), err)); + GGML_LOG_CONT("."); + } -static bool ggml_backend_opencl_cpy_tensor_async(ggml_backend_t backend, const ggml_tensor * src, ggml_tensor * dst) { - GGML_UNUSED(backend); - GGML_UNUSED(src); - GGML_UNUSED(dst); - return false; -} + // gemv_moe_q4_0_f32_ns + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "gemv_moe_q4_0_f32_ns.cl.h" + }; +#else + const std::string kernel_src = read_file("gemv_moe_q4_0_f32_ns.cl"); +#endif + cl_program prog = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), CL_moe_compile_opts); -static void ggml_backend_opencl_synchronize(ggml_backend_t backend) { - auto * backend_ctx = static_cast<ggml_backend_opencl_context *>(backend->context); + CL_CHECK((backend_ctx->kernel_gemv_moe_q4_0_f32_ns = clCreateKernel(prog, "kernel_gemv_moe_q4_0_f32_ns", &err), err)); + CL_CHECK(clReleaseProgram(prog)); + GGML_LOG_CONT("."); + } - cl_event evt; - CL_CHECK(clEnqueueBarrierWithWaitList(backend_ctx->queue, 0, nullptr, &evt)); - CL_CHECK(clWaitForEvents(1, &evt)); - CL_CHECK(clReleaseEvent(evt)); -} + // gemm_moe_q4_0_f32_ns + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "gemm_moe_q4_0_f32_ns.cl.h" + }; +#else + const std::string kernel_src = read_file("gemm_moe_q4_0_f32_ns.cl"); +#endif + cl_program prog = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), CL_moe_compile_opts); -// Syncronizes the 'backend_ctx's device with others so that commands -// enqueued to it won't start until commands in the other devices have -// completed. -static void sync_with_other_backends(ggml_backend_opencl_context * backend_ctx) { - if (g_ggml_backend_opencl_devices.size() < 2) - return; // No other devices to synchronize with. + CL_CHECK((backend_ctx->kernel_gemm_moe_q4_0_f32_ns = clCreateKernel(prog, "kernel_gemm_moe_q4_0_f32_ns", &err), err)); + CL_CHECK(clReleaseProgram(prog)); + GGML_LOG_CONT("."); + } - std::vector<cl_event> events; - events.reserve(g_ggml_backend_opencl_devices.size()); + // gemv_moe_q5_0_f32_ns + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "gemv_moe_q5_0_f32_ns.cl.h" + }; +#else + const std::string kernel_src = read_file("gemv_moe_q5_0_f32_ns.cl"); +#endif + cl_program prog = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), CL_moe_compile_opts); - for (ggml_backend_device & backend_dev : g_ggml_backend_opencl_devices) { - auto * other_backend_ctx = ggml_cl2_init(&backend_dev); - if (backend_ctx != other_backend_ctx) { - cl_event ev; - CL_CHECK(clEnqueueMarkerWithWaitList(other_backend_ctx->queue, 0, nullptr, &ev)); - CL_CHECK(clFlush(other_backend_ctx->queue)); - events.push_back(ev); - } + CL_CHECK((backend_ctx->kernel_gemv_moe_q5_0_f32_ns = clCreateKernel(prog, "kernel_gemv_moe_q5_0_f32_ns", &err), err)); + CL_CHECK(clReleaseProgram(prog)); + GGML_LOG_CONT("."); } - CL_CHECK(clEnqueueBarrierWithWaitList(backend_ctx->queue, events.size(), events.data(), nullptr)); - for (auto ev : events) { - CL_CHECK(clReleaseEvent(ev)); + // gemm_moe_q5_0_f32_ns + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "gemm_moe_q5_0_f32_ns.cl.h" + }; +#else + const std::string kernel_src = read_file("gemm_moe_q5_0_f32_ns.cl"); +#endif + cl_program prog = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), CL_moe_compile_opts); + + CL_CHECK((backend_ctx->kernel_gemm_moe_q5_0_f32_ns = clCreateKernel(prog, "kernel_gemm_moe_q5_0_f32_ns", &err), err)); + CL_CHECK(clReleaseProgram(prog)); + GGML_LOG_CONT("."); } -} -static void sync_with_other_backends(ggml_backend_t backend) { - auto * backend_ctx = static_cast<ggml_backend_opencl_context *>(backend->context); - sync_with_other_backends(backend_ctx); -} + // gemv_moe_q5_1_f32_ns + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "gemv_moe_q5_1_f32_ns.cl.h" + }; +#else + const std::string kernel_src = read_file("gemv_moe_q5_1_f32_ns.cl"); +#endif + cl_program prog = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), CL_moe_compile_opts); -static bool ggml_opencl_can_fuse(const struct ggml_cgraph * cgraph, int node_idx, std::initializer_list<enum ggml_op> ops) { - if (!ggml_can_fuse(cgraph, node_idx, ops)) { - return false; + CL_CHECK((backend_ctx->kernel_gemv_moe_q5_1_f32_ns = clCreateKernel(prog, "kernel_gemv_moe_q5_1_f32_ns", &err), err)); + CL_CHECK(clReleaseProgram(prog)); + GGML_LOG_CONT("."); } - if (ops.size() == 2 && ops.begin()[0] == GGML_OP_RMS_NORM && ops.begin()[1] == GGML_OP_MUL) { - const ggml_tensor *rms_norm = cgraph->nodes[node_idx]; - const ggml_tensor *mul = cgraph->nodes[node_idx+1]; + // gemm_moe_q5_1_f32_ns + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "gemm_moe_q5_1_f32_ns.cl.h" + }; +#else + const std::string kernel_src = read_file("gemm_moe_q5_1_f32_ns.cl"); +#endif + cl_program prog = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), CL_moe_compile_opts); - GGML_ASSERT(rms_norm->src[0]->type == GGML_TYPE_F32); - GGML_ASSERT(rms_norm->type == GGML_TYPE_F32); + CL_CHECK((backend_ctx->kernel_gemm_moe_q5_1_f32_ns = clCreateKernel(prog, "kernel_gemm_moe_q5_1_f32_ns", &err), err)); + CL_CHECK(clReleaseProgram(prog)); + GGML_LOG_CONT("."); + } - // rms_norm only supports f32 - if (mul->src[0]->type != GGML_TYPE_F32 || - mul->src[1]->type != GGML_TYPE_F32 || - mul->type != GGML_TYPE_F32) { - return false; - } + // gemv_moe_q4_k_f32_ns + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "gemv_moe_q4_k_f32_ns.cl.h" + }; +#else + const std::string kernel_src = read_file("gemv_moe_q4_k_f32_ns.cl"); +#endif + cl_program prog = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), CL_moe_compile_opts); - // if rms_norm is the B operand, then we don't handle broadcast - if (rms_norm == mul->src[1] && - !ggml_are_same_shape(mul->src[0], rms_norm)) { - return false; - } + CL_CHECK((backend_ctx->kernel_gemv_moe_q4_k_f32_ns = clCreateKernel(prog, "kernel_gemv_moe_q4_k_f32_ns", &err), err)); + CL_CHECK(clReleaseProgram(prog)); + GGML_LOG_CONT("."); + } - // rms_norm assumes contiguous rows - if (!ggml_is_contiguous_rows(mul->src[0]) || !ggml_is_contiguous_rows(mul->src[1])) { - return false; - } - } else if (ops.size() == 3 && ops.begin()[0] == GGML_OP_NORM && ops.begin()[1] == GGML_OP_MUL && ops.begin()[2] == GGML_OP_ADD) { - const ggml_tensor *norm = cgraph->nodes[node_idx]; - const ggml_tensor *mul = cgraph->nodes[node_idx+1]; - const ggml_tensor *add = cgraph->nodes[node_idx+2]; - const ggml_tensor *w = mul->src[0] == norm ? mul->src[1] : mul->src[0]; - const ggml_tensor *b = add->src[0] == mul ? add->src[1] : add->src[0]; + // gemm_moe_q4_k_f32_ns + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "gemm_moe_q4_k_f32_ns.cl.h" + }; +#else + const std::string kernel_src = read_file("gemm_moe_q4_k_f32_ns.cl"); +#endif + cl_program prog = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), CL_moe_compile_opts); - // norm fusion only supports F32 - if (norm->src[0]->type != GGML_TYPE_F32 || w->type != GGML_TYPE_F32 || b->type != GGML_TYPE_F32) { - return false; - } + CL_CHECK((backend_ctx->kernel_gemm_moe_q4_k_f32_ns = clCreateKernel(prog, "kernel_gemm_moe_q4_k_f32_ns", &err), err)); + CL_CHECK(clReleaseProgram(prog)); + GGML_LOG_CONT("."); + } - if (norm->src[0]->ne[0] % 4 != 0) { - return false; - } + // gemv_moe_q5_k_f32_ns + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "gemv_moe_q5_k_f32_ns.cl.h" + }; +#else + const std::string kernel_src = read_file("gemv_moe_q5_k_f32_ns.cl"); +#endif + cl_program prog = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), CL_moe_compile_opts); - if (!ggml_is_contiguous(norm->src[0]) || !ggml_is_contiguous(w) || !ggml_is_contiguous(b)) { - return false; - } - } else if (ops.size() == 3 && ops.begin()[0] == GGML_OP_GROUP_NORM && ops.begin()[1] == GGML_OP_MUL && ops.begin()[2] == GGML_OP_ADD) { - const ggml_tensor *gn = cgraph->nodes[node_idx]; - const ggml_tensor *mul = cgraph->nodes[node_idx+1]; - const ggml_tensor *add = cgraph->nodes[node_idx+2]; - const ggml_tensor *w = mul->src[0] == gn ? mul->src[1] : mul->src[0]; - const ggml_tensor *b = add->src[0] == mul ? add->src[1] : add->src[0]; + CL_CHECK((backend_ctx->kernel_gemv_moe_q5_k_f32_ns = clCreateKernel(prog, "kernel_gemv_moe_q5_k_f32_ns", &err), err)); + CL_CHECK(clReleaseProgram(prog)); + GGML_LOG_CONT("."); + } - if (gn->src[0]->type != GGML_TYPE_F32 || w->type != GGML_TYPE_F32 || b->type != GGML_TYPE_F32) { - return false; - } + // gemm_moe_q5_k_f32_ns + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "gemm_moe_q5_k_f32_ns.cl.h" + }; +#else + const std::string kernel_src = read_file("gemm_moe_q5_k_f32_ns.cl"); +#endif + cl_program prog = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), CL_moe_compile_opts); - if (!ggml_is_contiguous(gn->src[0]) || !ggml_is_contiguous(w) || !ggml_is_contiguous(b)) { - return false; - } + CL_CHECK((backend_ctx->kernel_gemm_moe_q5_k_f32_ns = clCreateKernel(prog, "kernel_gemm_moe_q5_k_f32_ns", &err), err)); + CL_CHECK(clReleaseProgram(prog)); + GGML_LOG_CONT("."); } - return true; -} - -static void ggml_opencl_op_rms_norm_fused(ggml_backend_t backend, ggml_tensor * rms_norm_tensor, ggml_tensor * mul_tensor); -static void ggml_opencl_op_norm_fused(ggml_backend_t backend, ggml_tensor * norm_tensor, ggml_tensor * mul_tensor, ggml_tensor * add_tensor); -static void ggml_opencl_op_group_norm_fused(ggml_backend_t backend, ggml_tensor * gn_tensor, ggml_tensor * mul_tensor, ggml_tensor * add_tensor); + // gemv_moe_q6_k_f32_ns + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "gemv_moe_q6_k_f32_ns.cl.h" + }; +#else + const std::string kernel_src = read_file("gemv_moe_q6_k_f32_ns.cl"); +#endif + cl_program prog = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), CL_moe_compile_opts); -static ggml_status ggml_backend_opencl_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) { - ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context; + CL_CHECK((backend_ctx->kernel_gemv_moe_q6_k_f32_ns = clCreateKernel(prog, "kernel_gemv_moe_q6_k_f32_ns", &err), err)); + CL_CHECK(clReleaseProgram(prog)); + GGML_LOG_CONT("."); + } - for (int i = 0; i < cgraph->n_nodes; i++) { - ggml_tensor * node = cgraph->nodes[i]; + // gemm_moe_q6_k_f32_ns + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "gemm_moe_q6_k_f32_ns.cl.h" + }; +#else + const std::string kernel_src = read_file("gemm_moe_q6_k_f32_ns.cl"); +#endif + cl_program prog = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), CL_moe_compile_opts); - // NOTE: this may oversynchronize by synchronizing with - // backends/devices which don't compute 'cgraph's - // dependencies. - sync_with_other_backends(backend); + CL_CHECK((backend_ctx->kernel_gemm_moe_q6_k_f32_ns = clCreateKernel(prog, "kernel_gemm_moe_q6_k_f32_ns", &err), err)); + CL_CHECK(clReleaseProgram(prog)); + GGML_LOG_CONT("."); + } - if (ggml_is_empty(node) || node->op == GGML_OP_RESHAPE || node->op == GGML_OP_TRANSPOSE || node->op == GGML_OP_VIEW || node->op == GGML_OP_PERMUTE || node->op == GGML_OP_NONE) { - continue; + // gemv_moe_mxfp4_f32_ns + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "gemv_moe_mxfp4_f32_ns.cl.h" + }; +#else + const std::string kernel_src = read_file("gemv_moe_mxfp4_f32_ns.cl"); +#endif + cl_program prog = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), CL_moe_compile_opts); + + CL_CHECK((backend_ctx->kernel_gemv_moe_mxfp4_f32_ns = clCreateKernel(prog, "kernel_gemv_moe_mxfp4_f32_ns", &err), err)); + CL_CHECK(clReleaseProgram(prog)); + GGML_LOG_CONT("."); + } + + // gemm_moe_mxfp4_f32_ns + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "gemm_moe_mxfp4_f32_ns.cl.h" + }; +#else + const std::string kernel_src = read_file("gemm_moe_mxfp4_f32_ns.cl"); +#endif + cl_program prog = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), CL_moe_compile_opts); + + CL_CHECK((backend_ctx->kernel_gemm_moe_mxfp4_f32_ns = clCreateKernel(prog, "kernel_gemm_moe_mxfp4_f32_ns", &err), err)); + CL_CHECK(clReleaseProgram(prog)); + GGML_LOG_CONT("."); + } + + // moe_reorder_b + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "moe_reorder_b.cl.h" + }; +#else + const std::string kernel_src = read_file("moe_reorder_b.cl"); +#endif + cl_program prog = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), CL_moe_compile_opts); + + CL_CHECK((backend_ctx->kernel_moe_reorder_b = clCreateKernel(prog, "kernel_moe_reorder_b", &err), err)); + CL_CHECK(clReleaseProgram(prog)); + GGML_LOG_CONT("."); + } + + // moe_sort_by_expert + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "moe_sort_by_expert.cl.h" + }; +#else + const std::string kernel_src = read_file("moe_sort_by_expert.cl"); +#endif + cl_program prog = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), CL_moe_compile_opts); + + CL_CHECK((backend_ctx->kernel_moe_histogram = clCreateKernel(prog, "kernel_moe_histogram", &err), err)); + CL_CHECK((backend_ctx->kernel_moe_scan = clCreateKernel(prog, "kernel_moe_scan", &err), err)); + CL_CHECK((backend_ctx->kernel_moe_fill = clCreateKernel(prog, "kernel_moe_fill", &err), err)); + CL_CHECK((backend_ctx->kernel_moe_scatter = clCreateKernel(prog, "kernel_moe_scatter", &err), err)); + CL_CHECK(clReleaseProgram(prog)); + GGML_LOG_CONT("."); + } + + // gemv_noshuffle_q6_k_f32 + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "gemv_noshuffle_q6_k_f32.cl.h" + }; +#else + const std::string kernel_src = read_file("gemv_noshuffle_q6_k_f32.cl"); +#endif + + std::string CL_gemv_compile_opts = std::string("-cl-std=") + opencl_c_std + + " -cl-mad-enable "; + if (backend_ctx->has_vector_subgroup_broadcast) { + CL_gemv_compile_opts += " -DVECTOR_SUB_GROUP_BROADCAT "; } - if (!backend_ctx->disable_fusion && ggml_opencl_can_fuse(cgraph, i, { GGML_OP_NORM, GGML_OP_MUL, GGML_OP_ADD })) { - ggml_opencl_op_norm_fused(backend, node, cgraph->nodes[i+1], cgraph->nodes[i+2]); - i += 2; - continue; + cl_program prog = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), CL_gemv_compile_opts); + + CL_CHECK((backend_ctx->kernel_gemv_noshuffle_q6_K_f32 = clCreateKernel(prog, "kernel_gemv_noshuffle_q6_K_f32", &err), err)); + GGML_LOG_CONT("."); + } + + // gemm_noshuffle_q6_k_f32 + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "gemm_noshuffle_q6_k_f32.cl.h" + }; +#else + const std::string kernel_src = read_file("gemm_noshuffle_q6_k_f32.cl"); +#endif + cl_program prog = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), CL_moe_compile_opts); + + CL_CHECK((backend_ctx->kernel_gemm_noshuffle_q6_K_f32 = clCreateKernel(prog, "kernel_gemm_noshuffle_q6_K_f32", &err), err)); + GGML_LOG_CONT("."); + } + + // gemv_noshuffle_q5_k_f32 + { + std::string CL_gemv_compile_opts = std::string("-cl-std=") + opencl_c_std + + " -cl-mad-enable "; + if (backend_ctx->has_vector_subgroup_broadcast) { + CL_gemv_compile_opts += " -DVECTOR_SUB_GROUP_BROADCAST "; } - if (!backend_ctx->disable_fusion && ggml_opencl_can_fuse(cgraph, i, { GGML_OP_GROUP_NORM, GGML_OP_MUL, GGML_OP_ADD })) { - ggml_opencl_op_group_norm_fused(backend, node, cgraph->nodes[i+1], cgraph->nodes[i+2]); - i += 2; - continue; + +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "gemv_noshuffle_q5_k_f32.cl.h" + }; +#else + const std::string kernel_src = read_file("gemv_noshuffle_q5_k_f32.cl"); +#endif + + cl_program prog = build_program_from_source( + backend_ctx->context, backend_ctx->device, kernel_src.c_str(), CL_gemv_compile_opts); + + CL_CHECK((backend_ctx->kernel_gemv_noshuffle_q5_k_f32 = clCreateKernel(prog, "kernel_gemv_noshuffle_q5_k_f32", &err), err)); + CL_CHECK(clReleaseProgram(prog)); + GGML_LOG_CONT("."); + } + + // gemm_noshuffle_q5_k_f32 + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "gemm_noshuffle_q5_k_f32.cl.h" + }; +#else + const std::string kernel_src = read_file("gemm_noshuffle_q5_k_f32.cl"); +#endif + cl_program prog = build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); + CL_CHECK((backend_ctx->kernel_gemm_noshuffle_q5_k_f32 = clCreateKernel(prog, "kernel_gemm_noshuffle_q5_k_f32", &err), err)); + CL_CHECK(clReleaseProgram(prog)); + GGML_LOG_CONT("."); + } +#endif // GGML_OPENCL_USE_ADRENO_KERNELS + GGML_LOG_CONT("\n"); + backend_ctx->kernels_loaded = true; +} + +// XXX static ggml_backend_opencl_context * ggml_cl2_init(ggml_backend_dev_t dev) { +// XXX static bool initialized = false; +// XXX static ggml_backend_opencl_context *backend_ctx = nullptr; + +static ggml_backend_opencl_context * ggml_cl_init(ggml_backend_dev_t dev); +static bool ggml_opencl_is_device_supported(ggml_backend_dev_t dev); + +namespace /* anonymous */ { +extern struct ggml_backend_device_i ggml_backend_opencl_device_i; +} + +// Look for available and suitable devices. +static std::vector<ggml_backend_device> ggml_opencl_probe_devices(ggml_backend_reg * reg) { + std::vector<ggml_backend_device> found_devices; + +#ifdef GGML_OPENCL_PROFILING + GGML_LOG_INFO("ggml_opencl: OpenCL profiling enabled\n"); +#endif + + struct cl_device; + struct cl_platform { + cl_platform_id id; + unsigned number; + char name[128]; + char vendor[128]; + struct cl_device * devices; + unsigned n_devices; + struct cl_device * default_device; + }; + + struct cl_device { + struct cl_platform * platform; + cl_device_id id; + unsigned number; + cl_device_type type; + char name[128]; + char version[128]; + }; + + enum { NPLAT = 16, NDEV = 16 }; + + struct cl_platform platforms[NPLAT]; + unsigned n_platforms = 0; + struct cl_device devices[NDEV]; + unsigned n_devices = 0; + struct cl_device * default_device = NULL; + unsigned default_platform_number = 0; + + cl_platform_id platform_ids[NPLAT]; + if (clGetPlatformIDs(NPLAT, platform_ids, &n_platforms) != CL_SUCCESS) { + GGML_LOG_ERROR("ggml_opencl: platform IDs not available.\n"); + return found_devices; + } + + for (unsigned i = 0; i < n_platforms; i++) { + struct cl_platform * p = &platforms[i]; + p->number = i; + p->id = platform_ids[i]; + CL_CHECK(clGetPlatformInfo(p->id, CL_PLATFORM_NAME, sizeof(p->name), &p->name, NULL)); + CL_CHECK(clGetPlatformInfo(p->id, CL_PLATFORM_VENDOR, sizeof(p->vendor), &p->vendor, NULL)); + + cl_device_id device_ids[NDEV]; + cl_int clGetDeviceIDsError = clGetDeviceIDs(p->id, CL_DEVICE_TYPE_ALL, NDEV, device_ids, &p->n_devices); + if (clGetDeviceIDsError == CL_DEVICE_NOT_FOUND) { + p->n_devices = 0; + } else { + CL_CHECK(clGetDeviceIDsError); } - if (!backend_ctx->disable_fusion && ggml_opencl_can_fuse(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL })) { - ggml_opencl_op_rms_norm_fused(backend, node, cgraph->nodes[i+1]); - i++; - continue; + p->devices = p->n_devices > 0 ? &devices[n_devices] : NULL; + p->default_device = NULL; + + for (unsigned j = 0; j < p->n_devices; j++) { + struct cl_device * d = &devices[n_devices]; + d->number = n_devices++; + d->id = device_ids[j]; + d->platform = p; + CL_CHECK(clGetDeviceInfo(d->id, CL_DEVICE_NAME, sizeof(d->name), &d->name, NULL)); + CL_CHECK(clGetDeviceInfo(d->id, CL_DEVICE_TYPE, sizeof(d->type), &d->type, NULL)); + CL_CHECK(clGetDeviceInfo(d->id, CL_DEVICE_VERSION, sizeof(d->version), &d->version, NULL)); + + if (p->default_device == NULL && d->type == CL_DEVICE_TYPE_GPU) { + p->default_device = d; + } } - bool ok = ggml_cl_compute_forward(backend, node); - if (!ok) { - GGML_LOG_ERROR("%s: error: op not supported %s (%s)\n", __func__, node->name, ggml_op_name(node->op)); + if (default_device == NULL && p->default_device != NULL) { + default_device = p->default_device; + default_platform_number = i; } - GGML_ASSERT(ok); } - return GGML_STATUS_SUCCESS; -} + if (n_devices == 0) { + GGML_LOG_ERROR("ggml_opencl: could find any OpenCL devices.\n"); + return found_devices; + } -static bool ggml_opencl_supports_op(ggml_backend_dev_t dev, const struct ggml_tensor * op) { - ggml_backend_opencl_device_context * dev_ctx = (ggml_backend_opencl_device_context *)dev->context; - ggml_backend_opencl_context * backend_ctx = dev_ctx->backend_ctx; + char * user_platform_string = getenv("GGML_OPENCL_PLATFORM"); + char * user_device_string = getenv("GGML_OPENCL_DEVICE"); + int user_platform_number = -1; + int user_device_number = -1; + cl_device * candidate_devices = nullptr; + unsigned n_candidate_devices = 0; - switch (op->op) { - case GGML_OP_NONE: - return true; - case GGML_OP_GET_ROWS: - switch (op->src[0]->type) { - case GGML_TYPE_F32: - case GGML_TYPE_F16: - return true; - case GGML_TYPE_Q4_0: -#ifdef GGML_OPENCL_SOA_Q - // We do not support flattened Q4_0 (and possibly other Q's) - return false; -#else // GGML_OPENCL_SOA_Q - return true; -#endif // GGML_OPENCL_SOA_Q - default: - return false; - } - case GGML_OP_SET_ROWS: - { - // TODO: add support - // ref: https://github.com/ggml-org/llama.cpp/pull/14274 -#pragma message("TODO: implement BF16, Q4_0, Q4_1, Q5_0, Q5_1, Q8_0, IQ4_NL support (https://github.com/ggml-org/llama.cpp/pull/14661)") - if (op->src[0]->type != GGML_TYPE_F32) { - return false; - } - switch (op->type) { - case GGML_TYPE_F16: - case GGML_TYPE_F32: - return (op->src[1]->type == GGML_TYPE_I64 || op->src[1]->type == GGML_TYPE_I32); - default: - return false; - } - } - case GGML_OP_CPY: - case GGML_OP_DUP: - case GGML_OP_CONT: - switch (op->src[0]->type) { - case GGML_TYPE_F32: - switch (op->type) { - case GGML_TYPE_F16: - case GGML_TYPE_F32: - return true; - default: - return false; - } - case GGML_TYPE_F16: - switch (op->type) { - case GGML_TYPE_F16: - case GGML_TYPE_F32: - return true; - default: - return false; - } - default: - return false; - } - case GGML_OP_SCALE: - return op->src[0]->type == GGML_TYPE_F32 && ggml_is_contiguous(op->src[0]); - case GGML_OP_ADD: - if (op->type == GGML_TYPE_F16) { - const bool src0_ok = op->src[0]->type == GGML_TYPE_F16 || op->src[0]->type == GGML_TYPE_F32; - const bool src1_ok = op->src[1]->type == GGML_TYPE_F16 || op->src[1]->type == GGML_TYPE_F32; - if (src0_ok && src1_ok) { - return true; + unsigned n; + if (user_platform_string != NULL && sscanf(user_platform_string, " %u", &n) == 1 && n < n_platforms) { + user_platform_number = (int)n; + } + if (user_device_string != NULL && sscanf(user_device_string, " %u", &n) == 1 && n < n_devices) { + user_device_number = (int)n; + } + if (user_platform_number != -1 && user_device_number != -1) { + cl_platform* platform = &platforms[user_platform_number]; + if ((unsigned)user_device_number >= platform->n_devices) { + GGML_LOG_ERROR("ggml_opencl: invalid device number %d\n", user_device_number); + exit(1); + } + default_device = &platform->devices[user_device_number]; + candidate_devices = platform->devices; + n_candidate_devices = platform->n_devices; + } else { + // Choose a platform by matching a substring. + if (user_platform_number == -1 && user_platform_string != NULL && user_platform_string[0] != 0) { + for (unsigned i = 0; i < n_platforms; i++) { + struct cl_platform * p = &platforms[i]; + if (strstr(p->name, user_platform_string) != NULL || + strstr(p->vendor, user_platform_string) != NULL) { + user_platform_number = (int)i; + break; } } - case GGML_OP_MUL: - case GGML_OP_DIV: - case GGML_OP_SUB: - return (op->src[0]->type == op->src[1]->type) && - (op->src[0]->type == op->type) && - (op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16); - case GGML_OP_ADD_ID: - return op->src[0]->type == GGML_TYPE_F32; - case GGML_OP_SQR: - case GGML_OP_SQRT: - return (op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16) && - ggml_is_contiguous(op->src[0]); - case GGML_OP_UNARY: - switch (ggml_get_unary_op(op)) { - case GGML_UNARY_OP_GELU: - case GGML_UNARY_OP_SILU: - case GGML_UNARY_OP_RELU: - case GGML_UNARY_OP_GELU_ERF: - case GGML_UNARY_OP_GELU_QUICK: - return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32; - case GGML_UNARY_OP_SIGMOID: - return ggml_is_contiguous(op->src[0]); - case GGML_UNARY_OP_TANH: - return (op->src[0]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32) || - (op->src[0]->type == GGML_TYPE_F16 && op->type == GGML_TYPE_F16); - case GGML_UNARY_OP_EXPM1: - return (op->src[0]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32) || - (op->src[0]->type == GGML_TYPE_F16 && op->type == GGML_TYPE_F16); - case GGML_UNARY_OP_SOFTPLUS: - return (op->src[0]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32) || - (op->src[0]->type == GGML_TYPE_F16 && op->type == GGML_TYPE_F16); - default: - return false; - } - case GGML_OP_GLU: - switch (ggml_get_glu_op(op)) { - case GGML_GLU_OP_GEGLU: - case GGML_GLU_OP_REGLU: - case GGML_GLU_OP_SWIGLU: - case GGML_GLU_OP_SWIGLU_OAI: - case GGML_GLU_OP_GEGLU_ERF: - case GGML_GLU_OP_GEGLU_QUICK: - return ggml_is_contiguous_1(op->src[0]) && (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16); - default: - return false; - } - case GGML_OP_FILL: - return op->type == GGML_TYPE_F32 && ggml_is_contiguous(op); - case GGML_OP_CLAMP: - return op->src[0]->type == GGML_TYPE_F32; - case GGML_OP_SOFT_MAX: - case GGML_OP_NORM: - return true; - case GGML_OP_RMS_NORM: - return op->ne[0] % 4 == 0 && ggml_is_contiguous_rows(op->src[0]); - case GGML_OP_REPEAT: - return op->src[0]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32; // Assuming F32 for now, can be expanded - case GGML_OP_PAD: - // TODO: add circular padding support for opencl, see https://github.com/ggml-org/llama.cpp/pull/16985 - if (ggml_get_op_params_i32(op, 8) != 0) { - return false; + if (user_platform_number == -1) { + GGML_LOG_ERROR("ggml_opencl: no platform matching '%s' was found.\n", user_platform_string); + exit(1); } - return op->src[0]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32; - case GGML_OP_UPSCALE: { - ggml_scale_mode mode = (ggml_scale_mode)(ggml_get_op_params_i32(op, 0) & 0xFF); - const bool antialias = (ggml_scale_mode)(ggml_get_op_params_i32(op, 0) & GGML_SCALE_FLAG_ANTIALIAS); - return op->src[0]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32 && - (mode == GGML_SCALE_MODE_NEAREST || mode == GGML_SCALE_MODE_BILINEAR) && !antialias; } - case GGML_OP_CONV_2D: - return (op->src[0]->type == GGML_TYPE_F16 && op->src[1]->type == GGML_TYPE_F16 && op->type == GGML_TYPE_F16) || - (op->src[0]->type == GGML_TYPE_F32 && op->src[1]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32) || - (op->src[0]->type == GGML_TYPE_F16 && op->src[1]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32); - case GGML_OP_SSM_CONV: - return (op->src[0]->type == GGML_TYPE_F32 && op->src[1]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32); - case GGML_OP_CONCAT: - return op->src[0]->type == GGML_TYPE_F32 && op->src[1]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32; - case GGML_OP_TIMESTEP_EMBEDDING: - return op->src[0]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32; - case GGML_OP_GROUP_NORM: - return ggml_is_contiguous(op->src[0]); - case GGML_OP_MUL_MAT: - if (op->src[0]->type == GGML_TYPE_F16) { - return true; - } else if (op->src[0]->type == GGML_TYPE_F32) { - return op->src[1]->type == GGML_TYPE_F32; - } else if (op->src[0]->type == GGML_TYPE_Q4_0 || op->src[0]->type == GGML_TYPE_MXFP4 || - op->src[0]->type == GGML_TYPE_Q6_K) { - return op->src[1]->type == GGML_TYPE_F32 && ggml_is_contiguous(op->src[0]) && ggml_is_contiguous(op->src[1]); - } else if (op->src[0]->type == GGML_TYPE_Q8_0) { - return op->src[1]->type == GGML_TYPE_F32; - } - return false; - case GGML_OP_MUL_MAT_ID: - if (op->src[0]->type == GGML_TYPE_Q4_0 || - op->src[0]->type == GGML_TYPE_Q8_0 || - op->src[0]->type == GGML_TYPE_MXFP4) { - if (op->src[1]->type == GGML_TYPE_F32) { - return ggml_is_contiguous(op->src[0]) && ggml_is_contiguous(op->src[1]); - } - } - return false; - case GGML_OP_RESHAPE: - case GGML_OP_VIEW: - case GGML_OP_PERMUTE: - case GGML_OP_TRANSPOSE: - return true; - case GGML_OP_DIAG_MASK_INF: - return op->ne[3] == 1; - case GGML_OP_ROPE: { - const int mode = ((const int32_t *) op->op_params)[2]; - const bool is_mrope = mode & GGML_ROPE_TYPE_MROPE; - const bool is_vision = mode == GGML_ROPE_TYPE_VISION; - if (is_mrope && !is_vision) { - if (op->src[0]->type == GGML_TYPE_F32 || - op->src[0]->type == GGML_TYPE_F16) { - return true; + + int platform_idx = user_platform_number != -1 ? user_platform_number : default_platform_number; + struct cl_platform * p = &platforms[platform_idx]; + candidate_devices = p->devices; + n_candidate_devices = p->n_devices; + default_device = p->default_device; + if (n_candidate_devices == 0) { + GGML_LOG_ERROR("ggml_opencl: selected platform '%s' does not have any devices.\n", p->name); + exit(1); + } + + if (user_device_number == -1 && user_device_string != NULL && user_device_string[0] != 0) { + for (unsigned i = 0; i < n_candidate_devices; i++) { + struct cl_device * d = &candidate_devices[i]; + if (strstr(d->name, user_device_string) != NULL) { + user_device_number = d->number; + break; } - return false; } - if (is_vision) { - if (op->src[0]->type == GGML_TYPE_F32 || - op->src[0]->type == GGML_TYPE_F16) { - return true; - } - return false; + if (user_device_number == -1) { + GGML_LOG_ERROR("ggml_opencl: no device matching '%s' was found.\n", user_device_string); + exit(1); } - return true; } - case GGML_OP_IM2COL: - return true; - case GGML_OP_ARGSORT: { - cl_kernel kernel = backend_ctx->kernel_argsort_f32_i32; - int max_workgroup_size = backend_ctx->get_kernel_workgroup_size(kernel); + if (user_device_number != -1) { + candidate_devices = &devices[user_device_number]; + n_candidate_devices = 1; + default_device = &candidate_devices[0]; + } - int cols = 1; - while (cols < op->ne[0]) { - cols *= 2; - } + GGML_ASSERT(n_candidate_devices > 0); - return cols <= max_workgroup_size && op->src[0]->type == GGML_TYPE_F32; + if (default_device == NULL) { + default_device = &candidate_devices[0]; } - case GGML_OP_SUM_ROWS: - case GGML_OP_MEAN: - return op->src[0]->type == GGML_TYPE_F32 && ggml_is_contiguous(op->src[0]); - case GGML_OP_FLASH_ATTN_EXT: - { - const ggml_tensor * q = op->src[0]; - const ggml_tensor * k = op->src[1]; - const ggml_tensor * v = op->src[2]; + } + + GGML_ASSERT(n_candidate_devices != 0 && candidate_devices); + + // Put the default device in front. + for (unsigned i = 1; i < n_candidate_devices; i++) { + if (&candidate_devices[i] == default_device) { + std::swap(candidate_devices[0], candidate_devices[i]); + default_device = &candidate_devices[0]; + break; + } + } + + GGML_LOG_INFO("ggml_opencl: selected platform: '%s'\n", default_device->platform->name); + + std::vector<cl_device_id> device_ids; + for (auto dev = candidate_devices, dev_end = candidate_devices + n_candidate_devices; dev != dev_end; dev++) { + device_ids.push_back(dev->id); + } + + cl_int err; + cl_context shared_context; + cl_context_properties properties[] = { (intptr_t) CL_CONTEXT_PLATFORM, (intptr_t) default_device->platform->id, 0 }; + + CL_CHECK( + (shared_context = clCreateContext(properties, device_ids.size(), device_ids.data(), NULL, NULL, &err), err)); + + for (auto dev = candidate_devices, dev_end = candidate_devices + n_candidate_devices; dev != dev_end; dev++) { + GGML_LOG_INFO("\nggml_opencl: device: '%s (%s)'\n", dev->name, dev->version); + + auto dev_ctx = std::unique_ptr<ggml_backend_opencl_device_context>(new ggml_backend_opencl_device_context{ + /*.platform =*/dev->platform->id, + /*.platform_nane =*/dev->platform->name, + /*.device =*/dev->id, + /*.device_name =*/dev->name, + /*.device_type =*/dev->type, + /*.device_version =*/dev->version, + /*.backend_ctx =*/nullptr, + /*.buffer_type =*/{}, + /*.context =*/shared_context, + }); + + found_devices.push_back(ggml_backend_device{ + /* .iface = */ ggml_backend_opencl_device_i, + /* .reg = */ reg, + /* .context = */ dev_ctx.get(), + }); + + if (!ggml_opencl_is_device_supported(&found_devices.back())) { + found_devices.pop_back(); + GGML_LOG_WARN("ggml_opencl: drop unsupported device '%s'.\n", dev->name); + continue; + } + + g_ggml_backend_opencl_dev_ctxs.push_back(std::move(dev_ctx)); + } + + if (found_devices.size()) { + auto * dev_ctx = static_cast<ggml_backend_opencl_device_context *>(found_devices.front().context); + GGML_LOG_INFO("ggml_opencl: default device: '%s (%s)'\n", dev_ctx->device_name.c_str(), + dev_ctx->device_version.c_str()); + + if (dev_ctx->device_type != CL_DEVICE_TYPE_GPU) { + GGML_LOG_WARN("ggml_opencl: warning, the default device is not a GPU: '%s'.\n", + dev_ctx->device_name.c_str()); + } + } + + return found_devices; +} + +static void ggml_opencl_print_backend_info(ggml_backend_opencl_device_context * dev_ctx) { + GGML_ASSERT(dev_ctx); + GGML_ASSERT(dev_ctx->backend_ctx); + + auto * backend_ctx = dev_ctx->backend_ctx; + + GGML_LOG_INFO("ggml_opencl: OpenCL driver: %s\n", + backend_ctx->driver_version.c_str()); + GGML_LOG_INFO("ggml_opencl: vector subgroup broadcast support: %s\n", + backend_ctx->has_vector_subgroup_broadcast ? "true" : "false"); + GGML_LOG_INFO("ggml_opencl: device FP16 support: %s\n", + backend_ctx->fp16_support ? "true" : "false"); + GGML_LOG_INFO("ggml_opencl: mem base addr align: %u\n", + backend_ctx->alignment); + GGML_LOG_INFO("ggml_opencl: global mem size: %zu MB\n", + backend_ctx->global_mem_size/1024/1024); + GGML_LOG_INFO("ggml_opencl: max mem alloc size: %zu MB\n", + backend_ctx->max_alloc_size/1024/1024); + GGML_LOG_INFO("ggml_opencl: device max image buffer size (pixels): %lu\n", + backend_ctx->image_max_buffer_size); + GGML_LOG_INFO("ggml_opencl: device max image2d size: %lu x %lu\n", + backend_ctx->image2d_max_width, backend_ctx->image2d_max_height); + GGML_LOG_INFO("ggml_opencl: device max workgroup size: %lu\n", + backend_ctx->max_workgroup_size); + GGML_LOG_INFO("ggml_opencl: SVM coarse grain buffer support: %s\n", + backend_ctx->svm_caps & CL_DEVICE_SVM_COARSE_GRAIN_BUFFER ? "true" : "false"); + GGML_LOG_INFO("ggml_opencl: SVM fine grain buffer support: %s\n", + backend_ctx->svm_caps & CL_DEVICE_SVM_FINE_GRAIN_BUFFER ? "true" : "false"); + GGML_LOG_INFO("ggml_opencl: SVM fine grain system support: %s\n", + backend_ctx->svm_caps & CL_DEVICE_SVM_FINE_GRAIN_SYSTEM ? "true" : "false"); + GGML_LOG_INFO("ggml_opencl: SVM atomics support: %s\n", + backend_ctx->svm_caps & CL_DEVICE_SVM_ATOMICS ? "true" : "false"); + GGML_LOG_INFO("ggml_opencl: cl_qcom_subgroup_shuffle support: %s\n", + backend_ctx->has_qcom_subgroup_shuffle ? "true" : "false"); + + // Print out configurations +#ifdef GGML_OPENCL_SOA_Q + GGML_LOG_INFO("ggml_opencl: flattening quantized weights representation as struct of arrays (GGML_OPENCL_SOA_Q)\n"); +#endif // GGML_OPENCL_SOA_Q + +#ifdef GGML_OPENCL_USE_ADRENO_KERNELS + GGML_LOG_INFO("ggml_opencl: using kernels optimized for Adreno (GGML_OPENCL_USE_ADRENO_KERNELS)\n"); + if (backend_ctx->adreno_xmem_gemm_enabled) { + GGML_LOG_INFO("ggml_opencl: Adreno xmem F16xF32 GEMM enabled (temporary weight prepack)\n"); + } +#endif // GGML_OPENCL_USE_ADRENO_KERNELS + + if (backend_ctx->adreno_use_large_buffer) { + if (!backend_ctx->adreno_has_large_buffer) { + GGML_LOG_INFO("ggml_opencl: Adreno large buffer requested but not supported by driver, will use regular buffer\n"); + backend_ctx->adreno_use_large_buffer = false; + } else { + GGML_LOG_INFO("ggml_opencl: Adreno large buffer enabled\n"); + } + } + + if (dev_ctx->opfilter) { + // for information only, the actual regex object is created in ggml_opencl_is_device_supported + GGML_LOG_INFO("ggml_opencl: opfilter regex = \"%s\"\n", dev_ctx->opfilter_str.c_str()); + } +} + +// check if device should be accepted +static bool ggml_opencl_is_device_supported(ggml_backend_dev_t dev) { + GGML_ASSERT(dev); + GGML_ASSERT(dev->context); + + ggml_backend_opencl_device_context * dev_ctx = (ggml_backend_opencl_device_context *) dev->context; + GGML_ASSERT(dev_ctx->platform); + GGML_ASSERT(dev_ctx->device); + + if (strstr(dev_ctx->device_name.c_str(), "Adreno") || + strstr(dev_ctx->device_name.c_str(), "Qualcomm") || + strstr(dev_ctx->device_version.c_str(), "Adreno")) { + dev_ctx->gpu_family = GPU_FAMILY::ADRENO; + + // Usually device version contains the detailed device name + dev_ctx->adreno_gen = get_adreno_gpu_gen(dev_ctx->device_version.c_str()); + if (dev_ctx->adreno_gen == ADRENO_GPU_GEN::ADRENO_UNKNOWN) { + dev_ctx->adreno_gen = get_adreno_gpu_gen(dev_ctx->device_name.c_str()); + } + } else if (strstr(dev_ctx->device_name.c_str(), "Intel")) { + dev_ctx->gpu_family = GPU_FAMILY::INTEL; + } else { + GGML_LOG_WARN("ggml_opencl: unsupported GPU '%s'.\n", dev_ctx->device_name.c_str()); + dev_ctx->gpu_family = GPU_FAMILY::UNKNOWN; + return false; + } + + ggml_cl_version platform_version = get_opencl_platform_version(dev_ctx->platform); + + // Check device OpenCL version, OpenCL 2.0 or above is required + ggml_cl_version opencl_c_version = get_opencl_c_version(platform_version, dev_ctx->device); + if (opencl_c_version.major < 2) { + GGML_LOG_WARN("ggml_opencl: OpenCL 2.0 or above is required\n"); + return false; + } + +#ifdef GGML_OPENCL_USE_ADRENO_KERNELS + if (dev_ctx->gpu_family != GPU_FAMILY::ADRENO) { + GGML_LOG_WARN("ggml_opencl: Adreno-specific kernels should not be enabled for non-Adreno GPUs; " + "run on an Adreno GPU or recompile with CMake option `-DGGML_OPENCL_USE_ADRENO_KERNELS=OFF`\n"); + return false; + } +#endif + + size_t ext_str_size; + clGetDeviceInfo(dev_ctx->device, CL_DEVICE_EXTENSIONS, 0, NULL, &ext_str_size); + + char *ext_buffer = (char *)alloca(ext_str_size + 1); + clGetDeviceInfo(dev_ctx->device, CL_DEVICE_EXTENSIONS, ext_str_size, ext_buffer, NULL); + ext_buffer[ext_str_size] = '\0'; + + // Check if ext_buffer contains cl_khr_fp16 + bool fp16_support = strstr(ext_buffer, "cl_khr_fp16") != NULL; + if (!fp16_support) { + GGML_LOG_WARN("ggml_opencl: device does not support FP16\n"); + return false; + } + + // If OpenCL 3.0 is supported, then check for cl_khr_subgroups, which becomes + // optional in OpenCL 3.0 (cl_khr_subgroup is mandatory in OpenCL 2.x) + if (opencl_c_version.major == 3 && strstr(ext_buffer, "cl_khr_subgroups") == NULL && + strstr(ext_buffer, "cl_intel_subgroups") == NULL) { + GGML_LOG_WARN("ggml_opencl: device does not support subgroups (cl_khr_subgroups or cl_intel_subgroups) " + "(note that subgroups is an optional feature in OpenCL 3.0)\n"); + return false; + } + + clGetDeviceInfo(dev_ctx->device, CL_DEVICE_GLOBAL_MEM_SIZE, sizeof(size_t), &dev_ctx->global_mem_size, NULL); + + const char * str_opfilter = getenv("GGML_OPENCL_OPFILTER"); + if (str_opfilter) { + dev_ctx->opfilter_str = str_opfilter; + dev_ctx->opfilter = new std::regex(str_opfilter, std::regex_constants::icase); + } + + return true; +} + +// Initialize device if it is supported (returns nullptr if it is not). +static ggml_backend_opencl_context * ggml_cl_init(ggml_backend_dev_t dev) { + GGML_ASSERT(dev); + GGML_ASSERT(dev->context); + + ggml_backend_opencl_device_context * dev_ctx = (ggml_backend_opencl_device_context *) dev->context; + GGML_ASSERT(dev_ctx->platform); + GGML_ASSERT(dev_ctx->device); + + if (dev_ctx->backend_ctx) { + return dev_ctx->backend_ctx; + } + + auto backend_ctx = std::make_unique<ggml_backend_opencl_context>(); + backend_ctx->device = dev_ctx->device; + backend_ctx->gpu_family = GPU_FAMILY::UNKNOWN; + + // ref_count get increased in ggml_backend_opencl_device_init + // This function is also used to retrieve backend context, so we don't want + // to increase ref_count for each call. We only want to increase ref_count + // when the associated device is initialized + backend_ctx->ref_count = 0; + + backend_ctx->gpu_family = dev_ctx->gpu_family; + backend_ctx->adreno_gen = dev_ctx->adreno_gen; + if (backend_ctx->gpu_family == GPU_FAMILY::ADRENO) { + // Use wave size of 64 for all Adreno GPUs. + backend_ctx->adreno_wave_size = 64; + } + + // Populate backend device name + backend_ctx->device_name = dev_ctx->device_name; + + // A local ref of cl_device_id for convenience + cl_device_id device = backend_ctx->device; + + ggml_cl_version platform_version = get_opencl_platform_version(dev_ctx->platform); + ggml_cl_version opencl_c_version = get_opencl_c_version(platform_version, device); + + backend_ctx->platform_version = platform_version; + backend_ctx->opencl_c_version = opencl_c_version; + + // Check driver version + size_t driver_version_str_size; + clGetDeviceInfo(device, CL_DRIVER_VERSION, 0, NULL, &driver_version_str_size); + char *driver_version = (char *)alloca(driver_version_str_size + 1); + clGetDeviceInfo(device, CL_DRIVER_VERSION, driver_version_str_size, driver_version, NULL); + driver_version[driver_version_str_size] = '\0'; + backend_ctx->driver_version = driver_version; + + backend_ctx->adreno_cl_compiler_version = get_adreno_cl_compiler_version(driver_version); + backend_ctx->has_vector_subgroup_broadcast = + (backend_ctx->adreno_cl_compiler_version.type == E031 && backend_ctx->adreno_cl_compiler_version.major >= 47) || + (backend_ctx->adreno_cl_compiler_version.type == DX && backend_ctx->adreno_cl_compiler_version.major >= 17); + + size_t ext_str_size; + clGetDeviceInfo(device, CL_DEVICE_EXTENSIONS, 0, NULL, &ext_str_size); + char *ext_buffer = (char *)alloca(ext_str_size + 1); + clGetDeviceInfo(device, CL_DEVICE_EXTENSIONS, ext_str_size, ext_buffer, NULL); + ext_buffer[ext_str_size] = '\0'; // ensure it is null terminated + + // check support for qcom_subgroup_shuffle + if (strstr(ext_buffer, "cl_qcom_subgroup_shuffle") != NULL) { + backend_ctx->has_qcom_subgroup_shuffle = true; + } + + // Check if ext_buffer contains cl_khr_fp16 + backend_ctx->fp16_support = strstr(ext_buffer, "cl_khr_fp16") != NULL; + + // check Adreno large buffer support + backend_ctx->adreno_has_large_buffer = strstr(ext_buffer, "cl_qcom_large_buffer") != NULL; + + cl_uint base_align_in_bits; + CL_CHECK(clGetDeviceInfo(device, CL_DEVICE_MEM_BASE_ADDR_ALIGN, sizeof(cl_uint), &base_align_in_bits, NULL)); + GGML_ASSERT(base_align_in_bits % 8u == 0); + backend_ctx->alignment = base_align_in_bits / 8u; + + backend_ctx->global_mem_size = dev_ctx->global_mem_size; + + CL_CHECK(clGetDeviceInfo(device, CL_DEVICE_MAX_MEM_ALLOC_SIZE, sizeof(size_t), &backend_ctx->max_alloc_size, NULL)); + CL_CHECK(clGetDeviceInfo(device, CL_DEVICE_IMAGE_MAX_BUFFER_SIZE, sizeof(size_t), &backend_ctx->image_max_buffer_size, NULL)); + CL_CHECK(clGetDeviceInfo(device, CL_DEVICE_IMAGE2D_MAX_WIDTH, sizeof(size_t), &backend_ctx->image2d_max_width, NULL)); + CL_CHECK(clGetDeviceInfo(device, CL_DEVICE_IMAGE2D_MAX_HEIGHT, sizeof(size_t), &backend_ctx->image2d_max_height, NULL)); + CL_CHECK(clGetDeviceInfo(device, CL_DEVICE_MAX_WORK_GROUP_SIZE, sizeof(size_t), &backend_ctx->max_workgroup_size, NULL)); + CL_CHECK(clGetDeviceInfo(device, CL_DEVICE_SVM_CAPABILITIES, sizeof(cl_device_svm_capabilities), &backend_ctx->svm_caps, 0)); + + if (opencl_c_version.major >= 3) { + // Assume it is not available for 3.0, since it is optional in 3.0. + // If compiling against 3.0, then we can query. + backend_ctx->non_uniform_workgroups = false; +#if CL_TARGET_OPENCL_VERSION >= 300 + CL_CHECK(clGetDeviceInfo(device, CL_DEVICE_NON_UNIFORM_WORK_GROUP_SUPPORT, sizeof(cl_bool), + &backend_ctx->non_uniform_workgroups, 0)); +#endif + } else { + GGML_ASSERT(opencl_c_version.major == 2); + // Non-uniform workgroup sizes is mandatory feature in v2.x. + backend_ctx->non_uniform_workgroups = true; + } + +#ifdef GGML_OPENCL_USE_ADRENO_KERNELS + // determine whether to use Adreno xmem GEMM + backend_ctx->adreno_xmem_gemm_enabled = getenv("GGML_OPENCL_ADRENO_XMEM_GEMM") != nullptr && + backend_ctx->gpu_family == GPU_FAMILY::ADRENO; +#endif + + // determine whether to use large buffer for Adreno + backend_ctx->adreno_use_large_buffer = getenv("GGML_OPENCL_ADRENO_USE_LARGE_BUFFER") != nullptr && + backend_ctx->gpu_family == GPU_FAMILY::ADRENO; + + cl_int err; + + // A local ref of cl_context for convenience + cl_context context = backend_ctx->context = dev_ctx->context; + + //CL_CHECK((queue = clCreateCommandQueue(context, device, CL_QUEUE_OUT_OF_ORDER_EXEC_MODE_ENABLE, &err), + // (err != CL_INVALID_QUEUE_PROPERTIES && err != CL_INVALID_VALUE ? err : + // (queue = clCreateCommandQueue(context, device, 0, &err), err) + //))); + cl_command_queue_properties command_queue_props = 0; +#ifdef GGML_OPENCL_PROFILING + command_queue_props |= CL_QUEUE_PROFILING_ENABLE; +#endif + CL_CHECK((backend_ctx->queue = clCreateCommandQueue(context, device, command_queue_props, &err), err)); + + // delay kernel loading until the first buffer is created + // load_cl_kernels(backend_ctx.get()); + +#ifdef GGML_OPENCL_USE_ADRENO_KERNELS + // Allocate intermediate buffers and images + size_t required_A_q_d_bytes = 311164928; + size_t required_A_s_d_bytes = 38895616; + size_t required_B_d_bytes = 45088768; + + // Ensure buffer sizes do not exceed the maximum allocation size + size_t max_A_q_d_bytes = MIN(required_A_q_d_bytes, backend_ctx->max_alloc_size); + size_t max_A_s_d_bytes = MIN(required_A_s_d_bytes, backend_ctx->max_alloc_size); + size_t max_B_d_bytes = MIN(required_B_d_bytes, backend_ctx->max_alloc_size); + if (required_A_q_d_bytes > backend_ctx->max_alloc_size) { + GGML_LOG_WARN("ggml_opencl: A_q_d buffer size reduced from %zu to %zu due to device limitations.\n", + required_A_q_d_bytes, max_A_q_d_bytes); + } + if (required_A_s_d_bytes > backend_ctx->max_alloc_size) { + GGML_LOG_WARN("ggml_opencl: A_s_d buffer size reduced from %zu to %zu due to device limitations.\n", + required_A_s_d_bytes, max_A_s_d_bytes); + } + if (required_B_d_bytes > backend_ctx->max_alloc_size) { + GGML_LOG_WARN("ggml_opencl: B_d buffer size reduced from %zu to %zu due to device limitations.\n", + required_B_d_bytes, max_B_d_bytes); + } + + backend_ctx->prealloc_quant_trans.allocate(context, max_A_q_d_bytes); + backend_ctx->prealloc_scales_trans.allocate(context, max_A_s_d_bytes); + backend_ctx->prealloc_act_trans.allocate(context, max_B_d_bytes); +#endif // GGML_OPENCL_USE_ADRENO_KERNELS + + backend_ctx->disable_fusion = getenv("GGML_OPENCL_DISABLE_FUSION") != nullptr; + + dev_ctx->backend_ctx = backend_ctx.release(); + return dev_ctx->backend_ctx; +} + +static void ggml_cl_free(ggml_backend_t backend) { + ggml_backend_opencl_context * ctx = (ggml_backend_opencl_context *) backend->context; + ctx->free(); +} + +#ifdef GGML_OPENCL_USE_ADRENO_KERNELS +static void transpose_2d( + ggml_backend_opencl_context * backend_ctx, + cl_kernel kernel, + cl_mem src, cl_mem dst, size_t size, + cl_int stride, cl_int rows, + bool blocking = true +) { + static ggml_cl_buffer buf; + + cl_event evt; + cl_int err; + + buf.allocate(backend_ctx->context, size); + + cl_mem trans; + cl_buffer_region region; + + region.origin = 0; + region.size = size; + CL_CHECK((trans = clCreateSubBuffer( + buf.buffer, CL_MEM_READ_WRITE, + CL_BUFFER_CREATE_TYPE_REGION, ®ion, &err), err)); + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &src)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &trans)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_int), &stride)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_int), &rows)); + + size_t local_size[3] = {64, 1, 1}; + size_t global_size[3] = {(size_t)stride, (size_t)rows, 1};; + CL_CHECK(clEnqueueNDRangeKernel(backend_ctx->queue, kernel, 3, NULL, + global_size, local_size, 0, NULL, NULL)); + + if (blocking) { + CL_CHECK(clEnqueueCopyBuffer(backend_ctx->queue, trans, dst, 0, 0, size, 0, NULL, &evt)); + CL_CHECK(clWaitForEvents(1, &evt)); + CL_CHECK(clReleaseEvent(evt)); + } else { + CL_CHECK(clEnqueueCopyBuffer(backend_ctx->queue, trans, dst, 0, 0, size, 0, NULL, NULL)); + } + + CL_CHECK(clReleaseMemObject(trans)); +} + +static void transpose_2d_as_8b( + ggml_backend_opencl_context * backend_ctx, + cl_mem src, cl_mem dst, size_t size, + cl_int stride, cl_int rows, + bool blocking = true +) { + transpose_2d(backend_ctx, backend_ctx->kernel_transpose_8_buf, + src, dst, size, stride, rows, blocking); +} + +static void transpose_2d_as_16b( + ggml_backend_opencl_context * backend_ctx, + cl_mem src, cl_mem dst, size_t size, + cl_int stride, cl_int rows, + bool blocking = true +) { + transpose_2d(backend_ctx, backend_ctx->kernel_transpose_16_buf, + src, dst, size, stride, rows, blocking); +} + +static void transpose_2d_as_32b( + ggml_backend_opencl_context * backend_ctx, + cl_mem src, cl_mem dst, size_t size, + cl_int stride, cl_int rows, + bool blocking = true +) { + transpose_2d(backend_ctx, backend_ctx->kernel_transpose_32_buf, + src, dst, size, stride, rows, blocking); +} +#endif // GGML_OPENCL_USE_ADRENO_KERNELS + +//------------------------------------------------------------------------------ +// Tensor extra management +//------------------------------------------------------------------------------ +struct ggml_tensor_extra_cl { + // The buffer object that holds the data. + cl_mem data_device; + // The offset into the buffer object. This is primarily for scratch buffer + // and view operation. + // NB: this offset no longer includes view offset (view_offs). Whenever this + // offset is used, view_offs should be considered. + cl_ulong offset; + // The actual size of the cl_mem object. This is needed when returning the + // block to the pool. + size_t actual_size; + + void reset() { + data_device = nullptr; + offset = 0; + actual_size = 0; + } +}; + +// Additional tensor extra structs for quantized tensors. +// These tensors are loaded from files and should not be allocated in scratch -- +// they should always be allocated from the pool. Hence, they do not have an +// `offset`, which indicate their locations in the scratch buffer. +struct ggml_tensor_extra_cl_q4_0 { + // Quantized values. + cl_mem q = nullptr; + // Quantized values in image1d_buffer_t. + cl_mem q_img = nullptr; + // Scales. + cl_mem d = nullptr; + // Scales in image1d_buffer_t. + cl_mem d_img = nullptr; + // Size of quantized values. + size_t size_q = 0; + // Size of scales. + size_t size_d = 0; + + ~ggml_tensor_extra_cl_q4_0() { + reset(); + } + + void reset() { + // q and d are subbuffers into the bigger buffer allocated in ggml_backend_buffer. + // They must be properly released so that the original buffer can be + // properly released to avoid memory leak. + if (q != nullptr) { + CL_CHECK(clReleaseMemObject(q)); + q = nullptr; + } + if (d != nullptr) { + CL_CHECK(clReleaseMemObject(d)); + d = nullptr; + } + if (q_img != nullptr) { + CL_CHECK(clReleaseMemObject(q_img)); + q_img = nullptr; + } + // Currently, q_img and d_img are only initialized when SMALL_ALLOC is + // enabled. They point to the images in ggml_backend_opencl_buffer_context. + // So, there is no need to release them here. + // TODO: initialize them for non SMALL_PATH path, or remove them. + d_img = nullptr; + size_q = 0; + size_d = 0; + } +}; + +struct ggml_tensor_extra_cl_q4_1 { + // Quantized values. + cl_mem q = nullptr; + // Quantized values in image1d_buffer_t. + cl_mem q_img = nullptr; + // Scales. + cl_mem d = nullptr; + // Scales in image1d_buffer_t. + cl_mem d_img = nullptr; + // Min + cl_mem m = nullptr; + // Min in image1d_buffer_t. + cl_mem m_img = nullptr; + // Size of quantized values. + size_t size_q = 0; + // Size of scales. + size_t size_d = 0; + // Size of min values. + size_t size_m = 0; + + ~ggml_tensor_extra_cl_q4_1() { + reset(); + } + + void reset() { + // q and d are subbuffers into the bigger buffer allocated in ggml_backend_buffer. + // They must be properly released so that the original buffer can be + // properly released to avoid memory leak. + if (q != nullptr) { + CL_CHECK(clReleaseMemObject(q)); + q = nullptr; + } + if (d != nullptr) { + CL_CHECK(clReleaseMemObject(d)); + d = nullptr; + } + if (m != nullptr) { + CL_CHECK(clReleaseMemObject(m)); + m = nullptr; + } + if (q_img != nullptr) { + CL_CHECK(clReleaseMemObject(q_img)); + q_img = nullptr; + } + // Currently, q_img and d_img are only initialized when SMALL_ALLOC is + // enabled. They point to the images in ggml_backend_opencl_buffer_context. + // So, there is no need to release them here. + // TODO: initialize them for non SMALL_PATH path, or remove them. + d_img = nullptr; + m_img = nullptr; + size_q = 0; + size_d = 0; + size_m = 0; + } +}; + +struct ggml_tensor_extra_cl_q5_0 { + // Quantized values. + cl_mem qs = nullptr; + // Quantized values in image1d_buffer_t. + cl_mem qs_img = nullptr; + // 5-th bit values. + cl_mem qh = nullptr; + // 5-th bit values in image1d_buffer_t. + cl_mem qh_img = nullptr; + // Scales. + cl_mem d = nullptr; + // Scales in image1d_buffer_t. + cl_mem d_img = nullptr; + // Size of quantized values. + size_t size_qs = 0; + // Size of 5-th bit values. + size_t size_qh = 0; + // Size of scales. + size_t size_d = 0; + + ~ggml_tensor_extra_cl_q5_0() { + reset(); + } + + void reset() { + if (qs != nullptr) { + CL_CHECK(clReleaseMemObject(qs)); + qs = nullptr; + } + if (qh != nullptr) { + CL_CHECK(clReleaseMemObject(qh)); + qh = nullptr; + } + if (d != nullptr) { + CL_CHECK(clReleaseMemObject(d)); + d = nullptr; + } + if (qs_img != nullptr) { + CL_CHECK(clReleaseMemObject(qs_img)); + qs_img = nullptr; + } + + qh_img = nullptr; + d_img = nullptr; + size_qs = 0; + size_qh = 0; + size_d = 0; + } +}; + +struct ggml_tensor_extra_cl_q5_1 { + // Quantized values. + cl_mem qs = nullptr; + // Quantized values in image1d_buffer_t. + cl_mem qs_img = nullptr; + // 5-th bit values. + cl_mem qh = nullptr; + // 5-th bit values in image1d_buffer_t. + cl_mem qh_img = nullptr; + // Scales. + cl_mem d = nullptr; + // Scales in image1d_buffer_t. + cl_mem d_img = nullptr; + // Min + cl_mem m = nullptr; + // Min in image1d_buffer_t. + cl_mem m_img = nullptr; + // Size of quantized values. + size_t size_qs = 0; + // Size of 5-th bit values. + size_t size_qh = 0; + // Size of scales. + size_t size_d = 0; + // Size of min values. + size_t size_m = 0; + + ~ggml_tensor_extra_cl_q5_1() { + reset(); + } + + void reset() { + // q and d are subbuffers into the bigger buffer allocated in ggml_backend_buffer. + // They must be properly released so that the original buffer can be + // properly released to avoid memory leak. + if (qs != nullptr) { + CL_CHECK(clReleaseMemObject(qs)); + qs = nullptr; + } + if (qh != nullptr) { + CL_CHECK(clReleaseMemObject(qh)); + qh = nullptr; + } + if (d != nullptr) { + CL_CHECK(clReleaseMemObject(d)); + d = nullptr; + } + if (m != nullptr) { + CL_CHECK(clReleaseMemObject(m)); + m = nullptr; + } + if (qs_img != nullptr) { + CL_CHECK(clReleaseMemObject(qs_img)); + qs_img = nullptr; + } + // qh_img, d_img, and m_img are not currently allocated separately. + // TODO: initialize them for non SMALL_PATH path, or remove them. + qh_img = nullptr; + d_img = nullptr; + m_img = nullptr; + size_qs = 0; + size_qh = 0; + size_d = 0; + size_m = 0; + } +}; + +struct ggml_tensor_extra_cl_mxfp4 { + // Quantized values. + cl_mem q = nullptr; + // Quantized values in image1d_buffer_t. + cl_mem q_img = nullptr; + // Scales in E8M0. + cl_mem e = nullptr; + // Scales in image1d_buffer_t. + cl_mem e_img = nullptr; + // Size of quantized values. + size_t size_q = 0; + // Size of scales. + size_t size_e = 0; + + ~ggml_tensor_extra_cl_mxfp4() { + reset(); + } + + void reset() { + // q and d are subbuffers into the bigger buffer allocated in ggml_backend_buffer. + // They must be properly released so that the original buffer can be + // properly released to avoid memory leak. + if (q != nullptr) { + CL_CHECK(clReleaseMemObject(q)); + q = nullptr; + } + if (e != nullptr) { + CL_CHECK(clReleaseMemObject(e)); + e = nullptr; + } + if (q_img != nullptr) { + CL_CHECK(clReleaseMemObject(q_img)); + q_img = nullptr; + } + // Currently, e_img is not used. They can be image1d_buffer_t + // that wraps around q and d to utilize image access path. + e_img = nullptr; + size_q = 0; + size_e = 0; + } +}; + +struct ggml_tensor_extra_cl_q8_0 { + cl_mem q = nullptr; + cl_mem q_img = nullptr; + + cl_mem d = nullptr; + cl_mem d_img = nullptr; + + size_t size_q = 0; + size_t size_d = 0; + + ~ggml_tensor_extra_cl_q8_0() { + reset(); + } + + void reset() { + // q and d are subbuffers into the bigger buffer allocated in ggml_backend_buffer. + // They must be properly released so that the original buffer can be + // properly released to avoid memory leak. + if (q != nullptr) { + CL_CHECK(clReleaseMemObject(q)); + q = nullptr; + } + if (d != nullptr) { + CL_CHECK(clReleaseMemObject(d)); + d = nullptr; + } + // Currently, q_img and d_img are not used. They can be image1d_buffer_t + // that wraps around q and d to utilize image access path. + q_img = nullptr; + d_img = nullptr; + size_q = 0; + size_d = 0; + } +}; + +struct ggml_tensor_extra_cl_iq4_nl { + cl_mem q = nullptr; + cl_mem q_img = nullptr; + + cl_mem d = nullptr; + cl_mem d_img = nullptr; + + size_t size_q = 0; + size_t size_d = 0; + + ~ggml_tensor_extra_cl_iq4_nl() { + reset(); + } + + void reset() { + if (q != nullptr) { CL_CHECK(clReleaseMemObject(q)); q = nullptr; } + if (d != nullptr) { CL_CHECK(clReleaseMemObject(d)); d = nullptr; } + q_img = nullptr; + d_img = nullptr; + size_q = 0; + size_d = 0; + } +}; + +struct ggml_tensor_extra_cl_q4_K { + // Quantized values + cl_mem q = nullptr; + // Quantized values in image1d_buffer_t. + cl_mem q_img = nullptr; + // Scales for each super block. + cl_mem s = nullptr; + // Scales + cl_mem d = nullptr; + // Min + cl_mem dm = nullptr; + + ~ggml_tensor_extra_cl_q4_K() { + reset(); + } + + void reset() { + if (q != nullptr) { + CL_CHECK(clReleaseMemObject(q)); + q = nullptr; + } + if (s != nullptr) { + CL_CHECK(clReleaseMemObject(s)); + s = nullptr; + } + if (d != nullptr) { + CL_CHECK(clReleaseMemObject(d)); + d = nullptr; + } + if (dm != nullptr) { + CL_CHECK(clReleaseMemObject(dm)); + dm = nullptr; + } + if (q_img != nullptr) { + CL_CHECK(clReleaseMemObject(q_img)); + q_img = nullptr; + } + } +}; + +struct ggml_tensor_extra_cl_q5_K { + // Lower 4 bits of quantized weights. + cl_mem q = nullptr; + // Quantized values in image1d_buffer_t. + cl_mem q_img = nullptr; + // Upper 1 bit of quantized weights. + cl_mem qh = nullptr; + // Scales for each block. + cl_mem s = nullptr; + // Scales for each super block. + cl_mem d = nullptr; + // Min for each super block. + cl_mem dm = nullptr; + + size_t size_q = 0; + size_t size_qh = 0; + size_t size_s = 0; + size_t size_d = 0; + size_t size_dm = 0; + + ~ggml_tensor_extra_cl_q5_K() { + reset(); + } + + void reset() { + if (q != nullptr) { + CL_CHECK(clReleaseMemObject(q)); + q = nullptr; + } + if (qh != nullptr) { + CL_CHECK(clReleaseMemObject(qh)); + qh = nullptr; + } + if (s != nullptr) { + CL_CHECK(clReleaseMemObject(s)); + s = nullptr; + } + if (d != nullptr) { + CL_CHECK(clReleaseMemObject(d)); + d = nullptr; + } + if (dm != nullptr) { + CL_CHECK(clReleaseMemObject(dm)); + dm = nullptr; + } + if (q_img != nullptr) { + CL_CHECK(clReleaseMemObject(q_img)); + q_img = nullptr; + } + + size_q = 0; + size_qh = 0; + size_s = 0; + size_d = 0; + size_dm = 0; + } +}; + +struct ggml_tensor_extra_cl_q6_K { + // Lower 4 bits of quantized weights. + cl_mem ql = nullptr; + // Lower 4 bits as image1d_buffer_t + cl_mem ql_img = nullptr; + // Upper 2 bits of quantized weights. + cl_mem qh = nullptr; + // Scales for each block. + cl_mem s = nullptr; + // Scales for each super block. + cl_mem d = nullptr; + + size_t size_ql = 0; + size_t size_qh = 0; + size_t size_s = 0; + size_t size_d = 0; + + ~ggml_tensor_extra_cl_q6_K() { + reset(); + } + + void reset() { + if (ql != nullptr) { + CL_CHECK(clReleaseMemObject(ql)); + ql = nullptr; + } + if (qh != nullptr) { + CL_CHECK(clReleaseMemObject(qh)); + qh = nullptr; + } + if (s != nullptr) { + CL_CHECK(clReleaseMemObject(s)); + s = nullptr; + } + if (d != nullptr) { + CL_CHECK(clReleaseMemObject(d)); + d = nullptr; + } + if (ql_img != nullptr) { + CL_CHECK(clReleaseMemObject(ql_img)); + ql_img = nullptr; + } + + size_ql = 0; + size_qh = 0; + size_s = 0; + size_d = 0; + } +}; + +//------------------------------------------------------------------------------ +// Backend API +//------------------------------------------------------------------------------ + +// +// backend +// +static const char * ggml_backend_opencl_name(ggml_backend_t backend) { + return "OpenCL"; + + UNUSED(backend); +} + +static void ggml_backend_opencl_free(ggml_backend_t backend) { + ggml_cl_free(backend); +} + +static void ggml_backend_opencl_set_tensor_async(ggml_backend_t backend, ggml_tensor * tensor, const void * data, size_t offset, size_t size) { + GGML_UNUSED(backend); + GGML_UNUSED(tensor); + GGML_UNUSED(data); + GGML_UNUSED(offset); + GGML_UNUSED(size); +} + +static void ggml_backend_opencl_get_tensor_async(ggml_backend_t backend, const ggml_tensor * tensor, void * data, size_t offset, size_t size) { + GGML_UNUSED(backend); + GGML_UNUSED(tensor); + GGML_UNUSED(data); + GGML_UNUSED(offset); + GGML_UNUSED(size); +} + +static bool ggml_backend_opencl_cpy_tensor_async(ggml_backend_t backend, const ggml_tensor * src, ggml_tensor * dst) { + GGML_UNUSED(backend); + GGML_UNUSED(src); + GGML_UNUSED(dst); + return false; +} + +static void ggml_backend_opencl_synchronize(ggml_backend_t backend) { + auto * backend_ctx = static_cast<ggml_backend_opencl_context *>(backend->context); + + cl_event evt; + CL_CHECK(clEnqueueBarrierWithWaitList(backend_ctx->queue, 0, nullptr, &evt)); + CL_CHECK(clWaitForEvents(1, &evt)); + CL_CHECK(clReleaseEvent(evt)); +} + +// Synchronizes the 'backend_ctx's device with others so that commands +// enqueued to it won't start until commands in the other devices have +// completed. +static void sync_with_other_backends(ggml_backend_opencl_context * backend_ctx) { + if (g_ggml_backend_opencl_devices.size() < 2) { + return; // No other devices to synchronize with. + } + + std::vector<cl_event> events; + events.reserve(g_ggml_backend_opencl_devices.size()); + + for (ggml_backend_device & backend_dev : g_ggml_backend_opencl_devices) { + ggml_backend_opencl_device_context * dev_ctx = (ggml_backend_opencl_device_context *) backend_dev.context; + auto * other_backend_ctx = dev_ctx->backend_ctx; + + if (backend_ctx != other_backend_ctx) { + cl_event ev; + CL_CHECK(clEnqueueMarkerWithWaitList(other_backend_ctx->queue, 0, nullptr, &ev)); + CL_CHECK(clFlush(other_backend_ctx->queue)); + events.push_back(ev); + } + } + + CL_CHECK(clEnqueueBarrierWithWaitList(backend_ctx->queue, events.size(), events.data(), nullptr)); + for (auto ev : events) { + CL_CHECK(clReleaseEvent(ev)); + } +} + +static void sync_with_other_backends(ggml_backend_t backend) { + auto * backend_ctx = static_cast<ggml_backend_opencl_context *>(backend->context); + sync_with_other_backends(backend_ctx); +} + +static bool ggml_opencl_can_fuse(const struct ggml_cgraph * cgraph, int node_idx, std::initializer_list<enum ggml_op> ops) { + if (!ggml_can_fuse(cgraph, node_idx, ops)) { + return false; + } + + if (ops.size() == 2 && ops.begin()[0] == GGML_OP_RMS_NORM && ops.begin()[1] == GGML_OP_MUL) { + const ggml_tensor *rms_norm = cgraph->nodes[node_idx]; + const ggml_tensor *mul = cgraph->nodes[node_idx+1]; + + GGML_ASSERT(rms_norm->src[0]->type == GGML_TYPE_F32); + GGML_ASSERT(rms_norm->type == GGML_TYPE_F32); + + // rms_norm only supports f32 + if (mul->src[0]->type != GGML_TYPE_F32 || + mul->src[1]->type != GGML_TYPE_F32 || + mul->type != GGML_TYPE_F32) { + return false; + } + + // if rms_norm is the B operand, then we don't handle broadcast + if (rms_norm == mul->src[1] && + !ggml_are_same_shape(mul->src[0], rms_norm)) { + return false; + } + + // rms_norm assumes contiguous rows + if (!ggml_is_contiguous_rows(mul->src[0]) || !ggml_is_contiguous_rows(mul->src[1])) { + return false; + } + } else if (ops.size() == 3 && ops.begin()[0] == GGML_OP_NORM && ops.begin()[1] == GGML_OP_MUL && ops.begin()[2] == GGML_OP_ADD) { + const ggml_tensor *norm = cgraph->nodes[node_idx]; + const ggml_tensor *mul = cgraph->nodes[node_idx+1]; + const ggml_tensor *add = cgraph->nodes[node_idx+2]; + const ggml_tensor *w = mul->src[0] == norm ? mul->src[1] : mul->src[0]; + const ggml_tensor *b = add->src[0] == mul ? add->src[1] : add->src[0]; + + // norm fusion only supports F32 + if (norm->src[0]->type != GGML_TYPE_F32 || w->type != GGML_TYPE_F32 || b->type != GGML_TYPE_F32) { + return false; + } + + if (norm->src[0]->ne[0] % 4 != 0) { + return false; + } + + if (!ggml_is_contiguous(norm->src[0]) || !ggml_is_contiguous(w) || !ggml_is_contiguous(b)) { + return false; + } + } else if (ops.size() == 3 && ops.begin()[0] == GGML_OP_GROUP_NORM && ops.begin()[1] == GGML_OP_MUL && ops.begin()[2] == GGML_OP_ADD) { + const ggml_tensor *gn = cgraph->nodes[node_idx]; + const ggml_tensor *mul = cgraph->nodes[node_idx+1]; + const ggml_tensor *add = cgraph->nodes[node_idx+2]; + const ggml_tensor *w = mul->src[0] == gn ? mul->src[1] : mul->src[0]; + const ggml_tensor *b = add->src[0] == mul ? add->src[1] : add->src[0]; + + if (gn->src[0]->type != GGML_TYPE_F32 || w->type != GGML_TYPE_F32 || b->type != GGML_TYPE_F32) { + return false; + } + + if (!ggml_is_contiguous(gn->src[0]) || !ggml_is_contiguous(w) || !ggml_is_contiguous(b)) { + return false; + } + } + + return true; +} + +static void ggml_opencl_op_rms_norm_fused(ggml_backend_t backend, ggml_tensor * rms_norm_tensor, ggml_tensor * mul_tensor); +static void ggml_opencl_op_norm_fused(ggml_backend_t backend, ggml_tensor * norm_tensor, ggml_tensor * mul_tensor, ggml_tensor * add_tensor); +static void ggml_opencl_op_group_norm_fused(ggml_backend_t backend, ggml_tensor * gn_tensor, ggml_tensor * mul_tensor, ggml_tensor * add_tensor); + +static ggml_status ggml_backend_opencl_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) { + ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context; + + for (int i = 0; i < cgraph->n_nodes; i++) { + ggml_tensor * node = cgraph->nodes[i]; + + // NOTE: this may oversynchronize by synchronizing with + // backends/devices which don't compute 'cgraph's + // dependencies. + sync_with_other_backends(backend); + + if (ggml_is_empty(node) || node->op == GGML_OP_RESHAPE || node->op == GGML_OP_TRANSPOSE || node->op == GGML_OP_VIEW || node->op == GGML_OP_PERMUTE || node->op == GGML_OP_NONE) { + continue; + } + + if ((node->flags & GGML_TENSOR_FLAG_COMPUTE) == 0) { + continue; + } + + if (!backend_ctx->disable_fusion && ggml_opencl_can_fuse(cgraph, i, { GGML_OP_NORM, GGML_OP_MUL, GGML_OP_ADD })) { + ggml_opencl_op_norm_fused(backend, node, cgraph->nodes[i+1], cgraph->nodes[i+2]); + i += 2; + continue; + } + if (!backend_ctx->disable_fusion && ggml_opencl_can_fuse(cgraph, i, { GGML_OP_GROUP_NORM, GGML_OP_MUL, GGML_OP_ADD })) { + ggml_opencl_op_group_norm_fused(backend, node, cgraph->nodes[i+1], cgraph->nodes[i+2]); + i += 2; + continue; + } + if (!backend_ctx->disable_fusion && ggml_opencl_can_fuse(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL })) { + ggml_opencl_op_rms_norm_fused(backend, node, cgraph->nodes[i+1]); + i++; + continue; + } + + bool ok = ggml_cl_compute_forward(backend, node); + if (!ok) { + GGML_LOG_ERROR("%s: error: op not supported %s (%s)\n", __func__, node->name, ggml_op_name(node->op)); + } + GGML_ASSERT(ok); + } + + return GGML_STATUS_SUCCESS; +} + +// The optimized gemm and gemv kernels are used for large matrices without batch. +// tensor is the quantized weights matrix. +inline bool use_adreno_kernels(const ggml_backend_opencl_context *backend_ctx, const ggml_tensor *tensor) { + int64_t threshold_ne0 = 512; + int64_t threshold_ne1 = 512; + if (!backend_ctx->adreno_cl_compiler_version.newer_than_or_same(E031, 38, 11, 0) && + backend_ctx->adreno_cl_compiler_version.type != DX) { + threshold_ne0 = 128; + threshold_ne1 = 128; + } + return tensor->ne[0] >= threshold_ne0 && tensor->ne[1] >= threshold_ne1 && + tensor->ne[2] == 1 && tensor->ne[3] == 1; +} + +inline bool use_adreno_moe_kernels(const ggml_backend_opencl_context *backend_ctx, const ggml_tensor *tensor) { + GGML_UNUSED(backend_ctx); + int ne01 = tensor->ne[1]; + return (((strstr(tensor->name, "ffn") != NULL) && (strstr(tensor->name, "exps") != NULL)) || (strstr(tensor->name, "as") != NULL)) && (ne01 % 32 == 0); +} + +inline bool enable_adreno_trans_weight(const ggml_backend_opencl_context *backend_ctx, const ggml_tensor *tensor) { + + bool adreno_kernel = use_adreno_kernels(backend_ctx, tensor); + + size_t elem_num = tensor->ne[0] * tensor->ne[1] * tensor->ne[2] * tensor->ne[3]; + + return ((elem_num < 128 * 1024 * 1024) && adreno_kernel); // max element num: 2**27 +} + +static inline bool use_flat_gemv_for_large_m_q4_K(const ggml_tensor *tensor) { + // gemv_noshuffle variant perf drops for large M, use flat variant for large M. + // threshold is well above typical hidden/FFN dims, but below typical vocab sizes. + // note that this forces large M weights to use LM GEMM. + return tensor->ne[1] >= 32768 && tensor->ne[2] == 1 && tensor->ne[3] == 1; +} + +static inline bool use_flat_gemv_for_large_m_q6_K(const ggml_tensor *tensor) { + // gemv_noshuffle variant perf drops for large M, use flat variant for large M. + // threshold is well above typical hidden/FFN dims, but below typical vocab sizes. + // q6_K flat gemv is worse for smaller K; 2048 seems to be a reasonable threshold. + // note that this forces large M weights to use LM GEMM. + return tensor->ne[1] >= 32768 && tensor->ne[0] >= 2048 && tensor->ne[2] == 1 && tensor->ne[3] == 1; +} + +static bool ggml_opencl_supports_op(ggml_backend_dev_t dev, const struct ggml_tensor * op) { + ggml_backend_opencl_device_context * dev_ctx = (ggml_backend_opencl_device_context *)dev->context; + ggml_backend_opencl_context * backend_ctx = dev_ctx->backend_ctx; + + // reject ops that match the opfilter regex + if (dev_ctx->opfilter && std::regex_match(std::string(ggml_op_desc(op)), *dev_ctx->opfilter)) { + return false; + } + + switch (op->op) { + case GGML_OP_NONE: + return true; + case GGML_OP_GET_ROWS: + switch (op->src[0]->type) { + case GGML_TYPE_F32: + case GGML_TYPE_F16: + return true; + case GGML_TYPE_Q4_0: +#ifdef GGML_OPENCL_SOA_Q + // We do not support flattened Q4_0 (and possibly other Q's) + return false; +#else // GGML_OPENCL_SOA_Q + return true; +#endif // GGML_OPENCL_SOA_Q + default: + return false; + } + case GGML_OP_SET_ROWS: + { + // TODO: add support + // ref: https://github.com/ggml-org/llama.cpp/pull/14274 +#pragma message("TODO: implement BF16, Q4_0, Q4_1, Q5_0, Q5_1, Q8_0, IQ4_NL support (https://github.com/ggml-org/llama.cpp/pull/14661)") + if (op->src[0]->type != GGML_TYPE_F32) { + return false; + } + switch (op->type) { + case GGML_TYPE_F16: + case GGML_TYPE_F32: + return (op->src[1]->type == GGML_TYPE_I64 || op->src[1]->type == GGML_TYPE_I32); + default: + return false; + } + } + case GGML_OP_CPY: + case GGML_OP_DUP: + case GGML_OP_CONT: + switch (op->src[0]->type) { + case GGML_TYPE_F32: + switch (op->type) { + case GGML_TYPE_F16: + case GGML_TYPE_F32: + return true; + default: + return false; + } + case GGML_TYPE_F16: + switch (op->type) { + case GGML_TYPE_F16: + case GGML_TYPE_F32: + return true; + default: + return false; + } + case GGML_TYPE_I32: + switch (op->type) { + case GGML_TYPE_I32: + return true; + default: + return false; + } + default: + return false; + } + case GGML_OP_SET: { + return (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_I32) && + op->type == op->src[0]->type && + op->type == op->src[1]->type; + } + case GGML_OP_SCALE: + return op->src[0]->type == GGML_TYPE_F32 && ggml_is_contiguous(op->src[0]); + case GGML_OP_ADD: + if (op->type == GGML_TYPE_F16) { + const bool src0_ok = op->src[0]->type == GGML_TYPE_F16 || op->src[0]->type == GGML_TYPE_F32; + const bool src1_ok = op->src[1]->type == GGML_TYPE_F16 || op->src[1]->type == GGML_TYPE_F32; + if (src0_ok && src1_ok) { + return true; + } + } + case GGML_OP_MUL: + case GGML_OP_DIV: + case GGML_OP_SUB: + return (op->src[0]->type == op->src[1]->type) && + (op->src[0]->type == op->type) && + (op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16); + case GGML_OP_ADD_ID: + return op->src[0]->type == GGML_TYPE_F32; + case GGML_OP_SQR: + case GGML_OP_SQRT: + return (op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16) && + ggml_is_contiguous(op->src[0]); + case GGML_OP_UNARY: + switch (ggml_get_unary_op(op)) { + case GGML_UNARY_OP_GELU: + case GGML_UNARY_OP_SILU: + case GGML_UNARY_OP_RELU: + case GGML_UNARY_OP_GELU_ERF: + case GGML_UNARY_OP_GELU_QUICK: + return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32; + case GGML_UNARY_OP_SIGMOID: + return ggml_is_contiguous(op->src[0]); + case GGML_UNARY_OP_TANH: + case GGML_UNARY_OP_NEG: + case GGML_UNARY_OP_EXP: + return op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16; + case GGML_UNARY_OP_EXPM1: + return op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16; + case GGML_UNARY_OP_SOFTPLUS: + return op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16; + default: + return false; + } + case GGML_OP_GLU: + switch (ggml_get_glu_op(op)) { + case GGML_GLU_OP_GEGLU: + case GGML_GLU_OP_REGLU: + case GGML_GLU_OP_SWIGLU: + case GGML_GLU_OP_SWIGLU_OAI: + case GGML_GLU_OP_GEGLU_ERF: + case GGML_GLU_OP_GEGLU_QUICK: + return ggml_is_contiguous_1(op->src[0]) && (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16); + default: + return false; + } + case GGML_OP_TRI: + return op->type == GGML_TYPE_F32 && ggml_is_contiguous(op); + case GGML_OP_FILL: + return op->type == GGML_TYPE_F32 && ggml_is_contiguous(op); + case GGML_OP_CLAMP: + return op->src[0]->type == GGML_TYPE_F32; + case GGML_OP_SOFT_MAX: + case GGML_OP_NORM: + return true; + case GGML_OP_RMS_NORM: + return op->ne[0] % 4 == 0 && ggml_is_contiguous_rows(op->src[0]); + case GGML_OP_L2_NORM: + return ggml_is_contiguous_rows(op->src[0]); + case GGML_OP_REPEAT: + return op->src[0]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32; // Assuming F32 for now, can be expanded + case GGML_OP_PAD: + // TODO: add circular padding support for opencl, see https://github.com/ggml-org/llama.cpp/pull/16985 + if (ggml_get_op_params_i32(op, 8) != 0) { + return false; + } + return op->src[0]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32; + case GGML_OP_UPSCALE: { + ggml_scale_mode mode = (ggml_scale_mode)(ggml_get_op_params_i32(op, 0) & 0xFF); + const bool antialias = (ggml_scale_mode)(ggml_get_op_params_i32(op, 0) & GGML_SCALE_FLAG_ANTIALIAS); + return op->src[0]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32 && + (mode == GGML_SCALE_MODE_NEAREST || mode == GGML_SCALE_MODE_BILINEAR) && !antialias; + } + case GGML_OP_CONV_2D: + return (op->src[0]->type == GGML_TYPE_F16 && op->src[1]->type == GGML_TYPE_F16 && op->type == GGML_TYPE_F16) || + (op->src[0]->type == GGML_TYPE_F32 && op->src[1]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32) || + (op->src[0]->type == GGML_TYPE_F16 && op->src[1]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32); + case GGML_OP_SSM_CONV: + return (op->src[0]->type == GGML_TYPE_F32 && op->src[1]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32); + case GGML_OP_GATED_DELTA_NET: + { + // Match the Vulkan backend: only F32 -> F32, S_v in {16, 32, 64, 128}. + if (op->src[0]->type != GGML_TYPE_F32 || op->type != GGML_TYPE_F32) { + return false; + } + const int64_t S_v = op->src[2]->ne[0]; + return S_v == 16 || S_v == 32 || S_v == 64 || S_v == 128; + } + case GGML_OP_CONCAT: + return op->src[0]->type == GGML_TYPE_F32 && op->src[1]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32; + case GGML_OP_TIMESTEP_EMBEDDING: + return op->src[0]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32; + case GGML_OP_GROUP_NORM: + return ggml_is_contiguous(op->src[0]); + case GGML_OP_MUL_MAT: + if (op->src[0]->type == GGML_TYPE_F16) { + return true; + } else if (op->src[0]->type == GGML_TYPE_BF16) { + return true; + } else if (op->src[0]->type == GGML_TYPE_F32) { + return op->src[1]->type == GGML_TYPE_F32; + } else if (op->src[0]->type == GGML_TYPE_Q4_0 || op->src[0]->type == GGML_TYPE_Q4_1 || + op->src[0]->type == GGML_TYPE_Q5_0 || op->src[0]->type == GGML_TYPE_Q5_1 || + op->src[0]->type == GGML_TYPE_MXFP4 || + op->src[0]->type == GGML_TYPE_IQ4_NL || + op->src[0]->type == GGML_TYPE_Q4_K || + op->src[0]->type == GGML_TYPE_Q5_K || + op->src[0]->type == GGML_TYPE_Q6_K) { + return op->src[1]->type == GGML_TYPE_F32 && ggml_is_contiguous(op->src[0]) && ggml_is_contiguous(op->src[1]); + } else if (op->src[0]->type == GGML_TYPE_Q8_0) { + return op->src[1]->type == GGML_TYPE_F32; + } + return false; + case GGML_OP_MUL_MAT_ID: + if (op->src[0]->type == GGML_TYPE_Q4_0 || + op->src[0]->type == GGML_TYPE_Q8_0 || + op->src[0]->type == GGML_TYPE_MXFP4) { + if (op->src[1]->type == GGML_TYPE_F32) { + return ggml_is_contiguous(op->src[0]) && ggml_is_contiguous(op->src[1]); + } + } + // q4_0, q8_0 and mxfp4 have general MUL_MAT_ID support, + // the quantizations here currently do not - they are only supported by Adreno with certain shapes + if (op->src[0]->type == GGML_TYPE_Q4_1 || + op->src[0]->type == GGML_TYPE_Q5_0 || + op->src[0]->type == GGML_TYPE_Q5_1 || + op->src[0]->type == GGML_TYPE_Q4_K || + op->src[0]->type == GGML_TYPE_Q5_K || + op->src[0]->type == GGML_TYPE_Q6_K) { +#ifdef GGML_OPENCL_USE_ADRENO_KERNELS + if (op->src[1]->type == GGML_TYPE_F32) { + return use_adreno_moe_kernels(backend_ctx, op->src[0]) + && ggml_is_contiguous(op->src[0]) + && ggml_is_contiguous(op->src[1]); + } +#endif + return false; + } + return false; + case GGML_OP_RESHAPE: + case GGML_OP_VIEW: + case GGML_OP_PERMUTE: + case GGML_OP_TRANSPOSE: + return true; + case GGML_OP_DIAG: + return true; + case GGML_OP_DIAG_MASK_INF: + return op->ne[3] == 1; + case GGML_OP_ROPE: { + const int mode = ((const int32_t *) op->op_params)[2]; + const bool is_mrope = mode & GGML_ROPE_TYPE_MROPE; + const bool is_vision = mode == GGML_ROPE_TYPE_VISION; + if (is_mrope && !is_vision) { + if (op->src[0]->type == GGML_TYPE_F32 || + op->src[0]->type == GGML_TYPE_F16) { + return true; + } + return false; + } + if (is_vision) { + if (op->src[0]->type == GGML_TYPE_F32 || + op->src[0]->type == GGML_TYPE_F16) { + return true; + } + return false; + } + return true; + } + case GGML_OP_SOLVE_TRI: + return op->src[0]->type == GGML_TYPE_F32 && ggml_is_contiguous(op->src[0]); + case GGML_OP_IM2COL: + return true; + case GGML_OP_ARGSORT: { + load_cl_kernels_argsort(backend_ctx); + + cl_kernel kernel = backend_ctx->kernel_argsort_f32_i32; + int max_workgroup_size = backend_ctx->get_kernel_workgroup_size(kernel); + + int cols = 1; + while (cols < op->ne[0]) { + cols *= 2; + } + + return cols <= max_workgroup_size && op->src[0]->type == GGML_TYPE_F32; + } + case GGML_OP_SUM_ROWS: + case GGML_OP_CUMSUM: + return op->src[0]->type == GGML_TYPE_F32 && ggml_is_contiguous(op->src[0]); + case GGML_OP_MEAN: + return op->src[0]->type == GGML_TYPE_F32; + case GGML_OP_FLASH_ATTN_EXT: + { + load_cl_kernels_flash_attn(backend_ctx); + + const ggml_tensor * q = op->src[0]; + const ggml_tensor * k = op->src[1]; + const ggml_tensor * v = op->src[2]; + + const int dk = q->ne[0]; + const int dv = v->ne[0]; + + const struct { int dk; int dv; } supported_dims[] = { + { 40, 40}, { 64, 64}, { 80, 80}, { 96, 96}, + {112, 112}, {128, 128}, {192, 128}, + {192, 192}, {256, 256}, + }; + + bool dims_supported = false; + for (size_t i = 0; i < sizeof(supported_dims)/sizeof(supported_dims[0]); ++i) { + if (supported_dims[i].dk == dk && supported_dims[i].dv == dv) { + dims_supported = true; + break; + } + } + if (!dims_supported) { + return false; + } + + const bool is_f32_f32 = q->type == GGML_TYPE_F32 && k->type == GGML_TYPE_F32 && + v->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32; + const bool is_f16_f16 = q->type == GGML_TYPE_F16 && k->type == GGML_TYPE_F16 && + v->type == GGML_TYPE_F16 && op->type == GGML_TYPE_F16; + const bool is_f32_f16 = q->type == GGML_TYPE_F32 && k->type == GGML_TYPE_F16 && + v->type == GGML_TYPE_F16 && op->type == GGML_TYPE_F32; + + return is_f32_f32 || is_f16_f16 || is_f32_f16; + } + default: + return false; + } +} + +// Forward declaration - implementation appears later in the file. +static const char * ggml_backend_opencl_buffer_type_get_name(ggml_backend_buffer_type_t buffer_type); + +static ggml_guid_t ggml_backend_opencl_guid() { + static ggml_guid guid = { 0xde, 0xe0, 0x70, 0xa2, 0x73, 0x4e, 0x4d, 0xbc, 0xb0, 0xc7, 0x4f, 0xd4, 0x6d, 0x4e, 0x90, 0xfe }; + return &guid; +} + +static ggml_backend_i ggml_backend_opencl_i = { + /* .get_name = */ ggml_backend_opencl_name, + /* .free = */ ggml_backend_opencl_free, + /* .set_tensor_async = */ NULL, /* ggml_backend_opencl_set_tensor_async */ + /* .get_tensor_async = */ NULL, /* ggml_backend_opencl_get_tensor_async */ + /* .set_tensor_2d_async = */ NULL, + /* .get_tensor_2d_async = */ NULL, + /* .cpy_tensor_async = */ NULL, /* ggml_backend_opencl_cpy_tensor_async */ + /* .synchronize = */ ggml_backend_opencl_synchronize, + /* .graph_plan_create = */ NULL, + /* .graph_plan_free = */ NULL, + /* .graph_plan_update = */ NULL, + /* .graph_plan_compute = */ NULL, + /* .graph_compute = */ ggml_backend_opencl_graph_compute, + /* .event_record = */ NULL, + /* .event_wait = */ NULL, + /* .graph_optimize = */ NULL, +}; + +ggml_backend_t ggml_backend_opencl_init(void) { + ggml_backend_dev_t dev = ggml_backend_reg_dev_get(ggml_backend_opencl_reg(), 0); + ggml_backend_opencl_context *backend_ctx = ggml_cl_init(dev); + + ggml_backend_t backend = new ggml_backend { + /* .guid = */ ggml_backend_opencl_guid(), + /* .iface = */ ggml_backend_opencl_i, + /* .device = */ dev, + /* .context = */ backend_ctx + }; + + return backend; +} + +bool ggml_backend_is_opencl(ggml_backend_t backend) { + return backend && backend->iface.get_name == ggml_backend_opencl_name; +} + +// +// buffer +// +struct ggml_backend_opencl_buffer_context { + // A buffer context can hold multiple cl_mem objects. This is for flattening + // quantized weights and should be used with GGML_OPENCL_SMALL_ALLOC where + // each tensor is allocated a separate buffer. When flattening is enabled + // with small allocation, each tensor is backed by two cl_mem objects (for + // quants and scales) packed into a backend_opencl_buffer. + ggml_backend_opencl_buffer_context(cl_mem buf) + : name("OpenCL") { + buffer.push_back(buf); + } + + ~ggml_backend_opencl_buffer_context() { + for (cl_mem buf : buffer) { + CL_CHECK(clReleaseMemObject(buf)); + } + for (cl_mem im : img) { + CL_CHECK(clReleaseMemObject(im)); + } + + // Delete all extras to trigger their destructors + for (ggml_tensor_extra_cl * e : temp_tensor_extras) { + delete e; + } + for (ggml_tensor_extra_cl * e : temp_tensor_extras_in_use) { + delete e; + } + for (ggml_tensor_extra_cl_q4_0 * e : temp_tensor_extras_q4_0) { + delete e; + } + for (ggml_tensor_extra_cl_q4_0 * e : temp_tensor_extras_q4_0_in_use) { + delete e; + } + for (ggml_tensor_extra_cl_q4_1 * e : temp_tensor_extras_q4_1) { + delete e; + } + for (ggml_tensor_extra_cl_q4_1 * e : temp_tensor_extras_q4_1_in_use) { + delete e; + } + for (ggml_tensor_extra_cl_q5_0 * e : temp_tensor_extras_q5_0) { + delete e; + } + for (ggml_tensor_extra_cl_q5_0 * e : temp_tensor_extras_q5_0_in_use) { + delete e; + } + for (ggml_tensor_extra_cl_q5_1 * e : temp_tensor_extras_q5_1) { + delete e; + } + for (ggml_tensor_extra_cl_q5_1 * e : temp_tensor_extras_q5_1_in_use) { + delete e; + } + for (ggml_tensor_extra_cl_mxfp4 * e : temp_tensor_extras_mxfp4) { + delete e; + } + for (ggml_tensor_extra_cl_mxfp4 * e : temp_tensor_extras_mxfp4_in_use) { + delete e; + } + for (ggml_tensor_extra_cl_q8_0 * e : temp_tensor_extras_q8_0) { + delete e; + } + for (ggml_tensor_extra_cl_q8_0 * e : temp_tensor_extras_q8_0_in_use) { + delete e; + } + for (ggml_tensor_extra_cl_iq4_nl * e : temp_tensor_extras_iq4_nl) { + delete e; + } + for (ggml_tensor_extra_cl_iq4_nl * e : temp_tensor_extras_iq4_nl_in_use) { + delete e; + } + for (ggml_tensor_extra_cl_q4_K * e : temp_tensor_extras_q4_K) { + delete e; + } + for (ggml_tensor_extra_cl_q4_K * e : temp_tensor_extras_q4_K_in_use) { + delete e; + } + for (ggml_tensor_extra_cl_q6_K * e : temp_tensor_extras_q6_K) { + delete e; + } + for (ggml_tensor_extra_cl_q6_K * e : temp_tensor_extras_q6_K_in_use) { + delete e; + } + for (ggml_tensor_extra_cl_q5_K * e : temp_tensor_extras_q5_K) { + delete e; + } + for (ggml_tensor_extra_cl_q5_K * e : temp_tensor_extras_q5_K_in_use) { + delete e; + } + } + + ggml_tensor_extra_cl * ggml_opencl_alloc_temp_tensor_extra() { + ggml_tensor_extra_cl * extra; + if (temp_tensor_extras.empty()) { + extra = new ggml_tensor_extra_cl(); + } else { + extra = temp_tensor_extras.back(); + temp_tensor_extras.pop_back(); + } + + temp_tensor_extras_in_use.push_back(extra); + + extra->reset(); + return extra; + } + + ggml_tensor_extra_cl_q4_0 * ggml_opencl_alloc_temp_tensor_extra_q4_0() { + ggml_tensor_extra_cl_q4_0 * extra; + if (temp_tensor_extras_q4_0.empty()) { + extra = new ggml_tensor_extra_cl_q4_0(); + } else { + extra = temp_tensor_extras_q4_0.back(); + temp_tensor_extras_q4_0.pop_back(); + } + + temp_tensor_extras_q4_0_in_use.push_back(extra); + + extra->reset(); + return extra; + } + + ggml_tensor_extra_cl_q4_1 * ggml_opencl_alloc_temp_tensor_extra_q4_1() { + ggml_tensor_extra_cl_q4_1 * extra; + if (temp_tensor_extras_q4_1.empty()) { + extra = new ggml_tensor_extra_cl_q4_1(); + } else { + extra = temp_tensor_extras_q4_1.back(); + temp_tensor_extras_q4_1.pop_back(); + } + + temp_tensor_extras_q4_1_in_use.push_back(extra); + + extra->reset(); + return extra; + } + + ggml_tensor_extra_cl_q5_0 * ggml_opencl_alloc_temp_tensor_extra_q5_0() { + ggml_tensor_extra_cl_q5_0 * extra; + if (temp_tensor_extras_q5_0.empty()) { + extra = new ggml_tensor_extra_cl_q5_0(); + } else { + extra = temp_tensor_extras_q5_0.back(); + temp_tensor_extras_q5_0.pop_back(); + } + + temp_tensor_extras_q5_0_in_use.push_back(extra); + + extra->reset(); + return extra; + } + + ggml_tensor_extra_cl_q5_1 * ggml_opencl_alloc_temp_tensor_extra_q5_1() { + ggml_tensor_extra_cl_q5_1 * extra; + if (temp_tensor_extras_q5_1.empty()) { + extra = new ggml_tensor_extra_cl_q5_1(); + } else { + extra = temp_tensor_extras_q5_1.back(); + temp_tensor_extras_q5_1.pop_back(); + } + + temp_tensor_extras_q5_1_in_use.push_back(extra); + + extra->reset(); + return extra; + } + + ggml_tensor_extra_cl_mxfp4 * ggml_opencl_alloc_temp_tensor_extra_mxfp4() { + ggml_tensor_extra_cl_mxfp4 * extra; + if (temp_tensor_extras_mxfp4.empty()) { + extra = new ggml_tensor_extra_cl_mxfp4(); + } else { + extra = temp_tensor_extras_mxfp4.back(); + temp_tensor_extras_mxfp4.pop_back(); + } + + temp_tensor_extras_mxfp4_in_use.push_back(extra); + + extra->reset(); + return extra; + } + + ggml_tensor_extra_cl_q8_0 * ggml_opencl_alloc_temp_tensor_extra_q8_0() { + ggml_tensor_extra_cl_q8_0 * extra; + if (temp_tensor_extras_q8_0.empty()) { + extra = new ggml_tensor_extra_cl_q8_0(); + } else { + extra = temp_tensor_extras_q8_0.back(); + temp_tensor_extras_q8_0.pop_back(); + } + + temp_tensor_extras_q8_0_in_use.push_back(extra); + + extra->reset(); + return extra; + } + + ggml_tensor_extra_cl_iq4_nl * ggml_opencl_alloc_temp_tensor_extra_iq4_nl() { + ggml_tensor_extra_cl_iq4_nl * extra; + if (temp_tensor_extras_iq4_nl.empty()) { + extra = new ggml_tensor_extra_cl_iq4_nl(); + } else { + extra = temp_tensor_extras_iq4_nl.back(); + temp_tensor_extras_iq4_nl.pop_back(); + } + + temp_tensor_extras_iq4_nl_in_use.push_back(extra); + + extra->reset(); + return extra; + } + + ggml_tensor_extra_cl_q4_K * ggml_opencl_alloc_temp_tensor_extra_q4_K() { + ggml_tensor_extra_cl_q4_K * extra; + if (temp_tensor_extras_q4_K.empty()) { + extra = new ggml_tensor_extra_cl_q4_K(); + } else { + extra = temp_tensor_extras_q4_K.back(); + temp_tensor_extras_q4_K.pop_back(); + } + + temp_tensor_extras_q4_K_in_use.push_back(extra); + + extra->reset(); + return extra; + } + + ggml_tensor_extra_cl_q5_K * ggml_opencl_alloc_temp_tensor_extra_q5_K() { + ggml_tensor_extra_cl_q5_K * extra; + if (temp_tensor_extras_q5_K.empty()) { + extra = new ggml_tensor_extra_cl_q5_K(); + } else { + extra = temp_tensor_extras_q5_K.back(); + temp_tensor_extras_q5_K.pop_back(); + } + + temp_tensor_extras_q5_K_in_use.push_back(extra); + + extra->reset(); + return extra; + } + + ggml_tensor_extra_cl_q6_K * ggml_opencl_alloc_temp_tensor_extra_q6_K() { + ggml_tensor_extra_cl_q6_K * extra; + if (temp_tensor_extras_q6_K.empty()) { + extra = new ggml_tensor_extra_cl_q6_K(); + } else { + extra = temp_tensor_extras_q6_K.back(); + temp_tensor_extras_q6_K.pop_back(); + } + + temp_tensor_extras_q6_K_in_use.push_back(extra); + + extra->reset(); + return extra; + } + + void reset() { + for (ggml_tensor_extra_cl * e : temp_tensor_extras_in_use) { + temp_tensor_extras.push_back(e); + } + temp_tensor_extras_in_use.clear(); + + for (ggml_tensor_extra_cl_q4_0 * e : temp_tensor_extras_q4_0_in_use) { + temp_tensor_extras_q4_0.push_back(e); + } + temp_tensor_extras_q4_0_in_use.clear(); + + for (ggml_tensor_extra_cl_q4_1 * e : temp_tensor_extras_q4_1_in_use) { + temp_tensor_extras_q4_1.push_back(e); + } + temp_tensor_extras_q4_1_in_use.clear(); + + for (ggml_tensor_extra_cl_q5_0 * e : temp_tensor_extras_q5_0_in_use) { + temp_tensor_extras_q5_0.push_back(e); + } + temp_tensor_extras_q5_0_in_use.clear(); + + for (ggml_tensor_extra_cl_q5_1 * e : temp_tensor_extras_q5_1_in_use) { + temp_tensor_extras_q5_1.push_back(e); + } + temp_tensor_extras_q5_1_in_use.clear(); + + for (ggml_tensor_extra_cl_mxfp4 * e : temp_tensor_extras_mxfp4_in_use) { + temp_tensor_extras_mxfp4.push_back(e); + } + temp_tensor_extras_mxfp4_in_use.clear(); + + for (ggml_tensor_extra_cl_q8_0 * e : temp_tensor_extras_q8_0_in_use) { + temp_tensor_extras_q8_0.push_back(e); + } + temp_tensor_extras_q8_0_in_use.clear(); + + for (ggml_tensor_extra_cl_iq4_nl * e : temp_tensor_extras_iq4_nl_in_use) { + temp_tensor_extras_iq4_nl.push_back(e); + } + temp_tensor_extras_iq4_nl_in_use.clear(); + + for (ggml_tensor_extra_cl_q4_K * e : temp_tensor_extras_q4_K_in_use) { + temp_tensor_extras_q4_K.push_back(e); + } + temp_tensor_extras_q4_K_in_use.clear(); + + for (ggml_tensor_extra_cl_q5_K * e : temp_tensor_extras_q5_K_in_use) { + temp_tensor_extras_q5_K.push_back(e); + } + temp_tensor_extras_q5_K_in_use.clear(); + + for (ggml_tensor_extra_cl_q6_K * e : temp_tensor_extras_q6_K_in_use) { + temp_tensor_extras_q6_K.push_back(e); + } + temp_tensor_extras_q6_K_in_use.clear(); + } + + // Pools for extras. Available extras are in `temp_tensor_extras`. Extras + // being used are in `temp_tensor_extras_in_use`. At the first run, new + // extras get created and put in `in_use`. When the buffer is reset via + // the `reset` callback, all extras in `in_use` get moved to available extras + // for reuse. + std::vector<ggml_tensor_extra_cl *> temp_tensor_extras; + std::vector<ggml_tensor_extra_cl *> temp_tensor_extras_in_use; + std::vector<ggml_tensor_extra_cl_q4_0 *> temp_tensor_extras_q4_0; + std::vector<ggml_tensor_extra_cl_q4_0 *> temp_tensor_extras_q4_0_in_use; + std::vector<ggml_tensor_extra_cl_q4_1 *> temp_tensor_extras_q4_1; + std::vector<ggml_tensor_extra_cl_q4_1 *> temp_tensor_extras_q4_1_in_use; + std::vector<ggml_tensor_extra_cl_q5_0 *> temp_tensor_extras_q5_0; + std::vector<ggml_tensor_extra_cl_q5_0 *> temp_tensor_extras_q5_0_in_use; + std::vector<ggml_tensor_extra_cl_q5_1 *> temp_tensor_extras_q5_1; + std::vector<ggml_tensor_extra_cl_q5_1 *> temp_tensor_extras_q5_1_in_use; + std::vector<ggml_tensor_extra_cl_mxfp4 *> temp_tensor_extras_mxfp4; + std::vector<ggml_tensor_extra_cl_mxfp4 *> temp_tensor_extras_mxfp4_in_use; + std::vector<ggml_tensor_extra_cl_q8_0 *> temp_tensor_extras_q8_0; + std::vector<ggml_tensor_extra_cl_q8_0 *> temp_tensor_extras_q8_0_in_use; + std::vector<ggml_tensor_extra_cl_iq4_nl *> temp_tensor_extras_iq4_nl; + std::vector<ggml_tensor_extra_cl_iq4_nl *> temp_tensor_extras_iq4_nl_in_use; + std::vector<ggml_tensor_extra_cl_q4_K *> temp_tensor_extras_q4_K; + std::vector<ggml_tensor_extra_cl_q4_K *> temp_tensor_extras_q4_K_in_use; + std::vector<ggml_tensor_extra_cl_q5_K *> temp_tensor_extras_q5_K; + std::vector<ggml_tensor_extra_cl_q5_K *> temp_tensor_extras_q5_K_in_use; + std::vector<ggml_tensor_extra_cl_q6_K *> temp_tensor_extras_q6_K; + std::vector<ggml_tensor_extra_cl_q6_K *> temp_tensor_extras_q6_K_in_use; + + // The buffer_context is initially created by ggml_backend_buft_alloc_buffer + // before any tensor is initialized (at the beginning of alloc_tensor_range). + // Hence, there is always a buffer object in this vector. When each tensor is + // being initialized, this original buffer object will be released if both + // flattening and small allocation are enabled, and additional buffer + // objects will be created in init_tensor to represent flattened quantized + // weights. + std::vector<cl_mem> buffer; + // These are image1d_buffer_t objects that wrap around the quants and scales. + // For Q4_0 quantization, there should be two of them - one for quants and + // one for scales. They should be populated only when flattening and small + // allocation are enabled. + std::vector<cl_mem> img; + std::string name; +}; + +static void ggml_backend_opencl_buffer_free_buffer(ggml_backend_buffer_t buffer) { + ggml_backend_opencl_buffer_context * ctx = (ggml_backend_opencl_buffer_context *) buffer->context; + delete ctx; +} + +static void * ggml_backend_opencl_buffer_get_base(ggml_backend_buffer_t buffer) { + ggml_backend_opencl_device_context * dev_ctx = (ggml_backend_opencl_device_context *) buffer->buft->device->context; + return (void *) (uintptr_t) dev_ctx->backend_ctx->alignment; +} + +static enum ggml_status ggml_backend_opencl_buffer_init_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor) { + ggml_backend_opencl_buffer_context * ctx = (ggml_backend_opencl_buffer_context *) buffer->context; + + if (tensor->view_src != nullptr) { + GGML_ASSERT(tensor->view_src->buffer->buft == buffer->buft); + + ggml_tensor_extra_cl * view_extra = (ggml_tensor_extra_cl *) tensor->view_src->extra; + GGML_ASSERT(view_extra && "view_extra is nullptr?"); + + // Reuse extra of the parent tensor. The offset of this view tensor + // becomes `extra->offset + view_offs` and needs to be calculated when + // it is used. This changes is needed because of the change to + // ggml_alloc.c in https://github.com/ggml-org/llama.cpp/pull/7640. + // `buffer` passed in here will always be `tensor->buffer`. It is OK + // to allocate extras from the same buffer context for ordinary + // intermediate tensors. But for views into kv cache tensors, doing so + // would mess up the extras used by kv cache. + // Before #7640, `buffer` is for intermediate tensors, which is always + // different from that of kv cache tensors. + // + // NB: now extra->offset no longer accounts for view_offs. + // NB: this should not apply to weight tensors (for end-to-end runs, but + // may apply for test-backend-ops). + // FIXME: if any unexpected results are seen, double check the offset - + // there could be other places that need fix. + tensor->extra = view_extra; + } else { + { + size_t offset = (char *) tensor->data - (char *) ggml_backend_opencl_buffer_get_base(buffer); + + ggml_tensor_extra_cl * extra = ctx->ggml_opencl_alloc_temp_tensor_extra(); + extra->offset = offset; + extra->data_device = ctx->buffer[0]; + extra->actual_size = ggml_nbytes(tensor); + + tensor->extra = extra; + } + } + return GGML_STATUS_SUCCESS; +} + +static void ggml_backend_opencl_buffer_set_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor, const void * data, size_t offset, size_t size) { + ggml_backend_opencl_device_context * dev_ctx = (ggml_backend_opencl_device_context *) buffer->buft->device->context; + ggml_backend_opencl_context * backend_ctx = dev_ctx->backend_ctx; + + cl_context context = backend_ctx->context; + cl_command_queue queue = backend_ctx->queue; + +#ifdef GGML_OPENCL_SOA_Q + // We separate the quantized bits and scale from block_q4_0 by using an + // additional kernel, where each thread handles a block. We first read the + // original weights into a temporary buffer, then create two separate + // buffers for quantized bits and scales, which are then populated by the + // conversion kernel. + if (tensor->type == GGML_TYPE_Q4_0) { + // Tensors should have been preallocated, therefore they should + // already have ggml_tensor_extra_cl as extra. + ggml_tensor_extra_cl * extra_orig = (ggml_tensor_extra_cl *)tensor->extra; + GGML_ASSERT(extra_orig && "Tesnors in OpenCL backend should have been allocated and initialized"); + + // Allocate the new extra and create aliases from the original. + ggml_backend_opencl_buffer_context * ctx = (ggml_backend_opencl_buffer_context *) buffer->context; + ggml_tensor_extra_cl_q4_0 * extra = ctx->ggml_opencl_alloc_temp_tensor_extra_q4_0(); + + size_t size_d = ggml_nelements(tensor)/ggml_blck_size(tensor->type)*sizeof(ggml_fp16_t); + size_t size_q = ggml_nelements(tensor)/ggml_blck_size(tensor->type)*ggml_blck_size(tensor->type)/2; + GGML_ASSERT(size_d + size_q == ggml_nbytes(tensor) && "Incorrect tensor size"); + + cl_int err; + cl_mem data_device = clCreateBuffer(context, CL_MEM_READ_WRITE, + ggml_nbytes(tensor), NULL, &err); + CL_CHECK(err); + CL_CHECK(clEnqueueWriteBuffer( + queue, data_device, CL_TRUE, 0, + ggml_nbytes(tensor), data, 0, NULL, NULL)); + + // We consider the specified offset arg as always, although For weights + // the offset arg should be 0 (we do not assert this). + //GGML_ASSERT(offset == 0); + + // We create subbuffers from the original tensor buffer for scales and + // quants - i.e., scales and quants are aliases into the buffer object + // that backs the original tensor. This is a cleaner way to adapt to the + // new memory management. + // In the old code, we allocate new buffers for scales and quants + // respectively, which could still be done but would result in double + // allocation; properly deallocating the preallocated buffer that backs + // the tensors is tricky and would leak the backend specific information + // into the general backend code. + // Does this create misaligned subbuffers (alignment is 1024) in certain + // cases ? + cl_buffer_region region; + + // The original tensor memory is divided into scales and quants, i.e., + // we first store scales, then quants. + // Create subbuffer for scales. + region.origin = align_to(extra_orig->offset + tensor->view_offs + offset, backend_ctx->alignment); + region.size = size_d; + extra->d = clCreateSubBuffer( + extra_orig->data_device, CL_MEM_READ_WRITE, + CL_BUFFER_CREATE_TYPE_REGION, ®ion, &err); + CL_CHECK(err); + auto previous_origin = region.origin; + + // Create subbuffer for quants. + region.origin = align_to(previous_origin + size_d, backend_ctx->alignment); + region.size = size_q; + extra->q = clCreateSubBuffer( + extra_orig->data_device, CL_MEM_READ_WRITE, + CL_BUFFER_CREATE_TYPE_REGION, ®ion, &err); + CL_CHECK(err); + +#ifdef GGML_OPENCL_USE_ADRENO_KERNELS + // Adreno moe q4_0 kernel needs special transpose and unshuffling + if (use_adreno_moe_kernels(backend_ctx, tensor)) { + cl_kernel kernel = backend_ctx->kernel_convert_block_q4_0_trans4_ns; + + int ne00 = tensor->ne[0]; + int ne01 = tensor->ne[1]; + int ne02 = tensor->ne[2]; + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &data_device)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra->q)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra->d)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(int), &ne01)); + + size_t global_work_size[3] = {static_cast<size_t>(((ne01 + 63) / 64) * 64), static_cast<size_t>(ne00 / 32), static_cast<size_t>(ne02)}; + size_t local_work_size[3] = {64, 2, 1}; + + cl_event evt; + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, &evt)); + CL_CHECK(clWaitForEvents(1, &evt)); + CL_CHECK(clReleaseMemObject(data_device)); + + // Create image for Q + cl_image_format img_format_q = {CL_R, CL_UNSIGNED_INT32}; + cl_image_desc img_desc_q = { + CL_MEM_OBJECT_IMAGE1D_BUFFER, + static_cast<size_t>(ggml_nelements(tensor) / 8), + 0, 0, 0, 0, 0, 0, 0, + { extra->q } + }; + extra->q_img = clCreateImage(context, CL_MEM_READ_ONLY, &img_format_q, &img_desc_q, NULL, &err); + tensor->extra = extra; + + return; + } +#endif // GGML_OPENCL_USE_ADRENO_KERNELS + +#ifdef GGML_OPENCL_USE_ADRENO_KERNELS + cl_kernel kernel = backend_ctx->kernel_convert_block_q4_0; + + // The optimized kernels need weights in natural order, so unshuffle. + if (use_adreno_kernels(backend_ctx, tensor)) { + kernel = backend_ctx->kernel_convert_block_q4_0_noshuffle; + } +#else + cl_kernel kernel = backend_ctx->kernel_convert_block_q4_0; +#endif // GGML_OPENCL_USE_ADRENO_KERNELS + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &data_device)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra->q)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra->d)); + + size_t global_work_size[] = {(size_t)ggml_nelements(tensor)/ggml_blck_size(tensor->type), 1, 1}; + size_t local_work_size[] = {64, 1, 1}; + + cl_event evt; + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, &evt)); + CL_CHECK(clWaitForEvents(1, &evt)); + CL_CHECK(clReleaseMemObject(data_device)); + + tensor->extra = extra; + + // transpose the weights and scales +#ifdef GGML_OPENCL_USE_ADRENO_KERNELS + // Only do transpose for large, non batched matrix + // TODO: use preallocated images instead of sub-buffer then image + if (use_adreno_kernels(backend_ctx, tensor)) { + int M = tensor->ne[1]; + int K = tensor->ne[0]; + + GGML_ASSERT(K % 32 == 0); + + // Transpose q as ushort + transpose_2d_as_16b(backend_ctx, extra->q, extra->q, size_q, K/4, M); + // Transpose d as ushort + transpose_2d_as_16b(backend_ctx, extra->d, extra->d, size_d, K/32, M); + } +#endif // GGML_OPENCL_USE_ADRENO_KERNELS + return; + } + if (tensor->type == GGML_TYPE_Q4_1) { + ggml_tensor_extra_cl * extra_orig = (ggml_tensor_extra_cl *)tensor->extra; + GGML_ASSERT(extra_orig && "Tesnors in OpenCL backend should have been allocated and initialized"); + + // Allocate the new extra and create aliases from the original. + ggml_backend_opencl_buffer_context * ctx = (ggml_backend_opencl_buffer_context *) buffer->context; + ggml_tensor_extra_cl_q4_1 * extra = ctx->ggml_opencl_alloc_temp_tensor_extra_q4_1(); + + size_t size_d = ggml_nelements(tensor)/ggml_blck_size(tensor->type)*sizeof(ggml_fp16_t); + size_t size_m = ggml_nelements(tensor)/ggml_blck_size(tensor->type)*sizeof(ggml_fp16_t); + size_t size_q = ggml_nelements(tensor)/ggml_blck_size(tensor->type)*ggml_blck_size(tensor->type)/2; + GGML_ASSERT(size_d + size_m + size_q == ggml_nbytes(tensor) && "Incorrect tensor size"); + + cl_int err; + cl_mem data_device = clCreateBuffer(context, CL_MEM_READ_WRITE, + ggml_nbytes(tensor), NULL, &err); + CL_CHECK(err); + CL_CHECK(clEnqueueWriteBuffer( + queue, data_device, CL_TRUE, 0, + ggml_nbytes(tensor), data, 0, NULL, NULL)); + + cl_buffer_region region; + + // The original tensor memory is divided into scales and quants, i.e., + // we first store scales, mins, then quants. + // Create subbuffer for scales. + region.origin = align_to(extra_orig->offset + tensor->view_offs + offset, backend_ctx->alignment); + region.size = size_d; + extra->d = clCreateSubBuffer( + extra_orig->data_device, CL_MEM_READ_WRITE, + CL_BUFFER_CREATE_TYPE_REGION, ®ion, &err); + CL_CHECK(err); + auto previous_origin = region.origin; + + // Create subbuffer for mins. + region.origin = align_to(previous_origin + size_d, backend_ctx->alignment); + region.size = size_m; + extra->m = clCreateSubBuffer( + extra_orig->data_device, CL_MEM_READ_WRITE, + CL_BUFFER_CREATE_TYPE_REGION, ®ion, &err); + CL_CHECK(err); + previous_origin = region.origin; + + // Create subbuffer for quants. + region.origin = align_to(previous_origin + size_m, backend_ctx->alignment); + region.size = size_q; + extra->q = clCreateSubBuffer( + extra_orig->data_device, CL_MEM_READ_WRITE, + CL_BUFFER_CREATE_TYPE_REGION, ®ion, &err); + CL_CHECK(err); + +#ifdef GGML_OPENCL_USE_ADRENO_KERNELS + // Adreno moe q4_1 kernel needs special transpose and unshuffling + if (use_adreno_moe_kernels(backend_ctx, tensor)) { + cl_kernel kernel = backend_ctx->kernel_convert_block_q4_1_trans4_ns; + + int ne00 = tensor->ne[0]; + int ne01 = tensor->ne[1]; + int ne02 = tensor->ne[2]; + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &data_device)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra->q)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra->d)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_mem), &extra->m)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(int), &ne01)); + + size_t global_work_size[3] = {static_cast<size_t>(((ne01 + 63) / 64) * 64), static_cast<size_t>(ne00 / 32), static_cast<size_t>(ne02)}; + size_t local_work_size[3] = {64, 2, 1}; + + cl_event evt; + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, &evt)); + CL_CHECK(clWaitForEvents(1, &evt)); + CL_CHECK(clReleaseMemObject(data_device)); + + // Create image for Q + cl_image_format img_format_q = {CL_R, CL_UNSIGNED_INT32}; + cl_image_desc img_desc_q = { + CL_MEM_OBJECT_IMAGE1D_BUFFER, + static_cast<size_t>(ggml_nelements(tensor) / 8), + 0, 0, 0, 0, 0, 0, 0, + { extra->q } + }; + extra->q_img = clCreateImage(context, CL_MEM_READ_ONLY, &img_format_q, &img_desc_q, NULL, &err); + tensor->extra = extra; + + return; + } +#endif // GGML_OPENCL_USE_ADRENO_KERNELS + + // normal q4_1 repack +#ifdef GGML_OPENCL_USE_ADRENO_KERNELS + cl_kernel kernel = backend_ctx->kernel_convert_block_q4_1; + + if (use_adreno_kernels(backend_ctx, tensor)) { + kernel = backend_ctx->kernel_convert_block_q4_1_noshuffle; + } +#else + cl_kernel kernel = backend_ctx->kernel_convert_block_q4_1; +#endif // GGML_OPENCL_USE_ADRENO_KERNELS + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &data_device)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra->q)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra->d)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_mem), &extra->m)); + + size_t global_work_size[] = {(size_t)ggml_nelements(tensor)/ggml_blck_size(tensor->type), 1, 1}; + size_t local_work_size[] = {64, 1, 1}; + + cl_event evt; + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, &evt)); + CL_CHECK(clWaitForEvents(1, &evt)); + CL_CHECK(clReleaseMemObject(data_device)); + + tensor->extra = extra; + +#ifdef GGML_OPENCL_USE_ADRENO_KERNELS + if (use_adreno_kernels(backend_ctx, tensor)) { + + int M = tensor->ne[1]; + int K = tensor->ne[0]; + + GGML_ASSERT(K % 32 == 0); + + // Transpose q as ushort + transpose_2d_as_16b(backend_ctx, extra->q, extra->q, size_q, K/4, M); + // Transpose d as ushort + transpose_2d_as_16b(backend_ctx, extra->d, extra->d, size_d, K/32, M); + // Transpose m as ushort + transpose_2d_as_16b(backend_ctx, extra->m, extra->m, size_m, K/32, M); + } +#endif // GGML_OPENCL_USE_ADRENO_KERNELS + return; + } + if (tensor->type == GGML_TYPE_Q5_0) { + ggml_tensor_extra_cl * extra_orig = (ggml_tensor_extra_cl *)tensor->extra; + GGML_ASSERT(extra_orig && "Tesnors in OpenCL backend should have been allocated and initialized"); + + // Allocate the new extra and create aliases from the original. + ggml_backend_opencl_buffer_context * ctx = (ggml_backend_opencl_buffer_context *) buffer->context; + ggml_tensor_extra_cl_q5_0 * extra = ctx->ggml_opencl_alloc_temp_tensor_extra_q5_0(); + + size_t size_d = ggml_nelements(tensor)/ggml_blck_size(tensor->type)*sizeof(ggml_fp16_t); + size_t size_qs = ggml_nelements(tensor)/ggml_blck_size(tensor->type)*ggml_blck_size(tensor->type)/2; + size_t size_qh = ggml_nelements(tensor)/ggml_blck_size(tensor->type)*sizeof(int32_t); + GGML_ASSERT(size_d + size_qs + size_qh == ggml_nbytes(tensor) && "Incorrect tensor size"); + + cl_int err; + cl_mem data_device = clCreateBuffer(context, CL_MEM_READ_WRITE, + ggml_nbytes(tensor), NULL, &err); + CL_CHECK(err); + CL_CHECK(clEnqueueWriteBuffer( + queue, data_device, CL_TRUE, 0, + ggml_nbytes(tensor), data, 0, NULL, NULL)); + + cl_buffer_region region; + + // Create subbuffer for scales. + region.origin = align_to(extra_orig->offset + tensor->view_offs + offset, backend_ctx->alignment); + region.size = size_d; + extra->d = clCreateSubBuffer( + extra_orig->data_device, CL_MEM_READ_WRITE, + CL_BUFFER_CREATE_TYPE_REGION, ®ion, &err); + CL_CHECK(err); + auto previous_origin = region.origin; + + // Create subbuffer for qh. + region.origin = align_to(previous_origin + size_d, backend_ctx->alignment); + region.size = size_qh; + extra->qh = clCreateSubBuffer( + extra_orig->data_device, CL_MEM_READ_WRITE, + CL_BUFFER_CREATE_TYPE_REGION, ®ion, &err); + CL_CHECK(err); + previous_origin = region.origin; + + // Create subbuffer for qs. + region.origin = align_to(previous_origin + size_qh, backend_ctx->alignment); + region.size = size_qs; + extra->qs = clCreateSubBuffer( + extra_orig->data_device, CL_MEM_READ_WRITE, + CL_BUFFER_CREATE_TYPE_REGION, ®ion, &err); + CL_CHECK(err); + +#ifdef GGML_OPENCL_USE_ADRENO_KERNELS + // Adreno moe q5_0 kernel needs special transpose and unshuffling + if (use_adreno_moe_kernels(backend_ctx, tensor)) { + cl_kernel kernel = backend_ctx->kernel_convert_block_q5_0_trans4_ns; + + int ne00 = tensor->ne[0]; + int ne01 = tensor->ne[1]; + int ne02 = tensor->ne[2]; + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &data_device)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra->qs)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra->qh)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_mem), &extra->d)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(int), &ne01)); + + size_t global_work_size[3] = {static_cast<size_t>(((ne01 + 63) / 64) * 64), static_cast<size_t>(ne00 / 32), static_cast<size_t>(ne02)}; + size_t local_work_size[3] = {64, 2, 1}; + + cl_event evt; + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, &evt)); + CL_CHECK(clWaitForEvents(1, &evt)); + CL_CHECK(clReleaseMemObject(data_device)); + + // Create image for Q + cl_image_format img_format_qs = {CL_R, CL_UNSIGNED_INT32}; + cl_image_desc img_desc_qs = { + CL_MEM_OBJECT_IMAGE1D_BUFFER, + static_cast<size_t>(ggml_nelements(tensor) / 8), + 0, 0, 0, 0, 0, 0, 0, + { extra->qs } + }; + extra->qs_img = clCreateImage(context, CL_MEM_READ_ONLY, &img_format_qs, &img_desc_qs, NULL, &err); + tensor->extra = extra; + + return; + } +#endif // GGML_OPENCL_USE_ADRENO_KERNELS + +#ifdef GGML_OPENCL_USE_ADRENO_KERNELS + if (use_adreno_kernels(backend_ctx, tensor)) { + cl_kernel kernel = backend_ctx->kernel_convert_block_q5_0_noshuffle; + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &data_device)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra->qs)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra->qh)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_mem), &extra->d)); + + size_t global_work_size[] = {(size_t)ggml_nelements(tensor)/ggml_blck_size(tensor->type), 1, 1}; + size_t local_work_size[] = {64, 1, 1}; + + cl_event evt; + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, &evt)); + CL_CHECK(clWaitForEvents(1, &evt)); + CL_CHECK(clReleaseMemObject(data_device)); + + tensor->extra = extra; + + int M = tensor->ne[1]; + int K = tensor->ne[0]; + GGML_ASSERT(K % 32 == 0); + + // Transpose qs as ushort + transpose_2d_as_16b(backend_ctx, extra->qs, extra->qs, size_qs, K/4, M); + // Transpose qh as uchar + transpose_2d_as_8b(backend_ctx, extra->qh, extra->qh, size_qh, K/8, M); + // Transpose d as ushort + transpose_2d_as_16b(backend_ctx, extra->d, extra->d, size_d, K/32, M); + + return; + } +#endif // GGML_OPENCL_USE_ADRENO_KERNELS + cl_kernel kernel = backend_ctx->kernel_convert_block_q5_0; + cl_ulong n_blk = ggml_nelements(tensor)/ggml_blck_size(tensor->type); + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &data_device)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra->qs)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra->qh)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_mem), &extra->d)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_ulong), &n_blk)); + + size_t global_work_size[] = {(size_t)CEIL_DIV(n_blk, 64) * 64, 1, 1}; + size_t local_work_size[] = {64, 1, 1}; + + cl_event evt; + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, &evt)); + CL_CHECK(clWaitForEvents(1, &evt)); + CL_CHECK(clReleaseMemObject(data_device)); + + tensor->extra = extra; + return; + } + if (tensor->type == GGML_TYPE_Q5_1) { + ggml_tensor_extra_cl * extra_orig = (ggml_tensor_extra_cl *)tensor->extra; + GGML_ASSERT(extra_orig && "Tesnors in OpenCL backend should have been allocated and initialized"); + + // Allocate the new extra and create aliases from the original. + ggml_backend_opencl_buffer_context * ctx = (ggml_backend_opencl_buffer_context *) buffer->context; + ggml_tensor_extra_cl_q5_1 * extra = ctx->ggml_opencl_alloc_temp_tensor_extra_q5_1(); + + size_t size_d = ggml_nelements(tensor)/ggml_blck_size(tensor->type)*sizeof(ggml_fp16_t); + size_t size_m = ggml_nelements(tensor)/ggml_blck_size(tensor->type)*sizeof(ggml_fp16_t); + size_t size_qs = ggml_nelements(tensor)/ggml_blck_size(tensor->type)*ggml_blck_size(tensor->type)/2; + size_t size_qh = ggml_nelements(tensor)/ggml_blck_size(tensor->type)*sizeof(int32_t); + GGML_ASSERT(size_d + size_m + size_qs + size_qh == ggml_nbytes(tensor) && "Incorrect tensor size"); + + cl_int err; + cl_mem data_device = clCreateBuffer(context, CL_MEM_READ_WRITE, + ggml_nbytes(tensor), NULL, &err); + CL_CHECK(err); + CL_CHECK(clEnqueueWriteBuffer( + queue, data_device, CL_TRUE, 0, + ggml_nbytes(tensor), data, 0, NULL, NULL)); + + cl_buffer_region region; + + // The original tensor memory is divided into scales and quants, i.e., + // we first store scales, mins, then quants. + // Create subbuffer for scales. + region.origin = align_to(extra_orig->offset + tensor->view_offs + offset, backend_ctx->alignment); + region.size = size_d; + extra->d = clCreateSubBuffer( + extra_orig->data_device, CL_MEM_READ_WRITE, + CL_BUFFER_CREATE_TYPE_REGION, ®ion, &err); + CL_CHECK(err); + auto previous_origin = region.origin; + + // Create subbuffer for mins. + region.origin = align_to(previous_origin + size_d, backend_ctx->alignment); + region.size = size_m; + extra->m = clCreateSubBuffer( + extra_orig->data_device, CL_MEM_READ_WRITE, + CL_BUFFER_CREATE_TYPE_REGION, ®ion, &err); + CL_CHECK(err); + previous_origin = region.origin; + + // Create subbuffer for qh. + region.origin = align_to(previous_origin + size_m, backend_ctx->alignment); + region.size = size_qh; + extra->qh = clCreateSubBuffer( + extra_orig->data_device, CL_MEM_READ_WRITE, + CL_BUFFER_CREATE_TYPE_REGION, ®ion, &err); + CL_CHECK(err); + previous_origin = region.origin; + + // Create subbuffer for qs. + region.origin = align_to(previous_origin + size_qh, backend_ctx->alignment); + region.size = size_qs; + extra->qs = clCreateSubBuffer( + extra_orig->data_device, CL_MEM_READ_WRITE, + CL_BUFFER_CREATE_TYPE_REGION, ®ion, &err); + CL_CHECK(err); + +#ifdef GGML_OPENCL_USE_ADRENO_KERNELS + // Adreno moe q5_1 kernel needs special transpose and unshuffling + if (use_adreno_moe_kernels(backend_ctx, tensor)) { + cl_kernel kernel = backend_ctx->kernel_convert_block_q5_1_trans4_ns; + + int ne00 = tensor->ne[0]; + int ne01 = tensor->ne[1]; + int ne02 = tensor->ne[2]; + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &data_device)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra->qs)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra->qh)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_mem), &extra->d)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extra->m)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne01)); + + size_t global_work_size[3] = {static_cast<size_t>(((ne01 + 63) / 64) * 64), static_cast<size_t>(ne00 / 32), static_cast<size_t>(ne02)}; + size_t local_work_size[3] = {64, 2, 1}; + + cl_event evt; + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, &evt)); + CL_CHECK(clWaitForEvents(1, &evt)); + CL_CHECK(clReleaseMemObject(data_device)); + + // Create image for Q + cl_image_format img_format_qs = {CL_R, CL_UNSIGNED_INT32}; + cl_image_desc img_desc_qs = { + CL_MEM_OBJECT_IMAGE1D_BUFFER, + static_cast<size_t>(ggml_nelements(tensor) / 8), + 0, 0, 0, 0, 0, 0, 0, + { extra->qs } + }; + extra->qs_img = clCreateImage(context, CL_MEM_READ_ONLY, &img_format_qs, &img_desc_qs, NULL, &err); + tensor->extra = extra; + + return; + } +#endif // GGML_OPENCL_USE_ADRENO_KERNELS + +#ifdef GGML_OPENCL_USE_ADRENO_KERNELS + if (use_adreno_kernels(backend_ctx, tensor)) { + cl_kernel kernel = backend_ctx->kernel_convert_block_q5_1_noshuffle; + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &data_device)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra->qs)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra->qh)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_mem), &extra->d)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extra->m)); + + size_t global_work_size[] = {(size_t)ggml_nelements(tensor)/ggml_blck_size(tensor->type), 1, 1}; + size_t local_work_size[] = {64, 1, 1}; + + cl_event evt; + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, &evt)); + CL_CHECK(clWaitForEvents(1, &evt)); + CL_CHECK(clReleaseMemObject(data_device)); + + tensor->extra = extra; + + int M = tensor->ne[1]; + int K = tensor->ne[0]; + GGML_ASSERT(K % 32 == 0); + + // Transpose qs as ushort + transpose_2d_as_16b(backend_ctx, extra->qs, extra->qs, size_qs, K/4, M); + // Transpose qh as uchar + transpose_2d_as_8b(backend_ctx, extra->qh, extra->qh, size_qh, K/8, M); + // Transpose d as ushort + transpose_2d_as_16b(backend_ctx, extra->d, extra->d, size_d, K/32, M); + // Transpose m as ushort + transpose_2d_as_16b(backend_ctx, extra->m, extra->m, size_m, K/32, M); + + return; + } +#endif // GGML_OPENCL_USE_ADRENO_KERNELS + cl_kernel kernel = backend_ctx->kernel_convert_block_q5_1; + cl_ulong n_blk = ggml_nelements(tensor)/ggml_blck_size(tensor->type); + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &data_device)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra->qs)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra->qh)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_mem), &extra->d)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extra->m)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &n_blk)); + + size_t global_work_size[] = {(size_t)CEIL_DIV(n_blk, 64) * 64, 1, 1}; + size_t local_work_size[] = {64, 1, 1}; + + cl_event evt; + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, &evt)); + CL_CHECK(clWaitForEvents(1, &evt)); + CL_CHECK(clReleaseMemObject(data_device)); + + tensor->extra = extra; + return; + } + if (tensor->type == GGML_TYPE_MXFP4) { + ggml_tensor_extra_cl * extra_orig = (ggml_tensor_extra_cl *)tensor->extra; + GGML_ASSERT(extra_orig && "Tesnors in OpenCL backend should have been allocated and initialized"); + + // Allocate the new extra and create aliases from the original. + ggml_backend_opencl_buffer_context * ctx = (ggml_backend_opencl_buffer_context *) buffer->context; + ggml_tensor_extra_cl_mxfp4 * extra = ctx->ggml_opencl_alloc_temp_tensor_extra_mxfp4(); + + size_t size_e = ggml_nelements(tensor)/ggml_blck_size(tensor->type)*sizeof(char); + size_t size_q = ggml_nelements(tensor)/ggml_blck_size(tensor->type)*ggml_blck_size(tensor->type)/2; + GGML_ASSERT(size_e + size_q == ggml_nbytes(tensor) && "Incorrect tensor size"); + + cl_int err; + cl_mem data_device = clCreateBuffer(context, CL_MEM_READ_WRITE, + ggml_nbytes(tensor), NULL, &err); + CL_CHECK(err); + CL_CHECK(clEnqueueWriteBuffer( + queue, data_device, CL_TRUE, 0, + ggml_nbytes(tensor), data, 0, NULL, NULL)); + + // The original tensor memory is divided into scales and quants, i.e., + // we first store scales, then quants. + cl_buffer_region region; + + // Create subbuffer for scales. + region.origin = align_to(extra_orig->offset + tensor->view_offs + offset, backend_ctx->alignment); + region.size = size_e; + extra->e = clCreateSubBuffer( + extra_orig->data_device, CL_MEM_READ_WRITE, + CL_BUFFER_CREATE_TYPE_REGION, ®ion, &err); + CL_CHECK(err); + auto previous_origin = region.origin; + + // Create subbuffer for quants. + region.origin = align_to(previous_origin + size_e, backend_ctx->alignment); + region.size = size_q; + extra->q = clCreateSubBuffer( + extra_orig->data_device, CL_MEM_READ_WRITE, + CL_BUFFER_CREATE_TYPE_REGION, ®ion, &err); + CL_CHECK(err); + +#ifdef GGML_OPENCL_USE_ADRENO_KERNELS + // Adreno moe mxfp4 kernel needs special transpose and unshuffling + if (use_adreno_moe_kernels(backend_ctx, tensor)) { + cl_kernel kernel = backend_ctx->kernel_convert_block_mxfp4_trans4_ns; + + int ne00 = tensor->ne[0]; + int ne01 = tensor->ne[1]; + int ne02 = tensor->ne[2]; + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &data_device)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra->q)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra->e)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(int), &ne01)); + + size_t global_work_size[3] = {static_cast<size_t>(((ne01 + 63) / 64) * 64), static_cast<size_t>(ne00 / 32), static_cast<size_t>(ne02)}; + size_t local_work_size[3] = {64, 2, 1}; + + cl_event evt; + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, &evt)); + CL_CHECK(clWaitForEvents(1, &evt)); + CL_CHECK(clReleaseMemObject(data_device)); + tensor->extra = extra; + + // Create image for Q + cl_image_format img_format_q = {CL_R, CL_UNSIGNED_INT32}; + cl_image_desc img_desc_q = { + CL_MEM_OBJECT_IMAGE1D_BUFFER, + static_cast<size_t>(ggml_nelements(tensor) / 8), + 0, 0, 0, 0, 0, 0, 0, + { extra->q } + }; + extra->q_img = clCreateImage(context, CL_MEM_READ_ONLY, &img_format_q, &img_desc_q, NULL, &err); + tensor->extra = extra; + + return; + } + +#endif // GGML_OPENCL_USE_ADRENO_KERNELS + cl_kernel kernel = backend_ctx->kernel_convert_block_mxfp4; + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &data_device)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra->q)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra->e)); + + size_t global_work_size[3] = {(size_t)ggml_nelements(tensor)/ggml_blck_size(tensor->type), 1, 1}; + size_t local_work_size[3] = {64, 1, 1}; + + cl_event evt; + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, &evt)); + CL_CHECK(clWaitForEvents(1, &evt)); + CL_CHECK(clReleaseMemObject(data_device)); + + // Create image for Q + cl_image_format img_format_q = {CL_RG, CL_UNSIGNED_INT32}; + cl_image_desc img_desc_q = { + CL_MEM_OBJECT_IMAGE1D_BUFFER, + static_cast<size_t>(ggml_nelements(tensor)/32*2), + 0, 0, 0, 0, 0, 0, 0, + { extra->q } + }; + extra->q_img = clCreateImage(context, CL_MEM_READ_ONLY, &img_format_q, &img_desc_q, NULL, &err); + tensor->extra = extra; + + return; + } + if (tensor->type == GGML_TYPE_Q8_0) { + ggml_tensor_extra_cl * extra_orig = (ggml_tensor_extra_cl *)tensor->extra; + GGML_ASSERT(extra_orig && "Tesnors in OpenCL backend should have been allocated and initialized"); + + // Allocate the new extra and create aliases from the original. + ggml_backend_opencl_buffer_context * ctx = (ggml_backend_opencl_buffer_context *) buffer->context; + ggml_tensor_extra_cl_q8_0 * extra = ctx->ggml_opencl_alloc_temp_tensor_extra_q8_0(); + + size_t size_d = ggml_nelements(tensor)/ggml_blck_size(tensor->type)*sizeof(ggml_fp16_t); + size_t size_q = ggml_nelements(tensor)/ggml_blck_size(tensor->type)*(ggml_blck_size(tensor->type)*sizeof(char)); + GGML_ASSERT(size_d + size_q == ggml_nbytes(tensor) && "Incorrect tensor size"); + + cl_int err; + cl_mem data_device = clCreateBuffer(context, CL_MEM_READ_WRITE, + ggml_nbytes(tensor), NULL, &err); + CL_CHECK(err); + CL_CHECK(clEnqueueWriteBuffer( + queue, data_device, CL_TRUE, 0, + ggml_nbytes(tensor), data, 0, NULL, NULL)); + + // The original tensor memory is divided into scales and quants, i.e., + // we first store scales, then quants. + cl_buffer_region region; + + // Create subbuffer for scales. + region.origin = align_to(extra_orig->offset + tensor->view_offs + offset, backend_ctx->alignment); + region.size = size_d; + extra->d = clCreateSubBuffer( + extra_orig->data_device, CL_MEM_READ_WRITE, + CL_BUFFER_CREATE_TYPE_REGION, ®ion, &err); + CL_CHECK(err); + auto previous_origin = region.origin; + + // Create subbuffer for quants. + region.origin = align_to(previous_origin + size_d, backend_ctx->alignment); + region.size = size_q; + extra->q = clCreateSubBuffer( + extra_orig->data_device, CL_MEM_READ_WRITE, + CL_BUFFER_CREATE_TYPE_REGION, ®ion, &err); + CL_CHECK(err); + + cl_kernel kernel = backend_ctx->kernel_convert_block_q8_0; + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &data_device)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra->q)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra->d)); + + size_t global_work_size[] = {(size_t)ggml_nelements(tensor)/ggml_blck_size(tensor->type), 1, 1}; + size_t local_work_size[] = {64, 1, 1}; + + cl_event evt; + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, &evt)); + CL_CHECK(clWaitForEvents(1, &evt)); + CL_CHECK(clReleaseMemObject(data_device)); + + tensor->extra = extra; + + // Transpose the weights and scales +#ifdef GGML_OPENCL_USE_ADRENO_KERNELS + if (enable_adreno_trans_weight(backend_ctx, tensor)) { + + int M = tensor->ne[1]; // ne01 + int K = tensor->ne[0]; // ne00 + + GGML_ASSERT(K % 32 == 0); + GGML_ASSERT(M % 4 == 0); + GGML_ASSERT(tensor->ne[2] == 1); + GGML_ASSERT(tensor->ne[3] == 1); + + transpose_2d_as_32b(backend_ctx, extra->q, extra->q, size_q, K/4, M); + transpose_2d_as_16b(backend_ctx, extra->d, extra->d, size_d, K/32, M); + } // end transpose +#endif // GGML_OPENCL_USE_ADRENO_KERNELS + + return; + } + if (tensor->type == GGML_TYPE_IQ4_NL) { + ggml_tensor_extra_cl * extra_orig = (ggml_tensor_extra_cl *)tensor->extra; + GGML_ASSERT(extra_orig && "Tensors in OpenCL backend should have been allocated and initialized"); + + ggml_backend_opencl_buffer_context * ctx = (ggml_backend_opencl_buffer_context *) buffer->context; + ggml_tensor_extra_cl_iq4_nl * extra = ctx->ggml_opencl_alloc_temp_tensor_extra_iq4_nl(); + + size_t size_d = ggml_nelements(tensor)/ggml_blck_size(tensor->type)*sizeof(ggml_fp16_t); + size_t size_q = ggml_nelements(tensor)/ggml_blck_size(tensor->type)*(ggml_blck_size(tensor->type)/2); + GGML_ASSERT(size_d + size_q == ggml_nbytes(tensor) && "Incorrect tensor size"); + + cl_int err; + cl_mem data_device = clCreateBuffer(context, CL_MEM_READ_WRITE, + ggml_nbytes(tensor), NULL, &err); + CL_CHECK(err); + CL_CHECK(clEnqueueWriteBuffer( + queue, data_device, CL_TRUE, 0, + ggml_nbytes(tensor), data, 0, NULL, NULL)); + + cl_buffer_region region; + + // Create subbuffer for scales. + region.origin = align_to(extra_orig->offset + tensor->view_offs + offset, backend_ctx->alignment); + region.size = size_d; + extra->d = clCreateSubBuffer( + extra_orig->data_device, CL_MEM_READ_WRITE, + CL_BUFFER_CREATE_TYPE_REGION, ®ion, &err); + CL_CHECK(err); + auto previous_origin = region.origin; + + // Create subbuffer for quants. + region.origin = align_to(previous_origin + size_d, backend_ctx->alignment); + region.size = size_q; + extra->q = clCreateSubBuffer( + extra_orig->data_device, CL_MEM_READ_WRITE, + CL_BUFFER_CREATE_TYPE_REGION, ®ion, &err); + CL_CHECK(err); + + #ifdef GGML_OPENCL_USE_ADRENO_KERNELS + cl_kernel kernel = backend_ctx->kernel_convert_block_iq4_nl; + if (use_adreno_kernels(backend_ctx, tensor)) { + kernel = backend_ctx->kernel_convert_block_iq4_nl_noshuffle; + } + #else + cl_kernel kernel = backend_ctx->kernel_convert_block_iq4_nl; + #endif + cl_ulong n_blk = ggml_nelements(tensor)/ggml_blck_size(tensor->type); + cl_uchar mask_0F = 0x0F; + cl_uchar mask_F0 = 0xF0; + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &data_device)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra->q)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra->d)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_uchar), &mask_0F)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_uchar), &mask_F0)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &n_blk)); + + size_t global_work_size[] = {(size_t)CEIL_DIV(n_blk, 64)*64, 1, 1}; + size_t local_work_size[] = {64, 1, 1}; + + cl_event evt; + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, &evt)); + CL_CHECK(clWaitForEvents(1, &evt)); + CL_CHECK(clReleaseMemObject(data_device)); + + tensor->extra = extra; + +#ifdef GGML_OPENCL_USE_ADRENO_KERNELS + if (use_adreno_kernels(backend_ctx, tensor)) { + int M = tensor->ne[1]; + int K = tensor->ne[0]; + GGML_ASSERT(K % 32 == 0); + + // Transpose q as ushort + transpose_2d_as_16b(backend_ctx, extra->q, extra->q, size_q, K/4, M); + // Transpose d as ushort + transpose_2d_as_16b(backend_ctx, extra->d, extra->d, size_d, K/32, M); + } +#endif + return; + } + if (tensor->type == GGML_TYPE_Q4_K) { + ggml_tensor_extra_cl * extra_orig = (ggml_tensor_extra_cl *)tensor->extra; + GGML_ASSERT(extra_orig && "Tesnors in OpenCL backend should have been allocated and initialized"); + + // Allocate the new extra and create aliases from the original. + ggml_backend_opencl_buffer_context * ctx = (ggml_backend_opencl_buffer_context *) buffer->context; + ggml_tensor_extra_cl_q4_K * extra = ctx->ggml_opencl_alloc_temp_tensor_extra_q4_K(); + + size_t size_d = ggml_nelements(tensor)/ggml_blck_size(tensor->type)*sizeof(ggml_fp16_t); + size_t size_dm = ggml_nelements(tensor)/ggml_blck_size(tensor->type)*sizeof(ggml_fp16_t); + size_t size_s = ggml_nelements(tensor)/ggml_blck_size(tensor->type)*(3 * ggml_blck_size(tensor->type) / 64); + size_t size_q = ggml_nelements(tensor)/ggml_blck_size(tensor->type)*ggml_blck_size(tensor->type)/2; + GGML_ASSERT(size_d + size_dm + size_s + size_q == ggml_nbytes(tensor) && "Incorrect tensor size"); + + cl_int err; + cl_mem data_device = clCreateBuffer(context, CL_MEM_READ_WRITE, + ggml_nbytes(tensor), NULL, &err); + CL_CHECK(err); + CL_CHECK(clEnqueueWriteBuffer( + queue, data_device, CL_TRUE, 0, + ggml_nbytes(tensor), data, 0, NULL, NULL)); + + cl_buffer_region region; + + // Create subbuffer for d. + region.origin = align_to(extra_orig->offset + tensor->view_offs + offset, backend_ctx->alignment); + region.size = size_d; + extra->d = clCreateSubBuffer( + extra_orig->data_device, CL_MEM_READ_WRITE, + CL_BUFFER_CREATE_TYPE_REGION, ®ion, &err); + CL_CHECK(err); + auto previous_origin = region.origin; + + // Create subbuffer for mins. + region.origin = align_to(previous_origin + size_d, backend_ctx->alignment); + region.size = size_dm; + extra->dm = clCreateSubBuffer( + extra_orig->data_device, CL_MEM_READ_WRITE, + CL_BUFFER_CREATE_TYPE_REGION, ®ion, &err); + CL_CHECK(err); + previous_origin = region.origin; + + // Create subbuffer for s. + region.origin = align_to(previous_origin + size_dm, backend_ctx->alignment); + region.size = size_s; + extra->s = clCreateSubBuffer( + extra_orig->data_device, CL_MEM_READ_WRITE, + CL_BUFFER_CREATE_TYPE_REGION, ®ion, &err); + CL_CHECK(err); + previous_origin = region.origin; + + // Create subbuffer for quants. + region.origin = align_to(previous_origin + size_s, backend_ctx->alignment); + region.size = size_q; + extra->q = clCreateSubBuffer( + extra_orig->data_device, CL_MEM_READ_WRITE, + CL_BUFFER_CREATE_TYPE_REGION, ®ion, &err); + CL_CHECK(err); + +#ifdef GGML_OPENCL_USE_ADRENO_KERNELS + if (use_adreno_moe_kernels(backend_ctx, tensor)) { + cl_kernel kernel = backend_ctx->kernel_convert_block_q4_k_trans4_ns; + + int ne00 = tensor->ne[0]; + int ne01 = tensor->ne[1]; + int ne02 = tensor->ne[2]; + + cl_uchar mask_0F = 0x0F; + cl_uchar mask_F0 = 0xF0; + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &data_device)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra->q)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra->d)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_mem), &extra->dm)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extra->s)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne01)); + CL_CHECK(clSetKernelArg(kernel, 7, sizeof(cl_uchar), &mask_0F)); + CL_CHECK(clSetKernelArg(kernel, 8, sizeof(cl_uchar), &mask_F0)); + + size_t global_work_size[] = {static_cast<size_t>(((ne01 + 63) / 64) * 64), static_cast<size_t>(ne00 / 256), static_cast<size_t>(ne02)}; + size_t local_work_size[] = {64, 1, 1}; + + cl_event evt; + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, &evt)); + CL_CHECK(clWaitForEvents(1, &evt)); + CL_CHECK(clReleaseMemObject(data_device)); + + cl_image_format img_format_q = {CL_R, CL_UNSIGNED_INT32}; + cl_image_desc img_desc_q = { + CL_MEM_OBJECT_IMAGE1D_BUFFER, + static_cast<size_t>(ggml_nelements(tensor) / 8), + 0, 0, 0, 0, 0, 0, 0, + { extra->q } + }; + extra->q_img = clCreateImage(context, CL_MEM_READ_ONLY, &img_format_q, &img_desc_q, NULL, &err); + CL_CHECK(err); + tensor->extra = extra; + + return; + } +#endif // GGML_OPENCL_USE_ADRENO_KERNELS + +#ifdef GGML_OPENCL_USE_ADRENO_KERNELS + cl_kernel kernel = backend_ctx->kernel_convert_block_q4_K; + if (use_adreno_kernels(backend_ctx, tensor) && !use_flat_gemv_for_large_m_q4_K(tensor)) { + kernel = backend_ctx->kernel_convert_block_q4_K_noshuffle; + } +#else + cl_kernel kernel = backend_ctx->kernel_convert_block_q4_K; +#endif // GGML_OPENCL_USE_ADRENO_KERNELS + + cl_uchar mask_0F = 0x0F; + cl_uchar mask_F0 = 0xF0; + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &data_device)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra->q)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra->s)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_mem), &extra->d)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extra->dm)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_uchar), &mask_0F)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(cl_uchar), &mask_F0)); + + size_t global_work_size[] = {(size_t)ggml_nelements(tensor)/ggml_blck_size(tensor->type), 1, 1}; + size_t local_work_size[] = {64, 1, 1}; + + cl_event evt; + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, &evt)); + CL_CHECK(clWaitForEvents(1, &evt)); + CL_CHECK(clReleaseMemObject(data_device)); + + tensor->extra = extra; +#ifdef GGML_OPENCL_USE_ADRENO_KERNELS + if (use_adreno_kernels(backend_ctx, tensor) && !use_flat_gemv_for_large_m_q4_K(tensor)) { + + int M = tensor->ne[1]; + int K = tensor->ne[0]; + + GGML_ASSERT(K % 32 == 0); + + // Transpose q, d, dm as ushort + transpose_2d_as_16b(backend_ctx, extra->q, extra->q, size_q, K/4, M); + transpose_2d_as_16b(backend_ctx, extra->d, extra->d, size_d, K/256, M); + transpose_2d_as_16b(backend_ctx, extra->dm, extra->dm, size_dm, K/256, M); + } +#endif // GGML_OPENCL_USE_ADRENO_KERNELS + return; + } + if (tensor->type == GGML_TYPE_Q5_K) { + ggml_tensor_extra_cl * extra_orig = (ggml_tensor_extra_cl *)tensor->extra; + GGML_ASSERT(extra_orig && "Tesnors in OpenCL backend should have been allocated and initialized"); + + // Allocate the new extra and create aliases from the original. + ggml_backend_opencl_buffer_context * ctx = (ggml_backend_opencl_buffer_context *) buffer->context; + ggml_tensor_extra_cl_q5_K * extra = ctx->ggml_opencl_alloc_temp_tensor_extra_q5_K(); + + size_t size_q = ggml_nelements(tensor)/ggml_blck_size(tensor->type)*ggml_blck_size(tensor->type)/2; + size_t size_qh = ggml_nelements(tensor)/ggml_blck_size(tensor->type)*ggml_blck_size(tensor->type)/8; + size_t size_s = ggml_nelements(tensor)/ggml_blck_size(tensor->type)*(3*ggml_blck_size(tensor->type)/64); + size_t size_d = ggml_nelements(tensor)/ggml_blck_size(tensor->type)*sizeof(ggml_fp16_t); + size_t size_dm = ggml_nelements(tensor)/ggml_blck_size(tensor->type)*sizeof(ggml_fp16_t); + GGML_ASSERT(size_q + size_qh + size_s + size_d + size_dm == ggml_nbytes(tensor) && + "Incorrect tensor size"); + + cl_int err; + cl_mem data_device; + CL_CHECK((data_device = clCreateBuffer(context, CL_MEM_READ_WRITE, ggml_nbytes(tensor), NULL, &err), err)); + CL_CHECK(clEnqueueWriteBuffer(queue, data_device, CL_TRUE, 0, ggml_nbytes(tensor), data, 0, NULL, NULL)); + + cl_buffer_region region; + + // Create subbuffer for d. + region.origin = align_to(extra_orig->offset + tensor->view_offs + offset, backend_ctx->alignment); + region.size = size_d; + extra->d = clCreateSubBuffer( + extra_orig->data_device, CL_MEM_READ_WRITE, + CL_BUFFER_CREATE_TYPE_REGION, ®ion, &err); + CL_CHECK(err); + auto previous_origin = region.origin; + + // Create subbuffer for dm. + region.origin = align_to(previous_origin + size_d, backend_ctx->alignment); + region.size = size_dm; + extra->dm = clCreateSubBuffer( + extra_orig->data_device, CL_MEM_READ_WRITE, + CL_BUFFER_CREATE_TYPE_REGION, ®ion, &err); + CL_CHECK(err); + previous_origin = region.origin; + + // Create subbuffer for s. + region.origin = align_to(previous_origin + size_dm, backend_ctx->alignment); + region.size = size_s; + extra->s = clCreateSubBuffer( + extra_orig->data_device, CL_MEM_READ_WRITE, + CL_BUFFER_CREATE_TYPE_REGION, ®ion, &err); + CL_CHECK(err); + previous_origin = region.origin; + + // Create subbuffer for q (lower 4 bits) + region.origin = align_to(previous_origin + size_s, backend_ctx->alignment); + region.size = size_q; + extra->q = clCreateSubBuffer( + extra_orig->data_device, CL_MEM_READ_WRITE, + CL_BUFFER_CREATE_TYPE_REGION, ®ion, &err); + CL_CHECK(err); + previous_origin = region.origin; + + // Create subbuffer for qh (upper 1 bit) + region.origin = align_to(previous_origin + size_q, backend_ctx->alignment); + region.size = size_qh; + CL_CHECK((extra->qh = clCreateSubBuffer(extra_orig->data_device, CL_MEM_READ_WRITE, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &err), err)); + CL_CHECK(err); + +#ifdef GGML_OPENCL_USE_ADRENO_KERNELS + if (use_adreno_moe_kernels(backend_ctx, tensor)) { + cl_kernel kernel = backend_ctx->kernel_convert_block_q5_k_trans4_ns; + + int ne00 = tensor->ne[0]; + int ne01 = tensor->ne[1]; + int ne02 = tensor->ne[2]; + + cl_uchar mask_0F = 0x0F; + cl_uchar mask_F0 = 0xF0; + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &data_device)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra->q)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra->qh)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_mem), &extra->d)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extra->dm)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_mem), &extra->s)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &ne01)); + CL_CHECK(clSetKernelArg(kernel, 8, sizeof(cl_uchar), &mask_0F)); + CL_CHECK(clSetKernelArg(kernel, 9, sizeof(cl_uchar), &mask_F0)); + + size_t global_work_size[] = {static_cast<size_t>(((ne01 + 63) / 64) * 64), static_cast<size_t>(ne00 / 256), static_cast<size_t>(ne02)}; + size_t local_work_size[] = {64, 1, 1}; + + cl_event evt; + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, &evt)); + CL_CHECK(clWaitForEvents(1, &evt)); + CL_CHECK(clReleaseMemObject(data_device)); + + cl_image_format img_format_q = {CL_R, CL_UNSIGNED_INT32}; + cl_image_desc img_desc_q = { + CL_MEM_OBJECT_IMAGE1D_BUFFER, + static_cast<size_t>(ggml_nelements(tensor) / 8), + 0, 0, 0, 0, 0, 0, 0, + { extra->q } + }; + extra->q_img = clCreateImage(context, CL_MEM_READ_ONLY, &img_format_q, &img_desc_q, NULL, &err); + CL_CHECK(err); + tensor->extra = extra; + + return; + } +#endif // GGML_OPENCL_USE_ADRENO_KERNELS + +#ifdef GGML_OPENCL_USE_ADRENO_KERNELS + cl_kernel kernel = backend_ctx->kernel_convert_block_q5_K; + if (use_adreno_kernels(backend_ctx, tensor)) { + kernel = backend_ctx->kernel_convert_block_q5_K_noshuffle; + } +#else + cl_kernel kernel = backend_ctx->kernel_convert_block_q5_K; +#endif + + cl_uchar mask_0F = 0x0F; + cl_uchar mask_F0 = 0xF0; + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &data_device)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra->q)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra->qh)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_mem), &extra->s)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extra->d)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_mem), &extra->dm)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(cl_uchar), &mask_0F)); + CL_CHECK(clSetKernelArg(kernel, 7, sizeof(cl_uchar), &mask_F0)); + + size_t global_work_size[] = {(size_t)ggml_nelements(tensor)/ggml_blck_size(tensor->type), 1, 1}; + size_t local_work_size[] = {64, 1, 1}; + + cl_event evt; + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, &evt)); + CL_CHECK(clWaitForEvents(1, &evt)); + CL_CHECK(clReleaseMemObject(data_device)); + + extra->size_q = size_q; + extra->size_qh = size_qh; + extra->size_s = size_s; + extra->size_d = size_d; + extra->size_dm = size_dm; + + tensor->extra = extra; +#ifdef GGML_OPENCL_USE_ADRENO_KERNELS + if (use_adreno_kernels(backend_ctx, tensor)) { + + int M = tensor->ne[1]; + int K = tensor->ne[0]; + + GGML_ASSERT(K % 32 == 0); + + // Transpose q, d, dm as ushort, qh as uchar + transpose_2d_as_16b(backend_ctx, extra->q, extra->q, size_q, K/4, M); + transpose_2d_as_8b (backend_ctx, extra->qh, extra->qh, size_qh, K/8, M); + transpose_2d_as_16b(backend_ctx, extra->d, extra->d, size_d, K/256, M); + transpose_2d_as_16b(backend_ctx, extra->dm, extra->dm, size_dm, K/256, M); + } +#endif // GGML_OPENCL_USE_ADRENO_KERNELS + return; + } + if (tensor->type == GGML_TYPE_Q6_K) { + ggml_tensor_extra_cl * extra_orig = (ggml_tensor_extra_cl *)tensor->extra; + GGML_ASSERT(extra_orig && "Tesnors in OpenCL backend should have been allocated and initialized"); + + // Allocate the new extra and create aliases from the original. + ggml_backend_opencl_buffer_context * ctx = (ggml_backend_opencl_buffer_context *) buffer->context; + ggml_tensor_extra_cl_q6_K * extra = ctx->ggml_opencl_alloc_temp_tensor_extra_q6_K(); + + size_t size_ql = ggml_nelements(tensor)/ggml_blck_size(tensor->type)*ggml_blck_size(tensor->type)/2; + size_t size_qh = ggml_nelements(tensor)/ggml_blck_size(tensor->type)*ggml_blck_size(tensor->type)/4; + size_t size_s = ggml_nelements(tensor)/ggml_blck_size(tensor->type)*ggml_blck_size(tensor->type)/16; + size_t size_d = ggml_nelements(tensor)/ggml_blck_size(tensor->type)*sizeof(ggml_fp16_t); + GGML_ASSERT(size_ql + size_qh + size_s + size_d == ggml_nbytes(tensor) && + "Incorrect tensor size"); + + cl_int err; + cl_mem data_device; + CL_CHECK((data_device = clCreateBuffer(context, CL_MEM_READ_WRITE, ggml_nbytes(tensor), NULL, &err), err)); + CL_CHECK(clEnqueueWriteBuffer(queue, data_device, CL_TRUE, 0, ggml_nbytes(tensor), data, 0, NULL, NULL)); + + cl_buffer_region region; + +#ifdef GGML_OPENCL_USE_ADRENO_KERNELS + // Adreno MoE Q6_K kernel needs special transposed layout + if (use_adreno_moe_kernels(backend_ctx, tensor)) { + size_t moe_size_ql = (size_t)(ggml_nelements(tensor) / 8) * sizeof(uint32_t); // 4 bits per element + size_t moe_size_qh = (size_t)(ggml_nelements(tensor) / 16) * sizeof(uint32_t); // 2 bits per element + size_t moe_size_s = size_s; + size_t moe_size_d = size_d; + + // Subbuffer for ql + region.origin = align_to(extra_orig->offset + tensor->view_offs + offset, backend_ctx->alignment); + region.size = moe_size_ql; + CL_CHECK((extra->ql = clCreateSubBuffer(extra_orig->data_device, CL_MEM_READ_WRITE, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &err), err)); + auto previous_origin = region.origin; + + // Subbuffer for qh + region.origin = align_to(previous_origin + moe_size_ql, backend_ctx->alignment); + region.size = moe_size_qh; + CL_CHECK((extra->qh = clCreateSubBuffer(extra_orig->data_device, CL_MEM_READ_WRITE, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &err), err)); + previous_origin = region.origin; + + // Subbuffer for scales + region.origin = align_to(previous_origin + moe_size_qh, backend_ctx->alignment); + region.size = moe_size_s; + CL_CHECK((extra->s = clCreateSubBuffer(extra_orig->data_device, CL_MEM_READ_WRITE, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &err), err)); + previous_origin = region.origin; + + // Subbuffer for d + region.origin = align_to(previous_origin + moe_size_s, backend_ctx->alignment); + region.size = moe_size_d; + CL_CHECK((extra->d = clCreateSubBuffer(extra_orig->data_device, CL_MEM_READ_WRITE, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &err), err)); + + cl_kernel kernel = backend_ctx->kernel_convert_block_q6_k_trans4_ns; + + cl_uchar mask_0F = 0x0F; + cl_uchar mask_F0 = 0xF0; + + int ne00 = tensor->ne[0]; + int ne01 = tensor->ne[1]; + int ne02 = tensor->ne[2]; + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &data_device)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra->ql)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra->qh)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_mem), &extra->d)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extra->s)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne01)); + CL_CHECK(clSetKernelArg(kernel, 7, sizeof(cl_uchar), &mask_0F)); + CL_CHECK(clSetKernelArg(kernel, 8, sizeof(cl_uchar), &mask_F0)); + + size_t global_work_size[] = {static_cast<size_t>(((ne01 + 63) / 64) * 64), static_cast<size_t>(ne00 / 256), static_cast<size_t>(ne02)}; + size_t local_work_size[] = {64, 1, 1}; + + cl_event evt; + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, &evt)); + CL_CHECK(clWaitForEvents(1, &evt)); + CL_CHECK(clReleaseMemObject(data_device)); + + // Create image for ql + cl_image_format img_format_ql = {CL_R, CL_UNSIGNED_INT32}; + cl_image_desc img_desc_ql = { + CL_MEM_OBJECT_IMAGE1D_BUFFER, + static_cast<size_t>(ggml_nelements(tensor) / 8), + 0, 0, 0, 0, 0, 0, 0, + { extra->ql } + }; + extra->ql_img = clCreateImage(context, CL_MEM_READ_ONLY, &img_format_ql, &img_desc_ql, NULL, &err); + tensor->extra = extra; + + return; + } +#endif // GGML_OPENCL_USE_ADRENO_KERNELS + + // Subbuffer for ql + region.origin = align_to(extra_orig->offset + tensor->view_offs + offset, backend_ctx->alignment); + region.size = size_ql; + CL_CHECK((extra->ql = clCreateSubBuffer(extra_orig->data_device, CL_MEM_READ_WRITE, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &err), err)); + auto previous_origin = region.origin; + + // Subbuffer for qh + region.origin = align_to(previous_origin + size_ql, backend_ctx->alignment); + region.size = size_qh; + CL_CHECK((extra->qh = clCreateSubBuffer(extra_orig->data_device, CL_MEM_READ_WRITE, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &err), err)); + previous_origin = region.origin; + + // Subbuffer for scales + region.origin = align_to(previous_origin + size_qh, backend_ctx->alignment); + region.size = size_s; + CL_CHECK((extra->s = clCreateSubBuffer(extra_orig->data_device, CL_MEM_READ_WRITE, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &err), err)); + previous_origin = region.origin; + + // Create subbuffer for d. + region.origin = align_to(previous_origin + size_s, backend_ctx->alignment); + region.size = size_d; + CL_CHECK((extra->d = clCreateSubBuffer(extra_orig->data_device, CL_MEM_READ_WRITE, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &err), err)); + previous_origin = region.origin; + + // Flatten the weights + cl_kernel kernel; +#ifdef GGML_OPENCL_USE_ADRENO_KERNELS + kernel = backend_ctx->kernel_convert_block_q6_K; + if (use_adreno_kernels(backend_ctx, tensor) && !use_flat_gemv_for_large_m_q6_K(tensor)) { + kernel = backend_ctx->kernel_convert_block_q6_K_noshuffle; + } +#else + kernel = backend_ctx->kernel_convert_block_q6_K; +#endif // GGML_OPENCL_USE_ADRENO_KERNELS + + cl_uchar mask = 0xff; + cl_ulong n_blk = ggml_nelements(tensor)/ggml_blck_size(tensor->type); + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &data_device)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra->ql)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra->qh)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_mem), &extra->s)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extra->d)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_uchar), &mask)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(cl_ulong), &n_blk)); + + size_t global_work_size[] = {(size_t)CEIL_DIV(n_blk, 64)*64, 1, 1}; + size_t local_work_size[] = {64, 1, 1}; + + cl_event evt; + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, &evt)); + CL_CHECK(clWaitForEvents(1, &evt)); + CL_CHECK(clReleaseMemObject(data_device)); + + extra->size_ql = size_ql; + extra->size_qh = size_qh; + extra->size_s = size_s; + extra->size_d = size_d; + + tensor->extra = extra; + +#ifdef GGML_OPENCL_USE_ADRENO_KERNELS + if (use_adreno_kernels(backend_ctx, tensor) && !use_flat_gemv_for_large_m_q6_K(tensor)) { + cl_int M = tensor->ne[1]; // ne01 + cl_int K = tensor->ne[0]; // ne00 + + // Transpose ql as ushort + transpose_2d_as_16b(backend_ctx, + extra->ql, extra->ql, size_ql, K/4, M); + + // Transpose qh as uchar + transpose_2d_as_8b(backend_ctx, + extra->qh, extra->qh, size_qh, K/4, M); + + // Transpose s as ushort + transpose_2d_as_16b(backend_ctx, + extra->s, extra->s, size_s, K/16/2, M); + + // Transpose d as ushort + transpose_2d_as_16b(backend_ctx, + extra->d, extra->d, size_d, K/256, M); + } +#endif // GGML_OPENCL_USE_ADRENO_KERNELS + return; + } +#endif // GGML_OPENCL_SOA_Q + + // convert bf16 to f16 and store as f16 in device buffer + if (tensor->type == GGML_TYPE_BF16) { + GGML_ASSERT(offset % sizeof(ggml_fp16_t) == 0 && size % sizeof(ggml_fp16_t) == 0 + && "Offset and size must be multiples of 2 for bf16 tensors"); + + ggml_tensor_extra_cl * extra = (ggml_tensor_extra_cl *) tensor->extra; + GGML_ASSERT(extra); + + cl_ulong n_elements = size / sizeof(ggml_fp16_t); + cl_ulong off_dst = (extra->offset + offset) / sizeof(ggml_fp16_t); + + cl_int err; + cl_mem data_device = clCreateBuffer(context, CL_MEM_READ_ONLY | CL_MEM_COPY_HOST_PTR, + size, const_cast<void *>(data), &err); + CL_CHECK(err); + + cl_kernel kernel = backend_ctx->kernel_convert_bf16_to_f16; + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &data_device)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra->data_device)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_ulong), &off_dst)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &n_elements)); + + size_t global_work_size[] = { (size_t)CEIL_DIV(n_elements, 64)*64, 1, 1 }; + size_t local_work_size[] = { 64, 1, 1 }; + + cl_event evt; + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, &evt)); + CL_CHECK(clWaitForEvents(1, &evt)); + CL_CHECK(clReleaseMemObject(data_device)); + CL_CHECK(clReleaseEvent(evt)); + + return; + } + + ggml_tensor_extra_cl * extra = (ggml_tensor_extra_cl *) tensor->extra; + GGML_ASSERT(extra); + + CL_CHECK(clEnqueueWriteBuffer( + queue, extra->data_device, CL_TRUE, extra->offset + offset, + size, data, 0, NULL, NULL)); + + GGML_UNUSED(buffer); +} + +static void ggml_backend_opencl_buffer_get_tensor(ggml_backend_buffer_t buffer, const ggml_tensor * tensor, void * data, size_t offset, size_t size) { + GGML_ASSERT(tensor->extra); + + ggml_backend_opencl_device_context * dev_ctx = (ggml_backend_opencl_device_context *) buffer->buft->device->context; + ggml_backend_opencl_context *backend_ctx = dev_ctx->backend_ctx; + + cl_context context = backend_ctx->context; + cl_command_queue queue = backend_ctx->queue; + + // Make sure all previously submitted commands in other devices are finished. + sync_with_other_backends(backend_ctx); + +#ifdef GGML_OPENCL_SOA_Q + // In end-to-end runs, get_tensor is usually used to get back the logits, + // where we can simply do clEnqueueReadBuffer since they are f32. + // However, in test-backend-ops, the GPU graph is copied to the CPU backend, + // which requires reading back quantized weight tensors. + // To properly support this, we need to restore block_q4_0 struct arrays + // from the flattened buffers. + if (tensor->type == GGML_TYPE_Q4_0) { + ggml_tensor_extra_cl_q4_0 * extra = (ggml_tensor_extra_cl_q4_0 *)tensor->extra; + +#ifdef GGML_OPENCL_USE_ADRENO_KERNELS + if (use_adreno_moe_kernels(backend_ctx, tensor)) { + cl_int err; + cl_kernel kernel = backend_ctx->kernel_restore_block_q4_0_trans4_ns; + + cl_mem data_device = clCreateBuffer(context, CL_MEM_READ_WRITE, + ggml_nbytes(tensor), NULL, &err); + CL_CHECK(err); + + int ne00 = tensor->ne[0]; + int ne01 = tensor->ne[1]; + int ne02 = tensor->ne[2]; + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra->q)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra->d)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &data_device)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_int), &ne01)); + + size_t global_work_size[3] = {static_cast<size_t>(((ne01 + 63) / 64) * 64), static_cast<size_t>(ne00 / 32), static_cast<size_t>(ne02)}; + size_t local_work_size[3] = {64, 2, 1}; + + cl_event evt; + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, + global_work_size, local_work_size, 0, NULL, &evt)); + CL_CHECK(clWaitForEvents(1, &evt)); + CL_CHECK(clEnqueueReadBuffer( + queue, data_device, CL_TRUE, offset, + size, data, 0, NULL, NULL)); + CL_CHECK(clReleaseMemObject(data_device)); + return; + } + if (use_adreno_kernels(backend_ctx, tensor)) { + ggml_cl_buffer buf_trans_q; + ggml_cl_buffer buf_trans_d; + ggml_cl_buffer buf_unpacked; + + cl_int M = tensor->ne[1]; // ne01 + cl_int K = tensor->ne[0]; // ne00 + + GGML_ASSERT(K % 32 == 0); + GGML_ASSERT(M % 4 == 0); + + size_t size_q = (ggml_nelements(tensor)/ggml_blck_size(tensor->type))*ggml_blck_size(tensor->type)/2; + size_t size_d = (ggml_nelements(tensor)/ggml_blck_size(tensor->type))*sizeof(ggml_fp16_t); + GGML_ASSERT(size_d + size_q == ggml_nbytes(tensor) && "Incorrect tensor size"); + + buf_trans_q.allocate(backend_ctx->context, size_q); + buf_trans_d.allocate(backend_ctx->context, size_d); + buf_unpacked.allocate(backend_ctx->context, ggml_nbytes(tensor)); + + transpose_2d_as_16b(backend_ctx, extra->q, buf_trans_q.buffer, size_q, M, K/4); + transpose_2d_as_16b(backend_ctx, extra->d, buf_trans_d.buffer, size_d, M, K/32); + + cl_uchar mask_0F = 0x0F; + cl_uchar mask_F0 = 0xF0; + + size_t global_work_size[] = {(size_t)ggml_nelements(tensor)/ggml_blck_size(tensor->type), 1, 1}; + size_t local_work_size[] = {1, 1, 1}; + + cl_kernel kernel = backend_ctx->kernel_restore_block_q4_0_noshuffle; + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &buf_trans_q.buffer)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &buf_trans_d.buffer)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &buf_unpacked.buffer)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_uchar), &mask_0F)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_uchar), &mask_F0)); + + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, NULL)); + CL_CHECK(clEnqueueReadBuffer(queue, buf_unpacked.buffer, CL_TRUE, offset, size, data, 0, NULL, NULL)); + return; + } +#endif + + cl_int err; + cl_mem data_device = clCreateBuffer(context, CL_MEM_READ_WRITE, + ggml_nbytes(tensor), NULL, &err); + CL_CHECK(err); + + cl_kernel kernel = backend_ctx->kernel_restore_block_q4_0; + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra->q)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra->d)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &data_device)); + + size_t global_work_size[] = {(size_t)ggml_nelements(tensor)/ggml_blck_size(tensor->type), 1, 1}; + size_t local_work_size[] = {1, 1, 1}; + + cl_event evt; + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, + global_work_size, local_work_size, 0, NULL, &evt)); + CL_CHECK(clWaitForEvents(1, &evt)); + CL_CHECK(clEnqueueReadBuffer( + queue, data_device, CL_TRUE, offset, + size, data, 0, NULL, NULL)); + CL_CHECK(clReleaseMemObject(data_device)); + return; + } + if (tensor->type == GGML_TYPE_Q4_1) { + ggml_tensor_extra_cl_q4_1 * extra = (ggml_tensor_extra_cl_q4_1 *)tensor->extra; + +#ifdef GGML_OPENCL_USE_ADRENO_KERNELS + if (use_adreno_moe_kernels(backend_ctx, tensor)) { + cl_int err; + cl_mem data_device = clCreateBuffer(context, CL_MEM_READ_WRITE, + ggml_nbytes(tensor), NULL, &err); + CL_CHECK(err); + cl_kernel kernel = backend_ctx->kernel_restore_block_q4_1_trans4_ns; + + int ne00 = tensor->ne[0]; + int ne01 = tensor->ne[1]; + int ne02 = tensor->ne[2]; + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra->q)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra->d)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra->m)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_mem), &data_device)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_int), &ne01)); + + size_t global_work_size[3] = {static_cast<size_t>(((ne01 + 63) / 64) * 64), static_cast<size_t>(ne00 / 32), static_cast<size_t>(ne02)}; + size_t local_work_size[3] = {64, 2, 1}; + + cl_event evt; + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, + global_work_size, local_work_size, 0, NULL, &evt)); + CL_CHECK(clWaitForEvents(1, &evt)); + CL_CHECK(clEnqueueReadBuffer( + queue, data_device, CL_TRUE, offset, + size, data, 0, NULL, NULL)); + CL_CHECK(clReleaseMemObject(data_device)); + return; + } + if (use_adreno_kernels(backend_ctx, tensor)) { + static ggml_cl_buffer buf_trans_q; + static ggml_cl_buffer buf_trans_m; + static ggml_cl_buffer buf_trans_d; + static ggml_cl_buffer buf_unpacked; + + cl_int M = tensor->ne[1]; + cl_int K = tensor->ne[0]; + + GGML_ASSERT(K % ggml_blck_size(tensor->type) == 0); + + size_t size_q = (ggml_nelements(tensor)/ggml_blck_size(tensor->type))*ggml_blck_size(tensor->type)/2; + size_t size_d = (ggml_nelements(tensor)/ggml_blck_size(tensor->type))*sizeof(ggml_fp16_t); + size_t size_m = (ggml_nelements(tensor)/ggml_blck_size(tensor->type))*sizeof(ggml_fp16_t); + GGML_ASSERT(size_d + size_q + size_m == ggml_nbytes(tensor) && "Incorrect tensor size"); + + buf_trans_q.allocate(backend_ctx->context, size_q); + buf_trans_m.allocate(backend_ctx->context, size_m); + buf_trans_d.allocate(backend_ctx->context, size_d); + buf_unpacked.allocate(backend_ctx->context, ggml_nbytes(tensor)); + + // transpose q, d, m back + transpose_2d_as_16b(backend_ctx, extra->q, buf_trans_q.buffer, size_q, M, K/4); + transpose_2d_as_16b(backend_ctx, extra->d, buf_trans_d.buffer, size_d, M, K/32); + transpose_2d_as_16b(backend_ctx, extra->m, buf_trans_m.buffer, size_m, M, K/32); + + cl_uchar mask_0F = 0x0F; + cl_uchar mask_F0 = 0xF0; + + size_t global_work_size[] = {(size_t)ggml_nelements(tensor)/ggml_blck_size(tensor->type), 1, 1}; + size_t local_work_size[] = {1, 1, 1}; + + cl_kernel kernel = backend_ctx->kernel_restore_block_q4_1_noshuffle; + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &buf_trans_q.buffer)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &buf_trans_d.buffer)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &buf_trans_m.buffer)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_mem), &buf_unpacked.buffer)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_uchar), &mask_0F)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_uchar), &mask_F0)); + + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, NULL)); + CL_CHECK(clEnqueueReadBuffer(queue, buf_unpacked.buffer, CL_TRUE, offset, size, data, 0, NULL, NULL)); + return; + } +#endif + + cl_int err; + cl_mem data_device = clCreateBuffer(context, CL_MEM_READ_WRITE, + ggml_nbytes(tensor), NULL, &err); + CL_CHECK(err); + + cl_kernel kernel = backend_ctx->kernel_restore_block_q4_1; + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra->q)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra->d)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra->m)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_mem), &data_device)); + + size_t global_work_size[] = {(size_t)ggml_nelements(tensor)/ggml_blck_size(tensor->type), 1, 1}; + size_t local_work_size[] = {1, 1, 1}; + + cl_event evt; + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, + global_work_size, local_work_size, 0, NULL, &evt)); + CL_CHECK(clWaitForEvents(1, &evt)); + CL_CHECK(clEnqueueReadBuffer( + queue, data_device, CL_TRUE, offset, + size, data, 0, NULL, NULL)); + CL_CHECK(clReleaseMemObject(data_device)); + return; + } + if (tensor->type == GGML_TYPE_Q5_0) { + ggml_tensor_extra_cl_q5_0 * extra = (ggml_tensor_extra_cl_q5_0 *)tensor->extra; + +#ifdef GGML_OPENCL_USE_ADRENO_KERNELS + if (use_adreno_moe_kernels(backend_ctx, tensor)) { + cl_int err; + // TODO: use ggml_cl_buffer to manage this temporary buffer + cl_mem data_device = clCreateBuffer(context, CL_MEM_READ_WRITE, + ggml_nbytes(tensor), NULL, &err); + CL_CHECK(err); + + cl_kernel kernel = backend_ctx->kernel_restore_block_q5_0_trans4_ns; + + int ne00 = tensor->ne[0]; + int ne01 = tensor->ne[1]; + int ne02 = tensor->ne[2]; + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra->qs)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra->qh)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra->d)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_mem), &data_device)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_int), &ne01)); + + size_t global_work_size[3] = {static_cast<size_t>(((ne01 + 63) / 64) * 64), static_cast<size_t>(ne00 / 32), static_cast<size_t>(ne02)}; + size_t local_work_size[3] = {64, 2, 1}; + + cl_event evt; + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, + global_work_size, local_work_size, 0, NULL, &evt)); + CL_CHECK(clWaitForEvents(1, &evt)); + CL_CHECK(clEnqueueReadBuffer( + queue, data_device, CL_TRUE, offset, + size, data, 0, NULL, NULL)); + CL_CHECK(clReleaseMemObject(data_device)); + return; + } + if (use_adreno_kernels(backend_ctx, tensor)) { + ggml_cl_buffer buf_trans_qs; + ggml_cl_buffer buf_trans_qh; + ggml_cl_buffer buf_trans_d; + ggml_cl_buffer buf_unpacked; - const int dk = q->ne[0]; - const int dv = v->ne[0]; + cl_int M = tensor->ne[1]; + cl_int K = tensor->ne[0]; - const struct { int dk; int dv; } supported_dims[] = { - { 40, 40}, { 64, 64}, { 80, 80}, { 96, 96}, - {112, 112}, {128, 128}, {192, 128}, - {192, 192}, {256, 256}, - }; + GGML_ASSERT(K % 32 == 0); - bool dims_supported = false; - for (size_t i = 0; i < sizeof(supported_dims)/sizeof(supported_dims[0]); ++i) { - if (supported_dims[i].dk == dk && supported_dims[i].dv == dv) { - dims_supported = true; - break; - } - } - if (!dims_supported) { - return false; - } + size_t size_qs = (ggml_nelements(tensor)/ggml_blck_size(tensor->type))*ggml_blck_size(tensor->type)/2; + size_t size_qh = (ggml_nelements(tensor)/ggml_blck_size(tensor->type))*sizeof(int32_t); + size_t size_d = (ggml_nelements(tensor)/ggml_blck_size(tensor->type))*sizeof(ggml_fp16_t); - const bool is_f32_f32 = q->type == GGML_TYPE_F32 && k->type == GGML_TYPE_F32 && - v->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32; - const bool is_f16_f16 = q->type == GGML_TYPE_F16 && k->type == GGML_TYPE_F16 && - v->type == GGML_TYPE_F16 && op->type == GGML_TYPE_F16; - const bool is_f32_f16 = q->type == GGML_TYPE_F32 && k->type == GGML_TYPE_F16 && - v->type == GGML_TYPE_F16 && op->type == GGML_TYPE_F32; + buf_trans_qs.allocate(backend_ctx->context, size_qs); + buf_trans_qh.allocate(backend_ctx->context, size_qh); + buf_trans_d.allocate(backend_ctx->context, size_d); + buf_unpacked.allocate(backend_ctx->context, ggml_nbytes(tensor)); - return is_f32_f32 || is_f16_f16 || is_f32_f16; - } - default: - return false; - } -} + transpose_2d_as_16b(backend_ctx, extra->qs, buf_trans_qs.buffer, size_qs, M, K/4); + transpose_2d_as_8b(backend_ctx, extra->qh, buf_trans_qh.buffer, size_qh, M, K/8); + transpose_2d_as_16b(backend_ctx, extra->d, buf_trans_d.buffer, size_d, M, K/32); -// Forward declaration - implementation appears later in the file. -static const char * ggml_backend_opencl_buffer_type_get_name(ggml_backend_buffer_type_t buffer_type); + cl_uchar mask_0F = 0x0F; + cl_uchar mask_F0 = 0xF0; -static ggml_guid_t ggml_backend_opencl_guid() { - static ggml_guid guid = { 0xde, 0xe0, 0x70, 0xa2, 0x73, 0x4e, 0x4d, 0xbc, 0xb0, 0xc7, 0x4f, 0xd4, 0x6d, 0x4e, 0x90, 0xfe }; - return &guid; -} + size_t global_work_size[] = {(size_t)ggml_nelements(tensor)/ggml_blck_size(tensor->type), 1, 1}; + size_t local_work_size[] = {1, 1, 1}; -static ggml_backend_i ggml_backend_opencl_i = { - /* .get_name = */ ggml_backend_opencl_name, - /* .free = */ ggml_backend_opencl_free, - /* .set_tensor_async = */ NULL, /* ggml_backend_opencl_set_tensor_async */ - /* .get_tensor_async = */ NULL, /* ggml_backend_opencl_get_tensor_async */ - /* .cpy_tensor_async = */ NULL, /* ggml_backend_opencl_cpy_tensor_async */ - /* .synchronize = */ ggml_backend_opencl_synchronize, - /* .graph_plan_create = */ NULL, - /* .graph_plan_free = */ NULL, - /* .graph_plan_update = */ NULL, - /* .graph_plan_compute = */ NULL, - /* .graph_compute = */ ggml_backend_opencl_graph_compute, - /* .event_record = */ NULL, - /* .event_wait = */ NULL, - /* .graph_optimize = */ NULL, -}; + cl_kernel kernel = backend_ctx->kernel_restore_block_q5_0_noshuffle; + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &buf_trans_qs.buffer)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &buf_trans_qh.buffer)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &buf_trans_d.buffer)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_mem), &buf_unpacked.buffer)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_uchar), &mask_0F)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_uchar), &mask_F0)); -ggml_backend_t ggml_backend_opencl_init(void) { - ggml_backend_dev_t dev = ggml_backend_reg_dev_get(ggml_backend_opencl_reg(), 0); - ggml_backend_opencl_context *backend_ctx = ggml_cl2_init(dev); + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, NULL)); + CL_CHECK(clEnqueueReadBuffer(queue, buf_unpacked.buffer, CL_TRUE, offset, size, data, 0, NULL, NULL)); + return; + } +#endif // GGML_OPENCL_USE_ADRENO_KERNELS - ggml_backend_t backend = new ggml_backend { - /* .guid = */ ggml_backend_opencl_guid(), - /* .iface = */ ggml_backend_opencl_i, - /* .device = */ dev, - /* .context = */ backend_ctx - }; + cl_int err; + cl_mem data_device = clCreateBuffer(context, CL_MEM_READ_WRITE, + ggml_nbytes(tensor), NULL, &err); + CL_CHECK(err); - return backend; -} + cl_kernel kernel = backend_ctx->kernel_restore_block_q5_0; + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra->qs)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra->qh)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra->d)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_mem), &data_device)); -bool ggml_backend_is_opencl(ggml_backend_t backend) { - return backend && backend->iface.get_name == ggml_backend_opencl_name; -} + size_t global_work_size[] = {(size_t)ggml_nelements(tensor)/ggml_blck_size(tensor->type), 1, 1}; + size_t local_work_size[] = {1, 1, 1}; -// -// buffer -// -struct ggml_backend_opencl_buffer_context { - // A buffer context can hold multiple cl_mem objects. This is for flattening - // quantized weights and should be used with GGML_OPENCL_SMALL_ALLOC where - // each tensor is allocated a separate buffer. When flattening is enabled - // with small allocation, each tensor is backed by two cl_mem objects (for - // quants and scales) packed into a backend_opencl_buffer. - ggml_backend_opencl_buffer_context(cl_mem buf) - : name("OpenCL") { - buffer.push_back(buf); + cl_event evt; + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, + global_work_size, local_work_size, 0, NULL, &evt)); + CL_CHECK(clWaitForEvents(1, &evt)); + CL_CHECK(clEnqueueReadBuffer( + queue, data_device, CL_TRUE, offset, + size, data, 0, NULL, NULL)); + CL_CHECK(clReleaseMemObject(data_device)); + return; } + if (tensor->type == GGML_TYPE_Q5_1) { + ggml_tensor_extra_cl_q5_1 * extra = (ggml_tensor_extra_cl_q5_1 *)tensor->extra; - ~ggml_backend_opencl_buffer_context() { - for (cl_mem buf : buffer) { - CL_CHECK(clReleaseMemObject(buf)); - } - for (cl_mem im : img) { - CL_CHECK(clReleaseMemObject(im)); - } +#ifdef GGML_OPENCL_USE_ADRENO_KERNELS + if (use_adreno_moe_kernels(backend_ctx, tensor)) { + cl_int err; + // TODO: use ggml_cl_buffer to manage this temporary buffer + cl_mem data_device = clCreateBuffer(context, CL_MEM_READ_WRITE, + ggml_nbytes(tensor), NULL, &err); + CL_CHECK(err); - // Delete all extras to trigger their destructors - for (ggml_tensor_extra_cl * e : temp_tensor_extras) { - delete e; - } - for (ggml_tensor_extra_cl * e : temp_tensor_extras_in_use) { - delete e; - } - for (ggml_tensor_extra_cl_q4_0 * e : temp_tensor_extras_q4_0) { - delete e; - } - for (ggml_tensor_extra_cl_q4_0 * e : temp_tensor_extras_q4_0_in_use) { - delete e; - } - for (ggml_tensor_extra_cl_mxfp4 * e : temp_tensor_extras_mxfp4) { - delete e; - } - for (ggml_tensor_extra_cl_mxfp4 * e : temp_tensor_extras_mxfp4_in_use) { - delete e; - } - for (ggml_tensor_extra_cl_q8_0 * e : temp_tensor_extras_q8_0) { - delete e; + cl_kernel kernel = backend_ctx->kernel_restore_block_q5_1_trans4_ns; + + int ne00 = tensor->ne[0]; + int ne01 = tensor->ne[1]; + int ne02 = tensor->ne[2]; + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra->qs)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra->qh)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra->d)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_mem), &extra->m)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &data_device)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(cl_int), &ne01)); + + size_t global_work_size[3] = {static_cast<size_t>(((ne01 + 63) / 64) * 64), static_cast<size_t>(ne00 / 32), static_cast<size_t>(ne02)}; + size_t local_work_size[3] = {64, 2, 1}; + + cl_event evt; + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, + global_work_size, local_work_size, 0, NULL, &evt)); + CL_CHECK(clWaitForEvents(1, &evt)); + CL_CHECK(clEnqueueReadBuffer( + queue, data_device, CL_TRUE, offset, + size, data, 0, NULL, NULL)); + CL_CHECK(clReleaseMemObject(data_device)); + return; } - for (ggml_tensor_extra_cl_q8_0 * e : temp_tensor_extras_q8_0_in_use) { - delete e; + + if (use_adreno_kernels(backend_ctx, tensor)) { + ggml_cl_buffer buf_trans_qs; + ggml_cl_buffer buf_trans_qh; + ggml_cl_buffer buf_trans_d; + ggml_cl_buffer buf_trans_m; + ggml_cl_buffer buf_unpacked; + + cl_int M = tensor->ne[1]; + cl_int K = tensor->ne[0]; + GGML_ASSERT(K % 32 == 0); + + size_t size_qs = (ggml_nelements(tensor)/ggml_blck_size(tensor->type))*ggml_blck_size(tensor->type)/2; + size_t size_qh = (ggml_nelements(tensor)/ggml_blck_size(tensor->type))*sizeof(int32_t); + size_t size_d = (ggml_nelements(tensor)/ggml_blck_size(tensor->type))*sizeof(ggml_fp16_t); + size_t size_m = (ggml_nelements(tensor)/ggml_blck_size(tensor->type))*sizeof(ggml_fp16_t); + + buf_trans_qs.allocate(backend_ctx->context, size_qs); + buf_trans_qh.allocate(backend_ctx->context, size_qh); + buf_trans_d.allocate(backend_ctx->context, size_d); + buf_trans_m.allocate(backend_ctx->context, size_m); + buf_unpacked.allocate(backend_ctx->context, ggml_nbytes(tensor)); + + // Transpose back: from col-major to row-major + transpose_2d_as_16b(backend_ctx, extra->qs, buf_trans_qs.buffer, size_qs, M, K/4); + transpose_2d_as_8b(backend_ctx, extra->qh, buf_trans_qh.buffer, size_qh, M, K/8); + transpose_2d_as_16b(backend_ctx, extra->d, buf_trans_d.buffer, size_d, M, K/32); + transpose_2d_as_16b(backend_ctx, extra->m, buf_trans_m.buffer, size_m, M, K/32); + + cl_uchar mask_0F = 0x0F; + cl_uchar mask_F0 = 0xF0; + + size_t global_work_size[] = {(size_t)ggml_nelements(tensor)/ggml_blck_size(tensor->type), 1, 1}; + size_t local_work_size[] = {1, 1, 1}; + + cl_kernel kernel = backend_ctx->kernel_restore_block_q5_1_noshuffle; + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &buf_trans_qs.buffer)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &buf_trans_qh.buffer)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &buf_trans_d.buffer)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_mem), &buf_trans_m.buffer)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &buf_unpacked.buffer)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_uchar), &mask_0F)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(cl_uchar), &mask_F0)); + + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, NULL)); + CL_CHECK(clEnqueueReadBuffer(queue, buf_unpacked.buffer, CL_TRUE, offset, size, data, 0, NULL, NULL)); + return; } +#endif // GGML_OPENCL_USE_ADRENO_KERNELS + cl_int err; + cl_mem data_device = clCreateBuffer(context, CL_MEM_READ_WRITE, + ggml_nbytes(tensor), NULL, &err); + CL_CHECK(err); + + cl_kernel kernel = backend_ctx->kernel_restore_block_q5_1; + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra->qs)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra->qh)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra->d)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_mem), &extra->m)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &data_device)); + + size_t global_work_size[] = {(size_t)ggml_nelements(tensor)/ggml_blck_size(tensor->type), 1, 1}; + size_t local_work_size[] = {1, 1, 1}; + + cl_event evt; + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, + global_work_size, local_work_size, 0, NULL, &evt)); + CL_CHECK(clWaitForEvents(1, &evt)); + CL_CHECK(clEnqueueReadBuffer( + queue, data_device, CL_TRUE, offset, + size, data, 0, NULL, NULL)); + CL_CHECK(clReleaseMemObject(data_device)); + return; } + if (tensor->type == GGML_TYPE_MXFP4) { + ggml_tensor_extra_cl_mxfp4 * extra = (ggml_tensor_extra_cl_mxfp4 *)tensor->extra; - ggml_tensor_extra_cl * ggml_opencl_alloc_temp_tensor_extra() { - ggml_tensor_extra_cl * extra; - if (temp_tensor_extras.empty()) { - extra = new ggml_tensor_extra_cl(); - } else { - extra = temp_tensor_extras.back(); - temp_tensor_extras.pop_back(); + cl_int err; + cl_mem data_device = clCreateBuffer(context, CL_MEM_READ_WRITE, + ggml_nbytes(tensor), NULL, &err); + CL_CHECK(err); + +#ifdef GGML_OPENCL_USE_ADRENO_KERNELS + if (use_adreno_moe_kernels(backend_ctx, tensor)) { + cl_kernel kernel = backend_ctx->kernel_restore_block_mxfp4_trans4_ns; + + int ne00 = tensor->ne[0]; + int ne01 = tensor->ne[1]; + int ne02 = tensor->ne[2]; + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra->q)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra->e)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &data_device)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_int), &ne01)); + + size_t global_work_size[3] = {static_cast<size_t>(((ne01 + 63) / 64) * 64), static_cast<size_t>(ne00 / 32), static_cast<size_t>(ne02)}; + size_t local_work_size[3] = {64, 2, 1}; + + cl_event evt; + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, + global_work_size, local_work_size, 0, NULL, &evt)); + CL_CHECK(clWaitForEvents(1, &evt)); + CL_CHECK(clEnqueueReadBuffer( + queue, data_device, CL_TRUE, offset, + size, data, 0, NULL, NULL)); + CL_CHECK(clReleaseMemObject(data_device)); + return; } - temp_tensor_extras_in_use.push_back(extra); +#endif // GGML_OPENCL_USE_ADRENO_KERNELS + cl_kernel kernel = backend_ctx->kernel_restore_block_mxfp4; + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra->q)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra->e)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &data_device)); - extra->reset(); - return extra; + size_t global_work_size[] = {(size_t)ggml_nelements(tensor)/ggml_blck_size(tensor->type), 1, 1}; + size_t local_work_size[] = {1, 1, 1}; + + cl_event evt; + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, + global_work_size, local_work_size, 0, NULL, &evt)); + CL_CHECK(clWaitForEvents(1, &evt)); + CL_CHECK(clEnqueueReadBuffer( + queue, data_device, CL_TRUE, offset, + size, data, 0, NULL, NULL)); + CL_CHECK(clReleaseMemObject(data_device)); + return; } + if (tensor->type == GGML_TYPE_Q8_0) { + ggml_tensor_extra_cl_q8_0 * extra = (ggml_tensor_extra_cl_q8_0 *)tensor->extra; + + cl_int err; + cl_mem data_device = clCreateBuffer(context, CL_MEM_READ_WRITE, + ggml_nbytes(tensor), NULL, &err); + CL_CHECK(err); + +#ifdef GGML_OPENCL_USE_ADRENO_KERNELS + if (enable_adreno_trans_weight(backend_ctx, tensor)) { + cl_kernel kernel = backend_ctx->kernel_restore_block_q8_0_trans; + + int ne00 = tensor->ne[0]; + int ne01 = tensor->ne[1]; + GGML_ASSERT(tensor->ne[2] == 1); + GGML_ASSERT(tensor->ne[3] == 1); - ggml_tensor_extra_cl_q4_0 * ggml_opencl_alloc_temp_tensor_extra_q4_0() { - ggml_tensor_extra_cl_q4_0 * extra; - if (temp_tensor_extras_q4_0.empty()) { - extra = new ggml_tensor_extra_cl_q4_0(); - } else { - extra = temp_tensor_extras_q4_0.back(); - temp_tensor_extras_q4_0.pop_back(); - } + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra->q)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra->d)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &data_device)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_int), &ne01)); - temp_tensor_extras_q4_0_in_use.push_back(extra); + size_t global_work_size[3] = {static_cast<size_t>(((ne01 + 63) / 64) * 64), 1, 1}; + size_t local_work_size[3] = {64, 1, 1}; - extra->reset(); - return extra; - } + cl_event evt; + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, + global_work_size, local_work_size, 0, NULL, &evt)); + CL_CHECK(clWaitForEvents(1, &evt)); - ggml_tensor_extra_cl_mxfp4 * ggml_opencl_alloc_temp_tensor_extra_mxfp4() { - ggml_tensor_extra_cl_mxfp4 * extra; - if (temp_tensor_extras_mxfp4.empty()) { - extra = new ggml_tensor_extra_cl_mxfp4(); - } else { - extra = temp_tensor_extras_mxfp4.back(); - temp_tensor_extras_mxfp4.pop_back(); + CL_CHECK(clEnqueueReadBuffer( + queue, data_device, CL_TRUE, offset, + size, data, 0, NULL, NULL)); + CL_CHECK(clReleaseMemObject(data_device)); + return; } +#endif + cl_kernel kernel = backend_ctx->kernel_restore_block_q8_0; + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra->q)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra->d)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &data_device)); - temp_tensor_extras_mxfp4_in_use.push_back(extra); + size_t global_work_size[] = {(size_t)ggml_nelements(tensor)/ggml_blck_size(tensor->type), 1, 1}; + size_t local_work_size[] = {1, 1, 1}; - extra->reset(); - return extra; + cl_event evt; + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, + global_work_size, local_work_size, 0, NULL, &evt)); + CL_CHECK(clWaitForEvents(1, &evt)); + CL_CHECK(clEnqueueReadBuffer( + queue, data_device, CL_TRUE, offset, + size, data, 0, NULL, NULL)); + CL_CHECK(clReleaseMemObject(data_device)); + return; } + if (tensor->type == GGML_TYPE_IQ4_NL) { + ggml_tensor_extra_cl_iq4_nl * extra = (ggml_tensor_extra_cl_iq4_nl *)tensor->extra; - ggml_tensor_extra_cl_q8_0 * ggml_opencl_alloc_temp_tensor_extra_q8_0() { - ggml_tensor_extra_cl_q8_0 * extra; - if (temp_tensor_extras_q8_0.empty()) { - extra = new ggml_tensor_extra_cl_q8_0(); - } else { - extra = temp_tensor_extras_q8_0.back(); - temp_tensor_extras_q8_0.pop_back(); - } + cl_int err; + cl_mem data_device = clCreateBuffer(context, CL_MEM_READ_WRITE, + ggml_nbytes(tensor), NULL, &err); + CL_CHECK(err); - temp_tensor_extras_q8_0_in_use.push_back(extra); +#ifdef GGML_OPENCL_USE_ADRENO_KERNELS + if (use_adreno_kernels(backend_ctx, tensor)) { + static ggml_cl_buffer buf_trans_q; + static ggml_cl_buffer buf_trans_d; + static ggml_cl_buffer buf_unpacked; - extra->reset(); - return extra; - } + cl_int M = tensor->ne[1]; + cl_int K = tensor->ne[0]; + GGML_ASSERT(K % 32 == 0); - void reset() { - for (ggml_tensor_extra_cl * e : temp_tensor_extras_in_use) { - temp_tensor_extras.push_back(e); - } - temp_tensor_extras_in_use.clear(); + size_t size_q = (ggml_nelements(tensor)/ggml_blck_size(tensor->type))*(ggml_blck_size(tensor->type)/2); + size_t size_d = (ggml_nelements(tensor)/ggml_blck_size(tensor->type))*sizeof(ggml_fp16_t); + GGML_ASSERT(size_d + size_q == ggml_nbytes(tensor) && "Incorrect tensor size"); - for (ggml_tensor_extra_cl_q4_0 * e : temp_tensor_extras_q4_0_in_use) { - temp_tensor_extras_q4_0.push_back(e); - } - temp_tensor_extras_q4_0_in_use.clear(); + buf_trans_q.allocate(backend_ctx->context, size_q); + buf_trans_d.allocate(backend_ctx->context, size_d); + buf_unpacked.allocate(backend_ctx->context, ggml_nbytes(tensor)); - for (ggml_tensor_extra_cl_mxfp4 * e : temp_tensor_extras_mxfp4_in_use) { - temp_tensor_extras_mxfp4.push_back(e); - } - temp_tensor_extras_mxfp4_in_use.clear(); + // transpose q, d back + transpose_2d_as_16b(backend_ctx, extra->q, buf_trans_q.buffer, size_q, M, K/4); + transpose_2d_as_16b(backend_ctx, extra->d, buf_trans_d.buffer, size_d, M, K/32); - for (ggml_tensor_extra_cl_q8_0 * e : temp_tensor_extras_q8_0_in_use) { - temp_tensor_extras_q8_0.push_back(e); - } - temp_tensor_extras_q8_0_in_use.clear(); - } + cl_uchar mask_0F = 0x0F; + cl_uchar mask_F0 = 0xF0; - // Pools for extras. Available extras are in `temp_tensor_extras`. Extras - // being used are in `temp_tensor_extras_in_use`. At the first run, new - // extras get created and put in `in_use`. When the buffer is reset via - // the `reset` callback, all extras in `in_use` get moved to available extras - // for reuse. - std::vector<ggml_tensor_extra_cl *> temp_tensor_extras; - std::vector<ggml_tensor_extra_cl *> temp_tensor_extras_in_use; - std::vector<ggml_tensor_extra_cl_q4_0 *> temp_tensor_extras_q4_0; - std::vector<ggml_tensor_extra_cl_q4_0 *> temp_tensor_extras_q4_0_in_use; - std::vector<ggml_tensor_extra_cl_mxfp4 *> temp_tensor_extras_mxfp4; - std::vector<ggml_tensor_extra_cl_mxfp4 *> temp_tensor_extras_mxfp4_in_use; - std::vector<ggml_tensor_extra_cl_q8_0 *> temp_tensor_extras_q8_0; - std::vector<ggml_tensor_extra_cl_q8_0 *> temp_tensor_extras_q8_0_in_use; + cl_kernel kernel = backend_ctx->kernel_restore_block_iq4_nl_noshuffle; + cl_ulong n_blk = ggml_nelements(tensor)/ggml_blck_size(tensor->type); - // The buffer_context is initially created by ggml_backend_buft_alloc_buffer - // before any tensor is initialized (at the beginning of alloc_tensor_range). - // Hence, there is alway a buffer object in this vector. When each tensor is - // being initialized, this original buffer object will be released if both - // flattening and small allocation are enabled, and additional buffer - // objects will be created in init_tensor to represent flattened quantized - // weights. - std::vector<cl_mem> buffer; - // These are image1d_buffer_t objects that wrap around the quants and scales. - // For Q4_0 quantization, there should be two of them - one for quants and - // one for scales. They should be populated only when flattening and small - // allocation are enabled. - std::vector<cl_mem> img; - std::string name; -}; + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &buf_trans_q.buffer)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &buf_trans_d.buffer)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &buf_unpacked.buffer)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_uchar), &mask_0F)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_uchar), &mask_F0)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &n_blk)); -static void ggml_backend_opencl_buffer_free_buffer(ggml_backend_buffer_t buffer) { - ggml_backend_opencl_buffer_context * ctx = (ggml_backend_opencl_buffer_context *) buffer->context; - delete ctx; -} + size_t global_work_size[] = {(size_t)n_blk, 1, 1}; + size_t local_work_size[] = {1, 1, 1}; -static void * ggml_backend_opencl_buffer_get_base(ggml_backend_buffer_t buffer) { - ggml_backend_opencl_context * backend_ctx = ggml_cl2_init(buffer->buft->device); - return (void *) (uintptr_t) backend_ctx->alignment; -} + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, NULL)); + CL_CHECK(clEnqueueReadBuffer(queue, buf_unpacked.buffer, CL_TRUE, offset, size, data, 0, NULL, NULL)); + return; + } +#endif + cl_kernel kernel = backend_ctx->kernel_restore_block_iq4_nl; + cl_ulong n_blk = ggml_nelements(tensor)/ggml_blck_size(tensor->type); -static enum ggml_status ggml_backend_opencl_buffer_init_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor) { - ggml_backend_opencl_buffer_context * ctx = (ggml_backend_opencl_buffer_context *) buffer->context; + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra->q)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra->d)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &data_device)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &n_blk)); - ggml_cl2_init(buffer->buft->device); + size_t global_work_size[] = {(size_t)n_blk, 1, 1}; + size_t local_work_size[] = {1, 1, 1}; - if (tensor->view_src != nullptr) { - GGML_ASSERT(tensor->view_src->buffer->buft == buffer->buft); + cl_event evt; + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, + global_work_size, local_work_size, 0, NULL, &evt)); + CL_CHECK(clWaitForEvents(1, &evt)); + CL_CHECK(clEnqueueReadBuffer( + queue, data_device, CL_TRUE, offset, + size, data, 0, NULL, NULL)); + CL_CHECK(clReleaseMemObject(data_device)); + return; + } + if (tensor->type == GGML_TYPE_Q4_K) { + ggml_tensor_extra_cl_q4_K * extra = (ggml_tensor_extra_cl_q4_K *)tensor->extra; - ggml_tensor_extra_cl * view_extra = (ggml_tensor_extra_cl *) tensor->view_src->extra; - GGML_ASSERT(view_extra && "view_extra is nullptr?"); + cl_int err; + cl_mem data_device = clCreateBuffer(context, CL_MEM_READ_WRITE, + ggml_nbytes(tensor), NULL, &err); + CL_CHECK(err); - // Reuse extra of the parent tensor. The offset of this view tensor - // becomes `extra->offset + view_offs` and needs to be calculated when - // it is used. This changes is needed because of the change to - // ggml_alloc.c in https://github.com/ggerganov/llama.cpp/pull/7640. - // `buffer` passed in here will always be `tensor->buffer`. It is OK - // to allocate extras from the same buffer context for ordinary - // intermediate tensors. But for views into kv cache tensors, doing so - // would mess up the extras used by kv cache. - // Before #7640, `buffer` is for intermediate tensors, which is always - // different from that of kv cache tensors. - // - // NB: now extra->offset no longer accounts for view_offs. - // NB: this should not apply to weight tensors (for end-to-end runs, but - // may apply for test-backend-ops). - // FIXME: if any unexpected results are seen, double check the offset - - // there could be other places that need fix. - tensor->extra = view_extra; - } else { - { - size_t offset = (char *) tensor->data - (char *) ggml_backend_opencl_buffer_get_base(buffer); + cl_uchar mask_0F = 0x0F; + cl_uchar mask_F0 = 0xF0; - ggml_tensor_extra_cl * extra = ctx->ggml_opencl_alloc_temp_tensor_extra(); - extra->offset = offset; - extra->data_device = ctx->buffer[0]; - extra->actual_size = ggml_nbytes(tensor); +#ifdef GGML_OPENCL_USE_ADRENO_KERNELS + if (use_adreno_moe_kernels(backend_ctx, tensor)) { + cl_int err; + cl_mem data_device = clCreateBuffer(context, CL_MEM_READ_WRITE, + ggml_nbytes(tensor), NULL, &err); + CL_CHECK(err); - tensor->extra = extra; - } - } - return GGML_STATUS_SUCCESS; -} + cl_kernel kernel = backend_ctx->kernel_restore_block_q4_k_trans4_ns; -// The optimized gemm and gemv kernels are used for large matrices without batch. -// tensor is the quantized weights matrix. -inline bool use_adreno_kernels(const ggml_backend_opencl_context *backend_ctx, const ggml_tensor *tensor) { - int64_t threshold_ne0 = 512; - int64_t threshold_ne1 = 512; - if (!backend_ctx->adreno_cl_compiler_version.newer_than_or_same(E031, 38, 11, 0) && - backend_ctx->adreno_cl_compiler_version.type != DX) { - threshold_ne0 = 128; - threshold_ne1 = 128; - } - return tensor->ne[0] >= threshold_ne0 && tensor->ne[1] >= threshold_ne1 && - tensor->ne[2] == 1 && tensor->ne[3] == 1; -} + int ne00 = tensor->ne[0]; + int ne01 = tensor->ne[1]; + int ne02 = tensor->ne[2]; + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra->q)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra->d)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra->dm)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_mem), &extra->s)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &data_device)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(cl_int), &ne01)); + CL_CHECK(clSetKernelArg(kernel, 7, sizeof(cl_uchar), &mask_0F)); + CL_CHECK(clSetKernelArg(kernel, 8, sizeof(cl_uchar), &mask_F0)); + + size_t global_work_size[] = {static_cast<size_t>(((ne01 + 63) / 64) * 64), static_cast<size_t>(ne00 / 256), static_cast<size_t>(ne02)}; + size_t local_work_size[] = {64, 1, 1}; -inline bool use_adreno_moe_kernels(const ggml_backend_opencl_context *backend_ctx, const ggml_tensor *tensor) { - GGML_UNUSED(backend_ctx); - int ne01 = tensor->ne[1]; - return ((strstr(tensor->name, "ffn") != NULL) || (strstr(tensor->name, "as") != NULL)) && (ne01 % 64 == 0); -} + cl_event evt; + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, + global_work_size, local_work_size, 0, NULL, &evt)); + CL_CHECK(clWaitForEvents(1, &evt)); + CL_CHECK(clEnqueueReadBuffer( + queue, data_device, CL_TRUE, offset, + size, data, 0, NULL, NULL)); + CL_CHECK(clReleaseMemObject(data_device)); + return; + } + if (use_adreno_kernels(backend_ctx, tensor) && !use_flat_gemv_for_large_m_q4_K(tensor)) { + int M = tensor->ne[1]; + int K = tensor->ne[0]; + + size_t size_q = ggml_nelements(tensor)/ggml_blck_size(tensor->type)*ggml_blck_size(tensor->type)/2; + size_t size_d = ggml_nelements(tensor)/ggml_blck_size(tensor->type)*sizeof(ggml_fp16_t); + size_t size_dm = ggml_nelements(tensor)/ggml_blck_size(tensor->type)*sizeof(ggml_fp16_t); + + static ggml_cl_buffer buf_trans_q; + static ggml_cl_buffer buf_trans_d; + static ggml_cl_buffer buf_trans_dm; + + buf_trans_q.allocate(backend_ctx->context, size_q); + buf_trans_d.allocate(backend_ctx->context, size_d); + buf_trans_dm.allocate(backend_ctx->context, size_dm); + + // Transpose q, d, dm back + transpose_2d_as_16b(backend_ctx, extra->q, buf_trans_q.buffer, size_q, M, K/4); + transpose_2d_as_16b(backend_ctx, extra->d, buf_trans_d.buffer, size_d, M, K/256); + transpose_2d_as_16b(backend_ctx, extra->dm, buf_trans_dm.buffer, size_dm, M, K/256); + + cl_kernel kernel = backend_ctx->kernel_restore_block_q4_K_noshuffle; + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &buf_trans_q.buffer)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra->s)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &buf_trans_d.buffer)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_mem), &buf_trans_dm.buffer)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &data_device)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_uchar), &mask_0F)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(cl_uchar), &mask_F0)); -static void ggml_backend_opencl_buffer_set_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor, const void * data, size_t offset, size_t size) { - ggml_backend_opencl_context *backend_ctx = ggml_cl2_init(buffer->buft->device); + size_t global_work_size[] = {(size_t)ggml_nelements(tensor)/ggml_blck_size(tensor->type), 1, 1}; + size_t local_work_size[] = {1, 1, 1}; - cl_context context = backend_ctx->context; - cl_command_queue queue = backend_ctx->queue; + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, + global_work_size, local_work_size, 0, NULL, NULL)); + CL_CHECK(clEnqueueReadBuffer(queue, data_device, CL_TRUE, offset, + size, data, 0, NULL, NULL)); + CL_CHECK(clReleaseMemObject(data_device)); + return; + } +#endif // GGML_OPENCL_USE_ADRENO_KERNELS -#ifdef GGML_OPENCL_SOA_Q - // We separate the quantized bits and scale from block_q4_0 by using an - // additional kernel, where each thread handles a block. We first read the - // original weights into a temporary buffer, then create two separate - // buffers for quantized bits and scales, which are then populated by the - // conversion kernel. - if (tensor->type == GGML_TYPE_Q4_0) { - // Tensors should have been preallocated, therefore they should - // already have ggml_tensor_extra_cl as extra. - ggml_tensor_extra_cl * extra_orig = (ggml_tensor_extra_cl *)tensor->extra; - GGML_ASSERT(extra_orig && "Tesnors in OpenCL backend should have been allocated and initialized"); + cl_kernel kernel = backend_ctx->kernel_restore_block_q4_K; + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra->q)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra->s)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra->d)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_mem), &extra->dm)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &data_device)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_uchar), &mask_0F)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(cl_uchar), &mask_F0)); - // Allocate the new extra and create aliases from the original. - ggml_backend_opencl_buffer_context * ctx = (ggml_backend_opencl_buffer_context *) buffer->context; - ggml_tensor_extra_cl_q4_0 * extra = ctx->ggml_opencl_alloc_temp_tensor_extra_q4_0(); + size_t global_work_size[] = {(size_t)ggml_nelements(tensor)/ggml_blck_size(tensor->type), 1, 1}; + size_t local_work_size[] = {1, 1, 1}; - size_t size_d = ggml_nelements(tensor)/ggml_blck_size(tensor->type)*sizeof(ggml_fp16_t); - size_t size_q = ggml_nelements(tensor)/ggml_blck_size(tensor->type)*ggml_blck_size(tensor->type)/2; - GGML_ASSERT(size_d + size_q == ggml_nbytes(tensor) && "Incorrect tensor size"); + cl_event evt; + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, + global_work_size, local_work_size, 0, NULL, &evt)); + CL_CHECK(clWaitForEvents(1, &evt)); + CL_CHECK(clEnqueueReadBuffer( + queue, data_device, CL_TRUE, offset, + size, data, 0, NULL, NULL)); + CL_CHECK(clReleaseMemObject(data_device)); + return; + } + if (tensor->type == GGML_TYPE_Q5_K) { + ggml_tensor_extra_cl_q5_K * extra = (ggml_tensor_extra_cl_q5_K *)tensor->extra; cl_int err; cl_mem data_device = clCreateBuffer(context, CL_MEM_READ_WRITE, ggml_nbytes(tensor), NULL, &err); CL_CHECK(err); - CL_CHECK(clEnqueueWriteBuffer( - queue, data_device, CL_TRUE, 0, - ggml_nbytes(tensor), data, 0, NULL, NULL)); - // We consider the specified offset arg as always, although For weights - // the offset arg should be 0 (we do not assert this). - //GGML_ASSERT(offset == 0); + cl_uchar mask_0F = 0x0F; + cl_uchar mask_F0 = 0xF0; - // We create subbuffers from the original tensor buffer for scales and - // quants - i.e., scales and quants are aliases into the buffer obejct - // that backs the original tensor. This is a cleaner way to adapt to the - // new memory management. - // In the old code, we allocate new buffers for scales and quants - // respectively, which could still be done but would result in double - // allocation; properly deallocating the preallocated buffer that backs - // the tensors is tricky and would leak the backend specific information - // into the general backend code. - // Does this create misaligned subbuffers (alignment is 1024) in certain - // cases ? - cl_buffer_region region; +#ifdef GGML_OPENCL_USE_ADRENO_KERNELS + if (use_adreno_moe_kernels(backend_ctx, tensor)) { + cl_int err; + cl_mem data_device = clCreateBuffer(context, CL_MEM_READ_WRITE, + ggml_nbytes(tensor), NULL, &err); + CL_CHECK(err); + cl_kernel kernel = backend_ctx->kernel_restore_block_q5_k_trans4_ns; - // The original tensor memory is divided into scales and quants, i.e., - // we first store scales, then quants. - // Create subbuffer for scales. - region.origin = align_to(extra_orig->offset + tensor->view_offs + offset, backend_ctx->alignment); - region.size = size_d; - extra->d = clCreateSubBuffer( - extra_orig->data_device, CL_MEM_READ_WRITE, - CL_BUFFER_CREATE_TYPE_REGION, ®ion, &err); - CL_CHECK(err); - auto previous_origin = region.origin; + int ne00 = tensor->ne[0]; + int ne01 = tensor->ne[1]; + int ne02 = tensor->ne[2]; + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra->q)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra->qh)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra->d)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_mem), &extra->dm)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extra->s)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_mem), &data_device)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(cl_int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, 7, sizeof(cl_int), &ne01)); + CL_CHECK(clSetKernelArg(kernel, 8, sizeof(cl_uchar), &mask_0F)); + CL_CHECK(clSetKernelArg(kernel, 9, sizeof(cl_uchar), &mask_F0)); + + size_t global_work_size[] = {static_cast<size_t>(((ne01 + 63) / 64) * 64), static_cast<size_t>(ne00 / 256), static_cast<size_t>(ne02)}; + size_t local_work_size[] = {64, 1, 1}; - // Create subbuffer for quants. - region.origin = align_to(previous_origin + size_d, backend_ctx->alignment); - region.size = size_q; - extra->q = clCreateSubBuffer( - extra_orig->data_device, CL_MEM_READ_WRITE, - CL_BUFFER_CREATE_TYPE_REGION, ®ion, &err); - CL_CHECK(err); + cl_event evt; + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, + global_work_size, local_work_size, 0, NULL, &evt)); + CL_CHECK(clWaitForEvents(1, &evt)); + CL_CHECK(clEnqueueReadBuffer( + queue, data_device, CL_TRUE, offset, + size, data, 0, NULL, NULL)); + CL_CHECK(clReleaseMemObject(data_device)); + return; + } + if (use_adreno_kernels(backend_ctx, tensor)) { + int M = tensor->ne[1]; + int K = tensor->ne[0]; + + size_t size_q = extra->size_q; + size_t size_qh = extra->size_qh; + size_t size_d = extra->size_d; + size_t size_dm = extra->size_dm; + + static ggml_cl_buffer buf_trans_q; + static ggml_cl_buffer buf_trans_qh; + static ggml_cl_buffer buf_trans_d; + static ggml_cl_buffer buf_trans_dm; + + buf_trans_q.allocate(backend_ctx->context, size_q); + buf_trans_qh.allocate(backend_ctx->context, size_qh); + buf_trans_d.allocate(backend_ctx->context, size_d); + buf_trans_dm.allocate(backend_ctx->context, size_dm); + + // Reverse transpose q, qh, d, dm + transpose_2d_as_16b(backend_ctx, extra->q, buf_trans_q.buffer, size_q, M, K/4); + transpose_2d_as_8b (backend_ctx, extra->qh, buf_trans_qh.buffer, size_qh, M, K/8); + transpose_2d_as_16b(backend_ctx, extra->d, buf_trans_d.buffer, size_d, M, K/256); + transpose_2d_as_16b(backend_ctx, extra->dm, buf_trans_dm.buffer, size_dm, M, K/256); + + cl_kernel kernel = backend_ctx->kernel_restore_block_q5_K_noshuffle; + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &buf_trans_q.buffer)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &buf_trans_qh.buffer)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra->s)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_mem), &buf_trans_d.buffer)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &buf_trans_dm.buffer)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_mem), &data_device)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(cl_uchar), &mask_0F)); + CL_CHECK(clSetKernelArg(kernel, 7, sizeof(cl_uchar), &mask_F0)); - //cl_kernel kernel = backend_ctx->kernel_convert_block_q4_0; - #ifdef GGML_OPENCL_USE_ADRENO_KERNELS - cl_kernel kernel = backend_ctx->kernel_convert_block_q4_0; + size_t global_work_size[] = {(size_t)ggml_nelements(tensor)/ggml_blck_size(tensor->type), 1, 1}; + size_t local_work_size[] = {1, 1, 1}; - // The optimized kernels need weights in natural order, so unshuffle. - if (use_adreno_kernels(backend_ctx, tensor)) { - kernel = backend_ctx->kernel_convert_block_q4_0_noshuffle; + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, + global_work_size, local_work_size, 0, NULL, NULL)); + CL_CHECK(clEnqueueReadBuffer(queue, data_device, CL_TRUE, offset, + size, data, 0, NULL, NULL)); + CL_CHECK(clReleaseMemObject(data_device)); + return; } - #else - cl_kernel kernel = backend_ctx->kernel_convert_block_q4_0; - #endif // GGML_OPENCL_USE_ADRENO_KERNELS - CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &data_device)); - CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra->q)); - CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra->d)); +#endif // GGML_OPENCL_USE_ADRENO_KERNELS + + cl_kernel kernel = backend_ctx->kernel_restore_block_q5_K; + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra->q)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra->qh)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra->s)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_mem), &extra->d)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extra->dm)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_mem), &data_device)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(cl_uchar), &mask_0F)); + CL_CHECK(clSetKernelArg(kernel, 7, sizeof(cl_uchar), &mask_F0)); size_t global_work_size[] = {(size_t)ggml_nelements(tensor)/ggml_blck_size(tensor->type), 1, 1}; - size_t local_work_size[] = {64, 1, 1}; + size_t local_work_size[] = {1, 1, 1}; cl_event evt; - CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, &evt)); + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, + global_work_size, local_work_size, 0, NULL, &evt)); CL_CHECK(clWaitForEvents(1, &evt)); + CL_CHECK(clEnqueueReadBuffer( + queue, data_device, CL_TRUE, offset, + size, data, 0, NULL, NULL)); CL_CHECK(clReleaseMemObject(data_device)); + return; + } + if (tensor->type == GGML_TYPE_Q6_K) { + ggml_tensor_extra_cl_q6_K * extra = (ggml_tensor_extra_cl_q6_K *)tensor->extra; - tensor->extra = extra; - - // transpose the weights and scales - #ifdef GGML_OPENCL_USE_ADRENO_KERNELS - // Only do transpose for large, non batched matrix - // TODO: use preallocated images instead of sub-buffer then image - if (use_adreno_kernels(backend_ctx, tensor)) { - // <----------------------------------------------------------------------------------> // - // start transpose - // <----------------------------------------------------------------------------------> // - int M = tensor->ne[1]; // ne01 - int K = tensor->ne[0]; // ne00 - - //For matrix-vector multiplication kernel, we assume K is a multiple of 32 - GGML_ASSERT(K % 32 == 0); - //For transpose kernels, we assume K is a multiple of 4 (satisfied by prior assert), and M is a multiple of 4 - GGML_ASSERT(M % 4 == 0); +#ifdef GGML_OPENCL_USE_ADRENO_KERNELS + if (use_adreno_moe_kernels(backend_ctx, tensor)) { + cl_int err; + cl_mem data_device = clCreateBuffer(context, CL_MEM_READ_WRITE, + ggml_nbytes(tensor), NULL, &err); + CL_CHECK(err); - // transpose is out of place, so we need to allocate transposed buffers - // <----------------------------------------------------------------------------------> // - // use sub_buffer of max buffer size instead + cl_kernel kernel = backend_ctx->kernel_restore_block_q6_k_trans4_ns; - size_t q_size_bytes = K * M / 8 * sizeof(float); - backend_ctx->prealloc_quant_trans.allocate(context, q_size_bytes); + cl_uchar mask_0F = 0x0F; + cl_uchar mask_F0 = 0xF0; - cl_buffer_region region; - region.origin = 0; - region.size = q_size_bytes; - cl_mem qT_d = clCreateSubBuffer( - backend_ctx->prealloc_quant_trans.buffer, - 0, - CL_BUFFER_CREATE_TYPE_REGION, - ®ion, - &err); - CL_CHECK(err); + int ne00 = tensor->ne[0]; + int ne01 = tensor->ne[1]; + int ne02 = tensor->ne[2]; + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra->ql)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra->qh)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra->d)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_mem), &extra->s)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &data_device)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(cl_int), &ne01)); + CL_CHECK(clSetKernelArg(kernel, 7, sizeof(cl_uchar), &mask_0F)); + CL_CHECK(clSetKernelArg(kernel, 8, sizeof(cl_uchar), &mask_F0)); + + size_t global_work_size[] = {static_cast<size_t>(((ne01 + 63) / 64) * 64), static_cast<size_t>(ne00 / 256), static_cast<size_t>(ne02)}; + size_t local_work_size[] = {64, 1, 1}; - bool K_tile_trans = true; - if ((K / 32) % 4 != 0){ - K_tile_trans =false; + cl_event evt; + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, + global_work_size, local_work_size, 0, NULL, &evt)); + CL_CHECK(clWaitForEvents(1, &evt)); + CL_CHECK(clEnqueueReadBuffer( + queue, data_device, CL_TRUE, offset, + size, data, 0, NULL, NULL)); + CL_CHECK(clReleaseMemObject(data_device)); + return; } + if (use_adreno_kernels(backend_ctx, tensor) && !use_flat_gemv_for_large_m_q6_K(tensor)) { + static ggml_cl_buffer buf_trans_ql; + static ggml_cl_buffer buf_trans_qh; + static ggml_cl_buffer buf_trans_s; + static ggml_cl_buffer buf_trans_d; + static ggml_cl_buffer buf_unpacked; - size_t d_size_bytes = M * (K / 32) * 2; - backend_ctx->prealloc_scales_trans.allocate(context, d_size_bytes); - - region.origin = 0; - region.size = d_size_bytes; - cl_mem dT_d = clCreateSubBuffer( - backend_ctx->prealloc_scales_trans.buffer, - 0, - CL_BUFFER_CREATE_TYPE_REGION, - ®ion, - &err); - CL_CHECK(err); + cl_int M = tensor->ne[1]; // ne01 + cl_int K = tensor->ne[0]; // ne00 - // <----------------------------------------------------------------------------------> // + GGML_ASSERT(K % ggml_blck_size(tensor->type) == 0); + size_t size_ql = ggml_nelements(tensor)/ggml_blck_size(tensor->type)*ggml_blck_size(tensor->type)/2; + size_t size_qh = ggml_nelements(tensor)/ggml_blck_size(tensor->type)*ggml_blck_size(tensor->type)/4; + size_t size_s = ggml_nelements(tensor)/ggml_blck_size(tensor->type)*ggml_blck_size(tensor->type)/16; + size_t size_d = ggml_nelements(tensor)/ggml_blck_size(tensor->type)*sizeof(ggml_fp16_t); + GGML_ASSERT(size_ql + size_qh + size_s + size_d == ggml_nbytes(tensor) && "Incorrect tensor size"); - // create images from the buffers - // <----------------------------------------------------------------------------------> // - cl_mem q_d_image1D; - cl_mem d_d_image1D; - cl_mem qT_d_image1D; - cl_mem dT_d_image1D; + buf_trans_ql.allocate(backend_ctx->context, size_ql); + buf_trans_qh.allocate(backend_ctx->context, size_qh); + buf_trans_s.allocate(backend_ctx->context, size_s); + buf_trans_d.allocate(backend_ctx->context, size_d); + buf_unpacked.allocate(backend_ctx->context, ggml_nbytes(tensor)); - cl_image_format img_fmt_1d = { CL_RGBA, CL_HALF_FLOAT }; - cl_image_desc img_desc_1d; + // transpose ql, qh, s and d back + transpose_2d_as_16b(backend_ctx, extra->ql, buf_trans_ql.buffer, size_ql, M, K/4); + transpose_2d_as_8b(backend_ctx, extra->qh, buf_trans_qh.buffer, size_qh, M, K/4); + transpose_2d_as_16b(backend_ctx, extra->s, buf_trans_s.buffer, size_s, M, K/16/2); + transpose_2d_as_16b(backend_ctx, extra->d, buf_trans_d.buffer, size_d, M, K/256); - memset(&img_desc_1d, 0, sizeof(img_desc_1d)); - img_desc_1d.image_type = CL_MEM_OBJECT_IMAGE1D_BUFFER; - img_desc_1d.image_width = M * K / 4 / 4; - img_desc_1d.buffer = extra->q; - q_d_image1D = clCreateImage(context, 0, &img_fmt_1d, &img_desc_1d, NULL, &err); - CL_CHECK(err); + // unpack + cl_uchar mask = 0xFF; + cl_ulong n_blk = ggml_nelements(tensor)/ggml_blck_size(tensor->type); + cl_kernel kernel = backend_ctx->kernel_restore_block_q6_K_noshuffle; + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &buf_trans_ql.buffer)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &buf_trans_qh.buffer)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &buf_trans_s.buffer)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_mem), &buf_trans_d.buffer)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &buf_unpacked.buffer)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_uchar), &mask)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(cl_ulong), &n_blk)); + + size_t global_work_size[] = {(size_t)n_blk, 1, 1}; + size_t local_work_size[] = {1, 1, 1}; - img_fmt_1d = { CL_RGBA, CL_HALF_FLOAT }; - memset(&img_desc_1d, 0, sizeof(img_desc_1d)); - img_desc_1d.image_type = CL_MEM_OBJECT_IMAGE1D_BUFFER; - img_desc_1d.image_width = M * K / 4 / 4; - img_desc_1d.buffer = qT_d; - qT_d_image1D = clCreateImage(context, 0, &img_fmt_1d, &img_desc_1d, NULL, &err); - CL_CHECK(err); + cl_event evt; + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, &evt)); + CL_CHECK(clWaitForEvents(1, &evt)); + CL_CHECK(clEnqueueReadBuffer(queue, buf_unpacked.buffer, CL_TRUE, offset, size, data, 0, NULL, NULL)); - memset(&img_desc_1d, 0, sizeof(img_desc_1d)); - if (K_tile_trans) { - img_fmt_1d = { CL_RGBA, CL_HALF_FLOAT }; - img_desc_1d.image_width = M * K / 32 / 4; - } else { - img_fmt_1d = { CL_R, CL_HALF_FLOAT }; - img_desc_1d.image_width = M * K / 32; + return; } - img_desc_1d.image_type = CL_MEM_OBJECT_IMAGE1D_BUFFER; - img_desc_1d.buffer = extra->d; - d_d_image1D = clCreateImage(context, 0, &img_fmt_1d, &img_desc_1d, NULL, &err); - CL_CHECK(err); +#endif // GGML_OPENCL_USE_ADRENO_KERNELS - img_fmt_1d = { CL_RGBA, CL_HALF_FLOAT }; - memset(&img_desc_1d, 0, sizeof(img_desc_1d)); - img_desc_1d.image_type = CL_MEM_OBJECT_IMAGE1D_BUFFER; - img_desc_1d.image_width = M * K / 32 / 4; - img_desc_1d.buffer = dT_d; - dT_d_image1D = clCreateImage(context, 0, &img_fmt_1d, &img_desc_1d, NULL, &err); + cl_int err; + cl_mem data_device = clCreateBuffer(context, CL_MEM_READ_WRITE, + ggml_nbytes(tensor), NULL, &err); CL_CHECK(err); - // <----------------------------------------------------------------------------------> // - - // set up and call the transpose kernels - // <----------------------------------------------------------------------------------> // - // weights - int height_q = M / 4; - int width_q = K / 4 / 4; - kernel = backend_ctx->kernel_transpose_16; - - CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &q_d_image1D)); - CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &qT_d_image1D)); - CL_CHECK(clSetKernelArg(kernel, 2, sizeof(int), &height_q)); - CL_CHECK(clSetKernelArg(kernel, 3, sizeof(int), &width_q)); - - size_t local_size_q[3] = {4, 16, 1}; - size_t global_size_q[3] = {static_cast<size_t>(width_q), static_cast<size_t>(height_q), 1}; - CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_size_q, local_size_q, 0, NULL, &evt)); + + cl_uchar mask = 0xFF; + cl_ulong n_blk = ggml_nelements(tensor)/ggml_blck_size(tensor->type); + cl_kernel kernel = backend_ctx->kernel_restore_block_q6_K; + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra->ql)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra->qh)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra->s)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_mem), &extra->d)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &data_device)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_uchar), &mask)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(cl_ulong), &n_blk)); + + size_t global_work_size[] = {(size_t)n_blk, 1, 1}; + size_t local_work_size[] = {1, 1, 1}; + + cl_event evt; + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, + global_work_size, local_work_size, 0, NULL, &evt)); CL_CHECK(clWaitForEvents(1, &evt)); + CL_CHECK(clEnqueueReadBuffer( + queue, data_device, CL_TRUE, offset, + size, data, 0, NULL, NULL)); + CL_CHECK(clReleaseMemObject(data_device)); + return; + } +#endif // GGML_OPENCL_SOA_Q - // scales - int height_s = M / 4; - int width_s = K / 32 / 4; + if (tensor->type == GGML_TYPE_BF16) { + GGML_ASSERT(offset % sizeof(ggml_fp16_t) == 0 && size % sizeof(ggml_fp16_t) == 0 + && "Offset and size must be multiples of 2 for bf16 tensors"); - kernel = backend_ctx->kernel_transpose_16; - if (!K_tile_trans) { - kernel = backend_ctx->kernel_transpose_16_4x1; - width_s = K / 32; - } - CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &d_d_image1D)); - CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &dT_d_image1D)); - CL_CHECK(clSetKernelArg(kernel, 2, sizeof(int), &height_s)); - CL_CHECK(clSetKernelArg(kernel, 3, sizeof(int), &width_s)); + ggml_tensor_extra_cl * extra = (ggml_tensor_extra_cl *) tensor->extra; + GGML_ASSERT(extra); - size_t local_size_s[3] = {4, 16, 1}; - size_t global_size_s[3] = {static_cast<size_t>(width_s), static_cast<size_t>(height_s), 1}; - CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_size_s, local_size_s, 0, NULL, &evt)); - CL_CHECK(clWaitForEvents(1, &evt)); - // <----------------------------------------------------------------------------------> // + cl_ulong n_elements = size / sizeof(ggml_fp16_t); + cl_ulong off_src = (extra->offset + tensor->view_offs + offset) / sizeof(ggml_fp16_t); - // copy transposed buffer contents to original buffers - // <----------------------------------------------------------------------------------> // - // weights - CL_CHECK(clEnqueueCopyBuffer(queue, qT_d, extra->q, 0, 0, q_size_bytes, 0, NULL, &evt)); - CL_CHECK(clWaitForEvents(1, &evt)); + cl_int err; + cl_mem data_device = clCreateBuffer(context, CL_MEM_READ_WRITE, size, NULL, &err); + CL_CHECK(err); - // scales - CL_CHECK(clEnqueueCopyBuffer(queue, dT_d, extra->d, 0, 0, d_size_bytes, 0, NULL, &evt)); - CL_CHECK(clWaitForEvents(1, &evt)); - // <----------------------------------------------------------------------------------> // + cl_kernel kernel = backend_ctx->kernel_convert_f16_to_bf16; + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra->data_device)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &off_src)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &data_device)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &n_elements)); - // deallocate transpose buffers - // <----------------------------------------------------------------------------------> // - CL_CHECK(clReleaseMemObject(qT_d)); - CL_CHECK(clReleaseMemObject(dT_d)); + size_t global_work_size[] = { (size_t)CEIL_DIV(n_elements, 64)*64, 1, 1 }; + size_t local_work_size[] = { 64, 1, 1 }; - // deallocate temporary images - CL_CHECK(clReleaseMemObject(q_d_image1D)); - CL_CHECK(clReleaseMemObject(d_d_image1D)); - CL_CHECK(clReleaseMemObject(qT_d_image1D)); - CL_CHECK(clReleaseMemObject(dT_d_image1D)); - // <----------------------------------------------------------------------------------> // - // end transpose - // <----------------------------------------------------------------------------------> // - } - #endif // GGML_OPENCL_USE_ADRENO_KERNELS + cl_event evt; + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, &evt)); + CL_CHECK(clWaitForEvents(1, &evt)); + CL_CHECK(clReleaseEvent(evt)); - return; + CL_CHECK(clEnqueueReadBuffer( + queue, data_device, CL_TRUE, 0, size, data, 0, NULL, NULL)); + CL_CHECK(clReleaseMemObject(data_device)); + return; } - if (tensor->type == GGML_TYPE_MXFP4) { - ggml_tensor_extra_cl * extra_orig = (ggml_tensor_extra_cl *)tensor->extra; - GGML_ASSERT(extra_orig && "Tesnors in OpenCL backend should have been allocated and initialized"); - // Allocate the new extra and create aliases from the original. - ggml_backend_opencl_buffer_context * ctx = (ggml_backend_opencl_buffer_context *) buffer->context; - ggml_tensor_extra_cl_mxfp4 * extra = ctx->ggml_opencl_alloc_temp_tensor_extra_mxfp4(); + ggml_tensor_extra_cl * extra = (ggml_tensor_extra_cl *) tensor->extra; - size_t size_e = ggml_nelements(tensor)/ggml_blck_size(tensor->type)*sizeof(char); - size_t size_q = ggml_nelements(tensor)/ggml_blck_size(tensor->type)*ggml_blck_size(tensor->type)/2; - GGML_ASSERT(size_e + size_q == ggml_nbytes(tensor) && "Incorrect tensor size"); + CL_CHECK(clEnqueueReadBuffer( + queue, extra->data_device, CL_TRUE, extra->offset + tensor->view_offs + offset, + size, data, 0, NULL, NULL)); - cl_int err; - cl_mem data_device = clCreateBuffer(context, CL_MEM_READ_WRITE, - ggml_nbytes(tensor), NULL, &err); - CL_CHECK(err); - CL_CHECK(clEnqueueWriteBuffer( - queue, data_device, CL_TRUE, 0, - ggml_nbytes(tensor), data, 0, NULL, NULL)); + GGML_UNUSED(buffer); +} - // The original tensor memory is divided into scales and quants, i.e., - // we first store scales, then quants. - cl_buffer_region region; +static void ggml_backend_opencl_buffer_clear(ggml_backend_buffer_t buffer, uint8_t value) { + ggml_backend_opencl_device_context * dev_ctx = (ggml_backend_opencl_device_context *) buffer->buft->device->context; + ggml_backend_opencl_context * backend_ctx = dev_ctx->backend_ctx; - // Create subbuffer for scales. - region.origin = align_to(extra_orig->offset + tensor->view_offs + offset, backend_ctx->alignment); - region.size = size_e; - extra->e = clCreateSubBuffer( - extra_orig->data_device, CL_MEM_READ_WRITE, - CL_BUFFER_CREATE_TYPE_REGION, ®ion, &err); - CL_CHECK(err); - auto previous_origin = region.origin; + cl_command_queue queue = backend_ctx->queue; - // Create subbuffer for quants. - region.origin = align_to(previous_origin + size_e, backend_ctx->alignment); - region.size = size_q; - extra->q = clCreateSubBuffer( - extra_orig->data_device, CL_MEM_READ_WRITE, - CL_BUFFER_CREATE_TYPE_REGION, ®ion, &err); - CL_CHECK(err); + ggml_backend_opencl_buffer_context * ctx = (ggml_backend_opencl_buffer_context *) buffer->context; + for (cl_mem buf : ctx->buffer) { + CL_CHECK(clEnqueueFillBuffer(queue, buf, &value, sizeof(value), 0, buffer->size, 0, NULL, NULL)); + } + CL_CHECK(clFinish(queue)); +} -#ifdef GGML_OPENCL_USE_ADRENO_KERNELS - if (use_adreno_moe_kernels(backend_ctx, tensor)) { - cl_kernel kernel = backend_ctx->kernel_convert_block_mxfp4_trans; +static void ggml_backend_opencl_buffer_reset(ggml_backend_buffer_t buffer) { + ggml_backend_opencl_buffer_context * ctx = (ggml_backend_opencl_buffer_context *) buffer->context; + ctx->reset(); +} - int ne00 = tensor->ne[0]; - int ne01 = tensor->ne[1]; - int ne02 = tensor->ne[2]; - CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &data_device)); - CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra->q)); - CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra->e)); - CL_CHECK(clSetKernelArg(kernel, 3, sizeof(int), &ne00)); - CL_CHECK(clSetKernelArg(kernel, 4, sizeof(int), &ne01)); +static ggml_backend_buffer_i ggml_backend_opencl_buffer_interface = { + /* .free_buffer = */ ggml_backend_opencl_buffer_free_buffer, + /* .get_base = */ ggml_backend_opencl_buffer_get_base, + /* .init_tensor = */ ggml_backend_opencl_buffer_init_tensor, + /* .memset_tensor = */ NULL, + /* .set_tensor = */ ggml_backend_opencl_buffer_set_tensor, + /* .get_tensor = */ ggml_backend_opencl_buffer_get_tensor, + /* .set_tensor_2d = */ NULL, + /* .get_tensor_2d = */ NULL, + /* .cpy_tensor = */ NULL, + /* .clear = */ ggml_backend_opencl_buffer_clear, + /* .reset = */ ggml_backend_opencl_buffer_reset, +}; + +// +// buffer type +// + +static const char * ggml_backend_opencl_buffer_type_get_name(ggml_backend_buffer_type_t buffer_type) { + return "OpenCL"; + + GGML_UNUSED(buffer_type); +} - size_t global_work_size[3] = {static_cast<size_t>(((ne01 + 63) / 64) * 64), static_cast<size_t>(ne00 / 32), static_cast<size_t>(ne02)}; - size_t local_work_size[3] = {64, 2, 1}; +static ggml_backend_buffer_t ggml_backend_opencl_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buffer_type, size_t size) { + ggml_backend_opencl_context *backend_ctx = ggml_cl_init(buffer_type->device); + load_cl_kernels(backend_ctx); - cl_event evt; - CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, &evt)); - CL_CHECK(clWaitForEvents(1, &evt)); - CL_CHECK(clReleaseMemObject(data_device)); - tensor->extra = extra; + // clCreateBuffer returns -61 for size 0 + size = std::max(size, (size_t)1); - return; - } -#endif - cl_kernel kernel = backend_ctx->kernel_convert_block_mxfp4; + cl_int err; + cl_mem mem = clCreateBuffer(backend_ctx->context, CL_MEM_READ_WRITE, size, NULL, &err); + if (err != CL_SUCCESS && backend_ctx->adreno_use_large_buffer) { + cl_mem_properties props[] = { 0x41A6 /* CL_LARGE_BUFFER_QCOM */, 1, 0 }; + mem = clCreateBufferWithProperties(backend_ctx->context, props, CL_MEM_READ_WRITE, size, NULL, &err); + } - CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &data_device)); - CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra->q)); - CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra->e)); + if (err != CL_SUCCESS) { + GGML_LOG_INFO("%s: failed to allocate %.2f MiB\n", __func__, size / 1024.0 / 1024.0); + return nullptr; + } - size_t global_work_size[3] = {(size_t)ggml_nelements(tensor)/ggml_blck_size(tensor->type), 1, 1}; - size_t local_work_size[3] = {64, 1, 1}; + ggml_backend_opencl_buffer_context * ctx = new ggml_backend_opencl_buffer_context(mem); - cl_event evt; - CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, &evt)); - CL_CHECK(clWaitForEvents(1, &evt)); - CL_CHECK(clReleaseMemObject(data_device)); + return ggml_backend_buffer_init(buffer_type, ggml_backend_opencl_buffer_interface, ctx, size); +} - // Create image for Q - cl_image_format img_format_q = {CL_RG, CL_UNSIGNED_INT32}; - cl_image_desc img_desc_q = { - CL_MEM_OBJECT_IMAGE1D_BUFFER, - static_cast<size_t>(ggml_nelements(tensor)/32*2), - 0, 0, 0, 0, 0, 0, 0, - { extra->q } - }; - extra->q_img = clCreateImage(context, CL_MEM_READ_ONLY, &img_format_q, &img_desc_q, NULL, &err); - tensor->extra = extra; +static size_t ggml_backend_opencl_buffer_type_get_alignment(ggml_backend_buffer_type_t buffer_type) { + ggml_backend_opencl_device_context * dev_ctx = (ggml_backend_opencl_device_context *) buffer_type->device->context; + return dev_ctx->backend_ctx->alignment; +} - return; +static size_t ggml_backend_opencl_buffer_type_get_max_size(ggml_backend_buffer_type_t buffer_type) { + static size_t max_size = -1; + if (max_size == (size_t)-1) { + ggml_backend_opencl_device_context * dev_ctx = (ggml_backend_opencl_device_context *) buffer_type->device->context; + max_size = dev_ctx->backend_ctx->max_alloc_size; } - if (tensor->type == GGML_TYPE_Q8_0) { - ggml_tensor_extra_cl * extra_orig = (ggml_tensor_extra_cl *)tensor->extra; - GGML_ASSERT(extra_orig && "Tesnors in OpenCL backend should have been allocated and initialized"); + return max_size; +} - // Allocate the new extra and create aliases from the original. - ggml_backend_opencl_buffer_context * ctx = (ggml_backend_opencl_buffer_context *) buffer->context; - ggml_tensor_extra_cl_q8_0 * extra = ctx->ggml_opencl_alloc_temp_tensor_extra_q8_0(); +static bool ggml_backend_opencl_buffer_type_supports_backend(ggml_backend_buffer_type_t buft, ggml_backend_t backend) { + return ggml_backend_is_opencl(backend); - size_t size_d = ggml_nelements(tensor)/ggml_blck_size(tensor->type)*sizeof(ggml_fp16_t); - size_t size_q = ggml_nelements(tensor)/ggml_blck_size(tensor->type)*(ggml_blck_size(tensor->type)*sizeof(char)); - GGML_ASSERT(size_d + size_q == ggml_nbytes(tensor) && "Incorrect tensor size"); + UNUSED(buft); +} - cl_int err; - cl_mem data_device = clCreateBuffer(context, CL_MEM_READ_WRITE, - ggml_nbytes(tensor), NULL, &err); - CL_CHECK(err); - CL_CHECK(clEnqueueWriteBuffer( - queue, data_device, CL_TRUE, 0, - ggml_nbytes(tensor), data, 0, NULL, NULL)); +static ggml_backend_buffer_type_i ggml_backend_opencl_buffer_type_interface = { + /* .get_name = */ ggml_backend_opencl_buffer_type_get_name, + /* .alloc_buffer = */ ggml_backend_opencl_buffer_type_alloc_buffer, + /* .get_alignment = */ ggml_backend_opencl_buffer_type_get_alignment, + /* .get_max_size = */ ggml_backend_opencl_buffer_type_get_max_size, + /* .get_alloc_size = */ NULL, + /* .is_host = */ NULL, +}; - // The original tensor memory is divided into scales and quants, i.e., - // we first store scales, then quants. - cl_buffer_region region; +// +// backend device +// - // Create subbuffer for scales. - region.origin = align_to(extra_orig->offset + tensor->view_offs + offset, backend_ctx->alignment); - region.size = size_d; - extra->d = clCreateSubBuffer( - extra_orig->data_device, CL_MEM_READ_WRITE, - CL_BUFFER_CREATE_TYPE_REGION, ®ion, &err); - CL_CHECK(err); - auto previous_origin = region.origin; +static const char * ggml_backend_opencl_device_get_name(ggml_backend_dev_t dev) { + return "GPUOpenCL"; - // Create subbuffer for quants. - region.origin = align_to(previous_origin + size_d, backend_ctx->alignment); - region.size = size_q; - extra->q = clCreateSubBuffer( - extra_orig->data_device, CL_MEM_READ_WRITE, - CL_BUFFER_CREATE_TYPE_REGION, ®ion, &err); - CL_CHECK(err); + GGML_UNUSED(dev); +} - cl_kernel kernel = backend_ctx->kernel_convert_block_q8_0; +static const char * ggml_backend_opencl_device_get_description(ggml_backend_dev_t dev) { + ggml_backend_opencl_device_context *dev_ctx = (ggml_backend_opencl_device_context *) dev->context; + return dev_ctx->device_name.c_str(); +} - CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &data_device)); - CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra->q)); - CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra->d)); +static void ggml_backend_opencl_device_get_memory(ggml_backend_dev_t dev, size_t * free, size_t * total) { + ggml_backend_opencl_device_context * dev_ctx = (ggml_backend_opencl_device_context *) dev->context; - size_t global_work_size[] = {(size_t)ggml_nelements(tensor)/ggml_blck_size(tensor->type), 1, 1}; - size_t local_work_size[] = {64, 1, 1}; + static const size_t opencl_extra_margin = 1024ull*1024ull*1024ull; - cl_event evt; - CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, &evt)); - CL_CHECK(clWaitForEvents(1, &evt)); - CL_CHECK(clReleaseMemObject(data_device)); + // OpenCL does not provide reliable currently-free device memory. + // Use total/global memory as a best-effort upper bound. + // Improved safety: Reduce by a 1GiB extra margin for common --fit + *total = dev_ctx->global_mem_size; + *free = *total > opencl_extra_margin ? *total - opencl_extra_margin : 0; +} - tensor->extra = extra; +static enum ggml_backend_dev_type ggml_backend_opencl_device_get_type(ggml_backend_dev_t dev) { + return GGML_BACKEND_DEVICE_TYPE_GPU; - return; - } -#endif // GGML_OPENCL_SOA_Q + GGML_UNUSED(dev); +} - ggml_tensor_extra_cl * extra = (ggml_tensor_extra_cl *) tensor->extra; - GGML_ASSERT(extra); +static void ggml_backend_opencl_device_get_props(ggml_backend_dev_t dev, struct ggml_backend_dev_props * props) { + props->name = ggml_backend_opencl_device_get_name(dev); + props->description = ggml_backend_opencl_device_get_description(dev); + props->type = ggml_backend_opencl_device_get_type(dev); + ggml_backend_opencl_device_get_memory(dev, &props->memory_free, &props->memory_total); + props->caps = ggml_backend_dev_caps { + /* .async = */ false, + /* .host_buffer = */ false, + /* .buffer_from_host_ptr = */ false, + /* .events = */ false, + }; +} - CL_CHECK(clEnqueueWriteBuffer( - queue, extra->data_device, CL_TRUE, extra->offset + offset, - size, data, 0, NULL, NULL)); +static ggml_backend_t ggml_backend_opencl_device_init(ggml_backend_dev_t dev, const char * params) { + ggml_backend_opencl_context * backend_ctx = ggml_cl_init(dev); + // Getting a new reference to the backend, increase ref_count + backend_ctx->ref_count++; - GGML_UNUSED(buffer); + ggml_backend_t backend = new ggml_backend { + /* .guid = */ ggml_backend_opencl_guid(), + /* .interface = */ ggml_backend_opencl_i, + /* .device = */ dev, + /* .context = */ backend_ctx, + }; + + ggml_backend_opencl_device_context * dev_ctx = (ggml_backend_opencl_device_context *) dev->context; + ggml_opencl_print_backend_info(dev_ctx); + return backend; + + GGML_UNUSED(params); } -static void ggml_backend_opencl_buffer_get_tensor(ggml_backend_buffer_t buffer, const ggml_tensor * tensor, void * data, size_t offset, size_t size) { - GGML_ASSERT(tensor->extra); +static ggml_backend_buffer_type_t ggml_backend_opencl_device_get_buffer_type(ggml_backend_dev_t dev) { + auto * dev_ctx = static_cast<ggml_backend_opencl_device_context *>(dev->context); - ggml_backend_opencl_context *backend_ctx = ggml_cl2_init(buffer->buft->device); + dev_ctx->buffer_type = ggml_backend_buffer_type{ + /* .iface = */ ggml_backend_opencl_buffer_type_interface, + /* .device = */ dev, + /* .context = */ nullptr, + }; - cl_context context = backend_ctx->context; - cl_command_queue queue = backend_ctx->queue; + return &dev_ctx->buffer_type; +} - // Make sure all previously submitted commands in other devices are finished. - sync_with_other_backends(backend_ctx); +static ggml_backend_buffer_t ggml_backend_opencl_device_buffer_from_ptr(ggml_backend_dev_t dev, void * ptr, size_t size, size_t max_tensor_size) { + GGML_UNUSED(dev); + GGML_UNUSED(ptr); + GGML_UNUSED(size); + GGML_UNUSED(max_tensor_size); + return nullptr; +} -#ifdef GGML_OPENCL_SOA_Q - // In end-to-end runs, get_tensor is usually used to get back the logits, - // where we can simply do clEnqueueReadBuffer since they are f32. - // However, in test-backend-ops, the GPU graph is copied to the CPU backend, - // which requires reading back quantized weight tensors. - // To properly support this, we need to restore block_q4_0 struct arrays - // from the flattened buffers. - if (tensor->type == GGML_TYPE_Q4_0) { - ggml_tensor_extra_cl_q4_0 * extra = (ggml_tensor_extra_cl_q4_0 *)tensor->extra; +static bool ggml_backend_opencl_device_supports_op(ggml_backend_dev_t dev, const struct ggml_tensor * op) { + ggml_cl_init(dev); + return ggml_opencl_supports_op(dev, op); +} -#ifdef GGML_OPENCL_USE_ADRENO_KERNELS - if (use_adreno_kernels(backend_ctx, tensor)) { - cl_int err; - cl_kernel kernel; +static bool ggml_backend_opencl_device_supports_buft(ggml_backend_dev_t dev, ggml_backend_buffer_type_t buft) { + // Check 'dev' and 'buffer_type' are not objects belonging to this backend. + if (dev->iface.get_name != ggml_backend_opencl_device_get_name || + buft->iface.get_name != ggml_backend_opencl_buffer_type_get_name) { + return false; + } - cl_int M = tensor->ne[1]; // ne01 - cl_int K = tensor->ne[0]; // ne00 + // Check cl_context is the same. clEnqueue* commands may not use + // buffers from another cl_context. + ggml_backend_opencl_context * backend_ctx0 = ggml_cl_init(dev); + ggml_backend_opencl_context * backend_ctx1 = ggml_cl_init(buft->device); + return backend_ctx0->context == backend_ctx1->context; +} - GGML_ASSERT(K % 32 == 0); - GGML_ASSERT(M % 4 == 0); +namespace /* anonymous */ { +struct ggml_backend_device_i ggml_backend_opencl_device_i = { + /* .get_name = */ ggml_backend_opencl_device_get_name, + /* .get_description = */ ggml_backend_opencl_device_get_description, + /* .get_memory = */ ggml_backend_opencl_device_get_memory, + /* .get_type = */ ggml_backend_opencl_device_get_type, + /* .get_props = */ ggml_backend_opencl_device_get_props, + /* .init_backend = */ ggml_backend_opencl_device_init, + /* .get_buffer_type = */ ggml_backend_opencl_device_get_buffer_type, + /* .get_host_buffer_type = */ NULL, + /* .buffer_from_host_ptr = */ ggml_backend_opencl_device_buffer_from_ptr, + /* .supports_op = */ ggml_backend_opencl_device_supports_op, + /* .supports_buft = */ ggml_backend_opencl_device_supports_buft, + /* .offload_op = */ NULL, + /* .event_new = */ NULL, + /* .event_free = */ NULL, + /* .event_synchronize = */ NULL, +}; +} - size_t size_q = (ggml_nelements(tensor)/ggml_blck_size(tensor->type))*ggml_blck_size(tensor->type)/2; - size_t size_d = (ggml_nelements(tensor)/ggml_blck_size(tensor->type))*sizeof(ggml_fp16_t); - GGML_ASSERT(size_d + size_q == ggml_nbytes(tensor) && "Incorrect tensor size"); +// Backend registry + +static const char * ggml_backend_opencl_reg_get_name(ggml_backend_reg_t reg) { + return "OpenCL"; + + GGML_UNUSED(reg); +} - cl_mem buf_trans_q; - cl_mem buf_trans_d; +static size_t ggml_backend_opencl_reg_device_count(ggml_backend_reg_t reg) { + return g_ggml_backend_opencl_devices.size(); - CL_CHECK((buf_trans_q = clCreateBuffer(context, CL_MEM_READ_WRITE, - size_q, NULL, &err), err)); - CL_CHECK((buf_trans_d = clCreateBuffer(context, CL_MEM_READ_WRITE, - size_d, NULL, &err), err)); + GGML_UNUSED(reg); +} - kernel = backend_ctx->kernel_transpose_16_buf; +static ggml_backend_dev_t ggml_backend_opencl_reg_device_get(ggml_backend_reg_t reg, size_t index) { + GGML_ASSERT(index < ggml_backend_opencl_reg_device_count(reg)); - // transpose q back - cl_int stride_k_q = K/4; - size_t local_size_q[3] = {64, 1, 1}; - size_t global_size_q[3] = {(size_t)M, (size_t)stride_k_q, 1}; + return &g_ggml_backend_opencl_devices[index]; - CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra->q)); - CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &buf_trans_q)); - CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_int), &M)); - CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_int), &stride_k_q)); + GGML_UNUSED(reg); + GGML_UNUSED(index); +} - CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, - global_size_q, local_size_q, 0, NULL, NULL)); +static struct ggml_backend_reg_i ggml_backend_opencl_reg_i = { + /* .get_name = */ ggml_backend_opencl_reg_get_name, + /* .device_count = */ ggml_backend_opencl_reg_device_count, + /* .device_get = */ ggml_backend_opencl_reg_device_get, + /* .get_proc_address = */ NULL, +}; - // transpose scales back - cl_int stride_k_d = K/32; - size_t local_size_d[3] = {64, 1, 1}; - size_t global_size_d[3] = {(size_t)M, (size_t)stride_k_d, 1}; +ggml_backend_reg_t ggml_backend_opencl_reg(void) { + static std::mutex mutex; + static ggml_backend_reg reg; + static bool initialized = false; + std::lock_guard<std::mutex> lock(mutex); - CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra->d)); - CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &buf_trans_d)); - CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_int), &M)); - CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_int), &stride_k_d)); + if (initialized) { + return ® + } + initialized = true; - CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, - global_size_d, local_size_d, 0, NULL, NULL)); + g_ggml_backend_opencl_devices = ggml_opencl_probe_devices(®); - // unpack - cl_mem data_device = clCreateBuffer(context, CL_MEM_READ_WRITE, - ggml_nbytes(tensor), NULL, &err); - CL_CHECK(err); + reg = ggml_backend_reg{ + /* .api_version = */ GGML_BACKEND_API_VERSION, + /* .iface = */ ggml_backend_opencl_reg_i, + /* .context = */ NULL, + }; - cl_uchar mask_0F = 0x0F; - cl_uchar mask_F0 = 0xF0; + return ® +} - size_t global_work_size[] = {(size_t)ggml_nelements(tensor)/ggml_blck_size(tensor->type), 1, 1}; - size_t local_work_size[] = {1, 1, 1}; +GGML_BACKEND_DL_IMPL(ggml_backend_opencl_reg) - kernel = backend_ctx->kernel_restore_block_q4_0_noshuffle; - CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &buf_trans_q)); - CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &buf_trans_d)); - CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &data_device)); - CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_uchar), &mask_0F)); - CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_uchar), &mask_F0)); +//------------------------------------------------------------------------------ +// Debugging utils +//------------------------------------------------------------------------------ +#if 0 +#define QK4_0 32 +typedef struct { + ggml_fp16_t d; // delta + uint8_t qs[QK4_0 / 2]; // nibbles / quants +} block_q4_0; +static_assert(sizeof(block_q4_0) == sizeof(ggml_fp16_t) + QK4_0 / 2, + "wrong q4_0 block size/padding"); - CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, - global_work_size, local_work_size, 0, NULL, NULL)); +#define QK_MXFP4 32 - // read back to host - CL_CHECK(clEnqueueReadBuffer( - queue, data_device, CL_TRUE, offset, - size, data, 0, NULL, NULL)); +#include <math.h> +#ifdef __cplusplus +#include "half.hpp" +#endif - CL_CHECK(clReleaseMemObject(data_device)); - CL_CHECK(clReleaseMemObject(buf_trans_q)); - CL_CHECK(clReleaseMemObject(buf_trans_d)); +static void dump_tensor(ggml_backend_t backend, const struct ggml_tensor * tensor) { + void * buf = malloc(ggml_nbytes(tensor)); - return; - } + ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context; + cl_command_queue queue = backend_ctx->queue; +#ifdef GGML_OPENCL_SOA_Q + void * buf_q; + void * buf_d; #endif - cl_int err; - cl_mem data_device = clCreateBuffer(context, CL_MEM_READ_WRITE, - ggml_nbytes(tensor), NULL, &err); - CL_CHECK(err); + // Make sure everything is done. + CL_CHECK(clFinish(queue)); - cl_kernel kernel = backend_ctx->kernel_restore_block_q4_0; - CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra->q)); - CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra->d)); - CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &data_device)); +#ifdef GGML_OPENCL_SOA_Q + if (tensor->type == GGML_TYPE_Q4_0) { + ggml_tensor_extra_cl_q4_0 * extra = (ggml_tensor_extra_cl_q4_0 *) tensor->extra; + GGML_ASSERT(extra); - size_t global_work_size[] = {(size_t)ggml_nelements(tensor)/ggml_blck_size(tensor->type), 1, 1}; - size_t local_work_size[] = {1, 1, 1}; + size_t size_q = ggml_nelements(tensor)/QK4_0 * QK4_0/2; + size_t size_d = ggml_nelements(tensor)/QK4_0 * sizeof(ggml_fp16_t); + GGML_ASSERT(size_q + size_d == ggml_nbytes(tensor)); + buf_q = malloc(size_q); + buf_d = malloc(size_d); - cl_event evt; - CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, - global_work_size, local_work_size, 0, NULL, &evt)); - CL_CHECK(clWaitForEvents(1, &evt)); - CL_CHECK(clEnqueueReadBuffer( - queue, data_device, CL_TRUE, offset, - size, data, 0, NULL, NULL)); - CL_CHECK(clReleaseMemObject(data_device)); - return; + CL_CHECK(clEnqueueReadBuffer(queue, extra->q, CL_TRUE, 0, size_q, buf_q, 0, NULL, NULL)); + CL_CHECK(clEnqueueReadBuffer(queue, extra->d, CL_TRUE, 0, size_d, buf_d, 0, NULL, NULL)); + CL_CHECK(clFinish(queue)); } else if (tensor->type == GGML_TYPE_MXFP4) { - ggml_tensor_extra_cl_mxfp4 * extra = (ggml_tensor_extra_cl_mxfp4 *)tensor->extra; + ggml_tensor_extra_cl_mxfp4 * extra = (ggml_tensor_extra_cl_mxfp4 *) tensor->extra; + GGML_ASSERT(extra); - cl_int err; - cl_mem data_device = clCreateBuffer(context, CL_MEM_READ_WRITE, - ggml_nbytes(tensor), NULL, &err); - CL_CHECK(err); + size_t size_q = ggml_nelements(tensor)/QK_MXFP4 * QK_MXFP4/2; + size_t size_e = ggml_nelements(tensor)/QK_MXFP4 * sizeof(char); + GGML_ASSERT(size_q + size_e == ggml_nbytes(tensor)); + buf_q = malloc(size_q); + buf_d = malloc(size_e); -#ifdef GGML_OPENCL_USE_ADRENO_KERNELS - if (use_adreno_moe_kernels(backend_ctx, tensor)) { - cl_kernel kernel = backend_ctx->kernel_restore_block_mxfp4_trans; + CL_CHECK(clEnqueueReadBuffer(queue, extra->q, CL_TRUE, 0, size_q, buf_q, 0, NULL, NULL)); + CL_CHECK(clEnqueueReadBuffer(queue, extra->e, CL_TRUE, 0, size_e, buf_d, 0, NULL, NULL)); + CL_CHECK(clFinish(queue)); + } else { + // Read out the tensor from GPU memory. + ggml_tensor_extra_cl * extra = (ggml_tensor_extra_cl *) tensor->extra; + GGML_ASSERT(extra); - int ne00 = tensor->ne[0]; - int ne01 = tensor->ne[1]; - int ne02 = tensor->ne[2]; - CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra->q)); - CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra->e)); - CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &data_device)); - CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_int), &ne00)); - CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_int), &ne01)); + CL_CHECK(clEnqueueReadBuffer(queue, extra->data_device, CL_TRUE, + extra->offset, ggml_nbytes(tensor), buf, 0, NULL, NULL)); + CL_CHECK(clFinish(queue)); + } +#else + // Read out the tensor from GPU memory. + ggml_tensor_extra_cl * extra = (ggml_tensor_extra_cl *) tensor->extra; + GGML_ASSERT(extra); - size_t global_work_size[3] = {static_cast<size_t>(((ne01 + 63) / 64) * 64), static_cast<size_t>(ne00 / 32), static_cast<size_t>(ne02)}; - size_t local_work_size[3] = {64, 2, 1}; + CL_CHECK(clEnqueueReadBuffer(queue, extra->data_device, CL_TRUE, + extra->offset, ggml_nbytes(tensor), buf, 0, NULL, NULL)); + CL_CHECK(clFinish(queue)); +#endif // GGML_OPENCL_SOA_Q - cl_event evt; - CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, - global_work_size, local_work_size, 0, NULL, &evt)); - CL_CHECK(clWaitForEvents(1, &evt)); - CL_CHECK(clEnqueueReadBuffer( - queue, data_device, CL_TRUE, offset, - size, data, 0, NULL, NULL)); - CL_CHECK(clReleaseMemObject(data_device)); - return; + // Open file and dump. + char fname[512]; + snprintf(fname, sizeof(fname), "./tensor-dumps/%s.txt", tensor->name); + FILE * f = fopen(fname, "w"); + if (!f) { + printf("Failed to open %s\n", fname); + return; + } + + if (tensor->type == GGML_TYPE_F32) { + float * data = (float *) buf; + for (int i = 0; i < ggml_nelements(tensor); ++i) { + if (isnan(data[i])) { + printf("NaN found: %s\n", tensor->name); + break; + } + fprintf(f, "%f\n", data[i]); + } + } else if (tensor->type == GGML_TYPE_I32) { + int * data = (int *) buf; + for (int i = 0; i < ggml_nelements(tensor); ++i) { + if (isnan(data[i])) { + printf("NaN found: %s\n", tensor->name); + break; + } + fprintf(f, "%d\n", data[i]); + } + } else if (tensor->type == GGML_TYPE_F16) { +#ifdef __cplusplus + half_float::half * data = (half_float::half *) buf; + for (int i = 0; i < ggml_nelements(tensor); ++i) { + if (std::isnan(data[i])) { + printf("NaN found: %s\n", tensor->name); + break; + } + fprintf(f, "%f\n", float(data[i])); } #endif - cl_kernel kernel = backend_ctx->kernel_restore_block_mxfp4; - CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra->q)); - CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra->e)); - CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &data_device)); - - size_t global_work_size[] = {(size_t)ggml_nelements(tensor)/ggml_blck_size(tensor->type), 1, 1}; - size_t local_work_size[] = {1, 1, 1}; + } else if (tensor->type == GGML_TYPE_Q4_0) { +#ifdef GGML_OPENCL_SOA_Q + ggml_fp16_t * data_d = (ggml_fp16_t *)buf_d; + unsigned char * data_q = (unsigned char *)buf_q; - cl_event evt; - CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, - global_work_size, local_work_size, 0, NULL, &evt)); - CL_CHECK(clWaitForEvents(1, &evt)); - CL_CHECK(clEnqueueReadBuffer( - queue, data_device, CL_TRUE, offset, - size, data, 0, NULL, NULL)); - CL_CHECK(clReleaseMemObject(data_device)); - return; + for (int i = 0; i < ggml_nelements(tensor)/QK4_0; ++i) { + fprintf(f, "%04x, ", data_d[i]); + for (int k = 0; k < QK4_0/2; ++k) { + fprintf(f, "%02x, ", data_q[k]); + } + fprintf(f, "\n"); + data_q += QK4_0/2; + } + free(buf_d); + free(buf_q); +#else + block_q4_0 * data = (block_q4_0 *) buf; + for (int i = 0; i < ggml_nelements(tensor)/QK4_0; ++i) { + fprintf(f, "%04x, ", data[i].d); + for (int k = 0; k < QK4_0/2; ++k) { + fprintf(f, "%02x, ", data[i].qs[k]); + } + fprintf(f, "\n"); + } +#endif // GGML_OPENCL_SOA_Q } - if (tensor->type == GGML_TYPE_Q8_0) { - ggml_tensor_extra_cl_q8_0 * extra = (ggml_tensor_extra_cl_q8_0 *)tensor->extra; + free(buf); + fflush(f); + fclose(f); +} +#else +#define dump_tensor(tensor) +#endif - cl_int err; - cl_mem data_device = clCreateBuffer(context, CL_MEM_READ_WRITE, - ggml_nbytes(tensor), NULL, &err); - CL_CHECK(err); +//------------------------------------------------------------------------------ +// Ops +//------------------------------------------------------------------------------ - cl_kernel kernel = backend_ctx->kernel_restore_block_q8_0; - CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra->q)); - CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra->d)); - CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &data_device)); +static bool ggml_cl_can_mul_mat(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst) { + const int64_t ne10 = src1->ne[0]; + + const int64_t ne0 = dst->ne[0]; + const int64_t ne1 = dst->ne[1]; + + // TODO: find the optimal values for these + return (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type)) && + src1->type == GGML_TYPE_F32 && + dst->type == GGML_TYPE_F32 && + (ne0 >= 32 && ne1 >= 32 && ne10 >= 32); +} + +// Copy a noncontiguous tensor to contiguous tensor. ne[] remains the same but +// nb[] is recalculated such that tensor is contiguous. +static void ggml_cl_copy_to_contiguous(ggml_backend_t backend, const ggml_tensor * src, cl_mem dst, + cl_ulong &nb0, cl_ulong &nb1, cl_ulong &nb2, cl_ulong &nb3) { + ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context; - size_t global_work_size[] = {(size_t)ggml_nelements(tensor)/ggml_blck_size(tensor->type), 1, 1}; - size_t local_work_size[] = {1, 1, 1}; + const int tensor_type_size = ggml_type_size(src->type); - cl_event evt; - CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, - global_work_size, local_work_size, 0, NULL, &evt)); - CL_CHECK(clWaitForEvents(1, &evt)); - CL_CHECK(clEnqueueReadBuffer( - queue, data_device, CL_TRUE, offset, - size, data, 0, NULL, NULL)); - CL_CHECK(clReleaseMemObject(data_device)); - return; - } -#endif // GGML_OPENCL_SOA_Q + const int ne00 = src->ne[0]; + const int ne01 = src->ne[1]; + const int ne02 = src->ne[2]; + const int ne03 = src->ne[3]; - ggml_tensor_extra_cl * extra = (ggml_tensor_extra_cl *) tensor->extra; + const cl_ulong nb00 = src->nb[0]; + const cl_ulong nb01 = src->nb[1]; + const cl_ulong nb02 = src->nb[2]; + const cl_ulong nb03 = src->nb[3]; - CL_CHECK(clEnqueueReadBuffer( - queue, extra->data_device, CL_TRUE, extra->offset + tensor->view_offs + offset, - size, data, 0, NULL, NULL)); + const int ne0 = src->ne[0]; + const int ne1 = src->ne[1]; + const int ne2 = src->ne[2]; + const int ne3 = src->ne[3]; - GGML_UNUSED(buffer); -} + nb0 = tensor_type_size; + nb1 = tensor_type_size*ne00; + nb2 = tensor_type_size*ne00*ne01; + nb3 = tensor_type_size*ne00*ne01*ne02; -static void ggml_backend_opencl_buffer_clear(ggml_backend_buffer_t buffer, uint8_t value) { - ggml_backend_dev_t dev = buffer->buft->device; - ggml_backend_opencl_context *backend_ctx = ggml_cl2_init(dev); - cl_command_queue queue = backend_ctx->queue; + ggml_tensor_extra_cl * extra = (ggml_tensor_extra_cl *)src->extra; - ggml_backend_opencl_buffer_context * ctx = (ggml_backend_opencl_buffer_context *) buffer->context; - for (cl_mem buf : ctx->buffer) { - CL_CHECK(clEnqueueFillBuffer(queue, buf, &value, sizeof(value), 0, buffer->size, 0, NULL, NULL)); + cl_ulong offset0 = extra->offset + src->view_offs; + cl_ulong offsetd = 0; + + cl_kernel kernel; + + switch (src->type) { + case GGML_TYPE_F32: + kernel = backend_ctx->kernel_cpy_f32_f32; + break; + case GGML_TYPE_F16: + case GGML_TYPE_BF16: // stored as f16 on device + kernel = backend_ctx->kernel_cpy_f16_f16; + break; + default: + GGML_ASSERT(false && "not implemented"); } - CL_CHECK(clFinish(queue)); + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra->data_device)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &dst)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offsetd)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(int), &ne01)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne02)); + CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &ne03)); + CL_CHECK(clSetKernelArg(kernel, 8, sizeof(cl_ulong), &nb00)); + CL_CHECK(clSetKernelArg(kernel, 9, sizeof(cl_ulong), &nb01)); + CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_ulong), &nb02)); + CL_CHECK(clSetKernelArg(kernel, 11, sizeof(cl_ulong), &nb03)); + CL_CHECK(clSetKernelArg(kernel, 12, sizeof(int), &ne0)); + CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int), &ne1)); + CL_CHECK(clSetKernelArg(kernel, 14, sizeof(int), &ne2)); + CL_CHECK(clSetKernelArg(kernel, 15, sizeof(int), &ne3)); + CL_CHECK(clSetKernelArg(kernel, 16, sizeof(cl_ulong), &nb0)); + CL_CHECK(clSetKernelArg(kernel, 17, sizeof(cl_ulong), &nb1)); + CL_CHECK(clSetKernelArg(kernel, 18, sizeof(cl_ulong), &nb2)); + CL_CHECK(clSetKernelArg(kernel, 19, sizeof(cl_ulong), &nb3)); + + const int nth = MIN(64, ne00); + + size_t global_work_size[] = {(size_t)ne01*nth, (size_t)ne02, (size_t)ne03}; + size_t local_work_size[] = {(size_t)nth, 1, 1}; + + backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, src); } -static void ggml_backend_opencl_buffer_reset(ggml_backend_buffer_t buffer) { - ggml_backend_opencl_buffer_context * ctx = (ggml_backend_opencl_buffer_context *) buffer->context; - ctx->reset(); +static void ggml_cl_nop(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { + UNUSED(backend); + UNUSED(src0); + UNUSED(src1); + UNUSED(dst); } -static ggml_backend_buffer_i ggml_backend_opencl_buffer_interface = { - /* .free_buffer = */ ggml_backend_opencl_buffer_free_buffer, - /* .get_base = */ ggml_backend_opencl_buffer_get_base, - /* .init_tensor = */ ggml_backend_opencl_buffer_init_tensor, - /* .memset_tensor = */ NULL, - /* .set_tensor = */ ggml_backend_opencl_buffer_set_tensor, - /* .get_tensor = */ ggml_backend_opencl_buffer_get_tensor, - /* .cpy_tensor = */ NULL, - /* .clear = */ ggml_backend_opencl_buffer_clear, - /* .reset = */ ggml_backend_opencl_buffer_reset, -}; +static void ggml_cl_get_rows(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { + GGML_ASSERT(src0); + GGML_ASSERT(src0->extra); + GGML_ASSERT(src1); + GGML_ASSERT(src1->extra); + GGML_ASSERT(dst); + GGML_ASSERT(dst->extra); -// -// buffer type -// + GGML_TENSOR_LOCALS(int, ne0, src0, ne); + GGML_TENSOR_LOCALS(cl_ulong, nb0, src0, nb); + GGML_TENSOR_LOCALS(int, ne1, src1, ne); + GGML_TENSOR_LOCALS(cl_ulong, nb1, src1, nb); + GGML_TENSOR_LOCALS(int, ne, dst, ne); + GGML_TENSOR_LOCALS(cl_ulong, nb, dst, nb); -static const char * ggml_backend_opencl_buffer_type_get_name(ggml_backend_buffer_type_t buffer_type) { - return "OpenCL"; + ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context; - GGML_UNUSED(buffer_type); -} + ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra; + ggml_tensor_extra_cl * extra1 = (ggml_tensor_extra_cl *)src1->extra; + ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra; -static ggml_backend_buffer_t ggml_backend_opencl_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buffer_type, size_t size) { - ggml_backend_opencl_context *backend_ctx = ggml_cl2_init(buffer_type->device); + cl_ulong offset0 = extra0->offset + src0->view_offs; + cl_ulong offset1 = extra1->offset + src1->view_offs; + cl_ulong offsetd = extrad->offset + dst->view_offs; - // clCreateBuffer returns -61 for size 0 - size = std::max(size, (size_t)1); + cl_kernel kernel; - cl_int err; - cl_mem mem = clCreateBuffer(backend_ctx->context, CL_MEM_READ_WRITE, size, NULL, &err); - if (err != CL_SUCCESS) { - GGML_LOG_INFO("%s: failed to allocate %.2f MiB\n", __func__, size / 1024.0 / 1024.0); - return nullptr; + switch (src0->type) { + case GGML_TYPE_F32: + kernel = backend_ctx->kernel_get_rows_f32; + break; + case GGML_TYPE_F16: + kernel = backend_ctx->kernel_get_rows_f16; + break; + case GGML_TYPE_Q4_0: + kernel = backend_ctx->kernel_get_rows_q4_0; + break; + default: + GGML_ASSERT(false && "not implemented"); } - ggml_backend_opencl_buffer_context * ctx = new ggml_backend_opencl_buffer_context(mem); - - return ggml_backend_buffer_init(buffer_type, ggml_backend_opencl_buffer_interface, ctx, size); -} + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra1->data_device)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offset1)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offsetd)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, 7, sizeof(cl_ulong), &nb01)); + CL_CHECK(clSetKernelArg(kernel, 8, sizeof(cl_ulong), &nb02)); + CL_CHECK(clSetKernelArg(kernel, 9, sizeof(cl_ulong), &nb03)); + CL_CHECK(clSetKernelArg(kernel, 10, sizeof(int), &ne10)); + CL_CHECK(clSetKernelArg(kernel, 11, sizeof(cl_ulong), &nb10)); + CL_CHECK(clSetKernelArg(kernel, 12, sizeof(cl_ulong), &nb11)); + CL_CHECK(clSetKernelArg(kernel, 13, sizeof(cl_ulong), &nb12)); + CL_CHECK(clSetKernelArg(kernel, 14, sizeof(cl_ulong), &nb1)); + CL_CHECK(clSetKernelArg(kernel, 15, sizeof(cl_ulong), &nb2)); + CL_CHECK(clSetKernelArg(kernel, 16, sizeof(cl_ulong), &nb3)); -static size_t ggml_backend_opencl_buffer_type_get_alignment(ggml_backend_buffer_type_t buffer_type) { - ggml_backend_opencl_context * backend_ctx = ggml_cl2_init(buffer_type->device); - return backend_ctx->alignment; -} + int max_workgroup_size = backend_ctx->get_kernel_workgroup_size(kernel); + int nth = 1; + while (nth < ne00 && 2*nth <= max_workgroup_size) { + nth *= 2; + } -static size_t ggml_backend_opencl_buffer_type_get_max_size(ggml_backend_buffer_type_t buffer_type) { - static size_t max_size = -1; - if (max_size == (size_t)-1) { - ggml_backend_opencl_context * backend_ctx = ggml_cl2_init(buffer_type->device); - max_size = backend_ctx->max_alloc_size; + int nchunks = 1; + if (src0->type == GGML_TYPE_F32) { + const int chunk_target = nth * 4; + nchunks = (ne00 + chunk_target - 1) / chunk_target; + nchunks = MAX(1, MIN(nchunks, 64)); } - return max_size; -} -static bool ggml_backend_opencl_buffer_type_supports_backend(ggml_backend_buffer_type_t buft, ggml_backend_t backend) { - return ggml_backend_is_opencl(backend); + size_t global_work_size[] = {(size_t)ne10*nth*nchunks, (size_t)ne11, (size_t)ne12}; + size_t local_work_size[] = {(size_t)nth, 1, 1}; - UNUSED(buft); + backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst); } -static ggml_backend_buffer_type_i ggml_backend_opencl_buffer_type_interface = { - /* .get_name = */ ggml_backend_opencl_buffer_type_get_name, - /* .alloc_buffer = */ ggml_backend_opencl_buffer_type_alloc_buffer, - /* .get_alignment = */ ggml_backend_opencl_buffer_type_get_alignment, - /* .get_max_size = */ ggml_backend_opencl_buffer_type_get_max_size, - /* .get_alloc_size = */ NULL, - /* .is_host = */ NULL, -}; +static void ggml_cl_set_rows(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { + GGML_ASSERT(src0); + GGML_ASSERT(src0->extra); + GGML_ASSERT(src1); + GGML_ASSERT(src1->extra); + GGML_ASSERT(dst); + GGML_ASSERT(dst->extra); + GGML_ASSERT(src1->type == GGML_TYPE_I64 || src1->type == GGML_TYPE_I32); -// -// backend device -// + // ne0 = ne00 + // ne2 = ne02 + // ne3 = ne03 -static const char * ggml_backend_opencl_device_get_name(ggml_backend_dev_t dev) { - return "GPUOpenCL"; + const int ne01 = src0->ne[1]; + const int ne02 = src0->ne[2]; + const int ne03 = src0->ne[3]; - GGML_UNUSED(dev); -} + const cl_ulong nb01 = src0->nb[1]; + const cl_ulong nb02 = src0->nb[2]; + const cl_ulong nb03 = src0->nb[3]; -static const char * ggml_backend_opencl_device_get_description(ggml_backend_dev_t dev) { - ggml_backend_opencl_device_context *dev_ctx = (ggml_backend_opencl_device_context *) dev->context; - return dev_ctx->device_name.c_str(); -} + const int ne11 = src1->ne[1]; + const int ne12 = src1->ne[2]; -static void ggml_backend_opencl_device_get_memory(ggml_backend_dev_t dev, size_t * free, size_t * total) { - *free = 0; - *total = 0; + const cl_ulong nb10 = src1->nb[0]; + const cl_ulong nb11 = src1->nb[1]; + const cl_ulong nb12 = src1->nb[2]; - GGML_UNUSED(dev); -} + const int ne0 = dst->ne[0]; -static enum ggml_backend_dev_type ggml_backend_opencl_device_get_type(ggml_backend_dev_t dev) { - return GGML_BACKEND_DEVICE_TYPE_GPU; + const cl_ulong nb1 = dst->nb[1]; + const cl_ulong nb2 = dst->nb[2]; + const cl_ulong nb3 = dst->nb[3]; - GGML_UNUSED(dev); -} + const int nblk0 = ne0/ggml_blck_size(dst->type); -static void ggml_backend_opencl_device_get_props(ggml_backend_dev_t dev, struct ggml_backend_dev_props * props) { - props->name = ggml_backend_opencl_device_get_name(dev); - props->description = ggml_backend_opencl_device_get_description(dev); - props->type = ggml_backend_opencl_device_get_type(dev); - ggml_backend_opencl_device_get_memory(dev, &props->memory_free, &props->memory_total); - props->caps = ggml_backend_dev_caps { - /* .async = */ false, - /* .host_buffer = */ false, - /* .buffer_from_host_ptr = */ false, - /* .events = */ false, - }; -} + ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context; -static ggml_backend_t ggml_backend_opencl_device_init(ggml_backend_dev_t dev, const char * params) { - ggml_backend_opencl_context * backend_ctx = ggml_cl2_init(dev); - // Getting a new reference to the backend, increase ref_count - backend_ctx->ref_count++; + ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra; + ggml_tensor_extra_cl * extra1 = (ggml_tensor_extra_cl *)src1->extra; + ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra; - ggml_backend_t backend = new ggml_backend { - /* .guid = */ ggml_backend_opencl_guid(), - /* .interface = */ ggml_backend_opencl_i, - /* .device = */ dev, - /* .context = */ backend_ctx, - }; + cl_ulong offset0 = extra0->offset + src0->view_offs; + cl_ulong offset1 = extra1->offset + src1->view_offs; + cl_ulong offsetd = extrad->offset + dst->view_offs; - return backend; + cl_kernel kernel; + + switch (dst->type) { + case GGML_TYPE_F32: + if (src1->type == GGML_TYPE_I64) { + kernel = backend_ctx->kernel_set_rows_f32_i64; + } else { + kernel = backend_ctx->kernel_set_rows_f32_i32; + } + break; + case GGML_TYPE_F16: + if (src1->type == GGML_TYPE_I64) { + kernel = backend_ctx->kernel_set_rows_f16_i64; + } else { + kernel = backend_ctx->kernel_set_rows_f16_i32; + } + break; + default: + GGML_ABORT("not implemented"); + } + + fastdiv_vals ne11_ = init_fastdiv_values(ne11); + fastdiv_vals ne12_ = init_fastdiv_values(ne12); - GGML_UNUSED(params); -} + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra1->data_device)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offset1)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offsetd)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne01)); + CL_CHECK(clSetKernelArg(kernel, 7, sizeof(cl_ulong), &nb01)); + CL_CHECK(clSetKernelArg(kernel, 8, sizeof(cl_ulong), &nb02)); + CL_CHECK(clSetKernelArg(kernel, 9, sizeof(cl_ulong), &nb03)); + CL_CHECK(clSetKernelArg(kernel, 10, sizeof(fastdiv_vals), &ne11_)); + CL_CHECK(clSetKernelArg(kernel, 11, sizeof(fastdiv_vals), &ne12_)); + CL_CHECK(clSetKernelArg(kernel, 12, sizeof(cl_ulong), &nb10)); + CL_CHECK(clSetKernelArg(kernel, 13, sizeof(cl_ulong), &nb11)); + CL_CHECK(clSetKernelArg(kernel, 14, sizeof(cl_ulong), &nb12)); + CL_CHECK(clSetKernelArg(kernel, 15, sizeof(int), &nblk0)); + CL_CHECK(clSetKernelArg(kernel, 16, sizeof(cl_ulong), &nb1)); + CL_CHECK(clSetKernelArg(kernel, 17, sizeof(cl_ulong), &nb2)); + CL_CHECK(clSetKernelArg(kernel, 18, sizeof(cl_ulong), &nb3)); -static ggml_backend_buffer_type_t ggml_backend_opencl_device_get_buffer_type(ggml_backend_dev_t dev) { - auto * dev_ctx = static_cast<ggml_backend_opencl_device_context *>(dev->context); + int nth0 = 64; + if (backend_ctx->gpu_family == INTEL) { + nth0 = 32; + } else if (backend_ctx->gpu_family == ADRENO) { + nth0 = 64; + } - dev_ctx->buffer_type = ggml_backend_buffer_type{ - /* .iface = */ ggml_backend_opencl_buffer_type_interface, - /* .device = */ dev, - /* .context = */ nullptr, - }; + int max_workgroup_size = backend_ctx->get_kernel_workgroup_size(kernel); + while (nth0 < nblk0 && nth0 < max_workgroup_size) { + nth0 *= 2; + } - return &dev_ctx->buffer_type; -} + int rows_per_workgroup = 1; + if (nth0 > nblk0) { + rows_per_workgroup = nth0 / nblk0; + nth0 = nblk0; + } -static ggml_backend_buffer_t ggml_backend_opencl_device_buffer_from_ptr(ggml_backend_dev_t dev, void * ptr, size_t size, size_t max_tensor_size) { - GGML_UNUSED(dev); - GGML_UNUSED(ptr); - GGML_UNUSED(size); - GGML_UNUSED(max_tensor_size); - return nullptr; -} + size_t global_work_size[] = { + (size_t)(ne01 + rows_per_workgroup - 1)/rows_per_workgroup*nth0, + (size_t)ne02*rows_per_workgroup, + (size_t)ne03}; + size_t local_work_size[] = {(size_t)nth0, (size_t)rows_per_workgroup, 1}; -static bool ggml_backend_opencl_device_supports_op(ggml_backend_dev_t dev, const struct ggml_tensor * op) { - return ggml_opencl_supports_op(dev, op); + backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst); } -static bool ggml_backend_opencl_device_supports_buft(ggml_backend_dev_t dev, ggml_backend_buffer_type_t buft) { - // Check 'dev' and 'buffer_type' are not objects belonging to this backend. - if (dev->iface.get_name != ggml_backend_opencl_device_get_name || - buft->iface.get_name != ggml_backend_opencl_buffer_type_get_name) { - return false; - } +static void ggml_cl_add(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { + GGML_ASSERT(src0); + GGML_ASSERT(src0->extra); + GGML_ASSERT(src1); + GGML_ASSERT(src1->extra); + GGML_ASSERT(dst); + GGML_ASSERT(dst->extra); - // Check cl_context is the same. clEnqueue* commands may not use - // buffers from another cl_context. - ggml_backend_opencl_context * backend_ctx0 = ggml_cl2_init(dev); - ggml_backend_opencl_context * backend_ctx1 = ggml_cl2_init(buft->device); - return backend_ctx0->context == backend_ctx1->context; -} + const int ne00 = src0->ne[0]; + const int ne01 = src0->ne[1]; + const int ne02 = src0->ne[2]; + const int ne03 = src0->ne[3]; -namespace /* anonymous */ { -struct ggml_backend_device_i ggml_backend_opencl_device_i = { - /* .get_name = */ ggml_backend_opencl_device_get_name, - /* .get_description = */ ggml_backend_opencl_device_get_description, - /* .get_memory = */ ggml_backend_opencl_device_get_memory, - /* .get_type = */ ggml_backend_opencl_device_get_type, - /* .get_props = */ ggml_backend_opencl_device_get_props, - /* .init_backend = */ ggml_backend_opencl_device_init, - /* .get_buffer_type = */ ggml_backend_opencl_device_get_buffer_type, - /* .get_host_buffer_type = */ NULL, - /* .buffer_from_host_ptr = */ ggml_backend_opencl_device_buffer_from_ptr, - /* .supports_op = */ ggml_backend_opencl_device_supports_op, - /* .supports_buft = */ ggml_backend_opencl_device_supports_buft, - /* .offload_op = */ NULL, - /* .event_new = */ NULL, - /* .event_free = */ NULL, - /* .event_synchronize = */ NULL, -}; -} + const cl_ulong nb00 = src0->nb[0]; + const cl_ulong nb01 = src0->nb[1]; + const cl_ulong nb02 = src0->nb[2]; + const cl_ulong nb03 = src0->nb[3]; -// Backend registry + const int ne10 = src1->ne[0]; + const int ne11 = src1->ne[1]; + const int ne12 = src1->ne[2]; + const int ne13 = src1->ne[3]; -static const char * ggml_backend_opencl_reg_get_name(ggml_backend_reg_t reg) { - return "OpenCL"; + const cl_ulong nb10 = src1->nb[0]; + const cl_ulong nb11 = src1->nb[1]; + const cl_ulong nb12 = src1->nb[2]; + const cl_ulong nb13 = src1->nb[3]; - GGML_UNUSED(reg); -} + const int ne0 = dst->ne[0]; + const int ne1 = dst->ne[1]; + const int ne2 = dst->ne[2]; + const int ne3 = dst->ne[3]; -static size_t ggml_backend_opencl_reg_device_count(ggml_backend_reg_t reg) { - return g_ggml_backend_opencl_devices.size(); + const cl_ulong nb0 = dst->nb[0]; + const cl_ulong nb1 = dst->nb[1]; + const cl_ulong nb2 = dst->nb[2]; + const cl_ulong nb3 = dst->nb[3]; - GGML_UNUSED(reg); -} + ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context; -static ggml_backend_dev_t ggml_backend_opencl_reg_device_get(ggml_backend_reg_t reg, size_t index) { - GGML_ASSERT(index < ggml_backend_opencl_reg_device_count(reg)); + ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra; + ggml_tensor_extra_cl * extra1 = (ggml_tensor_extra_cl *)src1->extra; + ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra; - return &g_ggml_backend_opencl_devices[index]; + cl_ulong offset0 = extra0->offset + src0->view_offs; + cl_ulong offset1 = extra1->offset + src1->view_offs; + cl_ulong offsetd = extrad->offset + dst->view_offs; - GGML_UNUSED(reg); - GGML_UNUSED(index); -} + cl_kernel kernel; -static struct ggml_backend_reg_i ggml_backend_opencl_reg_i = { - /* .get_name = */ ggml_backend_opencl_reg_get_name, - /* .device_count = */ ggml_backend_opencl_reg_device_count, - /* .device_get = */ ggml_backend_opencl_reg_device_get, - /* .get_proc_address = */ NULL, -}; + const bool bcast_row = ggml_nelements(src1) == ne10 && ggml_is_contiguous(src1) && ne00 % 4 == 0 && ne10 % 4 == 0; -ggml_backend_reg_t ggml_backend_opencl_reg(void) { - static std::mutex mutex; - static ggml_backend_reg reg; - static bool initialized = false; - std::lock_guard<std::mutex> lock(mutex); + if (bcast_row) { + GGML_ASSERT(ggml_is_contiguous(src0)); + GGML_ASSERT(ne11 == 1); + } - if (initialized) { - return ® + if (dst->type == GGML_TYPE_F32) { + GGML_ASSERT(src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32); + if (bcast_row) { + kernel = backend_ctx->kernel_add_row; + const int ne = ne00 / 4; + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra1->data_device)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offset1)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offsetd)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne)); + } else { + kernel = backend_ctx->kernel_add; + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra1->data_device)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offset1)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offsetd)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &ne01)); + CL_CHECK(clSetKernelArg(kernel, 8, sizeof(int), &ne02)); + CL_CHECK(clSetKernelArg(kernel, 9, sizeof(int), &ne03)); + CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_ulong), &nb00)); + CL_CHECK(clSetKernelArg(kernel, 11, sizeof(cl_ulong), &nb01)); + CL_CHECK(clSetKernelArg(kernel, 12, sizeof(cl_ulong), &nb02)); + CL_CHECK(clSetKernelArg(kernel, 13, sizeof(cl_ulong), &nb03)); + CL_CHECK(clSetKernelArg(kernel, 14, sizeof(int), &ne10)); + CL_CHECK(clSetKernelArg(kernel, 15, sizeof(int), &ne11)); + CL_CHECK(clSetKernelArg(kernel, 16, sizeof(int), &ne12)); + CL_CHECK(clSetKernelArg(kernel, 17, sizeof(int), &ne13)); + CL_CHECK(clSetKernelArg(kernel, 18, sizeof(cl_ulong), &nb10)); + CL_CHECK(clSetKernelArg(kernel, 19, sizeof(cl_ulong), &nb11)); + CL_CHECK(clSetKernelArg(kernel, 20, sizeof(cl_ulong), &nb12)); + CL_CHECK(clSetKernelArg(kernel, 21, sizeof(cl_ulong), &nb13)); + CL_CHECK(clSetKernelArg(kernel, 22, sizeof(int), &ne0)); + CL_CHECK(clSetKernelArg(kernel, 23, sizeof(int), &ne1)); + CL_CHECK(clSetKernelArg(kernel, 24, sizeof(int), &ne2)); + CL_CHECK(clSetKernelArg(kernel, 25, sizeof(int), &ne3)); + CL_CHECK(clSetKernelArg(kernel, 26, sizeof(cl_ulong), &nb0)); + CL_CHECK(clSetKernelArg(kernel, 27, sizeof(cl_ulong), &nb1)); + CL_CHECK(clSetKernelArg(kernel, 28, sizeof(cl_ulong), &nb2)); + CL_CHECK(clSetKernelArg(kernel, 29, sizeof(cl_ulong), &nb3)); + } + } else if (dst->type == GGML_TYPE_F16) { + GGML_ASSERT(src0->type == GGML_TYPE_F16 || src0->type == GGML_TYPE_F32); + GGML_ASSERT(src1->type == GGML_TYPE_F16 || src1->type == GGML_TYPE_F32); + const int type_src0 = (src0->type == GGML_TYPE_F32); + const int type_src1 = (src1->type == GGML_TYPE_F32); + if (bcast_row) { + kernel = backend_ctx->kernel_add_row_f16; + const int ne = ne00 / 4; + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra1->data_device)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offset1)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offsetd)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne)); + CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &type_src0)); + CL_CHECK(clSetKernelArg(kernel, 8, sizeof(int), &type_src1)); + } else { + kernel = backend_ctx->kernel_add_f16; + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra1->data_device)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offset1)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offsetd)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &ne01)); + CL_CHECK(clSetKernelArg(kernel, 8, sizeof(int), &ne02)); + CL_CHECK(clSetKernelArg(kernel, 9, sizeof(int), &ne03)); + CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_ulong), &nb00)); + CL_CHECK(clSetKernelArg(kernel, 11, sizeof(cl_ulong), &nb01)); + CL_CHECK(clSetKernelArg(kernel, 12, sizeof(cl_ulong), &nb02)); + CL_CHECK(clSetKernelArg(kernel, 13, sizeof(cl_ulong), &nb03)); + CL_CHECK(clSetKernelArg(kernel, 14, sizeof(int), &ne10)); + CL_CHECK(clSetKernelArg(kernel, 15, sizeof(int), &ne11)); + CL_CHECK(clSetKernelArg(kernel, 16, sizeof(int), &ne12)); + CL_CHECK(clSetKernelArg(kernel, 17, sizeof(int), &ne13)); + CL_CHECK(clSetKernelArg(kernel, 18, sizeof(cl_ulong), &nb10)); + CL_CHECK(clSetKernelArg(kernel, 19, sizeof(cl_ulong), &nb11)); + CL_CHECK(clSetKernelArg(kernel, 20, sizeof(cl_ulong), &nb12)); + CL_CHECK(clSetKernelArg(kernel, 21, sizeof(cl_ulong), &nb13)); + CL_CHECK(clSetKernelArg(kernel, 22, sizeof(int), &ne0)); + CL_CHECK(clSetKernelArg(kernel, 23, sizeof(int), &ne1)); + CL_CHECK(clSetKernelArg(kernel, 24, sizeof(int), &ne2)); + CL_CHECK(clSetKernelArg(kernel, 25, sizeof(int), &ne3)); + CL_CHECK(clSetKernelArg(kernel, 26, sizeof(cl_ulong), &nb0)); + CL_CHECK(clSetKernelArg(kernel, 27, sizeof(cl_ulong), &nb1)); + CL_CHECK(clSetKernelArg(kernel, 28, sizeof(cl_ulong), &nb2)); + CL_CHECK(clSetKernelArg(kernel, 29, sizeof(cl_ulong), &nb3)); + CL_CHECK(clSetKernelArg(kernel, 30, sizeof(int), &type_src0)); + CL_CHECK(clSetKernelArg(kernel, 31, sizeof(int), &type_src1)); + } + } else { + GGML_ASSERT(false && "unsupported data types for add"); } - initialized = true; - - g_ggml_backend_opencl_devices = ggml_opencl_probe_devices(®); - - reg = ggml_backend_reg{ - /* .api_version = */ GGML_BACKEND_API_VERSION, - /* .iface = */ ggml_backend_opencl_reg_i, - /* .context = */ NULL, - }; - - return ® -} - -GGML_BACKEND_DL_IMPL(ggml_backend_opencl_reg) -//------------------------------------------------------------------------------ -// Debugging utils -//------------------------------------------------------------------------------ -#if 0 -#define QK4_0 32 -typedef struct { - ggml_fp16_t d; // delta - uint8_t qs[QK4_0 / 2]; // nibbles / quants -} block_q4_0; -static_assert(sizeof(block_q4_0) == sizeof(ggml_fp16_t) + QK4_0 / 2, - "wrong q4_0 block size/padding"); + if (bcast_row) { + int n = ggml_nelements(dst)/4; + size_t global_work_size[] = {(size_t)n, 1, 1}; + size_t local_work_size[] = {64, 1, 1}; -#include <math.h> -#ifdef __cplusplus -#include "half.hpp" -#endif + size_t * local_work_size_ptr = local_work_size; + if (n % 64 != 0 && !backend_ctx->non_uniform_workgroups) { + local_work_size_ptr = nullptr; + } -static void dump_tensor(ggml_backend_t backend, const struct ggml_tensor * tensor) { - void * buf = malloc(ggml_nbytes(tensor)); + backend_ctx->enqueue_ndrange_kernel(kernel, 1, global_work_size, local_work_size_ptr, dst); + } else { + unsigned int nth = MIN(64, ne0); + size_t global_work_size[] = {(size_t)ne01*nth, (size_t)ne02, (size_t)ne03}; + size_t local_work_size[] = {nth, 1, 1}; - ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context; - cl_command_queue queue = backend_ctx->queue; -#ifdef GGML_OPENCL_SOA_Q - void * buf_q; - void * buf_d; -#endif + backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst); + } +} - // Make sure everything is done. - CL_CHECK(clFinish(queue)); +static void ggml_cl_add_id(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { + GGML_ASSERT(src0); + GGML_ASSERT(src0->extra); + GGML_ASSERT(src1); + GGML_ASSERT(src1->extra); + GGML_ASSERT(dst); + GGML_ASSERT(dst->extra); -#ifdef GGML_OPENCL_SOA_Q - if (tensor->type == GGML_TYPE_Q4_0) { - ggml_tensor_extra_cl_q4_0 * extra = (ggml_tensor_extra_cl_q4_0 *) tensor->extra; - GGML_ASSERT(extra); + const ggml_tensor * src2 = dst->src[2]; + GGML_ASSERT(src2); + GGML_ASSERT(src2->extra); - size_t size_q = ggml_nelements(tensor)/QK4_0 * QK4_0/2; - size_t size_d = ggml_nelements(tensor)/QK4_0 * sizeof(ggml_fp16_t); - GGML_ASSERT(size_q + size_d == ggml_nbytes(tensor)); - buf_q = malloc(size_q); - buf_d = malloc(size_d); + GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT(src1->type == GGML_TYPE_F32); + GGML_ASSERT(src2->type == GGML_TYPE_I32); + GGML_ASSERT(dst->type == GGML_TYPE_F32); - CL_CHECK(clEnqueueReadBuffer(queue, extra->q, CL_TRUE, 0, size_q, buf_q, 0, NULL, NULL)); - CL_CHECK(clEnqueueReadBuffer(queue, extra->d, CL_TRUE, 0, size_d, buf_d, 0, NULL, NULL)); - CL_CHECK(clFinish(queue)); - } else if (tensor->type == GGML_TYPE_MXFP4) { - ggml_tensor_extra_cl_mxfp4 * extra = (ggml_tensor_extra_cl_mxfp4 *) tensor->extra; - GGML_ASSERT(extra); + GGML_ASSERT(ggml_is_contiguous_rows(src0)); - size_t size_q = ggml_nelements(tensor)/QK_MXFP4 * QK_MXFP4/2; - size_t size_e = ggml_nelements(tensor)/QK_MXFP4 * sizeof(char); - GGML_ASSERT(size_q + size_e == ggml_nbytes(tensor)); - buf_q = malloc(size_q); - buf_d = malloc(size_e); + const int ne00 = src0->ne[0]; + const int ne01 = src0->ne[1]; + const int ne02 = src0->ne[2]; - CL_CHECK(clEnqueueReadBuffer(queue, extra->q, CL_TRUE, 0, size_q, buf_q, 0, NULL, NULL)); - CL_CHECK(clEnqueueReadBuffer(queue, extra->d, CL_TRUE, 0, size_e, buf_d, 0, NULL, NULL)); - CL_CHECK(clFinish(queue)); - } else { - // Read out the tensor from GPU memory. - ggml_tensor_extra_cl * extra = (ggml_tensor_extra_cl *) tensor->extra; - GGML_ASSERT(extra); + const cl_ulong nb01 = src0->nb[1]; + const cl_ulong nb02 = src0->nb[2]; - CL_CHECK(clEnqueueReadBuffer(queue, extra->data_device, CL_TRUE, - extra->offset, ggml_nbytes(tensor), buf, 0, NULL, NULL)); - CL_CHECK(clFinish(queue)); - } -#else - // Read out the tensor from GPU memory. - ggml_tensor_extra_cl * extra = (ggml_tensor_extra_cl *) tensor->extra; - GGML_ASSERT(extra); + const cl_ulong nb11 = src1->nb[1]; - CL_CHECK(clEnqueueReadBuffer(queue, extra->data_device, CL_TRUE, - extra->offset, ggml_nbytes(tensor), buf, 0, NULL, NULL)); - CL_CHECK(clFinish(queue)); -#endif // GGML_OPENCL_SOA_Q + const cl_ulong nb21 = src2->nb[1]; - // Open file and dump. - char fname[512]; - snprintf(fname, sizeof(fname), "./tensor-dumps/%s.txt", tensor->name); - FILE * f = fopen(fname, "w"); - if (!f) { - printf("Failed to open %s\n", fname); - return; - } + const int ne0 = dst->ne[0]; + const int ne1 = dst->ne[1]; - if (tensor->type == GGML_TYPE_F32) { - float * data = (float *) buf; - for (int i = 0; i < ggml_nelements(tensor); ++i) { - if (isnan(data[i])) { - printf("NaN found: %s\n", tensor->name); - break; - } - fprintf(f, "%f\n", data[i]); - } - } else if (tensor->type == GGML_TYPE_I32) { - int * data = (int *) buf; - for (int i = 0; i < ggml_nelements(tensor); ++i) { - if (isnan(data[i])) { - printf("NaN found: %s\n", tensor->name); - break; - } - fprintf(f, "%d\n", data[i]); - } - } else if (tensor->type == GGML_TYPE_F16) { -#ifdef __cplusplus - half_float::half * data = (half_float::half *) buf; - for (int i = 0; i < ggml_nelements(tensor); ++i) { - if (std::isnan(data[i])) { - printf("NaN found: %s\n", tensor->name); - break; - } - fprintf(f, "%f\n", float(data[i])); - } -#endif - } else if (tensor->type == GGML_TYPE_Q4_0) { -#ifdef GGML_OPENCL_SOA_Q - ggml_fp16_t * data_d = (ggml_fp16_t *)buf_d; - unsigned char * data_q = (unsigned char *)buf_q; + ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context; - for (int i = 0; i < ggml_nelements(tensor)/QK4_0; ++i) { - fprintf(f, "%04x, ", data_d[i]); - for (int k = 0; k < QK4_0/2; ++k) { - fprintf(f, "%02x, ", data_q[k]); - } - fprintf(f, "\n"); - data_q += QK4_0/2; - } - free(buf_d); - free(buf_q); -#else - block_q4_0 * data = (block_q4_0 *) buf; - for (int i = 0; i < ggml_nelements(tensor)/QK4_0; ++i) { - fprintf(f, "%04x, ", data[i].d); - for (int k = 0; k < QK4_0/2; ++k) { - fprintf(f, "%02x, ", data[i].qs[k]); - } - fprintf(f, "\n"); - } -#endif // GGML_OPENCL_SOA_Q - } - free(buf); - fflush(f); - fclose(f); -} -#else -#define dump_tensor(tensor) -#endif + ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra; + ggml_tensor_extra_cl * extra1 = (ggml_tensor_extra_cl *)src1->extra; + ggml_tensor_extra_cl * extra2 = (ggml_tensor_extra_cl *)src2->extra; + ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra; -//------------------------------------------------------------------------------ -// Ops -//------------------------------------------------------------------------------ + cl_ulong offset0 = extra0->offset + src0->view_offs; + cl_ulong offset1 = extra1->offset + src1->view_offs; + cl_ulong offset2 = extra2->offset + src2->view_offs; + cl_ulong offsetd = extrad->offset + dst->view_offs; -static bool ggml_cl_can_mul_mat(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst) { - const int64_t ne10 = src1->ne[0]; + cl_kernel kernel = backend_ctx->kernel_add_id; - const int64_t ne0 = dst->ne[0]; - const int64_t ne1 = dst->ne[1]; + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra1->data_device)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offset1)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extra2->data_device)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offset2)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, 7, sizeof(cl_ulong), &offsetd)); + CL_CHECK(clSetKernelArg(kernel, 8, sizeof(cl_ulong), &nb01)); + CL_CHECK(clSetKernelArg(kernel, 9, sizeof(cl_ulong), &nb02)); + CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_ulong), &nb11)); + CL_CHECK(clSetKernelArg(kernel, 11, sizeof(cl_ulong), &nb21)); + CL_CHECK(clSetKernelArg(kernel, 12, sizeof(int), &ne0)); + CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int), &ne1)); - // TODO: find the optimal values for these - return (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type)) && - src1->type == GGML_TYPE_F32 && - dst->type == GGML_TYPE_F32 && - (ne0 >= 32 && ne1 >= 32 && ne10 >= 32); -} + int nth = MIN(ne00, (int) backend_ctx->get_kernel_workgroup_size(kernel)); + size_t global_work_size[] = { (size_t)ne01*nth, (size_t)ne02, 1 }; + size_t local_work_size[] = { (size_t)nth, 1, 1 }; -static void ggml_cl_nop(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { - UNUSED(backend); - UNUSED(src0); - UNUSED(src1); - UNUSED(dst); + backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst); } -static void ggml_cl_get_rows(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { +static void ggml_cl_mul(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { GGML_ASSERT(src0); GGML_ASSERT(src0->extra); GGML_ASSERT(src1); @@ -4681,16 +9184,36 @@ static void ggml_cl_get_rows(ggml_backend_t backend, const ggml_tensor * src0, c GGML_ASSERT(dst); GGML_ASSERT(dst->extra); - const int ne00 = src0->ne[0]; + GGML_ASSERT(src0->type == src1->type); + GGML_ASSERT(src0->type == dst->type); + GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16); + + const int ne00 = src0->ne[0]; + const int ne01 = src0->ne[1]; + const int ne02 = src0->ne[2]; + const int ne03 = src0->ne[3]; + + const cl_ulong nb00 = src0->nb[0]; const cl_ulong nb01 = src0->nb[1]; const cl_ulong nb02 = src0->nb[2]; const cl_ulong nb03 = src0->nb[3]; - const int ne10 = src1->ne[0]; + + const int ne10 = src1->ne[0]; + const int ne11 = src1->ne[1]; + const int ne12 = src1->ne[2]; + const int ne13 = src1->ne[3]; UNUSED(ne13); + const cl_ulong nb10 = src1->nb[0]; - const int ne11 = src1->ne[1]; - const int ne12 = src1->ne[2]; const cl_ulong nb11 = src1->nb[1]; const cl_ulong nb12 = src1->nb[2]; + const cl_ulong nb13 = src1->nb[3]; UNUSED(nb13); + + const int ne0 = dst->ne[0]; + const int ne1 = dst->ne[1]; + const int ne2 = dst->ne[2]; + const int ne3 = dst->ne[3]; + + const cl_ulong nb0 = dst->nb[0]; const cl_ulong nb1 = dst->nb[1]; const cl_ulong nb2 = dst->nb[2]; const cl_ulong nb3 = dst->nb[3]; @@ -4705,82 +9228,129 @@ static void ggml_cl_get_rows(ggml_backend_t backend, const ggml_tensor * src0, c cl_ulong offset1 = extra1->offset + src1->view_offs; cl_ulong offsetd = extrad->offset + dst->view_offs; + bool bcast_row = false; cl_kernel kernel; - switch (src0->type) { - case GGML_TYPE_F32: - kernel = backend_ctx->kernel_get_rows_f32; - break; - case GGML_TYPE_F16: - kernel = backend_ctx->kernel_get_rows_f16; - break; - case GGML_TYPE_Q4_0: - kernel = backend_ctx->kernel_get_rows_q4_0; - break; - default: - GGML_ASSERT(false && "not implemented"); + if (ggml_nelements(src1) == ne10 && ggml_is_contiguous(src1) && ne00 % 4 == 0 && ne10 % 4 == 0) { + GGML_ASSERT(ggml_is_contiguous(src0)); + + // src1 is a row + GGML_ASSERT(ne11 == 1); + + bcast_row = true; + int ne = ne00 / 4; + + if (src0->type == GGML_TYPE_F32) { + kernel = backend_ctx->kernel_mul_row; + } else { + kernel = backend_ctx->kernel_mul_row_f16; + } + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra1->data_device)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offset1)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offsetd)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne)); + } else { + if (src0->type == GGML_TYPE_F32) { + kernel = backend_ctx->kernel_mul; + } else { + kernel = backend_ctx->kernel_mul_f16; + } + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra1->data_device)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offset1)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offsetd)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &ne01)); + CL_CHECK(clSetKernelArg(kernel, 8, sizeof(int), &ne02)); + CL_CHECK(clSetKernelArg(kernel, 9, sizeof(int), &ne03)); + CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_ulong), &nb00)); + CL_CHECK(clSetKernelArg(kernel, 11, sizeof(cl_ulong), &nb01)); + CL_CHECK(clSetKernelArg(kernel, 12, sizeof(cl_ulong), &nb02)); + CL_CHECK(clSetKernelArg(kernel, 13, sizeof(cl_ulong), &nb03)); + CL_CHECK(clSetKernelArg(kernel, 14, sizeof(int), &ne10)); + CL_CHECK(clSetKernelArg(kernel, 15, sizeof(int), &ne11)); + CL_CHECK(clSetKernelArg(kernel, 16, sizeof(int), &ne12)); + CL_CHECK(clSetKernelArg(kernel, 17, sizeof(int), &ne13)); + CL_CHECK(clSetKernelArg(kernel, 18, sizeof(cl_ulong), &nb10)); + CL_CHECK(clSetKernelArg(kernel, 19, sizeof(cl_ulong), &nb11)); + CL_CHECK(clSetKernelArg(kernel, 20, sizeof(cl_ulong), &nb12)); + CL_CHECK(clSetKernelArg(kernel, 21, sizeof(cl_ulong), &nb13)); + CL_CHECK(clSetKernelArg(kernel, 22, sizeof(int), &ne0)); + CL_CHECK(clSetKernelArg(kernel, 23, sizeof(int), &ne1)); + CL_CHECK(clSetKernelArg(kernel, 24, sizeof(int), &ne2)); + CL_CHECK(clSetKernelArg(kernel, 25, sizeof(int), &ne3)); + CL_CHECK(clSetKernelArg(kernel, 26, sizeof(cl_ulong), &nb0)); + CL_CHECK(clSetKernelArg(kernel, 27, sizeof(cl_ulong), &nb1)); + CL_CHECK(clSetKernelArg(kernel, 28, sizeof(cl_ulong), &nb2)); + CL_CHECK(clSetKernelArg(kernel, 29, sizeof(cl_ulong), &nb3)); } - CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device)); - CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0)); - CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra1->data_device)); - CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offset1)); - CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extrad->data_device)); - CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offsetd)); - CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne00)); - CL_CHECK(clSetKernelArg(kernel, 7, sizeof(cl_ulong), &nb01)); - CL_CHECK(clSetKernelArg(kernel, 8, sizeof(cl_ulong), &nb02)); - CL_CHECK(clSetKernelArg(kernel, 9, sizeof(cl_ulong), &nb03)); - CL_CHECK(clSetKernelArg(kernel, 10, sizeof(int), &ne10)); - CL_CHECK(clSetKernelArg(kernel, 11, sizeof(cl_ulong), &nb10)); - CL_CHECK(clSetKernelArg(kernel, 12, sizeof(cl_ulong), &nb11)); - CL_CHECK(clSetKernelArg(kernel, 13, sizeof(cl_ulong), &nb12)); - CL_CHECK(clSetKernelArg(kernel, 14, sizeof(cl_ulong), &nb1)); - CL_CHECK(clSetKernelArg(kernel, 15, sizeof(cl_ulong), &nb2)); - CL_CHECK(clSetKernelArg(kernel, 16, sizeof(cl_ulong), &nb3)); + if (bcast_row) { + int n = ggml_nelements(dst)/4; + size_t global_work_size[] = {(size_t)n, 1, 1}; + size_t local_work_size[] = {64, 1, 1}; - size_t global_work_size[] = {(size_t)ne10*64, (size_t)ne11, (size_t)ne12}; - size_t local_work_size[] = {64, 1, 1}; + size_t * local_work_size_ptr = local_work_size; + if (n % 64 != 0 && !backend_ctx->non_uniform_workgroups) { + local_work_size_ptr = nullptr; // Let driver choose the work-group sizes. + } - backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst); + backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size_ptr, dst); + } else { + unsigned int nth = MIN(64, ne0); + size_t global_work_size[] = {ne01*nth, (size_t)ne02, (size_t)ne03}; + size_t local_work_size[] = {nth, 1, 1}; + + backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst); + } } -static void ggml_cl_set_rows(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { +static void ggml_cl_div(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { GGML_ASSERT(src0); GGML_ASSERT(src0->extra); GGML_ASSERT(src1); GGML_ASSERT(src1->extra); GGML_ASSERT(dst); GGML_ASSERT(dst->extra); - GGML_ASSERT(src1->type == GGML_TYPE_I64 || src1->type == GGML_TYPE_I32); - // ne0 = ne00 - // ne2 = ne02 - // ne3 = ne03 + GGML_ASSERT(src0->type == src1->type); + GGML_ASSERT(src0->type == dst->type); + GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16); - const int ne01 = src0->ne[1]; - const int ne02 = src0->ne[2]; - const int ne03 = src0->ne[3]; + const int ne00 = src0->ne[0]; + const int ne01 = src0->ne[1]; + const int ne02 = src0->ne[2]; + const int ne03 = src0->ne[3]; + const cl_ulong nb00 = src0->nb[0]; const cl_ulong nb01 = src0->nb[1]; const cl_ulong nb02 = src0->nb[2]; const cl_ulong nb03 = src0->nb[3]; - const int ne11 = src1->ne[1]; - const int ne12 = src1->ne[2]; + const int ne10 = src1->ne[0]; + const int ne11 = src1->ne[1]; + const int ne12 = src1->ne[2]; + const int ne13 = src1->ne[3]; const cl_ulong nb10 = src1->nb[0]; const cl_ulong nb11 = src1->nb[1]; const cl_ulong nb12 = src1->nb[2]; + const cl_ulong nb13 = src1->nb[3]; - const int ne0 = dst->ne[0]; + const int ne0 = dst->ne[0]; + const cl_ulong nb0 = dst->nb[0]; const cl_ulong nb1 = dst->nb[1]; const cl_ulong nb2 = dst->nb[2]; const cl_ulong nb3 = dst->nb[3]; - const int nblk0 = ne0/ggml_blck_size(dst->type); - ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context; ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra; @@ -4791,78 +9361,79 @@ static void ggml_cl_set_rows(ggml_backend_t backend, const ggml_tensor * src0, c cl_ulong offset1 = extra1->offset + src1->view_offs; cl_ulong offsetd = extrad->offset + dst->view_offs; + bool bcast_row = false; cl_kernel kernel; - switch (dst->type) { - case GGML_TYPE_F32: - if (src1->type == GGML_TYPE_I64) { - kernel = backend_ctx->kernel_set_rows_f32_i64; - } else { - kernel = backend_ctx->kernel_set_rows_f32_i32; - } - break; - case GGML_TYPE_F16: - if (src1->type == GGML_TYPE_I64) { - kernel = backend_ctx->kernel_set_rows_f16_i64; - } else { - kernel = backend_ctx->kernel_set_rows_f16_i32; - } - break; - default: - GGML_ABORT("not implemented"); - } + if (ggml_nelements(src1) == ne10 && ggml_is_contiguous(src1) && ne00 % 4 == 0 && ne10 % 4 == 0) { + GGML_ASSERT(ggml_is_contiguous(src0)); - fastdiv_vals ne11_ = init_fastdiv_values(ne11); - fastdiv_vals ne12_ = init_fastdiv_values(ne12); + // src1 is a row + GGML_ASSERT(ne11 == 1); - CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device)); - CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0)); - CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra1->data_device)); - CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offset1)); - CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extrad->data_device)); - CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offsetd)); - CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne01)); - CL_CHECK(clSetKernelArg(kernel, 7, sizeof(cl_ulong), &nb01)); - CL_CHECK(clSetKernelArg(kernel, 8, sizeof(cl_ulong), &nb02)); - CL_CHECK(clSetKernelArg(kernel, 9, sizeof(cl_ulong), &nb03)); - CL_CHECK(clSetKernelArg(kernel, 10, sizeof(fastdiv_vals), &ne11_)); - CL_CHECK(clSetKernelArg(kernel, 11, sizeof(fastdiv_vals), &ne12_)); - CL_CHECK(clSetKernelArg(kernel, 12, sizeof(cl_ulong), &nb10)); - CL_CHECK(clSetKernelArg(kernel, 13, sizeof(cl_ulong), &nb11)); - CL_CHECK(clSetKernelArg(kernel, 14, sizeof(cl_ulong), &nb12)); - CL_CHECK(clSetKernelArg(kernel, 15, sizeof(int), &nblk0)); - CL_CHECK(clSetKernelArg(kernel, 16, sizeof(cl_ulong), &nb1)); - CL_CHECK(clSetKernelArg(kernel, 17, sizeof(cl_ulong), &nb2)); - CL_CHECK(clSetKernelArg(kernel, 18, sizeof(cl_ulong), &nb3)); + bcast_row = true; + int ne = ne00 / 4; - int nth0 = 64; - if (backend_ctx->gpu_family == INTEL) { - nth0 = 32; - } else if (backend_ctx->gpu_family == ADRENO) { - nth0 = 64; - } + if (src0->type == GGML_TYPE_F32) { + kernel = backend_ctx->kernel_div_row; + } else { + kernel = backend_ctx->kernel_div_row_f16; + } - int max_workgroup_size = backend_ctx->get_kernel_workgroup_size(kernel); - while (nth0 < nblk0 && nth0 < max_workgroup_size) { - nth0 *= 2; - } + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra1->data_device)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offset1)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offsetd)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne)); + } else { + if (src0->type == GGML_TYPE_F32) { + kernel = backend_ctx->kernel_div; + } else { + kernel = backend_ctx->kernel_div_f16; + } - int rows_per_workgroup = 1; - if (nth0 > nblk0) { - rows_per_workgroup = nth0 / nblk0; - nth0 = nblk0; + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra1->data_device)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offset1)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offsetd)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(cl_ulong), &nb00)); + CL_CHECK(clSetKernelArg(kernel, 7, sizeof(cl_ulong), &nb01)); + CL_CHECK(clSetKernelArg(kernel, 8, sizeof(cl_ulong), &nb02)); + CL_CHECK(clSetKernelArg(kernel, 9, sizeof(cl_ulong), &nb03)); + CL_CHECK(clSetKernelArg(kernel, 10, sizeof(int), &ne10)); + CL_CHECK(clSetKernelArg(kernel, 11, sizeof(int), &ne11)); + CL_CHECK(clSetKernelArg(kernel, 12, sizeof(int), &ne12)); + CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int), &ne13)); + CL_CHECK(clSetKernelArg(kernel, 14, sizeof(cl_ulong), &nb10)); + CL_CHECK(clSetKernelArg(kernel, 15, sizeof(cl_ulong), &nb11)); + CL_CHECK(clSetKernelArg(kernel, 16, sizeof(cl_ulong), &nb12)); + CL_CHECK(clSetKernelArg(kernel, 17, sizeof(cl_ulong), &nb13)); + CL_CHECK(clSetKernelArg(kernel, 18, sizeof(int), &ne0)); + CL_CHECK(clSetKernelArg(kernel, 19, sizeof(cl_ulong), &nb0)); + CL_CHECK(clSetKernelArg(kernel, 20, sizeof(cl_ulong), &nb1)); + CL_CHECK(clSetKernelArg(kernel, 21, sizeof(cl_ulong), &nb2)); + CL_CHECK(clSetKernelArg(kernel, 22, sizeof(cl_ulong), &nb3)); } - size_t global_work_size[] = { - (size_t)(ne01 + rows_per_workgroup - 1)/rows_per_workgroup*nth0, - (size_t)ne02*rows_per_workgroup, - (size_t)ne03}; - size_t local_work_size[] = {(size_t)nth0, (size_t)rows_per_workgroup, 1}; + if (bcast_row) { + int n = ggml_nelements(dst)/4; + size_t global_work_size[] = {(size_t)n, 1, 1}; + size_t local_work_size[] = {64, 1, 1}; - backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst); + backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst); + } else { + unsigned int nth = MIN(64, ne0); + size_t global_work_size[] = {ne01*nth, (size_t)ne02, (size_t)ne03}; + size_t local_work_size[] = {nth, 1, 1}; + + backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst); + } } -static void ggml_cl_add(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { +static void ggml_cl_sub(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { GGML_ASSERT(src0); GGML_ASSERT(src0->extra); GGML_ASSERT(src1); @@ -4870,6 +9441,10 @@ static void ggml_cl_add(ggml_backend_t backend, const ggml_tensor * src0, const GGML_ASSERT(dst); GGML_ASSERT(dst->extra); + GGML_ASSERT(src0->type == src1->type); + GGML_ASSERT(src0->type == dst->type); + GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16); + const int ne00 = src0->ne[0]; const int ne01 = src0->ne[1]; const int ne02 = src0->ne[2]; @@ -4891,9 +9466,6 @@ static void ggml_cl_add(ggml_backend_t backend, const ggml_tensor * src0, const const cl_ulong nb13 = src1->nb[3]; const int ne0 = dst->ne[0]; - const int ne1 = dst->ne[1]; - const int ne2 = dst->ne[2]; - const int ne3 = dst->ne[3]; const cl_ulong nb0 = dst->nb[0]; const cl_ulong nb1 = dst->nb[1]; @@ -4910,588 +9482,612 @@ static void ggml_cl_add(ggml_backend_t backend, const ggml_tensor * src0, const cl_ulong offset1 = extra1->offset + src1->view_offs; cl_ulong offsetd = extrad->offset + dst->view_offs; + bool bcast_row = false; cl_kernel kernel; - const bool bcast_row = ggml_nelements(src1) == ne10 && ggml_is_contiguous(src1) && ne00 % 4 == 0 && ne10 % 4 == 0; - - if (bcast_row) { + if (ggml_nelements(src1) == ne10 && ggml_is_contiguous(src1) && ne00 % 4 == 0 && ne10 % 4 == 0) { GGML_ASSERT(ggml_is_contiguous(src0)); + + // src1 is a row GGML_ASSERT(ne11 == 1); + + bcast_row = true; + int ne = ne00 / 4; + + if (src0->type == GGML_TYPE_F32) { + kernel = backend_ctx->kernel_sub_row; + } else { + kernel = backend_ctx->kernel_sub_row_f16; + } + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra1->data_device)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offset1)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offsetd)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne)); + } else { + if (src0->type == GGML_TYPE_F32) { + kernel = backend_ctx->kernel_sub; + } else { + kernel = backend_ctx->kernel_sub_f16; + } + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra1->data_device)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offset1)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offsetd)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(cl_ulong), &nb00)); + CL_CHECK(clSetKernelArg(kernel, 7, sizeof(cl_ulong), &nb01)); + CL_CHECK(clSetKernelArg(kernel, 8, sizeof(cl_ulong), &nb02)); + CL_CHECK(clSetKernelArg(kernel, 9, sizeof(cl_ulong), &nb03)); + CL_CHECK(clSetKernelArg(kernel, 10, sizeof(int), &ne10)); + CL_CHECK(clSetKernelArg(kernel, 11, sizeof(int), &ne11)); + CL_CHECK(clSetKernelArg(kernel, 12, sizeof(int), &ne12)); + CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int), &ne13)); + CL_CHECK(clSetKernelArg(kernel, 14, sizeof(cl_ulong), &nb10)); + CL_CHECK(clSetKernelArg(kernel, 15, sizeof(cl_ulong), &nb11)); + CL_CHECK(clSetKernelArg(kernel, 16, sizeof(cl_ulong), &nb12)); + CL_CHECK(clSetKernelArg(kernel, 17, sizeof(cl_ulong), &nb13)); + CL_CHECK(clSetKernelArg(kernel, 18, sizeof(int), &ne0)); + CL_CHECK(clSetKernelArg(kernel, 19, sizeof(cl_ulong), &nb0)); + CL_CHECK(clSetKernelArg(kernel, 20, sizeof(cl_ulong), &nb1)); + CL_CHECK(clSetKernelArg(kernel, 21, sizeof(cl_ulong), &nb2)); + CL_CHECK(clSetKernelArg(kernel, 22, sizeof(cl_ulong), &nb3)); } - if (dst->type == GGML_TYPE_F32) { - GGML_ASSERT(src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32); - if (bcast_row) { - kernel = backend_ctx->kernel_add_row; - const int ne = ne00 / 4; - CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device)); - CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0)); - CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra1->data_device)); - CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offset1)); - CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extrad->data_device)); - CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offsetd)); - CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne)); + if (bcast_row) { + int n = ggml_nelements(dst)/4; + size_t global_work_size[] = {(size_t)n, 1, 1}; + size_t local_work_size[] = {64, 1, 1}; + + backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst); + } else { + unsigned int nth = MIN(64, ne0); + size_t global_work_size[] = {ne01*nth, (size_t)ne02, (size_t)ne03}; + size_t local_work_size[] = {nth, 1, 1}; + + backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst); + } +} + +static void ggml_cl_sqr(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { + GGML_ASSERT(src0); + GGML_ASSERT(src0->extra); + GGML_ASSERT(dst); + GGML_ASSERT(dst->extra); + UNUSED(src1); + + ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context; + + ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra; + ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra; + + cl_ulong offset0 = extra0->offset + src0->view_offs; + cl_ulong offsetd = extrad->offset + dst->view_offs; + + cl_kernel kernel; + + // Currently assumes src0 is contiguous + int n = ggml_nelements(dst); + if (n % 4 == 0) { + if (src0->type == GGML_TYPE_F32) { + kernel = backend_ctx->kernel_sqr_cont_f32_4; } else { - kernel = backend_ctx->kernel_add; - CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device)); - CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0)); - CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra1->data_device)); - CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offset1)); - CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extrad->data_device)); - CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offsetd)); - CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne00)); - CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &ne01)); - CL_CHECK(clSetKernelArg(kernel, 8, sizeof(int), &ne02)); - CL_CHECK(clSetKernelArg(kernel, 9, sizeof(int), &ne03)); - CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_ulong), &nb00)); - CL_CHECK(clSetKernelArg(kernel, 11, sizeof(cl_ulong), &nb01)); - CL_CHECK(clSetKernelArg(kernel, 12, sizeof(cl_ulong), &nb02)); - CL_CHECK(clSetKernelArg(kernel, 13, sizeof(cl_ulong), &nb03)); - CL_CHECK(clSetKernelArg(kernel, 14, sizeof(int), &ne10)); - CL_CHECK(clSetKernelArg(kernel, 15, sizeof(int), &ne11)); - CL_CHECK(clSetKernelArg(kernel, 16, sizeof(int), &ne12)); - CL_CHECK(clSetKernelArg(kernel, 17, sizeof(int), &ne13)); - CL_CHECK(clSetKernelArg(kernel, 18, sizeof(cl_ulong), &nb10)); - CL_CHECK(clSetKernelArg(kernel, 19, sizeof(cl_ulong), &nb11)); - CL_CHECK(clSetKernelArg(kernel, 20, sizeof(cl_ulong), &nb12)); - CL_CHECK(clSetKernelArg(kernel, 21, sizeof(cl_ulong), &nb13)); - CL_CHECK(clSetKernelArg(kernel, 22, sizeof(int), &ne0)); - CL_CHECK(clSetKernelArg(kernel, 23, sizeof(int), &ne1)); - CL_CHECK(clSetKernelArg(kernel, 24, sizeof(int), &ne2)); - CL_CHECK(clSetKernelArg(kernel, 25, sizeof(int), &ne3)); - CL_CHECK(clSetKernelArg(kernel, 26, sizeof(cl_ulong), &nb0)); - CL_CHECK(clSetKernelArg(kernel, 27, sizeof(cl_ulong), &nb1)); - CL_CHECK(clSetKernelArg(kernel, 28, sizeof(cl_ulong), &nb2)); - CL_CHECK(clSetKernelArg(kernel, 29, sizeof(cl_ulong), &nb3)); + kernel = backend_ctx->kernel_sqr_cont_f16_4; } - } else if (dst->type == GGML_TYPE_F16) { - GGML_ASSERT(src0->type == GGML_TYPE_F16 || src0->type == GGML_TYPE_F32); - GGML_ASSERT(src1->type == GGML_TYPE_F16 || src1->type == GGML_TYPE_F32); - const int type_src0 = (src0->type == GGML_TYPE_F32); - const int type_src1 = (src1->type == GGML_TYPE_F32); - if (bcast_row) { - kernel = backend_ctx->kernel_add_row_f16; - const int ne = ne00 / 4; - CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device)); - CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0)); - CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra1->data_device)); - CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offset1)); - CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extrad->data_device)); - CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offsetd)); - CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne)); - CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &type_src0)); - CL_CHECK(clSetKernelArg(kernel, 8, sizeof(int), &type_src1)); + n /= 4; + } else { + if (src0->type == GGML_TYPE_F32) { + kernel = backend_ctx->kernel_sqr_cont_f32; } else { - kernel = backend_ctx->kernel_add_f16; - CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device)); - CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0)); - CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra1->data_device)); - CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offset1)); - CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extrad->data_device)); - CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offsetd)); - CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne00)); - CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &ne01)); - CL_CHECK(clSetKernelArg(kernel, 8, sizeof(int), &ne02)); - CL_CHECK(clSetKernelArg(kernel, 9, sizeof(int), &ne03)); - CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_ulong), &nb00)); - CL_CHECK(clSetKernelArg(kernel, 11, sizeof(cl_ulong), &nb01)); - CL_CHECK(clSetKernelArg(kernel, 12, sizeof(cl_ulong), &nb02)); - CL_CHECK(clSetKernelArg(kernel, 13, sizeof(cl_ulong), &nb03)); - CL_CHECK(clSetKernelArg(kernel, 14, sizeof(int), &ne10)); - CL_CHECK(clSetKernelArg(kernel, 15, sizeof(int), &ne11)); - CL_CHECK(clSetKernelArg(kernel, 16, sizeof(int), &ne12)); - CL_CHECK(clSetKernelArg(kernel, 17, sizeof(int), &ne13)); - CL_CHECK(clSetKernelArg(kernel, 18, sizeof(cl_ulong), &nb10)); - CL_CHECK(clSetKernelArg(kernel, 19, sizeof(cl_ulong), &nb11)); - CL_CHECK(clSetKernelArg(kernel, 20, sizeof(cl_ulong), &nb12)); - CL_CHECK(clSetKernelArg(kernel, 21, sizeof(cl_ulong), &nb13)); - CL_CHECK(clSetKernelArg(kernel, 22, sizeof(int), &ne0)); - CL_CHECK(clSetKernelArg(kernel, 23, sizeof(int), &ne1)); - CL_CHECK(clSetKernelArg(kernel, 24, sizeof(int), &ne2)); - CL_CHECK(clSetKernelArg(kernel, 25, sizeof(int), &ne3)); - CL_CHECK(clSetKernelArg(kernel, 26, sizeof(cl_ulong), &nb0)); - CL_CHECK(clSetKernelArg(kernel, 27, sizeof(cl_ulong), &nb1)); - CL_CHECK(clSetKernelArg(kernel, 28, sizeof(cl_ulong), &nb2)); - CL_CHECK(clSetKernelArg(kernel, 29, sizeof(cl_ulong), &nb3)); - CL_CHECK(clSetKernelArg(kernel, 30, sizeof(int), &type_src0)); - CL_CHECK(clSetKernelArg(kernel, 31, sizeof(int), &type_src1)); + kernel = backend_ctx->kernel_sqr_cont_f16; + } + } + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offsetd)); + + size_t global_work_size[] = {(size_t)n, 1, 1}; + size_t local_work_size[] = {64, 1, 1}; + + size_t * local_work_size_ptr = local_work_size; + if (n % 64 != 0 && !backend_ctx->non_uniform_workgroups) { + local_work_size_ptr = nullptr; + } + + backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size_ptr, dst); +} + +static void ggml_cl_sqrt(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { + GGML_ASSERT(src0); + GGML_ASSERT(src0->extra); + GGML_ASSERT(dst); + GGML_ASSERT(dst->extra); + UNUSED(src1); + + ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context; + + ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra; + ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra; + + cl_ulong offset0 = extra0->offset + src0->view_offs; + cl_ulong offsetd = extrad->offset + dst->view_offs; + + cl_kernel kernel; + + // Currently assumes src0 is contiguous + int n = ggml_nelements(dst); + if (n % 4 == 0) { + if (src0->type == GGML_TYPE_F32) { + kernel = backend_ctx->kernel_sqrt_cont_f32_4; + } else { + kernel = backend_ctx->kernel_sqrt_cont_f16_4; + } + n /= 4; + } else { + if (src0->type == GGML_TYPE_F32) { + kernel = backend_ctx->kernel_sqrt_cont_f32; + } else { + kernel = backend_ctx->kernel_sqrt_cont_f16; } - } else { - GGML_ASSERT(false && "unsupported data types for add"); } - if (bcast_row) { - int n = ggml_nelements(dst)/4; - size_t global_work_size[] = {(size_t)n, 1, 1}; - size_t local_work_size[] = {64, 1, 1}; - - size_t * local_work_size_ptr = local_work_size; - if (n % 64 != 0 && !backend_ctx->non_uniform_workgroups) { - local_work_size_ptr = nullptr; - } + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offsetd)); - backend_ctx->enqueue_ndrange_kernel(kernel, 1, global_work_size, local_work_size_ptr, dst); - } else { - unsigned int nth = MIN(64, ne0); - size_t global_work_size[] = {(size_t)ne01*nth, (size_t)ne02, (size_t)ne03}; - size_t local_work_size[] = {nth, 1, 1}; + size_t global_work_size[] = {(size_t)n, 1, 1}; + size_t local_work_size[] = {64, 1, 1}; - backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst); + size_t * local_work_size_ptr = local_work_size; + if (n % 64 != 0 && !backend_ctx->non_uniform_workgroups) { + local_work_size_ptr = nullptr; } + + backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size_ptr, dst); } -static void ggml_cl_add_id(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { +static void ggml_cl_mean(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { GGML_ASSERT(src0); GGML_ASSERT(src0->extra); - GGML_ASSERT(src1); - GGML_ASSERT(src1->extra); GGML_ASSERT(dst); GGML_ASSERT(dst->extra); + GGML_UNUSED(src1); - const ggml_tensor * src2 = dst->src[2]; - GGML_ASSERT(src2); - GGML_ASSERT(src2->extra); + GGML_ASSERT(src0->nb[0] == ggml_type_size(src0->type)); - GGML_ASSERT(src0->type == GGML_TYPE_F32); - GGML_ASSERT(src1->type == GGML_TYPE_F32); - GGML_ASSERT(src2->type == GGML_TYPE_I32); - GGML_ASSERT(dst->type == GGML_TYPE_F32); + ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context; - GGML_ASSERT(ggml_is_contiguous_rows(src0)); + ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra; + ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra; + + cl_ulong offset0 = extra0->offset + src0->view_offs; + cl_ulong offsetd = extrad->offset + dst->view_offs; const int ne00 = src0->ne[0]; const int ne01 = src0->ne[1]; const int ne02 = src0->ne[2]; + const int ne03 = src0->ne[3]; const cl_ulong nb01 = src0->nb[1]; const cl_ulong nb02 = src0->nb[2]; + const cl_ulong nb03 = src0->nb[3]; - const cl_ulong nb11 = src1->nb[1]; + const cl_ulong nb1 = dst->nb[1]; + const cl_ulong nb2 = dst->nb[2]; + const cl_ulong nb3 = dst->nb[3]; - const cl_ulong nb21 = src2->nb[1]; + cl_kernel kernel; - const int ne0 = dst->ne[0]; - const int ne1 = dst->ne[1]; + const bool is_c4 = ne00 % 4 == 0; + if (is_c4) { + kernel = backend_ctx->kernel_mean_f32_4; + } else { + kernel = backend_ctx->kernel_mean_f32; + } + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offsetd)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(int), &ne01)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne02)); + CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &ne03)); + CL_CHECK(clSetKernelArg(kernel, 8, sizeof(cl_ulong), &nb01)); + CL_CHECK(clSetKernelArg(kernel, 9, sizeof(cl_ulong), &nb02)); + CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_ulong), &nb03)); + CL_CHECK(clSetKernelArg(kernel, 11, sizeof(cl_ulong), &nb1)); + CL_CHECK(clSetKernelArg(kernel, 12, sizeof(cl_ulong), &nb2)); + CL_CHECK(clSetKernelArg(kernel, 13, sizeof(cl_ulong), &nb3)); + + size_t global_work_size[] = {64 * (size_t)ne01, (size_t)ne02, (size_t)ne03}; + size_t local_work_size[] = {(size_t)64, 1, 1}; + + backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst); +} + +static void ggml_cl_ssm_conv(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { + GGML_ASSERT(src0); + GGML_ASSERT(src0->extra); + GGML_ASSERT(src1); + GGML_ASSERT(src1->extra); + GGML_ASSERT(dst); + GGML_ASSERT(dst->extra); ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context; ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra; ggml_tensor_extra_cl * extra1 = (ggml_tensor_extra_cl *)src1->extra; - ggml_tensor_extra_cl * extra2 = (ggml_tensor_extra_cl *)src2->extra; ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra; cl_ulong offset0 = extra0->offset + src0->view_offs; cl_ulong offset1 = extra1->offset + src1->view_offs; - cl_ulong offset2 = extra2->offset + src2->view_offs; cl_ulong offsetd = extrad->offset + dst->view_offs; - cl_kernel kernel = backend_ctx->kernel_add_id; + int ne01 = src0->ne[1]; + cl_ulong nb00 = src0->nb[0]; + cl_ulong nb01 = src0->nb[1]; + cl_ulong nb02 = src0->nb[2]; + + int ne10 = src1->ne[0]; + cl_ulong nb11 = src1->nb[1]; + + int ne1 = dst->ne[1]; + int ne2 = dst->ne[2]; + cl_ulong nb0 = dst->nb[0]; + cl_ulong nb1 = dst->nb[1]; + cl_ulong nb2 = dst->nb[2]; + + cl_kernel kernel = backend_ctx->kernel_ssm_conv_f32_f32; + + if (ne10 % 4 == 0) { + kernel = backend_ctx->kernel_ssm_conv_f32_f32_4; + } CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device)); CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0)); CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra1->data_device)); CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offset1)); - CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extra2->data_device)); - CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offset2)); - CL_CHECK(clSetKernelArg(kernel, 6, sizeof(cl_mem), &extrad->data_device)); - CL_CHECK(clSetKernelArg(kernel, 7, sizeof(cl_ulong), &offsetd)); - CL_CHECK(clSetKernelArg(kernel, 8, sizeof(cl_ulong), &nb01)); - CL_CHECK(clSetKernelArg(kernel, 9, sizeof(cl_ulong), &nb02)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offsetd)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(cl_ulong), &nb00)); + CL_CHECK(clSetKernelArg(kernel, 7, sizeof(cl_ulong), &nb01)); + CL_CHECK(clSetKernelArg(kernel, 8, sizeof(cl_ulong), &nb02)); + CL_CHECK(clSetKernelArg(kernel, 9, sizeof(int), &ne10)); CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_ulong), &nb11)); - CL_CHECK(clSetKernelArg(kernel, 11, sizeof(cl_ulong), &nb21)); - CL_CHECK(clSetKernelArg(kernel, 12, sizeof(int), &ne0)); - CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int), &ne1)); + CL_CHECK(clSetKernelArg(kernel, 11, sizeof(cl_ulong), &nb0)); + CL_CHECK(clSetKernelArg(kernel, 12, sizeof(cl_ulong), &nb1)); + CL_CHECK(clSetKernelArg(kernel, 13, sizeof(cl_ulong), &nb2)); - int nth = MIN(ne00, (int) backend_ctx->get_kernel_workgroup_size(kernel)); - size_t global_work_size[] = { (size_t)ne01*nth, (size_t)ne02, 1 }; - size_t local_work_size[] = { (size_t)nth, 1, 1 }; + size_t global_work_size[] = {(size_t)ne01, (size_t)ne1, (size_t)ne2}; + size_t local_work_size[] = {64, 1, 1}; + + size_t * local_work_size_ptr = local_work_size; + if (ne01 % 64 != 0 && !backend_ctx->non_uniform_workgroups) { + local_work_size_ptr = nullptr; + } + + backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size_ptr, dst); +} + +static void ggml_cl_gelu(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { + GGML_ASSERT(src0); + GGML_ASSERT(src0->extra); + GGML_ASSERT(dst); + GGML_ASSERT(dst->extra); + + UNUSED(src1); + + ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context; + + ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra; + ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra; + + cl_ulong offset0 = extra0->offset + src0->view_offs; + cl_ulong offsetd = extrad->offset + dst->view_offs; + + cl_kernel kernel; + + int n = ggml_nelements(dst); + + if (n % 4 == 0) { + kernel = backend_ctx->kernel_gelu_4; + n /= 4; + } else { + kernel = backend_ctx->kernel_gelu; + } + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offsetd)); + + size_t global_work_size[] = {(size_t)n, 1, 1}; + size_t local_work_size[] = {64, 1, 1}; backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst); } -static void ggml_cl_mul(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { +static void ggml_cl_gelu_erf(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { GGML_ASSERT(src0); GGML_ASSERT(src0->extra); - GGML_ASSERT(src1); - GGML_ASSERT(src1->extra); GGML_ASSERT(dst); GGML_ASSERT(dst->extra); - GGML_ASSERT(src0->type == src1->type); - GGML_ASSERT(src0->type == dst->type); - GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16); + UNUSED(src1); - const int ne00 = src0->ne[0]; - const int ne01 = src0->ne[1]; - const int ne02 = src0->ne[2]; - const int ne03 = src0->ne[3]; + ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context; - const cl_ulong nb00 = src0->nb[0]; - const cl_ulong nb01 = src0->nb[1]; - const cl_ulong nb02 = src0->nb[2]; - const cl_ulong nb03 = src0->nb[3]; + ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra; + ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra; - const int ne10 = src1->ne[0]; - const int ne11 = src1->ne[1]; - const int ne12 = src1->ne[2]; - const int ne13 = src1->ne[3]; UNUSED(ne13); + cl_ulong offset0 = extra0->offset + src0->view_offs; + cl_ulong offsetd = extrad->offset + dst->view_offs; + + cl_kernel kernel; + + int n = ggml_nelements(dst); + + if (n % 4 == 0) { + kernel = backend_ctx->kernel_gelu_erf_4; + n /= 4; + } else { + kernel = backend_ctx->kernel_gelu_erf; + } + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offsetd)); + + size_t global_work_size[] = {(size_t)n, 1, 1}; + size_t local_work_size[] = {64, 1, 1}; + + backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst); +} + +static void ggml_cl_gelu_quick(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { + GGML_ASSERT(src0); + GGML_ASSERT(src0->extra); + GGML_ASSERT(dst); + GGML_ASSERT(dst->extra); + + UNUSED(src1); + + ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context; + + ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra; + ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra; + + cl_ulong offset0 = extra0->offset + src0->view_offs; + cl_ulong offsetd = extrad->offset + dst->view_offs; + + cl_kernel kernel; + + int n = ggml_nelements(dst); + + if (n % 4 == 0) { + kernel = backend_ctx->kernel_gelu_quick_4; + n /= 4; + } else { + kernel = backend_ctx->kernel_gelu_quick; + } + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offsetd)); - const cl_ulong nb10 = src1->nb[0]; - const cl_ulong nb11 = src1->nb[1]; - const cl_ulong nb12 = src1->nb[2]; - const cl_ulong nb13 = src1->nb[3]; UNUSED(nb13); + size_t global_work_size[] = {(size_t)n, 1, 1}; + size_t local_work_size[] = {64, 1, 1}; - const int ne0 = dst->ne[0]; - const int ne1 = dst->ne[1]; - const int ne2 = dst->ne[2]; - const int ne3 = dst->ne[3]; + backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst); +} - const cl_ulong nb0 = dst->nb[0]; - const cl_ulong nb1 = dst->nb[1]; - const cl_ulong nb2 = dst->nb[2]; - const cl_ulong nb3 = dst->nb[3]; +static void ggml_cl_silu(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { + GGML_ASSERT(src0); + GGML_ASSERT(src0->extra); + GGML_ASSERT(dst); + GGML_ASSERT(dst->extra); + + UNUSED(src1); ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context; ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra; - ggml_tensor_extra_cl * extra1 = (ggml_tensor_extra_cl *)src1->extra; ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra; cl_ulong offset0 = extra0->offset + src0->view_offs; - cl_ulong offset1 = extra1->offset + src1->view_offs; cl_ulong offsetd = extrad->offset + dst->view_offs; - bool bcast_row = false; cl_kernel kernel; - if (ggml_nelements(src1) == ne10 && ggml_is_contiguous(src1) && ne00 % 4 == 0 && ne10 % 4 == 0) { - GGML_ASSERT(ggml_is_contiguous(src0)); - - // src1 is a row - GGML_ASSERT(ne11 == 1); - - bcast_row = true; - int ne = ne00 / 4; - - if (src0->type == GGML_TYPE_F32) { - kernel = backend_ctx->kernel_mul_row; - } else { - kernel = backend_ctx->kernel_mul_row_f16; - } + int n = ggml_nelements(dst); - CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device)); - CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0)); - CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra1->data_device)); - CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offset1)); - CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extrad->data_device)); - CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offsetd)); - CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne)); + if (n % 4 == 0) { + kernel = backend_ctx->kernel_silu_4; + n /= 4; } else { - if (src0->type == GGML_TYPE_F32) { - kernel = backend_ctx->kernel_mul; - } else { - kernel = backend_ctx->kernel_mul_f16; - } - - CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device)); - CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0)); - CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra1->data_device)); - CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offset1)); - CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extrad->data_device)); - CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offsetd)); - CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne00)); - CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &ne01)); - CL_CHECK(clSetKernelArg(kernel, 8, sizeof(int), &ne02)); - CL_CHECK(clSetKernelArg(kernel, 9, sizeof(int), &ne03)); - CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_ulong), &nb00)); - CL_CHECK(clSetKernelArg(kernel, 11, sizeof(cl_ulong), &nb01)); - CL_CHECK(clSetKernelArg(kernel, 12, sizeof(cl_ulong), &nb02)); - CL_CHECK(clSetKernelArg(kernel, 13, sizeof(cl_ulong), &nb03)); - CL_CHECK(clSetKernelArg(kernel, 14, sizeof(int), &ne10)); - CL_CHECK(clSetKernelArg(kernel, 15, sizeof(int), &ne11)); - CL_CHECK(clSetKernelArg(kernel, 16, sizeof(int), &ne12)); - CL_CHECK(clSetKernelArg(kernel, 17, sizeof(int), &ne13)); - CL_CHECK(clSetKernelArg(kernel, 18, sizeof(cl_ulong), &nb10)); - CL_CHECK(clSetKernelArg(kernel, 19, sizeof(cl_ulong), &nb11)); - CL_CHECK(clSetKernelArg(kernel, 20, sizeof(cl_ulong), &nb12)); - CL_CHECK(clSetKernelArg(kernel, 21, sizeof(cl_ulong), &nb13)); - CL_CHECK(clSetKernelArg(kernel, 22, sizeof(int), &ne0)); - CL_CHECK(clSetKernelArg(kernel, 23, sizeof(int), &ne1)); - CL_CHECK(clSetKernelArg(kernel, 24, sizeof(int), &ne2)); - CL_CHECK(clSetKernelArg(kernel, 25, sizeof(int), &ne3)); - CL_CHECK(clSetKernelArg(kernel, 26, sizeof(cl_ulong), &nb0)); - CL_CHECK(clSetKernelArg(kernel, 27, sizeof(cl_ulong), &nb1)); - CL_CHECK(clSetKernelArg(kernel, 28, sizeof(cl_ulong), &nb2)); - CL_CHECK(clSetKernelArg(kernel, 29, sizeof(cl_ulong), &nb3)); + kernel = backend_ctx->kernel_silu; } - if (bcast_row) { - int n = ggml_nelements(dst)/4; - size_t global_work_size[] = {(size_t)n, 1, 1}; - size_t local_work_size[] = {64, 1, 1}; - - size_t * local_work_size_ptr = local_work_size; - if (n % 64 != 0 && !backend_ctx->non_uniform_workgroups) { - local_work_size_ptr = nullptr; // Let driver choose the work-group sizes. - } + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offsetd)); - backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size_ptr, dst); - } else { - unsigned int nth = MIN(64, ne0); - size_t global_work_size[] = {ne01*nth, (size_t)ne02, (size_t)ne03}; - size_t local_work_size[] = {nth, 1, 1}; + size_t global_work_size[] = {(size_t)n, 1, 1}; + size_t local_work_size[] = {64, 1, 1}; - backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst); + size_t * local_work_size_ptr = local_work_size; + if (n % 64 != 0 && !backend_ctx->non_uniform_workgroups) { + local_work_size_ptr = nullptr; // Let driver choose the work-group sizes. } + + backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size_ptr, dst); } -static void ggml_cl_div(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { +static void ggml_cl_relu(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { GGML_ASSERT(src0); GGML_ASSERT(src0->extra); - GGML_ASSERT(src1); - GGML_ASSERT(src1->extra); GGML_ASSERT(dst); GGML_ASSERT(dst->extra); - GGML_ASSERT(src0->type == src1->type); - GGML_ASSERT(src0->type == dst->type); - GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16); + UNUSED(src1); - const int ne00 = src0->ne[0]; - const int ne01 = src0->ne[1]; - const int ne02 = src0->ne[2]; - const int ne03 = src0->ne[3]; + ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context; - const cl_ulong nb00 = src0->nb[0]; - const cl_ulong nb01 = src0->nb[1]; - const cl_ulong nb02 = src0->nb[2]; - const cl_ulong nb03 = src0->nb[3]; + ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra; + ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra; - const int ne10 = src1->ne[0]; - const int ne11 = src1->ne[1]; - const int ne12 = src1->ne[2]; - const int ne13 = src1->ne[3]; + cl_ulong offset0 = extra0->offset + src0->view_offs; + cl_ulong offsetd = extrad->offset + dst->view_offs; - const cl_ulong nb10 = src1->nb[0]; - const cl_ulong nb11 = src1->nb[1]; - const cl_ulong nb12 = src1->nb[2]; - const cl_ulong nb13 = src1->nb[3]; + cl_kernel kernel = backend_ctx->kernel_relu; - const int ne0 = dst->ne[0]; + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offsetd)); - const cl_ulong nb0 = dst->nb[0]; - const cl_ulong nb1 = dst->nb[1]; - const cl_ulong nb2 = dst->nb[2]; - const cl_ulong nb3 = dst->nb[3]; + const int64_t n = ggml_nelements(dst); + + size_t global_work_size[] = {(size_t)n, 1, 1}; + size_t local_work_size[] = {64, 1, 1}; + + size_t * local_work_size_ptr = local_work_size; + if (n % 64 != 0 && !backend_ctx->non_uniform_workgroups) { + local_work_size_ptr = nullptr; // Let driver choose the work-group sizes. + } + + backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size_ptr, dst); +} + +static void ggml_cl_sigmoid(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { + GGML_ASSERT(src0); + GGML_ASSERT(src0->extra); + GGML_ASSERT(dst); + GGML_ASSERT(dst->extra); + + UNUSED(src1); ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context; ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra; - ggml_tensor_extra_cl * extra1 = (ggml_tensor_extra_cl *)src1->extra; ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra; cl_ulong offset0 = extra0->offset + src0->view_offs; - cl_ulong offset1 = extra1->offset + src1->view_offs; cl_ulong offsetd = extrad->offset + dst->view_offs; - bool bcast_row = false; cl_kernel kernel; - - if (ggml_nelements(src1) == ne10 && ggml_is_contiguous(src1) && ne00 % 4 == 0 && ne10 % 4 == 0) { - GGML_ASSERT(ggml_is_contiguous(src0)); - - // src1 is a row - GGML_ASSERT(ne11 == 1); - - bcast_row = true; - int ne = ne00 / 4; - - if (src0->type == GGML_TYPE_F32) { - kernel = backend_ctx->kernel_div_row; - } else { - kernel = backend_ctx->kernel_div_row_f16; - } - - CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device)); - CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0)); - CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra1->data_device)); - CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offset1)); - CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extrad->data_device)); - CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offsetd)); - CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne)); + if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { + kernel = backend_ctx->kernel_sigmoid_f32; + } else if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) { + kernel = backend_ctx->kernel_sigmoid_f16; } else { - if (src0->type == GGML_TYPE_F32) { - kernel = backend_ctx->kernel_div; - } else { - kernel = backend_ctx->kernel_div_f16; - } - - CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device)); - CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0)); - CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra1->data_device)); - CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offset1)); - CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extrad->data_device)); - CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offsetd)); - CL_CHECK(clSetKernelArg(kernel, 6, sizeof(cl_ulong), &nb00)); - CL_CHECK(clSetKernelArg(kernel, 7, sizeof(cl_ulong), &nb01)); - CL_CHECK(clSetKernelArg(kernel, 8, sizeof(cl_ulong), &nb02)); - CL_CHECK(clSetKernelArg(kernel, 9, sizeof(cl_ulong), &nb03)); - CL_CHECK(clSetKernelArg(kernel, 10, sizeof(int), &ne10)); - CL_CHECK(clSetKernelArg(kernel, 11, sizeof(int), &ne11)); - CL_CHECK(clSetKernelArg(kernel, 12, sizeof(int), &ne12)); - CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int), &ne13)); - CL_CHECK(clSetKernelArg(kernel, 14, sizeof(cl_ulong), &nb10)); - CL_CHECK(clSetKernelArg(kernel, 15, sizeof(cl_ulong), &nb11)); - CL_CHECK(clSetKernelArg(kernel, 16, sizeof(cl_ulong), &nb12)); - CL_CHECK(clSetKernelArg(kernel, 17, sizeof(cl_ulong), &nb13)); - CL_CHECK(clSetKernelArg(kernel, 18, sizeof(int), &ne0)); - CL_CHECK(clSetKernelArg(kernel, 19, sizeof(cl_ulong), &nb0)); - CL_CHECK(clSetKernelArg(kernel, 20, sizeof(cl_ulong), &nb1)); - CL_CHECK(clSetKernelArg(kernel, 21, sizeof(cl_ulong), &nb2)); - CL_CHECK(clSetKernelArg(kernel, 22, sizeof(cl_ulong), &nb3)); + GGML_ASSERT(false && "Unsupported data types for sigmoid (input and output must be both f32 or f16)"); } - if (bcast_row) { - int n = ggml_nelements(dst)/4; - size_t global_work_size[] = {(size_t)n, 1, 1}; - size_t local_work_size[] = {64, 1, 1}; + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offsetd)); - backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst); - } else { - unsigned int nth = MIN(64, ne0); - size_t global_work_size[] = {ne01*nth, (size_t)ne02, (size_t)ne03}; - size_t local_work_size[] = {nth, 1, 1}; + const int64_t n = ggml_nelements(dst); - backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst); + size_t global_work_size[] = {(size_t)n, 1, 1}; + size_t local_work_size[] = {64, 1, 1}; + + size_t * local_work_size_ptr = local_work_size; + if (n % 64 != 0 && !backend_ctx->non_uniform_workgroups) { + local_work_size_ptr = nullptr; // Let driver choose the work-group sizes. } + + backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size_ptr, dst); } -static void ggml_cl_sub(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { +static void ggml_cl_tri(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { GGML_ASSERT(src0); GGML_ASSERT(src0->extra); - GGML_ASSERT(src1); - GGML_ASSERT(src1->extra); GGML_ASSERT(dst); GGML_ASSERT(dst->extra); - GGML_ASSERT(src0->type == src1->type); - GGML_ASSERT(src0->type == dst->type); - GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16); - - const int ne00 = src0->ne[0]; - const int ne01 = src0->ne[1]; - const int ne02 = src0->ne[2]; - const int ne03 = src0->ne[3]; - - const cl_ulong nb00 = src0->nb[0]; - const cl_ulong nb01 = src0->nb[1]; - const cl_ulong nb02 = src0->nb[2]; - const cl_ulong nb03 = src0->nb[3]; - - const int ne10 = src1->ne[0]; - const int ne11 = src1->ne[1]; - const int ne12 = src1->ne[2]; - const int ne13 = src1->ne[3]; - - const cl_ulong nb10 = src1->nb[0]; - const cl_ulong nb11 = src1->nb[1]; - const cl_ulong nb12 = src1->nb[2]; - const cl_ulong nb13 = src1->nb[3]; - - const int ne0 = dst->ne[0]; - - const cl_ulong nb0 = dst->nb[0]; - const cl_ulong nb1 = dst->nb[1]; - const cl_ulong nb2 = dst->nb[2]; - const cl_ulong nb3 = dst->nb[3]; + UNUSED(src1); ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context; ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra; - ggml_tensor_extra_cl * extra1 = (ggml_tensor_extra_cl *)src1->extra; ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra; cl_ulong offset0 = extra0->offset + src0->view_offs; - cl_ulong offset1 = extra1->offset + src1->view_offs; cl_ulong offsetd = extrad->offset + dst->view_offs; - bool bcast_row = false; - cl_kernel kernel; + const int tri_type = ggml_get_op_params_i32(dst, 0); + const int64_t n = ggml_nelements(dst); + const int ne0 = dst->ne[0]; + const int ne1 = dst->ne[1]; - if (ggml_nelements(src1) == ne10 && ggml_is_contiguous(src1) && ne00 % 4 == 0 && ne10 % 4 == 0) { - GGML_ASSERT(ggml_is_contiguous(src0)); + cl_kernel kernel = backend_ctx->kernel_tri; - // src1 is a row - GGML_ASSERT(ne11 == 1); + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offsetd)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(int), &n)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(int), &ne0)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne1)); + CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &tri_type)); - bcast_row = true; - int ne = ne00 / 4; + size_t local_work_size[1] = { 256 }; + size_t global_work_size[1] = { ((size_t)n + local_work_size[0] - 1) / local_work_size[0] * local_work_size[0] }; - if (src0->type == GGML_TYPE_F32) { - kernel = backend_ctx->kernel_sub_row; - } else { - kernel = backend_ctx->kernel_sub_row_f16; - } + backend_ctx->enqueue_ndrange_kernel(kernel, 1, global_work_size, local_work_size, dst); +} - CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device)); - CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0)); - CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra1->data_device)); - CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offset1)); - CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extrad->data_device)); - CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offsetd)); - CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne)); - } else { - if (src0->type == GGML_TYPE_F32) { - kernel = backend_ctx->kernel_sub; - } else { - kernel = backend_ctx->kernel_sub_f16; - } +static void ggml_cl_fill(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { + GGML_ASSERT(dst); + GGML_ASSERT(dst->extra); - CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device)); - CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0)); - CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra1->data_device)); - CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offset1)); - CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extrad->data_device)); - CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offsetd)); - CL_CHECK(clSetKernelArg(kernel, 6, sizeof(cl_ulong), &nb00)); - CL_CHECK(clSetKernelArg(kernel, 7, sizeof(cl_ulong), &nb01)); - CL_CHECK(clSetKernelArg(kernel, 8, sizeof(cl_ulong), &nb02)); - CL_CHECK(clSetKernelArg(kernel, 9, sizeof(cl_ulong), &nb03)); - CL_CHECK(clSetKernelArg(kernel, 10, sizeof(int), &ne10)); - CL_CHECK(clSetKernelArg(kernel, 11, sizeof(int), &ne11)); - CL_CHECK(clSetKernelArg(kernel, 12, sizeof(int), &ne12)); - CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int), &ne13)); - CL_CHECK(clSetKernelArg(kernel, 14, sizeof(cl_ulong), &nb10)); - CL_CHECK(clSetKernelArg(kernel, 15, sizeof(cl_ulong), &nb11)); - CL_CHECK(clSetKernelArg(kernel, 16, sizeof(cl_ulong), &nb12)); - CL_CHECK(clSetKernelArg(kernel, 17, sizeof(cl_ulong), &nb13)); - CL_CHECK(clSetKernelArg(kernel, 18, sizeof(int), &ne0)); - CL_CHECK(clSetKernelArg(kernel, 19, sizeof(cl_ulong), &nb0)); - CL_CHECK(clSetKernelArg(kernel, 20, sizeof(cl_ulong), &nb1)); - CL_CHECK(clSetKernelArg(kernel, 21, sizeof(cl_ulong), &nb2)); - CL_CHECK(clSetKernelArg(kernel, 22, sizeof(cl_ulong), &nb3)); - } + UNUSED(src0); + UNUSED(src1); - if (bcast_row) { - int n = ggml_nelements(dst)/4; - size_t global_work_size[] = {(size_t)n, 1, 1}; - size_t local_work_size[] = {64, 1, 1}; + ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context; - backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst); - } else { - unsigned int nth = MIN(64, ne0); - size_t global_work_size[] = {ne01*nth, (size_t)ne02, (size_t)ne03}; - size_t local_work_size[] = {nth, 1, 1}; + ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra; + cl_ulong offsetd = extrad->offset + dst->view_offs; - backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst); - } + float v = 0.0f; + memcpy(&v, ((int32_t *) dst->op_params), sizeof(float)); + + const int64_t n = ggml_nelements(dst); + + cl_kernel kernel = backend_ctx->kernel_fill; + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offsetd)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(float), &v)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(float), &n)); + + size_t local_work_size[1] = { 256 }; + size_t global_work_size[1] = { ((size_t)n + local_work_size[0] - 1) / local_work_size[0] * local_work_size[0] }; + + backend_ctx->enqueue_ndrange_kernel(kernel, 1, global_work_size, local_work_size, dst); } -static void ggml_cl_sqr(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { +static void ggml_cl_clamp(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { GGML_ASSERT(src0); GGML_ASSERT(src0->extra); GGML_ASSERT(dst); GGML_ASSERT(dst->extra); + UNUSED(src1); ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context; @@ -5502,46 +10098,39 @@ static void ggml_cl_sqr(ggml_backend_t backend, const ggml_tensor * src0, const cl_ulong offset0 = extra0->offset + src0->view_offs; cl_ulong offsetd = extrad->offset + dst->view_offs; - cl_kernel kernel; + float min; + float max; + memcpy(&min, ((int32_t *) dst->op_params) + 0, sizeof(float)); + memcpy(&max, ((int32_t *) dst->op_params) + 1, sizeof(float)); - // Currently assumes src0 is contiguous - int n = ggml_nelements(dst); - if (n % 4 == 0) { - if (src0->type == GGML_TYPE_F32) { - kernel = backend_ctx->kernel_sqr_cont_f32_4; - } else { - kernel = backend_ctx->kernel_sqr_cont_f16_4; - } - n /= 4; - } else { - if (src0->type == GGML_TYPE_F32) { - kernel = backend_ctx->kernel_sqr_cont_f32; - } else { - kernel = backend_ctx->kernel_sqr_cont_f16; - } - } + cl_kernel kernel = backend_ctx->kernel_clamp; CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device)); CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0)); CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extrad->data_device)); CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offsetd)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(float), &min)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(float), &max)); + + const int64_t n = ggml_nelements(dst); size_t global_work_size[] = {(size_t)n, 1, 1}; size_t local_work_size[] = {64, 1, 1}; size_t * local_work_size_ptr = local_work_size; if (n % 64 != 0 && !backend_ctx->non_uniform_workgroups) { - local_work_size_ptr = nullptr; + local_work_size_ptr = nullptr; // Let driver choose the work-group sizes. } backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size_ptr, dst); } -static void ggml_cl_sqrt(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { +static void ggml_cl_norm(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { GGML_ASSERT(src0); GGML_ASSERT(src0->extra); GGML_ASSERT(dst); GGML_ASSERT(dst->extra); + UNUSED(src1); ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context; @@ -5552,59 +10141,153 @@ static void ggml_cl_sqrt(ggml_backend_t backend, const ggml_tensor * src0, const cl_ulong offset0 = extra0->offset + src0->view_offs; cl_ulong offsetd = extrad->offset + dst->view_offs; - cl_kernel kernel; + float eps; + memcpy(&eps, dst->op_params, sizeof(float)); - // Currently assumes src0 is contiguous - int n = ggml_nelements(dst); - if (n % 4 == 0) { - if (src0->type == GGML_TYPE_F32) { - kernel = backend_ctx->kernel_sqrt_cont_f32_4; - } else { - kernel = backend_ctx->kernel_sqrt_cont_f16_4; - } - n /= 4; + const int ne00 = src0 ? src0->ne[0] : 0; + const int ne01 = src0 ? src0->ne[1] : 0; + const int ne02 = src0 ? src0->ne[2] : 0; + const int ne03 = src0 ? src0->ne[3] : 0; + + const cl_ulong nb01 = src0 ? src0->nb[1] : 0; + const cl_ulong nb02 = src0 ? src0->nb[2] : 0; + const cl_ulong nb03 = src0 ? src0->nb[3] : 0; + + const int nth = MIN(64, ne00); + + cl_kernel kernel = backend_ctx->kernel_norm; + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offsetd)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(int), &ne01)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne02)); + CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &ne03)); + CL_CHECK(clSetKernelArg(kernel, 8, sizeof(cl_ulong), &nb01)); + CL_CHECK(clSetKernelArg(kernel, 9, sizeof(cl_ulong), &nb02)); + CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_ulong), &nb03)); + CL_CHECK(clSetKernelArg(kernel, 11, sizeof(float), &eps)); + CL_CHECK(clSetKernelArg(kernel, 12, sizeof(float)*nth, NULL)); + + size_t global_work_size[] = {(size_t)ne01*nth, (size_t)ne02, (size_t)ne03}; + size_t local_work_size[] = {(size_t)nth, 1, 1}; + + backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst); +} + +static void ggml_cl_rms_norm(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { + GGML_ASSERT(src0); + GGML_ASSERT(src0->extra); + GGML_ASSERT(dst); + GGML_ASSERT(dst->extra); + + UNUSED(src1); + + ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context; + + //ggml_backend_opencl_device_context * dev_ctx = + // (ggml_backend_opencl_device_context *)backend->device->context; + + ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra; + ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra; + + cl_ulong offset0 = extra0->offset + src0->view_offs; + cl_ulong offsetd = extrad->offset + dst->view_offs; + + float eps; + memcpy(&eps, dst->op_params, sizeof(float)); + + const int ne00 = src0 ? src0->ne[0] : 0; + const int ne01 = src0 ? src0->ne[1] : 0; + const int ne02 = src0 ? src0->ne[2] : 0; + const int ne03 = src0 ? src0->ne[3] : 0; + + const cl_ulong nb01 = src0 ? src0->nb[1] : 0; + const cl_ulong nb02 = src0 ? src0->nb[2] : 0; + const cl_ulong nb03 = src0 ? src0->nb[3] : 0; + + GGML_ASSERT(ne00 % 4 == 0); + + const int nth = MIN(64, ne00); + + size_t global_work_size[] = {(size_t)ne01*nth, (size_t)ne02, (size_t)ne03}; + size_t local_work_size[] = {(size_t)nth, 1, 1}; + + cl_kernel kernel = backend_ctx->kernel_rms_norm; + + // Note, this kernel declares local memory in kernel args and the size + // depends on subgroup size. + // Note, this requires OpenCL 2.1 and above + // For now we use fixed subgroup size to simplify support for OpenCL 2.0. + size_t sgs; + //CL_CHECK(clGetKernelSubGroupInfo(kernel, dev_ctx->device, + // CL_KERNEL_MAX_SUB_GROUP_SIZE_FOR_NDRANGE, + // sizeof(local_work_size), local_work_size, + // sizeof(size_t), &sgs, NULL)); + if (backend_ctx->gpu_family == ADRENO) { + sgs = 64; + } else if (backend_ctx->gpu_family == INTEL) { + sgs = 32; } else { - if (src0->type == GGML_TYPE_F32) { - kernel = backend_ctx->kernel_sqrt_cont_f32; - } else { - kernel = backend_ctx->kernel_sqrt_cont_f16; - } + GGML_ASSERT(false && "Unsupported GPU"); } - CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device)); - CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0)); - CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extrad->data_device)); - CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offsetd)); + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offsetd)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(int), &ne01)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne02)); + CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &ne03)); + CL_CHECK(clSetKernelArg(kernel, 8, sizeof(cl_ulong), &nb01)); + CL_CHECK(clSetKernelArg(kernel, 9, sizeof(cl_ulong), &nb02)); + CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_ulong), &nb03)); + CL_CHECK(clSetKernelArg(kernel, 11, sizeof(float), &eps)); + // This is local memory - the size depends on subgroup size. + CL_CHECK(clSetKernelArg(kernel, 12, sizeof(float)*nth/sgs, NULL)); - size_t global_work_size[] = {(size_t)n, 1, 1}; - size_t local_work_size[] = {64, 1, 1}; + backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst); +} - size_t * local_work_size_ptr = local_work_size; - if (n % 64 != 0 && !backend_ctx->non_uniform_workgroups) { - local_work_size_ptr = nullptr; - } +static void ggml_opencl_op_rms_norm_fused(ggml_backend_t backend, ggml_tensor * rms_norm_tensor, ggml_tensor * mul_tensor) { + GGML_ASSERT(mul_tensor); + GGML_ASSERT(rms_norm_tensor); - backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size_ptr, dst); -} + // src0 is the src of rms_norm, src1 is the other src of mul (one being rms_norm) + const ggml_tensor * src0 = rms_norm_tensor->src[0]; + const ggml_tensor * src1; + if (mul_tensor->src[0] == rms_norm_tensor) { + src1 = mul_tensor->src[1]; + } else if (mul_tensor->src[1] == rms_norm_tensor) { + src1 = mul_tensor->src[0]; + } else { + GGML_ASSERT(false && "Invalid args for rms_norm and mul"); + } + const ggml_tensor * dst = mul_tensor; -static void ggml_cl_mean(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { GGML_ASSERT(src0); GGML_ASSERT(src0->extra); + GGML_ASSERT(src1); + GGML_ASSERT(src1->extra); GGML_ASSERT(dst); GGML_ASSERT(dst->extra); - GGML_UNUSED(src1); - - GGML_ASSERT(src0->nb[0] == ggml_type_size(src0->type)); - GGML_ASSERT(ggml_is_contiguous(src0)); - - ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context; ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra; + ggml_tensor_extra_cl * extra1 = (ggml_tensor_extra_cl *)src1->extra; ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra; cl_ulong offset0 = extra0->offset + src0->view_offs; + cl_ulong offset1 = extra1->offset + src0->view_offs; cl_ulong offsetd = extrad->offset + dst->view_offs; + ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context; + + float eps; + memcpy(&eps, rms_norm_tensor->op_params, sizeof(float)); + const int ne00 = src0->ne[0]; const int ne01 = src0->ne[1]; const int ne02 = src0->ne[2]; @@ -5614,98 +10297,207 @@ static void ggml_cl_mean(ggml_backend_t backend, const ggml_tensor * src0, const const cl_ulong nb02 = src0->nb[2]; const cl_ulong nb03 = src0->nb[3]; - const cl_ulong nb1 = dst->nb[1]; - const cl_ulong nb2 = dst->nb[2]; - const cl_ulong nb3 = dst->nb[3]; + const int ne10 = src1->ne[0]; + const int ne11 = src1->ne[1]; + const int ne12 = src1->ne[2]; + const int ne13 = src1->ne[3]; - cl_kernel kernel = backend_ctx->kernel_mean_f32; + const cl_ulong nb11 = src1->nb[1]; + const cl_ulong nb12 = src1->nb[2]; + const cl_ulong nb13 = src1->nb[3]; - CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device)); - CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0)); - CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extrad->data_device)); - CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offsetd)); - CL_CHECK(clSetKernelArg(kernel, 4, sizeof(int), &ne00)); - CL_CHECK(clSetKernelArg(kernel, 5, sizeof(int), &ne01)); - CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne02)); - CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &ne03)); - CL_CHECK(clSetKernelArg(kernel, 8, sizeof(cl_ulong), &nb01)); - CL_CHECK(clSetKernelArg(kernel, 9, sizeof(cl_ulong), &nb02)); - CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_ulong), &nb03)); - CL_CHECK(clSetKernelArg(kernel, 11, sizeof(cl_ulong), &nb1)); - CL_CHECK(clSetKernelArg(kernel, 12, sizeof(cl_ulong), &nb2)); - CL_CHECK(clSetKernelArg(kernel, 13, sizeof(cl_ulong), &nb3)); + const cl_ulong nb1 = dst->nb[1]; + const cl_ulong nb2 = dst->nb[2]; + const cl_ulong nb3 = dst->nb[3]; - size_t global_work_size[] = {(size_t)ne01, (size_t)ne02, (size_t)ne03}; - size_t local_work_size[] = {(size_t)64, 1, 1}; + GGML_ASSERT(ne00 % 4 == 0); + + size_t sgs; + if (backend_ctx->gpu_family == ADRENO) { + sgs = 64; + } else if (backend_ctx->gpu_family == INTEL) { + sgs = 32; + } else { + GGML_ASSERT(false && "Unsupported GPU"); + } + + cl_kernel kernel = backend_ctx->kernel_rms_norm_mul; + + int nth = sgs; + int max_workgroup_size = backend_ctx->get_kernel_workgroup_size(kernel); + while (nth < ne00 && nth < max_workgroup_size) { + nth *= 2; + } + nth = MIN(nth, max_workgroup_size); + nth = MIN(nth, ne00); + + size_t global_work_size[] = {(size_t)ne01*nth, (size_t)ne02, (size_t)ne03}; + size_t local_work_size[] = {(size_t)nth, 1, 1}; + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra1->data_device)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offset1)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offsetd)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &ne01)); + CL_CHECK(clSetKernelArg(kernel, 8, sizeof(int), &ne02)); + CL_CHECK(clSetKernelArg(kernel, 9, sizeof(int), &ne03)); + CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_ulong), &nb01)); + CL_CHECK(clSetKernelArg(kernel, 11, sizeof(cl_ulong), &nb02)); + CL_CHECK(clSetKernelArg(kernel, 12, sizeof(cl_ulong), &nb03)); + CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int), &ne10)); + CL_CHECK(clSetKernelArg(kernel, 14, sizeof(int), &ne11)); + CL_CHECK(clSetKernelArg(kernel, 15, sizeof(int), &ne12)); + CL_CHECK(clSetKernelArg(kernel, 16, sizeof(int), &ne13)); + CL_CHECK(clSetKernelArg(kernel, 17, sizeof(cl_ulong), &nb11)); + CL_CHECK(clSetKernelArg(kernel, 18, sizeof(cl_ulong), &nb12)); + CL_CHECK(clSetKernelArg(kernel, 19, sizeof(cl_ulong), &nb13)); + CL_CHECK(clSetKernelArg(kernel, 20, sizeof(cl_ulong), &nb1)); + CL_CHECK(clSetKernelArg(kernel, 21, sizeof(cl_ulong), &nb2)); + CL_CHECK(clSetKernelArg(kernel, 22, sizeof(cl_ulong), &nb3)); + CL_CHECK(clSetKernelArg(kernel, 23, sizeof(float), &eps)); + CL_CHECK(clSetKernelArg(kernel, 24, sizeof(float)*sgs, NULL)); backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst); } -static void ggml_cl_ssm_conv(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { - GGML_ASSERT(src0); - GGML_ASSERT(src0->extra); - GGML_ASSERT(src1); - GGML_ASSERT(src1->extra); - GGML_ASSERT(dst); - GGML_ASSERT(dst->extra); +static void ggml_opencl_op_norm_fused(ggml_backend_t backend, ggml_tensor * norm_tensor, ggml_tensor * mul_tensor, ggml_tensor * add_tensor) { + GGML_ASSERT(norm_tensor && mul_tensor && add_tensor); - ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context; + const ggml_tensor * src0 = norm_tensor->src[0]; + const ggml_tensor * src1 = mul_tensor->src[0] == norm_tensor ? mul_tensor->src[1] : mul_tensor->src[0]; + const ggml_tensor * src2 = add_tensor->src[0] == mul_tensor ? add_tensor->src[1] : add_tensor->src[0]; + const ggml_tensor * dst = add_tensor; ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra; ggml_tensor_extra_cl * extra1 = (ggml_tensor_extra_cl *)src1->extra; + ggml_tensor_extra_cl * extra2 = (ggml_tensor_extra_cl *)src2->extra; ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra; cl_ulong offset0 = extra0->offset + src0->view_offs; cl_ulong offset1 = extra1->offset + src1->view_offs; + cl_ulong offset2 = extra2->offset + src2->view_offs; cl_ulong offsetd = extrad->offset + dst->view_offs; - int ne01 = src0->ne[1]; - cl_ulong nb00 = src0->nb[0]; - cl_ulong nb01 = src0->nb[1]; - cl_ulong nb02 = src0->nb[2]; + ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context; - int ne10 = src1->ne[0]; - cl_ulong nb11 = src1->nb[1]; + float eps; + memcpy(&eps, norm_tensor->op_params, sizeof(float)); - int ne1 = dst->ne[1]; - int ne2 = dst->ne[2]; - cl_ulong nb0 = dst->nb[0]; - cl_ulong nb1 = dst->nb[1]; - cl_ulong nb2 = dst->nb[2]; + const int ne00 = src0->ne[0], ne01 = src0->ne[1], ne02 = src0->ne[2], ne03 = src0->ne[3]; + const cl_ulong nb01 = src0->nb[1], nb02 = src0->nb[2], nb03 = src0->nb[3]; + const int ne10 = src1->ne[0], ne11 = src1->ne[1], ne12 = src1->ne[2], ne13 = src1->ne[3]; + const cl_ulong nb11 = src1->nb[1], nb12 = src1->nb[2], nb13 = src1->nb[3]; + const int ne20 = src2->ne[0], ne21 = src2->ne[1], ne22 = src2->ne[2], ne23 = src2->ne[3]; + const cl_ulong nb21 = src2->nb[1], nb22 = src2->nb[2], nb23 = src2->nb[3]; + const cl_ulong nbd1 = dst->nb[1], nbd2 = dst->nb[2], nbd3 = dst->nb[3]; + + size_t sgs; + if (backend_ctx->gpu_family == ADRENO) sgs = 64; + else if (backend_ctx->gpu_family == INTEL) sgs = 32; + else GGML_ASSERT(false && "Unsupported GPU"); + + cl_kernel kernel = backend_ctx->kernel_norm_mul_add; + + int nth = sgs; + int max_workgroup_size = backend_ctx->get_kernel_workgroup_size(kernel); + while (nth < ne00/4 && nth < max_workgroup_size) nth *= 2; + nth = MIN(nth, max_workgroup_size); + nth = MIN(nth, ne00/4); + + size_t gws[] = {(size_t)ne01*nth, (size_t)ne02, (size_t)ne03}; + size_t lws[] = {(size_t)nth, 1, 1}; + size_t num_subgroups = (nth + sgs - 1) / sgs; + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra1->data_device)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offset1)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extra2->data_device)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offset2)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, 7, sizeof(cl_ulong), &offsetd)); + CL_CHECK(clSetKernelArg(kernel, 8, sizeof(int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, 9, sizeof(int), &ne01)); + CL_CHECK(clSetKernelArg(kernel, 10, sizeof(int), &ne02)); + CL_CHECK(clSetKernelArg(kernel, 11, sizeof(int), &ne03)); + CL_CHECK(clSetKernelArg(kernel, 12, sizeof(cl_ulong), &nb01)); + CL_CHECK(clSetKernelArg(kernel, 13, sizeof(cl_ulong), &nb02)); + CL_CHECK(clSetKernelArg(kernel, 14, sizeof(cl_ulong), &nb03)); + CL_CHECK(clSetKernelArg(kernel, 15, sizeof(int), &ne10)); + CL_CHECK(clSetKernelArg(kernel, 16, sizeof(int), &ne11)); + CL_CHECK(clSetKernelArg(kernel, 17, sizeof(int), &ne12)); + CL_CHECK(clSetKernelArg(kernel, 18, sizeof(int), &ne13)); + CL_CHECK(clSetKernelArg(kernel, 19, sizeof(cl_ulong), &nb11)); + CL_CHECK(clSetKernelArg(kernel, 20, sizeof(cl_ulong), &nb12)); + CL_CHECK(clSetKernelArg(kernel, 21, sizeof(cl_ulong), &nb13)); + CL_CHECK(clSetKernelArg(kernel, 22, sizeof(int), &ne20)); + CL_CHECK(clSetKernelArg(kernel, 23, sizeof(int), &ne21)); + CL_CHECK(clSetKernelArg(kernel, 24, sizeof(int), &ne22)); + CL_CHECK(clSetKernelArg(kernel, 25, sizeof(int), &ne23)); + CL_CHECK(clSetKernelArg(kernel, 26, sizeof(cl_ulong), &nb21)); + CL_CHECK(clSetKernelArg(kernel, 27, sizeof(cl_ulong), &nb22)); + CL_CHECK(clSetKernelArg(kernel, 28, sizeof(cl_ulong), &nb23)); + CL_CHECK(clSetKernelArg(kernel, 29, sizeof(cl_ulong), &nbd1)); + CL_CHECK(clSetKernelArg(kernel, 30, sizeof(cl_ulong), &nbd2)); + CL_CHECK(clSetKernelArg(kernel, 31, sizeof(cl_ulong), &nbd3)); + CL_CHECK(clSetKernelArg(kernel, 32, sizeof(float), &eps)); + CL_CHECK(clSetKernelArg(kernel, 33, sizeof(cl_float2) * num_subgroups, NULL)); + + backend_ctx->enqueue_ndrange_kernel(kernel, 3, gws, lws, dst); +} + +static void ggml_opencl_op_group_norm_fused(ggml_backend_t backend, ggml_tensor * gn_tensor, ggml_tensor * mul_tensor, ggml_tensor * add_tensor) { + GGML_ASSERT(gn_tensor && mul_tensor && add_tensor); + + const ggml_tensor * src0 = gn_tensor->src[0]; + const ggml_tensor * src1 = mul_tensor->src[0] == gn_tensor ? mul_tensor->src[1] : mul_tensor->src[0]; + const ggml_tensor * src2 = add_tensor->src[0] == mul_tensor ? add_tensor->src[1] : add_tensor->src[0]; + const ggml_tensor * dst = add_tensor; + + ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra; + ggml_tensor_extra_cl * extra1 = (ggml_tensor_extra_cl *)src1->extra; + ggml_tensor_extra_cl * extra2 = (ggml_tensor_extra_cl *)src2->extra; + ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra; + + cl_ulong offset0 = extra0->offset + src0->view_offs; + cl_ulong offset1 = extra1->offset + src1->view_offs; + cl_ulong offset2 = extra2->offset + src2->view_offs; + cl_ulong offsetd = extrad->offset + dst->view_offs; - cl_kernel kernel = backend_ctx->kernel_ssm_conv_f32_f32; + ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context; - if (ne10 % 4 == 0) { - kernel = backend_ctx->kernel_ssm_conv_f32_f32_4; - } + int groups; + float eps; + memcpy(&groups, gn_tensor->op_params, sizeof(int)); + memcpy(&eps, (char *)gn_tensor->op_params + sizeof(int), sizeof(float)); - CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device)); - CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0)); - CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra1->data_device)); - CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offset1)); - CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extrad->data_device)); - CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offsetd)); - CL_CHECK(clSetKernelArg(kernel, 6, sizeof(cl_ulong), &nb00)); - CL_CHECK(clSetKernelArg(kernel, 7, sizeof(cl_ulong), &nb01)); - CL_CHECK(clSetKernelArg(kernel, 8, sizeof(cl_ulong), &nb02)); - CL_CHECK(clSetKernelArg(kernel, 9, sizeof(int), &ne10)); - CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_ulong), &nb11)); - CL_CHECK(clSetKernelArg(kernel, 11, sizeof(cl_ulong), &nb0)); - CL_CHECK(clSetKernelArg(kernel, 12, sizeof(cl_ulong), &nb1)); - CL_CHECK(clSetKernelArg(kernel, 13, sizeof(cl_ulong), &nb2)); + cl_kernel kernel = backend_ctx->kernel_group_norm_mul_add; + int max_workgroup_size = backend_ctx->get_kernel_workgroup_size(kernel); + int ne = ggml_nelements(src0); + int group_size = ne / groups; - size_t global_work_size[] = {(size_t)ne01, (size_t)ne1, (size_t)ne2}; - size_t local_work_size[] = {64, 1, 1}; + size_t lws[] = { (size_t)MIN(max_workgroup_size, group_size) }; + size_t gws[] = { (size_t)groups * lws[0] }; - size_t * local_work_size_ptr = local_work_size; - if (ne01 % 64 != 0 && !backend_ctx->non_uniform_workgroups) { - local_work_size_ptr = nullptr; - } + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra1->data_device)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offset1)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extra2->data_device)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offset2)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, 7, sizeof(cl_ulong), &offsetd)); + CL_CHECK(clSetKernelArg(kernel, 8, sizeof(int), &ne)); + CL_CHECK(clSetKernelArg(kernel, 9, sizeof(int), &group_size)); + CL_CHECK(clSetKernelArg(kernel, 10, sizeof(float), &eps)); - backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size_ptr, dst); + backend_ctx->enqueue_ndrange_kernel(kernel, 1, gws, lws, dst); } -static void ggml_cl_gelu(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { +static void ggml_cl_group_norm(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { GGML_ASSERT(src0); GGML_ASSERT(src0->extra); GGML_ASSERT(dst); @@ -5721,29 +10513,41 @@ static void ggml_cl_gelu(ggml_backend_t backend, const ggml_tensor * src0, const cl_ulong offset0 = extra0->offset + src0->view_offs; cl_ulong offsetd = extrad->offset + dst->view_offs; - cl_kernel kernel; + int32_t n_groups = ((const int32_t *) dst->op_params)[0]; + int32_t group_size = src0->ne[0] * src0->ne[1] * ((src0->ne[2] + n_groups - 1) / n_groups); + float eps = ((const float *) dst->op_params)[1]; - int n = ggml_nelements(dst); + const int ne00 = src0->ne[0]; + const int ne01 = src0->ne[1]; + const int ne02 = src0->ne[2]; + const int ne = ne00*ne01*ne02; - if (n % 4 == 0) { - kernel = backend_ctx->kernel_gelu_4; - n /= 4; + cl_kernel kernel = backend_ctx->kernel_group_norm; + + size_t sgs = 64; + if (backend_ctx->gpu_family == ADRENO) { + sgs = 64; + } else if (backend_ctx->gpu_family == INTEL) { + sgs = 32; } else { - kernel = backend_ctx->kernel_gelu; + GGML_ASSERT(false && "Unsupported GPU"); } CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device)); CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0)); CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extrad->data_device)); CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offsetd)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(int), &ne)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(int), &group_size)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(float), &eps)); - size_t global_work_size[] = {(size_t)n, 1, 1}; - size_t local_work_size[] = {64, 1, 1}; + size_t global_work_size[] = {(size_t)n_groups*sgs, 1, 1}; + size_t local_work_size[] = {(size_t)sgs, 1, 1}; backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst); } -static void ggml_cl_gelu_erf(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { +static void ggml_cl_l2_norm(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { GGML_ASSERT(src0); GGML_ASSERT(src0->extra); GGML_ASSERT(dst); @@ -5759,29 +10563,49 @@ static void ggml_cl_gelu_erf(ggml_backend_t backend, const ggml_tensor * src0, c cl_ulong offset0 = extra0->offset + src0->view_offs; cl_ulong offsetd = extrad->offset + dst->view_offs; - cl_kernel kernel; + float eps; + memcpy(&eps, dst->op_params, sizeof(float)); - int n = ggml_nelements(dst); + GGML_TENSOR_LOCALS(int, ne0, src0, ne); + GGML_TENSOR_LOCALS(cl_ulong, nb0, src0, nb); - if (n % 4 == 0) { - kernel = backend_ctx->kernel_gelu_erf_4; - n /= 4; + size_t sgs; + if (backend_ctx->gpu_family == ADRENO) { + sgs = 64; + } else if (backend_ctx->gpu_family == INTEL) { + sgs = 32; } else { - kernel = backend_ctx->kernel_gelu_erf; + GGML_ASSERT(false && "Unsupported GPU"); } - CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device)); - CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0)); - CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extrad->data_device)); - CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offsetd)); + cl_kernel kernel = backend_ctx->kernel_l2_norm_f32; - size_t global_work_size[] = {(size_t)n, 1, 1}; - size_t local_work_size[] = {64, 1, 1}; + int nth = sgs; + while (nth < ne00 && nth < (int)backend_ctx->get_kernel_workgroup_size(kernel)) { + nth *= 2; + } + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offsetd)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(int), &ne01)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne02)); + CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &ne03)); + CL_CHECK(clSetKernelArg(kernel, 8, sizeof(cl_ulong), &nb01)); + CL_CHECK(clSetKernelArg(kernel, 9, sizeof(cl_ulong), &nb02)); + CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_ulong), &nb03)); + CL_CHECK(clSetKernelArg(kernel, 11, sizeof(float), &eps)); + CL_CHECK(clSetKernelArg(kernel, 12, sizeof(float)*nth/sgs, NULL)); + + size_t global_work_size[] = {(size_t)ne01*nth, (size_t)ne02, (size_t)ne03}; + size_t local_work_size[] = {(size_t)nth, 1, 1}; backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst); } -static void ggml_cl_gelu_quick(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { +static void ggml_cl_tanh(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { GGML_ASSERT(src0); GGML_ASSERT(src0->extra); GGML_ASSERT(dst); @@ -5797,29 +10621,87 @@ static void ggml_cl_gelu_quick(ggml_backend_t backend, const ggml_tensor * src0, cl_ulong offset0 = extra0->offset + src0->view_offs; cl_ulong offsetd = extrad->offset + dst->view_offs; + const int ne00 = src0->ne[0]; + const int ne01 = src0->ne[1]; + const int ne02 = src0->ne[2]; + const int ne03 = src0->ne[3]; + + const cl_ulong nb00 = src0->nb[0]; + const cl_ulong nb01 = src0->nb[1]; + const cl_ulong nb02 = src0->nb[2]; + const cl_ulong nb03 = src0->nb[3]; + + const cl_ulong nb0 = dst->nb[0]; + const cl_ulong nb1 = dst->nb[1]; + const cl_ulong nb2 = dst->nb[2]; + const cl_ulong nb3 = dst->nb[3]; + cl_kernel kernel; - int n = ggml_nelements(dst); + if (ggml_is_contiguous(src0)) { + // Handle contiguous input + int n = ggml_nelements(dst); + if (n % 4 == 0) { + if (src0->type == GGML_TYPE_F32) { + kernel = backend_ctx->kernel_tanh_f32_4; + } else { + kernel = backend_ctx->kernel_tanh_f16_4; + } + n /= 4; + } else { + if (src0->type == GGML_TYPE_F32) { + kernel = backend_ctx->kernel_tanh_f32; + } else { + kernel = backend_ctx->kernel_tanh_f16; + } + } - if (n % 4 == 0) { - kernel = backend_ctx->kernel_gelu_quick_4; - n /= 4; + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offsetd)); + + size_t global_work_size[] = {(size_t)n, 1, 1}; + size_t local_work_size[] = {64, 1, 1}; + + size_t * local_work_size_ptr = local_work_size; + if (n % 64 != 0 && !backend_ctx->non_uniform_workgroups) { + local_work_size_ptr = nullptr; + } + + backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size_ptr, dst); } else { - kernel = backend_ctx->kernel_gelu_quick; - } + // Handle non-contiguous input + if (src0->type == GGML_TYPE_F32) { + kernel = backend_ctx->kernel_tanh_f32_nc; + } else { + kernel = backend_ctx->kernel_tanh_f16_nc; + } - CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device)); - CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0)); - CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extrad->data_device)); - CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offsetd)); + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offsetd)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &nb00)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(cl_ulong), &nb01)); + CL_CHECK(clSetKernelArg(kernel, 7, sizeof(cl_ulong), &nb02)); + CL_CHECK(clSetKernelArg(kernel, 8, sizeof(cl_ulong), &nb03)); + CL_CHECK(clSetKernelArg(kernel, 9, sizeof(cl_ulong), &nb0)); + CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_ulong), &nb1)); + CL_CHECK(clSetKernelArg(kernel, 11, sizeof(cl_ulong), &nb2)); + CL_CHECK(clSetKernelArg(kernel, 12, sizeof(cl_ulong), &nb3)); + + int nth = 64; - size_t global_work_size[] = {(size_t)n, 1, 1}; - size_t local_work_size[] = {64, 1, 1}; + size_t global_work_size[] = {(size_t)ne01*nth, (size_t)ne02, (size_t)ne03}; + size_t local_work_size[] = {(size_t)nth, 1, 1}; - backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst); + backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst); + } } -static void ggml_cl_silu(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { +static void ggml_cl_neg(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { GGML_ASSERT(src0); GGML_ASSERT(src0->extra); GGML_ASSERT(dst); @@ -5835,34 +10717,73 @@ static void ggml_cl_silu(ggml_backend_t backend, const ggml_tensor * src0, const cl_ulong offset0 = extra0->offset + src0->view_offs; cl_ulong offsetd = extrad->offset + dst->view_offs; + GGML_TENSOR_LOCALS(int, ne0, src0, ne); + GGML_TENSOR_LOCALS(cl_ulong, nb0, src0, nb); + GGML_TENSOR_LOCALS(int, ne, dst, ne); + GGML_TENSOR_LOCALS(cl_ulong, nb, dst, nb); + cl_kernel kernel; - int n = ggml_nelements(dst); + if (ggml_is_contiguous(src0)) { + // Handle contiguous input + int n = ggml_nelements(dst); + if (n % 4 == 0) { + if (src0->type == GGML_TYPE_F32) { + kernel = backend_ctx->kernel_neg_f32_4; + } else { + kernel = backend_ctx->kernel_neg_f16_4; + } + n /= 4; + } else { + if (src0->type == GGML_TYPE_F32) { + kernel = backend_ctx->kernel_neg_f32; + } else { + kernel = backend_ctx->kernel_neg_f16; + } + } - if (n % 4 == 0) { - kernel = backend_ctx->kernel_silu_4; - n /= 4; + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offsetd)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_int), &n)); + + size_t global_work_size[] = {(size_t)CEIL_DIV(n, 64)*64, 1, 1}; + size_t local_work_size[] = {64, 1, 1}; + + backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst); } else { - kernel = backend_ctx->kernel_silu; - } + // Handle non-contiguous input + if (src0->type == GGML_TYPE_F32) { + kernel = backend_ctx->kernel_neg_f32_nc; + } else { + kernel = backend_ctx->kernel_neg_f16_nc; + } - CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device)); - CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0)); - CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extrad->data_device)); - CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offsetd)); + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offsetd)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &nb00)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(cl_ulong), &nb01)); + CL_CHECK(clSetKernelArg(kernel, 7, sizeof(cl_ulong), &nb02)); + CL_CHECK(clSetKernelArg(kernel, 8, sizeof(cl_ulong), &nb03)); + CL_CHECK(clSetKernelArg(kernel, 9, sizeof(cl_ulong), &nb0)); + CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_ulong), &nb1)); + CL_CHECK(clSetKernelArg(kernel, 11, sizeof(cl_ulong), &nb2)); + CL_CHECK(clSetKernelArg(kernel, 12, sizeof(cl_ulong), &nb3)); + + int nth = 64; - size_t global_work_size[] = {(size_t)n, 1, 1}; - size_t local_work_size[] = {64, 1, 1}; + size_t global_work_size[] = {(size_t)ne01*nth, (size_t)ne02, (size_t)ne03}; + size_t local_work_size[] = {(size_t)nth, 1, 1}; - size_t * local_work_size_ptr = local_work_size; - if (n % 64 != 0 && !backend_ctx->non_uniform_workgroups) { - local_work_size_ptr = nullptr; // Let driver choose the work-group sizes. + backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst); } - - backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size_ptr, dst); } -static void ggml_cl_relu(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { +static void ggml_cl_exp(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { GGML_ASSERT(src0); GGML_ASSERT(src0->extra); GGML_ASSERT(dst); @@ -5878,27 +10799,73 @@ static void ggml_cl_relu(ggml_backend_t backend, const ggml_tensor * src0, const cl_ulong offset0 = extra0->offset + src0->view_offs; cl_ulong offsetd = extrad->offset + dst->view_offs; - cl_kernel kernel = backend_ctx->kernel_relu; + GGML_TENSOR_LOCALS(int, ne0, src0, ne); + GGML_TENSOR_LOCALS(cl_ulong, nb0, src0, nb); + GGML_TENSOR_LOCALS(int, ne, dst, ne); + GGML_TENSOR_LOCALS(cl_ulong, nb, dst, nb); - CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device)); - CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0)); - CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extrad->data_device)); - CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offsetd)); + cl_kernel kernel; - const int64_t n = ggml_nelements(dst); + if (ggml_is_contiguous(src0)) { + // Handle contiguous input + int n = ggml_nelements(dst); + if (n % 4 == 0) { + if (src0->type == GGML_TYPE_F32) { + kernel = backend_ctx->kernel_exp_f32_4; + } else { + kernel = backend_ctx->kernel_exp_f16_4; + } + n /= 4; + } else { + if (src0->type == GGML_TYPE_F32) { + kernel = backend_ctx->kernel_exp_f32; + } else { + kernel = backend_ctx->kernel_exp_f16; + } + } - size_t global_work_size[] = {(size_t)n, 1, 1}; - size_t local_work_size[] = {64, 1, 1}; + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offsetd)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_int), &n)); - size_t * local_work_size_ptr = local_work_size; - if (n % 64 != 0 && !backend_ctx->non_uniform_workgroups) { - local_work_size_ptr = nullptr; // Let driver choose the work-group sizes. + size_t global_work_size[] = {(size_t)CEIL_DIV(n, 64)*64, 1, 1}; + size_t local_work_size[] = {64, 1, 1}; + + backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst); + } else { + // Handle non-contiguous input + if (src0->type == GGML_TYPE_F32) { + kernel = backend_ctx->kernel_exp_f32_nc; + } else { + kernel = backend_ctx->kernel_exp_f16_nc; + } + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offsetd)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &nb00)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(cl_ulong), &nb01)); + CL_CHECK(clSetKernelArg(kernel, 7, sizeof(cl_ulong), &nb02)); + CL_CHECK(clSetKernelArg(kernel, 8, sizeof(cl_ulong), &nb03)); + CL_CHECK(clSetKernelArg(kernel, 9, sizeof(cl_ulong), &nb0)); + CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_ulong), &nb1)); + CL_CHECK(clSetKernelArg(kernel, 11, sizeof(cl_ulong), &nb2)); + CL_CHECK(clSetKernelArg(kernel, 12, sizeof(cl_ulong), &nb3)); + + int nth = 64; + + size_t global_work_size[] = {(size_t)ne01*nth, (size_t)ne02, (size_t)ne03}; + size_t local_work_size[] = {(size_t)nth, 1, 1}; + + backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst); } - - backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size_ptr, dst); } -static void ggml_cl_sigmoid(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { +static void ggml_cl_expm1(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { GGML_ASSERT(src0); GGML_ASSERT(src0->extra); GGML_ASSERT(dst); @@ -5914,1664 +10881,2875 @@ static void ggml_cl_sigmoid(ggml_backend_t backend, const ggml_tensor * src0, co cl_ulong offset0 = extra0->offset + src0->view_offs; cl_ulong offsetd = extrad->offset + dst->view_offs; + const int ne00 = src0->ne[0]; + const int ne01 = src0->ne[1]; + const int ne02 = src0->ne[2]; + const int ne03 = src0->ne[3]; + + const cl_ulong nb00 = src0->nb[0]; + const cl_ulong nb01 = src0->nb[1]; + const cl_ulong nb02 = src0->nb[2]; + const cl_ulong nb03 = src0->nb[3]; + + const cl_ulong nb0 = dst->nb[0]; + const cl_ulong nb1 = dst->nb[1]; + const cl_ulong nb2 = dst->nb[2]; + const cl_ulong nb3 = dst->nb[3]; + cl_kernel kernel; - if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { - kernel = backend_ctx->kernel_sigmoid_f32; - } else if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) { - kernel = backend_ctx->kernel_sigmoid_f16; - } else { - GGML_ASSERT(false && "Unsupported data types for sigmoid (input and output must be both f32 or f16)"); - } - CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device)); - CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0)); - CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extrad->data_device)); - CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offsetd)); + if (ggml_is_contiguous(src0)) { + // Handle contiguous input + int n = ggml_nelements(dst); + if (n % 4 == 0) { + if (src0->type == GGML_TYPE_F32) { + kernel = backend_ctx->kernel_expm1_f32_4; + } else { + kernel = backend_ctx->kernel_expm1_f16_4; + } + n /= 4; + } else { + if (src0->type == GGML_TYPE_F32) { + kernel = backend_ctx->kernel_expm1_f32; + } else { + kernel = backend_ctx->kernel_expm1_f16; + } + } - const int64_t n = ggml_nelements(dst); + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offsetd)); - size_t global_work_size[] = {(size_t)n, 1, 1}; - size_t local_work_size[] = {64, 1, 1}; + size_t global_work_size[] = {(size_t)n, 1, 1}; + size_t local_work_size[] = {64, 1, 1}; - size_t * local_work_size_ptr = local_work_size; - if (n % 64 != 0 && !backend_ctx->non_uniform_workgroups) { - local_work_size_ptr = nullptr; // Let driver choose the work-group sizes. - } + size_t * local_work_size_ptr = local_work_size; + if (n % 64 != 0 && !backend_ctx->non_uniform_workgroups) { + local_work_size_ptr = nullptr; + } - backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size_ptr, dst); + backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size_ptr, dst); + } else { + // Handle non-contiguous input + if (src0->type == GGML_TYPE_F32) { + kernel = backend_ctx->kernel_expm1_f32_nc; + } else { + kernel = backend_ctx->kernel_expm1_f16_nc; + } + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offsetd)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &nb00)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(cl_ulong), &nb01)); + CL_CHECK(clSetKernelArg(kernel, 7, sizeof(cl_ulong), &nb02)); + CL_CHECK(clSetKernelArg(kernel, 8, sizeof(cl_ulong), &nb03)); + CL_CHECK(clSetKernelArg(kernel, 9, sizeof(cl_ulong), &nb0)); + CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_ulong), &nb1)); + CL_CHECK(clSetKernelArg(kernel, 11, sizeof(cl_ulong), &nb2)); + CL_CHECK(clSetKernelArg(kernel, 12, sizeof(cl_ulong), &nb3)); + + int nth = 64; + + size_t global_work_size[] = {(size_t)ne01*nth, (size_t)ne02, (size_t)ne03}; + size_t local_work_size[] = {(size_t)nth, 1, 1}; + + backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst); + } } -static void ggml_cl_fill(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { +static void ggml_cl_softplus(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { + GGML_ASSERT(src0); + GGML_ASSERT(src0->extra); GGML_ASSERT(dst); GGML_ASSERT(dst->extra); - UNUSED(src0); UNUSED(src1); ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context; + ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra; ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra; + + cl_ulong offset0 = extra0->offset + src0->view_offs; cl_ulong offsetd = extrad->offset + dst->view_offs; - float v = 0.0f; - memcpy(&v, ((int32_t *) dst->op_params), sizeof(float)); + const int ne00 = src0->ne[0]; + const int ne01 = src0->ne[1]; + const int ne02 = src0->ne[2]; + const int ne03 = src0->ne[3]; - const int64_t n = ggml_nelements(dst); + const cl_ulong nb00 = src0->nb[0]; + const cl_ulong nb01 = src0->nb[1]; + const cl_ulong nb02 = src0->nb[2]; + const cl_ulong nb03 = src0->nb[3]; - cl_kernel kernel = backend_ctx->kernel_fill; + const cl_ulong nb0 = dst->nb[0]; + const cl_ulong nb1 = dst->nb[1]; + const cl_ulong nb2 = dst->nb[2]; + const cl_ulong nb3 = dst->nb[3]; - CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extrad->data_device)); - CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offsetd)); - CL_CHECK(clSetKernelArg(kernel, 2, sizeof(float), &v)); - CL_CHECK(clSetKernelArg(kernel, 3, sizeof(float), &n)); + cl_kernel kernel; - size_t local_work_size[1] = { 256 }; - size_t global_work_size[1] = { ((size_t)n + local_work_size[0] - 1) / local_work_size[0] * local_work_size[0] }; + if (ggml_is_contiguous(src0)) { + // Handle contiguous input + int n = ggml_nelements(dst); + if (n % 4 == 0) { + if (src0->type == GGML_TYPE_F32) { + kernel = backend_ctx->kernel_softplus_f32_4; + } else { + kernel = backend_ctx->kernel_softplus_f16_4; + } + n /= 4; + } else { + if (src0->type == GGML_TYPE_F32) { + kernel = backend_ctx->kernel_softplus_f32; + } else { + kernel = backend_ctx->kernel_softplus_f16; + } + } - backend_ctx->enqueue_ndrange_kernel(kernel, 1, global_work_size, local_work_size, dst); + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offsetd)); + + size_t global_work_size[] = {(size_t)n, 1, 1}; + size_t local_work_size[] = {64, 1, 1}; + + size_t * local_work_size_ptr = local_work_size; + if (n % 64 != 0 && !backend_ctx->non_uniform_workgroups) { + local_work_size_ptr = nullptr; + } + + backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size_ptr, dst); + } else { + // Handle non-contiguous input + if (src0->type == GGML_TYPE_F32) { + kernel = backend_ctx->kernel_softplus_f32_nc; + } else { + kernel = backend_ctx->kernel_softplus_f16_nc; + } + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offsetd)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &nb00)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(cl_ulong), &nb01)); + CL_CHECK(clSetKernelArg(kernel, 7, sizeof(cl_ulong), &nb02)); + CL_CHECK(clSetKernelArg(kernel, 8, sizeof(cl_ulong), &nb03)); + CL_CHECK(clSetKernelArg(kernel, 9, sizeof(cl_ulong), &nb0)); + CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_ulong), &nb1)); + CL_CHECK(clSetKernelArg(kernel, 11, sizeof(cl_ulong), &nb2)); + CL_CHECK(clSetKernelArg(kernel, 12, sizeof(cl_ulong), &nb3)); + + int nth = 64; + + size_t global_work_size[] = {(size_t)ne01*nth, (size_t)ne02, (size_t)ne03}; + size_t local_work_size[] = {(size_t)nth, 1, 1}; + + backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst); + } } -static void ggml_cl_clamp(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { +static void ggml_cl_repeat(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1_shape_def, ggml_tensor * dst) { GGML_ASSERT(src0); GGML_ASSERT(src0->extra); GGML_ASSERT(dst); GGML_ASSERT(dst->extra); + GGML_ASSERT(dst->type == src0->type); - UNUSED(src1); + UNUSED(src1_shape_def); ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context; ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra; - ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra; + ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra; cl_ulong offset0 = extra0->offset + src0->view_offs; - cl_ulong offsetd = extrad->offset + dst->view_offs; + cl_ulong offsetd = extrad->offset + dst->view_offs; - float min; - float max; - memcpy(&min, ((int32_t *) dst->op_params) + 0, sizeof(float)); - memcpy(&max, ((int32_t *) dst->op_params) + 1, sizeof(float)); + const int ne00 = src0->ne[0]; + const int ne01 = src0->ne[1]; + const int ne02 = src0->ne[2]; + const int ne03 = src0->ne[3]; - cl_kernel kernel = backend_ctx->kernel_clamp; + const cl_ulong nb00 = src0->nb[0]; + const cl_ulong nb01 = src0->nb[1]; + const cl_ulong nb02 = src0->nb[2]; + const cl_ulong nb03 = src0->nb[3]; - CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device)); - CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0)); - CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extrad->data_device)); - CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offsetd)); - CL_CHECK(clSetKernelArg(kernel, 4, sizeof(float), &min)); - CL_CHECK(clSetKernelArg(kernel, 5, sizeof(float), &max)); + const int ne0 = dst->ne[0]; + const int ne1 = dst->ne[1]; + const int ne2 = dst->ne[2]; + const int ne3 = dst->ne[3]; - const int64_t n = ggml_nelements(dst); + const cl_ulong nb0 = dst->nb[0]; + const cl_ulong nb1 = dst->nb[1]; + const cl_ulong nb2 = dst->nb[2]; + const cl_ulong nb3 = dst->nb[3]; - size_t global_work_size[] = {(size_t)n, 1, 1}; - size_t local_work_size[] = {64, 1, 1}; + cl_kernel kernel = backend_ctx->kernel_repeat_f32; - size_t * local_work_size_ptr = local_work_size; - if (n % 64 != 0 && !backend_ctx->non_uniform_workgroups) { - local_work_size_ptr = nullptr; // Let driver choose the work-group sizes. - } + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offsetd)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(int), &ne01)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne02)); + CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &ne03)); + CL_CHECK(clSetKernelArg(kernel, 8, sizeof(cl_ulong), &nb00)); + CL_CHECK(clSetKernelArg(kernel, 9, sizeof(cl_ulong), &nb01)); + CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_ulong), &nb02)); + CL_CHECK(clSetKernelArg(kernel, 11, sizeof(cl_ulong), &nb03)); + CL_CHECK(clSetKernelArg(kernel, 12, sizeof(int), &ne0)); + CL_CHECK(clSetKernelArg(kernel, 13, sizeof(cl_ulong), &nb0)); + CL_CHECK(clSetKernelArg(kernel, 14, sizeof(cl_ulong), &nb1)); + CL_CHECK(clSetKernelArg(kernel, 15, sizeof(cl_ulong), &nb2)); + CL_CHECK(clSetKernelArg(kernel, 16, sizeof(cl_ulong), &nb3)); - backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size_ptr, dst); + int nth = 64; + + size_t global_work_size[] = {(size_t)ne1*nth, (size_t)ne2, (size_t)ne3}; + size_t local_work_size[] = {(size_t)nth, 1, 1}; + + backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst); } -static void ggml_cl_norm(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { +static void ggml_cl_pad(ggml_backend_t backend, const ggml_tensor * src0, ggml_tensor * dst) { GGML_ASSERT(src0); GGML_ASSERT(src0->extra); GGML_ASSERT(dst); GGML_ASSERT(dst->extra); - - UNUSED(src1); + GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT(dst->type == GGML_TYPE_F32); ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context; - ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra; - ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra; + if (backend_ctx->kernel_pad == nullptr) { + GGML_LOG_WARN("%s: pad kernel not available, skipping OpenCL execution.\n", __func__); + return; + } - cl_ulong offset0 = extra0->offset + src0->view_offs; - cl_ulong offsetd = extrad->offset + dst->view_offs; + ggml_tensor_extra_cl * extra_src0 = (ggml_tensor_extra_cl *)src0->extra; + ggml_tensor_extra_cl * extra_dst = (ggml_tensor_extra_cl *)dst->extra; - float eps; - memcpy(&eps, dst->op_params, sizeof(float)); + cl_ulong off_src0 = extra_src0->offset + src0->view_offs; + cl_ulong off_dst = extra_dst->offset + dst->view_offs; - const int ne00 = src0 ? src0->ne[0] : 0; - const int ne01 = src0 ? src0->ne[1] : 0; - const int ne02 = src0 ? src0->ne[2] : 0; - const int ne03 = src0 ? src0->ne[3] : 0; + const int s_ne0 = src0->ne[0]; + const int s_ne1 = src0->ne[1]; + const int s_ne2 = src0->ne[2]; + const int s_ne3 = src0->ne[3]; - const cl_ulong nb01 = src0 ? src0->nb[1] : 0; - const cl_ulong nb02 = src0 ? src0->nb[2] : 0; - const cl_ulong nb03 = src0 ? src0->nb[3] : 0; + const int s_nb0 = src0->nb[0]; + const int s_nb1 = src0->nb[1]; + const int s_nb2 = src0->nb[2]; + const int s_nb3 = src0->nb[3]; - const int nth = MIN(64, ne00); + const int d_ne0 = dst->ne[0]; + const int d_ne1 = dst->ne[1]; + const int d_ne2 = dst->ne[2]; + const int d_ne3 = dst->ne[3]; + + const int d_nb0 = dst->nb[0]; + const int d_nb1 = dst->nb[1]; + const int d_nb2 = dst->nb[2]; + const int d_nb3 = dst->nb[3]; + + const int lp0 = ((const int*)(dst->op_params))[0]; + const int rp0 = ((const int*)(dst->op_params))[1]; + const int lp1 = ((const int*)(dst->op_params))[2]; + const int rp1 = ((const int*)(dst->op_params))[3]; + const int lp2 = ((const int*)(dst->op_params))[4]; + const int rp2 = ((const int*)(dst->op_params))[5]; + const int lp3 = ((const int*)(dst->op_params))[6]; + const int rp3 = ((const int*)(dst->op_params))[7]; + + cl_kernel kernel = backend_ctx->kernel_pad; + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra_src0->data_device)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &off_src0)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra_dst->data_device)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &off_dst)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(int), &s_ne0)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(int), &s_ne1)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &s_ne2)); + CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &s_ne3)); + CL_CHECK(clSetKernelArg(kernel, 8, sizeof(cl_ulong), &s_nb0)); + CL_CHECK(clSetKernelArg(kernel, 9, sizeof(cl_ulong), &s_nb1)); + CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_ulong), &s_nb2)); + CL_CHECK(clSetKernelArg(kernel, 11, sizeof(cl_ulong), &s_nb3)); + CL_CHECK(clSetKernelArg(kernel, 12, sizeof(int), &d_ne0)); + CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int), &d_ne1)); + CL_CHECK(clSetKernelArg(kernel, 14, sizeof(int), &d_ne2)); + CL_CHECK(clSetKernelArg(kernel, 15, sizeof(int), &d_ne3)); + CL_CHECK(clSetKernelArg(kernel, 16, sizeof(cl_ulong), &d_nb0)); + CL_CHECK(clSetKernelArg(kernel, 17, sizeof(cl_ulong), &d_nb1)); + CL_CHECK(clSetKernelArg(kernel, 18, sizeof(cl_ulong), &d_nb2)); + CL_CHECK(clSetKernelArg(kernel, 19, sizeof(cl_ulong), &d_nb3)); + CL_CHECK(clSetKernelArg(kernel, 20, sizeof(int), &lp0)); + CL_CHECK(clSetKernelArg(kernel, 21, sizeof(int), &rp0)); + CL_CHECK(clSetKernelArg(kernel, 22, sizeof(int), &lp1)); + CL_CHECK(clSetKernelArg(kernel, 23, sizeof(int), &rp1)); + CL_CHECK(clSetKernelArg(kernel, 24, sizeof(int), &lp2)); + CL_CHECK(clSetKernelArg(kernel, 25, sizeof(int), &rp2)); + CL_CHECK(clSetKernelArg(kernel, 26, sizeof(int), &lp3)); + CL_CHECK(clSetKernelArg(kernel, 27, sizeof(int), &rp3)); - cl_kernel kernel = backend_ctx->kernel_norm; + size_t lws0 = 64; + size_t gws0 = (( (size_t)d_ne0 + lws0 - 1 ) / lws0) * lws0; - CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device)); - CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0)); - CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extrad->data_device)); - CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offsetd)); - CL_CHECK(clSetKernelArg(kernel, 4, sizeof(int), &ne00)); - CL_CHECK(clSetKernelArg(kernel, 5, sizeof(int), &ne01)); - CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne02)); - CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &ne03)); - CL_CHECK(clSetKernelArg(kernel, 8, sizeof(cl_ulong), &nb01)); - CL_CHECK(clSetKernelArg(kernel, 9, sizeof(cl_ulong), &nb02)); - CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_ulong), &nb03)); - CL_CHECK(clSetKernelArg(kernel, 11, sizeof(float), &eps)); - CL_CHECK(clSetKernelArg(kernel, 12, sizeof(float)*nth, NULL)); + size_t global_work_size[] = { gws0, (size_t)d_ne1, (size_t)d_ne2*d_ne3 }; + size_t local_work_size[] = { lws0, 1, 1 }; - size_t global_work_size[] = {(size_t)ne01*nth, (size_t)ne02, (size_t)ne03}; - size_t local_work_size[] = {(size_t)nth, 1, 1}; + size_t * local_work_size_ptr = local_work_size; + if (d_ne0 % lws0 != 0 && !backend_ctx->non_uniform_workgroups) { + local_work_size_ptr = nullptr; + } - backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst); + backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size_ptr, dst); } -static void ggml_cl_rms_norm(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { +static void ggml_cl_upscale(ggml_backend_t backend, const ggml_tensor * src0, ggml_tensor * dst) { GGML_ASSERT(src0); GGML_ASSERT(src0->extra); GGML_ASSERT(dst); GGML_ASSERT(dst->extra); - - UNUSED(src1); + GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT(dst->type == GGML_TYPE_F32); ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context; - //ggml_backend_opencl_device_context * dev_ctx = - // (ggml_backend_opencl_device_context *)backend->device->context; + const int mode_flags = (ggml_scale_mode) ggml_get_op_params_i32(dst, 0); + const ggml_scale_mode mode = (ggml_scale_mode) (mode_flags & 0xFF); + cl_kernel kernel = nullptr; - ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra; - ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra; + if (mode == GGML_SCALE_MODE_NEAREST) { + kernel = backend_ctx->kernel_upscale; + if (kernel == nullptr) { + GGML_LOG_WARN("%s: nearest upscale kernel not available, skipping OpenCL execution.\n", __func__); + return; + } + } else if (mode == GGML_SCALE_MODE_BILINEAR) { + kernel = backend_ctx->kernel_upscale_bilinear; + if (kernel == nullptr) { + GGML_LOG_WARN("%s: bilinear upscale kernel not available, skipping OpenCL execution.\n", __func__); + return; + } + } else { + GGML_LOG_WARN("%s: unsupported upscale mode %d, skipping OpenCL execution.\n", __func__, mode); + return; + } - cl_ulong offset0 = extra0->offset + src0->view_offs; - cl_ulong offsetd = extrad->offset + dst->view_offs; + ggml_tensor_extra_cl * extra_src0 = (ggml_tensor_extra_cl *)src0->extra; + ggml_tensor_extra_cl * extra_dst = (ggml_tensor_extra_cl *)dst->extra; - float eps; - memcpy(&eps, dst->op_params, sizeof(float)); + cl_ulong off_src0 = extra_src0->offset + src0->view_offs; + cl_ulong off_dst = extra_dst->offset + dst->view_offs; - const int ne00 = src0 ? src0->ne[0] : 0; - const int ne01 = src0 ? src0->ne[1] : 0; - const int ne02 = src0 ? src0->ne[2] : 0; - const int ne03 = src0 ? src0->ne[3] : 0; + const cl_ulong nb00 = src0->nb[0]; + const cl_ulong nb01 = src0->nb[1]; + const cl_ulong nb02 = src0->nb[2]; + const cl_ulong nb03 = src0->nb[3]; - const cl_ulong nb01 = src0 ? src0->nb[1] : 0; - const cl_ulong nb02 = src0 ? src0->nb[2] : 0; - const cl_ulong nb03 = src0 ? src0->nb[3] : 0; + const int ne00 = src0->ne[0]; + const int ne01 = src0->ne[1]; + const int ne02 = src0->ne[2]; + const int ne03 = src0->ne[3]; - GGML_ASSERT(ne00 % 4 == 0); + const int ne0 = dst->ne[0]; + const int ne1 = dst->ne[1]; + const int ne2 = dst->ne[2]; + const int ne3 = dst->ne[3]; - const int nth = MIN(64, ne00); + float sf0 = (float)ne0 / ne00; + float sf1 = (float)ne1 / ne01; + float sf2 = (float)ne2 / ne02; + float sf3 = (float)ne3 / ne03; - size_t global_work_size[] = {(size_t)ne01*nth, (size_t)ne02, (size_t)ne03}; - size_t local_work_size[] = {(size_t)nth, 1, 1}; + float pixel_offset = 0.5f; - cl_kernel kernel = backend_ctx->kernel_rms_norm; + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra_src0->data_device)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &off_src0)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra_dst->data_device)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &off_dst)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_ulong), &nb00)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &nb01)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(cl_ulong), &nb02)); + CL_CHECK(clSetKernelArg(kernel, 7, sizeof(cl_ulong), &nb03)); - // Note, this kernel declares local memory in kernel args and the size - // depends on subgroup size. - // Note, this requires OpenCL 2.1 and above - // For now we use fixed subgroup size to simplify support for OpenCL 2.0. - size_t sgs; - //CL_CHECK(clGetKernelSubGroupInfo(kernel, dev_ctx->device, - // CL_KERNEL_MAX_SUB_GROUP_SIZE_FOR_NDRANGE, - // sizeof(local_work_size), local_work_size, - // sizeof(size_t), &sgs, NULL)); - if (backend_ctx->gpu_family == ADRENO) { - sgs = 64; - } else if (backend_ctx->gpu_family == INTEL) { - sgs = 32; - } else { - GGML_ASSERT(false && "Unsupported GPU"); - } + if (mode == GGML_SCALE_MODE_NEAREST) { + CL_CHECK(clSetKernelArg(kernel, 8, sizeof(int), &ne0)); + CL_CHECK(clSetKernelArg(kernel, 9, sizeof(int), &ne1)); + CL_CHECK(clSetKernelArg(kernel, 10, sizeof(int), &ne2)); + CL_CHECK(clSetKernelArg(kernel, 11, sizeof(int), &ne3)); + CL_CHECK(clSetKernelArg(kernel, 12, sizeof(float), &sf0)); + CL_CHECK(clSetKernelArg(kernel, 13, sizeof(float), &sf1)); + CL_CHECK(clSetKernelArg(kernel, 14, sizeof(float), &sf2)); + CL_CHECK(clSetKernelArg(kernel, 15, sizeof(float), &sf3)); + } else if (mode == GGML_SCALE_MODE_BILINEAR) { + if (mode_flags & GGML_SCALE_FLAG_ALIGN_CORNERS) { + sf0 = ne0 > 1 && ne00 > 1 ? (float)(ne0 - 1) / (ne00 - 1) : sf0; + sf1 = ne1 > 1 && ne01 > 1 ? (float)(ne1 - 1) / (ne01 - 1) : sf1; + pixel_offset = 0.0f; + } - CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device)); - CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0)); - CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extrad->data_device)); - CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offsetd)); - CL_CHECK(clSetKernelArg(kernel, 4, sizeof(int), &ne00)); - CL_CHECK(clSetKernelArg(kernel, 5, sizeof(int), &ne01)); - CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne02)); - CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &ne03)); - CL_CHECK(clSetKernelArg(kernel, 8, sizeof(cl_ulong), &nb01)); - CL_CHECK(clSetKernelArg(kernel, 9, sizeof(cl_ulong), &nb02)); - CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_ulong), &nb03)); - CL_CHECK(clSetKernelArg(kernel, 11, sizeof(float), &eps)); - // This is local memory - the size depends on subgroup size. - CL_CHECK(clSetKernelArg(kernel, 12, sizeof(float)*nth/sgs, NULL)); + CL_CHECK(clSetKernelArg(kernel, 8, sizeof(int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, 9, sizeof(int), &ne01)); + CL_CHECK(clSetKernelArg(kernel, 10, sizeof(int), &ne0)); + CL_CHECK(clSetKernelArg(kernel, 11, sizeof(int), &ne1)); + CL_CHECK(clSetKernelArg(kernel, 12, sizeof(int), &ne2)); + CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int), &ne3)); + CL_CHECK(clSetKernelArg(kernel, 14, sizeof(float), &sf0)); + CL_CHECK(clSetKernelArg(kernel, 15, sizeof(float), &sf1)); + CL_CHECK(clSetKernelArg(kernel, 16, sizeof(float), &sf2)); + CL_CHECK(clSetKernelArg(kernel, 17, sizeof(float), &sf3)); + CL_CHECK(clSetKernelArg(kernel, 18, sizeof(float), &pixel_offset)); + } - backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst); -} -static void ggml_opencl_op_rms_norm_fused(ggml_backend_t backend, ggml_tensor * rms_norm_tensor, ggml_tensor * mul_tensor) { - GGML_ASSERT(mul_tensor); - GGML_ASSERT(rms_norm_tensor); + size_t dst_total_elements = (size_t)ne0 * ne1 * ne2 * ne3; + if (dst_total_elements == 0) { + return; + } + size_t global_work_size[] = { dst_total_elements, 1, 1 }; + size_t local_work_size_pref = 256; + size_t local_work_size[] = { MIN(local_work_size_pref, dst_total_elements), 1, 1}; - // src0 is the src of rms_norm, src1 is the other src of mul (one being rms_norm) - const ggml_tensor * src0 = rms_norm_tensor->src[0]; - const ggml_tensor * src1; - if (mul_tensor->src[0] == rms_norm_tensor) { - src1 = mul_tensor->src[1]; - } else if (mul_tensor->src[1] == rms_norm_tensor) { - src1 = mul_tensor->src[0]; - } else { - GGML_ASSERT(false && "Invalid args for rms_norm and mul"); + size_t * local_work_size_ptr = local_work_size; + if (dst_total_elements % local_work_size[0] != 0 && !backend_ctx->non_uniform_workgroups) { + local_work_size_ptr = nullptr; } - const ggml_tensor * dst = mul_tensor; + backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size_ptr, dst); +} + +static void ggml_cl_concat(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { GGML_ASSERT(src0); GGML_ASSERT(src0->extra); GGML_ASSERT(src1); GGML_ASSERT(src1->extra); GGML_ASSERT(dst); GGML_ASSERT(dst->extra); + GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT(src1->type == GGML_TYPE_F32); + GGML_ASSERT(dst->type == GGML_TYPE_F32); + + ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context; ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra; ggml_tensor_extra_cl * extra1 = (ggml_tensor_extra_cl *)src1->extra; ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra; cl_ulong offset0 = extra0->offset + src0->view_offs; - cl_ulong offset1 = extra1->offset + src0->view_offs; - cl_ulong offsetd = extrad->offset + dst->view_offs; - - ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context; - - float eps; - memcpy(&eps, rms_norm_tensor->op_params, sizeof(float)); + cl_ulong offset1 = extra1->offset + src1->view_offs; + cl_ulong offsetd = extrad->offset + dst->view_offs; const int ne00 = src0->ne[0]; const int ne01 = src0->ne[1]; const int ne02 = src0->ne[2]; const int ne03 = src0->ne[3]; + const cl_ulong nb00 = src0->nb[0]; const cl_ulong nb01 = src0->nb[1]; const cl_ulong nb02 = src0->nb[2]; const cl_ulong nb03 = src0->nb[3]; - const int ne10 = src1->ne[0]; - const int ne11 = src1->ne[1]; - const int ne12 = src1->ne[2]; - const int ne13 = src1->ne[3]; + const cl_ulong nb10 = src1->nb[0]; + const cl_ulong nb11 = src1->nb[1]; + const cl_ulong nb12 = src1->nb[2]; + const cl_ulong nb13 = src1->nb[3]; + + const int ne0 = dst->ne[0]; + const int ne1 = dst->ne[1]; + const int ne2 = dst->ne[2]; + const int ne3 = dst->ne[3]; + + const cl_ulong nb0 = dst->nb[0]; + const cl_ulong nb1 = dst->nb[1]; + const cl_ulong nb2 = dst->nb[2]; + const cl_ulong nb3 = dst->nb[3]; + + const cl_int dim = ((const int32_t *) dst->op_params)[0]; + GGML_ASSERT(dim >= 0 && dim <= 3); + + int nth = MIN(64, ne0); + + const bool concat_pack = (dim == 0 && ne0 < 32); + cl_kernel kernel = concat_pack ? backend_ctx->kernel_concat_f32_pack + : backend_ctx->kernel_concat_f32; + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra1->data_device)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offset1)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offsetd)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &ne01)); + CL_CHECK(clSetKernelArg(kernel, 8, sizeof(int), &ne02)); + CL_CHECK(clSetKernelArg(kernel, 9, sizeof(int), &ne03)); + CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_ulong), &nb00)); + CL_CHECK(clSetKernelArg(kernel, 11, sizeof(cl_ulong), &nb01)); + CL_CHECK(clSetKernelArg(kernel, 12, sizeof(cl_ulong), &nb02)); + CL_CHECK(clSetKernelArg(kernel, 13, sizeof(cl_ulong), &nb03)); + CL_CHECK(clSetKernelArg(kernel, 14, sizeof(cl_ulong), &nb10)); + CL_CHECK(clSetKernelArg(kernel, 15, sizeof(cl_ulong), &nb11)); + CL_CHECK(clSetKernelArg(kernel, 16, sizeof(cl_ulong), &nb12)); + CL_CHECK(clSetKernelArg(kernel, 17, sizeof(cl_ulong), &nb13)); + CL_CHECK(clSetKernelArg(kernel, 18, sizeof(int), &ne0)); + CL_CHECK(clSetKernelArg(kernel, 19, sizeof(cl_ulong), &nb0)); + CL_CHECK(clSetKernelArg(kernel, 20, sizeof(cl_ulong), &nb1)); + CL_CHECK(clSetKernelArg(kernel, 21, sizeof(cl_ulong), &nb2)); + CL_CHECK(clSetKernelArg(kernel, 22, sizeof(cl_ulong), &nb3)); + CL_CHECK(clSetKernelArg(kernel, 23, sizeof(cl_int), &dim)); + + if (concat_pack) { + // packed kernel needs the dst dims to unflatten its 1-D row index. + CL_CHECK(clSetKernelArg(kernel, 24, sizeof(int), &ne1)); + CL_CHECK(clSetKernelArg(kernel, 25, sizeof(int), &ne2)); + CL_CHECK(clSetKernelArg(kernel, 26, sizeof(int), &ne3)); + + const int maxwg = (int)backend_ctx->get_kernel_workgroup_size(kernel); + const int base = MIN(64, maxwg); + const int tpr = MIN(ne0, base); // threads per row + const int rpw = MAX(1, base / tpr); // rows per workgroup + const int lsz = tpr * rpw; + const int nrows = ne1*ne2*ne3; + const int nwg = (nrows + rpw - 1) / rpw; + size_t global_work_size[] = {(size_t)nwg*lsz, 1, 1}; + size_t local_work_size[] = {(size_t)lsz, 1, 1}; + backend_ctx->enqueue_ndrange_kernel(kernel, 1, global_work_size, local_work_size, dst); + } else { + size_t global_work_size[] = {(size_t)ne1*nth, (size_t)ne2, (size_t)ne3}; + size_t local_work_size[] = {(size_t)nth, 1, 1}; + + backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst); + } +} + +static void ggml_cl_timestep_embedding(ggml_backend_t backend, const ggml_tensor * src0, ggml_tensor * dst) { + GGML_ASSERT(src0); + GGML_ASSERT(src0->extra); + GGML_ASSERT(dst); + GGML_ASSERT(dst->extra); + GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT(dst->type == GGML_TYPE_F32); + + ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context; - const cl_ulong nb11 = src1->nb[1]; - const cl_ulong nb12 = src1->nb[2]; - const cl_ulong nb13 = src1->nb[3]; + if (backend_ctx->kernel_timestep_embedding == nullptr) { + GGML_LOG_WARN("%s: timestep_embedding kernel not available, skipping OpenCL execution.\n", __func__); + return; + } - const cl_ulong nb1 = dst->nb[1]; - const cl_ulong nb2 = dst->nb[2]; - const cl_ulong nb3 = dst->nb[3]; + ggml_tensor_extra_cl * extra_src0 = (ggml_tensor_extra_cl *)src0->extra; + ggml_tensor_extra_cl * extra_dst = (ggml_tensor_extra_cl *)dst->extra; - GGML_ASSERT(ne00 % 4 == 0); + cl_ulong off_src0 = extra_src0->offset + src0->view_offs; + cl_ulong off_dst = extra_dst->offset + dst->view_offs; - size_t sgs; - if (backend_ctx->gpu_family == ADRENO) { - sgs = 64; - } else if (backend_ctx->gpu_family == INTEL) { - sgs = 32; - } else { - GGML_ASSERT(false && "Unsupported GPU"); - } + const int logical_dim = dst->op_params[0]; + const int max_period = dst->op_params[1]; + const int dst_nb1_bytes = dst->nb[1]; - cl_kernel kernel = backend_ctx->kernel_rms_norm_mul; + cl_kernel kernel = backend_ctx->kernel_timestep_embedding; - int nth = sgs; - int max_workgroup_size = backend_ctx->get_kernel_workgroup_size(kernel); - while (nth < ne00 && nth < max_workgroup_size) { - nth *= 2; - } - nth = MIN(nth, max_workgroup_size); - nth = MIN(nth, ne00); + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra_src0->data_device)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &off_src0)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra_dst->data_device)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &off_dst)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(int), &dst_nb1_bytes)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(int), &logical_dim)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &max_period)); - size_t global_work_size[] = {(size_t)ne01*nth, (size_t)ne02, (size_t)ne03}; - size_t local_work_size[] = {(size_t)nth, 1, 1}; + size_t gws0 = (size_t)(((logical_dim + 1) / 2) + 1); - CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device)); - CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0)); - CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra1->data_device)); - CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offset1)); - CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extrad->data_device)); - CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offsetd)); - CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne00)); - CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &ne01)); - CL_CHECK(clSetKernelArg(kernel, 8, sizeof(int), &ne02)); - CL_CHECK(clSetKernelArg(kernel, 9, sizeof(int), &ne03)); - CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_ulong), &nb01)); - CL_CHECK(clSetKernelArg(kernel, 11, sizeof(cl_ulong), &nb02)); - CL_CHECK(clSetKernelArg(kernel, 12, sizeof(cl_ulong), &nb03)); - CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int), &ne10)); - CL_CHECK(clSetKernelArg(kernel, 14, sizeof(int), &ne11)); - CL_CHECK(clSetKernelArg(kernel, 15, sizeof(int), &ne12)); - CL_CHECK(clSetKernelArg(kernel, 16, sizeof(int), &ne13)); - CL_CHECK(clSetKernelArg(kernel, 17, sizeof(cl_ulong), &nb11)); - CL_CHECK(clSetKernelArg(kernel, 18, sizeof(cl_ulong), &nb12)); - CL_CHECK(clSetKernelArg(kernel, 19, sizeof(cl_ulong), &nb13)); - CL_CHECK(clSetKernelArg(kernel, 20, sizeof(cl_ulong), &nb1)); - CL_CHECK(clSetKernelArg(kernel, 21, sizeof(cl_ulong), &nb2)); - CL_CHECK(clSetKernelArg(kernel, 22, sizeof(cl_ulong), &nb3)); - CL_CHECK(clSetKernelArg(kernel, 23, sizeof(float), &eps)); - CL_CHECK(clSetKernelArg(kernel, 24, sizeof(float)*sgs, NULL)); + size_t gws1 = (size_t)src0->ne[0]; - backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst); + size_t global_work_size[] = {gws0, gws1, 1}; + + backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, NULL, dst); } -static void ggml_opencl_op_norm_fused(ggml_backend_t backend, ggml_tensor * norm_tensor, ggml_tensor * mul_tensor, ggml_tensor * add_tensor) { - GGML_ASSERT(norm_tensor && mul_tensor && add_tensor); +static void ggml_cl_flash_attn(ggml_backend_t backend, const ggml_tensor * q, const ggml_tensor * k, ggml_tensor * dst) { + const ggml_tensor * v = dst->src[2]; + const ggml_tensor * mask = dst->src[3]; + const ggml_tensor * sinks = dst->src[4]; + GGML_ASSERT(q->extra); + GGML_ASSERT(k->extra); + GGML_ASSERT(v->extra); + GGML_ASSERT(dst->extra); + if (mask) { + GGML_ASSERT(mask->extra); + } + if (sinks) { + GGML_ASSERT(sinks->extra); + } - const ggml_tensor * src0 = norm_tensor->src[0]; - const ggml_tensor * src1 = mul_tensor->src[0] == norm_tensor ? mul_tensor->src[1] : mul_tensor->src[0]; - const ggml_tensor * src2 = add_tensor->src[0] == mul_tensor ? add_tensor->src[1] : add_tensor->src[0]; - const ggml_tensor * dst = add_tensor; + ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context; - ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra; - ggml_tensor_extra_cl * extra1 = (ggml_tensor_extra_cl *)src1->extra; - ggml_tensor_extra_cl * extra2 = (ggml_tensor_extra_cl *)src2->extra; - ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra; + const int n_q = q->ne[1]; + const int n_kv = k->ne[1]; + const int d_head_q = q->ne[0]; + const int d_head_v = v->ne[0]; + const int n_head = q->ne[2]; + const int n_head_kv = k->ne[2]; + const int n_batch = q->ne[3]; - cl_ulong offset0 = extra0->offset + src0->view_offs; - cl_ulong offset1 = extra1->offset + src1->view_offs; - cl_ulong offset2 = extra2->offset + src2->view_offs; - cl_ulong offsetd = extrad->offset + dst->view_offs; + cl_kernel kernel = NULL; - ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context; + const bool is_f16 = q->type == GGML_TYPE_F16; + const bool is_mixed = q->type == GGML_TYPE_F32 && k->type == GGML_TYPE_F16; + const std::pair<int, int> dk_dv = {d_head_q, d_head_v}; - float eps; - memcpy(&eps, norm_tensor->op_params, sizeof(float)); + if (n_q == 1) { + if (is_mixed) { + kernel = backend_ctx->kernels_flash_attn_f32_f16_q1.at(dk_dv); + } else if (is_f16) { + kernel = backend_ctx->kernels_flash_attn_f16_q1.at(dk_dv); + } else { + kernel = backend_ctx->kernels_flash_attn_f32_q1.at(dk_dv); + } + } else { + if (is_mixed) { + kernel = backend_ctx->kernels_flash_attn_f32_f16.at(dk_dv); + } else if (is_f16) { + kernel = backend_ctx->kernels_flash_attn_f16.at(dk_dv); + } else { + kernel = backend_ctx->kernels_flash_attn_f32.at(dk_dv); + } + } + GGML_ASSERT(kernel != NULL); - const int ne00 = src0->ne[0], ne01 = src0->ne[1], ne02 = src0->ne[2], ne03 = src0->ne[3]; - const cl_ulong nb01 = src0->nb[1], nb02 = src0->nb[2], nb03 = src0->nb[3]; - const int ne10 = src1->ne[0], ne11 = src1->ne[1], ne12 = src1->ne[2], ne13 = src1->ne[3]; - const cl_ulong nb11 = src1->nb[1], nb12 = src1->nb[2], nb13 = src1->nb[3]; - const int ne20 = src2->ne[0], ne21 = src2->ne[1], ne22 = src2->ne[2], ne23 = src2->ne[3]; - const cl_ulong nb21 = src2->nb[1], nb22 = src2->nb[2], nb23 = src2->nb[3]; - const cl_ulong nbd1 = dst->nb[1], nbd2 = dst->nb[2], nbd3 = dst->nb[3]; + ggml_tensor_extra_cl * extra_q = (ggml_tensor_extra_cl *)q->extra; + ggml_tensor_extra_cl * extra_k = (ggml_tensor_extra_cl *)k->extra; + ggml_tensor_extra_cl * extra_v = (ggml_tensor_extra_cl *)v->extra; + ggml_tensor_extra_cl * extra_o = (ggml_tensor_extra_cl *)dst->extra; + ggml_tensor_extra_cl * extra_mask = mask ? (ggml_tensor_extra_cl *)mask->extra : NULL; + ggml_tensor_extra_cl * extra_sinks = sinks ? (ggml_tensor_extra_cl *)sinks->extra : NULL; - size_t sgs; - if (backend_ctx->gpu_family == ADRENO) sgs = 64; - else if (backend_ctx->gpu_family == INTEL) sgs = 32; - else GGML_ASSERT(false && "Unsupported GPU"); + cl_ulong offset_q = extra_q->offset + q->view_offs; + cl_ulong offset_k = extra_k->offset + k->view_offs; + cl_ulong offset_v = extra_v->offset + v->view_offs; + cl_ulong offset_o = extra_o->offset + dst->view_offs; + cl_mem mask_buffer = extra_mask ? extra_mask->data_device : NULL; + cl_ulong offset_mask = extra_mask ? extra_mask->offset + mask->view_offs : 0; + cl_mem sinks_buffer = extra_sinks ? extra_sinks->data_device : NULL; + cl_ulong offset_sinks = extra_sinks ? extra_sinks->offset + sinks->view_offs : 0; - cl_kernel kernel = backend_ctx->kernel_norm_mul_add; + const cl_ulong q_nb1 = q->nb[1], q_nb2 = q->nb[2], q_nb3 = q->nb[3]; + const cl_ulong k_nb1 = k->nb[1], k_nb2 = k->nb[2], k_nb3 = k->nb[3]; + const cl_ulong v_nb1 = v->nb[1], v_nb2 = v->nb[2], v_nb3 = v->nb[3]; + const cl_ulong o_nb1 = dst->nb[1], o_nb2 = dst->nb[2], o_nb3 = dst->nb[3]; + const cl_ulong mask_nb1 = mask ? mask->nb[1] : 0; + const cl_ulong mask_nb2 = mask ? mask->nb[2] : 0; + const cl_ulong mask_nb3 = mask ? mask->nb[3] : 0; + const int mask_ne2 = mask ? mask->ne[2] : 0; + const int mask_ne3 = mask ? mask->ne[3] : 0; - int nth = sgs; - int max_workgroup_size = backend_ctx->get_kernel_workgroup_size(kernel); - while (nth < ne00/4 && nth < max_workgroup_size) nth *= 2; - nth = MIN(nth, max_workgroup_size); - nth = MIN(nth, ne00/4); + float scale, max_bias, logit_softcap; + const float * params = (const float *)dst->op_params; + scale = params[0]; + max_bias = params[1]; + logit_softcap = params[2]; - size_t gws[] = {(size_t)ne01*nth, (size_t)ne02, (size_t)ne03}; - size_t lws[] = {(size_t)nth, 1, 1}; - size_t num_subgroups = (nth + sgs - 1) / sgs; + const int is_causal = (mask == NULL && n_q > 1 && n_q == n_kv); - CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device)); - CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0)); - CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra1->data_device)); - CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offset1)); - CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extra2->data_device)); - CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offset2)); - CL_CHECK(clSetKernelArg(kernel, 6, sizeof(cl_mem), &extrad->data_device)); - CL_CHECK(clSetKernelArg(kernel, 7, sizeof(cl_ulong), &offsetd)); - CL_CHECK(clSetKernelArg(kernel, 8, sizeof(int), &ne00)); - CL_CHECK(clSetKernelArg(kernel, 9, sizeof(int), &ne01)); - CL_CHECK(clSetKernelArg(kernel, 10, sizeof(int), &ne02)); - CL_CHECK(clSetKernelArg(kernel, 11, sizeof(int), &ne03)); - CL_CHECK(clSetKernelArg(kernel, 12, sizeof(cl_ulong), &nb01)); - CL_CHECK(clSetKernelArg(kernel, 13, sizeof(cl_ulong), &nb02)); - CL_CHECK(clSetKernelArg(kernel, 14, sizeof(cl_ulong), &nb03)); - CL_CHECK(clSetKernelArg(kernel, 15, sizeof(int), &ne10)); - CL_CHECK(clSetKernelArg(kernel, 16, sizeof(int), &ne11)); - CL_CHECK(clSetKernelArg(kernel, 17, sizeof(int), &ne12)); - CL_CHECK(clSetKernelArg(kernel, 18, sizeof(int), &ne13)); - CL_CHECK(clSetKernelArg(kernel, 19, sizeof(cl_ulong), &nb11)); - CL_CHECK(clSetKernelArg(kernel, 20, sizeof(cl_ulong), &nb12)); - CL_CHECK(clSetKernelArg(kernel, 21, sizeof(cl_ulong), &nb13)); - CL_CHECK(clSetKernelArg(kernel, 22, sizeof(int), &ne20)); - CL_CHECK(clSetKernelArg(kernel, 23, sizeof(int), &ne21)); - CL_CHECK(clSetKernelArg(kernel, 24, sizeof(int), &ne22)); - CL_CHECK(clSetKernelArg(kernel, 25, sizeof(int), &ne23)); - CL_CHECK(clSetKernelArg(kernel, 26, sizeof(cl_ulong), &nb21)); - CL_CHECK(clSetKernelArg(kernel, 27, sizeof(cl_ulong), &nb22)); - CL_CHECK(clSetKernelArg(kernel, 28, sizeof(cl_ulong), &nb23)); - CL_CHECK(clSetKernelArg(kernel, 29, sizeof(cl_ulong), &nbd1)); - CL_CHECK(clSetKernelArg(kernel, 30, sizeof(cl_ulong), &nbd2)); - CL_CHECK(clSetKernelArg(kernel, 31, sizeof(cl_ulong), &nbd3)); - CL_CHECK(clSetKernelArg(kernel, 32, sizeof(float), &eps)); - CL_CHECK(clSetKernelArg(kernel, 33, sizeof(cl_float2) * num_subgroups, NULL)); + const int n_head_log2_val = n_head > 0 ? 1u << (int)floorf(log2f((float)n_head)) : 0; + const float n_head_log2_f = n_head_log2_val > 0 ? (float)n_head_log2_val : 1.0f; + const float m0 = powf(2.0f, -(max_bias) / n_head_log2_f); + const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2_f); + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra_q->data_device)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset_q)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra_k->data_device)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offset_k)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extra_v->data_device)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offset_v)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(cl_mem), &extra_o->data_device)); + CL_CHECK(clSetKernelArg(kernel, 7, sizeof(cl_ulong), &offset_o)); + CL_CHECK(clSetKernelArg(kernel, 8, sizeof(float), &scale)); + CL_CHECK(clSetKernelArg(kernel, 9, sizeof(int), &n_q)); + CL_CHECK(clSetKernelArg(kernel, 10, sizeof(int), &n_kv)); + CL_CHECK(clSetKernelArg(kernel, 11, sizeof(int), &is_causal)); + CL_CHECK(clSetKernelArg(kernel, 12, sizeof(int), &n_head)); + CL_CHECK(clSetKernelArg(kernel, 13, sizeof(cl_ulong), &q_nb1)); CL_CHECK(clSetKernelArg(kernel, 14, sizeof(cl_ulong), &q_nb2)); CL_CHECK(clSetKernelArg(kernel, 15, sizeof(cl_ulong), &q_nb3)); + CL_CHECK(clSetKernelArg(kernel, 16, sizeof(cl_ulong), &k_nb1)); CL_CHECK(clSetKernelArg(kernel, 17, sizeof(cl_ulong), &k_nb2)); CL_CHECK(clSetKernelArg(kernel, 18, sizeof(cl_ulong), &k_nb3)); + CL_CHECK(clSetKernelArg(kernel, 19, sizeof(cl_ulong), &v_nb1)); CL_CHECK(clSetKernelArg(kernel, 20, sizeof(cl_ulong), &v_nb2)); CL_CHECK(clSetKernelArg(kernel, 21, sizeof(cl_ulong), &v_nb3)); + CL_CHECK(clSetKernelArg(kernel, 22, sizeof(cl_ulong), &o_nb1)); CL_CHECK(clSetKernelArg(kernel, 23, sizeof(cl_ulong), &o_nb2)); CL_CHECK(clSetKernelArg(kernel, 24, sizeof(cl_ulong), &o_nb3)); + CL_CHECK(clSetKernelArg(kernel, 25, sizeof(float), &max_bias)); + CL_CHECK(clSetKernelArg(kernel, 26, sizeof(float), &m0)); + CL_CHECK(clSetKernelArg(kernel, 27, sizeof(float), &m1)); + CL_CHECK(clSetKernelArg(kernel, 28, sizeof(int), &n_head_log2_val)); + CL_CHECK(clSetKernelArg(kernel, 29, sizeof(float), &logit_softcap)); + CL_CHECK(clSetKernelArg(kernel, 30, sizeof(int), &n_head_kv)); + CL_CHECK(clSetKernelArg(kernel, 31, sizeof(cl_mem), &mask_buffer)); + CL_CHECK(clSetKernelArg(kernel, 32, sizeof(cl_ulong), &offset_mask)); + CL_CHECK(clSetKernelArg(kernel, 33, sizeof(cl_ulong), &mask_nb1)); + CL_CHECK(clSetKernelArg(kernel, 34, sizeof(cl_ulong), &mask_nb2)); + CL_CHECK(clSetKernelArg(kernel, 35, sizeof(cl_ulong), &mask_nb3)); + CL_CHECK(clSetKernelArg(kernel, 36, sizeof(int), &mask_ne2)); + CL_CHECK(clSetKernelArg(kernel, 37, sizeof(int), &mask_ne3)); + CL_CHECK(clSetKernelArg(kernel, 38, sizeof(cl_mem), &sinks_buffer)); + CL_CHECK(clSetKernelArg(kernel, 39, sizeof(cl_ulong), &offset_sinks)); - backend_ctx->enqueue_ndrange_kernel(kernel, 3, gws, lws, dst); + if (n_q == 1) { + const size_t wg_size = 64; + size_t local_work_size[] = { wg_size, 1 }; + size_t global_work_size[] = { wg_size, (size_t)(n_head * n_batch) }; + backend_ctx->enqueue_ndrange_kernel(kernel, 2, global_work_size, local_work_size, dst); + } else { + const int block_m = backend_ctx->kernels_flash_attn_bm.at(dk_dv); + const size_t wg_size = block_m; + size_t local_work_size[] = { wg_size, 1 }; + size_t global_work_size[] = { (size_t)((n_q + block_m - 1) / block_m) * wg_size, (size_t)(n_head * n_batch) }; + backend_ctx->enqueue_ndrange_kernel(kernel, 2, global_work_size, local_work_size, dst); + } } -static void ggml_opencl_op_group_norm_fused(ggml_backend_t backend, ggml_tensor * gn_tensor, ggml_tensor * mul_tensor, ggml_tensor * add_tensor) { - GGML_ASSERT(gn_tensor && mul_tensor && add_tensor); - - const ggml_tensor * src0 = gn_tensor->src[0]; - const ggml_tensor * src1 = mul_tensor->src[0] == gn_tensor ? mul_tensor->src[1] : mul_tensor->src[0]; - const ggml_tensor * src2 = add_tensor->src[0] == mul_tensor ? add_tensor->src[1] : add_tensor->src[0]; - const ggml_tensor * dst = add_tensor; +static void ggml_cl_mul_mat_f16_f32_tiled(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { + ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context; ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra; ggml_tensor_extra_cl * extra1 = (ggml_tensor_extra_cl *)src1->extra; - ggml_tensor_extra_cl * extra2 = (ggml_tensor_extra_cl *)src2->extra; ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra; cl_ulong offset0 = extra0->offset + src0->view_offs; cl_ulong offset1 = extra1->offset + src1->view_offs; - cl_ulong offset2 = extra2->offset + src2->view_offs; cl_ulong offsetd = extrad->offset + dst->view_offs; - ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context; + const int M = src0->ne[1]; + const int N = src1->ne[1]; + const int K = src0->ne[0]; - int groups; - float eps; - memcpy(&groups, gn_tensor->op_params, sizeof(int)); - memcpy(&eps, (char *)gn_tensor->op_params + sizeof(int), sizeof(float)); + cl_kernel kernel = backend_ctx->kernel_mul_mat_f16_f32_tiled; - cl_kernel kernel = backend_ctx->kernel_group_norm_mul_add; - int max_workgroup_size = backend_ctx->get_kernel_workgroup_size(kernel); - int ne = ggml_nelements(src0); - int group_size = ne / groups; + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(int), &M)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(int), &N)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(int), &K)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_mem), &extra0->data_device)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_ulong), &offset0)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_mem), &extra1->data_device)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(cl_ulong), &offset1)); + CL_CHECK(clSetKernelArg(kernel, 7, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, 8, sizeof(cl_ulong), &offsetd)); - size_t lws[] = { (size_t)MIN(max_workgroup_size, group_size) }; - size_t gws[] = { (size_t)groups * lws[0] }; + // Tiling parameters. These need to be tuned for optimal performance. + // They must match the #defines in the kernel mul_mat_f16_f32.cl. + // + // OPWM / OPWN: Output tile size per Work-Group. A work-group computes a tile of size OPWM x OPWN. + // TPWM / TPWN: Threads per Work-group. This is the work-group size. + // OPTM / OPTN: Output elements per Thread. Each thread computes OPTM x OPTN elements. + // + // The following relationships must hold: + // OPWM = TPWM * OPTM + // OPWN = TPWN * OPTN + // + const int OPWM = 64; + const int OPWN = 64; + const int TPWM = 16; + const int TPWN = 8; - CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device)); - CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0)); - CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra1->data_device)); - CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offset1)); - CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extra2->data_device)); - CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offset2)); - CL_CHECK(clSetKernelArg(kernel, 6, sizeof(cl_mem), &extrad->data_device)); - CL_CHECK(clSetKernelArg(kernel, 7, sizeof(cl_ulong), &offsetd)); - CL_CHECK(clSetKernelArg(kernel, 8, sizeof(int), &ne)); - CL_CHECK(clSetKernelArg(kernel, 9, sizeof(int), &group_size)); - CL_CHECK(clSetKernelArg(kernel, 10, sizeof(float), &eps)); + size_t local_work_size[2] = { TPWM, TPWN }; + size_t global_work_size[2] = { + (size_t) ((M + OPWM - 1) / OPWM) * TPWM, + (size_t) ((N + OPWN - 1) / OPWN) * TPWN, + }; - backend_ctx->enqueue_ndrange_kernel(kernel, 1, gws, lws, dst); + backend_ctx->enqueue_ndrange_kernel(kernel, 2, global_work_size, local_work_size, dst); } -static void ggml_cl_group_norm(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { - GGML_ASSERT(src0); - GGML_ASSERT(src0->extra); - GGML_ASSERT(dst); - GGML_ASSERT(dst->extra); +#ifdef GGML_OPENCL_USE_ADRENO_KERNELS +static bool ggml_cl_can_use_adreno_xmem_gemm_f16_f32( + const ggml_backend_opencl_context * backend_ctx, + const ggml_tensor * src0, + const ggml_tensor * src1, + const ggml_tensor * dst) { + if (!backend_ctx->adreno_xmem_gemm_enabled) { + return false; + } + if (backend_ctx->gpu_family != GPU_FAMILY::ADRENO) { + return false; + } + if ((src0->type != GGML_TYPE_F16 && src0->type != GGML_TYPE_BF16) || + src1->type != GGML_TYPE_F32 || dst->type != GGML_TYPE_F32) { + return false; + } + if (!ggml_is_contiguous(src0) || !ggml_is_contiguous(src1) || !ggml_is_contiguous(dst)) { + return false; + } + if (src0->ne[2] != 1 || src0->ne[3] != 1 || + src1->ne[2] != 1 || src1->ne[3] != 1 || + dst->ne[2] != 1 || dst->ne[3] != 1) { + return false; + } + const int K = src0->ne[0]; + const int M = src0->ne[1]; + const int N = src1->ne[1]; + if (src1->ne[0] != K || dst->ne[0] != M || dst->ne[1] != N) { + return false; + } + if (N <= 1 || M < 64 || N < 16 || K < 64) { + return false; + } + if ((K % 8) != 0) { + return false; + } + const int kpack = K / 4; + const int npack = CEIL_DIV(M, 4); + if (static_cast<size_t>(N) > backend_ctx->image2d_max_width || + static_cast<size_t>(kpack) > backend_ctx->image2d_max_height) { + return false; + } + if (static_cast<size_t>(N) > backend_ctx->image2d_max_width || + static_cast<size_t>(npack) > backend_ctx->image2d_max_height) { + return false; + } + return true; +} - UNUSED(src1); +static void ggml_cl_mul_mat_f16_f32_adreno_xmem( + ggml_backend_t backend, + const ggml_tensor * src0, + const ggml_tensor * src1, + ggml_tensor * dst) { + ggml_backend_opencl_context * backend_ctx = (ggml_backend_opencl_context *)backend->context; + + ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra; + ggml_tensor_extra_cl * extra1 = (ggml_tensor_extra_cl *)src1->extra; + ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra; + + const cl_ulong offset0 = extra0->offset + src0->view_offs; + const cl_ulong offset1 = extra1->offset + src1->view_offs; + const cl_ulong offsetd = extrad->offset + dst->view_offs; + + const int K = src0->ne[0]; + const int M = src0->ne[1]; + const int N = src1->ne[1]; + const int kpack = K / 4; + const int npack = CEIL_DIV(M, 4); + const int os = 8; + + const size_t xmem_bytes = 6144; + const size_t weight_bytes = static_cast<size_t>(kpack) * static_cast<size_t>(npack) * 4u * sizeof(cl_half4); + + backend_ctx->prealloc_adreno_xmem_const.allocate(backend_ctx->context, xmem_bytes); + + cl_int err = CL_SUCCESS; + cl_image_format fmt = {}; + fmt.image_channel_order = CL_RGBA; + fmt.image_channel_data_type = CL_HALF_FLOAT; + + cl_image_desc desc_src = {}; + desc_src.image_type = CL_MEM_OBJECT_IMAGE2D; + desc_src.image_width = static_cast<size_t>(N); + desc_src.image_height = static_cast<size_t>(kpack); + cl_mem src_img = clCreateImage(backend_ctx->context, CL_MEM_READ_WRITE, &fmt, &desc_src, nullptr, &err); + CL_CHECK(err); + + cl_image_desc desc_dst = {}; + desc_dst.image_type = CL_MEM_OBJECT_IMAGE2D; + desc_dst.image_width = static_cast<size_t>(N); + desc_dst.image_height = static_cast<size_t>(npack); + cl_mem dst_img = clCreateImage(backend_ctx->context, CL_MEM_READ_WRITE, &fmt, &desc_dst, nullptr, &err); + CL_CHECK(err); + + cl_mem weights = clCreateBuffer(backend_ctx->context, CL_MEM_READ_WRITE, weight_bytes, nullptr, &err); + CL_CHECK(err); + + cl_kernel prepack = backend_ctx->kernel_adreno_xmem_prepack_weight_f16; + CL_CHECK(clSetKernelArg(prepack, 0, sizeof(cl_mem), &weights)); + CL_CHECK(clSetKernelArg(prepack, 1, sizeof(cl_mem), &extra0->data_device)); + CL_CHECK(clSetKernelArg(prepack, 2, sizeof(cl_ulong), &offset0)); + CL_CHECK(clSetKernelArg(prepack, 3, sizeof(int), &K)); + CL_CHECK(clSetKernelArg(prepack, 4, sizeof(int), &M)); + CL_CHECK(clSetKernelArg(prepack, 5, sizeof(int), &kpack)); + CL_CHECK(clSetKernelArg(prepack, 6, sizeof(int), &npack)); + CL_CHECK(clSetKernelArg(prepack, 7, sizeof(int), &os)); + size_t lws = 256; + size_t max_wg = backend_ctx->get_kernel_workgroup_size(prepack); + if (lws > max_wg) { + lws = max_wg; + } + size_t gws = CEIL_DIV(static_cast<size_t>(kpack) * static_cast<size_t>(npack), lws) * lws; + backend_ctx->enqueue_ndrange_kernel(prepack, 1, &gws, &lws, dst); + + cl_kernel pack_src = backend_ctx->kernel_adreno_xmem_pack_src_f32; + CL_CHECK(clSetKernelArg(pack_src, 0, sizeof(cl_mem), &extra1->data_device)); + CL_CHECK(clSetKernelArg(pack_src, 1, sizeof(cl_ulong), &offset1)); + CL_CHECK(clSetKernelArg(pack_src, 2, sizeof(cl_mem), &src_img)); + CL_CHECK(clSetKernelArg(pack_src, 3, sizeof(int), &K)); + CL_CHECK(clSetKernelArg(pack_src, 4, sizeof(int), &N)); + size_t pack_src_lws[2] = { 16, 16 }; + size_t pack_src_gws[2] = { + CEIL_DIV(static_cast<size_t>(N), pack_src_lws[0])*pack_src_lws[0], + CEIL_DIV(static_cast<size_t>(kpack), pack_src_lws[1])*pack_src_lws[1] + }; + backend_ctx->enqueue_ndrange_kernel(pack_src, 2, pack_src_gws, pack_src_lws, dst); + + cl_kernel gemm = backend_ctx->kernel_gemm_xmem_f16_f32_os8; + CL_CHECK(clSetKernelArg(gemm, 0, sizeof(cl_mem), &weights)); + CL_CHECK(clSetKernelArg(gemm, 1, sizeof(cl_mem), &backend_ctx->prealloc_adreno_xmem_const.buffer)); + CL_CHECK(clSetKernelArg(gemm, 2, sizeof(cl_mem), &src_img)); + CL_CHECK(clSetKernelArg(gemm, 3, sizeof(cl_mem), &dst_img)); + CL_CHECK(clSetKernelArg(gemm, 4, sizeof(int), &N)); + CL_CHECK(clSetKernelArg(gemm, 5, sizeof(int), &npack)); + CL_CHECK(clSetKernelArg(gemm, 6, sizeof(int), &kpack)); + const size_t z_values = CEIL_DIV(static_cast<size_t>(npack), static_cast<size_t>(os)); + size_t gemm_lws[3] = { 64, 1, 1 }; + size_t gemm_gws[3] = { + z_values*gemm_lws[0], + CEIL_DIV(static_cast<size_t>(N), gemm_lws[0]), + 1 + }; + backend_ctx->enqueue_ndrange_kernel(gemm, 3, gemm_gws, gemm_lws, dst); + + cl_kernel store_dst = backend_ctx->kernel_adreno_xmem_store_dst_f32; + CL_CHECK(clSetKernelArg(store_dst, 0, sizeof(cl_mem), &dst_img)); + CL_CHECK(clSetKernelArg(store_dst, 1, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(store_dst, 2, sizeof(cl_ulong), &offsetd)); + CL_CHECK(clSetKernelArg(store_dst, 3, sizeof(int), &M)); + CL_CHECK(clSetKernelArg(store_dst, 4, sizeof(int), &N)); + size_t store_lws[2] = { 16, 16 }; + size_t store_gws[2] = { + CEIL_DIV(static_cast<size_t>(N), store_lws[0])*store_lws[0], + CEIL_DIV(static_cast<size_t>(npack), store_lws[1])*store_lws[1] + }; + backend_ctx->enqueue_ndrange_kernel(store_dst, 2, store_gws, store_lws, dst); + + CL_CHECK(clReleaseMemObject(weights)); + CL_CHECK(clReleaseMemObject(dst_img)); + CL_CHECK(clReleaseMemObject(src_img)); +} +#endif // GGML_OPENCL_USE_ADRENO_KERNELS +static void ggml_cl_conv_2d(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { + GGML_TENSOR_BINARY_OP_LOCALS; ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context; ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra; + ggml_tensor_extra_cl * extra1 = (ggml_tensor_extra_cl *)src1->extra; ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra; cl_ulong offset0 = extra0->offset + src0->view_offs; + cl_ulong offset1 = extra1->offset + src1->view_offs; cl_ulong offsetd = extrad->offset + dst->view_offs; - int32_t n_groups = ((const int32_t *) dst->op_params)[0]; - int32_t group_size = src0->ne[0] * src0->ne[1] * ((src0->ne[2] + n_groups - 1) / n_groups); - float eps = ((const float *) dst->op_params)[1]; + const cl_uint Cout = ne03; const cl_uint Cin = ne02; const cl_uint N = ne13; + const cl_uint KW = ne00; const cl_uint KH = ne01; const cl_uint W = ne10; const cl_uint H = ne11; const cl_uint OW = ne0; const cl_uint OH = ne1; - const int ne00 = src0->ne[0]; - const int ne01 = src0->ne[1]; - const int ne02 = src0->ne[2]; - const int ne = ne00*ne01*ne02; + const cl_uint s0 = dst->op_params[0]; const cl_uint s1 = dst->op_params[1]; + const cl_uint p0 = dst->op_params[2]; const cl_uint p1 = dst->op_params[3]; + const cl_uint d0 = dst->op_params[4]; const cl_uint d1 = dst->op_params[5]; - cl_kernel kernel = backend_ctx->kernel_group_norm; + const cl_uint cl_nb01 = nb01/ggml_type_size(src0->type); const cl_uint cl_nb02 = nb02/ggml_type_size(src0->type); const cl_uint cl_nb03 = nb03/ggml_type_size(src0->type); + const cl_uint cl_nb11 = nb11/ggml_type_size(src1->type); const cl_uint cl_nb12 = nb12/ggml_type_size(src1->type); const cl_uint cl_nb13 = nb13/ggml_type_size(src1->type); + const cl_uint cl_nb1 = nb1/ggml_type_size(dst->type); const cl_uint cl_nb2 = nb2/ggml_type_size(dst->type); const cl_uint cl_nb3 = nb3/ggml_type_size(dst->type); - size_t sgs = 64; - if (backend_ctx->gpu_family == ADRENO) { - sgs = 64; - } else if (backend_ctx->gpu_family == INTEL) { - sgs = 32; + const int64_t NPQ = (int64_t)N * OW * OH; + + const uint32_t BS_K = 64; + const uint32_t BS_NPQ = 64; + const uint32_t BS_CRS = 16; + const uint32_t VEC_SIZE = 4; + + const uint32_t TS_K = 4; + const uint32_t TS_NPQ = 8; + + const uint32_t WG_K = BS_K / TS_K; + const uint32_t WG_NPQ = BS_NPQ / TS_NPQ; + + auto splitWork = [](uint32_t work_size, uint32_t block_size) { return (block_size + work_size - 1) / block_size; }; + const uint32_t NB_K = splitWork(Cout, BS_K); + const uint32_t NB_NPQ = splitWork(NPQ, BS_NPQ); + + cl_kernel kernel; + size_t shmem_size; + + if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16) { + kernel = backend_ctx->kernel_conv_2d_f16; + shmem_size = (size_t)(BS_K * BS_CRS * sizeof(cl_half) + BS_CRS * (BS_NPQ / VEC_SIZE) * sizeof(cl_half4)); + } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32) { + kernel = backend_ctx->kernel_conv_2d_f32; + shmem_size = (size_t)(BS_K * BS_CRS * sizeof(cl_float) + BS_CRS * (BS_NPQ / VEC_SIZE) * sizeof(cl_float4)); + } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32) { + kernel = backend_ctx->kernel_conv_2d_f16_f32; + shmem_size = (size_t)(BS_K * BS_CRS * sizeof(cl_half) + BS_CRS * (BS_NPQ / VEC_SIZE) * sizeof(cl_float4)); } else { - GGML_ASSERT(false && "Unsupported GPU"); + GGML_ASSERT(false && "Unsupported data type combination for conv2d"); } - CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device)); - CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0)); - CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extrad->data_device)); - CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offsetd)); - CL_CHECK(clSetKernelArg(kernel, 4, sizeof(int), &ne)); - CL_CHECK(clSetKernelArg(kernel, 5, sizeof(int), &group_size)); - CL_CHECK(clSetKernelArg(kernel, 6, sizeof(float), &eps)); + cl_uint idx = 0; + CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_mem), &extra0->data_device)); CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_ulong), &offset0)); + CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_mem), &extra1->data_device)); CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_ulong), &offset1)); + CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_mem), &extrad->data_device)); CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_ulong), &offsetd)); + CL_CHECK(clSetKernelArg(kernel, idx++, shmem_size, NULL)); + CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &Cout)); CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &Cin)); CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &N)); + CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &KW)); CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &KH)); CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &W)); CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &H)); + CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &OW)); CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &OH)); + CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &s0)); CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &s1)); CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &p0)); CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &p1)); + CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &d0)); CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &d1)); + CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &cl_nb01)); CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &cl_nb02)); CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &cl_nb03)); + CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &cl_nb11)); CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &cl_nb12)); CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &cl_nb13)); + CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &cl_nb1)); CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &cl_nb2)); CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &cl_nb3)); - size_t global_work_size[] = {(size_t)n_groups*sgs, 1, 1}; - size_t local_work_size[] = {(size_t)sgs, 1, 1}; + size_t global_work_size[] = { (size_t)NB_K * WG_K, (size_t)NB_NPQ * WG_NPQ, 1 }; + size_t local_work_size[] = { (size_t)WG_K, (size_t)WG_NPQ, 1 }; - backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst); + backend_ctx->enqueue_ndrange_kernel(kernel, 2, global_work_size, local_work_size, dst); } -static void ggml_cl_tanh(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { - GGML_ASSERT(src0); - GGML_ASSERT(src0->extra); - GGML_ASSERT(dst); - GGML_ASSERT(dst->extra); - - UNUSED(src1); - +static void ggml_cl_mul_mat_kq_kqv_adreno(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context; ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra; + ggml_tensor_extra_cl * extra1 = (ggml_tensor_extra_cl *)src1->extra; ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra; - cl_ulong offset0_abs = extra0->offset + src0->view_offs; - cl_ulong offsetd_abs = extrad->offset + dst->view_offs; + const int ne00 = src0->ne[0]; + const int ne01 = src0->ne[1]; + const int ne02 = src0->ne[2]; + + const cl_ulong nb01 = src0->nb[1]; + const cl_ulong nb02 = src0->nb[2]; + + const int ne10 = src1->ne[0]; + const int ne11 = src1->ne[1]; + const int ne12 = src1->ne[2]; + + const cl_ulong nb10 = src1->nb[0]; + + const int ne0 = dst->ne[0]; + const int ne1 = dst->ne[1]; + + GGML_ASSERT(ne00 == ne10); + + cl_kernel kernel; + cl_context context = backend_ctx->context; + + cl_int status; + cl_image_format img_fmt_1d; + cl_image_desc img_desc_1d; + cl_buffer_region region; + cl_mem A_image1d; + cl_mem A_sub_buffer; + cl_mem B_sub_buffer; + cl_mem D_image1d; + cl_mem D_sub_buffer; + + int M = ne01; + int N = ne1; + int K = ne00; - cl_kernel kernel; - if (dst->type == GGML_TYPE_F32) { - kernel = backend_ctx->kernel_tanh_f32_nd; - } else if (dst->type == GGML_TYPE_F16) { - kernel = backend_ctx->kernel_tanh_f16_nd; + if (nb01 > nb02) { + // KQ + kernel = backend_ctx->kernel_mul_mm_f16_f32_kq; } else { - GGML_ASSERT(false && "Unsupported type for ggml_cl_tanh"); + // KQV + kernel = backend_ctx->kernel_mul_mm_f16_f32_kqv; } - GGML_ASSERT(kernel != nullptr); - - const int ne00 = src0->ne[0]; const int ne01 = src0->ne[1]; const int ne02 = src0->ne[2]; const int ne03 = src0->ne[3]; - const cl_ulong nb00 = src0->nb[0]; const cl_ulong nb01 = src0->nb[1]; const cl_ulong nb02 = src0->nb[2]; const cl_ulong nb03 = src0->nb[3]; + // create sub-buffer for A + // <--------------------------------------------> // + extra0 = src0->view_src ? (ggml_tensor_extra_cl *)src0->view_src->extra : (ggml_tensor_extra_cl *)src0->extra; - const int ne10 = dst->ne[0]; const int ne11 = dst->ne[1]; const int ne12 = dst->ne[2]; const int ne13 = dst->ne[3]; - const cl_ulong nb10 = dst->nb[0]; const cl_ulong nb11 = dst->nb[1]; const cl_ulong nb12 = dst->nb[2]; const cl_ulong nb13 = dst->nb[3]; + region.origin = (extra0->offset); + if (nb01 > nb02) { + // KQ + region.size = nb01 * ne01; + } else { + // KQV + region.size = nb02 * ne02; + } - CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device)); - CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0_abs)); - CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extrad->data_device)); - CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offsetd_abs)); + A_sub_buffer = clCreateSubBuffer((extra0->data_device), 0, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &status); + CL_CHECK(status); - CL_CHECK(clSetKernelArg(kernel, 4, sizeof(int), &ne00)); - CL_CHECK(clSetKernelArg(kernel, 5, sizeof(int), &ne01)); - CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne02)); - CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &ne03)); - CL_CHECK(clSetKernelArg(kernel, 8, sizeof(cl_ulong), &nb00)); - CL_CHECK(clSetKernelArg(kernel, 9, sizeof(cl_ulong), &nb01)); - CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_ulong),&nb02)); - CL_CHECK(clSetKernelArg(kernel, 11, sizeof(cl_ulong),&nb03)); + // <--------------------------------------------> // - CL_CHECK(clSetKernelArg(kernel, 12, sizeof(int), &ne10)); - CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int), &ne11)); - CL_CHECK(clSetKernelArg(kernel, 14, sizeof(int), &ne12)); - CL_CHECK(clSetKernelArg(kernel, 15, sizeof(int), &ne13)); - CL_CHECK(clSetKernelArg(kernel, 16, sizeof(cl_ulong),&nb10)); - CL_CHECK(clSetKernelArg(kernel, 17, sizeof(cl_ulong),&nb11)); - CL_CHECK(clSetKernelArg(kernel, 18, sizeof(cl_ulong),&nb12)); - CL_CHECK(clSetKernelArg(kernel, 19, sizeof(cl_ulong),&nb13)); + // create sub-buffer for B + // <--------------------------------------------> // + region.origin = (extra1->offset); + region.size = nb10 * ne10 * ne11 * ne12; + B_sub_buffer = clCreateSubBuffer((extra1->data_device), 0, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &status); + CL_CHECK(status); + // <--------------------------------------------> // - size_t global_work_size[3]; - if (ne10 == 0 || ne11 == 0 || ne12 == 0 || ne13 == 0) { // Handle case of 0 elements - return; + img_fmt_1d = {CL_RGBA, CL_FLOAT}; + memset(&img_desc_1d, 0, sizeof(img_desc_1d)); + img_desc_1d.image_type = CL_MEM_OBJECT_IMAGE1D_BUFFER; + if (nb01 > nb02) { + img_desc_1d.image_width = (nb01 * ne01 / 4)/4; + } + else { + img_desc_1d.image_width = (nb02 * ne02 / 4)/4; } - global_work_size[0] = (size_t)ne10; - global_work_size[1] = (size_t)ne11; - global_work_size[2] = (size_t)ne12; + img_desc_1d.buffer = A_sub_buffer; + A_image1d = clCreateImage(context, CL_MEM_READ_ONLY, &img_fmt_1d, &img_desc_1d, NULL, &status); + CL_CHECK(status); + + // create sub-buffer for output C + // <--------------------------------------------> // + region.origin = (extrad->offset); + region.size = ne0 * ne1 * dst->ne[2] * dst->nb[0]; // size of C in bytes + D_sub_buffer = clCreateSubBuffer((extrad->data_device), 0, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &status); + CL_CHECK(status); + // <--------------------------------------------> // - size_t lws0 = 16, lws1 = 4, lws2 = 1; - if (ne10 < 16) lws0 = ne10; - if (ne11 < 4) lws1 = ne11; - if (ne12 < 1) lws2 = ne12 > 0 ? ne12 : 1; + // create image for C output + // <--------------------------------------------> // + img_fmt_1d = {CL_R, CL_FLOAT}; + memset(&img_desc_1d, 0, sizeof(img_desc_1d)); + img_desc_1d.image_type = CL_MEM_OBJECT_IMAGE1D_BUFFER; + img_desc_1d.image_width = ne0 * ne1 * dst->ne[2] * dst->nb[0] / 4; + img_desc_1d.buffer = D_sub_buffer; + D_image1d = clCreateImage(context, CL_MEM_WRITE_ONLY, &img_fmt_1d, &img_desc_1d, NULL, &status); + CL_CHECK(status); + // <--------------------------------------------> // - while (lws0 * lws1 * lws2 > 256 && lws0 > 1) lws0 /= 2; - while (lws0 * lws1 * lws2 > 256 && lws1 > 1) lws1 /= 2; - while (lws0 * lws1 * lws2 > 256 && lws2 > 1) lws2 /= 2; + int offset_src0 = 0; + int offset_src1 = 0; + // set kernel args + // <--------------------------------------------> // + cl_uint k_arg = 0; + CL_CHECK(clSetKernelArg(kernel, k_arg++, sizeof(cl_mem), &A_image1d)); + CL_CHECK(clSetKernelArg(kernel, k_arg++, sizeof(int), &offset_src0)); + CL_CHECK(clSetKernelArg(kernel, k_arg++, sizeof(cl_mem), &B_sub_buffer)); + CL_CHECK(clSetKernelArg(kernel, k_arg++, sizeof(int), &offset_src1)); + CL_CHECK(clSetKernelArg(kernel, k_arg++, sizeof(cl_mem), &D_image1d)); + CL_CHECK(clSetKernelArg(kernel, k_arg++, sizeof(int), &extrad->offset)); + CL_CHECK(clSetKernelArg(kernel, k_arg++, sizeof(int), &M)); + CL_CHECK(clSetKernelArg(kernel, k_arg++, sizeof(int), &K)); + CL_CHECK(clSetKernelArg(kernel, k_arg++, sizeof(int), &N)); + CL_CHECK(clSetKernelArg(kernel, k_arg++, sizeof(int), &ne02)); + CL_CHECK(clSetKernelArg(kernel, k_arg++, sizeof(int), &ne12)); + CL_CHECK(clSetKernelArg(kernel, k_arg++, sizeof(int), &nb01)); - size_t local_work_size[] = {lws0, lws1, lws2}; + size_t global_work_size[3] = {64, static_cast<size_t>(((M+63)/64)), static_cast<size_t>(((N+31)/32)*ne12)}; + size_t local_work_size[3] = {64, 1, 2}; - size_t* local_work_size_ptr = local_work_size; - if (!backend_ctx->non_uniform_workgroups) { - if (global_work_size[0] % local_work_size[0] != 0 || - global_work_size[1] % local_work_size[1] != 0 || - global_work_size[2] % local_work_size[2] != 0) { - local_work_size_ptr = NULL; - } - } - if (global_work_size[0] == 0 || global_work_size[1] == 0 || global_work_size[2] == 0) return; + backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst); - backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size_ptr, dst); + // deallocate sub buffers and images + // <--------------------------------------------> // + CL_CHECK(clReleaseMemObject(A_image1d)); + CL_CHECK(clReleaseMemObject(D_image1d)); + CL_CHECK(clReleaseMemObject(A_sub_buffer)); + CL_CHECK(clReleaseMemObject(B_sub_buffer)); + CL_CHECK(clReleaseMemObject(D_sub_buffer)); } -static void ggml_cl_expm1(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { +static void ggml_cl_mul_mat_q4_0_f32_adreno(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { +#ifdef GGML_OPENCL_USE_ADRENO_KERNELS GGML_ASSERT(src0); GGML_ASSERT(src0->extra); + GGML_ASSERT(src1); + GGML_ASSERT(src1->extra); GGML_ASSERT(dst); GGML_ASSERT(dst->extra); - UNUSED(src1); - ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context; - ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra; + ggml_tensor_extra_cl * extra1 = (ggml_tensor_extra_cl *)src1->extra; ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra; + ggml_tensor_extra_cl_q4_0 * extra0_q4_0 = (ggml_tensor_extra_cl_q4_0 *)src0->extra; - cl_ulong offset0_abs = extra0->offset + src0->view_offs; - cl_ulong offsetd_abs = extrad->offset + dst->view_offs; - - cl_kernel kernel; - if (dst->type == GGML_TYPE_F32) { - kernel = backend_ctx->kernel_expm1_f32_nd; - } else if (dst->type == GGML_TYPE_F16) { - kernel = backend_ctx->kernel_expm1_f16_nd; - } else { - GGML_ASSERT(false && "Unsupported type for ggml_cl_expm1"); - } - GGML_ASSERT(kernel != nullptr); + cl_ulong offset1 = extra1->offset + src1->view_offs; + cl_ulong offsetd = extrad->offset + dst->view_offs; const int ne00 = src0->ne[0]; const int ne01 = src0->ne[1]; const int ne02 = src0->ne[2]; - const int ne03 = src0->ne[3]; - const cl_ulong nb00 = src0->nb[0]; - const cl_ulong nb01 = src0->nb[1]; - const cl_ulong nb02 = src0->nb[2]; - const cl_ulong nb03 = src0->nb[3]; + const int ne10 = src1->ne[0]; + const int ne12 = src1->ne[2]; - const int ne10 = dst->ne[0]; - const int ne11 = dst->ne[1]; - const int ne12 = dst->ne[2]; - const int ne13 = dst->ne[3]; + const int ne0 = dst->ne[0]; + const int ne1 = dst->ne[1]; - const cl_ulong nb10 = dst->nb[0]; - const cl_ulong nb11 = dst->nb[1]; - const cl_ulong nb12 = dst->nb[2]; - const cl_ulong nb13 = dst->nb[3]; + GGML_ASSERT(ne00 % ggml_blck_size(src0->type) == 0); - CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device)); - CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0_abs)); - CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extrad->data_device)); - CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offsetd_abs)); + cl_context context = backend_ctx->context; + cl_kernel kernel; - CL_CHECK(clSetKernelArg(kernel, 4, sizeof(int), &ne00)); - CL_CHECK(clSetKernelArg(kernel, 5, sizeof(int), &ne01)); - CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne02)); - CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &ne03)); - CL_CHECK(clSetKernelArg(kernel, 8, sizeof(cl_ulong), &nb00)); - CL_CHECK(clSetKernelArg(kernel, 9, sizeof(cl_ulong), &nb01)); - CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_ulong),&nb02)); - CL_CHECK(clSetKernelArg(kernel, 11, sizeof(cl_ulong),&nb03)); + cl_int err; + cl_image_format img_fmt; + cl_image_desc img_desc; + cl_buffer_region region; - CL_CHECK(clSetKernelArg(kernel, 12, sizeof(int), &ne10)); - CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int), &ne11)); - CL_CHECK(clSetKernelArg(kernel, 14, sizeof(int), &ne12)); - CL_CHECK(clSetKernelArg(kernel, 15, sizeof(int), &ne13)); - CL_CHECK(clSetKernelArg(kernel, 16, sizeof(cl_ulong),&nb10)); - CL_CHECK(clSetKernelArg(kernel, 17, sizeof(cl_ulong),&nb11)); - CL_CHECK(clSetKernelArg(kernel, 18, sizeof(cl_ulong),&nb12)); - CL_CHECK(clSetKernelArg(kernel, 19, sizeof(cl_ulong),&nb13)); + int M = ne01; + int N = ne1; + int K = ne00; - size_t global_work_size[3]; - if (ne10 == 0 || ne11 == 0 || ne12 == 0 || ne13 == 0) { // Handle case of 0 elements - return; - } - global_work_size[0] = (size_t)ne10; - global_work_size[1] = (size_t)ne11; - global_work_size[2] = (size_t)ne12; + if (ne1 == 1) { + cl_mem q_img = nullptr; + cl_mem b_sub_buf = nullptr; + cl_mem b_img = nullptr; + + // image for q + img_fmt = { CL_R, CL_UNSIGNED_INT32}; + memset(&img_desc, 0, sizeof(img_desc)); + img_desc.image_type = CL_MEM_OBJECT_IMAGE1D_BUFFER; + img_desc.image_width = M * K / 2 / 4; + img_desc.buffer = extra0_q4_0->q; + CL_CHECK((q_img = clCreateImage(context, CL_MEM_READ_ONLY, &img_fmt, &img_desc, NULL, &err), err)); + + // subbuffer for activations + region.origin = offset1; + region.size = K * N * sizeof(float); + CL_CHECK((b_sub_buf = clCreateSubBuffer(extra1->data_device, 0, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &err), err)); + + // image for activations + img_fmt = {CL_RGBA, CL_FLOAT}; + memset(&img_desc, 0, sizeof(img_desc)); + img_desc.image_type = CL_MEM_OBJECT_IMAGE1D_BUFFER; + img_desc.image_width = K * N / 4; + img_desc.buffer = b_sub_buf; + CL_CHECK((b_img = clCreateImage(context, CL_MEM_READ_ONLY, &img_fmt, &img_desc, NULL, &err), err)); + + kernel = backend_ctx->kernel_gemv_noshuffle_q4_0_f32; + if (M == 4096 && K == 4096) { + kernel = backend_ctx->kernel_gemv_noshuffle_q4_0_f32_4096_1_4096; + } else if (M == 4096 && K == 11008) { + kernel = backend_ctx->kernel_gemv_noshuffle_q4_0_f32_4096_1_11008; + } else if (M == 11008 && K == 4096) { + kernel = backend_ctx->kernel_gemv_noshuffle_q4_0_f32_11008_1_4096; + } else if (M == 32000 && K == 4096) { + kernel = backend_ctx->kernel_gemv_noshuffle_q4_0_f32_32000_1_4096; + } - size_t lws0 = 16, lws1 = 4, lws2 = 1; - if (ne10 < 16) lws0 = ne10; - if (ne11 < 4) lws1 = ne11; - if (ne12 < 1) lws2 = ne12 > 0 ? ne12 : 1; + int r2 = 1; + int r3 = 1; - while (lws0 * lws1 * lws2 > 256 && lws0 > 1) lws0 /= 2; - while (lws0 * lws1 * lws2 > 256 && lws1 > 1) lws1 /= 2; - while (lws0 * lws1 * lws2 > 256 && lws2 > 1) lws2 /= 2; + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &q_img)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra0_q4_0->d)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &b_img)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offset1)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offsetd)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &ne01)); + CL_CHECK(clSetKernelArg(kernel, 8, sizeof(int), &ne02)); + CL_CHECK(clSetKernelArg(kernel, 9, sizeof(int), &ne10)); + CL_CHECK(clSetKernelArg(kernel, 10, sizeof(int), &ne12)); + CL_CHECK(clSetKernelArg(kernel, 11, sizeof(int), &ne0)); + CL_CHECK(clSetKernelArg(kernel, 12, sizeof(int), &ne1)); + CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int), &r2)); + CL_CHECK(clSetKernelArg(kernel, 14, sizeof(int), &r3)); + size_t local_work_size[3] = {64, 4, 1}; + size_t global_work_size[3] = {(size_t)CEIL_DIV(ne01/2, 64)*64, 4, 1}; - size_t local_work_size[] = {lws0, lws1, lws2}; + backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst); - size_t* local_work_size_ptr = local_work_size; - if (!backend_ctx->non_uniform_workgroups) { - if (global_work_size[0] % local_work_size[0] != 0 || - global_work_size[1] % local_work_size[1] != 0 || - global_work_size[2] % local_work_size[2] != 0) { - local_work_size_ptr = NULL; + CL_CHECK(clReleaseMemObject(q_img)); + CL_CHECK(clReleaseMemObject(b_sub_buf)); + CL_CHECK(clReleaseMemObject(b_img)); + } else { + cl_mem b_sub_buf = nullptr; + cl_mem b_sub_buf_trans = nullptr; + cl_mem b_img = nullptr; + cl_mem b_img_trans = nullptr; + cl_mem d_sub_buf = nullptr; + + // subbuffer for activations + region.origin = offset1; + region.size = K * N * sizeof(float); + CL_CHECK((b_sub_buf = clCreateSubBuffer(extra1->data_device, 0, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &err), err)); + + // image for activations + img_fmt = {CL_RGBA, CL_FLOAT}; + memset(&img_desc, 0, sizeof(img_desc)); + img_desc.image_type = CL_MEM_OBJECT_IMAGE1D_BUFFER; + img_desc.image_width = K * N / 4; + img_desc.buffer = b_sub_buf; + CL_CHECK((b_img = clCreateImage(context, CL_MEM_READ_ONLY, &img_fmt, &img_desc, NULL, &err), err)); + + // pad N to multiple of 8 + int extra_elements = N % 8; + int padding = 0; + if (extra_elements > 0){ + padding = 8 - extra_elements; } - } - if (global_work_size[0] == 0 || global_work_size[1] == 0 || global_work_size[2] == 0) return; - backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size_ptr, dst); + // subbuffer for transposed activations + region.origin = 0; + region.size = K * (N + padding) * sizeof(float)/2; + backend_ctx->prealloc_act_trans.allocate(context, region.size); + CL_CHECK((b_sub_buf_trans = clCreateSubBuffer(backend_ctx->prealloc_act_trans.buffer, 0, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &err), err)); + + // image for transposed activations + img_fmt = {CL_RGBA, CL_HALF_FLOAT}; + memset(&img_desc, 0, sizeof(img_desc)); + img_desc.image_type = CL_MEM_OBJECT_IMAGE1D_BUFFER; + img_desc.image_width = K * (N + padding) / 4; + img_desc.buffer = b_sub_buf_trans; + CL_CHECK((b_img_trans = clCreateImage(context, 0, &img_fmt, &img_desc, NULL, &err), err)); + + // subbuffer for output + region.origin = extrad->offset; // Specify the starting offset (in bytes) + region.size = M * N * sizeof(float); // Specify the size of the sub-buffer + CL_CHECK((d_sub_buf = clCreateSubBuffer(extrad->data_device, CL_MEM_WRITE_ONLY, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &err), err)); + + // transpose activations + int height_B = N/4; + if (height_B == 0) { + height_B = 1; + } + int width_B = K/4; + int padded_height_B = (N + padding)/4; + + kernel = backend_ctx->kernel_transpose_32_16; + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &b_img)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &b_img_trans)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(int), &height_B)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(int), &width_B)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(int), &padded_height_B)); + + size_t local_work_size_t[2] = { 1, 16 }; + size_t global_work_size_t[2] = { (size_t)width_B, (size_t)padded_height_B }; + if (ne0 == 4096 && ne1 == 128 && ne10 == 4096) { + local_work_size_t[0]=4; + local_work_size_t[1]=8; + } else if (ne0 == 11008 && ne1 == 128 && ne10 == 4096) { + local_work_size_t[0]=2; + local_work_size_t[1]=8; + } else if(ne0 == 4096 && ne1 == 128 && ne10 == 11008) { + local_work_size_t[0]=1; + local_work_size_t[1]=8; + } else if(ne0 == 32000 && ne1 == 128 && ne10 == 4096) { + local_work_size_t[0]=2; + local_work_size_t[1]=8; + } + backend_ctx->enqueue_ndrange_kernel(kernel, 2, global_work_size_t, local_work_size_t, dst); + + // gemm + kernel = backend_ctx->kernel_gemm_noshuffle_q4_0_f32; + int padded_N = N + padding; + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0_q4_0->q)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra0_q4_0->d)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &b_img_trans)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_mem), &d_sub_buf)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_int), &ne01)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_int), &padded_N)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(cl_int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, 7, sizeof(cl_int), &ne1)); + + size_t global_work_size[3] = {(size_t)CEIL_DIV(ne1, 8), (size_t)CEIL_DIV(ne01, 4), 1}; + size_t local_work_size[3] = {1, 128, 1}; + if (ne0 == 4096 && ne1 == 128 && ne10 == 4096) { + local_work_size[0] = 1; + local_work_size[1] = 128; + } else if (ne0 == 11008 && ne1 == 128 && ne10 == 4096) { + local_work_size[0] = 2; + local_work_size[1] = 64; + } else if (ne0 == 4096 && ne1 == 128 && ne10 == 11008) { + local_work_size[0] = 2; + local_work_size[1] = 64; + } else if (ne0 == 32000 && ne1 == 128 && ne10 == 4096) { + local_work_size[0] = 2; + local_work_size[1] = 64; + } + + backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst); + + CL_CHECK(clReleaseMemObject(b_sub_buf)); + CL_CHECK(clReleaseMemObject(b_sub_buf_trans)); + CL_CHECK(clReleaseMemObject(b_img)); + CL_CHECK(clReleaseMemObject(b_img_trans)); + CL_CHECK(clReleaseMemObject(d_sub_buf)); + } +#else + GGML_UNUSED(backend); + GGML_UNUSED(src0); + GGML_UNUSED(src1); + GGML_UNUSED(dst); +#endif } -static void ggml_cl_softplus(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { +static void ggml_cl_mul_mat_q4_1_f32_adreno(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { +#ifdef GGML_OPENCL_USE_ADRENO_KERNELS GGML_ASSERT(src0); GGML_ASSERT(src0->extra); + GGML_ASSERT(src1); + GGML_ASSERT(src1->extra); GGML_ASSERT(dst); GGML_ASSERT(dst->extra); - UNUSED(src1); - ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context; - ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra; + ggml_tensor_extra_cl * extra1 = (ggml_tensor_extra_cl *)src1->extra; ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra; + ggml_tensor_extra_cl_q4_1 * extra0_q4_1 = (ggml_tensor_extra_cl_q4_1 *)src0->extra; - cl_ulong offset0_abs = extra0->offset + src0->view_offs; - cl_ulong offsetd_abs = extrad->offset + dst->view_offs; - - cl_kernel kernel; - if (dst->type == GGML_TYPE_F32) { - kernel = backend_ctx->kernel_softplus_f32_nd; - } else if (dst->type == GGML_TYPE_F16) { - kernel = backend_ctx->kernel_softplus_f16_nd; - } else { - GGML_ASSERT(false && "Unsupported type for ggml_cl_softplus"); - } - GGML_ASSERT(kernel != nullptr); - - const int ne00 = src0->ne[0]; - const int ne01 = src0->ne[1]; - const int ne02 = src0->ne[2]; - const int ne03 = src0->ne[3]; - - const cl_ulong nb00 = src0->nb[0]; - const cl_ulong nb01 = src0->nb[1]; - const cl_ulong nb02 = src0->nb[2]; - const cl_ulong nb03 = src0->nb[3]; - - const int ne10 = dst->ne[0]; - const int ne11 = dst->ne[1]; - const int ne12 = dst->ne[2]; - const int ne13 = dst->ne[3]; - - const cl_ulong nb10 = dst->nb[0]; - const cl_ulong nb11 = dst->nb[1]; - const cl_ulong nb12 = dst->nb[2]; - const cl_ulong nb13 = dst->nb[3]; - - CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device)); - CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0_abs)); - CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extrad->data_device)); - CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offsetd_abs)); - - CL_CHECK(clSetKernelArg(kernel, 4, sizeof(int), &ne00)); - CL_CHECK(clSetKernelArg(kernel, 5, sizeof(int), &ne01)); - CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne02)); - CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &ne03)); - CL_CHECK(clSetKernelArg(kernel, 8, sizeof(cl_ulong), &nb00)); - CL_CHECK(clSetKernelArg(kernel, 9, sizeof(cl_ulong), &nb01)); - CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_ulong),&nb02)); - CL_CHECK(clSetKernelArg(kernel, 11, sizeof(cl_ulong),&nb03)); - - CL_CHECK(clSetKernelArg(kernel, 12, sizeof(int), &ne10)); - CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int), &ne11)); - CL_CHECK(clSetKernelArg(kernel, 14, sizeof(int), &ne12)); - CL_CHECK(clSetKernelArg(kernel, 15, sizeof(int), &ne13)); - CL_CHECK(clSetKernelArg(kernel, 16, sizeof(cl_ulong),&nb10)); - CL_CHECK(clSetKernelArg(kernel, 17, sizeof(cl_ulong),&nb11)); - CL_CHECK(clSetKernelArg(kernel, 18, sizeof(cl_ulong),&nb12)); - CL_CHECK(clSetKernelArg(kernel, 19, sizeof(cl_ulong),&nb13)); - - size_t global_work_size[3]; - if (ne10 == 0 || ne11 == 0 || ne12 == 0 || ne13 == 0) { // Handle case of 0 elements - return; - } - global_work_size[0] = (size_t)ne10; - global_work_size[1] = (size_t)ne11; - global_work_size[2] = (size_t)ne12; - - size_t lws0 = 16, lws1 = 4, lws2 = 1; - if (ne10 < 16) lws0 = ne10; - if (ne11 < 4) lws1 = ne11; - if (ne12 < 1) lws2 = ne12 > 0 ? ne12 : 1; - - while (lws0 * lws1 * lws2 > 256 && lws0 > 1) lws0 /= 2; - while (lws0 * lws1 * lws2 > 256 && lws1 > 1) lws1 /= 2; - while (lws0 * lws1 * lws2 > 256 && lws2 > 1) lws2 /= 2; + cl_ulong offset1 = extra1->offset + src1->view_offs; + cl_ulong offsetd = extrad->offset + dst->view_offs; + const int ne00 = src0->ne[0]; + const int ne01 = src0->ne[1]; - size_t local_work_size[] = {lws0, lws1, lws2}; + const int ne1 = dst->ne[1]; - size_t* local_work_size_ptr = local_work_size; - if (!backend_ctx->non_uniform_workgroups) { - if (global_work_size[0] % local_work_size[0] != 0 || - global_work_size[1] % local_work_size[1] != 0 || - global_work_size[2] % local_work_size[2] != 0) { - local_work_size_ptr = NULL; - } - } - if (global_work_size[0] == 0 || global_work_size[1] == 0 || global_work_size[2] == 0) return; + GGML_ASSERT(ne00 % ggml_blck_size(src0->type) == 0); - backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size_ptr, dst); -} + cl_context context = backend_ctx->context; + cl_kernel kernel; -static void ggml_cl_repeat(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1_shape_def, ggml_tensor * dst) { - GGML_ASSERT(src0); - GGML_ASSERT(src0->extra); - GGML_ASSERT(dst); - GGML_ASSERT(dst->extra); - GGML_ASSERT(dst->type == src0->type); + cl_int err; + cl_image_format img_fmt; + cl_image_desc img_desc; + cl_buffer_region region; - UNUSED(src1_shape_def); + int M = ne01; + int N = ne1; + int K = ne00; - ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context; + if (ne1 == 1) { + cl_mem q_img = nullptr; + cl_mem b_sub_buf = nullptr; + cl_mem b_img = nullptr; + + // image for q + img_fmt = { CL_R, CL_UNSIGNED_INT32}; + memset(&img_desc, 0, sizeof(img_desc)); + img_desc.image_type = CL_MEM_OBJECT_IMAGE1D_BUFFER; + img_desc.image_width = M * K / 2 / 4; + img_desc.buffer = extra0_q4_1->q; + CL_CHECK((q_img = clCreateImage(context, CL_MEM_READ_ONLY, &img_fmt, &img_desc, NULL, &err), err)); + + // subbuffer for activations + region.origin = offset1; + region.size = K * N * sizeof(float); + CL_CHECK((b_sub_buf = clCreateSubBuffer(extra1->data_device, 0, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &err), err)); + + // image for activations + img_fmt = {CL_RGBA, CL_FLOAT}; + memset(&img_desc, 0, sizeof(img_desc)); + img_desc.image_type = CL_MEM_OBJECT_IMAGE1D_BUFFER; + img_desc.image_width = K * N / 4; + img_desc.buffer = b_sub_buf; + CL_CHECK((b_img = clCreateImage(context, CL_MEM_READ_ONLY, &img_fmt, &img_desc, NULL, &err), err)); + + kernel = backend_ctx->kernel_gemv_noshuffle_q4_1_f32; + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &q_img)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra0_q4_1->d)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra0_q4_1->m)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_mem), &b_img)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offsetd)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(cl_int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, 7, sizeof(cl_int), &ne01)); - if (backend_ctx->kernel_repeat == nullptr) { - GGML_LOG_WARN("%s: repeat kernel not available, skipping OpenCL execution.\n", __func__); - return; - } + size_t local_work_size[3] = {64, 4, 1}; + size_t global_work_size[3] = {(size_t)CEIL_DIV(ne01/2, 64)*64, 4, 1}; - ggml_tensor_extra_cl * extra_src0 = (ggml_tensor_extra_cl *)src0->extra; - ggml_tensor_extra_cl * extra_dst = (ggml_tensor_extra_cl *)dst->extra; + backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst); - cl_ulong off_src0 = extra_src0->offset + src0->view_offs; - cl_ulong off_dst = extra_dst->offset + dst->view_offs; + CL_CHECK(clReleaseMemObject(q_img)); + CL_CHECK(clReleaseMemObject(b_sub_buf)); + CL_CHECK(clReleaseMemObject(b_img)); + } else { + cl_mem b_sub_buf = nullptr; + cl_mem b_sub_buf_trans = nullptr; + cl_mem b_img = nullptr; + cl_mem b_img_trans = nullptr; - const int src0_ne0 = src0->ne[0]; const int src0_ne1 = src0->ne[1]; const int src0_ne2 = src0->ne[2]; const int src0_ne3 = src0->ne[3]; - const cl_ulong src0_nb0 = src0->nb[0]; const cl_ulong src0_nb1 = src0->nb[1]; const cl_ulong src0_nb2 = src0->nb[2]; const cl_ulong src0_nb3 = src0->nb[3]; + // subbuffer for activations + region.origin = offset1; + region.size = K * N * sizeof(float); + CL_CHECK((b_sub_buf = clCreateSubBuffer(extra1->data_device, 0, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &err), err)); + + // image for activations + img_fmt = {CL_RGBA, CL_FLOAT}; + memset(&img_desc, 0, sizeof(img_desc)); + img_desc.image_type = CL_MEM_OBJECT_IMAGE1D_BUFFER; + img_desc.image_width = K * N / 4; + img_desc.buffer = b_sub_buf; + CL_CHECK((b_img = clCreateImage(context, CL_MEM_READ_ONLY, &img_fmt, &img_desc, NULL, &err), err)); + + // pad N to multiple of 8 + int extra_elements = N % 8; + int padding = 0; + if (extra_elements > 0){ + padding = 8 - extra_elements; + } - const int dst_ne0 = dst->ne[0]; const int dst_ne1 = dst->ne[1]; const int dst_ne2 = dst->ne[2]; const int dst_ne3 = dst->ne[3]; - const cl_ulong dst_nb0 = dst->nb[0]; const cl_ulong dst_nb1 = dst->nb[1]; const cl_ulong dst_nb2 = dst->nb[2]; const cl_ulong dst_nb3 = dst->nb[3]; + // subbuffer for transposed activations + region.origin = 0; + region.size = K * (N + padding) * sizeof(float)/2; + backend_ctx->prealloc_act_trans.allocate(context, region.size); + CL_CHECK((b_sub_buf_trans = clCreateSubBuffer(backend_ctx->prealloc_act_trans.buffer, 0, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &err), err)); + + // image for transposed activations + img_fmt = {CL_RGBA, CL_HALF_FLOAT}; + memset(&img_desc, 0, sizeof(img_desc)); + img_desc.image_type = CL_MEM_OBJECT_IMAGE1D_BUFFER; + img_desc.image_width = K * (N + padding) / 4; + img_desc.buffer = b_sub_buf_trans; + CL_CHECK((b_img_trans = clCreateImage(context, 0, &img_fmt, &img_desc, NULL, &err), err)); + + // transpose activations + int height_B = N/4; + if (height_B == 0) { + height_B = 1; + } + int width_B = K/4; + int padded_height_B = (N + padding)/4; + + kernel = backend_ctx->kernel_transpose_32_16; + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &b_img)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &b_img_trans)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(int), &height_B)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(int), &width_B)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(int), &padded_height_B)); + + size_t local_work_size_t[2] = { 1, 16 }; + size_t global_work_size_t[2] = { (size_t)width_B, (size_t)padded_height_B }; + backend_ctx->enqueue_ndrange_kernel(kernel, 2, global_work_size_t, local_work_size_t, dst); + + // gemm + kernel = backend_ctx->kernel_gemm_noshuffle_q4_1_f32; + int padded_N = N + padding; + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0_q4_1->q)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra0_q4_1->d)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra0_q4_1->m)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_mem), &b_img_trans)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offsetd)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(cl_int), &ne01)); + CL_CHECK(clSetKernelArg(kernel, 7, sizeof(cl_int), &padded_N)); + CL_CHECK(clSetKernelArg(kernel, 8, sizeof(cl_int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, 9, sizeof(cl_int), &ne1)); - cl_kernel kernel = backend_ctx->kernel_repeat; + size_t global_work_size[3] = {(size_t)CEIL_DIV(ne1, 8), (size_t)CEIL_DIV(ne01, 4), 1}; + size_t local_work_size[3] = {1, 128, 1}; - CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra_src0->data_device)); - CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra_dst->data_device)); - CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_ulong), &off_src0)); - CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &off_dst)); - CL_CHECK(clSetKernelArg(kernel, 4, sizeof(int), &src0_ne0)); - CL_CHECK(clSetKernelArg(kernel, 5, sizeof(int), &src0_ne1)); - CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &src0_ne2)); - CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &src0_ne3)); - CL_CHECK(clSetKernelArg(kernel, 8, sizeof(cl_ulong), &src0_nb0)); - CL_CHECK(clSetKernelArg(kernel, 9, sizeof(cl_ulong), &src0_nb1)); - CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_ulong), &src0_nb2)); - CL_CHECK(clSetKernelArg(kernel, 11, sizeof(cl_ulong), &src0_nb3)); - CL_CHECK(clSetKernelArg(kernel, 12, sizeof(int), &dst_ne0)); - CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int), &dst_ne1)); - CL_CHECK(clSetKernelArg(kernel, 14, sizeof(int), &dst_ne2)); - CL_CHECK(clSetKernelArg(kernel, 15, sizeof(int), &dst_ne3)); - CL_CHECK(clSetKernelArg(kernel, 16, sizeof(cl_ulong), &dst_nb0)); - CL_CHECK(clSetKernelArg(kernel, 17, sizeof(cl_ulong), &dst_nb1)); - CL_CHECK(clSetKernelArg(kernel, 18, sizeof(cl_ulong), &dst_nb2)); - CL_CHECK(clSetKernelArg(kernel, 19, sizeof(cl_ulong), &dst_nb3)); - - size_t gws0 = dst_ne1 > 0 ? (size_t)dst_ne1 : 1; - size_t gws1 = dst_ne2 > 0 ? (size_t)dst_ne2 : 1; - size_t gws2 = dst_ne3 > 0 ? (size_t)dst_ne3 : 1; - - size_t global_work_size[] = { gws0, gws1, gws2 }; + backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst); - backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, NULL, dst); + CL_CHECK(clReleaseMemObject(b_sub_buf)); + CL_CHECK(clReleaseMemObject(b_sub_buf_trans)); + CL_CHECK(clReleaseMemObject(b_img)); + CL_CHECK(clReleaseMemObject(b_img_trans)); + } +#else + GGML_UNUSED(backend); + GGML_UNUSED(src0); + GGML_UNUSED(src1); + GGML_UNUSED(dst); +#endif } -static void ggml_cl_pad(ggml_backend_t backend, const ggml_tensor * src0, ggml_tensor * dst) { +static void ggml_cl_mul_mat_q5_0_f32_adreno(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { +#ifdef GGML_OPENCL_USE_ADRENO_KERNELS GGML_ASSERT(src0); GGML_ASSERT(src0->extra); + GGML_ASSERT(src1); + GGML_ASSERT(src1->extra); GGML_ASSERT(dst); GGML_ASSERT(dst->extra); - GGML_ASSERT(src0->type == GGML_TYPE_F32); - GGML_ASSERT(dst->type == GGML_TYPE_F32); ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context; - if (backend_ctx->kernel_pad == nullptr) { - GGML_LOG_WARN("%s: pad kernel not available, skipping OpenCL execution.\n", __func__); - return; - } + ggml_tensor_extra_cl * extra1 = (ggml_tensor_extra_cl *)src1->extra; + ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra; + ggml_tensor_extra_cl_q5_0 * extra0_q5_0 = (ggml_tensor_extra_cl_q5_0 *)src0->extra; - ggml_tensor_extra_cl * extra_src0 = (ggml_tensor_extra_cl *)src0->extra; - ggml_tensor_extra_cl * extra_dst = (ggml_tensor_extra_cl *)dst->extra; + cl_ulong offset1 = extra1->offset + src1->view_offs; + cl_ulong offsetd = extrad->offset + dst->view_offs; - cl_ulong off_src0 = extra_src0->offset + src0->view_offs; - cl_ulong off_dst = extra_dst->offset + dst->view_offs; + const int ne00 = src0->ne[0]; + const int ne01 = src0->ne[1]; - const int s_ne0 = src0->ne[0]; - const int s_ne1 = src0->ne[1]; - const int s_ne2 = src0->ne[2]; - const int s_ne3 = src0->ne[3]; + const int ne1 = dst->ne[1]; - const int s_nb0 = src0->nb[0]; - const int s_nb1 = src0->nb[1]; - const int s_nb2 = src0->nb[2]; - const int s_nb3 = src0->nb[3]; + GGML_ASSERT(ne00 % ggml_blck_size(src0->type) == 0); - const int d_ne0 = dst->ne[0]; - const int d_ne1 = dst->ne[1]; - const int d_ne2 = dst->ne[2]; - const int d_ne3 = dst->ne[3]; + cl_context context = backend_ctx->context; + cl_kernel kernel; - const int d_nb0 = dst->nb[0]; - const int d_nb1 = dst->nb[1]; - const int d_nb2 = dst->nb[2]; - const int d_nb3 = dst->nb[3]; + cl_int err; + cl_image_format img_fmt; + cl_image_desc img_desc; + cl_buffer_region region; - const int lp0 = ((const int*)(dst->op_params))[0]; - const int rp0 = ((const int*)(dst->op_params))[1]; - const int lp1 = ((const int*)(dst->op_params))[2]; - const int rp1 = ((const int*)(dst->op_params))[3]; - const int lp2 = ((const int*)(dst->op_params))[4]; - const int rp2 = ((const int*)(dst->op_params))[5]; - const int lp3 = ((const int*)(dst->op_params))[6]; - const int rp3 = ((const int*)(dst->op_params))[7]; + int M = ne01; + int N = ne1; + int K = ne00; - cl_kernel kernel = backend_ctx->kernel_pad; + if (ne1 == 1) { + cl_mem qs_img = nullptr; + cl_mem b_sub_buf = nullptr; + cl_mem b_img = nullptr; + + // image for qs + img_fmt = { CL_R, CL_UNSIGNED_INT32 }; + memset(&img_desc, 0, sizeof(img_desc)); + img_desc.image_type = CL_MEM_OBJECT_IMAGE1D_BUFFER; + img_desc.image_width = M * K / 2 / 4; + img_desc.buffer = extra0_q5_0->qs; + CL_CHECK((qs_img = clCreateImage(context, CL_MEM_READ_ONLY, &img_fmt, &img_desc, NULL, &err), err)); + + // subbuffer for activations + region.origin = offset1; + region.size = K * N * sizeof(float); + CL_CHECK((b_sub_buf = clCreateSubBuffer(extra1->data_device, 0, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &err), err)); + + // image for activations + img_fmt = {CL_RGBA, CL_FLOAT}; + memset(&img_desc, 0, sizeof(img_desc)); + img_desc.image_type = CL_MEM_OBJECT_IMAGE1D_BUFFER; + img_desc.image_width = K * N / 4; + img_desc.buffer = b_sub_buf; + CL_CHECK((b_img = clCreateImage(context, CL_MEM_READ_ONLY, &img_fmt, &img_desc, NULL, &err), err)); + + kernel = backend_ctx->kernel_gemv_noshuffle_q5_0_f32; + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &qs_img)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra0_q5_0->qh)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra0_q5_0->d)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_mem), &b_img)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offsetd)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(cl_int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, 7, sizeof(cl_int), &ne01)); - CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra_src0->data_device)); - CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &off_src0)); - CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra_dst->data_device)); - CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &off_dst)); - CL_CHECK(clSetKernelArg(kernel, 4, sizeof(int), &s_ne0)); - CL_CHECK(clSetKernelArg(kernel, 5, sizeof(int), &s_ne1)); - CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &s_ne2)); - CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &s_ne3)); - CL_CHECK(clSetKernelArg(kernel, 8, sizeof(cl_ulong), &s_nb0)); - CL_CHECK(clSetKernelArg(kernel, 9, sizeof(cl_ulong), &s_nb1)); - CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_ulong), &s_nb2)); - CL_CHECK(clSetKernelArg(kernel, 11, sizeof(cl_ulong), &s_nb3)); - CL_CHECK(clSetKernelArg(kernel, 12, sizeof(int), &d_ne0)); - CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int), &d_ne1)); - CL_CHECK(clSetKernelArg(kernel, 14, sizeof(int), &d_ne2)); - CL_CHECK(clSetKernelArg(kernel, 15, sizeof(int), &d_ne3)); - CL_CHECK(clSetKernelArg(kernel, 16, sizeof(cl_ulong), &d_nb0)); - CL_CHECK(clSetKernelArg(kernel, 17, sizeof(cl_ulong), &d_nb1)); - CL_CHECK(clSetKernelArg(kernel, 18, sizeof(cl_ulong), &d_nb2)); - CL_CHECK(clSetKernelArg(kernel, 19, sizeof(cl_ulong), &d_nb3)); - CL_CHECK(clSetKernelArg(kernel, 20, sizeof(int), &lp0)); - CL_CHECK(clSetKernelArg(kernel, 21, sizeof(int), &rp0)); - CL_CHECK(clSetKernelArg(kernel, 22, sizeof(int), &lp1)); - CL_CHECK(clSetKernelArg(kernel, 23, sizeof(int), &rp1)); - CL_CHECK(clSetKernelArg(kernel, 24, sizeof(int), &lp2)); - CL_CHECK(clSetKernelArg(kernel, 25, sizeof(int), &rp2)); - CL_CHECK(clSetKernelArg(kernel, 26, sizeof(int), &lp3)); - CL_CHECK(clSetKernelArg(kernel, 27, sizeof(int), &rp3)); + size_t local_work_size[3] = {64, 4, 1}; + size_t global_work_size[3] = {(size_t)CEIL_DIV(ne01/2, 64)*64, 4, 1}; - size_t lws0 = 64; - size_t gws0 = (( (size_t)d_ne0 + lws0 - 1 ) / lws0) * lws0; + backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst); - size_t global_work_size[] = { gws0, (size_t)d_ne1, (size_t)d_ne2*d_ne3 }; - size_t local_work_size[] = { lws0, 1, 1 }; + CL_CHECK(clReleaseMemObject(qs_img)); + CL_CHECK(clReleaseMemObject(b_sub_buf)); + CL_CHECK(clReleaseMemObject(b_img)); + } else { + cl_mem b_sub_buf = nullptr; + cl_mem b_sub_buf_trans = nullptr; + cl_mem b_img = nullptr; + cl_mem b_img_trans = nullptr; + cl_mem d_sub_buf = nullptr; + + // subbuffer for activations + region.origin = offset1; + region.size = K * N * sizeof(float); + CL_CHECK((b_sub_buf = clCreateSubBuffer(extra1->data_device, 0, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &err), err)); + + // image for activations + img_fmt = {CL_RGBA, CL_FLOAT}; + memset(&img_desc, 0, sizeof(img_desc)); + img_desc.image_type = CL_MEM_OBJECT_IMAGE1D_BUFFER; + img_desc.image_width = K * N / 4; + img_desc.buffer = b_sub_buf; + CL_CHECK((b_img = clCreateImage(context, CL_MEM_READ_ONLY, &img_fmt, &img_desc, NULL, &err), err)); + + // pad N to multiple of 8 + int extra_elements = N % 8; + int padding = 0; + if (extra_elements > 0){ + padding = 8 - extra_elements; + } - size_t * local_work_size_ptr = local_work_size; - if (d_ne0 % lws0 != 0 && !backend_ctx->non_uniform_workgroups) { - local_work_size_ptr = nullptr; - } + // subbuffer for transposed activations + region.origin = 0; + region.size = K * (N + padding) * sizeof(float)/2; + backend_ctx->prealloc_act_trans.allocate(context, region.size); + CL_CHECK((b_sub_buf_trans = clCreateSubBuffer(backend_ctx->prealloc_act_trans.buffer, 0, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &err), err)); + + // image for transposed activations + img_fmt = {CL_RGBA, CL_HALF_FLOAT}; + memset(&img_desc, 0, sizeof(img_desc)); + img_desc.image_type = CL_MEM_OBJECT_IMAGE1D_BUFFER; + img_desc.image_width = K * (N + padding) / 4; + img_desc.buffer = b_sub_buf_trans; + CL_CHECK((b_img_trans = clCreateImage(context, 0, &img_fmt, &img_desc, NULL, &err), err)); + + // subbuffer for output + region.origin = extrad->offset; + region.size = M * N * sizeof(float); + CL_CHECK((d_sub_buf = clCreateSubBuffer(extrad->data_device, CL_MEM_WRITE_ONLY, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &err), err)); + + // transpose activations + int height_B = N/4; + if (height_B == 0) { + height_B = 1; + } + int width_B = K/4; + int padded_height_B = (N + padding)/4; + + kernel = backend_ctx->kernel_transpose_32_16; + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &b_img)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &b_img_trans)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(int), &height_B)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(int), &width_B)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(int), &padded_height_B)); + + size_t local_work_size_t[2] = { 1, 16 }; + size_t global_work_size_t[2] = { (size_t)width_B, (size_t)padded_height_B }; + backend_ctx->enqueue_ndrange_kernel(kernel, 2, global_work_size_t, local_work_size_t, dst); + + // gemm + kernel = backend_ctx->kernel_gemm_noshuffle_q5_0_f32; + int padded_N = N + padding; + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0_q5_0->qs)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra0_q5_0->qh)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra0_q5_0->d)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_mem), &b_img_trans)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &d_sub_buf)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_int), &ne01)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(cl_int), &padded_N)); + CL_CHECK(clSetKernelArg(kernel, 7, sizeof(cl_int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, 8, sizeof(cl_int), &ne1)); + + size_t global_work_size[3] = {(size_t)CEIL_DIV(ne1, 8), (size_t)CEIL_DIV(ne01, 4), 1}; + size_t local_work_size[3] = {1, 128, 1}; - backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size_ptr, dst); + backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst); + + CL_CHECK(clReleaseMemObject(b_sub_buf)); + CL_CHECK(clReleaseMemObject(b_sub_buf_trans)); + CL_CHECK(clReleaseMemObject(b_img)); + CL_CHECK(clReleaseMemObject(b_img_trans)); + CL_CHECK(clReleaseMemObject(d_sub_buf)); + } +#else + GGML_UNUSED(backend); + GGML_UNUSED(src0); + GGML_UNUSED(src1); + GGML_UNUSED(dst); +#endif } -static void ggml_cl_upscale(ggml_backend_t backend, const ggml_tensor * src0, ggml_tensor * dst) { +static void ggml_cl_mul_mat_q5_1_f32_adreno(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { +#ifdef GGML_OPENCL_USE_ADRENO_KERNELS GGML_ASSERT(src0); GGML_ASSERT(src0->extra); + GGML_ASSERT(src1); + GGML_ASSERT(src1->extra); GGML_ASSERT(dst); GGML_ASSERT(dst->extra); - GGML_ASSERT(src0->type == GGML_TYPE_F32); - GGML_ASSERT(dst->type == GGML_TYPE_F32); ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context; - const int mode_flags = (ggml_scale_mode) ggml_get_op_params_i32(dst, 0); - const ggml_scale_mode mode = (ggml_scale_mode) (mode_flags & 0xFF); - cl_kernel kernel = nullptr; - - if (mode == GGML_SCALE_MODE_NEAREST) { - kernel = backend_ctx->kernel_upscale; - if (kernel == nullptr) { - GGML_LOG_WARN("%s: nearest upscale kernel not available, skipping OpenCL execution.\n", __func__); - return; - } - } else if (mode == GGML_SCALE_MODE_BILINEAR) { - kernel = backend_ctx->kernel_upscale_bilinear; - if (kernel == nullptr) { - GGML_LOG_WARN("%s: bilinear upscale kernel not available, skipping OpenCL execution.\n", __func__); - return; - } - } else { - GGML_LOG_WARN("%s: unsupported upscale mode %d, skipping OpenCL execution.\n", __func__, mode); - return; - } - - ggml_tensor_extra_cl * extra_src0 = (ggml_tensor_extra_cl *)src0->extra; - ggml_tensor_extra_cl * extra_dst = (ggml_tensor_extra_cl *)dst->extra; - - cl_ulong off_src0 = extra_src0->offset + src0->view_offs; - cl_ulong off_dst = extra_dst->offset + dst->view_offs; + ggml_tensor_extra_cl * extra1 = (ggml_tensor_extra_cl *)src1->extra; + ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra; + ggml_tensor_extra_cl_q5_1 * extra0_q5_1 = (ggml_tensor_extra_cl_q5_1 *)src0->extra; - const cl_ulong nb00 = src0->nb[0]; - const cl_ulong nb01 = src0->nb[1]; - const cl_ulong nb02 = src0->nb[2]; - const cl_ulong nb03 = src0->nb[3]; + cl_ulong offset1 = extra1->offset + src1->view_offs; + cl_ulong offsetd = extrad->offset + dst->view_offs; const int ne00 = src0->ne[0]; const int ne01 = src0->ne[1]; - const int ne02 = src0->ne[2]; - const int ne03 = src0->ne[3]; - const int ne0 = dst->ne[0]; const int ne1 = dst->ne[1]; - const int ne2 = dst->ne[2]; - const int ne3 = dst->ne[3]; - float sf0 = (float)ne0 / ne00; - float sf1 = (float)ne1 / ne01; - float sf2 = (float)ne2 / ne02; - float sf3 = (float)ne3 / ne03; + GGML_ASSERT(ne00 % ggml_blck_size(src0->type) == 0); - float pixel_offset = 0.5f; + cl_context context = backend_ctx->context; + cl_kernel kernel; - CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra_src0->data_device)); - CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &off_src0)); - CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra_dst->data_device)); - CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &off_dst)); - CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_ulong), &nb00)); - CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &nb01)); - CL_CHECK(clSetKernelArg(kernel, 6, sizeof(cl_ulong), &nb02)); - CL_CHECK(clSetKernelArg(kernel, 7, sizeof(cl_ulong), &nb03)); + cl_int err; + cl_image_format img_fmt; + cl_image_desc img_desc; + cl_buffer_region region; - if (mode == GGML_SCALE_MODE_NEAREST) { - CL_CHECK(clSetKernelArg(kernel, 8, sizeof(int), &ne0)); - CL_CHECK(clSetKernelArg(kernel, 9, sizeof(int), &ne1)); - CL_CHECK(clSetKernelArg(kernel, 10, sizeof(int), &ne2)); - CL_CHECK(clSetKernelArg(kernel, 11, sizeof(int), &ne3)); - CL_CHECK(clSetKernelArg(kernel, 12, sizeof(float), &sf0)); - CL_CHECK(clSetKernelArg(kernel, 13, sizeof(float), &sf1)); - CL_CHECK(clSetKernelArg(kernel, 14, sizeof(float), &sf2)); - CL_CHECK(clSetKernelArg(kernel, 15, sizeof(float), &sf3)); - } else if (mode == GGML_SCALE_MODE_BILINEAR) { - if (mode_flags & GGML_SCALE_FLAG_ALIGN_CORNERS) { - sf0 = ne0 > 1 && ne00 > 1 ? (float)(ne0 - 1) / (ne00 - 1) : sf0; - sf1 = ne1 > 1 && ne01 > 1 ? (float)(ne1 - 1) / (ne01 - 1) : sf1; - pixel_offset = 0.0f; - } + int M = ne01; + int N = ne1; + int K = ne00; - CL_CHECK(clSetKernelArg(kernel, 8, sizeof(int), &ne00)); - CL_CHECK(clSetKernelArg(kernel, 9, sizeof(int), &ne01)); - CL_CHECK(clSetKernelArg(kernel, 10, sizeof(int), &ne0)); - CL_CHECK(clSetKernelArg(kernel, 11, sizeof(int), &ne1)); - CL_CHECK(clSetKernelArg(kernel, 12, sizeof(int), &ne2)); - CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int), &ne3)); - CL_CHECK(clSetKernelArg(kernel, 14, sizeof(float), &sf0)); - CL_CHECK(clSetKernelArg(kernel, 15, sizeof(float), &sf1)); - CL_CHECK(clSetKernelArg(kernel, 16, sizeof(float), &sf2)); - CL_CHECK(clSetKernelArg(kernel, 17, sizeof(float), &sf3)); - CL_CHECK(clSetKernelArg(kernel, 18, sizeof(float), &pixel_offset)); - } + if (ne1 == 1) { + cl_mem qs_img = nullptr; + cl_mem b_sub_buf = nullptr; + cl_mem b_img = nullptr; + + // image for qs + img_fmt = { CL_R, CL_UNSIGNED_INT32 }; + memset(&img_desc, 0, sizeof(img_desc)); + img_desc.image_type = CL_MEM_OBJECT_IMAGE1D_BUFFER; + img_desc.image_width = M * K / 2 / 4; + img_desc.buffer = extra0_q5_1->qs; + CL_CHECK((qs_img = clCreateImage(context, CL_MEM_READ_ONLY, &img_fmt, &img_desc, NULL, &err), err)); + + // subbuffer for activations + region.origin = offset1; + region.size = K * N * sizeof(float); + CL_CHECK((b_sub_buf = clCreateSubBuffer(extra1->data_device, 0, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &err), err)); + + // image for activations + img_fmt = {CL_RGBA, CL_FLOAT}; + memset(&img_desc, 0, sizeof(img_desc)); + img_desc.image_type = CL_MEM_OBJECT_IMAGE1D_BUFFER; + img_desc.image_width = K * N / 4; + img_desc.buffer = b_sub_buf; + CL_CHECK((b_img = clCreateImage(context, CL_MEM_READ_ONLY, &img_fmt, &img_desc, NULL, &err), err)); + + kernel = backend_ctx->kernel_gemv_noshuffle_q5_1_f32; + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &qs_img)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra0_q5_1->qh)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra0_q5_1->d)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_mem), &extra0_q5_1->m)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &b_img)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(cl_ulong), &offsetd)); + CL_CHECK(clSetKernelArg(kernel, 7, sizeof(cl_int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, 8, sizeof(cl_int), &ne01)); + + size_t local_work_size[3] = {64, 4, 1}; + size_t global_work_size[3] = {(size_t)CEIL_DIV(ne01/2, 64)*64, 4, 1}; + backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst); - size_t dst_total_elements = (size_t)ne0 * ne1 * ne2 * ne3; - if (dst_total_elements == 0) { - return; - } - size_t global_work_size[] = { dst_total_elements, 1, 1 }; - size_t local_work_size_pref = 256; - size_t local_work_size[] = { MIN(local_work_size_pref, dst_total_elements), 1, 1}; + CL_CHECK(clReleaseMemObject(qs_img)); + CL_CHECK(clReleaseMemObject(b_sub_buf)); + CL_CHECK(clReleaseMemObject(b_img)); + } else { + cl_mem b_sub_buf = nullptr; + cl_mem b_sub_buf_trans = nullptr; + cl_mem b_img = nullptr; + cl_mem b_img_trans = nullptr; + cl_mem d_sub_buf = nullptr; + + // subbuffer for activations + region.origin = offset1; + region.size = K * N * sizeof(float); + CL_CHECK((b_sub_buf = clCreateSubBuffer(extra1->data_device, 0, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &err), err)); + + // image for activations + img_fmt = {CL_RGBA, CL_FLOAT}; + memset(&img_desc, 0, sizeof(img_desc)); + img_desc.image_type = CL_MEM_OBJECT_IMAGE1D_BUFFER; + img_desc.image_width = K * N / 4; + img_desc.buffer = b_sub_buf; + CL_CHECK((b_img = clCreateImage(context, CL_MEM_READ_ONLY, &img_fmt, &img_desc, NULL, &err), err)); + + // pad N to multiple of 8 + int extra_elements = N % 8; + int padding = 0; + if (extra_elements > 0){ + padding = 8 - extra_elements; + } - size_t * local_work_size_ptr = local_work_size; - if (dst_total_elements % local_work_size[0] != 0 && !backend_ctx->non_uniform_workgroups) { - local_work_size_ptr = nullptr; - } + // subbuffer for transposed activations + region.origin = 0; + region.size = K * (N + padding) * sizeof(float)/2; + backend_ctx->prealloc_act_trans.allocate(context, region.size); + CL_CHECK((b_sub_buf_trans = clCreateSubBuffer(backend_ctx->prealloc_act_trans.buffer, 0, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &err), err)); + + // image for transposed activations + img_fmt = {CL_RGBA, CL_HALF_FLOAT}; + memset(&img_desc, 0, sizeof(img_desc)); + img_desc.image_type = CL_MEM_OBJECT_IMAGE1D_BUFFER; + img_desc.image_width = K * (N + padding) / 4; + img_desc.buffer = b_sub_buf_trans; + CL_CHECK((b_img_trans = clCreateImage(context, 0, &img_fmt, &img_desc, NULL, &err), err)); + + // subbuffer for output + region.origin = extrad->offset; + region.size = M * N * sizeof(float); + CL_CHECK((d_sub_buf = clCreateSubBuffer(extrad->data_device, CL_MEM_WRITE_ONLY, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &err), err)); + + // transpose activations + int height_B = N/4; + if (height_B == 0) { + height_B = 1; + } + int width_B = K/4; + int padded_height_B = (N + padding)/4; + + kernel = backend_ctx->kernel_transpose_32_16; + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &b_img)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &b_img_trans)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(int), &height_B)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(int), &width_B)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(int), &padded_height_B)); + + size_t local_work_size_t[2] = { 1, 16 }; + size_t global_work_size_t[2] = { (size_t)width_B, (size_t)padded_height_B }; + backend_ctx->enqueue_ndrange_kernel(kernel, 2, global_work_size_t, local_work_size_t, dst); + + // gemm + kernel = backend_ctx->kernel_gemm_noshuffle_q5_1_f32; + int padded_N = N + padding; + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0_q5_1->qs)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra0_q5_1->qh)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra0_q5_1->d)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_mem), &extra0_q5_1->m)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &b_img_trans)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_mem), &d_sub_buf)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(cl_int), &ne01)); + CL_CHECK(clSetKernelArg(kernel, 7, sizeof(cl_int), &padded_N)); + CL_CHECK(clSetKernelArg(kernel, 8, sizeof(cl_int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, 9, sizeof(cl_int), &ne1)); + + size_t global_work_size[3] = {(size_t)CEIL_DIV(ne1, 8), (size_t)CEIL_DIV(ne01, 4), 1}; + size_t local_work_size[3] = {1, 128, 1}; - backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size_ptr, dst); + backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst); + + CL_CHECK(clReleaseMemObject(b_sub_buf)); + CL_CHECK(clReleaseMemObject(b_sub_buf_trans)); + CL_CHECK(clReleaseMemObject(b_img)); + CL_CHECK(clReleaseMemObject(b_img_trans)); + CL_CHECK(clReleaseMemObject(d_sub_buf)); + } +#else + GGML_UNUSED(backend); + GGML_UNUSED(src0); + GGML_UNUSED(src1); + GGML_UNUSED(dst); +#endif } -static void ggml_cl_concat(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { +static void ggml_cl_mul_mat_iq4_nl_f32_adreno(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { +#ifdef GGML_OPENCL_USE_ADRENO_KERNELS GGML_ASSERT(src0); GGML_ASSERT(src0->extra); GGML_ASSERT(src1); GGML_ASSERT(src1->extra); GGML_ASSERT(dst); GGML_ASSERT(dst->extra); - GGML_ASSERT(src0->type == GGML_TYPE_F32); - GGML_ASSERT(src1->type == GGML_TYPE_F32); - GGML_ASSERT(dst->type == GGML_TYPE_F32); ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context; - cl_command_queue queue = backend_ctx->queue; - if (backend_ctx->kernel_concat_f32_contiguous == nullptr || backend_ctx->kernel_concat_f32_non_contiguous == nullptr) { - GGML_LOG_WARN("%s: concat kernels not available, skipping OpenCL execution.\n", __func__); - return; - } + ggml_tensor_extra_cl * extra1 = (ggml_tensor_extra_cl *)src1->extra; + ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra; + ggml_tensor_extra_cl_iq4_nl * extra0_iq4_nl = (ggml_tensor_extra_cl_iq4_nl *)src0->extra; + + cl_ulong offset1 = extra1->offset + src1->view_offs; + cl_ulong offsetd = extrad->offset + dst->view_offs; - ggml_tensor_extra_cl * extra0_cl = (ggml_tensor_extra_cl *)src0->extra; - ggml_tensor_extra_cl * extra1_cl = (ggml_tensor_extra_cl *)src1->extra; - ggml_tensor_extra_cl * extrad_cl = (ggml_tensor_extra_cl *)dst->extra; + const int ne00 = src0->ne[0]; + const int ne01 = src0->ne[1]; - cl_ulong off_src0 = extra0_cl->offset + src0->view_offs; - cl_ulong off_src1 = extra1_cl->offset + src1->view_offs; - cl_ulong off_dst = extrad_cl->offset + dst->view_offs; + const int ne1 = dst->ne[1]; - const int32_t dim = ((const int32_t *) dst->op_params)[0]; - GGML_ASSERT(dim >= 0 && dim <= 3); + GGML_ASSERT(ne00 % 32 == 0); - if (ggml_is_contiguous(src0) && ggml_is_contiguous(src1) && ggml_is_contiguous(dst)) { - if (dim == 3) { + cl_context context = backend_ctx->context; + cl_kernel kernel; - size_t nbytes_src0 = ggml_nbytes(src0); - size_t nbytes_src1 = ggml_nbytes(src1); + cl_int err; + cl_image_format img_fmt; + cl_image_desc img_desc; + cl_buffer_region region; - CL_CHECK(clEnqueueCopyBuffer(queue, extra0_cl->data_device, extrad_cl->data_device, - off_src0, off_dst, nbytes_src0, 0, NULL, NULL)); - CL_CHECK(clEnqueueCopyBuffer(queue, extra1_cl->data_device, extrad_cl->data_device, - off_src1, off_dst + nbytes_src0, nbytes_src1, 0, NULL, NULL)); - } else { + int M = ne01; + int N = ne1; + int K = ne00; - cl_kernel kernel = backend_ctx->kernel_concat_f32_contiguous; - size_t global_work_size[3]; - - for (int i3 = 0; i3 < dst->ne[3]; ++i3) { - cl_ulong current_off_src0 = off_src0 + (i3 * src0->nb[3]); - cl_ulong current_off_src1 = off_src1 + (i3 * src1->nb[3]); - cl_ulong current_off_dst = off_dst + (i3 * dst->nb[3]); - - int d_ne00 = src0->ne[0]; int d_ne01 = src0->ne[1]; int d_ne02 = src0->ne[2]; - int d_ne10 = src1->ne[0]; int d_ne11 = src1->ne[1]; int d_ne12 = src1->ne[2]; - int d_ne0 = dst->ne[0]; int d_ne1 = dst->ne[1]; int d_ne2 = dst->ne[2]; - - CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0_cl->data_device)); - CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), ¤t_off_src0)); - CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra1_cl->data_device)); - CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), ¤t_off_src1)); - CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extrad_cl->data_device)); - CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), ¤t_off_dst)); - CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &d_ne00)); - CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &d_ne01)); - CL_CHECK(clSetKernelArg(kernel, 8, sizeof(int), &d_ne02)); - CL_CHECK(clSetKernelArg(kernel, 9, sizeof(int), &d_ne10)); - CL_CHECK(clSetKernelArg(kernel, 10, sizeof(int), &d_ne11)); - CL_CHECK(clSetKernelArg(kernel, 11, sizeof(int), &d_ne12)); - CL_CHECK(clSetKernelArg(kernel, 12, sizeof(int), &d_ne0)); - CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int), &d_ne1)); - CL_CHECK(clSetKernelArg(kernel, 14, sizeof(int), &d_ne2)); - CL_CHECK(clSetKernelArg(kernel, 15, sizeof(int), &dim)); - - global_work_size[0] = d_ne0; - global_work_size[1] = d_ne1; - global_work_size[2] = d_ne2; - - backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, NULL, dst); - } - } - } else { - cl_kernel kernel = backend_ctx->kernel_concat_f32_non_contiguous; + if (ne1 == 1) { + cl_mem q_img = nullptr; + cl_mem b_sub_buf = nullptr; + cl_mem b_img = nullptr; + + // image for q + img_fmt = { CL_R, CL_UNSIGNED_INT32}; + memset(&img_desc, 0, sizeof(img_desc)); + img_desc.image_type = CL_MEM_OBJECT_IMAGE1D_BUFFER; + img_desc.image_width = M * K / 2 / 4; + img_desc.buffer = extra0_iq4_nl->q; + CL_CHECK((q_img = clCreateImage(context, CL_MEM_READ_ONLY, &img_fmt, &img_desc, NULL, &err), err)); + + // subbuffer for activations + region.origin = offset1; + region.size = K * N * sizeof(float); + CL_CHECK((b_sub_buf = clCreateSubBuffer(extra1->data_device, 0, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &err), err)); - cl_long ne00 = src0->ne[0], ne01 = src0->ne[1], ne02 = src0->ne[2], ne03 = src0->ne[3]; - cl_ulong nb00 = src0->nb[0], nb01 = src0->nb[1], nb02 = src0->nb[2], nb03 = src0->nb[3]; + // image for activations + img_fmt = {CL_RGBA, CL_FLOAT}; + memset(&img_desc, 0, sizeof(img_desc)); + img_desc.image_type = CL_MEM_OBJECT_IMAGE1D_BUFFER; + img_desc.image_width = K * N / 4; + img_desc.buffer = b_sub_buf; + CL_CHECK((b_img = clCreateImage(context, CL_MEM_READ_ONLY, &img_fmt, &img_desc, NULL, &err), err)); - cl_ulong nb10 = src1->nb[0], nb11 = src1->nb[1], nb12 = src1->nb[2], nb13 = src1->nb[3]; + kernel = backend_ctx->kernel_gemv_noshuffle_iq4_nl_f32; - cl_long d_ne0 = dst->ne[0], d_ne1 = dst->ne[1], d_ne2 = dst->ne[2], d_ne3 = dst->ne[3]; - cl_ulong d_nb0 = dst->nb[0], d_nb1 = dst->nb[1], d_nb2 = dst->nb[2], d_nb3 = dst->nb[3]; + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &q_img)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra0_iq4_nl->d)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &b_img)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_ulong), &offsetd)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(cl_int), &ne01)); + size_t local_work_size[3] = {64, 4, 1}; + size_t global_work_size[3] = {(size_t)CEIL_DIV(ne01/2, 64)*64, 4, 1}; - CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0_cl->data_device)); - CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &off_src0)); - CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra1_cl->data_device)); - CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &off_src1)); - CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extrad_cl->data_device)); - CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &off_dst)); + backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst); - CL_CHECK(clSetKernelArg(kernel, 6, sizeof(cl_long), &ne00)); - CL_CHECK(clSetKernelArg(kernel, 7, sizeof(cl_long), &ne01)); - CL_CHECK(clSetKernelArg(kernel, 8, sizeof(cl_long), &ne02)); - CL_CHECK(clSetKernelArg(kernel, 9, sizeof(cl_long), &ne03)); - CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_ulong), &nb00)); - CL_CHECK(clSetKernelArg(kernel, 11, sizeof(cl_ulong), &nb01)); - CL_CHECK(clSetKernelArg(kernel, 12, sizeof(cl_ulong), &nb02)); - CL_CHECK(clSetKernelArg(kernel, 13, sizeof(cl_ulong), &nb03)); + CL_CHECK(clReleaseMemObject(q_img)); + CL_CHECK(clReleaseMemObject(b_sub_buf)); + CL_CHECK(clReleaseMemObject(b_img)); + } else { + cl_mem b_sub_buf = nullptr; + cl_mem b_sub_buf_trans = nullptr; + cl_mem b_img = nullptr; + cl_mem b_img_trans = nullptr; - CL_CHECK(clSetKernelArg(kernel, 14, sizeof(cl_ulong), &nb10)); - CL_CHECK(clSetKernelArg(kernel, 15, sizeof(cl_ulong), &nb11)); - CL_CHECK(clSetKernelArg(kernel, 16, sizeof(cl_ulong), &nb12)); - CL_CHECK(clSetKernelArg(kernel, 17, sizeof(cl_ulong), &nb13)); + // subbuffer for activations + region.origin = offset1; + region.size = K * N * sizeof(float); + CL_CHECK((b_sub_buf = clCreateSubBuffer(extra1->data_device, 0, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &err), err)); + + // image for activations + img_fmt = {CL_RGBA, CL_FLOAT}; + memset(&img_desc, 0, sizeof(img_desc)); + img_desc.image_type = CL_MEM_OBJECT_IMAGE1D_BUFFER; + img_desc.image_width = K * N / 4; + img_desc.buffer = b_sub_buf; + CL_CHECK((b_img = clCreateImage(context, CL_MEM_READ_ONLY, &img_fmt, &img_desc, NULL, &err), err)); + + // pad N to multiple of 8 + int extra_elements = N % 8; + int padding = 0; + if (extra_elements > 0){ + padding = 8 - extra_elements; + } - CL_CHECK(clSetKernelArg(kernel, 18, sizeof(cl_long), &d_ne0)); - CL_CHECK(clSetKernelArg(kernel, 19, sizeof(cl_long), &d_ne1)); - CL_CHECK(clSetKernelArg(kernel, 20, sizeof(cl_long), &d_ne2)); - CL_CHECK(clSetKernelArg(kernel, 21, sizeof(cl_long), &d_ne3)); - CL_CHECK(clSetKernelArg(kernel, 22, sizeof(cl_ulong), &d_nb0)); - CL_CHECK(clSetKernelArg(kernel, 23, sizeof(cl_ulong), &d_nb1)); - CL_CHECK(clSetKernelArg(kernel, 24, sizeof(cl_ulong), &d_nb2)); - CL_CHECK(clSetKernelArg(kernel, 25, sizeof(cl_ulong), &d_nb3)); - CL_CHECK(clSetKernelArg(kernel, 26, sizeof(int), &dim)); + // subbuffer for transposed activations + region.origin = 0; + region.size = K * (N + padding) * sizeof(float)/2; + backend_ctx->prealloc_act_trans.allocate(context, region.size); + CL_CHECK((b_sub_buf_trans = clCreateSubBuffer(backend_ctx->prealloc_act_trans.buffer, 0, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &err), err)); + + // image for transposed activations + img_fmt = {CL_RGBA, CL_HALF_FLOAT}; + memset(&img_desc, 0, sizeof(img_desc)); + img_desc.image_type = CL_MEM_OBJECT_IMAGE1D_BUFFER; + img_desc.image_width = K * (N + padding) / 4; + img_desc.buffer = b_sub_buf_trans; + CL_CHECK((b_img_trans = clCreateImage(context, 0, &img_fmt, &img_desc, NULL, &err), err)); + + // transpose activations + int height_B = N/4; + if (height_B == 0) { + height_B = 1; + } + int width_B = K/4; + int padded_height_B = (N + padding)/4; + + kernel = backend_ctx->kernel_transpose_32_16; + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &b_img)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &b_img_trans)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(int), &height_B)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(int), &width_B)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(int), &padded_height_B)); + + size_t local_work_size_t[2] = { 1, 16 }; + size_t global_work_size_t[2] = { (size_t)width_B, (size_t)padded_height_B }; + backend_ctx->enqueue_ndrange_kernel(kernel, 2, global_work_size_t, local_work_size_t, dst); + + // gemm + kernel = backend_ctx->kernel_gemm_noshuffle_iq4_nl_f32; + int padded_N = N + padding; + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0_iq4_nl->q)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra0_iq4_nl->d)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &b_img_trans)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_ulong), &offsetd)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_int), &ne01)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(cl_int), &padded_N)); + CL_CHECK(clSetKernelArg(kernel, 7, sizeof(cl_int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, 8, sizeof(cl_int), &ne1)); + + size_t global_work_size[3] = {(size_t)CEIL_DIV(ne1, 8), (size_t)CEIL_DIV(ne01, 4), 1}; + size_t local_work_size[3] = {1, 128, 1}; - size_t global_work_size_nc[] = { d_ne1 > 0 ? (size_t)d_ne1 : 1, - d_ne2 > 0 ? (size_t)d_ne2 : 1, - d_ne3 > 0 ? (size_t)d_ne3 : 1 }; + backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst); - backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size_nc, NULL, dst); + CL_CHECK(clReleaseMemObject(b_sub_buf)); + CL_CHECK(clReleaseMemObject(b_sub_buf_trans)); + CL_CHECK(clReleaseMemObject(b_img)); + CL_CHECK(clReleaseMemObject(b_img_trans)); } +#else + GGML_UNUSED(backend); + GGML_UNUSED(src0); + GGML_UNUSED(src1); + GGML_UNUSED(dst); +#endif } -static void ggml_cl_timestep_embedding(ggml_backend_t backend, const ggml_tensor * src0, ggml_tensor * dst) { +static void ggml_cl_mul_mat_q8_0_f32_adreno(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { +#ifdef GGML_OPENCL_USE_ADRENO_KERNELS GGML_ASSERT(src0); GGML_ASSERT(src0->extra); + GGML_ASSERT(src1); + GGML_ASSERT(src1->extra); GGML_ASSERT(dst); GGML_ASSERT(dst->extra); - GGML_ASSERT(src0->type == GGML_TYPE_F32); - GGML_ASSERT(dst->type == GGML_TYPE_F32); - - ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context; - - if (backend_ctx->kernel_timestep_embedding == nullptr) { - GGML_LOG_WARN("%s: timestep_embedding kernel not available, skipping OpenCL execution.\n", __func__); - return; - } - - ggml_tensor_extra_cl * extra_src0 = (ggml_tensor_extra_cl *)src0->extra; - ggml_tensor_extra_cl * extra_dst = (ggml_tensor_extra_cl *)dst->extra; - - cl_ulong off_src0 = extra_src0->offset + src0->view_offs; - cl_ulong off_dst = extra_dst->offset + dst->view_offs; - - const int logical_dim = dst->op_params[0]; - const int max_period = dst->op_params[1]; - const int dst_nb1_bytes = dst->nb[1]; - - cl_kernel kernel = backend_ctx->kernel_timestep_embedding; - - CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra_src0->data_device)); - CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &off_src0)); - CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra_dst->data_device)); - CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &off_dst)); - CL_CHECK(clSetKernelArg(kernel, 4, sizeof(int), &dst_nb1_bytes)); - CL_CHECK(clSetKernelArg(kernel, 5, sizeof(int), &logical_dim)); - CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &max_period)); - - size_t gws0 = (size_t)(((logical_dim + 1) / 2) + 1); - size_t gws1 = (size_t)src0->ne[0]; - - size_t global_work_size[] = {gws0, gws1, 1}; - - backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, NULL, dst); -} - -static void ggml_cl_flash_attn(ggml_backend_t backend, const ggml_tensor * q, const ggml_tensor * k, ggml_tensor * dst) { - const ggml_tensor * v = dst->src[2]; - const ggml_tensor * mask = dst->src[3]; - const ggml_tensor * sinks = dst->src[4]; - GGML_ASSERT(q->extra); - GGML_ASSERT(k->extra); - GGML_ASSERT(v->extra); - GGML_ASSERT(dst->extra); - if (mask) { - GGML_ASSERT(mask->extra); - } - if (sinks) { - GGML_ASSERT(sinks->extra); - } + GGML_ASSERT(src0->type == GGML_TYPE_Q8_0); + GGML_ASSERT(src1->type == GGML_TYPE_F32); ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context; - const int n_q = q->ne[1]; - const int n_kv = k->ne[1]; - const int d_head_q = q->ne[0]; - const int d_head_v = v->ne[0]; - const int n_head = q->ne[2]; - const int n_head_kv = k->ne[2]; - const int n_batch = q->ne[3]; - - cl_kernel kernel = NULL; - - const bool is_f16 = q->type == GGML_TYPE_F16; - const bool is_mixed = q->type == GGML_TYPE_F32 && k->type == GGML_TYPE_F16; - const std::pair<int, int> dk_dv = {d_head_q, d_head_v}; + ggml_tensor_extra_cl * extra1 = (ggml_tensor_extra_cl *)src1->extra; + ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra; + ggml_tensor_extra_cl_q8_0 * extra0_q8_0 = (ggml_tensor_extra_cl_q8_0 *)src0->extra; - if (n_q == 1) { - if (is_mixed) { - kernel = backend_ctx->kernels_flash_attn_f32_f16_q1.at(dk_dv); - } else if (is_f16) { - kernel = backend_ctx->kernels_flash_attn_f16_q1.at(dk_dv); - } else { - kernel = backend_ctx->kernels_flash_attn_f32_q1.at(dk_dv); - } - } else { - if (is_mixed) { - kernel = backend_ctx->kernels_flash_attn_f32_f16.at(dk_dv); - } else if (is_f16) { - kernel = backend_ctx->kernels_flash_attn_f16.at(dk_dv); - } else { - kernel = backend_ctx->kernels_flash_attn_f32.at(dk_dv); - } - } - GGML_ASSERT(kernel != NULL); + cl_ulong offset1 = extra1->offset + src1->view_offs; + cl_ulong offsetd = extrad->offset + dst->view_offs; - ggml_tensor_extra_cl * extra_q = (ggml_tensor_extra_cl *)q->extra; - ggml_tensor_extra_cl * extra_k = (ggml_tensor_extra_cl *)k->extra; - ggml_tensor_extra_cl * extra_v = (ggml_tensor_extra_cl *)v->extra; - ggml_tensor_extra_cl * extra_o = (ggml_tensor_extra_cl *)dst->extra; - ggml_tensor_extra_cl * extra_mask = mask ? (ggml_tensor_extra_cl *)mask->extra : NULL; - ggml_tensor_extra_cl * extra_sinks = sinks ? (ggml_tensor_extra_cl *)sinks->extra : NULL; + GGML_ASSERT(src1->view_offs == 0); + GGML_ASSERT(dst->view_offs == 0); - cl_ulong offset_q = extra_q->offset + q->view_offs; - cl_ulong offset_k = extra_k->offset + k->view_offs; - cl_ulong offset_v = extra_v->offset + v->view_offs; - cl_ulong offset_o = extra_o->offset + dst->view_offs; - cl_mem mask_buffer = extra_mask ? extra_mask->data_device : NULL; - cl_ulong offset_mask = extra_mask ? extra_mask->offset + mask->view_offs : 0; - cl_mem sinks_buffer = extra_sinks ? extra_sinks->data_device : NULL; - cl_ulong offset_sinks = extra_sinks ? extra_sinks->offset + sinks->view_offs : 0; + const int ne00 = src0->ne[0]; + const int ne01 = src0->ne[1]; + const int ne02 = src0->ne[2]; - const cl_ulong q_nb1 = q->nb[1], q_nb2 = q->nb[2], q_nb3 = q->nb[3]; - const cl_ulong k_nb1 = k->nb[1], k_nb2 = k->nb[2], k_nb3 = k->nb[3]; - const cl_ulong v_nb1 = v->nb[1], v_nb2 = v->nb[2], v_nb3 = v->nb[3]; - const cl_ulong o_nb1 = dst->nb[1], o_nb2 = dst->nb[2], o_nb3 = dst->nb[3]; - const cl_ulong mask_nb1 = mask ? mask->nb[1] : 0; - const cl_ulong mask_nb2 = mask ? mask->nb[2] : 0; - const cl_ulong mask_nb3 = mask ? mask->nb[3] : 0; - const int mask_ne2 = mask ? mask->ne[2] : 0; - const int mask_ne3 = mask ? mask->ne[3] : 0; + const int ne10 = src1->ne[0]; + const int ne12 = src1->ne[2]; - float scale, max_bias, logit_softcap; - const float * params = (const float *)dst->op_params; - scale = params[0]; - max_bias = params[1]; - logit_softcap = params[2]; + const int ne0 = dst->ne[0]; + const int ne1 = dst->ne[1]; - const int is_causal = (mask == NULL && n_q > 1 && n_q == n_kv); + GGML_ASSERT(ne00 == ne10); + GGML_ASSERT((ne00 % 32) == 0); + GGML_ASSERT(ne0 == ne01); - const int n_head_log2_val = n_head > 0 ? 1u << (int)floorf(log2f((float)n_head)) : 0; - const float n_head_log2_f = n_head_log2_val > 0 ? (float)n_head_log2_val : 1.0f; - const float m0 = powf(2.0f, -(max_bias) / n_head_log2_f); - const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2_f); + cl_context context = backend_ctx->context; + cl_kernel kernel; - CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra_q->data_device)); - CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset_q)); - CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra_k->data_device)); - CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offset_k)); - CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extra_v->data_device)); - CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offset_v)); - CL_CHECK(clSetKernelArg(kernel, 6, sizeof(cl_mem), &extra_o->data_device)); - CL_CHECK(clSetKernelArg(kernel, 7, sizeof(cl_ulong), &offset_o)); - CL_CHECK(clSetKernelArg(kernel, 8, sizeof(float), &scale)); - CL_CHECK(clSetKernelArg(kernel, 9, sizeof(int), &n_q)); - CL_CHECK(clSetKernelArg(kernel, 10, sizeof(int), &n_kv)); - CL_CHECK(clSetKernelArg(kernel, 11, sizeof(int), &is_causal)); - CL_CHECK(clSetKernelArg(kernel, 12, sizeof(int), &n_head)); - CL_CHECK(clSetKernelArg(kernel, 13, sizeof(cl_ulong), &q_nb1)); CL_CHECK(clSetKernelArg(kernel, 14, sizeof(cl_ulong), &q_nb2)); CL_CHECK(clSetKernelArg(kernel, 15, sizeof(cl_ulong), &q_nb3)); - CL_CHECK(clSetKernelArg(kernel, 16, sizeof(cl_ulong), &k_nb1)); CL_CHECK(clSetKernelArg(kernel, 17, sizeof(cl_ulong), &k_nb2)); CL_CHECK(clSetKernelArg(kernel, 18, sizeof(cl_ulong), &k_nb3)); - CL_CHECK(clSetKernelArg(kernel, 19, sizeof(cl_ulong), &v_nb1)); CL_CHECK(clSetKernelArg(kernel, 20, sizeof(cl_ulong), &v_nb2)); CL_CHECK(clSetKernelArg(kernel, 21, sizeof(cl_ulong), &v_nb3)); - CL_CHECK(clSetKernelArg(kernel, 22, sizeof(cl_ulong), &o_nb1)); CL_CHECK(clSetKernelArg(kernel, 23, sizeof(cl_ulong), &o_nb2)); CL_CHECK(clSetKernelArg(kernel, 24, sizeof(cl_ulong), &o_nb3)); - CL_CHECK(clSetKernelArg(kernel, 25, sizeof(float), &max_bias)); - CL_CHECK(clSetKernelArg(kernel, 26, sizeof(float), &m0)); - CL_CHECK(clSetKernelArg(kernel, 27, sizeof(float), &m1)); - CL_CHECK(clSetKernelArg(kernel, 28, sizeof(int), &n_head_log2_val)); - CL_CHECK(clSetKernelArg(kernel, 29, sizeof(float), &logit_softcap)); - CL_CHECK(clSetKernelArg(kernel, 30, sizeof(int), &n_head_kv)); - CL_CHECK(clSetKernelArg(kernel, 31, sizeof(cl_mem), &mask_buffer)); - CL_CHECK(clSetKernelArg(kernel, 32, sizeof(cl_ulong), &offset_mask)); - CL_CHECK(clSetKernelArg(kernel, 33, sizeof(cl_ulong), &mask_nb1)); - CL_CHECK(clSetKernelArg(kernel, 34, sizeof(cl_ulong), &mask_nb2)); - CL_CHECK(clSetKernelArg(kernel, 35, sizeof(cl_ulong), &mask_nb3)); - CL_CHECK(clSetKernelArg(kernel, 36, sizeof(int), &mask_ne2)); - CL_CHECK(clSetKernelArg(kernel, 37, sizeof(int), &mask_ne3)); - CL_CHECK(clSetKernelArg(kernel, 38, sizeof(cl_mem), &sinks_buffer)); - CL_CHECK(clSetKernelArg(kernel, 39, sizeof(cl_ulong), &offset_sinks)); + cl_int err; + cl_image_format img_fmt; + cl_image_desc img_desc; + cl_buffer_region region; - if (n_q == 1) { - const size_t wg_size = 64; - size_t local_work_size[] = { wg_size, 1 }; - size_t global_work_size[] = { wg_size, (size_t)(n_head * n_batch) }; - backend_ctx->enqueue_ndrange_kernel(kernel, 2, global_work_size, local_work_size, dst); - } else { - const int block_m = backend_ctx->kernels_flash_attn_bm.at(dk_dv); - const size_t wg_size = block_m; - size_t local_work_size[] = { wg_size, 1 }; - size_t global_work_size[] = { (size_t)((n_q + block_m - 1) / block_m) * wg_size, (size_t)(n_head * n_batch) }; - backend_ctx->enqueue_ndrange_kernel(kernel, 2, global_work_size, local_work_size, dst); - } -} + int M = ne01; + int N = ne1; + int K = ne00; -static void ggml_cl_mul_mat_f16_f32_tiled(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { - ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context; + if (ne1 == 1) { + cl_mem q_img = nullptr; + cl_mem b_sub_buf = nullptr; + cl_mem b_img = nullptr; - ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra; - ggml_tensor_extra_cl * extra1 = (ggml_tensor_extra_cl *)src1->extra; - ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra; + // image for q + img_fmt = { CL_R, CL_UNSIGNED_INT32}; + memset(&img_desc, 0, sizeof(img_desc)); + img_desc.image_type = CL_MEM_OBJECT_IMAGE1D_BUFFER; + img_desc.image_width = M * K / 4; + img_desc.buffer = extra0_q8_0->q; + CL_CHECK((q_img = clCreateImage(context, CL_MEM_READ_ONLY, &img_fmt, &img_desc, NULL, &err), err)); - cl_ulong offset0 = extra0->offset + src0->view_offs; - cl_ulong offset1 = extra1->offset + src1->view_offs; - cl_ulong offsetd = extrad->offset + dst->view_offs; + // create a sub_buffer for B + region.origin = offset1; + region.size = K * N * sizeof(float); + CL_CHECK((b_sub_buf = clCreateSubBuffer((extra1->data_device), 0, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &err), err)); - const int M = src0->ne[1]; - const int N = src1->ne[1]; - const int K = src0->ne[0]; + // image for activations + img_fmt = {CL_RGBA, CL_FLOAT}; + memset(&img_desc, 0, sizeof(img_desc)); + img_desc.image_type = CL_MEM_OBJECT_IMAGE1D_BUFFER; + img_desc.image_width = K * N / 4; + img_desc.buffer = b_sub_buf; + CL_CHECK((b_img = clCreateImage(context, CL_MEM_READ_ONLY, &img_fmt, &img_desc, NULL, &err), err)); - cl_kernel kernel = backend_ctx->kernel_mul_mat_f16_f32_tiled; + kernel = backend_ctx->kernel_gemv_noshuffle_q8_0_f32; - CL_CHECK(clSetKernelArg(kernel, 0, sizeof(int), &M)); - CL_CHECK(clSetKernelArg(kernel, 1, sizeof(int), &N)); - CL_CHECK(clSetKernelArg(kernel, 2, sizeof(int), &K)); - CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_mem), &extra0->data_device)); - CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_ulong), &offset0)); - CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_mem), &extra1->data_device)); - CL_CHECK(clSetKernelArg(kernel, 6, sizeof(cl_ulong), &offset1)); - CL_CHECK(clSetKernelArg(kernel, 7, sizeof(cl_mem), &extrad->data_device)); - CL_CHECK(clSetKernelArg(kernel, 8, sizeof(cl_ulong), &offsetd)); + int r2 = 1; + int r3 = 1; - // Tiling parameters. These need to be tuned for optimal performance. - // They must match the #defines in the kernel mul_mat_f16_f32.cl. - // - // OPWM / OPWN: Output tile size per Work-Group. A work-group computes a tile of size OPWM x OPWN. - // TPWM / TPWN: Threads per Work-group. This is the work-group size. - // OPTM / OPTN: Output elements per Thread. Each thread computes OPTM x OPTN elements. - // - // The following relationships must hold: - // OPWM = TPWM * OPTM - // OPWN = TPWN * OPTN - // - const int OPWM = 64; - const int OPWN = 64; - const int TPWM = 16; - const int TPWN = 8; + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &q_img)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra0_q8_0->d)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &b_img)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &extra1->offset)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &extrad->offset)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &ne01)); + CL_CHECK(clSetKernelArg(kernel, 8, sizeof(int), &ne02)); + CL_CHECK(clSetKernelArg(kernel, 9, sizeof(int), &ne10)); + CL_CHECK(clSetKernelArg(kernel, 10, sizeof(int), &ne12)); + CL_CHECK(clSetKernelArg(kernel, 11, sizeof(int), &ne0)); + CL_CHECK(clSetKernelArg(kernel, 12, sizeof(int), &ne1)); + CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int), &r2)); + CL_CHECK(clSetKernelArg(kernel, 14, sizeof(int), &r3)); - size_t local_work_size[2] = { TPWM, TPWN }; - size_t global_work_size[2] = { - (size_t) ((M + OPWM - 1) / OPWM) * TPWM, - (size_t) ((N + OPWN - 1) / OPWN) * TPWN, - }; + size_t wavesize = backend_ctx->adreno_wave_size; + size_t local_work_size[] = { wavesize, 4, 1 }; + size_t global_work_size[] = { CEIL_DIV(M, wavesize)*wavesize, 4, 1 }; - backend_ctx->enqueue_ndrange_kernel(kernel, 2, global_work_size, local_work_size, dst); + backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst); + + CL_CHECK(clReleaseMemObject(q_img)); + CL_CHECK(clReleaseMemObject(b_img)); + CL_CHECK(clReleaseMemObject(b_sub_buf)); + } else { + cl_mem b_sub_buf = nullptr; + cl_mem b_sub_buf_trans = nullptr; + cl_mem b_img = nullptr; + cl_mem b_img_trans = nullptr; + + // subbuffer for activations + region.origin = offset1; + region.size = K * N * sizeof(float); + CL_CHECK((b_sub_buf = clCreateSubBuffer(extra1->data_device, 0, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &err), err)); + + // image for activations + img_fmt = {CL_RGBA, CL_FLOAT}; + memset(&img_desc, 0, sizeof(img_desc)); + img_desc.image_type = CL_MEM_OBJECT_IMAGE1D_BUFFER; + img_desc.image_width = K * N / 4; + img_desc.buffer = b_sub_buf; + CL_CHECK((b_img = clCreateImage(context, CL_MEM_READ_ONLY, &img_fmt, &img_desc, NULL, &err), err)); + + // pad N to multiple of 8 + int extra_elements = N % 8; + int padding = 0; + if (extra_elements > 0){ + padding = 8 - extra_elements; + } + + // subbuffer for transposed activations + region.origin = 0; + region.size = K * (N + padding) * sizeof(float)/2; + backend_ctx->prealloc_act_trans.allocate(context, region.size); + CL_CHECK((b_sub_buf_trans = clCreateSubBuffer(backend_ctx->prealloc_act_trans.buffer, 0, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &err), err)); + + // image for transposed activations + img_fmt = {CL_RGBA, CL_HALF_FLOAT}; + memset(&img_desc, 0, sizeof(img_desc)); + img_desc.image_type = CL_MEM_OBJECT_IMAGE1D_BUFFER; + img_desc.image_width = K * (N + padding) / 4; + img_desc.buffer = b_sub_buf_trans; + CL_CHECK((b_img_trans = clCreateImage(context, 0, &img_fmt, &img_desc, NULL, &err), err)); + + // transpose activations + int height_B = N/4; + if (height_B == 0) { + height_B = 1; + } + int width_B = K/4; + int padded_height_B = (N + padding)/4; + + kernel = backend_ctx->kernel_transpose_32_16; + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &b_img)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &b_img_trans)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(int), &height_B)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(int), &width_B)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(int), &padded_height_B)); + + size_t local_work_size_t[2] = { 1, 16 }; + size_t global_work_size_t[2] = { (size_t)width_B, (size_t)padded_height_B }; + backend_ctx->enqueue_ndrange_kernel(kernel, 2, global_work_size_t, local_work_size_t, dst); + + // gemm + kernel = backend_ctx->kernel_gemm_noshuffle_q8_0_f32; + int padded_N = N + padding; + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0_q8_0->q)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra0_q8_0->d)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &b_img_trans)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(int), &K)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(int), &M)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &padded_N)); + CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &N)); + CL_CHECK(clSetKernelArg(kernel, 8, sizeof(cl_ulong), &offsetd)); + + size_t global_work_size[] = { (size_t)CEIL_DIV(N, 8), (size_t)CEIL_DIV(M, 4), 1 }; + size_t local_work_size[] = { 2, 128, 1 }; + + backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst); + + CL_CHECK(clReleaseMemObject(b_img_trans)); + CL_CHECK(clReleaseMemObject(b_sub_buf_trans)); + CL_CHECK(clReleaseMemObject(b_img)); + CL_CHECK(clReleaseMemObject(b_sub_buf)); + } +#else + GGML_UNUSED(backend); + GGML_UNUSED(src0); + GGML_UNUSED(src1); + GGML_UNUSED(dst); +#endif } -static void ggml_cl_conv_2d(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { - GGML_TENSOR_BINARY_OP_LOCALS; +static void ggml_cl_mul_mat_q4_k_f32_adreno(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { +#ifdef GGML_OPENCL_USE_ADRENO_KERNELS + GGML_ASSERT(src0); + GGML_ASSERT(src0->extra); + GGML_ASSERT(src1); + GGML_ASSERT(src1->extra); + GGML_ASSERT(dst); + GGML_ASSERT(dst->extra); + ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context; - ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra; ggml_tensor_extra_cl * extra1 = (ggml_tensor_extra_cl *)src1->extra; ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra; + ggml_tensor_extra_cl_q4_K * extra0_q4_k = (ggml_tensor_extra_cl_q4_K *)src0->extra; - cl_ulong offset0 = extra0->offset + src0->view_offs; cl_ulong offset1 = extra1->offset + src1->view_offs; cl_ulong offsetd = extrad->offset + dst->view_offs; - const cl_uint Cout = ne03; const cl_uint Cin = ne02; const cl_uint N = ne13; - const cl_uint KW = ne00; const cl_uint KH = ne01; const cl_uint W = ne10; const cl_uint H = ne11; const cl_uint OW = ne0; const cl_uint OH = ne1; - - const cl_uint s0 = dst->op_params[0]; const cl_uint s1 = dst->op_params[1]; - const cl_uint p0 = dst->op_params[2]; const cl_uint p1 = dst->op_params[3]; - const cl_uint d0 = dst->op_params[4]; const cl_uint d1 = dst->op_params[5]; + const int ne00 = src0->ne[0]; + const int ne01 = src0->ne[1]; - const cl_uint cl_nb01 = nb01/ggml_type_size(src0->type); const cl_uint cl_nb02 = nb02/ggml_type_size(src0->type); const cl_uint cl_nb03 = nb03/ggml_type_size(src0->type); - const cl_uint cl_nb11 = nb11/ggml_type_size(src1->type); const cl_uint cl_nb12 = nb12/ggml_type_size(src1->type); const cl_uint cl_nb13 = nb13/ggml_type_size(src1->type); - const cl_uint cl_nb1 = nb1/ggml_type_size(dst->type); const cl_uint cl_nb2 = nb2/ggml_type_size(dst->type); const cl_uint cl_nb3 = nb3/ggml_type_size(dst->type); + const int ne1 = dst->ne[1]; - const int64_t NPQ = (int64_t)N * OW * OH; + GGML_ASSERT(ne00 % ggml_blck_size(src0->type) == 0); - const uint32_t BS_K = 64; - const uint32_t BS_NPQ = 64; - const uint32_t BS_CRS = 16; - const uint32_t VEC_SIZE = 4; + cl_context context = backend_ctx->context; + cl_kernel kernel; - const uint32_t TS_K = 4; - const uint32_t TS_NPQ = 8; + cl_int err; + cl_image_format img_fmt; + cl_image_desc img_desc; + cl_buffer_region region; - const uint32_t WG_K = BS_K / TS_K; - const uint32_t WG_NPQ = BS_NPQ / TS_NPQ; + int M = ne01; + int N = ne1; + int K = ne00; - auto splitWork = [](uint32_t work_size, uint32_t block_size) { return (block_size + work_size - 1) / block_size; }; - const uint32_t NB_K = splitWork(Cout, BS_K); - const uint32_t NB_NPQ = splitWork(NPQ, BS_NPQ); + cl_uchar mask_d6 = 0x3F; + cl_uchar mask_d4 = 0x0F; + cl_uchar mask_hi2 = 0xC0; + + if (ne1 == 1) { + cl_mem q_img = nullptr; + cl_mem b_sub_buf = nullptr; + cl_mem b_img = nullptr; + + // image for q + img_fmt = { CL_R, CL_UNSIGNED_INT32}; + memset(&img_desc, 0, sizeof(img_desc)); + img_desc.image_type = CL_MEM_OBJECT_IMAGE1D_BUFFER; + img_desc.image_width = M * K / 2 / 4; + img_desc.buffer = extra0_q4_k->q; + CL_CHECK((q_img = clCreateImage(context, CL_MEM_READ_ONLY, &img_fmt, &img_desc, NULL, &err), err)); + + // subbuffer for activations + region.origin = offset1; + region.size = K * N * sizeof(float); + CL_CHECK((b_sub_buf = clCreateSubBuffer(extra1->data_device, 0, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &err), err)); + + // image for activations + img_fmt = {CL_RGBA, CL_FLOAT}; + memset(&img_desc, 0, sizeof(img_desc)); + img_desc.image_type = CL_MEM_OBJECT_IMAGE1D_BUFFER; + img_desc.image_width = K * N / 4; + img_desc.buffer = b_sub_buf; + CL_CHECK((b_img = clCreateImage(context, CL_MEM_READ_ONLY, &img_fmt, &img_desc, NULL, &err), err)); + + kernel = backend_ctx->kernel_gemv_noshuffle_q4_k_f32; + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &q_img)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra0_q4_k->d)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra0_q4_k->dm)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_mem), &extra0_q4_k->s)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &b_img)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(cl_ulong), &offsetd)); + CL_CHECK(clSetKernelArg(kernel, 7, sizeof(cl_int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, 8, sizeof(cl_int), &ne01)); + CL_CHECK(clSetKernelArg(kernel, 9, sizeof(cl_uchar), &mask_d6)); + CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_uchar), &mask_d4)); + CL_CHECK(clSetKernelArg(kernel, 11, sizeof(cl_uchar), &mask_hi2)); + + size_t local_work_size[3] = {64, 4, 1}; + size_t global_work_size[3] = {(size_t)CEIL_DIV(ne01/2, 64)*64, 4, 1}; - cl_kernel kernel; - size_t shmem_size; + backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst); - if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16) { - kernel = backend_ctx->kernel_conv_2d_f16; - shmem_size = (size_t)(BS_K * BS_CRS * sizeof(cl_half) + BS_CRS * (BS_NPQ / VEC_SIZE) * sizeof(cl_half4)); - } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32) { - kernel = backend_ctx->kernel_conv_2d_f32; - shmem_size = (size_t)(BS_K * BS_CRS * sizeof(cl_float) + BS_CRS * (BS_NPQ / VEC_SIZE) * sizeof(cl_float4)); - } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32) { - kernel = backend_ctx->kernel_conv_2d_f16_f32; - shmem_size = (size_t)(BS_K * BS_CRS * sizeof(cl_half) + BS_CRS * (BS_NPQ / VEC_SIZE) * sizeof(cl_float4)); + CL_CHECK(clReleaseMemObject(q_img)); + CL_CHECK(clReleaseMemObject(b_sub_buf)); + CL_CHECK(clReleaseMemObject(b_img)); } else { - GGML_ASSERT(false && "Unsupported data type combination for conv2d"); - } - cl_uint idx = 0; - CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_mem), &extra0->data_device)); CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_ulong), &offset0)); - CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_mem), &extra1->data_device)); CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_ulong), &offset1)); - CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_mem), &extrad->data_device)); CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_ulong), &offsetd)); - CL_CHECK(clSetKernelArg(kernel, idx++, shmem_size, NULL)); - CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &Cout)); CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &Cin)); CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &N)); - CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &KW)); CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &KH)); CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &W)); CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &H)); - CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &OW)); CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &OH)); - CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &s0)); CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &s1)); CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &p0)); CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &p1)); - CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &d0)); CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &d1)); - CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &cl_nb01)); CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &cl_nb02)); CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &cl_nb03)); - CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &cl_nb11)); CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &cl_nb12)); CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &cl_nb13)); - CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &cl_nb1)); CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &cl_nb2)); CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &cl_nb3)); + cl_mem b_sub_buf = nullptr; + cl_mem b_sub_buf_trans = nullptr; + cl_mem b_img = nullptr; + cl_mem b_img_trans = nullptr; - size_t global_work_size[] = { (size_t)NB_K * WG_K, (size_t)NB_NPQ * WG_NPQ, 1 }; - size_t local_work_size[] = { (size_t)WG_K, (size_t)WG_NPQ, 1 }; + // subbuffer for activations + region.origin = offset1; + region.size = K * N * sizeof(float); + CL_CHECK((b_sub_buf = clCreateSubBuffer(extra1->data_device, 0, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &err), err)); + + // image for activations + img_fmt = {CL_RGBA, CL_FLOAT}; + memset(&img_desc, 0, sizeof(img_desc)); + img_desc.image_type = CL_MEM_OBJECT_IMAGE1D_BUFFER; + img_desc.image_width = K * N / 4; + img_desc.buffer = b_sub_buf; + CL_CHECK((b_img = clCreateImage(context, CL_MEM_READ_ONLY, &img_fmt, &img_desc, NULL, &err), err)); + + // pad N to multiple of 8 + int extra_elements = N % 8; + int padding = 0; + if (extra_elements > 0){ + padding = 8 - extra_elements; + } - backend_ctx->enqueue_ndrange_kernel(kernel, 2, global_work_size, local_work_size, dst); + // subbuffer for transposed activations + region.origin = 0; + region.size = K * (N + padding) * sizeof(float)/2; + backend_ctx->prealloc_act_trans.allocate(context, region.size); + CL_CHECK((b_sub_buf_trans = clCreateSubBuffer(backend_ctx->prealloc_act_trans.buffer, 0, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &err), err)); + + // image for transposed activations + img_fmt = {CL_RGBA, CL_HALF_FLOAT}; + memset(&img_desc, 0, sizeof(img_desc)); + img_desc.image_type = CL_MEM_OBJECT_IMAGE1D_BUFFER; + img_desc.image_width = K * (N + padding) / 4; + img_desc.buffer = b_sub_buf_trans; + CL_CHECK((b_img_trans = clCreateImage(context, 0, &img_fmt, &img_desc, NULL, &err), err)); + + // transpose activations + int height_B = N/4; + if (height_B == 0) { + height_B = 1; + } + int width_B = K/4; + int padded_height_B = (N + padding)/4; + + kernel = backend_ctx->kernel_transpose_32_16; + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &b_img)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &b_img_trans)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(int), &height_B)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(int), &width_B)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(int), &padded_height_B)); + + size_t local_work_size_t[2] = { 1, 16 }; + size_t global_work_size_t[2] = { (size_t)width_B, (size_t)padded_height_B }; + backend_ctx->enqueue_ndrange_kernel(kernel, 2, global_work_size_t, local_work_size_t, dst); + + // gemm + kernel = backend_ctx->kernel_gemm_noshuffle_q4_k_f32; + int padded_N = N + padding; + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0_q4_k->q)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra0_q4_k->s)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra0_q4_k->d)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_mem), &extra0_q4_k->dm)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &b_img_trans)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(cl_ulong), &offsetd)); + CL_CHECK(clSetKernelArg(kernel, 7, sizeof(cl_int), &ne01)); + CL_CHECK(clSetKernelArg(kernel, 8, sizeof(cl_int), &padded_N)); + CL_CHECK(clSetKernelArg(kernel, 9, sizeof(cl_int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_int), &ne1)); + CL_CHECK(clSetKernelArg(kernel, 11, sizeof(cl_uchar), &mask_d6)); + CL_CHECK(clSetKernelArg(kernel, 12, sizeof(cl_uchar), &mask_d4)); + CL_CHECK(clSetKernelArg(kernel, 13, sizeof(cl_uchar), &mask_hi2)); + + size_t global_work_size[3] = {(size_t)CEIL_DIV(ne1, 8), (size_t)CEIL_DIV(ne01, 4), 1}; + size_t local_work_size[3] = {1, 128, 1}; + + backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst); + CL_CHECK(clReleaseMemObject(b_sub_buf)); + CL_CHECK(clReleaseMemObject(b_sub_buf_trans)); + CL_CHECK(clReleaseMemObject(b_img)); + CL_CHECK(clReleaseMemObject(b_img_trans)); + } +#else + GGML_UNUSED(backend); + GGML_UNUSED(src0); + GGML_UNUSED(src1); + GGML_UNUSED(dst); +#endif } -static void ggml_cl_mul_mat_kq_kqv_adreno(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { +static void ggml_cl_mul_mat_q6_K_f32_adreno(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { +#ifdef GGML_OPENCL_USE_ADRENO_KERNELS + GGML_ASSERT(src0); + GGML_ASSERT(src0->extra); + GGML_ASSERT(src1); + GGML_ASSERT(src1->extra); + GGML_ASSERT(dst); + GGML_ASSERT(dst->extra); + ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context; - ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra; + ggml_tensor_extra_cl_q6_K * extra0_q6_K = (ggml_tensor_extra_cl_q6_K *)src0->extra; ggml_tensor_extra_cl * extra1 = (ggml_tensor_extra_cl *)src1->extra; ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra; - const int ne00 = src0->ne[0]; - const int ne01 = src0->ne[1]; - const int ne02 = src0->ne[2]; - - const cl_ulong nb01 = src0->nb[1]; - const cl_ulong nb02 = src0->nb[2]; - - const int ne10 = src1->ne[0]; - const int ne11 = src1->ne[1]; - const int ne12 = src1->ne[2]; + cl_ulong offset1 = extra1->offset + src1->view_offs; + cl_ulong offsetd = extrad->offset + dst->view_offs; - const cl_ulong nb10 = src1->nb[0]; + const int ne00 = src0->ne[0]; + const int ne01 = src0->ne[1]; - const int ne0 = dst->ne[0]; - const int ne1 = dst->ne[1]; + const int ne1 = dst->ne[1]; - GGML_ASSERT(ne00 == ne10); + GGML_ASSERT(ne00 % ggml_blck_size(src0->type) == 0); - cl_kernel kernel; cl_context context = backend_ctx->context; + cl_kernel kernel; - cl_int status; - cl_image_format img_fmt_1d; - cl_image_desc img_desc_1d; - cl_buffer_region region; - cl_mem A_image1d; - cl_mem A_sub_buffer; - cl_mem B_sub_buffer; - cl_mem D_image1d; - cl_mem D_sub_buffer; + cl_int err; + cl_buffer_region region; + cl_image_format img_fmt; + cl_image_desc img_desc; + + // subbuffer and image for activation + if (ne1 == 1) { + cl_mem ql_img = nullptr; + cl_mem qh_img = nullptr; + cl_mem b_sub_buffer = nullptr; + cl_mem b_img = nullptr; + + // image for ql + img_fmt.image_channel_order = CL_R; + img_fmt.image_channel_data_type = CL_FLOAT; + memset(&img_desc, 0, sizeof(img_desc)); + img_desc.image_type = CL_MEM_OBJECT_IMAGE1D_BUFFER; + img_desc.image_width = ne01 * ne00 / 8; + img_desc.buffer = extra0_q6_K->ql; + CL_CHECK((ql_img = clCreateImage(context, CL_MEM_READ_ONLY, &img_fmt, &img_desc, NULL, &err), err)); + + // image for qh + img_fmt.image_channel_order = CL_R; + img_fmt.image_channel_data_type = CL_HALF_FLOAT; + memset(&img_desc, 0, sizeof(img_desc)); + img_desc.image_type = CL_MEM_OBJECT_IMAGE1D_BUFFER; + img_desc.image_width = ne01 * ne00 / 8; + img_desc.buffer = extra0_q6_K->qh; + CL_CHECK((qh_img = clCreateImage(context, CL_MEM_READ_ONLY, &img_fmt, &img_desc, NULL, &err), err)); + + region.origin = offset1; + region.size = ne00 * ne1 * sizeof(float); + CL_CHECK((b_sub_buffer = clCreateSubBuffer(extra1->data_device, 0, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &err), err)); + + img_fmt.image_channel_order = CL_RGBA; + img_fmt.image_channel_data_type = CL_FLOAT; + memset(&img_desc, 0, sizeof(img_desc)); + img_desc.image_type = CL_MEM_OBJECT_IMAGE1D_BUFFER; + img_desc.image_width = ne00 * ne1 / 4; + img_desc.buffer = b_sub_buffer; + CL_CHECK((b_img = clCreateImage(context, CL_MEM_READ_ONLY, &img_fmt, &img_desc, NULL, &err), err)); + + kernel = backend_ctx->kernel_gemv_noshuffle_q6_K_f32; + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &ql_img)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &qh_img)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra0_q6_K->s)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_mem), &extra0_q6_K->d)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &b_img)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(cl_ulong), &offsetd)); + CL_CHECK(clSetKernelArg(kernel, 7, sizeof(cl_int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, 8, sizeof(cl_int), &ne01)); + + size_t local_work_size[3] = {64, 4, 1}; + size_t global_work_size[3] = {(size_t)CEIL_DIV(ne01/2, 64)*64, 4, 1}; - int M = ne01; - int N = ne1; - int K = ne00; + backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst); - if (nb01 > nb02) { - // KQ - kernel = backend_ctx->kernel_mul_mm_f16_f32_kq; + CL_CHECK(clReleaseMemObject(ql_img)); + CL_CHECK(clReleaseMemObject(qh_img)); + CL_CHECK(clReleaseMemObject(b_sub_buffer)); + CL_CHECK(clReleaseMemObject(b_img)); } else { - // KQV - kernel = backend_ctx->kernel_mul_mm_f16_f32_kqv; - } - // create sub-buffer for A - // <--------------------------------------------> // - extra0 = src0->view_src ? (ggml_tensor_extra_cl *)src0->view_src->extra : (ggml_tensor_extra_cl *)src0->extra; + cl_mem b_sub_buf; + cl_mem b_buf_trans; + cl_mem b_img; + cl_mem b_img_trans; + + // subbuffer for activation + region.origin = offset1; + region.size = ne00 * ne1 * sizeof(float); + CL_CHECK((b_sub_buf = clCreateSubBuffer(extra1->data_device, 0, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &err), err)); + + // image for activation + img_fmt.image_channel_order = CL_RGBA; + img_fmt.image_channel_data_type = CL_FLOAT; + memset(&img_desc, 0, sizeof(img_desc)); + img_desc.image_type = CL_MEM_OBJECT_IMAGE1D_BUFFER; + img_desc.image_width = ne00 * ne1 / 4; + img_desc.buffer = b_sub_buf; + CL_CHECK((b_img = clCreateImage(context, CL_MEM_READ_ONLY, &img_fmt, &img_desc, NULL, &err), err)); + + // pad N to multiple of 8 + int extra_elements = ne1 % 8; + int padding = 0; + if (extra_elements > 0){ + padding = 8 - extra_elements; + } - region.origin = (extra0->offset); - if (nb01 > nb02) { - // KQ - region.size = nb01 * ne01; - } else { - // KQV - region.size = nb02 * ne02; + // subbuffer for transposed activation + region.origin = 0; + region.size = ne00 * (ne1 + padding) * sizeof(float)/2; + backend_ctx->prealloc_act_trans.allocate(context, region.size); + CL_CHECK((b_buf_trans = clCreateSubBuffer(backend_ctx->prealloc_act_trans.buffer, 0, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &err), err)); + + // image for transposed activation + img_fmt.image_channel_order = CL_RGBA; + img_fmt.image_channel_data_type = CL_HALF_FLOAT; + memset(&img_desc, 0, sizeof(img_desc)); + img_desc.image_type = CL_MEM_OBJECT_IMAGE1D_BUFFER; + img_desc.image_width = ne00 * (ne1 + padding) / 4; + img_desc.buffer = b_buf_trans; + CL_CHECK((b_img_trans = clCreateImage(context, 0, &img_fmt, &img_desc, NULL, &err), err)); + + // transpose activation + int height_B = ne1/4; + if (height_B == 0) { + height_B = 1; + } + int width_B = ne00/4; + int padded_height_B = (ne1 + padding) / 4; + + kernel = backend_ctx->kernel_transpose_32_16; + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &b_img)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &b_img_trans)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(int), &height_B)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(int), &width_B)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(int), &padded_height_B)); + + size_t local_size_t[2] = { 1, 16 }; + size_t global_size_t[2] = { (size_t)width_B, (size_t)padded_height_B }; + backend_ctx->enqueue_ndrange_kernel(kernel, 2, global_size_t, local_size_t, dst); + + // gemm + kernel = backend_ctx->kernel_gemm_noshuffle_q6_K_f32; + int padded_N = ne1 + padding; + + cl_ushort mask_f000 = 0xF000; + cl_uchar mask_c0 = 0xC0; + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0_q6_K->ql)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra0_q6_K->qh)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra0_q6_K->s)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_mem), &extra0_q6_K->d)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &b_img_trans)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(cl_ulong), &offsetd)); + CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &ne01)); + CL_CHECK(clSetKernelArg(kernel, 8, sizeof(int), &padded_N)); + CL_CHECK(clSetKernelArg(kernel, 9, sizeof(int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, 10, sizeof(int), &ne1)); + CL_CHECK(clSetKernelArg(kernel, 11, sizeof(cl_ushort),&mask_f000)); + CL_CHECK(clSetKernelArg(kernel, 12, sizeof(cl_uchar), &mask_c0)); + + size_t global_work_size[3] = {(size_t)CEIL_DIV(ne1, 8), (size_t)CEIL_DIV(ne01, 4), 1}; + size_t local_work_size[3] = {2, 128, 1}; + backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst); + + CL_CHECK(clReleaseMemObject(b_sub_buf)); + CL_CHECK(clReleaseMemObject(b_img)); + CL_CHECK(clReleaseMemObject(b_buf_trans)); + CL_CHECK(clReleaseMemObject(b_img_trans)); } +#else + GGML_UNUSED(backend); + GGML_UNUSED(src0); + GGML_UNUSED(src1); + GGML_UNUSED(dst); +#endif +} - A_sub_buffer = clCreateSubBuffer((extra0->data_device), 0, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &status); - CL_CHECK(status); +static void ggml_cl_mul_mat_q5_K_f32_adreno(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { +#ifdef GGML_OPENCL_USE_ADRENO_KERNELS + GGML_ASSERT(src0); + GGML_ASSERT(src0->extra); + GGML_ASSERT(src1); + GGML_ASSERT(src1->extra); + GGML_ASSERT(dst); + GGML_ASSERT(dst->extra); - // <--------------------------------------------> // + ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context; - // create sub-buffer for B - // <--------------------------------------------> // - region.origin = (extra1->offset); - region.size = nb10 * ne10 * ne11 * ne12; - B_sub_buffer = clCreateSubBuffer((extra1->data_device), 0, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &status); - CL_CHECK(status); - // <--------------------------------------------> // + ggml_tensor_extra_cl * extra1 = (ggml_tensor_extra_cl *)src1->extra; + ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra; + ggml_tensor_extra_cl_q5_K * extra0_q5_k = (ggml_tensor_extra_cl_q5_K *)src0->extra; - img_fmt_1d = {CL_RGBA, CL_FLOAT}; - memset(&img_desc_1d, 0, sizeof(img_desc_1d)); - img_desc_1d.image_type = CL_MEM_OBJECT_IMAGE1D_BUFFER; - if (nb01 > nb02) { - img_desc_1d.image_width = (nb01 * ne01 / 4)/4; - } - else { - img_desc_1d.image_width = (nb02 * ne02 / 4)/4; - } - img_desc_1d.buffer = A_sub_buffer; - A_image1d = clCreateImage(context, CL_MEM_READ_ONLY, &img_fmt_1d, &img_desc_1d, NULL, &status); - CL_CHECK(status); + cl_ulong offset1 = extra1->offset + src1->view_offs; + cl_ulong offsetd = extrad->offset + dst->view_offs; - // create sub-buffer for output C - // <--------------------------------------------> // - region.origin = (extrad->offset); - region.size = ne0 * ne1 * dst->ne[2] * dst->nb[0]; // size of C in bytes - D_sub_buffer = clCreateSubBuffer((extrad->data_device), 0, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &status); - CL_CHECK(status); - // <--------------------------------------------> // + const int ne00 = src0->ne[0]; + const int ne01 = src0->ne[1]; + const int ne1 = dst->ne[1]; - // create image for C output - // <--------------------------------------------> // - img_fmt_1d = {CL_R, CL_FLOAT}; - memset(&img_desc_1d, 0, sizeof(img_desc_1d)); - img_desc_1d.image_type = CL_MEM_OBJECT_IMAGE1D_BUFFER; - img_desc_1d.image_width = ne0 * ne1 * dst->ne[2] * dst->nb[0] / 4; - img_desc_1d.buffer = D_sub_buffer; - D_image1d = clCreateImage(context, CL_MEM_WRITE_ONLY, &img_fmt_1d, &img_desc_1d, NULL, &status); - CL_CHECK(status); - // <--------------------------------------------> // + GGML_ASSERT(ne00 % ggml_blck_size(src0->type) == 0); - int offset_src0 = 0; - int offset_src1 = 0; + cl_context context = backend_ctx->context; + cl_kernel kernel; - // set kernel args - // <--------------------------------------------> // - cl_uint k_arg = 0; - CL_CHECK(clSetKernelArg(kernel, k_arg++, sizeof(cl_mem), &A_image1d)); - CL_CHECK(clSetKernelArg(kernel, k_arg++, sizeof(int), &offset_src0)); - CL_CHECK(clSetKernelArg(kernel, k_arg++, sizeof(cl_mem), &B_sub_buffer)); - CL_CHECK(clSetKernelArg(kernel, k_arg++, sizeof(int), &offset_src1)); - CL_CHECK(clSetKernelArg(kernel, k_arg++, sizeof(cl_mem), &D_image1d)); - CL_CHECK(clSetKernelArg(kernel, k_arg++, sizeof(int), &extrad->offset)); - CL_CHECK(clSetKernelArg(kernel, k_arg++, sizeof(int), &M)); - CL_CHECK(clSetKernelArg(kernel, k_arg++, sizeof(int), &K)); - CL_CHECK(clSetKernelArg(kernel, k_arg++, sizeof(int), &N)); - CL_CHECK(clSetKernelArg(kernel, k_arg++, sizeof(int), &ne02)); - CL_CHECK(clSetKernelArg(kernel, k_arg++, sizeof(int), &ne12)); - CL_CHECK(clSetKernelArg(kernel, k_arg++, sizeof(int), &nb01)); + cl_int err; + cl_image_format img_fmt; + cl_image_desc img_desc; + cl_buffer_region region; + + int M = ne01; + int N = ne1; + int K = ne00; - size_t global_work_size[3] = {64, static_cast<size_t>(((M+63)/64)), static_cast<size_t>(((N+31)/32)*ne12)}; - size_t local_work_size[3] = {64, 1, 2}; + cl_uchar mask_d6 = 0x3F; + cl_uchar mask_d4 = 0x0F; + cl_uchar mask_hi2 = 0xC0; + + if (ne1 == 1) { + cl_mem q_img = nullptr; + cl_mem qh_img = nullptr; + cl_mem b_sub_buf = nullptr; + cl_mem b_img = nullptr; + + // image for q (CL_R, CL_UNSIGNED_INT32): width = M*K/2/4 + img_fmt = {CL_R, CL_UNSIGNED_INT32}; + memset(&img_desc, 0, sizeof(img_desc)); + img_desc.image_type = CL_MEM_OBJECT_IMAGE1D_BUFFER; + img_desc.image_width = M * K / 2 / 4; + img_desc.buffer = extra0_q5_k->q; + CL_CHECK((q_img = clCreateImage(context, CL_MEM_READ_ONLY, &img_fmt, &img_desc, NULL, &err), err)); + + // image for qh (CL_R, CL_HALF_FLOAT): width = M*K/16 + img_fmt = {CL_R, CL_HALF_FLOAT}; + memset(&img_desc, 0, sizeof(img_desc)); + img_desc.image_type = CL_MEM_OBJECT_IMAGE1D_BUFFER; + img_desc.image_width = M * K / 16; + img_desc.buffer = extra0_q5_k->qh; + CL_CHECK((qh_img = clCreateImage(context, CL_MEM_READ_ONLY, &img_fmt, &img_desc, NULL, &err), err)); + + // subbuffer for activations + region.origin = offset1; + region.size = K * N * sizeof(float); + CL_CHECK((b_sub_buf = clCreateSubBuffer(extra1->data_device, 0, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &err), err)); + + // image for activations (CL_RGBA, CL_FLOAT): width = K*N/4 + img_fmt = {CL_RGBA, CL_FLOAT}; + memset(&img_desc, 0, sizeof(img_desc)); + img_desc.image_type = CL_MEM_OBJECT_IMAGE1D_BUFFER; + img_desc.image_width = K * N / 4; + img_desc.buffer = b_sub_buf; + CL_CHECK((b_img = clCreateImage(context, CL_MEM_READ_ONLY, &img_fmt, &img_desc, NULL, &err), err)); + + kernel = backend_ctx->kernel_gemv_noshuffle_q5_k_f32; + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &q_img)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &qh_img)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra0_q5_k->d)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_mem), &extra0_q5_k->dm)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extra0_q5_k->s)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_mem), &b_img)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, 7, sizeof(cl_ulong), &offsetd)); + CL_CHECK(clSetKernelArg(kernel, 8, sizeof(cl_int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, 9, sizeof(cl_int), &ne01)); + CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_uchar), &mask_d6)); + CL_CHECK(clSetKernelArg(kernel, 11, sizeof(cl_uchar), &mask_d4)); + CL_CHECK(clSetKernelArg(kernel, 12, sizeof(cl_uchar), &mask_hi2)); + + size_t local_work_size[3] = {64, 4, 1}; + size_t global_work_size[3] = {(size_t)CEIL_DIV(ne01/2, 64)*64, 4, 1}; - backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst); + backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst); - // deallocate sub buffers and images - // <--------------------------------------------> // - CL_CHECK(clReleaseMemObject(A_image1d)); - CL_CHECK(clReleaseMemObject(D_image1d)); - CL_CHECK(clReleaseMemObject(A_sub_buffer)); - CL_CHECK(clReleaseMemObject(B_sub_buffer)); - CL_CHECK(clReleaseMemObject(D_sub_buffer)); + CL_CHECK(clReleaseMemObject(q_img)); + CL_CHECK(clReleaseMemObject(qh_img)); + CL_CHECK(clReleaseMemObject(b_sub_buf)); + CL_CHECK(clReleaseMemObject(b_img)); + } else { + cl_mem b_sub_buf = nullptr; + cl_mem b_sub_buf_trans = nullptr; + cl_mem b_img = nullptr; + cl_mem b_img_trans = nullptr; + + // subbuffer for activations + region.origin = offset1; + region.size = K * N * sizeof(float); + CL_CHECK((b_sub_buf = clCreateSubBuffer(extra1->data_device, 0, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &err), err)); + + // image for activations + img_fmt = {CL_RGBA, CL_FLOAT}; + memset(&img_desc, 0, sizeof(img_desc)); + img_desc.image_type = CL_MEM_OBJECT_IMAGE1D_BUFFER; + img_desc.image_width = K * N / 4; + img_desc.buffer = b_sub_buf; + CL_CHECK((b_img = clCreateImage(context, CL_MEM_READ_ONLY, &img_fmt, &img_desc, NULL, &err), err)); + + // pad N to multiple of 8 + int extra_elements = N % 8; + int padding = 0; + if (extra_elements > 0) { + padding = 8 - extra_elements; + } + + // subbuffer for transposed activations + region.origin = 0; + region.size = K * (N + padding) * sizeof(float) / 2; + backend_ctx->prealloc_act_trans.allocate(context, region.size); + CL_CHECK((b_sub_buf_trans = clCreateSubBuffer(backend_ctx->prealloc_act_trans.buffer, 0, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &err), err)); + + // image for transposed activations + img_fmt = {CL_RGBA, CL_HALF_FLOAT}; + memset(&img_desc, 0, sizeof(img_desc)); + img_desc.image_type = CL_MEM_OBJECT_IMAGE1D_BUFFER; + img_desc.image_width = K * (N + padding) / 4; + img_desc.buffer = b_sub_buf_trans; + CL_CHECK((b_img_trans = clCreateImage(context, 0, &img_fmt, &img_desc, NULL, &err), err)); + + // transpose activations + int height_B = N / 4; + if (height_B == 0) height_B = 1; + int width_B = K / 4; + int padded_height_B = (N + padding) / 4; + + kernel = backend_ctx->kernel_transpose_32_16; + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &b_img)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &b_img_trans)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(int), &height_B)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(int), &width_B)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(int), &padded_height_B)); + + size_t local_work_size_t[2] = {1, 16}; + size_t global_work_size_t[2] = {(size_t)width_B, (size_t)padded_height_B}; + backend_ctx->enqueue_ndrange_kernel(kernel, 2, global_work_size_t, local_work_size_t, dst); + + // gemm + kernel = backend_ctx->kernel_gemm_noshuffle_q5_k_f32; + int padded_N = N + padding; + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0_q5_k->q)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra0_q5_k->qh)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra0_q5_k->s)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_mem), &extra0_q5_k->d)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extra0_q5_k->dm)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_mem), &b_img_trans)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, 7, sizeof(cl_ulong), &offsetd)); + CL_CHECK(clSetKernelArg(kernel, 8, sizeof(cl_int), &ne01)); + CL_CHECK(clSetKernelArg(kernel, 9, sizeof(cl_int), &padded_N)); + CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, 11, sizeof(cl_int), &ne1)); + CL_CHECK(clSetKernelArg(kernel, 12, sizeof(cl_uchar), &mask_d6)); + CL_CHECK(clSetKernelArg(kernel, 13, sizeof(cl_uchar), &mask_d4)); + CL_CHECK(clSetKernelArg(kernel, 14, sizeof(cl_uchar), &mask_hi2)); + + size_t global_work_size[3] = {(size_t)CEIL_DIV(ne1, 8), (size_t)CEIL_DIV(ne01, 4), 1}; + size_t local_work_size[3] = {1, 128, 1}; + + backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst); + + CL_CHECK(clReleaseMemObject(b_sub_buf)); + CL_CHECK(clReleaseMemObject(b_sub_buf_trans)); + CL_CHECK(clReleaseMemObject(b_img)); + CL_CHECK(clReleaseMemObject(b_img_trans)); + } +#else + GGML_UNUSED(backend); + GGML_UNUSED(src0); + GGML_UNUSED(src1); + GGML_UNUSED(dst); +#endif } static void ggml_cl_mul_mat(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { @@ -7582,8 +13760,9 @@ static void ggml_cl_mul_mat(ggml_backend_t backend, const ggml_tensor * src0, co GGML_ASSERT(dst); GGML_ASSERT(dst->extra); - const enum ggml_type src0t = src0 ? src0->type : GGML_TYPE_COUNT; - const enum ggml_type src1t = src1 ? src1->type : GGML_TYPE_COUNT; + // bf16 is stored as f16 on device + const enum ggml_type src0t = (src0->type == GGML_TYPE_BF16) ? GGML_TYPE_F16 : src0->type; + const enum ggml_type src1t = src1->type; ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context; @@ -7597,32 +13776,23 @@ static void ggml_cl_mul_mat(ggml_backend_t backend, const ggml_tensor * src0, co #ifdef GGML_OPENCL_SOA_Q ggml_tensor_extra_cl_q4_0 * extra0_q4_0 = (ggml_tensor_extra_cl_q4_0 *)src0->extra; + ggml_tensor_extra_cl_q4_1 * extra0_q4_1 = (ggml_tensor_extra_cl_q4_1 *)src0->extra; + ggml_tensor_extra_cl_q5_0 * extra0_q5_0 = (ggml_tensor_extra_cl_q5_0 *)src0->extra; + ggml_tensor_extra_cl_q5_1 * extra0_q5_1 = (ggml_tensor_extra_cl_q5_1 *)src0->extra; ggml_tensor_extra_cl_mxfp4 * extra0_mxfp4 = (ggml_tensor_extra_cl_mxfp4 *)src0->extra; ggml_tensor_extra_cl_q8_0 * extra0_q8_0 = (ggml_tensor_extra_cl_q8_0 *)src0->extra; + ggml_tensor_extra_cl_iq4_nl * extra0_iq4_nl = (ggml_tensor_extra_cl_iq4_nl *)src0->extra; + ggml_tensor_extra_cl_q4_K * extra0_q4_K = (ggml_tensor_extra_cl_q4_K *)src0->extra; + ggml_tensor_extra_cl_q5_K * extra0_q5_K = (ggml_tensor_extra_cl_q5_K *)src0->extra; + ggml_tensor_extra_cl_q6_K * extra0_q6_K = (ggml_tensor_extra_cl_q6_K *)src0->extra; #endif - const int ne00 = src0 ? src0->ne[0] : 0; - const int ne01 = src0 ? src0->ne[1] : 0; - const int ne02 = src0 ? src0->ne[2] : 0; - const int ne03 = src0 ? src0->ne[3] : 0; - - const cl_ulong nb00 = src0 ? src0->nb[0] : 0; - const cl_ulong nb01 = src0 ? src0->nb[1] : 0; - const cl_ulong nb02 = src0 ? src0->nb[2] : 0; - const cl_ulong nb03 = src0 ? src0->nb[3] : 0; - - const int ne10 = src1 ? src1->ne[0] : 0; - const int ne11 = src1 ? src1->ne[1] : 0; - const int ne12 = src1 ? src1->ne[2] : 0; - const int ne13 = src1 ? src1->ne[3] : 0; - - const cl_ulong nb10 = src1 ? src1->nb[0] : 0; - const cl_ulong nb11 = src1 ? src1->nb[1] : 0; - const cl_ulong nb12 = src1 ? src1->nb[2] : 0; - const cl_ulong nb13 = src1 ? src1->nb[3] : 0; - - const int ne0 = dst ? dst->ne[0] : 0; - const int ne1 = dst ? dst->ne[1] : 0; + GGML_TENSOR_LOCALS(int, ne0, src0, ne); + GGML_TENSOR_LOCALS(cl_ulong, nb0, src0, nb); + GGML_TENSOR_LOCALS(int, ne1, src1, ne); + GGML_TENSOR_LOCALS(cl_ulong, nb1, src1, nb); + GGML_TENSOR_LOCALS(int, ne, dst, ne); + GGML_TENSOR_LOCALS(cl_ulong, nb, dst, nb); int r2 = ne12/ne02; int r3 = ne13/ne03; @@ -7638,12 +13808,13 @@ static void ggml_cl_mul_mat(ggml_backend_t backend, const ggml_tensor * src0, co cl_kernel kernel; #ifdef GGML_OPENCL_USE_ADRENO_KERNELS - cl_context context = backend_ctx->context; - if(src0t == GGML_TYPE_F16 && src1t == GGML_TYPE_F32){ - if (ne01 >= 64 && ne1 >= 32 && ne00 >= 16 && (ne12 % ne02) == 0) { + if (ne01 >= 64 && ne1 >= 32 && ne00 >= 16 && (ne12 % ne02) == 0 && + // dst is wrapped with image1d_buffer, the size limit applies, also src0 + (ne0 * ne1 * dst->ne[2] * dst->nb[0] / 4 <= backend_ctx->image_max_buffer_size)) { // For KQ if (ggml_is_permuted(src0) && ggml_is_permuted(src1) && + ((nb01 * ne01 / 4)/4 <= backend_ctx->image_max_buffer_size) && nb00 <= nb02 && nb02 <= nb01 && nb01 <= nb03 && @@ -7654,7 +13825,8 @@ static void ggml_cl_mul_mat(ggml_backend_t backend, const ggml_tensor * src0, co return; } // For KQV - if (!ggml_is_contiguous(src0) && ggml_is_contiguous(src1)) { + if (!ggml_is_contiguous(src0) && ggml_is_contiguous(src1) && + ((nb02 * ne02 / 4)/4 <= backend_ctx->image_max_buffer_size)) { ggml_cl_mul_mat_kq_kqv_adreno(backend, src0, src1, dst); return; } @@ -7662,320 +13834,402 @@ static void ggml_cl_mul_mat(ggml_backend_t backend, const ggml_tensor * src0, co } if (ne01 && ne1 && use_adreno_kernels(backend_ctx, src0)) { + // NOTE: Kernels using image1d_buffer_t (e.g., src0_q) would normally require + // a limit check, but q4_0 / q4_1 tensors are very unlikely to exceed that + // limit, so the check is omitted. - // init CL objects - // <--------------------------------------------> // - cl_int status; - cl_image_format img_fmt_1d; - cl_image_desc img_desc_1d; - cl_buffer_region region; - cl_mem A_image1d = nullptr; - cl_mem B_image1d = nullptr; - cl_mem B_sub_buffer = nullptr; - cl_mem C_d = nullptr; - // for B transpose - cl_mem B_d = nullptr; - cl_mem B_d_input_image = nullptr; - // <--------------------------------------------> // + // q4_0 x fp32 + if(src0t == GGML_TYPE_Q4_0 && src1t == GGML_TYPE_F32) { + ggml_cl_mul_mat_q4_0_f32_adreno(backend, src0, src1, dst); + return; + } - // define matrix dimensions - // <--------------------------------------------> // - int M = ne01; - int N = ne1; - int K = ne00; - int padding; - // <--------------------------------------------> // + // q4_1 x fp32 + if (src0t == GGML_TYPE_Q4_1 && src1t == GGML_TYPE_F32) { + ggml_cl_mul_mat_q4_1_f32_adreno(backend, src0, src1, dst); + return; + } - // q4_0 x fp32 - if(src0t == GGML_TYPE_Q4_0 && src1t == GGML_TYPE_F32) { - // TODO: remove duplicate definitions of image description + format -- move to top + // q5_0 x fp32 + if (src0t == GGML_TYPE_Q5_0 && src1t == GGML_TYPE_F32) { + ggml_cl_mul_mat_q5_0_f32_adreno(backend, src0, src1, dst); + return; + } - // create an image for A - // <--------------------------------------------> // - if (N == 1) { - img_fmt_1d = { CL_R, CL_UNSIGNED_INT32}; - } else { - img_fmt_1d = { CL_R, CL_FLOAT}; - } - memset(&img_desc_1d, 0, sizeof(img_desc_1d)); - img_desc_1d.image_type = CL_MEM_OBJECT_IMAGE1D_BUFFER; - img_desc_1d.image_width = M * K / 2 / 4; // Divide by 4 for char -> float - img_desc_1d.buffer = extra0_q4_0->q; - A_image1d = clCreateImage( - context, - CL_MEM_READ_ONLY, - &img_fmt_1d, - &img_desc_1d, - NULL, - &status); - CL_CHECK(status); - // <--------------------------------------------> // + // q5_1 x fp32 + if (src0t == GGML_TYPE_Q5_1 && src1t == GGML_TYPE_F32) { + ggml_cl_mul_mat_q5_1_f32_adreno(backend, src0, src1, dst); + return; + } + // iq4_nl x fp32 + if (src0t == GGML_TYPE_IQ4_NL && src1t == GGML_TYPE_F32) { + ggml_cl_mul_mat_iq4_nl_f32_adreno(backend, src0, src1, dst); + return; + } - // create a sub_buffer for B - // <--------------------------------------------> // - region.origin = (extra1->offset); - region.size = K * N * sizeof(float); - B_sub_buffer = clCreateSubBuffer( - extra1->data_device, - 0, - CL_BUFFER_CREATE_TYPE_REGION, - ®ion, - &status); - CL_CHECK(status); - // <--------------------------------------------> // - - // transpose activation for Skyler's gemm - if (N != 1) { - //how many extra elements beyond multiple of 8 - int extra_elements = N % 8; - - //how much padding to add - padding = 0; - if (extra_elements > 0){ - padding = 8 - extra_elements; - } - - // Specify the starting offset (in bytes) - region.origin = 0; - // Specify the size of the sub-buffer (divide by 2 for FP16) - region.size = K * (N + padding) * sizeof(float)/2; - backend_ctx->prealloc_act_trans.allocate(context, region.size); - - B_d = clCreateSubBuffer( - backend_ctx->prealloc_act_trans.buffer, - 0, - CL_BUFFER_CREATE_TYPE_REGION, - ®ion, - &status); - CL_CHECK(status); - - cl_image_format image_format_B_d_input = { CL_RGBA, CL_FLOAT }; - cl_image_desc image_desc_B_d_input = { - CL_MEM_OBJECT_IMAGE1D_BUFFER, - static_cast<size_t>(K * N / 4), - 0, 0, 0, 0, 0, 0, 0, { B_sub_buffer } - }; - B_d_input_image = clCreateImage( - context, - 0, - &image_format_B_d_input, - &image_desc_B_d_input, - NULL, - &status); - CL_CHECK(status); - - cl_image_format image_format_B_d_output = { CL_RGBA, CL_HALF_FLOAT }; //(CL_HALF_FLOAT for FP16) - cl_image_desc image_desc_B_d_output = { - CL_MEM_OBJECT_IMAGE1D_BUFFER, - static_cast<size_t>(K * (N + padding)/4), - 0, 0, 0, 0, 0, 0, 0, { B_d } - }; - B_image1d = clCreateImage( - context, - 0, - &image_format_B_d_output, - &image_desc_B_d_output, - NULL, - &status); - CL_CHECK(status); - - int height_B = N/4; - if (height_B == 0) { - height_B = 1; - } - int width_B = K/4; - int padded_height_B = (N + padding)/4; - - kernel = backend_ctx->kernel_transpose_32_16; - CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &B_d_input_image)); - CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &B_image1d)); - CL_CHECK(clSetKernelArg(kernel, 2, sizeof(int), &height_B)); - CL_CHECK(clSetKernelArg(kernel, 3, sizeof(int), &width_B)); - CL_CHECK(clSetKernelArg(kernel, 4, sizeof(int), &padded_height_B)); - - size_t local_size_t[2] = { 1, 16 }; - //WGS tuning - if (ne0 == 4096 && ne1 == 128 && ne10 == 4096) { - local_size_t[0]=4; - local_size_t[1]=8; - } else if (ne0 == 11008 && ne1 == 128 && ne10 == 4096) { - local_size_t[0]=2; - local_size_t[1]=8; - } else if(ne0 == 4096 && ne1 == 128 && ne10 == 11008) { - local_size_t[0]=1; - local_size_t[1]=8; - } else if(ne0 == 32000 && ne1 == 128 && ne10 == 4096) { - local_size_t[0]=2; - local_size_t[1]=8; - } - - size_t global_size_t[2] = { - static_cast<size_t>(width_B), - static_cast<size_t>(padded_height_B) - }; + // q8_0 x fp32 + if (src0t == GGML_TYPE_Q8_0 && src1t == GGML_TYPE_F32 && + enable_adreno_trans_weight(backend_ctx, src0)) { + ggml_cl_mul_mat_q8_0_f32_adreno(backend, src0, src1, dst); + return; + } - backend_ctx->enqueue_ndrange_kernel(kernel, 2, global_size_t, local_size_t, dst); - } else { - // no need to transpose B in other cases - // create an image for B from sub_buffer - // <--------------------------------------------> // - img_fmt_1d = {CL_RGBA, CL_FLOAT}; - - memset(&img_desc_1d, 0, sizeof(img_desc_1d)); - img_desc_1d.image_width = K * N / 4; - img_desc_1d.image_type = CL_MEM_OBJECT_IMAGE1D_BUFFER; - img_desc_1d.buffer = B_sub_buffer; - B_image1d = clCreateImage( - context, - CL_MEM_READ_ONLY, - &img_fmt_1d, - &img_desc_1d, - NULL, - &status); - CL_CHECK(status); - // <--------------------------------------------> // - } - - // choose gemm or gemv kernel - // <--------------------------------------------> // - if (N == 1) { - kernel = backend_ctx->CL_mul_mat_vec_q4_0_f32_1d_4x_flat_general; - if (M == 4096 && K == 4096) { - kernel = backend_ctx->CL_mul_mat_vec_q4_0_f32_1d_4x_flat_4096_1_4096; - } else if (M == 4096 && K == 11008) { - kernel = backend_ctx->CL_mul_mat_vec_q4_0_f32_1d_4x_flat_4096_1_11008; - } else if (M == 11008 && K == 4096) { - kernel = backend_ctx->CL_mul_mat_vec_q4_0_f32_1d_4x_flat_11008_1_4096; - } else if (M == 32000 && K == 4096) { - kernel = backend_ctx->CL_mul_mat_vec_q4_0_f32_1d_4x_flat_32000_1_4096; - } - } else { - kernel = backend_ctx->CL_mul_mat_Ab_Bi_8x4; - } - // <--------------------------------------------> // - - // set kernel args - // <--------------------------------------------> // - cl_uint k_arg = 0; - - if (N == 1) { - CL_CHECK(clSetKernelArg(kernel, k_arg++, sizeof(cl_mem), &A_image1d)); - CL_CHECK(clSetKernelArg(kernel, k_arg++, sizeof(cl_mem), &extra0_q4_0->d)); - CL_CHECK(clSetKernelArg(kernel, k_arg++, sizeof(cl_mem), &B_image1d)); - CL_CHECK(clSetKernelArg(kernel, k_arg++, sizeof(cl_ulong), &extra1->offset)); - CL_CHECK(clSetKernelArg(kernel, k_arg++, sizeof(cl_mem), &extrad->data_device)); - CL_CHECK(clSetKernelArg(kernel, k_arg++, sizeof(cl_ulong), &extrad->offset)); - CL_CHECK(clSetKernelArg(kernel, k_arg++, sizeof(int), &ne00)); - CL_CHECK(clSetKernelArg(kernel, k_arg++, sizeof(int), &ne01)); - CL_CHECK(clSetKernelArg(kernel, k_arg++, sizeof(int), &ne02)); - CL_CHECK(clSetKernelArg(kernel, k_arg++, sizeof(int), &ne10)); - CL_CHECK(clSetKernelArg(kernel, k_arg++, sizeof(int), &ne12)); - CL_CHECK(clSetKernelArg(kernel, k_arg++, sizeof(int), &ne0)); - CL_CHECK(clSetKernelArg(kernel, k_arg++, sizeof(int), &ne1)); - CL_CHECK(clSetKernelArg(kernel, k_arg++, sizeof(int), &r2)); - CL_CHECK(clSetKernelArg(kernel, k_arg++, sizeof(int), &r3)); - } else { - region.origin = extrad->offset; // Specify the starting offset (in bytes) - region.size = M * N * sizeof(float); // Specify the size of the sub-buffer - C_d = clCreateSubBuffer(extrad->data_device, CL_MEM_WRITE_ONLY, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &status); - CL_CHECK(status); - - int padded_N = ne1 + padding; - - CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0_q4_0->q)); //A_q_dextra0_q4_0->q - CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra0_q4_0->d)); //A_s_d - CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &B_image1d)); //B_d - CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_mem), &C_d)); //C_d - CL_CHECK(clSetKernelArg(kernel, 4, sizeof(int), &ne01)); //M - CL_CHECK(clSetKernelArg(kernel, 5, sizeof(int), &padded_N)); //N with padding - CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne00)); //K - CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &ne1)); //N without padding - } - // <--------------------------------------------> // - - // choose workgroup size - // <--------------------------------------------> // - size_t global_work_size[3] = { - 64, static_cast<size_t>((M+63)/64), static_cast<size_t>((N+31)/32)}; - size_t local_work_size[3] = {64, 2, 4}; - - global_work_size[0] = (size_t)(ceil((float)ne1/8)); - global_work_size[1] = (size_t)(ne01/4); - global_work_size[2] = (size_t)(1); - - local_work_size[0] = (size_t)(1); //4x32 for FP32 - local_work_size[1] = (size_t)(128); - local_work_size[2] = (size_t)(1); - - //WGS tuning - if (ne0 == 4096 && ne1 == 128 && ne10 == 4096) { - local_work_size[0] = 1; - local_work_size[1] = 128; - } else if (ne0 == 11008 && ne1 == 128 && ne10 == 4096) { - local_work_size[0] = 2; - local_work_size[1] = 64; - } else if (ne0 == 4096 && ne1 == 128 && ne10 == 11008) { - local_work_size[0] = 2; - local_work_size[1] = 64; - } else if (ne0 == 32000 && ne1 == 128 && ne10 == 4096) { - local_work_size[0] = 2; - local_work_size[1] = 64; + // q4_k x fp32 + if (src0t == GGML_TYPE_Q4_K && src1t == GGML_TYPE_F32 && !use_flat_gemv_for_large_m_q4_K(src0)) { + ggml_cl_mul_mat_q4_k_f32_adreno(backend, src0, src1, dst); + return; + } + + // q6_K x fp32 + if (src0t == GGML_TYPE_Q6_K && src1t == GGML_TYPE_F32 && !use_flat_gemv_for_large_m_q6_K(src0)) { + ggml_cl_mul_mat_q6_K_f32_adreno(backend, src0, src1, dst); + return; } - if (N == 1) { - size_t wavesize = backend_ctx->adreno_wave_size; - local_work_size[0] = wavesize; // localsize - local_work_size[1] = 4; // reduce factor - local_work_size[2] = 1; + // q5_K x fp32 + if (src0t == GGML_TYPE_Q5_K && src1t == GGML_TYPE_F32) { + ggml_cl_mul_mat_q5_K_f32_adreno(backend, src0, src1, dst); + return; + } + } // if (ne01 && ne1) +#endif // GGML_OPENCL_USE_ADRENO_KERNELS + + // GEMM using local memory + // Current BK = 16, so ne00 % 16 == 0 + if (src1t == GGML_TYPE_F32 && + ne00 % 16 == 0 && + ne11 > 1) { + switch(src0t) { + case GGML_TYPE_F32: { + kernel = backend_ctx->kernel_mul_mm_f32_f32_l4_lm; + nth0 = 128; // calculated as (BM*BN)/(TM*TN) + + int batch_stride_a = ne00*ne01; + int batch_stride_b = ne10*ne11; + int batch_stride_d = ne0*ne1; + + cl_mem mem_src0 = extra0->data_device; + cl_mem mem_src1 = extra1->data_device; + + cl_ulong nb00_cont = nb00; + cl_ulong nb01_cont = nb01; + cl_ulong nb02_cont = nb02; + cl_ulong nb03_cont = nb03; + + cl_ulong nb10_cont = nb10; + cl_ulong nb11_cont = nb11; + cl_ulong nb12_cont = nb12; + cl_ulong nb13_cont = nb13; + + cl_ulong offset0_cont = offset0; + cl_ulong offset1_cont = offset1; + + if (!ggml_is_contiguous(src0)) { + backend_ctx->prealloc_src0.allocate(backend_ctx->context, ggml_nbytes(src0)); + ggml_cl_copy_to_contiguous(backend, src0, backend_ctx->prealloc_src0.buffer, + nb00_cont, nb01_cont, nb02_cont, nb03_cont); + mem_src0 = backend_ctx->prealloc_src0.buffer; + offset0_cont = 0; + } + + if (!ggml_is_contiguous(src1)) { + backend_ctx->prealloc_src1.allocate(backend_ctx->context, ggml_nbytes(src1)); + ggml_cl_copy_to_contiguous(backend, src1, backend_ctx->prealloc_src1.buffer, + nb10_cont, nb11_cont, nb12_cont, nb13_cont); + mem_src1 = backend_ctx->prealloc_src1.buffer; + offset1_cont = 0; + } + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &mem_src0)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0_cont)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &mem_src1)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offset1_cont)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offsetd)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &ne01)); + CL_CHECK(clSetKernelArg(kernel, 8, sizeof(int), &ne02)); + CL_CHECK(clSetKernelArg(kernel, 9, sizeof(int), &ne11)); + CL_CHECK(clSetKernelArg(kernel, 10, sizeof(int), &ne12)); + CL_CHECK(clSetKernelArg(kernel, 11, sizeof(int), &ne10)); // stride_a + CL_CHECK(clSetKernelArg(kernel, 12, sizeof(int), &ne10)); // stride_b + CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int), &ne01)); // stride_d + CL_CHECK(clSetKernelArg(kernel, 14, sizeof(int), &batch_stride_a)); + CL_CHECK(clSetKernelArg(kernel, 15, sizeof(int), &batch_stride_b)); + CL_CHECK(clSetKernelArg(kernel, 16, sizeof(int), &batch_stride_d)); + CL_CHECK(clSetKernelArg(kernel, 17, sizeof(int), &r2)); + CL_CHECK(clSetKernelArg(kernel, 18, sizeof(int), &r3)); + + // 64 is block tile size BM and BN - change here when BM and BN in the kernel are changed. + size_t global_work_size[] = {(size_t)(CEIL_DIV(ne01, 64)*nth0), (size_t)(CEIL_DIV(ne11, 64)), (size_t)ne12*ne13}; + size_t local_work_size[] = {(size_t)nth0, 1, 1}; + + backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst); + return; + } + case GGML_TYPE_F16: { +#ifdef GGML_OPENCL_USE_ADRENO_KERNELS + if (ggml_cl_can_use_adreno_xmem_gemm_f16_f32(backend_ctx, src0, src1, dst)) { + ggml_cl_mul_mat_f16_f32_adreno_xmem(backend, src0, src1, dst); + return; + } +#endif + kernel = backend_ctx->kernel_mul_mm_f16_f32_l4_lm; + nth0 = 128; // calculated as (BM*BN)/(TM*TN) + + int batch_stride_a = ne00*ne01; + int batch_stride_b = ne10*ne11; + int batch_stride_d = ne0*ne1; + + cl_mem mem_src0 = extra0->data_device; + cl_mem mem_src1 = extra1->data_device; + + cl_ulong nb00_cont = nb00; + cl_ulong nb01_cont = nb01; + cl_ulong nb02_cont = nb02; + cl_ulong nb03_cont = nb03; + + cl_ulong nb10_cont = nb10; + cl_ulong nb11_cont = nb11; + cl_ulong nb12_cont = nb12; + cl_ulong nb13_cont = nb13; + + cl_ulong offset0_cont = offset0; + cl_ulong offset1_cont = offset1; + + if (!ggml_is_contiguous(src0)) { + backend_ctx->prealloc_src0.allocate(backend_ctx->context, ggml_nbytes(src0)); + ggml_cl_copy_to_contiguous(backend, src0, backend_ctx->prealloc_src0.buffer, + nb00_cont, nb01_cont, nb02_cont, nb03_cont); + mem_src0 = backend_ctx->prealloc_src0.buffer; + offset0_cont = 0; + } + + if (!ggml_is_contiguous(src1)) { + backend_ctx->prealloc_src1.allocate(backend_ctx->context, ggml_nbytes(src1)); + ggml_cl_copy_to_contiguous(backend, src1, backend_ctx->prealloc_src1.buffer, + nb10_cont, nb11_cont, nb12_cont, nb13_cont); + mem_src1 = backend_ctx->prealloc_src1.buffer; + offset1_cont = 0; + } + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &mem_src0)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0_cont)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &mem_src1)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offset1_cont)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offsetd)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &ne01)); + CL_CHECK(clSetKernelArg(kernel, 8, sizeof(int), &ne02)); + CL_CHECK(clSetKernelArg(kernel, 9, sizeof(int), &ne11)); + CL_CHECK(clSetKernelArg(kernel, 10, sizeof(int), &ne12)); + CL_CHECK(clSetKernelArg(kernel, 11, sizeof(int), &ne10)); // stride_a + CL_CHECK(clSetKernelArg(kernel, 12, sizeof(int), &ne10)); // stride_b + CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int), &ne01)); // stride_d + CL_CHECK(clSetKernelArg(kernel, 14, sizeof(int), &batch_stride_a)); + CL_CHECK(clSetKernelArg(kernel, 15, sizeof(int), &batch_stride_b)); + CL_CHECK(clSetKernelArg(kernel, 16, sizeof(int), &batch_stride_d)); + CL_CHECK(clSetKernelArg(kernel, 17, sizeof(int), &r2)); + CL_CHECK(clSetKernelArg(kernel, 18, sizeof(int), &r3)); + + // 64 is block tile size BM and BN - change here when BM and BN in the kernel are changed. + size_t global_work_size[] = {(size_t)(CEIL_DIV(ne01, 64)*nth0), (size_t)(CEIL_DIV(ne11, 64)), (size_t)ne12*ne13}; + size_t local_work_size[] = {(size_t)nth0, 1, 1}; + + backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst); + return; + } + case GGML_TYPE_Q4_0: { + if (ne11 < 32) { + break; + } + if (!ggml_is_contiguous(src0) || !ggml_is_contiguous(src1)) { + break; + } + + kernel = backend_ctx->kernel_mul_mm_q4_0_f32_l4_lm; + nth0 = 128; // calculated as (BM*BN)/(TM*TN) + + int batch_stride_a = ne00*ne01; + int batch_stride_b = ne10*ne11; + int batch_stride_d = ne0*ne1; + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0_q4_0->q)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra0_q4_0->d)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra1->data_device)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offset1)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offsetd)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &ne01)); + CL_CHECK(clSetKernelArg(kernel, 8, sizeof(int), &ne02)); + CL_CHECK(clSetKernelArg(kernel, 9, sizeof(int), &ne11)); + CL_CHECK(clSetKernelArg(kernel, 10, sizeof(int), &ne12)); + CL_CHECK(clSetKernelArg(kernel, 11, sizeof(int), &ne10)); // stride_a + CL_CHECK(clSetKernelArg(kernel, 12, sizeof(int), &ne10)); // stride_b + CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int), &ne01)); // stride_d + CL_CHECK(clSetKernelArg(kernel, 14, sizeof(int), &batch_stride_a)); + CL_CHECK(clSetKernelArg(kernel, 15, sizeof(int), &batch_stride_b)); + CL_CHECK(clSetKernelArg(kernel, 16, sizeof(int), &batch_stride_d)); + CL_CHECK(clSetKernelArg(kernel, 17, sizeof(int), &r2)); + CL_CHECK(clSetKernelArg(kernel, 18, sizeof(int), &r3)); + + // 64 is block tile size BM and BN - change here when BM and BN in the kernel are changed. + size_t global_work_size[] = {(size_t)(CEIL_DIV(ne01, 64)*nth0), (size_t)(CEIL_DIV(ne11, 64)), (size_t)ne12*ne13}; + size_t local_work_size[] = {(size_t)nth0, 1, 1}; + + backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst); + return; + } + case GGML_TYPE_Q4_1: { + if (ne11 < 32) { + break; + } + if (!ggml_is_contiguous(src0) || !ggml_is_contiguous(src1)) { + break; + } + + kernel = backend_ctx->kernel_mul_mm_q4_1_f32_l4_lm; + nth0 = 128; // calculated as (BM*BN)/(TM*TN) + + int batch_stride_a = ne00*ne01; + int batch_stride_b = ne10*ne11; + int batch_stride_d = ne0*ne1; + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0_q4_1->q)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra0_q4_1->d)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra0_q4_1->m)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_mem), &extra1->data_device)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_ulong), &offset1)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(cl_ulong), &offsetd)); + CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, 8, sizeof(int), &ne01)); + CL_CHECK(clSetKernelArg(kernel, 9, sizeof(int), &ne02)); + CL_CHECK(clSetKernelArg(kernel, 10, sizeof(int), &ne11)); + CL_CHECK(clSetKernelArg(kernel, 11, sizeof(int), &ne12)); + CL_CHECK(clSetKernelArg(kernel, 12, sizeof(int), &ne10)); // stride_a + CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int), &ne10)); // stride_b + CL_CHECK(clSetKernelArg(kernel, 14, sizeof(int), &ne01)); // stride_d + CL_CHECK(clSetKernelArg(kernel, 15, sizeof(int), &batch_stride_a)); + CL_CHECK(clSetKernelArg(kernel, 16, sizeof(int), &batch_stride_b)); + CL_CHECK(clSetKernelArg(kernel, 17, sizeof(int), &batch_stride_d)); + CL_CHECK(clSetKernelArg(kernel, 18, sizeof(int), &r2)); + CL_CHECK(clSetKernelArg(kernel, 19, sizeof(int), &r3)); + + // 64 is block tile size BM and BN - change here when BM and BN in the kernel are changed. + size_t global_work_size[] = {(size_t)(CEIL_DIV(ne01, 64)*nth0), (size_t)(CEIL_DIV(ne11, 64)), (size_t)ne12*ne13}; + size_t local_work_size[] = {(size_t)nth0, 1, 1}; + + backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst); + return; + } + case GGML_TYPE_Q5_0: { + if (ne11 < 32) { + break; + } + if (!ggml_is_contiguous(src0) || !ggml_is_contiguous(src1)) { + break; + } + + kernel = backend_ctx->kernel_mul_mm_q5_0_f32_l4_lm; + nth0 = 128; // calculated as (BM*BN)/(TM*TN) + + int batch_stride_a = ne00*ne01; + int batch_stride_b = ne10*ne11; + int batch_stride_d = ne0*ne1; + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0_q5_0->qs)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra0_q5_0->qh)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra0_q5_0->d)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_mem), &extra1->data_device)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_ulong), &offset1)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(cl_ulong), &offsetd)); + CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, 8, sizeof(int), &ne01)); + CL_CHECK(clSetKernelArg(kernel, 9, sizeof(int), &ne02)); + CL_CHECK(clSetKernelArg(kernel, 10, sizeof(int), &ne11)); + CL_CHECK(clSetKernelArg(kernel, 11, sizeof(int), &ne12)); + CL_CHECK(clSetKernelArg(kernel, 12, sizeof(int), &ne10)); // stride_a + CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int), &ne10)); // stride_b + CL_CHECK(clSetKernelArg(kernel, 14, sizeof(int), &ne01)); // stride_d + CL_CHECK(clSetKernelArg(kernel, 15, sizeof(int), &batch_stride_a)); + CL_CHECK(clSetKernelArg(kernel, 16, sizeof(int), &batch_stride_b)); + CL_CHECK(clSetKernelArg(kernel, 17, sizeof(int), &batch_stride_d)); + CL_CHECK(clSetKernelArg(kernel, 18, sizeof(int), &r2)); + CL_CHECK(clSetKernelArg(kernel, 19, sizeof(int), &r3)); + + // 64 is block tile size BM and BN - change here when BM and BN in the kernel are changed. + size_t global_work_size[] = {(size_t)(CEIL_DIV(ne01, 64)*nth0), (size_t)(CEIL_DIV(ne11, 64)), (size_t)ne12*ne13}; + size_t local_work_size[] = {(size_t)nth0, 1, 1}; + + backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst); + return; + } + case GGML_TYPE_Q5_1: { + if (ne11 < 32) { + break; + } + if (!ggml_is_contiguous(src0) || !ggml_is_contiguous(src1)) { + break; + } - global_work_size[0] = (((M / 2) + wavesize - 1) / wavesize) * wavesize; - global_work_size[1] = 4; // reduce factor - global_work_size[2] = 1; - } - // <--------------------------------------------> // + kernel = backend_ctx->kernel_mul_mm_q5_1_f32_l4_lm; + nth0 = 128; // calculated as (BM*BN)/(TM*TN) - // enqueue kernel with profiling - // <--------------------------------------------> // - backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst); - // <--------------------------------------------> // + int batch_stride_a = ne00*ne01; + int batch_stride_b = ne10*ne11; + int batch_stride_d = ne0*ne1; - // deallocate sub buffers and images - // <--------------------------------------------> // - CL_CHECK(clReleaseMemObject(A_image1d)); - CL_CHECK(clReleaseMemObject(B_sub_buffer)); - CL_CHECK(clReleaseMemObject(B_image1d)); + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0_q5_1->qs)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra0_q5_1->qh)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra0_q5_1->d)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_mem), &extra0_q5_1->m)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extra1->data_device)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offset1)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, 7, sizeof(cl_ulong), &offsetd)); + CL_CHECK(clSetKernelArg(kernel, 8, sizeof(int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, 9, sizeof(int), &ne01)); + CL_CHECK(clSetKernelArg(kernel, 10, sizeof(int), &ne02)); + CL_CHECK(clSetKernelArg(kernel, 11, sizeof(int), &ne11)); + CL_CHECK(clSetKernelArg(kernel, 12, sizeof(int), &ne12)); + CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int), &ne10)); // stride_a + CL_CHECK(clSetKernelArg(kernel, 14, sizeof(int), &ne10)); // stride_b + CL_CHECK(clSetKernelArg(kernel, 15, sizeof(int), &ne01)); // stride_d + CL_CHECK(clSetKernelArg(kernel, 16, sizeof(int), &batch_stride_a)); + CL_CHECK(clSetKernelArg(kernel, 17, sizeof(int), &batch_stride_b)); + CL_CHECK(clSetKernelArg(kernel, 18, sizeof(int), &batch_stride_d)); + CL_CHECK(clSetKernelArg(kernel, 19, sizeof(int), &r2)); + CL_CHECK(clSetKernelArg(kernel, 20, sizeof(int), &r3)); - if (N != 1) { - CL_CHECK(clReleaseMemObject(B_d)); - CL_CHECK(clReleaseMemObject(B_d_input_image)); - CL_CHECK(clReleaseMemObject(C_d)); - } - // <--------------------------------------------> // + // 64 is block tile size BM and BN - change here when BM and BN in the kernel are changed. + size_t global_work_size[] = {(size_t)(CEIL_DIV(ne01, 64)*nth0), (size_t)(CEIL_DIV(ne11, 64)), (size_t)ne12*ne13}; + size_t local_work_size[] = {(size_t)nth0, 1, 1}; - return; - } - } // if (ne01 && ne1) -#endif // GGML_OPENCL_USE_ADRENO_KERNELS + backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst); + return; + } + case GGML_TYPE_Q8_0: { + if (ne11 < 32) { + break; + } + if (!ggml_is_contiguous(src0) || !ggml_is_contiguous(src1)) { + break; + } - // GEMM using local memory - // Current BK = 16, so ne00 % 16 == 0 - if (ggml_is_contiguous(src0) && - ggml_is_contiguous(src1) && - src1t == GGML_TYPE_F32 && - ne00 % 16 == 0 && - ne11 > 1) { - switch(src0t) { - case GGML_TYPE_F32: { - kernel = backend_ctx->kernel_mul_mm_f32_f32_l4_lm; + kernel = backend_ctx->kernel_mul_mm_q8_0_f32_l4_lm; nth0 = 128; // calculated as (BM*BN)/(TM*TN) int batch_stride_a = ne00*ne01; int batch_stride_b = ne10*ne11; int batch_stride_d = ne0*ne1; - CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device)); - CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0)); + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0_q8_0->q)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra0_q8_0->d)); CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra1->data_device)); CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offset1)); CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extrad->data_device)); @@ -8001,16 +14255,23 @@ static void ggml_cl_mul_mat(ggml_backend_t backend, const ggml_tensor * src0, co backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst); return; } - case GGML_TYPE_F16: { - kernel = backend_ctx->kernel_mul_mm_f16_f32_l4_lm; + case GGML_TYPE_IQ4_NL: { + if (ne11 < 32) { + break; + } + if (!ggml_is_contiguous(src0) || !ggml_is_contiguous(src1)) { + break; + } + + kernel = backend_ctx->kernel_mul_mm_iq4_nl_f32_l4_lm; nth0 = 128; // calculated as (BM*BN)/(TM*TN) int batch_stride_a = ne00*ne01; int batch_stride_b = ne10*ne11; int batch_stride_d = ne0*ne1; - CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device)); - CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0)); + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0_iq4_nl->q)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra0_iq4_nl->d)); CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra1->data_device)); CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offset1)); CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extrad->data_device)); @@ -8036,36 +14297,131 @@ static void ggml_cl_mul_mat(ggml_backend_t backend, const ggml_tensor * src0, co backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst); return; } - case GGML_TYPE_Q8_0: { + case GGML_TYPE_Q4_K: { if (ne11 < 32) { break; } - kernel = backend_ctx->kernel_mul_mm_q8_0_f32_l4_lm; + if (!ggml_is_contiguous(src0) || !ggml_is_contiguous(src1)) { + break; + } + + kernel = backend_ctx->kernel_mul_mm_q4_k_f32_l4_lm; nth0 = 128; // calculated as (BM*BN)/(TM*TN) int batch_stride_a = ne00*ne01; int batch_stride_b = ne10*ne11; int batch_stride_d = ne0*ne1; - CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0_q8_0->q)); - CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra0_q8_0->d)); - CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra1->data_device)); - CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offset1)); - CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extrad->data_device)); - CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offsetd)); - CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne00)); - CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &ne01)); - CL_CHECK(clSetKernelArg(kernel, 8, sizeof(int), &ne02)); - CL_CHECK(clSetKernelArg(kernel, 9, sizeof(int), &ne11)); - CL_CHECK(clSetKernelArg(kernel, 10, sizeof(int), &ne12)); - CL_CHECK(clSetKernelArg(kernel, 11, sizeof(int), &ne10)); // stride_a - CL_CHECK(clSetKernelArg(kernel, 12, sizeof(int), &ne10)); // stride_b - CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int), &ne01)); // stride_d - CL_CHECK(clSetKernelArg(kernel, 14, sizeof(int), &batch_stride_a)); - CL_CHECK(clSetKernelArg(kernel, 15, sizeof(int), &batch_stride_b)); - CL_CHECK(clSetKernelArg(kernel, 16, sizeof(int), &batch_stride_d)); - CL_CHECK(clSetKernelArg(kernel, 17, sizeof(int), &r2)); - CL_CHECK(clSetKernelArg(kernel, 18, sizeof(int), &r3)); + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0_q4_K->q)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra0_q4_K->s)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra0_q4_K->d)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_mem), &extra0_q4_K->dm)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extra1->data_device)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offset1)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, 7, sizeof(cl_ulong), &offsetd)); + CL_CHECK(clSetKernelArg(kernel, 8, sizeof(int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, 9, sizeof(int), &ne01)); + CL_CHECK(clSetKernelArg(kernel, 10, sizeof(int), &ne02)); + CL_CHECK(clSetKernelArg(kernel, 11, sizeof(int), &ne11)); + CL_CHECK(clSetKernelArg(kernel, 12, sizeof(int), &ne12)); + CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int), &ne10)); // stride_a + CL_CHECK(clSetKernelArg(kernel, 14, sizeof(int), &ne10)); // stride_b + CL_CHECK(clSetKernelArg(kernel, 15, sizeof(int), &ne01)); // stride_d + CL_CHECK(clSetKernelArg(kernel, 16, sizeof(int), &batch_stride_a)); + CL_CHECK(clSetKernelArg(kernel, 17, sizeof(int), &batch_stride_b)); + CL_CHECK(clSetKernelArg(kernel, 18, sizeof(int), &batch_stride_d)); + CL_CHECK(clSetKernelArg(kernel, 19, sizeof(int), &r2)); + CL_CHECK(clSetKernelArg(kernel, 20, sizeof(int), &r3)); + + // 64 is block tile size BM and BN - change here when BM and BN in the kernel are changed. + size_t global_work_size[] = {(size_t)(CEIL_DIV(ne01, 64)*nth0), (size_t)(CEIL_DIV(ne11, 64)), (size_t)ne12*ne13}; + size_t local_work_size[] = {(size_t)nth0, 1, 1}; + + backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst); + return; + } + case GGML_TYPE_Q5_K: { + if (ne11 < 32) { + break; + } + if (!ggml_is_contiguous(src0) || !ggml_is_contiguous(src1)) { + break; + } + + kernel = backend_ctx->kernel_mul_mm_q5_k_f32_l4_lm; + nth0 = 128; // calculated as (BM*BN)/(TM*TN) + + int batch_stride_a = ne00*ne01; + int batch_stride_b = ne10*ne11; + int batch_stride_d = ne0*ne1; + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0_q5_K->q)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra0_q5_K->qh)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra0_q5_K->s)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_mem), &extra0_q5_K->d)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extra0_q5_K->dm)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_mem), &extra1->data_device)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(cl_ulong), &offset1)); + CL_CHECK(clSetKernelArg(kernel, 7, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, 8, sizeof(cl_ulong), &offsetd)); + CL_CHECK(clSetKernelArg(kernel, 9, sizeof(int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, 10, sizeof(int), &ne01)); + CL_CHECK(clSetKernelArg(kernel, 11, sizeof(int), &ne02)); + CL_CHECK(clSetKernelArg(kernel, 12, sizeof(int), &ne11)); + CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int), &ne12)); + CL_CHECK(clSetKernelArg(kernel, 14, sizeof(int), &ne10)); // stride_a + CL_CHECK(clSetKernelArg(kernel, 15, sizeof(int), &ne10)); // stride_b + CL_CHECK(clSetKernelArg(kernel, 16, sizeof(int), &ne01)); // stride_d + CL_CHECK(clSetKernelArg(kernel, 17, sizeof(int), &batch_stride_a)); + CL_CHECK(clSetKernelArg(kernel, 18, sizeof(int), &batch_stride_b)); + CL_CHECK(clSetKernelArg(kernel, 19, sizeof(int), &batch_stride_d)); + CL_CHECK(clSetKernelArg(kernel, 20, sizeof(int), &r2)); + CL_CHECK(clSetKernelArg(kernel, 21, sizeof(int), &r3)); + + // 64 is block tile size BM and BN - change here when BM and BN in the kernel are changed. + size_t global_work_size[] = {(size_t)(CEIL_DIV(ne01, 64)*nth0), (size_t)(CEIL_DIV(ne11, 64)), (size_t)ne12*ne13}; + size_t local_work_size[] = {(size_t)nth0, 1, 1}; + + backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst); + return; + } + case GGML_TYPE_Q6_K: { + if (ne11 < 32) { + break; + } + if (!ggml_is_contiguous(src0) || !ggml_is_contiguous(src1)) { + break; + } + + kernel = backend_ctx->kernel_mul_mm_q6_k_f32_l4_lm; + nth0 = 128; // calculated as (BM*BN)/(TM*TN) + + int batch_stride_a = ne00*ne01; + int batch_stride_b = ne10*ne11; + int batch_stride_d = ne0*ne1; + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0_q6_K->ql)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra0_q6_K->qh)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra0_q6_K->s)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_mem), &extra0_q6_K->d)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extra1->data_device)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offset1)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, 7, sizeof(cl_ulong), &offsetd)); + CL_CHECK(clSetKernelArg(kernel, 8, sizeof(int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, 9, sizeof(int), &ne01)); + CL_CHECK(clSetKernelArg(kernel, 10, sizeof(int), &ne02)); + CL_CHECK(clSetKernelArg(kernel, 11, sizeof(int), &ne11)); + CL_CHECK(clSetKernelArg(kernel, 12, sizeof(int), &ne12)); + CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int), &ne10)); // stride_a + CL_CHECK(clSetKernelArg(kernel, 14, sizeof(int), &ne10)); // stride_b + CL_CHECK(clSetKernelArg(kernel, 15, sizeof(int), &ne01)); // stride_d + CL_CHECK(clSetKernelArg(kernel, 16, sizeof(int), &batch_stride_a)); + CL_CHECK(clSetKernelArg(kernel, 17, sizeof(int), &batch_stride_b)); + CL_CHECK(clSetKernelArg(kernel, 18, sizeof(int), &batch_stride_d)); + CL_CHECK(clSetKernelArg(kernel, 19, sizeof(int), &r2)); + CL_CHECK(clSetKernelArg(kernel, 20, sizeof(int), &r3)); // 64 is block tile size BM and BN - change here when BM and BN in the kernel are changed. size_t global_work_size[] = {(size_t)(CEIL_DIV(ne01, 64)*nth0), (size_t)(CEIL_DIV(ne11, 64)), (size_t)ne12*ne13}; @@ -8213,20 +14569,167 @@ static void ggml_cl_mul_mat(ggml_backend_t backend, const ggml_tensor * src0, co GGML_ASSERT(false && "TODO: Unknown GPU"); } - if (src1t == GGML_TYPE_F32) { - if (ne11 * ne12 < 4) { - kernel = backend_ctx->kernel_mul_mat_f16_f32_1row; - } else if (ne00 >= 128 && ne01 >= 8 && ne00%4 == 0) { - kernel = backend_ctx->kernel_mul_mat_f16_f32_l4; - nrows = ne11; - } else { - kernel = backend_ctx->kernel_mul_mat_f16_f32; - nrows = 4; - } - } else { - kernel = backend_ctx->kernel_mul_mat_f16_f16; - nrows = 4; - } + if (src1t == GGML_TYPE_F32) { + if (ne11 * ne12 < 4) { + kernel = backend_ctx->kernel_mul_mat_f16_f32_1row; + } else if (ne00 >= 128 && ne01 >= 8 && ne00%4 == 0) { + kernel = backend_ctx->kernel_mul_mat_f16_f32_l4; + nrows = ne11; + } else { + kernel = backend_ctx->kernel_mul_mat_f16_f32; + nrows = 4; + } + } else { + kernel = backend_ctx->kernel_mul_mat_f16_f16; + nrows = 4; + } + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra1->data_device)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offset1)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offsetd)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &ne01)); + CL_CHECK(clSetKernelArg(kernel, 8, sizeof(int), &ne02)); + CL_CHECK(clSetKernelArg(kernel, 9, sizeof(cl_ulong), &nb00)); + CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_ulong), &nb01)); + CL_CHECK(clSetKernelArg(kernel, 11, sizeof(cl_ulong), &nb02)); + CL_CHECK(clSetKernelArg(kernel, 12, sizeof(cl_ulong), &nb03)); + CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int), &ne10)); + CL_CHECK(clSetKernelArg(kernel, 14, sizeof(int), &ne11)); + CL_CHECK(clSetKernelArg(kernel, 15, sizeof(int), &ne12)); + CL_CHECK(clSetKernelArg(kernel, 16, sizeof(cl_ulong), &nb10)); + CL_CHECK(clSetKernelArg(kernel, 17, sizeof(cl_ulong), &nb11)); + CL_CHECK(clSetKernelArg(kernel, 18, sizeof(cl_ulong), &nb12)); + CL_CHECK(clSetKernelArg(kernel, 19, sizeof(cl_ulong), &nb13)); + CL_CHECK(clSetKernelArg(kernel, 20, sizeof(int), &ne0)); + CL_CHECK(clSetKernelArg(kernel, 21, sizeof(int), &ne1)); + CL_CHECK(clSetKernelArg(kernel, 22, sizeof(int), &r2)); + CL_CHECK(clSetKernelArg(kernel, 23, sizeof(int), &r3)); + break; + case GGML_TYPE_Q4_0: + // This should have been satisfied. + GGML_ASSERT(ne11 == ne1); + GGML_ASSERT(ne01 == ne0); + +#ifdef GGML_OPENCL_SOA_Q + if (backend_ctx->gpu_family == INTEL) { + nth0 = 16; + nth1 = 1; + + kernel = backend_ctx->kernel_mul_mat_q4_0_f32_8x_flat; + ndst = 8; + } else if (backend_ctx->gpu_family == ADRENO) { + nth0 = 64; + nth1 = 1; + + kernel = backend_ctx->kernel_mul_mat_q4_0_f32_8x_flat; + ndst =8; + } else { + GGML_ASSERT(false && "TODO: Unknown GPU"); + } + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0_q4_0->q)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra0_q4_0->d)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra1->data_device)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offset1)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offsetd)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &ne01)); + CL_CHECK(clSetKernelArg(kernel, 8, sizeof(int), &ne02)); + CL_CHECK(clSetKernelArg(kernel, 9, sizeof(int), &ne10)); + CL_CHECK(clSetKernelArg(kernel, 10, sizeof(int), &ne12)); + CL_CHECK(clSetKernelArg(kernel, 11, sizeof(int), &ne0)); + CL_CHECK(clSetKernelArg(kernel, 12, sizeof(int), &ne1)); + CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int), &r2)); + CL_CHECK(clSetKernelArg(kernel, 14, sizeof(int), &r3)); +#else // GGML_OPENCL_SOA_Q + if (backend_ctx->gpu_family == INTEL) { + // Use 1D local size. Each workgroup is a SIMD group. Each SIMD + // group produces N_DST (4 for Q4_0 kernel) values in the result. + // The number of workgroups on dim 0 (the leading dimension) is + // the nearest multiple of 4 that covers ne0 (equals ne01). + nth0 = 16; + nth1 = 1; + + kernel = backend_ctx->kernel_mul_mat_q4_0_f32; + ndst = 4; + } else if (backend_ctx->gpu_family == ADRENO) { + nth0 = 64; + nth1 = 1; + + kernel = backend_ctx->kernel_mul_mat_q4_0_f32_v; + ndst = 4; + } else { + GGML_ASSERT(false && "TODO: Unknown GPU"); + } + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra1->data_device)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offset1)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offsetd)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &ne01)); + CL_CHECK(clSetKernelArg(kernel, 8, sizeof(int), &ne02)); + CL_CHECK(clSetKernelArg(kernel, 9, sizeof(int), &ne10)); + CL_CHECK(clSetKernelArg(kernel, 10, sizeof(int), &ne12)); + CL_CHECK(clSetKernelArg(kernel, 11, sizeof(int), &ne0)); + CL_CHECK(clSetKernelArg(kernel, 12, sizeof(int), &ne1)); + CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int), &r2)); + CL_CHECK(clSetKernelArg(kernel, 14, sizeof(int), &r3)); +#endif // GGML_OPENCL_SOA_Q + break; + case GGML_TYPE_Q4_1: { +#ifdef GGML_OPENCL_SOA_Q + if (backend_ctx->gpu_family == INTEL) { + nth0 = 16; + nth1 = 1; + ndst = 4; + } else if (backend_ctx->gpu_family == ADRENO) { + nth0 = 64; + nth1 = 1; + ndst = 4; + } else { + GGML_ASSERT(false && "TODO: Unknown GPU"); + } + + kernel = backend_ctx->kernel_mul_mv_q4_1_f32_flat; + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0_q4_1->q)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra0_q4_1->d)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra0_q4_1->m)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_mem), &extra1->data_device)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_ulong), &offset1)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(cl_ulong), &offsetd)); + CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, 8, sizeof(int), &ne01)); + CL_CHECK(clSetKernelArg(kernel, 9, sizeof(int), &ne02)); + CL_CHECK(clSetKernelArg(kernel, 10, sizeof(int), &ne10)); + CL_CHECK(clSetKernelArg(kernel, 11, sizeof(int), &ne12)); + CL_CHECK(clSetKernelArg(kernel, 12, sizeof(int), &ne0)); + CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int), &ne1)); + CL_CHECK(clSetKernelArg(kernel, 14, sizeof(int), &r2)); + CL_CHECK(clSetKernelArg(kernel, 15, sizeof(int), &r3)); +#else + if (backend_ctx->gpu_family == INTEL) { + nth0 = 16; + nth1 = 1; + ndst = 4; + } else if (backend_ctx->gpu_family == ADRENO) { + nth0 = 64; + nth1 = 1; + ndst = 4; + } else { + GGML_ASSERT(false && "TODO: Unknown GPU"); + } + + kernel = backend_ctx->kernel_mul_mv_q4_1_f32; CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device)); CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0)); @@ -8237,46 +14740,64 @@ static void ggml_cl_mul_mat(ggml_backend_t backend, const ggml_tensor * src0, co CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne00)); CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &ne01)); CL_CHECK(clSetKernelArg(kernel, 8, sizeof(int), &ne02)); - CL_CHECK(clSetKernelArg(kernel, 9, sizeof(cl_ulong), &nb00)); - CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_ulong), &nb01)); - CL_CHECK(clSetKernelArg(kernel, 11, sizeof(cl_ulong), &nb02)); - CL_CHECK(clSetKernelArg(kernel, 12, sizeof(cl_ulong), &nb03)); - CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int), &ne10)); - CL_CHECK(clSetKernelArg(kernel, 14, sizeof(int), &ne11)); - CL_CHECK(clSetKernelArg(kernel, 15, sizeof(int), &ne12)); - CL_CHECK(clSetKernelArg(kernel, 16, sizeof(cl_ulong), &nb10)); - CL_CHECK(clSetKernelArg(kernel, 17, sizeof(cl_ulong), &nb11)); - CL_CHECK(clSetKernelArg(kernel, 18, sizeof(cl_ulong), &nb12)); - CL_CHECK(clSetKernelArg(kernel, 19, sizeof(cl_ulong), &nb13)); - CL_CHECK(clSetKernelArg(kernel, 20, sizeof(int), &ne0)); - CL_CHECK(clSetKernelArg(kernel, 21, sizeof(int), &ne1)); - CL_CHECK(clSetKernelArg(kernel, 22, sizeof(int), &r2)); - CL_CHECK(clSetKernelArg(kernel, 23, sizeof(int), &r3)); + CL_CHECK(clSetKernelArg(kernel, 9, sizeof(int), &ne10)); + CL_CHECK(clSetKernelArg(kernel, 10, sizeof(int), &ne12)); + CL_CHECK(clSetKernelArg(kernel, 11, sizeof(int), &ne0)); + CL_CHECK(clSetKernelArg(kernel, 12, sizeof(int), &ne1)); + CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int), &r2)); + CL_CHECK(clSetKernelArg(kernel, 14, sizeof(int), &r3)); +#endif // GGML_OPENCL_SOA_Q break; - case GGML_TYPE_Q4_0: - // This should have been satisfied. - GGML_ASSERT(ne11 == ne1); - GGML_ASSERT(ne01 == ne0); - + } + case GGML_TYPE_Q5_0: { #ifdef GGML_OPENCL_SOA_Q if (backend_ctx->gpu_family == INTEL) { nth0 = 16; nth1 = 1; - - kernel = backend_ctx->kernel_mul_mat_q4_0_f32_8x_flat; - ndst = 8; + ndst = 4; } else if (backend_ctx->gpu_family == ADRENO) { nth0 = 64; nth1 = 1; + ndst = 4; + } else { + GGML_ASSERT(false && "TODO: Unknown GPU"); + } - kernel = backend_ctx->kernel_mul_mat_q4_0_f32_8x_flat; - ndst =8; + kernel = backend_ctx->kernel_mul_mv_q5_0_f32_flat; + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0_q5_0->qs)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra0_q5_0->qh)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra0_q5_0->d)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_mem), &extra1->data_device)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_ulong), &offset1)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(cl_ulong), &offsetd)); + CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, 8, sizeof(int), &ne01)); + CL_CHECK(clSetKernelArg(kernel, 9, sizeof(int), &ne02)); + CL_CHECK(clSetKernelArg(kernel, 10, sizeof(int), &ne10)); + CL_CHECK(clSetKernelArg(kernel, 11, sizeof(int), &ne12)); + CL_CHECK(clSetKernelArg(kernel, 12, sizeof(int), &ne0)); + CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int), &ne1)); + CL_CHECK(clSetKernelArg(kernel, 14, sizeof(int), &r2)); + CL_CHECK(clSetKernelArg(kernel, 15, sizeof(int), &r3)); +#else + if (backend_ctx->gpu_family == INTEL) { + nth0 = 16; + nth1 = 1; + ndst = 4; + } else if (backend_ctx->gpu_family == ADRENO) { + nth0 = 64; + nth1 = 1; + ndst = 4; } else { GGML_ASSERT(false && "TODO: Unknown GPU"); } - CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0_q4_0->q)); - CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra0_q4_0->d)); + kernel = backend_ctx->kernel_mul_mv_q5_0_f32; + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0)); CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra1->data_device)); CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offset1)); CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extrad->data_device)); @@ -8290,27 +14811,57 @@ static void ggml_cl_mul_mat(ggml_backend_t backend, const ggml_tensor * src0, co CL_CHECK(clSetKernelArg(kernel, 12, sizeof(int), &ne1)); CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int), &r2)); CL_CHECK(clSetKernelArg(kernel, 14, sizeof(int), &r3)); -#else // GGML_OPENCL_SOA_Q +#endif // GGML_OPENCL_SOA_Q + break; + } + case GGML_TYPE_Q5_1: { +#ifdef GGML_OPENCL_SOA_Q if (backend_ctx->gpu_family == INTEL) { - // Use 1D local size. Each workgroup is a SIMD group. Each SIMD - // group produces N_DST (4 for Q4_0 kernel) values in the result. - // The number of workgroups on dim 0 (the leading dimension) is - // the nearest multiple of 4 that covers ne0 (equals ne01). nth0 = 16; nth1 = 1; - - kernel = backend_ctx->kernel_mul_mat_q4_0_f32; ndst = 4; } else if (backend_ctx->gpu_family == ADRENO) { nth0 = 64; nth1 = 1; + ndst = 4; + } else { + GGML_ASSERT(false && "TODO: Unknown GPU"); + } - kernel = backend_ctx->kernel_mul_mat_q4_0_f32_v; + kernel = backend_ctx->kernel_mul_mv_q5_1_f32_flat; + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0_q5_1->qs)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra0_q5_1->qh)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra0_q5_1->d)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_mem), &extra0_q5_1->m)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extra1->data_device)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offset1)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, 7, sizeof(cl_ulong), &offsetd)); + CL_CHECK(clSetKernelArg(kernel, 8, sizeof(int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, 9, sizeof(int), &ne01)); + CL_CHECK(clSetKernelArg(kernel, 10, sizeof(int), &ne02)); + CL_CHECK(clSetKernelArg(kernel, 11, sizeof(int), &ne10)); + CL_CHECK(clSetKernelArg(kernel, 12, sizeof(int), &ne12)); + CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int), &ne0)); + CL_CHECK(clSetKernelArg(kernel, 14, sizeof(int), &ne1)); + CL_CHECK(clSetKernelArg(kernel, 15, sizeof(int), &r2)); + CL_CHECK(clSetKernelArg(kernel, 16, sizeof(int), &r3)); +#else + if (backend_ctx->gpu_family == INTEL) { + nth0 = 16; + nth1 = 1; + ndst = 4; + } else if (backend_ctx->gpu_family == ADRENO) { + nth0 = 64; + nth1 = 1; ndst = 4; } else { GGML_ASSERT(false && "TODO: Unknown GPU"); } + kernel = backend_ctx->kernel_mul_mv_q5_1_f32; + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device)); CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0)); CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra1->data_device)); @@ -8328,7 +14879,7 @@ static void ggml_cl_mul_mat(ggml_backend_t backend, const ggml_tensor * src0, co CL_CHECK(clSetKernelArg(kernel, 14, sizeof(int), &r3)); #endif // GGML_OPENCL_SOA_Q break; - case GGML_TYPE_Q4_1: + } case GGML_TYPE_Q8_0: { #ifdef GGML_OPENCL_SOA_Q kernel = backend_ctx->kernel_mul_mv_q8_0_f32_flat; @@ -8368,29 +14919,244 @@ static void ggml_cl_mul_mat(ggml_backend_t backend, const ggml_tensor * src0, co CL_CHECK(clSetKernelArg(kernel, 17, sizeof(int), &r2)); CL_CHECK(clSetKernelArg(kernel, 18, sizeof(int), &r3)); #else - kernel = backend_ctx->kernel_mul_mv_q8_0_f32; + kernel = backend_ctx->kernel_mul_mv_q8_0_f32; + + // nth0 - subgroup size + // nth1 - number of subgroups per workgroup + // ndst - number of output values per workgroup = output per subgroup * number of subgroups + if (backend_ctx->gpu_family == INTEL) { + nth0 = 16; + nth1 = 2; + ndst = nth1*4; + } else if (backend_ctx->gpu_family == ADRENO) { + nth0 = 64; + nth1 = 2; + ndst = nth1*4; + } else { + GGML_ASSERT(false && "TODO: Unknown GPU"); + } + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra1->data_device)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offset1)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offsetd)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &ne01)); + CL_CHECK(clSetKernelArg(kernel, 8, sizeof(cl_ulong), &nb01)); + CL_CHECK(clSetKernelArg(kernel, 9, sizeof(cl_ulong), &nb02)); + CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_ulong), &nb03)); + CL_CHECK(clSetKernelArg(kernel, 11, sizeof(int), &ne12)); + CL_CHECK(clSetKernelArg(kernel, 12, sizeof(cl_ulong), &nb11)); + CL_CHECK(clSetKernelArg(kernel, 13, sizeof(cl_ulong), &nb12)); + CL_CHECK(clSetKernelArg(kernel, 14, sizeof(cl_ulong), &nb13)); + CL_CHECK(clSetKernelArg(kernel, 15, sizeof(int), &ne0)); + CL_CHECK(clSetKernelArg(kernel, 16, sizeof(int), &ne1)); + CL_CHECK(clSetKernelArg(kernel, 17, sizeof(int), &r2)); + CL_CHECK(clSetKernelArg(kernel, 18, sizeof(int), &r3)); +#endif // GGML_OPENCL_SOA_Q + break; + } + case GGML_TYPE_IQ4_NL: { +#ifdef GGML_OPENCL_SOA_Q + kernel = backend_ctx->kernel_mul_mv_iq4_nl_f32_flat; + + if (backend_ctx->gpu_family == INTEL) { + nth0 = 16; + nth1 = 1; + ndst = 8; + } else if (backend_ctx->gpu_family == ADRENO) { + nth0 = 64; + nth1 = 1; + ndst = 8; + } else { + GGML_ASSERT(false && "TODO: Unknown GPU"); + } + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0_iq4_nl->q)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra0_iq4_nl->d)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra1->data_device)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offset1)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offsetd)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &ne01)); + CL_CHECK(clSetKernelArg(kernel, 8, sizeof(int), &ne02)); + CL_CHECK(clSetKernelArg(kernel, 9, sizeof(int), &ne10)); + CL_CHECK(clSetKernelArg(kernel, 10, sizeof(int), &ne12)); + CL_CHECK(clSetKernelArg(kernel, 11, sizeof(int), &ne0)); + CL_CHECK(clSetKernelArg(kernel, 12, sizeof(int), &ne1)); + CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int), &r2)); + CL_CHECK(clSetKernelArg(kernel, 14, sizeof(int), &r3)); +#else + kernel = backend_ctx->kernel_mul_mv_iq4_nl_f32; + + if (backend_ctx->gpu_family == INTEL) { + nth0 = 16; + nth1 = 1; + ndst = 4; + } else if (backend_ctx->gpu_family == ADRENO) { + nth0 = 64; + nth1 = 1; + ndst = 4; + } else { + GGML_ASSERT(false && "TODO: Unknown GPU"); + } + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra1->data_device)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offset1)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offsetd)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &ne01)); + CL_CHECK(clSetKernelArg(kernel, 8, sizeof(int), &ne02)); + CL_CHECK(clSetKernelArg(kernel, 9, sizeof(int), &ne10)); + CL_CHECK(clSetKernelArg(kernel, 10, sizeof(int), &ne12)); + CL_CHECK(clSetKernelArg(kernel, 11, sizeof(int), &ne0)); + CL_CHECK(clSetKernelArg(kernel, 12, sizeof(int), &ne1)); + CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int), &r2)); + CL_CHECK(clSetKernelArg(kernel, 14, sizeof(int), &r3)); +#endif // GGML_OPENCL_SOA_Q + break; + } + case GGML_TYPE_Q2_K: + case GGML_TYPE_Q3_K: + case GGML_TYPE_Q4_K: { +#ifdef GGML_OPENCL_SOA_Q + kernel = backend_ctx->kernel_mul_mv_q4_K_f32_flat; + + if (backend_ctx->gpu_family == INTEL) { + nth0 = 16; + nth1 = 1; + ndst = 4; + } else if (backend_ctx->gpu_family == ADRENO) { + nth0 = 64; + nth1 = 2; + ndst = 16; + } else { + GGML_ASSERT(false && "TODO: Unknown GPU"); + } + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0_q4_K->q)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra0_q4_K->s)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra0_q4_K->d)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_mem), &extra0_q4_K->dm)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extra1->data_device)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(int), &offset1)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &offsetd)); + CL_CHECK(clSetKernelArg(kernel, 8, sizeof(int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, 9, sizeof(int), &ne01)); + CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_ulong), &nb01)); + CL_CHECK(clSetKernelArg(kernel, 11, sizeof(cl_ulong), &nb02)); + CL_CHECK(clSetKernelArg(kernel, 12, sizeof(cl_ulong), &nb03)); + CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int), &ne12)); + CL_CHECK(clSetKernelArg(kernel, 14, sizeof(cl_ulong), &nb11)); + CL_CHECK(clSetKernelArg(kernel, 15, sizeof(cl_ulong), &nb12)); + CL_CHECK(clSetKernelArg(kernel, 16, sizeof(cl_ulong), &nb13)); + CL_CHECK(clSetKernelArg(kernel, 17, sizeof(int), &ne0)); + CL_CHECK(clSetKernelArg(kernel, 18, sizeof(int), &ne1)); + CL_CHECK(clSetKernelArg(kernel, 19, sizeof(int), &r2)); + CL_CHECK(clSetKernelArg(kernel, 20, sizeof(int), &r3)); +#else + kernel = backend_ctx->kernel_mul_mv_q4_K_f32; + + if (backend_ctx->gpu_family == INTEL) { + nth0 = 16; + nth1 = 1; + ndst = 4; + } else if (backend_ctx->gpu_family == ADRENO) { + nth0 = 64; + nth1 = 1; + ndst = 4; + } else { + GGML_ASSERT(false && "TODO: Unknown GPU"); + } + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(int), &offset0)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra1->data_device)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(int), &offset1)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(int), &offsetd)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &ne01)); + CL_CHECK(clSetKernelArg(kernel, 8, sizeof(cl_ulong), &nb01)); + CL_CHECK(clSetKernelArg(kernel, 9, sizeof(cl_ulong), &nb02)); + CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_ulong), &nb03)); + CL_CHECK(clSetKernelArg(kernel, 11, sizeof(int), &ne12)); + CL_CHECK(clSetKernelArg(kernel, 12, sizeof(cl_ulong), &nb11)); + CL_CHECK(clSetKernelArg(kernel, 13, sizeof(cl_ulong), &nb12)); + CL_CHECK(clSetKernelArg(kernel, 14, sizeof(cl_ulong), &nb13)); + CL_CHECK(clSetKernelArg(kernel, 15, sizeof(int), &ne0)); + CL_CHECK(clSetKernelArg(kernel, 16, sizeof(int), &ne1)); + CL_CHECK(clSetKernelArg(kernel, 17, sizeof(int), &r2)); + CL_CHECK(clSetKernelArg(kernel, 18, sizeof(int), &r3)); +#endif // GGML_OPENCL_SOA_Q + break; + } + case GGML_TYPE_Q5_K: { +#ifdef GGML_OPENCL_SOA_Q + kernel = backend_ctx->kernel_mul_mv_q5_K_f32_flat; + + if (backend_ctx->gpu_family == INTEL) { + nth0 = 16; + nth1 = 1; + ndst = 4; + } else if (backend_ctx->gpu_family == ADRENO) { + nth0 = 64; + nth1 = 2; + ndst = 16; + } else { + GGML_ASSERT(false && "TODO: Unknown GPU"); + } + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0_q5_K->q)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra0_q5_K->qh)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra0_q5_K->s)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_mem), &extra0_q5_K->d)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extra0_q5_K->dm)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_mem), &extra1->data_device)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &offset1)); + CL_CHECK(clSetKernelArg(kernel, 7, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, 8, sizeof(int), &offsetd)); + CL_CHECK(clSetKernelArg(kernel, 9, sizeof(int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, 10, sizeof(int), &ne01)); + CL_CHECK(clSetKernelArg(kernel, 11, sizeof(cl_ulong), &nb01)); + CL_CHECK(clSetKernelArg(kernel, 12, sizeof(cl_ulong), &nb02)); + CL_CHECK(clSetKernelArg(kernel, 13, sizeof(cl_ulong), &nb03)); + CL_CHECK(clSetKernelArg(kernel, 14, sizeof(int), &ne12)); + CL_CHECK(clSetKernelArg(kernel, 15, sizeof(cl_ulong), &nb11)); + CL_CHECK(clSetKernelArg(kernel, 16, sizeof(cl_ulong), &nb12)); + CL_CHECK(clSetKernelArg(kernel, 17, sizeof(cl_ulong), &nb13)); + CL_CHECK(clSetKernelArg(kernel, 18, sizeof(int), &ne0)); + CL_CHECK(clSetKernelArg(kernel, 19, sizeof(int), &ne1)); + CL_CHECK(clSetKernelArg(kernel, 20, sizeof(int), &r2)); + CL_CHECK(clSetKernelArg(kernel, 21, sizeof(int), &r3)); +#else + kernel = backend_ctx->kernel_mul_mv_q5_K_f32; - // nth0 - subgroup size - // nth1 - number of subgroups per workgroup - // ndst - number of output values per workgroup = output per subgroup * number of subgroups if (backend_ctx->gpu_family == INTEL) { nth0 = 16; - nth1 = 2; - ndst = nth1*4; + nth1 = 1; + ndst = 4; } else if (backend_ctx->gpu_family == ADRENO) { nth0 = 64; - nth1 = 2; - ndst = nth1*4; + nth1 = 1; + ndst = 4; } else { GGML_ASSERT(false && "TODO: Unknown GPU"); } CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device)); - CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(int), &offset0)); CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra1->data_device)); - CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offset1)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(int), &offset1)); CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extrad->data_device)); - CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offsetd)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(int), &offsetd)); CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne00)); CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &ne01)); CL_CHECK(clSetKernelArg(kernel, 8, sizeof(cl_ulong), &nb01)); @@ -8407,19 +15173,50 @@ static void ggml_cl_mul_mat(ggml_backend_t backend, const ggml_tensor * src0, co #endif // GGML_OPENCL_SOA_Q break; } - case GGML_TYPE_Q2_K: - case GGML_TYPE_Q3_K: - case GGML_TYPE_Q4_K: - case GGML_TYPE_Q5_K: case GGML_TYPE_Q6_K: +#ifdef GGML_OPENCL_SOA_Q + kernel = backend_ctx->kernel_mul_mv_q6_K_f32_flat; + + if (backend_ctx->gpu_family == INTEL) { + nth0 = 16; + nth1 = 2; + ndst = 4; + } else if (backend_ctx->gpu_family == ADRENO) { + nth0 = 64; + nth1 = 2; + ndst = 16; + } else { + GGML_ASSERT(false && "TODO: Unknown GPU"); + } + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0_q6_K->ql)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra0_q6_K->qh)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra0_q6_K->s)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_mem), &extra0_q6_K->d)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extra1->data_device)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offset1)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, 7, sizeof(cl_ulong), &offsetd)); + CL_CHECK(clSetKernelArg(kernel, 8, sizeof(int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, 9, sizeof(int), &ne01)); + CL_CHECK(clSetKernelArg(kernel, 10, sizeof(int), &ne02)); + CL_CHECK(clSetKernelArg(kernel, 11, sizeof(int), &ne10)); + CL_CHECK(clSetKernelArg(kernel, 12, sizeof(int), &ne12)); + CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int), &ne0)); + CL_CHECK(clSetKernelArg(kernel, 14, sizeof(int), &ne1)); + CL_CHECK(clSetKernelArg(kernel, 15, sizeof(int), &r2)); + CL_CHECK(clSetKernelArg(kernel, 16, sizeof(int), &r3)); +#else kernel = backend_ctx->kernel_mul_mv_q6_K_f32; if (backend_ctx->gpu_family == INTEL) { - nth0 = 2; - nth1 = 16; + nth0 = 16; + nth1 = 2; + ndst = 1; } else if (backend_ctx->gpu_family == ADRENO) { - nth0 = 2; - nth1 = 64; + nth0 = 64; + nth1 = 2; + ndst = 1; } else { GGML_ASSERT(false && "TODO: Unknown GPU"); } @@ -8439,6 +15236,7 @@ static void ggml_cl_mul_mat(ggml_backend_t backend, const ggml_tensor * src0, co CL_CHECK(clSetKernelArg(kernel, 12, sizeof(int), &ne1)); CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int), &r2)); CL_CHECK(clSetKernelArg(kernel, 14, sizeof(int), &r3)); +#endif // GGML_OPENCL_SOA_Q break; case GGML_TYPE_MXFP4: { #ifdef GGML_OPENCL_SOA_Q @@ -8483,197 +15281,1042 @@ static void ggml_cl_mul_mat(ggml_backend_t backend, const ggml_tensor * src0, co kernel = backend_ctx->kernel_mul_mv_mxfp4_f32; if (backend_ctx->gpu_family == INTEL) { - nth0 = 16; - nth1 = 2; - ndst = nth1*2; + nth0 = 16; + nth1 = 2; + ndst = nth1*2; + } else if (backend_ctx->gpu_family == ADRENO) { + nth0 = 64; + nth1 = 2; + ndst = nth1*2; + } else { + GGML_ASSERT(false && "TODO: Unknown GPU"); + } + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra1->data_device)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offset1)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offsetd)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, 7, sizeof(cl_ulong), &nb01)); + CL_CHECK(clSetKernelArg(kernel, 8, sizeof(cl_ulong), &nb02)); + CL_CHECK(clSetKernelArg(kernel, 9, sizeof(cl_ulong), &nb03)); + CL_CHECK(clSetKernelArg(kernel, 10, sizeof(int), &ne12)); + CL_CHECK(clSetKernelArg(kernel, 11, sizeof(cl_ulong), &nb11)); + CL_CHECK(clSetKernelArg(kernel, 12, sizeof(cl_ulong), &nb12)); + CL_CHECK(clSetKernelArg(kernel, 13, sizeof(cl_ulong), &nb13)); + CL_CHECK(clSetKernelArg(kernel, 14, sizeof(int), &ne0)); + CL_CHECK(clSetKernelArg(kernel, 15, sizeof(int), &ne1)); + CL_CHECK(clSetKernelArg(kernel, 16, sizeof(int), &r2)); + CL_CHECK(clSetKernelArg(kernel, 17, sizeof(int), &r3)); + CL_CHECK(clSetKernelArg(kernel, 18, sizeof(float)*nth0,nullptr)); +#endif + break; + } + default: + GGML_ASSERT(false && "not implemented"); + } + + if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_MXFP4 || + src0t == GGML_TYPE_Q4_1 || + src0t == GGML_TYPE_Q5_0 || + src0t == GGML_TYPE_Q5_1 || + src0t == GGML_TYPE_Q8_0 || + src0t == GGML_TYPE_IQ4_NL || + src0t == GGML_TYPE_Q2_K) { + // Each SIMD group produces N_DST values in the result. Assuming each + // workgroup has N_SIMDGROUP SIMD groups, then each workgroup will + // produce N_DST*N_SIMDGROUP values in the result. Hence, the grid size + // (number of workgroups) will be a nearest multiple of + // N_DST*N_SIMDGROUP to cover the size of the dimension. Below, 4 is + // N_DST*N_SIMDGROUP (see the kernel for Q4_0 matmul). + size_t global_work_size[] = {(size_t)(ne01 + ndst-1)/ndst*nth0, (size_t)ne11*nth1, (size_t)ne12*ne13}; + size_t local_work_size[] = {(size_t)nth0, (size_t)nth1, 1}; + + backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst); + } else if (src0t == GGML_TYPE_Q4_K) { + size_t global_work_size[] = {(size_t)(ne01+ndst*nth1-1)/(ndst*nth1)*nth0, (size_t)ne11*nth1, (size_t)ne12*ne13}; + size_t local_work_size[] = {(size_t)nth0, (size_t)nth1, 1}; + + backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst); + } else if (src0t == GGML_TYPE_Q3_K) { + GGML_ASSERT(false && "not implemented"); + } else if (src0t == GGML_TYPE_Q5_K) { + size_t global_work_size[] = {(size_t)(ne01+ndst*nth1-1)/(ndst*nth1)*nth0, (size_t)ne11*nth1, (size_t)ne12*ne13}; + size_t local_work_size[] = {(size_t)nth0, (size_t)nth1, 1}; + + backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst); + } else if (src0t == GGML_TYPE_Q6_K) { + size_t global_work_size[] = {(size_t)(ne01+ndst*nth1-1)/(ndst*nth1)*nth0, (size_t)ne11*nth1, (size_t)ne12*ne13}; + size_t local_work_size[] = {(size_t)nth0, (size_t)nth1, 1}; + + backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst); + } else { + int64_t ny = (ne11 + nrows - 1)/nrows; + + size_t global_work_size[] = {(size_t)ne01*nth0, (size_t)ny*nth1, (size_t)ne12*ne13}; + size_t local_work_size[] = {(size_t)nth0, (size_t)nth1, 1}; + + backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst); + } +} + +static void moe_router_reoerder(ggml_backend_t backend, const ggml_tensor * src, int ne20) { + cl_int err; + ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context; + + ggml_tensor_extra_cl * extra = (ggml_tensor_extra_cl *)src->extra; + cl_ulong offset = extra->offset + src->view_offs; + + const int ne21 = src->ne[1]; + const int nb21 = src->nb[1]; + const int ne02 = nb21 / src->nb[0]; + const int n_tile_size = 32; + const int max_post_router_tile = (ne20 * ne21 / n_tile_size) + ne02; + + cl_buffer_region region; + region.origin = offset; + region.size = nb21 * ne21; + cl_mem original_router_buf = clCreateSubBuffer(extra->data_device, 0, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &err); + CL_CHECK(err); + + backend_ctx->prealloc_post_router.allocate(backend_ctx->context, sizeof(int) * max_post_router_tile * n_tile_size); + region.origin = 0; + region.size = sizeof(int) * max_post_router_tile * n_tile_size; + cl_mem post_router_buf = clCreateSubBuffer(backend_ctx->prealloc_post_router.buffer, 0, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &err); + CL_CHECK(err); + + backend_ctx->prealloc_emap.allocate(backend_ctx->context, sizeof(short) * max_post_router_tile); + region.origin = 0; + region.size = sizeof(short) * max_post_router_tile; + cl_mem emap_buf = clCreateSubBuffer(backend_ctx->prealloc_emap.buffer, 0, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &err); + CL_CHECK(err); + + backend_ctx->prealloc_hist.allocate(backend_ctx->context, sizeof(int) * ne02); + region.origin = 0; + region.size = sizeof(int) * ne02; + cl_mem hist_buf = clCreateSubBuffer(backend_ctx->prealloc_hist.buffer, 0, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &err); + CL_CHECK(err); + + backend_ctx->prealloc_tile_offset.allocate(backend_ctx->context, sizeof(int) * ne02); + region.origin = 0; + region.size = sizeof(int) * ne02; + cl_mem tile_offset_buf = clCreateSubBuffer(backend_ctx->prealloc_tile_offset.buffer, 0, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &err); + CL_CHECK(err); + + backend_ctx->prealloc_slot_counter.allocate(backend_ctx->context, sizeof(int) * ne02); + region.origin = 0; + region.size = sizeof(int) * ne02; + cl_mem slot_counter_buf = clCreateSubBuffer(backend_ctx->prealloc_slot_counter.buffer, 0, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &err); + CL_CHECK(err); + + backend_ctx->prealloc_total_tiles.allocate(backend_ctx->context, sizeof(int)); + region.origin = 0; + region.size = sizeof(int); + cl_mem total_tiles_buf = clCreateSubBuffer(backend_ctx->prealloc_total_tiles.buffer, 0, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &err); + CL_CHECK(err); + + // Histogram + cl_kernel kernel = backend_ctx->kernel_moe_histogram; + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &original_router_buf)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &hist_buf)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(int), &ne21)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(int), &ne20)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(int), &ne02)); + + size_t histogram_global_size[] = {(size_t)(((ne21 + 63) / 64) * 64), static_cast<size_t>(ne20), 1}; + size_t histogram_local_size[] = {64, 1, 1}; + backend_ctx->enqueue_ndrange_kernel(kernel, 3, histogram_global_size, histogram_local_size, src); + + // Scan + kernel = backend_ctx->kernel_moe_scan; + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &hist_buf)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &tile_offset_buf)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &total_tiles_buf)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_mem), &slot_counter_buf)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(int), &n_tile_size)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(int), &ne02)); + + size_t scan_global_size[] = {1}; + size_t scan_local_size[] = {1}; + backend_ctx->enqueue_ndrange_kernel(kernel, 1, scan_global_size, scan_local_size, src); + + // Fill + kernel = backend_ctx->kernel_moe_fill; + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &post_router_buf)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &total_tiles_buf)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(int), &n_tile_size)); + + size_t fill_global_size[] = {(size_t)(((max_post_router_tile + 63) / 64) * 64), n_tile_size, 1}; + size_t fill_local_size[] = {64, 1, 1}; + backend_ctx->enqueue_ndrange_kernel(kernel, 3, fill_global_size, fill_local_size, src); + + // Scatter + kernel = backend_ctx->kernel_moe_scatter; + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &original_router_buf)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &post_router_buf)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &emap_buf)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_mem), &tile_offset_buf)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &slot_counter_buf)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(int), &ne21)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne20)); + CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &ne02)); + + backend_ctx->enqueue_ndrange_kernel(kernel, 3, histogram_global_size, histogram_local_size, src); + + CL_CHECK(clReleaseMemObject(original_router_buf)); + CL_CHECK(clReleaseMemObject(hist_buf)); + CL_CHECK(clReleaseMemObject(tile_offset_buf)); + CL_CHECK(clReleaseMemObject(total_tiles_buf)); + CL_CHECK(clReleaseMemObject(slot_counter_buf)); + CL_CHECK(clReleaseMemObject(post_router_buf)); + CL_CHECK(clReleaseMemObject(emap_buf)); +} + +static void ggml_cl_mul_mat_id(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { + GGML_ASSERT(src0); + GGML_ASSERT(src0->extra); + GGML_ASSERT(src1); + GGML_ASSERT(src1->extra); + GGML_ASSERT(dst); + GGML_ASSERT(dst->extra); + + const ggml_tensor * src2 = dst->src[2]; + GGML_ASSERT(src2); + GGML_ASSERT(src2->extra); + + ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context; + + ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra; + ggml_tensor_extra_cl * extra1 = (ggml_tensor_extra_cl *)src1->extra; + ggml_tensor_extra_cl * extra2 = (ggml_tensor_extra_cl *)src2->extra; + ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra; + + cl_ulong offset0 = extra0->offset + src0->view_offs; + cl_ulong offset1 = extra1->offset + src1->view_offs; + cl_ulong offset2 = extra2->offset + src2->view_offs; + cl_ulong offsetd = extrad->offset + dst->view_offs; + + GGML_UNUSED(offset0); + +#ifdef GGML_OPENCL_SOA_Q + ggml_tensor_extra_cl_q4_0 * extra0_q4_0 = (ggml_tensor_extra_cl_q4_0 *)src0->extra; + ggml_tensor_extra_cl_q4_1 * extra0_q4_1 = (ggml_tensor_extra_cl_q4_1 *)src0->extra; + ggml_tensor_extra_cl_q5_0 * extra0_q5_0 = (ggml_tensor_extra_cl_q5_0 *)src0->extra; + ggml_tensor_extra_cl_q5_1 * extra0_q5_1 = (ggml_tensor_extra_cl_q5_1 *)src0->extra; + ggml_tensor_extra_cl_q4_K * extra0_q4_K = (ggml_tensor_extra_cl_q4_K *)src0->extra; + ggml_tensor_extra_cl_q5_K * extra0_q5_K = (ggml_tensor_extra_cl_q5_K *)src0->extra; + ggml_tensor_extra_cl_q6_K * extra0_q6_K = (ggml_tensor_extra_cl_q6_K *)src0->extra; + ggml_tensor_extra_cl_mxfp4 * extra0_mxfp4 = (ggml_tensor_extra_cl_mxfp4 *)src0->extra; + ggml_tensor_extra_cl_q8_0 * extra0_q8_0 = (ggml_tensor_extra_cl_q8_0 *)src0->extra; +#endif + + // TODO: general MoE for the following types + (void)extra0_q4_1; + (void)extra0_q5_0; + (void)extra0_q5_1; + (void)extra0_q4_K; + (void)extra0_q5_K; + (void)extra0_q6_K; + + const int ne00 = src0->ne[0]; + const int ne01 = src0->ne[1]; + const int ne02 = src0->ne[2]; + const int ne03 = src0->ne[3]; + + const cl_ulong nb00 = src0->nb[0]; + const cl_ulong nb01 = src0->nb[1]; + const cl_ulong nb02 = src0->nb[2]; + const cl_ulong nb03 = src0->nb[3]; + + const int ne10 = src1->ne[0]; + const int ne11 = src1->ne[1]; + const int ne12 = src1->ne[2]; + const int ne13 = src1->ne[3]; + + const cl_ulong nb11 = src1->nb[1]; + const cl_ulong nb12 = src1->nb[2]; + const cl_ulong nb13 = src1->nb[3]; + + const int ne20 = src2->ne[0]; + const int ne21 = src2->ne[1]; + + const cl_ulong nb21 = src2->nb[1]; + const cl_ulong nb20 = src2->nb[0]; + + UNUSED(nb20); + + const int ne0 = dst->ne[0]; + const int ne1 = dst->ne[1]; + const int ne2 = dst->ne[2]; + + GGML_UNUSED(ne2); + + const int r2 = ne12/ne02; + const int r3 = ne13/ne03; + const int dst_rows = ne20*ne21; // ne20 = n_used_experts, ne21 = n_rows + + GGML_ASSERT(ne00 == ne10); + + int sgs = 32; // subgroup size + int nsg = 1; // number of subgroups + int nrows = 1; // number of row in src1 + int ndst = 4; // number of values produced by each subgroup + + const int n_tile_size = 32; + const int max_post_router_tile = (ne20 * ne21 / n_tile_size) + ne02; + + GGML_UNUSED(max_post_router_tile); + + cl_kernel kernel; + + // subgroup mat vec + switch (src0->type) { + case GGML_TYPE_Q4_0: { +#ifdef GGML_OPENCL_USE_ADRENO_KERNELS + if (use_adreno_moe_kernels(backend_ctx, src0)) { + cl_int status; + + size_t local_size[3] = {64, 2, 1}; + size_t global_size[3] = {64, 2, 1}; + + if (ne12 == 1) { // for gemv + kernel = backend_ctx->kernel_gemv_moe_q4_0_f32_ns; + + cl_mem src1_sub_buffer, buf_src1_image, buf_src2; + + // create a sub_buffer for src2 + cl_buffer_region region; + region.origin = offset2; + region.size = ne20 * ne21 * sizeof(int); + buf_src2 = clCreateSubBuffer(extra2->data_device, 0, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &status); + CL_CHECK(status); + + // set thread grid + global_size[0] = static_cast<size_t>(((ne01 + 63) / 64) * 64); + global_size[1] = 4; + global_size[2] = static_cast<size_t>(ne20); + local_size[1] = 4; + + // create a sub_buffer for src1 + region.origin = offset1; + region.size = ne10 * ne11 * ne12 * sizeof(float); + src1_sub_buffer = clCreateSubBuffer(extra1->data_device, 0, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &status); + CL_CHECK(status); + + // create image for src1 + cl_image_format image_format_buf_src1 = {CL_RGBA, CL_FLOAT}; + cl_image_desc image_desc_buf_src1 = {CL_MEM_OBJECT_IMAGE1D_BUFFER, static_cast<size_t>(ne10 * ne11 * ne12 / 4), 0,0,0,0,0,0,0, {src1_sub_buffer}}; + buf_src1_image = clCreateImage(backend_ctx->context, CL_MEM_READ_ONLY, &image_format_buf_src1, &image_desc_buf_src1, NULL, &status); + CL_CHECK(status); + + // Set kernel args + int arg_idx = 0; + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &extra0_q4_0->q)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &extra0_q4_0->d)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &buf_src1_image)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &buf_src2)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_ulong), &offsetd)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(int), &ne01)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(int), &ne11)); + + // launch kernel + backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_size, local_size, dst); + + // deallocate sub buffers and images + CL_CHECK(clReleaseMemObject(src1_sub_buffer)); + CL_CHECK(clReleaseMemObject(buf_src1_image)); + CL_CHECK(clReleaseMemObject(buf_src2)); + + } else { // for gemm + kernel = backend_ctx->kernel_gemm_moe_q4_0_f32_ns; + + // Reorder router if called from test-backend-ops or when new router is generated. + // Otherwise reuse the reordered result from previous mul_mat_id call. + if ((strstr(src0->name, "as") != NULL) || backend_ctx->toggle_reorder) { + moe_router_reoerder(backend, src2, ne20); + backend_ctx->toggle_reorder = false; + } + + cl_mem sub_buf_src1_pre, buf_src1_reordered, image_src1_reordered, sub_buf_dst, buf_dst_image; + cl_mem buf_src2, buf_src2_emap; + + cl_buffer_region region; + region.origin = 0; + region.size = sizeof(int) * max_post_router_tile * n_tile_size; + buf_src2 = clCreateSubBuffer(backend_ctx->prealloc_post_router.buffer, 0, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &status); + CL_CHECK(status); + + region.origin = 0; + region.size = sizeof(short) * max_post_router_tile; + buf_src2_emap = clCreateSubBuffer(backend_ctx->prealloc_emap.buffer, 0, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &status); + CL_CHECK(status); + + // Reorder activations + // create a sub_buffer for src1 + region.origin = offset1; + region.size = ne10 * ne11 * ne12 * sizeof(float); + sub_buf_src1_pre = clCreateSubBuffer(extra1->data_device, 0, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &status); + CL_CHECK(status); + + // Create image for reordered src1 + // Use pre-allocated placeholder + region.origin = 0; + region.size = ne00 * max_post_router_tile * n_tile_size * sizeof(float); + backend_ctx->prealloc_act_trans.allocate(backend_ctx->context, region.size); + buf_src1_reordered = clCreateSubBuffer( + backend_ctx->prealloc_act_trans.buffer, + 0, + CL_BUFFER_CREATE_TYPE_REGION, + ®ion, + &status); + CL_CHECK(status); + cl_image_format image_format_buf_src1; + cl_image_desc image_desc_buf_src1; + image_format_buf_src1 = {CL_RGBA, CL_FLOAT}; + image_desc_buf_src1 = {CL_MEM_OBJECT_IMAGE1D_BUFFER, static_cast<size_t>(ne00 * max_post_router_tile * n_tile_size / 4), 0,0,0,0,0,0,0, {buf_src1_reordered}}; + image_src1_reordered = clCreateImage(backend_ctx->context, CL_MEM_READ_ONLY, &image_format_buf_src1, &image_desc_buf_src1, NULL, &status); + CL_CHECK(status); + + unsigned short map_ratio = ne20 / ne11; + GGML_ASSERT(((map_ratio == 1) || (map_ratio == ne20)) && "Map ratio not supported\n"); + CL_CHECK(clSetKernelArg(backend_ctx->kernel_moe_reorder_b, 0, sizeof(cl_mem), &sub_buf_src1_pre)); + CL_CHECK(clSetKernelArg(backend_ctx->kernel_moe_reorder_b, 1, sizeof(cl_mem), &buf_src2)); + CL_CHECK(clSetKernelArg(backend_ctx->kernel_moe_reorder_b, 2, sizeof(cl_mem), &buf_src1_reordered)); + CL_CHECK(clSetKernelArg(backend_ctx->kernel_moe_reorder_b, 3, sizeof(cl_mem), &(backend_ctx->prealloc_total_tiles.buffer))); + CL_CHECK(clSetKernelArg(backend_ctx->kernel_moe_reorder_b, 4, sizeof(unsigned int), &ne00)); + CL_CHECK(clSetKernelArg(backend_ctx->kernel_moe_reorder_b, 5, sizeof(unsigned short), &map_ratio)); + CL_CHECK(clSetKernelArg(backend_ctx->kernel_moe_reorder_b, 6, sizeof(unsigned int), &n_tile_size)); + + size_t reorder_b_local_size[3] = {256, 1, 1}; + size_t reorder_b_global_size[3] = {static_cast<size_t>(((ne00 / 4) + 255) / 256 * 256), static_cast<size_t>(max_post_router_tile * n_tile_size), 1}; + + // Dispatch reorder kernel + backend_ctx->enqueue_ndrange_kernel(backend_ctx->kernel_moe_reorder_b, 3, reorder_b_global_size, reorder_b_local_size, dst); + + // MoE kernel prepare + // Create sub buffer for dst + region.origin = offsetd; + region.size = ne0 * ne1 * ne2 * sizeof(float); + sub_buf_dst = clCreateSubBuffer( + extrad->data_device, + 0, + CL_BUFFER_CREATE_TYPE_REGION, + ®ion, + &status); + CL_CHECK(status); + // Create image for dst + cl_image_format image_format_buf_dst = {CL_R, CL_FLOAT}; + cl_image_desc image_desc_buf_dst = {CL_MEM_OBJECT_IMAGE1D_BUFFER, static_cast<size_t>(ne0 * ne1 * ne2), 0,0,0,0,0,0,0, {sub_buf_dst}}; + buf_dst_image = clCreateImage(backend_ctx->context, CL_MEM_WRITE_ONLY, &image_format_buf_dst, &image_desc_buf_dst, NULL, &status); + CL_CHECK(status); + + // Set kernel args + int arg_idx = 0; + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &extra0_q4_0->q_img)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &extra0_q4_0->d)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &image_src1_reordered)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &buf_src2)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &buf_src2_emap)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &buf_dst_image)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &(backend_ctx->prealloc_total_tiles.buffer))); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(int), &ne01)); + + // set thread grid + global_size[1] = static_cast<size_t>((ne01 + 63) / 64); + global_size[2] = static_cast<size_t>(max_post_router_tile); + local_size[1] = 1; + local_size[2] = 1; + + // Dispatch kernel + backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_size, local_size, dst); + + clReleaseMemObject(sub_buf_src1_pre); + clReleaseMemObject(buf_src1_reordered); + clReleaseMemObject(image_src1_reordered); + clReleaseMemObject(buf_src2); + clReleaseMemObject(buf_src2_emap); + clReleaseMemObject(sub_buf_dst); + clReleaseMemObject(buf_dst_image); + } + return; + } // fallback to generic Q4_0 MoE kernel + +#endif // GGML_OPENCL_USE_ADRENO_KERNELS + kernel = backend_ctx->kernel_mul_mv_id_q4_0_f32_8x_flat; + + if (backend_ctx->gpu_family == INTEL) { + sgs = 16; + nsg = 1; + ndst = 8; } else if (backend_ctx->gpu_family == ADRENO) { - nth0 = 64; - nth1 = 2; - ndst = nth1*2; + sgs = 64; + nsg = 1; + ndst = 8; } else { GGML_ASSERT(false && "TODO: Unknown GPU"); } - CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device)); - CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0)); + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0_q4_0->q)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra0_q4_0->d)); CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra1->data_device)); CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offset1)); - CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extrad->data_device)); - CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offsetd)); - CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne00)); - CL_CHECK(clSetKernelArg(kernel, 7, sizeof(cl_ulong), &nb01)); - CL_CHECK(clSetKernelArg(kernel, 8, sizeof(cl_ulong), &nb02)); - CL_CHECK(clSetKernelArg(kernel, 9, sizeof(cl_ulong), &nb03)); - CL_CHECK(clSetKernelArg(kernel, 10, sizeof(int), &ne12)); - CL_CHECK(clSetKernelArg(kernel, 11, sizeof(cl_ulong), &nb11)); - CL_CHECK(clSetKernelArg(kernel, 12, sizeof(cl_ulong), &nb12)); - CL_CHECK(clSetKernelArg(kernel, 13, sizeof(cl_ulong), &nb13)); - CL_CHECK(clSetKernelArg(kernel, 14, sizeof(int), &ne0)); - CL_CHECK(clSetKernelArg(kernel, 15, sizeof(int), &ne1)); - CL_CHECK(clSetKernelArg(kernel, 16, sizeof(int), &r2)); - CL_CHECK(clSetKernelArg(kernel, 17, sizeof(int), &r3)); - CL_CHECK(clSetKernelArg(kernel, 18, sizeof(float)*nth0,nullptr)); -#endif + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extra2->data_device)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offset2)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, 7, sizeof(cl_ulong), &offsetd)); + CL_CHECK(clSetKernelArg(kernel, 8, sizeof(int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, 9, sizeof(int), &ne01)); + CL_CHECK(clSetKernelArg(kernel, 10, sizeof(int), &ne02)); + CL_CHECK(clSetKernelArg(kernel, 11, sizeof(cl_ulong), &nb00)); + CL_CHECK(clSetKernelArg(kernel, 12, sizeof(cl_ulong), &nb02)); + CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int), &ne10)); + CL_CHECK(clSetKernelArg(kernel, 14, sizeof(int), &ne11)); + CL_CHECK(clSetKernelArg(kernel, 15, sizeof(int), &ne12)); + CL_CHECK(clSetKernelArg(kernel, 16, sizeof(cl_ulong), &nb11)); + CL_CHECK(clSetKernelArg(kernel, 17, sizeof(cl_ulong), &nb12)); + CL_CHECK(clSetKernelArg(kernel, 18, sizeof(int), &ne20)); + CL_CHECK(clSetKernelArg(kernel, 19, sizeof(int), &ne21)); + CL_CHECK(clSetKernelArg(kernel, 20, sizeof(cl_ulong), &nb21)); + CL_CHECK(clSetKernelArg(kernel, 21, sizeof(int), &ne0)); + CL_CHECK(clSetKernelArg(kernel, 22, sizeof(int), &ne1)); + CL_CHECK(clSetKernelArg(kernel, 23, sizeof(int), &r2)); + CL_CHECK(clSetKernelArg(kernel, 24, sizeof(int), &r3)); + break; } - default: - GGML_ASSERT(false && "not implemented"); - } + case GGML_TYPE_Q4_1: { +#ifdef GGML_OPENCL_USE_ADRENO_KERNELS + if (use_adreno_moe_kernels(backend_ctx, src0)) { + cl_int status; - if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_MXFP4 || - src0t == GGML_TYPE_Q4_1 || - src0t == GGML_TYPE_Q8_0 || - src0t == GGML_TYPE_Q2_K) { - // Each SIMD group produces N_DST values in the result. Assuming each - // workgroup has N_SIMDGROUP SIMD groups, then each workgroup will - // produce N_DST*N_SIMDGROUP values in the result. Hence, the grid size - // (number of workgroups) will be a nearest multiple of - // N_DST*N_SIMDGROUP to cover the size of the dimension. Below, 4 is - // N_DST*N_SIMDGROUP (see the kernel for Q4_0 matmul). - size_t global_work_size[] = {(size_t)(ne01 + ndst-1)/ndst*nth0, (size_t)ne11*nth1, (size_t)ne12*ne13}; - size_t local_work_size[] = {(size_t)nth0, (size_t)nth1, 1}; + size_t local_size[3] = {64, 2, 1}; + size_t global_size[3] = {64, 2, 1}; - backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst); - } else if (src0t == GGML_TYPE_Q4_K) { - GGML_ASSERT(false && "not implemented"); - } else if (src0t == GGML_TYPE_Q3_K) { - GGML_ASSERT(false && "not implemented"); - } else if (src0t == GGML_TYPE_Q5_K) { - GGML_ASSERT(false && "not implemented"); - } else if (src0t == GGML_TYPE_Q6_K) { - size_t global_work_size[] = {(size_t)(ne01+1)/2*nth0, (size_t)ne11*nth1, (size_t)ne12*ne13}; - size_t local_work_size[] = {(size_t)nth0, (size_t)nth1, 1}; + if (ne12 == 1) { // for gemv + kernel = backend_ctx->kernel_gemv_moe_q4_1_f32_ns; - backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst); - } else { - int64_t ny = (ne11 + nrows - 1)/nrows; + cl_mem src1_sub_buffer, buf_src1_image, buf_src2; - size_t global_work_size[] = {(size_t)ne01*nth0, (size_t)ny*nth1, (size_t)ne12*ne13}; - size_t local_work_size[] = {(size_t)nth0, (size_t)nth1, 1}; + // create a sub_buffer for src2 + cl_buffer_region region; + region.origin = offset2; + region.size = ne20 * ne21 * sizeof(int); + buf_src2 = clCreateSubBuffer(extra2->data_device, 0, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &status); + CL_CHECK(status); - backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst); - } -} + // set thread grid + global_size[0] = static_cast<size_t>(((ne01 + 63) / 64) * 64); + global_size[1] = 4; + global_size[2] = static_cast<size_t>(ne20); + local_size[1] = 4; -static void ggml_cl_mul_mat_id(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { - GGML_ASSERT(src0); - GGML_ASSERT(src0->extra); - GGML_ASSERT(src1); - GGML_ASSERT(src1->extra); - GGML_ASSERT(dst); - GGML_ASSERT(dst->extra); + // create a sub_buffer for src1 + region.origin = offset1; + region.size = ne10 * ne11 * ne12 * sizeof(float); + src1_sub_buffer = clCreateSubBuffer(extra1->data_device, 0, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &status); + CL_CHECK(status); - const ggml_tensor * src2 = dst->src[2]; - GGML_ASSERT(src2); - GGML_ASSERT(src2->extra); + // create image for src1 + cl_image_format image_format_buf_src1 = {CL_RGBA, CL_FLOAT}; + cl_image_desc image_desc_buf_src1 = {CL_MEM_OBJECT_IMAGE1D_BUFFER, static_cast<size_t>(ne10 * ne11 * ne12 / 4), 0,0,0,0,0,0,0, {src1_sub_buffer}}; + buf_src1_image = clCreateImage(backend_ctx->context, CL_MEM_READ_ONLY, &image_format_buf_src1, &image_desc_buf_src1, NULL, &status); + CL_CHECK(status); - ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context; + // Set kernel args + int arg_idx = 0; + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &extra0_q4_1->q)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &extra0_q4_1->d)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &extra0_q4_1->m)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &buf_src1_image)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &buf_src2)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_ulong), &offsetd)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(int), &ne01)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(int), &ne11)); - ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra; - ggml_tensor_extra_cl * extra1 = (ggml_tensor_extra_cl *)src1->extra; - ggml_tensor_extra_cl * extra2 = (ggml_tensor_extra_cl *)src2->extra; - ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra; + // launch kernel + backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_size, local_size, dst); - cl_ulong offset0 = extra0->offset + src0->view_offs; - cl_ulong offset1 = extra1->offset + src1->view_offs; - cl_ulong offset2 = extra2->offset + src2->view_offs; - cl_ulong offsetd = extrad->offset + dst->view_offs; + // deallocate sub buffers and images + CL_CHECK(clReleaseMemObject(src1_sub_buffer)); + CL_CHECK(clReleaseMemObject(buf_src1_image)); + CL_CHECK(clReleaseMemObject(buf_src2)); - GGML_UNUSED(offset0); + } else { // for gemm + kernel = backend_ctx->kernel_gemm_moe_q4_1_f32_ns; -#ifdef GGML_OPENCL_SOA_Q - ggml_tensor_extra_cl_q4_0 * extra0_q4_0 = (ggml_tensor_extra_cl_q4_0 *)src0->extra; - ggml_tensor_extra_cl_mxfp4 * extra0_mxfp4 = (ggml_tensor_extra_cl_mxfp4 *)src0->extra; - ggml_tensor_extra_cl_q8_0 * extra0_q8_0 = (ggml_tensor_extra_cl_q8_0 *)src0->extra; -#endif + // Reorder router if called from test-backend-ops or when new router is generated. + // Otherwise reuse the reordered result from previous mul_mat_id call. + if ((strstr(src0->name, "as") != NULL) || backend_ctx->toggle_reorder) { + moe_router_reoerder(backend, src2, ne20); + backend_ctx->toggle_reorder = false; + } - const int ne00 = src0->ne[0]; - const int ne01 = src0->ne[1]; - const int ne02 = src0->ne[2]; - const int ne03 = src0->ne[3]; + cl_mem sub_buf_src1_pre, buf_src1_reordered, image_src1_reordered, sub_buf_dst, buf_dst_image; + cl_mem buf_src2, buf_src2_emap; - const cl_ulong nb00 = src0->nb[0]; - const cl_ulong nb01 = src0->nb[1]; - const cl_ulong nb02 = src0->nb[2]; - const cl_ulong nb03 = src0->nb[3]; + cl_buffer_region region; + region.origin = 0; + region.size = sizeof(int) * max_post_router_tile * n_tile_size; + buf_src2 = clCreateSubBuffer(backend_ctx->prealloc_post_router.buffer, 0, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &status); + CL_CHECK(status); - const int ne10 = src1->ne[0]; - const int ne11 = src1->ne[1]; - const int ne12 = src1->ne[2]; - const int ne13 = src1->ne[3]; + region.origin = 0; + region.size = sizeof(short) * max_post_router_tile; + buf_src2_emap = clCreateSubBuffer(backend_ctx->prealloc_emap.buffer, 0, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &status); + CL_CHECK(status); - const cl_ulong nb11 = src1->nb[1]; - const cl_ulong nb12 = src1->nb[2]; - const cl_ulong nb13 = src1->nb[3]; + // Reorder activations + // create a sub_buffer for src1 + region.origin = offset1; + region.size = ne10 * ne11 * ne12 * sizeof(float); + sub_buf_src1_pre = clCreateSubBuffer(extra1->data_device, 0, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &status); + CL_CHECK(status); - const int ne20 = src2->ne[0]; - const int ne21 = src2->ne[1]; + // Create image for reordered src1 + // Use pre-allocated placeholder + region.origin = 0; + region.size = ne00 * max_post_router_tile * n_tile_size * sizeof(float); + backend_ctx->prealloc_act_trans.allocate(backend_ctx->context, region.size); + buf_src1_reordered = clCreateSubBuffer( + backend_ctx->prealloc_act_trans.buffer, + 0, + CL_BUFFER_CREATE_TYPE_REGION, + ®ion, + &status); + CL_CHECK(status); + cl_image_format image_format_buf_src1; + cl_image_desc image_desc_buf_src1; + image_format_buf_src1 = {CL_RGBA, CL_FLOAT}; + image_desc_buf_src1 = {CL_MEM_OBJECT_IMAGE1D_BUFFER, static_cast<size_t>(ne00 * max_post_router_tile * n_tile_size / 4), 0,0,0,0,0,0,0, {buf_src1_reordered}}; + image_src1_reordered = clCreateImage(backend_ctx->context, CL_MEM_READ_ONLY, &image_format_buf_src1, &image_desc_buf_src1, NULL, &status); + CL_CHECK(status); - const cl_ulong nb21 = src2->nb[1]; - const cl_ulong nb20 = src2->nb[0]; + unsigned short map_ratio = ne20 / ne11; + GGML_ASSERT(((map_ratio == 1) || (map_ratio == ne20)) && "Map ratio not supported\n"); + CL_CHECK(clSetKernelArg(backend_ctx->kernel_moe_reorder_b, 0, sizeof(cl_mem), &sub_buf_src1_pre)); + CL_CHECK(clSetKernelArg(backend_ctx->kernel_moe_reorder_b, 1, sizeof(cl_mem), &buf_src2)); + CL_CHECK(clSetKernelArg(backend_ctx->kernel_moe_reorder_b, 2, sizeof(cl_mem), &buf_src1_reordered)); + CL_CHECK(clSetKernelArg(backend_ctx->kernel_moe_reorder_b, 3, sizeof(cl_mem), &(backend_ctx->prealloc_total_tiles.buffer))); + CL_CHECK(clSetKernelArg(backend_ctx->kernel_moe_reorder_b, 4, sizeof(unsigned int), &ne00)); + CL_CHECK(clSetKernelArg(backend_ctx->kernel_moe_reorder_b, 5, sizeof(unsigned short), &map_ratio)); + CL_CHECK(clSetKernelArg(backend_ctx->kernel_moe_reorder_b, 6, sizeof(unsigned int), &n_tile_size)); + + size_t reorder_b_local_size[3] = {256, 1, 1}; + size_t reorder_b_global_size[3] = {static_cast<size_t>(((ne00 / 4) + 255) / 256 * 256), static_cast<size_t>(max_post_router_tile * n_tile_size), 1}; + + // Dispatch reorder kernel + backend_ctx->enqueue_ndrange_kernel(backend_ctx->kernel_moe_reorder_b, 3, reorder_b_global_size, reorder_b_local_size, dst); + + // MoE kernel prepare + // Create sub buffer for dst + region.origin = offsetd; + region.size = ne0 * ne1 * ne2 * sizeof(float); + sub_buf_dst = clCreateSubBuffer( + extrad->data_device, + 0, + CL_BUFFER_CREATE_TYPE_REGION, + ®ion, + &status); + CL_CHECK(status); + // Create image for dst + cl_image_format image_format_buf_dst = {CL_R, CL_FLOAT}; + cl_image_desc image_desc_buf_dst = {CL_MEM_OBJECT_IMAGE1D_BUFFER, static_cast<size_t>(ne0 * ne1 * ne2), 0,0,0,0,0,0,0, {sub_buf_dst}}; + buf_dst_image = clCreateImage(backend_ctx->context, CL_MEM_WRITE_ONLY, &image_format_buf_dst, &image_desc_buf_dst, NULL, &status); + CL_CHECK(status); + + // Set kernel args + int arg_idx = 0; + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &extra0_q4_1->q_img)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &extra0_q4_1->d)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &extra0_q4_1->m)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &image_src1_reordered)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &buf_src2)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &buf_src2_emap)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &buf_dst_image)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &(backend_ctx->prealloc_total_tiles.buffer))); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(int), &ne01)); + + // set thread grid + global_size[1] = static_cast<size_t>((ne01 + 63) / 64); + global_size[2] = static_cast<size_t>(max_post_router_tile); + local_size[1] = 1; + local_size[2] = 1; + + // Dispatch kernel + backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_size, local_size, dst); + + clReleaseMemObject(sub_buf_src1_pre); + clReleaseMemObject(buf_src1_reordered); + clReleaseMemObject(image_src1_reordered); + clReleaseMemObject(buf_src2); + clReleaseMemObject(buf_src2_emap); + clReleaseMemObject(sub_buf_dst); + clReleaseMemObject(buf_dst_image); + } + return; + } +#endif //GGML_OPENCL_USE_ADRENO_KERNELS + } + case GGML_TYPE_Q5_0: { +#ifdef GGML_OPENCL_USE_ADRENO_KERNELS + if (use_adreno_moe_kernels(backend_ctx, src0)) { + cl_int status; + + size_t local_size[3] = {64, 2, 1}; + size_t global_size[3] = {64, 2, 1}; + + if (ne12 == 1) { // for gemv + kernel = backend_ctx->kernel_gemv_moe_q5_0_f32_ns; + + cl_mem src1_sub_buffer, buf_src1_image, buf_src2; + + // create a sub_buffer for src2 + cl_buffer_region region; + region.origin = offset2; + region.size = ne20 * ne21 * sizeof(int); + buf_src2 = clCreateSubBuffer(extra2->data_device, 0, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &status); + CL_CHECK(status); + + // set thread grid + global_size[0] = static_cast<size_t>(((ne01 + 63) / 64) * 64); + global_size[1] = 4; + global_size[2] = static_cast<size_t>(ne20); + local_size[1] = 4; + + // create a sub_buffer for src1 + region.origin = offset1; + region.size = ne10 * ne11 * ne12 * sizeof(float); + src1_sub_buffer = clCreateSubBuffer(extra1->data_device, 0, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &status); + CL_CHECK(status); + + // create image for src1 + cl_image_format image_format_buf_src1 = {CL_RGBA, CL_FLOAT}; + cl_image_desc image_desc_buf_src1 = {CL_MEM_OBJECT_IMAGE1D_BUFFER, static_cast<size_t>(ne10 * ne11 * ne12 / 4), 0,0,0,0,0,0,0, {src1_sub_buffer}}; + buf_src1_image = clCreateImage(backend_ctx->context, CL_MEM_READ_ONLY, &image_format_buf_src1, &image_desc_buf_src1, NULL, &status); + CL_CHECK(status); + + // Set kernel args + int arg_idx = 0; + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &extra0_q5_0->qs)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &extra0_q5_0->qh)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &extra0_q5_0->d)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &buf_src1_image)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &buf_src2)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_ulong), &offsetd)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(int), &ne01)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(int), &ne11)); + + // launch kernel + backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_size, local_size, dst); + + // deallocate sub buffers and images + CL_CHECK(clReleaseMemObject(src1_sub_buffer)); + CL_CHECK(clReleaseMemObject(buf_src1_image)); + CL_CHECK(clReleaseMemObject(buf_src2)); + + } else { // for gemm + kernel = backend_ctx->kernel_gemm_moe_q5_0_f32_ns; + + // Reorder router if called from test-backend-ops or when new router is generated. + // Otherwise reuse the reordered result from previous mul_mat_id call. + if ((strstr(src0->name, "as") != NULL) || backend_ctx->toggle_reorder) { + moe_router_reoerder(backend, src2, ne20); + backend_ctx->toggle_reorder = false; + } + + cl_mem sub_buf_src1_pre, buf_src1_reordered, image_src1_reordered, sub_buf_dst, buf_dst_image; + cl_mem buf_src2, buf_src2_emap; + + cl_buffer_region region; + region.origin = 0; + region.size = sizeof(int) * max_post_router_tile * n_tile_size; + buf_src2 = clCreateSubBuffer(backend_ctx->prealloc_post_router.buffer, 0, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &status); + CL_CHECK(status); + + region.origin = 0; + region.size = sizeof(short) * max_post_router_tile; + buf_src2_emap = clCreateSubBuffer(backend_ctx->prealloc_emap.buffer, 0, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &status); + CL_CHECK(status); + + // Reorder activations + // create a sub_buffer for src1 + region.origin = offset1; + region.size = ne10 * ne11 * ne12 * sizeof(float); + sub_buf_src1_pre = clCreateSubBuffer(extra1->data_device, 0, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &status); + CL_CHECK(status); + + // Create image for reordered src1 + // Use pre-allocated placeholder + region.origin = 0; + region.size = ne00 * max_post_router_tile * n_tile_size * sizeof(float); + backend_ctx->prealloc_act_trans.allocate(backend_ctx->context, region.size); + buf_src1_reordered = clCreateSubBuffer( + backend_ctx->prealloc_act_trans.buffer, + 0, + CL_BUFFER_CREATE_TYPE_REGION, + ®ion, + &status); + CL_CHECK(status); + cl_image_format image_format_buf_src1; + cl_image_desc image_desc_buf_src1; + image_format_buf_src1 = {CL_RGBA, CL_FLOAT}; + image_desc_buf_src1 = {CL_MEM_OBJECT_IMAGE1D_BUFFER, static_cast<size_t>(ne00 * max_post_router_tile * n_tile_size / 4), 0,0,0,0,0,0,0, {buf_src1_reordered}}; + image_src1_reordered = clCreateImage(backend_ctx->context, CL_MEM_READ_ONLY, &image_format_buf_src1, &image_desc_buf_src1, NULL, &status); + CL_CHECK(status); + + unsigned short map_ratio = ne20 / ne11; + GGML_ASSERT(((map_ratio == 1) || (map_ratio == ne20)) && "Map ratio not supported\n"); + CL_CHECK(clSetKernelArg(backend_ctx->kernel_moe_reorder_b, 0, sizeof(cl_mem), &sub_buf_src1_pre)); + CL_CHECK(clSetKernelArg(backend_ctx->kernel_moe_reorder_b, 1, sizeof(cl_mem), &buf_src2)); + CL_CHECK(clSetKernelArg(backend_ctx->kernel_moe_reorder_b, 2, sizeof(cl_mem), &buf_src1_reordered)); + CL_CHECK(clSetKernelArg(backend_ctx->kernel_moe_reorder_b, 3, sizeof(cl_mem), &(backend_ctx->prealloc_total_tiles.buffer))); + CL_CHECK(clSetKernelArg(backend_ctx->kernel_moe_reorder_b, 4, sizeof(unsigned int), &ne00)); + CL_CHECK(clSetKernelArg(backend_ctx->kernel_moe_reorder_b, 5, sizeof(unsigned short), &map_ratio)); + CL_CHECK(clSetKernelArg(backend_ctx->kernel_moe_reorder_b, 6, sizeof(unsigned int), &n_tile_size)); + + size_t reorder_b_local_size[3] = {256, 1, 1}; + size_t reorder_b_global_size[3] = {static_cast<size_t>(((ne00 / 4) + 255) / 256 * 256), static_cast<size_t>(max_post_router_tile * n_tile_size), 1}; + + // Dispatch reorder kernel + backend_ctx->enqueue_ndrange_kernel(backend_ctx->kernel_moe_reorder_b, 3, reorder_b_global_size, reorder_b_local_size, dst); + + // MoE kernel prepare + // Create sub buffer for dst + region.origin = offsetd; + region.size = ne0 * ne1 * ne2 * sizeof(float); + sub_buf_dst = clCreateSubBuffer( + extrad->data_device, + 0, + CL_BUFFER_CREATE_TYPE_REGION, + ®ion, + &status); + CL_CHECK(status); + // Create image for dst + cl_image_format image_format_buf_dst = {CL_R, CL_FLOAT}; + cl_image_desc image_desc_buf_dst = {CL_MEM_OBJECT_IMAGE1D_BUFFER, static_cast<size_t>(ne0 * ne1 * ne2), 0,0,0,0,0,0,0, {sub_buf_dst}}; + buf_dst_image = clCreateImage(backend_ctx->context, CL_MEM_WRITE_ONLY, &image_format_buf_dst, &image_desc_buf_dst, NULL, &status); + CL_CHECK(status); + + // Set kernel args + int arg_idx = 0; + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &extra0_q5_0->qs_img)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &extra0_q5_0->qh)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &extra0_q5_0->d)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &image_src1_reordered)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &buf_src2)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &buf_src2_emap)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &buf_dst_image)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &(backend_ctx->prealloc_total_tiles.buffer))); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(int), &ne01)); + + // set thread grid + global_size[1] = static_cast<size_t>((ne01 + 63) / 64); + global_size[2] = static_cast<size_t>(max_post_router_tile); + local_size[1] = 1; + local_size[2] = 1; + + // Dispatch kernel + backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_size, local_size, dst); + + clReleaseMemObject(sub_buf_src1_pre); + clReleaseMemObject(buf_src1_reordered); + clReleaseMemObject(image_src1_reordered); + clReleaseMemObject(buf_src2); + clReleaseMemObject(buf_src2_emap); + clReleaseMemObject(sub_buf_dst); + clReleaseMemObject(buf_dst_image); + } + return; + } +#endif //GGML_OPENCL_USE_ADRENO_KERNELS + } + case GGML_TYPE_Q5_1: { +#ifdef GGML_OPENCL_USE_ADRENO_KERNELS + if (use_adreno_moe_kernels(backend_ctx, src0)) { + cl_int status; + + size_t local_size[3] = {64, 2, 1}; + size_t global_size[3] = {64, 2, 1}; + + if (ne12 == 1) { // for gemv + kernel = backend_ctx->kernel_gemv_moe_q5_1_f32_ns; + + cl_mem src1_sub_buffer, buf_src1_image, buf_src2; + + // create a sub_buffer for src2 + cl_buffer_region region; + region.origin = offset2; + region.size = ne20 * ne21 * sizeof(int); + buf_src2 = clCreateSubBuffer(extra2->data_device, 0, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &status); + CL_CHECK(status); + + // set thread grid + global_size[0] = static_cast<size_t>(((ne01 + 63) / 64) * 64); + global_size[1] = 4; + global_size[2] = static_cast<size_t>(ne20); + local_size[1] = 4; + + // create a sub_buffer for src1 + region.origin = offset1; + region.size = ne10 * ne11 * ne12 * sizeof(float); + src1_sub_buffer = clCreateSubBuffer(extra1->data_device, 0, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &status); + CL_CHECK(status); + + // create image for src1 + cl_image_format image_format_buf_src1 = {CL_RGBA, CL_FLOAT}; + cl_image_desc image_desc_buf_src1 = {CL_MEM_OBJECT_IMAGE1D_BUFFER, static_cast<size_t>(ne10 * ne11 * ne12 / 4), 0,0,0,0,0,0,0, {src1_sub_buffer}}; + buf_src1_image = clCreateImage(backend_ctx->context, CL_MEM_READ_ONLY, &image_format_buf_src1, &image_desc_buf_src1, NULL, &status); + CL_CHECK(status); + + // Set kernel args + int arg_idx = 0; + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &extra0_q5_1->qs)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &extra0_q5_1->qh)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &extra0_q5_1->d)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &extra0_q5_1->m)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &buf_src1_image)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &buf_src2)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_ulong), &offsetd)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(int), &ne01)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(int), &ne11)); + + // launch kernel + backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_size, local_size, dst); - UNUSED(nb20); + // deallocate sub buffers and images + CL_CHECK(clReleaseMemObject(src1_sub_buffer)); + CL_CHECK(clReleaseMemObject(buf_src1_image)); + CL_CHECK(clReleaseMemObject(buf_src2)); + } else { // for gemm + kernel = backend_ctx->kernel_gemm_moe_q5_1_f32_ns; - const int ne0 = dst->ne[0]; - const int ne1 = dst->ne[1]; + // Reorder router if called from test-backend-ops or when new router is generated. + // Otherwise reuse the reordered result from previous mul_mat_id call. + if ((strstr(src0->name, "as") != NULL) || backend_ctx->toggle_reorder) { + moe_router_reoerder(backend, src2, ne20); + backend_ctx->toggle_reorder = false; + } - const int r2 = ne12/ne02; - const int r3 = ne13/ne03; - const int dst_rows = ne20*ne21; // ne20 = n_used_experts, ne21 = n_rows + cl_mem sub_buf_src1_pre, buf_src1_reordered, image_src1_reordered, sub_buf_dst, buf_dst_image; + cl_mem buf_src2, buf_src2_emap; - GGML_ASSERT(ne00 == ne10); + cl_buffer_region region; + region.origin = 0; + region.size = sizeof(int) * max_post_router_tile * n_tile_size; + buf_src2 = clCreateSubBuffer(backend_ctx->prealloc_post_router.buffer, 0, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &status); + CL_CHECK(status); - int sgs = 32; // subgroup size - int nsg = 1; // number of subgroups - int nrows = 1; // number of row in src1 - int ndst = 4; // number of values produced by each subgroup + region.origin = 0; + region.size = sizeof(short) * max_post_router_tile; + buf_src2_emap = clCreateSubBuffer(backend_ctx->prealloc_emap.buffer, 0, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &status); + CL_CHECK(status); - cl_kernel kernel; + // Reorder activations + // create a sub_buffer for src1 + region.origin = offset1; + region.size = ne10 * ne11 * ne12 * sizeof(float); + sub_buf_src1_pre = clCreateSubBuffer(extra1->data_device, 0, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &status); + CL_CHECK(status); - // subgroup mat vec - switch (src0->type) { - case GGML_TYPE_Q4_0: { - kernel = backend_ctx->kernel_mul_mv_id_q4_0_f32_8x_flat; + // Create image for reordered src1 + // Use pre-allocated placeholder + region.origin = 0; + region.size = ne00 * max_post_router_tile * n_tile_size * sizeof(float); + backend_ctx->prealloc_act_trans.allocate(backend_ctx->context, region.size); + buf_src1_reordered = clCreateSubBuffer( + backend_ctx->prealloc_act_trans.buffer, + 0, + CL_BUFFER_CREATE_TYPE_REGION, + ®ion, + &status); + CL_CHECK(status); + cl_image_format image_format_buf_src1; + cl_image_desc image_desc_buf_src1; + image_format_buf_src1 = {CL_RGBA, CL_FLOAT}; + image_desc_buf_src1 = {CL_MEM_OBJECT_IMAGE1D_BUFFER, static_cast<size_t>(ne00 * max_post_router_tile * n_tile_size / 4), 0,0,0,0,0,0,0, {buf_src1_reordered}}; + image_src1_reordered = clCreateImage(backend_ctx->context, CL_MEM_READ_ONLY, &image_format_buf_src1, &image_desc_buf_src1, NULL, &status); + CL_CHECK(status); - if (backend_ctx->gpu_family == INTEL) { - sgs = 16; - nsg = 1; - ndst = 8; - } else if (backend_ctx->gpu_family == ADRENO) { - sgs = 64; - nsg = 1; - ndst = 8; - } else { - GGML_ASSERT(false && "TODO: Unknown GPU"); - } + unsigned short map_ratio = ne20 / ne11; + GGML_ASSERT(((map_ratio == 1) || (map_ratio == ne20)) && "Map ratio not supported\n"); + CL_CHECK(clSetKernelArg(backend_ctx->kernel_moe_reorder_b, 0, sizeof(cl_mem), &sub_buf_src1_pre)); + CL_CHECK(clSetKernelArg(backend_ctx->kernel_moe_reorder_b, 1, sizeof(cl_mem), &buf_src2)); + CL_CHECK(clSetKernelArg(backend_ctx->kernel_moe_reorder_b, 2, sizeof(cl_mem), &buf_src1_reordered)); + CL_CHECK(clSetKernelArg(backend_ctx->kernel_moe_reorder_b, 3, sizeof(cl_mem), &(backend_ctx->prealloc_total_tiles.buffer))); + CL_CHECK(clSetKernelArg(backend_ctx->kernel_moe_reorder_b, 4, sizeof(unsigned int), &ne00)); + CL_CHECK(clSetKernelArg(backend_ctx->kernel_moe_reorder_b, 5, sizeof(unsigned short), &map_ratio)); + CL_CHECK(clSetKernelArg(backend_ctx->kernel_moe_reorder_b, 6, sizeof(unsigned int), &n_tile_size)); + + size_t reorder_b_local_size[3] = {256, 1, 1}; + size_t reorder_b_global_size[3] = {static_cast<size_t>(((ne00 / 4) + 255) / 256 * 256), static_cast<size_t>(max_post_router_tile * n_tile_size), 1}; + + // Dispatch reorder kernel + backend_ctx->enqueue_ndrange_kernel(backend_ctx->kernel_moe_reorder_b, 3, reorder_b_global_size, reorder_b_local_size, dst); + + // MoE kernel prepare + // Create sub buffer for dst + region.origin = offsetd; + region.size = ne0 * ne1 * ne2 * sizeof(float); + sub_buf_dst = clCreateSubBuffer( + extrad->data_device, + 0, + CL_BUFFER_CREATE_TYPE_REGION, + ®ion, + &status); + CL_CHECK(status); + // Create image for dst + cl_image_format image_format_buf_dst = {CL_R, CL_FLOAT}; + cl_image_desc image_desc_buf_dst = {CL_MEM_OBJECT_IMAGE1D_BUFFER, static_cast<size_t>(ne0 * ne1 * ne2), 0,0,0,0,0,0,0, {sub_buf_dst}}; + buf_dst_image = clCreateImage(backend_ctx->context, CL_MEM_WRITE_ONLY, &image_format_buf_dst, &image_desc_buf_dst, NULL, &status); + CL_CHECK(status); - CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0_q4_0->q)); - CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra0_q4_0->d)); - CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra1->data_device)); - CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offset1)); - CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extra2->data_device)); - CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offset2)); - CL_CHECK(clSetKernelArg(kernel, 6, sizeof(cl_mem), &extrad->data_device)); - CL_CHECK(clSetKernelArg(kernel, 7, sizeof(cl_ulong), &offsetd)); - CL_CHECK(clSetKernelArg(kernel, 8, sizeof(int), &ne00)); - CL_CHECK(clSetKernelArg(kernel, 9, sizeof(int), &ne01)); - CL_CHECK(clSetKernelArg(kernel, 10, sizeof(int), &ne02)); - CL_CHECK(clSetKernelArg(kernel, 11, sizeof(cl_ulong), &nb00)); - CL_CHECK(clSetKernelArg(kernel, 12, sizeof(cl_ulong), &nb02)); - CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int), &ne10)); - CL_CHECK(clSetKernelArg(kernel, 14, sizeof(int), &ne11)); - CL_CHECK(clSetKernelArg(kernel, 15, sizeof(int), &ne12)); - CL_CHECK(clSetKernelArg(kernel, 16, sizeof(cl_ulong), &nb11)); - CL_CHECK(clSetKernelArg(kernel, 17, sizeof(cl_ulong), &nb12)); - CL_CHECK(clSetKernelArg(kernel, 18, sizeof(int), &ne20)); - CL_CHECK(clSetKernelArg(kernel, 19, sizeof(int), &ne21)); - CL_CHECK(clSetKernelArg(kernel, 20, sizeof(cl_ulong), &nb21)); - CL_CHECK(clSetKernelArg(kernel, 21, sizeof(int), &ne0)); - CL_CHECK(clSetKernelArg(kernel, 22, sizeof(int), &ne1)); - CL_CHECK(clSetKernelArg(kernel, 23, sizeof(int), &r2)); - CL_CHECK(clSetKernelArg(kernel, 24, sizeof(int), &r3)); + // Set kernel args + int arg_idx = 0; + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &extra0_q5_1->qs_img)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &extra0_q5_1->qh)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &extra0_q5_1->d)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &extra0_q5_1->m)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &image_src1_reordered)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &buf_src2)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &buf_src2_emap)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &buf_dst_image)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &(backend_ctx->prealloc_total_tiles.buffer))); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(int), &ne01)); - break; + // set thread grid + global_size[1] = static_cast<size_t>((ne01 + 63) / 64); + global_size[2] = static_cast<size_t>(max_post_router_tile); + local_size[1] = 1; + local_size[2] = 1; + + // Dispatch kernel + backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_size, local_size, dst); + + clReleaseMemObject(sub_buf_src1_pre); + clReleaseMemObject(buf_src1_reordered); + clReleaseMemObject(image_src1_reordered); + clReleaseMemObject(buf_src2); + clReleaseMemObject(buf_src2_emap); + clReleaseMemObject(sub_buf_dst); + clReleaseMemObject(buf_dst_image); + } + return; + } +#endif //GGML_OPENCL_USE_ADRENO_KERNELS } case GGML_TYPE_Q8_0: { #ifdef GGML_OPENCL_SOA_Q @@ -8751,7 +16394,7 @@ static void ggml_cl_mul_mat_id(ggml_backend_t backend, const ggml_tensor * src0, #endif // GGML_OPENCL_SOA_Q break; } - case GGML_TYPE_MXFP4: { + case GGML_TYPE_Q4_K: { #ifdef GGML_OPENCL_USE_ADRENO_KERNELS if (use_adreno_moe_kernels(backend_ctx, src0)) { cl_int status; @@ -8759,11 +16402,183 @@ static void ggml_cl_mul_mat_id(ggml_backend_t backend, const ggml_tensor * src0, size_t local_size[3] = {64, 2, 1}; size_t global_size[3] = {64, 2, 1}; - cl_mem src1_sub_buffer, buf_src1_image, buf_src2; + if (ne12 == 1) { // for gemv + kernel = backend_ctx->kernel_gemv_moe_q4_k_f32_ns; + + cl_mem src1_sub_buffer, buf_src1_image, buf_src2; + + // create a sub_buffer for src2 + cl_buffer_region region; + region.origin = offset2; + region.size = ne20 * ne21 * sizeof(int); + buf_src2 = clCreateSubBuffer(extra2->data_device, 0, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &status); + CL_CHECK(status); + + // set thread grid + global_size[0] = static_cast<size_t>(((ne01 + 63) / 64) * 64); + global_size[1] = 4; + global_size[2] = static_cast<size_t>(ne20); + local_size[1] = 4; + + // create a sub_buffer for src1 + region.origin = offset1; + region.size = ne10 * ne11 * ne12 * sizeof(float); + src1_sub_buffer = clCreateSubBuffer(extra1->data_device, 0, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &status); + CL_CHECK(status); + + // create image for src1 + cl_image_format image_format_buf_src1 = {CL_RGBA, CL_FLOAT}; + cl_image_desc image_desc_buf_src1 = {CL_MEM_OBJECT_IMAGE1D_BUFFER, static_cast<size_t>(ne10 * ne11 * ne12 / 4), 0,0,0,0,0,0,0, {src1_sub_buffer}}; + buf_src1_image = clCreateImage(backend_ctx->context, CL_MEM_READ_ONLY, &image_format_buf_src1, &image_desc_buf_src1, NULL, &status); + CL_CHECK(status); + + // Set kernel args + int arg_idx = 0; + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &extra0_q4_K->q)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &extra0_q4_K->d)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &extra0_q4_K->dm)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &extra0_q4_K->s)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &buf_src1_image)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &buf_src2)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_ulong), &offsetd)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(int), &ne01)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(int), &ne11)); + + // launch kernel + backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_size, local_size, dst); + + // deallocate sub buffers and images + CL_CHECK(clReleaseMemObject(src1_sub_buffer)); + CL_CHECK(clReleaseMemObject(buf_src1_image)); + CL_CHECK(clReleaseMemObject(buf_src2)); + + } else { // for gemm + kernel = backend_ctx->kernel_gemm_moe_q4_k_f32_ns; + + // Reorder router if called from test-backend-ops or when new router is generated. + // Otherwise reuse the reordered result from previous mul_mat_id call. + if ((strstr(src0->name, "as") != NULL) || backend_ctx->toggle_reorder) { + moe_router_reoerder(backend, src2, ne20); + backend_ctx->toggle_reorder = false; + } + + cl_mem sub_buf_src1_pre, buf_src1_reordered, image_src1_reordered, sub_buf_dst, buf_dst_image; + cl_mem buf_src2, buf_src2_emap; + + cl_buffer_region region; + region.origin = 0; + region.size = sizeof(int) * max_post_router_tile * n_tile_size; + buf_src2 = clCreateSubBuffer(backend_ctx->prealloc_post_router.buffer, 0, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &status); + CL_CHECK(status); + + region.origin = 0; + region.size = sizeof(short) * max_post_router_tile; + buf_src2_emap = clCreateSubBuffer(backend_ctx->prealloc_emap.buffer, 0, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &status); + CL_CHECK(status); + + // Reorder activations + region.origin = offset1; + region.size = ne10 * ne11 * ne12 * sizeof(float); + sub_buf_src1_pre = clCreateSubBuffer(extra1->data_device, 0, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &status); + CL_CHECK(status); + + // Create image for reordered src1 + region.origin = 0; + region.size = ne00 * max_post_router_tile * n_tile_size * sizeof(float); + backend_ctx->prealloc_act_trans.allocate(backend_ctx->context, region.size); + buf_src1_reordered = clCreateSubBuffer( + backend_ctx->prealloc_act_trans.buffer, + 0, + CL_BUFFER_CREATE_TYPE_REGION, + ®ion, + &status); + CL_CHECK(status); + cl_image_format image_format_buf_src1 = {CL_RGBA, CL_FLOAT}; + cl_image_desc image_desc_buf_src1 = {CL_MEM_OBJECT_IMAGE1D_BUFFER, static_cast<size_t>(ne00 * max_post_router_tile * n_tile_size / 4), 0,0,0,0,0,0,0, {buf_src1_reordered}}; + image_src1_reordered = clCreateImage(backend_ctx->context, CL_MEM_READ_ONLY, &image_format_buf_src1, &image_desc_buf_src1, NULL, &status); + CL_CHECK(status); + + unsigned short map_ratio = ne20 / ne11; + GGML_ASSERT(((map_ratio == 1) || (map_ratio == ne20)) && "Map ratio not supported\n"); + CL_CHECK(clSetKernelArg(backend_ctx->kernel_moe_reorder_b, 0, sizeof(cl_mem), &sub_buf_src1_pre)); + CL_CHECK(clSetKernelArg(backend_ctx->kernel_moe_reorder_b, 1, sizeof(cl_mem), &buf_src2)); + CL_CHECK(clSetKernelArg(backend_ctx->kernel_moe_reorder_b, 2, sizeof(cl_mem), &buf_src1_reordered)); + CL_CHECK(clSetKernelArg(backend_ctx->kernel_moe_reorder_b, 3, sizeof(cl_mem), &(backend_ctx->prealloc_total_tiles.buffer))); + CL_CHECK(clSetKernelArg(backend_ctx->kernel_moe_reorder_b, 4, sizeof(unsigned int), &ne00)); + CL_CHECK(clSetKernelArg(backend_ctx->kernel_moe_reorder_b, 5, sizeof(unsigned short), &map_ratio)); + CL_CHECK(clSetKernelArg(backend_ctx->kernel_moe_reorder_b, 6, sizeof(unsigned int), &n_tile_size)); + + size_t reorder_b_local_size[3] = {256, 1, 1}; + size_t reorder_b_global_size[3] = {static_cast<size_t>(((ne00 / 4) + 255) / 256 * 256), static_cast<size_t>(max_post_router_tile * n_tile_size), 1}; + + // Dispatch reorder kernel + backend_ctx->enqueue_ndrange_kernel(backend_ctx->kernel_moe_reorder_b, 3, reorder_b_global_size, reorder_b_local_size, dst); + + // MoE kernel prepare + region.origin = offsetd; + region.size = ne0 * ne1 * ne2 * sizeof(float); + sub_buf_dst = clCreateSubBuffer( + extrad->data_device, + 0, + CL_BUFFER_CREATE_TYPE_REGION, + ®ion, + &status); + CL_CHECK(status); + // Create image for dst + cl_image_format image_format_buf_dst = {CL_R, CL_FLOAT}; + cl_image_desc image_desc_buf_dst = {CL_MEM_OBJECT_IMAGE1D_BUFFER, static_cast<size_t>(ne0 * ne1 * ne2), 0,0,0,0,0,0,0, {sub_buf_dst}}; + buf_dst_image = clCreateImage(backend_ctx->context, CL_MEM_WRITE_ONLY, &image_format_buf_dst, &image_desc_buf_dst, NULL, &status); + CL_CHECK(status); + + // Set kernel args + int arg_idx = 0; + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &extra0_q4_K->q_img)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &extra0_q4_K->d)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &extra0_q4_K->dm)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &extra0_q4_K->s)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &image_src1_reordered)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &buf_src2)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &buf_src2_emap)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &buf_dst_image)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &(backend_ctx->prealloc_total_tiles.buffer))); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(int), &ne01)); + + // set thread grid + global_size[1] = static_cast<size_t>((ne01 + 63) / 64); + global_size[2] = static_cast<size_t>(max_post_router_tile); + local_size[1] = 1; + local_size[2] = 1; + + // Dispatch kernel + backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_size, local_size, dst); + + clReleaseMemObject(sub_buf_src1_pre); + clReleaseMemObject(buf_src1_reordered); + clReleaseMemObject(image_src1_reordered); + clReleaseMemObject(buf_src2); + clReleaseMemObject(buf_src2_emap); + clReleaseMemObject(sub_buf_dst); + clReleaseMemObject(buf_dst_image); + } + return; + } +#endif //GGML_OPENCL_USE_ADRENO_KERNELS + } + case GGML_TYPE_Q5_K: { +#ifdef GGML_OPENCL_USE_ADRENO_KERNELS + if (use_adreno_moe_kernels(backend_ctx, src0)) { + cl_int status; + + size_t local_size[3] = {64, 2, 1}; + size_t global_size[3] = {64, 2, 1}; - int tile_size = 320; if (ne12 == 1) { // for gemv - kernel = backend_ctx->kernel_gemv_moe_mxfp4_f32; + kernel = backend_ctx->kernel_gemv_moe_q5_k_f32_ns; + + cl_mem src1_sub_buffer, buf_src1_image, buf_src2; // create a sub_buffer for src2 cl_buffer_region region; @@ -8773,82 +16588,511 @@ static void ggml_cl_mul_mat_id(ggml_backend_t backend, const ggml_tensor * src0, CL_CHECK(status); // set thread grid - global_size[0] = static_cast<size_t>(ne01); + global_size[0] = static_cast<size_t>(((ne01 + 63) / 64) * 64); global_size[1] = 4; global_size[2] = static_cast<size_t>(ne20); local_size[1] = 4; + + // create a sub_buffer for src1 + region.origin = offset1; + region.size = ne10 * ne11 * ne12 * sizeof(float); + src1_sub_buffer = clCreateSubBuffer(extra1->data_device, 0, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &status); + CL_CHECK(status); + + // create image for src1 + cl_image_format image_format_buf_src1 = {CL_RGBA, CL_FLOAT}; + cl_image_desc image_desc_buf_src1 = {CL_MEM_OBJECT_IMAGE1D_BUFFER, static_cast<size_t>(ne10 * ne11 * ne12 / 4), 0,0,0,0,0,0,0, {src1_sub_buffer}}; + buf_src1_image = clCreateImage(backend_ctx->context, CL_MEM_READ_ONLY, &image_format_buf_src1, &image_desc_buf_src1, NULL, &status); + CL_CHECK(status); + + // Set kernel args + int arg_idx = 0; + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &extra0_q5_K->q)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &extra0_q5_K->qh)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &extra0_q5_K->d)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &extra0_q5_K->dm)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &extra0_q5_K->s)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &buf_src1_image)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &buf_src2)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_ulong), &offsetd)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(int), &ne01)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(int), &ne11)); + + // launch kernel + backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_size, local_size, dst); + + // deallocate sub buffers and images + CL_CHECK(clReleaseMemObject(src1_sub_buffer)); + CL_CHECK(clReleaseMemObject(buf_src1_image)); + CL_CHECK(clReleaseMemObject(buf_src2)); + } else { // for gemm - kernel = backend_ctx->kernel_gemm_moe_mxfp4_f32; - - // preprocess router table - int num_tiles_per_expert = (ne01 + tile_size - 1) / tile_size; - void * host_src2_reorder = malloc(ne20 * ne21 * 4 * num_tiles_per_expert * sizeof(short)); - void * host_src2 = malloc(ne21 * nb21); - CL_CHECK(clEnqueueReadBuffer(backend_ctx->queue, extra2->data_device, CL_TRUE, offset2, ne21 * nb21, host_src2, 0, NULL, NULL)); - int total_experts = nb21 / nb20; - int out_idx = 0; - for (int i_expert = 0; i_expert < ne02; i_expert++) { - for (int i_tile = 0; i_tile < num_tiles_per_expert; i_tile++) { - for (int j = 0; j < ne21; j++) { - for (int i = 0; i < ne20; i++) { - int expert = ((int *)host_src2)[j * total_experts + i]; - if (i_expert == expert) { - ((short *)host_src2_reorder)[out_idx] = static_cast<short>(expert); - ((short *)host_src2_reorder)[out_idx + 1] = static_cast<short>(j * ne11 + (i % ne11)); - ((short *)host_src2_reorder)[out_idx + 2] = static_cast<short>(j * ne20 + i); - ((short *)host_src2_reorder)[out_idx + 3] = static_cast<short>(i_tile); - out_idx += 4; - } - } - } - } + kernel = backend_ctx->kernel_gemm_moe_q5_k_f32_ns; + + // Reorder router if called from test-backend-ops or when new router is generated. + // Otherwise reuse the reordered result from previous mul_mat_id call. + if ((strstr(src0->name, "as") != NULL) || backend_ctx->toggle_reorder) { + moe_router_reoerder(backend, src2, ne20); + backend_ctx->toggle_reorder = false; } - buf_src2 = clCreateBuffer(backend_ctx->context, CL_MEM_READ_ONLY | CL_MEM_COPY_HOST_PTR, ne20 * ne21 * 4 * num_tiles_per_expert * sizeof(short), host_src2_reorder, &status); + + cl_mem sub_buf_src1_pre, buf_src1_reordered, image_src1_reordered, sub_buf_dst, buf_dst_image; + cl_mem buf_src2, buf_src2_emap; + + cl_buffer_region region; + region.origin = 0; + region.size = sizeof(int) * max_post_router_tile * n_tile_size; + buf_src2 = clCreateSubBuffer(backend_ctx->prealloc_post_router.buffer, 0, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &status); + CL_CHECK(status); + + region.origin = 0; + region.size = sizeof(short) * max_post_router_tile; + buf_src2_emap = clCreateSubBuffer(backend_ctx->prealloc_emap.buffer, 0, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &status); + CL_CHECK(status); + + // Reorder activations + // create a sub_buffer for src1 + region.origin = offset1; + region.size = ne10 * ne11 * ne12 * sizeof(float); + sub_buf_src1_pre = clCreateSubBuffer(extra1->data_device, 0, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &status); + CL_CHECK(status); + + // Create image for reordered src1 + // Use pre-allocated placeholder + region.origin = 0; + region.size = ne00 * max_post_router_tile * n_tile_size * sizeof(float); + backend_ctx->prealloc_act_trans.allocate(backend_ctx->context, region.size); + buf_src1_reordered = clCreateSubBuffer( + backend_ctx->prealloc_act_trans.buffer, + 0, + CL_BUFFER_CREATE_TYPE_REGION, + ®ion, + &status); + CL_CHECK(status); + cl_image_format image_format_buf_src1 = {CL_RGBA, CL_FLOAT}; + cl_image_desc image_desc_buf_src1 = {CL_MEM_OBJECT_IMAGE1D_BUFFER, static_cast<size_t>(ne00 * max_post_router_tile * n_tile_size / 4), 0,0,0,0,0,0,0, {buf_src1_reordered}}; + image_src1_reordered = clCreateImage(backend_ctx->context, CL_MEM_READ_ONLY, &image_format_buf_src1, &image_desc_buf_src1, NULL, &status); + CL_CHECK(status); + + unsigned short map_ratio = ne20 / ne11; + GGML_ASSERT(((map_ratio == 1) || (map_ratio == ne20)) && "Map ratio not supported\n"); + CL_CHECK(clSetKernelArg(backend_ctx->kernel_moe_reorder_b, 0, sizeof(cl_mem), &sub_buf_src1_pre)); + CL_CHECK(clSetKernelArg(backend_ctx->kernel_moe_reorder_b, 1, sizeof(cl_mem), &buf_src2)); + CL_CHECK(clSetKernelArg(backend_ctx->kernel_moe_reorder_b, 2, sizeof(cl_mem), &buf_src1_reordered)); + CL_CHECK(clSetKernelArg(backend_ctx->kernel_moe_reorder_b, 3, sizeof(cl_mem), &(backend_ctx->prealloc_total_tiles.buffer))); + CL_CHECK(clSetKernelArg(backend_ctx->kernel_moe_reorder_b, 4, sizeof(unsigned int), &ne00)); + CL_CHECK(clSetKernelArg(backend_ctx->kernel_moe_reorder_b, 5, sizeof(unsigned short), &map_ratio)); + CL_CHECK(clSetKernelArg(backend_ctx->kernel_moe_reorder_b, 6, sizeof(unsigned int), &n_tile_size)); + + size_t reorder_b_local_size[3] = {256, 1, 1}; + size_t reorder_b_global_size[3] = {static_cast<size_t>(((ne00 / 4) + 255) / 256 * 256), static_cast<size_t>(max_post_router_tile * n_tile_size), 1}; + + // Dispatch reorder kernel + backend_ctx->enqueue_ndrange_kernel(backend_ctx->kernel_moe_reorder_b, 3, reorder_b_global_size, reorder_b_local_size, dst); + + // MoE kernel prepare + // Create sub buffer for dst + region.origin = offsetd; + region.size = ne0 * ne1 * ne2 * sizeof(float); + sub_buf_dst = clCreateSubBuffer( + extrad->data_device, + 0, + CL_BUFFER_CREATE_TYPE_REGION, + ®ion, + &status); CL_CHECK(status); + // Create image for dst + cl_image_format image_format_buf_dst = {CL_R, CL_FLOAT}; + cl_image_desc image_desc_buf_dst = {CL_MEM_OBJECT_IMAGE1D_BUFFER, static_cast<size_t>(ne0 * ne1 * ne2), 0,0,0,0,0,0,0, {sub_buf_dst}}; + buf_dst_image = clCreateImage(backend_ctx->context, CL_MEM_WRITE_ONLY, &image_format_buf_dst, &image_desc_buf_dst, NULL, &status); + CL_CHECK(status); + + // Set kernel args + int arg_idx = 0; + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &extra0_q5_K->q_img)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &extra0_q5_K->qh)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &extra0_q5_K->s)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &extra0_q5_K->d)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &extra0_q5_K->dm)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &image_src1_reordered)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &buf_src2)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &buf_src2_emap)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &buf_dst_image)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &(backend_ctx->prealloc_total_tiles.buffer))); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(int), &ne01)); // set thread grid - global_size[0] = static_cast<size_t>(tile_size); - global_size[2] = static_cast<size_t>(ne20 * ne21 * num_tiles_per_expert); + global_size[1] = static_cast<size_t>((ne01 + 63) / 64); + global_size[2] = static_cast<size_t>(max_post_router_tile); + local_size[1] = 1; + local_size[2] = 1; + + // Dispatch kernel + backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_size, local_size, dst); + + clReleaseMemObject(sub_buf_src1_pre); + clReleaseMemObject(buf_src1_reordered); + clReleaseMemObject(image_src1_reordered); + clReleaseMemObject(buf_src2); + clReleaseMemObject(buf_src2_emap); + clReleaseMemObject(sub_buf_dst); + clReleaseMemObject(buf_dst_image); } + return; + } +#endif //GGML_OPENCL_USE_ADRENO_KERNELS + } + case GGML_TYPE_Q6_K: { +#ifdef GGML_OPENCL_USE_ADRENO_KERNELS + if (use_adreno_moe_kernels(backend_ctx, src0)) { + cl_int status; + + size_t local_size[3] = {64, 2, 1}; + size_t global_size[3] = {64, 2, 1}; + + if (ne12 == 1) { // for gemv + kernel = backend_ctx->kernel_gemv_moe_q6_k_f32_ns; + + cl_mem src1_sub_buffer, buf_src1_image, buf_src2; + + // create a sub_buffer for src2 + cl_buffer_region region; + region.origin = offset2; + region.size = ne20 * ne21 * sizeof(int); + buf_src2 = clCreateSubBuffer(extra2->data_device, 0, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &status); + CL_CHECK(status); + + // set thread grid + global_size[0] = static_cast<size_t>(((ne01 + 63) / 64) * 64); + global_size[1] = 4; + global_size[2] = static_cast<size_t>(ne20); + local_size[1] = 4; + + // create a sub_buffer for src1 + region.origin = offset1; + region.size = ne10 * ne11 * ne12 * sizeof(float); + src1_sub_buffer = clCreateSubBuffer(extra1->data_device, 0, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &status); + CL_CHECK(status); - // create a sub_buffer for src1 - cl_buffer_region region; - region.origin = offset1; - region.size = ne10 * ne11 * ne12 * sizeof(float); - src1_sub_buffer = clCreateSubBuffer(extra1->data_device, 0, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &status); - CL_CHECK(status); - - // create image for src1 - cl_image_format image_format_buf_src1 = {CL_RGBA, CL_FLOAT}; - cl_image_desc image_desc_buf_src1 = {CL_MEM_OBJECT_IMAGE1D_BUFFER, static_cast<size_t>(ne10 * ne11 * ne12 / 4), 0,0,0,0,0,0,0, {src1_sub_buffer}}; - buf_src1_image = clCreateImage(backend_ctx->context, CL_MEM_READ_ONLY, &image_format_buf_src1, &image_desc_buf_src1, NULL, &status); - CL_CHECK(status); - - // Set kernel args - int arg_idx = 0; - CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &extra0_mxfp4->q)); - CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &extra0_mxfp4->e)); - CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &buf_src1_image)); - CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &buf_src2)); - CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &extrad->data_device)); - CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_ulong), &offsetd)); - CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(int), &ne00)); - CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(int), &ne01)); - if (ne12 == 1) { + // create image for src1 + cl_image_format image_format_buf_src1 = {CL_RGBA, CL_FLOAT}; + cl_image_desc image_desc_buf_src1 = {CL_MEM_OBJECT_IMAGE1D_BUFFER, static_cast<size_t>(ne10 * ne11 * ne12 / 4), 0,0,0,0,0,0,0, {src1_sub_buffer}}; + buf_src1_image = clCreateImage(backend_ctx->context, CL_MEM_READ_ONLY, &image_format_buf_src1, &image_desc_buf_src1, NULL, &status); + CL_CHECK(status); + + // Set kernel args + int arg_idx = 0; + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &extra0_q6_K->ql)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &extra0_q6_K->qh)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &extra0_q6_K->s)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &extra0_q6_K->d)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &buf_src1_image)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &buf_src2)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_ulong), &offsetd)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(int), &ne01)); CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(int), &ne11)); - } else { - CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(int), &tile_size)); + + // launch kernel + backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_size, local_size, dst); + + // deallocate sub buffers and images + CL_CHECK(clReleaseMemObject(src1_sub_buffer)); + CL_CHECK(clReleaseMemObject(buf_src1_image)); + CL_CHECK(clReleaseMemObject(buf_src2)); + + } else { // for gemm + kernel = backend_ctx->kernel_gemm_moe_q6_k_f32_ns; + + // Reorder router if called from test-backend-ops or when new router is generated. + // Otherwise reuse the reordered result from previous mul_mat_id call. + if ((strstr(src0->name, "as") != NULL) || backend_ctx->toggle_reorder) { + moe_router_reoerder(backend, src2, ne20); + backend_ctx->toggle_reorder = false; + } + + cl_mem sub_buf_src1_pre, buf_src1_reordered, image_src1_reordered, sub_buf_dst, buf_dst_image; + cl_mem buf_src2, buf_src2_emap; + + cl_buffer_region region; + region.origin = 0; + region.size = sizeof(int) * max_post_router_tile * n_tile_size; + buf_src2 = clCreateSubBuffer(backend_ctx->prealloc_post_router.buffer, 0, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &status); + CL_CHECK(status); + + region.origin = 0; + region.size = sizeof(short) * max_post_router_tile; + buf_src2_emap = clCreateSubBuffer(backend_ctx->prealloc_emap.buffer, 0, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &status); + CL_CHECK(status); + + // Reorder activations + // create a sub_buffer for src1 + region.origin = offset1; + region.size = ne10 * ne11 * ne12 * sizeof(float); + sub_buf_src1_pre = clCreateSubBuffer(extra1->data_device, 0, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &status); + CL_CHECK(status); + + // Create image for reordered src1 + region.origin = 0; + region.size = ne00 * max_post_router_tile * n_tile_size * sizeof(float); + backend_ctx->prealloc_act_trans.allocate(backend_ctx->context, region.size); + buf_src1_reordered = clCreateSubBuffer( + backend_ctx->prealloc_act_trans.buffer, + 0, + CL_BUFFER_CREATE_TYPE_REGION, + ®ion, + &status); + CL_CHECK(status); + cl_image_format image_format_buf_src1 = {CL_RGBA, CL_FLOAT}; + cl_image_desc image_desc_buf_src1 = {CL_MEM_OBJECT_IMAGE1D_BUFFER, static_cast<size_t>(ne00 * max_post_router_tile * n_tile_size / 4), 0,0,0,0,0,0,0, {buf_src1_reordered}}; + image_src1_reordered = clCreateImage(backend_ctx->context, CL_MEM_READ_ONLY, &image_format_buf_src1, &image_desc_buf_src1, NULL, &status); + CL_CHECK(status); + + unsigned short map_ratio = ne20 / ne11; + GGML_ASSERT(((map_ratio == 1) || (map_ratio == ne20)) && "Map ratio not supported\n"); + CL_CHECK(clSetKernelArg(backend_ctx->kernel_moe_reorder_b, 0, sizeof(cl_mem), &sub_buf_src1_pre)); + CL_CHECK(clSetKernelArg(backend_ctx->kernel_moe_reorder_b, 1, sizeof(cl_mem), &buf_src2)); + CL_CHECK(clSetKernelArg(backend_ctx->kernel_moe_reorder_b, 2, sizeof(cl_mem), &buf_src1_reordered)); + CL_CHECK(clSetKernelArg(backend_ctx->kernel_moe_reorder_b, 3, sizeof(cl_mem), &(backend_ctx->prealloc_total_tiles.buffer))); + CL_CHECK(clSetKernelArg(backend_ctx->kernel_moe_reorder_b, 4, sizeof(unsigned int), &ne00)); + CL_CHECK(clSetKernelArg(backend_ctx->kernel_moe_reorder_b, 5, sizeof(unsigned short), &map_ratio)); + CL_CHECK(clSetKernelArg(backend_ctx->kernel_moe_reorder_b, 6, sizeof(unsigned int), &n_tile_size)); + + size_t reorder_b_local_size[3] = {256, 1, 1}; + size_t reorder_b_global_size[3] = {static_cast<size_t>(((ne00 / 4) + 255) / 256 * 256), static_cast<size_t>(max_post_router_tile * n_tile_size), 1}; + + // Dispatch reorder kernel + backend_ctx->enqueue_ndrange_kernel(backend_ctx->kernel_moe_reorder_b, 3, reorder_b_global_size, reorder_b_local_size, dst); + + // MoE kernel prepare + // Create sub buffer for dst + region.origin = offsetd; + region.size = ne0 * ne1 * ne2 * sizeof(float); + sub_buf_dst = clCreateSubBuffer( + extrad->data_device, + 0, + CL_BUFFER_CREATE_TYPE_REGION, + ®ion, + &status); + CL_CHECK(status); + // Create image for dst + cl_image_format image_format_buf_dst = {CL_R, CL_FLOAT}; + cl_image_desc image_desc_buf_dst = {CL_MEM_OBJECT_IMAGE1D_BUFFER, static_cast<size_t>(ne0 * ne1 * ne2), 0,0,0,0,0,0,0, {sub_buf_dst}}; + buf_dst_image = clCreateImage(backend_ctx->context, CL_MEM_WRITE_ONLY, &image_format_buf_dst, &image_desc_buf_dst, NULL, &status); + CL_CHECK(status); + + // Set kernel args + int arg_idx = 0; + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &extra0_q6_K->ql_img)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &extra0_q6_K->qh)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &extra0_q6_K->s)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &extra0_q6_K->d)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &image_src1_reordered)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &buf_src2)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &buf_src2_emap)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &buf_dst_image)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &(backend_ctx->prealloc_total_tiles.buffer))); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(int), &ne01)); + + // set thread grid + global_size[1] = static_cast<size_t>((ne01 + 63) / 64); + global_size[2] = static_cast<size_t>(max_post_router_tile); + local_size[1] = 1; + local_size[2] = 1; + + // Dispatch kernel + backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_size, local_size, dst); + + clReleaseMemObject(sub_buf_src1_pre); + clReleaseMemObject(buf_src1_reordered); + clReleaseMemObject(image_src1_reordered); + clReleaseMemObject(buf_src2); + clReleaseMemObject(buf_src2_emap); + clReleaseMemObject(sub_buf_dst); + clReleaseMemObject(buf_dst_image); } + return; + } +#endif //GGML_OPENCL_USE_ADRENO_KERNELS + } + case GGML_TYPE_MXFP4: { +#ifdef GGML_OPENCL_USE_ADRENO_KERNELS + if (use_adreno_moe_kernels(backend_ctx, src0)) { + cl_int status; + + size_t local_size[3] = {64, 2, 1}; + size_t global_size[3] = {64, 2, 1}; + + if (ne12 == 1) { // for gemv + kernel = backend_ctx->kernel_gemv_moe_mxfp4_f32_ns; + + cl_mem src1_sub_buffer, buf_src1_image, buf_src2; + + // create a sub_buffer for src2 + cl_buffer_region region; + region.origin = offset2; + region.size = ne20 * ne21 * sizeof(int); + buf_src2 = clCreateSubBuffer(extra2->data_device, 0, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &status); + CL_CHECK(status); + + // set thread grid + global_size[0] = static_cast<size_t>(((ne01 + 63) / 64) * 64); + global_size[1] = 4; + global_size[2] = static_cast<size_t>(ne20); + local_size[1] = 4; + + // create a sub_buffer for src1 + region.origin = offset1; + region.size = ne10 * ne11 * ne12 * sizeof(float); + src1_sub_buffer = clCreateSubBuffer(extra1->data_device, 0, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &status); + CL_CHECK(status); + + // create image for src1 + cl_image_format image_format_buf_src1 = {CL_RGBA, CL_FLOAT}; + cl_image_desc image_desc_buf_src1 = {CL_MEM_OBJECT_IMAGE1D_BUFFER, static_cast<size_t>(ne10 * ne11 * ne12 / 4), 0,0,0,0,0,0,0, {src1_sub_buffer}}; + buf_src1_image = clCreateImage(backend_ctx->context, CL_MEM_READ_ONLY, &image_format_buf_src1, &image_desc_buf_src1, NULL, &status); + CL_CHECK(status); + + // Set kernel args + int arg_idx = 0; + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &extra0_mxfp4->q)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &extra0_mxfp4->e)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &buf_src1_image)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &buf_src2)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_ulong), &offsetd)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(int), &ne01)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(int), &ne11)); + + // launch kernel + backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_size, local_size, dst); + + // deallocate sub buffers and images + CL_CHECK(clReleaseMemObject(src1_sub_buffer)); + CL_CHECK(clReleaseMemObject(buf_src1_image)); + CL_CHECK(clReleaseMemObject(buf_src2)); + + } else { // for gemm + kernel = backend_ctx->kernel_gemm_moe_mxfp4_f32_ns; + + // Reorder router if called from test-backend-ops or when new router is generated. + // Otherwise reuse the reordered result from previous mul_mat_id call. + if ((strstr(src0->name, "as") != NULL) || backend_ctx->toggle_reorder) { + moe_router_reoerder(backend, src2, ne20); + backend_ctx->toggle_reorder = false; + } + + cl_mem sub_buf_src1_pre, buf_src1_reordered, image_src1_reordered, sub_buf_dst, buf_dst_image; + cl_mem buf_src2, buf_src2_emap; + + cl_buffer_region region; + region.origin = 0; + region.size = sizeof(int) * max_post_router_tile * n_tile_size; + GGML_ASSERT(backend_ctx->prealloc_post_router.buffer); + buf_src2 = clCreateSubBuffer(backend_ctx->prealloc_post_router.buffer, 0, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &status); + CL_CHECK(status); + + region.origin = 0; + region.size = sizeof(short) * max_post_router_tile; + buf_src2_emap = clCreateSubBuffer(backend_ctx->prealloc_emap.buffer, 0, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &status); + CL_CHECK(status); + + // Reorder activations + // create a sub_buffer for src1 + region.origin = offset1; + region.size = ne10 * ne11 * ne12 * sizeof(float); + sub_buf_src1_pre = clCreateSubBuffer(extra1->data_device, 0, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &status); + CL_CHECK(status); + + // Create image for reordered src1 + // Use pre-allocated placeholder + region.origin = 0; + region.size = ne00 * max_post_router_tile * n_tile_size * sizeof(float); + backend_ctx->prealloc_act_trans.allocate(backend_ctx->context, region.size); + buf_src1_reordered = clCreateSubBuffer( + backend_ctx->prealloc_act_trans.buffer, + 0, + CL_BUFFER_CREATE_TYPE_REGION, + ®ion, + &status); + CL_CHECK(status); + cl_image_format image_format_buf_src1; + cl_image_desc image_desc_buf_src1; + image_format_buf_src1 = {CL_RGBA, CL_FLOAT}; + image_desc_buf_src1 = {CL_MEM_OBJECT_IMAGE1D_BUFFER, static_cast<size_t>(ne00 * max_post_router_tile * n_tile_size / 4), 0,0,0,0,0,0,0, {buf_src1_reordered}}; + image_src1_reordered = clCreateImage(backend_ctx->context, CL_MEM_READ_ONLY, &image_format_buf_src1, &image_desc_buf_src1, NULL, &status); + CL_CHECK(status); + + unsigned short map_ratio = ne20 / ne11; + GGML_ASSERT(((map_ratio == 1) || (map_ratio == ne20)) && "Map ratio not supported\n"); + CL_CHECK(clSetKernelArg(backend_ctx->kernel_moe_reorder_b, 0, sizeof(cl_mem), &sub_buf_src1_pre)); + CL_CHECK(clSetKernelArg(backend_ctx->kernel_moe_reorder_b, 1, sizeof(cl_mem), &buf_src2)); + CL_CHECK(clSetKernelArg(backend_ctx->kernel_moe_reorder_b, 2, sizeof(cl_mem), &buf_src1_reordered)); + CL_CHECK(clSetKernelArg(backend_ctx->kernel_moe_reorder_b, 3, sizeof(cl_mem), &(backend_ctx->prealloc_total_tiles.buffer))); + CL_CHECK(clSetKernelArg(backend_ctx->kernel_moe_reorder_b, 4, sizeof(unsigned int), &ne00)); + CL_CHECK(clSetKernelArg(backend_ctx->kernel_moe_reorder_b, 5, sizeof(unsigned short), &map_ratio)); + CL_CHECK(clSetKernelArg(backend_ctx->kernel_moe_reorder_b, 6, sizeof(unsigned int), &n_tile_size)); + + size_t reorder_b_local_size[3] = {256, 1, 1}; + size_t reorder_b_global_size[3] = {static_cast<size_t>(((ne00 / 4) + 255) / 256 * 256), static_cast<size_t>(max_post_router_tile * n_tile_size), 1}; + + // Dispatch reorder kernel + backend_ctx->enqueue_ndrange_kernel(backend_ctx->kernel_moe_reorder_b, 3, reorder_b_global_size, reorder_b_local_size, dst); + + // MoE kernel prepare + // Create sub buffer for dst + region.origin = offsetd; + region.size = ne0 * ne1 * ne2 * sizeof(float); + sub_buf_dst = clCreateSubBuffer( + extrad->data_device, + 0, + CL_BUFFER_CREATE_TYPE_REGION, + ®ion, + &status); + CL_CHECK(status); + // Create image for dst + cl_image_format image_format_buf_dst = {CL_R, CL_FLOAT}; + cl_image_desc image_desc_buf_dst = {CL_MEM_OBJECT_IMAGE1D_BUFFER, static_cast<size_t>(ne0 * ne1 * ne2), 0,0,0,0,0,0,0, {sub_buf_dst}}; + buf_dst_image = clCreateImage(backend_ctx->context, CL_MEM_WRITE_ONLY, &image_format_buf_dst, &image_desc_buf_dst, NULL, &status); + CL_CHECK(status); - // launch kernel - backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_size, local_size, dst); + // Set kernel args + int arg_idx = 0; + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &extra0_mxfp4->q_img)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &extra0_mxfp4->e)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &image_src1_reordered)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &buf_src2)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &buf_src2_emap)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &buf_dst_image)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(cl_mem), &(backend_ctx->prealloc_total_tiles.buffer))); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, arg_idx++, sizeof(int), &ne01)); - // deallocate sub buffers and images - CL_CHECK(clReleaseMemObject(src1_sub_buffer)); - CL_CHECK(clReleaseMemObject(buf_src1_image)); - CL_CHECK(clReleaseMemObject(buf_src2)); + // set thread grid + global_size[1] = static_cast<size_t>((ne01 + 63) / 64); + global_size[2] = static_cast<size_t>(max_post_router_tile); + local_size[1] = 1; + local_size[2] = 1; + + // Dispatch kernel + backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_size, local_size, dst); + + clReleaseMemObject(sub_buf_src1_pre); + clReleaseMemObject(buf_src1_reordered); + clReleaseMemObject(image_src1_reordered); + clReleaseMemObject(buf_src2); + clReleaseMemObject(buf_src2_emap); + clReleaseMemObject(sub_buf_dst); + clReleaseMemObject(buf_dst_image); + } return; - } // else fallback to generic kernel + } // fallback to generic MoE mxfp4 kernel #endif // GGML_OPENCL_USE_ADRENO_KERNELS #ifdef GGML_OPENCL_SOA_Q @@ -8970,10 +17214,19 @@ static void ggml_cl_scale(ggml_backend_t backend, const ggml_tensor * src0, cons ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra; ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra; - cl_ulong offset0 = extra0->offset + src0->view_offs; - cl_ulong offsetd = extrad->offset + dst->view_offs; + cl_ulong offset0 = extra0->offset + src0->view_offs; + cl_ulong offsetd = extrad->offset + dst->view_offs; + + cl_kernel kernel; + + int n = ggml_nelements(dst); - cl_kernel kernel = backend_ctx->kernel_scale; + if (n % 4 == 0) { + kernel = backend_ctx->kernel_scale_f32_4; + n /= 4; + } else { + kernel = backend_ctx->kernel_scale_f32; + } CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device)); CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0)); @@ -8982,8 +17235,6 @@ static void ggml_cl_scale(ggml_backend_t backend, const ggml_tensor * src0, cons CL_CHECK(clSetKernelArg(kernel, 4, sizeof(float), &scale)); CL_CHECK(clSetKernelArg(kernel, 5, sizeof(float), &bias)); - int n = ggml_nelements(dst)/4; - size_t global_work_size[] = {(size_t)n, 1, 1}; size_t local_work_size[] = {64, 1, 1}; @@ -9005,28 +17256,13 @@ static void ggml_cl_cpy(ggml_backend_t backend, const ggml_tensor * src0, const // GGML_OP_DUP and GGML_OP_CONT happen between src0 and dst. UNUSED(dst); - const int ne00 = src0 ? src0->ne[0] : 0; - const int ne01 = src0 ? src0->ne[1] : 0; - const int ne02 = src0 ? src0->ne[2] : 0; - const int ne03 = src0 ? src0->ne[3] : 0; - - const cl_ulong nb00 = src0 ? src0->nb[0] : 0; - const cl_ulong nb01 = src0 ? src0->nb[1] : 0; - const cl_ulong nb02 = src0 ? src0->nb[2] : 0; - const cl_ulong nb03 = src0 ? src0->nb[3] : 0; - - const int ne10 = src1 ? src1->ne[0] : 0; - const int ne11 = src1 ? src1->ne[1] : 0; - const int ne12 = src1 ? src1->ne[2] : 0; - const int ne13 = src1 ? src1->ne[3] : 0; - - const cl_ulong nb10 = src1 ? src1->nb[0] : 0; - const cl_ulong nb11 = src1 ? src1->nb[1] : 0; - const cl_ulong nb12 = src1 ? src1->nb[2] : 0; - const cl_ulong nb13 = src1 ? src1->nb[3] : 0; + GGML_TENSOR_LOCALS(int, ne0, src0, ne); + GGML_TENSOR_LOCALS(cl_ulong, nb0, src0, nb); + GGML_TENSOR_LOCALS(int, ne1, src1, ne); + GGML_TENSOR_LOCALS(cl_ulong, nb1, src1, nb); - const enum ggml_type src0t = src0 ? src0->type : GGML_TYPE_COUNT; - const enum ggml_type src1t = src1 ? src1->type : GGML_TYPE_COUNT; + const enum ggml_type src0t = src0->type; + const enum ggml_type src1t = src1->type; ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context; @@ -9045,7 +17281,8 @@ static void ggml_cl_cpy(ggml_backend_t backend, const ggml_tensor * src0, const kernel = backend_ctx->kernel_cpy_f32_f16; break; case GGML_TYPE_F32: - kernel = backend_ctx->kernel_cpy_f32_f32; + kernel = ne00 < 32 ? backend_ctx->kernel_cpy_f32_f32_pack + : backend_ctx->kernel_cpy_f32_f32; break; default: GGML_ASSERT(false && "not implemented"); @@ -9063,6 +17300,15 @@ static void ggml_cl_cpy(ggml_backend_t backend, const ggml_tensor * src0, const GGML_ASSERT(false && "not implemented"); } break; + case GGML_TYPE_I32: + switch (src1t) { + case GGML_TYPE_I32: + kernel = backend_ctx->kernel_cpy_i32_i32; + break; + default: + GGML_ASSERT(false && "not implemented"); + } + break; default: GGML_ASSERT(false && "not implemented"); } @@ -9088,12 +17334,27 @@ static void ggml_cl_cpy(ggml_backend_t backend, const ggml_tensor * src0, const CL_CHECK(clSetKernelArg(kernel, 18, sizeof(cl_ulong), &nb12)); CL_CHECK(clSetKernelArg(kernel, 19, sizeof(cl_ulong), &nb13)); - const int nth = MIN(64, ne00); + if (kernel == backend_ctx->kernel_cpy_f32_f32_pack) { + const int maxwg = (int)backend_ctx->get_kernel_workgroup_size(kernel); + const int base = MIN(64, maxwg); + const int tpr = MIN(ne00, base); // threads per row + const int rpw = MAX(1, base / tpr); // rows per workgroup + const int lsz = tpr * rpw; // <= base <= maxwg + const int nrows = ne01*ne02*ne03; + const int nwg = (nrows + rpw - 1) / rpw; - size_t global_work_size[] = {(size_t)ne01*nth, (size_t)ne02, (size_t)ne03}; - size_t local_work_size[] = {(size_t)nth, 1, 1}; + size_t global_work_size[] = {(size_t)nwg*lsz, 1, 1}; + size_t local_work_size[] = {(size_t)lsz, 1, 1}; + + backend_ctx->enqueue_ndrange_kernel(kernel, 1, global_work_size, local_work_size, src1); + } else { + const int nth = MIN(64, ne00); - backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, src1); + size_t global_work_size[] = {(size_t)ne01*nth, (size_t)ne02, (size_t)ne03}; + size_t local_work_size[] = {(size_t)nth, 1, 1}; + + backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, src1); + } } static void ggml_cl_dup(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { @@ -9101,6 +17362,89 @@ static void ggml_cl_dup(ggml_backend_t backend, const ggml_tensor * src0, const UNUSED(src1); } +static void ggml_cl_set(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { + GGML_ASSERT(src0); + GGML_ASSERT(src0->extra); + GGML_ASSERT(src1); + GGML_ASSERT(src1->extra); + GGML_ASSERT(dst); + GGML_ASSERT(dst->extra); + + GGML_ASSERT((src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_I32) && + src1->type == src0->type && dst->type == src0->type); + + GGML_TENSOR_LOCALS(int, ne0, src0, ne); + GGML_TENSOR_LOCALS(cl_ulong, nb0, src0, nb); + GGML_TENSOR_LOCALS(int, ne1, src1, ne); + GGML_TENSOR_LOCALS(cl_ulong, nb1, src1, nb); + GGML_TENSOR_LOCALS(int, ne, dst, ne); + GGML_TENSOR_LOCALS(cl_ulong, nb, dst, nb); + + ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context; + + ggml_tensor_extra_cl * extra1 = (ggml_tensor_extra_cl *)src1->extra; + ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra; + + cl_ulong offset1 = extra1->offset + src1->view_offs; + cl_ulong offsetd = extrad->offset + dst->view_offs; + + const cl_ulong pnb1 = ((const int32_t *)dst->op_params)[0]; + const cl_ulong pnb2 = ((const int32_t *)dst->op_params)[1]; + const cl_ulong pnb3 = ((const int32_t *)dst->op_params)[2]; + const cl_ulong offs = ((const int32_t *)dst->op_params)[3]; + const bool inplace = (bool)((const int32_t *)dst->op_params)[4]; + + cl_kernel kernel = nullptr; + + // for inplace case, dst is a view of src0 and is updated on top of it + // so for non-inplace case, copy src0 to dst first + if (!inplace) { + ggml_cl_cpy(backend, src0, dst, nullptr); + } + + // then copy src1 to dst with specified offset + if (src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { + kernel = backend_ctx->kernel_cpy_f32_f32; + } else if (src1->type == GGML_TYPE_I32 && dst->type == GGML_TYPE_I32) { + kernel = backend_ctx->kernel_cpy_i32_i32; + } else { + GGML_ASSERT(false && "not implemented"); + } + + offsetd += offs; + cl_ulong nb = ggml_element_size(dst); + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra1->data_device)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset1)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offsetd)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(int), &ne10)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(int), &ne11)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne12)); + CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &ne13)); + CL_CHECK(clSetKernelArg(kernel, 8, sizeof(cl_ulong), &nb10)); + CL_CHECK(clSetKernelArg(kernel, 9, sizeof(cl_ulong), &nb11)); + CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_ulong), &nb12)); + CL_CHECK(clSetKernelArg(kernel, 11, sizeof(cl_ulong), &nb13)); + CL_CHECK(clSetKernelArg(kernel, 12, sizeof(int), &ne10)); + CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int), &ne11)); + CL_CHECK(clSetKernelArg(kernel, 14, sizeof(int), &ne12)); + CL_CHECK(clSetKernelArg(kernel, 15, sizeof(int), &ne13)); + CL_CHECK(clSetKernelArg(kernel, 16, sizeof(cl_ulong), &nb)); + CL_CHECK(clSetKernelArg(kernel, 17, sizeof(cl_ulong), &pnb1)); + CL_CHECK(clSetKernelArg(kernel, 18, sizeof(cl_ulong), &pnb2)); + CL_CHECK(clSetKernelArg(kernel, 19, sizeof(cl_ulong), &pnb3)); + + int max_local_size = backend_ctx->get_kernel_workgroup_size(kernel); + + const int nth = MIN(max_local_size, ne00); + + size_t global_work_size[] = {(size_t)ne11*nth, (size_t)ne12, (size_t)ne13}; + size_t local_work_size[] = {(size_t)nth, 1, 1}; + + backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst); +} + static void ggml_cl_diag_mask_inf(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { GGML_ASSERT(src0); GGML_ASSERT(src0->extra); @@ -9163,6 +17507,49 @@ static void ggml_cl_diag_mask_inf(ggml_backend_t backend, const ggml_tensor * sr } } +static void ggml_cl_diag(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { + GGML_ASSERT(src0); + GGML_ASSERT(src0->extra); + GGML_ASSERT(dst); + GGML_ASSERT(dst->extra); + + UNUSED(src1); + + ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context; + + ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra; + ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra; + + cl_ulong offset0 = extra0->offset + src0->view_offs; + cl_ulong offsetd = extrad->offset + dst->view_offs; + + GGML_TENSOR_LOCALS(int, ne0, src0, ne); + GGML_TENSOR_LOCALS(cl_ulong, nb0, src0, nb); + GGML_TENSOR_LOCALS(int, ne, dst, ne); + GGML_TENSOR_LOCALS(cl_ulong, nb, dst, nb); + + cl_kernel kernel = backend_ctx->kernel_diag_f32; + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offsetd)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_ulong), &nb01)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &nb02)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(cl_ulong), &nb03)); + CL_CHECK(clSetKernelArg(kernel, 7, sizeof(cl_int), &ne0)); + CL_CHECK(clSetKernelArg(kernel, 8, sizeof(cl_ulong), &nb0)); + CL_CHECK(clSetKernelArg(kernel, 9, sizeof(cl_ulong), &nb2)); + CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_ulong), &nb3)); + + int nth = 64; + + size_t global_work_size[] = {(size_t)ne1*nth, (size_t)ne2, (size_t)ne3}; + size_t local_work_size[] = {(size_t)nth, 1, 1}; + + backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst); +} + static void ggml_cl_soft_max(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { GGML_ASSERT(src0); GGML_ASSERT(src0->extra); @@ -9474,6 +17861,72 @@ static void ggml_cl_rope(ggml_backend_t backend, const ggml_tensor * src0, const backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst); } +static void ggml_cl_solve_tri(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { + GGML_ASSERT(src0); + GGML_ASSERT(src0->extra); + GGML_ASSERT(src1); + GGML_ASSERT(src1->extra); + GGML_ASSERT(dst); + GGML_ASSERT(dst->extra); + + ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context; + + ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra; + ggml_tensor_extra_cl * extra1 = (ggml_tensor_extra_cl *)src1->extra; + ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra; + + cl_ulong offset0 = extra0->offset + src0->view_offs; + cl_ulong offset1 = extra1->offset + src1->view_offs; + cl_ulong offsetd = extrad->offset + dst->view_offs; + + cl_kernel kernel = backend_ctx->kernel_solve_tri_f32; + GGML_ASSERT(kernel != nullptr); + + const int n = src0->ne[0]; + const int k = src1->ne[0]; + + const cl_ulong nb00 = src0->nb[0]; + const cl_ulong nb01 = src0->nb[1]; + const cl_ulong nb02 = src0->nb[2]; + const cl_ulong nb03 = src0->nb[3]; + + const cl_ulong nb10 = src1->nb[0]; + const cl_ulong nb11 = src1->nb[1]; + const cl_ulong nb12 = src1->nb[2]; + const cl_ulong nb13 = src1->nb[3]; + + const cl_ulong nb0 = dst->nb[0]; + const cl_ulong nb1 = dst->nb[1]; + const cl_ulong nb2 = dst->nb[2]; + const cl_ulong nb3 = dst->nb[3]; + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra1->data_device)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offset1)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offsetd)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &n)); + CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &k)); + CL_CHECK(clSetKernelArg(kernel, 8, sizeof(cl_ulong), &nb00)); + CL_CHECK(clSetKernelArg(kernel, 9, sizeof(cl_ulong), &nb01)); + CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_ulong),&nb02)); + CL_CHECK(clSetKernelArg(kernel, 11, sizeof(cl_ulong),&nb03)); + CL_CHECK(clSetKernelArg(kernel, 12, sizeof(cl_ulong),&nb10)); + CL_CHECK(clSetKernelArg(kernel, 13, sizeof(cl_ulong),&nb11)); + CL_CHECK(clSetKernelArg(kernel, 14, sizeof(cl_ulong),&nb12)); + CL_CHECK(clSetKernelArg(kernel, 15, sizeof(cl_ulong),&nb13)); + CL_CHECK(clSetKernelArg(kernel, 16, sizeof(cl_ulong),&nb0)); + CL_CHECK(clSetKernelArg(kernel, 17, sizeof(cl_ulong),&nb1)); + CL_CHECK(clSetKernelArg(kernel, 18, sizeof(cl_ulong),&nb2)); + CL_CHECK(clSetKernelArg(kernel, 19, sizeof(cl_ulong),&nb3)); + + size_t global_work_size[3]= { (size_t)k, (size_t)dst->ne[2], (size_t)dst->ne[3]}; + size_t local_work_size[] = {16, 4, 1}; + + backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst); +} + static void ggml_cl_im2col(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { GGML_ASSERT(src0); GGML_ASSERT(src1); @@ -9601,6 +18054,13 @@ static void ggml_cl_argsort(ggml_backend_t backend, const ggml_tensor * src0, co size_t local_work_size[] = {(size_t)ne00_padded, 1, 1}; backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst); + +#ifdef GGML_OPENCL_USE_ADRENO_KERNELS + const int ne21 = dst->ne[1]; + if ((strstr(src0->name, "_moe") != NULL) && (ne21 != 1)) { + backend_ctx->toggle_reorder = true; + } +#endif // GGML_OPENCL_USE_ADRENO_KERNELS } static void ggml_cl_sum_rows(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { @@ -9611,7 +18071,6 @@ static void ggml_cl_sum_rows(ggml_backend_t backend, const ggml_tensor * src0, c GGML_UNUSED(src1); GGML_ASSERT(src0->nb[0] == ggml_type_size(src0->type)); - GGML_ASSERT(ggml_is_contiguous(src0)); ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context; @@ -9634,7 +18093,14 @@ static void ggml_cl_sum_rows(ggml_backend_t backend, const ggml_tensor * src0, c const cl_ulong nb2 = dst->nb[2]; const cl_ulong nb3 = dst->nb[3]; - cl_kernel kernel = backend_ctx->kernel_sum_rows_f32; + cl_kernel kernel; + + const bool is_c4 = ne00 % 4 == 0; + if (is_c4) { + kernel = backend_ctx->kernel_sum_rows_f32_4; + } else { + kernel = backend_ctx->kernel_sum_rows_f32; + } CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device)); CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0)); @@ -9651,12 +18117,124 @@ static void ggml_cl_sum_rows(ggml_backend_t backend, const ggml_tensor * src0, c CL_CHECK(clSetKernelArg(kernel, 12, sizeof(cl_ulong), &nb2)); CL_CHECK(clSetKernelArg(kernel, 13, sizeof(cl_ulong), &nb3)); - size_t global_work_size[] = {(size_t)ne01, (size_t)ne02, (size_t)ne03}; + size_t global_work_size[] = {64 * (size_t)ne01, (size_t)ne02, (size_t)ne03}; size_t local_work_size[] = {(size_t)64, 1, 1}; backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst); } +static void ggml_cl_cumsum(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { + GGML_ASSERT(src0); + GGML_ASSERT(src0->extra); + GGML_ASSERT(dst); + GGML_ASSERT(dst->extra); + GGML_UNUSED(src1); + + GGML_ASSERT(src0->nb[0] == ggml_type_size(src0->type)); + GGML_ASSERT(ggml_is_contiguous(src0)); + + ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context; + + ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra; + ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra; + + cl_ulong offset0 = extra0->offset + src0->view_offs; + cl_ulong offsetd = extrad->offset + dst->view_offs; + + GGML_TENSOR_LOCALS(int, ne0, src0, ne); + GGML_TENSOR_LOCALS(cl_ulong, nb0, src0, nb); + + cl_kernel kernel = backend_ctx->kernel_cumsum_blk; + + int max_workgroup_size = backend_ctx->get_kernel_workgroup_size(kernel); + int nth = 1; + while (nth < ne00 && 2*nth <= max_workgroup_size) { + nth *= 2; + } + + GGML_ASSERT(ne00 <= nth*nth); + + const int net0 = CEIL_DIV(ne00, nth); + const int net1 = ne01; + const int net2 = ne02; + const int net3 = ne03; + + const cl_ulong nbt0 = sizeof(float); + const cl_ulong nbt1 = net0*nbt0; + const cl_ulong nbt2 = net1*nbt1; + const cl_ulong nbt3 = net2*nbt2; + + static ggml_cl_buffer tmp_buffer; + tmp_buffer.allocate(backend_ctx->context, net0*ne01*ne02*ne03*sizeof(float)); + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &tmp_buffer.buffer)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_ulong), &offsetd)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne01)); + CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &ne02)); + CL_CHECK(clSetKernelArg(kernel, 8, sizeof(int), &ne03)); + CL_CHECK(clSetKernelArg(kernel, 9, sizeof(cl_ulong), &nb00)); + CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_ulong), &nb01)); + CL_CHECK(clSetKernelArg(kernel, 11, sizeof(cl_ulong), &nb02)); + CL_CHECK(clSetKernelArg(kernel, 12, sizeof(cl_ulong), &nb03)); + CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int), &net0)); + CL_CHECK(clSetKernelArg(kernel, 14, sizeof(int), &net1)); + CL_CHECK(clSetKernelArg(kernel, 15, sizeof(int), &net2)); + + size_t global_work_size[] = { (size_t)(nth*net0*ne01), (size_t)ne02, (size_t)ne03}; + size_t local_work_size[] = { (size_t)nth, 1, 1}; + + backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst); + + if(ne00 > nth) { + // if a single workgroup cannot handle an entire row, each workgroup + // computes a partial sum and stores to dst, tmp_buffer contains the sum + // of the each workgroup; cumsum this buffer and add to the partial sums in dst + cl_ulong offsett = 0; + kernel = backend_ctx->kernel_cumsum_blk; + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &tmp_buffer.buffer)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offsett)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &tmp_buffer.buffer)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_mem), &tmp_buffer.buffer)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_ulong), &offsett)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(int), &net0)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne01)); + CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &ne02)); + CL_CHECK(clSetKernelArg(kernel, 8, sizeof(int), &ne03)); + CL_CHECK(clSetKernelArg(kernel, 9, sizeof(cl_ulong), &nbt0)); + CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_ulong), &nbt1)); + CL_CHECK(clSetKernelArg(kernel, 11, sizeof(cl_ulong), &nbt2)); + CL_CHECK(clSetKernelArg(kernel, 12, sizeof(cl_ulong), &nbt3)); + CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int), &net0)); + CL_CHECK(clSetKernelArg(kernel, 14, sizeof(int), &net1)); + CL_CHECK(clSetKernelArg(kernel, 15, sizeof(int), &net2)); + + size_t global_work_size_1[] = { (size_t)net1*nth, (size_t)net2, (size_t)net3}; + size_t local_work_size_1[] = { (size_t)nth, 1, 1}; + backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size_1, local_work_size_1, dst); + + kernel = backend_ctx->kernel_cumsum_add; + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &tmp_buffer.buffer)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_ulong), &offsetd)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(int), &ne01)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(int), &ne02)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne03)); + CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &nbt0)); + CL_CHECK(clSetKernelArg(kernel, 8, sizeof(int), &nbt1)); + CL_CHECK(clSetKernelArg(kernel, 9, sizeof(int), &nbt2)); + CL_CHECK(clSetKernelArg(kernel, 10, sizeof(int), &nbt3)); + + size_t global_work_size_2[] = { (size_t)(nth*net0*ne01), (size_t)ne02, (size_t)ne03}; + size_t local_work_size_2[] = { (size_t)nth, 1, 1}; + backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size_2, local_work_size_2, dst); + } +} + static void ggml_cl_glu(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { GGML_ASSERT(src0); GGML_ASSERT(src0->extra); @@ -9767,6 +18345,185 @@ static void ggml_cl_glu(ggml_backend_t backend, const ggml_tensor * src0, const backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst); } +static void ggml_cl_gated_delta_net(ggml_backend_t backend, ggml_tensor * dst) { + GGML_ASSERT(dst); + GGML_ASSERT(dst->extra); + + const ggml_tensor * src_q = dst->src[0]; + const ggml_tensor * src_k = dst->src[1]; + const ggml_tensor * src_v = dst->src[2]; + const ggml_tensor * src_g = dst->src[3]; + const ggml_tensor * src_beta = dst->src[4]; + const ggml_tensor * src_state = dst->src[5]; + + GGML_ASSERT(src_q && src_q->extra); + GGML_ASSERT(src_k && src_k->extra); + GGML_ASSERT(src_v && src_v->extra); + GGML_ASSERT(src_g && src_g->extra); + GGML_ASSERT(src_beta && src_beta->extra); + GGML_ASSERT(src_state && src_state->extra); + + ggml_backend_opencl_context * backend_ctx = (ggml_backend_opencl_context *) backend->context; + + const cl_uint S_v = (cl_uint) src_v->ne[0]; + const cl_uint H_v = (cl_uint) src_v->ne[1]; + const cl_uint n_tokens = (cl_uint) src_v->ne[2]; + const cl_uint n_seqs = (cl_uint) src_v->ne[3]; + const cl_uint K = (cl_uint) ggml_get_op_params_i32(dst, 0); + + int si; + switch (S_v) { + case 16: si = 0; break; + case 32: si = 1; break; + case 64: si = 2; break; + case 128: si = 3; break; + default: + GGML_ASSERT(false && "ggml_cl_gated_delta_net: unsupported S_v"); + } + + const int kda = (src_g->ne[0] == (int64_t) S_v) ? 1 : 0; + + // TODO: Optimize when S_v!=128. Not necessary for now as Qwen3.5/6 are all S_v=128 + // token generation mode (tgpp=0): + // process 1 token at a time, so columns per lane (cpl) == 1 + // prompt processing mode (tgpp=1): + // cpl=4 to process 4 tokens for single-token. 4 is chosen for Adreno 750 as per + // work-item/thread has at most 128 registers. + // All Qwen3.5/6 models are S_v == 128, so LANES_PER_COLUMN == 8 + // such that ROWS_PER_LANE = 128/8 = 16 + // Variables in the kernel: + // k_reg, q_reg, g_exp are all 16 floats + // s_shard has cpl*ROWS_PER_LANE = 4*16 = 64 floats + // Total 112 registers used. + // subgroups_per_workgroup (spw) can be set to 1,2,4,8,16 for tg and 1,2,4 for pp + // for S_v=128. + // Empirically found that when spw=1, we get the best performance for both tg and pp + const int tgpp = (n_tokens == 1) ? 0 : 1; + const int cpl = (tgpp == 0) ? 1 : 4; + // spw needs adjustment when S_v != 128 + const int spw = (tgpp == 0) ? 1 : 1; + + cl_kernel kernel = backend_ctx->kernel_gated_delta_net_f32[si][kda][tgpp]; + GGML_ASSERT(kernel != nullptr); + + const cl_uint s_off = S_v * H_v * n_tokens * n_seqs; + + const cl_uint sq1 = (cl_uint)(src_q->nb[1] / sizeof(float)); + const cl_uint sq2 = (cl_uint)(src_q->nb[2] / sizeof(float)); + const cl_uint sq3 = (cl_uint)(src_q->nb[3] / sizeof(float)); + const cl_uint sv1 = (cl_uint)(src_v->nb[1] / sizeof(float)); + const cl_uint sv2 = (cl_uint)(src_v->nb[2] / sizeof(float)); + const cl_uint sv3 = (cl_uint)(src_v->nb[3] / sizeof(float)); + const cl_uint sb1 = (cl_uint)(src_beta->nb[1] / sizeof(float)); + const cl_uint sb2 = (cl_uint)(src_beta->nb[2] / sizeof(float)); + const cl_uint sb3 = (cl_uint)(src_beta->nb[3] / sizeof(float)); + + const cl_uint H_k = (cl_uint) src_q->ne[1]; + const cl_uint rq3 = (cl_uint)(src_v->ne[3] / src_q->ne[3]); + + const float scale = 1.0f / sqrtf((float) S_v); + + ggml_tensor_extra_cl * extra_q = (ggml_tensor_extra_cl *) src_q->extra; + ggml_tensor_extra_cl * extra_k = (ggml_tensor_extra_cl *) src_k->extra; + ggml_tensor_extra_cl * extra_v = (ggml_tensor_extra_cl *) src_v->extra; + ggml_tensor_extra_cl * extra_g = (ggml_tensor_extra_cl *) src_g->extra; + ggml_tensor_extra_cl * extra_beta = (ggml_tensor_extra_cl *) src_beta->extra; + ggml_tensor_extra_cl * extra_state = (ggml_tensor_extra_cl *) src_state->extra; + ggml_tensor_extra_cl * extra_dst = (ggml_tensor_extra_cl *) dst->extra; + + const cl_ulong off_q = extra_q->offset + src_q->view_offs; + const cl_ulong off_k = extra_k->offset + src_k->view_offs; + const cl_ulong off_v = extra_v->offset + src_v->view_offs; + const cl_ulong off_g = extra_g->offset + src_g->view_offs; + const cl_ulong off_beta = extra_beta->offset + src_beta->view_offs; + const cl_ulong off_state = extra_state->offset + src_state->view_offs; + const cl_ulong off_dst = extra_dst->offset + dst->view_offs; + + int idx = 0; + CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_mem), &extra_q->data_device)); + CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_ulong), &off_q)); + CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_mem), &extra_k->data_device)); + CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_ulong), &off_k)); + CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_mem), &extra_v->data_device)); + CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_ulong), &off_v)); + CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_mem), &extra_g->data_device)); + CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_ulong), &off_g)); + CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_mem), &extra_beta->data_device)); + CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_ulong), &off_beta)); + CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_mem), &extra_state->data_device)); + CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_ulong), &off_state)); + CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_mem), &extra_dst->data_device)); + CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_ulong), &off_dst)); + CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &H_v)); + CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &n_tokens)); + CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &n_seqs)); + CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &s_off)); + CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &sq1)); + CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &sq2)); + CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &sq3)); + CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &sv1)); + CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &sv2)); + CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &sv3)); + CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &sb1)); + CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &sb2)); + CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &sb3)); + CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &H_k)); + CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &rq3)); + CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(float), &scale)); + CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &K)); + + // Subgroup size is 64 for Adreno and 32 for Intel + const int sg_size = backend_ctx->gpu_family == GPU_FAMILY::ADRENO ? 64 : backend_ctx->gpu_family == GPU_FAMILY::INTEL ? 32 : -1; + if (sg_size < 0) { + GGML_LOG_ERROR("Unsupported GPU Family: only Adreno and Intel are supported.\n"); + exit(1); + } + + // For the subgroup-shuffle kernel, we can safely prefer 8 lanes/column for S_v>=128 + // For the subgroup-shuffle kernel: + // S_v >= 128 -> prefer 8 lanes/column (good occupancy & register pressure tradeoff) + // else -> min(S_v, subgroup_size) + int lanes_per_column; + if ((int)S_v >= 128) { + lanes_per_column = 8; + } else { + lanes_per_column = std::min((int)S_v, sg_size); + } + + // Max workgroup size for Adreno 750 is 1024 + const int wg_size = sg_size * spw; + + // Ensure lanes_per_column is a power-of-two and divides both S_v and subgroup_size. + // (Required for lane-group shuffle-xor reduction correctness.) + while (lanes_per_column > 1 && + (((lanes_per_column & (lanes_per_column - 1)) != 0) || + (((int)S_v % lanes_per_column) != 0) || + (sg_size % lanes_per_column) != 0)) { + lanes_per_column >>= 1; + } + GGML_ASSERT(lanes_per_column >= 1); + GGML_ASSERT(((lanes_per_column & (lanes_per_column - 1)) == 0)); + GGML_ASSERT(((int)S_v % lanes_per_column) == 0); + GGML_ASSERT((sg_size % lanes_per_column) == 0); + + const int cols_per_wg = spw * (sg_size / lanes_per_column) * cpl; + GGML_ASSERT(cols_per_wg > 0); + GGML_ASSERT(((int)S_v % cols_per_wg) == 0); + + size_t global_work_size[3]; + size_t local_work_size[3]; + + global_work_size[0] = (size_t) H_v * (size_t) wg_size; + global_work_size[1] = (size_t) n_seqs; + global_work_size[2] = (size_t) S_v / (size_t) cols_per_wg; + + local_work_size[0] = (size_t) wg_size; + local_work_size[1] = 1; + local_work_size[2] = 1; + + backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst); +} + //------------------------------------------------------------------------------ // Op offloading //------------------------------------------------------------------------------ @@ -9802,6 +18559,12 @@ bool ggml_cl_compute_forward(ggml_backend_t backend, struct ggml_tensor * tensor } func = ggml_cl_cpy; break; + case GGML_OP_SET: + if (!any_on_device) { + return false; + } + func = ggml_cl_set; + break; case GGML_OP_DUP: case GGML_OP_CONT: if (!any_on_device) { @@ -9901,6 +18664,18 @@ bool ggml_cl_compute_forward(ggml_backend_t backend, struct ggml_tensor * tensor } func = ggml_cl_tanh; break; + case GGML_UNARY_OP_NEG: + if (!any_on_device) { + return false; + } + func = ggml_cl_neg; + break; + case GGML_UNARY_OP_EXP: + if (!any_on_device) { + return false; + } + func = ggml_cl_exp; + break; case GGML_UNARY_OP_EXPM1: if (!any_on_device) { return false; @@ -9922,6 +18697,12 @@ bool ggml_cl_compute_forward(ggml_backend_t backend, struct ggml_tensor * tensor } func = ggml_cl_glu; break; + case GGML_OP_TRI: + if (!any_on_device) { + return false; + } + func = ggml_cl_tri; + break; case GGML_OP_FILL: if (!any_on_device) { return false; @@ -9946,14 +18727,20 @@ bool ggml_cl_compute_forward(ggml_backend_t backend, struct ggml_tensor * tensor } func = ggml_cl_rms_norm; break; + case GGML_OP_L2_NORM: + if (!any_on_device) { + return false; + } + func = ggml_cl_l2_norm; + break; case GGML_OP_GROUP_NORM: if (!any_on_device) { return false; } func = ggml_cl_group_norm; break; - case GGML_OP_REPEAT: - if (!any_on_device) { + case GGML_OP_REPEAT: + if (!any_on_device) { return false; } func = ggml_cl_repeat; @@ -9982,6 +18769,14 @@ bool ggml_cl_compute_forward(ggml_backend_t backend, struct ggml_tensor * tensor } func = ggml_cl_ssm_conv; break; + case GGML_OP_GATED_DELTA_NET: + if (!any_on_device) { + return false; + } + // GDN has 6 source tensors, so it cannot use the standard + // (src0, src1, dst) func signature. Dispatch directly and return. + ggml_cl_gated_delta_net(backend, tensor); + return true; case GGML_OP_CONCAT: if (!any_on_device) { return false; @@ -10021,6 +18816,12 @@ bool ggml_cl_compute_forward(ggml_backend_t backend, struct ggml_tensor * tensor } func = ggml_cl_nop; break; + case GGML_OP_DIAG: + if (!any_on_device) { + return false; + } + func = ggml_cl_diag; + break; case GGML_OP_DIAG_MASK_INF: if (!any_on_device) { return false; @@ -10039,6 +18840,12 @@ bool ggml_cl_compute_forward(ggml_backend_t backend, struct ggml_tensor * tensor } func = ggml_cl_rope; break; + case GGML_OP_SOLVE_TRI: + if (!any_on_device) { + return false; + } + func = ggml_cl_solve_tri; + break; case GGML_OP_IM2COL: if (!any_on_device) { return false; @@ -10057,6 +18864,12 @@ bool ggml_cl_compute_forward(ggml_backend_t backend, struct ggml_tensor * tensor } func = ggml_cl_sum_rows; break; + case GGML_OP_CUMSUM: + if (!any_on_device) { + return false; + } + func = ggml_cl_cumsum; + break; case GGML_OP_FLASH_ATTN_EXT: if (!any_on_device) { return false; diff --git a/ggml/src/ggml-opencl/kernels/concat.cl b/ggml/src/ggml-opencl/kernels/concat.cl index 132758469c6..2fbd7851d3d 100644 --- a/ggml/src/ggml-opencl/kernels/concat.cl +++ b/ggml/src/ggml-opencl/kernels/concat.cl @@ -1,109 +1,118 @@ -kernel void kernel_concat_f32_contiguous( - global const char * p_src0, ulong off_src0, - global const char * p_src1, ulong off_src1, - global char * p_dst, ulong off_dst, - int d_ne00, int d_ne01, int d_ne02, // src0->ne[0..2] for the slice - int d_ne10, int d_ne11, int d_ne12, // src1->ne[0..2] for the slice (d_ne1X must match d_ne0X on non-concat axes) - int d_ne0, int d_ne1, int d_ne2, // dst->ne[0..2] for the slice - int dim +kernel void kernel_concat_f32( + global const char * src0, + ulong offset0, + global const char * src1, + ulong offset1, + global char * dst, + ulong offsetd, + int ne00, + int ne01, + int ne02, + int ne03, + ulong nb00, + ulong nb01, + ulong nb02, + ulong nb03, + ulong nb10, + ulong nb11, + ulong nb12, + ulong nb13, + int ne0, + ulong nb0, + ulong nb1, + ulong nb2, + ulong nb3, + int dim ) { - global const float * src0 = (global const float*)((global char*)p_src0 + off_src0); - global const float * src1 = (global const float*)((global char*)p_src1 + off_src1); - global float * dst = (global float*)((global char*)p_dst + off_dst); + src0 = src0 + offset0; + src1 = src1 + offset1; + dst = dst + offsetd; - int i0 = get_global_id(0); // Index along dst's 0th dimension - int i1 = get_global_id(1); // Index along dst's 1st dimension - int i2 = get_global_id(2); // Index along dst's 2nd dimension + const int i3 = get_group_id(2); + const int i2 = get_group_id(1); + const int i1 = get_group_id(0); - if (i0 >= d_ne0 || i1 >= d_ne1 || i2 >= d_ne2) { - return; - } + int o[4] = {0, 0, 0, 0}; + o[dim] = dim == 0 ? ne00 : (dim == 1 ? ne01 : (dim == 2 ? ne02 : ne03)); - ulong dst_idx = (ulong)i2 * d_ne0 * d_ne1 + (ulong)i1 * d_ne0 + i0; - ulong src_idx; + global const float * x; - if (dim == 0) { - if (i0 < d_ne00) { // Data from src0 - src_idx = (ulong)i2 * d_ne00 * d_ne01 + (ulong)i1 * d_ne00 + i0; - dst[dst_idx] = src0[src_idx]; - } else { // Data from src1 - src_idx = (ulong)i2 * d_ne10 * d_ne11 + (ulong)i1 * d_ne10 + (i0 - d_ne00); - dst[dst_idx] = src1[src_idx]; - } - } else if (dim == 1) { - if (i1 < d_ne01) { // Data from src0 - src_idx = (ulong)i2 * d_ne00 * d_ne01 + (ulong)i1 * d_ne00 + i0; - dst[dst_idx] = src0[src_idx]; - } else { // Data from src1 - src_idx = (ulong)i2 * d_ne10 * d_ne11 + (ulong)(i1 - d_ne01) * d_ne10 + i0; - dst[dst_idx] = src1[src_idx]; + for (int i0 = get_local_id(0); i0 < ne0; i0 += get_local_size(0)) { + if (i0 < ne00 && i1 < ne01 && i2 < ne02 && i3 < ne03) { + x = (global const float *)(src0 + (i3 )*nb03 + (i2 )*nb02 + (i1 )*nb01 + (i0 )*nb00); + } else { + x = (global const float *)(src1 + (i3 - o[3])*nb13 + (i2 - o[2])*nb12 + (i1 - o[1])*nb11 + (i0 - o[0])*nb10); } - } else if (dim == 2) { - if (i2 < d_ne02) { // Data from src0 - src_idx = (ulong)i2 * d_ne00 * d_ne01 + (ulong)i1 * d_ne00 + i0; - dst[dst_idx] = src0[src_idx]; - } else { // Data from src1 - src_idx = (ulong)(i2 - d_ne02) * d_ne10 * d_ne11 + (ulong)i1 * d_ne10 + i0; - dst[dst_idx] = src1[src_idx]; - } + global float * y = (global float *)(dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); + + *y = *x; } } -kernel void kernel_concat_f32_non_contiguous( - global const char * p_src0, ulong off_src0, - global const char * p_src1, ulong off_src1, - global char * p_dst, ulong off_dst, - - long ne00, long ne01, long ne02, long ne03, - ulong nb00, ulong nb01, ulong nb02, ulong nb03, - - ulong nb10, ulong nb11, ulong nb12, ulong nb13, // Strides for src1 - - long d_ne0, long d_ne1, long d_ne2, long d_ne3, - ulong d_nb0, ulong d_nb1, ulong d_nb2, ulong d_nb3, - int dim +kernel void kernel_concat_f32_pack( + global const char * src0, + ulong offset0, + global const char * src1, + ulong offset1, + global char * dst, + ulong offsetd, + int ne00, + int ne01, + int ne02, + int ne03, + ulong nb00, + ulong nb01, + ulong nb02, + ulong nb03, + ulong nb10, + ulong nb11, + ulong nb12, + ulong nb13, + int ne0, + ulong nb0, + ulong nb1, + ulong nb2, + ulong nb3, + int dim, + int ne1, + int ne2, + int ne3 ) { - global const char * src0_base = p_src0 + off_src0; - global const char * src1_base = p_src1 + off_src1; - global char * dst_base = p_dst + off_dst; - - long current_i1 = get_global_id(0); // Index for dst_dim_1 - long current_i2 = get_global_id(1); // Index for dst_dim_2 - long current_i3 = get_global_id(2); // Index for dst_dim_3 - - if (current_i1 >= d_ne1 || current_i2 >= d_ne2 || current_i3 >= d_ne3) { + src0 = src0 + offset0; + src1 = src1 + offset1; + dst = dst + offsetd; + + int lsz = get_local_size(0); + int tpr = min(ne0, lsz); // threads per row + int rpw = lsz / tpr; // rows per workgroup + int lid = get_local_id(0); + int row = get_group_id(0)*rpw + lid / tpr; + int lane = lid - (lid / tpr) * tpr; + + int nrows = ne1*ne2*ne3; + if (row >= nrows) { return; } - global const float * x_val_ptr; - global float * y_val_ptr; + int i1 = row % ne1; + int t = row / ne1; + int i2 = t % ne2; + int i3 = t / ne2; - for (long current_i0 = 0; current_i0 < d_ne0; ++current_i0) { - bool use_src0; - long s_i0 = current_i0, s_i1 = current_i1, s_i2 = current_i2, s_i3 = current_i3; + int o[4] = {0, 0, 0, 0}; + o[dim] = dim == 0 ? ne00 : (dim == 1 ? ne01 : (dim == 2 ? ne02 : ne03)); - if (dim == 0) { - use_src0 = (current_i0 < ne00); - if (!use_src0) { s_i0 = current_i0 - ne00; } - } else if (dim == 1) { - use_src0 = (current_i1 < ne01); - if (!use_src0) { s_i1 = current_i1 - ne01; } - } else if (dim == 2) { - use_src0 = (current_i2 < ne02); - if (!use_src0) { s_i2 = current_i2 - ne02; } - } else { // dim == 3 - use_src0 = (current_i3 < ne03); - if (!use_src0) { s_i3 = current_i3 - ne03; } - } - - if (use_src0) { - x_val_ptr = (global const float *)(src0_base + (ulong)s_i3*nb03 + (ulong)s_i2*nb02 + (ulong)s_i1*nb01 + (ulong)s_i0*nb00); + for (int i0 = lane; i0 < ne0; i0 += tpr) { + global const float * x; + if (i0 < ne00 && i1 < ne01 && i2 < ne02 && i3 < ne03) { + x = (global const float *)(src0 + (i3 )*nb03 + (i2 )*nb02 + (i1 )*nb01 + (i0 )*nb00); } else { - x_val_ptr = (global const float *)(src1_base + (ulong)s_i3*nb13 + (ulong)s_i2*nb12 + (ulong)s_i1*nb11 + (ulong)s_i0*nb10); + x = (global const float *)(src1 + (i3 - o[3])*nb13 + (i2 - o[2])*nb12 + (i1 - o[1])*nb11 + (i0 - o[0])*nb10); } - y_val_ptr = (global float *)(dst_base + (ulong)current_i3*d_nb3 + (ulong)current_i2*d_nb2 + (ulong)current_i1*d_nb1 + (ulong)current_i0*d_nb0); - *y_val_ptr = *x_val_ptr; + global float * y = (global float *)(dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); + + *y = *x; } } diff --git a/ggml/src/ggml-opencl/kernels/cpy.cl b/ggml/src/ggml-opencl/kernels/cpy.cl index 9369351a60c..adbd2e766d2 100644 --- a/ggml/src/ggml-opencl/kernels/cpy.cl +++ b/ggml/src/ggml-opencl/kernels/cpy.cl @@ -182,3 +182,107 @@ kernel void kernel_cpy_f32_f32( dst_data[i00] = src[0]; } } + +kernel void kernel_cpy_f32_f32_pack( + global float * src0, + ulong offset0, + global float * dst, + ulong offsetd, + int ne00, + int ne01, + int ne02, + int ne03, + ulong nb00, + ulong nb01, + ulong nb02, + ulong nb03, + int ne0, + int ne1, + int ne2, + int ne3, + ulong nb0, + ulong nb1, + ulong nb2, + ulong nb3 +) { + src0 = (global float*)((global char*)src0 + offset0); + dst = (global float*)((global char*)dst + offsetd); + + int lsz = get_local_size(0); + int tpr = min(ne00, lsz); // threads per row + int rpw = lsz / tpr; // rows per workgroup + int lid = get_local_id(0); + int row = get_group_id(0)*rpw + lid / tpr; + int lane = lid - (lid / tpr) * tpr; + + int nrows = ne01*ne02*ne03; + if (row >= nrows) { + return; + } + + int i01 = row % ne01; + int t = row / ne01; + int i02 = t % ne02; + int i03 = t / ne02; + + // linear index of the first element of this row, unflattened over dst dims + long n = (long)row * ne00; + int i3 = (int)(n / ((long)ne2*ne1*ne0)); + long rm = n - (long)i3*ne2*ne1*ne0; + int i2 = (int)(rm / ((long)ne1*ne0)); + rm -= (long)i2*ne1*ne0; + int i1 = (int)(rm / ne0); + int i0 = (int)(rm - (long)i1*ne0); + + global float * dst_data = (global float *) ((global char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); + + for (int i00 = lane; i00 < ne00; i00 += tpr) { + global const float * src = (global float *)((global char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00); + dst_data[i00] = src[0]; + } +} + +kernel void kernel_cpy_i32_i32( + global int * src0, + ulong offset0, + global int * dst, + ulong offsetd, + int ne00, + int ne01, + int ne02, + int ne03, + ulong nb00, + ulong nb01, + ulong nb02, + ulong nb03, + int ne0, + int ne1, + int ne2, + int ne3, + ulong nb0, + ulong nb1, + ulong nb2, + ulong nb3 +) { + src0 = (global int*)((global char*)src0 + offset0); + dst = (global int*)((global char*)dst + offsetd); + + int i03 = get_group_id(2); + int i02 = get_group_id(1); + int i01 = get_group_id(0); + + int n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00; + + int i3 = n / (ne2*ne1*ne0); + int i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0); + int i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0; + int i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0); + + global int * dst_data = (global int *) ((global char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); + + for (int i00 = get_local_id(0); i00 < ne00; i00 += get_local_size(0)) { + global const int * src = (global int *)((global char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00); + + dst_data[i00] = src[0]; + } +} diff --git a/ggml/src/ggml-opencl/kernels/cumsum.cl b/ggml/src/ggml-opencl/kernels/cumsum.cl new file mode 100644 index 00000000000..edfb74b7058 --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/cumsum.cl @@ -0,0 +1,139 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable + +#ifdef cl_intel_required_subgroup_size +#pragma OPENCL EXTENSION cl_intel_required_subgroup_size : enable +#define INTEL_GPU 1 +#define REQD_SUBGROUP_SIZE_16 __attribute__((intel_reqd_sub_group_size(16))) +#define REQD_SUBGROUP_SIZE_32 __attribute__((intel_reqd_sub_group_size(32))) +#elif defined(cl_qcom_reqd_sub_group_size) +#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable +#define ADRENO_GPU 1 +#define REQD_SUBGROUP_SIZE_64 __attribute__((qcom_reqd_sub_group_size("half"))) +#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full"))) +#endif + +// max workgroup size is usually 1024, this covers various subgroups sizes +#define MAX_SUBGROUPS 128 + +#ifdef INTEL_GPU +REQD_SUBGROUP_SIZE_32 +#elif defined (ADRENO_GPU) +REQD_SUBGROUP_SIZE_64 +#endif +kernel void kernel_cumsum_blk( + global char * src0, + ulong offset0, + global char * tmp, + global char * dst, + ulong offsetd, + int ne00, + int ne01, + int ne02, + int ne03, + ulong nb00, + ulong nb01, + ulong nb02, + ulong nb03, + uint net0, + uint net1, + uint net2 +) { + src0 = src0 + offset0; + dst = dst + offsetd; + + const int i3 = get_group_id(2); + const int i2 = get_group_id(1); + const int i1 = get_group_id(0); + + const int nth = get_local_size(0); + const int tid = get_local_id(0); + + const uint sg_size = get_sub_group_size(); + const uint sg_id = get_sub_group_id(); + const uint sg_lid = get_sub_group_local_id(); + + const int ib = i1 / ne01; + const int i00 = ib * nth; + const int i01 = i1 % ne01; + const int i02 = i2; + const int i03 = i3; + + global const float * src0_row = (global const float *)(src0 + i03*nb03 + i02*nb02 + i01*nb01); + global float * tmp_row = (global float *)tmp + net0 * i01 + net0 * net1 * i02 + net0 * net1 * net2 * i03; + global float * dst_row = (global float *)dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00; + + __local float partial[MAX_SUBGROUPS]; + + float v = 0.0f; + if (i00 + tid < ne00) { + v = src0_row[i00 + tid]; + } + + float s = sub_group_scan_inclusive_add(v); + if (sg_lid == sg_size - 1) { + partial[sg_id] = s; + } + barrier(CLK_LOCAL_MEM_FENCE); + + // NB: subgroup size should be larger than number of subgroups + // assuming max workgroup size of 1024, subgroup size should be >= 32 + if (sg_id == 0) { + float x = 0.0f; + if (sg_lid < get_num_sub_groups()) { + x = partial[sg_lid]; + } + float ex = sub_group_scan_exclusive_add(x); + if (sg_lid < get_num_sub_groups()) { + partial[sg_lid] = ex; + } + } + barrier(CLK_LOCAL_MEM_FENCE); + + s += partial[sg_id]; + + if (i00 + tid < ne00) { + dst_row[i00 + tid] = s; + } + if (ne00 > nth && tid == nth - 1) { + tmp_row[ib] = s; + } +} + +kernel void kernel_cumsum_add( + global char * tmp, + global char * dst, + ulong offsetd, + int ne00, + int ne01, + int ne02, + int ne03, + uint nbt0, + uint nbt1, + uint nbt2, + uint nbt3 +) { + dst = dst + offsetd; + + const int i3 = get_group_id(2); + const int i2 = get_group_id(1); + const int i1 = get_group_id(0); + + const int nth = get_local_size(0); + const int tid = get_local_id(0); + + const int ib = i1 / ne01; + if (ib == 0) { + return; + } + const int i00 = ib * nth; + const int i01 = i1 % ne01; + const int i02 = i2; + const int i03 = i3; + + global float * tmp_row = (global float *)(tmp + nbt1 * i01 + nbt2 * i02 + nbt3 * i03); + global float * dst_row = (global float *)dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00; + + if (i00 + tid < ne00) { + dst_row[i00 + tid] += tmp_row[ib - 1]; + } +} diff --git a/ggml/src/ggml-opencl/kernels/cvt.cl b/ggml/src/ggml-opencl/kernels/cvt.cl index 513a4d3e28f..226b127ab3b 100644 --- a/ggml/src/ggml-opencl/kernels/cvt.cl +++ b/ggml/src/ggml-opencl/kernels/cvt.cl @@ -28,6 +28,7 @@ #define QK8_0 32 #define QR8_0 1 #define QK_K 256 +#define K_SCALE_SIZE (3 * QK_K / 64) #define K_QUANTS_PER_ITERATION 2 typedef char int8_t; @@ -46,6 +47,118 @@ struct block_q4_0 uint8_t qs[QK4_0 / 2]; }; +//------------------------------------------------------------------------------ +// block_q4_1 +//------------------------------------------------------------------------------ +struct block_q4_1 { + half d; // delta + half m; // min + uchar qs[QK4_1 / 2]; // nibbles / quants +}; + +//------------------------------------------------------------------------------ +// block_q5_0 +//------------------------------------------------------------------------------ +struct block_q5_0 { + half d; // delta + uchar qh[4]; // 5-th bit of quants + uchar qs[QK5_0 / 2]; // nibbles / quants +}; + +//------------------------------------------------------------------------------ +// block_q5_1 +//------------------------------------------------------------------------------ +struct block_q5_1 { + half d; // delta + half m; // min + uchar qh[4]; // 5-th bit of quants + uchar qs[QK5_1 / 2]; // nibbles / quants +}; + +//------------------------------------------------------------------------------ +// block_q4_k +//------------------------------------------------------------------------------ +struct block_q4_K { + half d; // delta + half dm; // min + uchar s[K_SCALE_SIZE]; + uchar q[QK_K / 2]; // nibbles / quants +}; + +//------------------------------------------------------------------------------ +// block_q5_k +//------------------------------------------------------------------------------ +struct block_q5_K { + half d; // delta + half dm; // min + uchar s[K_SCALE_SIZE]; + uchar qh[QK_K / 8]; + uchar qs[QK_K / 2]; // nibbles / quants +}; + +//------------------------------------------------------------------------------ +// block_q6_K +//------------------------------------------------------------------------------ +struct block_q6_K { + uint8_t ql[QK_K/2]; // quants, lower 4 bits + uint8_t qh[QK_K/4]; // quants, upper 2 bits + int8_t scales[QK_K/16]; // scales, quantized with 8 bits + half d; // super-block scale +}; + +//------------------------------------------------------------------------------ +// block_iq4_nl +//------------------------------------------------------------------------------ +#define QK4_NL 32 + +struct block_iq4_nl +{ + half d; + uint8_t qs[QK4_NL / 2]; +}; + +//------------------------------------------------------------------------------ +// bf16 to f16 +//------------------------------------------------------------------------------ +kernel void kernel_convert_bf16_to_f16( + global const ushort * src, + global half * dst, + ulong off_dst, + ulong n +) { + uint i = get_global_id(0); + if (i >= n) { + return; + } + + dst[i + off_dst] = (half) as_float((uint) src[i] << 16); +} + +//------------------------------------------------------------------------------ +// f16 to bf16 +//------------------------------------------------------------------------------ +kernel void kernel_convert_f16_to_bf16( + global const half * src, + ulong off_src, + global ushort * dst, + ulong n +) { + uint i = get_global_id(0); + if (i >= n) { + return; + } + + float f = (float) src[i + off_src]; + uint bits = as_uint(f); + if ((bits & 0x7fffffffu) > 0x7f800000u) { + // nan to quiet nan + dst[i] = (ushort)((bits >> 16) | 0x40u); + } else { + uint rounded = bits + 0x7fffu + ((bits >> 16) & 1u); + dst[i] = (ushort)(rounded >> 16); + } +} + //------------------------------------------------------------------------------ // kernel_convert_block_q4_0 // Convert the block_q4_0 format to 2 separate arrays (AOS -> SOA). @@ -138,76 +251,248 @@ kernel void kernel_restore_block_q4_0_noshuffle( } } -//------------------------------------------------------------------------------ -// block_mxfp4 -//------------------------------------------------------------------------------ -#define QK_MXFP4 32 -struct block_mxfp4 { - uchar e; // E8M0 - uchar qs[QK_MXFP4 / 2]; -}; +kernel void kernel_convert_block_q4_0_trans4_ns( + global struct block_q4_0 * src0, + __global uint * dst_q, + __global half * dst_d, + uint ne00, + uint ne01 +) { + uint i00 = get_global_id(1); + uint i01 = get_global_id(0); + uint i02 = get_global_id(2); + + if (i01 >= ne01) { + return; + } + + uint ne00_blk = ne00 / QK4_0; + uint src_blk_offset = i00 + i01 * ne00_blk + i02 * ne00_blk * ne01; + uint dst_blk_offset = i01 + i00 * ne01 + i02 * ne00_blk * ne01; + + global struct block_q4_0 * b = src0 + src_blk_offset; + dst_d[dst_blk_offset] = b->d; + + // extract quantization and unshuffle + ushort8 pre_block = ((global ushort8 *)(&(b->qs[0])))[0]; + + ushort8 post_block = (ushort8)(0); + + uchar * pre_block_ptr = (uchar *)(&pre_block); + uchar * post_block_ptr = (uchar *)(&post_block); + + for (int i = 0; i < QK4_0 / 4; ++i) { + uchar x0 = pre_block_ptr[2*i + 0]; + uchar x1 = pre_block_ptr[2*i + 1]; + + post_block_ptr[i + 0 ] = convert_uchar(x0 & 0x0F) | convert_uchar((x1 & 0x0F) << 4); + post_block_ptr[i + QK4_0 / 4] = convert_uchar((x0 & 0xF0) >> 4) | convert_uchar(x1 & 0xF0); + } + + uint4 q_block = as_uint4(post_block); + + uint offset = i02 * ne00_blk * ne01 * 4 + i00 * ne01 * 4 + i01; + dst_q[offset] = q_block.x; + dst_q[offset + ne01] = q_block.y; + dst_q[offset + ne01 * 2] = q_block.z; + dst_q[offset + ne01 * 3] = q_block.w; +} + +kernel void kernel_restore_block_q4_0_trans4_ns( + __global uint * src_q, + __global half * src_d, + __global struct block_q4_0 * dst0, + uint ne00, + uint ne01 +) { + uint i00 = get_global_id(1); + uint i01 = get_global_id(0); + uint i02 = get_global_id(2); + + if (i01 >= ne01) { + return; + } + + uint ne00_blk = ne00 / QK4_0; + uint dst_blk_offset = i00 + i01 * ne00_blk + i02 * ne00_blk * ne01; + uint src_d_offset = i01 + i00 * ne01 + i02 * ne00_blk * ne01; + + __global struct block_q4_0 * b = dst0 + dst_blk_offset; + b->d = src_d[src_d_offset]; + + // collect transposed quantization parts for a block + uint src_q_offset = i02 * ne00_blk * ne01 * 4 + i00 * ne01 * 4 + i01; + uint4 q_block; + q_block.x = src_q[src_q_offset]; + q_block.y = src_q[src_q_offset + ne01]; + q_block.z = src_q[src_q_offset + ne01 * 2]; + q_block.w = src_q[src_q_offset + ne01 * 3]; + + ushort8 post_block = as_ushort8(q_block); + ushort8 pre_block = (ushort8)(0); + + uchar * pre_block_ptr = (uchar *)(&pre_block); + uchar * post_block_ptr = (uchar *)(&post_block); + + for (int i = 0; i < QK4_0 / 4; ++i) { + uchar x0 = post_block_ptr[i + 0]; + uchar x1 = post_block_ptr[i + QK4_0 / 4]; + + pre_block_ptr[2 * i + 0] = convert_uchar(x0 & 0x0F) | convert_uchar((x1 & 0x0F) << 4); + pre_block_ptr[2 * i + 1] = convert_uchar((x0 & 0xF0) >> 4) | convert_uchar(x1 & 0xF0); + } + + ((__global ushort8 *)(&(b->qs[0])))[0] = pre_block; +} //------------------------------------------------------------------------------ -// kernel_convert_block_mxfp4 -// Convert the block_mxfp4 format to 2 separate arrays (AOS -> SOA). +// kernel_convert_block_q4_1 +// Convert the block_q4_1 format to 2 separate arrays (AOS -> SOA). // This kernel does not deshuffle the bits. //------------------------------------------------------------------------------ -kernel void kernel_convert_block_mxfp4( - global struct block_mxfp4 * src0, +kernel void kernel_convert_block_q4_1( + global struct block_q4_1 * src0, global uchar * dst_q, - global uchar * dst_e + global half * dst_d, + global half * dst_m ) { - global struct block_mxfp4 * b = (global struct block_mxfp4 *) src0 + get_global_id(0); - global uchar * q = (global uchar *) dst_q + QK_MXFP4 / 2 * get_global_id(0); - global uchar * e = (global uchar *) dst_e + get_global_id(0); + global struct block_q4_1 * b = (global struct block_q4_1 *) src0 + get_global_id(0); + global uchar * q = (global uchar *) dst_q + QK4_1/2*get_global_id(0); + global half * d = (global half *) dst_d + get_global_id(0); + global half * m = (global half *) dst_m + get_global_id(0); - *e = b->e; + *d = b->d; + *m = b->m; - for (int i = 0; i < QK_MXFP4 / 2; ++i) { + for (int i = 0; i < QK4_1/2; ++i) { q[i] = b->qs[i]; } } -kernel void kernel_convert_block_mxfp4_trans( - global struct block_mxfp4 * src0, - __global uint4 * dst_q, - __global uchar * dst_e, +kernel void kernel_restore_block_q4_1( + global uchar * src_q, + global half * src_d, + global half * src_m, + global struct block_q4_1 * dst +) { + global struct block_q4_1 * b = (global struct block_q4_1 *) dst + get_global_id(0); + global uchar * q = (global uchar *) src_q + QK4_1/2*get_global_id(0); + global half * d = (global half *) src_d + get_global_id(0); + global half * m = (global half *) src_m + get_global_id(0); + + b->d = *d; + b->m = *m; + for (int i = 0; i < QK4_1/2; ++i) { + b->qs[i] = q[i]; + } +} + +kernel void kernel_convert_block_q4_1_noshuffle( + global struct block_q4_1 * src0, + global uchar * dst_q, + global half * dst_d, + global half * dst_m +) { + global struct block_q4_1 * b = (global struct block_q4_1 *) src0 + get_global_id(0); + global uchar * q = (global uchar *) dst_q + QK4_1/2*get_global_id(0); + global half * d = (global half *) dst_d + get_global_id(0); + global half * m = (global half *) dst_m + get_global_id(0); + + *d = b->d; + *m = b->m; + for (int i = 0; i < QK4_1/4; ++i) { + uchar x0 = b->qs[2*i + 0]; + uchar x1 = b->qs[2*i + 1]; + + q[i + 0 ] = convert_uchar(x0 & 0x0F) | convert_uchar((x1 & 0x0F) << 4); + q[i + QK4_1/4] = convert_uchar((x0 & 0xF0) >> 4) | convert_uchar(x1 & 0xF0); + +#ifdef ADRENO_GPU + if (get_global_id(0) == 65536*4096) { + printf("%04x - %02x\n", *(global ushort*)d, ((x0 & 0xF0) >> 4) | (x1 & 0xF0)); + } +#endif + } +} + +kernel void kernel_restore_block_q4_1_noshuffle( + global uchar * src_q, + global half * src_d, + global half * src_m, + global struct block_q4_1 * dst, + uchar mask_0F, + uchar mask_F0 +) { + global struct block_q4_1 * b = (global struct block_q4_1 *) dst + get_global_id(0); + global uchar * q = (global uchar *) src_q + QK4_1/2*get_global_id(0); + global half * d = (global half *) src_d + get_global_id(0); + global half * m = (global half *) src_m + get_global_id(0); + + b->d = *d; + b->m = *m; + for (int i = 0; i < QK4_1/4; ++i) { + uchar x0 = q[i + 0 ] ; + uchar x1 = q[i + QK4_1/4]; + + b->qs[2*i + 0] = convert_uchar((x0 & mask_0F) | ((x1 & mask_0F) << 4)); + b->qs[2*i + 1] = convert_uchar(((x0 & mask_F0) >> 4) | (x1 & mask_F0)); + } +} + +kernel void kernel_convert_block_q4_1_trans4_ns( + __global struct block_q4_1 * src0, + __global uint * dst_q, + __global half * dst_d, + __global half * dst_m, uint ne00, uint ne01 ) { - int i00 = get_global_id(1); + uint i00 = get_global_id(1); uint i01 = get_global_id(0); uint i02 = get_global_id(2); - uint ne00_blk = ne00 / QK_MXFP4; + if (i01 >= ne01) { + return; + } + + uint ne00_blk = ne00 / QK4_1; uint src_blk_offset = i00 + i01 * ne00_blk + i02 * ne00_blk * ne01; uint dst_blk_offset = i01 + i00 * ne01 + i02 * ne00_blk * ne01; - global struct block_mxfp4 * b = src0 + src_blk_offset; + global struct block_q4_1 * b = src0 + src_blk_offset; + dst_d[dst_blk_offset] = b->d; + dst_m[dst_blk_offset] = b->m; - dst_q[dst_blk_offset] = ((global uint4 *)(&(b->qs[0])))[0]; - dst_e[dst_blk_offset] = b->e; -} + // extract quantization and unshuffle + ushort8 pre_block = ((global ushort8 *)(&(b->qs[0])))[0]; -kernel void kernel_restore_block_mxfp4( - global uchar * src_q, - global half * src_e, - global struct block_mxfp4 * dst -) { - global struct block_mxfp4 * b = (global struct block_mxfp4 *) dst + get_global_id(0); - global uchar * q = (global uchar *) src_q + QK_MXFP4 / 2 * get_global_id(0); - global uchar * e = (global uchar *) src_e + get_global_id(0); + ushort8 post_block = (ushort8)(0); - b->e = *e; - for (int i = 0; i < QK_MXFP4 / 2; ++i) { - b->qs[i] = q[i]; + uchar * pre_block_ptr = (uchar *)(&pre_block); + uchar * post_block_ptr = (uchar *)(&post_block); + + for (int i = 0; i < QK4_1 / 4; ++i) { + uchar x0 = pre_block_ptr[2*i + 0]; + uchar x1 = pre_block_ptr[2*i + 1]; + + post_block_ptr[i + 0 ] = convert_uchar(x0 & 0x0F) | convert_uchar((x1 & 0x0F) << 4); + post_block_ptr[i + QK4_1 / 4] = convert_uchar((x0 & 0xF0) >> 4) | convert_uchar(x1 & 0xF0); } + + uint4 q_block = as_uint4(post_block); + + uint offset = i02 * ne00_blk * ne01 * 4 + i00 * ne01 * 4 + i01; + dst_q[offset] = q_block.x; + dst_q[offset + ne01] = q_block.y; + dst_q[offset + ne01 * 2] = q_block.z; + dst_q[offset + ne01 * 3] = q_block.w; } -kernel void kernel_restore_block_mxfp4_trans( - __global uint4 * src_q, - __global uchar * src_e, - global struct block_mxfp4 * dst, +kernel void kernel_restore_block_q4_1_trans4_ns( + __global uint * src_q, + __global half * src_d, + __global half * src_m, + __global struct block_q4_1 * dst0, uint ne00, uint ne01 ) { @@ -215,51 +500,1677 @@ kernel void kernel_restore_block_mxfp4_trans( uint i01 = get_global_id(0); uint i02 = get_global_id(2); - uint ne00_blk = ne00 / QK_MXFP4; - uint src_blk_offset = i01 + i00 * ne01 + i02 * ne00_blk * ne01; + if (i01 >= ne01) { + return; + } + + uint ne00_blk = ne00 / QK4_1; uint dst_blk_offset = i00 + i01 * ne00_blk + i02 * ne00_blk * ne01; + uint src_dm_offset = i01 + i00 * ne01 + i02 * ne00_blk * ne01; - global struct block_mxfp4 * b = dst + dst_blk_offset; + __global struct block_q4_1 * b = dst0 + dst_blk_offset; + b->d = src_d[src_dm_offset]; + b->m = src_m[src_dm_offset]; - ((global uint4 *)(&(b->qs[0])))[0] = src_q[src_blk_offset]; - b->e = src_e[src_blk_offset]; + // collect transposed quantization parts for a block + uint src_q_offset = i02 * ne00_blk * ne01 * 4 + i00 * ne01 * 4 + i01; + uint4 q_block; + q_block.x = src_q[src_q_offset]; + q_block.y = src_q[src_q_offset + ne01]; + q_block.z = src_q[src_q_offset + ne01 * 2]; + q_block.w = src_q[src_q_offset + ne01 * 3]; + + ushort8 post_block = as_ushort8(q_block); + ushort8 pre_block = (ushort8)(0); + + uchar * pre_block_ptr = (uchar *)(&pre_block); + uchar * post_block_ptr = (uchar *)(&post_block); + + for (int i = 0; i < QK4_0 / 4; ++i) { + uchar x0 = post_block_ptr[i + 0]; + uchar x1 = post_block_ptr[i + QK4_0 / 4]; + + pre_block_ptr[2 * i + 0] = convert_uchar(x0 & 0x0F) | convert_uchar((x1 & 0x0F) << 4); + pre_block_ptr[2 * i + 1] = convert_uchar((x0 & 0xF0) >> 4) | convert_uchar(x1 & 0xF0); + } + + ((__global ushort8 *)(&(b->qs[0])))[0] = pre_block; } //------------------------------------------------------------------------------ -// block_q8_0 +// kernel_convert_block_q5_0 +// Convert the block_q5_0 format to 3 separate arrays (AOS -> SOA). +// This kernel does not deshuffle the bits. //------------------------------------------------------------------------------ -typedef struct { - half d; // delta - char qs[QK8_0]; // quants -} block_q8_0; +kernel void kernel_convert_block_q5_0( + global struct block_q5_0 * src0, + global uchar * dst_qs, + global uint * dst_qh, + global half * dst_d, + ulong n_blk +) { + if (get_global_id(0) >= n_blk) { + return; + } -kernel void kernel_convert_block_q8_0( - global block_q8_0 * src0, + global struct block_q5_0 * b = (global struct block_q5_0 *) src0 + get_global_id(0); + global uchar * qs = (global uchar *) dst_qs + (QK5_0/2)*get_global_id(0); + global uint * qh = (global uint *) dst_qh + get_global_id(0); + global half * d = (global half *) dst_d + get_global_id(0); + + *d = b->d; + *qh = *((global uint *)(b->qh)); + + for (int i = 0; i < QK5_0/2; ++i) { + qs[i] = b->qs[i]; + } +} + +kernel void kernel_restore_block_q5_0( + global uchar * src_qs, + global uint * src_qh, + global half * src_d, + global struct block_q5_0 * dst +) { + global struct block_q5_0 * b = (global struct block_q5_0 *) dst + get_global_id(0); + global uchar * qs = (global uchar *) src_qs + (QK5_0/2)*get_global_id(0); + global uint * qh = (global uint *) src_qh + get_global_id(0); + global half * d = (global half *) src_d + get_global_id(0); + + b->d = *d; + *((global uint *)(b->qh)) = *qh; + for (int i = 0; i < QK5_0/2; ++i) { + b->qs[i] = qs[i]; + } +} + +kernel void kernel_convert_block_q5_0_noshuffle( + global struct block_q5_0 * src0, global uchar * dst_q, + global uint * dst_qh, global half * dst_d ) { - global block_q8_0 * b = (global block_q8_0 *) src0 + get_global_id(0); - global uchar * q = (global uchar *) dst_q + QK8_0*get_global_id(0); - global half * d = (global half *) dst_d + get_global_id(0); + global struct block_q5_0 * b = (global struct block_q5_0 *) src0 + get_global_id(0); + global uchar * q = (global uchar *) dst_q + QK5_0/2*get_global_id(0); + global uint * qh = (global uint *) dst_qh + get_global_id(0); + global half * d = (global half *) dst_d + get_global_id(0); *d = b->d; + *qh = *((global uint *)(b->qh)); - for (int i = 0; i < QK8_0; ++i) { - q[i] = b->qs[i]; + for (int i = 0; i < QK5_0/4; ++i) { + uchar x0 = b->qs[2*i + 0]; + uchar x1 = b->qs[2*i + 1]; + + q[i + 0 ] = convert_uchar(x0 & 0x0F) | convert_uchar((x1 & 0x0F) << 4); + q[i + QK5_0/4] = convert_uchar((x0 & 0xF0) >> 4) | convert_uchar(x1 & 0xF0); + +#ifdef ADRENO_GPU + if (get_global_id(0) == 65536*4096) { + printf("%04x - %02x\n", *(global ushort*)d, ((x0 & 0xF0) >> 4) | (x1 & 0xF0)); + } +#endif } } -kernel void kernel_restore_block_q8_0( +kernel void kernel_restore_block_q5_0_noshuffle( global uchar * src_q, + global uint * src_qh, global half * src_d, - global block_q8_0 * dst + global struct block_q5_0 * dst, + uchar mask_0F, + uchar mask_F0 ) { - global block_q8_0 * b = (global block_q8_0 *) dst + get_global_id(0); - global uchar * q = (global uchar *) src_q + QK8_0*get_global_id(0); - global half * d = (global half *) src_d + get_global_id(0); + global struct block_q5_0 * b = (global struct block_q5_0 *) dst + get_global_id(0); + global uchar * q = (global uchar *) src_q + QK5_0/2*get_global_id(0); + global uint * qh = (global uint *) src_qh + get_global_id(0); + global half * d = (global half *) src_d + get_global_id(0); b->d = *d; - for (int i = 0; i < QK8_0; ++i) { - b->qs[i] = q[i]; + *((global uint *)(b->qh)) = *qh; + + for (int i = 0; i < QK5_0/4; ++i) { + uchar x0 = q[i + 0 ]; + uchar x1 = q[i + QK5_0/4]; + + b->qs[2*i + 0] = convert_uchar((x0 & mask_0F) | ((x1 & mask_0F) << 4)); + b->qs[2*i + 1] = convert_uchar(((x0 & mask_F0) >> 4) | (x1 & mask_F0)); + } +} + +kernel void kernel_convert_block_q5_0_trans4_ns( + __global struct block_q5_0 * src0, + __global uint * dst_qs, + __global uint * dst_qh, + __global half * dst_d, + uint ne00, + uint ne01 +) { + uint i00 = get_global_id(1); + uint i01 = get_global_id(0); + uint i02 = get_global_id(2); + + if (i01 >= ne01) { + return; + } + + uint ne00_blk = ne00 / QK5_0; + uint src_blk_offset = i00 + i01 * ne00_blk + i02 * ne00_blk * ne01; + uint dst_blk_offset = i01 + i00 * ne01 + i02 * ne00_blk * ne01; + + global struct block_q5_0 * b = src0 + src_blk_offset; + dst_d[dst_blk_offset] = b->d; + + dst_qh[dst_blk_offset] = ((global uint *)(&(b->qh[0])))[0]; + + // extract quantization and unshuffle + ushort8 pre_block = ((global ushort8 *)(&(b->qs[0])))[0]; + ushort8 post_block = (ushort8)(0); + + uchar * pre_block_ptr = (uchar *)(&pre_block); + uchar * post_block_ptr = (uchar *)(&post_block); + + for (int i = 0; i < QK5_0 / 4; ++i) { + uchar x0 = pre_block_ptr[2*i + 0]; + uchar x1 = pre_block_ptr[2*i + 1]; + + post_block_ptr[i + 0 ] = convert_uchar(x0 & 0x0F) | convert_uchar((x1 & 0x0F) << 4); + post_block_ptr[i + QK5_0 / 4] = convert_uchar((x0 & 0xF0) >> 4) | convert_uchar(x1 & 0xF0); + } + + uint4 q_block = as_uint4(post_block); + + uint offset = i02 * ne00_blk * ne01 * 4 + i00 * ne01 * 4 + i01; + dst_qs[offset] = q_block.x; + dst_qs[offset + ne01] = q_block.y; + dst_qs[offset + ne01 * 2] = q_block.z; + dst_qs[offset + ne01 * 3] = q_block.w; +} + +kernel void kernel_restore_block_q5_0_trans4_ns( + __global uint * src_qs, + __global uint * src_qh, + __global half * src_d, + __global struct block_q5_0 * dst0, + uint ne00, + uint ne01 +) { + int i00 = get_global_id(1); + uint i01 = get_global_id(0); + uint i02 = get_global_id(2); + + if (i01 >= ne01) { + return; + } + + uint ne00_blk = ne00 / QK5_0; + uint dst_blk_offset = i00 + i01 * ne00_blk + i02 * ne00_blk * ne01; + uint src_blk_offset = i01 + i00 * ne01 + i02 * ne00_blk * ne01; + + __global struct block_q5_0 * b = dst0 + dst_blk_offset; + b->d = src_d[src_blk_offset]; + + ((__global uint *)(&(b->qh[0])))[0] = src_qh[src_blk_offset]; + + // collect transposed quantization parts for a block + uint src_q_offset = i02 * ne00_blk * ne01 * 4 + i00 * ne01 * 4 + i01; + uint4 q_block; + q_block.x = src_qs[src_q_offset]; + q_block.y = src_qs[src_q_offset + ne01]; + q_block.z = src_qs[src_q_offset + ne01 * 2]; + q_block.w = src_qs[src_q_offset + ne01 * 3]; + + ushort8 post_block = as_ushort8(q_block); + ushort8 pre_block = (ushort8)(0); + + uchar * pre_block_ptr = (uchar *)(&pre_block); + uchar * post_block_ptr = (uchar *)(&post_block); + + for (int i = 0; i < QK5_0 / 4; ++i) { + uchar x0 = post_block_ptr[i + 0]; + uchar x1 = post_block_ptr[i + QK5_0 / 4]; + + pre_block_ptr[2 * i + 0] = convert_uchar(x0 & 0x0F) | convert_uchar((x1 & 0x0F) << 4); + pre_block_ptr[2 * i + 1] = convert_uchar((x0 & 0xF0) >> 4) | convert_uchar(x1 & 0xF0); + } + + ((__global ushort8 *)(&(b->qs[0])))[0] = pre_block; +} + +//------------------------------------------------------------------------------ +// kernel_convert_block_q5_1 +// Convert the block_q5_1 format to 4 separate arrays (AOS -> SOA). +// This kernel does not deshuffle the bits. +//------------------------------------------------------------------------------ +kernel void kernel_convert_block_q5_1( + global struct block_q5_1 * src0, + global uchar * dst_qs, + global uint * dst_qh, + global half * dst_d, + global half * dst_m, + ulong n_blk +) { + if (get_global_id(0) >= n_blk) { + return; + } + + global struct block_q5_1 * b = (global struct block_q5_1 *) src0 + get_global_id(0); + global uchar * qs = (global uchar *) dst_qs + (QK5_1/2)*get_global_id(0); + global uint * qh = (global uint *) dst_qh + get_global_id(0); + global half * d = (global half *) dst_d + get_global_id(0); + global half * m = (global half *) dst_m + get_global_id(0); + + *d = b->d; + *m = b->m; + *qh = *((global uint *)(b->qh)); + + for (int i = 0; i < QK5_1/2; ++i) { + qs[i] = b->qs[i]; + } +} + +kernel void kernel_restore_block_q5_1( + global uchar * src_qs, + global uint * src_qh, + global half * src_d, + global half * src_m, + global struct block_q5_1 * dst +) { + global struct block_q5_1 * b = (global struct block_q5_1 *) dst + get_global_id(0); + global uchar * qs = (global uchar *) src_qs + (QK5_1/2)*get_global_id(0); + global uint * qh = (global uint *) src_qh + get_global_id(0); + global half * d = (global half *) src_d + get_global_id(0); + global half * m = (global half *) src_m + get_global_id(0); + + b->d = *d; + b->m = *m; + *((global uint *)(b->qh)) = *qh; + for (int i = 0; i < QK5_1/2; ++i) { + b->qs[i] = qs[i]; + } +} + +kernel void kernel_convert_block_q5_1_noshuffle( + global struct block_q5_1 * src0, + global uchar * dst_q, + global uint * dst_qh, + global half * dst_d, + global half * dst_m +) { + global struct block_q5_1 * b = (global struct block_q5_1 *) src0 + get_global_id(0); + global uchar * q = (global uchar *) dst_q + QK5_1/2*get_global_id(0); + global uint * qh = (global uint *) dst_qh + get_global_id(0); + global half * d = (global half *) dst_d + get_global_id(0); + global half * m = (global half *) dst_m + get_global_id(0); + + *d = b->d; + *m = b->m; + *qh = *((global uint *)(b->qh)); + + for (int i = 0; i < QK5_1/4; ++i) { + uchar x0 = b->qs[2*i + 0]; + uchar x1 = b->qs[2*i + 1]; + + q[i + 0 ] = convert_uchar(x0 & 0x0F) | convert_uchar((x1 & 0x0F) << 4); + q[i + QK5_1/4] = convert_uchar((x0 & 0xF0) >> 4) | convert_uchar(x1 & 0xF0); + +#ifdef ADRENO_GPU + if (get_global_id(0) == 65536*4096) { + printf("%04x - %02x\n", *(global ushort*)d, ((x0 & 0xF0) >> 4) | (x1 & 0xF0)); + } +#endif + } +} + +kernel void kernel_restore_block_q5_1_noshuffle( + global uchar * src_q, + global uint * src_qh, + global half * src_d, + global half * src_m, + global struct block_q5_1 * dst, + uchar mask_0F, + uchar mask_F0 +) { + global struct block_q5_1 * b = (global struct block_q5_1 *) dst + get_global_id(0); + global uchar * q = (global uchar *) src_q + QK5_1/2*get_global_id(0); + global uint * qh = (global uint *) src_qh + get_global_id(0); + global half * d = (global half *) src_d + get_global_id(0); + global half * m = (global half *) src_m + get_global_id(0); + + b->d = *d; + b->m = *m; + *((global uint *)(b->qh)) = *qh; + + for (int i = 0; i < QK5_1/4; ++i) { + uchar x0 = q[i + 0 ]; + uchar x1 = q[i + QK5_1/4]; + + b->qs[2*i + 0] = convert_uchar((x0 & mask_0F) | ((x1 & mask_0F) << 4)); + b->qs[2*i + 1] = convert_uchar(((x0 & mask_F0) >> 4) | (x1 & mask_F0)); + } +} + +kernel void kernel_convert_block_q5_1_trans4_ns( + __global struct block_q5_1 * src0, + __global uint * dst_qs, + __global uint * dst_qh, + __global half * dst_d, + __global half * dst_m, + uint ne00, + uint ne01 +) { + uint i00 = get_global_id(1); + uint i01 = get_global_id(0); + uint i02 = get_global_id(2); + + if (i01 >= ne01) { + return; + } + + uint ne00_blk = ne00 / QK5_1; + uint src_blk_offset = i00 + i01 * ne00_blk + i02 * ne00_blk * ne01; + uint dst_blk_offset = i01 + i00 * ne01 + i02 * ne00_blk * ne01; + + global struct block_q5_1 * b = src0 + src_blk_offset; + dst_d[dst_blk_offset] = b->d; + dst_m[dst_blk_offset] = b->m; + + dst_qh[dst_blk_offset] = ((global uint *)(&(b->qh[0])))[0]; + + // extract quantization and unshuffle + ushort8 pre_block = ((global ushort8 *)(&(b->qs[0])))[0]; + ushort8 post_block = (ushort8)(0); + + uchar * pre_block_ptr = (uchar *)(&pre_block); + uchar * post_block_ptr = (uchar *)(&post_block); + + for (int i = 0; i < QK5_1 / 4; ++i) { + uchar x0 = pre_block_ptr[2*i + 0]; + uchar x1 = pre_block_ptr[2*i + 1]; + + post_block_ptr[i + 0 ] = convert_uchar(x0 & 0x0F) | convert_uchar((x1 & 0x0F) << 4); + post_block_ptr[i + QK5_1 / 4] = convert_uchar((x0 & 0xF0) >> 4) | convert_uchar(x1 & 0xF0); + } + + uint4 q_block = as_uint4(post_block); + + uint offset = i02 * ne00_blk * ne01 * 4 + i00 * ne01 * 4 + i01; + dst_qs[offset] = q_block.x; + dst_qs[offset + ne01] = q_block.y; + dst_qs[offset + ne01 * 2] = q_block.z; + dst_qs[offset + ne01 * 3] = q_block.w; +} + +kernel void kernel_restore_block_q5_1_trans4_ns( + __global uint * src_qs, + __global uint * src_qh, + __global half * src_d, + __global half * src_m, + __global struct block_q5_1 * dst0, + uint ne00, + uint ne01 +) { + int i00 = get_global_id(1); + uint i01 = get_global_id(0); + uint i02 = get_global_id(2); + + if (i01 >= ne01) { + return; + } + + uint ne00_blk = ne00 / QK5_1; + uint dst_blk_offset = i00 + i01 * ne00_blk + i02 * ne00_blk * ne01; + uint src_blk_offset = i01 + i00 * ne01 + i02 * ne00_blk * ne01; + + __global struct block_q5_1 * b = dst0 + dst_blk_offset; + b->d = src_d[src_blk_offset]; + b->m = src_m[src_blk_offset]; + + ((__global uint *)(&(b->qh[0])))[0] = src_qh[src_blk_offset]; + + // collect transposed quantization parts for a block + uint src_q_offset = i02 * ne00_blk * ne01 * 4 + i00 * ne01 * 4 + i01; + uint4 q_block; + q_block.x = src_qs[src_q_offset]; + q_block.y = src_qs[src_q_offset + ne01]; + q_block.z = src_qs[src_q_offset + ne01 * 2]; + q_block.w = src_qs[src_q_offset + ne01 * 3]; + + ushort8 post_block = as_ushort8(q_block); + ushort8 pre_block = (ushort8)(0); + + uchar * pre_block_ptr = (uchar *)(&pre_block); + uchar * post_block_ptr = (uchar *)(&post_block); + + for (int i = 0; i < QK5_1 / 4; ++i) { + uchar x0 = post_block_ptr[i + 0]; + uchar x1 = post_block_ptr[i + QK5_1 / 4]; + + pre_block_ptr[2 * i + 0] = convert_uchar(x0 & 0x0F) | convert_uchar((x1 & 0x0F) << 4); + pre_block_ptr[2 * i + 1] = convert_uchar((x0 & 0xF0) >> 4) | convert_uchar(x1 & 0xF0); + } + ((__global ushort8 *)(&(b->qs[0])))[0] = pre_block; +} + +kernel void kernel_convert_block_q4_k_trans4_ns( + __global struct block_q4_K * src0, + __global uint * dst_q, + __global half * dst_d, + __global half * dst_dm, + __global uchar * dst_s, + uint ne00, + uint ne01, + uchar mask_0F, + uchar mask_F0 +) { + uint i00 = get_global_id(1); + uint i01 = get_global_id(0); + uint i02 = get_global_id(2); + + if (i01 >= ne01) { + return; + } + + uint ne00_blk = ne00 / QK_K; + uint src_blk_offset = i00 + i01 * ne00_blk + i02 * ne00_blk * ne01; + uint dst_blk_offset = i01 + i00 * ne01 + i02 * ne00_blk * ne01; + + __global struct block_q4_K * b = src0 + src_blk_offset; + + dst_d [dst_blk_offset] = b->d; + dst_dm[dst_blk_offset] = b->dm; + + uint4 qv[8]; + uchar * qv_bytes = (uchar *)qv; + for (int i = 0; i < QK_K / 64; ++i) { + for (int j = 0; j < 16; ++j) { + uchar x0 = b->q[i*32 + 2*j]; + uchar x1 = b->q[i*32 + 2*j + 1]; + + qv_bytes[i*32 + j ] = convert_uchar(x0 & mask_0F) | convert_uchar((x1 & mask_0F) << 4); + qv_bytes[i*32 + j + 16] = convert_uchar((x0 & mask_F0) >> 4) | convert_uchar(x1 & mask_F0); + } + } + + uint base = i02 * ne00_blk * ne01 * 32 + i00 * ne01 * 32 + i01; + #pragma unroll + for (int p = 0; p < 8; ++p) { + uint4 v = qv[p]; + dst_q[base + (p * 4 + 0) * ne01] = v.x; + dst_q[base + (p * 4 + 1) * ne01] = v.y; + dst_q[base + (p * 4 + 2) * ne01] = v.z; + dst_q[base + (p * 4 + 3) * ne01] = v.w; + } + + __global uchar * s_dst = dst_s + (i02 * ne01 + i01) * ne00_blk * K_SCALE_SIZE + i00 * K_SCALE_SIZE; + #pragma unroll + for (int i = 0; i < K_SCALE_SIZE; ++i) { + s_dst[i] = b->s[i]; + } +} + +kernel void kernel_restore_block_q4_k_trans4_ns( + __global uint * src_q, + __global half * src_d, + __global half * src_dm, + __global uchar * src_s, + __global struct block_q4_K * dst0, + uint ne00, + uint ne01, + uchar mask_0F, + uchar mask_F0 +) { + uint i00 = get_global_id(1); // block index along K + uint i01 = get_global_id(0); // row index + uint i02 = get_global_id(2); // batch index + + if (i01 >= ne01) { + return; + } + + uint ne00_blk = ne00 / QK_K; + + uint src_blk_offset = i01 + i00 * ne01 + i02 * ne00_blk * ne01; + uint dst_blk_offset = i00 + i01 * ne00_blk + i02 * ne00_blk * ne01; + + __global struct block_q4_K * b = dst0 + dst_blk_offset; + + b->d = src_d[src_blk_offset]; + b->dm = src_dm[src_blk_offset]; + + __global uchar * s_src = src_s + (i02 * ne01 + i01) * ne00_blk * K_SCALE_SIZE + i00 * K_SCALE_SIZE; + for (int i = 0; i < K_SCALE_SIZE; ++i) { + b->s[i] = s_src[i]; + } + + uint base = i02 * ne00_blk * ne01 * 32 + i00 * ne01 * 32 + i01; + + uint4 qv[8]; + for (int p = 0; p < 8; ++p) { + qv[p].x = src_q[base + (p * 4 + 0) * ne01]; + qv[p].y = src_q[base + (p * 4 + 1) * ne01]; + qv[p].z = src_q[base + (p * 4 + 2) * ne01]; + qv[p].w = src_q[base + (p * 4 + 3) * ne01]; + } + + uchar * qv_bytes = (uchar *)qv; + for (int i = 0; i < QK_K / 64; ++i) { + for (int j = 0; j < 16; ++j) { + uchar lo = qv_bytes[i*32 + j]; + uchar hi = qv_bytes[i*32 + j + 16]; + b->q[i*32 + 2*j] = convert_uchar((lo & mask_0F) | ((hi & mask_0F) << 4)); + b->q[i*32 + 2*j + 1] = convert_uchar(((lo & mask_F0) >> 4) | (hi & mask_F0)); + } + } +} + +kernel void kernel_convert_block_q5_k_trans4_ns( + __global struct block_q5_K * src0, + __global uint * dst_qs, + __global uint * dst_qh, + __global half * dst_d, + __global half * dst_dm, + __global uchar * dst_s, + uint ne00, + uint ne01, + uchar mask_0F, + uchar mask_F0 +) { + uint i00 = get_global_id(1); + uint i01 = get_global_id(0); + uint i02 = get_global_id(2); + + if (i01 >= ne01) { + return; + } + + uint ne00_blk = ne00 / QK_K; + uint src_blk_offset = i00 + i01 * ne00_blk + i02 * ne00_blk * ne01; + uint dst_blk_offset = i01 + i00 * ne01 + i02 * ne00_blk * ne01; + + __global struct block_q5_K * b = src0 + src_blk_offset; + + dst_d [dst_blk_offset] = b->d; + dst_dm[dst_blk_offset] = b->dm; + + for (int k = 0; k < 8; k++) { + uchar b0 = 0, b1 = 0, b2 = 0, b3 = 0; + for (int bit = 0; bit < 8; bit++) { + b0 |= (uchar)(((b->qh[bit] >> k) & 1) << bit); + b1 |= (uchar)(((b->qh[8 + bit] >> k) & 1) << bit); + b2 |= (uchar)(((b->qh[16 + bit] >> k) & 1) << bit); + b3 |= (uchar)(((b->qh[24 + bit] >> k) & 1) << bit); + } + uint packed = (uint)b0 | ((uint)b1 << 8) | ((uint)b2 << 16) | ((uint)b3 << 24); + dst_qh[i01 + (i00 * 8 + k) * ne01 + i02 * ne00_blk * 8 * ne01] = packed; + } + + uint4 qv[8]; + uchar * qv_bytes = (uchar *)qv; + for (int i = 0; i < QK_K / 64; ++i) { + for (int j = 0; j < 16; ++j) { + uchar x0 = b->qs[i*32 + 2*j]; + uchar x1 = b->qs[i*32 + 2*j + 1]; + + qv_bytes[i*32 + j ] = convert_uchar(x0 & mask_0F) | convert_uchar((x1 & mask_0F) << 4); + qv_bytes[i*32 + j + 16] = convert_uchar((x0 & mask_F0) >> 4) | convert_uchar(x1 & mask_F0); + } + } + + uint base = i02 * ne00_blk * ne01 * 32 + i00 * ne01 * 32 + i01; + #pragma unroll + for (int p = 0; p < 8; ++p) { + uint4 v = qv[p]; + dst_qs[base + (p * 4 + 0) * ne01] = v.x; + dst_qs[base + (p * 4 + 1) * ne01] = v.y; + dst_qs[base + (p * 4 + 2) * ne01] = v.z; + dst_qs[base + (p * 4 + 3) * ne01] = v.w; + } + + __global uchar * s_dst = dst_s + (i02 * ne01 + i01) * ne00_blk * K_SCALE_SIZE + i00 * K_SCALE_SIZE; + #pragma unroll + for (int i = 0; i < K_SCALE_SIZE; ++i) { + s_dst[i] = b->s[i]; + } +} + +kernel void kernel_restore_block_q5_k_trans4_ns( + __global uint * src_qs, + __global uint * src_qh, + __global half * src_d, + __global half * src_dm, + __global uchar * src_s, + __global struct block_q5_K * dst0, + uint ne00, + uint ne01, + uchar mask_0F, + uchar mask_F0 +) { + uint i00 = get_global_id(1); // block index along K + uint i01 = get_global_id(0); // row index + uint i02 = get_global_id(2); // batch index + + if (i01 >= ne01) { + return; + } + + uint ne00_blk = ne00 / QK_K; + + uint src_blk_offset = i01 + i00 * ne01 + i02 * ne00_blk * ne01; + uint dst_blk_offset = i00 + i01 * ne00_blk + i02 * ne00_blk * ne01; + + __global struct block_q5_K * b = dst0 + dst_blk_offset; + + b->d = src_d[src_blk_offset]; + b->dm = src_dm[src_blk_offset]; + + for (int j = 0; j < 32; j++) b->qh[j] = 0; + for (int k = 0; k < 8; k++) { + uint packed = src_qh[i01 + (i00 * 8 + k) * ne01 + i02 * ne00_blk * 8 * ne01]; + uchar b0 = (uchar)(packed & 0xFF); + uchar b1 = (uchar)((packed >> 8) & 0xFF); + uchar b2 = (uchar)((packed >> 16) & 0xFF); + uchar b3 = (uchar)((packed >> 24) & 0xFF); + for (int bit = 0; bit < 8; bit++) { + b->qh[bit] |= (uchar)(((b0 >> bit) & 1) << k); + b->qh[8 + bit] |= (uchar)(((b1 >> bit) & 1) << k); + b->qh[16 + bit] |= (uchar)(((b2 >> bit) & 1) << k); + b->qh[24 + bit] |= (uchar)(((b3 >> bit) & 1) << k); + } + } + + __global uchar * s_src = src_s + (i02 * ne01 + i01) * ne00_blk * K_SCALE_SIZE + i00 * K_SCALE_SIZE; + for (int i = 0; i < K_SCALE_SIZE; ++i) { + b->s[i] = s_src[i]; + } + + uint base = i02 * ne00_blk * ne01 * 32 + i00 * ne01 * 32 + i01; + + uint4 qv[8]; + for (int p = 0; p < 8; ++p) { + qv[p].x = src_qs[base + (p * 4 + 0) * ne01]; + qv[p].y = src_qs[base + (p * 4 + 1) * ne01]; + qv[p].z = src_qs[base + (p * 4 + 2) * ne01]; + qv[p].w = src_qs[base + (p * 4 + 3) * ne01]; + } + + uchar * qv_bytes = (uchar *)qv; + for (int i = 0; i < QK_K / 64; ++i) { + for (int j = 0; j < 16; ++j) { + uchar lo = qv_bytes[i*32 + j]; + uchar hi = qv_bytes[i*32 + j + 16]; + b->qs[i*32 + 2*j] = convert_uchar((lo & mask_0F) | ((hi & mask_0F) << 4)); + b->qs[i*32 + 2*j + 1] = convert_uchar(((lo & mask_F0) >> 4) | (hi & mask_F0)); + } + } +} + +kernel void kernel_convert_block_q6_k_trans4_ns( + __global struct block_q6_K * src0, + __global uint * dst_ql, + __global uint * dst_qh, + __global half * dst_d, + __global char * dst_s, + uint ne00, + uint ne01, + uchar mask_0F, + uchar mask_F0 +) { + uint i00 = get_global_id(1); + uint i01 = get_global_id(0); + uint i02 = get_global_id(2); + + if (i01 >= ne01) { + return; + } + + uint ne00_blk = ne00 / QK_K; + + uint src_blk_offset = i00 + i01 * ne00_blk + i02 * ne00_blk * ne01; + uint dst_blk_offset = i01 + i00 * ne01 + i02 * ne00_blk * ne01; + + __global struct block_q6_K * b = src0 + src_blk_offset; + + dst_d[dst_blk_offset] = b->d; + + uint4 qlv[8]; + uchar * qlv_bytes = (uchar *)qlv; + for (int i = 0; i < 2; ++i) { + for (int j = 0; j < 16; ++j) { + uchar x0 = b->ql[i*64 + 2*j]; + uchar x1 = b->ql[i*64 + 2*j + 1]; + uchar x2 = b->ql[i*64 + 32 + 2*j]; + uchar x3 = b->ql[i*64 + 32 + 2*j + 1]; + qlv_bytes[i*64 + j ] = convert_uchar(x0 & mask_0F) | convert_uchar((x1 & mask_0F) << 4); + qlv_bytes[i*64 + j + 16] = convert_uchar(x2 & mask_0F) | convert_uchar((x3 & mask_0F) << 4); + qlv_bytes[i*64 + j + 32] = convert_uchar((x0 & mask_F0) >> 4) | convert_uchar(x1 & mask_F0); + qlv_bytes[i*64 + j + 48] = convert_uchar((x2 & mask_F0) >> 4) | convert_uchar(x3 & mask_F0); + } + } + + uint ql_base = i02 * ne00_blk * ne01 * 32 + i00 * ne01 * 32 + i01; + + #pragma unroll + for (int p = 0; p < 8; ++p) { + uint4 v = qlv[p]; + dst_ql[ql_base + (p * 4 + 0) * ne01] = v.x; + dst_ql[ql_base + (p * 4 + 1) * ne01] = v.y; + dst_ql[ql_base + (p * 4 + 2) * ne01] = v.z; + dst_ql[ql_base + (p * 4 + 3) * ne01] = v.w; + } + + uint qhv[16] = {0}; + + for (int n = 0; n < 2; ++n) { + for (int l = 0; l < 32; ++l) { + uchar h = b->qh[n*32 + l]; + int u = l / 16; + int bit_pos = (l % 16) * 2; + qhv[(n*4 + 0)*2 + u] |= ((uint)((h >> 0) & 0x03)) << bit_pos; + qhv[(n*4 + 1)*2 + u] |= ((uint)((h >> 2) & 0x03)) << bit_pos; + qhv[(n*4 + 2)*2 + u] |= ((uint)((h >> 4) & 0x03)) << bit_pos; + qhv[(n*4 + 3)*2 + u] |= ((uint)((h >> 6) & 0x03)) << bit_pos; + } + } + + uint qh_base = i02 * ne00_blk * ne01 * 16 + i00 * ne01 * 16 + i01; + + for (int p = 0; p < 16; ++p) { + dst_qh[qh_base + p * ne01] = qhv[p]; + } + + __global char * s_dst = dst_s + (i02 * ne01 + i01) * ne00_blk * 16 + i00 * 16; + #pragma unroll + for (int i = 0; i < 16; ++i) { + s_dst[i] = b->scales[i]; + } +} + +kernel void kernel_restore_block_q6_k_trans4_ns( + __global uint * src_ql, + __global uint * src_qh, + __global half * src_d, + __global char * src_s, + __global struct block_q6_K * dst0, + uint ne00, + uint ne01, + uchar mask_0F, + uchar mask_F0 +) { + uint i00 = get_global_id(1); // block index along K + uint i01 = get_global_id(0); // row index + uint i02 = get_global_id(2); // batch index + + if (i01 >= ne01) { + return; + } + + uint ne00_blk = ne00 / QK_K; + + uint src_blk_offset = i01 + i00 * ne01 + i02 * ne00_blk * ne01; + uint dst_blk_offset = i00 + i01 * ne00_blk + i02 * ne00_blk * ne01; + + __global struct block_q6_K * b = dst0 + dst_blk_offset; + + b->d = src_d[src_blk_offset]; + + uint ql_base = i02 * ne00_blk * ne01 * 32 + i00 * ne01 * 32 + i01; + uint4 qlv[8]; + for (int p = 0; p < 8; ++p) { + qlv[p].x = src_ql[ql_base + (p * 4 + 0) * ne01]; + qlv[p].y = src_ql[ql_base + (p * 4 + 1) * ne01]; + qlv[p].z = src_ql[ql_base + (p * 4 + 2) * ne01]; + qlv[p].w = src_ql[ql_base + (p * 4 + 3) * ne01]; + } + + uchar * qlv_bytes = (uchar *)qlv; + for (int i = 0; i < 2; ++i) { + for (int j = 0; j < 16; ++j) { + uchar lo_02 = qlv_bytes[i*64 + j]; + uchar lo_13 = qlv_bytes[i*64 + j + 16]; + uchar hi_02 = qlv_bytes[i*64 + j + 32]; + uchar hi_13 = qlv_bytes[i*64 + j + 48]; + b->ql[i*64 + 2*j] = convert_uchar((lo_02 & mask_0F) | ((hi_02 & mask_0F) << 4)); + b->ql[i*64 + 2*j + 1] = convert_uchar(((lo_02 & mask_F0) >> 4) | (hi_02 & mask_F0)); + b->ql[i*64 + 32 + 2*j] = convert_uchar((lo_13 & mask_0F) | ((hi_13 & mask_0F) << 4)); + b->ql[i*64 + 32 + 2*j + 1] = convert_uchar(((lo_13 & mask_F0) >> 4) | (hi_13 & mask_F0)); + } + } + + uint qh_base = i02 * ne00_blk * ne01 * 16 + i00 * ne01 * 16 + i01; + uint qhv[16]; + for (int p = 0; p < 16; ++p) { + qhv[p] = src_qh[qh_base + p * ne01]; + } + + for (int n = 0; n < 2; ++n) { + for (int l = 0; l < 32; ++l) { + int u = l / 16; + int bit_pos = (l % 16) * 2; + uchar v0 = (uchar)((qhv[(n*4 + 0)*2 + u] >> bit_pos) & 0x03); + uchar v1 = (uchar)((qhv[(n*4 + 1)*2 + u] >> bit_pos) & 0x03); + uchar v2 = (uchar)((qhv[(n*4 + 2)*2 + u] >> bit_pos) & 0x03); + uchar v3 = (uchar)((qhv[(n*4 + 3)*2 + u] >> bit_pos) & 0x03); + b->qh[n*32 + l] = v0 | (v1 << 2) | (v2 << 4) | (v3 << 6); + } + } + + __global char * s_src = src_s + (i02 * ne01 + i01) * ne00_blk * 16 + i00 * 16; + for (int i = 0; i < 16; ++i) { + b->scales[i] = s_src[i]; + } +} + +//------------------------------------------------------------------------------ +// block_mxfp4 +//------------------------------------------------------------------------------ +#define QK_MXFP4 32 +struct block_mxfp4 { + uchar e; // E8M0 + uchar qs[QK_MXFP4 / 2]; +}; + +//------------------------------------------------------------------------------ +// kernel_convert_block_mxfp4 +// Convert the block_mxfp4 format to 2 separate arrays (AOS -> SOA). +// This kernel does not deshuffle the bits. +//------------------------------------------------------------------------------ +kernel void kernel_convert_block_mxfp4( + global struct block_mxfp4 * src0, + global uchar * dst_q, + global uchar * dst_e +) { + global struct block_mxfp4 * b = (global struct block_mxfp4 *) src0 + get_global_id(0); + global uchar * q = (global uchar *) dst_q + QK_MXFP4 / 2 * get_global_id(0); + global uchar * e = (global uchar *) dst_e + get_global_id(0); + + *e = b->e; + + for (int i = 0; i < QK_MXFP4 / 2; ++i) { + q[i] = b->qs[i]; + } +} + +kernel void kernel_convert_block_mxfp4_trans( + global struct block_mxfp4 * src0, + __global uint4 * dst_q, + __global uchar * dst_e, + uint ne00, + uint ne01 +) { + int i00 = get_global_id(1); + uint i01 = get_global_id(0); + uint i02 = get_global_id(2); + + uint ne00_blk = ne00 / QK_MXFP4; + uint src_blk_offset = i00 + i01 * ne00_blk + i02 * ne00_blk * ne01; + uint dst_blk_offset = i01 + i00 * ne01 + i02 * ne00_blk * ne01; + + global struct block_mxfp4 * b = src0 + src_blk_offset; + + dst_q[dst_blk_offset] = ((global uint4 *)(&(b->qs[0])))[0]; + dst_e[dst_blk_offset] = b->e; +} + +kernel void kernel_restore_block_mxfp4( + global uchar * src_q, + global half * src_e, + global struct block_mxfp4 * dst +) { + global struct block_mxfp4 * b = (global struct block_mxfp4 *) dst + get_global_id(0); + global uchar * q = (global uchar *) src_q + QK_MXFP4 / 2 * get_global_id(0); + global uchar * e = (global uchar *) src_e + get_global_id(0); + + b->e = *e; + for (int i = 0; i < QK_MXFP4 / 2; ++i) { + b->qs[i] = q[i]; + } +} + +kernel void kernel_restore_block_mxfp4_trans( + __global uint4 * src_q, + __global uchar * src_e, + global struct block_mxfp4 * dst, + uint ne00, + uint ne01 +) { + int i00 = get_global_id(1); + uint i01 = get_global_id(0); + uint i02 = get_global_id(2); + + uint ne00_blk = ne00 / QK_MXFP4; + uint src_blk_offset = i01 + i00 * ne01 + i02 * ne00_blk * ne01; + uint dst_blk_offset = i00 + i01 * ne00_blk + i02 * ne00_blk * ne01; + + global struct block_mxfp4 * b = dst + dst_blk_offset; + + ((global uint4 *)(&(b->qs[0])))[0] = src_q[src_blk_offset]; + b->e = src_e[src_blk_offset]; +} + +kernel void kernel_convert_block_mxfp4_trans4_ns( + global struct block_mxfp4 * src0, + __global uint * dst_q, + __global uchar * dst_e, + uint ne00, + uint ne01 +) { + uint i00 = get_global_id(1); + uint i01 = get_global_id(0); + uint i02 = get_global_id(2); + + if (i01 >= ne01) { + return; + } + + uint ne00_blk = ne00 / QK_MXFP4; + uint src_blk_offset = i00 + i01 * ne00_blk + i02 * ne00_blk * ne01; + uint dst_blk_offset = i01 + i00 * ne01 + i02 * ne00_blk * ne01; + + global struct block_mxfp4 * b = src0 + src_blk_offset; + dst_e[dst_blk_offset] = b->e; + + // extract quantization and unshuffle + ushort8 pre_block = ((global ushort8 *)(&(b->qs[0])))[0]; + + ushort8 post_block = (ushort8)(0); + + uchar * pre_block_ptr = (uchar *)(&pre_block); + uchar * post_block_ptr = (uchar *)(&post_block); + + for (int i = 0; i < QK_MXFP4 / 4; ++i) { + uchar x0 = pre_block_ptr[2*i + 0]; + uchar x1 = pre_block_ptr[2*i + 1]; + + post_block_ptr[i + 0 ] = convert_uchar(x0 & 0x0F) | convert_uchar((x1 & 0x0F) << 4); + post_block_ptr[i + QK_MXFP4 / 4] = convert_uchar((x0 & 0xF0) >> 4) | convert_uchar(x1 & 0xF0); + } + + uint4 q_block = as_uint4(post_block); + + uint offset = i02 * ne00_blk * ne01 * 4 + i00 * ne01 * 4 + i01; + dst_q[offset] = q_block.x; + dst_q[offset + ne01] = q_block.y; + dst_q[offset + ne01 * 2] = q_block.z; + dst_q[offset + ne01 * 3] = q_block.w; +} + +kernel void kernel_restore_block_mxfp4_trans4_ns( + __global uint * src_q, + __global uchar * src_e, + __global struct block_mxfp4 * dst0, + uint ne00, + uint ne01 +) { + uint i00 = get_global_id(1); + uint i01 = get_global_id(0); + uint i02 = get_global_id(2); + + if (i01 >= ne01) { + return; + } + + uint ne00_blk = ne00 / QK_MXFP4; + uint dst_blk_offset = i00 + i01 * ne00_blk + i02 * ne00_blk * ne01; + uint src_d_offset = i01 + i00 * ne01 + i02 * ne00_blk * ne01; + + __global struct block_mxfp4 * b = dst0 + dst_blk_offset; + b->e = src_e[src_d_offset]; + + // collect transposed quantization parts for a block + uint src_q_offset = i02 * ne00_blk * ne01 * 4 + i00 * ne01 * 4 + i01; + uint4 q_block; + q_block.x = src_q[src_q_offset]; + q_block.y = src_q[src_q_offset + ne01]; + q_block.z = src_q[src_q_offset + ne01 * 2]; + q_block.w = src_q[src_q_offset + ne01 * 3]; + + ushort8 post_block = as_ushort8(q_block); + ushort8 pre_block = (ushort8)(0); + + uchar * pre_block_ptr = (uchar *)(&pre_block); + uchar * post_block_ptr = (uchar *)(&post_block); + + for (int i = 0; i < QK_MXFP4 / 4; ++i) { + uchar x0 = post_block_ptr[i + 0]; + uchar x1 = post_block_ptr[i + QK_MXFP4 / 4]; + + pre_block_ptr[2 * i + 0] = convert_uchar(x0 & 0x0F) | convert_uchar((x1 & 0x0F) << 4); + pre_block_ptr[2 * i + 1] = convert_uchar((x0 & 0xF0) >> 4) | convert_uchar(x1 & 0xF0); + } + + ((__global ushort8 *)(&(b->qs[0])))[0] = pre_block; +} + + +//------------------------------------------------------------------------------ +// block_q8_0 +//------------------------------------------------------------------------------ +typedef struct { + half d; // delta + char qs[QK8_0]; // quants +} block_q8_0; + +kernel void kernel_convert_block_q8_0( + global block_q8_0 * src0, + global uchar * dst_q, + global half * dst_d +) { + global block_q8_0 * b = (global block_q8_0 *) src0 + get_global_id(0); + global uchar * q = (global uchar *) dst_q + QK8_0*get_global_id(0); + global half * d = (global half *) dst_d + get_global_id(0); + + *d = b->d; + + for (int i = 0; i < QK8_0; ++i) { + q[i] = b->qs[i]; + } +} + +kernel void kernel_restore_block_q8_0( + global uchar * src_q, + global half * src_d, + global block_q8_0 * dst +) { + global block_q8_0 * b = (global block_q8_0 *) dst + get_global_id(0); + global uchar * q = (global uchar *) src_q + QK8_0*get_global_id(0); + global half * d = (global half *) src_d + get_global_id(0); + + b->d = *d; + for (int i = 0; i < QK8_0; ++i) { + b->qs[i] = q[i]; + } +} + +kernel void kernel_restore_block_q8_0_trans( + global uchar * src_q, + global half * src_d, + global block_q8_0 * dst, + uint ne00, + uint ne01 +){ + uint num_blk_per_row = ne00 / QK8_0; + + global block_q8_0 * b = (global block_q8_0 *) dst + get_global_id(0) * num_blk_per_row; + global uchar * q = (global uchar *) src_q + get_global_id(0) * 4; // 4 8-bit packed + global half * d = (global half *) src_d + get_global_id(0); + + for (uint blk = 0; blk < num_blk_per_row; blk++) { + b->d = *d; + + for (uint i = 0; i < QK8_0; i+=4) { + b->qs[i] = q[0]; + b->qs[i+1] = q[1]; + b->qs[i+2] = q[2]; + b->qs[i+3] = q[3]; + + q += 4 * ne01; // M stride + } + + d += ne01; + + b++; + } +} + +//------------------------------------------------------------------------------ +// kernel_convert_block_q4_K +// Convert the block_q4_K format to 4 separate arrays (AOS -> SOA). +// This kernel does not deshuffle the bits. +// Each thread processes a super block. +// Mask args are just to keep the signature consistent with the no-shuffle +// version and they are not used in this kernel. +//------------------------------------------------------------------------------ +kernel void kernel_convert_block_q4_K( + global struct block_q4_K * src0, + global uchar * dst_q, + global uchar * dst_s, + global half * dst_d, + global half * dst_dm, + uchar mask_0F, + uchar mask_F0 +) { + global struct block_q4_K * b = (global struct block_q4_K *) src0 + get_global_id(0); + global uchar * q = (global uchar *) dst_q + QK_K/2*get_global_id(0); + global uchar * s = (global uchar *) dst_s + K_SCALE_SIZE*get_global_id(0); + global half * d = (global half *) dst_d + get_global_id(0); + global half * dm = (global half *) dst_dm + get_global_id(0); + + *d = b->d; + *dm = b->dm; + + for (int i = 0; i < QK_K/2; ++i) { + q[i] = b->q[i]; + } + for (int i = 0; i < K_SCALE_SIZE; ++i) { + s[i] = b->s[i]; + } +} + +// Restore block_q4_K from flattened arrays. +// Each thread processes a super block. +// Mask args are just to keep the signature consistent with the no-shuffle ones. +kernel void kernel_restore_block_q4_K( + global uchar * src_q, + global uchar * src_s, + global half * src_d, + global half * src_dm, + global struct block_q4_K * dst, + uchar mask_0F, + uchar mask_F0 +) { + global struct block_q4_K * b = (global struct block_q4_K *) dst + get_global_id(0); + global uchar * q = (global uchar *) src_q + QK_K/2*get_global_id(0); + global uchar * s = (global uchar *) src_s + K_SCALE_SIZE*get_global_id(0); + global half * d = (global half *) src_d + get_global_id(0); + global half * dm = (global half *) src_dm + get_global_id(0); + + b->d = *d; + b->dm = *dm; + + for (int i = 0; i < QK_K/2; ++i) { + b->q[i] = q[i]; + } + for (int i = 0; i < K_SCALE_SIZE; ++i) { + b->s[i] = s[i]; + } +} + +kernel void kernel_convert_block_q4_K_noshuffle( + global struct block_q4_K * src0, + global uchar * dst_q, + global uchar * dst_s, + global half * dst_d, + global half * dst_dm, + uchar mask_0F, + uchar mask_F0 +) { + global struct block_q4_K * b = (global struct block_q4_K *) src0 + get_global_id(0); + global uchar * q = (global uchar *) dst_q + QK_K/2 * get_global_id(0); + global uchar * s = (global uchar *) dst_s + K_SCALE_SIZE * get_global_id(0); + global half * d = (global half *) dst_d + get_global_id(0); + global half * dm = (global half *) dst_dm + get_global_id(0); + + *d = b->d; + *dm = b->dm; + + for (int i = 0; i < QK_K / 64; ++i) { + for (int j = 0; j < 16; ++j) { + uchar x0 = b->q[i*32 + 2*j]; + uchar x1 = b->q[i*32 + 2*j + 1]; + q[i*32 + j] = convert_uchar(x0 & mask_0F) | convert_uchar((x1 & mask_0F) << 4); + q[i*32 + j + 16] = convert_uchar((x0 & mask_F0) >> 4) | convert_uchar(x1 & mask_F0); + } + } + + for (int i = 0; i < K_SCALE_SIZE; ++i) { + s[i] = b->s[i]; + } +} + +kernel void kernel_restore_block_q4_K_noshuffle( + global uchar * src_q, + global uchar * src_s, + global half * src_d, + global half * src_dm, + global struct block_q4_K * dst, + uchar mask_0F, + uchar mask_F0 +) { + global struct block_q4_K * b = (global struct block_q4_K *) dst + get_global_id(0); + global uchar * q = (global uchar *) src_q + QK_K/2 * get_global_id(0); + global uchar * s = (global uchar *) src_s + K_SCALE_SIZE * get_global_id(0); + global half * d = (global half *) src_d + get_global_id(0); + global half * dm = (global half *) src_dm + get_global_id(0); + + b->d = *d; + b->dm = *dm; + + for (int i = 0; i < QK_K / 64; ++i) { + for (int j = 0; j < 16; ++j) { + uchar lo = q[i*32 + j]; + uchar hi = q[i*32 + j + 16]; + b->q[i*32 + 2*j] = convert_uchar((lo & mask_0F) | ((hi & mask_0F) << 4)); + b->q[i*32 + 2*j + 1] = convert_uchar(((lo & mask_F0) >> 4) | (hi & mask_F0)); + } + } + + for (int i = 0; i < K_SCALE_SIZE; ++i) { + b->s[i] = s[i]; + } +} + +//------------------------------------------------------------------------------ +// kernel_convert_block_q5_K +// Convert the block_q5_K format to 5 separate arrays (AOS -> SOA). +// Each thread processes a super block. +//------------------------------------------------------------------------------ +kernel void kernel_convert_block_q5_K( + global struct block_q5_K * src0, + global uchar * dst_q, + global uchar * dst_qh, + global uchar * dst_s, + global half * dst_d, + global half * dst_dm, + uchar mask_0F, + uchar mask_F0 +) { + global struct block_q5_K * b = (global struct block_q5_K *) src0 + get_global_id(0); + global uchar * q = (global uchar *) dst_q + QK_K/2*get_global_id(0); + global uchar * qh = (global uchar *) dst_qh + QK_K/8*get_global_id(0); + global uchar * s = (global uchar *) dst_s + K_SCALE_SIZE*get_global_id(0); + global half * d = (global half *) dst_d + get_global_id(0); + global half * dm = (global half *) dst_dm + get_global_id(0); + + *d = b->d; + *dm = b->dm; + + for (int i = 0; i < QK_K/2; ++i) { + q[i] = b->qs[i]; + } + for (int i = 0; i < QK_K/8; ++i) { + qh[i] = b->qh[i]; + } + for (int i = 0; i < K_SCALE_SIZE; ++i) { + s[i] = b->s[i]; + } +} + +// Restore block_q5_K from flattened arrays. +// Each thread processes a super block. +kernel void kernel_restore_block_q5_K( + global uchar * src_q, + global uchar * src_qh, + global uchar * src_s, + global half * src_d, + global half * src_dm, + global struct block_q5_K * dst, + uchar mask_0F, + uchar mask_F0 +) { + global struct block_q5_K * b = (global struct block_q5_K *) dst + get_global_id(0); + global uchar * q = (global uchar *) src_q + QK_K/2*get_global_id(0); + global uchar * qh = (global uchar *) src_qh + QK_K/8*get_global_id(0); + global uchar * s = (global uchar *) src_s + K_SCALE_SIZE*get_global_id(0); + global half * d = (global half *) src_d + get_global_id(0); + global half * dm = (global half *) src_dm + get_global_id(0); + + b->d = *d; + b->dm = *dm; + + for (int i = 0; i < QK_K/2; ++i) { + b->qs[i] = q[i]; + } + for (int i = 0; i < QK_K/8; ++i) { + b->qh[i] = qh[i]; + } + for (int i = 0; i < K_SCALE_SIZE; ++i) { + b->s[i] = s[i]; + } +} + +kernel void kernel_convert_block_q5_K_noshuffle( + global struct block_q5_K * src0, + global uchar * dst_q, + global uchar * dst_qh, + global uchar * dst_s, + global half * dst_d, + global half * dst_dm, + uchar mask_0F, + uchar mask_F0 +) { + global struct block_q5_K * b = (global struct block_q5_K *) src0 + get_global_id(0); + global uchar * q = (global uchar *) dst_q + QK_K/2 * get_global_id(0); + global uchar * qh = (global uchar *) dst_qh + QK_K/8 * get_global_id(0); + global uchar * s = (global uchar *) dst_s + K_SCALE_SIZE * get_global_id(0); + global half * d = (global half *) dst_d + get_global_id(0); + global half * dm = (global half *) dst_dm + get_global_id(0); + + *d = b->d; + *dm = b->dm; + + for (int i = 0; i < QK_K / 64; ++i) { + for (int j = 0; j < 16; ++j) { + uchar x0 = b->qs[i*32 + 2*j]; + uchar x1 = b->qs[i*32 + 2*j + 1]; + q[i*32 + j] = convert_uchar(x0 & mask_0F) | convert_uchar((x1 & mask_0F) << 4); + q[i*32 + j + 16] = convert_uchar((x0 & mask_F0) >> 4) | convert_uchar(x1 & mask_F0); + } + } + + for (int l = 0; l < QK_K/8; ++l) { + uchar x0 = 0; + for (int i = 0; i < 8; ++i) { + x0 |= ((b->qh[(l%4)*8+i] >> (l/4)) & 0x01) << i; + } + qh[l] = x0; + } + + for (int i = 0; i < K_SCALE_SIZE; ++i) { + s[i] = b->s[i]; + } +} + +kernel void kernel_restore_block_q5_K_noshuffle( + global uchar * src_q, + global uchar * src_qh, + global uchar * src_s, + global half * src_d, + global half * src_dm, + global struct block_q5_K * dst, + uchar mask_0F, + uchar mask_F0 +) { + global struct block_q5_K * b = (global struct block_q5_K *) dst + get_global_id(0); + global uchar * q = (global uchar *) src_q + QK_K/2 * get_global_id(0); + global uchar * qh = (global uchar *) src_qh + QK_K/8 * get_global_id(0); + global uchar * s = (global uchar *) src_s + K_SCALE_SIZE * get_global_id(0); + global half * d = (global half *) src_d + get_global_id(0); + global half * dm = (global half *) src_dm + get_global_id(0); + + b->d = *d; + b->dm = *dm; + + for (int i = 0; i < QK_K / 64; ++i) { + for (int j = 0; j < 16; ++j) { + uchar lo = q[i*32 + j]; + uchar hi = q[i*32 + j + 16]; + b->qs[i*32 + 2*j] = convert_uchar((lo & mask_0F) | ((hi & mask_0F) << 4)); + b->qs[i*32 + 2*j + 1] = convert_uchar(((lo & mask_F0) >> 4) | (hi & mask_F0)); + } + } + + for (int g = 0; g < 4; ++g) { + for (int i = 0; i < 8; ++i) { + uchar x0 = 0; + for (int k = 0; k < 8; ++k) { + x0 |= ((qh[4*k+g] >> i) & 0x01) << k; + } + b->qh[g*8+i] = x0; + } + } + + for (int i = 0; i < K_SCALE_SIZE; ++i) { + b->s[i] = s[i]; + } +} + +//------------------------------------------------------------------------------ +// kernel_convert_block_q6_K +// Convert the block_q6_K format to 3 separate arrays (AOS -> SOA). +// This kernel does not deshuffle the bits. +// Each thread processes a super block. +//------------------------------------------------------------------------------ +kernel void kernel_convert_block_q6_K( + global struct block_q6_K * src0, + global uchar * dst_ql, + global uchar * dst_qh, + global char * dst_s, + global half * dst_d, + uchar mask_lsb_8, + ulong n_blk +) { + if (get_global_id(0) >= n_blk) { + return; + } + global struct block_q6_K * b = (global struct block_q6_K *) src0 + get_global_id(0); + global uchar * ql = (global uchar *) dst_ql + QK_K/2*get_global_id(0); + global uchar * qh = (global uchar *) dst_qh + QK_K/4*get_global_id(0); + global char * s = (global char *) dst_s + QK_K/16*get_global_id(0); + global half * d = (global half *) dst_d + get_global_id(0); + + *d = b->d; + + for (int i = 0; i < QK_K/2; ++i) { + ql[i] = b->ql[i]; + } + for (int i = 0; i < QK_K/4; ++i) { + qh[i] = b->qh[i]; + } + for (int i = 0; i < QK_K/16; ++i) { + s[i] = b->scales[i]; + } +} + +// Restore block_q6_K from flattened arrays. +// Each thread processes a super block. +kernel void kernel_restore_block_q6_K( + global uchar * dst_ql, + global uchar * dst_qh, + global char * dst_s, + global half * dst_d, + global struct block_q6_K * dst, + uchar mask_lsb_8, + ulong n_blk +) { + if (get_global_id(0) >= n_blk) { + return; + } + global struct block_q6_K * b = (global struct block_q6_K *) dst + get_global_id(0); + global uchar * ql = (global uchar *) dst_ql + QK_K/2*get_global_id(0); + global uchar * qh = (global uchar *) dst_qh + QK_K/4*get_global_id(0); + global char * s = (global char *) dst_s + QK_K/16*get_global_id(0); + global half * d = (global half *) dst_d + get_global_id(0); + + b->d = *d; + + for (int i = 0; i < QK_K/2; ++i) { + b->ql[i] = ql[i]; + } + for (int i = 0; i < QK_K/4; ++i) { + b->qh[i] = qh[i]; + } + for (int i = 0; i < QK_K/16; ++i) { + b->scales[i] = s[i]; + } +} + +kernel void kernel_convert_block_q6_K_noshuffle( + global struct block_q6_K * src0, + global uchar * dst_ql, + global uchar * dst_qh, + global char * dst_s, + global half * dst_d, + uchar mask_lsb_8, + ulong n_blk +) { + if (get_global_id(0) >= n_blk) { + return; + } + global struct block_q6_K * b = (global struct block_q6_K *) src0 + get_global_id(0); + global uchar * ql = (global uchar *) dst_ql + QK_K/2*get_global_id(0); + global uchar * qh = (global uchar *) dst_qh + QK_K/4*get_global_id(0); + global char * s = (global char *) dst_s + QK_K/16*get_global_id(0); + global half * d = (global half *) dst_d + get_global_id(0); + + *d = b->d; + + for (int i = 0; i < QK_K/2/4; ++i) { + uchar x0 = b->ql[i*2 + 0] & mask_lsb_8; + uchar x1 = b->ql[i*2 + 1] & mask_lsb_8; + ql[i + 0] = (x0 & 0x0F) | ((x1 & 0x0F) << 4); + ql[i + 32] = ((x0 & 0xF0) >> 4) | (x1 & 0xF0); + + uchar x2 = b->ql[i*2 + 0 + 64] & mask_lsb_8; + uchar x3 = b->ql[i*2 + 1 + 64] & mask_lsb_8; + ql[i + 64] = (x2 & 0x0F) | ((x3 & 0x0F) << 4); + ql[i + 96] = ((x2 & 0xF0) >> 4) | (x3 & 0xF0); + } + + for (int i = 0; i < QK_K/4/8; ++i) { + uchar x0 = b->qh[i*4 + 0] & mask_lsb_8; + uchar x1 = b->qh[i*4 + 1] & mask_lsb_8; + uchar x2 = b->qh[i*4 + 2] & mask_lsb_8; + uchar x3 = b->qh[i*4 + 3] & mask_lsb_8; + qh[i + 0] = (x0 & 0x03) | ((x1 & 0x03) << 2) | ((x2 & 0x03) << 4) | ((x3 & 0x03) << 6); + qh[i + 8] = ((x0 & 0x0C) >> 2) | (x1 & 0x0C) | ((x2 & 0x0C) << 2) | ((x3 & 0x0C) << 4); + qh[i + 16] = ((x0 & 0x30) >> 4) | ((x1 & 0x30) >> 2) | (x2 & 0x30) | ((x3 & 0x30) << 2); + qh[i + 24] = ((x0 & 0xC0) >> 6) | ((x1 & 0xC0) >> 4) | ((x2 & 0xC0) >> 2) | (x3 & 0xC0); + + uchar x4 = b->qh[i*4 + 0 + 32] & mask_lsb_8; + uchar x5 = b->qh[i*4 + 1 + 32] & mask_lsb_8; + uchar x6 = b->qh[i*4 + 2 + 32] & mask_lsb_8; + uchar x7 = b->qh[i*4 + 3 + 32] & mask_lsb_8; + qh[i + 32] = (x4 & 0x03) | ((x5 & 0x03) << 2) | ((x6 & 0x03) << 4) | ((x7 & 0x03) << 6); + qh[i + 40] = ((x4 & 0x0C) >> 2) | (x5 & 0x0C) | ((x6 & 0x0C) << 2) | ((x7 & 0x0C) << 4); + qh[i + 48] = ((x4 & 0x30) >> 4) | ((x5 & 0x30) >> 2) | (x6 & 0x30) | ((x7 & 0x30) << 2); + qh[i + 56] = ((x4 & 0xC0) >> 6) | ((x5 & 0xC0) >> 4) | ((x6 & 0xC0) >> 2) | (x7 & 0xC0); + } + + for (int i = 0; i < QK_K/16; ++i) { + s[i] = b->scales[i]; + } +} + +kernel void kernel_restore_block_q6_K_noshuffle( + global uchar * src_ql, + global uchar * src_qh, + global char * src_s, + global half * src_d, + global struct block_q6_K * dst, + uchar mask_lsb_8, + ulong n_blk +) { + if (get_global_id(0) >= n_blk) { + return; + } + global struct block_q6_K * b = (global struct block_q6_K *) dst + get_global_id(0); + global uchar * ql = (global uchar *) src_ql + QK_K/2*get_global_id(0); + global uchar * qh = (global uchar *) src_qh + QK_K/4*get_global_id(0); + global char * s = (global char *) src_s + QK_K/16*get_global_id(0); + global half * d = (global half *) src_d + get_global_id(0); + + b->d = *d; + + for (int i = 0; i < QK_K/2/4; ++i) { + uchar x0 = ql[i + 0] & mask_lsb_8; + uchar x1 = ql[i + 32] & mask_lsb_8; + b->ql[i*2 + 0] = (x0 & 0x0F) | ((x1 & 0x0F) << 4); + b->ql[i*2 + 1] = ((x0 & 0xF0) >> 4) | (x1 & 0xF0); + + uchar x2 = ql[i + 64] & mask_lsb_8; + uchar x3 = ql[i + 96] & mask_lsb_8; + b->ql[i*2 + 0 + 64] = (x2 & 0x0F) | ((x3 & 0x0F) << 4); + b->ql[i*2 + 1 + 64] = ((x2 & 0xF0) >> 4) | (x3 & 0xF0); + } + + for (int i = 0; i < QK_K/4/8; ++i) { + uchar x0 = qh[i + 0] & mask_lsb_8; + uchar x1 = qh[i + 8] & mask_lsb_8; + uchar x2 = qh[i + 16] & mask_lsb_8; + uchar x3 = qh[i + 24] & mask_lsb_8; + b->qh[i*4 + 0] = (x0 & 0x03) | ((x1 & 0x03) << 2) | ((x2 & 0x03) << 4) | ((x3 & 0x03) << 6); + b->qh[i*4 + 1] = ((x0 & 0x0C) >> 2) | (x1 & 0x0C) | ((x2 & 0x0C) << 2) | ((x3 & 0x0C) << 4); + b->qh[i*4 + 2] = ((x0 & 0x30) >> 4) | ((x1 & 0x30) >> 2) | (x2 & 0x30) | ((x3 & 0x30) << 2); + b->qh[i*4 + 3] = ((x0 & 0xC0) >> 6) | ((x1 & 0xC0) >> 4) | ((x2 & 0xC0) >> 2) | (x3 & 0xC0); + + uchar x4 = qh[i + 0 + 32] & mask_lsb_8; + uchar x5 = qh[i + 8 + 32] & mask_lsb_8; + uchar x6 = qh[i + 16 + 32] & mask_lsb_8; + uchar x7 = qh[i + 24 + 32] & mask_lsb_8; + b->qh[i*4 + 0 + 32] = (x4 & 0x03) | ((x5 & 0x03) << 2) | ((x6 & 0x03) << 4) | ((x7 & 0x03) << 6); + b->qh[i*4 + 1 + 32] = ((x4 & 0x0C) >> 2) | (x5 & 0x0C) | ((x6 & 0x0C) << 2) | ((x7 & 0x0C) << 4); + b->qh[i*4 + 2 + 32] = ((x4 & 0x30) >> 4) | ((x5 & 0x30) >> 2) | (x6 & 0x30) | ((x7 & 0x30) << 2); + b->qh[i*4 + 3 + 32] = ((x4 & 0xC0) >> 6) | ((x5 & 0xC0) >> 4) | ((x6 & 0xC0) >> 2) | (x7 & 0xC0); + } + + for (int i = 0; i < QK_K/16; ++i) { + b->scales[i] = s[i]; + } +} + +//------------------------------------------------------------------------------ +// kernel_convert_block_iq4_nl +// Convert the block_iq4_nl format to 2 separate arrays (AOS -> SOA). +//------------------------------------------------------------------------------ +kernel void kernel_convert_block_iq4_nl( + global struct block_iq4_nl * src0, + global uchar * dst_q, + global half * dst_d, + uchar mask_0F, + uchar mask_F0, + ulong n_blk +) { + if (get_global_id(0) >= n_blk) { + return; + } + global struct block_iq4_nl * b = (global struct block_iq4_nl *) src0 + get_global_id(0); + global uchar * q = (global uchar *) dst_q + QK4_NL/2*get_global_id(0); + global half * d = (global half *) dst_d + get_global_id(0); + + *d = b->d; + + for (int i = 0; i < QK4_NL/2; ++i) { + q[i] = b->qs[i]; + } +} + +kernel void kernel_restore_block_iq4_nl( + global uchar * src_q, + global half * src_d, + global struct block_iq4_nl * dst, + ulong n_blk +) { + if (get_global_id(0) >= n_blk) { + return; + } + global struct block_iq4_nl * b = (global struct block_iq4_nl *) dst + get_global_id(0); + global uchar * q = (global uchar *) src_q + QK4_NL/2*get_global_id(0); + global half * d = (global half *) src_d + get_global_id(0); + + b->d = *d; + + for (int i = 0; i < QK4_NL/2; ++i) { + b->qs[i] = q[i]; + } +} + +kernel void kernel_convert_block_iq4_nl_noshuffle( + global struct block_iq4_nl * src0, + global uchar * dst_q, + global half * dst_d, + uchar mask_0F, + uchar mask_F0, + ulong n_blk +) { + if (get_global_id(0) >= n_blk) { + return; + } + global struct block_iq4_nl * b = (global struct block_iq4_nl *) src0 + get_global_id(0); + global uchar * q = (global uchar *) dst_q + QK4_NL/2*get_global_id(0); + global half * d = (global half *) dst_d + get_global_id(0); + + *d = b->d; + for (int i = 0; i < QK4_NL/4; ++i) { + uchar x0 = b->qs[2*i + 0]; + uchar x1 = b->qs[2*i + 1]; + + q[i + 0 ] = convert_uchar(x0 & mask_0F) | convert_uchar((x1 & mask_0F) << 4); + q[i + QK4_NL/4] = convert_uchar((x0 & mask_F0) >> 4) | convert_uchar(x1 & mask_F0); + } +} + +kernel void kernel_restore_block_iq4_nl_noshuffle( + global uchar * src_q, + global half * src_d, + global struct block_iq4_nl * dst, + uchar mask_0F, + uchar mask_F0, + ulong n_blk +) { + if (get_global_id(0) >= n_blk) { + return; + } + global struct block_iq4_nl * b = (global struct block_iq4_nl *) dst + get_global_id(0); + global uchar * q = (global uchar *) src_q + QK4_NL/2*get_global_id(0); + global half * d = (global half *) src_d + get_global_id(0); + + b->d = *d; + for (int i = 0; i < QK4_NL/4; ++i) { + uchar x0 = q[i + 0 ]; + uchar x1 = q[i + QK4_NL/4]; + + b->qs[2*i + 0] = convert_uchar((x0 & mask_0F) | ((x1 & mask_0F) << 4)); + b->qs[2*i + 1] = convert_uchar(((x0 & mask_F0) >> 4) | (x1 & mask_F0)); } } diff --git a/ggml/src/ggml-opencl/kernels/diag.cl b/ggml/src/ggml-opencl/kernels/diag.cl new file mode 100644 index 00000000000..884efa08fdd --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/diag.cl @@ -0,0 +1,27 @@ +kernel void kernel_diag_f32( + global const char * src0, + ulong offset0, + global char * dst, + ulong offsetd, + ulong nb01, + ulong nb02, + ulong nb03, + int ne0, + ulong nb0, + ulong nb2, + ulong nb3 +) { + src0 = src0 + offset0; + dst = dst + offsetd; + + int i3 = get_group_id(2); + int i2 = get_group_id(1); + int i1 = get_group_id(0); + + global const float * src0_ptr = (global const float *)(src0 + i2*nb02 + i3*nb03); + global float * dst_ptr = (global float *)(dst + i1*nb01 + i2*nb2 + i3*nb3); + + for (int i0 = get_local_id(0); i0 < ne0; i0 += get_local_size(0)) { + dst_ptr[i0] = i0 == i1 ? src0_ptr[i0] : 0.0f; + } +} diff --git a/ggml/src/ggml-opencl/kernels/exp.cl b/ggml/src/ggml-opencl/kernels/exp.cl new file mode 100644 index 00000000000..a2458b6579c --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/exp.cl @@ -0,0 +1,125 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable + +kernel void kernel_exp_f32( + global const float * src0, + ulong offset0, + global float * dst, + ulong offsetd, + int n +) { + if (get_global_id(0) >= n) { + return; + } + src0 = (global float*)((global char*)src0 + offset0); + dst = (global float*)((global char*)dst + offsetd); + + dst[get_global_id(0)] = exp(src0[get_global_id(0)]); +} + +kernel void kernel_exp_f32_4( + global const float4 * src0, + ulong offset0, + global float4 * dst, + ulong offsetd, + int n +) { + if (get_global_id(0) >= n) { + return; + } + src0 = (global float4*)((global char*)src0 + offset0); + dst = (global float4*)((global char*)dst + offsetd); + + dst[get_global_id(0)] = exp(src0[get_global_id(0)]); +} + +kernel void kernel_exp_f16( + global const half * src0, + ulong offset0, + global half * dst, + ulong offsetd, + int n +) { + if (get_global_id(0) >= n) { + return; + } + src0 = (global half*)((global char*)src0 + offset0); + dst = (global half*)((global char*)dst + offsetd); + + dst[get_global_id(0)] = exp(src0[get_global_id(0)]); +} + +kernel void kernel_exp_f16_4( + global const half4 * src0, + ulong offset0, + global half4 * dst, + ulong offsetd, + int n +) { + if (get_global_id(0) >= n) { + return; + } + src0 = (global half4*)((global char*)src0 + offset0); + dst = (global half4*)((global char*)dst + offsetd); + + dst[get_global_id(0)] = exp(src0[get_global_id(0)]); +} + +kernel void kernel_exp_f32_nc( + global const char * src0, + ulong offset0, + global char * dst, + ulong offsetd, + int ne00, + ulong nb00, + ulong nb01, + ulong nb02, + ulong nb03, + ulong nb0, + ulong nb1, + ulong nb2, + ulong nb3 +) { + src0 = src0 + offset0; + dst = dst + offsetd; + + const int i3 = get_group_id(2); + const int i2 = get_group_id(1); + const int i1 = get_group_id(0); + + for (int i0 = get_local_id(0); i0 < ne00; i0 += get_local_size(0)) { + global const float * x = (global const float *)(src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); + global float * y = (global float *)(dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); + + *y = exp(*x); + } +} + +kernel void kernel_exp_f16_nc( + global const char * src0, + ulong offset0, + global char * dst, + ulong offsetd, + int ne00, + ulong nb00, + ulong nb01, + ulong nb02, + ulong nb03, + ulong nb0, + ulong nb1, + ulong nb2, + ulong nb3 +) { + src0 = src0 + offset0; + dst = dst + offsetd; + + const int i3 = get_group_id(2); + const int i2 = get_group_id(1); + const int i1 = get_group_id(0); + + for (int i0 = get_local_id(0); i0 < ne00; i0 += get_local_size(0)) { + global const half * x = (global const half *)(src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); + global half * y = (global half *)(dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); + + *y = exp(*x); + } +} diff --git a/ggml/src/ggml-opencl/kernels/expm1.cl b/ggml/src/ggml-opencl/kernels/expm1.cl index 126298a2cdb..05442ac2043 100644 --- a/ggml/src/ggml-opencl/kernels/expm1.cl +++ b/ggml/src/ggml-opencl/kernels/expm1.cl @@ -3,80 +3,111 @@ //------------------------------------------------------------------------------ // expm1 //------------------------------------------------------------------------------ -kernel void kernel_expm1_f32_nd( - global void * p_src0_base, - ulong off_src0_abs, - global void * p_dst_base, - ulong off_dst_abs, - int ne00, - int ne01, - int ne02, - int ne03, + +kernel void kernel_expm1_f32( + global const float * src0, + ulong offset0, + global float * dst, + ulong offsetd +) { + src0 = (global float*)((global char*)src0 + offset0); + dst = (global float*)((global char*)dst + offsetd); + + dst[get_global_id(0)] = exp(src0[get_global_id(0)]) - 1.0f; +} + +kernel void kernel_expm1_f32_4( + global const float4 * src0, + ulong offset0, + global float4 * dst, + ulong offsetd +) { + src0 = (global float4*)((global char*)src0 + offset0); + dst = (global float4*)((global char*)dst + offsetd); + + dst[get_global_id(0)] = exp(src0[get_global_id(0)]) - 1.0f; +} + +kernel void kernel_expm1_f16( + global const half * src0, + ulong offset0, + global half * dst, + ulong offsetd +) { + src0 = (global half*)((global char*)src0 + offset0); + dst = (global half*)((global char*)dst + offsetd); + + dst[get_global_id(0)] = exp(src0[get_global_id(0)]) - 1.0h; +} + +kernel void kernel_expm1_f16_4( + global const half4 * src0, + ulong offset0, + global half4 * dst, + ulong offsetd +) { + src0 = (global half4*)((global char*)src0 + offset0); + dst = (global half4*)((global char*)dst + offsetd); + + dst[get_global_id(0)] = exp(src0[get_global_id(0)]) - 1.0h; +} + +kernel void kernel_expm1_f32_nc( + global const char * src0, + ulong offset0, + global char * dst, + ulong offsetd, + int ne00, ulong nb00, ulong nb01, ulong nb02, ulong nb03, - int ne10, - int ne11, - int ne12, - int ne13, - ulong nb10, - ulong nb11, - ulong nb12, - ulong nb13 + ulong nb0, + ulong nb1, + ulong nb2, + ulong nb3 ) { - int i0 = get_global_id(0); - int i1 = get_global_id(1); - int i2 = get_global_id(2); + src0 = src0 + offset0; + dst = dst + offsetd; - if (i0 < ne10 && i1 < ne11 && i2 < ne12) { - for (int i3 = 0; i3 < ne13; ++i3) { - ulong src_offset_in_tensor = (ulong)i0*nb00 + (ulong)i1*nb01 + (ulong)i2*nb02 + (ulong)i3*nb03; - global const float *src_val_ptr = (global const float *)((global char *)p_src0_base + off_src0_abs + src_offset_in_tensor); + const int i3 = get_group_id(2); + const int i2 = get_group_id(1); + const int i1 = get_group_id(0); - ulong dst_offset_in_tensor = (ulong)i0*nb10 + (ulong)i1*nb11 + (ulong)i2*nb12 + (ulong)i3*nb13; - global float *dst_val_ptr = (global float *)((global char *)p_dst_base + off_dst_abs + dst_offset_in_tensor); + for (int i0 = get_local_id(0); i0 < ne00; i0 += get_local_size(0)) { + global const float * x = (global const float *)(src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); + global float * y = (global float *)(dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); - *dst_val_ptr = exp(*src_val_ptr) - 1; - } + *y = exp(*x) - 1.0f; } } -kernel void kernel_expm1_f16_nd( - global void * p_src0_base, - ulong off_src0_abs, - global void * p_dst_base, - ulong off_dst_abs, - int ne00, - int ne01, - int ne02, - int ne03, +kernel void kernel_expm1_f16_nc( + global const char * src0, + ulong offset0, + global char * dst, + ulong offsetd, + int ne00, ulong nb00, ulong nb01, ulong nb02, ulong nb03, - int ne10, - int ne11, - int ne12, - int ne13, - ulong nb10, - ulong nb11, - ulong nb12, - ulong nb13 + ulong nb0, + ulong nb1, + ulong nb2, + ulong nb3 ) { - int i0 = get_global_id(0); - int i1 = get_global_id(1); - int i2 = get_global_id(2); + src0 = src0 + offset0; + dst = dst + offsetd; - if (i0 < ne10 && i1 < ne11 && i2 < ne12) { - for (int i3 = 0; i3 < ne13; ++i3) { - ulong src_offset_in_tensor = (ulong)i0*nb00 + (ulong)i1*nb01 + (ulong)i2*nb02 + (ulong)i3*nb03; - global const half *src_val_ptr = (global const half *)((global char *)p_src0_base + off_src0_abs + src_offset_in_tensor); + const int i3 = get_group_id(2); + const int i2 = get_group_id(1); + const int i1 = get_group_id(0); - ulong dst_offset_in_tensor = (ulong)i0*nb10 + (ulong)i1*nb11 + (ulong)i2*nb12 + (ulong)i3*nb13; - global half *dst_val_ptr = (global half *)((global char *)p_dst_base + off_dst_abs + dst_offset_in_tensor); + for (int i0 = get_local_id(0); i0 < ne00; i0 += get_local_size(0)) { + global const half * x = (global const half *)(src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); + global half * y = (global half *)(dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); - *dst_val_ptr = exp(*src_val_ptr) - 1; - } + *y = exp(*x) - 1.0f; } } diff --git a/ggml/src/ggml-opencl/kernels/gated_delta_net.cl b/ggml/src/ggml-opencl/kernels/gated_delta_net.cl new file mode 100644 index 00000000000..319c9829529 --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/gated_delta_net.cl @@ -0,0 +1,249 @@ +#pragma OPENCL EXTENSION cl_khr_subgroups : enable + +#ifdef cl_intel_required_subgroup_size +#pragma OPENCL EXTENSION cl_intel_required_subgroup_size : enable +#define INTEL_GPU 1 +#define REQD_SUBGROUP_SIZE_16 __attribute__((intel_reqd_sub_group_size(16))) +#define REQD_SUBGROUP_SIZE_32 __attribute__((intel_reqd_sub_group_size(32))) +#elif defined(cl_qcom_reqd_sub_group_size) +#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable +#define ADRENO_GPU 1 +#define REQD_SUBGROUP_SIZE_64 __attribute__((qcom_reqd_sub_group_size("half"))) +#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full"))) +#endif + +#ifndef S_V +#define S_V 128 +#endif +#ifndef KDA +#define KDA 0 +#endif +#ifndef SUBGROUP_SIZE +#define SUBGROUP_SIZE 64 +#endif +#ifndef LANES_PER_COLUMN +#define LANES_PER_COLUMN 8 +#endif +#ifndef COLS_PER_LANE_GROUP +#define COLS_PER_LANE_GROUP 1 +#endif +#ifndef SUBGROUPS_PER_WG +#define SUBGROUPS_PER_WG 1 +#endif +#ifndef USE_QCOM_SUBGROUP_SHUFFLE +#define USE_QCOM_SUBGROUP_SHUFFLE 0 +#endif + +#define WG_SIZE (SUBGROUP_SIZE * SUBGROUPS_PER_WG) +#define LANE_GROUPS_PER_SG (SUBGROUP_SIZE / LANES_PER_COLUMN) +#define COLS_PER_SG (LANE_GROUPS_PER_SG * COLS_PER_LANE_GROUP) +#define COLS_PER_WG (SUBGROUPS_PER_WG * COLS_PER_SG) +#define ROWS_PER_LANE (S_V / LANES_PER_COLUMN) + +#if USE_QCOM_SUBGROUP_SHUFFLE +#pragma OPENCL EXTENSION cl_qcom_subgroup_shuffle : enable +#endif + +// XOR-based parallel sum +// This does a reduction across groups of LANES_PER_COLUMN +static inline float reduce_add_shmem(float partial, __local float * temp, uint lane) { +#if USE_QCOM_SUBGROUP_SHUFFLE + #pragma unroll + for (uint s = LANES_PER_COLUMN / 2u; s > 0u; s >>= 1u) { + partial += qcom_sub_group_shuffle_xor(partial, s, CLK_SUB_GROUP_SHUFFLE_WIDTH_WAVE_SIZE_QCOM, partial); + } + return partial; +#else + temp[lane] = partial; + sub_group_barrier(CLK_LOCAL_MEM_FENCE); + #pragma unroll + for (uint s = LANES_PER_COLUMN / 2u; s > 0u; s >>= 1u) { + float other = temp[lane ^ s]; + sub_group_barrier(CLK_LOCAL_MEM_FENCE); + temp[lane] += other; + sub_group_barrier(CLK_LOCAL_MEM_FENCE); + } + const float result = temp[lane]; + sub_group_barrier(CLK_LOCAL_MEM_FENCE); + return result; +#endif +} + +#define REDUCE_PARTIAL(partial, temp_ptr, lid) \ + ((LANES_PER_COLUMN == 1u) ? (partial) : reduce_add_shmem((partial), (temp_ptr), (lid))) + +// force compiler to optimize kernel for a specific fixed work-group size +__attribute__((reqd_work_group_size(WG_SIZE, 1, 1))) +#ifdef INTEL_GPU +REQD_SUBGROUP_SIZE_32 +#elif defined (ADRENO_GPU) +REQD_SUBGROUP_SIZE_64 +#endif +kernel void kernel_gated_delta_net( + global const char * q_buf, ulong off_q, + global const char * k_buf, ulong off_k, + global const char * v_buf, ulong off_v, + global const char * g_buf, ulong off_g, + global const char * beta_buf, ulong off_beta, + global const char * state_buf, ulong off_state, + global char * dst_buf, ulong off_dst, + uint H_v, + uint n_tokens, + uint n_seqs, + uint s_off, + uint sq1, uint sq2, uint sq3, + uint sv1, uint sv2, uint sv3, + uint sb1, uint sb2, uint sb3, + uint H_k, + uint rq3, + float scale, + uint K) { + + global const float * data_q = (global const float *)(q_buf + off_q); + global const float * data_k = (global const float *)(k_buf + off_k); + global const float * data_v = (global const float *)(v_buf + off_v); + global const float * data_g = (global const float *)(g_buf + off_g); + global const float * data_beta = (global const float *)(beta_buf + off_beta); + global const float * data_state = (global const float *)(state_buf + off_state); + global float * data_dst = (global float *)(dst_buf + off_dst); + + const uint head_id = get_group_id(0); + const uint seq_id = get_group_id(1); + const uint tid = (uint)get_local_id(0); + + const uint sg_id = get_sub_group_id(); // subgroup id + const uint sg_lid = get_sub_group_local_id(); // subgroup lane id + + const uint lane = sg_lid % LANES_PER_COLUMN; + const uint lane_group = sg_lid / LANES_PER_COLUMN; + const uint wg_col_base = get_group_id(2) * COLS_PER_WG; + const uint sg_col_base = wg_col_base + sg_id * COLS_PER_SG; + + const uint iq1 = head_id % H_k; // head index for Q and K + const uint iq3 = seq_id / rq3; // seq index for Q and K + + const uint state_size = S_V * S_V; + // input state holds s0 only [S_v, S_v, H, n_seqs]: per-seq stride is H*D. + const uint state_base = (seq_id * H_v + head_id) * state_size; + const uint q_off_base = iq3 * sq3 + iq1 * sq1; + const uint v_off_base = seq_id * sv3 + head_id * sv1; + const uint gb_off_base = seq_id * sb3 + head_id * sb1; + const uint state_out_base = (seq_id * H_v + head_id) * state_size; + const uint state_size_per_snap = state_size * H_v * n_seqs; + + __local float reduce_temp[WG_SIZE]; + __local float * temp_ptr = reduce_temp + sg_id * SUBGROUP_SIZE; + + float s_shard[COLS_PER_LANE_GROUP][ROWS_PER_LANE]; + #pragma unroll + for (uint cg = 0; cg < COLS_PER_LANE_GROUP; cg++) { + const uint col = sg_col_base + cg * LANE_GROUPS_PER_SG + lane_group; + #pragma unroll + for (uint r = 0; r < ROWS_PER_LANE; r++) { + s_shard[cg][r] = data_state[state_base + col * S_V + r * LANES_PER_COLUMN + lane]; + } + } + + // snapshot slot mapping: slot 0 = most recent state, slot s = s tokens back. + // When n_tokens < K only slots 0..n_tokens-1 are written; older slots are caller-owned. + uint attn_off = (seq_id * n_tokens * H_v + head_id) * S_V; + + for (uint t = 0; t < n_tokens; t++) { + const uint q_off = q_off_base + t * sq2; + const uint k_off = q_off; + const uint v_off = v_off_base + t * sv2; + const uint gb_off = gb_off_base + t * sb2; + const float beta_val = data_beta[gb_off]; + + float k_reg[ROWS_PER_LANE]; + float q_reg[ROWS_PER_LANE]; +#if KDA + float g_exp[ROWS_PER_LANE]; + #pragma unroll + for (uint r = 0; r < ROWS_PER_LANE; r++) { + const uint i = r * LANES_PER_COLUMN + lane; + k_reg[r] = data_k[k_off + i]; + q_reg[r] = data_q[q_off + i]; + g_exp[r] = exp(data_g[gb_off * S_V + i]); + } +#else + const float g_val = exp(data_g[gb_off]); + + #pragma unroll + for (uint r = 0; r < ROWS_PER_LANE; r++) { + const uint i = r * LANES_PER_COLUMN + lane; + k_reg[r] = data_k[k_off + i]; + q_reg[r] = data_q[q_off + i]; + } +#endif + + #pragma unroll + for (uint cg = 0; cg < COLS_PER_LANE_GROUP; cg++) { + const uint col = sg_col_base + cg * LANE_GROUPS_PER_SG + lane_group; + float v_val = data_v[v_off + col]; + + float kv_shard = 0.0f; + #pragma unroll + for (uint r = 0; r < ROWS_PER_LANE; r++) { +#if KDA + float gs = g_exp[r] * s_shard[cg][r]; + kv_shard += gs * k_reg[r]; +#else + kv_shard += s_shard[cg][r] * k_reg[r]; +#endif + } + +#if !KDA + kv_shard *= g_val; // Applied once instead of ROWS_PER_LANE times +#endif + + const float kv_col = REDUCE_PARTIAL(kv_shard, temp_ptr, sg_lid); + + const float delta_col = (v_val - kv_col) * beta_val; + + float attn_partial = 0.0f; + #pragma unroll + for (uint r = 0; r < ROWS_PER_LANE; r++) { +#if KDA + float gs = g_exp[r] * s_shard[cg][r]; +#else + float gs = g_val * s_shard[cg][r]; +#endif + s_shard[cg][r] = gs + k_reg[r] * delta_col; + attn_partial += s_shard[cg][r] * q_reg[r]; + } + const float attn_col = REDUCE_PARTIAL(attn_partial, temp_ptr, sg_lid); + + if (lane == 0) { + data_dst[attn_off + col] = attn_col * scale; + } + } + attn_off += S_V * H_v; + + if (K > 1u) { + const int target_slot = (int)n_tokens - 1 - (int)t; + if (target_slot >= 0 && target_slot < (int)K) { + #pragma unroll + for (uint cg = 0; cg < COLS_PER_LANE_GROUP; cg++) { + const uint col = sg_col_base + cg * LANE_GROUPS_PER_SG + lane_group; + const uint slot_base = s_off + (uint)target_slot * state_size_per_snap + state_out_base; + #pragma unroll + for (uint r = 0; r < ROWS_PER_LANE; r++) { + data_dst[slot_base + col * S_V + r * LANES_PER_COLUMN + lane] = s_shard[cg][r]; + } + } + } + } + } + + if (K == 1u) { + #pragma unroll + for (uint cg = 0; cg < COLS_PER_LANE_GROUP; cg++) { + const uint col = sg_col_base + cg * LANE_GROUPS_PER_SG + lane_group; + #pragma unroll + for (uint r = 0; r < ROWS_PER_LANE; r++) { + data_dst[s_off + state_base + col * S_V + r * LANES_PER_COLUMN + lane] = s_shard[cg][r]; + } + } + } +} diff --git a/ggml/src/ggml-opencl/kernels/gemm_moe_mxfp4_f32_ns.cl b/ggml/src/ggml-opencl/kernels/gemm_moe_mxfp4_f32_ns.cl new file mode 100644 index 00000000000..02cdbdd9fb1 --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/gemm_moe_mxfp4_f32_ns.cl @@ -0,0 +1,306 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable +#pragma OPENCL EXTENSION cl_khr_subgroups : enable +#pragma OPENCL EXTENSION cl_qcom_subgroup_uniform_load: enable +#pragma OPENCL EXTENSION cl_qcom_subgroup_constant_load: enable +#pragma OPENCL EXTENSION cl_qcom_extra_vector_types : enable + +#define TILESIZE_K 16 +#define TILESIZE_M 64 +#define TILESIZE_N 32 + + +static inline half8 mxfp4_to_fp16_packed8(ushort2 fp4x8) { + ushort2 fp16_packed_a_0, fp16_packed_b_0, bias_a, bias_b, sign_a, sign_b; + fp16_packed_a_0.lo = (fp4x8.s0 << 9) & 0x0E00; + fp16_packed_a_0.hi = (fp4x8.s0 << 5) & 0x0E00; + fp16_packed_b_0.lo = (fp4x8.s0 << 1) & 0x0E00; + fp16_packed_b_0.hi = (fp4x8.s0 >> 3) & 0x0E00; + + bias_a.lo = (fp16_packed_a_0.lo != 0) ? 0x3800 : 0x0; + bias_a.hi = (fp16_packed_a_0.hi != 0) ? 0x3800 : 0x0; + bias_b.lo = (fp16_packed_b_0.lo != 0) ? 0x3800 : 0x0; + bias_b.hi = (fp16_packed_b_0.hi != 0) ? 0x3800 : 0x0; + + fp16_packed_a_0.lo = (fp16_packed_a_0.lo != 0x0200) ? fp16_packed_a_0.lo : 0x0; + fp16_packed_a_0.hi = (fp16_packed_a_0.hi != 0x0200) ? fp16_packed_a_0.hi : 0x0; + fp16_packed_b_0.lo = (fp16_packed_b_0.lo != 0x0200) ? fp16_packed_b_0.lo : 0x0; + fp16_packed_b_0.hi = (fp16_packed_b_0.hi != 0x0200) ? fp16_packed_b_0.hi : 0x0; + + sign_a.lo = (fp4x8.s0 << 12) & 0x8000; + sign_a.hi = (fp4x8.s0 << 8) & 0x8000; + sign_b.lo = (fp4x8.s0 << 4) & 0x8000; + sign_b.hi = fp4x8.s0 & 0x8000; + + fp16_packed_a_0 = sign_a + bias_a + fp16_packed_a_0; + fp16_packed_b_0 = sign_b + bias_b + fp16_packed_b_0; + + ushort2 fp16_packed_a_1, fp16_packed_b_1; + fp16_packed_a_1.lo = (fp4x8.s1 << 9) & 0x0E00; + fp16_packed_a_1.hi = (fp4x8.s1 << 5) & 0x0E00; + fp16_packed_b_1.lo = (fp4x8.s1 << 1) & 0x0E00; + fp16_packed_b_1.hi = (fp4x8.s1 >> 3) & 0x0E00; + + bias_a.lo = (fp16_packed_a_1.lo != 0) ? 0x3800 : 0x0; + bias_a.hi = (fp16_packed_a_1.hi != 0) ? 0x3800 : 0x0; + bias_b.lo = (fp16_packed_b_1.lo != 0) ? 0x3800 : 0x0; + bias_b.hi = (fp16_packed_b_1.hi != 0) ? 0x3800 : 0x0; + + fp16_packed_a_1.lo = (fp16_packed_a_1.lo != 0x0200) ? fp16_packed_a_1.lo : 0x0; + fp16_packed_a_1.hi = (fp16_packed_a_1.hi != 0x0200) ? fp16_packed_a_1.hi : 0x0; + fp16_packed_b_1.lo = (fp16_packed_b_1.lo != 0x0200) ? fp16_packed_b_1.lo : 0x0; + fp16_packed_b_1.hi = (fp16_packed_b_1.hi != 0x0200) ? fp16_packed_b_1.hi : 0x0; + + sign_a.lo = (fp4x8.s1 << 12) & 0x8000; + sign_a.hi = (fp4x8.s1 << 8) & 0x8000; + sign_b.lo = (fp4x8.s1 << 4) & 0x8000; + sign_b.hi = fp4x8.s1 & 0x8000; + + fp16_packed_a_1 = sign_a + bias_a + fp16_packed_a_1; + fp16_packed_b_1 = sign_b + bias_b + fp16_packed_b_1; + + return as_half8((ushort8)(fp16_packed_a_0, fp16_packed_b_0, fp16_packed_a_1, fp16_packed_b_1)); +} + + +#define dotx16_reduce8(a_reg, b_lm, c_reg, lm_offset) \ + acc.s0 = dot(a_reg.s0123, b_lm[lm_offset + 0]); \ + acc.s1 = dot(a_reg.s0123, b_lm[lm_offset + 1]); \ + acc.s2 = dot(a_reg.s0123, b_lm[lm_offset + 2]); \ + acc.s3 = dot(a_reg.s0123, b_lm[lm_offset + 3]); \ + acc.s4 = dot(a_reg.s0123, b_lm[lm_offset + 4]); \ + acc.s5 = dot(a_reg.s0123, b_lm[lm_offset + 5]); \ + acc.s6 = dot(a_reg.s0123, b_lm[lm_offset + 6]); \ + acc.s7 = dot(a_reg.s0123, b_lm[lm_offset + 7]); \ + acc.s8 = dot(a_reg.s0123, b_lm[lm_offset + 8]); \ + acc.s9 = dot(a_reg.s0123, b_lm[lm_offset + 9]); \ + acc.sa = dot(a_reg.s0123, b_lm[lm_offset + 10]); \ + acc.sb = dot(a_reg.s0123, b_lm[lm_offset + 11]); \ + acc.sc = dot(a_reg.s0123, b_lm[lm_offset + 12]); \ + acc.sd = dot(a_reg.s0123, b_lm[lm_offset + 13]); \ + acc.se = dot(a_reg.s0123, b_lm[lm_offset + 14]); \ + acc.sf = dot(a_reg.s0123, b_lm[lm_offset + 15]); \ + acc.s0 += dot(a_reg.s4567, b_lm[lm_offset + 32]); \ + acc.s1 += dot(a_reg.s4567, b_lm[lm_offset + 33]); \ + acc.s2 += dot(a_reg.s4567, b_lm[lm_offset + 34]); \ + acc.s3 += dot(a_reg.s4567, b_lm[lm_offset + 35]); \ + acc.s4 += dot(a_reg.s4567, b_lm[lm_offset + 36]); \ + acc.s5 += dot(a_reg.s4567, b_lm[lm_offset + 37]); \ + acc.s6 += dot(a_reg.s4567, b_lm[lm_offset + 38]); \ + acc.s7 += dot(a_reg.s4567, b_lm[lm_offset + 39]); \ + acc.s8 += dot(a_reg.s4567, b_lm[lm_offset + 40]); \ + acc.s9 += dot(a_reg.s4567, b_lm[lm_offset + 41]); \ + acc.sa += dot(a_reg.s4567, b_lm[lm_offset + 42]); \ + acc.sb += dot(a_reg.s4567, b_lm[lm_offset + 43]); \ + acc.sc += dot(a_reg.s4567, b_lm[lm_offset + 44]); \ + acc.sd += dot(a_reg.s4567, b_lm[lm_offset + 45]); \ + acc.se += dot(a_reg.s4567, b_lm[lm_offset + 46]); \ + acc.sf += dot(a_reg.s4567, b_lm[lm_offset + 47]); \ + c_reg.lo += convert_float8(acc.lo); \ + c_reg.hi += convert_float8(acc.hi); \ + acc.s0 = dot(a_reg.s89ab, b_lm[lm_offset + 64]); \ + acc.s1 = dot(a_reg.s89ab, b_lm[lm_offset + 65]); \ + acc.s2 = dot(a_reg.s89ab, b_lm[lm_offset + 66]); \ + acc.s3 = dot(a_reg.s89ab, b_lm[lm_offset + 67]); \ + acc.s4 = dot(a_reg.s89ab, b_lm[lm_offset + 68]); \ + acc.s5 = dot(a_reg.s89ab, b_lm[lm_offset + 69]); \ + acc.s6 = dot(a_reg.s89ab, b_lm[lm_offset + 70]); \ + acc.s7 = dot(a_reg.s89ab, b_lm[lm_offset + 71]); \ + acc.s8 = dot(a_reg.s89ab, b_lm[lm_offset + 72]); \ + acc.s9 = dot(a_reg.s89ab, b_lm[lm_offset + 73]); \ + acc.sa = dot(a_reg.s89ab, b_lm[lm_offset + 74]); \ + acc.sb = dot(a_reg.s89ab, b_lm[lm_offset + 75]); \ + acc.sc = dot(a_reg.s89ab, b_lm[lm_offset + 76]); \ + acc.sd = dot(a_reg.s89ab, b_lm[lm_offset + 77]); \ + acc.se = dot(a_reg.s89ab, b_lm[lm_offset + 78]); \ + acc.sf = dot(a_reg.s89ab, b_lm[lm_offset + 79]); \ + acc.s0 += dot(a_reg.scdef, b_lm[lm_offset + 96]); \ + acc.s1 += dot(a_reg.scdef, b_lm[lm_offset + 97]); \ + acc.s2 += dot(a_reg.scdef, b_lm[lm_offset + 98]); \ + acc.s3 += dot(a_reg.scdef, b_lm[lm_offset + 99]); \ + acc.s4 += dot(a_reg.scdef, b_lm[lm_offset + 100]); \ + acc.s5 += dot(a_reg.scdef, b_lm[lm_offset + 101]); \ + acc.s6 += dot(a_reg.scdef, b_lm[lm_offset + 102]); \ + acc.s7 += dot(a_reg.scdef, b_lm[lm_offset + 103]); \ + acc.s8 += dot(a_reg.scdef, b_lm[lm_offset + 104]); \ + acc.s9 += dot(a_reg.scdef, b_lm[lm_offset + 105]); \ + acc.sa += dot(a_reg.scdef, b_lm[lm_offset + 106]); \ + acc.sb += dot(a_reg.scdef, b_lm[lm_offset + 107]); \ + acc.sc += dot(a_reg.scdef, b_lm[lm_offset + 108]); \ + acc.sd += dot(a_reg.scdef, b_lm[lm_offset + 109]); \ + acc.se += dot(a_reg.scdef, b_lm[lm_offset + 110]); \ + acc.sf += dot(a_reg.scdef, b_lm[lm_offset + 111]); \ + c_reg.lo += convert_float8(acc.lo); \ + c_reg.hi += convert_float8(acc.hi); \ + + +static inline half e8m0_to_fp16(uchar x) { + ushort bits; + bits = (ushort)(x) - (ushort)(112); + bits = ((bits & 0x00E0) != 0) ? 0x7C00 : (bits << 10); + return as_half(bits); +} + +static inline float e8m0_to_fp32(uchar x) { + int bits; + bits = (x == 0) ? 0x00400000 : ((uint) x << 23); + return as_float(bits); +} + + +__attribute__((qcom_wave_pair_mode(1))) // 1=force single 2=force pair +kernel void kernel_gemm_moe_mxfp4_f32_ns( + __read_only image1d_buffer_t src0_q, + __global uchar * src0_d, + __read_only image1d_buffer_t src1, + __global uint * src2, + __global ushort * src2_emap, + __write_only image1d_buffer_t dst, + __global int * total_tiles, + uint ne00, + uint ne01 +) { + uint block_id_m = get_global_id(1); // m_tile + uint block_id_n = get_global_id(2); // n_tile + + // Boundary check + if (block_id_n >= total_tiles[0]) { + return; + } + + __private half16 reg_a; + __private float32 reg_c = (float32)(0); + __local half4 shared_b[128]; + + const ushort expert_id = src2_emap[block_id_n]; + + const uint row = block_id_m * TILESIZE_M; + const uint col = block_id_n * TILESIZE_N; + + uint sub_block_id_m = get_local_id(0); + uint2 b_global_offset; + b_global_offset.x = ((sub_block_id_m & 3) << 2) + (sub_block_id_m >> 2) * ne00; + b_global_offset.y = b_global_offset.x + (16 * ne00); + uint2 b_local_offset; + b_local_offset.x = (sub_block_id_m & 3) * 32 + (sub_block_id_m >> 2); + b_local_offset.y = b_local_offset.x + 16; + + // Loop along K axis, 32 elements (one block) for each iteration, divided into 2 sub-blocks + for (uint step = 0; step < ne00; step += TILESIZE_K * 2) { + // First sub-block + uint q_sub_offset = row + ((ne01 * step) >> 3) + ((expert_id * ne00 * ne01) >> 3); + uint s_sub_offset = row + ((ne01 * step) >> 5) + ((expert_id * ne00 * ne01) >> 5); + uint b_sub_offset = col * ne00 + step; + + // Load scale for current mxfp4 block + uint s_offset = s_sub_offset + get_global_id(0); + float s = e8m0_to_fp32(src0_d[s_offset]); + + // Load 16 fp4 (64-bits) in transposed layout + uint2 mxfp4x16; + mxfp4x16.x = read_imageui(src0_q, q_sub_offset + sub_block_id_m).x; + mxfp4x16.y = read_imageui(src0_q, q_sub_offset + sub_block_id_m + ne01).x; + + // Load 16x32 floats from matrix B, each fiber out of 64 in a sub-group loads 8 elements + float8 bx8_f32; + bx8_f32.lo = read_imagef(src1, (b_sub_offset + b_global_offset.x) / 4); + bx8_f32.hi = read_imagef(src1, (b_sub_offset + b_global_offset.y) / 4); + // Convert to half and store to LM to share within the subgroup + half8 bx8_f16 = convert_half8(bx8_f32); + shared_b[b_local_offset.x] = bx8_f16.lo; + shared_b[b_local_offset.y] = bx8_f16.hi; + + // Dequantization + reg_a.lo = mxfp4_to_fp16_packed8(as_ushort2(mxfp4x16.lo)) * s; + reg_a.hi = mxfp4_to_fp16_packed8(as_ushort2(mxfp4x16.hi)) * s; + + sub_group_barrier(CLK_LOCAL_MEM_FENCE); + + // 32 16x16 fp16 dot product with 8 elements reduction for better precision + half16 acc; + dotx16_reduce8(reg_a, shared_b, reg_c.lo, 0); + dotx16_reduce8(reg_a, shared_b, reg_c.hi, 16); + + // Repeat for second sub-block + uint half_step = step + TILESIZE_K; + q_sub_offset = row + ((ne01 * half_step) >> 3) + ((expert_id * ne00 * ne01) >> 3); + b_sub_offset = col * ne00 + half_step; + + // Load next 16 fp4 (64-bits) in transposed layout + mxfp4x16.x = read_imageui(src0_q, q_sub_offset + sub_block_id_m).x; + mxfp4x16.y = read_imageui(src0_q, q_sub_offset + sub_block_id_m + ne01).x; + + // Load 16x32 floats from matrix B, each fiber out of 64 in a sub-group loads 8 elements + bx8_f32.lo = read_imagef(src1, (b_sub_offset + b_global_offset.x) / 4); + bx8_f32.hi = read_imagef(src1, (b_sub_offset + b_global_offset.y) / 4); + // Convert to half and store to LM to share within the subgroup + bx8_f16 = convert_half8(bx8_f32); + shared_b[b_local_offset.x] = bx8_f16.lo; + shared_b[b_local_offset.y] = bx8_f16.hi; + + // Dequantization + reg_a.lo = mxfp4_to_fp16_packed8(as_ushort2(mxfp4x16.lo)) * s; + reg_a.hi = mxfp4_to_fp16_packed8(as_ushort2(mxfp4x16.hi)) * s; + + sub_group_barrier(CLK_LOCAL_MEM_FENCE); + + // 32 16x16 fp16 dot product with 3-levels reduction for better precision + dotx16_reduce8(reg_a, shared_b, reg_c.lo, 0); + dotx16_reduce8(reg_a, shared_b, reg_c.hi, 16); + } + + if ((get_global_id(0) + block_id_m * TILESIZE_M) >= ne01) { + return; + } + + // Load poster router and share in LM + __local uint out_idx[TILESIZE_N]; + + if (get_local_id(0) < TILESIZE_N) { + uint idx = src2[block_id_n * TILESIZE_N + get_local_id(0)]; + if (idx == 0xFFFFFFFF) { + idx = src2[block_id_n * TILESIZE_N + 0]; + } + out_idx[get_local_id(0)] = idx * ne01; + } + + barrier(CLK_LOCAL_MEM_FENCE); + + // Scatter results back to original position in output grid + uint m_offset = row + get_local_id(0); + + write_imagef(dst, out_idx[1] + m_offset, (reg_c.s1)); + write_imagef(dst, out_idx[2] + m_offset, (reg_c.s2)); + write_imagef(dst, out_idx[3] + m_offset, (reg_c.s3)); + write_imagef(dst, out_idx[4] + m_offset, (reg_c.s4)); + write_imagef(dst, out_idx[5] + m_offset, (reg_c.s5)); + write_imagef(dst, out_idx[6] + m_offset, (reg_c.s6)); + write_imagef(dst, out_idx[7] + m_offset, (reg_c.s7)); + write_imagef(dst, out_idx[8] + m_offset, (reg_c.s8)); + write_imagef(dst, out_idx[9] + m_offset, (reg_c.s9)); + write_imagef(dst, out_idx[10] + m_offset, (reg_c.sa)); + write_imagef(dst, out_idx[11] + m_offset, (reg_c.sb)); + write_imagef(dst, out_idx[12] + m_offset, (reg_c.sc)); + write_imagef(dst, out_idx[13] + m_offset, (reg_c.sd)); + write_imagef(dst, out_idx[14] + m_offset, (reg_c.se)); + write_imagef(dst, out_idx[15] + m_offset, (reg_c.sf)); + write_imagef(dst, out_idx[16] + m_offset, (reg_c.sg)); + write_imagef(dst, out_idx[17] + m_offset, (reg_c.sh)); + write_imagef(dst, out_idx[18] + m_offset, (reg_c.si)); + write_imagef(dst, out_idx[19] + m_offset, (reg_c.sj)); + write_imagef(dst, out_idx[20] + m_offset, (reg_c.sk)); + write_imagef(dst, out_idx[21] + m_offset, (reg_c.sl)); + write_imagef(dst, out_idx[22] + m_offset, (reg_c.sm)); + write_imagef(dst, out_idx[23] + m_offset, (reg_c.sn)); + write_imagef(dst, out_idx[24] + m_offset, (reg_c.so)); + write_imagef(dst, out_idx[25] + m_offset, (reg_c.sp)); + write_imagef(dst, out_idx[26] + m_offset, (reg_c.sq)); + write_imagef(dst, out_idx[27] + m_offset, (reg_c.sr)); + write_imagef(dst, out_idx[28] + m_offset, (reg_c.ss)); + write_imagef(dst, out_idx[29] + m_offset, (reg_c.st)); + write_imagef(dst, out_idx[30] + m_offset, (reg_c.su)); + write_imagef(dst, out_idx[31] + m_offset, (reg_c.sv)); + + // Store zero padding parts to the index of first output in tile, override correct result in the end + barrier(CLK_GLOBAL_MEM_FENCE); + write_imagef(dst, out_idx[0] + m_offset, (reg_c.s0)); +} diff --git a/ggml/src/ggml-opencl/kernels/gemm_moe_q4_0_f32_ns.cl b/ggml/src/ggml-opencl/kernels/gemm_moe_q4_0_f32_ns.cl new file mode 100644 index 00000000000..d403ed0cab1 --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/gemm_moe_q4_0_f32_ns.cl @@ -0,0 +1,256 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable +#pragma OPENCL EXTENSION cl_khr_subgroups : enable +#pragma OPENCL EXTENSION cl_qcom_subgroup_uniform_load: enable +#pragma OPENCL EXTENSION cl_qcom_subgroup_constant_load: enable +#pragma OPENCL EXTENSION cl_qcom_extra_vector_types : enable + +#define TILESIZE_K 16 +#define TILESIZE_M 64 +#define TILESIZE_N 32 + + +#define dequantize_q4_0(q4, a_f16, scale) \ + a_f16.s0 = (half)((q4.s0 & 0x000F) - 8) * scale; \ + a_f16.s1 = (half)(((q4.s0 & 0x00F0) >> 4) - 8) * scale; \ + a_f16.s2 = (half)(((q4.s0 & 0x0F00) >> 8) - 8) * scale; \ + a_f16.s3 = (half)(((q4.s0 & 0xF000) >> 12) - 8) * scale; \ + a_f16.s4 = (half)((q4.s1 & 0x000F) - 8) * scale; \ + a_f16.s5 = (half)(((q4.s1 & 0x00F0) >> 4) - 8) * scale; \ + a_f16.s6 = (half)(((q4.s1 & 0x0F00) >> 8) - 8) * scale; \ + a_f16.s7 = (half)(((q4.s1 & 0xF000) >> 12) - 8) * scale; \ + a_f16.s8 = (half)((q4.s2 & 0x000F) - 8) * scale; \ + a_f16.s9 = (half)(((q4.s2 & 0x00F0) >> 4) - 8) * scale; \ + a_f16.sa = (half)(((q4.s2 & 0x0F00) >> 8) - 8) * scale; \ + a_f16.sb = (half)(((q4.s2 & 0xF000) >> 12) - 8) * scale; \ + a_f16.sc = (half)((q4.s3 & 0x000F) - 8) * scale; \ + a_f16.sd = (half)(((q4.s3 & 0x00F0) >> 4) - 8) * scale; \ + a_f16.se = (half)(((q4.s3 & 0x0F00) >> 8) - 8) * scale; \ + a_f16.sf = (half)(((q4.s3 & 0xF000) >> 12) - 8) * scale; \ + + +#define dotx16_reduce8(a_reg, b_lm, c_reg, lm_offset) \ + acc.s0 = dot(a_reg.s0123, b_lm[lm_offset + 0]); \ + acc.s1 = dot(a_reg.s0123, b_lm[lm_offset + 1]); \ + acc.s2 = dot(a_reg.s0123, b_lm[lm_offset + 2]); \ + acc.s3 = dot(a_reg.s0123, b_lm[lm_offset + 3]); \ + acc.s4 = dot(a_reg.s0123, b_lm[lm_offset + 4]); \ + acc.s5 = dot(a_reg.s0123, b_lm[lm_offset + 5]); \ + acc.s6 = dot(a_reg.s0123, b_lm[lm_offset + 6]); \ + acc.s7 = dot(a_reg.s0123, b_lm[lm_offset + 7]); \ + acc.s8 = dot(a_reg.s0123, b_lm[lm_offset + 8]); \ + acc.s9 = dot(a_reg.s0123, b_lm[lm_offset + 9]); \ + acc.sa = dot(a_reg.s0123, b_lm[lm_offset + 10]); \ + acc.sb = dot(a_reg.s0123, b_lm[lm_offset + 11]); \ + acc.sc = dot(a_reg.s0123, b_lm[lm_offset + 12]); \ + acc.sd = dot(a_reg.s0123, b_lm[lm_offset + 13]); \ + acc.se = dot(a_reg.s0123, b_lm[lm_offset + 14]); \ + acc.sf = dot(a_reg.s0123, b_lm[lm_offset + 15]); \ + acc.s0 += dot(a_reg.s4567, b_lm[lm_offset + 32]); \ + acc.s1 += dot(a_reg.s4567, b_lm[lm_offset + 33]); \ + acc.s2 += dot(a_reg.s4567, b_lm[lm_offset + 34]); \ + acc.s3 += dot(a_reg.s4567, b_lm[lm_offset + 35]); \ + acc.s4 += dot(a_reg.s4567, b_lm[lm_offset + 36]); \ + acc.s5 += dot(a_reg.s4567, b_lm[lm_offset + 37]); \ + acc.s6 += dot(a_reg.s4567, b_lm[lm_offset + 38]); \ + acc.s7 += dot(a_reg.s4567, b_lm[lm_offset + 39]); \ + acc.s8 += dot(a_reg.s4567, b_lm[lm_offset + 40]); \ + acc.s9 += dot(a_reg.s4567, b_lm[lm_offset + 41]); \ + acc.sa += dot(a_reg.s4567, b_lm[lm_offset + 42]); \ + acc.sb += dot(a_reg.s4567, b_lm[lm_offset + 43]); \ + acc.sc += dot(a_reg.s4567, b_lm[lm_offset + 44]); \ + acc.sd += dot(a_reg.s4567, b_lm[lm_offset + 45]); \ + acc.se += dot(a_reg.s4567, b_lm[lm_offset + 46]); \ + acc.sf += dot(a_reg.s4567, b_lm[lm_offset + 47]); \ + c_reg.lo += convert_float8(acc.lo); \ + c_reg.hi += convert_float8(acc.hi); \ + acc.s0 = dot(a_reg.s89ab, b_lm[lm_offset + 64]); \ + acc.s1 = dot(a_reg.s89ab, b_lm[lm_offset + 65]); \ + acc.s2 = dot(a_reg.s89ab, b_lm[lm_offset + 66]); \ + acc.s3 = dot(a_reg.s89ab, b_lm[lm_offset + 67]); \ + acc.s4 = dot(a_reg.s89ab, b_lm[lm_offset + 68]); \ + acc.s5 = dot(a_reg.s89ab, b_lm[lm_offset + 69]); \ + acc.s6 = dot(a_reg.s89ab, b_lm[lm_offset + 70]); \ + acc.s7 = dot(a_reg.s89ab, b_lm[lm_offset + 71]); \ + acc.s8 = dot(a_reg.s89ab, b_lm[lm_offset + 72]); \ + acc.s9 = dot(a_reg.s89ab, b_lm[lm_offset + 73]); \ + acc.sa = dot(a_reg.s89ab, b_lm[lm_offset + 74]); \ + acc.sb = dot(a_reg.s89ab, b_lm[lm_offset + 75]); \ + acc.sc = dot(a_reg.s89ab, b_lm[lm_offset + 76]); \ + acc.sd = dot(a_reg.s89ab, b_lm[lm_offset + 77]); \ + acc.se = dot(a_reg.s89ab, b_lm[lm_offset + 78]); \ + acc.sf = dot(a_reg.s89ab, b_lm[lm_offset + 79]); \ + acc.s0 += dot(a_reg.scdef, b_lm[lm_offset + 96]); \ + acc.s1 += dot(a_reg.scdef, b_lm[lm_offset + 97]); \ + acc.s2 += dot(a_reg.scdef, b_lm[lm_offset + 98]); \ + acc.s3 += dot(a_reg.scdef, b_lm[lm_offset + 99]); \ + acc.s4 += dot(a_reg.scdef, b_lm[lm_offset + 100]); \ + acc.s5 += dot(a_reg.scdef, b_lm[lm_offset + 101]); \ + acc.s6 += dot(a_reg.scdef, b_lm[lm_offset + 102]); \ + acc.s7 += dot(a_reg.scdef, b_lm[lm_offset + 103]); \ + acc.s8 += dot(a_reg.scdef, b_lm[lm_offset + 104]); \ + acc.s9 += dot(a_reg.scdef, b_lm[lm_offset + 105]); \ + acc.sa += dot(a_reg.scdef, b_lm[lm_offset + 106]); \ + acc.sb += dot(a_reg.scdef, b_lm[lm_offset + 107]); \ + acc.sc += dot(a_reg.scdef, b_lm[lm_offset + 108]); \ + acc.sd += dot(a_reg.scdef, b_lm[lm_offset + 109]); \ + acc.se += dot(a_reg.scdef, b_lm[lm_offset + 110]); \ + acc.sf += dot(a_reg.scdef, b_lm[lm_offset + 111]); \ + c_reg.lo += convert_float8(acc.lo); \ + c_reg.hi += convert_float8(acc.hi); \ + + +__attribute__((qcom_wave_pair_mode(1))) // 1=force single 2=force pair +kernel void kernel_gemm_moe_q4_0_f32_ns( + __read_only image1d_buffer_t src0_q, + __global half * src0_d, + __read_only image1d_buffer_t src1, + __global uint * src2, + __global ushort * src2_emap, + __write_only image1d_buffer_t dst, + __global int * total_tiles, + uint ne00, + uint ne01 +) { + uint block_id_m = get_global_id(1); // m_tile + uint block_id_n = get_global_id(2); // n_tile + + // Boundary check + if (block_id_n >= total_tiles[0]) { + return; + } + + __private half16 reg_a; + __private float32 reg_c = (float32)(0); + __local half4 shared_b[128]; + + const ushort expert_id = src2_emap[block_id_n]; + + const uint row = block_id_m * TILESIZE_M; + const uint col = block_id_n * TILESIZE_N; + + uint sub_block_id_m = get_local_id(0); + uint2 b_global_offset; + b_global_offset.x = ((sub_block_id_m & 3) << 2) + (sub_block_id_m >> 2) * ne00; + b_global_offset.y = b_global_offset.x + (16 * ne00); + uint2 b_local_offset; + b_local_offset.x = (sub_block_id_m & 3) * 32 + (sub_block_id_m >> 2); + b_local_offset.y = b_local_offset.x + 16; + + // Loop along K axis, 32 elements (one block) for each iteration, divided into 2 sub-blocks + for (uint step = 0; step < ne00; step += TILESIZE_K * 2) { + // First sub-block + uint q_sub_offset = row + ((ne01 * step) >> 3) + ((expert_id * ne00 * ne01) >> 3); + uint s_sub_offset = row + ((ne01 * step) >> 5) + ((expert_id * ne00 * ne01) >> 5); + uint b_sub_offset = col * ne00 + step; + + // Load scale for current Q4_0 block + uint s_offset = s_sub_offset + get_global_id(0); + half s = src0_d[s_offset]; + + // Load 16 q (64-bits) in transposed layout + uint2 q4x16; + q4x16.x = read_imageui(src0_q, q_sub_offset + sub_block_id_m).x; + q4x16.y = read_imageui(src0_q, q_sub_offset + sub_block_id_m + ne01).x; + + // Load 16x32 floats from matrix B, each fiber out of 64 in a sub-group loads 8 elements + float8 bx8_f32; + bx8_f32.lo = read_imagef(src1, (b_sub_offset + b_global_offset.x) / 4); + bx8_f32.hi = read_imagef(src1, (b_sub_offset + b_global_offset.y) / 4); + // Convert to half and store to LM to share within the subgroup + half8 bx8_f16 = convert_half8(bx8_f32); + shared_b[b_local_offset.x] = bx8_f16.lo; + shared_b[b_local_offset.y] = bx8_f16.hi; + + // Dequantization + dequantize_q4_0(as_ushort4(q4x16), reg_a, s); + + sub_group_barrier(CLK_LOCAL_MEM_FENCE); + + // 32 16x16 fp16 dot product with 8 elements reduction for better precision + half16 acc; + dotx16_reduce8(reg_a, shared_b, reg_c.lo, 0); + dotx16_reduce8(reg_a, shared_b, reg_c.hi, 16); + + // Repeat for second sub-block + uint half_step = step + TILESIZE_K; + q_sub_offset = row + ((ne01 * half_step) >> 3) + ((expert_id * ne00 * ne01) >> 3); + b_sub_offset = col * ne00 + half_step; + + // Load next 16 q (64-bits) in transposed layout + q4x16.x = read_imageui(src0_q, q_sub_offset + sub_block_id_m).x; + q4x16.y = read_imageui(src0_q, q_sub_offset + sub_block_id_m + ne01).x; + + // Load 16x32 floats from matrix B, each fiber out of 64 in a sub-group loads 8 elements + bx8_f32.lo = read_imagef(src1, (b_sub_offset + b_global_offset.x) / 4); + bx8_f32.hi = read_imagef(src1, (b_sub_offset + b_global_offset.y) / 4); + // Convert to half and store to LM to share within the subgroup + bx8_f16 = convert_half8(bx8_f32); + shared_b[b_local_offset.x] = bx8_f16.lo; + shared_b[b_local_offset.y] = bx8_f16.hi; + + // Dequantization + dequantize_q4_0(as_ushort4(q4x16), reg_a, s); + + sub_group_barrier(CLK_LOCAL_MEM_FENCE); + + // 32 16x16 fp16 dot product with 3-levels reduction for better precision + dotx16_reduce8(reg_a, shared_b, reg_c.lo, 0); + dotx16_reduce8(reg_a, shared_b, reg_c.hi, 16); + } + + if ((get_global_id(0) + block_id_m * TILESIZE_M) >= ne01) { + return; + } + + // Load poster router and share in LM + __local uint out_idx[TILESIZE_N]; + + if (get_local_id(0) < TILESIZE_N) { + uint idx = src2[block_id_n * TILESIZE_N + get_local_id(0)]; + if (idx == 0xFFFFFFFF) { + idx = src2[block_id_n * TILESIZE_N + 0]; + } + out_idx[get_local_id(0)] = idx * ne01; + } + + barrier(CLK_LOCAL_MEM_FENCE); + + // Scatter results back to original position in output grid + uint m_offset = row + get_local_id(0); + + write_imagef(dst, out_idx[1] + m_offset, (reg_c.s1)); + write_imagef(dst, out_idx[2] + m_offset, (reg_c.s2)); + write_imagef(dst, out_idx[3] + m_offset, (reg_c.s3)); + write_imagef(dst, out_idx[4] + m_offset, (reg_c.s4)); + write_imagef(dst, out_idx[5] + m_offset, (reg_c.s5)); + write_imagef(dst, out_idx[6] + m_offset, (reg_c.s6)); + write_imagef(dst, out_idx[7] + m_offset, (reg_c.s7)); + write_imagef(dst, out_idx[8] + m_offset, (reg_c.s8)); + write_imagef(dst, out_idx[9] + m_offset, (reg_c.s9)); + write_imagef(dst, out_idx[10] + m_offset, (reg_c.sa)); + write_imagef(dst, out_idx[11] + m_offset, (reg_c.sb)); + write_imagef(dst, out_idx[12] + m_offset, (reg_c.sc)); + write_imagef(dst, out_idx[13] + m_offset, (reg_c.sd)); + write_imagef(dst, out_idx[14] + m_offset, (reg_c.se)); + write_imagef(dst, out_idx[15] + m_offset, (reg_c.sf)); + write_imagef(dst, out_idx[16] + m_offset, (reg_c.sg)); + write_imagef(dst, out_idx[17] + m_offset, (reg_c.sh)); + write_imagef(dst, out_idx[18] + m_offset, (reg_c.si)); + write_imagef(dst, out_idx[19] + m_offset, (reg_c.sj)); + write_imagef(dst, out_idx[20] + m_offset, (reg_c.sk)); + write_imagef(dst, out_idx[21] + m_offset, (reg_c.sl)); + write_imagef(dst, out_idx[22] + m_offset, (reg_c.sm)); + write_imagef(dst, out_idx[23] + m_offset, (reg_c.sn)); + write_imagef(dst, out_idx[24] + m_offset, (reg_c.so)); + write_imagef(dst, out_idx[25] + m_offset, (reg_c.sp)); + write_imagef(dst, out_idx[26] + m_offset, (reg_c.sq)); + write_imagef(dst, out_idx[27] + m_offset, (reg_c.sr)); + write_imagef(dst, out_idx[28] + m_offset, (reg_c.ss)); + write_imagef(dst, out_idx[29] + m_offset, (reg_c.st)); + write_imagef(dst, out_idx[30] + m_offset, (reg_c.su)); + write_imagef(dst, out_idx[31] + m_offset, (reg_c.sv)); + + // Store zero padding parts to the index of first output in tile, override correct result in the end + barrier(CLK_GLOBAL_MEM_FENCE); + write_imagef(dst, out_idx[0] + m_offset, (reg_c.s0)); +} diff --git a/ggml/src/ggml-opencl/kernels/gemm_moe_q4_1_f32_ns.cl b/ggml/src/ggml-opencl/kernels/gemm_moe_q4_1_f32_ns.cl new file mode 100644 index 00000000000..b2bddf3f73a --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/gemm_moe_q4_1_f32_ns.cl @@ -0,0 +1,258 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable +#pragma OPENCL EXTENSION cl_khr_subgroups : enable +#pragma OPENCL EXTENSION cl_qcom_subgroup_uniform_load: enable +#pragma OPENCL EXTENSION cl_qcom_subgroup_constant_load: enable +#pragma OPENCL EXTENSION cl_qcom_extra_vector_types : enable + +#define TILESIZE_K 16 +#define TILESIZE_M 64 +#define TILESIZE_N 32 + + +#define dequantize_q4_1(q4, a_f16, scale, m) \ + a_f16.s0 = (half)(q4.s0 & 0x000F) * scale + m; \ + a_f16.s1 = (half)((q4.s0 & 0x00F0) >> 4) * scale + m; \ + a_f16.s2 = (half)((q4.s0 & 0x0F00) >> 8) * scale + m; \ + a_f16.s3 = (half)((q4.s0 & 0xF000) >> 12) * scale + m; \ + a_f16.s4 = (half)(q4.s1 & 0x000F) * scale + m; \ + a_f16.s5 = (half)((q4.s1 & 0x00F0) >> 4) * scale + m; \ + a_f16.s6 = (half)((q4.s1 & 0x0F00) >> 8) * scale + m; \ + a_f16.s7 = (half)((q4.s1 & 0xF000) >> 12) * scale + m; \ + a_f16.s8 = (half)(q4.s2 & 0x000F) * scale + m; \ + a_f16.s9 = (half)((q4.s2 & 0x00F0) >> 4) * scale + m; \ + a_f16.sa = (half)((q4.s2 & 0x0F00) >> 8) * scale + m; \ + a_f16.sb = (half)((q4.s2 & 0xF000) >> 12) * scale + m; \ + a_f16.sc = (half)(q4.s3 & 0x000F) * scale + m; \ + a_f16.sd = (half)((q4.s3 & 0x00F0) >> 4) * scale + m; \ + a_f16.se = (half)((q4.s3 & 0x0F00) >> 8) * scale + m; \ + a_f16.sf = (half)((q4.s3 & 0xF000) >> 12) * scale + m; \ + + +#define dotx16_reduce8(a_reg, b_lm, c_reg, lm_offset) \ + acc.s0 = dot(a_reg.s0123, b_lm[lm_offset + 0]); \ + acc.s1 = dot(a_reg.s0123, b_lm[lm_offset + 1]); \ + acc.s2 = dot(a_reg.s0123, b_lm[lm_offset + 2]); \ + acc.s3 = dot(a_reg.s0123, b_lm[lm_offset + 3]); \ + acc.s4 = dot(a_reg.s0123, b_lm[lm_offset + 4]); \ + acc.s5 = dot(a_reg.s0123, b_lm[lm_offset + 5]); \ + acc.s6 = dot(a_reg.s0123, b_lm[lm_offset + 6]); \ + acc.s7 = dot(a_reg.s0123, b_lm[lm_offset + 7]); \ + acc.s8 = dot(a_reg.s0123, b_lm[lm_offset + 8]); \ + acc.s9 = dot(a_reg.s0123, b_lm[lm_offset + 9]); \ + acc.sa = dot(a_reg.s0123, b_lm[lm_offset + 10]); \ + acc.sb = dot(a_reg.s0123, b_lm[lm_offset + 11]); \ + acc.sc = dot(a_reg.s0123, b_lm[lm_offset + 12]); \ + acc.sd = dot(a_reg.s0123, b_lm[lm_offset + 13]); \ + acc.se = dot(a_reg.s0123, b_lm[lm_offset + 14]); \ + acc.sf = dot(a_reg.s0123, b_lm[lm_offset + 15]); \ + acc.s0 += dot(a_reg.s4567, b_lm[lm_offset + 32]); \ + acc.s1 += dot(a_reg.s4567, b_lm[lm_offset + 33]); \ + acc.s2 += dot(a_reg.s4567, b_lm[lm_offset + 34]); \ + acc.s3 += dot(a_reg.s4567, b_lm[lm_offset + 35]); \ + acc.s4 += dot(a_reg.s4567, b_lm[lm_offset + 36]); \ + acc.s5 += dot(a_reg.s4567, b_lm[lm_offset + 37]); \ + acc.s6 += dot(a_reg.s4567, b_lm[lm_offset + 38]); \ + acc.s7 += dot(a_reg.s4567, b_lm[lm_offset + 39]); \ + acc.s8 += dot(a_reg.s4567, b_lm[lm_offset + 40]); \ + acc.s9 += dot(a_reg.s4567, b_lm[lm_offset + 41]); \ + acc.sa += dot(a_reg.s4567, b_lm[lm_offset + 42]); \ + acc.sb += dot(a_reg.s4567, b_lm[lm_offset + 43]); \ + acc.sc += dot(a_reg.s4567, b_lm[lm_offset + 44]); \ + acc.sd += dot(a_reg.s4567, b_lm[lm_offset + 45]); \ + acc.se += dot(a_reg.s4567, b_lm[lm_offset + 46]); \ + acc.sf += dot(a_reg.s4567, b_lm[lm_offset + 47]); \ + c_reg.lo += convert_float8(acc.lo); \ + c_reg.hi += convert_float8(acc.hi); \ + acc.s0 = dot(a_reg.s89ab, b_lm[lm_offset + 64]); \ + acc.s1 = dot(a_reg.s89ab, b_lm[lm_offset + 65]); \ + acc.s2 = dot(a_reg.s89ab, b_lm[lm_offset + 66]); \ + acc.s3 = dot(a_reg.s89ab, b_lm[lm_offset + 67]); \ + acc.s4 = dot(a_reg.s89ab, b_lm[lm_offset + 68]); \ + acc.s5 = dot(a_reg.s89ab, b_lm[lm_offset + 69]); \ + acc.s6 = dot(a_reg.s89ab, b_lm[lm_offset + 70]); \ + acc.s7 = dot(a_reg.s89ab, b_lm[lm_offset + 71]); \ + acc.s8 = dot(a_reg.s89ab, b_lm[lm_offset + 72]); \ + acc.s9 = dot(a_reg.s89ab, b_lm[lm_offset + 73]); \ + acc.sa = dot(a_reg.s89ab, b_lm[lm_offset + 74]); \ + acc.sb = dot(a_reg.s89ab, b_lm[lm_offset + 75]); \ + acc.sc = dot(a_reg.s89ab, b_lm[lm_offset + 76]); \ + acc.sd = dot(a_reg.s89ab, b_lm[lm_offset + 77]); \ + acc.se = dot(a_reg.s89ab, b_lm[lm_offset + 78]); \ + acc.sf = dot(a_reg.s89ab, b_lm[lm_offset + 79]); \ + acc.s0 += dot(a_reg.scdef, b_lm[lm_offset + 96]); \ + acc.s1 += dot(a_reg.scdef, b_lm[lm_offset + 97]); \ + acc.s2 += dot(a_reg.scdef, b_lm[lm_offset + 98]); \ + acc.s3 += dot(a_reg.scdef, b_lm[lm_offset + 99]); \ + acc.s4 += dot(a_reg.scdef, b_lm[lm_offset + 100]); \ + acc.s5 += dot(a_reg.scdef, b_lm[lm_offset + 101]); \ + acc.s6 += dot(a_reg.scdef, b_lm[lm_offset + 102]); \ + acc.s7 += dot(a_reg.scdef, b_lm[lm_offset + 103]); \ + acc.s8 += dot(a_reg.scdef, b_lm[lm_offset + 104]); \ + acc.s9 += dot(a_reg.scdef, b_lm[lm_offset + 105]); \ + acc.sa += dot(a_reg.scdef, b_lm[lm_offset + 106]); \ + acc.sb += dot(a_reg.scdef, b_lm[lm_offset + 107]); \ + acc.sc += dot(a_reg.scdef, b_lm[lm_offset + 108]); \ + acc.sd += dot(a_reg.scdef, b_lm[lm_offset + 109]); \ + acc.se += dot(a_reg.scdef, b_lm[lm_offset + 110]); \ + acc.sf += dot(a_reg.scdef, b_lm[lm_offset + 111]); \ + c_reg.lo += convert_float8(acc.lo); \ + c_reg.hi += convert_float8(acc.hi); \ + + +__attribute__((qcom_wave_pair_mode(1))) // 1=force single 2=force pair +kernel void kernel_gemm_moe_q4_1_f32_ns( + __read_only image1d_buffer_t src0_q, + __global half * src0_d, + __global half * src0_m, + __read_only image1d_buffer_t src1, + __global uint * src2, + __global ushort * src2_emap, + __write_only image1d_buffer_t dst, + __global int * total_tiles, + uint ne00, + uint ne01 +) { + uint block_id_m = get_global_id(1); // m_tile + uint block_id_n = get_global_id(2); // n_tile + + // Boundary check + if (block_id_n >= total_tiles[0]) { + return; + } + + __private half16 reg_a; + __private float32 reg_c = (float32)(0); + __local half4 shared_b[128]; + + const ushort expert_id = src2_emap[block_id_n]; + + const uint row = block_id_m * TILESIZE_M; + const uint col = block_id_n * TILESIZE_N; + + uint sub_block_id_m = get_local_id(0); + uint2 b_global_offset; + b_global_offset.x = ((sub_block_id_m & 3) << 2) + (sub_block_id_m >> 2) * ne00; + b_global_offset.y = b_global_offset.x + (16 * ne00); + uint2 b_local_offset; + b_local_offset.x = (sub_block_id_m & 3) * 32 + (sub_block_id_m >> 2); + b_local_offset.y = b_local_offset.x + 16; + + // Loop along K axis, 32 elements (one block) for each iteration, divided into 2 sub-blocks + for (uint step = 0; step < ne00; step += TILESIZE_K * 2) { + // First sub-block + uint q_sub_offset = row + ((ne01 * step) >> 3) + ((expert_id * ne00 * ne01) >> 3); + uint s_sub_offset = row + ((ne01 * step) >> 5) + ((expert_id * ne00 * ne01) >> 5); + uint b_sub_offset = col * ne00 + step; + + // Load scale and m for current Q4_1 block + uint sm_offset = s_sub_offset + get_global_id(0); + half s = src0_d[sm_offset]; + half m = src0_m[sm_offset]; + + // Load 16 q (64-bits) in transposed layout + uint2 q4x16; + q4x16.x = read_imageui(src0_q, q_sub_offset + sub_block_id_m).x; + q4x16.y = read_imageui(src0_q, q_sub_offset + sub_block_id_m + ne01).x; + + // Load 16x32 floats from matrix B, each fiber out of 64 in a sub-group loads 8 elements + float8 bx8_f32; + bx8_f32.lo = read_imagef(src1, (b_sub_offset + b_global_offset.x) / 4); + bx8_f32.hi = read_imagef(src1, (b_sub_offset + b_global_offset.y) / 4); + // Convert to half and store to LM to share within the subgroup + half8 bx8_f16 = convert_half8(bx8_f32); + shared_b[b_local_offset.x] = bx8_f16.lo; + shared_b[b_local_offset.y] = bx8_f16.hi; + + // Dequantization + dequantize_q4_1(as_ushort4(q4x16), reg_a, s, m); + + sub_group_barrier(CLK_LOCAL_MEM_FENCE); + + // 32 16x16 fp16 dot product with 8 elements reduction for better precision + half16 acc; + dotx16_reduce8(reg_a, shared_b, reg_c.lo, 0); + dotx16_reduce8(reg_a, shared_b, reg_c.hi, 16); + + // Repeat for second sub-block + uint half_step = step + TILESIZE_K; + q_sub_offset = row + ((ne01 * half_step) >> 3) + ((expert_id * ne00 * ne01) >> 3); + b_sub_offset = col * ne00 + half_step; + + // Load next 16 q (64-bits) in transposed layout + q4x16.x = read_imageui(src0_q, q_sub_offset + sub_block_id_m).x; + q4x16.y = read_imageui(src0_q, q_sub_offset + sub_block_id_m + ne01).x; + + // Load 16x32 floats from matrix B, each fiber out of 64 in a sub-group loads 8 elements + bx8_f32.lo = read_imagef(src1, (b_sub_offset + b_global_offset.x) / 4); + bx8_f32.hi = read_imagef(src1, (b_sub_offset + b_global_offset.y) / 4); + // Convert to half and store to LM to share within the subgroup + bx8_f16 = convert_half8(bx8_f32); + shared_b[b_local_offset.x] = bx8_f16.lo; + shared_b[b_local_offset.y] = bx8_f16.hi; + + // Dequantization + dequantize_q4_1(as_ushort4(q4x16), reg_a, s, m); + + sub_group_barrier(CLK_LOCAL_MEM_FENCE); + + // 32 16x16 fp16 dot product with 3-levels reduction for better precision + dotx16_reduce8(reg_a, shared_b, reg_c.lo, 0); + dotx16_reduce8(reg_a, shared_b, reg_c.hi, 16); + } + + if ((get_global_id(0) + block_id_m * TILESIZE_M) >= ne01) { + return; + } + + // Load poster router and share in LM + __local uint out_idx[TILESIZE_N]; + + if (get_local_id(0) < TILESIZE_N) { + uint idx = src2[block_id_n * TILESIZE_N + get_local_id(0)]; + if (idx == 0xFFFFFFFF) { + idx = src2[block_id_n * TILESIZE_N + 0]; + } + out_idx[get_local_id(0)] = idx * ne01; + } + + barrier(CLK_LOCAL_MEM_FENCE); + + // Scatter results back to original position in output grid + uint m_offset = row + get_local_id(0); + + write_imagef(dst, out_idx[1] + m_offset, (reg_c.s1)); + write_imagef(dst, out_idx[2] + m_offset, (reg_c.s2)); + write_imagef(dst, out_idx[3] + m_offset, (reg_c.s3)); + write_imagef(dst, out_idx[4] + m_offset, (reg_c.s4)); + write_imagef(dst, out_idx[5] + m_offset, (reg_c.s5)); + write_imagef(dst, out_idx[6] + m_offset, (reg_c.s6)); + write_imagef(dst, out_idx[7] + m_offset, (reg_c.s7)); + write_imagef(dst, out_idx[8] + m_offset, (reg_c.s8)); + write_imagef(dst, out_idx[9] + m_offset, (reg_c.s9)); + write_imagef(dst, out_idx[10] + m_offset, (reg_c.sa)); + write_imagef(dst, out_idx[11] + m_offset, (reg_c.sb)); + write_imagef(dst, out_idx[12] + m_offset, (reg_c.sc)); + write_imagef(dst, out_idx[13] + m_offset, (reg_c.sd)); + write_imagef(dst, out_idx[14] + m_offset, (reg_c.se)); + write_imagef(dst, out_idx[15] + m_offset, (reg_c.sf)); + write_imagef(dst, out_idx[16] + m_offset, (reg_c.sg)); + write_imagef(dst, out_idx[17] + m_offset, (reg_c.sh)); + write_imagef(dst, out_idx[18] + m_offset, (reg_c.si)); + write_imagef(dst, out_idx[19] + m_offset, (reg_c.sj)); + write_imagef(dst, out_idx[20] + m_offset, (reg_c.sk)); + write_imagef(dst, out_idx[21] + m_offset, (reg_c.sl)); + write_imagef(dst, out_idx[22] + m_offset, (reg_c.sm)); + write_imagef(dst, out_idx[23] + m_offset, (reg_c.sn)); + write_imagef(dst, out_idx[24] + m_offset, (reg_c.so)); + write_imagef(dst, out_idx[25] + m_offset, (reg_c.sp)); + write_imagef(dst, out_idx[26] + m_offset, (reg_c.sq)); + write_imagef(dst, out_idx[27] + m_offset, (reg_c.sr)); + write_imagef(dst, out_idx[28] + m_offset, (reg_c.ss)); + write_imagef(dst, out_idx[29] + m_offset, (reg_c.st)); + write_imagef(dst, out_idx[30] + m_offset, (reg_c.su)); + write_imagef(dst, out_idx[31] + m_offset, (reg_c.sv)); + + // Store zero padding parts to the index of first output in tile, override correct result in the end + barrier(CLK_GLOBAL_MEM_FENCE); + write_imagef(dst, out_idx[0] + m_offset, (reg_c.s0)); +} diff --git a/ggml/src/ggml-opencl/kernels/gemm_moe_q4_k_f32_ns.cl b/ggml/src/ggml-opencl/kernels/gemm_moe_q4_k_f32_ns.cl new file mode 100644 index 00000000000..ab8228d18ca --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/gemm_moe_q4_k_f32_ns.cl @@ -0,0 +1,283 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable +#pragma OPENCL EXTENSION cl_khr_subgroups : enable +#pragma OPENCL EXTENSION cl_qcom_subgroup_uniform_load: enable +#pragma OPENCL EXTENSION cl_qcom_subgroup_constant_load: enable +#pragma OPENCL EXTENSION cl_qcom_extra_vector_types : enable + +#define TILESIZE_K 16 +#define TILESIZE_M 64 +#define TILESIZE_N 32 +#define QK_K 256 +#define K_SCALE_SIZE 12 + +inline void get_scale_min_k4( + int j, + global const uchar * q, + uchar * d, + uchar * m +) { + if (j < 4) { + *d = q[j] & 63; + *m = q[j+4] & 63; + } else { + *d = (q[j+4] & 0x0F) | ((q[j-4] & 0xC0) >> 2); + *m = ((q[j+4] >> 4) & 0x0F) | ((q[j] & 0xC0) >> 2); + } +} + +#define dequantize_q4_k(q4, a_f16, scale, minv) \ + a_f16.s0 = (half)((float)(q4.s0 & 0x000F) * scale - minv); \ + a_f16.s1 = (half)((float)((q4.s0 & 0x00F0) >> 4) * scale - minv); \ + a_f16.s2 = (half)((float)((q4.s0 & 0x0F00) >> 8) * scale - minv); \ + a_f16.s3 = (half)((float)((q4.s0 & 0xF000) >> 12) * scale - minv); \ + a_f16.s4 = (half)((float)(q4.s1 & 0x000F) * scale - minv); \ + a_f16.s5 = (half)((float)((q4.s1 & 0x00F0) >> 4) * scale - minv); \ + a_f16.s6 = (half)((float)((q4.s1 & 0x0F00) >> 8) * scale - minv); \ + a_f16.s7 = (half)((float)((q4.s1 & 0xF000) >> 12) * scale - minv); \ + a_f16.s8 = (half)((float)(q4.s2 & 0x000F) * scale - minv); \ + a_f16.s9 = (half)((float)((q4.s2 & 0x00F0) >> 4) * scale - minv); \ + a_f16.sa = (half)((float)((q4.s2 & 0x0F00) >> 8) * scale - minv); \ + a_f16.sb = (half)((float)((q4.s2 & 0xF000) >> 12) * scale - minv); \ + a_f16.sc = (half)((float)(q4.s3 & 0x000F) * scale - minv); \ + a_f16.sd = (half)((float)((q4.s3 & 0x00F0) >> 4) * scale - minv); \ + a_f16.se = (half)((float)((q4.s3 & 0x0F00) >> 8) * scale - minv); \ + a_f16.sf = (half)((float)((q4.s3 & 0xF000) >> 12) * scale - minv); \ + + +#define dotx16_reduce8(a_reg, b_lm, c_reg, lm_offset) \ + acc.s0 = dot(a_reg.s0123, b_lm[lm_offset + 0]); \ + acc.s1 = dot(a_reg.s0123, b_lm[lm_offset + 1]); \ + acc.s2 = dot(a_reg.s0123, b_lm[lm_offset + 2]); \ + acc.s3 = dot(a_reg.s0123, b_lm[lm_offset + 3]); \ + acc.s4 = dot(a_reg.s0123, b_lm[lm_offset + 4]); \ + acc.s5 = dot(a_reg.s0123, b_lm[lm_offset + 5]); \ + acc.s6 = dot(a_reg.s0123, b_lm[lm_offset + 6]); \ + acc.s7 = dot(a_reg.s0123, b_lm[lm_offset + 7]); \ + acc.s8 = dot(a_reg.s0123, b_lm[lm_offset + 8]); \ + acc.s9 = dot(a_reg.s0123, b_lm[lm_offset + 9]); \ + acc.sa = dot(a_reg.s0123, b_lm[lm_offset + 10]); \ + acc.sb = dot(a_reg.s0123, b_lm[lm_offset + 11]); \ + acc.sc = dot(a_reg.s0123, b_lm[lm_offset + 12]); \ + acc.sd = dot(a_reg.s0123, b_lm[lm_offset + 13]); \ + acc.se = dot(a_reg.s0123, b_lm[lm_offset + 14]); \ + acc.sf = dot(a_reg.s0123, b_lm[lm_offset + 15]); \ + acc.s0 += dot(a_reg.s4567, b_lm[lm_offset + 32]); \ + acc.s1 += dot(a_reg.s4567, b_lm[lm_offset + 33]); \ + acc.s2 += dot(a_reg.s4567, b_lm[lm_offset + 34]); \ + acc.s3 += dot(a_reg.s4567, b_lm[lm_offset + 35]); \ + acc.s4 += dot(a_reg.s4567, b_lm[lm_offset + 36]); \ + acc.s5 += dot(a_reg.s4567, b_lm[lm_offset + 37]); \ + acc.s6 += dot(a_reg.s4567, b_lm[lm_offset + 38]); \ + acc.s7 += dot(a_reg.s4567, b_lm[lm_offset + 39]); \ + acc.s8 += dot(a_reg.s4567, b_lm[lm_offset + 40]); \ + acc.s9 += dot(a_reg.s4567, b_lm[lm_offset + 41]); \ + acc.sa += dot(a_reg.s4567, b_lm[lm_offset + 42]); \ + acc.sb += dot(a_reg.s4567, b_lm[lm_offset + 43]); \ + acc.sc += dot(a_reg.s4567, b_lm[lm_offset + 44]); \ + acc.sd += dot(a_reg.s4567, b_lm[lm_offset + 45]); \ + acc.se += dot(a_reg.s4567, b_lm[lm_offset + 46]); \ + acc.sf += dot(a_reg.s4567, b_lm[lm_offset + 47]); \ + c_reg.lo += convert_float8(acc.lo); \ + c_reg.hi += convert_float8(acc.hi); \ + acc.s0 = dot(a_reg.s89ab, b_lm[lm_offset + 64]); \ + acc.s1 = dot(a_reg.s89ab, b_lm[lm_offset + 65]); \ + acc.s2 = dot(a_reg.s89ab, b_lm[lm_offset + 66]); \ + acc.s3 = dot(a_reg.s89ab, b_lm[lm_offset + 67]); \ + acc.s4 = dot(a_reg.s89ab, b_lm[lm_offset + 68]); \ + acc.s5 = dot(a_reg.s89ab, b_lm[lm_offset + 69]); \ + acc.s6 = dot(a_reg.s89ab, b_lm[lm_offset + 70]); \ + acc.s7 = dot(a_reg.s89ab, b_lm[lm_offset + 71]); \ + acc.s8 = dot(a_reg.s89ab, b_lm[lm_offset + 72]); \ + acc.s9 = dot(a_reg.s89ab, b_lm[lm_offset + 73]); \ + acc.sa = dot(a_reg.s89ab, b_lm[lm_offset + 74]); \ + acc.sb = dot(a_reg.s89ab, b_lm[lm_offset + 75]); \ + acc.sc = dot(a_reg.s89ab, b_lm[lm_offset + 76]); \ + acc.sd = dot(a_reg.s89ab, b_lm[lm_offset + 77]); \ + acc.se = dot(a_reg.s89ab, b_lm[lm_offset + 78]); \ + acc.sf = dot(a_reg.s89ab, b_lm[lm_offset + 79]); \ + acc.s0 += dot(a_reg.scdef, b_lm[lm_offset + 96]); \ + acc.s1 += dot(a_reg.scdef, b_lm[lm_offset + 97]); \ + acc.s2 += dot(a_reg.scdef, b_lm[lm_offset + 98]); \ + acc.s3 += dot(a_reg.scdef, b_lm[lm_offset + 99]); \ + acc.s4 += dot(a_reg.scdef, b_lm[lm_offset + 100]); \ + acc.s5 += dot(a_reg.scdef, b_lm[lm_offset + 101]); \ + acc.s6 += dot(a_reg.scdef, b_lm[lm_offset + 102]); \ + acc.s7 += dot(a_reg.scdef, b_lm[lm_offset + 103]); \ + acc.s8 += dot(a_reg.scdef, b_lm[lm_offset + 104]); \ + acc.s9 += dot(a_reg.scdef, b_lm[lm_offset + 105]); \ + acc.sa += dot(a_reg.scdef, b_lm[lm_offset + 106]); \ + acc.sb += dot(a_reg.scdef, b_lm[lm_offset + 107]); \ + acc.sc += dot(a_reg.scdef, b_lm[lm_offset + 108]); \ + acc.sd += dot(a_reg.scdef, b_lm[lm_offset + 109]); \ + acc.se += dot(a_reg.scdef, b_lm[lm_offset + 110]); \ + acc.sf += dot(a_reg.scdef, b_lm[lm_offset + 111]); \ + c_reg.lo += convert_float8(acc.lo); \ + c_reg.hi += convert_float8(acc.hi); \ + + +__attribute__((qcom_wave_pair_mode(1))) +kernel void kernel_gemm_moe_q4_k_f32_ns( + __read_only image1d_buffer_t src0_q, + __global half * src0_d, + __global half * src0_dm, + __global uchar * src0_s, + __read_only image1d_buffer_t src1, + __global uint * src2, + __global ushort * src2_emap, + __write_only image1d_buffer_t dst, + __global int * total_tiles, + uint ne00, + uint ne01 +) { + uint block_id_m = get_global_id(1); // m_tile + uint block_id_n = get_global_id(2); // n_tile + + // Boundary check + if (block_id_n >= total_tiles[0]) { + return; + } + + __private half16 reg_a; + __private float32 reg_c = (float32)(0); + __local half4 shared_b[128]; + + const ushort expert_id = src2_emap[block_id_n]; + + const uint row = block_id_m * TILESIZE_M; + const uint col = block_id_n * TILESIZE_N; + + uint sub_block_id_m = get_local_id(0); + uint2 b_global_offset; + b_global_offset.x = ((sub_block_id_m & 3) << 2) + (sub_block_id_m >> 2) * ne00; + b_global_offset.y = b_global_offset.x + (16 * ne00); + uint2 b_local_offset; + b_local_offset.x = (sub_block_id_m & 3) * 32 + (sub_block_id_m >> 2); + b_local_offset.y = b_local_offset.x + 16; + + uint num_superblocks = ne00 / QK_K; + uint scales_per_row = num_superblocks * K_SCALE_SIZE; + uint row_idx = row + get_global_id(0); + + // Loop along K axis, 32 elements per iteration (one sub-block), divided into 2 halves of 16 + for (uint step = 0; step < ne00; step += TILESIZE_K * 2) { + uint sub = step / 32; + uint sb = sub / 8; + uint j = sub % 8; + + // Load d and dm for super-block + uint d_offset = row + sb * ne01 + expert_id * num_superblocks * ne01 + get_global_id(0); + half d_val = src0_d[d_offset]; + half dm_val = src0_dm[d_offset]; + + // Load sub-block scale and min + global const uchar * sc = src0_s + (expert_id * ne01 + row_idx) * scales_per_row + sb * K_SCALE_SIZE; + uchar sv, mn; + get_scale_min_k4(j, sc, &sv, &mn); + + float scale = (float)d_val * (float)sv; + float minv = (float)dm_val * (float)mn; + + // First sub-block (16 elements) + uint q_sub_offset = row + ((ne01 * step) >> 3) + ((expert_id * ne00 * ne01) >> 3); + uint b_sub_offset = col * ne00 + step; + + // Load 16 q (64-bits) in transposed layout + uint2 q4x16; + q4x16.x = read_imageui(src0_q, q_sub_offset + sub_block_id_m).x; + q4x16.y = read_imageui(src0_q, q_sub_offset + sub_block_id_m + ne01).x; + + // Load 16x32 floats from matrix B + float8 bx8_f32; + bx8_f32.lo = read_imagef(src1, (b_sub_offset + b_global_offset.x) / 4); + bx8_f32.hi = read_imagef(src1, (b_sub_offset + b_global_offset.y) / 4); + half8 bx8_f16 = convert_half8(bx8_f32); + shared_b[b_local_offset.x] = bx8_f16.lo; + shared_b[b_local_offset.y] = bx8_f16.hi; + + // Dequantization + dequantize_q4_k(as_ushort4(q4x16), reg_a, scale, minv); + + sub_group_barrier(CLK_LOCAL_MEM_FENCE); + + half16 acc; + dotx16_reduce8(reg_a, shared_b, reg_c.lo, 0); + dotx16_reduce8(reg_a, shared_b, reg_c.hi, 16); + + // Second half (next 16 elements, same sub-block scale) + uint half_step = step + TILESIZE_K; + q_sub_offset = row + ((ne01 * half_step) >> 3) + ((expert_id * ne00 * ne01) >> 3); + b_sub_offset = col * ne00 + half_step; + + q4x16.x = read_imageui(src0_q, q_sub_offset + sub_block_id_m).x; + q4x16.y = read_imageui(src0_q, q_sub_offset + sub_block_id_m + ne01).x; + + bx8_f32.lo = read_imagef(src1, (b_sub_offset + b_global_offset.x) / 4); + bx8_f32.hi = read_imagef(src1, (b_sub_offset + b_global_offset.y) / 4); + bx8_f16 = convert_half8(bx8_f32); + shared_b[b_local_offset.x] = bx8_f16.lo; + shared_b[b_local_offset.y] = bx8_f16.hi; + + dequantize_q4_k(as_ushort4(q4x16), reg_a, scale, minv); + + sub_group_barrier(CLK_LOCAL_MEM_FENCE); + + dotx16_reduce8(reg_a, shared_b, reg_c.lo, 0); + dotx16_reduce8(reg_a, shared_b, reg_c.hi, 16); + } + + if ((get_global_id(0) + block_id_m * TILESIZE_M) >= ne01) { + return; + } + + // Load post router and share in LM + __local uint out_idx[TILESIZE_N]; + + if (get_local_id(0) < TILESIZE_N) { + uint idx = src2[block_id_n * TILESIZE_N + get_local_id(0)]; + if (idx == 0xFFFFFFFF) { + idx = src2[block_id_n * TILESIZE_N + 0]; + } + out_idx[get_local_id(0)] = idx * ne01; + } + + barrier(CLK_LOCAL_MEM_FENCE); + + // Scatter results back to original position in output grid + uint m_offset = row + get_local_id(0); + + write_imagef(dst, out_idx[1] + m_offset, (reg_c.s1)); + write_imagef(dst, out_idx[2] + m_offset, (reg_c.s2)); + write_imagef(dst, out_idx[3] + m_offset, (reg_c.s3)); + write_imagef(dst, out_idx[4] + m_offset, (reg_c.s4)); + write_imagef(dst, out_idx[5] + m_offset, (reg_c.s5)); + write_imagef(dst, out_idx[6] + m_offset, (reg_c.s6)); + write_imagef(dst, out_idx[7] + m_offset, (reg_c.s7)); + write_imagef(dst, out_idx[8] + m_offset, (reg_c.s8)); + write_imagef(dst, out_idx[9] + m_offset, (reg_c.s9)); + write_imagef(dst, out_idx[10] + m_offset, (reg_c.sa)); + write_imagef(dst, out_idx[11] + m_offset, (reg_c.sb)); + write_imagef(dst, out_idx[12] + m_offset, (reg_c.sc)); + write_imagef(dst, out_idx[13] + m_offset, (reg_c.sd)); + write_imagef(dst, out_idx[14] + m_offset, (reg_c.se)); + write_imagef(dst, out_idx[15] + m_offset, (reg_c.sf)); + write_imagef(dst, out_idx[16] + m_offset, (reg_c.sg)); + write_imagef(dst, out_idx[17] + m_offset, (reg_c.sh)); + write_imagef(dst, out_idx[18] + m_offset, (reg_c.si)); + write_imagef(dst, out_idx[19] + m_offset, (reg_c.sj)); + write_imagef(dst, out_idx[20] + m_offset, (reg_c.sk)); + write_imagef(dst, out_idx[21] + m_offset, (reg_c.sl)); + write_imagef(dst, out_idx[22] + m_offset, (reg_c.sm)); + write_imagef(dst, out_idx[23] + m_offset, (reg_c.sn)); + write_imagef(dst, out_idx[24] + m_offset, (reg_c.so)); + write_imagef(dst, out_idx[25] + m_offset, (reg_c.sp)); + write_imagef(dst, out_idx[26] + m_offset, (reg_c.sq)); + write_imagef(dst, out_idx[27] + m_offset, (reg_c.sr)); + write_imagef(dst, out_idx[28] + m_offset, (reg_c.ss)); + write_imagef(dst, out_idx[29] + m_offset, (reg_c.st)); + write_imagef(dst, out_idx[30] + m_offset, (reg_c.su)); + write_imagef(dst, out_idx[31] + m_offset, (reg_c.sv)); + + // Store zero padding parts to the index of first output in tile + barrier(CLK_GLOBAL_MEM_FENCE); + write_imagef(dst, out_idx[0] + m_offset, (reg_c.s0)); +} diff --git a/ggml/src/ggml-opencl/kernels/gemm_moe_q5_0_f32_ns.cl b/ggml/src/ggml-opencl/kernels/gemm_moe_q5_0_f32_ns.cl new file mode 100644 index 00000000000..d1a35d58bb2 --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/gemm_moe_q5_0_f32_ns.cl @@ -0,0 +1,260 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable +#pragma OPENCL EXTENSION cl_khr_subgroups : enable +#pragma OPENCL EXTENSION cl_qcom_subgroup_uniform_load: enable +#pragma OPENCL EXTENSION cl_qcom_subgroup_constant_load: enable +#pragma OPENCL EXTENSION cl_qcom_extra_vector_types : enable + +#define TILESIZE_K 16 +#define TILESIZE_M 64 +#define TILESIZE_N 32 + + +#define dequantize_q5_0(qs5x16, qh5x16, a_f16, scale) \ + a_f16.s0 = (half)((( qs5x16.s0 & 0x000F) | (( qh5x16.s0 & 0x01) << 4)) - 16) * scale; \ + a_f16.s1 = (half)((((qs5x16.s0 & 0x00F0) >> 4 ) | (((qh5x16.s0 >> 1) & 0x01) << 4)) - 16) * scale; \ + a_f16.s2 = (half)((((qs5x16.s0 & 0x0F00) >> 8 ) | (((qh5x16.s0 >> 2) & 0x01) << 4)) - 16) * scale; \ + a_f16.s3 = (half)((((qs5x16.s0 & 0xF000) >> 12) | (((qh5x16.s0 >> 3) & 0x01) << 4)) - 16) * scale; \ + a_f16.s4 = (half)((( qs5x16.s1 & 0x000F) | (((qh5x16.s0 >> 4) & 0x01) << 4)) - 16) * scale; \ + a_f16.s5 = (half)((((qs5x16.s1 & 0x00F0) >> 4 ) | (((qh5x16.s0 >> 5) & 0x01) << 4)) - 16) * scale; \ + a_f16.s6 = (half)((((qs5x16.s1 & 0x0F00) >> 8 ) | (((qh5x16.s0 >> 6) & 0x01) << 4)) - 16) * scale; \ + a_f16.s7 = (half)((((qs5x16.s1 & 0xF000) >> 12) | (((qh5x16.s0 >> 7) & 0x01) << 4)) - 16) * scale; \ + a_f16.s8 = (half)((( qs5x16.s2 & 0x000F) | (( qh5x16.s1 & 0x01) << 4)) - 16) * scale; \ + a_f16.s9 = (half)((((qs5x16.s2 & 0x00F0) >> 4 ) | (((qh5x16.s1 >> 1) & 0x01) << 4)) - 16) * scale; \ + a_f16.sa = (half)((((qs5x16.s2 & 0x0F00) >> 8 ) | (((qh5x16.s1 >> 2) & 0x01) << 4)) - 16) * scale; \ + a_f16.sb = (half)((((qs5x16.s2 & 0xF000) >> 12) | (((qh5x16.s1 >> 3) & 0x01) << 4)) - 16) * scale; \ + a_f16.sc = (half)((( qs5x16.s3 & 0x000F) | (((qh5x16.s1 >> 4) & 0x01) << 4)) - 16) * scale; \ + a_f16.sd = (half)((((qs5x16.s3 & 0x00F0) >> 4 ) | (((qh5x16.s1 >> 5) & 0x01) << 4)) - 16) * scale; \ + a_f16.se = (half)((((qs5x16.s3 & 0x0F00) >> 8 ) | (((qh5x16.s1 >> 6) & 0x01) << 4)) - 16) * scale; \ + a_f16.sf = (half)((((qs5x16.s3 & 0xF000) >> 12) | (((qh5x16.s1 >> 7) & 0x01) << 4)) - 16) * scale; \ + + +#define dotx16_reduce8(a_reg, b_lm, c_reg, lm_offset) \ + acc.s0 = dot(a_reg.s0123, b_lm[lm_offset + 0]); \ + acc.s1 = dot(a_reg.s0123, b_lm[lm_offset + 1]); \ + acc.s2 = dot(a_reg.s0123, b_lm[lm_offset + 2]); \ + acc.s3 = dot(a_reg.s0123, b_lm[lm_offset + 3]); \ + acc.s4 = dot(a_reg.s0123, b_lm[lm_offset + 4]); \ + acc.s5 = dot(a_reg.s0123, b_lm[lm_offset + 5]); \ + acc.s6 = dot(a_reg.s0123, b_lm[lm_offset + 6]); \ + acc.s7 = dot(a_reg.s0123, b_lm[lm_offset + 7]); \ + acc.s8 = dot(a_reg.s0123, b_lm[lm_offset + 8]); \ + acc.s9 = dot(a_reg.s0123, b_lm[lm_offset + 9]); \ + acc.sa = dot(a_reg.s0123, b_lm[lm_offset + 10]); \ + acc.sb = dot(a_reg.s0123, b_lm[lm_offset + 11]); \ + acc.sc = dot(a_reg.s0123, b_lm[lm_offset + 12]); \ + acc.sd = dot(a_reg.s0123, b_lm[lm_offset + 13]); \ + acc.se = dot(a_reg.s0123, b_lm[lm_offset + 14]); \ + acc.sf = dot(a_reg.s0123, b_lm[lm_offset + 15]); \ + acc.s0 += dot(a_reg.s4567, b_lm[lm_offset + 32]); \ + acc.s1 += dot(a_reg.s4567, b_lm[lm_offset + 33]); \ + acc.s2 += dot(a_reg.s4567, b_lm[lm_offset + 34]); \ + acc.s3 += dot(a_reg.s4567, b_lm[lm_offset + 35]); \ + acc.s4 += dot(a_reg.s4567, b_lm[lm_offset + 36]); \ + acc.s5 += dot(a_reg.s4567, b_lm[lm_offset + 37]); \ + acc.s6 += dot(a_reg.s4567, b_lm[lm_offset + 38]); \ + acc.s7 += dot(a_reg.s4567, b_lm[lm_offset + 39]); \ + acc.s8 += dot(a_reg.s4567, b_lm[lm_offset + 40]); \ + acc.s9 += dot(a_reg.s4567, b_lm[lm_offset + 41]); \ + acc.sa += dot(a_reg.s4567, b_lm[lm_offset + 42]); \ + acc.sb += dot(a_reg.s4567, b_lm[lm_offset + 43]); \ + acc.sc += dot(a_reg.s4567, b_lm[lm_offset + 44]); \ + acc.sd += dot(a_reg.s4567, b_lm[lm_offset + 45]); \ + acc.se += dot(a_reg.s4567, b_lm[lm_offset + 46]); \ + acc.sf += dot(a_reg.s4567, b_lm[lm_offset + 47]); \ + c_reg.lo += convert_float8(acc.lo); \ + c_reg.hi += convert_float8(acc.hi); \ + acc.s0 = dot(a_reg.s89ab, b_lm[lm_offset + 64]); \ + acc.s1 = dot(a_reg.s89ab, b_lm[lm_offset + 65]); \ + acc.s2 = dot(a_reg.s89ab, b_lm[lm_offset + 66]); \ + acc.s3 = dot(a_reg.s89ab, b_lm[lm_offset + 67]); \ + acc.s4 = dot(a_reg.s89ab, b_lm[lm_offset + 68]); \ + acc.s5 = dot(a_reg.s89ab, b_lm[lm_offset + 69]); \ + acc.s6 = dot(a_reg.s89ab, b_lm[lm_offset + 70]); \ + acc.s7 = dot(a_reg.s89ab, b_lm[lm_offset + 71]); \ + acc.s8 = dot(a_reg.s89ab, b_lm[lm_offset + 72]); \ + acc.s9 = dot(a_reg.s89ab, b_lm[lm_offset + 73]); \ + acc.sa = dot(a_reg.s89ab, b_lm[lm_offset + 74]); \ + acc.sb = dot(a_reg.s89ab, b_lm[lm_offset + 75]); \ + acc.sc = dot(a_reg.s89ab, b_lm[lm_offset + 76]); \ + acc.sd = dot(a_reg.s89ab, b_lm[lm_offset + 77]); \ + acc.se = dot(a_reg.s89ab, b_lm[lm_offset + 78]); \ + acc.sf = dot(a_reg.s89ab, b_lm[lm_offset + 79]); \ + acc.s0 += dot(a_reg.scdef, b_lm[lm_offset + 96]); \ + acc.s1 += dot(a_reg.scdef, b_lm[lm_offset + 97]); \ + acc.s2 += dot(a_reg.scdef, b_lm[lm_offset + 98]); \ + acc.s3 += dot(a_reg.scdef, b_lm[lm_offset + 99]); \ + acc.s4 += dot(a_reg.scdef, b_lm[lm_offset + 100]); \ + acc.s5 += dot(a_reg.scdef, b_lm[lm_offset + 101]); \ + acc.s6 += dot(a_reg.scdef, b_lm[lm_offset + 102]); \ + acc.s7 += dot(a_reg.scdef, b_lm[lm_offset + 103]); \ + acc.s8 += dot(a_reg.scdef, b_lm[lm_offset + 104]); \ + acc.s9 += dot(a_reg.scdef, b_lm[lm_offset + 105]); \ + acc.sa += dot(a_reg.scdef, b_lm[lm_offset + 106]); \ + acc.sb += dot(a_reg.scdef, b_lm[lm_offset + 107]); \ + acc.sc += dot(a_reg.scdef, b_lm[lm_offset + 108]); \ + acc.sd += dot(a_reg.scdef, b_lm[lm_offset + 109]); \ + acc.se += dot(a_reg.scdef, b_lm[lm_offset + 110]); \ + acc.sf += dot(a_reg.scdef, b_lm[lm_offset + 111]); \ + c_reg.lo += convert_float8(acc.lo); \ + c_reg.hi += convert_float8(acc.hi); \ + + +__attribute__((qcom_wave_pair_mode(1))) // 1=force single 2=force pair +kernel void kernel_gemm_moe_q5_0_f32_ns( + __read_only image1d_buffer_t src0_qs, + __global uint * src0_qh, + __global half * src0_d, + __read_only image1d_buffer_t src1, + __global uint * src2, + __global ushort * src2_emap, + __write_only image1d_buffer_t dst, + __global int * total_tiles, + uint ne00, + uint ne01 +) { + uint block_id_m = get_global_id(1); // m_tile + uint block_id_n = get_global_id(2); // n_tile + + // Boundary check + if (block_id_n >= total_tiles[0]) { + return; + } + + __private half16 reg_a; + __private float32 reg_c = (float32)(0); + __local half4 shared_b[128]; + + const ushort expert_id = src2_emap[block_id_n]; + + const uint row = block_id_m * TILESIZE_M; + const uint col = block_id_n * TILESIZE_N; + + uint sub_block_id_m = get_local_id(0); + uint2 b_global_offset; + b_global_offset.x = ((sub_block_id_m & 3) << 2) + (sub_block_id_m >> 2) * ne00; + b_global_offset.y = b_global_offset.x + (16 * ne00); + uint2 b_local_offset; + b_local_offset.x = (sub_block_id_m & 3) * 32 + (sub_block_id_m >> 2); + b_local_offset.y = b_local_offset.x + 16; + + // Loop along K axis, 32 elements (one block) for each iteration, divided into 2 sub-blocks + for (uint step = 0; step < ne00; step += TILESIZE_K * 2) { + // First sub-block + uint q_sub_offset = row + ((ne01 * step) >> 3) + ((expert_id * ne00 * ne01) >> 3); + uint s_sub_offset = row + ((ne01 * step) >> 5) + ((expert_id * ne00 * ne01) >> 5); + uint b_sub_offset = col * ne00 + step; + + // Load scale for current Q5_0 block + uint blk_offset = s_sub_offset + get_global_id(0); + half s = src0_d[blk_offset]; + + // Load 32 qh (5-th bit of each Q5) for the entire block + uchar4 qhx32 = as_uchar4(src0_qh[blk_offset]); + + // Load 16 qs (half block) in transposed layout + uint2 qsx16; + qsx16.x = read_imageui(src0_qs, q_sub_offset + sub_block_id_m).x; + qsx16.y = read_imageui(src0_qs, q_sub_offset + sub_block_id_m + ne01).x; + + // Load 16x32 floats from matrix B, each fiber out of 64 in a sub-group loads 8 elements + float8 bx8_f32; + bx8_f32.lo = read_imagef(src1, (b_sub_offset + b_global_offset.x) / 4); + bx8_f32.hi = read_imagef(src1, (b_sub_offset + b_global_offset.y) / 4); + // Convert to half and store to LM to share within the subgroup + half8 bx8_f16 = convert_half8(bx8_f32); + shared_b[b_local_offset.x] = bx8_f16.lo; + shared_b[b_local_offset.y] = bx8_f16.hi; + + // Dequantization + dequantize_q5_0(as_ushort4(qsx16), qhx32.lo, reg_a, s); + + sub_group_barrier(CLK_LOCAL_MEM_FENCE); + + // 32 16x16 fp16 dot product with 8 elements reduction for better precision + half16 acc; + dotx16_reduce8(reg_a, shared_b, reg_c.lo, 0); + dotx16_reduce8(reg_a, shared_b, reg_c.hi, 16); + + // Repeat for second sub-block + uint half_step = step + TILESIZE_K; + q_sub_offset = row + ((ne01 * half_step) >> 3) + ((expert_id * ne00 * ne01) >> 3); + b_sub_offset = col * ne00 + half_step; + + // Load next 16 qs in transposed layout + qsx16.x = read_imageui(src0_qs, q_sub_offset + sub_block_id_m).x; + qsx16.y = read_imageui(src0_qs, q_sub_offset + sub_block_id_m + ne01).x; + + // Load 16x32 floats from matrix B, each fiber out of 64 in a sub-group loads 8 elements + bx8_f32.lo = read_imagef(src1, (b_sub_offset + b_global_offset.x) / 4); + bx8_f32.hi = read_imagef(src1, (b_sub_offset + b_global_offset.y) / 4); + // Convert to half and store to LM to share within the subgroup + bx8_f16 = convert_half8(bx8_f32); + shared_b[b_local_offset.x] = bx8_f16.lo; + shared_b[b_local_offset.y] = bx8_f16.hi; + + // Dequantization + dequantize_q5_0(as_ushort4(qsx16), qhx32.hi, reg_a, s); + + sub_group_barrier(CLK_LOCAL_MEM_FENCE); + + // 32 16x16 fp16 dot product with 3-levels reduction for better precision + dotx16_reduce8(reg_a, shared_b, reg_c.lo, 0); + dotx16_reduce8(reg_a, shared_b, reg_c.hi, 16); + } + + if ((get_global_id(0) + block_id_m * TILESIZE_M) >= ne01) { + return; + } + + // Load poster router and share in LM + __local uint out_idx[TILESIZE_N]; + + if (get_local_id(0) < TILESIZE_N) { + uint idx = src2[block_id_n * TILESIZE_N + get_local_id(0)]; + if (idx == 0xFFFFFFFF) { + idx = src2[block_id_n * TILESIZE_N + 0]; + } + out_idx[get_local_id(0)] = idx * ne01; + } + + barrier(CLK_LOCAL_MEM_FENCE); + + // Scatter results back to original position in output grid + uint m_offset = row + get_local_id(0); + + write_imagef(dst, out_idx[1] + m_offset, (reg_c.s1)); + write_imagef(dst, out_idx[2] + m_offset, (reg_c.s2)); + write_imagef(dst, out_idx[3] + m_offset, (reg_c.s3)); + write_imagef(dst, out_idx[4] + m_offset, (reg_c.s4)); + write_imagef(dst, out_idx[5] + m_offset, (reg_c.s5)); + write_imagef(dst, out_idx[6] + m_offset, (reg_c.s6)); + write_imagef(dst, out_idx[7] + m_offset, (reg_c.s7)); + write_imagef(dst, out_idx[8] + m_offset, (reg_c.s8)); + write_imagef(dst, out_idx[9] + m_offset, (reg_c.s9)); + write_imagef(dst, out_idx[10] + m_offset, (reg_c.sa)); + write_imagef(dst, out_idx[11] + m_offset, (reg_c.sb)); + write_imagef(dst, out_idx[12] + m_offset, (reg_c.sc)); + write_imagef(dst, out_idx[13] + m_offset, (reg_c.sd)); + write_imagef(dst, out_idx[14] + m_offset, (reg_c.se)); + write_imagef(dst, out_idx[15] + m_offset, (reg_c.sf)); + write_imagef(dst, out_idx[16] + m_offset, (reg_c.sg)); + write_imagef(dst, out_idx[17] + m_offset, (reg_c.sh)); + write_imagef(dst, out_idx[18] + m_offset, (reg_c.si)); + write_imagef(dst, out_idx[19] + m_offset, (reg_c.sj)); + write_imagef(dst, out_idx[20] + m_offset, (reg_c.sk)); + write_imagef(dst, out_idx[21] + m_offset, (reg_c.sl)); + write_imagef(dst, out_idx[22] + m_offset, (reg_c.sm)); + write_imagef(dst, out_idx[23] + m_offset, (reg_c.sn)); + write_imagef(dst, out_idx[24] + m_offset, (reg_c.so)); + write_imagef(dst, out_idx[25] + m_offset, (reg_c.sp)); + write_imagef(dst, out_idx[26] + m_offset, (reg_c.sq)); + write_imagef(dst, out_idx[27] + m_offset, (reg_c.sr)); + write_imagef(dst, out_idx[28] + m_offset, (reg_c.ss)); + write_imagef(dst, out_idx[29] + m_offset, (reg_c.st)); + write_imagef(dst, out_idx[30] + m_offset, (reg_c.su)); + write_imagef(dst, out_idx[31] + m_offset, (reg_c.sv)); + + // Store zero padding parts to the index of first output in tile, override correct result in the end + barrier(CLK_GLOBAL_MEM_FENCE); + write_imagef(dst, out_idx[0] + m_offset, (reg_c.s0)); +} diff --git a/ggml/src/ggml-opencl/kernels/gemm_moe_q5_1_f32_ns.cl b/ggml/src/ggml-opencl/kernels/gemm_moe_q5_1_f32_ns.cl new file mode 100644 index 00000000000..90d345ecf51 --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/gemm_moe_q5_1_f32_ns.cl @@ -0,0 +1,262 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable +#pragma OPENCL EXTENSION cl_khr_subgroups : enable +#pragma OPENCL EXTENSION cl_qcom_subgroup_uniform_load: enable +#pragma OPENCL EXTENSION cl_qcom_subgroup_constant_load: enable +#pragma OPENCL EXTENSION cl_qcom_extra_vector_types : enable + +#define TILESIZE_K 16 +#define TILESIZE_M 64 +#define TILESIZE_N 32 + + +#define dequantize_q5_1(qs5x16, qh5x16, a_f16, scale, m) \ + a_f16.s0 = (half)((( qs5x16.s0 & 0x000F) | (( qh5x16.s0 & 0x01) << 4)) * scale + m); \ + a_f16.s1 = (half)((((qs5x16.s0 & 0x00F0) >> 4 ) | (((qh5x16.s0 >> 1) & 0x01) << 4)) * scale + m); \ + a_f16.s2 = (half)((((qs5x16.s0 & 0x0F00) >> 8 ) | (((qh5x16.s0 >> 2) & 0x01) << 4)) * scale + m); \ + a_f16.s3 = (half)((((qs5x16.s0 & 0xF000) >> 12) | (((qh5x16.s0 >> 3) & 0x01) << 4)) * scale + m); \ + a_f16.s4 = (half)((( qs5x16.s1 & 0x000F) | (((qh5x16.s0 >> 4) & 0x01) << 4)) * scale + m); \ + a_f16.s5 = (half)((((qs5x16.s1 & 0x00F0) >> 4 ) | (((qh5x16.s0 >> 5) & 0x01) << 4)) * scale + m); \ + a_f16.s6 = (half)((((qs5x16.s1 & 0x0F00) >> 8 ) | (((qh5x16.s0 >> 6) & 0x01) << 4)) * scale + m); \ + a_f16.s7 = (half)((((qs5x16.s1 & 0xF000) >> 12) | (((qh5x16.s0 >> 7) & 0x01) << 4)) * scale + m); \ + a_f16.s8 = (half)((( qs5x16.s2 & 0x000F) | (( qh5x16.s1 & 0x01) << 4)) * scale + m); \ + a_f16.s9 = (half)((((qs5x16.s2 & 0x00F0) >> 4 ) | (((qh5x16.s1 >> 1) & 0x01) << 4)) * scale + m); \ + a_f16.sa = (half)((((qs5x16.s2 & 0x0F00) >> 8 ) | (((qh5x16.s1 >> 2) & 0x01) << 4)) * scale + m); \ + a_f16.sb = (half)((((qs5x16.s2 & 0xF000) >> 12) | (((qh5x16.s1 >> 3) & 0x01) << 4)) * scale + m); \ + a_f16.sc = (half)((( qs5x16.s3 & 0x000F) | (((qh5x16.s1 >> 4) & 0x01) << 4)) * scale + m); \ + a_f16.sd = (half)((((qs5x16.s3 & 0x00F0) >> 4 ) | (((qh5x16.s1 >> 5) & 0x01) << 4)) * scale + m); \ + a_f16.se = (half)((((qs5x16.s3 & 0x0F00) >> 8 ) | (((qh5x16.s1 >> 6) & 0x01) << 4)) * scale + m); \ + a_f16.sf = (half)((((qs5x16.s3 & 0xF000) >> 12) | (((qh5x16.s1 >> 7) & 0x01) << 4)) * scale + m); \ + + +#define dotx16_reduce8(a_reg, b_lm, c_reg, lm_offset) \ + acc.s0 = dot(a_reg.s0123, b_lm[lm_offset + 0]); \ + acc.s1 = dot(a_reg.s0123, b_lm[lm_offset + 1]); \ + acc.s2 = dot(a_reg.s0123, b_lm[lm_offset + 2]); \ + acc.s3 = dot(a_reg.s0123, b_lm[lm_offset + 3]); \ + acc.s4 = dot(a_reg.s0123, b_lm[lm_offset + 4]); \ + acc.s5 = dot(a_reg.s0123, b_lm[lm_offset + 5]); \ + acc.s6 = dot(a_reg.s0123, b_lm[lm_offset + 6]); \ + acc.s7 = dot(a_reg.s0123, b_lm[lm_offset + 7]); \ + acc.s8 = dot(a_reg.s0123, b_lm[lm_offset + 8]); \ + acc.s9 = dot(a_reg.s0123, b_lm[lm_offset + 9]); \ + acc.sa = dot(a_reg.s0123, b_lm[lm_offset + 10]); \ + acc.sb = dot(a_reg.s0123, b_lm[lm_offset + 11]); \ + acc.sc = dot(a_reg.s0123, b_lm[lm_offset + 12]); \ + acc.sd = dot(a_reg.s0123, b_lm[lm_offset + 13]); \ + acc.se = dot(a_reg.s0123, b_lm[lm_offset + 14]); \ + acc.sf = dot(a_reg.s0123, b_lm[lm_offset + 15]); \ + acc.s0 += dot(a_reg.s4567, b_lm[lm_offset + 32]); \ + acc.s1 += dot(a_reg.s4567, b_lm[lm_offset + 33]); \ + acc.s2 += dot(a_reg.s4567, b_lm[lm_offset + 34]); \ + acc.s3 += dot(a_reg.s4567, b_lm[lm_offset + 35]); \ + acc.s4 += dot(a_reg.s4567, b_lm[lm_offset + 36]); \ + acc.s5 += dot(a_reg.s4567, b_lm[lm_offset + 37]); \ + acc.s6 += dot(a_reg.s4567, b_lm[lm_offset + 38]); \ + acc.s7 += dot(a_reg.s4567, b_lm[lm_offset + 39]); \ + acc.s8 += dot(a_reg.s4567, b_lm[lm_offset + 40]); \ + acc.s9 += dot(a_reg.s4567, b_lm[lm_offset + 41]); \ + acc.sa += dot(a_reg.s4567, b_lm[lm_offset + 42]); \ + acc.sb += dot(a_reg.s4567, b_lm[lm_offset + 43]); \ + acc.sc += dot(a_reg.s4567, b_lm[lm_offset + 44]); \ + acc.sd += dot(a_reg.s4567, b_lm[lm_offset + 45]); \ + acc.se += dot(a_reg.s4567, b_lm[lm_offset + 46]); \ + acc.sf += dot(a_reg.s4567, b_lm[lm_offset + 47]); \ + c_reg.lo += convert_float8(acc.lo); \ + c_reg.hi += convert_float8(acc.hi); \ + acc.s0 = dot(a_reg.s89ab, b_lm[lm_offset + 64]); \ + acc.s1 = dot(a_reg.s89ab, b_lm[lm_offset + 65]); \ + acc.s2 = dot(a_reg.s89ab, b_lm[lm_offset + 66]); \ + acc.s3 = dot(a_reg.s89ab, b_lm[lm_offset + 67]); \ + acc.s4 = dot(a_reg.s89ab, b_lm[lm_offset + 68]); \ + acc.s5 = dot(a_reg.s89ab, b_lm[lm_offset + 69]); \ + acc.s6 = dot(a_reg.s89ab, b_lm[lm_offset + 70]); \ + acc.s7 = dot(a_reg.s89ab, b_lm[lm_offset + 71]); \ + acc.s8 = dot(a_reg.s89ab, b_lm[lm_offset + 72]); \ + acc.s9 = dot(a_reg.s89ab, b_lm[lm_offset + 73]); \ + acc.sa = dot(a_reg.s89ab, b_lm[lm_offset + 74]); \ + acc.sb = dot(a_reg.s89ab, b_lm[lm_offset + 75]); \ + acc.sc = dot(a_reg.s89ab, b_lm[lm_offset + 76]); \ + acc.sd = dot(a_reg.s89ab, b_lm[lm_offset + 77]); \ + acc.se = dot(a_reg.s89ab, b_lm[lm_offset + 78]); \ + acc.sf = dot(a_reg.s89ab, b_lm[lm_offset + 79]); \ + acc.s0 += dot(a_reg.scdef, b_lm[lm_offset + 96]); \ + acc.s1 += dot(a_reg.scdef, b_lm[lm_offset + 97]); \ + acc.s2 += dot(a_reg.scdef, b_lm[lm_offset + 98]); \ + acc.s3 += dot(a_reg.scdef, b_lm[lm_offset + 99]); \ + acc.s4 += dot(a_reg.scdef, b_lm[lm_offset + 100]); \ + acc.s5 += dot(a_reg.scdef, b_lm[lm_offset + 101]); \ + acc.s6 += dot(a_reg.scdef, b_lm[lm_offset + 102]); \ + acc.s7 += dot(a_reg.scdef, b_lm[lm_offset + 103]); \ + acc.s8 += dot(a_reg.scdef, b_lm[lm_offset + 104]); \ + acc.s9 += dot(a_reg.scdef, b_lm[lm_offset + 105]); \ + acc.sa += dot(a_reg.scdef, b_lm[lm_offset + 106]); \ + acc.sb += dot(a_reg.scdef, b_lm[lm_offset + 107]); \ + acc.sc += dot(a_reg.scdef, b_lm[lm_offset + 108]); \ + acc.sd += dot(a_reg.scdef, b_lm[lm_offset + 109]); \ + acc.se += dot(a_reg.scdef, b_lm[lm_offset + 110]); \ + acc.sf += dot(a_reg.scdef, b_lm[lm_offset + 111]); \ + c_reg.lo += convert_float8(acc.lo); \ + c_reg.hi += convert_float8(acc.hi); \ + + +__attribute__((qcom_wave_pair_mode(1))) // 1=force single 2=force pair +kernel void kernel_gemm_moe_q5_1_f32_ns( + __read_only image1d_buffer_t src0_qs, + __global uint * src0_qh, + __global half * src0_d, + __global half * src0_m, + __read_only image1d_buffer_t src1, + __global uint * src2, + __global ushort * src2_emap, + __write_only image1d_buffer_t dst, + __global int * total_tiles, + uint ne00, + uint ne01 +) { + uint block_id_m = get_global_id(1); // m_tile + uint block_id_n = get_global_id(2); // n_tile + + // Boundary check + if (block_id_n >= total_tiles[0]) { + return; + } + + __private half16 reg_a; + __private float32 reg_c = (float32)(0); + __local half4 shared_b[128]; + + const ushort expert_id = src2_emap[block_id_n]; + + const uint row = block_id_m * TILESIZE_M; + const uint col = block_id_n * TILESIZE_N; + + uint sub_block_id_m = get_local_id(0); + uint2 b_global_offset; + b_global_offset.x = ((sub_block_id_m & 3) << 2) + (sub_block_id_m >> 2) * ne00; + b_global_offset.y = b_global_offset.x + (16 * ne00); + uint2 b_local_offset; + b_local_offset.x = (sub_block_id_m & 3) * 32 + (sub_block_id_m >> 2); + b_local_offset.y = b_local_offset.x + 16; + + // Loop along K axis, 32 elements (one block) for each iteration, divided into 2 sub-blocks + for (uint step = 0; step < ne00; step += TILESIZE_K * 2) { + // First sub-block + uint q_sub_offset = row + ((ne01 * step) >> 3) + ((expert_id * ne00 * ne01) >> 3); + uint s_sub_offset = row + ((ne01 * step) >> 5) + ((expert_id * ne00 * ne01) >> 5); + uint b_sub_offset = col * ne00 + step; + + // Load scale and m for current Q5_1 block + uint blk_offset = s_sub_offset + get_global_id(0); + half s = src0_d[blk_offset]; + half m = src0_m[blk_offset]; + + // Load 32 qh (5-th bit of each Q5) for the entire block + uchar4 qhx32 = as_uchar4(src0_qh[blk_offset]); + + // Load 16 qs (half block) in transposed layout + uint2 qsx16; + qsx16.x = read_imageui(src0_qs, q_sub_offset + sub_block_id_m).x; + qsx16.y = read_imageui(src0_qs, q_sub_offset + sub_block_id_m + ne01).x; + + // Load 16x32 floats from matrix B, each fiber out of 64 in a sub-group loads 8 elements + float8 bx8_f32; + bx8_f32.lo = read_imagef(src1, (b_sub_offset + b_global_offset.x) / 4); + bx8_f32.hi = read_imagef(src1, (b_sub_offset + b_global_offset.y) / 4); + // Convert to half and store to LM to share within the subgroup + half8 bx8_f16 = convert_half8(bx8_f32); + shared_b[b_local_offset.x] = bx8_f16.lo; + shared_b[b_local_offset.y] = bx8_f16.hi; + + // Dequantization + dequantize_q5_1(as_ushort4(qsx16), qhx32.lo, reg_a, s, m); + + sub_group_barrier(CLK_LOCAL_MEM_FENCE); + + // 32 16x16 fp16 dot product with 8 elements reduction for better precision + half16 acc; + dotx16_reduce8(reg_a, shared_b, reg_c.lo, 0); + dotx16_reduce8(reg_a, shared_b, reg_c.hi, 16); + + // Repeat for second sub-block + uint half_step = step + TILESIZE_K; + q_sub_offset = row + ((ne01 * half_step) >> 3) + ((expert_id * ne00 * ne01) >> 3); + b_sub_offset = col * ne00 + half_step; + + // Load next 16 qs in transposed layout + qsx16.x = read_imageui(src0_qs, q_sub_offset + sub_block_id_m).x; + qsx16.y = read_imageui(src0_qs, q_sub_offset + sub_block_id_m + ne01).x; + + // Load 16x32 floats from matrix B, each fiber out of 64 in a sub-group loads 8 elements + bx8_f32.lo = read_imagef(src1, (b_sub_offset + b_global_offset.x) / 4); + bx8_f32.hi = read_imagef(src1, (b_sub_offset + b_global_offset.y) / 4); + // Convert to half and store to LM to share within the subgroup + bx8_f16 = convert_half8(bx8_f32); + shared_b[b_local_offset.x] = bx8_f16.lo; + shared_b[b_local_offset.y] = bx8_f16.hi; + + // Dequantization + dequantize_q5_1(as_ushort4(qsx16), qhx32.hi, reg_a, s, m); + + sub_group_barrier(CLK_LOCAL_MEM_FENCE); + + // 32 16x16 fp16 dot product with 3-levels reduction for better precision + dotx16_reduce8(reg_a, shared_b, reg_c.lo, 0); + dotx16_reduce8(reg_a, shared_b, reg_c.hi, 16); + } + + if ((get_global_id(0) + block_id_m * TILESIZE_M) >= ne01) { + return; + } + + // Load poster router and share in LM + __local uint out_idx[TILESIZE_N]; + + if (get_local_id(0) < TILESIZE_N) { + uint idx = src2[block_id_n * TILESIZE_N + get_local_id(0)]; + if (idx == 0xFFFFFFFF) { + idx = src2[block_id_n * TILESIZE_N + 0]; + } + out_idx[get_local_id(0)] = idx * ne01; + } + + barrier(CLK_LOCAL_MEM_FENCE); + + // Scatter results back to original position in output grid + uint m_offset = row + get_local_id(0); + + write_imagef(dst, out_idx[1] + m_offset, (reg_c.s1)); + write_imagef(dst, out_idx[2] + m_offset, (reg_c.s2)); + write_imagef(dst, out_idx[3] + m_offset, (reg_c.s3)); + write_imagef(dst, out_idx[4] + m_offset, (reg_c.s4)); + write_imagef(dst, out_idx[5] + m_offset, (reg_c.s5)); + write_imagef(dst, out_idx[6] + m_offset, (reg_c.s6)); + write_imagef(dst, out_idx[7] + m_offset, (reg_c.s7)); + write_imagef(dst, out_idx[8] + m_offset, (reg_c.s8)); + write_imagef(dst, out_idx[9] + m_offset, (reg_c.s9)); + write_imagef(dst, out_idx[10] + m_offset, (reg_c.sa)); + write_imagef(dst, out_idx[11] + m_offset, (reg_c.sb)); + write_imagef(dst, out_idx[12] + m_offset, (reg_c.sc)); + write_imagef(dst, out_idx[13] + m_offset, (reg_c.sd)); + write_imagef(dst, out_idx[14] + m_offset, (reg_c.se)); + write_imagef(dst, out_idx[15] + m_offset, (reg_c.sf)); + write_imagef(dst, out_idx[16] + m_offset, (reg_c.sg)); + write_imagef(dst, out_idx[17] + m_offset, (reg_c.sh)); + write_imagef(dst, out_idx[18] + m_offset, (reg_c.si)); + write_imagef(dst, out_idx[19] + m_offset, (reg_c.sj)); + write_imagef(dst, out_idx[20] + m_offset, (reg_c.sk)); + write_imagef(dst, out_idx[21] + m_offset, (reg_c.sl)); + write_imagef(dst, out_idx[22] + m_offset, (reg_c.sm)); + write_imagef(dst, out_idx[23] + m_offset, (reg_c.sn)); + write_imagef(dst, out_idx[24] + m_offset, (reg_c.so)); + write_imagef(dst, out_idx[25] + m_offset, (reg_c.sp)); + write_imagef(dst, out_idx[26] + m_offset, (reg_c.sq)); + write_imagef(dst, out_idx[27] + m_offset, (reg_c.sr)); + write_imagef(dst, out_idx[28] + m_offset, (reg_c.ss)); + write_imagef(dst, out_idx[29] + m_offset, (reg_c.st)); + write_imagef(dst, out_idx[30] + m_offset, (reg_c.su)); + write_imagef(dst, out_idx[31] + m_offset, (reg_c.sv)); + + // Store zero padding parts to the index of first output in tile, override correct result in the end + barrier(CLK_GLOBAL_MEM_FENCE); + write_imagef(dst, out_idx[0] + m_offset, (reg_c.s0)); +} diff --git a/ggml/src/ggml-opencl/kernels/gemm_moe_q5_k_f32_ns.cl b/ggml/src/ggml-opencl/kernels/gemm_moe_q5_k_f32_ns.cl new file mode 100644 index 00000000000..13c26f6f3b6 --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/gemm_moe_q5_k_f32_ns.cl @@ -0,0 +1,288 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable +#pragma OPENCL EXTENSION cl_khr_subgroups : enable +#pragma OPENCL EXTENSION cl_qcom_subgroup_uniform_load: enable +#pragma OPENCL EXTENSION cl_qcom_subgroup_constant_load: enable +#pragma OPENCL EXTENSION cl_qcom_extra_vector_types : enable + +#define TILESIZE_K 16 +#define TILESIZE_M 64 +#define TILESIZE_N 32 +#define QK_K 256 +#define K_SCALE_SIZE 12 + +inline void get_scale_min_k4( + int j, + global const uchar * q, + uchar * d, + uchar * m +) { + if (j < 4) { + *d = q[j] & 63; + *m = q[j+4] & 63; + } else { + *d = (q[j+4] & 0x0F) | ((q[j-4] & 0xC0) >> 2); + *m = ((q[j+4] >> 4) & 0x0F) | ((q[j] & 0xC0) >> 2); + } +} + +#define dequantize_q5_k(qs5x16, qh5x16, a_f16, scale, m) \ + a_f16.s0 = (half)((float)(( qs5x16.s0 & 0x000F) | (( qh5x16.s0 & 0x01) << 4)) * scale + m); \ + a_f16.s1 = (half)((float)((((qs5x16.s0 & 0x00F0) >> 4 ) | (((qh5x16.s0 >> 1) & 0x01) << 4)) * scale + m)); \ + a_f16.s2 = (half)((float)((((qs5x16.s0 & 0x0F00) >> 8 ) | (((qh5x16.s0 >> 2) & 0x01) << 4)) * scale + m)); \ + a_f16.s3 = (half)((float)((((qs5x16.s0 & 0xF000) >> 12) | (((qh5x16.s0 >> 3) & 0x01) << 4)) * scale + m)); \ + a_f16.s4 = (half)((float)((( qs5x16.s1 & 0x000F) | (((qh5x16.s0 >> 4) & 0x01) << 4)) * scale + m)); \ + a_f16.s5 = (half)((float)((((qs5x16.s1 & 0x00F0) >> 4 ) | (((qh5x16.s0 >> 5) & 0x01) << 4)) * scale + m)); \ + a_f16.s6 = (half)((float)(((qs5x16.s1 & 0x0F00) >> 8 ) | (((qh5x16.s0 >> 6) & 0x01) << 4)) * scale + m); \ + a_f16.s7 = (half)((float)((((qs5x16.s1 & 0xF000) >> 12) | (((qh5x16.s0 >> 7) & 0x01) << 4)) * scale + m)); \ + a_f16.s8 = (half)((float)((( qs5x16.s2 & 0x000F) | (( qh5x16.s1 & 0x01) << 4)) * scale + m)); \ + a_f16.s9 = (half)((float)((((qs5x16.s2 & 0x00F0) >> 4 ) | (((qh5x16.s1 >> 1) & 0x01) << 4)) * scale + m)); \ + a_f16.sa = (half)((float)((((qs5x16.s2 & 0x0F00) >> 8 ) | (((qh5x16.s1 >> 2) & 0x01) << 4)) * scale + m)); \ + a_f16.sb = (half)((float)((((qs5x16.s2 & 0xF000) >> 12) | (((qh5x16.s1 >> 3) & 0x01) << 4)) * scale + m)); \ + a_f16.sc = (half)((float)((( qs5x16.s3 & 0x000F) | (((qh5x16.s1 >> 4) & 0x01) << 4)) * scale + m)); \ + a_f16.sd = (half)((float)((((qs5x16.s3 & 0x00F0) >> 4 ) | (((qh5x16.s1 >> 5) & 0x01) << 4)) * scale + m)); \ + a_f16.se = (half)((float)((((qs5x16.s3 & 0x0F00) >> 8 ) | (((qh5x16.s1 >> 6) & 0x01) << 4)) * scale + m)); \ + a_f16.sf = (half)((float)((((qs5x16.s3 & 0xF000) >> 12) | (((qh5x16.s1 >> 7) & 0x01) << 4)) * scale + m)); \ + + +#define dotx16_reduce8(a_reg, b_lm, c_reg, lm_offset) \ + acc.s0 = dot(a_reg.s0123, b_lm[lm_offset + 0]); \ + acc.s1 = dot(a_reg.s0123, b_lm[lm_offset + 1]); \ + acc.s2 = dot(a_reg.s0123, b_lm[lm_offset + 2]); \ + acc.s3 = dot(a_reg.s0123, b_lm[lm_offset + 3]); \ + acc.s4 = dot(a_reg.s0123, b_lm[lm_offset + 4]); \ + acc.s5 = dot(a_reg.s0123, b_lm[lm_offset + 5]); \ + acc.s6 = dot(a_reg.s0123, b_lm[lm_offset + 6]); \ + acc.s7 = dot(a_reg.s0123, b_lm[lm_offset + 7]); \ + acc.s8 = dot(a_reg.s0123, b_lm[lm_offset + 8]); \ + acc.s9 = dot(a_reg.s0123, b_lm[lm_offset + 9]); \ + acc.sa = dot(a_reg.s0123, b_lm[lm_offset + 10]); \ + acc.sb = dot(a_reg.s0123, b_lm[lm_offset + 11]); \ + acc.sc = dot(a_reg.s0123, b_lm[lm_offset + 12]); \ + acc.sd = dot(a_reg.s0123, b_lm[lm_offset + 13]); \ + acc.se = dot(a_reg.s0123, b_lm[lm_offset + 14]); \ + acc.sf = dot(a_reg.s0123, b_lm[lm_offset + 15]); \ + acc.s0 += dot(a_reg.s4567, b_lm[lm_offset + 32]); \ + acc.s1 += dot(a_reg.s4567, b_lm[lm_offset + 33]); \ + acc.s2 += dot(a_reg.s4567, b_lm[lm_offset + 34]); \ + acc.s3 += dot(a_reg.s4567, b_lm[lm_offset + 35]); \ + acc.s4 += dot(a_reg.s4567, b_lm[lm_offset + 36]); \ + acc.s5 += dot(a_reg.s4567, b_lm[lm_offset + 37]); \ + acc.s6 += dot(a_reg.s4567, b_lm[lm_offset + 38]); \ + acc.s7 += dot(a_reg.s4567, b_lm[lm_offset + 39]); \ + acc.s8 += dot(a_reg.s4567, b_lm[lm_offset + 40]); \ + acc.s9 += dot(a_reg.s4567, b_lm[lm_offset + 41]); \ + acc.sa += dot(a_reg.s4567, b_lm[lm_offset + 42]); \ + acc.sb += dot(a_reg.s4567, b_lm[lm_offset + 43]); \ + acc.sc += dot(a_reg.s4567, b_lm[lm_offset + 44]); \ + acc.sd += dot(a_reg.s4567, b_lm[lm_offset + 45]); \ + acc.se += dot(a_reg.s4567, b_lm[lm_offset + 46]); \ + acc.sf += dot(a_reg.s4567, b_lm[lm_offset + 47]); \ + c_reg.lo += convert_float8(acc.lo); \ + c_reg.hi += convert_float8(acc.hi); \ + acc.s0 = dot(a_reg.s89ab, b_lm[lm_offset + 64]); \ + acc.s1 = dot(a_reg.s89ab, b_lm[lm_offset + 65]); \ + acc.s2 = dot(a_reg.s89ab, b_lm[lm_offset + 66]); \ + acc.s3 = dot(a_reg.s89ab, b_lm[lm_offset + 67]); \ + acc.s4 = dot(a_reg.s89ab, b_lm[lm_offset + 68]); \ + acc.s5 = dot(a_reg.s89ab, b_lm[lm_offset + 69]); \ + acc.s6 = dot(a_reg.s89ab, b_lm[lm_offset + 70]); \ + acc.s7 = dot(a_reg.s89ab, b_lm[lm_offset + 71]); \ + acc.s8 = dot(a_reg.s89ab, b_lm[lm_offset + 72]); \ + acc.s9 = dot(a_reg.s89ab, b_lm[lm_offset + 73]); \ + acc.sa = dot(a_reg.s89ab, b_lm[lm_offset + 74]); \ + acc.sb = dot(a_reg.s89ab, b_lm[lm_offset + 75]); \ + acc.sc = dot(a_reg.s89ab, b_lm[lm_offset + 76]); \ + acc.sd = dot(a_reg.s89ab, b_lm[lm_offset + 77]); \ + acc.se = dot(a_reg.s89ab, b_lm[lm_offset + 78]); \ + acc.sf = dot(a_reg.s89ab, b_lm[lm_offset + 79]); \ + acc.s0 += dot(a_reg.scdef, b_lm[lm_offset + 96]); \ + acc.s1 += dot(a_reg.scdef, b_lm[lm_offset + 97]); \ + acc.s2 += dot(a_reg.scdef, b_lm[lm_offset + 98]); \ + acc.s3 += dot(a_reg.scdef, b_lm[lm_offset + 99]); \ + acc.s4 += dot(a_reg.scdef, b_lm[lm_offset + 100]); \ + acc.s5 += dot(a_reg.scdef, b_lm[lm_offset + 101]); \ + acc.s6 += dot(a_reg.scdef, b_lm[lm_offset + 102]); \ + acc.s7 += dot(a_reg.scdef, b_lm[lm_offset + 103]); \ + acc.s8 += dot(a_reg.scdef, b_lm[lm_offset + 104]); \ + acc.s9 += dot(a_reg.scdef, b_lm[lm_offset + 105]); \ + acc.sa += dot(a_reg.scdef, b_lm[lm_offset + 106]); \ + acc.sb += dot(a_reg.scdef, b_lm[lm_offset + 107]); \ + acc.sc += dot(a_reg.scdef, b_lm[lm_offset + 108]); \ + acc.sd += dot(a_reg.scdef, b_lm[lm_offset + 109]); \ + acc.se += dot(a_reg.scdef, b_lm[lm_offset + 110]); \ + acc.sf += dot(a_reg.scdef, b_lm[lm_offset + 111]); \ + c_reg.lo += convert_float8(acc.lo); \ + c_reg.hi += convert_float8(acc.hi); \ + + +__attribute__((qcom_wave_pair_mode(1))) +kernel void kernel_gemm_moe_q5_k_f32_ns( + __read_only image1d_buffer_t src0_q, + __global uint * src0_qh, + __global uchar * src0_s, + __global half * src0_d, + __global half * src0_dm, + __read_only image1d_buffer_t src1, + __global uint * src2, + __global ushort * src2_emap, + __write_only image1d_buffer_t dst, + __global int * total_tiles, + uint ne00, + uint ne01 +) { + uint block_id_m = get_global_id(1); // m_tile + uint block_id_n = get_global_id(2); // n_tile + + // Boundary check + if (block_id_n >= total_tiles[0]) { + return; + } + + __private half16 reg_a; + __private float32 reg_c = (float32)(0); + __local half4 shared_b[128]; + + const ushort expert_id = src2_emap[block_id_n]; + + const uint row = block_id_m * TILESIZE_M; + const uint col = block_id_n * TILESIZE_N; + + uint sub_block_id_m = get_local_id(0); + uint2 b_global_offset; + b_global_offset.x = ((sub_block_id_m & 3) << 2) + (sub_block_id_m >> 2) * ne00; + b_global_offset.y = b_global_offset.x + (16 * ne00); + uint2 b_local_offset; + b_local_offset.x = (sub_block_id_m & 3) * 32 + (sub_block_id_m >> 2); + b_local_offset.y = b_local_offset.x + 16; + + uint num_superblocks = ne00 / QK_K; + uint scales_per_row = num_superblocks * K_SCALE_SIZE; + uint row_idx = row + get_global_id(0); + + // Loop along K axis, 32 elements per iteration (one sub-block), divided into 2 halves of 16 + for (uint step = 0; step < ne00; step += TILESIZE_K * 2) { + uint sub = step / 32; + uint sb = sub / 8; + uint j = sub % 8; + + // Load d and dm for super-block + uint d_offset = row + sb * ne01 + expert_id * num_superblocks * ne01 + get_global_id(0); + half d_val = src0_d[d_offset]; + half dm_val = src0_dm[d_offset]; + + // Load sub-block scale and min + global const uchar * sc = src0_s + (expert_id * ne01 + row_idx) * scales_per_row + sb * K_SCALE_SIZE; + uchar sv, mn; + get_scale_min_k4(j, sc, &sv, &mn); + + float scale = (float)d_val * (float)sv; + float minv = -(float)dm_val * (float)mn; + + // qh is stored at sub-block granularity + uint qh_offset = row + sub * ne01 + expert_id * num_superblocks * 8 * ne01 + get_global_id(0); + uchar4 qhx32 = as_uchar4(src0_qh[qh_offset]); + + // First sub-block (16 elements) + uint q_sub_offset = row + ((ne01 * step) >> 3) + ((expert_id * ne00 * ne01) >> 3); + uint b_sub_offset = col * ne00 + step; + + // Load 16 q (64-bits) in transposed layout + uint2 q4x16; + q4x16.x = read_imageui(src0_q, q_sub_offset + sub_block_id_m).x; + q4x16.y = read_imageui(src0_q, q_sub_offset + sub_block_id_m + ne01).x; + + // Load 16x32 floats from matrix B + float8 bx8_f32; + bx8_f32.lo = read_imagef(src1, (b_sub_offset + b_global_offset.x) / 4); + bx8_f32.hi = read_imagef(src1, (b_sub_offset + b_global_offset.y) / 4); + half8 bx8_f16 = convert_half8(bx8_f32); + shared_b[b_local_offset.x] = bx8_f16.lo; + shared_b[b_local_offset.y] = bx8_f16.hi; + + // Dequantization + dequantize_q5_k(as_ushort4(q4x16), qhx32.lo, reg_a, scale, minv); + + sub_group_barrier(CLK_LOCAL_MEM_FENCE); + + half16 acc; + dotx16_reduce8(reg_a, shared_b, reg_c.lo, 0); + dotx16_reduce8(reg_a, shared_b, reg_c.hi, 16); + + // Second half + uint half_step = step + TILESIZE_K; + q_sub_offset = row + ((ne01 * half_step) >> 3) + ((expert_id * ne00 * ne01) >> 3); + b_sub_offset = col * ne00 + half_step; + + q4x16.x = read_imageui(src0_q, q_sub_offset + sub_block_id_m).x; + q4x16.y = read_imageui(src0_q, q_sub_offset + sub_block_id_m + ne01).x; + + bx8_f32.lo = read_imagef(src1, (b_sub_offset + b_global_offset.x) / 4); + bx8_f32.hi = read_imagef(src1, (b_sub_offset + b_global_offset.y) / 4); + bx8_f16 = convert_half8(bx8_f32); + shared_b[b_local_offset.x] = bx8_f16.lo; + shared_b[b_local_offset.y] = bx8_f16.hi; + + dequantize_q5_k(as_ushort4(q4x16), qhx32.hi, reg_a, scale, minv); + + sub_group_barrier(CLK_LOCAL_MEM_FENCE); + + dotx16_reduce8(reg_a, shared_b, reg_c.lo, 0); + dotx16_reduce8(reg_a, shared_b, reg_c.hi, 16); + } + + if ((get_global_id(0) + block_id_m * TILESIZE_M) >= ne01) { + return; + } + + // Load post router and share in LM + __local uint out_idx[TILESIZE_N]; + + if (get_local_id(0) < TILESIZE_N) { + uint idx = src2[block_id_n * TILESIZE_N + get_local_id(0)]; + if (idx == 0xFFFFFFFF) { + idx = src2[block_id_n * TILESIZE_N + 0]; + } + out_idx[get_local_id(0)] = idx * ne01; + } + + barrier(CLK_LOCAL_MEM_FENCE); + + // Scatter results back to original position in output grid + uint m_offset = row + get_local_id(0); + + write_imagef(dst, out_idx[1] + m_offset, (reg_c.s1)); + write_imagef(dst, out_idx[2] + m_offset, (reg_c.s2)); + write_imagef(dst, out_idx[3] + m_offset, (reg_c.s3)); + write_imagef(dst, out_idx[4] + m_offset, (reg_c.s4)); + write_imagef(dst, out_idx[5] + m_offset, (reg_c.s5)); + write_imagef(dst, out_idx[6] + m_offset, (reg_c.s6)); + write_imagef(dst, out_idx[7] + m_offset, (reg_c.s7)); + write_imagef(dst, out_idx[8] + m_offset, (reg_c.s8)); + write_imagef(dst, out_idx[9] + m_offset, (reg_c.s9)); + write_imagef(dst, out_idx[10] + m_offset, (reg_c.sa)); + write_imagef(dst, out_idx[11] + m_offset, (reg_c.sb)); + write_imagef(dst, out_idx[12] + m_offset, (reg_c.sc)); + write_imagef(dst, out_idx[13] + m_offset, (reg_c.sd)); + write_imagef(dst, out_idx[14] + m_offset, (reg_c.se)); + write_imagef(dst, out_idx[15] + m_offset, (reg_c.sf)); + write_imagef(dst, out_idx[16] + m_offset, (reg_c.sg)); + write_imagef(dst, out_idx[17] + m_offset, (reg_c.sh)); + write_imagef(dst, out_idx[18] + m_offset, (reg_c.si)); + write_imagef(dst, out_idx[19] + m_offset, (reg_c.sj)); + write_imagef(dst, out_idx[20] + m_offset, (reg_c.sk)); + write_imagef(dst, out_idx[21] + m_offset, (reg_c.sl)); + write_imagef(dst, out_idx[22] + m_offset, (reg_c.sm)); + write_imagef(dst, out_idx[23] + m_offset, (reg_c.sn)); + write_imagef(dst, out_idx[24] + m_offset, (reg_c.so)); + write_imagef(dst, out_idx[25] + m_offset, (reg_c.sp)); + write_imagef(dst, out_idx[26] + m_offset, (reg_c.sq)); + write_imagef(dst, out_idx[27] + m_offset, (reg_c.sr)); + write_imagef(dst, out_idx[28] + m_offset, (reg_c.ss)); + write_imagef(dst, out_idx[29] + m_offset, (reg_c.st)); + write_imagef(dst, out_idx[30] + m_offset, (reg_c.su)); + write_imagef(dst, out_idx[31] + m_offset, (reg_c.sv)); + + // Store zero padding parts to the index of first output in tile + barrier(CLK_GLOBAL_MEM_FENCE); + write_imagef(dst, out_idx[0] + m_offset, (reg_c.s0)); +} diff --git a/ggml/src/ggml-opencl/kernels/gemm_moe_q6_k_f32_ns.cl b/ggml/src/ggml-opencl/kernels/gemm_moe_q6_k_f32_ns.cl new file mode 100644 index 00000000000..85ccebec78c --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/gemm_moe_q6_k_f32_ns.cl @@ -0,0 +1,267 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable +#pragma OPENCL EXTENSION cl_khr_subgroups : enable +#pragma OPENCL EXTENSION cl_qcom_subgroup_uniform_load: enable +#pragma OPENCL EXTENSION cl_qcom_subgroup_constant_load: enable +#pragma OPENCL EXTENSION cl_qcom_extra_vector_types : enable + +#define TILESIZE_K 16 +#define TILESIZE_M 64 +#define TILESIZE_N 32 +#define QK_K 256 + +#define dequantize_q6_k(qs16, qh16, a_f16, scale) \ + a_f16.s0 = (half)(((float)(( qs16.s0 & 0x000F) | ((uint)(( qh16 ) & 0x3) << 4)) - 32.f) * scale); \ + a_f16.s1 = (half)(((float)((( qs16.s0 >> 4) & 0x000F) | ((uint)(( qh16 >> 2) & 0x3) << 4)) - 32.f) * scale); \ + a_f16.s2 = (half)(((float)((( qs16.s0 >> 8) & 0x000F) | ((uint)(( qh16 >> 4) & 0x3) << 4)) - 32.f) * scale); \ + a_f16.s3 = (half)(((float)((( qs16.s0 >>12) & 0x000F) | ((uint)(( qh16 >> 6) & 0x3) << 4)) - 32.f) * scale); \ + a_f16.s4 = (half)(((float)(( qs16.s1 & 0x000F) | ((uint)(( qh16 >> 8) & 0x3) << 4)) - 32.f) * scale); \ + a_f16.s5 = (half)(((float)((( qs16.s1 >> 4) & 0x000F) | ((uint)(( qh16 >> 10) & 0x3) << 4)) - 32.f) * scale); \ + a_f16.s6 = (half)(((float)((( qs16.s1 >> 8) & 0x000F) | ((uint)(( qh16 >> 12) & 0x3) << 4)) - 32.f) * scale); \ + a_f16.s7 = (half)(((float)((( qs16.s1 >>12) & 0x000F) | ((uint)(( qh16 >> 14) & 0x3) << 4)) - 32.f) * scale); \ + a_f16.s8 = (half)(((float)(( qs16.s2 & 0x000F) | ((uint)(( qh16 >> 16) & 0x3) << 4)) - 32.f) * scale); \ + a_f16.s9 = (half)(((float)((( qs16.s2 >> 4) & 0x000F) | ((uint)(( qh16 >> 18) & 0x3) << 4)) - 32.f) * scale); \ + a_f16.sa = (half)(((float)((( qs16.s2 >> 8) & 0x000F) | ((uint)(( qh16 >> 20) & 0x3) << 4)) - 32.f) * scale); \ + a_f16.sb = (half)(((float)((( qs16.s2 >>12) & 0x000F) | ((uint)(( qh16 >> 22) & 0x3) << 4)) - 32.f) * scale); \ + a_f16.sc = (half)(((float)(( qs16.s3 & 0x000F) | ((uint)(( qh16 >> 24) & 0x3) << 4)) - 32.f) * scale); \ + a_f16.sd = (half)(((float)((( qs16.s3 >> 4) & 0x000F) | ((uint)(( qh16 >> 26) & 0x3) << 4)) - 32.f) * scale); \ + a_f16.se = (half)(((float)((( qs16.s3 >> 8) & 0x000F) | ((uint)(( qh16 >> 28) & 0x3) << 4)) - 32.f) * scale); \ + a_f16.sf = (half)(((float)((( qs16.s3 >>12) & 0x000F) | ((uint)(( qh16 >> 30) & 0x3) << 4)) - 32.f) * scale); \ + + +#define dotx16_reduce8(a_reg, b_lm, c_reg, lm_offset) \ + acc.s0 = dot(a_reg.s0123, b_lm[lm_offset + 0]); \ + acc.s1 = dot(a_reg.s0123, b_lm[lm_offset + 1]); \ + acc.s2 = dot(a_reg.s0123, b_lm[lm_offset + 2]); \ + acc.s3 = dot(a_reg.s0123, b_lm[lm_offset + 3]); \ + acc.s4 = dot(a_reg.s0123, b_lm[lm_offset + 4]); \ + acc.s5 = dot(a_reg.s0123, b_lm[lm_offset + 5]); \ + acc.s6 = dot(a_reg.s0123, b_lm[lm_offset + 6]); \ + acc.s7 = dot(a_reg.s0123, b_lm[lm_offset + 7]); \ + acc.s8 = dot(a_reg.s0123, b_lm[lm_offset + 8]); \ + acc.s9 = dot(a_reg.s0123, b_lm[lm_offset + 9]); \ + acc.sa = dot(a_reg.s0123, b_lm[lm_offset + 10]); \ + acc.sb = dot(a_reg.s0123, b_lm[lm_offset + 11]); \ + acc.sc = dot(a_reg.s0123, b_lm[lm_offset + 12]); \ + acc.sd = dot(a_reg.s0123, b_lm[lm_offset + 13]); \ + acc.se = dot(a_reg.s0123, b_lm[lm_offset + 14]); \ + acc.sf = dot(a_reg.s0123, b_lm[lm_offset + 15]); \ + acc.s0 += dot(a_reg.s4567, b_lm[lm_offset + 32]); \ + acc.s1 += dot(a_reg.s4567, b_lm[lm_offset + 33]); \ + acc.s2 += dot(a_reg.s4567, b_lm[lm_offset + 34]); \ + acc.s3 += dot(a_reg.s4567, b_lm[lm_offset + 35]); \ + acc.s4 += dot(a_reg.s4567, b_lm[lm_offset + 36]); \ + acc.s5 += dot(a_reg.s4567, b_lm[lm_offset + 37]); \ + acc.s6 += dot(a_reg.s4567, b_lm[lm_offset + 38]); \ + acc.s7 += dot(a_reg.s4567, b_lm[lm_offset + 39]); \ + acc.s8 += dot(a_reg.s4567, b_lm[lm_offset + 40]); \ + acc.s9 += dot(a_reg.s4567, b_lm[lm_offset + 41]); \ + acc.sa += dot(a_reg.s4567, b_lm[lm_offset + 42]); \ + acc.sb += dot(a_reg.s4567, b_lm[lm_offset + 43]); \ + acc.sc += dot(a_reg.s4567, b_lm[lm_offset + 44]); \ + acc.sd += dot(a_reg.s4567, b_lm[lm_offset + 45]); \ + acc.se += dot(a_reg.s4567, b_lm[lm_offset + 46]); \ + acc.sf += dot(a_reg.s4567, b_lm[lm_offset + 47]); \ + c_reg.lo += convert_float8(acc.lo); \ + c_reg.hi += convert_float8(acc.hi); \ + acc.s0 = dot(a_reg.s89ab, b_lm[lm_offset + 64]); \ + acc.s1 = dot(a_reg.s89ab, b_lm[lm_offset + 65]); \ + acc.s2 = dot(a_reg.s89ab, b_lm[lm_offset + 66]); \ + acc.s3 = dot(a_reg.s89ab, b_lm[lm_offset + 67]); \ + acc.s4 = dot(a_reg.s89ab, b_lm[lm_offset + 68]); \ + acc.s5 = dot(a_reg.s89ab, b_lm[lm_offset + 69]); \ + acc.s6 = dot(a_reg.s89ab, b_lm[lm_offset + 70]); \ + acc.s7 = dot(a_reg.s89ab, b_lm[lm_offset + 71]); \ + acc.s8 = dot(a_reg.s89ab, b_lm[lm_offset + 72]); \ + acc.s9 = dot(a_reg.s89ab, b_lm[lm_offset + 73]); \ + acc.sa = dot(a_reg.s89ab, b_lm[lm_offset + 74]); \ + acc.sb = dot(a_reg.s89ab, b_lm[lm_offset + 75]); \ + acc.sc = dot(a_reg.s89ab, b_lm[lm_offset + 76]); \ + acc.sd = dot(a_reg.s89ab, b_lm[lm_offset + 77]); \ + acc.se = dot(a_reg.s89ab, b_lm[lm_offset + 78]); \ + acc.sf = dot(a_reg.s89ab, b_lm[lm_offset + 79]); \ + acc.s0 += dot(a_reg.scdef, b_lm[lm_offset + 96]); \ + acc.s1 += dot(a_reg.scdef, b_lm[lm_offset + 97]); \ + acc.s2 += dot(a_reg.scdef, b_lm[lm_offset + 98]); \ + acc.s3 += dot(a_reg.scdef, b_lm[lm_offset + 99]); \ + acc.s4 += dot(a_reg.scdef, b_lm[lm_offset + 100]); \ + acc.s5 += dot(a_reg.scdef, b_lm[lm_offset + 101]); \ + acc.s6 += dot(a_reg.scdef, b_lm[lm_offset + 102]); \ + acc.s7 += dot(a_reg.scdef, b_lm[lm_offset + 103]); \ + acc.s8 += dot(a_reg.scdef, b_lm[lm_offset + 104]); \ + acc.s9 += dot(a_reg.scdef, b_lm[lm_offset + 105]); \ + acc.sa += dot(a_reg.scdef, b_lm[lm_offset + 106]); \ + acc.sb += dot(a_reg.scdef, b_lm[lm_offset + 107]); \ + acc.sc += dot(a_reg.scdef, b_lm[lm_offset + 108]); \ + acc.sd += dot(a_reg.scdef, b_lm[lm_offset + 109]); \ + acc.se += dot(a_reg.scdef, b_lm[lm_offset + 110]); \ + acc.sf += dot(a_reg.scdef, b_lm[lm_offset + 111]); \ + c_reg.lo += convert_float8(acc.lo); \ + c_reg.hi += convert_float8(acc.hi); \ + + +__attribute__((qcom_wave_pair_mode(1))) +kernel void kernel_gemm_moe_q6_k_f32_ns( + __read_only image1d_buffer_t src0_ql, + __global uint * src0_qh, + __global char * src0_s, + __global half * src0_d, + __read_only image1d_buffer_t src1, + __global uint * src2, + __global ushort * src2_emap, + __write_only image1d_buffer_t dst, + __global int * total_tiles, + uint ne00, + uint ne01 +) { + uint block_id_m = get_global_id(1); // m_tile + uint block_id_n = get_global_id(2); // n_tile + + // Boundary check + if (block_id_n >= total_tiles[0]) { + return; + } + + __private half16 reg_a; + __private float32 reg_c = (float32)(0); + __local half4 shared_b[128]; + + const ushort expert_id = src2_emap[block_id_n]; + + const uint row = block_id_m * TILESIZE_M; + const uint col = block_id_n * TILESIZE_N; + + uint sub_block_id_m = get_local_id(0); + uint2 b_global_offset; + b_global_offset.x = ((sub_block_id_m & 3) << 2) + (sub_block_id_m >> 2) * ne00; + b_global_offset.y = b_global_offset.x + (16 * ne00); + uint2 b_local_offset; + b_local_offset.x = (sub_block_id_m & 3) * 32 + (sub_block_id_m >> 2); + b_local_offset.y = b_local_offset.x + 16; + + uint num_superblocks = ne00 / QK_K; + uint scales_per_row = num_superblocks * 16; + uint row_idx = row + get_global_id(0); + + // Loop along K axis, 32 elements per iteration (one sub-block), divided into 2 halves of 16 + for (uint step = 0; step < ne00; step += TILESIZE_K * 2) { + uint sub = step / 32; // 32-element group index + uint sb = sub / 8; // super-block index + uint j = sub % 8; // group within super-block + + // Load d for super-block + uint d_offset = row + sb * ne01 + expert_id * num_superblocks * ne01 + get_global_id(0); + half d_val = src0_d[d_offset]; + + // Load sub-block scales + global const char * sc = src0_s + (expert_id * ne01 + row_idx) * scales_per_row + sb * 16; + float scale0 = (float)d_val * (float)sc[j * 2]; + float scale1 = (float)d_val * (float)sc[j * 2 + 1]; + + uint qh_base = row + (sub * 2) * ne01 + expert_id * (num_superblocks * 16) * ne01 + get_global_id(0); + uint qh_first16 = src0_qh[qh_base]; + uint qh_second16 = src0_qh[qh_base + ne01]; + + // First half (16 elements) + uint q_sub_offset = row + ((ne01 * step) >> 3) + ((expert_id * ne00 * ne01) >> 3); + uint b_sub_offset = col * ne00 + step; + + // Load 16 ql nibbles (2 uints) from image + uint2 q4x16; + q4x16.x = read_imageui(src0_ql, q_sub_offset + sub_block_id_m).x; + q4x16.y = read_imageui(src0_ql, q_sub_offset + sub_block_id_m + ne01).x; + + // Load 16x32 floats from matrix B + float8 bx8_f32; + bx8_f32.lo = read_imagef(src1, (b_sub_offset + b_global_offset.x) / 4); + bx8_f32.hi = read_imagef(src1, (b_sub_offset + b_global_offset.y) / 4); + half8 bx8_f16 = convert_half8(bx8_f32); + shared_b[b_local_offset.x] = bx8_f16.lo; + shared_b[b_local_offset.y] = bx8_f16.hi; + + // Dequantize first 16 elements (scale0) + dequantize_q6_k(as_ushort4(q4x16), qh_first16, reg_a, scale0); + + sub_group_barrier(CLK_LOCAL_MEM_FENCE); + + half16 acc; + dotx16_reduce8(reg_a, shared_b, reg_c.lo, 0); + dotx16_reduce8(reg_a, shared_b, reg_c.hi, 16); + + // Second half + uint half_step = step + TILESIZE_K; + q_sub_offset = row + ((ne01 * half_step) >> 3) + ((expert_id * ne00 * ne01) >> 3); + b_sub_offset = col * ne00 + half_step; + + q4x16.x = read_imageui(src0_ql, q_sub_offset + sub_block_id_m).x; + q4x16.y = read_imageui(src0_ql, q_sub_offset + sub_block_id_m + ne01).x; + + bx8_f32.lo = read_imagef(src1, (b_sub_offset + b_global_offset.x) / 4); + bx8_f32.hi = read_imagef(src1, (b_sub_offset + b_global_offset.y) / 4); + bx8_f16 = convert_half8(bx8_f32); + shared_b[b_local_offset.x] = bx8_f16.lo; + shared_b[b_local_offset.y] = bx8_f16.hi; + + dequantize_q6_k(as_ushort4(q4x16), qh_second16, reg_a, scale1); + + sub_group_barrier(CLK_LOCAL_MEM_FENCE); + + dotx16_reduce8(reg_a, shared_b, reg_c.lo, 0); + dotx16_reduce8(reg_a, shared_b, reg_c.hi, 16); + } + + if ((get_global_id(0) + block_id_m * TILESIZE_M) >= ne01) { + return; + } + + // Load post router and share in LM + __local uint out_idx[TILESIZE_N]; + + if (get_local_id(0) < TILESIZE_N) { + uint idx = src2[block_id_n * TILESIZE_N + get_local_id(0)]; + if (idx == 0xFFFFFFFF) { + idx = src2[block_id_n * TILESIZE_N + 0]; + } + out_idx[get_local_id(0)] = idx * ne01; + } + + barrier(CLK_LOCAL_MEM_FENCE); + + // Scatter results back to original position in output grid + uint m_offset = row + get_local_id(0); + + write_imagef(dst, out_idx[1] + m_offset, (reg_c.s1)); + write_imagef(dst, out_idx[2] + m_offset, (reg_c.s2)); + write_imagef(dst, out_idx[3] + m_offset, (reg_c.s3)); + write_imagef(dst, out_idx[4] + m_offset, (reg_c.s4)); + write_imagef(dst, out_idx[5] + m_offset, (reg_c.s5)); + write_imagef(dst, out_idx[6] + m_offset, (reg_c.s6)); + write_imagef(dst, out_idx[7] + m_offset, (reg_c.s7)); + write_imagef(dst, out_idx[8] + m_offset, (reg_c.s8)); + write_imagef(dst, out_idx[9] + m_offset, (reg_c.s9)); + write_imagef(dst, out_idx[10] + m_offset, (reg_c.sa)); + write_imagef(dst, out_idx[11] + m_offset, (reg_c.sb)); + write_imagef(dst, out_idx[12] + m_offset, (reg_c.sc)); + write_imagef(dst, out_idx[13] + m_offset, (reg_c.sd)); + write_imagef(dst, out_idx[14] + m_offset, (reg_c.se)); + write_imagef(dst, out_idx[15] + m_offset, (reg_c.sf)); + write_imagef(dst, out_idx[16] + m_offset, (reg_c.sg)); + write_imagef(dst, out_idx[17] + m_offset, (reg_c.sh)); + write_imagef(dst, out_idx[18] + m_offset, (reg_c.si)); + write_imagef(dst, out_idx[19] + m_offset, (reg_c.sj)); + write_imagef(dst, out_idx[20] + m_offset, (reg_c.sk)); + write_imagef(dst, out_idx[21] + m_offset, (reg_c.sl)); + write_imagef(dst, out_idx[22] + m_offset, (reg_c.sm)); + write_imagef(dst, out_idx[23] + m_offset, (reg_c.sn)); + write_imagef(dst, out_idx[24] + m_offset, (reg_c.so)); + write_imagef(dst, out_idx[25] + m_offset, (reg_c.sp)); + write_imagef(dst, out_idx[26] + m_offset, (reg_c.sq)); + write_imagef(dst, out_idx[27] + m_offset, (reg_c.sr)); + write_imagef(dst, out_idx[28] + m_offset, (reg_c.ss)); + write_imagef(dst, out_idx[29] + m_offset, (reg_c.st)); + write_imagef(dst, out_idx[30] + m_offset, (reg_c.su)); + write_imagef(dst, out_idx[31] + m_offset, (reg_c.sv)); + + // Store zero padding parts to the index of first output in tile + barrier(CLK_GLOBAL_MEM_FENCE); + write_imagef(dst, out_idx[0] + m_offset, (reg_c.s0)); +} diff --git a/ggml/src/ggml-opencl/kernels/gemm_noshuffle_iq4_nl_f32.cl b/ggml/src/ggml-opencl/kernels/gemm_noshuffle_iq4_nl_f32.cl new file mode 100644 index 00000000000..6869d822862 --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/gemm_noshuffle_iq4_nl_f32.cl @@ -0,0 +1,150 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable +#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable + +#ifdef cl_qcom_reqd_sub_group_size +#define ADRENO_GPU 1 +#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full"))) +#endif + +constant half kvalues_iq4nl[16] = { + (half)-127.f, (half)-104.f, (half)-83.f, (half)-65.f, + (half) -49.f, (half) -35.f, (half)-22.f, (half)-10.f, + (half) 1.f, (half) 13.f, (half) 25.f, (half) 38.f, + (half) 53.f, (half) 69.f, (half) 89.f, (half)113.f +}; + +// Packed LUT: 2 FP16 values per uint, 8 unique constant loads instead of 16 +constant uint iq4nl_packed[8] = { + 0xD680D7F0u, // idx 0,1: -127, -104 + 0xD410D530u, // idx 2,3: -83, -65 + 0xD060D220u, // idx 4,5: -49, -35 + 0xC900CD80u, // idx 6,7: -22, -10 + 0x4A803C00u, // idx 8,9: 1, 13 + 0x50C04E40u, // idx 10,11: 25, 38 + 0x545052A0u, // idx 12,13: 53, 69 + 0x57105590u // idx 14,15: 89, 113 +}; + +// Packed dequant: 1 uint constant load (8-way divergence) + shift + as_half +#define IQ4_NL_DEQUANT(nibble) as_half((ushort)(iq4nl_packed[(nibble) >> 1] >> (((nibble) & 1u) << 4))) + +#ifdef ADRENO_GPU +REQD_SUBGROUP_SIZE_128 +#endif + +kernel void kernel_gemm_noshuffle_iq4_nl_f32( + global const ushort * src0_q, + global const half * src0_d, + read_only image1d_buffer_t src1, + global float * dst, + ulong offsetd, + int m, + int n, + int k, + int n_no_padding +) { + dst = (global float *)((global char *)dst + offsetd); + + int m_4 = m >> 2; + int n_4 = n >> 2; + + int gy = get_global_id(0); + int gx = get_global_id(1); + int gx_2 = gx << 2; + + half8 c0 = 0, c1 = 0, c2 = 0, c3 = 0; + half8 B; + half4 dequantized_weights; + + global const ushort * weight_ptr = src0_q + gx_2; + global const half * scale_ptr = src0_d + gx_2; + + for (int i = 0; i < k; i += 4) { + B.s0123 = read_imageh(src1, gy*2 + (i)*(n_4)); + B.s4567 = read_imageh(src1, gy*2 + (i)*(n_4)+1); + + ushort4 bits4 = vload4(0, weight_ptr + (i/4)*(m)); + + half4 scale = vload4(0, scale_ptr + (i/32)*(m)); + + // j=0 + dequantized_weights.s0 = IQ4_NL_DEQUANT(bits4.s0 & 0x000Fu) * scale.s0; + dequantized_weights.s1 = IQ4_NL_DEQUANT(bits4.s1 & 0x000Fu) * scale.s1; + dequantized_weights.s2 = IQ4_NL_DEQUANT(bits4.s2 & 0x000Fu) * scale.s2; + dequantized_weights.s3 = IQ4_NL_DEQUANT(bits4.s3 & 0x000Fu) * scale.s3; + c0 += B * dequantized_weights.s0; + c1 += B * dequantized_weights.s1; + c2 += B * dequantized_weights.s2; + c3 += B * dequantized_weights.s3; + + // j=1 + B.s0123 = read_imageh(src1, gy*2 + (i+1)*(n_4)); + B.s4567 = read_imageh(src1, gy*2 + (i+1)*(n_4)+1); + dequantized_weights.s0 = IQ4_NL_DEQUANT((bits4.s0 >> 4) & 0x000Fu) * scale.s0; + dequantized_weights.s1 = IQ4_NL_DEQUANT((bits4.s1 >> 4) & 0x000Fu) * scale.s1; + dequantized_weights.s2 = IQ4_NL_DEQUANT((bits4.s2 >> 4) & 0x000Fu) * scale.s2; + dequantized_weights.s3 = IQ4_NL_DEQUANT((bits4.s3 >> 4) & 0x000Fu) * scale.s3; + c0 += B * dequantized_weights.s0; + c1 += B * dequantized_weights.s1; + c2 += B * dequantized_weights.s2; + c3 += B * dequantized_weights.s3; + + // j=2 + B.s0123 = read_imageh(src1, gy*2 + (i+2)*(n_4)); + B.s4567 = read_imageh(src1, gy*2 + (i+2)*(n_4)+1); + dequantized_weights.s0 = IQ4_NL_DEQUANT((bits4.s0 >> 8) & 0x000Fu) * scale.s0; + dequantized_weights.s1 = IQ4_NL_DEQUANT((bits4.s1 >> 8) & 0x000Fu) * scale.s1; + dequantized_weights.s2 = IQ4_NL_DEQUANT((bits4.s2 >> 8) & 0x000Fu) * scale.s2; + dequantized_weights.s3 = IQ4_NL_DEQUANT((bits4.s3 >> 8) & 0x000Fu) * scale.s3; + c0 += B * dequantized_weights.s0; + c1 += B * dequantized_weights.s1; + c2 += B * dequantized_weights.s2; + c3 += B * dequantized_weights.s3; + + // j=3 + B.s0123 = read_imageh(src1, gy*2 + (i+3)*(n_4)); + B.s4567 = read_imageh(src1, gy*2 + (i+3)*(n_4)+1); + dequantized_weights.s0 = IQ4_NL_DEQUANT((bits4.s0 >> 12) & 0x000Fu) * scale.s0; + dequantized_weights.s1 = IQ4_NL_DEQUANT((bits4.s1 >> 12) & 0x000Fu) * scale.s1; + dequantized_weights.s2 = IQ4_NL_DEQUANT((bits4.s2 >> 12) & 0x000Fu) * scale.s2; + dequantized_weights.s3 = IQ4_NL_DEQUANT((bits4.s3 >> 12) & 0x000Fu) * scale.s3; + c0 += B * dequantized_weights.s0; + c1 += B * dequantized_weights.s1; + c2 += B * dequantized_weights.s2; + c3 += B * dequantized_weights.s3; + } + + int idx = (gy<<3)*m + (gx<<2); + + if(idx+3 < m*n_no_padding){ + vstore4((float4)(c0.s0, c1.s0, c2.s0, c3.s0), 0, dst + idx); + idx += m; + } + if(idx+3 < m*n_no_padding){ + vstore4((float4)(c0.s1, c1.s1, c2.s1, c3.s1), 0, dst + idx); + idx += m; + } + if(idx+3 < m*n_no_padding){ + vstore4((float4)(c0.s2, c1.s2, c2.s2, c3.s2), 0, dst + idx); + idx += m; + } + if(idx+3 < m*n_no_padding){ + vstore4((float4)(c0.s3, c1.s3, c2.s3, c3.s3), 0, dst + idx); + idx += m; + } + if(idx+3 < m*n_no_padding){ + vstore4((float4)(c0.s4, c1.s4, c2.s4, c3.s4), 0, dst + idx); + idx += m; + } + if(idx+3 < m*n_no_padding){ + vstore4((float4)(c0.s5, c1.s5, c2.s5, c3.s5), 0, dst + idx); + idx += m; + } + if(idx+3 < m*n_no_padding){ + vstore4((float4)(c0.s6, c1.s6, c2.s6, c3.s6), 0, dst + idx); + idx += m; + } + if(idx+3 < m*n_no_padding){ + vstore4((float4)(c0.s7, c1.s7, c2.s7, c3.s7), 0, dst + idx); + } +} diff --git a/ggml/src/ggml-opencl/kernels/mul_mat_Ab_Bi_8x4.cl b/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q4_0_f32.cl similarity index 99% rename from ggml/src/ggml-opencl/kernels/mul_mat_Ab_Bi_8x4.cl rename to ggml/src/ggml-opencl/kernels/gemm_noshuffle_q4_0_f32.cl index ecb577b9933..159378049fb 100644 --- a/ggml/src/ggml-opencl/kernels/mul_mat_Ab_Bi_8x4.cl +++ b/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q4_0_f32.cl @@ -17,7 +17,7 @@ REQD_SUBGROUP_SIZE_128 #endif -kernel void kernel_mul_mat_Ab_Bi_8x4( +kernel void kernel_gemm_noshuffle_q4_0_f32( global const ushort * src0_q, // quantized A global const half * src0_d, // A scales __read_only image1d_buffer_t src1, // B (1d image) diff --git a/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q4_1_f32.cl b/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q4_1_f32.cl new file mode 100644 index 00000000000..5c4d5cc8e2c --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q4_1_f32.cl @@ -0,0 +1,132 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable +#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable + +#ifdef cl_qcom_reqd_sub_group_size +#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable +#define ADRENO_GPU 1 +#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full"))) +#endif + +#ifdef ADRENO_GPU +REQD_SUBGROUP_SIZE_128 +#endif + +kernel void kernel_gemm_noshuffle_q4_1_f32( + global const ushort * src0_q, + global const half * src0_d, + global const half * src0_m, + read_only image1d_buffer_t src1, + global float * dst, + ulong offsetd, + int m, + int n, + int k, + int n_no_padding +) { + dst = (global float *)((global char *)dst + offsetd); + + int m_4 = m >> 2; + int n_4 = n >> 2; + + int gy = get_global_id(0); + int gx = get_global_id(1); + int gx_2 = gx << 2; + + half8 c0 = 0, c1 = 0, c2 = 0, c3 = 0; + half8 B; + half4 dequantized_weights; + + global const ushort* weight_ptr = src0_q + gx_2; + global const half* scale_ptr = src0_d + gx_2; + global const half* min_ptr = src0_m + gx_2; + + for(int i = 0; i < k; i += 4) { + B.s0123 = read_imageh(src1, gy*2 + (i)*(n_4)); + B.s4567 = read_imageh(src1, gy*2 + (i)*(n_4)+1); + + ushort4 bits4 = vload4(0, weight_ptr + (i/4)*(m)); + + half4 scale = vload4(0, scale_ptr + (i/32)*(m)); + half4 minv = vload4(0, min_ptr + (i/32)*(m)); + + // j=0 + dequantized_weights.s0 = (bits4.s0 & (0x000F)) * scale.s0 + minv.s0; + dequantized_weights.s1 = (bits4.s1 & (0x000F)) * scale.s1 + minv.s1; + dequantized_weights.s2 = (bits4.s2 & (0x000F)) * scale.s2 + minv.s2; + dequantized_weights.s3 = (bits4.s3 & (0x000F)) * scale.s3 + minv.s3; + c0 += B * dequantized_weights.s0; + c1 += B * dequantized_weights.s1; + c2 += B * dequantized_weights.s2; + c3 += B * dequantized_weights.s3; + + // j=1 + B.s0123 = read_imageh(src1, gy*2 + (i+1)*(n_4)); + B.s4567 = read_imageh(src1, gy*2 + (i+1)*(n_4)+1); + dequantized_weights.s0 = ((bits4.s0 & (0x00F0)) >> 4) * scale.s0 + minv.s0; + dequantized_weights.s1 = ((bits4.s1 & (0x00F0)) >> 4) * scale.s1 + minv.s1; + dequantized_weights.s2 = ((bits4.s2 & (0x00F0)) >> 4) * scale.s2 + minv.s2; + dequantized_weights.s3 = ((bits4.s3 & (0x00F0)) >> 4) * scale.s3 + minv.s3; + c0 += B * dequantized_weights.s0; + c1 += B * dequantized_weights.s1; + c2 += B * dequantized_weights.s2; + c3 += B * dequantized_weights.s3; + + // j=2 + B.s0123 = read_imageh(src1, gy*2 + (i+2)*(n_4)); + B.s4567 = read_imageh(src1, gy*2 + (i+2)*(n_4)+1); + dequantized_weights.s0 = ((bits4.s0 & (0x0F00)) >> 8) * scale.s0 + minv.s0; + dequantized_weights.s1 = ((bits4.s1 & (0x0F00)) >> 8) * scale.s1 + minv.s1; + dequantized_weights.s2 = ((bits4.s2 & (0x0F00)) >> 8) * scale.s2 + minv.s2; + dequantized_weights.s3 = ((bits4.s3 & (0x0F00)) >> 8) * scale.s3 + minv.s3; + c0 += B * dequantized_weights.s0; + c1 += B * dequantized_weights.s1; + c2 += B * dequantized_weights.s2; + c3 += B * dequantized_weights.s3; + + // j=3 + B.s0123 = read_imageh(src1, gy*2 + (i+3)*(n_4)); + B.s4567 = read_imageh(src1, gy*2 + (i+3)*(n_4)+1); + dequantized_weights.s0 = ((bits4.s0 & (0xF000)) >> 12) * scale.s0 + minv.s0; + dequantized_weights.s1 = ((bits4.s1 & (0xF000)) >> 12) * scale.s1 + minv.s1; + dequantized_weights.s2 = ((bits4.s2 & (0xF000)) >> 12) * scale.s2 + minv.s2; + dequantized_weights.s3 = ((bits4.s3 & (0xF000)) >> 12) * scale.s3 + minv.s3; + c0 += B * dequantized_weights.s0; + c1 += B * dequantized_weights.s1; + c2 += B * dequantized_weights.s2; + c3 += B * dequantized_weights.s3; + } + + int idx = (gy<<3)*m + (gx<<2); + + if(idx+3 < m*n_no_padding){ + vstore4((float4)(c0.s0, c1.s0, c2.s0, c3.s0), 0, dst + idx); + idx += m; + } + if(idx+3 < m*n_no_padding){ + vstore4((float4)(c0.s1, c1.s1, c2.s1, c3.s1), 0, dst + idx); + idx += m; + } + if(idx+3 < m*n_no_padding){ + vstore4((float4)(c0.s2, c1.s2, c2.s2, c3.s2), 0, dst + idx); + idx += m; + } + if(idx+3 < m*n_no_padding){ + vstore4((float4)(c0.s3, c1.s3, c2.s3, c3.s3), 0, dst + idx); + idx += m; + } + if(idx+3 < m*n_no_padding){ + vstore4((float4)(c0.s4, c1.s4, c2.s4, c3.s4), 0, dst + idx); + idx += m; + } + if(idx+3 < m*n_no_padding){ + vstore4((float4)(c0.s5, c1.s5, c2.s5, c3.s5), 0, dst + idx); + idx += m; + } + if(idx+3 < m*n_no_padding){ + vstore4((float4)(c0.s6, c1.s6, c2.s6, c3.s6), 0, dst + idx); + idx += m; + } + if(idx+3 < m*n_no_padding){ + vstore4((float4)(c0.s7, c1.s7, c2.s7, c3.s7), 0, dst + idx); + } +} diff --git a/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q4_k_f32.cl b/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q4_k_f32.cl new file mode 100644 index 00000000000..99fd1fd7bf1 --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q4_k_f32.cl @@ -0,0 +1,172 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable + +#ifdef cl_qcom_reqd_sub_group_size +#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable +#define ADRENO_GPU 1 +#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full"))) +#endif +#define QK_K 256 +#define K_SCALE_SIZE 12 + +inline void get_scale_min_k4( + int j, + global const uchar * q, + uchar * d, + uchar * m, + uchar mask_d6, + uchar mask_d4, + uchar mask_hi2 +) { + if (j < 4) { + *d = q[j] & mask_d6; + *m = q[j+4] & mask_d6; + } else { + *d = (q[j+4] & mask_d4) | ((q[j-4] & mask_hi2) >> 2); + *m = ((q[j+4] >> 4) & mask_d4) | ((q[j] & mask_hi2) >> 2); + } +} + +#ifdef ADRENO_GPU +REQD_SUBGROUP_SIZE_128 +#endif +kernel void kernel_gemm_noshuffle_q4_k_f32( + global const ushort * src0_q, + global const uchar * src0_s, + global const half * src0_d, + global const half * src0_dm, + read_only image1d_buffer_t src1, + global float * dst, + ulong offsetd, + int m, + int n, + int k, + int n_no_padding, + uchar mask_d6, + uchar mask_d4, + uchar mask_hi2 +) { + dst = (global float *)((global char *)dst + offsetd); + int n_4 = n >> 2; + int gy = get_global_id(0); + int gx = get_global_id(1); + int gx_2 = gx << 2; + + half8 c0 = 0, c1 = 0, c2 = 0, c3 = 0; + half8 B; + half4 dequantized_weights; + + int num_blocks_K = k / QK_K; + + global const ushort * weight_ptr = src0_q + gx_2; + global const half * d_ptr = src0_d + gx_2; + global const half * dm_ptr = src0_dm + gx_2; + + for (int i = 0; i < k; i += 32) { + int sb_idx = i / QK_K; + int sub_idx = (i / 32) % 8; + + half4 d = vload4(0, d_ptr + sb_idx * m); + half4 dm = vload4(0, dm_ptr + sb_idx * m); + + global const uchar * sc0 = src0_s + (gx_2+0) * num_blocks_K * K_SCALE_SIZE + sb_idx * K_SCALE_SIZE; + global const uchar * sc1 = src0_s + (gx_2+1) * num_blocks_K * K_SCALE_SIZE + sb_idx * K_SCALE_SIZE; + global const uchar * sc2 = src0_s + (gx_2+2) * num_blocks_K * K_SCALE_SIZE + sb_idx * K_SCALE_SIZE; + global const uchar * sc3 = src0_s + (gx_2+3) * num_blocks_K * K_SCALE_SIZE + sb_idx * K_SCALE_SIZE; + + uchar sv0, mn0, sv1, mn1, sv2, mn2, sv3, mn3; + get_scale_min_k4(sub_idx, sc0, &sv0, &mn0, mask_d6, mask_d4, mask_hi2); + get_scale_min_k4(sub_idx, sc1, &sv1, &mn1, mask_d6, mask_d4, mask_hi2); + get_scale_min_k4(sub_idx, sc2, &sv2, &mn2, mask_d6, mask_d4, mask_hi2); + get_scale_min_k4(sub_idx, sc3, &sv3, &mn3, mask_d6, mask_d4, mask_hi2); + + half4 scale = convert_half4(convert_float4(d) * convert_float4((uchar4)(sv0, sv1, sv2, sv3))); + half4 mval = convert_half4(convert_float4(dm) * convert_float4((uchar4)(mn0, mn1, mn2, mn3))); + + for (int l = 0; l < 32; l += 4) { + int ki = i + l; + ushort4 bits4 = vload4(0, weight_ptr + (ki/4) * m); + + // j=0 + B.s0123 = read_imageh(src1, gy*2 + (ki+0) * n_4); + B.s4567 = read_imageh(src1, gy*2+1 + (ki+0) * n_4); + dequantized_weights.s0 = (bits4.s0 & 0x000F) * scale.s0 - mval.s0; + dequantized_weights.s1 = (bits4.s1 & 0x000F) * scale.s1 - mval.s1; + dequantized_weights.s2 = (bits4.s2 & 0x000F) * scale.s2 - mval.s2; + dequantized_weights.s3 = (bits4.s3 & 0x000F) * scale.s3 - mval.s3; + c0 += B * dequantized_weights.s0; + c1 += B * dequantized_weights.s1; + c2 += B * dequantized_weights.s2; + c3 += B * dequantized_weights.s3; + + // j=1 + B.s0123 = read_imageh(src1, gy*2 + (ki+1) * n_4); + B.s4567 = read_imageh(src1, gy*2+1 + (ki+1) * n_4); + dequantized_weights.s0 = ((bits4.s0 & 0x00F0) >> 4) * scale.s0 - mval.s0; + dequantized_weights.s1 = ((bits4.s1 & 0x00F0) >> 4) * scale.s1 - mval.s1; + dequantized_weights.s2 = ((bits4.s2 & 0x00F0) >> 4) * scale.s2 - mval.s2; + dequantized_weights.s3 = ((bits4.s3 & 0x00F0) >> 4) * scale.s3 - mval.s3; + c0 += B * dequantized_weights.s0; + c1 += B * dequantized_weights.s1; + c2 += B * dequantized_weights.s2; + c3 += B * dequantized_weights.s3; + + // j=2 + B.s0123 = read_imageh(src1, gy*2 + (ki+2) * n_4); + B.s4567 = read_imageh(src1, gy*2+1 + (ki+2) * n_4); + dequantized_weights.s0 = ((bits4.s0 & 0x0F00) >> 8) * scale.s0 - mval.s0; + dequantized_weights.s1 = ((bits4.s1 & 0x0F00) >> 8) * scale.s1 - mval.s1; + dequantized_weights.s2 = ((bits4.s2 & 0x0F00) >> 8) * scale.s2 - mval.s2; + dequantized_weights.s3 = ((bits4.s3 & 0x0F00) >> 8) * scale.s3 - mval.s3; + c0 += B * dequantized_weights.s0; + c1 += B * dequantized_weights.s1; + c2 += B * dequantized_weights.s2; + c3 += B * dequantized_weights.s3; + + // j=3 + B.s0123 = read_imageh(src1, gy*2 + (ki+3) * n_4); + B.s4567 = read_imageh(src1, gy*2+1 + (ki+3) * n_4); + dequantized_weights.s0 = ((bits4.s0 & 0xF000) >> 12) * scale.s0 - mval.s0; + dequantized_weights.s1 = ((bits4.s1 & 0xF000) >> 12) * scale.s1 - mval.s1; + dequantized_weights.s2 = ((bits4.s2 & 0xF000) >> 12) * scale.s2 - mval.s2; + dequantized_weights.s3 = ((bits4.s3 & 0xF000) >> 12) * scale.s3 - mval.s3; + c0 += B * dequantized_weights.s0; + c1 += B * dequantized_weights.s1; + c2 += B * dequantized_weights.s2; + c3 += B * dequantized_weights.s3; + } + } + + int idx = (gy<<3)*m + (gx<<2); + + if (idx+3 < m*n_no_padding) { + vstore4((float4)(c0.s0, c1.s0, c2.s0, c3.s0), 0, dst + idx); + idx += m; + } + if (idx+3 < m*n_no_padding) { + vstore4((float4)(c0.s1, c1.s1, c2.s1, c3.s1), 0, dst + idx); + idx += m; + } + if (idx+3 < m*n_no_padding) { + vstore4((float4)(c0.s2, c1.s2, c2.s2, c3.s2), 0, dst + idx); + idx += m; + } + if (idx+3 < m*n_no_padding) { + vstore4((float4)(c0.s3, c1.s3, c2.s3, c3.s3), 0, dst + idx); + idx += m; + } + if (idx+3 < m*n_no_padding) { + vstore4((float4)(c0.s4, c1.s4, c2.s4, c3.s4), 0, dst + idx); + idx += m; + } + if (idx+3 < m*n_no_padding) { + vstore4((float4)(c0.s5, c1.s5, c2.s5, c3.s5), 0, dst + idx); + idx += m; + } + if (idx+3 < m*n_no_padding) { + vstore4((float4)(c0.s6, c1.s6, c2.s6, c3.s6), 0, dst + idx); + idx += m; + } + if (idx+3 < m*n_no_padding) { + vstore4((float4)(c0.s7, c1.s7, c2.s7, c3.s7), 0, dst + idx); + } +} diff --git a/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q5_0_f32.cl b/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q5_0_f32.cl new file mode 100644 index 00000000000..1d6bd48005e --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q5_0_f32.cl @@ -0,0 +1,131 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable +#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable + +#ifdef cl_qcom_reqd_sub_group_size +#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable +#define ADRENO_GPU 1 +#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full"))) +#endif + +#ifdef ADRENO_GPU +REQD_SUBGROUP_SIZE_128 +#endif + +kernel void kernel_gemm_noshuffle_q5_0_f32( + global const ushort * src0_qs, // quantized A + global const uchar * src0_qh, // 5th bits + global const half * src0_d, // A scales + __read_only image1d_buffer_t src1, // B (1d image) + global float * dst, // C + int m, // M + int n, // N with padding + int k, // K + int n_no_padding // N without padding +) { + + int n_4 = n >> 2; + + int gy = get_global_id(0); + int gx = get_global_id(1); + int gx_2 = gx << 2; + + half8 c0 = 0, c1 = 0, c2 = 0, c3 = 0; + half8 B; + half4 dequantized_weights; + + global const ushort * weight_ptr = src0_qs + gx_2; + global const uchar * qh_ptr = src0_qh + gx_2; + global const half * scale_ptr = src0_d + gx_2; + + for (int i = 0; i < k; i += 4) { + + B.s0123 = read_imageh(src1, gy*2 + i*n_4); + B.s4567 = read_imageh(src1, gy*2 + i*n_4 + 1); + + ushort4 bits4 = vload4(0, weight_ptr + (i >> 2)*m); + uchar4 bits1 = vload4(0, qh_ptr + (i >> 3)*m); + uchar4 qh = bits1 >> (uchar4)(i & 4); + + half4 scale = vload4(0, scale_ptr + (i >> 5)*m); + + // j=0 + dequantized_weights.s0 = (convert_half((bits4.s0 & 0x000F) | ((qh.s0 & 0x01) << 4)) - 16.0h) * scale.s0; + dequantized_weights.s1 = (convert_half((bits4.s1 & 0x000F) | ((qh.s1 & 0x01) << 4)) - 16.0h) * scale.s1; + dequantized_weights.s2 = (convert_half((bits4.s2 & 0x000F) | ((qh.s2 & 0x01) << 4)) - 16.0h) * scale.s2; + dequantized_weights.s3 = (convert_half((bits4.s3 & 0x000F) | ((qh.s3 & 0x01) << 4)) - 16.0h) * scale.s3; + c0 += B * dequantized_weights.s0; + c1 += B * dequantized_weights.s1; + c2 += B * dequantized_weights.s2; + c3 += B * dequantized_weights.s3; + + // j=1 + B.s0123 = read_imageh(src1, gy*2 + (i+1)*n_4); + B.s4567 = read_imageh(src1, gy*2 + (i+1)*n_4 + 1); + dequantized_weights.s0 = (convert_half(((bits4.s0 & 0x00F0) >> 4) | ((qh.s0 & 0x02) << 3)) - 16.0h) * scale.s0; + dequantized_weights.s1 = (convert_half(((bits4.s1 & 0x00F0) >> 4) | ((qh.s1 & 0x02) << 3)) - 16.0h) * scale.s1; + dequantized_weights.s2 = (convert_half(((bits4.s2 & 0x00F0) >> 4) | ((qh.s2 & 0x02) << 3)) - 16.0h) * scale.s2; + dequantized_weights.s3 = (convert_half(((bits4.s3 & 0x00F0) >> 4) | ((qh.s3 & 0x02) << 3)) - 16.0h) * scale.s3; + c0 += B * dequantized_weights.s0; + c1 += B * dequantized_weights.s1; + c2 += B * dequantized_weights.s2; + c3 += B * dequantized_weights.s3; + + // j=2 + B.s0123 = read_imageh(src1, gy*2 + (i+2)*n_4); + B.s4567 = read_imageh(src1, gy*2 + (i+2)*n_4 + 1); + dequantized_weights.s0 = (convert_half(((bits4.s0 & 0x0F00) >> 8) | ((qh.s0 & 0x04) << 2)) - 16.0h) * scale.s0; + dequantized_weights.s1 = (convert_half(((bits4.s1 & 0x0F00) >> 8) | ((qh.s1 & 0x04) << 2)) - 16.0h) * scale.s1; + dequantized_weights.s2 = (convert_half(((bits4.s2 & 0x0F00) >> 8) | ((qh.s2 & 0x04) << 2)) - 16.0h) * scale.s2; + dequantized_weights.s3 = (convert_half(((bits4.s3 & 0x0F00) >> 8) | ((qh.s3 & 0x04) << 2)) - 16.0h) * scale.s3; + c0 += B * dequantized_weights.s0; + c1 += B * dequantized_weights.s1; + c2 += B * dequantized_weights.s2; + c3 += B * dequantized_weights.s3; + + // j=3 + B.s0123 = read_imageh(src1, gy*2 + (i+3)*n_4); + B.s4567 = read_imageh(src1, gy*2 + (i+3)*n_4 + 1); + dequantized_weights.s0 = (convert_half(((bits4.s0 & 0xF000) >> 12) | ((qh.s0 & 0x08) << 1)) - 16.0h) * scale.s0; + dequantized_weights.s1 = (convert_half(((bits4.s1 & 0xF000) >> 12) | ((qh.s1 & 0x08) << 1)) - 16.0h) * scale.s1; + dequantized_weights.s2 = (convert_half(((bits4.s2 & 0xF000) >> 12) | ((qh.s2 & 0x08) << 1)) - 16.0h) * scale.s2; + dequantized_weights.s3 = (convert_half(((bits4.s3 & 0xF000) >> 12) | ((qh.s3 & 0x08) << 1)) - 16.0h) * scale.s3; + c0 += B * dequantized_weights.s0; + c1 += B * dequantized_weights.s1; + c2 += B * dequantized_weights.s2; + c3 += B * dequantized_weights.s3; + } + + int idx = (gy<<3)*m + (gx<<2); + + if(idx+3 < m*n_no_padding){ + vstore4((float4)(c0.s0, c1.s0, c2.s0, c3.s0), 0, dst + idx); + idx += m; + } + if(idx+3 < m*n_no_padding){ + vstore4((float4)(c0.s1, c1.s1, c2.s1, c3.s1), 0, dst + idx); + idx += m; + } + if(idx+3 < m*n_no_padding){ + vstore4((float4)(c0.s2, c1.s2, c2.s2, c3.s2), 0, dst + idx); + idx += m; + } + if(idx+3 < m*n_no_padding){ + vstore4((float4)(c0.s3, c1.s3, c2.s3, c3.s3), 0, dst + idx); + idx += m; + } + if(idx+3 < m*n_no_padding){ + vstore4((float4)(c0.s4, c1.s4, c2.s4, c3.s4), 0, dst + idx); + idx += m; + } + if(idx+3 < m*n_no_padding){ + vstore4((float4)(c0.s5, c1.s5, c2.s5, c3.s5), 0, dst + idx); + idx += m; + } + if(idx+3 < m*n_no_padding){ + vstore4((float4)(c0.s6, c1.s6, c2.s6, c3.s6), 0, dst + idx); + idx += m; + } + if(idx+3 < m*n_no_padding){ + vstore4((float4)(c0.s7, c1.s7, c2.s7, c3.s7), 0, dst + idx); + } +} diff --git a/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q5_1_f32.cl b/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q5_1_f32.cl new file mode 100644 index 00000000000..94b4ef6cacc --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q5_1_f32.cl @@ -0,0 +1,134 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable +#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable + +#ifdef cl_qcom_reqd_sub_group_size +#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable +#define ADRENO_GPU 1 +#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full"))) +#endif + +#ifdef ADRENO_GPU +REQD_SUBGROUP_SIZE_128 +#endif + +kernel void kernel_gemm_noshuffle_q5_1_f32( + global const ushort * src0_qs, // quantized A + global const uchar * src0_qh, // 5th bits + global const half * src0_d, // A scales + global const half * src0_m, // A mins + __read_only image1d_buffer_t src1, // B (1d image) + global float * dst, // C + int m, // M + int n, // N with padding + int k, // K + int n_no_padding // N without padding +) { + + int n_4 = n >> 2; + + int gy = get_global_id(0); + int gx = get_global_id(1); + int gx_2 = gx << 2; + + half8 c0 = 0, c1 = 0, c2 = 0, c3 = 0; + half8 B; + half4 dequantized_weights; + + global const ushort * weight_ptr = src0_qs + gx_2; + global const uchar * qh_ptr = src0_qh + gx_2; + global const half * scale_ptr = src0_d + gx_2; + global const half * min_ptr = src0_m + gx_2; + + for (int i = 0; i < k; i += 4) { + + B.s0123 = read_imageh(src1, gy*2 + i*n_4); + B.s4567 = read_imageh(src1, gy*2 + i*n_4 + 1); + + ushort4 bits4 = vload4(0, weight_ptr + (i >> 2)*m); + uchar4 bits1 = vload4(0, qh_ptr + (i >> 3)*m); + uchar4 qh = bits1 >> (uchar4)(i & 4); + + half4 scale = vload4(0, scale_ptr + (i >> 5)*m); + half4 minv = vload4(0, min_ptr + (i >> 5)*m); + + // j=0 + dequantized_weights.s0 = convert_half((bits4.s0 & 0x000F) | ((qh.s0 & 0x01) << 4)) * scale.s0 + minv.s0; + dequantized_weights.s1 = convert_half((bits4.s1 & 0x000F) | ((qh.s1 & 0x01) << 4)) * scale.s1 + minv.s1; + dequantized_weights.s2 = convert_half((bits4.s2 & 0x000F) | ((qh.s2 & 0x01) << 4)) * scale.s2 + minv.s2; + dequantized_weights.s3 = convert_half((bits4.s3 & 0x000F) | ((qh.s3 & 0x01) << 4)) * scale.s3 + minv.s3; + c0 += B * dequantized_weights.s0; + c1 += B * dequantized_weights.s1; + c2 += B * dequantized_weights.s2; + c3 += B * dequantized_weights.s3; + + // j=1 + B.s0123 = read_imageh(src1, gy*2 + (i+1)*n_4); + B.s4567 = read_imageh(src1, gy*2 + (i+1)*n_4 + 1); + dequantized_weights.s0 = convert_half(((bits4.s0 & 0x00F0) >> 4) | ((qh.s0 & 0x02) << 3)) * scale.s0 + minv.s0; + dequantized_weights.s1 = convert_half(((bits4.s1 & 0x00F0) >> 4) | ((qh.s1 & 0x02) << 3)) * scale.s1 + minv.s1; + dequantized_weights.s2 = convert_half(((bits4.s2 & 0x00F0) >> 4) | ((qh.s2 & 0x02) << 3)) * scale.s2 + minv.s2; + dequantized_weights.s3 = convert_half(((bits4.s3 & 0x00F0) >> 4) | ((qh.s3 & 0x02) << 3)) * scale.s3 + minv.s3; + c0 += B * dequantized_weights.s0; + c1 += B * dequantized_weights.s1; + c2 += B * dequantized_weights.s2; + c3 += B * dequantized_weights.s3; + + // j=2 + B.s0123 = read_imageh(src1, gy*2 + (i+2)*n_4); + B.s4567 = read_imageh(src1, gy*2 + (i+2)*n_4 + 1); + dequantized_weights.s0 = convert_half(((bits4.s0 & 0x0F00) >> 8) | ((qh.s0 & 0x04) << 2)) * scale.s0 + minv.s0; + dequantized_weights.s1 = convert_half(((bits4.s1 & 0x0F00) >> 8) | ((qh.s1 & 0x04) << 2)) * scale.s1 + minv.s1; + dequantized_weights.s2 = convert_half(((bits4.s2 & 0x0F00) >> 8) | ((qh.s2 & 0x04) << 2)) * scale.s2 + minv.s2; + dequantized_weights.s3 = convert_half(((bits4.s3 & 0x0F00) >> 8) | ((qh.s3 & 0x04) << 2)) * scale.s3 + minv.s3; + c0 += B * dequantized_weights.s0; + c1 += B * dequantized_weights.s1; + c2 += B * dequantized_weights.s2; + c3 += B * dequantized_weights.s3; + + // j=3 + B.s0123 = read_imageh(src1, gy*2 + (i+3)*n_4); + B.s4567 = read_imageh(src1, gy*2 + (i+3)*n_4 + 1); + dequantized_weights.s0 = convert_half(((bits4.s0 & 0xF000) >> 12) | ((qh.s0 & 0x08) << 1)) * scale.s0 + minv.s0; + dequantized_weights.s1 = convert_half(((bits4.s1 & 0xF000) >> 12) | ((qh.s1 & 0x08) << 1)) * scale.s1 + minv.s1; + dequantized_weights.s2 = convert_half(((bits4.s2 & 0xF000) >> 12) | ((qh.s2 & 0x08) << 1)) * scale.s2 + minv.s2; + dequantized_weights.s3 = convert_half(((bits4.s3 & 0xF000) >> 12) | ((qh.s3 & 0x08) << 1)) * scale.s3 + minv.s3; + c0 += B * dequantized_weights.s0; + c1 += B * dequantized_weights.s1; + c2 += B * dequantized_weights.s2; + c3 += B * dequantized_weights.s3; + } + + int idx = (gy<<3)*m + (gx<<2); + + if(idx+3 < m*n_no_padding){ + vstore4((float4)(c0.s0, c1.s0, c2.s0, c3.s0), 0, dst + idx); + idx += m; + } + if(idx+3 < m*n_no_padding){ + vstore4((float4)(c0.s1, c1.s1, c2.s1, c3.s1), 0, dst + idx); + idx += m; + } + if(idx+3 < m*n_no_padding){ + vstore4((float4)(c0.s2, c1.s2, c2.s2, c3.s2), 0, dst + idx); + idx += m; + } + if(idx+3 < m*n_no_padding){ + vstore4((float4)(c0.s3, c1.s3, c2.s3, c3.s3), 0, dst + idx); + idx += m; + } + if(idx+3 < m*n_no_padding){ + vstore4((float4)(c0.s4, c1.s4, c2.s4, c3.s4), 0, dst + idx); + idx += m; + } + if(idx+3 < m*n_no_padding){ + vstore4((float4)(c0.s5, c1.s5, c2.s5, c3.s5), 0, dst + idx); + idx += m; + } + if(idx+3 < m*n_no_padding){ + vstore4((float4)(c0.s6, c1.s6, c2.s6, c3.s6), 0, dst + idx); + idx += m; + } + if(idx+3 < m*n_no_padding){ + vstore4((float4)(c0.s7, c1.s7, c2.s7, c3.s7), 0, dst + idx); + } +} diff --git a/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q5_k_f32.cl b/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q5_k_f32.cl new file mode 100644 index 00000000000..058c0f7edc6 --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q5_k_f32.cl @@ -0,0 +1,176 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable + +#ifdef cl_qcom_reqd_sub_group_size +#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable +#define ADRENO_GPU 1 +#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full"))) +#endif +#define QK_K 256 +#define K_SCALE_SIZE 12 + +inline void get_scale_min_k4( + int j, + global const uchar * q, + uchar * d, + uchar * m, + uchar mask_d6, + uchar mask_d4, + uchar mask_hi2 +) { + if (j < 4) { + *d = q[j] & mask_d6; + *m = q[j+4] & mask_d6; + } else { + *d = (q[j+4] & mask_d4) | ((q[j-4] & mask_hi2) >> 2); + *m = ((q[j+4] >> 4) & mask_d4) | ((q[j] & mask_hi2) >> 2); + } +} + +#ifdef ADRENO_GPU +REQD_SUBGROUP_SIZE_128 +#endif +kernel void kernel_gemm_noshuffle_q5_k_f32( + global const ushort * src0_q, + global const uchar * src0_qh, + global const uchar * src0_s, + global const half * src0_d, + global const half * src0_dm, + read_only image1d_buffer_t src1, + global float * dst, + ulong offsetd, + int m, + int n, + int k, + int n_no_padding, + uchar mask_d6, + uchar mask_d4, + uchar mask_hi2 +) { + dst = (global float *)((global char *)dst + offsetd); + int n_4 = n >> 2; + int gy = get_global_id(0); + int gx = get_global_id(1); + int gx_2 = gx << 2; + + half8 c0 = 0, c1 = 0, c2 = 0, c3 = 0; + half8 B; + half4 dequantized_weights; + + int num_blocks_K = k / QK_K; + + global const ushort * weight_ptr = src0_q + gx_2; + global const uchar * qh_ptr = src0_qh + gx_2; + global const half * d_ptr = src0_d + gx_2; + global const half * dm_ptr = src0_dm + gx_2; + + for (int i = 0; i < k; i += 32) { + int sb_idx = i / QK_K; + int sub_idx = (i / 32) % 8; + + half4 d = vload4(0, d_ptr + sb_idx * m); + half4 dm = vload4(0, dm_ptr + sb_idx * m); + + global const uchar * sc0 = src0_s + (gx_2+0) * num_blocks_K * K_SCALE_SIZE + sb_idx * K_SCALE_SIZE; + global const uchar * sc1 = src0_s + (gx_2+1) * num_blocks_K * K_SCALE_SIZE + sb_idx * K_SCALE_SIZE; + global const uchar * sc2 = src0_s + (gx_2+2) * num_blocks_K * K_SCALE_SIZE + sb_idx * K_SCALE_SIZE; + global const uchar * sc3 = src0_s + (gx_2+3) * num_blocks_K * K_SCALE_SIZE + sb_idx * K_SCALE_SIZE; + + uchar sv0, mn0, sv1, mn1, sv2, mn2, sv3, mn3; + get_scale_min_k4(sub_idx, sc0, &sv0, &mn0, mask_d6, mask_d4, mask_hi2); + get_scale_min_k4(sub_idx, sc1, &sv1, &mn1, mask_d6, mask_d4, mask_hi2); + get_scale_min_k4(sub_idx, sc2, &sv2, &mn2, mask_d6, mask_d4, mask_hi2); + get_scale_min_k4(sub_idx, sc3, &sv3, &mn3, mask_d6, mask_d4, mask_hi2); + + half4 scale = convert_half4(convert_float4(d) * convert_float4((uchar4)(sv0, sv1, sv2, sv3))); + half4 mval = convert_half4(convert_float4(dm) * convert_float4((uchar4)(mn0, mn1, mn2, mn3))); + + for (int l = 0; l < 32; l += 4) { + int ki = i + l; + ushort4 bits4 = vload4(0, weight_ptr + (ki/4) * m); + uchar4 qh_bits = vload4(0, qh_ptr + (ki/8) * m); + int qh_shift = ki % 8; + + // j=0 + B.s0123 = read_imageh(src1, gy*2 + (ki+0) * n_4); + B.s4567 = read_imageh(src1, gy*2+1 + (ki+0) * n_4); + dequantized_weights.s0 = ((bits4.s0 & 0x000F) | (((qh_bits.s0 >> (qh_shift+0)) & 1) << 4)) * scale.s0 - mval.s0; + dequantized_weights.s1 = ((bits4.s1 & 0x000F) | (((qh_bits.s1 >> (qh_shift+0)) & 1) << 4)) * scale.s1 - mval.s1; + dequantized_weights.s2 = ((bits4.s2 & 0x000F) | (((qh_bits.s2 >> (qh_shift+0)) & 1) << 4)) * scale.s2 - mval.s2; + dequantized_weights.s3 = ((bits4.s3 & 0x000F) | (((qh_bits.s3 >> (qh_shift+0)) & 1) << 4)) * scale.s3 - mval.s3; + c0 += B * dequantized_weights.s0; + c1 += B * dequantized_weights.s1; + c2 += B * dequantized_weights.s2; + c3 += B * dequantized_weights.s3; + + // j=1 + B.s0123 = read_imageh(src1, gy*2 + (ki+1) * n_4); + B.s4567 = read_imageh(src1, gy*2+1 + (ki+1) * n_4); + dequantized_weights.s0 = (((bits4.s0 & 0x00F0) >> 4) | (((qh_bits.s0 >> (qh_shift+1)) & 1) << 4)) * scale.s0 - mval.s0; + dequantized_weights.s1 = (((bits4.s1 & 0x00F0) >> 4) | (((qh_bits.s1 >> (qh_shift+1)) & 1) << 4)) * scale.s1 - mval.s1; + dequantized_weights.s2 = (((bits4.s2 & 0x00F0) >> 4) | (((qh_bits.s2 >> (qh_shift+1)) & 1) << 4)) * scale.s2 - mval.s2; + dequantized_weights.s3 = (((bits4.s3 & 0x00F0) >> 4) | (((qh_bits.s3 >> (qh_shift+1)) & 1) << 4)) * scale.s3 - mval.s3; + c0 += B * dequantized_weights.s0; + c1 += B * dequantized_weights.s1; + c2 += B * dequantized_weights.s2; + c3 += B * dequantized_weights.s3; + + // j=2 + B.s0123 = read_imageh(src1, gy*2 + (ki+2) * n_4); + B.s4567 = read_imageh(src1, gy*2+1 + (ki+2) * n_4); + dequantized_weights.s0 = (((bits4.s0 & 0x0F00) >> 8) | (((qh_bits.s0 >> (qh_shift+2)) & 1) << 4)) * scale.s0 - mval.s0; + dequantized_weights.s1 = (((bits4.s1 & 0x0F00) >> 8) | (((qh_bits.s1 >> (qh_shift+2)) & 1) << 4)) * scale.s1 - mval.s1; + dequantized_weights.s2 = (((bits4.s2 & 0x0F00) >> 8) | (((qh_bits.s2 >> (qh_shift+2)) & 1) << 4)) * scale.s2 - mval.s2; + dequantized_weights.s3 = (((bits4.s3 & 0x0F00) >> 8) | (((qh_bits.s3 >> (qh_shift+2)) & 1) << 4)) * scale.s3 - mval.s3; + c0 += B * dequantized_weights.s0; + c1 += B * dequantized_weights.s1; + c2 += B * dequantized_weights.s2; + c3 += B * dequantized_weights.s3; + + // j=3 + B.s0123 = read_imageh(src1, gy*2 + (ki+3) * n_4); + B.s4567 = read_imageh(src1, gy*2+1 + (ki+3) * n_4); + dequantized_weights.s0 = (((bits4.s0 & 0xF000) >> 12) | (((qh_bits.s0 >> (qh_shift+3)) & 1) << 4)) * scale.s0 - mval.s0; + dequantized_weights.s1 = (((bits4.s1 & 0xF000) >> 12) | (((qh_bits.s1 >> (qh_shift+3)) & 1) << 4)) * scale.s1 - mval.s1; + dequantized_weights.s2 = (((bits4.s2 & 0xF000) >> 12) | (((qh_bits.s2 >> (qh_shift+3)) & 1) << 4)) * scale.s2 - mval.s2; + dequantized_weights.s3 = (((bits4.s3 & 0xF000) >> 12) | (((qh_bits.s3 >> (qh_shift+3)) & 1) << 4)) * scale.s3 - mval.s3; + c0 += B * dequantized_weights.s0; + c1 += B * dequantized_weights.s1; + c2 += B * dequantized_weights.s2; + c3 += B * dequantized_weights.s3; + } + } + + int idx = (gy<<3)*m + (gx<<2); + + if (idx+3 < m*n_no_padding) { + vstore4((float4)(c0.s0, c1.s0, c2.s0, c3.s0), 0, dst + idx); + idx += m; + } + if (idx+3 < m*n_no_padding) { + vstore4((float4)(c0.s1, c1.s1, c2.s1, c3.s1), 0, dst + idx); + idx += m; + } + if (idx+3 < m*n_no_padding) { + vstore4((float4)(c0.s2, c1.s2, c2.s2, c3.s2), 0, dst + idx); + idx += m; + } + if (idx+3 < m*n_no_padding) { + vstore4((float4)(c0.s3, c1.s3, c2.s3, c3.s3), 0, dst + idx); + idx += m; + } + if (idx+3 < m*n_no_padding) { + vstore4((float4)(c0.s4, c1.s4, c2.s4, c3.s4), 0, dst + idx); + idx += m; + } + if (idx+3 < m*n_no_padding) { + vstore4((float4)(c0.s5, c1.s5, c2.s5, c3.s5), 0, dst + idx); + idx += m; + } + if (idx+3 < m*n_no_padding) { + vstore4((float4)(c0.s6, c1.s6, c2.s6, c3.s6), 0, dst + idx); + idx += m; + } + if (idx+3 < m*n_no_padding) { + vstore4((float4)(c0.s7, c1.s7, c2.s7, c3.s7), 0, dst + idx); + } +} diff --git a/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q6_k_f32.cl b/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q6_k_f32.cl new file mode 100644 index 00000000000..3a9c624508a --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q6_k_f32.cl @@ -0,0 +1,140 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable +#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable + +#ifdef cl_qcom_reqd_sub_group_size +#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable +#define ADRENO_GPU 1 +#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full"))) +#endif + +#ifdef ADRENO_GPU +REQD_SUBGROUP_SIZE_128 +#endif +kernel void kernel_gemm_noshuffle_q6_K_f32( + global const ushort * src0_ql, + global const uchar * src0_qh, + global const ushort * src0_s, + global const half * src0_d, + read_only image1d_buffer_t src1, + global float * dst, + ulong offsetd, + int m, + int n, + int k, + int n_no_padding, + ushort mask_f000, + uchar mask_c0 +) { + dst = (global float *)( (global char *)dst + offsetd ); + + int m_4 = m >> 2; + int n_4 = n >> 2; + + int gy = get_global_id(0); // n + int gx = get_global_id(1); // m + int gx_2 = gx << 2; + + half8 c0 = 0, c1 = 0, c2 = 0, c3 = 0; + half8 B; + half4 dequantized_weights; + + global const ushort * ptr_ql = src0_ql + gx_2; + global const uchar * ptr_qh = src0_qh + gx_2; + global const ushort * ptr_s = src0_s + gx_2; + global const half * ptr_d = src0_d + gx_2; + + for (int i = 0; i < k; i += 4) { + // load 4x elements (ushort) of ql on M, each ushort contains 4 weights + // 4x ushort correspons to 4 rows on M + ushort4 bits4 = vload4(0, ptr_ql + (i/4)*m); // ql packed in 4s in ushort + uchar4 bits2 = vload4(0, ptr_qh + (i/4)*m); // qh packed in 4s in uchar + + // load 4 consecutive scales + char8 scale_s_8 = as_char8(vload4(0, ptr_s + (i/16/2)*m)); // 1 char scale every 16 elements, packed in 2s + char4 scale_s = ((i/16) % 2) == 0 ? scale_s_8.s0246 : scale_s_8.s1357; // transposed as ushort, 2 blocks + half4 scale_d = vload4(0, ptr_d + (i/256)*m); // 1 half scale every 256 elements + + // j=0 + // load 2x 4 elements of activations on N, corresponding to 8 rows on N + B.s0123 = read_imageh(src1, gy*2 + (i + 0)*n_4 + 0); + B.s4567 = read_imageh(src1, gy*2 + (i + 0)*n_4 + 1); + dequantized_weights.s0 = (convert_half((bits4.s0 & 0x000F) | ((bits2.s0 & 0x03) << 4)) - 32.f) * scale_s.s0 * scale_d.s0; + dequantized_weights.s1 = (convert_half((bits4.s1 & 0x000F) | ((bits2.s1 & 0x03) << 4)) - 32.f) * scale_s.s1 * scale_d.s1; + dequantized_weights.s2 = (convert_half((bits4.s2 & 0x000F) | ((bits2.s2 & 0x03) << 4)) - 32.f) * scale_s.s2 * scale_d.s2; + dequantized_weights.s3 = (convert_half((bits4.s3 & 0x000F) | ((bits2.s3 & 0x03) << 4)) - 32.f) * scale_s.s3 * scale_d.s3; + c0 += B * dequantized_weights.s0; + c1 += B * dequantized_weights.s1; + c2 += B * dequantized_weights.s2; + c3 += B * dequantized_weights.s3; + + // j=1 + B.s0123 = read_imageh(src1, gy*2 + (i + 1)*n_4 + 0); + B.s4567 = read_imageh(src1, gy*2 + (i + 1)*n_4 + 1); + dequantized_weights.s0 = (convert_half((((bits4.s0 & 0x00F0) >> 4) | ((bits2.s0 & 0x0C) << 2))) - 32.f) * scale_s.s0 * scale_d.s0; + dequantized_weights.s1 = (convert_half((((bits4.s1 & 0x00F0) >> 4) | ((bits2.s1 & 0x0C) << 2))) - 32.f) * scale_s.s1 * scale_d.s1; + dequantized_weights.s2 = (convert_half((((bits4.s2 & 0x00F0) >> 4) | ((bits2.s2 & 0x0C) << 2))) - 32.f) * scale_s.s2 * scale_d.s2; + dequantized_weights.s3 = (convert_half((((bits4.s3 & 0x00F0) >> 4) | ((bits2.s3 & 0x0C) << 2))) - 32.f) * scale_s.s3 * scale_d.s3; + c0 += B * dequantized_weights.s0; + c1 += B * dequantized_weights.s1; + c2 += B * dequantized_weights.s2; + c3 += B * dequantized_weights.s3; + + // j=2 + B.s0123 = read_imageh(src1, gy*2 + (i + 2)*n_4 + 0); + B.s4567 = read_imageh(src1, gy*2 + (i + 2)*n_4 + 1); + dequantized_weights.s0 = (convert_half((((bits4.s0 & 0x0F00) >> 8) | (bits2.s0 & 0x30))) - 32.f) * scale_s.s0 * scale_d.s0; + dequantized_weights.s1 = (convert_half((((bits4.s1 & 0x0F00) >> 8) | (bits2.s1 & 0x30))) - 32.f) * scale_s.s1 * scale_d.s1; + dequantized_weights.s2 = (convert_half((((bits4.s2 & 0x0F00) >> 8) | (bits2.s2 & 0x30))) - 32.f) * scale_s.s2 * scale_d.s2; + dequantized_weights.s3 = (convert_half((((bits4.s3 & 0x0F00) >> 8) | (bits2.s3 & 0x30))) - 32.f) * scale_s.s3 * scale_d.s3; + c0 += B * dequantized_weights.s0; + c1 += B * dequantized_weights.s1; + c2 += B * dequantized_weights.s2; + c3 += B * dequantized_weights.s3; + + // j=3 + B.s0123 = read_imageh(src1, gy*2 + (i + 3)*n_4 + 0); + B.s4567 = read_imageh(src1, gy*2 + (i + 3)*n_4 + 1); + dequantized_weights.s0 = (convert_half((((bits4.s0 & mask_f000) >> 12) | ((bits2.s0 & mask_c0) >> 2))) - 32.f) * scale_s.s0 * scale_d.s0; + dequantized_weights.s1 = (convert_half((((bits4.s1 & mask_f000) >> 12) | ((bits2.s1 & mask_c0) >> 2))) - 32.f) * scale_s.s1 * scale_d.s1; + dequantized_weights.s2 = (convert_half((((bits4.s2 & mask_f000) >> 12) | ((bits2.s2 & mask_c0) >> 2))) - 32.f) * scale_s.s2 * scale_d.s2; + dequantized_weights.s3 = (convert_half((((bits4.s3 & mask_f000) >> 12) | ((bits2.s3 & mask_c0) >> 2))) - 32.f) * scale_s.s3 * scale_d.s3; + c0 += B * dequantized_weights.s0; + c1 += B * dequantized_weights.s1; + c2 += B * dequantized_weights.s2; + c3 += B * dequantized_weights.s3; + } + + int idx = (gy<<3)*m + (gx<<2); + + if(idx+3 < m*n_no_padding){ + vstore4((float4)(c0.s0, c1.s0, c2.s0, c3.s0), 0, dst + idx); + idx += m; + } + if(idx+3 < m*n_no_padding){ + vstore4((float4)(c0.s1, c1.s1, c2.s1, c3.s1), 0, dst + idx); + idx += m; + } + if(idx+3 < m*n_no_padding){ + vstore4((float4)(c0.s2, c1.s2, c2.s2, c3.s2), 0, dst + idx); + idx += m; + } + if(idx+3 < m*n_no_padding){ + vstore4((float4)(c0.s3, c1.s3, c2.s3, c3.s3), 0, dst + idx); + idx += m; + } + if(idx+3 < m*n_no_padding){ + vstore4((float4)(c0.s4, c1.s4, c2.s4, c3.s4), 0, dst + idx); + idx += m; + } + if(idx+3 < m*n_no_padding){ + vstore4((float4)(c0.s5, c1.s5, c2.s5, c3.s5), 0, dst + idx); + idx += m; + } + if(idx+3 < m*n_no_padding){ + vstore4((float4)(c0.s6, c1.s6, c2.s6, c3.s6), 0, dst + idx); + idx += m; + } + if(idx+3 < m*n_no_padding){ + vstore4((float4)(c0.s7, c1.s7, c2.s7, c3.s7), 0, dst + idx); + } +} diff --git a/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q8_0_f32.cl b/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q8_0_f32.cl new file mode 100644 index 00000000000..7f06a22a2cb --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q8_0_f32.cl @@ -0,0 +1,129 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable +#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable + +#ifdef cl_qcom_reqd_sub_group_size +#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable +#define ADRENO_GPU 1 +#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full"))) +#endif + +#ifdef ADRENO_GPU +REQD_SUBGROUP_SIZE_128 +#endif + +kernel void kernel_gemm_noshuffle_q8_0_f32( + global const uint * src0_q, + global const half * src0_d, + __read_only image1d_buffer_t src1, + global float * dst, + int k, + int m, + int n, + int n_no_padding, + ulong offsetd +) { + + int m_4 = m >> 2; + int n_4 = n >> 2; + + int gy = get_global_id(0); + int gx = get_global_id(1); + int gx_2 = gx << 2; + dst = (global float *)((global char*)dst + offsetd); + + + half8 c0 = 0, c1 = 0, c2 = 0, c3 = 0; + half8 B; + half4 deq; + + __global const uint* wptr = src0_q + gx_2; + __global const half* sptr = src0_d + gx_2; + + for (int i = 0; i < k; i += 4) { + uint4 pack4 = vload4(0, wptr + (i / 4) * m); + half4 scale = vload4(0, sptr + (i / 32) * m); + + char4 p0 = as_char4(pack4.s0); + char4 p1 = as_char4(pack4.s1); + char4 p2 = as_char4(pack4.s2); + char4 p3 = as_char4(pack4.s3); + + // ------------------- j = 0 (k = i+0) ------------------- + B.s0123 = read_imageh(src1, gy * 2 + (i + 0) * n_4); + B.s4567 = read_imageh(src1, gy * 2 + (i + 0) * n_4 + 1); + + half4 wj0 = convert_half4((char4)(p0.s0, p1.s0, p2.s0, p3.s0)) * scale; + + c0 += B * wj0.s0; + c1 += B * wj0.s1; + c2 += B * wj0.s2; + c3 += B * wj0.s3; + + // ------------------- j = 1 (k = i+1) ------------------- + B.s0123 = read_imageh(src1, gy * 2 + (i + 1) * n_4); + B.s4567 = read_imageh(src1, gy * 2 + (i + 1) * n_4 + 1); + + half4 wj1 = convert_half4((char4)(p0.s1, p1.s1, p2.s1, p3.s1)) * scale; + + c0 += B * wj1.s0; + c1 += B * wj1.s1; + c2 += B * wj1.s2; + c3 += B * wj1.s3; + + // ------------------- j = 2 (k = i+2) ------------------- + B.s0123 = read_imageh(src1, gy * 2 + (i + 2) * n_4); + B.s4567 = read_imageh(src1, gy * 2 + (i + 2) * n_4 + 1); + + half4 wj2 = convert_half4((char4)(p0.s2, p1.s2, p2.s2, p3.s2)) * scale; + + c0 += B * wj2.s0; + c1 += B * wj2.s1; + c2 += B * wj2.s2; + c3 += B * wj2.s3; + + // ------------------- j = 3 (k = i+3) ------------------- + B.s0123 = read_imageh(src1, gy * 2 + (i + 3) * n_4); + B.s4567 = read_imageh(src1, gy * 2 + (i + 3) * n_4 + 1); + + half4 wj3 = convert_half4((char4)(p0.s3, p1.s3, p2.s3, p3.s3)) * scale; + + c0 += B * wj3.s0; + c1 += B * wj3.s1; + c2 += B * wj3.s2; + c3 += B * wj3.s3; + } + + int idx = (gy << 3) * m + (gx << 2); + + if(idx+3 < m*n_no_padding){ + vstore4((float4)(c0.s0, c1.s0, c2.s0, c3.s0), 0, dst + idx); + idx += m; + } + if(idx+3 < m*n_no_padding){ + vstore4((float4)(c0.s1, c1.s1, c2.s1, c3.s1), 0, dst + idx); + idx += m; + } + if(idx+3 < m*n_no_padding){ + vstore4((float4)(c0.s2, c1.s2, c2.s2, c3.s2), 0, dst + idx); + idx += m; + } + if(idx+3 < m*n_no_padding){ + vstore4((float4)(c0.s3, c1.s3, c2.s3, c3.s3), 0, dst + idx); + idx += m; + } + if(idx+3 < m*n_no_padding){ + vstore4((float4)(c0.s4, c1.s4, c2.s4, c3.s4), 0, dst + idx); + idx += m; + } + if(idx+3 < m*n_no_padding){ + vstore4((float4)(c0.s5, c1.s5, c2.s5, c3.s5), 0, dst + idx); + idx += m; + } + if(idx+3 < m*n_no_padding){ + vstore4((float4)(c0.s6, c1.s6, c2.s6, c3.s6), 0, dst + idx); + idx += m; + } + if(idx+3 < m*n_no_padding){ + vstore4((float4)(c0.s7, c1.s7, c2.s7, c3.s7), 0, dst + idx); + } +} diff --git a/ggml/src/ggml-opencl/kernels/gemm_xmem_f16_f32_os8.cl b/ggml/src/ggml-opencl/kernels/gemm_xmem_f16_f32_os8.cl new file mode 100644 index 00000000000..df9d9aed067 --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/gemm_xmem_f16_f32_os8.cl @@ -0,0 +1,233 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable +#pragma OPENCL EXTENSION cl_qcom_subgroup_uniform_load : enable +#pragma OPENCL EXTENSION cl_qcom_subgroup_constant_load : enable + +__constant sampler_t smp_zero = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST; + +__kernel void adreno_xmem_pack_src_f32( + __global const void * src_void, + ulong offset, + __write_only image2d_t src_img, + int K, + int N) { + const int x = get_global_id(0); + const int y = get_global_id(1); + const int kpack = K / 4; + + if (x >= N || y >= kpack) { + return; + } + + __global const float * src = (__global const float *)((__global const char *)src_void + offset); + const int base = x*K + y*4; + const half4 v = (half4)((half)src[base + 0], (half)src[base + 1], (half)src[base + 2], (half)src[base + 3]); + write_imageh(src_img, (int2)(x, y), v); +} + +__kernel void adreno_xmem_prepack_weight_f16( + __global half4 * dst, + __global const void * src_void, + ulong offset, + int K, + int M, + int kpack, + int npack, + int os) { + const int linear = get_global_id(0); + const int total = kpack*npack; + if (linear >= total) { + return; + } + + __global const half * src = (__global const half *)((__global const char *)src_void + offset); + + const int dst_ogroup = linear % os; + const int dst_o_sp_i = linear / os; + const int dst_i = dst_o_sp_i % kpack; + const int dst_o = dst_o_sp_i / kpack; + const int o_slice = dst_o*os + dst_ogroup; + const int k_base = dst_i*4; + + half4 w0 = (half4)(0.0h); + half4 w1 = (half4)(0.0h); + half4 w2 = (half4)(0.0h); + half4 w3 = (half4)(0.0h); + + const int o0 = o_slice*4 + 0; + const int o1 = o_slice*4 + 1; + const int o2 = o_slice*4 + 2; + const int o3 = o_slice*4 + 3; + + if (k_base + 0 < K) { + if (o0 < M) w0.s0 = src[o0*K + k_base + 0]; + if (o1 < M) w0.s1 = src[o1*K + k_base + 0]; + if (o2 < M) w0.s2 = src[o2*K + k_base + 0]; + if (o3 < M) w0.s3 = src[o3*K + k_base + 0]; + } + if (k_base + 1 < K) { + if (o0 < M) w1.s0 = src[o0*K + k_base + 1]; + if (o1 < M) w1.s1 = src[o1*K + k_base + 1]; + if (o2 < M) w1.s2 = src[o2*K + k_base + 1]; + if (o3 < M) w1.s3 = src[o3*K + k_base + 1]; + } + if (k_base + 2 < K) { + if (o0 < M) w2.s0 = src[o0*K + k_base + 2]; + if (o1 < M) w2.s1 = src[o1*K + k_base + 2]; + if (o2 < M) w2.s2 = src[o2*K + k_base + 2]; + if (o3 < M) w2.s3 = src[o3*K + k_base + 2]; + } + if (k_base + 3 < K) { + if (o0 < M) w3.s0 = src[o0*K + k_base + 3]; + if (o1 < M) w3.s1 = src[o1*K + k_base + 3]; + if (o2 < M) w3.s2 = src[o2*K + k_base + 3]; + if (o3 < M) w3.s3 = src[o3*K + k_base + 3]; + } + + dst[linear*4 + 0] = w0; + dst[linear*4 + 1] = w1; + dst[linear*4 + 2] = w2; + dst[linear*4 + 3] = w3; +} + +__attribute__((qcom_max_concurrent_subgroups(12))) +__kernel void kernel_gemm_xmem_f16_f32_os8( + __constant half8 * weights_buffer __attribute__((sub_group_uniform)), + __constant half8 * xmem_buffer __attribute__((max_constant_size((6144)))), + __read_only image2d_t src_img, + __write_only image2d_t dst_img, + int N, + int npack, + int kpack) { + const int X = get_group_id(1)*get_local_size(0) + get_local_id(0); + const int Z = get_group_id(0)*get_local_size(2) + get_local_id(2); + + if (X >= N || Z*8 >= npack) { + return; + } + + half4 r0 = (half4)(0.0h); + half4 r1 = (half4)(0.0h); + half4 r2 = (half4)(0.0h); + half4 r3 = (half4)(0.0h); + half4 r4 = (half4)(0.0h); + half4 r5 = (half4)(0.0h); + half4 r6 = (half4)(0.0h); + half4 r7 = (half4)(0.0h); + + int f_offset = Z*kpack*32; + int subgroup_id = (int)(0x1F & qcom_get_physical_sub_group_id()); + subgroup_id = subgroup_id % 12; + const int c_offset = subgroup_id*32; + __constant half16 * weights_cache = (__constant half16 *)&xmem_buffer[c_offset]; + + int coord_s = 0; + do { + const half4 src0 = read_imageh(src_img, smp_zero, (int2)(X, coord_s)); + coord_s++; + const half4 src1 = read_imageh(src_img, smp_zero, (int2)(X, coord_s)); + coord_s++; + + qcom_sub_group_constant_load8(xmem_buffer, weights_buffer, c_offset, f_offset >> 1, 32); + f_offset += 64; + qcom_sub_group_sync(QCOM_CLK_CONST_LOAD_SYNC); + + r0 += src0.x * weights_cache[0].s0123; + r0 += src0.y * weights_cache[0].s4567; + r0 += src0.z * weights_cache[0].s89ab; + r0 += src0.w * weights_cache[0].scdef; + r1 += src0.x * weights_cache[1].s0123; + r1 += src0.y * weights_cache[1].s4567; + r1 += src0.z * weights_cache[1].s89ab; + r1 += src0.w * weights_cache[1].scdef; + r2 += src0.x * weights_cache[2].s0123; + r2 += src0.y * weights_cache[2].s4567; + r2 += src0.z * weights_cache[2].s89ab; + r2 += src0.w * weights_cache[2].scdef; + r3 += src0.x * weights_cache[3].s0123; + r3 += src0.y * weights_cache[3].s4567; + r3 += src0.z * weights_cache[3].s89ab; + r3 += src0.w * weights_cache[3].scdef; + r4 += src0.x * weights_cache[4].s0123; + r4 += src0.y * weights_cache[4].s4567; + r4 += src0.z * weights_cache[4].s89ab; + r4 += src0.w * weights_cache[4].scdef; + r5 += src0.x * weights_cache[5].s0123; + r5 += src0.y * weights_cache[5].s4567; + r5 += src0.z * weights_cache[5].s89ab; + r5 += src0.w * weights_cache[5].scdef; + r6 += src0.x * weights_cache[6].s0123; + r6 += src0.y * weights_cache[6].s4567; + r6 += src0.z * weights_cache[6].s89ab; + r6 += src0.w * weights_cache[6].scdef; + r7 += src0.x * weights_cache[7].s0123; + r7 += src0.y * weights_cache[7].s4567; + r7 += src0.z * weights_cache[7].s89ab; + r7 += src0.w * weights_cache[7].scdef; + + r0 += src1.x * weights_cache[8].s0123; + r0 += src1.y * weights_cache[8].s4567; + r0 += src1.z * weights_cache[8].s89ab; + r0 += src1.w * weights_cache[8].scdef; + r1 += src1.x * weights_cache[9].s0123; + r1 += src1.y * weights_cache[9].s4567; + r1 += src1.z * weights_cache[9].s89ab; + r1 += src1.w * weights_cache[9].scdef; + r2 += src1.x * weights_cache[10].s0123; + r2 += src1.y * weights_cache[10].s4567; + r2 += src1.z * weights_cache[10].s89ab; + r2 += src1.w * weights_cache[10].scdef; + r3 += src1.x * weights_cache[11].s0123; + r3 += src1.y * weights_cache[11].s4567; + r3 += src1.z * weights_cache[11].s89ab; + r3 += src1.w * weights_cache[11].scdef; + r4 += src1.x * weights_cache[12].s0123; + r4 += src1.y * weights_cache[12].s4567; + r4 += src1.z * weights_cache[12].s89ab; + r4 += src1.w * weights_cache[12].scdef; + r5 += src1.x * weights_cache[13].s0123; + r5 += src1.y * weights_cache[13].s4567; + r5 += src1.z * weights_cache[13].s89ab; + r5 += src1.w * weights_cache[13].scdef; + r6 += src1.x * weights_cache[14].s0123; + r6 += src1.y * weights_cache[14].s4567; + r6 += src1.z * weights_cache[14].s89ab; + r6 += src1.w * weights_cache[14].scdef; + r7 += src1.x * weights_cache[15].s0123; + r7 += src1.y * weights_cache[15].s4567; + r7 += src1.z * weights_cache[15].s89ab; + r7 += src1.w * weights_cache[15].scdef; + } while (coord_s < kpack); + + int coord_s_out = Z*8; + if (coord_s_out < npack) { write_imageh(dst_img, (int2)(X, coord_s_out), r0); coord_s_out++; } + if (coord_s_out < npack) { write_imageh(dst_img, (int2)(X, coord_s_out), r1); coord_s_out++; } + if (coord_s_out < npack) { write_imageh(dst_img, (int2)(X, coord_s_out), r2); coord_s_out++; } + if (coord_s_out < npack) { write_imageh(dst_img, (int2)(X, coord_s_out), r3); coord_s_out++; } + if (coord_s_out < npack) { write_imageh(dst_img, (int2)(X, coord_s_out), r4); coord_s_out++; } + if (coord_s_out < npack) { write_imageh(dst_img, (int2)(X, coord_s_out), r5); coord_s_out++; } + if (coord_s_out < npack) { write_imageh(dst_img, (int2)(X, coord_s_out), r6); coord_s_out++; } + if (coord_s_out < npack) { write_imageh(dst_img, (int2)(X, coord_s_out), r7); } +} + +__kernel void adreno_xmem_store_dst_f32( + __read_only image2d_t dst_img, + __global void * dst_void, + ulong offset, + int M, + int N) { + const int x = get_global_id(0); + const int y = get_global_id(1); + const int npack = (M + 3) / 4; + + if (x >= N || y >= npack) { + return; + } + + __global float * dst = (__global float *)((__global char *)dst_void + offset); + const half4 hv = read_imageh(dst_img, smp_zero, (int2)(x, y)); + const int m = y*4; + if (m + 0 < M) dst[x*M + m + 0] = (float)hv.s0; + if (m + 1 < M) dst[x*M + m + 1] = (float)hv.s1; + if (m + 2 < M) dst[x*M + m + 2] = (float)hv.s2; + if (m + 3 < M) dst[x*M + m + 3] = (float)hv.s3; +} diff --git a/ggml/src/ggml-opencl/kernels/gemv_moe_mxfp4_f32_ns.cl b/ggml/src/ggml-opencl/kernels/gemv_moe_mxfp4_f32_ns.cl new file mode 100644 index 00000000000..75129e20c65 --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/gemv_moe_mxfp4_f32_ns.cl @@ -0,0 +1,165 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable +#pragma OPENCL EXTENSION cl_khr_subgroups : enable +#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable + +#define QK_MXFP4 32 +#define N_SIMDGROUP 4 +#define SIMDGROUP_WIDTH 64 + +static inline half8 mxfp4_to_fp16_packed8(ushort2 fp4x8) { + ushort2 fp16_packed_a_0, fp16_packed_b_0, bias_a, bias_b, sign_a, sign_b; + fp16_packed_a_0.lo = (fp4x8.s0 << 9) & 0x0E00; + fp16_packed_a_0.hi = (fp4x8.s0 << 5) & 0x0E00; + fp16_packed_b_0.lo = (fp4x8.s0 << 1) & 0x0E00; + fp16_packed_b_0.hi = (fp4x8.s0 >> 3) & 0x0E00; + + bias_a.lo = (fp16_packed_a_0.lo != 0) ? 0x3800 : 0x0; + bias_a.hi = (fp16_packed_a_0.hi != 0) ? 0x3800 : 0x0; + bias_b.lo = (fp16_packed_b_0.lo != 0) ? 0x3800 : 0x0; + bias_b.hi = (fp16_packed_b_0.hi != 0) ? 0x3800 : 0x0; + + fp16_packed_a_0.lo = (fp16_packed_a_0.lo != 0x0200) ? fp16_packed_a_0.lo : 0x0; + fp16_packed_a_0.hi = (fp16_packed_a_0.hi != 0x0200) ? fp16_packed_a_0.hi : 0x0; + fp16_packed_b_0.lo = (fp16_packed_b_0.lo != 0x0200) ? fp16_packed_b_0.lo : 0x0; + fp16_packed_b_0.hi = (fp16_packed_b_0.hi != 0x0200) ? fp16_packed_b_0.hi : 0x0; + + sign_a.lo = (fp4x8.s0 << 12) & 0x8000; + sign_a.hi = (fp4x8.s0 << 8) & 0x8000; + sign_b.lo = (fp4x8.s0 << 4) & 0x8000; + sign_b.hi = fp4x8.s0 & 0x8000; + + fp16_packed_a_0 = sign_a + bias_a + fp16_packed_a_0; + fp16_packed_b_0 = sign_b + bias_b + fp16_packed_b_0; + + ushort2 fp16_packed_a_1, fp16_packed_b_1; + fp16_packed_a_1.lo = (fp4x8.s1 << 9) & 0x0E00; + fp16_packed_a_1.hi = (fp4x8.s1 << 5) & 0x0E00; + fp16_packed_b_1.lo = (fp4x8.s1 << 1) & 0x0E00; + fp16_packed_b_1.hi = (fp4x8.s1 >> 3) & 0x0E00; + + bias_a.lo = (fp16_packed_a_1.lo != 0) ? 0x3800 : 0x0; + bias_a.hi = (fp16_packed_a_1.hi != 0) ? 0x3800 : 0x0; + bias_b.lo = (fp16_packed_b_1.lo != 0) ? 0x3800 : 0x0; + bias_b.hi = (fp16_packed_b_1.hi != 0) ? 0x3800 : 0x0; + + fp16_packed_a_1.lo = (fp16_packed_a_1.lo != 0x0200) ? fp16_packed_a_1.lo : 0x0; + fp16_packed_a_1.hi = (fp16_packed_a_1.hi != 0x0200) ? fp16_packed_a_1.hi : 0x0; + fp16_packed_b_1.lo = (fp16_packed_b_1.lo != 0x0200) ? fp16_packed_b_1.lo : 0x0; + fp16_packed_b_1.hi = (fp16_packed_b_1.hi != 0x0200) ? fp16_packed_b_1.hi : 0x0; + + sign_a.lo = (fp4x8.s1 << 12) & 0x8000; + sign_a.hi = (fp4x8.s1 << 8) & 0x8000; + sign_b.lo = (fp4x8.s1 << 4) & 0x8000; + sign_b.hi = fp4x8.s1 & 0x8000; + + fp16_packed_a_1 = sign_a + bias_a + fp16_packed_a_1; + fp16_packed_b_1 = sign_b + bias_b + fp16_packed_b_1; + + return as_half8((ushort8)(fp16_packed_a_0, fp16_packed_b_0, fp16_packed_a_1, fp16_packed_b_1)); +} + +static inline float e8m0_to_fp32(uchar x) { + int bits; + bits = (x == 0) ? 0x00400000 : ((uint) x << 23); + return as_float(bits); +} + + +__attribute__((qcom_reqd_sub_group_size("half"))) +__kernel void kernel_gemv_moe_mxfp4_f32_ns( + __global uint * src0_q, + __global uchar * src0_e, + __read_only image1d_buffer_t src1, + __global uint * src2, + __global float * dst, + ulong offsetd, + int ne00, + int ne01, + int ne11 +) { + uint i01 = get_global_id(0); + uint i20 = get_global_id(2); + uint sgid = get_local_id(1); + uint slid = get_sub_group_local_id(); + + if (i01 >= ne01) { + return; + } + + uint i11 = i20 % ne11; + + uint expert_id = src2[i20]; + uint expert_offset = expert_id * ne00 * ne01 / 32; + + __private float sum = 0.0f; // each thread calculate partial sum of one output + + // loop along ne00 in block granularity, skip 4 blocks every iter + for (uint ib00 = sgid; ib00 < (ne00 / QK_MXFP4); ib00 += N_SIMDGROUP) { + + // load one block of q + uint4 regQ; + uint block_offset = expert_offset * 4 + ib00 * ne01 * 4 + i01; + + regQ.s0 = src0_q[block_offset]; + regQ.s1 = src0_q[block_offset + ne01]; + regQ.s2 = src0_q[block_offset + ne01 * 2]; + regQ.s3 = src0_q[block_offset + ne01 * 3]; + + uint offset = i11 * ne00 / 4 + ib00 * 8; + + half8 fp16x8 = mxfp4_to_fp16_packed8(as_ushort2(regQ.s0)); + + float4 shared_y4; + shared_y4 = read_imagef(src1, (offset + 0)); + float4 acc = shared_y4 * convert_float4(fp16x8.lo); + + shared_y4 = read_imagef(src1, (offset + 1)); + acc += shared_y4 * convert_float4(fp16x8.hi); + + fp16x8 = mxfp4_to_fp16_packed8(as_ushort2(regQ.s1)); + + shared_y4 = read_imagef(src1, (offset + 2)); + acc += shared_y4 * convert_float4(fp16x8.lo); + + shared_y4 = read_imagef(src1, (offset + 3)); + acc += shared_y4 * convert_float4(fp16x8.hi); + + + fp16x8 = mxfp4_to_fp16_packed8(as_ushort2(regQ.s2)); + + shared_y4 = read_imagef(src1, (offset + 4)); + acc += shared_y4 * convert_float4(fp16x8.lo); + + shared_y4 = read_imagef(src1, (offset + 5)); + acc += shared_y4 * convert_float4(fp16x8.hi); + + + fp16x8 = mxfp4_to_fp16_packed8(as_ushort2(regQ.s3)); + + shared_y4 = read_imagef(src1, (offset + 6)); + acc += shared_y4 * convert_float4(fp16x8.lo); + + shared_y4 = read_imagef(src1, (offset + 7)); + acc += shared_y4 * convert_float4(fp16x8.hi); + + uchar regE = src0_e[ib00 * ne01 + i01 + expert_offset]; + sum += e8m0_to_fp32(regE) * ((acc.s0 + acc.s1) + (acc.s2 + acc.s3)); + } + + // reduction in local memory, assumes #subgroups=4 + __local float reduceLM[SIMDGROUP_WIDTH * (N_SIMDGROUP - 1)]; + if (sgid == 1) reduceLM[SIMDGROUP_WIDTH * 0 + slid] = sum; + if (sgid == 2) reduceLM[SIMDGROUP_WIDTH * 1 + slid] = sum; + if (sgid == 3) reduceLM[SIMDGROUP_WIDTH * 2 + slid] = sum; + barrier(CLK_LOCAL_MEM_FENCE); + if (sgid == 0) sum += reduceLM[SIMDGROUP_WIDTH * 0 + slid]; + if (sgid == 0) sum += reduceLM[SIMDGROUP_WIDTH * 1 + slid]; + if (sgid == 0) sum += reduceLM[SIMDGROUP_WIDTH * 2 + slid]; + + // 1 outputs per thread in subgroup 0 + if (sgid == 0) { + dst = dst + (offsetd >> 2); + dst[i01 + i20 * ne01] = sum; + } + +} diff --git a/ggml/src/ggml-opencl/kernels/gemv_moe_q4_0_f32_ns.cl b/ggml/src/ggml-opencl/kernels/gemv_moe_q4_0_f32_ns.cl new file mode 100644 index 00000000000..2d28db63ec5 --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/gemv_moe_q4_0_f32_ns.cl @@ -0,0 +1,120 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable +#pragma OPENCL EXTENSION cl_khr_subgroups : enable +#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable + +#define QK_Q4_0 32 +#define N_SIMDGROUP 4 +#define SIMDGROUP_WIDTH 64 + +static inline float8 q4_0_to_fp32_packed8(ushort2 q4x8) { + float8 fp32x8; + fp32x8.s0 = (float)((q4x8.s0 & 0x000F) - 8); + fp32x8.s1 = (float)(((q4x8.s0 & 0x00F0) >> 4) - 8); + fp32x8.s2 = (float)(((q4x8.s0 & 0x0F00) >> 8) - 8); + fp32x8.s3 = (float)(((q4x8.s0 & 0xF000) >> 12) - 8); + fp32x8.s4 = (float)((q4x8.s1 & 0x000F) - 8); + fp32x8.s5 = (float)(((q4x8.s1 & 0x00F0) >> 4) - 8); + fp32x8.s6 = (float)(((q4x8.s1 & 0x0F00) >> 8) - 8); + fp32x8.s7 = (float)(((q4x8.s1 & 0xF000) >> 12) - 8); + return fp32x8; +} + + +__attribute__((qcom_reqd_sub_group_size("half"))) +__kernel void kernel_gemv_moe_q4_0_f32_ns( + __global uint * src0_q, + __global half * src0_d, + __read_only image1d_buffer_t src1, + __global uint * src2, + __global float * dst, + ulong offsetd, + int ne00, + int ne01, + int ne11 +) { + uint i01 = get_global_id(0); + uint i20 = get_global_id(2); + uint sgid = get_local_id(1); + uint slid = get_sub_group_local_id(); + + if (i01 >= ne01) { + return; + } + + uint i11 = i20 % ne11; + + uint expert_id = src2[i20]; + uint expert_offset = expert_id * ne00 * ne01 / 32; + + __private float sum = 0.0f; // each thread calculate partial sum of one output + + // loop along ne00 in block granularity, skip 4 blocks every iter + for (uint ib00 = sgid; ib00 < (ne00 / QK_Q4_0); ib00 += N_SIMDGROUP) { + + // load one block of q + uint4 regQ; + uint block_offset = expert_offset * 4 + ib00 * ne01 * 4 + i01; + + regQ.s0 = src0_q[block_offset]; + regQ.s1 = src0_q[block_offset + ne01]; + regQ.s2 = src0_q[block_offset + ne01 * 2]; + regQ.s3 = src0_q[block_offset + ne01 * 3]; + + uint offset = i11 * ne00 / 4 + ib00 * 8; + + float8 fp32x8 = q4_0_to_fp32_packed8(as_ushort2(regQ.s0)); + + float4 shared_y4; + shared_y4 = read_imagef(src1, (offset + 0)); + float4 acc = shared_y4 * fp32x8.lo; + + shared_y4 = read_imagef(src1, (offset + 1)); + acc += shared_y4 * fp32x8.hi; + + fp32x8 = q4_0_to_fp32_packed8(as_ushort2(regQ.s1)); + + shared_y4 = read_imagef(src1, (offset + 2)); + acc += shared_y4 * fp32x8.lo; + + shared_y4 = read_imagef(src1, (offset + 3)); + acc += shared_y4 * fp32x8.hi; + + + fp32x8 = q4_0_to_fp32_packed8(as_ushort2(regQ.s2)); + + shared_y4 = read_imagef(src1, (offset + 4)); + acc += shared_y4 * fp32x8.lo; + + shared_y4 = read_imagef(src1, (offset + 5)); + acc += shared_y4 * fp32x8.hi; + + + fp32x8 = q4_0_to_fp32_packed8(as_ushort2(regQ.s3)); + + shared_y4 = read_imagef(src1, (offset + 6)); + acc += shared_y4 * fp32x8.lo; + + shared_y4 = read_imagef(src1, (offset + 7)); + acc += shared_y4 * fp32x8.hi; + + half regS = src0_d[ib00 * ne01 + i01 + expert_offset]; + sum += (float)(regS) * ((acc.s0 + acc.s1) + (acc.s2 + acc.s3)); + } + + // reduction in local memory, assumes #subgroups=4 + __local float reduceLM[SIMDGROUP_WIDTH * (N_SIMDGROUP - 1)]; + if (sgid == 1) reduceLM[SIMDGROUP_WIDTH * 0 + slid] = sum; + if (sgid == 2) reduceLM[SIMDGROUP_WIDTH * 1 + slid] = sum; + if (sgid == 3) reduceLM[SIMDGROUP_WIDTH * 2 + slid] = sum; + barrier(CLK_LOCAL_MEM_FENCE); + if (sgid == 0) sum += reduceLM[SIMDGROUP_WIDTH * 0 + slid]; + if (sgid == 0) sum += reduceLM[SIMDGROUP_WIDTH * 1 + slid]; + if (sgid == 0) sum += reduceLM[SIMDGROUP_WIDTH * 2 + slid]; + + // 1 outputs per thread in subgroup 0 + if (sgid == 0) { + dst = dst + (offsetd >> 2); + dst[i01 + i20 * ne01] = sum; + } + +} diff --git a/ggml/src/ggml-opencl/kernels/gemv_moe_q4_1_f32_ns.cl b/ggml/src/ggml-opencl/kernels/gemv_moe_q4_1_f32_ns.cl new file mode 100644 index 00000000000..b98bdc0f12e --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/gemv_moe_q4_1_f32_ns.cl @@ -0,0 +1,123 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable +#pragma OPENCL EXTENSION cl_khr_subgroups : enable +#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable + +#define QK_Q4_1 32 +#define N_SIMDGROUP 4 +#define SIMDGROUP_WIDTH 64 + +static inline float8 q4_1_to_fp32_packed8(ushort2 q4x8, half s, half m) { + float8 fp32x8; + fp32x8.s0 = (float)((q4x8.s0 & 0x000F) * s + m); + fp32x8.s1 = (float)(((q4x8.s0 & 0x00F0) >> 4) * s + m); + fp32x8.s2 = (float)(((q4x8.s0 & 0x0F00) >> 8) * s + m); + fp32x8.s3 = (float)(((q4x8.s0 & 0xF000) >> 12) * s + m); + fp32x8.s4 = (float)((q4x8.s1 & 0x000F) * s + m); + fp32x8.s5 = (float)(((q4x8.s1 & 0x00F0) >> 4) * s + m); + fp32x8.s6 = (float)(((q4x8.s1 & 0x0F00) >> 8) * s + m); + fp32x8.s7 = (float)(((q4x8.s1 & 0xF000) >> 12) * s + m); + return fp32x8; +} + + +__attribute__((qcom_reqd_sub_group_size("half"))) +__kernel void kernel_gemv_moe_q4_1_f32_ns( + __global uint * src0_q, + __global half * src0_d, + __global half * src0_m, + __read_only image1d_buffer_t src1, + __global uint * src2, + __global float * dst, + ulong offsetd, + int ne00, + int ne01, + int ne11 +) { + uint i01 = get_global_id(0); + uint i20 = get_global_id(2); + uint sgid = get_local_id(1); + uint slid = get_sub_group_local_id(); + + if (i01 >= ne01) { + return; + } + + uint i11 = i20 % ne11; + + uint expert_id = src2[i20]; + uint expert_offset = expert_id * ne00 * ne01 / 32; + + __private float sum = 0.0f; // each thread calculate partial sum of one output + + // loop along ne00 in block granularity, skip 4 blocks every iter + for (uint ib00 = sgid; ib00 < (ne00 / QK_Q4_1); ib00 += N_SIMDGROUP) { + + // load one block of q + uint4 regQ; + uint block_offset = expert_offset * 4 + ib00 * ne01 * 4 + i01; + + regQ.s0 = src0_q[block_offset]; + regQ.s1 = src0_q[block_offset + ne01]; + regQ.s2 = src0_q[block_offset + ne01 * 2]; + regQ.s3 = src0_q[block_offset + ne01 * 3]; + + uint offset = i11 * ne00 / 4 + ib00 * 8; + + half regM = src0_m[ib00 * ne01 + i01 + expert_offset]; + half regS = src0_d[ib00 * ne01 + i01 + expert_offset]; + + float8 fp32x8 = q4_1_to_fp32_packed8(as_ushort2(regQ.s0), regS, regM); + + float4 shared_y4; + shared_y4 = read_imagef(src1, (offset + 0)); + float4 acc = shared_y4 * fp32x8.lo; + + shared_y4 = read_imagef(src1, (offset + 1)); + acc += shared_y4 * fp32x8.hi; + + fp32x8 = q4_1_to_fp32_packed8(as_ushort2(regQ.s1), regS, regM); + + shared_y4 = read_imagef(src1, (offset + 2)); + acc += shared_y4 * fp32x8.lo; + + shared_y4 = read_imagef(src1, (offset + 3)); + acc += shared_y4 * fp32x8.hi; + + + fp32x8 = q4_1_to_fp32_packed8(as_ushort2(regQ.s2), regS, regM); + + shared_y4 = read_imagef(src1, (offset + 4)); + acc += shared_y4 * fp32x8.lo; + + shared_y4 = read_imagef(src1, (offset + 5)); + acc += shared_y4 * fp32x8.hi; + + + fp32x8 = q4_1_to_fp32_packed8(as_ushort2(regQ.s3), regS, regM); + + shared_y4 = read_imagef(src1, (offset + 6)); + acc += shared_y4 * fp32x8.lo; + + shared_y4 = read_imagef(src1, (offset + 7)); + acc += shared_y4 * fp32x8.hi; + + sum += ((acc.s0 + acc.s1) + (acc.s2 + acc.s3)); + } + + // reduction in local memory, assumes #subgroups=4 + __local float reduceLM[SIMDGROUP_WIDTH * (N_SIMDGROUP - 1)]; + if (sgid == 1) reduceLM[SIMDGROUP_WIDTH * 0 + slid] = sum; + if (sgid == 2) reduceLM[SIMDGROUP_WIDTH * 1 + slid] = sum; + if (sgid == 3) reduceLM[SIMDGROUP_WIDTH * 2 + slid] = sum; + barrier(CLK_LOCAL_MEM_FENCE); + if (sgid == 0) sum += reduceLM[SIMDGROUP_WIDTH * 0 + slid]; + if (sgid == 0) sum += reduceLM[SIMDGROUP_WIDTH * 1 + slid]; + if (sgid == 0) sum += reduceLM[SIMDGROUP_WIDTH * 2 + slid]; + + // 1 outputs per thread in subgroup 0 + if (sgid == 0) { + dst = dst + (offsetd >> 2); + dst[i01 + i20 * ne01] = sum; + } + +} diff --git a/ggml/src/ggml-opencl/kernels/gemv_moe_q4_k_f32_ns.cl b/ggml/src/ggml-opencl/kernels/gemv_moe_q4_k_f32_ns.cl new file mode 100644 index 00000000000..12464e9826e --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/gemv_moe_q4_k_f32_ns.cl @@ -0,0 +1,155 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable +#pragma OPENCL EXTENSION cl_khr_subgroups : enable +#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable + +#define QK_K 256 +#define K_SCALE_SIZE 12 +#define N_SIMDGROUP 4 +#define SIMDGROUP_WIDTH 64 + +inline void get_scale_min_k4( + int j, + global const uchar * q, + uchar * d, + uchar * m +) { + if (j < 4) { + *d = q[j] & 63; + *m = q[j+4] & 63; + } else { + *d = (q[j+4] & 0x0F) | ((q[j-4] & 0xC0) >> 2); + *m = ((q[j+4] >> 4) & 0x0F) | ((q[j] & 0xC0) >> 2); + } +} + +static inline float8 q4_k_to_fp32_packed8(ushort2 q4x8, float scale, float minv) { + float8 fp32x8; + fp32x8.s0 = (q4x8.s0 & 0x000F) * scale - minv; + fp32x8.s1 = ((q4x8.s0 & 0x00F0) >> 4) * scale - minv; + fp32x8.s2 = ((q4x8.s0 & 0x0F00) >> 8) * scale - minv; + fp32x8.s3 = ((q4x8.s0 & 0xF000) >> 12) * scale - minv; + fp32x8.s4 = (q4x8.s1 & 0x000F) * scale - minv; + fp32x8.s5 = ((q4x8.s1 & 0x00F0) >> 4) * scale - minv; + fp32x8.s6 = ((q4x8.s1 & 0x0F00) >> 8) * scale - minv; + fp32x8.s7 = ((q4x8.s1 & 0xF000) >> 12) * scale - minv; + return fp32x8; +} + +__attribute__((qcom_reqd_sub_group_size("half"))) +__kernel void kernel_gemv_moe_q4_k_f32_ns( + __global uint * src0_q, + __global half * src0_d, + __global half * src0_dm, + __global uchar * src0_s, + __read_only image1d_buffer_t src1, + __global uint * src2, + __global float * dst, + ulong offsetd, + int ne00, + int ne01, + int ne11 +) { + uint i01 = get_global_id(0); + uint i20 = get_global_id(2); + uint sgid = get_local_id(1); + uint slid = get_sub_group_local_id(); + + if (i01 >= ne01) { + return; + } + + uint i11 = i20 % ne11; + + uint expert_id = src2[i20]; + + int num_superblocks = ne00 / QK_K; + int num_subblocks = ne00 / 32; + int scales_per_row = num_superblocks * K_SCALE_SIZE; + + // Expert offsets in the transposed noshuffle layout + uint expert_q_offset = expert_id * (ne00 / 8) * ne01; + uint expert_d_offset = expert_id * num_superblocks * ne01; + + __private float sum = 0.0f; + + // Loop over sub-blocks of 32 elements, N_SIMDGROUP sub-blocks per iter + for (uint ib = sgid; ib < num_subblocks; ib += N_SIMDGROUP) { + uint sb = ib / 8; + uint j = ib % 8; + + // Load d and dmin for this super-block + half d_val = src0_d[expert_d_offset + sb * ne01 + i01]; + half dm_val = src0_dm[expert_d_offset + sb * ne01 + i01]; + + // Load sub-block scale and min + global const uchar * sc = src0_s + (expert_id * ne01 + i01) * scales_per_row + sb * K_SCALE_SIZE; + uchar sv, mn; + get_scale_min_k4(j, sc, &sv, &mn); + + float scale = (float)d_val * (float)sv; + float minv = (float)dm_val * (float)mn; + + // Load 4 uints of quants (32 nibbles = 32 elements) + uint q_base = expert_q_offset + ib * ne01 * 4 + i01; + + uint4 regQ; + regQ.s0 = src0_q[q_base]; + regQ.s1 = src0_q[q_base + ne01]; + regQ.s2 = src0_q[q_base + ne01 * 2]; + regQ.s3 = src0_q[q_base + ne01 * 3]; + + // Load activations: 32 floats = 8 float4s + uint y_offset = i11 * ne00 / 4 + ib * 8; + + float8 fp32x8 = q4_k_to_fp32_packed8(as_ushort2(regQ.s0), scale, minv); + + float4 shared_y4; + shared_y4 = read_imagef(src1, (y_offset + 0)); + float4 acc = shared_y4 * fp32x8.lo; + + shared_y4 = read_imagef(src1, (y_offset + 1)); + acc += shared_y4 * fp32x8.hi; + + fp32x8 = q4_k_to_fp32_packed8(as_ushort2(regQ.s1), scale, minv); + + shared_y4 = read_imagef(src1, (y_offset + 2)); + acc += shared_y4 * fp32x8.lo; + + shared_y4 = read_imagef(src1, (y_offset + 3)); + acc += shared_y4 * fp32x8.hi; + + fp32x8 = q4_k_to_fp32_packed8(as_ushort2(regQ.s2), scale, minv); + + shared_y4 = read_imagef(src1, (y_offset + 4)); + acc += shared_y4 * fp32x8.lo; + + shared_y4 = read_imagef(src1, (y_offset + 5)); + acc += shared_y4 * fp32x8.hi; + + fp32x8 = q4_k_to_fp32_packed8(as_ushort2(regQ.s3), scale, minv); + + shared_y4 = read_imagef(src1, (y_offset + 6)); + acc += shared_y4 * fp32x8.lo; + + shared_y4 = read_imagef(src1, (y_offset + 7)); + acc += shared_y4 * fp32x8.hi; + + sum += ((acc.s0 + acc.s1) + (acc.s2 + acc.s3)); + } + + // reduction in local memory, assumes #subgroups=4 + __local float reduceLM[SIMDGROUP_WIDTH * (N_SIMDGROUP - 1)]; + if (sgid == 1) reduceLM[SIMDGROUP_WIDTH * 0 + slid] = sum; + if (sgid == 2) reduceLM[SIMDGROUP_WIDTH * 1 + slid] = sum; + if (sgid == 3) reduceLM[SIMDGROUP_WIDTH * 2 + slid] = sum; + barrier(CLK_LOCAL_MEM_FENCE); + if (sgid == 0) sum += reduceLM[SIMDGROUP_WIDTH * 0 + slid]; + if (sgid == 0) sum += reduceLM[SIMDGROUP_WIDTH * 1 + slid]; + if (sgid == 0) sum += reduceLM[SIMDGROUP_WIDTH * 2 + slid]; + + // 1 output per thread in subgroup 0 + if (sgid == 0) { + dst = dst + (offsetd >> 2); + dst[i01 + i20 * ne01] = sum; + } +} diff --git a/ggml/src/ggml-opencl/kernels/gemv_moe_q5_0_f32_ns.cl b/ggml/src/ggml-opencl/kernels/gemv_moe_q5_0_f32_ns.cl new file mode 100644 index 00000000000..b43613638a8 --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/gemv_moe_q5_0_f32_ns.cl @@ -0,0 +1,123 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable +#pragma OPENCL EXTENSION cl_khr_subgroups : enable +#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable + +#define QK_Q5_0 32 +#define N_SIMDGROUP 4 +#define SIMDGROUP_WIDTH 64 + +static inline float8 q5_0_to_fp32_packed8(ushort2 qs5x8, uchar qh5x8) { + float8 fp32x8; + fp32x8.s0 = (float)((( qs5x8.s0 & 0x000F) | (( qh5x8 & 0x01) << 4)) - 16); + fp32x8.s1 = (float)((((qs5x8.s0 & 0x00F0) >> 4 ) | (((qh5x8 >> 1) & 0x01) << 4)) - 16); + fp32x8.s2 = (float)((((qs5x8.s0 & 0x0F00) >> 8 ) | (((qh5x8 >> 2) & 0x01) << 4)) - 16); + fp32x8.s3 = (float)((((qs5x8.s0 & 0xF000) >> 12) | (((qh5x8 >> 3) & 0x01) << 4)) - 16); + fp32x8.s4 = (float)((( qs5x8.s1 & 0x000F) | (((qh5x8 >> 4) & 0x01) << 4)) - 16); + fp32x8.s5 = (float)((((qs5x8.s1 & 0x00F0) >> 4 ) | (((qh5x8 >> 5) & 0x01) << 4)) - 16); + fp32x8.s6 = (float)((((qs5x8.s1 & 0x0F00) >> 8 ) | (((qh5x8 >> 6) & 0x01) << 4)) - 16); + fp32x8.s7 = (float)((((qs5x8.s1 & 0xF000) >> 12) | (((qh5x8 >> 7) & 0x01) << 4)) - 16); + return fp32x8; +} + + +__attribute__((qcom_reqd_sub_group_size("half"))) +__kernel void kernel_gemv_moe_q5_0_f32_ns( + __global uint * src0_qs, + __global uint * src0_qh, + __global half * src0_d, + __read_only image1d_buffer_t src1, + __global uint * src2, + __global float * dst, + ulong offsetd, + uint ne00, + uint ne01, + uint ne11 +) { + uint i01 = get_global_id(0); + uint i20 = get_global_id(2); + uint sgid = get_local_id(1); + uint slid = get_sub_group_local_id(); + + if (i01 >= ne01) { + return; + } + + uint i11 = i20 % ne11; + + uint expert_id = src2[i20]; + uint expert_offset = expert_id * ne00 * ne01 / 32; + + __private float sum = 0.0f; // each thread calculate partial sum of one output + + // loop along ne00 in block granularity, skip 4 blocks every iter + for (uint ib00 = sgid; ib00 < (ne00 / QK_Q5_0); ib00 += N_SIMDGROUP) { + + // load one block of q + uint4 regQ; + uint block_offset = expert_offset * 4 + ib00 * ne01 * 4 + i01; + + regQ.s0 = src0_qs[block_offset]; + regQ.s1 = src0_qs[block_offset + ne01]; + regQ.s2 = src0_qs[block_offset + ne01 * 2]; + regQ.s3 = src0_qs[block_offset + ne01 * 3]; + + uint offset = i11 * ne00 / 4 + ib00 * 8; + + uchar4 regQh = as_uchar4(src0_qh[ib00 * ne01 + i01 + expert_offset]); + half regS = src0_d[ib00 * ne01 + i01 + expert_offset]; + + float8 fp32x8 = q5_0_to_fp32_packed8(as_ushort2(regQ.s0), regQh.s0); + + float4 shared_y4; + shared_y4 = read_imagef(src1, (offset + 0)); + float4 acc = shared_y4 * fp32x8.lo; + + shared_y4 = read_imagef(src1, (offset + 1)); + acc += shared_y4 * fp32x8.hi; + + fp32x8 = q5_0_to_fp32_packed8(as_ushort2(regQ.s1), regQh.s1); + + shared_y4 = read_imagef(src1, (offset + 2)); + acc += shared_y4 * fp32x8.lo; + + shared_y4 = read_imagef(src1, (offset + 3)); + acc += shared_y4 * fp32x8.hi; + + + fp32x8 = q5_0_to_fp32_packed8(as_ushort2(regQ.s2), regQh.s2); + + shared_y4 = read_imagef(src1, (offset + 4)); + acc += shared_y4 * fp32x8.lo; + + shared_y4 = read_imagef(src1, (offset + 5)); + acc += shared_y4 * fp32x8.hi; + + + fp32x8 = q5_0_to_fp32_packed8(as_ushort2(regQ.s3), regQh.s3); + + shared_y4 = read_imagef(src1, (offset + 6)); + acc += shared_y4 * fp32x8.lo; + + shared_y4 = read_imagef(src1, (offset + 7)); + acc += shared_y4 * fp32x8.hi; + + sum += (float)(regS) * ((acc.s0 + acc.s1) + (acc.s2 + acc.s3)); + } + + // reduction in local memory, assumes #subgroups=4 + __local float reduceLM[SIMDGROUP_WIDTH * (N_SIMDGROUP - 1)]; + if (sgid == 1) reduceLM[SIMDGROUP_WIDTH * 0 + slid] = sum; + if (sgid == 2) reduceLM[SIMDGROUP_WIDTH * 1 + slid] = sum; + if (sgid == 3) reduceLM[SIMDGROUP_WIDTH * 2 + slid] = sum; + barrier(CLK_LOCAL_MEM_FENCE); + if (sgid == 0) sum += reduceLM[SIMDGROUP_WIDTH * 0 + slid]; + if (sgid == 0) sum += reduceLM[SIMDGROUP_WIDTH * 1 + slid]; + if (sgid == 0) sum += reduceLM[SIMDGROUP_WIDTH * 2 + slid]; + + // 1 outputs per thread in subgroup 0 + if (sgid == 0) { + dst = dst + (offsetd >> 2); + dst[i01 + i20 * ne01] = sum; + } + +} diff --git a/ggml/src/ggml-opencl/kernels/gemv_moe_q5_1_f32_ns.cl b/ggml/src/ggml-opencl/kernels/gemv_moe_q5_1_f32_ns.cl new file mode 100644 index 00000000000..7a666006e68 --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/gemv_moe_q5_1_f32_ns.cl @@ -0,0 +1,125 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable +#pragma OPENCL EXTENSION cl_khr_subgroups : enable +#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable + +#define QK_Q5_1 32 +#define N_SIMDGROUP 4 +#define SIMDGROUP_WIDTH 64 + +static inline float8 q5_1_to_fp32_packed8(ushort2 qs5x8, uchar qh5x8, half s, half m) { + float8 fp32x8; + fp32x8.s0 = (float)((( qs5x8.s0 & 0x000F) | (( qh5x8 & 0x01) << 4)) * s + m); + fp32x8.s1 = (float)((((qs5x8.s0 & 0x00F0) >> 4 ) | (((qh5x8 >> 1) & 0x01) << 4)) * s + m); + fp32x8.s2 = (float)((((qs5x8.s0 & 0x0F00) >> 8 ) | (((qh5x8 >> 2) & 0x01) << 4)) * s + m); + fp32x8.s3 = (float)((((qs5x8.s0 & 0xF000) >> 12) | (((qh5x8 >> 3) & 0x01) << 4)) * s + m); + fp32x8.s4 = (float)((( qs5x8.s1 & 0x000F) | (((qh5x8 >> 4) & 0x01) << 4)) * s + m); + fp32x8.s5 = (float)((((qs5x8.s1 & 0x00F0) >> 4 ) | (((qh5x8 >> 5) & 0x01) << 4)) * s + m); + fp32x8.s6 = (float)((((qs5x8.s1 & 0x0F00) >> 8 ) | (((qh5x8 >> 6) & 0x01) << 4)) * s + m); + fp32x8.s7 = (float)((((qs5x8.s1 & 0xF000) >> 12) | (((qh5x8 >> 7) & 0x01) << 4)) * s + m); + return fp32x8; +} + + +__attribute__((qcom_reqd_sub_group_size("half"))) +__kernel void kernel_gemv_moe_q5_1_f32_ns( + __global uint * src0_qs, + __global uint * src0_qh, + __global half * src0_d, + __global half * src0_m, + __read_only image1d_buffer_t src1, + __global uint * src2, + __global float * dst, + ulong offsetd, + uint ne00, + uint ne01, + uint ne11 +) { + uint i01 = get_global_id(0); + uint i20 = get_global_id(2); + uint sgid = get_local_id(1); + uint slid = get_sub_group_local_id(); + + if (i01 >= ne01) { + return; + } + + uint i11 = i20 % ne11; + + uint expert_id = src2[i20]; + uint expert_offset = expert_id * ne00 * ne01 / 32; + + __private float sum = 0.0f; // each thread calculate partial sum of one output + + // loop along ne00 in block granularity, skip 4 blocks every iter + for (uint ib00 = sgid; ib00 < (ne00 / QK_Q5_1); ib00 += N_SIMDGROUP) { + + // load one block of q + uint4 regQ; + uint block_offset = expert_offset * 4 + ib00 * ne01 * 4 + i01; + + regQ.s0 = src0_qs[block_offset]; + regQ.s1 = src0_qs[block_offset + ne01]; + regQ.s2 = src0_qs[block_offset + ne01 * 2]; + regQ.s3 = src0_qs[block_offset + ne01 * 3]; + + uint offset = i11 * ne00 / 4 + ib00 * 8; + + uchar4 regQh = as_uchar4(src0_qh[ib00 * ne01 + i01 + expert_offset]); + half regM = src0_m[ib00 * ne01 + i01 + expert_offset]; + half regS = src0_d[ib00 * ne01 + i01 + expert_offset]; + + float8 fp32x8 = q5_1_to_fp32_packed8(as_ushort2(regQ.s0), regQh.s0, regS, regM); + + float4 shared_y4; + shared_y4 = read_imagef(src1, (offset + 0)); + float4 acc = shared_y4 * fp32x8.lo; + + shared_y4 = read_imagef(src1, (offset + 1)); + acc += shared_y4 * fp32x8.hi; + + fp32x8 = q5_1_to_fp32_packed8(as_ushort2(regQ.s1), regQh.s1, regS, regM); + + shared_y4 = read_imagef(src1, (offset + 2)); + acc += shared_y4 * fp32x8.lo; + + shared_y4 = read_imagef(src1, (offset + 3)); + acc += shared_y4 * fp32x8.hi; + + + fp32x8 = q5_1_to_fp32_packed8(as_ushort2(regQ.s2), regQh.s2, regS, regM); + + shared_y4 = read_imagef(src1, (offset + 4)); + acc += shared_y4 * fp32x8.lo; + + shared_y4 = read_imagef(src1, (offset + 5)); + acc += shared_y4 * fp32x8.hi; + + + fp32x8 = q5_1_to_fp32_packed8(as_ushort2(regQ.s3), regQh.s3, regS, regM); + + shared_y4 = read_imagef(src1, (offset + 6)); + acc += shared_y4 * fp32x8.lo; + + shared_y4 = read_imagef(src1, (offset + 7)); + acc += shared_y4 * fp32x8.hi; + + sum += ((acc.s0 + acc.s1) + (acc.s2 + acc.s3)); + } + + // reduction in local memory, assumes #subgroups=4 + __local float reduceLM[SIMDGROUP_WIDTH * (N_SIMDGROUP - 1)]; + if (sgid == 1) reduceLM[SIMDGROUP_WIDTH * 0 + slid] = sum; + if (sgid == 2) reduceLM[SIMDGROUP_WIDTH * 1 + slid] = sum; + if (sgid == 3) reduceLM[SIMDGROUP_WIDTH * 2 + slid] = sum; + barrier(CLK_LOCAL_MEM_FENCE); + if (sgid == 0) sum += reduceLM[SIMDGROUP_WIDTH * 0 + slid]; + if (sgid == 0) sum += reduceLM[SIMDGROUP_WIDTH * 1 + slid]; + if (sgid == 0) sum += reduceLM[SIMDGROUP_WIDTH * 2 + slid]; + + // 1 outputs per thread in subgroup 0 + if (sgid == 0) { + dst = dst + (offsetd >> 2); + dst[i01 + i20 * ne01] = sum; + } + +} diff --git a/ggml/src/ggml-opencl/kernels/gemv_moe_q5_k_f32_ns.cl b/ggml/src/ggml-opencl/kernels/gemv_moe_q5_k_f32_ns.cl new file mode 100644 index 00000000000..7d868d7abd9 --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/gemv_moe_q5_k_f32_ns.cl @@ -0,0 +1,160 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable +#pragma OPENCL EXTENSION cl_khr_subgroups : enable +#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable + +#define QK_K 256 +#define K_SCALE_SIZE 12 +#define N_SIMDGROUP 4 +#define SIMDGROUP_WIDTH 64 + +inline void get_scale_min_k4( + int j, + global const uchar * q, + uchar * d, + uchar * m +) { + if (j < 4) { + *d = q[j] & 63; + *m = q[j+4] & 63; + } else { + *d = (q[j+4] & 0x0F) | ((q[j-4] & 0xC0) >> 2); + *m = ((q[j+4] >> 4) & 0x0F) | ((q[j] & 0xC0) >> 2); + } +} + +static inline float8 q5_k_to_fp32_packed8(ushort2 qs5x8, uchar qh5x8, half s, half m) { + float8 fp32x8; + fp32x8.s0 = (float)((( qs5x8.s0 & 0x000F) | (( qh5x8 & 0x01) << 4)) * s + m); + fp32x8.s1 = (float)((((qs5x8.s0 & 0x00F0) >> 4 ) | (((qh5x8 >> 1) & 0x01) << 4)) * s + m); + fp32x8.s2 = (float)((((qs5x8.s0 & 0x0F00) >> 8 ) | (((qh5x8 >> 2) & 0x01) << 4)) * s + m); + fp32x8.s3 = (float)((((qs5x8.s0 & 0xF000) >> 12) | (((qh5x8 >> 3) & 0x01) << 4)) * s + m); + fp32x8.s4 = (float)((( qs5x8.s1 & 0x000F) | (((qh5x8 >> 4) & 0x01) << 4)) * s + m); + fp32x8.s5 = (float)((((qs5x8.s1 & 0x00F0) >> 4 ) | (((qh5x8 >> 5) & 0x01) << 4)) * s + m); + fp32x8.s6 = (float)((((qs5x8.s1 & 0x0F00) >> 8 ) | (((qh5x8 >> 6) & 0x01) << 4)) * s + m); + fp32x8.s7 = (float)((((qs5x8.s1 & 0xF000) >> 12) | (((qh5x8 >> 7) & 0x01) << 4)) * s + m); + return fp32x8; +} + +__attribute__((qcom_reqd_sub_group_size("half"))) +__kernel void kernel_gemv_moe_q5_k_f32_ns( + __global uint * src0_q, + __global uint * src0_qh, + __global half * src0_d, + __global half * src0_dm, + __global uchar * src0_s, + __read_only image1d_buffer_t src1, + __global uint * src2, + __global float * dst, + ulong offsetd, + int ne00, + int ne01, + int ne11 +) { + uint i01 = get_global_id(0); + uint i20 = get_global_id(2); + uint sgid = get_local_id(1); + uint slid = get_sub_group_local_id(); + + if (i01 >= ne01) { + return; + } + + uint i11 = i20 % ne11; + + uint expert_id = src2[i20]; + + int num_superblocks = ne00 / QK_K; + int num_subblocks = ne00 / 32; + int scales_per_row = num_superblocks * K_SCALE_SIZE; + + // Expert offsets in the transposed noshuffle layout + uint expert_q_offset = expert_id * (ne00 / 8) * ne01; + uint expert_d_offset = expert_id * num_superblocks * ne01; + + __private float sum = 0.0f; + + // Loop over sub-blocks of 32 elements, N_SIMDGROUP sub-blocks per iter + for (uint ib = sgid; ib < num_subblocks; ib += N_SIMDGROUP) { + uint sb = ib / 8; + uint j = ib % 8; + + // Load d and dmin for this super-block + half d_val = src0_d[expert_d_offset + sb * ne01 + i01]; + half dm_val = src0_dm[expert_d_offset + sb * ne01 + i01]; + + // sub_block index = sb * 8 + j + uint expert_qh_offset = expert_id * num_superblocks * 8 * ne01; + uchar4 regQh = as_uchar4(src0_qh[expert_qh_offset + (sb * 8 + j) * ne01 + i01]); + + // Load sub-block scale and min + global const uchar * sc = src0_s + (expert_id * ne01 + i01) * scales_per_row + sb * K_SCALE_SIZE; + uchar sv, mn; + get_scale_min_k4(j, sc, &sv, &mn); + + float scale = (float)d_val * (float)sv; + float minv = -(float)dm_val * (float)mn; + + // Load 4 uints of quants (32 nibbles = 32 elements) + uint q_base = expert_q_offset + ib * ne01 * 4 + i01; + + uint4 regQ; + regQ.s0 = src0_q[q_base]; + regQ.s1 = src0_q[q_base + ne01]; + regQ.s2 = src0_q[q_base + ne01 * 2]; + regQ.s3 = src0_q[q_base + ne01 * 3]; + + // Load activations: 32 floats = 8 float4s + uint y_offset = i11 * ne00 / 4 + ib * 8; + + float8 fp32x8 = q5_k_to_fp32_packed8(as_ushort2(regQ.s0), regQh.s0, scale, minv); + + float4 shared_y4; + shared_y4 = read_imagef(src1, (y_offset + 0)); + float4 acc = shared_y4 * fp32x8.lo; + + shared_y4 = read_imagef(src1, (y_offset + 1)); + acc += shared_y4 * fp32x8.hi; + + fp32x8 = q5_k_to_fp32_packed8(as_ushort2(regQ.s1), regQh.s1, scale, minv); + + shared_y4 = read_imagef(src1, (y_offset + 2)); + acc += shared_y4 * fp32x8.lo; + + shared_y4 = read_imagef(src1, (y_offset + 3)); + acc += shared_y4 * fp32x8.hi; + + fp32x8 = q5_k_to_fp32_packed8(as_ushort2(regQ.s2), regQh.s2, scale, minv); + + shared_y4 = read_imagef(src1, (y_offset + 4)); + acc += shared_y4 * fp32x8.lo; + + shared_y4 = read_imagef(src1, (y_offset + 5)); + acc += shared_y4 * fp32x8.hi; + + fp32x8 = q5_k_to_fp32_packed8(as_ushort2(regQ.s3), regQh.s3, scale, minv); + + shared_y4 = read_imagef(src1, (y_offset + 6)); + acc += shared_y4 * fp32x8.lo; + + shared_y4 = read_imagef(src1, (y_offset + 7)); + acc += shared_y4 * fp32x8.hi; + + sum += ((acc.s0 + acc.s1) + (acc.s2 + acc.s3)); + } + + // reduction in local memory, assumes #subgroups=4 + __local float reduceLM[SIMDGROUP_WIDTH * (N_SIMDGROUP - 1)]; + if (sgid == 1) reduceLM[SIMDGROUP_WIDTH * 0 + slid] = sum; + if (sgid == 2) reduceLM[SIMDGROUP_WIDTH * 1 + slid] = sum; + if (sgid == 3) reduceLM[SIMDGROUP_WIDTH * 2 + slid] = sum; + barrier(CLK_LOCAL_MEM_FENCE); + if (sgid == 0) sum += reduceLM[SIMDGROUP_WIDTH * 0 + slid]; + if (sgid == 0) sum += reduceLM[SIMDGROUP_WIDTH * 1 + slid]; + if (sgid == 0) sum += reduceLM[SIMDGROUP_WIDTH * 2 + slid]; + + // 1 output per thread in subgroup 0 + if (sgid == 0) { + dst = dst + (offsetd >> 2); + dst[i01 + i20 * ne01] = sum; + } +} diff --git a/ggml/src/ggml-opencl/kernels/gemv_moe_q6_k_f32_ns.cl b/ggml/src/ggml-opencl/kernels/gemv_moe_q6_k_f32_ns.cl new file mode 100644 index 00000000000..c166bad5ba5 --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/gemv_moe_q6_k_f32_ns.cl @@ -0,0 +1,141 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable +#pragma OPENCL EXTENSION cl_khr_subgroups : enable +#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable + +#define QK_K 256 +#define N_SIMDGROUP 4 +#define SIMDGROUP_WIDTH 64 + +static inline float8 q6_k_to_fp32_packed8(ushort2 ql8, ushort qh8, float d_scale) { + float8 fp32x8; + fp32x8.s0 = ((float)(( ql8.s0 & 0x000F) | ((uint)((qh8 ) & 0x3) << 4)) - 32.f) * d_scale; + fp32x8.s1 = ((float)((( ql8.s0 >> 4) & 0x000F) | ((uint)((qh8 >> 2) & 0x3) << 4)) - 32.f) * d_scale; + fp32x8.s2 = ((float)((( ql8.s0 >> 8) & 0x000F) | ((uint)((qh8 >> 4) & 0x3) << 4)) - 32.f) * d_scale; + fp32x8.s3 = ((float)((( ql8.s0 >> 12)& 0x000F) | ((uint)((qh8 >> 6) & 0x3) << 4)) - 32.f) * d_scale; + fp32x8.s4 = ((float)(( ql8.s1 & 0x000F) | ((uint)((qh8 >> 8) & 0x3) << 4)) - 32.f) * d_scale; + fp32x8.s5 = ((float)((( ql8.s1 >> 4) & 0x000F) | ((uint)((qh8 >>10) & 0x3) << 4)) - 32.f) * d_scale; + fp32x8.s6 = ((float)((( ql8.s1 >> 8) & 0x000F) | ((uint)((qh8 >>12) & 0x3) << 4)) - 32.f) * d_scale; + fp32x8.s7 = ((float)((( ql8.s1 >> 12)& 0x000F) | ((uint)((qh8 >>14) & 0x3) << 4)) - 32.f) * d_scale; + return fp32x8; +} + +__attribute__((qcom_reqd_sub_group_size("half"))) +__kernel void kernel_gemv_moe_q6_k_f32_ns( + __global uint * src0_ql, + __global uint * src0_qh, + __global char * src0_s, + __global half * src0_d, + __read_only image1d_buffer_t src1, + __global uint * src2, + __global float * dst, + ulong offsetd, + int ne00, + int ne01, + int ne11 +) { + uint i01 = get_global_id(0); + uint i20 = get_global_id(2); + uint sgid = get_local_id(1); + uint slid = get_sub_group_local_id(); + + if (i01 >= ne01) { + return; + } + + uint i11 = i20 % ne11; + + uint expert_id = src2[i20]; + + int num_superblocks = ne00 / QK_K; + int num_subblocks = ne00 / 32; // 8 sub-blocks of 32 per super-block + int scales_per_row = num_superblocks * 16; + + // Expert offsets in the transposed noshuffle layout + uint expert_ql_offset = expert_id * (ne00 / 8) * ne01; // 32 uints per super-block + uint expert_qh_offset = expert_id * (ne00 / 16) * ne01; // 16 uints per super-block + uint expert_d_offset = expert_id * num_superblocks * ne01; + + __private float sum = 0.0f; + + // Loop over sub-blocks of 32 elements, N_SIMDGROUP sub-blocks per iter + for (uint ib = sgid; ib < num_subblocks; ib += N_SIMDGROUP) { + uint sb = ib / 8; // super-block index + uint j = ib % 8; // 32-element group within super-block + + // Load d for this super-block + half d_val = src0_d[expert_d_offset + sb * ne01 + i01]; + + // Load 2 sub-block scales + global const char * sc = src0_s + (expert_id * ne01 + i01) * scales_per_row + sb * 16; + float scale0 = (float)d_val * (float)sc[j * 2]; + float scale1 = (float)d_val * (float)sc[j * 2 + 1]; + + // Load 4 uints of ql + uint ql_base = expert_ql_offset + (ib * 4) * ne01 + i01; + uint4 regQL; + regQL.s0 = src0_ql[ql_base]; + regQL.s1 = src0_ql[ql_base + ne01]; + regQL.s2 = src0_ql[ql_base + ne01 * 2]; + regQL.s3 = src0_ql[ql_base + ne01 * 3]; + + // Load 2 uints of qh + uint qh_base = expert_qh_offset + (ib * 2) * ne01 + i01; + uint2 regQH; + regQH.s0 = src0_qh[qh_base]; + regQH.s1 = src0_qh[qh_base + ne01]; + + // Load activations: 32 floats = 8 float4s + uint y_offset = i11 * ne00 / 4 + ib * 8; + + float8 fp32x8 = q6_k_to_fp32_packed8(as_ushort2(regQL.s0), (ushort)(regQH.s0 & 0xFFFF), scale0); + + float4 shared_y4; + shared_y4 = read_imagef(src1, (y_offset + 0)); + float4 acc = shared_y4 * fp32x8.lo; + + shared_y4 = read_imagef(src1, (y_offset + 1)); + acc += shared_y4 * fp32x8.hi; + + fp32x8 = q6_k_to_fp32_packed8(as_ushort2(regQL.s1), (ushort)(regQH.s0 >> 16), scale0); + + shared_y4 = read_imagef(src1, (y_offset + 2)); + acc += shared_y4 * fp32x8.lo; + + shared_y4 = read_imagef(src1, (y_offset + 3)); + acc += shared_y4 * fp32x8.hi; + + fp32x8 = q6_k_to_fp32_packed8(as_ushort2(regQL.s2), (ushort)(regQH.s1 & 0xFFFF), scale1); + + shared_y4 = read_imagef(src1, (y_offset + 4)); + acc += shared_y4 * fp32x8.lo; + + shared_y4 = read_imagef(src1, (y_offset + 5)); + acc += shared_y4 * fp32x8.hi; + + fp32x8 = q6_k_to_fp32_packed8(as_ushort2(regQL.s3), (ushort)(regQH.s1 >> 16), scale1); + + shared_y4 = read_imagef(src1, (y_offset + 6)); + acc += shared_y4 * fp32x8.lo; + + shared_y4 = read_imagef(src1, (y_offset + 7)); + acc += shared_y4 * fp32x8.hi; + + sum += ((acc.s0 + acc.s1) + (acc.s2 + acc.s3)); + } + + // reduction in local memory, assumes #subgroups=4 + __local float reduceLM[SIMDGROUP_WIDTH * (N_SIMDGROUP - 1)]; + if (sgid == 1) reduceLM[SIMDGROUP_WIDTH * 0 + slid] = sum; + if (sgid == 2) reduceLM[SIMDGROUP_WIDTH * 1 + slid] = sum; + if (sgid == 3) reduceLM[SIMDGROUP_WIDTH * 2 + slid] = sum; + barrier(CLK_LOCAL_MEM_FENCE); + if (sgid == 0) sum += reduceLM[SIMDGROUP_WIDTH * 0 + slid]; + if (sgid == 0) sum += reduceLM[SIMDGROUP_WIDTH * 1 + slid]; + if (sgid == 0) sum += reduceLM[SIMDGROUP_WIDTH * 2 + slid]; + + // 1 output per thread in subgroup 0 + if (sgid == 0) { + dst = dst + (offsetd >> 2); + dst[i01 + i20 * ne01] = sum; + } +} diff --git a/ggml/src/ggml-opencl/kernels/gemv_noshuffle_iq4_nl_f32.cl b/ggml/src/ggml-opencl/kernels/gemv_noshuffle_iq4_nl_f32.cl new file mode 100644 index 00000000000..9386bf25a6f --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/gemv_noshuffle_iq4_nl_f32.cl @@ -0,0 +1,302 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable +#pragma OPENCL EXTENSION cl_khr_subgroups : enable + +#ifdef cl_qcom_reqd_sub_group_size +#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable +#define ADRENO_GPU 1 +#define REQD_SUBGROUP_SIZE_64 __attribute__((qcom_reqd_sub_group_size("half"))) +#endif + +#define QK4_NL 32 +#define NSUBGROUPS 4 +#define SUBGROUP_SIZE 64 + +constant half kvalues_iq4nl[16] = { + (half)-127.f, (half)-104.f, (half)-83.f, (half)-65.f, + (half) -49.f, (half) -35.f, (half)-22.f, (half)-10.f, + (half) 1.f, (half) 13.f, (half) 25.f, (half) 38.f, + (half) 53.f, (half) 69.f, (half) 89.f, (half)113.f +}; + +// Packed LUT: 2 FP16 values per uint, 8 unique constant loads instead of 16 +constant uint iq4nl_packed[8] = { + 0xD680D7F0u, // idx 0,1: -127, -104 + 0xD410D530u, // idx 2,3: -83, -65 + 0xD060D220u, // idx 4,5: -49, -35 + 0xC900CD80u, // idx 6,7: -22, -10 + 0x4A803C00u, // idx 8,9: 1, 13 + 0x50C04E40u, // idx 10,11: 25, 38 + 0x545052A0u, // idx 12,13: 53, 69 + 0x57105590u // idx 14,15: 89, 113 +}; + +// Packed dequant: 1 uint constant load (8-way divergence) + shift + as_half +#define IQ4_NL_DEQUANT(nibble) as_half((ushort)(iq4nl_packed[(nibble) >> 1] >> (((nibble) & 1u) << 4))) + +#define dequantizeBlockAccum_ns_sgbroadcast_1_hi(total_sums, bits4, scale, y) \ + float shared_y; \ + shared_y = sub_group_broadcast(y.s0, 0); \ + total_sums.s0 += IQ4_NL_DEQUANT((bits4.s0 & 0x000F)) * scale.s0 * shared_y; \ + total_sums.s1 += IQ4_NL_DEQUANT((bits4.s1 & 0x000F)) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s1, 0); \ + total_sums.s0 += IQ4_NL_DEQUANT(((bits4.s0 & 0x00F0) >> 4)) * scale.s0 * shared_y; \ + total_sums.s1 += IQ4_NL_DEQUANT(((bits4.s1 & 0x00F0) >> 4)) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s2, 0); \ + total_sums.s0 += IQ4_NL_DEQUANT(((bits4.s0 & 0x0F00) >> 8)) * scale.s0 * shared_y; \ + total_sums.s1 += IQ4_NL_DEQUANT(((bits4.s1 & 0x0F00) >> 8)) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s3, 0); \ + total_sums.s0 += IQ4_NL_DEQUANT(((bits4.s0 & 0xF000) >> 12)) * scale.s0 * shared_y; \ + total_sums.s1 += IQ4_NL_DEQUANT(((bits4.s1 & 0xF000) >> 12)) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s4, 0); \ + total_sums.s0 += IQ4_NL_DEQUANT((bits4.s2 & 0x000F)) * scale.s0 * shared_y; \ + total_sums.s1 += IQ4_NL_DEQUANT((bits4.s3 & 0x000F)) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s5, 0); \ + total_sums.s0 += IQ4_NL_DEQUANT(((bits4.s2 & 0x00F0) >> 4)) * scale.s0 * shared_y; \ + total_sums.s1 += IQ4_NL_DEQUANT(((bits4.s3 & 0x00F0) >> 4)) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s6, 0); \ + total_sums.s0 += IQ4_NL_DEQUANT(((bits4.s2 & 0x0F00) >> 8)) * scale.s0 * shared_y; \ + total_sums.s1 += IQ4_NL_DEQUANT(((bits4.s3 & 0x0F00) >> 8)) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s7, 0); \ + total_sums.s0 += IQ4_NL_DEQUANT(((bits4.s2 & 0xF000) >> 12)) * scale.s0 * shared_y; \ + total_sums.s1 += IQ4_NL_DEQUANT(((bits4.s3 & 0xF000) >> 12)) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s0, 1); \ + total_sums.s0 += IQ4_NL_DEQUANT((bits4.s4 & 0x000F)) * scale.s0 * shared_y; \ + total_sums.s1 += IQ4_NL_DEQUANT((bits4.s5 & 0x000F)) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s1, 1); \ + total_sums.s0 += IQ4_NL_DEQUANT(((bits4.s4 & 0x00F0) >> 4)) * scale.s0 * shared_y; \ + total_sums.s1 += IQ4_NL_DEQUANT(((bits4.s5 & 0x00F0) >> 4)) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s2, 1); \ + total_sums.s0 += IQ4_NL_DEQUANT(((bits4.s4 & 0x0F00) >> 8)) * scale.s0 * shared_y; \ + total_sums.s1 += IQ4_NL_DEQUANT(((bits4.s5 & 0x0F00) >> 8)) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s3, 1); \ + total_sums.s0 += IQ4_NL_DEQUANT(((bits4.s4 & 0xF000) >> 12)) * scale.s0 * shared_y; \ + total_sums.s1 += IQ4_NL_DEQUANT(((bits4.s5 & 0xF000) >> 12)) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s4, 1); \ + total_sums.s0 += IQ4_NL_DEQUANT((bits4.s6 & 0x000F)) * scale.s0 * shared_y; \ + total_sums.s1 += IQ4_NL_DEQUANT((bits4.s7 & 0x000F)) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s5, 1); \ + total_sums.s0 += IQ4_NL_DEQUANT(((bits4.s6 & 0x00F0) >> 4)) * scale.s0 * shared_y; \ + total_sums.s1 += IQ4_NL_DEQUANT(((bits4.s7 & 0x00F0) >> 4)) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s6, 1); \ + total_sums.s0 += IQ4_NL_DEQUANT(((bits4.s6 & 0x0F00) >> 8)) * scale.s0 * shared_y; \ + total_sums.s1 += IQ4_NL_DEQUANT(((bits4.s7 & 0x0F00) >> 8)) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s7, 1); \ + total_sums.s0 += IQ4_NL_DEQUANT(((bits4.s6 & 0xF000) >> 12)) * scale.s0 * shared_y; \ + total_sums.s1 += IQ4_NL_DEQUANT(((bits4.s7 & 0xF000) >> 12)) * scale.s1 * shared_y; \ + + +#define dequantizeBlockAccum_ns_sgbroadcast_1_lo(total_sums, bits4, scale, y) \ + shared_y = sub_group_broadcast(y.s0, 2); \ + total_sums.s0 += IQ4_NL_DEQUANT((bits4.s0 & 0x000F)) * scale.s0 * shared_y; \ + total_sums.s1 += IQ4_NL_DEQUANT((bits4.s1 & 0x000F)) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s1, 2); \ + total_sums.s0 += IQ4_NL_DEQUANT(((bits4.s0 & 0x00F0) >> 4)) * scale.s0 * shared_y; \ + total_sums.s1 += IQ4_NL_DEQUANT(((bits4.s1 & 0x00F0) >> 4)) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s2, 2); \ + total_sums.s0 += IQ4_NL_DEQUANT(((bits4.s0 & 0x0F00) >> 8)) * scale.s0 * shared_y; \ + total_sums.s1 += IQ4_NL_DEQUANT(((bits4.s1 & 0x0F00) >> 8)) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s3, 2); \ + total_sums.s0 += IQ4_NL_DEQUANT(((bits4.s0 & 0xF000) >> 12)) * scale.s0 * shared_y; \ + total_sums.s1 += IQ4_NL_DEQUANT(((bits4.s1 & 0xF000) >> 12)) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s4, 2); \ + total_sums.s0 += IQ4_NL_DEQUANT((bits4.s2 & 0x000F)) * scale.s0 * shared_y; \ + total_sums.s1 += IQ4_NL_DEQUANT((bits4.s3 & 0x000F)) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s5, 2); \ + total_sums.s0 += IQ4_NL_DEQUANT(((bits4.s2 & 0x00F0) >> 4)) * scale.s0 * shared_y; \ + total_sums.s1 += IQ4_NL_DEQUANT(((bits4.s3 & 0x00F0) >> 4)) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s6, 2); \ + total_sums.s0 += IQ4_NL_DEQUANT(((bits4.s2 & 0x0F00) >> 8)) * scale.s0 * shared_y; \ + total_sums.s1 += IQ4_NL_DEQUANT(((bits4.s3 & 0x0F00) >> 8)) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s7, 2); \ + total_sums.s0 += IQ4_NL_DEQUANT(((bits4.s2 & 0xF000) >> 12)) * scale.s0 * shared_y; \ + total_sums.s1 += IQ4_NL_DEQUANT(((bits4.s3 & 0xF000) >> 12)) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s0, 3); \ + total_sums.s0 += IQ4_NL_DEQUANT((bits4.s4 & 0x000F)) * scale.s0 * shared_y; \ + total_sums.s1 += IQ4_NL_DEQUANT((bits4.s5 & 0x000F)) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s1, 3); \ + total_sums.s0 += IQ4_NL_DEQUANT(((bits4.s4 & 0x00F0) >> 4)) * scale.s0 * shared_y; \ + total_sums.s1 += IQ4_NL_DEQUANT(((bits4.s5 & 0x00F0) >> 4)) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s2, 3); \ + total_sums.s0 += IQ4_NL_DEQUANT(((bits4.s4 & 0x0F00) >> 8)) * scale.s0 * shared_y; \ + total_sums.s1 += IQ4_NL_DEQUANT(((bits4.s5 & 0x0F00) >> 8)) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s3, 3); \ + total_sums.s0 += IQ4_NL_DEQUANT(((bits4.s4 & 0xF000) >> 12)) * scale.s0 * shared_y; \ + total_sums.s1 += IQ4_NL_DEQUANT(((bits4.s5 & 0xF000) >> 12)) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s4, 3); \ + total_sums.s0 += IQ4_NL_DEQUANT((bits4.s6 & 0x000F)) * scale.s0 * shared_y; \ + total_sums.s1 += IQ4_NL_DEQUANT((bits4.s7 & 0x000F)) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s5, 3); \ + total_sums.s0 += IQ4_NL_DEQUANT(((bits4.s6 & 0x00F0) >> 4)) * scale.s0 * shared_y; \ + total_sums.s1 += IQ4_NL_DEQUANT(((bits4.s7 & 0x00F0) >> 4)) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s6, 3); \ + total_sums.s0 += IQ4_NL_DEQUANT(((bits4.s6 & 0x0F00) >> 8)) * scale.s0 * shared_y; \ + total_sums.s1 += IQ4_NL_DEQUANT(((bits4.s7 & 0x0F00) >> 8)) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s7, 3); \ + total_sums.s0 += IQ4_NL_DEQUANT(((bits4.s6 & 0xF000) >> 12)) * scale.s0 * shared_y; \ + total_sums.s1 += IQ4_NL_DEQUANT(((bits4.s7 & 0xF000) >> 12)) * scale.s1 * shared_y; \ + + +#define dequantizeBlockAccum_ns_sgbroadcast_8_hi(total_sums, bits4, scale, y) \ + float8 shared_y; \ + shared_y = sub_group_broadcast(y, 0); \ + total_sums.s0 += IQ4_NL_DEQUANT((bits4.s0 & 0x000F)) * scale.s0 * shared_y.s0; \ + total_sums.s0 += IQ4_NL_DEQUANT(((bits4.s0 & 0x00F0) >> 4)) * scale.s0 * shared_y.s1; \ + total_sums.s0 += IQ4_NL_DEQUANT(((bits4.s0 & 0x0F00) >> 8)) * scale.s0 * shared_y.s2; \ + total_sums.s0 += IQ4_NL_DEQUANT(((bits4.s0 & 0xF000) >> 12)) * scale.s0 * shared_y.s3; \ + total_sums.s0 += IQ4_NL_DEQUANT((bits4.s2 & 0x000F)) * scale.s0 * shared_y.s4; \ + total_sums.s0 += IQ4_NL_DEQUANT(((bits4.s2 & 0x00F0) >> 4)) * scale.s0 * shared_y.s5; \ + total_sums.s0 += IQ4_NL_DEQUANT(((bits4.s2 & 0x0F00) >> 8)) * scale.s0 * shared_y.s6; \ + total_sums.s0 += IQ4_NL_DEQUANT(((bits4.s2 & 0xF000) >> 12)) * scale.s0 * shared_y.s7; \ + total_sums.s1 += IQ4_NL_DEQUANT((bits4.s1 & 0x000F)) * scale.s1 * shared_y.s0; \ + total_sums.s1 += IQ4_NL_DEQUANT(((bits4.s1 & 0x00F0) >> 4)) * scale.s1 * shared_y.s1; \ + total_sums.s1 += IQ4_NL_DEQUANT(((bits4.s1 & 0x0F00) >> 8)) * scale.s1 * shared_y.s2; \ + total_sums.s1 += IQ4_NL_DEQUANT(((bits4.s1 & 0xF000) >> 12)) * scale.s1 * shared_y.s3; \ + total_sums.s1 += IQ4_NL_DEQUANT((bits4.s3 & 0x000F)) * scale.s1 * shared_y.s4; \ + total_sums.s1 += IQ4_NL_DEQUANT(((bits4.s3 & 0x00F0) >> 4)) * scale.s1 * shared_y.s5; \ + total_sums.s1 += IQ4_NL_DEQUANT(((bits4.s3 & 0x0F00) >> 8)) * scale.s1 * shared_y.s6; \ + total_sums.s1 += IQ4_NL_DEQUANT(((bits4.s3 & 0xF000) >> 12)) * scale.s1 * shared_y.s7; \ + shared_y = sub_group_broadcast(y, 1); \ + total_sums.s0 += IQ4_NL_DEQUANT((bits4.s4 & 0x000F)) * scale.s0 * shared_y.s0; \ + total_sums.s0 += IQ4_NL_DEQUANT(((bits4.s4 & 0x00F0) >> 4)) * scale.s0 * shared_y.s1; \ + total_sums.s0 += IQ4_NL_DEQUANT(((bits4.s4 & 0x0F00) >> 8)) * scale.s0 * shared_y.s2; \ + total_sums.s0 += IQ4_NL_DEQUANT(((bits4.s4 & 0xF000) >> 12)) * scale.s0 * shared_y.s3; \ + total_sums.s0 += IQ4_NL_DEQUANT((bits4.s6 & 0x000F)) * scale.s0 * shared_y.s4; \ + total_sums.s0 += IQ4_NL_DEQUANT(((bits4.s6 & 0x00F0) >> 4)) * scale.s0 * shared_y.s5; \ + total_sums.s0 += IQ4_NL_DEQUANT(((bits4.s6 & 0x0F00) >> 8)) * scale.s0 * shared_y.s6; \ + total_sums.s0 += IQ4_NL_DEQUANT(((bits4.s6 & 0xF000) >> 12)) * scale.s0 * shared_y.s7; \ + total_sums.s1 += IQ4_NL_DEQUANT((bits4.s5 & 0x000F)) * scale.s1 * shared_y.s0; \ + total_sums.s1 += IQ4_NL_DEQUANT(((bits4.s5 & 0x00F0) >> 4)) * scale.s1 * shared_y.s1; \ + total_sums.s1 += IQ4_NL_DEQUANT(((bits4.s5 & 0x0F00) >> 8)) * scale.s1 * shared_y.s2; \ + total_sums.s1 += IQ4_NL_DEQUANT(((bits4.s5 & 0xF000) >> 12)) * scale.s1 * shared_y.s3; \ + total_sums.s1 += IQ4_NL_DEQUANT((bits4.s7 & 0x000F)) * scale.s1 * shared_y.s4; \ + total_sums.s1 += IQ4_NL_DEQUANT(((bits4.s7 & 0x00F0) >> 4)) * scale.s1 * shared_y.s5; \ + total_sums.s1 += IQ4_NL_DEQUANT(((bits4.s7 & 0x0F00) >> 8)) * scale.s1 * shared_y.s6; \ + total_sums.s1 += IQ4_NL_DEQUANT(((bits4.s7 & 0xF000) >> 12)) * scale.s1 * shared_y.s7; \ + + +#define dequantizeBlockAccum_ns_sgbroadcast_8_lo(total_sums, bits4, scale, y) \ + shared_y = sub_group_broadcast(y, 2); \ + total_sums.s0 += IQ4_NL_DEQUANT((bits4.s0 & 0x000F)) * scale.s0 * shared_y.s0; \ + total_sums.s0 += IQ4_NL_DEQUANT(((bits4.s0 & 0x00F0) >> 4)) * scale.s0 * shared_y.s1; \ + total_sums.s0 += IQ4_NL_DEQUANT(((bits4.s0 & 0x0F00) >> 8)) * scale.s0 * shared_y.s2; \ + total_sums.s0 += IQ4_NL_DEQUANT(((bits4.s0 & 0xF000) >> 12)) * scale.s0 * shared_y.s3; \ + total_sums.s0 += IQ4_NL_DEQUANT((bits4.s2 & 0x000F)) * scale.s0 * shared_y.s4; \ + total_sums.s0 += IQ4_NL_DEQUANT(((bits4.s2 & 0x00F0) >> 4)) * scale.s0 * shared_y.s5; \ + total_sums.s0 += IQ4_NL_DEQUANT(((bits4.s2 & 0x0F00) >> 8)) * scale.s0 * shared_y.s6; \ + total_sums.s0 += IQ4_NL_DEQUANT(((bits4.s2 & 0xF000) >> 12)) * scale.s0 * shared_y.s7; \ + total_sums.s1 += IQ4_NL_DEQUANT((bits4.s1 & 0x000F)) * scale.s1 * shared_y.s0; \ + total_sums.s1 += IQ4_NL_DEQUANT(((bits4.s1 & 0x00F0) >> 4)) * scale.s1 * shared_y.s1; \ + total_sums.s1 += IQ4_NL_DEQUANT(((bits4.s1 & 0x0F00) >> 8)) * scale.s1 * shared_y.s2; \ + total_sums.s1 += IQ4_NL_DEQUANT(((bits4.s1 & 0xF000) >> 12)) * scale.s1 * shared_y.s3; \ + total_sums.s1 += IQ4_NL_DEQUANT((bits4.s3 & 0x000F)) * scale.s1 * shared_y.s4; \ + total_sums.s1 += IQ4_NL_DEQUANT(((bits4.s3 & 0x00F0) >> 4)) * scale.s1 * shared_y.s5; \ + total_sums.s1 += IQ4_NL_DEQUANT(((bits4.s3 & 0x0F00) >> 8)) * scale.s1 * shared_y.s6; \ + total_sums.s1 += IQ4_NL_DEQUANT(((bits4.s3 & 0xF000) >> 12)) * scale.s1 * shared_y.s7; \ + shared_y = sub_group_broadcast(y, 3); \ + total_sums.s0 += IQ4_NL_DEQUANT((bits4.s4 & 0x000F)) * scale.s0 * shared_y.s0; \ + total_sums.s0 += IQ4_NL_DEQUANT(((bits4.s4 & 0x00F0) >> 4)) * scale.s0 * shared_y.s1; \ + total_sums.s0 += IQ4_NL_DEQUANT(((bits4.s4 & 0x0F00) >> 8)) * scale.s0 * shared_y.s2; \ + total_sums.s0 += IQ4_NL_DEQUANT(((bits4.s4 & 0xF000) >> 12)) * scale.s0 * shared_y.s3; \ + total_sums.s0 += IQ4_NL_DEQUANT((bits4.s6 & 0x000F)) * scale.s0 * shared_y.s4; \ + total_sums.s0 += IQ4_NL_DEQUANT(((bits4.s6 & 0x00F0) >> 4)) * scale.s0 * shared_y.s5; \ + total_sums.s0 += IQ4_NL_DEQUANT(((bits4.s6 & 0x0F00) >> 8)) * scale.s0 * shared_y.s6; \ + total_sums.s0 += IQ4_NL_DEQUANT(((bits4.s6 & 0xF000) >> 12)) * scale.s0 * shared_y.s7; \ + total_sums.s1 += IQ4_NL_DEQUANT((bits4.s5 & 0x000F)) * scale.s1 * shared_y.s0; \ + total_sums.s1 += IQ4_NL_DEQUANT(((bits4.s5 & 0x00F0) >> 4)) * scale.s1 * shared_y.s1; \ + total_sums.s1 += IQ4_NL_DEQUANT(((bits4.s5 & 0x0F00) >> 8)) * scale.s1 * shared_y.s2; \ + total_sums.s1 += IQ4_NL_DEQUANT(((bits4.s5 & 0xF000) >> 12)) * scale.s1 * shared_y.s3; \ + total_sums.s1 += IQ4_NL_DEQUANT((bits4.s7 & 0x000F)) * scale.s1 * shared_y.s4; \ + total_sums.s1 += IQ4_NL_DEQUANT(((bits4.s7 & 0x00F0) >> 4)) * scale.s1 * shared_y.s5; \ + total_sums.s1 += IQ4_NL_DEQUANT(((bits4.s7 & 0x0F00) >> 8)) * scale.s1 * shared_y.s6; \ + total_sums.s1 += IQ4_NL_DEQUANT(((bits4.s7 & 0xF000) >> 12)) * scale.s1 * shared_y.s7; \ + +#ifdef ADRENO_GPU +REQD_SUBGROUP_SIZE_64 +#endif +kernel void kernel_gemv_noshuffle_iq4_nl_f32( + read_only image1d_buffer_t src0_q, + global half2 * src0_d, + read_only image1d_buffer_t src1, + global float * dst, + ulong offsetd, + int ne00, + int ne01) +{ + uint groupId = get_local_id(1); + uint gid = get_global_id(0); + ushort slid = get_sub_group_local_id(); + + uint K = ne00; + uint M = ne01; + + uint LINE_STRIDE_A = M / 2; + uint BLOCK_STRIDE_A = NSUBGROUPS * M; + + private uint4 regA; + private half2 regS; + private float8 regB; + + private float2 totalSum = (float2)(0.0f); + + // loop along K in block granularity, skip 4 blocks every iter + for (uint k = groupId; k < (K / QK4_NL); k += NSUBGROUPS) { + regS = src0_d[gid + k * LINE_STRIDE_A]; // each fiber loads scale of two rows + // first 4 fibers in each wave load 8 B values to its private scope + if (slid < 4) { + regB.s0123 = read_imagef(src1, (slid * 2 + k * 8)); + regB.s4567 = read_imagef(src1, (1 + slid * 2 + k * 8)); + } + + // load half weights for two blocks in consecutive rows + regA.s0 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 0)).x; + regA.s1 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 1)).x; + regA.s2 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 2)).x; + regA.s3 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 3)).x; +#ifdef VECTOR_SUB_GROUP_BROADCAST + dequantizeBlockAccum_ns_sgbroadcast_8_hi(totalSum, as_ushort8(regA), regS, regB); +#else + dequantizeBlockAccum_ns_sgbroadcast_1_hi(totalSum, as_ushort8(regA), regS, regB); +#endif // VECTOR_SUB_GROUP_BROADCAST + + regA.s0 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 4)).x; + regA.s1 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 5)).x; + regA.s2 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 6)).x; + regA.s3 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 7)).x; +#ifdef VECTOR_SUB_GROUP_BROADCAST + dequantizeBlockAccum_ns_sgbroadcast_8_lo(totalSum, as_ushort8(regA), regS, regB); +#else + dequantizeBlockAccum_ns_sgbroadcast_1_lo(totalSum, as_ushort8(regA), regS, regB); +#endif // VECTOR_SUB_GROUP_BROADCAST + } + + // reduction in local memory, assumes #wave=4 + local float2 reduceLM[SUBGROUP_SIZE * 3]; + if (groupId == 1) { + reduceLM[SUBGROUP_SIZE * 0 + slid] = totalSum; + } + if (groupId == 2) { + reduceLM[SUBGROUP_SIZE * 1 + slid] = totalSum; + } + if (groupId == 3) { + reduceLM[SUBGROUP_SIZE * 2 + slid] = totalSum; + } + + barrier(CLK_LOCAL_MEM_FENCE); + + if (groupId == 0) { + totalSum += reduceLM[SUBGROUP_SIZE * 0 + slid]; + } + if (groupId == 0) { + totalSum += reduceLM[SUBGROUP_SIZE * 1 + slid]; + } + if (groupId == 0) { + totalSum += reduceLM[SUBGROUP_SIZE * 2 + slid]; + } + + // 2 outputs per fiber in wave 0 + if (groupId == 0) { + dst = (global float*)((global char*)dst + offsetd); + vstore2(totalSum, 0, &(dst[gid * 2])); + } + +} diff --git a/ggml/src/ggml-opencl/kernels/gemv_noshuffle_general.cl b/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q4_0_f32.cl similarity index 98% rename from ggml/src/ggml-opencl/kernels/gemv_noshuffle_general.cl rename to ggml/src/ggml-opencl/kernels/gemv_noshuffle_q4_0_f32.cl index 469d3edef00..10683206919 100644 --- a/ggml/src/ggml-opencl/kernels/gemv_noshuffle_general.cl +++ b/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q4_0_f32.cl @@ -191,7 +191,7 @@ #ifdef ADRENO_GPU REQD_SUBGROUP_SIZE_64 #endif -__kernel void kernel_gemv_noshuffle( +__kernel void kernel_gemv_noshuffle_q4_0_f32( __read_only image1d_buffer_t src0_q, // quantized A global half2 * src0_d, // A scales __read_only image1d_buffer_t src1, // B @@ -238,21 +238,21 @@ __kernel void kernel_gemv_noshuffle( regA.s1 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 1)).x; regA.s2 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 2)).x; regA.s3 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 3)).x; -#ifdef VECTOR_SUB_GROUP_BROADCAT +#ifdef VECTOR_SUB_GROUP_BROADCAST dequantizeBlockAccum_ns_sgbroadcast_8_hi(totalSum, as_ushort8(regA), regS, regB); #else dequantizeBlockAccum_ns_sgbroadcast_1_hi(totalSum, as_ushort8(regA), regS, regB); -#endif // VECTOR_SUB_GROUP_BROADCAT +#endif // VECTOR_SUB_GROUP_BROADCAST regA.s0 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 4)).x; regA.s1 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 5)).x; regA.s2 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 6)).x; regA.s3 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 7)).x; -#ifdef VECTOR_SUB_GROUP_BROADCAT +#ifdef VECTOR_SUB_GROUP_BROADCAST dequantizeBlockAccum_ns_sgbroadcast_8_lo(totalSum, as_ushort8(regA), regS, regB); #else dequantizeBlockAccum_ns_sgbroadcast_1_lo(totalSum, as_ushort8(regA), regS, regB); -#endif // VECTOR_SUB_GROUP_BROADCAT +#endif // VECTOR_SUB_GROUP_BROADCAST } // reduction in local memory, assumes #wave=4 diff --git a/ggml/src/ggml-opencl/kernels/gemv_noshuffle.cl b/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q4_0_f32_spec.cl similarity index 98% rename from ggml/src/ggml-opencl/kernels/gemv_noshuffle.cl rename to ggml/src/ggml-opencl/kernels/gemv_noshuffle_q4_0_f32_spec.cl index ee5c79f000d..571a375da7f 100644 --- a/ggml/src/ggml-opencl/kernels/gemv_noshuffle.cl +++ b/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q4_0_f32_spec.cl @@ -191,7 +191,7 @@ #ifdef ADRENO_GPU REQD_SUBGROUP_SIZE_64 #endif -__kernel void kernel_gemv_noshuffle( +__kernel void kernel_gemv_noshuffle_q4_0_f32( __read_only image1d_buffer_t src0_q, // quantized A global half2 * src0_d, // A scales __read_only image1d_buffer_t src1, // B @@ -232,21 +232,21 @@ __kernel void kernel_gemv_noshuffle( regA.s1 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 1)).x; regA.s2 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 2)).x; regA.s3 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 3)).x; -#ifdef VECTOR_SUB_GROUP_BROADCAT +#ifdef VECTOR_SUB_GROUP_BROADCAST dequantizeBlockAccum_ns_sgbroadcast_8_hi(totalSum, as_ushort8(regA), regS, regB); #else dequantizeBlockAccum_ns_sgbroadcast_1_hi(totalSum, as_ushort8(regA), regS, regB); -#endif // VECTOR_SUB_GROUP_BROADCAT +#endif // VECTOR_SUB_GROUP_BROADCAST regA.s0 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 4)).x; regA.s1 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 5)).x; regA.s2 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 6)).x; regA.s3 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 7)).x; -#ifdef VECTOR_SUB_GROUP_BROADCAT +#ifdef VECTOR_SUB_GROUP_BROADCAST dequantizeBlockAccum_ns_sgbroadcast_8_lo(totalSum, as_ushort8(regA), regS, regB); #else dequantizeBlockAccum_ns_sgbroadcast_1_lo(totalSum, as_ushort8(regA), regS, regB); -#endif // VECTOR_SUB_GROUP_BROADCAT +#endif // VECTOR_SUB_GROUP_BROADCAST } // reduction in local memory, assumes #wave=4 diff --git a/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q4_1_f32.cl b/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q4_1_f32.cl new file mode 100644 index 00000000000..fdc1472454f --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q4_1_f32.cl @@ -0,0 +1,283 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable +#pragma OPENCL EXTENSION cl_khr_subgroups : enable + +#ifdef cl_qcom_reqd_sub_group_size +#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable +#define ADRENO_GPU 1 +#define REQD_SUBGROUP_SIZE_64 __attribute__((qcom_reqd_sub_group_size("half"))) +#endif + +#define QK4_0 32 +#define NSUBGROUPS 4 +#define SUBGROUP_SIZE 64 + +#define dequantizeBlockAccum_ns_sgbroadcast_1_hi(total_sums, bits4, scale, minv, y) \ + float shared_y; \ + shared_y = sub_group_broadcast(y.s0, 0); \ + total_sums.s0 += ((bits4.s0 & 0x000F) * scale.s0 + minv.s0) * shared_y; \ + total_sums.s1 += ((bits4.s1 & 0x000F) * scale.s1 + minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s1, 0); \ + total_sums.s0 += (((bits4.s0 & 0x00F0) >> 4) * scale.s0 + minv.s0) * shared_y; \ + total_sums.s1 += (((bits4.s1 & 0x00F0) >> 4) * scale.s1 + minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s2, 0); \ + total_sums.s0 += (((bits4.s0 & 0x0F00) >> 8) * scale.s0 + minv.s0) * shared_y; \ + total_sums.s1 += (((bits4.s1 & 0x0F00) >> 8) * scale.s1 + minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s3, 0); \ + total_sums.s0 += (((bits4.s0 & 0xF000) >> 12) * scale.s0 + minv.s0) * shared_y; \ + total_sums.s1 += (((bits4.s1 & 0xF000) >> 12) * scale.s1 + minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s4, 0); \ + total_sums.s0 += ((bits4.s2 & 0x000F) * scale.s0 + minv.s0) * shared_y; \ + total_sums.s1 += ((bits4.s3 & 0x000F) * scale.s1 + minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s5, 0); \ + total_sums.s0 += (((bits4.s2 & 0x00F0) >> 4) * scale.s0 + minv.s0) * shared_y; \ + total_sums.s1 += (((bits4.s3 & 0x00F0) >> 4) * scale.s1 + minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s6, 0); \ + total_sums.s0 += (((bits4.s2 & 0x0F00) >> 8) * scale.s0 + minv.s0) * shared_y; \ + total_sums.s1 += (((bits4.s3 & 0x0F00) >> 8) * scale.s1 + minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s7, 0); \ + total_sums.s0 += (((bits4.s2 & 0xF000) >> 12) * scale.s0 + minv.s0) * shared_y; \ + total_sums.s1 += (((bits4.s3 & 0xF000) >> 12) * scale.s1 + minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s0, 1); \ + total_sums.s0 += ((bits4.s4 & 0x000F) * scale.s0 + minv.s0) * shared_y; \ + total_sums.s1 += ((bits4.s5 & 0x000F) * scale.s1 + minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s1, 1); \ + total_sums.s0 += (((bits4.s4 & 0x00F0) >> 4) * scale.s0 + minv.s0) * shared_y; \ + total_sums.s1 += (((bits4.s5 & 0x00F0) >> 4) * scale.s1 + minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s2, 1); \ + total_sums.s0 += (((bits4.s4 & 0x0F00) >> 8) * scale.s0 + minv.s0) * shared_y; \ + total_sums.s1 += (((bits4.s5 & 0x0F00) >> 8) * scale.s1 + minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s3, 1); \ + total_sums.s0 += (((bits4.s4 & 0xF000) >> 12) * scale.s0 + minv.s0) * shared_y; \ + total_sums.s1 += (((bits4.s5 & 0xF000) >> 12) * scale.s1 + minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s4, 1); \ + total_sums.s0 += ((bits4.s6 & 0x000F) * scale.s0 + minv.s0) * shared_y; \ + total_sums.s1 += ((bits4.s7 & 0x000F) * scale.s1 + minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s5, 1); \ + total_sums.s0 += (((bits4.s6 & 0x00F0) >> 4) * scale.s0 + minv.s0) * shared_y; \ + total_sums.s1 += (((bits4.s7 & 0x00F0) >> 4) * scale.s1 + minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s6, 1); \ + total_sums.s0 += (((bits4.s6 & 0x0F00) >> 8) * scale.s0 + minv.s0) * shared_y; \ + total_sums.s1 += (((bits4.s7 & 0x0F00) >> 8) * scale.s1 + minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s7, 1); \ + total_sums.s0 += (((bits4.s6 & 0xF000) >> 12) * scale.s0 + minv.s0) * shared_y; \ + total_sums.s1 += (((bits4.s7 & 0xF000) >> 12) * scale.s1 + minv.s1) * shared_y; \ + + +#define dequantizeBlockAccum_ns_sgbroadcast_1_lo(total_sums, bits4, scale, minv, y) \ + shared_y = sub_group_broadcast(y.s0, 2); \ + total_sums.s0 += ((bits4.s0 & 0x000F) * scale.s0 + minv.s0) * shared_y; \ + total_sums.s1 += ((bits4.s1 & 0x000F) * scale.s1 + minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s1, 2); \ + total_sums.s0 += (((bits4.s0 & 0x00F0) >> 4) * scale.s0 + minv.s0) * shared_y; \ + total_sums.s1 += (((bits4.s1 & 0x00F0) >> 4) * scale.s1 + minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s2, 2); \ + total_sums.s0 += (((bits4.s0 & 0x0F00) >> 8) * scale.s0 + minv.s0) * shared_y; \ + total_sums.s1 += (((bits4.s1 & 0x0F00) >> 8) * scale.s1 + minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s3, 2); \ + total_sums.s0 += (((bits4.s0 & 0xF000) >> 12) * scale.s0 + minv.s0) * shared_y; \ + total_sums.s1 += (((bits4.s1 & 0xF000) >> 12) * scale.s1 + minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s4, 2); \ + total_sums.s0 += ((bits4.s2 & 0x000F) * scale.s0 + minv.s0) * shared_y; \ + total_sums.s1 += ((bits4.s3 & 0x000F) * scale.s1 + minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s5, 2); \ + total_sums.s0 += (((bits4.s2 & 0x00F0) >> 4) * scale.s0 + minv.s0) * shared_y; \ + total_sums.s1 += (((bits4.s3 & 0x00F0) >> 4) * scale.s1 + minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s6, 2); \ + total_sums.s0 += (((bits4.s2 & 0x0F00) >> 8) * scale.s0 + minv.s0) * shared_y; \ + total_sums.s1 += (((bits4.s3 & 0x0F00) >> 8) * scale.s1 + minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s7, 2); \ + total_sums.s0 += (((bits4.s2 & 0xF000) >> 12) * scale.s0 + minv.s0) * shared_y; \ + total_sums.s1 += (((bits4.s3 & 0xF000) >> 12) * scale.s1 + minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s0, 3); \ + total_sums.s0 += ((bits4.s4 & 0x000F) * scale.s0 + minv.s0) * shared_y; \ + total_sums.s1 += ((bits4.s5 & 0x000F) * scale.s1 + minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s1, 3); \ + total_sums.s0 += (((bits4.s4 & 0x00F0) >> 4) * scale.s0 + minv.s0) * shared_y; \ + total_sums.s1 += (((bits4.s5 & 0x00F0) >> 4) * scale.s1 + minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s2, 3); \ + total_sums.s0 += (((bits4.s4 & 0x0F00) >> 8) * scale.s0 + minv.s0) * shared_y; \ + total_sums.s1 += (((bits4.s5 & 0x0F00) >> 8) * scale.s1 + minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s3, 3); \ + total_sums.s0 += (((bits4.s4 & 0xF000) >> 12) * scale.s0 + minv.s0) * shared_y; \ + total_sums.s1 += (((bits4.s5 & 0xF000) >> 12) * scale.s1 + minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s4, 3); \ + total_sums.s0 += ((bits4.s6 & 0x000F) * scale.s0 + minv.s0) * shared_y; \ + total_sums.s1 += ((bits4.s7 & 0x000F) * scale.s1 + minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s5, 3); \ + total_sums.s0 += (((bits4.s6 & 0x00F0) >> 4) * scale.s0 + minv.s0) * shared_y; \ + total_sums.s1 += (((bits4.s7 & 0x00F0) >> 4) * scale.s1 + minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s6, 3); \ + total_sums.s0 += (((bits4.s6 & 0x0F00) >> 8) * scale.s0 + minv.s0) * shared_y; \ + total_sums.s1 += (((bits4.s7 & 0x0F00) >> 8) * scale.s1 + minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s7, 3); \ + total_sums.s0 += (((bits4.s6 & 0xF000) >> 12) * scale.s0 + minv.s0) * shared_y; \ + total_sums.s1 += (((bits4.s7 & 0xF000) >> 12) * scale.s1 + minv.s1) * shared_y; \ + + +#define dequantizeBlockAccum_ns_sgbroadcast_8_hi(total_sums, bits4, scale, minv, y) \ + float8 shared_y; \ + shared_y = sub_group_broadcast(y, 0); \ + total_sums.s0 += ((bits4.s0 & 0x000F) * scale.s0 + minv.s0) * shared_y.s0; \ + total_sums.s0 += (((bits4.s0 & 0x00F0) >> 4) * scale.s0 + minv.s0) * shared_y.s1; \ + total_sums.s0 += (((bits4.s0 & 0x0F00) >> 8) * scale.s0 + minv.s0) * shared_y.s2; \ + total_sums.s0 += (((bits4.s0 & 0xF000) >> 12) * scale.s0 + minv.s0) * shared_y.s3; \ + total_sums.s0 += ((bits4.s2 & 0x000F) * scale.s0 + minv.s0) * shared_y.s4; \ + total_sums.s0 += (((bits4.s2 & 0x00F0) >> 4) * scale.s0 + minv.s0) * shared_y.s5; \ + total_sums.s0 += (((bits4.s2 & 0x0F00) >> 8) * scale.s0 + minv.s0) * shared_y.s6; \ + total_sums.s0 += (((bits4.s2 & 0xF000) >> 12) * scale.s0 + minv.s0) * shared_y.s7; \ + total_sums.s1 += ((bits4.s1 & 0x000F) * scale.s1 + minv.s1) * shared_y.s0; \ + total_sums.s1 += (((bits4.s1 & 0x00F0) >> 4) * scale.s1 + minv.s1) * shared_y.s1; \ + total_sums.s1 += (((bits4.s1 & 0x0F00) >> 8) * scale.s1 + minv.s1) * shared_y.s2; \ + total_sums.s1 += (((bits4.s1 & 0xF000) >> 12) * scale.s1 + minv.s1) * shared_y.s3; \ + total_sums.s1 += ((bits4.s3 & 0x000F) * scale.s1 + minv.s1) * shared_y.s4; \ + total_sums.s1 += (((bits4.s3 & 0x00F0) >> 4) * scale.s1 + minv.s1) * shared_y.s5; \ + total_sums.s1 += (((bits4.s3 & 0x0F00) >> 8) * scale.s1 + minv.s1) * shared_y.s6; \ + total_sums.s1 += (((bits4.s3 & 0xF000) >> 12) * scale.s1 + minv.s1) * shared_y.s7; \ + shared_y = sub_group_broadcast(y, 1); \ + total_sums.s0 += ((bits4.s4 & 0x000F) * scale.s0 + minv.s0) * shared_y.s0; \ + total_sums.s0 += (((bits4.s4 & 0x00F0) >> 4) * scale.s0 + minv.s0) * shared_y.s1; \ + total_sums.s0 += (((bits4.s4 & 0x0F00) >> 8) * scale.s0 + minv.s0) * shared_y.s2; \ + total_sums.s0 += (((bits4.s4 & 0xF000) >> 12) * scale.s0 + minv.s0) * shared_y.s3; \ + total_sums.s0 += ((bits4.s6 & 0x000F) * scale.s0 + minv.s0) * shared_y.s4; \ + total_sums.s0 += (((bits4.s6 & 0x00F0) >> 4) * scale.s0 + minv.s0) * shared_y.s5; \ + total_sums.s0 += (((bits4.s6 & 0x0F00) >> 8) * scale.s0 + minv.s0) * shared_y.s6; \ + total_sums.s0 += (((bits4.s6 & 0xF000) >> 12) * scale.s0 + minv.s0) * shared_y.s7; \ + total_sums.s1 += ((bits4.s5 & 0x000F) * scale.s1 + minv.s1) * shared_y.s0; \ + total_sums.s1 += (((bits4.s5 & 0x00F0) >> 4) * scale.s1 + minv.s1) * shared_y.s1; \ + total_sums.s1 += (((bits4.s5 & 0x0F00) >> 8) * scale.s1 + minv.s1) * shared_y.s2; \ + total_sums.s1 += (((bits4.s5 & 0xF000) >> 12) * scale.s1 + minv.s1) * shared_y.s3; \ + total_sums.s1 += ((bits4.s7 & 0x000F) * scale.s1 + minv.s1) * shared_y.s4; \ + total_sums.s1 += (((bits4.s7 & 0x00F0) >> 4) * scale.s1 + minv.s1) * shared_y.s5; \ + total_sums.s1 += (((bits4.s7 & 0x0F00) >> 8) * scale.s1 + minv.s1) * shared_y.s6; \ + total_sums.s1 += (((bits4.s7 & 0xF000) >> 12) * scale.s1 + minv.s1) * shared_y.s7; \ + + +#define dequantizeBlockAccum_ns_sgbroadcast_8_lo(total_sums, bits4, scale, minv, y) \ + shared_y = sub_group_broadcast(y, 2); \ + total_sums.s0 += ((bits4.s0 & 0x000F) * scale.s0 + minv.s0) * shared_y.s0; \ + total_sums.s0 += (((bits4.s0 & 0x00F0) >> 4) * scale.s0 + minv.s0) * shared_y.s1; \ + total_sums.s0 += (((bits4.s0 & 0x0F00) >> 8) * scale.s0 + minv.s0) * shared_y.s2; \ + total_sums.s0 += (((bits4.s0 & 0xF000) >> 12) * scale.s0 + minv.s0) * shared_y.s3; \ + total_sums.s0 += ((bits4.s2 & 0x000F) * scale.s0 + minv.s0) * shared_y.s4; \ + total_sums.s0 += (((bits4.s2 & 0x00F0) >> 4) * scale.s0 + minv.s0) * shared_y.s5; \ + total_sums.s0 += (((bits4.s2 & 0x0F00) >> 8) * scale.s0 + minv.s0) * shared_y.s6; \ + total_sums.s0 += (((bits4.s2 & 0xF000) >> 12) * scale.s0 + minv.s0) * shared_y.s7; \ + total_sums.s1 += ((bits4.s1 & 0x000F) * scale.s1 + minv.s1) * shared_y.s0; \ + total_sums.s1 += (((bits4.s1 & 0x00F0) >> 4) * scale.s1 + minv.s1) * shared_y.s1; \ + total_sums.s1 += (((bits4.s1 & 0x0F00) >> 8) * scale.s1 + minv.s1) * shared_y.s2; \ + total_sums.s1 += (((bits4.s1 & 0xF000) >> 12) * scale.s1 + minv.s1) * shared_y.s3; \ + total_sums.s1 += ((bits4.s3 & 0x000F) * scale.s1 + minv.s1) * shared_y.s4; \ + total_sums.s1 += (((bits4.s3 & 0x00F0) >> 4) * scale.s1 + minv.s1) * shared_y.s5; \ + total_sums.s1 += (((bits4.s3 & 0x0F00) >> 8) * scale.s1 + minv.s1) * shared_y.s6; \ + total_sums.s1 += (((bits4.s3 & 0xF000) >> 12) * scale.s1 + minv.s1) * shared_y.s7; \ + shared_y = sub_group_broadcast(y, 3); \ + total_sums.s0 += ((bits4.s4 & 0x000F) * scale.s0 + minv.s0) * shared_y.s0; \ + total_sums.s0 += (((bits4.s4 & 0x00F0) >> 4) * scale.s0 + minv.s0) * shared_y.s1; \ + total_sums.s0 += (((bits4.s4 & 0x0F00) >> 8) * scale.s0 + minv.s0) * shared_y.s2; \ + total_sums.s0 += (((bits4.s4 & 0xF000) >> 12) * scale.s0 + minv.s0) * shared_y.s3; \ + total_sums.s0 += ((bits4.s6 & 0x000F) * scale.s0 + minv.s0) * shared_y.s4; \ + total_sums.s0 += (((bits4.s6 & 0x00F0) >> 4) * scale.s0 + minv.s0) * shared_y.s5; \ + total_sums.s0 += (((bits4.s6 & 0x0F00) >> 8) * scale.s0 + minv.s0) * shared_y.s6; \ + total_sums.s0 += (((bits4.s6 & 0xF000) >> 12) * scale.s0 + minv.s0) * shared_y.s7; \ + total_sums.s1 += ((bits4.s5 & 0x000F) * scale.s1 + minv.s1) * shared_y.s0; \ + total_sums.s1 += (((bits4.s5 & 0x00F0) >> 4) * scale.s1 + minv.s1) * shared_y.s1; \ + total_sums.s1 += (((bits4.s5 & 0x0F00) >> 8) * scale.s1 + minv.s1) * shared_y.s2; \ + total_sums.s1 += (((bits4.s5 & 0xF000) >> 12) * scale.s1 + minv.s1) * shared_y.s3; \ + total_sums.s1 += ((bits4.s7 & 0x000F) * scale.s1 + minv.s1) * shared_y.s4; \ + total_sums.s1 += (((bits4.s7 & 0x00F0) >> 4) * scale.s1 + minv.s1) * shared_y.s5; \ + total_sums.s1 += (((bits4.s7 & 0x0F00) >> 8) * scale.s1 + minv.s1) * shared_y.s6; \ + total_sums.s1 += (((bits4.s7 & 0xF000) >> 12) * scale.s1 + minv.s1) * shared_y.s7; \ + +#ifdef ADRENO_GPU +REQD_SUBGROUP_SIZE_64 +#endif +kernel void kernel_gemv_noshuffle_q4_1_f32( + read_only image1d_buffer_t src0_q, + global half2 * src0_d, + global half2 * src0_m, + read_only image1d_buffer_t src1, + global float * dst, + ulong offsetd, + int ne00, + int ne01) +{ + uint groupId = get_local_id(1); + uint gid = get_global_id(0); + ushort slid = get_sub_group_local_id(); + + uint K = ne00; + uint M = ne01; + + uint LINE_STRIDE_A = M / 2; + uint BLOCK_STRIDE_A = NSUBGROUPS * M; + + private uint4 regA; + private half2 regS; + private half2 regM; + private float8 regB; + + private float2 totalSum = (float2)(0.0f); + + // loop along K in block granularity, skip 4 blocks every iter + for (uint k = groupId; k < (K / QK4_0); k += NSUBGROUPS) { + regS = src0_d[gid + k * LINE_STRIDE_A]; // each fiber loads scale of two rows + regM = src0_m[gid + k * LINE_STRIDE_A]; // each fiber loads min of two rows + // first 4 fibers in each wave load 8 B values to its private scope + if (slid < 4) { + regB.s0123 = read_imagef(src1, (slid * 2 + k * 8)); + regB.s4567 = read_imagef(src1, (1 + slid * 2 + k * 8)); + } + + // load half weights for two blocks in consecutive rows + regA.s0 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 0)).x; + regA.s1 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 1)).x; + regA.s2 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 2)).x; + regA.s3 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 3)).x; +#ifdef VECTOR_SUB_GROUP_BROADCAT + dequantizeBlockAccum_ns_sgbroadcast_8_hi(totalSum, as_ushort8(regA), regS, regM, regB); +#else + dequantizeBlockAccum_ns_sgbroadcast_1_hi(totalSum, as_ushort8(regA), regS, regM, regB); +#endif // VECTOR_SUB_GROUP_BROADCAT + + regA.s0 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 4)).x; + regA.s1 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 5)).x; + regA.s2 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 6)).x; + regA.s3 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 7)).x; +#ifdef VECTOR_SUB_GROUP_BROADCAT + dequantizeBlockAccum_ns_sgbroadcast_8_lo(totalSum, as_ushort8(regA), regS, regM, regB); +#else + dequantizeBlockAccum_ns_sgbroadcast_1_lo(totalSum, as_ushort8(regA), regS, regM, regB); +#endif // VECTOR_SUB_GROUP_BROADCAT + } + + // reduction in local memory, assumes #wave=4 + local float2 reduceLM[SUBGROUP_SIZE * 3]; + if (groupId == 1) { + reduceLM[SUBGROUP_SIZE * 0 + slid] = totalSum; + } + if (groupId == 2) { + reduceLM[SUBGROUP_SIZE * 1 + slid] = totalSum; + } + if (groupId == 3) { + reduceLM[SUBGROUP_SIZE * 2 + slid] = totalSum; + } + + barrier(CLK_LOCAL_MEM_FENCE); + + if (groupId == 0) { + totalSum += reduceLM[SUBGROUP_SIZE * 0 + slid]; + } + if (groupId == 0) { + totalSum += reduceLM[SUBGROUP_SIZE * 1 + slid]; + } + if (groupId == 0) { + totalSum += reduceLM[SUBGROUP_SIZE * 2 + slid]; + } + + // 2 outputs per fiber in wave 0 + if (groupId == 0) { + dst = (global float*)((global char*)dst + offsetd); + vstore2(totalSum, 0, &(dst[gid * 2])); + } + +} diff --git a/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q4_k_f32.cl b/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q4_k_f32.cl new file mode 100644 index 00000000000..dd1e2b55c0b --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q4_k_f32.cl @@ -0,0 +1,318 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable +#pragma OPENCL EXTENSION cl_khr_subgroups : enable + +#ifdef cl_qcom_reqd_sub_group_size +#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable +#define ADRENO_GPU 1 +#define REQD_SUBGROUP_SIZE_64 __attribute__((qcom_reqd_sub_group_size("half"))) +#endif + +#define QK_K 256 +#define NSUBGROUPS 4 +#define SUBGROUP_SIZE 64 + +inline void get_scale_min_k4( + int j, + global const uchar * q, + uchar * d, + uchar * m, + uchar mask_d6, + uchar mask_d4, + uchar mask_hi2 +) { + if (j < 4) { + *d = q[j] & mask_d6; + *m = q[j+4] & mask_d6; + } else { + *d = (q[j+4] & mask_d4) | ((q[j-4] & mask_hi2) >> 2); + *m = ((q[j+4] >> 4) & mask_d4) | ((q[j] & mask_hi2) >> 2); + } +} + +#define dequantizeBlockAccum_ns_sgbroadcast_1_hi(total_sums, bits4, scale, minv, y) \ + float shared_y; \ + shared_y = sub_group_broadcast(y.s0, 0); \ + total_sums.s0 += ((bits4.s0 & 0x000F) * scale.s0 - minv.s0) * shared_y; \ + total_sums.s1 += ((bits4.s1 & 0x000F) * scale.s1 - minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s1, 0); \ + total_sums.s0 += (((bits4.s0 & 0x00F0) >> 4) * scale.s0 - minv.s0) * shared_y; \ + total_sums.s1 += (((bits4.s1 & 0x00F0) >> 4) * scale.s1 - minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s2, 0); \ + total_sums.s0 += (((bits4.s0 & 0x0F00) >> 8) * scale.s0 - minv.s0) * shared_y; \ + total_sums.s1 += (((bits4.s1 & 0x0F00) >> 8) * scale.s1 - minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s3, 0); \ + total_sums.s0 += (((bits4.s0 & 0xF000) >> 12) * scale.s0 - minv.s0) * shared_y; \ + total_sums.s1 += (((bits4.s1 & 0xF000) >> 12) * scale.s1 - minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s4, 0); \ + total_sums.s0 += ((bits4.s2 & 0x000F) * scale.s0 - minv.s0) * shared_y; \ + total_sums.s1 += ((bits4.s3 & 0x000F) * scale.s1 - minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s5, 0); \ + total_sums.s0 += (((bits4.s2 & 0x00F0) >> 4) * scale.s0 - minv.s0) * shared_y; \ + total_sums.s1 += (((bits4.s3 & 0x00F0) >> 4) * scale.s1 - minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s6, 0); \ + total_sums.s0 += (((bits4.s2 & 0x0F00) >> 8) * scale.s0 - minv.s0) * shared_y; \ + total_sums.s1 += (((bits4.s3 & 0x0F00) >> 8) * scale.s1 - minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s7, 0); \ + total_sums.s0 += (((bits4.s2 & 0xF000) >> 12) * scale.s0 - minv.s0) * shared_y; \ + total_sums.s1 += (((bits4.s3 & 0xF000) >> 12) * scale.s1 - minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s0, 1); \ + total_sums.s0 += ((bits4.s4 & 0x000F) * scale.s0 - minv.s0) * shared_y; \ + total_sums.s1 += ((bits4.s5 & 0x000F) * scale.s1 - minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s1, 1); \ + total_sums.s0 += (((bits4.s4 & 0x00F0) >> 4) * scale.s0 - minv.s0) * shared_y; \ + total_sums.s1 += (((bits4.s5 & 0x00F0) >> 4) * scale.s1 - minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s2, 1); \ + total_sums.s0 += (((bits4.s4 & 0x0F00) >> 8) * scale.s0 - minv.s0) * shared_y; \ + total_sums.s1 += (((bits4.s5 & 0x0F00) >> 8) * scale.s1 - minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s3, 1); \ + total_sums.s0 += (((bits4.s4 & 0xF000) >> 12) * scale.s0 - minv.s0) * shared_y; \ + total_sums.s1 += (((bits4.s5 & 0xF000) >> 12) * scale.s1 - minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s4, 1); \ + total_sums.s0 += ((bits4.s6 & 0x000F) * scale.s0 - minv.s0) * shared_y; \ + total_sums.s1 += ((bits4.s7 & 0x000F) * scale.s1 - minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s5, 1); \ + total_sums.s0 += (((bits4.s6 & 0x00F0) >> 4) * scale.s0 - minv.s0) * shared_y; \ + total_sums.s1 += (((bits4.s7 & 0x00F0) >> 4) * scale.s1 - minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s6, 1); \ + total_sums.s0 += (((bits4.s6 & 0x0F00) >> 8) * scale.s0 - minv.s0) * shared_y; \ + total_sums.s1 += (((bits4.s7 & 0x0F00) >> 8) * scale.s1 - minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s7, 1); \ + total_sums.s0 += (((bits4.s6 & 0xF000) >> 12) * scale.s0 - minv.s0) * shared_y; \ + total_sums.s1 += (((bits4.s7 & 0xF000) >> 12) * scale.s1 - minv.s1) * shared_y; \ + + +#define dequantizeBlockAccum_ns_sgbroadcast_1_lo(total_sums, bits4, scale, minv, y) \ + shared_y = sub_group_broadcast(y.s0, 2); \ + total_sums.s0 += ((bits4.s0 & 0x000F) * scale.s0 - minv.s0) * shared_y; \ + total_sums.s1 += ((bits4.s1 & 0x000F) * scale.s1 - minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s1, 2); \ + total_sums.s0 += (((bits4.s0 & 0x00F0) >> 4) * scale.s0 - minv.s0) * shared_y; \ + total_sums.s1 += (((bits4.s1 & 0x00F0) >> 4) * scale.s1 - minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s2, 2); \ + total_sums.s0 += (((bits4.s0 & 0x0F00) >> 8) * scale.s0 - minv.s0) * shared_y; \ + total_sums.s1 += (((bits4.s1 & 0x0F00) >> 8) * scale.s1 - minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s3, 2); \ + total_sums.s0 += (((bits4.s0 & 0xF000) >> 12) * scale.s0 - minv.s0) * shared_y; \ + total_sums.s1 += (((bits4.s1 & 0xF000) >> 12) * scale.s1 - minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s4, 2); \ + total_sums.s0 += ((bits4.s2 & 0x000F) * scale.s0 - minv.s0) * shared_y; \ + total_sums.s1 += ((bits4.s3 & 0x000F) * scale.s1 - minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s5, 2); \ + total_sums.s0 += (((bits4.s2 & 0x00F0) >> 4) * scale.s0 - minv.s0) * shared_y; \ + total_sums.s1 += (((bits4.s3 & 0x00F0) >> 4) * scale.s1 - minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s6, 2); \ + total_sums.s0 += (((bits4.s2 & 0x0F00) >> 8) * scale.s0 - minv.s0) * shared_y; \ + total_sums.s1 += (((bits4.s3 & 0x0F00) >> 8) * scale.s1 - minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s7, 2); \ + total_sums.s0 += (((bits4.s2 & 0xF000) >> 12) * scale.s0 - minv.s0) * shared_y; \ + total_sums.s1 += (((bits4.s3 & 0xF000) >> 12) * scale.s1 - minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s0, 3); \ + total_sums.s0 += ((bits4.s4 & 0x000F) * scale.s0 - minv.s0) * shared_y; \ + total_sums.s1 += ((bits4.s5 & 0x000F) * scale.s1 - minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s1, 3); \ + total_sums.s0 += (((bits4.s4 & 0x00F0) >> 4) * scale.s0 - minv.s0) * shared_y; \ + total_sums.s1 += (((bits4.s5 & 0x00F0) >> 4) * scale.s1 - minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s2, 3); \ + total_sums.s0 += (((bits4.s4 & 0x0F00) >> 8) * scale.s0 - minv.s0) * shared_y; \ + total_sums.s1 += (((bits4.s5 & 0x0F00) >> 8) * scale.s1 - minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s3, 3); \ + total_sums.s0 += (((bits4.s4 & 0xF000) >> 12) * scale.s0 - minv.s0) * shared_y; \ + total_sums.s1 += (((bits4.s5 & 0xF000) >> 12) * scale.s1 - minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s4, 3); \ + total_sums.s0 += ((bits4.s6 & 0x000F) * scale.s0 - minv.s0) * shared_y; \ + total_sums.s1 += ((bits4.s7 & 0x000F) * scale.s1 - minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s5, 3); \ + total_sums.s0 += (((bits4.s6 & 0x00F0) >> 4) * scale.s0 - minv.s0) * shared_y; \ + total_sums.s1 += (((bits4.s7 & 0x00F0) >> 4) * scale.s1 - minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s6, 3); \ + total_sums.s0 += (((bits4.s6 & 0x0F00) >> 8) * scale.s0 - minv.s0) * shared_y; \ + total_sums.s1 += (((bits4.s7 & 0x0F00) >> 8) * scale.s1 - minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s7, 3); \ + total_sums.s0 += (((bits4.s6 & 0xF000) >> 12) * scale.s0 - minv.s0) * shared_y; \ + total_sums.s1 += (((bits4.s7 & 0xF000) >> 12) * scale.s1 - minv.s1) * shared_y; \ + + +#define dequantizeBlockAccum_ns_sgbroadcast_8_hi(total_sums, bits4, scale, minv, y) \ + float8 shared_y; \ + shared_y = sub_group_broadcast(y, 0); \ + total_sums.s0 += ((bits4.s0 & 0x000F) * scale.s0 - minv.s0) * shared_y.s0; \ + total_sums.s0 += (((bits4.s0 & 0x00F0) >> 4) * scale.s0 - minv.s0) * shared_y.s1; \ + total_sums.s0 += (((bits4.s0 & 0x0F00) >> 8) * scale.s0 - minv.s0) * shared_y.s2; \ + total_sums.s0 += (((bits4.s0 & 0xF000) >> 12) * scale.s0 - minv.s0) * shared_y.s3; \ + total_sums.s0 += ((bits4.s2 & 0x000F) * scale.s0 - minv.s0) * shared_y.s4; \ + total_sums.s0 += (((bits4.s2 & 0x00F0) >> 4) * scale.s0 - minv.s0) * shared_y.s5; \ + total_sums.s0 += (((bits4.s2 & 0x0F00) >> 8) * scale.s0 - minv.s0) * shared_y.s6; \ + total_sums.s0 += (((bits4.s2 & 0xF000) >> 12) * scale.s0 - minv.s0) * shared_y.s7; \ + total_sums.s1 += ((bits4.s1 & 0x000F) * scale.s1 - minv.s1) * shared_y.s0; \ + total_sums.s1 += (((bits4.s1 & 0x00F0) >> 4) * scale.s1 - minv.s1) * shared_y.s1; \ + total_sums.s1 += (((bits4.s1 & 0x0F00) >> 8) * scale.s1 - minv.s1) * shared_y.s2; \ + total_sums.s1 += (((bits4.s1 & 0xF000) >> 12) * scale.s1 - minv.s1) * shared_y.s3; \ + total_sums.s1 += ((bits4.s3 & 0x000F) * scale.s1 - minv.s1) * shared_y.s4; \ + total_sums.s1 += (((bits4.s3 & 0x00F0) >> 4) * scale.s1 - minv.s1) * shared_y.s5; \ + total_sums.s1 += (((bits4.s3 & 0x0F00) >> 8) * scale.s1 - minv.s1) * shared_y.s6; \ + total_sums.s1 += (((bits4.s3 & 0xF000) >> 12) * scale.s1 - minv.s1) * shared_y.s7; \ + shared_y = sub_group_broadcast(y, 1); \ + total_sums.s0 += ((bits4.s4 & 0x000F) * scale.s0 - minv.s0) * shared_y.s0; \ + total_sums.s0 += (((bits4.s4 & 0x00F0) >> 4) * scale.s0 - minv.s0) * shared_y.s1; \ + total_sums.s0 += (((bits4.s4 & 0x0F00) >> 8) * scale.s0 - minv.s0) * shared_y.s2; \ + total_sums.s0 += (((bits4.s4 & 0xF000) >> 12) * scale.s0 - minv.s0) * shared_y.s3; \ + total_sums.s0 += ((bits4.s6 & 0x000F) * scale.s0 - minv.s0) * shared_y.s4; \ + total_sums.s0 += (((bits4.s6 & 0x00F0) >> 4) * scale.s0 - minv.s0) * shared_y.s5; \ + total_sums.s0 += (((bits4.s6 & 0x0F00) >> 8) * scale.s0 - minv.s0) * shared_y.s6; \ + total_sums.s0 += (((bits4.s6 & 0xF000) >> 12) * scale.s0 - minv.s0) * shared_y.s7; \ + total_sums.s1 += ((bits4.s5 & 0x000F) * scale.s1 - minv.s1) * shared_y.s0; \ + total_sums.s1 += (((bits4.s5 & 0x00F0) >> 4) * scale.s1 - minv.s1) * shared_y.s1; \ + total_sums.s1 += (((bits4.s5 & 0x0F00) >> 8) * scale.s1 - minv.s1) * shared_y.s2; \ + total_sums.s1 += (((bits4.s5 & 0xF000) >> 12) * scale.s1 - minv.s1) * shared_y.s3; \ + total_sums.s1 += ((bits4.s7 & 0x000F) * scale.s1 - minv.s1) * shared_y.s4; \ + total_sums.s1 += (((bits4.s7 & 0x00F0) >> 4) * scale.s1 - minv.s1) * shared_y.s5; \ + total_sums.s1 += (((bits4.s7 & 0x0F00) >> 8) * scale.s1 - minv.s1) * shared_y.s6; \ + total_sums.s1 += (((bits4.s7 & 0xF000) >> 12) * scale.s1 - minv.s1) * shared_y.s7; \ + + +#define dequantizeBlockAccum_ns_sgbroadcast_8_lo(total_sums, bits4, scale, minv, y) \ + shared_y = sub_group_broadcast(y, 2); \ + total_sums.s0 += ((bits4.s0 & 0x000F) * scale.s0 - minv.s0) * shared_y.s0; \ + total_sums.s0 += (((bits4.s0 & 0x00F0) >> 4) * scale.s0 - minv.s0) * shared_y.s1; \ + total_sums.s0 += (((bits4.s0 & 0x0F00) >> 8) * scale.s0 - minv.s0) * shared_y.s2; \ + total_sums.s0 += (((bits4.s0 & 0xF000) >> 12) * scale.s0 - minv.s0) * shared_y.s3; \ + total_sums.s0 += ((bits4.s2 & 0x000F) * scale.s0 - minv.s0) * shared_y.s4; \ + total_sums.s0 += (((bits4.s2 & 0x00F0) >> 4) * scale.s0 - minv.s0) * shared_y.s5; \ + total_sums.s0 += (((bits4.s2 & 0x0F00) >> 8) * scale.s0 - minv.s0) * shared_y.s6; \ + total_sums.s0 += (((bits4.s2 & 0xF000) >> 12) * scale.s0 - minv.s0) * shared_y.s7; \ + total_sums.s1 += ((bits4.s1 & 0x000F) * scale.s1 - minv.s1) * shared_y.s0; \ + total_sums.s1 += (((bits4.s1 & 0x00F0) >> 4) * scale.s1 - minv.s1) * shared_y.s1; \ + total_sums.s1 += (((bits4.s1 & 0x0F00) >> 8) * scale.s1 - minv.s1) * shared_y.s2; \ + total_sums.s1 += (((bits4.s1 & 0xF000) >> 12) * scale.s1 - minv.s1) * shared_y.s3; \ + total_sums.s1 += ((bits4.s3 & 0x000F) * scale.s1 - minv.s1) * shared_y.s4; \ + total_sums.s1 += (((bits4.s3 & 0x00F0) >> 4) * scale.s1 - minv.s1) * shared_y.s5; \ + total_sums.s1 += (((bits4.s3 & 0x0F00) >> 8) * scale.s1 - minv.s1) * shared_y.s6; \ + total_sums.s1 += (((bits4.s3 & 0xF000) >> 12) * scale.s1 - minv.s1) * shared_y.s7; \ + shared_y = sub_group_broadcast(y, 3); \ + total_sums.s0 += ((bits4.s4 & 0x000F) * scale.s0 - minv.s0) * shared_y.s0; \ + total_sums.s0 += (((bits4.s4 & 0x00F0) >> 4) * scale.s0 - minv.s0) * shared_y.s1; \ + total_sums.s0 += (((bits4.s4 & 0x0F00) >> 8) * scale.s0 - minv.s0) * shared_y.s2; \ + total_sums.s0 += (((bits4.s4 & 0xF000) >> 12) * scale.s0 - minv.s0) * shared_y.s3; \ + total_sums.s0 += ((bits4.s6 & 0x000F) * scale.s0 - minv.s0) * shared_y.s4; \ + total_sums.s0 += (((bits4.s6 & 0x00F0) >> 4) * scale.s0 - minv.s0) * shared_y.s5; \ + total_sums.s0 += (((bits4.s6 & 0x0F00) >> 8) * scale.s0 - minv.s0) * shared_y.s6; \ + total_sums.s0 += (((bits4.s6 & 0xF000) >> 12) * scale.s0 - minv.s0) * shared_y.s7; \ + total_sums.s1 += ((bits4.s5 & 0x000F) * scale.s1 - minv.s1) * shared_y.s0; \ + total_sums.s1 += (((bits4.s5 & 0x00F0) >> 4) * scale.s1 - minv.s1) * shared_y.s1; \ + total_sums.s1 += (((bits4.s5 & 0x0F00) >> 8) * scale.s1 - minv.s1) * shared_y.s2; \ + total_sums.s1 += (((bits4.s5 & 0xF000) >> 12) * scale.s1 - minv.s1) * shared_y.s3; \ + total_sums.s1 += ((bits4.s7 & 0x000F) * scale.s1 - minv.s1) * shared_y.s4; \ + total_sums.s1 += (((bits4.s7 & 0x00F0) >> 4) * scale.s1 - minv.s1) * shared_y.s5; \ + total_sums.s1 += (((bits4.s7 & 0x0F00) >> 8) * scale.s1 - minv.s1) * shared_y.s6; \ + total_sums.s1 += (((bits4.s7 & 0xF000) >> 12) * scale.s1 - minv.s1) * shared_y.s7; \ + +#ifdef ADRENO_GPU +REQD_SUBGROUP_SIZE_64 +#endif +kernel void kernel_gemv_noshuffle_q4_k_f32( + read_only image1d_buffer_t src0_q, + global half2 * src0_d, + global half2 * src0_m, + global uchar * src0_s, + read_only image1d_buffer_t src1, + global float * dst, + ulong offsetd, + int ne00, + int ne01, + uchar mask_d6, + uchar mask_d4, + uchar mask_hi2) +{ + uint groupId = get_local_id(1); + uint gid = get_global_id(0); + ushort slid = get_sub_group_local_id(); + + uint K = ne00; + uint M = ne01; + + uint LINE_STRIDE_A = M / 2; + uint BLOCK_STRIDE_A = NSUBGROUPS * M; + uint scales_per_row = (K / QK_K) * 12; + + private uint4 regA; + private half2 regS; + private half2 regM; + private float8 regB; + + private float2 totalSum = (float2)(0.0f); + + for (uint k = groupId; k < (K / 32); k += NSUBGROUPS) { + uint sb = k / 8; + uint j = k % 8; + + half2 d = src0_d[gid + sb * LINE_STRIDE_A]; + half2 dm = src0_m[gid + sb * LINE_STRIDE_A]; + + global const uchar * sc0 = src0_s + 2 * gid * scales_per_row + sb * 12; + global const uchar * sc1 = src0_s + (2 * gid + 1) * scales_per_row + sb * 12; + + uchar sv0, mn0, sv1, mn1; + get_scale_min_k4(j, sc0, &sv0, &mn0, mask_d6, mask_d4, mask_hi2); + get_scale_min_k4(j, sc1, &sv1, &mn1, mask_d6, mask_d4, mask_hi2); + + regS = convert_half2(convert_float2(d) * convert_float2((uchar2)(sv0, sv1))); + regM = convert_half2(convert_float2(dm) * convert_float2((uchar2)(mn0, mn1))); + + if (slid < 4) { + regB.s0123 = read_imagef(src1, (slid * 2 + k * 8)); + regB.s4567 = read_imagef(src1, (1 + slid * 2 + k * 8)); + } + + // load half weights for two blocks in consecutive rows + regA.s0 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 0)).x; + regA.s1 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 1)).x; + regA.s2 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 2)).x; + regA.s3 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 3)).x; +#ifdef VECTOR_SUB_GROUP_BROADCAST + dequantizeBlockAccum_ns_sgbroadcast_8_hi(totalSum, as_ushort8(regA), regS, regM, regB); +#else + dequantizeBlockAccum_ns_sgbroadcast_1_hi(totalSum, as_ushort8(regA), regS, regM, regB); +#endif // VECTOR_SUB_GROUP_BROADCAST + + regA.s0 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 4)).x; + regA.s1 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 5)).x; + regA.s2 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 6)).x; + regA.s3 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 7)).x; +#ifdef VECTOR_SUB_GROUP_BROADCAST + dequantizeBlockAccum_ns_sgbroadcast_8_lo(totalSum, as_ushort8(regA), regS, regM, regB); +#else + dequantizeBlockAccum_ns_sgbroadcast_1_lo(totalSum, as_ushort8(regA), regS, regM, regB); +#endif // VECTOR_SUB_GROUP_BROADCAST + } + + // reduction in local memory, assumes #wave=4 + local float2 reduceLM[SUBGROUP_SIZE * 3]; + if (groupId == 1) { + reduceLM[SUBGROUP_SIZE * 0 + slid] = totalSum; + } + if (groupId == 2) { + reduceLM[SUBGROUP_SIZE * 1 + slid] = totalSum; + } + if (groupId == 3) { + reduceLM[SUBGROUP_SIZE * 2 + slid] = totalSum; + } + + barrier(CLK_LOCAL_MEM_FENCE); + + if (groupId == 0) { + totalSum += reduceLM[SUBGROUP_SIZE * 0 + slid]; + } + if (groupId == 0) { + totalSum += reduceLM[SUBGROUP_SIZE * 1 + slid]; + } + if (groupId == 0) { + totalSum += reduceLM[SUBGROUP_SIZE * 2 + slid]; + } + + // 2 outputs per fiber in wave 0 + if (groupId == 0) { + dst = (global float*)((global char*)dst + offsetd); + vstore2(totalSum, 0, &(dst[gid * 2])); + } + +} diff --git a/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q5_0_f32.cl b/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q5_0_f32.cl new file mode 100644 index 00000000000..c228f717a94 --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q5_0_f32.cl @@ -0,0 +1,291 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable +#pragma OPENCL EXTENSION cl_khr_subgroups : enable + +#ifdef cl_qcom_reqd_sub_group_size +#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable +#define ADRENO_GPU 1 +#define REQD_SUBGROUP_SIZE_64 __attribute__((qcom_reqd_sub_group_size("half"))) +#endif + +#define QK5_0 32 +#define NSUBGROUPS 4 +#define SUBGROUP_SIZE 64 + +#define dequantizeBlockAccum_ns_q5_0_sgbroadcast_1_hi(total_sums, bits4, bits1, scale, y) \ + float shared_y; \ + shared_y = sub_group_broadcast(y.s0, 0); \ + total_sums.s0 += (((bits4.s0 & 0x000F) | (((bits1.s0 ) & 0x01) << 4)) - 16) * scale.s0 * shared_y; \ + total_sums.s1 += (((bits4.s1 & 0x000F) | (((bits1.s4 ) & 0x01) << 4)) - 16) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s1, 0); \ + total_sums.s0 += ((((bits4.s0 & 0x00F0) >> 4) | (((bits1.s0 >> 1) & 0x01) << 4)) - 16) * scale.s0 * shared_y; \ + total_sums.s1 += ((((bits4.s1 & 0x00F0) >> 4) | (((bits1.s4 >> 1) & 0x01) << 4)) - 16) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s2, 0); \ + total_sums.s0 += ((((bits4.s0 & 0x0F00) >> 8) | (((bits1.s0 >> 2) & 0x01) << 4)) - 16) * scale.s0 * shared_y; \ + total_sums.s1 += ((((bits4.s1 & 0x0F00) >> 8) | (((bits1.s4 >> 2) & 0x01) << 4)) - 16) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s3, 0); \ + total_sums.s0 += ((((bits4.s0 & 0xF000) >> 12) | (((bits1.s0 >> 3) & 0x01) << 4)) - 16) * scale.s0 * shared_y; \ + total_sums.s1 += ((((bits4.s1 & 0xF000) >> 12) | (((bits1.s4 >> 3) & 0x01) << 4)) - 16) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s4, 0); \ + total_sums.s0 += (((bits4.s2 & 0x000F) | (((bits1.s0 >> 4) & 0x01) << 4)) - 16) * scale.s0 * shared_y; \ + total_sums.s1 += (((bits4.s3 & 0x000F) | (((bits1.s4 >> 4) & 0x01) << 4)) - 16) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s5, 0); \ + total_sums.s0 += ((((bits4.s2 & 0x00F0) >> 4) | (((bits1.s0 >> 5) & 0x01) << 4)) - 16) * scale.s0 * shared_y; \ + total_sums.s1 += ((((bits4.s3 & 0x00F0) >> 4) | (((bits1.s4 >> 5) & 0x01) << 4)) - 16) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s6, 0); \ + total_sums.s0 += ((((bits4.s2 & 0x0F00) >> 8) | (((bits1.s0 >> 6) & 0x01) << 4)) - 16) * scale.s0 * shared_y; \ + total_sums.s1 += ((((bits4.s3 & 0x0F00) >> 8) | (((bits1.s4 >> 6) & 0x01) << 4)) - 16) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s7, 0); \ + total_sums.s0 += ((((bits4.s2 & 0xF000) >> 12) | (((bits1.s0 >> 7) & 0x01) << 4)) - 16) * scale.s0 * shared_y; \ + total_sums.s1 += ((((bits4.s3 & 0xF000) >> 12) | (((bits1.s4 >> 7) & 0x01) << 4)) - 16) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s0, 1); \ + total_sums.s0 += (((bits4.s4 & 0x000F) | (((bits1.s1 ) & 0x01) << 4)) - 16) * scale.s0 * shared_y; \ + total_sums.s1 += (((bits4.s5 & 0x000F) | (((bits1.s5 ) & 0x01) << 4)) - 16) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s1, 1); \ + total_sums.s0 += ((((bits4.s4 & 0x00F0) >> 4) | (((bits1.s1 >> 1) & 0x01) << 4)) - 16) * scale.s0 * shared_y; \ + total_sums.s1 += ((((bits4.s5 & 0x00F0) >> 4) | (((bits1.s5 >> 1) & 0x01) << 4)) - 16) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s2, 1); \ + total_sums.s0 += ((((bits4.s4 & 0x0F00) >> 8) | (((bits1.s1 >> 2) & 0x01) << 4)) - 16) * scale.s0 * shared_y; \ + total_sums.s1 += ((((bits4.s5 & 0x0F00) >> 8) | (((bits1.s5 >> 2) & 0x01) << 4)) - 16) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s3, 1); \ + total_sums.s0 += ((((bits4.s4 & 0xF000) >> 12) | (((bits1.s1 >> 3) & 0x01) << 4)) - 16) * scale.s0 * shared_y; \ + total_sums.s1 += ((((bits4.s5 & 0xF000) >> 12) | (((bits1.s5 >> 3) & 0x01) << 4)) - 16) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s4, 1); \ + total_sums.s0 += (((bits4.s6 & 0x000F) | (((bits1.s1 >> 4) & 0x01) << 4)) - 16) * scale.s0 * shared_y; \ + total_sums.s1 += (((bits4.s7 & 0x000F) | (((bits1.s5 >> 4) & 0x01) << 4)) - 16) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s5, 1); \ + total_sums.s0 += ((((bits4.s6 & 0x00F0) >> 4) | (((bits1.s1 >> 5) & 0x01) << 4)) - 16) * scale.s0 * shared_y; \ + total_sums.s1 += ((((bits4.s7 & 0x00F0) >> 4) | (((bits1.s5 >> 5) & 0x01) << 4)) - 16) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s6, 1); \ + total_sums.s0 += ((((bits4.s6 & 0x0F00) >> 8) | (((bits1.s1 >> 6) & 0x01) << 4)) - 16) * scale.s0 * shared_y; \ + total_sums.s1 += ((((bits4.s7 & 0x0F00) >> 8) | (((bits1.s5 >> 6) & 0x01) << 4)) - 16) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s7, 1); \ + total_sums.s0 += ((((bits4.s6 & 0xF000) >> 12) | (((bits1.s1 >> 7) & 0x01) << 4)) - 16) * scale.s0 * shared_y; \ + total_sums.s1 += ((((bits4.s7 & 0xF000) >> 12) | (((bits1.s5 >> 7) & 0x01) << 4)) - 16) * scale.s1 * shared_y; \ + + +#define dequantizeBlockAccum_ns_q5_0_sgbroadcast_1_lo(total_sums, bits4, bits1, scale, y) \ + shared_y = sub_group_broadcast(y.s0, 2); \ + total_sums.s0 += (((bits4.s0 & 0x000F) | (((bits1.s2 ) & 0x01) << 4)) - 16) * scale.s0 * shared_y; \ + total_sums.s1 += (((bits4.s1 & 0x000F) | (((bits1.s6 ) & 0x01) << 4)) - 16) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s1, 2); \ + total_sums.s0 += ((((bits4.s0 & 0x00F0) >> 4) | (((bits1.s2 >> 1) & 0x01) << 4)) - 16) * scale.s0 * shared_y; \ + total_sums.s1 += ((((bits4.s1 & 0x00F0) >> 4) | (((bits1.s6 >> 1) & 0x01) << 4)) - 16) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s2, 2); \ + total_sums.s0 += ((((bits4.s0 & 0x0F00) >> 8) | (((bits1.s2 >> 2) & 0x01) << 4)) - 16) * scale.s0 * shared_y; \ + total_sums.s1 += ((((bits4.s1 & 0x0F00) >> 8) | (((bits1.s6 >> 2) & 0x01) << 4)) - 16) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s3, 2); \ + total_sums.s0 += ((((bits4.s0 & 0xF000) >> 12) | (((bits1.s2 >> 3) & 0x01) << 4)) - 16) * scale.s0 * shared_y; \ + total_sums.s1 += ((((bits4.s1 & 0xF000) >> 12) | (((bits1.s6 >> 3) & 0x01) << 4)) - 16) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s4, 2); \ + total_sums.s0 += (((bits4.s2 & 0x000F) | (((bits1.s2 >> 4) & 0x01) << 4)) - 16) * scale.s0 * shared_y; \ + total_sums.s1 += (((bits4.s3 & 0x000F) | (((bits1.s6 >> 4) & 0x01) << 4)) - 16) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s5, 2); \ + total_sums.s0 += ((((bits4.s2 & 0x00F0) >> 4) | (((bits1.s2 >> 5) & 0x01) << 4)) - 16) * scale.s0 * shared_y; \ + total_sums.s1 += ((((bits4.s3 & 0x00F0) >> 4) | (((bits1.s6 >> 5) & 0x01) << 4)) - 16) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s6, 2); \ + total_sums.s0 += ((((bits4.s2 & 0x0F00) >> 8) | (((bits1.s2 >> 6) & 0x01) << 4)) - 16) * scale.s0 * shared_y; \ + total_sums.s1 += ((((bits4.s3 & 0x0F00) >> 8) | (((bits1.s6 >> 6) & 0x01) << 4)) - 16) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s7, 2); \ + total_sums.s0 += ((((bits4.s2 & 0xF000) >> 12) | (((bits1.s2 >> 7) & 0x01) << 4)) - 16) * scale.s0 * shared_y; \ + total_sums.s1 += ((((bits4.s3 & 0xF000) >> 12) | (((bits1.s6 >> 7) & 0x01) << 4)) - 16) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s0, 3); \ + total_sums.s0 += (((bits4.s4 & 0x000F) | (((bits1.s3 ) & 0x01) << 4)) - 16) * scale.s0 * shared_y; \ + total_sums.s1 += (((bits4.s5 & 0x000F) | (((bits1.s7 ) & 0x01) << 4)) - 16) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s1, 3); \ + total_sums.s0 += ((((bits4.s4 & 0x00F0) >> 4) | (((bits1.s3 >> 1) & 0x01) << 4)) - 16) * scale.s0 * shared_y; \ + total_sums.s1 += ((((bits4.s5 & 0x00F0) >> 4) | (((bits1.s7 >> 1) & 0x01) << 4)) - 16) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s2, 3); \ + total_sums.s0 += ((((bits4.s4 & 0x0F00) >> 8) | (((bits1.s3 >> 2) & 0x01) << 4)) - 16) * scale.s0 * shared_y; \ + total_sums.s1 += ((((bits4.s5 & 0x0F00) >> 8) | (((bits1.s7 >> 2) & 0x01) << 4)) - 16) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s3, 3); \ + total_sums.s0 += ((((bits4.s4 & 0xF000) >> 12) | (((bits1.s3 >> 3) & 0x01) << 4)) - 16) * scale.s0 * shared_y; \ + total_sums.s1 += ((((bits4.s5 & 0xF000) >> 12) | (((bits1.s7 >> 3) & 0x01) << 4)) - 16) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s4, 3); \ + total_sums.s0 += (((bits4.s6 & 0x000F) | (((bits1.s3 >> 4) & 0x01) << 4)) - 16) * scale.s0 * shared_y; \ + total_sums.s1 += (((bits4.s7 & 0x000F) | (((bits1.s7 >> 4) & 0x01) << 4)) - 16) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s5, 3); \ + total_sums.s0 += ((((bits4.s6 & 0x00F0) >> 4) | (((bits1.s3 >> 5) & 0x01) << 4)) - 16) * scale.s0 * shared_y; \ + total_sums.s1 += ((((bits4.s7 & 0x00F0) >> 4) | (((bits1.s7 >> 5) & 0x01) << 4)) - 16) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s6, 3); \ + total_sums.s0 += ((((bits4.s6 & 0x0F00) >> 8) | (((bits1.s3 >> 6) & 0x01) << 4)) - 16) * scale.s0 * shared_y; \ + total_sums.s1 += ((((bits4.s7 & 0x0F00) >> 8) | (((bits1.s7 >> 6) & 0x01) << 4)) - 16) * scale.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s7, 3); \ + total_sums.s0 += ((((bits4.s6 & 0xF000) >> 12) | (((bits1.s3 >> 7) & 0x01) << 4)) - 16) * scale.s0 * shared_y; \ + total_sums.s1 += ((((bits4.s7 & 0xF000) >> 12) | (((bits1.s7 >> 7) & 0x01) << 4)) - 16) * scale.s1 * shared_y; \ + + +#define dequantizeBlockAccum_ns_q5_0_sgbroadcast_8_hi(total_sums, bits4, bits1, scale, y) \ + float8 shared_y; \ + shared_y = sub_group_broadcast(y, 0); \ + total_sums.s0 += (((bits4.s0 & 0x000F) | (((bits1.s0 ) & 0x01) << 4)) - 16) * scale.s0 * shared_y.s0; \ + total_sums.s0 += ((((bits4.s0 & 0x00F0) >> 4) | (((bits1.s0 >> 1) & 0x01) << 4)) - 16) * scale.s0 * shared_y.s1; \ + total_sums.s0 += ((((bits4.s0 & 0x0F00) >> 8) | (((bits1.s0 >> 2) & 0x01) << 4)) - 16) * scale.s0 * shared_y.s2; \ + total_sums.s0 += ((((bits4.s0 & 0xF000) >> 12) | (((bits1.s0 >> 3) & 0x01) << 4)) - 16) * scale.s0 * shared_y.s3; \ + total_sums.s0 += (((bits4.s2 & 0x000F) | (((bits1.s0 >> 4) & 0x01) << 4)) - 16) * scale.s0 * shared_y.s4; \ + total_sums.s0 += ((((bits4.s2 & 0x00F0) >> 4) | (((bits1.s0 >> 5) & 0x01) << 4)) - 16) * scale.s0 * shared_y.s5; \ + total_sums.s0 += ((((bits4.s2 & 0x0F00) >> 8) | (((bits1.s0 >> 6) & 0x01) << 4)) - 16) * scale.s0 * shared_y.s6; \ + total_sums.s0 += ((((bits4.s2 & 0xF000) >> 12) | (((bits1.s0 >> 7) & 0x01) << 4)) - 16) * scale.s0 * shared_y.s7; \ + total_sums.s1 += (((bits4.s1 & 0x000F) | (((bits1.s4 ) & 0x01) << 4)) - 16) * scale.s1 * shared_y.s0; \ + total_sums.s1 += ((((bits4.s1 & 0x00F0) >> 4) | (((bits1.s4 >> 1) & 0x01) << 4)) - 16) * scale.s1 * shared_y.s1; \ + total_sums.s1 += ((((bits4.s1 & 0x0F00) >> 8) | (((bits1.s4 >> 2) & 0x01) << 4)) - 16) * scale.s1 * shared_y.s2; \ + total_sums.s1 += ((((bits4.s1 & 0xF000) >> 12) | (((bits1.s4 >> 3) & 0x01) << 4)) - 16) * scale.s1 * shared_y.s3; \ + total_sums.s1 += (((bits4.s3 & 0x000F) | (((bits1.s4 >> 4) & 0x01) << 4)) - 16) * scale.s1 * shared_y.s4; \ + total_sums.s1 += ((((bits4.s3 & 0x00F0) >> 4) | (((bits1.s4 >> 5) & 0x01) << 4)) - 16) * scale.s1 * shared_y.s5; \ + total_sums.s1 += ((((bits4.s3 & 0x0F00) >> 8) | (((bits1.s4 >> 6) & 0x01) << 4)) - 16) * scale.s1 * shared_y.s6; \ + total_sums.s1 += ((((bits4.s3 & 0xF000) >> 12) | (((bits1.s4 >> 7) & 0x01) << 4)) - 16) * scale.s1 * shared_y.s7; \ + shared_y = sub_group_broadcast(y, 1); \ + total_sums.s0 += (((bits4.s4 & 0x000F) | (((bits1.s1 ) & 0x01) << 4)) - 16) * scale.s0 * shared_y.s0; \ + total_sums.s0 += ((((bits4.s4 & 0x00F0) >> 4) | (((bits1.s1 >> 1) & 0x01) << 4)) - 16) * scale.s0 * shared_y.s1; \ + total_sums.s0 += ((((bits4.s4 & 0x0F00) >> 8) | (((bits1.s1 >> 2) & 0x01) << 4)) - 16) * scale.s0 * shared_y.s2; \ + total_sums.s0 += ((((bits4.s4 & 0xF000) >> 12) | (((bits1.s1 >> 3) & 0x01) << 4)) - 16) * scale.s0 * shared_y.s3; \ + total_sums.s0 += (((bits4.s6 & 0x000F) | (((bits1.s1 >> 4) & 0x01) << 4)) - 16) * scale.s0 * shared_y.s4; \ + total_sums.s0 += ((((bits4.s6 & 0x00F0) >> 4) | (((bits1.s1 >> 5) & 0x01) << 4)) - 16) * scale.s0 * shared_y.s5; \ + total_sums.s0 += ((((bits4.s6 & 0x0F00) >> 8) | (((bits1.s1 >> 6) & 0x01) << 4)) - 16) * scale.s0 * shared_y.s6; \ + total_sums.s0 += ((((bits4.s6 & 0xF000) >> 12) | (((bits1.s1 >> 7) & 0x01) << 4)) - 16) * scale.s0 * shared_y.s7; \ + total_sums.s1 += (((bits4.s5 & 0x000F) | (((bits1.s5 ) & 0x01) << 4)) - 16) * scale.s1 * shared_y.s0; \ + total_sums.s1 += ((((bits4.s5 & 0x00F0) >> 4) | (((bits1.s5 >> 1) & 0x01) << 4)) - 16) * scale.s1 * shared_y.s1; \ + total_sums.s1 += ((((bits4.s5 & 0x0F00) >> 8) | (((bits1.s5 >> 2) & 0x01) << 4)) - 16) * scale.s1 * shared_y.s2; \ + total_sums.s1 += ((((bits4.s5 & 0xF000) >> 12) | (((bits1.s5 >> 3) & 0x01) << 4)) - 16) * scale.s1 * shared_y.s3; \ + total_sums.s1 += (((bits4.s7 & 0x000F) | (((bits1.s5 >> 4) & 0x01) << 4)) - 16) * scale.s1 * shared_y.s4; \ + total_sums.s1 += ((((bits4.s7 & 0x00F0) >> 4) | (((bits1.s5 >> 5) & 0x01) << 4)) - 16) * scale.s1 * shared_y.s5; \ + total_sums.s1 += ((((bits4.s7 & 0x0F00) >> 8) | (((bits1.s5 >> 6) & 0x01) << 4)) - 16) * scale.s1 * shared_y.s6; \ + total_sums.s1 += ((((bits4.s7 & 0xF000) >> 12) | (((bits1.s5 >> 7) & 0x01) << 4)) - 16) * scale.s1 * shared_y.s7; \ + + +#define dequantizeBlockAccum_ns_q5_0_sgbroadcast_8_lo(total_sums, bits4, bits1, scale, y) \ + shared_y = sub_group_broadcast(y, 2); \ + total_sums.s0 += (((bits4.s0 & 0x000F) | (((bits1.s2 ) & 0x01) << 4)) - 16) * scale.s0 * shared_y.s0; \ + total_sums.s0 += ((((bits4.s0 & 0x00F0) >> 4) | (((bits1.s2 >> 1) & 0x01) << 4)) - 16) * scale.s0 * shared_y.s1; \ + total_sums.s0 += ((((bits4.s0 & 0x0F00) >> 8) | (((bits1.s2 >> 2) & 0x01) << 4)) - 16) * scale.s0 * shared_y.s2; \ + total_sums.s0 += ((((bits4.s0 & 0xF000) >> 12) | (((bits1.s2 >> 3) & 0x01) << 4)) - 16) * scale.s0 * shared_y.s3; \ + total_sums.s0 += (((bits4.s2 & 0x000F) | (((bits1.s2 >> 4) & 0x01) << 4)) - 16) * scale.s0 * shared_y.s4; \ + total_sums.s0 += ((((bits4.s2 & 0x00F0) >> 4) | (((bits1.s2 >> 5) & 0x01) << 4)) - 16) * scale.s0 * shared_y.s5; \ + total_sums.s0 += ((((bits4.s2 & 0x0F00) >> 8) | (((bits1.s2 >> 6) & 0x01) << 4)) - 16) * scale.s0 * shared_y.s6; \ + total_sums.s0 += ((((bits4.s2 & 0xF000) >> 12) | (((bits1.s2 >> 7) & 0x01) << 4)) - 16) * scale.s0 * shared_y.s7; \ + total_sums.s1 += (((bits4.s1 & 0x000F) | (((bits1.s6 ) & 0x01) << 4)) - 16) * scale.s1 * shared_y.s0; \ + total_sums.s1 += ((((bits4.s1 & 0x00F0) >> 4) | (((bits1.s6 >> 1) & 0x01) << 4)) - 16) * scale.s1 * shared_y.s1; \ + total_sums.s1 += ((((bits4.s1 & 0x0F00) >> 8) | (((bits1.s6 >> 2) & 0x01) << 4)) - 16) * scale.s1 * shared_y.s2; \ + total_sums.s1 += ((((bits4.s1 & 0xF000) >> 12) | (((bits1.s6 >> 3) & 0x01) << 4)) - 16) * scale.s1 * shared_y.s3; \ + total_sums.s1 += (((bits4.s3 & 0x000F) | (((bits1.s6 >> 4) & 0x01) << 4)) - 16) * scale.s1 * shared_y.s4; \ + total_sums.s1 += ((((bits4.s3 & 0x00F0) >> 4) | (((bits1.s6 >> 5) & 0x01) << 4)) - 16) * scale.s1 * shared_y.s5; \ + total_sums.s1 += ((((bits4.s3 & 0x0F00) >> 8) | (((bits1.s6 >> 6) & 0x01) << 4)) - 16) * scale.s1 * shared_y.s6; \ + total_sums.s1 += ((((bits4.s3 & 0xF000) >> 12) | (((bits1.s6 >> 7) & 0x01) << 4)) - 16) * scale.s1 * shared_y.s7; \ + shared_y = sub_group_broadcast(y, 3); \ + total_sums.s0 += (((bits4.s4 & 0x000F) | (((bits1.s3 ) & 0x01) << 4)) - 16) * scale.s0 * shared_y.s0; \ + total_sums.s0 += ((((bits4.s4 & 0x00F0) >> 4) | (((bits1.s3 >> 1) & 0x01) << 4)) - 16) * scale.s0 * shared_y.s1; \ + total_sums.s0 += ((((bits4.s4 & 0x0F00) >> 8) | (((bits1.s3 >> 2) & 0x01) << 4)) - 16) * scale.s0 * shared_y.s2; \ + total_sums.s0 += ((((bits4.s4 & 0xF000) >> 12) | (((bits1.s3 >> 3) & 0x01) << 4)) - 16) * scale.s0 * shared_y.s3; \ + total_sums.s0 += (((bits4.s6 & 0x000F) | (((bits1.s3 >> 4) & 0x01) << 4)) - 16) * scale.s0 * shared_y.s4; \ + total_sums.s0 += ((((bits4.s6 & 0x00F0) >> 4) | (((bits1.s3 >> 5) & 0x01) << 4)) - 16) * scale.s0 * shared_y.s5; \ + total_sums.s0 += ((((bits4.s6 & 0x0F00) >> 8) | (((bits1.s3 >> 6) & 0x01) << 4)) - 16) * scale.s0 * shared_y.s6; \ + total_sums.s0 += ((((bits4.s6 & 0xF000) >> 12) | (((bits1.s3 >> 7) & 0x01) << 4)) - 16) * scale.s0 * shared_y.s7; \ + total_sums.s1 += (((bits4.s5 & 0x000F) | (((bits1.s7 ) & 0x01) << 4)) - 16) * scale.s1 * shared_y.s0; \ + total_sums.s1 += ((((bits4.s5 & 0x00F0) >> 4) | (((bits1.s7 >> 1) & 0x01) << 4)) - 16) * scale.s1 * shared_y.s1; \ + total_sums.s1 += ((((bits4.s5 & 0x0F00) >> 8) | (((bits1.s7 >> 2) & 0x01) << 4)) - 16) * scale.s1 * shared_y.s2; \ + total_sums.s1 += ((((bits4.s5 & 0xF000) >> 12) | (((bits1.s7 >> 3) & 0x01) << 4)) - 16) * scale.s1 * shared_y.s3; \ + total_sums.s1 += (((bits4.s7 & 0x000F) | (((bits1.s7 >> 4) & 0x01) << 4)) - 16) * scale.s1 * shared_y.s4; \ + total_sums.s1 += ((((bits4.s7 & 0x00F0) >> 4) | (((bits1.s7 >> 5) & 0x01) << 4)) - 16) * scale.s1 * shared_y.s5; \ + total_sums.s1 += ((((bits4.s7 & 0x0F00) >> 8) | (((bits1.s7 >> 6) & 0x01) << 4)) - 16) * scale.s1 * shared_y.s6; \ + total_sums.s1 += ((((bits4.s7 & 0xF000) >> 12) | (((bits1.s7 >> 7) & 0x01) << 4)) - 16) * scale.s1 * shared_y.s7; \ + +#ifdef ADRENO_GPU +REQD_SUBGROUP_SIZE_64 +#endif +__kernel void kernel_gemv_noshuffle_q5_0_f32( + __read_only image1d_buffer_t src0_qs, // quantized A + global ushort * src0_qh, // 5th bits + global half2 * src0_d, // A scales + __read_only image1d_buffer_t src1, // B activations + global float * dst, + ulong offsetd, + int ne00, // K + int ne01) // M +{ + uint groupId = get_local_id(1); + uint gid = get_global_id(0); + ushort slid = get_sub_group_local_id(); + + uint K = ne00; + uint M = ne01; + + uint LINE_STRIDE_A = M / 2; + uint BLOCK_STRIDE_A = NSUBGROUPS * M; + + private uint4 regA; + private half2 regS; + private float8 regB; + + private float2 totalSum = (float2)(0.0f); + + for (uint k = groupId; k < (K / QK5_0); k += NSUBGROUPS) { + regS = src0_d[gid + k * LINE_STRIDE_A]; + + ushort4 qh_raw; + qh_raw.s0 = src0_qh[gid + (4*k + 0) * LINE_STRIDE_A]; + qh_raw.s1 = src0_qh[gid + (4*k + 1) * LINE_STRIDE_A]; + qh_raw.s2 = src0_qh[gid + (4*k + 2) * LINE_STRIDE_A]; + qh_raw.s3 = src0_qh[gid + (4*k + 3) * LINE_STRIDE_A]; + + uchar8 raw = as_uchar8(qh_raw); + uchar8 qh_bytes = (uchar8)(raw.s0, raw.s2, raw.s4, raw.s6, + raw.s1, raw.s3, raw.s5, raw.s7); + + // Load activations + if (slid < 4) { + regB.s0123 = read_imagef(src1, (slid * 2 + k * 8)); + regB.s4567 = read_imagef(src1, (1 + slid * 2 + k * 8)); + } + + regA.s0 = read_imageui(src0_qs, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 0)).x; + regA.s1 = read_imageui(src0_qs, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 1)).x; + regA.s2 = read_imageui(src0_qs, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 2)).x; + regA.s3 = read_imageui(src0_qs, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 3)).x; + +#ifdef VECTOR_SUB_GROUP_BROADCAST + dequantizeBlockAccum_ns_q5_0_sgbroadcast_8_hi(totalSum, as_ushort8(regA), qh_bytes, regS, regB); +#else + dequantizeBlockAccum_ns_q5_0_sgbroadcast_1_hi(totalSum, as_ushort8(regA), qh_bytes, regS, regB); +#endif // VECTOR_SUB_GROUP_BROADCAST + + regA.s0 = read_imageui(src0_qs, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 4)).x; + regA.s1 = read_imageui(src0_qs, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 5)).x; + regA.s2 = read_imageui(src0_qs, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 6)).x; + regA.s3 = read_imageui(src0_qs, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 7)).x; +#ifdef VECTOR_SUB_GROUP_BROADCAST + dequantizeBlockAccum_ns_q5_0_sgbroadcast_8_lo(totalSum, as_ushort8(regA), qh_bytes, regS, regB); +#else + dequantizeBlockAccum_ns_q5_0_sgbroadcast_1_lo(totalSum, as_ushort8(regA), qh_bytes, regS, regB); +#endif // VECTOR_SUB_GROUP_BROADCAST + } + + // reduction in local memory, assumes #wave=4 + local float2 reduceLM[SUBGROUP_SIZE * 3]; + if (groupId == 1) { + reduceLM[SUBGROUP_SIZE * 0 + slid] = totalSum; + } + if (groupId == 2) { + reduceLM[SUBGROUP_SIZE * 1 + slid] = totalSum; + } + if (groupId == 3) { + reduceLM[SUBGROUP_SIZE * 2 + slid] = totalSum; + } + + barrier(CLK_LOCAL_MEM_FENCE); + + if (groupId == 0) { + totalSum += reduceLM[SUBGROUP_SIZE * 0 + slid]; + } + if (groupId == 0) { + totalSum += reduceLM[SUBGROUP_SIZE * 1 + slid]; + } + if (groupId == 0) { + totalSum += reduceLM[SUBGROUP_SIZE * 2 + slid]; + } + + // 2 outputs per fiber in wave 0 + if (groupId == 0) { + dst = (global float*)((global char*)dst + offsetd); + vstore2(totalSum, 0, &(dst[gid * 2])); + } + +} diff --git a/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q5_1_f32.cl b/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q5_1_f32.cl new file mode 100644 index 00000000000..daf1308ea4b --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q5_1_f32.cl @@ -0,0 +1,294 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable +#pragma OPENCL EXTENSION cl_khr_subgroups : enable + +#ifdef cl_qcom_reqd_sub_group_size +#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable +#define ADRENO_GPU 1 +#define REQD_SUBGROUP_SIZE_64 __attribute__((qcom_reqd_sub_group_size("half"))) +#endif + +#define QK5_1 32 +#define NSUBGROUPS 4 +#define SUBGROUP_SIZE 64 + +#define dequantizeBlockAccum_ns_q5_1_sgbroadcast_1_hi(total_sums, bits4, bits1, scale, minv, y) \ + float shared_y; \ + shared_y = sub_group_broadcast(y.s0, 0); \ + total_sums.s0 += (((bits4.s0 & 0x000F) | (((bits1.s0 ) & 0x01) << 4)) * scale.s0 + minv.s0) * shared_y; \ + total_sums.s1 += (((bits4.s1 & 0x000F) | (((bits1.s4 ) & 0x01) << 4)) * scale.s1 + minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s1, 0); \ + total_sums.s0 += ((((bits4.s0 & 0x00F0) >> 4) | (((bits1.s0 >> 1) & 0x01) << 4)) * scale.s0 + minv.s0) * shared_y; \ + total_sums.s1 += ((((bits4.s1 & 0x00F0) >> 4) | (((bits1.s4 >> 1) & 0x01) << 4)) * scale.s1 + minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s2, 0); \ + total_sums.s0 += ((((bits4.s0 & 0x0F00) >> 8) | (((bits1.s0 >> 2) & 0x01) << 4)) * scale.s0 + minv.s0) * shared_y; \ + total_sums.s1 += ((((bits4.s1 & 0x0F00) >> 8) | (((bits1.s4 >> 2) & 0x01) << 4)) * scale.s1 + minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s3, 0); \ + total_sums.s0 += ((((bits4.s0 & 0xF000) >> 12) | (((bits1.s0 >> 3) & 0x01) << 4)) * scale.s0 + minv.s0) * shared_y; \ + total_sums.s1 += ((((bits4.s1 & 0xF000) >> 12) | (((bits1.s4 >> 3) & 0x01) << 4)) * scale.s1 + minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s4, 0); \ + total_sums.s0 += (((bits4.s2 & 0x000F) | (((bits1.s0 >> 4) & 0x01) << 4)) * scale.s0 + minv.s0) * shared_y; \ + total_sums.s1 += (((bits4.s3 & 0x000F) | (((bits1.s4 >> 4) & 0x01) << 4)) * scale.s1 + minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s5, 0); \ + total_sums.s0 += ((((bits4.s2 & 0x00F0) >> 4) | (((bits1.s0 >> 5) & 0x01) << 4)) * scale.s0 + minv.s0) * shared_y; \ + total_sums.s1 += ((((bits4.s3 & 0x00F0) >> 4) | (((bits1.s4 >> 5) & 0x01) << 4)) * scale.s1 + minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s6, 0); \ + total_sums.s0 += ((((bits4.s2 & 0x0F00) >> 8) | (((bits1.s0 >> 6) & 0x01) << 4)) * scale.s0 + minv.s0) * shared_y; \ + total_sums.s1 += ((((bits4.s3 & 0x0F00) >> 8) | (((bits1.s4 >> 6) & 0x01) << 4)) * scale.s1 + minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s7, 0); \ + total_sums.s0 += ((((bits4.s2 & 0xF000) >> 12) | (((bits1.s0 >> 7) & 0x01) << 4)) * scale.s0 + minv.s0) * shared_y; \ + total_sums.s1 += ((((bits4.s3 & 0xF000) >> 12) | (((bits1.s4 >> 7) & 0x01) << 4)) * scale.s1 + minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s0, 1); \ + total_sums.s0 += (((bits4.s4 & 0x000F) | (((bits1.s1 ) & 0x01) << 4)) * scale.s0 + minv.s0) * shared_y; \ + total_sums.s1 += (((bits4.s5 & 0x000F) | (((bits1.s5 ) & 0x01) << 4)) * scale.s1 + minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s1, 1); \ + total_sums.s0 += ((((bits4.s4 & 0x00F0) >> 4) | (((bits1.s1 >> 1) & 0x01) << 4)) * scale.s0 + minv.s0) * shared_y; \ + total_sums.s1 += ((((bits4.s5 & 0x00F0) >> 4) | (((bits1.s5 >> 1) & 0x01) << 4)) * scale.s1 + minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s2, 1); \ + total_sums.s0 += ((((bits4.s4 & 0x0F00) >> 8) | (((bits1.s1 >> 2) & 0x01) << 4)) * scale.s0 + minv.s0) * shared_y; \ + total_sums.s1 += ((((bits4.s5 & 0x0F00) >> 8) | (((bits1.s5 >> 2) & 0x01) << 4)) * scale.s1 + minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s3, 1); \ + total_sums.s0 += ((((bits4.s4 & 0xF000) >> 12) | (((bits1.s1 >> 3) & 0x01) << 4)) * scale.s0 + minv.s0) * shared_y; \ + total_sums.s1 += ((((bits4.s5 & 0xF000) >> 12) | (((bits1.s5 >> 3) & 0x01) << 4)) * scale.s1 + minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s4, 1); \ + total_sums.s0 += (((bits4.s6 & 0x000F) | (((bits1.s1 >> 4) & 0x01) << 4)) * scale.s0 + minv.s0) * shared_y; \ + total_sums.s1 += (((bits4.s7 & 0x000F) | (((bits1.s5 >> 4) & 0x01) << 4)) * scale.s1 + minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s5, 1); \ + total_sums.s0 += ((((bits4.s6 & 0x00F0) >> 4) | (((bits1.s1 >> 5) & 0x01) << 4)) * scale.s0 + minv.s0) * shared_y; \ + total_sums.s1 += ((((bits4.s7 & 0x00F0) >> 4) | (((bits1.s5 >> 5) & 0x01) << 4)) * scale.s1 + minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s6, 1); \ + total_sums.s0 += ((((bits4.s6 & 0x0F00) >> 8) | (((bits1.s1 >> 6) & 0x01) << 4)) * scale.s0 + minv.s0) * shared_y; \ + total_sums.s1 += ((((bits4.s7 & 0x0F00) >> 8) | (((bits1.s5 >> 6) & 0x01) << 4)) * scale.s1 + minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s7, 1); \ + total_sums.s0 += ((((bits4.s6 & 0xF000) >> 12) | (((bits1.s1 >> 7) & 0x01) << 4)) * scale.s0 + minv.s0) * shared_y; \ + total_sums.s1 += ((((bits4.s7 & 0xF000) >> 12) | (((bits1.s5 >> 7) & 0x01) << 4)) * scale.s1 + minv.s1) * shared_y; \ + + +#define dequantizeBlockAccum_ns_q5_1_sgbroadcast_1_lo(total_sums, bits4, bits1, scale, minv, y) \ + shared_y = sub_group_broadcast(y.s0, 2); \ + total_sums.s0 += (((bits4.s0 & 0x000F) | (((bits1.s2 ) & 0x01) << 4)) * scale.s0 + minv.s0) * shared_y; \ + total_sums.s1 += (((bits4.s1 & 0x000F) | (((bits1.s6 ) & 0x01) << 4)) * scale.s1 + minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s1, 2); \ + total_sums.s0 += ((((bits4.s0 & 0x00F0) >> 4) | (((bits1.s2 >> 1) & 0x01) << 4)) * scale.s0 + minv.s0) * shared_y; \ + total_sums.s1 += ((((bits4.s1 & 0x00F0) >> 4) | (((bits1.s6 >> 1) & 0x01) << 4)) * scale.s1 + minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s2, 2); \ + total_sums.s0 += ((((bits4.s0 & 0x0F00) >> 8) | (((bits1.s2 >> 2) & 0x01) << 4)) * scale.s0 + minv.s0) * shared_y; \ + total_sums.s1 += ((((bits4.s1 & 0x0F00) >> 8) | (((bits1.s6 >> 2) & 0x01) << 4)) * scale.s1 + minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s3, 2); \ + total_sums.s0 += ((((bits4.s0 & 0xF000) >> 12) | (((bits1.s2 >> 3) & 0x01) << 4)) * scale.s0 + minv.s0) * shared_y; \ + total_sums.s1 += ((((bits4.s1 & 0xF000) >> 12) | (((bits1.s6 >> 3) & 0x01) << 4)) * scale.s1 + minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s4, 2); \ + total_sums.s0 += (((bits4.s2 & 0x000F) | (((bits1.s2 >> 4) & 0x01) << 4)) * scale.s0 + minv.s0) * shared_y; \ + total_sums.s1 += (((bits4.s3 & 0x000F) | (((bits1.s6 >> 4) & 0x01) << 4)) * scale.s1 + minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s5, 2); \ + total_sums.s0 += ((((bits4.s2 & 0x00F0) >> 4) | (((bits1.s2 >> 5) & 0x01) << 4)) * scale.s0 + minv.s0) * shared_y; \ + total_sums.s1 += ((((bits4.s3 & 0x00F0) >> 4) | (((bits1.s6 >> 5) & 0x01) << 4)) * scale.s1 + minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s6, 2); \ + total_sums.s0 += ((((bits4.s2 & 0x0F00) >> 8) | (((bits1.s2 >> 6) & 0x01) << 4)) * scale.s0 + minv.s0) * shared_y; \ + total_sums.s1 += ((((bits4.s3 & 0x0F00) >> 8) | (((bits1.s6 >> 6) & 0x01) << 4)) * scale.s1 + minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s7, 2); \ + total_sums.s0 += ((((bits4.s2 & 0xF000) >> 12) | (((bits1.s2 >> 7) & 0x01) << 4)) * scale.s0 + minv.s0) * shared_y; \ + total_sums.s1 += ((((bits4.s3 & 0xF000) >> 12) | (((bits1.s6 >> 7) & 0x01) << 4)) * scale.s1 + minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s0, 3); \ + total_sums.s0 += (((bits4.s4 & 0x000F) | (((bits1.s3 ) & 0x01) << 4)) * scale.s0 + minv.s0) * shared_y; \ + total_sums.s1 += (((bits4.s5 & 0x000F) | (((bits1.s7 ) & 0x01) << 4)) * scale.s1 + minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s1, 3); \ + total_sums.s0 += ((((bits4.s4 & 0x00F0) >> 4) | (((bits1.s3 >> 1) & 0x01) << 4)) * scale.s0 + minv.s0) * shared_y; \ + total_sums.s1 += ((((bits4.s5 & 0x00F0) >> 4) | (((bits1.s7 >> 1) & 0x01) << 4)) * scale.s1 + minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s2, 3); \ + total_sums.s0 += ((((bits4.s4 & 0x0F00) >> 8) | (((bits1.s3 >> 2) & 0x01) << 4)) * scale.s0 + minv.s0) * shared_y; \ + total_sums.s1 += ((((bits4.s5 & 0x0F00) >> 8) | (((bits1.s7 >> 2) & 0x01) << 4)) * scale.s1 + minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s3, 3); \ + total_sums.s0 += ((((bits4.s4 & 0xF000) >> 12) | (((bits1.s3 >> 3) & 0x01) << 4)) * scale.s0 + minv.s0) * shared_y; \ + total_sums.s1 += ((((bits4.s5 & 0xF000) >> 12) | (((bits1.s7 >> 3) & 0x01) << 4)) * scale.s1 + minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s4, 3); \ + total_sums.s0 += (((bits4.s6 & 0x000F) | (((bits1.s3 >> 4) & 0x01) << 4)) * scale.s0 + minv.s0) * shared_y; \ + total_sums.s1 += (((bits4.s7 & 0x000F) | (((bits1.s7 >> 4) & 0x01) << 4)) * scale.s1 + minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s5, 3); \ + total_sums.s0 += ((((bits4.s6 & 0x00F0) >> 4) | (((bits1.s3 >> 5) & 0x01) << 4)) * scale.s0 + minv.s0) * shared_y; \ + total_sums.s1 += ((((bits4.s7 & 0x00F0) >> 4) | (((bits1.s7 >> 5) & 0x01) << 4)) * scale.s1 + minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s6, 3); \ + total_sums.s0 += ((((bits4.s6 & 0x0F00) >> 8) | (((bits1.s3 >> 6) & 0x01) << 4)) * scale.s0 + minv.s0) * shared_y; \ + total_sums.s1 += ((((bits4.s7 & 0x0F00) >> 8) | (((bits1.s7 >> 6) & 0x01) << 4)) * scale.s1 + minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s7, 3); \ + total_sums.s0 += ((((bits4.s6 & 0xF000) >> 12) | (((bits1.s3 >> 7) & 0x01) << 4)) * scale.s0 + minv.s0) * shared_y; \ + total_sums.s1 += ((((bits4.s7 & 0xF000) >> 12) | (((bits1.s7 >> 7) & 0x01) << 4)) * scale.s1 + minv.s1) * shared_y; \ + + +#define dequantizeBlockAccum_ns_q5_1_sgbroadcast_8_hi(total_sums, bits4, bits1, scale, minv, y) \ + float8 shared_y; \ + shared_y = sub_group_broadcast(y, 0); \ + total_sums.s0 += (((bits4.s0 & 0x000F) | (((bits1.s0 ) & 0x01) << 4)) * scale.s0 + minv.s0) * shared_y.s0; \ + total_sums.s0 += ((((bits4.s0 & 0x00F0) >> 4) | (((bits1.s0 >> 1) & 0x01) << 4)) * scale.s0 + minv.s0) * shared_y.s1; \ + total_sums.s0 += ((((bits4.s0 & 0x0F00) >> 8) | (((bits1.s0 >> 2) & 0x01) << 4)) * scale.s0 + minv.s0) * shared_y.s2; \ + total_sums.s0 += ((((bits4.s0 & 0xF000) >> 12) | (((bits1.s0 >> 3) & 0x01) << 4)) * scale.s0 + minv.s0) * shared_y.s3; \ + total_sums.s0 += (((bits4.s2 & 0x000F) | (((bits1.s0 >> 4) & 0x01) << 4)) * scale.s0 + minv.s0) * shared_y.s4; \ + total_sums.s0 += ((((bits4.s2 & 0x00F0) >> 4) | (((bits1.s0 >> 5) & 0x01) << 4)) * scale.s0 + minv.s0) * shared_y.s5; \ + total_sums.s0 += ((((bits4.s2 & 0x0F00) >> 8) | (((bits1.s0 >> 6) & 0x01) << 4)) * scale.s0 + minv.s0) * shared_y.s6; \ + total_sums.s0 += ((((bits4.s2 & 0xF000) >> 12) | (((bits1.s0 >> 7) & 0x01) << 4)) * scale.s0 + minv.s0) * shared_y.s7; \ + total_sums.s1 += (((bits4.s1 & 0x000F) | (((bits1.s4 ) & 0x01) << 4)) * scale.s1 + minv.s1) * shared_y.s0; \ + total_sums.s1 += ((((bits4.s1 & 0x00F0) >> 4) | (((bits1.s4 >> 1) & 0x01) << 4)) * scale.s1 + minv.s1) * shared_y.s1; \ + total_sums.s1 += ((((bits4.s1 & 0x0F00) >> 8) | (((bits1.s4 >> 2) & 0x01) << 4)) * scale.s1 + minv.s1) * shared_y.s2; \ + total_sums.s1 += ((((bits4.s1 & 0xF000) >> 12) | (((bits1.s4 >> 3) & 0x01) << 4)) * scale.s1 + minv.s1) * shared_y.s3; \ + total_sums.s1 += (((bits4.s3 & 0x000F) | (((bits1.s4 >> 4) & 0x01) << 4)) * scale.s1 + minv.s1) * shared_y.s4; \ + total_sums.s1 += ((((bits4.s3 & 0x00F0) >> 4) | (((bits1.s4 >> 5) & 0x01) << 4)) * scale.s1 + minv.s1) * shared_y.s5; \ + total_sums.s1 += ((((bits4.s3 & 0x0F00) >> 8) | (((bits1.s4 >> 6) & 0x01) << 4)) * scale.s1 + minv.s1) * shared_y.s6; \ + total_sums.s1 += ((((bits4.s3 & 0xF000) >> 12) | (((bits1.s4 >> 7) & 0x01) << 4)) * scale.s1 + minv.s1) * shared_y.s7; \ + shared_y = sub_group_broadcast(y, 1); \ + total_sums.s0 += (((bits4.s4 & 0x000F) | (((bits1.s1 ) & 0x01) << 4)) * scale.s0 + minv.s0) * shared_y.s0; \ + total_sums.s0 += ((((bits4.s4 & 0x00F0) >> 4) | (((bits1.s1 >> 1) & 0x01) << 4)) * scale.s0 + minv.s0) * shared_y.s1; \ + total_sums.s0 += ((((bits4.s4 & 0x0F00) >> 8) | (((bits1.s1 >> 2) & 0x01) << 4)) * scale.s0 + minv.s0) * shared_y.s2; \ + total_sums.s0 += ((((bits4.s4 & 0xF000) >> 12) | (((bits1.s1 >> 3) & 0x01) << 4)) * scale.s0 + minv.s0) * shared_y.s3; \ + total_sums.s0 += (((bits4.s6 & 0x000F) | (((bits1.s1 >> 4) & 0x01) << 4)) * scale.s0 + minv.s0) * shared_y.s4; \ + total_sums.s0 += ((((bits4.s6 & 0x00F0) >> 4) | (((bits1.s1 >> 5) & 0x01) << 4)) * scale.s0 + minv.s0) * shared_y.s5; \ + total_sums.s0 += ((((bits4.s6 & 0x0F00) >> 8) | (((bits1.s1 >> 6) & 0x01) << 4)) * scale.s0 + minv.s0) * shared_y.s6; \ + total_sums.s0 += ((((bits4.s6 & 0xF000) >> 12) | (((bits1.s1 >> 7) & 0x01) << 4)) * scale.s0 + minv.s0) * shared_y.s7; \ + total_sums.s1 += (((bits4.s5 & 0x000F) | (((bits1.s5 ) & 0x01) << 4)) * scale.s1 + minv.s1) * shared_y.s0; \ + total_sums.s1 += ((((bits4.s5 & 0x00F0) >> 4) | (((bits1.s5 >> 1) & 0x01) << 4)) * scale.s1 + minv.s1) * shared_y.s1; \ + total_sums.s1 += ((((bits4.s5 & 0x0F00) >> 8) | (((bits1.s5 >> 2) & 0x01) << 4)) * scale.s1 + minv.s1) * shared_y.s2; \ + total_sums.s1 += ((((bits4.s5 & 0xF000) >> 12) | (((bits1.s5 >> 3) & 0x01) << 4)) * scale.s1 + minv.s1) * shared_y.s3; \ + total_sums.s1 += (((bits4.s7 & 0x000F) | (((bits1.s5 >> 4) & 0x01) << 4)) * scale.s1 + minv.s1) * shared_y.s4; \ + total_sums.s1 += ((((bits4.s7 & 0x00F0) >> 4) | (((bits1.s5 >> 5) & 0x01) << 4)) * scale.s1 + minv.s1) * shared_y.s5; \ + total_sums.s1 += ((((bits4.s7 & 0x0F00) >> 8) | (((bits1.s5 >> 6) & 0x01) << 4)) * scale.s1 + minv.s1) * shared_y.s6; \ + total_sums.s1 += ((((bits4.s7 & 0xF000) >> 12) | (((bits1.s5 >> 7) & 0x01) << 4)) * scale.s1 + minv.s1) * shared_y.s7; \ + + +#define dequantizeBlockAccum_ns_q5_1_sgbroadcast_8_lo(total_sums, bits4, bits1, scale, minv, y) \ + shared_y = sub_group_broadcast(y, 2); \ + total_sums.s0 += (((bits4.s0 & 0x000F) | (((bits1.s2 ) & 0x01) << 4)) * scale.s0 + minv.s0) * shared_y.s0; \ + total_sums.s0 += ((((bits4.s0 & 0x00F0) >> 4) | (((bits1.s2 >> 1) & 0x01) << 4)) * scale.s0 + minv.s0) * shared_y.s1; \ + total_sums.s0 += ((((bits4.s0 & 0x0F00) >> 8) | (((bits1.s2 >> 2) & 0x01) << 4)) * scale.s0 + minv.s0) * shared_y.s2; \ + total_sums.s0 += ((((bits4.s0 & 0xF000) >> 12) | (((bits1.s2 >> 3) & 0x01) << 4)) * scale.s0 + minv.s0) * shared_y.s3; \ + total_sums.s0 += (((bits4.s2 & 0x000F) | (((bits1.s2 >> 4) & 0x01) << 4)) * scale.s0 + minv.s0) * shared_y.s4; \ + total_sums.s0 += ((((bits4.s2 & 0x00F0) >> 4) | (((bits1.s2 >> 5) & 0x01) << 4)) * scale.s0 + minv.s0) * shared_y.s5; \ + total_sums.s0 += ((((bits4.s2 & 0x0F00) >> 8) | (((bits1.s2 >> 6) & 0x01) << 4)) * scale.s0 + minv.s0) * shared_y.s6; \ + total_sums.s0 += ((((bits4.s2 & 0xF000) >> 12) | (((bits1.s2 >> 7) & 0x01) << 4)) * scale.s0 + minv.s0) * shared_y.s7; \ + total_sums.s1 += (((bits4.s1 & 0x000F) | (((bits1.s6 ) & 0x01) << 4)) * scale.s1 + minv.s1) * shared_y.s0; \ + total_sums.s1 += ((((bits4.s1 & 0x00F0) >> 4) | (((bits1.s6 >> 1) & 0x01) << 4)) * scale.s1 + minv.s1) * shared_y.s1; \ + total_sums.s1 += ((((bits4.s1 & 0x0F00) >> 8) | (((bits1.s6 >> 2) & 0x01) << 4)) * scale.s1 + minv.s1) * shared_y.s2; \ + total_sums.s1 += ((((bits4.s1 & 0xF000) >> 12) | (((bits1.s6 >> 3) & 0x01) << 4)) * scale.s1 + minv.s1) * shared_y.s3; \ + total_sums.s1 += (((bits4.s3 & 0x000F) | (((bits1.s6 >> 4) & 0x01) << 4)) * scale.s1 + minv.s1) * shared_y.s4; \ + total_sums.s1 += ((((bits4.s3 & 0x00F0) >> 4) | (((bits1.s6 >> 5) & 0x01) << 4)) * scale.s1 + minv.s1) * shared_y.s5; \ + total_sums.s1 += ((((bits4.s3 & 0x0F00) >> 8) | (((bits1.s6 >> 6) & 0x01) << 4)) * scale.s1 + minv.s1) * shared_y.s6; \ + total_sums.s1 += ((((bits4.s3 & 0xF000) >> 12) | (((bits1.s6 >> 7) & 0x01) << 4)) * scale.s1 + minv.s1) * shared_y.s7; \ + shared_y = sub_group_broadcast(y, 3); \ + total_sums.s0 += (((bits4.s4 & 0x000F) | (((bits1.s3 ) & 0x01) << 4)) * scale.s0 + minv.s0) * shared_y.s0; \ + total_sums.s0 += ((((bits4.s4 & 0x00F0) >> 4) | (((bits1.s3 >> 1) & 0x01) << 4)) * scale.s0 + minv.s0) * shared_y.s1; \ + total_sums.s0 += ((((bits4.s4 & 0x0F00) >> 8) | (((bits1.s3 >> 2) & 0x01) << 4)) * scale.s0 + minv.s0) * shared_y.s2; \ + total_sums.s0 += ((((bits4.s4 & 0xF000) >> 12) | (((bits1.s3 >> 3) & 0x01) << 4)) * scale.s0 + minv.s0) * shared_y.s3; \ + total_sums.s0 += (((bits4.s6 & 0x000F) | (((bits1.s3 >> 4) & 0x01) << 4)) * scale.s0 + minv.s0) * shared_y.s4; \ + total_sums.s0 += ((((bits4.s6 & 0x00F0) >> 4) | (((bits1.s3 >> 5) & 0x01) << 4)) * scale.s0 + minv.s0) * shared_y.s5; \ + total_sums.s0 += ((((bits4.s6 & 0x0F00) >> 8) | (((bits1.s3 >> 6) & 0x01) << 4)) * scale.s0 + minv.s0) * shared_y.s6; \ + total_sums.s0 += ((((bits4.s6 & 0xF000) >> 12) | (((bits1.s3 >> 7) & 0x01) << 4)) * scale.s0 + minv.s0) * shared_y.s7; \ + total_sums.s1 += (((bits4.s5 & 0x000F) | (((bits1.s7 ) & 0x01) << 4)) * scale.s1 + minv.s1) * shared_y.s0; \ + total_sums.s1 += ((((bits4.s5 & 0x00F0) >> 4) | (((bits1.s7 >> 1) & 0x01) << 4)) * scale.s1 + minv.s1) * shared_y.s1; \ + total_sums.s1 += ((((bits4.s5 & 0x0F00) >> 8) | (((bits1.s7 >> 2) & 0x01) << 4)) * scale.s1 + minv.s1) * shared_y.s2; \ + total_sums.s1 += ((((bits4.s5 & 0xF000) >> 12) | (((bits1.s7 >> 3) & 0x01) << 4)) * scale.s1 + minv.s1) * shared_y.s3; \ + total_sums.s1 += (((bits4.s7 & 0x000F) | (((bits1.s7 >> 4) & 0x01) << 4)) * scale.s1 + minv.s1) * shared_y.s4; \ + total_sums.s1 += ((((bits4.s7 & 0x00F0) >> 4) | (((bits1.s7 >> 5) & 0x01) << 4)) * scale.s1 + minv.s1) * shared_y.s5; \ + total_sums.s1 += ((((bits4.s7 & 0x0F00) >> 8) | (((bits1.s7 >> 6) & 0x01) << 4)) * scale.s1 + minv.s1) * shared_y.s6; \ + total_sums.s1 += ((((bits4.s7 & 0xF000) >> 12) | (((bits1.s7 >> 7) & 0x01) << 4)) * scale.s1 + minv.s1) * shared_y.s7; \ + +#ifdef ADRENO_GPU +REQD_SUBGROUP_SIZE_64 +#endif +__kernel void kernel_gemv_noshuffle_q5_1_f32( + __read_only image1d_buffer_t src0_qs, // quantized A + global ushort * src0_qh, // 5th bits + global half2 * src0_d, // A scales + global half2 * src0_m, // A mins + __read_only image1d_buffer_t src1, // B activations + global float * dst, + ulong offsetd, + int ne00, // K + int ne01) // M +{ + uint groupId = get_local_id(1); + uint gid = get_global_id(0); + ushort slid = get_sub_group_local_id(); + + uint K = ne00; + uint M = ne01; + + uint LINE_STRIDE_A = M / 2; + uint BLOCK_STRIDE_A = NSUBGROUPS * M; + + __private uint4 regA; + __private half2 regS; + __private half2 regM; + __private float8 regB; + + __private float2 totalSum = (float2)(0.0f); + + for (uint k = groupId; k < (K / QK5_1); k += NSUBGROUPS) { + regS = src0_d[gid + k * LINE_STRIDE_A]; + regM = src0_m[gid + k * LINE_STRIDE_A]; + + ushort4 qh_raw; + qh_raw.s0 = src0_qh[gid + (4*k + 0) * LINE_STRIDE_A]; + qh_raw.s1 = src0_qh[gid + (4*k + 1) * LINE_STRIDE_A]; + qh_raw.s2 = src0_qh[gid + (4*k + 2) * LINE_STRIDE_A]; + qh_raw.s3 = src0_qh[gid + (4*k + 3) * LINE_STRIDE_A]; + + uchar8 raw = as_uchar8(qh_raw); + uchar8 qh_bytes = (uchar8)(raw.s0, raw.s2, raw.s4, raw.s6, + raw.s1, raw.s3, raw.s5, raw.s7); + + // Load activations + if (slid < 4) { + regB.s0123 = read_imagef(src1, (slid * 2 + k * 8)); + regB.s4567 = read_imagef(src1, (1 + slid * 2 + k * 8)); + } + + regA.s0 = read_imageui(src0_qs, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 0)).x; + regA.s1 = read_imageui(src0_qs, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 1)).x; + regA.s2 = read_imageui(src0_qs, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 2)).x; + regA.s3 = read_imageui(src0_qs, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 3)).x; + +#ifdef VECTOR_SUB_GROUP_BROADCAST + dequantizeBlockAccum_ns_q5_1_sgbroadcast_8_hi(totalSum, as_ushort8(regA), qh_bytes, regS, regM, regB); +#else + dequantizeBlockAccum_ns_q5_1_sgbroadcast_1_hi(totalSum, as_ushort8(regA), qh_bytes, regS, regM, regB); +#endif // VECTOR_SUB_GROUP_BROADCAST + + regA.s0 = read_imageui(src0_qs, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 4)).x; + regA.s1 = read_imageui(src0_qs, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 5)).x; + regA.s2 = read_imageui(src0_qs, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 6)).x; + regA.s3 = read_imageui(src0_qs, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 7)).x; +#ifdef VECTOR_SUB_GROUP_BROADCAST + dequantizeBlockAccum_ns_q5_1_sgbroadcast_8_lo(totalSum, as_ushort8(regA), qh_bytes, regS, regM, regB); +#else + dequantizeBlockAccum_ns_q5_1_sgbroadcast_1_lo(totalSum, as_ushort8(regA), qh_bytes, regS, regM, regB); +#endif // VECTOR_SUB_GROUP_BROADCAST + } + + // reduction in local memory, assumes #wave=4 + local float2 reduceLM[SUBGROUP_SIZE * 3]; + if (groupId == 1) { + reduceLM[SUBGROUP_SIZE * 0 + slid] = totalSum; + } + if (groupId == 2) { + reduceLM[SUBGROUP_SIZE * 1 + slid] = totalSum; + } + if (groupId == 3) { + reduceLM[SUBGROUP_SIZE * 2 + slid] = totalSum; + } + + barrier(CLK_LOCAL_MEM_FENCE); + + if (groupId == 0) { + totalSum += reduceLM[SUBGROUP_SIZE * 0 + slid]; + } + if (groupId == 0) { + totalSum += reduceLM[SUBGROUP_SIZE * 1 + slid]; + } + if (groupId == 0) { + totalSum += reduceLM[SUBGROUP_SIZE * 2 + slid]; + } + + // 2 outputs per fiber in wave 0 + if (groupId == 0) { + dst = (global float*)((global char*)dst + offsetd); + vstore2(totalSum, 0, &(dst[gid * 2])); + } + +} diff --git a/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q5_k_f32.cl b/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q5_k_f32.cl new file mode 100644 index 00000000000..c40db166638 --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q5_k_f32.cl @@ -0,0 +1,326 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable +#pragma OPENCL EXTENSION cl_khr_subgroups : enable + +#ifdef cl_qcom_reqd_sub_group_size +#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable +#define ADRENO_GPU 1 +#define REQD_SUBGROUP_SIZE_64 __attribute__((qcom_reqd_sub_group_size("half"))) +#endif + +#define QK_K 256 +#define NSUBGROUPS 4 +#define SUBGROUP_SIZE 64 + +inline void get_scale_min_k4( + int j, + global const uchar * q, + uchar * d, + uchar * m, + uchar mask_d6, + uchar mask_d4, + uchar mask_hi2 +) { + if (j < 4) { + *d = q[j] & mask_d6; + *m = q[j+4] & mask_d6; + } else { + *d = (q[j+4] & mask_d4) | ((q[j-4] & mask_hi2) >> 2); + *m = ((q[j+4] >> 4) & mask_d4) | ((q[j] & mask_hi2) >> 2); + } +} + +#define dequantizeBlockAccum_ns_sgbroadcast_1_hi(total_sums, bits4, bits1, scale, minv, y) \ + float shared_y; \ + shared_y = sub_group_broadcast(y.s0, 0); \ + total_sums.s0 += (((bits4.s0 & 0x000F) | ((bits1.s0 & 0x01) << 4)) * scale.s0 - minv.s0) * shared_y; \ + total_sums.s1 += (((bits4.s1 & 0x000F) | ((bits1.s1 & 0x01) << 4)) * scale.s1 - minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s1, 0); \ + total_sums.s0 += ((((bits4.s0 & 0x00F0) >> 4) | (((bits1.s0 >> 1) & 0x01) << 4)) * scale.s0 - minv.s0) * shared_y; \ + total_sums.s1 += ((((bits4.s1 & 0x00F0) >> 4) | (((bits1.s1 >> 1) & 0x01) << 4)) * scale.s1 - minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s2, 0); \ + total_sums.s0 += ((((bits4.s0 & 0x0F00) >> 8) | (((bits1.s0 >> 2) & 0x01) << 4)) * scale.s0 - minv.s0) * shared_y; \ + total_sums.s1 += ((((bits4.s1 & 0x0F00) >> 8) | (((bits1.s1 >> 2) & 0x01) << 4)) * scale.s1 - minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s3, 0); \ + total_sums.s0 += ((((bits4.s0 & 0xF000) >> 12) | (((bits1.s0 >> 3) & 0x01) << 4)) * scale.s0 - minv.s0) * shared_y; \ + total_sums.s1 += ((((bits4.s1 & 0xF000) >> 12) | (((bits1.s1 >> 3) & 0x01) << 4)) * scale.s1 - minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s4, 0); \ + total_sums.s0 += (((bits4.s2 & 0x000F) | (((bits1.s0 >> 4) & 0x01) << 4)) * scale.s0 - minv.s0) * shared_y; \ + total_sums.s1 += (((bits4.s3 & 0x000F) | (((bits1.s1 >> 4) & 0x01) << 4)) * scale.s1 - minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s5, 0); \ + total_sums.s0 += ((((bits4.s2 & 0x00F0) >> 4) | (((bits1.s0 >> 5) & 0x01) << 4)) * scale.s0 - minv.s0) * shared_y; \ + total_sums.s1 += ((((bits4.s3 & 0x00F0) >> 4) | (((bits1.s1 >> 5) & 0x01) << 4)) * scale.s1 - minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s6, 0); \ + total_sums.s0 += ((((bits4.s2 & 0x0F00) >> 8) | (((bits1.s0 >> 6) & 0x01) << 4)) * scale.s0 - minv.s0) * shared_y; \ + total_sums.s1 += ((((bits4.s3 & 0x0F00) >> 8) | (((bits1.s1 >> 6) & 0x01) << 4)) * scale.s1 - minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s7, 0); \ + total_sums.s0 += ((((bits4.s2 & 0xF000) >> 12) | (((bits1.s0 >> 7) & 0x01) << 4)) * scale.s0 - minv.s0) * shared_y; \ + total_sums.s1 += ((((bits4.s3 & 0xF000) >> 12) | (((bits1.s1 >> 7) & 0x01) << 4)) * scale.s1 - minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s0, 1); \ + total_sums.s0 += (((bits4.s4 & 0x000F) | ((bits1.s2 & 0x01) << 4)) * scale.s0 - minv.s0) * shared_y; \ + total_sums.s1 += (((bits4.s5 & 0x000F) | ((bits1.s3 & 0x01) << 4)) * scale.s1 - minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s1, 1); \ + total_sums.s0 += ((((bits4.s4 & 0x00F0) >> 4) | (((bits1.s2 >> 1) & 0x01) << 4)) * scale.s0 - minv.s0) * shared_y; \ + total_sums.s1 += ((((bits4.s5 & 0x00F0) >> 4) | (((bits1.s3 >> 1) & 0x01) << 4)) * scale.s1 - minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s2, 1); \ + total_sums.s0 += ((((bits4.s4 & 0x0F00) >> 8) | (((bits1.s2 >> 2) & 0x01) << 4)) * scale.s0 - minv.s0) * shared_y; \ + total_sums.s1 += ((((bits4.s5 & 0x0F00) >> 8) | (((bits1.s3 >> 2) & 0x01) << 4)) * scale.s1 - minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s3, 1); \ + total_sums.s0 += ((((bits4.s4 & 0xF000) >> 12) | (((bits1.s2 >> 3) & 0x01) << 4)) * scale.s0 - minv.s0) * shared_y; \ + total_sums.s1 += ((((bits4.s5 & 0xF000) >> 12) | (((bits1.s3 >> 3) & 0x01) << 4)) * scale.s1 - minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s4, 1); \ + total_sums.s0 += (((bits4.s6 & 0x000F) | (((bits1.s2 >> 4) & 0x01) << 4)) * scale.s0 - minv.s0) * shared_y; \ + total_sums.s1 += (((bits4.s7 & 0x000F) | (((bits1.s3 >> 4) & 0x01) << 4)) * scale.s1 - minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s5, 1); \ + total_sums.s0 += ((((bits4.s6 & 0x00F0) >> 4) | (((bits1.s2 >> 5) & 0x01) << 4)) * scale.s0 - minv.s0) * shared_y; \ + total_sums.s1 += ((((bits4.s7 & 0x00F0) >> 4) | (((bits1.s3 >> 5) & 0x01) << 4)) * scale.s1 - minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s6, 1); \ + total_sums.s0 += ((((bits4.s6 & 0x0F00) >> 8) | (((bits1.s2 >> 6) & 0x01) << 4)) * scale.s0 - minv.s0) * shared_y; \ + total_sums.s1 += ((((bits4.s7 & 0x0F00) >> 8) | (((bits1.s3 >> 6) & 0x01) << 4)) * scale.s1 - minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s7, 1); \ + total_sums.s0 += ((((bits4.s6 & 0xF000) >> 12) | (((bits1.s2 >> 7) & 0x01) << 4)) * scale.s0 - minv.s0) * shared_y; \ + total_sums.s1 += ((((bits4.s7 & 0xF000) >> 12) | (((bits1.s3 >> 7) & 0x01) << 4)) * scale.s1 - minv.s1) * shared_y; \ + + +#define dequantizeBlockAccum_ns_sgbroadcast_1_lo(total_sums, bits4, bits1, scale, minv, y) \ + shared_y = sub_group_broadcast(y.s0, 2); \ + total_sums.s0 += (((bits4.s0 & 0x000F) | ((bits1.s4 & 0x01) << 4)) * scale.s0 - minv.s0) * shared_y; \ + total_sums.s1 += (((bits4.s1 & 0x000F) | ((bits1.s5 & 0x01) << 4)) * scale.s1 - minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s1, 2); \ + total_sums.s0 += ((((bits4.s0 & 0x00F0) >> 4) | (((bits1.s4 >> 1) & 0x01) << 4)) * scale.s0 - minv.s0) * shared_y; \ + total_sums.s1 += ((((bits4.s1 & 0x00F0) >> 4) | (((bits1.s5 >> 1) & 0x01) << 4)) * scale.s1 - minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s2, 2); \ + total_sums.s0 += ((((bits4.s0 & 0x0F00) >> 8) | (((bits1.s4 >> 2) & 0x01) << 4)) * scale.s0 - minv.s0) * shared_y; \ + total_sums.s1 += ((((bits4.s1 & 0x0F00) >> 8) | (((bits1.s5 >> 2) & 0x01) << 4)) * scale.s1 - minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s3, 2); \ + total_sums.s0 += ((((bits4.s0 & 0xF000) >> 12) | (((bits1.s4 >> 3) & 0x01) << 4)) * scale.s0 - minv.s0) * shared_y; \ + total_sums.s1 += ((((bits4.s1 & 0xF000) >> 12) | (((bits1.s5 >> 3) & 0x01) << 4)) * scale.s1 - minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s4, 2); \ + total_sums.s0 += (((bits4.s2 & 0x000F) | (((bits1.s4 >> 4) & 0x01) << 4)) * scale.s0 - minv.s0) * shared_y; \ + total_sums.s1 += (((bits4.s3 & 0x000F) | (((bits1.s5 >> 4) & 0x01) << 4)) * scale.s1 - minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s5, 2); \ + total_sums.s0 += ((((bits4.s2 & 0x00F0) >> 4) | (((bits1.s4 >> 5) & 0x01) << 4)) * scale.s0 - minv.s0) * shared_y; \ + total_sums.s1 += ((((bits4.s3 & 0x00F0) >> 4) | (((bits1.s5 >> 5) & 0x01) << 4)) * scale.s1 - minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s6, 2); \ + total_sums.s0 += ((((bits4.s2 & 0x0F00) >> 8) | (((bits1.s4 >> 6) & 0x01) << 4)) * scale.s0 - minv.s0) * shared_y; \ + total_sums.s1 += ((((bits4.s3 & 0x0F00) >> 8) | (((bits1.s5 >> 6) & 0x01) << 4)) * scale.s1 - minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s7, 2); \ + total_sums.s0 += ((((bits4.s2 & 0xF000) >> 12) | (((bits1.s4 >> 7) & 0x01) << 4)) * scale.s0 - minv.s0) * shared_y; \ + total_sums.s1 += ((((bits4.s3 & 0xF000) >> 12) | (((bits1.s5 >> 7) & 0x01) << 4)) * scale.s1 - minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s0, 3); \ + total_sums.s0 += (((bits4.s4 & 0x000F) | ((bits1.s6 & 0x01) << 4)) * scale.s0 - minv.s0) * shared_y; \ + total_sums.s1 += (((bits4.s5 & 0x000F) | ((bits1.s7 & 0x01) << 4)) * scale.s1 - minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s1, 3); \ + total_sums.s0 += ((((bits4.s4 & 0x00F0) >> 4) | (((bits1.s6 >> 1) & 0x01) << 4)) * scale.s0 - minv.s0) * shared_y; \ + total_sums.s1 += ((((bits4.s5 & 0x00F0) >> 4) | (((bits1.s7 >> 1) & 0x01) << 4)) * scale.s1 - minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s2, 3); \ + total_sums.s0 += ((((bits4.s4 & 0x0F00) >> 8) | (((bits1.s6 >> 2) & 0x01) << 4)) * scale.s0 - minv.s0) * shared_y; \ + total_sums.s1 += ((((bits4.s5 & 0x0F00) >> 8) | (((bits1.s7 >> 2) & 0x01) << 4)) * scale.s1 - minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s3, 3); \ + total_sums.s0 += ((((bits4.s4 & 0xF000) >> 12) | (((bits1.s6 >> 3) & 0x01) << 4)) * scale.s0 - minv.s0) * shared_y; \ + total_sums.s1 += ((((bits4.s5 & 0xF000) >> 12) | (((bits1.s7 >> 3) & 0x01) << 4)) * scale.s1 - minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s4, 3); \ + total_sums.s0 += (((bits4.s6 & 0x000F) | (((bits1.s6 >> 4) & 0x01) << 4)) * scale.s0 - minv.s0) * shared_y; \ + total_sums.s1 += (((bits4.s7 & 0x000F) | (((bits1.s7 >> 4) & 0x01) << 4)) * scale.s1 - minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s5, 3); \ + total_sums.s0 += ((((bits4.s6 & 0x00F0) >> 4) | (((bits1.s6 >> 5) & 0x01) << 4)) * scale.s0 - minv.s0) * shared_y; \ + total_sums.s1 += ((((bits4.s7 & 0x00F0) >> 4) | (((bits1.s7 >> 5) & 0x01) << 4)) * scale.s1 - minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s6, 3); \ + total_sums.s0 += ((((bits4.s6 & 0x0F00) >> 8) | (((bits1.s6 >> 6) & 0x01) << 4)) * scale.s0 - minv.s0) * shared_y; \ + total_sums.s1 += ((((bits4.s7 & 0x0F00) >> 8) | (((bits1.s7 >> 6) & 0x01) << 4)) * scale.s1 - minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s7, 3); \ + total_sums.s0 += ((((bits4.s6 & 0xF000) >> 12) | (((bits1.s6 >> 7) & 0x01) << 4)) * scale.s0 - minv.s0) * shared_y; \ + total_sums.s1 += ((((bits4.s7 & 0xF000) >> 12) | (((bits1.s7 >> 7) & 0x01) << 4)) * scale.s1 - minv.s1) * shared_y; \ + + +#define dequantizeBlockAccum_ns_sgbroadcast_8_hi(total_sums, bits4, bits1, scale, minv, y) \ + float8 shared_y; \ + shared_y = sub_group_broadcast(y, 0); \ + total_sums.s0 += (((bits4.s0 & 0x000F) | ((bits1.s0 & 0x01) << 4)) * scale.s0 - minv.s0) * shared_y.s0; \ + total_sums.s0 += ((((bits4.s0 & 0x00F0) >> 4) | (((bits1.s0 >> 1) & 0x01) << 4)) * scale.s0 - minv.s0) * shared_y.s1; \ + total_sums.s0 += ((((bits4.s0 & 0x0F00) >> 8) | (((bits1.s0 >> 2) & 0x01) << 4)) * scale.s0 - minv.s0) * shared_y.s2; \ + total_sums.s0 += ((((bits4.s0 & 0xF000) >> 12) | (((bits1.s0 >> 3) & 0x01) << 4)) * scale.s0 - minv.s0) * shared_y.s3; \ + total_sums.s0 += (((bits4.s2 & 0x000F) | (((bits1.s0 >> 4) & 0x01) << 4)) * scale.s0 - minv.s0) * shared_y.s4; \ + total_sums.s0 += ((((bits4.s2 & 0x00F0) >> 4) | (((bits1.s0 >> 5) & 0x01) << 4)) * scale.s0 - minv.s0) * shared_y.s5; \ + total_sums.s0 += ((((bits4.s2 & 0x0F00) >> 8) | (((bits1.s0 >> 6) & 0x01) << 4)) * scale.s0 - minv.s0) * shared_y.s6; \ + total_sums.s0 += ((((bits4.s2 & 0xF000) >> 12) | (((bits1.s0 >> 7) & 0x01) << 4)) * scale.s0 - minv.s0) * shared_y.s7; \ + total_sums.s1 += (((bits4.s1 & 0x000F) | ((bits1.s1 & 0x01) << 4)) * scale.s1 - minv.s1) * shared_y.s0; \ + total_sums.s1 += ((((bits4.s1 & 0x00F0) >> 4) | (((bits1.s1 >> 1) & 0x01) << 4)) * scale.s1 - minv.s1) * shared_y.s1; \ + total_sums.s1 += ((((bits4.s1 & 0x0F00) >> 8) | (((bits1.s1 >> 2) & 0x01) << 4)) * scale.s1 - minv.s1) * shared_y.s2; \ + total_sums.s1 += ((((bits4.s1 & 0xF000) >> 12) | (((bits1.s1 >> 3) & 0x01) << 4)) * scale.s1 - minv.s1) * shared_y.s3; \ + total_sums.s1 += (((bits4.s3 & 0x000F) | (((bits1.s1 >> 4) & 0x01) << 4)) * scale.s1 - minv.s1) * shared_y.s4; \ + total_sums.s1 += ((((bits4.s3 & 0x00F0) >> 4) | (((bits1.s1 >> 5) & 0x01) << 4)) * scale.s1 - minv.s1) * shared_y.s5; \ + total_sums.s1 += ((((bits4.s3 & 0x0F00) >> 8) | (((bits1.s1 >> 6) & 0x01) << 4)) * scale.s1 - minv.s1) * shared_y.s6; \ + total_sums.s1 += ((((bits4.s3 & 0xF000) >> 12) | (((bits1.s1 >> 7) & 0x01) << 4)) * scale.s1 - minv.s1) * shared_y.s7; \ + shared_y = sub_group_broadcast(y, 1); \ + total_sums.s0 += (((bits4.s4 & 0x000F) | ((bits1.s2 & 0x01) << 4)) * scale.s0 - minv.s0) * shared_y.s0; \ + total_sums.s0 += ((((bits4.s4 & 0x00F0) >> 4) | (((bits1.s2 >> 1) & 0x01) << 4)) * scale.s0 - minv.s0) * shared_y.s1; \ + total_sums.s0 += ((((bits4.s4 & 0x0F00) >> 8) | (((bits1.s2 >> 2) & 0x01) << 4)) * scale.s0 - minv.s0) * shared_y.s2; \ + total_sums.s0 += ((((bits4.s4 & 0xF000) >> 12) | (((bits1.s2 >> 3) & 0x01) << 4)) * scale.s0 - minv.s0) * shared_y.s3; \ + total_sums.s0 += (((bits4.s6 & 0x000F) | (((bits1.s2 >> 4) & 0x01) << 4)) * scale.s0 - minv.s0) * shared_y.s4; \ + total_sums.s0 += ((((bits4.s6 & 0x00F0) >> 4) | (((bits1.s2 >> 5) & 0x01) << 4)) * scale.s0 - minv.s0) * shared_y.s5; \ + total_sums.s0 += ((((bits4.s6 & 0x0F00) >> 8) | (((bits1.s2 >> 6) & 0x01) << 4)) * scale.s0 - minv.s0) * shared_y.s6; \ + total_sums.s0 += ((((bits4.s6 & 0xF000) >> 12) | (((bits1.s2 >> 7) & 0x01) << 4)) * scale.s0 - minv.s0) * shared_y.s7; \ + total_sums.s1 += (((bits4.s5 & 0x000F) | ((bits1.s3 & 0x01) << 4)) * scale.s1 - minv.s1) * shared_y.s0; \ + total_sums.s1 += ((((bits4.s5 & 0x00F0) >> 4) | (((bits1.s3 >> 1) & 0x01) << 4)) * scale.s1 - minv.s1) * shared_y.s1; \ + total_sums.s1 += ((((bits4.s5 & 0x0F00) >> 8) | (((bits1.s3 >> 2) & 0x01) << 4)) * scale.s1 - minv.s1) * shared_y.s2; \ + total_sums.s1 += ((((bits4.s5 & 0xF000) >> 12) | (((bits1.s3 >> 3) & 0x01) << 4)) * scale.s1 - minv.s1) * shared_y.s3; \ + total_sums.s1 += (((bits4.s7 & 0x000F) | (((bits1.s3 >> 4) & 0x01) << 4)) * scale.s1 - minv.s1) * shared_y.s4; \ + total_sums.s1 += ((((bits4.s7 & 0x00F0) >> 4) | (((bits1.s3 >> 5) & 0x01) << 4)) * scale.s1 - minv.s1) * shared_y.s5; \ + total_sums.s1 += ((((bits4.s7 & 0x0F00) >> 8) | (((bits1.s3 >> 6) & 0x01) << 4)) * scale.s1 - minv.s1) * shared_y.s6; \ + total_sums.s1 += ((((bits4.s7 & 0xF000) >> 12) | (((bits1.s3 >> 7) & 0x01) << 4)) * scale.s1 - minv.s1) * shared_y.s7; \ + + +#define dequantizeBlockAccum_ns_sgbroadcast_8_lo(total_sums, bits4, bits1, scale, minv, y) \ + shared_y = sub_group_broadcast(y, 2); \ + total_sums.s0 += (((bits4.s0 & 0x000F) | ((bits1.s4 & 0x01) << 4)) * scale.s0 - minv.s0) * shared_y.s0; \ + total_sums.s0 += ((((bits4.s0 & 0x00F0) >> 4) | (((bits1.s4 >> 1) & 0x01) << 4)) * scale.s0 - minv.s0) * shared_y.s1; \ + total_sums.s0 += ((((bits4.s0 & 0x0F00) >> 8) | (((bits1.s4 >> 2) & 0x01) << 4)) * scale.s0 - minv.s0) * shared_y.s2; \ + total_sums.s0 += ((((bits4.s0 & 0xF000) >> 12) | (((bits1.s4 >> 3) & 0x01) << 4)) * scale.s0 - minv.s0) * shared_y.s3; \ + total_sums.s0 += (((bits4.s2 & 0x000F) | (((bits1.s4 >> 4) & 0x01) << 4)) * scale.s0 - minv.s0) * shared_y.s4; \ + total_sums.s0 += ((((bits4.s2 & 0x00F0) >> 4) | (((bits1.s4 >> 5) & 0x01) << 4)) * scale.s0 - minv.s0) * shared_y.s5; \ + total_sums.s0 += ((((bits4.s2 & 0x0F00) >> 8) | (((bits1.s4 >> 6) & 0x01) << 4)) * scale.s0 - minv.s0) * shared_y.s6; \ + total_sums.s0 += ((((bits4.s2 & 0xF000) >> 12) | (((bits1.s4 >> 7) & 0x01) << 4)) * scale.s0 - minv.s0) * shared_y.s7; \ + total_sums.s1 += (((bits4.s1 & 0x000F) | ((bits1.s5 & 0x01) << 4)) * scale.s1 - minv.s1) * shared_y.s0; \ + total_sums.s1 += ((((bits4.s1 & 0x00F0) >> 4) | (((bits1.s5 >> 1) & 0x01) << 4)) * scale.s1 - minv.s1) * shared_y.s1; \ + total_sums.s1 += ((((bits4.s1 & 0x0F00) >> 8) | (((bits1.s5 >> 2) & 0x01) << 4)) * scale.s1 - minv.s1) * shared_y.s2; \ + total_sums.s1 += ((((bits4.s1 & 0xF000) >> 12) | (((bits1.s5 >> 3) & 0x01) << 4)) * scale.s1 - minv.s1) * shared_y.s3; \ + total_sums.s1 += (((bits4.s3 & 0x000F) | (((bits1.s5 >> 4) & 0x01) << 4)) * scale.s1 - minv.s1) * shared_y.s4; \ + total_sums.s1 += ((((bits4.s3 & 0x00F0) >> 4) | (((bits1.s5 >> 5) & 0x01) << 4)) * scale.s1 - minv.s1) * shared_y.s5; \ + total_sums.s1 += ((((bits4.s3 & 0x0F00) >> 8) | (((bits1.s5 >> 6) & 0x01) << 4)) * scale.s1 - minv.s1) * shared_y.s6; \ + total_sums.s1 += ((((bits4.s3 & 0xF000) >> 12) | (((bits1.s5 >> 7) & 0x01) << 4)) * scale.s1 - minv.s1) * shared_y.s7; \ + shared_y = sub_group_broadcast(y, 3); \ + total_sums.s0 += (((bits4.s4 & 0x000F) | ((bits1.s6 & 0x01) << 4)) * scale.s0 - minv.s0) * shared_y.s0; \ + total_sums.s0 += ((((bits4.s4 & 0x00F0) >> 4) | (((bits1.s6 >> 1) & 0x01) << 4)) * scale.s0 - minv.s0) * shared_y.s1; \ + total_sums.s0 += ((((bits4.s4 & 0x0F00) >> 8) | (((bits1.s6 >> 2) & 0x01) << 4)) * scale.s0 - minv.s0) * shared_y.s2; \ + total_sums.s0 += ((((bits4.s4 & 0xF000) >> 12) | (((bits1.s6 >> 3) & 0x01) << 4)) * scale.s0 - minv.s0) * shared_y.s3; \ + total_sums.s0 += (((bits4.s6 & 0x000F) | (((bits1.s6 >> 4) & 0x01) << 4)) * scale.s0 - minv.s0) * shared_y.s4; \ + total_sums.s0 += ((((bits4.s6 & 0x00F0) >> 4) | (((bits1.s6 >> 5) & 0x01) << 4)) * scale.s0 - minv.s0) * shared_y.s5; \ + total_sums.s0 += ((((bits4.s6 & 0x0F00) >> 8) | (((bits1.s6 >> 6) & 0x01) << 4)) * scale.s0 - minv.s0) * shared_y.s6; \ + total_sums.s0 += ((((bits4.s6 & 0xF000) >> 12) | (((bits1.s6 >> 7) & 0x01) << 4)) * scale.s0 - minv.s0) * shared_y.s7; \ + total_sums.s1 += (((bits4.s5 & 0x000F) | ((bits1.s7 & 0x01) << 4)) * scale.s1 - minv.s1) * shared_y.s0; \ + total_sums.s1 += ((((bits4.s5 & 0x00F0) >> 4) | (((bits1.s7 >> 1) & 0x01) << 4)) * scale.s1 - minv.s1) * shared_y.s1; \ + total_sums.s1 += ((((bits4.s5 & 0x0F00) >> 8) | (((bits1.s7 >> 2) & 0x01) << 4)) * scale.s1 - minv.s1) * shared_y.s2; \ + total_sums.s1 += ((((bits4.s5 & 0xF000) >> 12) | (((bits1.s7 >> 3) & 0x01) << 4)) * scale.s1 - minv.s1) * shared_y.s3; \ + total_sums.s1 += (((bits4.s7 & 0x000F) | (((bits1.s7 >> 4) & 0x01) << 4)) * scale.s1 - minv.s1) * shared_y.s4; \ + total_sums.s1 += ((((bits4.s7 & 0x00F0) >> 4) | (((bits1.s7 >> 5) & 0x01) << 4)) * scale.s1 - minv.s1) * shared_y.s5; \ + total_sums.s1 += ((((bits4.s7 & 0x0F00) >> 8) | (((bits1.s7 >> 6) & 0x01) << 4)) * scale.s1 - minv.s1) * shared_y.s6; \ + total_sums.s1 += ((((bits4.s7 & 0xF000) >> 12) | (((bits1.s7 >> 7) & 0x01) << 4)) * scale.s1 - minv.s1) * shared_y.s7; \ + +#ifdef ADRENO_GPU +REQD_SUBGROUP_SIZE_64 +#endif +kernel void kernel_gemv_noshuffle_q5_k_f32( + read_only image1d_buffer_t src0_q, + read_only image1d_buffer_t src0_qh, + global half2 * src0_d, + global half2 * src0_m, + global uchar * src0_s, + read_only image1d_buffer_t src1, + global float * dst, + ulong offsetd, + int ne00, + int ne01, + uchar mask_d6, + uchar mask_d4, + uchar mask_hi2) +{ + uint groupId = get_local_id(1); + uint gid = get_global_id(0); + ushort slid = get_sub_group_local_id(); + + uint K = ne00; + uint M = ne01; + + uint LINE_STRIDE_A = M / 2; + uint BLOCK_STRIDE_A = NSUBGROUPS * M; + + uint LINE_STRIDE_A_QH = M / 2; + uint BLOCK_STRIDE_A_QH = NSUBGROUPS * M / 2; + uint scales_per_row = (K / QK_K) * 12; + + private uint4 regA; + private ushort4 regH; + private half2 regS; + private half2 regM; + private float8 regB; + + private float2 totalSum = (float2)(0.0f); + + for (uint k = groupId; k < (K / 32); k += NSUBGROUPS) { + uint sb = k / 8; + uint j = k % 8; + + half2 d = src0_d[gid + sb * LINE_STRIDE_A]; + half2 dm = src0_m[gid + sb * LINE_STRIDE_A]; + + global const uchar * sc0 = src0_s + 2 * gid * scales_per_row + sb * 12; + global const uchar * sc1 = src0_s + (2 * gid + 1) * scales_per_row + sb * 12; + + uchar sv0, mn0, sv1, mn1; + get_scale_min_k4(j, sc0, &sv0, &mn0, mask_d6, mask_d4, mask_hi2); + get_scale_min_k4(j, sc1, &sv1, &mn1, mask_d6, mask_d4, mask_hi2); + + regS = convert_half2(convert_float2(d) * convert_float2((uchar2)(sv0, sv1))); + regM = convert_half2(convert_float2(dm) * convert_float2((uchar2)(mn0, mn1))); + + if (slid < 4) { + regB.s0123 = read_imagef(src1, (slid * 2 + k * 8)); + regB.s4567 = read_imagef(src1, (1 + slid * 2 + k * 8)); + } + + regH.s0 = as_ushort(read_imageh(src0_qh, (gid + k * BLOCK_STRIDE_A_QH + LINE_STRIDE_A_QH * 0)).x); + regH.s1 = as_ushort(read_imageh(src0_qh, (gid + k * BLOCK_STRIDE_A_QH + LINE_STRIDE_A_QH * 1)).x); + regH.s2 = as_ushort(read_imageh(src0_qh, (gid + k * BLOCK_STRIDE_A_QH + LINE_STRIDE_A_QH * 2)).x); + regH.s3 = as_ushort(read_imageh(src0_qh, (gid + k * BLOCK_STRIDE_A_QH + LINE_STRIDE_A_QH * 3)).x); + + regA.s0 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 0)).x; + regA.s1 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 1)).x; + regA.s2 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 2)).x; + regA.s3 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 3)).x; +#ifdef VECTOR_SUB_GROUP_BROADCAST + dequantizeBlockAccum_ns_sgbroadcast_8_hi(totalSum, as_ushort8(regA), as_uchar8(regH), regS, regM, regB); +#else + dequantizeBlockAccum_ns_sgbroadcast_1_hi(totalSum, as_ushort8(regA), as_uchar8(regH), regS, regM, regB); +#endif // VECTOR_SUB_GROUP_BROADCAST + + regA.s0 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 4)).x; + regA.s1 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 5)).x; + regA.s2 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 6)).x; + regA.s3 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 7)).x; +#ifdef VECTOR_SUB_GROUP_BROADCAST + dequantizeBlockAccum_ns_sgbroadcast_8_lo(totalSum, as_ushort8(regA), as_uchar8(regH), regS, regM, regB); +#else + dequantizeBlockAccum_ns_sgbroadcast_1_lo(totalSum, as_ushort8(regA), as_uchar8(regH), regS, regM, regB); +#endif // VECTOR_SUB_GROUP_BROADCAST + } + + // reduction in local memory, assumes #wave=4 + local float2 reduceLM[SUBGROUP_SIZE * 3]; + if (groupId == 1) { + reduceLM[SUBGROUP_SIZE * 0 + slid] = totalSum; + } + if (groupId == 2) { + reduceLM[SUBGROUP_SIZE * 1 + slid] = totalSum; + } + if (groupId == 3) { + reduceLM[SUBGROUP_SIZE * 2 + slid] = totalSum; + } + + barrier(CLK_LOCAL_MEM_FENCE); + + if (groupId == 0) { + totalSum += reduceLM[SUBGROUP_SIZE * 0 + slid]; + } + if (groupId == 0) { + totalSum += reduceLM[SUBGROUP_SIZE * 1 + slid]; + } + if (groupId == 0) { + totalSum += reduceLM[SUBGROUP_SIZE * 2 + slid]; + } + + // 2 outputs per fiber in wave 0 + if (groupId == 0) { + dst = (global float*)((global char*)dst + offsetd); + vstore2(totalSum, 0, &(dst[gid * 2])); + } +} diff --git a/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q6_k_f32.cl b/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q6_k_f32.cl new file mode 100644 index 00000000000..6f89cf968b9 --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q6_k_f32.cl @@ -0,0 +1,293 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable +#pragma OPENCL EXTENSION cl_khr_subgroups : enable + +#ifdef cl_intel_required_subgroup_size +#pragma OPENCL EXTENSION cl_intel_required_subgroup_size : enable +#define INTEL_GPU 1 +#define REQD_SUBGROUP_SIZE_16 __attribute__((intel_reqd_sub_group_size(16))) +#define REQD_SUBGROUP_SIZE_32 __attribute__((intel_reqd_sub_group_size(32))) +#elif defined(cl_qcom_reqd_sub_group_size) +#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable +#define ADRENO_GPU 1 +#define REQD_SUBGROUP_SIZE_64 __attribute__((qcom_reqd_sub_group_size("half"))) +#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full"))) +#endif + +#define NSUBGROUPS 4 +#define SUBGROUP_SIZE 64 + +#define dequantize_block_acc_bcast_8_hi(total_sum, bits4, bits2, scale_d, scale_s, y) \ + float8 shared_y; \ + shared_y = sub_group_broadcast(y, 0); \ + total_sum.s0 += ((float)(((bits4.s0 & 0x000F) ) | ((bits2.s0 & 0x03) << 4)) - 32.f) * scale_s.s0 * scale_d.s0 * shared_y.s0; \ + total_sum.s0 += ((float)(((bits4.s0 & 0x00F0) >> 4) | ((bits2.s0 & 0x0C) << 2)) - 32.f) * scale_s.s0 * scale_d.s0 * shared_y.s1; \ + total_sum.s0 += ((float)(((bits4.s0 & 0x0F00) >> 8) | ((bits2.s0 & 0x30) )) - 32.f) * scale_s.s0 * scale_d.s0 * shared_y.s2; \ + total_sum.s0 += ((float)(((bits4.s0 & 0xF000) >> 12) | ((bits2.s0 & 0xC0) >> 2)) - 32.f) * scale_s.s0 * scale_d.s0 * shared_y.s3; \ + total_sum.s0 += ((float)(((bits4.s2 & 0x000F) ) | ((bits2.s2 & 0x03) << 4)) - 32.f) * scale_s.s0 * scale_d.s0 * shared_y.s4; \ + total_sum.s0 += ((float)(((bits4.s2 & 0x00F0) >> 4) | ((bits2.s2 & 0x0C) << 2)) - 32.f) * scale_s.s0 * scale_d.s0 * shared_y.s5; \ + total_sum.s0 += ((float)(((bits4.s2 & 0x0F00) >> 8) | ((bits2.s2 & 0x30) )) - 32.f) * scale_s.s0 * scale_d.s0 * shared_y.s6; \ + total_sum.s0 += ((float)(((bits4.s2 & 0xF000) >> 12) | ((bits2.s2 & 0xC0) >> 2)) - 32.f) * scale_s.s0 * scale_d.s0 * shared_y.s7; \ + total_sum.s1 += ((float)(((bits4.s1 & 0x000F) ) | ((bits2.s1 & 0x03) << 4)) - 32.f) * scale_s.s2 * scale_d.s1 * shared_y.s0; \ + total_sum.s1 += ((float)(((bits4.s1 & 0x00F0) >> 4) | ((bits2.s1 & 0x0C) << 2)) - 32.f) * scale_s.s2 * scale_d.s1 * shared_y.s1; \ + total_sum.s1 += ((float)(((bits4.s1 & 0x0F00) >> 8) | ((bits2.s1 & 0x30) )) - 32.f) * scale_s.s2 * scale_d.s1 * shared_y.s2; \ + total_sum.s1 += ((float)(((bits4.s1 & 0xF000) >> 12) | ((bits2.s1 & 0xC0) >> 2)) - 32.f) * scale_s.s2 * scale_d.s1 * shared_y.s3; \ + total_sum.s1 += ((float)(((bits4.s3 & 0x000F) ) | ((bits2.s3 & 0x03) << 4)) - 32.f) * scale_s.s2 * scale_d.s1 * shared_y.s4; \ + total_sum.s1 += ((float)(((bits4.s3 & 0x00F0) >> 4) | ((bits2.s3 & 0x0C) << 2)) - 32.f) * scale_s.s2 * scale_d.s1 * shared_y.s5; \ + total_sum.s1 += ((float)(((bits4.s3 & 0x0F00) >> 8) | ((bits2.s3 & 0x30) )) - 32.f) * scale_s.s2 * scale_d.s1 * shared_y.s6; \ + total_sum.s1 += ((float)(((bits4.s3 & 0xF000) >> 12) | ((bits2.s3 & 0xC0) >> 2)) - 32.f) * scale_s.s2 * scale_d.s1 * shared_y.s7; \ + shared_y = sub_group_broadcast(y, 1); \ + total_sum.s0 += ((float)(((bits4.s4 & 0x000F) ) | ((bits2.s4 & 0x03) << 4)) - 32.f) * scale_s.s0 * scale_d.s0 * shared_y.s0; \ + total_sum.s0 += ((float)(((bits4.s4 & 0x00F0) >> 4) | ((bits2.s4 & 0x0C) << 2)) - 32.f) * scale_s.s0 * scale_d.s0 * shared_y.s1; \ + total_sum.s0 += ((float)(((bits4.s4 & 0x0F00) >> 8) | ((bits2.s4 & 0x30) )) - 32.f) * scale_s.s0 * scale_d.s0 * shared_y.s2; \ + total_sum.s0 += ((float)(((bits4.s4 & 0xF000) >> 12) | ((bits2.s4 & 0xC0) >> 2)) - 32.f) * scale_s.s0 * scale_d.s0 * shared_y.s3; \ + total_sum.s0 += ((float)(((bits4.s6 & 0x000F) ) | ((bits2.s6 & 0x03) << 4)) - 32.f) * scale_s.s0 * scale_d.s0 * shared_y.s4; \ + total_sum.s0 += ((float)(((bits4.s6 & 0x00F0) >> 4) | ((bits2.s6 & 0x0C) << 2)) - 32.f) * scale_s.s0 * scale_d.s0 * shared_y.s5; \ + total_sum.s0 += ((float)(((bits4.s6 & 0x0F00) >> 8) | ((bits2.s6 & 0x30) )) - 32.f) * scale_s.s0 * scale_d.s0 * shared_y.s6; \ + total_sum.s0 += ((float)(((bits4.s6 & 0xF000) >> 12) | ((bits2.s6 & 0xC0) >> 2)) - 32.f) * scale_s.s0 * scale_d.s0 * shared_y.s7; \ + total_sum.s1 += ((float)(((bits4.s5 & 0x000F) ) | ((bits2.s5 & 0x03) << 4)) - 32.f) * scale_s.s2 * scale_d.s1 * shared_y.s0; \ + total_sum.s1 += ((float)(((bits4.s5 & 0x00F0) >> 4) | ((bits2.s5 & 0x0C) << 2)) - 32.f) * scale_s.s2 * scale_d.s1 * shared_y.s1; \ + total_sum.s1 += ((float)(((bits4.s5 & 0x0F00) >> 8) | ((bits2.s5 & 0x30) )) - 32.f) * scale_s.s2 * scale_d.s1 * shared_y.s2; \ + total_sum.s1 += ((float)(((bits4.s5 & 0xF000) >> 12) | ((bits2.s5 & 0xC0) >> 2)) - 32.f) * scale_s.s2 * scale_d.s1 * shared_y.s3; \ + total_sum.s1 += ((float)(((bits4.s7 & 0x000F) ) | ((bits2.s7 & 0x03) << 4)) - 32.f) * scale_s.s2 * scale_d.s1 * shared_y.s4; \ + total_sum.s1 += ((float)(((bits4.s7 & 0x00F0) >> 4) | ((bits2.s7 & 0x0C) << 2)) - 32.f) * scale_s.s2 * scale_d.s1 * shared_y.s5; \ + total_sum.s1 += ((float)(((bits4.s7 & 0x0F00) >> 8) | ((bits2.s7 & 0x30) )) - 32.f) * scale_s.s2 * scale_d.s1 * shared_y.s6; \ + total_sum.s1 += ((float)(((bits4.s7 & 0xF000) >> 12) | ((bits2.s7 & 0xC0) >> 2)) - 32.f) * scale_s.s2 * scale_d.s1 * shared_y.s7; \ + +#define dequantize_block_acc_bcast_8_lo(total_sum, bits4, bits2, scale_d, scale_s, y) \ + shared_y = sub_group_broadcast(y, 2); \ + total_sum.s0 += ((float)(((bits4.s0 & 0x000F) ) | ((bits2.s0 & 0x03) << 4)) - 32.f) * scale_s.s1 * scale_d.s0 * shared_y.s0; \ + total_sum.s0 += ((float)(((bits4.s0 & 0x00F0) >> 4) | ((bits2.s0 & 0x0C) << 2)) - 32.f) * scale_s.s1 * scale_d.s0 * shared_y.s1; \ + total_sum.s0 += ((float)(((bits4.s0 & 0x0F00) >> 8) | ((bits2.s0 & 0x30) )) - 32.f) * scale_s.s1 * scale_d.s0 * shared_y.s2; \ + total_sum.s0 += ((float)(((bits4.s0 & 0xF000) >> 12) | ((bits2.s0 & 0xC0) >> 2)) - 32.f) * scale_s.s1 * scale_d.s0 * shared_y.s3; \ + total_sum.s0 += ((float)(((bits4.s2 & 0x000F) ) | ((bits2.s2 & 0x03) << 4)) - 32.f) * scale_s.s1 * scale_d.s0 * shared_y.s4; \ + total_sum.s0 += ((float)(((bits4.s2 & 0x00F0) >> 4) | ((bits2.s2 & 0x0C) << 2)) - 32.f) * scale_s.s1 * scale_d.s0 * shared_y.s5; \ + total_sum.s0 += ((float)(((bits4.s2 & 0x0F00) >> 8) | ((bits2.s2 & 0x30) )) - 32.f) * scale_s.s1 * scale_d.s0 * shared_y.s6; \ + total_sum.s0 += ((float)(((bits4.s2 & 0xF000) >> 12) | ((bits2.s2 & 0xC0) >> 2)) - 32.f) * scale_s.s1 * scale_d.s0 * shared_y.s7; \ + total_sum.s1 += ((float)(((bits4.s1 & 0x000F) ) | ((bits2.s1 & 0x03) << 4)) - 32.f) * scale_s.s3 * scale_d.s1 * shared_y.s0; \ + total_sum.s1 += ((float)(((bits4.s1 & 0x00F0) >> 4) | ((bits2.s1 & 0x0C) << 2)) - 32.f) * scale_s.s3 * scale_d.s1 * shared_y.s1; \ + total_sum.s1 += ((float)(((bits4.s1 & 0x0F00) >> 8) | ((bits2.s1 & 0x30) )) - 32.f) * scale_s.s3 * scale_d.s1 * shared_y.s2; \ + total_sum.s1 += ((float)(((bits4.s1 & 0xF000) >> 12) | ((bits2.s1 & 0xC0) >> 2)) - 32.f) * scale_s.s3 * scale_d.s1 * shared_y.s3; \ + total_sum.s1 += ((float)(((bits4.s3 & 0x000F) ) | ((bits2.s3 & 0x03) << 4)) - 32.f) * scale_s.s3 * scale_d.s1 * shared_y.s4; \ + total_sum.s1 += ((float)(((bits4.s3 & 0x00F0) >> 4) | ((bits2.s3 & 0x0C) << 2)) - 32.f) * scale_s.s3 * scale_d.s1 * shared_y.s5; \ + total_sum.s1 += ((float)(((bits4.s3 & 0x0F00) >> 8) | ((bits2.s3 & 0x30) )) - 32.f) * scale_s.s3 * scale_d.s1 * shared_y.s6; \ + total_sum.s1 += ((float)(((bits4.s3 & 0xF000) >> 12) | ((bits2.s3 & 0xC0) >> 2)) - 32.f) * scale_s.s3 * scale_d.s1 * shared_y.s7; \ + shared_y = sub_group_broadcast(y, 3); \ + total_sum.s0 += ((float)(((bits4.s4 & 0x000F) ) | ((bits2.s4 & 0x03) << 4)) - 32.f) * scale_s.s1 * scale_d.s0 * shared_y.s0; \ + total_sum.s0 += ((float)(((bits4.s4 & 0x00F0) >> 4) | ((bits2.s4 & 0x0C) << 2)) - 32.f) * scale_s.s1 * scale_d.s0 * shared_y.s1; \ + total_sum.s0 += ((float)(((bits4.s4 & 0x0F00) >> 8) | ((bits2.s4 & 0x30) )) - 32.f) * scale_s.s1 * scale_d.s0 * shared_y.s2; \ + total_sum.s0 += ((float)(((bits4.s4 & 0xF000) >> 12) | ((bits2.s4 & 0xC0) >> 2)) - 32.f) * scale_s.s1 * scale_d.s0 * shared_y.s3; \ + total_sum.s0 += ((float)(((bits4.s6 & 0x000F) ) | ((bits2.s6 & 0x03) << 4)) - 32.f) * scale_s.s1 * scale_d.s0 * shared_y.s4; \ + total_sum.s0 += ((float)(((bits4.s6 & 0x00F0) >> 4) | ((bits2.s6 & 0x0C) << 2)) - 32.f) * scale_s.s1 * scale_d.s0 * shared_y.s5; \ + total_sum.s0 += ((float)(((bits4.s6 & 0x0F00) >> 8) | ((bits2.s6 & 0x30) )) - 32.f) * scale_s.s1 * scale_d.s0 * shared_y.s6; \ + total_sum.s0 += ((float)(((bits4.s6 & 0xF000) >> 12) | ((bits2.s6 & 0xC0) >> 2)) - 32.f) * scale_s.s1 * scale_d.s0 * shared_y.s7; \ + total_sum.s1 += ((float)(((bits4.s5 & 0x000F) ) | ((bits2.s5 & 0x03) << 4)) - 32.f) * scale_s.s3 * scale_d.s1 * shared_y.s0; \ + total_sum.s1 += ((float)(((bits4.s5 & 0x00F0) >> 4) | ((bits2.s5 & 0x0C) << 2)) - 32.f) * scale_s.s3 * scale_d.s1 * shared_y.s1; \ + total_sum.s1 += ((float)(((bits4.s5 & 0x0F00) >> 8) | ((bits2.s5 & 0x30) )) - 32.f) * scale_s.s3 * scale_d.s1 * shared_y.s2; \ + total_sum.s1 += ((float)(((bits4.s5 & 0xF000) >> 12) | ((bits2.s5 & 0xC0) >> 2)) - 32.f) * scale_s.s3 * scale_d.s1 * shared_y.s3; \ + total_sum.s1 += ((float)(((bits4.s7 & 0x000F) ) | ((bits2.s7 & 0x03) << 4)) - 32.f) * scale_s.s3 * scale_d.s1 * shared_y.s4; \ + total_sum.s1 += ((float)(((bits4.s7 & 0x00F0) >> 4) | ((bits2.s7 & 0x0C) << 2)) - 32.f) * scale_s.s3 * scale_d.s1 * shared_y.s5; \ + total_sum.s1 += ((float)(((bits4.s7 & 0x0F00) >> 8) | ((bits2.s7 & 0x30) )) - 32.f) * scale_s.s3 * scale_d.s1 * shared_y.s6; \ + total_sum.s1 += ((float)(((bits4.s7 & 0xF000) >> 12) | ((bits2.s7 & 0xC0) >> 2)) - 32.f) * scale_s.s3 * scale_d.s1 * shared_y.s7; \ + +#define dequantize_block_acc_bcast_1_hi(total_sum, bits4, bits2, scale_d, scale_s, y) \ + float shared_y; \ + shared_y = sub_group_broadcast(y.s0, 0); \ + total_sum.s0 += ((float)(((bits4.s0 & 0x000F) ) | ((bits2.s0 & 0x03) << 4)) - 32.f) * scale_s.s0 * scale_d.s0 * shared_y; \ + total_sum.s1 += ((float)(((bits4.s1 & 0x000F) ) | ((bits2.s1 & 0x03) << 4)) - 32.f) * scale_s.s2 * scale_d.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s1, 0); \ + total_sum.s0 += ((float)(((bits4.s0 & 0x00F0) >> 4) | ((bits2.s0 & 0x0C) << 2)) - 32.f) * scale_s.s0 * scale_d.s0 * shared_y; \ + total_sum.s1 += ((float)(((bits4.s1 & 0x00F0) >> 4) | ((bits2.s1 & 0x0C) << 2)) - 32.f) * scale_s.s2 * scale_d.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s2, 0); \ + total_sum.s0 += ((float)(((bits4.s0 & 0x0F00) >> 8) | ((bits2.s0 & 0x30) )) - 32.f) * scale_s.s0 * scale_d.s0 * shared_y; \ + total_sum.s1 += ((float)(((bits4.s1 & 0x0F00) >> 8) | ((bits2.s1 & 0x30) )) - 32.f) * scale_s.s2 * scale_d.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s3, 0); \ + total_sum.s0 += ((float)(((bits4.s0 & 0xF000) >> 12) | ((bits2.s0 & 0xC0) >> 2)) - 32.f) * scale_s.s0 * scale_d.s0 * shared_y; \ + total_sum.s1 += ((float)(((bits4.s1 & 0xF000) >> 12) | ((bits2.s1 & 0xC0) >> 2)) - 32.f) * scale_s.s2 * scale_d.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s4, 0); \ + total_sum.s0 += ((float)(((bits4.s2 & 0x000F) ) | ((bits2.s2 & 0x03) << 4)) - 32.f) * scale_s.s0 * scale_d.s0 * shared_y; \ + total_sum.s1 += ((float)(((bits4.s3 & 0x000F) ) | ((bits2.s3 & 0x03) << 4)) - 32.f) * scale_s.s2 * scale_d.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s5, 0); \ + total_sum.s0 += ((float)(((bits4.s2 & 0x00F0) >> 4) | ((bits2.s2 & 0x0C) << 2)) - 32.f) * scale_s.s0 * scale_d.s0 * shared_y; \ + total_sum.s1 += ((float)(((bits4.s3 & 0x00F0) >> 4) | ((bits2.s3 & 0x0C) << 2)) - 32.f) * scale_s.s2 * scale_d.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s6, 0); \ + total_sum.s0 += ((float)(((bits4.s2 & 0x0F00) >> 8) | ((bits2.s2 & 0x30) )) - 32.f) * scale_s.s0 * scale_d.s0 * shared_y; \ + total_sum.s1 += ((float)(((bits4.s3 & 0x0F00) >> 8) | ((bits2.s3 & 0x30) )) - 32.f) * scale_s.s2 * scale_d.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s7, 0); \ + total_sum.s0 += ((float)(((bits4.s2 & 0xF000) >> 12) | ((bits2.s2 & 0xC0) >> 2)) - 32.f) * scale_s.s0 * scale_d.s0 * shared_y; \ + total_sum.s1 += ((float)(((bits4.s3 & 0xF000) >> 12) | ((bits2.s3 & 0xC0) >> 2)) - 32.f) * scale_s.s2 * scale_d.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s0, 1); \ + total_sum.s0 += ((float)(((bits4.s4 & 0x000F) ) | ((bits2.s4 & 0x03) << 4)) - 32.f) * scale_s.s0 * scale_d.s0 * shared_y; \ + total_sum.s1 += ((float)(((bits4.s5 & 0x000F) ) | ((bits2.s5 & 0x03) << 4)) - 32.f) * scale_s.s2 * scale_d.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s1, 1); \ + total_sum.s0 += ((float)(((bits4.s4 & 0x00F0) >> 4) | ((bits2.s4 & 0x0C) << 2)) - 32.f) * scale_s.s0 * scale_d.s0 * shared_y; \ + total_sum.s1 += ((float)(((bits4.s5 & 0x00F0) >> 4) | ((bits2.s5 & 0x0C) << 2)) - 32.f) * scale_s.s2 * scale_d.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s2, 1); \ + total_sum.s0 += ((float)(((bits4.s4 & 0x0F00) >> 8) | ((bits2.s4 & 0x30) )) - 32.f) * scale_s.s0 * scale_d.s0 * shared_y; \ + total_sum.s1 += ((float)(((bits4.s5 & 0x0F00) >> 8) | ((bits2.s5 & 0x30) )) - 32.f) * scale_s.s2 * scale_d.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s3, 1); \ + total_sum.s0 += ((float)(((bits4.s4 & 0xF000) >> 12) | ((bits2.s4 & 0xC0) >> 2)) - 32.f) * scale_s.s0 * scale_d.s0 * shared_y; \ + total_sum.s1 += ((float)(((bits4.s5 & 0xF000) >> 12) | ((bits2.s5 & 0xC0) >> 2)) - 32.f) * scale_s.s2 * scale_d.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s4, 1); \ + total_sum.s0 += ((float)(((bits4.s6 & 0x000F) ) | ((bits2.s6 & 0x03) << 4)) - 32.f) * scale_s.s0 * scale_d.s0 * shared_y; \ + total_sum.s1 += ((float)(((bits4.s7 & 0x000F) ) | ((bits2.s7 & 0x03) << 4)) - 32.f) * scale_s.s2 * scale_d.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s5, 1); \ + total_sum.s0 += ((float)(((bits4.s6 & 0x00F0) >> 4) | ((bits2.s6 & 0x0C) << 2)) - 32.f) * scale_s.s0 * scale_d.s0 * shared_y; \ + total_sum.s1 += ((float)(((bits4.s7 & 0x00F0) >> 4) | ((bits2.s7 & 0x0C) << 2)) - 32.f) * scale_s.s2 * scale_d.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s6, 1); \ + total_sum.s0 += ((float)(((bits4.s6 & 0x0F00) >> 8) | ((bits2.s6 & 0x30) )) - 32.f) * scale_s.s0 * scale_d.s0 * shared_y; \ + total_sum.s1 += ((float)(((bits4.s7 & 0x0F00) >> 8) | ((bits2.s7 & 0x30) )) - 32.f) * scale_s.s2 * scale_d.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s7, 1); \ + total_sum.s0 += ((float)(((bits4.s6 & 0xF000) >> 12) | ((bits2.s6 & 0xC0) >> 2)) - 32.f) * scale_s.s0 * scale_d.s0 * shared_y; \ + total_sum.s1 += ((float)(((bits4.s7 & 0xF000) >> 12) | ((bits2.s7 & 0xC0) >> 2)) - 32.f) * scale_s.s2 * scale_d.s1 * shared_y; \ + +#define dequantize_block_acc_bcast_1_lo(total_sum, bits4, bits2, scale_d, scale_s, y) \ + shared_y = sub_group_broadcast(y.s0, 2); \ + total_sum.s0 += ((float)(((bits4.s0 & 0x000F) ) | ((bits2.s0 & 0x03) << 4)) - 32.f) * scale_s.s1 * scale_d.s0 * shared_y; \ + total_sum.s1 += ((float)(((bits4.s1 & 0x000F) ) | ((bits2.s1 & 0x03) << 4)) - 32.f) * scale_s.s3 * scale_d.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s1, 2); \ + total_sum.s0 += ((float)(((bits4.s0 & 0x00F0) >> 4) | ((bits2.s0 & 0x0C) << 2)) - 32.f) * scale_s.s1 * scale_d.s0 * shared_y; \ + total_sum.s1 += ((float)(((bits4.s1 & 0x00F0) >> 4) | ((bits2.s1 & 0x0C) << 2)) - 32.f) * scale_s.s3 * scale_d.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s2, 2); \ + total_sum.s0 += ((float)(((bits4.s0 & 0x0F00) >> 8) | ((bits2.s0 & 0x30) )) - 32.f) * scale_s.s1 * scale_d.s0 * shared_y; \ + total_sum.s1 += ((float)(((bits4.s1 & 0x0F00) >> 8) | ((bits2.s1 & 0x30) )) - 32.f) * scale_s.s3 * scale_d.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s3, 2); \ + total_sum.s0 += ((float)(((bits4.s0 & 0xF000) >> 12) | ((bits2.s0 & 0xC0) >> 2)) - 32.f) * scale_s.s1 * scale_d.s0 * shared_y; \ + total_sum.s1 += ((float)(((bits4.s1 & 0xF000) >> 12) | ((bits2.s1 & 0xC0) >> 2)) - 32.f) * scale_s.s3 * scale_d.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s4, 2); \ + total_sum.s0 += ((float)(((bits4.s2 & 0x000F) ) | ((bits2.s2 & 0x03) << 4)) - 32.f) * scale_s.s1 * scale_d.s0 * shared_y; \ + total_sum.s1 += ((float)(((bits4.s3 & 0x000F) ) | ((bits2.s3 & 0x03) << 4)) - 32.f) * scale_s.s3 * scale_d.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s5, 2); \ + total_sum.s0 += ((float)(((bits4.s2 & 0x00F0) >> 4) | ((bits2.s2 & 0x0C) << 2)) - 32.f) * scale_s.s1 * scale_d.s0 * shared_y; \ + total_sum.s1 += ((float)(((bits4.s3 & 0x00F0) >> 4) | ((bits2.s3 & 0x0C) << 2)) - 32.f) * scale_s.s3 * scale_d.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s6, 2); \ + total_sum.s0 += ((float)(((bits4.s2 & 0x0F00) >> 8) | ((bits2.s2 & 0x30) )) - 32.f) * scale_s.s1 * scale_d.s0 * shared_y; \ + total_sum.s1 += ((float)(((bits4.s3 & 0x0F00) >> 8) | ((bits2.s3 & 0x30) )) - 32.f) * scale_s.s3 * scale_d.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s7, 2); \ + total_sum.s0 += ((float)(((bits4.s2 & 0xF000) >> 12) | ((bits2.s2 & 0xC0) >> 2)) - 32.f) * scale_s.s1 * scale_d.s0 * shared_y; \ + total_sum.s1 += ((float)(((bits4.s3 & 0xF000) >> 12) | ((bits2.s3 & 0xC0) >> 2)) - 32.f) * scale_s.s3 * scale_d.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s0, 3); \ + total_sum.s0 += ((float)(((bits4.s4 & 0x000F) ) | ((bits2.s4 & 0x03) << 4)) - 32.f) * scale_s.s1 * scale_d.s0 * shared_y; \ + total_sum.s1 += ((float)(((bits4.s5 & 0x000F) ) | ((bits2.s5 & 0x03) << 4)) - 32.f) * scale_s.s3 * scale_d.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s1, 3); \ + total_sum.s0 += ((float)(((bits4.s4 & 0x00F0) >> 4) | ((bits2.s4 & 0x0C) << 2)) - 32.f) * scale_s.s1 * scale_d.s0 * shared_y; \ + total_sum.s1 += ((float)(((bits4.s5 & 0x00F0) >> 4) | ((bits2.s5 & 0x0C) << 2)) - 32.f) * scale_s.s3 * scale_d.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s2, 3); \ + total_sum.s0 += ((float)(((bits4.s4 & 0x0F00) >> 8) | ((bits2.s4 & 0x30) )) - 32.f) * scale_s.s1 * scale_d.s0 * shared_y; \ + total_sum.s1 += ((float)(((bits4.s5 & 0x0F00) >> 8) | ((bits2.s5 & 0x30) )) - 32.f) * scale_s.s3 * scale_d.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s3, 3); \ + total_sum.s0 += ((float)(((bits4.s4 & 0xF000) >> 12) | ((bits2.s4 & 0xC0) >> 2)) - 32.f) * scale_s.s1 * scale_d.s0 * shared_y; \ + total_sum.s1 += ((float)(((bits4.s5 & 0xF000) >> 12) | ((bits2.s5 & 0xC0) >> 2)) - 32.f) * scale_s.s3 * scale_d.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s4, 3); \ + total_sum.s0 += ((float)(((bits4.s6 & 0x000F) ) | ((bits2.s6 & 0x03) << 4)) - 32.f) * scale_s.s1 * scale_d.s0 * shared_y; \ + total_sum.s1 += ((float)(((bits4.s7 & 0x000F) ) | ((bits2.s7 & 0x03) << 4)) - 32.f) * scale_s.s3 * scale_d.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s5, 3); \ + total_sum.s0 += ((float)(((bits4.s6 & 0x00F0) >> 4) | ((bits2.s6 & 0x0C) << 2)) - 32.f) * scale_s.s1 * scale_d.s0 * shared_y; \ + total_sum.s1 += ((float)(((bits4.s7 & 0x00F0) >> 4) | ((bits2.s7 & 0x0C) << 2)) - 32.f) * scale_s.s3 * scale_d.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s6, 3); \ + total_sum.s0 += ((float)(((bits4.s6 & 0x0F00) >> 8) | ((bits2.s6 & 0x30) )) - 32.f) * scale_s.s1 * scale_d.s0 * shared_y; \ + total_sum.s1 += ((float)(((bits4.s7 & 0x0F00) >> 8) | ((bits2.s7 & 0x30) )) - 32.f) * scale_s.s3 * scale_d.s1 * shared_y; \ + shared_y = sub_group_broadcast(y.s7, 3); \ + total_sum.s0 += ((float)(((bits4.s6 & 0xF000) >> 12) | ((bits2.s6 & 0xC0) >> 2)) - 32.f) * scale_s.s1 * scale_d.s0 * shared_y; \ + total_sum.s1 += ((float)(((bits4.s7 & 0xF000) >> 12) | ((bits2.s7 & 0xC0) >> 2)) - 32.f) * scale_s.s3 * scale_d.s1 * shared_y; \ + +#if defined(ADRENO_GPU) +REQD_SUBGROUP_SIZE_64 +#endif +kernel void kernel_gemv_noshuffle_q6_K_f32( + read_only image1d_buffer_t src0_ql, + read_only image1d_buffer_t src0_qh, + global half2 * src0_s, + global half2 * src0_d, + read_only image1d_buffer_t src1, + global float * dst, + ulong offsetd, + int ne00, + int ne01 +) { + int grp = get_local_id(1); + int gid = get_global_id(0); + ushort slid = get_sub_group_local_id(); + + int nb = ne00 / 32; + + uint4 reg_a_l; + ushort4 reg_a_h; + half2 reg_d; + char4 reg_s; + float8 reg_b; + + float2 total_sum = 0.0f; + + int line_stride_a = ne01 / 2; + int block_stride_a = NSUBGROUPS * ne01; + + for (int k = grp; k < nb; k += NSUBGROUPS) { + reg_d = src0_d[gid + k/8 * line_stride_a]; + reg_s = as_char4(src0_s[gid + k * line_stride_a]); + + if (slid < 4) { + reg_b.s0123 = read_imagef(src1, 0 + slid*2 + k*8); + reg_b.s4567 = read_imagef(src1, 1 + slid*2 + k*8); + } + + reg_a_l.s0 = read_imageui(src0_ql, gid + k*block_stride_a + line_stride_a*0).x; + reg_a_l.s1 = read_imageui(src0_ql, gid + k*block_stride_a + line_stride_a*1).x; + reg_a_l.s2 = read_imageui(src0_ql, gid + k*block_stride_a + line_stride_a*2).x; + reg_a_l.s3 = read_imageui(src0_ql, gid + k*block_stride_a + line_stride_a*3).x; + + reg_a_h.s0 = as_ushort(read_imageh(src0_qh, gid + k*block_stride_a + line_stride_a*0).x); + reg_a_h.s1 = as_ushort(read_imageh(src0_qh, gid + k*block_stride_a + line_stride_a*1).x); + reg_a_h.s2 = as_ushort(read_imageh(src0_qh, gid + k*block_stride_a + line_stride_a*2).x); + reg_a_h.s3 = as_ushort(read_imageh(src0_qh, gid + k*block_stride_a + line_stride_a*3).x); + +#ifdef VECTOR_SUB_GROUP_BROADCAT + dequantize_block_acc_bcast_8_hi(total_sum, as_ushort8(reg_a_l), as_uchar8(reg_a_h), reg_d, reg_s, reg_b); +#else + dequantize_block_acc_bcast_1_hi(total_sum, as_ushort8(reg_a_l), as_uchar8(reg_a_h), reg_d, reg_s, reg_b); +#endif // VECTOR_SUB_GROUP_BROADCAT + + reg_a_l.s0 = read_imageui(src0_ql, gid + k*block_stride_a + line_stride_a*4).x; + reg_a_l.s1 = read_imageui(src0_ql, gid + k*block_stride_a + line_stride_a*5).x; + reg_a_l.s2 = read_imageui(src0_ql, gid + k*block_stride_a + line_stride_a*6).x; + reg_a_l.s3 = read_imageui(src0_ql, gid + k*block_stride_a + line_stride_a*7).x; + + reg_a_h.s0 = as_ushort(read_imageh(src0_qh, gid + k*block_stride_a + line_stride_a*4).x); + reg_a_h.s1 = as_ushort(read_imageh(src0_qh, gid + k*block_stride_a + line_stride_a*5).x); + reg_a_h.s2 = as_ushort(read_imageh(src0_qh, gid + k*block_stride_a + line_stride_a*6).x); + reg_a_h.s3 = as_ushort(read_imageh(src0_qh, gid + k*block_stride_a + line_stride_a*7).x); + +#ifdef VECTOR_SUB_GROUP_BROADCAT + dequantize_block_acc_bcast_8_lo(total_sum, as_ushort8(reg_a_l), as_uchar8(reg_a_h), reg_d, reg_s, reg_b); +#else + dequantize_block_acc_bcast_1_lo(total_sum, as_ushort8(reg_a_l), as_uchar8(reg_a_h), reg_d, reg_s, reg_b); +#endif // VECTOR_SUB_GROUP_BROADCAT + } + + local float2 reduce_lm[SUBGROUP_SIZE * 3]; + if (grp == 1) { + reduce_lm[SUBGROUP_SIZE*0 + slid] = total_sum; + } + if (grp == 2) { + reduce_lm[SUBGROUP_SIZE*1 + slid] = total_sum; + } + if (grp == 3) { + reduce_lm[SUBGROUP_SIZE*2 + slid] = total_sum; + } + + barrier(CLK_LOCAL_MEM_FENCE); + + if (grp == 0) { + total_sum += reduce_lm[SUBGROUP_SIZE*0 + slid]; + } + if (grp == 0) { + total_sum += reduce_lm[SUBGROUP_SIZE*1 + slid]; + } + if (grp == 0) { + total_sum += reduce_lm[SUBGROUP_SIZE*2 + slid]; + } + + if (grp == 0) { + dst = (global float*)((global char*)dst + offsetd); + vstore2(total_sum, 0, &(dst[gid * 2])); + } +} diff --git a/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q8_0_f32.cl b/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q8_0_f32.cl new file mode 100644 index 00000000000..9703b693e56 --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q8_0_f32.cl @@ -0,0 +1,195 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable +#pragma OPENCL EXTENSION cl_khr_subgroups : enable + +#ifdef cl_qcom_reqd_sub_group_size +#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable +#define ADRENO_GPU 1 +#define REQD_SUBGROUP_SIZE_64 __attribute__((qcom_reqd_sub_group_size("half"))) +#endif + +#define QK8_0 32 +#define N_SIMDGROUP 4 + +#define dequantizeBlockAccum_ns_sgbroadcast_1(total_sums, bits8, scale, y) \ + float shared_y; \ + char elem; \ + \ + shared_y = sub_group_broadcast(y.s0, 0); \ + elem = (char)(bits8.s0 & 0x000000FF); \ + total_sums += convert_int(elem) * scale * shared_y; \ + shared_y = sub_group_broadcast(y.s1, 0); \ + elem = (char)((bits8.s0 & 0x0000FF00) >> 8); \ + total_sums += convert_int(elem) * scale * shared_y; \ + shared_y = sub_group_broadcast(y.s2, 0); \ + elem = (char)((bits8.s0 & 0x00FF0000) >> 16); \ + total_sums += convert_int(elem) * scale * shared_y; \ + shared_y = sub_group_broadcast(y.s3, 0); \ + elem = (char)((bits8.s0 & 0xFF000000) >> 24); \ + total_sums += convert_int(elem) * scale * shared_y; \ + \ + shared_y = sub_group_broadcast(y.s4, 0); \ + elem = (char)(bits8.s1 & 0x000000FF); \ + total_sums += convert_int(elem) * scale * shared_y; \ + shared_y = sub_group_broadcast(y.s5, 0); \ + elem = (char)((bits8.s1 & 0x0000FF00) >> 8); \ + total_sums += convert_int(elem) * scale * shared_y; \ + shared_y = sub_group_broadcast(y.s6, 0); \ + elem = (char)((bits8.s1 & 0x00FF0000) >> 16); \ + total_sums += convert_int(elem) * scale * shared_y; \ + shared_y = sub_group_broadcast(y.s7, 0); \ + elem = (char)((bits8.s1 & 0xFF000000) >> 24); \ + total_sums += convert_int(elem) * scale * shared_y; \ + \ + shared_y = sub_group_broadcast(y.s0, 1); \ + elem = (char)(bits8.s2 & 0x000000FF); \ + total_sums += convert_int(elem) * scale * shared_y; \ + shared_y = sub_group_broadcast(y.s1, 1); \ + elem = (char)((bits8.s2 & 0x0000FF00) >> 8); \ + total_sums += convert_int(elem) * scale * shared_y; \ + shared_y = sub_group_broadcast(y.s2, 1); \ + elem = (char)((bits8.s2 & 0x00FF0000) >> 16); \ + total_sums += convert_int(elem) * scale * shared_y; \ + shared_y = sub_group_broadcast(y.s3, 1); \ + elem = (char)((bits8.s2 & 0xFF000000) >> 24); \ + total_sums += convert_int(elem) * scale * shared_y; \ + \ + shared_y = sub_group_broadcast(y.s4, 1); \ + elem = (char)(bits8.s3 & 0x000000FF); \ + total_sums += convert_int(elem) * scale * shared_y; \ + shared_y = sub_group_broadcast(y.s5, 1); \ + elem = (char)((bits8.s3 & 0x0000FF00) >> 8); \ + total_sums += convert_int(elem) * scale * shared_y; \ + shared_y = sub_group_broadcast(y.s6, 1); \ + elem = (char)((bits8.s3 & 0x00FF0000) >> 16); \ + total_sums += convert_int(elem) * scale * shared_y; \ + shared_y = sub_group_broadcast(y.s7, 1); \ + elem = (char)((bits8.s3 & 0xFF000000) >> 24); \ + total_sums += convert_int(elem) * scale * shared_y; \ + \ + shared_y = sub_group_broadcast(y.s0, 2); \ + elem = (char)(bits8.s4 & 0x000000FF); \ + total_sums += convert_int(elem) * scale * shared_y; \ + shared_y = sub_group_broadcast(y.s1, 2); \ + elem = (char)((bits8.s4 & 0x0000FF00) >> 8); \ + total_sums += convert_int(elem) * scale * shared_y; \ + shared_y = sub_group_broadcast(y.s2, 2); \ + elem = (char)((bits8.s4 & 0x00FF0000) >> 16); \ + total_sums += convert_int(elem) * scale * shared_y; \ + shared_y = sub_group_broadcast(y.s3, 2); \ + elem = (char)((bits8.s4 & 0xFF000000) >> 24); \ + total_sums += convert_int(elem) * scale * shared_y; \ + \ + shared_y = sub_group_broadcast(y.s4, 2); \ + elem = (char)(bits8.s5 & 0x000000FF); \ + total_sums += convert_int(elem) * scale * shared_y; \ + shared_y = sub_group_broadcast(y.s5, 2); \ + elem = (char)((bits8.s5 & 0x0000FF00) >> 8); \ + total_sums += convert_int(elem) * scale * shared_y; \ + shared_y = sub_group_broadcast(y.s6, 2); \ + elem = (char)((bits8.s5 & 0x00FF0000) >> 16); \ + total_sums += convert_int(elem) * scale * shared_y; \ + shared_y = sub_group_broadcast(y.s7, 2); \ + elem = (char)((bits8.s5 & 0xFF000000) >> 24); \ + total_sums += convert_int(elem) * scale * shared_y; \ + \ + shared_y = sub_group_broadcast(y.s0, 3); \ + elem = (char)(bits8.s6 & 0x000000FF); \ + total_sums += convert_int(elem) * scale * shared_y; \ + shared_y = sub_group_broadcast(y.s1, 3); \ + elem = (char)((bits8.s6 & 0x0000FF00) >> 8); \ + total_sums += convert_int(elem) * scale * shared_y; \ + shared_y = sub_group_broadcast(y.s2, 3); \ + elem = (char)((bits8.s6 & 0x00FF0000) >> 16); \ + total_sums += convert_int(elem) * scale * shared_y; \ + shared_y = sub_group_broadcast(y.s3, 3); \ + elem = (char)((bits8.s6 & 0xFF000000) >> 24); \ + total_sums += convert_int(elem) * scale * shared_y; \ + \ + shared_y = sub_group_broadcast(y.s4, 3); \ + elem = (char)(bits8.s7 & 0x000000FF); \ + total_sums += convert_int(elem) * scale * shared_y; \ + shared_y = sub_group_broadcast(y.s5, 3); \ + elem = (char)((bits8.s7 & 0x0000FF00) >> 8); \ + total_sums += convert_int(elem) * scale * shared_y; \ + shared_y = sub_group_broadcast(y.s6, 3); \ + elem = (char)((bits8.s7 & 0x00FF0000) >> 16); \ + total_sums += convert_int(elem) * scale * shared_y; \ + shared_y = sub_group_broadcast(y.s7, 3); \ + elem = (char)((bits8.s7 & 0xFF000000) >> 24); \ + total_sums += convert_int(elem) * scale * shared_y; \ + +#ifdef ADRENO_GPU +REQD_SUBGROUP_SIZE_64 +#endif +__kernel void kernel_gemv_noshuffle_q8_0_f32( + __read_only image1d_buffer_t src0_q, // quantized A + global half * src0_d, // A scales + __read_only image1d_buffer_t src1, // B + ulong offset1, // offset to B (0) + global float * dst, // C + ulong offsetd, // offset to C + int ne00, // K + int ne01, // M + int ne02, // 1 + int ne10, // K + int ne12, // 1 + int ne0, // M + int ne1, // N + int r2, // 1 + int r3) +{ + uint groupId = get_local_id(1); + uint gid = get_global_id(0); + ushort slid = get_sub_group_local_id(); + + uint K = ne00; + uint M = ne01; + + uint LINE_STRIDE_A = M; + uint BLOCK_STRIDE_A = 8 * M; // 32 / 4 = 8 + + __private uint8 regA; + __private half regS; + __private float8 regB; + + __private float totalSum = (float)(0.0f); + + // loop along K in block granularity, skip 4 blocks every iter + #pragma unroll 1 /* tell compiler not to unroll */ + for (uint k = groupId; k < (K / QK8_0); k += N_SIMDGROUP) { + regS = src0_d[gid + k * LINE_STRIDE_A]; // each fiber loads scale of one rows + // first 4 fibers in each wave load 8 B values to its private scope + if (slid < 4) { + regB.s0123 = read_imagef(src1, (slid * 2 + k * 8)); + regB.s4567 = read_imagef(src1, (1 + slid * 2 + k * 8)); + } + + // load weights for one block in consecutive rows + regA.s0 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 0)).x; + regA.s1 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 1)).x; + regA.s2 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 2)).x; + regA.s3 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 3)).x; + regA.s4 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 4)).x; + regA.s5 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 5)).x; + regA.s6 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 6)).x; + regA.s7 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 7)).x; + + dequantizeBlockAccum_ns_sgbroadcast_1(totalSum, regA, regS, regB); + } + + // reduction in local memory, assumes #wave=4 + __local float reduceLM[SIMDGROUP_WIDTH * 3]; + if (groupId == 1) reduceLM[SIMDGROUP_WIDTH * 0 + slid] = totalSum; + if (groupId == 2) reduceLM[SIMDGROUP_WIDTH * 1 + slid] = totalSum; + if (groupId == 3) reduceLM[SIMDGROUP_WIDTH * 2 + slid] = totalSum; + barrier(CLK_LOCAL_MEM_FENCE); + if (groupId == 0) totalSum += reduceLM[SIMDGROUP_WIDTH * 0 + slid]; + if (groupId == 0) totalSum += reduceLM[SIMDGROUP_WIDTH * 1 + slid]; + if (groupId == 0) totalSum += reduceLM[SIMDGROUP_WIDTH * 2 + slid]; + + // 1 outputs per fiber in wave 0 + if (groupId == 0) { + dst = (global float*)((global char*)dst + offsetd); + dst[gid] = totalSum; + } +} diff --git a/ggml/src/ggml-opencl/kernels/get_rows.cl b/ggml/src/ggml-opencl/kernels/get_rows.cl index c2962edc983..9ae4fff09fc 100644 --- a/ggml/src/ggml-opencl/kernels/get_rows.cl +++ b/ggml/src/ggml-opencl/kernels/get_rows.cl @@ -82,21 +82,27 @@ kernel void kernel_get_rows_f32( src1 = (global int*)((global char*)src1 + offset1); dst = (global float*)((global char*)dst + offsetd); - int i10 = get_group_id(0); - int i11 = get_group_id(1); - int i12 = get_group_id(2); + int nchunks = get_num_groups(0) / ne10; + int g = get_group_id(0); + int i10 = g / nchunks; + int chunk = g - i10 * nchunks; + int i11 = get_group_id(1); + int i12 = get_group_id(2); int r = ((global int *) ((global char *) src1 + i12*nb12 + i11*nb11 + i10*nb10))[0]; int i02 = i11; int i03 = i12; - for (int ind = get_local_id(0); ind < ne00; ind += get_local_size(0)) { - if (ind >= ne00) { - return; - } - ((global float *) ((global char *) dst + i12*nb3 + i11*nb2 + i10*nb1))[ind] = - ((global float *) ((global char *) src0 + r*nb01 + i02*nb02 + i03*nb03))[ind]; + global float * dst_row = (global float *) ((global char *) dst + i12*nb3 + i11*nb2 + i10*nb1); + global float * src_row = (global float *) ((global char *) src0 + r*nb01 + i02*nb02 + i03*nb03); + + int span = (ne00 + nchunks - 1) / nchunks; + int start = chunk * span; + int end = min(start + span, ne00); + + for (int ind = start + get_local_id(0); ind < end; ind += get_local_size(0)) { + dst_row[ind] = src_row[ind]; } } diff --git a/ggml/src/ggml-opencl/kernels/l2_norm.cl b/ggml/src/ggml-opencl/kernels/l2_norm.cl new file mode 100644 index 00000000000..fb95355a679 --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/l2_norm.cl @@ -0,0 +1,71 @@ +#ifdef cl_intel_required_subgroup_size +#pragma OPENCL EXTENSION cl_intel_required_subgroup_size : enable +#define INTEL_GPU 1 +#define REQD_SUBGROUP_SIZE_16 __attribute__((intel_reqd_sub_group_size(16))) +#define REQD_SUBGROUP_SIZE_32 __attribute__((intel_reqd_sub_group_size(32))) +#elif defined(cl_qcom_reqd_sub_group_size) +#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable +#define ADRENO_GPU 1 +#define REQD_SUBGROUP_SIZE_64 __attribute__((qcom_reqd_sub_group_size("half"))) +#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full"))) +#endif + +#ifdef INTEL_GPU +REQD_SUBGROUP_SIZE_32 +#elif defined (ADRENO_GPU) +REQD_SUBGROUP_SIZE_64 +#endif +kernel void kernel_l2_norm_f32( + global void * src0, + ulong offset0, + global float * dst, + ulong offsetd, + int ne00, + int ne01, + int ne02, + int ne03, + ulong nb01, + ulong nb02, + ulong nb03, + float eps, + local float * sum +) { + src0 = (global void*)((global char*)src0 + offset0); + dst = (global float*)((global char*)dst + offsetd); + + int i03 = get_group_id(2); + int i02 = get_group_id(1); + int i01 = get_group_id(0); + + global float * x = (global float *) ((global char *) src0 + i03*nb03 + i02*nb02 + i01*nb01); + global float * y = (global float *) (dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00); + + float sumf = 0; + + // parallel sum + for (int i00 = get_local_id(0); i00 < ne00; i00 += get_local_size(0)) { + sumf += x[i00] * x[i00]; + } + sumf = sub_group_reduce_add(sumf); + + if (get_sub_group_local_id() == 0) { + sum[get_sub_group_id()] = sumf; + } + + barrier(CLK_LOCAL_MEM_FENCE); + + // broadcast + for (uint i = get_local_size(0) / get_max_sub_group_size() / 2; i > 0; i /= 2) { + if (get_local_id(0) < i) { + sum[get_local_id(0)] += sum[get_local_id(0) + i]; + } + } + + barrier(CLK_LOCAL_MEM_FENCE); + + const float scale = 1.0f/max(sqrt(sum[0]), eps); + + for (int i00 = get_local_id(0); i00 < ne00; i00 += get_local_size(0)) { + y[i00] = x[i00] * scale; + } +} diff --git a/ggml/src/ggml-opencl/kernels/mean.cl b/ggml/src/ggml-opencl/kernels/mean.cl index 5c3e8bcd863..7c7e0a587ee 100644 --- a/ggml/src/ggml-opencl/kernels/mean.cl +++ b/ggml/src/ggml-opencl/kernels/mean.cl @@ -1,8 +1,13 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable +#pragma OPENCL EXTENSION cl_khr_subgroups : enable +// Most devices have max workgroup size of 1024, so this is enough for subgroup +// sizes of 16, 32, 64 and 128. Increase this value for smaller subgroups sizes +#define MAX_SUBGROUPS 64 kernel void kernel_mean_f32( - global float * src0, + global char * src0, ulong offset0, - global float * dst, + global char * dst, ulong offsetd, int ne00, int ne01, @@ -15,25 +20,121 @@ kernel void kernel_mean_f32( ulong nb2, ulong nb3 ) { - src0 = (global float *)((global char *)src0 + offset0); - dst = (global float *)((global char *)dst + offsetd); + src0 = src0 + offset0; + dst = dst + offsetd; - int i3 = get_global_id(2); - int i2 = get_global_id(1); - int i1 = get_global_id(0); + const int i3 = get_group_id(2); + const int i2 = get_group_id(1); + const int i1 = get_group_id(0); + + const int lid = get_local_id(0); + const int lsize = get_local_size(0); + + const uint sg_size = get_sub_group_size(); + const uint sg_id = get_sub_group_id(); + const uint sg_lid = get_sub_group_local_id(); + + __local float lmem[MAX_SUBGROUPS]; if (i3 >= ne03 || i2 >= ne02 || i1 >= ne01) { return; } - global float * src_row = (global float *) ((global char *) src0 + i1*nb01 + i2*nb02 + i3*nb03); - global float * dst_row = (global float *) ((global char *) dst + i1*nb1 + i2*nb2 + i3*nb3); + if(sg_id == 0){ + lmem[sg_lid] = 0.0f; + } - float row_sum = 0; + global float * src_row = (global float *) (src0 + i1*nb01 + i2*nb02 + i3*nb03); + global float * dst_row = (global float *) (dst + i1*nb1 + i2*nb2 + i3*nb3); - for (int i0 = 0; i0 < ne00; i0++) { - row_sum += src_row[i0]; + float sumf = 0.0f; + + for (int i0 = lid; i0 < ne00; i0 += lsize) { + sumf += src_row[i0]; } - dst_row[0] = row_sum / ne00; + sumf = sub_group_reduce_add(sumf); + + barrier(CLK_LOCAL_MEM_FENCE); + + if(sg_lid == 0){ + lmem[sg_id] = sumf; + } + + barrier(CLK_LOCAL_MEM_FENCE); + + sumf = lmem[sg_lid]; + sumf = sub_group_reduce_add(sumf); + + if (lid == 0) { + dst_row[0] = sumf / ne00; + } +} + +kernel void kernel_mean_f32_4( + global char * src0, + ulong offset0, + global char * dst, + ulong offsetd, + int ne00, + int ne01, + int ne02, + int ne03, + ulong nb01, + ulong nb02, + ulong nb03, + ulong nb1, + ulong nb2, + ulong nb3 +) { + src0 = src0 + offset0; + dst = dst + offsetd; + + const int i3 = get_group_id(2); + const int i2 = get_group_id(1); + const int i1 = get_group_id(0); + + const int lid = get_local_id(0); + const int lsize = get_local_size(0); + + const uint sg_size = get_sub_group_size(); + const uint sg_id = get_sub_group_id(); + const uint sg_lid = get_sub_group_local_id(); + + __local float lmem[MAX_SUBGROUPS]; + + if (i3 >= ne03 || i2 >= ne02 || i1 >= ne01) { + return; + } + + if(sg_id == 0){ + lmem[sg_lid] = 0.0f; + } + + global float4 * src_row = (global float4 *) (src0 + i1*nb01 + i2*nb02 + i3*nb03); + global float * dst_row = (global float *) (dst + i1*nb1 + i2*nb2 + i3*nb3); + + float4 sum_vec = (float4)0.0f; + + for (int i0 = lid; i0 < ne00 / 4; i0 += lsize) { + sum_vec += src_row[i0]; + } + + float sumf = dot(sum_vec, (float4)(1.0f)); + sumf = sub_group_reduce_add(sumf); + + barrier(CLK_LOCAL_MEM_FENCE); + + if(sg_lid == 0){ + lmem[sg_id] = sumf; + } + + barrier(CLK_LOCAL_MEM_FENCE); + + sumf = lmem[sg_lid]; + sumf = sub_group_reduce_add(sumf); + + if (lid == 0) { + dst_row[0] = sumf / ne00; + } } diff --git a/ggml/src/ggml-opencl/kernels/moe_reorder_b.cl b/ggml/src/ggml-opencl/kernels/moe_reorder_b.cl new file mode 100644 index 00000000000..e6295c81648 --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/moe_reorder_b.cl @@ -0,0 +1,30 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable + +#define QK4_0 32 + +kernel void kernel_moe_reorder_b( + global float4 * src, + global uint * router, + global float4 * dst, + global int * total_tiles, + uint K, + ushort map_ratio, + uint tile_size +) { + uint k_4 = get_global_id(0); + uint post_router_idx = get_global_id(1); + + if ((k_4 >= (K / 4)) || (post_router_idx >= total_tiles[0] * tile_size)) { + return; + } + + uint router_idx = router[post_router_idx]; + + float4 out = (float4)(0); + if (router_idx != 0xFFFFFFFF) { + ushort activation_idx = router_idx / map_ratio; + out = src[activation_idx * K / 4 + k_4]; + } + + dst[post_router_idx * K / 4 + k_4] = out; +} diff --git a/ggml/src/ggml-opencl/kernels/moe_sort_by_expert.cl b/ggml/src/ggml-opencl/kernels/moe_sort_by_expert.cl new file mode 100644 index 00000000000..d9703429b11 --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/moe_sort_by_expert.cl @@ -0,0 +1,82 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable + +__kernel void kernel_moe_histogram( + __global const int * input, + __global int * hist, + uint N, + uint topK, + uint n_experts +) { + uint n = get_global_id(0); + uint k = get_global_id(1); + + if (n >= N || k >= topK) { + return; + } + + int expert_id = input[n * n_experts + k]; + atomic_inc(&hist[expert_id]); +} + +__kernel void kernel_moe_scan( + __global int * hist, + __global int * tile_offset, + __global int * total_tiles, + __global int * slot_counter, + int tile_size, + uint n_experts +) { + int offset = 0; + for (int v = 0; v < n_experts; v++) { + int count = hist[v]; + int tiles = (count + tile_size - 1) / tile_size; + tile_offset[v] = offset; + offset += tiles; + hist[v] = 0; + slot_counter[v] = 0; + } + + *total_tiles = offset; +} + +__kernel void kernel_moe_scatter( + __global const int * input, + __global int * post_router, + __global ushort * emap, + __global const int * tile_offset, + __global int * slot_counter, + int N, + int topK, + uint n_experts +) { + uint n = get_global_id(0); + uint k = get_global_id(1); + + if (n >= N || k >= topK) { + return; + } + + int val = input[n * n_experts + k]; + + int local_slot = atomic_inc(&slot_counter[val]); + + int tile_idx = tile_offset[val] + (local_slot / 32); + int lane = local_slot % 32; + int out_pos = tile_idx * 32 + lane; + + post_router[out_pos] = n * topK + k; + emap[tile_idx] = val; +} + +__kernel void kernel_moe_fill( + __global int * post_router, + __global int * total_tiles, + int tile_size +) { + int tile_id = get_global_id(0); + int vec_id_in_tile = get_global_id(1); + + if (tile_id < total_tiles[0]) { + post_router[tile_id * tile_size + vec_id_in_tile] = 0xFFFFFFFF; + } +} diff --git a/ggml/src/ggml-opencl/kernels/mul_mm_iq4_nl_f32_l4_lm.cl b/ggml/src/ggml-opencl/kernels/mul_mm_iq4_nl_f32_l4_lm.cl new file mode 100644 index 00000000000..11ff7f8d9dc --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/mul_mm_iq4_nl_f32_l4_lm.cl @@ -0,0 +1,171 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable + +#define LOAD_VEC_A 8 +#define LOAD_VEC_B 4 + +#define BM 64 +#define BN 64 +#define BK 32 +#define TM 4 +#define TN 8 + +constant float kvalues_iq4nl[16] = { + -127.f, -104.f, -83.f, -65.f, -49.f, -35.f, -22.f, -10.f, + 1.f, 13.f, 25.f, 38.f, 53.f, 69.f, 89.f, 113.f +}; + +kernel void kernel_mul_mm_iq4_nl_f32_l4_lm( + global uchar4 * src0_q, + global half * src0_d, + global float4 * src1, + ulong offset1, + global float * dst, + ulong offsetd, + + int ne00, + int ne01, + int ne02, + int ne11, + int ne12, + + int stride_a, + int stride_b, + int stride_d, + + int batch_stride_a, + int batch_stride_b, + int batch_stride_d, + + int r2, + int r3 +) { + src1 = (global float4*)((global char*)src1 + offset1); + dst = (global float *)((global char*)dst + offsetd); + + local float buf_a[BM * BK]; + local float buf_b[BN * BK]; + + const int batch_idx = get_global_id(2); + + const int i13 = batch_idx / ne12; + const int i12 = batch_idx % ne12; + + const int i03 = i13 / r3; + const int i02 = i12 / r2; + + const int batch_idx_a = i03 * ne02 + i02; + + const int ir = get_group_id(0); + const int ic = get_group_id(1); + + const int tid = get_local_id(0); + const int th_r = tid % (BM / TM); + const int th_c = tid / (BM / TM); + + const int loadr_a = get_local_id(0) % (BK / LOAD_VEC_A); + const int loadc_a = get_local_id(0) / (BK / LOAD_VEC_A); + const int loadr_b = get_local_id(0) % (BK / LOAD_VEC_B); + const int loadc_b = get_local_id(0) / (BK / LOAD_VEC_B); + + const int loadstride_a = get_local_size(0) * LOAD_VEC_A / BK; + const int loadstride_b = get_local_size(0) * LOAD_VEC_B / BK; + + int pos_a = (batch_idx_a * batch_stride_a + ir * BM * stride_a) / LOAD_VEC_A; + int pos_b = (batch_idx * batch_stride_b + ic * BN * stride_b) / LOAD_VEC_B; + + float sums[TM * TN]; + float cache_a[TM]; + float cache_b[TN]; + + for (int i = 0; i < TM * TN; i++) { + sums[i] = 0.0f; + } + + for (int block = 0; block < ne00; block += BK) { + for (int l = 0; l < BM; l += loadstride_a) { + if (ir*BM + loadc_a + l < ne01) { + int idx = pos_a + (loadc_a + l) * stride_a / LOAD_VEC_A + loadr_a; + int ib = idx / 4; + int iqs = idx % 4; + + float d = (float)src0_d[ib]; + global uchar4 * qs = src0_q + ib*4 + iqs; + uchar4 q = *qs; + // IQ4_NL: use lookup table instead of linear (nibble - 8) + float4 v1 = (float4)(kvalues_iq4nl[(q.s0 )&0x0F], kvalues_iq4nl[(q.s1 )&0x0F], + kvalues_iq4nl[(q.s2 )&0x0F], kvalues_iq4nl[(q.s3 )&0x0F])*d; + float4 v2 = (float4)(kvalues_iq4nl[(q.s0>>4)&0x0F], kvalues_iq4nl[(q.s1>>4)&0x0F], + kvalues_iq4nl[(q.s2>>4)&0x0F], kvalues_iq4nl[(q.s3>>4)&0x0F])*d; + + buf_a[(loadr_a * 4 + 0) * BM + loadc_a + l] = v1.s0; + buf_a[(loadr_a * 4 + 1) * BM + loadc_a + l] = v1.s1; + buf_a[(loadr_a * 4 + 2) * BM + loadc_a + l] = v1.s2; + buf_a[(loadr_a * 4 + 3) * BM + loadc_a + l] = v1.s3; + buf_a[(loadr_a * 4 + 16) * BM + loadc_a + l] = v2.s0; + buf_a[(loadr_a * 4 + 17) * BM + loadc_a + l] = v2.s1; + buf_a[(loadr_a * 4 + 18) * BM + loadc_a + l] = v2.s2; + buf_a[(loadr_a * 4 + 19) * BM + loadc_a + l] = v2.s3; + } else { + buf_a[(loadr_a * 4 + 0) * BM + loadc_a + l] = 0.0f; + buf_a[(loadr_a * 4 + 1) * BM + loadc_a + l] = 0.0f; + buf_a[(loadr_a * 4 + 2) * BM + loadc_a + l] = 0.0f; + buf_a[(loadr_a * 4 + 3) * BM + loadc_a + l] = 0.0f; + buf_a[(loadr_a * 4 + 16) * BM + loadc_a + l] = 0.0f; + buf_a[(loadr_a * 4 + 17) * BM + loadc_a + l] = 0.0f; + buf_a[(loadr_a * 4 + 18) * BM + loadc_a + l] = 0.0f; + buf_a[(loadr_a * 4 + 19) * BM + loadc_a + l] = 0.0f; + } + } + + for (int l = 0; l < BN; l += loadstride_b) { + if (ic*BN + loadc_b + l < ne11) { + int idx = pos_b + (loadc_b + l) * stride_b / LOAD_VEC_B + loadr_b; + buf_b[(loadr_b * LOAD_VEC_B + 0) * BN + loadc_b + l] = src1[idx].s0; + buf_b[(loadr_b * LOAD_VEC_B + 1) * BN + loadc_b + l] = src1[idx].s1; + buf_b[(loadr_b * LOAD_VEC_B + 2) * BN + loadc_b + l] = src1[idx].s2; + buf_b[(loadr_b * LOAD_VEC_B + 3) * BN + loadc_b + l] = src1[idx].s3; + } else { + buf_b[(loadr_b * LOAD_VEC_B + 0) * BN + loadc_b + l] = 0.0f; + buf_b[(loadr_b * LOAD_VEC_B + 1) * BN + loadc_b + l] = 0.0f; + buf_b[(loadr_b * LOAD_VEC_B + 2) * BN + loadc_b + l] = 0.0f; + buf_b[(loadr_b * LOAD_VEC_B + 3) * BN + loadc_b + l] = 0.0f; + } + } + + barrier(CLK_LOCAL_MEM_FENCE); + + pos_a += BK / LOAD_VEC_A; + pos_b += BK / LOAD_VEC_B; + + for (int i = 0; i < BK; i++) { + for (int j = 0; j < TM; j++) { + cache_a[j] = buf_a[(i) * BM + th_r * TM + j]; + } + + for (int j = 0; j < TN; j++) { + cache_b[j] = buf_b[(i) * BN + th_c * TN + j]; + } + + for (int cc = 0; cc < TN; cc++) { + for (int cr = 0; cr < TM; cr++) { + const int sums_idx = cc*TM + cr; + sums[sums_idx] = mad(cache_a[cr], cache_b[cc], sums[sums_idx]); + } + } + } + barrier(CLK_LOCAL_MEM_FENCE); + } + + const int dr = ir * BM + th_r * TM; + const int dc = ic * BN + th_c * TN; + + const int offsets = batch_idx * batch_stride_d; + + for (int cc = 0; cc < TN; cc++) { + for (int cr = 0; cr < TM; cr++) { + if (dr + cr < ne01 && dc + cc < ne11) { + dst[offsets + (dc + cc) * stride_d + dr + cr] = sums[cc * TM + cr]; + } + } + } +} diff --git a/ggml/src/ggml-opencl/kernels/mul_mm_q4_0_f32_l4_lm.cl b/ggml/src/ggml-opencl/kernels/mul_mm_q4_0_f32_l4_lm.cl new file mode 100644 index 00000000000..4100e3080a2 --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/mul_mm_q4_0_f32_l4_lm.cl @@ -0,0 +1,163 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable + +#define LOAD_VEC_A 8 +#define LOAD_VEC_B 4 + +#define BM 64 +#define BN 64 +#define BK 32 +#define TM 4 +#define TN 8 + +kernel void kernel_mul_mm_q4_0_f32_l4_lm( + global uchar4 * src0_q, + global half * src0_d, + global float4 * src1, + ulong offset1, + global float * dst, + ulong offsetd, + + int ne00, + int ne01, + int ne02, + int ne11, + int ne12, + + int stride_a, + int stride_b, + int stride_d, + + int batch_stride_a, + int batch_stride_b, + int batch_stride_d, + + int r2, + int r3 +) { + src1 = (global float4*)((global char*)src1 + offset1); + dst = (global float *)((global char*)dst + offsetd); + + local float buf_a[BM * BK]; + local float buf_b[BN * BK]; + + const int batch_idx = get_global_id(2); + + const int i13 = batch_idx / ne12; + const int i12 = batch_idx % ne12; + + const int i03 = i13 / r3; + const int i02 = i12 / r2; + + const int batch_idx_a = i03 * ne02 + i02; + + const int ir = get_group_id(0); + const int ic = get_group_id(1); + + const int tid = get_local_id(0); + const int th_r = tid % (BM / TM); + const int th_c = tid / (BM / TM); + + const int loadr_a = get_local_id(0) % (BK / LOAD_VEC_A); + const int loadc_a = get_local_id(0) / (BK / LOAD_VEC_A); + const int loadr_b = get_local_id(0) % (BK / LOAD_VEC_B); + const int loadc_b = get_local_id(0) / (BK / LOAD_VEC_B); + + const int loadstride_a = get_local_size(0) * LOAD_VEC_A / BK; + const int loadstride_b = get_local_size(0) * LOAD_VEC_B / BK; + + int pos_a = (batch_idx_a * batch_stride_a + ir * BM * stride_a) / LOAD_VEC_A; + int pos_b = (batch_idx * batch_stride_b + ic * BN * stride_b) / LOAD_VEC_B; + + float sums[TM * TN]; + float cache_a[TM]; + float cache_b[TN]; + + for (int i = 0; i < TM * TN; i++) { + sums[i] = 0.0f; + } + + for (int block = 0; block < ne00; block += BK) { + for (int l = 0; l < BM; l += loadstride_a) { + if (ir*BM + loadc_a + l < ne01) { + int idx = pos_a + (loadc_a + l) * stride_a / LOAD_VEC_A + loadr_a; + int ib = idx / 4; + int iqs = idx % 4; + + float d = (float)src0_d[ib]; + global uchar4 * qs = src0_q + ib*4 + iqs; + uchar4 q = *qs; + float4 v1 = (convert_float4((uchar4)((q.s0 )&0x0F, (q.s1 )&0x0F, (q.s2 )&0x0F, (q.s3 )&0x0F)) - 8.0f)*d; + float4 v2 = (convert_float4((uchar4)((q.s0>>4)&0x0F, (q.s1>>4)&0x0F, (q.s2>>4)&0x0F, (q.s3>>4)&0x0F)) - 8.0f)*d; + + buf_a[(loadr_a * 4 + 0) * BM + loadc_a + l] = v1.s0; + buf_a[(loadr_a * 4 + 1) * BM + loadc_a + l] = v1.s1; + buf_a[(loadr_a * 4 + 2) * BM + loadc_a + l] = v1.s2; + buf_a[(loadr_a * 4 + 3) * BM + loadc_a + l] = v1.s3; + buf_a[(loadr_a * 4 + 16) * BM + loadc_a + l] = v2.s0; + buf_a[(loadr_a * 4 + 17) * BM + loadc_a + l] = v2.s1; + buf_a[(loadr_a * 4 + 18) * BM + loadc_a + l] = v2.s2; + buf_a[(loadr_a * 4 + 19) * BM + loadc_a + l] = v2.s3; + } else { + buf_a[(loadr_a * 4 + 0) * BM + loadc_a + l] = 0.0f; + buf_a[(loadr_a * 4 + 1) * BM + loadc_a + l] = 0.0f; + buf_a[(loadr_a * 4 + 2) * BM + loadc_a + l] = 0.0f; + buf_a[(loadr_a * 4 + 3) * BM + loadc_a + l] = 0.0f; + buf_a[(loadr_a * 4 + 16) * BM + loadc_a + l] = 0.0f; + buf_a[(loadr_a * 4 + 17) * BM + loadc_a + l] = 0.0f; + buf_a[(loadr_a * 4 + 18) * BM + loadc_a + l] = 0.0f; + buf_a[(loadr_a * 4 + 19) * BM + loadc_a + l] = 0.0f; + } + } + + for (int l = 0; l < BN; l += loadstride_b) { + if (ic*BN + loadc_b + l < ne11) { + int idx = pos_b + (loadc_b + l) * stride_b / LOAD_VEC_B + loadr_b; + buf_b[(loadr_b * LOAD_VEC_B + 0) * BN + loadc_b + l] = src1[idx].s0; + buf_b[(loadr_b * LOAD_VEC_B + 1) * BN + loadc_b + l] = src1[idx].s1; + buf_b[(loadr_b * LOAD_VEC_B + 2) * BN + loadc_b + l] = src1[idx].s2; + buf_b[(loadr_b * LOAD_VEC_B + 3) * BN + loadc_b + l] = src1[idx].s3; + } else { + buf_b[(loadr_b * LOAD_VEC_B + 0) * BN + loadc_b + l] = 0.0f; + buf_b[(loadr_b * LOAD_VEC_B + 1) * BN + loadc_b + l] = 0.0f; + buf_b[(loadr_b * LOAD_VEC_B + 2) * BN + loadc_b + l] = 0.0f; + buf_b[(loadr_b * LOAD_VEC_B + 3) * BN + loadc_b + l] = 0.0f; + } + } + + barrier(CLK_LOCAL_MEM_FENCE); + + pos_a += BK / LOAD_VEC_A; + pos_b += BK / LOAD_VEC_B; + + for (int i = 0; i < BK; i++) { + for (int j = 0; j < TM; j++) { + cache_a[j] = buf_a[(i) * BM + th_r * TM + j]; + } + + for (int j = 0; j < TN; j++) { + cache_b[j] = buf_b[(i) * BN + th_c * TN + j]; + } + + for (int cc = 0; cc < TN; cc++) { + for (int cr = 0; cr < TM; cr++) { + const int sums_idx = cc*TM + cr; + sums[sums_idx] = mad(cache_a[cr], cache_b[cc], sums[sums_idx]); + } + } + } + barrier(CLK_LOCAL_MEM_FENCE); + } + + const int dr = ir * BM + th_r * TM; + const int dc = ic * BN + th_c * TN; + + const int offsets = batch_idx * batch_stride_d; + + for (int cc = 0; cc < TN; cc++) { + for (int cr = 0; cr < TM; cr++) { + if (dr + cr < ne01 && dc + cc < ne11) { + dst[offsets + (dc + cc) * stride_d + dr + cr] = sums[cc * TM + cr]; + } + } + } +} diff --git a/ggml/src/ggml-opencl/kernels/mul_mm_q4_1_f32_l4_lm.cl b/ggml/src/ggml-opencl/kernels/mul_mm_q4_1_f32_l4_lm.cl new file mode 100644 index 00000000000..d0d2f08361e --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/mul_mm_q4_1_f32_l4_lm.cl @@ -0,0 +1,165 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable + +#define LOAD_VEC_A 8 +#define LOAD_VEC_B 4 + +#define BM 64 +#define BN 64 +#define BK 32 +#define TM 4 +#define TN 8 + +kernel void kernel_mul_mm_q4_1_f32_l4_lm( + global uchar4 * src0_q, + global half * src0_d, + global half * src0_m, + global float4 * src1, + ulong offset1, + global float * dst, + ulong offsetd, + + int ne00, + int ne01, + int ne02, + int ne11, + int ne12, + + int stride_a, + int stride_b, + int stride_d, + + int batch_stride_a, + int batch_stride_b, + int batch_stride_d, + + int r2, + int r3 +) { + src1 = (global float4*)((global char*)src1 + offset1); + dst = (global float *)((global char*)dst + offsetd); + + local float buf_a[BM * BK]; + local float buf_b[BN * BK]; + + const int batch_idx = get_global_id(2); + + const int i13 = batch_idx / ne12; + const int i12 = batch_idx % ne12; + + const int i03 = i13 / r3; + const int i02 = i12 / r2; + + const int batch_idx_a = i03 * ne02 + i02; + + const int ir = get_group_id(0); + const int ic = get_group_id(1); + + const int tid = get_local_id(0); + const int th_r = tid % (BM / TM); + const int th_c = tid / (BM / TM); + + const int loadr_a = get_local_id(0) % (BK / LOAD_VEC_A); + const int loadc_a = get_local_id(0) / (BK / LOAD_VEC_A); + const int loadr_b = get_local_id(0) % (BK / LOAD_VEC_B); + const int loadc_b = get_local_id(0) / (BK / LOAD_VEC_B); + + const int loadstride_a = get_local_size(0) * LOAD_VEC_A / BK; + const int loadstride_b = get_local_size(0) * LOAD_VEC_B / BK; + + int pos_a = (batch_idx_a * batch_stride_a + ir * BM * stride_a) / LOAD_VEC_A; + int pos_b = (batch_idx * batch_stride_b + ic * BN * stride_b) / LOAD_VEC_B; + + float sums[TM * TN]; + float cache_a[TM]; + float cache_b[TN]; + + for (int i = 0; i < TM * TN; i++) { + sums[i] = 0.0f; + } + + for (int block = 0; block < ne00; block += BK) { + for (int l = 0; l < BM; l += loadstride_a) { + if (ir*BM + loadc_a + l < ne01) { + int idx = pos_a + (loadc_a + l) * stride_a / LOAD_VEC_A + loadr_a; + int ib = idx / 4; + int iqs = idx % 4; + + float d = (float)src0_d[ib]; + float m = (float)src0_m[ib]; + global uchar4 * qs = src0_q + ib*4 + iqs; + uchar4 q = *qs; + float4 v1 = (convert_float4((uchar4)((q.s0 )&0x0F, (q.s1 )&0x0F, (q.s2 )&0x0F, (q.s3 )&0x0F)))*d + m; + float4 v2 = (convert_float4((uchar4)((q.s0>>4)&0x0F, (q.s1>>4)&0x0F, (q.s2>>4)&0x0F, (q.s3>>4)&0x0F)))*d + m; + + buf_a[(loadr_a * 4 + 0) * BM + loadc_a + l] = v1.s0; + buf_a[(loadr_a * 4 + 1) * BM + loadc_a + l] = v1.s1; + buf_a[(loadr_a * 4 + 2) * BM + loadc_a + l] = v1.s2; + buf_a[(loadr_a * 4 + 3) * BM + loadc_a + l] = v1.s3; + buf_a[(loadr_a * 4 + 16) * BM + loadc_a + l] = v2.s0; + buf_a[(loadr_a * 4 + 17) * BM + loadc_a + l] = v2.s1; + buf_a[(loadr_a * 4 + 18) * BM + loadc_a + l] = v2.s2; + buf_a[(loadr_a * 4 + 19) * BM + loadc_a + l] = v2.s3; + } else { + buf_a[(loadr_a * 4 + 0) * BM + loadc_a + l] = 0.0f; + buf_a[(loadr_a * 4 + 1) * BM + loadc_a + l] = 0.0f; + buf_a[(loadr_a * 4 + 2) * BM + loadc_a + l] = 0.0f; + buf_a[(loadr_a * 4 + 3) * BM + loadc_a + l] = 0.0f; + buf_a[(loadr_a * 4 + 16) * BM + loadc_a + l] = 0.0f; + buf_a[(loadr_a * 4 + 17) * BM + loadc_a + l] = 0.0f; + buf_a[(loadr_a * 4 + 18) * BM + loadc_a + l] = 0.0f; + buf_a[(loadr_a * 4 + 19) * BM + loadc_a + l] = 0.0f; + } + } + + for (int l = 0; l < BN; l += loadstride_b) { + if (ic*BN + loadc_b + l < ne11) { + int idx = pos_b + (loadc_b + l) * stride_b / LOAD_VEC_B + loadr_b; + buf_b[(loadr_b * LOAD_VEC_B + 0) * BN + loadc_b + l] = src1[idx].s0; + buf_b[(loadr_b * LOAD_VEC_B + 1) * BN + loadc_b + l] = src1[idx].s1; + buf_b[(loadr_b * LOAD_VEC_B + 2) * BN + loadc_b + l] = src1[idx].s2; + buf_b[(loadr_b * LOAD_VEC_B + 3) * BN + loadc_b + l] = src1[idx].s3; + } else { + buf_b[(loadr_b * LOAD_VEC_B + 0) * BN + loadc_b + l] = 0.0f; + buf_b[(loadr_b * LOAD_VEC_B + 1) * BN + loadc_b + l] = 0.0f; + buf_b[(loadr_b * LOAD_VEC_B + 2) * BN + loadc_b + l] = 0.0f; + buf_b[(loadr_b * LOAD_VEC_B + 3) * BN + loadc_b + l] = 0.0f; + } + } + + barrier(CLK_LOCAL_MEM_FENCE); + + pos_a += BK / LOAD_VEC_A; + pos_b += BK / LOAD_VEC_B; + + for (int i = 0; i < BK; i++) { + for (int j = 0; j < TM; j++) { + cache_a[j] = buf_a[(i) * BM + th_r * TM + j]; + } + + for (int j = 0; j < TN; j++) { + cache_b[j] = buf_b[(i) * BN + th_c * TN + j]; + } + + for (int cc = 0; cc < TN; cc++) { + for (int cr = 0; cr < TM; cr++) { + const int sums_idx = cc*TM + cr; + sums[sums_idx] = mad(cache_a[cr], cache_b[cc], sums[sums_idx]); + } + } + } + barrier(CLK_LOCAL_MEM_FENCE); + } + + const int dr = ir * BM + th_r * TM; + const int dc = ic * BN + th_c * TN; + + const int offsets = batch_idx * batch_stride_d; + + for (int cc = 0; cc < TN; cc++) { + for (int cr = 0; cr < TM; cr++) { + if (dr + cr < ne01 && dc + cc < ne11) { + dst[offsets + (dc + cc) * stride_d + dr + cr] = sums[cc * TM + cr]; + } + } + } +} diff --git a/ggml/src/ggml-opencl/kernels/mul_mm_q4_k_f32_l4_lm.cl b/ggml/src/ggml-opencl/kernels/mul_mm_q4_k_f32_l4_lm.cl new file mode 100644 index 00000000000..2235b1ae838 --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/mul_mm_q4_k_f32_l4_lm.cl @@ -0,0 +1,179 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable + +#define LOAD_VEC_A 4 +#define LOAD_VEC_B 4 + +#define BM 64 +#define BN 64 +#define BK 32 +#define TM 4 +#define TN 8 + +kernel void kernel_mul_mm_q4_k_f32_l4_lm( + global uchar4 * src0_q, + global uchar * src0_s, + global half * src0_d, + global half * src0_dm, + global float4 * src1, + ulong offset1, + global float * dst, + ulong offsetd, + + int ne00, + int ne01, + int ne02, + int ne11, + int ne12, + + int stride_a, + int stride_b, + int stride_d, + + int batch_stride_a, + int batch_stride_b, + int batch_stride_d, + + int r2, + int r3 +) { + src1 = (global float4*)((global char*)src1 + offset1); + dst = (global float *)((global char*)dst + offsetd); + + local float buf_a[BM * BK]; + local float buf_b[BN * BK]; + + const int batch_idx = get_global_id(2); + + const int i13 = batch_idx / ne12; + const int i12 = batch_idx % ne12; + + const int i03 = i13 / r3; + const int i02 = i12 / r2; + + const int batch_idx_a = i03 * ne02 + i02; + + const int ir = get_group_id(0); + const int ic = get_group_id(1); + + const int tid = get_local_id(0); + const int th_r = tid % (BM / TM); + const int th_c = tid / (BM / TM); + + const int loadr_a = get_local_id(0) % (BK / LOAD_VEC_A); + const int loadc_a = get_local_id(0) / (BK / LOAD_VEC_A); + const int loadr_b = get_local_id(0) % (BK / LOAD_VEC_B); + const int loadc_b = get_local_id(0) / (BK / LOAD_VEC_B); + + const int loadstride_a = get_local_size(0) * LOAD_VEC_A / BK; + const int loadstride_b = get_local_size(0) * LOAD_VEC_B / BK; + + int pos_a = (batch_idx_a * batch_stride_a + ir * BM * stride_a) / LOAD_VEC_A; + int pos_b = (batch_idx * batch_stride_b + ic * BN * stride_b) / LOAD_VEC_B; + + float sums[TM * TN]; + float cache_a[TM]; + float cache_b[TN]; + + for (int i = 0; i < TM * TN; i++) { + sums[i] = 0.0f; + } + + for (int block = 0; block < ne00; block += BK) { + for (int l = 0; l < BM; l += loadstride_a) { + if (ir*BM + loadc_a + l < ne01) { + int idx = pos_a + (loadc_a + l) * stride_a / LOAD_VEC_A + loadr_a; + int ib = idx / 64; + int iqs = (idx % 64) * 2; + + int n = iqs / 32; + int b = (iqs % 32) / 16; + int is = 2 * n + b; + int qsi = n * 32 + (iqs % 16) * 2; + + char * scales = src0_s + ib * 12; + + int scidx0 = (is < 4) ? is : (is + 4); + int scidx1 = (is < 4) ? is : (is - 4); + int scidxmask1 = (is < 4) ? 0x30 : 0xC0; + int scidxshift1 = (is < 4) ? 0 : 2; + int mbidx0 = is + 4; + int mbidx1 = (is < 4) ? is + 4 : is; + int mbidxmask0 = (is < 4) ? 0xF : 0xF0; + int mbidxshift0 = (is < 4) ? 0 : 4; + int mbidxmask1 = (is < 4) ? 0x30 : 0xC0; + int mbidxshift1 = (is < 4) ? 0 : 2; + + uchar sc = (scales[scidx0] & 0xF) | ((scales[scidx1] & scidxmask1) >> scidxshift1); + uchar mbyte = ((scales[mbidx0] & mbidxmask0) >> mbidxshift0) | ((scales[mbidx1] & mbidxmask1) >> mbidxshift1); + + float d = (float)src0_d[ib] * (float)sc; + float m = -(float)src0_dm[ib] * (float)mbyte; + + global uchar4 * qs = src0_q + ib*32 + (qsi >> 2); + uchar4 q = *qs; + float4 v1 = (convert_float4((uchar4)((q.s0 >> (b * 4))&0x0F, (q.s1 >> (b * 4))&0x0F, (q.s2 >> (b * 4))&0x0F, (q.s3 >> (b * 4))&0x0F)))*d + m; + + buf_a[(loadr_a * LOAD_VEC_A + 0) * BM + loadc_a + l] = v1.s0; + buf_a[(loadr_a * LOAD_VEC_A + 1) * BM + loadc_a + l] = v1.s1; + buf_a[(loadr_a * LOAD_VEC_A + 2) * BM + loadc_a + l] = v1.s2; + buf_a[(loadr_a * LOAD_VEC_A + 3) * BM + loadc_a + l] = v1.s3; + } else { + buf_a[(loadr_a * LOAD_VEC_A + 0) * BM + loadc_a + l] = 0.0f; + buf_a[(loadr_a * LOAD_VEC_A + 1) * BM + loadc_a + l] = 0.0f; + buf_a[(loadr_a * LOAD_VEC_A + 2) * BM + loadc_a + l] = 0.0f; + buf_a[(loadr_a * LOAD_VEC_A + 3) * BM + loadc_a + l] = 0.0f; + } + } + + for (int l = 0; l < BN; l += loadstride_b) { + if (ic*BN + loadc_b + l < ne11) { + int idx = pos_b + (loadc_b + l) * stride_b / LOAD_VEC_B + loadr_b; + buf_b[(loadr_b * LOAD_VEC_B + 0) * BN + loadc_b + l] = src1[idx].s0; + buf_b[(loadr_b * LOAD_VEC_B + 1) * BN + loadc_b + l] = src1[idx].s1; + buf_b[(loadr_b * LOAD_VEC_B + 2) * BN + loadc_b + l] = src1[idx].s2; + buf_b[(loadr_b * LOAD_VEC_B + 3) * BN + loadc_b + l] = src1[idx].s3; + } else { + buf_b[(loadr_b * LOAD_VEC_B + 0) * BN + loadc_b + l] = 0.0f; + buf_b[(loadr_b * LOAD_VEC_B + 1) * BN + loadc_b + l] = 0.0f; + buf_b[(loadr_b * LOAD_VEC_B + 2) * BN + loadc_b + l] = 0.0f; + buf_b[(loadr_b * LOAD_VEC_B + 3) * BN + loadc_b + l] = 0.0f; + } + } + + barrier(CLK_LOCAL_MEM_FENCE); + + pos_a += BK / LOAD_VEC_A; + pos_b += BK / LOAD_VEC_B; + + for (int i = 0; i < BK; i++) { + for (int j = 0; j < TM; j++) { + cache_a[j] = buf_a[(i) * BM + th_r * TM + j]; + } + + for (int j = 0; j < TN; j++) { + cache_b[j] = buf_b[(i) * BN + th_c * TN + j]; + } + + for (int cc = 0; cc < TN; cc++) { + for (int cr = 0; cr < TM; cr++) { + const int sums_idx = cc*TM + cr; + sums[sums_idx] = mad(cache_a[cr], cache_b[cc], sums[sums_idx]); + } + } + } + barrier(CLK_LOCAL_MEM_FENCE); + } + + const int dr = ir * BM + th_r * TM; + const int dc = ic * BN + th_c * TN; + + const int offsets = batch_idx * batch_stride_d; + + for (int cc = 0; cc < TN; cc++) { + for (int cr = 0; cr < TM; cr++) { + if (dr + cr < ne01 && dc + cc < ne11) { + dst[offsets + (dc + cc) * stride_d + dr + cr] = sums[cc * TM + cr]; + } + } + } +} diff --git a/ggml/src/ggml-opencl/kernels/mul_mm_q5_0_f32_l4_lm.cl b/ggml/src/ggml-opencl/kernels/mul_mm_q5_0_f32_l4_lm.cl new file mode 100644 index 00000000000..1e980a478a8 --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/mul_mm_q5_0_f32_l4_lm.cl @@ -0,0 +1,173 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable + +#define LOAD_VEC_A 8 +#define LOAD_VEC_B 4 + +#define BM 64 +#define BN 64 +#define BK 32 +#define TM 4 +#define TN 8 + +kernel void kernel_mul_mm_q5_0_f32_l4_lm( + global uchar4 * src0_qs, + global uint * src0_qh, + global half * src0_d, + global float4 * src1, + ulong offset1, + global float * dst, + ulong offsetd, + + int ne00, + int ne01, + int ne02, + int ne11, + int ne12, + + int stride_a, + int stride_b, + int stride_d, + + int batch_stride_a, + int batch_stride_b, + int batch_stride_d, + + int r2, + int r3 +) { + src1 = (global float4*)((global char*)src1 + offset1); + dst = (global float *)((global char*)dst + offsetd); + + local float buf_a[BM * BK]; + local float buf_b[BN * BK]; + + const int batch_idx = get_global_id(2); + + const int i13 = batch_idx / ne12; + const int i12 = batch_idx % ne12; + + const int i03 = i13 / r3; + const int i02 = i12 / r2; + + const int batch_idx_a = i03 * ne02 + i02; + + const int ir = get_group_id(0); + const int ic = get_group_id(1); + + const int tid = get_local_id(0); + const int th_r = tid % (BM / TM); + const int th_c = tid / (BM / TM); + + const int loadr_a = get_local_id(0) % (BK / LOAD_VEC_A); + const int loadc_a = get_local_id(0) / (BK / LOAD_VEC_A); + const int loadr_b = get_local_id(0) % (BK / LOAD_VEC_B); + const int loadc_b = get_local_id(0) / (BK / LOAD_VEC_B); + + const int loadstride_a = get_local_size(0) * LOAD_VEC_A / BK; + const int loadstride_b = get_local_size(0) * LOAD_VEC_B / BK; + + int pos_a = (batch_idx_a * batch_stride_a + ir * BM * stride_a) / LOAD_VEC_A; + int pos_b = (batch_idx * batch_stride_b + ic * BN * stride_b) / LOAD_VEC_B; + + float sums[TM * TN]; + float cache_a[TM]; + float cache_b[TN]; + + for (int i = 0; i < TM * TN; i++) { + sums[i] = 0.0f; + } + + for (int block = 0; block < ne00; block += BK) { + for (int l = 0; l < BM; l += loadstride_a) { + if (ir*BM + loadc_a + l < ne01) { + int idx = pos_a + (loadc_a + l) * stride_a / LOAD_VEC_A + loadr_a; + int ib = idx / 4; + int iqs = idx % 4; + + float d = (float)src0_d[ib]; + uint qh_val = src0_qh[ib]; + + global uchar4 * qs_ptr = src0_qs + ib*4 + iqs; + uchar4 q = *qs_ptr; + + uint qh_lo = qh_val >> (iqs * 4); + uint qh_hi = qh_val >> (iqs * 4 + 16); + + uchar4 b_lo = (uchar4)((uchar)qh_lo, (uchar)(qh_lo >> 1), (uchar)(qh_lo >> 2), (uchar)(qh_lo >> 3)) & (uchar)1; + uchar4 b_hi = (uchar4)((uchar)qh_hi, (uchar)(qh_hi >> 1), (uchar)(qh_hi >> 2), (uchar)(qh_hi >> 3)) & (uchar)1; + + float4 v1 = (convert_float4((q & (uchar)0x0F) | (b_lo << (uchar)4)) - 16.0f) * d; + float4 v2 = (convert_float4((q >> (uchar)4) | (b_hi << (uchar)4)) - 16.0f) * d; + + buf_a[(loadr_a * 4 + 0) * BM + loadc_a + l] = v1.s0; + buf_a[(loadr_a * 4 + 1) * BM + loadc_a + l] = v1.s1; + buf_a[(loadr_a * 4 + 2) * BM + loadc_a + l] = v1.s2; + buf_a[(loadr_a * 4 + 3) * BM + loadc_a + l] = v1.s3; + buf_a[(loadr_a * 4 + 16) * BM + loadc_a + l] = v2.s0; + buf_a[(loadr_a * 4 + 17) * BM + loadc_a + l] = v2.s1; + buf_a[(loadr_a * 4 + 18) * BM + loadc_a + l] = v2.s2; + buf_a[(loadr_a * 4 + 19) * BM + loadc_a + l] = v2.s3; + } else { + buf_a[(loadr_a * 4 + 0) * BM + loadc_a + l] = 0.0f; + buf_a[(loadr_a * 4 + 1) * BM + loadc_a + l] = 0.0f; + buf_a[(loadr_a * 4 + 2) * BM + loadc_a + l] = 0.0f; + buf_a[(loadr_a * 4 + 3) * BM + loadc_a + l] = 0.0f; + buf_a[(loadr_a * 4 + 16) * BM + loadc_a + l] = 0.0f; + buf_a[(loadr_a * 4 + 17) * BM + loadc_a + l] = 0.0f; + buf_a[(loadr_a * 4 + 18) * BM + loadc_a + l] = 0.0f; + buf_a[(loadr_a * 4 + 19) * BM + loadc_a + l] = 0.0f; + } + } + + for (int l = 0; l < BN; l += loadstride_b) { + if (ic*BN + loadc_b + l < ne11) { + int idx = pos_b + (loadc_b + l) * stride_b / LOAD_VEC_B + loadr_b; + buf_b[(loadr_b * LOAD_VEC_B + 0) * BN + loadc_b + l] = src1[idx].s0; + buf_b[(loadr_b * LOAD_VEC_B + 1) * BN + loadc_b + l] = src1[idx].s1; + buf_b[(loadr_b * LOAD_VEC_B + 2) * BN + loadc_b + l] = src1[idx].s2; + buf_b[(loadr_b * LOAD_VEC_B + 3) * BN + loadc_b + l] = src1[idx].s3; + } else { + buf_b[(loadr_b * LOAD_VEC_B + 0) * BN + loadc_b + l] = 0.0f; + buf_b[(loadr_b * LOAD_VEC_B + 1) * BN + loadc_b + l] = 0.0f; + buf_b[(loadr_b * LOAD_VEC_B + 2) * BN + loadc_b + l] = 0.0f; + buf_b[(loadr_b * LOAD_VEC_B + 3) * BN + loadc_b + l] = 0.0f; + } + } + + barrier(CLK_LOCAL_MEM_FENCE); + + pos_a += BK / LOAD_VEC_A; + pos_b += BK / LOAD_VEC_B; + + for (int i = 0; i < BK; i++) { + for (int j = 0; j < TM; j++) { + cache_a[j] = buf_a[(i) * BM + th_r * TM + j]; + } + + for (int j = 0; j < TN; j++) { + cache_b[j] = buf_b[(i) * BN + th_c * TN + j]; + } + + for (int cc = 0; cc < TN; cc++) { + for (int cr = 0; cr < TM; cr++) { + const int sums_idx = cc*TM + cr; + sums[sums_idx] = mad(cache_a[cr], cache_b[cc], sums[sums_idx]); + } + } + } + barrier(CLK_LOCAL_MEM_FENCE); + } + + const int dr = ir * BM + th_r * TM; + const int dc = ic * BN + th_c * TN; + + const int offsets = batch_idx * batch_stride_d; + + for (int cc = 0; cc < TN; cc++) { + for (int cr = 0; cr < TM; cr++) { + if (dr + cr < ne01 && dc + cc < ne11) { + dst[offsets + (dc + cc) * stride_d + dr + cr] = sums[cc * TM + cr]; + } + } + } +} diff --git a/ggml/src/ggml-opencl/kernels/mul_mm_q5_1_f32_l4_lm.cl b/ggml/src/ggml-opencl/kernels/mul_mm_q5_1_f32_l4_lm.cl new file mode 100644 index 00000000000..ba06be54697 --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/mul_mm_q5_1_f32_l4_lm.cl @@ -0,0 +1,175 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable + +#define LOAD_VEC_A 8 +#define LOAD_VEC_B 4 + +#define BM 64 +#define BN 64 +#define BK 32 +#define TM 4 +#define TN 8 + +kernel void kernel_mul_mm_q5_1_f32_l4_lm( + global uchar4 * src0_qs, + global uint * src0_qh, + global half * src0_d, + global half * src0_m, + global float4 * src1, + ulong offset1, + global float * dst, + ulong offsetd, + + int ne00, + int ne01, + int ne02, + int ne11, + int ne12, + + int stride_a, + int stride_b, + int stride_d, + + int batch_stride_a, + int batch_stride_b, + int batch_stride_d, + + int r2, + int r3 +) { + src1 = (global float4*)((global char*)src1 + offset1); + dst = (global float *)((global char*)dst + offsetd); + + local float buf_a[BM * BK]; + local float buf_b[BN * BK]; + + const int batch_idx = get_global_id(2); + + const int i13 = batch_idx / ne12; + const int i12 = batch_idx % ne12; + + const int i03 = i13 / r3; + const int i02 = i12 / r2; + + const int batch_idx_a = i03 * ne02 + i02; + + const int ir = get_group_id(0); + const int ic = get_group_id(1); + + const int tid = get_local_id(0); + const int th_r = tid % (BM / TM); + const int th_c = tid / (BM / TM); + + const int loadr_a = get_local_id(0) % (BK / LOAD_VEC_A); + const int loadc_a = get_local_id(0) / (BK / LOAD_VEC_A); + const int loadr_b = get_local_id(0) % (BK / LOAD_VEC_B); + const int loadc_b = get_local_id(0) / (BK / LOAD_VEC_B); + + const int loadstride_a = get_local_size(0) * LOAD_VEC_A / BK; + const int loadstride_b = get_local_size(0) * LOAD_VEC_B / BK; + + int pos_a = (batch_idx_a * batch_stride_a + ir * BM * stride_a) / LOAD_VEC_A; + int pos_b = (batch_idx * batch_stride_b + ic * BN * stride_b) / LOAD_VEC_B; + + float sums[TM * TN]; + float cache_a[TM]; + float cache_b[TN]; + + for (int i = 0; i < TM * TN; i++) { + sums[i] = 0.0f; + } + + for (int block = 0; block < ne00; block += BK) { + for (int l = 0; l < BM; l += loadstride_a) { + if (ir*BM + loadc_a + l < ne01) { + int idx = pos_a + (loadc_a + l) * stride_a / LOAD_VEC_A + loadr_a; + int ib = idx / 4; + int iqs = idx % 4; + + float d = (float)src0_d[ib]; + float m = (float)src0_m[ib]; + uint qh_val = src0_qh[ib]; + + global uchar4 * qs = src0_qs + ib*4 + iqs; + uchar4 q = *qs; + + uint qh_lo = qh_val >> (iqs * 4); + uint qh_hi = qh_val >> (iqs * 4 + 16); + + uchar4 b_lo = (uchar4)((uchar)qh_lo, (uchar)(qh_lo >> 1), (uchar)(qh_lo >> 2), (uchar)(qh_lo >> 3)) & (uchar)1; + uchar4 b_hi = (uchar4)((uchar)qh_hi, (uchar)(qh_hi >> 1), (uchar)(qh_hi >> 2), (uchar)(qh_hi >> 3)) & (uchar)1; + + float4 v1 = convert_float4((q & (uchar)0x0F) | (b_lo << (uchar)4)) * d + m; + float4 v2 = convert_float4((q >> (uchar)4) | (b_hi << (uchar)4)) * d + m; + + buf_a[(loadr_a * 4 + 0) * BM + loadc_a + l] = v1.s0; + buf_a[(loadr_a * 4 + 1) * BM + loadc_a + l] = v1.s1; + buf_a[(loadr_a * 4 + 2) * BM + loadc_a + l] = v1.s2; + buf_a[(loadr_a * 4 + 3) * BM + loadc_a + l] = v1.s3; + buf_a[(loadr_a * 4 + 16) * BM + loadc_a + l] = v2.s0; + buf_a[(loadr_a * 4 + 17) * BM + loadc_a + l] = v2.s1; + buf_a[(loadr_a * 4 + 18) * BM + loadc_a + l] = v2.s2; + buf_a[(loadr_a * 4 + 19) * BM + loadc_a + l] = v2.s3; + } else { + buf_a[(loadr_a * 4 + 0) * BM + loadc_a + l] = 0.0f; + buf_a[(loadr_a * 4 + 1) * BM + loadc_a + l] = 0.0f; + buf_a[(loadr_a * 4 + 2) * BM + loadc_a + l] = 0.0f; + buf_a[(loadr_a * 4 + 3) * BM + loadc_a + l] = 0.0f; + buf_a[(loadr_a * 4 + 16) * BM + loadc_a + l] = 0.0f; + buf_a[(loadr_a * 4 + 17) * BM + loadc_a + l] = 0.0f; + buf_a[(loadr_a * 4 + 18) * BM + loadc_a + l] = 0.0f; + buf_a[(loadr_a * 4 + 19) * BM + loadc_a + l] = 0.0f; + } + } + + for (int l = 0; l < BN; l += loadstride_b) { + if (ic*BN + loadc_b + l < ne11) { + int idx = pos_b + (loadc_b + l) * stride_b / LOAD_VEC_B + loadr_b; + buf_b[(loadr_b * LOAD_VEC_B + 0) * BN + loadc_b + l] = src1[idx].s0; + buf_b[(loadr_b * LOAD_VEC_B + 1) * BN + loadc_b + l] = src1[idx].s1; + buf_b[(loadr_b * LOAD_VEC_B + 2) * BN + loadc_b + l] = src1[idx].s2; + buf_b[(loadr_b * LOAD_VEC_B + 3) * BN + loadc_b + l] = src1[idx].s3; + } else { + buf_b[(loadr_b * LOAD_VEC_B + 0) * BN + loadc_b + l] = 0.0f; + buf_b[(loadr_b * LOAD_VEC_B + 1) * BN + loadc_b + l] = 0.0f; + buf_b[(loadr_b * LOAD_VEC_B + 2) * BN + loadc_b + l] = 0.0f; + buf_b[(loadr_b * LOAD_VEC_B + 3) * BN + loadc_b + l] = 0.0f; + } + } + + barrier(CLK_LOCAL_MEM_FENCE); + + pos_a += BK / LOAD_VEC_A; + pos_b += BK / LOAD_VEC_B; + + for (int i = 0; i < BK; i++) { + for (int j = 0; j < TM; j++) { + cache_a[j] = buf_a[(i) * BM + th_r * TM + j]; + } + + for (int j = 0; j < TN; j++) { + cache_b[j] = buf_b[(i) * BN + th_c * TN + j]; + } + + for (int cc = 0; cc < TN; cc++) { + for (int cr = 0; cr < TM; cr++) { + const int sums_idx = cc*TM + cr; + sums[sums_idx] = mad(cache_a[cr], cache_b[cc], sums[sums_idx]); + } + } + } + barrier(CLK_LOCAL_MEM_FENCE); + } + + const int dr = ir * BM + th_r * TM; + const int dc = ic * BN + th_c * TN; + + const int offsets = batch_idx * batch_stride_d; + + for (int cc = 0; cc < TN; cc++) { + for (int cr = 0; cr < TM; cr++) { + if (dr + cr < ne01 && dc + cc < ne11) { + dst[offsets + (dc + cc) * stride_d + dr + cr] = sums[cc * TM + cr]; + } + } + } +} diff --git a/ggml/src/ggml-opencl/kernels/mul_mm_q5_k_f32_l4_lm.cl b/ggml/src/ggml-opencl/kernels/mul_mm_q5_k_f32_l4_lm.cl new file mode 100644 index 00000000000..8e191f57e83 --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/mul_mm_q5_k_f32_l4_lm.cl @@ -0,0 +1,192 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable + +#define LOAD_VEC_A 4 +#define LOAD_VEC_B 4 + +#define BM 64 +#define BN 64 +#define BK 32 +#define TM 4 +#define TN 8 + +kernel void kernel_mul_mm_q5_k_f32_l4_lm( + global uchar4 * src0_q, + global uchar * src0_qh, + global uchar * src0_s, + global half * src0_d, + global half * src0_dm, + global float4 * src1, + ulong offset1, + global float * dst, + ulong offsetd, + + int ne00, + int ne01, + int ne02, + int ne11, + int ne12, + + int stride_a, + int stride_b, + int stride_d, + + int batch_stride_a, + int batch_stride_b, + int batch_stride_d, + + int r2, + int r3 +) { + src1 = (global float4*)((global char*)src1 + offset1); + dst = (global float *)((global char*)dst + offsetd); + + local float buf_a[BM * BK]; + local float buf_b[BN * BK]; + + const int batch_idx = get_global_id(2); + + const int i13 = batch_idx / ne12; + const int i12 = batch_idx % ne12; + + const int i03 = i13 / r3; + const int i02 = i12 / r2; + + const int batch_idx_a = i03 * ne02 + i02; + + const int ir = get_group_id(0); + const int ic = get_group_id(1); + + const int tid = get_local_id(0); + const int th_r = tid % (BM / TM); + const int th_c = tid / (BM / TM); + + const int loadr_a = get_local_id(0) % (BK / LOAD_VEC_A); + const int loadc_a = get_local_id(0) / (BK / LOAD_VEC_A); + const int loadr_b = get_local_id(0) % (BK / LOAD_VEC_B); + const int loadc_b = get_local_id(0) / (BK / LOAD_VEC_B); + + const int loadstride_a = get_local_size(0) * LOAD_VEC_A / BK; + const int loadstride_b = get_local_size(0) * LOAD_VEC_B / BK; + + int pos_a = (batch_idx_a * batch_stride_a + ir * BM * stride_a) / LOAD_VEC_A; + int pos_b = (batch_idx * batch_stride_b + ic * BN * stride_b) / LOAD_VEC_B; + + float sums[TM * TN]; + float cache_a[TM]; + float cache_b[TN]; + + for (int i = 0; i < TM * TN; i++) { + sums[i] = 0.0f; + } + + for (int block = 0; block < ne00; block += BK) { + for (int l = 0; l < BM; l += loadstride_a) { + if (ir*BM + loadc_a + l < ne01) { + int idx = pos_a + (loadc_a + l) * stride_a / LOAD_VEC_A + loadr_a; + int ib = idx / 64; + int iqs = (idx % 64) * 2; + + int n = iqs / 32; + int b = (iqs % 32) / 16; + int is = 2 * n + b; + int qsi = n * 32 + (iqs % 16) * 2; + + global uchar * scales = src0_s + ib * 12; + + int scidx0 = (is < 4) ? is : (is + 4); + int scidx1 = (is < 4) ? is : (is - 4); + int scidxmask1 = (is < 4) ? 0x30 : 0xC0; + int scidxshift1 = (is < 4) ? 0 : 2; + int mbidx0 = is + 4; + int mbidx1 = (is < 4) ? is + 4 : is; + int mbidxmask0 = (is < 4) ? 0xF : 0xF0; + int mbidxshift0 = (is < 4) ? 0 : 4; + int mbidxmask1 = (is < 4) ? 0x30 : 0xC0; + int mbidxshift1 = (is < 4) ? 0 : 2; + + uchar sc = (scales[scidx0] & 0xF) | ((scales[scidx1] & scidxmask1) >> scidxshift1); + uchar mbyte = ((scales[mbidx0] & mbidxmask0) >> mbidxshift0) | ((scales[mbidx1] & mbidxmask1) >> mbidxshift1); + + float d = (float)src0_d[ib] * (float)sc; + float m = -(float)src0_dm[ib] * (float)mbyte; + + int qh_base = (iqs % 16) * 2; + int bit_pos = 2*n + b; + uchar h0 = (src0_qh[ib*32 + qh_base + 0] >> bit_pos) & 1; + uchar h1 = (src0_qh[ib*32 + qh_base + 1] >> bit_pos) & 1; + uchar h2 = (src0_qh[ib*32 + qh_base + 2] >> bit_pos) & 1; + uchar h3 = (src0_qh[ib*32 + qh_base + 3] >> bit_pos) & 1; + + global uchar4 * qs = src0_q + ib*32 + (qsi >> 2); + uchar4 q = *qs; + float4 v1 = (convert_float4((uchar4)( + ((q.s0 >> (b * 4))&0x0F) | (h0 << 4), + ((q.s1 >> (b * 4))&0x0F) | (h1 << 4), + ((q.s2 >> (b * 4))&0x0F) | (h2 << 4), + ((q.s3 >> (b * 4))&0x0F) | (h3 << 4) + )))*d + m; + + buf_a[(loadr_a * LOAD_VEC_A + 0) * BM + loadc_a + l] = v1.s0; + buf_a[(loadr_a * LOAD_VEC_A + 1) * BM + loadc_a + l] = v1.s1; + buf_a[(loadr_a * LOAD_VEC_A + 2) * BM + loadc_a + l] = v1.s2; + buf_a[(loadr_a * LOAD_VEC_A + 3) * BM + loadc_a + l] = v1.s3; + } else { + buf_a[(loadr_a * LOAD_VEC_A + 0) * BM + loadc_a + l] = 0.0f; + buf_a[(loadr_a * LOAD_VEC_A + 1) * BM + loadc_a + l] = 0.0f; + buf_a[(loadr_a * LOAD_VEC_A + 2) * BM + loadc_a + l] = 0.0f; + buf_a[(loadr_a * LOAD_VEC_A + 3) * BM + loadc_a + l] = 0.0f; + } + } + + for (int l = 0; l < BN; l += loadstride_b) { + if (ic*BN + loadc_b + l < ne11) { + int idx = pos_b + (loadc_b + l) * stride_b / LOAD_VEC_B + loadr_b; + buf_b[(loadr_b * LOAD_VEC_B + 0) * BN + loadc_b + l] = src1[idx].s0; + buf_b[(loadr_b * LOAD_VEC_B + 1) * BN + loadc_b + l] = src1[idx].s1; + buf_b[(loadr_b * LOAD_VEC_B + 2) * BN + loadc_b + l] = src1[idx].s2; + buf_b[(loadr_b * LOAD_VEC_B + 3) * BN + loadc_b + l] = src1[idx].s3; + } else { + buf_b[(loadr_b * LOAD_VEC_B + 0) * BN + loadc_b + l] = 0.0f; + buf_b[(loadr_b * LOAD_VEC_B + 1) * BN + loadc_b + l] = 0.0f; + buf_b[(loadr_b * LOAD_VEC_B + 2) * BN + loadc_b + l] = 0.0f; + buf_b[(loadr_b * LOAD_VEC_B + 3) * BN + loadc_b + l] = 0.0f; + } + } + + barrier(CLK_LOCAL_MEM_FENCE); + + pos_a += BK / LOAD_VEC_A; + pos_b += BK / LOAD_VEC_B; + + for (int i = 0; i < BK; i++) { + for (int j = 0; j < TM; j++) { + cache_a[j] = buf_a[(i) * BM + th_r * TM + j]; + } + + for (int j = 0; j < TN; j++) { + cache_b[j] = buf_b[(i) * BN + th_c * TN + j]; + } + + for (int cc = 0; cc < TN; cc++) { + for (int cr = 0; cr < TM; cr++) { + const int sums_idx = cc*TM + cr; + sums[sums_idx] = mad(cache_a[cr], cache_b[cc], sums[sums_idx]); + } + } + } + barrier(CLK_LOCAL_MEM_FENCE); + } + + const int dr = ir * BM + th_r * TM; + const int dc = ic * BN + th_c * TN; + + const int offsets = batch_idx * batch_stride_d; + + for (int cc = 0; cc < TN; cc++) { + for (int cr = 0; cr < TM; cr++) { + if (dr + cr < ne01 && dc + cc < ne11) { + dst[offsets + (dc + cc) * stride_d + dr + cr] = sums[cc * TM + cr]; + } + } + } +} diff --git a/ggml/src/ggml-opencl/kernels/mul_mm_q6_k_f32_l4_lm.cl b/ggml/src/ggml-opencl/kernels/mul_mm_q6_k_f32_l4_lm.cl new file mode 100644 index 00000000000..3602c92fef4 --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/mul_mm_q6_k_f32_l4_lm.cl @@ -0,0 +1,158 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable + +#define LOAD_VEC_A 2 +#define LOAD_VEC_B 4 + +#define BM 64 +#define BN 64 +#define BK 32 +#define TM 4 +#define TN 8 + +kernel void kernel_mul_mm_q6_k_f32_l4_lm( + global uchar * src0_ql, + global uchar * src0_qh, + global char * src0_s, + global half * src0_d, + global float4 * src1, + ulong offset1, + global float * dst, + ulong offsetd, + + int ne00, + int ne01, + int ne02, + int ne11, + int ne12, + + int stride_a, + int stride_b, + int stride_d, + + int batch_stride_a, + int batch_stride_b, + int batch_stride_d, + + int r2, + int r3 +) { + src1 = (global float4*)((global char*)src1 + offset1); + dst = (global float *)((global char*)dst + offsetd); + + local float buf_a[BM * BK]; + local float buf_b[BN * BK]; + + const int batch_idx = get_global_id(2); + + const int i13 = batch_idx / ne12; + const int i12 = batch_idx % ne12; + + const int i03 = i13 / r3; + const int i02 = i12 / r2; + + const int batch_idx_a = i03 * ne02 + i02; + + const int ir = get_group_id(0); + const int ic = get_group_id(1); + + const int tid = get_local_id(0); + const int th_r = tid % (BM / TM); + const int th_c = tid / (BM / TM); + + const int loadr_a = get_local_id(0) % (BK / LOAD_VEC_A); + const int loadc_a = get_local_id(0) / (BK / LOAD_VEC_A); + const int loadr_b = get_local_id(0) % (BK / LOAD_VEC_B); + const int loadc_b = get_local_id(0) / (BK / LOAD_VEC_B); + + const int loadstride_a = get_local_size(0) * LOAD_VEC_A / BK; + const int loadstride_b = get_local_size(0) * LOAD_VEC_B / BK; + + int pos_a = (batch_idx_a * batch_stride_a + ir * BM * stride_a) / LOAD_VEC_A; + int pos_b = (batch_idx * batch_stride_b + ic * BN * stride_b) / LOAD_VEC_B; + + float sums[TM * TN]; + float cache_a[TM]; + float cache_b[TN]; + + for (int i = 0; i < TM * TN; i++) { + sums[i] = 0.0f; + } + + for (int block = 0; block < ne00; block += BK) { + for (int l = 0; l < BM; l += loadstride_a) { + if (ir*BM + loadc_a + l < ne01) { + int idx = pos_a + (loadc_a + l) * stride_a / LOAD_VEC_A + loadr_a; + + int ib = idx / 128; // 2 values per idx + int iqs = idx % 128; // 0..127 + + int n = iqs / 64; // 0,1 + int b = (iqs % 64) / 32; // 0,1 + int is_b = (iqs % 16) / 8; // 0,1 + int qhshift = ((iqs % 64) / 16) * 2; // 0,2,4,6 + int is = 8 * n + qhshift + is_b; // 0..15 + int qsi = n * 64 + (iqs % 32) * 2; // 0,2,4..126 + int qhi = n * 32 + (iqs % 16) * 2; // 0,2,4..62 + + float dscale = (float)src0_d[ib] * (float)src0_s[ib*16 + is]; + + buf_a[(loadr_a * LOAD_VEC_A + 0) * BM + loadc_a + l] = dscale * convert_float(convert_char(((src0_ql[128*ib + qsi + 0] >> (b * 4)) & 0xF) | (((src0_qh[64*ib + qhi + 0] >> qhshift) & 3) << 4)) - 32); + buf_a[(loadr_a * LOAD_VEC_A + 1) * BM + loadc_a + l] = dscale * convert_float(convert_char(((src0_ql[128*ib + qsi + 1] >> (b * 4)) & 0xF) | (((src0_qh[64*ib + qhi + 1] >> qhshift) & 3) << 4)) - 32); + } else { + buf_a[(loadr_a * LOAD_VEC_A + 0) * BM + loadc_a + l] = 0.0f; + buf_a[(loadr_a * LOAD_VEC_A + 1) * BM + loadc_a + l] = 0.0f; + } + } + + for (int l = 0; l < BN; l += loadstride_b) { + if (ic*BN + loadc_b + l < ne11) { + int idx = pos_b + (loadc_b + l) * stride_b / LOAD_VEC_B + loadr_b; + buf_b[(loadr_b * LOAD_VEC_B + 0) * BN + loadc_b + l] = src1[idx].s0; + buf_b[(loadr_b * LOAD_VEC_B + 1) * BN + loadc_b + l] = src1[idx].s1; + buf_b[(loadr_b * LOAD_VEC_B + 2) * BN + loadc_b + l] = src1[idx].s2; + buf_b[(loadr_b * LOAD_VEC_B + 3) * BN + loadc_b + l] = src1[idx].s3; + } else { + buf_b[(loadr_b * LOAD_VEC_B + 0) * BN + loadc_b + l] = 0.0f; + buf_b[(loadr_b * LOAD_VEC_B + 1) * BN + loadc_b + l] = 0.0f; + buf_b[(loadr_b * LOAD_VEC_B + 2) * BN + loadc_b + l] = 0.0f; + buf_b[(loadr_b * LOAD_VEC_B + 3) * BN + loadc_b + l] = 0.0f; + } + } + + barrier(CLK_LOCAL_MEM_FENCE); + + pos_a += BK / LOAD_VEC_A; + pos_b += BK / LOAD_VEC_B; + + for (int i = 0; i < BK; i++) { + for (int j = 0; j < TM; j++) { + cache_a[j] = buf_a[(i) * BM + th_r * TM + j]; + } + + for (int j = 0; j < TN; j++) { + cache_b[j] = buf_b[(i) * BN + th_c * TN + j]; + } + + for (int cc = 0; cc < TN; cc++) { + for (int cr = 0; cr < TM; cr++) { + const int sums_idx = cc*TM + cr; + sums[sums_idx] = mad(cache_a[cr], cache_b[cc], sums[sums_idx]); + } + } + } + barrier(CLK_LOCAL_MEM_FENCE); + } + + const int dr = ir * BM + th_r * TM; + const int dc = ic * BN + th_c * TN; + + const int offsets = batch_idx * batch_stride_d; + + for (int cc = 0; cc < TN; cc++) { + for (int cr = 0; cr < TM; cr++) { + if (dr + cr < ne01 && dc + cc < ne11) { + dst[offsets + (dc + cc) * stride_d + dr + cr] = sums[cc * TM + cr]; + } + } + } +} diff --git a/ggml/src/ggml-opencl/kernels/mul_mv_iq4_nl_f32.cl b/ggml/src/ggml-opencl/kernels/mul_mv_iq4_nl_f32.cl new file mode 100644 index 00000000000..a6a325cd729 --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/mul_mv_iq4_nl_f32.cl @@ -0,0 +1,164 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable + +#ifdef cl_intel_subgroups +#pragma OPENCL EXTENSION cl_intel_subgroups : enable +#else +#pragma OPENCL EXTENSION cl_khr_subgroups : enable +#endif + +#ifdef cl_intel_required_subgroup_size +#pragma OPENCL EXTENSION cl_intel_required_subgroup_size : enable +#define INTEL_GPU 1 +#define REQD_SUBGROUP_SIZE_16 __attribute__((intel_reqd_sub_group_size(16))) +#define REQD_SUBGROUP_SIZE_32 __attribute__((intel_reqd_sub_group_size(32))) +#elif defined(cl_qcom_reqd_sub_group_size) +#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable +#define ADRENO_GPU 1 +#define REQD_SUBGROUP_SIZE_64 __attribute__((qcom_reqd_sub_group_size("half"))) +#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full"))) +#endif + +#define QK4_NL 32 + +typedef char int8_t; +typedef uchar uint8_t; +typedef short int16_t; +typedef ushort uint16_t; +typedef int int32_t; +typedef uint uint32_t; + +constant float kvalues_iq4nl[16] = { + -127.f, -104.f, -83.f, -65.f, -49.f, -35.f, -22.f, -10.f, + 1.f, 13.f, 25.f, 38.f, 53.f, 69.f, 89.f, 113.f +}; + +//------------------------------------------------------------------------------ +// block_iq4_nl +//------------------------------------------------------------------------------ +struct block_iq4_nl +{ + half d; + uint8_t qs[QK4_NL / 2]; +}; + +//------------------------------------------------------------------------------ +// mul_vec_q_n_f32 +//------------------------------------------------------------------------------ +// Compute inner product between half a block of iq4_nl and 16 floats (yl). +// il indicates where the quants begin (0 or 8). +inline float block_iq4_nl_dot_y( + global struct block_iq4_nl * qb_curr, + private float * yl, + int il +) { + float d = qb_curr->d; + float acc = 0.f; + global uchar * qs = qb_curr->qs + il; + for (int i = 0; i < 8; ++i) { + acc += yl[i] * kvalues_iq4nl[qs[i] & 0x0F]; + acc += yl[i+8] * kvalues_iq4nl[qs[i] >> 4]; + } + return d * acc; +} + +#ifdef INTEL_GPU +#define N_DST 4 // each subgroup group works on 4 rows +#define N_SUBGROUP 1 // number of subgroups in a thread group +#define N_SUBGROUP_SIZE 16 // assuming subgroup size is 16 +#elif defined (ADRENO_GPU) +#define N_DST 4 +#define N_SUBGROUP 1 +#define N_SUBGROUP_SIZE 64 +#endif + +inline void mul_vec_q_n_f32( + global void * src0, + global float * src1, + global float * dst, + int ne00, + int ne01, + int ne02, + int ne10, + int ne12, + int ne0, + int ne1, + int r2, + int r3 +) { + + const ulong nb = ne00/QK4_NL; + + int r0 = get_group_id(0); + int r1 = get_group_id(1); + int im = get_group_id(2); + + int first_row = (r0 * N_SUBGROUP + get_sub_group_id()) * N_DST; + + int i12 = im%ne12; + int i13 = im/ne12; + + ulong offset0 = first_row * nb + (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02); + + global struct block_iq4_nl * x = (global struct block_iq4_nl *) src0 + offset0; + global float * y = (global float *) src1 + r1*ne10 + im*ne00*ne1; + + float yl[16]; // src1 vector cache + float sumf[N_DST]={0.f}; + + int ix = get_sub_group_local_id()/2; + int il = 8*(get_sub_group_local_id()%2); + + global float * yb = y + ix * QK4_NL + il; + + // each thread in a SIMD group deals with half a block. + for (int ib = ix; ib < nb; ib += N_SUBGROUP_SIZE/2) { + for (int i = 0; i < 8; ++i) { + yl[i] = yb[i]; + yl[i+8] = yb[i+16]; + } + + for (int row = 0; row < N_DST; row++) { + sumf[row] += block_iq4_nl_dot_y(x+ib+row*nb, yl, il); + } + + yb += QK4_NL * (N_SUBGROUP_SIZE/2); + } + + float tot[N_DST] = { + sub_group_reduce_add(sumf[0]), sub_group_reduce_add(sumf[1]), + sub_group_reduce_add(sumf[2]), sub_group_reduce_add(sumf[3])}; + for (int row = 0; row < N_DST; ++row) { + if (get_sub_group_local_id() == 0 && first_row + row < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + row] = tot[row]; + } + } +} + +#ifdef INTEL_GPU +REQD_SUBGROUP_SIZE_16 +#elif defined (ADRENO_GPU) +REQD_SUBGROUP_SIZE_64 +#endif +kernel void kernel_mul_mv_iq4_nl_f32( + global void * src0, + ulong offset0, + global float * src1, + ulong offset1, + global float * dst, + ulong offsetd, + int ne00, + int ne01, + int ne02, + int ne10, + int ne12, + int ne0, + int ne1, + int r2, + int r3 +) { + src0 = (global void*)((global char*)src0 + offset0); + src1 = (global float*)((global char*)src1 + offset1); + dst = (global float*)((global char*)dst + offsetd); + + mul_vec_q_n_f32(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3); +} diff --git a/ggml/src/ggml-opencl/kernels/mul_mv_iq4_nl_f32_flat.cl b/ggml/src/ggml-opencl/kernels/mul_mv_iq4_nl_f32_flat.cl new file mode 100644 index 00000000000..8c5b3f52e42 --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/mul_mv_iq4_nl_f32_flat.cl @@ -0,0 +1,202 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable + +#ifdef cl_intel_subgroups +#pragma OPENCL EXTENSION cl_intel_subgroups : enable +#else +#pragma OPENCL EXTENSION cl_khr_subgroups : enable +#endif + +#ifdef cl_intel_required_subgroup_size +#pragma OPENCL EXTENSION cl_intel_required_subgroup_size : enable +#define INTEL_GPU 1 +#define REQD_SUBGROUP_SIZE_16 __attribute__((intel_reqd_sub_group_size(16))) +#define REQD_SUBGROUP_SIZE_32 __attribute__((intel_reqd_sub_group_size(32))) +#elif defined(cl_qcom_reqd_sub_group_size) +#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable +#define ADRENO_GPU 1 +#define REQD_SUBGROUP_SIZE_64 __attribute__((qcom_reqd_sub_group_size("half"))) +#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full"))) +#endif + +#define QK4_NL 32 + +typedef char int8_t; +typedef uchar uint8_t; +typedef short int16_t; +typedef ushort uint16_t; +typedef int int32_t; +typedef uint uint32_t; + +constant float kvalues_iq4nl[16] = { + -127.f, -104.f, -83.f, -65.f, -49.f, -35.f, -22.f, -10.f, + 1.f, 13.f, 25.f, 38.f, 53.f, 69.f, 89.f, 113.f +}; + +//------------------------------------------------------------------------------ +// block_iq4_nl +//------------------------------------------------------------------------------ +struct block_iq4_nl +{ + half d; + uint8_t qs[QK4_NL / 2]; +}; + +// Compute dot product between half a block of iq4_nl quants and activations. +// x points to the quant bytes, dh points to the scale. +// yl has 16 activation values: [0..7] for low nibbles, [8..15] for high nibbles. +// il indicates offset into the quant bytes (0 or 8). +inline float block_iq4_nl_dot_y_flat( + global uchar * x, + global half * dh, + private float * yl, + int il +) { + float d = *dh; + global uchar * qs = x + il; + float acc = 0.f; + for (int i = 0; i < 8; ++i) { + acc += yl[i] * kvalues_iq4nl[qs[i] & 0x0F]; + acc += yl[i+8] * kvalues_iq4nl[qs[i] >> 4]; + } + return d * acc; +} + +#undef N_DST +#undef N_SIMDGROUP +#undef N_SIMDWIDTH + +#ifdef INTEL_GPU +#define N_DST 8 // each subgroup works on 8 rows +#define N_SUBGROUP 1 // number of subgroups in a thread group +#define N_SUBGROUP_SIZE 16 // assuming subgroup size is 16 +#elif defined (ADRENO_GPU) +#define N_DST 8 +#define N_SUBGROUP 1 +#define N_SUBGROUP_SIZE 64 +#endif + +inline void mul_vec_q_n_f32_8x_flat( + global uchar * src0_q, + global half * src0_d, + global float * src1, + global float * dst, + int ne00, + int ne01, + int ne02, + int ne10, + int ne12, + int ne0, + int ne1, + int r2, + int r3 +) { + const ulong nb = ne00/QK4_NL; + + int r0 = get_group_id(0); + int r1 = get_group_id(1); + int im = get_group_id(2); + + int first_row = (r0 * N_SUBGROUP + get_sub_group_id()) * N_DST; + + int i12 = im%ne12; + int i13 = im/ne12; + + // The number of scales is the same as the number of blocks. + ulong offset0_d = first_row * nb + (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02); + // Each block contains QK4_NL/2 uchars, hence offset for qs is as follows. + ulong offset0_q = (first_row * nb + (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02)) * QK4_NL/2; + + global uchar * x = (global uchar *) src0_q + offset0_q; + global half * d = (global half *) src0_d + offset0_d; + global float * y = (global float *) src1 + r1*ne10 + im*ne00*ne1; + + float yl[16]; + float8 sumf = 0.f; + + int ix = get_sub_group_local_id()/2; + int il = 8*(get_sub_group_local_id()%2); + + global float * yb = y + ix*QK4_NL + il; + + for (int ib = ix; ib < nb; ib += N_SUBGROUP_SIZE/2) { + for (int i = 0; i < 8; ++i) { + yl[i] = yb[i]; + yl[i+8] = yb[i+16]; + } + + sumf.s0 += block_iq4_nl_dot_y_flat(x + ib*QK4_NL/2 + 0*nb*QK4_NL/2, d + ib + 0*nb, yl, il); + sumf.s1 += block_iq4_nl_dot_y_flat(x + ib*QK4_NL/2 + 1*nb*QK4_NL/2, d + ib + 1*nb, yl, il); + sumf.s2 += block_iq4_nl_dot_y_flat(x + ib*QK4_NL/2 + 2*nb*QK4_NL/2, d + ib + 2*nb, yl, il); + sumf.s3 += block_iq4_nl_dot_y_flat(x + ib*QK4_NL/2 + 3*nb*QK4_NL/2, d + ib + 3*nb, yl, il); + + sumf.s4 += block_iq4_nl_dot_y_flat(x + ib*QK4_NL/2 + 4*nb*QK4_NL/2, d + ib + 4*nb, yl, il); + sumf.s5 += block_iq4_nl_dot_y_flat(x + ib*QK4_NL/2 + 5*nb*QK4_NL/2, d + ib + 5*nb, yl, il); + sumf.s6 += block_iq4_nl_dot_y_flat(x + ib*QK4_NL/2 + 6*nb*QK4_NL/2, d + ib + 6*nb, yl, il); + sumf.s7 += block_iq4_nl_dot_y_flat(x + ib*QK4_NL/2 + 7*nb*QK4_NL/2, d + ib + 7*nb, yl, il); + + yb += QK4_NL * (N_SUBGROUP_SIZE/2); + } + + float8 tot = (float8)( + sub_group_reduce_add(sumf.s0), sub_group_reduce_add(sumf.s1), + sub_group_reduce_add(sumf.s2), sub_group_reduce_add(sumf.s3), + sub_group_reduce_add(sumf.s4), sub_group_reduce_add(sumf.s5), + sub_group_reduce_add(sumf.s6), sub_group_reduce_add(sumf.s7) + ); + + if (get_sub_group_local_id() == 0) { + if (first_row + 0 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 0] = tot.s0; + } + if (first_row + 1 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 1] = tot.s1; + } + if (first_row + 2 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 2] = tot.s2; + } + if (first_row + 3 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 3] = tot.s3; + } + + if (first_row + 4 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 4] = tot.s4; + } + if (first_row + 5 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 5] = tot.s5; + } + if (first_row + 6 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 6] = tot.s6; + } + if (first_row + 7 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 7] = tot.s7; + } + } +} + +#ifdef INTEL_GPU +REQD_SUBGROUP_SIZE_16 +#elif defined (ADRENO_GPU) +REQD_SUBGROUP_SIZE_64 +#endif +kernel void kernel_mul_mv_iq4_nl_f32_flat( + global uchar * src0_q, + global half * src0_d, + global float * src1, + ulong offset1, + global float * dst, + ulong offsetd, + int ne00, + int ne01, + int ne02, + int ne10, + int ne12, + int ne0, + int ne1, + int r2, + int r3 +) { + src1 = (global float*)((global char*)src1 + offset1); + dst = (global float*)((global char*)dst + offsetd); + + mul_vec_q_n_f32_8x_flat(src0_q, src0_d, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3); +} diff --git a/ggml/src/ggml-opencl/kernels/mul_mv_q4_1_f32.cl b/ggml/src/ggml-opencl/kernels/mul_mv_q4_1_f32.cl new file mode 100644 index 00000000000..6fe828f20e7 --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/mul_mv_q4_1_f32.cl @@ -0,0 +1,219 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable + +#ifdef cl_intel_subgroups +#pragma OPENCL EXTENSION cl_intel_subgroups : enable +#else +#pragma OPENCL EXTENSION cl_khr_subgroups : enable +#endif + +#ifdef cl_intel_required_subgroup_size +#pragma OPENCL EXTENSION cl_intel_required_subgroup_size : enable +#define INTEL_GPU 1 +#define REQD_SUBGROUP_SIZE_16 __attribute__((intel_reqd_sub_group_size(16))) +#define REQD_SUBGROUP_SIZE_32 __attribute__((intel_reqd_sub_group_size(32))) +#elif defined(cl_qcom_reqd_sub_group_size) +#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable +#define ADRENO_GPU 1 +#define REQD_SUBGROUP_SIZE_64 __attribute__((qcom_reqd_sub_group_size("half"))) +#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full"))) +#endif + +#define QK4_1 32 + +struct block_q4_1 { + half d; // delta + half m; // min + uchar qs[QK4_1 / 2]; // nibbles / quants +}; + +inline float block_q4_1_dot_y( + global const struct block_q4_1 * qb_curr, + float sumy, + float16 yl, + int il +) { + float d = qb_curr->d; + float m = qb_curr->m; + + float4 acc = (float4)(0.0f, 0.0f, 0.0f, 0.0f); + + global const ushort * qs = ((global const ushort *) qb_curr + 2 + il/2); + + acc.s0 += yl.s0 * (qs[0] & 0x000F); + acc.s0 += yl.s1 * (qs[0] & 0x0F00); + acc.s0 += yl.s8 * (qs[0] & 0x00F0); + acc.s3 += yl.s9 * (qs[0] & 0xF000); + + acc.s0 += yl.s2 * (qs[1] & 0x000F); + acc.s1 += yl.s3 * (qs[1] & 0x0F00); + acc.s2 += yl.sa * (qs[1] & 0x00F0); + acc.s3 += yl.sb * (qs[1] & 0xF000); + + acc.s0 += yl.s4 * (qs[2] & 0x000F); + acc.s1 += yl.s5 * (qs[2] & 0x0F00); + acc.s2 += yl.sc * (qs[2] & 0x00F0); + acc.s3 += yl.sd * (qs[2] & 0xF000); + + acc.s0 += yl.s6 * (qs[3] & 0x000F); + acc.s1 += yl.s7 * (qs[3] & 0x0F00); + acc.s2 += yl.se * (qs[3] & 0x00F0); + acc.s3 += yl.sf * (qs[3] & 0xF000); + + return d * (acc.s0 + acc.s1 + acc.s2 + acc.s3) + sumy * m; +} + +#undef N_DST +#undef N_SIMDGROUP +#undef N_SIMDWIDTH + +#ifdef INTEL_GPU +#define N_DST 4 // each subgroup works on 4 rows +#define N_SIMDGROUP 1 // number of subgroups in a thread group +#define N_SIMDWIDTH 16 // assuming subgroup size is 16 +#elif defined (ADRENO_GPU) +#define N_DST 4 +#define N_SIMDGROUP 1 +#define N_SIMDWIDTH 64 +#endif + +inline void mul_vec_q_n_f32( + global void * src0, + global float * src1, + global float * dst, + int ne00, + int ne01, + int ne02, + int ne10, + int ne12, + int ne0, + int ne1, + int r2, + int r3 +) { + const ulong nb = ne00/QK4_1; + + int r0 = get_group_id(0); + int r1 = get_group_id(1); + int im = get_group_id(2); + + int first_row = (r0 * N_SIMDGROUP + get_sub_group_id()) * N_DST; + + int i12 = im%ne12; + int i13 = im/ne12; + + ulong offset0 = first_row * nb + (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02); + + global struct block_q4_1 * x = (global struct block_q4_1 *) src0 + offset0; + global float * y = (global float *) src1 + r1*ne10 + im*ne00*ne1; + + float16 yl; + float4 sumf = (float4)(0.f, 0.f, 0.f, 0.f); + + int ix = get_sub_group_local_id()/2; + int il = 8*(get_sub_group_local_id()%2); + + global float * yb = y + ix * QK4_1 + il; + + for (int ib = ix; ib < nb; ib += N_SIMDWIDTH/2) { + float sumy = 0; + + sumy += yb[0]; + sumy += yb[1]; + sumy += yb[2]; + sumy += yb[3]; + sumy += yb[4]; + sumy += yb[5]; + sumy += yb[6]; + sumy += yb[7]; + + sumy += yb[16]; + sumy += yb[17]; + sumy += yb[18]; + sumy += yb[19]; + sumy += yb[20]; + sumy += yb[21]; + sumy += yb[22]; + sumy += yb[23]; + + + yl.s0 = yb[0]; + yl.s1 = yb[1]/256.f; + + yl.s2 = yb[2]; + yl.s3 = yb[3]/256.f; + + yl.s4 = yb[4]; + yl.s5 = yb[5]/256.f; + + yl.s6 = yb[6]; + yl.s7 = yb[7]/256.f; + + yl.s8 = yb[16]/16.f; + yl.s9 = yb[17]/4096.f; + + yl.sa = yb[18]/16.f; + yl.sb = yb[19]/4096.f; + + yl.sc = yb[20]/16.f; + yl.sd = yb[21]/4096.f; + + yl.se = yb[22]/16.f; + yl.sf = yb[23]/4096.f; + + sumf.s0 += block_q4_1_dot_y(x+ib+0*nb, sumy, yl, il); + sumf.s1 += block_q4_1_dot_y(x+ib+1*nb, sumy, yl, il); + sumf.s2 += block_q4_1_dot_y(x+ib+2*nb, sumy, yl, il); + sumf.s3 += block_q4_1_dot_y(x+ib+3*nb, sumy, yl, il); + + yb += QK4_1 * (N_SIMDWIDTH/2); + } + + float4 tot = (float4)( + sub_group_reduce_add(sumf.s0), sub_group_reduce_add(sumf.s1), + sub_group_reduce_add(sumf.s2), sub_group_reduce_add(sumf.s3) + ); + + if (get_sub_group_local_id() == 0) { + if (first_row + 0 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 0] = tot.s0; + } + if (first_row + 1 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 1] = tot.s1; + } + if (first_row + 2 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 2] = tot.s2; + } + if (first_row + 3 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 3] = tot.s3; + } + } +} + +#ifdef INTEL_GPU +REQD_SUBGROUP_SIZE_16 +#elif defined (ADRENO_GPU) +REQD_SUBGROUP_SIZE_64 +#endif +kernel void kernel_mul_mv_q4_1_f32( + global void * src0, + ulong offset0, + global float * src1, + ulong offset1, + global float * dst, + ulong offsetd, + int ne00, + int ne01, + int ne02, + int ne10, + int ne12, + int ne0, + int ne1, + int r2, + int r3 +) { + src0 = (global void*)((global char*)src0 + offset0); + src1 = (global float*)((global char*)src1 + offset1); + dst = (global float*)((global char*)dst + offsetd); + + mul_vec_q_n_f32(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3); +} diff --git a/ggml/src/ggml-opencl/kernels/mul_mv_q4_1_f32_flat.cl b/ggml/src/ggml-opencl/kernels/mul_mv_q4_1_f32_flat.cl new file mode 100644 index 00000000000..d7c4645d675 --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/mul_mv_q4_1_f32_flat.cl @@ -0,0 +1,229 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable + +#ifdef cl_intel_subgroups +#pragma OPENCL EXTENSION cl_intel_subgroups : enable +#else +#pragma OPENCL EXTENSION cl_khr_subgroups : enable +#endif + +#ifdef cl_intel_required_subgroup_size +#pragma OPENCL EXTENSION cl_intel_required_subgroup_size : enable +#define INTEL_GPU 1 +#define REQD_SUBGROUP_SIZE_16 __attribute__((intel_reqd_sub_group_size(16))) +#define REQD_SUBGROUP_SIZE_32 __attribute__((intel_reqd_sub_group_size(32))) +#elif defined(cl_qcom_reqd_sub_group_size) +#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable +#define ADRENO_GPU 1 +#define REQD_SUBGROUP_SIZE_64 __attribute__((qcom_reqd_sub_group_size("half"))) +#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full"))) +#endif + +#define QK4_1 32 + +struct block_q4_1 { + half d; // delta + half m; // min + uchar qs[QK4_1 / 2]; // nibbles / quants +}; + +inline float block_q4_1_dot_y_flat( + global const uchar * x, + global const half * dh, + global const half * mh, + float sumy, + float16 yl, + int il +) { + float d = *dh; + float m = *mh; + global const ushort * qs = ((global const ushort *) x + il/2); + + float4 acc = (float4)(0.0f, 0.0f, 0.0f, 0.0f); + + acc.s0 += yl.s0 * (qs[0] & 0x000F); + acc.s0 += yl.s1 * (qs[0] & 0x0F00); + acc.s0 += yl.s8 * (qs[0] & 0x00F0); + acc.s3 += yl.s9 * (qs[0] & 0xF000); + + acc.s0 += yl.s2 * (qs[1] & 0x000F); + acc.s1 += yl.s3 * (qs[1] & 0x0F00); + acc.s2 += yl.sa * (qs[1] & 0x00F0); + acc.s3 += yl.sb * (qs[1] & 0xF000); + + acc.s0 += yl.s4 * (qs[2] & 0x000F); + acc.s1 += yl.s5 * (qs[2] & 0x0F00); + acc.s2 += yl.sc * (qs[2] & 0x00F0); + acc.s3 += yl.sd * (qs[2] & 0xF000); + + acc.s0 += yl.s6 * (qs[3] & 0x000F); + acc.s1 += yl.s7 * (qs[3] & 0x0F00); + acc.s2 += yl.se * (qs[3] & 0x00F0); + acc.s3 += yl.sf * (qs[3] & 0xF000); + + return d * (acc.s0 + acc.s1 + acc.s2 + acc.s3) + sumy * m; +} + +#undef N_DST +#undef N_SIMDGROUP +#undef N_SIMDWIDTH + +#ifdef INTEL_GPU +#define N_DST 4 // each subgroup works on 4 rows +#define N_SIMDGROUP 1 // number of subgroups in a thread group +#define N_SIMDWIDTH 16 // assuming subgroup size is 16 +#elif defined (ADRENO_GPU) +#define N_DST 4 +#define N_SIMDGROUP 1 +#define N_SIMDWIDTH 64 +#endif + +inline void mul_vec_q_n_f32_flat( + global void * src0_q, + global void * src0_d, + global void * src0_m, + global float * src1, + global float * dst, + int ne00, + int ne01, + int ne02, + int ne10, + int ne12, + int ne0, + int ne1, + int r2, + int r3 +) { + const ulong nb = ne00/QK4_1; + + int r0 = get_group_id(0); + int r1 = get_group_id(1); + int im = get_group_id(2); + + int first_row = (r0 * N_SIMDGROUP + get_sub_group_id()) * N_DST; + + int i12 = im%ne12; + int i13 = im/ne12; + + ulong offset0 = first_row * nb + (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02); + + // The number of scales/mins is the same as the number of blocks. + ulong offset0_dm = (first_row * nb + (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02)); + // Each block contains QK4_1/2 uchars, hence offset for qs is as follows. + ulong offset0_q = (first_row * nb + (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02)) * QK4_1/2; + + global uchar * x = (global uchar *) src0_q + offset0_q; + global half * d = (global half *) src0_d + offset0_dm; + global half * m = (global half *) src0_m + offset0_dm; + global float * y = (global float *) src1 + r1*ne10 + im*ne00*ne1; + + float16 yl; + float4 sumf = (float4)(0.f, 0.f, 0.f, 0.f); + + int ix = get_sub_group_local_id()/2; + int il = 8*(get_sub_group_local_id()%2); + + global float * yb = y + ix * QK4_1 + il; + + for (int ib = ix; ib < nb; ib += N_SIMDWIDTH/2) { + float sumy = 0; + + sumy += yb[0]; + sumy += yb[1]; + sumy += yb[2]; + sumy += yb[3]; + sumy += yb[4]; + sumy += yb[5]; + sumy += yb[6]; + sumy += yb[7]; + + sumy += yb[16]; + sumy += yb[17]; + sumy += yb[18]; + sumy += yb[19]; + sumy += yb[20]; + sumy += yb[21]; + sumy += yb[22]; + sumy += yb[23]; + + + yl.s0 = yb[0]; + yl.s1 = yb[1]/256.f; + + yl.s2 = yb[2]; + yl.s3 = yb[3]/256.f; + + yl.s4 = yb[4]; + yl.s5 = yb[5]/256.f; + + yl.s6 = yb[6]; + yl.s7 = yb[7]/256.f; + + yl.s8 = yb[16]/16.f; + yl.s9 = yb[17]/4096.f; + + yl.sa = yb[18]/16.f; + yl.sb = yb[19]/4096.f; + + yl.sc = yb[20]/16.f; + yl.sd = yb[21]/4096.f; + + yl.se = yb[22]/16.f; + yl.sf = yb[23]/4096.f; + + sumf.s0 += block_q4_1_dot_y_flat(x + ib*QK4_1/2 + 0*nb*QK4_1/2, d + ib + 0*nb, m + ib + 0*nb, sumy, yl, il); + sumf.s1 += block_q4_1_dot_y_flat(x + ib*QK4_1/2 + 1*nb*QK4_1/2, d + ib + 1*nb, m + ib + 1*nb, sumy, yl, il); + sumf.s2 += block_q4_1_dot_y_flat(x + ib*QK4_1/2 + 2*nb*QK4_1/2, d + ib + 2*nb, m + ib + 2*nb, sumy, yl, il); + sumf.s3 += block_q4_1_dot_y_flat(x + ib*QK4_1/2 + 3*nb*QK4_1/2, d + ib + 3*nb, m + ib + 3*nb, sumy, yl, il); + + yb += QK4_1 * (N_SIMDWIDTH/2); + } + + float4 tot = (float4)( + sub_group_reduce_add(sumf.s0), sub_group_reduce_add(sumf.s1), + sub_group_reduce_add(sumf.s2), sub_group_reduce_add(sumf.s3) + ); + + if (get_sub_group_local_id() == 0) { + if (first_row + 0 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 0] = tot.s0; + } + if (first_row + 1 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 1] = tot.s1; + } + if (first_row + 2 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 2] = tot.s2; + } + if (first_row + 3 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 3] = tot.s3; + } + } +} + +#ifdef INTEL_GPU +REQD_SUBGROUP_SIZE_16 +#elif defined (ADRENO_GPU) +REQD_SUBGROUP_SIZE_64 +#endif +kernel void kernel_mul_mv_q4_1_f32_flat( + global void * src0_q, + global void * src0_d, + global void * src0_m, + global float * src1, + ulong offset1, + global float * dst, + ulong offsetd, + int ne00, + int ne01, + int ne02, + int ne10, + int ne12, + int ne0, + int ne1, + int r2, + int r3 +) { + src1 = (global float*)((global char*)src1 + offset1); + dst = (global float*)((global char*)dst + offsetd); + + mul_vec_q_n_f32_flat(src0_q, src0_d, src0_m, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3); +} diff --git a/ggml/src/ggml-opencl/kernels/mul_mv_q4_k_f32.cl b/ggml/src/ggml-opencl/kernels/mul_mv_q4_k_f32.cl new file mode 100644 index 00000000000..71ab9898213 --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/mul_mv_q4_k_f32.cl @@ -0,0 +1,180 @@ +#ifdef cl_intel_required_subgroup_size +#pragma OPENCL EXTENSION cl_intel_required_subgroup_size : enable +#define INTEL_GPU 1 +#define REQD_SUBGROUP_SIZE_16 __attribute__((intel_reqd_sub_group_size(16))) +#define REQD_SUBGROUP_SIZE_32 __attribute__((intel_reqd_sub_group_size(32))) +#elif defined(cl_qcom_reqd_sub_group_size) +#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable +#define ADRENO_GPU 1 +#define REQD_SUBGROUP_SIZE_64 __attribute__((qcom_reqd_sub_group_size("half"))) +#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full"))) +#endif + +//------------------------------------------------------------------------------ +// block_q4_K +//------------------------------------------------------------------------------ +#define QK_K 256 +#define K_SCALE_SIZE 12 + +// 8 blocks of 32 elements each +// weight is represented as x = a * q + b +typedef struct { + half d; // super-block scale for quantized scales + half dmin; // super-block scale for quantized mins + + uchar scales[K_SCALE_SIZE]; // scales and mins, quantized with 6 bits + uchar qs[QK_K/2]; // 4-bit quants +} block_q4_K; + +#undef N_DST +#undef N_SIMDGROUP +#undef N_SIMDWIDTH + +#ifdef INTEL_GPU +#define N_DST 4 // number of rows each SIMD group works on +#define N_SIMDGROUP 1 // number of SIMD groups in a thread group +#define N_SIMDWIDTH 16 // SIMD group size +#elif defined (ADRENO_GPU) +#define N_DST 4 +#define N_SIMDGROUP 1 +#define N_SIMDWIDTH 64 +#endif + +#undef BLOCK_STRIDE +// number of (super) blocks each subgroup processes +// each thread in a subgroup processes a block (32 weights) +#define BLOCK_STRIDE (N_SIMDWIDTH/8) + +#ifdef INTEL_GPU +REQD_SUBGROUP_SIZE_16 +#elif defined (ADRENO_GPU) +REQD_SUBGROUP_SIZE_64 +#endif +kernel void kernel_mul_mv_q4_K_f32( + global char * src0, + int offset0, + global char * src1, + int offset1, + global char * dst, + int offsetd, + int ne00, + int ne01, + ulong nb01, + ulong nb02, + ulong nb03, + int ne12, + ulong nb11, + ulong nb12, + ulong nb13, + int ne0, + int ne1, + int r2, + int r3 +) { + src0 = src0 + offset0; + src1 = src1 + offset1; + dst = dst + offsetd; + + ushort kmask1 = 0x3f3f; + ushort kmask2 = 0x0f0f; + ushort kmask3 = 0xc0c0; + + int ix = get_sub_group_local_id()/8; // super block index + int it = get_sub_group_local_id()%8; // block index (inside super block) + int iq = it/4; // 0 or 1 - first or second half of the super block + int ir = it%4; // 0...3 - block index in the half super block + + int nb = ne00/QK_K; + + int r0 = get_group_id(0); + int r1 = get_group_id(1); + int im = get_group_id(2); + int first_row = (r0 * N_SIMDGROUP + get_sub_group_id()) * N_DST; + + int i12 = im%ne12; + int i13 = im/ne12; + + int offset_src0 = first_row*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03; + int offset_src1 = r1*nb11 + (i12 )*nb12 + (i13 )*nb13; + + global block_q4_K * x = (global block_q4_K *) (src0 + offset_src0); + global float * y = (global float *) (src1 + offset_src1); + + float yl[16]; + float yh[16]; + float sumf[N_DST] = {0.f}; + float all_sum; + + global float * y4 = y + ix * QK_K + 64 * iq + 8 * ir; + + ushort sc16[4]; + uchar * sc8 = (uchar *)sc16; + + for (int ib = ix; ib < nb; ib += BLOCK_STRIDE) { + float4 sumy = {0.f, 0.f, 0.f, 0.f}; + for (int i = 0; i < 8; ++i) { + yl[i+0] = y4[i+0]; + sumy.s0 += yl[i+0]; + + yl[i+8] = y4[i+32]; + sumy.s1 += yl[i+8]; + + yh[i+0] = y4[i+128]; + sumy.s2 += yh[i+0]; + + yh[i+8] = y4[i+160]; + sumy.s3 += yh[i+8]; + } + + global ushort * sc = (global ushort *)x[ib].scales + iq; + global ushort * q1 = (global ushort *)x[ib].qs + 16 * iq + 4 * ir; + global half * dh = &x[ib].d; + + for (int row = 0; row < N_DST; row++) { + sc16[0] = sc[0] & kmask1; + sc16[1] = sc[2] & kmask1; + sc16[2] = ((sc[4] >> 0) & kmask2) | ((sc[0] & kmask3) >> 2); + sc16[3] = ((sc[4] >> 4) & kmask2) | ((sc[2] & kmask3) >> 2); + + global ushort * q2 = q1 + 32; + + float4 acc1 = {0.f, 0.f, 0.f, 0.f}; + float4 acc2 = {0.f, 0.f, 0.f, 0.f}; + for (int i = 0; i < 8; i += 2) { + acc1.s0 += yl[i+0] * (q1[i/2] & 0x000F); + acc1.s1 += yl[i+1] * (q1[i/2] & 0x0F00); + acc1.s2 += yl[i+8] * (q1[i/2] & 0x00F0); + acc1.s3 += yl[i+9] * (q1[i/2] & 0xF000); + acc2.s0 += yh[i+0] * (q2[i/2] & 0x000F); + acc2.s1 += yh[i+1] * (q2[i/2] & 0x0F00); + acc2.s2 += yh[i+8] * (q2[i/2] & 0x00F0); + acc2.s3 += yh[i+9] * (q2[i/2] & 0xF000); + } + + float dall = dh[0]; + float dmin = dh[1]; + sumf[row] += dall * ((acc1.s0 + 1.f/256.f * acc1.s1) * sc8[0] + + (acc1.s2 + 1.f/256.f * acc1.s3) * sc8[1] * 1.f/16.f + + (acc2.s0 + 1.f/256.f * acc2.s1) * sc8[4] + + (acc2.s2 + 1.f/256.f * acc2.s3) * sc8[5] * 1.f/16.f) - + dmin * (sumy.s0 * sc8[2] + sumy.s1 * sc8[3] + sumy.s2 * sc8[6] + sumy.s3 * sc8[7]); + + q1 += nb01/2; + sc += nb01/2; + dh += nb01/2; + } + + y4 += BLOCK_STRIDE * QK_K; + } + + global float * dst_f32 = (global float *) dst + im*ne0*ne1 + r1*ne0; + + for (int row = 0; row < N_DST; ++row) { + all_sum = sub_group_reduce_add(sumf[row]); + if (first_row + row < ne01) { + if (get_sub_group_local_id() == 0) { + dst_f32[first_row + row] = all_sum; + } + } + } +} diff --git a/ggml/src/ggml-opencl/kernels/mul_mv_q4_k_f32_flat.cl b/ggml/src/ggml-opencl/kernels/mul_mv_q4_k_f32_flat.cl new file mode 100644 index 00000000000..d92fb968904 --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/mul_mv_q4_k_f32_flat.cl @@ -0,0 +1,196 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable + +#ifdef cl_intel_subgroups +#pragma OPENCL EXTENSION cl_intel_subgroups : enable +#else +#pragma OPENCL EXTENSION cl_khr_subgroups : enable +#endif + +#ifdef cl_intel_required_subgroup_size +#pragma OPENCL EXTENSION cl_intel_required_subgroup_size : enable +#define INTEL_GPU 1 +#define REQD_SUBGROUP_SIZE_16 __attribute__((intel_reqd_sub_group_size(16))) +#define REQD_SUBGROUP_SIZE_32 __attribute__((intel_reqd_sub_group_size(32))) +#elif defined(cl_qcom_reqd_sub_group_size) +#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable +#define ADRENO_GPU 1 +#define REQD_SUBGROUP_SIZE_64 __attribute__((qcom_reqd_sub_group_size("half"))) +#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full"))) +#endif + +//------------------------------------------------------------------------------ +// block_q4_K +//------------------------------------------------------------------------------ +#define QK_K 256 +#define BLOCK_Q4K_SIZE 144 +#define K_SCALE_SIZE 12 + +// 8 blocks of 32 elements each +// weight is represented as x = a * q + b +typedef struct { + half d; // super-block scale for quantized scales + half dmin; // super-block scale for quantized mins + + uchar scales[K_SCALE_SIZE]; // scales and mins, quantized with 6 bits + uchar qs[QK_K/2]; // 4-bit quants +} block_q4_K; + +#undef N_DST +#undef N_SIMDGROUP +#undef N_SIMDWIDTH + +#ifdef INTEL_GPU +#define N_DST 4 // number of rows each SIMD group works on +#define N_SIMDGROUP 1 // number of SIMD groups in a thread group +#define N_SIMDWIDTH 16 // SIMD group size +#elif defined (ADRENO_GPU) +#define N_DST 16 +#define N_SIMDGROUP 2 +#define N_SIMDWIDTH 64 +#endif + +#undef BLOCK_STRIDE +// number of (super) blocks each subgroup processes +// each thread in a subgroup processes a block (32 weights) +#define BLOCK_STRIDE (N_SIMDWIDTH/8) + +#ifdef INTEL_GPU +REQD_SUBGROUP_SIZE_16 +#elif defined (ADRENO_GPU) +REQD_SUBGROUP_SIZE_64 +#endif +kernel void kernel_mul_mv_q4_K_f32_flat( + global uchar * src0_q, + global uchar * src0_s, + global half * src0_d, + global half * src0_dm, + global char * src1, + int offset1, + global char * dst, + int offsetd, + int ne00, + int ne01, + ulong nb01, + ulong nb02, + ulong nb03, + int ne12, + ulong nb11, + ulong nb12, + ulong nb13, + int ne0, + int ne1, + int r2, + int r3 +) { + src1 = src1 + offset1; + dst = dst + offsetd; + + ushort kmask1 = 0x3f3f; + ushort kmask2 = 0x0f0f; + ushort kmask3 = 0xc0c0; + + int ix = get_sub_group_local_id()/8; + int it = get_sub_group_local_id()%8; + int iq = it/4; + int ir = it%4; + + int nb = ne00/QK_K; + + int r0 = get_group_id(0); + int r1 = get_group_id(1); + int im = get_group_id(2); + int first_row = (r0 * N_SIMDGROUP + get_sub_group_id()) * N_DST; + + int i12 = im%ne12; + int i13 = im/ne12; + + int offset_src0 = (first_row*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03)/BLOCK_Q4K_SIZE; + uint blk = nb01 / BLOCK_Q4K_SIZE; + global uchar * blk_q = (global uchar *)src0_q + offset_src0*(QK_K/2); + global uchar * blk_s = (global uchar *)src0_s + offset_src0*K_SCALE_SIZE; + global half * blk_d = (global half *)src0_d + offset_src0; + global half * blk_dm = (global half *)src0_dm + offset_src0; + + int offset_src1 = r1*nb11 + (i12)*nb12 + (i13)*nb13; + global float * y = (global float *)(src1 + offset_src1); + + float yl[16]; + float yh[16]; + float sumf[N_DST] = {0.f}; + float all_sum; + + global float * y4 = y + ix * QK_K + 64 * iq + 8 * ir; + + ushort sc16[4]; + uchar * sc8 = (uchar *)sc16; + + for (int ib = ix; ib < nb; ib += BLOCK_STRIDE) { + float4 sumy = {0.f, 0.f, 0.f, 0.f}; + for (int i = 0; i < 8; ++i) { + yl[i+0] = y4[i+0]; + sumy.s0 += yl[i+0]; + + yl[i+8] = y4[i+32]; + sumy.s1 += yl[i+8]; + + yh[i+0] = y4[i+128]; + sumy.s2 += yh[i+0]; + + yh[i+8] = y4[i+160]; + sumy.s3 += yh[i+8]; + } + + global ushort * q1 = (global ushort *)(blk_q + ib * (QK_K/2)) + (16 * iq + 4 * ir); + global ushort * sc = (global ushort *)(blk_s + ib * K_SCALE_SIZE) + iq; + global half * d = blk_d + ib; + global half * dm = blk_dm + ib; + + for (int row = 0; row < N_DST; row++) { + sc16[0] = sc[0] & kmask1; + sc16[1] = sc[2] & kmask1; + sc16[2] = ((sc[4] >> 0) & kmask2) | ((sc[0] & kmask3) >> 2); + sc16[3] = ((sc[4] >> 4) & kmask2) | ((sc[2] & kmask3) >> 2); + + global ushort * q2 = q1 + 32; + + float4 acc1 = {0.f, 0.f, 0.f, 0.f}; + float4 acc2 = {0.f, 0.f, 0.f, 0.f}; + for (int i = 0; i < 8; i += 2) { + acc1.s0 += yl[i+0] * (q1[i/2] & 0x000F); + acc1.s1 += yl[i+1] * (q1[i/2] & 0x0F00); + acc1.s2 += yl[i+8] * (q1[i/2] & 0x00F0); + acc1.s3 += yl[i+9] * (q1[i/2] & 0xF000); + acc2.s0 += yh[i+0] * (q2[i/2] & 0x000F); + acc2.s1 += yh[i+1] * (q2[i/2] & 0x0F00); + acc2.s2 += yh[i+8] * (q2[i/2] & 0x00F0); + acc2.s3 += yh[i+9] * (q2[i/2] & 0xF000); + } + + float dall = *d; + float dmin = *dm; + sumf[row] += dall * ((acc1.s0 + 1.f/256.f * acc1.s1) * sc8[0] + + (acc1.s2 + 1.f/256.f * acc1.s3) * sc8[1] * 1.f/16.f + + (acc2.s0 + 1.f/256.f * acc2.s1) * sc8[4] + + (acc2.s2 + 1.f/256.f * acc2.s3) * sc8[5] * 1.f/16.f) - + dmin * (sumy.s0 * sc8[2] + sumy.s1 * sc8[3] + sumy.s2 * sc8[6] + sumy.s3 * sc8[7]); + + q1 += blk*64; + sc += blk*6; + d += blk; + dm += blk; + } + + y4 += BLOCK_STRIDE * QK_K; + } + + global float * dst_f32 = (global float *) dst + im*ne0*ne1 + r1*ne0; + + for (int row = 0; row < N_DST; ++row) { + all_sum = sub_group_reduce_add(sumf[row]); + if (first_row + row < ne01) { + if (get_sub_group_local_id() == 0) { + dst_f32[first_row + row] = all_sum; + } + } + } +} diff --git a/ggml/src/ggml-opencl/kernels/mul_mv_q5_0_f32.cl b/ggml/src/ggml-opencl/kernels/mul_mv_q5_0_f32.cl new file mode 100644 index 00000000000..6d8c9e8f037 --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/mul_mv_q5_0_f32.cl @@ -0,0 +1,241 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable + +#ifdef cl_intel_subgroups +#pragma OPENCL EXTENSION cl_intel_subgroups : enable +#else +#pragma OPENCL EXTENSION cl_khr_subgroups : enable +#endif + +#ifdef cl_intel_required_subgroup_size +#pragma OPENCL EXTENSION cl_intel_required_subgroup_size : enable +#define INTEL_GPU 1 +#define REQD_SUBGROUP_SIZE_16 __attribute__((intel_reqd_sub_group_size(16))) +#define REQD_SUBGROUP_SIZE_32 __attribute__((intel_reqd_sub_group_size(32))) +#elif defined(cl_qcom_reqd_sub_group_size) +#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable +#define ADRENO_GPU 1 +#define REQD_SUBGROUP_SIZE_64 __attribute__((qcom_reqd_sub_group_size("half"))) +#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full"))) +#endif + +#define QK5_0 32 + +struct block_q5_0 { + half d; + uchar qh[4]; + uchar qs[QK5_0 / 2]; +}; + +inline float block_q5_0_dot_y( + global const struct block_q5_0 * qb_curr, + float sumy, + float16 yl, + int il, + global const float * yb +) { + float d = qb_curr->d; + + float4 acc = (float4)(0.0f, 0.0f, 0.0f, 0.0f); + + global const ushort * qs = ((global const ushort *)((global const uchar *) qb_curr + 6 + il)); + + acc.s0 += yl.s0 * (qs[0] & 0x000F); + acc.s0 += yl.s1 * (qs[0] & 0x0F00); + acc.s0 += yl.s8 * (qs[0] & 0x00F0); + acc.s3 += yl.s9 * (qs[0] & 0xF000); + + acc.s0 += yl.s2 * (qs[1] & 0x000F); + acc.s1 += yl.s3 * (qs[1] & 0x0F00); + acc.s2 += yl.sa * (qs[1] & 0x00F0); + acc.s3 += yl.sb * (qs[1] & 0xF000); + + acc.s0 += yl.s4 * (qs[2] & 0x000F); + acc.s1 += yl.s5 * (qs[2] & 0x0F00); + acc.s2 += yl.sc * (qs[2] & 0x00F0); + acc.s3 += yl.sd * (qs[2] & 0xF000); + + acc.s0 += yl.s6 * (qs[3] & 0x000F); + acc.s1 += yl.s7 * (qs[3] & 0x0F00); + acc.s2 += yl.se * (qs[3] & 0x00F0); + acc.s3 += yl.sf * (qs[3] & 0xF000); + + uint qh_val = *((global const uint *)((global const uchar *) qb_curr + 2)); + uchar qh_lo = (uchar)((qh_val >> il) & 0xFF); + uchar qh_hi = (uchar)((qh_val >> (il + 16)) & 0xFF); + + float qh_sum = 0.0f; + qh_sum += yb[0] * (float)((qh_lo >> 0) & 1); + qh_sum += yb[1] * (float)((qh_lo >> 1) & 1); + qh_sum += yb[2] * (float)((qh_lo >> 2) & 1); + qh_sum += yb[3] * (float)((qh_lo >> 3) & 1); + qh_sum += yb[4] * (float)((qh_lo >> 4) & 1); + qh_sum += yb[5] * (float)((qh_lo >> 5) & 1); + qh_sum += yb[6] * (float)((qh_lo >> 6) & 1); + qh_sum += yb[7] * (float)((qh_lo >> 7) & 1); + qh_sum += yb[16] * (float)((qh_hi >> 0) & 1); + qh_sum += yb[17] * (float)((qh_hi >> 1) & 1); + qh_sum += yb[18] * (float)((qh_hi >> 2) & 1); + qh_sum += yb[19] * (float)((qh_hi >> 3) & 1); + qh_sum += yb[20] * (float)((qh_hi >> 4) & 1); + qh_sum += yb[21] * (float)((qh_hi >> 5) & 1); + qh_sum += yb[22] * (float)((qh_hi >> 6) & 1); + qh_sum += yb[23] * (float)((qh_hi >> 7) & 1); + + return d * (acc.s0 + acc.s1 + acc.s2 + acc.s3 + 16.0f * qh_sum - 16.0f * sumy); +} + +#undef N_DST +#undef N_SIMDGROUP +#undef N_SIMDWIDTH + +#ifdef INTEL_GPU +#define N_DST 4 // each subgroup works on 4 rows +#define N_SIMDGROUP 1 // number of subgroups in a thread group +#define N_SIMDWIDTH 16 // assuming subgroup size is 16 +#elif defined (ADRENO_GPU) +#define N_DST 4 +#define N_SIMDGROUP 1 +#define N_SIMDWIDTH 64 +#endif + +inline void mul_vec_q_n_f32( + global void * src0, + global float * src1, + global float * dst, + int ne00, + int ne01, + int ne02, + int ne10, + int ne12, + int ne0, + int ne1, + int r2, + int r3 +) { + const ulong nb = ne00/QK5_0; + + int r0 = get_group_id(0); + int r1 = get_group_id(1); + int im = get_group_id(2); + + int first_row = (r0 * N_SIMDGROUP + get_sub_group_id()) * N_DST; + + int i12 = im%ne12; + int i13 = im/ne12; + + ulong offset0 = first_row * nb + (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02); + + global struct block_q5_0 * x = (global struct block_q5_0 *) src0 + offset0; + global float * y = (global float *) src1 + r1*ne10 + im*ne00*ne1; + + float16 yl; + float4 sumf = (float4)(0.f, 0.f, 0.f, 0.f); + + int ix = get_sub_group_local_id()/2; + int il = 8*(get_sub_group_local_id()%2); + + global float * yb = y + ix * QK5_0 + il; + + for (int ib = ix; ib < nb; ib += N_SIMDWIDTH/2) { + float sumy = 0; + + sumy += yb[0]; + sumy += yb[1]; + sumy += yb[2]; + sumy += yb[3]; + sumy += yb[4]; + sumy += yb[5]; + sumy += yb[6]; + sumy += yb[7]; + + sumy += yb[16]; + sumy += yb[17]; + sumy += yb[18]; + sumy += yb[19]; + sumy += yb[20]; + sumy += yb[21]; + sumy += yb[22]; + sumy += yb[23]; + + + yl.s0 = yb[0]; + yl.s1 = yb[1]/256.f; + + yl.s2 = yb[2]; + yl.s3 = yb[3]/256.f; + + yl.s4 = yb[4]; + yl.s5 = yb[5]/256.f; + + yl.s6 = yb[6]; + yl.s7 = yb[7]/256.f; + + yl.s8 = yb[16]/16.f; + yl.s9 = yb[17]/4096.f; + + yl.sa = yb[18]/16.f; + yl.sb = yb[19]/4096.f; + + yl.sc = yb[20]/16.f; + yl.sd = yb[21]/4096.f; + + yl.se = yb[22]/16.f; + yl.sf = yb[23]/4096.f; + + sumf.s0 += block_q5_0_dot_y(x+ib+0*nb, sumy, yl, il, yb); + sumf.s1 += block_q5_0_dot_y(x+ib+1*nb, sumy, yl, il, yb); + sumf.s2 += block_q5_0_dot_y(x+ib+2*nb, sumy, yl, il, yb); + sumf.s3 += block_q5_0_dot_y(x+ib+3*nb, sumy, yl, il, yb); + + yb += QK5_0 * (N_SIMDWIDTH/2); + } + + float4 tot = (float4)( + sub_group_reduce_add(sumf.s0), sub_group_reduce_add(sumf.s1), + sub_group_reduce_add(sumf.s2), sub_group_reduce_add(sumf.s3) + ); + + if (get_sub_group_local_id() == 0) { + if (first_row + 0 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 0] = tot.s0; + } + if (first_row + 1 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 1] = tot.s1; + } + if (first_row + 2 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 2] = tot.s2; + } + if (first_row + 3 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 3] = tot.s3; + } + } +} + +#ifdef INTEL_GPU +REQD_SUBGROUP_SIZE_16 +#elif defined (ADRENO_GPU) +REQD_SUBGROUP_SIZE_64 +#endif +kernel void kernel_mul_mv_q5_0_f32( + global void * src0, + ulong offset0, + global float * src1, + ulong offset1, + global float * dst, + ulong offsetd, + int ne00, + int ne01, + int ne02, + int ne10, + int ne12, + int ne0, + int ne1, + int r2, + int r3 +) { + src0 = (global void*)((global char*)src0 + offset0); + src1 = (global float*)((global char*)src1 + offset1); + dst = (global float*)((global char*)dst + offsetd); + + mul_vec_q_n_f32(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3); +} diff --git a/ggml/src/ggml-opencl/kernels/mul_mv_q5_0_f32_flat.cl b/ggml/src/ggml-opencl/kernels/mul_mv_q5_0_f32_flat.cl new file mode 100644 index 00000000000..34ec133d398 --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/mul_mv_q5_0_f32_flat.cl @@ -0,0 +1,243 @@ + +#pragma OPENCL EXTENSION cl_khr_fp16 : enable + +#ifdef cl_intel_subgroups +#pragma OPENCL EXTENSION cl_intel_subgroups : enable +#else +#pragma OPENCL EXTENSION cl_khr_subgroups : enable +#endif + +#ifdef cl_intel_required_subgroup_size +#pragma OPENCL EXTENSION cl_intel_required_subgroup_size : enable +#define INTEL_GPU 1 +#define REQD_SUBGROUP_SIZE_16 __attribute__((intel_reqd_sub_group_size(16))) +#define REQD_SUBGROUP_SIZE_32 __attribute__((intel_reqd_sub_group_size(32))) +#elif defined(cl_qcom_reqd_sub_group_size) +#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable +#define ADRENO_GPU 1 +#define REQD_SUBGROUP_SIZE_64 __attribute__((qcom_reqd_sub_group_size("half"))) +#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full"))) +#endif + +#define QK5_0 32 + +inline float block_q5_0_dot_y_flat( + global const uchar * x, + global const uint * qh_ptr, + global const half * dh, + float sumy, + float16 yl, + int il, + global const float * yb +) { + float d = *dh; + global const ushort * qs = ((global const ushort *)(x + il)); + + float4 acc = (float4)(0.0f, 0.0f, 0.0f, 0.0f); + + acc.s0 += yl.s0 * (qs[0] & 0x000F); + acc.s0 += yl.s1 * (qs[0] & 0x0F00); + acc.s0 += yl.s8 * (qs[0] & 0x00F0); + acc.s3 += yl.s9 * (qs[0] & 0xF000); + + acc.s0 += yl.s2 * (qs[1] & 0x000F); + acc.s1 += yl.s3 * (qs[1] & 0x0F00); + acc.s2 += yl.sa * (qs[1] & 0x00F0); + acc.s3 += yl.sb * (qs[1] & 0xF000); + + acc.s0 += yl.s4 * (qs[2] & 0x000F); + acc.s1 += yl.s5 * (qs[2] & 0x0F00); + acc.s2 += yl.sc * (qs[2] & 0x00F0); + acc.s3 += yl.sd * (qs[2] & 0xF000); + + acc.s0 += yl.s6 * (qs[3] & 0x000F); + acc.s1 += yl.s7 * (qs[3] & 0x0F00); + acc.s2 += yl.se * (qs[3] & 0x00F0); + acc.s3 += yl.sf * (qs[3] & 0xF000); + + uint qh_val = *qh_ptr; + uchar qh_lo = (uchar)((qh_val >> il) & 0xFF); + uchar qh_hi = (uchar)((qh_val >> (il + 16)) & 0xFF); + + float qh_sum = 0.0f; + qh_sum += yb[0] * (float)((qh_lo >> 0) & 1); + qh_sum += yb[1] * (float)((qh_lo >> 1) & 1); + qh_sum += yb[2] * (float)((qh_lo >> 2) & 1); + qh_sum += yb[3] * (float)((qh_lo >> 3) & 1); + qh_sum += yb[4] * (float)((qh_lo >> 4) & 1); + qh_sum += yb[5] * (float)((qh_lo >> 5) & 1); + qh_sum += yb[6] * (float)((qh_lo >> 6) & 1); + qh_sum += yb[7] * (float)((qh_lo >> 7) & 1); + qh_sum += yb[16] * (float)((qh_hi >> 0) & 1); + qh_sum += yb[17] * (float)((qh_hi >> 1) & 1); + qh_sum += yb[18] * (float)((qh_hi >> 2) & 1); + qh_sum += yb[19] * (float)((qh_hi >> 3) & 1); + qh_sum += yb[20] * (float)((qh_hi >> 4) & 1); + qh_sum += yb[21] * (float)((qh_hi >> 5) & 1); + qh_sum += yb[22] * (float)((qh_hi >> 6) & 1); + qh_sum += yb[23] * (float)((qh_hi >> 7) & 1); + + return d * (acc.s0 + acc.s1 + acc.s2 + acc.s3 + 16.0f * qh_sum - 16.0f * sumy); +} + +#undef N_DST +#undef N_SIMDGROUP +#undef N_SIMDWIDTH + +#ifdef INTEL_GPU +#define N_DST 4 // each subgroup works on 4 rows +#define N_SIMDGROUP 1 // number of subgroups in a thread group +#define N_SIMDWIDTH 16 // assuming subgroup size is 16 +#elif defined (ADRENO_GPU) +#define N_DST 4 +#define N_SIMDGROUP 1 +#define N_SIMDWIDTH 64 +#endif + +inline void mul_vec_q_n_f32_flat( + global void * src0_qs, + global void * src0_qh, + global void * src0_d, + global float * src1, + global float * dst, + int ne00, + int ne01, + int ne02, + int ne10, + int ne12, + int ne0, + int ne1, + int r2, + int r3 +) { + const ulong nb = ne00/QK5_0; + + int r0 = get_group_id(0); + int r1 = get_group_id(1); + int im = get_group_id(2); + + int first_row = (r0 * N_SIMDGROUP + get_sub_group_id()) * N_DST; + + int i12 = im%ne12; + int i13 = im/ne12; + + ulong offset0 = first_row * nb + (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02); + + ulong offset0_qs = offset0 * (QK5_0/2); + + global uchar * x = (global uchar *) src0_qs + offset0_qs; + global uint * qh = (global uint *) src0_qh + offset0; + global half * d = (global half *) src0_d + offset0; + global float * y = (global float *) src1 + r1*ne10 + im*ne00*ne1; + + float16 yl; + float4 sumf = (float4)(0.f, 0.f, 0.f, 0.f); + + int ix = get_sub_group_local_id()/2; + int il = 8*(get_sub_group_local_id()%2); + + global float * yb = y + ix * QK5_0 + il; + + for (int ib = ix; ib < nb; ib += N_SIMDWIDTH/2) { + float sumy = 0; + + sumy += yb[0]; + sumy += yb[1]; + sumy += yb[2]; + sumy += yb[3]; + sumy += yb[4]; + sumy += yb[5]; + sumy += yb[6]; + sumy += yb[7]; + + sumy += yb[16]; + sumy += yb[17]; + sumy += yb[18]; + sumy += yb[19]; + sumy += yb[20]; + sumy += yb[21]; + sumy += yb[22]; + sumy += yb[23]; + + + yl.s0 = yb[0]; + yl.s1 = yb[1]/256.f; + + yl.s2 = yb[2]; + yl.s3 = yb[3]/256.f; + + yl.s4 = yb[4]; + yl.s5 = yb[5]/256.f; + + yl.s6 = yb[6]; + yl.s7 = yb[7]/256.f; + + yl.s8 = yb[16]/16.f; + yl.s9 = yb[17]/4096.f; + + yl.sa = yb[18]/16.f; + yl.sb = yb[19]/4096.f; + + yl.sc = yb[20]/16.f; + yl.sd = yb[21]/4096.f; + + yl.se = yb[22]/16.f; + yl.sf = yb[23]/4096.f; + + sumf.s0 += block_q5_0_dot_y_flat(x + ib*(QK5_0/2) + 0*nb*(QK5_0/2), qh + ib + 0*nb, d + ib + 0*nb, sumy, yl, il, yb); + sumf.s1 += block_q5_0_dot_y_flat(x + ib*(QK5_0/2) + 1*nb*(QK5_0/2), qh + ib + 1*nb, d + ib + 1*nb, sumy, yl, il, yb); + sumf.s2 += block_q5_0_dot_y_flat(x + ib*(QK5_0/2) + 2*nb*(QK5_0/2), qh + ib + 2*nb, d + ib + 2*nb, sumy, yl, il, yb); + sumf.s3 += block_q5_0_dot_y_flat(x + ib*(QK5_0/2) + 3*nb*(QK5_0/2), qh + ib + 3*nb, d + ib + 3*nb, sumy, yl, il, yb); + + yb += QK5_0 * (N_SIMDWIDTH/2); + } + + float4 tot = (float4)( + sub_group_reduce_add(sumf.s0), sub_group_reduce_add(sumf.s1), + sub_group_reduce_add(sumf.s2), sub_group_reduce_add(sumf.s3) + ); + + if (get_sub_group_local_id() == 0) { + if (first_row + 0 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 0] = tot.s0; + } + if (first_row + 1 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 1] = tot.s1; + } + if (first_row + 2 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 2] = tot.s2; + } + if (first_row + 3 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 3] = tot.s3; + } + } +} + +#ifdef INTEL_GPU +REQD_SUBGROUP_SIZE_16 +#elif defined (ADRENO_GPU) +REQD_SUBGROUP_SIZE_64 +#endif +kernel void kernel_mul_mv_q5_0_f32_flat( + global void * src0_qs, + global void * src0_qh, + global void * src0_d, + global float * src1, + ulong offset1, + global float * dst, + ulong offsetd, + int ne00, + int ne01, + int ne02, + int ne10, + int ne12, + int ne0, + int ne1, + int r2, + int r3 +) { + src1 = (global float*)((global char*)src1 + offset1); + dst = (global float*)((global char*)dst + offsetd); + + mul_vec_q_n_f32_flat(src0_qs, src0_qh, src0_d, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3); +} diff --git a/ggml/src/ggml-opencl/kernels/mul_mv_q5_1_f32.cl b/ggml/src/ggml-opencl/kernels/mul_mv_q5_1_f32.cl new file mode 100644 index 00000000000..1480f675038 --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/mul_mv_q5_1_f32.cl @@ -0,0 +1,243 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable + +#ifdef cl_intel_subgroups +#pragma OPENCL EXTENSION cl_intel_subgroups : enable +#else +#pragma OPENCL EXTENSION cl_khr_subgroups : enable +#endif + +#ifdef cl_intel_required_subgroup_size +#pragma OPENCL EXTENSION cl_intel_required_subgroup_size : enable +#define INTEL_GPU 1 +#define REQD_SUBGROUP_SIZE_16 __attribute__((intel_reqd_sub_group_size(16))) +#define REQD_SUBGROUP_SIZE_32 __attribute__((intel_reqd_sub_group_size(32))) +#elif defined(cl_qcom_reqd_sub_group_size) +#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable +#define ADRENO_GPU 1 +#define REQD_SUBGROUP_SIZE_64 __attribute__((qcom_reqd_sub_group_size("half"))) +#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full"))) +#endif + +#define QK5_1 32 + +struct block_q5_1 { + half d; + half m; + uchar qh[4]; + uchar qs[QK5_1 / 2]; +}; + +inline float block_q5_1_dot_y( + global const struct block_q5_1 * qb_curr, + float sumy, + float16 yl, + int il, + global const float * yb +) { + float d = qb_curr->d; + float m = qb_curr->m; + + float4 acc = (float4)(0.0f, 0.0f, 0.0f, 0.0f); + + global const ushort * qs = ((global const ushort *)((global const uchar *) qb_curr + 8 + il)); + + acc.s0 += yl.s0 * (qs[0] & 0x000F); + acc.s0 += yl.s1 * (qs[0] & 0x0F00); + acc.s0 += yl.s8 * (qs[0] & 0x00F0); + acc.s3 += yl.s9 * (qs[0] & 0xF000); + + acc.s0 += yl.s2 * (qs[1] & 0x000F); + acc.s1 += yl.s3 * (qs[1] & 0x0F00); + acc.s2 += yl.sa * (qs[1] & 0x00F0); + acc.s3 += yl.sb * (qs[1] & 0xF000); + + acc.s0 += yl.s4 * (qs[2] & 0x000F); + acc.s1 += yl.s5 * (qs[2] & 0x0F00); + acc.s2 += yl.sc * (qs[2] & 0x00F0); + acc.s3 += yl.sd * (qs[2] & 0xF000); + + acc.s0 += yl.s6 * (qs[3] & 0x000F); + acc.s1 += yl.s7 * (qs[3] & 0x0F00); + acc.s2 += yl.se * (qs[3] & 0x00F0); + acc.s3 += yl.sf * (qs[3] & 0xF000); + + uint qh_val = *((global const uint *)((global const uchar *) qb_curr + 4)); + uchar qh_lo = (uchar)((qh_val >> il) & 0xFF); + uchar qh_hi = (uchar)((qh_val >> (il + 16)) & 0xFF); + + float qh_sum = 0.0f; + qh_sum += yb[0] * (float)((qh_lo >> 0) & 1); + qh_sum += yb[1] * (float)((qh_lo >> 1) & 1); + qh_sum += yb[2] * (float)((qh_lo >> 2) & 1); + qh_sum += yb[3] * (float)((qh_lo >> 3) & 1); + qh_sum += yb[4] * (float)((qh_lo >> 4) & 1); + qh_sum += yb[5] * (float)((qh_lo >> 5) & 1); + qh_sum += yb[6] * (float)((qh_lo >> 6) & 1); + qh_sum += yb[7] * (float)((qh_lo >> 7) & 1); + qh_sum += yb[16] * (float)((qh_hi >> 0) & 1); + qh_sum += yb[17] * (float)((qh_hi >> 1) & 1); + qh_sum += yb[18] * (float)((qh_hi >> 2) & 1); + qh_sum += yb[19] * (float)((qh_hi >> 3) & 1); + qh_sum += yb[20] * (float)((qh_hi >> 4) & 1); + qh_sum += yb[21] * (float)((qh_hi >> 5) & 1); + qh_sum += yb[22] * (float)((qh_hi >> 6) & 1); + qh_sum += yb[23] * (float)((qh_hi >> 7) & 1); + + return d * (acc.s0 + acc.s1 + acc.s2 + acc.s3 + 16.0f * qh_sum) + sumy * m; +} + +#undef N_DST +#undef N_SIMDGROUP +#undef N_SIMDWIDTH + +#ifdef INTEL_GPU +#define N_DST 4 // each subgroup works on 4 rows +#define N_SIMDGROUP 1 // number of subgroups in a thread group +#define N_SIMDWIDTH 16 // assuming subgroup size is 16 +#elif defined (ADRENO_GPU) +#define N_DST 4 +#define N_SIMDGROUP 1 +#define N_SIMDWIDTH 64 +#endif + +inline void mul_vec_q_n_f32( + global void * src0, + global float * src1, + global float * dst, + int ne00, + int ne01, + int ne02, + int ne10, + int ne12, + int ne0, + int ne1, + int r2, + int r3 +) { + const ulong nb = ne00/QK5_1; + + int r0 = get_group_id(0); + int r1 = get_group_id(1); + int im = get_group_id(2); + + int first_row = (r0 * N_SIMDGROUP + get_sub_group_id()) * N_DST; + + int i12 = im%ne12; + int i13 = im/ne12; + + ulong offset0 = first_row * nb + (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02); + + global struct block_q5_1 * x = (global struct block_q5_1 *) src0 + offset0; + global float * y = (global float *) src1 + r1*ne10 + im*ne00*ne1; + + float16 yl; + float4 sumf = (float4)(0.f, 0.f, 0.f, 0.f); + + int ix = get_sub_group_local_id()/2; + int il = 8*(get_sub_group_local_id()%2); + + global float * yb = y + ix * QK5_1 + il; + + for (int ib = ix; ib < nb; ib += N_SIMDWIDTH/2) { + float sumy = 0; + + sumy += yb[0]; + sumy += yb[1]; + sumy += yb[2]; + sumy += yb[3]; + sumy += yb[4]; + sumy += yb[5]; + sumy += yb[6]; + sumy += yb[7]; + + sumy += yb[16]; + sumy += yb[17]; + sumy += yb[18]; + sumy += yb[19]; + sumy += yb[20]; + sumy += yb[21]; + sumy += yb[22]; + sumy += yb[23]; + + + yl.s0 = yb[0]; + yl.s1 = yb[1]/256.f; + + yl.s2 = yb[2]; + yl.s3 = yb[3]/256.f; + + yl.s4 = yb[4]; + yl.s5 = yb[5]/256.f; + + yl.s6 = yb[6]; + yl.s7 = yb[7]/256.f; + + yl.s8 = yb[16]/16.f; + yl.s9 = yb[17]/4096.f; + + yl.sa = yb[18]/16.f; + yl.sb = yb[19]/4096.f; + + yl.sc = yb[20]/16.f; + yl.sd = yb[21]/4096.f; + + yl.se = yb[22]/16.f; + yl.sf = yb[23]/4096.f; + + sumf.s0 += block_q5_1_dot_y(x+ib+0*nb, sumy, yl, il, yb); + sumf.s1 += block_q5_1_dot_y(x+ib+1*nb, sumy, yl, il, yb); + sumf.s2 += block_q5_1_dot_y(x+ib+2*nb, sumy, yl, il, yb); + sumf.s3 += block_q5_1_dot_y(x+ib+3*nb, sumy, yl, il, yb); + + yb += QK5_1 * (N_SIMDWIDTH/2); + } + + float4 tot = (float4)( + sub_group_reduce_add(sumf.s0), sub_group_reduce_add(sumf.s1), + sub_group_reduce_add(sumf.s2), sub_group_reduce_add(sumf.s3) + ); + + if (get_sub_group_local_id() == 0) { + if (first_row + 0 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 0] = tot.s0; + } + if (first_row + 1 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 1] = tot.s1; + } + if (first_row + 2 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 2] = tot.s2; + } + if (first_row + 3 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 3] = tot.s3; + } + } +} + +#ifdef INTEL_GPU +REQD_SUBGROUP_SIZE_16 +#elif defined (ADRENO_GPU) +REQD_SUBGROUP_SIZE_64 +#endif +kernel void kernel_mul_mv_q5_1_f32( + global void * src0, + ulong offset0, + global float * src1, + ulong offset1, + global float * dst, + ulong offsetd, + int ne00, + int ne01, + int ne02, + int ne10, + int ne12, + int ne0, + int ne1, + int r2, + int r3 +) { + src0 = (global void*)((global char*)src0 + offset0); + src1 = (global float*)((global char*)src1 + offset1); + dst = (global float*)((global char*)dst + offsetd); + + mul_vec_q_n_f32(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3); +} diff --git a/ggml/src/ggml-opencl/kernels/mul_mv_q5_1_f32_flat.cl b/ggml/src/ggml-opencl/kernels/mul_mv_q5_1_f32_flat.cl new file mode 100644 index 00000000000..57c2f140958 --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/mul_mv_q5_1_f32_flat.cl @@ -0,0 +1,247 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable + +#ifdef cl_intel_subgroups +#pragma OPENCL EXTENSION cl_intel_subgroups : enable +#else +#pragma OPENCL EXTENSION cl_khr_subgroups : enable +#endif + +#ifdef cl_intel_required_subgroup_size +#pragma OPENCL EXTENSION cl_intel_required_subgroup_size : enable +#define INTEL_GPU 1 +#define REQD_SUBGROUP_SIZE_16 __attribute__((intel_reqd_sub_group_size(16))) +#define REQD_SUBGROUP_SIZE_32 __attribute__((intel_reqd_sub_group_size(32))) +#elif defined(cl_qcom_reqd_sub_group_size) +#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable +#define ADRENO_GPU 1 +#define REQD_SUBGROUP_SIZE_64 __attribute__((qcom_reqd_sub_group_size("half"))) +#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full"))) +#endif + +#define QK5_1 32 + +inline float block_q5_1_dot_y_flat( + global const uchar * x, + global const uint * qh_ptr, + global const half * dh, + global const half * mh, + float sumy, + float16 yl, + int il, + global const float * yb +) { + float d = *dh; + float m = *mh; + global const ushort * qs = ((global const ushort *)(x + il)); + + float4 acc = (float4)(0.0f, 0.0f, 0.0f, 0.0f); + + acc.s0 += yl.s0 * (qs[0] & 0x000F); + acc.s0 += yl.s1 * (qs[0] & 0x0F00); + acc.s0 += yl.s8 * (qs[0] & 0x00F0); + acc.s3 += yl.s9 * (qs[0] & 0xF000); + + acc.s0 += yl.s2 * (qs[1] & 0x000F); + acc.s1 += yl.s3 * (qs[1] & 0x0F00); + acc.s2 += yl.sa * (qs[1] & 0x00F0); + acc.s3 += yl.sb * (qs[1] & 0xF000); + + acc.s0 += yl.s4 * (qs[2] & 0x000F); + acc.s1 += yl.s5 * (qs[2] & 0x0F00); + acc.s2 += yl.sc * (qs[2] & 0x00F0); + acc.s3 += yl.sd * (qs[2] & 0xF000); + + acc.s0 += yl.s6 * (qs[3] & 0x000F); + acc.s1 += yl.s7 * (qs[3] & 0x0F00); + acc.s2 += yl.se * (qs[3] & 0x00F0); + acc.s3 += yl.sf * (qs[3] & 0xF000); + + uint qh_val = *qh_ptr; + uchar qh_lo = (uchar)((qh_val >> il) & 0xFF); + uchar qh_hi = (uchar)((qh_val >> (il + 16)) & 0xFF); + + float qh_sum = 0.0f; + qh_sum += yb[0] * (float)((qh_lo >> 0) & 1); + qh_sum += yb[1] * (float)((qh_lo >> 1) & 1); + qh_sum += yb[2] * (float)((qh_lo >> 2) & 1); + qh_sum += yb[3] * (float)((qh_lo >> 3) & 1); + qh_sum += yb[4] * (float)((qh_lo >> 4) & 1); + qh_sum += yb[5] * (float)((qh_lo >> 5) & 1); + qh_sum += yb[6] * (float)((qh_lo >> 6) & 1); + qh_sum += yb[7] * (float)((qh_lo >> 7) & 1); + qh_sum += yb[16] * (float)((qh_hi >> 0) & 1); + qh_sum += yb[17] * (float)((qh_hi >> 1) & 1); + qh_sum += yb[18] * (float)((qh_hi >> 2) & 1); + qh_sum += yb[19] * (float)((qh_hi >> 3) & 1); + qh_sum += yb[20] * (float)((qh_hi >> 4) & 1); + qh_sum += yb[21] * (float)((qh_hi >> 5) & 1); + qh_sum += yb[22] * (float)((qh_hi >> 6) & 1); + qh_sum += yb[23] * (float)((qh_hi >> 7) & 1); + + return d * (acc.s0 + acc.s1 + acc.s2 + acc.s3 + 16.0f * qh_sum) + sumy * m; +} + +#undef N_DST +#undef N_SIMDGROUP +#undef N_SIMDWIDTH + +#ifdef INTEL_GPU +#define N_DST 4 // each subgroup works on 4 rows +#define N_SIMDGROUP 1 // number of subgroups in a thread group +#define N_SIMDWIDTH 16 // assuming subgroup size is 16 +#elif defined (ADRENO_GPU) +#define N_DST 4 +#define N_SIMDGROUP 1 +#define N_SIMDWIDTH 64 +#endif + +inline void mul_vec_q_n_f32_flat( + global void * src0_qs, + global void * src0_qh, + global void * src0_d, + global void * src0_m, + global float * src1, + global float * dst, + int ne00, + int ne01, + int ne02, + int ne10, + int ne12, + int ne0, + int ne1, + int r2, + int r3 +) { + const ulong nb = ne00/QK5_1; + + int r0 = get_group_id(0); + int r1 = get_group_id(1); + int im = get_group_id(2); + + int first_row = (r0 * N_SIMDGROUP + get_sub_group_id()) * N_DST; + + int i12 = im%ne12; + int i13 = im/ne12; + + ulong offset0 = first_row * nb + (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02); + + ulong offset0_qs = offset0 * (QK5_1/2); + + global uchar * x = (global uchar *) src0_qs + offset0_qs; + global uint * qh = (global uint *) src0_qh + offset0; + global half * d = (global half *) src0_d + offset0; + global half * ms = (global half *) src0_m + offset0; + global float * y = (global float *) src1 + r1*ne10 + im*ne00*ne1; + + float16 yl; + float4 sumf = (float4)(0.f, 0.f, 0.f, 0.f); + + int ix = get_sub_group_local_id()/2; + int il = 8*(get_sub_group_local_id()%2); + + global float * yb = y + ix * QK5_1 + il; + + for (int ib = ix; ib < nb; ib += N_SIMDWIDTH/2) { + float sumy = 0; + + sumy += yb[0]; + sumy += yb[1]; + sumy += yb[2]; + sumy += yb[3]; + sumy += yb[4]; + sumy += yb[5]; + sumy += yb[6]; + sumy += yb[7]; + + sumy += yb[16]; + sumy += yb[17]; + sumy += yb[18]; + sumy += yb[19]; + sumy += yb[20]; + sumy += yb[21]; + sumy += yb[22]; + sumy += yb[23]; + + + yl.s0 = yb[0]; + yl.s1 = yb[1]/256.f; + + yl.s2 = yb[2]; + yl.s3 = yb[3]/256.f; + + yl.s4 = yb[4]; + yl.s5 = yb[5]/256.f; + + yl.s6 = yb[6]; + yl.s7 = yb[7]/256.f; + + yl.s8 = yb[16]/16.f; + yl.s9 = yb[17]/4096.f; + + yl.sa = yb[18]/16.f; + yl.sb = yb[19]/4096.f; + + yl.sc = yb[20]/16.f; + yl.sd = yb[21]/4096.f; + + yl.se = yb[22]/16.f; + yl.sf = yb[23]/4096.f; + + sumf.s0 += block_q5_1_dot_y_flat(x + ib*(QK5_1/2) + 0*nb*(QK5_1/2), qh + ib + 0*nb, d + ib + 0*nb, ms + ib + 0*nb, sumy, yl, il, yb); + sumf.s1 += block_q5_1_dot_y_flat(x + ib*(QK5_1/2) + 1*nb*(QK5_1/2), qh + ib + 1*nb, d + ib + 1*nb, ms + ib + 1*nb, sumy, yl, il, yb); + sumf.s2 += block_q5_1_dot_y_flat(x + ib*(QK5_1/2) + 2*nb*(QK5_1/2), qh + ib + 2*nb, d + ib + 2*nb, ms + ib + 2*nb, sumy, yl, il, yb); + sumf.s3 += block_q5_1_dot_y_flat(x + ib*(QK5_1/2) + 3*nb*(QK5_1/2), qh + ib + 3*nb, d + ib + 3*nb, ms + ib + 3*nb, sumy, yl, il, yb); + + yb += QK5_1 * (N_SIMDWIDTH/2); + } + + float4 tot = (float4)( + sub_group_reduce_add(sumf.s0), sub_group_reduce_add(sumf.s1), + sub_group_reduce_add(sumf.s2), sub_group_reduce_add(sumf.s3) + ); + + if (get_sub_group_local_id() == 0) { + if (first_row + 0 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 0] = tot.s0; + } + if (first_row + 1 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 1] = tot.s1; + } + if (first_row + 2 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 2] = tot.s2; + } + if (first_row + 3 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 3] = tot.s3; + } + } +} + +#ifdef INTEL_GPU +REQD_SUBGROUP_SIZE_16 +#elif defined (ADRENO_GPU) +REQD_SUBGROUP_SIZE_64 +#endif +kernel void kernel_mul_mv_q5_1_f32_flat( + global void * src0_qs, + global void * src0_qh, + global void * src0_d, + global void * src0_m, + global float * src1, + ulong offset1, + global float * dst, + ulong offsetd, + int ne00, + int ne01, + int ne02, + int ne10, + int ne12, + int ne0, + int ne1, + int r2, + int r3 +) { + src1 = (global float*)((global char*)src1 + offset1); + dst = (global float*)((global char*)dst + offsetd); + + mul_vec_q_n_f32_flat(src0_qs, src0_qh, src0_d, src0_m, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3); +} diff --git a/ggml/src/ggml-opencl/kernels/mul_mv_q5_k_f32.cl b/ggml/src/ggml-opencl/kernels/mul_mv_q5_k_f32.cl new file mode 100644 index 00000000000..b2058abc1b6 --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/mul_mv_q5_k_f32.cl @@ -0,0 +1,187 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable + +#ifdef cl_intel_subgroups +#pragma OPENCL EXTENSION cl_intel_subgroups : enable +#else +#pragma OPENCL EXTENSION cl_khr_subgroups : enable +#endif + +#ifdef cl_intel_required_subgroup_size +#pragma OPENCL EXTENSION cl_intel_required_subgroup_size : enable +#define INTEL_GPU 1 +#define REQD_SUBGROUP_SIZE_16 __attribute__((intel_reqd_sub_group_size(16))) +#define REQD_SUBGROUP_SIZE_32 __attribute__((intel_reqd_sub_group_size(32))) +#elif defined(cl_qcom_reqd_sub_group_size) +#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable +#define ADRENO_GPU 1 +#define REQD_SUBGROUP_SIZE_64 __attribute__((qcom_reqd_sub_group_size("half"))) +#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full"))) +#endif + +#define QK_K 256 +#define K_SCALE_SIZE 12 + +typedef struct { + half d; // super-block scale for quantized scales + half dmin; // super-block scale for quantized mins + uchar scales[K_SCALE_SIZE]; // scales and mins, quantized with 6 bits + uchar qh[QK_K/8]; // quants, high bit (1 bit per value, packed 8 per byte) + uchar qs[QK_K/2]; // quants, low 4 bits (2 values per byte) +} block_q5_K; + +#undef N_DST +#undef N_SIMDGROUP +#undef N_SIMDWIDTH + +#ifdef INTEL_GPU +#define N_DST 4 +#define N_SIMDGROUP 1 +#define N_SIMDWIDTH 16 +#elif defined(ADRENO_GPU) +#define N_DST 4 +#define N_SIMDGROUP 1 +#define N_SIMDWIDTH 64 +#endif + +#define BLOCK_STRIDE (N_SIMDWIDTH/8) + +#ifdef INTEL_GPU +REQD_SUBGROUP_SIZE_16 +#elif defined (ADRENO_GPU) +REQD_SUBGROUP_SIZE_64 +#endif +kernel void kernel_mul_mv_q5_K_f32( + global char * src0, + int offset0, + global char * src1, + int offset1, + global char * dst, + int offsetd, + int ne00, + int ne01, + ulong nb01, + ulong nb02, + ulong nb03, + int ne12, + ulong nb11, + ulong nb12, + ulong nb13, + int ne0, + int ne1, + int r2, + int r3 +) { + src0 = src0 + offset0; + src1 = src1 + offset1; + dst = dst + offsetd; + + ushort kmask1 = 0x3f3f; + ushort kmask2 = 0x0f0f; + ushort kmask3 = 0xc0c0; + + int ix = get_sub_group_local_id()/8; // super block index + int it = get_sub_group_local_id()%8; // block index (inside super block) + int iq = it/4; // 0 or 1 - first or second half of the super block + int ir = it%4; // 0...3 - block index in the half super block + + int nb = ne00/QK_K; + + int r0 = get_group_id(0); + int r1 = get_group_id(1); + int im = get_group_id(2); + int first_row = (r0 * N_SIMDGROUP + get_sub_group_id()) * N_DST; + + int i12 = im%ne12; + int i13 = im/ne12; + + int offset_src0 = first_row*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03; + int offset_src1 = r1*nb11 + (i12 )*nb12 + (i13 )*nb13; + + global block_q5_K * x = (global block_q5_K *) (src0 + offset_src0); + global float * y = (global float *) (src1 + offset_src1); + + float yl[16]; + float yh[16]; + float sumf[N_DST] = {0.f}; + float all_sum; + + global float * y4 = y + ix * QK_K + 64 * iq + 8 * ir; + + uchar u1_lo = (uchar)(1 << (2*iq)); + uchar u2_lo = (uchar)(2 << (2*iq)); + uchar u1_hi = (uchar)(1 << (2*iq + 4)); + uchar u2_hi = (uchar)(2 << (2*iq + 4)); + + ushort sc16[4]; + uchar * sc8 = (uchar *)sc16; + + for (int ib = ix; ib < nb; ib += BLOCK_STRIDE) { + float4 sumy = {0.f, 0.f, 0.f, 0.f}; + for (int i = 0; i < 8; ++i) { + yl[i+0] = y4[i+0]; + sumy.s0 += yl[i+0]; + + yl[i+8] = y4[i+32]; + sumy.s1 += yl[i+8]; + + yh[i+0] = y4[i+128]; + sumy.s2 += yh[i+0]; + + yh[i+8] = y4[i+160]; + sumy.s3 += yh[i+8]; + } + + global ushort * sc = (global ushort *)x[ib].scales + iq; + global ushort * q1 = (global ushort *)x[ib].qs + 16 * iq + 4 * ir; + global uchar * qh = x[ib].qh + 8 * ir; + global half * dh = &x[ib].d; + + for (int row = 0; row < N_DST; row++) { + sc16[0] = sc[0] & kmask1; + sc16[1] = sc[2] & kmask1; + sc16[2] = ((sc[4] >> 0) & kmask2) | ((sc[0] & kmask3) >> 2); + sc16[3] = ((sc[4] >> 4) & kmask2) | ((sc[2] & kmask3) >> 2); + + global ushort * q2 = q1 + 32; + + float4 acc1 = {0.f, 0.f, 0.f, 0.f}; + float4 acc2 = {0.f, 0.f, 0.f, 0.f}; + for (int i = 0; i < 8; i += 2) { + acc1.s0 += yl[i+0] * ((q1[i/2] & 0x000F) + (qh[i+0] & u1_lo ? 16.f : 0.f)); + acc1.s1 += yl[i+1] * ((q1[i/2] & 0x0F00) + (qh[i+1] & u1_lo ? 16.f*256.f : 0.f)); + acc1.s2 += yl[i+8] * ((q1[i/2] & 0x00F0) + (qh[i+0] & u2_lo ? 16.f*16.f : 0.f)); + acc1.s3 += yl[i+9] * ((q1[i/2] & 0xF000) + (qh[i+1] & u2_lo ? 16.f*4096.f: 0.f)); + acc2.s0 += yh[i+0] * ((q2[i/2] & 0x000F) + (qh[i+0] & u1_hi ? 16.f : 0.f)); + acc2.s1 += yh[i+1] * ((q2[i/2] & 0x0F00) + (qh[i+1] & u1_hi ? 16.f*256.f : 0.f)); + acc2.s2 += yh[i+8] * ((q2[i/2] & 0x00F0) + (qh[i+0] & u2_hi ? 16.f*16.f : 0.f)); + acc2.s3 += yh[i+9] * ((q2[i/2] & 0xF000) + (qh[i+1] & u2_hi ? 16.f*4096.f: 0.f)); + } + + float dall = dh[0]; + float dmin = dh[1]; + sumf[row] += dall * ((acc1.s0 + 1.f/256.f * acc1.s1) * sc8[0] + + (acc1.s2 + 1.f/256.f * acc1.s3) * sc8[1] * 1.f/16.f + + (acc2.s0 + 1.f/256.f * acc2.s1) * sc8[4] + + (acc2.s2 + 1.f/256.f * acc2.s3) * sc8[5] * 1.f/16.f) - + dmin * (sumy.s0 * sc8[2] + sumy.s1 * sc8[3] + sumy.s2 * sc8[6] + sumy.s3 * sc8[7]); + + q1 += nb01/2; + sc += nb01/2; + dh += nb01/2; + qh += nb01; + } + + y4 += BLOCK_STRIDE * QK_K; + } + + global float * dst_f32 = (global float *) dst + im*ne0*ne1 + r1*ne0; + + for (int row = 0; row < N_DST; ++row) { + all_sum = sub_group_reduce_add(sumf[row]); + if (first_row + row < ne01) { + if (get_sub_group_local_id() == 0) { + dst_f32[first_row + row] = all_sum; + } + } + } +} diff --git a/ggml/src/ggml-opencl/kernels/mul_mv_q5_k_f32_flat.cl b/ggml/src/ggml-opencl/kernels/mul_mv_q5_k_f32_flat.cl new file mode 100644 index 00000000000..e353a72be70 --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/mul_mv_q5_k_f32_flat.cl @@ -0,0 +1,203 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable + +#ifdef cl_intel_subgroups +#pragma OPENCL EXTENSION cl_intel_subgroups : enable +#else +#pragma OPENCL EXTENSION cl_khr_subgroups : enable +#endif + +#ifdef cl_intel_required_subgroup_size +#pragma OPENCL EXTENSION cl_intel_required_subgroup_size : enable +#define INTEL_GPU 1 +#define REQD_SUBGROUP_SIZE_16 __attribute__((intel_reqd_sub_group_size(16))) +#define REQD_SUBGROUP_SIZE_32 __attribute__((intel_reqd_sub_group_size(32))) +#elif defined(cl_qcom_reqd_sub_group_size) +#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable +#define ADRENO_GPU 1 +#define REQD_SUBGROUP_SIZE_64 __attribute__((qcom_reqd_sub_group_size("half"))) +#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full"))) +#endif + +//------------------------------------------------------------------------------ +// block_q5_K +//------------------------------------------------------------------------------ +#define QK_K 256 +#define BLOCK_Q5K_SIZE 176 +#define K_SCALE_SIZE 12 + +typedef struct { + half d; // super-block scale for quantized scales + half dmin; // super-block scale for quantized mins + uchar scales[K_SCALE_SIZE]; // scales and mins, quantized with 6 bits + uchar qh[QK_K/8]; // quants, high bit (1 bit per value, packed 8 per byte) + uchar qs[QK_K/2]; // quants, low 4 bits (2 values per byte) +} block_q5_K; + +#undef N_DST +#undef N_SIMDGROUP +#undef N_SIMDWIDTH + +#ifdef INTEL_GPU +#define N_DST 4 +#define N_SIMDGROUP 1 +#define N_SIMDWIDTH 16 +#elif defined(ADRENO_GPU) +#define N_DST 16 +#define N_SIMDGROUP 2 +#define N_SIMDWIDTH 64 +#endif + +#undef BLOCK_STRIDE +// number of (super) blocks each subgroup processes +// each thread in a subgroup processes a block (32 weights) +#define BLOCK_STRIDE (N_SIMDWIDTH/8) + +#ifdef INTEL_GPU +REQD_SUBGROUP_SIZE_16 +#elif defined (ADRENO_GPU) +REQD_SUBGROUP_SIZE_64 +#endif +kernel void kernel_mul_mv_q5_K_f32_flat( + global uchar * src0_q, + global uchar * src0_qh, + global uchar * src0_s, + global half * src0_d, + global half * src0_dm, + global char * src1, + int offset1, + global char * dst, + int offsetd, + int ne00, + int ne01, + ulong nb01, + ulong nb02, + ulong nb03, + int ne12, + ulong nb11, + ulong nb12, + ulong nb13, + int ne0, + int ne1, + int r2, + int r3 +) { + src1 = src1 + offset1; + dst = dst + offsetd; + + ushort kmask1 = 0x3f3f; + ushort kmask2 = 0x0f0f; + ushort kmask3 = 0xc0c0; + + int ix = get_sub_group_local_id()/8; + int it = get_sub_group_local_id()%8; + int iq = it/4; + int ir = it%4; + + int nb = ne00/QK_K; + + int r0 = get_group_id(0); + int r1 = get_group_id(1); + int im = get_group_id(2); + int first_row = (r0 * N_SIMDGROUP + get_sub_group_id()) * N_DST; + + int i12 = im%ne12; + int i13 = im/ne12; + + int offset_src0 = (first_row*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03)/BLOCK_Q5K_SIZE; + uint blk = nb01 / BLOCK_Q5K_SIZE; + global uchar * blk_q = (global uchar *)src0_q + offset_src0*(QK_K/2); + global uchar * blk_qh = (global uchar *)src0_qh + offset_src0*(QK_K/8); + global uchar * blk_s = (global uchar *)src0_s + offset_src0*K_SCALE_SIZE; + global half * blk_d = (global half *)src0_d + offset_src0; + global half * blk_dm = (global half *)src0_dm + offset_src0; + + int offset_src1 = r1*nb11 + (i12)*nb12 + (i13)*nb13; + global float * y = (global float *)(src1 + offset_src1); + + float yl[16]; + float yh[16]; + float sumf[N_DST] = {0.f}; + float all_sum; + + global float * y4 = y + ix * QK_K + 64 * iq + 8 * ir; + + uchar u1_lo = (uchar)(1 << (2*iq)); + uchar u2_lo = (uchar)(2 << (2*iq)); + uchar u1_hi = (uchar)(1 << (2*iq + 4)); + uchar u2_hi = (uchar)(2 << (2*iq + 4)); + + ushort sc16[4]; + uchar * sc8 = (uchar *)sc16; + + for (int ib = ix; ib < nb; ib += BLOCK_STRIDE) { + float4 sumy = {0.f, 0.f, 0.f, 0.f}; + for (int i = 0; i < 8; ++i) { + yl[i+0] = y4[i+0]; + sumy.s0 += yl[i+0]; + + yl[i+8] = y4[i+32]; + sumy.s1 += yl[i+8]; + + yh[i+0] = y4[i+128]; + sumy.s2 += yh[i+0]; + + yh[i+8] = y4[i+160]; + sumy.s3 += yh[i+8]; + } + + global ushort * q1 = (global ushort *)(blk_q + ib * (QK_K/2)) + (16 * iq + 4 * ir); + global uchar * qh = (global uchar *)(blk_qh + ib * (QK_K/8)) + 8 * ir; + global ushort * sc = (global ushort *)(blk_s + ib * K_SCALE_SIZE) + iq; + global half * d = blk_d + ib; + global half * dm = blk_dm + ib; + + for (int row = 0; row < N_DST; row++) { + sc16[0] = sc[0] & kmask1; + sc16[1] = sc[2] & kmask1; + sc16[2] = ((sc[4] >> 0) & kmask2) | ((sc[0] & kmask3) >> 2); + sc16[3] = ((sc[4] >> 4) & kmask2) | ((sc[2] & kmask3) >> 2); + + global ushort * q2 = q1 + 32; + + float4 acc1 = {0.f, 0.f, 0.f, 0.f}; + float4 acc2 = {0.f, 0.f, 0.f, 0.f}; + for (int i = 0; i < 8; i += 2) { + acc1.s0 += yl[i+0] * ((q1[i/2] & 0x000F) + (qh[i+0] & u1_lo ? 16.f : 0.f)); + acc1.s1 += yl[i+1] * ((q1[i/2] & 0x0F00) + (qh[i+1] & u1_lo ? 16.f*256.f : 0.f)); + acc1.s2 += yl[i+8] * ((q1[i/2] & 0x00F0) + (qh[i+0] & u2_lo ? 16.f*16.f : 0.f)); + acc1.s3 += yl[i+9] * ((q1[i/2] & 0xF000) + (qh[i+1] & u2_lo ? 16.f*4096.f: 0.f)); + acc2.s0 += yh[i+0] * ((q2[i/2] & 0x000F) + (qh[i+0] & u1_hi ? 16.f : 0.f)); + acc2.s1 += yh[i+1] * ((q2[i/2] & 0x0F00) + (qh[i+1] & u1_hi ? 16.f*256.f : 0.f)); + acc2.s2 += yh[i+8] * ((q2[i/2] & 0x00F0) + (qh[i+0] & u2_hi ? 16.f*16.f : 0.f)); + acc2.s3 += yh[i+9] * ((q2[i/2] & 0xF000) + (qh[i+1] & u2_hi ? 16.f*4096.f: 0.f)); + } + + float dall = *d; + float dmin = *dm; + sumf[row] += dall * ((acc1.s0 + 1.f/256.f * acc1.s1) * sc8[0] + + (acc1.s2 + 1.f/256.f * acc1.s3) * sc8[1] * 1.f/16.f + + (acc2.s0 + 1.f/256.f * acc2.s1) * sc8[4] + + (acc2.s2 + 1.f/256.f * acc2.s3) * sc8[5] * 1.f/16.f) - + dmin * (sumy.s0 * sc8[2] + sumy.s1 * sc8[3] + sumy.s2 * sc8[6] + sumy.s3 * sc8[7]); + + q1 += blk*64; + qh += blk*32; + sc += blk*6; + d += blk; + dm += blk; + } + + y4 += BLOCK_STRIDE * QK_K; + } + + global float * dst_f32 = (global float *) dst + im*ne0*ne1 + r1*ne0; + + for (int row = 0; row < N_DST; ++row) { + all_sum = sub_group_reduce_add(sumf[row]); + if (first_row + row < ne01) { + if (get_sub_group_local_id() == 0) { + dst_f32[first_row + row] = all_sum; + } + } + } +} diff --git a/ggml/src/ggml-opencl/kernels/mul_mv_q6_k.cl b/ggml/src/ggml-opencl/kernels/mul_mv_q6_k_f32.cl similarity index 99% rename from ggml/src/ggml-opencl/kernels/mul_mv_q6_k.cl rename to ggml/src/ggml-opencl/kernels/mul_mv_q6_k_f32.cl index 8a17b9aae63..819e5192e35 100644 --- a/ggml/src/ggml-opencl/kernels/mul_mv_q6_k.cl +++ b/ggml/src/ggml-opencl/kernels/mul_mv_q6_k_f32.cl @@ -111,6 +111,10 @@ kernel void kernel_mul_mv_q6_K_f32( int row = N_SIMDGROUP * r0 + get_sub_group_id(); + if (row >= ne01) { + return; + } + int i12 = im%ne12; int i13 = im/ne12; diff --git a/ggml/src/ggml-opencl/kernels/mul_mv_q6_k_f32_flat.cl b/ggml/src/ggml-opencl/kernels/mul_mv_q6_k_f32_flat.cl new file mode 100644 index 00000000000..57b90c05ae5 --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/mul_mv_q6_k_f32_flat.cl @@ -0,0 +1,178 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable + +#ifdef cl_intel_subgroups +#pragma OPENCL EXTENSION cl_intel_subgroups : enable +#else +#pragma OPENCL EXTENSION cl_khr_subgroups : enable +#endif + +#ifdef cl_intel_required_subgroup_size +#pragma OPENCL EXTENSION cl_intel_required_subgroup_size : enable +#define INTEL_GPU 1 +#define REQD_SUBGROUP_SIZE_16 __attribute__((intel_reqd_sub_group_size(16))) +#define REQD_SUBGROUP_SIZE_32 __attribute__((intel_reqd_sub_group_size(32))) +#elif defined(cl_qcom_reqd_sub_group_size) +#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable +#define ADRENO_GPU 1 +#define REQD_SUBGROUP_SIZE_64 __attribute__((qcom_reqd_sub_group_size("half"))) +#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full"))) +#endif + +//------------------------------------------------------------------------------ +// kernel_mul_mv_q6_K_f32_flat +//------------------------------------------------------------------------------ +#define Q6_K_MASK1 0x03 +#define Q6_K_MASK2 0x0C +#define Q6_K_MASK3 0x30 +#define Q6_K_MASK4 0xC0 + +#define QK_K 256 + +inline float block_q_6_K_dot_y_flat( + global uchar * blk_ql, + global uchar * blk_qh, + global char * blk_scales, + global half * blk_d, + int ib, + int ip, + int is, + int l0, + float4 y0, + float4 y1, + float4 y2, + float4 y3 +) { + int q_offset_l = 64*ip + l0; + int q_offset_h = 32*ip + l0; + + global uchar * q1 = blk_ql + ib*128 + q_offset_l; + global uchar * q2 = q1 + QK_K/8; + global uchar * qh = blk_qh + ib*64 + q_offset_h; + global char * sc = blk_scales + ib*16 + is; + + float dall = blk_d[ib]; + + // Vectorized loads: 3 uchar4 weight loads instead of 12 scalar byte reads. + // q_offset_l/h are 4-aligned, so these are aligned vector loads. + uchar4 q1v = vload4(0, q1); + uchar4 q2v = vload4(0, q2); + uchar4 qhv = vload4(0, qh); + + int4 q1i = convert_int4(q1v); + int4 q2i = convert_int4(q2v); + int4 qhi = convert_int4(qhv); + + // Reconstruct the four 6-bit weight groups (low/high nibble of ql OR'd with the + // matching 2-bit plane of qh), same arithmetic as the scalar version, then dot() + // against the cached activation lanes. + float4 w0 = convert_float4((q1i & 0xF) | ((qhi & Q6_K_MASK1) << 4)) - 32.f; + float4 w1 = convert_float4((q2i & 0xF) | ((qhi & Q6_K_MASK2) << 2)) - 32.f; + float4 w2 = convert_float4((q1i >> 4) | ((qhi & Q6_K_MASK3) )) - 32.f; + float4 w3 = convert_float4((q2i >> 4) | ((qhi & Q6_K_MASK4) >> 2)) - 32.f; + + return dall * (dot(y0, w0) * sc[0] + dot(y1, w1) * sc[2] + + dot(y2, w2) * sc[4] + dot(y3, w3) * sc[6]); +} + +#undef N_DST +#undef N_SIMDGROUP +#undef N_SIMDWIDTH + +#ifdef INTEL_GPU +#define N_DST 4 +#define N_SIMDGROUP 2 +#define N_SIMDWIDTH 16 +#elif defined (ADRENO_GPU) +#define N_DST 16 +#define N_SIMDGROUP 2 +#define N_SIMDWIDTH 64 +#endif + +#define BLOCK_STRIDE (N_SIMDWIDTH/16) // number of blocks each subgroup processes + +#ifdef INTEL_GPU +REQD_SUBGROUP_SIZE_16 +#elif defined (ADRENO_GPU) +REQD_SUBGROUP_SIZE_64 +#endif +kernel void kernel_mul_mv_q6_K_f32_flat( + global uchar * src0_ql, + global uchar * src0_qh, + global char * src0_s, + global half * src0_d, + global float * src1, + ulong offset1, + global float * dst, + ulong offsetd, + int ne00, + int ne01, + int ne02, + int ne10, + int ne12, + int ne0, + int ne1, + int r2, + int r3 +) { + src1 = (global float*)((global char*)src1 + offset1); + dst = (global float*)((global char*)dst + offsetd); + + int nb = ne00/QK_K; + + int r0 = get_group_id(0); + int r1 = get_group_id(1); + int im = get_group_id(2); + + int i12 = im%ne12; + int i13 = im/ne12; + + int first_row = (N_SIMDGROUP * r0 + get_sub_group_id()) * N_DST; + + ulong offset_src0 = first_row*nb + (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02); + ulong offset_src0_ql = offset_src0 * 128; + ulong offset_src0_qh = offset_src0 * 64; + ulong offset_src0_s = offset_src0 * 16; + ulong offset_src0_d = offset_src0; + + global uchar * blk_ql = (global uchar *) src0_ql + offset_src0_ql; + global uchar * blk_qh = (global uchar *) src0_qh + offset_src0_qh; + global char * blk_scales = (global char *) src0_s + offset_src0_s; + global half * blk_d = (global half *) src0_d + offset_src0_d; + global float * yy = (global float *) src1 + r1*ne10 + im*ne00*ne1; + + int tid = get_sub_group_local_id()%(N_SIMDWIDTH/BLOCK_STRIDE); // within-super-block part, 0..15 + int ix = get_sub_group_local_id()/(N_SIMDWIDTH/BLOCK_STRIDE); // super-block selector, 0..BLOCK_STRIDE-1 + int ip = tid/8; // first or second half of (super) block (0 or 1) + int il = tid%8; // each half has 8 parts, one per scale + int n = 4; // 4 scales at a time (and 4 sums) + int l0 = n*il; // offset into half-block, 0..28 + int is = 8*ip + l0/16; // 0, 1, 8, 9 + + float sumf[N_DST]; + for (int row = 0; row < N_DST; row++) { + sumf[row] = 0.f; + } + + for (int ib = ix; ib < nb; ib += BLOCK_STRIDE) { + global float * y = yy + ib * QK_K + 128*ip + l0; + float4 y0 = vload4(0, y + 0); + float4 y1 = vload4(0, y + 32); + float4 y2 = vload4(0, y + 64); + float4 y3 = vload4(0, y + 96); + + for (int row = 0; row < N_DST; row++) { + if (first_row + row < ne01) { + sumf[row] += block_q_6_K_dot_y_flat( + blk_ql + row*nb*128, blk_qh + row*nb*64, blk_scales + row*nb*16, blk_d + row*nb, + ib, ip, is, l0, y0, y1, y2, y3); + } + } + } + + for (int row = 0; row < N_DST; row++) { + float tot = sub_group_reduce_add(sumf[row]); + if (get_sub_group_local_id() == 0 && first_row + row < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + row] = tot; + } + } +} diff --git a/ggml/src/ggml-opencl/kernels/neg.cl b/ggml/src/ggml-opencl/kernels/neg.cl new file mode 100644 index 00000000000..a862d8bc585 --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/neg.cl @@ -0,0 +1,125 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable + +kernel void kernel_neg_f32( + global const float * src0, + ulong offset0, + global float * dst, + ulong offsetd, + int n +) { + if (get_global_id(0) >= n) { + return; + } + src0 = (global float*)((global char*)src0 + offset0); + dst = (global float*)((global char*)dst + offsetd); + + dst[get_global_id(0)] = -src0[get_global_id(0)]; +} + +kernel void kernel_neg_f32_4( + global const float4 * src0, + ulong offset0, + global float4 * dst, + ulong offsetd, + int n +) { + if (get_global_id(0) >= n) { + return; + } + src0 = (global float4*)((global char*)src0 + offset0); + dst = (global float4*)((global char*)dst + offsetd); + + dst[get_global_id(0)] = -src0[get_global_id(0)]; +} + +kernel void kernel_neg_f16( + global const half * src0, + ulong offset0, + global half * dst, + ulong offsetd, + int n +) { + if (get_global_id(0) >= n) { + return; + } + src0 = (global half*)((global char*)src0 + offset0); + dst = (global half*)((global char*)dst + offsetd); + + dst[get_global_id(0)] = -src0[get_global_id(0)]; +} + +kernel void kernel_neg_f16_4( + global const half4 * src0, + ulong offset0, + global half4 * dst, + ulong offsetd, + int n +) { + if (get_global_id(0) >= n) { + return; + } + src0 = (global half4*)((global char*)src0 + offset0); + dst = (global half4*)((global char*)dst + offsetd); + + dst[get_global_id(0)] = -src0[get_global_id(0)]; +} + +kernel void kernel_neg_f32_nc( + global const char * src0, + ulong offset0, + global char * dst, + ulong offsetd, + int ne00, + ulong nb00, + ulong nb01, + ulong nb02, + ulong nb03, + ulong nb0, + ulong nb1, + ulong nb2, + ulong nb3 +) { + src0 = src0 + offset0; + dst = dst + offsetd; + + const int i3 = get_group_id(2); + const int i2 = get_group_id(1); + const int i1 = get_group_id(0); + + for (int i0 = get_local_id(0); i0 < ne00; i0 += get_local_size(0)) { + global const float * x = (global const float *)(src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); + global float * y = (global float *)(dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); + + *y = -*x; + } +} + +kernel void kernel_neg_f16_nc( + global const char * src0, + ulong offset0, + global char * dst, + ulong offsetd, + int ne00, + ulong nb00, + ulong nb01, + ulong nb02, + ulong nb03, + ulong nb0, + ulong nb1, + ulong nb2, + ulong nb3 +) { + src0 = src0 + offset0; + dst = dst + offsetd; + + const int i3 = get_group_id(2); + const int i2 = get_group_id(1); + const int i1 = get_group_id(0); + + for (int i0 = get_local_id(0); i0 < ne00; i0 += get_local_size(0)) { + global const half * x = (global const half *)(src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); + global half * y = (global half *)(dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); + + *y = -*x; + } +} diff --git a/ggml/src/ggml-opencl/kernels/repeat.cl b/ggml/src/ggml-opencl/kernels/repeat.cl index 079498f5ab9..53951a55434 100644 --- a/ggml/src/ggml-opencl/kernels/repeat.cl +++ b/ggml/src/ggml-opencl/kernels/repeat.cl @@ -1,39 +1,38 @@ -kernel void kernel_repeat( - global const char * src0_data_in, - global char * dst_data_in, - ulong src0_offset, - ulong dst_offset, - int src0_ne0, int src0_ne1, int src0_ne2, int src0_ne3, - ulong src0_nb0, ulong src0_nb1, ulong src0_nb2, ulong src0_nb3, - int dst_ne0, int dst_ne1, int dst_ne2, int dst_ne3, - ulong dst_nb0, ulong dst_nb1, ulong dst_nb2, ulong dst_nb3 +kernel void kernel_repeat_f32( + global const char * src0, + ulong offset0, + global char * dst, + ulong offsetd, + int ne00, + int ne01, + int ne02, + int ne03, + ulong nb00, + ulong nb01, + ulong nb02, + ulong nb03, + int ne0, + ulong nb0, + ulong nb1, + ulong nb2, + ulong nb3 ) { - global const char * src0_data = src0_data_in + src0_offset; - global char * dst_data = dst_data_in + dst_offset; + src0 = src0 + offset0; + dst = dst + offsetd; - const int d3 = get_global_id(2); - const int d2 = get_global_id(1); - const int d1 = get_global_id(0); + const int i3 = get_group_id(2); + const int i2 = get_group_id(1); + const int i1 = get_group_id(0); - if (d3 >= dst_ne3 || d2 >= dst_ne2 || d1 >= dst_ne1) { - return; - } - - const int s3 = d3 % src0_ne3; - const int s2 = d2 % src0_ne2; - const int s1 = d1 % src0_ne1; - - const global char * p_src0_slice = src0_data + (ulong)s3*src0_nb3 + (ulong)s2*src0_nb2 + (ulong)s1*src0_nb1; - global char * p_dst_slice = dst_data + (ulong)d3*dst_nb3 + (ulong)d2*dst_nb2 + (ulong)d1*dst_nb1; + const int i03 = i3%ne03; + const int i02 = i2%ne02; + const int i01 = i1%ne01; - for (int d0 = 0; d0 < dst_ne0; ++d0) { - // Determine source index for dimension 0 based on tiling/broadcasting. - const int s0 = d0 % src0_ne0; + global const char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01; + global char * dst_ptr = dst + i3*nb3 + i2*nb2 + i1*nb1; - const global char * restrict current_src_el_ptr = p_src0_slice + (ulong)s0*src0_nb0; - global char * restrict current_dst_el_ptr = p_dst_slice + (ulong)d0*dst_nb0; - for (int k = 0; k < src0_nb0; ++k) { - current_dst_el_ptr[k] = current_src_el_ptr[k]; - } + for (int i0 = get_local_id(0); i0 < ne0; i0 += get_local_size(0)) { + const int i00 = i0%ne00; + *((global float *)(dst_ptr + i0*nb0)) = *((global float *)(src0_ptr + i00*nb00)); } } diff --git a/ggml/src/ggml-opencl/kernels/scale.cl b/ggml/src/ggml-opencl/kernels/scale.cl index aeca8a456e4..17ed97f0d66 100644 --- a/ggml/src/ggml-opencl/kernels/scale.cl +++ b/ggml/src/ggml-opencl/kernels/scale.cl @@ -1,9 +1,19 @@ #pragma OPENCL EXTENSION cl_khr_fp16 : enable -//------------------------------------------------------------------------------ -// scale -//------------------------------------------------------------------------------ -kernel void kernel_scale( +kernel void kernel_scale_f32( + global float * src0, + ulong offset0, + global float * dst, + ulong offsetd, + float scale, + float bias +) { + src0 = (global float*)((global char*)src0 + offset0); + dst = (global float*)((global char*)dst + offsetd); + dst[get_global_id(0)] = src0[get_global_id(0)] * scale + bias; +} + +kernel void kernel_scale_f32_4( global float4 * src0, ulong offset0, global float4 * dst, diff --git a/ggml/src/ggml-opencl/kernels/softplus.cl b/ggml/src/ggml-opencl/kernels/softplus.cl index 033766e2e07..6f8b7474165 100644 --- a/ggml/src/ggml-opencl/kernels/softplus.cl +++ b/ggml/src/ggml-opencl/kernels/softplus.cl @@ -3,86 +3,114 @@ //------------------------------------------------------------------------------ // softplus //------------------------------------------------------------------------------ -inline float softplus_f32(float x){ - float ax = fabs(x); - float m = fmax(x, 0.0f); - return log1p(exp(-ax)) + m; + +kernel void kernel_softplus_f32( + global const float * src0, + ulong offset0, + global float * dst, + ulong offsetd +) { + src0 = (global float*)((global char*)src0 + offset0); + dst = (global float*)((global char*)dst + offsetd); + + dst[get_global_id(0)] = (src0[get_global_id(0)] > 20.0f) ? src0[get_global_id(0)] : log(1.0f + exp(src0[get_global_id(0)])); +} + +kernel void kernel_softplus_f32_4( + global const float4 * src0, + ulong offset0, + global float4 * dst, + ulong offsetd +) { + src0 = (global float4*)((global char*)src0 + offset0); + dst = (global float4*)((global char*)dst + offsetd); + + dst[get_global_id(0)] = (src0[get_global_id(0)] > 20.0f) ? src0[get_global_id(0)] : log(1.0f + exp(src0[get_global_id(0)])); +} + +kernel void kernel_softplus_f16( + global const half * src0, + ulong offset0, + global half * dst, + ulong offsetd +) { + src0 = (global half*)((global char*)src0 + offset0); + dst = (global half*)((global char*)dst + offsetd); + + const float x = convert_float(src0[get_global_id(0)]); + dst[get_global_id(0)] = convert_half_rte((x > 20.0f) ? x : log(1.0f + exp(x))); +} + +kernel void kernel_softplus_f16_4( + global const half4 * src0, + ulong offset0, + global half4 * dst, + ulong offsetd +) { + src0 = (global half4*)((global char*)src0 + offset0); + dst = (global half4*)((global char*)dst + offsetd); + + const float4 x = convert_float4(src0[get_global_id(0)]); + dst[get_global_id(0)] = convert_half4_rte((x > 20.0f) ? x : log(1.0f + exp(x))); } -kernel void kernel_softplus_f32_nd( - global void * p_src0_base, - ulong off_src0_abs, - global void * p_dst_base, - ulong off_dst_abs, - int ne00, - int ne01, - int ne02, - int ne03, +kernel void kernel_softplus_f32_nc( + global const char * src0, + ulong offset0, + global char * dst, + ulong offsetd, + int ne00, ulong nb00, ulong nb01, ulong nb02, ulong nb03, - int ne10, - int ne11, - int ne12, - int ne13, - ulong nb10, - ulong nb11, - ulong nb12, - ulong nb13 + ulong nb0, + ulong nb1, + ulong nb2, + ulong nb3 ) { - int i0 = get_global_id(0); - int i1 = get_global_id(1); - int i2 = get_global_id(2); + src0 = src0 + offset0; + dst = dst + offsetd; - if (i0 < ne10 && i1 < ne11 && i2 < ne12) { - for (int i3 = 0; i3 < ne13; ++i3) { - ulong src_offset_in_tensor = (ulong)i0*nb00 + (ulong)i1*nb01 + (ulong)i2*nb02 + (ulong)i3*nb03; - global const float *src_val_ptr = (global const float *)((global char *)p_src0_base + off_src0_abs + src_offset_in_tensor); + const int i3 = get_group_id(2); + const int i2 = get_group_id(1); + const int i1 = get_group_id(0); - ulong dst_offset_in_tensor = (ulong)i0*nb10 + (ulong)i1*nb11 + (ulong)i2*nb12 + (ulong)i3*nb13; - global float *dst_val_ptr = (global float *)((global char *)p_dst_base + off_dst_abs + dst_offset_in_tensor); + for (int i0 = get_local_id(0); i0 < ne00; i0 += get_local_size(0)) { + global const float * x = (global const float *)(src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); + global float * y = (global float *)(dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); - *dst_val_ptr = softplus_f32(*src_val_ptr); - } + *y = (*x > 20.0f) ? *x : log(1.0f + exp(*x)); } } -kernel void kernel_softplus_f16_nd( - global void * p_src0_base, - ulong off_src0_abs, - global void * p_dst_base, - ulong off_dst_abs, - int ne00, - int ne01, - int ne02, - int ne03, +kernel void kernel_softplus_f16_nc( + global const char * src0, + ulong offset0, + global char * dst, + ulong offsetd, + int ne00, ulong nb00, ulong nb01, ulong nb02, ulong nb03, - int ne10, - int ne11, - int ne12, - int ne13, - ulong nb10, - ulong nb11, - ulong nb12, - ulong nb13 + ulong nb0, + ulong nb1, + ulong nb2, + ulong nb3 ) { - int i0 = get_global_id(0); - int i1 = get_global_id(1); - int i2 = get_global_id(2); + src0 = src0 + offset0; + dst = dst + offsetd; - if (i0 < ne10 && i1 < ne11 && i2 < ne12) { - for (int i3 = 0; i3 < ne13; ++i3) { - ulong src_offset_in_tensor = (ulong)i0*nb00 + (ulong)i1*nb01 + (ulong)i2*nb02 + (ulong)i3*nb03; - global const half *src_val_ptr = (global const half *)((global char *)p_src0_base + off_src0_abs + src_offset_in_tensor); + const int i3 = get_group_id(2); + const int i2 = get_group_id(1); + const int i1 = get_group_id(0); - ulong dst_offset_in_tensor = (ulong)i0*nb10 + (ulong)i1*nb11 + (ulong)i2*nb12 + (ulong)i3*nb13; - global half *dst_val_ptr = (global half *)((global char *)p_dst_base + off_dst_abs + dst_offset_in_tensor); + for (int i0 = get_local_id(0); i0 < ne00; i0 += get_local_size(0)) { + global const half * hx = (global const half *)(src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); + global half * hy = (global half *)(dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); - *dst_val_ptr = (half)(softplus_f32((float)(*src_val_ptr))); - } + const float x = convert_float(*hx); + *hy = convert_half_rte((x > 20.0f) ? x : log(1.0f + exp(x))); } } diff --git a/ggml/src/ggml-opencl/kernels/solve_tri.cl b/ggml/src/ggml-opencl/kernels/solve_tri.cl new file mode 100644 index 00000000000..80745fc7045 --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/solve_tri.cl @@ -0,0 +1,51 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable + +//------------------------------------------------------------------------------ +// solve_tri +//------------------------------------------------------------------------------ +kernel void kernel_solve_tri_f32( + global uchar * src0, + ulong offset0, + global uchar * src1, + ulong offset1, + global uchar * dst, + ulong offsetd, + int n, + int k, + ulong nb00, + ulong nb01, + ulong nb02, + ulong nb03, + ulong nb10, + ulong nb11, + ulong nb12, + ulong nb13, + ulong nb0, + ulong nb1, + ulong nb2, + ulong nb3 +) { + int col = get_global_id(0); + int i2 = get_global_id(1); + int i3 = get_global_id(2); + + global const uchar * Lb = src0 + offset0 + i2 * nb02 + i3 * nb03; + global const uchar * Bb = src1 + offset1 + i2 * nb12 + i3 * nb13; + global uchar * Xb = dst + offsetd + i2 * nb2 + i3 * nb3; + + for(int row = 0; row < n; ++row){ + global const float *pB = (global const float *)(Bb + row * nb11 + col * nb10); + + float sum = 0.0f; + for(int j = 0; j < row; ++j){ + global const float *pL = (global const float *)(Lb + row * nb01 + j * nb00); + global const float *pX = (global const float *)(Xb + j * nb1 + col * nb0); + sum += (*pL) * (*pX); + } + + global const float * pDiag = (global const float *)(Lb + row * nb01 + row *nb00); + global float * pOut = (global float *)(Xb + row * nb1 + col *nb0); + + *pOut = ((* pB) - sum) / (*pDiag); + } +} diff --git a/ggml/src/ggml-opencl/kernels/sum_rows.cl b/ggml/src/ggml-opencl/kernels/sum_rows.cl index c5f7c570f95..84630aa8a30 100644 --- a/ggml/src/ggml-opencl/kernels/sum_rows.cl +++ b/ggml/src/ggml-opencl/kernels/sum_rows.cl @@ -1,8 +1,13 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable +#pragma OPENCL EXTENSION cl_khr_subgroups : enable +// Most devices have max workgroup size of 1024, so this is enough for subgroup +// sizes of 16, 32, 64 and 128. Increase this value for smaller subgroups sizes +#define MAX_SUBGROUPS 64 kernel void kernel_sum_rows_f32( - global float * src0, + global char * src0, ulong offset0, - global float * dst, + global char * dst, ulong offsetd, int ne00, int ne01, @@ -15,25 +20,121 @@ kernel void kernel_sum_rows_f32( ulong nb2, ulong nb3 ) { - src0 = (global float *)((global char *)src0 + offset0); - dst = (global float *)((global char *)dst + offsetd); + src0 = src0 + offset0; + dst = dst + offsetd; - int i3 = get_global_id(2); - int i2 = get_global_id(1); - int i1 = get_global_id(0); + const int i3 = get_group_id(2); + const int i2 = get_group_id(1); + const int i1 = get_group_id(0); + + const int lid = get_local_id(0); + const int lsize = get_local_size(0); + + const uint sg_size = get_sub_group_size(); + const uint sg_id = get_sub_group_id(); + const uint sg_lid = get_sub_group_local_id(); + + __local float lmem[MAX_SUBGROUPS]; if (i3 >= ne03 || i2 >= ne02 || i1 >= ne01) { return; } - global float * src_row = (global float *) ((global char *) src0 + i1*nb01 + i2*nb02 + i3*nb03); - global float * dst_row = (global float *) ((global char *) dst + i1*nb1 + i2*nb2 + i3*nb3); + if(sg_id == 0){ + lmem[sg_lid] = 0.0f; + } - float row_sum = 0; + global float * src_row = (global float *) (src0 + i1*nb01 + i2*nb02 + i3*nb03); + global float * dst_row = (global float *) (dst + i1*nb1 + i2*nb2 + i3*nb3); - for (int i0 = 0; i0 < ne00; i0++) { - row_sum += src_row[i0]; + float sumf = 0.0f; + + for (int i0 = lid; i0 < ne00; i0 += lsize) { + sumf += src_row[i0]; } - dst_row[0] = row_sum; + sumf = sub_group_reduce_add(sumf); + + barrier(CLK_LOCAL_MEM_FENCE); + + if(sg_lid == 0){ + lmem[sg_id] = sumf; + } + + barrier(CLK_LOCAL_MEM_FENCE); + + sumf = lmem[sg_lid]; + sumf = sub_group_reduce_add(sumf); + + if (lid == 0) { + dst_row[0] = sumf; + } +} + +kernel void kernel_sum_rows_f32_4( + global char * src0, + ulong offset0, + global char * dst, + ulong offsetd, + int ne00, + int ne01, + int ne02, + int ne03, + ulong nb01, + ulong nb02, + ulong nb03, + ulong nb1, + ulong nb2, + ulong nb3 +) { + src0 = src0 + offset0; + dst = dst + offsetd; + + const int i3 = get_group_id(2); + const int i2 = get_group_id(1); + const int i1 = get_group_id(0); + + const int lid = get_local_id(0); + const int lsize = get_local_size(0); + + const uint sg_size = get_sub_group_size(); + const uint sg_id = get_sub_group_id(); + const uint sg_lid = get_sub_group_local_id(); + + __local float lmem[MAX_SUBGROUPS]; + + if (i3 >= ne03 || i2 >= ne02 || i1 >= ne01) { + return; + } + + if(sg_id == 0){ + lmem[sg_lid] = 0.0f; + } + + global float4 * src_row = (global float4 *) (src0 + i1*nb01 + i2*nb02 + i3*nb03); + global float * dst_row = (global float *) (dst + i1*nb1 + i2*nb2 + i3*nb3); + + float4 sum_vec = (float4)0.0f; + + for (int i0 = lid; i0 < ne00 / 4; i0 += lsize) { + sum_vec += src_row[i0]; + } + + float sumf = dot(sum_vec, (float4)(1.0f)); + sumf = sub_group_reduce_add(sumf); + + barrier(CLK_LOCAL_MEM_FENCE); + + if(sg_lid == 0){ + lmem[sg_id] = sumf; + } + + barrier(CLK_LOCAL_MEM_FENCE); + + sumf = lmem[sg_lid]; + sumf = sub_group_reduce_add(sumf); + + if (lid == 0) { + dst_row[0] = sumf; + } } diff --git a/ggml/src/ggml-opencl/kernels/tanh.cl b/ggml/src/ggml-opencl/kernels/tanh.cl index d9da86b1489..2c4887ad3e0 100644 --- a/ggml/src/ggml-opencl/kernels/tanh.cl +++ b/ggml/src/ggml-opencl/kernels/tanh.cl @@ -1,63 +1,109 @@ #pragma OPENCL EXTENSION cl_khr_fp16 : enable -#ifdef cl_intel_required_subgroup_size -#pragma OPENCL EXTENSION cl_intel_required_subgroup_size : enable -#define INTEL_GPU 1 -#define REQD_SUBGROUP_SIZE_16 __attribute__((intel_reqd_sub_group_size(16))) -#define REQD_SUBGROUP_SIZE_32 __attribute__((intel_reqd_sub_group_size(32))) -#elif defined(cl_qcom_reqd_sub_group_size) -#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable -#define ADRENO_GPU 1 -#define REQD_SUBGROUP_SIZE_64 __attribute__((qcom_reqd_sub_group_size("half"))) -#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full"))) -#endif - -kernel void kernel_tanh_f32_nd( - global void * p_src0_base, ulong off_src0_abs, - global void * p_dst_base, ulong off_dst_abs, - int ne00, int ne01, int ne02, int ne03, - ulong nb00, ulong nb01, ulong nb02, ulong nb03, - int ne10, int ne11, int ne12, int ne13, - ulong nb10, ulong nb11, ulong nb12, ulong nb13 +kernel void kernel_tanh_f32( + global const float * src0, + ulong offset0, + global float * dst, + ulong offsetd ) { - int i0 = get_global_id(0); - int i1 = get_global_id(1); - int i2 = get_global_id(2); + src0 = (global float*)((global char*)src0 + offset0); + dst = (global float*)((global char*)dst + offsetd); - if (i0 < ne10 && i1 < ne11 && i2 < ne12) { - for (int i3 = 0; i3 < ne13; ++i3) { - ulong src_offset_in_tensor = (ulong)i0*nb00 + (ulong)i1*nb01 + (ulong)i2*nb02 + (ulong)i3*nb03; - global const float *src_val_ptr = (global const float *)((global char *)p_src0_base + off_src0_abs + src_offset_in_tensor); + dst[get_global_id(0)] = tanh(src0[get_global_id(0)]); +} + +kernel void kernel_tanh_f32_4( + global const float4 * src0, + ulong offset0, + global float4 * dst, + ulong offsetd +) { + src0 = (global float4*)((global char*)src0 + offset0); + dst = (global float4*)((global char*)dst + offsetd); + + dst[get_global_id(0)] = tanh(src0[get_global_id(0)]); +} + +kernel void kernel_tanh_f16( + global const half * src0, + ulong offset0, + global half * dst, + ulong offsetd +) { + src0 = (global half*)((global char*)src0 + offset0); + dst = (global half*)((global char*)dst + offsetd); + + dst[get_global_id(0)] = tanh(src0[get_global_id(0)]); +} + +kernel void kernel_tanh_f16_4( + global const half4 * src0, + ulong offset0, + global half4 * dst, + ulong offsetd +) { + src0 = (global half4*)((global char*)src0 + offset0); + dst = (global half4*)((global char*)dst + offsetd); + + dst[get_global_id(0)] = tanh(src0[get_global_id(0)]); +} + +kernel void kernel_tanh_f32_nc( + global const char * src0, + ulong offset0, + global char * dst, + ulong offsetd, + int ne00, + ulong nb00, + ulong nb01, + ulong nb02, + ulong nb03, + ulong nb0, + ulong nb1, + ulong nb2, + ulong nb3 +) { + src0 = src0 + offset0; + dst = dst + offsetd; + + const int i3 = get_group_id(2); + const int i2 = get_group_id(1); + const int i1 = get_group_id(0); - ulong dst_offset_in_tensor = (ulong)i0*nb10 + (ulong)i1*nb11 + (ulong)i2*nb12 + (ulong)i3*nb13; - global float *dst_val_ptr = (global float *)((global char *)p_dst_base + off_dst_abs + dst_offset_in_tensor); + for (int i0 = get_local_id(0); i0 < ne00; i0 += get_local_size(0)) { + global const float * x = (global const float *)(src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); + global float * y = (global float *)(dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); - *dst_val_ptr = tanh(*src_val_ptr); - } + *y = tanh(*x); } } -kernel void kernel_tanh_f16_nd( - global void * p_src0_base, ulong off_src0_abs, - global void * p_dst_base, ulong off_dst_abs, - int ne00, int ne01, int ne02, int ne03, - ulong nb00, ulong nb01, ulong nb02, ulong nb03, - int ne10, int ne11, int ne12, int ne13, - ulong nb10, ulong nb11, ulong nb12, ulong nb13 +kernel void kernel_tanh_f16_nc( + global const char * src0, + ulong offset0, + global char * dst, + ulong offsetd, + int ne00, + ulong nb00, + ulong nb01, + ulong nb02, + ulong nb03, + ulong nb0, + ulong nb1, + ulong nb2, + ulong nb3 ) { - int i0 = get_global_id(0); - int i1 = get_global_id(1); - int i2 = get_global_id(2); + src0 = src0 + offset0; + dst = dst + offsetd; - if (i0 < ne10 && i1 < ne11 && i2 < ne12) { - for (int i3 = 0; i3 < ne13; ++i3) { - ulong src_offset_in_tensor = (ulong)i0*nb00 + (ulong)i1*nb01 + (ulong)i2*nb02 + (ulong)i3*nb03; - global const half *src_val_ptr = (global const half *)((global char *)p_src0_base + off_src0_abs + src_offset_in_tensor); + const int i3 = get_group_id(2); + const int i2 = get_group_id(1); + const int i1 = get_group_id(0); - ulong dst_offset_in_tensor = (ulong)i0*nb10 + (ulong)i1*nb11 + (ulong)i2*nb12 + (ulong)i3*nb13; - global half *dst_val_ptr = (global half *)((global char *)p_dst_base + off_dst_abs + dst_offset_in_tensor); + for (int i0 = get_local_id(0); i0 < ne00; i0 += get_local_size(0)) { + global const half * x = (global const half *)(src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); + global half * y = (global half *)(dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); - *dst_val_ptr = tanh(*src_val_ptr); - } + *y = tanh(*x); } } diff --git a/ggml/src/ggml-opencl/kernels/transpose.cl b/ggml/src/ggml-opencl/kernels/transpose.cl index 1279b6531b9..ad89bdcbdec 100644 --- a/ggml/src/ggml-opencl/kernels/transpose.cl +++ b/ggml/src/ggml-opencl/kernels/transpose.cl @@ -44,6 +44,19 @@ kernel void kernel_transpose_16_4x1( write_imageh(output, i * rows + j, (half4)(temp0, temp1, temp2, temp3)); } +// Transpose treating each element as 8-bit using buffer +kernel void kernel_transpose_8_buf( + global const uchar * input, + global uchar * output, + const int ldi, + const int ldo +) { + const int x = get_global_id(0); + const int y = get_global_id(1); + + output[x*ldo + y] = input[y*ldi + x]; +} + // Transpose treating each element as 16-bit using buffer kernel void kernel_transpose_16_buf( global const ushort * input, @@ -57,6 +70,19 @@ kernel void kernel_transpose_16_buf( output[x*ldo + y] = input[y*ldi + x]; } +// Transpose treating each element as 32-bit using buffer +kernel void kernel_transpose_32_buf( + global const uint * input, + global uint * output, + const int ldi, + const int ldo +) { + const int x = get_global_id(0); + const int y = get_global_id(1); + + output[x*ldo + y] = input[y*ldi + x]; +} + // 32-bit transpose, loading/storing a 4x4 tile of elements kernel void kernel_transpose_32( __read_only image1d_buffer_t input, diff --git a/ggml/src/ggml-opencl/kernels/tri.cl b/ggml/src/ggml-opencl/kernels/tri.cl new file mode 100644 index 00000000000..35cdd543bc5 --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/tri.cl @@ -0,0 +1,32 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable + +//------------------------------------------------------------------------------ +// tri +//------------------------------------------------------------------------------ +__kernel void kernel_tri_f32( + global float * src0, + ulong offset0, + global float * dst, + ulong offsetd, + int n, + int ne0, + int ne1, + int tri_type +) { + src0 = (global float*)((global char*)src0 + offset0); + dst = (global float*)((global char*)dst + offsetd); + + int idx = get_global_id(0); + if (idx >= n) return; + + int i0 = idx % ne0; + int i1 = (idx / ne0) % ne1; + + int keep = 0; + if (tri_type == 0) keep = (i0 >= i1); + else if (tri_type == 1) keep = (i0 > i1); + else if (tri_type == 2) keep = (i0 <= i1); + else keep = (i0 < i1); + + dst[idx] = keep ? src0[idx] : 0.0f; +} diff --git a/ggml/src/ggml-openvino/.clang-format b/ggml/src/ggml-openvino/.clang-format new file mode 100644 index 00000000000..a2a24d7d33a --- /dev/null +++ b/ggml/src/ggml-openvino/.clang-format @@ -0,0 +1,154 @@ +--- +# Override root .clang-format +AlignConsecutiveAssignments: false +AlignConsecutiveDeclarations: false +Cpp11BracedListStyle: true +SpacesInContainerLiterals: false +BreakBeforeBraces: Attach +AccessModifierOffset: -4 +IndentCaseBlocks: false +IndentCaseLabels: false + +Language: Cpp +AlignAfterOpenBracket: Align +AlignArrayOfStructures: Left +AlignConsecutiveBitFields: AcrossComments +AlignConsecutiveMacros: AcrossComments +# AlignConsecutiveShortCaseStatements: AcrossComments +AlignEscapedNewlines: Left # LeftWithLastLine +AlignOperands: Align +AlignTrailingComments: + Kind: Always + OverEmptyLines: 1 +AllowAllArgumentsOnNextLine: true +AllowAllParametersOfDeclarationOnNextLine: false +# AllowBreakBeforeNoexceptSpecifier: OnlyWithParen +AllowShortBlocksOnASingleLine: Never +AllowShortCaseLabelsOnASingleLine: false +AllowShortFunctionsOnASingleLine: Inline +AllowShortIfStatementsOnASingleLine: Never +AllowShortLambdasOnASingleLine: Inline +AllowShortLoopsOnASingleLine: false +AlwaysBreakBeforeMultilineStrings: true +# Treat CUDA keywords/attributes as "attribute macros" and avoid breaking lines inside them +AttributeMacros: + - __host__ + - __device__ + - __global__ + - __forceinline__ + - __launch_bounds__ +BinPackArguments: true +BinPackParameters: false # OnePerLine +BitFieldColonSpacing: Both +# BreakAdjacentStringLiterals: true +BreakAfterAttributes: Never +BreakBeforeBinaryOperators: None +BreakBeforeInlineASMColon: OnlyMultiline +BreakBeforeTernaryOperators: false +# BreakBinaryOperations: Never +BreakConstructorInitializers: AfterColon +# BreakFunctionDefinitionParameters: false +BreakInheritanceList: AfterComma +BreakStringLiterals: true +# BreakTemplateDeclarations: Yes +ColumnLimit: 120 +CommentPragmas: '^ IWYU pragma:' +CompactNamespaces: false +ConstructorInitializerIndentWidth: 4 +ContinuationIndentWidth: 4 +DerivePointerAlignment: false +DisableFormat: false +EmptyLineBeforeAccessModifier: Leave +EmptyLineAfterAccessModifier: Never +ExperimentalAutoDetectBinPacking: false +FixNamespaceComments: true +IncludeBlocks: Regroup +IncludeCategories: + - Regex: '".*"' + Priority: 1 + SortPriority: 0 + - Regex: '^<.*\.h>' + Priority: 2 + SortPriority: 0 + - Regex: '^<.*' + Priority: 3 + SortPriority: 0 + - Regex: '.*' + Priority: 4 + SortPriority: 0 +IncludeIsMainRegex: '([-_](test|unittest))?$' +IncludeIsMainSourceRegex: '' +IndentAccessModifiers: false +IndentExternBlock: NoIndent +IndentGotoLabels: false +IndentPPDirectives: AfterHash +IndentWidth: 4 +IndentWrappedFunctionNames: false +InsertBraces: true # NOTE: may lead to incorrect formatting +InsertNewlineAtEOF: true +JavaScriptQuotes: Leave +JavaScriptWrapImports: true +KeepEmptyLinesAtTheStartOfBlocks: false +LambdaBodyIndentation: Signature +LineEnding: LF +MacroBlockBegin: '' +MacroBlockEnd: '' +MaxEmptyLinesToKeep: 1 +NamespaceIndentation: None +ObjCBinPackProtocolList: Auto +ObjCBlockIndentWidth: 4 +ObjCSpaceAfterProperty: true +ObjCSpaceBeforeProtocolList: true +PPIndentWidth: -1 +PackConstructorInitializers: CurrentLine +PenaltyBreakAssignment: 2 +PenaltyBreakBeforeFirstCallParameter: 1 +PenaltyBreakComment: 300 +PenaltyBreakFirstLessLess: 120 +PenaltyBreakString: 1000 +PenaltyBreakTemplateDeclaration: 10 +PenaltyExcessCharacter: 1000000 +PenaltyReturnTypeOnItsOwnLine: 200 +PointerAlignment: Middle +QualifierAlignment: Left +#QualifierOrder: ['static', 'inline', 'friend', 'constexpr', 'const', 'volatile', 'type', 'restrict'] +RawStringFormats: + - Language: Cpp + Delimiters: + - cc + - CC + - cpp + - Cpp + - CPP + - 'c++' + - 'C++' + CanonicalDelimiter: '' +ReferenceAlignment: Middle +ReflowComments: false # IndentOnly +SeparateDefinitionBlocks: Always +SortIncludes: CaseInsensitive +SortUsingDeclarations: LexicographicNumeric +SpaceAfterCStyleCast: true +SpaceAfterLogicalNot: false +SpaceAfterTemplateKeyword: true +SpaceBeforeAssignmentOperators: true +SpaceBeforeCpp11BracedList: false +SpaceBeforeCtorInitializerColon: true +SpaceBeforeInheritanceColon: true +SpaceBeforeParens: ControlStatements +SpaceBeforeRangeBasedForLoopColon: true +SpaceInEmptyBlock: false +SpaceInEmptyParentheses: false +SpacesBeforeTrailingComments: 2 +SpacesInAngles: Never +SpacesInLineCommentPrefix: + Minimum: 1 + Maximum: -1 +SpacesInParentheses: false +SpacesInSquareBrackets: false +SpaceBeforeSquareBrackets: false +Standard: c++17 +TabWidth: 4 +UseTab: Never +WhitespaceSensitiveMacros: ['STRINGIZE'] +... diff --git a/ggml/src/ggml-openvino/CMakeLists.txt b/ggml/src/ggml-openvino/CMakeLists.txt new file mode 100644 index 00000000000..175b585661d --- /dev/null +++ b/ggml/src/ggml-openvino/CMakeLists.txt @@ -0,0 +1,22 @@ +find_package(OpenVINO REQUIRED) +find_package(OpenCL REQUIRED) + +include("${OpenVINO_DIR}/../3rdparty/tbb/lib/cmake/TBB/TBBConfig.cmake") + +file(GLOB_RECURSE GGML_HEADERS_OPENVINO "*.h" "*.hpp") +file(GLOB_RECURSE GGML_SOURCES_OPENVINO "*.cpp") + +ggml_add_backend_library(ggml-openvino + ${GGML_SOURCES_OPENVINO} + ${GGML_HEADERS_OPENVINO} +) + +target_link_libraries(ggml-openvino PRIVATE openvino::runtime TBB::tbb OpenCL::OpenCL) + +if (GGML_OPENVINO) + if (CMAKE_SYSTEM_PROCESSOR STREQUAL "aarch64") + elseif (CMAKE_SYSTEM_PROCESSOR STREQUAL "x86_64" OR CMAKE_SYSTEM_PROCESSOR STREQUAL "amd64" OR CMAKE_SYSTEM_PROCESSOR STREQUAL "AMD64") + else() + message(FATAL_ERROR "OpenVINO: OpenVINO toolkit supports x86-64 and arm64 but not ${CMAKE_SYSTEM_PROCESSOR}") + endif() +endif() diff --git a/ggml/src/ggml-openvino/ggml-decoder.cpp b/ggml/src/ggml-openvino/ggml-decoder.cpp new file mode 100644 index 00000000000..5095e799849 --- /dev/null +++ b/ggml/src/ggml-openvino/ggml-decoder.cpp @@ -0,0 +1,985 @@ +#include "ggml-decoder.h" + +#include "ggml-backend-impl.h" +#include "ggml-backend.h" +#include "ggml-openvino-extra.h" +#include "ggml-openvino.h" +#include "ggml-quants.h" + +#include <ggml-impl.h> +#include <ggml.h> + +#include <algorithm> +#include <cassert> +#include <cstddef> +#include <cstdint> +#include <cstdlib> +#include <execution> +#include <fstream> +#include <iomanip> +#include <map> +#include <memory> +#include <openvino/core/dimension.hpp> +#include <openvino/core/except.hpp> +#include <openvino/core/node.hpp> +#include <openvino/core/partial_shape.hpp> +#include <openvino/core/type/bfloat16.hpp> +#include <openvino/core/type/element_type.hpp> +#include <openvino/core/type/float16.hpp> +#include <openvino/op/constant.hpp> +#include <openvino/op/convert.hpp> +#include <openvino/op/parameter.hpp> +#include <openvino/runtime/tensor.hpp> +#include <optional> +#include <ostream> +#include <set> +#include <stdexcept> +#include <string> +#include <unordered_map> +#include <vector> + +GgmlOvDecoder::GgmlOvDecoder(ggml_cgraph * cgraph, + ModelParams & model_params, + ComputeParams & compute_params, + std::map<std::string, std::shared_ptr<ov::Node>> & model_weights, + bool is_static, + bool is_stateful, + bool is_prefill, + int prefill_chunk_size) : + m_is_static(is_static), + m_is_stateful(is_stateful), + m_is_prefill(is_prefill), + m_naive(false), + m_prefill_chunk_size(prefill_chunk_size), + m_cgraph(cgraph), + m_model_weights(model_weights), + m_model_params(model_params), + m_compute_params(compute_params) { + if (auto * env = getenv("GGML_OPENVINO_PRINT_CGRAPH_TENSOR_ADDRESS"); env && std::string(env) != "0") { +#ifdef _WIN32 + _putenv_s("GGML_OPENVINO_PRINT_CGRAPH_TENSOR_ADDRESS", ""); +#else + unsetenv("GGML_OPENVINO_PRINT_CGRAPH_TENSOR_ADDRESS"); +#endif + print_tensor_address_map(cgraph); + } + + validate_cgraph(); + + set_input_output(); + compute_model_inputs(); + compute_model_outputs(); + + for (int node_n = 0; node_n < cgraph->n_nodes; node_n++) { + m_node_info_list[node_n].node_op_case = compute_op_case(m_node_info_list[node_n].node); + m_node_info_list[node_n].node_op_type = compute_op_type(m_node_info_list[node_n].node); + } + + add_extra_inputs(); +} + +void GgmlOvDecoder::update_io(ggml_cgraph * cgraph) { + m_cgraph = cgraph; + m_model_inputs.clear(); + m_model_outputs.clear(); + m_node_info_list.clear(); + set_input_output(); + compute_model_inputs(); + compute_model_outputs(); +} + +GgmlOvDecoder::GgmlOvDecoder(ggml_cgraph * cgraph, std::map<std::string, std::shared_ptr<ov::Node>> & model_weights) { + m_cgraph = cgraph; + m_model_weights = model_weights; + m_naive = true; + set_input_output(); + compute_model_inputs(); + compute_model_outputs(); + for (int node_n = 0; node_n < cgraph->n_nodes; node_n++) { + m_node_info_list[node_n].node_op_case = compute_op_case(m_node_info_list[node_n].node); + m_node_info_list[node_n].node_op_type = compute_op_type(m_node_info_list[node_n].node); + } +} + +void GgmlOvDecoder::set_input_output() { + for (int node_n = 0; node_n < m_cgraph->n_nodes; node_n++) { + auto node = m_cgraph->nodes[node_n]; + + NodeInfo current_node_info; + auto node_name = std::string(node->name); + auto node_output_name = node_name; + auto * node_output = node; + if (node->op == GGML_OP_SET_ROWS) { + // SET_ROWS updates the tensor in place. For later ov op that uses the + // the view_src of SET_ROWS, we need to make sure they get the updated tensor + // by putting the view_src name in the tensor_map in + // <openvino>/src/frontends/ggml/src/translate_session.cpp + node_output_name = std::string(node->view_src->name); + node_output = node->view_src; + } + + current_node_info.node = node; + current_node_info.node_name = node_name; + current_node_info.node_output = node_output; + current_node_info.node_output_name = node_output_name; + current_node_info.node_op_case = 0; + current_node_info.data_addr = node->data; + + for (int i = 0; i < GGML_MAX_SRC; i++) { + auto * src = node->src[i]; + if (src == nullptr) { + continue; + } + auto src_name = std::string(src->name); + if (src->flags & GGML_TENSOR_FLAG_INPUT) { + src_name = get_graph_input_ov_name(src, node); + } + current_node_info.node_inputs[src_name] = src; + current_node_info.node_inputs_names.push_back(src_name); + } + + m_node_info_list.push_back(current_node_info); + } +} + +int GgmlOvDecoder::compute_op_case(const ggml_tensor * node) const { + int op_case = 0; + switch (node->op) { + case GGML_OP_RESHAPE: { + auto * src = node->src[0]; + if (src->op == GGML_OP_RESHAPE && src->src[0]->ne[0] == node->ne[0] && src->src[0]->ne[1] == node->ne[1]) { + op_case = 4; + } else if (node->ne[0] * node->ne[1] == src->ne[0]) { + op_case = 1; + } else if (src->ne[0] * src->ne[1] == node->ne[0]) { + op_case = 2; + if (src->ne[2] * src->ne[3] == node->ne[1]) { + op_case = 5; + } + } else if (src->ne[0] * src->ne[1] == node->ne[1]) { + op_case = 3; + } else if (src->ne[1] * src->ne[2] == node->ne[1]) { + op_case = 6; + } + break; + } + case GGML_OP_CONT: { + if (node->src[0]->op == GGML_OP_PERMUTE) { + op_case = 1; + } else if (node->src[0]->op == GGML_OP_TRANSPOSE) { + op_case = 2; + } else if (node->src[0]->op == GGML_OP_VIEW) { + op_case = 3; + } + break; + } + case GGML_OP_PERMUTE: { + if (node->src[0]->op != GGML_OP_VIEW) { + op_case = 1; + } else if (node->src[0]->src[0]->op == GGML_OP_NONE) { + // kv cache tensor + std::string src_name(node->view_src->name); + int layer = extract_layer_from_name(src_name); + if (!is_swa_layer(layer)) { + op_case = 2; + } else { + op_case = 3; + } + } else { + // rope'ed query tensor + op_case = 4; + } + break; + } + case GGML_OP_MUL_MAT: { + if (node->src[0]->op == GGML_OP_CONT && node->src[0]->src[0]->op == GGML_OP_TRANSPOSE) { + op_case = 2; + } else if (node->src[0]->op == GGML_OP_VIEW && node->src[1]->op == GGML_OP_VIEW) { + op_case = 3; + } + break; + } + case GGML_OP_GET_ROWS: { + if (node->src[1]->op == GGML_OP_VIEW) { + op_case = 2; + } + break; + } + case GGML_OP_ROPE: { + const int mode = node->op_params[2]; + switch (mode) { + case GGML_ROPE_TYPE_NEOX: { + op_case = 0x00010000; + break; + } + case GGML_ROPE_TYPE_IMROPE: { + op_case = 0x00020000; + break; + } + default: + op_case = 0x00000000; + break; + } + if (node->src[0]->op == GGML_OP_VIEW) { + op_case = (op_case | 0x00000002); + } + break; + } + case GGML_OP_VIEW: { + if (node->src[0]->op == GGML_OP_VIEW) { + auto * src = node->src[0]; + if (ggml_nelements(node) != ggml_nelements(src)) { + throw std::runtime_error("Unsupported VIEW case"); + } + op_case = 2; + } + { + auto * src = node->src[0]; + if ((ggml_nelements(node) != ggml_nelements(src)) && m_naive) { + // Compare each dimension of node and src, if only one dimension differs then op_case=3 + int diff_count = 0; + for (int i = 0; i < GGML_MAX_DIMS; i++) { + if (node->ne[i] != src->ne[i]) { + diff_count++; + } + } + if (diff_count == 1) { + op_case = 3; + } + } + } + break; + } + default: + break; + } + return op_case; +} + +int extract_layer_from_name(const std::string & name) { + size_t pos1 = name.find("_l"); + assert(pos1 != std::string::npos); + pos1 += 2; + size_t pos2 = name.find(' ', pos1); + if (pos2 == std::string::npos) { + pos2 = name.length(); + } + std::string layer_str = name.substr(pos1, pos2 - pos1); + int layer = std::stoi(layer_str); + return layer; +} + +std::pair<ModelParams, ComputeParams> GgmlOvDecoder::compute_llm_params(ggml_cgraph * cgraph, bool is_static) { + ModelParams model_params; + ComputeParams compute_params; + for (int i = 0; i < cgraph->n_nodes; i++) { + auto * node = cgraph->nodes[i]; + std::string name = std::string(node->name); + if (node->op == GGML_OP_FLASH_ATTN_EXT) { + model_params.n_heads = node->src[0]->ne[2]; + model_params.n_heads_kv = node->src[1]->ne[2]; + model_params.head_size = node->src[0]->ne[0]; + compute_params.input_len = node->src[0]->ne[1]; + + auto * cache_k_perm = node->src[1]; + if (cache_k_perm->op == GGML_OP_CPY) { + cache_k_perm = cache_k_perm->src[0]; + } + assert(cache_k_perm->op == GGML_OP_PERMUTE); + auto * cache_k_view = cache_k_perm->src[0]; + assert(cache_k_view->op == GGML_OP_VIEW); + + auto * cache_k = cache_k_view->src[0]; + int layer = extract_layer_from_name(cache_k->name); + auto * mask = node->src[3]; + std::string mask_name(mask->name); + + model_params.kv_buffer_ctx_id = ggml_backend_openvino_buffer_get_ctx_id(cache_k->buffer); + if (mask_name.find("swa") != std::string::npos) { + model_params.swa_layers.push_back(layer); + model_params.ctx_per_seq_swa = cache_k->ne[1]; + } else { + model_params.ctx_per_seq = cache_k->ne[1]; + model_params.n_seq = cache_k->ne[2]; + } + + compute_params.n_seq_active = mask->ne[3]; + auto seq_size = cache_k->ne[0] * cache_k->ne[1] * ggml_type_size(cache_k->type); + size_t offset; + memcpy(&offset, cache_k_view->op_params, sizeof(size_t)); + compute_params.seq_active_start = offset / seq_size; + compute_params.token_len_per_seq = node->ne[2]; + + if (mask_name.find("swa") != std::string::npos) { + compute_params.attention_size_swa = mask->ne[0]; + } else { + compute_params.attention_size = mask->ne[0]; + } + if (is_static) { + compute_params.attention_size = model_params.ctx_per_seq; + compute_params.attention_size_swa = model_params.ctx_per_seq_swa; + compute_params.token_len_per_seq = 1; + } + break; + } + if (node->op == GGML_OP_ROPE) { + memcpy(model_params.rope_params, node->op_params, sizeof(int32_t) * 15); + } + } + auto * output_tensor = cgraph->nodes[cgraph->n_nodes - 1]; + compute_params.output_len = output_tensor->ne[1]; + // for NPU, output_len is always 1 except for llama-perplexity + if (is_static && compute_params.output_len == 0) { + compute_params.output_len = 1; + } + model_params.ctx = model_params.ctx_per_seq * model_params.n_seq; + model_params.ctx_swa = model_params.ctx_per_seq_swa * model_params.n_seq; + return {model_params, compute_params}; +} + +void GgmlOvDecoder::validate_cgraph() const { + if (m_model_params.n_seq > 1 && m_is_static == true) { + throw std::runtime_error("n_seq > 1 is not supported on NPU. Try setting -np 1."); + } +} + +ov::PartialShape GgmlOvDecoder::get_graph_input_shape(const ggml_tensor * op, const ggml_tensor * input) const { + if (m_naive) { + return input!= nullptr ? ov::PartialShape{get_shape(input)} : ov::PartialShape{get_shape(op)}; + } + auto name = std::string(input->name); + ov::PartialShape input_shape; + + if (is_inp_tok(input, op) || is_inp_pos(input, op)) { + // tokens or positions + int len = m_is_static ? (m_is_prefill ? m_prefill_chunk_size : 1) : -1; + input_shape = ov::PartialShape{1, 1, 1, len}; + + } else if (is_output_idx(input, op)) { + // output index + input_shape = ov::PartialShape{1, 1, 1, m_is_static ? m_compute_params.output_len : -1}; + + } else if (is_inp_mask(input, op)) { + // mask + if (m_is_static) { + input_shape = ov::PartialShape{1, 1, m_is_prefill ? m_prefill_chunk_size : 1, m_model_params.ctx}; + } else if (m_is_stateful) { + input_shape = ov::PartialShape{1, 1, -1, -1}; + } else { + input_shape = ov::PartialShape{-1, 1, -1, -1}; + } + + } else if (is_kvcache(input, op)) { + // kvcache + input_shape = ov::PartialShape{get_shape(input)}; + if (!m_is_static) { + // do not fix ctx size to make llama-bench work across test params + input_shape[2] = -1; + } + if (is_stateful()) { + // Convert stateless KV cache layout [1, 1, seq, n_heads_kv * head_size] + // to stateful layout [1, seq, n_heads_kv, head_size]. + assert(input_shape.size() == 4 && input_shape[0] == 1 && input_shape[1] == 1 && + input_shape[2].is_dynamic() && + input_shape[3] == (m_model_params.n_heads_kv * m_model_params.head_size)); + input_shape = {input_shape[0], ov::Dimension::dynamic(), m_model_params.n_heads_kv, + m_model_params.head_size}; + } + + } else if (is_kv_idx(input, op)) { + // kv update index + int len = m_is_static ? (m_is_prefill ? m_prefill_chunk_size : 1) : -1; + input_shape = ov::PartialShape{1, 1, 1, len}; + + } else { + input_shape = ov::PartialShape{get_shape(input)}; + } + return input_shape; +} + +void GgmlOvDecoder::add_extra_inputs() { + // Extra inputs: + // 1. `attention_size`, used in FLASH_ATTN where the shape of the matmul's are 256 aligned, + // see llama_kv_cache_unified::get_n_kv and llama_kv_cache_unified::get_padding. + // 2. `n_seq_active` and `seq_active_start`, used in FLASH_ATTN_EXT to indicate the active sequences in the batch + + auto create_1d_input = [this](const std::string & name, int64_t value) { + if (m_is_static) { + auto constant = + std::make_shared<ov::op::v0::Constant>(ov::element::i64, ov::Shape{1}, std::vector<int64_t>{value}); + constant->set_friendly_name(name); + m_model_extra_inputs[name] = constant; + } else { + auto param_node = std::make_shared<ov::op::v0::Parameter>(ov::element::i64, ov::Shape{1}); + param_node->set_friendly_name(name); + param_node->output(0).get_tensor().set_names({name}); + m_model_extra_inputs[name] = param_node; + + auto tensor = std::make_shared<ov::Tensor>(ov::element::i64, ov::Shape{1}); + *tensor->data<int64_t>() = value; + m_model_extra_input_values[name] = tensor; + } + }; + + create_1d_input("attention_size", m_compute_params.attention_size); + if (m_compute_params.attention_size_swa != -1) { + create_1d_input("attention_size_swa", m_compute_params.attention_size_swa); + } + create_1d_input("n_seq_active", m_compute_params.n_seq_active); + create_1d_input("seq_active_start", m_compute_params.seq_active_start); + create_1d_input("seq_active_end", m_compute_params.seq_active_start + m_compute_params.n_seq_active); + create_1d_input("token_len_per_seq", m_compute_params.token_len_per_seq); + // create_1d_input("token_len", m_token_len_per_seq * m_n_seq_active); +} + +bool GgmlOvDecoder::node_is_used_as_src(const int node_idx) { + ggml_tensor * node = m_cgraph->nodes[node_idx]; + for (int i = node_idx; i < m_cgraph->n_nodes; i++) { + ggml_tensor * other_node = m_cgraph->nodes[i]; + for (int j = 0; j < GGML_MAX_SRC; j++) { + if (other_node->src[j] == node) { + return true; + } + } + } + return false; +} + +void GgmlOvDecoder::compute_model_inputs() { + m_model_inputs.clear(); + m_inputs.clear(); + for (int i = 0; i < m_cgraph->n_nodes; i++) { + ggml_tensor * node = m_cgraph->nodes[i]; + // the node op is NONE means this node maybe as input of later nodes, we should add it to model inputs for this node. + if (node->op == GGML_OP_NONE && node_is_used_as_src(i)) { + std::string node_name(node->name); + if (m_model_weights.find(node_name) == m_model_weights.end()) { + m_inputs[node_name] = node; + auto param_node = + std::make_shared<ov::op::v0::Parameter>(get_ov_type(node), get_graph_input_shape(node, nullptr)); + param_node->set_friendly_name(node_name); + param_node->output(0).get_tensor().set_names({node_name}); + m_model_inputs[node_name] = param_node; + } + continue; + } + for (int i = 0; i < GGML_MAX_SRC; i++) { + auto * src = node->src[i]; + if (src == nullptr) { + continue; + } + std::string src_name = std::string(src->name); + if (src->flags & GGML_TENSOR_FLAG_INPUT) { + src_name = get_graph_input_ov_name(src, node); + } + if (m_model_weights.find(src_name) != m_model_weights.end()) { + continue; + } + + bool is_intermediate_node = false; + for (const auto & node_info : m_node_info_list) { + if (node_info.node == src) { + is_intermediate_node = true; + break; + } + } + if (is_intermediate_node) { + continue; + } + if (m_model_inputs.find(src_name) != m_model_inputs.end()) { + continue; + } + + m_inputs[src_name] = src; + + ggml_backend_buffer * buffer = src->buffer; + // GGML_BACKEND_BUFFER_USAGE_ANY are kv caches + if (buffer->usage == GGML_BACKEND_BUFFER_USAGE_ANY) { + if (auto it = std::find(m_model_params.kv_names.begin(), m_model_params.kv_names.end(), src_name); + it == m_model_params.kv_names.end()) { + m_model_params.kv_names.push_back(src_name); + } + } + ov::PartialShape param_shape = get_graph_input_shape(node, src); + auto param_node = std::make_shared<ov::op::v0::Parameter>(get_ov_type(src), param_shape); + param_node->set_friendly_name(src_name); + param_node->output(0).get_tensor().set_names({src_name}); + m_model_inputs[src_name] = param_node; + } + } +} + +void GgmlOvDecoder::compute_model_outputs() { + m_model_outputs.clear(); + m_model_output_names.clear(); + for (int node_n = 0; node_n < m_cgraph->n_nodes; node_n++) { + auto * cur_node = m_cgraph->nodes[node_n]; + // if the node op is NONE means this node is not used at all, we can skip it directly without adding to model outputs. + if (cur_node->op == GGML_OP_NONE) { + continue; + } + auto cur_node_use_count = m_cgraph->use_counts[ggml_hash_find(&m_cgraph->visited_hash_set, cur_node)]; + if (cur_node_use_count == 0) { + // The output of SET_ROWS is the view_src tensor, which is updated in place. We should use the view_src name as the output name to make sure it can be correctly matched with the later ops that use the view_src. + if (cur_node != nullptr && cur_node->op == GGML_OP_SET_ROWS) { + cur_node = cur_node->view_src; + } + } else { + int input_use_count = 0; + for (int i = 0; i < m_cgraph->n_nodes; i++) { + ggml_tensor * node = m_cgraph->nodes[i]; + for (int j = 0; j < GGML_MAX_SRC; j++) { + if (node->src[j] != NULL && node->src[j] == cur_node) { + input_use_count++; + } + } + } + if (input_use_count == cur_node_use_count) { + cur_node = nullptr; + } + } + if (cur_node != nullptr) { + std::string node_output_name(cur_node->name); + m_model_outputs[node_output_name] = cur_node; + m_model_output_names.push_back(node_output_name); + } + } +} + +const ggml_tensor * GgmlOvDecoder::get_tensor_used_op(const ggml_tensor * tensor) const { + if (tensor == nullptr) { + return nullptr; + } + for (int i = 0; i < m_cgraph->n_nodes; i++) { + const auto * node = m_cgraph->nodes[i]; + for (int j = 0; j < GGML_MAX_SRC; j++) { + if (node->src[j] == tensor) { + return node; + } + } + } + return nullptr; +} + +const ggml_tensor * GgmlOvDecoder::get_tensor_from_name(const std::string & name) const { + for (int i = 0; i < m_cgraph->n_nodes; i++) { + const auto * node = m_cgraph->nodes[i]; + for (int j = 0; j < GGML_MAX_SRC; j++) { + const auto * src = node->src[j]; + if (src == nullptr) { + break; + } + if (std::string(src->name) == name) { + return src; + } + } + } + return nullptr; +} + +std::map<std::string, std::string> GgmlOvDecoder::get_kv_param_res_names() const { + std::map<std::string, std::string> kv_param_res_names; + for (const auto & name : m_model_params.kv_names) { + kv_param_res_names[name] = name; + } + return kv_param_res_names; +} + +std::map<std::string, std::shared_ptr<ov::Node>> GgmlOvDecoder::create_weight_nodes(ggml_cgraph * cgraph, bool naive) { + std::map<std::string, std::shared_ptr<ov::Node>> model_weights; + auto * nodes = cgraph->nodes; + auto n_nodes = cgraph->n_nodes; + for (int node_i = 0; node_i < n_nodes; node_i++) { + auto * node = nodes[node_i]; + for (int i = 0; i < GGML_MAX_SRC; i++) { + auto * src = node->src[i]; + if (src == nullptr) { + continue; + } + + std::string src_name(src->name); + if (is_rope_freqs_weight(src, node)) { + src_name = "rope_freqs.weight"; + } + if (!src->view_src) { + ggml_backend_buffer * buffer = src->buffer; + if (buffer->usage == GGML_BACKEND_BUFFER_USAGE_WEIGHTS || ggml_is_quantized(src->type)) { + if (model_weights.find(src_name) == model_weights.end()) { + auto weight_node = create_weight_node(src, naive); + weight_node->set_friendly_name(src_name); + model_weights[src_name] = weight_node; + } + } + } + } + } + return model_weights; +} + +std::shared_ptr<ov::Node> GgmlOvDecoder::create_weight_node(ggml_tensor * tensor, bool naive) { + const bool is_ov_buffer = ggml_backend_buffer_is_openvino(tensor->buffer); + + // Check if we have a pre-built constant from the OpenVINO backend buffer + // This is set during ggml_backend_openvino_buffer_set_tensor + if (tensor->extra) { + OPENVINO_ASSERT(is_ov_buffer, "Unsupported weight tensor: " + std::string(tensor->name) + + " Possibly this is a cpu backend repacked quantized weights"); + // Cast to our extra base type and check the type + auto * extra_base = static_cast<ggml_openvino_extra_base *>(tensor->extra); + + if (extra_base->type == ggml_openvino_extra_base::Type::WEIGHT) { + // F16/F32/BF16 weight with shared-memory constant + auto * weight_extra = static_cast<ggml_openvino_weight_extra *>(tensor->extra); + if (weight_extra->weight_node) { + // GGML_LOG_DEBUG("%s: using pre-built weight node for %s\n", __func__, tensor->name); + return weight_extra->weight_node; + } + } else if (extra_base->type == ggml_openvino_extra_base::Type::QUANTIZED_WEIGHT) { + // Quantized weight with pre-extracted data + auto * quant_extra = static_cast<ggml_openvino_quantized_weight_extra *>(tensor->extra); + if (quant_extra->weight_node) { + // GGML_LOG_DEBUG("%s: using pre-extracted quantized weight node for %s\n", __func__, tensor->name); + return quant_extra->weight_node; + } + } + } + + // There are three cases where we need to create a new weight node: + // 1. weights are in openvino_host_buffer. Weight loading to host buffer will not trigger backend_buffer_set_tensor + // 2. weights are in cpu/cpu_mapped buffer. On token_embd.weight goes to case 1 or 2, depending on whether mmap or direct_io is used + // 3. test-backend-ops. buffers in test-backend-ops does not set USAGE_WEIGHT so backend_buffer_set_tensor will not create weight node + + // GGML_LOG_DEBUG("%s: creating new weight node for %s\n", __func__, tensor->name); + static const std::set<ggml_type> weight_types = {GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_BF16, + GGML_TYPE_Q8_0, GGML_TYPE_Q4_0, GGML_TYPE_Q4_1, + GGML_TYPE_Q4_K, GGML_TYPE_Q5_K, GGML_TYPE_Q6_K}; + if (weight_types.find(tensor->type) == weight_types.end()) { + throw std::runtime_error("Unexpected weight tensor type: " + std::string(tensor->name) + " with type " + + ggml_type_name(tensor->type)); + } + + OvWeight ov_weight; + if (ggml_is_quantized(tensor->type)) { + auto use_bias = naive; + if (is_ov_buffer) { + // For quantized weights, copy raw data to a temp buffer first because + // process_weight_tensor reads from data and writes extracted results + // (weights/scales/zp) to output_base_ptr — they would overlap if both + // point to tensor->data. + size_t raw_size = ggml_nbytes(tensor); + std::vector<uint8_t> tmp(raw_size); + memcpy(tmp.data(), tensor->data, raw_size); + ov_weight = process_weight_tensor(tensor, tmp.data(), tensor->data, use_bias); + } else { + ov_weight = process_weight_tensor(tensor, tensor->data, nullptr, use_bias); + } + } else { + // For non-quantized weights (F16/F32/BF16), data is already in tensor->data. + // process_weight_tensor will create an ov::Tensor wrapping tensor->data directly. + ov_weight = process_weight_tensor(tensor, tensor->data, tensor->data); + } + + ov_weight.weight_node->set_friendly_name(tensor->name); + if (!is_ov_buffer) { + return ov_weight.weight_node; + } + + ggml_openvino_extra_base * extra; + if (ov_weight.is_quantized()) { + extra = new ggml_openvino_quantized_weight_extra(std::move(ov_weight.weights), std::move(ov_weight.scales), + std::move(ov_weight.zp), ov_weight.weight_node); + } else { + extra = new ggml_openvino_weight_extra(std::move(ov_weight.weights), ov_weight.weight_node); + } + ggml_openvino_buffer_register_extra(tensor, extra); + + return ov_weight.weight_node; +} + +void GgmlOvDecoder::dump_cgraph(const ggml_cgraph * cgraph, std::string & filename) { + std::ofstream file(filename); + if (!file.is_open()) { + std::cerr << "Failed to open file" << std::endl; + return; + } + + file << "=== GRAPH ===\n"; + + // clang-format off + file << "n_nodes = " << cgraph->n_nodes << "\n"; + file << " " << std::setw(3) << "nodes" + << std::setw(15) << "shape" + << std::setw(20) << "op" + << std::setw(20) << "name" + << std::setw(3) << " " + << std::setw(62) << "stride" + << std::setw(20) << "buffer_type" + << "\n"; + for (int i = 0; i < cgraph->n_nodes; i++) { + ggml_tensor * node = cgraph->nodes[i]; + + // Get buffer type name + const char * buf_name = "none"; + ggml_backend_buffer_t buf = node->view_src ? node->view_src->buffer : node->buffer; + if (buf) { + buf_name = ggml_backend_buffer_name(buf); + } + + file << " - " << std::setw(3) << i << ": [ " + << std::setw(5) << node->ne[0] << ", " + << std::setw(5) << node->ne[1] << ", " + << std::setw(5) << node->ne[2] << ", " + << std::setw(5) << node->ne[3] << "] " + << std::left << std::setw(20) << ggml_op_name(node->op) << std::right << " " + << std::left << std::setw(45) << node->name << std::right + << std::setw(2) << "[ " + << std::setw(0) << node->nb[0] << ", " + << std::setw(5) << node->nb[1] << ", " + << std::setw(5) << node->nb[2] << ", " + << std::setw(5) << node->nb[3] << "] " + << std::right << std::setw(15) << buf_name << std::right + << "\n"; + + for (int i = 0; i < GGML_MAX_SRC; i++) { + if (auto* src = node->src[i]) { + // Get buffer type name for source + const char * src_buf_name = "none"; + ggml_backend_buffer_t src_buf = src->view_src ? src->view_src->buffer : src->buffer; + if (src_buf) { + src_buf_name = ggml_backend_buffer_name(src_buf); + } + + file << std::setw(10) << " [ " + << std::setw(5) << src->ne[0] << ", " + << std::setw(5) << src->ne[1] << ", " + << std::setw(5) << src->ne[2] << ", " + << std::setw(5) << src->ne[3] << "] " + << std::setw(12) + << i << ": " << std::left << std::setw(12) << ggml_op_name(src->op) << std::right; + file << std::left << std::setw(30) << src->name << std::right + << std::setw(16) << "[ " + << std::setw(0) << src->nb[0] << ", " + << std::setw(5) << src->nb[1] << ", " + << std::setw(5) << src->nb[2] << ", " + << std::setw(5) << src->nb[3] << "] " + << std::right << std::setw(15) << src_buf_name << std::right + << "\n"; + } + } + } + + file << "n_leafs = " << cgraph->n_leafs << "\n"; + for (int i = 0; i < cgraph->n_leafs; i++) { + ggml_tensor * node = cgraph->leafs[i]; + + // Get buffer type name for leaf + const char * leaf_buf_name = "none"; + ggml_backend_buffer_t leaf_buf = node->view_src ? node->view_src->buffer : node->buffer; + if (leaf_buf) { + leaf_buf_name = ggml_backend_buffer_name(leaf_buf); + } + + file << " - " << std::setw(3) << i << ": [ " + << std::setw(5) << node->ne[0] << ", " + << std::setw(5) << node->ne[1] << "] " + << std::setw(8) << ggml_op_name(node->op) << " " + << std::setw(16) << ggml_get_name(node) + << std::setw(20) << leaf_buf_name << "\n"; + } + // clang-format on + file << "========================================\n"; + + file.close(); +} + +void print_tensor_address_map(const ggml_cgraph * cgraph) { + std::map<void *, std::vector<std::string>> address_map; + for (int node_n = 0; node_n < cgraph->n_nodes; node_n++) { + auto * node = cgraph->nodes[node_n]; + if (node->data) { + auto it = address_map.find(node->data); + if (it == address_map.end()) { + address_map[node->data] = std::vector<std::string>(); + } + address_map[node->data].push_back(node->name); + } + } + for (const auto & pair : address_map) { + std::cout << "Address: " << pair.first << std::endl; + for (const auto & name : pair.second) { + std::cout << name << " ; "; + } + std::cout << std::endl << std::endl; + } +} + +ov::Shape GgmlOvDecoder::get_shape(const ggml_tensor * tensor) { + std::vector<size_t> shape; + for (int i = GGML_MAX_DIMS - 1; i >= 0; --i) { + shape.push_back(static_cast<size_t>(tensor->ne[i])); + } + return shape; +} + +std::vector<size_t> GgmlOvDecoder::get_stride(const ggml_tensor * tensor) { + std::vector<size_t> stride; + for (int i = GGML_MAX_DIMS - 1; i >= 0; --i) { + stride.push_back(static_cast<size_t>(tensor->nb[i])); + } + return stride; +} + +ov::element::Type GgmlOvDecoder::get_ov_type(const ggml_tensor * tensor) { + switch (tensor->type) { + case GGML_TYPE_F64: + return ov::element::f64; + case GGML_TYPE_F32: + return ov::element::f32; + case GGML_TYPE_F16: + return ov::element::f16; + case GGML_TYPE_BF16: + return ov::element::bf16; + case GGML_TYPE_I8: + return ov::element::i8; + case GGML_TYPE_I16: + return ov::element::i16; + case GGML_TYPE_I32: + return ov::element::i32; + case GGML_TYPE_I64: + return ov::element::i64; + default: + return ov::element::dynamic; + } +} + +ov::PartialShape GgmlOvDecoder::get_input_shape(int node_idx, const std::string & name) const { + return ov::PartialShape(get_shape(m_node_info_list[node_idx].node_inputs.at(name))); +} + +std::vector<size_t> GgmlOvDecoder::get_input_stride(int node_idx, const std::string & name) const { + return get_stride(m_node_info_list[node_idx].node_inputs.at(name)); +} + +ov::element::Type GgmlOvDecoder::get_input_type(int node_idx, const std::string & name) const { + return get_ov_type(m_node_info_list[node_idx].node_inputs.at(name)); +} + +size_t GgmlOvDecoder::get_input_size() const { + return m_model_inputs.size(); +} + +size_t GgmlOvDecoder::get_input_size(int node_idx) const { + return m_node_info_list[node_idx].node_inputs_names.size(); +} + +std::vector<std::string> GgmlOvDecoder::get_input_names(int node_idx) const { + return m_node_info_list[node_idx].node_inputs_names; +} + +ov::PartialShape GgmlOvDecoder::get_output_shape(int node_idx) const { + auto * ggml_tensor = m_node_info_list[node_idx].node_output; + return ov::PartialShape(get_shape(ggml_tensor)); +} + +ov::element::Type GgmlOvDecoder::get_output_type(const int node_idx) const { + return get_ov_type(m_node_info_list[node_idx].node); +} + +std::vector<std::string> GgmlOvDecoder::get_output_names(int node_idx) const { + return {m_node_info_list[node_idx].node_output_name}; +} + +const std::string & GgmlOvDecoder::get_op_name() const { + static const std::string unknown_name = "UNKNOWN_OP_NAME"; + return unknown_name; +} + +const std::string & GgmlOvDecoder::get_op_name(int node_idx) const { + return m_node_info_list[node_idx].node_name; +} + +int32_t * GgmlOvDecoder::get_input_op_params(int node_idx, const std::string & name) const { + return m_node_info_list[node_idx].node_inputs.at(name)->op_params; +} + +int32_t * GgmlOvDecoder::get_output_op_params(int node_idx) const { + return m_node_info_list[node_idx].node->op_params; +} + +void GgmlOvDecoder::visit_subgraph(std::function<void(std::shared_ptr<GgmlDecoder>, int node_idx)> node_visitor) const { + for (int node_idx = 0; node_idx < m_cgraph->n_nodes; node_idx++) { + if (m_cgraph->nodes[node_idx]->op == GGML_OP_NONE) { + continue; + } + node_visitor(std::make_shared<GgmlOvDecoder>(*this), node_idx); + } +} + +std::string GgmlOvDecoder::compute_op_type(const ggml_tensor * node) { + static const std::map<ggml_op, std::string> ops = { + {GGML_OP_NONE, "GGML_OP_NONE" }, + {GGML_OP_ACC, "GGML_OP_ACC" }, + {GGML_OP_ADD, "GGML_OP_ADD" }, + {GGML_OP_ADD1, "GGML_OP_ADD1" }, + {GGML_OP_CONT, "GGML_OP_CONT" }, + {GGML_OP_DIV, "GGML_OP_DIV" }, + {GGML_OP_DUP, "GGML_OP_DUP" }, + {GGML_OP_GET_ROWS, "GGML_OP_GET_ROWS" }, + {GGML_OP_MUL, "GGML_OP_MUL" }, + {GGML_OP_MUL_MAT, "GGML_OP_MUL_MAT" }, + {GGML_OP_PERMUTE, "GGML_OP_PERMUTE" }, + {GGML_OP_RESHAPE, "GGML_OP_RESHAPE" }, + {GGML_OP_RMS_NORM, "GGML_OP_RMS_NORM" }, + {GGML_OP_ROPE, "GGML_OP_ROPE" }, + {GGML_OP_SCALE, "GGML_OP_SCALE" }, + {GGML_OP_SOFT_MAX, "GGML_OP_SOFT_MAX" }, + {GGML_OP_SUB, "GGML_OP_SUB" }, + {GGML_OP_TRANSPOSE, "GGML_OP_TRANSPOSE" }, + {GGML_OP_VIEW, "GGML_OP_VIEW" }, + {GGML_OP_SET_ROWS, "GGML_OP_SET_ROWS" }, + {GGML_OP_CPY, "GGML_OP_CPY" }, + {GGML_OP_FLASH_ATTN_EXT, "GGML_OP_FLASH_ATTN_EXT"}, + }; + static const std::map<ggml_unary_op, std::string> unary_ops = { + {GGML_UNARY_OP_ABS, "GGML_UNARY_OP_ABS" }, + {GGML_UNARY_OP_SGN, "GGML_UNARY_OP_SGN" }, + {GGML_UNARY_OP_NEG, "GGML_UNARY_OP_NEG" }, + {GGML_UNARY_OP_STEP, "GGML_UNARY_OP_STEP" }, + {GGML_UNARY_OP_TANH, "GGML_UNARY_OP_TANH" }, + {GGML_UNARY_OP_ELU, "GGML_UNARY_OP_ELU" }, + {GGML_UNARY_OP_RELU, "GGML_UNARY_OP_RELU" }, + {GGML_UNARY_OP_SIGMOID, "GGML_UNARY_OP_SIGMOID" }, + {GGML_UNARY_OP_GELU, "GGML_UNARY_OP_GELU" }, + {GGML_UNARY_OP_GELU_QUICK, "GGML_UNARY_OP_GELU_QUICK" }, + {GGML_UNARY_OP_SILU, "GGML_UNARY_OP_SILU" }, + {GGML_UNARY_OP_HARDSWISH, "GGML_UNARY_OP_HARDSWISH" }, + {GGML_UNARY_OP_HARDSIGMOID, "GGML_UNARY_OP_HARDSIGMOID"}, + {GGML_UNARY_OP_EXP, "GGML_UNARY_OP_EXP" }, + {GGML_UNARY_OP_COUNT, "GGML_UNARY_OP_COUNT" } + }; + static const std::map<ggml_glu_op, std::string> glu_ops = { + {GGML_GLU_OP_SWIGLU, "GGML_GLU_OP_SWIGLU"}, + {GGML_GLU_OP_GEGLU, "GGML_GLU_OP_GEGLU" }, + {GGML_GLU_OP_REGLU, "GGML_GLU_OP_REGLU" } + }; + + switch (node->op) { + case GGML_OP_UNARY: + return unary_ops.at(ggml_get_unary_op(node)); + case GGML_OP_GLU: + return glu_ops.at(ggml_get_glu_op(node)); + default: + return ops.at(node->op); + } + static const std::string unknown_op = "UNKNOWN_GGML_OP"; + return unknown_op; +} + +const std::string & GgmlOvDecoder::get_op_type(int node_idx) const { + return m_node_info_list[node_idx].node_op_type; +} + +const std::string & GgmlOvDecoder::get_op_type() const { + static const std::string unknown_op = "UNKNOWN_GGML_OP"; + return unknown_op; +} diff --git a/ggml/src/ggml-openvino/ggml-decoder.h b/ggml/src/ggml-openvino/ggml-decoder.h new file mode 100644 index 00000000000..3ae25ddda32 --- /dev/null +++ b/ggml/src/ggml-openvino/ggml-decoder.h @@ -0,0 +1,294 @@ +#pragma once + +#include "ggml-quants.h" +#include "ggml.h" +#include "openvino/decoder.h" + +#include <cstdint> +#include <cstring> +#include <map> +#include <memory> +#include <openvino/core/partial_shape.hpp> +#include <optional> +#include <vector> + +struct ModelParams { + int ctx = -1; + int ctx_swa = -1; + int ctx_per_seq = -1; + int ctx_per_seq_swa = -1; + int n_seq = 1; + int n_heads = -1; + int n_heads_kv = -1; + int head_size = -1; + int32_t rope_params[15]; + std::vector<int> swa_layers; + + std::vector<std::string> kv_names; + size_t kv_buffer_ctx_id = 0; + + bool same_rope_params(const ModelParams & other) const { + return memcmp(rope_params, other.rope_params, sizeof(int32_t) * 15) == 0; + } + + bool can_reuse_dynamically(const ModelParams & other) const { return same_rope_params(other); } + + bool can_reuse_statically(const ModelParams & other) const { return same_rope_params(other) && ctx == other.ctx; } + + bool kv_buffer_changed(const ModelParams & other) const { return kv_buffer_ctx_id != other.kv_buffer_ctx_id; } +}; + +struct ComputeParams { + int n_seq_active = 1; + int seq_active_start = 0; + int attention_size = -1; + int attention_size_swa = -1; + int input_len = -1; + int token_len_per_seq = -1; + int past_kv_len = -1; + int output_len = 1; +}; + +class GgmlOvDecoder : public ov::frontend::ggml::GgmlDecoder { +public: + struct NodeInfo { + ggml_tensor * node; + std::string node_name; + std::string node_op_type; + std::map<std::string, ggml_tensor *> node_inputs; + std::vector<std::string> node_inputs_names; + ggml_tensor * node_output; + std::string node_output_name; + int node_op_case = 0; + void * data_addr; + }; + // Graph decoder + GgmlOvDecoder(ggml_cgraph * cgraph, + ModelParams & model_params, + ComputeParams & compute_params, + std::map<std::string, std::shared_ptr<ov::Node>> & model_weights, + bool is_static, + bool is_stateful = false, + bool is_prefill = false, + int prefill_chunk_size = 256); + + // Naive graph decoder + GgmlOvDecoder(ggml_cgraph * cgraph, std::map<std::string, std::shared_ptr<ov::Node>> & model_weights); + + virtual ov::Any get_attribute(const std::string & name) const override { + return nullptr; + GGML_UNUSED(name); + } + + virtual ov::PartialShape get_input_shape(int node_idx, const std::string & name) const override; + + virtual std::vector<size_t> get_input_stride(int node_idx, const std::string & name) const override; + + virtual ov::element::Type get_input_type(int node_idx, const std::string & name) const override; + + virtual size_t get_input_size() const override; + + virtual size_t get_input_size(int node_idx) const override; + + virtual void get_input_node(size_t input_port_idx, + std::string & producer_name, + std::string & producer_output_port_name, + size_t & producer_output_port_index) const override { + GGML_UNUSED(input_port_idx); + GGML_UNUSED(producer_name); + GGML_UNUSED(producer_output_port_name); + GGML_UNUSED(producer_output_port_index); + } + + virtual std::vector<std::string> get_input_names(int node_idx) const override; + + virtual ov::PartialShape get_output_shape(int node_idx) const override; + + virtual ov::element::Type get_output_type(int node_idx) const override; + + virtual int32_t * get_input_op_params(int node_idx, const std::string & name) const override; + + virtual int32_t * get_output_op_params(int node_idx) const override; + + virtual std::vector<std::string> get_output_names(int node_idx) const override; + + virtual const std::string & get_op_type() const override; + + virtual const std::string & get_op_type(int node_idx) const override; + + virtual const std::string & get_op_name() const override; + + virtual const std::string & get_op_name(int node_idx) const override; + + virtual void visit_subgraph(std::function<void(std::shared_ptr<GgmlDecoder>, int node_idx)> node_visitor) const override; + + ggml_tensor * get_input_ggml_tensor(const std::string & name) const { return m_inputs.at(name); } + + virtual int get_op_case(int node_idx) const override { return m_node_info_list[node_idx].node_op_case; } + + virtual const std::map<std::string, std::shared_ptr<ov::Node>> & get_model_inputs() const override { + return m_model_inputs; + } + + virtual const std::map<std::string, std::shared_ptr<ov::Node>> & get_model_extra_inputs() const override { + return m_model_extra_inputs; + } + + virtual const std::map<std::string, std::shared_ptr<ov::Tensor>> & get_model_extra_input_values() const { + return m_model_extra_input_values; + } + + virtual const std::map<std::string, std::shared_ptr<ov::Node>> & get_model_weights() const override { + return m_model_weights; + } + + virtual std::vector<std::string> get_model_output_names() const override { + return m_model_output_names; + } + + const std::map<std::string, ggml_tensor *> & get_model_outputs() const { return m_model_outputs; } + + virtual int get_ctx_size() const { return m_model_params.ctx; } + + virtual int get_ctx_swa_size() const { return m_model_params.ctx_swa; } + + virtual int get_ctx_per_seq() const { return m_model_params.ctx_per_seq; } + + virtual int get_ctx_per_seq_swa() const { return m_model_params.ctx_per_seq_swa; } + + virtual int get_n_seq() const { return m_model_params.n_seq; } + + virtual int is_swa_layer(int layer) const override { + return std::find(m_model_params.swa_layers.begin(), m_model_params.swa_layers.end(), layer) != + m_model_params.swa_layers.end(); + } + + int get_past_kv_len() const { return m_compute_params.past_kv_len; } + + int get_input_len() const { return m_compute_params.input_len; } + + virtual int32_t * get_rope_params() const override { return const_cast<int32_t *>(m_model_params.rope_params); } + + virtual std::map<std::string, std::string> get_kv_param_res_names() const override; + + virtual bool is_static() const override { return m_is_static; } + + virtual bool is_stateful() const override { return m_is_stateful; } + + ov::PartialShape get_graph_input_shape(const ggml_tensor * op, const ggml_tensor * input) const; + + static void dump_cgraph(const ggml_cgraph * cgraph, std::string & filename); + + static std::shared_ptr<ov::Node> create_weight_node(ggml_tensor * tensor, bool naive = false); + + static std::map<std::string, std::shared_ptr<ov::Node>> create_weight_nodes(ggml_cgraph * cgraph, + bool naive = false); + + const ggml_tensor * get_tensor_used_op(const ggml_tensor * tensor) const; + + const ggml_tensor * get_tensor_from_name(const std::string & name) const; + + void clear_model_weights() { m_model_weights.clear(); } + + static std::pair<ModelParams, ComputeParams> compute_llm_params(ggml_cgraph * cgraph, bool is_static); + + ModelParams get_model_params() const { return m_model_params; } + + ComputeParams get_compute_params() const { return m_compute_params; } + + void set_model_params(const ModelParams & model_params) { m_model_params = model_params; } + + void set_compute_params(const ComputeParams & compute_params) { m_compute_params = compute_params; } + + bool m_is_static = false; + bool m_is_stateful = false; + bool m_is_prefill = false; + bool m_naive = false; + int m_prefill_chunk_size = 0; + + static ov::Shape get_shape(const ggml_tensor * tensor); + static std::vector<size_t> get_stride(const ggml_tensor * tensor); + static ov::element::Type get_ov_type(const ggml_tensor * tensor); + static std::string compute_op_type(const ggml_tensor * node); + void add_extra_inputs(); + + void update_io(ggml_cgraph * cgraph); + + inline static bool is_inp_tok(const ggml_tensor * tensor, const ggml_tensor * op) { + return op->op == GGML_OP_GET_ROWS && tensor == op->src[1] && op->src[0]->op == GGML_OP_NONE; + } + + inline static bool is_inp_pos(const ggml_tensor * tensor, const ggml_tensor * op) { + return op->op == GGML_OP_ROPE && tensor == op->src[1]; + } + + inline static bool is_inp_emb(const ggml_tensor * tensor, const ggml_tensor * op) { + return tensor->op == GGML_OP_GET_ROWS && op->op == GGML_OP_RMS_NORM; + } + + inline static bool is_inp_mask(const ggml_tensor * tensor, const ggml_tensor * op) { + return op->op == GGML_OP_CPY || (op->op == GGML_OP_FLASH_ATTN_EXT && tensor == op->src[3]); + } + + inline static bool is_rope_freqs_weight(const ggml_tensor * tensor, const ggml_tensor * op) { + return op->op == GGML_OP_ROPE && tensor == op->src[2]; + } + + inline static bool is_kvcache(const ggml_tensor * tensor, const ggml_tensor * op) { + return op->op == GGML_OP_SET_ROWS && op->src[2] == tensor; + } + + inline static bool is_kv_idx(const ggml_tensor * tensor, const ggml_tensor * op) { + return op->op == GGML_OP_SET_ROWS && op->src[1] == tensor; + } + + inline static bool is_output_idx(const ggml_tensor * tensor, const ggml_tensor * op) { + return op->op == GGML_OP_GET_ROWS && tensor == op->src[1] && op->src[0]->op != GGML_OP_NONE; + } + + static std::string get_graph_input_ov_name(const ggml_tensor * tensor, const ggml_tensor * op) { + if (is_inp_tok(tensor, op)) { + return "inp_tokens"; + } + if (is_inp_pos(tensor, op)) { + return "inp_pos"; + } + if (is_inp_emb(tensor, op)) { + return "embd"; + } + if (is_output_idx(tensor, op)) { + return "inp_out_ids"; + } + if (is_inp_mask(tensor, op)) { + return std::string(tensor->name).find("swa") == std::string::npos ? "self_kq_mask" : "self_kq_mask_swa"; + } + return tensor->name; + } + +private: + void set_input_output(); + int compute_op_case(const ggml_tensor * node) const; + bool node_is_used_as_src(const int node_idx); + void compute_model_inputs(); + void compute_model_outputs(); + + void validate_cgraph() const; + + ggml_cgraph * m_cgraph = nullptr; + std::map<std::string, ggml_tensor *> m_inputs; + + std::map<std::string, std::shared_ptr<ov::Node>> m_model_inputs; + std::map<std::string, std::shared_ptr<ov::Node>> m_model_extra_inputs; + std::map<std::string, std::shared_ptr<ov::Tensor>> m_model_extra_input_values; + std::map<std::string, std::shared_ptr<ov::Node>> m_model_weights; + std::map<std::string, ggml_tensor *> m_model_outputs; + std::vector<std::string> m_model_output_names; + std::vector<NodeInfo> m_node_info_list; + + ModelParams m_model_params; + ComputeParams m_compute_params; +}; + +void print_tensor_address_map(const ggml_cgraph * cgraph); + +int extract_layer_from_name(const std::string & name); diff --git a/ggml/src/ggml-openvino/ggml-openvino-extra.cpp b/ggml/src/ggml-openvino/ggml-openvino-extra.cpp new file mode 100644 index 00000000000..4140136aca2 --- /dev/null +++ b/ggml/src/ggml-openvino/ggml-openvino-extra.cpp @@ -0,0 +1,380 @@ +#include "ggml-openvino-extra.h" + +#include "ggml-impl.h" +#include "ggml.h" + +#include <cstring> +#include <openvino/runtime/intel_gpu/ocl/ocl.hpp> +#include <openvino/runtime/intel_npu/level_zero/level_zero.hpp> +#include <openvino/runtime/properties.hpp> +#include <optional> + +ov::Core & ov_singleton_core() { + static ov::Core core; + return core; +} + +// ===================================================== +// Device Configuration Implementations +// ===================================================== + +void ggml_openvino_device_config::init() { + if (initialized) { + return; + } + device_name = getenv("GGML_OPENVINO_DEVICE") ? getenv("GGML_OPENVINO_DEVICE") : "CPU"; + auto available_devices = ov_singleton_core().get_available_devices(); + if (std::find(available_devices.begin(), available_devices.end(), device_name) == available_devices.end()) { + GGML_LOG_WARN("GGML OpenVINO Backend: device %s is not available, fallback to CPU\n", device_name.c_str()); + device_name = "CPU"; + } + is_npu = (device_name == "NPU"); + + auto * cache_dir = getenv("GGML_OPENVINO_CACHE_DIR"); + if (device_name == "NPU") { + compile_config = { + {"NPU_COMPILER_DYNAMIC_QUANTIZATION", "YES" }, + {"NPU_USE_NPUW", "YES" }, + {"NPUW_DEVICES", "NPU" }, + {"NPUW_FOLD", "YES" }, + {"NPUW_WEIGHTS_BANK", "shared"}, + {"NPUW_FUNCALL_FOR_ALL", "YES" }, + {"NPUW_FUNCALL_ASYNC", "YES" }, + {"NPUW_DQ", "YES" }, + {"NPUW_DQ_FULL", "NO" }, + }; + if (cache_dir && strlen(cache_dir) > 0) { + compile_config["NPUW_CACHE_DIR"] = cache_dir; + compile_config.insert(ov::cache_mode(ov::CacheMode::OPTIMIZE_SIZE)); + } + } else if (cache_dir && strlen(cache_dir) > 0) { + compile_config.insert(ov::cache_dir(cache_dir)); + compile_config.insert(ov::cache_mode(ov::CacheMode::OPTIMIZE_SIZE)); + } + + // Initialize remote context with queue sharing for GPU + if (device_name == "GPU") { + // Create OpenCL context and queue + cl_int err; + cl_platform_id platform; + err = clGetPlatformIDs(1, &platform, nullptr); + if (err != CL_SUCCESS) { + GGML_LOG_ERROR("Failed to get OpenCL platform: %d\n", err); + return; + } + + cl_device_id cl_device; + err = clGetDeviceIDs(platform, CL_DEVICE_TYPE_GPU, 1, &cl_device, nullptr); + if (err != CL_SUCCESS) { + GGML_LOG_ERROR("Failed to get OpenCL device: %d\n", err); + return; + } + + cl_context cl_ctx = clCreateContext(nullptr, 1, &cl_device, nullptr, nullptr, &err); + if (err != CL_SUCCESS) { + GGML_LOG_ERROR("Failed to create OpenCL context: %d\n", err); + return; + } + + cl_queue = clCreateCommandQueueWithProperties(cl_ctx, cl_device, nullptr, &err); + if (err != CL_SUCCESS) { + GGML_LOG_ERROR("Failed to create OpenCL command queue: %d\n", err); + clReleaseContext(cl_ctx); + return; + } + + // Create OpenVINO remote context with queue sharing + remote_context = ov::intel_gpu::ocl::ClContext(ov_singleton_core(), cl_queue); + + // Release the context (queue keeps a reference) + clReleaseContext(cl_ctx); + } else if (device_name == "NPU") { + // remote tensor is not used for NPU yet + // remote_context = ov_singleton_core().get_default_context(device_name); + } + + initialized = true; +} + +ggml_openvino_device_config::~ggml_openvino_device_config() { + if (cl_queue != nullptr) { + clReleaseCommandQueue(cl_queue); + cl_queue = nullptr; + } +} + +// Get the global device config singleton +ggml_openvino_device_config & ggml_openvino_get_device_config() { + static ggml_openvino_device_config config; + return config; +} + +// Initialize device config (call during backend init) +void ggml_openvino_init_device_config() { + ggml_openvino_get_device_config().init(); +} + +// Get the device name +const std::string & ggml_openvino_get_device_name() { + return ggml_openvino_get_device_config().device_name; +} + +// Check if running on NPU +bool ggml_openvino_is_npu() { + return ggml_openvino_get_device_config().is_npu; +} + +// Get the remote context for the current device (returns empty optional for CPU) +std::optional<ov::RemoteContext> ggml_openvino_get_remote_context() { + return ggml_openvino_get_device_config().remote_context; +} + +// Get the compile config for the current device +const ov::AnyMap & ggml_openvino_get_compile_config() { + return ggml_openvino_get_device_config().compile_config; +} + +// Get the OpenCL command queue for GPU operations +cl_command_queue ggml_openvino_get_cl_queue() { + return ggml_openvino_get_device_config().cl_queue; +} + +// Get the clEnqueueMemFillINTEL function pointer (lazy load) +clEnqueueMemFillINTEL_fn ggml_openvino_get_clEnqueueMemFillINTEL() { + static clEnqueueMemFillINTEL_fn fn = nullptr; + static bool loaded = false; + if (!loaded) { + loaded = true; + cl_platform_id platform; + if (clGetPlatformIDs(1, &platform, nullptr) == CL_SUCCESS) { + fn = (clEnqueueMemFillINTEL_fn) clGetExtensionFunctionAddressForPlatform(platform, "clEnqueueMemFillINTEL"); + } + } + return fn; +} + +// Get the clEnqueueMemcpyINTEL function pointer (lazy load) +clEnqueueMemcpyINTEL_fn ggml_openvino_get_clEnqueueMemcpyINTEL() { + static clEnqueueMemcpyINTEL_fn fn = nullptr; + static bool loaded = false; + if (!loaded) { + loaded = true; + cl_platform_id platform; + if (clGetPlatformIDs(1, &platform, nullptr) == CL_SUCCESS) { + fn = (clEnqueueMemcpyINTEL_fn) clGetExtensionFunctionAddressForPlatform(platform, "clEnqueueMemcpyINTEL"); + } + } + return fn; +} + +// Get requantization type for a tensor type (returns nullopt if no requant needed) +std::optional<ExtraQuantType> ggml_openvino_get_requant_type(const ggml_tensor * tensor, bool no_requant) { + if (no_requant) { + return std::nullopt; + } + if (strncmp(tensor->name, "token_embd.weight", 17) == 0) { + return ((ggml_openvino_is_npu() && tensor->type == GGML_TYPE_Q6_K) ? ExtraQuantType::F16 : ExtraQuantType::Q8_0_C); + } + if (strncmp(tensor->name, "output.weight", 13) == 0) { + return ExtraQuantType::Q8_0_C; + } + if (ggml_openvino_is_npu()) { + return ExtraQuantType::Q4_0_128; + } + switch (tensor->type) { + case GGML_TYPE_Q6_K: + case GGML_TYPE_Q5_K: + return ExtraQuantType::Q8_0_C; + default: + return std::nullopt; + } +} + +// ===================================================== +// Extracted Layout Calculation +// ===================================================== + +ggml_openvino_extracted_layout ggml_openvino_get_extracted_layout(const ggml_tensor * tensor, bool use_bias) { + ggml_openvino_extracted_layout layout = {}; + layout.is_symmetric = false; + + if (!ggml_is_quantized(tensor->type)) { + return layout; + } + + // Only handle 2D weight tensors + if (tensor->ne[2] != 1 || tensor->ne[3] != 1) { + return layout; + } + + int64_t n_elements = ggml_nelements(tensor); + const size_t alignment = 64; // Good for SIMD + + // Check if requantization is needed (NPU-specific) + auto requant_type = ggml_openvino_get_requant_type(tensor, use_bias); + if (requant_type.has_value()) { + layout.is_requant = true; + layout.requant_type = requant_type; + + // Special case: requant to F16 - just store F16 weights, no scales/zp + if (requant_type.value() == ExtraQuantType::F16) { + layout.weights_size = n_elements * sizeof(uint16_t); // F16 = 2 bytes + layout.total_size = layout.weights_size; + layout.weights_offset = 0; + // No scales/zp for F16 + return layout; + } + + // Requant to different quantized format (e.g., Q4_0_128) + switch (requant_type.value()) { + case ExtraQuantType::Q4_0_128: + layout.is_u4 = true; + layout.weights_per_block = 128; + layout.is_symmetric = true; + break; + case ExtraQuantType::Q4_0_C: + layout.is_u4 = true; + layout.weights_per_block = tensor->ne[0]; + layout.is_symmetric = true; + break; + case ExtraQuantType::Q8_0_32: + layout.is_u4 = false; + layout.weights_per_block = 32; + layout.is_symmetric = true; + break; + case ExtraQuantType::Q8_0_C: + layout.is_u4 = false; + layout.weights_per_block = tensor->ne[0]; + layout.is_symmetric = true; + break; + case ExtraQuantType::Q8_1_C: + layout.is_u4 = false; + layout.weights_per_block = tensor->ne[0]; + break; + default: + layout.weights_per_block = -1; + GGML_ABORT("Code of re-quantizing to channel-wise is not updated"); + break; + } + + if (layout.is_requant) { + // Calculate sizes for requantized format + layout.weights_size = layout.is_u4 ? (n_elements / 2) : n_elements; + int64_t n_blocks = n_elements / layout.weights_per_block; + layout.scales_size = n_blocks * sizeof(uint16_t); + // For symmetric quantization, no zp needed (weights stored as signed) + if (layout.is_symmetric) { + layout.zp_size = 0; + } else { + layout.zp_size = layout.is_u4 ? ((n_blocks + 1) / 2) : n_blocks; + } + + layout.weights_offset = 0; + layout.scales_offset = ((layout.weights_size + alignment - 1) / alignment) * alignment; + layout.zp_offset = layout.scales_offset + ((layout.scales_size + alignment - 1) / alignment) * alignment; + layout.total_size = layout.zp_offset + layout.zp_size; + layout.total_size = std::max(layout.total_size, ggml_nbytes(tensor)); + return layout; + } + } + + // Normal extraction (no requant) - determine format based on tensor type + layout.is_u4 = false; + layout.weights_per_block = 32; + layout.is_symmetric = false; + + switch (tensor->type) { + case GGML_TYPE_Q4_0: + layout.is_u4 = true; + layout.is_symmetric = true; + break; + + case GGML_TYPE_Q4_1: + case GGML_TYPE_Q4_K: + layout.is_u4 = true; + break; + + case GGML_TYPE_Q8_0: + layout.is_symmetric = true; + break; + + case GGML_TYPE_Q6_K: + layout.weights_per_block = 16; + layout.is_symmetric = true; + break; + + case GGML_TYPE_Q5_K: + break; + + default: + // Unsupported quantization type + return layout; + } + + // Calculate sizes + // Weights: U4 = n_elements/2 bytes, U8 = n_elements bytes + layout.weights_size = layout.is_u4 ? (n_elements / 2) : n_elements; + + // Scales: F16 per block + int64_t n_blocks = n_elements / layout.weights_per_block; + layout.scales_size = n_blocks * sizeof(uint16_t); // F16 = 2 bytes + // For symmetric quantization, no zp needed (weights stored as signed) + if (layout.is_symmetric) { + layout.zp_size = 0; + } else { + layout.zp_size = layout.is_u4 ? ((n_blocks + 1) / 2) : n_blocks; + } + + // Layout in buffer: [weights | scales | zp] with alignment + layout.weights_offset = 0; + layout.scales_offset = ((layout.weights_size + alignment - 1) / alignment) * alignment; + layout.zp_offset = layout.scales_offset + ((layout.scales_size + alignment - 1) / alignment) * alignment; + layout.total_size = layout.zp_offset + layout.zp_size; + layout.total_size = std::max(layout.total_size, ggml_nbytes(tensor)); + + return layout; +} + +ggml_openvino_tensor_extra * ggml_openvino_create_tensor_extra(const ggml_tensor * tensor, bool is_remote) { + ov::Shape shape; + for (int i = GGML_MAX_DIMS - 1; i >= 0; --i) { + shape.push_back(static_cast<size_t>(tensor->ne[i])); + } + + ov::element::Type element_type; + switch (tensor->type) { + case GGML_TYPE_F32: + element_type = ov::element::f32; + break; + case GGML_TYPE_F16: + element_type = ov::element::f16; + break; + case GGML_TYPE_BF16: + element_type = ov::element::bf16; + break; + case GGML_TYPE_I32: + element_type = ov::element::i32; + break; + case GGML_TYPE_I64: + element_type = ov::element::i64; + break; + default: + // GGML_LOG_WARN("%s: unsupported tensor type for ov::Tensor: %s\n", __func__, ggml_type_name(tensor->type)); + return nullptr; + } + + const auto & device_name = ggml_openvino_get_device_name(); + auto remote_context = ggml_openvino_get_remote_context(); + + std::shared_ptr<ov::Tensor> ov_tensor; + if (is_remote) { + GGML_ASSERT(device_name == "GPU"); + auto gpu_context = remote_context->as<ov::intel_gpu::ocl::ClContext>(); + auto usm_tensor = gpu_context.create_tensor(element_type, shape, tensor->data); + ov_tensor = std::make_shared<ov::intel_gpu::ocl::USMTensor>(std::move(usm_tensor)); + } else { + ov_tensor = std::make_shared<ov::Tensor>(element_type, shape, tensor->data); + } + + return new ggml_openvino_tensor_extra(ov_tensor); +} diff --git a/ggml/src/ggml-openvino/ggml-openvino-extra.h b/ggml/src/ggml-openvino/ggml-openvino-extra.h new file mode 100644 index 00000000000..cd0baf4a681 --- /dev/null +++ b/ggml/src/ggml-openvino/ggml-openvino-extra.h @@ -0,0 +1,182 @@ +#pragma once + +#include "ggml.h" +#include "openvino/runtime/core.hpp" + +#define CL_TARGET_OPENCL_VERSION 300 +#include <CL/cl.h> + +#include <cstdlib> +#include <memory> +#include <openvino/core/node.hpp> +#include <openvino/runtime/remote_context.hpp> +#include <openvino/runtime/tensor.hpp> +#include <optional> +#include <string> + +// ExtraQuantType enum - defines requantization target formats +enum class ExtraQuantType { F16, Q4_0_C, Q8_1_C, Q4_0_128, Q8_0_C, Q8_0_32 }; + +ov::Core & ov_singleton_core(); + +// Get the remote context for the current device (returns empty optional for CPU) +std::optional<ov::RemoteContext> ggml_openvino_get_remote_context(); + +// Get the compile config for the current device +const ov::AnyMap & ggml_openvino_get_compile_config(); + +// Get the OpenCL command queue for GPU operations (returns nullptr for CPU/NPU) +cl_command_queue ggml_openvino_get_cl_queue(); + +// Intel USM extension function type +typedef cl_int(CL_API_CALL * clEnqueueMemFillINTEL_fn)(cl_command_queue queue, + void * dst_ptr, + const void * pattern, + size_t pattern_size, + size_t size, + cl_uint num_events_in_wait_list, + const cl_event * event_wait_list, + cl_event * event); + +typedef cl_int(CL_API_CALL * clEnqueueMemcpyINTEL_fn)(cl_command_queue queue, + cl_bool blocking, + void * dst_ptr, + const void * src_ptr, + size_t size, + cl_uint num_events_in_wait_list, + const cl_event * event_wait_list, + cl_event * event); + +// Get the clEnqueueMemFillINTEL function pointer (returns nullptr if not available) +clEnqueueMemFillINTEL_fn ggml_openvino_get_clEnqueueMemFillINTEL(); + +// Get the clEnqueueMemcpyINTEL function pointer (returns nullptr if not available) +clEnqueueMemcpyINTEL_fn ggml_openvino_get_clEnqueueMemcpyINTEL(); + +// ===================================================== +// Global Device Configuration (singleton) +// ===================================================== +// Initialized once during backend init from GGML_OPENVINO_DEVICE env var + +struct ggml_openvino_device_config { + std::string device_name = "CPU"; + bool is_npu = false; + bool initialized = false; + std::optional<ov::RemoteContext> remote_context; + ov::AnyMap compile_config; + cl_command_queue cl_queue = nullptr; + + void init(); + ~ggml_openvino_device_config(); +}; + +// Get the global device config singleton +ggml_openvino_device_config & ggml_openvino_get_device_config(); + +// Initialize device config (call during backend init) +void ggml_openvino_init_device_config(); + +// Get the device name +const std::string & ggml_openvino_get_device_name(); + +// Check if running on NPU +bool ggml_openvino_is_npu(); + +// Get requantization type for a tensor type (returns nullopt if no requant needed) +std::optional<ExtraQuantType> ggml_openvino_get_requant_type(const ggml_tensor * tensor, bool no_requant = false); + +// ===================================================== +// OpenVINO Tensor Extra Types +// ===================================================== +// These types are stored in tensor->extra by the OpenVINO backend buffer. +// They allow: +// 1. Pre-built ov::Constant nodes for weights (avoiding memcpy during graph construction) +// 2. ov::Tensor wrappers for KV cache / compute tensors (for direct use with infer_request) + +// Base class for OpenVINO tensor extra data +struct ggml_openvino_extra_base { + enum class Type { WEIGHT, QUANTIZED_WEIGHT, TENSOR }; + Type type; + virtual ~ggml_openvino_extra_base() = default; +protected: + explicit ggml_openvino_extra_base(Type t) : type(t) {} +}; + +// Extra data for F16/F32/BF16 weight tensors - stores the pre-built weight node +struct ggml_openvino_weight_extra : public ggml_openvino_extra_base { + ov::Tensor weights; // The underlying weight data tensor + std::shared_ptr<ov::Node> weight_node; // Pre-built OpenVINO weight node + + ggml_openvino_weight_extra(ov::Tensor w, std::shared_ptr<ov::Node> n) : + ggml_openvino_extra_base(Type::WEIGHT), + weights(std::move(w)), + weight_node(std::move(n)) {} +}; + +// Extra data for quantized weight tensors - stores extracted weights/scales/zp and weight node +struct ggml_openvino_quantized_weight_extra : public ggml_openvino_extra_base { + ov::Tensor weights; // U4 or U8 extracted weights + ov::Tensor scales; // F16 scales + ov::Tensor zp; // U4 or U8 zero points (same type as weights) + std::shared_ptr<ov::Node> weight_node; // Pre-built OpenVINO weight subgraph + + ggml_openvino_quantized_weight_extra(ov::Tensor w, ov::Tensor s, ov::Tensor z, std::shared_ptr<ov::Node> n) : + ggml_openvino_extra_base(Type::QUANTIZED_WEIGHT), + weights(std::move(w)), + scales(std::move(s)), + zp(std::move(z)), + weight_node(std::move(n)) {} +}; + +// Extra data for KV cache / compute tensors - stores ov::Tensor for infer_request +struct ggml_openvino_tensor_extra : public ggml_openvino_extra_base { + std::shared_ptr<ov::Tensor> tensor; // For direct use with infer_request + + explicit ggml_openvino_tensor_extra(std::shared_ptr<ov::Tensor> t) + : ggml_openvino_extra_base(Type::TENSOR), tensor(std::move(t)) {} +}; + +// ===================================================== +// Extracted Size Calculation for Quantized Tensors +// ===================================================== +// For quantized tensors, we need extra space to store extracted weights, scales, and zero points. +// Returns the total size needed in the buffer for extracted data. + +struct ggml_openvino_extracted_layout { + size_t total_size = 0; // Total bytes needed + size_t weights_offset = 0; // Offset to weights in buffer + size_t weights_size = 0; // Size of weights in bytes + size_t scales_offset = 0; // Offset to scales in buffer + size_t scales_size = 0; // Size of scales in bytes + size_t zp_offset = 0; // Offset to zero points in buffer + size_t zp_size = 0; // Size of zero points in bytes (U4 or U8) + bool is_u4; // true for U4 weights, false for U8 + int64_t weights_per_block; // weights per scale/zp block + bool is_symmetric; // true for symmetric quantization + + // Requantization info + bool is_requant = false; // true if this tensor needs requantization + std::optional<ExtraQuantType> requant_type; // target requant type if is_requant +}; + +// Calculate the buffer layout for extracted quantized data +ggml_openvino_extracted_layout ggml_openvino_get_extracted_layout(const ggml_tensor * tensor, bool use_bias = false); + +ggml_openvino_tensor_extra * ggml_openvino_create_tensor_extra(const ggml_tensor * tensor, bool is_remote); + +// Register an extra with the tensor's OpenVINO buffer context for proper lifetime management. +// This sets tensor->extra and tracks the extra in the buffer context for cleanup. +void ggml_openvino_buffer_register_extra(ggml_tensor * tensor, ggml_openvino_extra_base * extra); + +// ===================================================== +// OpenVINO Backend Context and Interface +// ===================================================== +struct ggml_backend_openvino_context { + int device = 0; + std::string name = "OpenVINO"; + std::string description = "OpenVINO Backend Context"; + + std::shared_ptr<void> runtime_context = nullptr; + + ggml_backend_openvino_context() = default; +}; diff --git a/ggml/src/ggml-openvino/ggml-openvino.cpp b/ggml/src/ggml-openvino/ggml-openvino.cpp new file mode 100644 index 00000000000..4f3ebf2536b --- /dev/null +++ b/ggml/src/ggml-openvino/ggml-openvino.cpp @@ -0,0 +1,1132 @@ +#include "ggml-openvino.h" + +#include "ggml-backend-impl.h" +#include "ggml-backend.h" +#include "ggml-impl.h" +#include "ggml-openvino-extra.h" +#include "ggml-openvino/utils.h" +#include "ggml-quants.h" +#include "ggml.h" + +#include <atomic> +#include <cstdlib> +#include <cstdint> +#include <cstring> +#include <memory> +#include <mutex> +#include <openvino/core/type/element_type.hpp> +#include <openvino/openvino.hpp> +#include <openvino/runtime/allocator.hpp> +#include <openvino/runtime/intel_gpu/ocl/ocl.hpp> +#include <openvino/runtime/intel_npu/level_zero/level_zero.hpp> +#include <openvino/runtime/tensor.hpp> +#include <set> +#include <string> +#include <vector> + +#if defined(_WIN32) +# define WIN32_LEAN_AND_MEAN +# ifndef NOMINMAX +# define NOMINMAX +# endif +# include <windows.h> +#else +# include <unistd.h> +#endif + +// ===================================================== +// OpenVINO Buffer Implementation using ov::Tensor +// ===================================================== +// +// Design: This implementation uses a hybrid approach: +// 1. For weight tensors: Store a pre-built ov::op::v0::Constant in tensor->extra +// - This avoids the memcpy during graph construction +// - For quantized weights, the constant is already converted to OpenVINO format +// 2. For KV cache / compute tensors: Store an ov::Tensor in tensor->extra +// - This can be directly passed to infer_request +// - Future: can be changed to ov::RemoteTensor for GPU/NPU +// +// This design is similar to: +// - CUDA split buffer: tensor->extra stores device pointers +// - CPU repack buffer: tensor->extra stores tensor_traits with repacked data +// ===================================================== + +// Buffer context that manages per-tensor allocations (no contiguous buffer for weights) +struct ggml_backend_openvino_buffer_context { + int device; + std::string name; + size_t id; + + // For non-weight buffers (KV cache, compute), we still use contiguous allocation + void * data; + size_t size; + bool is_remote; + + // Wrapping of the buffer + std::shared_ptr<ov::Tensor> ov_buffer; + + // Track all extras for cleanup + std::map<ggml_tensor *, ggml_openvino_extra_base *> tensor_extras; + + // Used for re-allocation on device for kvcache + void * data_prev; + + ggml_backend_openvino_buffer_context(int device, size_t size, bool is_remote = false) : + device(device), + name(std::string(GGML_OPENVINO_NAME) + std::to_string(device)), + id([]() { + static std::atomic<size_t> next_id{1}; + return next_id.fetch_add(1); + }()), + data(nullptr), + size(size), + is_remote(is_remote) { + if (size == 0) { + return; + } + + const auto & device_name = ggml_openvino_get_device_name(); + + if (is_remote) { + GGML_ASSERT(device_name == "GPU"); + auto remote_context = ggml_openvino_get_remote_context(); + auto gpu_context = remote_context->as<ov::intel_gpu::ocl::ClContext>(); + ov::intel_gpu::ocl::USMTensor usm_tensor = + gpu_context.create_usm_device_tensor(ov::element::u8, ov::Shape{size}); + data = usm_tensor.get(); + ov_buffer = std::make_shared<ov::intel_gpu::ocl::USMTensor>(std::move(usm_tensor)); + } else { + data = ggml_aligned_malloc(size); + GGML_ASSERT(data); + memset(data, 0, size); + ov_buffer = std::make_shared<ov::Tensor>(ov::element::u8, ov::Shape{size}, data); + } + + if (data == nullptr) { + GGML_LOG_ERROR("%s: failed to allocate %zu bytes\n", __func__, size); + return; + } + + if (reinterpret_cast<uintptr_t>(data) % TENSOR_ALIGNMENT != 0) { + GGML_LOG_ERROR("%s: %s buffer is not aligned to %d bytes\n", __func__, device_name.c_str(), + TENSOR_ALIGNMENT); + GGML_ABORT("fatal error"); + } + } + + ~ggml_backend_openvino_buffer_context() { + // Clean up all tensor extras + // GGML_LOG_DEBUG("Deleting OpenVINO buffer context #%zu for device %d, size %zu MB\n", id, device, + // size / 1024 / 1024); + for (auto & pair : tensor_extras) { + delete pair.second; + } + tensor_extras.clear(); + if (!is_remote && data != nullptr) { + ggml_aligned_free(data, size); + } + } +}; + +// Buffer type context (per-device) +struct ggml_backend_openvino_buffer_type_context { + int device; + std::string name; +}; + +// Buffer interface functions +static void ggml_backend_openvino_buffer_free_buffer(ggml_backend_buffer_t buffer) { + ggml_backend_openvino_buffer_context * ctx = (ggml_backend_openvino_buffer_context *) buffer->context; + delete ctx; +} + +static void * ggml_backend_openvino_buffer_get_base(ggml_backend_buffer_t buffer) { + ggml_backend_openvino_buffer_context * ctx = (ggml_backend_openvino_buffer_context *) buffer->context; + return ctx->data; +} + +static bool is_stateful_enabled() { + static const auto * stateful = getenv("GGML_OPENVINO_STATEFUL_EXECUTION"); + return stateful && *stateful != '\0' && strcmp(stateful, "0") != 0; +} + +static enum ggml_status ggml_backend_openvino_buffer_init_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor) { + // GGML_LOG_DEBUG("%s: buffer usage=%d, tensor name=%s\n", __func__, buffer->usage, tensor->name); + ggml_backend_openvino_buffer_context * ctx = (ggml_backend_openvino_buffer_context *) buffer->context; + + // Put kvcache on device memory for GPU (NPU memory is too small even for kvcache) + if (strncmp(tensor->name, "cache_", 6) == 0 && !ctx->is_remote && ggml_openvino_get_device_name() == "GPU" && + !is_stateful_enabled()) { + GGML_ASSERT(ctx->tensor_extras.empty()); + auto device = ctx->device; + auto size = ctx->size; + auto * data_prev = ctx->data; + delete ctx; + ctx = new ggml_backend_openvino_buffer_context(device, size, true); + buffer->context = ctx; + tensor->data = (char *) ctx->data + ((char *) tensor->data - (char *) data_prev); + } + + // Views share the extra from view_src + if (tensor->view_src != nullptr) { + GGML_ASSERT(tensor->view_src->buffer->buft == buffer->buft); + if (tensor->view_src->extra != nullptr) { + tensor->extra = tensor->view_src->extra; + } + return GGML_STATUS_SUCCESS; + } + + ctx = (ggml_backend_openvino_buffer_context *) buffer->context; + + if (tensor->data != nullptr && !ggml_is_quantized(tensor->type)) { + ggml_openvino_tensor_extra * extra = ggml_openvino_create_tensor_extra(tensor, ctx->is_remote); + if (extra != nullptr) { + auto it = ctx->tensor_extras.find(tensor); + if (it != ctx->tensor_extras.end()) { + delete it->second; + } + ctx->tensor_extras[tensor] = extra; + tensor->extra = extra; + } + } + + return GGML_STATUS_SUCCESS; +} + +static void ggml_backend_openvino_buffer_memset_tensor(ggml_backend_buffer_t buffer, + ggml_tensor * tensor, + uint8_t value, + size_t offset, + size_t size) { + // GGML_LOG_DEBUG("%s: buffer usage=%d, tensor name=%s\n", __func__, buffer->usage, tensor->name); + GGML_ASSERT(tensor != nullptr && tensor->data != nullptr); + ggml_backend_openvino_buffer_context * ctx = (ggml_backend_openvino_buffer_context *) buffer->context; + + if (ctx->is_remote) { + // For remote (device) buffers, use OpenCL USM memfill + cl_command_queue queue = ggml_openvino_get_cl_queue(); + auto mem_fill_fn = ggml_openvino_get_clEnqueueMemFillINTEL(); + if (queue != nullptr && mem_fill_fn != nullptr) { + uint8_t pattern = value; + cl_int err = mem_fill_fn(queue, (char *) tensor->data + offset, &pattern, sizeof(pattern), size, 0, nullptr, + nullptr); + if (err != CL_SUCCESS) { + GGML_LOG_ERROR("%s: clEnqueueMemFillINTEL failed with error %d\n", __func__, err); + } + clFinish(queue); + } else { + GGML_LOG_ERROR("%s: no OpenCL queue or clEnqueueMemFillINTEL not available for GPU buffer\n", __func__); + } + } else { + memset((char *) tensor->data + offset, value, size); + } +} + +static void ggml_backend_openvino_buffer_set_tensor(ggml_backend_buffer_t buffer, + ggml_tensor * tensor, + const void * data, + size_t offset, + size_t size) { + // GGML_LOG_DEBUG("%s: buffer usage=%d, tensor name=%s\n", __func__, buffer->usage, tensor->name); + GGML_ASSERT(tensor != nullptr && tensor->data != nullptr); + ggml_backend_openvino_buffer_context * ctx = (ggml_backend_openvino_buffer_context *) buffer->context; + + // Check if this is a weight buffer (usage is set BEFORE set_tensor is called, except in test-backend-ops) + bool is_weight_buffer = (buffer->usage == GGML_BACKEND_BUFFER_USAGE_WEIGHTS); + // Full tensor set: offset=0, full size, not a view + bool is_full_tensor_set = (offset == 0 && size == ggml_nbytes(tensor) && tensor->view_src == nullptr); + // 2D tensor (typical weight shape) + bool is_2d = (tensor->ne[2] == 1 && tensor->ne[3] == 1); + + if (is_weight_buffer && is_full_tensor_set && is_2d) { + try { + auto result = process_weight_tensor(tensor, data, tensor->data); + result.weight_node->set_friendly_name(tensor->name); + + // const auto & layout = result.layout; + ggml_openvino_extra_base * extra; + + // Quantized path with extracted weight/scale/zp tensors + if (result.is_quantized()) { + extra = new ggml_openvino_quantized_weight_extra(std::move(result.weights), std::move(result.scales), + std::move(result.zp), result.weight_node); + + // if (layout.is_requant) { + // GGML_LOG_DEBUG("%s: requantized %s to %s (u%d, block_size=%ld)\n", __func__, tensor->name, + // extra_quant_type_name(layout.requant_type.value()), layout.is_u4 ? 4 : 8, + // layout.weights_per_block); + // } else { + // int64_t n_blocks = ggml_nelements(tensor) / layout.weights_per_block; + // GGML_LOG_DEBUG("%s: extracted quantized weight node for %s (u%d, %zu weights, %ld blocks)\n", + // __func__, tensor->name, layout.is_u4 ? 4 : 8, layout.weights_size, n_blocks); + // } + } else { + // F16/F32/BF16 weight or F16-requant + extra = new ggml_openvino_weight_extra(std::move(result.weights), result.weight_node); + + // if (layout.total_size > 0) { + // GGML_LOG_DEBUG("%s: requantized %s to F16\n", __func__, tensor->name); + // } else { + // GGML_LOG_DEBUG("%s: created shared-memory weight node for %s\n", __func__, tensor->name); + // } + } + + ctx->tensor_extras[tensor] = extra; + tensor->extra = extra; + + } catch (const std::exception & e) { + GGML_LOG_ERROR("%s: failed to process weight tensor for %s: %s\n", __func__, tensor->name, e.what()); + memcpy((char *) tensor->data + offset, data, size); + } + } else { + // Non-weight tensor (KV cache, activations, etc.) - copy data. test-backend-ops also goes here + if (ctx->is_remote) { + cl_command_queue queue = ggml_openvino_get_cl_queue(); + auto mem_cpy_fn = ggml_openvino_get_clEnqueueMemcpyINTEL(); + if (queue != nullptr && mem_cpy_fn != nullptr) { + cl_int err = + mem_cpy_fn(queue, CL_TRUE, (char *) tensor->data + offset, data, size, 0, nullptr, nullptr); + if (err != CL_SUCCESS) { + GGML_LOG_ERROR("%s: clEnqueueMemcpyINTEL failed with error %d\n", __func__, err); + } + } else { + GGML_LOG_ERROR("%s: no OpenCL queue or clEnqueueMemcpyINTEL not available for GPU buffer\n", __func__); + } + } else { + memcpy((char *) tensor->data + offset, data, size); + } + + ggml_openvino_tensor_extra * extra = ggml_openvino_create_tensor_extra(tensor, ctx->is_remote); + if (extra == nullptr) { + // GGML_LOG_ERROR("%s: failed to create tensor extra for %s\n", __func__, tensor->name); + return; + } + + auto it = ctx->tensor_extras.find(tensor); + if (it != ctx->tensor_extras.end()) { + delete it->second; + } + ctx->tensor_extras[tensor] = extra; + tensor->extra = extra; + } +} + +static void ggml_backend_openvino_buffer_get_tensor(ggml_backend_buffer_t buffer, + const ggml_tensor * tensor, + void * data, + size_t offset, + size_t size) { + // GGML_LOG_DEBUG("%s: buffer usage=%d, tensor name=%s\n", __func__, buffer->usage, tensor->name); + GGML_ASSERT(tensor != nullptr && tensor->data != nullptr); + ggml_backend_openvino_buffer_context * ctx = (ggml_backend_openvino_buffer_context *) buffer->context; + + if (ctx->is_remote) { + // For remote (device) buffers, use OpenCL USM memcpy (device-to-host) + cl_command_queue queue = ggml_openvino_get_cl_queue(); + auto mem_cpy_fn = ggml_openvino_get_clEnqueueMemcpyINTEL(); + if (queue != nullptr && mem_cpy_fn != nullptr) { + cl_int err = + mem_cpy_fn(queue, CL_TRUE, data, (const char *) tensor->data + offset, size, 0, nullptr, nullptr); + if (err != CL_SUCCESS) { + GGML_LOG_ERROR("%s: clEnqueueMemcpyINTEL failed with error %d\n", __func__, err); + } + } else { + GGML_LOG_ERROR("%s: no OpenCL queue or clEnqueueMemcpyINTEL not available for GPU buffer\n", __func__); + } + } else { + memcpy(data, (const char *) tensor->data + offset, size); + } +} + +static bool ggml_backend_openvino_buffer_cpy_tensor(ggml_backend_buffer_t buffer, + const ggml_tensor * src, + ggml_tensor * dst) { + // GGML_LOG_DEBUG("%s: src tensor name=%s, dst tensor name=%s\n", __func__, src->name, dst->name); + GGML_ASSERT(src != nullptr && dst != nullptr); + ggml_backend_openvino_buffer_context * ctx = (ggml_backend_openvino_buffer_context *) buffer->context; + + if (ctx->is_remote) { + // For remote (device) buffers, use OpenCL USM memcpy + cl_command_queue queue = ggml_openvino_get_cl_queue(); + auto mem_cpy_fn = ggml_openvino_get_clEnqueueMemcpyINTEL(); + if (queue == nullptr || mem_cpy_fn == nullptr) { + GGML_LOG_ERROR("%s: no OpenCL queue or clEnqueueMemcpyINTEL not available for GPU buffer\n", __func__); + return false; + } + // Can copy from host to device + if (ggml_backend_buffer_is_host(src->buffer)) { + cl_int err = mem_cpy_fn(queue, CL_TRUE, dst->data, src->data, ggml_nbytes(src), 0, nullptr, nullptr); + if (err != CL_SUCCESS) { + GGML_LOG_ERROR("%s: clEnqueueMemcpyINTEL (host-to-device) failed with error %d\n", __func__, err); + return false; + } + return true; + } + // Can also copy from device to device if both are OpenVINO remote buffers + if (ggml_backend_buffer_is_openvino(src->buffer)) { + ggml_backend_openvino_buffer_context * src_ctx = + (ggml_backend_openvino_buffer_context *) src->buffer->context; + if (src_ctx->is_remote) { + cl_int err = + mem_cpy_fn(queue, CL_TRUE, dst->data, src->data, ggml_nbytes(src), 0, nullptr, nullptr); + if (err != CL_SUCCESS) { + GGML_LOG_ERROR("%s: clEnqueueMemcpyINTEL (device-to-device) failed with error %d\n", __func__, + err); + return false; + } + return true; + } + } + return false; + } + + // Host buffer - can copy from any host buffer + if (ggml_backend_buffer_is_host(src->buffer)) { + memcpy(dst->data, src->data, ggml_nbytes(src)); + return true; + } + return false; +} + +static void ggml_backend_openvino_buffer_clear(ggml_backend_buffer_t buffer, uint8_t value) { + ggml_backend_openvino_buffer_context * ctx = (ggml_backend_openvino_buffer_context *) buffer->context; + GGML_ASSERT(ctx->data != nullptr); + if (ctx->is_remote) { + cl_command_queue queue = ggml_openvino_get_cl_queue(); + auto mem_fill_fn = ggml_openvino_get_clEnqueueMemFillINTEL(); + if (queue != nullptr && mem_fill_fn != nullptr) { + uint8_t pattern = value; + cl_int err = mem_fill_fn(queue, ctx->data, &pattern, sizeof(pattern), ctx->size, 0, nullptr, nullptr); + if (err != CL_SUCCESS) { + GGML_LOG_WARN("%s: clEnqueueMemFillINTEL failed with error %d\n", __func__, err); + } + clFinish(queue); + } else { + GGML_LOG_WARN("%s: no OpenCL queue or clEnqueueMemFillINTEL not available for GPU buffer clear\n", + __func__); + } + } else { + memset(ctx->data, value, ctx->size); + } +} + +static const ggml_backend_buffer_i ggml_backend_openvino_buffer_interface = { + /* .free_buffer = */ ggml_backend_openvino_buffer_free_buffer, + /* .get_base = */ ggml_backend_openvino_buffer_get_base, + /* .init_tensor = */ ggml_backend_openvino_buffer_init_tensor, + /* .memset_tensor = */ ggml_backend_openvino_buffer_memset_tensor, + /* .set_tensor = */ ggml_backend_openvino_buffer_set_tensor, + /* .get_tensor = */ ggml_backend_openvino_buffer_get_tensor, + /* .set_tensor_2d = */ NULL, + /* .get_tensor_2d = */ NULL, + /* .cpy_tensor = */ ggml_backend_openvino_buffer_cpy_tensor, + /* .clear = */ ggml_backend_openvino_buffer_clear, + /* .reset = */ NULL, +}; + +// Buffer type interface functions +static const char * ggml_backend_openvino_buffer_type_get_name(ggml_backend_buffer_type_t buft) { + ggml_backend_openvino_buffer_type_context * ctx = (ggml_backend_openvino_buffer_type_context *) buft->context; + return ctx->name.c_str(); +} + +static ggml_backend_buffer_t ggml_backend_openvino_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, + size_t size) { + ggml_backend_openvino_buffer_type_context * buft_ctx = (ggml_backend_openvino_buffer_type_context *) buft->context; + + // Create buffer context with contiguous memory allocation + ggml_backend_openvino_buffer_context * ctx = new ggml_backend_openvino_buffer_context(buft_ctx->device, size); + + if (ctx->data == nullptr && size > 0) { + GGML_LOG_ERROR("%s: failed to allocate buffer of size %zu\n", __func__, size); + delete ctx; + return nullptr; + } + + return ggml_backend_buffer_init(buft, ggml_backend_openvino_buffer_interface, ctx, size); +} + +static size_t ggml_backend_openvino_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) { + GGML_UNUSED(buft); + return TENSOR_ALIGNMENT; +} + +static size_t ggml_backend_openvino_buffer_type_get_max_size(ggml_backend_buffer_type_t buft) { + GGML_UNUSED(buft); + return SIZE_MAX; +} + +static size_t ggml_backend_openvino_buffer_type_get_alloc_size(ggml_backend_buffer_type_t buft, + const ggml_tensor * tensor) { + GGML_UNUSED(buft); + + // For quantized 2D tensors (weights), we need extra space for extracted data + if (ggml_is_quantized(tensor->type) && tensor->ne[2] == 1 && tensor->ne[3] == 1) { + ggml_openvino_extracted_layout layout = ggml_openvino_get_extracted_layout(tensor); + if (layout.total_size > 0) { + // GGML_LOG_DEBUG("%s: tensor %s needs %zu bytes (original %zu, extracted: weights=%zu scales=%zu zp=%zu)\n", + // __func__, tensor->name, layout.total_size, ggml_nbytes(tensor), layout.weights_size, + // layout.scales_size, layout.zp_size); + return layout.total_size; + } + } + + return ggml_nbytes(tensor); +} + +static const ggml_backend_buffer_type_i ggml_backend_openvino_buffer_type_interface = { + /* .get_name = */ ggml_backend_openvino_buffer_type_get_name, + /* .alloc_buffer = */ ggml_backend_openvino_buffer_type_alloc_buffer, + /* .get_alignment = */ ggml_backend_openvino_buffer_type_get_alignment, + /* .get_max_size = */ ggml_backend_openvino_buffer_type_get_max_size, + /* .get_alloc_size = */ ggml_backend_openvino_buffer_type_get_alloc_size, + /* .is_host = */ nullptr, +}; + +// Get buffer type for a specific device +GGML_BACKEND_API ggml_backend_buffer_type_t ggml_backend_openvino_buffer_type(int device) { + GGML_ASSERT(device >= 0 && device < ggml_backend_openvino_get_device_count()); + + static std::mutex mutex; + std::lock_guard<std::mutex> lock(mutex); + + static std::vector<ggml_backend_buffer_type> buffer_types; + static std::vector<ggml_backend_openvino_buffer_type_context> buffer_type_contexts; + + if (buffer_types.empty()) { + int device_count = ggml_backend_openvino_get_device_count(); + buffer_types.resize(device_count); + buffer_type_contexts.resize(device_count); + + for (int i = 0; i < device_count; i++) { + buffer_type_contexts[i].device = i; + buffer_type_contexts[i].name = std::string(GGML_OPENVINO_NAME) + std::to_string(i); + + buffer_types[i] = ggml_backend_buffer_type{ + /* .iface = */ ggml_backend_openvino_buffer_type_interface, + /* .device = */ ggml_backend_reg_dev_get(ggml_backend_openvino_reg(), i), + /* .context = */ &buffer_type_contexts[i], + }; + } + } + + return &buffer_types[device]; +} + +// ===================================================== +// OpenVINO Host Buffer Implementation +// ===================================================== + +static const char * ggml_backend_openvino_host_buffer_type_get_name(ggml_backend_buffer_type_t buft) { + ggml_backend_openvino_buffer_type_context * ctx = (ggml_backend_openvino_buffer_type_context *) buft->context; + static std::string name; + name = ctx->name + "_HOST"; + return name.c_str(); +} + +static bool ggml_backend_openvino_host_buffer_type_is_host(ggml_backend_buffer_type_t buft) { + GGML_UNUSED(buft); + return true; +} + +static const ggml_backend_buffer_type_i ggml_backend_openvino_host_buffer_type_interface = { + /* .get_name = */ ggml_backend_openvino_host_buffer_type_get_name, + /* .alloc_buffer = */ ggml_backend_openvino_buffer_type_alloc_buffer, + /* .get_alignment = */ ggml_backend_openvino_buffer_type_get_alignment, + /* .get_max_size = */ ggml_backend_openvino_buffer_type_get_max_size, + /* .get_alloc_size = */ ggml_backend_openvino_buffer_type_get_alloc_size, + /* .is_host = */ ggml_backend_openvino_host_buffer_type_is_host, +}; + +GGML_BACKEND_API ggml_backend_buffer_type_t ggml_backend_openvino_host_buffer_type(int device) { + GGML_ASSERT(device >= 0 && device < ggml_backend_openvino_get_device_count()); + + static std::mutex mutex; + std::lock_guard<std::mutex> lock(mutex); + + static std::vector<ggml_backend_buffer_type> buffer_types; + static std::vector<ggml_backend_openvino_buffer_type_context> buffer_type_contexts; + + if (buffer_types.empty()) { + int device_count = ggml_backend_openvino_get_device_count(); + buffer_types.resize(device_count); + buffer_type_contexts.resize(device_count); + + for (int i = 0; i < device_count; i++) { + buffer_type_contexts[i].device = i; + buffer_type_contexts[i].name = std::string(GGML_OPENVINO_NAME) + std::to_string(i); + + buffer_types[i] = ggml_backend_buffer_type{ + /* .iface = */ ggml_backend_openvino_host_buffer_type_interface, + /* .device = */ ggml_backend_reg_dev_get(ggml_backend_openvino_reg(), i), + /* .context = */ &buffer_type_contexts[i], + }; + } + } + + return &buffer_types[device]; +} + +bool ggml_backend_buffer_is_openvino(ggml_backend_buffer_t buffer) { + return buffer->iface.free_buffer == ggml_backend_openvino_buffer_free_buffer; +} + +size_t ggml_backend_openvino_buffer_get_ctx_id(ggml_backend_buffer_t buffer) { + if (!ggml_backend_buffer_is_openvino(buffer)) { + return 0; + } + ggml_backend_openvino_buffer_context * ctx = (ggml_backend_openvino_buffer_context *) buffer->context; + return ctx->id; +} + +void ggml_openvino_buffer_register_extra(ggml_tensor * tensor, ggml_openvino_extra_base * extra) { + GGML_ASSERT(tensor != nullptr); + GGML_ASSERT(tensor->buffer != nullptr); + GGML_ASSERT(ggml_backend_buffer_is_openvino(tensor->buffer)); + + auto * ctx = static_cast<ggml_backend_openvino_buffer_context *>(tensor->buffer->context); + + auto it = ctx->tensor_extras.find(tensor); + if (it != ctx->tensor_extras.end()) { + delete it->second; + } + + ctx->tensor_extras[tensor] = extra; + tensor->extra = extra; +} + +bool ggml_backend_buft_is_openvino(ggml_backend_buffer_type_t buft) { + return buft->iface.get_name == ggml_backend_openvino_buffer_type_get_name; +} + +bool ggml_backend_buft_is_openvino_host(ggml_backend_buffer_type_t buft) { + return buft->iface.get_name == ggml_backend_openvino_host_buffer_type_get_name; +} + +static void ggml_backend_openvino_free(ggml_backend_t backend) { + ggml_backend_openvino_context * ctx = (ggml_backend_openvino_context *) backend->context; + + if (ctx->runtime_context) { + auto r_ctx = std::static_pointer_cast<ov_runtime_context>(ctx->runtime_context); + if (--r_ctx->backend_count == 0) { + r_ctx->clear_caches(); + } + } + + delete ctx; + delete backend; +} + +static const char * ggml_backend_openvino_get_name(ggml_backend_t backend) { + return GGML_OPENVINO_NAME; + GGML_UNUSED(backend); +} + +static enum ggml_status ggml_backend_openvino_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) { + return ov_graph_compute(cgraph, backend); + GGML_UNUSED(backend); +} + +static const ggml_backend_i ggml_backend_openvino_interface = { + /* .get_name = */ ggml_backend_openvino_get_name, + /* .free = */ ggml_backend_openvino_free, + /* .set_tensor_async = */ NULL, + /* .get_tensor_async = */ NULL, + /* .set_tensor_2d_async = */ NULL, + /* .get_tensor_2d_async = */ NULL, + /* .cpy_tensor_async = */ NULL, + /* .synchronize = */ NULL, + /* .graph_plan_create = */ NULL, + /* .graph_plan_free = */ NULL, + /* .graph_plan_update = */ NULL, + /* .graph_plan_compute = */ NULL, + /* .graph_compute = */ ggml_backend_openvino_graph_compute, + /* .event_record = */ NULL, + /* .event_wait = */ NULL, + /* .graph_optimize = */ NULL, +}; + +int ggml_backend_openvino_get_device_count() { + return 1; +} + +static ggml_guid_t ggml_backend_openvino_guid(void) { + static ggml_guid guid = {0x12, 0xa8, 0xae, 0xf4, 0xc0, 0x1e, 0x61, 0x97, + 0x8f, 0xeb, 0x33, 0x04, 0xa1, 0x33, 0x51, 0x2d}; + return &guid; +} + +static std::shared_ptr<ov_runtime_context> get_ov_runtime_context_ptr() { + static std::shared_ptr<ov_runtime_context> r_ctx = [] { + auto ctx = std::make_shared<ov_runtime_context>(); + ctx->device = ggml_openvino_get_device_name(); + ctx->stateful = is_stateful_enabled() && !ggml_openvino_is_npu(); + return ctx; + }(); + return r_ctx; +} + +// backend API +GGML_BACKEND_API ggml_backend_t ggml_backend_openvino_init(int device) { + if (device < 0 || device >= ggml_backend_openvino_get_device_count()) { + GGML_LOG_ERROR("%s: invalid device %d\n", __func__, device); + return nullptr; + } + + ggml_backend_openvino_context * ctx = new ggml_backend_openvino_context; + if (ctx == nullptr) { + GGML_LOG_ERROR("%s: failed to allocate context\n", __func__); + return nullptr; + } + + ctx->runtime_context = get_ov_runtime_context_ptr(); + if (ctx->runtime_context == nullptr) { + GGML_LOG_ERROR("%s: failed to allocate runtime context\n", __func__); + delete ctx; + return nullptr; + } + + std::shared_ptr<ov_runtime_context> r_ctx = std::static_pointer_cast<ov_runtime_context>(ctx->runtime_context); + r_ctx->backend_count++; + + ggml_backend_t openvino_backend = new ggml_backend{ + /* .guid = */ ggml_backend_openvino_guid(), + /* .interface = */ ggml_backend_openvino_interface, + /* .device = */ ggml_backend_reg_dev_get(ggml_backend_openvino_reg(), device), + /* .context = */ ctx, + }; + + return openvino_backend; +} + +GGML_BACKEND_API bool ggml_backend_is_openvino(ggml_backend_t backend) { + return backend != NULL && ggml_guid_matches(backend->guid, ggml_backend_openvino_guid()); +} + +struct ggml_backend_openvino_device_context { + int device; + std::string name; + std::string description; +}; + +static const char * ggml_backend_openvino_device_get_name(ggml_backend_dev_t dev) { + ggml_backend_openvino_device_context * ctx = (ggml_backend_openvino_device_context *) dev->context; + return ctx->name.c_str(); +} + +static const char * ggml_backend_openvino_device_get_description(ggml_backend_dev_t dev) { + ggml_backend_openvino_device_context * ctx = (ggml_backend_openvino_device_context *) dev->context; + return ctx->description.c_str(); +} + +static void ggml_backend_openvino_device_get_memory(ggml_backend_dev_t dev, size_t * free, size_t * total) { +#ifdef _WIN32 + MEMORYSTATUSEX status; + status.dwLength = sizeof(status); + GlobalMemoryStatusEx(&status); + *total = status.ullTotalPhys; + *free = status.ullAvailPhys; +#else + long pages = sysconf(_SC_PHYS_PAGES); + long page_size = sysconf(_SC_PAGE_SIZE); + *total = pages * page_size; + + // "free" system memory is ill-defined, for practical purposes assume that all of it is free: + *free = *total; +#endif // _WIN32 + + GGML_UNUSED(dev); +} + +static enum ggml_backend_dev_type ggml_backend_openvino_device_get_type(ggml_backend_dev_t dev) { + GGML_UNUSED(dev); + return GGML_BACKEND_DEVICE_TYPE_GPU; +} + +static void ggml_backend_openvino_device_get_props(ggml_backend_dev_t dev, ggml_backend_dev_props * props) { + props->name = ggml_backend_openvino_device_get_name(dev); + props->description = ggml_backend_openvino_device_get_description(dev); + props->type = ggml_backend_openvino_device_get_type(dev); + ggml_backend_openvino_device_get_memory(dev, &props->memory_free, &props->memory_total); + + props->caps = { + /* .async = */ false, + /* .host_buffer = */ false, + /* .buffer_from_host_ptr = */ false, + /* .events = */ false, + }; +} + +static ggml_backend_t ggml_backend_openvino_device_init(ggml_backend_dev_t dev, const char * params) { + GGML_UNUSED(params); + ggml_backend_openvino_device_context * ctx = (ggml_backend_openvino_device_context *) dev->context; + return ggml_backend_openvino_init(ctx->device); +} + +static ggml_backend_buffer_type_t ggml_backend_openvino_device_get_buffer_type(ggml_backend_dev_t dev) { + ggml_backend_openvino_device_context * ctx = (ggml_backend_openvino_device_context *) dev->context; + return ggml_backend_openvino_buffer_type(ctx->device); +} + +static ggml_backend_buffer_type_t ggml_backend_openvino_device_get_host_buffer_type(ggml_backend_dev_t dev) { + ggml_backend_openvino_device_context * ctx = (ggml_backend_openvino_device_context *) dev->context; + return ggml_backend_openvino_host_buffer_type(ctx->device); +} + +static bool has_view_op_input(const ggml_tensor * op) { + for (int i = 0; i < GGML_MAX_SRC; i++) { + if (op->src[i] == nullptr) { + break; + } + if (op->src[i]->op == GGML_OP_VIEW) { + return true; + } + } + return false; +} + +static bool is_supported_flash_attn_pattern(const ggml_tensor * op) { + // pattern of q,k,v should be q->op==PERMUTE, q->src[0]->op==VIEW, q->src[0]->src[0]->view_src==nullptr + for (int i = 0; i < 3; i++) { + const ggml_tensor * src = op->src[i]; + if (src->op != GGML_OP_PERMUTE || src->src[0] == nullptr || src->src[0]->op != GGML_OP_VIEW || + src->src[0]->src[0] == nullptr || src->src[0]->src[0]->view_src != nullptr) { + return false; + } + } + return true; +} + +static bool is_op_unsupported_case(const ggml_tensor * op) { + switch (op->op) { + case GGML_OP_GET_ROWS: + case GGML_OP_SET_ROWS: { + if (op->ne[3] != 1) { + return true; + } + break; + } + case GGML_OP_ADD: + case GGML_OP_MUL: { + if (op->src[1]->op == GGML_OP_PERMUTE) { + return true; + } + for (int i = 0; i < 4; i++) { + if (op->src[0]->ne[i] != op->src[1]->ne[i] && (op->src[0]->ne[i] != 1 && op->src[1]->ne[i] != 1)) { + return true; + } + } + break; + } + case GGML_OP_SOFT_MAX: { + if (op->src[2] != nullptr) { + // GGML_LOG_WARN("OpenVINO backend does not support SOFT_MAX with sinks\n"); + return true; + } + float scale = 1.0f; + float max_bias = 0.0f; + const auto * op_params = op->op_params; + memcpy(&scale, (const float *) op_params + 0, sizeof(float)); + memcpy(&max_bias, (const float *) op_params + 1, sizeof(float)); + if (max_bias > 0) { + // GGML_LOG_WARN("OpenVINO backend does not support SOFT_MAX with max_bias > 0\n"); + return true; + } + break; + } + case GGML_OP_FLASH_ATTN_EXT: { + if (op->src[4] != nullptr) { + // GGML_LOG_WARN("OpenVINO backend does not support FLASH_ATTN_EXT with sinks\n"); + return true; + } + if (!is_supported_flash_attn_pattern(op)) { + return true; + } + float scale = 1.0f; + float max_bias = 0.0f; + float logit_softcap = 0.0f; + const auto * op_params = op->op_params; + memcpy(&scale, (const float *) op_params + 0, sizeof(float)); + memcpy(&max_bias, (const float *) op_params + 1, sizeof(float)); + memcpy(&logit_softcap, (const float *) op_params + 2, sizeof(float)); + if (max_bias > 0) { + // GGML_LOG_WARN("OpenVINO backend does not support FLASH_ATTN_EXT with max_bias > 0\n"); + return true; + } + if (logit_softcap != 0) { + // GGML_LOG_WARN("OpenVINO backend does not support FLASH_ATTN_EXT with logit_softcap != 0\n"); + return true; + } + break; + } + case GGML_OP_PERMUTE: { + if (op->type == GGML_TYPE_BF16) { + // err msg: [GPU] Could not find a suitable kernel for transpose + // GGML_LOG_WARN("OpenVINO backend does not support PERMUTE with BF16 type\n"); + return true; + } + break; + } + case GGML_OP_CPY: { + if (op->src[1] != op) { + // GGML_LOG_WARN("OpenVINO backend only supports CPY that is a cast\n"); + return true; + } + break; + } + case GGML_OP_MUL_MAT: { + if (op->src[0]->type == GGML_TYPE_F16 && op->src[1]->type == GGML_TYPE_F16) { + // Has accuracy issue, try enabling this and see `test-backend-ops -o "MUL_MAT"` + // GGML_LOG_WARN("OpenVINO backend does not support MUL_MAT with two F16 tensors\n"); + return true; + } + if (op->src[0]->ne[3] != op->src[1]->ne[3] && op->src[0]->ne[3] != 1 && op->src[1]->ne[3] != 1) { + return true; + } + if (op->src[0]->op == GGML_OP_PERMUTE || op->src[1]->op == GGML_OP_PERMUTE) { + return true; + } + if (ggml_is_quantized(op->src[0]->type) && op->src[0]->ne[1] == 1) { + // MUL_MAT(type_a=q4_0,type_b=f32,m=1,n=2048,k=8192,bs=[1,1],nr=[1,1],per=[0,1,2,3],k_v=0,o=1) + // triggers a bug in ov matmul_shape_inference.hpp + return true; + } + if (op->src[0]->op == GGML_OP_VIEW && op->src[1]->op == GGML_OP_VIEW) { + return true; + } + break; + } + case GGML_OP_ROPE: { + const int32_t * op_params = op->op_params; + const int n_dims = op_params[1]; + const int mode = op_params[2]; + if (mode != GGML_ROPE_TYPE_NORMAL && mode != GGML_ROPE_TYPE_NEOX && mode != GGML_ROPE_TYPE_IMROPE) { + // GGML_LOG_WARN("OpenVINO backend does not support ROPE with mode %d\n", mode); + return true; + } + if (n_dims != 0.0f && n_dims != op->src[0]->ne[0]) { + // GGML_LOG_WARN("OpenVINO backend does not support ROPE with n_dims %d != src[0]->ne[0] %ld\n", n_dims, + // op->src[0]->ne[0]); + return true; + } + if (op->type != GGML_TYPE_F32) { + // GGML_LOG_WARN("OpenVINO backend does not support ROPE with type %s\n", ggml_type_name(op->type)); + return true; + } + if (op->src[0]->op == GGML_OP_VIEW) { + if (op->src[0]->view_src->ne[1] != op->src[0]->ne[2]) { + // GGML_LOG_WARN( + // "OpenVINO backend does not support ROPE with src[0]->view_src->ne[1] %ld != src[0]->ne[2] " + // "%ld\n", + // op->src[0]->view_src->ne[1], op->src[0]->ne[2]); + return true; + } + } + if (mode == GGML_ROPE_TYPE_IMROPE && + (op->src[2] != 0 || ((const float *) op_params)[6] != 1 || ((const float *) op_params)[7] != 0 || + ((const float *) op_params)[8] != 1)) { + // GGML_LOG_WARN("OpenVINO backend does not support IMROPE with freq_factors, freq_scale, ext_factor, and attn_factor\n"); + return true; + } + break; + } + default: + break; + } + if (op->op == GGML_OP_GET_ROWS) { + if (op->ne[0] == 256 && (op->src[0]->type == GGML_TYPE_Q4_K || op->src[0]->type == GGML_TYPE_Q5_K)) { + // ERR = 0.000000306 > 0.000000100 GET_ROWS(type=q4_K,n=256,m=5,r=4,be1=1,be2=1,v=0) + // ERR = 0.000000197 > 0.000000100 GET_ROWS(type=q5_K,n=256,m=5,r=4,be1=1,be2=1,v=0) + return true; + } + } + return false; +} + +static bool ggml_backend_openvino_device_supports_op(ggml_backend_dev_t dev, const ggml_tensor * op) { + GGML_ASSERT(dev->reg != nullptr); + + static std::set<ggml_type> supported_types{GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_BF16, GGML_TYPE_I64, + GGML_TYPE_I32, GGML_TYPE_Q4_0, GGML_TYPE_Q4_1, GGML_TYPE_Q4_K, + GGML_TYPE_Q5_K, GGML_TYPE_Q8_0, GGML_TYPE_Q6_K}; + + static const std::set<ggml_op> supported_ops{GGML_OP_NONE, GGML_OP_ADD, GGML_OP_MUL, GGML_OP_MUL_MAT, GGML_OP_VIEW, + /*GGML_OP_CONT,*/ GGML_OP_RESHAPE, GGML_OP_PERMUTE, GGML_OP_TRANSPOSE, + GGML_OP_GET_ROWS, GGML_OP_ROPE, GGML_OP_RMS_NORM, GGML_OP_SCALE, + // softmax is not updated due to replaced by flash_attn_ext + // GGML_OP_SOFT_MAX, + GGML_OP_SET_ROWS, GGML_OP_FLASH_ATTN_EXT, GGML_OP_CPY}; + static const std::set<ggml_unary_op> supported_unary_ops{ + GGML_UNARY_OP_GELU, + GGML_UNARY_OP_SILU, + }; + static const std::set<ggml_glu_op> supported_glu_ops{ + GGML_GLU_OP_SWIGLU, + GGML_GLU_OP_GEGLU, + }; + + switch (op->op) { + case GGML_OP_UNARY: { + auto supported = supported_unary_ops.find(ggml_get_unary_op(op)) != supported_unary_ops.end(); + if (!supported) { + // GGML_LOG_WARN("OpenVINO backend does not support unary op %s\n", ggml_unary_op_name(ggml_get_unary_op(op))); + return false; + } + if (has_view_op_input(op)) { + // GGML_LOG_WARN("OpenVINO backend does not support unary op %s with view input\n", + // ggml_unary_op_name(ggml_get_unary_op(op))); + return false; + } + break; + } + case GGML_OP_GLU: { + auto supported = supported_glu_ops.find(ggml_get_glu_op(op)) != supported_glu_ops.end(); + if (!supported) { + // GGML_LOG_WARN("OpenVINO backend does not support GLU op %s\n", ggml_glu_op_name(ggml_get_glu_op(op))); + return false; + } + if (has_view_op_input(op)) { + // GGML_LOG_WARN("OpenVINO backend does not support unary op %s with view input\n", + // ggml_glu_op_name(ggml_get_glu_op(op))); + return false; + } + if (op->src[1] == nullptr && op->src[0]->ne[0] % 2 != 0) { + // triggers bug in ov gpu + return false; + } + break; + } + default: { + auto supported = supported_ops.find(op->op) != supported_ops.end(); + if (!supported) { + // GGML_LOG_WARN("OpenVINO backend does not support op %s\n", ggml_op_name(op->op)); + return false; + } + static std::set<ggml_op> ops_not_support_view_input{ + GGML_OP_GET_ROWS, + GGML_OP_RMS_NORM, + }; + if (ops_not_support_view_input.find(op->op) != ops_not_support_view_input.end() && has_view_op_input(op)) { + // GGML_LOG_WARN("OpenVINO backend does not support op %s with view input\n", ggml_op_name(op->op)); + return false; + } + } + } + + if (supported_types.find(op->type) == supported_types.end()) { + // GGML_LOG_WARN("OpenVINO backend does not support tensor type %s\n", ggml_type_name(op->type)); + return false; + } + for (int i = 0; i < GGML_MAX_SRC; i++) { + auto * src = op->src[i]; + if (src == nullptr) { + break; + } + if (supported_types.find(src->type) == supported_types.end()) { + // GGML_LOG_WARN("OpenVINO backend does not support tensor type %s\n", ggml_type_name(src->type)); + return false; + } + if (ggml_is_quantized(src->type) && src->ne[2] != 1) { + // GGML_LOG_WARN("OpenVINO backend does not support 3D quantized tensors\n"); + return false; + } + } + + if (is_op_unsupported_case(op)) { + return false; + } + return true; +} + +static bool ggml_backend_openvino_device_supports_buft(ggml_backend_dev_t dev, ggml_backend_buffer_type_t buft) { + return ggml_backend_buft_is_openvino(buft) || ggml_backend_buft_is_host(buft); + GGML_UNUSED(dev); +} + +static const struct ggml_backend_device_i ggml_backend_openvino_device_interface = { + /* .get_name = */ ggml_backend_openvino_device_get_name, + /* .get_description = */ ggml_backend_openvino_device_get_description, + /* .get_memory = */ ggml_backend_openvino_device_get_memory, + /* .get_type = */ ggml_backend_openvino_device_get_type, + /* .get_props = */ ggml_backend_openvino_device_get_props, + /* .init_backend = */ ggml_backend_openvino_device_init, + /* .get_buffer_type = */ ggml_backend_openvino_device_get_buffer_type, + /* .get_host_buffer_type = */ ggml_backend_openvino_device_get_host_buffer_type, + /* .buffer_from_host_ptr = */ NULL, + /* .supports_op = */ ggml_backend_openvino_device_supports_op, + /* .supports_buft = */ ggml_backend_openvino_device_supports_buft, + /* .offload_op = */ NULL, + /* .event_new = */ NULL, + /* .event_free = */ NULL, + /* .event_synchronize = */ NULL, +}; + +struct ggml_backend_openvino_reg_context { + std::vector<ggml_backend_dev_t> devices; +}; + +static const char * ggml_backend_openvino_reg_get_name(ggml_backend_reg_t reg) { + return GGML_OPENVINO_NAME; + GGML_UNUSED(reg); +} + +static size_t ggml_backend_openvino_reg_get_device_count(ggml_backend_reg_t reg) { + GGML_UNUSED(reg); + return (size_t) ggml_backend_openvino_get_device_count(); +} + +static ggml_backend_dev_t ggml_backend_openvino_reg_get_device(ggml_backend_reg_t reg, size_t index) { + ggml_backend_openvino_reg_context * ctx = (ggml_backend_openvino_reg_context *) reg->context; + GGML_ASSERT(index < ctx->devices.size()); + return ctx->devices[index]; +} + +static const struct ggml_backend_reg_i ggml_backend_openvino_reg_interface = { + /* .get_name = */ ggml_backend_openvino_reg_get_name, + /* .get_device_count = */ ggml_backend_openvino_reg_get_device_count, + /* .get_device = */ ggml_backend_openvino_reg_get_device, + /* .get_proc_address = */ NULL, +}; + +static void ggml_openvino_init() { + // Initialize device config singleton from env var + ggml_openvino_init_device_config(); + GGML_LOG_INFO("OpenVINO: using device %s\n", ggml_openvino_get_device_name().c_str()); +} + +GGML_BACKEND_API ggml_backend_reg_t ggml_backend_openvino_reg(void) { + static ggml_backend_reg reg; + + static bool initialized = false; + { + static std::mutex mutex; + std::lock_guard<std::mutex> lock(mutex); + if (!initialized) { + ggml_openvino_init(); + + ggml_backend_openvino_reg_context * ctx = new ggml_backend_openvino_reg_context; + + for (int i = 0; i < ggml_backend_openvino_get_device_count(); i++) { + ggml_backend_openvino_device_context * dev_ctx = new ggml_backend_openvino_device_context; + dev_ctx->device = i; + dev_ctx->name = GGML_OPENVINO_NAME + std::to_string(i); + + dev_ctx->description = ov::get_openvino_version().description; + + ggml_backend_dev_t dev = + new ggml_backend_device{/* .interface = */ ggml_backend_openvino_device_interface, + /* .reg = */ ®, + /* .context = */ dev_ctx}; + ctx->devices.push_back(dev); + } + + reg = ggml_backend_reg{/* .api_version = */ GGML_BACKEND_API_VERSION, + /* .iface = */ ggml_backend_openvino_reg_interface, + /* .context = */ ctx}; + } + + initialized = true; + } + + return ® +} diff --git a/ggml/src/ggml-openvino/ggml-quants.cpp b/ggml/src/ggml-openvino/ggml-quants.cpp new file mode 100644 index 00000000000..57d66df4f01 --- /dev/null +++ b/ggml/src/ggml-openvino/ggml-quants.cpp @@ -0,0 +1,956 @@ +#include "ggml-quants.h" + +#include "ggml-common.h" +#include "ggml-impl.h" +#include "ggml.h" + +#include <algorithm> +#include <cassert> +#include <cmath> +#include <cstddef> +#include <cstdint> +#include <limits> +#include <memory> +#include <openvino/core/except.hpp> +#include <openvino/core/node.hpp> +#include <openvino/core/node_output.hpp> +#include <openvino/core/parallel.hpp> +#include <openvino/core/shape.hpp> +#include <openvino/core/type/element_type.hpp> +#include <openvino/core/type/element_type_traits.hpp> +#include <openvino/core/type/float16.hpp> +#include <openvino/op/add.hpp> +#include <openvino/op/constant.hpp> +#include <openvino/op/convert.hpp> +#include <openvino/op/multiply.hpp> +#include <openvino/op/reshape.hpp> +#include <openvino/op/subtract.hpp> +#include <openvino/op/util/attr_types.hpp> +#include <openvino/runtime/tensor.hpp> +#include <string> +#include <vector> + +void unpack_32_4(const uint8_t * data, uint8_t * dst) { + std::fill_n(dst, 16, 0); + for (int j = 0; j < 16; ++j) { + uint8_t x = (data[j] & 0x0F); + uint8_t y = (data[j] >> 4); + if (j % 2 != 0) { + x <<= 4; + y <<= 4; + } + dst[j / 2] |= x; + dst[8 + j / 2] |= y; // Last 16 weights are in the higher bits + } +} + +// Extracts (weight, scales, zp) from Q4_0 tensors. +// Data layout is: |16 bit scale|32 x 4bit weights|. +// When zp_arr is empty (symmetric), weights are stored as signed i4 (value - 8). +void extract_q4_0_data(const ggml_tensor * tensor, + ov::Tensor & weights_arr, + ov::Tensor & scales_arr, + ov::Tensor & zp_arr) { + const uint64_t bytes_per_block = 18; // 2 bytes scale, 32x0.5 byte weights + + auto * data = static_cast<uint8_t *>(tensor->data); + auto * weights = static_cast<uint8_t *>(weights_arr.data()); + auto * scales = scales_arr.data<ov::element_type_traits<ov::element::f16>::value_type>(); + + bool is_symmetric = (weights_arr.get_element_type() == ov::element::i4); // Signed i4 path + + if (!is_symmetric) { + auto * zp = static_cast<uint8_t *>(zp_arr.data()); + ov::parallel_for(scales_arr.get_size(), [&](size_t i) { + scales[i] = ov::float16::from_bits(*((uint16_t *) (data + i * bytes_per_block))); + // Pack two 4-bit zero points per byte + if (i % 2 == 0) { + zp[i / 2] = 8; // Lower nibble + } else { + zp[i / 2] |= (8 << 4); // Upper nibble + } + unpack_32_4(data + i * bytes_per_block + 2, weights + i * 16); + }); + } else { + // Symmetric: unpack as u4 then convert to i4 by subtracting 8 (XOR each nibble) + ov::parallel_for(scales_arr.get_size(), [&](size_t i) { + scales[i] = ov::float16::from_bits(*((uint16_t *) (data + i * bytes_per_block))); + unpack_32_4(data + i * bytes_per_block + 2, weights + i * 16); + // Convert u4 to i4: subtract 8 from each nibble. XOR 0x88 flips each nibble by 8. + for (int j = 0; j < 16; ++j) { + weights[i * 16 + j] ^= 0x88; + } + }); + } +} + +// Extracts (weight, scales, zp) from Q4_1 tensors. +// Data layout is: |16 bit scale|16 bit min|32 x 4bit weights|. +void extract_q4_1_data(const ggml_tensor * tensor, + ov::Tensor & weights_arr, + ov::Tensor & scales_arr, + ov::Tensor & zp_arr, + bool use_bias) { + const uint64_t bytes_per_block = 20; // 2 bytes scale, 2 bytes min, 32x0.5 byte weights + + auto * data = static_cast<uint8_t *>(tensor->data); + auto * weights = static_cast<uint8_t *>(weights_arr.data()); + auto * scales = scales_arr.data<ov::element_type_traits<ov::element::f16>::value_type>(); + + if (use_bias) { + // Store bias (min) directly as f16 instead of computing u4 zero points + auto * bias = zp_arr.data<ov::element_type_traits<ov::element::f16>::value_type>(); + ov::parallel_for(scales_arr.get_size(), [&](size_t i) { + float scale = static_cast<float>(ov::float16::from_bits(*((uint16_t *) (data + i * bytes_per_block)))); + float min = static_cast<float>(ov::float16::from_bits(*((uint16_t *) (data + i * bytes_per_block + 2)))); + scales[i] = ov::float16(scale); + bias[i] = ov::float16(min); // bias = min, dequant: w*s + bias + unpack_32_4(data + i * bytes_per_block + 4, weights + i * 16); + }); + } else { + auto * zp = static_cast<uint8_t *>(zp_arr.data()); + ov::parallel_for(scales_arr.get_size(), [&](size_t i) { + float scale = static_cast<float>(ov::float16::from_bits(*((uint16_t *) (data + i * bytes_per_block)))); + float min = static_cast<float>(ov::float16::from_bits(*((uint16_t *) (data + i * bytes_per_block + 2)))); + scales[i] = ov::float16(scale); + // zp = -min / scale (bias = min, so zp = -bias/scale) + uint8_t zp_val = (scale != 0.0f) ? (uint8_t) std::round(-min / scale) : 0; + // Pack two 4-bit zero points per byte + if (i % 2 == 0) { + zp[i / 2] = zp_val & 0x0F; // Lower nibble + } else { + zp[i / 2] |= (zp_val << 4); // Upper nibble + } + unpack_32_4(data + i * bytes_per_block + 4, weights + i * 16); + }); + } +} + +// Extracts (weight, scales, zp) from Q8_0 tensors. +// Data layout is: |16 bit scale|32 x 8bit weights|. +// When zp_arr is empty (symmetric), weights are stored as signed i8 directly. +void extract_q8_0_data(const ggml_tensor * tensor, + ov::Tensor & weights_arr, + ov::Tensor & scales_arr, + ov::Tensor & zp_arr) { + const uint64_t weights_per_block = 32; + const uint64_t bytes_per_block = 34; // 2 bytes scale, 32x1 byte weights + + auto * data = static_cast<uint8_t *>(tensor->data); + auto * weights = static_cast<uint8_t *>(weights_arr.data()); + auto * scales = scales_arr.data<ov::element_type_traits<ov::element::f16>::value_type>(); + + bool is_symmetric = (weights_arr.get_element_type() == ov::element::i8); // Signed i8 path + + if (!is_symmetric) { + auto * zp = static_cast<uint8_t *>(zp_arr.data()); + ov::parallel_for(scales_arr.get_size(), [&](size_t i) { + uint8_t * block_data = data + i * bytes_per_block; + scales[i] = ov::float16::from_bits(*(uint16_t *) block_data); + zp[i] = 128; + for (size_t j = 0; j < weights_per_block; ++j) { + uint8_t x = block_data[j + 2]; + x ^= 1 << 7; // Convert int8 to uint8 by flipping sign bit + weights[i * weights_per_block + j] = x; + } + }); + } else { + // Symmetric: store original int8 values directly (no unsigned bias) + ov::parallel_for(scales_arr.get_size(), [&](size_t i) { + uint8_t * block_data = data + i * bytes_per_block; + scales[i] = ov::float16::from_bits(*(uint16_t *) block_data); + // Copy int8 weights as-is (the tensor element type is i8) + memcpy(weights + i * weights_per_block, block_data + 2, weights_per_block); + }); + } +} + +void unpack_256_4(const uint8_t * data, uint8_t * dst) { + // Initialize the output array with zeros + std::fill_n(dst, 128, 0); + + for (size_t i = 0; i < 4; ++i) { + for (int j = 0; j < 32; ++j) { + uint8_t x = (data[i * 32 + j] & 0x0F); + uint8_t y = (data[i * 32 + j] >> 4); + if (j % 2 != 0) { + x <<= 4; + y <<= 4; + } + dst[i * 32 + j / 2] |= x; + dst[i * 32 + 16 + j / 2] |= y; // Last 16 weights are in the higher bits + } + } +} + +void extract_q4_k_data(const ggml_tensor * tensor, + ov::Tensor & weights_arr, + ov::Tensor & scales_arr, + ov::Tensor & zp_arr, + bool use_bias) { + const uint64_t bytes_per_block = 2 + 2 + 12 + 128; + const uint64_t n_super_block = tensor->nb[3] / bytes_per_block; + + auto * data = static_cast<uint8_t *>(tensor->data); + auto * weights = static_cast<uint8_t *>(weights_arr.data()); + auto * scales = scales_arr.data<ov::element_type_traits<ov::element::f16>::value_type>(); + + // For bias path, zp_arr holds f16 bias values; for zp path, it holds packed u4 zero points + auto * zp_u4 = use_bias ? nullptr : static_cast<uint8_t *>(zp_arr.data()); + auto * bias_f16 = use_bias ? zp_arr.data<ov::element_type_traits<ov::element::f16>::value_type>() : nullptr; + + ov::parallel_for(n_super_block, [&](size_t i) { + uint8_t * block_data = data + i * bytes_per_block; + + // Extract scale factors and offsets + float scale_scales = static_cast<float>(ov::float16::from_bits(*((uint16_t *) block_data))); + float scale_mins = static_cast<float>(ov::float16::from_bits(*((uint16_t *) block_data + 1))); + + // Extract qs1 and qs2 + uint8_t * qs1 = block_data + 4; + + // Calculate scales + float scale_vals[8]; + scale_vals[0] = scale_scales * static_cast<float>((*(qs1) & 0b111111)); + scale_vals[1] = scale_scales * static_cast<float>((*(qs1 + 1) & 0b111111)); + scale_vals[2] = scale_scales * static_cast<float>((*(qs1 + 2) & 0b111111)); + scale_vals[3] = scale_scales * static_cast<float>((*(qs1 + 3) & 0b111111)); + scale_vals[4] = scale_scales * static_cast<float>((*(qs1 + 8) & 0b00001111) | ((*(qs1) >> 6) << 4)); + scale_vals[5] = scale_scales * static_cast<float>((*(qs1 + 9) & 0b00001111) | ((*(qs1 + 1) >> 6) << 4)); + scale_vals[6] = scale_scales * static_cast<float>((*(qs1 + 10) & 0b00001111) | ((*(qs1 + 2) >> 6) << 4)); + scale_vals[7] = scale_scales * static_cast<float>((*(qs1 + 11) & 0b00001111) | ((*(qs1 + 3) >> 6) << 4)); + + // Calculate min values (bias = -min) + float min_vals[8]; + min_vals[0] = scale_mins * static_cast<float>((*(qs1 + 4) & 0b111111)); + min_vals[1] = scale_mins * static_cast<float>((*(qs1 + 5) & 0b111111)); + min_vals[2] = scale_mins * static_cast<float>((*(qs1 + 6) & 0b111111)); + min_vals[3] = scale_mins * static_cast<float>((*(qs1 + 7) & 0b111111)); + min_vals[4] = scale_mins * static_cast<float>((*(qs1 + 8) >> 4) | ((*(qs1 + 4) >> 6) << 4)); + min_vals[5] = scale_mins * static_cast<float>((*(qs1 + 9) >> 4) | ((*(qs1 + 5) >> 6) << 4)); + min_vals[6] = scale_mins * static_cast<float>((*(qs1 + 10) >> 4) | ((*(qs1 + 6) >> 6) << 4)); + min_vals[7] = scale_mins * static_cast<float>((*(qs1 + 11) >> 4) | ((*(qs1 + 7) >> 6) << 4)); + + // Store scales and compute zero points or bias + for (int j = 0; j < 8; j++) { + scales[i * 8 + j] = ov::float16(scale_vals[j]); + if (use_bias) { + // Store bias = -min directly as f16, dequant: w*s + bias + bias_f16[i * 8 + j] = ov::float16(-min_vals[j]); + } else { + // zp = min / scale (since bias = -min and zp = -bias/scale) + uint8_t zp_val = (scale_vals[j] != 0.0f) ? (uint8_t) std::round(min_vals[j] / scale_vals[j]) : 0; + // Pack two 4-bit zero points per byte + size_t idx = i * 8 + j; + if (idx % 2 == 0) { + zp_u4[idx / 2] = zp_val & 0x0F; + } else { + zp_u4[idx / 2] |= (zp_val << 4); + } + } + } + unpack_256_4(block_data + 16, weights + i * 128); + }); +} + +void extract_q6_k_data(const ggml_tensor * tensor, + ov::Tensor & weights_arr, + ov::Tensor & scales_arr, + ov::Tensor & zp_arr) { + const uint64_t bytes_per_block = 128 + 64 + 16 + 2; + const uint64_t n_super_block = tensor->nb[3] / bytes_per_block; + + auto * data = static_cast<uint8_t *>(tensor->data); + auto * weights = static_cast<uint8_t *>(weights_arr.data()); + auto * scales = scales_arr.data<ov::element_type_traits<ov::element::f16>::value_type>(); + + bool is_symmetric = (weights_arr.get_element_type() == ov::element::i8); // Signed i8 path + + if (!is_symmetric) { + auto * zp = static_cast<uint8_t *>(zp_arr.data()); + ov::parallel_for(n_super_block, [&](size_t i) { + uint8_t * block_data = data + i * bytes_per_block; + float scale_factor = static_cast<float>(ov::float16::from_bits(*((uint16_t *) block_data + 104))); + for (size_t j = 0; j < 16; j++) { + scales[j + i * 16] = + ov::float16(scale_factor * static_cast<float>(*((int8_t *) (block_data + 128 + 64 + j)))); + zp[j + i * 16] = 32; + } + uint8_t * ql = block_data; + uint8_t * qh = block_data + 128; + for (int64_t j = 0; j < 32; ++j) { + weights[i * 256 + j] = (ql[j] & 0xF) | (((qh[j] >> 0) & 3) << 4); + weights[i * 256 + j + 32] = (ql[32 + j] & 0xF) | (((qh[j] >> 2) & 3) << 4); + weights[i * 256 + j + 64] = (ql[j] >> 4) | (((qh[j] >> 4) & 3) << 4); + weights[i * 256 + j + 96] = (ql[32 + j] >> 4) | (((qh[j] >> 6) & 3) << 4); + weights[i * 256 + j + 128] = (ql[64 + j] & 0xF) | (((qh[32 + j] >> 0) & 3) << 4); + weights[i * 256 + j + 160] = (ql[96 + j] & 0xF) | (((qh[32 + j] >> 2) & 3) << 4); + weights[i * 256 + j + 192] = (ql[64 + j] >> 4) | (((qh[32 + j] >> 4) & 3) << 4); + weights[i * 256 + j + 224] = (ql[96 + j] >> 4) | (((qh[32 + j] >> 6) & 3) << 4); + } + }); + } else { + // Symmetric: subtract 32 from each weight to store as signed i8 + ov::parallel_for(n_super_block, [&](size_t i) { + uint8_t * block_data = data + i * bytes_per_block; + float scale_factor = static_cast<float>(ov::float16::from_bits(*((uint16_t *) block_data + 104))); + for (size_t j = 0; j < 16; j++) { + scales[j + i * 16] = + ov::float16(scale_factor * static_cast<float>(*((int8_t *) (block_data + 128 + 64 + j)))); + } + uint8_t * ql = block_data; + uint8_t * qh = block_data + 128; + auto * signed_weights = reinterpret_cast<int8_t *>(weights); + for (int64_t j = 0; j < 32; ++j) { + signed_weights[i * 256 + j] = static_cast<int8_t>((ql[j] & 0xF) | (((qh[j] >> 0) & 3) << 4)) - 32; + signed_weights[i * 256 + j + 32] = + static_cast<int8_t>((ql[32 + j] & 0xF) | (((qh[j] >> 2) & 3) << 4)) - 32; + signed_weights[i * 256 + j + 64] = static_cast<int8_t>((ql[j] >> 4) | (((qh[j] >> 4) & 3) << 4)) - 32; + signed_weights[i * 256 + j + 96] = + static_cast<int8_t>((ql[32 + j] >> 4) | (((qh[j] >> 6) & 3) << 4)) - 32; + signed_weights[i * 256 + j + 128] = + static_cast<int8_t>((ql[64 + j] & 0xF) | (((qh[32 + j] >> 0) & 3) << 4)) - 32; + signed_weights[i * 256 + j + 160] = + static_cast<int8_t>((ql[96 + j] & 0xF) | (((qh[32 + j] >> 2) & 3) << 4)) - 32; + signed_weights[i * 256 + j + 192] = + static_cast<int8_t>((ql[64 + j] >> 4) | (((qh[32 + j] >> 4) & 3) << 4)) - 32; + signed_weights[i * 256 + j + 224] = + static_cast<int8_t>((ql[96 + j] >> 4) | (((qh[32 + j] >> 6) & 3) << 4)) - 32; + } + }); + } +} + +static inline void get_scale_min_k4(int j, const uint8_t * q, uint8_t * d, uint8_t * m) { + if (j < 4) { + *d = q[j] & 63; + *m = q[j + 4] & 63; + } else { + *d = (q[j + 4] & 0xF) | ((q[j - 4] >> 6) << 4); + *m = (q[j + 4] >> 4) | ((q[j - 0] >> 6) << 4); + } +} + +void extract_q5_k_data(const ggml_tensor * tensor, + ov::Tensor & weights_arr, + ov::Tensor & scales_arr, + ov::Tensor & zp_arr, + bool use_bias) { + const uint64_t bytes_per_block = 4 + 12 + 32 + 128; + const uint64_t n_super_block = tensor->nb[3] / bytes_per_block; + + auto * data = static_cast<uint8_t *>(tensor->data); + auto * weights = static_cast<uint8_t *>(weights_arr.data()); + auto * scales = scales_arr.data<ov::element_type_traits<ov::element::f16>::value_type>(); + + // For bias path, zp_arr holds f16 bias values; for zp path, it holds u8 zero points + auto * zp_u8 = use_bias ? nullptr : static_cast<uint8_t *>(zp_arr.data()); + auto * bias_f16 = use_bias ? zp_arr.data<ov::element_type_traits<ov::element::f16>::value_type>() : nullptr; + + ov::parallel_for(n_super_block, [&](size_t i) { + uint8_t * block_data = data + i * bytes_per_block; + + const float d = static_cast<float>(ov::float16::from_bits(*((uint16_t *) block_data))); + const float min_factor = static_cast<float>(ov::float16::from_bits(*((uint16_t *) block_data + 1))); + + const uint8_t * scales_data = block_data + 4; // 12 bytes of scales + const uint8_t * qh = block_data + 4 + 12; // 32 bytes of high bits + const uint8_t * ql = block_data + 4 + 12 + 32; // 128 bytes of low bits + + int is = 0; + uint8_t u1 = 1; + uint8_t u2 = 2; + + // Process 2 blocks in one iteration + for (int j = 0; j < 256; j += 64) { // 256 = QK_K, so 4 iterations of 64 + uint8_t sc; + uint8_t m; + + // Get scale and min for first 32 elements + get_scale_min_k4(is + 0, scales_data, &sc, &m); + const float d1 = d * sc; + const float m1 = min_factor * m; + + // Get scale and min for second 32 elements + get_scale_min_k4(is + 1, scales_data, &sc, &m); + const float d2 = d * sc; + const float m2 = min_factor * m; + + scales[i * 8 + is] = ov::float16(d1); + scales[i * 8 + is + 1] = ov::float16(d2); + if (use_bias) { + // Store bias = -min directly as f16, dequant: w*s + bias + bias_f16[i * 8 + is] = ov::float16(-m1); + bias_f16[i * 8 + is + 1] = ov::float16(-m2); + } else { + // zp = min / scale (since bias = -min and zp = -bias/scale) + zp_u8[i * 8 + is] = (d1 != 0.0f) ? (uint8_t) std::round(m1 / d1) : 0; + zp_u8[i * 8 + is + 1] = (d2 != 0.0f) ? (uint8_t) std::round(m2 / d2) : 0; + } + + // Extract weights for first 32 elements (matching deq formula exactly) + for (int l = 0; l < 32; ++l) { + weights[i * 256 + j + l] = (ql[l] & 0xF) + ((qh[l] & u1) ? 16 : 0); + } + + // Extract weights for second 32 elements + for (int l = 0; l < 32; ++l) { + weights[i * 256 + j + l + 32] = (ql[l] >> 4) + ((qh[l] & u2) ? 16 : 0); + } + + ql += 32; + is += 2; + u1 <<= 2; + u2 <<= 2; + } + }); +} + +// TODO Reorder for make_intX_weights + +ov::Output<ov::Node> make_int8_weights(ov::Tensor & weight, + ov::Tensor & scales, + ov::Tensor & zp, + size_t group_size, + bool use_bias) { + ov::Shape orig_shape = weight.get_shape(); + bool is_signed = (weight.get_element_type() == ov::element::i8); // Symmetric: signed weights, no ZP + + // Expand dimensions for scales and zp/bias + auto scale_shape = scales.get_shape(); + + ov::Shape packed_shape = {orig_shape[0], orig_shape[1] / group_size, group_size}; + + if (packed_shape[1] == 1) { + // Requantized channel-wise case + packed_shape.erase(packed_shape.begin() + 1); + } else { + scale_shape.push_back(1); + scales.set_shape(scale_shape); + if (!is_signed && zp.get_size() > 0) { + auto zp_shape = zp.get_shape(); + zp_shape.push_back(1); + zp.set_shape(zp_shape); + } + } + + auto scales_f16 = std::make_shared<ov::op::v0::Constant>(scales); + + ov::Output<ov::Node> result; + if (is_signed) { + // Signed path: q * s (no zero point subtraction needed) + auto weights_node = std::make_shared<ov::op::v0::Constant>(ov::element::i8, packed_shape, + static_cast<uint8_t *>(weight.data()), nullptr); + weights_node->get_rt_info()["__gguf_tensor_holder"] = weight; + auto weights_f16 = std::make_shared<ov::op::v0::Convert>(weights_node, ov::element::f16); + result = std::make_shared<ov::op::v1::Multiply>(weights_f16, scales_f16, ov::op::AutoBroadcastType::NUMPY); + } else { + // Unsigned path + auto weights_node = std::make_shared<ov::op::v0::Constant>(ov::element::u8, packed_shape, + static_cast<uint8_t *>(weight.data()), nullptr); + weights_node->get_rt_info()["__gguf_tensor_holder"] = weight; + auto weights_f16 = std::make_shared<ov::op::v0::Convert>(weights_node, ov::element::f16); + + if (use_bias && zp.get_size() > 0) { + // Bias path: w * s + b (zp tensor holds f16 bias values) + auto bias_f16 = std::make_shared<ov::op::v0::Constant>(zp); + auto w_s = + std::make_shared<ov::op::v1::Multiply>(weights_f16, scales_f16, ov::op::AutoBroadcastType::NUMPY); + result = std::make_shared<ov::op::v1::Add>(w_s, bias_f16, ov::op::AutoBroadcastType::NUMPY); + } else { + // Zero point path: (w - zp) * s + auto zero_point = std::make_shared<ov::op::v0::Constant>(zp); + float zp_value; + if (ov::op::util::get_single_value(zero_point, zp_value)) { + zero_point = ov::op::v0::Constant::create(zero_point->get_element_type(), {}, {zp_value}); + } + auto zero_point_f16 = std::make_shared<ov::op::v0::Convert>(zero_point, ov::element::f16); + auto w_zp = + std::make_shared<ov::op::v1::Subtract>(weights_f16, zero_point_f16, ov::op::AutoBroadcastType::NUMPY); + result = std::make_shared<ov::op::v1::Multiply>(w_zp, scales_f16, ov::op::AutoBroadcastType::NUMPY); + } + } + + if (packed_shape.size() != 2) { + // If not requantized channel-wise case, reshape back to original shape + auto final_shape = + std::make_shared<ov::op::v0::Constant>(ov::element::i64, ov::Shape{orig_shape.size()}, orig_shape); + result = std::make_shared<ov::op::v1::Reshape>(result, final_shape, false); + } + + return std::make_shared<ov::op::v0::Convert>(result, ov::element::f32); +} + +ov::Output<ov::Node> make_int4_weights(ov::Tensor & weight, + ov::Tensor & scales, + ov::Tensor & zp, + size_t group_size, + bool use_bias) { + ov::Shape orig_weight_shape = weight.get_shape(); + bool is_signed = (weight.get_element_type() == ov::element::i4); // Symmetric: signed weights, no ZP + + // Expand dimensions for scales and zp/bias + ov::Shape scale_shape = scales.get_shape(); + + // Create INT4 weight tensor + ov::Shape packed_shape = {orig_weight_shape[0], orig_weight_shape[1] / group_size, group_size}; + + if (packed_shape[1] == 1) { + // Requantized channel-wise case + packed_shape.erase(packed_shape.begin() + 1); + } else { + scale_shape.push_back(1); + scales.set_shape(scale_shape); + if (!is_signed && zp.get_size() > 0) { + auto zp_shape = zp.get_shape(); + zp_shape.push_back(1); + zp.set_shape(zp_shape); + } + } + + auto scales_f16 = std::make_shared<ov::op::v0::Constant>(scales); + + ov::Output<ov::Node> result; + if (is_signed) { + // Signed path: q * s (no zero point subtraction needed) + auto weights_node = std::make_shared<ov::op::v0::Constant>(ov::element::i4, packed_shape, + static_cast<uint8_t *>(weight.data()), nullptr); + weights_node->get_rt_info()["__gguf_tensor_holder"] = weight; + auto weights_f16 = std::make_shared<ov::op::v0::Convert>(weights_node, ov::element::f16); + result = std::make_shared<ov::op::v1::Multiply>(weights_f16, scales_f16, ov::op::AutoBroadcastType::NUMPY); + } else { + // Unsigned path + auto weights_node = std::make_shared<ov::op::v0::Constant>(ov::element::u4, packed_shape, + static_cast<uint8_t *>(weight.data()), nullptr); + weights_node->get_rt_info()["__gguf_tensor_holder"] = weight; + auto weights_f16 = std::make_shared<ov::op::v0::Convert>(weights_node, ov::element::f16); + + if (use_bias && zp.get_size() > 0) { + // Bias path: w * s + b (zp tensor holds f16 bias values) + auto bias_f16 = std::make_shared<ov::op::v0::Constant>(zp); + auto w_s = + std::make_shared<ov::op::v1::Multiply>(weights_f16, scales_f16, ov::op::AutoBroadcastType::NUMPY); + result = std::make_shared<ov::op::v1::Add>(w_s, bias_f16, ov::op::AutoBroadcastType::NUMPY); + } else { + // Zero point path: (w - zp) * s + auto zero_points_node = std::make_shared<ov::op::v0::Constant>(zp); + float zp_value; + if (ov::op::util::get_single_value(zero_points_node, zp_value)) { + zero_points_node = ov::op::v0::Constant::create(zero_points_node->get_element_type(), {}, {zp_value}); + } + auto zero_points_f16 = std::make_shared<ov::op::v0::Convert>(zero_points_node, ov::element::f16); + auto w_zp = + std::make_shared<ov::op::v1::Subtract>(weights_f16, zero_points_f16, ov::op::AutoBroadcastType::NUMPY); + result = std::make_shared<ov::op::v1::Multiply>(w_zp, scales_f16, ov::op::AutoBroadcastType::NUMPY); + } + } + + if (packed_shape.size() != 2) { + // If not requantized channel-wise case, reshape back to original shape + auto final_shape = std::make_shared<ov::op::v0::Constant>(ov::element::i64, ov::Shape{orig_weight_shape.size()}, + orig_weight_shape); + result = std::make_shared<ov::op::v1::Reshape>(result, final_shape, false); + } + + return std::make_shared<ov::op::v0::Convert>(result, ov::element::f32); +} + +// Extract quantized weights from tensor and create weight subgraph +std::shared_ptr<ov::Node> extract_quantized_weights(const ggml_tensor * tensor, + const void * data, + ov::Tensor & weights, + ov::Tensor & scales, + ov::Tensor & zp, + bool use_bias) { + // Create a temporary tensor for extraction functions that read from tensor->data + ggml_tensor temp_tensor = *tensor; + temp_tensor.data = const_cast<void *>(data); + + // Determine block size based on tensor type + int64_t weights_per_block; + bool is_u4; + switch (tensor->type) { + case GGML_TYPE_Q4_0: + case GGML_TYPE_Q4_1: + case GGML_TYPE_Q4_K: + is_u4 = true; + weights_per_block = 32; + break; + case GGML_TYPE_Q8_0: + case GGML_TYPE_Q5_K: + is_u4 = false; + weights_per_block = 32; + break; + case GGML_TYPE_Q6_K: + is_u4 = false; + weights_per_block = 16; + break; + default: + throw std::runtime_error("Unsupported quantized type for extraction: " + + std::string(ggml_type_name(tensor->type))); + } + + // Extract quantized data + switch (tensor->type) { + case GGML_TYPE_Q4_0: + extract_q4_0_data(&temp_tensor, weights, scales, zp); + break; + case GGML_TYPE_Q4_1: + extract_q4_1_data(&temp_tensor, weights, scales, zp, use_bias); + break; + case GGML_TYPE_Q4_K: + extract_q4_k_data(&temp_tensor, weights, scales, zp, use_bias); + break; + case GGML_TYPE_Q8_0: + extract_q8_0_data(&temp_tensor, weights, scales, zp); + break; + case GGML_TYPE_Q6_K: + extract_q6_k_data(&temp_tensor, weights, scales, zp); + break; + case GGML_TYPE_Q5_K: + extract_q5_k_data(&temp_tensor, weights, scales, zp, use_bias); + break; + default: + throw std::runtime_error("Unsupported quantized type: " + std::string(ggml_type_name(tensor->type))); + } + + // Create the OpenVINO weight subgraph + ov::Output<ov::Node> weight_node; + if (is_u4) { + weight_node = make_int4_weights(weights, scales, zp, weights_per_block, use_bias); + } else { + weight_node = make_int8_weights(weights, scales, zp, weights_per_block, use_bias); + } + + auto result = weight_node.get_node_shared_ptr(); + result->set_friendly_name(tensor->name); + return result; +} + +// Requantize weights to target format, writing to provided buffers +std::shared_ptr<ov::Node> requantize_to_buffers(const ggml_tensor * tensor, + const void * data, + ExtraQuantType requant_type, + int64_t block_size, + ov::Tensor & weights, + ov::Tensor & scales, + ov::Tensor & zp) { + int64_t n_elements = ggml_nelements(tensor); + + // First dequantize to F32 + std::vector<float> weights_f32(n_elements); + ggml_get_type_traits(tensor->type)->to_float(data, weights_f32.data(), n_elements); + + // Handle F16 case - just convert and create constant + if (requant_type == ExtraQuantType::F16) { + ggml_get_type_traits(GGML_TYPE_F16)->from_float_ref(weights_f32.data(), weights.data(), n_elements); + auto result = std::make_shared<ov::op::v0::Constant>(weights); + result->set_friendly_name(tensor->name); + return result; + } + + // Requantize to target quantized format + bool is_u4 = (requant_type == ExtraQuantType::Q4_0_C || requant_type == ExtraQuantType::Q4_0_128); + + if (is_u4) { + quantize_q4_0(weights_f32.data(), weights, scales, zp, n_elements, block_size); + } else if (requant_type == ExtraQuantType::Q8_1_C) { + quantize_q8_1(weights_f32.data(), weights, scales, zp, n_elements, block_size); + } else { + quantize_q8_0(weights_f32.data(), weights, scales, zp, n_elements, block_size); + } + + // Create the OpenVINO weight subgraph + ov::Output<ov::Node> weight_node; + if (is_u4) { + weight_node = make_int4_weights(weights, scales, zp, block_size); + } else { + weight_node = make_int8_weights(weights, scales, zp, block_size); + } + + auto result = weight_node.get_node_shared_ptr(); + result->set_friendly_name(tensor->name); + return result; +} + +OvWeight process_weight_tensor(const ggml_tensor * tensor, const void * data, void * output_base_ptr, bool use_bias) { + GGML_ASSERT(tensor != nullptr); + GGML_ASSERT(data != nullptr); + + OvWeight result; + + // Get 2D shape for weights [rows, cols] + ov::Shape node_shape = {static_cast<size_t>(tensor->ne[1]), static_cast<size_t>(tensor->ne[0])}; + + // Handle F16/F32/BF16 weights + if (tensor->type == GGML_TYPE_F32 || tensor->type == GGML_TYPE_F16 || tensor->type == GGML_TYPE_BF16) { + ov::element::Type element_type; + switch (tensor->type) { + case GGML_TYPE_F32: + element_type = ov::element::f32; + break; + case GGML_TYPE_F16: + element_type = ov::element::f16; + break; + case GGML_TYPE_BF16: + element_type = ov::element::bf16; + break; + default: + OPENVINO_THROW("Unexpected tensor type in F16/F32/BF16 path"); + } + + if (output_base_ptr && output_base_ptr != data) { + // Using external buffer - copy data and create shared-memory constant + size_t tensor_bytes = ggml_nbytes(tensor); + memcpy(output_base_ptr, data, tensor_bytes); + result.weights = ov::Tensor(element_type, node_shape, output_base_ptr); + } else { + result.weights = ov::Tensor(element_type, node_shape, data); + } + result.weight_node = std::make_shared<ov::op::v0::Constant>(result.weights); + return result; + } + + // Handle quantized weights + if (!ggml_is_quantized(tensor->type)) { + OPENVINO_THROW("Unsupported weight tensor type: ", ggml_type_name(tensor->type)); + } + + result.layout = ggml_openvino_get_extracted_layout(tensor, use_bias); + const auto & layout = result.layout; + if (layout.total_size == 0) { + OPENVINO_THROW("Unsupported quantized type: ", ggml_type_name(tensor->type)); + } + + if (use_bias) { + OPENVINO_ASSERT(!layout.is_requant, + "use_bias is only used for test-backend-ops, which should not have requantization"); + // bias node will be created on the fly and not use backend buffer + output_base_ptr = nullptr; + } + + // F16 requant path - no separate scales/zp needed in result + if (layout.is_requant && layout.requant_type.has_value() && layout.requant_type.value() == ExtraQuantType::F16) { + if (output_base_ptr) { + result.weights = ov::Tensor(ov::element::f16, node_shape, + static_cast<uint8_t *>(output_base_ptr) + layout.weights_offset); + } else { + result.weights = ov::Tensor(ov::element::f16, node_shape); + } + ov::Tensor dummy_scales, dummy_zp; // Not used for F16 + result.weight_node = + requantize_to_buffers(tensor, data, ExtraQuantType::F16, 0, result.weights, dummy_scales, dummy_zp); + return result; + } + + // Quantized path (normal extraction or quantized requant) + // Create weight/scale/zp tensors - shared between both paths + // For symmetric quantization, use signed types (i4/i8) and no ZP tensor + ov::element::Type weight_type = layout.is_symmetric ? (layout.is_u4 ? ov::element::i4 : ov::element::i8) : + (layout.is_u4 ? ov::element::u4 : ov::element::u8); + ov::Shape scale_shape = {node_shape[0], node_shape[1] / layout.weights_per_block}; + + if (output_base_ptr) { + uint8_t * buf_base = static_cast<uint8_t *>(output_base_ptr); + result.weights = ov::Tensor(weight_type, node_shape, buf_base + layout.weights_offset); + result.scales = ov::Tensor(ov::element::f16, scale_shape, buf_base + layout.scales_offset); + if (!layout.is_symmetric) { + ov::element::Type zp_type = layout.is_u4 ? ov::element::u4 : ov::element::u8; + result.zp = ov::Tensor(zp_type, scale_shape, buf_base + layout.zp_offset); + } + // else: result.zp remains default-constructed (empty) for symmetric + } else { + result.weights = ov::Tensor(weight_type, node_shape); + result.scales = ov::Tensor(ov::element::f16, scale_shape); + if (!layout.is_symmetric) { + if (use_bias) { + result.zp = ov::Tensor(ov::element::f16, scale_shape); + } else { + ov::element::Type zp_type = layout.is_u4 ? ov::element::u4 : ov::element::u8; + result.zp = ov::Tensor(zp_type, scale_shape); + } + } + // else: result.zp remains default-constructed (empty) for symmetric + } + + if (layout.is_requant && layout.requant_type.has_value()) { + result.weight_node = requantize_to_buffers(tensor, data, layout.requant_type.value(), layout.weights_per_block, + result.weights, result.scales, result.zp); + } else { + result.weight_node = + extract_quantized_weights(tensor, data, result.weights, result.scales, result.zp, use_bias); + } + + return result; +} + +void quantize_q4_0(const float * x, + ov::Tensor & weights_arr, + ov::Tensor & scales_arr, + ov::Tensor & zp_arr, + int64_t k, + int64_t qk) { + assert(k % qk == 0); + const int nb = k / qk; + + auto * weights = static_cast<uint8_t *>(weights_arr.data()); + auto * scales = scales_arr.data<ov::element_type_traits<ov::element::f16>::value_type>(); + bool is_symmetric = (weights_arr.get_element_type() == ov::element::i4); // Signed i4 path + + if (!is_symmetric) { + auto * zp = static_cast<uint8_t *>(zp_arr.data()); + for (int i = 0; i < nb; i++) { + float amax = 0.0f; + float max = 0.0f; + for (int j = 0; j < qk; j++) { + const float v = x[i * qk + j]; + if (amax < fabsf(v)) { + amax = fabsf(v); + max = v; + } + } + const float d = max / -8; + if (d == 0) { + scales[i] = ov::float16(1.0f); + if (i % 2 == 0) { + zp[i / 2] = 8; + } else { + zp[i / 2] |= (8 << 4); + } + memset(weights + i * qk / 2, 8 | (8 << 4), qk / 2); + continue; + } + const float id = 1.0f / d; + scales[i] = ov::float16(d); + if (i % 2 == 0) { + zp[i / 2] = 8; + } else { + zp[i / 2] |= (8 << 4); + } + for (int j = 0; j < qk / 2; ++j) { + const float x0 = x[i * qk + 2 * j] * id; + const float x1 = x[i * qk + 2 * j + 1] * id; + const uint8_t xi0 = MIN(15, (int8_t) (x0 + 8.5f)); + const uint8_t xi1 = MIN(15, (int8_t) (x1 + 8.5f)); + weights[i * qk / 2 + j] = xi0 | (xi1 << 4); + } + } + } else { + // Symmetric: produce signed i4 values in [-8, 7] + for (int i = 0; i < nb; i++) { + float amax = 0.0f; + float max = 0.0f; + for (int j = 0; j < qk; j++) { + const float v = x[i * qk + j]; + if (amax < fabsf(v)) { + amax = fabsf(v); + max = v; + } + } + const float d = max / -8; + if (d == 0) { + scales[i] = ov::float16(1.0f); + // i4 value 0 packed: 0x00 + memset(weights + i * qk / 2, 0, qk / 2); + continue; + } + const float id = 1.0f / d; + scales[i] = ov::float16(d); + for (int j = 0; j < qk / 2; ++j) { + const float x0 = x[i * qk + 2 * j] * id; + const float x1 = x[i * qk + 2 * j + 1] * id; + // Signed i4: range [-8, 7]. Quantize as round(x*id), then pack as 4-bit two's complement. + int8_t si0 = (int8_t) std::max(-8, std::min(7, (int) roundf(x0))); + int8_t si1 = (int8_t) std::max(-8, std::min(7, (int) roundf(x1))); + weights[i * qk / 2 + j] = (si0 & 0x0F) | ((si1 & 0x0F) << 4); + } + } + } +} + +void quantize_q8_0(const float * x, + ov::Tensor & weights_arr, + ov::Tensor & scales_arr, + ov::Tensor & zp_arr, + int64_t k, + int64_t qk) { + assert(k % qk == 0); + const int nb = k / qk; + + auto * weights = static_cast<uint8_t *>(weights_arr.data()); + auto * scales = scales_arr.data<ov::element_type_traits<ov::element::f16>::value_type>(); + bool is_symmetric = (weights_arr.get_element_type() == ov::element::i8); // Signed i8 path + + if (!is_symmetric) { + auto * zp = static_cast<uint8_t *>(zp_arr.data()); + for (int i = 0; i < nb; i++) { + float amax = 0.0f; + for (int j = 0; j < qk; j++) { + const float v = x[i * qk + j]; + amax = std::max(amax, fabsf(v)); + } + const float d = amax / 127.0f; + const float id = d ? 1.0f / d : 0.0f; + scales[i] = ov::float16(d); + zp[i] = 128; + for (int j = 0; j < qk; ++j) { + const float x0 = x[i * qk + j] * id; + const int8_t xi0 = roundf(x0); + weights[i * qk + j] = (uint8_t) (xi0 + 128); + } + } + } else { + // Symmetric: store signed int8 values directly + auto * signed_weights = reinterpret_cast<int8_t *>(weights); + for (int i = 0; i < nb; i++) { + float amax = 0.0f; + for (int j = 0; j < qk; j++) { + const float v = x[i * qk + j]; + amax = std::max(amax, fabsf(v)); + } + const float d = amax / 127.0f; + const float id = d ? 1.0f / d : 0.0f; + scales[i] = ov::float16(d); + for (int j = 0; j < qk; ++j) { + const float x0 = x[i * qk + j] * id; + signed_weights[i * qk + j] = (int8_t) roundf(x0); + } + } + } +} + +void quantize_q8_1(const float * x, + ov::Tensor & weights_arr, + ov::Tensor & scales_arr, + ov::Tensor & zp_arr, + int64_t k, + int64_t qk) { + assert(k % qk == 0); + const int nb = k / qk; + + auto * weights = static_cast<uint8_t *>(weights_arr.data()); + auto * scales = scales_arr.data<ov::element_type_traits<ov::element::f16>::value_type>(); + auto * zp = static_cast<uint8_t *>(zp_arr.data()); + for (int i = 0; i < nb; i++) { + float min = std::numeric_limits<float>::max(); + float max = std::numeric_limits<float>::lowest(); + + for (int j = 0; j < qk; j++) { + const float v = x[i * qk + j]; + min = std::min(v, min); + max = std::max(v, max); + } + + const float d = (max - min) / ((1 << 8) - 1); + const float id = d ? 1.0f / d : 0.0f; + scales[i] = ov::float16(d); + // zp = -min / scale (Q8_1 is asymmetric) + zp[i] = (d != 0.0f) ? (uint8_t) std::round(-min / d) : 0; + + for (int j = 0; j < qk; ++j) { + const float x0 = (x[i * qk + j] - min) * id; + const uint8_t xi0 = roundf(x0); + weights[i * qk + j] = xi0; + } + } +} diff --git a/ggml/src/ggml-openvino/ggml-quants.h b/ggml/src/ggml-openvino/ggml-quants.h new file mode 100644 index 00000000000..e4a02297cae --- /dev/null +++ b/ggml/src/ggml-openvino/ggml-quants.h @@ -0,0 +1,153 @@ +#pragma once +#include "ggml-openvino-extra.h" // For ExtraQuantType +#include "ggml.h" + +#include <cstdint> +#include <openvino/op/constant.hpp> +#include <openvino/runtime/tensor.hpp> + +void unpack_32_4(const uint8_t* data, uint8_t* dst); + +void extract_q4_0_data(const ggml_tensor * tensor, + ov::Tensor & weights_arr, + ov::Tensor & scales_arr, + ov::Tensor & zp_arr); + +void extract_q4_1_data(const ggml_tensor * tensor, + ov::Tensor & weights_arr, + ov::Tensor & scales_arr, + ov::Tensor & zp_arr, + bool use_bias = false); + +void extract_q8_0_data(const ggml_tensor * tensor, + ov::Tensor & weights_arr, + ov::Tensor & scales_arr, + ov::Tensor & zp_arr); + +void unpack_256_4(const uint8_t* data, uint8_t* dst); + +void extract_q4_k_data(const ggml_tensor * tensor, + ov::Tensor & weights_arr, + ov::Tensor & scales_arr, + ov::Tensor & zp_arr, + bool use_bias = false); + +void extract_q5_k_data(const ggml_tensor * tensor, + ov::Tensor & weights_arr, + ov::Tensor & scales_arr, + ov::Tensor & zp_arr, + bool use_bias = false); + +void extract_q6_k_data(const ggml_tensor * tensor, + ov::Tensor & weights_arr, + ov::Tensor & scales_arr, + ov::Tensor & zp_arr); + +static constexpr size_t GGML_QUANTIZATION_GROUP_SIZE = 32; + +ov::Output<ov::Node> make_int8_weights(ov::Tensor & weight, + ov::Tensor & scales, + ov::Tensor & zp, + size_t group_size = GGML_QUANTIZATION_GROUP_SIZE, + bool use_bias = false); + +ov::Output<ov::Node> make_int4_weights(ov::Tensor & weight, + ov::Tensor & scales, + ov::Tensor & zp, + size_t group_size = GGML_QUANTIZATION_GROUP_SIZE, + bool use_bias = false); + +// Extract quantized weights from tensor and create weight subgraph +// If weights/scales/zp are provided (non-empty), uses them as output buffers +// Otherwise allocates new ov::Tensors internally +// Returns the weight node (make_int4_weights or make_int8_weights result) +std::shared_ptr<ov::Node> extract_quantized_weights( + const ggml_tensor * tensor, + const void * data, // Source data pointer (may differ from tensor->data) + ov::Tensor & weights, + ov::Tensor & scales, + ov::Tensor & zp, + bool use_bias = false); // Use fp bias instead of quantized zero_point (for test-backend-ops) + +// Requantize weights from tensor to target format, writing to provided buffers +// For F16 target, only weights buffer is used (scales/zp ignored) +// Returns the weight node +std::shared_ptr<ov::Node> requantize_to_buffers(const ggml_tensor * tensor, + const void * data, // Source data pointer + ExtraQuantType requant_type, + int64_t block_size, + ov::Tensor & weights, + ov::Tensor & scales, + ov::Tensor & zp); + +inline const char * extra_quant_type_name(ExtraQuantType t) { + switch (t) { + case ExtraQuantType::F16: + return "F16"; + case ExtraQuantType::Q4_0_C: + return "Q4_0_C"; + case ExtraQuantType::Q4_0_128: + return "Q4_0_128"; + case ExtraQuantType::Q8_0_C: + return "Q8_0_C"; + case ExtraQuantType::Q8_0_32: + return "Q8_0_32"; + case ExtraQuantType::Q8_1_C: + return "Q8_1_C"; + default: + return "unknown"; + } +} + +// Result from process_weight_tensor containing the weight node and tensors. +// For quantized weights, also contains the extracted layout and scale/zp tensors. +struct OvWeight { + std::shared_ptr<ov::Node> weight_node; + ggml_openvino_extracted_layout layout; // Only meaningful for quantized (layout.total_size > 0) + ov::Tensor weights; + ov::Tensor scales; + ov::Tensor zp; + + bool is_quantized() const { return layout.scales_size > 0; } +}; + +// Process weight tensor and create an OpenVINO weight node +// Handles F16/F32/BF16 and quantized weights, with optional requantization +// If output_base_ptr is nullptr, allocates internal buffers (for decoder use) +// If output_base_ptr is provided, uses pre-allocated buffers at specified offsets (for backend buffer use) +// Returns OvWeight with the weight node and optional quantized tensors +OvWeight process_weight_tensor( + const ggml_tensor * tensor, + const void * data, // Source data pointer (may differ from tensor->data) + void * output_base_ptr = nullptr, // Base pointer for output buffers (or nullptr for internal allocation) + bool use_bias = false); // Use fp bias instead of quantized zero_point, only used in test-backend-ops + +void quantize_q4_0(const float * x, + ov::Tensor & weights_arr, + ov::Tensor & scales_arr, + ov::Tensor & zp_arr, + int64_t k, + int64_t qk); +void quantize_q8_1(const float * x, + ov::Tensor & weights_arr, + ov::Tensor & scales_arr, + ov::Tensor & zp_arr, + int64_t k, + int64_t qk); +void quantize_q8_0(const float * x, + ov::Tensor & weights_arr, + ov::Tensor & scales_arr, + ov::Tensor & zp_arr, + int64_t k, + int64_t qk); + +namespace ov { +namespace op { +namespace util { +// From <openvino>/src/common/transformations/include/transformations/utils/utils.hpp +bool get_single_value(const std::shared_ptr<ov::op::v0::Constant>& const_node, + float& value, + bool check_value_range = true); +} // namespace util +} // namespace op +} // namespace ov diff --git a/ggml/src/ggml-openvino/openvino/decoder.h b/ggml/src/ggml-openvino/openvino/decoder.h new file mode 100644 index 00000000000..3b8da2be5d2 --- /dev/null +++ b/ggml/src/ggml-openvino/openvino/decoder.h @@ -0,0 +1,74 @@ +#pragma once + +#include <cstdint> +#include <map> +#include <openvino/core/node.hpp> +#include <openvino/frontend/decoder.hpp> +#include <string> + +namespace ov { +namespace frontend { +namespace ggml { + +class GgmlDecoder : public DecoderBase { +public: + virtual ov::Any get_attribute(const std::string& name) const = 0; + + virtual PartialShape get_input_shape(int node_idx, const std::string& name) const = 0; + + virtual std::vector<size_t> get_input_stride(int node_idx, const std::string& name) const = 0; + + virtual element::Type get_input_type(int node_idx, const std::string& name) const = 0; + + virtual size_t get_input_size() const = 0; + + virtual size_t get_input_size(int node_idx) const = 0; + + virtual void get_input_node(size_t input_port_idx, + std::string& producer_name, + std::string& producer_output_port_name, + size_t& producer_output_port_index) const = 0; + + virtual std::vector<std::string> get_input_names(int node_idx) const = 0; + + virtual PartialShape get_output_shape(int node_idx) const = 0; + + virtual element::Type get_output_type(const int node_idx) const = 0; + + virtual int32_t* get_input_op_params(int node_idx, const std::string& name) const = 0; + + virtual int32_t * get_output_op_params(int node_idx) const = 0; + + virtual std::vector<std::string> get_output_names(int node_idx) const = 0; + + virtual const std::string& get_op_type() const = 0; + + virtual const std::string& get_op_type(int node_idx) const = 0; + + virtual const std::string& get_op_name() const = 0; + + virtual const std::string& get_op_name(int node_idx) const = 0; + + virtual void visit_subgraph(std::function<void(std::shared_ptr<GgmlDecoder>, int node_idx)> node_visitor) const = 0; + + virtual int get_op_case(int node_idx) const = 0; + + virtual const std::map<std::string, std::shared_ptr<ov::Node>>& get_model_inputs() const = 0; + virtual const std::map<std::string, std::shared_ptr<ov::Node>>& get_model_extra_inputs() const = 0; + virtual const std::map<std::string, std::shared_ptr<ov::Node>>& get_model_weights() const = 0; + virtual std::vector<std::string> get_model_output_names() const = 0; + + virtual int32_t* get_rope_params() const = 0; + + virtual std::map<std::string, std::string> get_kv_param_res_names() const = 0; + + virtual bool is_static() const = 0; + + virtual bool is_stateful() const = 0; + + virtual int is_swa_layer(int layer) const = 0; +}; + +} // namespace ggml +} // namespace frontend +} // namespace ov diff --git a/ggml/src/ggml-openvino/openvino/frontend.cpp b/ggml/src/ggml-openvino/openvino/frontend.cpp new file mode 100644 index 00000000000..c2ba14e66e6 --- /dev/null +++ b/ggml/src/ggml-openvino/openvino/frontend.cpp @@ -0,0 +1,27 @@ +#include "frontend.h" + +#include "input_model.h" +#include "op_table.h" +#include "translate_session.h" + +namespace ov { +namespace frontend { +namespace ggml { + +FrontEnd::FrontEnd() {} + +std::shared_ptr<Model> FrontEnd::convert(const InputModel::Ptr & model, bool naive) { + auto ggml_model = std::dynamic_pointer_cast<ggml::InputModel>(model); + FRONT_END_GENERAL_CHECK(ggml_model, "Invalid input model"); + std::shared_ptr<Model> converted_model; + const auto & supported_ops = get_supported_ops(); + { + TranslateSession translate_session(model, supported_ops, naive); + converted_model = translate_session.get_converted_model(); + } + return converted_model; +} + +} // namespace ggml +} // namespace frontend +} // namespace ov diff --git a/ggml/src/ggml-openvino/openvino/frontend.h b/ggml/src/ggml-openvino/openvino/frontend.h new file mode 100644 index 00000000000..f1c6f0c3e3c --- /dev/null +++ b/ggml/src/ggml-openvino/openvino/frontend.h @@ -0,0 +1,23 @@ +// Copyright (C) 2018-2024 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include <openvino/frontend/frontend.hpp> + +namespace ov { +namespace frontend { +namespace ggml { + +class FrontEnd { +public: + using Ptr = std::shared_ptr<FrontEnd>; + FrontEnd(); + + static std::shared_ptr<Model> convert(const InputModel::Ptr& model, bool naive = false); +}; + +} // namespace ggml +} // namespace frontend +} // namespace ov diff --git a/ggml/src/ggml-openvino/openvino/input_model.cpp b/ggml/src/ggml-openvino/openvino/input_model.cpp new file mode 100644 index 00000000000..39b004c9317 --- /dev/null +++ b/ggml/src/ggml-openvino/openvino/input_model.cpp @@ -0,0 +1,17 @@ +#include "input_model.h" + +#include "decoder.h" + +namespace ov { +namespace frontend { +namespace ggml { + +InputModel::InputModel(const std::shared_ptr<GgmlDecoder> & gdecoder) : m_decoder(gdecoder) {} + +const std::shared_ptr<GgmlDecoder> & InputModel::get_model_decoder() const { + return m_decoder; +} + +} // namespace ggml +} // namespace frontend +} // namespace ov diff --git a/ggml/src/ggml-openvino/openvino/input_model.h b/ggml/src/ggml-openvino/openvino/input_model.h new file mode 100644 index 00000000000..ce8434426c9 --- /dev/null +++ b/ggml/src/ggml-openvino/openvino/input_model.h @@ -0,0 +1,29 @@ +#pragma once + +#include <openvino/frontend/input_model.hpp> + +#include "decoder.h" + +namespace ov { +namespace frontend { +namespace ggml { + +class FrontEnd; +class GgmlDecoder; +using ov::frontend::ggml::GgmlDecoder; + +class InputModel : public ov::frontend::InputModel { + friend class ::ov::frontend::ggml::FrontEnd; + +public: + explicit InputModel(const std::shared_ptr<GgmlDecoder>& gdecoder); + + const std::shared_ptr<GgmlDecoder>& get_model_decoder() const; + +private: + std::shared_ptr<GgmlDecoder> m_decoder; +}; + +} // namespace ggml +} // namespace frontend +} // namespace ov diff --git a/ggml/src/ggml-openvino/openvino/node_context.h b/ggml/src/ggml-openvino/openvino/node_context.h new file mode 100644 index 00000000000..aa484128a95 --- /dev/null +++ b/ggml/src/ggml-openvino/openvino/node_context.h @@ -0,0 +1,112 @@ +#pragma once + +#include <cstdint> +#include <openvino/frontend/node_context.hpp> +#include <string> + +#include "decoder.h" + +namespace ov { +namespace frontend { +namespace ggml { + +class TranslateSession; + +typedef std::map<std::string, Output<Node>> TensorMap; + +class NodeContext : public frontend::NodeContext { +public: + NodeContext(const std::shared_ptr<GgmlDecoder>& decoder, + std::shared_ptr<TensorMap>& tensor_map, + int node_idx, + TranslateSession* translate_session = nullptr) + : ov::frontend::NodeContext(decoder->get_op_type(node_idx)), + m_decoder(decoder), + m_tensor_map(tensor_map), + m_node_idx(node_idx), + m_translate_session(translate_session) { + m_input_names = decoder->get_input_names(m_node_idx); + m_output_names = decoder->get_output_names(m_node_idx); + } + + TranslateSession* get_translate_session() const { + return m_translate_session; + } + + const std::vector<std::string>& get_input_names() const { return m_input_names; } + + size_t get_input_size() const override { + return m_decoder->get_input_size(m_node_idx); + } + + ov::element::Type get_input_type(size_t index) const { + return m_decoder->get_input_type(m_node_idx, m_input_names[index]); + } + + PartialShape get_input_shape(size_t input_index) const { + return m_decoder->get_input_shape(m_node_idx, m_input_names[input_index]); + } + + std::vector<size_t> get_input_stride(size_t index) const { + return m_decoder->get_input_stride(m_node_idx, m_input_names[index]); + } + + std::string get_output_name() const { return m_output_names[0]; } + + PartialShape get_output_shape() const { return m_decoder->get_output_shape(m_node_idx); } + + int32_t* get_input_op_params(size_t index) const { + return m_decoder->get_input_op_params(m_node_idx, m_input_names[index]); + } + + int32_t * get_output_op_params() const { return m_decoder->get_output_op_params(m_node_idx); } + + ov::element::Type get_output_type() const { + return m_decoder->get_output_type(m_node_idx); + } + + Output<Node> get_input(int idx) const override { + return m_tensor_map->at(m_input_names[idx]); + } + + Output<Node> get_input(const std::string& name) const override { + if (m_tensor_map->find(name) == m_tensor_map->end()) { + throw std::runtime_error("'" + name + "' not found in tensor map."); + } + return m_tensor_map->at(name); + } + + bool has_input(const std::string& name) const { + return m_tensor_map->find(name) != m_tensor_map->end(); + } + + const std::string& get_name() const override { + return m_decoder->get_op_name(m_node_idx); + } + + ov::Any get_attribute_as_any(const std::string& name) const override { + return m_decoder->get_attribute(name); + } + + int get_op_case() const { + return m_decoder->get_op_case(m_node_idx); + } + + bool is_static() const { return m_decoder->is_static(); } + + bool is_stateful() const { return m_decoder->is_stateful(); } + +private: + std::shared_ptr<GgmlDecoder> m_decoder; + std::shared_ptr<TensorMap>& m_tensor_map; + int m_node_idx; + TranslateSession* m_translate_session; + std::vector<std::string> m_input_names; + std::vector<std::string> m_output_names; +}; + +using CreatorFunction = std::function<ov::OutputVector(const ov::frontend::ggml::NodeContext&)>; + +} // namespace ggml +} // namespace frontend +} // namespace ov diff --git a/ggml/src/ggml-openvino/openvino/op/cont.cpp b/ggml/src/ggml-openvino/openvino/op/cont.cpp new file mode 100644 index 00000000000..6160dd74444 --- /dev/null +++ b/ggml/src/ggml-openvino/openvino/op/cont.cpp @@ -0,0 +1,48 @@ + +#include "../node_context.h" +#include "../op_table.h" +#include "../utils.h" + +#include <climits> +#include <cstdint> +#include <memory> +#include <openvino/op/reshape.hpp> +#include <openvino/op/slice.hpp> +#include <vector> + +namespace ov { +namespace frontend { +namespace ggml { +namespace op { + +OutputVector translate_cont(const NodeContext & context) { + num_inputs_check(context, 1, 1); + + int op_case = context.get_op_case(); + FRONT_END_CHECK_IMPLEMENTED(op_case == 1 || op_case == 2 || op_case == 3, "Unsupported CONT case"); + + auto src_shape = context.get_input_shape(0).to_shape(); + auto dst_shape = context.get_output_shape().to_shape(); + ov::Output<Node> res; + + if (op_case == 1) { + // The input comes from a PERMUTE + throw std::runtime_error("Code of this case might be outdated"); + dst_shape[1] = -1; + res = std::make_shared<ov::op::v1::Reshape>( + context.get_input(0), ov::op::v0::Constant::create(ov::element::i64, {dst_shape.size()}, dst_shape), false); + } else if (op_case == 2) { + // The input comes from a TRANSPOSE + return {context.get_input(0)}; + } else { + // The input comes from a VIEW + res = process_view_input(context, 0); + } + + return rename_outputs_with_suffix({res}, context.get_name()); +} + +} // namespace op +} // namespace ggml +} // namespace frontend +} // namespace ov diff --git a/ggml/src/ggml-openvino/openvino/op/cpy.cpp b/ggml/src/ggml-openvino/openvino/op/cpy.cpp new file mode 100644 index 00000000000..831117208be --- /dev/null +++ b/ggml/src/ggml-openvino/openvino/op/cpy.cpp @@ -0,0 +1,21 @@ +#include "../node_context.h" +#include "../op_table.h" +#include "../utils.h" + +#include <memory> +#include <openvino/op/convert.hpp> + +namespace ov { +namespace frontend { +namespace ggml { +namespace op { + +OutputVector translate_cpy(const NodeContext & context) { + auto res = std::make_shared<ov::op::v0::Convert>(context.get_input(0), context.get_output_type()); + return rename_outputs_with_suffix({res}, context.get_name()); +} + +} // namespace op +} // namespace ggml +} // namespace frontend +} // namespace ov diff --git a/ggml/src/ggml-openvino/openvino/op/flash_attn_ext.cpp b/ggml/src/ggml-openvino/openvino/op/flash_attn_ext.cpp new file mode 100644 index 00000000000..42602a730a4 --- /dev/null +++ b/ggml/src/ggml-openvino/openvino/op/flash_attn_ext.cpp @@ -0,0 +1,90 @@ +#include "../node_context.h" +#include "../op_table.h" +#include "../utils.h" + +#include <cstdint> +#include <memory> +#include <openvino/op/broadcast.hpp> +#include <openvino/op/concat.hpp> +#include <openvino/op/constant.hpp> +#include <openvino/op/convert.hpp> +#include <openvino/op/reshape.hpp> +#include <openvino/op/scaled_dot_product_attention.hpp> +#include <openvino/op/transpose.hpp> +#include <openvino/op/unsqueeze.hpp> +#include <string> + +namespace ov { +namespace frontend { +namespace ggml { +namespace op { + +OutputVector translate_flash_attn_ext(const NodeContext & context) { + num_inputs_check(context, 4, 4); + auto q_f32 = context.get_input(0); + auto k = context.get_input(1); + auto v = context.get_input(2); + auto mask = context.get_input(3); + + float * params = reinterpret_cast<float *>(context.get_output_op_params()); + float scale = params[0]; + // float max_bias = params[1]; + // float logit_softcap = params[2]; + + auto q = std::make_shared<ov::op::v0::Convert>(q_f32, ov::element::f16); + auto scale_node = std::make_shared<ov::op::v0::Constant>(ov::element::f16, ov::Shape{}, std::vector<float>{scale}); + + ov::Output<ov::Node> mask_sliced, res; + std::string mask_name = "KQ_mask_sliced"; + if (context.get_input_names()[3].find("swa") != std::string::npos) { + mask_name = "KQ_mask_swa_sliced"; + } + if (context.has_input(mask_name)) { + mask_sliced = context.get_input(mask_name); + } else { + auto zero = ov::op::v0::Constant::create(ov::element::i64, {1}, {0}); + auto one = ov::op::v0::Constant::create(ov::element::i64, {1}, {1}); + auto two = ov::op::v0::Constant::create(ov::element::i64, {1}, {2}); + auto token_len = get_dimensions(q, {2}); + mask_sliced = std::make_shared<ov::op::v8::Slice>(mask, zero, token_len, one, two); + } + + if (mask_sliced.get_element_type() != ov::element::f16) { + mask_sliced = std::make_shared<ov::op::v0::Convert>(mask_sliced, ov::element::f16); + } + + auto tile_kv = [&](int64_t num_heads, int64_t num_heads_kv, int64_t head_size, ov::Output<Node> kv) { + int64_t factor = num_heads / num_heads_kv; + if (factor > 1 && num_heads_kv > 1) { + ov::Output<ov::Node> kv_broadcast_shape, kv_unsqueezed, new_kv_shape; + auto unsqueeze_axes = ov::op::v0::Constant::create(ov::element::i64, Shape{}, {2}); + kv_unsqueezed = std::make_shared<ov::op::v0::Unsqueeze>(kv, unsqueeze_axes); + + kv_broadcast_shape = ov::op::v0::Constant::create( + ov::element::i64, {5}, {(int64_t) 1, (int64_t) 1, factor, (int64_t) 1, (int64_t) 1}); + new_kv_shape = + ov::op::v0::Constant::create(ov::element::i64, {4}, {(int64_t) 0, num_heads, (int64_t) -1, head_size}); + + kv = std::make_shared<ov::op::v3::Broadcast>(kv_unsqueezed, kv_broadcast_shape, + ov::op::BroadcastType::BIDIRECTIONAL); + kv = std::make_shared<ov::op::v1::Reshape>(kv, new_kv_shape, true); + } + return kv; + }; + + auto q_shape = context.get_input_shape(0).to_shape(); + auto k_shape = context.get_input_shape(1).to_shape(); + k = tile_kv(q_shape[1], k_shape[1], q_shape[3], k); + v = tile_kv(q_shape[1], k_shape[1], q_shape[3], v); + + auto sdpa = std::make_shared<ov::op::v13::ScaledDotProductAttention>(q, k, v, mask_sliced, scale_node, false); + res = std::make_shared<ov::op::v1::Transpose>(sdpa, + ov::op::v0::Constant::create(ov::element::i64, {4}, {0, 2, 1, 3})); + res = std::make_shared<ov::op::v0::Convert>(res, ov::element::f32); + return rename_outputs_with_suffix({res}, context.get_name()); +} + +} // namespace op +} // namespace ggml +} // namespace frontend +} // namespace ov diff --git a/ggml/src/ggml-openvino/openvino/op/get_rows.cpp b/ggml/src/ggml-openvino/openvino/op/get_rows.cpp new file mode 100644 index 00000000000..49f51b7ca3f --- /dev/null +++ b/ggml/src/ggml-openvino/openvino/op/get_rows.cpp @@ -0,0 +1,69 @@ +#include "../node_context.h" +#include "../op_table.h" +#include "../utils.h" + +#include <openvino/core/node.hpp> +#include <openvino/core/node_output.hpp> +#include <openvino/op/constant.hpp> +#include <openvino/op/convert.hpp> +#include <openvino/op/gather.hpp> +#include <openvino/op/squeeze.hpp> +#include <openvino/op/unsqueeze.hpp> + +namespace ov { +namespace frontend { +namespace ggml { +namespace op { + +OutputVector translate_get_rows(const NodeContext & context) { + num_inputs_check(context, 2, 2); + + int op_case = context.get_op_case(); + + Output<Node> res; + auto data = context.get_input(0); + auto indices = context.get_input(1); + + if (op_case == 2) { + // The input comes from a VIEW + indices = process_view_input(context, 1); + } + + // data[1,b,x,y] ind[1,1,b,x'] test-backend-ops case + // data[x,y] ind[1,1,1,x'] normal case + indices = + std::make_shared<ov::op::v0::Squeeze>(indices, ov::op::v0::Constant::create(ov::element::i64, {2}, {0, 1})); + if (data.get_partial_shape().rank() == 4) { + if (!(data.get_partial_shape()[1].is_dynamic()) && data.get_partial_shape()[1].get_length() == 1) { + // Work-around for a bug in ov cpu plugin for test-backend-ops + data = std::make_shared<ov::op::v0::Squeeze>(data, + ov::op::v0::Constant::create(ov::element::i64, {2}, {0, 1})); + auto axis = ov::op::v0::Constant::create(ov::element::i32, ov::Shape{}, {0}); + res = std::make_shared<ov::op::v8::Gather>(data, indices, axis); + } else { + auto axis = ov::op::v0::Constant::create(ov::element::i32, ov::Shape{}, {1}); + data = + std::make_shared<ov::op::v0::Squeeze>(data, ov::op::v0::Constant::create(ov::element::i64, {1}, {0})); + res = std::make_shared<ov::op::v8::Gather>(data, indices, axis, 1); + } + } else if (context.is_stateful() && data.get_partial_shape().rank() == 3) { + auto axis = ov::op::v0::Constant::create(ov::element::i32, ov::Shape{}, {1}); + res = std::make_shared<ov::op::v8::Gather>(data, indices, axis, 1); + } else { + auto axis = ov::op::v0::Constant::create(ov::element::i32, ov::Shape{}, {0}); + res = std::make_shared<ov::op::v8::Gather>(data, indices, axis); + } + + if (res.get_element_type() != context.get_output_type()) { + res = std::make_shared<ov::op::v0::Convert>(res, context.get_output_type()); + } + if (!(context.is_stateful())) { + res = std::make_shared<ov::op::v0::Unsqueeze>(res, ov::op::v0::Constant::create(ov::element::i64, {1}, {0})); + } + return rename_outputs_with_suffix({res}, context.get_name()); +} + +} // namespace op +} // namespace ggml +} // namespace frontend +} // namespace ov diff --git a/ggml/src/ggml-openvino/openvino/op/glu_geglu.cpp b/ggml/src/ggml-openvino/openvino/op/glu_geglu.cpp new file mode 100644 index 00000000000..d9fa4c24367 --- /dev/null +++ b/ggml/src/ggml-openvino/openvino/op/glu_geglu.cpp @@ -0,0 +1,61 @@ +#include "../node_context.h" +#include "../op_table.h" +#include "../utils.h" + +#include <memory> +#include <openvino/core/node_output.hpp> +#include <openvino/op/constant.hpp> +#include <openvino/op/gelu.hpp> +#include <openvino/op/multiply.hpp> +#include <openvino/op/sigmoid.hpp> +#include <openvino/op/slice.hpp> + +namespace ov { +namespace frontend { +namespace ggml { +namespace op { + +OutputVector translate_glu_geglu(const NodeContext & context) { + num_inputs_check(context, 1, 2); + + ov::Output<ov::Node> src0; + ov::Output<ov::Node> src1; + if (context.get_input_size() == 2) { + src0 = context.get_input(0); + src1 = context.get_input(1); + } else { + // GGML splits along ne[0] (OV last axis) using floor division: nc = ne[0] / 2. + // Both halves are nc elements; if the dimension is odd, the last element is dropped. + // Use Slice instead of Split to handle odd dimensions correctly. + auto combined = context.get_input(0); + auto combined_shape = combined.get_partial_shape(); + int64_t last_dim_val = combined_shape[combined_shape.rank().get_length() - 1].get_length(); + int64_t nc = last_dim_val / 2; + + auto axis = ov::op::v0::Constant::create(ov::element::i64, {1}, {-1}); + auto step = ov::op::v0::Constant::create(ov::element::i64, {1}, {1}); + auto start0 = ov::op::v0::Constant::create(ov::element::i64, {1}, {0}); + auto stop0 = ov::op::v0::Constant::create(ov::element::i64, {1}, {nc}); + auto start1 = ov::op::v0::Constant::create(ov::element::i64, {1}, {nc}); + auto stop1 = ov::op::v0::Constant::create(ov::element::i64, {1}, {2 * nc}); + + src0 = std::make_shared<ov::op::v8::Slice>(combined, start0, stop0, step, axis); + src1 = std::make_shared<ov::op::v8::Slice>(combined, start1, stop1, step, axis); + } + + int32_t * params = context.get_output_op_params(); + const int32_t swapped = params[1]; + if (swapped) { + std::swap(src0, src1); + } + + auto gelu = std::make_shared<ov::op::v7::Gelu>(src0); + auto res = std::make_shared<ov::op::v1::Multiply>(gelu, src1); + + return rename_outputs_with_suffix({res}, context.get_name()); +} + +} // namespace op +} // namespace ggml +} // namespace frontend +} // namespace ov diff --git a/ggml/src/ggml-openvino/openvino/op/glu_swiglu.cpp b/ggml/src/ggml-openvino/openvino/op/glu_swiglu.cpp new file mode 100644 index 00000000000..00ed7951a03 --- /dev/null +++ b/ggml/src/ggml-openvino/openvino/op/glu_swiglu.cpp @@ -0,0 +1,62 @@ +#include "../node_context.h" +#include "../op_table.h" +#include "../utils.h" + +#include <cstdint> +#include <memory> +#include <openvino/core/node_output.hpp> +#include <openvino/op/constant.hpp> +#include <openvino/op/multiply.hpp> +#include <openvino/op/sigmoid.hpp> +#include <openvino/op/slice.hpp> + +namespace ov { +namespace frontend { +namespace ggml { +namespace op { + +OutputVector translate_glu_swiglu(const NodeContext & context) { + num_inputs_check(context, 1, 2); + + ov::Output<ov::Node> src0; + ov::Output<ov::Node> src1; + if (context.get_input_size() == 2) { + src0 = context.get_input(0); + src1 = context.get_input(1); + } else { + // GGML splits along ne[0] (OV last axis) using floor division: nc = ne[0] / 2. + // Both halves are nc elements; if the dimension is odd, the last element is dropped. + // Use Slice instead of Split to handle odd dimensions correctly. + auto combined = context.get_input(0); + auto combined_shape = combined.get_partial_shape(); + int64_t last_dim_val = combined_shape[combined_shape.rank().get_length() - 1].get_length(); + int64_t nc = last_dim_val / 2; + + auto axis = ov::op::v0::Constant::create(ov::element::i64, {1}, {-1}); + auto step = ov::op::v0::Constant::create(ov::element::i64, {1}, {1}); + auto start0 = ov::op::v0::Constant::create(ov::element::i64, {1}, {0}); + auto stop0 = ov::op::v0::Constant::create(ov::element::i64, {1}, {nc}); + auto start1 = ov::op::v0::Constant::create(ov::element::i64, {1}, {nc}); + auto stop1 = ov::op::v0::Constant::create(ov::element::i64, {1}, {2 * nc}); + + src0 = std::make_shared<ov::op::v8::Slice>(combined, start0, stop0, step, axis); + src1 = std::make_shared<ov::op::v8::Slice>(combined, start1, stop1, step, axis); + } + + int32_t * params = context.get_output_op_params(); + const int32_t swapped = params[1]; + if (swapped) { + std::swap(src0, src1); + } + + auto sigmoid = std::make_shared<ov::op::v0::Sigmoid>(src0); + auto silu = std::make_shared<ov::op::v1::Multiply>(src0, sigmoid); + auto res = std::make_shared<ov::op::v1::Multiply>(silu, src1); + + return rename_outputs_with_suffix({res}, context.get_name()); +} + +} // namespace op +} // namespace ggml +} // namespace frontend +} // namespace ov diff --git a/ggml/src/ggml-openvino/openvino/op/mulmat.cpp b/ggml/src/ggml-openvino/openvino/op/mulmat.cpp new file mode 100644 index 00000000000..38edec85ddf --- /dev/null +++ b/ggml/src/ggml-openvino/openvino/op/mulmat.cpp @@ -0,0 +1,90 @@ +#include "../node_context.h" +#include "../op_table.h" +#include "../utils.h" + +#include <climits> +#include <cstdint> +#include <memory> +#include <openvino/core/node.hpp> +#include <openvino/core/node_output.hpp> +#include <openvino/op/broadcast.hpp> +#include <openvino/op/concat.hpp> +#include <openvino/op/constant.hpp> +#include <openvino/op/convert.hpp> +#include <openvino/op/matmul.hpp> +#include <openvino/op/reshape.hpp> +#include <openvino/op/slice.hpp> +#include <openvino/op/transpose.hpp> +#include <openvino/op/unsqueeze.hpp> +#include <openvino/op/util/op_types.hpp> +#include <vector> + +namespace ov { +namespace frontend { +namespace ggml { +namespace op { + +OutputVector translate_mulmat(const NodeContext & context) { + num_inputs_check(context, 2, 2); + + int op_case = context.get_op_case(); + + ov::Output<Node> res; + ov::Output<ov::Node> B = context.get_input(0); + ov::Output<ov::Node> A = context.get_input(1); + + bool transpose_b = true; + if (op_case == 2) { + B = B.get_node_shared_ptr()->input_value(0); + transpose_b = false; + } else if (op_case == 3) { + B = process_view_input(context, 0); + A = process_view_input(context, 1); + } + if (A.get_element_type() != B.get_element_type()) { + B = std::make_shared<ov::op::v0::Convert>(context.get_input(0), context.get_input_type(1)); + } + + auto B_shape = context.get_input_shape(0).to_shape(); + auto A_shape = context.get_input_shape(1).to_shape(); + int64_t A_batch = A_shape[1]; + int64_t B_batch = B_shape[1]; + + auto A_batch_larger = A_batch > B_batch; + auto batch_large = A_batch_larger ? A_batch : B_batch; + auto batch_small = A_batch_larger ? B_batch : A_batch; + + Output<Node> Z = A_batch_larger ? B : A; + int64_t factor = batch_large / batch_small; + if (factor > 1 && batch_small > 1) { + auto batch_large_node = ov::op::v0::Constant::create(ov::element::i64, {1}, std::vector<int64_t>{batch_large}); + auto batch_small_node = ov::op::v0::Constant::create(ov::element::i64, {1}, std::vector<int64_t>{batch_small}); + auto factor_node = ov::op::v0::Constant::create(ov::element::i64, {1}, std::vector<int64_t>{factor}); + + auto unsqueeze_axes = ov::op::v0::Constant::create(ov::element::i64, Shape{}, {2}); + auto Z_unsqueezed = std::make_shared<ov::op::v0::Unsqueeze>(Z, unsqueeze_axes); + + auto broadcast_shape = ov::op::v0::Constant::create( + ov::element::i64, {5}, {(int64_t) 1, (int64_t) 1, factor, (int64_t) 1, (int64_t) 1}); + auto new_Z_shape = ov::op::v0::Constant::create(ov::element::i64, {4}, + {(int64_t) 0, batch_large, (int64_t) -1, (int64_t) A_shape[3]}); + + auto Z_broadcasted = std::make_shared<ov::op::v3::Broadcast>(Z_unsqueezed, broadcast_shape, + ov::op::BroadcastType::BIDIRECTIONAL); + Z = std::make_shared<ov::op::v1::Reshape>(Z_broadcasted, new_Z_shape, true); + } + if (A_batch_larger) { + B = Z; + } else { + A = Z; + } + + res = std::make_shared<ov::op::v0::MatMul>(A, B, false, transpose_b); + + return rename_outputs_with_suffix({res}, context.get_name()); +} + +} // namespace op +} // namespace ggml +} // namespace frontend +} // namespace ov diff --git a/ggml/src/ggml-openvino/openvino/op/permute.cpp b/ggml/src/ggml-openvino/openvino/op/permute.cpp new file mode 100644 index 00000000000..4c800f9ee4f --- /dev/null +++ b/ggml/src/ggml-openvino/openvino/op/permute.cpp @@ -0,0 +1,102 @@ +#include "../node_context.h" +#include "../op_table.h" +#include "../utils.h" + +#include <climits> +#include <cstdint> +#include <memory> +#include <openvino/core/node.hpp> +#include <openvino/op/add.hpp> +#include <openvino/op/concat.hpp> +#include <openvino/op/constant.hpp> +#include <openvino/op/reshape.hpp> +#include <openvino/op/slice.hpp> +#include <openvino/op/transpose.hpp> + +namespace ov { +namespace frontend { +namespace ggml { +namespace op { + +OutputVector translate_permute(const NodeContext & context) { + num_inputs_check(context, 1, 1); + + int op_case = context.get_op_case(); + FRONT_END_CHECK_IMPLEMENTED(op_case == 1 || op_case == 2 || op_case == 3 || op_case == 4, + "Unsupported PERMUTE case"); + + ov::Output<Node> res; + auto src = context.get_input(0); + auto perm = ov::op::v0::Constant::create(ov::element::i64, {4}, {0, 2, 1, 3}); + + if (op_case == 1 || context.is_stateful()) { + res = std::make_shared<ov::op::v1::Transpose>(src, perm); + } else if (op_case == 4) { + auto output_shape = context.get_output_shape().to_shape(); + auto n_heads = ov::op::v0::Constant::create(ov::element::i64, {1}, {output_shape[1]}); + auto head_size = ov::op::v0::Constant::create(ov::element::i64, {1}, {output_shape[3]}); + auto n_seq_active = context.has_input("n_seq_active") ? + context.get_input("n_seq_active") : + ov::op::v0::Constant::create(ov::element::i64, {1}, {output_shape[0]}); + auto neg_one = ov::op::v0::Constant::create(ov::element::i64, {1}, {-1}); + + auto new_shape = + std::make_shared<ov::op::v0::Concat>(ov::OutputVector{n_seq_active, neg_one, n_heads, head_size}, 0); + + // // Alternative + // auto zero = ov::op::v0::Constant::create(ov::element::i64, {1}, {0}); + // auto new_shape = std::make_shared<ov::op::v0::Concat>(ov::OutputVector{n_seq_active, neg_one, zero, zero}, 0); + + auto reshaped = std::make_shared<ov::op::v1::Reshape>(src, new_shape, true); + res = std::make_shared<ov::op::v1::Transpose>(reshaped, perm); + } else { + auto cache_shape = src.get_partial_shape(); + auto output_shape = context.get_output_shape().to_shape(); + int64_t head_size = output_shape[3]; + int64_t n_heads = output_shape[1]; + int64_t ctx_per_seq = cache_shape[2].is_static() ? cache_shape[2].get_length() : -1; + int64_t n_seq = cache_shape[1].get_length(); + + Output<Node> attention_size; + if (!context.has_input("attention_size")) { + attention_size = ov::op::v0::Constant::create(ov::element::i64, {1}, {output_shape[2]}); + } else if (op_case == 2) { + attention_size = context.get_input("attention_size"); + } else { + attention_size = context.get_input("attention_size_swa"); + } + + Output<Node> seq_active_start; + Output<Node> seq_active_end; + if (context.has_input("seq_active_start")) { + seq_active_start = context.get_input("seq_active_start"); + seq_active_end = context.get_input("seq_active_end"); + } else { + int64_t n_seq_active = output_shape[0]; + size_t offset = *((size_t *) context.get_input_op_params(0)); + int64_t seq_active_start_val = offset / context.get_input_stride(0)[0]; + int64_t seq_active_end_val = seq_active_start_val + n_seq_active; + seq_active_start = ov::op::v0::Constant::create(ov::element::i64, {1}, {seq_active_start_val}); + seq_active_end = ov::op::v0::Constant::create(ov::element::i64, {1}, {seq_active_end_val}); + } + + // 1. reshape to [n_seq, ctx_per_seq, n_heads, head_size] + // 2. slice out the active sequences + // 3. slice out the attention part in each sequence + // 4. permute + auto zero = ov::op::v0::Constant::create(ov::element::i64, {1}, {0}); + auto one = ov::op::v0::Constant::create(ov::element::i64, {1}, {1}); + + auto src_reshaped = std::make_shared<ov::op::v1::Reshape>( + src, ov::op::v0::Constant::create(ov::element::i64, {4}, {n_seq, ctx_per_seq, n_heads, head_size}), false); + auto slice1 = std::make_shared<ov::op::v8::Slice>(src_reshaped, seq_active_start, seq_active_end, one, zero); + auto slice2 = std::make_shared<ov::op::v8::Slice>(slice1, zero, attention_size, one, one); + res = std::make_shared<ov::op::v1::Transpose>(slice2, perm); + } + return rename_outputs_with_suffix({res}, context.get_name()); +} + +} // namespace op +} // namespace ggml +} // namespace frontend +} // namespace ov diff --git a/ggml/src/ggml-openvino/openvino/op/reshape.cpp b/ggml/src/ggml-openvino/openvino/op/reshape.cpp new file mode 100644 index 00000000000..efd9a5a860a --- /dev/null +++ b/ggml/src/ggml-openvino/openvino/op/reshape.cpp @@ -0,0 +1,83 @@ +#include "../node_context.h" +#include "../op_table.h" +#include "../utils.h" + +#include <cstdint> +#include <memory> +#include <openvino/core/node.hpp> +#include <openvino/core/node_output.hpp> +#include <openvino/frontend/exception.hpp> +#include <openvino/op/concat.hpp> +#include <openvino/op/constant.hpp> +#include <openvino/op/reshape.hpp> +#include <stdexcept> +#include <vector> + +namespace ov { +namespace frontend { +namespace ggml { +namespace op { + +OutputVector translate_reshape(const NodeContext & context) { + num_inputs_check(context, 1, 1); + if (context.get_input_shape(0) == context.get_output_shape()) { + return {context.get_input(0)}; + } + + int op_case = context.get_op_case(); + FRONT_END_CHECK_IMPLEMENTED( + op_case == 1 || op_case == 2 || op_case == 3 || op_case == 4 || op_case == 5 || op_case == 6, + "Unsupported RESHAPE case"); + + auto output_shape = context.get_output_shape().to_shape(); + std::shared_ptr<ov::Node> new_shape_node; + if (op_case == 1) { + if (context.is_stateful()) { + new_shape_node = ov::op::v0::Constant::create( + ov::element::i64, {3}, + std::vector<int64_t>{-1, (int64_t) output_shape[2], (int64_t) output_shape[3]}); + } else { + new_shape_node = ov::op::v0::Constant::create( + ov::element::i64, {4}, + std::vector<int64_t>{(int64_t) output_shape[0], -1, (int64_t) output_shape[2], (int64_t) output_shape[3]}); + } + } else if (op_case == 2) { + new_shape_node = ov::op::v0::Constant::create( + ov::element::i64, {4}, + std::vector<int64_t>{(int64_t) output_shape[0], (int64_t) output_shape[1], -1, (int64_t) output_shape[3]}); + + } else if (op_case == 3) { + throw std::runtime_error("might be outdated RESHAPE case"); + new_shape_node = ov::op::v0::Constant::create( + ov::element::i64, {4}, std::vector<int64_t>{(int64_t) output_shape[0], (int64_t) output_shape[1], -1, 1}); + + } else if (op_case == 4) { + return {context.get_input(0).get_node_shared_ptr()->input_value(0)}; + + } else if (op_case == 5) { + if (context.is_stateful()) { + std::vector<int64_t> shape_vec = {1, -1, (int64_t) context.get_output_shape().to_shape()[3]}; + new_shape_node = ov::op::v0::Constant::create(ov::element::i64, {3}, shape_vec); + } else { + std::vector<int64_t> shape_vec = {1, 1, -1, (int64_t) context.get_output_shape().to_shape()[3]}; + new_shape_node = ov::op::v0::Constant::create(ov::element::i64, {4}, shape_vec); + } + + // // Alternative + // auto token_len = context.get_input("token_len"); + // auto emb_size = + // ov::op::v0::Constant::create(ov::element::i64, {1}, {(int64_t) context.get_output_shape().to_shape()[3]}); + // auto one = ov::op::v0::Constant::create(ov::element::i64, {1}, {1}); + // new_shape_node = std::make_shared<ov::op::v0::Concat>(ov::OutputVector{one, one, token_len, emb_size}, 0); + + } else if (op_case == 6) { + new_shape_node = ov::op::v0::Constant::create(ov::element::i64, {4}, context.get_output_shape().to_shape()); + } + auto res = std::make_shared<ov::op::v1::Reshape>(context.get_input(0), new_shape_node, false); + return rename_outputs_with_suffix({res}, context.get_name()); +} + +} // namespace op +} // namespace ggml +} // namespace frontend +} // namespace ov diff --git a/ggml/src/ggml-openvino/openvino/op/rms_norm.cpp b/ggml/src/ggml-openvino/openvino/op/rms_norm.cpp new file mode 100644 index 00000000000..72cf92283e9 --- /dev/null +++ b/ggml/src/ggml-openvino/openvino/op/rms_norm.cpp @@ -0,0 +1,46 @@ +#include "../node_context.h" +#include "../op_table.h" +#include "../utils.h" + +#include <memory> +#include <openvino/op/add.hpp> +#include <openvino/op/constant.hpp> +#include <openvino/op/divide.hpp> +#include <openvino/op/multiply.hpp> +#include <openvino/op/power.hpp> +#include <openvino/op/reduce_mean.hpp> +#include <openvino/op/sqrt.hpp> + +namespace ov { +namespace frontend { +namespace ggml { +namespace op { + +OutputVector translate_rms_norm(const NodeContext & context) { + num_inputs_check(context, 1, 1); + + auto input_node = context.get_input(0); + auto square = std::make_shared<ov::op::v1::Power>( + input_node, ov::op::v0::Constant::create(ov::element::f32, ov::Shape{1}, {2.0f})); + + auto mean = std::make_shared<ov::op::v1::ReduceMean>( + square, ov::op::v0::Constant::create(ov::element::i64, ov::Shape{1}, {-1}), true); + + float eps; + memcpy(&eps, context.get_output_op_params(), sizeof(float)); + + auto rms = std::make_shared<ov::op::v0::Sqrt>( + std::make_shared<ov::op::v1::Add>(mean, ov::op::v0::Constant::create(ov::element::f32, ov::Shape{1}, {eps}))); + + auto reciprocal = + std::make_shared<ov::op::v1::Divide>(ov::op::v0::Constant::create(ov::element::f32, ov::Shape{1}, {1.0f}), rms); + + auto res = std::make_shared<ov::op::v1::Multiply>(input_node, reciprocal); + + return rename_outputs_with_suffix({res}, context.get_name()); +} + +} // namespace op +} // namespace ggml +} // namespace frontend +} // namespace ov diff --git a/ggml/src/ggml-openvino/openvino/op/rope.cpp b/ggml/src/ggml-openvino/openvino/op/rope.cpp new file mode 100644 index 00000000000..a8db9b38930 --- /dev/null +++ b/ggml/src/ggml-openvino/openvino/op/rope.cpp @@ -0,0 +1,149 @@ +#include "../node_context.h" +#include "../op_table.h" +#include "../utils.h" + +#include <cstdint> +#include <memory> +#include <openvino/core/node.hpp> +#include <openvino/core/node_output.hpp> +#include <openvino/op/add.hpp> +#include <openvino/op/concat.hpp> +#include <openvino/op/constant.hpp> +#include <openvino/op/convert.hpp> +#include <openvino/op/cos.hpp> +#include <openvino/op/gather.hpp> +#include <openvino/op/multiply.hpp> +#include <openvino/op/reshape.hpp> +#include <openvino/op/shape_of.hpp> +#include <openvino/op/sin.hpp> +#include <openvino/op/slice.hpp> +#include <openvino/op/split.hpp> +#include <openvino/op/subtract.hpp> +#include <openvino/op/transpose.hpp> +#include <openvino/op/unsqueeze.hpp> +#include <vector> + +namespace ov { +namespace frontend { +namespace ggml { +namespace op { + +OutputVector translate_rope(const NodeContext & context) { + num_inputs_check(context, 2, 3); + + int op_case = context.get_op_case(); + + ov::Output<Node> res; + + auto data_node = context.get_input(0).get_node_shared_ptr(); + auto output_shape = context.get_output_shape().to_shape(); + int32_t * op_params = context.get_output_op_params(); + const int mode = (op_case & 0xFFFF0000) >> 16; + op_case = (op_case & 0x0000FFFF); + + constexpr int TYPE_NORMAL = 0; + constexpr int TYPE_NEOX = 1; + constexpr int TYPE_IMROPE = 2; + + Output<Node> cos_theta_node; + Output<Node> sin_theta_node; + if (context.has_input("rope_cos")) { + cos_theta_node = context.get_input("rope_cos"); + sin_theta_node = context.get_input("rope_sin"); + } else { + auto inp_pos = context.get_input(1).get_node_shared_ptr(); + std::shared_ptr<ov::Node> rope_freqs_weight; + if (context.get_input_size() == 3) { + rope_freqs_weight = context.get_input(2).get_node_shared_ptr(); + } + auto sin_cos = make_sin_cos(op_params, inp_pos, rope_freqs_weight, mode == TYPE_IMROPE); + sin_theta_node = sin_cos.first; + cos_theta_node = sin_cos.second; + } + + if (op_case == 2) { + // The input comes from a VIEW + int slice_len = output_shape[2] * output_shape[3]; + data_node = process_view_input(context, 0, slice_len).get_node_shared_ptr(); + if (context.is_stateful()) { + auto data_shape = ov::op::v0::Constant::create( + ov::element::i64, {3}, std::vector<int64_t>{-1, (int64_t) output_shape[2], (int64_t) output_shape[3]}); + data_node = std::make_shared<ov::op::v1::Reshape>(data_node, data_shape, false); + } else { + auto data_shape = ov::op::v0::Constant::create( + ov::element::i64, {4}, std::vector<int64_t>{1, -1, (int64_t) output_shape[2], (int64_t) output_shape[3]}); + data_node = std::make_shared<ov::op::v1::Reshape>(data_node, data_shape, false); + } + } + + if (mode == TYPE_NORMAL) { + auto neg_one = ov::op::v0::Constant::create(ov::element::i64, {1}, {-1}); + auto zero = ov::op::v0::Constant::create(ov::element::i64, {1}, {0}); + auto one = ov::op::v0::Constant::create(ov::element::i64, {1}, {1}); + auto two = ov::op::v0::Constant::create(ov::element::i64, {1}, {2}); + auto end = ov::op::v0::Constant::create(ov::element::i64, {1}, {output_shape[3]}); + Output<Node> even_slice; + Output<Node> odd_slice; + int32_t unsqueeze_dim = context.is_stateful() ? 3 : 4; + even_slice = std::make_shared<ov::op::v8::Slice>(data_node, zero, end, two, neg_one); + odd_slice = std::make_shared<ov::op::v8::Slice>(data_node, one, end, two, neg_one); + + Output<Node> first_half = + std::make_shared<ov::op::v1::Subtract>(std::make_shared<ov::op::v1::Multiply>(even_slice, cos_theta_node), + std::make_shared<ov::op::v1::Multiply>(odd_slice, sin_theta_node)); + Output<Node> second_half = + std::make_shared<ov::op::v1::Add>(std::make_shared<ov::op::v1::Multiply>(even_slice, sin_theta_node), + std::make_shared<ov::op::v1::Multiply>(odd_slice, cos_theta_node)); + + first_half = std::make_shared<ov::op::v0::Unsqueeze>(first_half, + ov::op::v0::Constant::create(ov::element::i64, {1}, {unsqueeze_dim})); + second_half = std::make_shared<ov::op::v0::Unsqueeze>(second_half, + ov::op::v0::Constant::create(ov::element::i64, {1}, {unsqueeze_dim})); + auto stack = std::make_shared<ov::op::v0::Concat>(OutputVector{first_half, second_half}, unsqueeze_dim); + + auto data_shape = ov::op::v0::Constant::create( + ov::element::i64, {4}, std::vector<int64_t>{1, -1, (int64_t) output_shape[2], (int64_t) output_shape[3]}); + res = std::make_shared<ov::op::v1::Reshape>(stack, data_shape, false); + } else if (mode == TYPE_NEOX) { + auto data_split = std::make_shared<ov::op::v1::Split>( + data_node, ov::op::v0::Constant::create(ov::element::i64, ov::Shape{}, {-1}), 2); + Output<Node> slice_data_node_0 = data_split->outputs()[0]; + Output<Node> slice_data_node_1 = data_split->outputs()[1]; + + auto first_half_node = std::make_shared<ov::op::v1::Subtract>( + std::make_shared<ov::op::v1::Multiply>(slice_data_node_0, cos_theta_node), + std::make_shared<ov::op::v1::Multiply>(slice_data_node_1, sin_theta_node)); + + auto second_half_node = std::make_shared<ov::op::v1::Add>( + std::make_shared<ov::op::v1::Multiply>(slice_data_node_0, sin_theta_node), + std::make_shared<ov::op::v1::Multiply>(slice_data_node_1, cos_theta_node)); + + res = std::make_shared<ov::op::v0::Concat>(ov::OutputVector{first_half_node, second_half_node}, -1); + } else if (mode == TYPE_IMROPE) { + int64_t n_dims = data_node->get_shape()[3]; + auto cos_sin_shape = std::make_shared<ov::op::v0::Constant>(ov::element::i64, ov::Shape{4}, std::vector<int64_t>{1,-1,1,(n_dims >> 1)}); + auto cos_reshaped = std::make_shared<ov::op::v1::Reshape>(cos_theta_node, cos_sin_shape, true); + auto sin_reshaped = std::make_shared<ov::op::v1::Reshape>(sin_theta_node, cos_sin_shape, true); + + auto split_axis = ov::op::v0::Constant::create(ov::element::i64, ov::Shape{}, {3}); + auto split_a = std::make_shared<ov::op::v1::Split>(data_node, split_axis, 2); + auto x0 = split_a->output(0); + auto x1 = split_a->output(1); + auto mul_a = std::make_shared<ov::op::v1::Multiply>(x0, cos_reshaped); + auto mul_b = std::make_shared<ov::op::v1::Multiply>(x1, sin_reshaped); + auto sub = std::make_shared<ov::op::v1::Subtract>(mul_a, mul_b); + + auto mul_c = std::make_shared<ov::op::v1::Multiply>(x0, sin_reshaped); + auto mul_d = std::make_shared<ov::op::v1::Multiply>(x1, cos_reshaped); + auto add = std::make_shared<ov::op::v1::Add>(mul_c, mul_d); + + res = std::make_shared<ov::op::v0::Concat>(ov::OutputVector{sub, add}, 3); + } + + return rename_outputs_with_suffix({res}, context.get_name()); +} + +} // namespace op +} // namespace ggml +} // namespace frontend +} // namespace ov diff --git a/ggml/src/ggml-openvino/openvino/op/scale.cpp b/ggml/src/ggml-openvino/openvino/op/scale.cpp new file mode 100644 index 00000000000..0f3d800c199 --- /dev/null +++ b/ggml/src/ggml-openvino/openvino/op/scale.cpp @@ -0,0 +1,41 @@ +#include "../node_context.h" +#include "../op_table.h" +#include "../utils.h" + +#include <openvino/op/add.hpp> +#include <openvino/op/constant.hpp> +#include <openvino/op/multiply.hpp> +#include <vector> + +namespace ov { +namespace frontend { +namespace ggml { +namespace op { + +OutputVector translate_scale(const NodeContext & context) { + num_inputs_check(context, 1, 1); + + float scale; + float bias; + memcpy(&scale, (float *) context.get_output_op_params() + 0, sizeof(float)); + memcpy(&bias, (float *) context.get_output_op_params() + 1, sizeof(float)); + + auto scale_node = std::make_shared<ov::op::v0::Constant>(ov::element::f32, ov::Shape{}, std::vector<float>{scale}); + auto scaled = std::make_shared<ov::op::v1::Multiply>(context.get_input(0), scale_node); + + std::shared_ptr<ov::Node> res; + if (bias != 0.0f) { + auto bias_node = + std::make_shared<ov::op::v0::Constant>(ov::element::f32, ov::Shape{}, std::vector<float>{bias}); + res = std::make_shared<ov::op::v1::Add>(scaled, bias_node); + } else { + res = scaled; + } + + return rename_outputs_with_suffix({res}, context.get_name()); +} + +} // namespace op +} // namespace ggml +} // namespace frontend +} // namespace ov diff --git a/ggml/src/ggml-openvino/openvino/op/set_rows.cpp b/ggml/src/ggml-openvino/openvino/op/set_rows.cpp new file mode 100644 index 00000000000..136e4265b42 --- /dev/null +++ b/ggml/src/ggml-openvino/openvino/op/set_rows.cpp @@ -0,0 +1,76 @@ +#include "../node_context.h" +#include "../op_table.h" +#include "../utils.h" + +#include <cassert> +#include <cstdint> +#include <memory> +#include <openvino/core/node.hpp> +#include <openvino/core/node_output.hpp> +#include <openvino/frontend/exception.hpp> +#include <openvino/op/concat.hpp> +#include <openvino/op/constant.hpp> +#include <openvino/op/convert.hpp> +#include <openvino/op/gather.hpp> +#include <openvino/op/reshape.hpp> +#include <openvino/op/scatter_update.hpp> +#include <openvino/op/shape_of.hpp> +#include <openvino/op/slice.hpp> +#include <openvino/op/squeeze.hpp> +#include <openvino/op/transpose.hpp> +#include <vector> + +namespace ov { +namespace frontend { +namespace ggml { +namespace op { + +OutputVector translate_set_rows(const NodeContext & context) { + num_inputs_check(context, 3, 3); + + auto data = context.get_input(0); + auto indices = context.get_input(1); + auto dst = context.get_input(2); + + data = std::make_shared<ov::op::v0::Convert>(data, context.get_output_type()); + + auto dst_shape = context.get_output_shape().to_shape(); + + auto ind_squeezed = + std::make_shared<ov::op::v0::Squeeze>(indices, ov::op::v0::Constant::create(ov::element::i64, {3}, {0, 1, 2})); + auto data_reshaped = std::make_shared<ov::op::v1::Reshape>( + data, + ov::op::v0::Constant::create(ov::element::i64, {4}, + {(int64_t) 1, (int64_t) 1, (int64_t) -1, (int64_t) dst_shape[3]}), + false); + auto axes = ov::op::v0::Constant::create(ov::element::i64, ov::Shape{}, {2}); + + Output<Node> res; + if (context.is_stateful()) { + int concat_axis = 1; + int64_t dim2 = dst.get_partial_shape()[2].get_length(); + int64_t dim3 = dst.get_partial_shape()[3].get_length(); + data = std::make_shared<ov::op::v1::Reshape>( + data, ov::op::v0::Constant::create(ov::element::i64, {4}, {(int64_t) 1, (int64_t) -1, dim2, dim3}), false); + res = std::make_shared<ov::op::v0::Concat>(OutputVector{dst, data}, concat_axis); + } else { + res = std::make_shared<ov::op::v3::ScatterUpdate>(dst, ind_squeezed, data_reshaped, axes); + } + + if (auto dst_reshape = std::dynamic_pointer_cast<ov::op::v1::Reshape>(dst.get_node_shared_ptr())) { + // Fix the case of multiple sequences, reshape back to original shape [1, n_seq, ctx_per_seq, emb] + // ctx_per_seq is not fixed due to llama-bench compatibility + auto dst_shape_partial = dst_reshape->get_input_partial_shape(0); + std::vector<int64_t> dst_shape = {dst_shape_partial[0].get_length(), dst_shape_partial[1].get_length(), + dst_shape_partial[2].is_static() ? dst_shape_partial[2].get_length() : -1, + dst_shape_partial[3].get_length()}; + res = std::make_shared<ov::op::v1::Reshape>(res, ov::op::v0::Constant::create(ov::element::i64, {4}, dst_shape), + false); + } + return rename_outputs_with_suffix({res}, context.get_name()); +} + +} // namespace op +} // namespace ggml +} // namespace frontend +} // namespace ov diff --git a/ggml/src/ggml-openvino/openvino/op/softmax.cpp b/ggml/src/ggml-openvino/openvino/op/softmax.cpp new file mode 100644 index 00000000000..9f6330862be --- /dev/null +++ b/ggml/src/ggml-openvino/openvino/op/softmax.cpp @@ -0,0 +1,89 @@ +#include "../node_context.h" +#include "../op_table.h" +#include "../utils.h" + +#include <climits> +#include <cstdint> +#include <memory> +#include <openvino/core/node.hpp> +#include <openvino/core/node_output.hpp> +#include <openvino/op/add.hpp> +#include <openvino/op/concat.hpp> +#include <openvino/op/constant.hpp> +#include <openvino/op/convert.hpp> +#include <openvino/op/matmul.hpp> +#include <openvino/op/multiply.hpp> +#include <openvino/op/slice.hpp> +#include <openvino/op/softmax.hpp> +#include <vector> + +namespace ov { +namespace frontend { +namespace ggml { +namespace op { + +OutputVector translate_soft_max(const NodeContext & context) { + // TODO code is outdated + num_inputs_check(context, 1, 2); + + auto input_node = context.get_input(0).get_node_shared_ptr(); + ov::Output<Node> res; + + float scale = 1.0f; + float max_bias = 0.0f; + auto * op_params = context.get_output_op_params(); + memcpy(&scale, (float *) op_params + 0, sizeof(float)); + memcpy(&max_bias, (float *) op_params + 1, sizeof(float)); + auto src0_shape = context.get_input_shape(0).get_shape(); + const uint32_t h = src0_shape[2]; + const uint32_t n_head = src0_shape[0]; + const uint32_t n_head_log2 = 1u << (uint32_t) floor(log2(n_head)); + + const float m0 = powf(2.0f, -(max_bias) / n_head_log2); + const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2); + const float slope = + (max_bias > 0.0f) ? h < n_head_log2 ? powf(m0, h + 1) : powf(m1, 2 * (h - n_head_log2) + 1) : 1.0f; + + auto scale_node = std::make_shared<ov::op::v0::Constant>(ov::element::f32, ov::Shape{}, std::vector<float>{scale}); + auto scaled_input = std::make_shared<ov::op::v1::Multiply>(input_node, scale_node); + + if (context.get_input_size() < 2) { + res = std::make_shared<ov::op::v8::Softmax>(scaled_input, 2); + return rename_outputs_with_suffix({res}, context.get_name()); + } + + ov::Output<ov::Node> mask_node_sliced; + if (context.has_input("KQ_mask_sliced")) { + mask_node_sliced = context.get_input("KQ_mask_sliced"); + } else { + auto token_len = get_dimensions(input_node, {1}); + auto mask_node = context.get_input(1); + auto zero = ov::op::v0::Constant::create(ov::element::i64, {1}, {0}); + auto one = ov::op::v0::Constant::create(ov::element::i64, {1}, {1}); + mask_node_sliced = std::make_shared<ov::op::v8::Slice>(mask_node, zero, token_len, one, one); + } + + if (mask_node_sliced.get_element_type() != context.get_output_type()) { + mask_node_sliced = std::make_shared<ov::op::v0::Convert>(mask_node_sliced, context.get_output_type()); + } + + Output<Node> slope_mask; + if (slope != 1.0f) { + auto slope_node = + std::make_shared<ov::op::v0::Constant>(ov::element::f32, ov::Shape{}, std::vector<float>{slope}); + slope_mask = std::make_shared<ov::op::v1::Multiply>(mask_node_sliced, slope_node); + throw std::runtime_error("Slope != 1.0f in softmax has not been tested, verify it before use."); + } + slope_mask = mask_node_sliced; + + auto input_slope_mask_node = std::make_shared<ov::op::v1::Add>(scaled_input, slope_mask); + + res = std::make_shared<ov::op::v8::Softmax>(input_slope_mask_node, 2); + + return rename_outputs_with_suffix({res}, context.get_name()); +} + +} // namespace op +} // namespace ggml +} // namespace frontend +} // namespace ov diff --git a/ggml/src/ggml-openvino/openvino/op/transpose.cpp b/ggml/src/ggml-openvino/openvino/op/transpose.cpp new file mode 100644 index 00000000000..8e62e83c0d7 --- /dev/null +++ b/ggml/src/ggml-openvino/openvino/op/transpose.cpp @@ -0,0 +1,23 @@ +#include "../node_context.h" +#include "../op_table.h" +#include "../utils.h" + +#include <openvino/op/transpose.hpp> + +namespace ov { +namespace frontend { +namespace ggml { +namespace op { + +OutputVector translate_transpose(const NodeContext & context) { + num_inputs_check(context, 1, 1); + + auto res = std::make_shared<ov::op::v1::Transpose>( + context.get_input(0), ov::op::v0::Constant::create(ov::element::i64, {4}, {0, 1, 3, 2})); + return rename_outputs_with_suffix({res}, context.get_name()); +} + +} // namespace op +} // namespace ggml +} // namespace frontend +} // namespace ov diff --git a/ggml/src/ggml-openvino/openvino/op/unary_gelu.cpp b/ggml/src/ggml-openvino/openvino/op/unary_gelu.cpp new file mode 100644 index 00000000000..d1e9efc33a5 --- /dev/null +++ b/ggml/src/ggml-openvino/openvino/op/unary_gelu.cpp @@ -0,0 +1,25 @@ +#include "../node_context.h" +#include "../op_table.h" +#include "../utils.h" + +#include <openvino/core/node_output.hpp> +#include <openvino/op/gelu.hpp> + +namespace ov { +namespace frontend { +namespace ggml { +namespace op { + +OutputVector translate_unary_gelu(const NodeContext & context) { + num_inputs_check(context, 1, 1); + + auto input = context.get_input(0); + auto res = std::make_shared<ov::op::v7::Gelu>(input); + + return rename_outputs_with_suffix({res}, context.get_name()); +} + +} // namespace op +} // namespace ggml +} // namespace frontend +} // namespace ov diff --git a/ggml/src/ggml-openvino/openvino/op/unary_silu.cpp b/ggml/src/ggml-openvino/openvino/op/unary_silu.cpp new file mode 100644 index 00000000000..037e0b94df1 --- /dev/null +++ b/ggml/src/ggml-openvino/openvino/op/unary_silu.cpp @@ -0,0 +1,27 @@ +#include "../node_context.h" +#include "../op_table.h" +#include "../utils.h" + +#include <openvino/core/node_output.hpp> +#include <openvino/op/multiply.hpp> +#include <openvino/op/sigmoid.hpp> + +namespace ov { +namespace frontend { +namespace ggml { +namespace op { + +OutputVector translate_unary_silu(const NodeContext & context) { + num_inputs_check(context, 1, 1); + + auto input = context.get_input(0); + auto sigmoid = std::make_shared<ov::op::v0::Sigmoid>(input); + auto res = std::make_shared<ov::op::v1::Multiply>(input, sigmoid); + + return rename_outputs_with_suffix({res}, context.get_name()); +} + +} // namespace op +} // namespace ggml +} // namespace frontend +} // namespace ov diff --git a/ggml/src/ggml-openvino/openvino/op/view.cpp b/ggml/src/ggml-openvino/openvino/op/view.cpp new file mode 100644 index 00000000000..8528d252336 --- /dev/null +++ b/ggml/src/ggml-openvino/openvino/op/view.cpp @@ -0,0 +1,53 @@ +#include "../op_table.h" +#include "../utils.h" +#include <openvino/op/reshape.hpp> +namespace ov { +namespace frontend { +namespace ggml { +namespace op { + +OutputVector translate_view(const NodeContext & context) { + num_inputs_check(context, 1, 1); + + if (context.get_op_case() == 2) { + auto dst_shape = context.get_output_shape().to_shape(); + return rename_outputs_with_suffix({process_view_input(context, 0, dst_shape[2] * dst_shape[3])}, + context.get_name()); + } + // op_case 3 + if (context.get_op_case() == 3) { + auto input = context.get_input(0); + auto input_ov_shape = input.get_partial_shape(); + + auto input_llama_shape = context.get_input_shape(0).to_shape(); + + // if the input ov shape size is different from the input llama shape size, it means the input is already reshaped and we need to reshape it back to the original shape before slicing + if (input_ov_shape.size() != input_llama_shape.size()) { + input = std::make_shared<ov::op::v1::Reshape>(input, ov::op::v0::Constant::create(ov::element::i64, {input_llama_shape.size()}, input_llama_shape), false); + } + + auto dst_shape = context.get_output_shape().to_shape(); + + // find the index of dst_shape that is different from input shape, and use that index to slice the input + int slice_dim = -1; + for (size_t i = 0; i < dst_shape.size(); ++i) { + if (dst_shape[i] != input_llama_shape[i]) { + slice_dim = i; + break; + } + } + + auto begin = ov::op::v0::Constant::create(ov::element::i64, {1}, {0}); + auto end = ov::op::v0::Constant::create(ov::element::i64, {1}, {dst_shape[slice_dim]}); + auto stride = ov::op::v0::Constant::create(ov::element::i64, {1}, {1}); + auto axes = ov::op::v0::Constant::create(ov::element::i64, {1}, {slice_dim}); + auto sliced = std::make_shared<ov::op::v8::Slice>(input, begin, end, stride, axes); + return {sliced}; + } + return {context.get_input(0)}; +} + +} // namespace op +} // namespace ggml +} // namespace frontend +} // namespace ov diff --git a/ggml/src/ggml-openvino/openvino/op_table.cpp b/ggml/src/ggml-openvino/openvino/op_table.cpp new file mode 100644 index 00000000000..1385539279c --- /dev/null +++ b/ggml/src/ggml-openvino/openvino/op_table.cpp @@ -0,0 +1,47 @@ +#include "op_table.h" + +#include "utils.h" + +#include <openvino/op/add.hpp> +#include <openvino/op/divide.hpp> +#include <openvino/op/gather.hpp> +#include <openvino/op/matmul.hpp> +#include <openvino/op/multiply.hpp> +#include <openvino/op/subtract.hpp> + +namespace ov { +namespace frontend { +namespace ggml { + +std::unordered_map<std::string, CreatorFunction> get_supported_ops() { + using namespace ov::op; + return { + {"GGML_OP_ADD", op::translate_1to1_match_2_inputs<v1::Add> }, + {"GGML_OP_ADD1", op::translate_1to1_match_2_inputs<v1::Add> }, + {"GGML_OP_CONT", op::translate_cont }, + {"GGML_OP_DIV", op::translate_1to1_match_2_inputs<v1::Divide> }, + {"GGML_OP_GET_ROWS", op::translate_get_rows }, + {"GGML_OP_MUL", op::translate_1to1_match_2_inputs<v1::Multiply>}, + {"GGML_OP_MUL_MAT", op::translate_mulmat }, + {"GGML_OP_PERMUTE", op::translate_permute }, + {"GGML_OP_RESHAPE", op::translate_reshape }, + {"GGML_OP_RMS_NORM", op::translate_rms_norm }, + {"GGML_OP_ROPE", op::translate_rope }, + {"GGML_OP_SCALE", op::translate_scale }, + {"GGML_OP_SOFT_MAX", op::translate_soft_max }, + {"GGML_OP_SUB", op::translate_1to1_match_2_inputs<v1::Subtract>}, + {"GGML_OP_TRANSPOSE", op::translate_transpose }, + {"GGML_UNARY_OP_GELU", op::translate_unary_gelu }, + {"GGML_UNARY_OP_SILU", op::translate_unary_silu }, + {"GGML_OP_VIEW", op::translate_view }, + {"GGML_GLU_OP_SWIGLU", op::translate_glu_swiglu }, + {"GGML_GLU_OP_GEGLU", op::translate_glu_geglu }, + {"GGML_OP_SET_ROWS", op::translate_set_rows }, + {"GGML_OP_CPY", op::translate_cpy }, + {"GGML_OP_FLASH_ATTN_EXT", op::translate_flash_attn_ext }, + }; +} + +} // namespace ggml +} // namespace frontend +} // namespace ov diff --git a/ggml/src/ggml-openvino/openvino/op_table.h b/ggml/src/ggml-openvino/openvino/op_table.h new file mode 100644 index 00000000000..f546796d2ee --- /dev/null +++ b/ggml/src/ggml-openvino/openvino/op_table.h @@ -0,0 +1,40 @@ +#pragma once + +#include "node_context.h" + +namespace ov { +namespace frontend { +namespace ggml { + +namespace op { + +#define GGML_OP_CONVERTER(op) OutputVector op(const NodeContext& context) + +GGML_OP_CONVERTER(translate_add); +GGML_OP_CONVERTER(translate_cont); +GGML_OP_CONVERTER(translate_get_rows); +GGML_OP_CONVERTER(translate_mul); +GGML_OP_CONVERTER(translate_mulmat); +GGML_OP_CONVERTER(translate_permute); +GGML_OP_CONVERTER(translate_reshape); +GGML_OP_CONVERTER(translate_rms_norm); +GGML_OP_CONVERTER(translate_rope); +GGML_OP_CONVERTER(translate_scale); +GGML_OP_CONVERTER(translate_unary_silu); +GGML_OP_CONVERTER(translate_unary_gelu); +GGML_OP_CONVERTER(translate_soft_max); +GGML_OP_CONVERTER(translate_transpose); +GGML_OP_CONVERTER(translate_view); +GGML_OP_CONVERTER(translate_glu_swiglu); +GGML_OP_CONVERTER(translate_glu_geglu); +GGML_OP_CONVERTER(translate_set_rows); +GGML_OP_CONVERTER(translate_cpy); +GGML_OP_CONVERTER(translate_flash_attn_ext); + +} // namespace op + +std::unordered_map<std::string, CreatorFunction> get_supported_ops(); + +} // namespace ggml +} // namespace frontend +} // namespace ov diff --git a/ggml/src/ggml-openvino/openvino/pass/fuse_to_sdpa.cpp b/ggml/src/ggml-openvino/openvino/pass/fuse_to_sdpa.cpp new file mode 100644 index 00000000000..0671542ee38 --- /dev/null +++ b/ggml/src/ggml-openvino/openvino/pass/fuse_to_sdpa.cpp @@ -0,0 +1,60 @@ +#include "fuse_to_sdpa.h" + +#include <openvino/core/graph_util.hpp> +#include <openvino/core/rt_info.hpp> +#include <openvino/op/add.hpp> +#include <openvino/op/convert.hpp> +#include <openvino/op/matmul.hpp> +#include <openvino/op/multiply.hpp> +#include <openvino/op/scaled_dot_product_attention.hpp> +#include <openvino/op/softmax.hpp> +#include <openvino/op/transpose.hpp> +#include <openvino/pass/pattern/op/label.hpp> +#include <openvino/pass/pattern/op/pattern.hpp> +#include <openvino/pass/pattern/op/wrap_type.hpp> + +namespace ov { +namespace frontend { +namespace ggml { +namespace pass { + +FuseToSDPA::FuseToSDPA() { + // Not maintained since FLASH_ATTN_EXT has replaced this pattern + const auto m_k = ov::pass::pattern::any_input(); + const auto m_q = ov::pass::pattern::any_input(); + const auto m_qk = ov::pass::pattern::wrap_type<ov::op::v0::MatMul>({m_q, m_k}); + const auto m_qk_f32 = ov::pass::pattern::wrap_type<ov::op::v0::Convert>({m_qk}); + const auto m_scale = ov::pass::pattern::any_input(); + const auto m_scaled_qk = ov::pass::pattern::wrap_type<ov::op::v1::Multiply>({m_qk_f32, m_scale}); + const auto m_mask = ov::pass::pattern::any_input(); + const auto m_masked_qk = ov::pass::pattern::wrap_type<ov::op::v1::Add>({m_scaled_qk, m_mask}); + const auto m_softmax_qk = ov::pass::pattern::wrap_type<ov::op::v8::Softmax>({m_masked_qk}); + const auto m_softmax_qk_f16 = ov::pass::pattern::wrap_type<ov::op::v0::Convert>({m_softmax_qk}); + const auto m_v = ov::pass::pattern::any_input(); + const auto m_qkv = ov::pass::pattern::wrap_type<ov::op::v0::MatMul>({m_softmax_qk_f16, m_v}); + + const auto callback = [=](ov::pass::pattern::Matcher & m) { + auto & pattern_to_output = m.get_pattern_value_map(); + auto k = pattern_to_output[m_k]; + auto q = pattern_to_output[m_q]; + auto v = pattern_to_output[m_v]; + auto mask = pattern_to_output[m_mask]; + auto scale = pattern_to_output[m_scale]; + + auto mask_f16 = register_new_node<ov::op::v0::Convert>(mask, ov::element::f16); + auto scale_f16 = register_new_node<ov::op::v0::Convert>(scale, ov::element::f16); + auto sdpa = std::make_shared<ov::op::v13::ScaledDotProductAttention>(q, k, v, mask_f16, scale_f16, false); + + ov::replace_node(m.get_match_root(), sdpa); + ov::copy_runtime_info(m.get_matched_nodes(), sdpa); + + return true; + }; + register_matcher(std::make_shared<ov::pass::pattern::Matcher>(m_qkv, "ov::frontend::ggml::pass::FuseToSDPA"), + callback); +} + +} // namespace pass +} // namespace ggml +} // namespace frontend +} // namespace ov diff --git a/ggml/src/ggml-openvino/openvino/pass/fuse_to_sdpa.h b/ggml/src/ggml-openvino/openvino/pass/fuse_to_sdpa.h new file mode 100644 index 00000000000..8b5164d2329 --- /dev/null +++ b/ggml/src/ggml-openvino/openvino/pass/fuse_to_sdpa.h @@ -0,0 +1,17 @@ +#include "openvino/pass/matcher_pass.hpp" + +namespace ov { +namespace frontend { +namespace ggml { +namespace pass { + +class FuseToSDPA : public ov::pass::MatcherPass { +public: + OPENVINO_MATCHER_PASS_RTTI("ov::frontend::ggml::pass::FuseToSDPA") + FuseToSDPA(); +}; + +} // namespace pass +} // namespace ggml +} // namespace frontend +} // namespace ov diff --git a/ggml/src/ggml-openvino/openvino/pass/mark_decompression_convert_constant_folding.h b/ggml/src/ggml-openvino/openvino/pass/mark_decompression_convert_constant_folding.h new file mode 100644 index 00000000000..b95385611e8 --- /dev/null +++ b/ggml/src/ggml-openvino/openvino/pass/mark_decompression_convert_constant_folding.h @@ -0,0 +1,29 @@ +#pragma once + +#include "mark_decompression_convert_constant_folding.h" +#include "openvino/pass/matcher_pass.hpp" +#include "openvino/core/visibility.hpp" + +#ifdef OPENVINO_STATIC_LIBRARY +# define TRANSFORMATIONS_API +#else +# ifdef IMPLEMENT_OPENVINO_API +# define TRANSFORMATIONS_API OPENVINO_CORE_EXPORTS +# else +# define TRANSFORMATIONS_API OPENVINO_CORE_IMPORTS +# endif // IMPLEMENT_OPENVINO_API +#endif // OPENVINO_STATIC_LIBRARY + +namespace ov { +namespace pass { + +class TRANSFORMATIONS_API MarkCompressedFloatConstants; + +} // namespace pass +} // namespace ov + +class ov::pass::MarkCompressedFloatConstants : public MatcherPass { +public: + OPENVINO_MATCHER_PASS_RTTI("MarkCompressedFloatConstants") + MarkCompressedFloatConstants(); +}; diff --git a/ggml/src/ggml-openvino/openvino/pass/squeeze_matmul.cpp b/ggml/src/ggml-openvino/openvino/pass/squeeze_matmul.cpp new file mode 100644 index 00000000000..20a3a374934 --- /dev/null +++ b/ggml/src/ggml-openvino/openvino/pass/squeeze_matmul.cpp @@ -0,0 +1,58 @@ +#include "squeeze_matmul.h" + +#include <openvino/core/graph_util.hpp> +#include <openvino/core/rt_info.hpp> +#include <openvino/op/constant.hpp> +#include <openvino/op/matmul.hpp> +#include <openvino/op/squeeze.hpp> +#include <openvino/op/unsqueeze.hpp> +#include <openvino/pass/pattern/op/label.hpp> +#include <openvino/pass/pattern/op/pattern.hpp> +#include <openvino/pass/pattern/op/wrap_type.hpp> + +namespace opp = ov::pass::pattern; + +namespace ov { +namespace frontend { +namespace ggml { +namespace pass { + +// For quantized models, NPUW expects the activation to be 3d in DQ(DynamicQuantization) opt, e.g. DQMatMulGQ2i +SqueezeMatmul::SqueezeMatmul() { + auto m_act = opp::any_input(); + auto m_wei = opp::any_input(); + auto m_matmul = opp::wrap_type<ov::op::v0::MatMul>({m_act, m_wei}); + + const auto callback = [=](ov::pass::pattern::Matcher & m) { + const auto & pattern_map = m.get_pattern_value_map(); + auto matmul_node = + std::dynamic_pointer_cast<ov::op::v0::MatMul>(pattern_map.at(m_matmul).get_node_shared_ptr()); + auto act = pattern_map.at(m_act); + auto wei = pattern_map.at(m_wei); + auto act_shape = act.get_partial_shape(); + auto wei_shape = wei.get_partial_shape(); + if (act_shape.rank().is_dynamic() || wei_shape.rank().is_dynamic()) { + return false; + } + if (act_shape.rank().get_length() == 4 && wei_shape.rank().get_length() == 2) { + auto axis = ov::op::v0::Constant::create(ov::element::i64, ov::Shape{1}, {0}); + auto squeezed_act = std::make_shared<ov::op::v0::Squeeze>(act, axis); + auto new_matmul = std::make_shared<ov::op::v0::MatMul>(squeezed_act, wei, matmul_node->get_transpose_a(), + matmul_node->get_transpose_b()); + auto unsqueezed_output = std::make_shared<ov::op::v0::Unsqueeze>(new_matmul, axis); + unsqueezed_output->set_friendly_name(matmul_node->get_friendly_name()); + ov::copy_runtime_info(matmul_node, {squeezed_act, new_matmul, unsqueezed_output}); + ov::replace_node(matmul_node, unsqueezed_output); + return true; + } + return false; + }; + + register_matcher(std::make_shared<ov::pass::pattern::Matcher>(m_matmul, "ov::frontend::ggml::pass::SqueezeMatmul"), + callback); +} + +} // namespace pass +} // namespace ggml +} // namespace frontend +} // namespace ov diff --git a/ggml/src/ggml-openvino/openvino/pass/squeeze_matmul.h b/ggml/src/ggml-openvino/openvino/pass/squeeze_matmul.h new file mode 100644 index 00000000000..f8fbc69d546 --- /dev/null +++ b/ggml/src/ggml-openvino/openvino/pass/squeeze_matmul.h @@ -0,0 +1,17 @@ +#include "openvino/pass/matcher_pass.hpp" + +namespace ov { +namespace frontend { +namespace ggml { +namespace pass { + +class SqueezeMatmul : public ov::pass::MatcherPass { +public: + OPENVINO_MATCHER_PASS_RTTI("ov::frontend::ggml::pass::SqueezeMatmul") + SqueezeMatmul(); +}; + +} // namespace pass +} // namespace ggml +} // namespace frontend +} // namespace ov diff --git a/ggml/src/ggml-openvino/openvino/rt_info/weightless_caching_attributes.hpp b/ggml/src/ggml-openvino/openvino/rt_info/weightless_caching_attributes.hpp new file mode 100644 index 00000000000..f051891c481 --- /dev/null +++ b/ggml/src/ggml-openvino/openvino/rt_info/weightless_caching_attributes.hpp @@ -0,0 +1,41 @@ +// Copyright (C) 2018-2026 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include <openvino/core/core_visibility.hpp> +#include <openvino/core/node.hpp> +#include <openvino/core/runtime_attribute.hpp> + +namespace ov { + +/** + * @brief Holds weightless caching attributes of a single constant. + * + * WeightlessCacheAttribute class represents runtime info attribute that holds + * the values of original size of the constant in bytes and the binary offset of the + * constant's data in the weights file used by the weightless caching mechanism. It's + * not copyable in case the data was changed (the original node was replaced by a new + * one produced during the tranformation pipeline) - in that case weightless caching + * can't be used for that constant. + */ +class OPENVINO_API WeightlessCacheAttribute : public RuntimeAttribute { +public: + OPENVINO_RTTI("WeightlessCacheAttribute", "0", RuntimeAttribute) + + WeightlessCacheAttribute() = delete; + + WeightlessCacheAttribute(size_t original_size, size_t bin_offset, ov::element::Type original_dtype) + : original_size(original_size), + bin_offset(bin_offset), + original_dtype(original_dtype) {} + + bool is_copyable() const override; + + size_t original_size; + size_t bin_offset; + ov::element::Type original_dtype; +}; + +} // namespace ov diff --git a/ggml/src/ggml-openvino/openvino/translate_session.cpp b/ggml/src/ggml-openvino/openvino/translate_session.cpp new file mode 100644 index 00000000000..0f68a1f5062 --- /dev/null +++ b/ggml/src/ggml-openvino/openvino/translate_session.cpp @@ -0,0 +1,317 @@ +#include "translate_session.h" + +#include "ggml-openvino/openvino/node_context.h" +#include "ggml-openvino/openvino/utils.h" +#include "input_model.h" +#include "pass/mark_decompression_convert_constant_folding.h" +#include "pass/squeeze_matmul.h" +#include "rt_info/weightless_caching_attributes.hpp" + +#include <cstdint> +#include <cstdlib> +#include <map> +#include <memory> +#include <openvino/core/node.hpp> +#include <openvino/core/preprocess/pre_post_process.hpp> +#include <openvino/op/add.hpp> +#include <openvino/op/broadcast.hpp> +#include <openvino/op/concat.hpp> +#include <openvino/op/convert.hpp> +#include <openvino/op/convert_like.hpp> +#include <openvino/op/cos.hpp> +#include <openvino/op/divide.hpp> +#include <openvino/op/gather.hpp> +#include <openvino/op/multiply.hpp> +#include <openvino/op/parameter.hpp> +#include <openvino/op/range.hpp> +#include <openvino/op/reshape.hpp> +#include <openvino/op/result.hpp> +#include <openvino/op/sin.hpp> +#include <openvino/op/slice.hpp> +#include <openvino/op/squeeze.hpp> +#include <openvino/op/strided_slice.hpp> +#include <openvino/op/transpose.hpp> +#include <openvino/op/unsqueeze.hpp> +#include <openvino/pass/constant_folding.hpp> +#include <openvino/pass/make_stateful.hpp> + +namespace ov { +namespace frontend { +namespace ggml { + +using namespace ov::op; + +namespace { + +ov::pass::MakeStateful::ParamResPairs get_kv_param_res_pairs( + const std::shared_ptr<ov::Model> & model, + const std::map<std::string, std::string> & kv_param_res_names) { + ov::pass::MakeStateful::ParamResPairs pairs; + const auto & params = model->get_parameters(); + const auto & results = model->get_results(); + + for (const auto & param_res : kv_param_res_names) { + const auto & param_name = param_res.first; + const auto & res_name = param_res.second; + + auto param_it = std::find_if(params.begin(), params.end(), [&](const std::shared_ptr<v0::Parameter> & node) { + return node->get_friendly_name() == param_name; + }); + + OPENVINO_ASSERT(param_it != params.end(), "The tensor name ", param_name, + " is not associated with any of " + "Parameters in the network."); + + auto res_it = std::find_if(results.begin(), results.end(), [&](const std::shared_ptr<v0::Result> & node) { + return node->get_friendly_name() == res_name; + }); + + OPENVINO_ASSERT(res_it != results.end(), "The tensor name ", res_name, + " is not associated with any of " + "Results in the network."); + + std::shared_ptr<ov::op::v0::Parameter> param = *param_it; + std::shared_ptr<ov::op::v0::Result> res = *res_it; + pairs.emplace_back(param, res); + } + return pairs; +} + +void add_sliced_mask(TensorMap & tensor_map, GgmlDecoder & ggml_model_decoder) { + + auto create_sliced_mask = [&](const std::string & mask_name, const std::string & sliced_name, bool is_static) { + if ((tensor_map.find(mask_name) != tensor_map.end()) && + (tensor_map.find("token_len_per_seq") != tensor_map.end())) { + auto token_len_per_seq = tensor_map.at("token_len_per_seq").get_node_shared_ptr(); + auto mask = tensor_map.at(mask_name).get_node_shared_ptr(); + std::shared_ptr<ov::Node> mask_sliced; + if (is_static) { + mask_sliced = mask; + } else if (ggml_model_decoder.is_stateful()) { + auto zero_2d = ov::op::v0::Constant::create(ov::element::i64, {2}, {0,0}); + auto one_2d = ov::op::v0::Constant::create(ov::element::i64, {2}, {1,1}); + auto zero_1d = ov::op::v0::Constant::create(ov::element::i64, {1}, {0}); + auto three_1d = ov::op::v0::Constant::create(ov::element::i64, {1}, {3}); + auto neg_one_1d = ov::op::v0::Constant::create(ov::element::i64, {1}, {-1}); + auto axes = ov::op::v0::Constant::create(ov::element::i64, {2}, {-2,-1}); + auto inp_pos = tensor_map.at("inp_pos").get_node_shared_ptr(); + auto gather_inp_pos = std::make_shared<ov::op::v8::Gather>(inp_pos, neg_one_1d, three_1d); + auto reshaped_inp_pos = std::make_shared<ov::op::v1::Reshape>(gather_inp_pos, ov::op::v0::Constant::create(ov::element::i64, {1}, {1}), false); + auto inp_pos_incremented = std::make_shared<ov::op::v1::Add>(reshaped_inp_pos, ov::op::v0::Constant::create(ov::element::i32, ov::Shape{1}, {1})); + auto stop = std::make_shared<ov::op::v0::Concat>(ov::OutputVector{token_len_per_seq, std::make_shared<v1::ConvertLike>(inp_pos_incremented, token_len_per_seq)}, 0); + mask_sliced = + std::make_shared<ov::op::v8::Slice>(mask, zero_2d, stop, one_2d, axes); + mask_sliced = std::make_shared<ov::op::v0::Convert>(mask_sliced, ov::element::f16); + mask_sliced->set_friendly_name(sliced_name); + } else { + auto zero = ov::op::v0::Constant::create(ov::element::i64, {1}, {0}); + auto one = ov::op::v0::Constant::create(ov::element::i64, {1}, {1}); + auto two = ov::op::v0::Constant::create(ov::element::i64, {1}, {2}); + mask_sliced = std::make_shared<ov::op::v8::Slice>(mask, zero, token_len_per_seq, one, two); + mask_sliced = std::make_shared<ov::op::v0::Convert>(mask_sliced, ov::element::f16); + mask_sliced->set_friendly_name(sliced_name); + } + tensor_map.insert({sliced_name, mask_sliced->output(0)}); + } + }; + + create_sliced_mask("self_kq_mask", "KQ_mask_sliced", ggml_model_decoder.is_static()); + create_sliced_mask("self_kq_mask_swa", "KQ_mask_swa_sliced", ggml_model_decoder.is_static()); +} + +void add_rope_sin_cos(TensorMap & tensor_map, GgmlDecoder & ggml_model_decoder) { + int32_t * rope_params = ggml_model_decoder.get_rope_params(); + if (tensor_map.find("inp_pos") == tensor_map.end() || rope_params == nullptr) { + return; + } + auto inp_pos = tensor_map.at("inp_pos").get_node_shared_ptr(); + std::shared_ptr<ov::Node> rope_freqs_weight; + if (tensor_map.find("rope_freqs.weight") != tensor_map.end()) { + rope_freqs_weight = tensor_map.at("rope_freqs.weight").get_node_shared_ptr(); + } + + auto sin_cos = make_sin_cos(rope_params, inp_pos, rope_freqs_weight); + auto sin_theta = sin_cos.first; + auto cos_theta = sin_cos.second; + + cos_theta.get_node_shared_ptr()->set_friendly_name("rope_cos"); + sin_theta.get_node_shared_ptr()->set_friendly_name("rope_sin"); + tensor_map.insert({"rope_cos", cos_theta}); + tensor_map.insert({"rope_sin", sin_theta}); +} + +// Create common patterns +void preprocess(TensorMap & tensor_map, GgmlDecoder & ggml_model_decoder) { + add_sliced_mask(tensor_map, ggml_model_decoder); + add_rope_sin_cos(tensor_map, ggml_model_decoder); +} + +} // namespace + +TranslateSession::TranslateSession(const frontend::InputModel::Ptr & input_model, + const std::unordered_map<std::string, CreatorFunction> & translator_map, + bool naive) : + m_input_model(input_model), + m_translator_map(translator_map), + m_ov_model(nullptr), + m_naive(naive) {} + +std::shared_ptr<Model> TranslateSession::get_converted_model() { + if (m_ov_model) { + return m_ov_model; + } + m_ov_model = translate_graph(m_input_model); + return m_ov_model; +} + +std::shared_ptr<Model> TranslateSession::translate_graph(const frontend::InputModel::Ptr & input_model) { + ov::ParameterVector params; + ov::ResultVector results; + auto tensor_map = std::make_shared<TensorMap>(); + std::shared_ptr<Model> resulting_model; + + const auto & ggml_model = std::dynamic_pointer_cast<InputModel>(input_model); + std::shared_ptr<GgmlDecoder> ggml_model_decoder = ggml_model->get_model_decoder(); + + for (const auto & it : ggml_model_decoder->get_model_inputs()) { + params.push_back(std::dynamic_pointer_cast<ov::op::v0::Parameter>(it.second)); + (*tensor_map)[it.first] = it.second; + } + + for (const auto & it : ggml_model_decoder->get_model_extra_inputs()) { + if (std::dynamic_pointer_cast<ov::op::v0::Parameter>(it.second)) { + params.push_back(std::dynamic_pointer_cast<ov::op::v0::Parameter>(it.second)); + } + (*tensor_map)[it.first] = it.second; + } + + for (const auto & it : ggml_model_decoder->get_model_weights()) { + (*tensor_map)[it.first] = it.second; + } + + auto node_visitor = [&](std::shared_ptr<GgmlDecoder> decoder, int node_idx) { + auto operation_type = decoder->get_op_type(node_idx); + if (operation_type == "GGML_OP_NONE") { + return; + } + + ov::OutputVector converted_outputs; + auto it = m_translator_map.find(operation_type); + FRONT_END_OP_CONVERSION_CHECK(it != m_translator_map.end(), "Translation for operation type ", operation_type, + " is not implemented."); + NodeContext node_context(decoder, tensor_map, node_idx, this); + converted_outputs = it->second(node_context); + + const auto & node_output_names = decoder->get_output_names(node_idx); + FRONT_END_OP_CONVERSION_CHECK(node_output_names.size() == converted_outputs.size(), "Number of ", + operation_type, " outputs greater than number of converted outputs, which are ", + node_output_names.size(), " and ", converted_outputs.size(), " respectively."); + + for (size_t i = 0; i < node_output_names.size(); ++i) { + auto output_name = node_output_names[i]; + if (i < converted_outputs.size() && converted_outputs[i].get_node_shared_ptr() != nullptr) { + (*tensor_map)[output_name] = converted_outputs[i]; + } + } + }; + + if (!m_naive) { + preprocess(*tensor_map, *ggml_model_decoder); + } + ggml_model_decoder->visit_subgraph(node_visitor); + + for (const auto & name : ggml_model_decoder->get_model_output_names()) { + FRONT_END_GENERAL_CHECK(tensor_map->find(name) != tensor_map->end(), + "Output name not found in tensor map: ", name); + auto result = std::make_shared<v0::Result>(tensor_map->at(name)); + result->set_friendly_name(name); + results.push_back(result); + } + + ov::ParameterVector used_params; + for (const auto & param : params) { + if (!param->output(0).get_target_inputs().empty()) { + used_params.push_back(param); + } + } + // if (auto diff = params.size() - used_params.size()) { + // GGML_LOG_INFO("%zu parameters are not used in the model.", diff); + // } + resulting_model = std::make_shared<Model>(results, used_params); + + apply_transformations(resulting_model); + + // Set WeightlessCacheAttribute on large constants to avoid unnecessary memory copies + // in the NPUW plugin. Without this attribute, NPUW's LazyTensor constructor + // (lazy_tensor.cpp, op::Const::Const) will memcpy every constant "in case export + // occurs", doubling memory usage per compile_model call. + // + // The bin_offset field serves as a unique key (not a real file offset) — this is + // the same convention the GPU plugin uses for non-IR models (see + // Plugin::set_weightless_cache_attributes in intel_gpu/src/plugin/plugin.cpp). + // Each constant must have a distinct bin_offset, otherwise GPU's weightless cache + // import will map multiple constants to the same data. + // + // Small constants (< 16 elements) are excluded since they may be introduced by + // optimization patterns and the overhead is negligible. + size_t offset = 0; + for (auto & node : resulting_model->get_ordered_ops()) { + if (auto cnst = ov::as_type_ptr<ov::op::v0::Constant>(node); + cnst && cnst->get_byte_size() / cnst->get_element_type().size() >= 16) { + auto & rt_info = cnst->get_rt_info(); + if (rt_info.find(ov::WeightlessCacheAttribute::get_type_info_static()) == rt_info.end()) { + rt_info[ov::WeightlessCacheAttribute::get_type_info_static()] = + ov::WeightlessCacheAttribute(cnst->get_byte_size(), offset++, cnst->get_element_type()); + } + } + } + return resulting_model; +} + +std::shared_ptr<Model> TranslateSession::apply_transformations(std::shared_ptr<Model> model) { + auto ggml_model_decoder = std::dynamic_pointer_cast<InputModel>(m_input_model)->get_model_decoder(); + { + ov::pass::Manager manager; + manager.set_per_pass_validation(true); + manager.register_pass<ov::pass::MarkCompressedFloatConstants>(); + + if (ggml_model_decoder->is_stateful()) { + const auto kv_param_res_names = ggml_model_decoder->get_kv_param_res_names(); + const auto kv_param_res_pairs = get_kv_param_res_pairs(model, kv_param_res_names); + manager.register_pass<ov::pass::MakeStateful>(kv_param_res_pairs); + } + + if (ggml_model_decoder->is_static()) { + manager.register_pass<pass::SqueezeMatmul>(); + } + manager.run_passes(model); + if (ggml_model_decoder->is_stateful()) { + auto output_names = ggml_model_decoder->get_model_output_names(); + std::map<std::string, int> model_output_indexes; + for (size_t i=0; i<output_names.size(); i++) { + model_output_indexes.insert(std::make_pair(output_names[i], i)); + } + ov::preprocess::PrePostProcessor ppp(model); + for (size_t i=0; i<model->get_output_size(); i++) { + auto output_friendly_name = model->output(i).get_node_shared_ptr()->get_friendly_name(); + auto output_id = model_output_indexes[output_friendly_name]; + auto model_output_shape = model->output(i).get_partial_shape(); + auto decoder_output_shape = ggml_model_decoder->get_output_shape(output_id); + if (model_output_shape.rank().is_static() && decoder_output_shape.rank().is_static() + && model_output_shape.rank().get_length() + 1 == decoder_output_shape.rank().get_length() + && decoder_output_shape[0].is_static() && decoder_output_shape[0].get_length() == 1) { + ppp.output(i).postprocess().custom([](const ov::Output<ov::Node>& node) { + auto axes = ov::op::v0::Constant::create(ov::element::i32, ov::Shape{1}, {0}); + return std::make_shared<ov::op::v0::Unsqueeze>(node, axes); + }); + } + } + model = ppp.build(); + } + } + return model; +} + +} // namespace ggml +} // namespace frontend +} // namespace ov diff --git a/ggml/src/ggml-openvino/openvino/translate_session.h b/ggml/src/ggml-openvino/openvino/translate_session.h new file mode 100644 index 00000000000..56a14ae7c07 --- /dev/null +++ b/ggml/src/ggml-openvino/openvino/translate_session.h @@ -0,0 +1,28 @@ +#pragma once + +#include "input_model.h" +#include "node_context.h" + +namespace ov { +namespace frontend { +namespace ggml { + +class TranslateSession { +public: + TranslateSession(const frontend::InputModel::Ptr& input_model, + const std::unordered_map<std::string, CreatorFunction>& translator_map, bool naive = false); + + std::shared_ptr<Model> get_converted_model(); + std::shared_ptr<Model> translate_graph(const frontend::InputModel::Ptr& input_model); + +private: + std::shared_ptr<Model> apply_transformations(std::shared_ptr<Model> model); + const frontend::InputModel::Ptr m_input_model; + const std::unordered_map<std::string, CreatorFunction>& m_translator_map; + std::shared_ptr<Model> m_ov_model; + bool m_naive; +}; + +} // namespace ggml +} // namespace frontend +} // namespace ov diff --git a/ggml/src/ggml-openvino/openvino/utils.cpp b/ggml/src/ggml-openvino/openvino/utils.cpp new file mode 100644 index 00000000000..0baaf88e17a --- /dev/null +++ b/ggml/src/ggml-openvino/openvino/utils.cpp @@ -0,0 +1,257 @@ +#include "utils.h" + +#include "ggml-impl.h" + +#include <cmath> +#include <cstddef> +#include <ctime> +#include <memory> +#include <openvino/op/add.hpp> +#include <openvino/op/clamp.hpp> +#include <openvino/op/convert.hpp> +#include <openvino/op/cos.hpp> +#include <openvino/op/divide.hpp> +#include <openvino/op/gather.hpp> +#include <openvino/op/maximum.hpp> +#include <openvino/op/multiply.hpp> +#include <openvino/op/reshape.hpp> +#include <openvino/op/shape_of.hpp> +#include <openvino/op/sin.hpp> +#include <openvino/op/squeeze.hpp> +#include <openvino/op/subtract.hpp> +#include <openvino/op/transpose.hpp> +#include <string> + +namespace ov { +namespace frontend { +namespace ggml { + +std::string getCurrentTime() { + std::time_t now = std::time(nullptr); + char buf[100]; + std::strftime(buf, sizeof(buf), "%Y-%m-%d %H:%M:%S", std::localtime(&now)); + return buf; +} + +void num_inputs_check(const NodeContext & context, size_t min_inputs, size_t max_inputs) { + auto input_size = context.get_input_size(); + FRONT_END_OP_CONVERSION_CHECK(input_size >= min_inputs, "Got less inputs than expected"); + FRONT_END_OP_CONVERSION_CHECK(input_size <= max_inputs, "Got more inputs than expected"); +} + +int non_cont_dim(std::vector<size_t> ne, std::vector<size_t> nb) { + int dim = nb.size() - 1; + size_t bytes = nb[dim]; + for (int i = dim; i > 0; i--) { + bytes *= ne[i]; + if (bytes != nb[i - 1]) { + return i; + } + } + return 0; +} + +std::shared_ptr<ov::Node> get_dimensions(const std::shared_ptr<ov::op::v3::ShapeOf> & shape, + const std::vector<int> & dims) { + using namespace ov::op; + const auto zero = v0::Constant::create(ov::element::i32, ov::Shape{}, {0}); + const auto dims_const = v0::Constant::create(ov::element::i32, ov::Shape{dims.size()}, dims); + return std::make_shared<v8::Gather>(shape, dims_const, zero); +} + +std::shared_ptr<ov::Node> get_dimensions(const std::shared_ptr<ov::Node> & node, const std::vector<int> & dims) { + return get_dimensions(std::make_shared<ov::op::v3::ShapeOf>(node), dims); +} + +OutputVector rename_outputs_with_suffix(const OutputVector & outputs, const std::string & suffix) { + for (const auto & output : outputs) { + auto node = output.get_node_shared_ptr(); + std::string name = node->get_friendly_name(); + name += "_"; + name += suffix; + node->set_friendly_name(name); + // std::cout << name << " " << output.get_partial_shape() << std::endl; + } + return outputs; +} + +namespace { +ov::Output<ov::Node> rope_yarn_ramp_mix(int n_dims, const float corr_dims[2], float ext_factor) { + int half_n_dims = n_dims / 2; + std::vector<float> dim_ids_vec(half_n_dims); + std::iota(dim_ids_vec.begin(), dim_ids_vec.end(), 0); + auto dim_ids = ov::op::v0::Constant::create(ov::element::f32, Shape{1, 1, 1, (size_t) half_n_dims}, dim_ids_vec); + auto corr_low = ov::op::v0::Constant::create(ov::element::f32, Shape{1, 1, 1, 1}, {corr_dims[0]}); + auto corr_high = ov::op::v0::Constant::create(ov::element::f32, Shape{1, 1, 1, 1}, {corr_dims[1]}); + auto denom = std::make_shared<ov::op::v1::Maximum>( + std::make_shared<ov::op::v1::Subtract>(corr_high, corr_low), + ov::op::v0::Constant::create(ov::element::f32, Shape{1, 1, 1, 1}, {0.001f})); + auto ramp_y = + std::make_shared<ov::op::v1::Divide>(std::make_shared<ov::op::v1::Subtract>(dim_ids, corr_low), denom); + auto ramp_clamped = std::make_shared<ov::op::v0::Clamp>(ramp_y, 0.0f, 1.0f); + // rope_yarn_ramp returns (1 - clamp(y)), so invert before scaling + auto one = ov::op::v0::Constant::create(ov::element::f32, Shape{1, 1, 1, 1}, {1.0f}); + auto ramp_inverted = std::make_shared<ov::op::v1::Subtract>(one, ramp_clamped); + auto ext_factor_node = ov::op::v0::Constant::create(ov::element::f32, Shape{}, {ext_factor}); + auto ramp_mix = std::make_shared<ov::op::v1::Multiply>(ramp_inverted, ext_factor_node); + return ramp_mix; +} + +float ggml_rope_yarn_corr_dim(int n_dims, int n_ctx_orig, float n_rot, float base) { +#ifndef M_PI +# define M_PI 3.14159265358979323846 +#endif + return n_dims * logf(n_ctx_orig / (n_rot * 2 * (float) M_PI)) / (2 * logf(base)); +} + +void ggml_rope_yarn_corr_dims(int n_dims, + int n_ctx_orig, + float freq_base, + float beta_fast, + float beta_slow, + float dims[2]) { + float start = floorf(ggml_rope_yarn_corr_dim(n_dims, n_ctx_orig, beta_fast, freq_base)); + float end = ceilf(ggml_rope_yarn_corr_dim(n_dims, n_ctx_orig, beta_slow, freq_base)); + dims[0] = std::max(0.0f, start); + dims[1] = std::min(static_cast<float>(n_dims - 1), end); +} +} // namespace + +std::pair<ov::Output<Node>, ov::Output<Node>> make_sin_cos(int32_t * rope_params, + std::shared_ptr<ov::Node> inp_pos, + std::shared_ptr<ov::Node> rope_freqs_weight, + bool imrope, + bool stateful) { + if (stateful) { + inp_pos = std::make_shared<ov::op::v0::Squeeze>(inp_pos, ov::op::v0::Constant::create(ov::element::i64, {1}, {0})); + inp_pos = std::make_shared<ov::op::v0::Convert>(inp_pos, ov::element::f32); + auto pos_perm = + std::make_shared<ov::op::v0::Constant>(ov::element::i64, ov::Shape{3}, std::vector<int64_t>{2, 1, 0}); + inp_pos = std::make_shared<ov::op::v1::Transpose>(inp_pos, pos_perm); + } else if (imrope) { + inp_pos = std::make_shared<ov::op::v0::Convert>(inp_pos, ov::element::f32); + auto pos_shape = ov::op::v0::Constant::create(ov::element::i64, ov::Shape{5}, {0, 0, 0, 4, -1}); + inp_pos = std::make_shared<ov::op::v1::Reshape>(inp_pos, pos_shape, true); + auto pos_transpose_shape = + std::make_shared<ov::op::v0::Constant>(ov::element::i64, ov::Shape{5}, std::vector<int64_t>{0, 1, 2, 4, 3}); + inp_pos = std::make_shared<ov::op::v1::Transpose>(inp_pos, pos_transpose_shape); + } else { + inp_pos = std::make_shared<ov::op::v0::Convert>(inp_pos, ov::element::f32); + auto pos_perm = + std::make_shared<ov::op::v0::Constant>(ov::element::i64, ov::Shape{4}, std::vector<int64_t>{0, 3, 1, 2}); + inp_pos = std::make_shared<ov::op::v1::Transpose>(inp_pos, pos_perm); + } + + float freq_base; + float freq_scale; + float ext_factor; + float attn_factor; + float beta_fast; + float beta_slow; + const int n_dims = rope_params[1]; + const size_t n_dims_half = n_dims >> 1; + const int n_ctx_orig = rope_params[4]; + memcpy(&freq_base, rope_params + 5, sizeof(float)); + memcpy(&freq_scale, rope_params + 6, sizeof(float)); + memcpy(&ext_factor, rope_params + 7, sizeof(float)); + memcpy(&attn_factor, rope_params + 8, sizeof(float)); + memcpy(&beta_fast, rope_params + 9, sizeof(float)); + memcpy(&beta_slow, rope_params + 10, sizeof(float)); + + const float theta_scale = powf(freq_base, -2.0f / n_dims); + + std::vector<float> factor(n_dims_half); + + Output<Node> freq_factors; + + Output<Node> theta; + float mscale = attn_factor; + if (imrope) { + std::vector<int64_t> gather_indices(n_dims_half); + for (size_t j = 0; j < n_dims_half; j++) { + gather_indices[j] = j % 3; + factor[j] = std::pow(theta_scale, j); + } + auto gather_indices_const = + std::make_shared<ov::op::v0::Constant>(ov::element::i64, ov::Shape{n_dims_half}, gather_indices); + auto gather_axis = ov::op::v0::Constant::create(ov::element::i32, ov::Shape{}, {4}); + inp_pos = std::make_shared<ov::op::v8::Gather>(inp_pos, gather_indices_const, gather_axis); + auto factor_const = std::make_shared<ov::op::v0::Constant>(ov::element::f32, ov::Shape{n_dims_half}, factor); + theta = std::make_shared<ov::op::v1::Multiply>(inp_pos, factor_const); + } else { + float corr_dims[2]; + ggml_rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims); + factor[0] = 1.0f; + for (size_t i = 1; i < factor.size(); i++) { + factor[i] = theta_scale * factor[i - 1]; + } + if (stateful) { + freq_factors = + std::make_shared<ov::op::v0::Constant>(ov::element::f32, ov::Shape{1, 1, factor.size()}, factor); + } else { + freq_factors = + std::make_shared<ov::op::v0::Constant>(ov::element::f32, ov::Shape{1, 1, 1, factor.size()}, factor); + } + if (rope_freqs_weight) { + freq_factors = std::make_shared<ov::op::v1::Divide>(freq_factors, rope_freqs_weight); + } + + auto theta_extrap = std::make_shared<ov::op::v1::Multiply>(freq_factors, inp_pos); + auto theta_interp = std::make_shared<ov::op::v1::Multiply>( + theta_extrap, ov::op::v0::Constant::create(ov::element::f32, {1}, {freq_scale})); + + if (ext_factor == 0.0f) { + theta = theta_interp; + } else { + auto ramp_mix = rope_yarn_ramp_mix(n_dims, corr_dims, ext_factor); + Output<Node> one; + if (stateful) { + one = ov::op::v0::Constant::create(ov::element::f32, Shape{1, 1, 1}, {1.0f}); + } else { + one = ov::op::v0::Constant::create(ov::element::f32, Shape{1, 1, 1, 1}, {1.0f}); + } + auto one_minus_ramp = std::make_shared<ov::op::v1::Subtract>(one, ramp_mix); + + theta = std::make_shared<ov::op::v1::Add>(std::make_shared<ov::op::v1::Multiply>(theta_interp, one_minus_ramp), + std::make_shared<ov::op::v1::Multiply>(theta_extrap, ramp_mix)); + mscale *= (1.0f + 0.1f * std::log(1.0f / freq_scale)); + } + } + + Output<Node> cos_theta = std::make_shared<ov::op::v0::Cos>(theta); + Output<Node> sin_theta = std::make_shared<ov::op::v0::Sin>(theta); + + if (!imrope) { + auto mscale_node = ov::op::v0::Constant::create(ov::element::f32, Shape{}, {mscale}); + + cos_theta = std::make_shared<ov::op::v1::Multiply>(cos_theta, mscale_node); + sin_theta = std::make_shared<ov::op::v1::Multiply>(sin_theta, mscale_node); + } + + return std::make_pair(sin_theta, cos_theta); +} + +ov::Output<ov::Node> process_view_input(const NodeContext & context, int input_index, int slice_len) { + // Only works for VIEW operations that slice at the lowest dimension + // If the VIEW also reshape the result, `slice_len` should be provided + auto input = context.get_input(input_index); + auto * op_params = (size_t *) context.get_input_op_params(input_index); + auto src1_stride = context.get_input_stride(input_index); + + int64_t split_addr = op_params[0] / src1_stride[3]; + if (slice_len == 0) { + slice_len = context.get_input_shape(input_index)[3].get_length(); + } + int64_t slice_end = split_addr + slice_len; + + auto begin = ov::op::v0::Constant::create(ov::element::i64, {1}, {split_addr}); + auto end = ov::op::v0::Constant::create(ov::element::i64, {1}, {slice_end}); + auto stride = ov::op::v0::Constant::create(ov::element::i64, {1}, {1}); + auto axes = ov::op::v0::Constant::create(ov::element::i64, {1}, {context.is_stateful() ? 2 : 3}); + auto sliced = std::make_shared<ov::op::v8::Slice>(input, begin, end, stride, axes); + return sliced; +} + +} // namespace ggml +} // namespace frontend +} // namespace ov diff --git a/ggml/src/ggml-openvino/openvino/utils.h b/ggml/src/ggml-openvino/openvino/utils.h new file mode 100644 index 00000000000..767dd4c53ea --- /dev/null +++ b/ggml/src/ggml-openvino/openvino/utils.h @@ -0,0 +1,86 @@ +#pragma once + +#include <memory> +#include <openvino/core/node.hpp> +#include <openvino/op/shape_of.hpp> +#include <openvino/op/slice.hpp> +#include <utility> + +#include "node_context.h" + +namespace ov { +namespace frontend { +namespace ggml { + +std::string getCurrentTime(); + +void dump_ov_model(std::shared_ptr<ov::Model> model); + +void num_inputs_check(const NodeContext& context, size_t min_inputs, size_t max_inputs); + +int non_cont_dim(std::vector<size_t> ne, std::vector<size_t> nb); + +template <typename T> +std::vector<int> argsort_descend(const std::vector<T>& v) { + std::vector<int> idx(v.size()); + std::iota(idx.begin(), idx.end(), 0); + std::sort(idx.begin(), idx.end(), [&v](int i1, int i2) { + return v[i1] > v[i2]; + }); + return idx; +} + +template <typename T> +std::vector<T> sorted_descend(std::vector<T> v) { + std::sort(v.begin(), v.end(), [](T a, T b) { + return a > b; + }); + return v; +} + +template <typename T> +bool is_permuted(const std::vector<T>& strides) { + for (size_t i = 0; i < strides.size() - 1; ++i) { + if (strides[i] < strides[i + 1]) { + return true; + } + } + return false; +} + +template <typename T> +std::vector<T> permute(const std::vector<T>& x, const std::vector<int>& perm) { + std::vector<T> result; + result.reserve(perm.size()); + for (int i : perm) { + result.push_back(x[i]); + } + return result; +} + +std::shared_ptr<ov::Node> get_dimensions(const std::shared_ptr<ov::op::v3::ShapeOf>& shape, + const std::vector<int>& dims); +std::shared_ptr<ov::Node> get_dimensions(const std::shared_ptr<ov::Node>& node, const std::vector<int>& dims); + +OutputVector rename_outputs_with_suffix(const OutputVector& outputs, const std::string& suffix); + +std::pair<ov::Output<Node>, ov::Output<Node>> make_sin_cos(int32_t* rope_params, + std::shared_ptr<ov::Node> inp_pos, + std::shared_ptr<ov::Node> rope_freqs_weight = nullptr, + bool imrope = false, + bool stateful = false); + +ov::Output<ov::Node> process_view_input(const NodeContext& context, int input_index, int slice_len = 0); + +namespace op { +template <typename T> +OutputVector translate_1to1_match_2_inputs(const NodeContext& context) { + num_inputs_check(context, 2, 2); + auto res = std::make_shared<T>(context.get_input(0), context.get_input(1)); + return rename_outputs_with_suffix({res}, context.get_name()); +} +} // namespace op + +} // namespace ggml +} // namespace frontend +} // namespace ov diff --git a/ggml/src/ggml-openvino/utils.cpp b/ggml/src/ggml-openvino/utils.cpp new file mode 100644 index 00000000000..998ef7c9eb4 --- /dev/null +++ b/ggml/src/ggml-openvino/utils.cpp @@ -0,0 +1,880 @@ +#include "utils.h" + +#include "ggml-impl.h" +#include "ggml-openvino-extra.h" +#include "ggml-openvino/ggml-decoder.h" +#include "ggml.h" +#include "openvino/frontend.h" +#include "openvino/input_model.h" + +#include <algorithm> +#include <cassert> +#include <cmath> +#include <cstddef> +#include <cstdint> +#include <cstdlib> +#include <cstring> +#include <iomanip> +#include <iostream> +#include <memory> +#include <openvino/core/any.hpp> +#include <openvino/core/graph_util.hpp> +#include <openvino/core/shape.hpp> +#include <openvino/core/type/float16.hpp> +#include <openvino/frontend/manager.hpp> +#include <openvino/openvino.hpp> +#include <openvino/runtime/compiled_model.hpp> +#include <openvino/runtime/infer_request.hpp> +#include <openvino/runtime/intel_npu/properties.hpp> +#include <openvino/runtime/properties.hpp> +#include <openvino/runtime/tensor.hpp> +#include <string> +#include <unordered_map> +#include <vector> + +// Suppress deprecation warning for ov::Tensor::data() +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wdeprecated-declarations" + +enum ggml_status ov_graph_compute(ggml_cgraph * cgraph, ggml_backend_t backend) { + ggml_backend_openvino_context * ctx = (ggml_backend_openvino_context *) backend->context; + try { + if (getenv("GGML_OPENVINO_DUMP_CGRAPH")) { + std::string filename = "cgraph_ov.txt"; + GgmlOvDecoder::dump_cgraph(cgraph, filename); + } + + const auto is_static = ggml_openvino_is_npu(); + + GGML_ASSERT(ctx->runtime_context != nullptr); + std::shared_ptr<ov_runtime_context> r_ctx = std::static_pointer_cast<ov_runtime_context>(ctx->runtime_context); + + return is_static ? ov_graph_compute_static(cgraph, r_ctx) : ov_graph_compute_dynamic(cgraph, r_ctx); + } catch (const ov::Exception & e) { + GGML_LOG_ERROR("GGML OpenVINO backend ov::Exception: %s\n", e.what()); + return GGML_STATUS_FAILED; + } catch (const std::exception & e) { + GGML_LOG_ERROR("GGML OpenVINO backend std::exception: %s\n", e.what()); + return GGML_STATUS_FAILED; + } catch (...) { + GGML_LOG_ERROR("GGML OpenVINO backend unknown exception\n"); + return GGML_STATUS_FAILED; + } +} + +ov::Tensor create_ov_output_tensor(std::shared_ptr<GgmlOvDecoder> ggml_decoder, + std::shared_ptr<ov::InferRequest> infer_request, + int output_index, + const ggml_tensor * ggml_tensor) { + auto output_type = ggml_decoder->get_ov_type(ggml_tensor); + ov::Shape output_shape; + if (ggml_decoder->is_static()) { + output_shape = infer_request->get_output_tensor(output_index).get_shape(); + } else { + output_shape = ggml_decoder->get_shape(ggml_tensor); + } + + ov::Tensor output_tensor(output_type, output_shape, ggml_tensor->data); + return output_tensor; +} + +enum ggml_status ov_graph_compute_dynamic(ggml_cgraph * cgraph, std::shared_ptr<ov_runtime_context> r_ctx) { + auto & core = ov_singleton_core(); + const auto & config = ggml_openvino_get_compile_config(); + const auto & device = r_ctx->device; + const auto & stateful = r_ctx->stateful; + static auto is_static = false; + + if (is_naive(cgraph)) { + return naive_compute(cgraph, core, device, config); + } + + auto start_time = ggml_time_us(); + + std::shared_ptr<GgmlOvDecoder> ggml_decoder; + std::shared_ptr<ov::InferRequest> infer_request; + ModelParams m_params; + ComputeParams c_params; + std::tie(m_params, c_params) = GgmlOvDecoder::compute_llm_params(cgraph, is_static); + + graph_key key(cgraph); + bool cache_hit; + + int64_t decoder_end_time; + int64_t conversion_end_time; + int64_t compile_end_time; + int64_t infer_end_time; + + { + std::shared_ptr<decoder_runtime_ctx> entry; + ModelParams old_m_params; + + { + std::lock_guard<std::mutex> map_lock(r_ctx->ctx_mutex); + auto it = r_ctx->decoder_cache.find(key); + cache_hit = it != r_ctx->decoder_cache.end(); + if (cache_hit) { + entry = it->second; + } else { + auto mutex = std::make_shared<std::mutex>(); + entry = std::make_shared<decoder_runtime_ctx>(mutex); + r_ctx->decoder_cache[key] = entry; + } + } + + std::lock_guard<std::mutex> lock(*(entry->mutex)); + + if (cache_hit) { + ggml_decoder = entry->ptr; + old_m_params = ggml_decoder->get_model_params(); + cache_hit = old_m_params.can_reuse_dynamically(m_params); + } + + if (cache_hit) { + std::map<std::string, std::shared_ptr<ov::Node>> model_weights; + ggml_decoder->set_compute_params(c_params); + ggml_decoder->set_model_params(m_params); + if (old_m_params.kv_buffer_changed(m_params)) { + ggml_decoder->update_io(cgraph); + } + ggml_decoder->add_extra_inputs(); + { + std::lock_guard<std::mutex> map_lock(r_ctx->ctx_mutex); + infer_request = r_ctx->infer_request_cache.at(key); + } + + if (stateful) { + const auto * inp_pos = get_inp_pos_tensor(cgraph); + int32_t * pos_data = (int32_t *) inp_pos->data; + auto pos_shape = ggml_decoder->get_shape(inp_pos); + if (pos_data[0] == 0) { + infer_request->reset_state(); + r_ctx->stateful_kv_size = pos_shape[3]; + } else if (r_ctx->stateful_kv_size == static_cast<size_t>(pos_data[0])) { + r_ctx->stateful_kv_size += pos_shape[3]; + } else { + auto states = infer_request->query_state(); + for (auto state : states) { + auto state_tensor = state.get_state(); + auto state_tensor_shape = state_tensor.get_shape(); + if (static_cast<uint32_t>(pos_data[0]) > r_ctx->stateful_kv_size) { + std::string state_name; + try { + state_name = r_ctx->kv_state_input_name_map.at(state.get_name()); + } catch (...) { + GGML_LOG_ERROR("GGML OpenVINO backend stateful inference failed: no input found for the state\n"); + return GGML_STATUS_FAILED; + } + auto kv_tensor = get_ov_input_tensor(ggml_decoder, state_name); + kv_tensor.set_shape({state_tensor_shape[0], kv_tensor.get_shape()[2], + state_tensor_shape[2], state_tensor_shape[3]}); + state_tensor = kv_tensor; + state_tensor_shape = state_tensor.get_shape(); + } + ov::Coordinate begin = {0, 0, 0, 0}; + ov::Coordinate end = {state_tensor_shape[0], static_cast<uint32_t>(pos_data[0]), + state_tensor_shape[2], state_tensor_shape[3]}; + ov::Tensor new_state_tensor(state_tensor, begin, end); + state.set_state(new_state_tensor); + } + r_ctx->stateful_kv_size = pos_data[0] + 1; + } + } + + decoder_end_time = ggml_time_us(); + conversion_end_time = decoder_end_time; + compile_end_time = decoder_end_time; + } else { + { + std::lock_guard<std::mutex> map_lock(r_ctx->ctx_mutex); + r_ctx->infer_request_cache.erase(key); + } + + std::shared_ptr<ov::Model> model; + auto model_weights = GgmlOvDecoder::create_weight_nodes(cgraph); + + ggml_decoder = std::make_shared<GgmlOvDecoder>(cgraph, m_params, c_params, model_weights, is_static, stateful); + decoder_end_time = ggml_time_us(); + + auto input_model = std::make_shared<ov::frontend::ggml::InputModel>(ggml_decoder); + model = ov::frontend::ggml::FrontEnd::convert(input_model); + ggml_decoder->clear_model_weights(); + conversion_end_time = ggml_time_us(); + + if (getenv("GGML_OPENVINO_DUMP_IR")) { + char timestamped_filename[64]; + auto timestamp = (long long) ggml_time_us(); + snprintf(timestamped_filename, sizeof(timestamped_filename), "model_%lld.xml", timestamp); + ov::serialize(model, timestamped_filename); + } + + ov::CompiledModel compiled_model; + auto remote_context = ggml_openvino_get_remote_context(); + if (remote_context.has_value()) { + compiled_model = core.compile_model(model, remote_context.value(), config); + } else { + compiled_model = core.compile_model(model, device, config); + } + compile_end_time = ggml_time_us(); + infer_request = std::make_shared<ov::InferRequest>(compiled_model.create_infer_request()); + entry->ptr = ggml_decoder; + + std::vector<std::string> ov_input_names; + std::vector<std::string> ov_output_names; + for (const auto & ov_param : model->get_parameters()) { + ov_input_names.push_back(ov_param->get_friendly_name()); + } + for (const auto & ov_output : model->get_results()) { + ov_output_names.push_back(ov_output->get_friendly_name()); + } + + { + std::lock_guard<std::mutex> map_lock(r_ctx->ctx_mutex); + r_ctx->infer_request_cache[key] = infer_request; + r_ctx->ov_input_names_cache[key] = std::move(ov_input_names); + r_ctx->ov_output_names_cache[key] = std::move(ov_output_names); + } + + if (stateful) { + const auto * inp_pos = get_inp_pos_tensor(cgraph); + auto pos_shape = ggml_decoder->get_shape(inp_pos); + r_ctx->stateful_kv_size = pos_shape[3]; + const auto kv_param_res_names = ggml_decoder->get_kv_param_res_names(); + for (const auto& pair : kv_param_res_names) { + r_ctx->kv_state_input_name_map[pair.first+pair.second] = pair.first; + } + } + } + + std::vector<std::string> ov_input_names; + std::vector<std::string> ov_output_names; + { + std::lock_guard<std::mutex> map_lock(r_ctx->ctx_mutex); + ov_input_names = r_ctx->ov_input_names_cache[key]; + ov_output_names = r_ctx->ov_output_names_cache[key]; + } + + for (size_t i = 0; i < ov_input_names.size(); i++) { + auto param_name = ov_input_names[i]; + auto input_tensor = get_ov_input_tensor(ggml_decoder, param_name); + infer_request->set_input_tensor(i, input_tensor); + + if (getenv("GGML_OPENVINO_DEBUG_INPUT")) { + print_input_tensor_info(param_name, input_tensor); + } + } + + for (size_t i = 0; i < ov_output_names.size(); i++) { + auto * ggml_tensor = ggml_decoder->get_model_outputs().at(ov_output_names[i]); + auto output_tensor = create_ov_output_tensor(ggml_decoder, infer_request, i, ggml_tensor); + infer_request->set_output_tensor(i, output_tensor); + } + + infer_request->infer(); + infer_end_time = ggml_time_us(); + + if (getenv("GGML_OPENVINO_DEBUG_OUTPUT")) { + for (size_t i = 0; i < ov_output_names.size(); i++) { + const auto output_tensor = infer_request->get_output_tensor(i); + print_output_tensor_info(ov_output_names[i], output_tensor, output_tensor.data()); + } + } + + if (getenv("GGML_OPENVINO_PROFILING")) { + GGML_LOG_INFO("\nGGML OpenVINO Backend: \n"); + GGML_LOG_INFO(" - Graph decoder time: %ld ms \n", (decoder_end_time - start_time) / 1000); + if (!cache_hit) { + GGML_LOG_INFO(" - Graph conversion time: %ld ms \n", (conversion_end_time - decoder_end_time) / 1000); + GGML_LOG_INFO(" - Graph compile time: %ld ms \n", (compile_end_time - conversion_end_time) / 1000); + } + GGML_LOG_INFO(" - Graph inference time: %ld ms \n", (infer_end_time - compile_end_time) / 1000); + } + } + + return GGML_STATUS_SUCCESS; +} + +enum ggml_status ov_graph_compute_static(ggml_cgraph * cgraph, std::shared_ptr<ov_runtime_context> r_ctx) { + auto & core = ov_singleton_core(); + + auto get_prefill_chunk_size = [] { + const char * chunk_size_str = getenv("GGML_OPENVINO_PREFILL_CHUNK_SIZE"); + if (chunk_size_str && atoi(chunk_size_str) > 0) { + return atoi(chunk_size_str); + } + return 256; + }; + + static std::string device = "NPU"; + static auto is_static = true; + static auto stateful = false; + static auto prefill_chunk_size = get_prefill_chunk_size(); + const auto & config = ggml_openvino_get_compile_config(); + + if (is_naive(cgraph)) { + return naive_compute(cgraph, core, device, config); + } + + auto start_time = ggml_time_us(); + + std::shared_ptr<GgmlOvDecoder> ggml_decoder; + std::shared_ptr<ov::InferRequest> infer_request; + ModelParams m_params; + ComputeParams c_params; + std::tie(m_params, c_params) = GgmlOvDecoder::compute_llm_params(cgraph, is_static); + + const auto * inp_pos = get_inp_pos_tensor(cgraph); + const auto is_prefill = get_is_prefill(inp_pos); + graph_key key(cgraph); + bool cache_hit; + + int64_t decoder_end_time; + int64_t conversion_end_time; + int64_t compile_end_time; + int64_t infer_end_time; + + std::shared_ptr<decoder_runtime_ctx> entry; + ModelParams old_m_params; + + { + std::lock_guard<std::mutex> map_lock(r_ctx->ctx_mutex); + auto it = r_ctx->decoder_cache.find(key); + cache_hit = it != r_ctx->decoder_cache.end(); + if (cache_hit) { + entry = it->second; + } else { + auto mutex = std::make_shared<std::mutex>(); + entry = std::make_shared<decoder_runtime_ctx>(mutex); + r_ctx->decoder_cache[key] = entry; + } + } + + std::lock_guard<std::mutex> lock(*(entry->mutex)); + + if (cache_hit) { + ggml_decoder = entry->ptr; + old_m_params = ggml_decoder->get_model_params(); + cache_hit = old_m_params.can_reuse_statically(m_params); + } + + if (cache_hit) { + std::map<std::string, std::shared_ptr<ov::Node>> model_weights; + ggml_decoder->m_is_prefill = is_prefill; + ggml_decoder->set_model_params(m_params); + ggml_decoder->set_compute_params(c_params); + if (old_m_params.kv_buffer_changed(m_params)) { + ggml_decoder->update_io(cgraph); + } + ggml_decoder->add_extra_inputs(); + { + std::lock_guard<std::mutex> map_lock(r_ctx->ctx_mutex); + infer_request = + is_prefill ? r_ctx->infer_request_cache_prefill.at(key) : r_ctx->infer_request_cache.at(key); + } + + decoder_end_time = ggml_time_us(); + conversion_end_time = decoder_end_time; + compile_end_time = decoder_end_time; + } else { + { + std::lock_guard<std::mutex> map_lock(r_ctx->ctx_mutex); + r_ctx->infer_request_cache.erase(key); + r_ctx->infer_request_cache_prefill.erase(key); + } + + std::shared_ptr<ov::Model> model; + auto model_weights = GgmlOvDecoder::create_weight_nodes(cgraph); + + auto ggml_decoder_prefill = std::make_shared<GgmlOvDecoder>(cgraph, m_params, c_params, model_weights, + is_static, stateful, true, prefill_chunk_size); + auto ggml_decoder_decode = std::make_shared<GgmlOvDecoder>(cgraph, m_params, c_params, model_weights, is_static, + stateful, false, prefill_chunk_size); + decoder_end_time = ggml_time_us(); + + auto input_model_prefill = std::make_shared<ov::frontend::ggml::InputModel>(ggml_decoder_prefill); + auto input_model_decode = std::make_shared<ov::frontend::ggml::InputModel>(ggml_decoder_decode); + + auto model_prefill = ov::frontend::ggml::FrontEnd::convert(input_model_prefill); + ggml_decoder_prefill->clear_model_weights(); + auto model_decode = ov::frontend::ggml::FrontEnd::convert(input_model_decode); + ggml_decoder_decode->clear_model_weights(); + conversion_end_time = ggml_time_us(); + + if (getenv("GGML_OPENVINO_DUMP_IR")) { + char timestamped_filename[64]; + auto timestamp = (long long) ggml_time_us(); + snprintf(timestamped_filename, sizeof(timestamped_filename), "model_prefill_%lld.xml", timestamp); + ov::serialize(model_prefill, timestamped_filename); + snprintf(timestamped_filename, sizeof(timestamped_filename), "model_decode_%lld.xml", timestamp); + ov::serialize(model_decode, timestamped_filename); + } + + ov::CompiledModel compiled_model_prefill; + ov::CompiledModel compiled_model_decode; + auto remote_context = ggml_openvino_get_remote_context(); + if (remote_context.has_value()) { + compiled_model_prefill = core.compile_model(model_prefill, remote_context.value(), config); + compiled_model_decode = core.compile_model(model_decode, remote_context.value(), config); + } else { + compiled_model_prefill = core.compile_model(model_prefill, device, config); + compiled_model_decode = core.compile_model(model_decode, device, config); + } + + auto infer_request_prefill = std::make_shared<ov::InferRequest>(compiled_model_prefill.create_infer_request()); + auto infer_request_decode = std::make_shared<ov::InferRequest>(compiled_model_decode.create_infer_request()); + compile_end_time = ggml_time_us(); + + model = is_prefill ? model_prefill : model_decode; + ggml_decoder = is_prefill ? ggml_decoder_prefill : ggml_decoder_decode; + infer_request = is_prefill ? infer_request_prefill : infer_request_decode; + entry->ptr = ggml_decoder; + + std::vector<std::string> ov_input_names; + std::vector<std::string> ov_output_names; + for (const auto & ov_param : model->get_parameters()) { + ov_input_names.push_back(ov_param->get_friendly_name()); + } + for (const auto & ov_output : model->get_results()) { + ov_output_names.push_back(ov_output->get_friendly_name()); + } + + { + std::lock_guard<std::mutex> map_lock(r_ctx->ctx_mutex); + r_ctx->infer_request_cache_prefill[key] = infer_request_prefill; + r_ctx->infer_request_cache[key] = infer_request_decode; + r_ctx->ov_input_names_cache[key] = std::move(ov_input_names); + r_ctx->ov_output_names_cache[key] = std::move(ov_output_names); + } + } + + std::vector<std::string> ov_input_names_local; + std::vector<std::string> ov_output_names_local; + { + std::lock_guard<std::mutex> map_lock(r_ctx->ctx_mutex); + ov_input_names_local = r_ctx->ov_input_names_cache[key]; + ov_output_names_local = r_ctx->ov_output_names_cache[key]; + } + + if (is_prefill) { + auto inp_len = inp_pos->ne[0]; + for (int chunk_index = 0; chunk_index * prefill_chunk_size < inp_len; chunk_index++) { + for (size_t i = 0; i < ov_input_names_local.size(); i++) { + auto param_name = ov_input_names_local[i]; + auto input_tensor = get_ov_input_tensor_static_prefill(ggml_decoder, param_name, chunk_index); + infer_request->set_input_tensor(i, input_tensor); + + if (getenv("GGML_OPENVINO_DEBUG_INPUT")) { + const auto input_tensor = infer_request->get_input_tensor(i); + print_input_tensor_info(param_name, input_tensor); + } + } + + for (size_t i = 0; i < ov_output_names_local.size(); i++) { + auto * ggml_tensor = ggml_decoder->get_model_outputs().at(ov_output_names_local[i]); + auto output_tensor = create_ov_output_tensor(ggml_decoder, infer_request, i, ggml_tensor); + infer_request->set_output_tensor(i, output_tensor); + } + + infer_request->infer(); + + if (getenv("GGML_OPENVINO_DEBUG_OUTPUT")) { + for (size_t i = 0; i < ov_output_names_local.size(); i++) { + const auto output_tensor = infer_request->get_output_tensor(i); + print_output_tensor_info(ov_output_names_local[i], output_tensor, output_tensor.data()); + } + } + } + infer_end_time = ggml_time_us(); + } else { + for (size_t i = 0; i < ov_input_names_local.size(); i++) { + auto param_name = ov_input_names_local[i]; + auto input_tensor = get_ov_input_tensor_static_decode(ggml_decoder, param_name); + infer_request->set_input_tensor(i, input_tensor); + + if (getenv("GGML_OPENVINO_DEBUG_INPUT")) { + const auto input_tensor = infer_request->get_input_tensor(i); + print_input_tensor_info(param_name, input_tensor); + } + } + + for (size_t i = 0; i < ov_output_names_local.size(); i++) { + auto * ggml_tensor = ggml_decoder->get_model_outputs().at(ov_output_names_local[i]); + auto output_tensor = create_ov_output_tensor(ggml_decoder, infer_request, i, ggml_tensor); + infer_request->set_output_tensor(i, output_tensor); + } + + infer_request->infer(); + infer_end_time = ggml_time_us(); + + if (getenv("GGML_OPENVINO_DEBUG_OUTPUT")) { + for (size_t i = 0; i < ov_output_names_local.size(); i++) { + const auto output_tensor = infer_request->get_output_tensor(i); + print_output_tensor_info(ov_output_names_local[i], output_tensor, output_tensor.data()); + } + } + } + + if (getenv("GGML_OPENVINO_PROFILING")) { + GGML_LOG_INFO("\nGGML OpenVINO Backend: \n"); + GGML_LOG_INFO(" - Graph decoder time: %ld ms \n", (decoder_end_time - start_time) / 1000); + if (!cache_hit) { + GGML_LOG_INFO(" - Graph conversion time: %ld ms \n", (conversion_end_time - decoder_end_time) / 1000); + GGML_LOG_INFO(" - Graph compile time: %ld ms \n", (compile_end_time - conversion_end_time) / 1000); + } + GGML_LOG_INFO(" - Graph inference time: %ld ms \n", (infer_end_time - compile_end_time) / 1000); + } + + return GGML_STATUS_SUCCESS; +} + +bool is_naive(ggml_cgraph * cgraph) { + constexpr int naive_graph_size_threshold = 20; + int count = 0; + for (int i = 0; i < cgraph->n_nodes; i++) { + if (cgraph->nodes[i]->op != GGML_OP_NONE) { + count++; + } + } + return count < naive_graph_size_threshold; +} + +enum ggml_status naive_compute(ggml_cgraph * cgraph, + ov::Core & core, + const std::string & device, + const ov::AnyMap & config) { + if (cgraph->n_nodes == 1 && (cgraph->nodes[0]->op == GGML_OP_NONE || cgraph->nodes[0]->op == GGML_OP_VIEW)) { + return GGML_STATUS_SUCCESS; + } + + bool naive = true; + auto model_weights = GgmlOvDecoder::create_weight_nodes(cgraph, naive); + auto decoder = std::make_shared<GgmlOvDecoder>(cgraph, model_weights); + auto input_model = std::make_shared<ov::frontend::ggml::InputModel>(decoder); + auto model = ov::frontend::ggml::FrontEnd::convert(input_model, naive); + if (getenv("GGML_OPENVINO_DUMP_IR")) { + ov::serialize(model, "IR_naive.xml"); + } + + std::shared_ptr<ov::InferRequest> infer_request; + auto remote_context = ggml_openvino_get_remote_context(); + if (cgraph->nodes[0]->op == GGML_OP_MUL_MAT) { + // TODO ACCURACY hint triggers a bug in GPU plugin/driver on Lunar Lake. Remove once CVS-182166 is resolved + core.set_property(device, ov::hint::execution_mode(ov::hint::ExecutionMode::PERFORMANCE)); + } else { + core.set_property(device, ov::hint::execution_mode(ov::hint::ExecutionMode::ACCURACY)); + } + if (remote_context.has_value()) { + infer_request = std::make_shared<ov::InferRequest>( + core.compile_model(model, remote_context.value(), config).create_infer_request()); + } else { + infer_request = + std::make_shared<ov::InferRequest>(core.compile_model(model, device, config).create_infer_request()); + } + + auto ov_params = model->get_parameters(); + for (size_t i = 0; i < ov_params.size(); i++) { + auto param_name = ov_params[i]->get_friendly_name(); + auto input_tensor = get_ov_input_tensor(decoder, param_name); + infer_request->set_input_tensor(i, input_tensor); + } + + auto ov_results = model->get_results(); + for (size_t i = 0; i < ov_results.size(); i++) { + auto * ggml_tensor = decoder->get_model_outputs().at(ov_results[i]->get_friendly_name()); + auto output_tensor = create_ov_output_tensor(decoder, infer_request, i, ggml_tensor); + infer_request->set_output_tensor(i, output_tensor); + } + + infer_request->infer(); + return GGML_STATUS_SUCCESS; +} + +namespace { +ov::Tensor convert_ggml_input_to_ov(std::shared_ptr<GgmlOvDecoder> ggml_decoder, const std::string & name) { + const auto * ggml_tensor = ggml_decoder->get_input_ggml_tensor(name); + + if (ggml_tensor->extra != nullptr) { + // GGML_LOG_DEBUG("Using ggml_tensor->extra as ov::Tensor for input: %s\n", name.c_str()); + auto * extra_base = static_cast<ggml_openvino_extra_base *>(ggml_tensor->extra); + if (extra_base->type != ggml_openvino_extra_base::Type::TENSOR) { + throw std::runtime_error("ggml tensor extra is not of type TENSOR for input: " + name); + } + auto * tensor_extra = static_cast<ggml_openvino_tensor_extra *>(extra_base); + return *tensor_extra->tensor; + } + + // GGML_LOG_DEBUG("Converting ggml tensor to ov::Tensor for input: %s\n", name.c_str()); + auto * input_data = ggml_tensor->data; + ov::Shape input_shape; + if (ggml_tensor->op == GGML_OP_VIEW) { + // This case is added to make test-backend-ops work + input_shape = ggml_decoder->get_shape(ggml_tensor->view_src); + } else { + input_shape = ggml_decoder->get_shape(ggml_tensor); + } + auto input_tensor = ov::Tensor(ggml_decoder->get_ov_type(ggml_tensor), input_shape, input_data); + return input_tensor; +} +} // namespace + +ov::Tensor get_ov_input_tensor(std::shared_ptr<GgmlOvDecoder> ggml_decoder, const std::string & param_name) { + ov::Tensor input_tensor; + if (ggml_decoder->get_model_extra_inputs().find(param_name) != ggml_decoder->get_model_extra_inputs().end()) { + input_tensor = *ggml_decoder->get_model_extra_input_values().at(param_name); + } else { + input_tensor = convert_ggml_input_to_ov(ggml_decoder, param_name); + } + return input_tensor; +} + +ov::Tensor get_ov_input_tensor_static_decode(std::shared_ptr<GgmlOvDecoder> ggml_decoder, + const std::string & param_name) { + // NPU decoding stage + const auto * ggml_tensor = ggml_decoder->get_input_ggml_tensor(param_name); + const auto * op = ggml_decoder->get_tensor_used_op(ggml_tensor); + + if (GgmlOvDecoder::is_inp_tok(ggml_tensor, op) || GgmlOvDecoder::is_inp_pos(ggml_tensor, op) || + GgmlOvDecoder::is_kv_idx(ggml_tensor, op)) { + assert(ggml_tensor->ne[0] == 1); + ov::Shape input_shape = {1, 1, 1, 1}; + ov::Tensor input_tensor(ggml_decoder->get_ov_type(ggml_tensor), input_shape); + if (ggml_tensor->type == GGML_TYPE_I32) { + *input_tensor.data<int32_t>() = *((int32_t *) ggml_tensor->data); + } else if (ggml_tensor->type == GGML_TYPE_I64) { + *input_tensor.data<int64_t>() = *((int64_t *) ggml_tensor->data); + } else { + throw std::runtime_error("Unexpected tensor type for " + param_name); + } + return input_tensor; + } + + if (GgmlOvDecoder::is_output_idx(ggml_tensor, op)) { + ov::Shape input_shape = {1, 1, 1, 1}; + ov::Tensor input_tensor(ggml_decoder->get_ov_type(ggml_tensor), input_shape); + int32_t inp_out_id = *((int32_t *) ggml_tensor->data); + assert(ggml_tensor->ne[0] == 1); + assert(inp_out_id == 0); + *input_tensor.data<int32_t>() = inp_out_id; + return input_tensor; + } + + if (GgmlOvDecoder::is_inp_mask(ggml_tensor, op)) { + size_t context_size = ggml_decoder->get_ctx_size(); + std::vector<float> padded_data = pad_input<float>(ggml_tensor, 1, context_size, -INFINITY); + ov::Tensor input_tensor(ov::element::f32, ov::Shape{1, 1, 1, context_size}); + auto * data_ptr = input_tensor.data<float>(); + std::copy(padded_data.begin(), padded_data.begin() + context_size, data_ptr); + return input_tensor; + } + + return get_ov_input_tensor(ggml_decoder, param_name); +} + +ov::Tensor get_ov_input_tensor_static_prefill(std::shared_ptr<GgmlOvDecoder> ggml_decoder, + const std::string & param_name, + int chunk_index) { + // NPU prompt processing stage + const auto * ggml_tensor = ggml_decoder->get_input_ggml_tensor(param_name); + const auto * op = ggml_decoder->get_tensor_used_op(ggml_tensor); + + const size_t input_len = ggml_decoder->get_input_len(); + const size_t chunk_size = ggml_decoder->m_prefill_chunk_size; + const size_t chunk_valid_size = std::min(chunk_size, input_len - chunk_index * chunk_size); + const size_t chunk_pad_size = chunk_size - chunk_valid_size; + + if (GgmlOvDecoder::is_inp_tok(ggml_tensor, op) || GgmlOvDecoder::is_inp_pos(ggml_tensor, op) || + GgmlOvDecoder::is_kv_idx(ggml_tensor, op)) { + ov::Shape input_shape = {1, 1, 1, chunk_size}; + ov::Tensor input_tensor(ggml_decoder->get_ov_type(ggml_tensor), input_shape); + // copy the chunk_index-th chunk from ggml_tensor + size_t element_size = ggml_type_size(ggml_tensor->type); + void * input_data = (char *) ggml_tensor->data + chunk_index * chunk_size * element_size; + std::memcpy(input_tensor.data(), input_data, chunk_valid_size * element_size); + // pad the rest with last_value + 1, so that kv's of padded positions are inserted + // to the next row after the valids row in the kvcache + if (chunk_pad_size > 0) { + if (ggml_tensor->type == GGML_TYPE_I32) { + int32_t last_value = + *((int32_t *) ggml_tensor->data + (chunk_index * chunk_size + chunk_valid_size - 1)); + int32_t * output_data = input_tensor.data<int32_t>(); + std::fill(output_data + chunk_valid_size, output_data + chunk_size, last_value + 1); + } else if (ggml_tensor->type == GGML_TYPE_I64) { + int64_t last_value = + *((int64_t *) ggml_tensor->data + (chunk_index * chunk_size + chunk_valid_size - 1)); + int64_t * output_data = input_tensor.data<int64_t>(); + std::fill(output_data + chunk_valid_size, output_data + chunk_size, last_value + 1); + } else { + throw std::runtime_error("Unexpected tensor type for " + param_name); + } + } + return input_tensor; + } + + if (GgmlOvDecoder::is_output_idx(ggml_tensor, op)) { + size_t output_len = ggml_decoder->get_compute_params().output_len; + ov::Shape input_shape = {1, 1, 1, output_len}; + ov::Tensor input_tensor(ggml_decoder->get_ov_type(ggml_tensor), input_shape); + if (ggml_tensor->ne[0] == 0) { + *input_tensor.data<int32_t>() = 0; + } else { + auto * data_addr = input_tensor.data<int32_t>(); + for (size_t i = 0; i < output_len; i++) { + data_addr[i] = ((int32_t *) ggml_tensor->data)[i] % chunk_size; + } + } + return input_tensor; + } + + if (GgmlOvDecoder::is_inp_mask(ggml_tensor, op)) { + size_t cols = ggml_tensor->ne[0]; + size_t rows = ggml_tensor->ne[1]; + float * ggml_data = (float *) ggml_tensor->data + chunk_index * chunk_size * cols; + size_t chunk_valid_rows = std::min(chunk_size, rows - chunk_index * chunk_size); + size_t context_size = ggml_decoder->get_ctx_size(); + std::vector<float> padded_data = + pad_input<float>(ggml_data, chunk_valid_rows, cols, chunk_size, context_size, -INFINITY); + set_zero_diagonal(padded_data, chunk_size, context_size); + ov::Tensor input_tensor(ov::element::f32, ov::Shape{1, 1, chunk_size, context_size}); + auto * data_ptr = input_tensor.data<float>(); + std::copy(padded_data.begin(), padded_data.begin() + chunk_size * context_size, data_ptr); + return input_tensor; + } + + return get_ov_input_tensor(ggml_decoder, param_name); +} + +size_t checksum(const void * data, size_t size) { + const uint8_t * bytes = static_cast<const uint8_t *>(data); + size_t sum = 0; + for (size_t i = 0; i < size; ++i) { + sum += (uint8_t) i; + sum += bytes[i]; + } + return sum; +} + +void print_input_tensor_info(const std::string & name, const ov::Tensor & tensor) { + std::cout << "Input name: " << name << ", Input shape: " << tensor.get_shape() << ", Address: " << tensor.data() + << std::endl; + switch (tensor.get_element_type()) { + case ov::element::f32: { + if (name.find("self_kq_mask") == std::string::npos) { + std::cout << *(tensor.data<float>()) << std::endl; + } else { + size_t rows = tensor.get_shape()[2]; + size_t cols = tensor.get_shape()[3]; + auto * data = tensor.data<float>(); + for (size_t i = 0; i < rows; ++i) { + for (size_t j = 0; j < cols; ++j) { + float val = data[i * cols + j]; + if (std::isinf(val) && val < 0) { + std::cout << std::setw(5) << "-inf"; + } else { + std::cout << std::setw(5) << val; + } + } + std::cout << std::endl; + } + } + + break; + } + case ov::element::f16: + std::cout << *(tensor.data<ov::float16>()) << std::endl; + break; + case ov::element::i32: + for (size_t i = 0; i < tensor.get_size(); ++i) { + std::cout << tensor.data<int32_t>()[i] << " "; + } + std::cout << std::endl; + break; + case ov::element::i64: + for (size_t i = 0; i < tensor.get_size(); ++i) { + std::cout << tensor.data<int64_t>()[i] << " "; + } + std::cout << std::endl; + break; + default: + break; + } +} + +void print_output_tensor_info(const std::string & name, const ov::Tensor & tensor, const void * output_dst) { + std::cout << "Output name: " << name << ", Output shape: " << tensor.get_shape() << ", Address: " << output_dst + << std::endl; + + auto print_float_stats = [](const std::string & type_name, size_t size, auto get_value) { + if (size == 0) { + return; + } + + float first = get_value(0); + float min = first; + float max = first; + double sum = first; + + for (size_t i = 1; i < size; ++i) { + float v = get_value(i); + if (v < min) { + min = v; + } + if (v > max) { + max = v; + } + sum += v; + } + double mean = sum / size; + + std::cout << std::right << std::setw(6) << type_name << std::right << std::setw(12) << "First" << std::setw(12) + << "Min" << std::setw(12) << "Max" << std::setw(12) << "Mean" << std::endl; + std::cout << std::right << std::setw(6) << "" << std::right << std::setw(12) << first << std::setw(12) << min + << std::setw(12) << max << std::setw(12) << mean << std::endl; + }; + + switch (tensor.get_element_type()) { + case ov::element::f32: { + const float * data = tensor.data<float>(); + size_t size = tensor.get_size(); + print_float_stats("[f32]", size, [data](size_t i) { return data[i]; }); + break; + } + case ov::element::f16: { + const ov::float16 * data = tensor.data<ov::float16>(); + size_t size = tensor.get_size(); + print_float_stats("[f16]", size, [data](size_t i) { return static_cast<float>(data[i]); }); + break; + } + default: + break; + } +} + +void set_zero_diagonal(std::vector<float> & matrix, size_t rows, size_t cols) { + for (size_t i = 0; i < rows; ++i) { + size_t diag_col = std::min(i, cols - 1); + matrix[i * cols + diag_col] = 0.0f; + } +} + +const ggml_tensor * get_inp_pos_tensor(ggml_cgraph * cgraph) { + for (int i = 0; i < cgraph->n_nodes; ++i) { + auto * op = cgraph->nodes[i]; + for (int j = 0; j < GGML_MAX_SRC; ++j) { + auto * src = op->src[j]; + if (src == nullptr) { + break; + } + if (GgmlOvDecoder::is_inp_pos(src, op)) { + return src; + } + } + } + GGML_LOG_ERROR("get_inp_pos_tensor: inp_pos not found in cgraph"); + throw std::runtime_error("get_inp_pos_tensor: inp_pos not found in cgraph"); +} + +bool get_is_prefill(const ggml_tensor * inp_pos) { + return inp_pos->ne[0] > 1; +} + +#pragma GCC diagnostic pop diff --git a/ggml/src/ggml-openvino/utils.h b/ggml/src/ggml-openvino/utils.h new file mode 100644 index 00000000000..2c72e33c352 --- /dev/null +++ b/ggml/src/ggml-openvino/utils.h @@ -0,0 +1,143 @@ +#include "ggml-backend-impl.h" +#include "ggml-decoder.h" +#include "ggml-impl.h" + +#include <algorithm> +#include <atomic> +#include <cstddef> +#include <memory> +#include <mutex> +#include <openvino/runtime/core.hpp> +#include <openvino/runtime/infer_request.hpp> +#include <string> +#include <unordered_map> +#include <utility> +#include <vector> + +struct graph_key { + int n_nodes; + std::string first_node_name; + std::string last_node_name; + + graph_key(const ggml_cgraph * cgraph) : n_nodes(cgraph->n_nodes) { + if (n_nodes > 0) { + first_node_name = cgraph->nodes[0]->name; + last_node_name = cgraph->nodes[n_nodes - 1]->name; + } + } + + bool operator==(const graph_key & other) const { + return n_nodes == other.n_nodes && first_node_name == other.first_node_name && + last_node_name == other.last_node_name; + } +}; + +struct graph_key_hash { + size_t operator()(const graph_key & key) const { + size_t h = std::hash<int>{}(key.n_nodes); + if (key.n_nodes > 0) { + h ^= std::hash<std::string>{}(key.first_node_name) + 0x9e3779b9 + (h << 6) + (h >> 2); + h ^= std::hash<std::string>{}(key.last_node_name) + 0x9e3779b9 + (h << 6) + (h >> 2); + } + return h; + } +}; + +struct decoder_runtime_ctx { + decoder_runtime_ctx(std::shared_ptr<std::mutex> mutex) : mutex(std::move(mutex)) {} + std::shared_ptr<std::mutex> mutex; + std::shared_ptr<GgmlOvDecoder> ptr; +}; + +struct ov_runtime_context { + mutable std::mutex ctx_mutex; + std::string device; + bool stateful; + std::unordered_map<graph_key, std::shared_ptr<decoder_runtime_ctx>, graph_key_hash> decoder_cache; + std::unordered_map<graph_key, std::shared_ptr<ov::InferRequest>, graph_key_hash> infer_request_cache; + std::unordered_map<graph_key, std::shared_ptr<ov::InferRequest>, graph_key_hash> infer_request_cache_prefill; + std::unordered_map<graph_key, std::vector<std::string>, graph_key_hash> ov_input_names_cache; + std::unordered_map<graph_key, std::vector<std::string>, graph_key_hash> ov_output_names_cache; + //TODO: Stateful is only supported for single request at a time. + // Simultanous stateful inference request support to be added. + size_t stateful_kv_size; + std::map<std::string, std::string> kv_state_input_name_map; + std::atomic<int> backend_count; + + ov_runtime_context() : + device("CPU"), + stateful(false), + stateful_kv_size(0), + backend_count(0) {} + + void clear_caches() { + std::lock_guard<std::mutex> lock(ctx_mutex); + decoder_cache.clear(); + infer_request_cache.clear(); + infer_request_cache_prefill.clear(); + ov_input_names_cache.clear(); + ov_output_names_cache.clear(); + } +}; + +enum ggml_status ov_graph_compute(struct ggml_cgraph * cgraph, ggml_backend_t backend); + +enum ggml_status ov_graph_compute_dynamic(struct ggml_cgraph * cgraph, std::shared_ptr<ov_runtime_context> r_ctx); +enum ggml_status ov_graph_compute_static(struct ggml_cgraph * cgraph, std::shared_ptr<ov_runtime_context> r_ctx); + +size_t checksum(const void * data, size_t size); + +void print_input_tensor_info(const std::string & name, const ov::Tensor & tensor); + +void print_output_tensor_info(const std::string & name, const ov::Tensor & tensor, const void * output_dst); + +template <typename T> +std::vector<T> pad_input(const T * data, + size_t rows, + size_t cols, + size_t padded_rows, + size_t padded_cols, + T pad_value) { + std::vector<T> padded(padded_rows * padded_cols, pad_value); + + for (size_t i = 0; i < std::min(rows, padded_rows); ++i) { + for (size_t j = 0; j < std::min(cols, padded_cols); ++j) { + padded[i * padded_cols + j] = data[i * cols + j]; + } + } + + return padded; +} + +template <typename T> +std::vector<T> pad_input(const ggml_tensor * tensor, size_t padded_rows, size_t padded_cols, T pad_value) { + return pad_input<T>(reinterpret_cast<const T *>(tensor->data), + static_cast<size_t>(tensor->ne[1]), // rows + static_cast<size_t>(tensor->ne[0]), // cols + padded_rows, padded_cols, pad_value); +} + +void set_zero_diagonal(std::vector<float> & matrix, size_t rows, size_t cols); + +const ggml_tensor * get_inp_pos_tensor(struct ggml_cgraph * cgraph); + +bool get_is_prefill(const ggml_tensor * inp_pos); + +ov::Tensor get_ov_input_tensor(std::shared_ptr<GgmlOvDecoder> ggml_decoder, const std::string & param_name); +ov::Tensor get_ov_input_tensor_static_decode(std::shared_ptr<GgmlOvDecoder> ggml_decoder, + const std::string & param_name); +ov::Tensor get_ov_input_tensor_static_prefill(std::shared_ptr<GgmlOvDecoder> ggml_decoder, + const std::string & param_name, + int chunk_index); + +ov::Tensor create_ov_output_tensor(std::shared_ptr<GgmlOvDecoder> ggml_decoder, + std::shared_ptr<ov::InferRequest> infer_request, + int output_index, + const ggml_tensor * ggml_tensor); + +bool is_naive(struct ggml_cgraph * cgraph); + +enum ggml_status naive_compute(struct ggml_cgraph * cgraph, + ov::Core & core, + const std::string & device, + const ov::AnyMap & config); diff --git a/ggml/src/ggml-opt.cpp b/ggml/src/ggml-opt.cpp index e078ad14a39..53903defa8f 100644 --- a/ggml/src/ggml-opt.cpp +++ b/ggml/src/ggml-opt.cpp @@ -589,6 +589,7 @@ void ggml_opt_free(ggml_opt_context_t opt_ctx) { ggml_backend_buffer_free(opt_ctx->buf_cpu); ggml_free(opt_ctx->ctx_static); ggml_free(opt_ctx->ctx_cpu); + ggml_free(opt_ctx->ctx_copy); delete opt_ctx; } diff --git a/ggml/src/ggml-quants.c b/ggml/src/ggml-quants.c index de5cbd75e86..15d231f70c0 100644 --- a/ggml/src/ggml-quants.c +++ b/ggml/src/ggml-quants.c @@ -13,6 +13,10 @@ #include <stdlib.h> // for qsort #include <stdio.h> // for GGML_ASSERT +#ifdef GGML_USE_OPENMP +#include <omp.h> +#endif + #define GROUP_MAX_EPS 1e-15f #define GROUP_MAX_EPS_IQ3_XXS 1e-8f #define GROUP_MAX_EPS_IQ2_S 1e-8f @@ -32,6 +36,41 @@ static inline int best_index_int8(int n, const int8_t * val, float x) { return x - val[mu-1] < val[mu] - x ? mu-1 : mu; } +// reference implementation for deterministic creation of model files +void quantize_row_q1_0_ref(const float * GGML_RESTRICT x, block_q1_0 * GGML_RESTRICT y, int64_t k) { + static const int qk = QK1_0; + + assert(k % qk == 0); + + const int nb = k / qk; + + for (int i = 0; i < nb; i++) { + float sum_abs = 0.0f; + for (int j = 0; j < qk; j++) { + sum_abs += fabsf(x[i*qk + j]); + } + const float d = sum_abs / qk; + + y[i].d = GGML_FP32_TO_FP16(d); + + // Clear all bits first + for (int j = 0; j < qk / 8; ++j) { + y[i].qs[j] = 0; + } + + // Just store sign of each weight directly (no normalization) + for (int j = 0; j < qk; ++j) { + const int bit_index = j; + const int byte_index = bit_index / 8; + const int bit_offset = bit_index % 8; + + if (x[i*qk + j] >= 0.0f) { + y[i].qs[byte_index] |= (1 << bit_offset); + } + } + } +} + // reference implementation for deterministic creation of model files void quantize_row_q4_0_ref(const float * GGML_RESTRICT x, block_q4_0 * GGML_RESTRICT y, int64_t k) { static const int qk = QK4_0; @@ -304,6 +343,61 @@ void quantize_row_mxfp4_ref(const float * GGML_RESTRICT x, block_mxfp4 * GGML_RE } } +void quantize_row_nvfp4_ref(const float * GGML_RESTRICT x, block_nvfp4 * GGML_RESTRICT y, int64_t k) { + static const int qk = QK_NVFP4; + static const int qk_sub = QK_NVFP4_SUB; + static const int n_sub = QK_NVFP4 / QK_NVFP4_SUB; + + assert(k % qk == 0); + + const int nb = k / qk; + + for (int i = 0; i < nb; i++) { + for (int s = 0; s < n_sub; s++) { + const float * xb = x + i*qk + s*qk_sub; + + float amax = 0.0f; + for (int j = 0; j < qk_sub; j++) { + if (amax < fabsf(xb[j])) { + amax = fabsf(xb[j]); + } + } + + // UE4M3 scale: amax / 6.0 maps the max E2M1 value (6.0) to amax + const uint8_t ue = ggml_fp32_to_ue4m3(amax / 6.0f); + y[i].d[s] = ue; + const float d = ggml_ue4m3_to_fp32(ue); + + for (int j = 0; j < qk_sub/2; ++j) { + const uint8_t x0 = best_index_mxfp4(xb[0 + j], d); + const uint8_t x1 = best_index_mxfp4(xb[qk_sub/2 + j], d); + + y[i].qs[s*(qk_sub/2) + j] = x0 | (x1 << 4); + } + } + } +} + +void dequantize_row_q1_0(const block_q1_0 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k) { + static const int qk = QK1_0; + + assert(k % qk == 0); + + const int nb = k / qk; + + for (int i = 0; i < nb; i++) { + const float d = GGML_FP16_TO_FP32(x[i].d); + const float neg_d = -d; + + for (int j = 0; j < qk; ++j) { + const int byte_index = j / 8; + const int bit_offset = j % 8; + const uint8_t bit = (x[i].qs[byte_index] >> bit_offset) & 1; + y[i*qk + j] = bit ? d : neg_d; + } + } +} + void dequantize_row_q4_0(const block_q4_0 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k) { static const int qk = QK4_0; @@ -434,6 +528,31 @@ void dequantize_row_mxfp4(const block_mxfp4 * GGML_RESTRICT x, float * GGML_REST } } +void dequantize_row_nvfp4(const block_nvfp4 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k) { + static const int qk = QK_NVFP4; + static const int qk_sub = QK_NVFP4_SUB; + static const int n_sub = QK_NVFP4 / QK_NVFP4_SUB; + + assert(k % qk == 0); + + const int nb = k / qk; + + for (int i = 0; i < nb; i++) { + for (int s = 0; s < n_sub; s++) { + const float d = ggml_ue4m3_to_fp32(x[i].d[s]); + float * yb = y + i*qk + s*qk_sub; + + for (int j = 0; j < qk_sub/2; ++j) { + const int8_t v0 = kvalues_mxfp4[x[i].qs[s*(qk_sub/2) + j] & 0x0F]; + const int8_t v1 = kvalues_mxfp4[x[i].qs[s*(qk_sub/2) + j] >> 4]; + + yb[j + 0 ] = v0*d; + yb[j + qk_sub/2] = v1*d; + } + } + } +} + // // 2-6 bit quantization in super-blocks // @@ -1918,6 +2037,22 @@ static void quantize_row_q4_0_impl(const float * GGML_RESTRICT x, block_q4_0 * G } } +size_t quantize_q1_0(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) { + if (!quant_weights) { + quantize_row_q1_0_ref(src, dst, (int64_t)nrow*n_per_row); + return nrow * ggml_row_size(GGML_TYPE_Q1_0, n_per_row); + } + size_t row_size = ggml_row_size(GGML_TYPE_Q1_0, n_per_row); + char * qrow = (char *)dst; + for (int64_t row = 0; row < nrow; ++row) { + quantize_row_q1_0_ref(src, (block_q1_0*)qrow, n_per_row); + src += n_per_row; + qrow += row_size; + } + return nrow * row_size; +} + + size_t quantize_q4_0(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) { if (!quant_weights) { quantize_row_q4_0_ref(src, dst, (int64_t)nrow*n_per_row); @@ -2098,6 +2233,12 @@ size_t quantize_mxfp4(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, return nrow * ggml_row_size(GGML_TYPE_MXFP4, n_per_row); } +size_t quantize_nvfp4(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) { + GGML_UNUSED(quant_weights); + quantize_row_nvfp4_ref(src, dst, (int64_t)nrow*n_per_row); + return nrow * ggml_row_size(GGML_TYPE_NVFP4, n_per_row); +} + // ====================== Ternary (de)-quantization (BitNet b1.58 and TriLMs) void quantize_row_tq1_0_ref(const float * GGML_RESTRICT x, block_tq1_0 * GGML_RESTRICT y, int64_t k) { @@ -2927,70 +3068,121 @@ void iq2xs_init_impl(enum ggml_type type) { } kmap_q2xs[index] = i; } - int8_t pos[8]; - int * dist2 = (int *)malloc(2*grid_size*sizeof(int)); + // The neighbour search runs in three passes: + // 1. Parallel: for each i, qsort and count its neighbours into n_per_i, + // and reduce the totals (num_neighbors, num_not_in_map). + // 2. Serial: prefix-sum n_per_i into offsets[], so each i has a + // pre-assigned slice of kneighbors_q2xs to write into. + // 3. Parallel: redo the qsort and write each i's neighbour list at + // offsets[i]. + int * n_per_i = (int *)malloc(kmap_size*sizeof(int)); + GGML_ASSERT(n_per_i); int num_neighbors = 0, num_not_in_map = 0; - for (int i = 0; i < kmap_size; ++i) { - if (kmap_q2xs[i] >= 0) continue; - ++num_not_in_map; - for (int k = 0; k < 8; ++k) { - int l = (i >> 2*k) & 0x3; - pos[k] = 2*l + 1; - } - for (int j = 0; j < grid_size; ++j) { - const int8_t * pg = (const int8_t *)(kgrid_q2xs + j); - int d2 = 0; - for (int k = 0; k < 8; ++k) d2 += (pg[k] - pos[k])*(pg[k] - pos[k]); - dist2[2*j+0] = d2; - dist2[2*j+1] = j; - } - qsort(dist2, grid_size, 2*sizeof(int), iq2_compare_func); - int n = 0; int d2 = dist2[0]; - int nhave = 1; - for (int j = 0; j < grid_size; ++j) { - if (dist2[2*j] > d2) { - if (nhave == nwant) break; - d2 = dist2[2*j]; - ++nhave; - } - ++n; - } - num_neighbors += n; +#ifdef GGML_USE_OPENMP + #pragma omp parallel reduction(+:num_neighbors,num_not_in_map) +#endif + { + int * dist2 = (int *)malloc(2*grid_size*sizeof(int)); + GGML_ASSERT(dist2); + int8_t pos[8]; + int i; +#ifdef GGML_USE_OPENMP + #pragma omp for schedule(dynamic, 64) +#endif + for (i = 0; i < kmap_size; ++i) { + if (kmap_q2xs[i] >= 0) { + n_per_i[i] = 0; + continue; + } + ++num_not_in_map; + for (int k = 0; k < 8; ++k) { + int l = (i >> 2*k) & 0x3; + pos[k] = 2*l + 1; + } + for (int j = 0; j < grid_size; ++j) { + const int8_t * pg = (const int8_t *)(kgrid_q2xs + j); + int d2 = 0; + for (int k = 0; k < 8; ++k) d2 += (pg[k] - pos[k])*(pg[k] - pos[k]); + dist2[2*j+0] = d2; + dist2[2*j+1] = j; + } + qsort(dist2, grid_size, 2*sizeof(int), iq2_compare_func); + int n = 0; int d2 = dist2[0]; + int nhave = 1; + for (int j = 0; j < grid_size; ++j) { + if (dist2[2*j] > d2) { + if (nhave == nwant) break; + d2 = dist2[2*j]; + ++nhave; + } + ++n; + } + n_per_i[i] = n; + num_neighbors += n; + } + free(dist2); } //printf("%s: %d neighbours in total\n", __func__, num_neighbors); kneighbors_q2xs = (uint16_t *)malloc((num_neighbors + num_not_in_map)*sizeof(uint16_t)); iq2_data[gindex].neighbours = kneighbors_q2xs; + + int * offsets = (int *)malloc(kmap_size*sizeof(int)); + GGML_ASSERT(offsets); int counter = 0; for (int i = 0; i < kmap_size; ++i) { - if (kmap_q2xs[i] >= 0) continue; - for (int k = 0; k < 8; ++k) { - int l = (i >> 2*k) & 0x3; - pos[k] = 2*l + 1; - } - for (int j = 0; j < grid_size; ++j) { - const int8_t * pg = (const int8_t *)(kgrid_q2xs + j); - int d2 = 0; - for (int k = 0; k < 8; ++k) d2 += (pg[k] - pos[k])*(pg[k] - pos[k]); - dist2[2*j+0] = d2; - dist2[2*j+1] = j; - } - qsort(dist2, grid_size, 2*sizeof(int), iq2_compare_func); - kmap_q2xs[i] = -(counter + 1); - int d2 = dist2[0]; - uint16_t * start = &kneighbors_q2xs[counter++]; - int n = 0, nhave = 1; - for (int j = 0; j < grid_size; ++j) { - if (dist2[2*j] > d2) { - if (nhave == nwant) break; - d2 = dist2[2*j]; - ++nhave; - } - kneighbors_q2xs[counter++] = dist2[2*j+1]; - ++n; - } - *start = n; - } - free(dist2); + if (kmap_q2xs[i] >= 0) { + offsets[i] = -1; + continue; + } + offsets[i] = counter; + counter += 1 + n_per_i[i]; + } + +#ifdef GGML_USE_OPENMP + #pragma omp parallel +#endif + { + int * dist2 = (int *)malloc(2*grid_size*sizeof(int)); + GGML_ASSERT(dist2); + int8_t pos[8]; + int i; +#ifdef GGML_USE_OPENMP + #pragma omp for schedule(dynamic, 64) +#endif + for (i = 0; i < kmap_size; ++i) { + if (kmap_q2xs[i] >= 0) continue; + for (int k = 0; k < 8; ++k) { + int l = (i >> 2*k) & 0x3; + pos[k] = 2*l + 1; + } + for (int j = 0; j < grid_size; ++j) { + const int8_t * pg = (const int8_t *)(kgrid_q2xs + j); + int d2 = 0; + for (int k = 0; k < 8; ++k) d2 += (pg[k] - pos[k])*(pg[k] - pos[k]); + dist2[2*j+0] = d2; + dist2[2*j+1] = j; + } + qsort(dist2, grid_size, 2*sizeof(int), iq2_compare_func); + int local_counter = offsets[i]; + kmap_q2xs[i] = -(local_counter + 1); + int d2 = dist2[0]; + uint16_t * start = &kneighbors_q2xs[local_counter++]; + int n = 0, nhave = 1; + for (int j = 0; j < grid_size; ++j) { + if (dist2[2*j] > d2) { + if (nhave == nwant) break; + d2 = dist2[2*j]; + ++nhave; + } + kneighbors_q2xs[local_counter++] = dist2[2*j+1]; + ++n; + } + *start = n; + } + free(dist2); + } + free(offsets); + free(n_per_i); } void iq2xs_free_impl(enum ggml_type type) { @@ -3104,6 +3296,11 @@ static void quantize_row_iq2_xxs_impl(const float * GGML_RESTRICT x, void * GGML } float scale = make_qp_quants(32, kMaxQ+1, xval, (uint8_t*)L, weight); float eff_max = scale*kMaxQ; + if (eff_max <= 0) { + scales[ib] = 0; + memset(L, 0, 32); + continue; + } float best = 0; for (int is = -6; is <= 6; ++is) { float id = (2*kMaxQ-1+is*0.1f)/eff_max; @@ -3273,9 +3470,9 @@ static void quantize_row_iq2_xs_impl(const float * GGML_RESTRICT x, void * GGML_ } float max = xval[0]; for (int i = 1; i < 16; ++i) max = MAX(max, xval[i]); + memset(L, 0, 16); if (max < GROUP_MAX_EPS) { scales[ib] = 0; - memset(L, 0, 16); continue; } float best = 0; @@ -3521,70 +3718,115 @@ void iq3xs_init_impl(int grid_size) { } kmap_q3xs[index] = i; } - int8_t pos[4]; - int * dist2 = (int *)malloc(2*grid_size*sizeof(int)); + // See explanation of parallelism in iq2xs_init_impl + int * n_per_i = (int *)malloc(kmap_size*sizeof(int)); + GGML_ASSERT(n_per_i); int num_neighbors = 0, num_not_in_map = 0; - for (int i = 0; i < kmap_size; ++i) { - if (kmap_q3xs[i] >= 0) continue; - ++num_not_in_map; - for (int k = 0; k < 4; ++k) { - int l = (i >> 3*k) & 0x7; - pos[k] = 2*l + 1; - } - for (int j = 0; j < grid_size; ++j) { - const int8_t * pg = (const int8_t *)(kgrid_q3xs + j); - int d2 = 0; - for (int k = 0; k < 4; ++k) d2 += (pg[k] - pos[k])*(pg[k] - pos[k]); - dist2[2*j+0] = d2; - dist2[2*j+1] = j; - } - qsort(dist2, grid_size, 2*sizeof(int), iq3_compare_func); - int n = 0; int d2 = dist2[0]; - int nhave = 1; - for (int j = 0; j < grid_size; ++j) { - if (dist2[2*j] > d2) { - if (nhave == nwant) break; - d2 = dist2[2*j]; - ++nhave; - } - ++n; - } - num_neighbors += n; +#ifdef GGML_USE_OPENMP + #pragma omp parallel reduction(+:num_neighbors,num_not_in_map) +#endif + { + int * dist2 = (int *)malloc(2*grid_size*sizeof(int)); + GGML_ASSERT(dist2); + int8_t pos[4]; + int i; +#ifdef GGML_USE_OPENMP + #pragma omp for schedule(dynamic, 64) +#endif + for (i = 0; i < kmap_size; ++i) { + if (kmap_q3xs[i] >= 0) { + n_per_i[i] = 0; + continue; + } + ++num_not_in_map; + for (int k = 0; k < 4; ++k) { + int l = (i >> 3*k) & 0x7; + pos[k] = 2*l + 1; + } + for (int j = 0; j < grid_size; ++j) { + const int8_t * pg = (const int8_t *)(kgrid_q3xs + j); + int d2 = 0; + for (int k = 0; k < 4; ++k) d2 += (pg[k] - pos[k])*(pg[k] - pos[k]); + dist2[2*j+0] = d2; + dist2[2*j+1] = j; + } + qsort(dist2, grid_size, 2*sizeof(int), iq3_compare_func); + int n = 0; int d2 = dist2[0]; + int nhave = 1; + for (int j = 0; j < grid_size; ++j) { + if (dist2[2*j] > d2) { + if (nhave == nwant) break; + d2 = dist2[2*j]; + ++nhave; + } + ++n; + } + n_per_i[i] = n; + num_neighbors += n; + } + free(dist2); } //printf("%s: %d neighbours in total\n", __func__, num_neighbors); kneighbors_q3xs = (uint16_t *)malloc((num_neighbors + num_not_in_map)*sizeof(uint16_t)); iq3_data[gindex].neighbours = kneighbors_q3xs; + + int * offsets = (int *)malloc(kmap_size*sizeof(int)); + GGML_ASSERT(offsets); int counter = 0; for (int i = 0; i < kmap_size; ++i) { - if (kmap_q3xs[i] >= 0) continue; - for (int k = 0; k < 4; ++k) { - int l = (i >> 3*k) & 0x7; - pos[k] = 2*l + 1; - } - for (int j = 0; j < grid_size; ++j) { - const int8_t * pg = (const int8_t *)(kgrid_q3xs + j); - int d2 = 0; - for (int k = 0; k < 4; ++k) d2 += (pg[k] - pos[k])*(pg[k] - pos[k]); - dist2[2*j+0] = d2; - dist2[2*j+1] = j; - } - qsort(dist2, grid_size, 2*sizeof(int), iq3_compare_func); - kmap_q3xs[i] = -(counter + 1); - int d2 = dist2[0]; - uint16_t * start = &kneighbors_q3xs[counter++]; - int n = 0, nhave = 1; - for (int j = 0; j < grid_size; ++j) { - if (dist2[2*j] > d2) { - if (nhave == nwant) break; - d2 = dist2[2*j]; - ++nhave; - } - kneighbors_q3xs[counter++] = dist2[2*j+1]; - ++n; - } - *start = n; - } - free(dist2); + if (kmap_q3xs[i] >= 0) { + offsets[i] = -1; + continue; + } + offsets[i] = counter; + counter += 1 + n_per_i[i]; + } + +#ifdef GGML_USE_OPENMP + #pragma omp parallel +#endif + { + int * dist2 = (int *)malloc(2*grid_size*sizeof(int)); + GGML_ASSERT(dist2); + int8_t pos[4]; + int i; +#ifdef GGML_USE_OPENMP + #pragma omp for schedule(dynamic, 64) +#endif + for (i = 0; i < kmap_size; ++i) { + if (kmap_q3xs[i] >= 0) continue; + for (int k = 0; k < 4; ++k) { + int l = (i >> 3*k) & 0x7; + pos[k] = 2*l + 1; + } + for (int j = 0; j < grid_size; ++j) { + const int8_t * pg = (const int8_t *)(kgrid_q3xs + j); + int d2 = 0; + for (int k = 0; k < 4; ++k) d2 += (pg[k] - pos[k])*(pg[k] - pos[k]); + dist2[2*j+0] = d2; + dist2[2*j+1] = j; + } + qsort(dist2, grid_size, 2*sizeof(int), iq3_compare_func); + int local_counter = offsets[i]; + kmap_q3xs[i] = -(local_counter + 1); + int d2 = dist2[0]; + uint16_t * start = &kneighbors_q3xs[local_counter++]; + int n = 0, nhave = 1; + for (int j = 0; j < grid_size; ++j) { + if (dist2[2*j] > d2) { + if (nhave == nwant) break; + d2 = dist2[2*j]; + ++nhave; + } + kneighbors_q3xs[local_counter++] = dist2[2*j+1]; + ++n; + } + *start = n; + } + free(dist2); + } + free(offsets); + free(n_per_i); } void iq3xs_free_impl(int grid_size) { @@ -3714,9 +3956,9 @@ static void quantize_row_iq3_xxs_impl(int grid_size, const float * GGML_RESTRICT } float max = xval[0]; for (int i = 1; i < 32; ++i) max = MAX(max, xval[i]); + memset(L, 0, 32); if (max < GROUP_MAX_EPS_IQ3_XXS) { scales[ib] = 0; - memset(L, 0, 32); continue; } float best = 0; @@ -3922,6 +4164,7 @@ static void quantize_row_iq3_s_impl(int block_size, const float * GGML_RESTRICT } float max = xval[0]; for (int i = 1; i < block_size; ++i) max = MAX(max, xval[i]); + memset(L, 0, block_size); if (!max) { scales[ib] = 0; continue; @@ -4245,6 +4488,7 @@ static void quantize_row_iq1_s_impl(const float * GGML_RESTRICT x, void * GGML_R for (int i = 1; i < block_size; ++i) max = MAX(max, fabsf(xb[i])); if (max < GROUP_MAX_EPS_IQ1_S) { scales[ib] = 0; + shifts[ib] = 1; memset(L, 1, block_size); continue; } @@ -4285,7 +4529,12 @@ static void quantize_row_iq1_s_impl(const float * GGML_RESTRICT x, void * GGML_R } } } - GGML_ASSERT(besti1 >= 0 && besti2 >= 0 && best_shift != 0); + if (besti1 < 0 || besti2 < 0 || best_shift == 0) { + scales[ib] = 0; + shifts[ib] = 1; + memset(L, 1, block_size); + continue; + } for (int j = 0; j < besti1; ++j) L[idx[2*j]] = 0; for (int j = besti1; j < besti2; ++j) L[idx[2*j]] = 1; for (int j = besti2; j < block_size; ++j) L[idx[2*j]] = 2; @@ -4429,6 +4678,7 @@ static void quantize_row_iq1_m_impl(const float * GGML_RESTRICT x, void * GGML_R for (int i = 1; i < block_size; ++i) max = MAX(max, fabsf(xb[i])); if (max < GROUP_MAX_EPS_IQ1_M) { scales[ib] = 0; + shifts[ib] = 0; memset(L, 1, block_size); continue; } @@ -4527,7 +4777,12 @@ static void quantize_row_iq1_m_impl(const float * GGML_RESTRICT x, void * GGML_R } } } - GGML_ASSERT(besti1 >= 0 && besti2 >= 0 && best_k >= 0); + if (besti1 < 0 || besti2 < 0 || best_k < 0) { + scales[ib] = 0; + shifts[ib] = 0; + memset(L, 1, block_size); + continue; + } for (int j = 0; j < besti1; ++j) L[idx[2*j]] = 0; for (int j = besti1; j < besti2; ++j) L[idx[2*j]] = 1; for (int j = besti2; j < block_size; ++j) L[idx[2*j]] = 2; @@ -4683,7 +4938,7 @@ static void quantize_row_iq4_nl_impl(const int super_block_size, const int block sumqx += w*q*xb[j]; sumq2 += w*q*q; } - d = sumqx/sumq2; + d = sumq2 > 0 ? sumqx/sumq2 : 0.f; float best = d*sumqx; for (int itry = -ntry; itry <= ntry; ++itry) { id = (itry + values[0])/max; @@ -4874,6 +5129,7 @@ static void quantize_row_iq2_s_impl(const float * GGML_RESTRICT x, void * GGML_R } float max = xval[0]; for (int i = 1; i < 16; ++i) max = MAX(max, xval[i]); + memset(L, 0, 16); if (max < GROUP_MAX_EPS_IQ2_S) { scales[ib] = 0; continue; @@ -5201,6 +5457,10 @@ bool ggml_validate_row_data(enum ggml_type type, const void * data, size_t nbyte } } } break; + case GGML_TYPE_Q1_0: + { + VALIDATE_ROW_DATA_D_F16_IMPL(block_q1_0, data, nb); + } break; case GGML_TYPE_Q4_0: { VALIDATE_ROW_DATA_D_F16_IMPL(block_q4_0, data, nb); @@ -5225,6 +5485,12 @@ bool ggml_validate_row_data(enum ggml_type type, const void * data, size_t nbyte { VALIDATE_ROW_DATA_E_E8M0_IMPL(block_mxfp4, data, nb); } break; + case GGML_TYPE_NVFP4: + { + // UE4M3 scales are uint8_t — all byte values are valid + GGML_UNUSED(data); + GGML_UNUSED(nb); + } break; case GGML_TYPE_Q2_K: { VALIDATE_ROW_DATA_DM_F16_IMPL(block_q2_K, data, nb, d, dmin); diff --git a/ggml/src/ggml-quants.h b/ggml/src/ggml-quants.h index 3b688f31c21..d56c86da890 100644 --- a/ggml/src/ggml-quants.h +++ b/ggml/src/ggml-quants.h @@ -14,6 +14,7 @@ extern "C" { // NOTE: these functions are defined as GGML_API because they used by the CPU backend // Quantization +GGML_API void quantize_row_q1_0_ref(const float * GGML_RESTRICT x, block_q1_0 * GGML_RESTRICT y, int64_t k); GGML_API void quantize_row_q4_0_ref(const float * GGML_RESTRICT x, block_q4_0 * GGML_RESTRICT y, int64_t k); GGML_API void quantize_row_q4_1_ref(const float * GGML_RESTRICT x, block_q4_1 * GGML_RESTRICT y, int64_t k); GGML_API void quantize_row_q5_0_ref(const float * GGML_RESTRICT x, block_q5_0 * GGML_RESTRICT y, int64_t k); @@ -22,6 +23,7 @@ GGML_API void quantize_row_q8_0_ref(const float * GGML_RESTRICT x, block_q8_0 * GGML_API void quantize_row_q8_1_ref(const float * GGML_RESTRICT x, block_q8_1 * GGML_RESTRICT y, int64_t k); GGML_API void quantize_row_mxfp4_ref(const float * GGML_RESTRICT x, block_mxfp4 * GGML_RESTRICT y, int64_t k); +GGML_API void quantize_row_nvfp4_ref(const float * GGML_RESTRICT x, block_nvfp4 * GGML_RESTRICT y, int64_t k); GGML_API void quantize_row_q2_K_ref(const float * GGML_RESTRICT x, block_q2_K * GGML_RESTRICT y, int64_t k); GGML_API void quantize_row_q3_K_ref(const float * GGML_RESTRICT x, block_q3_K * GGML_RESTRICT y, int64_t k); @@ -40,6 +42,7 @@ GGML_API void quantize_row_iq3_s_ref (const float * GGML_RESTRICT x, block_iq3_ GGML_API void quantize_row_iq2_s_ref (const float * GGML_RESTRICT x, block_iq2_s * GGML_RESTRICT y, int64_t k); // Dequantization +GGML_API void dequantize_row_q1_0(const block_q1_0 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); GGML_API void dequantize_row_q4_0(const block_q4_0 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); GGML_API void dequantize_row_q4_1(const block_q4_1 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); GGML_API void dequantize_row_q5_0(const block_q5_0 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); @@ -48,6 +51,7 @@ GGML_API void dequantize_row_q8_0(const block_q8_0 * GGML_RESTRICT x, float * GG //GGML_API void dequantize_row_q8_1(const block_q8_1 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); GGML_API void dequantize_row_mxfp4(const block_mxfp4 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); +GGML_API void dequantize_row_nvfp4(const block_nvfp4 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); GGML_API void dequantize_row_q2_K(const block_q2_K * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); GGML_API void dequantize_row_q3_K(const block_q3_K * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); @@ -88,6 +92,7 @@ GGML_API size_t quantize_q3_K(const float * GGML_RESTRICT src, void * GGML_RESTR GGML_API size_t quantize_q4_K(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); GGML_API size_t quantize_q5_K(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); GGML_API size_t quantize_q6_K(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); +GGML_API size_t quantize_q1_0(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); GGML_API size_t quantize_q4_0(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); GGML_API size_t quantize_q4_1(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); GGML_API size_t quantize_q5_0(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); @@ -95,6 +100,7 @@ GGML_API size_t quantize_q5_1(const float * GGML_RESTRICT src, void * GGML_RESTR GGML_API size_t quantize_q8_0(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); GGML_API size_t quantize_mxfp4(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); +GGML_API size_t quantize_nvfp4(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); GGML_API void iq2xs_init_impl(enum ggml_type type); GGML_API void iq2xs_free_impl(enum ggml_type type); diff --git a/ggml/src/ggml-rpc/CMakeLists.txt b/ggml/src/ggml-rpc/CMakeLists.txt index f5acb8ec2cb..40e11fead63 100644 --- a/ggml/src/ggml-rpc/CMakeLists.txt +++ b/ggml/src/ggml-rpc/CMakeLists.txt @@ -2,8 +2,32 @@ message(STATUS "Using RPC backend") ggml_add_backend_library(ggml-rpc ggml-rpc.cpp + transport.cpp ) if (WIN32) target_link_libraries(ggml-rpc PRIVATE ws2_32) endif() + +# RDMA auto-detection (Linux only, requires libibverbs) +if (NOT WIN32 AND NOT APPLE) + find_library(IBVERBS_LIB ibverbs) + if (IBVERBS_LIB) + option(GGML_RPC_RDMA "ggml: enable RDMA transport for RPC" ON) + else() + option(GGML_RPC_RDMA "ggml: enable RDMA transport for RPC" OFF) + endif() +else() + set(GGML_RPC_RDMA OFF CACHE BOOL "RDMA not available on this platform" FORCE) +endif() + +if (GGML_RPC_RDMA) + if (NOT IBVERBS_LIB) + find_library(IBVERBS_LIB ibverbs REQUIRED) + endif() + target_compile_definitions(ggml-rpc PRIVATE GGML_RPC_RDMA) + target_link_libraries(ggml-rpc PRIVATE ${IBVERBS_LIB}) + message(STATUS " RDMA transport enabled (auto-detected)") +else() + message(STATUS " RDMA transport disabled") +endif() diff --git a/ggml/src/ggml-rpc/ggml-rpc.cpp b/ggml/src/ggml-rpc/ggml-rpc.cpp index d7c8ad8c168..d3805772183 100644 --- a/ggml/src/ggml-rpc/ggml-rpc.cpp +++ b/ggml/src/ggml-rpc/ggml-rpc.cpp @@ -2,30 +2,17 @@ #include "ggml-impl.h" #include "ggml-backend-impl.h" #include "ggml-cpp.h" +#include "transport.h" +#include <array> #include <cinttypes> +#include <optional> #include <string> #include <vector> #include <memory> #include <mutex> #include <unordered_map> #include <unordered_set> -#ifdef _WIN32 -# define WIN32_LEAN_AND_MEAN -# ifndef NOMINMAX -# define NOMINMAX -# endif -# include <windows.h> -# include <winsock2.h> -#else -# include <arpa/inet.h> -# include <sys/socket.h> -# include <sys/types.h> -# include <netinet/in.h> -# include <netinet/tcp.h> -# include <netdb.h> -# include <unistd.h> -#endif #include <cstring> #include <fstream> #include <filesystem> @@ -39,29 +26,6 @@ static const char * RPC_DEBUG = std::getenv("GGML_RPC_DEBUG"); namespace fs = std::filesystem; -static constexpr size_t MAX_CHUNK_SIZE = 1024ull * 1024ull * 1024ull; // 1 GiB - -#ifdef _WIN32 -typedef SOCKET sockfd_t; -using ssize_t = __int64; -#else -typedef int sockfd_t; -#endif - -// cross-platform socket -struct socket_t { - sockfd_t fd; - socket_t(sockfd_t fd) : fd(fd) {} - ~socket_t() { - LOG_DBG("[%s] closing socket %d\n", __func__, this->fd); -#ifdef _WIN32 - closesocket(this->fd); -#else - close(this->fd); -#endif - } -}; - // macro for nicer error messages on server crash #define RPC_STATUS_ASSERT(x) if (!(x)) GGML_ABORT("Remote RPC server crashed or returned malformed response") @@ -115,10 +79,16 @@ static_assert(RPC_CMD_HELLO == 14, "RPC_CMD_HELLO must be always 14"); // Try RPC_CMD_SET_TENSOR_HASH first when data size is larger than this threshold const size_t HASH_THRESHOLD = 10 * 1024 * 1024; +struct rpc_msg_hello_req { + uint8_t conn_caps[RPC_CONN_CAPS_SIZE]; +}; + struct rpc_msg_hello_rsp { uint8_t major; uint8_t minor; uint8_t patch; + uint8_t padding; + uint8_t conn_caps[RPC_CONN_CAPS_SIZE]; }; struct rpc_msg_device_count_rsp { @@ -229,6 +199,14 @@ static ggml_guid_t ggml_backend_rpc_guid() { return &guid; } +struct ggml_backend_rpc_device_context { + std::string endpoint; + uint32_t device; + std::string name; + std::string description; + uint64_t last_graph_uid; +}; + struct ggml_backend_rpc_buffer_type_context { std::string endpoint; uint32_t device; @@ -237,35 +215,10 @@ struct ggml_backend_rpc_buffer_type_context { size_t max_size; }; -struct graph_cache { - - bool is_cached(const ggml_cgraph * cgraph) { - if ((int)last_graph.size() != cgraph->n_nodes) { - return false; - } - for (int i = 0; i < cgraph->n_nodes; i++) { - if (memcmp(&last_graph[i], cgraph->nodes[i], sizeof(ggml_tensor)) != 0) { - return false; - } - } - return true; - } - - void add(const ggml_cgraph * cgraph) { - last_graph.resize(cgraph->n_nodes); - for (int i = 0; i < cgraph->n_nodes; i++) { - memcpy(&last_graph[i], cgraph->nodes[i], sizeof(ggml_tensor)); - } - } - - std::vector<ggml_tensor> last_graph; -}; - struct ggml_backend_rpc_context { std::string endpoint; uint32_t device; std::string name; - graph_cache gc; }; struct ggml_backend_rpc_buffer_context { @@ -288,153 +241,27 @@ static uint64_t fnv_hash(const uint8_t * data, size_t len) { return hash; } -static std::shared_ptr<socket_t> make_socket(sockfd_t fd) { -#ifdef _WIN32 - if (fd == INVALID_SOCKET) { - return nullptr; - } -#else - if (fd < 0) { - return nullptr; - } -#endif - return std::make_shared<socket_t>(fd); -} - -static bool set_no_delay(sockfd_t sockfd) { - int flag = 1; - // set TCP_NODELAY to disable Nagle's algorithm - int ret = setsockopt(sockfd, IPPROTO_TCP, TCP_NODELAY, (char *)&flag, sizeof(int)); - return ret == 0; -} - -static bool set_reuse_addr(sockfd_t sockfd) { - int flag = 1; - int ret = setsockopt(sockfd, SOL_SOCKET, SO_REUSEADDR, (char *)&flag, sizeof(int)); - return ret == 0; -} - -static std::shared_ptr<socket_t> socket_connect(const char * host, int port) { - struct sockaddr_in addr; - auto sockfd = socket(AF_INET, SOCK_STREAM, 0); - auto sock_ptr = make_socket(sockfd); - if (sock_ptr == nullptr) { - return nullptr; - } - if (!set_no_delay(sockfd)) { - GGML_LOG_ERROR("Failed to set TCP_NODELAY\n"); - return nullptr; - } - addr.sin_family = AF_INET; - addr.sin_port = htons(port); - struct hostent * server = gethostbyname(host); - if (server == NULL) { - GGML_LOG_ERROR("Cannot resolve host '%s'\n", host); - return nullptr; - } - memcpy(&addr.sin_addr.s_addr, server->h_addr, server->h_length); - if (connect(sock_ptr->fd, (struct sockaddr *)&addr, sizeof(addr)) < 0) { - return nullptr; - } - return sock_ptr; -} - -static std::shared_ptr<socket_t> socket_accept(sockfd_t srv_sockfd) { - auto client_socket_fd = accept(srv_sockfd, NULL, NULL); - auto client_socket = make_socket(client_socket_fd); - if (client_socket == nullptr) { - return nullptr; - } - if (!set_no_delay(client_socket_fd)) { - GGML_LOG_ERROR("Failed to set TCP_NODELAY\n"); - return nullptr; - } - return client_socket; -} - -static std::shared_ptr<socket_t> create_server_socket(const char * host, int port) { - auto sockfd = socket(AF_INET, SOCK_STREAM, 0); - auto sock = make_socket(sockfd); - if (sock == nullptr) { - return nullptr; - } - if (!set_reuse_addr(sockfd)) { - GGML_LOG_ERROR("Failed to set SO_REUSEADDR\n"); - return nullptr; - } - if (inet_addr(host) == INADDR_NONE) { - GGML_LOG_ERROR("Invalid host address: %s\n", host); - return nullptr; - } - struct sockaddr_in serv_addr; - serv_addr.sin_family = AF_INET; - serv_addr.sin_addr.s_addr = inet_addr(host); - serv_addr.sin_port = htons(port); - - if (bind(sockfd, (struct sockaddr *) &serv_addr, sizeof(serv_addr)) < 0) { - return nullptr; - } - if (listen(sockfd, 1) < 0) { - return nullptr; - } - return sock; -} - -static bool send_data(sockfd_t sockfd, const void * data, size_t size) { - size_t bytes_sent = 0; - while (bytes_sent < size) { - size_t size_to_send = std::min(size - bytes_sent, MAX_CHUNK_SIZE); - ssize_t n = send(sockfd, (const char *)data + bytes_sent, size_to_send, 0); - if (n < 0) { - GGML_LOG_ERROR("send failed (bytes_sent=%zu, size_to_send=%zu)\n", - bytes_sent, size_to_send); - return false; - } - bytes_sent += (size_t)n; - } - return true; -} - -static bool recv_data(sockfd_t sockfd, void * data, size_t size) { - size_t bytes_recv = 0; - while (bytes_recv < size) { - size_t size_to_recv = std::min(size - bytes_recv, MAX_CHUNK_SIZE); - ssize_t n = recv(sockfd, (char *)data + bytes_recv, size_to_recv, 0); - if (n < 0) { - GGML_LOG_ERROR("recv failed (bytes_recv=%zu, size_to_recv=%zu)\n", - bytes_recv, size_to_recv); - return false; - } - if (n == 0) { - LOG_DBG("recv returned 0 (peer closed?)\n"); - return false; - } - bytes_recv += (size_t)n; - } - return true; -} - -static bool send_msg(sockfd_t sockfd, const void * msg, size_t msg_size) { - if (!send_data(sockfd, &msg_size, sizeof(msg_size))) { +static bool send_msg(socket_ptr sock, const void * msg, size_t msg_size) { + if (!sock->send_data(&msg_size, sizeof(msg_size))) { return false; } - return send_data(sockfd, msg, msg_size); + return sock->send_data(msg, msg_size); } -static bool recv_msg(sockfd_t sockfd, void * msg, size_t msg_size) { +static bool recv_msg(socket_ptr sock, void * msg, size_t msg_size) { uint64_t size; - if (!recv_data(sockfd, &size, sizeof(size))) { + if (!sock->recv_data(&size, sizeof(size))) { return false; } if (size != msg_size) { return false; } - return recv_data(sockfd, msg, msg_size); + return sock->recv_data(msg, msg_size); } -static bool recv_msg(sockfd_t sockfd, std::vector<uint8_t> & input) { +static bool recv_msg(socket_ptr sock, std::vector<uint8_t> & input) { uint64_t size; - if (!recv_data(sockfd, &size, sizeof(size))) { + if (!sock->recv_data(&size, sizeof(size))) { return false; } try { @@ -443,7 +270,7 @@ static bool recv_msg(sockfd_t sockfd, std::vector<uint8_t> & input) { GGML_LOG_ERROR("Failed to allocate input buffer of size %" PRIu64 "\n", size); return false; } - return recv_data(sockfd, input.data(), size); + return sock->recv_data(input.data(), size); } static bool parse_endpoint(const std::string & endpoint, std::string & host, int & port) { @@ -452,21 +279,25 @@ static bool parse_endpoint(const std::string & endpoint, std::string & host, int return false; } host = endpoint.substr(0, pos); - port = std::stoi(endpoint.substr(pos + 1)); + try { + port = std::stoi(endpoint.substr(pos + 1)); + } catch (...) { + return false; + } return true; } // RPC request : | rpc_cmd (1 byte) | request_size (8 bytes) | request_data (request_size bytes) | // No response -static bool send_rpc_cmd(const std::shared_ptr<socket_t> & sock, enum rpc_cmd cmd, const void * input, size_t input_size) { +static bool send_rpc_cmd(socket_ptr sock, enum rpc_cmd cmd, const void * input, size_t input_size) { uint8_t cmd_byte = cmd; - if (!send_data(sock->fd, &cmd_byte, sizeof(cmd_byte))) { + if (!sock->send_data(&cmd_byte, sizeof(cmd_byte))) { return false; } - if (!send_data(sock->fd, &input_size, sizeof(input_size))) { + if (!sock->send_data(&input_size, sizeof(input_size))) { return false; } - if (!send_data(sock->fd, input, input_size)) { + if (!sock->send_data(input, input_size)) { return false; } return true; @@ -474,20 +305,18 @@ static bool send_rpc_cmd(const std::shared_ptr<socket_t> & sock, enum rpc_cmd cm // RPC request : | rpc_cmd (1 byte) | request_size (8 bytes) | request_data (request_size bytes) | // RPC response: | response_size (8 bytes) | response_data (response_size bytes) | -static bool send_rpc_cmd(const std::shared_ptr<socket_t> & sock, enum rpc_cmd cmd, const void * input, size_t input_size, void * output, size_t output_size) { +static bool send_rpc_cmd(socket_ptr sock, enum rpc_cmd cmd, const void * input, size_t input_size, void * output, size_t output_size) { if (!send_rpc_cmd(sock, cmd, input, input_size)) { return false; } - // TODO: currently the output_size is always known, do we need support for commands with variable output size? - // even if we do, we can skip sending output_size from the server for commands with known output size uint64_t out_size; - if (!recv_data(sock->fd, &out_size, sizeof(out_size))) { + if (!sock->recv_data(&out_size, sizeof(out_size))) { return false; } if (out_size != output_size) { return false; } - if (!recv_data(sock->fd, output, output_size)) { + if (!sock->recv_data(output, output_size)) { return false; } return true; @@ -495,17 +324,25 @@ static bool send_rpc_cmd(const std::shared_ptr<socket_t> & sock, enum rpc_cmd cm // RPC client-side implementation -static bool check_server_version(const std::shared_ptr<socket_t> & sock) { - rpc_msg_hello_rsp response; - bool status = send_rpc_cmd(sock, RPC_CMD_HELLO, nullptr, 0, &response, sizeof(response)); +// Performs HELLO handshake with transport auto-negotiation. +// Advertises local capabilities via conn_caps; if the server responds with +// matching capabilities, the socket is upgraded transparently. +static bool negotiate_hello(const std::shared_ptr<socket_t> & sock) { + rpc_msg_hello_req request = {}; + rpc_msg_hello_rsp response = {}; + + sock->get_caps(request.conn_caps); + + bool status = send_rpc_cmd(sock, RPC_CMD_HELLO, &request, sizeof(request), &response, sizeof(response)); RPC_STATUS_ASSERT(status); + if (response.major != RPC_PROTO_MAJOR_VERSION || response.minor > RPC_PROTO_MINOR_VERSION) { - GGML_LOG_ERROR("RPC server version mismatch: %d.%d.%d\n", response.major, response.minor, response.patch); + GGML_LOG_ERROR("RPC server version mismatch: %d.%d.%d\n", + response.major, response.minor, response.patch); return false; } - if (response.minor != RPC_PROTO_MINOR_VERSION || response.patch != RPC_PROTO_PATCH_VERSION) { - GGML_LOG_INFO("WARNING: RPC server version mismatch: %d.%d.%d\n", response.major, response.minor, response.patch); - } + + sock->update_caps(response.conn_caps); return true; } @@ -513,7 +350,6 @@ static std::shared_ptr<socket_t> get_socket(const std::string & endpoint) { static std::mutex mutex; std::lock_guard<std::mutex> lock(mutex); static std::unordered_map<std::string, std::weak_ptr<socket_t>> sockets; - static bool initialized = false; auto it = sockets.find(endpoint); if (it != sockets.end()) { @@ -527,26 +363,18 @@ static std::shared_ptr<socket_t> get_socket(const std::string & endpoint) { GGML_LOG_ERROR("Failed to parse endpoint: %s\n", endpoint.c_str()); return nullptr; } -#ifdef _WIN32 - if (!initialized) { - WSADATA wsaData; - int res = WSAStartup(MAKEWORD(2, 2), &wsaData); - if (res != 0) { - return nullptr; - } - initialized = true; + + if (!rpc_transport_init()) { + return nullptr; } -#else - GGML_UNUSED(initialized); -#endif - auto sock = socket_connect(host.c_str(), port); + auto sock = socket_t::connect(host.c_str(), port); if (sock == nullptr) { return nullptr; } - if (!check_server_version(sock)) { + if (!negotiate_hello(sock)) { return nullptr; } - LOG_DBG("[%s] connected to %s, sockfd=%d\n", __func__, endpoint.c_str(), sock->fd); + LOG_DBG("[%s] connected to %s\n", __func__, endpoint.c_str()); sockets[endpoint] = sock; return sock; } @@ -589,8 +417,10 @@ static rpc_tensor serialize_tensor(const ggml_tensor * tensor) { ggml_backend_buffer_t buffer = tensor->buffer; ggml_backend_rpc_buffer_context * ctx = (ggml_backend_rpc_buffer_context *)buffer->context; result.buffer = ctx != nullptr ? ctx->remote_ptr : 0; + result.data = reinterpret_cast<uint64_t>(tensor->data); } else { result.buffer = 0; + result.data = 0; } for (uint32_t i = 0; i < GGML_MAX_DIMS; i++) { result.ne[i] = tensor->ne[i]; @@ -606,7 +436,6 @@ static rpc_tensor serialize_tensor(const ggml_tensor * tensor) { } result.view_src = reinterpret_cast<uint64_t>(tensor->view_src); result.view_offs = tensor->view_offs; - result.data = reinterpret_cast<uint64_t>(tensor->data); // Avoid sending uninitialized data over the wire memset(result.name, 0, sizeof(result.name)); @@ -705,6 +534,8 @@ static ggml_backend_buffer_i ggml_backend_rpc_buffer_interface = { /* .memset_tensor = */ NULL, /* .set_tensor = */ ggml_backend_rpc_buffer_set_tensor, /* .get_tensor = */ ggml_backend_rpc_buffer_get_tensor, + /* .set_tensor_2d = */ NULL, + /* .get_tensor_2d = */ NULL, /* .cpy_tensor = */ ggml_backend_rpc_buffer_cpy_tensor, /* .clear = */ ggml_backend_rpc_buffer_clear, /* .reset = */ NULL, @@ -867,9 +698,11 @@ static void serialize_graph(uint32_t device, const ggml_cgraph * cgraph, std::ve static enum ggml_status ggml_backend_rpc_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) { ggml_backend_rpc_context * rpc_ctx = (ggml_backend_rpc_context *)backend->context; + ggml_backend_dev_t rpc_dev = ggml_backend_get_device(backend); + ggml_backend_rpc_device_context * rpc_dev_ctx = (ggml_backend_rpc_device_context *)rpc_dev->context; GGML_ASSERT(cgraph->n_nodes > 0); - bool reuse = rpc_ctx->gc.is_cached(cgraph); + bool reuse = cgraph->uid != 0 && rpc_dev_ctx->last_graph_uid == cgraph->uid; if (reuse) { rpc_msg_graph_recompute_req request; request.device = rpc_ctx->device; @@ -877,7 +710,7 @@ static enum ggml_status ggml_backend_rpc_graph_compute(ggml_backend_t backend, g bool status = send_rpc_cmd(sock, RPC_CMD_GRAPH_RECOMPUTE, &request, sizeof(request)); RPC_STATUS_ASSERT(status); } else { - rpc_ctx->gc.add(cgraph); + rpc_dev_ctx->last_graph_uid = cgraph->uid; std::vector<uint8_t> input; serialize_graph(rpc_ctx->device, cgraph, input); auto sock = get_socket(rpc_ctx->endpoint); @@ -892,6 +725,8 @@ static ggml_backend_i ggml_backend_rpc_interface = { /* .free = */ ggml_backend_rpc_free, /* .set_tensor_async = */ NULL, /* .get_tensor_async = */ NULL, + /* .set_tensor_2d_async = */ NULL, + /* .get_tensor_2d_async = */ NULL, /* .cpy_tensor_async = */ NULL, /* .synchronize = */ ggml_backend_rpc_synchronize, /* .graph_plan_create = */ NULL, @@ -941,10 +776,9 @@ ggml_backend_buffer_type_t ggml_backend_rpc_buffer_type(const char * endpoint, u ggml_backend_t ggml_backend_rpc_init(const char * endpoint, uint32_t device) { std::string dev_name = "RPC" + std::to_string(device) + "[" + std::string(endpoint) + "]"; ggml_backend_rpc_context * ctx = new ggml_backend_rpc_context { - /* .endpoint = */ endpoint, - /* .device = */ device, - /* .name = */ dev_name, - /* .gc = */ {}, + /* .endpoint = */ endpoint, + /* .device = */ device, + /* .name = */ dev_name, }; auto reg = ggml_backend_rpc_add_server(endpoint); ggml_backend_t backend = new ggml_backend { @@ -1008,8 +842,8 @@ class rpc_server { bool get_device_memory(const rpc_msg_get_device_memory_req & request, rpc_msg_get_device_memory_rsp & response); struct stored_graph { - ggml_context_ptr ctx_ptr; - ggml_cgraph * graph; + std::vector<uint8_t> buffer; + ggml_cgraph * graph; }; private: @@ -1162,12 +996,18 @@ ggml_tensor * rpc_server::deserialize_tensor(struct ggml_context * ctx, const rp return nullptr; } + // Fix: Prevent division by zero if blck_size is 0 (e.g., deprecated types) + if (ggml_blck_size((enum ggml_type)tensor->type) == 0) { + GGML_LOG_ERROR("[%s] invalid tensor type received (blck_size is 0): %u\n", __func__, tensor->type); + return nullptr; + } + ggml_tensor * result = ggml_new_tensor_4d(ctx, (ggml_type) tensor->type, tensor->ne[0], tensor->ne[1], tensor->ne[2], tensor->ne[3]); // ggml_new_tensor_4d might fail if dimensions are invalid, although less likely to crash than invalid type if (result == nullptr) { - GGML_LOG_ERROR("[%s] ggml_new_tensor_4d failed for type %u\\n", __func__, tensor->type); + GGML_LOG_ERROR("[%s] ggml_new_tensor_4d failed for type %u\n", __func__, tensor->type); return nullptr; } @@ -1245,7 +1085,7 @@ bool rpc_server::set_tensor(const std::vector<uint8_t> & input) { fs::path cache_file = fs::path(cache_dir) / hash_str; std::ofstream ofs(cache_file, std::ios::binary); ofs.write((const char *)data, size); - GGML_LOG_INFO("[%s] saved to '%s'\n", __func__, cache_file.c_str()); + GGML_LOG_INFO("[%s] saved to '%s'\n", __func__, cache_file.string().c_str()); } ggml_backend_tensor_set(tensor, data, offset, size); return true; @@ -1333,7 +1173,9 @@ bool rpc_server::init_tensor(const rpc_msg_init_tensor_req & request) { if (buffer && buffer->iface.init_tensor) { buffer->iface.init_tensor(buffer, tensor); } else { - GGML_LOG_ERROR("Null buffer for tensor passed to init_tensor function\n"); + if (!buffer) { + GGML_LOG_ERROR("Tensor with null buffer passed to init_tensor function\n"); + } } if (tensor->extra != nullptr) { @@ -1440,6 +1282,10 @@ ggml_tensor * rpc_server::create_node(uint64_t id, if (result == nullptr) { return nullptr; } + if (result->buffer == nullptr && result->data != nullptr) { + GGML_LOG_ERROR("[%s] invalid data ptr", __func__); + return nullptr; + } tensor_map[id] = result; for (int i = 0; i < GGML_MAX_SRC; i++) { // Check if the source ID is 0 before calling create_node recursively @@ -1505,10 +1351,12 @@ bool rpc_server::graph_compute(const std::vector<uint8_t> & input) { LOG_DBG("[%s] device: %u, n_nodes: %u, n_tensors: %u\n", __func__, device, n_nodes, n_tensors); size_t buf_size = ggml_tensor_overhead()*(n_nodes + n_tensors) + ggml_graph_overhead_custom(n_nodes, false); - + if (stored_graphs[device].buffer.size() < buf_size) { + stored_graphs[device].buffer.resize(buf_size); + } struct ggml_init_params params = { /*.mem_size =*/ buf_size, - /*.mem_buffer =*/ NULL, + /*.mem_buffer =*/ stored_graphs[device].buffer.data(), /*.no_alloc =*/ true, }; ggml_context_ptr ctx_ptr { ggml_init(params) }; @@ -1538,7 +1386,6 @@ bool rpc_server::graph_compute(const std::vector<uint8_t> & input) { } ggml_status status = ggml_backend_graph_compute(backends[device], graph); GGML_ASSERT(status == GGML_STATUS_SUCCESS && "Unsuccessful graph computations are not supported with RPC"); - stored_graphs[device].ctx_ptr.swap(ctx_ptr); stored_graphs[device].graph = graph; return true; } @@ -1579,27 +1426,46 @@ rpc_server::~rpc_server() { } static void rpc_serve_client(const std::vector<ggml_backend_t> & backends, const char * cache_dir, - sockfd_t sockfd) { + socket_ptr sock) { rpc_server server(backends, cache_dir); uint8_t cmd; - if (!recv_data(sockfd, &cmd, 1)) { + if (!sock->recv_data(&cmd, 1)) { return; } - // the first command sent by the client must be HELLO if (cmd != RPC_CMD_HELLO) { GGML_LOG_ERROR("Expected HELLO command, update client\n"); return; } - if (!recv_msg(sockfd, nullptr, 0)) { + + // Read input_size and validate protocol version + uint64_t hello_input_size; + if (!sock->recv_data(&hello_input_size, sizeof(hello_input_size))) { + return; + } + + if (hello_input_size != sizeof(rpc_msg_hello_req)) { + GGML_LOG_ERROR("HELLO request size mismatch (%zu vs %zu) — client needs upgrade to protocol v%d.x\n", + (size_t)hello_input_size, sizeof(rpc_msg_hello_req), RPC_PROTO_MAJOR_VERSION); + return; + } + + rpc_msg_hello_req req = {}; + if (!sock->recv_data(&req, sizeof(req))) { return; } - rpc_msg_hello_rsp response; - server.hello(response); - if (!send_msg(sockfd, &response, sizeof(response))) { + + rpc_msg_hello_rsp rsp = {}; + server.hello(rsp); + // Advertise server transport capabilities based on client's caps + sock->get_caps(rsp.conn_caps); + if (!send_msg(sock, &rsp, sizeof(rsp))) { return; } + + // Activate transport upgrade using client's caps + sock->update_caps(req.conn_caps); while (true) { - if (!recv_data(sockfd, &cmd, 1)) { + if (!sock->recv_data(&cmd, 1)) { break; } if (cmd >= RPC_CMD_COUNT) { @@ -1613,115 +1479,115 @@ static void rpc_serve_client(const std::vector<ggml_backend_t> & backends, const return; } case RPC_CMD_DEVICE_COUNT: { - if (!recv_msg(sockfd, nullptr, 0)) { + if (!recv_msg(sock, nullptr, 0)) { return; } rpc_msg_device_count_rsp response; response.device_count = backends.size(); - if (!send_msg(sockfd, &response, sizeof(response))) { + if (!send_msg(sock, &response, sizeof(response))) { return; } break; } case RPC_CMD_ALLOC_BUFFER: { rpc_msg_alloc_buffer_req request; - if (!recv_msg(sockfd, &request, sizeof(request))) { + if (!recv_msg(sock, &request, sizeof(request))) { return; } rpc_msg_alloc_buffer_rsp response; if (!server.alloc_buffer(request, response)) { return; } - if (!send_msg(sockfd, &response, sizeof(response))) { + if (!send_msg(sock, &response, sizeof(response))) { return; } break; } case RPC_CMD_GET_ALLOC_SIZE: { rpc_msg_get_alloc_size_req request; - if (!recv_msg(sockfd, &request, sizeof(request))) { + if (!recv_msg(sock, &request, sizeof(request))) { return; } rpc_msg_get_alloc_size_rsp response; if (!server.get_alloc_size(request, response)) { return; } - if (!send_msg(sockfd, &response, sizeof(response))) { + if (!send_msg(sock, &response, sizeof(response))) { return; } break; } case RPC_CMD_GET_ALIGNMENT: { rpc_msg_get_alignment_req request; - if (!recv_msg(sockfd, &request, sizeof(request))) { + if (!recv_msg(sock, &request, sizeof(request))) { return; } rpc_msg_get_alignment_rsp response; if (!server.get_alignment(request, response)) { return; } - if (!send_msg(sockfd, &response, sizeof(response))) { + if (!send_msg(sock, &response, sizeof(response))) { return; } break; } case RPC_CMD_GET_MAX_SIZE: { rpc_msg_get_max_size_req request; - if (!recv_msg(sockfd, &request, sizeof(request))) { + if (!recv_msg(sock, &request, sizeof(request))) { return; } rpc_msg_get_max_size_rsp response; if (!server.get_max_size(request, response)) { return; } - if (!send_msg(sockfd, &response, sizeof(response))) { + if (!send_msg(sock, &response, sizeof(response))) { return; } break; } case RPC_CMD_BUFFER_GET_BASE: { rpc_msg_buffer_get_base_req request; - if (!recv_msg(sockfd, &request, sizeof(request))) { + if (!recv_msg(sock, &request, sizeof(request))) { return; } rpc_msg_buffer_get_base_rsp response; if (!server.buffer_get_base(request, response)) { return; } - if (!send_msg(sockfd, &response, sizeof(response))) { + if (!send_msg(sock, &response, sizeof(response))) { return; } break; } case RPC_CMD_FREE_BUFFER: { rpc_msg_free_buffer_req request; - if (!recv_msg(sockfd, &request, sizeof(request))) { + if (!recv_msg(sock, &request, sizeof(request))) { return; } if (!server.free_buffer(request)) { return; } - if (!send_msg(sockfd, nullptr, 0)) { + if (!send_msg(sock, nullptr, 0)) { return; } break; } case RPC_CMD_BUFFER_CLEAR: { rpc_msg_buffer_clear_req request; - if (!recv_msg(sockfd, &request, sizeof(request))) { + if (!recv_msg(sock, &request, sizeof(request))) { return; } if (!server.buffer_clear(request)) { return; } - if (!send_msg(sockfd, nullptr, 0)) { + if (!send_msg(sock, nullptr, 0)) { return; } break; } case RPC_CMD_SET_TENSOR: { std::vector<uint8_t> input; - if (!recv_msg(sockfd, input)) { + if (!recv_msg(sock, input)) { return; } if (!server.set_tensor(input)) { @@ -1731,62 +1597,62 @@ static void rpc_serve_client(const std::vector<ggml_backend_t> & backends, const } case RPC_CMD_SET_TENSOR_HASH: { rpc_msg_set_tensor_hash_req request; - if (!recv_msg(sockfd, &request, sizeof(request))) { + if (!recv_msg(sock, &request, sizeof(request))) { return; } rpc_msg_set_tensor_hash_rsp response; if (!server.set_tensor_hash(request, response)) { return; } - if (!send_msg(sockfd, &response, sizeof(response))) { + if (!send_msg(sock, &response, sizeof(response))) { return; } break; } case RPC_CMD_INIT_TENSOR: { rpc_msg_init_tensor_req request; - if (!recv_msg(sockfd, &request,sizeof(request))) { + if (!recv_msg(sock, &request,sizeof(request))) { return; } if (!server.init_tensor(request)) { return; } - if (!send_msg(sockfd, nullptr, 0)) { + if (!send_msg(sock, nullptr, 0)) { return; } break; } case RPC_CMD_GET_TENSOR: { rpc_msg_get_tensor_req request; - if (!recv_msg(sockfd, &request, sizeof(request))) { + if (!recv_msg(sock, &request, sizeof(request))) { return; } std::vector<uint8_t> response; if (!server.get_tensor(request, response)) { return; } - if (!send_msg(sockfd, response.data(), response.size())) { + if (!send_msg(sock, response.data(), response.size())) { return; } break; } case RPC_CMD_COPY_TENSOR: { rpc_msg_copy_tensor_req request; - if (!recv_msg(sockfd, &request, sizeof(request))) { + if (!recv_msg(sock, &request, sizeof(request))) { return; } rpc_msg_copy_tensor_rsp response; if (!server.copy_tensor(request, response)) { return; } - if (!send_msg(sockfd, &response, sizeof(response))) { + if (!send_msg(sock, &response, sizeof(response))) { return; } break; } case RPC_CMD_GRAPH_COMPUTE: { std::vector<uint8_t> input; - if (!recv_msg(sockfd, input)) { + if (!recv_msg(sock, input)) { return; } if (!server.graph_compute(input)) { @@ -1796,7 +1662,7 @@ static void rpc_serve_client(const std::vector<ggml_backend_t> & backends, const } case RPC_CMD_GRAPH_RECOMPUTE: { rpc_msg_graph_recompute_req request; - if (!recv_msg(sockfd, &request, sizeof(request))) { + if (!recv_msg(sock, &request, sizeof(request))) { return; } if (!server.graph_recompute(request)) { @@ -1806,14 +1672,14 @@ static void rpc_serve_client(const std::vector<ggml_backend_t> & backends, const } case RPC_CMD_GET_DEVICE_MEMORY: { rpc_msg_get_device_memory_req request; - if (!recv_msg(sockfd, &request, sizeof(request))) { + if (!recv_msg(sock, &request, sizeof(request))) { return; } rpc_msg_get_device_memory_rsp response; if (!server.get_device_memory(request, response)) { return; } - if (!send_msg(sockfd, &response, sizeof(response))) { + if (!send_msg(sock, &response, sizeof(response))) { return; } break; @@ -1866,50 +1732,39 @@ void ggml_backend_rpc_start_server(const char * endpoint, const char * cache_dir if (!parse_endpoint(endpoint, host, port)) { return; } -#ifdef _WIN32 - { - WSADATA wsaData; - int res = WSAStartup(MAKEWORD(2, 2), &wsaData); - if (res != 0) { - fprintf(stderr, "WSAStartup failed: %d\n", res); - return; - } + +#ifdef GGML_RPC_RDMA + printf(" transport : TCP (RDMA auto-negotiate enabled)\n"); +#else + printf(" transport : TCP\n"); +#endif // GGML_RPC_RDMA + if (!rpc_transport_init()) { + fprintf(stderr, "Failed to initialize RPC transport\n"); + return; } -#endif - auto server_socket = create_server_socket(host.c_str(), port); + auto server_socket = socket_t::create_server(host.c_str(), port); if (server_socket == nullptr) { fprintf(stderr, "Failed to create server socket\n"); return; } while (true) { - auto client_socket = socket_accept(server_socket->fd); + auto client_socket = server_socket->accept(); if (client_socket == nullptr) { fprintf(stderr, "Failed to accept client connection\n"); return; } printf("Accepted client connection\n"); fflush(stdout); - rpc_serve_client(backends, cache_dir, client_socket->fd); + rpc_serve_client(backends, cache_dir, client_socket); printf("Client connection closed\n"); fflush(stdout); } -#ifdef _WIN32 - WSACleanup(); -#endif + rpc_transport_shutdown(); for (auto backend : backends) { ggml_backend_free(backend); } } -// device interface - -struct ggml_backend_rpc_device_context { - std::string endpoint; - uint32_t device; - std::string name; - std::string description; -}; - static const char * ggml_backend_rpc_device_get_name(ggml_backend_dev_t dev) { ggml_backend_rpc_device_context * ctx = (ggml_backend_rpc_device_context *)dev->context; @@ -2091,10 +1946,11 @@ ggml_backend_reg_t ggml_backend_rpc_add_server(const char * endpoint) { std::string dev_name = "RPC" + std::to_string(dev_id); std::string dev_desc = std::string(endpoint); ggml_backend_rpc_device_context * dev_ctx = new ggml_backend_rpc_device_context { - /* .endpoint = */ endpoint, - /* .device = */ ind, - /* .name = */ dev_name, - /* .description = */ dev_desc + /* .endpoint = */ endpoint, + /* .device = */ ind, + /* .name = */ dev_name, + /* .description = */ dev_desc, + /* .last_graph_uid = */ 0, }; ggml_backend_dev_t dev = new ggml_backend_device { diff --git a/ggml/src/ggml-rpc/transport.cpp b/ggml/src/ggml-rpc/transport.cpp new file mode 100644 index 00000000000..a728152421f --- /dev/null +++ b/ggml/src/ggml-rpc/transport.cpp @@ -0,0 +1,683 @@ +#include "transport.h" +#include "ggml-impl.h" + +#ifdef _WIN32 +# define WIN32_LEAN_AND_MEAN +# ifndef NOMINMAX +# define NOMINMAX +# endif +# include <windows.h> +# include <winsock2.h> +#else +# include <arpa/inet.h> +# include <sys/socket.h> +# include <sys/types.h> +# include <netinet/in.h> +# include <netinet/tcp.h> +# include <netdb.h> +# include <unistd.h> +#endif +#include <cstdlib> +#include <mutex> +#include <optional> + +#ifdef GGML_RPC_RDMA +# include <infiniband/verbs.h> +# include <time.h> +# ifndef _WIN32 +# include <poll.h> +# endif +#endif // GGML_RPC_RDMA + +#ifdef _WIN32 +typedef SOCKET sockfd_t; +using ssize_t = __int64; +#else +typedef int sockfd_t; +#endif + +static const char * RPC_DEBUG = std::getenv("GGML_RPC_DEBUG"); + +#define LOG_DBG(...) \ + do { if (RPC_DEBUG) GGML_LOG_DEBUG(__VA_ARGS__); } while (0) + +#ifdef GGML_RPC_RDMA +static constexpr size_t RDMA_CHUNK = 256 * 1024; // 256 KiB per send/recv (fits default 8 MiB memlock) +static constexpr int RDMA_RX_DEPTH = 24; // pre-posted recv ring: 24 × 256 KiB = 6 MiB +static constexpr size_t RDMA_GID_SIZE = 16; // RoCE GID / IB GID is always 16 bytes +using rdma_gid_t = std::array<uint8_t, RDMA_GID_SIZE>; + +struct rdma_conn { + struct ibv_context * ctx = nullptr; + struct ibv_pd * pd = nullptr; + struct ibv_cq * scq = nullptr; // send completions + struct ibv_cq * rcq = nullptr; // recv completions + struct ibv_qp * qp = nullptr; + + void * tx_buf = nullptr; + struct ibv_mr * tx_mr = nullptr; + + void * rx_buf = nullptr; // RDMA_RX_DEPTH × RDMA_CHUNK contiguous + struct ibv_mr * rx_mr = nullptr; + int rx_head = 0; + + uint32_t max_inline = 0; + + uint8_t * rx_slot(int i) const { + return static_cast<uint8_t *>(rx_buf) + static_cast<size_t>(i) * RDMA_CHUNK; + } + + bool post_rx(int i) { + struct ibv_sge sge = {}; + sge.addr = (uintptr_t)rx_slot(i); + sge.length = RDMA_CHUNK; + sge.lkey = rx_mr->lkey; + struct ibv_recv_wr wr = {}, * bad = nullptr; + wr.wr_id = (uint64_t)i; + wr.sg_list = &sge; + wr.num_sge = 1; + return ibv_post_recv(qp, &wr, &bad) == 0; + } + + ~rdma_conn() { + if (tx_mr) ibv_dereg_mr(tx_mr); + if (rx_mr) ibv_dereg_mr(rx_mr); + free(tx_buf); + free(rx_buf); + if (qp) ibv_destroy_qp(qp); + if (scq) ibv_destroy_cq(scq); + if (rcq) ibv_destroy_cq(rcq); + if (pd) ibv_dealloc_pd(pd); + if (ctx) ibv_close_device(ctx); + } +}; + +// Local RDMA parameters captured during the probe phase and later consumed +// by rdma_activate() after the remote side's caps arrive via HELLO. +struct rdma_local_info { + uint32_t qpn = 0; + uint32_t psn = 0; + uint8_t gid[RDMA_GID_SIZE] = {}; + uint8_t ib_port = 0; + int gid_idx = 0; + enum ibv_mtu path_mtu = IBV_MTU_1024; +}; + +struct rdma_caps { + uint32_t qpn; + uint32_t psn; + uint8_t gid[RDMA_GID_SIZE]; +}; + +static_assert(sizeof(rdma_caps) == RPC_CONN_CAPS_SIZE, "rdma_caps must match conn_caps size"); + +#endif // GGML_RPC_RDMA + +struct socket_t::impl { + impl(sockfd_t fd) : use_rdma(false), fd(fd) {} + ~impl(); + bool send_data(const void * data, size_t size); + bool recv_data(void * data, size_t size); + void get_caps(uint8_t * local_caps); + void update_caps(const uint8_t * remote_caps); + +#ifdef GGML_RPC_RDMA + bool tcp_peer_closed(); + std::optional<rdma_gid_t> rdma_build_target_gid(); + bool rdma_probe(); + bool rdma_activate(uint32_t remote_qpn, uint32_t remote_psn, const uint8_t * remote_gid); + bool rdma_poll(struct ibv_cq * cq, struct ibv_wc * wc); + bool rdma_send(const void * data, size_t size); + bool rdma_recv(void * data, size_t size); + + std::unique_ptr<rdma_conn> rdma; + rdma_local_info rdma_local = {}; +#endif // GGML_RPC_RDMA + bool use_rdma; + sockfd_t fd; +}; + +socket_t::impl::~impl() { +#ifdef GGML_RPC_RDMA + rdma.reset(); +#endif // GGML_RPC_RDMA + LOG_DBG("[%s] closing socket %d\n", __func__, this->fd); +#ifdef _WIN32 + if (fd != INVALID_SOCKET) closesocket(this->fd); +#else + if (fd >= 0) close(this->fd); +#endif +} + +#ifdef GGML_RPC_RDMA + +bool socket_t::impl::tcp_peer_closed() { + if (fd < 0) return false; +#ifndef _WIN32 + struct pollfd pfd = { fd, POLLIN | POLLRDHUP, 0 }; + int r = poll(&pfd, 1, 0); + return r > 0 && (pfd.revents & (POLLHUP | POLLERR | POLLRDHUP)); +#else + return false; +#endif +} + +// Build a RoCE GID-shaped 16-byte target from a TCP socket's local address. +// Used to match the socket's local IP against the kernel's GID table so that +// a single memcmp handles IPv4, IPv4-mapped IPv6, and native IPv6 uniformly: +// AF_INET -> ::ffff:a.b.c.d (bytes 10-11 = 0xff, last 4 = IPv4) +// AF_INET6 (IPv4-mapped) -> ::ffff:a.b.c.d (already in GID shape) +// AF_INET6 (native v6) -> the 16-byte IPv6 address as-is +// Returns std::nullopt on unsupported family or getsockname failure. +std::optional<rdma_gid_t> socket_t::impl::rdma_build_target_gid() { + sockaddr_storage addr = {}; + socklen_t addr_len = sizeof(addr); + if (getsockname(fd, reinterpret_cast<sockaddr *>(&addr), &addr_len) != 0) { + return std::nullopt; + } + rdma_gid_t target = {}; + if (addr.ss_family == AF_INET) { + const auto * a = reinterpret_cast<const sockaddr_in *>(&addr); + target[10] = 0xff; + target[11] = 0xff; + memcpy(&target[12], &a->sin_addr, 4); + return target; + } + if (addr.ss_family == AF_INET6) { + const auto * a = reinterpret_cast<const sockaddr_in6 *>(&addr); + memcpy(target.data(), &a->sin6_addr, RDMA_GID_SIZE); + return target; + } + return std::nullopt; +} + +bool socket_t::impl::rdma_probe() { + const char * dev_env = std::getenv("GGML_RDMA_DEV"); + const char * gid_env = std::getenv("GGML_RDMA_GID"); + + auto target_gid = rdma_build_target_gid(); + if (!target_gid) { + return false; + } + + const uint8_t ib_port = 1; + int num_devs = 0; + ibv_device ** devs = ibv_get_device_list(&num_devs); + if (!devs || num_devs == 0) return false; + + ibv_context * ibctx = nullptr; + const char * matched_dev = nullptr; + int gid_idx = gid_env ? atoi(gid_env) : -1; + int gid_version = IBV_GID_TYPE_IB; // 0 = unknown/IB + + for (int d = 0; d < num_devs; d++) { + const char * dn = ibv_get_device_name(devs[d]); + if (dev_env && strcmp(dev_env, dn) != 0) continue; + + ibv_context * ctx = ibv_open_device(devs[d]); + if (!ctx) continue; + + ibv_port_attr pa; + if (ibv_query_port(ctx, ib_port, &pa) != 0) { ibv_close_device(ctx); continue; } + + int found_gid = gid_idx; + int found_version = IBV_GID_TYPE_IB; + if (found_gid < 0) { + // Find a GID on this port whose bytes equal the local TCP address + // (IPv4 or IPv6). Prefer RoCE v2 (UDP/IP, L3-routable) over v1 + // (raw Ethernet, same-L2 only) so silent hangs on L3-routed paths + // are avoided. ibv_query_gid_ex returns gid+type in one call. + int v2_idx = -1; + int v1_idx = -1; + for (int i = 0; i < pa.gid_tbl_len; i++) { + ibv_gid_entry entry = {}; + if (ibv_query_gid_ex(ctx, ib_port, i, &entry, 0) != 0) continue; + if (memcmp(entry.gid.raw, target_gid->data(), RDMA_GID_SIZE) != 0) continue; + if (entry.gid_type == IBV_GID_TYPE_ROCE_V2 && v2_idx < 0) { + v2_idx = i; + } else if (entry.gid_type == IBV_GID_TYPE_ROCE_V1 && v1_idx < 0) { + v1_idx = i; + } + } + if (v2_idx >= 0) { + found_gid = v2_idx; + found_version = IBV_GID_TYPE_ROCE_V2; + } else if (v1_idx >= 0) { + found_gid = v1_idx; + found_version = IBV_GID_TYPE_ROCE_V1; + } + } else { + // Explicit GID index from GGML_RDMA_GID — fetch its type for logging. + ibv_gid_entry entry = {}; + if (ibv_query_gid_ex(ctx, ib_port, found_gid, &entry, 0) == 0) { + found_version = entry.gid_type; + } + } + if (found_gid >= 0) { + ibctx = ctx; + gid_idx = found_gid; + gid_version = found_version; + matched_dev = dn; + rdma_local.path_mtu = pa.active_mtu; + break; + } + ibv_close_device(ctx); + } + ibv_free_device_list(devs); + if (!ibctx) return false; + + rdma_local.ib_port = ib_port; + rdma_local.gid_idx = gid_idx; + + rdma = std::make_unique<rdma_conn>(); + rdma->ctx = ibctx; + + rdma->pd = ibv_alloc_pd(ibctx); + if (!rdma->pd) return false; + + rdma->scq = ibv_create_cq(ibctx, 16, nullptr, nullptr, 0); + rdma->rcq = ibv_create_cq(ibctx, RDMA_RX_DEPTH + 4, nullptr, nullptr, 0); + if (!rdma->scq || !rdma->rcq) return false; + + ibv_qp_init_attr qia = {}; + qia.send_cq = rdma->scq; + qia.recv_cq = rdma->rcq; + qia.qp_type = IBV_QPT_RC; + qia.cap.max_send_wr = 4; + qia.cap.max_recv_wr = RDMA_RX_DEPTH + 4; + qia.cap.max_send_sge = 1; + qia.cap.max_recv_sge = 1; + qia.cap.max_inline_data = 256; + + rdma->qp = ibv_create_qp(rdma->pd, &qia); + if (!rdma->qp) return false; + rdma->max_inline = qia.cap.max_inline_data; + + rdma->tx_buf = aligned_alloc(4096, RDMA_CHUNK); + rdma->rx_buf = aligned_alloc(4096, static_cast<size_t>(RDMA_RX_DEPTH) * RDMA_CHUNK); + if (!rdma->tx_buf || !rdma->rx_buf) return false; + + rdma->tx_mr = ibv_reg_mr(rdma->pd, rdma->tx_buf, RDMA_CHUNK, IBV_ACCESS_LOCAL_WRITE); + rdma->rx_mr = ibv_reg_mr(rdma->pd, rdma->rx_buf, static_cast<size_t>(RDMA_RX_DEPTH) * RDMA_CHUNK, + IBV_ACCESS_LOCAL_WRITE | IBV_ACCESS_REMOTE_WRITE); + if (!rdma->tx_mr || !rdma->rx_mr) return false; + + ibv_gid local_gid; + if (ibv_query_gid(ibctx, ib_port, gid_idx, &local_gid) != 0) return false; + + rdma_local.qpn = rdma->qp->qp_num; + rdma_local.psn = rdma->qp->qp_num & 0xffffff; + memcpy(&rdma_local.gid, &local_gid, RDMA_GID_SIZE); + + const char * ver_str = ""; + if (gid_version == IBV_GID_TYPE_ROCE_V2) { + ver_str = " RoCEv2"; + } else if (gid_version == IBV_GID_TYPE_ROCE_V1) { + ver_str = " RoCEv1"; + } + GGML_LOG_INFO("RDMA probed: dev=%s gid=%d%s qpn=%u inline=%u\n", + matched_dev, gid_idx, ver_str, rdma_local.qpn, rdma->max_inline); + return true; +} + +// Phase 2: Given remote QPN/PSN/GID, transition QP: RESET->INIT->pre-post->RTR->RTS. +// On success, the connection is live and ready for rdma_send/rdma_recv. +bool socket_t::impl::rdma_activate(uint32_t remote_qpn, uint32_t remote_psn, const uint8_t * remote_gid) { + // RESET -> INIT + { + struct ibv_qp_attr a = {}; + a.qp_state = IBV_QPS_INIT; + a.port_num = rdma_local.ib_port; + a.pkey_index = 0; + a.qp_access_flags = IBV_ACCESS_REMOTE_WRITE | IBV_ACCESS_REMOTE_READ | IBV_ACCESS_LOCAL_WRITE; + if (ibv_modify_qp(rdma->qp, &a, + IBV_QP_STATE | IBV_QP_PKEY_INDEX | IBV_QP_PORT | IBV_QP_ACCESS_FLAGS) != 0) { + return false; + } + } + + for (int i = 0; i < RDMA_RX_DEPTH; i++) { + if (!rdma->post_rx(i)) return false; + } + + // INIT -> RTR + { + struct ibv_qp_attr a = {}; + a.qp_state = IBV_QPS_RTR; + a.path_mtu = rdma_local.path_mtu; + a.dest_qp_num = remote_qpn; + a.rq_psn = remote_psn; + a.max_dest_rd_atomic = 1; + a.min_rnr_timer = 1; + a.ah_attr.is_global = 1; + memcpy(&a.ah_attr.grh.dgid, remote_gid, RDMA_GID_SIZE); + a.ah_attr.grh.hop_limit = 1; + a.ah_attr.grh.sgid_index = rdma_local.gid_idx; + a.ah_attr.dlid = 0; + a.ah_attr.port_num = rdma_local.ib_port; + if (ibv_modify_qp(rdma->qp, &a, + IBV_QP_STATE | IBV_QP_AV | IBV_QP_PATH_MTU | IBV_QP_DEST_QPN | + IBV_QP_RQ_PSN | IBV_QP_MAX_DEST_RD_ATOMIC | IBV_QP_MIN_RNR_TIMER) != 0) { + return false; + } + } + + // RTR -> RTS + { + struct ibv_qp_attr a = {}; + a.qp_state = IBV_QPS_RTS; + a.timeout = 14; + a.retry_cnt = 7; + a.rnr_retry = 7; + a.sq_psn = rdma_local.psn; + a.max_rd_atomic = 1; + if (ibv_modify_qp(rdma->qp, &a, + IBV_QP_STATE | IBV_QP_TIMEOUT | IBV_QP_RETRY_CNT | IBV_QP_RNR_RETRY | + IBV_QP_SQ_PSN | IBV_QP_MAX_QP_RD_ATOMIC) != 0) { + return false; + } + } + + GGML_LOG_INFO("RDMA activated: qpn=%u->%u mtu=%d rx_depth=%d\n", + rdma_local.qpn, remote_qpn, 128 << rdma_local.path_mtu, RDMA_RX_DEPTH); + return true; +} + +bool socket_t::impl::rdma_poll(struct ibv_cq * cq, struct ibv_wc * wc) { + for (uint64_t s = 0; ; s++) { + int n = ibv_poll_cq(cq, 1, wc); + if (n > 0) { + if (wc->status != IBV_WC_SUCCESS) { + GGML_LOG_ERROR("RDMA CQ wc error: status=%d (%s) vendor_err=0x%x\n", + wc->status, ibv_wc_status_str(wc->status), wc->vendor_err); + } + return wc->status == IBV_WC_SUCCESS; + } + if (n < 0) return false; + if ((s & 0xFFFFF) == 0 && s > 0) { + if (tcp_peer_closed()) { + return false; + } + } + } +} + +bool socket_t::impl::rdma_send(const void * data, size_t size) { + rdma_conn * c = rdma.get(); + const uint8_t * src = (const uint8_t *)data; + size_t rem = size; + while (rem > 0) { + size_t chunk = std::min(rem, RDMA_CHUNK); + + struct ibv_sge sge = {}; + struct ibv_send_wr wr = {}, * bad = nullptr; + wr.opcode = IBV_WR_SEND; + wr.sg_list = &sge; + wr.num_sge = 1; + + if (chunk <= c->max_inline) { + sge.addr = (uintptr_t)src; + sge.length = chunk; + wr.send_flags = IBV_SEND_SIGNALED | IBV_SEND_INLINE; + } else { + memcpy(c->tx_buf, src, chunk); + sge.addr = (uintptr_t)c->tx_buf; + sge.length = chunk; + sge.lkey = c->tx_mr->lkey; + wr.send_flags = IBV_SEND_SIGNALED; + } + + if (ibv_post_send(c->qp, &wr, &bad) != 0) return false; + struct ibv_wc wc; + if (!rdma_poll(c->scq, &wc)) return false; + + src += chunk; + rem -= chunk; + } + return true; +} + +bool socket_t::impl::rdma_recv(void * data, size_t size) { + rdma_conn * c = rdma.get(); + uint8_t * dst = (uint8_t *)data; + size_t rem = size; + while (rem > 0) { + struct ibv_wc wc; + if (!rdma_poll(c->rcq, &wc)) return false; + + int slot = (int)wc.wr_id; + size_t got = wc.byte_len; + memcpy(dst, c->rx_slot(slot), got); + + if (!c->post_rx(slot)) return false; + + dst += got; + rem -= got; + } + return true; +} + +#endif // GGML_RPC_RDMA + +bool socket_t::impl::send_data(const void * data, size_t size) { +#ifdef GGML_RPC_RDMA + if (use_rdma) { + return rdma_send(data, size); + } +#endif + size_t bytes_sent = 0; + while (bytes_sent < size) { + size_t size_to_send = std::min(size - bytes_sent, MAX_CHUNK_SIZE); + ssize_t n = send(fd, (const char *)data + bytes_sent, size_to_send, 0); + if (n < 0) { + GGML_LOG_ERROR("send failed (bytes_sent=%zu, size_to_send=%zu)\n", + bytes_sent, size_to_send); + return false; + } + bytes_sent += (size_t)n; + } + return true; +} + +bool socket_t::impl::recv_data(void * data, size_t size) { +#ifdef GGML_RPC_RDMA + if (use_rdma) { + return rdma_recv(data, size); + } +#endif + size_t bytes_recv = 0; + while (bytes_recv < size) { + size_t size_to_recv = std::min(size - bytes_recv, MAX_CHUNK_SIZE); + ssize_t n = recv(fd, (char *)data + bytes_recv, size_to_recv, 0); + if (n < 0) { + GGML_LOG_ERROR("recv failed (bytes_recv=%zu, size_to_recv=%zu)\n", + bytes_recv, size_to_recv); + return false; + } + if (n == 0) { + LOG_DBG("recv returned 0 (peer closed?)\n"); + return false; + } + bytes_recv += (size_t)n; + } + return true; +} + +void socket_t::impl::get_caps(uint8_t * local_caps) { + memset(local_caps, 0, RPC_CONN_CAPS_SIZE); +#ifdef GGML_RPC_RDMA + rdma_local = {}; + if (rdma_probe()) { + rdma_caps rc = {}; + rc.qpn = rdma_local.qpn; + rc.psn = rdma_local.psn; + memcpy(rc.gid, rdma_local.gid, RDMA_GID_SIZE); + memcpy(local_caps, &rc, sizeof(rc)); + } else { + rdma.reset(); + } +#endif // GGML_RPC_RDMA +} + +void socket_t::impl::update_caps(const uint8_t * remote_caps) { +#ifdef GGML_RPC_RDMA + if (!rdma) { + return; + } + rdma_caps rc = {}; + memcpy(&rc, remote_caps, sizeof(rc)); + if (rc.qpn == 0) { + rdma.reset(); + return; + } + if (rdma_activate(rc.qpn, rc.psn, rc.gid)) { + use_rdma = true; + } else { + GGML_LOG_ERROR("RDMA activate failed, staying on TCP\n"); + rdma.reset(); + } +#else + (void)remote_caps; +#endif // GGML_RPC_RDMA +} + + +///////////////////////////////////////////////////////////////////////////// + +socket_t::socket_t(std::unique_ptr<impl> p) : pimpl(std::move(p)) {} + +socket_t::~socket_t() = default; + +bool socket_t::send_data(const void * data, size_t size) { + return pimpl->send_data(data, size); +} + +bool socket_t::recv_data(void * data, size_t size) { + return pimpl->recv_data(data, size); +} + +void socket_t::get_caps(uint8_t * local_caps) { + return pimpl->get_caps(local_caps); +} + +void socket_t::update_caps(const uint8_t * remote_caps) { + return pimpl->update_caps(remote_caps); +} + +static bool is_valid_fd(sockfd_t sockfd) { +#ifdef _WIN32 + return sockfd != INVALID_SOCKET; +#else + return sockfd >= 0; +#endif +} + +static bool set_no_delay(sockfd_t sockfd) { + int flag = 1; + // set TCP_NODELAY to disable Nagle's algorithm + int ret = setsockopt(sockfd, IPPROTO_TCP, TCP_NODELAY, (char *)&flag, sizeof(int)); + return ret == 0; +} + +static bool set_reuse_addr(sockfd_t sockfd) { + int flag = 1; + int ret = setsockopt(sockfd, SOL_SOCKET, SO_REUSEADDR, (char *)&flag, sizeof(int)); + return ret == 0; +} + +socket_ptr socket_t::accept() { + auto client_socket_fd = ::accept(pimpl->fd, NULL, NULL); + if (!is_valid_fd(client_socket_fd)) { + return nullptr; + } + if (!set_no_delay(client_socket_fd)) { + GGML_LOG_ERROR("Failed to set TCP_NODELAY\n"); + return nullptr; + } + return socket_ptr(new socket_t(std::make_unique<impl>(client_socket_fd))); +} + +socket_ptr socket_t::create_server(const char * host, int port) { + auto sockfd = socket(AF_INET, SOCK_STREAM, 0); + if (!is_valid_fd(sockfd)) { + return nullptr; + } + if (!set_reuse_addr(sockfd)) { + GGML_LOG_ERROR("Failed to set SO_REUSEADDR\n"); + return nullptr; + } + if (inet_addr(host) == INADDR_NONE) { + GGML_LOG_ERROR("Invalid host address: %s\n", host); + return nullptr; + } + struct sockaddr_in serv_addr; + serv_addr.sin_family = AF_INET; + serv_addr.sin_addr.s_addr = inet_addr(host); + serv_addr.sin_port = htons(port); + + if (bind(sockfd, (struct sockaddr *) &serv_addr, sizeof(serv_addr)) < 0) { + return nullptr; + } + if (listen(sockfd, 1) < 0) { + return nullptr; + } + return socket_ptr(new socket_t(std::make_unique<impl>(sockfd))); +} + +socket_ptr socket_t::connect(const char * host, int port) { + auto sockfd = socket(AF_INET, SOCK_STREAM, 0); + if (!is_valid_fd(sockfd)) { + return nullptr; + } + if (!set_no_delay(sockfd)) { + GGML_LOG_ERROR("Failed to set TCP_NODELAY\n"); + return nullptr; + } + struct sockaddr_in addr; + addr.sin_family = AF_INET; + addr.sin_port = htons(port); + struct hostent * server = gethostbyname(host); + if (server == NULL) { + GGML_LOG_ERROR("Cannot resolve host '%s'\n", host); + return nullptr; + } + memcpy(&addr.sin_addr.s_addr, server->h_addr, server->h_length); + if (::connect(sockfd, (struct sockaddr *)&addr, sizeof(addr)) < 0) { + return nullptr; + } + return socket_ptr(new socket_t(std::make_unique<impl>(sockfd))); +} + +#ifdef _WIN32 +static std::mutex g_rpc_transport_mu; +static bool g_rpc_transport_wsa_started = false; +#endif + +bool rpc_transport_init() { +#ifdef _WIN32 + std::lock_guard<std::mutex> lock(g_rpc_transport_mu); + if (g_rpc_transport_wsa_started) { + return true; + } + WSADATA wsaData; + int res = WSAStartup(MAKEWORD(2, 2), &wsaData); + if (res != 0) { + return false; + } + g_rpc_transport_wsa_started = true; + return true; +#else + return true; +#endif +} + +void rpc_transport_shutdown() { +#ifdef _WIN32 + std::lock_guard<std::mutex> lock(g_rpc_transport_mu); + if (!g_rpc_transport_wsa_started) { + return; + } + WSACleanup(); + g_rpc_transport_wsa_started = false; +#endif +} diff --git a/ggml/src/ggml-rpc/transport.h b/ggml/src/ggml-rpc/transport.h new file mode 100644 index 00000000000..73b85cc530a --- /dev/null +++ b/ggml/src/ggml-rpc/transport.h @@ -0,0 +1,34 @@ +#pragma once + +#include <cstddef> +#include <cstdint> +#include <memory> + +struct socket_t; +typedef std::shared_ptr<socket_t> socket_ptr; + +static constexpr size_t MAX_CHUNK_SIZE = 1024ull * 1024ull * 1024ull; // 1 GiB +static constexpr size_t RPC_CONN_CAPS_SIZE = 24; + +struct socket_t { + ~socket_t(); + + bool send_data(const void * data, size_t size); + bool recv_data(void * data, size_t size); + + socket_ptr accept(); + + void get_caps(uint8_t * local_caps); + void update_caps(const uint8_t * remote_caps); + + static socket_ptr create_server(const char * host, int port); + static socket_ptr connect(const char * host, int port); + +private: + struct impl; + explicit socket_t(std::unique_ptr<impl> p); + std::unique_ptr<impl> pimpl; +}; + +bool rpc_transport_init(); +void rpc_transport_shutdown(); diff --git a/ggml/src/ggml-sycl/CMakeLists.txt b/ggml/src/ggml-sycl/CMakeLists.txt index 5a89d8dd688..180de92202d 100644 --- a/ggml/src/ggml-sycl/CMakeLists.txt +++ b/ggml/src/ggml-sycl/CMakeLists.txt @@ -1,7 +1,7 @@ message(STATUS "GGML_SYCL_TARGET=${GGML_SYCL_TARGET}") -if (NOT GGML_SYCL_TARGET MATCHES "^(INTEL|NVIDIA|AMD)$") - message(FATAL_ERROR "Invalid backend chosen, supported options are INTEL, NVIDIA, or AMD") +if (NOT GGML_SYCL_TARGET MATCHES "^(INTEL)$") + message(FATAL_ERROR "GGML_SYCL_TARGET: Invalid target, the supported options are [INTEL]") endif() check_cxx_compiler_flag("-fsycl" SUPPORTS_SYCL) @@ -25,6 +25,11 @@ ggml_add_backend_library(ggml-sycl file(GLOB GGML_HEADERS_SYCL "*.hpp") file(GLOB GGML_SOURCES_SYCL "*.cpp") +file(GLOB SRCS "template-instances/fattn-tile*.cpp") +list(APPEND GGML_SOURCES_SYCL ${SRCS}) +file(GLOB SRCS "template-instances/fattn-vec*.cpp") +list(APPEND GGML_SOURCES_SYCL ${SRCS}) + target_sources(ggml-sycl PRIVATE ${GGML_HEADERS_SYCL} ${GGML_SOURCES_SYCL}) if (WIN32) @@ -34,6 +39,18 @@ if (WIN32) set(CMAKE_CXX_COMPILER "icx") set(CMAKE_CXX_COMPILER_ID "IntelLLVM") endif() + # Level Zero SDK path for Windows (only when GGML_SYCL_SUPPORT_LEVEL_ZERO is enabled) + if(GGML_SYCL_SUPPORT_LEVEL_ZERO) + if(DEFINED ENV{LEVEL_ZERO_V1_SDK_PATH}) + set(LEVEL_ZERO_V1_SDK_PATH $ENV{LEVEL_ZERO_V1_SDK_PATH}) + if(EXISTS "${LEVEL_ZERO_V1_SDK_PATH}") + target_include_directories(ggml-sycl PRIVATE "${LEVEL_ZERO_V1_SDK_PATH}/include") + set(LEVEL_ZERO_V1_SDK_LIB_PATH "${LEVEL_ZERO_V1_SDK_PATH}/lib") + else() + message(WARNING "LEVEL_ZERO_V1_SDK_PATH set but folder not found: ${LEVEL_ZERO_V1_SDK_PATH}") + endif() + endif() + endif() endif() macro(detect_and_find_package package_name) @@ -88,6 +105,23 @@ endif() target_compile_options(ggml-sycl PRIVATE "-Wno-narrowing") +message(STATUS "GGML_SYCL_SUPPORT_LEVEL_ZERO ${GGML_SYCL_SUPPORT_LEVEL_ZERO}") +if (GGML_SYCL_SUPPORT_LEVEL_ZERO) + # Link against Level Zero loader for direct device memory allocation. + # Avoids sycl::malloc_device triggering DMA-buf/TTM system RAM staging + # in the xe kernel driver during multi-GPU inference. + find_path(LEVEL_ZERO_INCLUDE_DIR level_zero/ze_api.h HINTS ${ONEAPI_ROOT}/include ${LEVEL_ZERO_V1_SDK_PATH}/include) + find_library(ZE_LOADER_LIB ze_loader HINTS ${ONEAPI_ROOT}/lib ${LEVEL_ZERO_V1_SDK_LIB_PATH} ENV LD_LIBRARY_PATH) + if(ZE_LOADER_LIB AND LEVEL_ZERO_INCLUDE_DIR) + target_link_libraries(ggml-sycl PRIVATE ${ZE_LOADER_LIB}) + target_compile_definitions(ggml-sycl PRIVATE GGML_SYCL_SUPPORT_LEVEL_ZERO) + message(STATUS "Level Zero loader found: ${ZE_LOADER_LIB}") + message(STATUS "Level Zero headers found: ${LEVEL_ZERO_INCLUDE_DIR}") + else() + message(WARNING "Level Zero loader or headers not found, Level Zero support disabled") + endif() +endif() + # Link against oneDNN set(GGML_SYCL_DNNL 0) if(GGML_SYCL_DNN) @@ -125,110 +159,49 @@ endif() target_compile_definitions(ggml-sycl PRIVATE GGML_SYCL_DNNL=${GGML_SYCL_DNNL}) if (GGML_SYCL_F16) - if (GGML_SYCL_TARGET STREQUAL "AMD") - message(WARNING "AMD target does not entirely support FP16 in the SYCL backend.") - endif() add_compile_definitions(GGML_SYCL_F16) endif() if (GGML_SYCL_TARGET STREQUAL "INTEL") add_compile_definitions(GGML_SYCL_WARP_SIZE=16) - target_link_options(ggml-sycl PRIVATE -Xs -ze-intel-greater-than-4GB-buffer-required) -elseif (GGML_SYCL_TARGET STREQUAL "NVIDIA") - add_compile_definitions(GGML_SYCL_WARP_SIZE=32) -elseif (GGML_SYCL_TARGET STREQUAL "AMD") - # INFO: Allowed Sub_group_sizes are not consistent through all - # hip targets. For example, 64 is used for certain models, but the backend - # does not support it. - # Target archs tested working: gfx1030, gfx1031, (Only tested sub_group_size = 32) - add_compile_definitions(GGML_SYCL_WARP_SIZE=32) + if (NOT GGML_SYCL_DEVICE_ARCH) + target_link_options(ggml-sycl PRIVATE -Xs -ze-intel-greater-than-4GB-buffer-required) + else() + message(STATUS "Skipping -ze-intel-greater-than-4GB-buffer-required for spir64_gen AOT") + endif() + + # Link against Intel oneMKL + if (CMAKE_CXX_COMPILER_ID STREQUAL "Clang") + set(SYCL_COMPILER ON) + endif() + find_package(MKL REQUIRED) + target_link_libraries(ggml-sycl PRIVATE MKL::MKL_SYCL::BLAS) else() # default for other target + message(FATAL_ERROR "GGML_SYCL_TARGET is not supported") add_compile_definitions(GGML_SYCL_WARP_SIZE=32) endif() if (GGML_SYCL_GRAPH) + message(STATUS "find GGML_SYCL_GRAPH") target_compile_definitions(ggml-sycl PRIVATE GGML_SYCL_GRAPH) endif() -# Link against Intel oneMKL or oneMath -if (GGML_SYCL_TARGET STREQUAL "INTEL") - # Intel devices use Intel oneMKL directly instead of oneMath to avoid the limitation of linking Intel oneMKL statically - # See https://github.com/uxlfoundation/oneMath/issues/654 - if (CMAKE_CXX_COMPILER_ID STREQUAL "Clang") - set(SYCL_COMPILER ON) - endif() - find_package(MKL REQUIRED) - target_link_libraries(ggml-sycl PRIVATE MKL::MKL_SYCL::BLAS) - target_compile_definitions(ggml-sycl PRIVATE GGML_SYCL_USE_INTEL_ONEMKL) -else() - find_package(oneMath QUIET) - if (NOT oneMath_FOUND) - message(STATUS "oneMath not found: oneMath will be automatically downloaded") - # Use FetchContent to automatically pull and build oneMath - include(FetchContent) - set(BUILD_FUNCTIONAL_TESTS False) - set(BUILD_EXAMPLES False) - set(TARGET_DOMAINS blas) - if (GGML_SYCL_TARGET STREQUAL "NVIDIA") - set(ENABLE_MKLCPU_BACKEND False) - set(ENABLE_MKLGPU_BACKEND False) - set(ENABLE_CUBLAS_BACKEND True) - elseif (GGML_SYCL_TARGET STREQUAL "AMD") - set(ENABLE_MKLCPU_BACKEND False) - set(ENABLE_MKLGPU_BACKEND False) - set(ENABLE_ROCBLAS_BACKEND True) - # Ensure setting a string variable here is not overriden by oneMath CACHE variables - cmake_policy(SET CMP0126 NEW) - # Setting the device architecture is only needed and useful for AMD devices in oneMath - set(HIP_TARGETS ${GGML_SYCL_DEVICE_ARCH} CACHE STRING "oneMath HIP target" FORCE) - endif() - FetchContent_Declare( - ONEMATH - GIT_REPOSITORY https://github.com/uxlfoundation/oneMath.git - GIT_TAG 8efe85f5aaebb37f1d8c503b7af66315feabf142 - ) - FetchContent_MakeAvailable(ONEMATH) - # Create alias to match with find_package targets name - function(onemath_alias target) - if (TARGET ${target}_obj) - # Silence verbose warnings from external libraries - target_compile_options(${target}_obj PRIVATE -w) - endif() - if (TARGET ${target}) - add_library(ONEMATH::${target} ALIAS ${target}) - endif() - endfunction() - onemath_alias(onemath) - onemath_alias(onemath_blas_mklcpu) - onemath_alias(onemath_blas_mklgpu) - onemath_alias(onemath_blas_cublas) - onemath_alias(onemath_blas_rocblas) - endif() - - # Below oneMath compile-time dispatching is used for better performance - if (GGML_SYCL_TARGET STREQUAL "NVIDIA") - target_link_libraries(ggml-sycl PRIVATE ONEMATH::onemath_blas_cublas) - target_compile_options(ggml-sycl PRIVATE "-fsycl-targets=nvptx64-nvidia-cuda") - target_link_options(ggml-sycl PRIVATE "-fsycl-targets=nvptx64-nvidia-cuda") - target_compile_definitions(ggml-sycl PRIVATE GGML_SYCL_NVIDIA) - elseif (GGML_SYCL_TARGET STREQUAL "AMD") - if (NOT GGML_SYCL_DEVICE_ARCH) - message(FATAL_ERROR "Can't enable SYCL hip backend, GGML_SYCL_DEVICE_ARCH has not been set.") - endif() - target_link_libraries(ggml-sycl PRIVATE ONEMATH::onemath_blas_rocblas) - target_compile_options(ggml-sycl PRIVATE "-fsycl-targets=amdgcn-amd-amdhsa") - target_link_options(ggml-sycl PRIVATE "-fsycl-targets=amdgcn-amd-amdhsa") - target_compile_definitions(ggml-sycl PRIVATE GGML_SYCL_AMD) - else() - # Fallback to oneMath runtime dispatcher - target_link_libraries(ggml-sycl PRIVATE ONEMATH::onemath) - target_compile_definitions(ggml-sycl PRIVATE GGML_SYCL_GENERIC) - endif() +if (GGML_SYCL_HOST_MEM_FALLBACK) + message(STATUS "find GGML_SYCL_HOST_MEM_FALLBACK") + target_compile_definitions(ggml-sycl PRIVATE GGML_SYCL_HOST_MEM_FALLBACK) endif() if (GGML_SYCL_DEVICE_ARCH) - target_compile_options(ggml-sycl PRIVATE -Xsycl-target-backend --offload-arch=${GGML_SYCL_DEVICE_ARCH}) - target_link_options(ggml-sycl PRIVATE -Xsycl-target-backend --offload-arch=${GGML_SYCL_DEVICE_ARCH}) + message(STATUS "GGML_SYCL_DEVICE_ARCH=${GGML_SYCL_DEVICE_ARCH} (AOT via spir64_gen)") + target_compile_options( + ggml-sycl PRIVATE + -fsycl-targets=spir64_gen + "SHELL:-Xsycl-target-backend=spir64_gen \"-device ${GGML_SYCL_DEVICE_ARCH}\"" + ) + target_link_options( + ggml-sycl PRIVATE + -fsycl-targets=spir64_gen + "SHELL:-Xsycl-target-backend=spir64_gen \"-device ${GGML_SYCL_DEVICE_ARCH}\"" + ) endif() - diff --git a/ggml/src/ggml-sycl/add-id.cpp b/ggml/src/ggml-sycl/add-id.cpp index 00c073cf937..e0adc4fe423 100644 --- a/ggml/src/ggml-sycl/add-id.cpp +++ b/ggml/src/ggml-sycl/add-id.cpp @@ -55,7 +55,11 @@ void ggml_sycl_add_id(ggml_backend_sycl_context& ctx, ggml_tensor* dst) { const int32_t* src2_d = (const int32_t*)src2->data; float* dst_d = (float*)dst->data; - int threads = std::min((int)ne00, 768); // cols + const unsigned int max_work_group_size = ggml_sycl_info().max_work_group_sizes[ctx.device]; + GGML_ASSERT(max_work_group_size % (WARP_SIZE * WARP_SIZE) == 0); + + int threads = std::min((unsigned int)ne00, max_work_group_size); // cols + ctx.stream()->parallel_for( sycl::nd_range<3>( sycl::range<3>(1, ne02, ne01) * sycl::range<3>(1, 1, threads), diff --git a/ggml/src/ggml-sycl/backend.hpp b/ggml/src/ggml-sycl/backend.hpp index 75657f3fca2..a526d8e58bc 100644 --- a/ggml/src/ggml-sycl/backend.hpp +++ b/ggml/src/ggml-sycl/backend.hpp @@ -23,6 +23,8 @@ #include "dequantize.hpp" #include "dmmv.hpp" #include "element_wise.hpp" +#include "fattn.hpp" +#include "gated_delta_net.hpp" #include "gla.hpp" #include "im2col.hpp" #include "mmq.hpp" @@ -30,6 +32,7 @@ #include "norm.hpp" #include "outprod.hpp" #include "pad.hpp" +#include "pad_reflect_1d.hpp" #include "quantize.hpp" #include "quants.hpp" #include "roll.hpp" @@ -38,8 +41,8 @@ #include "ssm_conv.hpp" #include "softmax.hpp" #include "tsembd.hpp" +#include "upscale.hpp" #include "wkv.hpp" -#include "pad_reflect_1d.hpp" #endif // GGML_SYCL_BACKEND_HPP diff --git a/ggml/src/ggml-sycl/binbcast.cpp b/ggml/src/ggml-sycl/binbcast.cpp index 0a3883ae1ed..92dd18889f4 100644 --- a/ggml/src/ggml-sycl/binbcast.cpp +++ b/ggml/src/ggml-sycl/binbcast.cpp @@ -11,8 +11,8 @@ static void k_bin_bcast(const src0_t * src0, const src1_t * src1, dst_t * dst, int ne0, int ne1, int ne2, int ne3, int ne10, int ne11, int ne12, int ne13, /*int s0, */ int s1, int s2, int s3, - /*int s00,*/ int s01, int s02, int s03, - /*int s10,*/ int s11, int s12, int s13, + int s00, int s01, int s02, int s03, + int s10, int s11, int s12, int s13, const sycl::nd_item<3> &item_ct1) { const int i0s = item_ct1.get_local_range(2) * item_ct1.get_group(2) + item_ct1.get_local_id(2); @@ -44,7 +44,7 @@ static void k_bin_bcast(const src0_t * src0, const src1_t * src1, dst_t * dst, for (int i0 = i0s; i0 < ne0; i0 += item_ct1.get_local_range(2) * item_ct1.get_group_range(2)) { const int i10 = i0 % ne10; - dst_row[i0] = (dst_t)bin_op(src0 ? (float)src0_row[i0] : 0.0f, (float)src1_row[i10]); + dst_row[i0] = (dst_t)bin_op(src0 ? (float)src0_row[i0*s00] : 0.0f, (float)src1_row[i10*s10]); } } @@ -53,8 +53,8 @@ static void k_bin_bcast_unravel(const src0_t * src0, const src1_t * src1, dst_t int ne0, int ne1, int ne2, int ne3, int ne10, int ne11, int ne12, int ne13, /*int s0, */ int s1, int s2, int s3, - /*int s00,*/ int s01, int s02, int s03, - /*int s10,*/ int s11, int s12, int s13, + int s00, int s01, int s02, int s03, + int s10, int s11, int s12, int s13, const sycl::nd_item<3> &item_ct1) { const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) + @@ -82,7 +82,7 @@ static void k_bin_bcast_unravel(const src0_t * src0, const src1_t * src1, dst_t dst_t * dst_row = dst + i_dst; const int i10 = i0 % ne10; - dst_row[i0] = (dst_t)bin_op(src0 ? (float)src0_row[i0] : 0.0f, (float)src1_row[i10]); + dst_row[i0] = (dst_t)bin_op(src0 ? (float)src0_row[i0*s00] : 0.0f, (float)src1_row[i10*s10]); } @@ -95,7 +95,8 @@ struct bin_bcast_sycl { const int64_t ne3, const size_t nb00, const size_t nb01, const size_t nb02, const size_t nb03, const size_t nb10, const size_t nb11, const size_t nb12, const size_t nb13, const size_t nb0, const size_t nb1, const size_t nb2, const size_t nb3, const bool src0_is_contiguous, - const bool src1_is_contiguous, const bool dst_is_contiguous, queue_ptr stream) { + const bool src1_is_contiguous, const bool src0_is_permuted, const bool src1_is_permuted, + queue_ptr stream) { int nr0 = ne10 / ne0; int nr1 = ne11/ne1; int nr2 = ne12/ne2; @@ -123,7 +124,7 @@ struct bin_bcast_sycl { cnb[3] *= cne[3]; }; - if (src0_is_contiguous && src1_is_contiguous && dst_is_contiguous) { + if (src0_is_contiguous && src1_is_contiguous && !src0_is_permuted && !src1_is_permuted) { for (int i = 0; i < 4; i++) { if (nr[i] != 1) { break; @@ -164,7 +165,7 @@ struct bin_bcast_sycl { size_t nb12 = cnb1[2]; size_t nb13 = cnb1[3]; - size_t s0 = nb0 / sizeof(dst_t); + // size_t s0 = nb0 / sizeof(dst_t); size_t s1 = nb1 / sizeof(dst_t); size_t s2 = nb2 / sizeof(dst_t); size_t s3 = nb3 / sizeof(dst_t); @@ -196,9 +197,6 @@ struct bin_bcast_sycl { GGML_ASSERT(nb12 % sizeof(src1_t) == 0); GGML_ASSERT(nb13 % sizeof(src1_t) == 0); - GGML_ASSERT(s0 == 1); - GGML_ASSERT(s10 == 1); - const int block_size = 128; int64_t hne0 = std::max(ne0/2LL, 1LL); @@ -232,8 +230,8 @@ struct bin_bcast_sycl { [=](sycl::nd_item<3> item_ct1) { k_bin_bcast_unravel<bin_op>( src0_dd, src1_dd, dst_dd, ne0, ne1, ne2, ne3, - ne10, ne11, ne12, ne13, s1, s2, s3, s01, s02, - s03, s11, s12, s13, item_ct1); + ne10, ne11, ne12, ne13, s1, s2, s3, s00, s01, s02, + s03, s10, s11, s12, s13, item_ct1); }); } } else { @@ -251,7 +249,7 @@ struct bin_bcast_sycl { [=](sycl::nd_item<3> item_ct1) { k_bin_bcast<bin_op>(src0_dd, src1_dd, dst_dd, ne0, ne1, ne2, ne3, ne10, ne11, ne12, ne13, - s1, s2, s3, s01, s02, s03, s11, s12, s13, + s1, s2, s3, s00, s01, s02, s03, s10, s11, s12, s13, item_ct1); }); } @@ -268,24 +266,27 @@ inline void ggml_sycl_op_bin_bcast(ggml_backend_sycl_context & ctx, const ggml_t if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { op()((const float *) src0->data, (const float *) src1->data, (float *) dst->data, ne00, ne01, ne02, ne03, ne10, ne11, ne12, ne13, ne0, ne1, ne2, ne3, nb00, nb01, nb02, nb03, nb10, nb11, nb12, nb13, nb0, nb1, nb2, nb3, - ggml_is_contiguous(src0), ggml_is_contiguous(src1), ggml_is_contiguous(dst), main_stream); + ggml_is_contiguous(src0), ggml_is_contiguous(src1), ggml_is_permuted(src0), ggml_is_permuted(src1), main_stream); } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) { op()((const sycl::half *) src0->data, (const sycl::half *) src1->data, (sycl::half *) dst->data, ne00, ne01, ne02, ne03, ne10, ne11, ne12, ne13, ne0, ne1, ne2, ne3, nb00, nb01, nb02, nb03, nb10, nb11, nb12, nb13, - nb0, nb1, nb2, nb3, ggml_is_contiguous(src0), ggml_is_contiguous(src1), ggml_is_contiguous(dst), + nb0, nb1, nb2, nb3, ggml_is_contiguous(src0), ggml_is_contiguous(src1), ggml_is_permuted(src0), ggml_is_permuted(src1), main_stream); } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F16) { op()((const sycl::half *) src0->data, (const float *) src1->data, (sycl::half *) dst->data, ne00, ne01, ne02, ne03, ne10, ne11, ne12, ne13, ne0, ne1, ne2, ne3, nb00, nb01, nb02, nb03, nb10, nb11, nb12, nb13, nb0, nb1, - nb2, nb3, ggml_is_contiguous(src0), ggml_is_contiguous(src1), ggml_is_contiguous(dst), main_stream); + nb2, nb3, ggml_is_contiguous(src0), ggml_is_contiguous(src1), ggml_is_permuted(src0), ggml_is_permuted(src1), + main_stream); } else if (src0->type == GGML_TYPE_I32 && src1->type == GGML_TYPE_I32 && dst->type == GGML_TYPE_I32) { op()((const int32_t *) src0->data, (const int32_t *) src1->data, (int32_t *) dst->data, ne00, ne01, ne02, ne03, ne10, ne11, ne12, ne13, ne0, ne1, ne2, ne3, nb00, nb01, nb02, nb03, nb10, nb11, nb12, nb13, nb0, nb1, nb2, - nb3, ggml_is_contiguous(src0), ggml_is_contiguous(src1), ggml_is_contiguous(dst), main_stream); + nb3, ggml_is_contiguous(src0), ggml_is_contiguous(src1), ggml_is_permuted(src0), ggml_is_permuted(src1), + main_stream); } else if (src0->type == GGML_TYPE_I16 && src1->type == GGML_TYPE_I16 && dst->type == GGML_TYPE_I16) { op()((const int16_t *) src0->data, (const int16_t *) src1->data, (int16_t *) dst->data, ne00, ne01, ne02, ne03, ne10, ne11, ne12, ne13, ne0, ne1, ne2, ne3, nb00, nb01, nb02, nb03, nb10, nb11, nb12, nb13, nb0, nb1, nb2, - nb3, ggml_is_contiguous(src0), ggml_is_contiguous(src1), ggml_is_contiguous(dst), main_stream); + nb3, ggml_is_contiguous(src0), ggml_is_contiguous(src1), ggml_is_permuted(src0), ggml_is_permuted(src1), + main_stream); } else { fprintf(stderr, "%s: unsupported types: dst: %s, src0: %s, src1: %s\n", __func__, ggml_type_name(dst->type), ggml_type_name(src0->type), ggml_type_name(src1->type)); diff --git a/ggml/src/ggml-sycl/common.cpp b/ggml/src/ggml-sycl/common.cpp index 05fd5ef46c7..ae08abad81b 100644 --- a/ggml/src/ggml-sycl/common.cpp +++ b/ggml/src/ggml-sycl/common.cpp @@ -11,6 +11,10 @@ // #include "common.hpp" +#include <sycl/backend.hpp> +#ifdef GGML_SYCL_SUPPORT_LEVEL_ZERO +#include <level_zero/ze_api.h> +#endif #include "ggml-backend-impl.h" #include "ggml-impl.h" @@ -55,6 +59,20 @@ bool gpu_has_xmx(sycl::device &dev) { return dev.has(sycl::aspect::ext_intel_matrix); } +static int ggml_sycl_get_env(const char *env_name, int default_val) { + char *user_device_string = getenv(env_name); + int user_number = default_val; + + unsigned n; + if (user_device_string != NULL && + sscanf(user_device_string, " %u", &n) == 1) { + user_number = (int)n; + } else { + user_number = default_val; + } + return user_number; +} + int64_t downsample_sycl_global_range(int64_t accumulate_block_num, int64_t block_size) { const int64_t max_range = std::numeric_limits<int>::max(); int64_t sycl_down_blk_size = block_size; @@ -66,6 +84,61 @@ int64_t downsample_sycl_global_range(int64_t accumulate_block_num, int64_t block return sycl_down_blk_size; } +#ifdef GGML_SYCL_SUPPORT_LEVEL_ZERO +static bool ggml_sycl_use_level_zero_device_alloc(sycl::queue &q) { + return ggml_sycl_get_env("GGML_SYCL_ENABLE_LEVEL_ZERO", 1) && + q.get_device().is_gpu() && + q.get_backend() == sycl::backend::ext_oneapi_level_zero; +} +#endif + +// Use Level Zero zeMemAllocDevice to avoid sycl::malloc_device triggering +// DMA-buf/TTM system RAM staging in the xe kernel driver during multi-GPU inference. +// The decision is made from the queue and runtime env because large buffers can be +// allocated before ggml_check_sycl() initializes g_ggml_sycl_enable_level_zero. +void * ggml_sycl_malloc_device(size_t size, sycl::queue &q) { +#ifdef GGML_SYCL_SUPPORT_LEVEL_ZERO + if (ggml_sycl_use_level_zero_device_alloc(q)) { + void *ptr = nullptr; + auto ze_ctx = sycl::get_native<sycl::backend::ext_oneapi_level_zero>(q.get_context()); + auto ze_dev = sycl::get_native<sycl::backend::ext_oneapi_level_zero>(q.get_device()); +#ifdef ZE_RELAXED_ALLOCATION_LIMITS_EXP_NAME + ze_relaxed_allocation_limits_exp_desc_t relaxed_desc = { + ZE_STRUCTURE_TYPE_RELAXED_ALLOCATION_LIMITS_EXP_DESC, + nullptr, + ZE_RELAXED_ALLOCATION_LIMITS_EXP_FLAG_MAX_SIZE, + }; + ze_device_mem_alloc_desc_t alloc_desc = { + ZE_STRUCTURE_TYPE_DEVICE_MEM_ALLOC_DESC, + &relaxed_desc, + 0, + 0, + }; +#else + ze_device_mem_alloc_desc_t alloc_desc = {ZE_STRUCTURE_TYPE_DEVICE_MEM_ALLOC_DESC, nullptr, 0, 0}; +#endif + ze_result_t r = zeMemAllocDevice(ze_ctx, &alloc_desc, size, 64, ze_dev, &ptr); + if (r == ZE_RESULT_SUCCESS && ptr) { + return ptr; + } + return nullptr; + } +#endif + return sycl::malloc_device(size, q); +} + +void ggml_sycl_free_device(void *ptr, sycl::queue &q) { + if (!ptr) return; +#ifdef GGML_SYCL_SUPPORT_LEVEL_ZERO + if (ggml_sycl_use_level_zero_device_alloc(q)) { + auto ze_ctx = sycl::get_native<sycl::backend::ext_oneapi_level_zero>(q.get_context()); + zeMemFree(ze_ctx, ptr); + return; + } +#endif + SYCL_CHECK(CHECK_TRY_ERROR(sycl::free(ptr, q))); +} + void release_extra_gpu(ggml_tensor_extra_gpu * extra, std::vector<queue_ptr> streams) { for (int i = 0; i < ggml_sycl_info().device_count; ++i) { for (int64_t is = 0; is < GGML_SYCL_MAX_STREAMS; ++is) { @@ -75,8 +148,7 @@ void release_extra_gpu(ggml_tensor_extra_gpu * extra, std::vector<queue_ptr> str } if (extra->data_device[i] != nullptr && streams.size()>0) { ggml_sycl_set_device(i); - SYCL_CHECK( - CHECK_TRY_ERROR(sycl::free(extra->data_device[i], *(streams[i])))); + SYCL_CHECK(CHECK_TRY_ERROR(ggml_sycl_free_device(extra->data_device[i], *(streams[i])))); } } delete extra; diff --git a/ggml/src/ggml-sycl/common.hpp b/ggml/src/ggml-sycl/common.hpp index 519638fd416..d8bb3638dfd 100644 --- a/ggml/src/ggml-sycl/common.hpp +++ b/ggml/src/ggml-sycl/common.hpp @@ -19,10 +19,22 @@ #include <string> #include "dpct/helper.hpp" +#include "ggml.h" +#include "ggml-impl.h" #include "ggml-sycl.h" #include "presets.hpp" +#include "type.hpp" #include "sycl_hw.hpp" +#include "fattn-buffers.hpp" + +namespace syclexp = sycl::ext::oneapi::experimental; +#if defined(__INTEL_LLVM_COMPILER) && __has_include(<sycl/ext/oneapi/bfloat16.hpp>) + #include <sycl/ext/oneapi/bfloat16.hpp> + #ifndef GGML_SYCL_HAS_BF16 + #define GGML_SYCL_HAS_BF16 + #endif +#endif #if GGML_SYCL_DNNL #include "dnnl.hpp" @@ -31,6 +43,10 @@ #define GGML_COMMON_DECL_SYCL #define GGML_COMMON_IMPL_SYCL +#define SYCL_FLASH_ATTN //remove it to disable FLASH_ATTENTION in building. +#define SYCL_FAST_FP16 //don't change. remove it will break fattn-tile.hpp building +#define GGML_SYCL_FA_ALL_QUANTS //define it to enable all quantization types in flash attention. undefine it to only support F16, Q4_0 and Q8_0 in flash attention. + /* suppress warning spam */ #pragma clang diagnostic push #pragma clang diagnostic ignored "-Wnested-anon-types" @@ -45,6 +61,8 @@ void ggml_sycl_host_free(void* ptr); extern int g_ggml_sycl_debug; extern int g_ggml_sycl_disable_optimize; extern int g_ggml_sycl_prioritize_dmmv; +extern int g_ggml_sycl_enable_flash_attention; + #if defined(__clang__) && __has_builtin(__builtin_expect) // Hint the optimizer to pipeline the more likely following instruction in branches @@ -76,10 +94,10 @@ extern int g_ggml_sycl_prioritize_dmmv; #define __SYCL_ARCH__ DPCT_COMPATIBILITY_TEMP -#define VER_4VEC 610 // todo for hardward optimize. -#define VER_GEN9 700 // todo for hardward optimize. -#define VER_GEN12 1000000 // todo for hardward optimize. -#define VER_GEN13 (VER_GEN12 + 1030) // todo for hardward optimize. +#define VER_4VEC 610 // todo for hardware optimize. +#define VER_GEN9 700 // todo for hardware optimize. +#define VER_GEN12 1000000 // todo for hardware optimize. +#define VER_GEN13 (VER_GEN12 + 1030) // todo for hardware optimize. #define GGML_SYCL_MAX_NODES 8192 // TODO: adapt to hardwares @@ -170,6 +188,10 @@ static size_t g_scratch_offset = 0; int get_current_device_id(); +inline int ggml_sycl_get_device() { + return get_current_device_id(); +} + inline dpct::err0 ggml_sycl_set_device(const int device) try { int current_device_id; SYCL_CHECK(CHECK_TRY_ERROR(current_device_id = get_current_device_id())); @@ -194,14 +216,18 @@ struct optimize_feature { }; struct sycl_device_info { - int cc; // compute capability + int cc; // compute capability int nsm; // number of streaming multiprocessors (CUDA) maps to the maximum // number of compute units on a SYCL device. // size_t smpb; // max. shared memory per block size_t smpbo; // max. shared memory per block (with opt-in) + int warp_size; // WARP_SIZE(16)|WARP_32_SIZE(32)|WARP_16_SIZE(16). For Intel GPU, 16 is better in most cases. Some OP support 32 only. + int max_wg_per_cu; // max work groups per compute unit - refer to + // cudaOccupancyMaxActiveBlocksPerMultiprocessor bool vmm; // virtual memory support + size_t vmm_granularity; // granularity of virtual memory size_t total_vram; - //sycl_hw_info hw_info; \\ device id and aarch, currently not used + sycl_hw_info hw_info; optimize_feature opt_feature; }; @@ -214,10 +240,14 @@ struct ggml_sycl_device_info { std::array<float, GGML_SYCL_MAX_DEVICES> default_tensor_split = {}; int max_work_group_sizes[GGML_SYCL_MAX_DEVICES] = {0}; + + bool ext_oneapi_level_zero = true; // sycl::backend::ext_oneapi_level_zero used by all enumerated GPU devices }; const ggml_sycl_device_info & ggml_sycl_info(); +static constexpr size_t SYCL_BUFFER_ALIGNMENT = 128; + struct ggml_sycl_pool { virtual ~ggml_sycl_pool() = default; @@ -286,6 +316,10 @@ struct ggml_tensor_extra_gpu { optimize_feature optimized_feature; }; +extern int g_ggml_sycl_enable_level_zero; +void * ggml_sycl_malloc_device(size_t size, sycl::queue &q); +void ggml_sycl_free_device(void *ptr, sycl::queue &q); + void release_extra_gpu(ggml_tensor_extra_gpu * extra, std::vector<queue_ptr> streams={}); namespace sycl_ex = sycl::ext::oneapi::experimental; @@ -381,12 +415,16 @@ struct ggml_backend_sycl_context { std::unique_ptr<ggml_sycl_pool> pools[GGML_SYCL_MAX_DEVICES]; std::unordered_map<sycl::queue *, std::unique_ptr<ggml_sycl_pool_alloc<uint8_t>>> scratchpad_map; + std::unique_ptr<ggml_sycl_fattn_kv_buffers> fattn_bufs[GGML_SYCL_MAX_DEVICES]; + std::unique_ptr<ggml_sycl_pool> host_pools[GGML_SYCL_MAX_DEVICES]; static std::unique_ptr<ggml_sycl_pool> new_pool_for_device(queue_ptr qptr, int device); static std::unique_ptr<ggml_sycl_pool> new_pool_for_host(queue_ptr qptr, int device); + static std::unique_ptr<ggml_sycl_fattn_kv_buffers> new_fattn_kv_buffers(queue_ptr qptr, int device); + ggml_sycl_pool & pool(int device) { if (pools[device] == nullptr) { pools[device] = new_pool_for_device(stream(device,0), device); @@ -398,6 +436,17 @@ struct ggml_backend_sycl_context { return pool(device); } + ggml_sycl_fattn_kv_buffers & fattn_buffers(int device) { + if (fattn_bufs[device] == nullptr) { + fattn_bufs[device] = new_fattn_kv_buffers(stream(device, 0), device); + } + return *fattn_bufs[device]; + } + + ggml_sycl_fattn_kv_buffers & fattn_buffers() { + return fattn_buffers(device); + } + #ifdef GGML_SYCL_GRAPH std::unique_ptr<sycl_ex::command_graph<sycl_ex::graph_state::executable>> exec_graph = nullptr; #endif @@ -435,13 +484,15 @@ warp_reduce_sum(sycl::float2 a, const sycl::nd_item<3>& item_ct1) { return a; } -template <int width = WARP_SIZE> +/* use WARP_SIZE or WARP_32_SIZE*/ +template <int width> static __dpct_inline__ int warp_reduce_sum(int x) { return sycl::reduce_over_group( sycl::ext::oneapi::this_work_item::get_sub_group(), x, sycl::plus<>()); } -template <int width = WARP_SIZE> +/* use WARP_SIZE or WARP_32_SIZE*/ +template <int width> static __dpct_inline__ float warp_reduce_sum(float x) { #pragma unroll for (int offset = width / 2; offset > 0; offset >>= 1) { @@ -451,7 +502,19 @@ static __dpct_inline__ float warp_reduce_sum(float x) { return x; } -template <int width = WARP_SIZE> +/* use WARP_SIZE or WARP_32_SIZE*/ +template <int width> +static __dpct_inline__ float warp_reduce_sum(float x, const sycl::nd_item<3>& item_ct1) { +#pragma unroll + for (int offset = width / 2; offset > 0; offset >>= 1) { + x += dpct::permute_sub_group_by_xor( + item_ct1.get_sub_group(), x, offset); + } + return x; +} + +/* use WARP_SIZE or WARP_32_SIZE*/ +template <int width> static __dpct_inline__ sycl::float2 warp_reduce_sum(sycl::float2 a) { #pragma unroll for (int offset = width / 2; offset > 0; offset >>= 1) { @@ -465,7 +528,8 @@ static __dpct_inline__ sycl::float2 warp_reduce_sum(sycl::float2 a) { return a; } -template <int width = WARP_SIZE> +/* use WARP_SIZE or WARP_32_SIZE*/ +template <int width> static __dpct_inline__ sycl::half2 warp_reduce_sum(sycl::half2 a) { #pragma unroll for (int offset = width / 2; offset > 0; offset >>= 1) { @@ -481,7 +545,52 @@ static constexpr int ggml_sycl_get_physical_warp_size() { return WARP_SIZE; } -template <int width = WARP_SIZE> +/* use WARP_SIZE or WARP_32_SIZE*/ +template <int width> +static __dpct_inline__ int warp_reduce_all(int x) { + if (width == ggml_sycl_get_physical_warp_size()) { + return sycl::all_of_group( + sycl::ext::oneapi::this_work_item::get_sub_group(), + (~0xffffffff & + (0x1 << sycl::ext::oneapi::this_work_item::get_sub_group() + .get_local_linear_id())) || + x); + } else { +#pragma unroll + for (int offset = width / 2; offset > 0; offset >>= 1) { + x = dpct::permute_sub_group_by_xor( + sycl::ext::oneapi::this_work_item::get_sub_group(), x, + offset, width) && + x; + } + return x; + } +} + +/* use WARP_SIZE or WARP_32_SIZE*/ +template <int width> +static __dpct_inline__ int warp_reduce_any(int x) { + if (width == ggml_sycl_get_physical_warp_size()) { + return sycl::any_of_group( + sycl::ext::oneapi::this_work_item::get_sub_group(), + (0xffffffff & + (0x1 << sycl::ext::oneapi::this_work_item::get_sub_group() + .get_local_linear_id())) && + x); + } else { +#pragma unroll + for (int offset = width / 2; offset > 0; offset >>= 1) { + x = dpct::permute_sub_group_by_xor( + sycl::ext::oneapi::this_work_item::get_sub_group(), x, + offset, width) || + x; + } + return x; + } +} + +/* use WARP_SIZE or WARP_32_SIZE*/ +template <int width> static __dpct_inline__ float warp_reduce_max(float x) { #pragma unroll for (int offset = width / 2; offset > 0; offset >>= 1) { @@ -629,6 +738,42 @@ static const sycl::uint3 init_fastdiv_values(uint32_t d) { return sycl::uint3(mp, L, d); } +// Maximum number of bytes that can be copied in a single instruction. +// Set by test result. +static constexpr int ggml_sycl_get_max_cpy_bytes() { + return 16; +} + +// Aligned memory transfers of 8/16 bytes can be faster than 2 transfers with 4 bytes. +template <int nbytes, int alignment = 0> +static __dpct_inline__ void ggml_sycl_memcpy_1(void * dst, const void * src) { + if constexpr (alignment != 0) { + static_assert(nbytes % alignment == 0, "bad alignment"); + } + constexpr int nb_per_cpy = alignment == 0 ? nbytes : alignment; + +#pragma unroll + for (int i = 0; i < nbytes/nb_per_cpy; ++i) { + if constexpr (nb_per_cpy == 1) { + ((char *) dst)[i] = ((const char *) src)[i]; + } else if constexpr (nb_per_cpy == 2) { + ((short *) dst)[i] = ((const short *) src)[i]; + } else if constexpr (nb_per_cpy == 4) { + ((int *) dst)[i] = ((const int *) src)[i]; + } else if constexpr (nb_per_cpy == 8) { + ((sycl::int2 *) dst)[i] = ((const sycl::int2 *) src)[i]; + } else if constexpr (nb_per_cpy == 16) { + ((sycl::int4 *) dst)[i] = ((const sycl::int4 *) src)[i]; + } else { + static_assert(nbytes == 0 && nbytes == -1, "bad nbytes"); + } + } +} +template <typename T> +sycl::half2 __dpct_inline__ make_half2( T x, T y) { + sycl::half2 res(static_cast<sycl::half>(x),static_cast<sycl::half>(y)); + return res; +} static __dpct_inline__ uint32_t fastdiv(uint32_t n, const sycl::uint3 fastdiv_values) { const uint32_t hi = sycl::mul_hi<unsigned>(n, fastdiv_values.x()); @@ -636,6 +781,17 @@ static __dpct_inline__ uint32_t fastdiv(uint32_t n, const sycl::uint3 fastdiv_va } +template <typename T> +sycl::float2 __dpct_inline__ make_float2( T x, T y) { + sycl::float2 res(static_cast<float>(x),static_cast<float>(y)); + return res; +} + +sycl::float2 __dpct_inline__ __half22float2(sycl::half2 &H) { + sycl::float2 float2_value(static_cast<float>(H.x()), static_cast<float>(H.y())); + return float2_value; +} + static __dpct_inline__ sycl::uint2 fast_div_modulo(uint32_t n, const sycl::uint3 fastdiv_values) { const uint32_t div_val = fastdiv(n, fastdiv_values); const uint32_t mod_val = n - div_val * fastdiv_values.z(); @@ -659,5 +815,194 @@ static __dpct_inline__ float ggml_sycl_e8m0_to_fp32(uint8_t x) { return result; } +sycl::float2 __dpct_inline__ __half22float2(const sycl::half2 &H) { + sycl::float2 float2_value(static_cast<float>(H.x()), static_cast<float>(H.y())); + return float2_value; +} + +float __dpct_inline__ __half2float(sycl::half H) { + return static_cast<float>(H); +} + +static __dpct_inline__ void ggml_sycl_mad(float & acc, const float v, const float u) { + acc += v*u; +} + +static __dpct_inline__ void ggml_sycl_mad(float & acc, const sycl::float2 v, const sycl::float2 u) { + acc += v.x() * u.x(); + acc += v.y() * u.y(); +} + +static __dpct_inline__ void ggml_sycl_mad(float & acc, const sycl::half2 v, const sycl::half2 u) { +#ifdef GGML_SYCL_F16 + const sycl::float2 tmp = (v * u).template convert<float, sycl::rounding_mode::automatic>(); + acc += tmp.x() + tmp.y(); +#else + const sycl::float2 tmpv = __half22float2(v); + const sycl::float2 tmpu = __half22float2(u); + acc += tmpv.x() * tmpu.x(); + acc += tmpv.y() * tmpu.y(); +#endif // GGML_SYCL_F16 +} + +static __dpct_inline__ void ggml_sycl_mad(sycl::half2 & acc, const sycl::half2 v, const sycl::half2 u) { +#ifdef GGML_SYCL_F16 + acc += v*u; +#else + const sycl::float2 tmpv = __half22float2(v); + const sycl::float2 tmpu = __half22float2(u); + sycl::float2 tmpacc = __half22float2(acc); + // tmpacc.x += tmpv.x() * tmpu.x(); + // tmpacc.y += tmpv.y() * tmpu.y(); + sycl::float2 tmp1(tmpacc.x() + tmpv.x() * tmpu.x(), tmpacc.y() + tmpv.y() * tmpu.y()); + acc = make_half2(tmp1.x(), tmp1.y()); +#endif // GGML_SYCL_F16 +} + +template <int n> +struct ggml_sycl_unroll { + template <typename Func, typename... Args> + void operator()(const Func & f, Args... args) const { + f(n - 1, args...); + ggml_sycl_unroll<n - 1>{}(f, args...); + } +}; + +template <> +struct ggml_sycl_unroll<1> { + template <typename Func, typename... Args> + void operator()(const Func & f, Args... args) const { + f(0, args...); + } +}; + +static __dpct_inline__ sycl::half2 ggml_sycl_hmax2(const sycl::half2 a, const sycl::half2 b) { + sycl::half2 ret; + reinterpret_cast<sycl::half &>(ret.x()) = + sycl::vec<float, 1>(sycl::fmax(a[0], b[0])).convert<sycl::half, sycl::rounding_mode::automatic>()[0]; + reinterpret_cast<sycl::half &>(ret.y()) = + sycl::vec<float, 1>(sycl::fmax(a[1], b[1])).convert<sycl::half, sycl::rounding_mode::automatic>()[0]; + return ret; +} + +static __dpct_inline__ sycl::half ggml_sycl_hmax(const sycl::half a, const sycl::half b) { + return sycl::vec<float, 1>( + sycl::fmax(sycl::vec<sycl::half, 1>(a).convert<float, sycl::rounding_mode::automatic>()[0], + sycl::vec<sycl::half, 1>(b).convert<float, sycl::rounding_mode::automatic>()[0])) + .convert<sycl::half, sycl::rounding_mode::automatic>()[0]; +} + +static __dpct_inline__ uint32_t __hgt2_mask(const sycl::half2 a, const sycl::half2 b) { + const uint32_t mask_low = 0x0000FFFF * (float(a[0]) > float(b[0])); + const uint32_t mask_high = 0xFFFF0000 * (float(a[1]) > float(b[1])); + return mask_low | mask_high; +} + +static __dpct_inline__ uint32_t fastmodulo(uint32_t n, const sycl::uint3 fastdiv_values) { + // expects fastdiv_values to contain <mp, L, divisor> in <x, y, z> (see init_fastdiv_values) + return n - fastdiv(n, fastdiv_values) * fastdiv_values.z(); +} + +static bool fast_fp16_available(const int cc) { + GGML_UNUSED(cc); + return true; //Intel GPUs always support FP16. +} + +enum class block_reduce_method { + MAX, + SUM, +}; + +template<block_reduce_method method_t, typename T, int warp_size> +struct block_reduce_policy; + +template <typename T, typename... Ts> +inline constexpr bool is_any = (std::is_same_v<T, Ts> || ...); + +template<typename...> +inline constexpr bool ggml_sycl_dependent_false_v = false; + +#define WARP_32_SIZE 32 + +template <typename T, int warp_size> struct block_reduce_policy<block_reduce_method::SUM, T, warp_size> { + static T reduce(T val) { + if constexpr (is_any<T, float, sycl::float2, sycl::half2, int>) { + return warp_reduce_sum<warp_size>(val); + } else { + static_assert(ggml_sycl_dependent_false_v<T>, "Unsupported type for block reduce sum"); + } + } + + static T sentinel() { + if constexpr (std::is_same_v<T, float>) { + return 0.0f; + } else if constexpr (std::is_same_v<T, sycl::float2>) { + return sycl::float2(0.0f, 0.0f); + } else if constexpr (std::is_same_v<T, sycl::half2>) { + return sycl::half2(0.0f, 0.0f); + } else if constexpr (std::is_same_v<T, int>) { + return 0; + } else { + static_assert(ggml_sycl_dependent_false_v<T>, "Unsupported type for block reduce sum"); + } + } +}; + +template <typename T, int warp_size> struct block_reduce_policy<block_reduce_method::MAX, T, warp_size> { + static T reduce(T val) { + if constexpr (is_any<T, float, sycl::half2>) { + return warp_reduce_max<warp_size>(val); + } else { + static_assert(ggml_sycl_dependent_false_v<T>, "Unsupported type for block reduce max"); + } + } + + static T sentinel() { + if constexpr (std::is_same_v<T, float>) { + return -INFINITY; + } else if constexpr (std::is_same_v<T, sycl::half2>) { + return sycl::half2(-INFINITY, -INFINITY); + } else { + static_assert(ggml_sycl_dependent_false_v<T>, "Unsupported type for block reduce max"); + } + } +}; + + +template <block_reduce_method reduce_method_t, int warp_size, typename T> +static T block_reduce(T val, T * shared_vals, int block_size_template) { + auto item_ct1 = sycl::ext::oneapi::this_work_item::get_nd_item<3>(); + val = block_reduce_policy<reduce_method_t, T,warp_size>::reduce(val); + const int block_size = block_size_template == 0 ? item_ct1.get_local_range(2) : block_size_template; + const int nthreads = item_ct1.get_local_range(2); + const int nwarps = nthreads / WARP_SIZE; + + if (block_size > warp_size) { + assert((block_size <= 1024) && (block_size % warp_size) == 0); + const int warp_id = item_ct1.get_local_id(2) / warp_size; + const int lane_id = item_ct1.get_local_id(2) % warp_size; + if (lane_id == 0) { + shared_vals[warp_id] = val; + } + item_ct1.barrier(sycl::access::fence_space::local_space); + + size_t nreduce = nwarps / WARP_SIZE; + float tmp = 0.f; + if (lane_id < (static_cast<int>(block_size) / warp_size)) { + for (size_t i = 0; i < nreduce; i += 1) + { + tmp += shared_vals[lane_id + i * WARP_SIZE]; + } + } + return block_reduce_policy<reduce_method_t, T, warp_size>::reduce(tmp); + } + return val; +} + +static __dpct_inline__ float ggml_sycl_ue4m3_to_fp32(uint8_t x) { + const uint32_t bits = x * (x != 0x7F && x != 0xFF); + const __nv_fp8_e4m3 xf = *reinterpret_cast<const __nv_fp8_e4m3 *>(&bits); + return static_cast<float>(xf) / 2; +} #endif // GGML_SYCL_COMMON_HPP diff --git a/ggml/src/ggml-sycl/convert.cpp b/ggml/src/ggml-sycl/convert.cpp index 8bdae36458c..65593402e7d 100644 --- a/ggml/src/ggml-sycl/convert.cpp +++ b/ggml/src/ggml-sycl/convert.cpp @@ -2,13 +2,6 @@ #include "dequantize.hpp" #include "presets.hpp" -#if defined(__INTEL_LLVM_COMPILER) - #if __has_include(<sycl/ext/oneapi/bfloat16.hpp>) - #include <sycl/ext/oneapi/bfloat16.hpp> - #define GGML_SYCL_HAS_BF16 - #endif -#endif - template <int qk, int qr, dequantize_kernel_t dequantize_kernel, typename dst_t> static void dequantize_block(const void * __restrict__ vx, dst_t * __restrict__ y, const int64_t k, const sycl::nd_item<3> &item_ct1) { @@ -114,6 +107,19 @@ static void dequantize_row_q3_K_sycl(const void *vx, dst_t *y, const int64_t k, #endif } +template <typename dst_t> +static void dequantize_row_q3_K_sycl_reorder(const void *vx, dst_t *y, const int64_t k, + dpct::queue_ptr stream) { + const int64_t nb = k / QK_K; + + dpct::has_capability_or_fail(stream->get_device(), { sycl::aspect::fp16 }); + stream->parallel_for( + sycl::nd_range<3>(sycl::range<3>(1, 1, nb) * sycl::range<3>(1, 1, 64), sycl::range<3>(1, 1, 64)), + [=](sycl::nd_item<3> item_ct1) { + dequantize_block_q3_K_reorder(vx, y, item_ct1, nb); + }); +} + template <typename dst_t> static void dequantize_row_q4_0_sycl(const void *vx, dst_t *y, const int64_t k, dpct::queue_ptr stream) { @@ -151,6 +157,25 @@ static void dequantize_row_q4_0_sycl_reorder(const void *vx, dst_t *y, const int } +template <typename dst_t> +static void dequantize_row_q8_0_sycl_reorder(const void *vx, dst_t *y, const int64_t k, + dpct::queue_ptr stream) { + + dpct::has_capability_or_fail(stream->get_device(), + {sycl::aspect::fp16}); + + int constexpr WARP_K = WARP_SIZE * QK8_0; + const int n_warp = (k + WARP_K - 1) / WARP_K; + GGML_ASSERT(k % QK8_0 == 0); + stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, n_warp) * + sycl::range<3>(1, 1, WARP_SIZE), + sycl::range<3>(1, 1, WARP_SIZE)), + [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]]{ + dequantize_block_q8_0_reorder(vx, y, k, item_ct1); + }); + +} + template <typename dst_t> static void dequantize_row_q4_1_sycl(const void *vx, dst_t *y, const int64_t k, dpct::queue_ptr stream) { @@ -240,6 +265,23 @@ static void dequantize_row_q5_K_sycl(const void *vx, dst_t *y, const int64_t k, #endif } +template <typename dst_t> +static void dequantize_row_q5_K_sycl_reorder(const void * vx, dst_t * y, const int64_t k, dpct::queue_ptr stream) { + const int64_t nb = k / QK_K; + + dpct::has_capability_or_fail(stream->get_device(), { sycl::aspect::fp16 }); + + stream->submit([&](sycl::handler & cgh) { + sycl::local_accessor<uint8_t, 1> scale_local_acc(sycl::range<1>(K_SCALE_SIZE), cgh); + + cgh.parallel_for( + sycl::nd_range<3>(sycl::range<3>(1, 1, nb) * sycl::range<3>(1, 1, 64), sycl::range<3>(1, 1, 64)), + [=](sycl::nd_item<3> item_ct1) { + dequantize_block_q5_K_reorder(vx, y, get_pointer(scale_local_acc), item_ct1, nb); + }); + }); +} + template <typename dst_t> static void dequantize_row_q6_K_sycl(const void *vx, dst_t *y, const int64_t k, dpct::queue_ptr stream) { @@ -482,6 +524,75 @@ static void dequantize_row_mxfp4_sycl(const void * vx, dst_t * y, const int64_t }); } +template <typename dst_t> +static void dequantize_row_nvfp4_sycl(const void * vx, dst_t * y, const int64_t k, dpct::queue_ptr stream) { + GGML_ASSERT(k % QK_NVFP4 == 0); + const int nb = k / QK_NVFP4; + stream->parallel_for( + sycl::nd_range<3>(sycl::range<3>(1, 1, nb) * sycl::range<3>(1, 1, 32), sycl::range<3>(1, 1, 32)), + [=](sycl::nd_item<3> /*item_ct1*/) { + dequantize_block_nvfp4(vx, y, k); + }); +} + + +template <int qk, int qr, dequantize_kernel_t dequantize_kernel, typename dst_t> +static void dequantize_block_nc(const void * __restrict__ vx, dst_t * __restrict__ y, + const int64_t ne00, const int64_t ne01, const int64_t ne02, + const int64_t s01, const int64_t s02, const int64_t s03) { + auto item_ct1 = sycl::ext::oneapi::this_work_item::get_nd_item<3>(); + const int64_t i00 = 2 * (int64_t(item_ct1.get_local_range(2)) * item_ct1.get_group(2) + item_ct1.get_local_id(2)); + + if (i00 >= ne00) { + return; + } + + const int64_t i01 = item_ct1.get_group(1); + const int64_t i02 = item_ct1.get_group(0) % ne02; + const int64_t i03 = item_ct1.get_group(0) / ne02; + + const int64_t ibx0 = i03*s03 + i02*s02 + i01*s01; + + const int64_t ib = ibx0 + i00/qk; // block index + const int64_t iqs = (i00%qk)/qr; // quant index + const int64_t iybs = i00 - i00%qk; // y block start index + const int64_t y_offset = qr == 1 ? 1 : qk/2; + + // dequantize + #ifdef GGML_SYCL_F16 + sycl::half2 v; + #else + sycl::float2 v; + #endif + + dequantize_kernel(vx, ib, iqs, v); + + const int64_t iy0 = ((i03*ne02 + i02)*ne01 + i01)*ne00 + iybs + iqs; + y[iy0 + 0] = ggml_sycl_cast<dst_t>(v.x()); + y[iy0 + y_offset] = ggml_sycl_cast<dst_t>(v.y()); +} + + +template <int qk, int qr, dequantize_kernel_t dequantize_kernel, typename dst_t> +static void dequantize_block_nc_sycl(const void * vx, + dst_t * y, + const int64_t ne00, + const int64_t ne01, + const int64_t ne02, + const int64_t ne03, + const int64_t s01, + const int64_t s02, + const int64_t s03, + dpct::queue_ptr stream) { + const dpct::dim3 num_blocks((ne00 + 2 * SYCL_DEQUANTIZE_BLOCK_SIZE - 1) / (2 * SYCL_DEQUANTIZE_BLOCK_SIZE), ne01, + ne02 * ne03); + stream->parallel_for(sycl::nd_range<3>(num_blocks * sycl::range<3>(1, 1, SYCL_DEQUANTIZE_BLOCK_SIZE), + sycl::range<3>(1, 1, SYCL_DEQUANTIZE_BLOCK_SIZE)), + [=](sycl::nd_item<3> item_ct1) { + GGML_UNUSED(item_ct1); + dequantize_block_nc<qk, qr, dequantize_kernel>(vx, y, ne00, ne01, ne02, s01, s02, s03); + }); +} template <typename src_t, typename dst_t> static void convert_unary_nc(const void * __restrict__ vx, dst_t * __restrict__ y, const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t s01, const int64_t s02, const int64_t s03, @@ -545,11 +656,20 @@ to_fp16_sycl_t ggml_get_to_fp16_sycl(ggml_type type, ggml_tensor * dst) { case GGML_TYPE_Q5_1: return dequantize_block_sycl<QK5_1, QR5_1, dequantize_q5_1>; case GGML_TYPE_Q8_0: - return dequantize_block_sycl<QK8_0, QR8_0, dequantize_q8_0>; + if (dst->src[0]->extra && + ((ggml_tensor_extra_gpu *) dst->src[0]->extra)->optimized_feature.reorder) { + return dequantize_row_q8_0_sycl_reorder; + } else { + return dequantize_block_sycl<QK8_0, QR8_0, dequantize_q8_0>; + } case GGML_TYPE_Q2_K: return dequantize_row_q2_K_sycl; case GGML_TYPE_Q3_K: - return dequantize_row_q3_K_sycl; + if (dst->src[0]->extra && ((ggml_tensor_extra_gpu *) dst->src[0]->extra)->optimized_feature.reorder) { + return dequantize_row_q3_K_sycl_reorder; + } else { + return dequantize_row_q3_K_sycl; + } case GGML_TYPE_Q4_K: if (dst->src[0]->extra && ((ggml_tensor_extra_gpu *) dst->src[0]->extra)->optimized_feature.reorder) { return dequantize_row_q4_K_sycl_reorder; @@ -557,7 +677,11 @@ to_fp16_sycl_t ggml_get_to_fp16_sycl(ggml_type type, ggml_tensor * dst) { return dequantize_row_q4_K_sycl; } case GGML_TYPE_Q5_K: - return dequantize_row_q5_K_sycl; + if (dst->src[0]->extra && ((ggml_tensor_extra_gpu *) dst->src[0]->extra)->optimized_feature.reorder) { + return dequantize_row_q5_K_sycl_reorder; + } else { + return dequantize_row_q5_K_sycl; + } case GGML_TYPE_Q6_K: if (dst->src[0]->extra && ((ggml_tensor_extra_gpu *) dst->src[0]->extra)->optimized_feature.reorder) { return dequantize_row_q6_K_sycl_reorder; @@ -584,6 +708,8 @@ to_fp16_sycl_t ggml_get_to_fp16_sycl(ggml_type type, ggml_tensor * dst) { return dequantize_row_iq4_nl_sycl; case GGML_TYPE_MXFP4: return dequantize_row_mxfp4_sycl; + case GGML_TYPE_NVFP4: + return dequantize_row_nvfp4_sycl; case GGML_TYPE_F32: return convert_unary_sycl<float>; #ifdef GGML_SYCL_HAS_BF16 @@ -591,6 +717,7 @@ to_fp16_sycl_t ggml_get_to_fp16_sycl(ggml_type type, ggml_tensor * dst) { return convert_unary_sycl<sycl::ext::oneapi::bfloat16>; #endif default: + GGML_ABORT("fatal error: unsupport data type=%s\n", ggml_type_name(type)); return nullptr; } } @@ -611,11 +738,20 @@ to_fp32_sycl_t ggml_get_to_fp32_sycl(ggml_type type, ggml_tensor *dst) { case GGML_TYPE_Q5_1: return dequantize_block_sycl<QK5_1, QR5_1, dequantize_q5_1>; case GGML_TYPE_Q8_0: - return dequantize_block_sycl<QK8_0, QR8_0, dequantize_q8_0>; + if (dst->src[0]->extra && + ((ggml_tensor_extra_gpu*)dst->src[0]->extra)->optimized_feature.reorder) { + return dequantize_row_q8_0_sycl_reorder; + } else { + return dequantize_block_sycl<QK8_0, QR8_0, dequantize_q8_0>; + } case GGML_TYPE_Q2_K: return dequantize_row_q2_K_sycl; case GGML_TYPE_Q3_K: - return dequantize_row_q3_K_sycl; + if (dst->src[0]->extra && ((ggml_tensor_extra_gpu *) dst->src[0]->extra)->optimized_feature.reorder) { + return dequantize_row_q3_K_sycl_reorder; + } else { + return dequantize_row_q3_K_sycl; + } case GGML_TYPE_Q4_K: if (dst->src[0]->extra && ((ggml_tensor_extra_gpu*)dst->src[0]->extra)->optimized_feature.reorder) { @@ -624,7 +760,11 @@ to_fp32_sycl_t ggml_get_to_fp32_sycl(ggml_type type, ggml_tensor *dst) { return dequantize_row_q4_K_sycl; } case GGML_TYPE_Q5_K: - return dequantize_row_q5_K_sycl; + if (dst->src[0]->extra && ((ggml_tensor_extra_gpu *) dst->src[0]->extra)->optimized_feature.reorder) { + return dequantize_row_q5_K_sycl_reorder; + } else { + return dequantize_row_q5_K_sycl; + } case GGML_TYPE_Q6_K: if (dst->src[0]->extra && ((ggml_tensor_extra_gpu *) dst->src[0]->extra)->optimized_feature.reorder) { return dequantize_row_q6_K_sycl_reorder; @@ -651,6 +791,8 @@ to_fp32_sycl_t ggml_get_to_fp32_sycl(ggml_type type, ggml_tensor *dst) { return dequantize_row_iq4_nl_sycl; case GGML_TYPE_MXFP4: return dequantize_row_mxfp4_sycl; + case GGML_TYPE_NVFP4: + return dequantize_row_nvfp4_sycl; case GGML_TYPE_F16: return convert_unary_sycl<sycl::half>; #ifdef GGML_SYCL_HAS_BF16 @@ -658,11 +800,29 @@ to_fp32_sycl_t ggml_get_to_fp32_sycl(ggml_type type, ggml_tensor *dst) { return convert_unary_sycl<sycl::ext::oneapi::bfloat16>; #endif default: + GGML_ABORT("fatal error: unsupport data type=%s\n", ggml_type_name(type)); return nullptr; } } -to_fp16_nc_sycl_t get_to_fp16_nc_sycl(ggml_type type) { + +#ifdef GGML_SYCL_HAS_BF16 +to_bf16_sycl_t ggml_get_to_bf16_sycl(ggml_type type, ggml_tensor * /*dst*/) { + switch (type) { + case GGML_TYPE_F32: + return convert_unary_sycl<float>; + case GGML_TYPE_F16: + return convert_unary_sycl<sycl::half>; + case GGML_TYPE_BF16: + return convert_unary_sycl<sycl::ext::oneapi::bfloat16>; + default: + GGML_ABORT("fatal error: unsupport data type=%s\n", ggml_type_name(type)); + return nullptr; + } +} +#endif + +to_fp16_nc_sycl_t ggml_get_to_fp16_nc_sycl(ggml_type type) { switch (type) { case GGML_TYPE_F32: return convert_unary_nc_sycl<float>; @@ -670,6 +830,16 @@ to_fp16_nc_sycl_t get_to_fp16_nc_sycl(ggml_type type) { case GGML_TYPE_BF16: return convert_unary_nc_sycl<sycl::ext::oneapi::bfloat16>; #endif + case GGML_TYPE_Q4_0: + return dequantize_block_nc_sycl<QK4_0, QR4_0, dequantize_q4_0>; + case GGML_TYPE_Q4_1: + return dequantize_block_nc_sycl<QK4_1, QR4_1, dequantize_q4_1>; + case GGML_TYPE_Q5_0: + return dequantize_block_nc_sycl<QK5_0, QR5_0, dequantize_q5_0>; + case GGML_TYPE_Q5_1: + return dequantize_block_nc_sycl<QK5_1, QR5_1, dequantize_q5_1>; + case GGML_TYPE_Q8_0: + return dequantize_block_nc_sycl<QK8_0, QR8_0, dequantize_q8_0>; default: return nullptr; } diff --git a/ggml/src/ggml-sycl/convert.hpp b/ggml/src/ggml-sycl/convert.hpp index f8cb573e368..8de79d10ff6 100644 --- a/ggml/src/ggml-sycl/convert.hpp +++ b/ggml/src/ggml-sycl/convert.hpp @@ -23,12 +23,42 @@ typedef to_t_sycl_t<sycl::half> to_fp16_sycl_t; to_fp16_sycl_t ggml_get_to_fp16_sycl(ggml_type type, ggml_tensor * dst); to_fp32_sycl_t ggml_get_to_fp32_sycl(ggml_type type, ggml_tensor * dst); +#ifdef GGML_SYCL_HAS_BF16 +typedef to_t_sycl_t<sycl::ext::oneapi::bfloat16> to_bf16_sycl_t; +to_bf16_sycl_t ggml_get_to_bf16_sycl(ggml_type type, ggml_tensor * dst); +#endif + // Nc = Non-contiguous template <typename T> using to_t_nc_sycl_t = void (*)(const void * x, T * y, int64_t ne00, int64_t ne01, int64_t ne02, int64_t ne03, int64_t s01, int64_t s02, int64_t s03, dpct::queue_ptr queue); typedef to_t_nc_sycl_t<sycl::half> to_fp16_nc_sycl_t; -to_fp16_nc_sycl_t get_to_fp16_nc_sycl(ggml_type type); +to_fp16_nc_sycl_t ggml_get_to_fp16_nc_sycl(ggml_type type); + +template<typename dst_t, typename src_t> + inline dst_t ggml_sycl_cast(src_t x) { + if constexpr (std::is_same_v<dst_t, src_t>) { + return x; +#ifdef GGML_SYCL_HAS_BF16 + } else if constexpr (std::is_same_v<dst_t, sycl::ext::oneapi::bfloat16>) { + return sycl::ext::oneapi::bfloat16(float(x)); + } else if constexpr (std::is_same_v<src_t, sycl::ext::oneapi::bfloat16>) { + return static_cast<float>(x); +#endif + } else if constexpr (std::is_same_v<src_t, sycl::float2> && std::is_same_v<dst_t, sycl::half2>) { + return x.template convert<sycl::half, sycl::rounding_mode::rte>(); +#ifdef GGML_SYCL_HAS_BF16 + } else if constexpr (std::is_same_v<src_t, sycl::float2> && + std::is_same_v<dst_t, sycl::vec<sycl::ext::oneapi::bfloat16, 2>>) { + return {x.x, x.y}; +#endif + } else if constexpr(std::is_same_v<dst_t, int32_t>) { + return int32_t(x); + } else { + return float(x); + } +} + #endif // GGML_SYCL_CONVERT_HPP diff --git a/ggml/src/ggml-sycl/count-equal.cpp b/ggml/src/ggml-sycl/count-equal.cpp index b0a8b4820de..4580354cd9d 100644 --- a/ggml/src/ggml-sycl/count-equal.cpp +++ b/ggml/src/ggml-sycl/count-equal.cpp @@ -18,7 +18,7 @@ static void count_equal(const T *__restrict__ x, const T *__restrict__ y, nequal += xi == yi; } - nequal = warp_reduce_sum(nequal); + nequal = warp_reduce_sum<WARP_SIZE>(nequal); if (item_ct1.get_local_id(2) != 0) { return; diff --git a/ggml/src/ggml-sycl/cumsum.cpp b/ggml/src/ggml-sycl/cumsum.cpp new file mode 100644 index 00000000000..c1c5fe4fe4a --- /dev/null +++ b/ggml/src/ggml-sycl/cumsum.cpp @@ -0,0 +1,148 @@ +#include "cumsum.hpp" +#include "common.hpp" + +#include <algorithm> + +#define SYCL_CUMSUM_BLOCK_SIZE 256 + +static __dpct_inline__ float warp_prefix_inclusive_sum_f32(float x, const sycl::nd_item<3> & item) { + return sycl::inclusive_scan_over_group(item.get_sub_group(), x, sycl::plus<float>()); +} + +static void cumsum_f32_kernel( + const float * __restrict__ src, float * __restrict__ dst, + const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03, + const int64_t s01, const int64_t s02, const int64_t s03, + const int64_t d1, const int64_t d2, const int64_t d3, + const sycl::nd_item<3> & item, float * smem) { + + const int tid = item.get_local_id(2); + const int block_size = item.get_local_range(2); + const int lane = tid % WARP_SIZE; + const int warp = tid / WARP_SIZE; + const int warps_per_block = block_size / WARP_SIZE; + + float * s_vals = smem; + float * s_warp_sums = smem + block_size; + float * s_carry = smem + block_size + warps_per_block; + + if (tid == 0) { + s_carry[0] = 0.0f; + } + item.barrier(sycl::access::fence_space::local_space); + + const int64_t i3 = item.get_group(0); + const int64_t i2 = item.get_group(1); + const int64_t i1 = item.get_group(2); + if (i3 >= ne03 || i2 >= ne02 || i1 >= ne01) { + return; + } + + const float * src_row = src + i1 * s01 + i2 * s02 + i3 * s03; + float * dst_row = dst + i1 * d1 + i2 * d2 + i3 * d3; + + constexpr int num_unroll = 4; + float temp[num_unroll]; + + for (int64_t i = 0; i < ne00; i += num_unroll * block_size) { + int64_t idx = i + tid * num_unroll; + + temp[0] = (idx < ne00 ? src_row[idx] : 0.0f); +#pragma unroll + for (int j = 1; j < num_unroll; j++) { + temp[j] = temp[j - 1]; + if (idx + j < ne00) { + temp[j] += src_row[idx + j]; + } + } + + float val = (idx < ne00) ? temp[num_unroll - 1] : 0.0f; + + val = warp_prefix_inclusive_sum_f32(val, item); + s_vals[tid] = val; + + if (lane == WARP_SIZE - 1) { + s_warp_sums[warp] = val; + } + item.barrier(sycl::access::fence_space::local_space); + + if (warp == 0) { + float w = (tid < warps_per_block) ? s_warp_sums[tid] : 0.0f; + float inc = warp_prefix_inclusive_sum_f32(w, item); + if (tid < warps_per_block) { + s_warp_sums[tid] = inc - w; + } + if (tid == warps_per_block - 1) { + s_carry[1] = inc; + } + } + item.barrier(sycl::access::fence_space::local_space); + + float carry = s_carry[0]; + float final_offset = s_vals[tid] + s_warp_sums[warp] + carry - temp[num_unroll - 1]; + +#pragma unroll + for (int j = 0; j < num_unroll; j++) { + if (idx + j < ne00) { + dst_row[idx + j] = temp[j] + final_offset; + } + } + + item.barrier(sycl::access::fence_space::local_space); + + if (tid == 0) { + s_carry[0] += s_carry[1]; + } + } +} + +inline void ggml_sycl_op_cumsum(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { + const ggml_tensor * src0 = dst->src[0]; + + GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT(dst->type == GGML_TYPE_F32); + + dpct::queue_ptr stream = ctx.stream(); + SYCL_CHECK(ggml_sycl_set_device(ctx.device)); + + const float * src_d = static_cast<const float *>(src0->data); + float * dst_d = static_cast<float *>(dst->data); + + const int64_t ne00 = src0->ne[0]; + const int64_t ne01 = src0->ne[1]; + const int64_t ne02 = src0->ne[2]; + const int64_t ne03 = src0->ne[3]; + + const size_t ts = sizeof(float); + const int64_t s01 = src0->nb[1] / ts; + const int64_t s02 = src0->nb[2] / ts; + const int64_t s03 = src0->nb[3] / ts; + const int64_t d1 = dst->nb[1] / ts; + const int64_t d2 = dst->nb[2] / ts; + const int64_t d3 = dst->nb[3] / ts; + + const int num_warps = (ne00 + WARP_SIZE - 1) / WARP_SIZE; + int block_size = num_warps * WARP_SIZE; + block_size = std::min(block_size, SYCL_CUMSUM_BLOCK_SIZE); + const int warps_per_block = block_size / WARP_SIZE; + const int smem_size = block_size + warps_per_block + 2; + + const sycl::range<3> grid(ne03, ne02, ne01); + const sycl::range<3> block(1, 1, block_size); + + stream->submit([&](sycl::handler & cgh) { + sycl::local_accessor<float, 1> smem_acc(sycl::range<1>(smem_size), cgh); + cgh.parallel_for( + sycl::nd_range<3>(grid * block, block), + [=](sycl::nd_item<3> item) [[sycl::reqd_sub_group_size(WARP_SIZE)]] { + cumsum_f32_kernel(src_d, dst_d, ne00, ne01, ne02, ne03, + s01, s02, s03, d1, d2, d3, + item, get_pointer(smem_acc)); + }); + }); +} + +void ggml_sycl_cumsum(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { + scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1); + ggml_sycl_op_cumsum(ctx, dst); +} diff --git a/ggml/src/ggml-sycl/cumsum.hpp b/ggml/src/ggml-sycl/cumsum.hpp new file mode 100644 index 00000000000..f1a564472c5 --- /dev/null +++ b/ggml/src/ggml-sycl/cumsum.hpp @@ -0,0 +1,5 @@ +#pragma once + +#include "common.hpp" + +void ggml_sycl_cumsum(ggml_backend_sycl_context & ctx, ggml_tensor * dst); diff --git a/ggml/src/ggml-sycl/dequantize.hpp b/ggml/src/ggml-sycl/dequantize.hpp index da2a605daa8..ca8cd96c08c 100644 --- a/ggml/src/ggml-sycl/dequantize.hpp +++ b/ggml/src/ggml-sycl/dequantize.hpp @@ -14,11 +14,16 @@ #define GGML_SYCL_DEQUANTIZE_HPP #include "common.hpp" +#include "convert.hpp" typedef void (*dequantize_kernel_t)(const void * vx, const int64_t ib, const int iqs, dfloat2 & v); typedef void (*dequantize_kernel_t_reorder)(const void *d, const int64_t ib, const void *qs, const int iqs, dfloat2 &v); +#if QK_K == 256 +static inline void get_scale_min_k4(int j, const uint8_t * q, uint8_t & d, uint8_t & m); +#endif + static __dpct_inline__ void dequantize_q4_0(const void *vx, const int64_t ib, const int iqs, dfloat2 &v) { const block_q4_0 * x = (const block_q4_0 *) vx; @@ -89,6 +94,474 @@ static __dpct_inline__ void dequantize_q4_1(const void *vx, const int64_t ib, #endif // GGML_SYCL_F16 } +static __dpct_inline__ void dequantize_q4_K(const void *vx, const int64_t ib, + const int iqs, dfloat2 &v) { +#if QK_K == 256 + const block_q4_K * x = (const block_q4_K *) vx; + const sycl::half2 dm = x[ib].dm; + const float dall = dm[0]; + const float dmin = dm[1]; + + auto dequantize_one = [&](const int idx) -> dfloat { + const int il = idx / 64; + const int in = idx % 64; + const int is = 2 * il + (in >= 32 ? 1 : 0); + const int off = in & 31; + const int qsi = 32 * il + off; + + uint8_t sc; + uint8_t m; + get_scale_min_k4(is, x[ib].scales, sc, m); + + const uint8_t q = x[ib].qs[qsi]; + const uint8_t qv = (in >= 32) ? (q >> 4) : (q & 0xF); + return sycl::fma((dfloat) qv, (dfloat) (dall * sc), (dfloat) (-dmin * m)); + }; + + v.x() = dequantize_one(iqs + 0); + v.y() = dequantize_one(iqs + 1); +#else + GGML_ABORT("Q4_K dequantize not supported for QK_K != 256"); +#endif +} + +static __dpct_inline__ void dequantize_q2_K(const void *vx, const int64_t ib, + const int iqs, dfloat2 &v) { +#if QK_K == 256 + const block_q2_K * x = (const block_q2_K *) vx; + const float dall = x[ib].dm[0]; + const float dmin = x[ib].dm[1]; + + auto dequantize_one = [&](const int idx) -> dfloat { + const int n = idx / 128; + const int r = idx % 128; + const int g = r / 32; + const int l = r % 32; + const int is = 8 * n + l / 16; + + const uint8_t q = x[ib].qs[32 * n + l]; + const uint8_t sc = x[ib].scales[is + 2 * g]; + const float d = dall * (sc & 0xF); + const float m = dmin * (sc >> 4); + + return sycl::fma((dfloat) ((q >> (2 * g)) & 3), (dfloat) d, (dfloat) (-m)); + }; + + v.x() = dequantize_one(iqs + 0); + v.y() = dequantize_one(iqs + 1); +#else + GGML_ABORT("Q2_K dequantize not supported for QK_K != 256"); +#endif +} + +static __dpct_inline__ void dequantize_q3_K(const void *vx, const int64_t ib, + const int iqs, dfloat2 &v) { +#if QK_K == 256 + const block_q3_K * x = (const block_q3_K *) vx; + const float d_all = x[ib].d; + + auto dequantize_one = [&](const int idx) -> dfloat { + const int n = idx / 128; + const int r = idx % 128; + const int j = r / 32; + const int l = r % 32; + + const int is0 = l / 16; + const int is = 8 * n + 2 * j + is0; + const int shift = 2 * j; + const uint8_t m = 1 << (4 * n + j); + + const int8_t us = is < 4 ? (x[ib].scales[is - 0] & 0xF) | (((x[ib].scales[is + 8] >> 0) & 3) << 4) : + is < 8 ? (x[ib].scales[is - 0] & 0xF) | (((x[ib].scales[is + 4] >> 2) & 3) << 4) : + is < 12 ? (x[ib].scales[is - 8] >> 4) | (((x[ib].scales[is + 0] >> 4) & 3) << 4) : + (x[ib].scales[is - 8] >> 4) | (((x[ib].scales[is - 4] >> 6) & 3) << 4); + + const float dl = d_all * (us - 32); + const uint8_t q = x[ib].qs[32 * n + l]; + const uint8_t h = x[ib].hmask[l]; + const int8_t qv = ((q >> shift) & 3) - ((h & m) ? 0 : 4); + + return (dfloat) (dl * qv); + }; + + v.x() = dequantize_one(iqs + 0); + v.y() = dequantize_one(iqs + 1); +#else + GGML_ABORT("Q3_K dequantize not supported for QK_K != 256"); +#endif +} + +static __dpct_inline__ void dequantize_q5_K(const void *vx, const int64_t ib, + const int iqs, dfloat2 &v) { +#if QK_K == 256 + const block_q5_K * x = (const block_q5_K *) vx; + const float dall = x[ib].dm[0]; + const float dmin = x[ib].dm[1]; + + auto dequantize_one = [&](const int idx) -> dfloat { + const int il = idx / 64; + const int in = idx % 64; + const int is = 2 * il + (in >= 32 ? 1 : 0); + const int ir = (in & 31) / 2; + const int iq = in & 1; + + const uint8_t q = x[ib].qs[32 * il + 2 * ir + iq]; + const uint8_t h = x[ib].qh[2 * ir + iq]; + const uint8_t qv = (in >= 32) ? (q >> 4) : (q & 0xF); + + uint8_t sc; + uint8_t m; + get_scale_min_k4(is, x[ib].scales, sc, m); + + const float d = dall * sc; + const float mn = dmin * m; + const uint8_t hm = 1 << (2 * il + (in >= 32 ? 1 : 0)); + + return sycl::fma((dfloat) (qv + ((h & hm) ? 16 : 0)), (dfloat) d, (dfloat) (-mn)); + }; + + v.x() = dequantize_one(iqs + 0); + v.y() = dequantize_one(iqs + 1); +#else + GGML_ABORT("Q5_K dequantize not supported for QK_K != 256"); +#endif +} + +static __dpct_inline__ void dequantize_q6_K(const void *vx, const int64_t ib, + const int iqs, dfloat2 &v) { +#if QK_K == 256 + const block_q6_K * x = (const block_q6_K *) vx; + const float d = x[ib].d; + + auto dequantize_one = [&](const int idx) -> dfloat { + const int ip = idx / 128; + const int in = idx % 128; + const int il = in & 31; + const int ig = in / 32; + const int is = 8 * ip + il / 16; + + const uint8_t ql0 = x[ib].ql[64 * ip + il]; + const uint8_t ql1 = x[ib].ql[64 * ip + il + 32]; + const uint8_t qh = x[ib].qh[32 * ip + il]; + const int8_t * sc = x[ib].scales + is; + + uint8_t qv; + int8_t scale; + if (ig == 0) { + qv = (ql0 & 0xF) | (((qh >> 0) & 3) << 4); + scale = sc[0]; + } else if (ig == 1) { + qv = (ql1 & 0xF) | (((qh >> 2) & 3) << 4); + scale = sc[2]; + } else if (ig == 2) { + qv = (ql0 >> 4) | (((qh >> 4) & 3) << 4); + scale = sc[4]; + } else { + qv = (ql1 >> 4) | (((qh >> 6) & 3) << 4); + scale = sc[6]; + } + + return (dfloat) (d * scale * ((int8_t) qv - 32)); + }; + + v.x() = dequantize_one(iqs + 0); + v.y() = dequantize_one(iqs + 1); +#else + GGML_ABORT("Q6_K dequantize not supported for QK_K != 256"); +#endif +} + +static __dpct_inline__ void dequantize_mxfp4(const void *vx, const int64_t ib, + const int iqs, dfloat2 &v) { + const block_mxfp4 * x = (const block_mxfp4 *) vx; + const float d = ggml_sycl_e8m0_to_fp32(x[ib].e); + const uint8_t q = x[ib].qs[iqs]; + + v.x() = d * kvalues_mxfp4[q & 0xF] * 0.5f; + v.y() = d * kvalues_mxfp4[q >> 4] * 0.5f; +} + +static __dpct_inline__ void dequantize_q1_0(const void *vx, const int64_t ib, + const int iqs, dfloat2 &v) { + const block_q1_0 * x = (const block_q1_0 *) vx; + const dfloat d = x[ib].d; + + const int bit_index_0 = iqs + 0; + const int bit_index_1 = iqs + 1; + + const int bit_0 = (x[ib].qs[bit_index_0 / 8] >> (bit_index_0 % 8)) & 1; + const int bit_1 = (x[ib].qs[bit_index_1 / 8] >> (bit_index_1 % 8)) & 1; + + v.x() = (2 * bit_0 - 1) * d; + v.y() = (2 * bit_1 - 1) * d; +} + +static __dpct_inline__ void dequantize_nvfp4(const void *vx, const int64_t ib, + const int iqs, dfloat2 &v) { + const block_nvfp4 & xb = ((const block_nvfp4 *) vx)[ib]; + + auto dequantize_one = [&](const int idx) -> dfloat { + const int sub = idx / QK_NVFP4_SUB; + const int j = idx % QK_NVFP4_SUB; + const int jh = j % (QK_NVFP4_SUB / 2); + + const float d = ggml_sycl_ue4m3_to_fp32(xb.d[sub]); + const uint8_t q = xb.qs[sub * (QK_NVFP4_SUB / 2) + jh]; + const uint8_t qv = (j < (QK_NVFP4_SUB / 2)) ? (q & 0x0F) : (q >> 4); + + return d * kvalues_mxfp4[qv]; + }; + + v.x() = dequantize_one(iqs + 0); + v.y() = dequantize_one(iqs + 1); +} + +static __dpct_inline__ void dequantize_iq2_xxs(const void *vx, const int64_t ib, + const int iqs, dfloat2 &v) { +#if QK_K == 256 + const block_iq2_xxs * x = (const block_iq2_xxs *) vx; + + auto dequantize_one = [&](const int idx) -> dfloat { + const int ib8 = idx / 32; + const int r = idx % 32; + const int il = r / 8; + const int j = r % 8; + + const uint16_t * q2 = x[ib].qs + 4 * ib8; + const uint8_t * aux8 = (const uint8_t *) q2; + const uint8_t * grid = (const uint8_t *) (iq2xxs_grid + aux8[il]); + const uint32_t aux32 = q2[2] | (q2[3] << 16); + const float d = (float) x[ib].d * (0.5f + (aux32 >> 28)) * 0.25f; + const uint8_t signs = ksigns_iq2xs[(aux32 >> (7 * il)) & 127]; + + return d * grid[j] * ((signs & kmask_iq2xs[j]) ? -1.f : 1.f); + }; + + v.x() = dequantize_one(iqs + 0); + v.y() = dequantize_one(iqs + 1); +#else + GGML_ABORT("IQ2_XXS dequantize not supported for QK_K != 256"); +#endif +} + +static __dpct_inline__ void dequantize_iq2_xs(const void *vx, const int64_t ib, + const int iqs, dfloat2 &v) { +#if QK_K == 256 + const block_iq2_xs * x = (const block_iq2_xs *) vx; + + auto dequantize_one = [&](const int idx) -> dfloat { + const int ib8 = idx / 32; + const int r = idx % 32; + const int il = r / 8; + const int j = r % 8; + + const uint16_t * q2 = x[ib].qs + 4 * ib8; + const uint8_t * grid = (const uint8_t *) (iq2xs_grid + (q2[il] & 511)); + const float d = (float) x[ib].d * (0.5f + ((x[ib].scales[ib8] >> (4 * (il / 2))) & 0xf)) * 0.25f; + const uint8_t signs = ksigns_iq2xs[q2[il] >> 9]; + + return d * grid[j] * ((signs & kmask_iq2xs[j]) ? -1.f : 1.f); + }; + + v.x() = dequantize_one(iqs + 0); + v.y() = dequantize_one(iqs + 1); +#else + GGML_ABORT("IQ2_XS dequantize not supported for QK_K != 256"); +#endif +} + +static __dpct_inline__ void dequantize_iq2_s(const void *vx, const int64_t ib, + const int iqs, dfloat2 &v) { +#if QK_K == 256 + const block_iq2_s * x = (const block_iq2_s *) vx; + + auto dequantize_one = [&](const int idx) -> dfloat { + const int ib8 = idx / 32; + const int r = idx % 32; + const int il = r / 8; + const int j = r % 8; + + const uint16_t grid_id = x[ib].qs[4 * ib8 + il] | ((x[ib].qh[ib8] << (8 - 2 * il)) & 0x300); + const uint8_t * grid = (const uint8_t *) (iq2s_grid + grid_id); + const float d = (float) x[ib].d * (0.5f + ((x[ib].scales[ib8] >> (4 * (il / 2))) & 0xf)) * 0.25f; + const uint8_t signs = x[ib].qs[QK_K / 8 + 4 * ib8 + il]; + + return d * grid[j] * ((signs & kmask_iq2xs[j]) ? -1.f : 1.f); + }; + + v.x() = dequantize_one(iqs + 0); + v.y() = dequantize_one(iqs + 1); +#else + GGML_ABORT("IQ2_S dequantize not supported for QK_K != 256"); +#endif +} + +static __dpct_inline__ void dequantize_iq3_xxs(const void *vx, const int64_t ib, + const int iqs, dfloat2 &v) { +#if QK_K == 256 + const block_iq3_xxs * x = (const block_iq3_xxs *) vx; + + auto dequantize_one = [&](const int idx) -> dfloat { + const int ib8 = idx / 32; + const int r = idx % 32; + const int il = r / 8; + const int j = r % 8; + + const uint8_t * q3 = x[ib].qs + 8 * ib8; + const uint16_t * gas = (const uint16_t *) (x[ib].qs + QK_K / 4) + 2 * ib8; + const uint8_t * grid1 = (const uint8_t *) (iq3xxs_grid + q3[2 * il + 0]); + const uint8_t * grid2 = (const uint8_t *) (iq3xxs_grid + q3[2 * il + 1]); + const uint32_t aux32 = gas[0] | (gas[1] << 16); + const float d = (float) x[ib].d * (0.5f + (aux32 >> 28)) * 0.5f; + const uint8_t signs = ksigns_iq2xs[(aux32 >> (7 * il)) & 127]; + + if (j < 4) { + return d * grid1[j] * ((signs & kmask_iq2xs[j + 0]) ? -1.f : 1.f); + } + return d * grid2[j - 4] * ((signs & kmask_iq2xs[j + 0]) ? -1.f : 1.f); + }; + + v.x() = dequantize_one(iqs + 0); + v.y() = dequantize_one(iqs + 1); +#else + GGML_ABORT("IQ3_XXS dequantize not supported for QK_K != 256"); +#endif +} + +static __dpct_inline__ void dequantize_iq3_s(const void *vx, const int64_t ib, + const int iqs, dfloat2 &v) { +#if QK_K == 256 + const block_iq3_s * x = (const block_iq3_s *) vx; + + auto dequantize_one = [&](const int idx) -> dfloat { + const int ib8 = idx / 32; + const int r = idx % 32; + const int il = r / 8; + const int j = r % 8; + + const uint8_t * qs = x[ib].qs + 8 * ib8; + const uint16_t grid1_id = qs[2 * il + 0] | ((x[ib].qh[ib8] << (8 - 2 * il)) & 256); + const uint16_t grid2_id = qs[2 * il + 1] | ((x[ib].qh[ib8] << (7 - 2 * il)) & 256); + const uint8_t * grid1 = (const uint8_t *) (iq3s_grid + grid1_id); + const uint8_t * grid2 = (const uint8_t *) (iq3s_grid + grid2_id); + const float d = (float) x[ib].d * (1 + 2 * ((x[ib].scales[ib8 / 2] >> (4 * (ib8 % 2))) & 0xf)); + const uint8_t signs = x[ib].signs[4 * ib8 + il]; + + if (j < 4) { + return d * grid1[j] * ((signs & kmask_iq2xs[j + 0]) ? -1.f : 1.f); + } + return d * grid2[j - 4] * ((signs & kmask_iq2xs[j + 0]) ? -1.f : 1.f); + }; + + v.x() = dequantize_one(iqs + 0); + v.y() = dequantize_one(iqs + 1); +#else + GGML_ABORT("IQ3_S dequantize not supported for QK_K != 256"); +#endif +} + +static __dpct_inline__ void dequantize_iq1_s(const void *vx, const int64_t ib, + const int iqs, dfloat2 &v) { +#if QK_K == 256 + const block_iq1_s * x = (const block_iq1_s *) vx; + + auto dequantize_one = [&](const int idx) -> dfloat { + const int ib8 = idx / 32; + const int r = idx % 32; + const int il = r / 8; + const int j = r % 8; + + const float delta = (x[ib].qh[ib8] & 0x8000) ? (-1.f - IQ1S_DELTA) : (-1.f + IQ1S_DELTA); + const float d = (float) x[ib].d * (2 * ((x[ib].qh[ib8] >> 12) & 7) + 1); + const uint16_t grid_id = x[ib].qs[4 * ib8 + il] | (((x[ib].qh[ib8] >> (3 * il)) & 7) << 8); + const uint32_t g = iq1s_grid_gpu[grid_id]; + const int8_t qv = (j < 4) ? ((g >> (8 * j)) & 0x0F) : ((g >> (8 * (j - 4) + 4)) & 0x0F); + + return d * (qv + delta); + }; + + v.x() = dequantize_one(iqs + 0); + v.y() = dequantize_one(iqs + 1); +#else + GGML_ABORT("IQ1_S dequantize not supported for QK_K != 256"); +#endif +} + +static __dpct_inline__ void dequantize_iq1_m(const void *vx, const int64_t ib, + const int iqs, dfloat2 &v) { +#if QK_K == 256 + const block_iq1_m * x = (const block_iq1_m *) vx; + + auto dequantize_one = [&](const int idx) -> dfloat { + const int ib8 = idx / 32; + const int r = idx % 32; + const int il = r / 8; + const int j = r % 8; + + const uint16_t * sc = (const uint16_t *) x[ib].scales; + iq1m_scale_t scale; + scale.u16 = (sc[0] >> 12) | ((sc[1] >> 8) & 0x00f0) | ((sc[2] >> 4) & 0x0f00) | (sc[3] & 0xf000); + + const int ib16 = 2 * ib8 + il / 2; + const float d = (float) scale.f16 * (2 * ((sc[ib16 / 4] >> (3 * (ib16 % 4))) & 0x7) + 1); + + const uint8_t qh = x[ib].qh[2 * ib8 + il / 2]; + const float delta = (qh & (0x08 << (4 * (il % 2)))) ? (-1.f - IQ1M_DELTA) : (-1.f + IQ1M_DELTA); + + const uint16_t grid_id = x[ib].qs[4 * ib8 + il] | (((qh >> (4 * (il % 2))) & 7) << 8); + const uint32_t g = iq1s_grid_gpu[grid_id]; + const int8_t qv = (j < 4) ? ((g >> (8 * j)) & 0x0F) : ((g >> (8 * (j - 4) + 4)) & 0x0F); + + return d * (qv + delta); + }; + + v.x() = dequantize_one(iqs + 0); + v.y() = dequantize_one(iqs + 1); +#else + GGML_ABORT("IQ1_M dequantize not supported for QK_K != 256"); +#endif +} + +static __dpct_inline__ void dequantize_iq4_nl(const void *vx, const int64_t ib, + const int iqs, dfloat2 &v) { + const block_iq4_nl * x = (const block_iq4_nl *) vx; + const float d = (float) x[ib].d; + + auto dequantize_one = [&](const int idx) -> dfloat { + if (idx < 16) { + return d * kvalues_iq4nl[x[ib].qs[idx] & 0xF]; + } + return d * kvalues_iq4nl[x[ib].qs[idx - 16] >> 4]; + }; + + v.x() = dequantize_one(iqs + 0); + v.y() = dequantize_one(iqs + 1); +} + +static __dpct_inline__ void dequantize_iq4_xs(const void *vx, const int64_t ib, + const int iqs, dfloat2 &v) { +#if QK_K == 256 + const block_iq4_xs * x = (const block_iq4_xs *) vx; + + auto dequantize_one = [&](const int idx) -> dfloat { + const int ib8 = idx / 32; + const int r = idx % 32; + const int byte_idx = (r < 16) ? r : (r - 16); + const uint8_t q = x[ib].qs[16 * ib8 + byte_idx]; + const uint8_t qv = (r < 16) ? (q & 0x0F) : (q >> 4); + + const float d = (float) x[ib].d * ((((x[ib].scales_l[ib8 / 2] >> (4 * (ib8 % 2))) & 0xf) | + (((x[ib].scales_h >> (2 * ib8)) & 3) << 4)) - 32); + return d * kvalues_iq4nl[qv]; + }; + + v.x() = dequantize_one(iqs + 0); + v.y() = dequantize_one(iqs + 1); +#else + GGML_ABORT("IQ4_XS dequantize not supported for QK_K != 256"); +#endif +} + static __dpct_inline__ void dequantize_q5_0(const void *vx, const int64_t ib, const int iqs, dfloat2 &v) { const block_q5_0 * x = (const block_q5_0 *) vx; @@ -143,6 +616,22 @@ static __dpct_inline__ void dequantize_q5_1(const void *vx, const int64_t ib, #endif // GGML_SYCL_F16 } +static __dpct_inline__ void dequantize_q8_0_reorder(const void *d_ptr, const int64_t ib, const void *qs, + const int iqs, dfloat2 &v) { + const dfloat d = (const dfloat)*((const sycl::half*)d_ptr + ib); + + v.x() = ((const int8_t *)qs)[iqs + 0]; + v.y() = ((const int8_t *)qs)[iqs + 1]; + +#ifdef GGML_SYCL_F16 + v.s0() *= d; + v.s1() *= d; +#else + v.x() *= d; + v.y() *= d; +#endif // GGML_SYCL_F16 +} + static __dpct_inline__ void dequantize_q8_0(const void *vx, const int64_t ib, const int iqs, dfloat2 &v) { const block_q8_0 * x = (const block_q8_0 *) vx; @@ -222,6 +711,34 @@ static void dequantize_block_q4_0_reorder(const void * __restrict__ vx, dst_t * } +// Dequantize Q8_0 from reorder layout: [all qs (k bytes)][all d values] +// Each thread handles one block of QK8_0 elements. +template<typename dst_t> +static void dequantize_block_q8_0_reorder(const void * __restrict__ vx, dst_t * __restrict__ yy, int64_t k, + const sycl::nd_item<3> &item_ct1) { + + const int64_t i = item_ct1.get_group(2); + const int64_t tid = item_ct1.get_local_id(2); + const int lane_ib = i * WARP_SIZE + tid; + + if (lane_ib >= k / QK8_0) { + return; + } + + dst_t * y_ptr = yy + lane_ib * QK8_0; + + auto qs = (const int8_t*)vx + lane_ib * QK8_0; + auto s_ptr = (const sycl::half*)((const uint8_t*)vx + k) + lane_ib; + + const float d = float(*s_ptr); + +#pragma unroll + for (int l = 0; l < QK8_0; ++l) { + y_ptr[l] = d * qs[l]; + } + +} + template<typename dst_t> static void dequantize_block_q4_1(const void * __restrict__ vx, dst_t * __restrict__ yy, int64_t nb32, const sycl::nd_item<3> &item_ct1) { @@ -345,6 +862,63 @@ static void dequantize_block_q3_K(const void * __restrict__ vx, dst_t * __restri } +template<typename dst_t> +static void dequantize_block_q3_K_reorder(const void * __restrict__ vx, dst_t * __restrict__ yy, + const sycl::nd_item<3> & item_ct1, int64_t n_blocks) { +#if QK_K == 256 + const int64_t i = item_ct1.get_group(2); + if (i >= n_blocks) { + return; + } + + const uint8_t * base = static_cast<const uint8_t *>(vx); + const size_t qs_offset = i * (QK_K / 4); + const size_t hmask_offset = n_blocks * (QK_K / 4) + i * (QK_K / 8); + const size_t scales_offset = n_blocks * (QK_K / 4) + n_blocks * (QK_K / 8) + i * 12; + const size_t d_offset = n_blocks * (QK_K / 4) + n_blocks * (QK_K / 8) + n_blocks * 12 + + i * sizeof(ggml_half); + + const uint8_t * qs = base + qs_offset; + const uint8_t * hmask = base + hmask_offset; + const uint8_t * scales = base + scales_offset; + const float d_all = static_cast<float>(*reinterpret_cast<const ggml_half *>(base + d_offset)); + + const int64_t r = item_ct1.get_local_id(2) / 4; + const int64_t tid = r / 2; + const int64_t is0 = r % 2; + const int64_t l0 = 16 * is0 + 4 * (item_ct1.get_local_id(2) % 4); + const int64_t n = tid / 4; + const int64_t j = tid - 4 * n; + const int64_t is = 8 * n + 2 * j + is0; + const int shift = 2 * j; + uint8_t m = 1 << (4 * n + j); + + uint8_t us = is < 4 + ? (scales[is - 0] & 0xF) | (((scales[is + 8] >> 0) & 3) << 4) + : is < 8 + ? (scales[is - 0] & 0xF) | (((scales[is + 4] >> 2) & 3) << 4) + : is < 12 + ? (scales[is - 8] >> 4) | (((scales[is + 0] >> 4) & 3) << 4) + : (scales[is - 8] >> 4) | (((scales[is - 4] >> 6) & 3) << 4); + + const float dl = d_all * (us - 32); + + dst_t * y = yy + i * QK_K + 128 * n + 32 * j; + const uint8_t * q = qs + 32 * n; + const uint8_t * hm = hmask; + + for (int l = l0; l < l0 + 4; ++l) { + y[l] = dl * ((int8_t) ((q[l] >> shift) & 3) - ((hm[l] & m) ? 0 : 4)); + } +#else + GGML_UNUSED(vx); + GGML_UNUSED(yy); + GGML_UNUSED(item_ct1); + GGML_UNUSED(n_blocks); + GGML_ABORT("Q3_K reorder dequantize not supported for QK_K != 256"); +#endif +} + #if QK_K == 256 static inline void get_scale_min_k4(int j, const uint8_t * q, uint8_t & d, uint8_t & m) { if (j < 4) { @@ -492,6 +1066,63 @@ static void dequantize_block_q5_K(const void * __restrict__ vx, dst_t * __restri #endif } +template <typename dst_t> +static void dequantize_block_q5_K_reorder(const void * __restrict__ vx, dst_t * __restrict__ yy, + uint8_t * scales_local, const sycl::nd_item<3> & item_ct1, int64_t n_blocks) { + const int64_t ib = item_ct1.get_group(2); + +#if QK_K == 256 + // assume 64 threads + const int64_t tid = item_ct1.get_local_id(2); + const int64_t il = tid / 16; // 0...3 + const int64_t ir = tid % 16; // 0...15 + const int64_t is = 2 * il; + + dst_t * y = yy + ib * QK_K + 64 * il + 2 * ir; + + const uint8_t * base = static_cast<const uint8_t *>(vx); + + // Reordered layout: [qs (QK_K/2 per block)] [qh (QK_K/8 per block)] [scales (K_SCALE_SIZE per block)] [dm (half2 per block)] + const size_t qs_offset = ib * (QK_K / 2); + const size_t qh_offset = n_blocks * (QK_K / 2) + ib * (QK_K / 8); + const size_t scales_offset = n_blocks * (QK_K / 2) + n_blocks * (QK_K / 8) + ib * K_SCALE_SIZE; + const size_t dm_offset = n_blocks * (QK_K / 2) + n_blocks * (QK_K / 8) + n_blocks * K_SCALE_SIZE + ib * sizeof(ggml_half2); + + const uint8_t * qs_ptr = base + qs_offset; + const uint8_t * qh_ptr = base + qh_offset; + const uint8_t * scales_ptr = base + scales_offset; + const ggml_half2 dm_values = *reinterpret_cast<const ggml_half2 *>(base + dm_offset); + + const float dall = dm_values.x(); + const float dmin = dm_values.y(); + + const uint8_t * ql = qs_ptr + 32 * il + 2 * ir; + const uint8_t * qh = qh_ptr + 2 * ir; + + if (tid < K_SCALE_SIZE) { + scales_local[tid] = scales_ptr[tid]; + } + + item_ct1.barrier(sycl::access::fence_space::local_space); + + uint8_t sc, m; + get_scale_min_k4(is + 0, scales_local, sc, m); + const float d1 = dall * sc; const float m1 = dmin * m; + get_scale_min_k4(is + 1, scales_local, sc, m); + const float d2 = dall * sc; const float m2 = dmin * m; + + uint8_t hm = 1 << (2 * il); + y[ 0] = d1 * ((ql[ 0] & 0xF) + (qh[ 0] & hm ? 16 : 0)) - m1; + y[ 1] = d1 * ((ql[ 1] & 0xF) + (qh[ 1] & hm ? 16 : 0)) - m1; + hm <<= 1; + y[32] = d2 * ((ql[ 0] >> 4) + (qh[ 0] & hm ? 16 : 0)) - m2; + y[33] = d2 * ((ql[ 1] >> 4) + (qh[ 1] & hm ? 16 : 0)) - m2; +#else + GGML_UNUSED(ib); GGML_UNUSED(tid); GGML_UNUSED(yy); GGML_UNUSED(scales_local); GGML_UNUSED(n_blocks); + GGML_ABORT("Q5_K reorder dequantize not supported for QK_K != 256"); +#endif +} + template<typename dst_t> static void dequantize_block_q6_K(const void * __restrict__ vx, dst_t * __restrict__ yy, const sycl::nd_item<3> &item_ct1) { @@ -838,4 +1469,36 @@ static void dequantize_block_mxfp4(const void * __restrict__ vx, dst_t * __restr } } + +template <typename dst_t> +static void dequantize_block_nvfp4( + const void * __restrict__ vx, + dst_t * __restrict__ yy, + const int64_t ne) { + auto item_ct1 = sycl::ext::oneapi::this_work_item::get_nd_item<3>(); + const int64_t i = item_ct1.get_group(2); + const int tid = item_ct1.get_local_id(2); + + const int64_t base = i * QK_NVFP4; + if (base >= ne) { + return; + } + + const block_nvfp4 * x = (const block_nvfp4 *) vx; + const block_nvfp4 & xb = x[i]; + + const int sub = tid / (QK_NVFP4_SUB / 2); + const int j = tid % (QK_NVFP4_SUB / 2); + + const float d = ggml_sycl_ue4m3_to_fp32(xb.d[sub]); + const uint8_t q = xb.qs[sub * (QK_NVFP4_SUB / 2) + j]; + + const int64_t y0 = base + sub * QK_NVFP4_SUB + j; + const int64_t y1 = y0 + QK_NVFP4_SUB / 2; + + yy[y0] = ggml_sycl_cast<dst_t>(d * kvalues_mxfp4[q & 0x0F]); + yy[y1] = ggml_sycl_cast<dst_t>(d * kvalues_mxfp4[q >> 4]); +} + + #endif // GGML_SYCL_DEQUANTIZE_HPP diff --git a/ggml/src/ggml-sycl/diag.cpp b/ggml/src/ggml-sycl/diag.cpp new file mode 100644 index 00000000000..c4264fee342 --- /dev/null +++ b/ggml/src/ggml-sycl/diag.cpp @@ -0,0 +1,67 @@ +#include "diag.hpp" +#include "common.hpp" + +#define SYCL_DIAG_BLOCK_SIZE 256 + +template <typename T> +static void diag_kernel(T * __restrict__ dst, const T * __restrict__ src, + const int64_t ne0, const int64_t ne1, + const int64_t ne2, const int64_t ne3, + const int64_t total_elements, + const sycl::nd_item<1> & item) { + const int64_t i = item.get_global_id(0); + if (i >= total_elements) { + return; + } + + const int64_t i0 = i % ne0; + const int64_t i1 = (i / ne0) % ne1; + const int64_t i2 = (i / (ne0 * ne1)) % ne2; + const int64_t i3 = i / (ne0 * ne1 * ne2); + + const int64_t dst_idx = ((i3 * ne2 + i2) * ne1 + i1) * ne0 + i0; + + if (i0 == i1) { + const int64_t batch_idx = i3 * ne2 + i2; + dst[dst_idx] = src[batch_idx * ne0 + i0]; + } else { + dst[dst_idx] = T(0); + } + + (void)ne3; +} + +inline void ggml_sycl_op_diag(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { + const ggml_tensor * src0 = dst->src[0]; + + GGML_ASSERT(ggml_is_contiguous(dst)); + GGML_ASSERT(ggml_is_contiguous(src0)); + GGML_ASSERT(src0->ne[1] == 1); + + dpct::queue_ptr stream = ctx.stream(); + SYCL_CHECK(ggml_sycl_set_device(ctx.device)); + + const void * src0_d = src0->data; + void * dst_d = dst->data; + + const int64_t ne0 = dst->ne[0]; + const int64_t ne1 = dst->ne[1]; + const int64_t ne2 = dst->ne[2]; + const int64_t ne3 = dst->ne[3]; + const int64_t n_elems = ggml_nelements(dst); + const int64_t num_blocks = (n_elems + SYCL_DIAG_BLOCK_SIZE - 1) / SYCL_DIAG_BLOCK_SIZE; + + GGML_ASSERT(dst->type == GGML_TYPE_F32); + stream->parallel_for( + sycl::nd_range<1>(num_blocks * SYCL_DIAG_BLOCK_SIZE, SYCL_DIAG_BLOCK_SIZE), + [=](sycl::nd_item<1> item) { + diag_kernel(static_cast<float *>(dst_d), + static_cast<const float *>(src0_d), + ne0, ne1, ne2, ne3, n_elems, item); + }); +} + +void ggml_sycl_diag(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { + scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1); + ggml_sycl_op_diag(ctx, dst); +} diff --git a/ggml/src/ggml-sycl/diag.hpp b/ggml/src/ggml-sycl/diag.hpp new file mode 100644 index 00000000000..20d7ce4895d --- /dev/null +++ b/ggml/src/ggml-sycl/diag.hpp @@ -0,0 +1,5 @@ +#pragma once + +#include "common.hpp" + +void ggml_sycl_diag(ggml_backend_sycl_context & ctx, ggml_tensor * dst); diff --git a/ggml/src/ggml-sycl/dmmv.cpp b/ggml/src/ggml-sycl/dmmv.cpp index 4f2760110c2..d80b0a38219 100644 --- a/ggml/src/ggml-sycl/dmmv.cpp +++ b/ggml/src/ggml-sycl/dmmv.cpp @@ -3,6 +3,13 @@ #include "dequantize.hpp" #include "presets.hpp" +#if defined(__INTEL_LLVM_COMPILER) + #if __has_include(<sycl/ext/oneapi/bfloat16.hpp>) + #include <sycl/ext/oneapi/bfloat16.hpp> + #define GGML_SYCL_DMMV_HAS_BF16 + #endif +#endif + static void convert_f16(const void * vx, const int64_t ib, const int iqs, dfloat2 & v){ const sycl::half *x = (const sycl::half *)vx; @@ -11,6 +18,16 @@ static void convert_f16(const void * vx, const int64_t ib, const int iqs, dfloat v.y() = x[ib + iqs + 1]; } +#ifdef GGML_SYCL_DMMV_HAS_BF16 +static void convert_bf16(const void * vx, const int64_t ib, const int iqs, dfloat2 & v){ + const sycl::ext::oneapi::bfloat16 *x = (const sycl::ext::oneapi::bfloat16 *)vx; + + // automatic bfloat16 -> float type cast if dfloat == float + v.x() = x[ib + iqs + 0]; + v.y() = x[ib + iqs + 1]; +} +#endif + static void convert_f32(const void * vx, const int64_t ib, const int iqs, dfloat2 & v){ const float * x = (const float *) vx; @@ -217,6 +234,28 @@ static void convert_mul_mat_vec_f16_sycl(const void *vx, const dfloat *y, } } +#ifdef GGML_SYCL_DMMV_HAS_BF16 +static void convert_mul_mat_vec_bf16_sycl(const void *vx, const dfloat *y, + float *dst, const int ncols, + const int nrows, + dpct::queue_ptr stream) { + // The qk=1 kernel iterates with stride 2*GGML_SYCL_DMMV_X, so ncols must be a + // multiple of that — not just GGML_SYCL_DMMV_X — to avoid out-of-bounds reads. + GGML_ASSERT(ncols % (2*GGML_SYCL_DMMV_X) == 0); + const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y; + const sycl::range<3> block_nums(1, 1, block_num_y); + const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE); + { + stream->parallel_for( + sycl::nd_range<3>(block_nums * block_dims, block_dims), + [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] { + dequantize_mul_mat_vec<1, 1, convert_bf16>(vx, y, dst, ncols, + nrows, item_ct1); + }); + } +} +#endif + /* DPCT1110:4: The total declared local variable size in device function dequantize_mul_mat_vec_q2_k exceeds 128 bytes and may cause high register @@ -462,6 +501,103 @@ static void dequantize_mul_mat_vec_q3_k(const void *__restrict__ vx, } } +static void dequantize_mul_mat_vec_q3_k_reorder(const void *__restrict__ vx, + const float *__restrict__ yy, + float *__restrict__ dst, + const int ncols, int nrows, + const sycl::nd_item<3> &item_ct1) { + + const int row = item_ct1.get_group(2) * item_ct1.get_local_range(1) + + item_ct1.get_local_id(1); + if (row > nrows) return; + + const int num_blocks_per_row = ncols / QK_K; + const int ib0 = row*num_blocks_per_row; + + // SOA base pointers for the reordered layout: + // [qs: nb * (QK_K/4)] [hmask: nb * (QK_K/8)] [scales: nb * 12] [d: nb * sizeof(half)] + const int nb = nrows * num_blocks_per_row; + const uint8_t * qs_base = (const uint8_t *)vx; + const uint8_t * hmask_base = qs_base + (size_t)nb * (QK_K / 4); + const uint8_t * scales_base = hmask_base + (size_t)nb * (QK_K / 8); + const sycl::half * d_base = (const sycl::half *)(scales_base + (size_t)nb * 12); + + float tmp = 0; // partial sum for thread in warp + +#if QK_K == 256 + + const uint16_t kmask1 = 0x0303; + const uint16_t kmask2 = 0x0f0f; + + const int tid = + item_ct1.get_local_id(2) / K_QUANTS_PER_ITERATION; // 0...31 or 0...16 + const int ix = + item_ct1.get_local_id(2) % K_QUANTS_PER_ITERATION; // 0 or 0,1 + + const int n = K_QUANTS_PER_ITERATION; // iterations in the inner loop + const int step = 16/K_QUANTS_PER_ITERATION; + const int im = tid/step; // 0 or 1. 0 computes 0..., 1 computes 128... + const int in = tid - step*im; // 0....15 or 0...7 + + const uint8_t m = 1 << (4*im); + + const int l0 = n*in; // 0...15 or 0...14 in steps of 2 + const int q_offset = 32*im + l0; + const int y_offset = 128*im + l0; + + uint16_t utmp[4]; + const int8_t * s = (const int8_t *)utmp; + + const uint16_t s_shift = 4*im; + + for (int i = ix; i < num_blocks_per_row; i += K_QUANTS_PER_ITERATION) { + const int bi = ib0 + i; + + const float * y = yy + i * QK_K + y_offset; + const uint8_t * q = qs_base + bi * (QK_K / 4) + q_offset; + const uint8_t * h = hmask_base + bi * (QK_K / 8) + l0; + + const uint16_t * a = (const uint16_t *)(scales_base + bi * 12); + utmp[0] = ((a[0] >> s_shift) & kmask2) | (((a[4] >> (s_shift + 0)) & kmask1) << 4); + utmp[1] = ((a[1] >> s_shift) & kmask2) | (((a[5] >> (s_shift + 0)) & kmask1) << 4); + utmp[2] = ((a[2] >> s_shift) & kmask2) | (((a[4] >> (s_shift + 2)) & kmask1) << 4); + utmp[3] = ((a[3] >> s_shift) & kmask2) | (((a[5] >> (s_shift + 2)) & kmask1) << 4); + + const float d = d_base[bi]; + + float sum = 0; + for (int l = 0; l < n; ++l) { + sum += y[l+ 0] * (s[0] - 32) * (((q[l] >> 0) & 3) - (h[l] & (m << 0) ? 0 : 4)) + + y[l+32] * (s[2] - 32) * (((q[l] >> 2) & 3) - (h[l] & (m << 1) ? 0 : 4)) + + y[l+64] * (s[4] - 32) * (((q[l] >> 4) & 3) - (h[l] & (m << 2) ? 0 : 4)) + + y[l+96] * (s[6] - 32) * (((q[l] >> 6) & 3) - (h[l] & (m << 3) ? 0 : 4)); + sum += y[l+16] * (s[1] - 32) * (((q[l+16] >> 0) & 3) - (h[l+16] & (m << 0) ? 0 : 4)) + + y[l+48] * (s[3] - 32) * (((q[l+16] >> 2) & 3) - (h[l+16] & (m << 1) ? 0 : 4)) + + y[l+80] * (s[5] - 32) * (((q[l+16] >> 4) & 3) - (h[l+16] & (m << 2) ? 0 : 4)) + + y[l+112] * (s[7] - 32) * (((q[l+16] >> 6) & 3) - (h[l+16] & (m << 3) ? 0 : 4)); + } + tmp += d * sum; + } +#else + GGML_UNUSED(vx); + GGML_UNUSED(yy); + GGML_UNUSED(ncols); + GGML_UNUSED(item_ct1); + GGML_ABORT("Q3_K reorder DMMV not supported for QK_K != 256"); +#endif + + // sum up partial sums and write back result +#pragma unroll + for (int mask = QK_WARP_SIZE / 2; mask > 0; mask >>= 1) { + tmp += + dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask); + } + + if (item_ct1.get_local_id(2) == 0) { + dst[row] = tmp; + } +} + /* DPCT1110:6: The total declared local variable size in device function dequantize_mul_mat_vec_q4_k exceeds 128 bytes and may cause high register @@ -615,6 +751,162 @@ static void dequantize_mul_mat_vec_q4_k(const void *__restrict__ vx, } } +static void dequantize_mul_mat_vec_q4_k_reorder(const void *__restrict__ vx, + const float *__restrict__ yy, + float *__restrict__ dst, + const int ncols, int nrows, + const sycl::nd_item<3> &item_ct1) { + + const int row = item_ct1.get_group(2) * item_ct1.get_local_range(1) + + item_ct1.get_local_id(1); + if (row > nrows) return; + const int num_blocks_per_row = ncols / QK_K; + const int ib0 = row*num_blocks_per_row; + + // SOA base pointers for the reordered layout: + // [qs: nb * QK_K/2] [scales: nb * K_SCALE_SIZE] [dm: nb * sizeof(half2)] + const int nb = nrows * num_blocks_per_row; + const uint8_t * qs_base = (const uint8_t *)vx; + const uint8_t * scales_base = qs_base + (size_t)nb * (QK_K / 2); + const sycl::half2 * dm_base = (const sycl::half2 *)(scales_base + (size_t)nb * K_SCALE_SIZE); + +#if QK_K == 256 + const uint16_t kmask1 = 0x3f3f; + const uint16_t kmask2 = 0x0f0f; + const uint16_t kmask3 = 0xc0c0; + + const int tid = + item_ct1.get_local_id(2) / K_QUANTS_PER_ITERATION; // 0...31 or 0...16 + const int ix = + item_ct1.get_local_id(2) % K_QUANTS_PER_ITERATION; // 0 or 0,1 + + const int step = 8/K_QUANTS_PER_ITERATION; // 8 or 4 + + const int il = tid/step; // 0...3 + const int ir = tid - step*il; // 0...7 or 0...3 + const int n = 2 * K_QUANTS_PER_ITERATION; // 2 or 4 + + const int im = il/2; // 0 or 1. 0 computes 0,32 + 128,160, 1 computes 64,96 + 192,224 + const int in = il%2; + + const int l0 = n*(2*ir + in); + const int q_offset = 32*im + l0; + const int y_offset = 64*im + l0; + + uint16_t aux[4]; + const uint8_t * sc = (const uint8_t *)aux; + +#if K_QUANTS_PER_ITERATION == 2 + uint32_t q32[4]; + const uint8_t * q4 = (const uint8_t *)q32; +#else + uint16_t q16[4]; + const uint8_t * q4 = (const uint8_t *)q16; +#endif + + float tmp = 0; // partial sum for thread in warp + + for (int i = ix; i < num_blocks_per_row; i += K_QUANTS_PER_ITERATION) { + const int bi = ib0 + i; + + const float * y1 = yy + i*QK_K + y_offset; + const float * y2 = y1 + 128; + + const sycl::half2 dm_val = dm_base[bi]; + const float dall = dm_val[0]; + const float dmin = dm_val[1]; + + const uint16_t * a = (const uint16_t *)(scales_base + bi * K_SCALE_SIZE); + aux[0] = a[im+0] & kmask1; + aux[1] = a[im+2] & kmask1; + aux[2] = ((a[im+4] >> 0) & kmask2) | ((a[im+0] & kmask3) >> 2); + aux[3] = ((a[im+4] >> 4) & kmask2) | ((a[im+2] & kmask3) >> 2); + +#if K_QUANTS_PER_ITERATION == 2 + const uint32_t * q1 = (const uint32_t *)(qs_base + bi * (QK_K / 2) + q_offset); + const uint32_t * q2 = q1 + 16; + + q32[0] = q1[0] & 0x0f0f0f0f; + q32[1] = q1[0] & 0xf0f0f0f0; + q32[2] = q2[0] & 0x0f0f0f0f; + q32[3] = q2[0] & 0xf0f0f0f0; + + sycl::float4 s = {0.f, 0.f, 0.f, 0.f}; + float smin = 0; + for (int l = 0; l < 4; ++l) { + s.x() += y1[l] * q4[l + 0]; s.y() += y1[l + 32] * q4[l + 4]; + s.z() += y2[l] * q4[l + 8]; s.w() += y2[l + 32] * q4[l + 12]; + smin += y1[l] * sc[2] + y1[l+32] * sc[3] + y2[l] * sc[6] + y2[l+32] * sc[7]; + } + tmp += dall * (s.x() * sc[0] + s.y() * sc[1] * 1.f / 16.f + + s.z() * sc[4] + s.w() * sc[5] * 1.f / 16.f) - + dmin * smin; +#else + const uint16_t * q1 = (const uint16_t *)(qs_base + bi * (QK_K / 2) + q_offset); + const uint16_t * q2 = q1 + 32; + + q16[0] = q1[0] & 0x0f0f; + q16[1] = q1[0] & 0xf0f0; + q16[2] = q2[0] & 0x0f0f; + q16[3] = q2[0] & 0xf0f0; + + float4 s = {0.f, 0.f, 0.f, 0.f}; + float smin = 0; + for (int l = 0; l < 2; ++l) { + s.x += y1[l] * q4[l+0]; s.y += y1[l+32] * q4[l+2]; + s.z += y2[l] * q4[l+4]; s.w += y2[l+32] * q4[l+6]; + smin += y1[l] * sc[2] + y1[l+32] * sc[3] + y2[l] * sc[6] + y2[l+32] * sc[7]; + } + tmp += dall * (s.x * sc[0] + s.y * sc[1] * 1.f/16.f + s.z * sc[4] + s.w * sc[5] * 1.f/16.f) - dmin * smin; +#endif + + } +#else + const int tid = item_ct1.get_local_id(2)/(2*K_QUANTS_PER_ITERATION); // 0...15 + const int ix = item_ct1.get_local_id(2)%(2*K_QUANTS_PER_ITERATION); + + const int step = tid * K_QUANTS_PER_ITERATION; + + uint16_t aux16[2]; + const uint8_t * s = (const uint8_t *)aux16; + + float tmp = 0; + + for (int i = ix; i < num_blocks_per_row; i += 2*K_QUANTS_PER_ITERATION) { + const int bi = ib0 + i; + + const uint8_t * q = qs_base + bi * (QK_K / 2) + step; + const float * y = yy + i*QK_K + step; + const uint16_t * a = (const uint16_t *)(scales_base + bi * K_SCALE_SIZE); + aux16[0] = a[0] & 0x0f0f; + aux16[1] = (a[0] >> 4) & 0x0f0f; + const sycl::half2 dm_val = dm_base[bi]; + const float d = (float)dm_val[0]; + const float m = (float)dm_val[1]; + float sum = 0.f; + for (int j = 0; j < K_QUANTS_PER_ITERATION; ++j) { + sum += y[j+ 0] * (d * s[0] * (q[j+ 0] & 0xF) - m * s[2]) + + y[j+16] * (d * s[0] * (q[j+16] & 0xF) - m * s[2]) + + y[j+32] * (d * s[1] * (q[j+ 0] >> 4) - m * s[3]) + + y[j+48] * (d * s[1] * (q[j+16] >> 4) - m * s[3]); + } + tmp += sum; + } + +#endif + + // sum up partial sums and write back result +#pragma unroll + for (int mask = QK_WARP_SIZE / 2; mask > 0; mask >>= 1) { + tmp += + dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask); + } + + if (tid == 0) { + dst[row] = tmp; + } +} + /* DPCT1110:7: The total declared local variable size in device function dequantize_mul_mat_vec_q5_k exceeds 128 bytes and may cause high register @@ -864,6 +1156,129 @@ static void dequantize_mul_mat_vec_q6_k(const void * __restrict__ vx, const floa } } +static void dequantize_mul_mat_vec_q6_k_reorder(const void * __restrict__ vx, const float * __restrict__ yy, float * __restrict__ dst, const int ncols, int nrows, + const sycl::nd_item<3> &item_ct1) { + + static_assert(16%K_QUANTS_PER_ITERATION == 0, "16 must be divisible by K_QUANTS_PER_ITERATION"); + + const int row = item_ct1.get_group(2) * item_ct1.get_local_range(1) + + item_ct1.get_local_id(1); + if (row > nrows) return; + + const int num_blocks_per_row = ncols / QK_K; + const int ib0 = row*num_blocks_per_row; + + // SOA base pointers for the reordered layout: + // [ql: nb * QK_K/2] [qh: nb * QK_K/4] [scales: nb * QK_K/16] [d: nb * sizeof(half)] + const int nb = nrows * num_blocks_per_row; + const uint8_t * ql_base = (const uint8_t *)vx; + const uint8_t * qh_base = ql_base + (size_t)nb * (QK_K / 2); + const int8_t * scales_base = (const int8_t *)(qh_base + (size_t)nb * (QK_K / 4)); + const sycl::half * d_base = (const sycl::half *)((const uint8_t *)scales_base + (size_t)nb * (QK_K / 16)); + +#if QK_K == 256 + + const int tid = + item_ct1.get_local_id(2) / K_QUANTS_PER_ITERATION; // 0...31 or 0...16 + const int ix = + item_ct1.get_local_id(2) % K_QUANTS_PER_ITERATION; // 0 or 0, 1 + + const int step = 16/K_QUANTS_PER_ITERATION; // 16 or 8 + + const int im = tid/step; // 0 or 1. 0 computes 0..., 1 computes 128... + const int in = tid - step*im; // 0...15 or 0...7 + +#if K_QUANTS_PER_ITERATION == 1 + const int l0 = K_QUANTS_PER_ITERATION*in; // 0...15 + const int is = 0; +#else + const int l0 = 4 * in; // 0, 4, 8, ..., 28 + const int is = in / 4; +#endif + const int ql_offset = 64*im + l0; + const int qh_offset = 32*im + l0; + const int s_offset = 8*im + is; + const int y_offset = 128*im + l0; + + float tmp = 0; // partial sum for thread in warp + + for (int i = ix; i < num_blocks_per_row; i += K_QUANTS_PER_ITERATION) { + const int bi = ib0 + i; + + const float * y = yy + i * QK_K + y_offset; + const uint8_t * ql = ql_base + bi * (QK_K / 2) + ql_offset; + const uint8_t * qh = qh_base + bi * (QK_K / 4) + qh_offset; + const int8_t * s = scales_base + bi * (QK_K / 16) + s_offset; + + const float d = d_base[bi]; + +#if K_QUANTS_PER_ITERATION == 1 + float sum = y[ 0] * s[0] * d * ((int8_t)((ql[ 0] & 0xF) | ((qh[ 0] & 0x03) << 4)) - 32) + + y[16] * s[1] * d * ((int8_t)((ql[16] & 0xF) | ((qh[16] & 0x03) << 4)) - 32) + + y[32] * s[2] * d * ((int8_t)((ql[32] & 0xF) | ((qh[ 0] & 0x0c) << 2)) - 32) + + y[48] * s[3] * d * ((int8_t)((ql[48] & 0xF) | ((qh[16] & 0x0c) << 2)) - 32) + + y[64] * s[4] * d * ((int8_t)((ql[ 0] >> 4) | ((qh[ 0] & 0x30) >> 0)) - 32) + + y[80] * s[5] * d * ((int8_t)((ql[16] >> 4) | ((qh[16] & 0x30) >> 0)) - 32) + + y[96] * s[6] * d * ((int8_t)((ql[32] >> 4) | ((qh[ 0] & 0xc0) >> 2)) - 32) + +y[112] * s[7] * d * ((int8_t)((ql[48] >> 4) | ((qh[16] & 0xc0) >> 2)) - 32); + tmp += sum; +#else + float sum = 0; + for (int l = 0; l < 4; ++l) { + sum += y[l+ 0] * s[0] * d * ((int8_t)((ql[l+ 0] & 0xF) | (((qh[l] >> 0) & 3) << 4)) - 32) + + y[l+32] * s[2] * d * ((int8_t)((ql[l+32] & 0xF) | (((qh[l] >> 2) & 3) << 4)) - 32) + + y[l+64] * s[4] * d * ((int8_t)((ql[l+ 0] >> 4) | (((qh[l] >> 4) & 3) << 4)) - 32) + + y[l+96] * s[6] * d * ((int8_t)((ql[l+32] >> 4) | (((qh[l] >> 6) & 3) << 4)) - 32); + } + tmp += sum; +#endif + + } + +#else + + const int tid = item_ct1.get_local_id(2)/(2*K_QUANTS_PER_ITERATION); // 0...7 + const int ix = item_ct1.get_local_id(2)%(2*K_QUANTS_PER_ITERATION); // 0...3 + + const int step = tid * K_QUANTS_PER_ITERATION; + + float tmp = 0; // partial sum for thread in warp + + for (int i = ix; i < num_blocks_per_row; i += 2*K_QUANTS_PER_ITERATION) { + const int bi = ib0 + i; + + const float * y = yy + i * QK_K + step; + const uint8_t * ql = ql_base + bi * (QK_K / 2) + step; + const uint8_t * qh = qh_base + bi * (QK_K / 4) + step; + const int8_t * s = scales_base + bi * (QK_K / 16); + + const float d = d_base[bi]; + + float sum = 0; + for (int j = 0; j < K_QUANTS_PER_ITERATION; ++j) { + sum += y[j+ 0] * s[0] * d * ((int8_t)((ql[j+ 0] & 0xF) | ((qh[j] & 0x03) << 4)) - 32) + + y[j+16] * s[1] * d * ((int8_t)((ql[j+16] & 0xF) | ((qh[j] & 0x0c) << 2)) - 32) + + y[j+32] * s[2] * d * ((int8_t)((ql[j+ 0] >> 4) | ((qh[j] & 0x30) >> 0)) - 32) + + y[j+48] * s[3] * d * ((int8_t)((ql[j+16] >> 4) | ((qh[j] & 0xc0) >> 2)) - 32); + } + tmp += sum; + + } + +#endif + + // sum up partial sums and write back result +#pragma unroll + for (int mask = QK_WARP_SIZE / 2; mask > 0; mask >>= 1) { + tmp += + dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask); + } + + if (tid == 0) { + dst[row] = tmp; + } +} + static void dequantize_mul_mat_vec_q4_0_sycl_reorder(const void *vx, const dfloat *y, float *dst, const int ncols, const int nrows, @@ -972,6 +1387,103 @@ static void dequantize_mul_mat_vec_q5_1_sycl(const void *vx, const dfloat *y, } } +static void dequantize_mul_mat_vec_q8_0_sycl_reorder(const void *vx, const dfloat *y, + float *dst, const int ncols, + const int nrows, + dpct::queue_ptr stream) { + GGML_ASSERT(ncols % GGML_SYCL_DMMV_X == 0); + const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y; + const sycl::range<3> block_nums(1, 1, block_num_y); + const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE); + { + dpct::has_capability_or_fail(stream->get_device(), + {sycl::aspect::fp16}); + + stream->parallel_for( + sycl::nd_range<3>(block_nums * block_dims, block_dims), + [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] { + // Q8_0 reorder layout: [all qs (ncols*nrows bytes)][all d values] + // Cannot reuse dequantize_mul_mat_vec_reorder template because it has + // Q4_0-specific constants hardcoded (d_ptr offset and qs stride). + const int row = item_ct1.get_group(2) * item_ct1.get_local_range(1) + + item_ct1.get_local_id(1); + if (row >= nrows) return; + + const int tid = item_ct1.get_local_id(2); + const int iter_stride = 8*2*GGML_SYCL_DMMV_X; + const int vals_per_iter = iter_stride / WARP_SIZE; + const int ncols_left = ncols % (QK8_0*WARP_SIZE); + const int ncols_align = ncols - ncols_left; + +#ifdef GGML_SYCL_F16 + sycl::half2 tmp = {0.0f, 0.0f}; +#else + float tmp = 0.0f; +#endif + const char *d_ptr = (const char*)vx + ncols*nrows; // d after all qs + + int i = 0; + for (i = 0; i < ncols_align; i += iter_stride) { + const int col = i + vals_per_iter*tid; + const int ib = (row*ncols + col)/QK8_0; + const int iqs = col % QK8_0; + +#pragma unroll + for (int j = 0; j < vals_per_iter; j += 2) { + dfloat2 v; + dequantize_q8_0_reorder((const void *)d_ptr, ib, (const void *)vx, + ib * QK8_0 + iqs + j, v); + +#ifdef GGML_SYCL_F16 + dfloat2 t1{y[col + j + 0], y[col + j + 1]}; + tmp += v * t1; +#else + tmp += v.x() * y[col + j + 0]; + tmp += v.y() * y[col + j + 1]; +#endif + } + } + + // handle remaining columns + for (; i < ncols; i += iter_stride) { + if (tid >= ncols_left/QK8_0) continue; + const int col = i + vals_per_iter*tid; + const int ib = (row*ncols + col)/QK8_0; + const int iqs = col % QK8_0; + +#pragma unroll + for (int j = 0; j < vals_per_iter; j += 2) { + dfloat2 v; + dequantize_q8_0_reorder((const void *)d_ptr, ib, (const void *)vx, + ib * QK8_0 + iqs + j, v); + +#ifdef GGML_SYCL_F16 + dfloat2 t1{y[col + j + 0], y[col + j + 1]}; + tmp += v * t1; +#else + tmp += v.x() * y[col + j + 0]; + tmp += v.y() * y[col + j + 1]; +#endif + } + } + + // reduce + const int mask_start = ncols > GGML_SYCL_DMMV_X ? WARP_SIZE >> 1 : WARP_SIZE >> 2; + for (int mask = mask_start; mask > 0; mask >>= 1) { + tmp += dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask); + } + + if (tid == 0) { +#ifdef GGML_SYCL_F16 + dst[row] = tmp.x() + tmp.y(); +#else + dst[row] = tmp; +#endif + } + }); + } +} + static void dequantize_mul_mat_vec_q8_0_sycl(const void *vx, const dfloat *y, float *dst, const int ncols, const int nrows, @@ -1025,6 +1537,22 @@ static void dequantize_mul_mat_vec_q3_K_sycl(const void *vx, const float *y, }); } +static void dequantize_mul_mat_vec_q3_K_sycl_reorder(const void *vx, const float *y, + float *dst, const int ncols, + const int nrows, + dpct::queue_ptr stream) { + GGML_ASSERT(ncols % QK_K == 0); + const int ny = 2 / K_QUANTS_PER_ITERATION; + const int block_num_y = (nrows + ny - 1) / ny; + const sycl::range<3> block_nums(1, 1, block_num_y); + const sycl::range<3> block_dims(1, ny, QK_WARP_SIZE); + stream->parallel_for( + sycl::nd_range<3>(block_nums * block_dims, block_dims), + [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(QK_WARP_SIZE)]] { + dequantize_mul_mat_vec_q3_k_reorder(vx, y, dst, ncols, nrows, item_ct1); + }); +} + static void dequantize_mul_mat_vec_q4_K_sycl(const void *vx, const float *y, float *dst, const int ncols, const int nrows, @@ -1070,6 +1598,38 @@ static void dequantize_mul_mat_vec_q6_K_sycl(const void *vx, const float *y, }); } +static void dequantize_mul_mat_vec_q4_K_sycl_reorder(const void *vx, const float *y, + float *dst, const int ncols, + const int nrows, + dpct::queue_ptr stream) { + GGML_ASSERT(ncols % QK_K == 0); + const int ny = 2 / K_QUANTS_PER_ITERATION; + const int block_num_y = (nrows + ny - 1) / ny; + const sycl::range<3> block_nums(1, 1, block_num_y); + const sycl::range<3> block_dims(1, ny, QK_WARP_SIZE); + stream->parallel_for( + sycl::nd_range<3>(block_nums * block_dims, block_dims), + [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(QK_WARP_SIZE)]] { + dequantize_mul_mat_vec_q4_k_reorder(vx, y, dst, ncols, nrows, item_ct1); + }); +} + +static void dequantize_mul_mat_vec_q6_K_sycl_reorder(const void *vx, const float *y, + float *dst, const int ncols, + const int nrows, + dpct::queue_ptr stream) { + GGML_ASSERT(ncols % QK_K == 0); + const int ny = 2 / K_QUANTS_PER_ITERATION; + const int block_num_y = (nrows + ny - 1) / ny; + const sycl::range<3> block_nums(1, 1, block_num_y); + const sycl::range<3> block_dims(1, ny, QK_WARP_SIZE); + stream->parallel_for( + sycl::nd_range<3>(block_nums * block_dims, block_dims), + [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(QK_WARP_SIZE)]] { + dequantize_mul_mat_vec_q6_k_reorder(vx, y, dst, ncols, nrows, item_ct1); + }); +} + void ggml_sycl_op_dequantize_mul_mat_vec( ggml_backend_sycl_context & ctx, const ggml_tensor *src0, const ggml_tensor *src1, ggml_tensor *dst, @@ -1089,7 +1649,8 @@ void ggml_sycl_op_dequantize_mul_mat_vec( bool src1_convert_f16 = src0->type == GGML_TYPE_Q4_0 || src0->type == GGML_TYPE_Q4_1 || src0->type == GGML_TYPE_Q5_0 || src0->type == GGML_TYPE_Q5_1 || - src0->type == GGML_TYPE_Q8_0 || src0->type == GGML_TYPE_F16; + src0->type == GGML_TYPE_Q8_0 || src0->type == GGML_TYPE_F16 || + src0->type == GGML_TYPE_BF16; if (src1_convert_f16) { scope_op_debug_print scope_dbg_print(__func__, "/to_fp16_sycl", dst, /*num_src=*/2, @@ -1122,19 +1683,28 @@ void ggml_sycl_op_dequantize_mul_mat_vec( dequantize_mul_mat_vec_q5_1_sycl(src0_dd_i, src1_dfloat, dst_dd_i, ne00, row_diff, stream); break; case GGML_TYPE_Q8_0: - dequantize_mul_mat_vec_q8_0_sycl(src0_dd_i, src1_dfloat, dst_dd_i, ne00, row_diff, stream); + if ((ggml_tensor_extra_gpu *) dst->src[0]->extra && + ((ggml_tensor_extra_gpu *) dst->src[0]->extra)->optimized_feature.reorder) { + dequantize_mul_mat_vec_q8_0_sycl_reorder(src0_dd_i, src1_dfloat, dst_dd_i, ne00, row_diff, stream); + } else { + dequantize_mul_mat_vec_q8_0_sycl(src0_dd_i, src1_dfloat, dst_dd_i, ne00, row_diff, stream); + } break; case GGML_TYPE_Q2_K: dequantize_mul_mat_vec_q2_K_sycl(src0_dd_i, src1_ddf_i, dst_dd_i, ne00, row_diff, stream); break; case GGML_TYPE_Q3_K: - dequantize_mul_mat_vec_q3_K_sycl(src0_dd_i, src1_ddf_i, dst_dd_i, ne00, row_diff, stream); + if ((ggml_tensor_extra_gpu *) dst->src[0]->extra && + ((ggml_tensor_extra_gpu *) dst->src[0]->extra)->optimized_feature.reorder) { + dequantize_mul_mat_vec_q3_K_sycl_reorder(src0_dd_i, src1_ddf_i, dst_dd_i, ne00, row_diff, stream); + } else { + dequantize_mul_mat_vec_q3_K_sycl(src0_dd_i, src1_ddf_i, dst_dd_i, ne00, row_diff, stream); + } break; case GGML_TYPE_Q4_K: if ((ggml_tensor_extra_gpu *) dst->src[0]->extra && ((ggml_tensor_extra_gpu *) dst->src[0]->extra)->optimized_feature.reorder) { - // reorder is currently not supported for dmmv - GGML_ABORT("Unimplemented dequantize case case for q4_k reorder"); + dequantize_mul_mat_vec_q4_K_sycl_reorder(src0_dd_i, src1_ddf_i, dst_dd_i, ne00, row_diff, stream); } else { dequantize_mul_mat_vec_q4_K_sycl(src0_dd_i, src1_ddf_i, dst_dd_i, ne00, row_diff, stream); } @@ -1143,11 +1713,21 @@ void ggml_sycl_op_dequantize_mul_mat_vec( dequantize_mul_mat_vec_q5_K_sycl(src0_dd_i, src1_ddf_i, dst_dd_i, ne00, row_diff, stream); break; case GGML_TYPE_Q6_K: - dequantize_mul_mat_vec_q6_K_sycl(src0_dd_i, src1_ddf_i, dst_dd_i, ne00, row_diff, stream); + if ((ggml_tensor_extra_gpu *) dst->src[0]->extra && + ((ggml_tensor_extra_gpu *) dst->src[0]->extra)->optimized_feature.reorder) { + dequantize_mul_mat_vec_q6_K_sycl_reorder(src0_dd_i, src1_ddf_i, dst_dd_i, ne00, row_diff, stream); + } else { + dequantize_mul_mat_vec_q6_K_sycl(src0_dd_i, src1_ddf_i, dst_dd_i, ne00, row_diff, stream); + } break; case GGML_TYPE_F16: convert_mul_mat_vec_f16_sycl(src0_dd_i, src1_dfloat, dst_dd_i, ne00, row_diff, stream); break; +#ifdef GGML_SYCL_DMMV_HAS_BF16 + case GGML_TYPE_BF16: + convert_mul_mat_vec_bf16_sycl(src0_dd_i, src1_dfloat, dst_dd_i, ne00, row_diff, stream); + break; +#endif default: printf("ggml_sycl_op_dequantize_mul_mat_vec unsupported GGML_TYPE %d\n", src0->type); GGML_ABORT("fatal error"); diff --git a/ggml/src/ggml-sycl/dpct/helper.hpp b/ggml/src/ggml-sycl/dpct/helper.hpp index 30ec1e8dafc..791d3cac52e 100644 --- a/ggml/src/ggml-sycl/dpct/helper.hpp +++ b/ggml/src/ggml-sycl/dpct/helper.hpp @@ -15,18 +15,9 @@ #include <sycl/sycl.hpp> #include <sycl/half_type.hpp> -#include <syclcompat/math.hpp> -#include <map> - -#ifdef GGML_SYCL_USE_INTEL_ONEMKL #include <oneapi/mkl.hpp> -// Allow to use the same namespace for Intel oneMKL and oneMath -namespace oneapi { - namespace math = mkl; -} -#else -#include <oneapi/math.hpp> -#endif + +#include <map> #include "ggml.h" @@ -92,32 +83,13 @@ inline std::string get_device_backend_and_type(const sycl::device &device) { } template <typename Ts> struct matrix_info_t { - oneapi::math::transpose transpose_info[2]; + oneapi::mkl::transpose transpose_info[2]; Ts value_info[2]; std::int64_t size_info[3]; std::int64_t ld_info[3]; std::int64_t groupsize_info; }; -inline auto get_onemath_backend(sycl::queue& queue) -#if defined(GGML_SYCL_GENERIC) || defined(GGML_SYCL_USE_INTEL_ONEMKL) - -> sycl::queue& -#endif -{ -// If the backend is known at compile-time, use oneMath backend_selector to use -// compile-time dispatching and avoid the need to dlopen libraries. Otherwise -// fallback to runtime dispatching. -#if defined(GGML_SYCL_NVIDIA) - return oneapi::math::backend_selector<oneapi::math::backend::cublas>{ queue }; -#elif defined(GGML_SYCL_AMD) - return oneapi::math::backend_selector<oneapi::math::backend::rocblas>{ queue }; -#elif defined(GGML_SYCL_GENERIC) || defined(GGML_SYCL_USE_INTEL_ONEMKL) - return queue; -#else - static_assert(false, "Unsupported backend"); -#endif -} - namespace dpct { typedef sycl::queue *queue_ptr; @@ -1735,7 +1707,7 @@ namespace dpct namespace detail { template <class Ta, class Tb, class Tc, class Ts> - inline void gemm_impl(sycl::queue & q, oneapi::math::transpose a_trans, oneapi::math::transpose b_trans, int m, + inline void gemm_impl(sycl::queue & q, oneapi::mkl::transpose a_trans, oneapi::mkl::transpose b_trans, int m, int n, int k, const void * alpha, const void * a, int lda, const void * b, int ldb, const void * beta, void * c, int ldc) { Ts alpha_value = dpct::get_value(reinterpret_cast<const Ts *>(alpha), q); @@ -1743,7 +1715,7 @@ namespace dpct auto data_a = get_memory<const Ta>(a); auto data_b = get_memory<const Tb>(b); auto data_c = get_memory<Tc>(c); - oneapi::math::blas::column_major::gemm(get_onemath_backend(q), a_trans, b_trans, m, n, k, alpha_value, data_a, + oneapi::mkl::blas::column_major::gemm(q, a_trans, b_trans, m, n, k, alpha_value, data_a, lda, data_b, ldb, beta_value, data_c, ldc); } @@ -1775,7 +1747,7 @@ namespace dpct }; template <class Ta, class Tb, class Tc, class Ts> - inline void gemm_batch_impl(sycl::queue & q, oneapi::math::transpose a_trans, oneapi::math::transpose b_trans, + inline void gemm_batch_impl(sycl::queue & q, oneapi::mkl::transpose a_trans, oneapi::mkl::transpose b_trans, int m, int n, int k, const void * alpha, const void ** a, int lda, const void ** b, int ldb, const void * beta, void ** c, int ldc, int batch_size, matrix_info_t<float> * matrix_info) { @@ -1794,8 +1766,8 @@ namespace dpct matrix_info->ld_info[2] = ldc; matrix_info->groupsize_info = batch_size; - sycl::event e = oneapi::math::blas::column_major::gemm_batch( - get_onemath_backend(q), matrix_info->transpose_info, matrix_info->transpose_info + 1, + sycl::event e = oneapi::mkl::blas::column_major::gemm_batch( + q, matrix_info->transpose_info, matrix_info->transpose_info + 1, matrix_info->size_info, matrix_info->size_info + 1, matrix_info->size_info + 2, reinterpret_cast<Ts *>(matrix_info->value_info), reinterpret_cast<const Ta **>(a), matrix_info->ld_info, reinterpret_cast<const Tb **>(b), matrix_info->ld_info + 1, @@ -1804,7 +1776,7 @@ namespace dpct } template <class Ta, class Tb, class Tc, class Ts> - inline void gemm_batch_impl(sycl::queue & q, oneapi::math::transpose a_trans, oneapi::math::transpose b_trans, + inline void gemm_batch_impl(sycl::queue & q, oneapi::mkl::transpose a_trans, oneapi::mkl::transpose b_trans, int m, int n, int k, const void * alpha, const void * a, int lda, long long int stride_a, const void * b, int ldb, long long int stride_b, const void * beta, void * c, int ldc, long long int stride_c, int batch_size) { @@ -1813,7 +1785,7 @@ namespace dpct auto data_a = get_memory<const Ta>(a); auto data_b = get_memory<const Tb>(b); auto data_c = get_memory<Tc>(c); - oneapi::math::blas::column_major::gemm_batch(get_onemath_backend(q), a_trans, b_trans, m, n, k, alpha_value, + oneapi::mkl::blas::column_major::gemm_batch(q, a_trans, b_trans, m, n, k, alpha_value, data_a, lda, stride_a, data_b, ldb, stride_b, beta_value, data_c, ldc, stride_c, batch_size); } @@ -2300,7 +2272,7 @@ namespace dpct sycl::range<3>(x, y, 1), direction); } - inline void gemm(sycl::queue & q, oneapi::math::transpose a_trans, oneapi::math::transpose b_trans, int m, int n, + inline void gemm(sycl::queue & q, oneapi::mkl::transpose a_trans, oneapi::mkl::transpose b_trans, int m, int n, int k, const void * alpha, const void * a, library_data_t a_type, int lda, const void * b, library_data_t b_type, int ldb, const void * beta, void * c, library_data_t c_type, int ldc, library_data_t scaling_type) { @@ -2367,7 +2339,7 @@ namespace dpct library_data_t::real_bfloat16, library_data_t::real_bfloat16, library_data_t::real_float, library_data_t::real_float): { - detail::gemm_impl<oneapi::math::bfloat16, oneapi::math::bfloat16, float, float>( + detail::gemm_impl<oneapi::mkl::bfloat16, oneapi::mkl::bfloat16, float, float>( q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); break; } @@ -2406,7 +2378,7 @@ namespace dpct library_data_t::real_bfloat16, library_data_t::real_bfloat16, library_data_t::real_bfloat16, library_data_t::real_float): { - detail::gemm_impl<oneapi::math::bfloat16, oneapi::math::bfloat16, oneapi::math::bfloat16, float>( + detail::gemm_impl<oneapi::mkl::bfloat16, oneapi::mkl::bfloat16, oneapi::mkl::bfloat16, float>( q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); break; } @@ -2448,7 +2420,7 @@ namespace dpct /// \param [in] ldc Leading dimension of C. /// \param [in] batch_size Specifies the number of matrix multiply operations to perform. /// \param [in] scaling_type Data type of the scaling factors. - inline void gemm_batch(sycl::queue & q, oneapi::math::transpose a_trans, oneapi::math::transpose b_trans, int m, + inline void gemm_batch(sycl::queue & q, oneapi::mkl::transpose a_trans, oneapi::mkl::transpose b_trans, int m, int n, int k, const void * alpha, const void * a[], library_data_t a_type, int lda, const void * b[], library_data_t b_type, int ldb, const void * beta, void * c[], library_data_t c_type, int ldc, int batch_size, library_data_t scaling_type, @@ -2486,7 +2458,7 @@ namespace dpct library_data_t::real_bfloat16, library_data_t::real_bfloat16, library_data_t::real_bfloat16, library_data_t::real_float): { - detail::gemm_batch_impl<oneapi::math::bfloat16, oneapi::math::bfloat16, oneapi::math::bfloat16, float>( + detail::gemm_batch_impl<oneapi::mkl::bfloat16, oneapi::mkl::bfloat16, oneapi::mkl::bfloat16, float>( q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, batch_size, matrix_info); break; } @@ -2494,7 +2466,7 @@ namespace dpct library_data_t::real_bfloat16, library_data_t::real_bfloat16, library_data_t::real_float, library_data_t::real_float): { - detail::gemm_batch_impl<oneapi::math::bfloat16, oneapi::math::bfloat16, float, float>( + detail::gemm_batch_impl<oneapi::mkl::bfloat16, oneapi::mkl::bfloat16, float, float>( q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, batch_size, matrix_info); break; } @@ -2570,7 +2542,7 @@ namespace dpct /// \param [in] stride_c Stride between the different C matrices. /// \param [in] batch_size Specifies the number of matrix multiply operations to perform. /// \param [in] scaling_type Data type of the scaling factors. - inline void gemm_batch(sycl::queue & q, oneapi::math::transpose a_trans, oneapi::math::transpose b_trans, int m, + inline void gemm_batch(sycl::queue & q, oneapi::mkl::transpose a_trans, oneapi::mkl::transpose b_trans, int m, int n, int k, const void * alpha, const void * a, library_data_t a_type, int lda, long long int stride_a, const void * b, library_data_t b_type, int ldb, long long int stride_b, const void * beta, void * c, library_data_t c_type, int ldc, @@ -2643,7 +2615,7 @@ namespace dpct library_data_t::real_bfloat16, library_data_t::real_bfloat16, library_data_t::real_bfloat16, library_data_t::real_float): { - detail::gemm_batch_impl<oneapi::math::bfloat16, oneapi::math::bfloat16, oneapi::math::bfloat16, float>( + detail::gemm_batch_impl<oneapi::mkl::bfloat16, oneapi::mkl::bfloat16, oneapi::mkl::bfloat16, float>( q, a_trans, b_trans, m, n, k, alpha, a, lda, stride_a, b, ldb, stride_b, beta, c, ldc, stride_c, batch_size); break; @@ -2652,7 +2624,7 @@ namespace dpct library_data_t::real_bfloat16, library_data_t::real_bfloat16, library_data_t::real_float, library_data_t::real_float): { - detail::gemm_batch_impl<oneapi::math::bfloat16, oneapi::math::bfloat16, float, float>( + detail::gemm_batch_impl<oneapi::mkl::bfloat16, oneapi::mkl::bfloat16, float, float>( q, a_trans, b_trans, m, n, k, alpha, a, lda, stride_a, b, ldb, stride_b, beta, c, ldc, stride_c, batch_size); break; @@ -3025,6 +2997,778 @@ namespace dpct return 0; } + template <int n_nondefault_params, int n_default_params, typename T> + class args_selector; + + /// args_selector is a helper class for extracting arguments from an + /// array of pointers to arguments or buffer of arguments to pass to a + /// kernel function. + /// + /// \param R(Ts...) The type of the kernel + /// \param n_nondefault_params The number of nondefault parameters of the + /// kernel (excluding parameters that like sycl::nd_item, etc.) \param + /// n_default_params The number of default parameters of the kernel + /// + /// Example usage: + /// With the following kernel: + /// void foo(sycl::float2 *x, int n, sycl::nd_item<3> item_ct1, float + /// f=.1) {} + /// and with the declaration: + /// args_selector<2, 1, decltype(foo)> selector(kernelParams, extra); + /// we have: + /// selector.get<0>() returns a reference to sycl::float*, + /// selector.get<1>() returns a reference to int, + /// selector.get<2>() returns a reference to float + template <int n_nondefault_params, int n_default_params, typename R, + typename... Ts> + class args_selector<n_nondefault_params, n_default_params, R(Ts...)> { + private: + void **kernel_params; + char *args_buffer; + + template <int i> static constexpr int account_for_default_params() { + constexpr int n_total_params = sizeof...(Ts); + if constexpr (i >= n_nondefault_params) { + return n_total_params - n_default_params + + (i - n_nondefault_params); + } else { + return i; + } + } + + public: + /// Get the type of the ith argument of R(Ts...) + /// \param [in] i Index of parameter to get + /// \returns Type of ith parameter + template <int i> + using arg_type = std::tuple_element_t<account_for_default_params<i>(), + std::tuple<Ts...>>; + static constexpr int params_num = sizeof...(Ts); + + private: + template <int i> static constexpr int get_offset() { + if constexpr (i == 0) { + // we can assume args_buffer is properly aligned to the + // first argument + return 0; + } else { + constexpr int prev_off = get_offset<i - 1>(); + constexpr int prev_past_end = + prev_off + sizeof(arg_type<i - 1>); + using T = arg_type<i>; + // is the past-the-end of the i-1st element properly aligned + // with the ith element's alignment? + if constexpr (prev_past_end % alignof(T) == 0) { + return prev_past_end; + } + // otherwise bump prev_past_end to match alignment + else { + return prev_past_end + + (alignof(T) - (prev_past_end % alignof(T))); + } + } + } + + static char *get_args_buffer(void **extra) { + if (!extra) + return nullptr; + for (; (std::size_t)*extra != 0; ++extra) { + if ((std::size_t)*extra == 1) { + return static_cast<char *>(*(extra + 1)); + } + } + return nullptr; + } + + public: + /// If kernel_params is nonnull, then args_selector will + /// extract arguments from kernel_params. Otherwise, it + /// will extract them from extra. + /// \param [in] kernel_params Array of pointers to arguments + /// a or null pointer. + /// \param [in] extra Array containing pointer to argument buffer. + args_selector(void **kernel_params, void **extra) + : kernel_params(kernel_params), + args_buffer(get_args_buffer(extra)) {} + + /// Get a reference to the ith argument extracted from kernel_params + /// or extra. + /// \param [in] i Index of argument to get + /// \returns Reference to the ith argument + template <int i> arg_type<i> &get() { + if (kernel_params) { + return *static_cast<arg_type<i> *>(kernel_params[i]); + } else { + return *reinterpret_cast<arg_type<i> *>(args_buffer + + get_offset<i>()); + } + } + }; // COPY from DPCT head file + // /opt/intel/oneapi/dpcpp-ct/latest/include/dpct/util.hpp + + /// Utility class for launching SYCL kernels through kernel + /// function wrapper. + /// For example: + /// A SYCL kernel function: + /// void kernel_func(int *ptr, sycl::nd_item<3> item); + /// Kernel function wrapper: + /// void kernel_func_wrapper(int *ptr) { + /// sycl::queue queue = *dpct::kernel_launcher::_que; + /// unsigned int localMemSize = dpct::kernel_launcher::_local_mem_size; + /// sycl::nd_range<3> nr = dpct::kernel_launcher::_nr; + /// queue.parallel_for( + /// nr, + /// [=](sycl::nd_item<3> item_ct1) { + /// kernel_func(ptr, item_ct1); + /// }); + /// } + /// Then launch the kernel through wrapper like: + /// typedef void(*fpt)(int *); + /// fpt fp = kernel_func_wrapper; + /// dpct::kernel_launcher::launch(fp, dpct::dim3(1), dpct::dim3(1), 0, 0, + /// device_ptr); + /// If the origin function type is erased, then need to register it first: + /// void *fp = (void *)wrapper_register(&kernel_func_wrapper).get(); + /// dpct::kernel_launcher::launch(fp, dpct::dim3(1), dpct::dim3(1), args, + /// 0, 0); + class kernel_launcher { + template <typename FuncT, typename ArgSelector, std::size_t... Index> + static void launch_helper(FuncT &&func, ArgSelector &selector, + std::index_sequence<Index...>) { + func(selector.template get<Index>()...); + } + static void set_execution_config(dim3 group_range, dim3 local_range, + unsigned int local_mem_size, + queue_ptr que) { + if (que) { + _que = que; + } else { + _que = &get_default_queue(); + } + _nr = sycl::nd_range<3>( + static_cast<sycl::range<3>>(group_range * local_range), + static_cast<sycl::range<3>>(local_range)); + _local_mem_size = local_mem_size; + + + }; + static inline std::mutex kernel_function_ptr_map_mutex; + + public: + /// Variables for storing execution configuration. + static inline thread_local sycl::queue *_que = nullptr; + static inline thread_local sycl::nd_range<3> _nr = sycl::nd_range<3>(); + static inline thread_local unsigned int _local_mem_size = 0; + /// Map for retrieving launchable functor from a raw pointer. + static inline std::map< + const void *, + std::function<void(dim3, dim3, void **, unsigned int, queue_ptr)>> + kernel_function_ptr_map = {}; + + /// Registers a kernel function pointer with a corresponding launchable + /// functor. + /// \param [in] func Pointer to the kernel function. + /// \param [in] launcher Functor to handle kernel invocation. + static void register_kernel_ptr( + const void *func, + std::function<void(dim3, dim3, void **, unsigned int, queue_ptr)> + launcher) { + std::lock_guard<std::mutex> lock(kernel_function_ptr_map_mutex); + kernel_function_ptr_map[func] = std::move(launcher); + } + /// Launches a kernel function with arguments provided directly through + /// kernel function wrapper. + /// \tparam FuncT Type of the kernel function wrapper. + /// \tparam ArgsT Types of kernel arguments. + /// \param [in] func Pointer to the kernel function wrapper. + /// \param [in] group_range SYCL group range. + /// \param [in] local_range SYCL local range. + /// \param [in] local_mem_size The size of local memory required by the + /// kernel function. \param [in] que SYCL queue used to execute kernel. + /// \param [in] args Kernel arguments. + template <typename FuncT, typename... ArgsT> + static std::enable_if_t<std::is_invocable_v<FuncT *, ArgsT...>, void> + launch(FuncT *func, dim3 group_range, dim3 local_range, + unsigned int local_mem_size, queue_ptr que, ArgsT... args) { + set_execution_config(group_range, local_range, local_mem_size, que); + func(args...); + } + /// Launches a kernel function through registered kernel function + /// wrapper. \param [in] func Pointer to the registered kernel function + /// wrapper. \param [in] group_range SYCL group range. \param [in] + /// local_range SYCL local range. \param [in] args Array of pointers to + /// kernel arguments. \param [in] local_mem_size The size of local + /// memory required by the kernel function. \param [in] que SYCL queue + /// used to execute kernel. + static void launch(const void *func, dim3 group_range, dim3 local_range, + void **args, unsigned int local_mem_size, + queue_ptr que) { + std::lock_guard<std::mutex> lock(kernel_function_ptr_map_mutex); + auto Iter = kernel_function_ptr_map.find(func); + if (Iter == kernel_function_ptr_map.end()) { + throw std::runtime_error("dpct::launch() : no registered " + "kernel function wrapper found."); + } + (Iter->second)(group_range, local_range, args, local_mem_size, que); + } + /// Launches a kernel function with packed arguments through kernel + /// function wrapper. + /// \tparam FuncT Type of the kernel function wrapper. + /// \param [in] func Pointer to the kernel function wrapper. + /// \param [in] group_range SYCL group range. + /// \param [in] local_range SYCL local range. + /// \param [in] args Array of pointers to kernel arguments. + /// \param [in] local_mem_size The size of local memory required by the + /// kernel function. \param [in] que SYCL queue used to execute kernel. + template <typename FuncT> + static std::enable_if_t<std::is_function_v<FuncT>, void> + launch(FuncT *func, dim3 group_range, dim3 local_range, void **args, + unsigned int local_mem_size, queue_ptr que) { + constexpr size_t p_num = args_selector<0, 0, FuncT>::params_num; + set_execution_config(group_range, local_range, local_mem_size, que); + args_selector<p_num, p_num, FuncT> selector(args, nullptr); + launch_helper(func, selector, std::make_index_sequence<p_num>{}); + } + }; // COPY from DPCT head file + // /opt/intel/oneapi/dpcpp-ct/latest/include/dpct/kernel.hpp + + // /opt/intel/oneapi/dpcpp-ct/latest/include/dpct/util.hpp + template <typename T> + T select_from_sub_group( + sycl::sub_group g, + T x, + int remote_local_id, + int logical_sub_group_size = 32) { + unsigned int start_index = g.get_local_linear_id() / + logical_sub_group_size * + logical_sub_group_size; + return sycl::select_from_group( + g, x, start_index + remote_local_id % logical_sub_group_size); + } + + // /opt/intel/oneapi/dpcpp-ct/latest/include/dpct/math.hpp + template <typename T> + void ldmatrix(uintptr_t addr, T* m, bool trans = false, unsigned mat = 0) { + auto sg = sycl::ext::oneapi::this_work_item::get_sub_group(); + int lane = sg.get_local_linear_id(); + + int lane_group8_row = lane / 8; + int lane_group8_col = lane % 8; + + if (!trans) { + // calculate the source lane + int src_lane = 2 * lane_group8_row; + if (lane_group8_col >= 4) + src_lane += 1; + + // Broadcast the address from the source lane + auto recv_addr_uintp = + dpct::select_from_sub_group(sg, addr, mat * 8 + src_lane); + + // Cast the received address from uintptr_t to the type of 'm' + auto recv_addr = reinterpret_cast<T*>(recv_addr_uintp); + + // Non-transposed load + *m = recv_addr[lane_group8_col % 4]; + } else { + // calculate the source lane + int src_lane = (lane % 4) * 2; + + // Broadcast the address from the source lane + auto recv_addr_uintp_1 = + dpct::select_from_sub_group(sg, addr, mat * 8 + src_lane); + auto recv_addr_uintp_2 = + dpct::select_from_sub_group(sg, addr, mat * 8 + src_lane + 1); + + // Cast the received address from uintptr_t to 'half *' + auto recv_addr_1 = reinterpret_cast<sycl::half*>(recv_addr_uintp_1); + auto recv_addr_2 = reinterpret_cast<sycl::half*>(recv_addr_uintp_2); + + // Transposed load + int index = lane / 4; + sycl::half val0 = recv_addr_1[index]; + sycl::half val1 = recv_addr_2[index]; + + // Combine the two 16-bits into one 32-bit value + sycl::half2 val = sycl::half2(val0, val1); + *m = *reinterpret_cast<T*>(&val); + } + } + + template <typename T> + void ldmatrix(uintptr_t addr, T* m1, T* m2, bool trans = false) { + // Load 1st matrix + ldmatrix(addr, m1, trans, 0); + // Load 2nd matrix + ldmatrix(addr, m2, trans, 1); + } + + template <typename T> + void ldmatrix( + uintptr_t addr, T* m1, T* m2, T* m3, T* m4, bool trans = false) { + // Load 1st matrix + ldmatrix(addr, m1, trans, 0); + // Load 2nd matrix + ldmatrix(addr, m2, trans, 1); + // Load 3rd matrix + ldmatrix(addr, m3, trans, 2); + // Load 4th matrix + ldmatrix(addr, m4, trans, 3); + } + + // /opt/intel/oneapi/dpcpp-ct/latest/include/dpct/math.hpp + + /// A helper struct that defines the pack type for the input matrix + /// fragments + /// of mma() function based on the type of input matrix fragments. + /// The MMAType struct is specialized for different types of input matrices. + /// Currently, the specialization for f16, bf16 and s8 types is defined + /// below. \tparam [in] T The type of the input matrix fragments + template <typename T> + struct MMAType { + using PackType = uint32_t; + }; + + /// Each work item of a sub-group (limited to size 32) calling this function + /// calculates a subset fragment for the output matrix D using MAD operation + /// on A, B & C matrix fragments (D = A * B + C). Current supported shapes & + /// types: + /// - m8n8k4 (f32.f16.f16.f32) + /// - m8n8k16 (s32.s8.s8.s32) + /// - m16n8k8 (f32.f16.f16.f32 & f32.bf16.bf16.f32) + /// - m16n8k16 (f32.f16.f16.f32 & s32.s8.s8.s32) + /// - m16n8k32 (s32.s8.s8.s32) + /// Here, m, n & k define the shapes of A, B & C matrices respectively + /// (A = [m x k], B = [k x n], C = [m x n]). + /// \tparam [in] M The rows of A, C & D matrices + /// \tparam [in] N The columns of B, C, D matrices + /// \tparam [in] K The columns & rows of A & B matrices respectively + /// \tparam [in] ABType The type of the input matrix (A & B) fragment + /// \tparam [in] CDType The type of the output matrix (C & D) fragment + /// \param [out] d_mat_frag The fragment of the output matrix D to store the + /// result of A * B + C + /// \param [in] a_mat_frag The fragment of the input matrix A to be + /// multiplied with B matrix fragment \param [in] b_mat_frag The fragment of + /// the input matrix B to be multiplied with A matrix fragment \param [in] + /// c_mat_frag The fragment of the input matrix C to be added with the + /// result of A * B fragments + template <int M, int N, int K, typename ABType, typename CDType> + void mma( + volatile void** d_mat_frag, + void* a_mat_frag, + void* b_mat_frag, + void* c_mat_frag) { + auto d = reinterpret_cast<volatile CDType**>(d_mat_frag); + auto a = + reinterpret_cast<typename MMAType<ABType>::PackType*>(a_mat_frag); + auto b = + reinterpret_cast<typename MMAType<ABType>::PackType*>(b_mat_frag); + auto c = reinterpret_cast<CDType*>(c_mat_frag); + + auto sg = sycl::ext::oneapi::this_work_item::get_sub_group(); + int lane = sg.get_local_linear_id(); + + static_assert( + (M == 8 && N == 8 && K == 4) || (M == 8 && N == 8 && K == 16) || + (M == 16 && N == 8 && K == 8) || (M == 16 && N == 8 && K == 16) || + (M == 16 && N == 8 && K == 32), + "Unsupported MMA shape!"); + + short row_load_offset = 4 * (lane >> 2); + short col_load_offset = 8 * (lane % 4); + + if constexpr (M == 8 && N == 8 && K == 4) { + if constexpr (std::is_floating_point_v<CDType>) { + col_load_offset = row_load_offset % 16; + + // Init D matrix with fragments of C matrix + *d[0] = c[0]; + *d[1] = c[1]; + *d[2] = c[2]; + *d[3] = c[3]; + *d[4] = c[4]; + *d[5] = c[5]; + *d[6] = c[6]; + *d[7] = c[7]; + + // Calculate the row and col offset indices to iterate through the row + // & col fragments of A & B matrices + int r_ind = (lane % 2) ? 1 : 0; + int c_ind = ((lane % 4) / 2) ? 2 : 0; + + // Each sub-group is responsible for computing a fragment size of 8*8 + // elements of matrix D for each of 4 MMA computations. + // Each work item computes 8 elements of matrix D by gathering + // their corresponding col & row matrix fragments of length k (4) + // from A & B matrices respectively using below mapping logic: + // row0 = (i % 4) if (lane < 16) else (i % 4) + 4 + // col0 = (lane % 4) + // As each row & col fragment of A & B matrices is distributed across + // 4 work items, each iteration of below loop loads a partial fragment + // of matrix A (row) and matrix B (col) using the row & col offsets. + typename MMAType<ABType>::PackType recv_a[2], recv_b[2]; + + for (int i = 0; i < 4; i++) { + // Load partial fragment from col0 of matrix A ({a0, a1}) + recv_a[0] = + dpct::select_from_sub_group(sg, a[0], row_load_offset + i); + // Load partial fragment from col0 of matrix A ({a2, a3}) + recv_a[1] = + dpct::select_from_sub_group(sg, a[1], row_load_offset + i); + + // Load partial fragment from row0 of matrix B ({b0, b1}) + recv_b[0] = + dpct::select_from_sub_group(sg, b[0], col_load_offset + i); + // Load partial fragment from row0 of matrix B ({b2, b3}) + recv_b[1] = + dpct::select_from_sub_group(sg, b[1], col_load_offset + i); + + auto ra = reinterpret_cast<ABType*>(recv_a); + auto rb = reinterpret_cast<ABType*>(recv_b); + + // Each work item calculates a partial product of A & B matrix + // fragments and adds it to the corresponding D matrix fragment (for + // even work item indices) d0 += col0{ a0 } * row0{ b0 } d1 += col0{ + // a0 } * row0{ b1 } d2 += col1{ a2 } * row0{ b0 } d3 += col1{ a2 } + // * row0{ b1 } (for odd work item indices) d0 += col0{ a1 } * row0{ + // b2 } d1 += col0{ a1 } * row0{ b3 } d2 += col1{ a3 } * row0{ b2 } + // d3 += col1{ a3 } * row0{ b3 } + *d[0] += + static_cast<float>(ra[r_ind]) * static_cast<float>(rb[c_ind]); + *d[1] += static_cast<float>(ra[r_ind]) * + static_cast<float>(rb[c_ind + 1]); + *d[2] += static_cast<float>(ra[r_ind + 2]) * + static_cast<float>(rb[c_ind]); + *d[3] += static_cast<float>(ra[r_ind + 2]) * + static_cast<float>(rb[c_ind + 1]); + + // Load partial fragment from row1 of matrix B ({b0, b1}) + recv_b[0] = + dpct::select_from_sub_group(sg, b[0], col_load_offset + i + 16); + // Load partial fragment from row1 of matrix B ({b2, b3}) + recv_b[1] = + dpct::select_from_sub_group(sg, b[1], col_load_offset + i + 16); + + // (for even work item indices) + // d0 += col0{ a0 } * row1{ b0 } + // d1 += col0{ a0 } * row1{ b1 } + // d2 += col1{ a2 } * row1{ b0 } + // d3 += col1{ a2 } * row1{ b1 } + // (for odd work item indices) + // d0 += col0{ a1 } * row1{ b2 } + // d1 += col0{ a1 } * row1{ b3 } + // d2 += col1{ a3 } * row1{ b2 } + // d3 += col1{ a3 } * row1{ b3 } + *d[4] += + static_cast<float>(ra[r_ind]) * static_cast<float>(rb[c_ind]); + *d[5] += static_cast<float>(ra[r_ind]) * + static_cast<float>(rb[c_ind + 1]); + *d[6] += static_cast<float>(ra[r_ind + 2]) * + static_cast<float>(rb[c_ind]); + *d[7] += static_cast<float>(ra[r_ind + 2]) * + static_cast<float>(rb[c_ind + 1]); + } + } + } else if constexpr (M == 8 && N == 8 && K == 16) { + if constexpr (std::is_integral_v<ABType>) { + // Init D matrix with fragments of C matrix + *d[0] = c[0]; + *d[1] = c[1]; + + // Each sub-group is responsible for computing a fragment size of 16*8 + // elements of matrix D. + // Each work item computes 2 elements of matrix D by gathering + // their corresponding row & col matrix fragments of length k (16) + // from A & B matrices respectively using below mapping logic: + // row0 = ((lane % 4) * 4) + i + // col0 = (lane >> 2) + // As each row & col fragment of A & B matrices is distributed across + // 4 work items, each iteration of below loop loads a partial fragment + // of matrix A (row) and matrix B (col) using the row & col offsets. + for (int i = 0; i < 4; i++) { + typename MMAType<ABType>::PackType recv_a, recv_b[2]; + + // Load partial fragment from row0 of matrix A ({a0, a1, a2, a3}) + recv_a = dpct::select_from_sub_group(sg, a[0], row_load_offset + i); + // Load partial fragment from col0 of matrix B ({b0, b1, b2, b3}) + recv_b[0] = + dpct::select_from_sub_group(sg, b[0], col_load_offset + i); + // Load partial fragment from col1 of matrix B ({b0, b1, b2, b3}) + recv_b[1] = + dpct::select_from_sub_group(sg, b[0], col_load_offset + i + 4); + + auto a = reinterpret_cast<ABType*>(&recv_a); + auto b = reinterpret_cast<ABType*>(recv_b); + + // Each work item calculates a partial product of A & B matrix + // fragments and adds it to the corresponding D matrix fragment d0 + // += row0{ a0, a1, a2, a3 } * col0{ b0, b1, b2, b3 } d1 += row0{ + // a0, a1, a2, a3 } * col1{ b0, b1, b2, b3 } d2 += row0{ a0, a1, a2, + // a3 } * col0{ b0, b1, b2, b3 } d3 += row0{ a0, a1, a2, a3 } * + // col1{ b0, b1, b2, b3 } + for (int j = 0; j < 4; j++) { + *d[0] += a[j] * b[j]; + *d[1] += a[j] * b[j + 4]; + } + } + } + } else if constexpr (M == 16 && N == 8 && K == 8) { + if constexpr (std::is_floating_point_v<CDType>) { + // Init D matrix fragment with C matrix fragment + *d[0] = c[0]; + *d[1] = c[1]; + *d[2] = c[2]; + *d[3] = c[3]; + + // Each sub-group is responsible for computing a fragment size of 16*8 + // elements of matrix D. + // Each work item computes 4 elements of matrix D by gathering + // their corresponding row & col matrix fragments of length k (8) + // from A & B matrices respectively using below mapping logic: + // row0 = (lane >> 2) & row1 = (lane >> 2) + 8 + // col0 = (lane % 4) * 2 + (i & 0x1) + // As each row & col fragment of A & B matrices is distributed across + // 4 work items, each iteration of below loop loads a partial fragment + // of matrix A (row) and matrix B (col) using the row & col offsets. + for (int i = 0; i < 4; i++) { + typename MMAType<ABType>::PackType recv_a[2], recv_b[2]; + + // Load partial fragment from row0 of matrix A ({a0, a1}) + recv_a[0] = + dpct::select_from_sub_group(sg, a[0], row_load_offset + i); + // Load partial fragment from row1 of matrix A ({a2, a3}) + recv_a[1] = + dpct::select_from_sub_group(sg, a[1], row_load_offset + i); + // Load partial fragment from col0 of matrix B ({b0, b1}) + recv_b[0] = + dpct::select_from_sub_group(sg, b[0], col_load_offset + i); + // Load partial fragment from col1 of matrix B ({b0, b1}) + recv_b[1] = + dpct::select_from_sub_group(sg, b[0], col_load_offset + i + 4); + + auto ra = reinterpret_cast<ABType*>(recv_a); + auto rb = reinterpret_cast<ABType*>(recv_b); + + // Each work item calculates a partial product of A & B matrix + // fragments and adds it to the corresponding D matrix fragment d0 + // += row0{ a0, a1 } * col0{ b0, b1 } d1 += row0{ a0, a1 } * col1{ + // b0, b1 } d2 += row1{ a2, a3 } * col0{ b0, b1 } d3 += row1{ a2, a3 + // } * col1{ b0, b1 } + for (int j = 0; j < 2; j++) { + *d[0] += static_cast<float>(ra[j]) * static_cast<float>(rb[j]); + *d[1] += + static_cast<float>(ra[j]) * static_cast<float>(rb[j + 2]); + *d[2] += + static_cast<float>(ra[j + 2]) * static_cast<float>(rb[j]); + *d[3] += + static_cast<float>(ra[j + 2]) * static_cast<float>(rb[j + 2]); + } + } + } + } else if constexpr (M == 16 && N == 8 && K == 16) { + if constexpr (std::is_floating_point_v<CDType>) { + // Init D matrix fragment with C matrix fragment + *d[0] = c[0]; + *d[1] = c[1]; + *d[2] = c[2]; + *d[3] = c[3]; + + // Each sub-group is responsible for computing a fragment size of 16*8 + // elements of matrix D. + // Each work item computes 4 elements of matrix D by gathering + // their corresponding row & col matrix fragments of length k (8) + // from A & B matrices respectively using below mapping logic: + // row0 = (lane >> 2) & row1 = (lane >> 2) + 8 + // col0 = (lane % 4) * 2 & col1 = (lane % 4) * 2 + 1 + // As each row & col fragment of A & B matrices is distributed across + // 4 work items, each iteration of below loop loads a partial fragment + // of matrix A (row) and matrix B (col) using the row & col offsets. + for (int i = 0; i < 4; i++) { + typename MMAType<ABType>::PackType recv_a[4], recv_b[4]; + + // Load partial fragment from row0 of matrix A ({a0, a1}) + recv_a[0] = + dpct::select_from_sub_group(sg, a[0], row_load_offset + i); + // Load partial fragment from row0 of matrix A ({a2, a3}) + recv_a[1] = + dpct::select_from_sub_group(sg, a[2], row_load_offset + i); + // Load partial fragment from row1 of matrix A ({a0, a1}) + recv_a[2] = + dpct::select_from_sub_group(sg, a[1], row_load_offset + i); + // Load partial fragment from row1 of matrix A ({a2, a3}) + recv_a[3] = + dpct::select_from_sub_group(sg, a[3], row_load_offset + i); + + // Load partial fragment from col0 of matrix B ({b0, b1}) + recv_b[0] = + dpct::select_from_sub_group(sg, b[0], col_load_offset + i); + // Load partial fragment from col0 of matrix B ({b2, b3}) + recv_b[1] = + dpct::select_from_sub_group(sg, b[1], col_load_offset + i); + // Load partial fragment from col1 of matrix B ({b0, b1}) + recv_b[2] = + dpct::select_from_sub_group(sg, b[0], col_load_offset + 4 + i); + // Load partial fragment from col1 of matrix B ({b2, b3}) + recv_b[3] = + dpct::select_from_sub_group(sg, b[1], col_load_offset + 4 + i); + + auto ra = reinterpret_cast<ABType*>(recv_a); + auto rb = reinterpret_cast<ABType*>(recv_b); + + // Each work item calculates a partial product of A & B matrix + // fragments and adds it to the corresponding D matrix fragment d0 + // += row0{ a0, a1, a2, a3 } * col0{ b0, b1, b2, b3 } d1 += row0{ + // a0, a1, a2, a3 } * col1{ b0, b1, b2, b3 } d2 += row1{ a0, a1, a2, + // a3 } * col0{ b0, b1, b2, b3 } d3 += row1{ a0, a1, a2, a3 } * + // col1{ b0, b1, b2, b3 } + for (int j = 0; j < 4; j++) { + *d[0] += static_cast<CDType>(ra[j]) * static_cast<CDType>(rb[j]); + *d[1] += + static_cast<CDType>(ra[j]) * static_cast<CDType>(rb[j + 4]); + *d[2] += + static_cast<CDType>(ra[j + 4]) * static_cast<CDType>(rb[j]); + *d[3] += static_cast<CDType>(ra[j + 4]) * + static_cast<CDType>(rb[j + 4]); + } + } + } else if constexpr (std::is_integral_v<ABType>) { + // Init D matrix with fragments of C matrix + *d[0] = c[0]; + *d[1] = c[1]; + *d[2] = c[2]; + *d[3] = c[3]; + + // Each sub-group is responsible for computing a fragment size of 16*8 + // elements of matrix D. + // Each work item computes 4 elements of matrix D by gathering + // their corresponding row & col matrix fragments of length k (8) + // from A & B matrices respectively using below mapping logic: + // row0 = (lane >> 2) & row1 = (lane >> 2) + 8 + // col0 = (lane % 4) * 2 & col1 = (lane % 4) * 2 + 1 + // As each row & col fragment of A & B matrices is distributed across + // 4 work items, each iteration of below loop loads a partial fragment + // of matrix A (row) and matrix B (col) using the row & col offsets. + for (int i = 0; i < 4; i++) { + typename MMAType<ABType>::PackType recv_a[2], recv_b[2]; + + // Load partial fragment from row0 of matrix A ({a0, a1, a2, a3}) + recv_a[0] = + dpct::select_from_sub_group(sg, a[0], row_load_offset + i); + // Load partial fragment from row1 of matrix A ({a4, a5, a6, a7}) + recv_a[1] = + dpct::select_from_sub_group(sg, a[1], row_load_offset + i); + // Load partial fragment from col0 of matrix B ({b0, b1, b2, b3}) + recv_b[0] = + dpct::select_from_sub_group(sg, b[0], col_load_offset + i); + // Load partial fragment from col1 of matrix B ({b4, b5, b6, b7}) + recv_b[1] = + dpct::select_from_sub_group(sg, b[0], col_load_offset + i + 4); + + auto ra = reinterpret_cast<ABType*>(recv_a); + auto rb = reinterpret_cast<ABType*>(recv_b); + + // Each work item calculates a partial product of A & B matrix + // fragments and adds it to the corresponding D matrix fragment d0 + // += row0{ a0, a1, a2, a3 } * col0{ b0, b1, b2, b3 } d1 += row0{ + // a0, a1, a2, a3 } * col1{ b4, b5, b6, b7 } d2 += row1{ a4, a5, a6, + // a7 } * col0{ b0, b1, b2, b3 } d3 += row1{ a4, a5, a6, a7 } * + // col1{ b4, b5, b6, b7 } + for (int i = 0; i < 4; i++) { + *d[0] += ra[i] * rb[i]; + *d[1] += ra[i] * rb[i + 4]; + *d[2] += ra[i + 4] * rb[i]; + *d[3] += ra[i + 4] * rb[i + 4]; + } + } + } + } else if constexpr (M == 16 && N == 8 && K == 32) { + if constexpr (std::is_integral_v<ABType>) { + // Init D matrix with fragments of C matrix + *d[0] = c[0]; + *d[1] = c[1]; + *d[2] = c[2]; + *d[3] = c[3]; + + // Each sub-group is responsible for computing a fragment size of 16*8 + // elements of matrix D. + // Each work item computes 4 elements of matrix D by gathering + // their corresponding row & col matrix fragments of length k (32) + // from A & B matrices respectively using below mapping logic: + // row0 = (lane >> 2) & row1 = (lane >> 2) + 8 + // col0 = ((lane % 4) * 4) + (i & 0x3) & col1 = ((lane % 4) * 4) + (i + // & 0x3) As each row & col fragment of A & B matrices is distributed + // across 4 work items, each iteration of below loop loads a partial + // fragment of matrix A (row) and matrix B (col) using the row & col + // offsets. + for (int i = 0; i < 4; i++) { + typename MMAType<ABType>::PackType recv_a[2], recv_b[2]; + + // Load partial fragment from row0 of matrix A ({a0, a1, a2, a3}) + recv_a[0] = + dpct::select_from_sub_group(sg, a[0], row_load_offset + i); + // Load partial fragment from row1 of matrix A ({a4, a5, a6, a7}) + recv_a[1] = + dpct::select_from_sub_group(sg, a[1], row_load_offset + i); + // Load partial fragment from col0 of matrix B ({b0, b1, b2, b3}) + recv_b[0] = + dpct::select_from_sub_group(sg, b[0], col_load_offset + i); + // Load partial fragment from col1 of matrix B ({b0, b1, b2, b3}) + recv_b[1] = + dpct::select_from_sub_group(sg, b[0], col_load_offset + i + 4); + + auto a = reinterpret_cast<ABType*>(recv_a); + auto b = reinterpret_cast<ABType*>(recv_b); + + // Each work item calculates a partial product of A & B matrix + // fragments and adds it to the corresponding D matrix fragment d0 + // += row0{ a0, a1, a2, a3 } * col0{ b0, b1, b2, b3 } d1 += row0{ + // a0, a1, a2, a3 } * col1{ b0, b1, b2, b3 } d2 += row1{ a4, a5, a6, + // a7 } * col0{ b0, b1, b2, b3 } d3 += row1{ a4, a5, a6, a7 } * + // col1{ b0, b1, b2, b3 } + for (int j = 0; j < 4; j++) { + *d[0] += a[j] * b[j]; + *d[1] += a[j] * b[j + 4]; + *d[2] += a[j + 4] * b[j]; + *d[3] += a[j + 4] * b[j + 4]; + } + } + + for (int i = 0; i < 4; i++) { + typename MMAType<ABType>::PackType recv_a[2], recv_b[2]; + + // Load partial fragment from row0 of matrix A ({a8, a9, a10, a11}) + recv_a[0] = + dpct::select_from_sub_group(sg, a[2], row_load_offset + i); + // Load partial fragment from row1 of matrix A ({a12, a13, a14, + // a15}) + recv_a[1] = + dpct::select_from_sub_group(sg, a[3], row_load_offset + i); + // Load partial fragment from col0 of matrix B ({b4, b5, b6, b7}) + recv_b[0] = + dpct::select_from_sub_group(sg, b[1], col_load_offset + i); + // Load partial fragment from col1 of matrix B ({b4, b5, b6, b7}) + recv_b[1] = + dpct::select_from_sub_group(sg, b[1], col_load_offset + i + 4); + + auto a = reinterpret_cast<ABType*>(recv_a); + auto b = reinterpret_cast<ABType*>(recv_b); + + // Each work item calculates a partial product of A & B matrix + // fragments and adds it to the corresponding D matrix fragment d0 + // += row0{ a8, a9, a10, a11 } * col0{ b4, b5, b6, b7 } d1 += row0{ + // a8, a9, a10, a11 } * col1{ b4, b5, b6, b7 } d2 += row1{ a12, a13, + // a14, a15 } * col0{ b4, b5, b6, b7 } d3 += row1{ a12, a13, a14, + // a15 } * col1{ b4, b5, b6, b7 } + for (int j = 0; j < 4; j++) { + *d[0] += a[j] * b[j]; + *d[1] += a[j] * b[j + 4]; + *d[2] += a[j + 4] * b[j]; + *d[3] += a[j + 4] * b[j + 4]; + } + } + } + } + } } // COPY from DPCT head files #endif // GGML_SYCL_DPCT_HELPER_HPP diff --git a/ggml/src/ggml-sycl/element_wise.cpp b/ggml/src/ggml-sycl/element_wise.cpp index 8d83b2446bd..249e80c826e 100644 --- a/ggml/src/ggml-sycl/element_wise.cpp +++ b/ggml/src/ggml-sycl/element_wise.cpp @@ -9,23 +9,32 @@ #define SYCL_LOCAL_ID_CALC(ITEM, IDX) \ (ITEM.get_local_range(IDX) * ITEM.get_group(IDX) + ITEM.get_local_id(IDX)) +static void acc_f32(const float * x, const float * y, float * dst, const int64_t ne, + const int64_t ne10, const int64_t ne11, const int64_t ne12, const int64_t ne13, + const int64_t s11, const int64_t s12, const int64_t s13, const int64_t offset) { + auto item_ct1 = sycl::ext::oneapi::this_work_item::get_nd_item<3>(); + const int64_t i = SYCL_LOCAL_ID_CALC(item_ct1, 2); -static void acc_f32(const float * x, const float * y, float * dst, const int ne, - const int ne10, const int ne11, const int ne12, - const int nb1, const int nb2, int offset, const sycl::nd_item<1> &item_ct1) { - const int i = SYCL_LOCAL_ID_CALC(item_ct1, 0); if (i >= ne) { return; } - int src1_idx = i - offset; - int oz = src1_idx / nb2; - int oy = (src1_idx - (oz * nb2)) / nb1; - int ox = src1_idx % nb1; - if (src1_idx >= 0 && ox < ne10 && oy < ne11 && oz < ne12) { - dst[i] = x[i] + y[ox + oy * ne10 + oz * ne10 * ne11]; - } else { - dst[i] = x[i]; + + int64_t src1_idx = i - offset; + + int64_t tmp = src1_idx; + const int64_t i13 = tmp / s13; + tmp -= i13 * s13; + const int64_t i12 = tmp / s12; + tmp -= i12 * s12; + const int64_t i11 = tmp / s11; + tmp -= i11 * s11; + const int64_t i10 = tmp; + + float val = x[i]; + if (src1_idx >= 0 && i10 < ne10 && i11 < ne11 && i12 < ne12 && i13 < ne13) { + val += y[((i13*ne12 + i12) * ne11 + i11) * ne10 + i10]; } + dst[i] = val; } /* Unary OP funcs */ @@ -123,6 +132,15 @@ static __dpct_inline__ T op_log(T x) { return sycl::log(x); } +template<typename T> +static __dpct_inline__ T op_softplus(T x) { + const float xf = (float) x; + const float ax = sycl::fabs(xf); + const float m = sycl::fmax(xf, 0.0f); + const float y = m + sycl::log1p(sycl::exp(-ax)); + return (T) y; +} + template<typename T> static __dpct_inline__ T op_neg(T x) { return -x; @@ -276,30 +294,6 @@ static void unary_op_trunc_kernel(const T * x, T * dst, const int k, const sycl: } } -template<typename T> -static void upscale(const T *x, T *dst, const int nb00, const int nb01, - const int nb02, const int nb03, const int ne10, const int ne11, - const int ne12, const int ne13, const float sf0, const float sf1, - const float sf2, const float sf3, const sycl::nd_item<1> &item_ct1) { - int index = item_ct1.get_local_id(0) + - item_ct1.get_group(0) * item_ct1.get_local_range(0); - if (index >= ne10 * ne11 * ne12 * ne13) { - return; - } - // operation - int i10 = index % ne10; - int i11 = (index / ne10) % ne11; - int i12 = (index / (ne10 * ne11)) % ne12; - int i13 = (index / (ne10 * ne11 * ne12)) % ne13; - - int i00 = static_cast<int>(i10 / sf0); - int i01 = static_cast<int>(i11 / sf1); - int i02 = static_cast<int>(i12 / sf2); - int i03 = static_cast<int>(i13 / sf3); - - dst[index] = *(const T *)((const char *)x + i03 * nb03 + i02 * nb02 + i01 * nb01 + i00 * nb00); -} - template<typename T> static void clamp(const T * x, T * dst, const float min, const float max, const int k, const sycl::nd_item<1> &item_ct1) { @@ -355,18 +349,15 @@ static void gated_op_fused_geglu_quick(const T * x, const T * g, T * dst, const namespace ggml_sycl_detail { static void acc_f32_sycl(const float *x, const float *y, float *dst, - const int n_elements, const int ne10, const int ne11, - const int ne12, const int nb1, const int nb2, - const int offset, queue_ptr stream) { - int num_blocks = ceil_div(n_elements, SYCL_ACC_BLOCK_SIZE); - stream->parallel_for( - sycl::nd_range<1>(sycl::range<1>(num_blocks) * - sycl::range<1>(SYCL_ACC_BLOCK_SIZE), - sycl::range<1>(SYCL_ACC_BLOCK_SIZE)), - [=](sycl::nd_item<1> item_ct1) { - acc_f32(x, y, dst, n_elements, ne10, ne11, ne12, nb1, nb2, offset, - item_ct1); - }); + const int64_t n_elements, const int64_t ne10, const int64_t ne11, + const int64_t ne12, const int64_t ne13, const int64_t s1, const int64_t s2, const int64_t s3, + const int64_t offset, queue_ptr stream) { + const int num_blocks = (n_elements + SYCL_ACC_BLOCK_SIZE - 1) / SYCL_ACC_BLOCK_SIZE; + stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_ACC_BLOCK_SIZE), + sycl::range<3>(1, 1, SYCL_ACC_BLOCK_SIZE)), + [=](sycl::nd_item<3> /*item_ct1*/) { + acc_f32(x, y, dst, n_elements, ne10, ne11, ne12, ne13, s1, s2, s3, offset); + }); } template<typename T> @@ -377,41 +368,21 @@ static void arange_kernel(T * dst, const int k, T start, T step, } } -template<typename T> -static void upscale_sycl(const T *x, T *dst, const int nb00, const int nb01, - const int nb02, const int nb03, const int ne10, const int ne11, - const int ne12, const int ne13, const float sf0, const float sf1, - const float sf2, const float sf3, queue_ptr stream) { - int dst_size = ne10 * ne11 * ne12 * ne13; - int num_blocks = ceil_div(dst_size, SYCL_UPSCALE_BLOCK_SIZE); - sycl::range<1> gridDim(num_blocks * SYCL_UPSCALE_BLOCK_SIZE); - stream->parallel_for( - sycl::nd_range<1>(gridDim, sycl::range<1>(SYCL_UPSCALE_BLOCK_SIZE)), [=](sycl::nd_item<1> item_ct1) { - upscale(x, dst, nb00, nb01, nb02, nb03, ne10, ne11, ne12, ne13, sf0, sf1, sf2, sf3, item_ct1); - }); -} - template<typename KernelInvoker, typename... Args> static inline void dispatch_ggml_sycl_op_unary(ggml_backend_sycl_context & ctx, ggml_tensor * dst, KernelInvoker kernel_invoker, Args&&... args) { -#if defined (GGML_SYCL_F16) GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16); GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16); -#else - GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32); - GGML_ASSERT(dst->type == GGML_TYPE_F32); -#endif GGML_ASSERT(dst->src[0]->type == dst->type); + dpct::queue_ptr main_stream = ctx.stream(); SYCL_CHECK(ggml_sycl_set_device(ctx.device)); switch (dst->type) { -#if defined (GGML_SYCL_F16) case GGML_TYPE_F16: { auto data_pts = cast_data<sycl::half>(dst); kernel_invoker(data_pts.src, data_pts.dst, (int)ggml_nelements(dst->src[0]), main_stream, std::forward<Args>(args)...); break; } -#endif case GGML_TYPE_F32: { auto data_pts = cast_data<float>(dst); @@ -425,14 +396,10 @@ static inline void dispatch_ggml_sycl_op_unary(ggml_backend_sycl_context & ctx, template<typename KernelInvoker, typename... Args> static inline void dispatch_ggml_sycl_op_fused_glu(ggml_backend_sycl_context & ctx, ggml_tensor * dst, KernelInvoker kernel_invoker, Args&&... args) { -#if defined (GGML_SYCL_F16) GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16); GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16); -#else - GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32); - GGML_ASSERT(dst->type == GGML_TYPE_F32); -#endif GGML_ASSERT(dst->src[0]->type == dst->type); + dpct::queue_ptr main_stream = ctx.stream(); SYCL_CHECK(ggml_sycl_set_device(ctx.device)); const ggml_tensor * src0 = dst->src[0]; @@ -454,7 +421,6 @@ static inline void dispatch_ggml_sycl_op_fused_glu(ggml_backend_sycl_context & c GGML_ASSERT(src0->type == src1->type); } switch (dst->type) { -#if defined (GGML_SYCL_F16) case GGML_TYPE_F16: { sycl::half * src0_p = (sycl::half *) src0_d; @@ -475,7 +441,6 @@ static inline void dispatch_ggml_sycl_op_fused_glu(ggml_backend_sycl_context & c std::forward<Args>(args)...); break; } -#endif case GGML_TYPE_F32: { float * src0_p = (float *) src0_d; @@ -502,48 +467,6 @@ static inline void dispatch_ggml_sycl_op_fused_glu(ggml_backend_sycl_context & c } } -template<typename KernelInvoker, typename... Args> -static inline void dispatch_ggml_sycl_op_upscale(ggml_backend_sycl_context & ctx, ggml_tensor * dst, KernelInvoker kernel_invoker, Args&&... args) { -#if defined (GGML_SYCL_F16) - GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16); - GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16); -#else - GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32); - GGML_ASSERT(dst->type == GGML_TYPE_F32); -#endif - GGML_ASSERT(dst->src[0]->type == dst->type); - - dpct::queue_ptr main_stream = ctx.stream(); - SYCL_CHECK(ggml_sycl_set_device(ctx.device)); - - const float sf0 = (float) dst->ne[0] / dst->src[0]->ne[0]; - const float sf1 = (float) dst->ne[1] / dst->src[0]->ne[1]; - const float sf2 = (float) dst->ne[2] / dst->src[0]->ne[2]; - const float sf3 = (float) dst->ne[3] / dst->src[0]->ne[3]; - switch (dst->type) { -#if defined (GGML_SYCL_F16) - case GGML_TYPE_F16: - { - auto data_pts = cast_data<sycl::half>(dst); - kernel_invoker(data_pts.src, data_pts.dst, (int)dst->src[0]->nb[0], (int)dst->src[0]->nb[1], (int)dst->src[0]->nb[2], - (int)dst->src[0]->nb[3], (int)dst->ne[0], (int)dst->ne[1], (int)dst->ne[2], (int)dst->ne[3], sf0, sf1, sf2, sf3, - main_stream, std::forward<Args>(args)...); - break; - } -#endif - case GGML_TYPE_F32: - { - auto data_pts = cast_data<float>(dst); - kernel_invoker(data_pts.src, data_pts.dst, (int)dst->src[0]->nb[0], (int)dst->src[0]->nb[1], (int)dst->src[0]->nb[2], - (int)dst->src[0]->nb[3], (int)dst->ne[0], (int)dst->ne[1], (int)dst->ne[2], (int)dst->ne[3], sf0, sf1, sf2, sf3, - main_stream, std::forward<Args>(args)...); - break; - } - default: - GGML_ABORT("GGML tensor type not supported!\n"); - } -} - template<typename F> static inline void ggml_sycl_op_unary( ggml_backend_sycl_context & ctx, ggml_tensor * dst, F func) { @@ -695,6 +618,12 @@ static inline void ggml_sycl_op_log(ggml_backend_sycl_context & ctx, ggml_tensor }); } +static inline void ggml_sycl_op_softplus(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { + ggml_sycl_detail::ggml_sycl_op_unary(ctx, dst, [](auto x) { + return op_softplus(x); + }); +} + static inline void ggml_sycl_op_neg(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { ggml_sycl_detail::ggml_sycl_op_unary(ctx, dst, [](auto x) { return op_neg(x); @@ -781,15 +710,6 @@ static inline void ggml_sycl_op_sqr(ggml_backend_sycl_context & ctx, ggml_tensor }); } -static inline void ggml_sycl_op_upscale(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { - ggml_sycl_detail::dispatch_ggml_sycl_op_upscale(ctx, dst, - [](const auto* src, auto* dst_ptr, int nb00, int nb01, int nb02, int nb03, - int ne10, int ne11, int ne12, int ne13, float sf0, float sf1, float sf2, float sf3, - queue_ptr stream) { - ggml_sycl_detail::upscale_sycl(src, dst_ptr, nb00, nb01, nb02, nb03, ne10, ne11, ne12, ne13, sf0, sf1, sf2, sf3, stream); - }); -} - static inline void ggml_sycl_op_clamp(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { float min_val; float max_val; @@ -821,16 +741,9 @@ static inline void ggml_sycl_op_floor(ggml_backend_sycl_context & ctx, ggml_tens } static inline void ggml_sycl_op_ceil(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { - ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst, - [](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) { - const int num_blocks = ceil_div(k_elements, 256); - stream->parallel_for( - sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(256), - sycl::range<1>(256)), - [=](sycl::nd_item<1> item_ct1) { - unary_op_ceil_kernel(src, dst_ptr, k_elements, item_ct1); - }); - }); + ggml_sycl_detail::ggml_sycl_op_unary(ctx, dst, [](auto x) { + return op_ceil(x); + }); } static inline void ggml_sycl_op_round(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { @@ -860,22 +773,31 @@ static inline void ggml_sycl_op_trunc(ggml_backend_sycl_context & ctx, ggml_tens } static inline void ggml_sycl_op_acc(ggml_backend_sycl_context & ctx, ggml_tensor *dst) { - GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32); - GGML_ASSERT(dst->src[1]->type == GGML_TYPE_F32); + const ggml_tensor * src0 = dst->src[0]; + const ggml_tensor * src1 = dst->src[1]; + + const float * src0_d = (const float *) src0->data; + const float * src1_d = (const float *) src1->data; + float * dst_d = (float *) dst->data; + + dpct::queue_ptr stream = ctx.stream(); + + GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT(src1->type == GGML_TYPE_F32); GGML_ASSERT( dst->type == GGML_TYPE_F32); - GGML_ASSERT(dst->ne[3] == 1); // just 3D tensors supported - dpct::queue_ptr main_stream = ctx.stream(); - SYCL_CHECK(ggml_sycl_set_device(ctx.device)); - const float * src0_dd = static_cast<const float *>(dst->src[0]->data); - const float * src1_dd = static_cast<const float*>(dst->src[1]->data); - float * dst_dd = static_cast<float *>(dst->data); - int nb1 = dst->op_params[0] / 4; // 4 bytes of float32 - int nb2 = dst->op_params[1] / 4; // 4 bytes of float32 - // int nb3 = dst->op_params[2] / 4; // 4 bytes of float32 - unused - int offset = dst->op_params[3] / 4; // offset in bytes + GGML_ASSERT(ggml_is_contiguous(src1)); + GGML_ASSERT(dst->nb[0] == ggml_element_size(dst)); + GGML_ASSERT(ggml_is_contiguously_allocated(dst)); + + const int64_t s1 = dst->op_params[0] / sizeof(float); + const int64_t s2 = dst->op_params[1] / sizeof(float); + const int64_t s3 = dst->op_params[2] / sizeof(float); + const int64_t offset = dst->op_params[3] / sizeof(float); - ggml_sycl_detail::acc_f32_sycl(src0_dd, src1_dd, dst_dd, (int)ggml_nelements(dst), (int)dst->src[1]->ne[0], (int)dst->src[1]->ne[1], (int)dst->src[1]->ne[2], nb1, nb2, offset, main_stream); + ggml_sycl_detail::acc_f32_sycl(src0_d, src1_d, dst_d, ggml_nelements(dst), + src1->ne[0], src1->ne[1], src1->ne[2], src1->ne[3], + s1, s2, s3, offset, stream); } static inline void ggml_sycl_op_geglu(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { @@ -1101,6 +1023,11 @@ void ggml_sycl_log(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { ggml_sycl_op_log(ctx, dst); } +void ggml_sycl_softplus(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { + scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1); + ggml_sycl_op_softplus(ctx, dst); +} + void ggml_sycl_neg(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1); ggml_sycl_op_neg(ctx, dst); @@ -1121,12 +1048,6 @@ void ggml_sycl_sqr(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { ggml_sycl_op_sqr(ctx, dst); } -void ggml_sycl_upscale(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { - scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1); - ggml_sycl_op_upscale(ctx, dst); -} - - void ggml_sycl_clamp(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1); ggml_sycl_op_clamp(ctx, dst); diff --git a/ggml/src/ggml-sycl/element_wise.hpp b/ggml/src/ggml-sycl/element_wise.hpp index 0913a2e529b..997132166ab 100644 --- a/ggml/src/ggml-sycl/element_wise.hpp +++ b/ggml/src/ggml-sycl/element_wise.hpp @@ -61,6 +61,8 @@ void ggml_sycl_exp(ggml_backend_sycl_context & ctx, ggml_tensor * dst); void ggml_sycl_log(ggml_backend_sycl_context & ctx, ggml_tensor * dst); +void ggml_sycl_softplus(ggml_backend_sycl_context & ctx, ggml_tensor * dst); + void ggml_sycl_neg(ggml_backend_sycl_context & ctx, ggml_tensor * dst); void ggml_sycl_step(ggml_backend_sycl_context & ctx, ggml_tensor * dst); @@ -69,8 +71,6 @@ void ggml_sycl_leaky_relu(ggml_backend_sycl_context & ctx, ggml_tensor * dst); void ggml_sycl_sqr(ggml_backend_sycl_context & ctx, ggml_tensor * dst); -void ggml_sycl_upscale(ggml_backend_sycl_context & ctx, ggml_tensor * dst); - void ggml_sycl_clamp(ggml_backend_sycl_context & ctx, ggml_tensor * dst); void ggml_sycl_sgn(ggml_backend_sycl_context & ctx, ggml_tensor * dst); diff --git a/ggml/src/ggml-sycl/fattn-buffers.cpp b/ggml/src/ggml-sycl/fattn-buffers.cpp new file mode 100644 index 00000000000..46cf6d551f1 --- /dev/null +++ b/ggml/src/ggml-sycl/fattn-buffers.cpp @@ -0,0 +1,56 @@ +// +// MIT license +// Copyright (C) 2025 Intel Corporation +// SPDX-License-Identifier: MIT +// + +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// + +#include "common.hpp" + +sycl::half * ggml_sycl_fattn_kv_buffers::kv_buffer::ensure_half(size_t n_elems) { + const size_t need_bytes = n_elems * sizeof(sycl::half); + + if (capacity >= need_bytes) { + return ptr; + } + + if (ptr) { + SYCL_CHECK(CHECK_TRY_ERROR(qptr->wait())); + SYCL_CHECK(CHECK_TRY_ERROR(sycl::free(ptr, *qptr))); + ptr = nullptr; + capacity = 0; + } + + size_t cap = 0; + while (cap < need_bytes) { + cap += CHUNK_SIZE; + } + + void * dev_ptr; + SYCL_CHECK( + CHECK_TRY_ERROR(dev_ptr = sycl::malloc_device( + cap, *qptr))); + + if (!dev_ptr) { + GGML_LOG_ERROR("%s: can't allocate %lu Bytes of memory on device\n", __func__, cap); + GGML_ABORT("fattn buffer alloc failed"); + } + + ptr = static_cast<sycl::half *>(dev_ptr); + capacity = cap; + return ptr; +} + +ggml_sycl_fattn_kv_buffers::kv_buffer::~kv_buffer() { +#ifdef DEBUG_SYCL_POOL + GGML_LOG_INFO("ggml_sycl_fattn_kv_buffer[%d]: %.2f MiB\n", device, capacity / 1024.0 / 1024.0); +#endif + if (ptr) { + SYCL_CHECK(CHECK_TRY_ERROR(sycl::free(ptr, *qptr))); + } +} diff --git a/ggml/src/ggml-sycl/fattn-buffers.hpp b/ggml/src/ggml-sycl/fattn-buffers.hpp new file mode 100644 index 00000000000..c00461de620 --- /dev/null +++ b/ggml/src/ggml-sycl/fattn-buffers.hpp @@ -0,0 +1,63 @@ +// +// MIT license +// Copyright (C) 2025 Intel Corporation +// SPDX-License-Identifier: MIT +// + +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// + +#ifndef GGML_SYCL_FATTN_BUFFERS_HPP +#define GGML_SYCL_FATTN_BUFFERS_HPP + +#include <sycl/sycl.hpp> + +typedef sycl::queue *queue_ptr; + +struct ggml_sycl_fattn_kv_buffers { + // buffers grow in chunks of this size + static constexpr size_t CHUNK_SIZE = 16ull << 20; // 16 MiB + + struct kv_buffer { + kv_buffer(queue_ptr qptr_, int device_) : qptr(qptr_), device(device_) {} + ~kv_buffer(); + + kv_buffer(const kv_buffer &) = delete; + kv_buffer & operator=(const kv_buffer &) = delete; + + sycl::half * ensure_half(size_t n_elems); + + private: + sycl::half * ptr = nullptr; + size_t capacity = 0; + queue_ptr qptr = nullptr; + [[maybe_unused]] int device = 0; + }; + + kv_buffer K; + kv_buffer V; + + ggml_sycl_fattn_kv_buffers(queue_ptr qptr, int device) : K(qptr, device), V(qptr, device) {} + + ggml_sycl_fattn_kv_buffers(const ggml_sycl_fattn_kv_buffers &) = delete; + ggml_sycl_fattn_kv_buffers & operator=(const ggml_sycl_fattn_kv_buffers &) = delete; +}; + +/** + * Imitates `ggml_sycl_pool_alloc` to keep the code calling alloc unchanged. + */ +struct ggml_sycl_fattn_alloc { + ggml_sycl_fattn_kv_buffers::kv_buffer & buf; + sycl::half * ptr = nullptr; + + explicit ggml_sycl_fattn_alloc(ggml_sycl_fattn_kv_buffers::kv_buffer & buf_) : buf(buf_) {} + + sycl::half * alloc(size_t n_elems) { + ptr = buf.ensure_half(n_elems); + return ptr; + } +}; +#endif diff --git a/ggml/src/ggml-sycl/fattn-common.hpp b/ggml/src/ggml-sycl/fattn-common.hpp new file mode 100644 index 00000000000..c6cc13cfb00 --- /dev/null +++ b/ggml/src/ggml-sycl/fattn-common.hpp @@ -0,0 +1,1181 @@ +#pragma once + +#include <sycl/sycl.hpp> +#include "dpct/helper.hpp" +#include "common.hpp" +#include "convert.hpp" +#include "vecdotq.hpp" +#include "fattn-buffers.hpp" + +#include "ggml.h" + +#include <cstdint> +#include <cmath> +#include <float.h> + + +#define FATTN_KQ_STRIDE 256 +#define HALF_MAX_HALF sycl::half(65504.0f/2) // Use neg. of this instead of -INFINITY to initialize KQ max vals to avoid NaN upon subtraction. +#define SOFTMAX_FTZ_THRESHOLD -20.0f // Softmax exp. of values smaller than this are flushed to zero to avoid NaNs. +#define FATTN_KQ_MAX_OFFSET (3.0f*0.6931f) + +typedef void (*fattn_kernel_t)( + const char* Q, + const char* K, + const char* V, + const char* mask, + const char* sinks, + const int* KV_max, + float* dst, + sycl::float2* dst_meta, + const float scale, + const float max_bias, + const float m0, + const float m1, + const uint32_t n_head_log2, + const float logit_softcap, + const int32_t ne00, + const sycl::uint3 ne01, + const int32_t ne02, + const int32_t ne03, + const int32_t nb01, + const int32_t nb02, + const int32_t nb03, + const int32_t ne10, + const int32_t ne11, + const int32_t ne12, + const int32_t ne13, + const int32_t nb11, + const int32_t nb12, + const int64_t nb13, + const int32_t nb21, + const int32_t nb22, + const int64_t nb23, + const int32_t ne31, + const int32_t ne32, + const int32_t ne33, + const int32_t nb31, + const int32_t nb32, + const int64_t nb33); + +typedef float (*vec_dot_KQ_t)( + const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8 , const void * __restrict__ Q_ds); + +template <int D, int nthreads> +static __dpct_inline__ float vec_dot_fattn_vec_KQ_f16(const char * __restrict__ K_c, + const void * __restrict__ Q_v, + const int * __restrict__ Q_q8, + const void * __restrict__ Q_ds_v) { + const sycl::half2 * K_h2 = (const sycl::half2 *) K_c; + GGML_UNUSED(Q_q8); + GGML_UNUSED(Q_ds_v); + + constexpr int cpy_nb = ggml_sycl_get_max_cpy_bytes(); + constexpr int cpy_ne = cpy_nb / 4; + + float sum = 0.0f; + +#pragma unroll + for (int k_KQ_0 = 0; k_KQ_0 < D/2; k_KQ_0 += nthreads*cpy_ne) { + sycl::half2 tmp[cpy_ne]; + ggml_sycl_memcpy_1<sizeof(tmp)>( + tmp, + K_h2 + k_KQ_0 + (sycl::ext::oneapi::this_work_item::get_nd_item<3>().get_local_id(2) % nthreads) * cpy_ne); +#pragma unroll + for (int k_KQ_1 = 0; k_KQ_1 < cpy_ne; ++k_KQ_1) { +#ifdef GGML_SYCL_F16 + ggml_sycl_mad(sum, tmp[k_KQ_1] , ((const sycl::half2 *) Q_v)[k_KQ_0/nthreads + k_KQ_1]); +#else + ggml_sycl_mad(sum, __half22float2(tmp[k_KQ_1]), ((const sycl::float2 *) Q_v)[k_KQ_0/nthreads + k_KQ_1]); +#endif // GGML_SYCL_F16 + } + } + + return sum; +} + +template <int D, int nthreads, int warp_size> +static __dpct_inline__ float vec_dot_fattn_vec_KQ_q4_0(const char * __restrict__ K_c, + const void * __restrict__ Q_v, + const int * __restrict__ Q_q8, + const void * __restrict__ Q_ds_v) { + auto item_ct1 = sycl::ext::oneapi::this_work_item::get_nd_item<3>(); + + const block_q4_0 * K_q4_0 = (const block_q4_0 *) K_c; + GGML_UNUSED(Q_v); + + float sum = 0.0f; + +#pragma unroll + for (int k_KQ_0 = 0; k_KQ_0 < int(D/sizeof(int)); k_KQ_0 += nthreads) { + const int k_KQ = + k_KQ_0 + (nthreads == warp_size ? item_ct1.get_local_id(2) : item_ct1.get_local_id(2) % nthreads); + + const int ib = k_KQ / QI8_1; + const int iqs4 = k_KQ % QI4_0; + const int shift = k_KQ & (QI8_1/2); + + int v; + ggml_sycl_memcpy_1<sizeof(int), 2>(&v, K_q4_0[ib].qs + sizeof(int)*iqs4); + v = (v >> shift) & 0x0F0F0F0F; + const int u = Q_q8[k_KQ_0/nthreads]; + + const int sumi = ggml_sycl_dp4a(v, u, 0); + + const sycl::float2 Q_ds = ((const sycl::float2 *) Q_ds_v)[k_KQ_0 / nthreads]; + sum += __half2float(K_q4_0[ib].d) * (sumi*Q_ds.x() - (8/QI8_1)*Q_ds.y()); + } + + return sum; +} + +template <int D, int nthreads , int warp_size> +static __dpct_inline__ float vec_dot_fattn_vec_KQ_q4_1(const char * __restrict__ K_c, + const void * __restrict__ Q_v, + const int * __restrict__ Q_q8, + const void * __restrict__ Q_ds_v) { + auto item_ct1 = sycl::ext::oneapi::this_work_item::get_nd_item<3>(); + const block_q4_1 * K_q4_1 = (const block_q4_1 *) K_c; + GGML_UNUSED(Q_v); + + float sum = 0.0f; + +#pragma unroll + for (int k_KQ_0 = 0; k_KQ_0 < int(D/sizeof(int)); k_KQ_0 += nthreads) { + const int k_KQ = + k_KQ_0 + (nthreads == warp_size ? item_ct1.get_local_id(2) : item_ct1.get_local_id(2) % nthreads); + + const int ib = k_KQ / QI8_1; + const int iqs4 = k_KQ % QI4_1; + const int shift = k_KQ & (QI8_1/2); + + int v; + ggml_sycl_memcpy_1<sizeof(int)>(&v, K_q4_1[ib].qs + sizeof(int)*iqs4); + v = (v >> shift) & 0x0F0F0F0F; + const int u = Q_q8[k_KQ_0/nthreads]; + + const int sumi = ggml_sycl_dp4a(v, u, 0); + + const sycl::float2 K_dm = (K_q4_1[ib].dm).template convert<float, sycl::rounding_mode::automatic>(); + const sycl::float2 Q_ds = ((const sycl::float2 *) Q_ds_v)[k_KQ_0 / nthreads]; + + sum += K_dm.x()*Q_ds.x()*sumi + K_dm.y()*Q_ds.y()/QI8_1; + } + + return sum; +} + +template <int D, int nthreads, int warp_size> +static __dpct_inline__ float vec_dot_fattn_vec_KQ_q5_0(const char * __restrict__ K_c, + const void * __restrict__ Q_v, + const int * __restrict__ Q_q8, + const void * __restrict__ Q_ds_v) { + auto item_ct1 = sycl::ext::oneapi::this_work_item::get_nd_item<3>(); + const block_q5_0 * K_q5_0 = (const block_q5_0 *) K_c; + GGML_UNUSED(Q_v); + + float sum = 0.0f; + +#pragma unroll + for (int k_KQ_0 = 0; k_KQ_0 < int(D/sizeof(int)); k_KQ_0 += nthreads) { + const int k_KQ = + k_KQ_0 + (nthreads == warp_size ? item_ct1.get_local_id(2) : item_ct1.get_local_id(2) % nthreads); + + const int ib = k_KQ / QI8_1; + const int iqs4 = k_KQ % QI5_0; + const int iqs8 = k_KQ % QI8_1; + const int shift = k_KQ & (QI8_1/2); + + int v; + ggml_sycl_memcpy_1<sizeof(int), 2>(&v, K_q5_0[ib].qs + sizeof(int)*iqs4); + v = (v >> shift) & 0x0F0F0F0F; + + { + int vh; + ggml_sycl_memcpy_1<sizeof(int), 2>(&vh, K_q5_0[ib].qh); + vh >>= iqs8 * QI5_0; + + v |= (vh << 4) & 0x00000010; // 0 -> 4 + v |= (vh << 11) & 0x00001000; // 1 -> 12 + v |= (vh << 18) & 0x00100000; // 2 -> 20 + v |= (vh << 25) & 0x10000000; // 3 -> 28 + } + + const int u = Q_q8[k_KQ_0/nthreads]; + + const int sumi = ggml_sycl_dp4a(v, u, 0); + + const sycl::float2 Q_ds = ((const sycl::float2 *) Q_ds_v)[k_KQ_0 / nthreads]; + + sum += __half2float(K_q5_0[ib].d) * (sumi*Q_ds.x() - (16/QI8_1)*Q_ds.y()); + } + + return sum; +} + +template <int D, int nthreads, int warp_size> +static __dpct_inline__ float vec_dot_fattn_vec_KQ_q5_1(const char * __restrict__ K_c, + const void * __restrict__ Q_v, + const int * __restrict__ Q_q8, + const void * __restrict__ Q_ds_v) { + auto item_ct1 = sycl::ext::oneapi::this_work_item::get_nd_item<3>(); + const block_q5_1 * K_q5_1 = (const block_q5_1 *) K_c; + GGML_UNUSED(Q_v); + + float sum = 0.0f; + +#pragma unroll + for (int k_KQ_0 = 0; k_KQ_0 < int(D/sizeof(int)); k_KQ_0 += nthreads) { + const int k_KQ = + k_KQ_0 + (nthreads == warp_size ? item_ct1.get_local_id(2) : item_ct1.get_local_id(2) % nthreads); + + const int ib = k_KQ / QI8_1; + const int iqs4 = k_KQ % QI5_1; + const int iqs8 = k_KQ % QI8_1; + const int shift = k_KQ & (QI8_1/2); + + int v; + ggml_sycl_memcpy_1<sizeof(int)>(&v, K_q5_1[ib].qs + sizeof(int)*iqs4); + v = (v >> shift) & 0x0F0F0F0F; + + { + int vh; + ggml_sycl_memcpy_1<sizeof(int)>(&vh, K_q5_1[ib].qh); + vh >>= iqs8 * QI5_0; + + v |= (vh << 4) & 0x00000010; // 0 -> 4 + v |= (vh << 11) & 0x00001000; // 1 -> 12 + v |= (vh << 18) & 0x00100000; // 2 -> 20 + v |= (vh << 25) & 0x10000000; // 3 -> 28 + } + + const int u = Q_q8[k_KQ_0/nthreads]; + + const int sumi = ggml_sycl_dp4a(v, u, 0); + + const sycl::float2 K_dm = (K_q5_1[ib].dm).template convert<float, sycl::rounding_mode::automatic>(); + const sycl::float2 Q_ds = ((const sycl::float2 *) Q_ds_v)[k_KQ_0 / nthreads]; + + sum += K_dm.x()*Q_ds.x()*sumi + K_dm.y()*Q_ds.y()/QI8_1; + } + + return sum; +} + +template <int D, int nthreads, int warp_size> +static __dpct_inline__ float vec_dot_fattn_vec_KQ_q8_0(const char * __restrict__ K_c, + const void * __restrict__ Q_v, + const int * __restrict__ Q_q8, + const void * __restrict__ Q_ds_v) { + auto item_ct1 = sycl::ext::oneapi::this_work_item::get_nd_item<3>(); + const block_q8_0 * K_q8_0 = (const block_q8_0 *) K_c; + GGML_UNUSED(Q_v); + + float sum = 0.0f; + +#pragma unroll + for (int k_KQ_0 = 0; k_KQ_0 < int(D/sizeof(int)); k_KQ_0 += nthreads) { + const int k_KQ = + k_KQ_0 + (nthreads == warp_size ? item_ct1.get_local_id(2) : item_ct1.get_local_id(2) % nthreads); + + const int ib = k_KQ / QI8_0; + const int iqs = k_KQ % QI8_0; + + int v; + ggml_sycl_memcpy_1<sizeof(v), 2>(&v, K_q8_0[ib].qs + 4*iqs); + + const sycl::float2 * Q_ds = (const sycl::float2 *) Q_ds_v; + const float Q_d = Q_ds[k_KQ_0 / nthreads].x(); + + sum += vec_dot_q8_0_q8_1_impl<float, 1>(&v, &Q_q8[k_KQ_0/nthreads], K_q8_0[ib].d, Q_d); + } + + return sum; +} + +template <typename Tds, int ni, int warp_size> +static __dpct_inline__ void quantize_q8_1_to_shared(const float * __restrict__ x, + const float scale, + int * __restrict__ yq32, + void * __restrict__ yds) { + auto item_ct1 = sycl::ext::oneapi::this_work_item::get_nd_item<3>(); + + float vals[sizeof(int)] = { 0.0f }; +#pragma unroll + for (int l = 0; l < int(sizeof(int)); ++l) { + vals[l] = + (ni == warp_size || item_ct1.get_local_id(2) < ni) ? scale * x[4 * item_ct1.get_local_id(2) + l] : 0.0f; + } + + float amax = sycl::fabs(vals[0]); + float sum = vals[0]; +#pragma unroll + for (int l = 1; l < int(sizeof(int)); ++l) { + amax = sycl::fmax(amax, sycl::fabs(vals[l])); + sum += vals[l]; + } +#pragma unroll + for (int mask = QI8_1/2; mask > 0; mask >>= 1) { + amax = sycl::fmax( + amax, dpct::permute_sub_group_by_xor(sycl::ext::oneapi::this_work_item::get_sub_group(), amax, mask)); + sum += dpct::permute_sub_group_by_xor(sycl::ext::oneapi::this_work_item::get_sub_group(), sum, mask); + } + + const float d = amax / 127; + int q32 = 0; + int8_t * q8 = (int8_t *) &q32; + + if (d != 0.0f) { +#pragma unroll + for (int l = 0; l < int(sizeof(int)); ++l) { + q8[l] = sycl::round(vals[l] / d); + } + } + + yq32[item_ct1.get_local_id(2)] = q32; + if (item_ct1.get_local_id(2) % QI8_1 == 0 && (ni == warp_size || item_ct1.get_local_id(2) < ni)) { + if (std::is_same<Tds, sycl::half2>::value) { + ((sycl::half2 *) yds)[item_ct1.get_local_id(2)/QI8_1] = make_half2(d, sum); + } else { + ((sycl::float2 *) yds)[item_ct1.get_local_id(2)/QI8_1] = make_float2(d, sum); + } + } +} + +typedef void (*dequantize_V_t)(const void *, void *, const int64_t); + +template <typename T, int ne> +static __dpct_inline__ void dequantize_V_f16(const void * __restrict__ vx, void * __restrict__ dst, const int64_t i0) { + if constexpr (std::is_same_v<T, sycl::half>) { + ggml_sycl_memcpy_1<ne * sizeof(sycl::half)>(dst, (const sycl::half *) vx + i0); + } else if constexpr (std::is_same_v<T, float>) { + static_assert(ne % 2 == 0, "bad ne"); + sycl::half2 tmp[ne / 2]; + ggml_sycl_memcpy_1<ne * sizeof(sycl::half)>(tmp, (const sycl::half *) vx + i0); + sycl::float2 * dst_f2 = (sycl::float2 *) dst; +#pragma unroll + for (int l = 0; l < ne/2; ++l) { + dst_f2[l] = tmp[l].template convert<float, sycl::rounding_mode::automatic>(); + } + } else { + static_assert(std::is_same_v<T, void>, "unsupported type"); + } +} + +template <typename T, int ne> +static __dpct_inline__ void dequantize_V_q4_0(const void * __restrict__ vx, void * __restrict__ dst, const int64_t i0) { + const block_q4_0 * x = (const block_q4_0 *) vx; + + const int64_t ib = i0 / QK4_0; + const int iqs = i0 % (QK4_0/2); + const int shift = (i0 % QK4_0) / (QK4_0/2); + + int q; + static_assert(ne == 2 || ne == 4, "bad ne"); + ggml_sycl_memcpy_1<ne, 2>(&q, x[ib].qs + iqs); + q >>= 4*shift; + q &= 0x0F0F0F0F; + q = dpct::vectorized_binary<sycl::char4>(q, 0x08080808, dpct::sub_sat()); + + const int8_t * q8 = (const int8_t *) &q; + +#ifdef GGML_SYCL_F16 + if constexpr (std::is_same_v<T, sycl::half>) { + const sycl::half2 d = sycl::half2(x[ib].d); + +#pragma unroll + for (int l0 = 0; l0 < ne; l0 += 2) { + ((sycl::half2 *) dst)[l0 / 2] = d * sycl::half2(q8[l0 + 0], q8[l0 + 1]); + } + } else +#endif // GGML_SYCL_F16 + if constexpr (std::is_same_v<T, float>) { + const float d = x[ib].d; + +#pragma unroll + for (int l = 0; l < ne; ++l) { + ((float *) dst)[l] = d * q8[l]; + } + } else { + static_assert(std::is_same_v<T, void>, "bad type"); + } +} + +template <typename T, int ne> +static __dpct_inline__ void dequantize_V_q4_1(const void * __restrict__ vx, void * __restrict__ dst, const int64_t i0) { + const block_q4_1 * x = (const block_q4_1 *) vx; + + const int64_t ib = i0 / QK4_1; + const int iqs = i0 % (QK4_1/2); + const int shift = (i0 % QK4_1) / (QK4_1/2); + + int q; + static_assert(ne == 2 || ne == 4, "bad ne"); + ggml_sycl_memcpy_1<ne>(&q, x[ib].qs + iqs); + q >>= 4*shift; + q &= 0x0F0F0F0F; + + const int8_t * q8 = (const int8_t *) &q; + +#ifdef GGML_SYCL_F16 + if constexpr (std::is_same_v<T, sycl::half>) { + const sycl::half2 dm = x[ib].dm; + const sycl::half2 d = sycl::half2(dm[0]); + const sycl::half2 m = sycl::half2(dm[1]); + +#pragma unroll + for (int l0 = 0; l0 < ne; l0 += 2) { + ((sycl::half2 *) dst)[l0 / 2] = d * sycl::half2(q8[l0 + 0], q8[l0 + 1]) + m; + } + } else +#endif // GGML_SYCL_F16 + if constexpr (std::is_same_v<T, float>) { + const sycl::float2 dm = (x[ib].dm).template convert<float, sycl::rounding_mode::automatic>(); + +#pragma unroll + for (int l = 0; l < ne; ++l) { + ((float *) dst)[l] = dm.x() * q8[l] + dm.y(); + } + } else { + static_assert(std::is_same_v<T, void>, "bad type"); + } +} + +template <typename T, int ne> +static __dpct_inline__ void dequantize_V_q5_0(const void * __restrict__ vx, void * __restrict__ dst, const int64_t i0) { + const block_q5_0 * x = (const block_q5_0 *) vx; + + const int64_t ib = i0 / QK5_0; + const int idq = i0 % QK5_0; + const int iqs = i0 % (QK5_0/2); + const int shift = (i0 % QK5_0) / (QK5_0/2); + + int q; + static_assert(ne == 2 || ne == 4, "bad ne"); + ggml_sycl_memcpy_1<ne, 2>(&q, x[ib].qs + iqs); + q >>= 4*shift; + q &= 0x0F0F0F0F; + + { + int qh; + ggml_sycl_memcpy_1<ne, 2>(&qh, x[ib].qh); +#pragma unroll + for (int l = 0; l < ne; ++l) { + q |= ((qh >> (idq + l)) & 0x00000001) << (8*l + 4); + } + } + + q = dpct::vectorized_binary<sycl::char4>(q, 0x10101010, dpct::sub_sat()); + + const int8_t * q8 = (const int8_t *) &q; + +#ifdef GGML_SYCL_F16 + if constexpr (std::is_same_v<T, sycl::half>) { + const sycl::half2 d = sycl::half2(x[ib].d); + +#pragma unroll + for (int l0 = 0; l0 < ne; l0 += 2) { + ((sycl::half2 *) dst)[l0 / 2] = d * sycl::half2(q8[l0 + 0], q8[l0 + 1]); + } + } else +#endif // GGML_SYCL_F16 + if constexpr (std::is_same_v<T, float>) { + const float d = x[ib].d; + +#pragma unroll + for (int l = 0; l < ne; ++l) { + ((float *) dst)[l] = d * q8[l]; + } + } else { + static_assert(std::is_same_v<T, void>, "bad type"); + } +} + +template <typename T, int ne> +static __dpct_inline__ void dequantize_V_q5_1(const void * __restrict__ vx, void * __restrict__ dst, const int64_t i0) { + const block_q5_1 * x = (const block_q5_1 *) vx; + + const int64_t ib = i0 / QK5_1; + const int idq = i0 % QK5_1; + const int iqs = i0 % (QK5_1/2); + const int shift = (i0 % QK5_1) / (QK5_1/2); + + int q; + static_assert(ne == 2 || ne == 4, "bad ne"); + ggml_sycl_memcpy_1<ne>(&q, x[ib].qs + iqs); + q >>= 4*shift; + q &= 0x0F0F0F0F; + + { + int qh; + ggml_sycl_memcpy_1<ne>(&qh, x[ib].qh); +#pragma unroll + for (int l = 0; l < ne; ++l) { + q |= ((qh >> (idq + l)) & 0x00000001) << (8*l + 4); + } + } + + const int8_t * q8 = (const int8_t *) &q; + +#ifdef GGML_SYCL_F16 + if constexpr (std::is_same_v<T, sycl::half>) { + const sycl::half2 dm = x[ib].dm; + const sycl::half2 d = sycl::half2(dm[0]); + const sycl::half2 m = sycl::half2(dm[1]); + +#pragma unroll + for (int l0 = 0; l0 < ne; l0 += 2) { + ((sycl::half2 *) dst)[l0 / 2] = d * sycl::half2(q8[l0 + 0], q8[l0 + 1]) + m; + } + } else +#endif // GGML_SYCL_F16 + if constexpr (std::is_same_v<T, float>) { + const sycl::float2 dm = (x[ib].dm).template convert<float, sycl::rounding_mode::automatic>(); + +#pragma unroll + for (int l = 0; l < ne; ++l) { + ((float *) dst)[l] = dm.x() * q8[l] + dm.y(); + } + } else { + static_assert(std::is_same_v<T, void>, "bad type"); + } +} + +template <typename T, int ne> +static __dpct_inline__ void dequantize_V_q8_0(const void * __restrict__ vx, void * __restrict__ dst, const int64_t i0) { + const block_q8_0 * x = (const block_q8_0 *) vx; + + const int64_t ib = i0 / QK8_0; + const int iqs = i0 % QK8_0; + + static_assert(ne % 2 == 0, "bad ne"); + int8_t qs[ne]; + ggml_sycl_memcpy_1<ne, 2>(qs, x[ib].qs + iqs); + +#ifdef GGML_SYCL_F16 + if constexpr (std::is_same<T, sycl::half>::value) { + const sycl::half2 d = sycl::half2(x[ib].d); + +#pragma unroll + for (int l0 = 0; l0 < ne; l0 += 2) { + ((sycl::half2 *) dst)[l0 / 2] = d * make_half2(qs[l0 + 0], qs[l0 + 1]); + } + } else +#endif // GGML_SYCL_F16 + if constexpr (std::is_same<T, float>::value) { + const float d = x[ib].d; + +#pragma unroll + for (int l = 0; l < ne; ++l) { + ((float *) dst)[l] = d * qs[l]; + } + } else { + static_assert(std::is_same_v<T, void>, "unsupported type"); + } +} + +template <int type_K, int D, int nthreads, int warp_size> +constexpr vec_dot_KQ_t get_vec_dot_KQ() { + if constexpr (type_K == GGML_TYPE_F16) { + return vec_dot_fattn_vec_KQ_f16<D, nthreads>; + } else if constexpr (type_K == GGML_TYPE_Q4_0) { + return vec_dot_fattn_vec_KQ_q4_0<D, nthreads, warp_size>; + } else if constexpr (type_K == GGML_TYPE_Q4_1) { + return vec_dot_fattn_vec_KQ_q4_1<D, nthreads, warp_size>; + } else if constexpr (type_K == GGML_TYPE_Q5_0) { + return vec_dot_fattn_vec_KQ_q5_0<D, nthreads, warp_size>; + } else if constexpr (type_K == GGML_TYPE_Q5_1) { + return vec_dot_fattn_vec_KQ_q5_1<D, nthreads, warp_size>; + } else if constexpr (type_K == GGML_TYPE_Q8_0) { + return vec_dot_fattn_vec_KQ_q8_0<D, nthreads, warp_size>; + } else { + static_assert(type_K == -1, "bad type"); + return nullptr; + } +} + +template <int type_V, typename T, int ne> +constexpr dequantize_V_t get_dequantize_V() { + if constexpr (type_V == GGML_TYPE_F16) { + return dequantize_V_f16<T, ne>; + } else if constexpr (type_V == GGML_TYPE_Q4_0) { + return dequantize_V_q4_0<T, ne>; + } else if constexpr (type_V == GGML_TYPE_Q4_1) { + return dequantize_V_q4_1<T, ne>; + } else if constexpr (type_V == GGML_TYPE_Q5_0) { + return dequantize_V_q5_0<T, ne>; + } else if constexpr (type_V == GGML_TYPE_Q5_1) { + return dequantize_V_q5_1<T, ne>; + } else if constexpr (type_V == GGML_TYPE_Q8_0) { + return dequantize_V_q8_0<T, ne>; + } else { + static_assert(type_V == -1, "bad type"); + return nullptr; + } +} + +template <int ncols1, int warp_size> +static void flash_attn_mask_to_KV_max(const sycl::half2 * __restrict__ mask, + int * __restrict__ KV_max, + const int ne30, + const int s31, + const int s33, + int * buf_iw) { + auto item_ct1 = sycl::ext::oneapi::this_work_item::get_nd_item<3>(); + const int ne31 = item_ct1.get_group_range(2); + const int tid = item_ct1.get_local_id(2); + const int sequence = item_ct1.get_group(1); + const int jt = item_ct1.get_group(2); + + mask += sequence*s33 + jt*ncols1*s31; + + if (tid < warp_size) { + buf_iw[tid] = 1; + } + item_ct1.barrier(sycl::access::fence_space::local_space); + + int KV_max_sj = (ne30 - 1) * FATTN_KQ_STRIDE; + for (; KV_max_sj >= 0; KV_max_sj -= FATTN_KQ_STRIDE) { + int all_inf = 1; + +#pragma unroll + for (int j = 0; j < ncols1; ++j) { + const sycl::float2 tmp = + mask[j * s31 + KV_max_sj / 2 + tid].template convert<float, sycl::rounding_mode::automatic>(); + all_inf = all_inf && int(sycl::isinf((float) (tmp.x()))) && int(sycl::isinf((float) (tmp.y()))); + } + + all_inf = warp_reduce_all<warp_size>(all_inf); + if (tid % warp_size == 0) { + buf_iw[tid / warp_size] = all_inf; + } + item_ct1.barrier(sycl::access::fence_space::local_space); + all_inf = buf_iw[tid % warp_size]; + item_ct1.barrier(sycl::access::fence_space::local_space); + all_inf = warp_reduce_all<warp_size>(all_inf); + + if (!all_inf) { + break; + } + } + + // If the break in the loop was not triggered, KV_max_sj is now -FATTN_KQ_STRIDE. + // If the break was triggered it's the lower edge of the tile with the first non-masked values. + // In either case, walk back the decrementation by FATTN_KQ_STRIDE. + KV_max_sj += FATTN_KQ_STRIDE; + + if (item_ct1.get_local_id(2) != 0) { + return; + } + + KV_max[sequence*ne31 + jt] = KV_max_sj; +} + +template <int D, int ncols1, int ncols2> // D == head size + +static void flash_attn_stream_k_fixup(float * __restrict__ dst, + const sycl::float2 * __restrict__ dst_fixup, + const int ne01, + const int ne02, + const int ne03, + const int ne11, + const int ne12, + const int nbatch_fa) { + auto item_ct1 = sycl::ext::oneapi::this_work_item::get_nd_item<3>(); + constexpr int ncols = ncols1 * ncols2; + + const int bidx0 = item_ct1.get_group(2); + const int j = item_ct1.get_group(1); + const int c = item_ct1.get_group(0); + const int jc = j*ncols2 + c; + const int tid = item_ct1.get_local_id(2); + + const float * dst_fixup_data = ((const float *) dst_fixup) + item_ct1.get_group_range(2) * (2 * 2 * ncols); + + const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix. + + const int iter_k = (ne11 + (nbatch_fa - 1)) / nbatch_fa; + const int iter_j = (ne01 + (ncols1 - 1)) / ncols1; + const int iter_z_gqa = (gqa_ratio + (ncols2 - 1)) / ncols2; + + const int kbc0 = int64_t(bidx0 + 0) * (iter_k * iter_j * iter_z_gqa * ne12 * ne03) / item_ct1.get_group_range(2); + const int kbc0_stop = + int64_t(bidx0 + 1) * (iter_k * iter_j * iter_z_gqa * ne12 * ne03) / item_ct1.get_group_range(2); + + const bool did_not_have_any_data = kbc0 == kbc0_stop; + const bool wrote_beginning_of_tile = kbc0 % iter_k == 0; + const bool did_not_write_last = kbc0/iter_k == kbc0_stop/iter_k && kbc0_stop % iter_k != 0; + if (did_not_have_any_data || wrote_beginning_of_tile || did_not_write_last) { + return; + } + + // z_KV == K/V head index, zt_gqa = Q head start index per K/V head, jt = token position start index + const int sequence = kbc0 /(iter_k*iter_j*iter_z_gqa*ne12); + const int z_KV = (kbc0 - iter_k*iter_j*iter_z_gqa*ne12 * sequence)/(iter_k*iter_j*iter_z_gqa); + const int zt_gqa = (kbc0 - iter_k*iter_j*iter_z_gqa*ne12 * sequence - iter_k*iter_j*iter_z_gqa * z_KV)/(iter_k*iter_j); + const int jt = (kbc0 - iter_k*iter_j*iter_z_gqa*ne12 * sequence - iter_k*iter_j*iter_z_gqa * z_KV - iter_k*iter_j * zt_gqa) / iter_k; + + const int zt_Q = z_KV*gqa_ratio + zt_gqa*ncols2; // Global Q head start index. + + if (jt*ncols1 + j >= ne01 || zt_gqa*ncols2 + c >= gqa_ratio) { + return; + } + + dst += sequence*ne02*ne01*D + jt*ne02*(ncols1*D) + zt_Q*D + (j*ne02 + c)*D + tid; + + // Load the partial result that needs a fixup: + float dst_val = 0.0f; + float max_val = 0.0f; + float rowsum = 0.0f; + { + dst_val = *dst; + + const sycl::float2 tmp = dst_fixup[bidx0 * ncols + jc]; + max_val = tmp.x(); + rowsum = tmp.y(); + } + + // Iterate over previous blocks and compute the combined results. + // All SYCL blocks that get here must have a previous block that needs a fixup. + int bidx = bidx0 - 1; + int kbc_stop = kbc0; + while(true) { + const int kbc = int64_t(bidx) * (iter_k * iter_j * iter_z_gqa * ne12 * ne03) / item_ct1.get_group_range(2); + if (kbc == kbc_stop) { // Did not have any data. + bidx--; + kbc_stop = kbc; + continue; + } + + const float dst_add = dst_fixup_data[bidx*ncols*D + jc*D + tid]; + + const sycl::float2 tmp = dst_fixup[(item_ct1.get_group_range(2) + bidx) * ncols + jc]; + + // Scale the current and new value accumulators depending on the max. values. + const float max_val_new = sycl::fmax(max_val, tmp.x()); + + const float diff_val = max_val - max_val_new; + const float diff_add = tmp.x() - max_val_new; + + const float scale_val = diff_val >= SOFTMAX_FTZ_THRESHOLD ? sycl::native::exp(diff_val) : 0.0f; + const float scale_add = diff_add >= SOFTMAX_FTZ_THRESHOLD ? sycl::native::exp(diff_add) : 0.0f; + + dst_val = scale_val*dst_val + scale_add*dst_add; + rowsum = scale_val * rowsum + scale_add * tmp.y(); + + max_val = max_val_new; + + // If this block started in a previous tile we are done and don't need to combine additional partial results. + if (kbc % iter_k == 0 || kbc/iter_k < kbc0/iter_k) { + break; + } + bidx--; + kbc_stop = kbc; + } + + // Write back final result: + *dst = dst_val / rowsum; +} + +template <int D> // D == head size + +static void flash_attn_combine_results(const float * __restrict__ VKQ_parts, + const sycl::float2 * __restrict__ VKQ_meta, + float * __restrict__ dst, + const int parallel_blocks, + uint8_t * dpct_local) { + // Dimension 0: threadIdx.x + // Dimension 1: blockIdx.x + // Dimension 2: blockIdx.y + // Dimension 3: blockIdx.z + // Memory layout is permuted with [0, 2, 1, 3] + + auto item_ct1 = sycl::ext::oneapi::this_work_item::get_nd_item<3>(); + const int ne01 = item_ct1.get_group_range(2); + const int ne02 = item_ct1.get_group_range(1); + + const int col = item_ct1.get_group(2); + const int head = item_ct1.get_group(1); + const int sequence = item_ct1.get_group(0); + + const int j_dst_unrolled = (sequence*ne01 + col)*ne02 + head; + + VKQ_parts += j_dst_unrolled * parallel_blocks*D; + VKQ_meta += j_dst_unrolled * parallel_blocks; + dst += j_dst_unrolled * D; + + const int tid = item_ct1.get_local_id(2); + __builtin_assume(tid < D); + + auto meta = (sycl::float2 *) dpct_local; + for (int i = tid; i < 2*parallel_blocks; i += D) { + ((float *) meta)[i] = ((const float *)VKQ_meta) [i]; + } + + item_ct1.barrier(sycl::access::fence_space::local_space); + + float kqmax = meta[0].x(); + for (int l = 1; l < parallel_blocks; ++l) { + kqmax = sycl::max(kqmax, meta[l].x()); + } + + float VKQ_numerator = 0.0f; + float VKQ_denominator = 0.0f; + for (int l = 0; l < parallel_blocks; ++l) { + const float KQ_max_scale = sycl::native::exp(meta[l].x() - kqmax); + + VKQ_numerator += KQ_max_scale * VKQ_parts[l*D + tid]; + VKQ_denominator += KQ_max_scale * meta[l].y(); + } + + dst[tid] = VKQ_numerator / VKQ_denominator; +} + +template <fattn_kernel_t fattn_kernel, int warp_size> +static void lauch_kernel( + dpct::dim3 group_range, + dpct::dim3 local_range, + queue_ptr q, + unsigned int local_mem_size, + const char* __restrict__ Q, + const char* __restrict__ K, + const char* __restrict__ V, + const char* __restrict__ mask, + const char* __restrict__ sinks, + const int* __restrict__ KV_max, + float* __restrict__ dst, + sycl::float2* __restrict__ dst_meta, + const float scale, + const float max_bias, + const float m0, + const float m1, + const uint32_t n_head_log2, + const float logit_softcap, + const int32_t ne00, + const sycl::uint3 ne01, + const int32_t ne02, + const int32_t ne03, + const int32_t nb01, + const int32_t nb02, + const int32_t nb03, + const int32_t ne10, + const int32_t ne11, + const int32_t ne12, + const int32_t ne13, + const int32_t nb11, + const int32_t nb12, + const int64_t nb13, + const int32_t nb21, + const int32_t nb22, + const int64_t nb23, + const int32_t ne31, + const int32_t ne32, + const int32_t ne33, + const int32_t nb31, + const int32_t nb32, + const int64_t nb33) { + GGML_UNUSED(local_mem_size); + q->submit([&](sycl::handler &cgh) { + cgh.parallel_for( + sycl::nd_range<3>( + static_cast<sycl::range<3>>(group_range * local_range), + static_cast<sycl::range<3>>(local_range)), + [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(warp_size)]] { + GGML_UNUSED(item_ct1); + fattn_kernel(Q, K, V, mask, sinks, KV_max, dst, dst_meta, scale, + max_bias, m0, m1, n_head_log2, logit_softcap, ne00, + ne01, ne02, ne03, nb01, nb02, nb03, ne10, ne11, + ne12, ne13, nb11, nb12, nb13, nb21, nb22, nb23, + ne31, ne32, ne33, nb31, nb32, nb33); + }); + }); +} + +template <int DV, int ncols1, int ncols2, fattn_kernel_t fattn_kernel, int warp_size> +void launch_fattn( + ggml_backend_sycl_context & ctx, ggml_tensor * dst, const int nwarps, const size_t nbytes_shared, + const int nbatch_fa, const bool need_f16_K, const bool need_f16_V, const bool stream_k) { + + constexpr int ncols = ncols1 * ncols2; + + const ggml_tensor * Q = dst->src[0]; + const ggml_tensor * K = dst->src[1]; + const ggml_tensor * V = dst->src[2]; + + const bool V_is_K_view = V->view_src && (V->view_src == K || (V->view_src == K->view_src && V->view_offs == K->view_offs)); + + const ggml_tensor * mask = dst->src[3]; + const ggml_tensor * sinks = dst->src[4]; + + ggml_tensor * KQV = dst; + + GGML_ASSERT(Q->type == GGML_TYPE_F32); + GGML_ASSERT(KQV->type == GGML_TYPE_F32); + + GGML_ASSERT(Q->nb[0] == ggml_element_size(Q)); + GGML_ASSERT(K->nb[0] == ggml_element_size(K)); + GGML_ASSERT(V->nb[0] == ggml_element_size(V)); + + GGML_ASSERT(!mask || mask->type == GGML_TYPE_F16); + + ggml_sycl_pool & pool = ctx.pool(); + ggml_sycl_fattn_kv_buffers & fbuf = ctx.fattn_buffers(); + dpct::queue_ptr main_stream = ctx.stream(); + const int id = ggml_sycl_get_device(); + const int nsm = ggml_sycl_info().devices[id].nsm; + + ggml_sycl_fattn_alloc K_f16(fbuf.K); + ggml_sycl_fattn_alloc V_f16(fbuf.V); + ggml_sycl_pool_alloc<int> KV_max(pool); + ggml_sycl_pool_alloc<float> dst_tmp(pool); + ggml_sycl_pool_alloc<sycl::float2> dst_tmp_meta(pool); + + const char * K_data = (const char *) K->data; + size_t nb11 = K->nb[1]; + size_t nb12 = K->nb[2]; + size_t nb13 = K->nb[3]; + + const char * V_data = (const char *) V->data; + size_t nb21 = V->nb[1]; + size_t nb22 = V->nb[2]; + size_t nb23 = V->nb[3]; + + if (need_f16_K && K->type != GGML_TYPE_F16) { + const size_t bs = ggml_blck_size(K->type); + const size_t ts = ggml_type_size(K->type); + + K_f16.alloc(ggml_nelements(K)); + if (ggml_is_contiguously_allocated(K)) { + to_fp16_sycl_t to_fp16 = ggml_get_to_fp16_sycl(K->type, dst); + to_fp16(K_data, K_f16.ptr, ggml_nelements(K), main_stream); + + nb11 = nb11 * bs * sizeof(sycl::half) / ts; + nb12 = nb12 * bs * sizeof(sycl::half) / ts; + nb13 = nb13 * bs * sizeof(sycl::half) / ts; + } else { + GGML_ASSERT(K->nb[0] == ts); + to_fp16_nc_sycl_t to_fp16 = ggml_get_to_fp16_nc_sycl(K->type); + const int64_t s01 = nb11 / ts; + const int64_t s02 = nb12 / ts; + const int64_t s03 = nb13 / ts; + to_fp16(K_data, K_f16.ptr, K->ne[0], K->ne[1], K->ne[2], K->ne[3], s01, s02, s03, main_stream); + + nb11 = K->ne[0] * sizeof(sycl::half); + nb12 = K->ne[1] * nb11; + nb13 = K->ne[2] * nb12; + } + K_data = (char *) K_f16.ptr; + } + + if (need_f16_V && V->type != GGML_TYPE_F16) { + if (V_is_K_view) { + V_data = K_data; + nb21 = nb11; + nb22 = nb12; + nb23 = nb13; + } else { + const size_t bs = ggml_blck_size(V->type); + const size_t ts = ggml_type_size(V->type); + + V_f16.alloc(ggml_nelements(V)); + if (ggml_is_contiguously_allocated(V)) { + to_fp16_sycl_t to_fp16 = ggml_get_to_fp16_sycl(V->type, dst); + to_fp16(V_data, V_f16.ptr, ggml_nelements(V), main_stream); + V_data = (char *) V_f16.ptr; + + nb21 = nb21 * bs * sizeof(sycl::half) / ts; + nb22 = nb22 * bs * sizeof(sycl::half) / ts; + nb23 = nb23 * bs * sizeof(sycl::half) / ts; + } else { + GGML_ASSERT(V->nb[0] == ts); + to_fp16_nc_sycl_t to_fp16 = ggml_get_to_fp16_nc_sycl(V->type); + const int64_t s01 = nb21 / ts; + const int64_t s02 = nb22 / ts; + const int64_t s03 = nb23 / ts; + to_fp16(V_data, V_f16.ptr, V->ne[0], V->ne[1], V->ne[2], V->ne[3], s01, s02, s03, main_stream); + + nb21 = V->ne[0] * sizeof(sycl::half); + nb22 = V->ne[1] * nb21; + nb23 = V->ne[2] * nb22; + } + V_data = (char *) V_f16.ptr; + } + } + + const int ntiles_x = ((Q->ne[1] + ncols1 - 1) / ncols1); + const int gqa_ratio = Q->ne[2] / K->ne[2]; + const int ntiles_z_gqa = ((gqa_ratio + ncols2 - 1) / ncols2); + const int ntiles_total = ntiles_x * ntiles_z_gqa * K->ne[2] * Q->ne[3]; + + // Optional optimization where the mask is scanned to determine whether part of the calculation can be skipped. + // Only worth the overhead if there is at lease one FATTN_KQ_STRIDE x FATTN_KQ_STRIDE square to be skipped or + // multiple sequences of possibly different lengths. + if (mask && K->ne[1] % FATTN_KQ_STRIDE == 0 && (Q->ne[1] >= 1024 || Q->ne[3] > 1)) { + const int s31 = mask->nb[1] / sizeof(sycl::half2); + const int s33 = mask->nb[3] / sizeof(sycl::half2); + + const dpct::dim3 blocks_num_KV_max(ntiles_x, Q->ne[3], 1); + const dpct::dim3 block_dim_KV_max(FATTN_KQ_STRIDE / 2, 1, 1); + + const int ne_KV_max = blocks_num_KV_max.x*blocks_num_KV_max.y; + const int iter_k = K->ne[1] / FATTN_KQ_STRIDE; + + KV_max.alloc(ne_KV_max); + { + dpct::has_capability_or_fail(main_stream->get_device(), { sycl::aspect::fp16 }); + + main_stream->submit([&](sycl::handler & cgh) { + sycl::local_accessor<int, 1> buf_iw_acc_ct1(sycl::range<1>(warp_size), cgh); + + auto mask_data_ct0 = (const sycl::half2 *) mask->data; + auto KV_max_ptr_ct1 = KV_max.ptr; + + cgh.parallel_for(sycl::nd_range<3>(blocks_num_KV_max * block_dim_KV_max, block_dim_KV_max), + [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(warp_size)]] { + GGML_UNUSED(item_ct1); + flash_attn_mask_to_KV_max<ncols1, warp_size>( + mask_data_ct0, KV_max_ptr_ct1, iter_k, s31, s33, + buf_iw_acc_ct1.get_multi_ptr<sycl::access::decorated::no>().get()); + }); + }); + } + SYCL_CHECK(0); + } + + const dpct::dim3 block_dim(warp_size, nwarps, 1); + + // Max. number of active blocks limited by occupancy. + int max_blocks_per_sm = ggml_sycl_info().devices[id].max_wg_per_cu; + int parallel_blocks = max_blocks_per_sm; + dpct::dim3 blocks_num; + if (stream_k) { + // For short contexts it can be faster to have the SMs work on whole tiles because this lets us skip the fixup. + const int max_blocks = max_blocks_per_sm*nsm; + const int nblocks_stream_k = max_blocks; + const bool use_stream_k = true; + + blocks_num.x = use_stream_k ? nblocks_stream_k : ntiles_total; + blocks_num.y = 1; + blocks_num.z = 1; + + if (ntiles_total % blocks_num.x != 0) { // Fixup is only needed if the SMs work on fractional tiles. + dst_tmp_meta.alloc((size_t(blocks_num.x) * ncols * (2 + DV/2))); + } + } else { + const int ntiles_KQ = (K->ne[1] + nbatch_fa - 1) / nbatch_fa; // Max. number of parallel blocks limited by tensor size. + + // parallel_blocks must not be larger than what the tensor size allows: + parallel_blocks = std::min(parallel_blocks, ntiles_KQ); + // todo fix the hard code change + // parallel_blocks = ntiles_KQ; + + // If ntiles_total % blocks_per_wave != 0 then some efficiency is lost due to tail effects. + // Test whether parallel_blocks can be set to a higher value for better efficiency. + const int blocks_per_wave = nsm * max_blocks_per_sm; + int nwaves_best = 0; + int efficiency_percent_best = 0; + for (int parallel_blocks_test = parallel_blocks; parallel_blocks_test <= ntiles_KQ; ++parallel_blocks_test) { + const int nblocks_total = ntiles_total * parallel_blocks_test; + const int nwaves = (nblocks_total + blocks_per_wave - 1) / blocks_per_wave; + const int efficiency_percent = 100 * nblocks_total / (nwaves*blocks_per_wave); + + // Stop trying configurations with more waves if we already have good efficiency to avoid excessive overhead. + if (efficiency_percent_best >= 95 && nwaves > nwaves_best) { + break; + } + + if (efficiency_percent > efficiency_percent_best) { + nwaves_best = nwaves; + efficiency_percent_best = efficiency_percent; + parallel_blocks = parallel_blocks_test; + } + } + + blocks_num.x = ntiles_x; + blocks_num.y = parallel_blocks; + blocks_num.z = ntiles_z_gqa*K->ne[2]*Q->ne[3]; + + if (parallel_blocks > 1) { + dst_tmp.alloc(parallel_blocks*ggml_nelements(KQV)); + dst_tmp_meta.alloc(parallel_blocks*ggml_nrows(KQV)); + } + } + + float scale = 1.0f; + float max_bias = 0.0f; + float logit_softcap = 0.0f; + + memcpy(&scale, (const float *) KQV->op_params + 0, sizeof(float)); + memcpy(&max_bias, (const float *) KQV->op_params + 1, sizeof(float)); + memcpy(&logit_softcap, (const float *) KQV->op_params + 2, sizeof(float)); + + if (logit_softcap != 0.0f) { + scale /= logit_softcap; + } + + const uint32_t n_head = Q->ne[2]; + const uint32_t n_head_log2 = 1u << uint32_t(floorf(log2f(float(n_head)))); + + const float m0 = powf(2.0f, -(max_bias ) / n_head_log2); + const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2); + + // TODO other tensor dimensions after removal of WMMA kernel: + const sycl::uint3 ne01 = init_fastdiv_values(Q->ne[1]); + + GGML_ASSERT(block_dim.x % warp_size == 0); + + lauch_kernel<fattn_kernel, warp_size>( + blocks_num, block_dim, main_stream, (unsigned int) nbytes_shared, (const char *) Q->data, K_data, V_data, + mask ? ((const char *) mask->data) : nullptr, sinks ? ((const char *) sinks->data) : nullptr, KV_max.ptr, + !stream_k && parallel_blocks > 1 ? dst_tmp.ptr : (float *) KQV->data, (sycl::float2 *)dst_tmp_meta.ptr, scale, max_bias, m0, m1, + n_head_log2, logit_softcap, Q->ne[0], ne01, Q->ne[2], Q->ne[3], Q->nb[1], Q->nb[2], Q->nb[3], K->ne[0], + K->ne[1], K->ne[2], K->ne[3], nb11, nb12, nb13, nb21, nb22, nb23, mask ? mask->ne[1] : 0, + mask ? mask->ne[2] : 0, mask ? mask->ne[3] : 0, mask ? mask->nb[1] : 0, mask ? mask->nb[2] : 0, + mask ? mask->nb[3] : 0); + SYCL_CHECK(0); + + if (stream_k) { + if (ntiles_total % blocks_num.x != 0) { // Fixup is only needed if the SMs work on fractional tiles. + const dpct::dim3 block_dim_combine(DV, 1, 1); + const dpct::dim3 blocks_num_combine = { blocks_num.x, ncols1, ncols2 }; + + main_stream->submit([&](sycl::handler & cgh) { + auto KQV_data_ct0 = (float *) KQV->data; + auto dst_tmp_meta_ptr_ct1 = dst_tmp_meta.ptr; + auto Q_ne_ct2 = Q->ne[1]; + auto Q_ne_ct3 = Q->ne[2]; + auto Q_ne_ct4 = Q->ne[3]; + auto K_ne_ct5 = K->ne[1]; + auto K_ne_ct6 = K->ne[2]; + + cgh.parallel_for(sycl::nd_range<3>(blocks_num_combine * block_dim_combine, block_dim_combine), + [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(warp_size)]] { + GGML_UNUSED(item_ct1); + flash_attn_stream_k_fixup<DV, ncols1, ncols2>(KQV_data_ct0, dst_tmp_meta_ptr_ct1, + Q_ne_ct2, Q_ne_ct3, Q_ne_ct4, + K_ne_ct5, K_ne_ct6, nbatch_fa); + }); + }); + } + } else if (parallel_blocks > 1) { + const dpct::dim3 block_dim_combine(DV, 1, 1); + const dpct::dim3 blocks_num_combine(Q->ne[1], Q->ne[2], Q->ne[3]); + const size_t nbytes_shared_combine = parallel_blocks * sizeof(sycl::float2); + main_stream->submit([&](sycl::handler & cgh) { + sycl::local_accessor<uint8_t, 1> dpct_local_acc_ct1(sycl::range<1>(nbytes_shared_combine), cgh); + + auto dst_tmp_ptr_ct0 = dst_tmp.ptr; + auto dst_tmp_meta_ptr_ct1 = dst_tmp_meta.ptr; + auto KQV_data_ct2 = (float *) KQV->data; + + cgh.parallel_for(sycl::nd_range<3>(blocks_num_combine * block_dim_combine, block_dim_combine), + [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(warp_size)]] { + GGML_UNUSED(item_ct1); + flash_attn_combine_results<DV>( + dst_tmp_ptr_ct0, dst_tmp_meta_ptr_ct1, KQV_data_ct2, parallel_blocks, + dpct_local_acc_ct1.get_multi_ptr<sycl::access::decorated::no>().get()); + }); + }); + } + SYCL_CHECK(0); +} diff --git a/ggml/src/ggml-sycl/fattn-tile.cpp b/ggml/src/ggml-sycl/fattn-tile.cpp new file mode 100644 index 00000000000..9449d75784d --- /dev/null +++ b/ggml/src/ggml-sycl/fattn-tile.cpp @@ -0,0 +1,59 @@ +#include <sycl/sycl.hpp> +#include <sycl/ext/oneapi/work_group_static.hpp> +#include "dpct/helper.hpp" +#include "common.hpp" +#include "fattn-common.hpp" +#include "fattn-tile.hpp" +#include <cmath> +#include <float.h> +namespace syclex = sycl::ext::oneapi::experimental; + +void ggml_sycl_flash_attn_ext_tile(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { + const ggml_tensor * K = dst->src[1]; + const ggml_tensor * V = dst->src[2]; + switch (K->ne[0]) { + case 40: { + GGML_ASSERT(V->ne[0] == K->ne[0]); + ggml_sycl_flash_attn_ext_tile_case< 40, 40>(ctx, dst); + } break; + case 64: { + GGML_ASSERT(V->ne[0] == K->ne[0]); + ggml_sycl_flash_attn_ext_tile_case< 64, 64>(ctx, dst); + } break; + case 72: { + GGML_ASSERT(V->ne[0] == K->ne[0]); + ggml_sycl_flash_attn_ext_tile_case< 72, 72>(ctx, dst); + } break; + case 80: { + GGML_ASSERT(V->ne[0] == K->ne[0]); + ggml_sycl_flash_attn_ext_tile_case< 80, 80>(ctx, dst); + } break; + case 96: { + GGML_ASSERT(V->ne[0] == K->ne[0]); + ggml_sycl_flash_attn_ext_tile_case< 96, 96>(ctx, dst); + } break; + case 112: { + GGML_ASSERT(V->ne[0] == K->ne[0]); + ggml_sycl_flash_attn_ext_tile_case<112, 112>(ctx, dst); + } break; + case 128: { + GGML_ASSERT(V->ne[0] == K->ne[0]); + ggml_sycl_flash_attn_ext_tile_case<128, 128>(ctx, dst); + } break; + case 256: { + GGML_ASSERT(V->ne[0] == K->ne[0]); + ggml_sycl_flash_attn_ext_tile_case<256, 256>(ctx, dst); + } break; + case 512: { + GGML_ASSERT(V->ne[0] == K->ne[0]); + ggml_sycl_flash_attn_ext_tile_case<512, 512>(ctx, dst); + } break; + case 576: { + GGML_ASSERT(V->ne[0] == 512); + ggml_sycl_flash_attn_ext_tile_case<576, 512>(ctx, dst); + } break; + default: { + GGML_ABORT("Unsupported head size"); + } break; + } +} diff --git a/ggml/src/ggml-sycl/fattn-tile.hpp b/ggml/src/ggml-sycl/fattn-tile.hpp new file mode 100644 index 00000000000..9ba5296968d --- /dev/null +++ b/ggml/src/ggml-sycl/fattn-tile.hpp @@ -0,0 +1,1246 @@ +#include <sycl/sycl.hpp> +#include <sycl/ext/oneapi/work_group_static.hpp> +#include "dpct/helper.hpp" +#include "common.hpp" +#include "fattn-common.hpp" + +#include <cmath> +#include <float.h> + +namespace syclex = sycl::ext::oneapi::experimental; + +#define GGML_SYCL_FATTN_TILE_CONFIG_CASE(DKQ_, DV_, ncols_, nthreads, occupancy, nbatch_fa, nbatch_K) \ + if (DKQ == (DKQ_) && DV == (DV_) && ncols == (ncols_)) { \ + static_assert((nthreads) <= 512, "bad nthreads"); \ + static_assert((occupancy) <= 8, "bad occupancy"); \ + static_assert((nbatch_fa) <= 256, "bad nbatch_fa"); \ + static_assert((nbatch_K) <= 256, "bad nbatch_K"); \ + return ((nthreads) << 0) | ((occupancy) << 10) | ((nbatch_fa) << 14) | ((nbatch_K) << 23); \ + } \ + +static constexpr uint32_t ggml_sycl_fattn_tile_get_config_fp16(const int DKQ, const int DV, const int ncols) { + GGML_SYCL_FATTN_TILE_CONFIG_CASE( 40, 40, 2, 64, 2, 64, 40) + GGML_SYCL_FATTN_TILE_CONFIG_CASE( 40, 40, 4, 128, 2, 64, 40) + GGML_SYCL_FATTN_TILE_CONFIG_CASE( 40, 40, 8, 256, 2, 64, 40) + GGML_SYCL_FATTN_TILE_CONFIG_CASE( 40, 40, 16, 256, 2, 64, 40) + GGML_SYCL_FATTN_TILE_CONFIG_CASE( 40, 40, 32, 256, 2, 64, 40) + + GGML_SYCL_FATTN_TILE_CONFIG_CASE( 64, 64, 2, 64, 2, 64, 64) + GGML_SYCL_FATTN_TILE_CONFIG_CASE( 64, 64, 4, 128, 2, 64, 64) + GGML_SYCL_FATTN_TILE_CONFIG_CASE( 64, 64, 8, 256, 2, 64, 64) + GGML_SYCL_FATTN_TILE_CONFIG_CASE( 64, 64, 16, 256, 2, 64, 64) + GGML_SYCL_FATTN_TILE_CONFIG_CASE( 64, 64, 32, 256, 2, 64, 64) + + GGML_SYCL_FATTN_TILE_CONFIG_CASE( 72, 72, 2, 64, 2, 64, 72) + GGML_SYCL_FATTN_TILE_CONFIG_CASE( 72, 72, 4, 128, 2, 64, 72) + GGML_SYCL_FATTN_TILE_CONFIG_CASE( 72, 72, 8, 256, 2, 64, 72) + GGML_SYCL_FATTN_TILE_CONFIG_CASE( 72, 72, 16, 256, 2, 64, 72) + GGML_SYCL_FATTN_TILE_CONFIG_CASE( 72, 72, 32, 256, 2, 64, 72) + + GGML_SYCL_FATTN_TILE_CONFIG_CASE( 80, 80, 2, 64, 2, 64, 40) + GGML_SYCL_FATTN_TILE_CONFIG_CASE( 80, 80, 4, 128, 2, 64, 40) + GGML_SYCL_FATTN_TILE_CONFIG_CASE( 80, 80, 8, 256, 2, 64, 40) + GGML_SYCL_FATTN_TILE_CONFIG_CASE( 80, 80, 16, 256, 2, 64, 40) + GGML_SYCL_FATTN_TILE_CONFIG_CASE( 80, 80, 32, 256, 2, 64, 40) + + GGML_SYCL_FATTN_TILE_CONFIG_CASE( 96, 96, 2, 64, 2, 64, 48) + GGML_SYCL_FATTN_TILE_CONFIG_CASE( 96, 96, 4, 128, 2, 64, 48) + GGML_SYCL_FATTN_TILE_CONFIG_CASE( 96, 96, 8, 256, 2, 64, 48) + GGML_SYCL_FATTN_TILE_CONFIG_CASE( 96, 96, 16, 256, 2, 64, 48) + GGML_SYCL_FATTN_TILE_CONFIG_CASE( 96, 96, 32, 256, 2, 64, 48) + + GGML_SYCL_FATTN_TILE_CONFIG_CASE(112, 112, 2, 64, 2, 64, 56) + GGML_SYCL_FATTN_TILE_CONFIG_CASE(112, 112, 4, 128, 2, 64, 56) + GGML_SYCL_FATTN_TILE_CONFIG_CASE(112, 112, 8, 256, 2, 64, 56) + GGML_SYCL_FATTN_TILE_CONFIG_CASE(112, 112, 16, 256, 2, 64, 56) + GGML_SYCL_FATTN_TILE_CONFIG_CASE(112, 112, 32, 256, 2, 64, 56) + + GGML_SYCL_FATTN_TILE_CONFIG_CASE(128, 128, 2, 64, 2, 64, 64) + GGML_SYCL_FATTN_TILE_CONFIG_CASE(128, 128, 4, 128, 2, 64, 64) + GGML_SYCL_FATTN_TILE_CONFIG_CASE(128, 128, 8, 256, 2, 64, 64) + GGML_SYCL_FATTN_TILE_CONFIG_CASE(128, 128, 16, 256, 2, 64, 64) + GGML_SYCL_FATTN_TILE_CONFIG_CASE(128, 128, 32, 256, 2, 64, 64) + + GGML_SYCL_FATTN_TILE_CONFIG_CASE(256, 256, 2, 64, 2, 64, 64) + GGML_SYCL_FATTN_TILE_CONFIG_CASE(256, 256, 4, 128, 2, 64, 64) + GGML_SYCL_FATTN_TILE_CONFIG_CASE(256, 256, 8, 256, 2, 64, 64) + GGML_SYCL_FATTN_TILE_CONFIG_CASE(256, 256, 16, 256, 2, 64, 64) + GGML_SYCL_FATTN_TILE_CONFIG_CASE(256, 256, 32, 256, 2, 64, 64) + + GGML_SYCL_FATTN_TILE_CONFIG_CASE(512, 512, 2, 64, 2, 64, 64) + GGML_SYCL_FATTN_TILE_CONFIG_CASE(512, 512, 4, 128, 2, 64, 64) + GGML_SYCL_FATTN_TILE_CONFIG_CASE(512, 512, 8, 256, 2, 64, 64) + GGML_SYCL_FATTN_TILE_CONFIG_CASE(512, 512, 16, 256, 2, 64, 64) + GGML_SYCL_FATTN_TILE_CONFIG_CASE(512, 512, 32, 256, 2, 64, 64) + + GGML_SYCL_FATTN_TILE_CONFIG_CASE(576, 512, 4, 128, 2, 64, 64) + GGML_SYCL_FATTN_TILE_CONFIG_CASE(576, 512, 8, 256, 2, 64, 64) + GGML_SYCL_FATTN_TILE_CONFIG_CASE(576, 512, 16, 256, 2, 64, 64) + GGML_SYCL_FATTN_TILE_CONFIG_CASE(576, 512, 32, 256, 2, 64, 64) + + return 0; +} + +static constexpr uint32_t ggml_sycl_fattn_tile_get_config_fp32(const int DKQ, const int DV, const int ncols) { + GGML_SYCL_FATTN_TILE_CONFIG_CASE( 40, 40, 2, 64, 2, 32, 40) + GGML_SYCL_FATTN_TILE_CONFIG_CASE( 40, 40, 4, 128, 2, 32, 40) + GGML_SYCL_FATTN_TILE_CONFIG_CASE( 40, 40, 8, 256, 2, 32, 40) + GGML_SYCL_FATTN_TILE_CONFIG_CASE( 40, 40, 16, 256, 2, 32, 40) + GGML_SYCL_FATTN_TILE_CONFIG_CASE( 40, 40, 32, 256, 2, 32, 40) + + GGML_SYCL_FATTN_TILE_CONFIG_CASE( 64, 64, 2, 128, 3, 64, 64) + GGML_SYCL_FATTN_TILE_CONFIG_CASE( 64, 64, 4, 128, 3, 32, 64) + GGML_SYCL_FATTN_TILE_CONFIG_CASE( 64, 64, 8, 128, 3, 32, 64) + GGML_SYCL_FATTN_TILE_CONFIG_CASE( 64, 64, 16, 128, 3, 64, 64) + GGML_SYCL_FATTN_TILE_CONFIG_CASE( 64, 64, 32, 256, 2, 64, 64) + + GGML_SYCL_FATTN_TILE_CONFIG_CASE( 72, 72, 2, 64, 2, 32, 72) + GGML_SYCL_FATTN_TILE_CONFIG_CASE( 72, 72, 4, 128, 2, 32, 72) + GGML_SYCL_FATTN_TILE_CONFIG_CASE( 72, 72, 8, 256, 2, 32, 72) + GGML_SYCL_FATTN_TILE_CONFIG_CASE( 72, 72, 16, 256, 2, 32, 72) + GGML_SYCL_FATTN_TILE_CONFIG_CASE( 72, 72, 32, 256, 2, 32, 72) + + GGML_SYCL_FATTN_TILE_CONFIG_CASE( 80, 80, 2, 64, 2, 32, 40) + GGML_SYCL_FATTN_TILE_CONFIG_CASE( 80, 80, 4, 128, 2, 32, 40) + GGML_SYCL_FATTN_TILE_CONFIG_CASE( 80, 80, 8, 256, 2, 32, 40) + GGML_SYCL_FATTN_TILE_CONFIG_CASE( 80, 80, 16, 256, 2, 32, 40) + GGML_SYCL_FATTN_TILE_CONFIG_CASE( 80, 80, 32, 256, 2, 32, 40) + + GGML_SYCL_FATTN_TILE_CONFIG_CASE( 96, 96, 2, 64, 2, 32, 48) + GGML_SYCL_FATTN_TILE_CONFIG_CASE( 96, 96, 4, 128, 2, 32, 48) + GGML_SYCL_FATTN_TILE_CONFIG_CASE( 96, 96, 8, 256, 2, 32, 48) + GGML_SYCL_FATTN_TILE_CONFIG_CASE( 96, 96, 16, 256, 2, 32, 48) + GGML_SYCL_FATTN_TILE_CONFIG_CASE( 96, 96, 32, 256, 2, 32, 48) + + GGML_SYCL_FATTN_TILE_CONFIG_CASE(112, 112, 2, 64, 2, 32, 56) + GGML_SYCL_FATTN_TILE_CONFIG_CASE(112, 112, 4, 128, 2, 32, 56) + GGML_SYCL_FATTN_TILE_CONFIG_CASE(112, 112, 8, 256, 2, 32, 56) + GGML_SYCL_FATTN_TILE_CONFIG_CASE(112, 112, 16, 256, 2, 32, 56) + GGML_SYCL_FATTN_TILE_CONFIG_CASE(112, 112, 32, 256, 2, 32, 56) + + GGML_SYCL_FATTN_TILE_CONFIG_CASE(128, 128, 2, 128, 3, 64, 64) + GGML_SYCL_FATTN_TILE_CONFIG_CASE(128, 128, 4, 128, 3, 32, 128) + GGML_SYCL_FATTN_TILE_CONFIG_CASE(128, 128, 8, 128, 3, 64, 128) + GGML_SYCL_FATTN_TILE_CONFIG_CASE(128, 128, 16, 128, 3, 32, 128) + GGML_SYCL_FATTN_TILE_CONFIG_CASE(128, 128, 32, 256, 2, 64, 64) + + GGML_SYCL_FATTN_TILE_CONFIG_CASE(256, 256, 2, 128, 3, 64, 64) + GGML_SYCL_FATTN_TILE_CONFIG_CASE(256, 256, 4, 128, 3, 32, 64) + GGML_SYCL_FATTN_TILE_CONFIG_CASE(256, 256, 8, 256, 2, 32, 256) + GGML_SYCL_FATTN_TILE_CONFIG_CASE(256, 256, 16, 256, 2, 32, 128) + GGML_SYCL_FATTN_TILE_CONFIG_CASE(256, 256, 32, 256, 2, 32, 64) + + GGML_SYCL_FATTN_TILE_CONFIG_CASE(512, 512, 2, 128, 2, 64, 64) + GGML_SYCL_FATTN_TILE_CONFIG_CASE(512, 512, 4, 128, 2, 64, 64) + GGML_SYCL_FATTN_TILE_CONFIG_CASE(512, 512, 8, 256, 2, 64, 64) + GGML_SYCL_FATTN_TILE_CONFIG_CASE(512, 512, 16, 256, 2, 64, 64) + GGML_SYCL_FATTN_TILE_CONFIG_CASE(512, 512, 32, 256, 2, 64, 64) + + GGML_SYCL_FATTN_TILE_CONFIG_CASE(576, 512, 4, 128, 2, 32, 64) + GGML_SYCL_FATTN_TILE_CONFIG_CASE(576, 512, 8, 256, 2, 32, 64) + GGML_SYCL_FATTN_TILE_CONFIG_CASE(576, 512, 16, 256, 2, 32, 64) + + return 0; +} + +static constexpr uint32_t ggml_sycl_fattn_tile_get_config(const int DKQ, const int DV, const int ncols, const int cc) { + if(fast_fp16_available(cc)) + return ggml_sycl_fattn_tile_get_config_fp16(DKQ, DV, ncols); + else + return ggml_sycl_fattn_tile_get_config_fp32(DKQ, DV, ncols); +} + +static constexpr uint32_t ggml_sycl_fattn_tile_get_config(const int DKQ, const int DV, const int ncols) { +#ifdef SYCL_FAST_FP16 + return ggml_sycl_fattn_tile_get_config_fp16(DKQ, DV, ncols); +#else + return ggml_sycl_fattn_tile_get_config_fp32(DKQ, DV, ncols); +#endif // SYCL_FAST_FP16 +} + +static int ggml_sycl_fattn_tile_get_nthreads(const int DKQ, const int DV, const int ncols, const int cc) { + return (ggml_sycl_fattn_tile_get_config(DKQ, DV, ncols, cc) >> 0) & ((1 << 10) - 1); +} + +static constexpr int ggml_sycl_fattn_tile_get_nthreads(const int DKQ, const int DV, const int ncols) { + return (ggml_sycl_fattn_tile_get_config(DKQ, DV, ncols) >> 0) & ((1 << 10) - 1); +} + +static int ggml_sycl_fattn_tile_get_occupancy(const int DKQ, const int DV, const int ncols, const int cc) { + return (ggml_sycl_fattn_tile_get_config(DKQ, DV, ncols, cc) >> 10) & ((1 << 4) - 1); +} + +static constexpr int ggml_sycl_fattn_tile_get_occupancy(const int DKQ, const int DV, const int ncols) { + return (ggml_sycl_fattn_tile_get_config(DKQ, DV, ncols) >> 10) & ((1 << 4) - 1); +} + +static int ggml_sycl_fattn_tile_get_nbatch_fa(const int DKQ, const int DV, const int ncols, const int cc) { + return (ggml_sycl_fattn_tile_get_config(DKQ, DV, ncols, cc) >> 14) & ((1 << 9) - 1); +} + +static constexpr int ggml_sycl_fattn_tile_get_nbatch_fa(const int DKQ, const int DV, const int ncols) { + return (ggml_sycl_fattn_tile_get_config(DKQ, DV, ncols) >> 14) & ((1 << 9) - 1); +} + +static int ggml_sycl_fattn_tile_get_nbatch_K(const int DKQ, const int DV, const int ncols, const int cc) { + return (ggml_sycl_fattn_tile_get_config(DKQ, DV, ncols, cc) >> 23) & ((1 << 9) - 1); +} + +static constexpr int ggml_sycl_fattn_tile_get_nbatch_K(const int DKQ, const int DV, const int ncols) { + return (ggml_sycl_fattn_tile_get_config(DKQ, DV, ncols) >> 23) & ((1 << 9) - 1); +} + +template <int warp_size, int nwarps, int I, int J, int J_padding, bool oob_check> +static __dpct_inline__ void flash_attn_tile_load_tile(const sycl::half2 * const __restrict__ KV, + sycl::half2 * const __restrict__ tile_KV, + const int stride_KV, + const int i_sup) { + auto item_ct1 = sycl::ext::oneapi::this_work_item::get_nd_item<3>(); + constexpr int cpy_nb = ggml_sycl_get_max_cpy_bytes(); + constexpr int cpy_ne = cpy_nb / 4; + + auto load = [&] (const int n) { + const int stride_j = warp_size >> n; + + if (stride_j == 0) { + return; + } + + const int j0_start = stride_j == warp_size ? 0 : ((J/2)/cpy_ne) - ((J/2)/cpy_ne) % (2*stride_j); + const int j0_stop = ((J/2)/cpy_ne) - ((J/2)/cpy_ne) % (1*stride_j); + const int stride_i = warp_size / stride_j; + + if (j0_start == j0_stop) { + return; + } + +#pragma unroll + for (int i0 = 0; i0 < I; i0 += nwarps*stride_i) { + const int i = i0 + item_ct1.get_local_id(1) * stride_i + + (stride_j == warp_size ? 0 : item_ct1.get_local_id(2) / stride_j); + + if (i0 + nwarps*stride_i <= I || i < I) { +#pragma unroll + for (int j0 = j0_start; j0 < j0_stop; j0 += stride_j) { + const int j = j0 * cpy_ne + (stride_j == warp_size ? item_ct1.get_local_id(2) : + item_ct1.get_local_id(2) % stride_j) * + cpy_ne; + + const __dpct_align__(16) sycl::half2 zero[cpy_ne] = { + { 0.0f, 0.0f } + }; + ggml_sycl_memcpy_1<cpy_nb>( + tile_KV + i*(J/2 + J_padding) + j, + !oob_check || i < i_sup ? KV + i*stride_KV + j : zero); + } + } + } + }; + // 1: max 64*16=512 bytes, 512 half + // 2: max 32*16=512 bytes, 256 half + // 3: max 16*16=256 bytes, 128 half + // 4: max 8*16=128 bytes, 64 half + // 5: max 4*16= 64 bytes, 32 half + // 6: max 2*16= 32 bytes, 16 half + // 7: max 1*16= 16 bytes, 8 half + static_assert(J % 8 == 0, "bad J"); + static_assert((J/2) % cpy_ne == 0, "bad J"); + ggml_sycl_unroll<7>{}(load); +} + +template <int warp_size, int nwarps, int I, int J, int J_padding, bool oob_check> +static __dpct_inline__ void flash_attn_tile_load_tile(const sycl::half2 * const __restrict__ KV, + float * const __restrict__ tile_KV, + const int stride_KV, + const int i_sup) { + constexpr int cpy_nb = ggml_sycl_get_max_cpy_bytes(); + constexpr int cpy_ne = cpy_nb / 4; + + auto load = [&] (const int n) { + auto item_ct1 = sycl::ext::oneapi::this_work_item::get_nd_item<3>(); + const int stride_j = warp_size >> n; + + if (stride_j == 0) { + return; + } + + const int j0_start = stride_j == warp_size ? 0 : (J/cpy_ne) - (J/cpy_ne) % (2*stride_j); + const int j0_stop = (J/cpy_ne) - (J/cpy_ne) % (1*stride_j); + const int stride_i = warp_size / stride_j; + + if (j0_start == j0_stop) { + return; + } + +#pragma unroll + for (int i0 = 0; i0 < I; i0 += nwarps*stride_i) { + const int i = i0 + item_ct1.get_local_id(1) * stride_i + + (stride_j == warp_size ? 0 : item_ct1.get_local_id(2) / stride_j); + + if (i0 + nwarps*stride_i <= I || i < I) { +#pragma unroll + for (int j0 = j0_start; j0 < j0_stop; j0 += stride_j) { + const int j = j0 * (cpy_ne / 2) + (stride_j == warp_size ? item_ct1.get_local_id(2) : + item_ct1.get_local_id(2) % stride_j) * + (cpy_ne / 2); + + const sycl::half2 zero[cpy_ne / 2] = { + { 0.0f, 0.0f } + }; + __dpct_align__(16) sycl::half2 tmp_h2[cpy_ne / 2]; + ggml_sycl_memcpy_1<sizeof(tmp_h2)>( + tmp_h2, !oob_check || i < i_sup ? KV + i*stride_KV + j : zero); + + __dpct_align__(16) sycl::float2 tmp_f2[cpy_ne / 2]; +#pragma unroll + for (int l = 0; l < cpy_ne/2; ++l) { + tmp_f2[l] = tmp_h2[l].template convert<float, sycl::rounding_mode::automatic>(); + } + ggml_sycl_memcpy_1<sizeof(tmp_f2)>(tile_KV + i*(J + J_padding) + 2*j, tmp_f2); + } + } + } + }; + // 1: max 32*16=512 bytes, 128 float + // 2: max 16*16=256 bytes, 64 float + // 3: max 8*16=128 bytes, 32 float + // 4: max 4*16= 64 bytes, 16 float + // 5: max 2*16= 32 bytes, 8 float + static_assert(J % 8 == 0, "bad J"); + static_assert(J % cpy_ne == 0, "bad J"); + ggml_sycl_unroll<5>{}(load); +} + +// Function that performs a single iteration in for the KQ matrix multiplication: +template <int warp_size, + int nwarps, + int ncols1, + int ncols2, + int DKQ, + int nbatch_fa, + int nbatch_K, + bool use_logit_softcap, + bool oob_check, + typename T_vec_dot> +static __dpct_inline__ void flash_attn_tile_iter_KQ(T_vec_dot * const Q_tmp, + const sycl::half2 * const __restrict__ K_h2, + T_vec_dot * const KV_tmp, + const int stride_K2, + const int k_VKQ_0, + const int k_VKQ_sup, + const int k_KQ_0, + float * KQ_acc) { + auto item_ct1 = sycl::ext::oneapi::this_work_item::get_nd_item<3>(); + constexpr int cpy_nb = ggml_sycl_get_max_cpy_bytes(); + constexpr int cpy_ne = cpy_nb / 4; + + constexpr int ncols = ncols1*ncols2; + constexpr int cpw = ncols > nwarps ? ncols/nwarps : 1; // Q columns per warp + constexpr int np = nwarps > ncols ? nwarps/ncols : 1; // number of parallel warps per Q column + + flash_attn_tile_load_tile<warp_size, nwarps, nbatch_fa, nbatch_K, cpy_ne, oob_check> + (K_h2 + int64_t(k_VKQ_0)*stride_K2 + k_KQ_0/2, KV_tmp, stride_K2, k_VKQ_sup); + item_ct1.barrier(sycl::access::fence_space::local_space); + +#ifdef SYCL_FAST_FP16 + static_assert((nbatch_K/2) % cpy_ne == 0, "bad nbatch_K"); +#pragma unroll + for (int k_KQ_1 = 0; k_KQ_1 < nbatch_K/2; k_KQ_1 += cpy_ne) { + __dpct_align__(16) sycl::half2 K_k[nbatch_fa / (np * warp_size)][cpy_ne]; + __dpct_align__(16) sycl::half2 Q_k[cpw][cpy_ne]; +#else + static_assert(nbatch_K % cpy_ne == 0, "bad nbatch_K"); +#pragma unroll + for (int k_KQ_1 = 0; k_KQ_1 < nbatch_K; k_KQ_1 += cpy_ne) { + __dpct_align__(16) float K_k[nbatch_fa/(np*warp_size)][cpy_ne]; + __dpct_align__(16) float Q_k[cpw][cpy_ne]; +#endif // SYCL_FAST_FP16 + +#pragma unroll + for (int i_KQ_0 = 0; i_KQ_0 < nbatch_fa; i_KQ_0 += np*warp_size) { + const int i_KQ = i_KQ_0 + (item_ct1.get_local_id(1) % np) * warp_size + item_ct1.get_local_id(2); + +#ifdef SYCL_FAST_FP16 + ggml_sycl_memcpy_1<cpy_nb>(&K_k[i_KQ_0/(np*warp_size)], &KV_tmp[i_KQ*(nbatch_K/2 + cpy_ne) + k_KQ_1]); +#else + ggml_sycl_memcpy_1<cpy_nb>(&K_k[i_KQ_0/(np*warp_size)], &KV_tmp[i_KQ*(nbatch_K + cpy_ne) + k_KQ_1]); +#endif // SYCL_FAST_FP16 + } +#pragma unroll + for (int jc0 = 0; jc0 < cpw; ++jc0) { + const int jc = jc0 + (item_ct1.get_local_id(1) / np) * cpw; + +#ifdef SYCL_FAST_FP16 + ggml_sycl_memcpy_1<cpy_nb>(&Q_k[jc0], &Q_tmp[jc*(DKQ/2) + k_KQ_0/2 + k_KQ_1]); +#else + ggml_sycl_memcpy_1<cpy_nb>(&Q_k[jc0], &Q_tmp[jc* DKQ + k_KQ_0 + k_KQ_1]); +#endif // SYCL_FAST_FP16 + } + +#pragma unroll + for (int i_KQ_0 = 0; i_KQ_0 < nbatch_fa; i_KQ_0 += np*warp_size) { +#pragma unroll + for (int jc0 = 0; jc0 < cpw; ++jc0) { +#pragma unroll + for (int k = 0; k < cpy_ne; ++k) { + ggml_sycl_mad(KQ_acc[i_KQ_0/(np*warp_size)*cpw + jc0], K_k[i_KQ_0/(np*warp_size)][k], Q_k[jc0][k]); + } + } + } + } + + if (k_KQ_0 + nbatch_K < DKQ) { + item_ct1.barrier(sycl::access::fence_space::local_space); // Sync not needed on last iteration. + } +} + +// Function that performs a single iteration of the main loop over up to nbatch_fa tokens. +template <int warp_size, + int nwarps, + int ncols1, + int ncols2, + int DKQ, + int DV, + int nbatch_fa, + int nbatch_K, + bool use_logit_softcap, + bool oob_check, + typename T_vec_dot, + typename T_KQ, + typename T_acc> +/* +The total declared local variable size in device function flash_attn_tile_iter exceeds 128 bytes and may cause high register pressure. Consult with your hardware vendor to find the total register size available and adjust the code, or use smaller sub-group size to avoid high register pressure. +*/ +static __dpct_inline__ void flash_attn_tile_iter(T_vec_dot * const Q_tmp, + const sycl::half2 * const __restrict__ K_h2, + const sycl::half2 * const __restrict__ V_h2, + const sycl::half * const __restrict__ mask, + const sycl::uint3 ne01, + const float logit_softcap, + const float slope, + T_KQ * const KQ, + T_vec_dot * const KV_tmp, + const int stride_K2, + const int stride_V2, + const int stride_mask, + float * const KQ_max, + float * const KQ_sum, + T_acc * const VKQ, + const int k_VKQ_0, + const int k_VKQ_max, + const int col_Q_0, + float * KQ_max_new_shared) { + auto item_ct1 = sycl::ext::oneapi::this_work_item::get_nd_item<3>(); + constexpr int cpy_nb = ggml_sycl_get_max_cpy_bytes(); + constexpr int cpy_ne = cpy_nb / 4; + + constexpr int ncols = ncols1*ncols2; + constexpr int cpw = ncols > nwarps ? ncols/nwarps : 1; // Q columns per warp + constexpr int np = nwarps > ncols ? nwarps/ncols : 1; // number of parallel warps per Q column + + constexpr int DVp = (DV + 2*warp_size - 1) & ~(2*warp_size - 1); // DV padded to multiple of 2*warp_size. + +#ifdef SYCL_FAST_FP16 + constexpr int KQ_cs = cpw < 2*cpy_ne ? cpw : 2*cpy_ne; +#else + constexpr int KQ_cs = cpw < 1*cpy_ne ? cpw : 1*cpy_ne; +#endif // SYCL_FAST_FP16 + static_assert(cpw % KQ_cs == 0, "bad KQ_cs"); + const int k_VKQ_sup = k_VKQ_max - k_VKQ_0; // k supremum, only smaller k values have valid KV data + + float KQ_max_new[cpw]; +#pragma unroll + for (int jc0 = 0; jc0 < cpw; ++jc0) { + KQ_max_new[jc0] = KQ_max[jc0]; + } + + float KQ_acc[nbatch_fa/(np*warp_size) * cpw] = {0.0f}; // Accumulators for KQ matrix multiplication. + + // KQ = K @ Q matrix multiplication: + constexpr int nbatch_K_last = DKQ % nbatch_K; +#pragma unroll + for (int k_KQ_0 = 0; k_KQ_0 < DKQ - nbatch_K_last; k_KQ_0 += nbatch_K) { + flash_attn_tile_iter_KQ<warp_size, nwarps, ncols1, ncols2, DKQ, nbatch_fa, nbatch_K, use_logit_softcap, oob_check>( + Q_tmp, K_h2, KV_tmp, stride_K2, k_VKQ_0, k_VKQ_sup, k_KQ_0, KQ_acc); + } + if (nbatch_K_last > 0) { + constexpr int k_KQ_0 = DKQ - nbatch_K_last; + flash_attn_tile_iter_KQ<warp_size, nwarps, ncols1, ncols2, DKQ, nbatch_fa, nbatch_K_last, use_logit_softcap, oob_check>( + Q_tmp, K_h2, KV_tmp, stride_K2, k_VKQ_0, k_VKQ_sup, k_KQ_0, KQ_acc); + } + + // Apply logit softcap + mask, update KQ_max: +#pragma unroll + for (int jc0 = 0; jc0 < cpw; ++jc0) { + const int j = fastmodulo(col_Q_0 + (jc0 + (item_ct1.get_local_id(1) / np) * cpw) / ncols2, ne01); + +#pragma unroll + for (int i_KQ_0 = 0; i_KQ_0 < nbatch_fa; i_KQ_0 += np*warp_size) { + const int i_KQ = i_KQ_0 + (item_ct1.get_local_id(1) % np) * warp_size + item_ct1.get_local_id(2); + +#if defined(SYCL_FAST_FP16) && !defined(GGML_SYCL_F16) + // Without the v_dot2_f32_f16 instruction there is a higher risk of numerical overflow in the KQ calculation. + // Therefore, scale down Q values and apply the inverse scale the FP32 KQ values afterwards again. + KQ_acc[i_KQ_0/(np*warp_size)*cpw + jc0] *= 4.0f; +#endif // defined(SYCL_FAST_FP16) && !defined(GGML_SYCL_F16) + + if (use_logit_softcap) { + KQ_acc[(i_KQ_0 / (np * warp_size)) * cpw + jc0] = + logit_softcap * sycl::tanh((float) KQ_acc[(i_KQ_0 / (np * warp_size)) * cpw + jc0]); + } + + if (!oob_check || i_KQ < k_VKQ_sup) { + KQ_acc[(i_KQ_0 / (np * warp_size)) * cpw + jc0] += + (ncols2 > 1 || mask) ? slope * sycl::vec<sycl::half, 1>(mask[j * stride_mask + k_VKQ_0 + i_KQ]) + .convert<float, sycl::rounding_mode::automatic>()[0] : + 0.0f; + + KQ_max_new[jc0] = + sycl::fmax((float) KQ_max_new[jc0], + (float) (KQ_acc[(i_KQ_0 / (np * warp_size)) * cpw + jc0] + FATTN_KQ_MAX_OFFSET)); + } + } + + KQ_max_new[jc0] = warp_reduce_max<warp_size>(KQ_max_new[jc0]); + } + + if constexpr (np == 1) { + item_ct1.barrier(sycl::access::fence_space::local_space); + } else { + static_assert(cpw == 1, "bad cpw"); + + if (item_ct1.get_local_id(2) == 0) { + KQ_max_new_shared[item_ct1.get_local_id(1)] = KQ_max_new[0]; + } + item_ct1.barrier(sycl::access::fence_space::local_space); + KQ_max_new[0] = KQ_max_new_shared[(item_ct1.get_local_id(1) & ~(np - 1)) + item_ct1.get_local_id(2) % np]; + KQ_max_new[0] = warp_reduce_max<np>(KQ_max_new[0]); + } + + // Calculate KQ softmax, write to shared KQ buffer, re-scale VKQ accumulators: +#pragma unroll + for (int jc0 = 0; jc0 < cpw; jc0 += KQ_cs) { +#ifdef SYCL_FAST_FP16 + __dpct_align__(16) sycl::half tmp[nbatch_fa / (np * warp_size)][KQ_cs]; +#else + __dpct_align__(16) float tmp[nbatch_fa/(np*warp_size)][KQ_cs]; +#endif // SYCL_FAST_FP16 + +#pragma unroll + for (int jc1 = 0; jc1 < KQ_cs; ++jc1) { + const int jc = jc0 + jc1; + + const float KQ_max_scale = sycl::native::exp((float) (KQ_max[jc] - KQ_max_new[jc])); + KQ_max[jc] = KQ_max_new[jc]; + + float KQ_sum_add = 0.0f; +#pragma unroll + for (int i0 = 0; i0 < nbatch_fa; i0 += np*warp_size) { + const float val = + !oob_check || i0 + (item_ct1.get_local_id(1) % np) * warp_size + item_ct1.get_local_id(2) < + static_cast<uint32_t>(k_VKQ_sup) ? + sycl::native::exp((float) (KQ_acc[(i0 / (np * warp_size)) * cpw + jc] - KQ_max[jc])) : + 0.0f; + KQ_sum_add += val; + tmp[i0/(np*warp_size)][jc1] = val; + } + KQ_sum[jc] = KQ_sum[jc]*KQ_max_scale + KQ_sum_add; + +#ifdef SYCL_FAST_FP16 + const sycl::half2 KQ_max_scale_h2 = sycl::half2(KQ_max_scale, KQ_max_scale); +#pragma unroll + for (int i0 = 0; i0 < DVp/2; i0 += warp_size) { + VKQ[jc*((DVp/2)/warp_size) + i0/warp_size].x() *= KQ_max_scale_h2.x(); + VKQ[jc*((DVp/2)/warp_size) + i0/warp_size].y() *= KQ_max_scale_h2.y(); + } +#else +#pragma unroll + for (int i0 = 0; i0 < DVp/2; i0 += warp_size) { + VKQ[jc*((DVp/2)/warp_size) + i0/warp_size].x() *= KQ_max_scale; + VKQ[jc*((DVp/2)/warp_size) + i0/warp_size].y() *= KQ_max_scale; + } +#endif // SYCL_FAST_FP16 + } + +#pragma unroll + for (int i0 = 0; i0 < nbatch_fa; i0 += np*warp_size) { + const int i = i0 + (item_ct1.get_local_id(1) % np) * warp_size + item_ct1.get_local_id(2); + + ggml_sycl_memcpy_1<sizeof(tmp[0])>( + KQ + (jc0 / KQ_cs + (item_ct1.get_local_id(1) / np) * (cpw / KQ_cs)) * (nbatch_fa * KQ_cs) + i * KQ_cs, + tmp[i0 / (np * warp_size)]); + } + } + + // VKQ = V @ KQ matrix multiplication: + static_assert(DV <= DKQ, "bad DV"); + static_assert(DV % nbatch_K == 0 || (nbatch_K % 3 == 0 && DV % (nbatch_K*2/3) == 0), "bad nbatch_K"); + constexpr int nbatch_V = (DV % nbatch_K == 0 ? nbatch_K : nbatch_K*2/3) * nbatch_fa / DV; // Number of V columns that fit in SRAM for K. + static_assert(nbatch_fa % nbatch_V == 0, "bad nbatch_V"); + static_assert(nbatch_V % np == 0, "bad nbatch_V"); +#pragma unroll + for (int k0 = 0; k0 < nbatch_fa; k0 += nbatch_V) { + flash_attn_tile_load_tile<warp_size, nwarps, nbatch_V, DV, 0, oob_check> + (V_h2 + int64_t(k_VKQ_0 + k0)*stride_V2, KV_tmp, stride_V2, k_VKQ_sup - k0); + item_ct1.barrier(sycl::access::fence_space::local_space); + +#ifdef SYCL_FAST_FP16 +#pragma unroll + for (int k1 = 0; k1 < nbatch_V; k1 += np) { + __dpct_align__(16) sycl::half2 V_k[(DVp / 2) / warp_size]; + __dpct_align__(16) sycl::half2 KQ_k[cpw]; + + constexpr int cpy_ne_D = cpy_ne/2 < (DVp/2)/warp_size ? cpy_ne/2 : (DVp/2)/warp_size; +#pragma unroll + for (int i0 = 0; i0 < DVp/2; i0 += warp_size*cpy_ne_D) { + ggml_sycl_memcpy_1<cpy_ne_D * 4>(&V_k[i0 / warp_size], + &KV_tmp[(k1 + item_ct1.get_local_id(1) % np) * (DV / 2) + i0 + + item_ct1.get_local_id(2) * cpy_ne_D]); + } +#pragma unroll + for (int jc_VKQ_0 = 0; jc_VKQ_0 < cpw; jc_VKQ_0 += KQ_cs) { + const int jc_KQ = jc_VKQ_0 / KQ_cs + (item_ct1.get_local_id(1) / np) * (cpw / KQ_cs); + + __dpct_align__(16) sycl::half tmp[KQ_cs]; + ggml_sycl_memcpy_1<KQ_cs * sizeof(sycl::half)>( + &tmp, KQ + jc_KQ * (nbatch_fa * KQ_cs) + (k0 + k1 + item_ct1.get_local_id(1) % np) * KQ_cs); +#pragma unroll + for (int jc_VKQ_1 = 0; jc_VKQ_1 < KQ_cs; ++jc_VKQ_1) { + KQ_k[jc_VKQ_0 + jc_VKQ_1] = sycl::half2(tmp[jc_VKQ_1]); + } + } + +#pragma unroll + for (int i0 = 0; i0 < DVp/2; i0 += warp_size) { +#pragma unroll + for (int jc_VKQ_0 = 0; jc_VKQ_0 < cpw; ++jc_VKQ_0) { + VKQ[jc_VKQ_0*((DVp/2)/warp_size) + i0/warp_size].x() += + V_k[i0/warp_size].x()*KQ_k[jc_VKQ_0].x(); + VKQ[jc_VKQ_0*((DVp/2)/warp_size) + i0/warp_size].y() += + V_k[i0/warp_size].y()*KQ_k[jc_VKQ_0].y(); + } + } + } +#else +#pragma unroll + for (int k1 = 0; k1 < nbatch_V; k1 += np) { + __dpct_align__(16) sycl::float2 V_k[(DVp/2)/warp_size]; + __dpct_align__(16) float KQ_k[cpw]; + + constexpr int cpy_ne_D = cpy_ne < DVp/warp_size ? cpy_ne : DVp/warp_size; +#pragma unroll + for (int i0 = 0; i0 < DVp; i0 += warp_size*cpy_ne_D) { + ggml_sycl_memcpy_1<cpy_ne_D*4>(&V_k[i0/(2*warp_size)], &KV_tmp[(k1 + item_ct1.get_local_id(1) % np)*DV + i0 + item_ct1.get_local_id(2)*cpy_ne_D]); + } +#pragma unroll + for (int jc_VKQ_0 = 0; jc_VKQ_0 < cpw; jc_VKQ_0 += KQ_cs) { + const int jc_KQ = jc_VKQ_0/KQ_cs + (item_ct1.get_local_id(1) / np)*(cpw/KQ_cs); + + ggml_sycl_memcpy_1<KQ_cs*sizeof(float)>( + &KQ_k[jc_VKQ_0], KQ + jc_KQ*(nbatch_fa*KQ_cs) + (k0 + k1 + item_ct1.get_local_id(1) % np)*KQ_cs); + } + +#pragma unroll + for (int i0 = 0; i0 < DVp/2; i0 += warp_size) { +#pragma unroll + for (int jc_VKQ_0 = 0; jc_VKQ_0 < cpw; ++jc_VKQ_0) { + VKQ[jc_VKQ_0*((DVp/2)/warp_size) + i0/warp_size].x() += V_k[i0/warp_size].x()*KQ_k[jc_VKQ_0]; + VKQ[jc_VKQ_0*((DVp/2)/warp_size) + i0/warp_size].y() += V_k[i0/warp_size].y()*KQ_k[jc_VKQ_0]; + } + } + } +#endif // SYCL_FAST_FP16 + item_ct1.barrier(sycl::access::fence_space::local_space); + } +} + +template <int DKQ, int DV, int ncols1, int ncols2, bool use_logit_softcap, int warp_size> // D == head size +/* +The total declared local variable size in device function flash_attn_tile exceeds 128 bytes and may cause high register pressure. Consult with your hardware vendor to find the total register size available and adjust the code, or use smaller sub-group size to avoid high register pressure. +*/ +static void flash_attn_tile(const char * Q, + const char * K, + const char * V, + const char * mask, + const char * sinks, + const int * KV_max, + float * dst, + sycl::float2 * dst_meta, + const float scale, + const float max_bias, + const float m0, + const float m1, + const uint32_t n_head_log2, + const float logit_softcap, + const int32_t ne00, + const sycl::uint3 ne01, + const int32_t ne02, + const int32_t ne03, + const int32_t nb01, + const int32_t nb02, + const int32_t nb03, + const int32_t ne10, + const int32_t ne11, + const int32_t ne12, + const int32_t ne13, + const int32_t nb11, + const int32_t nb12, + const int64_t nb13, + const int32_t nb21, + const int32_t nb22, + const int64_t nb23, + const int32_t ne31, + const int32_t ne32, + const int32_t ne33, + const int32_t nb31, + const int32_t nb32, + const int64_t nb33) { +#ifdef SYCL_FLASH_ATTN + // Skip unused kernel variants for faster compilation: + auto item_ct1 = sycl::ext::oneapi::this_work_item::get_nd_item<3>(); + if ((use_logit_softcap && !(DV == 128 || DV == 256))) { + GGML_UNUSED_VARS(Q, K, V, mask, sinks, KV_max, dst, dst_meta, scale, + max_bias, m0, m1, n_head_log2, logit_softcap, + ne00, ne01, ne02, ne03, + nb01, nb02, nb03, + ne10, ne11, ne12, ne13, + nb11, nb12, nb13, + nb21, nb22, nb23, + ne31, ne32, ne33, + nb31, nb32, nb33); + return; + } + + static_assert(ggml_sycl_fattn_tile_get_config(DKQ, DV, ncols1*ncols2) != 0, "kernel config not defined"); + + constexpr int ncols = ncols1*ncols2; + + constexpr int nwarps = ggml_sycl_fattn_tile_get_nthreads (DKQ, DV, ncols1*ncols2) / warp_size; + constexpr int nbatch_fa = ggml_sycl_fattn_tile_get_nbatch_fa(DKQ, DV, ncols1*ncols2); + constexpr int nbatch_K = ggml_sycl_fattn_tile_get_nbatch_K (DKQ, DV, ncols1*ncols2); + + // In this kernel Q, K, V are matrices while i, j, k are matrix indices. + + const int col_Q_0 = item_ct1.get_group(2) * ncols1; // Index of the first Q column for this SYCL block to work on. + + const int sequence = item_ct1.get_group(0) / (ne02 / ncols2); + const int head0 = item_ct1.get_group(0) * ncols2 - sequence * ne02; // == item_ct1.get_group(0) % (ne02/ncols2) + const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix. + const float * Q_f = (const float *) (Q + nb03*sequence + nb02* head0); + const sycl::half2 * K_h2 = (const sycl::half2 *) (K + nb13 * sequence + nb12 * (head0 / gqa_ratio)); + const sycl::half2 * V_h2 = + (const sycl::half2 *) (V + nb23 * sequence + nb22 * (head0 / gqa_ratio)); // K and V have same shape + + const sycl::half * maskh = mask ? (const sycl::half *) (mask + nb33 * (sequence % ne33)) : nullptr; + + const int stride_K2 = nb11 / sizeof(sycl::half2); + const int stride_V2 = nb21 / sizeof(sycl::half2); + const int stride_mask = nb31 / sizeof(sycl::half); + + const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, head0, n_head_log2, m0, m1) : 1.0f; + + constexpr int cpy_nb = ggml_sycl_get_max_cpy_bytes(); + constexpr int cpy_ne = cpy_nb / 4; + + constexpr int cpw = ncols > nwarps ? ncols/nwarps : 1; // Q columns per warp. + constexpr int np = nwarps > ncols ? nwarps/ncols : 1; // Number of parallel warps per Q column. + + static_assert(cpw == 1 || np == 1, "bad cpw / np"); + static_assert(nbatch_fa % (np*warp_size) == 0, "nbatch_fa % (np*warp_size) != 0"); + + constexpr int DKQp = (DKQ + 2*warp_size - 1) & ~(2*warp_size - 1); // DKQ padded to multiple of 2*warp_size. + constexpr int DVp = (DV + 2*warp_size - 1) & ~(2*warp_size - 1); // DV padded to multiple of 2*warp_size. + + // Q_tmp == SRAM buffer to hold Q data for the entire lifetime of the kernel. + // KV_tmp == SRAM buffer to hold fragments of K/V data while iterating over ne11. + // KV_tmp is padded to avoid memory conflicts for K (cpy_ne) and OOB accesses for V (DVp-DV). + // KQ == SRAM buffer to hold KQ fragments between KQ and VKQ matrix multiplications. + // VKQ == Accumulators in registers for the final VKQ result. + + +#ifdef SYCL_FAST_FP16 + constexpr size_t lsm_size1 = ncols * DKQ/2 ; + constexpr size_t lsm_size2 = nbatch_fa * (nbatch_K/2 + cpy_ne) + DVp-DV ; + constexpr size_t lsm_size3 = ncols * nbatch_fa; + constexpr size_t lsm_size4 = nwarps; + + constexpr size_t local_share_mem_size = lsm_size1 * sizeof(sycl::half2) + + lsm_size2 * sizeof(sycl::half2) + + lsm_size3 * sizeof(sycl::half) + + lsm_size4 * sizeof(float); + + syclex::work_group_static<char[local_share_mem_size]> lsm; + + sycl::half2 *Q_tmp = (sycl::half2 *)&lsm; + sycl::half2 *KV_tmp = (sycl::half2*)(Q_tmp +lsm_size1); + sycl::half *KQ = (sycl::half *)(KV_tmp+lsm_size2); + float *KQ_max_new_shared = (float *)(KQ+lsm_size3); + + __dpct_align__(16) sycl::half2 VKQ[cpw * ((DVp / 2) / warp_size)] = { + { 0.0f, 0.0f } + }; +#else + constexpr size_t lsm_size1 = ncols * DKQ ; + constexpr size_t lsm_size2 = nbatch_fa * (nbatch_K + cpy_ne) + DVp-DV; + constexpr size_t lsm_size3 = ncols * nbatch_fa; + constexpr size_t lsm_size4 = nwarps; + + constexpr size_t local_share_mem_size = (lsm_size1 + lsm_size2 +lsm_size3 + lsm_size4) * sizeof(float); + + syclex::work_group_static<char[local_share_mem_size]> lsm; + + float *Q_tmp = (float *)&lsm; + float *KV_tmp = Q_tmp +lsm_size1; + float *KQ = KV_tmp+lsm_size2; + float *KQ_max_new_shared = KQ+lsm_size3; + + __dpct_align__(16) sycl::float2 VKQ[cpw * ((DVp/2)/warp_size)] = {{0.0f, 0.0f}}; + + +#endif // SYCL_FAST_FP16 + + float KQ_max[cpw] = {}; + +#pragma unroll + for (int j0 = 0; j0 < ncols; j0 += nwarps) { + KQ_max[j0/nwarps] = -FLT_MAX/2.0f; + } + float KQ_sum[cpw] = {0.0f}; + + // Load Q data, convert to FP16 if fast: +#pragma unroll + for (int jc0 = 0; jc0 < cpw; ++jc0) { + const int jc = jc0 + (item_ct1.get_local_id(1) / np) * cpw; + + const int j = jc / ncols2; + const int c = jc % ncols2; + + constexpr int cpy_ne_D = cpy_ne < DKQp/warp_size ? cpy_ne : DKQp/warp_size; + +#pragma unroll + for (int i0 = 0; i0 < DKQp; i0 += np*warp_size*cpy_ne_D) { + if (i0 + np * warp_size * cpy_ne_D <= DKQ || + i0 + (item_ct1.get_local_id(1) % np) * (warp_size * cpy_ne_D) + item_ct1.get_local_id(2) * cpy_ne_D < + DKQ) { + __dpct_align__(16) float tmp_f[cpy_ne_D] = { 0.0f }; + ggml_sycl_memcpy_1<sizeof(tmp_f)>( + tmp_f, &Q_f[c * (nb02 / sizeof(float)) + fastmodulo(col_Q_0 + j, ne01) * (nb01 / sizeof(float)) + + i0 + (item_ct1.get_local_id(1) % np) * (warp_size * cpy_ne_D) + + item_ct1.get_local_id(2) * cpy_ne_D]); + +#pragma unroll + for (int i1 = 0; i1 < cpy_ne_D; ++i1) { + tmp_f[i1] *= scale; + } + +#ifdef SYCL_FAST_FP16 + __dpct_align__(16) sycl::half2 tmp_h2[cpy_ne_D / 2]; +#pragma unroll + for (int i1 = 0; i1 < cpy_ne_D; i1 += 2) { + tmp_h2[i1/2] = make_half2(tmp_f[i1 + 0], tmp_f[i1 + 1]); +#if defined(SYCL_FAST_FP16) && !defined(GGML_SYCL_F16) + // Without the v_dot2_f32_f16 instruction there is a higher risk of numerical overflow in the KQ calculation. + // Therefore, scale down Q values and apply the inverse scale the FP32 KQ values afterwards again. + tmp_h2[i1 / 2] *= sycl::half2(0.25f, 0.25f); +#endif // defined(SYCL_FAST_FP16) && !defined(GGML_SYCL_F16) + } + ggml_sycl_memcpy_1<sizeof(tmp_h2)>( + &Q_tmp[jc * (DKQ / 2) + i0 / 2 + (item_ct1.get_local_id(1) % np) * (warp_size * cpy_ne_D / 2) + + item_ct1.get_local_id(2) * (cpy_ne_D / 2)], + tmp_h2); +#else + ggml_sycl_memcpy_1<sizeof(tmp_f)>( + &Q_tmp[jc* DKQ + i0 + (item_ct1.get_local_id(1) % np)*(warp_size*cpy_ne_D) + item_ct1.get_local_id(2)* cpy_ne_D], + tmp_f); +#endif // SYCL_FAST_FP16 + } + } + } + + item_ct1.barrier(sycl::access::fence_space::local_space); + + // Main loop over KV cache: + const int k_VKQ_max = KV_max ? KV_max[sequence * item_ct1.get_group_range(2) + item_ct1.get_group(2)] : ne11; + if (ncols2 == 1) { + // Branch with out-of-bounds checks. + int k_VKQ_0 = item_ct1.get_group(1) * nbatch_fa; + while (k_VKQ_0 < k_VKQ_max - nbatch_fa) { + constexpr bool oob_check = false; + flash_attn_tile_iter<warp_size, nwarps, ncols1, ncols2, DKQ, DV, nbatch_fa, nbatch_K, use_logit_softcap, + oob_check>(Q_tmp, K_h2, V_h2, maskh, ne01, logit_softcap, slope, KQ, KV_tmp, stride_K2, + stride_V2, stride_mask, KQ_max, KQ_sum, VKQ, k_VKQ_0, k_VKQ_max, col_Q_0, + KQ_max_new_shared); + k_VKQ_0 += item_ct1.get_group_range(1) * nbatch_fa; + } + if (k_VKQ_0 < k_VKQ_max) { + constexpr bool oob_check = true; + flash_attn_tile_iter<warp_size, nwarps, ncols1, ncols2, DKQ, DV, nbatch_fa, nbatch_K, use_logit_softcap, + oob_check>(Q_tmp, K_h2, V_h2, maskh, ne01, logit_softcap, slope, KQ, KV_tmp, stride_K2, + stride_V2, stride_mask, KQ_max, KQ_sum, VKQ, k_VKQ_0, k_VKQ_max, col_Q_0, + KQ_max_new_shared); + } + } else { + // Branch without out-of-bounds checks. + for (int k_VKQ_0 = item_ct1.get_group(1) * nbatch_fa; k_VKQ_0 < k_VKQ_max; + k_VKQ_0 += item_ct1.get_group_range(1) * nbatch_fa) { + + constexpr bool oob_check = false; + flash_attn_tile_iter<warp_size, nwarps, ncols1, ncols2, DKQ, DV, nbatch_fa, nbatch_K, use_logit_softcap, + oob_check>(Q_tmp, K_h2, V_h2, maskh, ne01, logit_softcap, slope, KQ, KV_tmp, stride_K2, + stride_V2, stride_mask, KQ_max, KQ_sum, VKQ, k_VKQ_0, k_VKQ_max, col_Q_0, + KQ_max_new_shared); + } + } + +#pragma unroll + for (int jc0 = 0; jc0 < cpw; ++jc0) { + KQ_sum[jc0] = warp_reduce_sum<warp_size>(KQ_sum[jc0]); + } + + if constexpr (np > 1) { + static_assert(cpw == 1, "bad cpw"); + static_assert(nbatch_fa*nbatch_K >= nwarps*DVp, "KV_tmp too small"); + +#ifdef SYCL_FAST_FP16 + sycl::half2 * VKQ_combine = (sycl::half2 *) KV_tmp; +#else + float * VKQ_combine = (float *) KV_tmp; +#endif // SYCL_FAST_FP16 + + float * KQ_sum_combine = (float *) Q_tmp; + + if (item_ct1.get_local_id(1) % np != 0) { + +#ifdef SYCL_FAST_FP16 + constexpr int cpy_ne_D = cpy_ne < (DVp/2)/warp_size ? cpy_ne : (DVp/2)/warp_size; +#pragma unroll + for (int i0 = 0; i0 < DVp/2; i0 += warp_size*cpy_ne_D) { + ggml_sycl_memcpy_1<cpy_ne_D * 4>( + &VKQ_combine[item_ct1.get_local_id(1) * (DVp / 2) + i0 + item_ct1.get_local_id(2) * cpy_ne_D], + &VKQ[i0 / warp_size]); + } +#else + + constexpr int cpy_ne_D = cpy_ne < DVp/warp_size ? cpy_ne : DVp/warp_size; + +#pragma unroll + for (int i0 = 0; i0 < DVp; i0 += warp_size*cpy_ne_D) { + ggml_sycl_memcpy_1<cpy_ne_D*4>( + &VKQ_combine[item_ct1.get_local_id(1)*DVp + i0 + item_ct1.get_local_id(2)*cpy_ne_D], ((const float *) VKQ) + i0/warp_size); + } +#endif // SYCL_FAST_FP16 + + if (item_ct1.get_local_id(2) == 0) { + KQ_sum_combine[item_ct1.get_local_id(1)] = KQ_sum[0]; + } + return; + } + + item_ct1.barrier(sycl::access::fence_space::local_space); + +#pragma unroll + for (int ip = 1; ip < np; ++ip) { +#ifdef SYCL_FAST_FP16 + constexpr int cpy_ne_D = cpy_ne < (DVp/2)/warp_size ? cpy_ne : (DVp/2)/warp_size; +#pragma unroll + for (int i0 = 0; i0 < DVp/2; i0 += warp_size*cpy_ne_D) { + __dpct_align__(16) sycl::half2 tmp[cpy_ne_D]; + ggml_sycl_memcpy_1<cpy_ne_D * 4>(tmp, &VKQ_combine[(item_ct1.get_local_id(1) + ip) * (DVp / 2) + i0 + + item_ct1.get_local_id(2) * cpy_ne_D]); +#pragma unroll + for (int i1 = 0; i1 < cpy_ne_D; ++i1) { + VKQ[i0/warp_size + i1] += tmp[i1]; + } + } +#else + constexpr int cpy_ne_D = cpy_ne < DVp/warp_size ? cpy_ne : DVp/warp_size; +#pragma unroll + for (int i0 = 0; i0 < DVp; i0 += warp_size*cpy_ne_D) { + __dpct_align__(16) float tmp[cpy_ne_D]; + ggml_sycl_memcpy_1<cpy_ne_D*4>(tmp, &VKQ_combine[(item_ct1.get_local_id(1) + ip)*DVp + i0 + item_ct1.get_local_id(2)*cpy_ne_D]); +#pragma unroll + for (int i1 = 0; i1 < cpy_ne_D; ++i1) { + ((float *)VKQ)[i0/warp_size + i1] += tmp[i1]; + } + } +#endif // SYCL_FAST_FP16 + + KQ_sum[0] += KQ_sum_combine[item_ct1.get_local_id(1) + ip]; + } + } + + // Attention sink: adjust KQ max and sum only for the first of all parallel blocks: + if (sinks && item_ct1.get_group(1) == 0) { +#pragma unroll + for (int jc0 = 0; jc0 < cpw; ++jc0) { + const int jc = jc0 + (item_ct1.get_local_id(1) / np) * cpw; + const float sink = ((const float *) sinks)[head0 + jc % ncols2]; + + float KQ_max_new_j = sycl::fmax((float) KQ_max[jc0], sink); + const float KQ_max_scale = sycl::native::exp((float) (KQ_max[jc0] - KQ_max_new_j)); + KQ_max[jc0] = KQ_max_new_j; + + const float val = sycl::native::exp((float) (sink - KQ_max[jc0])); + KQ_sum[jc0] = KQ_sum[jc0]*KQ_max_scale + val; + +#ifdef SYCL_FAST_FP16 + const sycl::half2 KQ_max_scale_h2 = sycl::half2(KQ_max_scale, KQ_max_scale); +#pragma unroll + for (int i0 = 0; i0 < DVp/2; i0 += warp_size) { + VKQ[jc0*((DVp/2)/warp_size) + i0/warp_size] *= KQ_max_scale_h2; + } +#else +#pragma unroll + for (int i0 = 0; i0 < DVp/2; i0 += warp_size) { + VKQ[jc0*((DVp/2)/warp_size) + i0/warp_size].x() *= KQ_max_scale; + VKQ[jc0*((DVp/2)/warp_size) + i0/warp_size].y() *= KQ_max_scale; + } +#endif // SYCL_FAST_FP16 + } + } + + // Write back results: +#pragma unroll + for (int jc0 = 0; jc0 < cpw; ++jc0) { + const int jc = jc0 + (item_ct1.get_local_id(1) / np) * cpw; + + const int j = jc / ncols2; + const int c = jc % ncols2; + + if (ncols1 > 1 && col_Q_0 + j >= int(ne01.z())) { + return; + } + + const float scale = item_ct1.get_group_range(1) == 1 ? 1.0f / KQ_sum[jc0] : 1.0f; + + const int j_dst_unrolled = + ((sequence * int(ne01.z()) + col_Q_0 + j) * ne02 + head0 + c) * item_ct1.get_group_range(1) + + item_ct1.get_group(1); + +#ifdef SYCL_FAST_FP16 + constexpr int cpy_ne_D = cpy_ne/2 < (DVp/2)/warp_size ? cpy_ne/2 : (DVp/2)/warp_size; +#pragma unroll + for (int i0 = 0; i0 < DVp/2; i0 += warp_size*cpy_ne_D) { + __dpct_align__(16) sycl::float2 tmp[cpy_ne_D]; +#pragma unroll + for (int i1 = 0; i1 < cpy_ne_D; ++i1) { + tmp[i1] = VKQ[jc0 * ((DVp / 2) / warp_size) + i0 / warp_size + i1] + .template convert<float, sycl::rounding_mode::automatic>(); + tmp[i1].x() *= scale; + tmp[i1].y() *= scale; + } + if (i0 + warp_size * cpy_ne_D <= DV / 2 || i0 + item_ct1.get_local_id(2) * cpy_ne_D < DV / 2) { + ggml_sycl_memcpy_1<sizeof(tmp)>( + &dst[j_dst_unrolled * DV + 2 * i0 + item_ct1.get_local_id(2) * (2 * cpy_ne_D)], tmp); + } + } +#else + constexpr int cpy_ne_D = cpy_ne < DVp/warp_size ? cpy_ne : DVp/warp_size; +#pragma unroll + for (int i0 = 0; i0 < DVp; i0 += warp_size*cpy_ne_D) { + if (i0 + warp_size*cpy_ne_D <= DV || i0 + item_ct1.get_local_id(2)*cpy_ne_D < DV) { +#pragma unroll + for (int i1 = 0; i1 < cpy_ne_D/2; ++i1) { + VKQ[jc0*((DVp/2)/warp_size) + i0/(2*warp_size) + i1].x() *= scale; + VKQ[jc0*((DVp/2)/warp_size) + i0/(2*warp_size) + i1].y() *= scale; + } + ggml_sycl_memcpy_1<cpy_ne_D*4>( + &dst[j_dst_unrolled*DV + i0 + item_ct1.get_local_id(2)*cpy_ne_D], + &VKQ[jc0*((DVp/2)/warp_size) + i0/(2*warp_size)]); + } + } +#endif // SYCL_FAST_FP16 + + if (item_ct1.get_group_range(1) != 1 && item_ct1.get_local_id(2) == 0) { + dst_meta[j_dst_unrolled] = make_float2(KQ_max[jc0], KQ_sum[jc0]); + } + } +#else + GGML_UNUSED_VARS(Q, K, V, mask, sinks, KV_max, dst, dst_meta, scale, + max_bias, m0, m1, n_head_log2, logit_softcap, + ne00, ne01, ne02, ne03, + nb01, nb02, nb03, + ne10, ne11, ne12, ne13, + nb11, nb12, nb13, + nb21, nb22, nb23, + ne31, ne32, ne33, + nb31, nb32, nb33); +#endif // SYCL_FLASH_ATTN +} + +template <int DKQ, int DV, int ncols2, bool use_logit_softcap> +static void launch_fattn_tile_switch_ncols1(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { + const ggml_tensor * Q = dst->src[0]; + + const int id = ggml_sycl_get_device(); + const int cc = ggml_sycl_info().devices[id].cc; + const int warp_size = WARP_32_SIZE; //can't support WARP_16_SIZE + + constexpr size_t nbytes_shared = 0; + + if (DV < 512 && Q->ne[1] < 32) { + if constexpr (ncols2 <= 32) { + if (Q->ne[1] > 16/ncols2) { + constexpr int cols_per_block = 32; + const int nwarps = ggml_sycl_fattn_tile_get_nthreads (DKQ, DV, cols_per_block, cc) / warp_size; + const int nbatch_fa = ggml_sycl_fattn_tile_get_nbatch_fa(DKQ, DV, cols_per_block, cc); + launch_fattn<DV, cols_per_block/ncols2, ncols2, + flash_attn_tile<DKQ, DV, cols_per_block / ncols2, ncols2, use_logit_softcap, warp_size>, warp_size> + (ctx, dst, nwarps, nbytes_shared, nbatch_fa, true, true, false); + return; + } + } + if constexpr (ncols2 <= 16) { + if (Q->ne[1] > 8/ncols2) { + constexpr int cols_per_block = 16; + const int nwarps = ggml_sycl_fattn_tile_get_nthreads (DKQ, DV, cols_per_block, cc) / warp_size; + const int nbatch_fa = ggml_sycl_fattn_tile_get_nbatch_fa(DKQ, DV, cols_per_block, cc); + launch_fattn<DV, cols_per_block/ncols2, ncols2, + flash_attn_tile<DKQ, DV, cols_per_block / ncols2, ncols2, use_logit_softcap, warp_size>, warp_size> + (ctx, dst, nwarps, nbytes_shared, nbatch_fa, true, true, false); + return; + } + } + if constexpr (ncols2 <= 8) { + if (Q->ne[1] > 4/ncols2) { + constexpr int cols_per_block = 8; + const int nwarps = ggml_sycl_fattn_tile_get_nthreads (DKQ, DV, cols_per_block, cc) / warp_size; + const int nbatch_fa = ggml_sycl_fattn_tile_get_nbatch_fa(DKQ, DV, cols_per_block, cc); + launch_fattn<DV, cols_per_block/ncols2, ncols2, + flash_attn_tile<DKQ, DV, cols_per_block / ncols2, ncols2, use_logit_softcap, warp_size>, warp_size> + (ctx, dst, nwarps, nbytes_shared, nbatch_fa, true, true, false); + return; + } + } + } + + if constexpr (ncols2 <= 4) { + if (Q->ne[1] > 2/ncols2) { + constexpr int cols_per_block = 4; + const int nwarps = ggml_sycl_fattn_tile_get_nthreads (DKQ, DV, cols_per_block, cc) / warp_size; + const int nbatch_fa = ggml_sycl_fattn_tile_get_nbatch_fa(DKQ, DV, cols_per_block, cc); + launch_fattn<DV, cols_per_block/ncols2, ncols2, + flash_attn_tile<DKQ, DV, cols_per_block / ncols2, ncols2, use_logit_softcap, warp_size>, warp_size> + (ctx, dst, nwarps, nbytes_shared, nbatch_fa, true, true, false); + return; + } + } + + if constexpr (ncols2 <= 2) { + constexpr int cols_per_block = 2; + const int nwarps = ggml_sycl_fattn_tile_get_nthreads (DKQ, DV, cols_per_block, cc) / warp_size; + const int nbatch_fa = ggml_sycl_fattn_tile_get_nbatch_fa(DKQ, DV, cols_per_block, cc); + launch_fattn<DV, cols_per_block/ncols2, ncols2, + flash_attn_tile<DKQ, DV, cols_per_block / ncols2, ncols2, use_logit_softcap, warp_size>, warp_size> + (ctx, dst, nwarps, nbytes_shared, nbatch_fa, true, true, false); + return; + } + + { + constexpr int cols_per_block = ncols2*2; + const int nwarps = ggml_sycl_fattn_tile_get_nthreads (DKQ, DV, cols_per_block, cc) / warp_size; + const int nbatch_fa = ggml_sycl_fattn_tile_get_nbatch_fa(DKQ, DV, cols_per_block, cc); + launch_fattn<DV, cols_per_block/ncols2, ncols2, + flash_attn_tile<DKQ, DV, cols_per_block / ncols2, ncols2, use_logit_softcap, warp_size>, warp_size> + (ctx, dst, nwarps, nbytes_shared, nbatch_fa, true, true, false); + return; + } + + GGML_ABORT("fatal error"); +} + +template <int DKQ, int DV, bool use_logit_softcap> +static void launch_fattn_tile_switch_ncols2(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { + const ggml_tensor * KQV = dst; + const ggml_tensor * Q = dst->src[0]; + const ggml_tensor * K = dst->src[1]; + const ggml_tensor * mask = dst->src[3]; + + float max_bias = 0.0f; + memcpy(&max_bias, (const float *) KQV->op_params + 1, sizeof(float)); + + GGML_ASSERT(Q->ne[2] % K->ne[2] == 0); + const int gqa_ratio = Q->ne[2] / K->ne[2]; + + // On NVIDIA (Pascal and older) the GQA optimizations seem to be detrimental in some cases. + // However, for DKQ == 576, DV == 512 only the kernel variant with GQA optimizations is implemented. + //const bool nvidia = GGML_SYCL_CC_IS_NVIDIA(ggml_sycl_info().devices[ggml_sycl_get_device()].cc); + const int gqa_limit = gqa_ratio <= 4 && DV <= 256 ? 16 : INT_MAX; + const bool use_gqa_opt = mask && max_bias == 0.0f && Q->ne[1] <= gqa_limit && K->ne[1] % FATTN_KQ_STRIDE == 0; + + if constexpr (DV == 512) { + if (use_gqa_opt && gqa_ratio % 16 == 0) { + launch_fattn_tile_switch_ncols1<DKQ, DV, 16, use_logit_softcap>(ctx, dst); + return; + } + if (use_gqa_opt && gqa_ratio % 4 == 0) { + launch_fattn_tile_switch_ncols1<DKQ, DV, 4, use_logit_softcap>(ctx, dst); + return; + } + // ncols2=2 and ncols2=1 fallbacks only for cases where ncols=2 config exists (DKQ == DV). + // For DKQ == 576, DV == 512 only GQA-optimized variants are implemented. + if constexpr (DKQ == DV) { + if (use_gqa_opt && gqa_ratio % 2 == 0) { + launch_fattn_tile_switch_ncols1<DKQ, DV, 2, use_logit_softcap>(ctx, dst); + return; + } + launch_fattn_tile_switch_ncols1<DKQ, DV, 1, use_logit_softcap>(ctx, dst); + return; + } + } + + if constexpr (DV <= 256) { + if (use_gqa_opt && gqa_ratio % 8 == 0) { + launch_fattn_tile_switch_ncols1<DKQ, DV, 8, use_logit_softcap>(ctx, dst); + return; + } + + if (use_gqa_opt && gqa_ratio % 4 == 0) { + launch_fattn_tile_switch_ncols1<DKQ, DV, 4, use_logit_softcap>(ctx, dst); + return; + } + + if (use_gqa_opt && gqa_ratio % 2 == 0) { + launch_fattn_tile_switch_ncols1<DKQ, DV, 2, use_logit_softcap>(ctx, dst); + return; + } + + launch_fattn_tile_switch_ncols1<DKQ, DV, 1, use_logit_softcap>(ctx, dst); + return; + } + GGML_ABORT("fatal error"); +} + +template <int DKQ, int DV> +void ggml_sycl_flash_attn_ext_tile_case(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { + const ggml_tensor * KQV = dst; + + float logit_softcap; + memcpy(&logit_softcap, (const float *) KQV->op_params + 2, sizeof(float)); + + if (logit_softcap == 0.0f) { + constexpr bool use_logit_softcap = false; + launch_fattn_tile_switch_ncols2<DKQ, DV, use_logit_softcap>(ctx, dst); + } else { + constexpr bool use_logit_softcap = true; + launch_fattn_tile_switch_ncols2<DKQ, DV, use_logit_softcap>(ctx, dst); + } +} + +void ggml_sycl_flash_attn_ext_tile(ggml_backend_sycl_context & ctx, ggml_tensor * dst); + +#define DECL_FATTN_TILE_CASE(DKQ, DV) \ + template void ggml_sycl_flash_attn_ext_tile_case \ + <DKQ, DV>(ggml_backend_sycl_context & ctx, ggml_tensor * dst) \ + +extern DECL_FATTN_TILE_CASE( 40, 40); +extern DECL_FATTN_TILE_CASE( 64, 64); +extern DECL_FATTN_TILE_CASE( 72, 72); +extern DECL_FATTN_TILE_CASE( 80, 80); +extern DECL_FATTN_TILE_CASE( 96, 96); +extern DECL_FATTN_TILE_CASE(112, 112); +extern DECL_FATTN_TILE_CASE(128, 128); +extern DECL_FATTN_TILE_CASE(256, 256); +extern DECL_FATTN_TILE_CASE(512, 512); +extern DECL_FATTN_TILE_CASE(576, 512); + diff --git a/ggml/src/ggml-sycl/fattn-vec.hpp b/ggml/src/ggml-sycl/fattn-vec.hpp new file mode 100644 index 00000000000..8031acfdff8 --- /dev/null +++ b/ggml/src/ggml-sycl/fattn-vec.hpp @@ -0,0 +1,674 @@ +#ifndef GGML_SYCL_FATTN_VEC_HPP +#define GGML_SYCL_FATTN_VEC_HPP + +#include <sycl/sycl.hpp> +#include <sycl/ext/oneapi/work_group_static.hpp> +#include <iostream> +#include <iomanip> + +#include "dpct/helper.hpp" +#include "common.hpp" +#include "ggml.h" +#include "fattn-common.hpp" +#include <cmath> +#include <float.h> + +namespace syclex = sycl::ext::oneapi::experimental; + +static int ggml_sycl_fattn_vec_get_nthreads_host(const int cc) { + return 128; + GGML_UNUSED(cc); +} + +static constexpr int ggml_sycl_fattn_vec_get_nthreads_device() { + return 128; +} + +// Currenlty llvm with the amdgcn target dose not support unrolling loops +// that contain a break that can not be resolved at compile time. +#ifdef __clang__ +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wpass-failed" +#endif // __clang__ + +template <int D, + int ncols, + int type_K, + int type_V, + bool use_logit_softcap, + int warp_size> // D == head size +static void flash_attn_ext_vec(const char* __restrict__ Q, + const char* __restrict__ K, + const char* __restrict__ V, + const char* __restrict__ mask, + const char* __restrict__ sinks, + const int* __restrict__ KV_max, + float* __restrict__ dst, + sycl::float2* __restrict__ dst_meta, + const float scale, + const float max_bias, + const float m0, + const float m1, + const uint32_t n_head_log2, + const float logit_softcap, + const int32_t ne00, + const sycl::uint3 ne01, + const int32_t ne02, + const int32_t ne03, + const int32_t nb01, + const int32_t nb02, + const int32_t nb03, + const int32_t ne10, + const int32_t ne11, + const int32_t ne12, + const int32_t ne13, + const int32_t nb11, + const int32_t nb12, + const int64_t nb13, + const int32_t nb21, + const int32_t nb22, + const int64_t nb23, + const int32_t ne31, + const int32_t ne32, + const int32_t ne33, + const int32_t nb31, + const int32_t nb32, + const int64_t nb33) { +#ifdef SYCL_FLASH_ATTN + // Skip unused kernel variants for faster compilation: + + auto item_ct1 = sycl::ext::oneapi::this_work_item::get_nd_item<3>(); + if (use_logit_softcap && !(D == 128 || D == 256)) { + GGML_UNUSED_VARS(Q, K, V, mask, sinks, KV_max, dst, dst_meta, scale, + max_bias, m0, m1, n_head_log2, logit_softcap, + ne00, ne01, ne02, ne03, + nb01, nb02, nb03, + ne10, ne11, ne12, ne13, + nb11, nb12, nb13, + nb21, nb22, nb23, + ne31, ne32, ne33, + nb31, nb32, nb33); + return; + } + + //In this kernel Q, K, V are matrices while i, j, k are matrix indices. + + constexpr int cpy_nb = ggml_sycl_get_max_cpy_bytes(); + constexpr int cpy_ne = cpy_nb / 4; + + constexpr int nthreads_KQ_q = (D/4 < warp_size ? D/4 : warp_size); + constexpr int nthreads_V_q = (D/4 < warp_size ? D/4 : warp_size); + + constexpr int nthreads = ggml_sycl_fattn_vec_get_nthreads_device(); + constexpr int nthreads_KQ = type_K == GGML_TYPE_F16 ? 128 / cpy_nb : nthreads_KQ_q; + constexpr int nthreads_V = type_V == GGML_TYPE_F16 ? 128 / cpy_nb : nthreads_V_q; + + static_assert(warp_size % nthreads_KQ == 0, "bad nthreads_K"); + static_assert(warp_size % nthreads_V == 0, "bad nthreads_V"); + + constexpr int V_rows_per_thread = type_V == GGML_TYPE_F16 ? 2*cpy_ne : 4; + constexpr int V_cols_per_iter = warp_size / nthreads_V; + + constexpr vec_dot_KQ_t vec_dot_KQ = get_vec_dot_KQ<type_K, D, nthreads_KQ, warp_size>(); + constexpr bool Q_q8_1 = type_K != GGML_TYPE_F16; +#ifdef GGML_SYCL_F16 + constexpr dequantize_V_t dequantize_V = get_dequantize_V<type_V, sycl::half, V_rows_per_thread>(); +#else + constexpr dequantize_V_t dequantize_V = get_dequantize_V<type_V, float, V_rows_per_thread>(); +#endif // GGML_SYCL_F16 + + const int ic0 = item_ct1.get_group(2) * ncols; // Index of the Q/QKV column to work on. + + const int sequence = item_ct1.get_group(0) / ne02; + const int head = item_ct1.get_group(0) - sequence * ne02; + const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix. + Q += nb03*sequence + nb02* head + nb01*ic0; + K += nb13*sequence + nb12*(head / gqa_ratio); + V += nb23*sequence + nb22*(head / gqa_ratio); + + const sycl::half * maskh = (const sycl::half *) (mask + nb33 * (sequence % ne33) + nb31 * ic0); + + const float slope = get_alibi_slope(max_bias, head, n_head_log2, m0, m1); + + static_assert(D % (2*warp_size) == 0, "D not divisible by 2*warp_size == 64."); + constexpr int nwarps = nthreads / warp_size; + const int tid = warp_size * item_ct1.get_local_id(1) + item_ct1.get_local_id(2); + __builtin_assume(tid < nthreads); + + constexpr int ne_KQ = ncols*D; + constexpr int ne_combine = nwarps*V_cols_per_iter*D; + + constexpr size_t lsm_size1 = ncols * warp_size; + constexpr size_t lsm_size2 = ncols * warp_size; +#ifdef GGML_SYCL_F16 + sycl::half2 VKQ[ncols][(D / 2) / nthreads_V] = { { { 0.0f, 0.0f } } }; + constexpr size_t lsm_size3 = (ne_KQ > ne_combine ? ne_KQ : ne_combine); + constexpr size_t local_share_mem_size = (lsm_size1 + lsm_size2)*sizeof(float) + lsm_size3*sizeof(sycl::half); + + syclex::work_group_static<char[local_share_mem_size]> lsm; + + float *KQ_max_shared = (float *)&lsm; + float *KQ_sum_shared = KQ_max_shared+lsm_size1; + sycl::half* KQ = (sycl::half*)(KQ_sum_shared + lsm_size2); + + +#else + sycl::float2 VKQ[ncols][(D/2)/nthreads_V] = {{{0.0f, 0.0f}}}; + + constexpr size_t lsm_size3 = (ne_KQ > ne_combine ? ne_KQ : ne_combine); + constexpr size_t local_share_mem_size = (lsm_size1 + lsm_size2 + lsm_size3)*sizeof(float); + + + syclex::work_group_static<char[local_share_mem_size]> lsm; + float *KQ_max_shared = (float *)&lsm; + float *KQ_sum_shared = KQ_max_shared+lsm_size1; + float* KQ = KQ_sum_shared + lsm_size2; + +#endif // GGML_SYCL_F16 + + float KQ_max[ncols]; + float KQ_sum[ncols]; +#pragma unroll + for (int j = 0; j < ncols; ++j) { + KQ_max[j] = -FLT_MAX/2.0f; + KQ_sum[j] = 0.0f; + } + + // Convert Q to float2 (f16 K) or q8_1 (quantized K) and store in registers: +#ifdef GGML_SYCL_F16 + sycl::half2 Q_reg[ncols][(D / 2) / nthreads_KQ] = {{{0.0f, 0.0f}}}; // Will be initialized completely. +#else + sycl::float2 Q_reg[ncols][(D/2)/nthreads_KQ] = {{{0.0f, 0.0f}}}; // May be only partially initialized. +#endif // GGML_SYCL_F16 + int Q_i32[ncols][1 > D/(sizeof(int)*nthreads_KQ) ? 1 : D/(sizeof(int)*nthreads_KQ)]; + sycl::float2 Q_ds[ncols][1 > D / (sizeof(int) * nthreads_KQ) ? 1 : D / (sizeof(int) * nthreads_KQ)]; + if constexpr (Q_q8_1) { +#pragma unroll + for (int j0 = 0; j0 < ncols; j0 += nwarps) { + const int j = j0 + item_ct1.get_local_id(1); + + if (j0 + nwarps > ncols && j >= ncols) { + break; + } + + // Reuse KQ as temporary storage for converting Q to q8_1: + int * tmp_q_i32 = (int *) &KQ[j*D]; + sycl::float2 * tmp_q_ds = (sycl::float2 *) (tmp_q_i32 + D / sizeof(int)); + + // Set memory to zero if out of bounds: + if (ncols > 1 && ic0 + j >= int(ne01.z())) { +#pragma unroll + for (int i0 = 0; i0 < int(D/sizeof(int)); i0 += warp_size) { + const int i = i0 + item_ct1.get_local_id(2); + + if (i0 + warp_size <= int(D/sizeof(int)) || i < int(D/sizeof(int))) { + tmp_q_i32[i] = 0; + } + } + if (item_ct1.get_local_id(2) < D/QK8_1) { + tmp_q_ds[item_ct1.get_local_id(2)] = sycl::float2(0.0f, 0.0f); + } + } else { + const float * Q_f = (const float *) (Q + j*nb01); + constexpr int nthreads_quantize = D/sizeof(int) < warp_size ? D/sizeof(int) : warp_size; +#pragma unroll + for (int i0 = 0; i0 < int(D/sizeof(int)); i0 += nthreads_quantize) { + quantize_q8_1_to_shared<sycl::float2, nthreads_quantize, warp_size> + (Q_f + i0*sizeof(int), scale, tmp_q_i32 + i0, tmp_q_ds + i0/QI8_1); + } + } + } + + + item_ct1.barrier(sycl::access::fence_space::local_space); + +#pragma unroll + for (int j = 0; j < ncols; ++j) { + int * tmp_q_i32 = (int *) &KQ[j*D]; + sycl::float2 * tmp_q_ds = (sycl::float2 *) (tmp_q_i32 + D / sizeof(int)); + +#pragma unroll + for (int i0 = 0; i0 < int(D/sizeof(int)); i0 += nthreads_KQ) { + const int i = + i0 + (nthreads_KQ == warp_size ? item_ct1.get_local_id(2) : item_ct1.get_local_id(2) % nthreads_KQ); + + Q_i32[j][i0/nthreads_KQ] = tmp_q_i32[i]; + Q_ds[j][i0/nthreads_KQ] = tmp_q_ds[i/QI8_1]; + } + } + + item_ct1.barrier(sycl::access::fence_space::local_space); + + } else { +#ifdef GGML_SYCL_F16 + const sycl::half2 scale_h2 = sycl::half2(scale, scale); +#pragma unroll + for (int j = 0; j < ncols; ++j) { + const sycl::float2 * Q_j = (const sycl::float2 *) (Q + j * nb01); +#pragma unroll + for (int i0 = 0; i0 < D/2; i0 += nthreads_KQ*cpy_ne) { + const int i = i0 + (nthreads_KQ == warp_size ? item_ct1.get_local_id(2) : + item_ct1.get_local_id(2) % nthreads_KQ) * + cpy_ne; + + sycl::float2 tmp[cpy_ne] = { + { 0.0f, 0.0f } + }; + if (ncols == 1 || ic0 + j < int(ne01.z())) { + ggml_sycl_memcpy_1<cpy_nb>(tmp, &Q_j[i]); + ggml_sycl_memcpy_1<cpy_nb>(tmp + cpy_ne/2, &Q_j[i + cpy_ne/2]); + } +#pragma unroll + for (int i1 = 0; i1 < cpy_ne; ++i1) { + Q_reg[j][i0 / nthreads_KQ + i1] = sycl::half2(tmp[i1].x(), tmp[i1].y()); + } + } +#pragma unroll + for (int k = 0; k < (D/2)/nthreads_KQ; ++k) { + Q_reg[j][k] *= scale_h2; + } + } +#else +#pragma unroll + for (int j = 0; j < ncols; ++j) { + const sycl::float2 * Q_j = (const sycl::float2 *) (Q + j*nb01); +#pragma unroll + for (int i0 = 0; i0 < D/2; i0 += nthreads_KQ*cpy_ne) { + const int i = i0 + (nthreads_KQ == warp_size ? item_ct1.get_local_id(2) : item_ct1.get_local_id(2) % nthreads_KQ)*cpy_ne; + if (ncols == 1 || ic0 + j < int(ne01.z())) { + ggml_sycl_memcpy_1<cpy_nb>(&Q_reg[j][i0/nthreads_KQ], &Q_j[i]); + ggml_sycl_memcpy_1<cpy_nb>(&Q_reg[j][i0/nthreads_KQ + cpy_ne/2], &Q_j[i + cpy_ne/2]); + } + } +#pragma unroll + for (int k = 0; k < (D/2)/nthreads_KQ; ++k) { + Q_reg[j][k].x() *= scale; + Q_reg[j][k].y() *= scale; + } + } +#endif // GGML_SYCL_F16 + } + + const int k_VKQ_max = KV_max ? KV_max[sequence * item_ct1.get_group_range(2) + item_ct1.get_group(2)] : ne11; + K += item_ct1.get_group(1) * nthreads * nb11; + V += item_ct1.get_group(1) * nthreads * nb21; + maskh += item_ct1.get_group(1) * nthreads; + for (int k_VKQ_0 = item_ct1.get_group(1) * nthreads; k_VKQ_0 < k_VKQ_max; + k_VKQ_0 += item_ct1.get_group_range(1) * nthreads, + // Increment pointers after each loop: + K += item_ct1.get_group_range(1) * nthreads * nb11, V += item_ct1.get_group_range(1) * nthreads * nb21, + maskh += item_ct1.get_group_range(1) * nthreads) { + // Calculate KQ tile and keep track of new maximum KQ values: + float KQ_reg[ncols]={}; // KQ in registers. + float KQ_max_new[ncols]={}; + + +#pragma unroll + for (int j = 0; j < ncols; ++j) { + KQ_max_new[j] = KQ_max[j]; + } + +#pragma unroll + for (int i_KQ_0 = 0; i_KQ_0 < nthreads_KQ; ++i_KQ_0) { + const int i_KQ = item_ct1.get_local_id(1) * warp_size + + (nthreads_KQ == warp_size ? 0 : (item_ct1.get_local_id(2) & ~(nthreads_KQ - 1))) + i_KQ_0; + +#pragma unroll + for (int j = 0; j < ncols; ++j) { + float sum = vec_dot_KQ(K + i_KQ*nb11, Q_reg[j], Q_i32[j], Q_ds[j]); + sum = warp_reduce_sum<nthreads_KQ>(sum); + + if (use_logit_softcap) { + sum = logit_softcap * sycl::tanh(sum); + } + if (mask) { + sum += slope * sycl::vec<sycl::half, 1>(maskh[j * ne11 + i_KQ]) + .convert<float, sycl::rounding_mode::automatic>()[0]; + } + + KQ_max_new[j] = sycl::fmax((float) KQ_max_new[j], sum); + + if (int(nthreads_KQ == warp_size ? item_ct1.get_local_id(2) + : item_ct1.get_local_id(2) % + nthreads_KQ) == i_KQ_0) { + KQ_reg[j] = sum; + } + } + } + +#pragma unroll + for (int j = 0; j < ncols; ++j) { +#pragma unroll + for (int offset = nthreads_KQ; offset < warp_size; offset <<= 1) { + KQ_max_new[j] = sycl::fmax( + (float)KQ_max_new[j], + (float)dpct::permute_sub_group_by_xor( + sycl::ext::oneapi::this_work_item::get_sub_group(), + KQ_max_new[j], + offset, + warp_size)); + } + const float KQ_max_scale = sycl::native::exp((float) (KQ_max[j] - KQ_max_new[j])); + KQ_max[j] = KQ_max_new[j]; + + KQ_reg[j] = sycl::native::exp((float) (KQ_reg[j] - KQ_max[j])); + KQ_sum[j] = KQ_sum[j]*KQ_max_scale + KQ_reg[j]; + KQ[j*nthreads + tid] = KQ_reg[j]; + +#ifdef GGML_SYCL_F16 + const sycl::half2 KQ_max_scale_h2 = sycl::half2(KQ_max_scale, KQ_max_scale); +#pragma unroll + for (int i_VKQ_0 = 0; i_VKQ_0 < D/2; i_VKQ_0 += nthreads_V) { + VKQ[j][i_VKQ_0/nthreads_V] *= KQ_max_scale_h2; + } +#else +#pragma unroll + for (int i_VKQ_0 = 0; i_VKQ_0 < D/2; i_VKQ_0 += nthreads_V) { + VKQ[j][i_VKQ_0/nthreads_V].x() *= KQ_max_scale; + VKQ[j][i_VKQ_0/nthreads_V].y() *= KQ_max_scale; + } +#endif // GGML_SYCL_F16 + } + + sycl::group_barrier(sycl::ext::oneapi::this_work_item::get_sub_group()); + +#pragma unroll + for (int k0 = 0; k0 < warp_size; k0 += V_cols_per_iter) { + const int k = item_ct1.get_local_id(1) * warp_size + k0 + + (nthreads_V == warp_size ? 0 : item_ct1.get_local_id(2) / nthreads_V); + +#ifdef GGML_SYCL_F16 + sycl::half2 KQ_k[ncols]; +#pragma unroll + for (int j = 0; j < ncols; ++j) { + KQ_k[j] = sycl::half2(KQ[j * nthreads + k]); + } +#pragma unroll + for (int i_VKQ_0 = 0; i_VKQ_0 < D/2; i_VKQ_0 += nthreads_V*V_rows_per_thread/2) { + sycl::half2 tmp[V_rows_per_thread / 2]; + dequantize_V(V + k * nb21, tmp, + 2 * i_VKQ_0 + (nthreads_V == warp_size ? item_ct1.get_local_id(2) : + item_ct1.get_local_id(2) % nthreads_V) * + V_rows_per_thread); +#pragma unroll + for (int i_VKQ_1 = 0; i_VKQ_1 < V_rows_per_thread/2; ++i_VKQ_1) { +#pragma unroll + for (int j = 0; j < ncols; ++j) { + VKQ[j][i_VKQ_0/nthreads_V + i_VKQ_1] += tmp[i_VKQ_1]*KQ_k[j]; + } + } + } +#else + float KQ_k[ncols]; +#pragma unroll + for (int j = 0; j < ncols; ++j) { + KQ_k[j] = KQ[j*nthreads + k]; + } +#pragma unroll + for (int i_VKQ_0 = 0; i_VKQ_0 < D/2; i_VKQ_0 += nthreads_V*V_rows_per_thread/2) { + sycl::float2 tmp[V_rows_per_thread/2]; + dequantize_V(V + k*nb21, tmp, + 2*i_VKQ_0 + (nthreads_V == warp_size ? item_ct1.get_local_id(2) : item_ct1.get_local_id(2) % nthreads_V)*V_rows_per_thread); +#pragma unroll + for (int i_VKQ_1 = 0; i_VKQ_1 < V_rows_per_thread/2; ++i_VKQ_1) { +#pragma unroll + for (int j = 0; j < ncols; ++j) { + VKQ[j][i_VKQ_0/nthreads_V + i_VKQ_1].x() += tmp[i_VKQ_1].x()*KQ_k[j]; + VKQ[j][i_VKQ_0/nthreads_V + i_VKQ_1].y() += tmp[i_VKQ_1].y()*KQ_k[j]; + } + } + } +#endif // GGML_SYCL_F16 + } + } + + if (sinks && item_ct1.get_group(1) == 0) { + const float sink = ((const float *) sinks)[head]; + +#pragma unroll + for (int j0 = 0; j0 < ncols; j0 += nwarps) { + const int j = j0 + item_ct1.get_local_id(1); + + if (j0 + nwarps > ncols && j >= ncols) { + break; + } + const float kqmax_new_j = sycl::fmax(sink, (float) KQ_max[j]); + const float KQ_max_scale = sycl::native::exp((float) (KQ_max[j] - kqmax_new_j)); + KQ_max[j] = kqmax_new_j; + + KQ_sum[j] = KQ_sum[j] * KQ_max_scale + + (item_ct1.get_local_id(2) == 0 ? sycl::native::exp((float) (sink - KQ_max[j])) : 0.0f); +#ifdef GGML_SYCL_F16 + const sycl::half2 KQ_max_scale_h2 = sycl::half2(KQ_max_scale, KQ_max_scale); +#pragma unroll + for (int i_VKQ_0 = 0; i_VKQ_0 < D/2; i_VKQ_0 += nthreads_V) { + VKQ[j][i_VKQ_0/nthreads_V] *= KQ_max_scale_h2; + } +#else +#pragma unroll + for (int i_VKQ_0 = 0; i_VKQ_0 < D/2; i_VKQ_0 += nthreads_V) { + VKQ[j][i_VKQ_0/nthreads_V].x() *= KQ_max_scale; + VKQ[j][i_VKQ_0/nthreads_V].y() *= KQ_max_scale; + } +#endif // GGML_SYCL_F16 + } + } + +#pragma unroll + for (int j = 0; j < ncols; ++j) { + if (item_ct1.get_local_id(1) == 0) { + KQ_max_shared[j*warp_size+item_ct1.get_local_id(2)] = -FLT_MAX / 2.0f; + KQ_sum_shared[j*warp_size+item_ct1.get_local_id(2)] = 0.0f; + } + } + + item_ct1.barrier(sycl::access::fence_space::local_space); + +#pragma unroll + for (int j = 0; j < ncols; ++j) { + if (item_ct1.get_local_id(2) == 0) { + KQ_max_shared[j*warp_size+item_ct1.get_local_id(1)] = KQ_max[j]; + } + } + + + item_ct1.barrier(sycl::access::fence_space::local_space); + +#pragma unroll + for (int j_VKQ = 0; j_VKQ < ncols; ++j_VKQ) { + if (ncols > 1 && ic0 + j_VKQ >= int(ne01.z())) { + break; + } + + float kqmax_new = KQ_max_shared[j_VKQ*warp_size+item_ct1.get_local_id(2)]; + kqmax_new = warp_reduce_max<warp_size>(kqmax_new); + const float kqmax_scale = sycl::native::exp((float) (KQ_max[j_VKQ] - kqmax_new)); + KQ_max[j_VKQ] = kqmax_new; + +#ifdef GGML_SYCL_F16 + sycl::half2 * VKQ_tmp = (sycl::half2 *) KQ + item_ct1.get_local_id(1) * (V_cols_per_iter * D / 2) + + (nthreads_V == warp_size ? 0 : item_ct1.get_local_id(2) / nthreads_V) * (D / 2); + + const sycl::half2 kqmax_scale_h2 = sycl::half2(kqmax_scale, kqmax_scale); +#pragma unroll + for (int i_VKQ_0 = 0; i_VKQ_0 < D/2; i_VKQ_0 += nthreads_V) { + VKQ[j_VKQ][i_VKQ_0/nthreads_V] *= kqmax_scale_h2; + } +#pragma unroll + for (int i_VKQ_0 = 0; i_VKQ_0 < D/2; i_VKQ_0 += nthreads_V*V_rows_per_thread/2) { + const int i_VKQ = + i_VKQ_0 + (nthreads_V == warp_size ? item_ct1.get_local_id(2) : item_ct1.get_local_id(2) % nthreads_V) * + (V_rows_per_thread / 2); + + ggml_sycl_memcpy_1<V_rows_per_thread * sizeof(sycl::half)>(VKQ_tmp + i_VKQ, + &VKQ[j_VKQ][i_VKQ_0 / nthreads_V]); + } +#else + sycl::float2 * VKQ_tmp = (sycl::float2 *) KQ + item_ct1.get_local_id(1)*(V_cols_per_iter*D/2) + + (nthreads_V == warp_size ? 0 : item_ct1.get_local_id(2) / nthreads_V)*(D/2); +#pragma unroll + for (int i_VKQ_0 = 0; i_VKQ_0 < D/2; i_VKQ_0 += nthreads_V) { + VKQ[j_VKQ][i_VKQ_0/nthreads_V].x() *= kqmax_scale; + VKQ[j_VKQ][i_VKQ_0/nthreads_V].y() *= kqmax_scale; + } +#pragma unroll + for (int i_VKQ_0 = 0; i_VKQ_0 < D/2; i_VKQ_0 += nthreads_V*V_rows_per_thread/2) { + const int i_VKQ = i_VKQ_0 + (nthreads_V == warp_size ? item_ct1.get_local_id(2) : item_ct1.get_local_id(2) % nthreads_V)*(V_rows_per_thread/2); + + ggml_sycl_memcpy_1<V_rows_per_thread/2*sizeof(float)>(VKQ_tmp + i_VKQ, &VKQ[j_VKQ][i_VKQ_0/nthreads_V]); + ggml_sycl_memcpy_1<V_rows_per_thread/2*sizeof(float)>(VKQ_tmp + i_VKQ + V_rows_per_thread/4, &VKQ[j_VKQ][i_VKQ_0/nthreads_V + V_rows_per_thread/4]); + } +#endif // GGML_SYCL_F16 + + KQ_sum[j_VKQ] *= kqmax_scale; + KQ_sum[j_VKQ] = warp_reduce_sum<warp_size>(KQ_sum[j_VKQ]); + if (item_ct1.get_local_id(2) == 0) { + KQ_sum_shared[j_VKQ*warp_size+item_ct1.get_local_id(1)] = KQ_sum[j_VKQ]; + } + + item_ct1.barrier(sycl::access::fence_space::local_space); + + + if (nthreads <= D || tid < D) { + KQ_sum[j_VKQ] = KQ_sum_shared[j_VKQ*warp_size+item_ct1.get_local_id(2)]; + KQ_sum[j_VKQ] = warp_reduce_sum<warp_size>(KQ_sum[j_VKQ]); + +#pragma unroll + for (int i0 = 0; i0 < D; i0 += nthreads) { + float dst_val = 0; +#pragma unroll + for (int w = 0; w < nwarps; ++w) { +#pragma unroll + for (int v = 0; v < V_cols_per_iter; ++v) { + dst_val += float(KQ[w*V_cols_per_iter*D + v*D + i0 + tid]); + } + } + if (item_ct1.get_group_range(1) == 1) { + dst_val /= KQ_sum[j_VKQ]; + } + dst[(((sequence * int(ne01.z()) + ic0 + j_VKQ) * ne02 + head) * item_ct1.get_group_range(1) + + item_ct1.get_group(1)) * + D + + i0 + tid] = dst_val; + } + } + + if (j_VKQ < ncols-1) { + item_ct1.barrier(sycl::access::fence_space::local_space); + } + + } + + if (item_ct1.get_group_range(1) != 1 && tid < ncols && (ncols == 1 || ic0 + tid < int(ne01.z()))) { + dst_meta[((sequence * int(ne01.z()) + ic0 + tid) * ne02 + head) * item_ct1.get_group_range(1) + + item_ct1.get_group(1)] = make_float2(KQ_max[tid], KQ_sum[tid]); + } +#else + GGML_UNUSED_VARS(Q, K, V, mask, sinks, KV_max, dst, dst_meta, scale, + max_bias, m0, m1, n_head_log2, logit_softcap, + ne00, ne01, ne02, ne03, + nb01, nb02, nb03, + ne10, ne11, ne12, ne13, + nb11, nb12, nb13, + nb21, nb22, nb23, + ne31, ne32, ne33, + nb31, nb32, nb33); + +#endif // SYCL_FLASH_ATTN +} +#ifdef __clang__ +#pragma clang diagnostic pop +#endif // __clang__ + + +template <int D, int cols_per_block, int type_K, int type_V, bool use_logit_softcap> +void ggml_sycl_flash_attn_ext_vec_case_impl(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { + + const int warp_size = WARP_16_SIZE; //better performance than WARP_32_SIZE + + const int cc = ggml_sycl_info().devices[ggml_sycl_get_device()].cc; + + const int nthreads = ggml_sycl_fattn_vec_get_nthreads_host(cc); + const int nwarps = nthreads / warp_size; + + const bool need_f16_K = type_K == GGML_TYPE_F16; + const bool need_f16_V = type_V == GGML_TYPE_F16; + constexpr size_t nbytes_shared = 0; + + launch_fattn<D, cols_per_block, 1, + flash_attn_ext_vec<D, cols_per_block, type_K, type_V, + use_logit_softcap, warp_size>, warp_size>( + ctx, dst, nwarps, nbytes_shared, D, need_f16_K, need_f16_V, false); +} + +template <int D, int type_K, int type_V> +void ggml_sycl_flash_attn_ext_vec_case(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { + const ggml_tensor * KQV = dst; + const ggml_tensor * Q = dst->src[0]; + + float logit_softcap; + memcpy(&logit_softcap, (const float *) KQV->op_params + 2, sizeof(float)); + + if (Q->ne[1] == 1) { + constexpr int cols_per_block = 1; + if (logit_softcap == 0.0f) { + constexpr bool use_logit_softcap = false; + ggml_sycl_flash_attn_ext_vec_case_impl<D, cols_per_block, type_K, type_V, use_logit_softcap>(ctx, dst); + } else { + constexpr bool use_logit_softcap = true; + ggml_sycl_flash_attn_ext_vec_case_impl<D, cols_per_block, type_K, type_V, use_logit_softcap>(ctx, dst); + } + return; + } + + constexpr int cols_per_block = 2; + if (logit_softcap == 0.0f) { + constexpr bool use_logit_softcap = false; + ggml_sycl_flash_attn_ext_vec_case_impl<D, cols_per_block, type_K, type_V, use_logit_softcap>(ctx, dst); + } else { + constexpr bool use_logit_softcap = true; + ggml_sycl_flash_attn_ext_vec_case_impl<D, cols_per_block, type_K, type_V, use_logit_softcap>(ctx, dst); + } +} + +#define DECL_FATTN_VEC_CASE(D, type_K, type_V) \ + template void ggml_sycl_flash_attn_ext_vec_case \ + <D, type_K, type_V>(ggml_backend_sycl_context & ctx, ggml_tensor * dst) \ + +#define EXTERN_DECL_FATTN_VEC_CASES(D, type_K) \ + extern DECL_FATTN_VEC_CASE(D, type_K, GGML_TYPE_F16); \ + extern DECL_FATTN_VEC_CASE(D, type_K, GGML_TYPE_Q4_0); \ + extern DECL_FATTN_VEC_CASE(D, type_K, GGML_TYPE_Q4_1); \ + extern DECL_FATTN_VEC_CASE(D, type_K, GGML_TYPE_Q5_0); \ + extern DECL_FATTN_VEC_CASE(D, type_K, GGML_TYPE_Q5_1); \ + extern DECL_FATTN_VEC_CASE(D, type_K, GGML_TYPE_Q8_0); \ + +EXTERN_DECL_FATTN_VEC_CASES( 64, GGML_TYPE_F16) +EXTERN_DECL_FATTN_VEC_CASES( 64, GGML_TYPE_Q4_0) +EXTERN_DECL_FATTN_VEC_CASES( 64, GGML_TYPE_Q4_1) +EXTERN_DECL_FATTN_VEC_CASES( 64, GGML_TYPE_Q5_0) +EXTERN_DECL_FATTN_VEC_CASES( 64, GGML_TYPE_Q5_1) +EXTERN_DECL_FATTN_VEC_CASES( 64, GGML_TYPE_Q8_0) + +EXTERN_DECL_FATTN_VEC_CASES(128, GGML_TYPE_F16) +EXTERN_DECL_FATTN_VEC_CASES(128, GGML_TYPE_Q4_0) +EXTERN_DECL_FATTN_VEC_CASES(128, GGML_TYPE_Q4_1) +EXTERN_DECL_FATTN_VEC_CASES(128, GGML_TYPE_Q5_0) +EXTERN_DECL_FATTN_VEC_CASES(128, GGML_TYPE_Q5_1) +EXTERN_DECL_FATTN_VEC_CASES(128, GGML_TYPE_Q8_0) + +EXTERN_DECL_FATTN_VEC_CASES(256, GGML_TYPE_F16) +EXTERN_DECL_FATTN_VEC_CASES(256, GGML_TYPE_Q4_0) +EXTERN_DECL_FATTN_VEC_CASES(256, GGML_TYPE_Q4_1) +EXTERN_DECL_FATTN_VEC_CASES(256, GGML_TYPE_Q5_0) +EXTERN_DECL_FATTN_VEC_CASES(256, GGML_TYPE_Q5_1) +EXTERN_DECL_FATTN_VEC_CASES(256, GGML_TYPE_Q8_0) + +EXTERN_DECL_FATTN_VEC_CASES(512, GGML_TYPE_F16) +EXTERN_DECL_FATTN_VEC_CASES(512, GGML_TYPE_Q4_0) +EXTERN_DECL_FATTN_VEC_CASES(512, GGML_TYPE_Q4_1) +EXTERN_DECL_FATTN_VEC_CASES(512, GGML_TYPE_Q5_0) +EXTERN_DECL_FATTN_VEC_CASES(512, GGML_TYPE_Q5_1) +EXTERN_DECL_FATTN_VEC_CASES(512, GGML_TYPE_Q8_0) + +#endif // GGML_SYCL_FATTN_VEC_HPP diff --git a/ggml/src/ggml-sycl/fattn.cpp b/ggml/src/ggml-sycl/fattn.cpp new file mode 100644 index 00000000000..7c6e6112fdc --- /dev/null +++ b/ggml/src/ggml-sycl/fattn.cpp @@ -0,0 +1,227 @@ +// +// MIT license +// Copyright (C) 2025 Intel Corporation +// SPDX-License-Identifier: MIT +// + +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// + + +#include <sycl/sycl.hpp> +#include "dpct/helper.hpp" +#include "common.hpp" +#include "fattn-common.hpp" +#include "fattn-tile.hpp" +#include "fattn-vec.hpp" +#include "fattn.hpp" + + +#define FATTN_VEC_CASE(D, type_K, type_V) \ + { \ + const bool type_K_okay = K->type == (type_K) || (K->type == GGML_TYPE_F32 && (type_K) == GGML_TYPE_F16); \ + const bool type_V_okay = V->type == (type_V) || (V->type == GGML_TYPE_F32 && (type_V) == GGML_TYPE_F16); \ + if (Q->ne[0] == (D) && type_K_okay && type_V_okay) { \ + ggml_sycl_flash_attn_ext_vec_case<D, type_K, type_V>(ctx, dst); \ + return; \ + } \ + } \ + +#define FATTN_VEC_CASES_ALL_D(type_K, type_V) \ + FATTN_VEC_CASE( 64, type_K, type_V) \ + FATTN_VEC_CASE(128, type_K, type_V) \ + FATTN_VEC_CASE(256, type_K, type_V) \ + FATTN_VEC_CASE(512, type_K, type_V) \ + +static void ggml_sycl_flash_attn_ext_vec(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { + ggml_tensor * Q = dst->src[0]; + ggml_tensor * K = dst->src[1]; + ggml_tensor * V = dst->src[2]; + +#ifdef GGML_SYCL_FA_ALL_QUANTS + FATTN_VEC_CASES_ALL_D(GGML_TYPE_F16, GGML_TYPE_F16) + FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_0, GGML_TYPE_F16) + FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_1, GGML_TYPE_F16) + FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_0, GGML_TYPE_F16) + FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_1, GGML_TYPE_F16) + FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q8_0, GGML_TYPE_F16) + + FATTN_VEC_CASES_ALL_D(GGML_TYPE_F16, GGML_TYPE_Q4_0) + FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_0, GGML_TYPE_Q4_0) + FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_1, GGML_TYPE_Q4_0) + FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_0, GGML_TYPE_Q4_0) + FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_1, GGML_TYPE_Q4_0) + FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q8_0, GGML_TYPE_Q4_0) + + FATTN_VEC_CASES_ALL_D(GGML_TYPE_F16, GGML_TYPE_Q4_1) + FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_0, GGML_TYPE_Q4_1) + FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_1, GGML_TYPE_Q4_1) + FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_0, GGML_TYPE_Q4_1) + FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_1, GGML_TYPE_Q4_1) + FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q8_0, GGML_TYPE_Q4_1) + + FATTN_VEC_CASES_ALL_D(GGML_TYPE_F16, GGML_TYPE_Q5_0) + FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_0, GGML_TYPE_Q5_0) + FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_1, GGML_TYPE_Q5_0) + FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_0, GGML_TYPE_Q5_0) + FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_1, GGML_TYPE_Q5_0) + FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q8_0, GGML_TYPE_Q5_0) + + FATTN_VEC_CASES_ALL_D(GGML_TYPE_F16, GGML_TYPE_Q5_1) + FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_0, GGML_TYPE_Q5_1) + FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_1, GGML_TYPE_Q5_1) + FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_0, GGML_TYPE_Q5_1) + FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_1, GGML_TYPE_Q5_1) + FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q8_0, GGML_TYPE_Q5_1) + + FATTN_VEC_CASES_ALL_D(GGML_TYPE_F16, GGML_TYPE_Q8_0) + FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_0, GGML_TYPE_Q8_0) + FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_1, GGML_TYPE_Q8_0) + FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_0, GGML_TYPE_Q8_0) + FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_1, GGML_TYPE_Q8_0) + FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q8_0, GGML_TYPE_Q8_0) +#else + FATTN_VEC_CASES_ALL_D(GGML_TYPE_F16, GGML_TYPE_F16) + FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_0, GGML_TYPE_Q4_0) + FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q8_0, GGML_TYPE_Q8_0) +#endif // GGML_SYCL_FA_ALL_QUANTS + + GGML_ABORT("Not match KV type in vec"); +} + +// Best FlashAttention kernel for a specific GPU: +enum best_fattn_kernel { + BEST_FATTN_KERNEL_NONE = 0, + BEST_FATTN_KERNEL_VEC = 100, + BEST_FATTN_KERNEL_TILE = 200, +}; + +static best_fattn_kernel ggml_sycl_get_best_fattn_kernel(const int device, const ggml_tensor * dst) { + GGML_UNUSED(device); +#ifndef SYCL_FLASH_ATTN + GGML_UNUSED(dst); + return BEST_FATTN_KERNEL_NONE; +#endif// SYCL_FLASH_ATTN + + if(!g_ggml_sycl_enable_flash_attention) return BEST_FATTN_KERNEL_NONE; + + const ggml_tensor * KQV = dst; + const ggml_tensor * Q = dst->src[0]; + const ggml_tensor * K = dst->src[1]; + const ggml_tensor * V = dst->src[2]; + const ggml_tensor * mask = dst->src[3]; + + const int gqa_ratio = Q->ne[2] / K->ne[2]; + GGML_ASSERT(Q->ne[2] % K->ne[2] == 0); + + float max_bias = 0.0f; + memcpy(&max_bias, (const float *) KQV->op_params + 1, sizeof(float)); + + bool gqa_opt_applies = gqa_ratio >= 2 && mask && max_bias == 0.0f && K->ne[1] % FATTN_KQ_STRIDE == 0; + for (const ggml_tensor * t : {Q, K, V, mask}) { + if (t == nullptr || ggml_is_quantized(t->type)) { + continue; + } + for (size_t i = 1; i < GGML_MAX_DIMS; ++i) { + if (t->nb[i] % 16 != 0) { + gqa_opt_applies = false; + break; + } + } + } + + switch (K->ne[0]) { + case 40: + case 64: + case 72: + case 80: + case 96: + case 128: + case 112: + case 256: + case 512: + if (V->ne[0] != K->ne[0]) { + return BEST_FATTN_KERNEL_NONE; + } + break; + case 576: + if (V->ne[0] != 512) { + return BEST_FATTN_KERNEL_NONE; + } + if (!gqa_opt_applies) { + return BEST_FATTN_KERNEL_NONE; + } + break; + default: + return BEST_FATTN_KERNEL_NONE; + } + +#ifndef GGML_SYCL_FA_ALL_QUANTS + if (K->type != V->type) { + return BEST_FATTN_KERNEL_NONE; + } +#endif // GGML_SYCL_FA_ALL_QUANTS + + switch (K->type) { + case GGML_TYPE_F32: + case GGML_TYPE_F16: + break; + case GGML_TYPE_Q4_1: + case GGML_TYPE_Q5_0: + case GGML_TYPE_Q5_1: +#ifndef GGML_SYCL_FA_ALL_QUANTS + return BEST_FATTN_KERNEL_NONE; +#endif // GGML_SYCL_FA_ALL_QUANTS + case GGML_TYPE_Q4_0: + case GGML_TYPE_Q8_0: + break; + default: + return BEST_FATTN_KERNEL_NONE; + } + + if (mask && mask->ne[2] != 1) { + return BEST_FATTN_KERNEL_NONE; + } + + // For small batch sizes the vector kernel may be preferable over the kernels optimized for large batch sizes: + const bool can_use_vector_kernel = Q->ne[0] <= 512 && Q->ne[0] % 64 == 0 && K->ne[1] % FATTN_KQ_STRIDE == 0; + + // Todo: Use the XMX kernel if possible: + + // If there are no tensor cores available, use the generic tile kernel: + if (can_use_vector_kernel) { + if (!ggml_is_quantized(K->type) && !ggml_is_quantized(V->type)) { + if (Q->ne[1] == 1) { + if (!gqa_opt_applies) { + return BEST_FATTN_KERNEL_VEC; + } + } + } else { + if (Q->ne[1] <= 2) { + return BEST_FATTN_KERNEL_VEC; + } + } + } + return BEST_FATTN_KERNEL_TILE; +} + +void ggml_sycl_flash_attn_ext(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { + ggml_sycl_set_device(ctx.device); + switch (ggml_sycl_get_best_fattn_kernel(ggml_sycl_get_device(), dst)) { + case BEST_FATTN_KERNEL_NONE: + GGML_ABORT("Not support Flash-Attention"); + case BEST_FATTN_KERNEL_TILE: + ggml_sycl_flash_attn_ext_tile(ctx, dst); + break; + case BEST_FATTN_KERNEL_VEC: + ggml_sycl_flash_attn_ext_vec(ctx, dst); + break; + } +} + +bool ggml_sycl_flash_attn_ext_supported(int device, const ggml_tensor * dst) { + return ggml_sycl_get_best_fattn_kernel(device, dst) != BEST_FATTN_KERNEL_NONE; +} diff --git a/ggml/src/ggml-sycl/fattn.hpp b/ggml/src/ggml-sycl/fattn.hpp new file mode 100644 index 00000000000..f2a8ffc97de --- /dev/null +++ b/ggml/src/ggml-sycl/fattn.hpp @@ -0,0 +1,22 @@ +// +// MIT license +// Copyright (C) 2025 Intel Corporation +// SPDX-License-Identifier: MIT +// + +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// + +#ifndef GGML_SYCL_FATTN_HPP +#define GGML_SYCL_FATTN_HPP + +#include "common.hpp" + +void ggml_sycl_flash_attn_ext(ggml_backend_sycl_context & ctx, ggml_tensor * dst); + +bool ggml_sycl_flash_attn_ext_supported(int device, const ggml_tensor * dst); + +#endif // GGML_SYCL_FATTN_HPP diff --git a/ggml/src/ggml-sycl/fill.cpp b/ggml/src/ggml-sycl/fill.cpp new file mode 100644 index 00000000000..28e618e4ef5 --- /dev/null +++ b/ggml/src/ggml-sycl/fill.cpp @@ -0,0 +1,55 @@ +#include "fill.hpp" +#include "common.hpp" + +#define SYCL_FILL_BLOCK_SIZE 256 + +template <typename T> +static void fill_kernel(T * dst, const int64_t k, const T value, + const sycl::nd_item<1> & item) { + const int64_t i = (int64_t)item.get_global_id(0); + if (i >= k) { + return; + } + dst[i] = value; +} + +inline void ggml_sycl_op_fill(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { + GGML_ASSERT(ggml_is_contiguous(dst)); + + dpct::queue_ptr stream = ctx.stream(); + SYCL_CHECK(ggml_sycl_set_device(ctx.device)); + + float value; + memcpy(&value, dst->op_params, sizeof(float)); + + const int64_t k = ggml_nelements(dst); + const int64_t num_blocks = (k + SYCL_FILL_BLOCK_SIZE - 1) / SYCL_FILL_BLOCK_SIZE; + void * dst_d = dst->data; + + switch (dst->type) { + case GGML_TYPE_F32: + stream->parallel_for( + sycl::nd_range<1>(num_blocks * SYCL_FILL_BLOCK_SIZE, SYCL_FILL_BLOCK_SIZE), + [=](sycl::nd_item<1> item) { + fill_kernel(static_cast<float *>(dst_d), k, value, item); + }); + break; + case GGML_TYPE_F16: + { + sycl::half h_value = sycl::half(value); + stream->parallel_for( + sycl::nd_range<1>(num_blocks * SYCL_FILL_BLOCK_SIZE, SYCL_FILL_BLOCK_SIZE), + [=](sycl::nd_item<1> item) { + fill_kernel(static_cast<sycl::half *>(dst_d), k, h_value, item); + }); + } + break; + default: + GGML_ABORT("unsupported type"); + } +} + +void ggml_sycl_fill(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { + scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/0); + ggml_sycl_op_fill(ctx, dst); +} diff --git a/ggml/src/ggml-sycl/fill.hpp b/ggml/src/ggml-sycl/fill.hpp new file mode 100644 index 00000000000..b2adb94ff52 --- /dev/null +++ b/ggml/src/ggml-sycl/fill.hpp @@ -0,0 +1,5 @@ +#pragma once + +#include "common.hpp" + +void ggml_sycl_fill(ggml_backend_sycl_context & ctx, ggml_tensor * dst); diff --git a/ggml/src/ggml-sycl/gated_delta_net.cpp b/ggml/src/ggml-sycl/gated_delta_net.cpp new file mode 100644 index 00000000000..239e00bd7e5 --- /dev/null +++ b/ggml/src/ggml-sycl/gated_delta_net.cpp @@ -0,0 +1,347 @@ +#include <sycl/sycl.hpp> +#include "dpct/helper.hpp" +#include "common.hpp" +#include "ggml.h" +#include "gated_delta_net.hpp" +#include <cmath> + + +template <int S_v, bool KDA, bool keep_rs_t> +void gated_delta_net_sycl(const float * q, + const float * k, + const float * v, + const float * g, + const float * beta, + const float * curr_state, + float * dst, + int64_t H, + int64_t n_tokens, + int64_t n_seqs, + int64_t sq1, + int64_t sq2, + int64_t sq3, + int64_t sv1, + int64_t sv2, + int64_t sv3, + int64_t sb1, + int64_t sb2, + int64_t sb3, + const sycl::uint3 neqk1_magic, + const sycl::uint3 rq3_magic, + float scale, + int K) { + auto item_ct1 = sycl::ext::oneapi::this_work_item::get_nd_item<3>(); + const uint32_t h_idx = item_ct1.get_group(2); + const uint32_t sequence = item_ct1.get_group(1); + // each warp owns one column, using warp-level primitives to reduce across rows + const int lane = item_ct1.get_local_id(2); + const int col = item_ct1.get_group(0) * item_ct1.get_local_range(1) + item_ct1.get_local_id(1); + + const uint32_t iq1 = fastmodulo(h_idx, neqk1_magic); + const uint32_t iq3 = fastdiv(sequence, rq3_magic); + + const int64_t attn_score_elems = S_v * H * n_tokens * n_seqs; + float * attn_data = dst; + float * state = dst + attn_score_elems; + + // input state holds s0 only [S_v, S_v, H, n_seqs] — seq stride is D = H * S_v * S_v. + // output state layout (per-slot D * n_seqs) — same per-(seq,head) offset as before. + const int64_t state_in_offset = sequence * H * S_v * S_v + h_idx * S_v * S_v; + const int64_t state_out_offset = (sequence * H + h_idx) * S_v * S_v; + const int64_t state_size_per_token = S_v * S_v * H * n_seqs; // per-slot stride in output + state += state_out_offset; + curr_state += state_in_offset + col * S_v; + attn_data += (sequence * n_tokens * H + h_idx) * S_v; + + constexpr int warp_size = ggml_sycl_get_physical_warp_size() < S_v ? ggml_sycl_get_physical_warp_size() : S_v; + static_assert(S_v % warp_size == 0, "S_v must be a multiple of warp_size"); + constexpr int rows_per_lane = (S_v + warp_size - 1) / warp_size; + float s_shard[rows_per_lane]; +#pragma unroll + for (int r = 0; r < rows_per_lane; r++) { + const int i = r * warp_size + lane; + s_shard[r] = curr_state[i]; + } + + // snapshot slot mapping: slot 0 = most recent state, slot s = s tokens back. + // When n_tokens < K only slots 0..n_tokens-1 are written; older slots are caller-owned. + + for (int t = 0; t < n_tokens; t++) { + const float * q_t = q + iq3 * sq3 + t * sq2 + iq1 * sq1; + const float * k_t = k + iq3 * sq3 + t * sq2 + iq1 * sq1; + const float * v_t = v + sequence * sv3 + t * sv2 + h_idx * sv1; + + const int64_t gb_offset = sequence * sb3 + t * sb2 + h_idx * sb1; + const float * beta_t = beta + gb_offset; + const float * g_t = g + gb_offset * (KDA ? S_v : 1); + + const float beta_val = *beta_t; + + if constexpr (!KDA) { + const float g_val = sycl::native::exp(*g_t); + + // kv[col] = (S^T @ k)[col] = sum_i S[i][col] * k[i] + float kv_shard = 0.0f; +#pragma unroll + for (int r = 0; r < rows_per_lane; r++) { + const int i = r * warp_size + lane; + kv_shard += s_shard[r] * k_t[i]; + } + float kv_col = warp_reduce_sum<warp_size>(kv_shard); + + // delta[col] = (v[col] - g * kv[col]) * beta + float delta_col = (v_t[col] - g_val * kv_col) * beta_val; + + // fused: S[i][col] = g * S[i][col] + k[i] * delta[col] + // attn[col] = (S^T @ q)[col] = sum_i S[i][col] * q[i] + float attn_partial = 0.0f; +#pragma unroll + for (int r = 0; r < rows_per_lane; r++) { + const int i = r * warp_size + lane; + s_shard[r] = g_val * s_shard[r] + k_t[i] * delta_col; + attn_partial += s_shard[r] * q_t[i]; + } + + float attn_col = warp_reduce_sum<warp_size>(attn_partial); + + if (lane == 0) { + attn_data[col] = attn_col * scale; + } + } else { + // kv[col] = sum_i g[i] * S[i][col] * k[i] + float kv_shard = 0.0f; +#pragma unroll + for (int r = 0; r < rows_per_lane; r++) { + const int i = r * warp_size + lane; + kv_shard += sycl::native::exp(g_t[i]) * s_shard[r] * k_t[i]; + } + + float kv_col = warp_reduce_sum<warp_size>(kv_shard); + + // delta[col] = (v[col] - kv[col]) * beta + float delta_col = (v_t[col] - kv_col) * beta_val; + + // fused: S[i][col] = g[i] * S[i][col] + k[i] * delta[col] + // attn[col] = (S^T @ q)[col] = sum_i S[i][col] * q[i] + float attn_partial = 0.0f; +#pragma unroll + for (int r = 0; r < rows_per_lane; r++) { + const int i = r * warp_size + lane; + s_shard[r] = sycl::native::exp(g_t[i]) * s_shard[r] + k_t[i] * delta_col; + attn_partial += s_shard[r] * q_t[i]; + } + + float attn_col = warp_reduce_sum<warp_size>(attn_partial); + + if (lane == 0) { + attn_data[col] = attn_col * scale; + } + } + + attn_data += S_v * H; + + + // Write state back to global memory + if constexpr (keep_rs_t) { + const int target_slot = (int) n_tokens - 1 - t; + if (target_slot >= 0 && target_slot < K) { + float * curr_state = (dst + attn_score_elems) + target_slot * state_size_per_token + state_out_offset; +#pragma unroll + for (int r = 0; r < rows_per_lane; r++) { + const int i = r * warp_size + lane; + curr_state[col * S_v + i] = s_shard[r]; + } + } + } + } + + if constexpr (!keep_rs_t) { +#pragma unroll + for (int r = 0; r < rows_per_lane; r++) { + const int i = r * warp_size + lane; + state[col * S_v + i] = s_shard[r]; + } + } +} + +template <bool KDA, bool keep_rs_t> +static void launch_gated_delta_net(const float * q_d, + const float * k_d, + const float * v_d, + const float * g_d, + const float * b_d, + const float * s_d, + float * dst_d, + int64_t S_v, + int64_t H, + int64_t n_tokens, + int64_t n_seqs, + int64_t sq1, + int64_t sq2, + int64_t sq3, + int64_t sv1, + int64_t sv2, + int64_t sv3, + int64_t sb1, + int64_t sb2, + int64_t sb3, + int64_t neqk1, + int64_t rq3, + float scale, + int K, + dpct::queue_ptr stream) { + //TODO: Add chunked kernel for even faster pre-fill + const int warp_size = ggml_sycl_info().devices[ggml_sycl_get_device()].warp_size; + + const int num_warps = 4; + dpct::dim3 grid_dims(H, n_seqs, (S_v + num_warps - 1) / num_warps); + dpct::dim3 block_dims(warp_size <= S_v ? warp_size : S_v, num_warps, 1); + + const sycl::uint3 neqk1_magic = init_fastdiv_values(neqk1); + const sycl::uint3 rq3_magic = init_fastdiv_values(rq3); + + switch (S_v) { + case 16: + { + constexpr int sv = 16; + stream->parallel_for(sycl::nd_range<3>(grid_dims * block_dims, block_dims), + [=](sycl::nd_item<3> /*item_ct1*/) [[sycl::reqd_sub_group_size(WARP_SIZE)]] { + gated_delta_net_sycl<sv, KDA, keep_rs_t>(q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H, n_tokens, + n_seqs, sq1, sq2, sq3, sv1, sv2, sv3, sb1, sb2, + sb3, neqk1_magic, rq3_magic, scale, K); + }); + } + break; + case 32: + { + constexpr int sv = 32; + stream->parallel_for(sycl::nd_range<3>(grid_dims * block_dims, block_dims), + [=](sycl::nd_item<3> /*item_ct1*/) [[sycl::reqd_sub_group_size(WARP_SIZE)]] { + gated_delta_net_sycl<sv, KDA, keep_rs_t>(q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H, n_tokens, + n_seqs, sq1, sq2, sq3, sv1, sv2, sv3, sb1, sb2, + sb3, neqk1_magic, rq3_magic, scale, K); + }); + } + break; + case 64: { + { + constexpr int sv = 64; + stream->parallel_for(sycl::nd_range<3>(grid_dims * block_dims, block_dims), + [=](sycl::nd_item<3> /*item_ct1*/) [[sycl::reqd_sub_group_size(WARP_SIZE)]] { + gated_delta_net_sycl<sv, KDA, keep_rs_t>( + q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H, n_tokens, n_seqs, sq1, sq2, + sq3, sv1, sv2, sv3, sb1, sb2, sb3, neqk1_magic, rq3_magic, scale, K); + }); + } + break; + } + case 128: { + { + constexpr int sv = 128; + stream->parallel_for(sycl::nd_range<3>(grid_dims * block_dims, block_dims), + [=](sycl::nd_item<3> /*item_ct1*/) [[sycl::reqd_sub_group_size(WARP_SIZE)]] { + gated_delta_net_sycl<sv, KDA, keep_rs_t>( + q_d, k_d, v_d, g_d, b_d, s_d, dst_d, H, n_tokens, n_seqs, sq1, sq2, + sq3, sv1, sv2, sv3, sb1, sb2, sb3, neqk1_magic, rq3_magic, scale, K); + }); + } + break; + } + default: + GGML_ABORT("fatal error"); + break; + } +} + +void ggml_sycl_op_gated_delta_net(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { + ggml_tensor * src_q = dst->src[0]; + ggml_tensor * src_k = dst->src[1]; + ggml_tensor * src_v = dst->src[2]; + ggml_tensor * src_g = dst->src[3]; + ggml_tensor * src_beta = dst->src[4]; + ggml_tensor * src_state = dst->src[5]; + + GGML_TENSOR_LOCALS(int64_t, neq, src_q, ne); + GGML_TENSOR_LOCALS(size_t , nbq, src_q, nb); + GGML_TENSOR_LOCALS(int64_t, nek, src_k, ne); + GGML_TENSOR_LOCALS(size_t , nbk, src_k, nb); + GGML_TENSOR_LOCALS(int64_t, nev, src_v, ne); + GGML_TENSOR_LOCALS(size_t, nbv, src_v, nb); + GGML_TENSOR_LOCALS(size_t, nbb, src_beta, nb); + + const int64_t S_v = nev0; + const int64_t H = nev1; + const int64_t n_tokens = nev2; + const int64_t n_seqs = nev3; + + const bool kda = (src_g->ne[0] == S_v); + + GGML_ASSERT(neq1 == nek1); + const int64_t neqk1 = neq1; + + const int64_t rq3 = nev3 / neq3; + + const float * q_d = (const float *) src_q->data; + const float * k_d = (const float *) src_k->data; + const float * v_d = (const float *) src_v->data; + const float * g_d = (const float *) src_g->data; + const float * b_d = (const float *) src_beta->data; + + const float * s_d = (const float *) src_state->data; + float * dst_d = (float *) dst->data; + + GGML_ASSERT(ggml_is_contiguous_rows(src_q)); + GGML_ASSERT(ggml_is_contiguous_rows(src_k)); + GGML_ASSERT(ggml_is_contiguous_rows(src_v)); + GGML_ASSERT(ggml_are_same_stride(src_q, src_k)); + GGML_ASSERT(src_g->ne[0] == 1 || kda); + GGML_ASSERT(ggml_is_contiguous(src_g)); + GGML_ASSERT(ggml_is_contiguous(src_beta)); + GGML_ASSERT(ggml_is_contiguous(src_state)); + + // strides in floats (beta strides used for both g and beta offset computation) + const int64_t sq1 = nbq1 / sizeof(float); + const int64_t sq2 = nbq2 / sizeof(float); + const int64_t sq3 = nbq3 / sizeof(float); + const int64_t sv1 = nbv1 / sizeof(float); + const int64_t sv2 = nbv2 / sizeof(float); + const int64_t sv3 = nbv3 / sizeof(float); + const int64_t sb1 = nbb1 / sizeof(float); + const int64_t sb2 = nbb2 / sizeof(float); + const int64_t sb3 = nbb3 / sizeof(float); + + const float scale = 1.0f / sqrtf((float) S_v); + + dpct::queue_ptr stream = ctx.stream(); + + // K (snapshot slot count) is an op param; state holds s0 only [S_v, S_v, H, n_seqs]. + const int K = ggml_get_op_params_i32(dst, 0); + const bool keep_rs = K > 1; + + if (kda) { + if (keep_rs) { + launch_gated_delta_net<true, true>(q_d, k_d, v_d, g_d, b_d, s_d, dst_d, + S_v, H, n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3, + sb1, sb2, sb3, neqk1, rq3, scale, K, stream); + } else { + launch_gated_delta_net<true, false>(q_d, k_d, v_d, g_d, b_d, s_d, dst_d, + S_v, H, n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3, + sb1, sb2, sb3, neqk1, rq3, scale, K, stream); + } + } else { + if (keep_rs) { + launch_gated_delta_net<false, true>(q_d, k_d, v_d, g_d, b_d, s_d, dst_d, + S_v, H, n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3, + sb1, sb2, sb3, neqk1, rq3, scale, K, stream); + } else { + launch_gated_delta_net<false, false>(q_d, k_d, v_d, g_d, b_d, s_d, dst_d, + S_v, H, n_tokens, n_seqs, sq1, sq2, sq3, sv1, sv2, sv3, + sb1, sb2, sb3, neqk1, rq3, scale, K, stream); + } + } +} + +void ggml_sycl_gated_delta_net(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { + scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/6); + ggml_sycl_op_gated_delta_net(ctx, dst); +} diff --git a/ggml/src/ggml-sycl/gated_delta_net.hpp b/ggml/src/ggml-sycl/gated_delta_net.hpp new file mode 100644 index 00000000000..350b4ce2f66 --- /dev/null +++ b/ggml/src/ggml-sycl/gated_delta_net.hpp @@ -0,0 +1,9 @@ +#pragma once + +#include <sycl/sycl.hpp> +#include "dpct/helper.hpp" +#include "common.hpp" +#include "ggml.h" + +void ggml_sycl_op_gated_delta_net(ggml_backend_sycl_context & ctx, ggml_tensor * dst); +void ggml_sycl_gated_delta_net(ggml_backend_sycl_context & ctx, ggml_tensor * dst); diff --git a/ggml/src/ggml-sycl/gemm.hpp b/ggml/src/ggml-sycl/gemm.hpp index dcf6c7aeeb4..c202da110be 100644 --- a/ggml/src/ggml-sycl/gemm.hpp +++ b/ggml/src/ggml-sycl/gemm.hpp @@ -29,6 +29,9 @@ class DnnlGemmWrapper { static constexpr dt to_dt() { if constexpr (std::is_same_v<T, float>) return dt::f32; else if constexpr (std::is_same_v<T, sycl::half>) return dt::f16; +#ifdef GGML_SYCL_HAS_BF16 + else if constexpr (std::is_same_v<T, sycl::ext::oneapi::bfloat16>) return dt::bf16; +#endif else static_assert(0); } diff --git a/ggml/src/ggml-sycl/getrows.cpp b/ggml/src/ggml-sycl/getrows.cpp index 03f8dd90748..298f247f84e 100644 --- a/ggml/src/ggml-sycl/getrows.cpp +++ b/ggml/src/ggml-sycl/getrows.cpp @@ -129,11 +129,11 @@ static void get_rows_sycl(ggml_backend_sycl_context & ctx, const ggml_tensor *sr GGML_UNUSED(ctx); } -template <typename src0_t> +template <typename src0_t, typename dst_t> static void get_rows_sycl_float(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, const ggml_tensor *src1, ggml_tensor *dst, const src0_t *src0_dd, const int32_t *src1_dd, - float *dst_dd, queue_ptr stream) { + dst_t *dst_dd, queue_ptr stream) { GGML_TENSOR_BINARY_OP_LOCALS @@ -170,7 +170,7 @@ static void get_rows_sycl_float(ggml_backend_sycl_context & ctx, const ggml_tens void ggml_sycl_op_get_rows(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { GGML_ASSERT(dst->src[1]->type == GGML_TYPE_I32); - GGML_ASSERT(dst->type == GGML_TYPE_F32); + GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_I32 ); GGML_ASSERT(dst->src[0]->nb[0] == ggml_type_size(dst->src[0]->type)); GGML_ASSERT(dst->src[1]->nb[0] == ggml_type_size(dst->src[1]->type)); @@ -183,10 +183,74 @@ void ggml_sycl_op_get_rows(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { get_rows_sycl_float(ctx, dst->src[0], dst->src[1], dst, (const sycl::half *)dst->src[0]->data, src1_i32, (float *)dst->data, ctx.stream()); break; + case GGML_TYPE_BF16: + get_rows_sycl_float(ctx, dst->src[0], dst->src[1], dst, (const sycl::ext::oneapi::bfloat16 *)dst->src[0]->data, + src1_i32, (float *)dst->data, ctx.stream()); + break; case GGML_TYPE_F32: get_rows_sycl_float(ctx, dst->src[0], dst->src[1], dst, (const float *)dst->src[0]->data, src1_i32, (float *)dst->data, ctx.stream()); break; + case GGML_TYPE_I32: + get_rows_sycl_float(ctx, dst->src[0], dst->src[1], dst, (const int32_t *)dst->src[0]->data, + src1_i32, (int32_t *)dst->data, ctx.stream()); + break; + case GGML_TYPE_Q1_0: + get_rows_sycl<QK1_0, 1, dequantize_q1_0>(ctx, dst->src[0], dst->src[1], dst, (const float *)dst->src[0]->data, + src1_i32, (float *)dst->data, ctx.stream()); + break; + case GGML_TYPE_MXFP4: + get_rows_sycl<QK_MXFP4, 2, dequantize_mxfp4>(ctx, dst->src[0], dst->src[1], dst, (const float *)dst->src[0]->data, + src1_i32, (float *)dst->data, ctx.stream()); + break; + case GGML_TYPE_NVFP4: + get_rows_sycl<QK_NVFP4, 1, dequantize_nvfp4>(ctx, dst->src[0], dst->src[1], dst, (const float *)dst->src[0]->data, + src1_i32, (float *)dst->data, ctx.stream()); + break; + case GGML_TYPE_IQ2_XXS: + get_rows_sycl<QK_K, 1, dequantize_iq2_xxs>(ctx, dst->src[0], dst->src[1], dst, (const float *)dst->src[0]->data, + src1_i32, (float *)dst->data, ctx.stream()); + break; + case GGML_TYPE_IQ2_XS: + get_rows_sycl<QK_K, 1, dequantize_iq2_xs>(ctx, dst->src[0], dst->src[1], dst, (const float *)dst->src[0]->data, + src1_i32, (float *)dst->data, ctx.stream()); + break; + case GGML_TYPE_IQ2_S: + get_rows_sycl<QK_K, 1, dequantize_iq2_s>(ctx, dst->src[0], dst->src[1], dst, (const float *)dst->src[0]->data, + src1_i32, (float *)dst->data, ctx.stream()); + break; + case GGML_TYPE_IQ3_XXS: + get_rows_sycl<QK_K, 1, dequantize_iq3_xxs>(ctx, dst->src[0], dst->src[1], dst, (const float *)dst->src[0]->data, + src1_i32, (float *)dst->data, ctx.stream()); + break; + case GGML_TYPE_IQ1_S: + get_rows_sycl<QK_K, 1, dequantize_iq1_s>(ctx, dst->src[0], dst->src[1], dst, (const float *)dst->src[0]->data, + src1_i32, (float *)dst->data, ctx.stream()); + break; + case GGML_TYPE_IQ1_M: + get_rows_sycl<QK_K, 1, dequantize_iq1_m>(ctx, dst->src[0], dst->src[1], dst, (const float *)dst->src[0]->data, + src1_i32, (float *)dst->data, ctx.stream()); + break; + case GGML_TYPE_IQ3_S: + get_rows_sycl<QK_K, 1, dequantize_iq3_s>(ctx, dst->src[0], dst->src[1], dst, (const float *)dst->src[0]->data, + src1_i32, (float *)dst->data, ctx.stream()); + break; + case GGML_TYPE_IQ4_NL: + get_rows_sycl<QK4_NL, 1, dequantize_iq4_nl>(ctx, dst->src[0], dst->src[1], dst, (const float *)dst->src[0]->data, + src1_i32, (float *)dst->data, ctx.stream()); + break; + case GGML_TYPE_IQ4_XS: + get_rows_sycl<QK_K, 1, dequantize_iq4_xs>(ctx, dst->src[0], dst->src[1], dst, (const float *)dst->src[0]->data, + src1_i32, (float *)dst->data, ctx.stream()); + break; + case GGML_TYPE_Q2_K: + get_rows_sycl<QK_K, 1, dequantize_q2_K>(ctx, dst->src[0], dst->src[1], dst, (const float *)dst->src[0]->data, + src1_i32, (float *)dst->data, ctx.stream()); + break; + case GGML_TYPE_Q3_K: + get_rows_sycl<QK_K, 1, dequantize_q3_K>(ctx, dst->src[0], dst->src[1], dst, (const float *)dst->src[0]->data, + src1_i32, (float *)dst->data, ctx.stream()); + break; case GGML_TYPE_Q4_0: get_rows_sycl<QK4_0, QR4_0, dequantize_q4_0>(ctx, dst->src[0], dst->src[1], dst, (const float *)dst->src[0]->data, src1_i32, (float *)dst->data, ctx.stream()); @@ -195,6 +259,10 @@ void ggml_sycl_op_get_rows(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { get_rows_sycl<QK4_1, QR4_1, dequantize_q4_1>(ctx, dst->src[0], dst->src[1], dst, (const float *)dst->src[0]->data, src1_i32, (float *)dst->data, ctx.stream()); break; + case GGML_TYPE_Q4_K: + get_rows_sycl<QK_K, 1, dequantize_q4_K>(ctx, dst->src[0], dst->src[1], dst, (const float *)dst->src[0]->data, + src1_i32, (float *)dst->data, ctx.stream()); + break; case GGML_TYPE_Q5_0: get_rows_sycl<QK5_0, QR5_0, dequantize_q5_0>(ctx, dst->src[0], dst->src[1], dst, (const float *)dst->src[0]->data, src1_i32, (float *)dst->data, ctx.stream()); @@ -203,6 +271,14 @@ void ggml_sycl_op_get_rows(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { get_rows_sycl<QK5_1, QR5_1, dequantize_q5_1>(ctx, dst->src[0], dst->src[1], dst, (const float *)dst->src[0]->data, src1_i32, (float *)dst->data, ctx.stream()); break; + case GGML_TYPE_Q5_K: + get_rows_sycl<QK_K, 1, dequantize_q5_K>(ctx, dst->src[0], dst->src[1], dst, (const float *)dst->src[0]->data, + src1_i32, (float *)dst->data, ctx.stream()); + break; + case GGML_TYPE_Q6_K: + get_rows_sycl<QK_K, 1, dequantize_q6_K>(ctx, dst->src[0], dst->src[1], dst, (const float *)dst->src[0]->data, + src1_i32, (float *)dst->data, ctx.stream()); + break; case GGML_TYPE_Q8_0: get_rows_sycl<QK8_0, QR8_0, dequantize_q8_0>(ctx, dst->src[0], dst->src[1], dst, (const float *)dst->src[0]->data, src1_i32, (float *)dst->data, ctx.stream()); diff --git a/ggml/src/ggml-sycl/ggml-sycl.cpp b/ggml/src/ggml-sycl/ggml-sycl.cpp index 8f8176b678a..3f246e8672d 100644 --- a/ggml/src/ggml-sycl/ggml-sycl.cpp +++ b/ggml/src/ggml-sycl/ggml-sycl.cpp @@ -19,6 +19,7 @@ #include <cstdlib> #include <float.h> #include <limits> +#include <optional> #include <stdint.h> #include <stdio.h> #include <vector> @@ -30,11 +31,21 @@ #include <regex> #include <sycl/sycl.hpp> +#include <sycl/backend.hpp> +#ifdef GGML_SYCL_SUPPORT_LEVEL_ZERO +#include <level_zero/ze_api.h> +#endif #if defined(GGML_SYCL_GRAPH) && SYCL_EXT_ONEAPI_ASYNC_MEMORY_ALLOC # include <sycl/ext/oneapi/experimental/async_alloc/async_alloc.hpp> #endif +#if SYCL_EXT_ONEAPI_VIRTUAL_MEM +# include <sycl/ext/oneapi/virtual_mem/physical_mem.hpp> +# include <sycl/ext/oneapi/virtual_mem/virtual_mem.hpp> +# define GGML_SYCL_USE_VMM +#endif #include <sycl/half_type.hpp> +#include "ggml.h" #include "ggml-sycl.h" #include "ggml-impl.h" #include "ggml-backend-impl.h" @@ -43,25 +54,35 @@ #include "ggml-sycl/backend.hpp" #include "ggml-sycl/common.hpp" #include "ggml-sycl/element_wise.hpp" +#include "ggml-sycl/gemm.hpp" +#include "ggml-sycl/getrows.hpp" #include "ggml-sycl/norm.hpp" #include "ggml-sycl/presets.hpp" -#include "ggml-sycl/gemm.hpp" +#include "ggml-sycl/quantize.hpp" +#include "ggml-sycl/repeat_back.hpp" #include "ggml-sycl/set_rows.hpp" #include "ggml-sycl/set.hpp" -#include "ggml-sycl/sycl_hw.hpp" -#include "ggml-sycl/getrows.hpp" -#include "ggml-sycl/repeat_back.hpp" -#include "ggml-sycl/quantize.hpp" #include "ggml-sycl/ssm_conv.hpp" -#include "ggml.h" +#include "ggml-sycl/sycl_hw.hpp" +#include "ggml-sycl/ssm_scan.hpp" +#include "ggml-sycl/fill.hpp" +#include "ggml-sycl/cumsum.hpp" +#include "ggml-sycl/diag.hpp" +#include "ggml-sycl/solve_tri.hpp" +#include "ggml-sycl/gated_delta_net.hpp" static bool g_sycl_loaded = false; int g_ggml_sycl_debug = 0; int g_ggml_sycl_disable_optimize = 0; int g_ggml_sycl_disable_graph = 0; int g_ggml_sycl_disable_dnn = 0; +int g_ggml_sycl_enable_vmm = 1; int g_ggml_sycl_prioritize_dmmv = 0; int g_ggml_sycl_use_async_mem_op = 0; +int g_ggml_sycl_use_async_mem_op_requested = 1; +int g_ggml_sycl_enable_level_zero = 0; +int g_ggml_sycl_enable_flash_attention = 1; + static ggml_sycl_device_info ggml_sycl_init() { ggml_sycl_device_info info = {}; @@ -82,23 +103,50 @@ static ggml_sycl_device_info ggml_sycl_init() { // GGML_LOG_INFO("%s: SYCL_USE_XMX: no\n", __func__); // #endif for (int i = 0; i < info.device_count; ++i) { - info.devices[i].vmm = 0; dpct::device_info prop; - sycl::device device = dpct::dev_mgr::instance().get_device(i); + auto & device = dpct::dev_mgr::instance().get_device(i); SYCL_CHECK(CHECK_TRY_ERROR(dpct::get_device_info( prop, device))); +#if !defined(GGML_SYCL_USE_VMM) + info.devices[i].vmm = 0; +#else + info.devices[i].vmm = device.has(sycl::aspect::ext_oneapi_virtual_mem); + if (info.devices[i].vmm) { + // NB: SYCL's get_mem_granularity always returns the _minimum_ granularity, + // but the L0 API requires a larger page size for allocs above 2 MiB and + // rejects non-multiples with UR_RESULT_ERROR_INVALID_VALUE [sic]. + // Here we clamp it to 2 MiB for simplicity, but other devices may require + // calling zeVirtualMemQueryPageSize or yet unexposed public API. + const size_t physical_page = 2ull << 20; // 2 MiB + info.devices[i].vmm_granularity = std::max<size_t>( + sycl::ext::oneapi::experimental::get_mem_granularity( + device, sycl::context(device)), + physical_page); + } +#endif + info.default_tensor_split[i] = total_vram; total_vram += prop.get_global_mem_size(); info.devices[i].cc = 100 * prop.get_major_version() + 10 * prop.get_minor_version(); - info.devices[i].nsm = prop.get_max_compute_units(); + info.devices[i].nsm = prop.get_max_compute_units() / 16; //16: Number of Xe Cores info.devices[i].opt_feature.reorder = device.ext_oneapi_architecture_is(syclex::arch_category::intel_gpu); info.devices[i].smpbo = prop.get_local_mem_size(); + info.devices[i].warp_size = WARP_SIZE; info.max_work_group_sizes[i] = prop.get_max_work_group_size(); + info.devices[i].max_wg_per_cu = info.max_work_group_sizes[i] / prop.get_max_compute_units(); + info.devices[i].hw_info = get_device_hw_info(&device); + + // Only check GPU devices; CPU devices use OpenCL and would otherwise + // disable Level Zero for the GPUs on systems without ONEAPI_DEVICE_SELECTOR set. + if (device.is_gpu() && device.default_queue().get_backend() != sycl::backend::ext_oneapi_level_zero) { + GGML_LOG_WARN("SYCL GPU device %d does not use Level Zero backend, disabling Level Zero memory API\n", i); + info.ext_oneapi_level_zero = false; + } } for (int id = 0; id < info.device_count; ++id) { @@ -210,8 +258,54 @@ static void ggml_check_sycl() try { g_ggml_sycl_disable_optimize = get_sycl_env("GGML_SYCL_DISABLE_OPT", 0); g_ggml_sycl_disable_graph = get_sycl_env("GGML_SYCL_DISABLE_GRAPH", 1); g_ggml_sycl_disable_dnn = get_sycl_env("GGML_SYCL_DISABLE_DNN", 0); + g_ggml_sycl_enable_vmm = get_sycl_env("GGML_SYCL_ENABLE_VMM", 1); g_ggml_sycl_prioritize_dmmv = get_sycl_env("GGML_SYCL_PRIORITIZE_DMMV", 0); +#ifdef GGML_SYCL_SUPPORT_LEVEL_ZERO + g_ggml_sycl_enable_level_zero = get_sycl_env("GGML_SYCL_ENABLE_LEVEL_ZERO", ggml_sycl_info().ext_oneapi_level_zero); +#else + g_ggml_sycl_enable_level_zero = 0; +#endif + +#ifdef SYCL_FLASH_ATTN + g_ggml_sycl_enable_flash_attention = get_sycl_env("GGML_SYCL_ENABLE_FLASH_ATTN", 1); +#else + g_ggml_sycl_enable_flash_attention = 0; +#endif + GGML_SYCL_DEBUG("[SYCL] call ggml_check_sycl\n"); + + GGML_LOG_INFO("Build with Macros:\n"); +#if defined(GGML_SYCL_FORCE_MMQ) + GGML_LOG_INFO(" GGML_SYCL_FORCE_MMQ: yes\n"); +#else + GGML_LOG_INFO(" GGML_SYCL_FORCE_MMQ: no\n"); +#endif +#if defined(GGML_SYCL_F16) + GGML_LOG_INFO(" GGML_SYCL_F16: yes\n"); +#else + GGML_LOG_INFO(" GGML_SYCL_F16: no\n"); +#endif +#if defined(GGML_SYCL_GRAPH) + GGML_LOG_INFO(" GGML_SYCL_GRAPH: yes\n"); +#else + GGML_LOG_INFO(" GGML_SYCL_GRAPH: no\n"); +#endif +#if defined(GGML_SYCL_DNNL) + GGML_LOG_INFO(" GGML_SYCL_DNNL: yes\n"); +#else + GGML_LOG_INFO(" GGML_SYCL_DNNL: no\n"); +#endif +#if defined(GGML_SYCL_SUPPORT_LEVEL_ZERO) + GGML_LOG_INFO(" GGML_SYCL_SUPPORT_LEVEL_ZERO: yes\n"); +#else + GGML_LOG_INFO(" GGML_SYCL_SUPPORT_LEVEL_ZERO: no\n"); +#endif +#if defined(GGML_SYCL_USE_VMM) + GGML_LOG_INFO(" GGML_SYCL_USE_VMM: yes\n"); +#else + GGML_LOG_INFO(" GGML_SYCL_USE_VMM: no\n"); +#endif + GGML_LOG_INFO("Running with Environment Variables:\n"); GGML_LOG_INFO(" GGML_SYCL_DEBUG: %d\n", g_ggml_sycl_debug); GGML_LOG_INFO(" GGML_SYCL_DISABLE_OPT: %d\n", g_ggml_sycl_disable_optimize); @@ -220,22 +314,30 @@ static void ggml_check_sycl() try { #else GGML_LOG_INFO(" GGML_SYCL_DISABLE_GRAPH: graph disabled by compile flag\n"); #endif +#ifdef GGML_SYCL_SUPPORT_LEVEL_ZERO + GGML_LOG_INFO(" GGML_SYCL_ENABLE_LEVEL_ZERO: %d\n", g_ggml_sycl_enable_level_zero); +#else + GGML_LOG_INFO(" GGML_SYCL_ENABLE_LEVEL_ZERO: Level Zero disabled by compile flag\n"); +#endif #if GGML_SYCL_DNNL GGML_LOG_INFO(" GGML_SYCL_DISABLE_DNN: %d\n", g_ggml_sycl_disable_dnn); #else GGML_LOG_INFO(" GGML_SYCL_DISABLE_DNN: DNN disabled by compile flag\n"); #endif - GGML_LOG_INFO(" GGML_SYCL_PRIORITIZE_DMMV: %d\n", g_ggml_sycl_prioritize_dmmv); - GGML_LOG_INFO("Build with Macros:\n"); -#if defined(GGML_SYCL_FORCE_MMQ) - GGML_LOG_INFO(" GGML_SYCL_FORCE_MMQ: yes\n"); +#if defined(GGML_SYCL_USE_VMM) + GGML_LOG_INFO(" GGML_SYCL_ENABLE_VMM: %d\n", g_ggml_sycl_enable_vmm); #else - GGML_LOG_INFO(" GGML_SYCL_FORCE_MMQ: no\n"); + GGML_LOG_INFO(" GGML_SYCL_ENABLE_VMM: virtual memory extension is not available\n"); #endif -#if defined(GGML_SYCL_F16) - GGML_LOG_INFO(" GGML_SYCL_F16: yes\n"); + GGML_LOG_INFO(" GGML_SYCL_PRIORITIZE_DMMV: %d\n", g_ggml_sycl_prioritize_dmmv); + g_ggml_sycl_use_async_mem_op_requested = get_sycl_env("GGML_SYCL_USE_ASYNC_MEM_OP", 1); + GGML_LOG_INFO(" GGML_SYCL_USE_ASYNC_MEM_OP: %d\n", g_ggml_sycl_use_async_mem_op_requested); + +#ifdef SYCL_FLASH_ATTN + GGML_LOG_INFO(" GGML_SYCL_ENABLE_FLASH_ATTN: %d\n", g_ggml_sycl_enable_flash_attention); #else - GGML_LOG_INFO(" GGML_SYCL_F16: no\n"); + GGML_LOG_INFO(" GGML_SYCL_ENABLE_FLASH_ATTN: %d disabled by compile flag\n", + g_ggml_sycl_enable_flash_attention); #endif /* NOT REMOVE, keep it for next optimize for XMX. @@ -245,11 +347,11 @@ static void ggml_check_sycl() try { fprintf(stderr, "%s: SYCL_USE_XMX: no\n", __func__); #endif */ - // Currently, we only use async malloc / free when graphs are enabled as it is required for the calls to be - // properly recorded. As this SYCL extension matures it may be beneficial to enable as the default path and in - // other places. + // Async USM allocation/free is also useful outside the graph path: it avoids the host waits in the reorder + // staging path while preserving queue ordering semantics. Graph support still depends on the extension being + // available, but it no longer needs to control the non-graph fast path. #if defined(GGML_SYCL_GRAPH) && SYCL_EXT_ONEAPI_ASYNC_MEMORY_ALLOC - g_ggml_sycl_use_async_mem_op = !g_ggml_sycl_disable_graph; + g_ggml_sycl_use_async_mem_op = g_ggml_sycl_use_async_mem_op_requested || !g_ggml_sycl_disable_graph; if (g_ggml_sycl_use_async_mem_op) { for (unsigned int i = 0; i < dpct::dev_mgr::instance().device_count(); ++i) { if (!dpct::dev_mgr::instance().get_device(i).has(sycl::aspect::ext_oneapi_async_memory_alloc)) { @@ -333,7 +435,7 @@ struct ggml_backend_sycl_buffer_context { ~ggml_backend_sycl_buffer_context() { if (dev_ptr != nullptr) { ggml_sycl_set_device(device); - SYCL_CHECK(CHECK_TRY_ERROR(sycl::free(dev_ptr, *stream))); + SYCL_CHECK(CHECK_TRY_ERROR(ggml_sycl_free_device(dev_ptr, *stream))); } //release extra used by tensors @@ -379,11 +481,22 @@ ggml_backend_sycl_buffer_init_tensor(ggml_backend_buffer_t buffer, assert(tensor->view_src->buffer->buft == buffer->buft); return GGML_STATUS_SUCCESS; } - if ((tensor->type == GGML_TYPE_Q4_0 || tensor->type == GGML_TYPE_Q4_K || tensor->type == GGML_TYPE_Q6_K) && - !g_ggml_sycl_disable_optimize) { - ggml_tensor_extra_gpu * extra = new ggml_tensor_extra_gpu{}; - tensor->extra = extra; - ctx->tensor_extras.push_back(extra); //used to release it when destroy ctx. + + if (!g_ggml_sycl_disable_optimize) { + // set reorder extra buffer based on supported type + switch (tensor->type) { + case GGML_TYPE_Q4_0: + case GGML_TYPE_Q8_0: + case GGML_TYPE_Q4_K: + case GGML_TYPE_Q6_K:{ + ggml_tensor_extra_gpu * extra = new ggml_tensor_extra_gpu{}; + tensor->extra = extra; + ctx->tensor_extras.push_back(extra); + break; + } + default: + break; + } } if (ggml_is_quantized(tensor->type)) { @@ -455,8 +568,43 @@ catch (sycl::exception const &exc) { std::exit(1); } +#ifdef GGML_SYCL_SUPPORT_LEVEL_ZERO +static bool ggml_sycl_is_l0_discrete_gpu(sycl::queue &q) { + if (!q.get_device().is_gpu() || q.get_backend() != sycl::backend::ext_oneapi_level_zero) { + return false; + } + + ze_device_handle_t ze_dev = sycl::get_native<sycl::backend::ext_oneapi_level_zero>(q.get_device()); + ze_device_properties_t props = {}; + props.stype = ZE_STRUCTURE_TYPE_DEVICE_PROPERTIES; + ze_result_t r = zeDeviceGetProperties(ze_dev, &props); + return r == ZE_RESULT_SUCCESS && !(props.flags & ZE_DEVICE_PROPERTY_FLAG_INTEGRATED); +} +#endif + static void dev2dev_memcpy(sycl::queue &q_dst, sycl::queue &q_src, void *ptr_dst, const void *ptr_src, size_t size) { +#ifdef GGML_SYCL_SUPPORT_LEVEL_ZERO + // Use Level Zero direct copy for dGPU-to-dGPU transfers. + const bool l0_copy_supported = + ggml_sycl_is_l0_discrete_gpu(q_dst) && ggml_sycl_is_l0_discrete_gpu(q_src); + if (g_ggml_sycl_enable_level_zero && l0_copy_supported) { + auto ze_ctx = sycl::get_native<sycl::backend::ext_oneapi_level_zero>(q_dst.get_context()); + auto ze_dev = sycl::get_native<sycl::backend::ext_oneapi_level_zero>(q_dst.get_device()); + ze_command_queue_desc_t cq_desc = {ZE_STRUCTURE_TYPE_COMMAND_QUEUE_DESC, nullptr, 0, 0, + 0, ZE_COMMAND_QUEUE_MODE_SYNCHRONOUS, ZE_COMMAND_QUEUE_PRIORITY_NORMAL}; + ze_command_list_handle_t cl; + ze_result_t r = zeCommandListCreateImmediate(ze_ctx, ze_dev, &cq_desc, &cl); + if (r == ZE_RESULT_SUCCESS) { + r = zeCommandListAppendMemoryCopy(cl, ptr_dst, ptr_src, size, nullptr, 0, nullptr); + zeCommandListDestroy(cl); + if (r == ZE_RESULT_SUCCESS) { + return; + } + } + } +#endif + // Host-staged copy char *host_buf = (char *)malloc(size); q_src.memcpy(host_buf, (const char *)ptr_src, size).wait(); q_dst.memcpy((char *)ptr_dst, host_buf, size).wait(); @@ -537,9 +685,15 @@ static void ggml_backend_sycl_buffer_clear(ggml_backend_buffer_t buffer, SYCL_CHECK( CHECK_TRY_ERROR(dpct::get_current_device().queues_wait_and_throw())); - SYCL_CHECK(CHECK_TRY_ERROR((*stream) - .memset(ctx->dev_ptr, value, buffer->size) - .wait())); + constexpr size_t MAX_CHUNK = 2ULL << 30; // 2 GiB + for (size_t off = 0; off < buffer->size; off += MAX_CHUNK) { + size_t chunk = std::min(buffer->size - off, MAX_CHUNK); + SYCL_CHECK(CHECK_TRY_ERROR( + (*stream) + .memset(static_cast<char*>(ctx->dev_ptr) + off, value, chunk) + .wait() + )); + } } catch (sycl::exception const &exc) { std::cerr << exc.what() << "Exception caught at file:" << __FILE__ @@ -589,6 +743,8 @@ static const ggml_backend_buffer_i ggml_backend_sycl_buffer_interface = { /* .memset_tensor = */ ggml_backend_sycl_buffer_memset_tensor, /* .set_tensor = */ ggml_backend_sycl_buffer_set_tensor, /* .get_tensor = */ ggml_backend_sycl_buffer_get_tensor, + /* .set_tensor_2d = */ NULL, + /* .get_tensor_2d = */ NULL, /* .cpy_tensor = */ ggml_backend_sycl_buffer_cpy_tensor, /* .clear = */ ggml_backend_sycl_buffer_clear, /* .reset = */ ggml_backend_sycl_buffer_reset, @@ -618,8 +774,7 @@ ggml_backend_sycl_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size = std::max(size, (size_t)1); // syclMalloc returns null for size 0 void * dev_ptr; - SYCL_CHECK(CHECK_TRY_ERROR(dev_ptr = (void *)sycl::malloc_device( - size, *stream))); + SYCL_CHECK(CHECK_TRY_ERROR(dev_ptr = (void *)ggml_sycl_malloc_device(size, *stream))); if (!dev_ptr) { GGML_LOG_ERROR("%s: can't allocate %lu Bytes of memory on device\n", __func__, size); return nullptr; @@ -634,7 +789,7 @@ catch (sycl::exception const &exc) { } static size_t ggml_backend_sycl_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) { - return 128; + return SYCL_BUFFER_ALIGNMENT; GGML_UNUSED(buft); } @@ -860,18 +1015,10 @@ ggml_backend_sycl_split_buffer_init_tensor(ggml_backend_buffer_t buffer, size += ggml_row_size(tensor->type, MATRIX_ROW_PADDING - ne0 % MATRIX_ROW_PADDING); } - // FIXME: do not crash if SYCL Buffer alloc fails - // currently, init_tensor cannot fail, it needs to be fixed in ggml-backend first ggml_sycl_set_device(i); const queue_ptr stream = ctx->streams[i]; char * buf; - /* - DPCT1009:208: SYCL uses exceptions to report errors and does not use the - error codes. The original code was commented out and a warning string - was inserted. You need to rewrite this code. - */ - SYCL_CHECK(CHECK_TRY_ERROR(buf = (char *)sycl::malloc_device( - size, *stream))); + SYCL_CHECK(CHECK_TRY_ERROR(buf = (char *)ggml_sycl_malloc_device(size, *stream))); if (!buf) { char err_buf[1024]; snprintf(err_buf, 1023, "%s: can't allocate %lu Bytes of memory on device\n", __func__, size); @@ -1035,6 +1182,8 @@ static struct ggml_backend_buffer_i ggml_backend_sycl_split_buffer_interface = { /* .memset_tensor = */ NULL, /* .set_tensor = */ ggml_backend_sycl_split_buffer_set_tensor, /* .get_tensor = */ ggml_backend_sycl_split_buffer_get_tensor, + /* .set_tensor_2d = */ NULL, + /* .get_tensor_2d = */ NULL, /* .cpy_tensor = */ NULL, /* .clear = */ ggml_backend_sycl_split_buffer_clear, /* .reset = */ NULL, @@ -1063,7 +1212,7 @@ static ggml_backend_buffer_t ggml_backend_sycl_split_buffer_type_alloc_buffer(gg } static size_t ggml_backend_sycl_split_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) { - return 128; + return SYCL_BUFFER_ALIGNMENT; GGML_UNUSED(buft); } @@ -1157,13 +1306,28 @@ static const char * ggml_backend_sycl_host_buffer_type_name(ggml_backend_buffer_ GGML_UNUSED(buft); } +inline void * aligned_malloc_host(size_t alignment, size_t size) { +#ifdef _WIN32 + return _aligned_malloc(size, alignment); +#else + return aligned_alloc(alignment, size); +#endif +} + +inline void free_aligned_mem_host(void * memblock) { +#ifdef _WIN32 + _aligned_free(memblock); +#else + free(memblock); +#endif +} + static void ggml_backend_sycl_host_buffer_free_buffer(ggml_backend_buffer_t buffer) { - ggml_sycl_host_free(buffer->context); + free_aligned_mem_host((void *)buffer->context); } static ggml_backend_buffer_t ggml_backend_sycl_host_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) { - void * ptr = ggml_sycl_host_malloc(size); - + void * ptr = aligned_malloc_host(TENSOR_ALIGNMENT, size); if (ptr == nullptr) { // fallback to cpu buffer return ggml_backend_buft_alloc_buffer(ggml_backend_cpu_buffer_type(), size); @@ -1212,16 +1376,53 @@ struct ggml_sycl_pool_leg : public ggml_sycl_pool { explicit ggml_sycl_pool_leg(queue_ptr qptr_, int device_) : device(device_), qptr(qptr_) {} ~ggml_sycl_pool_leg() { +#ifdef DEBUG_SYCL_POOL + int n_cached = 0; + size_t bytes_cached = 0; + for (int i = 0; i < MAX_SYCL_BUFFERS; ++i) { + if (buffer_pool[i].ptr != nullptr) { + ++n_cached; + bytes_cached += buffer_pool[i].size; + } + } + GGML_LOG_INFO("%s: %d buffers, cached = %.2f MiB\n", __func__, + n_cached, bytes_cached / 1024.0 / 1024.0); + const auto slots = format_slots_in_alloc_order(); + if (!slots.empty()) { + GGML_LOG_INFO("%s: slots MiB: %s\n", __func__, slots.c_str()); + } +#endif + for (int i = 0; i < MAX_SYCL_BUFFERS; ++i) { ggml_sycl_buffer & b = buffer_pool[i]; if (b.ptr != nullptr) { - SYCL_CHECK(CHECK_TRY_ERROR(sycl::free(b.ptr, *qptr))); + SYCL_CHECK(CHECK_TRY_ERROR(ggml_sycl_free_device(b.ptr, *qptr))); pool_size -= b.size; } } GGML_ASSERT(pool_size == 0); } +#ifdef DEBUG_SYCL_POOL + std::string format_slots_in_alloc_order() const { + std::string line; + char buf[32]; + bool first = true; + for (int i = 0; i < MAX_SYCL_BUFFERS; ++i) { + if (buffer_pool[i].ptr == nullptr) { + continue; + } + if (!first) { + line += '/'; + } + first = false; + snprintf(buf, sizeof(buf), "%.2f", buffer_pool[i].size / 1024.0 / 1024.0); + line += buf; + } + return line; + } +#endif + void * alloc(size_t size, size_t * actual_size) override { #ifdef DEBUG_sycl_MALLOC int nnz = 0; @@ -1263,9 +1464,7 @@ struct ggml_sycl_pool_leg : public ggml_sycl_pool { void * ptr; size_t look_ahead_size = (size_t) (1.05 * size); - SYCL_CHECK( - CHECK_TRY_ERROR(ptr = (void *)sycl::malloc_device( - look_ahead_size, *qptr))); + SYCL_CHECK(CHECK_TRY_ERROR(ptr = (void *)ggml_sycl_malloc_device(look_ahead_size, *qptr))); if (!ptr) { GGML_LOG_ERROR("%s: can't allocate %lu Bytes of memory on device/GPU\n", __func__, look_ahead_size); return nullptr; @@ -1293,11 +1492,126 @@ struct ggml_sycl_pool_leg : public ggml_sycl_pool { } } GGML_LOG_WARN("WARNING: sycl buffer pool full, increase MAX_sycl_BUFFERS\n"); - SYCL_CHECK(CHECK_TRY_ERROR(sycl::free(ptr, *qptr))); + SYCL_CHECK(CHECK_TRY_ERROR(ggml_sycl_free_device(ptr, *qptr))); pool_size -= size; } }; +// pool with virtual memory management +#if defined(GGML_SYCL_USE_VMM) +struct ggml_sycl_pool_vmm : public ggml_sycl_pool { + static const size_t SYCL_POOL_VMM_MAX_SIZE = 1ull << 35; // 32 GB + + int device; + sycl::context ctx; + sycl::device dev; + + uintptr_t pool_addr = 0; + size_t pool_used = 0; + size_t pool_size = 0; + size_t granularity; + + // physical_mem owns the commits (unlike cuMemMap) + struct mapping { + sycl::ext::oneapi::experimental::physical_mem phys; + void * map_ptr; + }; + std::vector<mapping> mappings; + + explicit ggml_sycl_pool_vmm(queue_ptr qptr_, int device_) : + device(device_), + ctx(qptr_->get_context()), + dev(qptr_->get_device()), + granularity(ggml_sycl_info().devices[device_].vmm_granularity) { + } + + ~ggml_sycl_pool_vmm() { + if (pool_addr == 0) { + return; + } + + // Per spec, unmap must (a) match the exact (ptr, size) of an earlier + // physical_mem::map() call and (b) precede destruction of the + // physical_mem objects (their dtors won't unmap). + for (auto & m : mappings) { + SYCL_CHECK(CHECK_TRY_ERROR(sycl::ext::oneapi::experimental::unmap( + m.map_ptr, m.phys.size(), ctx))); + } + SYCL_CHECK(CHECK_TRY_ERROR(sycl::ext::oneapi::experimental::free_virtual_mem( + pool_addr, SYCL_POOL_VMM_MAX_SIZE, ctx))); + } + + void * alloc(size_t size, size_t * actual_size) override { + // round up the allocation size to the alignment to ensure that all allocations are aligned for all data types + size = GGML_PAD(size, SYCL_BUFFER_ALIGNMENT); + + size_t avail = pool_size - pool_used; + + if (size > avail) { + // round up to the next multiple of the granularity + size_t reserve_size = GGML_PAD(size - avail, granularity); + + GGML_ASSERT(pool_size + reserve_size <= SYCL_POOL_VMM_MAX_SIZE); + + // allocate more physical memory + std::optional<sycl::ext::oneapi::experimental::physical_mem> phys; + SYCL_CHECK(CHECK_TRY_ERROR(phys.emplace(dev, ctx, reserve_size))); + + // reserve virtual address space (if not already reserved) + if (pool_addr == 0) { + SYCL_CHECK(CHECK_TRY_ERROR( + pool_addr = sycl::ext::oneapi::experimental::reserve_virtual_mem( + SYCL_POOL_VMM_MAX_SIZE, ctx))); + } + + // map at the end of the pool + void * map_ptr = nullptr; + SYCL_CHECK(CHECK_TRY_ERROR( + map_ptr = phys->map(pool_addr + pool_size, reserve_size, + sycl::ext::oneapi::experimental::address_access_mode::read_write))); + + // stash these so we could unmap this exact range in dtor + mappings.push_back({ + std::move(*phys), + map_ptr, + }); + + // add to the pool + pool_size += reserve_size; + +#ifdef DEBUG_SYCL_MALLOC + GGML_LOG_INFO("sycl pool[%d]: size increased to %llu MB (reserved %llu MB)\n", + device, (unsigned long long) (pool_size/1024/1024), + (unsigned long long) (reserve_size/1024/1024)); +#endif + } + + GGML_ASSERT(pool_addr != 0); + + void * ptr = reinterpret_cast<void *>(pool_addr + pool_used); + *actual_size = size; + pool_used += size; + +#ifdef DEBUG_SYCL_MALLOC + GGML_LOG_INFO("sycl pool[%d]: allocated %llu bytes at %p\n", device, (unsigned long long) size, ptr); +#endif + + return ptr; + } + + void free(void * ptr, size_t size) override { +#ifdef DEBUG_SYCL_MALLOC + GGML_LOG_INFO("sycl pool[%d]: freed %llu bytes at %p\n", device, (unsigned long long) size, ptr); +#endif + + pool_used -= size; + + // all deallocations must be in reverse order of the allocations + GGML_ASSERT(ptr == reinterpret_cast<void *>(pool_addr + pool_used)); + } +}; +#endif // defined(GGML_SYCL_USE_VMM) + struct ggml_sycl_pool_host : public ggml_sycl_pool { queue_ptr qptr; int device; @@ -1378,15 +1692,18 @@ std::unique_ptr<ggml_sycl_pool> ggml_backend_sycl_context::new_pool_for_host(que } std::unique_ptr<ggml_sycl_pool> ggml_backend_sycl_context::new_pool_for_device(queue_ptr qptr, int device) { - // TBD: NO VMM support - // if (ggml_sycl_info().devices[device].vmm) { - // return std::unique_ptr<ggml_sycl_pool>(new ggml_sycl_pool_vmm(device)); - // } - return std::unique_ptr<ggml_sycl_pool>(new ggml_sycl_pool_leg(qptr, device)); +#if defined(GGML_SYCL_USE_VMM) + if (g_ggml_sycl_enable_vmm && ggml_sycl_info().devices[device].vmm) { + return std::unique_ptr<ggml_sycl_pool>(new ggml_sycl_pool_vmm(qptr, device)); + } +#endif // defined(GGML_SYCL_USE_VMM) + return std::unique_ptr<ggml_sycl_pool>(new ggml_sycl_pool_leg(qptr, device)); } -// TBD pool with virtual memory management -// struct ggml_sycl_pool_vmm : public ggml_sycl_pool + +std::unique_ptr<ggml_sycl_fattn_kv_buffers> ggml_backend_sycl_context::new_fattn_kv_buffers(queue_ptr qptr, int device) { + return std::unique_ptr<ggml_sycl_fattn_kv_buffers>(new ggml_sycl_fattn_kv_buffers(qptr, device)); +} /// kernels typedef void (*ggml_sycl_op_mul_mat_t)( @@ -1825,6 +2142,110 @@ static void argsort_f32_i32_sycl(const float *x, int *dst, const int ncols, } } +static void top_k_f32_sycl( + const float * src, + int32_t * dst_indices, + const int64_t ncols, + const int64_t nrows, + const int k, + dpct::queue_ptr main_stream +) { + const int block_size = 128; + + const sycl::range<1> block_dims(block_size); + const sycl::range<1> grid_dims(nrows); + + main_stream->submit([&](sycl::handler &cgh) { + sycl::local_accessor<float, 1> shared_vals(sycl::range<1>(block_size * k), cgh); + sycl::local_accessor<int, 1> shared_idx(sycl::range<1>(block_size * k), cgh); + + cgh.parallel_for( + sycl::nd_range<1>(grid_dims * block_dims, block_dims), + [=](sycl::nd_item<1> item_ct1) { + const int row = item_ct1.get_group(0); + const int tid = item_ct1.get_local_id(0); + + if (row >= nrows) return; + + const float * src_row = src + row * ncols; + int32_t * dst_idx_row = dst_indices + row * k; + + float local_vals[32]; + int local_idx[32]; + + for (int i = 0; i < k; i++) { + local_vals[i] = -FLT_MAX; + local_idx[i] = -1; + } + + for (int col = tid; col < ncols; col += block_size) { + float val = src_row[col]; + + if (val > local_vals[k-1]) { + int pos = k - 1; + while (pos > 0 && val > local_vals[pos - 1]) { + pos--; + } + + for (int i = k - 1; i > pos; i--) { + local_vals[i] = local_vals[i - 1]; + local_idx[i] = local_idx[i - 1]; + } + local_vals[pos] = val; + local_idx[pos] = col; + } + } + + for (int i = 0; i < k; i++) { + shared_vals[tid * k + i] = local_vals[i]; + shared_idx[tid * k + i] = local_idx[i]; + } + item_ct1.barrier(sycl::access::fence_space::local_space); + + if (tid == 0) { + float final_vals[32]; + int final_idx[32]; + + for (int i = 0; i < k; i++) { + final_vals[i] = -FLT_MAX; + final_idx[i] = -1; + } + + for (int t = 0; t < block_size; t++) { + for (int i = 0; i < k; i++) { + float val = shared_vals[t * k + i]; + int idx = shared_idx[t * k + i]; + + if (val > final_vals[k-1]) { + int pos = k - 1; + while (pos > 0 && val > final_vals[pos - 1]) { + pos--; + } + + for (int j = k - 1; j > pos; j--) { + final_vals[j] = final_vals[j - 1]; + final_idx[j] = final_idx[j - 1]; + } + final_vals[pos] = val; + final_idx[pos] = idx; + } + } + } + + for (int i = 0; i < k; i++) { + dst_idx_row[i] = final_idx[i]; + } + + if (k > 1) { + int32_t temp = dst_idx_row[0]; + dst_idx_row[0] = dst_idx_row[1]; + dst_idx_row[1] = temp; + } + } + }); + }); +} + static void argmax_f32_i32_sycl(const float *x, int *dst, const int ncols, const int nrows, queue_ptr stream) { const sycl::range<3> block_dims(1, 1, SYCL_ARGMAX_BLOCK_SIZE); @@ -2004,6 +2425,31 @@ inline void ggml_sycl_op_mul_mat_sycl( #else bool use_fp16 = false; #endif + +#if GGML_SYCL_DNNL && defined(GGML_SYCL_HAS_BF16) + // Fast path for bf16 src0 + if (src0->type == GGML_TYPE_BF16 && !g_ggml_sycl_disable_dnn && ggml_is_contiguous(src0) && + row_diff == src0->ne[1]) { + using bf16_t = sycl::ext::oneapi::bfloat16; + ggml_sycl_pool_alloc<bf16_t> src1_as_bf16(ctx.pool(), src1_ncols*ne10); + if (src1->type != GGML_TYPE_BF16) { + const to_bf16_sycl_t to_bf16_sycl = ggml_get_to_bf16_sycl(src1->type, dst); + GGML_ASSERT(to_bf16_sycl != nullptr); + to_bf16_sycl(src1_ddf_i, src1_as_bf16.get(), src1_ncols*ne10, stream); + } else { + stream->memcpy(src1_as_bf16.get(), src1_ddf_i, src1_ncols*ne10*sizeof(bf16_t)); + } + DnnlGemmWrapper::row_gemm(ctx, row_diff, src1_ncols, ne10, + src0_dd_i, DnnlGemmWrapper::to_dt<bf16_t>(), + src1_as_bf16.get(), DnnlGemmWrapper::to_dt<bf16_t>(), + dst_dd_i, DnnlGemmWrapper::to_dt<float>(), stream); + GGML_UNUSED(dst); + GGML_UNUSED(src1_ddq_i); + GGML_UNUSED(src1_padded_row_size); + return; + } +#endif + if ((src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type)) && use_fp16 && ggml_is_contiguous(src0) && row_diff == src0->ne[1] && dst->op_params[0] == GGML_PREC_DEFAULT) { ggml_sycl_pool_alloc<sycl::half> src0_as_f16(ctx.pool()); @@ -2048,8 +2494,8 @@ inline void ggml_sycl_op_mul_mat_sycl( const sycl::half alpha_f16 = 1.0f; const sycl::half beta_f16 = 0.0f; SYCL_CHECK(CHECK_TRY_ERROR(dpct::gemm( - *stream, oneapi::math::transpose::trans, - oneapi::math::transpose::nontrans, row_diff, src1_ncols, ne10, + *stream, oneapi::mkl::transpose::trans, + oneapi::mkl::transpose::nontrans, row_diff, src1_ncols, ne10, &alpha_f16, src0_ptr, dpct::library_data_t::real_half, ne00, src1_ptr, dpct::library_data_t::real_half, ne10, &beta_f16, dst_f16.get(), dpct::library_data_t::real_half, ldc, @@ -2081,21 +2527,25 @@ inline void ggml_sycl_op_mul_mat_sycl( const float * src0_ddf_i = src0->type == GGML_TYPE_F32 ? (const float *) src0_dd_i : src0_ddq_as_f32.get(); const float * src1_ddf1_i = src1->type == GGML_TYPE_F32 ? (const float *) src1_ddf_i : src1_ddq_as_f32.get(); + { + const int64_t gemm_flops = (int64_t)row_diff * src1_ncols * ne10; + const bool use_mkl_direct = gemm_flops < 256 * 256 * 256; #if GGML_SYCL_DNNL - if (!g_ggml_sycl_disable_dnn) { - DnnlGemmWrapper::row_gemm(ctx, row_diff, src1_ncols, ne10, src0_ddf_i, - DnnlGemmWrapper::to_dt<float>(), src1_ddf1_i, DnnlGemmWrapper::to_dt<float>(), - dst_dd_i, DnnlGemmWrapper::to_dt<float>(), stream); - } - else + if (!g_ggml_sycl_disable_dnn && !use_mkl_direct) { + DnnlGemmWrapper::row_gemm(ctx, row_diff, src1_ncols, ne10, src0_ddf_i, + DnnlGemmWrapper::to_dt<float>(), src1_ddf1_i, DnnlGemmWrapper::to_dt<float>(), + dst_dd_i, DnnlGemmWrapper::to_dt<float>(), stream); + } + else #endif - { - const float alpha = 1.0f; - const float beta = 0.0f; - SYCL_CHECK(CHECK_TRY_ERROR(oneapi::math::blas::column_major::gemm( - get_onemath_backend(*stream), oneapi::math::transpose::trans, oneapi::math::transpose::nontrans, row_diff, - src1_ncols, ne10, dpct::get_value(&alpha, *stream), src0_ddf_i, ne00, src1_ddf1_i, ne10, - dpct::get_value(&beta, *stream), dst_dd_i, ldc))); + { + const float alpha = 1.0f; + const float beta = 0.0f; + SYCL_CHECK(CHECK_TRY_ERROR(oneapi::mkl::blas::column_major::gemm( + *stream, oneapi::mkl::transpose::trans, oneapi::mkl::transpose::nontrans, row_diff, + src1_ncols, ne10, dpct::get_value(&alpha, *stream), src0_ddf_i, ne00, src1_ddf1_i, ne10, + dpct::get_value(&beta, *stream), dst_dd_i, ldc))); + } } } GGML_UNUSED(dst); @@ -2216,6 +2666,30 @@ inline void ggml_sycl_op_argsort(ggml_backend_sycl_context & ctx, ggml_tensor * main_stream, ctx.device); } +static void ggml_sycl_op_top_k(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { + const ggml_tensor * src0 = dst->src[0]; + + GGML_ASSERT(src0); + GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT(dst->type == GGML_TYPE_I32); + GGML_ASSERT(ggml_is_contiguous(src0)); + + dpct::queue_ptr main_stream = ctx.stream(); + SYCL_CHECK(ggml_sycl_set_device(ctx.device)); + + const float * src0_dd = static_cast<const float *>(src0->data); + int32_t * dst_dd = static_cast<int32_t *>(dst->data); + + const int k = dst->ne[0]; + const int64_t ncols = src0->ne[0]; + const int64_t nrows = ggml_nrows(src0); + + GGML_ASSERT(k > 0 && k <= 32); + GGML_ASSERT(k <= ncols); + + top_k_f32_sycl(src0_dd, dst_dd, ncols, nrows, k, main_stream); +} + inline void ggml_sycl_op_argmax(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32); GGML_ASSERT( dst->type == GGML_TYPE_I32); @@ -2248,6 +2722,65 @@ inline void ggml_sycl_op_diag_mask_inf(ggml_backend_sycl_context & ctx, ggml_ten diag_mask_inf_f32_sycl(src0_dd, dst_dd, ne00, nrows0, ne01, n_past, main_stream); } +static void tri_f32_sycl( + const float * src, + float * dst, + const int64_t ne0, + const int64_t ne1, + const int64_t ne2, + const int64_t ne3, + const ggml_tri_type ttype, + dpct::queue_ptr main_stream +) { + const size_t total = (size_t) ne0 * (size_t) ne1 * (size_t) ne2 * (size_t) ne3; + + main_stream->parallel_for(sycl::range<1>(total), [=](sycl::id<1> tid) { + const int64_t idx = (int64_t) tid[0]; + + const int64_t i0 = idx % ne0; + const int64_t t1 = idx / ne0; + const int64_t i1 = t1 % ne1; + + bool keep = false; + switch (ttype) { + case GGML_TRI_TYPE_LOWER: keep = (i0 < i1); break; + case GGML_TRI_TYPE_LOWER_DIAG: keep = (i0 <= i1); break; + case GGML_TRI_TYPE_UPPER: keep = (i0 > i1); break; + case GGML_TRI_TYPE_UPPER_DIAG: keep = (i0 >= i1); break; + default: keep = false; break; + } + + dst[idx] = keep ? src[idx] : 0.0f; + }); +} + +static void ggml_sycl_op_tri(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { + const ggml_tensor * src0 = dst->src[0]; + GGML_ASSERT(src0); + + GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT(dst->type == GGML_TYPE_F32); + GGML_ASSERT(ggml_is_contiguous(src0)); + GGML_ASSERT(ggml_is_contiguous(dst)); + GGML_ASSERT(ggml_are_same_shape(src0, dst)); + + dpct::queue_ptr main_stream = ctx.stream(); + SYCL_CHECK(ggml_sycl_set_device(ctx.device)); + + const float * src0_dd = static_cast<const float *>(src0->data); + float * dst_dd = static_cast<float *>(dst->data); + + const ggml_tri_type ttype = (ggml_tri_type) ggml_get_op_params_i32(dst, 0); + + const int64_t ne0 = src0->ne[0]; + const int64_t ne1 = src0->ne[1]; + const int64_t ne2 = src0->ne[2]; + const int64_t ne3 = src0->ne[3]; + + tri_f32_sycl(src0_dd, dst_dd, ne0, ne1, ne2, ne3, ttype, main_stream); +} + + inline void ggml_sycl_op_scale(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32); GGML_ASSERT( dst->type == GGML_TYPE_F32); @@ -2810,7 +3343,7 @@ static void ggml_sycl_mul_mat_batched_sycl(ggml_backend_sycl_context & ctx, cons } #if GGML_SYCL_DNNL - // oneDNN handles strided data and does not need overhead of get_to_fp16_nc_sycl + // oneDNN handles strided data and does not need overhead of ggml_get_to_fp16_nc_sycl const int64_t ne_src1 = src1->nb[last_str] * src1->ne[last_dim] / type_size_src1; src1_f16_alloc.alloc(ne_src1); const to_fp16_sycl_t to_fp16_sycl = ggml_get_to_fp16_sycl(src1->type, dst); @@ -2819,7 +3352,7 @@ static void ggml_sycl_mul_mat_batched_sycl(ggml_backend_sycl_context & ctx, cons # else const int64_t ne_src1 = ggml_nelements(src1); src1_f16_alloc.alloc(ne_src1); - const to_fp16_nc_sycl_t to_fp16_nc_sycl = get_to_fp16_nc_sycl(src1->type); + const to_fp16_nc_sycl_t to_fp16_nc_sycl = ggml_get_to_fp16_nc_sycl(src1->type); GGML_ASSERT(to_fp16_nc_sycl != nullptr); to_fp16_nc_sycl(src1_f16, src1_f16_alloc.get(), ne10, ne11, ne12, ne13, s11, s12, s13, queue); #endif @@ -2963,8 +3496,8 @@ static void ggml_sycl_mul_mat_batched_sycl(ggml_backend_sycl_context & ctx, cons const int64_t smb = ne12 == 1 ? s13 : s12; // there is no broadcast and src0, src1 are contiguous across dims 2, 3 - SYCL_CHECK(CHECK_TRY_ERROR(dpct::gemm_batch(*queue, oneapi::math::transpose::trans, - oneapi::math::transpose::nontrans, ne01, ne11, ne10, alpha, + SYCL_CHECK(CHECK_TRY_ERROR(dpct::gemm_batch(*queue, oneapi::mkl::transpose::trans, + oneapi::mkl::transpose::nontrans, ne01, ne11, ne10, alpha, src0_f16, dpct::library_data_t::real_half, nb01 / nb00, sma, src1_f16, dpct::library_data_t::real_half, s11, smb, beta, dst_ddf, mkl_data_type, ne0, ne1 * ne0, ne12 * ne13, mkl_compute_type))); @@ -2988,7 +3521,7 @@ static void ggml_sycl_mul_mat_batched_sycl(ggml_backend_sycl_context & ctx, cons }); SYCL_CHECK(CHECK_TRY_ERROR(dpct::gemm_batch( - *queue, oneapi::math::transpose::trans, oneapi::math::transpose::nontrans, ne01, ne11, ne10, alpha, + *queue, oneapi::mkl::transpose::trans, oneapi::mkl::transpose::nontrans, ne01, ne11, ne10, alpha, (const void **) (ptrs_src.get() + 0 * ne23), dpct::library_data_t::real_half, nb01 / nb00, (const void **) (ptrs_src.get() + 1 * ne23), dpct::library_data_t::real_half, s11, beta, (void **) (ptrs_dst.get() + 0 * ne23), mkl_data_type, ne0, ne23, mkl_compute_type, matrix_info.get()))); @@ -3014,8 +3547,11 @@ inline bool ggml_sycl_supports_mmq(enum ggml_type type) { inline bool ggml_sycl_supports_reorder_mul_mat_sycl(enum ggml_type type) { switch (type) { case GGML_TYPE_Q4_0: + case GGML_TYPE_Q8_0: return true; + case GGML_TYPE_Q3_K: case GGML_TYPE_Q4_K: + case GGML_TYPE_Q5_K: case GGML_TYPE_Q6_K: return !g_ggml_sycl_prioritize_dmmv; default: @@ -3026,6 +3562,7 @@ inline bool ggml_sycl_supports_reorder_mul_mat_sycl(enum ggml_type type) { inline bool ggml_sycl_supports_reorder_dmmv(enum ggml_type type) { switch (type) { case GGML_TYPE_Q4_0: + case GGML_TYPE_Q8_0: return true; default: return false; @@ -3035,7 +3572,10 @@ inline bool ggml_sycl_supports_reorder_dmmv(enum ggml_type type) { inline bool ggml_sycl_supports_reorder_mmvq(enum ggml_type type) { switch (type) { case GGML_TYPE_Q4_0: + case GGML_TYPE_Q8_0: + case GGML_TYPE_Q3_K: case GGML_TYPE_Q4_K: + case GGML_TYPE_Q5_K: case GGML_TYPE_Q6_K: return true; default: @@ -3056,6 +3596,7 @@ static bool ggml_sycl_supports_dmmv(enum ggml_type type) { case GGML_TYPE_Q5_K: case GGML_TYPE_Q6_K: case GGML_TYPE_F16: + case GGML_TYPE_BF16: return true; default: return false; @@ -3073,7 +3614,7 @@ static inline void * sycl_ext_malloc_device(dpct::queue_ptr stream, size_t size) // If async allocation extension is not available, use_async should always be false. GGML_ASSERT(!use_async); #endif - return sycl::malloc(size, *stream, sycl::usm::alloc::device); + return ggml_sycl_malloc_device(size, *stream); } static inline void sycl_ext_free(dpct::queue_ptr stream, void * ptr) { @@ -3087,12 +3628,58 @@ static inline void sycl_ext_free(dpct::queue_ptr stream, void * ptr) { // If async allocation extension is not available, use_async should always be false. GGML_ASSERT(!use_async); #endif - sycl::free(ptr, *stream); + ggml_sycl_free_device(ptr, *stream); } -static void reorder_qw_q4_0(uint8_t * data_device, const int ncols, const int nrows, size_t size, size_t offset, +// RAII wrapper for temporary reorder buffers with optional host memory fallback. +// When device allocation fails and GGML_SYCL_HOST_MEM_FALLBACK is enabled, +// falls back to host memory so the reorder kernel can still run (over PCIe). +// Device access to host memory requires Linux kernel 6.8+ (Ubuntu 26.04+). +struct sycl_reorder_temp_buffer { + void * ptr = nullptr; + dpct::queue_ptr stream; + + sycl_reorder_temp_buffer(dpct::queue_ptr stream, size_t size) : stream(stream) { + ptr = sycl_ext_malloc_device(stream, size); +#ifdef GGML_SYCL_HOST_MEM_FALLBACK + if (!ptr) { + ptr = sycl::malloc_host(size, *stream); + if (ptr) { + host_fallback = true; + GGML_LOG_WARN("%s: device alloc of %zu bytes failed, using host memory fallback\n", __func__, size); + } + } +#endif + } + + ~sycl_reorder_temp_buffer() { + if (!ptr) { + return; + } + if (host_fallback) { + sycl::free(ptr, *stream); + } else { + sycl_ext_free(stream, ptr); + } + } + + explicit operator bool() const { return ptr != nullptr; } + + sycl_reorder_temp_buffer(const sycl_reorder_temp_buffer &) = delete; + sycl_reorder_temp_buffer & operator=(const sycl_reorder_temp_buffer &) = delete; + +private: + bool host_fallback = false; +}; + +static bool reorder_qw_q4_0(uint8_t * data_device, const int ncols, const int nrows, size_t size, size_t offset, dpct::queue_ptr stream) { - uint8_t * tmp_buf = static_cast<uint8_t *>(sycl_ext_malloc_device(stream, size)); + sycl_reorder_temp_buffer tmp(stream, size); + if (!tmp) { + GGML_LOG_WARN("%s: failed to allocate %zu bytes for reorder temp buffer, skipping reorder\n", __func__, size); + return false; + } + uint8_t * tmp_buf = static_cast<uint8_t *>(tmp.ptr); sycl::event copy_event; SYCL_CHECK(CHECK_TRY_ERROR(copy_event = stream->memcpy(tmp_buf, data_device, size))); @@ -3121,16 +3708,60 @@ static void reorder_qw_q4_0(uint8_t * data_device, const int ncols, const int nr if (!g_ggml_sycl_use_async_mem_op) { reorder_event.wait_and_throw(); } - sycl_ext_free(stream, tmp_buf); + return true; } -static void reorder_qw_q4_k(uint8_t * data_device, size_t size, size_t offset, dpct::queue_ptr stream) { +static bool reorder_qw_q8_0(uint8_t * data_device, const int ncols, const int nrows, size_t size, size_t offset, + dpct::queue_ptr stream) { + sycl_reorder_temp_buffer tmp(stream, size); + if (!tmp) { + GGML_LOG_WARN("%s: failed to allocate %zu bytes for reorder temp buffer, skipping reorder\n", __func__, size); + return false; + } + uint8_t * tmp_buf = static_cast<uint8_t *>(tmp.ptr); + + sycl::event copy_event; + SYCL_CHECK(CHECK_TRY_ERROR(copy_event = stream->memcpy(tmp_buf, data_device, size))); + if (!g_ggml_sycl_use_async_mem_op) { + copy_event.wait(); + } + + GGML_ASSERT((size % sizeof(block_q8_0) == 0)); + GGML_ASSERT((offset % sizeof(block_q8_0) == 0)); + int offset_blks = offset / sizeof(block_q8_0); + auto qs_ptr = data_device + offset_blks * QK8_0; + auto d_ptr = (sycl::half*)(qs_ptr + ncols * nrows) + offset_blks; + + auto reorder_event = stream->parallel_for( + size / sizeof(block_q8_0), + [=](auto i) [[sycl::reqd_sub_group_size(WARP_SIZE)]] { + const block_q8_0* x = (const block_q8_0*)tmp_buf; + const int ib = i; + + for (int j = 0; j < QK8_0; j++) + { + *((int8_t*)qs_ptr + ib * QK8_0 + j) = x[ib].qs[j]; + } + *(d_ptr + ib) = x[ib].d; + }); + if (!g_ggml_sycl_use_async_mem_op) { + reorder_event.wait_and_throw(); + } + return true; +} + +static bool reorder_qw_q4_k(uint8_t * data_device, size_t size, size_t offset, dpct::queue_ptr stream) { GGML_ASSERT(size % sizeof(block_q4_K) == 0); GGML_ASSERT(offset % sizeof(block_q4_K) == 0); const int nblocks = size / sizeof(block_q4_K); - uint8_t * tmp_buf = static_cast<uint8_t *>(sycl_ext_malloc_device(stream, size)); + sycl_reorder_temp_buffer tmp(stream, size); + if (!tmp) { + GGML_LOG_WARN("%s: failed to allocate %zu bytes for reorder temp buffer, skipping reorder\n", __func__, size); + return false; + } + uint8_t * tmp_buf = static_cast<uint8_t *>(tmp.ptr); sycl::event copy_event; SYCL_CHECK(CHECK_TRY_ERROR(copy_event = stream->memcpy(tmp_buf, data_device, size))); @@ -3159,16 +3790,117 @@ static void reorder_qw_q4_k(uint8_t * data_device, size_t size, size_t offset, d if (!g_ggml_sycl_use_async_mem_op) { reorder_event.wait_and_throw(); } - sycl_ext_free(stream, tmp_buf); + return true; +} + +static bool reorder_qw_q3_k(uint8_t * data_device, size_t size, size_t offset, dpct::queue_ptr stream) { + GGML_ASSERT(size % sizeof(block_q3_K) == 0); + GGML_ASSERT(offset % sizeof(block_q3_K) == 0); + + const int nblocks = size / sizeof(block_q3_K); + + sycl_reorder_temp_buffer tmp(stream, size); + if (!tmp) { + GGML_LOG_WARN("%s: failed to allocate %zu bytes for reorder temp buffer, skipping reorder\n", __func__, size); + return false; + } + uint8_t * tmp_buf = static_cast<uint8_t *>(tmp.ptr); + + sycl::event copy_event; + SYCL_CHECK(CHECK_TRY_ERROR(copy_event = stream->memcpy(tmp_buf, data_device, size))); + if (!g_ggml_sycl_use_async_mem_op) { + copy_event.wait(); + } + + auto * qs_ptr = data_device; + auto * hmask_ptr = qs_ptr + (QK_K / 4) * nblocks; + auto * scales_ptr = hmask_ptr + (QK_K / 8) * nblocks; + sycl::half * d_ptr = (sycl::half *) (scales_ptr + 12 * nblocks); + + auto reorder_event = stream->parallel_for(nblocks, [=](auto i) { + const block_q3_K * x = (const block_q3_K *) tmp_buf; + const int ib = i; + + for (int j = 0; j < QK_K / 4; ++j) { + qs_ptr[ib * (QK_K / 4) + j] = x[ib].qs[j]; + } + + for (int j = 0; j < QK_K / 8; ++j) { + hmask_ptr[ib * (QK_K / 8) + j] = x[ib].hmask[j]; + } + + for (int j = 0; j < 12; ++j) { + scales_ptr[ib * 12 + j] = x[ib].scales[j]; + } + + d_ptr[ib] = x[ib].d; + }); + if (!g_ggml_sycl_use_async_mem_op) { + reorder_event.wait_and_throw(); + } + return true; +} + +static bool reorder_qw_q5_k(uint8_t * data_device, size_t size, size_t offset, dpct::queue_ptr stream) { + GGML_ASSERT(size % sizeof(block_q5_K) == 0); + GGML_ASSERT(offset % sizeof(block_q5_K) == 0); + + const int nblocks = size / sizeof(block_q5_K); + + sycl_reorder_temp_buffer tmp(stream, size); + if (!tmp) { + GGML_LOG_WARN("%s: failed to allocate %zu bytes for reorder temp buffer, skipping reorder\n", __func__, size); + return false; + } + uint8_t * tmp_buf = static_cast<uint8_t *>(tmp.ptr); + + sycl::event copy_event; + SYCL_CHECK(CHECK_TRY_ERROR(copy_event = stream->memcpy(tmp_buf, data_device, size))); + if (!g_ggml_sycl_use_async_mem_op) { + copy_event.wait(); + } + + auto * qs_ptr = data_device; + auto * qh_ptr = qs_ptr + (QK_K / 2) * nblocks; + auto * scales_ptr = qh_ptr + (QK_K / 8) * nblocks; + auto * dm_ptr = (sycl::half2 *) (scales_ptr + K_SCALE_SIZE * nblocks); + + auto reorder_event = stream->parallel_for(nblocks, [=](auto i) { + const block_q5_K * x = (const block_q5_K *) tmp_buf; + const int ib = i; + + for (int j = 0; j < QK_K / 2; ++j) { + qs_ptr[ib * (QK_K / 2) + j] = x[ib].qs[j]; + } + + for (int j = 0; j < QK_K / 8; ++j) { + qh_ptr[ib * (QK_K / 8) + j] = x[ib].qh[j]; + } + + for (int j = 0; j < K_SCALE_SIZE; ++j) { + scales_ptr[ib * K_SCALE_SIZE + j] = x[ib].scales[j]; + } + + dm_ptr[ib] = x[ib].dm; + }); + if (!g_ggml_sycl_use_async_mem_op) { + reorder_event.wait_and_throw(); + } + return true; } -static void reorder_qw_q6_k(uint8_t * data_device, size_t size, size_t offset, dpct::queue_ptr stream) { +static bool reorder_qw_q6_k(uint8_t * data_device, size_t size, size_t offset, dpct::queue_ptr stream) { GGML_ASSERT(size % sizeof(block_q6_K) == 0); GGML_ASSERT(offset % sizeof(block_q6_K) == 0); const int nblocks = size / sizeof(block_q6_K); - uint8_t * tmp_buf = static_cast<uint8_t *>(sycl_ext_malloc_device(stream, size)); + sycl_reorder_temp_buffer tmp(stream, size); + if (!tmp) { + GGML_LOG_WARN("%s: failed to allocate %zu bytes for reorder temp buffer, skipping reorder\n", __func__, size); + return false; + } + uint8_t * tmp_buf = static_cast<uint8_t *>(tmp.ptr); sycl::event copy_event; SYCL_CHECK(CHECK_TRY_ERROR(copy_event = stream->memcpy(tmp_buf, data_device, size))); @@ -3207,10 +3939,10 @@ static void reorder_qw_q6_k(uint8_t * data_device, size_t size, size_t offset, d if (!g_ggml_sycl_use_async_mem_op) { reorder_event.wait_and_throw(); } - sycl_ext_free(stream, tmp_buf); + return true; } -static void reorder_qw(const ggml_tensor * src0, dpct::queue_ptr stream) { +static bool reorder_qw(const ggml_tensor * src0, dpct::queue_ptr stream) { uint8_t * data_device = (uint8_t *) src0->data; size_t ncols = src0->ne[0]; size_t nrows = src0->ne[1]; @@ -3218,17 +3950,20 @@ static void reorder_qw(const ggml_tensor * src0, dpct::queue_ptr stream) { switch (src0->type) { case GGML_TYPE_Q4_0: - reorder_qw_q4_0(data_device, ncols, nrows, size, 0, stream); - break; + return reorder_qw_q4_0(data_device, ncols, nrows, size, 0, stream); + case GGML_TYPE_Q8_0: + return reorder_qw_q8_0(data_device, ncols, nrows, size, 0, stream); + case GGML_TYPE_Q3_K: + return reorder_qw_q3_k(data_device, size, 0, stream); case GGML_TYPE_Q4_K: - reorder_qw_q4_k(data_device, size, 0, stream); - break; + return reorder_qw_q4_k(data_device, size, 0, stream); + case GGML_TYPE_Q5_K: + return reorder_qw_q5_k(data_device, size, 0, stream); case GGML_TYPE_Q6_K: - reorder_qw_q6_k(data_device, size, 0, stream); - break; + return reorder_qw_q6_k(data_device, size, 0, stream); default: GGML_ABORT("reorder_qw() called with unsupported type"); - break; + return false; } } @@ -3236,7 +3971,9 @@ static bool should_reorder_tensor(ggml_backend_sycl_context& ctx, const ggml_ten return !g_ggml_sycl_disable_optimize && //allow optimize, controlled by $GGML_SYCL_DISABLE_OPT ctx.opt_feature.reorder && //allow this device due to good perf, skip the devices with bad perf. dst->op == GGML_OP_MUL_MAT && //limit to some supported cases of Q4_0, to do for more cases. - dst->src[1]->ne[1]==1 && dst->src[1]->ne[2]==1 && dst->src[1]->ne[3]==1; + // ne[1] <= 8 so multi-column decode (spec / MTP verify) also bootstraps the reorder; + // all reorderable types have a _switch_ncols kernel. + dst->src[1]->ne[1] <= 8 && dst->src[1]->ne[2]==1 && dst->src[1]->ne[3]==1; } static void opt_for_reorder(ggml_backend_sycl_context * ctx, const ggml_tensor * src0, const ggml_tensor * /* src1 */, @@ -3268,14 +4005,20 @@ static void opt_for_reorder(ggml_backend_sycl_context * ctx, const ggml_tensor * break; } - reorder_qw(src0, ctx->stream()); - extra->optimized_feature.reorder = true; // Used to decode/dequan in next steps and avoid re-reordering + if (reorder_qw(src0, ctx->stream())) { + extra->optimized_feature.reorder = true; // Used to decode/dequan in next steps and avoid re-reordering + } } static bool can_use_dequantize_mul_mat_vec(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { + // The F16/BF16 qk=1 kernel iterates with stride 2*DMMV_X, requiring ne[0] to be + // a multiple of 2*DMMV_X. Quantized types use block-structured kernels that only + // need ne[0] % DMMV_X == 0. + const int64_t dmmv_x_required = (src0->type == GGML_TYPE_BF16 || src0->type == GGML_TYPE_F16) ? + 2*GGML_SYCL_DMMV_X : GGML_SYCL_DMMV_X; return ggml_sycl_supports_dmmv(src0->type) && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32 && - src0->ne[0] % GGML_SYCL_DMMV_X == 0 && src1->ne[1] == 1; + src0->ne[0] % dmmv_x_required == 0 && src1->ne[1] == 1; } static bool can_use_mul_mat_vec_q(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { @@ -3316,19 +4059,25 @@ static void ggml_sycl_mul_mat(ggml_backend_sycl_context & ctx, const ggml_tensor // mmvq and mmq need the __dp4a instruction which is available for gen12+ - // Workaround in https://github.com/ggerganov/llama.cpp/commit/95f84d5ce8b449a9b16009434aca800df504a02e + // Workaround in https://github.com/ggml-org/llama.cpp/commit/95f84d5ce8b449a9b16009434aca800df504a02e use_mul_mat_q = use_mul_mat_q && (src0->type != GGML_TYPE_IQ2_XXS); #ifdef SYCL_USE_XMX use_mul_mat_q = use_mul_mat_q && (src1->ne[1] <= MMQ_MAX_BATCH_SIZE); #endif // SYCL_USE_XMX - // mmvq path is faster in the CUDA backend. - if (!g_ggml_sycl_prioritize_dmmv && (ctx.stream()->get_backend() == sycl::backend::ext_oneapi_cuda - // Dispatch becomes obscure with the reorder, MMVQ when the reorder optimization - // is enabled takes precedence over DMMV, the current if-else implementation - // requires disabling DMMV if both conditions are met - || (should_reorder_tensor(ctx, dst) && ggml_sycl_supports_reorder_mmvq(src0->type)))) { - use_dequantize_mul_mat_vec = use_dequantize_mul_mat_vec && !use_mul_mat_vec_q; + // Dispatch becomes obscure with the reorder, MMVQ when the reorder optimization + // is enabled takes precedence over DMMV, the current if-else implementation + // requires disabling DMMV if both conditions are met + + if (!g_ggml_sycl_prioritize_dmmv && ((should_reorder_tensor(ctx, dst) && + ggml_sycl_supports_reorder_mmvq(src0->type)))) { + // Arc770 get benefit with Q4_0 by skipping it. + if (!(ggml_sycl_info().devices[ctx.device].hw_info.arch == + gpu_arch::intel_gpu_acm_g10 && + src0->type == GGML_TYPE_Q4_0)) { + use_dequantize_mul_mat_vec = + use_dequantize_mul_mat_vec && !use_mul_mat_vec_q; + } } if (!split && src0->type == GGML_TYPE_F16 && ggml_is_permuted(src0) && ggml_is_permuted(src1) && src1->ne[1] == 1) { @@ -3373,35 +4122,17 @@ struct mmid_row_mapping { __dpct_inline__ static void k_copy_src1_to_contiguous( const char *__restrict__ src1_original, char *__restrict__ src1_contiguous, - int *__restrict__ cur_src1_row, mmid_row_mapping *__restrict__ row_mapping, - const char *__restrict ids, int64_t i02, size_t ids_nb1, size_t ids_nb0, + const mmid_row_mapping *__restrict__ row_mapping, int64_t ne11, int64_t ne10, size_t nb11, size_t nb12, - const sycl::nd_item<3> &item_ct1, int &src1_row) { - int32_t iid1 = item_ct1.get_group(2); - int32_t id = item_ct1.get_group(1); - - const int32_t row_id_i = *(const int32_t *) (ids + iid1*ids_nb1 + id*ids_nb0); + const sycl::nd_item<3> &item_ct1) { + const int32_t src1_row = item_ct1.get_group(2); - if (row_id_i != i02) { - return; - } + const int32_t iid1 = row_mapping[src1_row].i2; + const int32_t id = row_mapping[src1_row].i1; const int64_t i11 = id % ne11; const int64_t i12 = iid1; - if (item_ct1.get_local_id(2) == 0) { - src1_row = - dpct::atomic_fetch_add<sycl::access::address_space::generic_space>( - cur_src1_row, 1); - row_mapping[src1_row] = {id, iid1}; - } - /* - DPCT1065:194: Consider replacing sycl::nd_item::barrier() with - sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better - performance if there is no access to global memory. - */ - item_ct1.barrier(); - const float * src1_row_original = (const float *)(src1_original + i11*nb11 + i12*nb12); float * src1_row_contiguous = (float *)(src1_contiguous + src1_row*nb11); @@ -3431,6 +4162,92 @@ __dpct_inline__ static void k_copy_dst_from_contiguous( } } +// Fused MoE TG fast path. Returns false to fall back to the per-expert loop below. +static bool ggml_sycl_mul_mat_id_mmvq_fused( + ggml_backend_sycl_context & ctx, const ggml_tensor * src0, + const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst) +{ + const int64_t ne10 = src1->ne[0]; + const int64_t ne11 = src1->ne[1]; + const int64_t ne12 = src1->ne[2]; + if (ne12 != 1) return false; + if (src1->type != GGML_TYPE_F32 || dst->type != GGML_TYPE_F32) return false; + if (ne10 != src0->ne[0] || ne10 % QK8_1 != 0) return false; + if (!ggml_is_contiguous(src1)) return false; + + // Reorder layout not supported; fall back. + const ggml_tensor_extra_gpu * src0_extra = + static_cast<const ggml_tensor_extra_gpu *>(src0->extra); + if (src0_extra && src0_extra->optimized_feature.reorder) return false; + + const int64_t n_ids_per_group = ids->ne[0]; + if (ids->ne[1] != 1) return false; + if (ne11 != 1 && ne11 != n_ids_per_group) return false; + + const queue_ptr stream = ctx.stream(); + const int src1_padded_cols = GGML_PAD((int) ne10, MATRIX_ROW_PADDING); + const int n_experts_used = (int) n_ids_per_group; + const int nrows = (int) src0->ne[1]; + + ggml_sycl_pool_alloc<char> src1_q8_alloc(ctx.pool(), + (size_t) ne11 * src1_padded_cols * sizeof(block_q8_1) / QK8_1); + char * src1_ddq = src1_q8_alloc.get(); + quantize_row_q8_1_sycl<quantize_q8_1>( + (const float *) src1->data, src1_ddq, (int) ne10, (int) ne11, + src1_padded_cols, stream); + + const size_t bytes_per_qrow = (size_t) src1_padded_cols * sizeof(block_q8_1) / QK8_1; + const size_t src1_row_stride = (ne11 == 1) ? 0 : bytes_per_qrow; + + return ggml_sycl_mul_mat_vec_q_id( + src0->type, src0->data, src1_ddq, (const int32_t *) ids->data, + (float *) dst->data, (int) ne10, nrows, n_experts_used, + /*expert_weight_stride=*/ src0->nb[2], + /*dst_row_stride=*/ dst->nb[1], + src1_row_stride, stream); +} + +// counting sort of the routed rows by expert id (row_id_i, as chosen by the router): +// builds a projection of a memory layout where each expert's slice is contiguous +static void mmid_counting_sort_rows( + const ggml_tensor * ids, const char * ids_host, + int64_t n_ids, int64_t n_as, int64_t n_routed_rows, + std::vector<int64_t> & expert_counts, + std::vector<int64_t> & expert_row_offsets, + std::vector<mmid_row_mapping> & routed_row_src) { + + // frequencies: how many routed rows each expert "owns" + expert_counts.assign(n_as, 0); + for (int64_t iid1 = 0; iid1 < ids->ne[1]; iid1++) { + for (int64_t id = 0; id < n_ids; id++) { + const int32_t row_id_i = *(const int32_t *) (ids_host + iid1*ids->nb[1] + id*ids->nb[0]); + GGML_ASSERT(row_id_i >= 0 && row_id_i < n_as); + expert_counts[row_id_i]++; + } + } + + // where each expert's slice starts (row indices) and the previous ends + expert_row_offsets.assign(n_as + 1, 0); + for (int64_t i02 = 0; i02 < n_as; i02++) { + expert_row_offsets[i02 + 1] = expert_row_offsets[i02] + expert_counts[i02]; + } + + std::vector<int64_t> expert_row_next = expert_row_offsets; + routed_row_src.resize(n_routed_rows); + for (int64_t iid1 = 0; iid1 < ids->ne[1]; iid1++) { + for (int64_t id = 0; id < n_ids; id++) { + const int32_t row_id_i = *(const int32_t *) (ids_host + iid1*ids->nb[1] + id*ids->nb[0]); + GGML_ASSERT(row_id_i >= 0 && row_id_i < n_as); + + // find and validate the next free row for a given expert (row_id_i) + const int64_t routed_row = expert_row_next[row_id_i]++; + GGML_ASSERT(routed_row >= expert_row_offsets[row_id_i]); + GGML_ASSERT(routed_row < expert_row_offsets[row_id_i + 1]); + routed_row_src[routed_row] = {(int32_t) id, (int32_t) iid1}; + } + } +} + static void ggml_sycl_mul_mat_id(ggml_backend_sycl_context & ctx, ggml_tensor *dst) try { scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/3); @@ -3446,6 +4263,12 @@ static void ggml_sycl_mul_mat_id(ggml_backend_sycl_context & ctx, const int64_t n_as = ne02; const int64_t n_ids = ids->ne[0]; + if (ne12 == 1) { + if (ggml_sycl_mul_mat_id_mmvq_fused(ctx, src0, src1, ids, dst)) { + return; + } + } + std::vector<char> ids_host(ggml_nbytes(ids)); const char * ids_dev = (const char *) ids->data; @@ -3496,105 +4319,98 @@ static void ggml_sycl_mul_mat_id(ggml_backend_sycl_context & ctx, } } } else { - ggml_sycl_pool_alloc<char> src1_contiguous(ctx.pool(), sizeof(float)*ggml_nelements(src1)); - ggml_sycl_pool_alloc<char> dst_contiguous(ctx.pool(), sizeof(float)*ggml_nelements(dst)); + const int64_t n_routed_rows = ids->ne[1] * n_ids; + ggml_sycl_pool_alloc<char> src1_contiguous(ctx.pool(), sizeof(float)*n_routed_rows*ne10); + ggml_sycl_pool_alloc<char> dst_contiguous(ctx.pool(), sizeof(float)*n_routed_rows*ne0); src1_row.data = src1_contiguous.get(); dst_row.data = dst_contiguous.get(); - for (int64_t i02 = 0; i02 < n_as; i02++) { - int64_t num_src1_rows = 0; - for (int64_t iid1 = 0; iid1 < ids->ne[1]; iid1++) { - for (int64_t id = 0; id < n_ids; id++) { - const int32_t row_id_i = *(const int32_t *) (ids_host.data() + iid1*ids->nb[1] + id*ids->nb[0]); + // how many "owned" routed rows to pass to each expert + std::vector<int64_t> expert_row_counts; + // where each expert's slice starts and the previous ends (row indices, right-exclusive) + std::vector<int64_t> expert_row_offsets; + // the sources (slot/token pairs) of contiguous rows to guide k_copy_src1_to_contiguous + std::vector<mmid_row_mapping> routed_row_src; - GGML_ASSERT(row_id_i >= 0 && row_id_i < n_as); + mmid_counting_sort_rows(ids, ids_host.data(), n_ids, n_as, n_routed_rows, + expert_row_counts, expert_row_offsets, routed_row_src); - if (row_id_i != i02) { - continue; - } + ggml_sycl_pool_alloc<mmid_row_mapping> dev_row_mapping(ctx.pool(), n_routed_rows); + SYCL_CHECK(CHECK_TRY_ERROR( + stream->memcpy(dev_row_mapping.get(), routed_row_src.data(), n_routed_rows*sizeof(mmid_row_mapping)))); - num_src1_rows++; - } - } + const unsigned int max_work_group_size = ggml_sycl_info().max_work_group_sizes[ctx.device]; + assert(max_work_group_size % (WARP_SIZE * WARP_SIZE) == 0); + + { + sycl::range<3> block_dims(1, 1, std::min((unsigned int)ne10, max_work_group_size)); + sycl::range<3> grid_dims(1, 1, n_routed_rows); + stream->submit([&](sycl::handler &cgh) { + char *__restrict src1_contiguous_get = + src1_contiguous.get(); + mmid_row_mapping *__restrict dev_row_mapping_get = + dev_row_mapping.get(); + + cgh.parallel_for( + sycl::nd_range<3>(grid_dims * block_dims, block_dims), + [=](sycl::nd_item<3> item_ct1) { + k_copy_src1_to_contiguous( + src1_original, src1_contiguous_get, + dev_row_mapping_get, + ne11, ne10, nb11, nb12, + item_ct1); + }); + }); + } + + for (int64_t i02 = 0; i02 < n_as; i02++) { + const int64_t num_src1_rows = expert_row_counts[i02]; if (num_src1_rows == 0) { continue; } - - ggml_sycl_pool_alloc<int> dev_cur_src1_row(ctx.pool(), 1); - ggml_sycl_pool_alloc<mmid_row_mapping> dev_row_mapping(ctx.pool(), num_src1_rows); - SYCL_CHECK(CHECK_TRY_ERROR( - stream->memset(dev_cur_src1_row.get(), 0, sizeof(int)))); - - const unsigned int max_work_group_size = ggml_sycl_info().max_work_group_sizes[ctx.device]; - assert(max_work_group_size % (WARP_SIZE * WARP_SIZE) == 0); - - { - sycl::range<3> block_dims(1, 1, std::min((unsigned int)ne10, max_work_group_size)); - sycl::range<3> grid_dims(1, n_ids, ids->ne[1]); - stream->submit([&](sycl::handler &cgh) { - sycl::local_accessor<int, 0> src1_row_acc(cgh); - - char *__restrict src1_contiguous_get = - src1_contiguous.get(); - int *__restrict dev_cur_src1_row_get = - dev_cur_src1_row.get(); - mmid_row_mapping *__restrict dev_row_mapping_get = - dev_row_mapping.get(); - size_t ids_nb_ct6 = ids->nb[1]; - size_t ids_nb_ct7 = ids->nb[0]; - - cgh.parallel_for( - sycl::nd_range<3>(grid_dims * block_dims, block_dims), - [=](sycl::nd_item<3> item_ct1) { - k_copy_src1_to_contiguous( - src1_original, src1_contiguous_get, - dev_cur_src1_row_get, - dev_row_mapping_get, ids_dev, i02, - ids_nb_ct6, ids_nb_ct7, ne11, ne10, nb11, nb12, - item_ct1, src1_row_acc); - }); - }); - } + const int64_t expert_row_offset = expert_row_offsets[i02]; src0_row.data = src0_original + i02*nb02; GGML_ASSERT(nb11 == sizeof(float)*ne10); GGML_ASSERT(nb1 == sizeof(float)*ne0); + src1_row.data = src1_contiguous.get() + expert_row_offset*nb11; src1_row.ne[1] = num_src1_rows; src1_row.nb[1] = nb11; src1_row.nb[2] = num_src1_rows*nb11; src1_row.nb[3] = num_src1_rows*nb11; + dst_row.data = dst_contiguous.get() + expert_row_offset*nb1; dst_row.ne[1] = num_src1_rows; dst_row.nb[1] = nb1; dst_row.nb[2] = num_src1_rows*nb1; dst_row.nb[3] = num_src1_rows*nb1; ggml_sycl_mul_mat(ctx, &src0_row, &src1_row, &dst_row); + } - { - sycl::range<3> block_dims(1, 1, std::min((unsigned int)ne0, max_work_group_size)); - sycl::range<3> grid_dims(1, 1, num_src1_rows); - stream->submit([&](sycl::handler &cgh) { - const char *__restrict dst_contiguous_get = - dst_contiguous.get(); - const mmid_row_mapping *__restrict dev_row_mapping_get = - dev_row_mapping.get(); - - cgh.parallel_for( - sycl::nd_range<3>(grid_dims * block_dims, block_dims), - [=](sycl::nd_item<3> item_ct1) { - k_copy_dst_from_contiguous(dst_original, - dst_contiguous_get, - dev_row_mapping_get, - ne0, nb1, nb2, item_ct1); - }); - }); - } + { + sycl::range<3> block_dims(1, 1, std::min((unsigned int)ne0, max_work_group_size)); + sycl::range<3> grid_dims(1, 1, n_routed_rows); + stream->submit([&](sycl::handler &cgh) { + const char *__restrict dst_contiguous_get = + dst_contiguous.get(); + const mmid_row_mapping *__restrict dev_row_mapping_get = + dev_row_mapping.get(); + + cgh.parallel_for( + sycl::nd_range<3>(grid_dims * block_dims, block_dims), + [=](sycl::nd_item<3> item_ct1) { + k_copy_dst_from_contiguous(dst_original, + dst_contiguous_get, + dev_row_mapping_get, + ne0, nb1, nb2, item_ct1); + }); + }); } } } @@ -3624,6 +4440,11 @@ static void ggml_sycl_im2col(ggml_backend_sycl_context & ctx, ggml_tensor * dst) ggml_sycl_op_im2col(ctx, dst); } +static void ggml_sycl_im2col_3d(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { + scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/2); + ggml_sycl_op_im2col_3d(ctx, dst); +} + static void ggml_sycl_sum(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1); GGML_ASSERT(ggml_is_contiguous(dst->src[0])); @@ -3771,6 +4592,9 @@ static bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct gg case GGML_UNARY_OP_EXP: ggml_sycl_exp(ctx, dst); break; + case GGML_UNARY_OP_SOFTPLUS: + ggml_sycl_softplus(ctx, dst); + break; case GGML_UNARY_OP_SGN: ggml_sycl_sgn(ctx, dst); break; @@ -3897,6 +4721,9 @@ static bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct gg case GGML_OP_TRANSPOSE: GGML_SYCL_DEBUG("%s: Tensor NO-OP\n", __func__); break; + case GGML_OP_TRI: + ggml_sycl_op_tri(ctx, dst); + break; case GGML_OP_DIAG_MASK_INF: ggml_sycl_diag_mask_inf(ctx, dst); break; @@ -3909,9 +4736,15 @@ static bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct gg case GGML_OP_ROPE: ggml_sycl_rope(ctx, dst); break; + case GGML_OP_ROPE_BACK: + ggml_sycl_rope_back(ctx, dst); + break; case GGML_OP_IM2COL: ggml_sycl_im2col(ctx, dst); break; + case GGML_OP_IM2COL_3D: + ggml_sycl_im2col_3d(ctx, dst); + break; case GGML_OP_POOL_2D: ggml_sycl_pool2d(ctx, dst); break; @@ -3927,6 +4760,9 @@ static bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct gg case GGML_OP_ARGSORT: ggml_sycl_argsort(ctx, dst); break; + case GGML_OP_TOP_K: + ggml_sycl_op_top_k(ctx, dst); + break; case GGML_OP_TIMESTEP_EMBEDDING: ggml_sycl_op_timestep_embedding(ctx, dst); break; @@ -3939,15 +4775,36 @@ static bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct gg case GGML_OP_GATED_LINEAR_ATTN: ggml_sycl_op_gated_linear_attn(ctx, dst); break; + case GGML_OP_GATED_DELTA_NET: + ggml_sycl_gated_delta_net(ctx, dst); + break; case GGML_OP_SSM_CONV: ggml_sycl_ssm_conv(ctx, dst); break; + case GGML_OP_SSM_SCAN: + ggml_sycl_ssm_scan(ctx, dst); + break; + case GGML_OP_FILL: + ggml_sycl_fill(ctx, dst); + break; + case GGML_OP_CUMSUM: + ggml_sycl_cumsum(ctx, dst); + break; + case GGML_OP_DIAG: + ggml_sycl_diag(ctx, dst); + break; + case GGML_OP_SOLVE_TRI: + ggml_sycl_solve_tri(ctx, dst); + break; case GGML_OP_ROLL: ggml_sycl_roll(ctx, dst); break; case GGML_OP_ARANGE: ggml_sycl_arange(ctx, dst); break; + case GGML_OP_FLASH_ATTN_EXT: + ggml_sycl_flash_attn_ext(ctx, dst); + break; default: return false; } @@ -3978,16 +4835,6 @@ void ggml_backend_sycl_get_device_memory(int device, size_t *free, GGML_SYCL_DEBUG("[SYCL] call ggml_backend_sycl_get_device_memory\n"); ggml_sycl_set_device(device); - /* - DPCT1009:218: SYCL uses exceptions to report errors and does not use the - error codes. The original code was commented out and a warning string was - inserted. You need to rewrite this code. - */ - /* - DPCT1106:217: 'cudaMemGetInfo' was migrated with the Intel extensions for - device information which may not be supported by all compilers or runtimes. - You may need to adjust the code. - */ SYCL_CHECK(CHECK_TRY_ERROR( dpct::dev_mgr::instance().get_device(device).get_memory_info(*free, *total))); } @@ -4109,6 +4956,9 @@ static void ggml_backend_sycl_graph_compute_impl(ggml_backend_sycl_context * syc if (ggml_is_empty(node) || node->op == GGML_OP_RESHAPE || node->op == GGML_OP_TRANSPOSE || node->op == GGML_OP_VIEW || node->op == GGML_OP_PERMUTE || node->op == GGML_OP_NONE) { continue; } + if ((node->flags & GGML_TENSOR_FLAG_COMPUTE) == 0) { + continue; + } #ifndef NDEBUG assert(node->buffer->buft == ggml_backend_sycl_buffer_type(sycl_ctx->device)); for (int j = 0; j < GGML_MAX_SRC; j++) { @@ -4252,6 +5102,8 @@ static ggml_backend_i ggml_backend_sycl_interface = { /* .free = */ ggml_backend_sycl_free, /* .set_tensor_async = */ ggml_backend_sycl_set_tensor_async, /* .get_tensor_async = */ ggml_backend_sycl_get_tensor_async, + /* .set_tensor_2d_async = */ NULL, + /* .get_tensor_2d_async = */ NULL, /* .cpy_tensor_async = */ NULL, // ggml_backend_sycl_cpy_tensor_async, // // TODO: update for the new // interface @@ -4386,10 +5238,11 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g case GGML_UNARY_OP_GELU_QUICK: case GGML_UNARY_OP_GELU_ERF: case GGML_UNARY_OP_EXP: + case GGML_UNARY_OP_SOFTPLUS: case GGML_UNARY_OP_ELU: + case GGML_UNARY_OP_CEIL: return true; case GGML_UNARY_OP_FLOOR: - case GGML_UNARY_OP_CEIL: case GGML_UNARY_OP_ROUND: case GGML_UNARY_OP_TRUNC: #if defined (GGML_SYCL_F16) @@ -4419,26 +5272,19 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g struct ggml_tensor * a = op->src[0]; struct ggml_tensor * b = op->src[1]; - if (a->ne[3] != b->ne[3]) { + // disable Q1_0 until implementation + if (a->type == GGML_TYPE_Q1_0 || b->type == GGML_TYPE_Q1_0) { return false; } - ggml_type a_type = a->type; - if (a_type == GGML_TYPE_IQ4_NL || a_type == GGML_TYPE_IQ4_XS || - a_type == GGML_TYPE_IQ3_XXS || a_type == GGML_TYPE_IQ3_S || - a_type == GGML_TYPE_IQ2_XXS || a_type == GGML_TYPE_IQ2_XS || a_type == GGML_TYPE_IQ2_S || - a_type == GGML_TYPE_IQ1_S || a_type == GGML_TYPE_IQ1_M - ) { - if (b->ne[1] == 1 && ggml_nrows(b) > 1) { - return false; - } - } - ggml_type src0_type = op->src[0]->type; - if (src0_type == GGML_TYPE_BF16 ) { - // TODO: support GGML_TYPE_BF16 - // FIXME: keep a list of supported types to avoid breaking the backend when a new type is added + + if (a->ne[3] != b->ne[3]) { return false; } + ggml_type src0_type = op->src[0]->type; + + + // TODO: The configuration below needs more work to be supported with oneDNN if (ggml_is_permuted(a) && !ggml_is_contiguous(a) && a->ne[2] > 1 && a->ne[3] > 1 && src0_type == GGML_TYPE_F16) { @@ -4457,12 +5303,31 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g case GGML_OP_GET_ROWS: { switch (op->src[0]->type) { + case GGML_TYPE_I32: case GGML_TYPE_F16: + case GGML_TYPE_BF16: case GGML_TYPE_F32: + case GGML_TYPE_Q1_0: + case GGML_TYPE_MXFP4: + case GGML_TYPE_NVFP4: + case GGML_TYPE_IQ2_XXS: + case GGML_TYPE_IQ2_XS: + case GGML_TYPE_IQ2_S: + case GGML_TYPE_IQ3_XXS: + case GGML_TYPE_IQ1_S: + case GGML_TYPE_IQ1_M: + case GGML_TYPE_IQ3_S: + case GGML_TYPE_IQ4_NL: + case GGML_TYPE_IQ4_XS: + case GGML_TYPE_Q2_K: + case GGML_TYPE_Q3_K: case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_1: + case GGML_TYPE_Q4_K: case GGML_TYPE_Q5_0: case GGML_TYPE_Q5_1: + case GGML_TYPE_Q5_K: + case GGML_TYPE_Q6_K: case GGML_TYPE_Q8_0: return true; default: @@ -4588,18 +5453,23 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g return (op->type == GGML_TYPE_F32 && op->src[0]->type == GGML_TYPE_F32) && (op->type == op->src[0]->type); #endif case GGML_OP_NORM: - return true; case GGML_OP_L2_NORM: case GGML_OP_GROUP_NORM: - return ggml_is_contiguous(op->src[0]); case GGML_OP_RMS_NORM: - return ((op->src[0]->ne[0] % WARP_SIZE) == 0); + return true; case GGML_OP_RMS_NORM_BACK: - return ((op->src[0]->ne[0] % WARP_SIZE) == 0); + return ggml_is_contiguous(op->src[0]); case GGML_OP_SCALE: return true; case GGML_OP_CONT: return op->src[0]->type != GGML_TYPE_BF16; + case GGML_OP_TRI: + { + const ggml_tensor * src0 = op->src[0]; + return src0 && + op->type == GGML_TYPE_F32 && + ggml_is_contiguous(src0); + } case GGML_OP_DIAG_MASK_INF: return true; case GGML_OP_SOFT_MAX: @@ -4610,10 +5480,11 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g return max_bias == 0.0f; } case GGML_OP_ROPE: + case GGML_OP_ROPE_BACK: case GGML_OP_IM2COL: - return true; + case GGML_OP_IM2COL_3D: case GGML_OP_UPSCALE: - return op->src[0]->type == GGML_TYPE_F32 && op->op_params[0] == GGML_SCALE_MODE_NEAREST && !(op->op_params[0] & GGML_SCALE_FLAG_ANTIALIAS); + return true; case GGML_OP_SUM: case GGML_OP_SUM_ROWS: case GGML_OP_MEAN: @@ -4621,20 +5492,30 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g case GGML_OP_ARGSORT: return op->src[0]->ne[0] * sizeof(int) <= ggml_sycl_info().devices[device].smpbo; + case GGML_OP_TOP_K: { + const ggml_tensor * src0 = op->src[0]; + const int k = op->ne[0]; + return src0 && + op->type == GGML_TYPE_I32 && + src0->type == GGML_TYPE_F32 && + ggml_is_contiguous(src0) && + k > 0 && k <= 32; + } case GGML_OP_POOL_2D: - case GGML_OP_ACC: return true; + case GGML_OP_ACC: + return ggml_is_contiguous(op->src[0]) && ggml_is_contiguous(op->src[1]); case GGML_OP_PAD: - // TODO: add circular padding support for syscl, see https://github.com/ggml-org/llama.cpp/pull/16985 if (ggml_get_op_params_i32(op, 8) != 0) { return false; } - return ggml_is_contiguous(op->src[0]); + return true; case GGML_OP_LEAKY_RELU: case GGML_OP_TIMESTEP_EMBEDDING: case GGML_OP_RWKV_WKV6: case GGML_OP_RWKV_WKV7: case GGML_OP_GATED_LINEAR_ATTN: + case GGML_OP_GATED_DELTA_NET: return true; case GGML_OP_SSM_CONV: return op->type == GGML_TYPE_F32 && @@ -4644,6 +5525,23 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g return op->type == GGML_TYPE_F32; case GGML_OP_ARANGE: return op->type == GGML_TYPE_F32; + case GGML_OP_SSM_SCAN: + if (op->src[3]->ne[0] == 1) { + // Mamba2 + // (kernel only supports (d_state == 128 || d_state == 256) && d_head % WARP_SIZE == 0) + return (op->src[0]->ne[0] == 128 || op->src[0]->ne[0] == 256) && op->src[0]->ne[1] % WARP_SIZE == 0; + } else { + // TODO Mamba-1 not yet ported to SYCL + return false; + } + case GGML_OP_FILL: + case GGML_OP_CUMSUM: + case GGML_OP_DIAG: + return true; + case GGML_OP_SOLVE_TRI: + return op->src[0]->ne[0] <= SYCL_SOLVE_TRI_MAX_N && op->src[1]->ne[0] <= SYCL_SOLVE_TRI_MAX_K; + case GGML_OP_FLASH_ATTN_EXT: + return ggml_sycl_flash_attn_ext_supported(device, op); default: return false; } diff --git a/ggml/src/ggml-sycl/im2col.cpp b/ggml/src/ggml-sycl/im2col.cpp index 6d75d34d83f..7bf3584fb97 100644 --- a/ggml/src/ggml-sycl/im2col.cpp +++ b/ggml/src/ggml-sycl/im2col.cpp @@ -1,6 +1,6 @@ // // MIT license -// Copyright (C) 2024 Intel Corporation +// Copyright (C) 2026 Intel Corporation // SPDX-License-Identifier: MIT // @@ -12,125 +12,389 @@ #include "im2col.hpp" -#include <sycl/sycl.hpp> -#include <type_traits> // For std::is_same_v - -#include "ggml.h" +#define MAX_GRIDDIM_Z 65535 template <typename T> -static void im2col_kernel(const float * x, T * dst, int64_t batch_offset, int64_t offset_delta, int64_t IC, int64_t IW, - int64_t IH, int64_t OH, int64_t OW, int64_t KW, int64_t KH, int64_t pelements, int64_t CHW, - int s0, int s1, int p0, int p1, int d0, int d1, const sycl::nd_item<3> & item_ct1) { - const int64_t work_group_size = item_ct1.get_local_range(2); - const int64_t global_id = item_ct1.get_local_id(2) + (work_group_size * item_ct1.get_group(2)); - - // make each work-item deal with more elements since sycl global range can not exceed max int - for (int64_t i = global_id; i < pelements; i += (work_group_size * item_ct1.get_group_range(2))) { - const int64_t ksize = OW * KH; - const int64_t kx = i / ksize; - const int64_t kd = kx * ksize; - const int64_t ky = (i - kd) / OW; - const int64_t ix = i % OW; - - const int64_t oh = item_ct1.get_group(1); - const int64_t batch = item_ct1.get_group(0) / IC; - const int64_t ic = item_ct1.get_group(0) % IC; - - const int64_t iiw = (ix * s0) + (kx * d0) - p0; - const int64_t iih = (oh * s1) + (ky * d1) - p1; - - const int64_t offset_dst = (((batch * OH + oh) * OW + ix) * CHW) + (ic * (KW * KH) + ky * KW + kx); - - const int64_t offset_src_base = (ic * offset_delta) + (batch * batch_offset); - const int64_t offset_src = offset_src_base + (iih * IW) + iiw; - - const bool out_of_bounds = (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW); - const float src_val = out_of_bounds ? 0.0f : x[offset_src]; - - if constexpr (std::is_same_v<T, sycl::half>) { - dst[offset_dst] = sycl::half(src_val); - } else if constexpr (std::is_same_v<T, float>) { - dst[offset_dst] = src_val; - } +static void im2col_kernel( + const float * x, T * dst, + int64_t IC, int64_t IW, int64_t IH, int64_t OH, int64_t OW, int64_t KW, int64_t KH, + int64_t IC_IH_IW, int64_t IH_IW, int64_t N_OH, int64_t KH_KW, int64_t IC_KH_KW, + int s0, int s1, int p0, int p1, int d0, int d1) { + auto item_ct1 = sycl::ext::oneapi::this_work_item::get_nd_item<3>(); + const int64_t i = item_ct1.get_local_id(2) + item_ct1.get_group(2) * item_ct1.get_local_range(2); + if (i >= IC_KH_KW) { + return; } -} -template <typename T> -static void im2col_sycl_internal(const float * x, T * dst, int64_t IW, int64_t IH, int64_t OW, int64_t OH, int64_t KW, - int64_t KH, int64_t IC, int64_t batch, int64_t batch_offset, int64_t offset_delta, - int s0, int s1, int p0, int p1, int d0, int d1, queue_ptr stream) { - const int64_t parallel_elements = OW * KW * KH; - const int64_t num_blocks = (parallel_elements + SYCL_IM2COL_BLOCK_SIZE - 1) / SYCL_IM2COL_BLOCK_SIZE; + const int64_t iic = i / (KH_KW); + const int64_t rem = i - iic * KH_KW; + const int64_t ikh = rem / KW; + const int64_t ikw = rem - ikh * KW; - // decrease global range when it exceeds the max int - int64_t local_size = downsample_sycl_global_range(batch * IC * OH * num_blocks, SYCL_IM2COL_BLOCK_SIZE); + const int64_t iow = item_ct1.get_group(1); + for (int64_t iz = item_ct1.get_group(0); iz < N_OH; iz += MAX_GRIDDIM_Z) { + const int64_t in = iz / OH; + const int64_t ioh = iz - in * OH; - sycl::range<3> block_nums(batch * IC, OH, num_blocks); - sycl::range<3> local_range(1, 1, local_size); + const int64_t iiw = iow * s0 + ikw * d0 - p0; + const int64_t iih = ioh * s1 + ikh * d1 - p1; - const int64_t CHW = IC * KH * KW; + const int64_t offset_dst = + ((in * OH + ioh) * OW + iow) * IC_KH_KW + iic * KH_KW + ikh * KW + ikw; - stream->parallel_for(sycl::nd_range<3>(block_nums * local_range, local_range), [=](sycl::nd_item<3> item_ct1) { - im2col_kernel<T>(x, dst, batch_offset, offset_delta, IC, IW, IH, OH, OW, KW, KH, parallel_elements, CHW, s0, s1, - p0, p1, d0, d1, item_ct1); - }); + if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW) { + dst[offset_dst] = 0.0f; + } else { + const int64_t offset_src = iic * IC_IH_IW + in * IH_IW; + dst[offset_dst] = x[offset_src + iih * IW + iiw]; + } + } + + GGML_UNUSED(IC); + GGML_UNUSED(KH); } -static void im2col_sycl_f16(const float * x, sycl::half * dst, int64_t IW, int64_t IH, int64_t OW, int64_t OH, - int64_t KW, int64_t KH, int64_t IC, int64_t batch, int64_t batch_offset, - int64_t offset_delta, int s0, int s1, int p0, int p1, int d0, int d1, queue_ptr stream) { - if (!stream->get_device().has(sycl::aspect::fp16)) { - throw sycl::exception(sycl::make_error_code(sycl::errc::kernel_not_supported), - "Device does not support half precision (fp16) operations!"); - } - im2col_sycl_internal<sycl::half>(x, dst, IW, IH, OW, OH, KW, KH, IC, batch, batch_offset, offset_delta, s0, s1, p0, - p1, d0, d1, stream); +// im2col: [N, IC, IH, IW] => [N, OH, OW, IC*KH*KW] +template <typename T> +static void im2col_sycl(const float * x, + T * dst, + int64_t IW, + int64_t IH, + int64_t OW, + int64_t OH, + int64_t KW, + int64_t KH, + int64_t IC, + int64_t N, + int64_t IC_IH_IW, + int64_t IH_IW, + int s0, + int s1, + int p0, + int p1, + int d0, + int d1, + dpct::queue_ptr stream) { + const int64_t IC_KH_KW = IC * KH * KW; + const int64_t num_blocks = (IC_KH_KW + SYCL_IM2COL_BLOCK_SIZE - 1) / SYCL_IM2COL_BLOCK_SIZE; + const int64_t N_OH = N * OH; + const int64_t KH_KW = KW*KH; + dpct::dim3 block_nums(num_blocks, OW, MIN(N_OH, MAX_GRIDDIM_Z)); + /* + DPCT1049:73: The work-group size passed to the SYCL kernel may exceed the limit. To get the device limit, query info::device::max_work_group_size. Adjust the work-group size if needed. + */ + stream->parallel_for(sycl::nd_range<3>(block_nums * sycl::range<3>(1, 1, MIN(IC_KH_KW, SYCL_IM2COL_BLOCK_SIZE)), + sycl::range<3>(1, 1, MIN(IC_KH_KW, SYCL_IM2COL_BLOCK_SIZE))), + [=](sycl::nd_item<3> item_ct1) { + im2col_kernel(x, dst, IC, IW, IH, OH, OW, KW, KH, IC_IH_IW, IH_IW, N_OH, KH_KW, IC_KH_KW, + s0, s1, p0, p1, d0, d1); + }); +} + +static void im2col_sycl_f16(const float * x, + sycl::half * dst, + int64_t IW, + int64_t IH, + int64_t OW, + int64_t OH, + int64_t KW, + int64_t KH, + int64_t IC, + int64_t N, + int64_t IC_IH_IW, + int64_t IH_IW, + int s0, + int s1, + int p0, + int p1, + int d0, + int d1, + dpct::queue_ptr stream) { + im2col_sycl<sycl::half>(x, dst, IW, IH, OW, OH, KW, KH, IC, N, IC_IH_IW, IH_IW, s0, s1, p0, p1, d0, d1, stream); } -static void im2col_sycl_f32(const float * x, float * dst, int64_t IW, int64_t IH, int64_t OW, int64_t OH, int64_t KW, - int64_t KH, int64_t IC, int64_t batch, int64_t batch_offset, int64_t offset_delta, int s0, - int s1, int p0, int p1, int d0, int d1, queue_ptr stream) { - im2col_sycl_internal<float>(x, dst, IW, IH, OW, OH, KW, KH, IC, batch, batch_offset, offset_delta, s0, s1, p0, p1, - d0, d1, stream); +static void im2col_sycl_f32(const float * x, + float * dst, + int64_t IW, + int64_t IH, + int64_t OW, + int64_t OH, + int64_t KW, + int64_t KH, + int64_t IC, + int64_t N, + int64_t IC_IH_IW, + int64_t IH_IW, + int s0, + int s1, + int p0, + int p1, + int d0, + int d1, + dpct::queue_ptr stream) { + im2col_sycl<float>(x, dst, IW, IH, OW, OH, KW, KH, IC, N, IC_IH_IW, IH_IW, s0, s1, p0, p1, d0, d1, stream); } void ggml_sycl_op_im2col(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { const ggml_tensor * src0 = dst->src[0]; const ggml_tensor * src1 = dst->src[1]; + const float * src1_d = (const float *)src1->data; + float * dst_d = (float *)dst->data; + dpct::queue_ptr stream = ctx.stream(); GGML_ASSERT(src1->type == GGML_TYPE_F32); - GGML_ASSERT(dst->type == GGML_TYPE_F16 || dst->type == GGML_TYPE_F32); + GGML_ASSERT( dst->type == GGML_TYPE_F16 || dst->type == GGML_TYPE_F32); - const int32_t s0 = ((const int32_t *) (dst->op_params))[0]; - const int32_t s1 = ((const int32_t *) (dst->op_params))[1]; - const int32_t p0 = ((const int32_t *) (dst->op_params))[2]; - const int32_t p1 = ((const int32_t *) (dst->op_params))[3]; - const int32_t d0 = ((const int32_t *) (dst->op_params))[4]; - const int32_t d1 = ((const int32_t *) (dst->op_params))[5]; + const int32_t s0 = ((const int32_t*)(dst->op_params))[0]; + const int32_t s1 = ((const int32_t*)(dst->op_params))[1]; + const int32_t p0 = ((const int32_t*)(dst->op_params))[2]; + const int32_t p1 = ((const int32_t*)(dst->op_params))[3]; + const int32_t d0 = ((const int32_t*)(dst->op_params))[4]; + const int32_t d1 = ((const int32_t*)(dst->op_params))[5]; - const bool is_2D = ((const int32_t *) (dst->op_params))[6] == 1; + const bool is_2D = ((const int32_t*)(dst->op_params))[6] == 1; const int64_t IC = src1->ne[is_2D ? 2 : 1]; const int64_t IH = is_2D ? src1->ne[1] : 1; - const int64_t IW = src1->ne[0]; + const int64_t IW = src1->ne[0]; const int64_t KH = is_2D ? src0->ne[1] : 1; - const int64_t KW = src0->ne[0]; + const int64_t KW = src0->ne[0]; const int64_t OH = is_2D ? dst->ne[2] : 1; - const int64_t OW = dst->ne[1]; + const int64_t OW = dst->ne[1]; + + const int64_t IC_IH_IW = src1->nb[is_2D ? 2 : 1] / 4; // nb is byte offset, src is type float32 + const int64_t N = src1->ne[is_2D ? 3 : 2]; + const int64_t IH_IW = src1->nb[is_2D ? 3 : 2] / 4; // nb is byte offset, src is type float32 + + if(dst->type == GGML_TYPE_F16) { + im2col_sycl_f16(src1_d, (sycl::half *) dst_d, IW, IH, OW, OH, KW, KH, IC, N, IC_IH_IW, IH_IW, s0, s1, p0, p1, + d0, d1, stream); + } else { + im2col_sycl_f32(src1_d, (float *) dst_d, IW, IH, OW, OH, KW, KH, IC, N, IC_IH_IW, IH_IW, s0, s1, p0, p1, d0, d1, stream); + } +} + +// [N*IC, ID, IH, IW] => [N*OD, OH, OW, IC * KD * KH * KW] +template <typename T> +static void im2col_3d_kernel( + const float * src, T * dst, + int64_t N, int64_t IC, int64_t ID, int64_t IH, int64_t IW, int64_t OC, + int64_t KD, int64_t KH, int64_t KW, int64_t OD, int64_t OH, int64_t OW, + int64_t OH_OW, int64_t KD_KH_KW, int64_t ID_IH_IW, int64_t KH_KW, int64_t IH_IW, int64_t IC_ID_IH_IW, + int64_t IC_KD_KH_KW, int64_t OW_KD_KH_KW, int64_t OD_OH_OW_IC_KD_KH_KW, int64_t OH_OW_IC_KD_KH_KW, + int64_t OW_IC_KD_KH_KW, int64_t N_OD_OH, int64_t OD_OH, + int64_t stride_q, int64_t stride_z, int64_t stride_y, int64_t stride_x, + int s0, int s1, int s2, int p0, int p1, int p2, int d0, int d1, int d2) { + auto item_ct1 = sycl::ext::oneapi::this_work_item::get_nd_item<3>(); + const int64_t i = item_ct1.get_local_id(2) + item_ct1.get_group(2) * item_ct1.get_local_range(2); + if (i >= IC_KD_KH_KW) { + return; + } + GGML_UNUSED(N); GGML_UNUSED(OC); GGML_UNUSED(OH_OW); GGML_UNUSED(OD); GGML_UNUSED(OW); GGML_UNUSED(KD); GGML_UNUSED(KH); + GGML_UNUSED(ID_IH_IW); GGML_UNUSED(IH_IW); GGML_UNUSED(IC_ID_IH_IW); GGML_UNUSED(OW_KD_KH_KW); + + const int64_t iic = i / KD_KH_KW; + const int64_t ikd = (i - iic * KD_KH_KW) / KH_KW; + const int64_t ikh = (i - iic * KD_KH_KW - ikd * KH_KW) / KW; + const int64_t ikw = i % KW; + + const int64_t iow = item_ct1.get_group(1); + for (int64_t iz = item_ct1.get_group(0); iz < N_OD_OH; iz += MAX_GRIDDIM_Z) { + const int64_t in = iz / OD_OH; + const int64_t iod = (iz - in*OD_OH) / OH; + const int64_t ioh = iz % OH; + + const int64_t iiw = iow * s0 + ikw * d0 - p0; + const int64_t iih = ioh * s1 + ikh * d1 - p1; + const int64_t iid = iod * s2 + ikd * d2 - p2; + + const int64_t offset_dst = in*OD_OH_OW_IC_KD_KH_KW + iod*OH_OW_IC_KD_KH_KW + ioh*OW_IC_KD_KH_KW + iow*IC_KD_KH_KW + iic*KD_KH_KW + ikd * KH_KW + ikh*KW + ikw; + + if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW || iid < 0 || iid >= ID) { + dst[offset_dst] = 0.0f; + } else { + const int64_t offset_src = ((in * IC + iic) * stride_q) + (iid * stride_z) + (iih * stride_y) + (iiw * stride_x); + dst[offset_dst] = src[offset_src]; + } + } +} + +// [N*IC, ID, IH, IW] => [N*OD, OH, OW, IC * KD * KH * KW] +template <typename T> +static void im2col_3d_sycl(const float * src, + T * dst, + int64_t N, + int64_t IC, + int64_t ID, + int64_t IH, + int64_t IW, + int64_t OC, + int64_t KD, + int64_t KH, + int64_t KW, + int64_t OD, + int64_t OH, + int64_t OW, + int64_t stride_q, + int64_t stride_z, + int64_t stride_y, + int64_t stride_x, + int s0, + int s1, + int s2, + int p0, + int p1, + int p2, + int d0, + int d1, + int d2, + dpct::queue_ptr stream) { + const int64_t OH_OW = OH*OW; + const int64_t KD_KH_KW = KD*KH*KW; + const int64_t ID_IH_IW = ID*IH*IW; + const int64_t KH_KW = KH*KW; + const int64_t IH_IW = IH*IW; + const int64_t IC_KD_KH_KW = IC*KD*KH*KW; + const int64_t OW_KD_KH_KW = OW*KD*KH*KW; + const int64_t N_OD_OH = N*OD*OH; + const int64_t OD_OH = OD*OH; + const int64_t IC_ID_IH_IW = IC*ID*IH*IW; + const int64_t OD_OH_OW_IC_KD_KH_KW = OD*OH*OW*IC*KD*KH*KW; + const int64_t OH_OW_IC_KD_KH_KW = OH*OW*IC*KD*KH*KW; + const int64_t OW_IC_KD_KH_KW = OW*IC*KD*KH*KW; + const int64_t num_blocks = (IC_KD_KH_KW + SYCL_IM2COL_BLOCK_SIZE - 1) / SYCL_IM2COL_BLOCK_SIZE; + dpct::dim3 block_nums(num_blocks, OW, MIN(N_OD_OH, MAX_GRIDDIM_Z)); + /* + DPCT1049:74: The work-group size passed to the SYCL kernel may exceed the limit. To get the device limit, query info::device::max_work_group_size. Adjust the work-group size if needed. + */ + stream->parallel_for(sycl::nd_range<3>(block_nums * sycl::range<3>(1, 1, MIN(IC_KD_KH_KW, SYCL_IM2COL_BLOCK_SIZE)), + sycl::range<3>(1, 1, MIN(IC_KD_KH_KW, SYCL_IM2COL_BLOCK_SIZE))), + [=](sycl::nd_item<3> item_ct1) { + im2col_3d_kernel(src, dst, N, IC, ID, IH, IW, OC, KD, KH, KW, OD, OH, OW, OH_OW, KD_KH_KW, + ID_IH_IW, KH_KW, IH_IW, IC_ID_IH_IW, IC_KD_KH_KW, OW_KD_KH_KW, + OD_OH_OW_IC_KD_KH_KW, OH_OW_IC_KD_KH_KW, OW_IC_KD_KH_KW, N_OD_OH, OD_OH, + stride_q, stride_z, stride_y, stride_x, s0, s1, s2, p0, p1, p2, d0, d1, + d2); + }); +} + +static void im2col_3d_sycl_f16(const float * src, + sycl::half * dst, + int64_t N, + int64_t IC, + int64_t ID, + int64_t IH, + int64_t IW, + int64_t OC, + int64_t KD, + int64_t KH, + int64_t KW, + int64_t OD, + int64_t OH, + int64_t OW, + int64_t stride_q, + int64_t stride_z, + int64_t stride_y, + int64_t stride_x, + int s0, + int s1, + int s2, + int p0, + int p1, + int p2, + int d0, + int d1, + int d2, + dpct::queue_ptr stream) { + im2col_3d_sycl<sycl::half>(src, dst, N, IC, ID, IH, IW, OC, KD, KH, KW, OD, OH, OW, stride_q, stride_z, stride_y, + stride_x, s0, s1, s2, p0, p1, p2, d0, d1, d2, stream); +} + +static void im2col_3d_sycl_f32(const float * src, + float * dst, + int64_t N, + int64_t IC, + int64_t ID, + int64_t IH, + int64_t IW, + int64_t OC, + int64_t KD, + int64_t KH, + int64_t KW, + int64_t OD, + int64_t OH, + int64_t OW, + int64_t stride_q, + int64_t stride_z, + int64_t stride_y, + int64_t stride_x, + int s0, + int s1, + int s2, + int p0, + int p1, + int p2, + int d0, + int d1, + int d2, + dpct::queue_ptr stream) { + im2col_3d_sycl<float>(src, dst, N, IC, ID, IH, IW, OC, KD, KH, KW, OD, OH, OW, + stride_q, stride_z, stride_y, stride_x, + s0, s1, s2, p0, p1, p2, d0, d1, d2, stream); +} + +void ggml_sycl_op_im2col_3d(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { + const ggml_tensor * src0 = dst->src[0]; + const ggml_tensor * src1 = dst->src[1]; + const float * src1_d = (const float *)src1->data; + float * dst_d = (float *)dst->data; + dpct::queue_ptr stream = ctx.stream(); + + GGML_ASSERT(src1->type == GGML_TYPE_F32); + GGML_ASSERT( dst->type == GGML_TYPE_F16 || dst->type == GGML_TYPE_F32); + + GGML_TENSOR_BINARY_OP_LOCALS + + const int32_t s0 = ((const int32_t *)(dst->op_params))[0]; + const int32_t s1 = ((const int32_t *)(dst->op_params))[1]; + const int32_t s2 = ((const int32_t *)(dst->op_params))[2]; + const int32_t p0 = ((const int32_t *)(dst->op_params))[3]; + const int32_t p1 = ((const int32_t *)(dst->op_params))[4]; + const int32_t p2 = ((const int32_t *)(dst->op_params))[5]; + const int32_t d0 = ((const int32_t *)(dst->op_params))[6]; + const int32_t d1 = ((const int32_t *)(dst->op_params))[7]; + const int32_t d2 = ((const int32_t *)(dst->op_params))[8]; + const int32_t IC = ((const int32_t *)(dst->op_params))[9]; + + const int64_t N = ne13 / IC; + const int64_t ID = ne12; + const int64_t IH = ne11; + const int64_t IW = ne10; + + const int64_t OC = ne03 / IC; + const int64_t KD = ne02; + const int64_t KH = ne01; + const int64_t KW = ne00; - const size_t delta_offset = src1->nb[is_2D ? 2 : 1] / sizeof(float); - const int64_t batch = src1->ne[is_2D ? 3 : 2]; - const size_t batch_offset = src1->nb[is_2D ? 3 : 2] / sizeof(float); + const int64_t OD = ne3 / N; + const int64_t OH = ne2; + const int64_t OW = ne1; - queue_ptr stream = ctx.stream(); + const size_t es = ggml_element_size(src1); + const int64_t stride_x = src1->nb[0] / es; + const int64_t stride_y = src1->nb[1] / es; + const int64_t stride_z = src1->nb[2] / es; + const int64_t stride_q = src1->nb[3] / es; - if (dst->type == GGML_TYPE_F16) { - im2col_sycl_f16((const float *) src1->data, (sycl::half *) dst->data, IW, IH, OW, OH, KW, KH, IC, batch, - batch_offset, delta_offset, s0, s1, p0, p1, d0, d1, stream); + if(dst->type == GGML_TYPE_F16) { + im2col_3d_sycl_f16(src1_d, (sycl::half *) dst_d, N, IC, ID, IH, IW, OC, KD, KH, KW, OD, OH, OW, + stride_q, stride_z, stride_y, stride_x, + s0, s1, s2, p0, p1, p2, d0, d1, d2, stream); } else { - im2col_sycl_f32((const float *) src1->data, (float *) dst->data, IW, IH, OW, OH, KW, KH, IC, batch, - batch_offset, delta_offset, s0, s1, p0, p1, d0, d1, stream); + im2col_3d_sycl_f32(src1_d, (float *) dst_d, N, IC, ID, IH, IW, OC, KD, KH, KW, OD, OH, OW, + stride_q, stride_z, stride_y, stride_x, + s0, s1, s2, p0, p1, p2, d0, d1, d2, stream); } } diff --git a/ggml/src/ggml-sycl/im2col.hpp b/ggml/src/ggml-sycl/im2col.hpp index dbbb248ddb4..976d1094636 100644 --- a/ggml/src/ggml-sycl/im2col.hpp +++ b/ggml/src/ggml-sycl/im2col.hpp @@ -1,6 +1,6 @@ // // MIT license -// Copyright (C) 2024 Intel Corporation +// Copyright (C) 2026 Intel Corporation // SPDX-License-Identifier: MIT // @@ -15,7 +15,9 @@ #include "common.hpp" -void ggml_sycl_op_im2col( - ggml_backend_sycl_context & ctx, ggml_tensor *dst); +#define SYCL_IM2COL_BLOCK_SIZE 256 + +void ggml_sycl_op_im2col(ggml_backend_sycl_context & ctx, ggml_tensor * dst); +void ggml_sycl_op_im2col_3d(ggml_backend_sycl_context & ctx, ggml_tensor * dst); #endif // GGML_SYCL_IM2COL_HPP diff --git a/ggml/src/ggml-sycl/mmvq.cpp b/ggml/src/ggml-sycl/mmvq.cpp index 316aa0d0fb5..cf2b59576aa 100644 --- a/ggml/src/ggml-sycl/mmvq.cpp +++ b/ggml/src/ggml-sycl/mmvq.cpp @@ -56,6 +56,65 @@ static void mul_mat_vec_q_reorder(const void * __restrict__ vx, const void * __r } } +template <typename reorder_vec_dot_q_sycl, int ncols_dst> +static void mul_mat_vec_q_reorder_ncols(const void * __restrict__ vx, const void * __restrict__ vy, + float * __restrict__ dst, const int ncols, const int nrows, + const int stride_col_y_bytes, const int stride_col_dst, + const sycl::nd_item<3> & nd_item) { + using block_type = ggml_sycl_reordered::block_q_t<reorder_vec_dot_q_sycl::gtype>; + using block_traits = typename block_type::traits; + + const auto sg = nd_item.get_sub_group(); + const int sg_range = sg.get_group_linear_range(); + const int workgroup_id = nd_item.get_group_linear_id(); + const int sg_id = sg.get_group_linear_id(); + const int row = workgroup_id * sg_range + sg_id; + + if (row >= nrows) { + return; + } + + const int blocks_per_row = ncols / block_traits::qk; + constexpr int blocks_per_subgroup = ceil_div(block_traits::vdr_mmvq * WARP_SIZE, block_traits::qi); + constexpr int block_elements_per_subgroup = block_traits::qi / block_traits::vdr_mmvq; + const int nblocks = nrows * (ncols / block_traits::qk); + + static_assert(blocks_per_subgroup > 0); + static_assert(block_elements_per_subgroup > 0); + + float partial_sum[ncols_dst] = {0.0f}; + for (int i = sg.get_local_linear_id() / block_elements_per_subgroup; i < blocks_per_row; i += blocks_per_subgroup) { + const int ibx = row * blocks_per_row + i; + + const auto bx_offset = block_type::get_block_offset(ibx, nblocks); + const auto d_offset = block_type::get_d_offset(nrows, ncols, ibx); + const int iby = i * block_type::block_to_q8_1_ratio(); + +#pragma unroll + for (int elem = 0; elem < block_elements_per_subgroup; elem += WARP_SIZE) { + const int iqs = elem + block_traits::vdr_mmvq * (sg.get_local_linear_id() % block_elements_per_subgroup); + +#pragma unroll + for (int j = 0; j < ncols_dst; ++j) { + const char * vy_j = (const char *)vy + j * stride_col_y_bytes; + const int8_t * q8_1_quant_ptr = (const int8_t *)vy_j + iby * QK8_1; + const sycl::half2* q8_1_ds_ptr = (const sycl::half2 *)(vy_j + ncols + iby * sizeof(sycl::half2)); + + partial_sum[j] += reorder_vec_dot_q_sycl()(vx, bx_offset, d_offset, q8_1_quant_ptr, q8_1_ds_ptr, iqs); + } + } + } + +#pragma unroll + for (int j = 0; j < ncols_dst; ++j) { + float sum = sycl::reduce_over_group(nd_item.get_sub_group(), partial_sum[j], std::plus<>()); + + if (sg.leader()) { + dst[j * stride_col_dst + row] = sum; + } + } +} + template <int qk, int qi, typename block_q_t, int vdr, vec_dot_q_sycl_t vec_dot_q_sycl> static void mul_mat_vec_q(const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst, const int ncols, const int nrows, const sycl::nd_item<3> & item_ct1) { @@ -100,6 +159,70 @@ static void mul_mat_vec_q(const void * __restrict__ vx, const void * __restrict_ } } +template <int qk, int qi, typename block_q_t, int vdr, + vec_dot_q_sycl_t vec_dot_q_sycl, int ncols_dst> +static void mul_mat_vec_q_ncols( + const void * __restrict__ vx, + const void * __restrict__ vy, + float * __restrict__ dst, + const int ncols, + const int nrows, + const int stride_col_y, + const int stride_col_dst, + const sycl::nd_item<3> & item_ct1) { + + const int row = item_ct1.get_group(2) * item_ct1.get_local_range(1) + + item_ct1.get_local_id(1); + + if (row >= nrows) { + return; + } + + const int blocks_per_row = ncols / qk; + constexpr int blocks_per_warp = (vdr * WARP_SIZE + qi - 1) / qi; + + // partial sums: one per output column + float tmp[ncols_dst] = {0.0f}; + + const block_q_t * x = (const block_q_t *) vx; + const block_q8_1 * y = (const block_q8_1 *) vy; + + for (int i = item_ct1.get_local_id(2) / (qi / vdr); + i < blocks_per_row; + i += blocks_per_warp) { + + const int ibx = row * blocks_per_row + i; + const int iby = i * (qk / QK8_1); + + // read weight block once, dot against all columns + for (size_t elem = 0; elem < qi / vdr; elem += WARP_SIZE) { + const int iqs = elem + vdr * (item_ct1.get_local_id(2) % (qi / vdr)); + +#pragma unroll + for (int j = 0; j < ncols_dst; ++j) { + tmp[j] += vec_dot_q_sycl(&x[ibx], &y[j * stride_col_y + iby], iqs); + } + } + } + + // reduce within subgroup +#pragma unroll + for (int j = 0; j < ncols_dst; ++j) { +#pragma unroll + for (int mask = WARP_SIZE / 2; mask > 0; mask >>= 1) { + tmp[j] += dpct::permute_sub_group_by_xor( + item_ct1.get_sub_group(), tmp[j], mask); + } + } + + if (item_ct1.get_local_id(2) == 0) { +#pragma unroll + for (int j = 0; j < ncols_dst; ++j) { + dst[j * stride_col_dst + row] = tmp[j]; + } + } +} + template <int qk, int qi, typename block_q_t, int vdr> static void mul_mat_vec_q_iq2_xxs_q8_1(const void *__restrict__ vx, const void *__restrict__ vy, @@ -537,9 +660,9 @@ static void mul_mat_vec_q_iq4_xs_q8_1(const void *__restrict__ vx, static void reorder_mul_mat_vec_q4_0_q8_1_sycl(const void * vx, const void * vy, float * dst, const int ncols, const int nrows, dpct::queue_ptr stream) { GGML_ASSERT(ncols % QK4_0 == 0); - const int block_num_y = ceil_div(nrows, GGML_SYCL_MMV_Y); - constexpr size_t num_subgroups = 16; - GGML_ASSERT(block_num_y % num_subgroups == 0); + // Round up to a whole number of subgroup-sized workgroups; out-of-range rows are skipped inside the kernel. + constexpr size_t num_subgroups = WARP_SIZE; + const int block_num_y = ceil_div(nrows, GGML_SYCL_MMV_Y * (int) num_subgroups) * (int) num_subgroups; const sycl::range<3> global_size(1, GGML_SYCL_MMV_Y, (block_num_y * WARP_SIZE)); const sycl::range<3> workgroup_size(1, GGML_SYCL_MMV_Y, num_subgroups * WARP_SIZE); @@ -553,6 +676,45 @@ static void reorder_mul_mat_vec_q4_0_q8_1_sycl(const void * vx, const void * vy, }); } +template <int ncols_dst> +static void reorder_mul_mat_vec_q4_0_q8_1_sycl_ncols( + const void * vx, const void * vy, float * dst, + const int ncols, const int nrows, + const int stride_col_y_bytes, const int stride_col_dst, + dpct::queue_ptr stream) { + GGML_ASSERT(ncols % QK4_0 == 0); + const int block_num_y = ceil_div(nrows, GGML_SYCL_MMV_Y); + constexpr size_t num_subgroups = 16; + GGML_ASSERT(block_num_y % num_subgroups == 0); + const sycl::range<3> global_size(1, GGML_SYCL_MMV_Y, block_num_y * WARP_SIZE); + const sycl::range<3> workgroup_size(1, GGML_SYCL_MMV_Y, num_subgroups * WARP_SIZE); + stream->submit([&](sycl::handler & cgh) { + cgh.parallel_for(sycl::nd_range<3>(global_size, workgroup_size), + [=](sycl::nd_item<3> nd_item) [[sycl::reqd_sub_group_size(WARP_SIZE)]] { + mul_mat_vec_q_reorder_ncols<reorder_vec_dot_q_sycl<GGML_TYPE_Q4_0>, ncols_dst>( + vx, vy, dst, ncols, nrows, stride_col_y_bytes, stride_col_dst, nd_item); + }); + }); +} + +static void reorder_mul_mat_vec_q4_0_q8_1_sycl_switch_ncols( + const void * vx, const void * vy, float * dst, + const int ncols, const int nrows, const int ncols_dst, + const int stride_col_y_bytes, const int stride_col_dst, + dpct::queue_ptr stream) { + switch (ncols_dst) { + case 1: reorder_mul_mat_vec_q4_0_q8_1_sycl(vx, vy, dst, ncols, nrows, stream); break; + case 2: reorder_mul_mat_vec_q4_0_q8_1_sycl_ncols<2>(vx, vy, dst, ncols, nrows, stride_col_y_bytes, stride_col_dst, stream); break; + case 3: reorder_mul_mat_vec_q4_0_q8_1_sycl_ncols<3>(vx, vy, dst, ncols, nrows, stride_col_y_bytes, stride_col_dst, stream); break; + case 4: reorder_mul_mat_vec_q4_0_q8_1_sycl_ncols<4>(vx, vy, dst, ncols, nrows, stride_col_y_bytes, stride_col_dst, stream); break; + case 5: reorder_mul_mat_vec_q4_0_q8_1_sycl_ncols<5>(vx, vy, dst, ncols, nrows, stride_col_y_bytes, stride_col_dst, stream); break; + case 6: reorder_mul_mat_vec_q4_0_q8_1_sycl_ncols<6>(vx, vy, dst, ncols, nrows, stride_col_y_bytes, stride_col_dst, stream); break; + case 7: reorder_mul_mat_vec_q4_0_q8_1_sycl_ncols<7>(vx, vy, dst, ncols, nrows, stride_col_y_bytes, stride_col_dst, stream); break; + case 8: reorder_mul_mat_vec_q4_0_q8_1_sycl_ncols<8>(vx, vy, dst, ncols, nrows, stride_col_y_bytes, stride_col_dst, stream); break; + default: GGML_ABORT("unsupported ncols_dst=%d for Q4_0 reorder multi-col MMVQ", ncols_dst); + } +} + static void mul_mat_vec_q4_0_q8_1_sycl(const void * vx, const void * vy, float * dst, const int ncols, const int nrows, dpct::queue_ptr stream) { GGML_ASSERT(ncols % QK4_0 == 0); @@ -571,6 +733,45 @@ static void mul_mat_vec_q4_0_q8_1_sycl(const void * vx, const void * vy, float * } } +template <int ncols_dst> +static void mul_mat_vec_q4_0_q8_1_sycl_ncols( + const void * vx, const void * vy, float * dst, + const int ncols, const int nrows, + const int stride_col_y, const int stride_col_dst, + dpct::queue_ptr stream) { + GGML_ASSERT(ncols % QK4_0 == 0); + const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y; + const sycl::range<3> block_nums(1, 1, block_num_y); + const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE); + stream->submit([&](sycl::handler & cgh) { + cgh.parallel_for( + sycl::nd_range<3>(block_nums * block_dims, block_dims), + [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] { + mul_mat_vec_q_ncols<QK4_0, QI4_0, block_q4_0, + VDR_Q4_0_Q8_1_MMVQ, vec_dot_q4_0_q8_1, ncols_dst>( + vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, item_ct1); + }); + }); +} + +static void mul_mat_vec_q4_0_q8_1_sycl_switch_ncols( + const void * vx, const void * vy, float * dst, + const int ncols, const int nrows, const int ncols_dst, + const int stride_col_y, const int stride_col_dst, + dpct::queue_ptr stream) { + switch (ncols_dst) { + case 1: mul_mat_vec_q4_0_q8_1_sycl(vx, vy, dst, ncols, nrows, stream); break; + case 2: mul_mat_vec_q4_0_q8_1_sycl_ncols<2>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break; + case 3: mul_mat_vec_q4_0_q8_1_sycl_ncols<3>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break; + case 4: mul_mat_vec_q4_0_q8_1_sycl_ncols<4>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break; + case 5: mul_mat_vec_q4_0_q8_1_sycl_ncols<5>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break; + case 6: mul_mat_vec_q4_0_q8_1_sycl_ncols<6>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break; + case 7: mul_mat_vec_q4_0_q8_1_sycl_ncols<7>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break; + case 8: mul_mat_vec_q4_0_q8_1_sycl_ncols<8>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break; + default: GGML_ABORT("unsupported ncols_dst=%d for Q4_0 multi-col MMVQ", ncols_dst); + } +} + static void mul_mat_vec_q4_1_q8_1_sycl(const void *vx, const void *vy, float *dst, const int ncols, const int nrows, @@ -595,6 +796,45 @@ static void mul_mat_vec_q4_1_q8_1_sycl(const void *vx, const void *vy, } } +template <int ncols_dst> +static void mul_mat_vec_q4_1_q8_1_sycl_ncols( + const void * vx, const void * vy, float * dst, + const int ncols, const int nrows, + const int stride_col_y, const int stride_col_dst, + dpct::queue_ptr stream) { + GGML_ASSERT(ncols % QK4_1 == 0); + const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y; + const sycl::range<3> block_nums(1, 1, block_num_y); + const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE); + stream->submit([&](sycl::handler & cgh) { + cgh.parallel_for( + sycl::nd_range<3>(block_nums * block_dims, block_dims), + [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] { + mul_mat_vec_q_ncols<QK4_0, QI4_1, block_q4_1, + VDR_Q4_1_Q8_1_MMVQ, vec_dot_q4_1_q8_1, ncols_dst>( + vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, item_ct1); + }); + }); +} + +static void mul_mat_vec_q4_1_q8_1_sycl_switch_ncols( + const void * vx, const void * vy, float * dst, + const int ncols, const int nrows, const int ncols_dst, + const int stride_col_y, const int stride_col_dst, + dpct::queue_ptr stream) { + switch (ncols_dst) { + case 1: mul_mat_vec_q4_1_q8_1_sycl(vx, vy, dst, ncols, nrows, stream); break; + case 2: mul_mat_vec_q4_1_q8_1_sycl_ncols<2>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break; + case 3: mul_mat_vec_q4_1_q8_1_sycl_ncols<3>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break; + case 4: mul_mat_vec_q4_1_q8_1_sycl_ncols<4>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break; + case 5: mul_mat_vec_q4_1_q8_1_sycl_ncols<5>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break; + case 6: mul_mat_vec_q4_1_q8_1_sycl_ncols<6>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break; + case 7: mul_mat_vec_q4_1_q8_1_sycl_ncols<7>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break; + case 8: mul_mat_vec_q4_1_q8_1_sycl_ncols<8>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break; + default: GGML_ABORT("unsupported ncols_dst=%d for Q4_1 multi-col MMVQ", ncols_dst); + } +} + static void mul_mat_vec_mxfp4_q8_1_sycl(const void * vx, const void * vy, float * dst, const int ncols, const int nrows, dpct::queue_ptr stream) { GGML_ASSERT(ncols % QK_MXFP4 == 0); @@ -613,6 +853,101 @@ static void mul_mat_vec_mxfp4_q8_1_sycl(const void * vx, const void * vy, float } } +template <int ncols_dst> +static void mul_mat_vec_mxfp4_q8_1_sycl_ncols( + const void * vx, const void * vy, float * dst, + const int ncols, const int nrows, + const int stride_col_y, const int stride_col_dst, + dpct::queue_ptr stream) { + GGML_ASSERT(ncols % QK_MXFP4 == 0); + const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y; + const sycl::range<3> block_nums(1, 1, block_num_y); + const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE); + stream->submit([&](sycl::handler & cgh) { + cgh.parallel_for( + sycl::nd_range<3>(block_nums * block_dims, block_dims), + [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] { + mul_mat_vec_q_ncols<QK_MXFP4, QI_MXFP4, block_mxfp4, + VDR_MXFP4_Q8_1_MMVQ, vec_dot_mxfp4_q8_1, ncols_dst>( + vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, item_ct1); + }); + }); +} + +static void mul_mat_vec_mxfp4_q8_1_sycl_switch_ncols( + const void * vx, const void * vy, float * dst, + const int ncols, const int nrows, const int ncols_dst, + const int stride_col_y, const int stride_col_dst, + dpct::queue_ptr stream) { + switch (ncols_dst) { + case 1: mul_mat_vec_mxfp4_q8_1_sycl(vx, vy, dst, ncols, nrows, stream); break; + case 2: mul_mat_vec_mxfp4_q8_1_sycl_ncols<2>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break; + case 3: mul_mat_vec_mxfp4_q8_1_sycl_ncols<3>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break; + case 4: mul_mat_vec_mxfp4_q8_1_sycl_ncols<4>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break; + case 5: mul_mat_vec_mxfp4_q8_1_sycl_ncols<5>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break; + case 6: mul_mat_vec_mxfp4_q8_1_sycl_ncols<6>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break; + case 7: mul_mat_vec_mxfp4_q8_1_sycl_ncols<7>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break; + case 8: mul_mat_vec_mxfp4_q8_1_sycl_ncols<8>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break; + default: GGML_ABORT("unsupported ncols_dst=%d for MXFP4 multi-col MMVQ", ncols_dst); + } +} + +static void mul_mat_vec_nvfp4_q8_1_sycl(const void * vx, const void * vy, float * dst, const int ncols, const int nrows, + dpct::queue_ptr stream) { + GGML_ASSERT(ncols % QK_NVFP4 == 0); + const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y; + const sycl::range<3> block_nums(1, 1, block_num_y); + const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE); + + { + stream->submit([&](sycl::handler & cgh) { + cgh.parallel_for(sycl::nd_range<3>(block_nums * block_dims, block_dims), + [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] { + mul_mat_vec_q<QK_NVFP4, QI_NVFP4, block_nvfp4, VDR_NVFP4_Q8_1_MMVQ, vec_dot_nvfp4_q8_1>( + vx, vy, dst, ncols, nrows, item_ct1); + }); + }); + } +} + +template <int ncols_dst> +static void mul_mat_vec_nvfp4_q8_1_sycl_ncols( + const void * vx, const void * vy, float * dst, + const int ncols, const int nrows, + const int stride_col_y, const int stride_col_dst, + dpct::queue_ptr stream) { + GGML_ASSERT(ncols % QK_NVFP4 == 0); + const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y; + const sycl::range<3> block_nums(1, 1, block_num_y); + const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE); + stream->submit([&](sycl::handler & cgh) { + cgh.parallel_for( + sycl::nd_range<3>(block_nums * block_dims, block_dims), + [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] { + mul_mat_vec_q_ncols<QK_NVFP4, QI_NVFP4, block_nvfp4, + VDR_NVFP4_Q8_1_MMVQ, vec_dot_nvfp4_q8_1, ncols_dst>( + vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, item_ct1); + }); + }); +} + +static void mul_mat_vec_nvfp4_q8_1_sycl_switch_ncols( + const void * vx, const void * vy, float * dst, + const int ncols, const int nrows, const int ncols_dst, + const int stride_col_y, const int stride_col_dst, + dpct::queue_ptr stream) { + switch (ncols_dst) { + case 1: mul_mat_vec_nvfp4_q8_1_sycl(vx, vy, dst, ncols, nrows, stream); break; + case 2: mul_mat_vec_nvfp4_q8_1_sycl_ncols<2>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break; + case 3: mul_mat_vec_nvfp4_q8_1_sycl_ncols<3>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break; + case 4: mul_mat_vec_nvfp4_q8_1_sycl_ncols<4>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break; + case 5: mul_mat_vec_nvfp4_q8_1_sycl_ncols<5>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break; + case 6: mul_mat_vec_nvfp4_q8_1_sycl_ncols<6>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break; + case 7: mul_mat_vec_nvfp4_q8_1_sycl_ncols<7>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break; + case 8: mul_mat_vec_nvfp4_q8_1_sycl_ncols<8>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break; + default: GGML_ABORT("unsupported ncols_dst=%d for NVFP4 multi-col MMVQ", ncols_dst); + } +} static void mul_mat_vec_q5_0_q8_1_sycl(const void *vx, const void *vy, float *dst, const int ncols, @@ -638,6 +973,45 @@ static void mul_mat_vec_q5_0_q8_1_sycl(const void *vx, const void *vy, } } +template <int ncols_dst> +static void mul_mat_vec_q5_0_q8_1_sycl_ncols( + const void * vx, const void * vy, float * dst, + const int ncols, const int nrows, + const int stride_col_y, const int stride_col_dst, + dpct::queue_ptr stream) { + GGML_ASSERT(ncols % QK5_0 == 0); + const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y; + const sycl::range<3> block_nums(1, 1, block_num_y); + const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE); + stream->submit([&](sycl::handler & cgh) { + cgh.parallel_for( + sycl::nd_range<3>(block_nums * block_dims, block_dims), + [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] { + mul_mat_vec_q_ncols<QK5_0, QI5_0, block_q5_0, + VDR_Q5_0_Q8_1_MMVQ, vec_dot_q5_0_q8_1, ncols_dst>( + vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, item_ct1); + }); + }); +} + +static void mul_mat_vec_q5_0_q8_1_sycl_switch_ncols( + const void * vx, const void * vy, float * dst, + const int ncols, const int nrows, const int ncols_dst, + const int stride_col_y, const int stride_col_dst, + dpct::queue_ptr stream) { + switch (ncols_dst) { + case 1: mul_mat_vec_q5_0_q8_1_sycl(vx, vy, dst, ncols, nrows, stream); break; + case 2: mul_mat_vec_q5_0_q8_1_sycl_ncols<2>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break; + case 3: mul_mat_vec_q5_0_q8_1_sycl_ncols<3>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break; + case 4: mul_mat_vec_q5_0_q8_1_sycl_ncols<4>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break; + case 5: mul_mat_vec_q5_0_q8_1_sycl_ncols<5>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break; + case 6: mul_mat_vec_q5_0_q8_1_sycl_ncols<6>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break; + case 7: mul_mat_vec_q5_0_q8_1_sycl_ncols<7>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break; + case 8: mul_mat_vec_q5_0_q8_1_sycl_ncols<8>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break; + default: GGML_ABORT("unsupported ncols_dst=%d for Q5_0 multi-col MMVQ", ncols_dst); + } +} + static void mul_mat_vec_q5_1_q8_1_sycl(const void *vx, const void *vy, float *dst, const int ncols, const int nrows, @@ -662,6 +1036,103 @@ static void mul_mat_vec_q5_1_q8_1_sycl(const void *vx, const void *vy, } } +template <int ncols_dst> +static void mul_mat_vec_q5_1_q8_1_sycl_ncols( + const void * vx, const void * vy, float * dst, + const int ncols, const int nrows, + const int stride_col_y, const int stride_col_dst, + dpct::queue_ptr stream) { + GGML_ASSERT(ncols % QK5_1 == 0); + const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y; + const sycl::range<3> block_nums(1, 1, block_num_y); + const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE); + stream->submit([&](sycl::handler & cgh) { + cgh.parallel_for( + sycl::nd_range<3>(block_nums * block_dims, block_dims), + [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] { + mul_mat_vec_q_ncols<QK5_1, QI5_1, block_q5_1, + VDR_Q5_1_Q8_1_MMVQ, vec_dot_q5_1_q8_1, ncols_dst>( + vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, item_ct1); + }); + }); +} + +static void mul_mat_vec_q5_1_q8_1_sycl_switch_ncols( + const void * vx, const void * vy, float * dst, + const int ncols, const int nrows, const int ncols_dst, + const int stride_col_y, const int stride_col_dst, + dpct::queue_ptr stream) { + switch (ncols_dst) { + case 1: mul_mat_vec_q5_1_q8_1_sycl(vx, vy, dst, ncols, nrows, stream); break; + case 2: mul_mat_vec_q5_1_q8_1_sycl_ncols<2>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break; + case 3: mul_mat_vec_q5_1_q8_1_sycl_ncols<3>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break; + case 4: mul_mat_vec_q5_1_q8_1_sycl_ncols<4>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break; + case 5: mul_mat_vec_q5_1_q8_1_sycl_ncols<5>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break; + case 6: mul_mat_vec_q5_1_q8_1_sycl_ncols<6>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break; + case 7: mul_mat_vec_q5_1_q8_1_sycl_ncols<7>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break; + case 8: mul_mat_vec_q5_1_q8_1_sycl_ncols<8>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break; + default: GGML_ABORT("unsupported ncols_dst=%d for Q5_1 multi-col MMVQ", ncols_dst); + } +} + +static void reorder_mul_mat_vec_q8_0_q8_1_sycl(const void * vx, const void * vy, float * dst, const int ncols, + const int nrows, dpct::queue_ptr stream) { + GGML_ASSERT(ncols % QK8_0 == 0); + // Round up to a whole number of subgroup-sized workgroups; out-of-range rows are skipped inside the kernel. + constexpr size_t num_subgroups = WARP_SIZE; + const int block_num_y = ceil_div(nrows, GGML_SYCL_MMV_Y * (int) num_subgroups) * (int) num_subgroups; + + const sycl::range<3> global_size(1, GGML_SYCL_MMV_Y, (block_num_y * WARP_SIZE)); + const sycl::range<3> workgroup_size(1, GGML_SYCL_MMV_Y, num_subgroups * WARP_SIZE); + + stream->submit([&](sycl::handler & cgh) { + cgh.parallel_for(sycl::nd_range<3>(global_size, workgroup_size), + [=](sycl::nd_item<3> nd_item) [[sycl::reqd_sub_group_size(WARP_SIZE)]] { + mul_mat_vec_q_reorder<reorder_vec_dot_q_sycl<GGML_TYPE_Q8_0>>(vx, vy, dst, ncols, nrows, + nd_item); + }); + }); +} + +template <int ncols_dst> +static void reorder_mul_mat_vec_q8_0_q8_1_sycl_ncols( + const void * vx, const void * vy, float * dst, + const int ncols, const int nrows, + const int stride_col_y_bytes, const int stride_col_dst, + dpct::queue_ptr stream) { + GGML_ASSERT(ncols % QK8_0 == 0); + const int block_num_y = ceil_div(nrows, GGML_SYCL_MMV_Y); + constexpr size_t num_subgroups = 16; + GGML_ASSERT(block_num_y % num_subgroups == 0); + const sycl::range<3> global_size(1, GGML_SYCL_MMV_Y, block_num_y * WARP_SIZE); + const sycl::range<3> workgroup_size(1, GGML_SYCL_MMV_Y, num_subgroups * WARP_SIZE); + stream->submit([&](sycl::handler & cgh) { + cgh.parallel_for(sycl::nd_range<3>(global_size, workgroup_size), + [=](sycl::nd_item<3> nd_item) [[sycl::reqd_sub_group_size(WARP_SIZE)]] { + mul_mat_vec_q_reorder_ncols<reorder_vec_dot_q_sycl<GGML_TYPE_Q8_0>, ncols_dst>( + vx, vy, dst, ncols, nrows, stride_col_y_bytes, stride_col_dst, nd_item); + }); + }); +} + +static void reorder_mul_mat_vec_q8_0_q8_1_sycl_switch_ncols( + const void * vx, const void * vy, float * dst, + const int ncols, const int nrows, const int ncols_dst, + const int stride_col_y_bytes, const int stride_col_dst, + dpct::queue_ptr stream) { + switch (ncols_dst) { + case 1: reorder_mul_mat_vec_q8_0_q8_1_sycl(vx, vy, dst, ncols, nrows, stream); break; + case 2: reorder_mul_mat_vec_q8_0_q8_1_sycl_ncols<2>(vx, vy, dst, ncols, nrows, stride_col_y_bytes, stride_col_dst, stream); break; + case 3: reorder_mul_mat_vec_q8_0_q8_1_sycl_ncols<3>(vx, vy, dst, ncols, nrows, stride_col_y_bytes, stride_col_dst, stream); break; + case 4: reorder_mul_mat_vec_q8_0_q8_1_sycl_ncols<4>(vx, vy, dst, ncols, nrows, stride_col_y_bytes, stride_col_dst, stream); break; + case 5: reorder_mul_mat_vec_q8_0_q8_1_sycl_ncols<5>(vx, vy, dst, ncols, nrows, stride_col_y_bytes, stride_col_dst, stream); break; + case 6: reorder_mul_mat_vec_q8_0_q8_1_sycl_ncols<6>(vx, vy, dst, ncols, nrows, stride_col_y_bytes, stride_col_dst, stream); break; + case 7: reorder_mul_mat_vec_q8_0_q8_1_sycl_ncols<7>(vx, vy, dst, ncols, nrows, stride_col_y_bytes, stride_col_dst, stream); break; + case 8: reorder_mul_mat_vec_q8_0_q8_1_sycl_ncols<8>(vx, vy, dst, ncols, nrows, stride_col_y_bytes, stride_col_dst, stream); break; + default: GGML_ABORT("unsupported ncols_dst=%d for Q8_0 reorder multi-col MMVQ", ncols_dst); + } +} + static void mul_mat_vec_q8_0_q8_1_sycl(const void *vx, const void *vy, float *dst, const int ncols, const int nrows, @@ -686,6 +1157,45 @@ static void mul_mat_vec_q8_0_q8_1_sycl(const void *vx, const void *vy, } } +template <int ncols_dst> +static void mul_mat_vec_q8_0_q8_1_sycl_ncols( + const void * vx, const void * vy, float * dst, + const int ncols, const int nrows, + const int stride_col_y, const int stride_col_dst, + dpct::queue_ptr stream) { + GGML_ASSERT(ncols % QK8_0 == 0); + const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y; + const sycl::range<3> block_nums(1, 1, block_num_y); + const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE); + stream->submit([&](sycl::handler & cgh) { + cgh.parallel_for( + sycl::nd_range<3>(block_nums * block_dims, block_dims), + [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] { + mul_mat_vec_q_ncols<QK8_0, QI8_0, block_q8_0, + VDR_Q8_0_Q8_1_MMVQ, vec_dot_q8_0_q8_1, ncols_dst>( + vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, item_ct1); + }); + }); +} + +static void mul_mat_vec_q8_0_q8_1_sycl_switch_ncols( + const void * vx, const void * vy, float * dst, + const int ncols, const int nrows, const int ncols_dst, + const int stride_col_y, const int stride_col_dst, + dpct::queue_ptr stream) { + switch (ncols_dst) { + case 1: mul_mat_vec_q8_0_q8_1_sycl(vx, vy, dst, ncols, nrows, stream); break; + case 2: mul_mat_vec_q8_0_q8_1_sycl_ncols<2>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break; + case 3: mul_mat_vec_q8_0_q8_1_sycl_ncols<3>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break; + case 4: mul_mat_vec_q8_0_q8_1_sycl_ncols<4>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break; + case 5: mul_mat_vec_q8_0_q8_1_sycl_ncols<5>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break; + case 6: mul_mat_vec_q8_0_q8_1_sycl_ncols<6>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break; + case 7: mul_mat_vec_q8_0_q8_1_sycl_ncols<7>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break; + case 8: mul_mat_vec_q8_0_q8_1_sycl_ncols<8>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break; + default: GGML_ABORT("unsupported ncols_dst=%d for Q8_0 multi-col MMVQ", ncols_dst); + } +} + static void mul_mat_vec_q2_K_q8_1_sycl(const void *vx, const void *vy, float *dst, const int ncols, const int nrows, @@ -710,6 +1220,45 @@ static void mul_mat_vec_q2_K_q8_1_sycl(const void *vx, const void *vy, } } +template <int ncols_dst> +static void mul_mat_vec_q2_K_q8_1_sycl_ncols( + const void * vx, const void * vy, float * dst, + const int ncols, const int nrows, + const int stride_col_y, const int stride_col_dst, + dpct::queue_ptr stream) { + GGML_ASSERT(ncols % QK_K == 0); + const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y; + const sycl::range<3> block_nums(1, 1, block_num_y); + const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE); + stream->submit([&](sycl::handler & cgh) { + cgh.parallel_for( + sycl::nd_range<3>(block_nums * block_dims, block_dims), + [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] { + mul_mat_vec_q_ncols<QK_K, QI2_K, block_q2_K, + VDR_Q2_K_Q8_1_MMVQ, vec_dot_q2_K_q8_1, ncols_dst>( + vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, item_ct1); + }); + }); +} + +static void mul_mat_vec_q2_K_q8_1_sycl_switch_ncols( + const void * vx, const void * vy, float * dst, + const int ncols, const int nrows, const int ncols_dst, + const int stride_col_y, const int stride_col_dst, + dpct::queue_ptr stream) { + switch (ncols_dst) { + case 1: mul_mat_vec_q2_K_q8_1_sycl(vx, vy, dst, ncols, nrows, stream); break; + case 2: mul_mat_vec_q2_K_q8_1_sycl_ncols<2>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break; + case 3: mul_mat_vec_q2_K_q8_1_sycl_ncols<3>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break; + case 4: mul_mat_vec_q2_K_q8_1_sycl_ncols<4>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break; + case 5: mul_mat_vec_q2_K_q8_1_sycl_ncols<5>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break; + case 6: mul_mat_vec_q2_K_q8_1_sycl_ncols<6>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break; + case 7: mul_mat_vec_q2_K_q8_1_sycl_ncols<7>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break; + case 8: mul_mat_vec_q2_K_q8_1_sycl_ncols<8>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break; + default: GGML_ABORT("unsupported ncols_dst=%d for Q2_K multi-col MMVQ", ncols_dst); + } +} + static void mul_mat_vec_q3_K_q8_1_sycl(const void *vx, const void *vy, float *dst, const int ncols, const int nrows, @@ -734,6 +1283,105 @@ static void mul_mat_vec_q3_K_q8_1_sycl(const void *vx, const void *vy, } } +static void reorder_mul_mat_vec_q3_k_q8_1_sycl(const void * vx, const void * vy, float * dst, const int ncols, + const int nrows, dpct::queue_ptr stream) { + GGML_ASSERT(ncols % QK_K == 0); + + // Round up to a whole number of subgroup-sized workgroups; out-of-range rows are skipped inside the kernel. + constexpr size_t num_subgroups = WARP_SIZE; + const int block_num_y = ceil_div(nrows, GGML_SYCL_MMV_Y * (int) num_subgroups) * (int) num_subgroups; + + const sycl::range<3> global_size(1, GGML_SYCL_MMV_Y, block_num_y * WARP_SIZE); + const sycl::range<3> workgroup_size(1, GGML_SYCL_MMV_Y, num_subgroups * WARP_SIZE); + + stream->submit([&](sycl::handler & cgh) { + cgh.parallel_for(sycl::nd_range<3>(global_size, workgroup_size), + [=](sycl::nd_item<3> nd_item) [[sycl::reqd_sub_group_size(WARP_SIZE)]] { + mul_mat_vec_q_reorder<reorder_vec_dot_q_sycl<GGML_TYPE_Q3_K>>(vx, vy, dst, ncols, nrows, + nd_item); + }); + }); +} + +template <int ncols_dst> +static void reorder_mul_mat_vec_q3_k_q8_1_sycl_ncols( + const void * vx, const void * vy, float * dst, + const int ncols, const int nrows, + const int stride_col_y_bytes, const int stride_col_dst, + dpct::queue_ptr stream) { + GGML_ASSERT(ncols % QK_K == 0); + const int block_num_y = ceil_div(nrows, GGML_SYCL_MMV_Y); + constexpr size_t num_subgroups = 16; + GGML_ASSERT(block_num_y % num_subgroups == 0); + const sycl::range<3> global_size(1, GGML_SYCL_MMV_Y, block_num_y * WARP_SIZE); + const sycl::range<3> workgroup_size(1, GGML_SYCL_MMV_Y, num_subgroups * WARP_SIZE); + stream->submit([&](sycl::handler & cgh) { + cgh.parallel_for(sycl::nd_range<3>(global_size, workgroup_size), + [=](sycl::nd_item<3> nd_item) [[sycl::reqd_sub_group_size(WARP_SIZE)]] { + mul_mat_vec_q_reorder_ncols<reorder_vec_dot_q_sycl<GGML_TYPE_Q3_K>, ncols_dst>( + vx, vy, dst, ncols, nrows, stride_col_y_bytes, stride_col_dst, nd_item); + }); + }); +} + +static void reorder_mul_mat_vec_q3_k_q8_1_sycl_switch_ncols( + const void * vx, const void * vy, float * dst, + const int ncols, const int nrows, const int ncols_dst, + const int stride_col_y_bytes, const int stride_col_dst, + dpct::queue_ptr stream) { + switch (ncols_dst) { + case 1: reorder_mul_mat_vec_q3_k_q8_1_sycl(vx, vy, dst, ncols, nrows, stream); break; + case 2: reorder_mul_mat_vec_q3_k_q8_1_sycl_ncols<2>(vx, vy, dst, ncols, nrows, stride_col_y_bytes, stride_col_dst, stream); break; + case 3: reorder_mul_mat_vec_q3_k_q8_1_sycl_ncols<3>(vx, vy, dst, ncols, nrows, stride_col_y_bytes, stride_col_dst, stream); break; + case 4: reorder_mul_mat_vec_q3_k_q8_1_sycl_ncols<4>(vx, vy, dst, ncols, nrows, stride_col_y_bytes, stride_col_dst, stream); break; + case 5: reorder_mul_mat_vec_q3_k_q8_1_sycl_ncols<5>(vx, vy, dst, ncols, nrows, stride_col_y_bytes, stride_col_dst, stream); break; + case 6: reorder_mul_mat_vec_q3_k_q8_1_sycl_ncols<6>(vx, vy, dst, ncols, nrows, stride_col_y_bytes, stride_col_dst, stream); break; + case 7: reorder_mul_mat_vec_q3_k_q8_1_sycl_ncols<7>(vx, vy, dst, ncols, nrows, stride_col_y_bytes, stride_col_dst, stream); break; + case 8: reorder_mul_mat_vec_q3_k_q8_1_sycl_ncols<8>(vx, vy, dst, ncols, nrows, stride_col_y_bytes, stride_col_dst, stream); break; + default: GGML_ABORT("unsupported ncols_dst=%d for Q3_K reorder multi-col MMVQ", ncols_dst); + } +} + +template <int ncols_dst> +static void mul_mat_vec_q3_K_q8_1_sycl_ncols( + const void * vx, const void * vy, float * dst, + const int ncols, const int nrows, + const int stride_col_y, const int stride_col_dst, + dpct::queue_ptr stream) { + GGML_ASSERT(ncols % QK_K == 0); + const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y; + const sycl::range<3> block_nums(1, 1, block_num_y); + const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE); + stream->submit([&](sycl::handler & cgh) { + cgh.parallel_for( + sycl::nd_range<3>(block_nums * block_dims, block_dims), + [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] { + mul_mat_vec_q_ncols<QK_K, QI3_K, block_q3_K, + VDR_Q3_K_Q8_1_MMVQ, vec_dot_q3_K_q8_1, ncols_dst>( + vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, item_ct1); + }); + }); +} + +static void mul_mat_vec_q3_K_q8_1_sycl_switch_ncols( + const void * vx, const void * vy, float * dst, + const int ncols, const int nrows, const int ncols_dst, + const int stride_col_y, const int stride_col_dst, + dpct::queue_ptr stream) { + switch (ncols_dst) { + case 1: mul_mat_vec_q3_K_q8_1_sycl(vx, vy, dst, ncols, nrows, stream); break; + case 2: mul_mat_vec_q3_K_q8_1_sycl_ncols<2>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break; + case 3: mul_mat_vec_q3_K_q8_1_sycl_ncols<3>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break; + case 4: mul_mat_vec_q3_K_q8_1_sycl_ncols<4>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break; + case 5: mul_mat_vec_q3_K_q8_1_sycl_ncols<5>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break; + case 6: mul_mat_vec_q3_K_q8_1_sycl_ncols<6>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break; + case 7: mul_mat_vec_q3_K_q8_1_sycl_ncols<7>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break; + case 8: mul_mat_vec_q3_K_q8_1_sycl_ncols<8>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break; + default: GGML_ABORT("unsupported ncols_dst=%d for Q3_K multi-col MMVQ", ncols_dst); + } +} + + static void mul_mat_vec_q4_K_q8_1_sycl(const void *vx, const void *vy, float *dst, const int ncols, const int nrows, @@ -758,13 +1406,58 @@ static void mul_mat_vec_q4_K_q8_1_sycl(const void *vx, const void *vy, } } +template <int ncols_dst> +static void mul_mat_vec_q4_K_q8_1_sycl_ncols( + const void * vx, const void * vy, float * dst, + const int ncols, const int nrows, + const int stride_col_y, const int stride_col_dst, + dpct::queue_ptr stream) { + GGML_ASSERT(ncols % QK_K == 0); + const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y; + const sycl::range<3> block_nums(1, 1, block_num_y); + const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE); + + stream->submit([&](sycl::handler & cgh) { + cgh.parallel_for( + sycl::nd_range<3>(block_nums * block_dims, block_dims), + [=](sycl::nd_item<3> item_ct1) + [[sycl::reqd_sub_group_size(WARP_SIZE)]] { + mul_mat_vec_q_ncols<QK_K, QI4_K, block_q4_K, + VDR_Q4_K_Q8_1_MMVQ, + vec_dot_q4_K_q8_1, + ncols_dst>( + vx, vy, dst, ncols, nrows, + stride_col_y, stride_col_dst, item_ct1); + }); + }); +} + +static void mul_mat_vec_q4_K_q8_1_sycl_switch_ncols( + const void * vx, const void * vy, float * dst, + const int ncols, const int nrows, + const int ncols_dst, + const int stride_col_y, const int stride_col_dst, + dpct::queue_ptr stream) { + switch (ncols_dst) { + case 1: mul_mat_vec_q4_K_q8_1_sycl(vx, vy, dst, ncols, nrows, stream); break; + case 2: mul_mat_vec_q4_K_q8_1_sycl_ncols<2>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break; + case 3: mul_mat_vec_q4_K_q8_1_sycl_ncols<3>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break; + case 4: mul_mat_vec_q4_K_q8_1_sycl_ncols<4>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break; + case 5: mul_mat_vec_q4_K_q8_1_sycl_ncols<5>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break; + case 6: mul_mat_vec_q4_K_q8_1_sycl_ncols<6>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break; + case 7: mul_mat_vec_q4_K_q8_1_sycl_ncols<7>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break; + case 8: mul_mat_vec_q4_K_q8_1_sycl_ncols<8>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break; + default: GGML_ABORT("unsupported ncols_dst=%d for Q4_K multi-col MMVQ", ncols_dst); + } +} + static void reorder_mul_mat_vec_q4_k_q8_1_sycl(const void * vx, const void * vy, float * dst, const int ncols, const int nrows, dpct::queue_ptr stream) { GGML_ASSERT(ncols % QK_K == 0); - const int block_num_y = ceil_div(nrows, GGML_SYCL_MMV_Y); - constexpr size_t num_subgroups = 16; - GGML_ASSERT(block_num_y % num_subgroups == 0); + // Round up to a whole number of subgroup-sized workgroups; out-of-range rows are skipped inside the kernel. + constexpr size_t num_subgroups = WARP_SIZE; + const int block_num_y = ceil_div(nrows, GGML_SYCL_MMV_Y * (int) num_subgroups) * (int) num_subgroups; const sycl::range<3> global_size(1, GGML_SYCL_MMV_Y, block_num_y * WARP_SIZE); const sycl::range<3> workgroup_size(1, GGML_SYCL_MMV_Y, num_subgroups * WARP_SIZE); @@ -778,6 +1471,44 @@ static void reorder_mul_mat_vec_q4_k_q8_1_sycl(const void * vx, const void * vy, }); } +template <int ncols_dst> +static void reorder_mul_mat_vec_q4_k_q8_1_sycl_ncols( + const void * vx, const void * vy, float * dst, + const int ncols, const int nrows, + const int stride_col_y_bytes, const int stride_col_dst, + dpct::queue_ptr stream) { + GGML_ASSERT(ncols % QK_K == 0); + const int block_num_y = ceil_div(nrows, GGML_SYCL_MMV_Y); + constexpr size_t num_subgroups = 16; + GGML_ASSERT(block_num_y % num_subgroups == 0); + const sycl::range<3> global_size(1, GGML_SYCL_MMV_Y, block_num_y * WARP_SIZE); + const sycl::range<3> workgroup_size(1, GGML_SYCL_MMV_Y, num_subgroups * WARP_SIZE); + stream->submit([&](sycl::handler & cgh) { + cgh.parallel_for(sycl::nd_range<3>(global_size, workgroup_size), + [=](sycl::nd_item<3> nd_item) [[sycl::reqd_sub_group_size(WARP_SIZE)]] { + mul_mat_vec_q_reorder_ncols<reorder_vec_dot_q_sycl<GGML_TYPE_Q4_K>, ncols_dst>( + vx, vy, dst, ncols, nrows, stride_col_y_bytes, stride_col_dst, nd_item); + }); + }); +} + +static void reorder_mul_mat_vec_q4_k_q8_1_sycl_switch_ncols( + const void * vx, const void * vy, float * dst, + const int ncols, const int nrows, const int ncols_dst, + const int stride_col_y_bytes, const int stride_col_dst, + dpct::queue_ptr stream) { + switch (ncols_dst) { + case 1: reorder_mul_mat_vec_q4_k_q8_1_sycl(vx, vy, dst, ncols, nrows, stream); break; + case 2: reorder_mul_mat_vec_q4_k_q8_1_sycl_ncols<2>(vx, vy, dst, ncols, nrows, stride_col_y_bytes, stride_col_dst, stream); break; + case 3: reorder_mul_mat_vec_q4_k_q8_1_sycl_ncols<3>(vx, vy, dst, ncols, nrows, stride_col_y_bytes, stride_col_dst, stream); break; + case 4: reorder_mul_mat_vec_q4_k_q8_1_sycl_ncols<4>(vx, vy, dst, ncols, nrows, stride_col_y_bytes, stride_col_dst, stream); break; + case 5: reorder_mul_mat_vec_q4_k_q8_1_sycl_ncols<5>(vx, vy, dst, ncols, nrows, stride_col_y_bytes, stride_col_dst, stream); break; + case 6: reorder_mul_mat_vec_q4_k_q8_1_sycl_ncols<6>(vx, vy, dst, ncols, nrows, stride_col_y_bytes, stride_col_dst, stream); break; + case 7: reorder_mul_mat_vec_q4_k_q8_1_sycl_ncols<7>(vx, vy, dst, ncols, nrows, stride_col_y_bytes, stride_col_dst, stream); break; + case 8: reorder_mul_mat_vec_q4_k_q8_1_sycl_ncols<8>(vx, vy, dst, ncols, nrows, stride_col_y_bytes, stride_col_dst, stream); break; + default: GGML_ABORT("unsupported ncols_dst=%d for Q4_K reorder multi-col MMVQ", ncols_dst); + } +} static void mul_mat_vec_q5_K_q8_1_sycl(const void *vx, const void *vy, float *dst, const int ncols, @@ -803,9 +1534,55 @@ static void mul_mat_vec_q5_K_q8_1_sycl(const void *vx, const void *vy, } } -static void reorder_mul_mat_vec_q6_k_q8_1_sycl(const void * vx, const void * vy, float * dst, const int ncols, +template <int ncols_dst> +static void mul_mat_vec_q5_K_q8_1_sycl_ncols( + const void * vx, const void * vy, float * dst, + const int ncols, const int nrows, + const int stride_col_y, const int stride_col_dst, + dpct::queue_ptr stream) { + GGML_ASSERT(ncols % QK_K == 0); + const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y; + const sycl::range<3> block_nums(1, 1, block_num_y); + const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE); + + stream->submit([&](sycl::handler & cgh) { + cgh.parallel_for( + sycl::nd_range<3>(block_nums * block_dims, block_dims), + [=](sycl::nd_item<3> item_ct1) + [[sycl::reqd_sub_group_size(WARP_SIZE)]] { + mul_mat_vec_q_ncols<QK_K, QI5_K, block_q5_K, + VDR_Q5_K_Q8_1_MMVQ, + vec_dot_q5_K_q8_1, + ncols_dst>( + vx, vy, dst, ncols, nrows, + stride_col_y, stride_col_dst, item_ct1); + }); + }); +} + +static void mul_mat_vec_q5_K_q8_1_sycl_switch_ncols( + const void * vx, const void * vy, float * dst, + const int ncols, const int nrows, + const int ncols_dst, + const int stride_col_y, const int stride_col_dst, + dpct::queue_ptr stream) { + switch (ncols_dst) { + case 1: mul_mat_vec_q5_K_q8_1_sycl(vx, vy, dst, ncols, nrows, stream); break; + case 2: mul_mat_vec_q5_K_q8_1_sycl_ncols<2>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break; + case 3: mul_mat_vec_q5_K_q8_1_sycl_ncols<3>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break; + case 4: mul_mat_vec_q5_K_q8_1_sycl_ncols<4>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break; + case 5: mul_mat_vec_q5_K_q8_1_sycl_ncols<5>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break; + case 6: mul_mat_vec_q5_K_q8_1_sycl_ncols<6>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break; + case 7: mul_mat_vec_q5_K_q8_1_sycl_ncols<7>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break; + case 8: mul_mat_vec_q5_K_q8_1_sycl_ncols<8>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break; + default: GGML_ABORT("unsupported ncols_dst=%d for Q5_K multi-col MMVQ", ncols_dst); + } +} + +static void reorder_mul_mat_vec_q5_k_q8_1_sycl(const void * vx, const void * vy, float * dst, const int ncols, const int nrows, dpct::queue_ptr stream) { GGML_ASSERT(ncols % QK_K == 0); + const int block_num_y = ceil_div(nrows, GGML_SYCL_MMV_Y); constexpr size_t num_subgroups = 16; GGML_ASSERT(block_num_y % num_subgroups == 0); @@ -813,6 +1590,64 @@ static void reorder_mul_mat_vec_q6_k_q8_1_sycl(const void * vx, const void * vy, const sycl::range<3> global_size(1, GGML_SYCL_MMV_Y, block_num_y * WARP_SIZE); const sycl::range<3> workgroup_size(1, GGML_SYCL_MMV_Y, num_subgroups * WARP_SIZE); + stream->submit([&](sycl::handler & cgh) { + cgh.parallel_for(sycl::nd_range<3>(global_size, workgroup_size), + [=](sycl::nd_item<3> nd_item) [[sycl::reqd_sub_group_size(WARP_SIZE)]] { + mul_mat_vec_q_reorder<reorder_vec_dot_q_sycl<GGML_TYPE_Q5_K>>(vx, vy, dst, ncols, + nrows, nd_item); + }); + }); +} + +template <int ncols_dst> +static void reorder_mul_mat_vec_q5_k_q8_1_sycl_ncols( + const void * vx, const void * vy, float * dst, + const int ncols, const int nrows, + const int stride_col_y_bytes, const int stride_col_dst, + dpct::queue_ptr stream) { + GGML_ASSERT(ncols % QK_K == 0); + const int block_num_y = ceil_div(nrows, GGML_SYCL_MMV_Y); + constexpr size_t num_subgroups = 16; + GGML_ASSERT(block_num_y % num_subgroups == 0); + const sycl::range<3> global_size(1, GGML_SYCL_MMV_Y, block_num_y * WARP_SIZE); + const sycl::range<3> workgroup_size(1, GGML_SYCL_MMV_Y, num_subgroups * WARP_SIZE); + stream->submit([&](sycl::handler & cgh) { + cgh.parallel_for(sycl::nd_range<3>(global_size, workgroup_size), + [=](sycl::nd_item<3> nd_item) [[sycl::reqd_sub_group_size(WARP_SIZE)]] { + mul_mat_vec_q_reorder_ncols<reorder_vec_dot_q_sycl<GGML_TYPE_Q5_K>, ncols_dst>( + vx, vy, dst, ncols, nrows, stride_col_y_bytes, stride_col_dst, nd_item); + }); + }); +} + +static void reorder_mul_mat_vec_q5_k_q8_1_sycl_switch_ncols( + const void * vx, const void * vy, float * dst, + const int ncols, const int nrows, const int ncols_dst, + const int stride_col_y_bytes, const int stride_col_dst, + dpct::queue_ptr stream) { + switch (ncols_dst) { + case 1: reorder_mul_mat_vec_q5_k_q8_1_sycl(vx, vy, dst, ncols, nrows, stream); break; + case 2: reorder_mul_mat_vec_q5_k_q8_1_sycl_ncols<2>(vx, vy, dst, ncols, nrows, stride_col_y_bytes, stride_col_dst, stream); break; + case 3: reorder_mul_mat_vec_q5_k_q8_1_sycl_ncols<3>(vx, vy, dst, ncols, nrows, stride_col_y_bytes, stride_col_dst, stream); break; + case 4: reorder_mul_mat_vec_q5_k_q8_1_sycl_ncols<4>(vx, vy, dst, ncols, nrows, stride_col_y_bytes, stride_col_dst, stream); break; + case 5: reorder_mul_mat_vec_q5_k_q8_1_sycl_ncols<5>(vx, vy, dst, ncols, nrows, stride_col_y_bytes, stride_col_dst, stream); break; + case 6: reorder_mul_mat_vec_q5_k_q8_1_sycl_ncols<6>(vx, vy, dst, ncols, nrows, stride_col_y_bytes, stride_col_dst, stream); break; + case 7: reorder_mul_mat_vec_q5_k_q8_1_sycl_ncols<7>(vx, vy, dst, ncols, nrows, stride_col_y_bytes, stride_col_dst, stream); break; + case 8: reorder_mul_mat_vec_q5_k_q8_1_sycl_ncols<8>(vx, vy, dst, ncols, nrows, stride_col_y_bytes, stride_col_dst, stream); break; + default: GGML_ABORT("unsupported ncols_dst=%d for Q5_K reorder multi-col MMVQ", ncols_dst); + } +} + +static void reorder_mul_mat_vec_q6_k_q8_1_sycl(const void * vx, const void * vy, float * dst, const int ncols, + const int nrows, dpct::queue_ptr stream) { + GGML_ASSERT(ncols % QK_K == 0); + // Round up to a whole number of subgroup-sized workgroups; out-of-range rows are skipped inside the kernel. + constexpr size_t num_subgroups = WARP_SIZE; + const int block_num_y = ceil_div(nrows, GGML_SYCL_MMV_Y * (int) num_subgroups) * (int) num_subgroups; + + const sycl::range<3> global_size(1, GGML_SYCL_MMV_Y, block_num_y * WARP_SIZE); + const sycl::range<3> workgroup_size(1, GGML_SYCL_MMV_Y, num_subgroups * WARP_SIZE); + stream->submit([&](sycl::handler & cgh) { cgh.parallel_for(sycl::nd_range<3>(global_size, workgroup_size), [=](sycl::nd_item<3> nd_item) [[sycl::reqd_sub_group_size(WARP_SIZE)]] { @@ -821,6 +1656,46 @@ static void reorder_mul_mat_vec_q6_k_q8_1_sycl(const void * vx, const void * vy, }); }); } + +template <int ncols_dst> +static void reorder_mul_mat_vec_q6_k_q8_1_sycl_ncols( + const void * vx, const void * vy, float * dst, + const int ncols, const int nrows, + const int stride_col_y_bytes, const int stride_col_dst, + dpct::queue_ptr stream) { + GGML_ASSERT(ncols % QK_K == 0); + const int block_num_y = ceil_div(nrows, GGML_SYCL_MMV_Y); + constexpr size_t num_subgroups = 16; + GGML_ASSERT(block_num_y % num_subgroups == 0); + const sycl::range<3> global_size(1, GGML_SYCL_MMV_Y, block_num_y * WARP_SIZE); + const sycl::range<3> workgroup_size(1, GGML_SYCL_MMV_Y, num_subgroups * WARP_SIZE); + stream->submit([&](sycl::handler & cgh) { + cgh.parallel_for(sycl::nd_range<3>(global_size, workgroup_size), + [=](sycl::nd_item<3> nd_item) [[sycl::reqd_sub_group_size(WARP_SIZE)]] { + mul_mat_vec_q_reorder_ncols<reorder_vec_dot_q_sycl<GGML_TYPE_Q6_K>, ncols_dst>( + vx, vy, dst, ncols, nrows, stride_col_y_bytes, stride_col_dst, nd_item); + }); + }); +} + +static void reorder_mul_mat_vec_q6_k_q8_1_sycl_switch_ncols( + const void * vx, const void * vy, float * dst, + const int ncols, const int nrows, const int ncols_dst, + const int stride_col_y_bytes, const int stride_col_dst, + dpct::queue_ptr stream) { + switch (ncols_dst) { + case 1: reorder_mul_mat_vec_q6_k_q8_1_sycl(vx, vy, dst, ncols, nrows, stream); break; + case 2: reorder_mul_mat_vec_q6_k_q8_1_sycl_ncols<2>(vx, vy, dst, ncols, nrows, stride_col_y_bytes, stride_col_dst, stream); break; + case 3: reorder_mul_mat_vec_q6_k_q8_1_sycl_ncols<3>(vx, vy, dst, ncols, nrows, stride_col_y_bytes, stride_col_dst, stream); break; + case 4: reorder_mul_mat_vec_q6_k_q8_1_sycl_ncols<4>(vx, vy, dst, ncols, nrows, stride_col_y_bytes, stride_col_dst, stream); break; + case 5: reorder_mul_mat_vec_q6_k_q8_1_sycl_ncols<5>(vx, vy, dst, ncols, nrows, stride_col_y_bytes, stride_col_dst, stream); break; + case 6: reorder_mul_mat_vec_q6_k_q8_1_sycl_ncols<6>(vx, vy, dst, ncols, nrows, stride_col_y_bytes, stride_col_dst, stream); break; + case 7: reorder_mul_mat_vec_q6_k_q8_1_sycl_ncols<7>(vx, vy, dst, ncols, nrows, stride_col_y_bytes, stride_col_dst, stream); break; + case 8: reorder_mul_mat_vec_q6_k_q8_1_sycl_ncols<8>(vx, vy, dst, ncols, nrows, stride_col_y_bytes, stride_col_dst, stream); break; + default: GGML_ABORT("unsupported ncols_dst=%d for Q6_K reorder multi-col MMVQ", ncols_dst); + } +} + static void mul_mat_vec_q6_K_q8_1_sycl(const void *vx, const void *vy, float *dst, const int ncols, const int nrows, @@ -845,6 +1720,51 @@ static void mul_mat_vec_q6_K_q8_1_sycl(const void *vx, const void *vy, } } +template <int ncols_dst> +static void mul_mat_vec_q6_K_q8_1_sycl_ncols( + const void * vx, const void * vy, float * dst, + const int ncols, const int nrows, + const int stride_col_y, const int stride_col_dst, + dpct::queue_ptr stream) { + GGML_ASSERT(ncols % QK_K == 0); + const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y; + const sycl::range<3> block_nums(1, 1, block_num_y); + const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE); + + stream->submit([&](sycl::handler & cgh) { + cgh.parallel_for( + sycl::nd_range<3>(block_nums * block_dims, block_dims), + [=](sycl::nd_item<3> item_ct1) + [[sycl::reqd_sub_group_size(WARP_SIZE)]] { + mul_mat_vec_q_ncols<QK_K, QI6_K, block_q6_K, + VDR_Q6_K_Q8_1_MMVQ, + vec_dot_q6_K_q8_1, + ncols_dst>( + vx, vy, dst, ncols, nrows, + stride_col_y, stride_col_dst, item_ct1); + }); + }); +} + +static void mul_mat_vec_q6_K_q8_1_sycl_switch_ncols( + const void * vx, const void * vy, float * dst, + const int ncols, const int nrows, + const int ncols_dst, + const int stride_col_y, const int stride_col_dst, + dpct::queue_ptr stream) { + switch (ncols_dst) { + case 1: mul_mat_vec_q6_K_q8_1_sycl(vx, vy, dst, ncols, nrows, stream); break; + case 2: mul_mat_vec_q6_K_q8_1_sycl_ncols<2>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break; + case 3: mul_mat_vec_q6_K_q8_1_sycl_ncols<3>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break; + case 4: mul_mat_vec_q6_K_q8_1_sycl_ncols<4>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break; + case 5: mul_mat_vec_q6_K_q8_1_sycl_ncols<5>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break; + case 6: mul_mat_vec_q6_K_q8_1_sycl_ncols<6>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break; + case 7: mul_mat_vec_q6_K_q8_1_sycl_ncols<7>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break; + case 8: mul_mat_vec_q6_K_q8_1_sycl_ncols<8>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break; + default: GGML_ABORT("unsupported ncols_dst=%d for Q6_K multi-col MMVQ", ncols_dst); + } +} + static void mul_mat_vec_iq2_xxs_q8_1_sycl(const void *vx, const void *vy, float *dst, const int ncols, @@ -1041,6 +1961,51 @@ static void mul_mat_vec_iq4_xs_q8_1_sycl(const void *vx, const void *vy, } } +template <int ncols_dst> +static void mul_mat_vec_iq4_xs_q8_1_sycl_ncols( + const void * vx, const void * vy, float * dst, + const int ncols, const int nrows, + const int stride_col_y, const int stride_col_dst, + dpct::queue_ptr stream) { + GGML_ASSERT(ncols % QK_K == 0); + const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y; + const sycl::range<3> block_nums(1, 1, block_num_y); + const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE); + + stream->submit([&](sycl::handler & cgh) { + cgh.parallel_for( + sycl::nd_range<3>(block_nums * block_dims, block_dims), + [=](sycl::nd_item<3> item_ct1) + [[sycl::reqd_sub_group_size(WARP_SIZE)]] { + mul_mat_vec_q_ncols<QK_K, QI4_XS/4, block_iq4_xs, + 1, + vec_dot_iq4_xs_q8_1, + ncols_dst>( + vx, vy, dst, ncols, nrows, + stride_col_y, stride_col_dst, item_ct1); + }); + }); +} + +static void mul_mat_vec_iq4_xs_q8_1_sycl_switch_ncols( + const void * vx, const void * vy, float * dst, + const int ncols, const int nrows, + const int ncols_dst, + const int stride_col_y, const int stride_col_dst, + dpct::queue_ptr stream) { + switch (ncols_dst) { + case 1: mul_mat_vec_iq4_xs_q8_1_sycl(vx, vy, dst, ncols, nrows, stream); break; + case 2: mul_mat_vec_iq4_xs_q8_1_sycl_ncols<2>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break; + case 3: mul_mat_vec_iq4_xs_q8_1_sycl_ncols<3>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break; + case 4: mul_mat_vec_iq4_xs_q8_1_sycl_ncols<4>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break; + case 5: mul_mat_vec_iq4_xs_q8_1_sycl_ncols<5>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break; + case 6: mul_mat_vec_iq4_xs_q8_1_sycl_ncols<6>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break; + case 7: mul_mat_vec_iq4_xs_q8_1_sycl_ncols<7>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break; + case 8: mul_mat_vec_iq4_xs_q8_1_sycl_ncols<8>(vx, vy, dst, ncols, nrows, stride_col_y, stride_col_dst, stream); break; + default: GGML_ABORT("unsupported ncols_dst=%d for IQ4_XS multi-col MMVQ", ncols_dst); + } +} + void ggml_sycl_op_mul_mat_vec_q(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const char * src0_dd_i, const float * src1_ddf_i, const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low, @@ -1067,50 +2032,219 @@ void ggml_sycl_op_mul_mat_vec_q(ggml_backend_sycl_context & ctx, const ggml_tens case GGML_TYPE_Q4_0: if ((ggml_tensor_extra_gpu *) dst->src[0]->extra && ((ggml_tensor_extra_gpu *) dst->src[0]->extra)->optimized_feature.reorder) { - GGML_SYCL_DEBUG("Calling reorder_mul_mat_vec_q4_0_q8_1_sycl\n"); - reorder_mul_mat_vec_q4_0_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream); - } else { + if (i == 0 && src1_ncols > 1 && src1_ncols <= 8) { + const int stride_col_y_bytes = src1_padded_col_size * q8_1_ts / q8_1_bs; + const int stride_col_dst = dst->ne[0]; + GGML_SYCL_DEBUG("Calling reorder_mul_mat_vec_q4_0_q8_1_sycl_switch_ncols ncols=%d\n", (int)src1_ncols); + reorder_mul_mat_vec_q4_0_q8_1_sycl_switch_ncols( + src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, + src1_ncols, stride_col_y_bytes, stride_col_dst, stream); + return; + } else { + GGML_SYCL_DEBUG("Calling reorder_mul_mat_vec_q4_0_q8_1_sycl\n"); + reorder_mul_mat_vec_q4_0_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream); + } + } else if (i == 0 && src1_ncols > 1 && src1_ncols <= 8) { + const int stride_col_y = src1_padded_col_size / QK8_1; + const int stride_col_dst = dst->ne[0]; + GGML_SYCL_DEBUG("Calling mul_mat_vec_q4_0_q8_1_sycl_switch_ncols ncols=%d\n", (int)src1_ncols); + mul_mat_vec_q4_0_q8_1_sycl_switch_ncols( + src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, + src1_ncols, stride_col_y, stride_col_dst, stream); + return; + } else if (i == 0 || src1_ncols == 1) { GGML_SYCL_DEBUG("Calling mul_mat_vec_q4_0_q8_1_sycl\n"); mul_mat_vec_q4_0_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream); } break; case GGML_TYPE_Q4_1: - mul_mat_vec_q4_1_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream); + if (i == 0 && src1_ncols > 1 && src1_ncols <= 8) { + const int stride_col_y = src1_padded_col_size / QK8_1; + const int stride_col_dst = dst->ne[0]; + GGML_SYCL_DEBUG("Calling mul_mat_vec_q4_1_q8_1_sycl_switch_ncols ncols=%d\n", (int)src1_ncols); + mul_mat_vec_q4_1_q8_1_sycl_switch_ncols( + src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, + src1_ncols, stride_col_y, stride_col_dst, stream); + return; + } else if (i == 0 || src1_ncols == 1) { + mul_mat_vec_q4_1_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream); + } break; case GGML_TYPE_Q5_0: - mul_mat_vec_q5_0_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream); + if (i == 0 && src1_ncols > 1 && src1_ncols <= 8) { + const int stride_col_y = src1_padded_col_size / QK8_1; + const int stride_col_dst = dst->ne[0]; + GGML_SYCL_DEBUG("Calling mul_mat_vec_q5_0_q8_1_sycl_switch_ncols ncols=%d\n", (int)src1_ncols); + mul_mat_vec_q5_0_q8_1_sycl_switch_ncols( + src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, + src1_ncols, stride_col_y, stride_col_dst, stream); + return; + } else if (i == 0 || src1_ncols == 1) { + mul_mat_vec_q5_0_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream); + } break; case GGML_TYPE_Q5_1: - mul_mat_vec_q5_1_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream); + if (i == 0 && src1_ncols > 1 && src1_ncols <= 8) { + const int stride_col_y = src1_padded_col_size / QK8_1; + const int stride_col_dst = dst->ne[0]; + GGML_SYCL_DEBUG("Calling mul_mat_vec_q5_1_q8_1_sycl_switch_ncols ncols=%d\n", (int)src1_ncols); + mul_mat_vec_q5_1_q8_1_sycl_switch_ncols( + src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, + src1_ncols, stride_col_y, stride_col_dst, stream); + return; + } else if (i == 0 || src1_ncols == 1) { + mul_mat_vec_q5_1_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream); + } break; case GGML_TYPE_Q8_0: - mul_mat_vec_q8_0_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream); + if ((ggml_tensor_extra_gpu *) dst->src[0]->extra && + ((ggml_tensor_extra_gpu *) dst->src[0]->extra)->optimized_feature.reorder) { + if (i == 0 && src1_ncols > 1 && src1_ncols <= 8) { + const int stride_col_y_bytes = src1_padded_col_size * q8_1_ts / q8_1_bs; + const int stride_col_dst = dst->ne[0]; + GGML_SYCL_DEBUG("Calling reorder_mul_mat_vec_q8_0_q8_1_sycl_switch_ncols ncols=%d\n", (int)src1_ncols); + reorder_mul_mat_vec_q8_0_q8_1_sycl_switch_ncols( + src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, + src1_ncols, stride_col_y_bytes, stride_col_dst, stream); + return; + } else { + GGML_SYCL_DEBUG("Calling reorder_mul_mat_vec_q8_0_q8_1_sycl\n"); + reorder_mul_mat_vec_q8_0_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream); + } + } else if (i == 0 && src1_ncols > 1 && src1_ncols <= 8) { + const int stride_col_y = src1_padded_col_size / QK8_1; + const int stride_col_dst = dst->ne[0]; + GGML_SYCL_DEBUG("Calling mul_mat_vec_q8_0_q8_1_sycl_switch_ncols ncols=%d\n", (int)src1_ncols); + mul_mat_vec_q8_0_q8_1_sycl_switch_ncols( + src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, + src1_ncols, stride_col_y, stride_col_dst, stream); + return; + } else if (i == 0 || src1_ncols == 1) { + GGML_SYCL_DEBUG("Calling mul_mat_vec_q8_0_q8_1_sycl\n"); + mul_mat_vec_q8_0_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream); + } break; case GGML_TYPE_Q2_K: - mul_mat_vec_q2_K_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream); + if (i == 0 && src1_ncols > 1 && src1_ncols <= 8) { + const int stride_col_y = src1_padded_col_size / QK8_1; + const int stride_col_dst = dst->ne[0]; + GGML_SYCL_DEBUG("Calling mul_mat_vec_q2_K_q8_1_sycl_switch_ncols ncols=%d\n", (int)src1_ncols); + mul_mat_vec_q2_K_q8_1_sycl_switch_ncols( + src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, + src1_ncols, stride_col_y, stride_col_dst, stream); + return; + } else if (i == 0 || src1_ncols == 1) { + mul_mat_vec_q2_K_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream); + } break; case GGML_TYPE_Q3_K: - mul_mat_vec_q3_K_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream); + if ((ggml_tensor_extra_gpu *) dst->src[0]->extra && + ((ggml_tensor_extra_gpu *) dst->src[0]->extra)->optimized_feature.reorder) { + if (i == 0 && src1_ncols > 1 && src1_ncols <= 8) { + const int stride_col_y_bytes = src1_padded_col_size * q8_1_ts / q8_1_bs; + const int stride_col_dst = dst->ne[0]; + GGML_SYCL_DEBUG("Calling reorder_mul_mat_vec_q3_k_q8_1_sycl_switch_ncols ncols=%d\n", (int)src1_ncols); + reorder_mul_mat_vec_q3_k_q8_1_sycl_switch_ncols( + src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, + src1_ncols, stride_col_y_bytes, stride_col_dst, stream); + return; + } else { + GGML_SYCL_DEBUG("Calling reorder_mul_mat_vec_q3_k_q8_1_sycl\n"); + reorder_mul_mat_vec_q3_k_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream); + } + } else if (i == 0 && src1_ncols > 1 && src1_ncols <= 8) { + const int stride_col_y = src1_padded_col_size / QK8_1; + const int stride_col_dst = dst->ne[0]; + GGML_SYCL_DEBUG("Calling mul_mat_vec_q3_K_q8_1_sycl_switch_ncols ncols=%d\n", (int)src1_ncols); + mul_mat_vec_q3_K_q8_1_sycl_switch_ncols( + src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, + src1_ncols, stride_col_y, stride_col_dst, stream); + return; + } else if (i == 0 || src1_ncols == 1) { + GGML_SYCL_DEBUG("Calling mul_mat_vec_q3_K_q8_1_sycl\n"); + mul_mat_vec_q3_K_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream); + } break; case GGML_TYPE_Q4_K: if ((ggml_tensor_extra_gpu *) dst->src[0]->extra && ((ggml_tensor_extra_gpu *) dst->src[0]->extra)->optimized_feature.reorder) { - GGML_SYCL_DEBUG("Calling reorder_mul_mat_vec_q4_k_q8_1_sycl\n"); - reorder_mul_mat_vec_q4_k_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream); - } else { + if (i == 0 && src1_ncols > 1 && src1_ncols <= 8) { + const int stride_col_y_bytes = src1_padded_col_size * q8_1_ts / q8_1_bs; + const int stride_col_dst = dst->ne[0]; + GGML_SYCL_DEBUG("Calling reorder_mul_mat_vec_q4_k_q8_1_sycl_switch_ncols ncols=%d\n", (int)src1_ncols); + reorder_mul_mat_vec_q4_k_q8_1_sycl_switch_ncols( + src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, + src1_ncols, stride_col_y_bytes, stride_col_dst, stream); + return; + } else { + GGML_SYCL_DEBUG("Calling reorder_mul_mat_vec_q4_k_q8_1_sycl\n"); + reorder_mul_mat_vec_q4_k_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream); + } + } else if (i == 0 && src1_ncols > 1 && src1_ncols <= 8) { + const int stride_col_y = src1_padded_col_size / QK8_1; + const int stride_col_dst = dst->ne[0]; + GGML_SYCL_DEBUG("Calling mul_mat_vec_q4_K_q8_1_sycl_switch_ncols ncols=%d\n", (int)src1_ncols); + mul_mat_vec_q4_K_q8_1_sycl_switch_ncols( + src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, + src1_ncols, stride_col_y, stride_col_dst, stream); + return; + } else if (i == 0 || src1_ncols == 1) { GGML_SYCL_DEBUG("Calling mul_mat_vec_q4_K_q8_1_sycl\n"); mul_mat_vec_q4_K_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream); } break; case GGML_TYPE_Q5_K: - mul_mat_vec_q5_K_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream); + if ((ggml_tensor_extra_gpu *) dst->src[0]->extra && + ((ggml_tensor_extra_gpu *) dst->src[0]->extra)->optimized_feature.reorder) { + if (i == 0 && src1_ncols > 1 && src1_ncols <= 8) { + const int stride_col_y_bytes = src1_padded_col_size * q8_1_ts / q8_1_bs; + const int stride_col_dst = dst->ne[0]; + GGML_SYCL_DEBUG("Calling reorder_mul_mat_vec_q5_k_q8_1_sycl_switch_ncols ncols=%d\n", (int)src1_ncols); + reorder_mul_mat_vec_q5_k_q8_1_sycl_switch_ncols( + src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, + src1_ncols, stride_col_y_bytes, stride_col_dst, stream); + return; + } else { + GGML_SYCL_DEBUG("Calling reorder_mul_mat_vec_q5_k_q8_1_sycl\n"); + reorder_mul_mat_vec_q5_k_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream); + } + } else if (i == 0 && src1_ncols > 1 && src1_ncols <= 8) { + const int stride_col_y = src1_padded_col_size / QK8_1; + const int stride_col_dst = dst->ne[0]; + GGML_SYCL_DEBUG("Calling mul_mat_vec_q5_K_q8_1_sycl_switch_ncols ncols=%d\n", (int)src1_ncols); + mul_mat_vec_q5_K_q8_1_sycl_switch_ncols( + src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, + src1_ncols, stride_col_y, stride_col_dst, stream); + return; + } else if (i == 0 || src1_ncols == 1) { + GGML_SYCL_DEBUG("Calling mul_mat_vec_q5_K_q8_1_sycl\n"); + mul_mat_vec_q5_K_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream); + } break; case GGML_TYPE_Q6_K: if ((ggml_tensor_extra_gpu *) dst->src[0]->extra && ((ggml_tensor_extra_gpu *) dst->src[0]->extra)->optimized_feature.reorder) { - GGML_SYCL_DEBUG("Calling reorder_mul_mat_vec_q6_k_q8_1_sycl\n"); - reorder_mul_mat_vec_q6_k_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream); - } else { + if (i == 0 && src1_ncols > 1 && src1_ncols <= 8) { + const int stride_col_y_bytes = src1_padded_col_size * q8_1_ts / q8_1_bs; + const int stride_col_dst = dst->ne[0]; + GGML_SYCL_DEBUG("Calling reorder_mul_mat_vec_q6_k_q8_1_sycl_switch_ncols ncols=%d\n", (int)src1_ncols); + reorder_mul_mat_vec_q6_k_q8_1_sycl_switch_ncols( + src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, + src1_ncols, stride_col_y_bytes, stride_col_dst, stream); + return; + } else { + GGML_SYCL_DEBUG("Calling reorder_mul_mat_vec_q6_k_q8_1_sycl\n"); + reorder_mul_mat_vec_q6_k_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream); + } + } else if (i == 0 && src1_ncols > 1 && src1_ncols <= 8) { + const int stride_col_y = src1_padded_col_size / QK8_1; + const int stride_col_dst = dst->ne[0]; + GGML_SYCL_DEBUG("Calling mul_mat_vec_q6_K_q8_1_sycl_switch_ncols ncols=%d\n", (int)src1_ncols); + mul_mat_vec_q6_K_q8_1_sycl_switch_ncols( + src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, + src1_ncols, stride_col_y, stride_col_dst, stream); + return; + } else if (i == 0 || src1_ncols == 1) { GGML_SYCL_DEBUG("Calling mul_mat_vec_q6_k_q8_1_sycl\n"); mul_mat_vec_q6_K_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream); } @@ -1140,13 +2274,46 @@ void ggml_sycl_op_mul_mat_vec_q(ggml_backend_sycl_context & ctx, const ggml_tens mul_mat_vec_iq4_nl_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream); break; case GGML_TYPE_IQ4_XS: - mul_mat_vec_iq4_xs_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream); + if (i == 0 && src1_ncols > 1 && src1_ncols <= 8) { + const int stride_col_y = src1_padded_col_size / QK8_1; + const int stride_col_dst = dst->ne[0]; + GGML_SYCL_DEBUG("Calling mul_mat_vec_iq4_xs_q8_1_sycl_switch_ncols ncols=%d\n", (int)src1_ncols); + mul_mat_vec_iq4_xs_q8_1_sycl_switch_ncols( + src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, + src1_ncols, stride_col_y, stride_col_dst, stream); + return; + } else if (i == 0 || src1_ncols == 1) { + mul_mat_vec_iq4_xs_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream); + } break; case GGML_TYPE_MXFP4: - mul_mat_vec_mxfp4_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream); + if (i == 0 && src1_ncols > 1 && src1_ncols <= 8) { + const int stride_col_y = src1_padded_col_size / QK8_1; + const int stride_col_dst = dst->ne[0]; + GGML_SYCL_DEBUG("Calling mul_mat_vec_mxfp4_q8_1_sycl_switch_ncols ncols=%d\n", (int)src1_ncols); + mul_mat_vec_mxfp4_q8_1_sycl_switch_ncols( + src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, + src1_ncols, stride_col_y, stride_col_dst, stream); + return; + } else if (i == 0 || src1_ncols == 1) { + mul_mat_vec_mxfp4_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream); + } + break; + case GGML_TYPE_NVFP4: + if (i == 0 && src1_ncols > 1 && src1_ncols <= 8) { + const int stride_col_y = src1_padded_col_size / QK8_1; + const int stride_col_dst = dst->ne[0]; + GGML_SYCL_DEBUG("Calling mul_mat_vec_nvfp4_q8_1_sycl_switch_ncols ncols=%d\n", (int)src1_ncols); + mul_mat_vec_nvfp4_q8_1_sycl_switch_ncols( + src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, + src1_ncols, stride_col_y, stride_col_dst, stream); + return; + } else if (i == 0 || src1_ncols == 1) { + mul_mat_vec_nvfp4_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream); + } break; default: - GGML_ABORT("fatal error"); + GGML_ABORT("fatal error: unsupport data type=%s\n", ggml_type_name(src0->type)); } } GGML_UNUSED(src1); @@ -1154,3 +2321,154 @@ void ggml_sycl_op_mul_mat_vec_q(ggml_backend_sycl_context & ctx, const ggml_tens GGML_UNUSED(src1_ddf_i); GGML_UNUSED(ctx); } + +// src1_row_stride: 0 for shared src1 (gate/up proj), else per-expert stride (down proj). +template <int qk, int qi, typename block_q_t, int vdr, vec_dot_q_sycl_t vec_dot_q_sycl> +static void mul_mat_vec_q_moe( + const void * __restrict__ vx_base, const void * __restrict__ vy_base, + float * __restrict__ dst_base, const int32_t * __restrict__ ids_dev, + const int ncols, const int nrows, + const size_t expert_weight_stride, const size_t dst_row_stride, + const size_t src1_row_stride, + const sycl::nd_item<3> & item_ct1) { + + const int expert_idx = item_ct1.get_group(1); + const int i02 = ids_dev[expert_idx]; + + const char * vx = (const char *) vx_base + (size_t) i02 * expert_weight_stride; + const char * vy = (const char *) vy_base + (size_t) expert_idx * src1_row_stride; + float * dst = (float *) ((char *) dst_base + (size_t) expert_idx * dst_row_stride); + + const int row = item_ct1.get_group(2) * item_ct1.get_local_range(1) + item_ct1.get_local_id(1); + + if (row >= nrows) { + return; + } + + const int blocks_per_row = ncols / qk; + constexpr int blocks_per_warp = (vdr * WARP_SIZE + qi - 1) / qi; + + float tmp = 0.0f; + + const block_q_t * x = (const block_q_t *) vx; + const block_q8_1 * y = (const block_q8_1 *) vy; + + for (int i = item_ct1.get_local_id(2) / (qi / vdr); i < blocks_per_row; i += blocks_per_warp) { + const int ibx = row * blocks_per_row + i; + const int iby = i * (qk / QK8_1); + + for (size_t elem = 0; elem < qi / vdr; elem += WARP_SIZE) { + const int iqs = elem + vdr * (item_ct1.get_local_id(2) % (qi / vdr)); + tmp += vec_dot_q_sycl(&x[ibx], &y[iby], iqs); + } + } + +#pragma unroll + for (int mask = WARP_SIZE / 2; mask > 0; mask >>= 1) { + tmp += dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask); + } + + if (item_ct1.get_local_id(2) == 0) { + dst[row] = tmp; + } +} + +template <int qk, int qi, typename block_q_t, int vdr, vec_dot_q_sycl_t vec_dot_q_sycl> +static void launch_mul_mat_vec_q_moe( + const void * vx_base, const void * vy, const int32_t * ids_dev, + float * dst_base, const int ncols, const int nrows, const int n_experts_used, + const size_t expert_weight_stride, const size_t dst_row_stride, + const size_t src1_row_stride, + dpct::queue_ptr stream) { + const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y; + const sycl::range<3> block_nums(1, (unsigned) n_experts_used, (unsigned) block_num_y); + const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE); + stream->submit([&](sycl::handler & cgh) { + cgh.parallel_for( + sycl::nd_range<3>(block_nums * block_dims, block_dims), + [=](sycl::nd_item<3> item) [[sycl::reqd_sub_group_size(WARP_SIZE)]] { + mul_mat_vec_q_moe<qk, qi, block_q_t, vdr, vec_dot_q_sycl>( + vx_base, vy, dst_base, ids_dev, ncols, nrows, + expert_weight_stride, dst_row_stride, src1_row_stride, item); + }); + }); +} + +bool ggml_sycl_mul_mat_vec_q_id( + enum ggml_type src0_type, + const void * vx_base, + const void * vy, + const int32_t * ids_dev, + float * dst_base, + int ncols, + int nrows, + int n_experts_used, + size_t expert_weight_stride, + size_t dst_row_stride, + size_t src1_row_stride, + dpct::queue_ptr stream) { + switch (src0_type) { + case GGML_TYPE_Q4_0: + launch_mul_mat_vec_q_moe<QK4_0, QI4_0, block_q4_0, VDR_Q4_0_Q8_1_MMVQ, vec_dot_q4_0_q8_1>( + vx_base, vy, ids_dev, dst_base, ncols, nrows, n_experts_used, + expert_weight_stride, dst_row_stride, src1_row_stride, stream); + return true; + case GGML_TYPE_Q4_1: + launch_mul_mat_vec_q_moe<QK4_1, QI4_1, block_q4_1, VDR_Q4_1_Q8_1_MMVQ, vec_dot_q4_1_q8_1>( + vx_base, vy, ids_dev, dst_base, ncols, nrows, n_experts_used, + expert_weight_stride, dst_row_stride, src1_row_stride, stream); + return true; + case GGML_TYPE_Q5_0: + launch_mul_mat_vec_q_moe<QK5_0, QI5_0, block_q5_0, VDR_Q5_0_Q8_1_MMVQ, vec_dot_q5_0_q8_1>( + vx_base, vy, ids_dev, dst_base, ncols, nrows, n_experts_used, + expert_weight_stride, dst_row_stride, src1_row_stride, stream); + return true; + case GGML_TYPE_Q5_1: + launch_mul_mat_vec_q_moe<QK5_1, QI5_1, block_q5_1, VDR_Q5_1_Q8_1_MMVQ, vec_dot_q5_1_q8_1>( + vx_base, vy, ids_dev, dst_base, ncols, nrows, n_experts_used, + expert_weight_stride, dst_row_stride, src1_row_stride, stream); + return true; + case GGML_TYPE_Q8_0: + launch_mul_mat_vec_q_moe<QK8_0, QI8_0, block_q8_0, VDR_Q8_0_Q8_1_MMVQ, vec_dot_q8_0_q8_1>( + vx_base, vy, ids_dev, dst_base, ncols, nrows, n_experts_used, + expert_weight_stride, dst_row_stride, src1_row_stride, stream); + return true; + case GGML_TYPE_Q2_K: + launch_mul_mat_vec_q_moe<QK_K, QI2_K, block_q2_K, VDR_Q2_K_Q8_1_MMVQ, vec_dot_q2_K_q8_1>( + vx_base, vy, ids_dev, dst_base, ncols, nrows, n_experts_used, + expert_weight_stride, dst_row_stride, src1_row_stride, stream); + return true; + case GGML_TYPE_Q3_K: + launch_mul_mat_vec_q_moe<QK_K, QI3_K, block_q3_K, VDR_Q3_K_Q8_1_MMVQ, vec_dot_q3_K_q8_1>( + vx_base, vy, ids_dev, dst_base, ncols, nrows, n_experts_used, + expert_weight_stride, dst_row_stride, src1_row_stride, stream); + return true; + case GGML_TYPE_Q4_K: + launch_mul_mat_vec_q_moe<QK_K, QI4_K, block_q4_K, VDR_Q4_K_Q8_1_MMVQ, vec_dot_q4_K_q8_1>( + vx_base, vy, ids_dev, dst_base, ncols, nrows, n_experts_used, + expert_weight_stride, dst_row_stride, src1_row_stride, stream); + return true; + case GGML_TYPE_Q5_K: + launch_mul_mat_vec_q_moe<QK_K, QI5_K, block_q5_K, VDR_Q5_K_Q8_1_MMVQ, vec_dot_q5_K_q8_1>( + vx_base, vy, ids_dev, dst_base, ncols, nrows, n_experts_used, + expert_weight_stride, dst_row_stride, src1_row_stride, stream); + return true; + case GGML_TYPE_Q6_K: + launch_mul_mat_vec_q_moe<QK_K, QI6_K, block_q6_K, VDR_Q6_K_Q8_1_MMVQ, vec_dot_q6_K_q8_1>( + vx_base, vy, ids_dev, dst_base, ncols, nrows, n_experts_used, + expert_weight_stride, dst_row_stride, src1_row_stride, stream); + return true; + case GGML_TYPE_MXFP4: + launch_mul_mat_vec_q_moe<QK_MXFP4, QI_MXFP4, block_mxfp4, VDR_MXFP4_Q8_1_MMVQ, vec_dot_mxfp4_q8_1>( + vx_base, vy, ids_dev, dst_base, ncols, nrows, n_experts_used, + expert_weight_stride, dst_row_stride, src1_row_stride, stream); + return true; + case GGML_TYPE_NVFP4: + launch_mul_mat_vec_q_moe<QK_NVFP4, QI_NVFP4, block_nvfp4, VDR_NVFP4_Q8_1_MMVQ, vec_dot_nvfp4_q8_1>( + vx_base, vy, ids_dev, dst_base, ncols, nrows, n_experts_used, + expert_weight_stride, dst_row_stride, src1_row_stride, stream); + return true; + default: + return false; + } +} diff --git a/ggml/src/ggml-sycl/mmvq.hpp b/ggml/src/ggml-sycl/mmvq.hpp index 049b43d4535..d674dc1d61e 100644 --- a/ggml/src/ggml-sycl/mmvq.hpp +++ b/ggml/src/ggml-sycl/mmvq.hpp @@ -24,4 +24,20 @@ void ggml_sycl_op_mul_mat_vec_q( const int64_t src1_ncols, const int64_t src1_padded_row_size, const dpct::queue_ptr &stream); +// Requires standard (non-reorder) block layout for src0. +// Returns false if src0_type isn't handled; caller should fall back. +bool ggml_sycl_mul_mat_vec_q_id( + enum ggml_type src0_type, + const void * vx_base, // start of stacked expert weights + const void * vy, // pre-quantized src1 (Q8_1) + const int32_t * ids_dev, // device-side int32, length n_experts_used + float * dst_base, + int ncols, + int nrows, + int n_experts_used, + size_t expert_weight_stride, // bytes between experts in vx_base + size_t dst_row_stride, // bytes between dst rows + size_t src1_row_stride, // 0 = shared src1, else per-expert stride in bytes + dpct::queue_ptr stream); + #endif // GGML_SYCL_MMVQ_HPP diff --git a/ggml/src/ggml-sycl/norm.cpp b/ggml/src/ggml-sycl/norm.cpp index 823d3a4828c..09fce1280ad 100644 --- a/ggml/src/ggml-sycl/norm.cpp +++ b/ggml/src/ggml-sycl/norm.cpp @@ -202,47 +202,34 @@ static void rms_norm_f32(const float* x, float* dst, const int ncols, const int6 } } -static void l2_norm_f32(const float* x, float* dst, const int ncols, const float eps, - const sycl::nd_item<3>& item_ct1, float* s_sum, int block_size) { - const int row = item_ct1.get_group(2) * item_ct1.get_local_range(1) + - item_ct1.get_local_id(1); - const int tid = item_ct1.get_local_id(2); - const int nthreads = item_ct1.get_local_range(2); - const int nwarps = nthreads / WARP_SIZE; +template<int warp_size> +static void l2_norm_f32(const float * x, float * dst, const int ncols, + const int64_t stride_row, const int64_t stride_channel, + const int64_t stride_sample, const float eps, + const sycl::nd_item<3>& item_ct1, float* s_sum, const int block_size) { + const int nrows = item_ct1.get_group_range(2); + const int nchannels = item_ct1.get_group_range(1); + + const int row = item_ct1.get_group(2); + const int channel = item_ct1.get_group(1); + const int sample = item_ct1.get_group(0); + const int tid = item_ct1.get_local_id(2); + + x += sample*stride_sample + channel*stride_channel + row*stride_row; + dst += ((sample*nchannels + channel)*nrows + row)*ncols; + float tmp = 0.0f; // partial sum for thread in warp for (int col = tid; col < ncols; col += block_size) { - const float xi = x[row * ncols + col]; + const float xi = x[col]; tmp += xi * xi; } - // sum up partial sums - tmp = warp_reduce_sum(tmp, item_ct1); - if (block_size > WARP_SIZE) { - - int warp_id = item_ct1.get_local_id(2) / WARP_SIZE; - int lane_id = item_ct1.get_local_id(2) % WARP_SIZE; - if (lane_id == 0) { - s_sum[warp_id] = tmp; - } - /* - DPCT1118:3: SYCL group functions and algorithms must be encountered in - converged control flow. You may need to adjust the code. - */ - item_ct1.barrier(sycl::access::fence_space::local_space); - size_t nreduce = nwarps / WARP_SIZE; - tmp = 0.f; - for (size_t i = 0; i < nreduce; i += 1) - { - tmp += s_sum[lane_id + i * WARP_SIZE]; - } - tmp = warp_reduce_sum(tmp, item_ct1); - } - - const float scale = sycl::rsqrt(sycl::max(tmp, eps * eps)); + tmp = block_reduce<block_reduce_method::SUM, warp_size>(tmp, s_sum, block_size); + const float scale = sycl::rsqrt(sycl::fmax(tmp, eps * eps)); for (int col = tid; col < ncols; col += block_size) { - dst[row * ncols + col] = scale * x[row * ncols + col]; + dst[col] = scale * x[col]; } } @@ -251,7 +238,6 @@ static void norm_f32_sycl(const float * x, float * dst, const int ncols, const i const float eps, queue_ptr stream, int device) { const sycl::range<3> global_dims(nsamples, nchannels, nrows); - GGML_ASSERT(ncols % WARP_SIZE == 0); if (ncols < 1024) { const sycl::range<3> block_dims(1, 1, WARP_SIZE); stream->submit([&](sycl::handler& cgh) { @@ -334,7 +320,6 @@ static void group_norm_f32_sycl(const float* x, float* dst, static void rms_norm_f32_sycl(const float* x, float* dst, const int ncols, const int nrows, const int nchannels, const int nsamples, const int64_t stride_row, const int64_t stride_channel, const int64_t stride_sample, const float eps, queue_ptr stream, int device) { - GGML_ASSERT(ncols % WARP_SIZE == 0); // printf("%s ncols=%d, nrows=%d, WARP_SIZE=%d\n", __func__, ncols, nrows, WARP_SIZE); const sycl::range<3> global_dims(nsamples, nchannels, nrows); @@ -371,43 +356,50 @@ static void rms_norm_f32_sycl(const float* x, float* dst, const int ncols, const } } -static void l2_norm_f32_sycl(const float* x, float* dst, const int ncols, - const int nrows, const float eps, - queue_ptr stream, int device) { - GGML_ASSERT(ncols % WARP_SIZE == 0); - // printf("%s ncols=%d, nrows=%d, WARP_SIZE=%d\n", __func__, ncols, nrows, WARP_SIZE); +template<int warp_size> +static void l2_norm_f32_sycl(const float * x, + float * dst, + const int ncols, + const int nrows, + const int nchannels, + const int nsamples, + const int64_t stride_row, + const int64_t stride_channel, + const int64_t stride_sample, + const float eps, + queue_ptr stream, + int device) { + const dpct::dim3 blocks_num(nrows, nchannels, nsamples); + if (ncols < 1024) { - const sycl::range<3> block_dims(1, 1, WARP_SIZE); + const dpct::dim3 block_dims(warp_size, 1, 1); stream->submit([&](sycl::handler& cgh) { cgh.parallel_for( - sycl::nd_range<3>(sycl::range<3>(1, 1, nrows) * block_dims, + sycl::nd_range<3>(blocks_num * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) - [[sycl::reqd_sub_group_size(WARP_SIZE)]] { - l2_norm_f32(x, dst, ncols, eps, item_ct1, - nullptr, WARP_SIZE); + [[sycl::reqd_sub_group_size(warp_size)]] { + l2_norm_f32<warp_size>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps, item_ct1, + nullptr, warp_size); }); }); } else { const int work_group_size = ggml_sycl_info().max_work_group_sizes[device]; - assert(work_group_size % (WARP_SIZE * WARP_SIZE) == 0); + assert(work_group_size % (warp_size * warp_size) == 0); const sycl::range<3> block_dims(1, 1, work_group_size); - /* - DPCT1049:19: The work-group size passed to the SYCL kernel may exceed - the limit. To get the device limit, query - info::device::max_work_group_size. Adjust the work-group size if needed. - */ + int lsm_size = block_dims[2] > warp_size ? work_group_size / warp_size * sizeof(float): 0; stream->submit([&](sycl::handler& cgh) { - sycl::local_accessor<float, 1> s_sum_acc_ct1(sycl::range<1>(work_group_size / WARP_SIZE), + sycl::local_accessor<float, 1> s_sum_acc_ct1(sycl::range<1>(lsm_size), cgh); + cgh.parallel_for( - sycl::nd_range<3>(sycl::range<3>(1, 1, nrows) * block_dims, + sycl::nd_range<3>(blocks_num * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) - [[sycl::reqd_sub_group_size(WARP_SIZE)]] { - l2_norm_f32(x, dst, ncols, eps, item_ct1, - get_pointer(s_sum_acc_ct1), work_group_size); + [[sycl::reqd_sub_group_size(warp_size)]] { + l2_norm_f32<warp_size>(x, dst, ncols, stride_row, stride_channel, stride_sample, + eps, item_ct1, get_pointer(s_sum_acc_ct1), work_group_size); }); }); } @@ -637,21 +629,28 @@ void ggml_sycl_op_rms_norm_back(ggml_backend_sycl_context & ctx, ggml_tensor * d } void ggml_sycl_op_l2_norm(ggml_backend_sycl_context& ctx, ggml_tensor* dst) { + const ggml_tensor * src0 = dst->src[0]; + const float * src0_d = (const float *) src0->data; + float * dst_d = (float *) dst->data; + dpct::queue_ptr stream = ctx.stream(); - GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32); - GGML_ASSERT(dst->type == GGML_TYPE_F32); + GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT( dst->type == GGML_TYPE_F32); - dpct::queue_ptr main_stream = ctx.stream(); - SYCL_CHECK(ggml_sycl_set_device(ctx.device)); - - const int64_t ne00 = dst->src[0]->ne[0]; - const int64_t nrows = ggml_nrows(dst->src[0]); - const float * src0_dd = static_cast<const float *>(dst->src[0]->data); - float * dst_dd = static_cast<float *>(dst->data); + GGML_TENSOR_UNARY_OP_LOCALS; float eps; memcpy(&eps, dst->op_params, sizeof(float)); + GGML_ASSERT(eps >= 0.0f); - l2_norm_f32_sycl(src0_dd, dst_dd, ne00, nrows, eps, main_stream, ctx.device); + const size_t ts0 = ggml_type_size(src0->type); + GGML_ASSERT(nb00 == ts0); + const int64_t s01 = nb01 / ts0; + const int64_t s02 = nb02 / ts0; + const int64_t s03 = nb03 / ts0; + /*support both WARP_SIZE or WARP_32_SIZE in code + choose by hardware for better performance + */ + l2_norm_f32_sycl<WARP_SIZE>(src0_d, dst_d, ne00, ne01, ne02, ne03, s01, s02, s03, eps, stream, ctx.device); } diff --git a/ggml/src/ggml-sycl/outprod.cpp b/ggml/src/ggml-sycl/outprod.cpp index 3a17f3a1b88..f52b11f0d6e 100644 --- a/ggml/src/ggml-sycl/outprod.cpp +++ b/ggml/src/ggml-sycl/outprod.cpp @@ -32,12 +32,12 @@ void ggml_sycl_op_out_prod(ggml_backend_sycl_context& ctx, ggml_tensor* dst) { // Handle transposition of src1 const bool src1_T = ggml_is_transposed(src1); - const oneapi::math::transpose src1_op = src1_T ? oneapi::math::transpose::nontrans : oneapi::math::transpose::trans; + const oneapi::mkl::transpose src1_op = src1_T ? oneapi::mkl::transpose::nontrans : oneapi::mkl::transpose::trans; const int64_t ldb = (src1_T ? nb10 : nb11) / sizeof(float); try { - // Perform matrix multiplication using oneMath GEMM - oneapi::math::blas::column_major::gemm(get_onemath_backend(*stream), oneapi::math::transpose::nontrans, src1_op, + // Perform matrix multiplication using oneMKL GEMM + oneapi::mkl::blas::column_major::gemm(*stream, oneapi::mkl::transpose::nontrans, src1_op, ne0, ne1, ne01, alpha, src0_d, ne00, src1_d, ldb, beta, dst_d, ne0); } catch (sycl::exception const& exc) { diff --git a/ggml/src/ggml-sycl/pad.cpp b/ggml/src/ggml-sycl/pad.cpp index f989c5e4b8b..ee93bb51801 100644 --- a/ggml/src/ggml-sycl/pad.cpp +++ b/ggml/src/ggml-sycl/pad.cpp @@ -13,7 +13,8 @@ //#include "common.hpp" #include "pad.hpp" -static void pad_f32(const float * src, float * dst, +static void pad_f32(const float * src, size_t s00, size_t s01, size_t s02, size_t s03, + float * dst, const int lp0, const int rp0, const int lp1, const int rp1, const int lp2, const int rp2, const int lp3, const int rp3, const int ne0, const int ne1, const int ne2, const int ne3, @@ -27,7 +28,6 @@ static void pad_f32(const float * src, float * dst, return; } - // operation const int64_t dst_idx = i3*(ne0*ne1*ne2) + i2*(ne0*ne1) + i1*ne0 + i0; if ((i0 >= lp0 && i0 < ne0 - rp0) && (i1 >= lp1 && i1 < ne1 - rp1) && @@ -37,12 +37,8 @@ static void pad_f32(const float * src, float * dst, const int64_t i01 = i1 - lp1; const int64_t i02 = i2 - lp2; const int64_t i03 = i3 - lp3; - const int64_t ne02 = ne2 - lp2 - rp2; - const int64_t ne01 = ne1 - lp1 - rp1; - const int64_t ne00 = ne0 - lp0 - rp0; - const int64_t src_idx = i03 * (ne00 * ne01 * ne02) + - i02 * (ne00 * ne01) + i01 * ne00 + i00; + const int64_t src_idx = i03 * s03 + i02 * s02 + i01 * s01 + i00 * s00; dst[dst_idx] = src[src_idx]; } else { @@ -50,20 +46,19 @@ static void pad_f32(const float * src, float * dst, } } -static void pad_f32_sycl(const float *src, float *dst, const int lp0, - const int rp0, const int lp1, const int rp1, - const int lp2, const int rp2, const int lp3, - const int rp3, const int ne0, const int ne1, - const int ne2, const int ne3, +static void pad_f32_sycl(const float * src, size_t s00, size_t s01, size_t s02, size_t s03, + float * dst, const int lp0, const int rp0, const int lp1, const int rp1, + const int lp2, const int rp2, const int lp3, const int rp3, + const int ne0, const int ne1, const int ne2, const int ne3, dpct::queue_ptr stream) { int num_blocks = (ne0 + SYCL_PAD_BLOCK_SIZE - 1) / SYCL_PAD_BLOCK_SIZE; - dpct::dim3 gridDim(num_blocks, ne1, ne2 * ne3); + sycl::range<3> grid(ne2 * ne3, ne1, num_blocks); stream->parallel_for( - sycl::nd_range<3>(gridDim * sycl::range<3>(1, 1, SYCL_PAD_BLOCK_SIZE), + sycl::nd_range<3>(grid * sycl::range<3>(1, 1, SYCL_PAD_BLOCK_SIZE), sycl::range<3>(1, 1, SYCL_PAD_BLOCK_SIZE)), [=](sycl::nd_item<3> item_ct1) { - pad_f32(src, dst, lp0, rp0, lp1, rp1, lp2, rp2, lp3, rp3, ne0, ne1, - ne2, ne3, item_ct1); + pad_f32(src, s00, s01, s02, s03, dst, lp0, rp0, lp1, rp1, lp2, rp2, lp3, rp3, + ne0, ne1, ne2, ne3, item_ct1); }); } @@ -71,22 +66,27 @@ void ggml_sycl_op_pad(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { const ggml_tensor * src0 = dst->src[0]; const float * src0_d = (const float *)src0->data; float * dst_d = (float *)dst->data; - dpct::queue_ptr stream = ctx.stream(); + dpct::queue_ptr stream = ctx.stream(); GGML_ASSERT(src0->type == GGML_TYPE_F32); GGML_ASSERT(dst->type == GGML_TYPE_F32); - GGML_ASSERT(ggml_is_contiguous(src0)); - const int32_t lp0 = ((const int32_t*)(dst->op_params))[0]; - const int32_t rp0 = ((const int32_t*)(dst->op_params))[1]; - const int32_t lp1 = ((const int32_t*)(dst->op_params))[2]; - const int32_t rp1 = ((const int32_t*)(dst->op_params))[3]; - const int32_t lp2 = ((const int32_t*)(dst->op_params))[4]; - const int32_t rp2 = ((const int32_t*)(dst->op_params))[5]; - const int32_t lp3 = ((const int32_t*)(dst->op_params))[6]; - const int32_t rp3 = ((const int32_t*)(dst->op_params))[7]; + const size_t ts = ggml_type_size(src0->type); + const size_t s00 = src0->nb[0] / ts; + const size_t s01 = src0->nb[1] / ts; + const size_t s02 = src0->nb[2] / ts; + const size_t s03 = src0->nb[3] / ts; - pad_f32_sycl(src0_d, dst_d, + const int32_t lp0 = ((const int32_t *)(dst->op_params))[0]; + const int32_t rp0 = ((const int32_t *)(dst->op_params))[1]; + const int32_t lp1 = ((const int32_t *)(dst->op_params))[2]; + const int32_t rp1 = ((const int32_t *)(dst->op_params))[3]; + const int32_t lp2 = ((const int32_t *)(dst->op_params))[4]; + const int32_t rp2 = ((const int32_t *)(dst->op_params))[5]; + const int32_t lp3 = ((const int32_t *)(dst->op_params))[6]; + const int32_t rp3 = ((const int32_t *)(dst->op_params))[7]; + + pad_f32_sycl(src0_d, s00, s01, s02, s03, dst_d, lp0, rp0, lp1, rp1, lp2, rp2, lp3, rp3, dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], stream); } diff --git a/ggml/src/ggml-sycl/presets.hpp b/ggml/src/ggml-sycl/presets.hpp index b6517374230..dc4dad1d37a 100644 --- a/ggml/src/ggml-sycl/presets.hpp +++ b/ggml/src/ggml-sycl/presets.hpp @@ -73,4 +73,7 @@ static_assert(K_QUANTS_PER_ITERATION == 1 || K_QUANTS_PER_ITERATION == 2, "K_QUA #define MUL_MAT_SRC1_COL_STRIDE 128 #define QK_WARP_SIZE 32 +#define WARP_32_SIZE 32 +#define WARP_16_SIZE 16 + #endif // GGML_SYCL_PRESETS_HPP diff --git a/ggml/src/ggml-sycl/quants.hpp b/ggml/src/ggml-sycl/quants.hpp index d0d5ac9a4e8..95287f17510 100644 --- a/ggml/src/ggml-sycl/quants.hpp +++ b/ggml/src/ggml-sycl/quants.hpp @@ -29,7 +29,7 @@ namespace ggml_sycl_reordered { // [qs0, qs1, qs2, ..., qsN] [d0, d1, d2, ..., dN] // // Notes: out-of-bounds qs will run into d values -// Aligment relies on the allocated size of qs +// Alignment relies on the allocated size of qs template <ggml_type type> struct block_q_t; @@ -58,6 +58,31 @@ template <> struct block_q_t<GGML_TYPE_Q4_0> { static constexpr int block_to_q8_1_ratio() { return traits::qk / QK8_1; } }; +template <> struct block_q_t<GGML_TYPE_Q3_K> { + struct traits { + static constexpr uint32_t qk = QK_K; + static constexpr uint32_t qi = QI3_K; + static constexpr uint32_t qr = QR3_K; + static constexpr uint32_t vdr_mmvq = 1; + }; + + // Reordered layout: [qs (QK_K/4 per block)] [hmask (QK_K/8 per block)] [scales] [d] + static constexpr std::pair<int, int> get_block_offset(const int block_index, const int n_blocks) { + auto qs_offset = block_index * (QK_K / 4); + auto hmask_offset = n_blocks * (QK_K / 4) + block_index * (QK_K / 8); + return { qs_offset, hmask_offset }; + } + + static constexpr std::pair<int, int> get_d_offset(int nrows, int ncols, const int block_index) { + auto nblocks = (nrows * (ncols / QK_K)); + auto total_qs_bytes = nblocks * (QK_K / 4) + nblocks * (QK_K / 8); + return { total_qs_bytes + block_index * 12, + total_qs_bytes + nblocks * 12 + block_index * sizeof(ggml_half) }; + } + + static constexpr int block_to_q8_1_ratio() { return traits::qk / QK8_1; } +}; + template <> struct block_q_t<GGML_TYPE_Q4_K> { struct traits { static constexpr uint32_t qk = QK_K; @@ -79,6 +104,31 @@ template <> struct block_q_t<GGML_TYPE_Q4_K> { static constexpr int block_to_q8_1_ratio() { return traits::qk / QK8_1; } }; +template <> struct block_q_t<GGML_TYPE_Q5_K> { + struct traits { + static constexpr uint32_t qk = QK_K; + static constexpr uint32_t qi = QI5_K; + static constexpr uint32_t qr = QR5_K; + static constexpr uint32_t vdr_mmvq = 2; + }; + + // Reordered layout: [qs (QK_K/2 per block)] [qh (QK_K/8 per block)] [scales] [dm] + static constexpr std::pair<int, int> get_block_offset(const int block_index, const int n_blocks) { + auto qs_offset = block_index * (QK_K / 2); + auto qh_offset = n_blocks * (QK_K / 2) + block_index * (QK_K / 8); + return { qs_offset, qh_offset }; + } + + static constexpr std::pair<int, int> get_d_offset(int nrows, int ncols, const int block_index) { + auto nblocks = (nrows * (ncols / QK_K)); + auto total_qs_bytes = nblocks * (QK_K / 2) + nblocks * (QK_K / 8); + return { total_qs_bytes + block_index * K_SCALE_SIZE, + total_qs_bytes + nblocks * K_SCALE_SIZE + block_index * sizeof(ggml_half2) }; + } + + static constexpr int block_to_q8_1_ratio() { return traits::qk / QK8_1; } +}; + template <> struct block_q_t<GGML_TYPE_Q6_K> { struct traits { static constexpr uint32_t qk = QK_K; @@ -105,6 +155,27 @@ template <> struct block_q_t<GGML_TYPE_Q6_K> { static constexpr int block_to_q8_1_ratio() { return traits::qk / QK8_1; } }; +template <> struct block_q_t<GGML_TYPE_Q8_0> { + struct traits { + static constexpr uint32_t qk = QK8_0; // 32 + static constexpr uint32_t qi = QI8_0; // 8 + static constexpr uint32_t qr = QR8_0; // 1 + static constexpr uint32_t vdr_mmvq = 4; + }; + + // Q8_0 reorder layout: [qs0|qs1|...|qsN][d0|d1|...|dN] + // Each block has 32 int8 weights (32 bytes) followed by all scales + static constexpr std::pair<int, int> get_block_offset(const int block_index, const int /* nblocks */) { + return { block_index * QK8_0, 0 }; + } + + static constexpr std::pair<int, int> get_d_offset(int nrows, int ncols, const int block_index) { + return { (ncols * nrows) + block_index * sizeof(ggml_half), 0 }; + } + + static constexpr int block_to_q8_1_ratio() { return traits::qk / QK8_1; } // 1 +}; + } // namespace ggml_sycl_reordered #endif // GGML_SYCL_QUANTS_HPP diff --git a/ggml/src/ggml-sycl/rope.cpp b/ggml/src/ggml-sycl/rope.cpp index 69140b19a4c..9d83a1e9fa0 100644 --- a/ggml/src/ggml-sycl/rope.cpp +++ b/ggml/src/ggml-sycl/rope.cpp @@ -1,4 +1,5 @@ #include "rope.hpp" +#include "convert.hpp" #include "ggml-sycl/common.hpp" #include "ggml.h" @@ -15,367 +16,489 @@ static float rope_yarn_ramp(const float low, const float high, const int i0) { return 1.0f - sycl::min(1.0f, sycl::max(0.0f, y)); } -// YaRN algorithm based on LlamaYaRNScaledRotaryEmbedding.py from https://github.com/jquesnelle/yarn -// MIT licensed. Copyright (c) 2023 Jeffrey Quesnelle and Bowen Peng. -static void rope_yarn( - float theta_extrap, float freq_scale, rope_corr_dims corr_dims, int64_t i0, float ext_factor, float mscale, - float * cos_theta, float * sin_theta) { - // Get n-d rotational scaling corrected for extrapolation +template <bool forward> +static void rope_yarn(const float theta_extrap, const float freq_scale, + const rope_corr_dims corr_dims, const int64_t i0, + const float ext_factor, float mscale, float &cos_theta, + float &sin_theta) { float theta_interp = freq_scale * theta_extrap; float theta = theta_interp; if (ext_factor != 0.0f) { - float ramp_mix = rope_yarn_ramp(corr_dims.v[0], corr_dims.v[1], i0) * ext_factor; + float ramp_mix = + rope_yarn_ramp(corr_dims.v[0], corr_dims.v[1], i0) * ext_factor; theta = theta_interp * (1 - ramp_mix) + theta_extrap * ramp_mix; - // Get n-d magnitude scaling corrected for interpolation mscale *= 1.0f + 0.1f * sycl::log(1.0f / freq_scale); } - *cos_theta = sycl::cos(theta) * mscale; - *sin_theta = sycl::sin(theta) * mscale; + cos_theta = sycl::cos(theta) * mscale; + sin_theta = sycl::sin(theta) * mscale; + if (!forward) { + sin_theta *= -1.0f; + } } -template <typename T, bool has_ff> -static void rope_norm(const T * x, T * dst, const int ne0, const int ne1, const int s1, const int s2, const int n_dims, - const int32_t * pos, float freq_scale, float ext_factor, float attn_factor, - const rope_corr_dims corr_dims, const float theta_scale, const float * freq_factors, - const sycl::nd_item<3> & item_ct1) { - const int i0 = 2 * (item_ct1.get_local_range(1) * item_ct1.get_group(1) + item_ct1.get_local_id(1)); - - if (i0 >= ne0) { +template <bool forward, bool has_ff, typename T, typename D> +static void rope_norm(const T *x, D *dst, const int ne00, const int ne01, + const int ne02, const int s01, const int s02, + const int s03, const int s1, const int s2, const int s3, + const int n_dims, const int32_t *pos, + const float freq_scale, const float ext_factor, + const float attn_factor, const rope_corr_dims corr_dims, + const float theta_scale, const float *freq_factors, + const int64_t *row_indices, const int set_rows_stride) { + auto item_ct1 = sycl::ext::oneapi::this_work_item::get_nd_item<3>(); + const int i0 = 2 * (item_ct1.get_local_range(1) * item_ct1.get_group(1) + + item_ct1.get_local_id(1)); + + if (i0 >= ne00) { return; } - const int row = item_ct1.get_local_range(2) * item_ct1.get_group(2) + item_ct1.get_local_id(2); + const int row_dst = item_ct1.get_local_range(2) * item_ct1.get_group(2) + + item_ct1.get_local_id(2); - const int row0 = row % ne1; - const int channel0 = row / ne1; + const uint32_t i3 = row_dst / (ne01 * ne02); + const uint32_t i2 = (row_dst - i3 * ne01 * ne02) / ne01; + const uint32_t i1 = row_dst - i3 * ne01 * ne02 - i2 * ne01; - const int i = row * ne0 + i0; - const int i2 = channel0 * s2 + row0 * s1 + i0; + int idst = i0 + i1 * s1 + i2 * s2 + i3 * s3; + const int ix = i0 + i1 * s01 + i2 * s02 + i3 * s03; + + if (set_rows_stride != 0) { + idst = i1 * s1 + i0; + idst += row_indices[i2] * set_rows_stride; + } + const auto &store_coaelsced = [&](float x0, float x1) { + if constexpr (std::is_same_v<float, D>) { + sycl::float2 v = sycl::float2(x0, x1); + ggml_sycl_memcpy_1<8>(dst + idst, &v); + } else if constexpr (std::is_same_v<sycl::half, D>) { + sycl::half2 v = sycl::half2(x0, x1); + ggml_sycl_memcpy_1<4>(dst + idst, &v); + } + }; if (i0 >= n_dims) { - *reinterpret_cast<sycl::vec<T, 2> *>(dst + i) = *reinterpret_cast<const sycl::vec<T, 2> *>(x + i2); + store_coaelsced(x[ix + 0], x[ix + 1]); return; } - const float theta_base = pos[channel0] * sycl::pow(theta_scale, i0 / 2.0f); + const float theta_base = pos[i2] * dpct::pow(theta_scale, i0 / 2.0f); const float freq_factor = has_ff ? freq_factors[i0 / 2] : 1.0f; float cos_theta; float sin_theta; - rope_yarn(theta_base / freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta); + rope_yarn<forward>(theta_base / freq_factor, freq_scale, corr_dims, i0, + ext_factor, attn_factor, cos_theta, sin_theta); - const float x0 = x[i2 + 0]; - const float x1 = x[i2 + 1]; + const float x0 = x[ix + 0]; + const float x1 = x[ix + 1]; - dst[i + 0] = x0 * cos_theta - x1 * sin_theta; - dst[i + 1] = x0 * sin_theta + x1 * cos_theta; + store_coaelsced(x0 * cos_theta - x1 * sin_theta, + x0 * sin_theta + x1 * cos_theta); } -template <typename T, bool has_ff> -static void rope_neox(const T * x, T * dst, const int ne0, const int ne1, const int s1, const int s2, const int n_dims, - const int32_t * pos, const float freq_scale, const float ext_factor, const float attn_factor, - const rope_corr_dims corr_dims, const float theta_scale, const float * freq_factors, - const sycl::nd_item<3> & item_ct1) { - const int i0 = 2 * (item_ct1.get_local_range(1) * item_ct1.get_group(1) + item_ct1.get_local_id(1)); - - if (i0 >= ne0) { +template <bool forward, bool has_ff, typename T, typename D> +static void rope_neox(const T *x, D *dst, const int ne00, const int ne01, + const int ne02, const int s01, const int s02, + const int s03, const int s1, const int s2, const int s3, + const int n_dims, const int32_t *pos, + const float freq_scale, const float ext_factor, + const float attn_factor, const rope_corr_dims corr_dims, + const float theta_scale, const float *freq_factors, + const int64_t *row_indices, const int set_rows_stride) { + auto item_ct1 = sycl::ext::oneapi::this_work_item::get_nd_item<3>(); + const int i0 = 2 * (item_ct1.get_local_range(1) * item_ct1.get_group(1) + + item_ct1.get_local_id(1)); + + if (i0 >= ne00) { return; } - const int row = item_ct1.get_local_range(2) * item_ct1.get_group(2) + item_ct1.get_local_id(2); + const int row_dst = item_ct1.get_local_range(2) * item_ct1.get_group(2) + + item_ct1.get_local_id(2); - const int row0 = row % ne1; - const int channel0 = row / ne1; + const uint32_t i3 = row_dst / (ne01 * ne02); + const uint32_t i2 = (row_dst - i3 * ne01 * ne02) / ne01; + const uint32_t i1 = row_dst - i3 * ne01 * ne02 - i2 * ne01; - const int i = row * ne0 + i0 / 2; - const int i2 = channel0 * s2 + row0 * s1 + i0 / 2; + int idst = i0 / 2 + i1 * s1 + i2 * s2 + i3 * s3; + const int ix = i0 / 2 + i1 * s01 + i2 * s02 + i3 * s03; + + if (set_rows_stride != 0) { + idst = i1 * s1 + i0 / 2; + idst += row_indices[i2] * set_rows_stride; + } if (i0 >= n_dims) { - *reinterpret_cast<sycl::vec<T, 2> *>(dst + i + i0 / 2) = *reinterpret_cast<const sycl::vec<T, 2> *>(x + i2 + i0 / 2); + dst[idst + i0 / 2 + 0] = ggml_sycl_cast<D>(x[ix + i0 / 2 + 0]); + dst[idst + i0 / 2 + 1] = ggml_sycl_cast<D>(x[ix + i0 / 2 + 1]); + return; } - const float theta_base = pos[channel0] * sycl::pow(theta_scale, i0 / 2.0f); + const float theta_base = pos[i2] * dpct::pow(theta_scale, i0 / 2.0f); const float freq_factor = has_ff ? freq_factors[i0 / 2] : 1.0f; float cos_theta; float sin_theta; - rope_yarn(theta_base / freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta); + rope_yarn<forward>(theta_base / freq_factor, freq_scale, corr_dims, i0, + ext_factor, attn_factor, cos_theta, sin_theta); - const float x0 = x[i2 + 0]; - const float x1 = x[i2 + n_dims / 2]; + const float x0 = x[ix + 0]; + const float x1 = x[ix + n_dims / 2]; - dst[i + 0] = x0 * cos_theta - x1 * sin_theta; - dst[i + n_dims / 2] = x0 * sin_theta + x1 * cos_theta; + dst[idst + 0] = ggml_sycl_cast<D>(x0 * cos_theta - x1 * sin_theta); + dst[idst + n_dims / 2] = ggml_sycl_cast<D>(x0 * sin_theta + x1 * cos_theta); } -template <typename T, bool has_ff> -static void rope_multi(const T * x, T * dst, const int ne0, const int ne1, const int ne2, const size_t s1, - const size_t s2, const int n_dims, const int32_t * pos, const float freq_scale, - const float ext_factor, const float attn_factor, const rope_corr_dims corr_dims, - const float theta_scale, const float * freq_factors, const mrope_sections sections, - const bool is_imrope, const sycl::nd_item<3> & item_ct1) { - // get index pos - const int i0 = 2 * (item_ct1.get_group(1) * item_ct1.get_local_range(1) + item_ct1.get_local_id(1)); - if (i0 >= ne0) { +template <bool forward, bool has_ff, typename T> +static void rope_multi(const T *x, T *dst, const int ne00, const int ne01, + const int ne02, const int s01, const int s02, + const int s03, const int s1, const int s2, const int s3, + const int n_dims, const int32_t *pos, + const float freq_scale, const float ext_factor, + const float attn_factor, const rope_corr_dims corr_dims, + const float theta_scale, const float *freq_factors, + const mrope_sections sections, const bool is_imrope) { + auto item_ct1 = sycl::ext::oneapi::this_work_item::get_nd_item<3>(); + const int i0 = 2 * (item_ct1.get_local_range(1) * item_ct1.get_group(1) + + item_ct1.get_local_id(1)); + + if (i0 >= ne00) { return; } - const int row_dst = (item_ct1.get_group(2) * item_ct1.get_local_range(2)) + item_ct1.get_local_id(2); - const int row_x = row_dst % ne1; - const int channel_x = row_dst / ne1; - const int idst = (row_dst * ne0) + (i0 / 2); - const size_t ix = ((size_t) channel_x * s2) + ((size_t) row_x * s1) + (i0 / 2); + const int row_dst = item_ct1.get_local_range(2) * item_ct1.get_group(2) + + item_ct1.get_local_id(2); + + const uint32_t i3 = row_dst / (ne01 * ne02); + const uint32_t i2 = (row_dst - i3 * ne01 * ne02) / ne01; + const uint32_t i1 = row_dst - i3 * ne01 * ne02 - i2 * ne01; + + int idst = i0 / 2 + i1 * s1 + i2 * s2 + i3 * s3; + const int ix = i0 / 2 + i1 * s01 + i2 * s02 + i3 * s03; if (i0 >= n_dims) { - *reinterpret_cast<sycl::vec<T, 2> *>(dst + idst + i0 / 2) = *reinterpret_cast<const sycl::vec<T, 2> *>(x + i0 / 2 + ix); + dst[idst + i0 / 2 + 0] = x[ix + i0 / 2 + 0]; + dst[idst + i0 / 2 + 1] = x[ix + i0 / 2 + 1]; + return; } - const int sect_dims = sections.v[0] + sections.v[1] + sections.v[2] + sections.v[3]; + const int sect_dims = + sections.v[0] + sections.v[1] + sections.v[2] + sections.v[3]; const int sec_w = sections.v[1] + sections.v[0]; const int sector = (i0 / 2) % sect_dims; - float theta_base = 0.0; if (is_imrope) { - if (sector % 3 == 1 && sector < 3 * sections.v[1]) { - theta_base = pos[channel_x + ne2 * 1]*sycl::pow(theta_scale, i0/2.0f); - } else if (sector % 3 == 2 && sector < 3 * sections.v[2]) { - theta_base = pos[channel_x + ne2 * 2]*sycl::pow(theta_scale, i0/2.0f); - } else if (sector % 3 == 0 && sector < 3 * sections.v[0]) { - theta_base = pos[channel_x]*sycl::pow(theta_scale, i0/2.0f); + if (sector % 3 == 1 && sector < 3 * sections.v[1]) { // h + theta_base = pos[i2 + ne02 * 1] * dpct::pow(theta_scale, i0 / 2.0f); + } else if (sector % 3 == 2 && sector < 3 * sections.v[2]) { // w + theta_base = pos[i2 + ne02 * 2] * dpct::pow(theta_scale, i0 / 2.0f); + } else if (sector % 3 == 0 && sector < 3 * sections.v[0]) { // t + theta_base = pos[i2] * dpct::pow(theta_scale, i0 / 2.0f); } else { - theta_base = pos[channel_x + ne2 * 3]*sycl::pow(theta_scale, i0/2.0f); + theta_base = pos[i2 + ne02 * 3] * dpct::pow(theta_scale, i0 / 2.0f); } } else { if (sector < sections.v[0]) { - theta_base = pos[channel_x]*sycl::pow(theta_scale, i0/2.0f); - } - else if (sector >= sections.v[0] && sector < sec_w) { - theta_base = pos[channel_x + ne2 * 1]*sycl::pow(theta_scale, i0/2.0f); - } - else if (sector >= sec_w && sector < sec_w + sections.v[2]) { - theta_base = pos[channel_x + ne2 * 2]*sycl::pow(theta_scale, i0/2.0f); - } - else if (sector >= sec_w + sections.v[2]) { - theta_base = pos[channel_x + ne2 * 3]*sycl::pow(theta_scale, i0/2.0f); + theta_base = pos[i2] * dpct::pow(theta_scale, i0 / 2.0f); + } else if (sector >= sections.v[0] && sector < sec_w) { + theta_base = pos[i2 + ne02 * 1] * dpct::pow(theta_scale, i0 / 2.0f); + } else if (sector >= sec_w && sector < sec_w + sections.v[2]) { + theta_base = pos[i2 + ne02 * 2] * dpct::pow(theta_scale, i0 / 2.0f); + } else if (sector >= sec_w + sections.v[2]) { + theta_base = pos[i2 + ne02 * 3] * dpct::pow(theta_scale, i0 / 2.0f); } } const float freq_factor = has_ff ? freq_factors[i0 / 2] : 1.0f; - float cos_theta; - float sin_theta; - rope_yarn(theta_base / freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta); - const float x0 = x[ix + 0]; - const float x1 = x[ix + n_dims/2]; - // store results in dst - dst[idst + 0] = x0 * cos_theta - x1 * sin_theta; - dst[idst + n_dims/2] = x0 * sin_theta + x1 * cos_theta; -} + float cos_theta; + float sin_theta; + rope_yarn<forward>(theta_base / freq_factor, freq_scale, corr_dims, i0, + ext_factor, attn_factor, cos_theta, sin_theta); + const float x0 = x[ix + 0]; + const float x1 = x[ix + n_dims / 2]; + + dst[idst + 0] = x0 * cos_theta - x1 * sin_theta; + dst[idst + n_dims / 2] = x0 * sin_theta + x1 * cos_theta; +} -template <typename T, bool has_ff> -static void rope_vision(const T * x, T * dst, const int ne0, const int ne1, const int ne2, const size_t s1, - const size_t s2, const int n_dims, const int32_t * pos, const float freq_scale, - const float ext_factor, const float attn_factor, const rope_corr_dims corr_dims, - const float theta_scale, const float * freq_factors, const mrope_sections sections, - const sycl::nd_item<3> & item_ct1) { - // get index pos - const int i0 = 2 * (item_ct1.get_group(1) * item_ct1.get_local_range(1) + item_ct1.get_local_id(1)); - if (i0 >= ne0) { +template <bool forward, bool has_ff, typename T> +static void rope_vision(const T *x, T *dst, const int ne00, const int ne01, + const int ne02, const int s01, const int s02, + const int s03, const int s1, const int s2, const int s3, + const int n_dims, const int32_t *pos, + const float freq_scale, const float ext_factor, + const float attn_factor, const rope_corr_dims corr_dims, + const float theta_scale, const float *freq_factors, + const mrope_sections sections) { + auto item_ct1 = sycl::ext::oneapi::this_work_item::get_nd_item<3>(); + const int i0 = 2 * (item_ct1.get_local_range(1) * item_ct1.get_group(1) + + item_ct1.get_local_id(1)); + + if (i0 >= ne00) { return; } - const int row_dst = (item_ct1.get_group(2) * item_ct1.get_local_range(2)) + item_ct1.get_local_id(2); - const int row_x = row_dst % ne1; - const int channel_x = row_dst / ne1; - const int idst = (row_dst * ne0) + (i0 / 2); - const size_t ix = ((size_t) channel_x * s2) + ((size_t) row_x * s1) + (i0 / 2); + + const int row_dst = item_ct1.get_local_range(2) * item_ct1.get_group(2) + + item_ct1.get_local_id(2); + + const uint32_t i3 = row_dst / (ne01 * ne02); + const uint32_t i2 = (row_dst - i3 * ne01 * ne02) / ne01; + const uint32_t i1 = row_dst - i3 * ne01 * ne02 - i2 * ne01; + + int idst = i0 / 2 + i1 * s1 + i2 * s2 + i3 * s3; + const int ix = i0 / 2 + i1 * s01 + i2 * s02 + i3 * s03; const int sect_dims = sections.v[0] + sections.v[1]; - const int sector = (i0 / 2) % sect_dims; + const int sec_w = sections.v[1] + sections.v[0]; + const int sector = (i0 / 2) % sect_dims; - float theta_base = 0.0f; + float theta_base = 0.0; if (sector < sections.v[0]) { const int p = sector; - theta_base = pos[channel_x] * sycl::pow(theta_scale, (float) p); - } else { - // Simplified from CUDA backend code: if (sector >= sections.v[0] && sector < sec_w) which is just sector >= sections.v[0] + theta_base = pos[i2] * dpct::pow(theta_scale, p); + } else if (sector >= sections.v[0] && sector < sec_w) { const int p = sector - sections.v[0]; - theta_base = pos[channel_x + ne2] * sycl::pow(theta_scale, (float) p); + theta_base = pos[i2 + ne02] * dpct::pow(theta_scale, p); } const float freq_factor = has_ff ? freq_factors[i0 / 2] : 1.0f; - float cos_theta; - float sin_theta; - rope_yarn(theta_base / freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta); + + float cos_theta; + float sin_theta; + + rope_yarn<forward>(theta_base / freq_factor, freq_scale, corr_dims, i0, + ext_factor, attn_factor, cos_theta, sin_theta); + const float x0 = x[ix + 0]; const float x1 = x[ix + n_dims]; - // store results in dst - dst[idst + 0] = x0 * cos_theta - x1 * sin_theta; + dst[idst + 0] = x0 * cos_theta - x1 * sin_theta; dst[idst + n_dims] = x0 * sin_theta + x1 * cos_theta; } -template <typename T> -static void rope_norm_sycl(const T * x, T * dst, const int ne0, const int ne1, const int s1, const int s2, - const int n_dims, int nr, const int32_t * pos, const float freq_scale, const float freq_base, - const float ext_factor, const float attn_factor, const rope_corr_dims corr_dims, - const float * freq_factors, queue_ptr stream) { - GGML_ASSERT(ne0 % 2 == 0); - const sycl::range<3> block_dims(1, SYCL_ROPE_BLOCK_SIZE, 1); - const int num_blocks_x = ceil_div(ne0, (2 * SYCL_ROPE_BLOCK_SIZE)); - const sycl::range<3> block_nums(1, num_blocks_x, nr); +template <bool forward, typename T, typename D> +static void +rope_norm_sycl(const T *x, D *dst, const int ne00, const int ne01, + const int ne02, const int s01, const int s02, const int s03, + const int s1, const int s2, const int s3, const int n_dims, + const int nr, const int32_t *pos, const float freq_scale, + const float freq_base, const float ext_factor, + const float attn_factor, const rope_corr_dims corr_dims, + const float *freq_factors, const int64_t *row_indices, + const int set_rows_stride, dpct::queue_ptr stream) { + GGML_ASSERT(ne00 % 2 == 0); + const dpct::dim3 block_dims(1, SYCL_ROPE_BLOCK_SIZE, 1); + const int n_blocks_x = + (ne00 + 2 * SYCL_ROPE_BLOCK_SIZE - 1) / (2 * SYCL_ROPE_BLOCK_SIZE); + const dpct::dim3 block_nums(nr, n_blocks_x, 1); const float theta_scale = powf(freq_base, -2.0f / n_dims); - dpct::has_capability_or_fail(stream->get_device(), { sycl::aspect::fp16 }); - if (freq_factors == nullptr) { - /* - DPCT1049:40: The work-group size passed to the SYCL kernel may exceed - the limit. To get the device limit, query - info::device::max_work_group_size. Adjust the work-group size if needed. - */ - stream->parallel_for(sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) { - rope_norm<T, false>(x, dst, ne0, ne1, s1, s2, n_dims, pos, freq_scale, ext_factor, attn_factor, corr_dims, - theta_scale, freq_factors, item_ct1); - }); + stream->parallel_for( + sycl::nd_range<3>(block_nums * block_dims, block_dims), + [=](sycl::nd_item<3> item_ct1) { + GGML_UNUSED(item_ct1); + rope_norm<forward, false>( + x, dst, ne00, ne01, ne02, s01, s02, s03, s1, s2, s3, n_dims, + pos, freq_scale, ext_factor, attn_factor, corr_dims, + theta_scale, freq_factors, row_indices, set_rows_stride); + }); } else { - /* - DPCT1049:41: The work-group size passed to the SYCL kernel may exceed - the limit. To get the device limit, query - info::device::max_work_group_size. Adjust the work-group size if needed. - */ - stream->parallel_for(sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) { - rope_norm<T, true>(x, dst, ne0, ne1, s1, s2, n_dims, pos, freq_scale, ext_factor, attn_factor, corr_dims, - theta_scale, freq_factors, item_ct1); - }); + stream->parallel_for( + sycl::nd_range<3>(block_nums * block_dims, block_dims), + [=](sycl::nd_item<3> item_ct1) { + GGML_UNUSED(item_ct1); + rope_norm<forward, true>( + x, dst, ne00, ne01, ne02, s01, s02, s03, s1, s2, s3, n_dims, + pos, freq_scale, ext_factor, attn_factor, corr_dims, + theta_scale, freq_factors, row_indices, set_rows_stride); + }); } } -template <typename T> -static void rope_neox_sycl(const T * x, T * dst, const int ne0, const int ne1, const int s1, const int s2, - const int n_dims, const int nr, const int32_t * pos, const float freq_scale, - const float freq_base, const float ext_factor, const float attn_factor, - const rope_corr_dims corr_dims, const float * freq_factors, queue_ptr stream) { - GGML_ASSERT(ne0 % 2 == 0); - const sycl::range<3> block_dims(1, SYCL_ROPE_BLOCK_SIZE, 1); - const int num_blocks_x = ceil_div(ne0, (2 * SYCL_ROPE_BLOCK_SIZE)); - const sycl::range<3> block_nums(1, num_blocks_x, nr); +template <bool forward, typename T, typename D> +static void +rope_neox_sycl(const T *x, D *dst, const int ne00, const int ne01, + const int ne02, const int s01, const int s02, const int s03, + const int s1, const int s2, const int s3, const int n_dims, + const int nr, const int32_t *pos, const float freq_scale, + const float freq_base, const float ext_factor, + const float attn_factor, const rope_corr_dims corr_dims, + const float *freq_factors, const int64_t *row_indices, + const int set_rows_stride, dpct::queue_ptr stream) { + GGML_ASSERT(ne00 % 2 == 0); + const dpct::dim3 block_dims(1, SYCL_ROPE_BLOCK_SIZE, 1); + const int n_blocks_x = + (ne00 + 2 * SYCL_ROPE_BLOCK_SIZE - 1) / (2 * SYCL_ROPE_BLOCK_SIZE); + const dpct::dim3 block_nums(nr, n_blocks_x, 1); const float theta_scale = powf(freq_base, -2.0f / n_dims); - dpct::has_capability_or_fail(stream->get_device(), { sycl::aspect::fp16 }); - if (freq_factors == nullptr) { - stream->parallel_for(sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) { - rope_neox<T, false>(x, dst, ne0, ne1, s1, s2, n_dims, pos, freq_scale, ext_factor, attn_factor, corr_dims, - theta_scale, freq_factors, item_ct1); - }); + stream->parallel_for( + sycl::nd_range<3>(block_nums * block_dims, block_dims), + [=](sycl::nd_item<3> item_ct1) { + GGML_UNUSED(item_ct1); + rope_neox<forward, false>( + x, dst, ne00, ne01, ne02, s01, s02, s03, s1, s2, s3, n_dims, + pos, freq_scale, ext_factor, attn_factor, corr_dims, + theta_scale, freq_factors, row_indices, set_rows_stride); + }); } else { - stream->parallel_for(sycl::nd_range<3>(block_nums * block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) { - rope_neox<T, true>(x, dst, ne0, ne1, s1, s2, n_dims, pos, freq_scale, ext_factor, attn_factor, corr_dims, - theta_scale, freq_factors, item_ct1); - }); + stream->parallel_for( + sycl::nd_range<3>(block_nums * block_dims, block_dims), + [=](sycl::nd_item<3> item_ct1) { + GGML_UNUSED(item_ct1); + rope_neox<forward, true>( + x, dst, ne00, ne01, ne02, s01, s02, s03, s1, s2, s3, n_dims, + pos, freq_scale, ext_factor, attn_factor, corr_dims, + theta_scale, freq_factors, row_indices, set_rows_stride); + }); } } -template <typename T> -static void rope_multi_sycl(const T * x, T * dst, const int ne0, const int ne1, const int ne2, const size_t s1, - const size_t s2, const int n_dims, const int nr, const int32_t * pos, - const float freq_scale, const float freq_base, const float ext_factor, - const float attn_factor, const rope_corr_dims corr_dims, const float * freq_factors, - const mrope_sections sections, const bool is_imrope, queue_ptr stream) { - GGML_ASSERT(ne0 % 2 == 0); - const sycl::range<3> block_dims(1, SYCL_ROPE_BLOCK_SIZE, 1); - const int n_blocks_y = ceil_div(ne0, (2 * SYCL_ROPE_BLOCK_SIZE)); - const sycl::range<3> grid_dims(1, n_blocks_y, nr); - const sycl::nd_range<3> nd_range(grid_dims * block_dims, block_dims); - - const float theta_scale = std::pow(freq_base, -2.0f / n_dims); - // Add FP16 capability check if T could be sycl::half - if constexpr (std::is_same_v<T, sycl::half>) { - dpct::has_capability_or_fail(stream->get_device(), { sycl::aspect::fp16 }); - } - // launch kernel +template <bool forward, typename T> +static void +rope_multi_sycl(const T *x, T *dst, const int ne00, const int ne01, + const int ne02, const int s01, const int s02, const int s03, + const int s1, const int s2, const int s3, const int n_dims, + const int nr, const int32_t *pos, const float freq_scale, + const float freq_base, const float ext_factor, + const float attn_factor, const rope_corr_dims corr_dims, + const float *freq_factors, const mrope_sections sections, + const bool is_imrope, dpct::queue_ptr stream) { + GGML_ASSERT(ne00 % 2 == 0); + const dpct::dim3 block_dims(1, SYCL_ROPE_BLOCK_SIZE, 1); + const int n_blocks_x = + (ne00 + 2 * SYCL_ROPE_BLOCK_SIZE - 1) / (2 * SYCL_ROPE_BLOCK_SIZE); + const dpct::dim3 block_nums(nr, n_blocks_x, 1); + + const float theta_scale = powf(freq_base, -2.0f / n_dims); + if (freq_factors == nullptr) { - stream->parallel_for(nd_range, [=](sycl::nd_item<3> item_ct1) { - rope_multi<T, false>(x, dst, ne0, ne1, ne2, s1, s2, n_dims, pos, freq_scale, ext_factor, attn_factor, - corr_dims, theta_scale, freq_factors, sections, is_imrope, item_ct1); - }); + stream->parallel_for( + sycl::nd_range<3>(block_nums * block_dims, block_dims), + [=](sycl::nd_item<3> item_ct1) { + GGML_UNUSED(item_ct1); + rope_multi<forward, false, T>( + x, dst, ne00, ne01, ne02, s01, s02, s03, s1, s2, s3, n_dims, + pos, freq_scale, ext_factor, attn_factor, corr_dims, + theta_scale, freq_factors, sections, is_imrope); + }); } else { - stream->parallel_for(nd_range, [=](sycl::nd_item<3> item_ct1) { - rope_multi<T, true>(x, dst, ne0, ne1, ne2, s1, s2, n_dims, pos, freq_scale, ext_factor, attn_factor, - corr_dims, theta_scale, freq_factors, sections, is_imrope, item_ct1); - }); + stream->parallel_for( + sycl::nd_range<3>(block_nums * block_dims, block_dims), + [=](sycl::nd_item<3> item_ct1) { + GGML_UNUSED(item_ct1); + rope_multi<forward, true, T>( + x, dst, ne00, ne01, ne02, s01, s02, s03, s1, s2, s3, n_dims, + pos, freq_scale, ext_factor, attn_factor, corr_dims, + theta_scale, freq_factors, sections, is_imrope); + }); } } +template <bool forward, typename T> +static void +rope_vision_sycl(const T *x, T *dst, const int ne00, const int ne01, + const int ne02, const int s01, const int s02, const int s03, + const int s1, const int s2, const int s3, const int n_dims, + const int nr, const int32_t *pos, const float freq_scale, + const float freq_base, const float ext_factor, + const float attn_factor, const rope_corr_dims corr_dims, + const float *freq_factors, const mrope_sections sections, + dpct::queue_ptr stream) { + GGML_ASSERT(ne00 % 2 == 0); + const dpct::dim3 block_dims(1, SYCL_ROPE_BLOCK_SIZE, 1); + const int n_blocks_x = + (ne00 + 2 * SYCL_ROPE_BLOCK_SIZE - 1) / (2 * SYCL_ROPE_BLOCK_SIZE); + const dpct::dim3 block_nums(nr, n_blocks_x, 1); + const float theta_scale = powf(freq_base, -2.0f / n_dims); - -// rope vision -template <typename T> -static void rope_vision_sycl(const T * x, T * dst, const int ne0, const int ne1, const int ne2, const size_t s1, - const size_t s2, const int n_dims, const int nr, const int32_t * pos, - const float freq_scale, const float freq_base, const float ext_factor, - const float attn_factor, const rope_corr_dims corr_dims, const float * freq_factors, - const mrope_sections sections, queue_ptr stream) { - GGML_ASSERT(ne0 % 2 == 0); - const sycl::range<3> block_dims(1, SYCL_ROPE_BLOCK_SIZE, 1); - const int n_blocks_y = ceil_div(ne0, (2 * SYCL_ROPE_BLOCK_SIZE)); - const sycl::range<3> grid_dims(1, n_blocks_y, nr); - const sycl::nd_range<3> nd_range(grid_dims * block_dims, block_dims); - - const float theta_scale = std::pow(freq_base, -2.0f / n_dims); - // Add FP16 capability check if T could be sycl::half - if constexpr (std::is_same_v<T, sycl::half>) { - dpct::has_capability_or_fail(stream->get_device(), { sycl::aspect::fp16 }); - } - // launch kernel if (freq_factors == nullptr) { - stream->parallel_for(nd_range, [=](sycl::nd_item<3> item_ct1) { - rope_vision<T, false>(x, dst, ne0, ne1, ne2, s1, s2, n_dims, pos, freq_scale, ext_factor, attn_factor, - corr_dims, theta_scale, freq_factors, sections, item_ct1); - }); + stream->parallel_for( + sycl::nd_range<3>(block_nums * block_dims, block_dims), + [=](sycl::nd_item<3> item_ct1) { + GGML_UNUSED(item_ct1); + rope_vision<forward, false, T>( + x, dst, ne00, ne01, ne02, s01, s02, s03, s1, s2, s3, n_dims, + pos, freq_scale, ext_factor, attn_factor, corr_dims, + theta_scale, freq_factors, sections); + }); } else { - stream->parallel_for(nd_range, [=](sycl::nd_item<3> item_ct1) { - rope_vision<T, true>(x, dst, ne0, ne1, ne2, s1, s2, n_dims, pos, freq_scale, ext_factor, attn_factor, - corr_dims, theta_scale, freq_factors, sections, item_ct1); - }); + stream->parallel_for( + sycl::nd_range<3>(block_nums * block_dims, block_dims), + [=](sycl::nd_item<3> item_ct1) { + GGML_UNUSED(item_ct1); + rope_vision<forward, true, T>( + x, dst, ne00, ne01, ne02, s01, s02, s03, s1, s2, s3, n_dims, + pos, freq_scale, ext_factor, attn_factor, corr_dims, + theta_scale, freq_factors, sections); + }); } } -inline void ggml_sycl_op_rope(ggml_backend_sycl_context & ctx, ggml_tensor *dst) { +template <bool forward> +void ggml_sycl_op_rope_impl(ggml_backend_sycl_context &ctx, ggml_tensor *dst, + const ggml_tensor *set_rows = nullptr) { + const ggml_tensor *src0 = dst->src[0]; + const ggml_tensor *src1 = dst->src[1]; + const ggml_tensor *src2 = dst->src[2]; + + const float *src0_d = (const float *)src0->data; + const float *src1_d = (const float *)src1->data; + + void *dst_d = dst->data; + const int64_t *row_indices = nullptr; + ggml_type dst_type = dst->type; + int set_rows_stride = 0; + + if (set_rows != nullptr) { + GGML_ASSERT(forward); + dst_d = set_rows->data; + row_indices = (const int64_t *)set_rows->src[1]->data; + dst_type = set_rows->type; + set_rows_stride = set_rows->nb[1] / ggml_type_size(set_rows->type); + } + dpct::queue_ptr stream = ctx.stream(); + + GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16); + GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16); + GGML_ASSERT(src0->type == dst->type || + (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F16)); - GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16); - GGML_ASSERT( dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16); - GGML_ASSERT(dst->src[0]->type == dst->type); - const int64_t ne00 = dst->src[0]->ne[0]; // head dims - const int64_t ne01 = dst->src[0]->ne[1]; // num heads - const int64_t ne02 = dst->src[0]->ne[2]; // num heads - const int64_t nr = ggml_nrows(dst->src[0]); + const int64_t ne00 = src0->ne[0]; // head dims + const int64_t ne01 = src0->ne[1]; // num heads + const int64_t ne02 = src0->ne[2]; // num heads + const int64_t nr = ggml_nrows(src0); - const size_t s01 = dst->src[0]->nb[1] / ggml_type_size(dst->src[0]->type); - const size_t s02 = dst->src[0]->nb[2] / ggml_type_size(dst->src[0]->type); + const size_t s01 = src0->nb[1] / ggml_type_size(src0->type); + const size_t s02 = src0->nb[2] / ggml_type_size(src0->type); + const size_t s03 = src0->nb[3] / ggml_type_size(src0->type); + const size_t s1 = dst->nb[1] / ggml_type_size(dst->type); + const size_t s2 = dst->nb[2] / ggml_type_size(dst->type); + const size_t s3 = dst->nb[3] / ggml_type_size(dst->type); - //const int n_past = ((int32_t *) dst->op_params)[0]; - const int n_dims = ((int32_t *) dst->op_params)[1]; - const int mode = ((int32_t *) dst->op_params)[2]; - //const int n_ctx = ((int32_t *) dst->op_params)[3]; - const int n_ctx_orig = ((int32_t *) dst->op_params)[4]; + const int n_dims = ((int32_t *)dst->op_params)[1]; + const int mode = ((int32_t *)dst->op_params)[2]; + const int n_ctx_orig = ((int32_t *)dst->op_params)[4]; mrope_sections sections; - // RoPE alteration for extended context float freq_base; float freq_scale; float ext_factor; @@ -383,13 +506,13 @@ inline void ggml_sycl_op_rope(ggml_backend_sycl_context & ctx, ggml_tensor *dst) float beta_fast; float beta_slow; - memcpy(&freq_base, (int32_t *) dst->op_params + 5, sizeof(float)); - memcpy(&freq_scale, (int32_t *) dst->op_params + 6, sizeof(float)); - memcpy(&ext_factor, (int32_t *) dst->op_params + 7, sizeof(float)); - memcpy(&attn_factor, (int32_t *) dst->op_params + 8, sizeof(float)); - memcpy(&beta_fast, (int32_t *) dst->op_params + 9, sizeof(float)); - memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float)); - memcpy(§ions.v, (int32_t *) dst->op_params + 11, sizeof(int)*4); + memcpy(&freq_base, (int32_t *)dst->op_params + 5, sizeof(float)); + memcpy(&freq_scale, (int32_t *)dst->op_params + 6, sizeof(float)); + memcpy(&ext_factor, (int32_t *)dst->op_params + 7, sizeof(float)); + memcpy(&attn_factor, (int32_t *)dst->op_params + 8, sizeof(float)); + memcpy(&beta_fast, (int32_t *)dst->op_params + 9, sizeof(float)); + memcpy(&beta_slow, (int32_t *)dst->op_params + 10, sizeof(float)); + memcpy(§ions.v, (int32_t *)dst->op_params + 11, sizeof(int) * 4); const bool is_neox = mode & GGML_ROPE_TYPE_NEOX; const bool is_mrope = mode & GGML_ROPE_TYPE_MROPE; @@ -397,82 +520,122 @@ inline void ggml_sycl_op_rope(ggml_backend_sycl_context & ctx, ggml_tensor *dst) const bool is_vision = mode == GGML_ROPE_TYPE_VISION; if (is_mrope) { - GGML_ASSERT(sections.v[0] > 0 || sections.v[1] > 0 || sections.v[2] > 0); + GGML_ASSERT(sections.v[0] > 0 || sections.v[1] > 0 || + sections.v[2] > 0); } if (is_vision) { - GGML_ASSERT(n_dims == ne00/2); + GGML_ASSERT(n_dims == ne00 / 2); } - const int32_t * pos = (const int32_t *) dst->src[1]->data; + const int32_t *pos = (const int32_t *)src1_d; - const float * freq_factors = nullptr; - if (dst->src[2] != nullptr) { - freq_factors = (const float *) dst->src[2]->data; + const float *freq_factors = nullptr; + if (src2 != nullptr) { + freq_factors = (const float *)src2->data; } rope_corr_dims corr_dims; - ggml_rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims.v); - - dpct::queue_ptr main_stream = ctx.stream(); - SYCL_CHECK(ggml_sycl_set_device(ctx.device)); + ggml_rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, + beta_slow, corr_dims.v); // compute if (is_neox) { GGML_SYCL_DEBUG("%s: neox path\n", __func__); - if (dst->src[0]->type == GGML_TYPE_F32) { - rope_neox_sycl((const float *) dst->src[0]->data, (float *) dst->data, ne00, ne01, s01, s02, n_dims, nr, - pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, main_stream); - } else if (dst->src[0]->type == GGML_TYPE_F16) { - rope_neox_sycl((const sycl::half *) dst->src[0]->data, (sycl::half *) dst->data, ne00, ne01, s01, s02, - n_dims, nr, pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, - main_stream); + if (src0->type == GGML_TYPE_F32 && dst_type == GGML_TYPE_F32) { + rope_neox_sycl<forward, float, float>( + (const float *)src0_d, (float *)dst_d, ne00, ne01, ne02, s01, + s02, s03, s1, s2, s3, n_dims, nr, pos, freq_scale, freq_base, + ext_factor, attn_factor, corr_dims, freq_factors, row_indices, + set_rows_stride, stream); + } else if (src0->type == GGML_TYPE_F32 && dst_type == GGML_TYPE_F16) { + rope_neox_sycl<forward, float, sycl::half>( + (const float *)src0_d, (sycl::half *)dst_d, ne00, ne01, ne02, + s01, s02, s03, s1, s2, s3, n_dims, nr, pos, freq_scale, + freq_base, ext_factor, attn_factor, corr_dims, freq_factors, + row_indices, set_rows_stride, stream); + } else if (src0->type == GGML_TYPE_F16 && dst_type == GGML_TYPE_F16) { + rope_neox_sycl<forward, sycl::half, sycl::half>( + (const sycl::half *)src0_d, (sycl::half *)dst_d, ne00, ne01, + ne02, s01, s02, s03, s1, s2, s3, n_dims, nr, pos, freq_scale, + freq_base, ext_factor, attn_factor, corr_dims, freq_factors, + row_indices, set_rows_stride, stream); } else { - GGML_ABORT("fatal error"); + GGML_ABORT("Fatal error: Tensor type unsupported!"); } } else if (is_mrope && !is_vision) { GGML_SYCL_DEBUG("%s: mrope path\n", __func__); - if (dst->src[0]->type == GGML_TYPE_F16) { - rope_multi_sycl((const sycl::half *)dst->src[0]->data, (sycl::half *)dst->data, ne00, ne01, ne02, s01, - s02, n_dims, nr, pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims, - freq_factors, sections, is_imrope, main_stream); - } else if (dst->src[0]->type == GGML_TYPE_F32) { - rope_multi_sycl((const float *) dst->src[0]->data, (float *) dst->data, ne00, ne01, ne02, s01, s02, n_dims, - nr, pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, sections, - is_imrope, main_stream); + if (src0->type == GGML_TYPE_F32) { + rope_multi_sycl<forward>((const float *)src0_d, (float *)dst_d, + ne00, ne01, ne02, s01, s02, s03, s1, s2, + s3, n_dims, nr, pos, freq_scale, freq_base, + ext_factor, attn_factor, corr_dims, + freq_factors, sections, is_imrope, stream); + } else if (src0->type == GGML_TYPE_F16) { + rope_multi_sycl<forward>( + (const sycl::half *)src0_d, (sycl::half *)dst_d, ne00, ne01, + ne02, s01, s02, s03, s1, s2, s3, n_dims, nr, pos, freq_scale, + freq_base, ext_factor, attn_factor, corr_dims, freq_factors, + sections, is_imrope, stream); } else { GGML_ABORT("Fatal error: Tensor type unsupported!"); } } else if (is_vision) { GGML_SYCL_DEBUG("%s: vision path\n", __func__); - if (dst->src[0]->type == GGML_TYPE_F16) { - rope_vision_sycl((const sycl::half *) dst->src[0]->data, (sycl::half *) dst->data, ne00, ne01, ne02, s01, - s02, n_dims, nr, pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims, - freq_factors, sections, main_stream); - } else if (dst->src[0]->type == GGML_TYPE_F32) { - rope_vision_sycl((const float *) dst->src[0]->data, (float *) dst->data, ne00, ne01, ne02, s01, s02, n_dims, - nr, pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, sections, - main_stream); + if (src0->type == GGML_TYPE_F32) { + rope_vision_sycl<forward>( + (const float *)src0_d, (float *)dst_d, ne00, ne01, ne02, s01, + s02, s03, s1, s2, s3, n_dims, nr, pos, freq_scale, freq_base, + ext_factor, attn_factor, corr_dims, freq_factors, sections, + stream); + } else if (src0->type == GGML_TYPE_F16) { + rope_vision_sycl<forward>( + (const sycl::half *)src0_d, (sycl::half *)dst_d, ne00, ne01, + ne02, s01, s02, s03, s1, s2, s3, n_dims, nr, pos, freq_scale, + freq_base, ext_factor, attn_factor, corr_dims, freq_factors, + sections, stream); } else { GGML_ABORT("Fatal error: Tensor type unsupported!"); } } else { GGML_SYCL_DEBUG("%s: norm path\n", __func__); - if (dst->src[0]->type == GGML_TYPE_F32) { - rope_norm_sycl((const float *) dst->src[0]->data, (float *) dst->data, ne00, ne01, s01, s02, n_dims, nr, - pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, main_stream); - } else if (dst->src[0]->type == GGML_TYPE_F16) { - rope_norm_sycl((const sycl::half *) dst->src[0]->data, (sycl::half *) dst->data, ne00, ne01, s01, s02, - n_dims, nr, pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims, freq_factors, - main_stream); + if (src0->type == GGML_TYPE_F32 && dst_type == GGML_TYPE_F32) { + rope_norm_sycl<forward, float, float>( + (const float *)src0_d, (float *)dst_d, ne00, ne01, ne02, s01, + s02, s03, s1, s2, s3, n_dims, nr, pos, freq_scale, freq_base, + ext_factor, attn_factor, corr_dims, freq_factors, row_indices, + set_rows_stride, stream); + } else if (src0->type == GGML_TYPE_F32 && dst_type == GGML_TYPE_F16) { + rope_norm_sycl<forward, float, sycl::half>( + (const float *)src0_d, (sycl::half *)dst_d, ne00, ne01, ne02, + s01, s02, s03, s1, s2, s3, n_dims, nr, pos, freq_scale, + freq_base, ext_factor, attn_factor, corr_dims, freq_factors, + row_indices, set_rows_stride, stream); + } else if (src0->type == GGML_TYPE_F16 && dst_type == GGML_TYPE_F16) { + rope_norm_sycl<forward, sycl::half, sycl::half>( + (const sycl::half *)src0_d, (sycl::half *)dst_d, ne00, ne01, + ne02, s01, s02, s03, s1, s2, s3, n_dims, nr, pos, freq_scale, + freq_base, ext_factor, attn_factor, corr_dims, freq_factors, + row_indices, set_rows_stride, stream); } else { - GGML_ABORT("fatal error"); + GGML_ABORT("Fatal error: Tensor type unsupported!"); } } } -void ggml_sycl_rope(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { +void ggml_sycl_rope(ggml_backend_sycl_context &ctx, ggml_tensor *dst) { scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/3); - ggml_sycl_op_rope(ctx, dst); + + ggml_sycl_op_rope_impl<true>(ctx, dst); } +void ggml_sycl_rope_back(ggml_backend_sycl_context &ctx, ggml_tensor *dst) { + scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/3); + ggml_sycl_op_rope_impl<false>(ctx, dst); +} + +void ggml_sycl_rope_fused(ggml_backend_sycl_context &ctx, ggml_tensor *rope, + ggml_tensor *set_rows) { + scope_op_debug_print scope_dbg_print(__func__, rope, /*num_src=*/3); + ggml_sycl_op_rope_impl<true>(ctx, rope, set_rows); +} diff --git a/ggml/src/ggml-sycl/rope.hpp b/ggml/src/ggml-sycl/rope.hpp index 8c7141aac5c..b95a585808b 100644 --- a/ggml/src/ggml-sycl/rope.hpp +++ b/ggml/src/ggml-sycl/rope.hpp @@ -15,6 +15,12 @@ #include "common.hpp" +#define SYCL_ROPE_BLOCK_SIZE 256 + void ggml_sycl_rope(ggml_backend_sycl_context & ctx, ggml_tensor *dst); +void ggml_sycl_rope_back(ggml_backend_sycl_context & ctx, ggml_tensor * dst); + +void ggml_sycl_rope_fused(ggml_backend_sycl_context & ctx, ggml_tensor * dst, ggml_tensor * set_rows); + #endif // GGML_SYCL_ROPE_HPP diff --git a/ggml/src/ggml-sycl/set_rows.cpp b/ggml/src/ggml-sycl/set_rows.cpp index a641c100913..8fb41943525 100644 --- a/ggml/src/ggml-sycl/set_rows.cpp +++ b/ggml/src/ggml-sycl/set_rows.cpp @@ -4,7 +4,11 @@ namespace utils { template<typename T> static constexpr bool is_arithmetic_v() { - return std::is_arithmetic_v<T> || std::is_same_v<T, sycl::half> || std::is_same_v<T, sycl::ext::oneapi::bfloat16>; + return std::is_arithmetic_v<T> || std::is_same_v<T, sycl::half> +#ifdef GGML_SYCL_HAS_BF16 + || std::is_same_v<T, sycl::ext::oneapi::bfloat16> +#endif + ; } } @@ -181,6 +185,7 @@ static void set_rows_sycl(ggml_backend_sycl_context & ctx, const ggml_tensor * s stream ); break; +#ifdef GGML_SYCL_HAS_BF16 case GGML_TYPE_BF16: set_rows_sycl<TIn, TIdx, sycl::ext::oneapi::bfloat16>( src0_d, src1_d, (char *)dst->data, @@ -193,6 +198,7 @@ static void set_rows_sycl(ggml_backend_sycl_context & ctx, const ggml_tensor * s stream ); break; +#endif case GGML_TYPE_Q8_0: set_rows_sycl_q<TIdx, block_q8_0, QK8_0, cpy_blck_f32_q8_0>(src0_d, src1_d, (block_q8_0 *)dst->data, ne00, ne01, ne02, ne03, ne10, ne11, ne12, ne13, nb00, nb01, nb02, nb03, nb10, nb11, nb12, nb13, nb1, nb2, nb3, stream); break; diff --git a/ggml/src/ggml-sycl/softmax.cpp b/ggml/src/ggml-sycl/softmax.cpp index b41124acc13..fdf9b843e01 100644 --- a/ggml/src/ggml-sycl/softmax.cpp +++ b/ggml/src/ggml-sycl/softmax.cpp @@ -37,7 +37,7 @@ struct soft_max_params { }; // When ncols_template == 0 the bounds for the loops in this function are not known and can't be unrolled. -// As we want to keep pragma unroll for all other cases we supress the clang transformation warning here. +// As we want to keep pragma unroll for all other cases we suppress the clang transformation warning here. #ifdef __clang__ #pragma clang diagnostic push #pragma clang diagnostic ignored "-Wpass-failed" @@ -102,7 +102,7 @@ static void soft_max_f32(const float * x, max_val = sycl::max(max_val, val); } // find the max value in the block - max_val = warp_reduce_max(max_val); + max_val = warp_reduce_max<WARP_SIZE>(max_val); if (block_size > WARP_SIZE) { if (warp_id == 0) { @@ -116,7 +116,7 @@ static void soft_max_f32(const float * x, item_ct1.barrier(); max_val = buf_iw[lane_id]; - max_val = warp_reduce_max(max_val); + max_val = warp_reduce_max<WARP_SIZE>(max_val); } float tmp = 0.0f; // partial sum @@ -133,7 +133,7 @@ static void soft_max_f32(const float * x, vals[col] = val; } // find the sum of exps in the block - tmp = warp_reduce_sum(tmp); + tmp = warp_reduce_sum<WARP_SIZE>(tmp); if (block_size > WARP_SIZE) { item_ct1.barrier(); if (warp_id == 0) { @@ -153,7 +153,7 @@ static void soft_max_f32(const float * x, for (size_t i = 1; i < nreduce; i += 1) { tmp += buf_iw[lane_id + i * WARP_SIZE]; } - tmp = warp_reduce_sum(tmp); + tmp = warp_reduce_sum<WARP_SIZE>(tmp); } if (sinks) { tmp += sycl::native::exp(sinks[i02] - max_val); @@ -191,7 +191,7 @@ static void soft_max_back_f32(const float *grad, const float *dstf, float *dst, dgf_dot += dstf[col]*grad[col]; } - dgf_dot = warp_reduce_sum(dgf_dot); + dgf_dot = warp_reduce_sum<WARP_SIZE>(dgf_dot); for (int col = tid; col < ncols; col += WARP_SIZE) { dst[col] = scale * (grad[col] - dgf_dot) * dstf[col]; diff --git a/ggml/src/ggml-sycl/solve_tri.cpp b/ggml/src/ggml-sycl/solve_tri.cpp new file mode 100644 index 00000000000..39326deee44 --- /dev/null +++ b/ggml/src/ggml-sycl/solve_tri.cpp @@ -0,0 +1,172 @@ +#include "solve_tri.hpp" +#include "common.hpp" +#include <oneapi/mkl/blas.hpp> + +template <int n_template, int k_template> +static void solve_tri_f32_fast(const float * __restrict__ A, + const float * __restrict__ B, + float * __restrict__ X, + const int64_t ne02, [[maybe_unused]] const int64_t ne03, + const int64_t nb02, const int64_t nb03, + const int64_t nb12, const int64_t nb13, + const int64_t nb2, const int64_t nb3, + const int n_arg, const int k_arg, + const sycl::nd_item<2> & item, float * sA) { + + const int n = n_template == 0 ? n_arg : n_template; + const int k = k_template == 0 ? k_arg : k_template; + + const int batch_idx = item.get_group(1); + const int lane = item.get_local_id(1) % WARP_SIZE; + const int col_idx = item.get_local_id(0); + + if (col_idx >= k) { + return; + } + + const int64_t i03 = batch_idx / ne02; + const int64_t i02 = batch_idx % ne02; + + const float * A_batch = (const float *) ((const char *) A + i02 * nb02 + i03 * nb03); + const float * B_batch = (const float *) ((const char *) B + i02 * nb12 + i03 * nb13); + float * X_batch = (float *) ((char *) X + i02 * nb2 + i03 * nb3); + + const int offset = item.get_local_id(1) + item.get_local_id(0) * item.get_local_range(1); + +#pragma unroll + for (int i = 0; i < n * n; i += k * WARP_SIZE) { + const int i0 = i + offset; + if (i0 < n * n) { + sA[i0] = A_batch[i0]; + } + } + + item.barrier(sycl::access::fence_space::local_space); + + float x_low = (lane < n) ? B_batch[lane * k + col_idx] : 0.0f; + float x_high = (WARP_SIZE + lane < n) ? B_batch[(WARP_SIZE + lane) * k + col_idx] : 0.0f; + + const int half = WARP_SIZE; + const int nrows_low = (n < half) ? n : half; + +#pragma unroll + for (int row = 0; row < nrows_low; ++row) { + float sum = 0.0f; + if (lane < row) { + sum += sA[row * n + lane] * x_low; + } + sum = warp_reduce_sum<WARP_SIZE>(sum); + if (lane == row) { + x_low = (x_low - sum) / sA[row * n + row]; + } + } + +#pragma unroll + for (int row = half; row < n; ++row) { + float sum = sA[row * n + lane] * x_low; + const int j = half + lane; + if (j < row) { + sum += sA[row * n + j] * x_high; + } + sum = warp_reduce_sum<WARP_SIZE>(sum); + if (lane == row - half) { + x_high = (x_high - sum) / sA[row * n + row]; + } + } + +#pragma unroll + for (int rr = 0; rr < 2; ++rr) { + const int row = rr * WARP_SIZE + lane; + if (row < n) { + const float val = (row < half) ? x_low : x_high; + X_batch[row * k + col_idx] = val; + } + } +} + +static void solve_tri_f32_mkl(dpct::queue_ptr stream, + const float * A, float * X, + int n, int k, + int64_t ne02, [[maybe_unused]] int64_t ne03, + int64_t nb02, [[maybe_unused]] int64_t nb03, + int64_t nb2, [[maybe_unused]] int64_t nb3) { + const float alpha = 1.0f; + const int64_t total_batches = ne02 * ne03; + if (total_batches == 0) { + return; + } + + const int64_t stride_a = nb02 / sizeof(float); + const int64_t stride_x = nb2 / sizeof(float); + + oneapi::mkl::blas::trsm_batch( + *stream, + oneapi::mkl::side::right, + oneapi::mkl::uplo::upper, + oneapi::mkl::transpose::nontrans, + oneapi::mkl::diag::nonunit, + k, n, alpha, + A, n, stride_a, + X, k, stride_x, + total_batches); +} + +inline void ggml_sycl_op_solve_tri(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { + const ggml_tensor * src0 = dst->src[0]; + const ggml_tensor * src1 = dst->src[1]; + + GGML_ASSERT(ggml_is_contiguous(src0)); + GGML_ASSERT(ggml_is_contiguous(src1)); + GGML_ASSERT(src0->type == GGML_TYPE_F32); + + dpct::queue_ptr stream = ctx.stream(); + SYCL_CHECK(ggml_sycl_set_device(ctx.device)); + + const int n = src0->ne[0]; + const int k = src1->ne[0]; + const int64_t ne02 = src0->ne[2]; + const int64_t ne03 = src0->ne[3]; + + GGML_ASSERT(n <= SYCL_SOLVE_TRI_MAX_N && k <= SYCL_SOLVE_TRI_MAX_K); + + const float * A_d = static_cast<const float *>(src0->data); + const float * B_d = static_cast<const float *>(src1->data); + float * X_d = static_cast<float *>(dst->data); + + if (X_d != B_d) { + const int64_t total_elements = (int64_t)n * k * ne02 * ne03; + stream->memcpy(X_d, B_d, total_elements * sizeof(float)); + } + + const int64_t nb02 = src0->nb[2]; + const int64_t nb03 = src0->nb[3]; + const int64_t nb12 = src1->nb[2]; + const int64_t nb13 = src1->nb[3]; + const int64_t nb2 = dst->nb[2]; + const int64_t nb3 = dst->nb[3]; + + const int64_t total_batches = ne02 * ne03; + + if (n <= 2 * WARP_SIZE && k <= 32) { + const int smem_size = 2 * WARP_SIZE * 2 * WARP_SIZE; + const sycl::range<2> grid(1, total_batches); + const sycl::range<2> block(k, WARP_SIZE); + stream->submit([&](sycl::handler & cgh) { + sycl::local_accessor<float, 1> smem_acc(sycl::range<1>(smem_size), cgh); + cgh.parallel_for( + sycl::nd_range<2>(grid * block, block), + [=](sycl::nd_item<2> item) [[sycl::reqd_sub_group_size(WARP_SIZE)]] { + solve_tri_f32_fast<0, 0>(A_d, B_d, X_d, ne02, ne03, + nb02, nb03, nb12, nb13, nb2, nb3, + n, k, item, get_pointer(smem_acc)); + }); + }); + } else { + solve_tri_f32_mkl(stream, A_d, X_d, n, k, ne02, ne03, nb02, nb03, nb2, nb3); + } +} + +void ggml_sycl_solve_tri(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { + scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/2); + ggml_sycl_op_solve_tri(ctx, dst); +} diff --git a/ggml/src/ggml-sycl/solve_tri.hpp b/ggml/src/ggml-sycl/solve_tri.hpp new file mode 100644 index 00000000000..c7c34cfa2bb --- /dev/null +++ b/ggml/src/ggml-sycl/solve_tri.hpp @@ -0,0 +1,8 @@ +#pragma once + +#include "common.hpp" + +#define SYCL_SOLVE_TRI_MAX_N 64 +#define SYCL_SOLVE_TRI_MAX_K 64 + +void ggml_sycl_solve_tri(ggml_backend_sycl_context & ctx, ggml_tensor * dst); diff --git a/ggml/src/ggml-sycl/ssm_conv.cpp b/ggml/src/ggml-sycl/ssm_conv.cpp index eea9a73d67e..e55223586a1 100644 --- a/ggml/src/ggml-sycl/ssm_conv.cpp +++ b/ggml/src/ggml-sycl/ssm_conv.cpp @@ -63,7 +63,7 @@ static void kernel_ssm_conv( }); } -void ggml_sycl_ssm_conv(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { +inline void ggml_sycl_op_ssm_conv(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { ggml_tensor * src0 = dst->src[0]; ggml_tensor * src1 = dst->src[1]; @@ -125,3 +125,8 @@ void ggml_sycl_ssm_conv(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { throw; } } + +void ggml_sycl_ssm_conv(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { + scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/2); + ggml_sycl_op_ssm_conv(ctx, dst); +} diff --git a/ggml/src/ggml-sycl/ssm_scan.cpp b/ggml/src/ggml-sycl/ssm_scan.cpp new file mode 100644 index 00000000000..ae652981384 --- /dev/null +++ b/ggml/src/ggml-sycl/ssm_scan.cpp @@ -0,0 +1,156 @@ +#include "ssm_scan.hpp" +#include "common.hpp" + +template <int c_factor, int d_state> +static void ssm_scan_f32_group( + const float * __restrict__ src0, const float * __restrict__ src1, const float * __restrict__ src2, + const float * __restrict__ src3, const float * __restrict__ src4, const float * __restrict__ src5, + const int32_t * __restrict__ src6, float * __restrict__ dst, + const int src0_nb2, const int src0_nb3, const int src1_nb2, const int src1_nb3, + const int src2_nb1, const int src2_nb2, const int src3_nb1, + const int src4_nb2, const int src4_nb3, const int src5_nb2, const int src5_nb3, + const int64_t s_off, const int64_t n_head, const int64_t d_head, const int64_t n_group, const int64_t n_tok, + const sycl::nd_item<2> & item) { + + const int lane = item.get_local_id(1) % WARP_SIZE; + const int warp = item.get_local_id(1) / WARP_SIZE; + const int warp_idx = item.get_group(1) * c_factor + warp; + const int seq_idx = item.get_group(0); + + const int head_idx = warp_idx / d_head; + const int head_off = (warp_idx % d_head) * sizeof(float); + const int group_off = (head_idx / (n_head / n_group)) * d_state * sizeof(float); + + const float * s0_warp = (const float *) ((const char *) src0 + src6[seq_idx] * src0_nb3 + head_idx * src0_nb2 + head_off * d_state); + const float * x_warp = (const float *) ((const char *) src1 + (seq_idx * src1_nb3) + (warp_idx * sizeof(float))); + const float * dt_warp = (const float *) ((const char *) src2 + (seq_idx * src2_nb2) + head_idx * sizeof(float)); + const float * A_warp = (const float *) ((const char *) src3 + head_idx * src3_nb1); + const float * B_warp = (const float *) ((const char *) src4 + (seq_idx * src4_nb3) + (group_off)); + const float * C_warp = (const float *) ((const char *) src5 + (seq_idx * src5_nb3) + (group_off)); + float * y_warp = dst + (seq_idx * n_tok * n_head * d_head) + warp_idx; + float * s_warp = (float *) ((char *) dst + s_off + seq_idx * src0_nb3 + head_idx * src0_nb2 + head_off * d_state); + + const int stride_x = src1_nb2 / sizeof(float); + const int stride_dt = src2_nb1 / sizeof(float); + const int stride_B = src4_nb2 / sizeof(float); + const int stride_C = src5_nb2 / sizeof(float); + const int stride_y = n_head * d_head; + + float state[c_factor]; + float state_sum = 0.0f; + +#pragma unroll + for (int j = 0; j < c_factor; j++) { + state[j] = s0_warp[WARP_SIZE * j + lane]; + } + + for (int64_t i = 0; i < n_tok; i++) { + const float dt_val = dt_warp[i * stride_dt]; + const float dt_soft_plus = (dt_val <= 20.0f ? sycl::log1p(sycl::exp(dt_val)) : dt_val); + + state_sum = 0.0f; + const float dA = sycl::exp(dt_soft_plus * A_warp[0]); + const float x_dt = x_warp[i * stride_x] * dt_soft_plus; +#pragma unroll + for (int j = 0; j < c_factor; j++) { + const float B_val = B_warp[i * stride_B + WARP_SIZE * j + lane]; + const float C_val = C_warp[i * stride_C + WARP_SIZE * j + lane]; + state[j] = (state[j] * dA) + (B_val * x_dt); + state_sum += state[j] * C_val; + } + + state_sum = warp_reduce_sum<WARP_SIZE>(state_sum); + + if (lane == 0) { + y_warp[i * stride_y] = state_sum; + } + } + +#pragma unroll + for (int j = 0; j < c_factor; j++) { + s_warp[WARP_SIZE * j + lane] = state[j]; + } +} + +static void ssm_scan_f32_sycl( + const float * src0, const float * src1, const float * src2, const float * src3, + const float * src4, const float * src5, const int32_t * src6, float * dst, + const int src0_nb2, const int src0_nb3, const int src1_nb2, const int src1_nb3, const int src2_nb1, + const int src2_nb2, const int src3_nb1, const int src4_nb2, const int src4_nb3, const int src5_nb2, + const int src5_nb3, const int64_t s_off, const int64_t d_state, const int64_t head_dim, + const int64_t n_head, const int64_t n_group, const int64_t n_tok, const int64_t n_seq, + dpct::queue_ptr stream) { + + // NOTE: if you change conditions here, be sure to update the corresponding supports_op condition! + GGML_ASSERT(src3_nb1 == sizeof(float)); + if (d_state == 128) { + constexpr int threads = 128; + constexpr int num_warps = threads / WARP_SIZE; + const sycl::range<2> grid(n_seq, (n_head * head_dim + num_warps - 1) / num_warps); + const sycl::range<2> block(1, threads); + stream->parallel_for( + sycl::nd_range<2>(grid * block, block), + [=](sycl::nd_item<2> item) [[sycl::reqd_sub_group_size(WARP_SIZE)]] { + ssm_scan_f32_group<128 / WARP_SIZE, 128>( + src0, src1, src2, src3, src4, src5, src6, dst, + src0_nb2, src0_nb3, src1_nb2, src1_nb3, src2_nb1, src2_nb2, src3_nb1, + src4_nb2, src4_nb3, src5_nb2, src5_nb3, s_off, n_head, head_dim, n_group, n_tok, item); + }); + } else if (d_state == 256) { + constexpr int threads = 256; + constexpr int num_warps = threads / WARP_SIZE; + const sycl::range<2> grid(n_seq, (n_head * head_dim + num_warps - 1) / num_warps); + const sycl::range<2> block(1, threads); + stream->parallel_for( + sycl::nd_range<2>(grid * block, block), + [=](sycl::nd_item<2> item) [[sycl::reqd_sub_group_size(WARP_SIZE)]] { + ssm_scan_f32_group<256 / WARP_SIZE, 256>( + src0, src1, src2, src3, src4, src5, src6, dst, + src0_nb2, src0_nb3, src1_nb2, src1_nb3, src2_nb1, src2_nb2, src3_nb1, + src4_nb2, src4_nb3, src5_nb2, src5_nb3, s_off, n_head, head_dim, n_group, n_tok, item); + }); + } else { + GGML_ABORT("ssm_scan: unsupported d_state (must be 128 or 256)"); + } +} + +inline void ggml_sycl_op_ssm_scan(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { + const ggml_tensor * src0 = dst->src[0]; + const ggml_tensor * src1 = dst->src[1]; + const ggml_tensor * src2 = dst->src[2]; + const ggml_tensor * src3 = dst->src[3]; + const ggml_tensor * src4 = dst->src[4]; + const ggml_tensor * src5 = dst->src[5]; + const ggml_tensor * src6 = dst->src[6]; + + GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT(src6->type == GGML_TYPE_I32); + GGML_ASSERT(dst->type == GGML_TYPE_F32); + + const int64_t nc = src0->ne[0]; + const int64_t nr = src0->ne[1]; + const int64_t nh = src1->ne[1]; + const int64_t ng = src4->ne[1]; + const int64_t n_t = src1->ne[2]; + const int64_t n_s = src1->ne[3]; + const int64_t s_off = ggml_nelements(src1) * sizeof(float); + + GGML_ASSERT(ggml_nelements(src1) + nc * nr * nh * n_s == ggml_nelements(dst)); + + dpct::queue_ptr stream = ctx.stream(); + SYCL_CHECK(ggml_sycl_set_device(ctx.device)); + + ssm_scan_f32_sycl( + static_cast<const float *>(src0->data), static_cast<const float *>(src1->data), + static_cast<const float *>(src2->data), static_cast<const float *>(src3->data), + static_cast<const float *>(src4->data), static_cast<const float *>(src5->data), + static_cast<const int32_t *>(src6->data), static_cast<float *>(dst->data), + src0->nb[2], src0->nb[3], src1->nb[2], src1->nb[3], src2->nb[1], src2->nb[2], + src3->nb[1], src4->nb[2], src4->nb[3], src5->nb[2], src5->nb[3], + s_off, nc, nr, nh, ng, n_t, n_s, stream); +} + +void ggml_sycl_ssm_scan(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { + scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/7); + ggml_sycl_op_ssm_scan(ctx, dst); +} diff --git a/ggml/src/ggml-sycl/ssm_scan.hpp b/ggml/src/ggml-sycl/ssm_scan.hpp new file mode 100644 index 00000000000..1f9731fb6fd --- /dev/null +++ b/ggml/src/ggml-sycl/ssm_scan.hpp @@ -0,0 +1,5 @@ +#pragma once + +#include "common.hpp" + +void ggml_sycl_ssm_scan(ggml_backend_sycl_context & ctx, ggml_tensor * dst); diff --git a/ggml/src/ggml-sycl/sycl_hw.cpp b/ggml/src/ggml-sycl/sycl_hw.cpp index 7041140034b..03b0c37a3cd 100644 --- a/ggml/src/ggml-sycl/sycl_hw.cpp +++ b/ggml/src/ggml-sycl/sycl_hw.cpp @@ -1,15 +1,67 @@ #include "sycl_hw.hpp" -// TODO: currently not used -/* -sycl_hw_info get_device_hw_info(sycl::device *device_ptr) { - sycl_hw_info res; - int32_t id = device_ptr->get_info<sycl::ext::intel::info::device::device_id>(); - res.device_id = id; +using namespace std; - syclex::architecture arch = device_ptr->get_info<syclex::info::device::architecture>(); - res.arch = arch; +/*defined in +* /opt/intel/oneapi/compiler/latest/include/sycl/ext/oneapi/experimental/device_architecture.def +*/ +static map<gpu_arch, std::pair<const char*, sycl_intel_gpu_family>> arch2name = { + {gpu_arch::intel_gpu_bdw, {"intel_gpu_bdw", GPU_FAMILY_IGPU_NON_XE}}, + {gpu_arch::intel_gpu_skl, {"intel_gpu_skl", GPU_FAMILY_IGPU_NON_XE}}, + {gpu_arch::intel_gpu_kbl, {"intel_gpu_kbl", GPU_FAMILY_IGPU_NON_XE}}, + {gpu_arch::intel_gpu_cfl, {"intel_gpu_cfl", GPU_FAMILY_IGPU_NON_XE}}, + {gpu_arch::intel_gpu_apl, {"intel_gpu_apl", GPU_FAMILY_IGPU_NON_XE}}, + {gpu_arch::intel_gpu_glk, {"intel_gpu_glk", GPU_FAMILY_IGPU_NON_XE}}, + {gpu_arch::intel_gpu_whl, {"intel_gpu_whl", GPU_FAMILY_IGPU_NON_XE}}, + {gpu_arch::intel_gpu_aml, {"intel_gpu_aml", GPU_FAMILY_IGPU_NON_XE}}, + {gpu_arch::intel_gpu_cml, {"intel_gpu_cml", GPU_FAMILY_IGPU_NON_XE}}, + {gpu_arch::intel_gpu_icllp, {"intel_gpu_icllp", GPU_FAMILY_IGPU_NON_XE}}, + {gpu_arch::intel_gpu_ehl, {"intel_gpu_ehl", GPU_FAMILY_IGPU_NON_XE}}, + {gpu_arch::intel_gpu_tgllp, {"intel_gpu_tgllp", GPU_FAMILY_IGPU_NON_XE}}, + {gpu_arch::intel_gpu_rkl, {"intel_gpu_rkl", GPU_FAMILY_IGPU_NON_XE}}, + {gpu_arch::intel_gpu_adl_s, {"intel_gpu_adl_s", GPU_FAMILY_IGPU_NON_XE}}, + {gpu_arch::intel_gpu_adl_p, {"intel_gpu_adl_p", GPU_FAMILY_IGPU_NON_XE}}, + {gpu_arch::intel_gpu_adl_n, {"intel_gpu_adl_n", GPU_FAMILY_IGPU_NON_XE}}, + {gpu_arch::intel_gpu_dg1, {"intel_gpu_dg1", GPU_FAMILY_DGPU_CLIENT_GAME}}, + {gpu_arch::intel_gpu_acm_g10, {"intel_gpu_acm_g10", GPU_FAMILY_DGPU_CLIENT_GAME}}, + {gpu_arch::intel_gpu_acm_g11, {"intel_gpu_acm_g11", GPU_FAMILY_DGPU_CLIENT_GAME}}, + {gpu_arch::intel_gpu_acm_g12, {"intel_gpu_acm_g12", GPU_FAMILY_DGPU_CLIENT_GAME}}, + {gpu_arch::intel_gpu_pvc, {"intel_gpu_pvc", GPU_FAMILY_DGPU_CLOUD}}, + {gpu_arch::intel_gpu_pvc_vg, {"intel_gpu_pvc_vg", GPU_FAMILY_DGPU_CLOUD}}, + {gpu_arch::intel_gpu_mtl_u, {"intel_gpu_mtl_u", GPU_FAMILY_IGPU_XE}}, + {gpu_arch::intel_gpu_mtl_h, {"intel_gpu_mtl_h", GPU_FAMILY_IGPU_XE}}, + {gpu_arch::intel_gpu_arl_h, {"intel_gpu_arl_h", GPU_FAMILY_IGPU_XE}}, + {gpu_arch::intel_gpu_bmg_g21, {"intel_gpu_bmg_g21", GPU_FAMILY_DGPU_CLIENT_GAME}}, + {gpu_arch::intel_gpu_bmg_g31, {"intel_gpu_bmg_g31", GPU_FAMILY_DGPU_CLIENT_GAME}}, + {gpu_arch::intel_gpu_lnl_m, {"intel_gpu_lnl_m", GPU_FAMILY_IGPU_XE}}, + {gpu_arch::intel_gpu_ptl_h, {"intel_gpu_ptl_h", GPU_FAMILY_IGPU_XE}}, + {gpu_arch::intel_gpu_ptl_u, {"intel_gpu_ptl_u", GPU_FAMILY_IGPU_XE}}, + {gpu_arch::intel_gpu_wcl, {"intel_gpu_wcl", GPU_FAMILY_IGPU_XE}} +}; + + +sycl_hw_info get_device_hw_info(sycl::device* device_ptr) { + sycl_hw_info res; + int32_t id = + device_ptr->get_info<sycl::ext::intel::info::device::device_id>(); + res.device_id = id; + + res.name = device_ptr->get_info<sycl::info::device::name>(); - return res; + syclex::architecture arch = + device_ptr->get_info<syclex::info::device::architecture>(); + res.arch = arch; + + map<syclex::architecture, + std::pair<const char*, sycl_intel_gpu_family>>::iterator it = + arch2name.find(res.arch); + if (it != arch2name.end()) { + res.arch_name = it->second.first; + res.gpu_family = it->second.second; + } else { + res.arch_name = "unknown"; + res.gpu_family = GPU_FAMILY_UKNOWN; + } + + return res; } -*/ diff --git a/ggml/src/ggml-sycl/sycl_hw.hpp b/ggml/src/ggml-sycl/sycl_hw.hpp index 36b140bf037..a5d20462572 100644 --- a/ggml/src/ggml-sycl/sycl_hw.hpp +++ b/ggml/src/ggml-sycl/sycl_hw.hpp @@ -9,18 +9,30 @@ #include <sycl/sycl.hpp> namespace syclex = sycl::ext::oneapi::experimental; +using gpu_arch = sycl::ext::oneapi::experimental::architecture; + +// It's used to mark the GPU computing capacity +// The value must flow the order of performance. +enum sycl_intel_gpu_family { + GPU_FAMILY_UKNOWN = -1, + // iGPU without Xe core, before Meteor Lake iGPU(Xe) + GPU_FAMILY_IGPU_NON_XE = 0, + // iGPU with Xe core, Meteor Lake iGPU or newer. + GPU_FAMILY_IGPU_XE = 1, + // dGPU for gaming in client/data center (DG1/FLex 140 or newer). + GPU_FAMILY_DGPU_CLIENT_GAME = 2, + // dGPU for AI in cloud, PVC or newer. + GPU_FAMILY_DGPU_CLOUD = 3 +}; -// TODO: currently not used -/* struct sycl_hw_info { syclex::architecture arch; + const char* arch_name; int32_t device_id; + std::string name; + sycl_intel_gpu_family gpu_family; }; -bool is_in_vector(std::vector<int> &vec, int item); - sycl_hw_info get_device_hw_info(sycl::device *device_ptr); -*/ - #endif // SYCL_HW_HPP diff --git a/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq112-dv112.cpp b/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq112-dv112.cpp new file mode 100644 index 00000000000..5c06d42fdbd --- /dev/null +++ b/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq112-dv112.cpp @@ -0,0 +1,5 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-tile.hpp" + +DECL_FATTN_TILE_CASE(112, 112); diff --git a/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq128-dv128.cpp b/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq128-dv128.cpp new file mode 100644 index 00000000000..f74e1202b83 --- /dev/null +++ b/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq128-dv128.cpp @@ -0,0 +1,5 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-tile.hpp" + +DECL_FATTN_TILE_CASE(128, 128); diff --git a/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq256-dv256.cpp b/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq256-dv256.cpp new file mode 100644 index 00000000000..b574fe9308d --- /dev/null +++ b/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq256-dv256.cpp @@ -0,0 +1,5 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-tile.hpp" + +DECL_FATTN_TILE_CASE(256, 256); diff --git a/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq40-dv40.cpp b/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq40-dv40.cpp new file mode 100644 index 00000000000..8c8fb692c43 --- /dev/null +++ b/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq40-dv40.cpp @@ -0,0 +1,5 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-tile.hpp" + +DECL_FATTN_TILE_CASE(40, 40); diff --git a/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq512-dv512.cpp b/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq512-dv512.cpp new file mode 100644 index 00000000000..9a6a1877566 --- /dev/null +++ b/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq512-dv512.cpp @@ -0,0 +1,6 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-tile.hpp" + +DECL_FATTN_TILE_CASE(512, 512); + diff --git a/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq576-dv512.cpp b/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq576-dv512.cpp new file mode 100644 index 00000000000..f218552e85f --- /dev/null +++ b/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq576-dv512.cpp @@ -0,0 +1,5 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-tile.hpp" + +DECL_FATTN_TILE_CASE(576, 512); diff --git a/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq64-dv64.cpp b/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq64-dv64.cpp new file mode 100644 index 00000000000..99303a53a3c --- /dev/null +++ b/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq64-dv64.cpp @@ -0,0 +1,5 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-tile.hpp" + +DECL_FATTN_TILE_CASE(64, 64); diff --git a/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq72-dv72.cpp b/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq72-dv72.cpp new file mode 100644 index 00000000000..50592768afd --- /dev/null +++ b/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq72-dv72.cpp @@ -0,0 +1,5 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-tile.hpp" + +DECL_FATTN_TILE_CASE(72, 72); diff --git a/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq80-dv80.cpp b/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq80-dv80.cpp new file mode 100644 index 00000000000..74f1ea5e90c --- /dev/null +++ b/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq80-dv80.cpp @@ -0,0 +1,5 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-tile.hpp" + +DECL_FATTN_TILE_CASE(80, 80); diff --git a/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq96-dv96.cpp b/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq96-dv96.cpp new file mode 100644 index 00000000000..cefb46dddc7 --- /dev/null +++ b/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq96-dv96.cpp @@ -0,0 +1,5 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-tile.hpp" + +DECL_FATTN_TILE_CASE(96, 96); diff --git a/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-f16.cpp b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-f16.cpp new file mode 100644 index 00000000000..43ef94c118c --- /dev/null +++ b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-f16.cpp @@ -0,0 +1,8 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-vec.hpp" + +DECL_FATTN_VEC_CASE( 64, GGML_TYPE_F16, GGML_TYPE_F16); +DECL_FATTN_VEC_CASE(128, GGML_TYPE_F16, GGML_TYPE_F16); +DECL_FATTN_VEC_CASE(256, GGML_TYPE_F16, GGML_TYPE_F16); +DECL_FATTN_VEC_CASE(512, GGML_TYPE_F16, GGML_TYPE_F16); diff --git a/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q4_0.cpp b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q4_0.cpp new file mode 100644 index 00000000000..9404061d456 --- /dev/null +++ b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q4_0.cpp @@ -0,0 +1,8 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-vec.hpp" + +DECL_FATTN_VEC_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q4_0); +DECL_FATTN_VEC_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q4_0); +DECL_FATTN_VEC_CASE(256, GGML_TYPE_F16, GGML_TYPE_Q4_0); +DECL_FATTN_VEC_CASE(512, GGML_TYPE_F16, GGML_TYPE_Q4_0); diff --git a/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q4_1.cpp b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q4_1.cpp new file mode 100644 index 00000000000..a8bb9f52d0c --- /dev/null +++ b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q4_1.cpp @@ -0,0 +1,8 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-vec.hpp" + +DECL_FATTN_VEC_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q4_1); +DECL_FATTN_VEC_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q4_1); +DECL_FATTN_VEC_CASE(256, GGML_TYPE_F16, GGML_TYPE_Q4_1); +DECL_FATTN_VEC_CASE(512, GGML_TYPE_F16, GGML_TYPE_Q4_1); diff --git a/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q5_0.cpp b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q5_0.cpp new file mode 100644 index 00000000000..7d61f6ab0af --- /dev/null +++ b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q5_0.cpp @@ -0,0 +1,8 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-vec.hpp" + +DECL_FATTN_VEC_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q5_0); +DECL_FATTN_VEC_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q5_0); +DECL_FATTN_VEC_CASE(256, GGML_TYPE_F16, GGML_TYPE_Q5_0); +DECL_FATTN_VEC_CASE(512, GGML_TYPE_F16, GGML_TYPE_Q5_0); diff --git a/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q5_1.cpp b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q5_1.cpp new file mode 100644 index 00000000000..753bae09f83 --- /dev/null +++ b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q5_1.cpp @@ -0,0 +1,8 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-vec.hpp" + +DECL_FATTN_VEC_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q5_1); +DECL_FATTN_VEC_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q5_1); +DECL_FATTN_VEC_CASE(256, GGML_TYPE_F16, GGML_TYPE_Q5_1); +DECL_FATTN_VEC_CASE(512, GGML_TYPE_F16, GGML_TYPE_Q5_1); diff --git a/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q8_0.cpp b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q8_0.cpp new file mode 100644 index 00000000000..546a93b2570 --- /dev/null +++ b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q8_0.cpp @@ -0,0 +1,8 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-vec.hpp" + +DECL_FATTN_VEC_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q8_0); +DECL_FATTN_VEC_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q8_0); +DECL_FATTN_VEC_CASE(256, GGML_TYPE_F16, GGML_TYPE_Q8_0); +DECL_FATTN_VEC_CASE(512, GGML_TYPE_F16, GGML_TYPE_Q8_0); diff --git a/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-f16.cpp b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-f16.cpp new file mode 100644 index 00000000000..53c8c2f2654 --- /dev/null +++ b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-f16.cpp @@ -0,0 +1,8 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-vec.hpp" + +DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q4_0, GGML_TYPE_F16); +DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_F16); +DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q4_0, GGML_TYPE_F16); +DECL_FATTN_VEC_CASE(512, GGML_TYPE_Q4_0, GGML_TYPE_F16); diff --git a/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q4_0.cpp b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q4_0.cpp new file mode 100644 index 00000000000..5b409c55f21 --- /dev/null +++ b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q4_0.cpp @@ -0,0 +1,8 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-vec.hpp" + +DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q4_0, GGML_TYPE_Q4_0); +DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q4_0); +DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q4_0, GGML_TYPE_Q4_0); +DECL_FATTN_VEC_CASE(512, GGML_TYPE_Q4_0, GGML_TYPE_Q4_0); diff --git a/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q4_1.cpp b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q4_1.cpp new file mode 100644 index 00000000000..8c4ef588d63 --- /dev/null +++ b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q4_1.cpp @@ -0,0 +1,8 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-vec.hpp" + +DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q4_0, GGML_TYPE_Q4_1); +DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q4_1); +DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q4_0, GGML_TYPE_Q4_1); +DECL_FATTN_VEC_CASE(512, GGML_TYPE_Q4_0, GGML_TYPE_Q4_1); diff --git a/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q5_0.cpp b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q5_0.cpp new file mode 100644 index 00000000000..83f0a07552e --- /dev/null +++ b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q5_0.cpp @@ -0,0 +1,8 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-vec.hpp" + +DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q4_0, GGML_TYPE_Q5_0); +DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q5_0); +DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q4_0, GGML_TYPE_Q5_0); +DECL_FATTN_VEC_CASE(512, GGML_TYPE_Q4_0, GGML_TYPE_Q5_0); diff --git a/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q5_1.cpp b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q5_1.cpp new file mode 100644 index 00000000000..9df9b03bba4 --- /dev/null +++ b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q5_1.cpp @@ -0,0 +1,8 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-vec.hpp" + +DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q4_0, GGML_TYPE_Q5_1); +DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q5_1); +DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q4_0, GGML_TYPE_Q5_1); +DECL_FATTN_VEC_CASE(512, GGML_TYPE_Q4_0, GGML_TYPE_Q5_1); diff --git a/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q8_0.cpp b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q8_0.cpp new file mode 100644 index 00000000000..6980c2a65bb --- /dev/null +++ b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q8_0.cpp @@ -0,0 +1,8 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-vec.hpp" + +DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q4_0, GGML_TYPE_Q8_0); +DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q8_0); +DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q4_0, GGML_TYPE_Q8_0); +DECL_FATTN_VEC_CASE(512, GGML_TYPE_Q4_0, GGML_TYPE_Q8_0); diff --git a/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-f16.cpp b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-f16.cpp new file mode 100644 index 00000000000..bd61bc1dc2b --- /dev/null +++ b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-f16.cpp @@ -0,0 +1,8 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-vec.hpp" + +DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q4_1, GGML_TYPE_F16); +DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_F16); +DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q4_1, GGML_TYPE_F16); +DECL_FATTN_VEC_CASE(512, GGML_TYPE_Q4_1, GGML_TYPE_F16); diff --git a/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q4_0.cpp b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q4_0.cpp new file mode 100644 index 00000000000..492e229a58e --- /dev/null +++ b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q4_0.cpp @@ -0,0 +1,8 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-vec.hpp" + +DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q4_1, GGML_TYPE_Q4_0); +DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q4_0); +DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q4_1, GGML_TYPE_Q4_0); +DECL_FATTN_VEC_CASE(512, GGML_TYPE_Q4_1, GGML_TYPE_Q4_0); diff --git a/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q4_1.cpp b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q4_1.cpp new file mode 100644 index 00000000000..30f88a2ebd5 --- /dev/null +++ b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q4_1.cpp @@ -0,0 +1,8 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-vec.hpp" + +DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q4_1, GGML_TYPE_Q4_1); +DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q4_1); +DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q4_1, GGML_TYPE_Q4_1); +DECL_FATTN_VEC_CASE(512, GGML_TYPE_Q4_1, GGML_TYPE_Q4_1); diff --git a/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q5_0.cpp b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q5_0.cpp new file mode 100644 index 00000000000..db76663604e --- /dev/null +++ b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q5_0.cpp @@ -0,0 +1,8 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-vec.hpp" + +DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q4_1, GGML_TYPE_Q5_0); +DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q5_0); +DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q4_1, GGML_TYPE_Q5_0); +DECL_FATTN_VEC_CASE(512, GGML_TYPE_Q4_1, GGML_TYPE_Q5_0); diff --git a/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q5_1.cpp b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q5_1.cpp new file mode 100644 index 00000000000..1dbcc8a85a8 --- /dev/null +++ b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q5_1.cpp @@ -0,0 +1,8 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-vec.hpp" + +DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q4_1, GGML_TYPE_Q5_1); +DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q5_1); +DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q4_1, GGML_TYPE_Q5_1); +DECL_FATTN_VEC_CASE(512, GGML_TYPE_Q4_1, GGML_TYPE_Q5_1); diff --git a/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q8_0.cpp b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q8_0.cpp new file mode 100644 index 00000000000..d30996a6259 --- /dev/null +++ b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q8_0.cpp @@ -0,0 +1,8 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-vec.hpp" + +DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q4_1, GGML_TYPE_Q8_0); +DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q8_0); +DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q4_1, GGML_TYPE_Q8_0); +DECL_FATTN_VEC_CASE(512, GGML_TYPE_Q4_1, GGML_TYPE_Q8_0); diff --git a/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-f16.cpp b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-f16.cpp new file mode 100644 index 00000000000..bc0f635d922 --- /dev/null +++ b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-f16.cpp @@ -0,0 +1,8 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-vec.hpp" + +DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q5_0, GGML_TYPE_F16); +DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_F16); +DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q5_0, GGML_TYPE_F16); +DECL_FATTN_VEC_CASE(512, GGML_TYPE_Q5_0, GGML_TYPE_F16); diff --git a/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q4_0.cpp b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q4_0.cpp new file mode 100644 index 00000000000..9e0378107cb --- /dev/null +++ b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q4_0.cpp @@ -0,0 +1,8 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-vec.hpp" + +DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q5_0, GGML_TYPE_Q4_0); +DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q4_0); +DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q5_0, GGML_TYPE_Q4_0); +DECL_FATTN_VEC_CASE(512, GGML_TYPE_Q5_0, GGML_TYPE_Q4_0); diff --git a/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q4_1.cpp b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q4_1.cpp new file mode 100644 index 00000000000..a8535ac9156 --- /dev/null +++ b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q4_1.cpp @@ -0,0 +1,8 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-vec.hpp" + +DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q5_0, GGML_TYPE_Q4_1); +DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q4_1); +DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q5_0, GGML_TYPE_Q4_1); +DECL_FATTN_VEC_CASE(512, GGML_TYPE_Q5_0, GGML_TYPE_Q4_1); diff --git a/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q5_0.cpp b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q5_0.cpp new file mode 100644 index 00000000000..43d4fae9a61 --- /dev/null +++ b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q5_0.cpp @@ -0,0 +1,8 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-vec.hpp" + +DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q5_0, GGML_TYPE_Q5_0); +DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q5_0); +DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q5_0, GGML_TYPE_Q5_0); +DECL_FATTN_VEC_CASE(512, GGML_TYPE_Q5_0, GGML_TYPE_Q5_0); diff --git a/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q5_1.cpp b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q5_1.cpp new file mode 100644 index 00000000000..23335a41640 --- /dev/null +++ b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q5_1.cpp @@ -0,0 +1,8 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-vec.hpp" + +DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q5_0, GGML_TYPE_Q5_1); +DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q5_1); +DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q5_0, GGML_TYPE_Q5_1); +DECL_FATTN_VEC_CASE(512, GGML_TYPE_Q5_0, GGML_TYPE_Q5_1); diff --git a/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q8_0.cpp b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q8_0.cpp new file mode 100644 index 00000000000..52550a33757 --- /dev/null +++ b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q8_0.cpp @@ -0,0 +1,8 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-vec.hpp" + +DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q5_0, GGML_TYPE_Q8_0); +DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q8_0); +DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q5_0, GGML_TYPE_Q8_0); +DECL_FATTN_VEC_CASE(512, GGML_TYPE_Q5_0, GGML_TYPE_Q8_0); diff --git a/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-f16.cpp b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-f16.cpp new file mode 100644 index 00000000000..4651f14c050 --- /dev/null +++ b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-f16.cpp @@ -0,0 +1,8 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-vec.hpp" + +DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q5_1, GGML_TYPE_F16); +DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_F16); +DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q5_1, GGML_TYPE_F16); +DECL_FATTN_VEC_CASE(512, GGML_TYPE_Q5_1, GGML_TYPE_F16); diff --git a/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q4_0.cpp b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q4_0.cpp new file mode 100644 index 00000000000..2310fd8792c --- /dev/null +++ b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q4_0.cpp @@ -0,0 +1,8 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-vec.hpp" + +DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q5_1, GGML_TYPE_Q4_0); +DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q4_0); +DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q5_1, GGML_TYPE_Q4_0); +DECL_FATTN_VEC_CASE(512, GGML_TYPE_Q5_1, GGML_TYPE_Q4_0); diff --git a/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q4_1.cpp b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q4_1.cpp new file mode 100644 index 00000000000..d2494048bc1 --- /dev/null +++ b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q4_1.cpp @@ -0,0 +1,8 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-vec.hpp" + +DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q5_1, GGML_TYPE_Q4_1); +DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q4_1); +DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q5_1, GGML_TYPE_Q4_1); +DECL_FATTN_VEC_CASE(512, GGML_TYPE_Q5_1, GGML_TYPE_Q4_1); diff --git a/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q5_0.cpp b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q5_0.cpp new file mode 100644 index 00000000000..be3a1fe97f5 --- /dev/null +++ b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q5_0.cpp @@ -0,0 +1,8 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-vec.hpp" + +DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q5_1, GGML_TYPE_Q5_0); +DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q5_0); +DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q5_1, GGML_TYPE_Q5_0); +DECL_FATTN_VEC_CASE(512, GGML_TYPE_Q5_1, GGML_TYPE_Q5_0); diff --git a/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q5_1.cpp b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q5_1.cpp new file mode 100644 index 00000000000..be0a89409ca --- /dev/null +++ b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q5_1.cpp @@ -0,0 +1,8 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-vec.hpp" + +DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q5_1, GGML_TYPE_Q5_1); +DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q5_1); +DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q5_1, GGML_TYPE_Q5_1); +DECL_FATTN_VEC_CASE(512, GGML_TYPE_Q5_1, GGML_TYPE_Q5_1); diff --git a/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q8_0.cpp b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q8_0.cpp new file mode 100644 index 00000000000..6781efcb0d2 --- /dev/null +++ b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q8_0.cpp @@ -0,0 +1,8 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-vec.hpp" + +DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q5_1, GGML_TYPE_Q8_0); +DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q8_0); +DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q5_1, GGML_TYPE_Q8_0); +DECL_FATTN_VEC_CASE(512, GGML_TYPE_Q5_1, GGML_TYPE_Q8_0); diff --git a/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-f16.cpp b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-f16.cpp new file mode 100644 index 00000000000..43a70ae3543 --- /dev/null +++ b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-f16.cpp @@ -0,0 +1,8 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-vec.hpp" + +DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q8_0, GGML_TYPE_F16); +DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_F16); +DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q8_0, GGML_TYPE_F16); +DECL_FATTN_VEC_CASE(512, GGML_TYPE_Q8_0, GGML_TYPE_F16); diff --git a/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q4_0.cpp b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q4_0.cpp new file mode 100644 index 00000000000..fa7eb8163ca --- /dev/null +++ b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q4_0.cpp @@ -0,0 +1,8 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-vec.hpp" + +DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q8_0, GGML_TYPE_Q4_0); +DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q4_0); +DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q8_0, GGML_TYPE_Q4_0); +DECL_FATTN_VEC_CASE(512, GGML_TYPE_Q8_0, GGML_TYPE_Q4_0); diff --git a/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q4_1.cpp b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q4_1.cpp new file mode 100644 index 00000000000..79d9cfbee96 --- /dev/null +++ b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q4_1.cpp @@ -0,0 +1,8 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-vec.hpp" + +DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q8_0, GGML_TYPE_Q4_1); +DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q4_1); +DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q8_0, GGML_TYPE_Q4_1); +DECL_FATTN_VEC_CASE(512, GGML_TYPE_Q8_0, GGML_TYPE_Q4_1); diff --git a/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q5_0.cpp b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q5_0.cpp new file mode 100644 index 00000000000..86befd5d327 --- /dev/null +++ b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q5_0.cpp @@ -0,0 +1,8 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-vec.hpp" + +DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q8_0, GGML_TYPE_Q5_0); +DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q5_0); +DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q8_0, GGML_TYPE_Q5_0); +DECL_FATTN_VEC_CASE(512, GGML_TYPE_Q8_0, GGML_TYPE_Q5_0); diff --git a/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q5_1.cpp b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q5_1.cpp new file mode 100644 index 00000000000..c2f619b0b16 --- /dev/null +++ b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q5_1.cpp @@ -0,0 +1,8 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-vec.hpp" + +DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q8_0, GGML_TYPE_Q5_1); +DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q5_1); +DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q8_0, GGML_TYPE_Q5_1); +DECL_FATTN_VEC_CASE(512, GGML_TYPE_Q8_0, GGML_TYPE_Q5_1); diff --git a/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q8_0.cpp b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q8_0.cpp new file mode 100644 index 00000000000..7cf31f8b8a1 --- /dev/null +++ b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q8_0.cpp @@ -0,0 +1,8 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-vec.hpp" + +DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q8_0, GGML_TYPE_Q8_0); +DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q8_0); +DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q8_0, GGML_TYPE_Q8_0); +DECL_FATTN_VEC_CASE(512, GGML_TYPE_Q8_0, GGML_TYPE_Q8_0); diff --git a/ggml/src/ggml-sycl/type.hpp b/ggml/src/ggml-sycl/type.hpp new file mode 100644 index 00000000000..d7ff89d7d42 --- /dev/null +++ b/ggml/src/ggml-sycl/type.hpp @@ -0,0 +1,112 @@ +#pragma once + +#include <sycl/sycl.hpp> +#include <cstdint> +#include <limits> + +inline uint8_t float_to_e4m3(float f) +{ + if (sycl::isnan(f)) { + return 0x7F; // Canonical NaN (positive) + } + + uint32_t bits = sycl::bit_cast<uint32_t>(f); + uint32_t sign = (bits >> 31) & 0x1u; + uint32_t exp = (bits >> 23) & 0xFFu; + uint32_t mant = bits & 0x7FFFFFu; + + // Zero + if (exp == 0 && mant == 0) { + return static_cast<uint8_t>(sign << 7); + } + + // Extract biased exponent and mantissa for FP8 + int e = static_cast<int>(exp) - 127; // true exponent (IEEE bias 127) + uint32_t m = mant; + + // Handle very large values → NaN (NVIDIA behavior for E4M3) + if (e > 7) { // max exponent for E4M3 is 7 (biased 14) + return static_cast<uint8_t>((sign << 7) | 0x7F); + } + + // Handle subnormals and normal numbers + if (e < -6) { // smallest normal exponent is -6 + // Subnormal in FP8: shift mantissa right + int shift = -6 - e; + m = (m | 0x800000u) >> (shift + 1); // +1 because we lose the implicit 1 position + if (shift > 23) m = 0; + } else { + // Normal number: adjust exponent bias from 127 to 7 + int new_exp = e + 7; + m = (m >> 20) & 0x7u; // take top 3 mantissa bits (after implicit 1) + m |= (static_cast<uint32_t>(new_exp) << 3); + } + + // Round-to-nearest-even (simple guard + round bit) + // For better accuracy you can add sticky bit, but this is sufficient for most use cases + uint32_t round_bit = (mant >> 19) & 0x1u; // bit after the 3 mantissa bits + if (round_bit) { + m += 1; + // Carry into exponent if mantissa overflows + if ((m & 0x8u) != 0) { + m = (m & 0x7u) | ((m & 0x38u) << 1); // simple carry handling + // If exponent overflows after carry → NaN + if ((m >> 3) > 14) { + return static_cast<uint8_t>((sign << 7) | 0x7F); + } + } + } + + uint8_t result = static_cast<uint8_t>((sign << 7) | (m & 0x7F)); + return result; +} + +inline float e4m3_to_float(uint8_t x) +{ + if (x == 0) return 0.0f; + + uint8_t sign = (x >> 7) & 0x1u; + uint8_t exp = (x >> 3) & 0xFu; + uint8_t mant = x & 0x7u; + + // NaN (NVIDIA uses 0x7F / 0xFF as NaN) + if (exp == 0xF && mant != 0) { + return std::numeric_limits<float>::quiet_NaN(); + } + if (exp == 0xF) { // 0x7F or 0xFF treated as NaN + return std::numeric_limits<float>::quiet_NaN(); + } + + float val; + + if (exp == 0) { + // Subnormal + val = mant * (1.0f / 8.0f) * sycl::pow(2.0f, -6.0f); + } else { + // Normal: implicit leading 1 + bias 7 + val = (1.0f + mant / 8.0f) * sycl::pow(2.0f, static_cast<float>(exp) - 7.0f); + } + + return sign ? -val : val; +} + +// The actual type definition +struct __nv_fp8_e4m3 { + uint8_t raw; + + __nv_fp8_e4m3() = default; + + explicit __nv_fp8_e4m3(float f) : raw(float_to_e4m3(f)) {} + explicit __nv_fp8_e4m3(sycl::half h) : raw(float_to_e4m3(static_cast<float>(h))) {} + + operator float() const { return e4m3_to_float(raw); } + operator sycl::half() const { return static_cast<sycl::half>(static_cast<float>(*this)); } + + // Allow direct access for vector loads/stores + operator uint8_t&() { return raw; } + operator uint8_t() const { return raw; } +}; + +using __nv_fp8x2_e4m3 = sycl::vec<__nv_fp8_e4m3, 2>; +using __nv_fp8x4_e4m3 = sycl::vec<__nv_fp8_e4m3, 4>; + diff --git a/ggml/src/ggml-sycl/upscale.cpp b/ggml/src/ggml-sycl/upscale.cpp new file mode 100644 index 00000000000..e42cb419d83 --- /dev/null +++ b/ggml/src/ggml-sycl/upscale.cpp @@ -0,0 +1,410 @@ +#include "upscale.hpp" + +static void upscale_f32(const float * x, float * dst, + const int nb00, const int nb01, const int nb02, const int nb03, + const int ne10, const int ne11, const int ne12, const int ne13, + const float sf0, const float sf1, const float sf2, const float sf3) { + auto item_ct1 = sycl::ext::oneapi::this_work_item::get_nd_item<3>(); + int index = item_ct1.get_local_id(2) + item_ct1.get_group(2) * item_ct1.get_local_range(2); + if (index >= ne10 * ne11 * ne12 * ne13) { + return; + } + + int i10 = index % ne10; + int i11 = (index / ne10) % ne11; + int i12 = (index / (ne10 * ne11)) % ne12; + int i13 = (index / (ne10 * ne11 * ne12)) % ne13; + + int i00 = i10 / sf0; + int i01 = i11 / sf1; + int i02 = i12 / sf2; + int i03 = i13 / sf3; + + dst[index] = *((const float*)((const char*)x + i03 * nb03 + i02 * nb02 + + i01 * nb01 + i00 * nb00)); +} + +static void upscale_f32_bilinear(const float * x, float * dst, + const int nb00, const int nb01, const int nb02, const int nb03, + const int ne00_src, const int ne01_src, + const int ne10_dst, const int ne11_dst, const int ne12_dst, const int ne13_dst, + const float sf0, const float sf1, const float sf2, const float sf3, + const float pixel_offset) { + auto item_ct1 = sycl::ext::oneapi::this_work_item::get_nd_item<3>(); + const int64_t index = item_ct1.get_local_id(2) + + item_ct1.get_group(2) * item_ct1.get_local_range(2); + const int64_t dst_total_elements = ne10_dst * ne11_dst * ne12_dst * ne13_dst; + + if (index >= dst_total_elements) { + return; + } + + const int i10_dst = index % ne10_dst; + const int i11_dst = (index / ne10_dst) % ne11_dst; + const int i12_dst = (index / (ne10_dst * ne11_dst)) % ne12_dst; + const int i13_dst = index / (ne10_dst * ne11_dst * ne12_dst); + + const int i02_src = (int)(i12_dst / sf2); + const int i03_src = (int)(i13_dst / sf3); + + const float y_src_f = ((float)i11_dst + pixel_offset) / sf1 - pixel_offset; + int y0_src = (int) sycl::floor((float) y_src_f); + int y1_src = y0_src + 1; + + y0_src = sycl::max(0, sycl::min(y0_src, ne01_src - 1)); + y1_src = sycl::max(0, sycl::min(y1_src, ne01_src - 1)); + + float dy = y_src_f - (float)y0_src; + dy = sycl::max(0.0f, sycl::min(dy, 1.0f)); + + float x_src_f = ((float)i10_dst + pixel_offset) / sf0 - pixel_offset; + int x0_src = (int) sycl::floor(x_src_f); + int x1_src = x0_src + 1; + + x0_src = sycl::max(0, sycl::min(x0_src, ne00_src - 1)); + x1_src = sycl::max(0, sycl::min(x1_src, ne00_src - 1)); + + float dx = x_src_f - (float)x0_src; + dx = sycl::max(0.0f, sycl::min(dx, 1.0f)); + + const float* p_a = + (const float*)((const char*)x + (int64_t)x0_src * nb00 + + (int64_t)y0_src * nb01 + (int64_t)i02_src * nb02 + + (int64_t)i03_src * nb03); + const float* p_b = + (const float*)((const char*)x + (int64_t)x1_src * nb00 + + (int64_t)y0_src * nb01 + (int64_t)i02_src * nb02 + + (int64_t)i03_src * nb03); + const float* p_c = + (const float*)((const char*)x + (int64_t)x0_src * nb00 + + (int64_t)y1_src * nb01 + (int64_t)i02_src * nb02 + + (int64_t)i03_src * nb03); + const float* p_d = + (const float*)((const char*)x + (int64_t)x1_src * nb00 + + (int64_t)y1_src * nb01 + (int64_t)i02_src * nb02 + + (int64_t)i03_src * nb03); + + const float val_a = *p_a; + const float val_b = *p_b; + const float val_c = *p_c; + const float val_d = *p_d; + + float result = val_a * (1.0f - dx) * (1.0f - dy) + + val_b * dx * (1.0f - dy) + + val_c * (1.0f - dx) * dy + + val_d * dx * dy; + + dst[index] = result; +} + +// Similar to F.interpolate(..., mode="bilinear", align_corners=False, antialias=True) +// https://github.com/pytorch/pytorch/blob/8871ff29b743948d1225389d5b7068f37b22750b/aten/src/ATen/native/cpu/UpSampleKernel.cpp +static void upscale_f32_bilinear_antialias(const float * src0, + float * dst, + const int nb00, + const int nb01, + const int nb02, + const int nb03, + const int ne00_src, + const int ne01_src, + const int ne10_dst, + const int ne11_dst, + const int ne12_dst, + const int ne13_dst, + const float sf0, + const float sf1, + const float sf2, + const float sf3, + const float pixel_offset) { + auto item_ct1 = sycl::ext::oneapi::this_work_item::get_nd_item<3>(); + const int64_t index = item_ct1.get_local_id(2) + + item_ct1.get_group(2) * item_ct1.get_local_range(2); + const int64_t dst_total_elements = ne10_dst * ne11_dst * ne12_dst * ne13_dst; + + if (index >= dst_total_elements) { + return; + } + + const int i10_dst = index % ne10_dst; + const int i11_dst = (index / ne10_dst) % ne11_dst; + const int i12_dst = (index / (ne10_dst * ne11_dst)) % ne12_dst; + const int i13_dst = index / (ne10_dst * ne11_dst * ne12_dst); + + const int i02_src = (int)(i12_dst / sf2); + const int i03_src = (int)(i13_dst / sf3); + + const float y = ((float)i11_dst + pixel_offset) / sf1; + const float x = ((float)i10_dst + pixel_offset) / sf0; + + // support and invscale, minimum 1 pixel for bilinear + const float support1 = sycl::max(1.0f / sf1, 1.0f); + const float invscale1 = 1.0f / support1; + const float support0 = sycl::max(1.0f / sf0, 1.0f); + const float invscale0 = 1.0f / support0; + + // the range of source pixels that contribute + const int64_t x_min = sycl::max(int64_t(0), int64_t(x - support0 + pixel_offset)); + const int64_t x_max = sycl::min(int64_t(ne00_src), int64_t(x + support0 + pixel_offset)); + const int64_t y_min = sycl::max(int64_t(0), int64_t(y - support1 + pixel_offset)); + const int64_t y_max = sycl::min(int64_t(ne01_src), int64_t(y + support1 + pixel_offset)); + + // bilinear filter with antialiasing + float val = 0.0f; + float total_weight = 0.0f; + + auto triangle_filter = [](float x) -> float { + return sycl::max(1.0f - sycl::fabs(x), 0.0f); + }; + + for (int64_t sy = y_min; sy < y_max; sy++) { + const float weight_y = triangle_filter((sy - y + pixel_offset) * invscale1); + + for (int64_t sx = x_min; sx < x_max; sx++) { + const float weight_x = triangle_filter((sx - x + pixel_offset) * invscale0); + const float weight = weight_x * weight_y; + + if (weight <= 0.0f) { + continue; + } + + const float pixel = + *(const float*)((const char*)src0 + sx * nb00 + sy * nb01 + + i02_src * nb02 + i03_src * nb03); + val += pixel * weight; + total_weight += weight; + } + } + + if (total_weight > 0.0f) { + val /= total_weight; + } + + dst[index] = val; +} + +namespace bicubic_interpolation { +static float weight1(float x, const float &a) { return ((a + 2) * x - (a + 3)) * x * x + 1; }; +static float weight2(float x, const float &a) { return ((a * x - 5 * a) * x + 8 * a) * x - 4 * a; }; + +static float bicubic(float p0, float p1, float p2, float p3, float x, float a) { + const float w0 = weight2(x + 1, a); + const float w1 = weight1(x + 0, a); + const float w2 = weight1(1 - x, a); + const float w3 = weight2(2 - x, a); + return p0 * w0 + p1 * w1 + p2 * w2 + p3 * w3; +}; + +} + +static void upscale_f32_bicubic(const float * x, float * dst, + const int nb00, const int nb01, const int nb02, const int nb03, + const int ne00_src, const int ne01_src, + const int ne10_dst, const int ne11_dst, const int ne12_dst, const int ne13_dst, + const float sf0, const float sf1, const float sf2, const float sf3, + const float pixel_offset) { + auto item_ct1 = sycl::ext::oneapi::this_work_item::get_nd_item<3>(); + const float a = -0.75f; + using bicubic_interpolation::bicubic; + + const int64_t index = item_ct1.get_local_id(2) + + item_ct1.get_group(2) * item_ct1.get_local_range(2); + const int64_t dst_total_elements = + ne10_dst * ne11_dst * ne12_dst * ne13_dst; + + if (index >= dst_total_elements) { + return; + } + + const int i10_dst = index % ne10_dst; + const int i11_dst = (index / ne10_dst) % ne11_dst; + const int i12_dst = (index / (ne10_dst * ne11_dst)) % ne12_dst; + const int i13_dst = index / (ne10_dst * ne11_dst * ne12_dst); + + const int i02_src = (int)(i12_dst / sf2); + const int i03_src = (int)(i13_dst / sf3); + + const float y_src_f = ((float)i11_dst + pixel_offset) / sf1 - pixel_offset; + const int y0_src = (int) sycl::floor((float) y_src_f); + const float dy = y_src_f - (float)y0_src; + + const float x_src_f = ((float)i10_dst + pixel_offset) / sf0 - pixel_offset; + const int x0_src = (int) sycl::floor((float) x_src_f); + const float dx = x_src_f - (float)x0_src; + + const char * x_base = (const char *)x + (int64_t)i02_src * nb02 + (int64_t)i03_src * nb03; + + auto load = [=](int x_off, int y_off) -> float { + int i00_src = sycl::max(0, sycl::min(x0_src + x_off, ne00_src - 1)); + int i01_src = sycl::max(0, sycl::min(y0_src + y_off, ne01_src - 1)); + return *(const float *)(x_base + (int64_t)i00_src * nb00 + (int64_t)i01_src * nb01); + }; + + const float result = bicubic( + bicubic(load(-1, -1), load(0, -1), load(1, -1), load(2, -1), dx, a), + bicubic(load(-1, 0), load(0, 0), load(1, 0), load(2, 0), dx, a), + bicubic(load(-1, 1), load(0, 1), load(1, 1), load(2, 1), dx, a), + bicubic(load(-1, 2), load(0, 2), load(1, 2), load(2, 2), dx, a), + dy, + a); + + dst[index] = result; +} + +static void upscale_f32_sycl(const float * x, + float * dst, + const int nb00, + const int nb01, + const int nb02, + const int nb03, + const int ne10, + const int ne11, + const int ne12, + const int ne13, + const float sf0, + const float sf1, + const float sf2, + const float sf3, + dpct::queue_ptr stream) { + const int64_t dst_size = ne10 * ne11 * ne12 * ne13; + const int64_t num_blocks = (dst_size + SYCL_UPSCALE_BLOCK_SIZE - 1) / SYCL_UPSCALE_BLOCK_SIZE; + + stream->parallel_for( + sycl::nd_range<3>( + sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_UPSCALE_BLOCK_SIZE), + sycl::range<3>(1, 1, SYCL_UPSCALE_BLOCK_SIZE)), + [=](sycl::nd_item<3> /*item_ct1*/) { + upscale_f32(x, dst, nb00, nb01, nb02, nb03, ne10, ne11, ne12, ne13, sf0, sf1, sf2, sf3); + }); +} + +static void upscale_f32_bilinear_sycl(const float * x, + float * dst, + const int nb00, + const int nb01, + const int nb02, + const int nb03, + const int ne00_src, + const int ne01_src, + const int ne10_dst, + const int ne11_dst, + const int ne12_dst, + const int ne13_dst, + const float sf0, + const float sf1, + const float sf2, + const float sf3, + const float pixel_offset, + bool antialias, + dpct::queue_ptr stream) { + const int64_t dst_size = ne10_dst * ne11_dst * ne12_dst * ne13_dst; + const int64_t num_blocks = (dst_size + SYCL_UPSCALE_BLOCK_SIZE - 1) / SYCL_UPSCALE_BLOCK_SIZE; + + if (antialias) { + stream->parallel_for( + sycl::nd_range<3>( + sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_UPSCALE_BLOCK_SIZE), + sycl::range<3>(1, 1, SYCL_UPSCALE_BLOCK_SIZE)), + [=](sycl::nd_item<3> /*item_ct1*/) { + upscale_f32_bilinear_antialias( + x, dst, nb00, nb01, nb02, nb03, ne00_src, ne01_src, ne10_dst, ne11_dst, + ne12_dst, ne13_dst, sf0, sf1, sf2, sf3, pixel_offset); + }); + } else { + stream->parallel_for( + sycl::nd_range<3>( + sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_UPSCALE_BLOCK_SIZE), + sycl::range<3>(1, 1, SYCL_UPSCALE_BLOCK_SIZE)), + [=](sycl::nd_item<3> /*item_ct1*/) { + upscale_f32_bilinear( + x, dst, nb00, nb01, nb02, nb03, ne00_src, ne01_src, ne10_dst, ne11_dst, ne12_dst, + ne13_dst, sf0, sf1, sf2, sf3, pixel_offset); + }); + } +} + +static void upscale_f32_bicubic_sycl(const float * x, + float * dst, + const int nb00, + const int nb01, + const int nb02, + const int nb03, + const int ne00_src, + const int ne01_src, + const int ne10_dst, + const int ne11_dst, + const int ne12_dst, + const int ne13_dst, + const float sf0, + const float sf1, + const float sf2, + const float sf3, + const float pixel_offset, + dpct::queue_ptr stream) { + const int64_t dst_size = ne10_dst * ne11_dst * ne12_dst * ne13_dst; + const int64_t num_blocks = (dst_size + SYCL_UPSCALE_BLOCK_SIZE - 1) / SYCL_UPSCALE_BLOCK_SIZE; + + { + stream->submit([&](sycl::handler & cgh) { + cgh.parallel_for( + sycl::nd_range<3>( + sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_UPSCALE_BLOCK_SIZE), + sycl::range<3>(1, 1, SYCL_UPSCALE_BLOCK_SIZE)), + [=](sycl::nd_item<3> /*item_ct1*/) { + upscale_f32_bicubic( + x, dst, nb00, nb01, nb02, nb03, ne00_src, ne01_src, ne10_dst, ne11_dst, + ne12_dst, ne13_dst, sf0, sf1, sf2, sf3, pixel_offset); + }); + }); + } +} + +void ggml_sycl_op_upscale(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { + const ggml_tensor * src0 = dst->src[0]; + const float * src0_d = (const float *)src0->data; + float * dst_d = (float *)dst->data; + dpct::queue_ptr stream = ctx.stream(); + + GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT( dst->type == GGML_TYPE_F32); + + const int mode_flags = dst->op_params[0]; + const ggml_scale_mode mode = (ggml_scale_mode)(mode_flags & 0xFF); + + float sf0 = (float)dst->ne[0]/src0->ne[0]; + float sf1 = (float)dst->ne[1]/src0->ne[1]; + float sf2 = (float)dst->ne[2]/src0->ne[2]; + const float sf3 = (float)dst->ne[3]/src0->ne[3]; + + float pixel_offset = 0.5f; + if (mode_flags & GGML_SCALE_FLAG_ALIGN_CORNERS) { + sf0 = dst->ne[0] > 1 && src0->ne[0] > 1 + ? (float)(dst->ne[0] - 1) / (src0->ne[0] - 1) + : sf0; + sf1 = dst->ne[1] > 1 && src0->ne[1] > 1 + ? (float)(dst->ne[1] - 1) / (src0->ne[1] - 1) + : sf1; + pixel_offset = 0.0f; + } + + if (mode == GGML_SCALE_MODE_NEAREST) { + upscale_f32_sycl( + src0_d, dst_d, src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3], + dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], sf0, sf1, sf2, sf3, stream); + } else if (mode == GGML_SCALE_MODE_BILINEAR) { + const bool antialias = (mode_flags & GGML_SCALE_FLAG_ANTIALIAS); + upscale_f32_bilinear_sycl( + src0_d, dst_d, src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3], + src0->ne[0], src0->ne[1], dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], + sf0, sf1, sf2, sf3, pixel_offset, antialias, stream); + } else if (mode == GGML_SCALE_MODE_BICUBIC) { + upscale_f32_bicubic_sycl( + src0_d, dst_d, src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3], + src0->ne[0], src0->ne[1], dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], + sf0, sf1, sf2, sf3, pixel_offset, stream); + } +} + +void ggml_sycl_upscale(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { + scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1); + ggml_sycl_op_upscale(ctx, dst); +} diff --git a/ggml/src/ggml-sycl/upscale.hpp b/ggml/src/ggml-sycl/upscale.hpp new file mode 100644 index 00000000000..c36c1bdc970 --- /dev/null +++ b/ggml/src/ggml-sycl/upscale.hpp @@ -0,0 +1,9 @@ +#pragma once + +#include <sycl/sycl.hpp> +#include "dpct/helper.hpp" +#include "common.hpp" + +#define SYCL_UPSCALE_BLOCK_SIZE 256 + +void ggml_sycl_upscale(ggml_backend_sycl_context & ctx, ggml_tensor * dst); diff --git a/ggml/src/ggml-sycl/vecdotq.hpp b/ggml/src/ggml-sycl/vecdotq.hpp index 43482b3672c..4b58b09ab2c 100644 --- a/ggml/src/ggml-sycl/vecdotq.hpp +++ b/ggml/src/ggml-sycl/vecdotq.hpp @@ -15,6 +15,7 @@ #include "dpct/helper.hpp" #include "ggml.h" +#include "type.hpp" #include "quants.hpp" typedef float (*vec_dot_q_sycl_t)(const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, @@ -31,6 +32,18 @@ static __dpct_inline__ int get_int_b1(const void * x, const int & i32) { return x32; } +static __dpct_inline__ int get_int_b2(const void * x, const int & i32) { + const uint16_t * x16 = (const uint16_t *) x; // assume at least 2 byte alignment + + int x32 = x16[2*i32 + 0] << 0; + x32 |= x16[2*i32 + 1] << 16; + + return x32; +} + +static __dpct_inline__ int get_int_b4(const void * x, const int & i32) { + return ((const int *) x)[i32]; // assume at least 4 byte alignment +} static __dpct_inline__ int get_int_from_int8(const int8_t* x8, const int& i32) { const uint16_t* x16 = @@ -72,6 +85,32 @@ static __dpct_inline__ int get_int_from_uint8_aligned( (const int*)(x8 + sizeof(int) * i32)); // assume at least 4 byte alignment } +static __dpct_inline__ int byte_sub_4(const int a, const int b) { + const uint32_t ua = static_cast<uint32_t>(a); + const uint32_t ub = static_cast<uint32_t>(b); + return static_cast<int>(((ua | 0x80808080u) - ub) ^ 0x80808080u); +} + +static __dpct_inline__ float vec_dot_q6_K_q8_1_impl_mmvq_scalar( + const int vl, const int vh, const int u0, const int u1, const int8_t sc0, + const int8_t sc1, const float d, const float d80, const float d81) { + static_assert(QR6_K == 2, "q6_K MMVQ scalar fast path assumes QR6_K == 2"); + + const int vil0 = (vl >> 0) & 0x0F0F0F0F; + const int vih0 = ((vh >> 0) << 4) & 0x30303030; + const int vi0 = byte_sub_4(vil0 | vih0, 0x20202020); + + const int vil1 = (vl >> 4) & 0x0F0F0F0F; + const int vih1 = ((vh >> 4) << 4) & 0x30303030; + const int vi1 = byte_sub_4(vil1 | vih1, 0x20202020); + + const float sumf = + d80 * (dpct::dp4a(vi0, u0, 0) * sc0) + + d81 * (dpct::dp4a(vi1, u1, 0) * sc1); + + return d * sumf; +} + static __dpct_inline__ void get_int_from_table_16(const uint32_t &q4, const uint8_t *values, int &val1, int &val2) { @@ -266,24 +305,8 @@ vec_dot_q6_K_q8_1_impl_mmvq(const int &vl, const int &vh, const int *__restrict__ u, const int8_t *__restrict__ scales, const float &d, const float *__restrict__ d8) { - - float sumf = 0.0f; - -#pragma unroll - for (int i = 0; i < QR6_K; ++i) { - const int sc = scales[4*i]; - - const int vil = (vl >> (4*i)) & 0x0F0F0F0F; - - const int vih = ((vh >> (4*i)) << 4) & 0x30303030; - - const int vi = dpct::vectorized_binary<sycl::char4>( - (vil | vih), 0x20202020, dpct::sub_sat()); // vi = (vil | vih) - 32 - - sumf += d8[i] * (dpct::dp4a(vi, u[i], 0) * sc); // SIMD dot product - } - - return d*sumf; + return vec_dot_q6_K_q8_1_impl_mmvq_scalar( + vl, vh, u[0], u[1], scales[0], scales[4], d, d8[0], d8[1]); } // VDR = vec dot ratio, how many contiguous integers each thread processes when the vec dot kernel is called @@ -338,6 +361,74 @@ template <> struct reorder_vec_dot_q_sycl<GGML_TYPE_Q4_0> { }; }; +template <> struct reorder_vec_dot_q_sycl<GGML_TYPE_Q8_0> { + static constexpr ggml_type gtype = GGML_TYPE_Q8_0; + + using q8_0_block = ggml_sycl_reordered::block_q_t<GGML_TYPE_Q8_0>; + using q8_0_traits = typename q8_0_block::traits; + + __dpct_inline__ float operator()(const void * __restrict__ vbq, const std::pair<int, int> ibx_offset, + const std::pair<int, int> d_offset, const int8_t * q8_1_quant_ptr, + const sycl::half2 * q8_1_ds, const int & iqs) { + const uint8_t * base = static_cast<const uint8_t *>(vbq); + const int8_t * qs = reinterpret_cast<const int8_t *>(base + ibx_offset.first); + const ggml_half d = *reinterpret_cast<const ggml_half *>(base + d_offset.first); + + int v[q8_0_traits::vdr_mmvq]; + int u[q8_0_traits::vdr_mmvq]; + +#pragma unroll + for (size_t i = 0; i < q8_0_traits::vdr_mmvq; ++i) { + v[i] = get_int_from_int8(qs, iqs + i); + u[i] = get_int_from_int8_aligned(q8_1_quant_ptr, iqs + i); + } + + int sumi = 0; +#pragma unroll + for (size_t i = 0; i < q8_0_traits::vdr_mmvq; ++i) { + sumi = dpct::dp4a(v[i], u[i], sumi); + } + + const sycl::half2 ds_values = *q8_1_ds; + return static_cast<float>(d) * static_cast<float>(ds_values[0]) * sumi; + } +}; + +template <> struct reorder_vec_dot_q_sycl<GGML_TYPE_Q3_K> { + static constexpr ggml_type gtype = GGML_TYPE_Q3_K; + + using q3_k_block = ggml_sycl_reordered::block_q_t<GGML_TYPE_Q3_K>; + using q3_k_traits = typename q3_k_block::traits; + + __dpct_inline__ float operator()(const void * __restrict__ vbq, const std::pair<int, int> ibx_offset, + const std::pair<int, int> d_offset, const int8_t * q8_1_quant_ptr, + const sycl::half2 * q8_1_ds, const int & iqs) { + const uint8_t * base = static_cast<const uint8_t *>(vbq); + const uint8_t * qs = base + ibx_offset.first; + const uint8_t * hmask = base + ibx_offset.second; + const uint8_t * scales = base + d_offset.first; + const ggml_half d = *reinterpret_cast<const ggml_half *>(base + d_offset.second); + + const int bq8_offset = QR3_K * (iqs / (QI3_K / 2)); + const int scale_offset = iqs - iqs % QI8_1 + (iqs % QI8_1) / (QI8_1 / 2); + + const int vl = get_int_from_uint8(qs, iqs); + const int vh = ~get_int_from_uint8(hmask, iqs % (QI3_K / 2)) >> bq8_offset; + + int u[QR3_K]; + float d8[QR3_K]; + +#pragma unroll + for (int i = 0; i < QR3_K; ++i) { + const int8_t * quant_base_ptr = q8_1_quant_ptr + (bq8_offset + i) * QK8_1; + u[i] = get_int_from_int8_aligned(quant_base_ptr, iqs % QI8_1); + d8[i] = (*(q8_1_ds + bq8_offset + i))[0]; + } + + return vec_dot_q3_K_q8_1_impl_mmvq(vl, vh, u, scales, scale_offset, static_cast<float>(d), d8); + } +}; + static inline float vec_dot_q4_K_q8_1_common(const int * __restrict__ q4, const uint16_t * __restrict__ scales, const ggml_half2 & dm, const block_q8_1 * __restrict__ bq8_1, const int & iqs) { @@ -428,32 +519,76 @@ template <> struct reorder_vec_dot_q_sycl<GGML_TYPE_Q4_K> { } }; -template <> struct reorder_vec_dot_q_sycl<GGML_TYPE_Q6_K> { - static constexpr ggml_type gtype = GGML_TYPE_Q6_K; +template <> struct reorder_vec_dot_q_sycl<GGML_TYPE_Q5_K> { + static constexpr ggml_type gtype = GGML_TYPE_Q5_K; - using q6_k_block = ggml_sycl_reordered::block_q_t<GGML_TYPE_Q6_K>; - using q6_k_traits = typename q6_k_block::traits; + using q5_k_block = ggml_sycl_reordered::block_q_t<GGML_TYPE_Q5_K>; + using q5_k_traits = typename q5_k_block::traits; - __dpct_inline__ float vec_dot_q6_K_q8_1_impl_mmvq(const int vl, const int vh, const int * __restrict__ u, - const int8_t * __restrict__ scales, const float d, - const float * __restrict__ d8) { - float sumf = 0.0f; + __dpct_inline__ float operator()(const void * __restrict__ vbq, const std::pair<int, int> ibx_offset, + const std::pair<int, int> d_offset, const int8_t * q8_1_quant_ptr, + const sycl::half2 * q8_1_ds, const int & iqs) { + const uint8_t * base = static_cast<const uint8_t *>(vbq); + const uint8_t * qs = base + ibx_offset.first; // low 4 bits + const uint8_t * qh_base = base + ibx_offset.second; // high bit + const uint8_t * scs = base + d_offset.first; + const ggml_half2 * dms = reinterpret_cast<const ggml_half2 *>(base + d_offset.second); -#pragma unroll - for (int i = 0; i < QR6_K; ++i) { - const int sc = scales[4 * i]; + const int bq8_offset = QR5_K * ((iqs / 2) / (QI8_1 / 2)); + const int * ql_ptr = (const int *) (qs + 16 * bq8_offset + 4 * ((iqs / 2) % 4)); + const int * qh_ptr = (const int *) (qh_base + 4 * ((iqs / 2) % 4)); + const uint16_t * scales = (const uint16_t *) scs; - const int vil = (vl >> (4 * i)) & 0x0F0F0F0F; + int vl[2]; + int vh[2]; + int u[2 * QR5_K]; + float d8[QR5_K]; - const int vih = ((vh >> (4 * i)) << 4) & 0x30303030; + vl[0] = ql_ptr[0]; + vl[1] = ql_ptr[4]; - const int vi = dpct::vectorized_binary<sycl::char4>((vil | vih), 0x20202020, - dpct::sub_sat()); // vi = (vil | vih) - 32 + vh[0] = qh_ptr[0] >> bq8_offset; + vh[1] = qh_ptr[4] >> bq8_offset; - sumf += d8[i] * (dpct::dp4a(vi, u[i], 0) * sc); // SIMD dot product + uint16_t aux[2]; + const int j = (QR5_K * ((iqs / 2) / (QI8_1 / 2))) / 2; + if (j < 2) { + aux[0] = scales[j + 0] & 0x3f3f; + aux[1] = scales[j + 2] & 0x3f3f; + } else { + aux[0] = ((scales[j + 2] >> 0) & 0x0f0f) | ((scales[j - 2] & 0xc0c0) >> 2); + aux[1] = ((scales[j + 2] >> 4) & 0x0f0f) | ((scales[j - 0] & 0xc0c0) >> 2); } - return d * sumf; + const uint8_t * sc = (const uint8_t *) aux; + const uint8_t * m = sc + 2; + + for (int i = 0; i < QR5_K; ++i) { + const int8_t* quant_base_ptr = q8_1_quant_ptr + (bq8_offset + i) * QK8_1; + sycl::half2 ds_values = *(q8_1_ds + bq8_offset + i); + + d8[i] = ds_values[0]; + + const int * q8 = (const int *) quant_base_ptr + ((iqs / 2) % 4); + u[2 * i + 0] = q8[0]; + u[2 * i + 1] = q8[4]; + } + + return vec_dot_q5_K_q8_1_impl_vmmq(vl, vh, u, sc, m, *dms, d8); + } +}; + +template <> struct reorder_vec_dot_q_sycl<GGML_TYPE_Q6_K> { + static constexpr ggml_type gtype = GGML_TYPE_Q6_K; + + using q6_k_block = ggml_sycl_reordered::block_q_t<GGML_TYPE_Q6_K>; + using q6_k_traits = typename q6_k_block::traits; + + __dpct_inline__ float vec_dot_q6_K_q8_1_impl_mmvq(const int vl, const int vh, const int * __restrict__ u, + const int8_t * __restrict__ scales, const float d, + const float * __restrict__ d8) { + return vec_dot_q6_K_q8_1_impl_mmvq_scalar( + vl, vh, u[0], u[1], scales[0], scales[4], d, d8[0], d8[1]); } __dpct_inline__ float operator()(const void * __restrict__ vbq, const std::pair<int, int> ibx_offset, @@ -474,16 +609,15 @@ template <> struct reorder_vec_dot_q_sycl<GGML_TYPE_Q6_K> { const int8_t * scs = scales + scale_offset; - int u[QR6_K]; - float d8[QR6_K]; + const int u0 = get_int_from_int8_aligned( + q8_1_quant_ptr + bq8_offset * QK8_1, iqs % QI8_1); + const int u1 = get_int_from_int8_aligned( + q8_1_quant_ptr + (bq8_offset + 2) * QK8_1, iqs % QI8_1); + const float d80 = (*(q8_1_ds + bq8_offset + 0))[0]; + const float d81 = (*(q8_1_ds + bq8_offset + 2))[0]; -#pragma unroll - for (int i = 0; i < QR6_K; ++i) { - u[i] = get_int_from_int8_aligned(q8_1_quant_ptr + (bq8_offset + 2 * i) * QK8_1, iqs % QI8_1); - const sycl::half2 ds_values = *(q8_1_ds + bq8_offset + 2 * i); - d8[i] = ds_values[0]; - } - return vec_dot_q6_K_q8_1_impl_mmvq(vl, vh, u, scs, *d, d8); + return vec_dot_q6_K_q8_1_impl_mmvq_scalar( + vl, vh, u0, u1, scs[0], scs[4], *d, d80, d81); } }; #define VDR_Q4_0_Q8_1_MMVQ 2 @@ -650,6 +784,19 @@ static __dpct_inline__ float vec_dot_q8_0_q8_1_impl(const int *v, const int *u, return d8_0*d8_1 * sumi; } +template <typename T, int vdr> +static __dpct_inline__ T vec_dot_q8_0_q8_1_impl(const int * v, const int * u, const T & d8_0, const T & d8_1) { + int sumi = 0; + +#pragma unroll + for (int i = 0; i < vdr; ++i) { + // SIMD dot product of quantized values + sumi = ggml_sycl_dp4a(v[i], u[i], sumi); + } + + return d8_0*d8_1 * ((T) sumi); +} + template <int vdr> static __dpct_inline__ float vec_dot_q8_1_q8_1_impl(const int *v, const int *u, const sycl::half2 &dm8, @@ -742,6 +889,35 @@ static __dpct_inline__ float vec_dot_mxfp4_q8_1(const void * __restrict__ vbq, return d * sumi; } +#define VDR_NVFP4_Q8_1_MMVQ 4 +#define VDR_NVFP4_Q8_1_MMQ 8 + +static __dpct_inline__ float vec_dot_nvfp4_q8_1(const void * __restrict__ vbq, + const block_q8_1 * __restrict__ bq8_1, + const int32_t & iqs) { + const block_nvfp4 * bq4 = (const block_nvfp4 *) vbq; + float sum = 0.0f; +#pragma unroll + for (int i = 0; i < VDR_NVFP4_Q8_1_MMVQ/2; i++) { + const int32_t iqs0 = iqs + 2*i; + const int32_t iqs1 = iqs0 + 1; + const int32_t is = iqs0 >> 1; + const sycl::int2 v0 = get_int_from_table_16(get_int_b4(bq4->qs, iqs0), kvalues_mxfp4); + const sycl::int2 v1 = get_int_from_table_16(get_int_b4(bq4->qs, iqs1), kvalues_mxfp4); + const block_q8_1 * bq8 = bq8_1 + (is >> 1); + const int32_t i8 = ((is & 1) << 2); + + int sumi = ggml_sycl_dp4a(v0.x(), get_int_b4(bq8->qs, i8 + 0), 0); + sumi = ggml_sycl_dp4a(v0.y(), get_int_b4(bq8->qs, i8 + 2), sumi); + sumi = ggml_sycl_dp4a(v1.x(), get_int_b4(bq8->qs, i8 + 1), sumi); + sumi = ggml_sycl_dp4a(v1.y(), get_int_b4(bq8->qs, i8 + 3), sumi); + + const float d = ggml_sycl_ue4m3_to_fp32(bq4->d[is]) * (bq8->ds)[0]; + sum += d * float(sumi); + } + + return sum; +} static __dpct_inline__ float vec_dot_q5_0_q8_1(const void *__restrict__ vbq, @@ -1020,16 +1196,15 @@ vec_dot_q6_K_q8_1(const void *__restrict__ vbq, const int8_t * scales = bq6_K->scales + scale_offset; - int u[QR6_K]; - float d8[QR6_K]; - -#pragma unroll - for (int i = 0; i < QR6_K; ++i) { - u[i] = get_int_from_int8_aligned(bq8_1[bq8_offset + 2*i].qs, iqs % QI8_1); - d8[i] = bq8_1[bq8_offset + 2 * i].ds[0]; - } + const int u0 = get_int_from_int8_aligned( + bq8_1[bq8_offset + 0].qs, iqs % QI8_1); + const int u1 = get_int_from_int8_aligned( + bq8_1[bq8_offset + 2].qs, iqs % QI8_1); + const float d80 = bq8_1[bq8_offset + 0].ds[0]; + const float d81 = bq8_1[bq8_offset + 2].ds[0]; - return vec_dot_q6_K_q8_1_impl_mmvq(vl, vh, u, scales, bq6_K->d, d8); + return vec_dot_q6_K_q8_1_impl_mmvq_scalar( + vl, vh, u0, u1, scales[0], scales[4], bq6_K->d, d80, d81); } diff --git a/ggml/src/ggml-sycl/wkv.cpp b/ggml/src/ggml-sycl/wkv.cpp index c10e2f7645e..b56e0c2400f 100644 --- a/ggml/src/ggml-sycl/wkv.cpp +++ b/ggml/src/ggml-sycl/wkv.cpp @@ -1,7 +1,7 @@ #include <sycl/sycl.hpp> #include "wkv.hpp" -constexpr int WKV_BLOCK_SIZE = 64; // Matching CUDA_WKV_BLOCK_SIZE +constexpr int WKV_BLOCK_SIZE = 64; // Helper function for the main kernel template <int block_size> diff --git a/ggml/src/ggml-virtgpu/CMakeLists.txt b/ggml/src/ggml-virtgpu/CMakeLists.txt new file mode 100644 index 00000000000..e6b020beb5b --- /dev/null +++ b/ggml/src/ggml-virtgpu/CMakeLists.txt @@ -0,0 +1,70 @@ +cmake_minimum_required(VERSION 3.19) +cmake_policy(SET CMP0114 NEW) + +include(ExternalProject) + +message(STATUS "Including the VirtGPU/Virglrenderer API Remoting") + +# Download venus_hw.h from virglrenderer repository +ExternalProject_Add( + venus_hw_header + URL https://gitlab.freedesktop.org/virgl/virglrenderer/-/raw/virglrenderer-1.2.0/src/venus_hw.h + DOWNLOAD_NO_EXTRACT YES + DOWNLOAD_DIR ${CMAKE_CURRENT_SOURCE_DIR}/include + DOWNLOAD_NAME venus_hw.h + CONFIGURE_COMMAND "" + BUILD_COMMAND "" + INSTALL_COMMAND "" + LOG_DOWNLOAD ON +) + +if (NOT GGML_VIRTGPU_BACKEND STREQUAL "ONLY") + message(STATUS "Enable the VirtGPU/Virglrenderer API Remoting frontend library") + + find_package(PkgConfig REQUIRED) + pkg_check_modules(DRM REQUIRED libdrm) + if (NOT GGML_BACKEND_DL) + # cannot simply use USE_VIRTGPU, as in the 'else()' case the + # frontend isn't compiled + target_compile_definitions(ggml PUBLIC "GGML_USE_VIRTGPU_FRONTEND") + endif() + + ggml_add_backend_library(ggml-virtgpu + ggml-backend-buffer.cpp + ggml-backend.cpp + ggml-backend-device.cpp + ggml-backend-reg.cpp + ggml-backend-buffer-type.cpp + virtgpu-apir.h + virtgpu-forward.gen.h + virtgpu.cpp + virtgpu-shm.cpp + virtgpu-utils.cpp + virtgpu-forward-device.cpp + virtgpu-forward-buffer-type.cpp + virtgpu-forward-buffer.cpp + virtgpu-forward-backend.cpp + virtgpu-forward-impl.h + apir_cs_ggml-rpc-front.cpp + ../../include/ggml-virtgpu.h) + + target_include_directories(ggml-virtgpu PUBLIC /usr/include/libdrm/) + + target_link_libraries(ggml-virtgpu PUBLIC ${DRM_LIBRARIES}) + target_include_directories(ggml-virtgpu PUBLIC ${DRM_INCLUDE_DIRS}) + target_compile_options(ggml-virtgpu PUBLIC ${DRM_CFLAGS_OTHER}) + + target_include_directories(ggml-virtgpu PUBLIC ./include) + target_include_directories(ggml-virtgpu PRIVATE ${CMAKE_CURRENT_BINARY_DIR}) + + # Ensure venus_hw.h is downloaded before building ggml-virtgpu + add_dependencies(ggml-virtgpu venus_hw_header) + + target_compile_options(ggml-virtgpu PRIVATE -std=c++20) +else() + message(STATUS "Not building the VirtGPU/Virglrenderer API Remoting frontend library") +endif() + +if (NOT GGML_VIRTGPU_BACKEND STREQUAL "OFF") + add_subdirectory("backend") +endif() diff --git a/ggml/src/ggml-virtgpu/apir_cs_ggml-rpc-front.cpp b/ggml/src/ggml-virtgpu/apir_cs_ggml-rpc-front.cpp new file mode 100644 index 00000000000..d2e87330a63 --- /dev/null +++ b/ggml/src/ggml-virtgpu/apir_cs_ggml-rpc-front.cpp @@ -0,0 +1,87 @@ +#include "backend/shared/apir_cs_rpc.h" +#include "ggml-backend-impl.h" +#include "ggml-impl.h" +#include "ggml-remoting.h" + +#include <cinttypes> +#include <unordered_map> +#include <unordered_set> +#include <vector> + +apir_rpc_tensor apir_serialize_tensor(const ggml_tensor * tensor) { + apir_rpc_tensor result; + result.id = reinterpret_cast<uint64_t>(tensor); + result.type = tensor->type; + if (tensor->buffer) { + ggml_backend_buffer_t buffer = tensor->buffer; + + result.buffer = BUFFER_TO_HOST_HANDLE(buffer); + } else { + result.buffer = 0; + } + for (uint32_t i = 0; i < GGML_MAX_DIMS; i++) { + result.ne[i] = tensor->ne[i]; + result.nb[i] = tensor->nb[i]; + } + result.op = tensor->op; + for (uint32_t i = 0; i < GGML_MAX_OP_PARAMS / sizeof(int32_t); i++) { + result.op_params[i] = tensor->op_params[i]; + } + result.flags = tensor->flags; + for (uint32_t i = 0; i < GGML_MAX_SRC; i++) { + result.src[i] = reinterpret_cast<uint64_t>(tensor->src[i]); + } + result.view_src = reinterpret_cast<uint64_t>(tensor->view_src); + result.view_offs = tensor->view_offs; + result.data = reinterpret_cast<uint64_t>(tensor->data); + if (tensor->data) { + if (!tensor->buffer) { + GGML_ABORT("%s: tensor has data but not buffer", __func__); + } + // tensor->data is serialized as an offset to the buffer base address + result.data -= reinterpret_cast<uint64_t>(BUFFER_TO_GGML_CONTEXT(tensor->buffer)->base); + } + snprintf(result.name, GGML_MAX_NAME, "%s", tensor->name); + return result; +} + +void apir_add_tensor(ggml_tensor * tensor, + std::vector<apir_rpc_tensor> & tensors, + std::unordered_set<ggml_tensor *> & visited) { + if (tensor == nullptr) { + return; + } + if (visited.find(tensor) != visited.end()) { + return; + } + visited.insert(tensor); + for (int i = 0; i < GGML_MAX_SRC; i++) { + apir_add_tensor(tensor->src[i], tensors, visited); + } + apir_add_tensor(tensor->view_src, tensors, visited); + tensors.push_back(apir_serialize_tensor(tensor)); +} + +void apir_serialize_graph(const ggml_cgraph * cgraph, std::vector<uint8_t> & output) { + uint32_t n_nodes = cgraph->n_nodes; + std::vector<apir_rpc_tensor> tensors; + std::unordered_set<ggml_tensor *> visited; + for (uint32_t i = 0; i < n_nodes; i++) { + apir_add_tensor(cgraph->nodes[i], tensors, visited); + } + // serialization format: + // | n_nodes (4 bytes) | nodes (n_nodes * sizeof(uint64_t) | n_tensors (4 bytes) | tensors (n_tensors * sizeof(apir_rpc_tensor)) | + uint32_t n_tensors = tensors.size(); + int output_size = + sizeof(uint32_t) + n_nodes * sizeof(uint64_t) + sizeof(uint32_t) + n_tensors * sizeof(apir_rpc_tensor); + output.resize(output_size, 0); + memcpy(output.data(), &n_nodes, sizeof(n_nodes)); + for (uint32_t i = 0; i < n_nodes; i++) { + memcpy(output.data() + sizeof(n_nodes) + i * sizeof(uint64_t), &cgraph->nodes[i], sizeof(uint64_t)); + } + uint32_t * out_ntensors = (uint32_t *) (output.data() + sizeof(n_nodes) + n_nodes * sizeof(uint64_t)); + *out_ntensors = n_tensors; + apir_rpc_tensor * out_tensors = + (apir_rpc_tensor *) (output.data() + sizeof(n_nodes) + n_nodes * sizeof(uint64_t) + sizeof(uint32_t)); + memcpy(out_tensors, tensors.data(), n_tensors * sizeof(apir_rpc_tensor)); +} diff --git a/ggml/src/ggml-virtgpu/backend/CMakeLists.txt b/ggml/src/ggml-virtgpu/backend/CMakeLists.txt new file mode 100644 index 00000000000..0b49c403b9a --- /dev/null +++ b/ggml/src/ggml-virtgpu/backend/CMakeLists.txt @@ -0,0 +1,21 @@ +cmake_minimum_required(VERSION 3.19) +cmake_policy(SET CMP0114 NEW) + +message(STATUS "Enable the VirtGPU/Virglrenderer backend library") + +ggml_add_backend_library(ggml-virtgpu-backend + backend.cpp + backend-dispatched.cpp + backend-dispatched-backend.cpp + backend-dispatched-device.cpp + backend-dispatched-buffer.cpp + backend-dispatched-buffer-type.cpp + shared/api_remoting.h + shared/apir_backend.h + shared/apir_cs.h + apir_cs_ggml-rpc-back.cpp) + +target_compile_options(ggml-virtgpu-backend PRIVATE -std=c++20) + +# Add include directory for ggml-backend-impl.h and other core headers +target_include_directories(ggml-virtgpu-backend PRIVATE ../..) diff --git a/ggml/src/ggml-virtgpu/backend/apir_cs_ggml-rpc-back.cpp b/ggml/src/ggml-virtgpu/backend/apir_cs_ggml-rpc-back.cpp new file mode 100644 index 00000000000..60a8a93bfb8 --- /dev/null +++ b/ggml/src/ggml-virtgpu/backend/apir_cs_ggml-rpc-back.cpp @@ -0,0 +1,115 @@ +#include "ggml-backend-impl.h" +#include "ggml-impl.h" +#include "shared/apir_cs_rpc.h" + +#include <cinttypes> +#include <unordered_map> +#include <unordered_set> +#include <vector> + +std::unordered_set<ggml_backend_buffer_t> backend_buffers; + +void apir_track_backend_buffer(ggml_backend_buffer_t buffer) { + backend_buffers.insert(buffer); +} + +bool apir_untrack_backend_buffer(ggml_backend_buffer_t buffer) { + auto it = backend_buffers.find(buffer); + if (it == backend_buffers.end()) { + return false; + } + + backend_buffers.erase(it); + return true; +} + +std::unordered_set<ggml_backend_buffer_t> apir_get_track_backend_buffers() { + return backend_buffers; +} + +ggml_tensor * apir_deserialize_tensor(ggml_context * ctx, const apir_rpc_tensor * tensor) { + ggml_tensor * result = + ggml_new_tensor_4d(ctx, (ggml_type) tensor->type, tensor->ne[0], tensor->ne[1], tensor->ne[2], tensor->ne[3]); + for (uint32_t i = 0; i < GGML_MAX_DIMS; i++) { + result->nb[i] = tensor->nb[i]; + } + result->buffer = reinterpret_cast<ggml_backend_buffer_t>(tensor->buffer); + if (result->buffer && backend_buffers.find(result->buffer) == backend_buffers.end()) { + printf("WARNING: HOST BUFFER NOT FOUND | %p\n", (void *) result->buffer); + result->buffer = nullptr; + } + + uint64_t tensor_data = tensor->data; + if (result->buffer) { + // require that the tensor data does not go beyond the buffer end + uint64_t tensor_size = (uint64_t) ggml_nbytes(result); + uint64_t buffer_start = (uint64_t) ggml_backend_buffer_get_base(result->buffer); + uint64_t buffer_size = (uint64_t) ggml_backend_buffer_get_size(result->buffer); + + // tensor->data is serialized as an offset to the buffer base address + tensor_data += buffer_start; + + GGML_ASSERT(tensor_data + tensor_size >= tensor_data); // check for overflow + GGML_ASSERT(tensor_data >= buffer_start && tensor_data + tensor_size <= buffer_start + buffer_size); + } + + result->op = (ggml_op) tensor->op; + for (uint32_t i = 0; i < GGML_MAX_OP_PARAMS / sizeof(int32_t); i++) { + result->op_params[i] = tensor->op_params[i]; + } + result->flags = tensor->flags; + result->data = reinterpret_cast<void *>(tensor_data); + ggml_set_name(result, tensor->name); + return result; +} + +ggml_tensor * apir_create_node(uint64_t id, + ggml_context * ctx, + const std::unordered_map<uint64_t, const apir_rpc_tensor *> & tensor_ptrs, + std::unordered_map<uint64_t, ggml_tensor *> & tensor_map) { + if (id == 0) { + return nullptr; + } + if (tensor_map.find(id) != tensor_map.end()) { + return tensor_map[id]; + } + const apir_rpc_tensor * tensor = tensor_ptrs.at(id); + ggml_tensor * result = apir_deserialize_tensor(ctx, tensor); + if (result == nullptr) { + return nullptr; + } + tensor_map[id] = result; + for (int i = 0; i < GGML_MAX_SRC; i++) { + result->src[i] = apir_create_node(tensor->src[i], ctx, tensor_ptrs, tensor_map); + } + result->view_src = apir_create_node(tensor->view_src, ctx, tensor_ptrs, tensor_map); + result->view_offs = tensor->view_offs; + return result; +} + +ggml_cgraph * apir_deserialize_graph(uint32_t n_nodes, + uint32_t n_tensors, + const apir_rpc_tensor * tensors, + const uint64_t * nodes) { + size_t buf_size = ggml_tensor_overhead() * (n_nodes + n_tensors) + ggml_graph_overhead_custom(n_nodes, false); + ggml_init_params params = { + /*.mem_size =*/buf_size, + /*.mem_buffer =*/NULL, + /*.no_alloc =*/true, + }; + ggml_context * ctx = ggml_init(params); + ggml_cgraph * graph = ggml_new_graph_custom(ctx, n_nodes, false); + graph->n_nodes = n_nodes; + std::unordered_map<uint64_t, const apir_rpc_tensor *> tensor_ptrs; + for (uint32_t i = 0; i < n_tensors; i++) { + tensor_ptrs[tensors[i].id] = &tensors[i]; + } + std::unordered_map<uint64_t, ggml_tensor *> tensor_map; + for (uint32_t i = 0; i < n_nodes; i++) { + int64_t id; + memcpy(&id, &nodes[i], sizeof(id)); + graph->nodes[i] = apir_create_node(id, ctx, tensor_ptrs, tensor_map); + } + + return graph; +} diff --git a/ggml/src/ggml-virtgpu/backend/backend-convert.h b/ggml/src/ggml-virtgpu/backend/backend-convert.h new file mode 100644 index 00000000000..1978d21f7ef --- /dev/null +++ b/ggml/src/ggml-virtgpu/backend/backend-convert.h @@ -0,0 +1,13 @@ +#include "shared/apir_backend.h" + +#define BUFFER_TO_HOST_HANDLE(name) ggml_buffer_to_apir_handle(name) + +static inline apir_buffer_host_handle_t ggml_buffer_to_apir_handle(ggml_backend_buffer_t buffer) { + // in the backend, the buffer handle is the buffer pointer + return (apir_buffer_host_handle_t) buffer; +} + +static inline apir_buffer_type_host_handle_t ggml_buffer_type_to_apir_handle(ggml_backend_buffer_type_t buft) { + // in the backend, the buffer handle is the buffer pointer + return (apir_buffer_type_host_handle_t) buft; +} diff --git a/ggml/src/ggml-virtgpu/backend/backend-dispatched-backend.cpp b/ggml/src/ggml-virtgpu/backend/backend-dispatched-backend.cpp new file mode 100644 index 00000000000..03a037f1cbd --- /dev/null +++ b/ggml/src/ggml-virtgpu/backend/backend-dispatched-backend.cpp @@ -0,0 +1,102 @@ +#include "backend-dispatched.h" +#include "backend-virgl-apir.h" +#include "ggml-backend-impl.h" +#include "ggml-backend.h" +#include "ggml-impl.h" +#include "shared/apir_backend.h" + +#include <cstdint> + +static uint32_t validate_graph_operation(size_t cgraph_size, uint32_t shmem_res_id, const char * operation) { + if (cgraph_size == 0) { + GGML_LOG_ERROR(GGML_VIRTGPU_BCK "%s: Zero-size computation graph\n", operation); + return 1; + } + + // place-holder: validate that the size of shmem_res_id is <= cgraph_size + // need to add another method in the Virgl->APIR callback interface + GGML_UNUSED(shmem_res_id); + + return 0; // Valid +} + +uint32_t backend_backend_graph_compute(apir_encoder * enc, apir_decoder * dec, virgl_apir_context * ctx) { + GGML_UNUSED(ctx); + + static bool async_backend_initialized = false; + static bool async_backend; + + if (!async_backend_initialized) { + ggml_backend_dev_props props; + + dev->iface.get_props(dev, &props); + async_backend = props.caps.async; + async_backend_initialized = true; + } + + uint32_t shmem_res_id; + apir_decode_virtgpu_shmem_res_id(dec, &shmem_res_id); + + const void * shmem_data = ctx->iface->get_shmem_ptr(ctx->ctx_id, shmem_res_id); + if (!shmem_data) { + GGML_LOG_ERROR(GGML_VIRTGPU_BCK "%s: Couldn't get the shmem addr from virgl\n", __func__); + apir_decoder_set_fatal(dec); + return 1; + } + size_t cgraph_size; + apir_decode_size_t(dec, &cgraph_size); + + if (validate_graph_operation(cgraph_size, shmem_res_id, __func__) != 0) { + apir_decoder_set_fatal(dec); + return 1; + } + + apir_decoder secondary_dec = apir_new_decoder((const char *) shmem_data, cgraph_size); + + ggml_cgraph * cgraph = apir_decode_ggml_cgraph(&secondary_dec, cgraph_size); + + if (!cgraph || apir_decoder_get_fatal(&secondary_dec)) { + GGML_LOG_ERROR(GGML_VIRTGPU_BCK "%s: Failed to deserialize computation graph\n", __func__); + return 1; + } + + if (cgraph->n_nodes < 0 || cgraph->n_leafs < 0) { + GGML_LOG_ERROR(GGML_VIRTGPU_BCK "%s: Invalid negative node/leaf count: nodes=%d leafs=%d\n", __func__, + cgraph->n_nodes, cgraph->n_leafs); + return 1; + } + + ggml_status status; +#if APIR_BACKEND_CHECK_SUPPORTS_OP == 1 + for (int idx = 0; idx < cgraph->n_nodes; idx++) { + ggml_tensor * op = ggml_graph_node(cgraph, idx); + if (dev->iface.supports_op(dev, op)) { + continue; + } + GGML_LOG_ERROR(GGML_VIRTGPU_BCK "%s: Graph node %d (%s) not supported by the backend\n", __func__, idx, + ggml_op_desc(op)); + + status = GGML_STATUS_ABORTED; + apir_encode_ggml_status(enc, &status); + + return 0; + } +#endif + + // Check if backend is properly initialized + if (!bck) { + GGML_LOG_ERROR(GGML_VIRTGPU_BCK "%s: Backend not initialized (bck is null)\n", __func__); + + return 1; + } + + status = bck->iface.graph_compute(bck, cgraph); + + if (async_backend && bck->iface.synchronize) { + bck->iface.synchronize(bck); + } + + apir_encode_ggml_status(enc, &status); + + return 0; +} diff --git a/ggml/src/ggml-virtgpu/backend/backend-dispatched-buffer-type.cpp b/ggml/src/ggml-virtgpu/backend/backend-dispatched-buffer-type.cpp new file mode 100644 index 00000000000..c66dbaa9e8f --- /dev/null +++ b/ggml/src/ggml-virtgpu/backend/backend-dispatched-buffer-type.cpp @@ -0,0 +1,105 @@ +#include "backend-dispatched.h" +#include "backend-virgl-apir.h" +#include "ggml-backend-impl.h" +#include "ggml-backend.h" +#include "ggml-impl.h" + +#include <cstdint> + +uint32_t backend_buffer_type_get_name(apir_encoder * enc, apir_decoder * dec, virgl_apir_context * ctx) { + GGML_UNUSED(ctx); + ggml_backend_buffer_type_t buft; + buft = apir_decode_ggml_buffer_type(dec); + + const char * string = buft->iface.get_name(buft); + + const size_t string_size = strlen(string) + 1; + apir_encode_array_size(enc, string_size); + apir_encode_char_array(enc, string, string_size); + + return 0; +} + +uint32_t backend_buffer_type_get_alignment(apir_encoder * enc, apir_decoder * dec, virgl_apir_context * ctx) { + GGML_UNUSED(ctx); + ggml_backend_buffer_type_t buft; + buft = apir_decode_ggml_buffer_type(dec); + + size_t value = buft->iface.get_alignment(buft); + apir_encode_size_t(enc, &value); + + return 0; +} + +uint32_t backend_buffer_type_get_max_size(apir_encoder * enc, apir_decoder * dec, virgl_apir_context * ctx) { + GGML_UNUSED(ctx); + ggml_backend_buffer_type_t buft; + buft = apir_decode_ggml_buffer_type(dec); + + size_t value = SIZE_MAX; + if (buft->iface.get_max_size) { + value = buft->iface.get_max_size(buft); + } + + apir_encode_size_t(enc, &value); + + return 0; +} + +/* APIR_COMMAND_TYPE_BUFFER_TYPE_IS_HOST is deprecated. Keeping the handler for backward compatibility. */ +uint32_t backend_buffer_type_is_host(apir_encoder * enc, apir_decoder * dec, virgl_apir_context * ctx) { + GGML_UNUSED(ctx); + GGML_UNUSED(dec); + const bool is_host = false; + + apir_encode_bool_t(enc, &is_host); + + return 0; +} + +uint32_t backend_buffer_type_alloc_buffer(apir_encoder * enc, apir_decoder * dec, virgl_apir_context * ctx) { + GGML_UNUSED(ctx); + ggml_backend_buffer_type_t buft; + buft = apir_decode_ggml_buffer_type(dec); + + size_t size; + apir_decode_size_t(dec, &size); + + ggml_backend_buffer_t buffer; + + buffer = buft->iface.alloc_buffer(buft, size); + + apir_encode_ggml_buffer(enc, buffer); + + if (buffer) { + apir_track_backend_buffer(buffer); + } + + return 0; +} + +uint32_t backend_buffer_type_get_alloc_size(apir_encoder * enc, apir_decoder * dec, virgl_apir_context * ctx) { + GGML_UNUSED(ctx); + ggml_backend_buffer_type_t buft; + buft = apir_decode_ggml_buffer_type(dec); + + const ggml_tensor * op = apir_decode_ggml_tensor_inplace(dec); + + // Check for decode error + if (op == nullptr) { + GGML_LOG_ERROR(GGML_VIRTGPU_BCK "%s: Failed to decode tensor\n", __func__); + apir_decoder_set_fatal(dec); + return 1; + } + + size_t value; + if (buft->iface.get_alloc_size) { + value = buft->iface.get_alloc_size(buft, op); + } else { + value = ggml_nbytes(op); // Default fallback + } + + apir_encode_size_t(enc, &value); + + return 0; +} diff --git a/ggml/src/ggml-virtgpu/backend/backend-dispatched-buffer.cpp b/ggml/src/ggml-virtgpu/backend/backend-dispatched-buffer.cpp new file mode 100644 index 00000000000..3ade8d99b4e --- /dev/null +++ b/ggml/src/ggml-virtgpu/backend/backend-dispatched-buffer.cpp @@ -0,0 +1,179 @@ +#include "backend-dispatched.h" +#include "backend-virgl-apir.h" +#include "ggml-backend-impl.h" +#include "ggml-backend.h" +#include "ggml-impl.h" + +#include <cstdint> + +static uint32_t validate_buffer_operation(size_t offset, size_t size, const char * operation) { + // Only check for critical integer overflow - no arbitrary size limits + if (offset > SIZE_MAX - size) { + GGML_LOG_ERROR(GGML_VIRTGPU_BCK "%s: Integer overflow in offset+size: %zu + %zu\n", operation, offset, size); + return 1; + } + + return 0; // Valid +} + +uint32_t backend_buffer_get_base(apir_encoder * enc, apir_decoder * dec, virgl_apir_context * ctx) { + GGML_UNUSED(ctx); + ggml_backend_buffer_t buffer; + buffer = apir_decode_ggml_buffer(dec); + + if (!buffer || apir_decoder_get_fatal(dec)) { + GGML_LOG_ERROR(GGML_VIRTGPU_BCK "%s: Invalid buffer handle from guest\n", __func__); + return 1; + } + + uintptr_t base = (uintptr_t) buffer->iface.get_base(buffer); + apir_encode_uintptr_t(enc, &base); + + return 0; +} + +uint32_t backend_buffer_set_tensor(apir_encoder * enc, apir_decoder * dec, virgl_apir_context * ctx) { + GGML_UNUSED(ctx); + GGML_UNUSED(enc); + + ggml_backend_buffer_t buffer; + buffer = apir_decode_ggml_buffer(dec); + + if (!buffer || apir_decoder_get_fatal(dec)) { + GGML_LOG_ERROR(GGML_VIRTGPU_BCK "%s: Invalid buffer handle from guest\n", __func__); + return 1; + } + + ggml_tensor * tensor; + // safe to remove the const qualifier here + tensor = (ggml_tensor *) (uintptr_t) apir_decode_ggml_tensor(dec); + + uint32_t shmem_res_id; + apir_decode_virtgpu_shmem_res_id(dec, &shmem_res_id); + + size_t offset; + apir_decode_size_t(dec, &offset); + + size_t size; + apir_decode_size_t(dec, &size); + + if (validate_buffer_operation(offset, size, __func__) != 0) { + return 1; + } + + void * shmem_data = ctx->iface->get_shmem_ptr(ctx->ctx_id, shmem_res_id); + + if (!shmem_data) { + GGML_LOG_ERROR(GGML_VIRTGPU_BCK "%s: Couldn't get the shmem addr from virgl\n", __func__); + return 1; + } + + buffer->iface.set_tensor(buffer, tensor, shmem_data, offset, size); + + return 0; +} + +uint32_t backend_buffer_get_tensor(apir_encoder * enc, apir_decoder * dec, virgl_apir_context * ctx) { + GGML_UNUSED(ctx); + GGML_UNUSED(enc); + + ggml_backend_buffer_t buffer; + buffer = apir_decode_ggml_buffer(dec); + + if (!buffer || apir_decoder_get_fatal(dec)) { + GGML_LOG_ERROR(GGML_VIRTGPU_BCK "%s: Invalid buffer handle from guest\n", __func__); + return 1; + } + + const ggml_tensor * tensor; + // safe to remove the const qualifier here + tensor = apir_decode_ggml_tensor(dec); + + uint32_t shmem_res_id; + apir_decode_virtgpu_shmem_res_id(dec, &shmem_res_id); + + size_t offset; + apir_decode_size_t(dec, &offset); + + size_t size; + apir_decode_size_t(dec, &size); + + if (validate_buffer_operation(offset, size, __func__) != 0) { + return 1; + } + + void * shmem_data = ctx->iface->get_shmem_ptr(ctx->ctx_id, shmem_res_id); + if (!shmem_data) { + GGML_LOG_ERROR(GGML_VIRTGPU_BCK "%s: Couldn't get the shmem addr from virgl\n", __func__); + return 1; + } + + buffer->iface.get_tensor(buffer, tensor, shmem_data, offset, size); + + return 0; +} + +uint32_t backend_buffer_cpy_tensor(apir_encoder * enc, apir_decoder * dec, virgl_apir_context * ctx) { + GGML_UNUSED(ctx); + + ggml_backend_buffer_t buffer; + buffer = apir_decode_ggml_buffer(dec); + + if (!buffer || apir_decoder_get_fatal(dec)) { + GGML_LOG_ERROR(GGML_VIRTGPU_BCK "%s: Invalid buffer handle from guest\n", __func__); + return 1; + } + + const ggml_tensor * src; + // safe to remove the const qualifier here + src = apir_decode_ggml_tensor(dec); + ggml_tensor * dst = (ggml_tensor *) (uintptr_t) apir_decode_ggml_tensor(dec); + + bool ret = buffer->iface.cpy_tensor(buffer, src, (ggml_tensor *) dst); + + apir_encode_bool_t(enc, &ret); + + return 0; +} + +uint32_t backend_buffer_clear(apir_encoder * enc, apir_decoder * dec, virgl_apir_context * ctx) { + GGML_UNUSED(ctx); + GGML_UNUSED(enc); + + ggml_backend_buffer_t buffer; + buffer = apir_decode_ggml_buffer(dec); + + if (!buffer || apir_decoder_get_fatal(dec)) { + GGML_LOG_ERROR(GGML_VIRTGPU_BCK "%s: Invalid buffer handle from guest\n", __func__); + return 1; + } + + uint8_t value; + apir_decode_uint8_t(dec, &value); + + buffer->iface.clear(buffer, value); + + return 0; +} + +uint32_t backend_buffer_free_buffer(apir_encoder * enc, apir_decoder * dec, virgl_apir_context * ctx) { + GGML_UNUSED(ctx); + GGML_UNUSED(enc); + + ggml_backend_buffer_t buffer; + buffer = apir_decode_ggml_buffer(dec); + + if (!buffer || apir_decoder_get_fatal(dec)) { + GGML_LOG_ERROR(GGML_VIRTGPU_BCK "%s: Invalid buffer handle from guest\n", __func__); + return 1; + } + + if (!apir_untrack_backend_buffer(buffer)) { + GGML_LOG_WARN(GGML_VIRTGPU_BCK "%s: unknown buffer %p\n", __func__, (void *) buffer); + return 1; + } + + buffer->iface.free_buffer(buffer); + + return 0; +} diff --git a/ggml/src/ggml-virtgpu/backend/backend-dispatched-device.cpp b/ggml/src/ggml-virtgpu/backend/backend-dispatched-device.cpp new file mode 100644 index 00000000000..c7acb8b51ce --- /dev/null +++ b/ggml/src/ggml-virtgpu/backend/backend-dispatched-device.cpp @@ -0,0 +1,148 @@ +#include "backend-dispatched.h" +#include "backend-virgl-apir.h" +#include "ggml-backend-impl.h" +#include "ggml-backend.h" +#include "ggml-impl.h" + +#include <cstdint> + +uint32_t backend_device_get_device_count(apir_encoder * enc, apir_decoder * dec, virgl_apir_context * ctx) { + GGML_UNUSED(ctx); + GGML_UNUSED(ctx); + GGML_UNUSED(dec); + + int32_t dev_count = reg->iface.get_device_count(reg); + apir_encode_int32_t(enc, &dev_count); + + return 0; +} + +uint32_t backend_device_get_count(apir_encoder * enc, apir_decoder * dec, virgl_apir_context * ctx) { + GGML_UNUSED(ctx); + GGML_UNUSED(ctx); + GGML_UNUSED(dec); + + int32_t dev_count = reg->iface.get_device_count(reg); + apir_encode_int32_t(enc, &dev_count); + + return 0; +} + +uint32_t backend_device_get_name(apir_encoder * enc, apir_decoder * dec, virgl_apir_context * ctx) { + GGML_UNUSED(ctx); + GGML_UNUSED(dec); + + const char * string = dev->iface.get_name(dev); + + const size_t string_size = strlen(string) + 1; + apir_encode_array_size(enc, string_size); + apir_encode_char_array(enc, string, string_size); + + return 0; +} + +uint32_t backend_device_get_description(apir_encoder * enc, apir_decoder * dec, virgl_apir_context * ctx) { + GGML_UNUSED(ctx); + GGML_UNUSED(dec); + + const char * string = dev->iface.get_description(dev); + + const size_t string_size = strlen(string) + 1; + apir_encode_array_size(enc, string_size); + apir_encode_char_array(enc, string, string_size); + + return 0; +} + +uint32_t backend_device_get_type(apir_encoder * enc, apir_decoder * dec, virgl_apir_context * ctx) { + GGML_UNUSED(ctx); + GGML_UNUSED(dec); + + uint32_t type = dev->iface.get_type(dev); + apir_encode_uint32_t(enc, &type); + + return 0; +} + +uint32_t backend_device_get_memory(apir_encoder * enc, apir_decoder * dec, virgl_apir_context * ctx) { + GGML_UNUSED(ctx); + GGML_UNUSED(dec); + + size_t free, total; + dev->iface.get_memory(dev, &free, &total); + + apir_encode_size_t(enc, &free); + apir_encode_size_t(enc, &total); + + return 0; +} + +uint32_t backend_device_supports_op(apir_encoder * enc, apir_decoder * dec, virgl_apir_context * ctx) { + GGML_UNUSED(ctx); + + const ggml_tensor * op = apir_decode_ggml_tensor_inplace(dec); + + bool supports_op = dev->iface.supports_op(dev, op); + + apir_encode_bool_t(enc, &supports_op); + + return 0; +} + +uint32_t backend_device_get_buffer_type(apir_encoder * enc, apir_decoder * dec, virgl_apir_context * ctx) { + GGML_UNUSED(ctx); + GGML_UNUSED(dec); + + ggml_backend_buffer_type_t bufft = dev->iface.get_buffer_type(dev); + + apir_encode_ggml_buffer_type(enc, bufft); + + return 0; +} + +uint32_t backend_device_get_props(apir_encoder * enc, apir_decoder * dec, virgl_apir_context * ctx) { + GGML_UNUSED(ctx); + GGML_UNUSED(dec); + + ggml_backend_dev_props props; + dev->iface.get_props(dev, &props); + + apir_encode_bool_t(enc, &props.caps.async); + apir_encode_bool_t(enc, &props.caps.host_buffer); + apir_encode_bool_t(enc, &props.caps.buffer_from_host_ptr); + apir_encode_bool_t(enc, &props.caps.events); + + return 0; +} + +uint32_t backend_device_buffer_from_ptr(apir_encoder * enc, apir_decoder * dec, virgl_apir_context * ctx) { + GGML_UNUSED(ctx); + GGML_UNUSED(dec); + + uint32_t shmem_res_id; + apir_decode_virtgpu_shmem_res_id(dec, &shmem_res_id); + + void * shmem_ptr = ctx->iface->get_shmem_ptr(ctx->ctx_id, shmem_res_id); + if (!shmem_ptr) { + GGML_LOG_ERROR(GGML_VIRTGPU_BCK "%s: Couldn't get the shmem addr from virgl\n", __func__); + apir_decoder_set_fatal(dec); + return 1; + } + + size_t size; + apir_decode_size_t(dec, &size); + size_t max_tensor_size; + apir_decode_size_t(dec, &max_tensor_size); + + ggml_backend_buffer_t buffer; + buffer = dev->iface.buffer_from_host_ptr(dev, shmem_ptr, size, max_tensor_size); + + apir_encode_ggml_buffer(enc, buffer); + apir_encode_ggml_buffer_type(enc, buffer->buft); + + if (buffer) { + apir_track_backend_buffer(buffer); + } + + return 0; +} diff --git a/ggml/src/ggml-virtgpu/backend/backend-dispatched.cpp b/ggml/src/ggml-virtgpu/backend/backend-dispatched.cpp new file mode 100644 index 00000000000..c80e4aabe1f --- /dev/null +++ b/ggml/src/ggml-virtgpu/backend/backend-dispatched.cpp @@ -0,0 +1,51 @@ +#include "backend-dispatched.h" + +#include "backend-virgl-apir.h" +#include "ggml-backend-impl.h" +#include "ggml-backend.h" +#include "ggml-impl.h" + +#include <cstdint> + +ggml_backend_reg_t reg = NULL; +ggml_backend_dev_t dev = NULL; +ggml_backend_t bck = NULL; + +uint64_t timer_start = 0; +uint64_t timer_total = 0; +uint64_t timer_count = 0; + +uint32_t backend_dispatch_initialize(void * ggml_backend_reg_fct_p) { + if (reg != NULL) { + GGML_LOG_WARN(GGML_VIRTGPU_BCK "%s: already initialized\n", __func__); + return APIR_BACKEND_INITIALIZE_ALREADY_INITED; + } + ggml_backend_reg_t (*ggml_backend_reg_fct)(void) = (ggml_backend_reg_t (*)()) ggml_backend_reg_fct_p; + + reg = ggml_backend_reg_fct(); + if (reg == NULL) { + GGML_LOG_ERROR(GGML_VIRTGPU_BCK "%s: backend registration failed\n", __func__); + return APIR_BACKEND_INITIALIZE_BACKEND_REG_FAILED; + } + + size_t device_count = reg->iface.get_device_count(reg); + if (!device_count) { + GGML_LOG_ERROR(GGML_VIRTGPU_BCK "%s: no device found\n", __func__); + return APIR_BACKEND_INITIALIZE_NO_DEVICE; + } + + dev = reg->iface.get_device(reg, 0); + + if (!dev) { + GGML_LOG_ERROR(GGML_VIRTGPU_BCK "%s: failed to get device\n", __func__); + return APIR_BACKEND_INITIALIZE_NO_DEVICE; + } + + bck = dev->iface.init_backend(dev, NULL); + if (!bck) { + GGML_LOG_ERROR(GGML_VIRTGPU_BCK "%s: backend initialization failed\n", __func__); + return APIR_BACKEND_INITIALIZE_BACKEND_INIT_FAILED; + } + + return APIR_BACKEND_INITIALIZE_SUCCESS; +} diff --git a/ggml/src/ggml-virtgpu/backend/backend-dispatched.gen.h b/ggml/src/ggml-virtgpu/backend/backend-dispatched.gen.h new file mode 100644 index 00000000000..3dc334e4ce4 --- /dev/null +++ b/ggml/src/ggml-virtgpu/backend/backend-dispatched.gen.h @@ -0,0 +1,73 @@ +#pragma once + +/* device */ +uint32_t backend_device_get_device_count(apir_encoder * enc, apir_decoder * dec, virgl_apir_context * ctx); +uint32_t backend_device_get_count(apir_encoder * enc, apir_decoder * dec, virgl_apir_context * ctx); +uint32_t backend_device_get_name(apir_encoder * enc, apir_decoder * dec, virgl_apir_context * ctx); +uint32_t backend_device_get_description(apir_encoder * enc, apir_decoder * dec, virgl_apir_context * ctx); +uint32_t backend_device_get_type(apir_encoder * enc, apir_decoder * dec, virgl_apir_context * ctx); +uint32_t backend_device_get_memory(apir_encoder * enc, apir_decoder * dec, virgl_apir_context * ctx); +uint32_t backend_device_supports_op(apir_encoder * enc, apir_decoder * dec, virgl_apir_context * ctx); +uint32_t backend_device_get_buffer_type(apir_encoder * enc, apir_decoder * dec, virgl_apir_context * ctx); +uint32_t backend_device_get_props(apir_encoder * enc, apir_decoder * dec, virgl_apir_context * ctx); +uint32_t backend_device_buffer_from_ptr(apir_encoder * enc, apir_decoder * dec, virgl_apir_context * ctx); + +/* buffer-type */ +uint32_t backend_buffer_type_get_name(apir_encoder * enc, apir_decoder * dec, virgl_apir_context * ctx); +uint32_t backend_buffer_type_get_alignment(apir_encoder * enc, apir_decoder * dec, virgl_apir_context * ctx); +uint32_t backend_buffer_type_get_max_size(apir_encoder * enc, apir_decoder * dec, virgl_apir_context * ctx); +/* APIR_COMMAND_TYPE_BUFFER_TYPE_IS_HOST is deprecated. Keeping the handler for backward compatibility. */ +uint32_t backend_buffer_type_is_host(apir_encoder * enc, apir_decoder * dec, virgl_apir_context * ctx); +uint32_t backend_buffer_type_alloc_buffer(apir_encoder * enc, apir_decoder * dec, virgl_apir_context * ctx); +uint32_t backend_buffer_type_get_alloc_size(apir_encoder * enc, apir_decoder * dec, virgl_apir_context * ctx); + +/* buffer */ +uint32_t backend_buffer_get_base(apir_encoder * enc, apir_decoder * dec, virgl_apir_context * ctx); +uint32_t backend_buffer_set_tensor(apir_encoder * enc, apir_decoder * dec, virgl_apir_context * ctx); +uint32_t backend_buffer_get_tensor(apir_encoder * enc, apir_decoder * dec, virgl_apir_context * ctx); +uint32_t backend_buffer_cpy_tensor(apir_encoder * enc, apir_decoder * dec, virgl_apir_context * ctx); +uint32_t backend_buffer_clear(apir_encoder * enc, apir_decoder * dec, virgl_apir_context * ctx); +uint32_t backend_buffer_free_buffer(apir_encoder * enc, apir_decoder * dec, virgl_apir_context * ctx); + +/* backend */ +uint32_t backend_backend_graph_compute(apir_encoder * enc, apir_decoder * dec, virgl_apir_context * ctx); + +extern "C" { +static const backend_dispatch_t apir_backend_dispatch_table[APIR_BACKEND_DISPATCH_TABLE_COUNT] = { + + /* device */ + + /* APIR_COMMAND_TYPE_DEVICE_GET_DEVICE_COUNT = */ backend_device_get_device_count, + /* APIR_COMMAND_TYPE_DEVICE_GET_COUNT = */ backend_device_get_count, + /* APIR_COMMAND_TYPE_DEVICE_GET_NAME = */ backend_device_get_name, + /* APIR_COMMAND_TYPE_DEVICE_GET_DESCRIPTION = */ backend_device_get_description, + /* APIR_COMMAND_TYPE_DEVICE_GET_TYPE = */ backend_device_get_type, + /* APIR_COMMAND_TYPE_DEVICE_GET_MEMORY = */ backend_device_get_memory, + /* APIR_COMMAND_TYPE_DEVICE_SUPPORTS_OP = */ backend_device_supports_op, + /* APIR_COMMAND_TYPE_DEVICE_GET_BUFFER_TYPE = */ backend_device_get_buffer_type, + /* APIR_COMMAND_TYPE_DEVICE_GET_PROPS = */ backend_device_get_props, + /* APIR_COMMAND_TYPE_DEVICE_BUFFER_FROM_PTR = */ backend_device_buffer_from_ptr, + + /* buffer-type */ + + /* APIR_COMMAND_TYPE_BUFFER_TYPE_GET_NAME = */ backend_buffer_type_get_name, + /* APIR_COMMAND_TYPE_BUFFER_TYPE_GET_ALIGNMENT = */ backend_buffer_type_get_alignment, + /* APIR_COMMAND_TYPE_BUFFER_TYPE_GET_MAX_SIZE = */ backend_buffer_type_get_max_size, + /* APIR_COMMAND_TYPE_BUFFER_TYPE_IS_HOST = */ backend_buffer_type_is_host /* DEPRECATED */, + /* APIR_COMMAND_TYPE_BUFFER_TYPE_ALLOC_BUFFER = */ backend_buffer_type_alloc_buffer, + /* APIR_COMMAND_TYPE_BUFFER_TYPE_GET_ALLOC_SIZE = */ backend_buffer_type_get_alloc_size, + + /* buffer */ + + /* APIR_COMMAND_TYPE_BUFFER_GET_BASE = */ backend_buffer_get_base, + /* APIR_COMMAND_TYPE_BUFFER_SET_TENSOR = */ backend_buffer_set_tensor, + /* APIR_COMMAND_TYPE_BUFFER_GET_TENSOR = */ backend_buffer_get_tensor, + /* APIR_COMMAND_TYPE_BUFFER_CPY_TENSOR = */ backend_buffer_cpy_tensor, + /* APIR_COMMAND_TYPE_BUFFER_CLEAR = */ backend_buffer_clear, + /* APIR_COMMAND_TYPE_BUFFER_FREE_BUFFER = */ backend_buffer_free_buffer, + + /* backend */ + + /* APIR_COMMAND_TYPE_BACKEND_GRAPH_COMPUTE = */ backend_backend_graph_compute, +}; +} diff --git a/ggml/src/ggml-virtgpu/backend/backend-dispatched.h b/ggml/src/ggml-virtgpu/backend/backend-dispatched.h new file mode 100644 index 00000000000..740ee9e3ffc --- /dev/null +++ b/ggml/src/ggml-virtgpu/backend/backend-dispatched.h @@ -0,0 +1,27 @@ +#pragma once + +// clang-format off +#include <cstdint> +#include <cstddef> + +#include <ggml-backend.h> + +#include "backend-convert.h" +#include "backend-virgl-apir.h" +#include "shared/apir_backend.h" +#include "shared/apir_cs.h" +#include "shared/apir_cs_ggml.h" +// clang-format on + +#define GGML_VIRTGPU_BCK "ggml-virtgpu-backend: " + +struct virgl_apir_context { + uint32_t ctx_id; + virgl_apir_callbacks * iface; +}; + +typedef uint32_t (*backend_dispatch_t)(apir_encoder * enc, apir_decoder * dec, virgl_apir_context * ctx); + +#include "backend-dispatched.gen.h" + +uint32_t backend_dispatch_initialize(void * ggml_backend_reg_fct_p); diff --git a/ggml/src/ggml-virtgpu/backend/backend-virgl-apir.h b/ggml/src/ggml-virtgpu/backend/backend-virgl-apir.h new file mode 100644 index 00000000000..c65a01cdf9b --- /dev/null +++ b/ggml/src/ggml-virtgpu/backend/backend-virgl-apir.h @@ -0,0 +1,32 @@ +#pragma once + +#include "ggml-backend-impl.h" +#include "ggml-backend.h" +#include "ggml-impl.h" +#include "shared/api_remoting.h" + +#include <cstdarg> +#include <cstdio> +#include <cstdlib> + +extern ggml_backend_reg_t reg; +extern ggml_backend_dev_t dev; +extern ggml_backend_t bck; + +struct virgl_apir_callbacks { + const char * (*get_config)(uint32_t virgl_ctx_id, const char * key); + void * (*get_shmem_ptr)(uint32_t virgl_ctx_id, uint32_t res_id); +}; + +extern "C" { +ApirLoadLibraryReturnCode apir_backend_initialize(uint32_t virgl_ctx_id, struct virgl_apir_callbacks * virgl_cbs); +void apir_backend_deinit(uint32_t virgl_ctx_id); +uint32_t apir_backend_dispatcher(uint32_t virgl_ctx_id, + virgl_apir_callbacks * virgl_cbs, + uint32_t cmd_type, + char * dec_cur, + const char * dec_end, + char * enc_cur, + const char * enc_end, + char ** enc_cur_after); +} diff --git a/ggml/src/ggml-virtgpu/backend/backend.cpp b/ggml/src/ggml-virtgpu/backend/backend.cpp new file mode 100644 index 00000000000..535a05f3e69 --- /dev/null +++ b/ggml/src/ggml-virtgpu/backend/backend.cpp @@ -0,0 +1,144 @@ +#include "backend-dispatched.h" +#include "backend-virgl-apir.h" +#include "shared/api_remoting.h" +#include "shared/apir_backend.h" +#include "shared/apir_cs.h" + +#include <dlfcn.h> +#include <ggml-backend.h> + +#include <iostream> + +#define APIR_LLAMA_CPP_GGML_LIBRARY_PATH_ENV "APIR_LLAMA_CPP_GGML_LIBRARY_PATH" +#define APIR_LLAMA_CPP_GGML_LIBRARY_REG_ENV "APIR_LLAMA_CPP_GGML_LIBRARY_REG" +#define APIR_LLAMA_CPP_LOG_TO_FILE_ENV "APIR_LLAMA_CPP_LOG_TO_FILE" + +#define GGML_DEFAULT_BACKEND_REG "ggml_backend_init" + +static void * backend_library_handle = NULL; +static FILE * apir_logfile = NULL; + +static void log_to_file_callback(enum ggml_log_level level, const char * text, void * user_data) { + FILE * logfile = (FILE *) user_data; + fprintf(logfile, "[%d] %s", level, text); + fflush(logfile); +} + +extern "C" { +void apir_backend_deinit(uint32_t virgl_ctx_id) { + GGML_UNUSED(virgl_ctx_id); + + auto buffers = apir_get_track_backend_buffers(); + for (const auto & buffer : buffers) { + apir_untrack_backend_buffer(buffer); + buffer->iface.free_buffer(buffer); + } + + if (backend_library_handle) { + GGML_LOG_INFO(GGML_VIRTGPU_BCK "The GGML backend library was loaded. Unloading it.\n"); + dlclose(backend_library_handle); + backend_library_handle = NULL; + } + + if (apir_logfile) { + fclose(apir_logfile); + apir_logfile = NULL; + } +} + +#define APIR_GGML_LIBRARY_PATH_KEY "ggml.library.path" +#define APIR_GGML_LIBRARY_REG_KEY "ggml.library.reg" + +ApirLoadLibraryReturnCode apir_backend_initialize(uint32_t virgl_ctx_id, struct virgl_apir_callbacks * virgl_cbs) { + const char * dlsym_error; + + const char * apir_log_to_file = getenv(APIR_LLAMA_CPP_LOG_TO_FILE_ENV); + if (apir_log_to_file) { + apir_logfile = fopen(apir_log_to_file, "w"); + if (apir_logfile) { + ggml_log_set(log_to_file_callback, apir_logfile); + } else { + GGML_LOG_INFO(GGML_VIRTGPU_BCK "Could not open the log file at '%s'\n", apir_log_to_file); + } + } + + const char * library_name = virgl_cbs->get_config(virgl_ctx_id, APIR_GGML_LIBRARY_PATH_KEY); + const char * virgl_library_reg = virgl_cbs->get_config(virgl_ctx_id, APIR_GGML_LIBRARY_REG_KEY); + const char * library_reg = virgl_library_reg ? virgl_library_reg : GGML_DEFAULT_BACKEND_REG; + + if (!library_name) { + GGML_LOG_ERROR(GGML_VIRTGPU_BCK "%s: cannot open the GGML library: env var '%s' not defined\n", __func__, + APIR_LLAMA_CPP_GGML_LIBRARY_PATH_ENV); + + return APIR_LOAD_LIBRARY_ENV_VAR_MISSING; + } + + backend_library_handle = dlopen(library_name, RTLD_LAZY); + + if (!backend_library_handle) { + GGML_LOG_ERROR(GGML_VIRTGPU_BCK "%s: cannot open the GGML library: %s\n", __func__, dlerror()); + + return APIR_LOAD_LIBRARY_CANNOT_OPEN; + } + + if (!library_reg) { + GGML_LOG_ERROR(GGML_VIRTGPU_BCK "%s: cannot register the GGML library: env var '%s' not defined\n", __func__, + APIR_LLAMA_CPP_GGML_LIBRARY_REG_ENV); + + return APIR_LOAD_LIBRARY_ENV_VAR_MISSING; + } + + void * ggml_backend_reg_fct = dlsym(backend_library_handle, library_reg); + dlsym_error = dlerror(); + if (dlsym_error) { + GGML_LOG_ERROR(GGML_VIRTGPU_BCK "%s: cannot find the GGML backend registration symbol '%s' (from %s): %s\n", + __func__, library_reg, APIR_LLAMA_CPP_GGML_LIBRARY_REG_ENV, dlsym_error); + + return APIR_LOAD_LIBRARY_SYMBOL_MISSING; + } + + uint32_t ret = backend_dispatch_initialize(ggml_backend_reg_fct); + + return (ApirLoadLibraryReturnCode) (APIR_LOAD_LIBRARY_INIT_BASE_INDEX + ret); +} + +uint32_t apir_backend_dispatcher(uint32_t virgl_ctx_id, + virgl_apir_callbacks * virgl_cbs, + uint32_t cmd_type, + char * dec_cur, + const char * dec_end, + char * enc_cur, + const char * enc_end, + char ** enc_cur_after) { + apir_encoder enc = { + .cur = enc_cur, + .start = enc_cur, + .end = enc_end, + .fatal = false, + }; + + apir_decoder dec = { + .cur = dec_cur, + .end = dec_end, + .fatal = false, + }; + + virgl_apir_context ctx = { + .ctx_id = virgl_ctx_id, + .iface = virgl_cbs, + }; + + if (cmd_type >= APIR_BACKEND_DISPATCH_TABLE_COUNT) { + GGML_LOG_ERROR(GGML_VIRTGPU_BCK "%s: Received an invalid dispatch index (%d >= %d)\n", __func__, cmd_type, + APIR_BACKEND_DISPATCH_TABLE_COUNT); + return APIR_BACKEND_FORWARD_INDEX_INVALID; + } + + backend_dispatch_t forward_fct = apir_backend_dispatch_table[cmd_type]; + uint32_t ret = forward_fct(&enc, &dec, &ctx); + + *enc_cur_after = enc.cur; + + return ret; +} +} diff --git a/ggml/src/ggml-virtgpu/backend/shared/api_remoting.h b/ggml/src/ggml-virtgpu/backend/shared/api_remoting.h new file mode 100644 index 00000000000..6bf97e8a3a2 --- /dev/null +++ b/ggml/src/ggml-virtgpu/backend/shared/api_remoting.h @@ -0,0 +1,95 @@ +#pragma once + +/* the rest of this file must match virglrenderer/src/apir-protocol.h */ + +#include <unistd.h> + +#include <cstdint> + +#define APIR_PROTOCOL_MAJOR 0 +#define APIR_PROTOCOL_MINOR 1 + +#define APIR_HANDSHAKE_MAGIC 0xab1e + +enum ApirCommandType { + APIR_COMMAND_TYPE_HANDSHAKE = 0, + APIR_COMMAND_TYPE_LOADLIBRARY = 1, + APIR_COMMAND_TYPE_FORWARD = 2, + + APIR_COMMAND_TYPE_LENGTH = 3, +}; + +typedef uint64_t ApirCommandFlags; + +enum ApirLoadLibraryReturnCode { + APIR_LOAD_LIBRARY_SUCCESS = 0, + // these error codes are returned by the Virglrenderer APIR component + APIR_LOAD_LIBRARY_HYPERCALL_INITIALIZATION_ERROR = 1, + APIR_LOAD_LIBRARY_ALREADY_LOADED = 2, + APIR_LOAD_LIBRARY_ENV_VAR_MISSING = 3, + APIR_LOAD_LIBRARY_CANNOT_OPEN = 4, + APIR_LOAD_LIBRARY_SYMBOL_MISSING = 5, + // any value greater than this is an APIR *backend library* initialization return code + APIR_LOAD_LIBRARY_INIT_BASE_INDEX = 6, +}; + +enum ApirForwardReturnCode { + APIR_FORWARD_SUCCESS = 0, + // these error codes are returned by the Virglrenderer APIR component + APIR_FORWARD_NO_DISPATCH_FCT = 1, + APIR_FORWARD_TIMEOUT = 2, + APIR_FORWARD_FAILED_TO_SYNC_STREAMS = 3, + // any value greater than this index an APIR *backend library* forward return code + APIR_FORWARD_BASE_INDEX = 4, +}; + +__attribute__((unused)) static inline const char * apir_command_name(ApirCommandType type) { + switch (type) { + case APIR_COMMAND_TYPE_HANDSHAKE: + return "HandShake"; + case APIR_COMMAND_TYPE_LOADLIBRARY: + return "LoadLibrary"; + case APIR_COMMAND_TYPE_FORWARD: + return "Forward"; + default: + return "unknown"; + } +} + +__attribute__((unused)) static const char * apir_load_library_error(ApirLoadLibraryReturnCode code) { +#define APIR_LOAD_LIBRARY_ERROR(code_name) \ + do { \ + if (code == code_name) \ + return #code_name; \ + } while (0) + + APIR_LOAD_LIBRARY_ERROR(APIR_LOAD_LIBRARY_SUCCESS); + APIR_LOAD_LIBRARY_ERROR(APIR_LOAD_LIBRARY_HYPERCALL_INITIALIZATION_ERROR); + APIR_LOAD_LIBRARY_ERROR(APIR_LOAD_LIBRARY_ALREADY_LOADED); + APIR_LOAD_LIBRARY_ERROR(APIR_LOAD_LIBRARY_ENV_VAR_MISSING); + APIR_LOAD_LIBRARY_ERROR(APIR_LOAD_LIBRARY_CANNOT_OPEN); + APIR_LOAD_LIBRARY_ERROR(APIR_LOAD_LIBRARY_SYMBOL_MISSING); + APIR_LOAD_LIBRARY_ERROR(APIR_LOAD_LIBRARY_INIT_BASE_INDEX); + + return "Unknown APIR_COMMAND_TYPE_LoadLibrary error"; + +#undef APIR_LOAD_LIBRARY_ERROR +} + +__attribute__((unused)) static const char * apir_forward_error(ApirForwardReturnCode code) { +#define APIR_FORWARD_ERROR(code_name) \ + do { \ + if (code == code_name) \ + return #code_name; \ + } while (0) + + APIR_FORWARD_ERROR(APIR_FORWARD_SUCCESS); + APIR_FORWARD_ERROR(APIR_FORWARD_NO_DISPATCH_FCT); + APIR_FORWARD_ERROR(APIR_FORWARD_TIMEOUT); + APIR_FORWARD_ERROR(APIR_FORWARD_FAILED_TO_SYNC_STREAMS); + APIR_FORWARD_ERROR(APIR_FORWARD_BASE_INDEX); + + return "Unknown APIR_COMMAND_TYPE_FORWARD error"; + +#undef APIR_FORWARD_ERROR +} diff --git a/ggml/src/ggml-virtgpu/backend/shared/apir_backend.gen.h b/ggml/src/ggml-virtgpu/backend/shared/apir_backend.gen.h new file mode 100644 index 00000000000..520ac9c7299 --- /dev/null +++ b/ggml/src/ggml-virtgpu/backend/shared/apir_backend.gen.h @@ -0,0 +1,94 @@ +typedef enum ApirBackendCommandType { + + /* device */ + APIR_COMMAND_TYPE_DEVICE_GET_DEVICE_COUNT = 0, + APIR_COMMAND_TYPE_DEVICE_GET_COUNT = 1, + APIR_COMMAND_TYPE_DEVICE_GET_NAME = 2, + APIR_COMMAND_TYPE_DEVICE_GET_DESCRIPTION = 3, + APIR_COMMAND_TYPE_DEVICE_GET_TYPE = 4, + APIR_COMMAND_TYPE_DEVICE_GET_MEMORY = 5, + APIR_COMMAND_TYPE_DEVICE_SUPPORTS_OP = 6, + APIR_COMMAND_TYPE_DEVICE_GET_BUFFER_TYPE = 7, + APIR_COMMAND_TYPE_DEVICE_GET_PROPS = 8, + APIR_COMMAND_TYPE_DEVICE_BUFFER_FROM_PTR = 9, + + /* buffer-type */ + APIR_COMMAND_TYPE_BUFFER_TYPE_GET_NAME = 10, + APIR_COMMAND_TYPE_BUFFER_TYPE_GET_ALIGNMENT = 11, + APIR_COMMAND_TYPE_BUFFER_TYPE_GET_MAX_SIZE = 12, + APIR_COMMAND_TYPE_BUFFER_TYPE_IS_HOST = 13, + APIR_COMMAND_TYPE_BUFFER_TYPE_ALLOC_BUFFER = 14, + APIR_COMMAND_TYPE_BUFFER_TYPE_GET_ALLOC_SIZE = 15, + + /* buffer */ + APIR_COMMAND_TYPE_BUFFER_GET_BASE = 16, + APIR_COMMAND_TYPE_BUFFER_SET_TENSOR = 17, + APIR_COMMAND_TYPE_BUFFER_GET_TENSOR = 18, + APIR_COMMAND_TYPE_BUFFER_CPY_TENSOR = 19, + APIR_COMMAND_TYPE_BUFFER_CLEAR = 20, + APIR_COMMAND_TYPE_BUFFER_FREE_BUFFER = 21, + + /* backend */ + APIR_COMMAND_TYPE_BACKEND_GRAPH_COMPUTE = 22, + + // last command_type index + 1 + APIR_BACKEND_DISPATCH_TABLE_COUNT = 23, +} ApirBackendCommandType; + +static inline const char * apir_dispatch_command_name(ApirBackendCommandType type) { + switch (type) { + /* device */ + case APIR_COMMAND_TYPE_DEVICE_GET_DEVICE_COUNT: + return "device_get_device_count"; + case APIR_COMMAND_TYPE_DEVICE_GET_COUNT: + return "device_get_count"; + case APIR_COMMAND_TYPE_DEVICE_GET_NAME: + return "device_get_name"; + case APIR_COMMAND_TYPE_DEVICE_GET_DESCRIPTION: + return "device_get_description"; + case APIR_COMMAND_TYPE_DEVICE_GET_TYPE: + return "device_get_type"; + case APIR_COMMAND_TYPE_DEVICE_GET_MEMORY: + return "device_get_memory"; + case APIR_COMMAND_TYPE_DEVICE_SUPPORTS_OP: + return "device_supports_op"; + case APIR_COMMAND_TYPE_DEVICE_GET_BUFFER_TYPE: + return "device_get_buffer_type"; + case APIR_COMMAND_TYPE_DEVICE_GET_PROPS: + return "device_get_props"; + case APIR_COMMAND_TYPE_DEVICE_BUFFER_FROM_PTR: + return "device_buffer_from_ptr"; + /* buffer-type */ + case APIR_COMMAND_TYPE_BUFFER_TYPE_GET_NAME: + return "buffer_type_get_name"; + case APIR_COMMAND_TYPE_BUFFER_TYPE_GET_ALIGNMENT: + return "buffer_type_get_alignment"; + case APIR_COMMAND_TYPE_BUFFER_TYPE_GET_MAX_SIZE: + return "buffer_type_get_max_size"; + case APIR_COMMAND_TYPE_BUFFER_TYPE_IS_HOST: + return "buffer_type_is_host"; + case APIR_COMMAND_TYPE_BUFFER_TYPE_ALLOC_BUFFER: + return "buffer_type_alloc_buffer"; + case APIR_COMMAND_TYPE_BUFFER_TYPE_GET_ALLOC_SIZE: + return "buffer_type_get_alloc_size"; + /* buffer */ + case APIR_COMMAND_TYPE_BUFFER_GET_BASE: + return "buffer_get_base"; + case APIR_COMMAND_TYPE_BUFFER_SET_TENSOR: + return "buffer_set_tensor"; + case APIR_COMMAND_TYPE_BUFFER_GET_TENSOR: + return "buffer_get_tensor"; + case APIR_COMMAND_TYPE_BUFFER_CPY_TENSOR: + return "buffer_cpy_tensor"; + case APIR_COMMAND_TYPE_BUFFER_CLEAR: + return "buffer_clear"; + case APIR_COMMAND_TYPE_BUFFER_FREE_BUFFER: + return "buffer_free_buffer"; + /* backend */ + case APIR_COMMAND_TYPE_BACKEND_GRAPH_COMPUTE: + return "backend_graph_compute"; + + default: + return "unknown"; + } +} diff --git a/ggml/src/ggml-virtgpu/backend/shared/apir_backend.h b/ggml/src/ggml-virtgpu/backend/shared/apir_backend.h new file mode 100644 index 00000000000..da1e21b5b2f --- /dev/null +++ b/ggml/src/ggml-virtgpu/backend/shared/apir_backend.h @@ -0,0 +1,50 @@ +#pragma once + +#include "apir_backend.gen.h" + +#include <stdint.h> // for uintptr_t +#include <time.h> // for timespec, clock_gettime + +#define APIR_BACKEND_INITIALIZE_SUCCESS 0 +#define APIR_BACKEND_INITIALIZE_CANNOT_OPEN_BACKEND_LIBRARY 1 +#define APIR_BACKEND_INITIALIZE_CANNOT_OPEN_GGML_LIBRARY 2 +#define APIR_BACKEND_INITIALIZE_MISSING_BACKEND_SYMBOLS 3 +#define APIR_BACKEND_INITIALIZE_MISSING_GGML_SYMBOLS 4 +#define APIR_BACKEND_INITIALIZE_BACKEND_FAILED 5 +#define APIR_BACKEND_INITIALIZE_BACKEND_REG_FAILED 6 +#define APIR_BACKEND_INITIALIZE_ALREADY_INITED 7 +#define APIR_BACKEND_INITIALIZE_NO_DEVICE 8 +#define APIR_BACKEND_INITIALIZE_BACKEND_INIT_FAILED 9 + +// new entries here need to be added to the apir_backend_initialize_error function below + +#define APIR_BACKEND_FORWARD_INDEX_INVALID 6 + +// 0 is fast, 1 avoids the backend to crash if an unsupported tensor is received +#define APIR_BACKEND_CHECK_SUPPORTS_OP 0 + +typedef uintptr_t apir_buffer_type_host_handle_t; +typedef uintptr_t apir_buffer_host_handle_t; + +static const char * apir_backend_initialize_error(int code) { +#define APIR_BACKEND_INITIALIZE_ERROR(code_name) \ + do { \ + if (code == code_name) \ + return #code_name; \ + } while (0) + + APIR_BACKEND_INITIALIZE_ERROR(APIR_BACKEND_INITIALIZE_SUCCESS); + APIR_BACKEND_INITIALIZE_ERROR(APIR_BACKEND_INITIALIZE_CANNOT_OPEN_BACKEND_LIBRARY); + APIR_BACKEND_INITIALIZE_ERROR(APIR_BACKEND_INITIALIZE_CANNOT_OPEN_GGML_LIBRARY); + APIR_BACKEND_INITIALIZE_ERROR(APIR_BACKEND_INITIALIZE_MISSING_BACKEND_SYMBOLS); + APIR_BACKEND_INITIALIZE_ERROR(APIR_BACKEND_INITIALIZE_MISSING_GGML_SYMBOLS); + APIR_BACKEND_INITIALIZE_ERROR(APIR_BACKEND_INITIALIZE_BACKEND_FAILED); + APIR_BACKEND_INITIALIZE_ERROR(APIR_BACKEND_INITIALIZE_BACKEND_REG_FAILED); + APIR_BACKEND_INITIALIZE_ERROR(APIR_BACKEND_INITIALIZE_ALREADY_INITED); + APIR_BACKEND_INITIALIZE_ERROR(APIR_BACKEND_INITIALIZE_NO_DEVICE); + APIR_BACKEND_INITIALIZE_ERROR(APIR_BACKEND_INITIALIZE_BACKEND_INIT_FAILED); + + return "Unknown APIR_BACKEND_INITIALIZE error:/"; + +#undef APIR_BACKEND_INITIALIZE_ERROR +} diff --git a/ggml/src/ggml-virtgpu/backend/shared/apir_cs.h b/ggml/src/ggml-virtgpu/backend/shared/apir_cs.h new file mode 100644 index 00000000000..64bf2ec9609 --- /dev/null +++ b/ggml/src/ggml-virtgpu/backend/shared/apir_cs.h @@ -0,0 +1,378 @@ +#pragma once + +#include "ggml-impl.h" + +#include <cassert> +#include <cstring> + +#define likely(x) __builtin_expect(!!(x), 1) +#define unlikely(x) __builtin_expect(!!(x), 0) + +struct apir_encoder { + char * cur; + const char * start; + const char * end; + bool fatal; +}; + +struct apir_decoder { + const char * cur; + const char * end; + bool fatal; +}; + +/* + * new encoder and decoder + */ + +static apir_decoder apir_new_decoder(const char * ptr, size_t size) { + apir_decoder dec = { + .cur = ptr, + .end = ptr + size, + .fatal = false, + }; + + return dec; +} + +static apir_encoder apir_new_encoder(char * ptr, size_t size) { + apir_encoder enc = { + .cur = ptr, + .start = ptr, + .end = ptr + size, + .fatal = false, + }; + + return enc; +} + +/* + * fatal flag handling + */ + +static inline void apir_encoder_reset_fatal(apir_encoder * enc) { + enc->fatal = false; +} + +static inline void apir_encoder_set_fatal(apir_encoder * enc) { + enc->fatal = true; +} + +static inline bool apir_encoder_get_fatal(const apir_encoder * enc) { + return enc->fatal; +} + +static inline void apir_decoder_reset_fatal(apir_decoder * dec) { + dec->fatal = false; +} + +static inline void apir_decoder_set_fatal(apir_decoder * dec) { + dec->fatal = true; +} + +static inline bool apir_decoder_get_fatal(const apir_decoder * dec) { + return dec->fatal; +} + +/* + * encode peek + */ + +static inline bool apir_decoder_peek_internal(apir_decoder * dec, size_t size, void * val, size_t val_size) { + assert(val_size <= size); + + if (unlikely(size > (size_t) (dec->end - dec->cur))) { + GGML_LOG_ERROR("%s: reading too much from the decoder ...\n", __func__); + apir_decoder_set_fatal(dec); + memset(val, 0, val_size); + return false; + } + + /* we should not rely on the compiler to optimize away memcpy... */ + memcpy(val, dec->cur, val_size); + return true; +} + +static inline void apir_decoder_peek(apir_decoder * dec, size_t size, void * val, size_t val_size) { + apir_decoder_peek_internal(dec, size, val, val_size); +} + +static inline const void * apir_decoder_use_inplace(apir_decoder * dec, size_t size) { + if (unlikely(size > (size_t) (dec->end - dec->cur))) { + GGML_LOG_ERROR("%s: reading too much from the decoder ...\n", __func__); + apir_decoder_set_fatal(dec); + return NULL; + } + const void * addr = dec->cur; + dec->cur += size; + + return addr; +} + +/* + * read/write + */ + +static inline void apir_decoder_read(apir_decoder * dec, size_t size, void * val, size_t val_size) { + if (apir_decoder_peek_internal(dec, size, val, val_size)) { + dec->cur += size; + } +} + +static inline char * apir_encoder_write(apir_encoder * enc, size_t size, const void * val, size_t val_size) { + assert(val_size <= size); + assert(size <= ((size_t) (enc->end - enc->cur))); + + char * write_addr = enc->cur; + /* we should not rely on the compiler to optimize away memcpy... */ + memcpy(write_addr, val, val_size); + enc->cur += size; + + return write_addr; +} + +/* + * encode/decode + */ + +static inline void apir_decode(apir_decoder * dec, size_t size, void * data, size_t data_size) { + assert(size % 4 == 0); + apir_decoder_read(dec, size, data, data_size); +} + +static inline void apir_encode(apir_encoder * enc, size_t size, const void * data, size_t data_size) { + assert(size % 4 == 0); + apir_encoder_write(enc, size, data, data_size); +} + +/* + * typed encode/decode + */ + +/* uint8_t */ + +static inline void apir_encode_uint8_t(apir_encoder * enc, const uint8_t * val) { + apir_encode(enc, sizeof(int), val, sizeof(*val)); +} + +static inline void apir_decode_uint8_t(apir_decoder * dec, uint8_t * val) { + apir_decode(dec, sizeof(int), val, sizeof(*val)); +} + +/* uint64_t */ + +static inline void apir_encode_uint64_t(apir_encoder * enc, const uint64_t * val) { + apir_encode(enc, 8, val, sizeof(*val)); +} + +static inline void apir_decode_uint64_t(apir_decoder * dec, uint64_t * val) { + apir_decode(dec, 8, val, sizeof(*val)); +} + +static inline void apir_encode_uint64_t_array(apir_encoder * enc, const uint64_t * val, uint32_t count) { + const size_t size = sizeof(*val) * count; + assert(size >= count); + apir_encode(enc, size, val, size); +} + +static inline void apir_decode_uint64_t_array(apir_decoder * dec, uint64_t * val, uint32_t count) { + const size_t size = sizeof(*val) * count; + assert(size >= count); + apir_decode(dec, size, val, size); +} + +static inline const uint64_t * apir_decode_uint64_t_array_inplace(apir_decoder * dec, uint32_t count) { + return (uint64_t *) (uintptr_t) apir_decoder_use_inplace(dec, count * sizeof(uint64_t)); +} + +/* int32_t */ + +static inline void apir_encode_int32_t(apir_encoder * enc, const int32_t * val) { + apir_encode(enc, 4, val, sizeof(*val)); +} + +static inline void apir_decode_int32_t(apir_decoder * dec, int32_t * val) { + apir_decode(dec, 4, val, sizeof(*val)); +} + +static inline void apir_encode_int32_t_array(apir_encoder * enc, const int32_t * val, uint32_t count) { + const size_t size = sizeof(*val) * count; + assert(size >= count); + apir_encode(enc, size, val, size); +} + +static inline void apir_decode_int32_t_array(apir_decoder * dec, int32_t * val, uint32_t count) { + const size_t size = sizeof(*val) * count; + assert(size >= count); + apir_decode(dec, size, val, size); +} + +/* array size (uint64_t) */ + +static inline void apir_encode_array_size(apir_encoder * enc, uint64_t size) { + apir_encode_uint64_t(enc, &size); +} + +static inline uint64_t apir_decode_array_size(apir_decoder * dec, uint64_t expected_size) { + uint64_t size; + apir_decode_uint64_t(dec, &size); + if (size != expected_size) { + GGML_LOG_ERROR("%s: Couldn't decode array from the decoder\n", __func__); + apir_decoder_set_fatal(dec); + size = 0; + } + return size; +} + +static inline uint64_t apir_decode_array_size_unchecked(apir_decoder * dec) { + uint64_t size; + apir_decode_uint64_t(dec, &size); + return size; +} + +/* non-array pointer */ + +static inline bool apir_encode_simple_pointer(apir_encoder * enc, const void * val) { + apir_encode_array_size(enc, val ? 1 : 0); + return val; +} + +static inline bool apir_decode_simple_pointer(apir_decoder * dec) { + return apir_decode_array_size_unchecked(dec); +} + +/* uint32_t */ + +static inline void apir_encode_uint32_t(apir_encoder * enc, const uint32_t * val) { + apir_encode(enc, 4, val, sizeof(*val)); +} + +static inline void apir_decode_uint32_t(apir_decoder * dec, uint32_t * val) { + apir_decode(dec, 4, val, sizeof(*val)); +} + +static inline void apir_encode_uint32_t_array(apir_encoder * enc, const uint32_t * val, uint32_t count) { + const size_t size = sizeof(*val) * count; + assert(size >= count); + apir_encode(enc, size, val, size); +} + +static inline void apir_decode_uint32_t_array(apir_decoder * dec, uint32_t * val, uint32_t count) { + const size_t size = sizeof(*val) * count; + assert(size >= count); + apir_decode(dec, size, val, size); +} + +/* size_t */ + +static inline void apir_encode_size_t(apir_encoder * enc, const size_t * val) { + const uint64_t tmp = *val; + apir_encode_uint64_t(enc, &tmp); +} + +static inline void apir_decode_size_t(apir_decoder * dec, size_t * val) { + uint64_t tmp; + apir_decode_uint64_t(dec, &tmp); + *val = tmp; +} + +static inline void apir_encode_size_t_array(apir_encoder * enc, const size_t * val, uint32_t count) { + if (sizeof(size_t) == sizeof(uint64_t)) { + apir_encode_uint64_t_array(enc, (const uint64_t *) val, count); + } else { + for (uint32_t i = 0; i < count; i++) { + apir_encode_size_t(enc, &val[i]); + } + } +} + +static inline void apir_decode_size_t_array(apir_decoder * dec, size_t * val, uint32_t count) { + if (sizeof(size_t) == sizeof(uint64_t)) { + apir_decode_uint64_t_array(dec, (uint64_t *) val, count); + } else { + for (uint32_t i = 0; i < count; i++) { + apir_decode_size_t(dec, &val[i]); + } + } +} + +/* opaque blob */ + +static inline void apir_encode_blob_array(apir_encoder * enc, const void * val, size_t size) { + apir_encode(enc, (size + 3) & ~3, val, size); +} + +static inline void apir_decode_blob_array(apir_decoder * dec, void * val, size_t size) { + apir_decode(dec, (size + 3) & ~3, val, size); +} + +/* string */ + +static inline void apir_encode_char_array(apir_encoder * enc, const char * val, size_t size) { + assert(size && strlen(val) < size); + apir_encode_blob_array(enc, val, size); +} + +static inline void apir_decode_char_array(apir_decoder * dec, char * val, size_t size) { + apir_decode_blob_array(dec, val, size); + if (size) { + val[size - 1] = '\0'; + } else { + GGML_LOG_ERROR("%s: Couldn't decode the blog array\n", __func__); + apir_decoder_set_fatal(dec); + } +} + +/* (temp) buffer allocation */ + +static inline void * apir_decoder_alloc_array(size_t size, size_t count) { + size_t alloc_size; + if (unlikely(__builtin_mul_overflow(size, count, &alloc_size))) { + GGML_LOG_ERROR("%s: overflow in array allocation of %zu * %zu bytes\n", __func__, size, count); + return NULL; + } + + return malloc(alloc_size); +} + +/* bool */ + +static inline void apir_encode_bool_t(apir_encoder * enc, const bool * val) { + apir_encode(enc, sizeof(int), val, sizeof(bool)); +} + +static inline void apir_decode_bool_t(apir_decoder * dec, bool * val) { + apir_decode(dec, sizeof(int), val, sizeof(bool)); +} + +/* apir_buffer_type_host_handle_t */ + +static inline void apir_encode_apir_buffer_type_host_handle_t(apir_encoder * enc, + const apir_buffer_type_host_handle_t * val) { + apir_encode(enc, sizeof(apir_buffer_type_host_handle_t), val, sizeof(apir_buffer_type_host_handle_t)); +} + +static inline void apir_decode_apir_buffer_type_host_handle_t(apir_decoder * dec, + apir_buffer_type_host_handle_t * val) { + apir_decode(dec, sizeof(apir_buffer_type_host_handle_t), val, sizeof(apir_buffer_type_host_handle_t)); +} + +/* apir_buffer_host_handle_t */ + +static inline void apir_encode_apir_buffer_host_handle_t(apir_encoder * enc, const apir_buffer_host_handle_t * val) { + apir_encode(enc, sizeof(apir_buffer_host_handle_t), val, sizeof(apir_buffer_host_handle_t)); +} + +static inline void apir_decode_apir_buffer_host_handle_t(apir_decoder * dec, apir_buffer_host_handle_t * val) { + apir_decode(dec, sizeof(apir_buffer_host_handle_t), val, sizeof(apir_buffer_host_handle_t)); +} + +/* uintptr_t */ + +static inline void apir_encode_uintptr_t(apir_encoder * enc, const uintptr_t * val) { + apir_encode(enc, sizeof(*val), val, sizeof(*val)); +} + +static inline void apir_decode_uintptr_t(apir_decoder * dec, uintptr_t * val) { + apir_decode(dec, sizeof(*val), val, sizeof(*val)); +} diff --git a/ggml/src/ggml-virtgpu/backend/shared/apir_cs_ggml.h b/ggml/src/ggml-virtgpu/backend/shared/apir_cs_ggml.h new file mode 100644 index 00000000000..fabe3e401ca --- /dev/null +++ b/ggml/src/ggml-virtgpu/backend/shared/apir_cs_ggml.h @@ -0,0 +1,232 @@ +#include "apir_cs.h" +#include "apir_cs_rpc.h" +#include "ggml-impl.h" + +// ggml_buffer_to_apir_host_handle(ggml_backend_buffer_t buffer); + +static inline void apir_encode_ggml_buffer_host_handle(apir_encoder * enc, const apir_buffer_host_handle_t * handle); + +static inline ggml_backend_buffer_t apir_decode_ggml_buffer(apir_decoder * dec); + +/* apir_rpc_tensor */ + +static inline void apir_encode_rcp_tensor(apir_encoder * enc, const apir_rpc_tensor * apir_rpc_tensor) { + size_t apir_rpc_tensor_size = sizeof(*apir_rpc_tensor); + apir_encode(enc, apir_rpc_tensor_size, apir_rpc_tensor, apir_rpc_tensor_size); +} + +static inline apir_rpc_tensor * apir_decode_apir_rpc_tensor_inplace(apir_decoder * dec) { + size_t apir_rpc_tensor_size = sizeof(apir_rpc_tensor); + + return (apir_rpc_tensor *) (uintptr_t) apir_decoder_use_inplace(dec, apir_rpc_tensor_size); +} + +static inline apir_rpc_tensor * apir_decode_apir_rpc_tensor_array_inplace(apir_decoder * dec, uint32_t n_tensors) { + size_t apir_rpc_tensor_size = sizeof(apir_rpc_tensor) * n_tensors; + + return (apir_rpc_tensor *) (uintptr_t) apir_decoder_use_inplace(dec, apir_rpc_tensor_size); +} + +/* ggml_tensor */ + +static inline void apir_encode_ggml_tensor(apir_encoder * enc, const ggml_tensor * tensor) { + apir_rpc_tensor serialized = apir_serialize_tensor(tensor); + + apir_encode_rcp_tensor(enc, &serialized); +} + +static inline const ggml_tensor * apir_decode_ggml_tensor(apir_decoder * dec) { + const apir_rpc_tensor * apir_rpc_tensor = apir_decode_apir_rpc_tensor_inplace(dec); + + if (!apir_rpc_tensor) { + return NULL; + } + + ggml_init_params params{ + /*.mem_size =*/ggml_tensor_overhead(), + /*.mem_buffer =*/NULL, + /*.no_alloc =*/true, + }; + + ggml_context * ctx = ggml_init(params); + + const ggml_tensor * tensor = apir_deserialize_tensor(ctx, apir_rpc_tensor); + + return tensor; +} + +/* *** ggml_backend_buffer_type_t *** */ + +// ggml_backend_buffer_type_t is a POINTER (to a struct). +// Only the host pointer is shared between the host and guest. +// The guest stores it in `buft->context`. +// The host simply writes the pointer address in the buffer variable. + +static inline void apir_encode_ggml_buffer_type(apir_encoder * enc, ggml_backend_buffer_type_t buft) { + apir_buffer_type_host_handle_t handle = ggml_buffer_type_to_apir_handle(buft); + apir_encoder_write(enc, sizeof(handle), &handle, sizeof(handle)); +} + +static inline ggml_backend_buffer_type_t apir_decode_ggml_buffer_type(apir_decoder * dec) { + apir_buffer_type_host_handle_t handle; + + apir_decoder_read(dec, sizeof(handle), &handle, sizeof(handle)); + + return (ggml_backend_buffer_type_t) handle; +} + +static inline void apir_encode_apir_buffer_type_host_handle(apir_encoder * enc, apir_buffer_type_host_handle_t handle) { + apir_encoder_write(enc, sizeof(handle), &handle, sizeof(handle)); +} + +static inline apir_buffer_type_host_handle_t apir_decode_apir_buffer_type_host_handle(apir_decoder * dec) { + apir_buffer_type_host_handle_t handle; + + apir_decoder_read(dec, sizeof(handle), &handle, sizeof(handle)); + + return handle; +} + +/* *** ggml_backend_type_t *** */ + +// ggml_backend_buffer_t is a POINTER. +// same logic as for ggml_backend_buffer_type_t + +static inline void apir_encode_ggml_buffer(apir_encoder * enc, const ggml_backend_buffer_t buffer) { + apir_buffer_host_handle_t handle = BUFFER_TO_HOST_HANDLE(buffer); + apir_encoder_write(enc, sizeof(handle), &handle, sizeof(handle)); +} + +static inline ggml_backend_buffer_t apir_decode_ggml_buffer(apir_decoder * dec) { + ggml_backend_buffer_t buffer; + size_t buffer_ptr_size = sizeof(buffer); + + apir_decoder_read(dec, buffer_ptr_size, &buffer, buffer_ptr_size); + + // SECURITY: Validate buffer handle against tracked buffers to prevent + // guest VM from providing arbitrary host memory addresses + if (buffer) { + extern std::unordered_set<ggml_backend_buffer_t> backend_buffers; + if (backend_buffers.find(buffer) == backend_buffers.end()) { + GGML_LOG_WARN("ggml-virtgpu-backend: %s: Invalid buffer handle from guest: %p\n", __func__, + (void *) buffer); + // Set fatal flag to prevent further processing with invalid handle + apir_decoder_set_fatal(dec); + return NULL; + } + } + + return buffer; +} + +/* enum ggml_status */ + +static inline void apir_encode_ggml_status(apir_encoder * enc, const ggml_status * status) { + apir_encoder_write(enc, sizeof(*status), status, sizeof(*status)); +} + +static inline void apir_decode_ggml_status(apir_decoder * dec, ggml_status * status) { + apir_decoder_read(dec, sizeof(*status), status, sizeof(*status)); +} + +/* virtgpu_shmem */ + +static inline void apir_encode_virtgpu_shmem_res_id(apir_encoder * enc, uint32_t shmem_res_id) { + apir_encode_uint32_t(enc, &shmem_res_id); +} + +static inline void apir_decode_virtgpu_shmem_res_id(apir_decoder * dec, uint32_t * shmem_res_id) { + apir_decode_uint32_t(dec, shmem_res_id); +} + +/* ggml_cgraph */ + +static inline size_t apir_serialize_ggml_cgraph(ggml_cgraph * cgraph, std::vector<uint8_t> & cgraph_data) { + apir_serialize_graph(cgraph, cgraph_data); + + return cgraph_data.size(); +} + +static inline void apir_encode_cgraph_data(apir_encoder * enc, std::vector<uint8_t> & cgraph_data) { + size_t cgraph_size = cgraph_data.size(); + + apir_encode(enc, cgraph_size, cgraph_data.data(), cgraph_size); +} + +static inline ggml_cgraph * apir_decode_ggml_cgraph(apir_decoder * dec, size_t cgraph_size) { + GGML_UNUSED(cgraph_size); + + uint32_t n_nodes; + apir_decode_uint32_t(dec, &n_nodes); + const uint64_t * nodes = apir_decode_uint64_t_array_inplace(dec, n_nodes); + + uint32_t n_tensors; + apir_decode_uint32_t(dec, &n_tensors); + const apir_rpc_tensor * tensors = apir_decode_apir_rpc_tensor_array_inplace(dec, n_tensors); + + return apir_deserialize_graph(n_nodes, n_tensors, tensors, nodes); +} + +static inline void apir_encode_ggml_buffer_handle(apir_encoder * enc, const apir_buffer_host_handle_t * handle) { + apir_encoder_write(enc, sizeof(*handle), &handle, sizeof(*handle)); +} + +static inline void apir_encode_ggml_tensor_inline(apir_encoder * enc, const ggml_tensor * tensor) { + size_t tensor_size = sizeof(*tensor); + + if (tensor->extra) { + GGML_ABORT("%s: Cannot pass tensors with extra", __func__); + } + + if (tensor->src[0] && tensor->buffer) { + static int first = 1; + if (first) { + GGML_LOG_WARN("%s: Cannot pass tensors with src and buffer\n", __func__); + first = 0; + } + } + + apir_encoder_write(enc, tensor_size, tensor, tensor_size); + + // tensor->data is a pointer inside the device buffer. No need to touch it + // tensor->buffer is a pointer to a buffer. Encoding the buffer handle in sequence. + // (could also make a copy of the tensor, and update locally.) + + if (tensor->buffer) { + apir_buffer_host_handle_t buffer_handle = ggml_buffer_to_apir_handle(tensor->buffer); + apir_encode_ggml_buffer_handle(enc, &buffer_handle); + } + + if (tensor->view_src) { + apir_encoder_write(enc, tensor_size, tensor->view_src, tensor_size); + } + + for (int i = 0; tensor->src[i]; i++) { + const ggml_tensor * tensor_src = tensor->src[i]; + apir_encoder_write(enc, tensor_size, tensor_src, tensor_size); + } +} + +static inline const ggml_tensor * apir_decode_ggml_tensor_inplace(apir_decoder * dec) { + // it safe to remove the `const` qualifier here, we *do* want to + // modify the shared memory data to fix the `src` pointers. + ggml_tensor * tensor = (ggml_tensor *) (uintptr_t) apir_decoder_use_inplace(dec, sizeof(ggml_tensor)); + + // tensor->data is a pointer inside the device buffer. No need to touch it + // tensor->buffer is a pointer to a buffer. Decode the buffer handle encoded in sequence. + if (tensor->buffer) { + tensor->buffer = apir_decode_ggml_buffer(dec); + } + + if (tensor->view_src) { + ggml_tensor * tensor_view_src = (ggml_tensor *) (uintptr_t) apir_decoder_use_inplace(dec, sizeof(ggml_tensor)); + tensor->view_src = tensor_view_src; + } + + for (int i = 0; tensor->src[i]; i++) { + ggml_tensor * tensor_src = (ggml_tensor *) (uintptr_t) apir_decoder_use_inplace(dec, sizeof(ggml_tensor)); + tensor->src[i] = tensor_src; // overwrite op->src[i] pointer with the actual location of the src tensor + } + + return tensor; +} diff --git a/ggml/src/ggml-virtgpu/backend/shared/apir_cs_rpc.h b/ggml/src/ggml-virtgpu/backend/shared/apir_cs_rpc.h new file mode 100644 index 00000000000..4cb2f047d1e --- /dev/null +++ b/ggml/src/ggml-virtgpu/backend/shared/apir_cs_rpc.h @@ -0,0 +1,58 @@ +#pragma once + +// clang-format off +#include "ggml.h" +#include "ggml-backend-impl.h" + +#include <unordered_map> +#include <unordered_set> +#include <vector> +#include <cstdint> +// clang-format on + +// ggml_tensor is serialized into apir_rpc_tensor +struct apir_rpc_tensor { + uint64_t id; + uint32_t type; + uint64_t buffer; + uint32_t ne[GGML_MAX_DIMS]; + uint32_t nb[GGML_MAX_DIMS]; + uint32_t op; + int32_t op_params[GGML_MAX_OP_PARAMS / sizeof(int32_t)]; + int32_t flags; + uint64_t src[GGML_MAX_SRC]; + uint64_t view_src; + uint64_t view_offs; + uint64_t data; + char name[GGML_MAX_NAME]; + + char padding[4]; +}; + +/* frontend */ + +apir_rpc_tensor apir_serialize_tensor(const ggml_tensor * tensor); + +void apir_serialize_graph(const ggml_cgraph * cgraph, std::vector<uint8_t> & output); + +/* backend */ + +void apir_track_backend_buffer(ggml_backend_buffer_t buffer); +bool apir_untrack_backend_buffer(ggml_backend_buffer_t buffer); +std::unordered_set<ggml_backend_buffer_t> apir_get_track_backend_buffers(); + +void apir_add_tensor(ggml_tensor * tensor, + std::vector<apir_rpc_tensor> & tensors, + std::unordered_set<ggml_tensor *> & visited); + +ggml_tensor * apir_deserialize_tensor(ggml_context * ctx, const apir_rpc_tensor * tensor); + +ggml_tensor * apir_create_node(uint64_t id, + ggml_context * ctx, + const std::unordered_map<uint64_t, const apir_rpc_tensor *> & tensor_ptrs, + std::unordered_map<uint64_t, ggml_tensor *> & tensor_map); + +ggml_cgraph * apir_deserialize_graph(uint32_t n_nodes, + uint32_t n_tensors, + const apir_rpc_tensor * tensors, + const uint64_t * nodes); diff --git a/ggml/src/ggml-virtgpu/ggml-backend-buffer-type.cpp b/ggml/src/ggml-virtgpu/ggml-backend-buffer-type.cpp new file mode 100644 index 00000000000..8fa20ff43bd --- /dev/null +++ b/ggml/src/ggml-virtgpu/ggml-backend-buffer-type.cpp @@ -0,0 +1,81 @@ +#include "ggml-remoting.h" + +static ggml_backend_buffer_t ggml_backend_remoting_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, + size_t size) { + virtgpu * gpu = BUFT_TO_GPU(buft); + + ggml_backend_remoting_buffer_context * context = (ggml_backend_remoting_buffer_context *) malloc(sizeof(*context)); + if (!context) { + GGML_ABORT(GGML_VIRTGPU "%s: Couldn't allocate the buffer context ...", __func__); + } + + context->gpu = gpu; + + bool async__unused, host_buffer__unused, events__unused; + bool buffer_from_host_ptr; + apir_device_get_props(gpu, &async__unused, &host_buffer__unused, &buffer_from_host_ptr, &events__unused); + + if (buffer_from_host_ptr) { + context->apir_context = apir_device_buffer_from_ptr(gpu, size, size); + context->base = context->apir_context.shmem.mmap_ptr; + context->is_from_ptr = true; + } else { + context->apir_context = apir_buffer_type_alloc_buffer(gpu, gpu->cached_buffer_type.host_handle, size); + context->is_from_ptr = false; + context->base = NULL; + } + + ggml_backend_buffer_t buffer = + ggml_backend_buffer_init(buft, ggml_backend_remoting_buffer_interface, (void *) context, size); + + return buffer; +} + +static const char * ggml_backend_remoting_buffer_type_get_name(ggml_backend_buffer_type_t buft) { + virtgpu * gpu = BUFT_TO_GPU(buft); + + // Return the prefixed name that was built once during initialization + return gpu->cached_buffer_type.name; +} + +static size_t ggml_backend_remoting_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) { + virtgpu * gpu = BUFT_TO_GPU(buft); + + return gpu->cached_buffer_type.alignment; +} + +static size_t ggml_backend_remoting_buffer_type_get_max_size(ggml_backend_buffer_type_t buft) { + virtgpu * gpu = BUFT_TO_GPU(buft); + + return gpu->cached_buffer_type.max_size; +} + +static size_t ggml_backend_remoting_buffer_type_get_alloc_size(ggml_backend_buffer_type_t buft, + const ggml_tensor * tensor) { + virtgpu * gpu = BUFT_TO_GPU(buft); + + if (tensor->buffer == NULL || !tensor->buffer->context || + !buft->device->iface.supports_buft(buft->device, tensor->buffer->buft)) { + return ggml_nbytes(tensor); + } + + return apir_buffer_type_get_alloc_size(gpu, gpu->cached_buffer_type.host_handle, tensor); +} + +const ggml_backend_buffer_type_i ggml_backend_remoting_buffer_type_interface = { + /* .get_name = */ ggml_backend_remoting_buffer_type_get_name, + /* .alloc_buffer = */ ggml_backend_remoting_buffer_type_alloc_buffer, + /* .get_alignment = */ ggml_backend_remoting_buffer_type_get_alignment, + /* .get_max_size = */ ggml_backend_remoting_buffer_type_get_max_size, + /* .get_alloc_size = */ ggml_backend_remoting_buffer_type_get_alloc_size, + /* .is_host = */ NULL, +}; + +const ggml_backend_buffer_type_i ggml_backend_remoting_buffer_from_ptr_type_interface = { + /* .get_name = */ ggml_backend_remoting_buffer_type_get_name, + /* .alloc_buffer = */ NULL, + /* .get_alignment = */ ggml_backend_remoting_buffer_type_get_alignment, + /* .get_max_size = */ ggml_backend_remoting_buffer_type_get_max_size, + /* .get_alloc_size = */ ggml_backend_remoting_buffer_type_get_alloc_size, + /* .is_host = */ NULL, +}; diff --git a/ggml/src/ggml-virtgpu/ggml-backend-buffer.cpp b/ggml/src/ggml-virtgpu/ggml-backend-buffer.cpp new file mode 100644 index 00000000000..b6c561cd61e --- /dev/null +++ b/ggml/src/ggml-virtgpu/ggml-backend-buffer.cpp @@ -0,0 +1,123 @@ +#include "ggml-remoting.h" + +#define BUFFER_TO_GPU(name) ((ggml_backend_remoting_buffer_context *) (name)->context)->gpu + +static void * ggml_backend_remoting_buffer_get_base(ggml_backend_buffer_t buffer) { + ggml_backend_remoting_buffer_context * context = (ggml_backend_remoting_buffer_context *) buffer->context; + if (context->base) { + return context->base; + } + + context->base = apir_buffer_get_base(BUFFER_TO_GPU(buffer), BUFFER_TO_APIR_CONTEXT(buffer)); + + return context->base; +} + +static void ggml_backend_remoting_buffer_set_tensor(ggml_backend_buffer_t buffer, + ggml_tensor * tensor, + const void * data, + size_t offset, + size_t size) { + virtgpu * gpu = BUFFER_TO_GPU(buffer); + + ggml_backend_remoting_buffer_context * context = BUFFER_TO_GGML_CONTEXT(buffer); + if (context->is_from_ptr) { + memcpy((char *) tensor->data + offset, data, size); + } else { + apir_buffer_set_tensor(gpu, BUFFER_TO_APIR_CONTEXT(buffer), tensor, data, offset, size); + } + + return; +} + +static void ggml_backend_remoting_buffer_get_tensor(ggml_backend_buffer_t buffer, + const ggml_tensor * tensor, + void * data, + size_t offset, + size_t size) { + virtgpu * gpu = BUFFER_TO_GPU(buffer); + ggml_backend_remoting_buffer_context * context = BUFFER_TO_GGML_CONTEXT(buffer); + if (context->is_from_ptr) { + memcpy(data, (const char *) tensor->data + offset, size); + } else { + apir_buffer_get_tensor(gpu, BUFFER_TO_APIR_CONTEXT(buffer), tensor, data, offset, size); + } +} + +static void ggml_backend_remoting_buffer_set_tensor_from_ptr(ggml_backend_buffer_t buffer, + ggml_tensor * tensor, + const void * data, + size_t offset, + size_t size) { + UNUSED(buffer); + + memcpy((char *) tensor->data + offset, data, size); + + return; +} + +static void ggml_backend_remoting_buffer_get_tensor_from_ptr(ggml_backend_buffer_t buffer, + const ggml_tensor * tensor, + void * data, + size_t offset, + size_t size) { + UNUSED(buffer); + + memcpy(data, (const char *) tensor->data + offset, size); +} + +static bool ggml_backend_remoting_buffer_cpy_tensor(ggml_backend_buffer_t buffer, + const ggml_tensor * src, + ggml_tensor * dst) { + virtgpu * gpu = BUFFER_TO_GPU(buffer); + + bool ret = apir_buffer_cpy_tensor(gpu, BUFFER_TO_APIR_CONTEXT(buffer), src, dst); + + return ret; +} + +static void ggml_backend_remoting_buffer_clear(ggml_backend_buffer_t buffer, uint8_t value) { + virtgpu * gpu = BUFFER_TO_GPU(buffer); + + apir_buffer_clear(gpu, BUFFER_TO_APIR_CONTEXT(buffer), value); + + return; +} + +static void ggml_backend_remoting_buffer_free_buffer(ggml_backend_buffer_t buffer) { + virtgpu * gpu = BUFFER_TO_GPU(buffer); + + apir_buffer_free_buffer(gpu, BUFFER_TO_APIR_CONTEXT(buffer)); + + ggml_backend_remoting_buffer_context * context = BUFFER_TO_GGML_CONTEXT(buffer); + free(context); + buffer->context = NULL; +} + +const ggml_backend_buffer_i ggml_backend_remoting_buffer_interface = { + /* .free_buffer = */ ggml_backend_remoting_buffer_free_buffer, + /* .get_base = */ ggml_backend_remoting_buffer_get_base, + /* .init_tensor = */ NULL, + /* .memset_tensor = */ NULL, + /* .set_tensor = */ ggml_backend_remoting_buffer_set_tensor, + /* .get_tensor = */ ggml_backend_remoting_buffer_get_tensor, + /* .set_tensor_2d = */ NULL, + /* .get_tensor_2d = */ NULL, + /* .cpy_tensor = */ ggml_backend_remoting_buffer_cpy_tensor, + /* .clear = */ ggml_backend_remoting_buffer_clear, + /* .reset = */ NULL, +}; + +const ggml_backend_buffer_i ggml_backend_remoting_buffer_from_ptr_interface = { + /* .free_buffer = */ ggml_backend_remoting_buffer_free_buffer, + /* .get_base = */ ggml_backend_remoting_buffer_get_base, + /* .init_tensor = */ NULL, + /* .memset_tensor = */ NULL, + /* .set_tensor = */ ggml_backend_remoting_buffer_set_tensor_from_ptr, + /* .get_tensor = */ ggml_backend_remoting_buffer_get_tensor_from_ptr, + /* .set_tensor_2d = */ NULL, + /* .get_tensor_2d = */ NULL, + /* .cpy_tensor = */ ggml_backend_remoting_buffer_cpy_tensor, + /* .clear = */ ggml_backend_remoting_buffer_clear, + /* .reset = */ NULL, +}; diff --git a/ggml/src/ggml-virtgpu/ggml-backend-device.cpp b/ggml/src/ggml-virtgpu/ggml-backend-device.cpp new file mode 100644 index 00000000000..a978812cd90 --- /dev/null +++ b/ggml/src/ggml-virtgpu/ggml-backend-device.cpp @@ -0,0 +1,160 @@ +#include "ggml-remoting.h" + +#include <mutex> + +static const char * ggml_backend_remoting_device_get_name(ggml_backend_dev_t dev) { + virtgpu * gpu = DEV_TO_GPU(dev); + + // Return the prefixed name that was built once during initialization + return gpu->cached_device_info.name; +} + +static const char * ggml_backend_remoting_device_get_description(ggml_backend_dev_t dev) { + virtgpu * gpu = DEV_TO_GPU(dev); + + // Return the pre-cached description from the virtgpu structure + return gpu->cached_device_info.description; +} + +static enum ggml_backend_dev_type ggml_backend_remoting_device_get_type(ggml_backend_dev_t dev) { + virtgpu * gpu = DEV_TO_GPU(dev); + + return (enum ggml_backend_dev_type) gpu->cached_device_info.type; +} + +static void ggml_backend_remoting_device_get_memory(ggml_backend_dev_t dev, size_t * free, size_t * total) { + virtgpu * gpu = DEV_TO_GPU(dev); + + *free = gpu->cached_device_info.memory_free; + *total = gpu->cached_device_info.memory_total; +} + +static bool ggml_backend_remoting_device_supports_op(ggml_backend_dev_t dev, const ggml_tensor * op) { +#if USE_ALWAYS_TRUE_SUPPORTS_OP == 1 + /* ggml-rpc cheats it like this */ + /* with the current implementation of serialize_tensor, the src/view aren't properly passed */ + UNUSED(dev); + UNUSED(op); + + return true; +#else + virtgpu * gpu = DEV_TO_GPU(dev); + + return apir_device_supports_op(gpu, op); +#endif +} + +static bool ggml_backend_remoting_device_supports_buft(ggml_backend_dev_t dev, ggml_backend_buffer_type_t buft) { + bool supported = buft->device == dev; + + return supported; +} + +static bool ggml_backend_remoting_device_offload_op(ggml_backend_dev_t dev, const ggml_tensor * op) { + UNUSED(dev); + UNUSED(op); + + return false; +} + +static void ggml_backend_remoting_device_get_props(ggml_backend_dev_t dev, ggml_backend_dev_props * props) { + props->name = ggml_backend_remoting_device_get_name(dev); + props->description = ggml_backend_remoting_device_get_description(dev); + props->type = ggml_backend_remoting_device_get_type(dev); + ggml_backend_remoting_device_get_memory(dev, &props->memory_free, &props->memory_total); + + virtgpu * gpu = DEV_TO_GPU(dev); + apir_device_get_props(gpu, &props->caps.async, &props->caps.host_buffer, &props->caps.buffer_from_host_ptr, + &props->caps.events); + + props->caps.buffer_from_host_ptr = false; + props->caps.async = false; + props->caps.events = false; +} + +ggml_backend_buffer_type_t ggml_backend_remoting_device_get_buffer_type(ggml_backend_dev_t dev) { + virtgpu * gpu = DEV_TO_GPU(dev); + + static std::atomic<bool> initialized = false; + static ggml_backend_buffer_type buft; + + if (!initialized) { + static std::mutex mutex; + std::lock_guard<std::mutex> lock(mutex); + + if (!initialized) { + buft = { + /* .iface = */ ggml_backend_remoting_buffer_type_interface, + /* .device = */ dev, + /* .context = */ (void *) gpu->cached_buffer_type.host_handle, + }; + initialized = true; + } + } + + return &buft; +} + +static ggml_backend_buffer_type_t ggml_backend_remoting_device_get_buffer_from_ptr_type(ggml_backend_dev_t dev) { + virtgpu * gpu = DEV_TO_GPU(dev); + + static std::atomic<bool> initialized = false; + static ggml_backend_buffer_type buft; + + if (!initialized) { + static std::mutex mutex; + std::lock_guard<std::mutex> lock(mutex); + + if (!initialized) { + buft = { + /* .iface = */ ggml_backend_remoting_buffer_from_ptr_type_interface, + /* .device = */ dev, + /* .context = */ (void *) gpu->cached_buffer_type.host_handle, + }; + initialized = true; + } + } + + return &buft; +} + +static ggml_backend_buffer_t ggml_backend_remoting_device_buffer_from_ptr(ggml_backend_dev_t dev, + void * ptr, + size_t size, + size_t max_tensor_size) { + virtgpu * gpu = DEV_TO_GPU(dev); + + ggml_backend_remoting_buffer_context * context = (ggml_backend_remoting_buffer_context *) malloc(sizeof(*context)); + if (!context) { + GGML_ABORT(GGML_VIRTGPU "%s: Couldn't allocate the buffer context ...", __func__); + } + + context->gpu = gpu; + context->apir_context = apir_device_buffer_from_ptr(gpu, size, max_tensor_size); + context->base = ptr; + context->is_from_ptr = true; + + ggml_backend_buffer_t buffer = + ggml_backend_buffer_init(ggml_backend_remoting_device_get_buffer_from_ptr_type(dev), + ggml_backend_remoting_buffer_from_ptr_interface, (void *) context, size); + + return buffer; +} + +const ggml_backend_device_i ggml_backend_remoting_device_interface = { + /* .get_name = */ ggml_backend_remoting_device_get_name, + /* .get_description = */ ggml_backend_remoting_device_get_description, + /* .get_memory = */ ggml_backend_remoting_device_get_memory, + /* .get_type = */ ggml_backend_remoting_device_get_type, + /* .get_props = */ ggml_backend_remoting_device_get_props, + /* .init_backend = */ ggml_backend_remoting_device_init, + /* .get_buffer_type = */ ggml_backend_remoting_device_get_buffer_type, + /* .get_host_buffer_type = */ NULL, + /* .buffer_from_host_ptr = */ ggml_backend_remoting_device_buffer_from_ptr, + /* .supports_op = */ ggml_backend_remoting_device_supports_op, + /* .supports_buft = */ ggml_backend_remoting_device_supports_buft, + /* .offload_op = */ ggml_backend_remoting_device_offload_op, + /* .event_new = */ NULL, + /* .event_free = */ NULL, + /* .event_synchronize = */ NULL, +}; diff --git a/ggml/src/ggml-virtgpu/ggml-backend-reg.cpp b/ggml/src/ggml-virtgpu/ggml-backend-reg.cpp new file mode 100644 index 00000000000..a4df5956aa3 --- /dev/null +++ b/ggml/src/ggml-virtgpu/ggml-backend-reg.cpp @@ -0,0 +1,213 @@ +#include "ggml-remoting.h" +#include "ggml-virtgpu.h" + +#include <iostream> +#include <mutex> + +void ggml_virtgpu_cleanup(virtgpu * gpu); + +static virtgpu * apir_initialize() { + static virtgpu * gpu = NULL; + static std::atomic<bool> initialized = false; + + if (initialized) { + // fast track + return gpu; + } + + { + static std::mutex mutex; + std::lock_guard<std::mutex> lock(mutex); + + if (initialized) { + // thread safe + return gpu; + } + + gpu = create_virtgpu(); + if (!gpu) { + initialized = true; + return NULL; + } + + // Pre-fetch and cache all device information, it will not change + gpu->cached_device_info.description = apir_device_get_description(gpu); + if (!gpu->cached_device_info.description) { + GGML_ABORT(GGML_VIRTGPU "%s: failed to initialize the virtgpu device description", __func__); + } + gpu->cached_device_info.device_count = apir_device_get_count(gpu); + gpu->cached_device_info.type = apir_device_get_type(gpu); + + { + // Get the remote name and create prefixed version + char * rmt_device_name = apir_device_get_name(gpu); + if (!rmt_device_name) { + GGML_ABORT(GGML_VIRTGPU "%s: failed to get the virtgpu device name", __func__); + } + + size_t device_name_len = strlen(rmt_device_name) + 11; // "[virtgpu] " + null terminator + gpu->cached_device_info.name = (char *) malloc(device_name_len); + if (!gpu->cached_device_info.name) { + free(rmt_device_name); + GGML_ABORT(GGML_VIRTGPU "%s: failed to allocate memory for prefixed device name", __func__); + } + snprintf(gpu->cached_device_info.name, device_name_len, "[virtgpu] %s", rmt_device_name); + free(rmt_device_name); + } + + apir_device_get_memory(gpu, &gpu->cached_device_info.memory_free, &gpu->cached_device_info.memory_total); + + apir_buffer_type_host_handle_t buft_host_handle = apir_device_get_buffer_type(gpu); + gpu->cached_buffer_type.host_handle = buft_host_handle; + { + // Get the remote name and create prefixed version + char * rmt_name = apir_buffer_type_get_name(gpu, buft_host_handle); + if (!rmt_name) { + GGML_ABORT(GGML_VIRTGPU "%s: failed to get the virtgpu buffer type name", __func__); + } + + size_t prefixed_len = strlen(rmt_name) + 11; // "[virtgpu] " + null terminator + gpu->cached_buffer_type.name = (char *) malloc(prefixed_len); + if (!gpu->cached_buffer_type.name) { + free(rmt_name); + GGML_ABORT(GGML_VIRTGPU "%s: failed to allocate memory for prefixed buffer type name", __func__); + } + snprintf(gpu->cached_buffer_type.name, prefixed_len, "[virtgpu] %s", rmt_name); + free(rmt_name); + } + + gpu->cached_buffer_type.alignment = apir_buffer_type_get_alignment(gpu, buft_host_handle); + gpu->cached_buffer_type.max_size = apir_buffer_type_get_max_size(gpu, buft_host_handle); + + initialized = true; + } + + return gpu; +} + +static int ggml_backend_remoting_get_device_count() { + virtgpu * gpu = apir_initialize(); + if (!gpu) { + return 0; + } + + return gpu->cached_device_info.device_count; +} + +static size_t ggml_backend_remoting_reg_get_device_count(ggml_backend_reg_t reg) { + UNUSED(reg); + + return ggml_backend_remoting_get_device_count(); +} + +static std::vector<ggml_backend_dev_t> devices; + +ggml_backend_dev_t ggml_backend_remoting_get_device(size_t device) { + GGML_ASSERT(device < devices.size()); + return devices[device]; +} + +static void ggml_backend_remoting_reg_init_devices(ggml_backend_reg_t reg) { + if (devices.size() > 0) { + GGML_LOG_INFO(GGML_VIRTGPU "%s: already initialized\n", __func__); + return; + } + + virtgpu * gpu = apir_initialize(); + if (!gpu) { + GGML_LOG_ERROR(GGML_VIRTGPU "%s: apir_initialize failed\n", __func__); + return; + } + + static std::atomic<bool> initialized = false; + + if (initialized) { + return; // fast track + } + + { + static std::mutex mutex; + std::lock_guard<std::mutex> lock(mutex); + if (!initialized) { + for (int i = 0; i < ggml_backend_remoting_get_device_count(); i++) { + ggml_backend_remoting_device_context * ctx = new ggml_backend_remoting_device_context; + char desc[256] = "ggml-virtgpu API Remoting device"; + + ctx->device = i; + ctx->name = GGML_VIRTGPU_NAME + std::to_string(i); + ctx->description = desc; + ctx->gpu = gpu; + + ggml_backend_dev_t dev = new ggml_backend_device{ + /* .iface = */ ggml_backend_remoting_device_interface, + /* .reg = */ reg, + /* .context = */ ctx, + }; + devices.push_back(dev); + } + initialized = true; + } + } +} + +static ggml_backend_dev_t ggml_backend_remoting_reg_get_device(ggml_backend_reg_t reg, size_t device) { + UNUSED(reg); + + return ggml_backend_remoting_get_device(device); +} + +static const char * ggml_backend_remoting_reg_get_name(ggml_backend_reg_t reg) { + UNUSED(reg); + + return GGML_VIRTGPU_NAME; +} + +static const ggml_backend_reg_i ggml_backend_remoting_reg_i = { + /* .get_name = */ ggml_backend_remoting_reg_get_name, + /* .get_device_count = */ ggml_backend_remoting_reg_get_device_count, + /* .get_device = */ ggml_backend_remoting_reg_get_device, + /* .get_proc_address = */ NULL, +}; + +ggml_backend_reg_t ggml_backend_virtgpu_reg() { + virtgpu * gpu = apir_initialize(); + if (!gpu) { + GGML_LOG_ERROR(GGML_VIRTGPU "%s: virtgpu_apir_initialize failed\n", __func__); + } + + static ggml_backend_reg reg = { + /* .api_version = */ GGML_BACKEND_API_VERSION, + /* .iface = */ ggml_backend_remoting_reg_i, + /* .context = */ gpu, + }; + + static bool initialized = false; + if (initialized) { + return ® + } + initialized = true; + + ggml_backend_remoting_reg_init_devices(®); + + return ® +} + +// public function, not exposed in the GGML interface at the moment +void ggml_virtgpu_cleanup(virtgpu * gpu) { + if (gpu->cached_device_info.name) { + free(gpu->cached_device_info.name); + gpu->cached_device_info.name = NULL; + } + if (gpu->cached_device_info.description) { + free(gpu->cached_device_info.description); + gpu->cached_device_info.description = NULL; + } + if (gpu->cached_buffer_type.name) { + free(gpu->cached_buffer_type.name); + gpu->cached_buffer_type.name = NULL; + } + + mtx_destroy(&gpu->data_shmem_mutex); +} + +GGML_BACKEND_DL_IMPL(ggml_backend_virtgpu_reg) diff --git a/ggml/src/ggml-virtgpu/ggml-backend.cpp b/ggml/src/ggml-virtgpu/ggml-backend.cpp new file mode 100644 index 00000000000..12756c9282f --- /dev/null +++ b/ggml/src/ggml-virtgpu/ggml-backend.cpp @@ -0,0 +1,71 @@ +#include "../../include/ggml-virtgpu.h" +#include "ggml-remoting.h" + +static const char * ggml_backend_remoting_get_name(ggml_backend_t backend) { + UNUSED(backend); + + return "API Remoting backend"; +} + +static void ggml_backend_remoting_free(ggml_backend_t backend) { + delete backend; +} + +static ggml_status ggml_backend_remoting_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) { + virtgpu * gpu = DEV_TO_GPU(backend->device); + + return apir_backend_graph_compute(gpu, cgraph); +} + +static void ggml_backend_remoting_graph_optimize(ggml_backend_t backend, ggml_cgraph * cgraph) { + virtgpu * gpu = DEV_TO_GPU(backend->device); +#if true + UNUSED(gpu); + UNUSED(cgraph); +#else + // not working yet + + apir_backend_graph_optimize(gpu, cgraph); +#endif +} + +static ggml_backend_i ggml_backend_remoting_interface = { + /* .get_name = */ ggml_backend_remoting_get_name, + /* .free = */ ggml_backend_remoting_free, + /* .set_tensor_async = */ NULL, // ggml_backend_remoting_set_tensor_async, + /* .get_tensor_async = */ NULL, // ggml_backend_remoting_get_tensor_async, + /* .set_tensor_2d_async = */ NULL, + /* .get_tensor_2d_async = */ NULL, + /* .cpy_tensor_async = */ NULL, // ggml_backend_remoting_cpy_tensor_async, + /* .synchronize = */ NULL, // ggml_backend_remoting_synchronize, + /* .graph_plan_create = */ NULL, + /* .graph_plan_free = */ NULL, + /* .graph_plan_update = */ NULL, + /* .graph_plan_compute = */ NULL, + /* .graph_compute = */ ggml_backend_remoting_graph_compute, + /* .event_record = */ NULL, + /* .event_wait = */ NULL, + /* .graph_optimize = */ ggml_backend_remoting_graph_optimize, +}; + +static ggml_guid_t ggml_backend_remoting_guid() { + static ggml_guid guid = { 0xb8, 0xf7, 0x4f, 0x86, 0x14, 0x03, 0x86, 0x02, + 0x91, 0xc8, 0xdd, 0xe9, 0x02, 0x3f, 0xc0, 0x2b }; + + return &guid; +} + +ggml_backend_t ggml_backend_remoting_device_init(ggml_backend_dev_t dev, const char * params) { + UNUSED(params); + + ggml_backend_remoting_device_context * ctx = (ggml_backend_remoting_device_context *) dev->context; + + ggml_backend_t remoting_backend = new ggml_backend{ + /* .guid = */ ggml_backend_remoting_guid(), + /* .interface = */ ggml_backend_remoting_interface, + /* .device = */ ggml_backend_reg_dev_get(ggml_backend_virtgpu_reg(), ctx->device), + /* .context = */ ctx, + }; + + return remoting_backend; +} diff --git a/ggml/src/ggml-virtgpu/ggml-remoting.h b/ggml/src/ggml-virtgpu/ggml-remoting.h new file mode 100644 index 00000000000..4f70326bee2 --- /dev/null +++ b/ggml/src/ggml-virtgpu/ggml-remoting.h @@ -0,0 +1,71 @@ +#pragma once + +#include "ggml-backend-impl.h" +#include "ggml-backend.h" +#include "ggml-impl.h" +#include "virtgpu.h" + +#include <memory> +#include <string> + +#define GGML_VIRTGPU_NAME "ggml-virtgpu" +#define GGML_VIRTGPU "ggml-virtgpu: " + +// USE_ALWAYS_TRUE_SUPPORTS_OP: 1 is fast, 0 avoid micro-benchmark crashes + +#define USE_ALWAYS_TRUE_SUPPORTS_OP 1 +#define USE_METAL_GUEST_SUPPORTS_OP 0 + +#define DEV_TO_GPU(name) ((ggml_backend_remoting_device_context *) (name)->context)->gpu + +#define BUFFER_TO_GGML_CONTEXT(name) ((ggml_backend_remoting_buffer_context *) (name)->context) + +#define BUFFER_TO_APIR_CONTEXT(name) &((ggml_backend_remoting_buffer_context *) (name)->context)->apir_context + +#define BUFFER_TO_HOST_HANDLE(name) ((ggml_backend_remoting_buffer_context *) (name)->context)->apir_context.host_handle + +#define GET_DEVICE_CONTEXT() (ggml_backend_remoting_device_context *) ggml_backend_remoting_get_device(0)->context + +#define BUFT_TO_GPU(name) ((ggml_backend_remoting_device_context *) (name)->device->context)->gpu + +struct ggml_backend_remoting_device_context { + size_t device; + std::string name; + std::string description; + + std::vector<std::tuple<void *, size_t, virtgpu_shmem *>> shared_memory; + + virtgpu * gpu; +}; + +struct ggml_backend_remoting_buffer_context { + apir_buffer_context_t apir_context; + + virtgpu * gpu; + + void * base; + + bool is_from_ptr; +}; + +extern const ggml_backend_buffer_type_i ggml_backend_remoting_buffer_type_interface; +extern const ggml_backend_device_i ggml_backend_remoting_device_interface; +extern const ggml_backend_buffer_i ggml_backend_remoting_buffer_interface; +extern const ggml_backend_buffer_type_i ggml_backend_remoting_buffer_from_ptr_type_interface; +extern const ggml_backend_buffer_i ggml_backend_remoting_buffer_from_ptr_interface; + +ggml_backend_dev_t ggml_backend_remoting_get_device(size_t device); +ggml_backend_t ggml_backend_remoting_device_init(ggml_backend_dev_t dev, const char * params); +ggml_backend_buffer_type_t ggml_backend_remoting_device_get_buffer_type(ggml_backend_dev_t dev); + +static inline apir_buffer_type_host_handle_t ggml_buffer_type_to_apir_handle(ggml_backend_buffer_type_t buft) { + // in the backend, the buffer handle is the buffer pointer + return (apir_buffer_type_host_handle_t) buft->context; +} + +static inline apir_buffer_host_handle_t ggml_buffer_to_apir_handle(ggml_backend_buffer_t buffer) { + if (!buffer->context) { + GGML_ABORT(GGML_VIRTGPU "%s: no context available :/", __func__); + } + return BUFFER_TO_HOST_HANDLE(buffer); +} diff --git a/ggml/src/ggml-virtgpu/ggmlremoting_functions.yaml b/ggml/src/ggml-virtgpu/ggmlremoting_functions.yaml new file mode 100644 index 00000000000..14ef2433e46 --- /dev/null +++ b/ggml/src/ggml-virtgpu/ggmlremoting_functions.yaml @@ -0,0 +1,166 @@ +# YAML schema for GGML remoting API functions +# This defines the structure for generating the remoting layer code + +# Configuration for the generated files +config: + # Base path for the generated files + base_path: "ggml/src" + + # Header files to update + files: + apir_backend_header: "ggml-virtgpu-apir/backend/shared/apir_backend.gen.h" + backend_dispatched_header: "ggml-virtgpu-apir/backend/backend-dispatched.gen.h" + virtgpu_forward_header: "ggml-virtgpu-apir/virtgpu-forward.gen.h" + +# Simplified function definitions with grouping and metadata combined +functions: + device: + group_description: "device" + functions: + get_device_count: + # No specific metadata - uses default void return and base params + + get_count: + frontend_return: "int" + + get_name: + frontend_return: "char *" + + get_description: + frontend_return: "char *" + + get_type: + frontend_return: "uint32_t" + + get_memory: + frontend_return: "void" + frontend_extra_params: + - "size_t *free" + - "size_t *total" + + supports_op: + frontend_return: "bool" + frontend_extra_params: + - "const ggml_tensor *op" + + get_buffer_type: + frontend_return: "apir_buffer_type_host_handle_t" + + get_props: + frontend_return: "void" + frontend_extra_params: + - "bool *async" + - "bool *host_buffer" + - "bool *buffer_from_host_ptr" + - "bool *events" + + buffer_from_ptr: + frontend_return: "apir_buffer_context_t" + frontend_extra_params: + - "size_t size" + - "size_t max_tensor_size" + + buffer_type: + group_description: "buffer-type" + functions: + get_name: + frontend_return: "char *" + frontend_extra_params: + - "apir_buffer_type_host_handle_t host_handle" + + get_alignment: + frontend_return: "size_t" + frontend_extra_params: + - "apir_buffer_type_host_handle_t host_handle" + + get_max_size: + frontend_return: "size_t" + frontend_extra_params: + - "apir_buffer_type_host_handle_t host_handle" + + is_host: + deprecated: true + + alloc_buffer: + frontend_return: "apir_buffer_context_t" + frontend_extra_params: + - "apir_buffer_type_host_handle_t host_handle" + - "size_t size" + + get_alloc_size: + frontend_return: "size_t" + frontend_extra_params: + - "apir_buffer_type_host_handle_t host_handle" + - "const ggml_tensor *op" + + buffer: + group_description: "buffer" + functions: + get_base: + frontend_return: "void *" + frontend_extra_params: + - "apir_buffer_context_t *buffer_context" + + set_tensor: + frontend_return: "void" + frontend_extra_params: + - "apir_buffer_context_t *buffer_context" + - "ggml_tensor *tensor" + - "const void *data" + - "size_t offset" + - "size_t size" + + get_tensor: + frontend_return: "void" + frontend_extra_params: + - "apir_buffer_context_t *buffer_context" + - "const ggml_tensor *tensor" + - "void *data" + - "size_t offset" + - "size_t size" + + cpy_tensor: + frontend_return: "bool" + frontend_extra_params: + - "apir_buffer_context_t *buffer_context" + - "const ggml_tensor *src" + - "const ggml_tensor *dst" + + clear: + frontend_return: "void" + frontend_extra_params: + - "apir_buffer_context_t *buffer_context" + - "uint8_t value" + + free_buffer: + frontend_return: "void" + frontend_extra_params: + - "apir_buffer_context_t *buffer_context" + + backend: + group_description: "backend" + functions: + graph_compute: + frontend_return: "ggml_status" + frontend_extra_params: + - "ggml_cgraph *cgraph" + + graph_optimize: + frontend_return: "ggml_cgraph *" + frontend_extra_params: + - "ggml_cgraph *cgraph" + enabled: false + +# Naming patterns used for code generation +naming_patterns: + # How to generate enum names + enum_prefix: "APIR_COMMAND_TYPE_" + + # How to generate backend function names + backend_function_prefix: "backend_" + + # How to generate frontend function names + frontend_function_prefix: "apir_" + + # Standard frontend first parameter + frontend_base_param: "struct virtgpu *gpu" diff --git a/ggml/src/ggml-virtgpu/include/apir_hw.h b/ggml/src/ggml-virtgpu/include/apir_hw.h new file mode 100644 index 00000000000..7d6ea2265db --- /dev/null +++ b/ggml/src/ggml-virtgpu/include/apir_hw.h @@ -0,0 +1,9 @@ +#pragma once + +#include <stdint.h> + +struct virgl_renderer_capset_apir { + uint32_t apir_version; + uint32_t supports_blob_resources; + uint32_t reserved[4]; // For future expansion +}; diff --git a/ggml/src/ggml-virtgpu/regenerate_remoting.py b/ggml/src/ggml-virtgpu/regenerate_remoting.py new file mode 100755 index 00000000000..dae75fd1c80 --- /dev/null +++ b/ggml/src/ggml-virtgpu/regenerate_remoting.py @@ -0,0 +1,333 @@ +#!/usr/bin/env python3 +""" +# Generated by Claude AI + +Script to completely regenerate the GGML remoting codebase from YAML configuration. + +This script reads api_functions.yaml and regenerates all the header files and +implementation templates for the GGML remoting layer. + +Usage: + python regenerate_remoting.py + +The script will: +1. Read ggmlremoting_functions.yaml configuration +2. Generate updated header files +3. Generate implementation templates in dedicated files +4. Show a summary of what was generated +""" + +import yaml +from typing import Dict, List, Any +from pathlib import Path +import os +import subprocess +import shutil +import logging + +NL = '\n' # can't have f"{'\n'}" in f-strings + + +class RemotingCodebaseGenerator: + def __init__(self, yaml_path: str = "ggmlremoting_functions.yaml"): + """Initialize the generator with the YAML configuration.""" + self.yaml_path = yaml_path + + if not Path(yaml_path).exists(): + raise FileNotFoundError(f"Configuration file {yaml_path} not found") + + with open(yaml_path, 'r') as f: + self.config = yaml.safe_load(f) + + self.functions = self.config['functions'] + self.naming_patterns = self.config['naming_patterns'] + self.config_data = self.config['config'] + + # Check if clang-format is available + self.clang_format_available = self._check_clang_format_available() + + def _check_clang_format_available(self) -> bool: + """Check if clang-format is available in the system PATH.""" + return shutil.which("clang-format") is not None + + def _format_file_with_clang_format(self, file_path: Path) -> bool: + """Format a file with clang-format -i. Returns True if successful, False otherwise.""" + if not self.clang_format_available: + return False + + try: + subprocess.run( + ["clang-format", "-i", str(file_path)], + check=True, + capture_output=True, + text=True + ) + return True + except subprocess.CalledProcessError: + logging.exception(f" ⚠️ clang-format failed for {file_path}") + return False + except Exception as e: + logging.exception(f" ⚠️ Unexpected error formatting {file_path}: {e}") + return False + + def generate_enum_name(self, group_name: str, function_name: str) -> str: + """Generate the APIR_COMMAND_TYPE enum name for a function.""" + prefix = self.naming_patterns['enum_prefix'] + return f"{prefix}{group_name.upper()}_{function_name.upper()}" + + def generate_backend_function_name(self, group_name: str, function_name: str) -> str: + """Generate the backend function name.""" + function_key = f"{group_name}_{function_name}" + overrides = self.naming_patterns.get('backend_function_overrides', {}) + + if function_key in overrides: + return overrides[function_key] + + prefix = self.naming_patterns['backend_function_prefix'] + return f"{prefix}{group_name}_{function_name}" + + def generate_frontend_function_name(self, group_name: str, function_name: str) -> str: + """Generate the frontend function name.""" + prefix = self.naming_patterns['frontend_function_prefix'] + return f"{prefix}{group_name}_{function_name}" + + def get_enabled_functions(self) -> List[Dict[str, Any]]: + """Get all enabled functions with their metadata.""" + functions = [] + enum_value = 0 + + for group_name, group_data in self.functions.items(): + group_description = group_data['group_description'] + + for function_name, func_metadata in group_data['functions'].items(): + # Handle case where func_metadata is None or empty (functions with only comments) + if func_metadata is None: + func_metadata = {} + + # Functions are enabled by default unless explicitly disabled + if func_metadata.get('enabled', True): + functions.append({ + 'group_name': group_name, + 'function_name': function_name, + 'enum_name': self.generate_enum_name(group_name, function_name), + 'enum_value': enum_value, + 'backend_function': self.generate_backend_function_name(group_name, function_name), + 'frontend_function': self.generate_frontend_function_name(group_name, function_name), + 'frontend_return': func_metadata.get('frontend_return', 'void'), + 'frontend_extra_params': func_metadata.get('frontend_extra_params', []), + 'group_description': group_description, + 'deprecated': func_metadata.get('deprecated', False), + }) + enum_value += 1 + + return functions + + def generate_apir_backend_header(self) -> str: + """Generate the complete apir_backend.h file.""" + functions = self.get_enabled_functions() + + # Generate the enum section + enum_lines = ["typedef enum ApirBackendCommandType {"] + current_group = None + + for func in functions: + # Add comment for new group + if func['group_name'] != current_group: + enum_lines.append("") + enum_lines.append(f" /* {func['group_description']} */") + current_group = func['group_name'] + + enum_lines.append(f" {func['enum_name']} = {func['enum_value']},") + + # Add the count + total_count = len(functions) + enum_lines.append("\n // last command_type index + 1") + enum_lines.append(f" APIR_BACKEND_DISPATCH_TABLE_COUNT = {total_count},") + enum_lines.append("} ApirBackendCommandType;") + + # Generate function name mapping + func_lines = [] + func_lines.append("static inline const char * apir_dispatch_command_name(ApirBackendCommandType type) {") + func_lines.append(" switch (type) {") + + current_group = None + for func in functions: + # Add comment for new group + if func['group_name'] != current_group: + func_lines.append(f" /* {func['group_description']} */") + current_group = func['group_name'] + + # Generate clean function name without backend_ prefix + clean_name = f"{func['group_name']}_{func['function_name']}" + func_lines.append(f" case {func['enum_name']}:") + func_lines.append(f" return \"{clean_name}\";") + + func_lines.append("") + func_lines.append(" default:") + func_lines.append(" return \"unknown\";") + func_lines.append(" }") + func_lines.append("}") + + # Full header template + header_content = NL.join(enum_lines) + "\n\n" + NL.join(func_lines) + "\n" + + return header_content + + def generate_backend_dispatched_header(self) -> str: + """Generate the complete backend-dispatched.h file.""" + functions = self.get_enabled_functions() + + # Function declarations + decl_lines = [] + current_group = None + + for func in functions: + if func['group_name'] != current_group: + decl_lines.append(f"\n/* {func['group_description']} */") + current_group = func['group_name'] + + signature = "uint32_t" + params = "apir_encoder *enc, apir_decoder *dec, virgl_apir_context *ctx" + if func['deprecated']: + decl_lines.append(f"/* {func['enum_name']} is deprecated. Keeping the handler for backward compatibility. */") + + decl_lines.append(f"{signature} {func['backend_function']}({params});") + + # Dispatch table + table_lines = [] + current_group = None + + for func in functions: + if func['group_name'] != current_group: + table_lines.append(f"\n /* {func['group_description']} */") + table_lines.append("") + current_group = func['group_name'] + + deprecated = " /* DEPRECATED */" if func['deprecated'] else "" + table_lines.append(f" /* {func['enum_name']} = */ {func['backend_function']}{deprecated},") + + header_content = f'''\ +#pragma once + +{NL.join(decl_lines)} + +extern "C" {{ +static const backend_dispatch_t apir_backend_dispatch_table[APIR_BACKEND_DISPATCH_TABLE_COUNT] = {{ + {NL.join(table_lines)} +}}; +}} +''' + return header_content + + def generate_virtgpu_forward_header(self) -> str: + """Generate the complete virtgpu-forward.gen.h file.""" + functions = self.get_enabled_functions() + + decl_lines = [] + current_group = None + + for func in functions: + if func['group_name'] != current_group: + decl_lines.append("") + decl_lines.append(f"/* {func['group_description']} */") + current_group = func['group_name'] + + if func['deprecated']: + decl_lines.append(f"/* {func['frontend_function']} is deprecated. */") + continue + + # Build parameter list + params = [self.naming_patterns['frontend_base_param']] + params.extend(func['frontend_extra_params']) + param_str = ', '.join(params) + + decl_lines.append(f"{func['frontend_return']} {func['frontend_function']}({param_str});") + + header_content = f'''\ +#pragma once +{NL.join(decl_lines)} +''' + return header_content + + def regenerate_codebase(self) -> None: + """Regenerate the entire remoting codebase.""" + logging.info("🔄 Regenerating GGML Remoting Codebase...") + logging.info("=" * 50) + + # Detect if we're running from frontend directory + current_dir = os.getcwd() + is_frontend_dir = current_dir.endswith('ggml-virtgpu') + + if is_frontend_dir: + # Running from ggml/src/ggml-virtgpu-apir + logging.info("📍 Detected frontend directory execution") + frontend_base = Path(".") + else: + # Running from project root (fallback to original behavior) + logging.info("📍 Detected project root execution") + base_path = self.config_data.get('base_path', 'ggml/src') + frontend_base = Path(base_path) / "ggml-virtgpu" + + # Compute final file paths + backend_base = frontend_base / "backend" + apir_backend_path = backend_base / "shared" / "apir_backend.gen.h" + backend_dispatched_path = backend_base / "backend-dispatched.gen.h" + virtgpu_forward_path = frontend_base / "virtgpu-forward.gen.h" + + # Create output directories for each file + apir_backend_path.parent.mkdir(parents=True, exist_ok=True) + backend_dispatched_path.parent.mkdir(parents=True, exist_ok=True) + virtgpu_forward_path.parent.mkdir(parents=True, exist_ok=True) + + # Generate header files + logging.info("📁 Generating header files...") + + apir_backend_content = self.generate_apir_backend_header() + apir_backend_path.write_text(apir_backend_content) + logging.info(f" ✅ {apir_backend_path.resolve()}") + + backend_dispatched_content = self.generate_backend_dispatched_header() + backend_dispatched_path.write_text(backend_dispatched_content) + logging.info(f" ✅ {backend_dispatched_path.resolve()}") + + virtgpu_forward_content = self.generate_virtgpu_forward_header() + virtgpu_forward_path.write_text(virtgpu_forward_content) + logging.info(f" ✅ {virtgpu_forward_path.resolve()}") + + # Format generated files with clang-format + generated_files = [apir_backend_path, backend_dispatched_path, virtgpu_forward_path] + + if not self.clang_format_available: + logging.warning("\n⚠️clang-format not found in PATH. Generated files will not be formatted.\n" + " Install clang-format to enable automatic code formatting.") + else: + logging.info("\n🎨 Formatting files with clang-format...") + for file_path in generated_files: + if self._format_file_with_clang_format(file_path): + logging.info(f" ✅ Formatted {file_path.name}") + else: + logging.warning(f" ❌ Failed to format {file_path.name}") + + # Generate summary + functions = self.get_enabled_functions() + total_functions = len(functions) + + logging.info("\n📊 Generation Summary:") + logging.info("=" * 50) + logging.info(f" Total functions: {total_functions}") + logging.info(f" Function groups: {len(self.functions)}") + logging.info(" Header files: 3") + logging.info(f" Working directory: {current_dir}") + + +def main(): + try: + generator = RemotingCodebaseGenerator() + generator.regenerate_codebase() + except Exception as e: + logging.exception(f"❌ Error: {e}") + exit(1) + + +if __name__ == "__main__": + main() diff --git a/ggml/src/ggml-virtgpu/virtgpu-apir.h b/ggml/src/ggml-virtgpu/virtgpu-apir.h new file mode 100644 index 00000000000..238f960acd2 --- /dev/null +++ b/ggml/src/ggml-virtgpu/virtgpu-apir.h @@ -0,0 +1,15 @@ +#include "backend/shared/apir_backend.h" +#include "ggml-alloc.h" +#include "ggml-impl.h" +#include "ggml.h" +#include "virtgpu-shm.h" +#include "virtgpu-utils.h" + +struct apir_buffer_context_t { + apir_buffer_host_handle_t host_handle; + + struct virtgpu_shmem shmem; + apir_buffer_type_host_handle_t buft_host_handle; +}; + +#include "virtgpu-forward.gen.h" diff --git a/ggml/src/ggml-virtgpu/virtgpu-forward-backend.cpp b/ggml/src/ggml-virtgpu/virtgpu-forward-backend.cpp new file mode 100644 index 00000000000..4593690c638 --- /dev/null +++ b/ggml/src/ggml-virtgpu/virtgpu-forward-backend.cpp @@ -0,0 +1,58 @@ +#include "virtgpu-forward-impl.h" + +static long long current_time_ms() { + timespec ts; + clock_gettime(CLOCK_REALTIME, &ts); // Use CLOCK_MONOTONIC for elapsed time + return (long long) ts.tv_sec * 1000000000LL + ts.tv_nsec; +} + +ggml_status apir_backend_graph_compute(virtgpu * gpu, ggml_cgraph * cgraph) { + apir_encoder * encoder; + apir_decoder * decoder; + ApirForwardReturnCode ret; + + REMOTE_CALL_PREPARE(gpu, encoder, APIR_COMMAND_TYPE_BACKEND_GRAPH_COMPUTE); + + std::vector<uint8_t> cgraph_data; + size_t cgraph_size = apir_serialize_ggml_cgraph(cgraph, cgraph_data); + + virtgpu_shmem temp_shmem; // Local storage for large buffers + virtgpu_shmem * shmem = &temp_shmem; + bool using_shared_shmem = false; + + if (cgraph_size <= gpu->data_shmem.mmap_size) { + // Lock mutex before using shared data_shmem buffer + if (mtx_lock(&gpu->data_shmem_mutex) != thrd_success) { + GGML_ABORT(GGML_VIRTGPU "%s: Failed to lock data_shmem mutex", __func__); + } + using_shared_shmem = true; + shmem = &gpu->data_shmem; + } else if (virtgpu_shmem_create(gpu, cgraph_size, shmem)) { + GGML_ABORT(GGML_VIRTGPU "%s: Couldn't allocate the guest-host shared buffer", __func__); + } + + apir_encode_virtgpu_shmem_res_id(encoder, shmem->res_id); + + apir_encode_size_t(encoder, &cgraph_size); + + char * shmem_data = (char *) shmem->mmap_ptr; + apir_encoder secondary_enc = apir_new_encoder(shmem_data, cgraph_size); + + apir_encode_cgraph_data(&secondary_enc, cgraph_data); + + REMOTE_CALL(gpu, encoder, decoder, ret); + + ggml_status status = GGML_STATUS_ABORTED; + apir_decode_ggml_status(decoder, &status); + + remote_call_finish(gpu, encoder, decoder); + + // Unlock mutex before cleanup + if (using_shared_shmem) { + mtx_unlock(&gpu->data_shmem_mutex); + } else { + virtgpu_shmem_destroy(gpu, shmem); + } + + return status; +} diff --git a/ggml/src/ggml-virtgpu/virtgpu-forward-buffer-type.cpp b/ggml/src/ggml-virtgpu/virtgpu-forward-buffer-type.cpp new file mode 100644 index 00000000000..38f8ec945e0 --- /dev/null +++ b/ggml/src/ggml-virtgpu/virtgpu-forward-buffer-type.cpp @@ -0,0 +1,110 @@ +#include "virtgpu-forward-impl.h" + +char * apir_buffer_type_get_name(virtgpu * gpu, apir_buffer_type_host_handle_t host_handle) { + apir_encoder * encoder; + apir_decoder * decoder; + ApirForwardReturnCode ret; + + REMOTE_CALL_PREPARE(gpu, encoder, APIR_COMMAND_TYPE_BUFFER_TYPE_GET_NAME); + + apir_encode_apir_buffer_type_host_handle(encoder, host_handle); + + REMOTE_CALL(gpu, encoder, decoder, ret); + + const size_t string_size = apir_decode_array_size_unchecked(decoder); + char * string = (char *) apir_decoder_alloc_array(sizeof(char), string_size); + if (!string) { + GGML_LOG_ERROR(GGML_VIRTGPU "%s: Could not allocate the device name buffer\n", __func__); + apir_decoder_set_fatal(decoder); + } + apir_decode_char_array(decoder, string, string_size); + + remote_call_finish(gpu, encoder, decoder); + + return string; +} + +size_t apir_buffer_type_get_alignment(virtgpu * gpu, apir_buffer_type_host_handle_t host_handle) { + apir_encoder * encoder; + apir_decoder * decoder; + ApirForwardReturnCode ret; + + REMOTE_CALL_PREPARE(gpu, encoder, APIR_COMMAND_TYPE_BUFFER_TYPE_GET_ALIGNMENT); + + apir_encode_apir_buffer_type_host_handle(encoder, host_handle); + + REMOTE_CALL(gpu, encoder, decoder, ret); + + size_t alignment; + apir_decode_size_t(decoder, &alignment); + + remote_call_finish(gpu, encoder, decoder); + + return alignment; +} + +size_t apir_buffer_type_get_max_size(virtgpu * gpu, apir_buffer_type_host_handle_t host_handle) { + apir_encoder * encoder; + apir_decoder * decoder; + ApirForwardReturnCode ret; + + REMOTE_CALL_PREPARE(gpu, encoder, APIR_COMMAND_TYPE_BUFFER_TYPE_GET_MAX_SIZE); + + apir_encode_apir_buffer_type_host_handle(encoder, host_handle); + + REMOTE_CALL(gpu, encoder, decoder, ret); + + size_t max_size; + apir_decode_size_t(decoder, &max_size); + + remote_call_finish(gpu, encoder, decoder); + + return max_size; +} + +apir_buffer_context_t apir_buffer_type_alloc_buffer(virtgpu * gpu, + apir_buffer_type_host_handle_t host_handle, + size_t size) { + apir_encoder * encoder; + apir_decoder * decoder; + ApirForwardReturnCode ret; + + apir_buffer_context_t buffer_context; + + REMOTE_CALL_PREPARE(gpu, encoder, APIR_COMMAND_TYPE_BUFFER_TYPE_ALLOC_BUFFER); + + apir_encode_apir_buffer_type_host_handle(encoder, host_handle); + + apir_encode_size_t(encoder, &size); + + REMOTE_CALL(gpu, encoder, decoder, ret); + + apir_decode_apir_buffer_host_handle_t(decoder, &buffer_context.host_handle); + + remote_call_finish(gpu, encoder, decoder); + + return buffer_context; +} + +size_t apir_buffer_type_get_alloc_size(virtgpu * gpu, + apir_buffer_type_host_handle_t host_handle, + const ggml_tensor * op) { + apir_encoder * encoder; + apir_decoder * decoder; + ApirForwardReturnCode ret; + + REMOTE_CALL_PREPARE(gpu, encoder, APIR_COMMAND_TYPE_BUFFER_TYPE_GET_ALLOC_SIZE); + + apir_encode_apir_buffer_type_host_handle(encoder, host_handle); + + apir_encode_ggml_tensor_inline(encoder, op); + + REMOTE_CALL(gpu, encoder, decoder, ret); + + size_t alloc_size; + apir_decode_size_t(decoder, &alloc_size); + + remote_call_finish(gpu, encoder, decoder); + + return alloc_size; +} diff --git a/ggml/src/ggml-virtgpu/virtgpu-forward-buffer.cpp b/ggml/src/ggml-virtgpu/virtgpu-forward-buffer.cpp new file mode 100644 index 00000000000..228284f4a42 --- /dev/null +++ b/ggml/src/ggml-virtgpu/virtgpu-forward-buffer.cpp @@ -0,0 +1,173 @@ +#include "virtgpu-forward-impl.h" + +void * apir_buffer_get_base(virtgpu * gpu, apir_buffer_context_t * buffer_context) { + apir_encoder * encoder; + apir_decoder * decoder; + ApirForwardReturnCode ret; + + REMOTE_CALL_PREPARE(gpu, encoder, APIR_COMMAND_TYPE_BUFFER_GET_BASE); + + apir_encode_apir_buffer_host_handle_t(encoder, &buffer_context->host_handle); + + REMOTE_CALL(gpu, encoder, decoder, ret); + + uintptr_t base; + apir_decode_uintptr_t(decoder, &base); + + remote_call_finish(gpu, encoder, decoder); + + return (void *) base; +} + +void apir_buffer_set_tensor(virtgpu * gpu, + apir_buffer_context_t * buffer_context, + ggml_tensor * tensor, + const void * data, + size_t offset, + size_t size) { + apir_encoder * encoder; + apir_decoder * decoder; + ApirForwardReturnCode ret; + + REMOTE_CALL_PREPARE(gpu, encoder, APIR_COMMAND_TYPE_BUFFER_SET_TENSOR); + + apir_encode_apir_buffer_host_handle_t(encoder, &buffer_context->host_handle); + apir_encode_ggml_tensor(encoder, tensor); + + virtgpu_shmem temp_shmem; // Local storage for large buffers + virtgpu_shmem * shmem = &temp_shmem; + bool using_shared_shmem = false; + + if (size <= gpu->data_shmem.mmap_size) { + // Lock mutex before using shared data_shmem buffer + if (mtx_lock(&gpu->data_shmem_mutex) != thrd_success) { + GGML_ABORT(GGML_VIRTGPU "%s: Failed to lock data_shmem mutex", __func__); + } + using_shared_shmem = true; + shmem = &gpu->data_shmem; + + } else if (virtgpu_shmem_create(gpu, size, shmem)) { + GGML_ABORT(GGML_VIRTGPU "%s: Couldn't allocate the guest-host shared buffer", __func__); + } + + memcpy(shmem->mmap_ptr, data, size); + apir_encode_virtgpu_shmem_res_id(encoder, shmem->res_id); + + apir_encode_size_t(encoder, &offset); + apir_encode_size_t(encoder, &size); + + REMOTE_CALL(gpu, encoder, decoder, ret); + + remote_call_finish(gpu, encoder, decoder); + + // Unlock mutex before cleanup + if (using_shared_shmem) { + mtx_unlock(&gpu->data_shmem_mutex); + } else { + virtgpu_shmem_destroy(gpu, shmem); + } + + return; +} + +void apir_buffer_get_tensor(virtgpu * gpu, + apir_buffer_context_t * buffer_context, + const ggml_tensor * tensor, + void * data, + size_t offset, + size_t size) { + apir_encoder * encoder; + apir_decoder * decoder; + ApirForwardReturnCode ret; + + REMOTE_CALL_PREPARE(gpu, encoder, APIR_COMMAND_TYPE_BUFFER_GET_TENSOR); + + apir_encode_apir_buffer_host_handle_t(encoder, &buffer_context->host_handle); + apir_encode_ggml_tensor(encoder, tensor); + + virtgpu_shmem temp_shmem; // Local storage for large buffers + virtgpu_shmem * shmem = &temp_shmem; + bool using_shared_shmem = false; + + if (size <= gpu->data_shmem.mmap_size) { + // Lock mutex before using shared data_shmem buffer + if (mtx_lock(&gpu->data_shmem_mutex) != thrd_success) { + GGML_ABORT(GGML_VIRTGPU "%s: Failed to lock data_shmem mutex", __func__); + } + using_shared_shmem = true; + shmem = &gpu->data_shmem; + + } else if (virtgpu_shmem_create(gpu, size, shmem)) { + GGML_ABORT(GGML_VIRTGPU "%s: Couldn't allocate the guest-host shared buffer", __func__); + } + + apir_encode_virtgpu_shmem_res_id(encoder, shmem->res_id); + apir_encode_size_t(encoder, &offset); + apir_encode_size_t(encoder, &size); + + REMOTE_CALL(gpu, encoder, decoder, ret); + + memcpy(data, shmem->mmap_ptr, size); + + remote_call_finish(gpu, encoder, decoder); + + // Unlock mutex before cleanup + if (using_shared_shmem) { + mtx_unlock(&gpu->data_shmem_mutex); + } else { + virtgpu_shmem_destroy(gpu, shmem); + } +} + +bool apir_buffer_cpy_tensor(virtgpu * gpu, + apir_buffer_context_t * buffer_context, + const ggml_tensor * src, + const ggml_tensor * dst) { + apir_encoder * encoder; + apir_decoder * decoder; + ApirForwardReturnCode ret; + + REMOTE_CALL_PREPARE(gpu, encoder, APIR_COMMAND_TYPE_BUFFER_CPY_TENSOR); + + apir_encode_apir_buffer_host_handle_t(encoder, &buffer_context->host_handle); + apir_encode_ggml_tensor(encoder, src); + apir_encode_ggml_tensor(encoder, dst); + + REMOTE_CALL(gpu, encoder, decoder, ret); + + bool ret_val; + apir_decode_bool_t(decoder, &ret_val); + + remote_call_finish(gpu, encoder, decoder); + + return ret_val; +} + +void apir_buffer_clear(virtgpu * gpu, apir_buffer_context_t * buffer_context, uint8_t value) { + apir_encoder * encoder; + apir_decoder * decoder; + ApirForwardReturnCode ret; + + REMOTE_CALL_PREPARE(gpu, encoder, APIR_COMMAND_TYPE_BUFFER_CLEAR); + + apir_encode_apir_buffer_host_handle_t(encoder, &buffer_context->host_handle); + apir_encode_uint8_t(encoder, &value); + + REMOTE_CALL(gpu, encoder, decoder, ret); + + remote_call_finish(gpu, encoder, decoder); +} + +void apir_buffer_free_buffer(virtgpu * gpu, apir_buffer_context_t * buffer_context) { + apir_encoder * encoder; + apir_decoder * decoder; + ApirForwardReturnCode ret; + + REMOTE_CALL_PREPARE(gpu, encoder, APIR_COMMAND_TYPE_BUFFER_FREE_BUFFER); + + apir_encode_apir_buffer_host_handle_t(encoder, &buffer_context->host_handle); + + REMOTE_CALL(gpu, encoder, decoder, ret); + + remote_call_finish(gpu, encoder, decoder); +} diff --git a/ggml/src/ggml-virtgpu/virtgpu-forward-device.cpp b/ggml/src/ggml-virtgpu/virtgpu-forward-device.cpp new file mode 100644 index 00000000000..9f513c138dd --- /dev/null +++ b/ggml/src/ggml-virtgpu/virtgpu-forward-device.cpp @@ -0,0 +1,192 @@ +#include "virtgpu-forward-impl.h" +#include "virtgpu-shm.h" + +int apir_device_get_count(virtgpu * gpu) { + apir_encoder * encoder; + apir_decoder * decoder; + ApirForwardReturnCode ret; + + REMOTE_CALL_PREPARE(gpu, encoder, APIR_COMMAND_TYPE_DEVICE_GET_COUNT); + REMOTE_CALL(gpu, encoder, decoder, ret); + + int32_t dev_count = -1; + apir_decode_int32_t(decoder, &dev_count); + + remote_call_finish(gpu, encoder, decoder); + + return dev_count; +} + +char * apir_device_get_name(virtgpu * gpu) { + apir_encoder * encoder; + apir_decoder * decoder; + ApirForwardReturnCode ret; + + REMOTE_CALL_PREPARE(gpu, encoder, APIR_COMMAND_TYPE_DEVICE_GET_NAME); + REMOTE_CALL(gpu, encoder, decoder, ret); + + const size_t string_size = apir_decode_array_size_unchecked(decoder); + char * string = (char *) apir_decoder_alloc_array(sizeof(char), string_size); + if (!string) { + GGML_LOG_ERROR(GGML_VIRTGPU "%s: Could not allocate the device name buffer\n", __func__); + return NULL; + } + apir_decode_char_array(decoder, string, string_size); + + remote_call_finish(gpu, encoder, decoder); + + return string; +} + +char * apir_device_get_description(virtgpu * gpu) { + apir_encoder * encoder; + apir_decoder * decoder; + ApirForwardReturnCode ret; + + REMOTE_CALL_PREPARE(gpu, encoder, APIR_COMMAND_TYPE_DEVICE_GET_DESCRIPTION); + + REMOTE_CALL(gpu, encoder, decoder, ret); + + const size_t string_size = apir_decode_array_size_unchecked(decoder); + char * string = (char *) apir_decoder_alloc_array(sizeof(char), string_size); + if (!string) { + GGML_LOG_ERROR(GGML_VIRTGPU "%s: Could not allocate the device description buffer\n", __func__); + + return NULL; + } + apir_decode_char_array(decoder, string, string_size); + + remote_call_finish(gpu, encoder, decoder); + + return string; +} + +uint32_t apir_device_get_type(virtgpu * gpu) { + static uint32_t dev_type = 255; + if (dev_type != 255) { + return dev_type; + } + + apir_encoder * encoder; + apir_decoder * decoder; + ApirForwardReturnCode ret; + + REMOTE_CALL_PREPARE(gpu, encoder, APIR_COMMAND_TYPE_DEVICE_GET_TYPE); + + REMOTE_CALL(gpu, encoder, decoder, ret); + + apir_decode_uint32_t(decoder, &dev_type); + + remote_call_finish(gpu, encoder, decoder); + + return dev_type; +} + +void apir_device_get_memory(virtgpu * gpu, size_t * free, size_t * total) { + static size_t dev_free = 0; + static size_t dev_total = 0; + apir_encoder * encoder; + apir_decoder * decoder; + ApirForwardReturnCode ret; + + REMOTE_CALL_PREPARE(gpu, encoder, APIR_COMMAND_TYPE_DEVICE_GET_MEMORY); + + REMOTE_CALL(gpu, encoder, decoder, ret); + + apir_decode_size_t(decoder, &dev_free); + apir_decode_size_t(decoder, &dev_total); + + *free = dev_free; + *total = dev_total; + + remote_call_finish(gpu, encoder, decoder); + + return; +} + +bool apir_device_supports_op(virtgpu * gpu, const ggml_tensor * op) { + apir_encoder * encoder; + apir_decoder * decoder; + ApirForwardReturnCode ret; + + REMOTE_CALL_PREPARE(gpu, encoder, APIR_COMMAND_TYPE_DEVICE_SUPPORTS_OP); + + apir_encode_ggml_tensor_inline(encoder, op); + + REMOTE_CALL(gpu, encoder, decoder, ret); + + bool supports_op; + apir_decode_bool_t(decoder, &supports_op); + + remote_call_finish(gpu, encoder, decoder); + + return supports_op; +} + +apir_buffer_type_host_handle_t apir_device_get_buffer_type(virtgpu * gpu) { + apir_encoder * encoder; + apir_decoder * decoder; + ApirForwardReturnCode ret; + + REMOTE_CALL_PREPARE(gpu, encoder, APIR_COMMAND_TYPE_DEVICE_GET_BUFFER_TYPE); + + REMOTE_CALL(gpu, encoder, decoder, ret); + + apir_buffer_type_host_handle_t buft_handle; + apir_decode_apir_buffer_type_host_handle_t(decoder, &buft_handle); + + remote_call_finish(gpu, encoder, decoder); + + return buft_handle; +} + +void apir_device_get_props(virtgpu * gpu, + bool * async, + bool * host_buffer, + bool * buffer_from_host_ptr, + bool * events) { + apir_encoder * encoder; + apir_decoder * decoder; + ApirForwardReturnCode ret; + + REMOTE_CALL_PREPARE(gpu, encoder, APIR_COMMAND_TYPE_DEVICE_GET_PROPS); + + REMOTE_CALL(gpu, encoder, decoder, ret); + + apir_decode_bool_t(decoder, async); + apir_decode_bool_t(decoder, host_buffer); + apir_decode_bool_t(decoder, buffer_from_host_ptr); + apir_decode_bool_t(decoder, events); + + remote_call_finish(gpu, encoder, decoder); + + return; +} + +apir_buffer_context_t apir_device_buffer_from_ptr(virtgpu * gpu, size_t size, size_t max_tensor_size) { + apir_encoder * encoder; + apir_decoder * decoder; + ApirForwardReturnCode ret; + + apir_buffer_context_t buffer_context; + + REMOTE_CALL_PREPARE(gpu, encoder, APIR_COMMAND_TYPE_DEVICE_BUFFER_FROM_PTR); + + if (virtgpu_shmem_create(gpu, size, &buffer_context.shmem)) { + GGML_ABORT(GGML_VIRTGPU "%s: Couldn't allocate %ldb of guest-host shared buffer", __func__, size); + } + + apir_encode_virtgpu_shmem_res_id(encoder, buffer_context.shmem.res_id); + + apir_encode_size_t(encoder, &size); + apir_encode_size_t(encoder, &max_tensor_size); + + REMOTE_CALL(gpu, encoder, decoder, ret); + + apir_decode_apir_buffer_host_handle_t(decoder, &buffer_context.host_handle); + buffer_context.buft_host_handle = apir_decode_apir_buffer_type_host_handle(decoder); + + remote_call_finish(gpu, encoder, decoder); + + return buffer_context; +} diff --git a/ggml/src/ggml-virtgpu/virtgpu-forward-impl.h b/ggml/src/ggml-virtgpu/virtgpu-forward-impl.h new file mode 100644 index 00000000000..4d0b6e05c74 --- /dev/null +++ b/ggml/src/ggml-virtgpu/virtgpu-forward-impl.h @@ -0,0 +1,36 @@ +#pragma once + +// clang-format off +#include "virtgpu.h" +#include "ggml-remoting.h" +#include "backend/shared/apir_backend.h" +#include "backend/shared/apir_cs_ggml.h" +#include "ggml-backend-impl.h" +// clang-format on + +#define REMOTE_CALL_PREPARE(gpu_dev_name, encoder_name, apir_command_type__) \ + int32_t REMOTE_CALL_PREPARE_forward_flag = (int32_t) apir_command_type__; \ + const char * REMOTE_CALL_PREPARE_command_name = apir_dispatch_command_name(apir_command_type__); \ + do { \ + encoder_name = remote_call_prepare(gpu_dev_name, APIR_COMMAND_TYPE_FORWARD, REMOTE_CALL_PREPARE_forward_flag); \ + if (!encoder_name) { \ + GGML_ABORT(GGML_VIRTGPU "%s: failed to prepare the remote call encoder", __func__); \ + } \ + } while (0) + +#define REMOTE_CALL(gpu_dev_name, encoder_name, decoder_name, ret_name) \ + do { \ + ret_name = (ApirForwardReturnCode) remote_call(gpu_dev_name, encoder_name, &decoder_name, 0, NULL); \ + if (!decoder_name) { \ + GGML_ABORT(GGML_VIRTGPU "%s: failed to kick the remote call", __func__); \ + } \ + if (ret_name < APIR_FORWARD_BASE_INDEX) { \ + GGML_ABORT(GGML_VIRTGPU "%s: failed to forward the API call: %s: code %d", __func__, \ + apir_forward_error(ret_name), ret_name); \ + } \ + ret_name = (ApirForwardReturnCode) (ret_name - APIR_FORWARD_BASE_INDEX); \ + if (ret_name != 0) { \ + GGML_ABORT(GGML_VIRTGPU "backend function '%s' failed (return code: %d)", \ + REMOTE_CALL_PREPARE_command_name, ret_name); \ + } \ + } while (0) diff --git a/ggml/src/ggml-virtgpu/virtgpu-forward.gen.h b/ggml/src/ggml-virtgpu/virtgpu-forward.gen.h new file mode 100644 index 00000000000..44b0ad1ffa1 --- /dev/null +++ b/ggml/src/ggml-virtgpu/virtgpu-forward.gen.h @@ -0,0 +1,53 @@ +#pragma once + +/* device */ +void apir_device_get_device_count(struct virtgpu * gpu); +int apir_device_get_count(struct virtgpu * gpu); +char * apir_device_get_name(struct virtgpu * gpu); +char * apir_device_get_description(struct virtgpu * gpu); +uint32_t apir_device_get_type(struct virtgpu * gpu); +void apir_device_get_memory(struct virtgpu * gpu, size_t * free, size_t * total); +bool apir_device_supports_op(struct virtgpu * gpu, const ggml_tensor * op); +apir_buffer_type_host_handle_t apir_device_get_buffer_type(struct virtgpu * gpu); +void apir_device_get_props(struct virtgpu * gpu, + bool * async, + bool * host_buffer, + bool * buffer_from_host_ptr, + bool * events); +apir_buffer_context_t apir_device_buffer_from_ptr(struct virtgpu * gpu, size_t size, size_t max_tensor_size); + +/* buffer-type */ +char * apir_buffer_type_get_name(struct virtgpu * gpu, apir_buffer_type_host_handle_t host_handle); +size_t apir_buffer_type_get_alignment(struct virtgpu * gpu, apir_buffer_type_host_handle_t host_handle); +size_t apir_buffer_type_get_max_size(struct virtgpu * gpu, apir_buffer_type_host_handle_t host_handle); +/* apir_buffer_type_is_host is deprecated. */ +apir_buffer_context_t apir_buffer_type_alloc_buffer(struct virtgpu * gpu, + apir_buffer_type_host_handle_t host_handle, + size_t size); +size_t apir_buffer_type_get_alloc_size(struct virtgpu * gpu, + apir_buffer_type_host_handle_t host_handle, + const ggml_tensor * op); + +/* buffer */ +void * apir_buffer_get_base(struct virtgpu * gpu, apir_buffer_context_t * buffer_context); +void apir_buffer_set_tensor(struct virtgpu * gpu, + apir_buffer_context_t * buffer_context, + ggml_tensor * tensor, + const void * data, + size_t offset, + size_t size); +void apir_buffer_get_tensor(struct virtgpu * gpu, + apir_buffer_context_t * buffer_context, + const ggml_tensor * tensor, + void * data, + size_t offset, + size_t size); +bool apir_buffer_cpy_tensor(struct virtgpu * gpu, + apir_buffer_context_t * buffer_context, + const ggml_tensor * src, + const ggml_tensor * dst); +void apir_buffer_clear(struct virtgpu * gpu, apir_buffer_context_t * buffer_context, uint8_t value); +void apir_buffer_free_buffer(struct virtgpu * gpu, apir_buffer_context_t * buffer_context); + +/* backend */ +ggml_status apir_backend_graph_compute(struct virtgpu * gpu, ggml_cgraph * cgraph); diff --git a/ggml/src/ggml-virtgpu/virtgpu-shm.cpp b/ggml/src/ggml-virtgpu/virtgpu-shm.cpp new file mode 100644 index 00000000000..7f2c2322d91 --- /dev/null +++ b/ggml/src/ggml-virtgpu/virtgpu-shm.cpp @@ -0,0 +1,99 @@ +#include "virtgpu-shm.h" + +#include "virtgpu.h" +#include "ggml-remoting.h" + +#include <assert.h> + +static uint32_t virtgpu_ioctl_resource_create_blob(virtgpu * gpu, + uint32_t blob_mem, + uint32_t blob_flags, + size_t blob_size, + uint64_t blob_id, + uint32_t * res_id) { +#ifdef SIMULATE_BO_SIZE_FIX + blob_size = align64(blob_size, 4096); +#endif + + drm_virtgpu_resource_create_blob args = { + .blob_mem = blob_mem, + .blob_flags = blob_flags, + .bo_handle = 0, + .res_handle = 0, + .size = blob_size, + .pad = 0, + .cmd_size = 0, + .cmd = 0, + .blob_id = blob_id, + }; + + if (virtgpu_ioctl(gpu, DRM_IOCTL_VIRTGPU_RESOURCE_CREATE_BLOB, &args)) { + return 0; + } + + *res_id = args.res_handle; + return args.bo_handle; +} + +static void virtgpu_ioctl_gem_close(virtgpu * gpu, uint32_t gem_handle) { + drm_gem_close args = { + .handle = gem_handle, + .pad = 0, + }; + + const int ret = virtgpu_ioctl(gpu, DRM_IOCTL_GEM_CLOSE, &args); + assert(!ret); +#ifdef NDEBUG + UNUSED(ret); +#endif +} + +static void * virtgpu_ioctl_map(virtgpu * gpu, uint32_t gem_handle, size_t size) { + drm_virtgpu_map args = { + .offset = 0, + .handle = gem_handle, + .pad = 0, + }; + + if (virtgpu_ioctl(gpu, DRM_IOCTL_VIRTGPU_MAP, &args)) { + return NULL; + } + + void * ptr = mmap(NULL, size, PROT_READ | PROT_WRITE, MAP_SHARED, gpu->fd, args.offset); + if (ptr == MAP_FAILED) { + return NULL; + } + + return ptr; +} + +void virtgpu_shmem_destroy(virtgpu * gpu, virtgpu_shmem * shmem) { + munmap(shmem->mmap_ptr, shmem->mmap_size); + virtgpu_ioctl_gem_close(gpu, shmem->gem_handle); +} + +int virtgpu_shmem_create(virtgpu * gpu, size_t size, virtgpu_shmem * shmem) { + size = align64(size, 16384); + + uint32_t res_id; + uint32_t gem_handle = virtgpu_ioctl_resource_create_blob(gpu, VIRTGPU_BLOB_MEM_HOST3D, + VIRTGPU_BLOB_FLAG_USE_MAPPABLE, size, 0, &res_id); + + if (!gem_handle) { + return 1; + } + + void * ptr = virtgpu_ioctl_map(gpu, gem_handle, size); + if (!ptr) { + virtgpu_ioctl_gem_close(gpu, gem_handle); + GGML_LOG_ERROR(GGML_VIRTGPU "%s: virtgpu_ioctl_map failed\n", __func__); + return 1; + } + + shmem->res_id = res_id; + shmem->mmap_size = size; + shmem->mmap_ptr = ptr; + shmem->gem_handle = gem_handle; + + return 0; +} diff --git a/ggml/src/ggml-virtgpu/virtgpu-shm.h b/ggml/src/ggml-virtgpu/virtgpu-shm.h new file mode 100644 index 00000000000..606860a0946 --- /dev/null +++ b/ggml/src/ggml-virtgpu/virtgpu-shm.h @@ -0,0 +1,23 @@ +#pragma once + +#include "virtgpu-utils.h" + +#include <sys/mman.h> + +#include <atomic> +#include <cassert> +#include <cstddef> +#include <cstdint> + +struct virtgpu; + +struct virtgpu_shmem { + uint32_t res_id; + size_t mmap_size; + void * mmap_ptr; + + uint32_t gem_handle; +}; + +int virtgpu_shmem_create(virtgpu * gpu, size_t size, virtgpu_shmem * shmem); +void virtgpu_shmem_destroy(virtgpu * gpu, virtgpu_shmem * shmem); diff --git a/ggml/src/ggml-virtgpu/virtgpu-utils.cpp b/ggml/src/ggml-virtgpu/virtgpu-utils.cpp new file mode 100644 index 00000000000..8a2805e9902 --- /dev/null +++ b/ggml/src/ggml-virtgpu/virtgpu-utils.cpp @@ -0,0 +1,179 @@ +#include "virtgpu-utils.h" + +#include <malloc.h> +#include <stdlib.h> + +#include <cstring> + +#define NODE_ALLOC_ALIGN 64 +#define NODE_PTR_MASK (~((uintptr_t) NODE_ALLOC_ALIGN - 1)) +#define NODE_LEVEL_MASK ((uintptr_t) NODE_ALLOC_ALIGN - 1) +#define NULL_NODE 0 + +#define os_malloc_aligned(_size, _align) _aligned_malloc(_size, _align) +#define os_free_aligned(_ptr) free(_ptr) +#define p_atomic_cmpxchg(v, old, _new) __sync_val_compare_and_swap((v), (old), (_new)) + +static inline uint64_t util_logbase2_64(uint64_t n) { +#if defined(HAVE___BUILTIN_CLZLL) + return ((sizeof(uint64_t) * 8 - 1) - __builtin_clzll(n | 1)); +#else + uint64_t pos = 0ull; + if (n >= 1ull << 32) { + n >>= 32; + pos += 32; + } + if (n >= 1ull << 16) { + n >>= 16; + pos += 16; + } + if (n >= 1ull << 8) { + n >>= 8; + pos += 8; + } + if (n >= 1ull << 4) { + n >>= 4; + pos += 4; + } + if (n >= 1ull << 2) { + n >>= 2; + pos += 2; + } + if (n >= 1ull << 1) { + pos += 1; + } + return pos; +#endif +} + +void util_sparse_array_init(util_sparse_array * arr, size_t elem_size, size_t node_size) { + memset(arr, 0, sizeof(*arr)); + arr->elem_size = elem_size; + arr->node_size_log2 = util_logbase2_64(node_size); + assert(node_size >= 2 && node_size == (1ull << arr->node_size_log2)); +} + +static inline void * os_malloc_aligned(size_t size, size_t alignment) { + void * ptr; + alignment = (alignment + sizeof(void *) - 1) & ~(sizeof(void *) - 1); + if (posix_memalign(&ptr, alignment, size) != 0) { + return NULL; + } + return ptr; +} + +static inline void * _util_sparse_array_node_data(uintptr_t handle) { + return (void *) (handle & NODE_PTR_MASK); +} + +static inline unsigned _util_sparse_array_node_level(uintptr_t handle) { + return handle & NODE_LEVEL_MASK; +} + +static inline void _util_sparse_array_node_finish(util_sparse_array * arr, uintptr_t node) { + if (_util_sparse_array_node_level(node) > 0) { + uintptr_t * children = (uintptr_t *) _util_sparse_array_node_data(node); + size_t node_size = 1ull << arr->node_size_log2; + for (size_t i = 0; i < node_size; i++) { + if (children[i]) { + _util_sparse_array_node_finish(arr, children[i]); + } + } + } + + os_free_aligned(_util_sparse_array_node_data(node)); +} + +static inline uintptr_t _util_sparse_array_node(void * data, unsigned level) { + assert(data != NULL); + assert(((uintptr_t) data & NODE_LEVEL_MASK) == 0); + assert((level & NODE_PTR_MASK) == 0); + return (uintptr_t) data | level; +} + +inline uintptr_t _util_sparse_array_node_alloc(util_sparse_array * arr, unsigned level) { + size_t size; + if (level == 0) { + size = arr->elem_size << arr->node_size_log2; + } else { + size = sizeof(uintptr_t) << arr->node_size_log2; + } + + void * data = os_malloc_aligned(size, NODE_ALLOC_ALIGN); + memset(data, 0, size); + + return _util_sparse_array_node(data, level); +} + +static inline uintptr_t _util_sparse_array_set_or_free_node(uintptr_t * node_ptr, uintptr_t cmp_node, uintptr_t node) { + uintptr_t prev_node = p_atomic_cmpxchg(node_ptr, cmp_node, node); + + if (prev_node != cmp_node) { + /* We lost the race. Free this one and return the one that was already + * allocated. + */ + os_free_aligned(_util_sparse_array_node_data(node)); + return prev_node; + } else { + return node; + } +} + +void * util_sparse_array_get(util_sparse_array * arr, uint64_t idx) { + const unsigned node_size_log2 = arr->node_size_log2; + uintptr_t root = p_atomic_read(&arr->root); + if (unlikely(!root)) { + unsigned root_level = 0; + uint64_t idx_iter = idx >> node_size_log2; + while (idx_iter) { + idx_iter >>= node_size_log2; + root_level++; + } + uintptr_t new_root = _util_sparse_array_node_alloc(arr, root_level); + root = _util_sparse_array_set_or_free_node(&arr->root, NULL_NODE, new_root); + } + + while (1) { + unsigned root_level = _util_sparse_array_node_level(root); + uint64_t root_idx = idx >> (root_level * node_size_log2); + if (likely(root_idx < (1ull << node_size_log2))) { + break; + } + + /* In this case, we have a root but its level is low enough that the + * requested index is out-of-bounds. + */ + uintptr_t new_root = _util_sparse_array_node_alloc(arr, root_level + 1); + + uintptr_t * new_root_children = (uintptr_t *) _util_sparse_array_node_data(new_root); + new_root_children[0] = root; + + /* We only add one at a time instead of the whole tree because it's + * easier to ensure correctness of both the tree building and the + * clean-up path. Because we're only adding one node we never have to + * worry about trying to free multiple things without freeing the old + * things. + */ + root = _util_sparse_array_set_or_free_node(&arr->root, root, new_root); + } + + void * node_data = _util_sparse_array_node_data(root); + unsigned node_level = _util_sparse_array_node_level(root); + while (node_level > 0) { + uint64_t child_idx = (idx >> (node_level * node_size_log2)) & ((1ull << node_size_log2) - 1); + + uintptr_t * children = (uintptr_t *) node_data; + uintptr_t child = p_atomic_read(&children[child_idx]); + + if (unlikely(!child)) { + child = _util_sparse_array_node_alloc(arr, node_level - 1); + child = _util_sparse_array_set_or_free_node(&children[child_idx], NULL_NODE, child); + } + + node_data = _util_sparse_array_node_data(child); + node_level = _util_sparse_array_node_level(child); + } + + uint64_t elem_idx = idx & ((1ull << node_size_log2) - 1); + return (void *) ((char *) node_data + (elem_idx * arr->elem_size)); +} diff --git a/ggml/src/ggml-virtgpu/virtgpu-utils.h b/ggml/src/ggml-virtgpu/virtgpu-utils.h new file mode 100644 index 00000000000..a0036b4e2bc --- /dev/null +++ b/ggml/src/ggml-virtgpu/virtgpu-utils.h @@ -0,0 +1,86 @@ +#pragma once + +#include <atomic> +#include <cassert> +#include <cerrno> +#include <cstdarg> +#include <cstddef> +#include <cstdint> +#include <cstdio> +#include <cstdlib> +#include <ctime> + +#define unlikely(x) __builtin_expect(!!(x), 0) +#define likely(x) __builtin_expect(!!(x), 1) + +#ifndef UNUSED +# define UNUSED(x) (void) (x) +#endif + +/** Checks is a value is a power of two. Does not handle zero. */ +#define IS_POT(v) (((v) & ((v) - 1)) == 0) + +/** Checks is a value is a power of two. Zero handled. */ +#define IS_POT_NONZERO(v) ((v) != 0 && IS_POT(v)) + +/** Align a value to a power of two */ +#define ALIGN_POT(x, pot_align) (((x) + (pot_align) - 1) & ~((pot_align) - 1)) + +#define p_atomic_read(_v) __atomic_load_n((_v), __ATOMIC_ACQUIRE) + +static inline bool util_is_power_of_two_nonzero64(uint64_t v) { + return IS_POT_NONZERO(v); +} + +static inline uint64_t align64(uint64_t value, uint64_t alignment) { + assert(util_is_power_of_two_nonzero64(alignment)); + return ALIGN_POT(value, alignment); +} + +struct list_head { + list_head * prev; + list_head * next; +}; + +struct util_sparse_array { + size_t elem_size; + unsigned node_size_log2; + + uintptr_t root; +}; + +void * util_sparse_array_get(util_sparse_array * arr, uint64_t idx); +void util_sparse_array_init(util_sparse_array * arr, size_t elem_size, size_t node_size); + +inline void os_time_sleep(int64_t usecs) { + timespec time; + time.tv_sec = usecs / 1000000; + time.tv_nsec = (usecs % 1000000) * 1000; + while (clock_nanosleep(CLOCK_MONOTONIC, 0, &time, &time) == EINTR) + ; +} + +struct timer_data { + long long start; + long long total; + long long count; +}; + +static inline void start_timer(timer_data * timer) { + timespec ts; + clock_gettime(CLOCK_MONOTONIC, &ts); + timer->start = (long long) ts.tv_sec * 1000000000LL + ts.tv_nsec; +} + +// returns the duration in ns +static inline long long stop_timer(timer_data * timer) { + timespec ts; + clock_gettime(CLOCK_MONOTONIC, &ts); + long long timer_end = (long long) ts.tv_sec * 1000000000LL + ts.tv_nsec; + + long long duration = (timer_end - timer->start); + timer->total += duration; + timer->count += 1; + + return duration; +} diff --git a/ggml/src/ggml-virtgpu/virtgpu.cpp b/ggml/src/ggml-virtgpu/virtgpu.cpp new file mode 100644 index 00000000000..e3ae1cc75e0 --- /dev/null +++ b/ggml/src/ggml-virtgpu/virtgpu.cpp @@ -0,0 +1,545 @@ +#include "virtgpu.h" +#include "ggml-remoting.h" + +#include <stdio.h> +#include <unistd.h> + +#include <cassert> +#include <cerrno> +#include <cstdlib> + +static virt_gpu_result_t virtgpu_open_device(virtgpu * gpu, const drmDevicePtr dev); +static virt_gpu_result_t virtgpu_open(virtgpu * gpu); + +static virt_gpu_result_t virtgpu_init_capset(virtgpu * gpu); +static virt_gpu_result_t virtgpu_init_context(virtgpu * gpu); + +static int virtgpu_ioctl_context_init(virtgpu * gpu, virgl_renderer_capset capset_id); +static int virtgpu_ioctl_get_caps(virtgpu * gpu, + virgl_renderer_capset id, + uint32_t version, + void * capset, + size_t capset_size); +static uint64_t virtgpu_ioctl_getparam(virtgpu * gpu, uint64_t param); +static void virtgpu_init_renderer_info(virtgpu * gpu); + +static void log_call_duration(long long call_duration_ns, const char * name); + +const uint64_t APIR_HANDSHAKE_MAX_WAIT_MS = 2 * 1000; // 2s +const uint64_t APIR_LOADLIBRARY_MAX_WAIT_MS = 60 * 1000; // 60s + +static int virtgpu_handshake(virtgpu * gpu) { + apir_encoder * encoder; + apir_decoder * decoder; + + encoder = remote_call_prepare(gpu, APIR_COMMAND_TYPE_HANDSHAKE, 0); + if (!encoder) { + GGML_ABORT(GGML_VIRTGPU "%s: failed to prepare the remote call encoder", __func__); + return 1; + } + + /* write handshake props */ + + uint32_t guest_major = APIR_PROTOCOL_MAJOR; + uint32_t guest_minor = APIR_PROTOCOL_MINOR; + apir_encode_uint32_t(encoder, &guest_major); + apir_encode_uint32_t(encoder, &guest_minor); + + /* *** */ + + uint32_t ret_magic; + long long call_duration_ns; + ret_magic = remote_call(gpu, encoder, &decoder, APIR_HANDSHAKE_MAX_WAIT_MS, &call_duration_ns); + log_call_duration(call_duration_ns, "API Remoting handshake"); + + if (!decoder) { + GGML_ABORT(GGML_VIRTGPU + "%s: failed to initiate the communication with the virglrenderer library. " + "Most likely, the wrong virglrenderer library was loaded in the hypervisor.", + __func__); + return 1; + } + + /* read handshake return values */ + + uint32_t host_major; + uint32_t host_minor; + + if (ret_magic != APIR_HANDSHAKE_MAGIC) { + GGML_ABORT(GGML_VIRTGPU "%s: handshake with the virglrenderer failed (code=%d | %s)", __func__, ret_magic, + apir_backend_initialize_error(ret_magic)); + } else { + apir_decode_uint32_t(decoder, &host_major); + apir_decode_uint32_t(decoder, &host_minor); + } + + remote_call_finish(gpu, encoder, decoder); + + if (ret_magic != APIR_HANDSHAKE_MAGIC) { + return 1; + } + + GGML_LOG_INFO(GGML_VIRTGPU "%s: Guest is running with %u.%u\n", __func__, guest_major, guest_minor); + GGML_LOG_INFO(GGML_VIRTGPU "%s: Host is running with %u.%u\n", __func__, host_major, host_minor); + + if (guest_major != host_major) { + GGML_LOG_ERROR(GGML_VIRTGPU "Host major (%d) and guest major (%d) version differ\n", host_major, guest_major); + } else if (guest_minor != host_minor) { + GGML_LOG_WARN(GGML_VIRTGPU "Host minor (%d) and guest minor (%d) version differ\n", host_minor, guest_minor); + } + + return 0; +} + +static ApirLoadLibraryReturnCode virtgpu_load_library(virtgpu * gpu) { + apir_encoder * encoder; + apir_decoder * decoder; + ApirLoadLibraryReturnCode ret; + + encoder = remote_call_prepare(gpu, APIR_COMMAND_TYPE_LOADLIBRARY, 0); + if (!encoder) { + GGML_ABORT(GGML_VIRTGPU "%s: hypercall error: failed to prepare the API Remoting command encoder", __func__); + return APIR_LOAD_LIBRARY_HYPERCALL_INITIALIZATION_ERROR; + } + + long long call_duration_ns; + + ret = (ApirLoadLibraryReturnCode) remote_call(gpu, encoder, &decoder, APIR_LOADLIBRARY_MAX_WAIT_MS, + &call_duration_ns); + log_call_duration(call_duration_ns, "API Remoting LoadLibrary"); + + if (!decoder) { + GGML_ABORT(GGML_VIRTGPU "%s: hypercall error: failed to trigger the API Remoting hypercall.\n", __func__); + return APIR_LOAD_LIBRARY_HYPERCALL_INITIALIZATION_ERROR; + } + + remote_call_finish(gpu, encoder, decoder); + + if (ret == APIR_LOAD_LIBRARY_SUCCESS) { + GGML_LOG_INFO(GGML_VIRTGPU "The API Remoting backend was successfully loaded and initialized\n"); + + return ret; + } + + // something wrong happened, find out what. + if (ret < APIR_LOAD_LIBRARY_INIT_BASE_INDEX) { + if (ret == APIR_LOAD_LIBRARY_ENV_VAR_MISSING) { + GGML_ABORT(GGML_VIRTGPU + "%s: virglrenderer could not open the API Remoting backend library, " + "some environment variables are missing. " + "Make sure virglrenderer is correctly configured by the hypervisor. (%s)", + __func__, apir_load_library_error(ret)); + } else if (ret == APIR_LOAD_LIBRARY_CANNOT_OPEN) { + GGML_ABORT(GGML_VIRTGPU + "%s: virglrenderer could not open the API Remoting backend library. " + "Make sure virglrenderer is correctly configured by the hypervisor. (%s)", + __func__, apir_load_library_error(ret)); + } else if (ret == APIR_LOAD_LIBRARY_ENV_VAR_MISSING) { + GGML_ABORT(GGML_VIRTGPU + "%s: could not load the backend library, some symbols are missing. " + "Make sure virglrenderer is correctly configured by the hypervisor. (%s) ", + __func__, apir_load_library_error(ret)); + } else { + GGML_ABORT(GGML_VIRTGPU "%s: virglrenderer could not load the API Remoting backend library. (%s - code %d)", + __func__, apir_load_library_error(ret), ret); + } + return ret; + } + + GGML_LOG_INFO(GGML_VIRTGPU "%s: virglrenderer successfully loaded the API Remoting backend library.\n", __func__); + + ApirLoadLibraryReturnCode apir_ret = (ApirLoadLibraryReturnCode) (ret - APIR_LOAD_LIBRARY_INIT_BASE_INDEX); + + if (apir_ret == APIR_LOAD_LIBRARY_CANNOT_OPEN) { + GGML_ABORT(GGML_VIRTGPU + "%s: the API Remoting backend library couldn't load the GGML backend library. " + "Make sure virglrenderer is correctly configured by the hypervisor. (%s)", + __func__, apir_load_library_error(apir_ret)); + } else if (apir_ret == APIR_LOAD_LIBRARY_SYMBOL_MISSING) { + GGML_ABORT( + GGML_VIRTGPU + "%s: the API Remoting backend library couldn't load the GGML backend library, some symbols are missing. " + "Make sure virglrenderer is correctly configured by the hypervisor. (%s)", + __func__, apir_load_library_error(apir_ret)); + } else if (apir_ret < APIR_LOAD_LIBRARY_INIT_BASE_INDEX) { + GGML_ABORT(GGML_VIRTGPU + "%s: the API Remoting backend library couldn't load the GGML backend library: apir code=%d | %s)", + __func__, apir_ret, apir_load_library_error(apir_ret)); + } else { + uint32_t lib_ret = apir_ret - APIR_LOAD_LIBRARY_INIT_BASE_INDEX; + GGML_ABORT(GGML_VIRTGPU + "%s: the API Remoting backend library failed to initialize its backend library: apir code=%d)", + __func__, lib_ret); + } + return ret; +} + +virtgpu * create_virtgpu() { + virtgpu * gpu = new virtgpu(); + + gpu->use_apir_capset = getenv("GGML_REMOTING_USE_APIR_CAPSET") != nullptr; + util_sparse_array_init(&gpu->shmem_array, sizeof(virtgpu_shmem), 1024); + + // Initialize mutex to protect shared data_shmem buffer + if (mtx_init(&gpu->data_shmem_mutex, mtx_plain) != thrd_success) { + delete gpu; + GGML_ABORT(GGML_VIRTGPU "%s: failed to initialize data_shmem mutex", __func__); + return NULL; + } + + if (virtgpu_open(gpu) != APIR_SUCCESS) { + GGML_LOG_ERROR(GGML_VIRTGPU "%s: failed to open the virtgpu device\n", __func__); + return NULL; + } + + if (virtgpu_init_capset(gpu) != APIR_SUCCESS) { + if (gpu->use_apir_capset) { + GGML_ABORT(GGML_VIRTGPU + "%s: failed to initialize the virtgpu APIR capset. Make sure that the virglrenderer library " + "supports it.", + __func__); + } else { + GGML_ABORT(GGML_VIRTGPU "%s: failed to initialize the virtgpu Venus capset", __func__); + } + return NULL; + } + + if (virtgpu_init_context(gpu) != APIR_SUCCESS) { + GGML_ABORT(GGML_VIRTGPU "%s: failed to initialize the GPU context", __func__); + return NULL; + } + + if (virtgpu_shmem_create(gpu, SHMEM_REPLY_SIZE, &gpu->reply_shmem)) { + GGML_ABORT(GGML_VIRTGPU "%s: failed to create the shared reply memory pages", __func__); + return NULL; + } + + if (virtgpu_shmem_create(gpu, SHMEM_DATA_SIZE, &gpu->data_shmem)) { + GGML_ABORT(GGML_VIRTGPU "%s: failed to create the shared data memory pages", __func__); + return NULL; + } + + if (virtgpu_handshake(gpu)) { + GGML_ABORT(GGML_VIRTGPU "%s: failed to handshake with the virglrenderer library", __func__); + return NULL; + } + + if (virtgpu_load_library(gpu) != APIR_LOAD_LIBRARY_SUCCESS) { + GGML_ABORT(GGML_VIRTGPU "%s: failed to load the backend library", __func__); + return NULL; + } + + return gpu; +} + +static virt_gpu_result_t virtgpu_open(virtgpu * gpu) { + drmDevicePtr devs[8]; + int count = drmGetDevices2(0, devs, ARRAY_SIZE(devs)); + if (count < 0) { + GGML_LOG_ERROR(GGML_VIRTGPU "%s: failed to enumerate DRM devices\n", __func__); + return APIR_ERROR_INITIALIZATION_FAILED; + } + + virt_gpu_result_t result = APIR_ERROR_INITIALIZATION_FAILED; + for (int i = 0; i < count; i++) { + result = virtgpu_open_device(gpu, devs[i]); + if (result == APIR_SUCCESS) { + break; + } + } + + drmFreeDevices(devs, count); + + return result; +} + +static virt_gpu_result_t virtgpu_open_device(virtgpu * gpu, const drmDevicePtr dev) { + const char * node_path = dev->nodes[DRM_NODE_RENDER]; + + int fd = open(node_path, O_RDWR | O_CLOEXEC); + if (fd < 0) { + GGML_ABORT(GGML_VIRTGPU "%s: failed to open %s", __func__, node_path); + return APIR_ERROR_INITIALIZATION_FAILED; + } + + drmVersionPtr version = drmGetVersion(fd); + if (!version || strcmp(version->name, "virtio_gpu") || version->version_major != 0) { + if (version) { + GGML_LOG_ERROR(GGML_VIRTGPU "%s: unknown DRM driver %s version %d\n", __func__, version->name, + version->version_major); + } else { + GGML_LOG_ERROR(GGML_VIRTGPU "%s: failed to get DRM driver version\n", __func__); + } + + if (version) { + drmFreeVersion(version); + } + close(fd); + return APIR_ERROR_INITIALIZATION_FAILED; + } + + gpu->fd = fd; + + drmFreeVersion(version); + + GGML_LOG_INFO(GGML_VIRTGPU "using DRM device %s\n", node_path); + + return APIR_SUCCESS; +} + +static virt_gpu_result_t virtgpu_init_context(virtgpu * gpu) { + assert(!gpu->capset.version); + const int ret = virtgpu_ioctl_context_init(gpu, gpu->capset.id); + if (ret) { + GGML_LOG_ERROR(GGML_VIRTGPU "%s: failed to initialize context: %s\n", __func__, strerror(errno)); + return APIR_ERROR_INITIALIZATION_FAILED; + } + + return APIR_SUCCESS; +} + +static virt_gpu_result_t virtgpu_init_capset(virtgpu * gpu) { + if (gpu->use_apir_capset) { + GGML_LOG_INFO(GGML_VIRTGPU "Using the APIR capset\n"); + gpu->capset.id = VIRTGPU_DRM_CAPSET_APIR; + } else { + GGML_LOG_INFO(GGML_VIRTGPU "Using the Venus capset\n"); + gpu->capset.id = VIRTGPU_DRM_CAPSET_VENUS; + } + gpu->capset.version = 0; + + int ret = + virtgpu_ioctl_get_caps(gpu, gpu->capset.id, gpu->capset.version, &gpu->capset.data, sizeof(gpu->capset.data)); + + if (ret) { + GGML_LOG_ERROR(GGML_VIRTGPU "%s: failed to get APIR v%d capset: %s\n", __func__, gpu->capset.version, + strerror(errno)); + return APIR_ERROR_INITIALIZATION_FAILED; + } + + assert(gpu->capset.data.supports_blob_resources); + + return APIR_SUCCESS; +} + +static int virtgpu_ioctl_context_init(virtgpu * gpu, virgl_renderer_capset capset_id) { + drm_virtgpu_context_set_param ctx_set_params[3] = { + { + .param = VIRTGPU_CONTEXT_PARAM_CAPSET_ID, + .value = capset_id, + }, + { + .param = VIRTGPU_CONTEXT_PARAM_NUM_RINGS, + .value = 1, + }, + { + .param = VIRTGPU_CONTEXT_PARAM_POLL_RINGS_MASK, + .value = 0, /* don't generate drm_events on fence signaling */ + }, + }; + + drm_virtgpu_context_init args = { + .num_params = ARRAY_SIZE(ctx_set_params), + .pad = 0, + .ctx_set_params = (uintptr_t) &ctx_set_params, + }; + + return virtgpu_ioctl(gpu, DRM_IOCTL_VIRTGPU_CONTEXT_INIT, &args); +} + +static int virtgpu_ioctl_get_caps(virtgpu * gpu, + virgl_renderer_capset id, + uint32_t version, + void * capset, + size_t capset_size) { + drm_virtgpu_get_caps args = { + .cap_set_id = id, + .cap_set_ver = version, + .addr = (uintptr_t) capset, + .size = (__u32) capset_size, + .pad = 0, + }; + + return virtgpu_ioctl(gpu, DRM_IOCTL_VIRTGPU_GET_CAPS, &args); +} + +static uint64_t virtgpu_ioctl_getparam(virtgpu * gpu, uint64_t param) { + /* val must be zeroed because kernel only writes the lower 32 bits */ + uint64_t val = 0; + drm_virtgpu_getparam args = { + .param = param, + .value = (uintptr_t) &val, + }; + + const int ret = virtgpu_ioctl(gpu, DRM_IOCTL_VIRTGPU_GETPARAM, &args); + return ret ? 0 : val; +} + +apir_encoder * remote_call_prepare(virtgpu * gpu, ApirCommandType apir_cmd_type, int32_t cmd_flags) { + /* + * Prepare the command encoder and its buffer + */ + + thread_local char encoder_buffer[4096]; + + thread_local apir_encoder enc; + enc = { + .cur = encoder_buffer, + .start = encoder_buffer, + .end = encoder_buffer + sizeof(encoder_buffer), + .fatal = false, + }; + + /* + * Fill the command encoder with the common args: + * - cmd_type (int32_t) + * - cmd_flags (int32_t) + * - reply res id (uint32_t) + */ + + int32_t cmd_type = apir_cmd_type; + + // for testing during the hypervisor transition + if (!gpu->use_apir_capset) { + cmd_type += VENUS_COMMAND_TYPE_LENGTH; + } + apir_encode_int32_t(&enc, &cmd_type); + apir_encode_int32_t(&enc, &cmd_flags); + + uint32_t reply_res_id = gpu->reply_shmem.res_id; + apir_encode_uint32_t(&enc, &reply_res_id); + + return &enc; +} + +void remote_call_finish(virtgpu * gpu, apir_encoder * enc, apir_decoder * dec) { + UNUSED(gpu); + + if (!enc) { + GGML_ABORT(GGML_VIRTGPU "%s: Invalid (null) encoder", __func__); + } + + if (!dec) { + GGML_ABORT(GGML_VIRTGPU "%s: Invalid (null) decoder", __func__); + } + + if (apir_encoder_get_fatal(enc)) { + GGML_LOG_ERROR(GGML_VIRTGPU "%s: Failed to encode the output parameters.", __func__); + } + + if (apir_decoder_get_fatal(dec)) { + GGML_LOG_ERROR(GGML_VIRTGPU "%s: Failed to decode the input parameters.", __func__); + } +} + +uint32_t remote_call(virtgpu * gpu, + apir_encoder * encoder, + apir_decoder ** decoder, + float max_wait_ms, + long long * call_duration_ns) { + /* + * Prepare the reply notification pointer + */ + + volatile std::atomic_uint * atomic_reply_notif = (volatile std::atomic_uint *) gpu->reply_shmem.mmap_ptr; + *atomic_reply_notif = 0; + + /* + * Trigger the execbuf ioctl + */ + + drm_virtgpu_execbuffer args = { + .flags = VIRTGPU_EXECBUF_RING_IDX, + .size = (uint32_t) (encoder->cur - encoder->start), + .command = (uintptr_t) encoder->start, + + .bo_handles = 0, + .num_bo_handles = 0, + + .fence_fd = 0, + .ring_idx = 0, + .syncobj_stride = 0, + .num_in_syncobjs = 0, + .num_out_syncobjs = 0, + .in_syncobjs = 0, + .out_syncobjs = 0, + }; + + *decoder = NULL; + + int ret = drmIoctl(gpu->fd, DRM_IOCTL_VIRTGPU_EXECBUFFER, &args); + + if (ret != 0) { + GGML_ABORT(GGML_VIRTGPU "%s: the virtgpu EXECBUFFER ioctl failed (%d)", __func__, ret); + } + + /* + * Wait for the response notification + */ + timer_data wait_host_reply_timer = { 0, 0, 0 }; + + start_timer(&wait_host_reply_timer); + + timespec ts_start, ts_end; + clock_gettime(CLOCK_MONOTONIC, &ts_start); + long long start_time = (long long) ts_start.tv_sec * 1000000000LL + ts_start.tv_nsec; + + bool timedout = false; + uint32_t notif_value = 0; + while (true) { + notif_value = std::atomic_load_explicit(atomic_reply_notif, std::memory_order_acquire); + + if (notif_value != 0) { + break; + } + + int64_t base_sleep_us = 15; + + os_time_sleep(base_sleep_us); + + if (max_wait_ms) { + clock_gettime(CLOCK_MONOTONIC, &ts_end); + long long end_time = (long long) ts_end.tv_sec * 1000000000LL + ts_end.tv_nsec; + float duration_ms = (end_time - start_time) / 1000000; + + if (duration_ms > max_wait_ms) { + timedout = true; + break; + } + } + } + + if (call_duration_ns) { + *call_duration_ns = stop_timer(&wait_host_reply_timer); + } + + if (max_wait_ms && timedout) { + GGML_LOG_ERROR(GGML_VIRTGPU "%s: timed out waiting for the host answer...\n", __func__); + return APIR_FORWARD_TIMEOUT; + } + + /* + * Prepare the decoder + */ + static apir_decoder response_dec; + response_dec.cur = (char *) gpu->reply_shmem.mmap_ptr + sizeof(*atomic_reply_notif); + response_dec.end = (char *) gpu->reply_shmem.mmap_ptr + gpu->reply_shmem.mmap_size; + *decoder = &response_dec; + + // extract the actual return value from the notif flag + uint32_t returned_value = notif_value - 1; + return returned_value; +} + +static void log_call_duration(long long call_duration_ns, const char * name) { + double call_duration_ms = (double) call_duration_ns / 1e6; // 1 millisecond = 1e6 nanoseconds + double call_duration_s = (double) call_duration_ns / 1e9; // 1 second = 1e9 nanoseconds + + if (call_duration_s > 1) { + GGML_LOG_INFO(GGML_VIRTGPU "waited %.2fs for the %s host reply...\n", call_duration_s, name); + } else if (call_duration_ms > 1) { + GGML_LOG_INFO(GGML_VIRTGPU "waited %.2fms for the %s host reply...\n", call_duration_ms, name); + } else { + GGML_LOG_INFO(GGML_VIRTGPU "waited %lldns for the %s host reply...\n", call_duration_ns, name); + } +} diff --git a/ggml/src/ggml-virtgpu/virtgpu.h b/ggml/src/ggml-virtgpu/virtgpu.h new file mode 100644 index 00000000000..6b8de583893 --- /dev/null +++ b/ggml/src/ggml-virtgpu/virtgpu.h @@ -0,0 +1,115 @@ +#pragma once + +// clang-format off +#include "virtgpu-utils.h" +#include "virtgpu-shm.h" +#include "virtgpu-apir.h" + +#include "backend/shared/api_remoting.h" +#include "backend/shared/apir_cs.h" + +#include <fcntl.h> +#include <stdbool.h> +#include <stdio.h> +#include <sys/stat.h> +#include <sys/sysmacros.h> +#include <threads.h> +#include <xf86drm.h> + +#include <cstring> + +#define VIRGL_RENDERER_UNSTABLE_APIS 1 +#include "apir_hw.h" +#include <drm/virtgpu_drm.h> +#include "venus_hw.h" +// clang-format on + +#ifndef VIRTGPU_DRM_CAPSET_APIR +// Will be defined include/drm/virtgpu_drm.h when +// https://gitlab.freedesktop.org/virgl/virglrenderer/-/merge_requests/1590/diffs +// is merged +# define VIRTGPU_DRM_CAPSET_APIR 10 +#endif + +// Mesa/Virlgrenderer Venus internal. Only necessary during the +// Venus->APIR transition in Virglrenderer +#define VENUS_COMMAND_TYPE_LENGTH 331 + +#ifndef VIRTGPU_DRM_CAPSET_VENUS // only available with Linux >= v6.16 +# define VIRTGPU_DRM_CAPSET_VENUS 4 +#endif + +typedef uint32_t virgl_renderer_capset; + +/* from src/virtio/vulkan/vn_renderer_virtgpu.c */ +#define VIRTGPU_PCI_VENDOR_ID 0x1af4 +#define VIRTGPU_PCI_DEVICE_ID 0x1050 +#define VIRTGPU_BLOB_MEM_GUEST_VRAM 0x0004 +#define VIRTGPU_PARAM_GUEST_VRAM 9 + +#define SHMEM_DATA_SIZE 0x1830000 // 24MiB +#define SHMEM_REPLY_SIZE 0x4000 + +#define ARRAY_SIZE(x) (sizeof(x) / sizeof((x)[0])) + +enum virt_gpu_result_t { + APIR_SUCCESS = 0, + APIR_ERROR_INITIALIZATION_FAILED = -1, +}; + +#define PRINTFLIKE(f, a) __attribute__((format(__printf__, f, a))) + +struct virtgpu { + bool use_apir_capset; + + int fd; + + struct { + virgl_renderer_capset id; + uint32_t version; + virgl_renderer_capset_apir data; + } capset; + + util_sparse_array shmem_array; + + /* APIR communication pages */ + virtgpu_shmem reply_shmem; + virtgpu_shmem data_shmem; + + /* Mutex to protect shared data_shmem buffer from concurrent access */ + mtx_t data_shmem_mutex; + + /* Cached device information to prevent memory leaks and race conditions */ + struct { + char * description; + char * name; + int32_t device_count; + uint32_t type; + size_t memory_free; + size_t memory_total; + } cached_device_info; + + /* Cached buffer type information to prevent memory leaks and race conditions */ + struct { + apir_buffer_type_host_handle_t host_handle; + char * name; + size_t alignment; + size_t max_size; + } cached_buffer_type; +}; + +static inline int virtgpu_ioctl(virtgpu * gpu, unsigned long request, void * args) { + return drmIoctl(gpu->fd, request, args); +} + +virtgpu * create_virtgpu(); + +apir_encoder * remote_call_prepare(virtgpu * gpu, ApirCommandType apir_cmd_type, int32_t cmd_flags); + +uint32_t remote_call(virtgpu * gpu, + apir_encoder * enc, + apir_decoder ** dec, + float max_wait_ms, + long long * call_duration_ns); + +void remote_call_finish(virtgpu * gpu, apir_encoder * enc, apir_decoder * dec); diff --git a/ggml/src/ggml-vulkan/CMakeLists.txt b/ggml/src/ggml-vulkan/CMakeLists.txt index de01336cd3f..2d9e85794ad 100644 --- a/ggml/src/ggml-vulkan/CMakeLists.txt +++ b/ggml/src/ggml-vulkan/CMakeLists.txt @@ -8,6 +8,11 @@ endif() find_package(Vulkan COMPONENTS glslc REQUIRED) +if (DEFINED ENV{VULKAN_SDK}) + list(APPEND CMAKE_PREFIX_PATH "$ENV{VULKAN_SDK}") +endif() +find_package(SPIRV-Headers CONFIG REQUIRED) + if (CMAKE_CXX_COMPILER_ID STREQUAL "MSVC") # Parallel build object files add_definitions(/MP) @@ -74,6 +79,12 @@ if (Vulkan_FOUND) "GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT" ) + test_shader_extension_support( + "GL_NV_cooperative_matrix_decode_vector" + "${CMAKE_CURRENT_SOURCE_DIR}/vulkan-shaders/feature-tests/coopmat2_decode_vector.comp" + "GGML_VULKAN_COOPMAT2_DECODE_VECTOR_GLSLC_SUPPORT" + ) + test_shader_extension_support( "GL_EXT_integer_dot_product" "${CMAKE_CURRENT_SOURCE_DIR}/vulkan-shaders/feature-tests/integer_dot.comp" @@ -90,7 +101,7 @@ if (Vulkan_FOUND) target_include_directories(ggml-vulkan PRIVATE ${CMAKE_CURRENT_BINARY_DIR}) # Workaround to the "can't dereference invalidated vector iterator" bug in clang-cl debug build - # Posssibly relevant: https://stackoverflow.com/questions/74748276/visual-studio-no-displays-the-correct-length-of-stdvector + # Possibly relevant: https://stackoverflow.com/questions/74748276/visual-studio-no-displays-the-correct-length-of-stdvector if (MSVC AND CMAKE_CXX_COMPILER_ID STREQUAL "Clang") add_compile_definitions(_ITERATOR_DEBUG_LEVEL=0) endif() diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index deed5055d54..1b1150e7731 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -21,12 +21,40 @@ DispatchLoaderDynamic & ggml_vk_default_dispatcher(); #include <vulkan/vulkan.hpp> +// Fallback definitions for VK_NV_cooperative_matrix_decode_vector in case the +// installed Vulkan headers predate the extension. +#ifndef VK_NV_cooperative_matrix_decode_vector +#define VK_NV_cooperative_matrix_decode_vector 1 +#define VK_NV_COOPERATIVE_MATRIX_DECODE_VECTOR_EXTENSION_NAME "VK_NV_cooperative_matrix_decode_vector" +#define VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_COOPERATIVE_MATRIX_DECODE_VECTOR_FEATURES_NV ((VkStructureType)1000689000) +typedef struct VkPhysicalDeviceCooperativeMatrixDecodeVectorFeaturesNV { + VkStructureType sType; + void* pNext; + VkBool32 cooperativeMatrixDecodeVector; +} VkPhysicalDeviceCooperativeMatrixDecodeVectorFeaturesNV; +#endif + +// SPIR-V Headers: different SDK installations expose different include paths. +// LunarG Vulkan SDK on Windows typically provides <spirv-headers/spirv.hpp>. +// Linux packages, MSYS2 and MinGW often use the Khronos layout <spirv/unified1/spirv.hpp>. +#if __has_include(<spirv/unified1/spirv.hpp>) +# include <spirv/unified1/spirv.hpp> +#elif __has_include(<spirv-headers/spirv.hpp>) +# include <spirv-headers/spirv.hpp> +#elif __has_include(<spirv.hpp>) +# include <spirv.hpp> +#else + // Fallback to let the compiler throw a standard "file not found" error +# include <spirv/unified1/spirv.hpp> +#endif + #include <algorithm> #include <cmath> #include <iomanip> #include <iostream> #include <tuple> #include <vector> +#include <deque> #include <sstream> #include <utility> #include <memory> @@ -34,9 +62,10 @@ DispatchLoaderDynamic & ggml_vk_default_dispatcher(); #include <map> #include <set> #include <unordered_map> -#include <memory> +#include <shared_mutex> #include <mutex> #include <future> +#include <condition_variable> #include <thread> #if defined(_MSC_VER) @@ -84,6 +113,21 @@ typedef struct VkPhysicalDeviceShaderBfloat16FeaturesKHR { } VkPhysicalDeviceShaderBfloat16FeaturesKHR; #endif +#if !defined(VK_VALVE_shader_mixed_float_dot_product) +#define VK_VALVE_shader_mixed_float_dot_product 1 +#define VK_VALVE_SHADER_MIXED_FLOAT_DOT_PRODUCT_SPEC_VERSION 1 +#define VK_VALVE_SHADER_MIXED_FLOAT_DOT_PRODUCT_EXTENSION_NAME "VK_VALVE_shader_mixed_float_dot_product" +#define VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_SHADER_MIXED_FLOAT_DOT_PRODUCT_FEATURES_VALVE ((VkStructureType)1000673000) +typedef struct VkPhysicalDeviceShaderMixedFloatDotProductFeaturesVALVE { + VkStructureType sType; + void* pNext; + VkBool32 shaderMixedFloatDotProductFloat16AccFloat32; + VkBool32 shaderMixedFloatDotProductFloat16AccFloat16; + VkBool32 shaderMixedFloatDotProductBFloat16Acc; + VkBool32 shaderMixedFloatDotProductFloat8AccFloat32; +} VkPhysicalDeviceShaderMixedFloatDotProductFeaturesVALVE; +#endif + #define ROUNDUP_POW2(M, N) (((M) + (N) - 1) & ~((N) - 1)) #define CEIL_DIV(M, N) (((M) + (N)-1) / (N)) static bool is_pow2(uint32_t x) { return x > 1 && (x & (x-1)) == 0; } @@ -92,11 +136,10 @@ static bool is_pow2(uint32_t x) { return x > 1 && (x & (x-1)) == 0; } #define VK_VENDOR_ID_APPLE 0x106b #define VK_VENDOR_ID_INTEL 0x8086 #define VK_VENDOR_ID_NVIDIA 0x10de +#define VK_VENDOR_ID_QUALCOMM 0x5143 #define VK_DEVICE_DESCRIPTOR_POOL_SIZE 256 -#define GGML_VK_MAX_NODES 8192 - #define VK_CHECK(err, msg) \ do { \ vk::Result err_ = (err); \ @@ -132,8 +175,9 @@ struct vk_pipeline_struct { uint32_t align; // true if fields have been set by ggml_vk_create_pipeline bool initialized {}; - // set to true to request the pipeline is compiled - std::atomic<bool> needed {}; + // true while a compile is in flight, used to dedupe concurrent claims. + // Protected by device->compile_mutex. + bool compile_pending {}; // set to true when the shader has been compiled std::atomic<bool> compiled {}; // number of registers used, extracted from pipeline executable properties @@ -187,6 +231,12 @@ struct ggml_backend_vk_buffer_type_context { struct vk_queue; +struct vk_command_buffer { + vk::CommandBuffer buf; + uint64_t use_counter = 0; + bool in_use = false; +}; + // Stores command pool/buffers. There's an instance of this // for each (context,queue) pair and for each (device,queue) pair. struct vk_command_pool { @@ -194,10 +244,16 @@ struct vk_command_pool { void destroy(vk::Device& device); vk::CommandPool pool; - uint32_t cmd_buffer_idx; - std::vector<vk::CommandBuffer> cmd_buffers; + // Using deque so the pointers to command buffers + // remain valid even if we add more + std::deque<vk_command_buffer> cmd_buffers; vk_queue *q; + + size_t buffers_in_use() const { + return std::count_if(cmd_buffers.begin(), cmd_buffers.end(), + [](const auto& cb) { return cb.in_use; }); + } }; // Prevent simultaneous submissions to the same queue. @@ -254,6 +310,7 @@ enum vk_device_architecture { AMD_RDNA3, INTEL_XE2, NVIDIA_PRE_TURING, + NVIDIA_TURING, }; static vk_device_architecture get_device_architecture(const vk::PhysicalDevice& device) { @@ -336,18 +393,34 @@ static vk_device_architecture get_device_architecture(const vk::PhysicalDevice& const std::vector<vk::ExtensionProperties> ext_props = device.enumerateDeviceExtensionProperties(); bool cooperative_matrix = false; + bool sm_builtins = false; // Detect "pre-turing" based on lack of coopmat support. for (const auto& properties : ext_props) { if (strcmp("VK_KHR_cooperative_matrix", properties.extensionName) == 0) { cooperative_matrix = true; - break; + } else if (strcmp("VK_NV_shader_sm_builtins", properties.extensionName) == 0) { + sm_builtins = true; } } if (!cooperative_matrix) { return vk_device_architecture::NVIDIA_PRE_TURING; } + + if (sm_builtins) { + vk::PhysicalDeviceProperties2 props2; + vk::PhysicalDeviceShaderSMBuiltinsPropertiesNV sm_props; + + props2.pNext = &sm_props; + + device.getProperties2(&props2); + + // Turing has 32, following architectures have 48 + if (sm_props.shaderWarpsPerSM == 32) { + return vk_device_architecture::NVIDIA_TURING; + } + } } return vk_device_architecture::OTHER; } @@ -356,6 +429,7 @@ enum vk_conv_shapes { CONV_SHAPE_128x128, CONV_SHAPE_64x32, CONV_SHAPE_32x256, + CONV_SHAPE_64x128, CONV_SHAPE_COUNT, }; @@ -370,6 +444,7 @@ vk_conv_block_size vk_conv_block_sizes[CONV_SHAPE_COUNT] = { { 128, 128, 16 }, // CONV_SHAPE_128x128 { 64, 32, 32 }, // CONV_SHAPE_64x32 { 32, 256, 16 }, // CONV_SHAPE_32x256 + { 64, 128, 16 }, // CONV_SHAPE_64x128 }; enum dmmv_wg_sizes { @@ -385,30 +460,36 @@ enum FaCodePath { }; struct vk_fa_pipeline_state { - vk_fa_pipeline_state(uint32_t HSK, uint32_t HSV, bool small_rows, bool small_cache, FaCodePath path, bool aligned, bool f32acc) - : HSK(HSK), HSV(HSV), small_rows(small_rows), small_cache(small_cache), path(path), aligned(aligned), f32acc(f32acc) {} - uint32_t HSK, HSV; - bool small_rows, small_cache; + uint32_t Br, Bc; + uint32_t D_split, row_split; + bool shmem_staging; FaCodePath path; + uint32_t workgroup_size, subgroup_size; bool aligned; bool f32acc; + uint32_t flags; + uint32_t limit_occupancy_shmem; + ggml_type k_type; + ggml_type v_type; bool operator<(const vk_fa_pipeline_state &b) const { - return std::tie(HSK, HSV, small_rows, small_cache, path, aligned, f32acc) < - std::tie(b.HSK, b.HSV, b.small_rows, b.small_cache, b.path, b.aligned, b.f32acc); + return std::tie(HSK, HSV, Br, Bc, D_split, row_split, shmem_staging, path, workgroup_size, subgroup_size, aligned, f32acc, flags, limit_occupancy_shmem, k_type, v_type) < + std::tie(b.HSK, b.HSV, b.Br, b.Bc, b.D_split, b.row_split, b.shmem_staging, b.path, b.workgroup_size, b.subgroup_size, b.aligned, b.f32acc, b.flags, b.limit_occupancy_shmem, b.k_type, b.v_type); } }; struct vk_conv2d_pipeline_state { - vk_conv2d_pipeline_state(uint32_t s0, uint32_t s1, uint32_t p0, uint32_t p1, uint32_t d0, uint32_t d1, uint32_t KW, uint32_t KH) - : s0(s0), s1(s1), p0(p0), p1(p1), d0(d0), d1(d1), KW(KW), KH(KH) {} + vk_conv2d_pipeline_state(uint32_t s0, uint32_t s1, uint32_t p0, uint32_t p1, uint32_t d0, uint32_t d1, uint32_t KW, uint32_t KH, uint32_t aligned) + : s0(s0), s1(s1), p0(p0), p1(p1), d0(d0), d1(d1), KW(KW), KH(KH), aligned(aligned) {} uint32_t s0, s1, p0, p1, d0, d1, KW, KH; + // when set, shader can skip K/CRS/NPQ bounds checks and address clamps + uint32_t aligned; bool operator<(const vk_conv2d_pipeline_state &b) const { - return std::tie(s0, s1, p0, p1, d0, d1, KW, KH) < - std::tie(b.s0, b.s1, b.p0, b.p1, b.d0, b.d1, b.KW, b.KH); + return std::tie(s0, s1, p0, p1, d0, d1, KW, KH, aligned) < + std::tie(b.s0, b.s1, b.p0, b.p1, b.d0, b.d1, b.KW, b.KH, b.aligned); } }; @@ -453,6 +534,12 @@ static constexpr std::initializer_list<ggml_op> topk_moe_late_softmax { GGM GGML_OP_GET_ROWS, GGML_OP_RESHAPE, GGML_OP_SOFT_MAX, GGML_OP_RESHAPE }; +// Snake activation: y = x + sin(a*x)^2 * inv_b. Used by the optimize_graph reorder +// pass so it keeps the chain contiguous and by the dispatcher to detect the fusion. +static constexpr std::initializer_list<ggml_op> snake_pattern { GGML_OP_MUL, GGML_OP_SIN, + GGML_OP_SQR, GGML_OP_MUL, + GGML_OP_ADD }; + //node #978 ( SOFT_MAX): ffn_moe_probs-15 ( 0K) [Vulka ] use=2: ffn_moe_logits-15 ( 0K) [Vulka ] //node #979 ( RESHAPE): ffn_moe_probs-15 (re ( 0K) [Vulka ] use=1: ffn_moe_probs-15 ( 0K) [Vulka ] //node #980 ( ARGSORT): ffn_moe_argsort-15 ( 0K) [Vulka ] use=1: ffn_moe_probs-15 ( 0K) [Vulka ] @@ -549,6 +636,14 @@ static constexpr std::initializer_list<std::array<int, 3>> rms_norm_mul_rope_vie struct vk_device_struct { std::recursive_mutex mutex; + mutable std::shared_mutex pinned_memory_mutex; + + // Guards compile_pending, all_pipelines, and the dynamic pipeline maps + // (flash_attn, fa_mask_opt, solve_tri, conv2d, etc). The actual compile + // runs with no lock held, so different pipelines can compile in parallel. + // Lock order is device->mutex -> compile_mutex, never the reverse. + std::mutex compile_mutex; + std::condition_variable compile_cv; vk::PhysicalDevice physical_device; vk::PhysicalDeviceProperties properties; @@ -570,6 +665,7 @@ struct vk_device_struct { vk_queue transfer_queue; bool single_queue; bool support_async; + bool async_use_transfer_queue; uint32_t subgroup_size; uint32_t subgroup_size_log2; uint32_t shader_core_count; @@ -621,6 +717,10 @@ struct vk_device_struct { uint32_t coopmat_int_k; bool coopmat2; + bool coopmat2_bf16_support {}; + bool coopmat2_decode_vector; + + bool dot2_f16 {}; bool pipeline_executable_properties_support {}; @@ -633,6 +733,15 @@ struct vk_device_struct { bool mul_mat_id_m[GGML_TYPE_COUNT]; bool mul_mat_id_s[GGML_TYPE_COUNT]; + // Separate flags for the q8_1 (integer dot) mmq path, whose shader uses + // a different shared-memory layout than the float matmul shaders. + bool mul_mat_l_int[GGML_TYPE_COUNT]; + bool mul_mat_m_int[GGML_TYPE_COUNT]; + bool mul_mat_s_int[GGML_TYPE_COUNT]; + bool mul_mat_id_l_int[GGML_TYPE_COUNT]; + bool mul_mat_id_m_int[GGML_TYPE_COUNT]; + bool mul_mat_id_s_int[GGML_TYPE_COUNT]; + vk::DescriptorSetLayout dsl; vk_matmul_pipeline pipeline_matmul_f32 {}; @@ -669,6 +778,7 @@ struct vk_device_struct { vk_pipeline pipeline_get_rows[GGML_TYPE_COUNT]; vk_pipeline pipeline_get_rows_f32[GGML_TYPE_COUNT]; vk_pipeline pipeline_acc_f32; + vk_pipeline pipeline_set_f32; // [src0 0=fp32,1=fp16][src1 0=fp32,1=fp16][dst 0=fp32,1=fp16] vk_pipeline pipeline_add[2][2][2]; @@ -701,9 +811,10 @@ struct vk_device_struct { vk_pipeline pipeline_clamp_f32; vk_pipeline pipeline_pad_f32; vk_pipeline pipeline_roll_f32; - vk_pipeline pipeline_repeat_f32, pipeline_repeat_back_f32; - vk_pipeline pipeline_cpy_f32_f32, pipeline_cpy_f32_f16, pipeline_cpy_f16_f16, pipeline_cpy_f16_f32, pipeline_cpy_f32_bf16, pipeline_cpy_f32_i32, pipeline_cpy_i32_f32; - vk_pipeline pipeline_contig_cpy_f32_f32, pipeline_contig_cpy_f32_f16, pipeline_contig_cpy_f16_f16, pipeline_contig_cpy_f16_f32, pipeline_contig_cpy_f32_bf16, pipeline_contig_cpy_f32_i32, pipeline_contig_cpy_i32_f32; + vk_pipeline pipeline_repeat_i32, pipeline_repeat_back_f32; + vk_pipeline pipeline_repeat_i16; + vk_pipeline pipeline_cpy_f32_f32, pipeline_cpy_f32_f16, pipeline_cpy_f16_f16, pipeline_cpy_f16_f32, pipeline_cpy_f32_bf16, pipeline_cpy_bf16_f32, pipeline_cpy_f32_i32, pipeline_cpy_i32_f32; + vk_pipeline pipeline_contig_cpy_f32_f32, pipeline_contig_cpy_f32_f16, pipeline_contig_cpy_f16_f16, pipeline_contig_cpy_f16_f32, pipeline_contig_cpy_f32_bf16, pipeline_contig_cpy_bf16_f32, pipeline_contig_cpy_f32_i32, pipeline_contig_cpy_i32_f32; vk_pipeline pipeline_cpy_f32_quant[GGML_TYPE_COUNT]; vk_pipeline pipeline_cpy_quant_f32[GGML_TYPE_COUNT]; vk_pipeline pipeline_cpy_transpose_16, pipeline_cpy_transpose_32; @@ -722,6 +833,7 @@ struct vk_device_struct { // [src/dst 0=fp32,1=fp16] vk_pipeline pipeline_exp[2]; + vk_pipeline pipeline_elu[2]; vk_pipeline pipeline_gelu[2]; vk_pipeline pipeline_gelu_erf[2]; vk_pipeline pipeline_gelu_quick[2]; @@ -740,6 +852,7 @@ struct vk_device_struct { vk_pipeline pipeline_ceil[2]; vk_pipeline pipeline_floor[2]; vk_pipeline pipeline_trunc[2]; + vk_pipeline pipeline_sgn[2]; vk_pipeline pipeline_add1_f16_f16; vk_pipeline pipeline_add1_f16_f32; @@ -748,6 +861,7 @@ struct vk_device_struct { vk_pipeline pipeline_arange_f32; vk_pipeline pipeline_fill_f32; + vk_pipeline pipeline_fill_f16; vk_pipeline pipeline_geglu[2]; vk_pipeline pipeline_reglu[2]; @@ -775,6 +889,7 @@ struct vk_device_struct { vk_pipeline pipeline_argsort_large_f32[num_argsort_pipelines]; vk_pipeline pipeline_topk_f32[num_topk_pipelines]; vk_pipeline pipeline_sum_rows_f32; + vk_pipeline pipeline_fwht_f32[4]; vk_pipeline pipeline_cumsum_f32; vk_pipeline pipeline_cumsum_small_f32; vk_pipeline pipeline_cumsum_multipass1_f32; @@ -786,12 +901,19 @@ struct vk_device_struct { vk_pipeline pipeline_im2col_3d_f32, pipeline_im2col_3d_f32_f16; vk_pipeline pipeline_timestep_embedding_f32; vk_pipeline pipeline_conv_transpose_1d_f32; + vk_pipeline pipeline_snake_f32; + vk_pipeline pipeline_snake_f16; + vk_pipeline pipeline_snake_bf16; vk_pipeline pipeline_pool2d_f32; vk_pipeline pipeline_rwkv_wkv6_f32; vk_pipeline pipeline_rwkv_wkv7_f32; + // [size_idx][kda] where size_idx: 0=d32, 1=d64, 2=d128 + vk_pipeline pipeline_gated_delta_net[3][2]; vk_pipeline pipeline_ssm_scan_f32_d128; vk_pipeline pipeline_ssm_scan_f32_d256; vk_pipeline pipeline_ssm_conv_f32; + vk_pipeline pipeline_ssm_conv_silu_f32; + vk_pipeline pipeline_ssm_conv_bias_silu_f32; vk_pipeline pipeline_opt_step_adamw_f32; vk_pipeline pipeline_opt_step_sgd_f32; std::map<vk_conv2d_pipeline_state, vk_pipeline> pipeline_conv2d_f32[CONV_SHAPE_COUNT]; @@ -801,7 +923,9 @@ struct vk_device_struct { vk_pipeline pipeline_conv2d_dw_whcn_f32, pipeline_conv2d_dw_whcn_f16_f32; vk_pipeline pipeline_conv2d_dw_cwhn_f32, pipeline_conv2d_dw_cwhn_f16_f32; - std::map<vk_fa_pipeline_state, vk_pipeline> pipeline_flash_attn_f32_f16[GGML_TYPE_COUNT]; + std::map<vk_fa_pipeline_state, vk_pipeline> pipeline_flash_attn_f32_f16; + + std::map<std::pair<uint32_t, uint32_t>, vk_pipeline> pipeline_fa_mask_opt; vk_pipeline pipeline_flash_attn_split_k_reduce; vk_pipeline pipeline_count_experts; @@ -852,10 +976,12 @@ struct vk_device_struct { }; void vk_command_pool::init(vk_device& device, vk_queue *q_) { - cmd_buffer_idx = 0; + cmd_buffers.clear(); q = q_; - vk::CommandPoolCreateInfo command_pool_create_info(vk::CommandPoolCreateFlags(VK_COMMAND_POOL_CREATE_TRANSIENT_BIT), q->queue_family_index); + vk::CommandPoolCreateInfo command_pool_create_info( + vk::CommandPoolCreateFlags(VK_COMMAND_POOL_CREATE_TRANSIENT_BIT | VK_COMMAND_POOL_CREATE_RESET_COMMAND_BUFFER_BIT), + q->queue_family_index); pool = device->device.createCommandPool(command_pool_create_info); } @@ -896,22 +1022,28 @@ struct vk_subbuffer { } }; -// vk_event is used for the event-related backend interfaces. It uses 'event' for -// event_wait and 'fence' for event_synchronize. Polling on an event for +struct vk_semaphore { + vk::Semaphore s; + uint64_t value; +}; + +// vk_event is used for the event-related backend interfaces. It uses vk::Events for +// event_wait and a timeline semaphore for event_synchronize. Polling on an event for // event_synchronize wouldn't be sufficient to wait for command buffers to complete, // and would lead to validation errors. struct vk_event { + std::vector<vk::Event> events_free; // Events available for reuse + std::vector<vk::Event> events_submitted; // Events that are fully submitted and can be reused on next synchronize vk::Event event; - vk::Fence fence; -}; + bool has_event; -struct vk_semaphore { - vk::Semaphore s; - uint64_t value; + vk_semaphore tl_semaphore; + vk_command_buffer* cmd_buffer = nullptr; + uint64_t cmd_buffer_use_counter = 0; }; struct vk_submission { - vk::CommandBuffer buffer; + vk_command_buffer* buffer = nullptr; std::vector<vk_semaphore> wait_semaphores; std::vector<vk_semaphore> signal_semaphores; }; @@ -922,6 +1054,7 @@ struct vk_mat_mat_push_constants { uint32_t M; uint32_t N; uint32_t K; uint32_t stride_a; uint32_t stride_b; uint32_t stride_d; uint32_t batch_stride_a; uint32_t batch_stride_b; uint32_t batch_stride_d; + uint32_t base_work_group_z; uint32_t num_batches; uint32_t k_split; uint32_t ne02; uint32_t ne12; uint32_t broadcast2; uint32_t broadcast3; uint32_t padded_N; @@ -941,6 +1074,7 @@ struct vk_mat_vec_push_constants { uint32_t batch_stride_b; uint32_t batch_stride_d; uint32_t fusion_flags; + uint32_t base_work_group_y; uint32_t ne02; uint32_t ne12; uint32_t broadcast2; @@ -991,6 +1125,8 @@ struct vk_mat_vec_id_push_constants { uint32_t fusion_flags; uint32_t nei0; uint32_t ne11; + uint32_t expert_i1; + uint32_t nbi1; }; struct vk_flash_attn_push_constants { @@ -1044,6 +1180,13 @@ struct vk_op_push_constants { float param4; }; +struct vk_op_fwht_push_constants { + uint32_t n_rows; + uint32_t src_offset; + uint32_t dst_offset; + float scale; +}; + struct vk_op_count_experts_push_constants { uint32_t ne00; uint32_t ne01; @@ -1059,6 +1202,16 @@ struct vk_op_glu_push_constants { uint32_t mode; // 0: default, 1: swapped, 2: split float alpha; // for swiglu_oai float limit; + uint32_t nb01; + uint32_t nb02; + uint32_t nb03; + uint32_t ne01; + uint32_t ne02; + uint32_t nb11; + uint32_t nb12; + uint32_t nb13; + uint32_t ne11; + uint32_t ne12; }; struct vk_op_unary_push_constants { @@ -1244,25 +1397,32 @@ struct vk_op_diag_mask_push_constants { struct vk_op_rope_push_constants { uint32_t rope_mode; - uint32_t ncols; uint32_t nrows; uint32_t n_dims; float freq_scale; - uint32_t p_delta_rows; float freq_base; float ext_factor; float attn_factor; float corr_dims[2]; float theta_scale; uint32_t has_ff; - uint32_t ne02; - uint32_t s1; - uint32_t s2; int32_t sections[4]; uint32_t is_imrope; uint32_t is_back; uint32_t set_rows_stride; + uint32_t ne00; + uint32_t ne01; + uint32_t ne02; + uint32_t nb01; + uint32_t nb02; + uint32_t nb03; + uint32_t nb11; + uint32_t nb12; + uint32_t nb13; + uint32_t a_offset; + uint32_t d_offset; }; +static_assert(sizeof(vk_op_rope_push_constants) <= 128, "sizeof(vk_op_rope_push_constants) must be <= 128"); // For fused rms_norm+mul+rope(+view+set_rows) struct vk_op_rms_norm_mul_rope_push_constants { @@ -1319,7 +1479,7 @@ struct vk_op_im2col_push_constants { uint32_t IW; uint32_t IH; uint32_t OW; uint32_t OH; uint32_t KW; uint32_t KH; - uint32_t pelements; + uint32_t OH_batch; uint32_t CHW; int32_t s0; int32_t s1; int32_t p0; int32_t p1; @@ -1380,6 +1540,11 @@ struct vk_op_conv_transpose_1d_push_constants { int32_t s0; }; +struct vk_op_snake_push_constants { + uint32_t ne0; + uint32_t ne1; +}; + struct vk_op_pool2d_push_constants { uint32_t IW; uint32_t IH; uint32_t OW; uint32_t OH; @@ -1404,6 +1569,19 @@ struct vk_op_rwkv_wkv7_push_constants { uint32_t C; uint32_t H; }; +struct vk_op_gated_delta_net_push_constants { + uint32_t H; + uint32_t n_tokens; + uint32_t n_seqs; + uint32_t s_off; + uint32_t sq1, sq2, sq3; + uint32_t sv1, sv2, sv3; + uint32_t sb1, sb2, sb3; + uint32_t neq1, rq3; + float scale; + uint32_t K; +}; + struct vk_op_ssm_scan_push_constants { uint32_t nb02, nb03, nb12, nb13; uint32_t nb21, nb22, nb31; @@ -1516,6 +1694,27 @@ struct vk_quantize_q8_1_push_constants { uint32_t num_blocks; }; +struct vk_op_flash_attn_split_k_reduce_push_constants { + uint32_t D; + uint32_t ne1; + uint32_t ne2; + uint32_t ne3; + uint32_t k_num; + uint32_t sinks; +}; + +struct vk_op_flash_attn_mask_opt_push_constants { + uint32_t nem0; + uint32_t nem1; + uint32_t nem2; + uint32_t nbm1; + uint32_t nbm2; + uint32_t nbm3; + uint32_t nbd1; + uint32_t nbd2; + uint32_t nbd3; +}; + // Allow pre-recording command buffers struct vk_staging_memcpy { vk_staging_memcpy(void * _dst, const void * _src, size_t _n) : dst(_dst), src(_src), n(_n) {} @@ -1556,7 +1755,7 @@ struct ggml_vk_garbage_collector { }; static void ggml_vk_preallocate_buffers(ggml_backend_vk_context * ctx, vk_context subctx); -static void ggml_vk_load_shaders(vk_device& device); +static void ggml_vk_load_shaders(vk_device& device, vk_pipeline requested = nullptr); static void ggml_pipeline_allocate_descriptor_sets(ggml_backend_vk_context * ctx); static bool vk_memory_logger_enabled = false; @@ -1604,6 +1803,7 @@ static bool vk_perf_logger_concurrent = false; static bool vk_enable_sync_logger = false; // number of calls between perf logger prints static uint32_t vk_perf_logger_frequency = 1; +static std::string vk_pipeline_stats_filter; class vk_perf_logger { public: @@ -1724,6 +1924,7 @@ class vk_perf_logger { " k(" << k->ne[0] << "," << k->ne[1] << "," << k->ne[2] << "," << k->ne[3] << "), " << " v(" << v->ne[0] << "," << v->ne[1] << "," << v->ne[2] << "," << v->ne[3] << "), " << " m(" << (m?m->ne[0]:0) << "," << (m?m->ne[1]:0) << "," << (m?m->ne[2]:0) << "," << (m?m->ne[3]:0) << ")"; + *n_flops = 2ull * q->ne[1] * q->ne[2] * (k->ne[0] + v->ne[0]) * k->ne[1] * q->ne[3]; return name.str(); } if (node->op == GGML_OP_TOP_K) { @@ -1792,6 +1993,9 @@ struct ggml_backend_vk_context { // Cache most recent tensor that was converted into prealloc_y, and what pipeline it used to convert. vk_pipeline_struct * prealloc_y_last_pipeline_used {}; const ggml_tensor * prealloc_y_last_tensor_used {}; + // True when prealloc_y holds the padded fp16 layout used by the coopmat2 B decode-vector callback. + // If false, then it's contiguous. + bool prealloc_y_last_decode_vector_staging {}; // Track which nodes have been used since the last sync, and whether they were written to std::vector<const ggml_tensor *> unsynced_nodes_written; @@ -1802,7 +2006,10 @@ struct ggml_backend_vk_context { bool prealloc_x_need_sync, prealloc_y_need_sync, prealloc_split_k_need_sync; vk_context_ref compute_ctx; + vk_context_ref transfer_ctx; + vk_semaphore transfer_semaphore; + uint64_t transfer_semaphore_last_submitted {}; std::vector<vk_context_ref> tensor_ctxs; @@ -1888,6 +2095,15 @@ template <> void init_pushconst_tensor_offsets(ggml_backend_vk_context * ctx, vk GGML_UNUSED(src3); } +template <> void init_pushconst_tensor_offsets(ggml_backend_vk_context * ctx, vk_op_fwht_push_constants &p, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, const ggml_tensor * src3, ggml_tensor * dst) { + p.src_offset = get_misalign_bytes(ctx, src0) / ggml_type_size(src0->type); + p.dst_offset = get_misalign_bytes(ctx, dst) / ggml_type_size(dst->type); + + GGML_UNUSED(src1); + GGML_UNUSED(src2); + GGML_UNUSED(src3); +} + struct ggml_backend_vk_buffer_context { vk_device_ref device; vk_buffer dev_buffer; @@ -1928,9 +2144,9 @@ void vk_memory_logger::log_deallocation(vk_buffer_ref buf_ref) { const bool device = bool(buf->memory_property_flags & vk::MemoryPropertyFlagBits::eDeviceLocal); std::string type = device ? "device" : "host"; auto it = allocations.find(buf->buffer); - total_device -= device ? it->second : 0; - total_host -= device ? 0 : it->second; if (it != allocations.end()) { + total_device -= device ? it->second : 0; + total_host -= device ? 0 : it->second; VK_LOG_MEMORY(buf->device->name << ": -" << format_size(it->second) << " " << type << " at " << buf->buffer << ". Total device: " << format_size(total_device) << ", total host: " << format_size(total_host)); allocations.erase(it); } else { @@ -2009,10 +2225,135 @@ static void ggml_vk_wait_for_fence(ggml_backend_vk_context * ctx) { ctx->device->device.resetFences({ ctx->fence }); } -// variables to track number of compiles in progress -static uint32_t compile_count = 0; -static std::mutex compile_count_mutex; -static std::condition_variable compile_count_cond; +static constexpr uint32_t kSpvOpCooperativeMatrixLoadTensorNV = 5367; +static constexpr uint32_t kSpvCapabilityCooperativeMatrixDecodeVectorNV = 5447; +static constexpr uint32_t kSpvTensorAddressingDecodeVectorFuncBit = 0x4; + +// Remove SPV_NV_cooperative_matrix_decode_vector usage from a SPIR-V module so it +// can be loaded on drivers that only support SPV_NV_cooperative_matrix2. Drops the +// OpExtension declaration, the CooperativeMatrixDecodeVectorNV OpCapability, and the +// DecodeVectorFunc operand from any OpCooperativeMatrixLoadTensorNV instruction. +// Returns true when the input used the extension (and `out` was populated with a +// stripped copy); returns false otherwise without touching `out`. +static bool ggml_vk_strip_decode_vector(const uint32_t * code, size_t word_count, std::vector<uint32_t> & out) { + static const char kDecodeVectorExt[] = "SPV_NV_cooperative_matrix_decode_vector"; + + if (word_count < 5) { + return false; + } + + bool uses_decode_vector = false; + for (size_t pos = 5; pos < word_count; ) { + uint32_t word = code[pos]; + uint32_t wc = word >> spv::WordCountShift; + uint32_t op = word & spv::OpCodeMask; + GGML_ASSERT(wc > 0 && pos + wc <= word_count); + if (op == spv::OpExtension && wc >= 2) { + const char * s = reinterpret_cast<const char *>(&code[pos + 1]); + if (strcmp(s, kDecodeVectorExt) == 0) { + uses_decode_vector = true; + break; + } + } + pos += wc; + } + + if (!uses_decode_vector) { + return false; + } + + VK_LOG_DEBUG("ggml_vk_strip_decode_vector: stripping SPV_NV_cooperative_matrix_decode_vector"); + + // Bulk-copy unchanged runs and only break the run when an instruction needs to + // be dropped or patched. Use reserve + insert/push_back so the destination buffer + // is touched exactly once (no zero-initialization pass from resize()). + out.clear(); + out.reserve(word_count); + + size_t run_start = 0; + auto flush_run = [&](size_t up_to) { + if (up_to > run_start) { + out.insert(out.end(), code + run_start, code + up_to); + } + }; + + for (size_t pos = 5; pos < word_count; ) { + uint32_t word = code[pos]; + uint32_t wc = word >> spv::WordCountShift; + uint32_t op = word & spv::OpCodeMask; + GGML_ASSERT(wc > 0 && pos + wc <= word_count); + + if (op == spv::OpExtension && wc >= 2) { + const char * s = reinterpret_cast<const char *>(&code[pos + 1]); + if (strcmp(s, kDecodeVectorExt) == 0) { + flush_run(pos); + pos += wc; + run_start = pos; + continue; + } + } + + if (op == spv::OpCapability && wc == 2 && code[pos + 1] == kSpvCapabilityCooperativeMatrixDecodeVectorNV) { + flush_run(pos); + pos += wc; + run_start = pos; + continue; + } + + if (op == kSpvOpCooperativeMatrixLoadTensorNV) { + // [opcode/wc][ResultType][Result][Pointer][Object][TensorLayout][MemOperand mask][mem extras...][TA mask][ta extras...] + GGML_ASSERT(wc >= 8); + + uint32_t mem_mask = code[pos + 6]; + size_t cur = pos + 7; + // Each of these MemoryAccess bits (when set) carries one trailing operand. + cur += (mem_mask & 0x2) ? 1 : 0; // Aligned + cur += (mem_mask & 0x8) ? 1 : 0; // MakePointerAvailable + cur += (mem_mask & 0x10) ? 1 : 0; // MakePointerVisible + cur += (mem_mask & 0x10000) ? 1 : 0; // AliasScopeINTELMask + cur += (mem_mask & 0x20000) ? 1 : 0; // NoAliasINTELMask + GGML_ASSERT(cur < pos + wc); + + uint32_t ta_mask = code[cur]; + if ((ta_mask & kSpvTensorAddressingDecodeVectorFuncBit) == 0) { + pos += wc; + continue; // leave instruction inside the current unchanged run + } + + flush_run(pos); + + // Append unchanged prefix of the instruction (header through the mem-extras). + size_t inst_start = out.size(); + size_t pre_n = cur - pos; + out.insert(out.end(), code + pos, code + pos + pre_n); + + // Emit TA mask with the DecodeVectorFunc bit cleared. + out.push_back(ta_mask & ~kSpvTensorAddressingDecodeVectorFuncBit); + + // TA extras: TensorView (0x1) and DecodeFunc (0x2) are kept verbatim; + // DecodeVectorFunc (0x4) is dropped along with its trailing id operand. + size_t keep_ta_extras = ((ta_mask & 0x1) ? 1 : 0) + ((ta_mask & 0x2) ? 1 : 0); + if (keep_ta_extras) { + out.insert(out.end(), code + cur + 1, code + cur + 1 + keep_ta_extras); + } + + GGML_ASSERT(wc == pre_n + 1 + keep_ta_extras + 1); + + // Patch the instruction header with the new (one-shorter) word count. + uint32_t new_wc = wc - 1; + out[inst_start] = (new_wc << spv::WordCountShift) | op; + + pos += wc; + run_start = pos; + continue; + } + + pos += wc; + } + + flush_run(word_count); + return true; +} static void ggml_vk_create_pipeline_func(vk_device& device, vk_pipeline& pipeline, size_t spv_size, const void* spv_data, const std::string entrypoint, uint32_t parameter_count, std::array<uint32_t, 3> wg_denoms, std::vector<uint32_t> specialization_constants, @@ -2025,6 +2366,78 @@ static void ggml_vk_create_pipeline_func(vk_device& device, vk_pipeline& pipelin GGML_ASSERT(wg_denoms[0] > 0 && wg_denoms[1] > 0 && wg_denoms[2] > 0); // NOLINT vk::ShaderModuleCreateInfo shader_module_create_info({}, spv_size, reinterpret_cast<const uint32_t *>(spv_data)); + + // Patch SPIR-V to enable RTE rounding for FP16, avoiding the need for + // separate shader variants compiled with -DRTE16. + std::vector<uint32_t> spirv; + if (device->float_controls_rte_fp16) { + const uint32_t* spv_words = reinterpret_cast<const uint32_t *>(spv_data); + size_t word_count = spv_size / sizeof(uint32_t); + spirv.assign(spv_words, spv_words + word_count); + + // Find insertion points respecting SPIR-V layout order: + // Header(5) -> OpCapability -> OpExtension -> ... -> OpEntryPoint -> OpExecutionMode -> ... + size_t pos = 5; // skip header + size_t cap_insert_pos = pos; + size_t ext_insert_pos = pos; + size_t exec_insert_pos = pos; + uint32_t entry_point_id = 0; + + while (pos < spirv.size()) { + uint32_t opcode = spirv[pos] & spv::OpCodeMask; + uint32_t len = spirv[pos] >> spv::WordCountShift; + if (len == 0) break; + + if (opcode == spv::OpCapability) { + cap_insert_pos = pos + len; + ext_insert_pos = pos + len; + } else if (opcode == spv::OpExtension) { + ext_insert_pos = pos + len; + } else if (opcode == spv::OpEntryPoint) { + entry_point_id = spirv[pos + 2]; + exec_insert_pos = pos + len; + } else if (opcode == spv::OpExecutionMode || opcode == spv::OpExecutionModeId) { + exec_insert_pos = pos + len; + } else if (entry_point_id != 0) { + break; + } + + pos += len; + } + + // Insert from latest position first so earlier indices stay valid. + + // OpExecutionMode %entrypoint RoundingModeRTE 16 + uint32_t exec_mode[] = { (4u << spv::WordCountShift) | spv::OpExecutionMode, entry_point_id, spv::ExecutionModeRoundingModeRTE, 16 }; + spirv.insert(spirv.begin() + exec_insert_pos, std::begin(exec_mode), std::end(exec_mode)); + + // OpExtension "SPV_KHR_float_controls" + const char ext_str[] = "SPV_KHR_float_controls"; + size_t ext_str_words = CEIL_DIV(sizeof(ext_str), sizeof(uint32_t)); + std::vector<uint32_t> extension(1 + ext_str_words, 0); + extension[0] = (uint32_t)((1 + ext_str_words) << spv::WordCountShift) | spv::OpExtension; + memcpy(&extension[1], ext_str, sizeof(ext_str)); + spirv.insert(spirv.begin() + ext_insert_pos, extension.begin(), extension.end()); + + // OpCapability RoundingModeRTE + uint32_t capability[] = { (2u << spv::WordCountShift) | spv::OpCapability, spv::CapabilityRoundingModeRTE }; + spirv.insert(spirv.begin() + cap_insert_pos, std::begin(capability), std::end(capability)); + + shader_module_create_info = vk::ShaderModuleCreateInfo({}, spirv.size() * sizeof(uint32_t), spirv.data()); + } + +#if defined(GGML_VULKAN_COOPMAT2_DECODE_VECTOR_GLSLC_SUPPORT) + if (device->coopmat2 && !device->coopmat2_decode_vector) { + const uint32_t * src = spirv.empty() ? reinterpret_cast<const uint32_t *>(spv_data) : spirv.data(); + size_t src_n = spirv.empty() ? spv_size / sizeof(uint32_t) : spirv.size(); + std::vector<uint32_t> stripped; + if (ggml_vk_strip_decode_vector(src, src_n, stripped)) { + spirv = std::move(stripped); + shader_module_create_info = vk::ShaderModuleCreateInfo({}, spirv.size() * sizeof(uint32_t), spirv.data()); + } + } +#endif + pipeline->shader_module = device->device.createShaderModule(shader_module_create_info); vk::PushConstantRange pcr( @@ -2106,7 +2519,6 @@ static void ggml_vk_create_pipeline_func(vk_device& device, vk_pipeline& pipelin std::cerr << "ggml_vulkan: " << e.what() << std::endl; throw e; } - pipeline->compiled = true; if (vk_instance.debug_utils_support) { vk::DebugUtilsObjectNameInfoEXT duoni; @@ -2121,7 +2533,32 @@ static void ggml_vk_create_pipeline_func(vk_device& device, vk_pipeline& pipelin executableInfo.pipeline = pipeline->pipeline; auto statistics = device->device.getPipelineExecutableStatisticsKHR(executableInfo); + + bool print_stats = !vk_pipeline_stats_filter.empty() && + pipeline->name.find(vk_pipeline_stats_filter) != std::string::npos; + if (print_stats) { + std::cerr << "ggml_vulkan: pipeline stats for " << pipeline->name << ":" << std::endl; + } + for (auto & s : statistics) { + if (print_stats) { + std::cerr << "ggml_vulkan: " << s.name.data() << ": "; + switch (s.format) { + case vk::PipelineExecutableStatisticFormatKHR::eBool32: + std::cerr << (s.value.b32 ? "true" : "false"); + break; + case vk::PipelineExecutableStatisticFormatKHR::eInt64: + std::cerr << s.value.i64; + break; + case vk::PipelineExecutableStatisticFormatKHR::eUint64: + std::cerr << s.value.u64; + break; + case vk::PipelineExecutableStatisticFormatKHR::eFloat64: + std::cerr << s.value.f64; + break; + } + std::cerr << std::endl; + } // "Register Count" is reported by NVIDIA drivers. if (strcmp(s.name, "Register Count") == 0) { VK_LOG_DEBUG(pipeline->name << " " << s.name << ": " << s.value.u64 << " registers"); @@ -2130,14 +2567,13 @@ static void ggml_vk_create_pipeline_func(vk_device& device, vk_pipeline& pipelin } } - device->all_pipelines.push_back(pipeline); - { - std::lock_guard<std::mutex> guard(compile_count_mutex); - assert(compile_count > 0); - compile_count--; + std::lock_guard<std::mutex> guard(device->compile_mutex); + device->all_pipelines.push_back(pipeline); + pipeline->compiled = true; + pipeline->compile_pending = false; } - compile_count_cond.notify_all(); + device->compile_cv.notify_all(); } static void ggml_vk_destroy_pipeline(vk::Device& device, vk_pipeline& pipeline) { @@ -2153,8 +2589,7 @@ static void ggml_pipeline_request_descriptor_sets(ggml_backend_vk_context *ctx, VK_LOG_DEBUG("ggml_pipeline_request_descriptor_sets(" << pipeline->name << ", " << n << ")"); ctx->pipeline_descriptor_set_requirements += n; if (!pipeline->compiled) { - pipeline->needed = true; - ggml_vk_load_shaders(ctx->device); + ggml_vk_load_shaders(ctx->device, pipeline); } ggml_pipeline_allocate_descriptor_sets(ctx); } @@ -2197,25 +2632,15 @@ static void ggml_pipeline_allocate_descriptor_sets(ggml_backend_vk_context * ctx } } -static vk::CommandBuffer ggml_vk_create_cmd_buffer(vk_device& device, vk_command_pool& p) { +static vk_command_buffer* ggml_vk_create_cmd_buffer(vk_device& device, vk_command_pool& p) { VK_LOG_DEBUG("ggml_vk_create_cmd_buffer()"); - - if (p.cmd_buffers.size() > p.cmd_buffer_idx) { - // Reuse command buffer - return p.cmd_buffers[p.cmd_buffer_idx++]; - } - vk::CommandBufferAllocateInfo command_buffer_alloc_info( p.pool, vk::CommandBufferLevel::ePrimary, 1); const std::vector<vk::CommandBuffer> cmd_buffers = device->device.allocateCommandBuffers(command_buffer_alloc_info); - auto buf = cmd_buffers.front(); - - p.cmd_buffers.push_back(buf); - p.cmd_buffer_idx++; - - return buf; + p.cmd_buffers.push_back({ cmd_buffers.front(), 0, true }); + return &p.cmd_buffers[p.cmd_buffers.size()-1]; } static void ggml_vk_submit(vk_context& ctx, vk::Fence fence) { @@ -2282,7 +2707,7 @@ static void ggml_vk_submit(vk_context& ctx, vk::Fence fence) { tl_wait_semaphores[idx].data(), stage_flags[idx].data(), 1, - &submission.buffer, + &submission.buffer->buf, (uint32_t) submission.signal_semaphores.size(), tl_signal_semaphores[idx].data(), }; @@ -2406,7 +2831,11 @@ static void ggml_vk_command_pool_cleanup(vk_device& device, vk_command_pool& p) // Requires command buffers to be done device->device.resetCommandPool(p.pool); - p.cmd_buffer_idx = 0; + // Don't clear the command buffers and mark them as not in use. + // This allows us to reuse them + for (auto& cmd_buffer : p.cmd_buffers) { + cmd_buffer.in_use = false; + } } static void ggml_vk_queue_command_pools_cleanup(vk_device& device) { @@ -2415,10 +2844,10 @@ static void ggml_vk_queue_command_pools_cleanup(vk_device& device) { // Arbitrary frequency to cleanup/reuse command buffers static constexpr uint32_t cleanup_frequency = 10; - if (device->compute_queue.cmd_pool.cmd_buffer_idx >= cleanup_frequency) { + if (device->compute_queue.cmd_pool.buffers_in_use() >= cleanup_frequency) { ggml_vk_command_pool_cleanup(device, device->compute_queue.cmd_pool); } - if (device->transfer_queue.cmd_pool.cmd_buffer_idx >= cleanup_frequency) { + if (device->transfer_queue.cmd_pool.buffers_in_use() >= cleanup_frequency) { ggml_vk_command_pool_cleanup(device, device->transfer_queue.cmd_pool); } } @@ -2666,7 +3095,7 @@ static void ggml_vk_sync_buffers(ggml_backend_vk_context* ctx, vk_context& subct ctx->prealloc_x_need_sync = ctx->prealloc_y_need_sync = ctx->prealloc_split_k_need_sync = false; } - subctx->s->buffer.pipelineBarrier( + subctx->s->buffer->buf.pipelineBarrier( subctx->p->q->stage_flags, subctx->p->q->stage_flags, {}, @@ -2679,10 +3108,19 @@ static void ggml_vk_sync_buffers(ggml_backend_vk_context* ctx, vk_context& subct ); } +static void ggml_vk_reset_event(vk_context& ctx, vk::Event& event) { + VK_LOG_DEBUG("ggml_vk_set_event()"); + + ctx->s->buffer->buf.resetEvent( + event, + ctx->p->q->stage_flags + ); +} + static void ggml_vk_set_event(vk_context& ctx, vk::Event& event) { VK_LOG_DEBUG("ggml_vk_set_event()"); - ctx->s->buffer.setEvent( + ctx->s->buffer->buf.setEvent( event, ctx->p->q->stage_flags ); @@ -2694,7 +3132,7 @@ static void ggml_vk_wait_events(vk_context& ctx, std::vector<vk::Event>&& events return; } - ctx->s->buffer.waitEvents( + ctx->s->buffer->buf.waitEvents( events, ctx->p->q->stage_flags, ctx->p->q->stage_flags, @@ -2704,137 +3142,379 @@ static void ggml_vk_wait_events(vk_context& ctx, std::vector<vk::Event>&& events ); } -// number of rows/cols for flash attention shader -static constexpr uint32_t flash_attention_num_small_rows = 32; -static constexpr uint32_t scalar_flash_attention_num_small_rows = 1; +struct vk_fa_tuning_params { + FaCodePath path; + uint32_t workgroup_size; + uint32_t subgroup_size; + uint32_t block_rows; + uint32_t block_cols; + uint32_t d_split; + uint32_t row_split; + bool shmem_staging; + bool disable_subgroups; + uint32_t limit_occupancy_shmem; + + void print() const { + std::cerr << "path=" << path << " workgroup_size=" << workgroup_size << " subgroup_size=" << subgroup_size << + " block_rows=" << block_rows << " block_cols=" << block_cols << " d_split=" << d_split << + " row_split=" << row_split << " shmem_staging=" << shmem_staging << " disable_subgroups=" << disable_subgroups << + " limit_occupancy_shmem=" << limit_occupancy_shmem << std::endl; + } +}; + +static bool ggml_vk_flash_attn_scalar_shmem_support(const vk_device& device, const vk_fa_tuning_params& params, uint32_t hsk, uint32_t hsv, bool f32acc, ggml_type k_type, ggml_type v_type); +static bool ggml_vk_flash_attn_coopmat_shmem_support(const vk_device& device, const vk_fa_tuning_params& params, uint32_t hsk, uint32_t hsv, bool f32acc, ggml_type k_type = GGML_TYPE_F16); + +static vk_fa_tuning_params get_fa_tuning_params_scalar(const vk_device& device, uint32_t hsk, uint32_t hsv, uint32_t n_rows, uint32_t n_kv, ggml_type k_type, ggml_type v_type, bool f32acc) { -static uint32_t get_fa_scalar_num_large_rows(uint32_t hsk, uint32_t hsv, bool small_cache) { - if (hsv >= 192) { - return 2; - } else if ((hsv | hsk) & 8 || small_cache) { - return 4; + vk_fa_tuning_params result{}; + result.path = FA_SCALAR; + + if (device->vendor_id == VK_VENDOR_ID_INTEL) { + // Disable subgroup use due to performance issues when enforcing subgroup sizes + result.subgroup_size = 32; + result.disable_subgroups = true; + } else if (device->vendor_id == VK_VENDOR_ID_AMD && device->architecture != AMD_GCN) { + result.subgroup_size = n_rows < 4 ? 32 : device->subgroup_size; } else { - return 8; + result.subgroup_size = device->subgroup_size; } -} -// The FA coopmat1 shader assumes 16x16x16 matrix multiply support. -// 128 threads split into four subgroups, each subgroup does 1/4 -// of the Bc dimension. -static constexpr uint32_t coopmat1_flash_attention_num_large_rows = 16; -static constexpr uint32_t scalar_flash_attention_Bc = 64; -static constexpr uint32_t scalar_flash_attention_workgroup_size = 128; + // Row split splits the workgroup so that synchronization only has to happen within subgroups, which avoids barriers + uint32_t row_split_max_hsk = 64; + if (device->vendor_id == VK_VENDOR_ID_AMD && device->architecture != AMD_GCN && !device->uma) { + row_split_max_hsk = n_rows <= 8 ? 64 : 128; + } + result.row_split = (n_rows < 4 || hsk <= row_split_max_hsk) ? 1 : 4; -static uint32_t get_fa_num_small_rows(FaCodePath path) { - if (path == FA_COOPMAT2) { - return flash_attention_num_small_rows; + if (result.subgroup_size > 32 && (n_rows < 4 || hsk < (result.row_split == 1 ? 128 : 64))) { + result.workgroup_size = result.subgroup_size * 2; } else { - return scalar_flash_attention_num_small_rows; + result.workgroup_size = result.subgroup_size * 4; } -} -static std::array<uint32_t, 2> fa_rows_cols(FaCodePath path, uint32_t hsk, uint32_t hsv, uint32_t clamp, ggml_type type, bool small_rows, bool small_cache) { - GGML_UNUSED(clamp); + const uint32_t D = hsk | hsv; - if (path == FA_SCALAR) { - if (small_rows) { - return {scalar_flash_attention_num_small_rows, 64}; - } else { - if ((hsv | hsk) & 8) { - // HSV/HSK not being a multiple of 16 makes D_split smaller, which makes cols_per_iter - // larger, and Bc needs to be >= cols_per_thread. 64 is large enough, 32 is not. - return {get_fa_scalar_num_large_rows(hsk, hsv, small_cache), 64}; - } else { - return {get_fa_scalar_num_large_rows(hsk, hsv, small_cache), 32}; - } - } - } + const bool reduce_block_rows = D & 8 || n_kv < 1024 || device->vendor_id == VK_VENDOR_ID_INTEL; - if (path == FA_COOPMAT1) { - if (small_rows) { - return {scalar_flash_attention_num_small_rows, scalar_flash_attention_Bc}; + if (n_rows == 1) { + result.block_rows = 1; + result.block_cols = 64; + } else { + // row_split 1 means higher register use per row, so block size has to be adjusted + if (result.row_split == 1) { + result.block_rows = n_rows == 2 ? 2 : ((n_rows <= 4 || reduce_block_rows) ? 4 : 8); } else { - return {coopmat1_flash_attention_num_large_rows, scalar_flash_attention_Bc}; + result.block_rows = n_rows <= 4 ? 4 : ((n_rows <= 8 || reduce_block_rows) ? 8 : 16); } + + result.block_cols = (D & 8) ? 64 : 32; } - // small rows, large cols - if (small_rows) { - return {get_fa_num_small_rows(FA_COOPMAT2), 32}; + const uint32_t D_lsb = D ^ (D & (D-1)); // extract lowest set bit + + result.d_split = std::min(std::min(result.subgroup_size, 8u), D_lsb / 4); + + result.shmem_staging = (device->vendor_id == VK_VENDOR_ID_NVIDIA && hsk < 256 && hsv < 256) ? 1 : 0; + + if (!reduce_block_rows && !ggml_vk_flash_attn_scalar_shmem_support(device, result, hsk, hsv, f32acc, k_type, v_type)) { + result.block_rows /= 2; } - // small cols to reduce register count - if (ggml_is_quantized(type) || hsk >= 256 || hsv >= 256) { - if (hsk >= 512 || hsv >= 512) { - return {32, 32}; - } else { - return {64, 32}; + // On AMD RDNA, for small head sizes and big batch size the shader uses few registers, so too many subgroups get scheduled + // at once and end up thrashing the cache. Fix this by setting a large (unused) shmem buffer that reduces occupancy. + // This targets an occupancy of 4 subgroups per SIMD. + if (device->vendor_id == VK_VENDOR_ID_AMD && device->properties.limits.maxComputeSharedMemorySize == 65536) { + if (device->architecture != AMD_GCN && n_rows >= 64 && hsk <= 128) { + // 30kb target for hsk > 64, 26kb for <= 64 due to smaller workgroup size + // Values are guessed, tested on RDNA2 + result.limit_occupancy_shmem = (hsk <= 64 ? 26 : 30) * 1024 / 4 / 4; + } else if (device->architecture == AMD_GCN && n_rows <= 8 && hsk >= 256) { + // Same thing for GCN, with an occupancy target of 2 subgroups per SIMD. + // Here low-batch FA with large head size is affected. + // n_rows < 4 switch because workgroup size switches from 128 to 256 there. + result.limit_occupancy_shmem = (n_rows < 4 ? 14 : 26) * 1024 / 4 / 4; } } - return {64, 64}; -} -static uint32_t fa_align(FaCodePath path, uint32_t hsk, uint32_t hsv, ggml_type type, bool small_rows, bool small_cache) { - return fa_rows_cols(path, hsk, hsv, 0, type, small_rows, small_cache)[1]; + return result; } -static bool ggml_vk_matmul_shmem_support(const vk_device& device, const std::vector<uint32_t>& warptile, bool mul_mat_id, ggml_type src0_type) { +static vk_fa_tuning_params get_fa_tuning_params_coopmat1(const vk_device& device, uint32_t hsk, uint32_t hsv, uint32_t n_rows, uint32_t n_kv, ggml_type k_type, ggml_type v_type, bool f32acc) { + GGML_UNUSED(n_rows); + GGML_UNUSED(n_kv); + GGML_UNUSED(k_type); + GGML_UNUSED(v_type); + GGML_UNUSED(f32acc); - uint32_t lut_size = 0; - switch (src0_type) { - case GGML_TYPE_IQ1_S: - case GGML_TYPE_IQ1_M: - lut_size = 2*2048 + 4*2048; - break; - case GGML_TYPE_IQ2_XXS: - lut_size = 8*256; - break; - case GGML_TYPE_IQ2_XS: - lut_size = 8*512; - break; - case GGML_TYPE_IQ2_S: - lut_size = 8*1024; - break; - case GGML_TYPE_IQ3_XXS: - lut_size = 4*256; - break; - case GGML_TYPE_IQ3_S: - lut_size = 4*512; - break; - case GGML_TYPE_IQ4_NL: - case GGML_TYPE_IQ4_XS: - case GGML_TYPE_MXFP4: - lut_size = 4*16; - break; - default: - break; - } + vk_fa_tuning_params result{}; + result.path = FA_COOPMAT1; - // Needs to be kept up to date on shader changes - const uint32_t bank_conflict_offset = device->coopmat_support ? 8 : 1; - const uint32_t type_size = device->fp16 ? sizeof(ggml_fp16_t) : sizeof(float); - const uint32_t warps = warptile[0] / warptile[10]; + const uint32_t D = hsk | hsv; - const uint32_t load_bufs = (warptile[1] + warptile[2]) * (warptile[3] + bank_conflict_offset) * type_size; - const uint32_t mmid_row_ids = mul_mat_id ? (warptile[2] * 2 * sizeof(uint16_t)) : 0; - const uint32_t coopmat_stage = device->coopmat_support ? warptile[7] * warptile[8] / warps * sizeof(float) : 0; - const uint32_t ballots_sh = mul_mat_id ? (warps * 4 * sizeof(uint32_t)) : 0; + const uint32_t coopmat_block_rows = 16; + const uint32_t coopmat_block_cols = 16; - const uint32_t total_size = load_bufs + mmid_row_ids + coopmat_stage + lut_size + ballots_sh; - const bool supported = total_size <= device->properties.limits.maxComputeSharedMemorySize; + const uint32_t num_subgroups = 4; - VK_LOG_DEBUG("ggml_vk_matmul_shmem_support(warptile=(" << warptile[0] << "," << warptile[1] << "," << warptile[2] << "), " - "mul_mat_id=" << mul_mat_id << ", src0_type=" << ggml_type_name(src0_type) << ", supported=" << supported); + result.block_rows = coopmat_block_rows; + result.block_cols = coopmat_block_cols * num_subgroups; + result.row_split = num_subgroups; + result.subgroup_size = device->subgroup_size; + result.workgroup_size = num_subgroups * result.subgroup_size; - return supported; + const uint32_t D_lsb = D ^ (D & (D-1)); // extract lowest set bit + result.d_split = std::min(std::min(result.subgroup_size, 8u), D_lsb / 4); + + result.shmem_staging = (device->vendor_id == VK_VENDOR_ID_NVIDIA && hsk < 256 && hsv < 256) ? 1 : 0; + + return result; } -struct GpuPipelineConfig { - // GPU architecture identifier. - // Example: vk_device_architecture::AMD_GCN - vk_device_architecture arch; +static vk_fa_tuning_params get_fa_tuning_params_coopmat2(const vk_device& device, uint32_t hsk, uint32_t hsv, uint32_t n_rows, uint32_t n_kv, ggml_type k_type, ggml_type v_type, bool f32acc) { + GGML_UNUSED(n_kv); + GGML_UNUSED(f32acc); - // Mapping of pipeline names to their specific subgroup sizes. + vk_fa_tuning_params result{}; + result.path = FA_COOPMAT2; + + const uint32_t D = hsk | hsv; + + const bool small_rows = n_rows < 32; + + if (small_rows) { + result.block_rows = 32; + result.block_cols = 32; + } else if (ggml_is_quantized(k_type) || ggml_is_quantized(v_type) || hsk >= 256 || hsv >= 256) { + result.block_rows = (hsk >= 512 || hsv >= 512) ? 32 : 64; + result.block_cols = 32; + } else { + result.block_rows = 64; + result.block_cols = 64; + } + + result.subgroup_size = device->subgroup_size; + result.workgroup_size = (small_rows && (D % 32) == 0) ? 256 : 128; + + return result; +} + +static vk_fa_tuning_params get_fa_tuning_params(const vk_device& device, uint32_t hsk, uint32_t hsv, uint32_t n_rows, uint32_t n_kv, ggml_type k_type, ggml_type v_type, bool f32acc) { + FaCodePath path = device->coopmat2 ? FA_COOPMAT2 : + device->coopmat1_fa_support ? FA_COOPMAT1 : FA_SCALAR; + + if (path == FA_COOPMAT2 && k_type == GGML_TYPE_BF16 && !device->coopmat2_bf16_support) { + path = FA_COOPMAT1; + } + if (path == FA_COOPMAT1 && k_type == GGML_TYPE_BF16 && !device->coopmat_bf16_support) { + path = FA_SCALAR; + } + + if (path == FA_COOPMAT1 && device->architecture == vk_device_architecture::NVIDIA_TURING) { + // Nvidia compiler bug, see https://github.com/ggml-org/llama.cpp/pull/19075#issuecomment-3820716090 + path = FA_SCALAR; + } + + if (path == FA_COOPMAT1) { + bool shape_ok = (f32acc && device->coopmat_support_16x16x16_f32acc) || + (!f32acc && device->coopmat_support_16x16x16_f16acc); + const vk_fa_tuning_params params = get_fa_tuning_params_coopmat1(device, hsk, hsv, n_rows, n_kv, k_type, v_type, f32acc); + bool shmem_ok = ggml_vk_flash_attn_coopmat_shmem_support(device, params, hsk, hsv, f32acc, k_type); + + if (!shape_ok || !shmem_ok) { + path = FA_SCALAR; + } + } + + // scalar is faster than coopmat when N==1 + if (n_rows == 1 && (path == FA_COOPMAT1 || path == FA_COOPMAT2)) { + path = FA_SCALAR; + } + + // Q1_0 K/V is only implemented on coopmat2 (flash_attn_cm2); there is no scalar FA shader for it. + if ((k_type == GGML_TYPE_Q1_0 || v_type == GGML_TYPE_Q1_0) && device->coopmat2) { + path = FA_COOPMAT2; + } + + switch (path) { + case FA_SCALAR: + return get_fa_tuning_params_scalar(device, hsk, hsv, n_rows, n_kv, k_type, v_type, f32acc); + case FA_COOPMAT1: + return get_fa_tuning_params_coopmat1(device, hsk, hsv, n_rows, n_kv, k_type, v_type, f32acc); + case FA_COOPMAT2: + return get_fa_tuning_params_coopmat2(device, hsk, hsv, n_rows, n_kv, k_type, v_type, f32acc); + default: + throw std::runtime_error("unsupported FaCodePath"); + } +} + +static vk_fa_pipeline_state get_fa_pipeline_state(const vk_device& device, const vk_fa_tuning_params& params, uint32_t hsk, uint32_t hsv, bool aligned, bool f32acc, + bool use_mask, bool use_mask_opt, bool use_logit_softcap, ggml_type k_type, ggml_type v_type) { + const bool old_amd_windows = device->vendor_id == VK_VENDOR_ID_AMD && device->driver_id == vk::DriverId::eAmdProprietary && + (device->architecture == AMD_GCN || device->architecture == AMD_RDNA1 || device->architecture == AMD_RDNA2); + + uint32_t flags = (use_mask_opt ? 1 : 0) | + (use_mask ? 2 : 0) | + (use_logit_softcap ? 4 : 0) | + (old_amd_windows ? 8 : 0); + + const uint32_t subgroup_size = params.disable_subgroups ? 0 : params.subgroup_size; + + return vk_fa_pipeline_state{hsk, hsv, params.block_rows, params.block_cols, params.d_split, params.row_split, params.shmem_staging, params.path, params.workgroup_size, subgroup_size, aligned, f32acc, flags, params.limit_occupancy_shmem, k_type, v_type}; +} + +static std::vector<uint32_t> get_fa_spec_constants(const vk_fa_pipeline_state& state) { + const auto fa_block_bytes = [](ggml_type t) -> uint32_t { + if (t == GGML_TYPE_F32) return 16u; + return (uint32_t) ggml_type_size(t); + }; + return { + /* 0 WorkGroupSize */ state.workgroup_size, + /* 1 Br */ state.Br, + /* 2 Bc */ state.Bc, + /* 3 HSK */ state.HSK, + /* 4 HSV */ state.HSV, + /* 5 Clamp */ static_cast<uint32_t>(!state.aligned), + /* 6 D_split */ state.D_split, + /* 7 row_split */ state.row_split, + /* 8 SubGroupSize */ state.subgroup_size, + /* 9 SHMEM_STAGING */ state.shmem_staging ? 1u : 0u, + /*10 Flags */ state.flags, + /*11 LIMIT_OCCUPANCY_SHMEM */ state.limit_occupancy_shmem, + /*12 FaTypeK */ static_cast<uint32_t>(state.k_type), + /*13 FaTypeV */ static_cast<uint32_t>(state.v_type), + /*14 FaBlockBytesK */ fa_block_bytes(state.k_type), + /*15 FaBlockBytesV */ fa_block_bytes(state.v_type), + }; +} + +static bool ggml_vk_matmul_shmem_support(const vk_device& device, const std::vector<uint32_t>& warptile, bool mul_mat_id, ggml_type src0_type) { + + uint32_t lut_size = 0; + switch (src0_type) { + case GGML_TYPE_IQ1_S: + case GGML_TYPE_IQ1_M: + // Regular matmul uses the compact uint16_t IQ1 grid; the expanded + // uint32_t grid is only enabled for the q8_1/int-dot vector path. + lut_size = 2*2048; + break; + case GGML_TYPE_IQ2_XXS: + lut_size = 8*256; + break; + case GGML_TYPE_IQ2_XS: + lut_size = 8*512; + break; + case GGML_TYPE_IQ2_S: + lut_size = 8*1024; + break; + case GGML_TYPE_IQ3_XXS: + lut_size = 4*256; + break; + case GGML_TYPE_IQ3_S: + lut_size = 4*512; + break; + case GGML_TYPE_IQ4_NL: + case GGML_TYPE_IQ4_XS: + case GGML_TYPE_MXFP4: + lut_size = 4*16; + break; + case GGML_TYPE_NVFP4: + // Same kvalues budget as MXFP4 plus ue4m3_fp32_lut[128] (types.glsl, DATA_A_NVFP4). + lut_size = 4*16 + 128u * (uint32_t)sizeof(float); + break; + default: + break; + } + + // Needs to be kept up to date on shader changes + const uint32_t bank_conflict_offset = device->coopmat_support ? 8 : 1; + const uint32_t type_size = device->fp16 ? sizeof(ggml_fp16_t) : sizeof(float); + const uint32_t warps = warptile[0] / warptile[10]; + + const uint32_t load_bufs = (warptile[1] + warptile[2]) * (warptile[3] + bank_conflict_offset) * type_size; + const uint32_t mmid_row_ids = mul_mat_id ? (warptile[2] * 2 * sizeof(uint16_t)) : 0; + const uint32_t coopmat_stage = device->coopmat_support ? warptile[7] * warptile[8] / warps * sizeof(float) : 0; + const uint32_t ballots_sh = mul_mat_id ? (warps * 4 * sizeof(uint32_t)) : 0; + + const uint32_t total_size = load_bufs + mmid_row_ids + coopmat_stage + lut_size + ballots_sh; + const bool supported = total_size <= device->properties.limits.maxComputeSharedMemorySize; + + VK_LOG_DEBUG("ggml_vk_matmul_shmem_support(warptile=(" << warptile[0] << "," << warptile[1] << "," << warptile[2] << "), " + "mul_mat_id=" << mul_mat_id << ", src0_type=" << ggml_type_name(src0_type) << ", supported=" << supported); + + return supported; +} + +// Shmem usage for the q8_1 mmq shader (mul_mmq.comp), which uses +// block_a_cache / block_b_cache layouts (see mul_mmq_shmem_types.glsl) rather +// than the float load buffers checked by ggml_vk_matmul_shmem_support. +// Sizes follow std430 rules. Returns false for types without a q8_1 pipeline. +static bool ggml_vk_matmul_int_shmem_support(const vk_device& device, const std::vector<uint32_t>& warptile, bool mul_mat_id, ggml_type src0_type) { + + // FLOAT_TYPE in the shader is float16_t with fp16 support, otherwise float. + const uint32_t fp_size = device->fp16 ? 2u : 4u; + const uint32_t fp_align = fp_size; + const uint32_t fp2_size = 2u * fp_size; + const uint32_t fp2_align = device->fp16 ? 4u : 8u; + + struct member { uint32_t size, align; }; + auto std430_size = [](std::initializer_list<member> members) { + uint32_t off = 0, struct_align = 1; + for (const auto &m : members) { + off = (off + m.align - 1) & ~(m.align - 1); + off += m.size; + struct_align = std::max(struct_align, m.align); + } + return (off + struct_align - 1) & ~(struct_align - 1); + }; + + uint32_t block_a_size = 0; + switch (src0_type) { + case GGML_TYPE_Q4_0: block_a_size = std430_size({{16, 4}, {fp_size, fp_align}}); break; // qs[16/4] + dm + case GGML_TYPE_Q4_1: block_a_size = std430_size({{16, 4}, {fp2_size, fp2_align}}); break; // qs[16/4] + dm(vec2) + case GGML_TYPE_Q5_0: block_a_size = std430_size({{16, 4}, {4, 4}, {fp_size, fp_align}}); break; // qs[16/4] + qh + dm + case GGML_TYPE_Q5_1: block_a_size = std430_size({{16, 4}, {4, 4}, {fp2_size, fp2_align}}); break; // qs[16/4] + qh + dm(vec2) + case GGML_TYPE_Q8_0: block_a_size = std430_size({{32, 4}, {fp_size, fp_align}}); break; // qs[8] + dm + case GGML_TYPE_MXFP4: block_a_size = std430_size({{32, 4}, {fp_size, fp_align}}); break; // qs[8] + d + case GGML_TYPE_Q2_K: block_a_size = std430_size({{ 8, 4}, {2, 2}, {fp2_size, fp2_align}}); break; // qs[2] + scales(u8vec2) + dm(vec2) + case GGML_TYPE_Q3_K: block_a_size = std430_size({{16, 4}, {fp2_size, fp2_align}}); break; // qs[4] + d_scales(vec2) + case GGML_TYPE_Q4_K: block_a_size = std430_size({{16, 4}, {fp2_size, fp2_align}}); break; // qs[4] + dm(vec2) + case GGML_TYPE_Q5_K: block_a_size = std430_size({{32, 4}, {fp2_size, fp2_align}}); break; // qs[8] + dm(vec2) + case GGML_TYPE_Q6_K: block_a_size = std430_size({{32, 4}, {fp2_size, fp2_align}}); break; // qs[8] + d_scales(vec2) + default: + return false; + } + + // block_b_cache: { int32_t qs[8]; FLOAT_TYPEV2 ds; } + const uint32_t block_b_size = std430_size({{32, 4}, {fp2_size, fp2_align}}); + + const uint32_t BM = warptile[1]; + const uint32_t BN = warptile[2]; + // mul_mmq.comp: BK_STEP=1 for MUL_MAT_ID, 4 otherwise. + const uint32_t BK_STEP = mul_mat_id ? 1u : 4u; + + const uint32_t buf_a_size = BM * BK_STEP * block_a_size; + const uint32_t buf_b_size = BN * BK_STEP * block_b_size; + const uint32_t mmid_row_ids = mul_mat_id ? (BN * 2u * (uint32_t)sizeof(uint16_t)) : 0u; + + const uint32_t warps = warptile[0] / warptile[10]; + const uint32_t ballots_sh = mul_mat_id ? (warps * 4u * (uint32_t)sizeof(uint32_t)) : 0u; + + const uint32_t total_size = buf_a_size + buf_b_size + mmid_row_ids + ballots_sh; + const bool supported = total_size <= device->properties.limits.maxComputeSharedMemorySize; + + VK_LOG_DEBUG("ggml_vk_matmul_int_shmem_support(warptile=(" << warptile[0] << "," << warptile[1] << "," << warptile[2] << "), " + "mul_mat_id=" << mul_mat_id << ", src0_type=" << ggml_type_name(src0_type) << ", total=" << total_size << ", supported=" << supported); + + return supported; +} + +struct GpuPipelineConfig { + // GPU architecture identifier. + // Example: vk_device_architecture::AMD_GCN + vk_device_architecture arch; + + // Mapping of pipeline names to their specific subgroup sizes. // Example: {"soft_max_f32", 64} std::unordered_map<std::string, uint32_t> pipelines; @@ -2896,10 +3576,40 @@ static uint32_t get_subgroup_size(const std::string &pipeline_name, const vk_dev return 0; // If no matching configuration is found } -static void ggml_vk_load_shaders(vk_device& device) { +// Whether scalar flash attention will use the MMQ path for the given k_type. +static bool ggml_vk_fa_scalar_uses_mmq(const vk_device& device, ggml_type k_type) { +#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT) + return device->integer_dot_product && device->subgroup_clustered && + (k_type == GGML_TYPE_Q4_0 || k_type == GGML_TYPE_Q4_1 || + k_type == GGML_TYPE_Q5_0 || k_type == GGML_TYPE_Q5_1 || + k_type == GGML_TYPE_Q8_0); +#else + GGML_UNUSED(device); + GGML_UNUSED(k_type); + return false; +#endif +} + +// load_shaders walks the pipeline list under compile_mutex and either claims +// the requested pipeline for compilation or, if another thread is already +// compiling it, drops the lock and waits on compile_cv. Compiles themselves +// run unlocked. +struct CompileTask { + vk_pipeline pipeline; + size_t spv_size; + const void * spv_data; + std::string entrypoint; + uint32_t parameter_count; + std::array<uint32_t, 3> wg_denoms; + std::vector<uint32_t> specialization_constants; + bool disable_robustness; + bool require_full_subgroups; + uint32_t required_subgroup_size; +}; + +static void ggml_vk_load_shaders(vk_device& device, vk_pipeline requested) { VK_LOG_DEBUG("ggml_vk_load_shaders(" << device->name << ")"); - std::lock_guard<std::recursive_mutex> guard(device->mutex); // some shaders have a minimum subgroup size const uint32_t subgroup_size_8 = std::max(device->subgroup_size, 8u); const uint32_t subgroup_size_16 = std::max(device->subgroup_size, 16u); @@ -2929,6 +3639,15 @@ static void ggml_vk_load_shaders(vk_device& device) { l_mmqid_wg_denoms, m_mmqid_wg_denoms, s_mmqid_wg_denoms; uint32_t l_align, m_align, s_align; + + vk_pipeline wait_pipeline; + CompileTask claimed_task {}; + bool has_claimed_task = false; + + // The rest of the walk reads and writes shared device state, so hold the + // lock until we're done deciding what to compile. + std::unique_lock<std::mutex> compile_lock(device->compile_mutex); + if (device->coopmat2) { // spec constants and tile sizes for non-quant matmul/matmul_id l_warptile = { 256, 128, 256, 64, 1 }; @@ -2955,9 +3674,10 @@ static void ggml_vk_load_shaders(vk_device& device) { s_mmq_wg_denoms_k = { 32, 64, 1 }; // spec constants and tile sizes for quant matmul_id - l_warptile_mmqid = { 256, 128, 128, 32, 1, device->subgroup_size }; - m_warptile_mmqid = { 256, 128, 64, 32, 0, device->subgroup_size }; - s_warptile_mmqid = { 256, 128, 64, 32, 0, device->subgroup_size }; + const uint32_t mmqid_bk = device->coopmat2_decode_vector ? 64u : 32u; + l_warptile_mmqid = { 256, 128, 128, mmqid_bk, 1, device->subgroup_size }; + m_warptile_mmqid = { 256, 128, 64, mmqid_bk, 0, device->subgroup_size }; + s_warptile_mmqid = { 256, 128, 64, mmqid_bk, 0, device->subgroup_size }; l_mmqid_wg_denoms = { 128, 128, 1 }; m_mmqid_wg_denoms = { 128, 64, 1 }; s_mmqid_wg_denoms = { 128, 64, 1 }; @@ -3061,6 +3781,40 @@ static void ggml_vk_load_shaders(vk_device& device) { } else if (!ggml_vk_matmul_shmem_support(device, l_warptile_mmqid, true, t)) { device->mul_mat_id_l[i] = false; } + + // The q8_1 mmq path has its own (larger) shmem layout, check it separately. + // K-quants use the _int_k warptiles, others use _int. + const bool is_k_quant = (t == GGML_TYPE_Q2_K || t == GGML_TYPE_Q3_K || + t == GGML_TYPE_Q4_K || t == GGML_TYPE_Q5_K || + t == GGML_TYPE_Q6_K); + const auto & s_int = is_k_quant ? s_warptile_mmq_int_k : s_warptile_mmq_int; + const auto & m_int = is_k_quant ? m_warptile_mmq_int_k : m_warptile_mmq_int; + const auto & l_int = is_k_quant ? l_warptile_mmq_int_k : l_warptile_mmq_int; + const auto & s_intid = is_k_quant ? s_warptile_mmqid_int_k : s_warptile_mmqid_int; + const auto & m_intid = is_k_quant ? m_warptile_mmqid_int_k : m_warptile_mmqid_int; + const auto & l_intid = is_k_quant ? l_warptile_mmqid_int_k : l_warptile_mmqid_int; + + if (!ggml_vk_matmul_int_shmem_support(device, s_int, false, t)) { + device->mul_mat_s_int[i] = false; + device->mul_mat_m_int[i] = false; + device->mul_mat_l_int[i] = false; + } else if (!ggml_vk_matmul_int_shmem_support(device, m_int, false, t)) { + device->mul_mat_m_int[i] = false; + device->mul_mat_l_int[i] = false; + } else if (!ggml_vk_matmul_int_shmem_support(device, l_int, false, t)) { + device->mul_mat_l_int[i] = false; + } + + if (!ggml_vk_matmul_int_shmem_support(device, s_intid, true, t)) { + device->mul_mat_id_s_int[i] = false; + device->mul_mat_id_m_int[i] = false; + device->mul_mat_id_l_int[i] = false; + } else if (!ggml_vk_matmul_int_shmem_support(device, m_intid, true, t)) { + device->mul_mat_id_m_int[i] = false; + device->mul_mat_id_l_int[i] = false; + } else if (!ggml_vk_matmul_int_shmem_support(device, l_intid, true, t)) { + device->mul_mat_id_l_int[i] = false; + } } } @@ -3080,7 +3834,6 @@ static void ggml_vk_load_shaders(vk_device& device) { device->pipeline_matmul_id_bf16 = std::make_shared<vk_matmul_pipeline_struct>(); } - std::vector<std::future<void>> compiles; auto const &ggml_vk_create_pipeline = [&](vk_device& device, vk_pipeline& base_pipeline, const char *name, size_t spv_size, const void* spv_data, const char *entrypoint, uint32_t parameter_count, uint32_t push_constant_size, std::array<uint32_t, 3> wg_denoms, const std::vector<uint32_t>& specialization_constants, uint32_t align, bool disable_robustness = false, bool require_full_subgroups = false, uint32_t required_subgroup_size = 0) { @@ -3114,23 +3867,33 @@ static void ggml_vk_load_shaders(vk_device& device) { #endif } - if (!pipeline->needed || pipeline->compiled) { + // We only care about the pipeline this call asked for; the rest + // (including the 64-bit indexing variant) are handled by their + // own request_descriptor_sets / load_shaders calls. + if (pipeline.get() != requested.get()) { continue; } - // TODO: We're no longer benefitting from the async compiles (shaders are - // compiled individually, as needed) and this complexity can be removed. - { - // wait until fewer than N compiles are in progress - uint32_t N = std::max(1u, std::thread::hardware_concurrency()); - std::unique_lock<std::mutex> guard(compile_count_mutex); - while (compile_count >= N) { - compile_count_cond.wait(guard); - } - compile_count++; + + if (pipeline->compiled) { + continue; } - compiles.push_back(std::async(ggml_vk_create_pipeline_func, std::ref(device), std::ref(pipeline), spv_size, spv_data, entrypoint, - parameter_count, wg_denoms, specialization_constants, disable_robustness, require_full_subgroups, required_subgroup_size)); + wait_pipeline = pipeline; + + if (!pipeline->compile_pending) { + pipeline->compile_pending = true; + claimed_task.pipeline = pipeline; + claimed_task.spv_size = spv_size; + claimed_task.spv_data = spv_data; + claimed_task.entrypoint = entrypoint; + claimed_task.parameter_count = parameter_count; + claimed_task.wg_denoms = wg_denoms; + claimed_task.specialization_constants = specialization_constants; + claimed_task.disable_robustness = disable_robustness; + claimed_task.require_full_subgroups = require_full_subgroups; + claimed_task.required_subgroup_size = required_subgroup_size; + has_claimed_task = true; + } } }; @@ -3142,81 +3905,132 @@ static void ggml_vk_load_shaders(vk_device& device) { align, disable_robustness, require_full_subgroups, required_subgroup_size); }; - auto const &fa_wg_denoms = [&](FaCodePath path, uint32_t hsk, uint32_t hsv, uint32_t clamp, ggml_type type, bool small_rows, bool small_cache) -> std::array<uint32_t, 3> { - return {fa_rows_cols(path, hsk, hsv, clamp, type, small_rows, small_cache)[0], 1, 1}; - }; - - auto const &fa_spec_constants = [&](FaCodePath path, uint32_t hsk, uint32_t hsv, uint32_t clamp, ggml_type type, bool small_rows, bool small_cache) -> std::vector<uint32_t> { - // For large number of rows, 128 invocations seems to work best. - // For small number of rows (e.g. N==1), 256 works better. But matrix granularity for 256 is 32, so we - // can't use 256 for D==80. - // For scalar, use 128 (arbitrary) - // The same D_split value is used for both HSK and HSV, so just base it on the union of the LSBs. - const uint32_t D = (hsk|hsv); - uint32_t wg_size = (path == FA_SCALAR || path == FA_COOPMAT1) - ? scalar_flash_attention_workgroup_size - : ((small_rows && (D % 32) == 0) ? 256 : 128); - auto rows_cols = fa_rows_cols(path, hsk, hsv, clamp, type, small_rows, small_cache); - - // D_split can't be larger than a subgroup because we use subgroupShuffle to reduce it. - // D_split can't be larger than the LSB of D divided by 4 due to vectorization in the shader. - const uint32_t D_lsb = D ^ (D & (D-1)); - uint32_t D_split = std::min(std::min(device->subgroup_size, 8u), D_lsb / 4); - - return {wg_size, rows_cols[0], rows_cols[1], hsk, hsv, clamp, D_split}; - }; + // FA scalar has two SPIR-V modules (MMQ vs non-MMQ); FA cm1 has one. K/V + // quant type is selected at runtime via the FaTypeK / FaTypeV spec constants. + + for (auto &fa : device->pipeline_flash_attn_f32_f16) { + if (fa.first.path != FA_SCALAR) continue; + const uint32_t Br = fa.first.Br; + const uint32_t Bc = fa.first.Bc; + const bool aligned = fa.first.aligned; + const bool f32acc = fa.first.f32acc; + const uint32_t fa_sgs = fa.first.subgroup_size; + const bool fa_ds = fa.first.subgroup_size == 0; + + const bool bf16_kv = fa.first.k_type == GGML_TYPE_BF16; + const bool use_mmq = ggml_vk_fa_scalar_uses_mmq(device, fa.first.k_type); + const void * spv_data = nullptr; + size_t spv_size = 0; + const char *name = nullptr; + if (bf16_kv) { + spv_data = flash_attn_f32_f16_fp32_data; + spv_size = flash_attn_f32_f16_fp32_len; + name = aligned ? "flash_attn_f32_bf16_aligned" : "flash_attn_f32_bf16"; + } else if (use_mmq) { +#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT) + if (device->fp16) { + if (f32acc) { spv_data = flash_attn_f32_f16_int8_data; spv_size = flash_attn_f32_f16_int8_len; } + else { spv_data = flash_attn_f32_f16_f16acc_int8_data; spv_size = flash_attn_f32_f16_f16acc_int8_len; } + } else { + spv_data = flash_attn_f32_f16_fp32_int8_data; + spv_size = flash_attn_f32_f16_fp32_int8_len; + } +#endif + name = aligned ? "flash_attn_f32_f16_aligned" : "flash_attn_f32_f16"; + } else { + if (device->fp16) { + if (device->dot2_f16) { + if (f32acc) { spv_data = flash_attn_f32_f16_dot2_data; spv_size = flash_attn_f32_f16_dot2_len; } + else { spv_data = flash_attn_f32_f16_dot2_f16acc_data; spv_size = flash_attn_f32_f16_dot2_f16acc_len; } + } else { + if (f32acc) { spv_data = flash_attn_f32_f16_data; spv_size = flash_attn_f32_f16_len; } + else { spv_data = flash_attn_f32_f16_f16acc_data; spv_size = flash_attn_f32_f16_f16acc_len; } + } + } else { + spv_data = flash_attn_f32_f16_fp32_data; + spv_size = flash_attn_f32_f16_fp32_len; + } + name = aligned ? "flash_attn_f32_f16_aligned" : "flash_attn_f32_f16"; + } + ggml_vk_create_pipeline(device, fa.second, name, spv_size, spv_data, "main", 7, + sizeof(vk_flash_attn_push_constants), {Br, 1, 1}, + get_fa_spec_constants(fa.first), aligned ? Bc : 1, true, + !fa_ds, !fa_ds ? fa_sgs : 0); + } -#define CREATE_FA(TYPE, NAMELC, FAPATH, SUFFIX) \ - for (auto &fa : device->pipeline_flash_attn_f32_f16[TYPE]) { \ - uint32_t HSK = fa.first.HSK; \ - uint32_t HSV = fa.first.HSV; \ - bool small_rows = fa.first.small_rows; \ - bool small_cache = fa.first.small_cache; \ - FaCodePath path = fa.first.path; \ - bool aligned = fa.first.aligned; \ - bool f32acc = fa.first.f32acc; \ - if (path == FAPATH) { \ - if (aligned) { \ - if (f32acc) { \ - ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_aligned_f32acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,small_rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,small_rows,small_cache), fa_align(FAPATH,HSK,HSV,TYPE,small_rows,small_cache), true, true, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \ - } else { \ - ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_aligned_f16acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,small_rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,small_rows,small_cache), fa_align(FAPATH,HSK,HSV,TYPE,small_rows,small_cache), true, true, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \ - } \ - } else { \ - if (f32acc) { \ - ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_f32acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,small_rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,small_rows,small_cache), 1, true, true, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \ - } else { \ - ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_f16acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,small_rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,small_rows,small_cache), 1, true, true, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \ - } \ - } \ - } \ - } - - CREATE_FA(GGML_TYPE_F32, f32, FA_SCALAR, ) - CREATE_FA(GGML_TYPE_F16, f16, FA_SCALAR, ) - CREATE_FA(GGML_TYPE_Q4_0, q4_0, FA_SCALAR, ) - CREATE_FA(GGML_TYPE_Q8_0, q8_0, FA_SCALAR, ) #if defined(VK_KHR_cooperative_matrix) && defined(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT) if (device->coopmat1_fa_support) { - CREATE_FA(GGML_TYPE_F32, f32, FA_COOPMAT1, _cm1) - CREATE_FA(GGML_TYPE_F16, f16, FA_COOPMAT1, _cm1) - CREATE_FA(GGML_TYPE_Q4_0, q4_0, FA_COOPMAT1, _cm1) - CREATE_FA(GGML_TYPE_Q8_0, q8_0, FA_COOPMAT1, _cm1) + for (auto &fa : device->pipeline_flash_attn_f32_f16) { + if (fa.first.path != FA_COOPMAT1) continue; + const uint32_t Br = fa.first.Br; + const uint32_t Bc = fa.first.Bc; + const bool aligned = fa.first.aligned; + const bool f32acc = fa.first.f32acc; + const uint32_t fa_sgs = fa.first.subgroup_size; + const bool fa_ds = fa.first.subgroup_size == 0; + + const bool bf16_kv = fa.first.k_type == GGML_TYPE_BF16; + + const void * spv_data; + size_t spv_size; + const char *name; + if (bf16_kv) { +#if defined(VK_KHR_shader_bfloat16) && defined(GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT) + if (!device->coopmat_bf16_support) continue; + spv_data = flash_attn_f32_f16_bf16_cm1_data; + spv_size = flash_attn_f32_f16_bf16_cm1_len; + name = aligned ? "flash_attn_f32_bf16_aligned_cm1" : "flash_attn_f32_bf16_cm1"; +#else + continue; +#endif + } else { + if (f32acc) { spv_data = flash_attn_f32_f16_cm1_data; spv_size = flash_attn_f32_f16_cm1_len; } + else { spv_data = flash_attn_f32_f16_f16acc_cm1_data; spv_size = flash_attn_f32_f16_f16acc_cm1_len; } + name = aligned ? "flash_attn_f32_f16_aligned_cm1" : "flash_attn_f32_f16_cm1"; + } + ggml_vk_create_pipeline(device, fa.second, name, spv_size, spv_data, "main", 7, + sizeof(vk_flash_attn_push_constants), {Br, 1, 1}, + get_fa_spec_constants(fa.first), aligned ? Bc : 1, true, + !fa_ds, !fa_ds ? fa_sgs : 0); + } } #endif + #if defined(VK_NV_cooperative_matrix2) && defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT) if (device->coopmat2) { - CREATE_FA(GGML_TYPE_F32, f32, FA_COOPMAT2, _cm2) - CREATE_FA(GGML_TYPE_F16, f16, FA_COOPMAT2, _cm2) - CREATE_FA(GGML_TYPE_Q4_0, q4_0, FA_COOPMAT2, _cm2) - CREATE_FA(GGML_TYPE_Q4_1, q4_1, FA_COOPMAT2, _cm2) - CREATE_FA(GGML_TYPE_Q5_0, q5_0, FA_COOPMAT2, _cm2) - CREATE_FA(GGML_TYPE_Q5_1, q5_1, FA_COOPMAT2, _cm2) - CREATE_FA(GGML_TYPE_Q8_0, q8_0, FA_COOPMAT2, _cm2) - CREATE_FA(GGML_TYPE_IQ4_NL, iq4_nl, FA_COOPMAT2, _cm2) + for (auto &fa : device->pipeline_flash_attn_f32_f16) { + if (fa.first.path != FA_COOPMAT2) continue; + const uint32_t Br = fa.first.Br; + const uint32_t Bc = fa.first.Bc; + const bool aligned = fa.first.aligned; + const bool f32acc = fa.first.f32acc; + + const bool bf16_kv = fa.first.k_type == GGML_TYPE_BF16; + const void * spv_data; + size_t spv_size; + const char * name; + if (bf16_kv) { +#if defined(VK_KHR_shader_bfloat16) && defined(GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT) + if (!device->coopmat2_bf16_support) continue; + spv_data = flash_attn_f32_f16_bf16_cm2_data; + spv_size = flash_attn_f32_f16_bf16_cm2_len; + name = aligned ? "flash_attn_f32_bf16_aligned_cm2" : "flash_attn_f32_bf16_cm2"; +#else + continue; +#endif + } else if (aligned) { + if (f32acc) { spv_data = flash_attn_f32_f16_cm2_data; spv_size = flash_attn_f32_f16_cm2_len; name = "flash_attn_f32_f16_aligned_f32acc_cm2"; } + else { spv_data = flash_attn_f32_f16_f16acc_cm2_data; spv_size = flash_attn_f32_f16_f16acc_cm2_len; name = "flash_attn_f32_f16_aligned_f16acc_cm2"; } + } else { + if (f32acc) { spv_data = flash_attn_f32_f16_cm2_data; spv_size = flash_attn_f32_f16_cm2_len; name = "flash_attn_f32_f16_f32acc_cm2"; } + else { spv_data = flash_attn_f32_f16_f16acc_cm2_data; spv_size = flash_attn_f32_f16_f16acc_cm2_len; name = "flash_attn_f32_f16_f16acc_cm2"; } + } + ggml_vk_create_pipeline(device, fa.second, name, spv_size, spv_data, "main", 7, + sizeof(vk_flash_attn_push_constants), {Br, 1, 1}, + get_fa_spec_constants(fa.first), aligned ? Bc : 1, true, false, 0); + } } #endif -#undef CREATE_FA const int mul_mat_id_param_count = 5; @@ -3243,6 +4057,7 @@ static void ggml_vk_load_shaders(vk_device& device) { CREATE_MM(pipeline_matmul_bf16, matmul_bf16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3) } #endif + CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q1_0], matmul_q1_0_f16, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3) CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q4_0], matmul_q4_0_f16, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3) CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q4_1], matmul_q4_1_f16, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3) CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q5_0], matmul_q5_0_f16, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3) @@ -3263,6 +4078,7 @@ static void ggml_vk_load_shaders(vk_device& device) { CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_IQ4_XS], matmul_iq4_xs_f16, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3) CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_IQ4_NL], matmul_iq4_nl_f16, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3) CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_MXFP4], matmul_mxfp4_f16, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3) + CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_NVFP4], matmul_nvfp4_f16, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3) GGML_ASSERT(device->subgroup_ballot); @@ -3272,6 +4088,7 @@ static void ggml_vk_load_shaders(vk_device& device) { CREATE_MM(pipeline_matmul_id_bf16, matmul_id_subgroup_bf16, , wg_denoms, warptile, vk_mat_mat_id_push_constants, 5) } #endif + CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q1_0], matmul_id_subgroup_q1_0_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 5) CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0], matmul_id_subgroup_q4_0_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 5) CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1], matmul_id_subgroup_q4_1_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 5) CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0], matmul_id_subgroup_q5_0_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 5) @@ -3292,6 +4109,7 @@ static void ggml_vk_load_shaders(vk_device& device) { CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS], matmul_id_subgroup_iq4_xs_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 5) CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL], matmul_id_subgroup_iq4_nl_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 5) CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_MXFP4], matmul_id_subgroup_mxfp4_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 5) + CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_NVFP4], matmul_id_subgroup_nvfp4_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 5) #undef CREATE_MM #undef CREATE_MM2 } else @@ -3333,6 +4151,7 @@ static void ggml_vk_load_shaders(vk_device& device) { #endif if (device->coopmat_acc_f16_support) { + CREATE_MM2(GGML_TYPE_Q1_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q1_0], matmul_q1_0_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); CREATE_MM2(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0], matmul_q4_0_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); CREATE_MM2(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1], matmul_q4_1_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); CREATE_MM2(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0], matmul_q5_0_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); @@ -3354,7 +4173,9 @@ static void ggml_vk_load_shaders(vk_device& device) { CREATE_MM2(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_XS], matmul_iq4_xs_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); CREATE_MM2(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL], matmul_iq4_nl_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); CREATE_MM2(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat[GGML_TYPE_MXFP4], matmul_mxfp4_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM2(GGML_TYPE_NVFP4, pipeline_dequant_mul_mat_mat[GGML_TYPE_NVFP4], matmul_nvfp4_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); } else { + CREATE_MM(GGML_TYPE_Q1_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q1_0].f32acc, matmul_q1_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); CREATE_MM(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0].f32acc, matmul_q4_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); CREATE_MM(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1].f32acc, matmul_q4_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); CREATE_MM(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0].f32acc, matmul_q5_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); @@ -3376,6 +4197,7 @@ static void ggml_vk_load_shaders(vk_device& device) { CREATE_MM(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_XS].f32acc, matmul_iq4_xs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); CREATE_MM(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL].f32acc, matmul_iq4_nl_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); CREATE_MM(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat[GGML_TYPE_MXFP4].f32acc, matmul_mxfp4_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(GGML_TYPE_NVFP4, pipeline_dequant_mul_mat_mat[GGML_TYPE_NVFP4].f32acc, matmul_nvfp4_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); } GGML_ASSERT(device->subgroup_ballot); @@ -3389,6 +4211,7 @@ static void ggml_vk_load_shaders(vk_device& device) { } #endif + CREATE_MM2(GGML_TYPE_Q1_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q1_0], matmul_id_subgroup_q1_0_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id); CREATE_MM2(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0], matmul_id_subgroup_q4_0_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id); CREATE_MM2(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1], matmul_id_subgroup_q4_1_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id); CREATE_MM2(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0], matmul_id_subgroup_q5_0_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id); @@ -3409,13 +4232,30 @@ static void ggml_vk_load_shaders(vk_device& device) { CREATE_MM2(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS], matmul_id_subgroup_iq4_xs_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id); CREATE_MM2(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL], matmul_id_subgroup_iq4_nl_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id); CREATE_MM2(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_MXFP4], matmul_id_subgroup_mxfp4_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id); + CREATE_MM2(GGML_TYPE_NVFP4, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_NVFP4], matmul_id_subgroup_nvfp4_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id); #undef CREATE_MM2 #undef CREATE_MM } else #endif // defined(VK_KHR_cooperative_matrix) && defined(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT) if (device->fp16) { // Create 6 variants, {s,m,l}x{unaligned,aligned} + // Selects dot2 SPIR-V variant at runtime when device->dot2_f16 is true #define CREATE_MM(TYPE, PIPELINE_NAME, NAMELC, F16ACC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID, REQSUBGROUPSIZE) \ + if (device->mul_mat ## ID ## _l[TYPE]) \ + ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC #F16ACC "_l", (device->dot2_f16 ? NAMELC ## _dot2 ## F16ACC ## _len : NAMELC ## F16ACC ## _len), (device->dot2_f16 ? NAMELC ## _dot2 ## F16ACC ## _data : NAMELC ## F16ACC ## _data), "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \ + if (device->mul_mat ## ID ## _m[TYPE]) \ + ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->m, #NAMELC #F16ACC "_m", (device->dot2_f16 ? NAMELC ## _dot2 ## F16ACC ## _len : NAMELC ## F16ACC ## _len), (device->dot2_f16 ? NAMELC ## _dot2 ## F16ACC ## _data : NAMELC ## F16ACC ## _data), "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \ + if (device->mul_mat ## ID ## _s[TYPE]) \ + ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->s, #NAMELC #F16ACC "_s", (device->dot2_f16 ? NAMELC ## _dot2 ## F16ACC ## _len : NAMELC ## F16ACC ## _len), (device->dot2_f16 ? NAMELC ## _dot2 ## F16ACC ## _data : NAMELC ## F16ACC ## _data), "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \ + if (device->mul_mat ## ID ## _l[TYPE]) \ + ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_l, #NAMELC #F16ACC "_aligned_l", (device->dot2_f16 ? NAMELC ## _dot2_aligned ## F16ACC ## _len : NAMELC ## _aligned ## F16ACC ## _len), (device->dot2_f16 ? NAMELC ## _dot2_aligned ## F16ACC ## _data : NAMELC ## _aligned ## F16ACC ## _data), "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, l_align, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \ + if (device->mul_mat ## ID ## _m[TYPE]) \ + ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_m, #NAMELC #F16ACC "_aligned_m", (device->dot2_f16 ? NAMELC ## _dot2_aligned ## F16ACC ## _len : NAMELC ## _aligned ## F16ACC ## _len), (device->dot2_f16 ? NAMELC ## _dot2_aligned ## F16ACC ## _data : NAMELC ## _aligned ## F16ACC ## _data), "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, m_align, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \ + if (device->mul_mat ## ID ## _s[TYPE]) \ + ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", (device->dot2_f16 ? NAMELC ## _dot2_aligned ## F16ACC ## _len : NAMELC ## _aligned ## F16ACC ## _len), (device->dot2_f16 ? NAMELC ## _dot2_aligned ## F16ACC ## _data : NAMELC ## _aligned ## F16ACC ## _data), "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, s_align, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \ + + // bf16 scalar path promotes to f32, no dot2 variant +#define CREATE_MM_NODOT2(TYPE, PIPELINE_NAME, NAMELC, F16ACC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID, REQSUBGROUPSIZE) \ if (device->mul_mat ## ID ## _l[TYPE]) \ ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC #F16ACC "_l", NAMELC ## F16ACC ## _len, NAMELC ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \ if (device->mul_mat ## ID ## _m[TYPE]) \ @@ -3430,13 +4270,13 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", NAMELC ## _aligned ## F16ACC ## _len, NAMELC ## _aligned ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, s_align, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \ #define CREATE_MMQ(TYPE, PIPELINE_NAME, NAMELC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID, REQSUBGROUPSIZE) \ - if (device->mul_mat ## ID ## _l[TYPE]) { \ + if (device->mul_mat ## ID ## _l_int[TYPE]) { \ ggml_vk_create_pipeline(device, device-> PIPELINE_NAME .f32acc->l, #NAMELC "_l", NAMELC ## _len, NAMELC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \ } \ - if (device->mul_mat ## ID ## _m[TYPE]) { \ + if (device->mul_mat ## ID ## _m_int[TYPE]) { \ ggml_vk_create_pipeline(device, device-> PIPELINE_NAME .f32acc->m, #NAMELC "_m", NAMELC ## _len, NAMELC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \ } \ - if (device->mul_mat ## ID ## _s[TYPE]) { \ + if (device->mul_mat ## ID ## _s_int[TYPE]) { \ ggml_vk_create_pipeline(device, device-> PIPELINE_NAME .f32acc->s, #NAMELC "_s", NAMELC ## _len, NAMELC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \ } \ @@ -3450,14 +4290,14 @@ static void ggml_vk_load_shaders(vk_device& device) { CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_f16, matmul_f16, wg_denoms, warptile, vk_mat_mat_push_constants, 3, , 0); CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_f16_f32, matmul_f16_f32, wg_denoms, warptile, vk_mat_mat_push_constants, 3, , 0); - CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_bf16, matmul_bf16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, , 0); + CREATE_MM_NODOT2(GGML_TYPE_BF16, pipeline_matmul_bf16, matmul_bf16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, , 0); + CREATE_MM2(GGML_TYPE_Q1_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q1_0], matmul_q1_0_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0); CREATE_MM2(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0], matmul_q4_0_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0); CREATE_MM2(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1], matmul_q4_1_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0); CREATE_MM2(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0], matmul_q5_0_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0); CREATE_MM2(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1], matmul_q5_1_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0); CREATE_MM2(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0], matmul_q8_0_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0); - CREATE_MM2(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q2_K], matmul_q2_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0); CREATE_MM2(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q3_K], matmul_q3_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0); CREATE_MM2(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K], matmul_q4_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0); @@ -3473,6 +4313,7 @@ static void ggml_vk_load_shaders(vk_device& device) { CREATE_MM2(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_XS], matmul_iq4_xs_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0); CREATE_MM2(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL], matmul_iq4_nl_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0); CREATE_MM2(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat[GGML_TYPE_MXFP4], matmul_mxfp4_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0); + CREATE_MM2(GGML_TYPE_NVFP4, pipeline_dequant_mul_mat_mat[GGML_TYPE_NVFP4], matmul_nvfp4_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0); #if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT) if (device->integer_dot_product) { @@ -3496,8 +4337,8 @@ static void ggml_vk_load_shaders(vk_device& device) { CREATE_MM(GGML_TYPE_F32, pipeline_matmul_id_f32, matmul_id_subgroup_f32_f32, , wg_denoms, warptile_id, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size_16); CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16, matmul_id_subgroup_f16, wg_denoms, warptile_id, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size_16); CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16_f32, matmul_id_subgroup_f16_f32, wg_denoms, warptile_id, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size_16); - CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_id_bf16, matmul_id_subgroup_bf16, , wg_denoms, warptile_id, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size_16); - + CREATE_MM_NODOT2(GGML_TYPE_BF16, pipeline_matmul_id_bf16, matmul_id_subgroup_bf16, , wg_denoms, warptile_id, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size_16); + CREATE_MM2(GGML_TYPE_Q1_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q1_0], matmul_id_subgroup_q1_0_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size); CREATE_MM2(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0], matmul_id_subgroup_q4_0_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size); CREATE_MM2(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1], matmul_id_subgroup_q4_1_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size); CREATE_MM2(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0], matmul_id_subgroup_q5_0_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size); @@ -3518,6 +4359,7 @@ static void ggml_vk_load_shaders(vk_device& device) { CREATE_MM2(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS], matmul_id_subgroup_iq4_xs_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size); CREATE_MM2(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL], matmul_id_subgroup_iq4_nl_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size); CREATE_MM2(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_MXFP4], matmul_id_subgroup_mxfp4_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size); + CREATE_MM2(GGML_TYPE_NVFP4, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_NVFP4], matmul_id_subgroup_nvfp4_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size); #if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT) if (device->integer_dot_product) { @@ -3540,8 +4382,8 @@ static void ggml_vk_load_shaders(vk_device& device) { CREATE_MM(GGML_TYPE_F32, pipeline_matmul_id_f32, matmul_id_f32_f32, , wg_denoms, warptile, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0); CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16, matmul_id_f16, wg_denoms, warptile, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0); CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16_f32, matmul_id_f16_f32, wg_denoms, warptile, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0); - CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_id_bf16, matmul_id_bf16, , wg_denoms, warptile, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0); - + CREATE_MM_NODOT2(GGML_TYPE_BF16, pipeline_matmul_id_bf16, matmul_id_bf16, , wg_denoms, warptile, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0); + CREATE_MM2(GGML_TYPE_Q1_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q1_0], matmul_id_q1_0_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0); CREATE_MM2(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0], matmul_id_q4_0_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0); CREATE_MM2(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1], matmul_id_q4_1_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0); CREATE_MM2(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0], matmul_id_q5_0_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0); @@ -3562,6 +4404,7 @@ static void ggml_vk_load_shaders(vk_device& device) { CREATE_MM2(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS], matmul_id_iq4_xs_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0); CREATE_MM2(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL], matmul_id_iq4_nl_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0); CREATE_MM2(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_MXFP4], matmul_id_mxfp4_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0); + CREATE_MM2(GGML_TYPE_NVFP4, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_NVFP4], matmul_id_nvfp4_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0); #if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT) if (device->integer_dot_product) { @@ -3584,6 +4427,7 @@ static void ggml_vk_load_shaders(vk_device& device) { #undef CREATE_MM2 #undef CREATE_MMQ #undef CREATE_MM +#undef CREATE_MM_NODOT2 } else { // Create 6 variants, {s,m,l}x{unaligned,aligned} #define CREATE_MM(TYPE, PIPELINE_NAME, NAMELC, F16ACC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID, REQSUBGROUPSIZE) \ @@ -3601,11 +4445,11 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", NAMELC ## _aligned ## F16ACC ## _fp32_len, NAMELC ## _aligned ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, s_align, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \ #define CREATE_MMQ(TYPE, PIPELINE_NAME, NAMELC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \ - if (device->mul_mat ## ID ## _l[TYPE]) \ + if (device->mul_mat ## ID ## _l_int[TYPE]) \ ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC "_l", NAMELC ## _fp32_len, NAMELC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1); \ - if (device->mul_mat ## ID ## _m[TYPE]) \ + if (device->mul_mat ## ID ## _m_int[TYPE]) \ ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->m, #NAMELC "_m", NAMELC ## _fp32_len, NAMELC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, 1); \ - if (device->mul_mat ## ID ## _s[TYPE]) \ + if (device->mul_mat ## ID ## _s_int[TYPE]) \ ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->s, #NAMELC "_s", NAMELC ## _fp32_len, NAMELC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, 1); \ CREATE_MM(GGML_TYPE_F32, pipeline_matmul_f32, matmul_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, , 0); @@ -3615,6 +4459,7 @@ static void ggml_vk_load_shaders(vk_device& device) { CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_bf16, matmul_bf16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, , 0); + CREATE_MM(GGML_TYPE_Q1_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q1_0].f32acc, matmul_q1_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0); CREATE_MM(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0].f32acc, matmul_q4_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0); CREATE_MM(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1].f32acc, matmul_q4_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0); CREATE_MM(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0].f32acc, matmul_q5_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0); @@ -3636,6 +4481,7 @@ static void ggml_vk_load_shaders(vk_device& device) { CREATE_MM(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_XS].f32acc, matmul_iq4_xs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0); CREATE_MM(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL].f32acc, matmul_iq4_nl_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0); CREATE_MM(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat[GGML_TYPE_MXFP4].f32acc, matmul_mxfp4_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0); + CREATE_MM(GGML_TYPE_NVFP4, pipeline_dequant_mul_mat_mat[GGML_TYPE_NVFP4].f32acc, matmul_nvfp4_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0); #if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT) if (device->integer_dot_product) { @@ -3659,6 +4505,7 @@ static void ggml_vk_load_shaders(vk_device& device) { CREATE_MM(GGML_TYPE_F16, pipeline_matmul_id_f16_f32.f32acc, matmul_id_subgroup_f16_f32, , wg_denoms, warptile_id, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size_16); CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_id_bf16, matmul_id_subgroup_bf16, , wg_denoms, warptile_id, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size_16); + CREATE_MM(GGML_TYPE_Q1_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q1_0].f32acc, matmul_id_subgroup_q1_0_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size); CREATE_MM(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0].f32acc, matmul_id_subgroup_q4_0_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size); CREATE_MM(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1].f32acc, matmul_id_subgroup_q4_1_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size); CREATE_MM(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0].f32acc, matmul_id_subgroup_q5_0_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size); @@ -3679,12 +4526,14 @@ static void ggml_vk_load_shaders(vk_device& device) { CREATE_MM(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS].f32acc, matmul_id_subgroup_iq4_xs_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size); CREATE_MM(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f32acc, matmul_id_subgroup_iq4_nl_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size); CREATE_MM(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_MXFP4].f32acc, matmul_id_subgroup_mxfp4_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size); + CREATE_MM(GGML_TYPE_NVFP4, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_NVFP4].f32acc, matmul_id_subgroup_nvfp4_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size); } else { CREATE_MM(GGML_TYPE_F32, pipeline_matmul_id_f32, matmul_id_f32_f32, , wg_denoms, warptile, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0); CREATE_MM(GGML_TYPE_F16, pipeline_matmul_id_f16.f32acc, matmul_id_f16, , wg_denoms, warptile, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0); CREATE_MM(GGML_TYPE_F16, pipeline_matmul_id_f16_f32.f32acc, matmul_id_f16_f32, , wg_denoms, warptile, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0); CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_id_bf16, matmul_id_bf16, , wg_denoms, warptile, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0); + CREATE_MM(GGML_TYPE_Q1_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q1_0].f32acc, matmul_id_q1_0_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0); CREATE_MM(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0].f32acc, matmul_id_q4_0_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0); CREATE_MM(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1].f32acc, matmul_id_q4_1_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0); CREATE_MM(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0].f32acc, matmul_id_q5_0_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0); @@ -3705,6 +4554,7 @@ static void ggml_vk_load_shaders(vk_device& device) { CREATE_MM(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS].f32acc, matmul_id_iq4_xs_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0); CREATE_MM(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f32acc, matmul_id_iq4_nl_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0); CREATE_MM(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_MXFP4].f32acc, matmul_id_mxfp4_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0); + CREATE_MM(GGML_TYPE_NVFP4, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_NVFP4].f32acc, matmul_id_nvfp4_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0); } } // reusing CREATE_MM from the fp32 path @@ -3713,20 +4563,17 @@ static void ggml_vk_load_shaders(vk_device& device) { && !device->coopmat_bf16_support #endif ) { + const uint32_t s_warptile_wm = device->subgroup_size == 8 ? 8 : 32; + // use scalar tile sizes l_warptile = { 128, 128, 128, 16, subgroup_size_8 * 2, 64, 2, 4, 4, 1, subgroup_size_8 }; m_warptile = { 128, 64, 64, 16, subgroup_size_8, 32, 2, 4, 2, 1, subgroup_size_8 }; - s_warptile = { subgroup_size_16, 32, 32, 16, 32, 32, 2, 2, 2, 1, subgroup_size_8 }; + s_warptile = { subgroup_size_32, 32, 32, 16, s_warptile_wm, 32, 2, 2, 2, 1, subgroup_size_8 }; l_wg_denoms = {128, 128, 1 }; m_wg_denoms = { 64, 64, 1 }; s_wg_denoms = { 32, 32, 1 }; - if (device->vendor_id == VK_VENDOR_ID_INTEL && device->architecture == INTEL_XE2) { - // Xe2/Xe3 - bf16 warptile performance tuning - l_warptile = { 512, 128, 128, 16, subgroup_size_8, 32, 2, 4, 4, 1, subgroup_size_8 }; - } - CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_bf16, matmul_bf16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, , 0); CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_id_bf16, matmul_id_bf16, , wg_denoms, warptile, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0); } @@ -3780,6 +4627,7 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_F32 ][i], "mul_mat_vec_f32_f32_f32", arr_dmmv_f32_f32_f32_len[reduc], arr_dmmv_f32_f32_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, {wg_size_subgroup, 1, i+1}, 1, false, use_subgroups, force_subgroup_size); ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_F16 ][i], "mul_mat_vec_f16_f32_f32", arr_dmmv_f16_f32_f32_len[reduc], arr_dmmv_f16_f32_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {wg_size_subgroup, 2, i+1}, 1, false, use_subgroups, force_subgroup_size); ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_BF16][i], "mul_mat_vec_bf16_f32_f32", arr_dmmv_bf16_f32_f32_len[reduc], arr_dmmv_bf16_f32_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {wg_size_subgroup, 2, i+1}, 1, false, use_subgroups, force_subgroup_size); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_Q1_0][i], "mul_mat_vec_q1_0_f32_f32", arr_dmmv_q1_0_f32_f32_len[reduc], arr_dmmv_q1_0_f32_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {wg_size_subgroup, 2*rm_stdq, i+1}, 1, true, use_subgroups, force_subgroup_size); ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_Q4_0][i], "mul_mat_vec_q4_0_f32_f32", arr_dmmv_q4_0_f32_f32_len[reduc], arr_dmmv_q4_0_f32_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {wg_size_subgroup, 2*rm_stdq, i+1}, 1, true, use_subgroups, force_subgroup_size); ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_Q4_1][i], "mul_mat_vec_q4_1_f32_f32", arr_dmmv_q4_1_f32_f32_len[reduc], arr_dmmv_q4_1_f32_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {wg_size_subgroup, 2*rm_stdq, i+1}, 1, true, use_subgroups, force_subgroup_size); ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_Q5_0][i], "mul_mat_vec_q5_0_f32_f32", arr_dmmv_q5_0_f32_f32_len[reduc], arr_dmmv_q5_0_f32_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {wg_size_subgroup, 2*rm_stdq, i+1}, 1, true, use_subgroups, force_subgroup_size); @@ -3800,10 +4648,12 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_IQ4_XS][i], "mul_mat_vec_iq4_xs_f32_f32", arr_dmmv_iq4_xs_f32_f32_len[reduc16], arr_dmmv_iq4_xs_f32_f32_data[reduc16], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true, use_subgroups16, force_subgroup_size16); ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_IQ4_NL][i], "mul_mat_vec_iq4_nl_f32_f32", arr_dmmv_iq4_nl_f32_f32_len[reduc16], arr_dmmv_iq4_nl_f32_f32_data[reduc16], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true, use_subgroups16, force_subgroup_size16); ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_MXFP4][i], "mul_mat_vec_mxfp4_f32_f32", arr_dmmv_mxfp4_f32_f32_len[reduc16], arr_dmmv_mxfp4_f32_f32_data[reduc16], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true, use_subgroups16, force_subgroup_size16); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_NVFP4][i], "mul_mat_vec_nvfp4_f32_f32", arr_dmmv_nvfp4_f32_f32_len[reduc16], arr_dmmv_nvfp4_f32_f32_data[reduc16], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true, use_subgroups16, force_subgroup_size16); ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_F32 ][i], "mul_mat_vec_f32_f16_f32", arr_dmmv_f32_f16_f32_len[reduc], arr_dmmv_f32_f16_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, {wg_size_subgroup, 1, i+1}, 1, false, use_subgroups, force_subgroup_size); ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_F16 ][i], "mul_mat_vec_f16_f16_f32", arr_dmmv_f16_f16_f32_len[reduc], arr_dmmv_f16_f16_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {wg_size_subgroup, 2, i+1}, 1, false, use_subgroups, force_subgroup_size); ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_BF16][i], "mul_mat_vec_bf16_f16_f32", arr_dmmv_bf16_f16_f32_len[reduc], arr_dmmv_bf16_f16_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {wg_size_subgroup, 2, i+1}, 1, false, use_subgroups, force_subgroup_size); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_Q1_0][i], "mul_mat_vec_q1_0_f16_f32", arr_dmmv_q1_0_f16_f32_len[reduc], arr_dmmv_q1_0_f16_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {wg_size_subgroup, 2*rm_stdq, i+1}, 1, true, use_subgroups, force_subgroup_size); ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_Q4_0][i], "mul_mat_vec_q4_0_f16_f32", arr_dmmv_q4_0_f16_f32_len[reduc], arr_dmmv_q4_0_f16_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {wg_size_subgroup, 2*rm_stdq, i+1}, 1, true, use_subgroups, force_subgroup_size); ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_Q4_1][i], "mul_mat_vec_q4_1_f16_f32", arr_dmmv_q4_1_f16_f32_len[reduc], arr_dmmv_q4_1_f16_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {wg_size_subgroup, 2*rm_stdq, i+1}, 1, true, use_subgroups, force_subgroup_size); ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_Q5_0][i], "mul_mat_vec_q5_0_f16_f32", arr_dmmv_q5_0_f16_f32_len[reduc], arr_dmmv_q5_0_f16_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {wg_size_subgroup, 2*rm_stdq, i+1}, 1, true, use_subgroups, force_subgroup_size); @@ -3824,6 +4674,7 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_IQ4_XS][i], "mul_mat_vec_iq4_xs_f16_f32", arr_dmmv_iq4_xs_f16_f32_len[reduc16], arr_dmmv_iq4_xs_f16_f32_data[reduc16], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true, use_subgroups16, force_subgroup_size16); ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_IQ4_NL][i], "mul_mat_vec_iq4_nl_f16_f32", arr_dmmv_iq4_nl_f16_f32_len[reduc16], arr_dmmv_iq4_nl_f16_f32_data[reduc16], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true, use_subgroups16, force_subgroup_size16); ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_MXFP4][i], "mul_mat_vec_mxfp4_f16_f32", arr_dmmv_mxfp4_f16_f32_len[reduc16], arr_dmmv_mxfp4_f16_f32_data[reduc16], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true, use_subgroups16, force_subgroup_size16); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_NVFP4][i], "mul_mat_vec_nvfp4_f16_f32", arr_dmmv_nvfp4_f16_f32_len[reduc16], arr_dmmv_nvfp4_f16_f32_data[reduc16], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true, use_subgroups16, force_subgroup_size16); #if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT) if (device->integer_dot_product) { @@ -3854,6 +4705,7 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_F32 ], "mul_mat_vec_id_f32_f32", arr_dmmv_id_f32_f32_f32_len[reduc], arr_dmmv_id_f32_f32_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {1, 1, 1}, {wg_size_subgroup, 1}, 1, false, use_subgroups, force_subgroup_size); ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_F16 ], "mul_mat_vec_id_f16_f32", arr_dmmv_id_f16_f32_f32_len[reduc], arr_dmmv_id_f16_f32_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {2, 1, 1}, {wg_size_subgroup, 2}, 1, false, use_subgroups, force_subgroup_size); ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_BF16], "mul_mat_vec_id_bf16_f32", arr_dmmv_id_bf16_f32_f32_len[reduc], arr_dmmv_id_bf16_f32_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {2, 1, 1}, {wg_size_subgroup, 2}, 1, false, use_subgroups, force_subgroup_size); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_Q1_0], "mul_mat_vec_id_q1_0_f32", arr_dmmv_id_q1_0_f32_f32_len[reduc], arr_dmmv_id_q1_0_f32_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {2*rm_stdq, 1, 1}, {wg_size_subgroup, 2*rm_stdq}, 1, true, use_subgroups, force_subgroup_size); ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_Q4_0], "mul_mat_vec_id_q4_0_f32", arr_dmmv_id_q4_0_f32_f32_len[reduc], arr_dmmv_id_q4_0_f32_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {2*rm_stdq, 1, 1}, {wg_size_subgroup, 2*rm_stdq}, 1, true, use_subgroups, force_subgroup_size); ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_Q4_1], "mul_mat_vec_id_q4_1_f32", arr_dmmv_id_q4_1_f32_f32_len[reduc], arr_dmmv_id_q4_1_f32_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {2*rm_stdq, 1, 1}, {wg_size_subgroup, 2*rm_stdq}, 1, true, use_subgroups, force_subgroup_size); ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_Q5_0], "mul_mat_vec_id_q5_0_f32", arr_dmmv_id_q5_0_f32_f32_len[reduc], arr_dmmv_id_q5_0_f32_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {2*rm_stdq, 1, 1}, {wg_size_subgroup, 2*rm_stdq}, 1, true, use_subgroups, force_subgroup_size); @@ -3874,6 +4726,7 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_IQ4_XS], "mul_mat_vec_id_iq4_xs_f32", arr_dmmv_id_iq4_xs_f32_f32_len[reduc16], arr_dmmv_id_iq4_xs_f32_f32_data[reduc16], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq}, 1, true, use_subgroups16, force_subgroup_size16); ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_IQ4_NL], "mul_mat_vec_id_iq4_nl_f32", arr_dmmv_id_iq4_nl_f32_f32_len[reduc16], arr_dmmv_id_iq4_nl_f32_f32_data[reduc16], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq}, 1, true, use_subgroups16, force_subgroup_size16); ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_MXFP4], "mul_mat_vec_id_mxfp4_f32", arr_dmmv_id_mxfp4_f32_f32_len[reduc16], arr_dmmv_id_mxfp4_f32_f32_data[reduc16], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq}, 1, true, use_subgroups16, force_subgroup_size16); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_NVFP4], "mul_mat_vec_id_nvfp4_f32", arr_dmmv_id_nvfp4_f32_f32_len[reduc16], arr_dmmv_id_nvfp4_f32_f32_data[reduc16], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq}, 1, true, use_subgroups16, force_subgroup_size16); #if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT) if (device->integer_dot_product) { @@ -3908,6 +4761,7 @@ static void ggml_vk_load_shaders(vk_device& device) { // dequant shaders ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_F32 ], "f32_to_f16", dequant_f32_len, dequant_f32_data, "main", 2, 5 * sizeof(uint32_t), {256 * 16, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_Q1_0], "dequant_q1_0", dequant_q1_0_len, dequant_q1_0_data, "main", 2, 5 * sizeof(uint32_t), {256 * 8, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_Q4_0], "dequant_q4_0", dequant_q4_0_len, dequant_q4_0_data, "main", 2, 5 * sizeof(uint32_t), {256 * 16, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_Q4_1], "dequant_q4_1", dequant_q4_1_len, dequant_q4_1_data, "main", 2, 5 * sizeof(uint32_t), {256 * 16, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_Q5_0], "dequant_q5_0", dequant_q5_0_len, dequant_q5_0_data, "main", 2, 5 * sizeof(uint32_t), {256 * 16, 1, 1}, {}, 1); @@ -3928,11 +4782,13 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_IQ4_XS], "dequant_iq4_xs", dequant_iq4_xs_len, dequant_iq4_xs_data, "main", 2, 5 * sizeof(uint32_t), {256 * 32, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_IQ4_NL], "dequant_iq4_nl", dequant_iq4_nl_len, dequant_iq4_nl_data, "main", 2, 5 * sizeof(uint32_t), {256 * 16, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_MXFP4], "dequant_mxfp4", dequant_mxfp4_len, dequant_mxfp4_data, "main", 2, 5 * sizeof(uint32_t), {256 * 16, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_NVFP4], "dequant_nvfp4", dequant_nvfp4_len, dequant_nvfp4_data, "main", 2, 5 * sizeof(uint32_t), {256 * 16, 1, 1}, {}, 1); // get_rows ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_F32 ], "get_rows_f32", get_rows_f32_len, get_rows_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), { 512, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_F16 ], "get_rows_f16", get_rows_f16_len, get_rows_f16_data, "main", 3, sizeof(vk_op_binary_push_constants), { 512, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_BF16], "get_rows_bf16", get_rows_bf16_len, get_rows_bf16_data, "main", 3, sizeof(vk_op_binary_push_constants), { 512, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_Q1_0], "get_rows_q1_0", get_rows_q1_0_len, get_rows_q1_0_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_Q4_0], "get_rows_q4_0", get_rows_q4_0_len, get_rows_q4_0_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_Q4_1], "get_rows_q4_1", get_rows_q4_1_len, get_rows_q4_1_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_Q5_0], "get_rows_q5_0", get_rows_q5_0_len, get_rows_q5_0_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); @@ -3953,11 +4809,13 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_IQ4_XS], "get_rows_iq4_xs", get_rows_iq4_xs_len, get_rows_iq4_xs_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_IQ4_NL], "get_rows_iq4_nl", get_rows_iq4_nl_len, get_rows_iq4_nl_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_MXFP4], "get_rows_mxfp4", get_rows_mxfp4_len, get_rows_mxfp4_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_NVFP4], "get_rows_nvfp4", get_rows_nvfp4_len, get_rows_nvfp4_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_I32], "get_rows_i32", get_rows_i32_len, get_rows_i32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_F32 ], "get_rows_f32_f32", get_rows_f32_f32_len, get_rows_f32_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), { 512, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_F16 ], "get_rows_f16_f32", get_rows_f16_f32_len, get_rows_f16_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), { 512, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_BF16], "get_rows_bf16_f32", get_rows_bf16_f32_len, get_rows_bf16_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), { 512, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_Q1_0], "get_rows_q1_0_f32", get_rows_q1_0_f32_len, get_rows_q1_0_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_Q4_0], "get_rows_q4_0_f32", get_rows_q4_0_f32_len, get_rows_q4_0_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_Q4_1], "get_rows_q4_1_f32", get_rows_q4_1_f32_len, get_rows_q4_1_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_Q5_0], "get_rows_q5_0_f32", get_rows_q5_0_f32_len, get_rows_q5_0_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); @@ -3978,9 +4836,15 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_IQ4_XS], "get_rows_iq4_xs_f32", get_rows_iq4_xs_f32_len, get_rows_iq4_xs_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_IQ4_NL], "get_rows_iq4_nl_f32", get_rows_iq4_nl_f32_len, get_rows_iq4_nl_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_MXFP4], "get_rows_mxfp4_f32", get_rows_mxfp4_f32_len, get_rows_mxfp4_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_NVFP4], "get_rows_nvfp4_f32", get_rows_nvfp4_f32_len, get_rows_nvfp4_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_matmul_split_k_reduce, "split_k_reduce", split_k_reduce_len, split_k_reduce_data, "main", 2, 2 * sizeof(uint32_t), {256 * 4, 1, 1}, {}, 1); - ggml_vk_create_pipeline(device, device->pipeline_flash_attn_split_k_reduce, "fa_split_k_reduce", fa_split_k_reduce_len, fa_split_k_reduce_data, "main", 3, 5 * sizeof(uint32_t), {1, device->subgroup_size, 1}, {device->subgroup_size}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_flash_attn_split_k_reduce, "fa_split_k_reduce", fa_split_k_reduce_len, fa_split_k_reduce_data, "main", 3, sizeof(vk_op_flash_attn_split_k_reduce_push_constants), {1, device->subgroup_size, 1}, {device->subgroup_size}, 1, true); + + for (auto &it : device->pipeline_fa_mask_opt) { + auto BrBc = it.first; + ggml_vk_create_pipeline(device, it.second, "fa_mask_opt", fa_mask_opt_len, fa_mask_opt_data, "main", 2, sizeof(vk_op_flash_attn_mask_opt_push_constants), {1, 1, 1}, {128, 128 / device->subgroup_size, BrBc.first, BrBc.second}, 1, true, true, device->subgroup_size); + } if (device->subgroup_clustered && device->subgroup_require_full_support) { ggml_vk_create_pipeline(device, device->pipeline_quantize_q8_1_x4, "quantize_q8_1_x4", quantize_q8_1_x4_subgroup_len, quantize_q8_1_x4_subgroup_data, "main", 2, sizeof(vk_quantize_q8_1_push_constants), {32 * device->subgroup_size / 8, 1, 1}, { device->subgroup_size }, 1, true, true); @@ -4005,20 +4869,20 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_rms_norm_partials_f32, "rms_norm_partials_f32", rms_norm_partials_f32_len, rms_norm_partials_f32_data, "main", 4, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {0, 0}, 1, true); ggml_vk_create_pipeline(device, device->pipeline_rms_norm_mul_partials_f32, "rms_norm_mul_partials_f32", rms_norm_partials_f32_len, rms_norm_partials_f32_data, "main", 4, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {0, 1}, 1, true); - if (device->float_controls_rte_fp16 && - sizeof(vk_op_rms_norm_mul_rope_push_constants) <= device->properties.limits.maxPushConstantsSize) { + if (sizeof(vk_op_rms_norm_mul_rope_push_constants) <= device->properties.limits.maxPushConstantsSize) { ggml_vk_create_pipeline(device, device->pipeline_rms_norm_mul_rope_f32_f32, "rms_norm_mul_rope_f32_f32", rms_norm_mul_rope_f32_f32_len, rms_norm_mul_rope_f32_f32_data, "main", 7, sizeof(vk_op_rms_norm_mul_rope_push_constants), {1, 1, 1}, {0, 1}, 1, true); - ggml_vk_create_pipeline(device, device->pipeline_rms_norm_mul_rope_f32_f16, "rms_norm_mul_rope_f32_f16", rms_norm_mul_rope_f32_f16_rte_len, rms_norm_mul_rope_f32_f16_rte_data, "main", 7, sizeof(vk_op_rms_norm_mul_rope_push_constants), {1, 1, 1}, {0, 1}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_rms_norm_mul_rope_f32_f16, "rms_norm_mul_rope_f32_f16", rms_norm_mul_rope_f32_f16_len, rms_norm_mul_rope_f32_f16_data, "main", 7, sizeof(vk_op_rms_norm_mul_rope_push_constants), {1, 1, 1}, {0, 1}, 1, true); } ggml_vk_create_pipeline(device, device->pipeline_rms_norm_back_f32, "rms_norm_back_f32", rms_norm_back_f32_len, rms_norm_back_f32_data, "main", 3, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1); - ggml_vk_create_pipeline(device, device->pipeline_l2_norm_f32, "l2_norm_f32", l2_norm_f32_len, l2_norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_l2_norm_f32, "l2_norm_f32", l2_norm_f32_len, l2_norm_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {1, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_f32, "cpy_f32_f32", cpy_f32_f32_len, cpy_f32_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_f16, "cpy_f32_f16", cpy_f32_f16_len, cpy_f32_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_cpy_f16_f16, "cpy_f16_f16", cpy_f16_f16_len, cpy_f16_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_cpy_f16_f32, "cpy_f16_f32", cpy_f16_f32_len, cpy_f16_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_bf16,"cpy_f32_bf16",cpy_f32_bf16_len,cpy_f32_bf16_data,"main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_cpy_bf16_f32,"cpy_bf16_f32",cpy_bf16_f32_len,cpy_bf16_f32_data,"main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_cpy_i32_f32, "cpy_i32_f32", cpy_i32_f32_len, cpy_i32_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_i32, "cpy_f32_i32", cpy_f32_i32_len, cpy_f32_i32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); @@ -4027,49 +4891,39 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_contig_cpy_f16_f16, "contig_cpy_f16_f16", contig_cpy_f16_f16_len, contig_cpy_f16_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_contig_cpy_f16_f32, "contig_cpy_f16_f32", contig_cpy_f16_f32_len, contig_cpy_f16_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_contig_cpy_f32_bf16,"contig_cpy_f32_bf16",contig_cpy_f32_bf16_len,contig_cpy_f32_bf16_data,"main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_contig_cpy_bf16_f32,"contig_cpy_bf16_f32",contig_cpy_bf16_f32_len,contig_cpy_bf16_f32_data,"main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_contig_cpy_i32_f32, "contig_cpy_i32_f32", contig_cpy_i32_f32_len, contig_cpy_i32_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_contig_cpy_f32_i32, "contig_cpy_f32_i32", contig_cpy_f32_i32_len, contig_cpy_f32_i32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_cpy_transpose_32, "cpy_transpose_32", cpy_transpose_32_len, cpy_transpose_32_data, "main", 2, sizeof(vk_op_unary_push_constants), {1, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_cpy_transpose_16, "cpy_transpose_16", cpy_transpose_16_len, cpy_transpose_16_data, "main", 2, sizeof(vk_op_unary_push_constants), {1, 1, 1}, {}, 1); - if (device->float_controls_rte_fp16) { - ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q4_0], "cpy_f32_q4_0", cpy_f32_q4_0_rte_len, cpy_f32_q4_0_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1); - ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q4_1], "cpy_f32_q4_1", cpy_f32_q4_1_rte_len, cpy_f32_q4_1_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1); - ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q5_0], "cpy_f32_q5_0", cpy_f32_q5_0_rte_len, cpy_f32_q5_0_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1); - ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q5_1], "cpy_f32_q5_1", cpy_f32_q5_1_rte_len, cpy_f32_q5_1_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1); - ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q8_0], "cpy_f32_q8_0", cpy_f32_q8_0_rte_len, cpy_f32_q8_0_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1); - ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_IQ4_NL], "cpy_f32_iq4_nl", cpy_f32_iq4_nl_rte_len, cpy_f32_iq4_nl_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1); - } else { - ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q4_0], "cpy_f32_q4_0", cpy_f32_q4_0_len, cpy_f32_q4_0_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1); - ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q4_1], "cpy_f32_q4_1", cpy_f32_q4_1_len, cpy_f32_q4_1_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1); - ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q5_0], "cpy_f32_q5_0", cpy_f32_q5_0_len, cpy_f32_q5_0_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1); - ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q5_1], "cpy_f32_q5_1", cpy_f32_q5_1_len, cpy_f32_q5_1_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1); - ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q8_0], "cpy_f32_q8_0", cpy_f32_q8_0_len, cpy_f32_q8_0_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1); - ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_IQ4_NL], "cpy_f32_iq4_nl", cpy_f32_iq4_nl_len, cpy_f32_iq4_nl_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1); - } - -#define SET_ROWS(itype, rte) \ - ggml_vk_create_pipeline(device, device->pipeline_set_rows ## itype [GGML_TYPE_F32], "set_rows_f32" #itype, set_rows_f32 ## itype ## rte ## _len, set_rows_f32 ## itype ## rte ## _data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true); \ - ggml_vk_create_pipeline(device, device->pipeline_set_rows ## itype [GGML_TYPE_F16], "set_rows_f16" #itype, set_rows_f16 ## itype ## rte ## _len, set_rows_f16 ## itype ## rte ## _data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true); \ - ggml_vk_create_pipeline(device, device->pipeline_set_rows ## itype [GGML_TYPE_BF16], "set_rows_bf16" #itype, set_rows_bf16 ## itype ## rte ## _len, set_rows_bf16 ## itype ## rte ## _data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true); \ - ggml_vk_create_pipeline(device, device->pipeline_set_rows ## itype [GGML_TYPE_Q4_0], "set_rows_q4_0" #itype, set_rows_q4_0 ## itype ## rte ## _len, set_rows_q4_0 ## itype ## rte ## _data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true); \ - ggml_vk_create_pipeline(device, device->pipeline_set_rows ## itype [GGML_TYPE_Q4_1], "set_rows_q4_1" #itype, set_rows_q4_1 ## itype ## rte ## _len, set_rows_q4_1 ## itype ## rte ## _data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true); \ - ggml_vk_create_pipeline(device, device->pipeline_set_rows ## itype [GGML_TYPE_Q5_0], "set_rows_q5_0" #itype, set_rows_q5_0 ## itype ## rte ## _len, set_rows_q5_0 ## itype ## rte ## _data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true); \ - ggml_vk_create_pipeline(device, device->pipeline_set_rows ## itype [GGML_TYPE_Q5_1], "set_rows_q5_1" #itype, set_rows_q5_1 ## itype ## rte ## _len, set_rows_q5_1 ## itype ## rte ## _data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true); \ - ggml_vk_create_pipeline(device, device->pipeline_set_rows ## itype [GGML_TYPE_Q8_0], "set_rows_q8_0" #itype, set_rows_q8_0 ## itype ## rte ## _len, set_rows_q8_0 ## itype ## rte ## _data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true); \ - ggml_vk_create_pipeline(device, device->pipeline_set_rows ## itype [GGML_TYPE_IQ4_NL], "set_rows_iq4_nl" #itype, set_rows_iq4_nl ## itype ## rte ## _len, set_rows_iq4_nl ## itype ## rte ## _data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true); - - if (device->float_controls_rte_fp16) { - SET_ROWS(_i32, _rte) - SET_ROWS(_i64, _rte) - } else { - SET_ROWS(_i32, ) - SET_ROWS(_i64, ) - } + ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q1_0], "cpy_f32_q1_0", cpy_f32_q1_0_len, cpy_f32_q1_0_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q4_0], "cpy_f32_q4_0", cpy_f32_q4_0_len, cpy_f32_q4_0_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q4_1], "cpy_f32_q4_1", cpy_f32_q4_1_len, cpy_f32_q4_1_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q5_0], "cpy_f32_q5_0", cpy_f32_q5_0_len, cpy_f32_q5_0_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q5_1], "cpy_f32_q5_1", cpy_f32_q5_1_len, cpy_f32_q5_1_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q8_0], "cpy_f32_q8_0", cpy_f32_q8_0_len, cpy_f32_q8_0_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_IQ4_NL], "cpy_f32_iq4_nl", cpy_f32_iq4_nl_len, cpy_f32_iq4_nl_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1); + +#define SET_ROWS(itype) \ + ggml_vk_create_pipeline(device, device->pipeline_set_rows ## itype [GGML_TYPE_F32], "set_rows_f32" #itype, set_rows_f32 ## itype ## _len, set_rows_f32 ## itype ## _data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true); \ + ggml_vk_create_pipeline(device, device->pipeline_set_rows ## itype [GGML_TYPE_F16], "set_rows_f16" #itype, set_rows_f16 ## itype ## _len, set_rows_f16 ## itype ## _data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true); \ + ggml_vk_create_pipeline(device, device->pipeline_set_rows ## itype [GGML_TYPE_BF16], "set_rows_bf16" #itype, set_rows_bf16 ## itype ## _len, set_rows_bf16 ## itype ## _data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true); \ + ggml_vk_create_pipeline(device, device->pipeline_set_rows ## itype [GGML_TYPE_Q1_0], "set_rows_q1_0" #itype, set_rows_q1_0 ## itype ## _len, set_rows_q1_0 ## itype ## _data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true); \ + ggml_vk_create_pipeline(device, device->pipeline_set_rows ## itype [GGML_TYPE_Q4_0], "set_rows_q4_0" #itype, set_rows_q4_0 ## itype ## _len, set_rows_q4_0 ## itype ## _data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true); \ + ggml_vk_create_pipeline(device, device->pipeline_set_rows ## itype [GGML_TYPE_Q4_1], "set_rows_q4_1" #itype, set_rows_q4_1 ## itype ## _len, set_rows_q4_1 ## itype ## _data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true); \ + ggml_vk_create_pipeline(device, device->pipeline_set_rows ## itype [GGML_TYPE_Q5_0], "set_rows_q5_0" #itype, set_rows_q5_0 ## itype ## _len, set_rows_q5_0 ## itype ## _data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true); \ + ggml_vk_create_pipeline(device, device->pipeline_set_rows ## itype [GGML_TYPE_Q5_1], "set_rows_q5_1" #itype, set_rows_q5_1 ## itype ## _len, set_rows_q5_1 ## itype ## _data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true); \ + ggml_vk_create_pipeline(device, device->pipeline_set_rows ## itype [GGML_TYPE_Q8_0], "set_rows_q8_0" #itype, set_rows_q8_0 ## itype ## _len, set_rows_q8_0 ## itype ## _data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true); \ + ggml_vk_create_pipeline(device, device->pipeline_set_rows ## itype [GGML_TYPE_IQ4_NL], "set_rows_iq4_nl" #itype, set_rows_iq4_nl ## itype ## _len, set_rows_iq4_nl ## itype ## _data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true); + + SET_ROWS(_i32) + SET_ROWS(_i64) #undef SET_ROWS + ggml_vk_create_pipeline(device, device->pipeline_cpy_quant_f32[GGML_TYPE_Q1_0], "cpy_q1_0_f32", cpy_q1_0_f32_len, cpy_q1_0_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q1_0), 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_cpy_quant_f32[GGML_TYPE_Q4_0], "cpy_q4_0_f32", cpy_q4_0_f32_len, cpy_q4_0_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q4_0), 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_cpy_quant_f32[GGML_TYPE_Q4_1], "cpy_q4_1_f32", cpy_q4_1_f32_len, cpy_q4_1_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q4_1), 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_cpy_quant_f32[GGML_TYPE_Q5_0], "cpy_q5_0_f32", cpy_q5_0_f32_len, cpy_q5_0_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q5_0), 1, 1}, {}, 1); @@ -4085,11 +4939,10 @@ static void ggml_vk_load_shaders(vk_device& device) { return s; }; - bool rte = device->float_controls_rte_fp16; #define CREATE_BINARY(name, namemod, spec, bindings) \ for (int s0 : {0,1}) for (int s1 : {0,1}) for (int d : {0,1}) \ ggml_vk_create_pipeline2(device, device->pipeline_ ## name ## namemod[s0][s1][d], \ - #name + get_suffix(s0, s1, d) + #namemod, name ## _len[s0][s1][d][rte], name ## _data[s0][s1][d][rte], \ + #name + get_suffix(s0, s1, d) + #namemod, name ## _len[s0][s1][d], name ## _data[s0][s1][d], \ "main", (bindings), sizeof(vk_op_binary_push_constants), {512, 1, 1}, spec, 1); CREATE_BINARY(add, , {0}, 4) @@ -4113,7 +4966,8 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_add_id_f32, "add_id_f32", add_id_f32_len, add_id_f32_data, "main", 4, sizeof(vk_op_add_id_push_constants), {1, 1, 1}, {}, 1); - ggml_vk_create_pipeline(device, device->pipeline_acc_f32, "acc_f32", acc_f32_len, acc_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_acc_f32, "acc_f32", acc_f32_len, acc_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {0, 1}, 1); + ggml_vk_create_pipeline(device, device->pipeline_set_f32, "set_f32", acc_f32_len, acc_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {0, 0}, 1); ggml_vk_create_pipeline(device, device->pipeline_concat_f32, "concat_f32", concat_f32_len, concat_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_concat_f16, "concat_f16", concat_f16_len, concat_f16_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1); @@ -4131,13 +4985,8 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_sin_f32, "sin_f32", sin_f32_len, sin_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_cos_f32, "cos_f32", cos_f32_len, cos_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); - if (device->float_controls_rte_fp16) { - ggml_vk_create_pipeline(device, device->pipeline_log[0], "log_f32_rte", log_f32_rte_len, log_f32_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); - ggml_vk_create_pipeline(device, device->pipeline_log[1], "log_f16_rte", log_f16_rte_len, log_f16_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); - } else { - ggml_vk_create_pipeline(device, device->pipeline_log[0], "log_f32", log_f32_len, log_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); - ggml_vk_create_pipeline(device, device->pipeline_log[1], "log_f16", log_f16_len, log_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); - } + ggml_vk_create_pipeline(device, device->pipeline_log[0], "log_f32", log_f32_len, log_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_log[1], "log_f16", log_f16_len, log_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_tri[0], "tri_f32", tri_f32_len, tri_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_tri[1], "tri_f16", tri_f16_len, tri_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); @@ -4151,13 +5000,16 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_roll_f32, "roll_f32", roll_f32_len, roll_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); - ggml_vk_create_pipeline(device, device->pipeline_repeat_f32, "repeat_f32", repeat_f32_len, repeat_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_repeat_i32, "repeat_i32", repeat_i32_len, repeat_i32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_repeat_back_f32, "repeat_back_f32", repeat_back_f32_len, repeat_back_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_repeat_i16, "repeat_i16", repeat_i16_len, repeat_i16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); + #define CREATE_UNARY(name) \ ggml_vk_create_pipeline(device, device->pipeline_ ## name [0], #name "_f32", name ## _f32_len, name ## _f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1); \ ggml_vk_create_pipeline(device, device->pipeline_ ## name [1], #name "_f16", name ## _f16_len, name ## _f16_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1); + CREATE_UNARY(elu) CREATE_UNARY(gelu) CREATE_UNARY(gelu_erf) CREATE_UNARY(gelu_quick) @@ -4176,19 +5028,10 @@ static void ggml_vk_load_shaders(vk_device& device) { CREATE_UNARY(ceil) CREATE_UNARY(floor) CREATE_UNARY(trunc) + CREATE_UNARY(sgn) + CREATE_UNARY(exp) #undef CREATE_UNARY -#define CREATE_UNARY_RTE(name) \ - if (device->float_controls_rte_fp16) { \ - ggml_vk_create_pipeline(device, device->pipeline_ ## name [0], #name "_f32_rte", name ## _f32_rte_len, name ## _f32_rte_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1); \ - ggml_vk_create_pipeline(device, device->pipeline_ ## name [1], #name "_f16_rte", name ## _f16_rte_len, name ## _f16_rte_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1); \ - } else { \ - ggml_vk_create_pipeline(device, device->pipeline_ ## name [0], #name "_f32", name ## _f32_len, name ## _f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1); \ - ggml_vk_create_pipeline(device, device->pipeline_ ## name [1], #name "_f16", name ## _f16_len, name ## _f16_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1); \ - } - CREATE_UNARY_RTE(exp) -#undef CREATE_UNARY_RTE - ggml_vk_create_pipeline(device, device->pipeline_add1_f16_f16, "add1_f16_f16", add1_f16_f16_len, add1_f16_f16_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_add1_f16_f32, "add1_f16_f32", add1_f16_f32_len, add1_f16_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_add1_f32_f32, "add1_f32_f32", add1_f32_f32_len, add1_f32_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1); @@ -4196,15 +5039,11 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_arange_f32, "arange_f32", arange_f32_len, arange_f32_data, "main", 1, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_fill_f32, "fill_f32", fill_f32_len, fill_f32_data, "main", 1, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_fill_f16, "fill_f16", fill_f16_len, fill_f16_data, "main", 1, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1); #define CREATE_GLU(name) \ - if (device->float_controls_rte_fp16) { \ - ggml_vk_create_pipeline(device, device->pipeline_ ## name [0], #name "_f32_rte", name ## _f32_rte_len, name ## _f32_rte_data, "main", 3, sizeof(vk_op_glu_push_constants), {512, 1, 1}, {}, 1, true); \ - ggml_vk_create_pipeline(device, device->pipeline_ ## name [1], #name "_f16_rte", name ## _f16_rte_len, name ## _f16_rte_data, "main", 3, sizeof(vk_op_glu_push_constants), {512, 1, 1}, {}, 1, true); \ - } else { \ - ggml_vk_create_pipeline(device, device->pipeline_ ## name [0], #name "_f32", name ## _f32_len, name ## _f32_data, "main", 3, sizeof(vk_op_glu_push_constants), {512, 1, 1}, {}, 1, true); \ - ggml_vk_create_pipeline(device, device->pipeline_ ## name [1], #name "_f16", name ## _f16_len, name ## _f16_data, "main", 3, sizeof(vk_op_glu_push_constants), {512, 1, 1}, {}, 1, true); \ - } + ggml_vk_create_pipeline(device, device->pipeline_ ## name [0], #name "_f32", name ## _f32_len, name ## _f32_data, "main", 3, sizeof(vk_op_glu_push_constants), {512, 1, 1}, {}, 1, true); \ + ggml_vk_create_pipeline(device, device->pipeline_ ## name [1], #name "_f16", name ## _f16_len, name ## _f16_data, "main", 3, sizeof(vk_op_glu_push_constants), {512, 1, 1}, {}, 1, true); CREATE_GLU(geglu) CREATE_GLU(reglu) @@ -4237,25 +5076,14 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_rope_multi_f32, "rope_multi_f32", rope_multi_f32_len, rope_multi_f32_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_rope_vision_f32, "rope_vision_f32", rope_vision_f32_len, rope_vision_f32_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1); - if (device->float_controls_rte_fp16) { - ggml_vk_create_pipeline(device, device->pipeline_rope_norm_f16, "rope_norm_f16", rope_norm_f16_rte_len, rope_norm_f16_rte_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1); - ggml_vk_create_pipeline(device, device->pipeline_rope_neox_f16, "rope_neox_f16", rope_neox_f16_rte_len, rope_neox_f16_rte_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1); - ggml_vk_create_pipeline(device, device->pipeline_rope_multi_f16, "rope_multi_f16", rope_multi_f16_rte_len, rope_multi_f16_rte_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1); - ggml_vk_create_pipeline(device, device->pipeline_rope_vision_f16, "rope_vision_f16", rope_vision_f16_rte_len, rope_vision_f16_rte_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1); - - ggml_vk_create_pipeline(device, device->pipeline_rope_norm_f32_f16, "rope_norm_f32_f16", rope_norm_f32_f16_rte_len, rope_norm_f32_f16_rte_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1); - ggml_vk_create_pipeline(device, device->pipeline_rope_neox_f32_f16, "rope_neox_f32_f16", rope_neox_f32_f16_rte_len, rope_neox_f32_f16_rte_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1); - ggml_vk_create_pipeline(device, device->pipeline_rope_multi_f32_f16, "rope_multi_f32_f16", rope_multi_f32_f16_rte_len, rope_multi_f32_f16_rte_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1); - } else { - ggml_vk_create_pipeline(device, device->pipeline_rope_norm_f16, "rope_norm_f16", rope_norm_f16_len, rope_norm_f16_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1); - ggml_vk_create_pipeline(device, device->pipeline_rope_neox_f16, "rope_neox_f16", rope_neox_f16_len, rope_neox_f16_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1); - ggml_vk_create_pipeline(device, device->pipeline_rope_multi_f16, "rope_multi_f16", rope_multi_f16_len, rope_multi_f16_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1); - ggml_vk_create_pipeline(device, device->pipeline_rope_vision_f16, "rope_vision_f16", rope_vision_f16_len, rope_vision_f16_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_rope_norm_f16, "rope_norm_f16", rope_norm_f16_len, rope_norm_f16_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_rope_neox_f16, "rope_neox_f16", rope_neox_f16_len, rope_neox_f16_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_rope_multi_f16, "rope_multi_f16", rope_multi_f16_len, rope_multi_f16_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_rope_vision_f16, "rope_vision_f16", rope_vision_f16_len, rope_vision_f16_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1); - ggml_vk_create_pipeline(device, device->pipeline_rope_norm_f32_f16, "rope_norm_f32_f16", rope_norm_f32_f16_len, rope_norm_f32_f16_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1); - ggml_vk_create_pipeline(device, device->pipeline_rope_neox_f32_f16, "rope_neox_f32_f16", rope_neox_f32_f16_len, rope_neox_f32_f16_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1); - ggml_vk_create_pipeline(device, device->pipeline_rope_multi_f32_f16, "rope_multi_f32_f16", rope_multi_f32_f16_len, rope_multi_f32_f16_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1); - } + ggml_vk_create_pipeline(device, device->pipeline_rope_norm_f32_f16, "rope_norm_f32_f16", rope_norm_f32_f16_len, rope_norm_f32_f16_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_rope_neox_f32_f16, "rope_neox_f32_f16", rope_neox_f32_f16_len, rope_neox_f32_f16_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_rope_multi_f32_f16, "rope_multi_f32_f16", rope_multi_f32_f16_len, rope_multi_f32_f16_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1); for (uint32_t i = 0; i < num_argsort_pipelines; ++i) { uint32_t BLOCK_SIZE = 1u << std::min(i, device->max_workgroup_size_log2); @@ -4289,6 +5117,24 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_argmax_f32, "argmax_f32", argmax_f32_len, argmax_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, { device->subgroup_size }, 1); ggml_vk_create_pipeline(device, device->pipeline_sum_rows_f32, "sum_rows_f32", sum_rows_f32_len, sum_rows_f32_data, "main", 2, sizeof(vk_op_sum_rows_push_constants), {1, 1, 1}, { device->subgroup_size }, 1); + // Intel Arc B390 was observed segfaulting with this shader. + if (device->subgroup_basic && device->subgroup_shuffle && device->vendor_id != VK_VENDOR_ID_INTEL) { + int idx = 0; + for (uint32_t n : {64, 128, 256, 512}) { + if (device->subgroup_size <= n) { + ggml_vk_create_pipeline(device, device->pipeline_fwht_f32[idx], "fwht_f32", fwht_f32_len, fwht_f32_data, "main", 2, sizeof(vk_op_fwht_push_constants), {1, 1, 1}, { device->subgroup_size, n }, 1, true, true, device->subgroup_size); + } + ++idx; + } + } else if (device->driver_id != vk::DriverId::eIntelProprietaryWindows) { + // Disabled on Intel Windows due to a driver bug: https://github.com/ggml-org/llama.cpp/pull/23964#issuecomment-4598226147 + int idx = 0; + for (uint32_t n : {64, 128, 256, 512}) { + const uint32_t block_size = std::min(device->subgroup_size, n); + ggml_vk_create_pipeline(device, device->pipeline_fwht_f32[idx], "fwht_shmem_f32", fwht_shmem_f32_len, fwht_shmem_f32_data, "main", 2, sizeof(vk_op_fwht_push_constants), {1, 1, 1}, { block_size, n }, 1); + ++idx; + } + } const uint32_t cumsum_elem_per_thread = (device->vendor_id == VK_VENDOR_ID_AMD || device->vendor_id == VK_VENDOR_ID_INTEL) ? 2 : 4; ggml_vk_create_pipeline(device, device->pipeline_cumsum_f32, "cumsum_f32", cumsum_f32_len, cumsum_f32_data, "main", 2, sizeof(vk_op_sum_rows_push_constants), {1, 1, 1}, { 256, device->subgroup_size, cumsum_elem_per_thread }, 1, true, true, device->subgroup_size); @@ -4317,13 +5163,8 @@ static void ggml_vk_load_shaders(vk_device& device) { #define IM2COL(bda) \ ggml_vk_create_pipeline(device, device->pipeline_im2col_f32, "im2col_f32", im2col_f32 ## bda ## _len, im2col_f32 ## bda ## _data, "main", 2, sizeof(vk_op_im2col_push_constants), {512, 1, 1}, { device->subgroup_size }, 1, true); \ ggml_vk_create_pipeline(device, device->pipeline_im2col_3d_f32, "im2col_3d_f32", im2col_3d_f32 ## bda ## _len, im2col_3d_f32 ## bda ## _data, "main", 2, sizeof(vk_op_im2col_3d_push_constants), {512, 1, 1}, { 512 }, 1, true); \ - if (device->float_controls_rte_fp16) { \ - ggml_vk_create_pipeline(device, device->pipeline_im2col_f32_f16, "im2col_f32_f16", im2col_f32_f16_rte ## bda ## _len, im2col_f32_f16_rte ## bda ## _data, "main", 2, sizeof(vk_op_im2col_push_constants), {512, 1, 1}, { device->subgroup_size }, 1, true); \ - ggml_vk_create_pipeline(device, device->pipeline_im2col_3d_f32_f16, "im2col_3d_f32_f16", im2col_3d_f32_f16_rte ## bda ## _len, im2col_3d_f32_f16_rte ## bda ## _data, "main", 2, sizeof(vk_op_im2col_3d_push_constants), {512, 1, 1}, { 512 }, 1, true); \ - } else { \ - ggml_vk_create_pipeline(device, device->pipeline_im2col_f32_f16, "im2col_f32_f16", im2col_f32_f16 ## bda ## _len, im2col_f32_f16 ## bda ## _data, "main", 2, sizeof(vk_op_im2col_push_constants), {512, 1, 1}, { device->subgroup_size }, 1, true); \ - ggml_vk_create_pipeline(device, device->pipeline_im2col_3d_f32_f16, "im2col_3d_f32_f16", im2col_3d_f32_f16 ## bda ## _len, im2col_3d_f32_f16 ## bda ## _data, "main", 2, sizeof(vk_op_im2col_3d_push_constants), {512, 1, 1}, { 512 }, 1, true); \ - } + ggml_vk_create_pipeline(device, device->pipeline_im2col_f32_f16, "im2col_f32_f16", im2col_f32_f16 ## bda ## _len, im2col_f32_f16 ## bda ## _data, "main", 2, sizeof(vk_op_im2col_push_constants), {512, 1, 1}, { device->subgroup_size }, 1, true); \ + ggml_vk_create_pipeline(device, device->pipeline_im2col_3d_f32_f16, "im2col_3d_f32_f16", im2col_3d_f32_f16 ## bda ## _len, im2col_3d_f32_f16 ## bda ## _data, "main", 2, sizeof(vk_op_im2col_3d_push_constants), {512, 1, 1}, { 512 }, 1, true); if (device->shader_int64 && device->buffer_device_address) { IM2COL(_bda) } else { @@ -4334,12 +5175,63 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_conv_transpose_1d_f32, "conv_transpose_1d_f32", conv_transpose_1d_f32_len, conv_transpose_1d_f32_data, "main", 3, sizeof(vk_op_conv_transpose_1d_push_constants), {1, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_snake_f32, "snake_f32", snake_f32_len, snake_f32_data, "main", 4, sizeof(vk_op_snake_push_constants), {256, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_snake_f16, "snake_f16", snake_f16_len, snake_f16_data, "main", 4, sizeof(vk_op_snake_push_constants), {256, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_snake_bf16, "snake_bf16", snake_bf16_len, snake_bf16_data, "main", 4, sizeof(vk_op_snake_push_constants), {256, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_pool2d_f32, "pool2d_f32", pool2d_f32_len, pool2d_f32_data, "main", 2, sizeof(vk_op_pool2d_push_constants), {512, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_rwkv_wkv6_f32, "rwkv_wkv6_f32", rwkv_wkv6_f32_len, rwkv_wkv6_f32_data, "main", 7, sizeof(vk_op_rwkv_wkv6_push_constants), {1, 1, 1}, {device->subgroup_size}, 1); ggml_vk_create_pipeline(device, device->pipeline_rwkv_wkv7_f32, "rwkv_wkv7_f32", rwkv_wkv7_f32_len, rwkv_wkv7_f32_data, "main", 8, sizeof(vk_op_rwkv_wkv7_push_constants), {1, 1, 1}, {device->subgroup_size}, 1); + { + const uint32_t gdn_sizes[] = {32, 64, 128}; + const char * gdn_names[][2] = { + {"gated_delta_net_f32_d32", "gated_delta_net_f32_d32_kda"}, + {"gated_delta_net_f32_d64", "gated_delta_net_f32_d64_kda"}, + {"gated_delta_net_f32_d128", "gated_delta_net_f32_d128_kda"}, + }; + const bool use_subgroup_reduce = device->subgroup_arithmetic; + for (uint32_t si = 0; si < 3; si++) { + const uint32_t S_V = gdn_sizes[si]; + GGML_ASSERT(is_pow2(S_V)); + + uint32_t lanes_per_column; + if (S_V >= 128u && device->subgroup_clustered) { + lanes_per_column = 8u; + } else { + // Use largest power-of-two that divides both S_V and subgroup_size so that + // (1) S_V % lanes_per_column == 0 and (2) S_V % (subgroup_size / lanes_per_column) == 0. + // This means we don't need extra bounds checking logic in the shader. + lanes_per_column = std::min(S_V, device->subgroup_size); + } + + const bool need_clustered_shader = lanes_per_column != 1 && (lanes_per_column < device->subgroup_size); + size_t gdn_len; + const void * gdn_data; + if (use_subgroup_reduce && need_clustered_shader) { + gdn_len = gated_delta_net_f32_len; + gdn_data = (const void *)gated_delta_net_f32_data; + } else if (use_subgroup_reduce) { + gdn_len = gated_delta_net_f32_nocluster_len; + gdn_data = (const void *)gated_delta_net_f32_nocluster_data; + } else { + gdn_len = gated_delta_net_f32_shmem_len; + gdn_data = (const void *)gated_delta_net_f32_shmem_data; + } + + const uint32_t cols_per_wg = device->subgroup_size / lanes_per_column; + const std::array<uint32_t, 3> wg_denoms = {1u, 1u, cols_per_wg}; + + for (uint32_t kda = 0; kda < 2; kda++) { + ggml_vk_create_pipeline(device, device->pipeline_gated_delta_net[si][kda], + gdn_names[si][kda], gdn_len, gdn_data, "main", 7, sizeof(vk_op_gated_delta_net_push_constants), + wg_denoms, {S_V, kda, device->subgroup_size, lanes_per_column}, 1, true, use_subgroup_reduce, device->subgroup_size); + } + } + } + if (device->subgroup_arithmetic && device->subgroup_require_full_support) { ggml_vk_create_pipeline(device, device->pipeline_ssm_scan_f32_d128, "ssm_scan_128_f32", ssm_scan_subgroup_f32_len, ssm_scan_subgroup_f32_data, "main", 8, sizeof(vk_op_ssm_scan_push_constants), {1, 1, 1}, {128, device->subgroup_size}, 1, true, true); ggml_vk_create_pipeline(device, device->pipeline_ssm_scan_f32_d256, "ssm_scan_256_f32", ssm_scan_subgroup_f32_len, ssm_scan_subgroup_f32_data, "main", 8, sizeof(vk_op_ssm_scan_push_constants), {1, 1, 1}, {256, device->subgroup_size}, 1, true, true); @@ -4348,7 +5240,9 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_ssm_scan_f32_d256, "ssm_scan_256_f32", ssm_scan_f32_len, ssm_scan_f32_data, "main", 8, sizeof(vk_op_ssm_scan_push_constants), {1, 1, 1}, {256, device->subgroup_size, 16}, 1, true, true); } - ggml_vk_create_pipeline(device, device->pipeline_ssm_conv_f32, "ssm_conv_f32", ssm_conv_f32_len, ssm_conv_f32_data, "main", 3, sizeof(vk_op_ssm_conv_push_constants), {32, 1, 1}, {32}, 1); + ggml_vk_create_pipeline(device, device->pipeline_ssm_conv_f32, "ssm_conv_f32", ssm_conv_f32_len, ssm_conv_f32_data, "main", 4, sizeof(vk_op_ssm_conv_push_constants), {32, 16, 1}, {32, 16, 0, 0}, 1); + ggml_vk_create_pipeline(device, device->pipeline_ssm_conv_silu_f32, "ssm_conv_silu_f32", ssm_conv_f32_len, ssm_conv_f32_data, "main", 4, sizeof(vk_op_ssm_conv_push_constants), {32, 16, 1}, {32, 16, 0, 1}, 1); + ggml_vk_create_pipeline(device, device->pipeline_ssm_conv_bias_silu_f32, "ssm_conv_bias_silu_f32", ssm_conv_f32_len, ssm_conv_f32_data, "main", 4, sizeof(vk_op_ssm_conv_push_constants), {32, 16, 1}, {32, 16, 1, 1}, 1); ggml_vk_create_pipeline(device, device->pipeline_opt_step_adamw_f32, "opt_step_adamw_f32", opt_step_adamw_f32_len, opt_step_adamw_f32_data, "main", 5, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1); @@ -4356,7 +5250,8 @@ static void ggml_vk_load_shaders(vk_device& device) { // conv2d, conv_transpose_2d for (uint32_t s = 0; s < CONV_SHAPE_COUNT; ++s) { - uint32_t conv2d_WG_SIZE = 256; + // smaller WG for the small-tile fallback gives more concurrent WGs per SM + uint32_t conv2d_WG_SIZE = (s == CONV_SHAPE_64x32) ? 128 : 256; uint32_t use_collectives = 0; // Enables subgroup ops for preventing the re-calculation of indices. uint32_t conv2d_TS_K = (s == CONV_SHAPE_64x32) ? 4 : 8; uint32_t conv2d_SHMEM_PAD = 4; @@ -4395,18 +5290,77 @@ static void ggml_vk_load_shaders(vk_device& device) { conv2d_BS.CRS); // CRS block size should be capped at subgroup size for correctness when shuffle is used. } - uint32_t conv2d_shmem_req = - (conv2d_BS.K * (conv2d_BS.CRS + conv2d_SHMEM_PAD) + conv2d_BS.CRS * (conv2d_BS.NPQ + conv2d_SHMEM_PAD)) * sizeof(float); - if (device->properties.limits.maxComputeSharedMemorySize < conv2d_shmem_req) { + // cm1 is used only when cm2 is unavailable; capped at 64x128 (due to shared memory size). + // Requires 16x16x16 f16-acc since that's the fragment shape hard-coded in the shader. + // Subgroup size must be 32 or 64 (to keep WG_SIZE sane) and we need + // subgroup_size_control to force the driver to actually use it. + bool conv2d_use_cm1 = false; +#if defined(VK_KHR_cooperative_matrix) && defined(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT) + conv2d_use_cm1 = !device->coopmat2 && + device->coopmat_support && device->coopmat_support_16x16x16_f16acc && + device->subgroup_size_control && + (device->subgroup_size == 32 || device->subgroup_size == 64) && + s != CONV_SHAPE_128x128; +#endif + + const uint32_t conv2d_cm1_shmem_pad = 8; + + auto shmem_req = [&](uint32_t pad, bool csh_store, bool fp16_shmem) { + const uint32_t elem_size = fp16_shmem ? (uint32_t)sizeof(uint16_t) : (uint32_t)sizeof(float); + const uint32_t csh_elems = csh_store ? conv2d_BS.K * conv2d_BS.NPQ : 0u; + return (conv2d_BS.K * (conv2d_BS.CRS + pad) + conv2d_BS.CRS * (conv2d_BS.NPQ + pad) + csh_elems) * elem_size; + }; + + // coopmat1 needs to store the output through shared memory, so check up front + // whether it'll fit and disable it before applying coopmat1 parameters. + if (conv2d_use_cm1 && device->properties.limits.maxComputeSharedMemorySize < shmem_req(conv2d_cm1_shmem_pad, true, true)) { + conv2d_use_cm1 = false; + } + + uint32_t conv2d_WM = 16, conv2d_WN = 16; // cm1 subgroup tile, ignored otherwise + if (conv2d_use_cm1) { + conv2d_SHMEM_PAD = conv2d_cm1_shmem_pad; + // 16x16x16 fragments; pick WM/WN to keep WG_SIZE at 256 + // (i.e. 8 subgroups for sg=32, 4 subgroups for sg=64). + const bool sg64 = (device->subgroup_size == 64); + switch (s) { + case CONV_SHAPE_64x32: conv2d_WM = sg64 ? 32 : 16; conv2d_WN = 16; break; + case CONV_SHAPE_64x128: conv2d_WM = 32; conv2d_WN = sg64 ? 64 : 32; break; + case CONV_SHAPE_32x256: conv2d_WM = sg64 ? 16 : 32; conv2d_WN = sg64 ? 128 : 32; break; + default: break; + } + const uint32_t warps_M = conv2d_BS.K / conv2d_WM; + const uint32_t warps_N = conv2d_BS.NPQ / conv2d_WN; + conv2d_WG_SIZE = warps_M * warps_N * device->subgroup_size; + } + + // stage cm2 accumulator through shmem for coalesced global stores; + // skipped on 128x128 where the extra Csh footprint hurts occupancy. + // cm1 always uses the staged path. + uint32_t conv2d_csh_store = (device->coopmat2 && s != CONV_SHAPE_128x128) ? 1u : 0u; + if (conv2d_use_cm1) { + conv2d_csh_store = 1; + } + + // shmem is fp16 on cm2/cm1 (matches Csh), fp32 on scalar + const bool conv2d_use_fp16_shmem = device->coopmat2 || conv2d_use_cm1; + + // shrink CRS if the non-cm1 config still doesn't fit + if (device->properties.limits.maxComputeSharedMemorySize < shmem_req(conv2d_SHMEM_PAD, conv2d_csh_store, conv2d_use_fp16_shmem)) { + GGML_ASSERT(!conv2d_use_cm1); conv2d_BS.CRS = 8; if (use_collectives) { conv2d_BS.CRS = std::min(device->subgroup_size, conv2d_BS.CRS); } + conv2d_csh_store = 0; } std::array<uint32_t, 3> wg_denoms = { conv2d_BS.K, 1, 1 }; std::vector<uint32_t> spec_constants = { conv2d_WG_SIZE, conv2d_BS.K, conv2d_BS.CRS, conv2d_BS.NPQ, conv2d_TS_K, use_collectives, conv2d_SHMEM_PAD }; + // cm1 needs a fixed subgroup width to match the WG_SIZE we computed + const uint32_t conv2d_required_subgroup_size = conv2d_use_cm1 ? device->subgroup_size : 0; + #define CREATE_CONV(name, type_suffix, spv_suffix) \ for (auto &c : device->pipeline_##name##type_suffix[s]) { \ const vk_conv2d_pipeline_state &state = c.first; \ @@ -4419,10 +5373,14 @@ static void ggml_vk_load_shaders(vk_device& device) { spec_constants_cpy.push_back(state.d1); \ spec_constants_cpy.push_back(state.KW); \ spec_constants_cpy.push_back(state.KH); \ + spec_constants_cpy.push_back(state.aligned); \ + spec_constants_cpy.push_back(conv2d_csh_store); \ + spec_constants_cpy.push_back(conv2d_WM); \ + spec_constants_cpy.push_back(conv2d_WN); \ ggml_vk_create_pipeline( \ device, c.second, #name #type_suffix, \ name##type_suffix##spv_suffix##_len, name##type_suffix##spv_suffix##_data, "main", 3, \ - sizeof(vk_op_conv2d_push_constants), wg_denoms, spec_constants_cpy, 1, true, use_collectives); \ + sizeof(vk_op_conv2d_push_constants), wg_denoms, spec_constants_cpy, 1, true, use_collectives || conv2d_required_subgroup_size, conv2d_required_subgroup_size); \ } #define CREATE_CONVS(spv_suffix) \ CREATE_CONV(conv2d, _f32, spv_suffix) \ @@ -4433,6 +5391,11 @@ static void ggml_vk_load_shaders(vk_device& device) { if (device->coopmat2) { CREATE_CONVS(_cm2) } else +#endif +#if defined(VK_KHR_cooperative_matrix) && defined(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT) + if (conv2d_use_cm1) { + CREATE_CONVS(_cm1) + } else #endif if (conv2d_UNROLL) { CREATE_CONVS(_unroll) @@ -4454,12 +5417,30 @@ static void ggml_vk_load_shaders(vk_device& device) { } } - for (auto &c : compiles) { - c.wait(); + // Drop compile_mutex so other threads can walk while we compile. + compile_lock.unlock(); + + // Compile what we claimed; create_pipeline_func reacquires compile_mutex + // at the end to flip compile_pending/compiled and notify waiters. + if (has_claimed_task) { + auto & task = claimed_task; + ggml_vk_create_pipeline_func(device, task.pipeline, task.spv_size, task.spv_data, + task.entrypoint, task.parameter_count, task.wg_denoms, + task.specialization_constants, task.disable_robustness, + task.require_full_subgroups, task.required_subgroup_size); + } + + // Another thread may be compiling the pipeline we need; block on it here. + if (wait_pipeline) { + std::unique_lock<std::mutex> wait_lock(device->compile_mutex); + device->compile_cv.wait(wait_lock, [&] { + return wait_pipeline->compiled.load(); + }); } } static bool ggml_vk_khr_cooperative_matrix_support(const vk::PhysicalDeviceProperties& props, const vk::PhysicalDeviceDriverProperties& driver_props, vk_device_architecture arch); +static uint32_t ggml_vk_intel_shader_core_count(const vk::PhysicalDevice& vkdev); static vk_device ggml_vk_get_device(size_t idx) { VK_LOG_DEBUG("ggml_vk_get_device(" << idx << ")"); @@ -4504,11 +5485,13 @@ static vk_device ggml_vk_get_device(size_t idx) { bool amd_shader_core_properties2 = false; bool pipeline_robustness = false; bool coopmat2_support = false; + bool coopmat2_decode_vector_support = false; bool pipeline_executable_properties_support = false; device->coopmat_support = false; device->integer_dot_product = false; device->shader_64b_indexing = false; bool bfloat16_support = false; + bool dot2_f16_support = false; for (const auto& properties : ext_props) { if (strcmp("VK_KHR_maintenance4", properties.extensionName) == 0) { @@ -4538,6 +5521,9 @@ static vk_device ggml_vk_get_device(size_t idx) { !getenv("GGML_VK_DISABLE_COOPMAT2")) { coopmat2_support = true; #endif + } else if (strcmp(VK_NV_COOPERATIVE_MATRIX_DECODE_VECTOR_EXTENSION_NAME, properties.extensionName) == 0 && + !getenv("GGML_VK_DISABLE_COOPMAT2_DECODE_VECTOR")) { + coopmat2_decode_vector_support = true; #if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT) } else if (strcmp("VK_KHR_shader_integer_dot_product", properties.extensionName) == 0 && !getenv("GGML_VK_DISABLE_INTEGER_DOT_PRODUCT")) { @@ -4548,6 +5534,9 @@ static vk_device ggml_vk_get_device(size_t idx) { !getenv("GGML_VK_DISABLE_BFLOAT16")) { bfloat16_support = true; #endif + } else if (strcmp("VK_VALVE_shader_mixed_float_dot_product", properties.extensionName) == 0 && + !getenv("GGML_VK_DISABLE_DOT2")) { + dot2_f16_support = true; } else if (strcmp("VK_KHR_pipeline_executable_properties", properties.extensionName) == 0) { pipeline_executable_properties_support = true; } else if (strcmp("VK_EXT_memory_priority", properties.extensionName) == 0 && @@ -4676,6 +5665,8 @@ static vk_device ggml_vk_get_device(size_t idx) { device->shader_core_count = sm_props.shaderSMCount; } else if (amd_shader_core_properties2) { device->shader_core_count = amd_shader_core_properties2_props.activeComputeUnitCount; + } else if (device->vendor_id == VK_VENDOR_ID_INTEL) { + device->shader_core_count = ggml_vk_intel_shader_core_count(device->physical_device); } else { device->shader_core_count = 0; } @@ -4693,6 +5684,11 @@ static vk_device ggml_vk_get_device(size_t idx) { #endif device->subgroup_shuffle = (vk11_props.subgroupSupportedStages & vk::ShaderStageFlagBits::eCompute) && (vk11_props.subgroupSupportedOperations & vk::SubgroupFeatureFlagBits::eShuffle); +#ifdef __APPLE__ + if (device->vendor_id == VK_VENDOR_ID_AMD) { + device->subgroup_shuffle = false; + } +#endif device->subgroup_clustered = (vk11_props.subgroupSupportedStages & vk::ShaderStageFlagBits::eCompute) && (vk11_props.subgroupSupportedOperations & vk::SubgroupFeatureFlagBits::eClustered); @@ -4719,8 +5715,11 @@ static vk_device ggml_vk_get_device(size_t idx) { std::vector<vk::QueueFamilyProperties> queue_family_props = device->physical_device.getQueueFamilyProperties(); // Try to find a non-graphics compute queue and transfer-focused queues - const uint32_t compute_queue_family_index = ggml_vk_find_queue_family_index(queue_family_props, vk::QueueFlagBits::eCompute, vk::QueueFlagBits::eGraphics, -1, 1); - const uint32_t transfer_queue_family_index = ggml_vk_find_queue_family_index(queue_family_props, vk::QueueFlagBits::eTransfer, vk::QueueFlagBits::eCompute | vk::QueueFlagBits::eGraphics, compute_queue_family_index, 1); + // Allow overriding avoiding the graphics queue because it can increase performance on RADV + const bool allow_graphics_queue = (getenv("GGML_VK_ALLOW_GRAPHICS_QUEUE") != nullptr); + const vk::QueueFlagBits graphics_flag = allow_graphics_queue ? (vk::QueueFlagBits)0 : vk::QueueFlagBits::eGraphics; + const uint32_t compute_queue_family_index = ggml_vk_find_queue_family_index(queue_family_props, vk::QueueFlagBits::eCompute, graphics_flag, -1, 1); + const uint32_t transfer_queue_family_index = ggml_vk_find_queue_family_index(queue_family_props, vk::QueueFlagBits::eTransfer, vk::QueueFlagBits::eCompute | graphics_flag, compute_queue_family_index, 1); const float priorities[] = { 1.0f, 1.0f }; device->single_queue = compute_queue_family_index == transfer_queue_family_index && queue_family_props[compute_queue_family_index].queueCount == 1; @@ -4734,7 +5733,7 @@ static vk_device ggml_vk_get_device(size_t idx) { } else { device_queue_create_infos.push_back({vk::DeviceQueueCreateFlags(), compute_queue_family_index, 1, priorities}); } - vk::DeviceCreateInfo device_create_info; + vk::DeviceCreateInfo device_create_info{}; std::vector<const char *> device_extensions; vk::PhysicalDeviceFeatures device_features = device->physical_device.getFeatures(); @@ -4810,6 +5809,14 @@ static vk_device ggml_vk_get_device(size_t idx) { } #endif + VkPhysicalDeviceCooperativeMatrixDecodeVectorFeaturesNV coopmat2_decode_vector_features {}; + coopmat2_decode_vector_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_COOPERATIVE_MATRIX_DECODE_VECTOR_FEATURES_NV; + if (coopmat2_decode_vector_support) { + last_struct->pNext = (VkBaseOutStructure *)&coopmat2_decode_vector_features; + last_struct = (VkBaseOutStructure *)&coopmat2_decode_vector_features; + device_extensions.push_back(VK_NV_COOPERATIVE_MATRIX_DECODE_VECTOR_EXTENSION_NAME); + } + #if defined(VK_KHR_shader_bfloat16) VkPhysicalDeviceShaderBfloat16FeaturesKHR bfloat16_features {}; bfloat16_features.pNext = nullptr; @@ -4837,6 +5844,14 @@ static vk_device ggml_vk_get_device(size_t idx) { device_extensions.push_back("VK_KHR_shader_integer_dot_product"); } + VkPhysicalDeviceShaderMixedFloatDotProductFeaturesVALVE dot2_features {}; + dot2_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_SHADER_MIXED_FLOAT_DOT_PRODUCT_FEATURES_VALVE; + if (dot2_f16_support) { + last_struct->pNext = (VkBaseOutStructure *)&dot2_features; + last_struct = (VkBaseOutStructure *)&dot2_features; + device_extensions.push_back("VK_VALVE_shader_mixed_float_dot_product"); + } + VkPhysicalDevicePipelineExecutablePropertiesFeaturesKHR pep_features {}; pep_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_PIPELINE_EXECUTABLE_PROPERTIES_FEATURES_KHR; if (pipeline_executable_properties_support) { @@ -4871,6 +5886,8 @@ static vk_device ggml_vk_get_device(size_t idx) { device->bf16 = false; #endif + device->dot2_f16 = dot2_f16_support && dot2_features.shaderMixedFloatDotProductFloat16AccFloat32; + device->pipeline_robustness = pl_robustness_features.pipelineRobustness; device->multi_add = vk12_props.shaderRoundingModeRTEFloat16 && @@ -4895,11 +5912,7 @@ static vk_device ggml_vk_get_device(size_t idx) { #if defined(VK_KHR_cooperative_matrix) device->coopmat_support = device->coopmat_support && coopmat_features.cooperativeMatrix; - - // coopmat1 fa shader currently assumes 32 invocations per subgroup - device->coopmat1_fa_support = device->coopmat_support && device->subgroup_require_full_support && - device->subgroup_size_control && device->subgroup_min_size <= 32 && - device->subgroup_max_size >= 32; + device->coopmat1_fa_support = device->coopmat_support && device->subgroup_require_full_support; #endif if (coopmat2_support) { @@ -4933,46 +5946,73 @@ static vk_device ggml_vk_get_device(size_t idx) { found_fp16_256 = false, found_fp32_128 = false, found_fp32_256 = false; + bool found_bf16_128 = false, + found_bf16_256 = false; // need to support fp16*fp16 with fp16/fp32 accumulator, for workgroupsize 128 // with 32x16x16 and 256 with 32x32x16. for (auto &prop : flexible_dimensions) { if (prop.saturatingAccumulation == VK_FALSE && - prop.scope == VK_SCOPE_WORKGROUP_KHR && - prop.AType == VK_COMPONENT_TYPE_FLOAT16_KHR && - prop.BType == VK_COMPONENT_TYPE_FLOAT16_KHR) { - - if (prop.workgroupInvocations == 128 && - prop.MGranularity <= 32 && - prop.NGranularity <= 16 && - prop.KGranularity <= 16) { - if (prop.CType == VK_COMPONENT_TYPE_FLOAT16_KHR && - prop.ResultType == VK_COMPONENT_TYPE_FLOAT16_KHR) { - found_fp16_128 = true; + prop.scope == VK_SCOPE_WORKGROUP_KHR) { + + if (prop.AType == VK_COMPONENT_TYPE_FLOAT16_KHR && + prop.BType == VK_COMPONENT_TYPE_FLOAT16_KHR) { + + if (prop.workgroupInvocations == 128 && + prop.MGranularity <= 32 && + prop.NGranularity <= 16 && + prop.KGranularity <= 16) { + if (prop.CType == VK_COMPONENT_TYPE_FLOAT16_KHR && + prop.ResultType == VK_COMPONENT_TYPE_FLOAT16_KHR) { + found_fp16_128 = true; + } + if (prop.CType == VK_COMPONENT_TYPE_FLOAT32_KHR && + prop.ResultType == VK_COMPONENT_TYPE_FLOAT32_KHR) { + found_fp32_128 = true; + } } - if (prop.CType == VK_COMPONENT_TYPE_FLOAT32_KHR && - prop.ResultType == VK_COMPONENT_TYPE_FLOAT32_KHR) { - found_fp32_128 = true; + if (prop.workgroupInvocations == 256 && + prop.MGranularity <= 32 && + prop.NGranularity <= 32 && + prop.KGranularity <= 16) { + if (prop.CType == VK_COMPONENT_TYPE_FLOAT16_KHR && + prop.ResultType == VK_COMPONENT_TYPE_FLOAT16_KHR) { + found_fp16_256 = true; + } + if (prop.CType == VK_COMPONENT_TYPE_FLOAT32_KHR && + prop.ResultType == VK_COMPONENT_TYPE_FLOAT32_KHR) { + found_fp32_256 = true; + } } } - if (prop.workgroupInvocations == 256 && - prop.MGranularity <= 32 && - prop.NGranularity <= 32 && - prop.KGranularity <= 16) { - if (prop.CType == VK_COMPONENT_TYPE_FLOAT16_KHR && - prop.ResultType == VK_COMPONENT_TYPE_FLOAT16_KHR) { - found_fp16_256 = true; + +#if defined(VK_KHR_shader_bfloat16) && defined(GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT) + if (prop.AType == VK_COMPONENT_TYPE_BFLOAT16_KHR && + prop.BType == VK_COMPONENT_TYPE_BFLOAT16_KHR && + prop.CType == VK_COMPONENT_TYPE_FLOAT32_KHR && + prop.ResultType == VK_COMPONENT_TYPE_FLOAT32_KHR) { + + if (prop.workgroupInvocations == 128 && + prop.MGranularity <= 32 && + prop.NGranularity <= 16 && + prop.KGranularity <= 16) { + found_bf16_128 = true; } - if (prop.CType == VK_COMPONENT_TYPE_FLOAT32_KHR && - prop.ResultType == VK_COMPONENT_TYPE_FLOAT32_KHR) { - found_fp32_256 = true; + if (prop.workgroupInvocations == 256 && + prop.MGranularity <= 32 && + prop.NGranularity <= 32 && + prop.KGranularity <= 16) { + found_bf16_256 = true; } } +#endif } } if (found_fp16_128 && found_fp16_256 && found_fp32_128 && found_fp32_256 && coopmat2_props.cooperativeMatrixFlexibleDimensionsMaxDimension >= 512) { device->coopmat2 = true; + device->coopmat2_bf16_support = found_bf16_128 && found_bf16_256; + device->coopmat2_decode_vector = coopmat2_decode_vector_support && coopmat2_decode_vector_features.cooperativeMatrixDecodeVector; } } #endif @@ -5107,12 +6147,10 @@ static vk_device ggml_vk_get_device(size_t idx) { #endif device->name = GGML_VK_NAME + std::to_string(idx); - device_create_info = { - vk::DeviceCreateFlags(), - device_queue_create_infos, - {}, - device_extensions - }; + device_create_info + .setFlags(vk::DeviceCreateFlags()) + .setQueueCreateInfos(device_queue_create_infos) + .setPEnabledExtensionNames(device_extensions); device_create_info.setPNext(&device_features2); device->device = device->physical_device.createDevice(device_create_info); @@ -5132,19 +6170,19 @@ static vk_device ggml_vk_get_device(size_t idx) { device->mul_mat_id_m[i] = true; device->mul_mat_id_s[i] = true; break; - case VK_VENDOR_ID_INTEL: - if (!device->coopmat_support || device->architecture != INTEL_XE2) { - device->mul_mat_l[i] = false; - device->mul_mat_id_l[i] = false; - } else { - device->mul_mat_l[i] = true; // if coopmat & XE2+, allow large matmul warptile config for Intel - device->mul_mat_id_l[i] = true; - } + case VK_VENDOR_ID_INTEL: { + // Current Windows driver does not expose BF16 support. + // We only want to use l_warptile if coopmat is available and is Xe2+ + const bool xe2_with_coopmat = device->coopmat_support && device->architecture == INTEL_XE2; + const bool use_l_warptile = (i == GGML_TYPE_BF16) ? (device->coopmat_bf16_support && xe2_with_coopmat) : xe2_with_coopmat; + device->mul_mat_l[i] = use_l_warptile; + device->mul_mat_id_l[i] = use_l_warptile; device->mul_mat_m[i] = true; device->mul_mat_s[i] = true; device->mul_mat_id_m[i] = true; device->mul_mat_id_s[i] = true; break; + } case VK_VENDOR_ID_APPLE: device->mul_mat_l[i] = false; device->mul_mat_m[i] = true; @@ -5163,6 +6201,26 @@ static vk_device ggml_vk_get_device(size_t idx) { device->mul_mat_id_s[i] = true; break; } + +#if VK_HEADER_VERSION >= 287 + // Honeykrisp driver for Asahi Linux doesn't report VK_VENDOR_ID_APPLE. + // Check for Honeykrisp driver and force same configuration as the VK_VENDOR_ID_APPLE case. + if (device->driver_id == vk::DriverId::eMesaHoneykrisp) { + device->mul_mat_l[i] = false; + device->mul_mat_m[i] = true; + device->mul_mat_s[i] = false; + device->mul_mat_id_l[i] = false; + device->mul_mat_id_m[i] = true; + device->mul_mat_id_s[i] = false; + } +#endif + + device->mul_mat_l_int[i] = device->mul_mat_l[i]; + device->mul_mat_m_int[i] = device->mul_mat_m[i]; + device->mul_mat_s_int[i] = device->mul_mat_s[i]; + device->mul_mat_id_l_int[i] = device->mul_mat_id_l[i]; + device->mul_mat_id_m_int[i] = device->mul_mat_id_m[i]; + device->mul_mat_id_s_int[i] = device->mul_mat_id_s[i]; } @@ -5183,13 +6241,24 @@ static vk_device ggml_vk_get_device(size_t idx) { ggml_vk_load_shaders(device); + // Prefer a dedicated transfer queue on AMD dGPUs (non-GCN) when graphics queue use is disabled. + const bool prefers_transfer_queue = + device->vendor_id == VK_VENDOR_ID_AMD && + device->architecture != AMD_GCN && + !device->uma && + !allow_graphics_queue; + if (!device->single_queue) { const uint32_t transfer_queue_index = compute_queue_family_index == transfer_queue_family_index ? 1 : 0; ggml_vk_create_queue(device, device->transfer_queue, transfer_queue_family_index, transfer_queue_index, { vk::PipelineStageFlagBits::eTransfer }, true); + + device->async_use_transfer_queue = prefers_transfer_queue || (getenv("GGML_VK_ASYNC_USE_TRANSFER_QUEUE") != nullptr); } else { // TODO: Use pointer or reference to avoid copy device->transfer_queue.copyFrom(device->compute_queue); device->transfer_queue.cmd_pool.init(device, &device->transfer_queue); + + device->async_use_transfer_queue = false; } device->buffer_type = { @@ -5243,8 +6312,10 @@ static void ggml_vk_print_gpu_info(size_t idx) { bool fp16_compute = false; bool coopmat_support = false; bool coopmat2_support = false; + bool coopmat2_decode_vector_support = false; bool integer_dot_product = false; bool bfloat16_support = false; + bool dot2_f16_support = false; for (auto properties : ext_props) { if (strcmp("VK_KHR_16bit_storage", properties.extensionName) == 0) { @@ -5261,6 +6332,9 @@ static void ggml_vk_print_gpu_info(size_t idx) { !getenv("GGML_VK_DISABLE_COOPMAT2")) { coopmat2_support = true; #endif + } else if (strcmp(VK_NV_COOPERATIVE_MATRIX_DECODE_VECTOR_EXTENSION_NAME, properties.extensionName) == 0 && + !getenv("GGML_VK_DISABLE_COOPMAT2_DECODE_VECTOR")) { + coopmat2_decode_vector_support = true; #if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT) } else if (strcmp("VK_KHR_shader_integer_dot_product", properties.extensionName) == 0 && !getenv("GGML_VK_DISABLE_INTEGER_DOT_PRODUCT")) { @@ -5271,6 +6345,9 @@ static void ggml_vk_print_gpu_info(size_t idx) { !getenv("GGML_VK_DISABLE_BFLOAT16")) { bfloat16_support = true; #endif + } else if (strcmp("VK_VALVE_shader_mixed_float_dot_product", properties.extensionName) == 0 && + !getenv("GGML_VK_DISABLE_DOT2")) { + dot2_f16_support = true; } } @@ -5345,6 +6422,29 @@ static void ggml_vk_print_gpu_info(size_t idx) { } #endif +#if defined(VK_NV_cooperative_matrix2) + VkPhysicalDeviceCooperativeMatrix2FeaturesNV coopmat2_features {}; + coopmat2_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_COOPERATIVE_MATRIX_2_FEATURES_NV; + if (coopmat2_support) { + last_struct->pNext = (VkBaseOutStructure *)&coopmat2_features; + last_struct = (VkBaseOutStructure *)&coopmat2_features; + } +#endif + + VkPhysicalDeviceCooperativeMatrixDecodeVectorFeaturesNV coopmat2_decode_vector_features {}; + coopmat2_decode_vector_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_COOPERATIVE_MATRIX_DECODE_VECTOR_FEATURES_NV; + if (coopmat2_decode_vector_support) { + last_struct->pNext = (VkBaseOutStructure *)&coopmat2_decode_vector_features; + last_struct = (VkBaseOutStructure *)&coopmat2_decode_vector_features; + } + + VkPhysicalDeviceShaderMixedFloatDotProductFeaturesVALVE dot2_features {}; + dot2_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_SHADER_MIXED_FLOAT_DOT_PRODUCT_FEATURES_VALVE; + if (dot2_f16_support) { + last_struct->pNext = (VkBaseOutStructure *)&dot2_features; + last_struct = (VkBaseOutStructure *)&dot2_features; + } + vkGetPhysicalDeviceFeatures2(physical_device, &device_features2); fp16 = fp16 && vk12_features.shaderFloat16; @@ -5369,11 +6469,34 @@ static void ggml_vk_print_gpu_info(size_t idx) { #endif && ggml_vk_khr_cooperative_matrix_support(props2.properties, driver_props, device_architecture); - std::string matrix_cores = coopmat2_support ? "NV_coopmat2" : coopmat_support ? "KHR_coopmat" : "none"; +#if defined(VK_NV_cooperative_matrix2) && defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT) + coopmat2_support = coopmat2_support && + coopmat2_features.cooperativeMatrixWorkgroupScope && + coopmat2_features.cooperativeMatrixFlexibleDimensions && + coopmat2_features.cooperativeMatrixReductions && + coopmat2_features.cooperativeMatrixConversions && + coopmat2_features.cooperativeMatrixPerElementOperations && + coopmat2_features.cooperativeMatrixTensorAddressing && + coopmat2_features.cooperativeMatrixBlockLoads; +#else + coopmat2_support = false; +#endif + + coopmat2_decode_vector_support = coopmat2_decode_vector_support && coopmat2_decode_vector_features.cooperativeMatrixDecodeVector; +#if !defined(GGML_VULKAN_COOPMAT2_DECODE_VECTOR_GLSLC_SUPPORT) + coopmat2_decode_vector_support = false; +#endif + + std::string matrix_cores = coopmat2_support ? (coopmat2_decode_vector_support ? "NV_coopmat2v" : "NV_coopmat2") + : coopmat_support ? "KHR_coopmat" + : "none"; + + bool dot2_f16 = dot2_f16_support && dot2_features.shaderMixedFloatDotProductFloat16AccFloat32; + const char *fp16_str = fp16 ? (dot2_f16 ? "dot2" : "1") : "0"; std::string device_name = props2.properties.deviceName.data(); - GGML_LOG_DEBUG("ggml_vulkan: %zu = %s (%s) | uma: %d | fp16: %d | bf16: %d | warp size: %zu | shared memory: %d | int dot: %d | matrix cores: %s\n", - idx, device_name.c_str(), driver_props.driverName.data(), uma, fp16, bf16, subgroup_size, + GGML_LOG_DEBUG("ggml_vulkan: %zu = %s (%s) | uma: %d | fp16: %s | bf16: %d | warp size: %zu | shared memory: %d | int dot: %d | matrix cores: %s\n", + idx, device_name.c_str(), driver_props.driverName.data(), uma, fp16_str, bf16, subgroup_size, props2.properties.limits.maxComputeSharedMemorySize, integer_dot_product, matrix_cores.c_str()); if (props2.properties.deviceType == vk::PhysicalDeviceType::eCpu) { @@ -5467,6 +6590,10 @@ static void ggml_vk_instance_init() { vk_perf_logger_concurrent = getenv("GGML_VK_PERF_LOGGER_CONCURRENT") != nullptr; vk_enable_sync_logger = getenv("GGML_VK_SYNC_LOGGER") != nullptr; vk_memory_logger_enabled = getenv("GGML_VK_MEMORY_LOGGER") != nullptr; + const char* GGML_VK_PIPELINE_STATS = getenv("GGML_VK_PIPELINE_STATS"); + if (GGML_VK_PIPELINE_STATS != nullptr) { + vk_pipeline_stats_filter = GGML_VK_PIPELINE_STATS; + } const char* GGML_VK_PERF_LOGGER_FREQUENCY = getenv("GGML_VK_PERF_LOGGER_FREQUENCY"); if (GGML_VK_PERF_LOGGER_FREQUENCY != nullptr) { @@ -5513,22 +6640,30 @@ static void ggml_vk_instance_init() { if ((new_props.properties.deviceType == vk::PhysicalDeviceType::eDiscreteGpu || new_props.properties.deviceType == vk::PhysicalDeviceType::eIntegratedGpu) && ggml_vk_device_is_supported(devices[i])) { // Check if there are two physical devices corresponding to the same GPU + // This handles the case where the same GPU appears with different drivers (e.g., RADV + AMDVLK on Linux), + // see https://github.com/ggml-org/llama.cpp/pull/7582 for original deduplication. + // MoltenVK on macOS may report the same UUID for distinct GPUs on multi-GPU cards, + // see https://github.com/KhronosGroup/MoltenVK/issues/2683. Skip when both old/new + // driver is MoltenVK auto old_device = std::find_if( vk_instance.device_indices.begin(), vk_instance.device_indices.end(), - [&devices, &new_id](const size_t k){ + [&devices, &new_id, &new_driver](const size_t k){ vk::PhysicalDeviceProperties2 old_props; + vk::PhysicalDeviceDriverProperties old_driver; vk::PhysicalDeviceIDProperties old_id; - old_props.pNext = &old_id; + old_props.pNext = &old_driver; + old_driver.pNext = &old_id; devices[k].getProperties2(&old_props); - bool equals = std::equal(std::begin(old_id.deviceUUID), std::end(old_id.deviceUUID), std::begin(new_id.deviceUUID)); - equals = equals || ( + bool same_uuid = std::equal(std::begin(old_id.deviceUUID), std::end(old_id.deviceUUID), std::begin(new_id.deviceUUID)); + same_uuid = same_uuid || ( old_id.deviceLUIDValid && new_id.deviceLUIDValid && std::equal(std::begin(old_id.deviceLUID), std::end(old_id.deviceLUID), std::begin(new_id.deviceLUID)) ); + bool both_molten_vk = (new_driver.driverID == vk::DriverId::eMoltenvk && old_driver.driverID == vk::DriverId::eMoltenvk); - return equals; + return same_uuid && !both_molten_vk; } ); if (old_device == vk_instance.device_indices.end()) { @@ -5565,6 +6700,10 @@ static void ggml_vk_instance_init() { driver_priorities[vk::DriverId::eMesaNvk] = 2; #endif break; + case VK_VENDOR_ID_QUALCOMM: + driver_priorities[vk::DriverId::eQualcommProprietary] = 1; + driver_priorities[vk::DriverId::eMesaTurnip] = 2; + break; } driver_priorities[vk::DriverId::eMesaDozen] = 100; @@ -5647,7 +6786,15 @@ static void ggml_vk_init(ggml_backend_vk_context * ctx, size_t idx) { ctx->almost_ready_fence = ctx->device->device.createFence({}); ctx->compute_cmd_pool.init(ctx->device, &ctx->device->compute_queue); - ctx->transfer_cmd_pool.init(ctx->device, &ctx->device->transfer_queue); + if (ctx->device->async_use_transfer_queue) { + vk::SemaphoreTypeCreateInfo tci{ vk::SemaphoreType::eTimeline, 0 }; + vk::SemaphoreCreateInfo ci{}; + ci.setPNext(&tci); + ctx->transfer_semaphore.s = ctx->device->device.createSemaphore(ci); + ctx->transfer_semaphore.value = 0; + + ctx->transfer_cmd_pool.init(ctx->device, &ctx->device->transfer_queue); + } if (vk_perf_logger_enabled) { ctx->perf_logger = std::unique_ptr<vk_perf_logger>(new vk_perf_logger()); @@ -5665,6 +6812,7 @@ static vk_pipeline ggml_vk_get_to_fp16(ggml_backend_vk_context * ctx, ggml_type VK_LOG_DEBUG("ggml_vk_get_to_fp16()"); switch (type) { case GGML_TYPE_F32: + case GGML_TYPE_Q1_0: case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_1: case GGML_TYPE_Q5_0: @@ -5685,6 +6833,7 @@ static vk_pipeline ggml_vk_get_to_fp16(ggml_backend_vk_context * ctx, ggml_type case GGML_TYPE_IQ4_XS: case GGML_TYPE_IQ4_NL: case GGML_TYPE_MXFP4: + case GGML_TYPE_NVFP4: break; default: return nullptr; @@ -5736,6 +6885,7 @@ static vk_matmul_pipeline ggml_vk_get_mul_mat_mat_pipeline(ggml_backend_vk_conte } switch (src0_type) { + case GGML_TYPE_Q1_0: case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_1: case GGML_TYPE_Q5_0: @@ -5756,6 +6906,7 @@ static vk_matmul_pipeline ggml_vk_get_mul_mat_mat_pipeline(ggml_backend_vk_conte case GGML_TYPE_IQ4_XS: case GGML_TYPE_IQ4_NL: case GGML_TYPE_MXFP4: + case GGML_TYPE_NVFP4: break; default: return nullptr; @@ -5801,6 +6952,7 @@ static vk_pipeline ggml_vk_get_dequantize_mul_mat_vec(ggml_backend_vk_context * case GGML_TYPE_F32: case GGML_TYPE_F16: case GGML_TYPE_BF16: + case GGML_TYPE_Q1_0: case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_1: case GGML_TYPE_Q5_0: @@ -5821,6 +6973,7 @@ static vk_pipeline ggml_vk_get_dequantize_mul_mat_vec(ggml_backend_vk_context * case GGML_TYPE_IQ4_XS: case GGML_TYPE_IQ4_NL: case GGML_TYPE_MXFP4: + case GGML_TYPE_NVFP4: break; default: return nullptr; @@ -5891,6 +7044,7 @@ static vk_matmul_pipeline ggml_vk_get_mul_mat_mat_id_pipeline(ggml_backend_vk_co GGML_ASSERT(src1_type == GGML_TYPE_F32 || (ctx->device->coopmat2 && src1_type == GGML_TYPE_F16)); switch (src0_type) { + case GGML_TYPE_Q1_0: case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_1: case GGML_TYPE_Q5_0: @@ -5911,6 +7065,7 @@ static vk_matmul_pipeline ggml_vk_get_mul_mat_mat_id_pipeline(ggml_backend_vk_co case GGML_TYPE_IQ4_XS: case GGML_TYPE_IQ4_NL: case GGML_TYPE_MXFP4: + case GGML_TYPE_NVFP4: break; default: return nullptr; @@ -5959,6 +7114,7 @@ static vk_pipeline ggml_vk_get_dequantize_mul_mat_vec_id(ggml_backend_vk_context case GGML_TYPE_F32: case GGML_TYPE_F16: case GGML_TYPE_BF16: + case GGML_TYPE_Q1_0: case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_1: case GGML_TYPE_Q5_0: @@ -5979,6 +7135,7 @@ static vk_pipeline ggml_vk_get_dequantize_mul_mat_vec_id(ggml_backend_vk_context case GGML_TYPE_IQ4_XS: case GGML_TYPE_IQ4_NL: case GGML_TYPE_MXFP4: + case GGML_TYPE_NVFP4: break; default: return nullptr; @@ -6025,7 +7182,7 @@ static void * ggml_vk_host_malloc(vk_device& device, size_t size) { return nullptr; } - std::lock_guard<std::recursive_mutex> guard(device->mutex); + std::lock_guard<std::shared_mutex> guard(device->pinned_memory_mutex); device->pinned_memory.push_back(std::make_tuple(buf->ptr, size, buf)); return buf->ptr; @@ -6036,7 +7193,7 @@ static void ggml_vk_host_free(vk_device& device, void* ptr) { return; } VK_LOG_MEMORY("ggml_vk_host_free(" << ptr << ")"); - std::lock_guard<std::recursive_mutex> guard(device->mutex); + std::lock_guard<std::shared_mutex> guard(device->pinned_memory_mutex); vk_buffer buf; size_t index; @@ -6060,7 +7217,7 @@ static void ggml_vk_host_free(vk_device& device, void* ptr) { } static void ggml_vk_host_get(const vk_device& device, const void * ptr, vk_buffer& buf, size_t& buf_offset) { - std::lock_guard<std::recursive_mutex> guard(device->mutex); + std::shared_lock<std::shared_mutex> guard(device->pinned_memory_mutex); buf = nullptr; buf_offset = 0; for (size_t i = 0; i < device->pinned_memory.size(); i++) { @@ -6100,13 +7257,25 @@ static vk_subbuffer ggml_vk_tensor_subbuffer( return vk_subbuffer{buffer, offset, size}; } +// Get a command buffer from pool. Create a new one if no reusable buffer is available +static vk_command_buffer* ggml_vk_get_or_create_cmd_buffer(vk_device& device, vk_command_pool& pool) { + for (auto& cmd_buffer : pool.cmd_buffers) { + if (!cmd_buffer.in_use) { + cmd_buffer.use_counter++; + cmd_buffer.in_use = true; + return &cmd_buffer; + } + } + return ggml_vk_create_cmd_buffer(device, pool); +} + static vk_submission ggml_vk_begin_submission(vk_device& device, vk_command_pool& p, bool one_time = true) { vk_submission s; - s.buffer = ggml_vk_create_cmd_buffer(device, p); + s.buffer = ggml_vk_get_or_create_cmd_buffer(device, p); if (one_time) { - s.buffer.begin({ vk::CommandBufferUsageFlagBits::eOneTimeSubmit }); + s.buffer->buf.begin({ vk::CommandBufferUsageFlagBits::eOneTimeSubmit }); } else { - s.buffer.begin({ vk::CommandBufferUsageFlags{} }); + s.buffer->buf.begin({ vk::CommandBufferUsageFlags{} }); } return s; @@ -6159,21 +7328,14 @@ static void ggml_vk_dispatch_pipeline(ggml_backend_vk_context* ctx, vk_context& vk::WriteDescriptorSet write_descriptor_set{ descriptor_set, 0, 0, pipeline->parameter_count, vk::DescriptorType::eStorageBuffer, nullptr, descriptor_buffer_infos.begin() }; ctx->device->device.updateDescriptorSets({ write_descriptor_set }, {}); - subctx->s->buffer.pushConstants(pipeline->layout, vk::ShaderStageFlagBits::eCompute, 0, push_constant_size(push_constants), push_constant_data(push_constants)); - subctx->s->buffer.bindPipeline(vk::PipelineBindPoint::eCompute, pipeline->pipeline); - subctx->s->buffer.bindDescriptorSets(vk::PipelineBindPoint::eCompute, + subctx->s->buffer->buf.pushConstants(pipeline->layout, vk::ShaderStageFlagBits::eCompute, 0, push_constant_size(push_constants), push_constant_data(push_constants)); + subctx->s->buffer->buf.bindPipeline(vk::PipelineBindPoint::eCompute, pipeline->pipeline); + subctx->s->buffer->buf.bindDescriptorSets(vk::PipelineBindPoint::eCompute, pipeline->layout, 0, { descriptor_set }, {}); - subctx->s->buffer.dispatch(wg0, wg1, wg2); -} - -static void ggml_vk_end_submission(vk_submission& s, std::vector<vk_semaphore> wait_semaphores, std::vector<vk_semaphore> signal_semaphores) { - s.buffer.end(); - - s.wait_semaphores = std::move(wait_semaphores); - s.signal_semaphores = std::move(signal_semaphores); + subctx->s->buffer->buf.dispatch(wg0, wg1, wg2); } static void ggml_vk_ctx_end(vk_context& ctx) { @@ -6182,7 +7344,7 @@ static void ggml_vk_ctx_end(vk_context& ctx) { return; } - ctx->s->buffer.end(); + ctx->s->buffer->buf.end(); ctx->s = nullptr; } @@ -6196,6 +7358,48 @@ static void ggml_vk_ctx_begin(vk_device& device, vk_context& subctx) { subctx->s = subctx->seqs[subctx->seqs.size() - 1].data(); } +static vk_context ggml_vk_get_compute_ctx(ggml_backend_vk_context * ctx) { + vk_context result; + if (!ctx->compute_ctx.expired()) { + result = ctx->compute_ctx.lock(); + } else { + result = ggml_vk_create_context(ctx, ctx->compute_cmd_pool); + + ctx->compute_ctx = result; + ggml_vk_ctx_begin(ctx->device, result); + } + + if (ctx->device->async_use_transfer_queue && ctx->transfer_semaphore_last_submitted < ctx->transfer_semaphore.value) { + result->s->wait_semaphores.push_back(ctx->transfer_semaphore); + ctx->transfer_semaphore_last_submitted = ctx->transfer_semaphore.value; + } + + return result; +} + +// Submit any pending transfer queue work and signal the transfer semaphore. +// The next compute context created via ggml_vk_get_compute_ctx will wait on this semaphore. +// Returns true if work was submitted. +static bool ggml_vk_submit_transfer_ctx(ggml_backend_vk_context * ctx) { + if (!ctx->device->async_use_transfer_queue || ctx->transfer_ctx.expired()) { + return false; + } + + vk_context cpy_ctx = ctx->transfer_ctx.lock(); + ggml_vk_ctx_end(cpy_ctx); + + for (auto& cpy : cpy_ctx->in_memcpys) { + memcpy(cpy.dst, cpy.src, cpy.n); + } + + ctx->transfer_semaphore.value++; + cpy_ctx->seqs.back().back().signal_semaphores.push_back(ctx->transfer_semaphore); + + ggml_vk_submit(cpy_ctx, {}); + ctx->transfer_ctx.reset(); + return true; +} + static size_t ggml_vk_align_size(size_t width, size_t align) { VK_LOG_DEBUG("ggml_vk_align_size(" << width << ", " << align << ")"); return CEIL_DIV(width, align) * align; @@ -6286,7 +7490,7 @@ static void ggml_vk_buffer_write_nc_async(ggml_backend_vk_context * ctx, vk_cont const uint64_t s_off = buf_offset + i3*nb3 + i2*nb2 + i1*nb1; const uint64_t d_off = offset + i3*dstnb3 + i2*dstnb2 + i1*dstnb1; for (uint64_t i0 = 0; i0 < ne0; i0++) { - slices.push_back({ s_off + i1*nb0, d_off + i0*dstnb0, dstnb0 }); + slices.push_back({ s_off + i0*nb0, d_off + i0*dstnb0, dstnb0 }); } } } @@ -6295,7 +7499,7 @@ static void ggml_vk_buffer_write_nc_async(ggml_backend_vk_context * ctx, vk_cont } ggml_vk_sync_buffers(ctx, subctx); - subctx->s->buffer.copyBuffer(buf->buffer, dst->buffer, slices); + subctx->s->buffer->buf.copyBuffer(buf->buffer, dst->buffer, slices); return; } @@ -6310,7 +7514,7 @@ static void ggml_vk_buffer_write_nc_async(ggml_backend_vk_context * ctx, vk_cont VkBufferCopy buf_copy{ 0, offset, copy_size }; ggml_vk_sync_buffers(ctx, subctx); - vkCmdCopyBuffer(subctx->s->buffer, (VkBuffer)staging->buffer, (VkBuffer)dst->buffer, 1, &buf_copy); + vkCmdCopyBuffer(subctx->s->buffer->buf, (VkBuffer)staging->buffer, (VkBuffer)dst->buffer, 1, &buf_copy); for (uint64_t i3 = 0; i3 < ne3; i3++) { for (uint64_t i2 = 0; i2 < ne2; i2++) { @@ -6334,7 +7538,7 @@ static void ggml_vk_buffer_write_nc_async(ggml_backend_vk_context * ctx, vk_cont } } -static bool ggml_vk_buffer_write_2d_async(vk_context subctx, vk_buffer& dst, size_t offset, const void * src, size_t spitch, size_t width, size_t height, bool sync_staging = false) { +static bool ggml_vk_buffer_write_2d_async(vk_context subctx, vk_buffer& dst, size_t offset, const void * src, size_t spitch, size_t dpitch, size_t width, size_t height, bool sync_staging = false) { VK_LOG_DEBUG("ggml_vk_buffer_write_2d_async(" << width << ", " << height << ")"); // Check if src is pinned memory vk_buffer buf = nullptr; @@ -6344,7 +7548,7 @@ static bool ggml_vk_buffer_write_2d_async(vk_context subctx, vk_buffer& dst, siz if (buf != nullptr) { // Memory is pinned, use as staging buffer std::vector<vk::BufferCopy> slices(1); - if (width == spitch) { + if (width == spitch && width == dpitch) { // Only do single write if stride is equal slices[0].srcOffset = buf_offset; slices[0].dstOffset = offset; @@ -6353,13 +7557,13 @@ static bool ggml_vk_buffer_write_2d_async(vk_context subctx, vk_buffer& dst, siz slices.resize(height); for (size_t i = 0; i < height; i++) { slices[i].srcOffset = buf_offset + i * spitch; - slices[i].dstOffset = offset + i * width; + slices[i].dstOffset = offset + i * dpitch; slices[i].size = width; } } ggml_vk_sync_buffers(nullptr, subctx); - subctx->s->buffer.copyBuffer(buf->buffer, dst->buffer, slices); + subctx->s->buffer->buf.copyBuffer(buf->buffer, dst->buffer, slices); return true; } VK_LOG_DEBUG("STAGING"); @@ -6370,21 +7574,30 @@ static bool ggml_vk_buffer_write_2d_async(vk_context subctx, vk_buffer& dst, siz } // Staging buffer required - const size_t copy_size = width*height; - ggml_vk_ensure_sync_staging_buffer(dst->device, copy_size); + const size_t staging_size = width * height; + ggml_vk_ensure_sync_staging_buffer(dst->device, staging_size); vk_buffer& staging_buffer = dst->device->sync_staging; - VkBufferCopy buf_copy = { - 0, - offset, - copy_size}; + std::vector<vk::BufferCopy> slices(1); + if (width == dpitch) { + slices[0].srcOffset = 0; + slices[0].dstOffset = offset; + slices[0].size = staging_size; + } else { + slices.resize(height); + for (size_t i = 0; i < height; i++) { + slices[i].srcOffset = i * width; + slices[i].dstOffset = offset + i * dpitch; + slices[i].size = width; + } + } ggml_vk_sync_buffers(nullptr, subctx); - vkCmdCopyBuffer(subctx->s->buffer, (VkBuffer)staging_buffer->buffer, (VkBuffer)dst->buffer, 1, &buf_copy); + subctx->s->buffer->buf.copyBuffer((VkBuffer)staging_buffer->buffer, (VkBuffer)dst->buffer, slices); if (width == spitch) { - deferred_memcpy((uint8_t *)staging_buffer->ptr, src, width * height, &subctx->in_memcpys); + deferred_memcpy((uint8_t *)staging_buffer->ptr, src, staging_size, &subctx->in_memcpys); } else { for (size_t i = 0; i < height; i++) { deferred_memcpy((uint8_t *)staging_buffer->ptr + i * width, (const uint8_t *) src + i * spitch, width, &subctx->in_memcpys); @@ -6395,24 +7608,28 @@ static bool ggml_vk_buffer_write_2d_async(vk_context subctx, vk_buffer& dst, siz static bool ggml_vk_buffer_write_async(vk_context subctx, vk_buffer& dst, size_t offset, const void * src, size_t size, bool sync_staging = false) { VK_LOG_DEBUG("ggml_vk_buffer_write_async(" << size << ")"); - return ggml_vk_buffer_write_2d_async(subctx, dst, offset, src, size, size, 1, sync_staging); + return ggml_vk_buffer_write_2d_async(subctx, dst, offset, src, size, size, size, 1, sync_staging); } -static void ggml_vk_buffer_write_2d(vk_buffer& dst, size_t offset, const void * src, size_t spitch, size_t width, size_t height) { +static void ggml_vk_buffer_write_2d(vk_buffer& dst, size_t offset, const void * src, size_t spitch, size_t dpitch, size_t width, size_t height) { VK_LOG_DEBUG("ggml_vk_buffer_write_2d(" << width << ", " << height << ")"); // Buffer is already mapped if(dst->memory_property_flags & vk::MemoryPropertyFlagBits::eHostVisible) { GGML_ASSERT(dst->memory_property_flags & vk::MemoryPropertyFlagBits::eHostCoherent); - for (size_t i = 0; i < height; i++) { - memcpy((uint8_t *)dst->ptr + offset + i * width, (const uint8_t *) src + i * spitch, width); + if (width == spitch && width == dpitch) { + memcpy((uint8_t *)dst->ptr + offset, src, width * height); + } else { + for (size_t i = 0; i < height; i++) { + memcpy((uint8_t *)dst->ptr + offset + i * dpitch, (const uint8_t *) src + i * spitch, width); + } } } else { std::lock_guard<std::recursive_mutex> guard(dst->device->mutex); vk_context subctx = ggml_vk_create_temporary_context(dst->device->transfer_queue.cmd_pool); ggml_vk_ctx_begin(dst->device, subctx); - bool ret = ggml_vk_buffer_write_2d_async(subctx, dst, offset, src, spitch, width, height, true); + bool ret = ggml_vk_buffer_write_2d_async(subctx, dst, offset, src, spitch, dpitch, width, height, true); GGML_ASSERT(ret); ggml_vk_ctx_end(subctx); @@ -6433,7 +7650,7 @@ static void ggml_vk_buffer_write_2d(vk_buffer& dst, size_t offset, const void * static void ggml_vk_buffer_write(vk_buffer& dst, size_t offset, const void * src, size_t size) { VK_LOG_DEBUG("ggml_vk_buffer_write(" << size << ")"); - ggml_vk_buffer_write_2d(dst, offset, src, 0, size, 1); + ggml_vk_buffer_write_2d(dst, offset, src, size, size, size, 1); } static bool ggml_vk_buffer_read_2d_async(vk_context subctx, vk_buffer& src, size_t offset, void * dst, size_t spitch, size_t dpitch, size_t width, size_t height, bool sync_staging = false) { @@ -6467,7 +7684,7 @@ static bool ggml_vk_buffer_read_2d_async(vk_context subctx, vk_buffer& src, size if (buf != nullptr) { // Memory is pinned, use as staging buffer ggml_vk_sync_buffers(nullptr, subctx); - subctx->s->buffer.copyBuffer(src->buffer, buf->buffer, slices); + subctx->s->buffer->buf.copyBuffer(src->buffer, buf->buffer, slices); return true; } @@ -6479,15 +7696,35 @@ static bool ggml_vk_buffer_read_2d_async(vk_context subctx, vk_buffer& src, size } // Fall back to staging buffer - const size_t copy_size = dpitch * height; - ggml_vk_ensure_sync_staging_buffer(src->device, copy_size); + const size_t staging_size = width * height; + ggml_vk_ensure_sync_staging_buffer(src->device, staging_size); vk_buffer& staging_buffer = src->device->sync_staging; + std::vector<vk::BufferCopy> staging_slices(1); + if (width == spitch) { + staging_slices[0].srcOffset = offset; + staging_slices[0].dstOffset = 0; + staging_slices[0].size = staging_size; + } else { + staging_slices.resize(height); + for (size_t i = 0; i < height; i++) { + staging_slices[i].srcOffset = offset + i * spitch; + staging_slices[i].dstOffset = i * width; + staging_slices[i].size = width; + } + } + ggml_vk_sync_buffers(nullptr, subctx); - subctx->s->buffer.copyBuffer(src->buffer, staging_buffer->buffer, slices); + subctx->s->buffer->buf.copyBuffer(src->buffer, staging_buffer->buffer, staging_slices); - deferred_memcpy(dst, staging_buffer->ptr, copy_size, &subctx->out_memcpys); + if (width == dpitch) { + deferred_memcpy(dst, staging_buffer->ptr, staging_size, &subctx->out_memcpys); + } else { + for (size_t i = 0; i < height; i++) { + deferred_memcpy((uint8_t *) dst + i * dpitch, (const uint8_t *) staging_buffer->ptr + i * width, width, &subctx->out_memcpys); + } + } return true; } @@ -6495,8 +7732,8 @@ static bool ggml_vk_buffer_read_async(vk_context subctx, vk_buffer& src, size_t return ggml_vk_buffer_read_2d_async(subctx, src, offset, dst, size, size, size, 1, sync_staging); } -static void ggml_vk_buffer_read(vk_buffer& src, size_t offset, void * dst, size_t size) { - VK_LOG_DEBUG("ggml_vk_buffer_read(" << src->buffer << ", " << offset << ", " << size << ")"); +static void ggml_vk_buffer_read_2d(vk_buffer& src, size_t offset, void * dst, size_t spitch, size_t dpitch, size_t width, size_t height) { + VK_LOG_DEBUG("ggml_vk_buffer_read_2d(" << src->buffer << ", " << offset << ", " << width << ", " << height << ")"); // If the device is not an UMA device the memory is host-accessible through rebar. While writing // through PCIe is sufficient fast reading back data from PCIe is slower than going through @@ -6504,18 +7741,24 @@ static void ggml_vk_buffer_read(vk_buffer& src, size_t offset, void * dst, size_ if(src->memory_property_flags & vk::MemoryPropertyFlagBits::eHostVisible && src->device->uma) { GGML_ASSERT(src->memory_property_flags & vk::MemoryPropertyFlagBits::eHostCoherent); - memcpy(dst, (uint8_t *) src->ptr + offset, size); + if (width == spitch && width == dpitch) { + memcpy(dst, (const uint8_t *) src->ptr + offset, width * height); + } else { + for (size_t i = 0; i < height; i++) { + memcpy((uint8_t *) dst + i * dpitch, (const uint8_t *) src->ptr + offset + i * spitch, width); + } + } } else { std::lock_guard<std::recursive_mutex> guard(src->device->mutex); vk_context subctx = ggml_vk_create_temporary_context(src->device->transfer_queue.cmd_pool); ggml_vk_ctx_begin(src->device, subctx); - bool ret = ggml_vk_buffer_read_async(subctx, src, offset, dst, size, true); + bool ret = ggml_vk_buffer_read_2d_async(subctx, src, offset, dst, spitch, dpitch, width, height, true); GGML_ASSERT(ret); ggml_vk_ctx_end(subctx); ggml_vk_submit(subctx, src->device->fence); - VK_CHECK(src->device->device.waitForFences({ src->device->fence }, true, UINT64_MAX), "vk_buffer_read waitForFences"); + VK_CHECK(src->device->device.waitForFences({ src->device->fence }, true, UINT64_MAX), "vk_buffer_read_2d waitForFences"); src->device->device.resetFences({ src->device->fence }); ggml_vk_queue_command_pools_cleanup(src->device); @@ -6525,6 +7768,11 @@ static void ggml_vk_buffer_read(vk_buffer& src, size_t offset, void * dst, size_ } } +static void ggml_vk_buffer_read(vk_buffer& src, size_t offset, void * dst, size_t size) { + VK_LOG_DEBUG("ggml_vk_buffer_read(" << src->buffer << ", " << offset << ", " << size << ")"); + ggml_vk_buffer_read_2d(src, offset, dst, size, size, size, 1); +} + static void ggml_vk_buffer_copy_async(vk_context& ctx, vk_buffer& dst, size_t dst_offset, vk_buffer& src, size_t src_offset, size_t size) { VK_LOG_DEBUG("ggml_vk_buffer_copy_async(" << size << ")"); // Make sure both buffers are on same device @@ -6532,7 +7780,7 @@ static void ggml_vk_buffer_copy_async(vk_context& ctx, vk_buffer& dst, size_t ds VkBufferCopy bc{ src_offset, dst_offset, size }; - vkCmdCopyBuffer(ctx->s->buffer, (VkBuffer)src->buffer, (VkBuffer)dst->buffer, 1, &bc); + vkCmdCopyBuffer(ctx->s->buffer->buf, (VkBuffer)src->buffer, (VkBuffer)dst->buffer, 1, &bc); } static void ggml_vk_buffer_copy(vk_buffer& dst, size_t dst_offset, vk_buffer& src, size_t src_offset, size_t size) { @@ -6556,7 +7804,7 @@ static void ggml_vk_buffer_copy(vk_buffer& dst, size_t dst_offset, vk_buffer& sr // Copy to src staging buffer ggml_vk_buffer_copy(src->device->sync_staging, 0, src, src_offset, size); // Copy to dst buffer - ggml_vk_buffer_write_2d(dst, dst_offset, src->device->sync_staging->ptr, 0, size, 1); + ggml_vk_buffer_write(dst, dst_offset, src->device->sync_staging->ptr, size); } } @@ -6570,7 +7818,7 @@ static void ggml_vk_buffer_memset_async(vk_context& ctx, vk_buffer& dst, size_t } // Fall back to GPU fillBuffer for non-UMA or non-host-visible buffers - ctx->s->buffer.fillBuffer(dst->buffer, offset, size, c); + ctx->s->buffer->buf.fillBuffer(dst->buffer, offset, size, c); } static void ggml_vk_buffer_memset(vk_buffer& dst, size_t offset, uint32_t c, size_t size) { @@ -6585,7 +7833,7 @@ static void ggml_vk_buffer_memset(vk_buffer& dst, size_t offset, uint32_t c, siz std::lock_guard<std::recursive_mutex> guard(dst->device->mutex); vk_context subctx = ggml_vk_create_temporary_context(dst->device->transfer_queue.cmd_pool); ggml_vk_ctx_begin(dst->device, subctx); - subctx->s->buffer.fillBuffer(dst->buffer, offset, size, c); + subctx->s->buffer->buf.fillBuffer(dst->buffer, offset, size, c); ggml_vk_ctx_end(subctx); ggml_vk_submit(subctx, dst->device->fence); @@ -6639,6 +7887,13 @@ static uint32_t ggml_vk_guess_split_k(ggml_backend_vk_context * ctx, uint32_t m, static vk_pipeline ggml_vk_guess_matmul_pipeline(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, uint32_t m, uint32_t n, bool aligned, ggml_type src0_type, ggml_type src1_type) { VK_LOG_DEBUG("ggml_vk_guess_matmul_pipeline(" << m << ", " << n << ", " << aligned << ", " << ggml_type_name(src0_type) << ", " << ggml_type_name(src1_type) << ")"); + // The q8_1 (integer dot) mmq path uses a different shader with its own + // shared-memory layout, so use the int-specific availability flags. + const bool is_q8_1 = (src1_type == GGML_TYPE_Q8_1); + const bool mm_l = is_q8_1 ? ctx->device->mul_mat_l_int[src0_type] : ctx->device->mul_mat_l[src0_type]; + const bool mm_m = is_q8_1 ? ctx->device->mul_mat_m_int[src0_type] : ctx->device->mul_mat_m[src0_type]; + const bool mm_s = is_q8_1 ? ctx->device->mul_mat_s_int[src0_type] : ctx->device->mul_mat_s[src0_type]; + if (ctx->device->coopmat2) { const uint32_t shader_core_count = ctx->device->shader_core_count; const uint32_t tiles_l = CEIL_DIV(m, mmp->a_l->wg_denoms[0]) * CEIL_DIV(n, mmp->a_l->wg_denoms[1]); @@ -6655,26 +7910,24 @@ static vk_pipeline ggml_vk_guess_matmul_pipeline(ggml_backend_vk_context * ctx, // split_k==3 with large tiles likely better than medium tiles with no split_k. (tiles_l <= shader_core_count / 3 && tiles_m > shader_core_count / 2); - if ((ctx->device->mul_mat_l[src0_type] && (n > crossover_large && prefer_large)) || (!ctx->device->mul_mat_m[src0_type] && !ctx->device->mul_mat_s[src0_type])) { + if ((mm_l && (n > crossover_large && prefer_large)) || (!mm_m && !mm_s)) { return aligned ? mmp->a_l : mmp->l; } // Use medium shader when the N dimension is greater than the small shader's tile size uint32_t crossover_medium = mmp->s->wg_denoms[1]; - if ((ctx->device->mul_mat_m[src0_type] && (n > crossover_medium)) || !ctx->device->mul_mat_s[src0_type]) { + if ((mm_m && (n > crossover_medium)) || !mm_s) { return aligned ? mmp->a_m : mmp->m; } return aligned ? mmp->a_s : mmp->s; } - if ((ctx->device->mul_mat_s[src0_type] && (m <= 32 || n <= 32)) || (!ctx->device->mul_mat_m[src0_type] && !ctx->device->mul_mat_l[src0_type])) { + if ((mm_s && (m <= 32 || n <= 32)) || (!mm_m && !mm_l)) { return aligned ? mmp->a_s : mmp->s; } - if ((ctx->device->mul_mat_m[src0_type] && (m <= 64 || n <= 64)) || !ctx->device->mul_mat_l[src0_type]) { + if ((mm_m && (m <= 64 || n <= 64)) || !mm_l) { return aligned ? mmp->a_m : mmp->m; } return aligned ? mmp->a_l : mmp->l; - - GGML_UNUSED(src1_type); } static uint32_t ggml_vk_guess_matmul_pipeline_align(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, int m, int n, ggml_type src0_type, ggml_type src1_type) { @@ -6691,8 +7944,16 @@ static void ggml_vk_matmul( uint32_t padded_n) { VK_LOG_DEBUG("ggml_vk_matmul(a: (" << a.buffer->buffer << ", " << a.offset << ", " << a.size << "), b: (" << b.buffer->buffer << ", " << b.offset << ", " << b.size << "), d: (" << d.buffer->buffer << ", " << d.offset << ", " << d.size << "), split_k: (" << (split_k_buffer.buffer != nullptr ? split_k_buffer.buffer->buffer : VK_NULL_HANDLE) << ", " << split_k_buffer.offset << ", " << split_k_buffer.size << "), m: " << m << ", n: " << n << ", k: " << k << ", stride_a: " << stride_a << ", stride_b: " << stride_b << ", stride_d: " << stride_d << ", batch_stride_a: " << batch_stride_a << ", batch_stride_b: " << batch_stride_b << ", batch_stride_d: " << batch_stride_d << ", split_k: " << split_k << ", batch: " << batch << ", ne02: " << ne02 << ", ne12: " << ne12 << ", broadcast2: " << broadcast2 << ", broadcast3: " << broadcast3 << ", padded_n: " << padded_n << ")"); if (split_k == 1) { - const vk_mat_mat_push_constants pc = { m, n, k, stride_a, stride_b, stride_d, batch_stride_a, batch_stride_b, batch_stride_d, k, ne02, ne12, broadcast2, broadcast3, padded_n }; - ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { a, b, d }, pc, { m, n, batch }); + ggml_pipeline_request_descriptor_sets(ctx, pipeline, CEIL_DIV(batch, ctx->device->properties.limits.maxComputeWorkGroupCount[2])); + + uint32_t base_work_group_z = 0; + while (base_work_group_z < batch) { + uint32_t groups_z = std::min(batch - base_work_group_z, ctx->device->properties.limits.maxComputeWorkGroupCount[2]); + + const vk_mat_mat_push_constants pc = { m, n, k, stride_a, stride_b, stride_d, batch_stride_a, batch_stride_b, batch_stride_d, base_work_group_z, batch, k, ne02, ne12, broadcast2, broadcast3, padded_n }; + ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { a, b, d }, pc, { m, n, groups_z }); + base_work_group_z += groups_z; + } return; } @@ -6706,44 +7967,59 @@ static void ggml_vk_matmul( uint32_t k_split = CEIL_DIV(k, split_k); k_split = ROUNDUP_POW2(k_split, 256); - const vk_mat_mat_push_constants pc1 = { m, n, k, stride_a, stride_b, stride_d, batch_stride_a, batch_stride_b, batch_stride_d, k_split, ne02, ne12, broadcast2, broadcast3, padded_n }; - // Make sure enough workgroups get assigned for split k to work - ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { a, b, split_k_buffer }, pc1, { (CEIL_DIV(m, pipeline->wg_denoms[0]) * pipeline->wg_denoms[0]) * split_k, n, batch }); + ggml_pipeline_request_descriptor_sets(ctx, pipeline, CEIL_DIV(batch, ctx->device->properties.limits.maxComputeWorkGroupCount[2])); + + uint32_t base_work_group_z = 0; + while (base_work_group_z < batch) { + uint32_t groups_z = std::min(batch - base_work_group_z, ctx->device->properties.limits.maxComputeWorkGroupCount[2]); + + const vk_mat_mat_push_constants pc1 = { m, n, k, stride_a, stride_b, stride_d, batch_stride_a, batch_stride_b, batch_stride_d, base_work_group_z, batch, k_split, ne02, ne12, broadcast2, broadcast3, padded_n }; + // Make sure enough workgroups get assigned for split k to work + ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { a, b, split_k_buffer }, pc1, { (CEIL_DIV(m, pipeline->wg_denoms[0]) * pipeline->wg_denoms[0]) * split_k, n, groups_z }); + base_work_group_z += groups_z; + } ggml_vk_sync_buffers(ctx, subctx); const std::array<uint32_t, 2> pc2 = { (uint32_t)(m * n * batch), split_k }; ggml_vk_dispatch_pipeline(ctx, subctx, ctx->device->pipeline_matmul_split_k_reduce, { split_k_buffer, d }, pc2, { m * n * batch, 1, 1 }); ctx->prealloc_split_k_need_sync = true; } -static vk_pipeline ggml_vk_guess_matmul_id_pipeline(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, uint32_t m, uint32_t n, bool aligned, ggml_type src0_type) { - VK_LOG_DEBUG("ggml_vk_guess_matmul_id_pipeline(" << m << ", " << n << ", " << aligned << ", " << ggml_type_name(src0_type) << ")"); +static vk_pipeline ggml_vk_guess_matmul_id_pipeline(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, uint32_t m, uint32_t n, bool aligned, ggml_type src0_type, ggml_type src1_type) { + VK_LOG_DEBUG("ggml_vk_guess_matmul_id_pipeline(" << m << ", " << n << ", " << aligned << ", " << ggml_type_name(src0_type) << ", " << ggml_type_name(src1_type) << ")"); + + // The q8_1 (integer dot) mmq path uses a different shader with its own + // shared-memory layout, so use the int-specific availability flags. + const bool is_q8_1 = (src1_type == GGML_TYPE_Q8_1); + const bool mm_l = is_q8_1 ? ctx->device->mul_mat_id_l_int[src0_type] : ctx->device->mul_mat_id_l[src0_type]; + const bool mm_m = is_q8_1 ? ctx->device->mul_mat_id_m_int[src0_type] : ctx->device->mul_mat_id_m[src0_type]; + const bool mm_s = is_q8_1 ? ctx->device->mul_mat_id_s_int[src0_type] : ctx->device->mul_mat_id_s[src0_type]; if (ctx->device->coopmat2) { // Use large shader when the N dimension is greater than the medium shader's tile size uint32_t crossover_large = mmp->m->wg_denoms[1]; - if ((ctx->device->mul_mat_id_l[src0_type] && (n > crossover_large)) || (!ctx->device->mul_mat_id_m[src0_type] && !ctx->device->mul_mat_id_s[src0_type])) { + if ((mm_l && (n > crossover_large)) || (!mm_m && !mm_s)) { return aligned ? mmp->a_l : mmp->l; } // Use medium shader when the N dimension is greater than the small shader's tile size uint32_t crossover_medium = mmp->s->wg_denoms[1]; - if ((ctx->device->mul_mat_id_m[src0_type] && (n > crossover_medium)) || !ctx->device->mul_mat_id_s[src0_type]) { + if ((mm_m && (n > crossover_medium)) || !mm_s) { return aligned ? mmp->a_m : mmp->m; } return aligned ? mmp->a_s : mmp->s; } - if ((ctx->device->mul_mat_id_s[src0_type] && (m <= 32 || n <= 32)) || (!ctx->device->mul_mat_id_m[src0_type] && !ctx->device->mul_mat_id_l[src0_type])) { + if ((mm_s && (m <= 32 || n <= 32)) || (!mm_m && !mm_l)) { return aligned ? mmp->a_s : mmp->s; } - if ((ctx->device->mul_mat_id_m[src0_type] && (m <= 64 || n <= 64)) || !ctx->device->mul_mat_id_l[src0_type]) { + if ((mm_m && (m <= 64 || n <= 64)) || !mm_l) { return aligned ? mmp->a_m : mmp->m; } return aligned ? mmp->a_l : mmp->l; } -static uint32_t ggml_vk_guess_matmul_id_pipeline_align(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, int m, int n, ggml_type src0_type) { - VK_LOG_DEBUG("ggml_vk_guess_matmul_pipeline_align(" << m << ", " << n << ", " << ggml_type_name(src0_type) << ")"); - return ggml_vk_guess_matmul_id_pipeline(ctx, mmp, m, n, true, src0_type)->align; +static uint32_t ggml_vk_guess_matmul_id_pipeline_align(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, int m, int n, ggml_type src0_type, ggml_type src1_type) { + VK_LOG_DEBUG("ggml_vk_guess_matmul_pipeline_align(" << m << ", " << n << ", " << ggml_type_name(src0_type) << ", " << ggml_type_name(src1_type) << ")"); + return ggml_vk_guess_matmul_id_pipeline(ctx, mmp, m, n, true, src0_type, src1_type)->align; } static void ggml_vk_matmul_id( @@ -6820,6 +8096,13 @@ static vk_pipeline ggml_vk_get_cpy_pipeline(ggml_backend_vk_context * ctx, const return ctx->device->pipeline_cpy_f32_bf16; } } + if (src->type == GGML_TYPE_BF16 && to == GGML_TYPE_F32) { + if (contig) { + return ctx->device->pipeline_contig_cpy_bf16_f32; + } else { + return ctx->device->pipeline_cpy_bf16_f32; + } + } if (src->type == GGML_TYPE_F32 && to == GGML_TYPE_I32) { if (contig) { return ctx->device->pipeline_contig_cpy_f32_i32; @@ -6836,6 +8119,7 @@ static vk_pipeline ggml_vk_get_cpy_pipeline(ggml_backend_vk_context * ctx, const } if (src->type == GGML_TYPE_F32) { switch (to) { + case GGML_TYPE_Q1_0: case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_1: case GGML_TYPE_Q5_0: @@ -6850,6 +8134,7 @@ static vk_pipeline ggml_vk_get_cpy_pipeline(ggml_backend_vk_context * ctx, const if (to == GGML_TYPE_F32) { switch (src->type) { + case GGML_TYPE_Q1_0: case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_1: case GGML_TYPE_Q5_0: @@ -6916,6 +8201,40 @@ static void ggml_vk_cpy_to_contiguous(ggml_backend_vk_context * ctx, vk_context& ggml_vk_sync_buffers(ctx, subctx); } +// Copy/convert tensor into a caller-defined dense layout. Destination strides +// are in output elements, not bytes. +static void ggml_vk_cpy_to_strided( + ggml_backend_vk_context * ctx, vk_context& subctx, vk_pipeline pipeline, const ggml_tensor * tensor, + const vk_subbuffer & in, const vk_subbuffer & out, + uint32_t nb10, uint32_t nb11, uint32_t nb12, uint32_t nb13) { + VK_LOG_DEBUG("ggml_vk_cpy_to_strided((" << tensor << ", type=" << tensor->type << ", ne0=" << tensor->ne[0] << ", ne1=" << tensor->ne[1] << ", ne2=" << tensor->ne[2] << ", ne3=" << tensor->ne[3] << ", nb0=" << tensor->nb[0] << ", nb1=" << tensor->nb[1] << ", nb2=" << tensor->nb[2] << ", nb3=" << tensor->nb[3] << "), "; + std::cerr << "dst_nb=(" << nb10 << ", " << nb11 << ", " << nb12 << ", " << nb13 << "), buffer in size=" << in.buffer->size << ", buffer out size=" << out.buffer->size << ")"); + const int tensor_type_size = ggml_type_size(tensor->type); + + const uint32_t ne = ggml_nelements(tensor); + std::array<uint32_t, 3> elements; + + if (ne > 262144) { + elements = { 512, 512, CEIL_DIV(ne, 262144) }; + } else if (ne > 512) { + elements = { 512, CEIL_DIV(ne, 512), 1 }; + } else { + elements = { ne, 1, 1 }; + } + + vk_op_unary_push_constants pc = { + (uint32_t)ne, + (uint32_t)tensor->ne[0], (uint32_t)tensor->ne[1], (uint32_t)tensor->ne[2], (uint32_t)tensor->ne[3], (uint32_t)tensor->nb[0] / tensor_type_size, (uint32_t)tensor->nb[1] / tensor_type_size, (uint32_t)tensor->nb[2] / tensor_type_size, (uint32_t)tensor->nb[3] / tensor_type_size, + (uint32_t)tensor->ne[0], (uint32_t)tensor->ne[1], (uint32_t)tensor->ne[2], (uint32_t)tensor->ne[3], nb10, nb11, nb12, nb13, + 0, + 0.0f, 0.0f, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + }; + init_pushconst_fastdiv(pc); + ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { in, out }, pc, elements); + ggml_vk_sync_buffers(ctx, subctx); +} + static vk_pipeline ggml_vk_get_quantize_pipeline(ggml_backend_vk_context * ctx, ggml_type type) { switch(type) { case GGML_TYPE_Q8_1: @@ -7037,10 +8356,12 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub // Not implemented GGML_ASSERT(y_non_contig || !qy_needs_dequant); // NOLINT - const uint32_t kpad = quantize_y ? 0 : ggml_vk_align_size(ne10, ggml_vk_guess_matmul_pipeline_align(ctx, mmp, ne01, ne11, qx_needs_dequant ? f16_type : src0->type, quantize_y ? GGML_TYPE_Q8_1 : (y_f32_kernel ? GGML_TYPE_F32 : src1->type))); + const ggml_type effective_src1_type = quantize_y ? GGML_TYPE_Q8_1 : (y_f32_kernel ? GGML_TYPE_F32 : src1->type); + + const uint32_t kpad = quantize_y ? 0 : ggml_vk_align_size(ne10, ggml_vk_guess_matmul_pipeline_align(ctx, mmp, ne01, ne11, qx_needs_dequant ? f16_type : src0->type, effective_src1_type)); const bool aligned = !quantize_y && ne10 == kpad && ne01 > 8 && ne11 > 8; - vk_pipeline pipeline = ggml_vk_guess_matmul_pipeline(ctx, mmp, ne01, ne11, aligned, qx_needs_dequant ? f16_type : src0->type, quantize_y ? GGML_TYPE_Q8_1 : (y_f32_kernel ? GGML_TYPE_F32 : src1->type)); + vk_pipeline pipeline = ggml_vk_guess_matmul_pipeline(ctx, mmp, ne01, ne11, aligned, qx_needs_dequant ? f16_type : src0->type, effective_src1_type); if (ggml_nbytes(src0) > ctx->device->properties.limits.maxStorageBufferRange) { pipeline = ggml_vk_get_64b_indexing_pipeline(ctx, pipeline); @@ -7104,7 +8425,6 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub } // Request descriptor sets - ggml_pipeline_request_descriptor_sets(ctx, pipeline, 1); if (qx_needs_dequant) { ggml_pipeline_request_descriptor_sets(ctx, to_fp16_vk_0, 1); } @@ -7172,24 +8492,28 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub } if (y_non_contig) { if (ctx->prealloc_y_last_pipeline_used != to_fp16_vk_1.get() || - ctx->prealloc_y_last_tensor_used != src1) { + ctx->prealloc_y_last_tensor_used != src1 || + ctx->prealloc_y_last_decode_vector_staging) { if (ctx->prealloc_y_need_sync) { ggml_vk_sync_buffers(ctx, subctx); } ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_1, src1, ggml_vk_subbuffer(ctx, d_Qy, qy_buf_offset), ggml_vk_subbuffer(ctx, d_Y, 0)); ctx->prealloc_y_last_pipeline_used = to_fp16_vk_1.get(); ctx->prealloc_y_last_tensor_used = src1; + ctx->prealloc_y_last_decode_vector_staging = false; } } if (quantize_y) { if (ctx->prealloc_y_last_pipeline_used != to_q8_1.get() || - ctx->prealloc_y_last_tensor_used != src1) { + ctx->prealloc_y_last_tensor_used != src1 || + ctx->prealloc_y_last_decode_vector_staging) { if (ctx->prealloc_y_need_sync) { ggml_vk_sync_buffers(ctx, subctx); } ggml_vk_quantize_q8_1(ctx, subctx, ggml_vk_subbuffer(ctx, d_Qy, qy_buf_offset), ggml_vk_subbuffer(ctx, d_Y, 0), y_ne); ctx->prealloc_y_last_pipeline_used = to_q8_1.get(); ctx->prealloc_y_last_tensor_used = src1; + ctx->prealloc_y_last_decode_vector_staging = false; } } @@ -7230,8 +8554,10 @@ static bool ggml_vk_should_use_mmvq(const vk_device& device, uint32_t m, uint32_ return false; } - // General performance issue with q3_k and q6_k due to 2-byte alignment - if (src0_type == GGML_TYPE_Q3_K || src0_type == GGML_TYPE_Q6_K) { + // q6_k only has 2-byte alignment which makes it somewhat problematic, + // using MMVQ is only a win on Intel. + bool mmvq_q6 = device->vendor_id == VK_VENDOR_ID_INTEL; + if (src0_type == GGML_TYPE_Q6_K && !mmvq_q6) { return false; } @@ -7243,7 +8569,7 @@ static bool ggml_vk_should_use_mmvq(const vk_device& device, uint32_t m, uint32_ // Quantization overhead is not worth it for small k switch (device->vendor_id) { case VK_VENDOR_ID_NVIDIA: - if (src0_type == GGML_TYPE_Q2_K || src0_type == GGML_TYPE_IQ1_S || src0_type == GGML_TYPE_IQ1_M) { + if (src0_type == GGML_TYPE_Q2_K || src0_type == GGML_TYPE_Q3_K || src0_type == GGML_TYPE_IQ1_S || src0_type == GGML_TYPE_IQ1_M) { return true; } @@ -7270,6 +8596,19 @@ static bool ggml_vk_should_use_mmvq(const vk_device& device, uint32_t m, uint32_ return true; } case VK_VENDOR_ID_INTEL: + if (device->architecture == vk_device_architecture::INTEL_XE2) { + if (src0_type == GGML_TYPE_Q2_K || src0_type == GGML_TYPE_Q3_K || src0_type == GGML_TYPE_Q6_K) { + return true; + } + } + + if (device->driver_id == vk::DriverId::eIntelProprietaryWindows) { + // Intel Windows proprietary driver MMVQ performance for !Q2/Q3/Q6 is worse than fp16, + // see https://github.com/ggml-org/llama.cpp/issues/17628 and + // https://github.com/ggml-org/llama.cpp/pull/23056 + return false; + } + if (k < 2048) { return false; } @@ -7402,7 +8741,6 @@ static void ggml_vk_mul_mat_vec_q_f16(ggml_backend_vk_context * ctx, vk_context& if (quantize_y) { ggml_pipeline_request_descriptor_sets(ctx, to_q8_1, 1); } - ggml_pipeline_request_descriptor_sets(ctx, dmmv, 1); } vk_subbuffer d_D = ggml_vk_tensor_subbuffer(ctx, cgraph->nodes[node_idx + ctx->num_additional_fused_ops]); @@ -7433,24 +8771,28 @@ static void ggml_vk_mul_mat_vec_q_f16(ggml_backend_vk_context * ctx, vk_context& if (y_non_contig) { GGML_ASSERT(y_sz == ggml_type_size(src1->type) * y_ne); if (ctx->prealloc_y_last_pipeline_used != to_fp16_vk_1.get() || - ctx->prealloc_y_last_tensor_used != src1) { + ctx->prealloc_y_last_tensor_used != src1 || + ctx->prealloc_y_last_decode_vector_staging) { if (ctx->prealloc_y_need_sync) { ggml_vk_sync_buffers(ctx, subctx); } ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_1, src1, d_Qy, d_Y); ctx->prealloc_y_last_pipeline_used = to_fp16_vk_1.get(); ctx->prealloc_y_last_tensor_used = src1; + ctx->prealloc_y_last_decode_vector_staging = false; } } if (quantize_y) { if (ctx->prealloc_y_last_pipeline_used != to_q8_1.get() || - ctx->prealloc_y_last_tensor_used != src1) { + ctx->prealloc_y_last_tensor_used != src1 || + ctx->prealloc_y_last_decode_vector_staging) { if (ctx->prealloc_y_need_sync) { ggml_vk_sync_buffers(ctx, subctx); } ggml_vk_quantize_q8_1(ctx, subctx, d_Qy, d_Y, y_ne); ctx->prealloc_y_last_pipeline_used = to_q8_1.get(); ctx->prealloc_y_last_tensor_used = src1; + ctx->prealloc_y_last_decode_vector_staging = false; } } @@ -7497,22 +8839,29 @@ static void ggml_vk_mul_mat_vec_q_f16(ggml_backend_vk_context * ctx, vk_context& fusion_flags |= MAT_VEC_FUSION_FLAGS_BIAS1; } - // compute - const vk_mat_vec_push_constants pc = { - (uint32_t)ne00, (uint32_t)ne10, (uint32_t)ne10, (uint32_t)ne01, - stride_batch_x, stride_batch_y, stride_batch_d, - fusion_flags, - (uint32_t)ne02, (uint32_t)ne12, (uint32_t)r2, (uint32_t)r3, - }; - ggml_vk_dispatch_pipeline(ctx, subctx, dmmv, - { - d_X, - d_Y, - d_D, - d_F0, - d_F1, - }, - pc, { groups_x, (uint32_t)(ne12 * ne13), groups_z }); + ggml_pipeline_request_descriptor_sets(ctx, dmmv, CEIL_DIV(ne12 * ne13, ctx->device->properties.limits.maxComputeWorkGroupCount[1])); + + uint32_t base_work_group_y = 0; + while (base_work_group_y < ne12 * ne13) { + + uint32_t groups_y = std::min((uint32_t)(ne12 * ne13) - base_work_group_y, ctx->device->properties.limits.maxComputeWorkGroupCount[1]); + const vk_mat_vec_push_constants pc = { + (uint32_t)ne00, (uint32_t)ne10, (uint32_t)ne10, (uint32_t)ne01, + stride_batch_x, stride_batch_y, stride_batch_d, + fusion_flags, base_work_group_y, + (uint32_t)ne02, (uint32_t)ne12, (uint32_t)r2, (uint32_t)r3, + }; + ggml_vk_dispatch_pipeline(ctx, subctx, dmmv, + { + d_X, + d_Y, + d_D, + d_F0, + d_F1, + }, + pc, { groups_x, groups_y, groups_z }); + base_work_group_y += groups_y; + } if (x_non_contig) { ctx->prealloc_x_need_sync = true; @@ -7708,6 +9057,68 @@ static void ggml_vk_mul_mat_vec_nc_f16_f32(ggml_backend_vk_context * ctx, vk_con }, pc, { (uint32_t)ne03, (uint32_t)ne01, (uint32_t)ne12 }); } +static int ggml_vk_fwht_pipeline_idx(int64_t n) { + switch (n) { + case 64: return 0; + case 128: return 1; + case 256: return 2; + case 512: return 3; + default: return -1; + } +} + +static bool ggml_vk_can_use_fwht(const ggml_backend_vk_context * ctx, const ggml_tensor * src1, const ggml_tensor * dst) { + if (ctx->num_additional_fused_ops != 0) { + return false; + } + + if (ggml_get_op_params_i32(dst, 1) != GGML_HINT_SRC0_IS_HADAMARD) { + return false; + } + + const int idx = ggml_vk_fwht_pipeline_idx(src1->ne[0]); + if (idx < 0 || ctx->device->pipeline_fwht_f32[idx] == nullptr) { + return false; + } + + if (src1->type != GGML_TYPE_F32 || dst->type != GGML_TYPE_F32) { + return false; + } + + if (!ggml_is_contiguous(src1)) { + return false; + } + GGML_ASSERT(ggml_is_contiguous(dst)); + + return true; +} + +static void ggml_vk_fwht(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src, ggml_tensor * dst) { + const int idx = ggml_vk_fwht_pipeline_idx(src->ne[0]); + vk_pipeline pipeline = ctx->device->pipeline_fwht_f32[idx]; + + const uint32_t rows_per_workgroup = 4; + const uint32_t n_rows = (uint32_t)ggml_nrows(src); + const uint32_t max_workgroups_x = ctx->device->properties.limits.maxComputeWorkGroupCount[0]; + + const uint32_t total_workgroups = CEIL_DIV(n_rows, rows_per_workgroup); + const uint32_t workgroups_x = std::min(total_workgroups, max_workgroups_x); + ggml_pipeline_request_descriptor_sets(ctx, pipeline, 1); + + const vk_subbuffer src_buf = ggml_vk_tensor_subbuffer(ctx, src, true); + const vk_subbuffer dst_buf = ggml_vk_tensor_subbuffer(ctx, dst, true); + + vk_op_fwht_push_constants pc = { + n_rows, + 0, + 0, + 1.0f / std::sqrt((float)src->ne[0]), + }; + init_pushconst_tensor_offsets(ctx, pc, src, nullptr, nullptr, nullptr, dst); + + ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { src_buf, dst_buf }, pc, { workgroups_x, 1, 1 }); +} + static void ggml_vk_mul_mat(ggml_backend_vk_context * ctx, vk_context& subctx, const struct ggml_cgraph * cgraph, int node_idx) { ggml_tensor * dst = cgraph->nodes[node_idx]; ggml_tensor * src0 = dst->src[0]; @@ -7741,6 +9152,8 @@ static void ggml_vk_mul_mat(ggml_backend_vk_context * ctx, vk_context& subctx, c m_offset += cur_M_size; } + } else if (ggml_vk_can_use_fwht(ctx, src1, dst)) { + ggml_vk_fwht(ctx, subctx, src1, dst); } else if (src0->type == GGML_TYPE_F16 && ggml_is_permuted(src0) && ggml_is_permuted(src1) && dst->ne[1] == 1 && // detect 0213 permutation, and batch size of 1 src0->nb[0] <= src0->nb[2] && @@ -7750,10 +9163,15 @@ static void ggml_vk_mul_mat(ggml_backend_vk_context * ctx, vk_context& subctx, c src1->nb[2] <= src1->nb[1] && src1->nb[1] <= src1->nb[3] && src0->ne[3] == 1 && - src1->ne[3] == 1) { + src1->ne[3] == 1 && + src0->ne[1] <= ctx->device->properties.limits.maxComputeWorkGroupCount[1] && + src1->ne[2] <= ctx->device->properties.limits.maxComputeWorkGroupCount[2]) { ggml_vk_mul_mat_vec_p021_f16_f32(ctx, subctx, cgraph, node_idx); } else if (src0->type == GGML_TYPE_F16 && !ggml_is_contiguous(src0) && !ggml_is_transposed(src1) && dst->ne[1] == 1 && - !ggml_is_permuted(src0) && !ggml_is_permuted(src1)) { + !ggml_is_permuted(src0) && !ggml_is_permuted(src1) && + src0->ne[3] <= ctx->device->properties.limits.maxComputeWorkGroupCount[0] && + src0->ne[1] <= ctx->device->properties.limits.maxComputeWorkGroupCount[1] && + src1->ne[2] <= ctx->device->properties.limits.maxComputeWorkGroupCount[2]) { ggml_vk_mul_mat_vec_nc_f16_f32(ctx, subctx, cgraph, node_idx); // mul_mat_vec supports batching ne12*ne13 when ne11==1, or treating ne11 as the batch size (up to four) // when ne12 and ne13 are one. @@ -7825,12 +9243,30 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context& // Reformat and convert to fp16 if non-contiguous, or for coopmat2 for better perf const bool x_non_contig = (ctx->device->coopmat2 && src0->type == GGML_TYPE_F32) || !ggml_vk_dim01_contiguous(src0); - const bool y_non_contig = (ctx->device->coopmat2 && src1->type == GGML_TYPE_F32) || + // If src0 is BF16, try to use a BF16 x BF16 multiply + ggml_type f16_type = src0->type == GGML_TYPE_BF16 ? GGML_TYPE_BF16 : GGML_TYPE_F16; +#if defined(GGML_VULKAN_COOPMAT2_DECODE_VECTOR_GLSLC_SUPPORT) + // B must already be, or be convertible to, the matmul B type used by this path. + const bool y_decode_vector_supported = ctx->device->coopmat2_decode_vector && + (f16_type != GGML_TYPE_BF16 || ctx->device->coopmat2_bf16_support) && + (src1->type == GGML_TYPE_F32 || src1->type == f16_type); + // If B is copied to prealloc_y, we can choose a 4-element-aligned row stride. + const bool y_decode_vector_uses_prealloc = !ggml_vk_dim01_contiguous(src1) || src1->type != f16_type; + // Direct B reads are safe only if row starts and the original buffer offset are 4-element aligned. + const bool y_decode_vector_aligned = + (ne10 % 4 == 0) && + (y_decode_vector_uses_prealloc || get_misalign_bytes(ctx, src1) % (4 * ggml_type_size(src1->type)) == 0); + // Stage B only when decode-vector is available and direct B reads would be misaligned. + const bool y_decode_vector_staging = y_decode_vector_supported && !y_decode_vector_aligned; +#else + const bool y_decode_vector_staging = false; +#endif + const bool y_non_contig = y_decode_vector_staging || + (ctx->device->coopmat2 && src1->type == GGML_TYPE_F32) || (src0->type == GGML_TYPE_BF16 && src1->type != GGML_TYPE_BF16) || !ggml_vk_dim01_contiguous(src1); - // If src0 is BF16, try to use a BF16 x BF16 multiply - ggml_type f16_type = src0->type == GGML_TYPE_BF16 ? GGML_TYPE_BF16 : GGML_TYPE_F16; + const uint32_t y_staged_row_stride = y_decode_vector_staging ? (uint32_t)ggml_vk_align_size(ne10, 4) : (uint32_t)ne10; const bool y_f32_kernel = src1->type == GGML_TYPE_F32 && !y_non_contig; @@ -7856,10 +9292,12 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context& // Not implemented GGML_ASSERT(y_non_contig || !qy_needs_dequant); // NOLINT - const uint32_t kpad = quantize_y ? 0 : ggml_vk_align_size(ne10, ggml_vk_guess_matmul_id_pipeline_align(ctx, mmp, ne01, nei1, qx_needs_dequant ? f16_type : src0->type)); + const ggml_type effective_src1_type = quantize_y ? GGML_TYPE_Q8_1 : (y_f32_kernel ? GGML_TYPE_F32 : src1->type); + + const uint32_t kpad = quantize_y ? 0 : ggml_vk_align_size(ne10, ggml_vk_guess_matmul_id_pipeline_align(ctx, mmp, ne01, nei1, qx_needs_dequant ? f16_type : src0->type, effective_src1_type)); const bool aligned = !quantize_y && ne10 == kpad && ne01 > 8 && nei1 > 8; - vk_pipeline pipeline = ggml_vk_guess_matmul_id_pipeline(ctx, mmp, ne01, nei1, aligned, qx_needs_dequant ? f16_type : src0->type); + vk_pipeline pipeline = ggml_vk_guess_matmul_id_pipeline(ctx, mmp, ne01, nei1, aligned, qx_needs_dequant ? f16_type : src0->type, effective_src1_type); if (ggml_nbytes(src0) > ctx->device->properties.limits.maxStorageBufferRange) { pipeline = ggml_vk_get_64b_indexing_pipeline(ctx, pipeline); @@ -7867,11 +9305,11 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context& // Reserve extra storage in the N dimension for the Y matrix, so we can avoid bounds-checking uint32_t padded_n = qy_needs_dequant ? ROUNDUP_POW2(ne11, pipeline->wg_denoms[1]) :ne11; const uint64_t x_ne = ggml_nelements(src0); - const uint64_t y_ne = padded_n * ne10 * ne12 * ne13; + const uint64_t y_ne = (uint64_t)y_staged_row_stride * padded_n * ne12 * ne13; const uint64_t d_ne = ggml_nelements(dst); const uint64_t qx_sz = ggml_type_size(src0->type) * x_ne / ggml_blck_size(src0->type); - const uint64_t qy_sz = ggml_type_size(src1->type) * y_ne / ggml_blck_size(src1->type); + const uint64_t qy_sz = ggml_type_size(src1->type) * ggml_nelements(src1) / ggml_blck_size(src1->type); const uint64_t x_sz = !qx_needs_dequant ? qx_sz : sizeof(ggml_fp16_t) * x_ne; const uint64_t y_sz = quantize_y ? (ggml_vk_align_size(y_ne, 128) * ggml_type_size(GGML_TYPE_Q8_1) / ggml_blck_size(GGML_TYPE_Q8_1)) : (y_f32_kernel ? sizeof(float) * y_ne : sizeof(ggml_fp16_t) * y_ne); const uint64_t ids_sz = nbi2; @@ -7881,13 +9319,30 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context& vk_pipeline to_fp16_vk_1 = nullptr; vk_pipeline to_q8_1 = nullptr; + auto make_y_staged_dst = [&]() { + ggml_tensor y_staged_dst = *src1; + y_staged_dst.type = f16_type; + y_staged_dst.nb[0] = ggml_type_size(f16_type); + y_staged_dst.nb[1] = y_staged_dst.nb[0] * y_staged_row_stride; + y_staged_dst.nb[2] = y_staged_dst.nb[1] * padded_n; + y_staged_dst.nb[3] = y_staged_dst.nb[2] * y_staged_dst.ne[2]; + return y_staged_dst; + }; + if (x_non_contig) { to_fp16_vk_0 = ggml_vk_get_cpy_pipeline(ctx, src0, nullptr, f16_type); } else { to_fp16_vk_0 = ggml_vk_get_to_fp16(ctx, src0->type); } if (y_non_contig) { - to_fp16_vk_1 = ggml_vk_get_cpy_pipeline(ctx, src1, nullptr, f16_type); + ggml_tensor y_staged_dst; + const ggml_tensor * y_staged_dst_ptr = nullptr; + if (y_decode_vector_staging) { + y_staged_dst = make_y_staged_dst(); + y_staged_dst_ptr = &y_staged_dst; + } + + to_fp16_vk_1 = ggml_vk_get_cpy_pipeline(ctx, src1, y_staged_dst_ptr, f16_type); } else { to_fp16_vk_1 = ggml_vk_get_to_fp16(ctx, src1->type); } @@ -8005,30 +9460,47 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context& } if (y_non_contig) { if (ctx->prealloc_y_last_pipeline_used != to_fp16_vk_1.get() || - ctx->prealloc_y_last_tensor_used != src1) { + ctx->prealloc_y_last_tensor_used != src1 || + ctx->prealloc_y_last_decode_vector_staging != y_decode_vector_staging) { if (ctx->prealloc_y_need_sync) { ggml_vk_sync_buffers(ctx, subctx); } - ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_1, src1, ggml_vk_subbuffer(ctx, d_Qy, qy_buf_offset), ggml_vk_subbuffer(ctx, d_Y, 0)); + if (y_decode_vector_staging) { + const ggml_tensor y_staged_dst = make_y_staged_dst(); + const uint32_t y_staged_dst_type_size = ggml_type_size(y_staged_dst.type); + ggml_vk_cpy_to_strided( + ctx, subctx, to_fp16_vk_1, src1, + ggml_vk_subbuffer(ctx, d_Qy, qy_buf_offset), ggml_vk_subbuffer(ctx, d_Y, 0), + (uint32_t)(y_staged_dst.nb[0] / y_staged_dst_type_size), + (uint32_t)(y_staged_dst.nb[1] / y_staged_dst_type_size), + (uint32_t)(y_staged_dst.nb[2] / y_staged_dst_type_size), + (uint32_t)(y_staged_dst.nb[3] / y_staged_dst_type_size)); + } else { + ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_1, src1, ggml_vk_subbuffer(ctx, d_Qy, qy_buf_offset), ggml_vk_subbuffer(ctx, d_Y, 0)); + } ctx->prealloc_y_last_pipeline_used = to_fp16_vk_1.get(); ctx->prealloc_y_last_tensor_used = src1; + ctx->prealloc_y_last_decode_vector_staging = y_decode_vector_staging; } } if (quantize_y) { if (ctx->prealloc_y_last_pipeline_used != to_q8_1.get() || - ctx->prealloc_y_last_tensor_used != src1) { + ctx->prealloc_y_last_tensor_used != src1 || + ctx->prealloc_y_last_decode_vector_staging) { if (ctx->prealloc_y_need_sync) { ggml_vk_sync_buffers(ctx, subctx); } ggml_vk_quantize_q8_1(ctx, subctx, ggml_vk_subbuffer(ctx, d_Qy, qy_buf_offset), ggml_vk_subbuffer(ctx, d_Y, 0), y_ne); ctx->prealloc_y_last_pipeline_used = to_q8_1.get(); ctx->prealloc_y_last_tensor_used = src1; + ctx->prealloc_y_last_decode_vector_staging = false; } } ggml_vk_sync_buffers(ctx, subctx); uint32_t stride_batch_x = ne00*ne01; - uint32_t stride_batch_y = ne10*ne11; + uint32_t stride_b_y = y_decode_vector_staging ? y_staged_row_stride : ne10; + uint32_t stride_batch_y = y_decode_vector_staging ? y_staged_row_stride * padded_n : ne10*ne11; if (!ggml_vk_dim01_contiguous(src0) && !qx_needs_dequant) { stride_batch_x = src0->nb[0] / ggml_type_size(src0->type); @@ -8043,7 +9515,7 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context& ctx, subctx, pipeline, { d_X, x_buf_offset, x_sz }, { d_Y, y_buf_offset, y_sz }, { d_D, d_buf_offset, d_sz }, { d_ids, ids_buf_offset, ids_sz }, expert_count_buf, - ne01, ne21, ne10, ne10, ne10, ne01, + ne01, ne21, ne10, ne10, stride_b_y, ne01, stride_batch_x, stride_batch_y, ne20*ne21, n_as, nei0, nei1, nbi1 / ggml_type_size(ids->type), ne11, padded_n ); // NOLINT @@ -8083,8 +9555,7 @@ static void ggml_vk_mul_mat_vec_id_q_f16(ggml_backend_vk_context * ctx, vk_conte const uint64_t nei0 = ids->ne[0]; const uint64_t nei1 = ids->ne[1]; - - GGML_ASSERT(nei1 == 1); + const uint32_t nbi1 = (uint32_t)(ids->nb[1] / sizeof(int)); const uint64_t ne20 = dst->ne[0]; const uint64_t ne21 = dst->ne[1]; @@ -8168,7 +9639,7 @@ static void ggml_vk_mul_mat_vec_id_q_f16(ggml_backend_vk_context * ctx, vk_conte if (quantize_y) { ggml_pipeline_request_descriptor_sets(ctx, to_q8_1, 1); } - ggml_pipeline_request_descriptor_sets(ctx, dmmv, 1); + ggml_pipeline_request_descriptor_sets(ctx, dmmv, nei1); } vk_subbuffer d_D = ggml_vk_tensor_subbuffer(ctx, cgraph->nodes[node_idx + ctx->num_additional_fused_ops]); @@ -8202,31 +9673,35 @@ static void ggml_vk_mul_mat_vec_id_q_f16(ggml_backend_vk_context * ctx, vk_conte if (y_non_contig) { GGML_ASSERT(y_sz == ggml_type_size(src1->type) * y_ne); if (ctx->prealloc_y_last_pipeline_used != to_fp16_vk_1.get() || - ctx->prealloc_y_last_tensor_used != src1) { + ctx->prealloc_y_last_tensor_used != src1 || + ctx->prealloc_y_last_decode_vector_staging) { if (ctx->prealloc_y_need_sync) { ggml_vk_sync_buffers(ctx, subctx); } ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_1, src1, d_Qy, d_Y); ctx->prealloc_y_last_pipeline_used = to_fp16_vk_1.get(); ctx->prealloc_y_last_tensor_used = src1; + ctx->prealloc_y_last_decode_vector_staging = false; } } if (quantize_y) { if (ctx->prealloc_y_last_pipeline_used != to_q8_1.get() || - ctx->prealloc_y_last_tensor_used != src1) { + ctx->prealloc_y_last_tensor_used != src1 || + ctx->prealloc_y_last_decode_vector_staging) { if (ctx->prealloc_y_need_sync) { ggml_vk_sync_buffers(ctx, subctx); } ggml_vk_quantize_q8_1(ctx, subctx, d_Qy, d_Y, y_ne); ctx->prealloc_y_last_pipeline_used = to_q8_1.get(); ctx->prealloc_y_last_tensor_used = src1; + ctx->prealloc_y_last_decode_vector_staging = false; } } uint32_t stride_batch_y = ne10*ne11; if (!ggml_vk_dim01_contiguous(src1) && !qy_needs_dequant) { - stride_batch_y = src1->nb[0] / ggml_type_size(src1->type); + stride_batch_y = src1->nb[2] / ggml_type_size(src1->type); } const uint32_t max_groups_x = ctx->device->properties.limits.maxComputeWorkGroupCount[0]; @@ -8260,25 +9735,27 @@ static void ggml_vk_mul_mat_vec_id_q_f16(ggml_backend_vk_context * ctx, vk_conte d_F1 = ggml_vk_tensor_subbuffer(ctx, scale); fusion_flags |= MAT_VEC_FUSION_FLAGS_SCALE1; - } - - // compute - const vk_mat_vec_id_push_constants pc = { - (uint32_t)ne00, (uint32_t)ne10, (uint32_t)ne10, (uint32_t)ne01, - (uint32_t)(ne00 * ne01), stride_batch_y, (uint32_t)(ne20 * ne21), - fusion_flags, - (uint32_t)nei0, (uint32_t)ne11, - }; - ggml_vk_dispatch_pipeline(ctx, subctx, dmmv, - { - d_X, - d_Y, - d_D, - d_F0, - d_F1, - d_ids, - }, - pc, { groups_x, (uint32_t)nei0, groups_z }); + } + + // Loop over the batch dimension + for (uint32_t expert_i1 = 0; expert_i1 < nei1; ++expert_i1) { + const vk_mat_vec_id_push_constants pc = { + (uint32_t)ne00, (uint32_t)ne10, (uint32_t)ne10, (uint32_t)ne01, + (uint32_t)(ne00 * ne01), stride_batch_y, (uint32_t)(ne20 * ne21), + fusion_flags, + (uint32_t)nei0, (uint32_t)ne11, expert_i1, nbi1 + }; + ggml_vk_dispatch_pipeline(ctx, subctx, dmmv, + { + d_X, + d_Y, + d_D, + d_F0, + d_F1, + d_ids, + }, + pc, { groups_x, (uint32_t)nei0, groups_z }); + } if (x_non_contig) { ctx->prealloc_x_need_sync = true; @@ -8292,7 +9769,7 @@ static bool ggml_vk_use_mul_mat_vec_id(const struct ggml_cgraph * cgraph, int no ggml_tensor * dst = cgraph->nodes[node_idx]; ggml_tensor * src0 = dst->src[0]; ggml_tensor * src2 = dst->src[2]; - return src2->ne[1] == 1 && (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type)); + return (src2->ne[1] <= 8) && (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type)); } static void ggml_vk_mul_mat_id(ggml_backend_vk_context * ctx, vk_context& subctx, const struct ggml_cgraph * cgraph, int node_idx) { @@ -8308,55 +9785,94 @@ static void ggml_vk_mul_mat_id(ggml_backend_vk_context * ctx, vk_context& subctx } } -static bool ggml_vk_flash_attn_scalar_shmem_support(const vk_device& device, const uint32_t hsk, uint32_t hsv, bool small_cache) { +static bool ggml_vk_flash_attn_scalar_shmem_support(const vk_device& device, const vk_fa_tuning_params& params, uint32_t hsk, uint32_t hsv, bool f32acc, ggml_type k_type, ggml_type v_type) { + GGML_UNUSED(f32acc); + GGML_UNUSED(v_type); // Needs to be kept up to date on shader changes - GGML_UNUSED(hsv); - const uint32_t wg_size = scalar_flash_attention_workgroup_size; - const uint32_t Br = get_fa_scalar_num_large_rows(hsk, hsv, small_cache); - const uint32_t Bc = scalar_flash_attention_Bc; + const uint32_t wg_size = params.workgroup_size; + const uint32_t Br = params.block_rows; + const uint32_t Bc = params.block_cols; + + // BF16 uses the fp32 shader (FLOAT_TYPE=float) + const uint32_t float_type_size = (device->fp16 && k_type != GGML_TYPE_BF16) ? sizeof(ggml_fp16_t) : sizeof(float); + const bool mmq = ggml_vk_fa_scalar_uses_mmq(device, k_type); + + // tmpsh is overestimated slightly const uint32_t tmpsh = wg_size * sizeof(float); - const uint32_t tmpshv4 = wg_size * 4 * sizeof(float); + const uint32_t tmpshv4 = wg_size * 4 * float_type_size; + + const uint32_t masksh = Bc * (Br + 1) * float_type_size; - const uint32_t masksh = Bc * Br * sizeof(float); + uint32_t Qf, kvsh, kblocksh_size; + if (mmq) { + // block_b_cache: int32_t qs[8] + FLOAT_TYPEV2 ds + const uint32_t block_b_size = 8 * sizeof(int32_t) + 2 * float_type_size; + Qf = Br * (hsk / 32) * block_b_size; + + // kvsh uses D = HSV (K goes through kblocksh instead) + kvsh = params.shmem_staging ? Bc * (hsv / 4 + 1) * 4 * float_type_size : 4 * float_type_size; + + // The mixed MMQ shader uses a superset block_a_cache that fits every + // FA-supported quant: int32_t qs[8] + uint32_t qh + FLOAT_TYPEV2 dm. + // Single-scale types leave dm.y unused; non-Q5_* leave qh unused. + const uint32_t block_a_size = 8 * sizeof(int32_t) + sizeof(uint32_t) + 2 * float_type_size; + kblocksh_size = params.shmem_staging ? Bc * (hsk / 32) * block_a_size : block_a_size; + } else { + Qf = Br * (hsk / 4 + 1) * 4 * float_type_size; - const uint32_t Qf = Br * (hsk / 4 + 2) * 4 * sizeof(float); + const uint32_t D = std::max(hsk, hsv); + kvsh = params.shmem_staging ? Bc * (D / 4 + 1) * 4 * float_type_size : 4 * float_type_size; + + kblocksh_size = 0; + } - const uint32_t total_size = tmpsh + tmpshv4 + masksh + Qf; + const uint32_t total_size = tmpsh + tmpshv4 + masksh + Qf + kvsh + kblocksh_size; const bool supported = total_size <= device->properties.limits.maxComputeSharedMemorySize; - VK_LOG_DEBUG("ggml_vk_flash_attn_coopmat_shmem_support(HSK=" << hsk << ", HSV=" << hsv << ", total_size=" << total_size << ", supported=" << supported); + VK_LOG_DEBUG("ggml_vk_flash_attn_scalar_shmem_support(HSK=" << hsk << ", HSV=" << hsv << ", mmq=" << mmq << ", total_size=" << total_size << ", supported=" << supported); return supported; } -static bool ggml_vk_flash_attn_coopmat_shmem_support(const vk_device& device, const uint32_t hsk, uint32_t hsv, bool f32acc) { +static bool ggml_vk_flash_attn_coopmat_shmem_support(const vk_device& device, const vk_fa_tuning_params& params, uint32_t hsk, uint32_t hsv, bool f32acc, ggml_type k_type) { // Needs to be kept up to date on shader changes - GGML_UNUSED(hsv); - const uint32_t wg_size = scalar_flash_attention_workgroup_size; - const uint32_t Br = coopmat1_flash_attention_num_large_rows; - const uint32_t Bc = scalar_flash_attention_Bc; + const uint32_t Br = params.block_rows; + const uint32_t Bc = params.block_cols; + + const uint32_t MatBr = 16, MatBc = 16; + + const uint32_t row_split = Bc / MatBc; const uint32_t hsk_pad = ROUNDUP_POW2(hsk, 16); + const uint32_t hsv_pad = ROUNDUP_POW2(hsv, 16); const uint32_t acctype = f32acc ? 4 : 2; const uint32_t f16vec4 = 8; - const uint32_t tmpsh = wg_size * sizeof(float); - const uint32_t tmpshv4 = wg_size * 4 * acctype; + const uint32_t tmpsh = (Bc / MatBc) * sizeof(float); const uint32_t qstride = hsk_pad / 4 + 2; const uint32_t Qf = Br * qstride * f16vec4; + const uint32_t psh_stride = Br / 4 + 2; + const uint32_t Psh = Bc * psh_stride * f16vec4; + const uint32_t sfshstride = (hsk <= 128) ? (Br + 8) : Br; const uint32_t sfsh = Bc * sfshstride * acctype; - const uint32_t kshstride = hsk_pad / 4 + 2; - const uint32_t ksh = Bc * kshstride * f16vec4; + const uint32_t kvshstride = (params.shmem_staging ? std::max(hsk_pad, hsv_pad) : MatBr) / 4 + 2; + const uint32_t vsh_stride = MatBc / 4 * row_split; + const uint32_t ksh = ((kvshstride >= vsh_stride) ? (Bc * kvshstride) : (Bc * vsh_stride)) * f16vec4; - const uint32_t slope = Br * sizeof(float); + // BF16 PVMat accumulator is f32 (no bf16 accumulator support), so pvsh is vec4 (16 bytes) + const uint32_t pvsh_elem_size = (k_type == GGML_TYPE_BF16) ? 16u : f16vec4; + const uint32_t osh_stride = params.row_split * MatBr / 4; + const uint32_t pvsh = MatBc * osh_stride * pvsh_elem_size; - const uint32_t total_size = tmpsh + tmpshv4 + Qf + sfsh + ksh + slope; + const uint32_t slope = Br * acctype; + + const uint32_t total_size = tmpsh + Qf + Psh + sfsh + ksh + pvsh + slope; const bool supported = total_size <= device->properties.limits.maxComputeSharedMemorySize; VK_LOG_DEBUG("ggml_vk_flash_attn_coopmat_shmem_support(HSK=" << hsk << ", HSV=" << hsv << ", f32acc=" << f32acc << ", total_size=" << total_size << ", supported=" << supported); @@ -8383,6 +9899,7 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx GGML_TENSOR_LOCALS(int64_t, ne, dst, ne) GGML_TENSOR_LOCALS(size_t, nb, dst, nb) + const uint32_t nem0 = mask ? mask->ne[0] : 0; const uint32_t nem1 = mask ? mask->ne[1] : 0; const uint32_t nem2 = mask ? mask->ne[2] : 0; const uint32_t nem3 = mask ? mask->ne[3] : 0; @@ -8414,74 +9931,30 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx assert(dst->type == GGML_TYPE_F32); assert(q->type == GGML_TYPE_F32); - assert(k->type == v->type); - - FaCodePath path = ctx->device->coopmat2 ? FA_COOPMAT2 : - ctx->device->coopmat1_fa_support ? FA_COOPMAT1 : FA_SCALAR; - - if (path == FA_COOPMAT1) { - const bool coopmat_shape_supported = (dst->op_params[3] == GGML_PREC_F32 && ctx->device->coopmat_support_16x16x16_f32acc) || - (dst->op_params[3] != GGML_PREC_F32 && ctx->device->coopmat_support_16x16x16_f16acc); - - const bool coopmat_shmem_supported = ggml_vk_flash_attn_coopmat_shmem_support(ctx->device, HSK, HSV, dst->op_params[3] == GGML_PREC_F32); - - if (!coopmat_shape_supported || !coopmat_shmem_supported) { - path = FA_SCALAR; - } - } - uint32_t gqa_ratio = 1; uint32_t qk_ratio = neq2 / nek2; uint32_t workgroups_x = (uint32_t)neq1; uint32_t workgroups_y = (uint32_t)neq2; uint32_t workgroups_z = (uint32_t)neq3; - const bool small_cache = nek1 < 1024; + const bool f32acc = !ctx->device->fp16 || dst->op_params[3] == GGML_PREC_F32 || k->type == GGML_TYPE_BF16; // For scalar/coopmat1 FA, we can use the "large" size to accommodate qga. // For coopmat2 FA, we always use the small size (which is still pretty large for gqa). - uint32_t max_gqa; - switch (path) { - case FA_SCALAR: - case FA_COOPMAT1: - // We may switch from coopmat1 to scalar, so use the scalar limit for both - max_gqa = get_fa_scalar_num_large_rows(HSK, HSV, small_cache); - break; - case FA_COOPMAT2: - max_gqa = get_fa_num_small_rows(FA_COOPMAT2); - break; - default: - GGML_ASSERT(0); - } + vk_fa_tuning_params tuning_params = get_fa_tuning_params(ctx->device, HSK, HSV, 512, KV, k->type, v->type, f32acc); + const uint32_t max_gqa = std::min(tuning_params.block_rows, 32u); - if (N == 1 && qk_ratio > 1 && qk_ratio <= max_gqa && + if (N <= 8 && qk_ratio > 1 && qk_ratio <= max_gqa && qk_ratio * nek2 == neq2 && nek2 == nev2 && nem2 <= 1) { // grouped query attention - make the N dimension equal to gqa_ratio, reduce // workgroups proportionally in y dimension. The shader will detect gqa_ratio > 1 // and change addressing calculations to index Q's dimension 2. gqa_ratio = qk_ratio; N = gqa_ratio; - workgroups_y /= N; - } - - bool small_rows = N <= get_fa_num_small_rows(path); - - // coopmat1 does not actually support "small rows" (it needs 16 rows). - // So use scalar instead. - if (small_rows && path == FA_COOPMAT1) { - path = FA_SCALAR; - } - - // scalar is faster than coopmat2 when N==1 - if (N == 1 && path == FA_COOPMAT2) { - path = FA_SCALAR; + workgroups_y /= gqa_ratio; } - // with large hsk/hsv, scalar path may need to use small_rows to fit in shared memory - if (path == FA_SCALAR && - !ggml_vk_flash_attn_scalar_shmem_support(ctx->device, HSK, HSV, small_cache)) { - small_rows = true; - } + tuning_params = get_fa_tuning_params(ctx->device, HSK, HSV, N, KV, k->type, v->type, f32acc); const uint32_t q_stride = (uint32_t)(nbq1 / ggml_type_size(q->type)); uint32_t k_stride = (uint32_t)(nbk1 / ggml_type_size(k->type)); @@ -8495,25 +9968,38 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx v_stride /= 4; } - uint32_t alignment = fa_align(path, HSK, HSV, k->type, small_rows, small_cache); + const uint32_t alignment = tuning_params.block_cols; bool aligned = (KV % alignment) == 0 && // the "aligned" shader variant will forcibly align strides, for performance (q_stride & 7) == 0 && (k_stride & 7) == 0 && (v_stride & 7) == 0; // Need to use the coopmat2 variant that clamps loads when HSK/HSV aren't sufficiently aligned. - if (((HSK | HSV) % 16) != 0 && path == FA_COOPMAT2) { + if (((HSK | HSV) % 16) != 0 && tuning_params.path == FA_COOPMAT2) { aligned = false; } - bool f32acc = path == FA_SCALAR || dst->op_params[3] == GGML_PREC_F32; + float scale = 1.0f; + float max_bias = 0.0f; + float logit_softcap = 0.0f; + + memcpy(&scale, (const float *) dst->op_params + 0, sizeof(float)); + memcpy(&max_bias, (const float *) dst->op_params + 1, sizeof(float)); + memcpy(&logit_softcap, (const float *) dst->op_params + 2, sizeof(float)); + + if (logit_softcap != 0) { + scale /= logit_softcap; + } - vk_fa_pipeline_state fa_pipeline_state(HSK, HSV, small_rows, small_cache, path, aligned, f32acc); + // Only use mask opt when the mask is fairly large. This hasn't been tuned extensively. + bool use_mask_opt = mask && nem1 >= 32 && nem0 * nem1 > 32768 && nem0 >= tuning_params.block_cols * 16; + vk_fa_pipeline_state fa_pipeline_state = get_fa_pipeline_state(ctx->device, tuning_params, HSK, HSV, aligned, f32acc, + mask != nullptr, use_mask_opt, logit_softcap != 0, k->type, v->type); vk_pipeline pipeline = nullptr; { - std::lock_guard<std::recursive_mutex> guard(ctx->device->mutex); - auto &pipelines = ctx->device->pipeline_flash_attn_f32_f16[k->type]; + std::lock_guard<std::mutex> guard(ctx->device->compile_mutex); + auto &pipelines = ctx->device->pipeline_flash_attn_f32_f16; auto it = pipelines.find(fa_pipeline_state); if (it != pipelines.end()) { pipeline = it->second; @@ -8523,29 +10009,46 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx } assert(pipeline); + // Compile early to initialize wg_denoms. + ggml_pipeline_request_descriptor_sets(ctx, pipeline, 1); uint32_t split_kv = KV; uint32_t split_k = 1; + // Intel Alchemist prefers more workgroups + const uint32_t shader_core_count_multiplier = (ctx->device->vendor_id == VK_VENDOR_ID_INTEL && ctx->device->architecture != INTEL_XE2) ? 2 : 1; + // Use a placeholder core count if one isn't available. split_k is a big help for perf. - const uint32_t shader_core_count = ctx->device->shader_core_count ? ctx->device->shader_core_count : 16; + const uint32_t shader_core_count = ctx->device->shader_core_count ? ctx->device->shader_core_count * shader_core_count_multiplier : 16; - // Try to use split_k when KV is large enough to be worth the overhead - if (workgroups_x == 1 && shader_core_count > 0) { - // Try to run two workgroups per SM. - split_k = shader_core_count * 2 / (workgroups_y * workgroups_z); - if (split_k > 1) { - // Try to evenly split KV into split_k chunks, but it needs to be a multiple - // of "align", so recompute split_k based on that. - split_kv = ROUNDUP_POW2(std::max(1u, KV / split_k), alignment); - split_k = CEIL_DIV(KV, split_kv); - workgroups_x = split_k; + const uint32_t Br = fa_pipeline_state.Br; + const uint32_t Bc = fa_pipeline_state.Bc; + + GGML_ASSERT(Br == pipeline->wg_denoms[0]); + const uint32_t Tr = CEIL_DIV(N, Br); + + // Try to use split_k when KV is large enough to be worth the overhead. + if (gqa_ratio > 1 && workgroups_x <= Br) { + split_k = shader_core_count * 2 / (workgroups_x * workgroups_y * workgroups_z); + } else if (gqa_ratio <= 1) { + uint32_t total_wgs_no_split = Tr * workgroups_y * workgroups_z; + if (total_wgs_no_split < shader_core_count * 2) { + split_k = shader_core_count * 2 / total_wgs_no_split; } } + if (split_k > 1) { + // Try to evenly split KV into split_k chunks, but it needs to be a multiple + // of "align", so recompute split_k based on that. + split_kv = ROUNDUP_POW2(std::max(1u, KV / split_k), alignment); + split_k = CEIL_DIV(KV, split_kv); + } + // Reserve space for split_k temporaries. For each split x batch, we need to store the O matrix (D x ne1) // and the per-row m and L values (ne1 rows). We store all the matrices first, followed by the rows. - const uint64_t split_k_size = split_k > 1 ? (HSV * ne1 * sizeof(float) + ne1 * sizeof(float) * 2) * split_k * ne3 : 0; + // For matrices, the order is (inner to outer) [HSV, ne1, k, ne2, ne3]. + // For L/M, the order is (inner to outer) [ne1, k, ne2, ne3]. + const uint64_t split_k_size = split_k > 1 ? (HSV * ne1 * sizeof(float) + ne1 * sizeof(float) * 2) * split_k * ne2 * ne3 : 0; if (split_k_size > ctx->device->properties.limits.maxStorageBufferRange) { GGML_ABORT("Requested preallocation size is too large"); } @@ -8554,24 +10057,31 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx ggml_vk_preallocate_buffers(ctx, subctx); } - { - // Request descriptor sets - ggml_pipeline_request_descriptor_sets(ctx, pipeline, 1); - if (split_k > 1) { - ggml_pipeline_request_descriptor_sets(ctx, ctx->device->pipeline_flash_attn_split_k_reduce, 1); - } - } - - float scale = 1.0f; - float max_bias = 0.0f; - float logit_softcap = 0.0f; + const uint32_t mask_opt_num_dwords = CEIL_DIV(nem0, 16 * Bc); + const uint64_t mask_opt_size = sizeof(uint32_t) * mask_opt_num_dwords * CEIL_DIV(nem1, Br) * nem2 * nem3; - memcpy(&scale, (const float *) dst->op_params + 0, sizeof(float)); - memcpy(&max_bias, (const float *) dst->op_params + 1, sizeof(float)); - memcpy(&logit_softcap, (const float *) dst->op_params + 2, sizeof(float)); + vk_pipeline pipeline_fa_mask_opt = nullptr; + if (use_mask_opt) { + { + std::lock_guard<std::mutex> guard(ctx->device->compile_mutex); + auto &pipelines = ctx->device->pipeline_fa_mask_opt; + auto it = pipelines.find({Br, Bc}); + if (it != pipelines.end()) { + pipeline_fa_mask_opt = it->second; + } else { + pipelines[{Br, Bc}] = pipeline_fa_mask_opt = std::make_shared<vk_pipeline_struct>(); + } + } + assert(pipeline_fa_mask_opt); + ggml_pipeline_request_descriptor_sets(ctx, pipeline_fa_mask_opt, 1); - if (logit_softcap != 0) { - scale /= logit_softcap; + if (ctx->prealloc_size_y < mask_opt_size) { + ctx->prealloc_size_y = mask_opt_size; + ggml_vk_preallocate_buffers(ctx, subctx); + } + if (ctx->prealloc_y_need_sync) { + ggml_vk_sync_buffers(ctx, subctx); + } } const uint32_t n_head_kv = neq2; @@ -8585,8 +10095,29 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx vk_subbuffer dst_buf = ggml_vk_tensor_subbuffer(ctx, dst); vk_subbuffer mask_buf = mask ? ggml_vk_tensor_subbuffer(ctx, mask) : q_buf; vk_subbuffer sinks_buf = sinks ? ggml_vk_tensor_subbuffer(ctx, sinks) : q_buf; + vk_subbuffer mask_opt_buf = use_mask_opt ? ggml_vk_subbuffer(ctx, ctx->prealloc_y, 0) : q_buf; + + uint32_t mask_n_head_log2 = ((sinks != nullptr) << 24) | n_head_log2; + + if (use_mask_opt) + { + const vk_op_flash_attn_mask_opt_push_constants opt_pc = { + nem0, + nem1, + nem2, + (uint32_t)(mask->nb[1] / sizeof(ggml_fp16_t)), + (uint32_t)(mask->nb[2] / sizeof(ggml_fp16_t)), + (uint32_t)(mask->nb[3] / sizeof(ggml_fp16_t)), + mask_opt_num_dwords, + mask_opt_num_dwords * CEIL_DIV(nem1, Br), + mask_opt_num_dwords * CEIL_DIV(nem1, Br) * nem2, + }; - uint32_t mask_n_head_log2 = ((sinks != nullptr) << 24) | ((mask != nullptr) << 16) | n_head_log2; + ggml_vk_dispatch_pipeline(ctx, subctx, pipeline_fa_mask_opt, + { mask_buf, mask_opt_buf }, opt_pc, + { mask_opt_num_dwords, CEIL_DIV(nem1, Br), nem2 * nem3 }); + ggml_vk_sync_buffers(ctx, subctx); + } const vk_flash_attn_push_constants pc = { N, KV, (uint32_t)ne1, (uint32_t)ne2, (uint32_t)ne3, @@ -8602,28 +10133,40 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx gqa_ratio, split_kv, split_k }; if (split_k > 1) { + ggml_pipeline_request_descriptor_sets(ctx, ctx->device->pipeline_flash_attn_split_k_reduce, 1); + if (ctx->prealloc_split_k_need_sync) { ggml_vk_sync_buffers(ctx, subctx); } + // We reuse workgroups_x to mean the number of splits, so we need to + // cancel out the divide by wg_denoms[0]. + uint32_t dispatch_x; + if (gqa_ratio > 1) { + workgroups_x *= pipeline->wg_denoms[0]; + dispatch_x = split_k * workgroups_x; + } else { + dispatch_x = Tr * split_k * pipeline->wg_denoms[0]; + } + vk_subbuffer split_k_buf = ggml_vk_subbuffer(ctx, ctx->prealloc_split_k, 0); ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, - {q_buf, k_buf, v_buf, mask_buf, sinks_buf, split_k_buf}, - // We only use split_k when group query attention is enabled, which means - // there's no more than one tile of rows (i.e. workgroups_x would have been - // one). We reuse workgroups_x to mean the number of splits, so we need to - // cancel out the divide by wg_denoms[0]. - pc, { workgroups_x * pipeline->wg_denoms[0], workgroups_y, workgroups_z }); + {q_buf, k_buf, v_buf, mask_buf, sinks_buf, split_k_buf, mask_opt_buf}, + pc, { dispatch_x, workgroups_y, workgroups_z }); ggml_vk_sync_buffers(ctx, subctx); - const std::array<uint32_t, 5> pc2 = { HSV, (uint32_t)ne1, (uint32_t)ne3, split_k, (sinks != nullptr) }; + const vk_op_flash_attn_split_k_reduce_push_constants pc2 = { HSV, (uint32_t)ne1, (uint32_t)ne2, (uint32_t)ne3, split_k, (sinks != nullptr) }; ggml_vk_dispatch_pipeline(ctx, subctx, ctx->device->pipeline_flash_attn_split_k_reduce, {split_k_buf, sinks_buf, dst_buf}, - pc2, { (uint32_t)ne1, HSV, (uint32_t)ne3 }); + pc2, { (uint32_t)ne1, HSV, (uint32_t)(ne2 * ne3) }); ctx->prealloc_split_k_need_sync = true; } else { + if (gqa_ratio > 1) { + // When using gqa, we want one actual workgroup per batch, so cancel out wg_denoms + workgroups_x *= pipeline->wg_denoms[0]; + } ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, - {q_buf, k_buf, v_buf, mask_buf, sinks_buf, dst_buf}, + {q_buf, k_buf, v_buf, mask_buf, sinks_buf, dst_buf, mask_opt_buf}, pc, { workgroups_x, workgroups_y, workgroups_z }); } } @@ -8638,10 +10181,23 @@ static vk_conv_shapes ggml_vk_conv_select_shape(ggml_backend_vk_context * ctx, u // so small convolutions will still choose a smaller tile. const uint32_t shader_core_count = ctx->device->shader_core_count > 0 ? ctx->device->shader_core_count : 32; - if (K > 64 && n_tiles(CONV_SHAPE_128x128) >= shader_core_count * 2) { + // 128x128 isn't used with cm1 due to shared memory size; fall through to a smaller tile. + bool allow_128x128 = true; +#if defined(VK_KHR_cooperative_matrix) && defined(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT) + if (!ctx->device->coopmat2 && ctx->device->coopmat_support && ctx->device->coopmat_support_16x16x16_f16acc) { + allow_128x128 = false; + } +#endif + + if (allow_128x128 && K > 64 && n_tiles(CONV_SHAPE_128x128) >= shader_core_count * 2) { return CONV_SHAPE_128x128; } else if (K <= 32 && n_tiles(CONV_SHAPE_32x256) >= shader_core_count * 2) { return CONV_SHAPE_32x256; + } else if (K <= 64 && n_tiles(CONV_SHAPE_64x128) >= shader_core_count * 2) { + return CONV_SHAPE_64x128; + } else if (!allow_128x128 && K > 64 && n_tiles(CONV_SHAPE_64x128) >= shader_core_count * 2) { + // cm1 fallback for large K when 128x128 isn't available + return CONV_SHAPE_64x128; } else { return CONV_SHAPE_64x32; } @@ -8668,6 +10224,12 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const return ctx->device->pipeline_acc_f32; } return nullptr; + case GGML_OP_SET: + if (src0->type == src1->type && src0->type == dst->type && + (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_I32)) { + return ctx->device->pipeline_set_f32; + } + return nullptr; case GGML_OP_ADD: case GGML_OP_SUB: case GGML_OP_MUL: @@ -8807,7 +10369,10 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const return nullptr; case GGML_OP_REPEAT: if (ggml_type_size(src0->type) == sizeof(float) && ggml_type_size(dst->type) == sizeof(float)) { - return ctx->device->pipeline_repeat_f32; + return ctx->device->pipeline_repeat_i32; + } + if (ggml_type_size(src0->type) == 2 && ggml_type_size(dst->type) == 2) { + return ctx->device->pipeline_repeat_i16; } return nullptr; case GGML_OP_REPEAT_BACK: @@ -8869,6 +10434,8 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const switch (ggml_get_unary_op(dst)) { case GGML_UNARY_OP_EXP: return ctx->device->pipeline_exp[dst->type == GGML_TYPE_F16]; + case GGML_UNARY_OP_ELU: + return ctx->device->pipeline_elu[dst->type == GGML_TYPE_F16]; case GGML_UNARY_OP_SILU: return ctx->device->pipeline_silu[dst->type == GGML_TYPE_F16]; case GGML_UNARY_OP_GELU: @@ -8905,6 +10472,8 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const return ctx->device->pipeline_floor[dst->type == GGML_TYPE_F16]; case GGML_UNARY_OP_TRUNC: return ctx->device->pipeline_trunc[dst->type == GGML_TYPE_F16]; + case GGML_UNARY_OP_SGN: + return ctx->device->pipeline_sgn[dst->type == GGML_TYPE_F16]; default: break; } @@ -9035,7 +10604,7 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const vk_pipeline pipeline = nullptr; { - std::lock_guard<std::recursive_mutex> guard(ctx->device->mutex); + std::lock_guard<std::mutex> guard(ctx->device->compile_mutex); auto it = ctx->device->pipeline_solve_tri_f32.find(solve_tri_pipeline_state); if (it != ctx->device->pipeline_solve_tri_f32.end()) { pipeline = it->second; @@ -9098,6 +10667,20 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const return ctx->device->pipeline_rwkv_wkv7_f32; } return nullptr; + case GGML_OP_GATED_DELTA_NET: + if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { + const uint32_t S_v = dst->src[2]->ne[0]; + const uint32_t kda = (dst->src[3]->ne[0] == (int64_t)S_v) ? 1 : 0; + uint32_t si; + switch (S_v) { + case 32: si = 0; break; + case 64: si = 1; break; + case 128: si = 2; break; + default: return nullptr; + } + return ctx->device->pipeline_gated_delta_net[si][kda]; + } + return nullptr; case GGML_OP_SSM_SCAN: if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { const uint32_t d_state = src0->ne[0]; @@ -9110,7 +10693,12 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const return nullptr; case GGML_OP_SSM_CONV: if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { - return ctx->device->pipeline_ssm_conv_f32; + switch (ctx->num_additional_fused_ops) { + case 0: return ctx->device->pipeline_ssm_conv_f32; + case 1: return ctx->device->pipeline_ssm_conv_silu_f32; + case 2: return ctx->device->pipeline_ssm_conv_bias_silu_f32; + default: return nullptr; + } } return nullptr; case GGML_OP_OPT_STEP_ADAMW: @@ -9144,7 +10732,18 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const uint32_t p1 = !transpose ? (uint32_t)ggml_get_op_params_i32(dst, 3) : 0; uint32_t d0 = !transpose ? (uint32_t)ggml_get_op_params_i32(dst, 4) : 1; uint32_t d1 = !transpose ? (uint32_t)ggml_get_op_params_i32(dst, 5) : 1; - vk_conv2d_pipeline_state conv2d_pipeline_state(s0, s1, p0, p1, d0, d1, KW, KH); + + // tile-aligned shapes let the shader skip bounds checks + const uint32_t Cin = (uint32_t)src1->ne[2]; + const uint32_t CRS = Cin * KW * KH; + const uint32_t BS_K = vk_conv_block_sizes[shape].K; + const uint32_t BS_CRS = vk_conv_block_sizes[shape].CRS; + const uint32_t BS_NPQ = vk_conv_block_sizes[shape].NPQ; + const uint32_t aligned = ((K % BS_K == 0) && + (CRS % BS_CRS == 0) && + (NPQ % BS_NPQ == 0)) ? 1u : 0u; + + vk_conv2d_pipeline_state conv2d_pipeline_state(s0, s1, p0, p1, d0, d1, KW, KH, aligned); std::map<vk_conv2d_pipeline_state, vk_pipeline> *pipelines = nullptr; if (op == GGML_OP_CONV_2D) { @@ -9164,7 +10763,7 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const vk_pipeline pipeline = nullptr; { - std::lock_guard<std::recursive_mutex> guard(ctx->device->mutex); + std::lock_guard<std::mutex> guard(ctx->device->compile_mutex); auto it = pipelines->find(conv2d_pipeline_state); if (it != pipelines->end()) { pipeline = it->second; @@ -9211,6 +10810,9 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const if (dst->type == GGML_TYPE_F32) { return ctx->device->pipeline_fill_f32; } + if (dst->type == GGML_TYPE_F16) { + return ctx->device->pipeline_fill_f16; + } return nullptr; default: return nullptr; @@ -9288,6 +10890,15 @@ template <> void init_pushconst_tensor_offsets(ggml_backend_vk_context * ctx, vk GGML_UNUSED(src3); } +template <> void init_pushconst_tensor_offsets(ggml_backend_vk_context * ctx, vk_op_rope_push_constants &p, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, const ggml_tensor * src3, ggml_tensor * dst) { + p.a_offset = get_misalign_bytes(ctx, src0) / ggml_type_size(src0->type); + p.d_offset = get_misalign_bytes(ctx, dst) / ggml_type_size(dst->type); + + GGML_UNUSED(src1); + GGML_UNUSED(src2); + GGML_UNUSED(src3); +} + template<typename PC> static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, const ggml_tensor * src3, ggml_tensor * dst, ggml_op op, PC&& pc) { VK_LOG_DEBUG("ggml_vk_op_f32((" << src0 << ", name=" << src0->name << ", type=" << src0->type << ", ne0=" << src0->ne[0] << ", ne1=" << src0->ne[1] << ", ne2=" << src0->ne[2] << ", ne3=" << src0->ne[3] << ", nb0=" << src0->nb[0] << ", nb1=" << src0->nb[1] << ", nb2=" << src0->nb[2] << ", nb3=" << src0->nb[3]; @@ -9431,7 +11042,13 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co const uint32_t batch = src1->ne[is_2D ? 3 : 2]; - elements = { OW * KW * KH, OH, batch * IC }; + const uint32_t CHW = IC * KH * KW; + // Cap X workgroups to limit concurrent IC channel reads. + // The shader loops over X to cover the full CHW dimension. + // AMD prefers a lower limit + const uint32_t min_cap = ctx->device->vendor_id == VK_VENDOR_ID_AMD ? 512u : 4096u; + const uint32_t x_elements = std::min(CHW, std::max(min_cap, OW * KH * KW)); + elements = { x_elements, OW, OH * batch }; elements[1] = std::min(elements[1], ctx->device->properties.limits.maxComputeWorkGroupCount[1]); elements[2] = std::min(elements[2], ctx->device->properties.limits.maxComputeWorkGroupCount[2]); } break; @@ -9654,16 +11271,16 @@ static void ggml_vk_acc(ggml_backend_vk_context * ctx, vk_context& subctx, const const uint32_t src1_type_size = ggml_type_size(src1->type); const uint32_t dst_type_size = ggml_type_size(dst->type); - int nb1 = dst->op_params[0] / 4; // 4 bytes of float32 - int nb2 = dst->op_params[1] / 4; // 4 bytes of float32 - // int nb3 = dst->op_params[2] / 4; // 4 bytes of float32 - unused - int offset = dst->op_params[3] / 4; // offset in bytes + int nb1 = dst->op_params[0] / src0_type_size; // 4 bytes of float32 + int nb2 = dst->op_params[1] / src0_type_size; // 4 bytes of float32 + int nb3 = dst->op_params[2] / src0_type_size; // 4 bytes of float32 + int offset = dst->op_params[3] / src0_type_size; // offset in bytes - ggml_vk_op_f32<vk_op_binary_push_constants>(ctx, subctx, src0, src1, nullptr, nullptr, dst, GGML_OP_ACC, { + ggml_vk_op_f32<vk_op_binary_push_constants>(ctx, subctx, src0, src1, nullptr, nullptr, dst, dst->op, { (uint32_t)ggml_nelements(src0), - (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2],(uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)nb1, (uint32_t)nb2, (uint32_t)src0->nb[3] / src0_type_size, + (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2],(uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)nb1, (uint32_t)nb2, (uint32_t)nb3, (uint32_t)src1->ne[0], (uint32_t)src1->ne[1], (uint32_t)src1->ne[2],(uint32_t)src1->ne[3], (uint32_t)src1->nb[0] / src1_type_size, (uint32_t)src1->nb[1] / src1_type_size, (uint32_t)src1->nb[2] / src1_type_size, (uint32_t)src1->nb[3] / src1_type_size, - (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2],(uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t)nb1, (uint32_t)nb2, (uint32_t) dst->nb[3] / dst_type_size, + (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2],(uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t)nb1, (uint32_t)nb2, (uint32_t)nb3, 0, 0.0f, 0.0f, offset, }); @@ -9928,6 +11545,63 @@ static void ggml_vk_rwkv_wkv7(ggml_backend_vk_context * ctx, vk_context& subctx, ); } +static void ggml_vk_gated_delta_net(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst) { + const ggml_tensor * src_q = dst->src[0]; + const ggml_tensor * src_v = dst->src[2]; + const ggml_tensor * src_beta = dst->src[4]; + + GGML_ASSERT(dst->buffer != nullptr); + + const uint32_t S_v = (uint32_t)src_v->ne[0]; + const uint32_t H = (uint32_t)src_v->ne[1]; + const uint32_t n_tokens = (uint32_t)src_v->ne[2]; + const uint32_t n_seqs = (uint32_t)src_v->ne[3]; + + // K (snapshot slot count) is an op param; state holds s0 only [S_v, S_v, H, n_seqs]. + const uint32_t K = (uint32_t)ggml_get_op_params_i32(dst, 0); + + const uint32_t s_off = S_v * H * n_tokens * n_seqs; + + vk_pipeline pipeline = ggml_vk_op_get_pipeline(ctx, dst->src[0], dst->src[1], dst->src[2], dst, dst->op); + GGML_ASSERT(pipeline != nullptr); + + ggml_pipeline_request_descriptor_sets(ctx, pipeline, 1); + + vk_subbuffer dst_buf = ggml_vk_tensor_subbuffer(ctx, dst); + vk_subbuffer src_buf[6] = {}; + for (int i = 0; i < 6; i++) { + src_buf[i] = ggml_vk_tensor_subbuffer(ctx, dst->src[i]); + } + + const uint32_t sq1 = (uint32_t)(src_q->nb[1] / sizeof(float)); + const uint32_t sq2 = (uint32_t)(src_q->nb[2] / sizeof(float)); + const uint32_t sq3 = (uint32_t)(src_q->nb[3] / sizeof(float)); + const uint32_t sv1 = (uint32_t)(src_v->nb[1] / sizeof(float)); + const uint32_t sv2 = (uint32_t)(src_v->nb[2] / sizeof(float)); + const uint32_t sv3 = (uint32_t)(src_v->nb[3] / sizeof(float)); + const uint32_t sb1 = (uint32_t)(src_beta->nb[1] / sizeof(float)); + const uint32_t sb2 = (uint32_t)(src_beta->nb[2] / sizeof(float)); + const uint32_t sb3 = (uint32_t)(src_beta->nb[3] / sizeof(float)); + + const uint32_t neq1 = (uint32_t)src_q->ne[1]; + const uint32_t rq3 = (uint32_t)(src_v->ne[3] / src_q->ne[3]); + + const float scale = 1.0f / sqrtf((float)S_v); + const vk_op_gated_delta_net_push_constants pc = { + H, n_tokens, n_seqs, s_off, + sq1, sq2, sq3, + sv1, sv2, sv3, + sb1, sb2, sb3, + neq1, rq3, + scale, + K + }; + + ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, + {src_buf[0], src_buf[1], src_buf[2], src_buf[3], src_buf[4], src_buf[5], dst_buf}, + pc, { H, n_seqs, S_v }); +} + static void ggml_vk_ssm_scan(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst) { const ggml_tensor * src0 = dst->src[0]; const ggml_tensor * src1 = dst->src[1]; @@ -9984,11 +11658,28 @@ static void ggml_vk_ssm_scan(ggml_backend_vk_context * ctx, vk_context& subctx, pc, elements); } -static void ggml_vk_ssm_conv(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst) { - const ggml_tensor * src0 = dst->src[0]; - const ggml_tensor * src1 = dst->src[1]; +static void ggml_vk_ssm_conv(ggml_backend_vk_context * ctx, vk_context& subctx, const struct ggml_cgraph * cgraph, int node_idx) { + ggml_tensor * conv = cgraph->nodes[node_idx]; + const ggml_tensor * src0 = conv->src[0]; + const ggml_tensor * src1 = conv->src[1]; + + // Pick the destination tensor (last node in the fused chain) and the optional bias. + // Fusion modes: 0 = ssm_conv, 1 = ssm_conv+silu, 2 = ssm_conv+add(bias)+silu. + ggml_tensor * dst = conv; + const ggml_tensor * bias = nullptr; + + if (ctx->num_additional_fused_ops == 1) { + dst = cgraph->nodes[node_idx + 1]; // silu + } else if (ctx->num_additional_fused_ops == 2) { + ggml_tensor * add = cgraph->nodes[node_idx + 1]; + bias = (add->src[0] == conv) ? add->src[1] : add->src[0]; + dst = cgraph->nodes[node_idx + 2]; // silu + } + + // The shader always declares 4 bindings; bind src0 as a dummy when bias isn't fused. + const ggml_tensor * src2 = bias ? bias : src0; - ggml_vk_op_f32<vk_op_ssm_conv_push_constants>(ctx, subctx, src0, src1, nullptr, nullptr, dst, GGML_OP_SSM_CONV, { + ggml_vk_op_f32<vk_op_ssm_conv_push_constants>(ctx, subctx, src0, src1, src2, nullptr, dst, GGML_OP_SSM_CONV, { (uint32_t)src0->nb[1], (uint32_t)src0->nb[2], (uint32_t)src1->nb[1], (uint32_t)dst->nb[0], (uint32_t)dst->nb[1], (uint32_t)dst->nb[2], @@ -10335,12 +12026,23 @@ static vk_op_rope_push_constants ggml_vk_make_rope_constants(const ggml_tensor * uint32_t nb01 = src0->nb[1] / ggml_type_size(src0->type); uint32_t nb02 = src0->nb[2] / ggml_type_size(src0->type); + uint32_t nb03 = src0->nb[3] / ggml_type_size(src0->type); + + uint32_t nb11 = dst->nb[1] / ggml_type_size(dst->type); + uint32_t nb12 = dst->nb[2] / ggml_type_size(dst->type); + uint32_t nb13 = dst->nb[3] / ggml_type_size(dst->type); vk_op_rope_push_constants rope { - (uint32_t)mode, (uint32_t)src0->ne[0], (uint32_t)ggml_nrows(src0), (uint32_t)n_dims, freq_scale, (uint32_t)src0->ne[1], - freq_base, ext_factor, attn_factor, {corr_dims[0], corr_dims[1]}, theta_scale, - has_ff, (uint32_t)src0->ne[2], nb01, nb02, + (uint32_t)mode, (uint32_t)ggml_nrows(src0), (uint32_t)n_dims, freq_scale, + freq_base, ext_factor, attn_factor, {corr_dims[0], corr_dims[1]}, theta_scale, has_ff, { sections[0], sections[1], sections[2], sections[3] }, is_imrope, backprop, set_rows_stride, + + (uint32_t)src0->ne[0], + (uint32_t)src0->ne[1], + (uint32_t)src0->ne[2], + nb01, nb02, nb03, + nb11, nb12, nb13, + 0, 0, // a_offset, d_offset filled in by init_pushconst_tensor_offsets }; return rope; @@ -10436,6 +12138,11 @@ static void ggml_vk_rms_norm(ggml_backend_vk_context * ctx, vk_context& subctx, GGML_ASSERT(buf[i] != nullptr); } + // a_offset is unused (the fused path reads from shared memory), but the rope/set_rows dst can be misaligned. + // Round the binding offset down to the storage buffer alignment; the in-element shift goes in pc.rope.d_offset. + pc.rope.d_offset = get_misalign_bytes(ctx, tensors[5]) / ggml_type_size(tensors[5]->type); + offset[5] &= ~(size_t(ctx->device->properties.limits.minStorageBufferOffsetAlignment) - 1); + std::array<uint32_t, 3> elements; elements = { (uint32_t)rms->src[0]->ne[1], (uint32_t)rms->src[0]->ne[2], (uint32_t)rms->src[0]->ne[3] }; @@ -10467,8 +12174,10 @@ static void ggml_vk_rms_norm_back(ggml_backend_vk_context * ctx, vk_context& sub } static void ggml_vk_l2_norm(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) { - float * op_params = (float *)dst->op_params; - ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_L2_NORM, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], op_params[0], 0.0f, 0.0f, 0.0f }); + const float * op_params = (const float *)dst->op_params; + vk_op_unary_push_constants p = vk_op_unary_push_constants_init(src0, dst); + p.param1 = op_params[0]; + ggml_vk_op_f32<vk_op_unary_push_constants>(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_L2_NORM, std::move(p)); } static void ggml_vk_unary(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) { @@ -10493,8 +12202,6 @@ static void ggml_vk_glu(ggml_backend_vk_context * ctx, vk_context& subctx, const const float alpha = op_params_f[2]; const float limit = op_params_f[3]; - GGML_ASSERT(ggml_is_contiguous(src0)); - if (!split) { GGML_ASSERT(src0->ne[0] / 2 == dst->ne[0]); } else { @@ -10512,7 +12219,17 @@ static void ggml_vk_glu(ggml_backend_vk_context * ctx, vk_context& subctx, const (uint32_t)dst->ne[0], mode, alpha, - limit + limit, + (uint32_t)(src0->nb[1] / src0->nb[0]), + (uint32_t)(src0->nb[2] / src0->nb[0]), + (uint32_t)(src0->nb[3] / src0->nb[0]), + (uint32_t)src0->ne[1], + (uint32_t)src0->ne[2], + (uint32_t)(dst->nb[1] / dst->nb[0]), + (uint32_t)(dst->nb[2] / dst->nb[0]), + (uint32_t)(dst->nb[3] / dst->nb[0]), + (uint32_t)dst->ne[1], + (uint32_t)dst->ne[2] }); } @@ -11021,7 +12738,6 @@ static void ggml_vk_im2col(ggml_backend_vk_context * ctx, vk_context& subctx, co const uint32_t offset_delta = src1->nb[is_2D ? 2 : 1] / 4; // nb is byte offset, src is type float32 const uint32_t batch_offset = src1->nb[is_2D ? 3 : 2] / 4; // nb is byte offset, src is type float32 - const uint32_t pelements = OW * KW * KH; const uint32_t batch = src1->ne[is_2D ? 3 : 2]; const ggml_backend_vk_buffer_context * d_buf_ctx = (ggml_backend_vk_buffer_context *)dst->buffer->context; @@ -11033,7 +12749,7 @@ static void ggml_vk_im2col(ggml_backend_vk_context * ctx, vk_context& subctx, co dst_addr, batch_offset, offset_delta, IC, IW, IH, OW, OH, KW, KH, - pelements, + OH * batch, IC * KH * KW, s0, s1, p0, p1, d0, d1, batch * IC }); @@ -11146,6 +12862,45 @@ static void ggml_vk_conv_transpose_1d(ggml_backend_vk_context * ctx, vk_context& ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, nullptr, dst, GGML_OP_CONV_TRANSPOSE_1D, std::move(p)); } +// Dispatch the fused snake activation: y = x + sin^2(a * x) * inv_b. +// Match the naive mul -> sin -> sqr -> mul -> add chain and run the +// dedicated kernel directly. The pattern is validated by +// ggml_vk_can_fuse_snake before this call. +static void ggml_vk_snake_dispatch_fused(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_cgraph * cgraph, int node_idx) { + const ggml_tensor * mul0 = cgraph->nodes[node_idx + 0]; + const ggml_tensor * sqr = cgraph->nodes[node_idx + 2]; + const ggml_tensor * mul1 = cgraph->nodes[node_idx + 3]; + ggml_tensor * add = cgraph->nodes[node_idx + 4]; + + // x carries the full activation shape, a is the broadcast operand + const ggml_tensor * x = ggml_are_same_shape(mul0, mul0->src[0]) ? mul0->src[0] : mul0->src[1]; + const ggml_tensor * a = (x == mul0->src[0]) ? mul0->src[1] : mul0->src[0]; + + // mul1 reads sqr and inv_b in either operand order + const ggml_tensor * inv_b = (mul1->src[0] == sqr) ? mul1->src[1] : mul1->src[0]; + + vk_pipeline pipeline = nullptr; + switch (x->type) { + case GGML_TYPE_F32: pipeline = ctx->device->pipeline_snake_f32; break; + case GGML_TYPE_F16: pipeline = ctx->device->pipeline_snake_f16; break; + case GGML_TYPE_BF16: pipeline = ctx->device->pipeline_snake_bf16; break; + default: GGML_ABORT("unsupported type"); + } + ggml_pipeline_request_descriptor_sets(ctx, pipeline, 1); + + vk_subbuffer x_buf = ggml_vk_tensor_subbuffer(ctx, x); + vk_subbuffer a_buf = ggml_vk_tensor_subbuffer(ctx, a); + vk_subbuffer inv_b_buf = ggml_vk_tensor_subbuffer(ctx, inv_b); + vk_subbuffer dst_buf = ggml_vk_tensor_subbuffer(ctx, add); + + vk_op_snake_push_constants pc{}; + pc.ne0 = static_cast<uint32_t>(x->ne[0]); + pc.ne1 = static_cast<uint32_t>(x->ne[1]); + + std::array<uint32_t, 3> elements = { pc.ne0, pc.ne1, 1 }; + ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { x_buf, a_buf, inv_b_buf, dst_buf }, pc, elements); +} + static void ggml_vk_pool_2d(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) { uint32_t op = static_cast<uint32_t>(dst->op_params[0]); const int32_t k1 = dst->op_params[1]; @@ -11386,7 +13141,6 @@ static void ggml_vk_test_matmul(ggml_backend_vk_context * ctx, size_t m, size_t } } - ggml_pipeline_request_descriptor_sets(ctx, p, num_it); if (split_k > 1) { ggml_pipeline_request_descriptor_sets(ctx, ctx->device->pipeline_matmul_split_k_reduce, num_it); @@ -11560,7 +13314,6 @@ static void ggml_vk_test_matmul(ggml_backend_vk_context * ctx, size_t m, size_t free(d_chk); ggml_vk_command_pool_cleanup(ctx->device, ctx->compute_cmd_pool); - ggml_vk_command_pool_cleanup(ctx->device, ctx->transfer_cmd_pool); ggml_vk_destroy_buffer(d_X); ggml_vk_destroy_buffer(d_Y); @@ -11896,7 +13649,6 @@ static void ggml_vk_test_dequant_matmul(ggml_backend_vk_context * ctx, size_t m, // y[i] = i % k; } - ggml_pipeline_request_descriptor_sets(ctx, p, num_it); if (split_k > 1) { ggml_pipeline_request_descriptor_sets(ctx, ctx->device->pipeline_matmul_split_k_reduce, num_it); @@ -11909,7 +13661,8 @@ static void ggml_vk_test_dequant_matmul(ggml_backend_vk_context * ctx, size_t m, } } if (mmq) { - ggml_pipeline_request_descriptor_sets(ctx, ctx->device->pipeline_quantize_q8_1, num_it); + vk_pipeline pipeline_quantize_q8_1 = ggml_vk_get_quantize_pipeline(ctx, GGML_TYPE_Q8_1); + ggml_pipeline_request_descriptor_sets(ctx, pipeline_quantize_q8_1, num_it); } ggml_pipeline_allocate_descriptor_sets(ctx); @@ -12145,7 +13898,9 @@ static void ggml_vk_preallocate_buffers(ggml_backend_vk_context * ctx, vk_contex ggml_vk_submit(subctx, {}); ctx->submit_pending = true; ggml_vk_synchronize(ctx); + GGML_ASSERT(ctx->compute_ctx.expired()); ggml_vk_ctx_begin(ctx->device, subctx); + ctx->compute_ctx = subctx; } if (ctx->prealloc_x == nullptr || (ctx->prealloc_size_x > 0 && ctx->prealloc_x->size < ctx->prealloc_size_x)) { @@ -12163,6 +13918,9 @@ static void ggml_vk_preallocate_buffers(ggml_backend_vk_context * ctx, vk_contex ggml_vk_destroy_buffer(ctx->prealloc_y); } ctx->prealloc_y = ggml_vk_create_buffer_device(ctx->device, ctx->prealloc_size_y); + ctx->prealloc_y_last_pipeline_used = nullptr; + ctx->prealloc_y_last_tensor_used = nullptr; + ctx->prealloc_y_last_decode_vector_staging = false; } if (ctx->prealloc_split_k == nullptr || (ctx->prealloc_size_split_k > 0 && ctx->prealloc_split_k->size < ctx->prealloc_size_split_k)) { VK_LOG_MEMORY("ggml_vk_preallocate_buffers(split_k_size: " << ctx->prealloc_size_split_k << ")"); @@ -12191,6 +13949,9 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr if (ggml_is_empty(node) || ggml_op_is_empty(node->op) || !node->buffer) { return false; } + if ((node->flags & GGML_TENSOR_FLAG_COMPUTE) == 0) { + return false; + } VK_LOG_DEBUG("ggml_vk_build_graph(" << node << ", " << ggml_op_name(node->op) << ")"); ctx->semaphore_idx = 0; @@ -12215,15 +13976,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr } } - vk_context compute_ctx; - - if (ctx->compute_ctx.expired()) { - compute_ctx = ggml_vk_create_context(ctx, ctx->compute_cmd_pool); - ctx->compute_ctx = compute_ctx; - ggml_vk_ctx_begin(ctx->device, compute_ctx); - } else { - compute_ctx = ctx->compute_ctx.lock(); - } + vk_context compute_ctx = ggml_vk_get_compute_ctx(ctx); { // This logic detects dependencies between modes in the graph and calls ggml_vk_sync_buffers @@ -12294,7 +14047,8 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr if (vk_perf_logger_enabled && vk_perf_logger_concurrent) { ctx->query_node_idx[ctx->query_idx] = node_idx; - compute_ctx->s->buffer.writeTimestamp(vk::PipelineStageFlagBits::eAllCommands, ctx->query_pool, ctx->query_idx++); + compute_ctx->s->buffer->buf.writeTimestamp(vk::PipelineStageFlagBits::eAllCommands, ctx->query_pool, ctx->query_idx++); + ggml_vk_sync_buffers(ctx, compute_ctx); } } // Add all fused nodes to the unsynchronized lists. @@ -12337,6 +14091,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr break; case GGML_OP_ACC: + case GGML_OP_SET: ggml_vk_acc(ctx, compute_ctx, src0, src1, node); break; @@ -12356,7 +14111,11 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr break; case GGML_OP_MUL: - ggml_vk_mul(ctx, compute_ctx, src0, src1, node); + if (ctx->num_additional_fused_ops) { + ggml_vk_snake_dispatch_fused(ctx, compute_ctx, cgraph, node_idx); + } else { + ggml_vk_mul(ctx, compute_ctx, src0, src1, node); + } break; case GGML_OP_DIV: @@ -12471,6 +14230,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr } switch (ggml_get_unary_op(node)) { + case GGML_UNARY_OP_ELU: case GGML_UNARY_OP_EXP: case GGML_UNARY_OP_SILU: case GGML_UNARY_OP_GELU: @@ -12489,6 +14249,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr case GGML_UNARY_OP_CEIL: case GGML_UNARY_OP_FLOOR: case GGML_UNARY_OP_TRUNC: + case GGML_UNARY_OP_SGN: ggml_vk_unary(ctx, compute_ctx, src0, node); break; case GGML_UNARY_OP_XIELU: @@ -12633,13 +14394,18 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr break; + case GGML_OP_GATED_DELTA_NET: + ggml_vk_gated_delta_net(ctx, compute_ctx, node); + + break; + case GGML_OP_SSM_SCAN: ggml_vk_ssm_scan(ctx, compute_ctx, node); break; case GGML_OP_SSM_CONV: - ggml_vk_ssm_conv(ctx, compute_ctx, node); + ggml_vk_ssm_conv(ctx, compute_ctx, cgraph, node_idx); break; @@ -12734,13 +14500,17 @@ static void ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_cgraph * static void ggml_vk_graph_cleanup(ggml_backend_vk_context * ctx) { VK_LOG_DEBUG("ggml_vk_graph_cleanup()"); ctx->prealloc_y_last_pipeline_used = {}; + ctx->prealloc_y_last_tensor_used = nullptr; + ctx->prealloc_y_last_decode_vector_staging = false; ctx->unsynced_nodes_written.clear(); ctx->unsynced_nodes_read.clear(); ctx->prealloc_x_need_sync = ctx->prealloc_y_need_sync = ctx->prealloc_split_k_need_sync = false; ggml_vk_command_pool_cleanup(ctx->device, ctx->compute_cmd_pool); - ggml_vk_command_pool_cleanup(ctx->device, ctx->transfer_cmd_pool); + if (ctx->device->async_use_transfer_queue) { + ggml_vk_command_pool_cleanup(ctx->device, ctx->transfer_cmd_pool); + } for (size_t i = 0; i < ctx->gc.semaphores.size(); i++) { ctx->device->device.destroySemaphore({ ctx->gc.semaphores[i].s }); @@ -12769,7 +14539,7 @@ static void ggml_vk_graph_cleanup(ggml_backend_vk_context * ctx) { static void ggml_vk_cleanup(ggml_backend_vk_context * ctx) { VK_LOG_DEBUG("ggml_vk_cleanup(" << ctx->name << ")"); // discard any unsubmitted command buffers - ctx->transfer_ctx.reset(); + ctx->compute_ctx.reset(); // wait for any pending command buffers to finish ggml_vk_synchronize(ctx); @@ -12782,6 +14552,8 @@ static void ggml_vk_cleanup(ggml_backend_vk_context * ctx) { ggml_vk_destroy_buffer(ctx->sync_staging); ctx->prealloc_y_last_pipeline_used = nullptr; + ctx->prealloc_y_last_tensor_used = nullptr; + ctx->prealloc_y_last_decode_vector_staging = false; ctx->prealloc_size_x = 0; ctx->prealloc_size_y = 0; @@ -12802,7 +14574,11 @@ static void ggml_vk_cleanup(ggml_backend_vk_context * ctx) { ctx->descriptor_sets.clear(); ctx->compute_cmd_pool.destroy(ctx->device->device); - ctx->transfer_cmd_pool.destroy(ctx->device->device); + if (ctx->device->async_use_transfer_queue) { + ctx->device->device.destroySemaphore(ctx->transfer_semaphore.s); + + ctx->transfer_cmd_pool.destroy(ctx->device->device); + } if (vk_perf_logger_enabled) { ctx->perf_logger->print_timings(true); } @@ -12861,6 +14637,10 @@ static void ggml_backend_vk_buffer_memset_tensor(ggml_backend_buffer_t buffer, g ggml_backend_vk_buffer_context * buf_ctx = (ggml_backend_vk_buffer_context *)buffer->context; vk_buffer buf = buf_ctx->dev_buffer; + if (size == 0) { + return; + } + uint32_t val32 = (uint32_t)value * 0x01010101; ggml_vk_buffer_memset(buf, vk_tensor_offset(tensor) + tensor->view_offs + offset, val32, size); } @@ -12870,19 +14650,60 @@ static void ggml_backend_vk_buffer_set_tensor(ggml_backend_buffer_t buffer, ggml ggml_backend_vk_buffer_context * buf_ctx = (ggml_backend_vk_buffer_context *)buffer->context; vk_buffer buf = buf_ctx->dev_buffer; + if (size == 0) { + return; + } + ggml_vk_buffer_write(buf, vk_tensor_offset(tensor) + tensor->view_offs + offset, data, size); } +static void ggml_backend_vk_buffer_set_tensor_2d(ggml_backend_buffer_t buffer, ggml_tensor * tensor, const void * data, size_t offset, + size_t size, size_t n_copies, size_t stride_tensor, size_t stride_data) { + VK_LOG_DEBUG("ggml_backend_vk_buffer_set_tensor_2d(" << buffer << ", " << tensor << ", " << data << ", " << offset << ", " << size << ", " << + n_copies << ", " << stride_tensor << ", " << stride_data << ")"); + ggml_backend_vk_buffer_context * buf_ctx = (ggml_backend_vk_buffer_context *)buffer->context; + vk_buffer buf = buf_ctx->dev_buffer; + + if (size == 0) { + return; + } + + ggml_vk_buffer_write_2d(buf, vk_tensor_offset(tensor) + tensor->view_offs + offset, data, stride_data, stride_tensor, size, n_copies); +} + static void ggml_backend_vk_buffer_get_tensor(ggml_backend_buffer_t buffer, const ggml_tensor * tensor, void * data, size_t offset, size_t size) { VK_LOG_DEBUG("ggml_backend_vk_buffer_get_tensor(" << buffer << ", " << tensor << ", " << data << ", " << offset << ", " << size << ")"); ggml_backend_vk_buffer_context * buf_ctx = (ggml_backend_vk_buffer_context *)buffer->context; + if (size == 0) { + return; + } + vk_buffer buf = buf_ctx->dev_buffer; ggml_vk_buffer_read(buf, vk_tensor_offset(tensor) + tensor->view_offs + offset, data, size); } +static void ggml_backend_vk_buffer_get_tensor_2d(ggml_backend_buffer_t buffer, const ggml_tensor * tensor, void * data, size_t offset, + size_t size, size_t n_copies, size_t stride_tensor, size_t stride_data) { + VK_LOG_DEBUG("ggml_backend_vk_buffer_get_tensor_2d(" << buffer << ", " << tensor << ", " << data << ", " << offset << ", " << size << ", " << + n_copies << ", " << stride_tensor << ", " << stride_data << ")"); + ggml_backend_vk_buffer_context * buf_ctx = (ggml_backend_vk_buffer_context *)buffer->context; + + if (size == 0) { + return; + } + + vk_buffer buf = buf_ctx->dev_buffer; + + ggml_vk_buffer_read_2d(buf, vk_tensor_offset(tensor) + tensor->view_offs + offset, data, stride_tensor, stride_data, size, n_copies); +} + static bool ggml_backend_vk_buffer_cpy_tensor(ggml_backend_buffer_t buffer, const ggml_tensor * src, ggml_tensor * dst) { + if (ggml_nbytes(src) == 0) { + return true; + } + if (ggml_backend_buffer_is_vk(src->buffer)) { ggml_backend_vk_buffer_context * src_buf_ctx = (ggml_backend_vk_buffer_context *)src->buffer->context; ggml_backend_vk_buffer_context * dst_buf_ctx = (ggml_backend_vk_buffer_context *)dst->buffer->context; @@ -12912,6 +14733,8 @@ static ggml_backend_buffer_i ggml_backend_vk_buffer_interface = { /* .memset_tensor = */ ggml_backend_vk_buffer_memset_tensor, /* .set_tensor = */ ggml_backend_vk_buffer_set_tensor, /* .get_tensor = */ ggml_backend_vk_buffer_get_tensor, + /* .set_tensor_2d = */ ggml_backend_vk_buffer_set_tensor_2d, + /* .get_tensor_2d = */ ggml_backend_vk_buffer_get_tensor_2d, /* .cpy_tensor = */ ggml_backend_vk_buffer_cpy_tensor, /* .clear = */ ggml_backend_vk_buffer_clear, /* .reset = */ NULL, @@ -12974,12 +14797,6 @@ static const char * ggml_backend_vk_host_buffer_type_name(ggml_backend_buffer_ty UNUSED(buft); } -static const char * ggml_backend_vk_host_buffer_name(ggml_backend_buffer_t buffer) { - return GGML_VK_NAME "_Host"; - - UNUSED(buffer); -} - static void ggml_backend_vk_host_buffer_free_buffer(ggml_backend_buffer_t buffer) { VK_LOG_MEMORY("ggml_backend_vk_host_buffer_free_buffer()"); ggml_vk_host_free(vk_instance.devices[0], buffer->context); @@ -13061,151 +14878,259 @@ static void ggml_backend_vk_free(ggml_backend_t backend) { delete backend; } -static ggml_backend_buffer_type_t ggml_backend_vk_get_default_buffer_type(ggml_backend_t backend) { - ggml_backend_vk_context * ctx = (ggml_backend_vk_context *)backend->context; +static ggml_backend_buffer_type_t ggml_backend_vk_get_default_buffer_type(ggml_backend_t backend) { + ggml_backend_vk_context * ctx = (ggml_backend_vk_context *)backend->context; + + return &ctx->device->buffer_type; +} + +static void ggml_backend_vk_set_tensor_2d_async(ggml_backend_t backend, ggml_tensor * tensor, const void * data, size_t offset, + size_t size, size_t n_copies, size_t stride_tensor, size_t stride_data) { + VK_LOG_DEBUG("ggml_backend_vk_set_tensor_2d_async(" << size << ", " << n_copies << ")"); + ggml_backend_vk_context * ctx = (ggml_backend_vk_context *)backend->context; + GGML_ASSERT((tensor->buffer->buft == ggml_backend_vk_get_default_buffer_type(backend) || tensor->buffer->buft == ggml_backend_vk_host_buffer_type()) && "unsupported buffer type"); + + if (size == 0) { + return; + } + + ggml_backend_vk_buffer_context * buf_ctx = (ggml_backend_vk_buffer_context *)tensor->buffer->context; + + vk_context cpy_ctx; + + if (ctx->device->async_use_transfer_queue) { + if (ctx->transfer_ctx.expired()) { + cpy_ctx = ggml_vk_create_context(ctx, ctx->transfer_cmd_pool); + ctx->transfer_ctx = cpy_ctx; + ggml_vk_ctx_begin(ctx->device, cpy_ctx); + } else { + cpy_ctx = ctx->transfer_ctx.lock(); + } + } else { + cpy_ctx = ggml_vk_get_compute_ctx(ctx); + } + + vk_buffer buf = buf_ctx->dev_buffer; + + auto dst_offset = vk_tensor_offset(tensor) + tensor->view_offs + offset; + + bool ret = ggml_vk_buffer_write_2d_async(cpy_ctx, buf, dst_offset, data, stride_data, stride_tensor, size, n_copies); + + if (!ret) { + const size_t staging_size = size * n_copies; + ggml_vk_ensure_sync_staging_buffer(ctx, staging_size); + ggml_vk_sync_buffers(nullptr, cpy_ctx); + + std::vector<vk::BufferCopy> slices(1); + if (size == stride_tensor) { + slices[0].srcOffset = 0; + slices[0].dstOffset = dst_offset; + slices[0].size = staging_size; + } else { + slices.resize(n_copies); + for (size_t i = 0; i < n_copies; i++) { + slices[i].srcOffset = i * size; + slices[i].dstOffset = dst_offset + i * stride_tensor; + slices[i].size = size; + } + } + + cpy_ctx->s->buffer->buf.copyBuffer(ctx->sync_staging->buffer, buf->buffer, slices); - return &ctx->device->buffer_type; + if (size == stride_data) { + deferred_memcpy(ctx->sync_staging->ptr, data, staging_size, &cpy_ctx->in_memcpys); + } else { + for (size_t i = 0; i < n_copies; i++) { + deferred_memcpy((uint8_t *)ctx->sync_staging->ptr + i * size, (const uint8_t *)data + i * stride_data, size, &cpy_ctx->in_memcpys); + } + } + ggml_vk_synchronize(ctx); + } } static void ggml_backend_vk_set_tensor_async(ggml_backend_t backend, ggml_tensor * tensor, const void * data, size_t offset, size_t size) { VK_LOG_DEBUG("ggml_backend_vk_set_tensor_async(" << size << ")"); + ggml_backend_vk_set_tensor_2d_async(backend, tensor, data, offset, size, 1, size, size); +} + +static void ggml_backend_vk_get_tensor_2d_async(ggml_backend_t backend, const ggml_tensor * tensor, void * data, size_t offset, + size_t size, size_t n_copies, size_t stride_tensor, size_t stride_data) { + VK_LOG_DEBUG("ggml_backend_vk_get_tensor_2d_async(" << size << ", " << n_copies << ")"); ggml_backend_vk_context * ctx = (ggml_backend_vk_context *)backend->context; GGML_ASSERT((tensor->buffer->buft == ggml_backend_vk_get_default_buffer_type(backend) || tensor->buffer->buft == ggml_backend_vk_host_buffer_type()) && "unsupported buffer type"); - ggml_backend_vk_buffer_context * buf_ctx = (ggml_backend_vk_buffer_context *)tensor->buffer->context; + if (size == 0) { + return; + } - vk_context transfer_ctx; + ggml_backend_vk_buffer_context * buf_ctx = (ggml_backend_vk_buffer_context *)tensor->buffer->context; - if (ctx->transfer_ctx.expired()) { - // Initialize new transfer context - transfer_ctx = ggml_vk_create_context(ctx, ctx->compute_cmd_pool); - ctx->transfer_ctx = transfer_ctx; - ggml_vk_ctx_begin(ctx->device, transfer_ctx); - } else { - transfer_ctx = ctx->transfer_ctx.lock(); - } + vk_context compute_ctx = ggml_vk_get_compute_ctx(ctx); vk_buffer buf = buf_ctx->dev_buffer; - auto dst_offset = vk_tensor_offset(tensor) + tensor->view_offs + offset; - - bool ret = ggml_vk_buffer_write_async(transfer_ctx, buf, dst_offset, data, size); + auto src_offset = vk_tensor_offset(tensor) + tensor->view_offs + offset; + bool ret = ggml_vk_buffer_read_2d_async(compute_ctx, buf, src_offset, data, stride_tensor, stride_data, size, n_copies); if (!ret) { - ggml_vk_ensure_sync_staging_buffer(ctx, size); - ggml_vk_sync_buffers(nullptr, transfer_ctx); + const size_t staging_size = size * n_copies; + ggml_vk_ensure_sync_staging_buffer(ctx, staging_size); + ggml_vk_sync_buffers(nullptr, compute_ctx); + + std::vector<vk::BufferCopy> slices(1); + if (size == stride_tensor) { + slices[0].srcOffset = src_offset; + slices[0].dstOffset = 0; + slices[0].size = staging_size; + } else { + slices.resize(n_copies); + for (size_t i = 0; i < n_copies; i++) { + slices[i].srcOffset = src_offset + i * stride_tensor; + slices[i].dstOffset = i * size; + slices[i].size = size; + } + } - vk::BufferCopy buffer_cpy; - buffer_cpy.srcOffset = 0; - buffer_cpy.dstOffset = dst_offset; - buffer_cpy.size = size; + compute_ctx->s->buffer->buf.copyBuffer(buf->buffer, ctx->sync_staging->buffer, slices); - transfer_ctx->s->buffer.copyBuffer(ctx->sync_staging->buffer, buf->buffer, { buffer_cpy }); - deferred_memcpy(ctx->sync_staging->ptr, data, size, &transfer_ctx->in_memcpys); + if (size == stride_data) { + deferred_memcpy(data, ctx->sync_staging->ptr, staging_size, &compute_ctx->out_memcpys); + } else { + for (size_t i = 0; i < n_copies; i++) { + deferred_memcpy((uint8_t *)data + i * stride_data, (const uint8_t *)ctx->sync_staging->ptr + i * size, size, &compute_ctx->out_memcpys); + } + } ggml_vk_synchronize(ctx); } } static void ggml_backend_vk_get_tensor_async(ggml_backend_t backend, const ggml_tensor * tensor, void * data, size_t offset, size_t size) { VK_LOG_DEBUG("ggml_backend_vk_get_tensor_async(" << size << ")"); - ggml_backend_vk_context * ctx = (ggml_backend_vk_context *)backend->context; - GGML_ASSERT((tensor->buffer->buft == ggml_backend_vk_get_default_buffer_type(backend) || tensor->buffer->buft == ggml_backend_vk_host_buffer_type()) && "unsupported buffer type"); + ggml_backend_vk_get_tensor_2d_async(backend, tensor, data, offset, size, 1, size, size); +} - ggml_backend_vk_buffer_context * buf_ctx = (ggml_backend_vk_buffer_context *)tensor->buffer->context; +static bool ggml_backend_vk_cpy_tensor_async(ggml_backend_t backend_src, ggml_backend_t backend_dst, const ggml_tensor * src, ggml_tensor * dst) { + VK_LOG_DEBUG("ggml_backend_vk_cpy_tensor_async(" << src << " -> " << dst << ", size=" << ggml_nbytes(src) << ")"); + ggml_backend_vk_context * ctx = (ggml_backend_vk_context *)backend_dst->context; - vk_context transfer_ctx; + // Skip zero-size tensors + if (ggml_nbytes(src) == 0) { + return true; + } - if (ctx->transfer_ctx.expired()) { - // Initialize new transfer context - transfer_ctx = ggml_vk_create_context(ctx, ctx->compute_cmd_pool); - ctx->transfer_ctx = transfer_ctx; - ggml_vk_ctx_begin(ctx->device, transfer_ctx); - } else { - transfer_ctx = ctx->transfer_ctx.lock(); + if (dst->buffer->buft != ggml_backend_vk_get_default_buffer_type(backend_dst)) { + return false; } - vk_buffer buf = buf_ctx->dev_buffer; + ggml_backend_vk_buffer_context * dst_buf_ctx = (ggml_backend_vk_buffer_context *)dst->buffer->context; + vk_buffer dst_buf = dst_buf_ctx->dev_buffer; - auto src_offset = vk_tensor_offset(tensor) + tensor->view_offs + offset; - bool ret = ggml_vk_buffer_read_async(transfer_ctx, buf, src_offset, data, size); + if (ggml_backend_buffer_is_vk(src->buffer)) { + ggml_backend_vk_buffer_context * src_buf_ctx = (ggml_backend_vk_buffer_context *)src->buffer->context; - // If that failed, copy synchronously through a staging buffer - if (!ret) { - ggml_vk_ensure_sync_staging_buffer(ctx, size); - ggml_vk_sync_buffers(nullptr, transfer_ctx); + // Async copy only works within the same device + if (src_buf_ctx->dev_buffer->device != dst_buf->device) { + return false; + } - vk::BufferCopy buffer_cpy; - buffer_cpy.srcOffset = src_offset; - buffer_cpy.dstOffset = 0; - buffer_cpy.size = size; + vk_context compute_ctx = ggml_vk_get_compute_ctx(ctx); - transfer_ctx->s->buffer.copyBuffer(buf->buffer, ctx->sync_staging->buffer, { buffer_cpy }); - deferred_memcpy(data, ctx->sync_staging->ptr, size, &transfer_ctx->out_memcpys); - ggml_vk_synchronize(ctx); + ggml_vk_buffer_copy_async(compute_ctx, dst_buf, vk_tensor_offset(dst) + dst->view_offs, + src_buf_ctx->dev_buffer, vk_tensor_offset(src) + src->view_offs, + ggml_nbytes(src)); + return true; } -} - -static bool ggml_backend_vk_cpy_tensor_async(ggml_backend_t backend, const ggml_tensor * src, ggml_tensor * dst) { - VK_LOG_DEBUG("ggml_backend_vk_cpy_tensor_async()"); - ggml_backend_vk_context * ctx = (ggml_backend_vk_context *)backend->context; - if ((dst->buffer->buft == ggml_backend_vk_get_default_buffer_type(backend) || dst->buffer->buft == ggml_backend_vk_host_buffer_type()) && ggml_backend_buffer_is_vk(src->buffer)) { - ggml_backend_vk_buffer_context * src_buf_ctx = (ggml_backend_vk_buffer_context *)src->buffer->context; - ggml_backend_vk_buffer_context * dst_buf_ctx = (ggml_backend_vk_buffer_context *)dst->buffer->context; - vk_context transfer_ctx; + if (ggml_backend_buffer_is_host(src->buffer)) { + vk_buffer pinned_buf = nullptr; + size_t pinned_offset = 0; + ggml_vk_host_get(ctx->device, src->data, pinned_buf, pinned_offset); + if (pinned_buf == nullptr) { + return false; + } - if (ctx->transfer_ctx.expired()) { - // Initialize new transfer context - transfer_ctx = ggml_vk_create_context(ctx, ctx->compute_cmd_pool); - ctx->transfer_ctx = transfer_ctx; - ggml_vk_ctx_begin(ctx->device, transfer_ctx); + vk_context cpy_ctx; + if (ctx->device->async_use_transfer_queue) { + if (ctx->transfer_ctx.expired()) { + cpy_ctx = ggml_vk_create_context(ctx, ctx->transfer_cmd_pool); + ctx->transfer_ctx = cpy_ctx; + ggml_vk_ctx_begin(ctx->device, cpy_ctx); + } else { + cpy_ctx = ctx->transfer_ctx.lock(); + } } else { - transfer_ctx = ctx->transfer_ctx.lock(); + cpy_ctx = ggml_vk_get_compute_ctx(ctx); } - vk_buffer src_buf = src_buf_ctx->dev_buffer; - vk_buffer dst_buf = dst_buf_ctx->dev_buffer; - - ggml_vk_buffer_copy_async(transfer_ctx, dst_buf, vk_tensor_offset(dst) + dst->view_offs, src_buf, vk_tensor_offset(src) + src->view_offs, ggml_nbytes(src)); - return true; + return ggml_vk_buffer_write_async(cpy_ctx, dst_buf, + vk_tensor_offset(dst) + dst->view_offs, + src->data, ggml_nbytes(src)); } + GGML_UNUSED(backend_src); return false; } static void ggml_vk_synchronize(ggml_backend_vk_context * ctx) { VK_LOG_DEBUG("ggml_vk_synchronize()"); - bool do_transfer = !ctx->transfer_ctx.expired(); + bool do_transfer = !ctx->compute_ctx.expired(); + + if (ggml_vk_submit_transfer_ctx(ctx)) { + ctx->submit_pending = true; + } - vk_context transfer_ctx; + vk_context compute_ctx; + vk_command_buffer* cmd_buf = nullptr; if (do_transfer) { - transfer_ctx = ctx->transfer_ctx.lock(); + compute_ctx = ctx->compute_ctx.lock(); + if (compute_ctx->s) { + cmd_buf = compute_ctx->s->buffer; + } - ggml_vk_ctx_end(transfer_ctx); + ggml_vk_ctx_end(compute_ctx); - for (auto& cpy : transfer_ctx->in_memcpys) { + for (auto& cpy : compute_ctx->in_memcpys) { memcpy(cpy.dst, cpy.src, cpy.n); } - ggml_vk_submit(transfer_ctx, {}); + ggml_vk_submit(compute_ctx, {}); ctx->submit_pending = true; } if (ctx->submit_pending) { - { + if (ctx->device->async_use_transfer_queue && ctx->transfer_semaphore_last_submitted < ctx->transfer_semaphore.value) { + vk::TimelineSemaphoreSubmitInfo tl_info{ + 1, &ctx->transfer_semaphore.value, + 0, nullptr, + }; + vk::PipelineStageFlags stage = ctx->device->transfer_queue.stage_flags; + vk::SubmitInfo si{ + 1, &ctx->transfer_semaphore.s, &stage, + 0, nullptr, + 0, nullptr, + }; + si.setPNext(&tl_info); + std::lock_guard<std::mutex> guard(queue_mutex); + ctx->device->compute_queue.queue.submit({ si }, ctx->fence); + ctx->transfer_semaphore_last_submitted = ctx->transfer_semaphore.value; + } else { std::lock_guard<std::mutex> guard(queue_mutex); ctx->device->compute_queue.queue.submit({}, ctx->fence); } ggml_vk_wait_for_fence(ctx); ctx->submit_pending = false; + if (cmd_buf) { + cmd_buf->in_use = false; + cmd_buf->buf.reset(); + } } if (do_transfer) { - for (auto& cpy : transfer_ctx->out_memcpys) { + for (auto& cpy : compute_ctx->out_memcpys) { memcpy(cpy.dst, cpy.src, cpy.n); } - ctx->transfer_ctx.reset(); + ctx->compute_ctx.reset(); } } @@ -13375,6 +15300,62 @@ static bool ggml_vk_can_fuse(const ggml_backend_vk_context * ctx, const struct g return true; } +// Match SSM_CONV + UNARY(SILU) or SSM_CONV + ADD + UNARY(SILU). num_extra is 1 or 2. +static bool ggml_vk_can_fuse_ssm_conv(const ggml_backend_vk_context * ctx, const struct ggml_cgraph * cgraph, + int node_idx, int num_extra) { + const ggml_tensor * conv = cgraph->nodes[node_idx]; + if (conv->op != GGML_OP_SSM_CONV) { + return false; + } + + const ggml_tensor * silu = nullptr; + const ggml_tensor * bias = nullptr; + + if (num_extra == 1) { + if (!ggml_can_fuse(cgraph, node_idx, { GGML_OP_SSM_CONV, GGML_OP_UNARY })) { + return false; + } + silu = cgraph->nodes[node_idx + 1]; + } else if (num_extra == 2) { + if (!ggml_can_fuse(cgraph, node_idx, { GGML_OP_SSM_CONV, GGML_OP_ADD, GGML_OP_UNARY })) { + return false; + } + const ggml_tensor * add = cgraph->nodes[node_idx + 1]; + silu = cgraph->nodes[node_idx + 2]; + bias = (add->src[0] == conv) ? add->src[1] : add->src[0]; + + if (bias->type != GGML_TYPE_F32 || !ggml_is_contiguous(bias)) { + return false; + } + // bias must be channel-wise (one element per channel of the conv output) + if (ggml_nelements(bias) != conv->ne[0] || bias->ne[0] != conv->ne[0]) { + return false; + } + if (add->type != GGML_TYPE_F32) { + return false; + } + // The shader doesn't apply per-tensor offsets, so reject misaligned bias. + if (get_misalign_bytes(ctx, bias) != 0) { + return false; + } + } else { + return false; + } + + if (ggml_get_unary_op(silu) != GGML_UNARY_OP_SILU) { + return false; + } + if (conv->type != GGML_TYPE_F32 || silu->type != GGML_TYPE_F32) { + return false; + } + // The shader writes to the fused dst using its own strides, but the push constants don't + // carry a per-tensor offset, so the binding must be naturally aligned. + if (get_misalign_bytes(ctx, silu) != 0) { + return false; + } + return true; +} + static bool ggml_vk_can_fuse_topk_moe(ggml_backend_vk_context * ctx, const struct ggml_cgraph * cgraph, int node_idx, topk_moe_mode mode) { @@ -13505,12 +15486,70 @@ static bool ggml_vk_can_fuse_rope_set_rows(ggml_backend_vk_context * ctx, const return true; } -// Check whether the tensors overlap in memory but are not equal. -// Fusions can potenitally overwrite src tensors in ways that are not prevented -// by ggml-alloc. If the fusion is entirely elementwise, then it's OK for them -// to overlap if they are exactly equal. -// XXX TODO this check is probably missing from several fusion optimizations. -static bool ggml_vk_tensors_overlap_but_not_equal(const ggml_tensor * a, const ggml_tensor * b) { +// Pattern check for the 5-op Snake fusion: mul -> sin -> sqr -> mul -> add. +// Verifies the chain shape, the closure x_in_add == x_in_mul0, and that +// the broadcast operands a and inv_b share a [1, C] layout. +static bool ggml_vk_can_fuse_snake(ggml_backend_vk_context * ctx, const struct ggml_cgraph * cgraph, int node_idx) { + GGML_UNUSED(ctx); + if (!ggml_can_fuse(cgraph, node_idx, snake_pattern)) { + return false; + } + + const ggml_tensor * mul0 = cgraph->nodes[node_idx + 0]; + const ggml_tensor * sin_node = cgraph->nodes[node_idx + 1]; + const ggml_tensor * sqr = cgraph->nodes[node_idx + 2]; + const ggml_tensor * mul1 = cgraph->nodes[node_idx + 3]; + const ggml_tensor * add = cgraph->nodes[node_idx + 4]; + + const ggml_tensor * x = ggml_are_same_shape(mul0, mul0->src[0]) ? mul0->src[0] : mul0->src[1]; + const ggml_tensor * a = (x == mul0->src[0]) ? mul0->src[1] : mul0->src[0]; + + const ggml_tensor * inv_b = (mul1->src[0] == sqr) ? mul1->src[1] : mul1->src[0]; + const ggml_tensor * x_in_add = (add->src[0] == mul1) ? add->src[1] : add->src[0]; + + if (x_in_add != x) { + return false; + } + if (x->type != GGML_TYPE_F32 && x->type != GGML_TYPE_F16 && x->type != GGML_TYPE_BF16) { + return false; + } + // Shader bindings: data_a is A_TYPE so it follows x's precision, while + // data_b and data_c are hardcoded float, so the broadcast operands must + // be F32 regardless of x's type. + if (a->type != GGML_TYPE_F32) return false; + if (inv_b->type != GGML_TYPE_F32) return false; + // Chain intermediates and output share x's precision (single A_TYPE / D_TYPE pipeline). + if (mul0->type != x->type) return false; + if (sin_node->type != x->type) return false; + if (sqr->type != x->type) return false; + if (mul1->type != x->type) return false; + if (add->type != x->type) return false; + if (!ggml_are_same_shape(a, inv_b)) { + return false; + } + if (a->ne[0] != 1 || a->ne[1] != x->ne[1]) { + return false; + } + // Dispatch is 2D over (ne0, ne1), so x and add must be 2D and a / inv_b + // must collapse to [1, C, 1, 1]. Higher dims are not handled by the shader. + if (x->ne[2] != 1 || x->ne[3] != 1) return false; + if (add->ne[2] != 1 || add->ne[3] != 1) return false; + if (a->ne[2] != 1 || a->ne[3] != 1) return false; + if (inv_b->ne[2] != 1 || inv_b->ne[3] != 1) return false; + // Shader uses idx = i0 + i1 * ne0 and reads data_b[i1] / data_c[i1], + // so every operand must be contiguous. + if (!ggml_is_contiguous(x) || !ggml_is_contiguous(add) || + !ggml_is_contiguous(a) || !ggml_is_contiguous(inv_b)) { + return false; + } + return true; +} + +// Check whether the tensors overlap in memory. +// Fusions can potentially overwrite src tensors in ways that are not prevented +// by ggml-alloc. If the fusion src is being applied in a way that's elementwise +// with the destination, then it's OK for them to overlap if they are exactly equal. +static bool ggml_vk_tensors_overlap(const ggml_tensor * a, const ggml_tensor * b, bool elementwise) { ggml_backend_vk_buffer_context * a_buf_ctx = (ggml_backend_vk_buffer_context *)a->buffer->context; vk_buffer a_buf = a_buf_ctx->dev_buffer; ggml_backend_vk_buffer_context * b_buf_ctx = (ggml_backend_vk_buffer_context *)b->buffer->context; @@ -13521,7 +15560,7 @@ static bool ggml_vk_tensors_overlap_but_not_equal(const ggml_tensor * a, const g auto b_base = vk_tensor_offset(b) + b->view_offs; auto b_size = ggml_nbytes(b); - if (a_base == b_base && a_size == b_size) { + if (elementwise && a_base == b_base && a_size == b_size) { return false; } @@ -13559,16 +15598,8 @@ static bool ggml_vk_can_fuse_rms_norm_mul_rope(ggml_backend_vk_context * ctx, co return false; } - // must not overwrite srcs in a way that's not elementwise - ggml_tensor *other_src = mul->src[0] == rms ? mul->src[1] : mul->src[0]; - if (ggml_vk_tensors_overlap_but_not_equal(rms->src[0], rope) || - ggml_vk_tensors_overlap_but_not_equal(other_src, rope)) { - return false; - } - // conditions for pipeline creation - if (!(ctx->device->float_controls_rte_fp16 && - sizeof(vk_op_rms_norm_mul_rope_push_constants) <= ctx->device->properties.limits.maxPushConstantsSize)) { + if (sizeof(vk_op_rms_norm_mul_rope_push_constants) > ctx->device->properties.limits.maxPushConstantsSize) { return false; } @@ -13627,6 +15658,18 @@ static uint32_t ggml_vk_fuse_multi_add(ggml_backend_vk_context * ctx, const stru return num_adds; } +static int32_t find_first_set(uint32_t x) { + int32_t ret = 0; + if (!x) { + return -1; + } + while (!(x & 1)) { + x >>= 1; + ret++; + } + return ret; +} + static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) { VK_LOG_DEBUG("ggml_backend_vk_graph_compute(" << cgraph->n_nodes << " nodes)"); ggml_backend_vk_context * ctx = (ggml_backend_vk_context *)backend->context; @@ -13645,7 +15688,7 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg int last_node = cgraph->n_nodes - 1; // If the last op in the cgraph isn't backend GPU, the command buffer doesn't get closed properly - while (last_node > 0 && ggml_vk_is_empty(cgraph->nodes[last_node])) { + while (last_node > 0 && (ggml_vk_is_empty(cgraph->nodes[last_node]) || ((cgraph->nodes[last_node]->flags & GGML_TENSOR_FLAG_COMPUTE) == 0))) { last_node -= 1; } @@ -13655,6 +15698,8 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg bool first_node_in_batch = true; // true if next node will be first node in a batch int submit_node_idx = 0; // index to first node in a batch + ggml_vk_submit_transfer_ctx(ctx); + vk_context compute_ctx; if (vk_perf_logger_enabled) { // allocate/resize the query pool @@ -13680,25 +15725,19 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg std::fill(ctx->query_node_idx.begin(), ctx->query_node_idx.end(), 0); GGML_ASSERT(ctx->compute_ctx.expired()); - compute_ctx = ggml_vk_create_context(ctx, ctx->compute_cmd_pool); - ctx->compute_ctx = compute_ctx; - ggml_vk_ctx_begin(ctx->device, compute_ctx); + compute_ctx = ggml_vk_get_compute_ctx(ctx); ctx->query_idx = 0; - compute_ctx->s->buffer.writeTimestamp(vk::PipelineStageFlagBits::eAllCommands, ctx->query_pool, ctx->query_idx++); + compute_ctx->s->buffer->buf.writeTimestamp(vk::PipelineStageFlagBits::eAllCommands, ctx->query_pool, ctx->query_idx++); + ggml_vk_sync_buffers(ctx, compute_ctx); } ctx->prealloc_y_last_pipeline_used = nullptr; ctx->prealloc_y_last_tensor_used = nullptr; + ctx->prealloc_y_last_decode_vector_staging = false; if (ctx->prealloc_size_add_rms_partials) { ggml_vk_preallocate_buffers(ctx, nullptr); - if (ctx->compute_ctx.expired()) { - compute_ctx = ggml_vk_create_context(ctx, ctx->compute_cmd_pool); - ctx->compute_ctx = compute_ctx; - ggml_vk_ctx_begin(ctx->device, compute_ctx); - } else { - compute_ctx = ctx->compute_ctx.lock(); - } + compute_ctx = ggml_vk_get_compute_ctx(ctx); // initialize partial sums to zero. ggml_vk_buffer_memset_async(compute_ctx, ctx->prealloc_add_rms_partials, 0, 0, ctx->prealloc_size_add_rms_partials); ggml_vk_sync_buffers(ctx, compute_ctx); @@ -13725,6 +15764,12 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg total_mul_mat_bytes += bytes; } + // op_srcs_fused_elementwise indicates whether an op's srcs all contribute to + // the fused result in an elementwise-way. This affects whether the memory for + // the src is allowed to overlap the memory for the destination. + // The array is sized to handle the largest fusion (asserted later). + bool op_srcs_fused_elementwise[12]; + ctx->fused_topk_moe_mode = TOPK_MOE_COUNT; ctx->fused_topk_moe_scale = false; const char *fusion_string {}; @@ -13733,39 +15778,89 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg if (num_adds) { ctx->num_additional_fused_ops = num_adds - 1; fusion_string = "MULTI_ADD"; + std::fill_n(op_srcs_fused_elementwise, ctx->num_additional_fused_ops + 1, true); } else if (ggml_vk_can_fuse(ctx, cgraph, i, { GGML_OP_MUL_MAT, GGML_OP_ADD, GGML_OP_ADD })) { ctx->num_additional_fused_ops = 2; fusion_string = "MUL_MAT_ADD_ADD"; + op_srcs_fused_elementwise[0] = false; + op_srcs_fused_elementwise[1] = true; + op_srcs_fused_elementwise[2] = true; } else if (ggml_vk_can_fuse(ctx, cgraph, i, { GGML_OP_MUL_MAT, GGML_OP_ADD })) { ctx->num_additional_fused_ops = 1; fusion_string = "MUL_MAT_ADD"; + op_srcs_fused_elementwise[0] = false; + op_srcs_fused_elementwise[1] = true; } else if (ggml_vk_can_fuse(ctx, cgraph, i, { GGML_OP_MUL_MAT_ID, GGML_OP_ADD_ID, GGML_OP_MUL })) { ctx->num_additional_fused_ops = 2; fusion_string = "MUL_MAT_ID_ADD_ID_MUL"; + op_srcs_fused_elementwise[0] = false; + op_srcs_fused_elementwise[1] = true; + op_srcs_fused_elementwise[2] = true; } else if (ggml_vk_can_fuse(ctx, cgraph, i, { GGML_OP_MUL_MAT_ID, GGML_OP_ADD_ID })) { ctx->num_additional_fused_ops = 1; fusion_string = "MUL_MAT_ID_ADD_ID"; + op_srcs_fused_elementwise[0] = false; + op_srcs_fused_elementwise[1] = true; } else if (ggml_vk_can_fuse(ctx, cgraph, i, { GGML_OP_MUL_MAT_ID, GGML_OP_MUL })) { ctx->num_additional_fused_ops = 1; fusion_string = "MUL_MAT_ID_MUL"; + op_srcs_fused_elementwise[0] = false; + op_srcs_fused_elementwise[1] = true; } else if (ggml_can_fuse_subgraph(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL, GGML_OP_ROPE, GGML_OP_VIEW, GGML_OP_SET_ROWS }, { i + 4 }) && ggml_check_edges(cgraph, i, rms_norm_mul_rope_view_set_rows_edges) && ggml_vk_can_fuse_rms_norm_mul_rope(ctx, cgraph, i) && ggml_vk_can_fuse_rope_set_rows(ctx, cgraph, i + 2)) { ctx->num_additional_fused_ops = 4; fusion_string = "RMS_NORM_MUL_ROPE_VIEW_SET_ROWS"; + op_srcs_fused_elementwise[0] = false; + op_srcs_fused_elementwise[1] = false; + op_srcs_fused_elementwise[2] = false; + op_srcs_fused_elementwise[3] = false; + op_srcs_fused_elementwise[4] = false; } else if (ggml_vk_can_fuse(ctx, cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL, GGML_OP_ROPE })&& ggml_vk_can_fuse_rms_norm_mul_rope(ctx, cgraph, i)) { ctx->num_additional_fused_ops = 2; fusion_string = "RMS_NORM_MUL_ROPE"; + // rope is approximately elementwise - whole rows are done by a single workgroup and it's row-wise + op_srcs_fused_elementwise[0] = false; + op_srcs_fused_elementwise[1] = true; + op_srcs_fused_elementwise[2] = true; } else if (ggml_vk_can_fuse(ctx, cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL })) { ctx->num_additional_fused_ops = 1; fusion_string = "RMS_NORM_MUL"; + // rms_norm is not elementwise, but whole rows must be consumed and the scale factor computed before + // they are overwritten, and one workgroup per row. So close enough. + op_srcs_fused_elementwise[0] = true; + op_srcs_fused_elementwise[1] = true; + } else if (ggml_vk_can_fuse_ssm_conv(ctx, cgraph, i, 2)) { + ctx->num_additional_fused_ops = 2; + fusion_string = "SSM_CONV_BIAS_SILU"; + // ssm_conv reads multiple input tokens per output, so it's not elementwise w.r.t. its srcs. + // The downstream add and silu are elementwise on the conv output. + op_srcs_fused_elementwise[0] = false; + op_srcs_fused_elementwise[1] = true; + op_srcs_fused_elementwise[2] = true; + } else if (ggml_vk_can_fuse_ssm_conv(ctx, cgraph, i, 1)) { + ctx->num_additional_fused_ops = 1; + fusion_string = "SSM_CONV_SILU"; + op_srcs_fused_elementwise[0] = false; + op_srcs_fused_elementwise[1] = true; } else if (ggml_can_fuse_subgraph(cgraph, i, { GGML_OP_ROPE, GGML_OP_VIEW, GGML_OP_SET_ROWS }, { i + 2 }) && ggml_check_edges(cgraph, i, rope_view_set_rows_edges) && ggml_vk_can_fuse_rope_set_rows(ctx, cgraph, i)) { ctx->num_additional_fused_ops = 2; fusion_string = "ROPE_VIEW_SET_ROWS"; + op_srcs_fused_elementwise[0] = false; + op_srcs_fused_elementwise[1] = false; + op_srcs_fused_elementwise[2] = false; + } else if (ggml_vk_can_fuse_snake(ctx, cgraph, i)) { + ctx->num_additional_fused_ops = 4; + fusion_string = "SNAKE"; + // elementwise=true: snake.comp is safe under exact aliasing because each + // thread reads data_x[idx] into a register before writing data_d[idx] + // with a data dependency on that register. The overlap check still + // rejects partial overlaps (different base or size). + std::fill_n(op_srcs_fused_elementwise, 5, true); } else if (ggml_can_fuse_subgraph(cgraph, i, topk_moe_early_softmax_norm, { i + 3, i + 9 }) && ggml_check_edges(cgraph, i, topk_moe_early_softmax_norm_edges) && ggml_vk_can_fuse_topk_moe(ctx, cgraph, i, TOPK_MOE_EARLY_SOFTMAX_NORM)) { @@ -13774,6 +15869,7 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg ctx->fused_ops_write_mask |= 1 << 3; ctx->fused_topk_moe_mode = TOPK_MOE_EARLY_SOFTMAX_NORM; fusion_string = "TOPK_MOE_EARLY_SOFTMAX_NORM"; + std::fill_n(op_srcs_fused_elementwise, ctx->num_additional_fused_ops + 1, false); } else if (ggml_can_fuse_subgraph(cgraph, i, topk_moe_sigmoid_norm_bias, { i + 4, i + 10 }) && ggml_check_edges(cgraph, i, topk_moe_sigmoid_norm_bias_edges) && ggml_vk_can_fuse_topk_moe(ctx, cgraph, i, TOPK_MOE_SIGMOID_NORM_BIAS)) { @@ -13782,6 +15878,7 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg ctx->fused_ops_write_mask |= 1 << 4; ctx->fused_topk_moe_mode = TOPK_MOE_SIGMOID_NORM_BIAS; fusion_string = "TOPK_MOE_SIGMOID_NORM_BIAS"; + std::fill_n(op_srcs_fused_elementwise, ctx->num_additional_fused_ops + 1, false); } else if (ggml_can_fuse_subgraph(cgraph, i, topk_moe_early_softmax, { i + 3, i + 4 }) && ggml_check_edges(cgraph, i, topk_moe_early_softmax_edges) && ggml_vk_can_fuse_topk_moe(ctx, cgraph, i, TOPK_MOE_EARLY_SOFTMAX)) { @@ -13790,6 +15887,7 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg ctx->fused_ops_write_mask |= 1 << 3; ctx->fused_topk_moe_mode = TOPK_MOE_EARLY_SOFTMAX; fusion_string = "TOPK_MOE_EARLY_SOFTMAX"; + std::fill_n(op_srcs_fused_elementwise, ctx->num_additional_fused_ops + 1, false); } else if (ggml_can_fuse_subgraph(cgraph, i, topk_moe_late_softmax, { i + 1, i + 5 }) && ggml_check_edges(cgraph, i, topk_moe_late_softmax_edges) && ggml_vk_can_fuse_topk_moe(ctx, cgraph, i, TOPK_MOE_LATE_SOFTMAX)) { @@ -13798,6 +15896,7 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg ctx->fused_ops_write_mask |= 1 << 1; ctx->fused_topk_moe_mode = TOPK_MOE_LATE_SOFTMAX; fusion_string = "TOPK_MOE_LATE_SOFTMAX"; + std::fill_n(op_srcs_fused_elementwise, ctx->num_additional_fused_ops + 1, false); } if (ctx->fused_topk_moe_mode != TOPK_MOE_COUNT) { // Look for an additional scale op to fuse - occurs in deepseek2 and nemotron3 nano. @@ -13805,11 +15904,73 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg ggml_can_fuse_subgraph(cgraph, i + ctx->num_additional_fused_ops, { GGML_OP_GET_ROWS, GGML_OP_SCALE }, { i + ctx->num_additional_fused_ops + 1 })) { ctx->fused_topk_moe_scale = true; ctx->num_additional_fused_ops++; + op_srcs_fused_elementwise[ctx->num_additional_fused_ops] = false; } } } + GGML_ASSERT(ctx->num_additional_fused_ops < (int)(sizeof(op_srcs_fused_elementwise) / sizeof(op_srcs_fused_elementwise[0]))); ctx->fused_ops_write_mask |= 1 << ctx->num_additional_fused_ops; + // Check whether fusion would overwrite src operands while they're still in use. + // If so, disable fusion. + if (ctx->num_additional_fused_ops) { + // There are up to two output nodes - topk_moe has two. + uint32_t bits = ctx->fused_ops_write_mask & ~(1 << ctx->num_additional_fused_ops); + ggml_tensor *output_nodes[2] {}; + output_nodes[0] = cgraph->nodes[i + ctx->num_additional_fused_ops]; + if (bits) { + int output_idx = find_first_set(bits); + GGML_ASSERT(bits == (1u << output_idx)); + output_nodes[1] = cgraph->nodes[i + output_idx]; + } + + bool need_disable = false; + + // topk_moe often overwrites the source, but for a given row all the src values are + // loaded before anything is stored. If there's only one row, this is safe, so treat + // this as a special case. + bool is_topk_moe_single_row = ctx->fused_topk_moe_mode != TOPK_MOE_COUNT && + ggml_nrows(cgraph->nodes[i]->src[0]) == 1; + + if (!is_topk_moe_single_row) { + for (int j = 0; j < 2; ++j) { + ggml_tensor *dst = output_nodes[j]; + if (!dst) { + continue; + } + // Loop over all srcs of all nodes in the fusion. If the src overlaps + // the destination and the src is not an intermediate node that's being + // elided, then disable fusion. + for (int k = 0; k <= ctx->num_additional_fused_ops; ++k) { + for (uint32_t s = 0; s < GGML_MAX_SRC; ++s) { + ggml_tensor *src = cgraph->nodes[i + k]->src[s]; + if (!src || src->op == GGML_OP_NONE) { + continue; + } + if (ggml_vk_tensors_overlap(src, dst, op_srcs_fused_elementwise[k])) { + bool found = false; + for (int n = 0; n < k; ++n) { + if (cgraph->nodes[i + n] == src) { + found = true; + break; + } + } + if (!found) { + need_disable = true; + } + } + } + } + } + } + if (need_disable) { + ctx->num_additional_fused_ops = 0; + ctx->fused_ops_write_mask = 1; + ctx->fused_topk_moe_mode = TOPK_MOE_COUNT; + ctx->fused_topk_moe_scale = false; + } + } + // Signal the almost_ready fence when the graph is mostly complete (< 20% remaining) bool almost_ready = (cgraph->n_nodes - i) < cgraph->n_nodes / 5; bool submit = (submitted_nodes >= nodes_per_submit) || @@ -13820,18 +15981,13 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg bool enqueued = ggml_vk_build_graph(ctx, cgraph, i, cgraph->nodes[submit_node_idx], submit_node_idx, i + ctx->num_additional_fused_ops >= last_node, almost_ready, submit); if (vk_perf_logger_enabled && enqueued) { - if (ctx->compute_ctx.expired()) { - compute_ctx = ggml_vk_create_context(ctx, ctx->compute_cmd_pool); - ctx->compute_ctx = compute_ctx; - ggml_vk_ctx_begin(ctx->device, compute_ctx); - } else { - compute_ctx = ctx->compute_ctx.lock(); - } + compute_ctx = ggml_vk_get_compute_ctx(ctx); if (!vk_perf_logger_concurrent) { // track a single node/fusion for the current query ctx->query_nodes[ctx->query_idx] = cgraph->nodes[i]; ctx->query_fusion_names[ctx->query_idx] = fusion_string; - compute_ctx->s->buffer.writeTimestamp(vk::PipelineStageFlagBits::eAllCommands, ctx->query_pool, ctx->query_idx++); + compute_ctx->s->buffer->buf.writeTimestamp(vk::PipelineStageFlagBits::eAllCommands, ctx->query_pool, ctx->query_idx++); + ggml_vk_sync_buffers(ctx, compute_ctx); } else { // track a fusion string and number of fused ops for the current node_idx ctx->query_fusion_names[i] = fusion_string; @@ -13874,6 +16030,7 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg ggml_vk_submit(compute_ctx, ctx->device->fence); VK_CHECK(ctx->device->device.waitForFences({ ctx->device->fence }, true, UINT64_MAX), "GGML_VULKAN_PERF waitForFences"); ctx->device->device.resetFences({ ctx->device->fence }); + ctx->compute_ctx.reset(); // Get the results and pass them to the logger std::vector<uint64_t> timestamps(cgraph->n_nodes + 1); @@ -13994,6 +16151,9 @@ static void ggml_vk_graph_optimize(ggml_backend_t backend, struct ggml_cgraph * if (keep_pattern(topk_moe_late_softmax)) { continue; } + if (keep_pattern(snake_pattern)) { + continue; + } // First, grab the next unused node. current_set.push_back(first_unused); @@ -14016,7 +16176,8 @@ static void ggml_vk_graph_optimize(ggml_backend_t backend, struct ggml_cgraph * if (match_pattern(topk_moe_early_softmax_norm, j) || match_pattern(topk_moe_sigmoid_norm_bias, j) || match_pattern(topk_moe_early_softmax, j) || - match_pattern(topk_moe_late_softmax, j)) { + match_pattern(topk_moe_late_softmax, j) || + match_pattern(snake_pattern, j)) { continue; } bool ok = true; @@ -14027,7 +16188,9 @@ static void ggml_vk_graph_optimize(ggml_backend_t backend, struct ggml_cgraph * !(j == c+1 && c == current_set.back() && graph->nodes[c]->op == GGML_OP_MUL_MAT && graph->nodes[j]->op == GGML_OP_ADD) && !(j == c+1 && c == current_set.back() && graph->nodes[c]->op == GGML_OP_MUL_MAT_ID && graph->nodes[j]->op == GGML_OP_ADD_ID) && !(j == c+1 && c == current_set.back() && graph->nodes[c]->op == GGML_OP_MUL_MAT_ID && graph->nodes[j]->op == GGML_OP_MUL) && - !(j == c+1 && c == current_set.back() && graph->nodes[c]->op == GGML_OP_ADD && graph->nodes[j]->op == GGML_OP_ADD)) { + !(j == c+1 && c == current_set.back() && graph->nodes[c]->op == GGML_OP_ADD && graph->nodes[j]->op == GGML_OP_ADD) && + !(j == c+1 && c == current_set.back() && graph->nodes[c]->op == GGML_OP_SSM_CONV && graph->nodes[j]->op == GGML_OP_ADD) && + !(j == c+1 && c == current_set.back() && graph->nodes[c]->op == GGML_OP_SSM_CONV && graph->nodes[j]->op == GGML_OP_UNARY)) { ok = false; break; } @@ -14110,6 +16273,19 @@ static void ggml_vk_graph_optimize(ggml_backend_t backend, struct ggml_cgraph * } } } + // SSM_CONV + ADD + UNARY: pull the consuming UNARY forward + if (j > 0 && + graph->nodes[j]->op == GGML_OP_ADD && + graph->nodes[j-1]->op == GGML_OP_SSM_CONV) { + for (int k = j + 1; k < std::min(j + 15, graph->n_nodes); ++k) { + if (graph->nodes[k]->op == GGML_OP_UNARY && + graph->nodes[k]->src[0] == graph->nodes[j]) { + current_set.push_back(k); + used[k] = true; + break; + } + } + } } } // Second pass grabs view nodes. @@ -14160,29 +16336,37 @@ static void ggml_backend_vk_event_record(ggml_backend_t backend, ggml_backend_ev ggml_backend_vk_context * ctx = (ggml_backend_vk_context *)backend->context; vk_event *vkev = (vk_event *)event->context; - vk_context transfer_ctx; + ggml_vk_submit_transfer_ctx(ctx); + + vk_context compute_ctx = ggml_vk_get_compute_ctx(ctx); + auto* cmd_buf = compute_ctx->s->buffer; // retrieve pointer before it gets reset - if (ctx->transfer_ctx.expired()) { - // Initialize new transfer context - transfer_ctx = ggml_vk_create_context(ctx, ctx->compute_cmd_pool); - ctx->transfer_ctx = transfer_ctx; - ggml_vk_ctx_begin(ctx->device, transfer_ctx); + if (vkev->has_event) { + // Move existing event into submitted + vkev->events_submitted.push_back(vkev->event); + } + + // Grab the next event and record it, create one if necessary + if (vkev->events_free.empty()) { + vkev->event = ctx->device->device.createEvent({}); } else { - transfer_ctx = ctx->transfer_ctx.lock(); + vkev->event = vkev->events_free.back(); + vkev->events_free.pop_back(); } - // the backend interface doesn't have an explicit reset, so reset it here - // before we record the command to set it - ctx->device->device.resetEvent(vkev->event); - ctx->device->device.resetFences({ vkev->fence }); + vkev->has_event = true; - ggml_vk_set_event(transfer_ctx, vkev->event); + ggml_vk_set_event(compute_ctx, vkev->event); - ggml_vk_ctx_end(transfer_ctx); + vkev->tl_semaphore.value++; + compute_ctx->s->signal_semaphores.push_back(vkev->tl_semaphore); + ggml_vk_ctx_end(compute_ctx); - ggml_vk_submit(transfer_ctx, {vkev->fence}); + ggml_vk_submit(compute_ctx, {}); ctx->submit_pending = true; - ctx->transfer_ctx.reset(); + vkev->cmd_buffer = cmd_buf; + vkev->cmd_buffer_use_counter = cmd_buf->use_counter; + ctx->compute_ctx.reset(); } static void ggml_backend_vk_event_wait(ggml_backend_t backend, ggml_backend_event_t event) { @@ -14190,20 +16374,12 @@ static void ggml_backend_vk_event_wait(ggml_backend_t backend, ggml_backend_even ggml_backend_vk_context * ctx = (ggml_backend_vk_context *)backend->context; vk_event *vkev = (vk_event *)event->context; - vk_context transfer_ctx; + vk_context compute_ctx = ggml_vk_get_compute_ctx(ctx); - if (ctx->transfer_ctx.expired()) { - // Initialize new transfer context - transfer_ctx = ggml_vk_create_context(ctx, ctx->compute_cmd_pool); - ctx->transfer_ctx = transfer_ctx; - ggml_vk_ctx_begin(ctx->device, transfer_ctx); - } else { - transfer_ctx = ctx->transfer_ctx.lock(); + if (vkev->has_event) { + // Wait for latest event + ggml_vk_wait_events(compute_ctx, { vkev->event }); } - - ggml_vk_wait_events(transfer_ctx, {vkev->event}); - ggml_vk_ctx_end(transfer_ctx); - ctx->transfer_ctx.reset(); } // TODO: enable async and synchronize @@ -14212,7 +16388,9 @@ static ggml_backend_i ggml_backend_vk_interface = { /* .free = */ ggml_backend_vk_free, /* .set_tensor_async = */ ggml_backend_vk_set_tensor_async, /* .get_tensor_async = */ ggml_backend_vk_get_tensor_async, - /* .cpy_tensor_async = */ NULL, // ggml_backend_vk_cpy_tensor_async, + /* .set_tensor_2d_async = */ ggml_backend_vk_set_tensor_2d_async, + /* .get_tensor_2d_async = */ ggml_backend_vk_get_tensor_2d_async, + /* .cpy_tensor_async = */ ggml_backend_vk_cpy_tensor_async, /* .synchronize = */ ggml_backend_vk_synchronize, /* .graph_plan_create = */ NULL, /* .graph_plan_free = */ NULL, @@ -14413,13 +16591,29 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context; const vk_device& device = ggml_vk_get_device(ctx->device); + const bool uses_bda = (op->op == GGML_OP_IM2COL || op->op == GGML_OP_IM2COL_3D) && + device->shader_int64 && device->buffer_device_address; + + auto const & tensor_size_supported = [&](size_t tensor_size) { + if (tensor_size > device->max_buffer_size) { + return false; + } + // For im2col shaders using BDA, maxStorageBufferRange limit doesn't apply. + // If shader64BitIndexing is enabled, maxStorageBufferRange limit doesn't apply. + if (!uses_bda && !device->shader_64b_indexing) { + if (tensor_size > device->properties.limits.maxStorageBufferRange) { + return false; + } + } + return true; + }; // reject any tensors larger than the max buffer size for (int i = 0; i < GGML_MAX_SRC; i++) { - if (op->src[i] && ggml_nbytes(op->src[i]) > device->max_buffer_size) { + if (op->src[i] && !tensor_size_supported(ggml_nbytes(op->src[i]))) { return false; } } - if (ggml_nbytes(op) > device->max_buffer_size) { + if (!tensor_size_supported(ggml_nbytes(op))) { return false; } @@ -14427,6 +16621,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm case GGML_OP_UNARY: switch (ggml_get_unary_op(op)) { case GGML_UNARY_OP_EXP: + case GGML_UNARY_OP_ELU: case GGML_UNARY_OP_GELU: case GGML_UNARY_OP_GELU_ERF: case GGML_UNARY_OP_GELU_QUICK: @@ -14445,6 +16640,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm case GGML_UNARY_OP_CEIL: case GGML_UNARY_OP_FLOOR: case GGML_UNARY_OP_TRUNC: + case GGML_UNARY_OP_SGN: return ggml_is_contiguous(op->src[0]) && (op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16) && (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) && @@ -14460,8 +16656,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm case GGML_GLU_OP_SWIGLU_OAI: case GGML_GLU_OP_GEGLU_ERF: case GGML_GLU_OP_GEGLU_QUICK: - return ggml_is_contiguous(op->src[0]) && - (op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16) && + return (op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16) && (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) && (op->src[0]->type == op->type); default: @@ -14481,6 +16676,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm case GGML_TYPE_F32: case GGML_TYPE_F16: case GGML_TYPE_BF16: + case GGML_TYPE_Q1_0: case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_1: case GGML_TYPE_Q5_0: @@ -14501,6 +16697,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm case GGML_TYPE_IQ4_XS: case GGML_TYPE_IQ4_NL: case GGML_TYPE_MXFP4: + case GGML_TYPE_NVFP4: break; default: return false; @@ -14549,42 +16746,27 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm if (op->src[3] && op->src[3]->type != GGML_TYPE_F16) { return false; } - // It's straightforward to support different K/V dequant, but would - // significantly increase the number of pipelines - if (op->src[1]->type != op->src[2]->type) { - return false; - } - switch (op->src[1]->type) { - case GGML_TYPE_F16: - case GGML_TYPE_F32: - case GGML_TYPE_Q4_0: - case GGML_TYPE_Q8_0: - // supported in scalar and coopmat2 paths - break; - case GGML_TYPE_Q4_1: - case GGML_TYPE_Q5_0: - case GGML_TYPE_Q5_1: - // K dequants currently disabled because D dimension is rounded up to 256 and runs inefficiently - //case GGML_TYPE_Q2_K: - //case GGML_TYPE_Q3_K: - //case GGML_TYPE_Q4_K: - //case GGML_TYPE_Q5_K: - //case GGML_TYPE_Q6_K: - //case GGML_TYPE_IQ1_S: - //case GGML_TYPE_IQ1_M: - //case GGML_TYPE_IQ2_XXS: - //case GGML_TYPE_IQ2_XS: - //case GGML_TYPE_IQ2_S: - //case GGML_TYPE_IQ3_XXS: - //case GGML_TYPE_IQ3_S: - //case GGML_TYPE_IQ4_XS: - case GGML_TYPE_IQ4_NL: - // currently supported only in coopmat2 path - if (!coopmat2) { + auto fa_kv_ok = [coopmat2](ggml_type t) { + switch (t) { + case GGML_TYPE_F32: + case GGML_TYPE_F16: + case GGML_TYPE_BF16: + case GGML_TYPE_Q8_0: + case GGML_TYPE_Q5_1: + case GGML_TYPE_Q5_0: + case GGML_TYPE_Q4_1: + case GGML_TYPE_Q4_0: + return true; + case GGML_TYPE_Q1_0: + return coopmat2; + default: return false; } - break; - default: + }; + if (!fa_kv_ok(op->src[1]->type) || !fa_kv_ok(op->src[2]->type)) { + return false; + } + if ((op->src[1]->type == GGML_TYPE_BF16) != (op->src[2]->type == GGML_TYPE_BF16)) { return false; } if (!coopmat2 && !(device->subgroup_shuffle && device->subgroup_vote)) { @@ -14599,6 +16781,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm case GGML_TYPE_F32: case GGML_TYPE_F16: case GGML_TYPE_BF16: + case GGML_TYPE_Q1_0: case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_1: case GGML_TYPE_Q5_0: @@ -14619,6 +16802,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm case GGML_TYPE_IQ4_XS: case GGML_TYPE_IQ4_NL: case GGML_TYPE_MXFP4: + case GGML_TYPE_NVFP4: case GGML_TYPE_I32: return true; default: @@ -14631,6 +16815,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm case GGML_TYPE_F32: case GGML_TYPE_F16: case GGML_TYPE_BF16: + case GGML_TYPE_Q1_0: case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_1: case GGML_TYPE_Q5_0: @@ -14654,6 +16839,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm case GGML_TYPE_F32: case GGML_TYPE_F16: case GGML_TYPE_BF16: + case GGML_TYPE_Q1_0: case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_1: case GGML_TYPE_Q5_0: @@ -14668,6 +16854,8 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm if (src1_type == GGML_TYPE_F32) { switch (src0_type) { case GGML_TYPE_F16: + case GGML_TYPE_BF16: + case GGML_TYPE_Q1_0: case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_1: case GGML_TYPE_Q5_0: @@ -14703,10 +16891,12 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm return false; } case GGML_OP_REPEAT: - return ggml_type_size(op->type) == sizeof(float) && ggml_type_size(op->src[0]->type) == sizeof(float); + return ggml_type_size(op->type) == ggml_type_size(op->src[0]->type) && + (ggml_type_size(op->type) == sizeof(float) || ggml_type_size(op->type) == 2); case GGML_OP_REPEAT_BACK: return op->type == GGML_TYPE_F32 && op->src[0]->type == GGML_TYPE_F32; case GGML_OP_ROPE: + return ggml_is_contiguous_rows(op) && ggml_is_contiguous_rows(op->src[0]); case GGML_OP_ROPE_BACK: case GGML_OP_NONE: case GGML_OP_RESHAPE: @@ -14717,8 +16907,10 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm return true; case GGML_OP_NORM: case GGML_OP_GROUP_NORM: - case GGML_OP_L2_NORM: return ggml_is_contiguous(op->src[0]); + case GGML_OP_L2_NORM: + return ggml_is_contiguous_rows(op->src[0]) && + op->src[0]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32; case GGML_OP_ADD: case GGML_OP_SUB: case GGML_OP_MUL: @@ -14781,7 +16973,10 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm } return op->src[0]->type == GGML_TYPE_F32; case GGML_OP_ACC: - return op->src[0]->type == GGML_TYPE_F32; + return op->src[0]->type == GGML_TYPE_F32 && op->src[1]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32; + case GGML_OP_SET: + return op->src[0]->type == op->src[1]->type && op->src[0]->type == op->type && + (op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_I32); case GGML_OP_CONCAT: return ggml_type_size(op->src[0]->type) == ggml_type_size(GGML_TYPE_F32); case GGML_OP_ADD1: @@ -14789,8 +16984,9 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm || (op->src[0]->type == GGML_TYPE_F16 && op->src[1]->type == GGML_TYPE_F32) || (op->src[0]->type == GGML_TYPE_F16 && op->src[1]->type == GGML_TYPE_F16); case GGML_OP_ARANGE: - case GGML_OP_FILL: return op->type == GGML_TYPE_F32; + case GGML_OP_FILL: + return op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16; case GGML_OP_SCALE: return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32; case GGML_OP_PAD: @@ -14855,6 +17051,19 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm case GGML_OP_RWKV_WKV6: case GGML_OP_RWKV_WKV7: return true; // all inputs are contiguous, see ggml.c + case GGML_OP_GATED_DELTA_NET: + { + const uint32_t S_v = op->src[2]->ne[0]; + if (S_v != 32 && S_v != 64 && S_v != 128) { + return false; + } + for (int i = 0; i < 6; i++) { + if (op->src[i] == nullptr || op->src[i]->type != GGML_TYPE_F32) { + return false; + } + } + return op->type == GGML_TYPE_F32; + } case GGML_OP_SSM_SCAN: { for (int i = 0; i < 6; i++) { @@ -14926,11 +17135,25 @@ static bool ggml_backend_vk_device_supports_buft(ggml_backend_dev_t dev, ggml_ba return buft_ctx->device->idx == ctx->device; } +static int64_t ggml_vk_get_op_batch_size(const ggml_tensor * op) { + switch (op->op) { + case GGML_OP_GET_ROWS: + return 0; + case GGML_OP_MUL_MAT: + return op->ne[1]; + case GGML_OP_MUL_MAT_ID: + case GGML_OP_ROPE: + case GGML_OP_ROPE_BACK: + return op->ne[2]; + default: + return ggml_nrows(op); + } +} + static bool ggml_backend_vk_device_offload_op(ggml_backend_dev_t dev, const ggml_tensor * op) { ggml_backend_vk_device_context * dev_ctx = (ggml_backend_vk_device_context *)dev->context; - return (op->ne[1] >= dev_ctx->op_offload_min_batch_size && op->op != GGML_OP_GET_ROWS) || - (op->ne[2] >= dev_ctx->op_offload_min_batch_size && op->op == GGML_OP_MUL_MAT_ID); + return ggml_vk_get_op_batch_size(op) >= dev_ctx->op_offload_min_batch_size; } static ggml_backend_event_t ggml_backend_vk_device_event_new(ggml_backend_dev_t dev) { @@ -14942,10 +17165,13 @@ static ggml_backend_event_t ggml_backend_vk_device_event_new(ggml_backend_dev_t return nullptr; } - // The event/fence is expected to initially be in the signaled state. - vkev->event = device->device.createEvent({}); - vkev->fence = device->device.createFence({vk::FenceCreateFlagBits::eSignaled}); - device->device.setEvent(vkev->event); + // No events initially, they get created on demand + vkev->has_event = false; + + vk::SemaphoreTypeCreateInfo tci{ vk::SemaphoreType::eTimeline, 0 }; + vk::SemaphoreCreateInfo ci{}; + ci.setPNext(&tci); + vkev->tl_semaphore = { device->device.createSemaphore(ci), 0 }; return new ggml_backend_event { /* .device = */ dev, @@ -14959,8 +17185,16 @@ static void ggml_backend_vk_device_event_free(ggml_backend_dev_t dev, ggml_backe vk_event *vkev = (vk_event *)event->context; - device->device.destroyFence(vkev->fence); - device->device.destroyEvent(vkev->event); + device->device.destroySemaphore(vkev->tl_semaphore.s); + for (auto& event : vkev->events_free) { + device->device.destroyEvent(event); + } + for (auto& event : vkev->events_submitted) { + device->device.destroyEvent(event); + } + if (vkev->has_event) { + device->device.destroyEvent(vkev->event); + } delete vkev; delete event; } @@ -14971,7 +17205,30 @@ static void ggml_backend_vk_device_event_synchronize(ggml_backend_dev_t dev, ggm auto device = ggml_vk_get_device(ctx->device); vk_event *vkev = (vk_event *)event->context; - VK_CHECK(device->device.waitForFences({ vkev->fence }, true, UINT64_MAX), "event_synchronize"); + // Only do something if the event has actually been used + if (vkev->has_event) { + vk::Semaphore sem = vkev->tl_semaphore.s; + uint64_t val = vkev->tl_semaphore.value; + vk::SemaphoreWaitInfo swi{vk::SemaphoreWaitFlags{}, sem, val}; + VK_CHECK(device->device.waitSemaphores(swi, UINT64_MAX), "event_synchronize"); + + // Reset and move submitted events + for (auto& event : vkev->events_submitted) { + device->device.resetEvent(event); + } + vkev->events_free.insert(vkev->events_free.end(), vkev->events_submitted.begin(), vkev->events_submitted.end()); + vkev->events_submitted.clear(); + + // Finished using current command buffer so we flag for reuse + if (vkev->cmd_buffer) { + // Only flag for reuse if it hasn't been reused already + if (vkev->cmd_buffer_use_counter == vkev->cmd_buffer->use_counter) { + vkev->cmd_buffer->in_use = false; + vkev->cmd_buffer->buf.reset(); + } + vkev->cmd_buffer = nullptr; + } + } } static vk_buffer ggml_vk_buffer_from_host_ptr(vk_device & device, void * ptr, size_t size) { @@ -15190,6 +17447,47 @@ static bool ggml_vk_khr_cooperative_matrix_support(const vk::PhysicalDevicePrope } } +static uint32_t ggml_vk_intel_shader_core_count(const vk::PhysicalDevice& vkdev) { + VkPhysicalDeviceProperties2 props = vkdev.getProperties2(); + + if (props.properties.vendorID != VK_VENDOR_ID_INTEL) { + return 0; + } + + const uint32_t device_id = props.properties.deviceID; + + switch (device_id) { + case 0x56A6: // A310 + return 6; + case 0x5693: // A370M + case 0x56A5: // A380 + case 0x56B1: // Pro A40/A50 + return 8; + case 0x5697: // A530M + return 12; + case 0x5692: // A550M + case 0x56B3: // Pro A60 + return 16; + case 0x56A2: // A580 + return 24; + case 0x5691: // A730M + case 0x56A1: // A750 + return 28; + case 0x56A0: // A770 + case 0x5690: // A770M + return 32; + case 0xE212: // Pro B50 + return 16; + case 0xE20C: // B570 + return 18; + case 0xE20B: // B580 + case 0xE211: // Pro B60 + return 20; + default: + return 0; + } +} + // checks #ifdef GGML_VULKAN_CHECK_RESULTS @@ -15403,7 +17701,7 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph * tensor_clone = ggml_arange(ggml_ctx, start, stop, step); } else if (tensor->op == GGML_OP_FILL) { const float value = ggml_get_op_params_f32(tensor, 0); - tensor_clone = ggml_fill(ggml_ctx, tensor_clone, value); + tensor_clone = ggml_fill(ggml_ctx, src_clone[0], value); } else if (tensor->op == GGML_OP_SQR) { tensor_clone = ggml_sqr(ggml_ctx, src_clone[0]); } else if (tensor->op == GGML_OP_SQRT) { @@ -15432,6 +17730,8 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph * tensor_clone = ggml_add(ggml_ctx, src_clone[0], src_clone[1]); } else if (tensor->op == GGML_OP_ACC) { tensor_clone = ggml_acc(ggml_ctx, src_clone[0], src_clone[1], tensor->op_params[0], tensor->op_params[1], tensor->op_params[2], tensor->op_params[3]); + } else if (tensor->op == GGML_OP_SET) { + tensor_clone = ggml_set(ggml_ctx, src_clone[0], src_clone[1], tensor->op_params[0], tensor->op_params[1], tensor->op_params[2], tensor->op_params[3]); } else if (tensor->op == GGML_OP_NORM) { tensor_clone = ggml_norm(ggml_ctx, src_clone[0], *(float *)tensor->op_params); } else if (tensor->op == GGML_OP_GROUP_NORM) { @@ -15488,6 +17788,9 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph * case GGML_UNARY_OP_EXP: tensor_clone = ggml_exp(ggml_ctx, src_clone[0]); break; + case GGML_UNARY_OP_ELU: + tensor_clone = ggml_elu(ggml_ctx, src_clone[0]); + break; case GGML_UNARY_OP_SILU: tensor_clone = ggml_silu(ggml_ctx, src_clone[0]); break; @@ -15546,6 +17849,9 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph * case GGML_UNARY_OP_TRUNC: tensor_clone = ggml_trunc(ggml_ctx, src_clone[0]); break; + case GGML_UNARY_OP_SGN: + tensor_clone = ggml_sgn(ggml_ctx, src_clone[0]); + break; default: std::cerr << "Missing vk_check_results OP: " << ggml_op_name(tensor->op) << std::endl; GGML_ABORT("fatal error"); @@ -15666,6 +17972,10 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph * } else if (tensor->op == GGML_OP_RWKV_WKV7) { tensor_clone = ggml_rwkv_wkv7(ggml_ctx, src_clone[0], src_clone[1], src_clone[2], src_clone[3], src_clone[4], src_clone[5], src_clone[6]); + } else if (tensor->op == GGML_OP_GATED_DELTA_NET) { + tensor_clone = ggml_gated_delta_net(ggml_ctx, src_clone[0], src_clone[1], + src_clone[2], src_clone[3], src_clone[4], src_clone[5], + ggml_get_op_params_i32(tensor, 0)); } else if (tensor->op == GGML_OP_OPT_STEP_ADAMW) { src_clone[0]->flags = tensor->src[0]->flags; tensor_clone = ggml_opt_step_adamw(ggml_ctx, src_clone[0], src_clone[1], @@ -15864,7 +18174,7 @@ static void ggml_vk_check_results_1(ggml_backend_vk_context * ctx, ggml_cgraph * ggml_vk_print_graph_origin(tensor, done); } - if (avg_err > 0.5 || std::isnan(avg_err)) { + if (avg_err > 0.01 || std::isnan(avg_err)) { std::cerr << "ERROR: avg_err=" << avg_err << " in " << ggml_op_name(tensor->op) << " (check " << check_counter << ")" << std::endl; std::cerr << "tensor=" << tensor << " tensor->name=" << tensor->name << " tensor->type: " << ggml_type_name(tensor->type) << " ne0=" << tensor->ne[0] << " nb0=" << tensor->nb[0] << " ne1=" << tensor->ne[1] << " nb1=" << tensor->nb[1] << " ne2=" << tensor->ne[2] << " nb2=" << tensor->nb[2] << " ne3=" << tensor->ne[3] << " nb3=" << tensor->nb[3] << " offset=" << tensor->view_offs << std::endl; if (src0 != nullptr) { diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt b/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt index e1f613fb4f6..10a9ea21025 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt +++ b/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt @@ -11,6 +11,10 @@ if (GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT) add_compile_definitions(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT) message(STATUS "Enabling coopmat2 glslc support") endif() +if (GGML_VULKAN_COOPMAT2_DECODE_VECTOR_GLSLC_SUPPORT) + add_compile_definitions(GGML_VULKAN_COOPMAT2_DECODE_VECTOR_GLSLC_SUPPORT) + message(STATUS "Enabling coopmat2 decode_vector glslc support") +endif() if (GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT) add_compile_definitions(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT) message(STATUS "Enabling dot glslc support") diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/acc.comp b/ggml/src/ggml-vulkan/vulkan-shaders/acc.comp index 5084a70ed49..6ba3d1d89e0 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/acc.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/acc.comp @@ -3,6 +3,9 @@ #include "types.glsl" #include "generic_binary_head.glsl" +// false for SET, true for ACC +layout(constant_id = 1) const bool ACC = true; + layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in; void main() { @@ -13,17 +16,22 @@ void main() { const uint offset = p.param3; const uint src1_i = idx - offset; - const uint oz = src1_i / p.nb02; - const uint oy = (src1_i - (oz * p.nb02)) / p.nb01; - const uint ox = src1_i % p.nb01; + const uint i3 = src1_i / p.nb03; + const uint rem2 = src1_i - i3 * p.nb03; + const uint i2 = rem2 / p.nb02; + const uint rem1 = rem2 - i2 * p.nb02; + const uint i1 = rem1 / p.nb01; + const uint i0 = rem1 % p.nb01; uint i00, i01, i02, i03; - get_indices(idx, i00, i01, i02, i03); - if (ox < p.ne10 && oy < p.ne11 && oz < p.ne12) { - data_d[get_doffset() + dst_idx(i00, i01, i02, i03)] = D_TYPE(FLOAT_TYPE(data_a[get_aoffset() + src0_idx(i00, i01, i02, i03)]) + FLOAT_TYPE(data_b[get_boffset() + ox + oy * p.ne10 + oz * p.ne10 * p.ne11])); + if (i0 < p.ne10 && i1 < p.ne11 && i2 < p.ne12 && i3 < p.ne13) { + if (ACC) { + data_d[get_doffset() + idx] = D_TYPE(FLOAT_TYPE(data_a[get_aoffset() + idx]) + FLOAT_TYPE(data_b[get_boffset() + src1_idx(i0, i1, i2, i3)])); + } else { + data_d[get_doffset() + idx] = D_TYPE(FLOAT_TYPE(data_b[get_boffset() + src1_idx(i0, i1, i2, i3)])); + } } else { - data_d[get_doffset() + dst_idx(i00, i01, i02, i03)] = D_TYPE(FLOAT_TYPE(data_a[get_aoffset() + src0_idx(i00, i01, i02, i03)])); + data_d[get_doffset() + idx] = D_TYPE(FLOAT_TYPE(data_a[get_aoffset() + idx])); } } - diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/contig_copy.comp b/ggml/src/ggml-vulkan/vulkan-shaders/contig_copy.comp index ca1a3ac25bd..b3b182fb084 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/contig_copy.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/contig_copy.comp @@ -19,7 +19,9 @@ void main() { if (idx + (num_iter-1)*num_threads < p.ne) { [[unroll]] for (uint i = 0; i < num_iter; ++i) { -#if defined(DATA_D_BF16) +#if defined(DATA_A_BF16) + data_d[get_doffset() + idx] = D_TYPE(bf16_to_fp32(uint32_t(data_a[get_aoffset() + idx]))); +#elif defined(DATA_D_BF16) float f = float(data_a[get_aoffset() + idx]); data_d[get_doffset() + idx] = D_TYPE(fp32_to_bf16(f)); #elif !defined(OPTIMIZATION_ERROR_WORKAROUND) @@ -35,7 +37,9 @@ void main() { continue; } -#if defined(DATA_D_BF16) +#if defined(DATA_A_BF16) + data_d[get_doffset() + idx] = D_TYPE(bf16_to_fp32(uint32_t(data_a[get_aoffset() + idx]))); +#elif defined(DATA_D_BF16) float f = float(data_a[get_aoffset() + idx]); data_d[get_doffset() + idx] = D_TYPE(fp32_to_bf16(f)); #elif !defined(OPTIMIZATION_ERROR_WORKAROUND) diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_mm.comp b/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_mm.comp index 875c012cd3b..1428ef68d81 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_mm.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_mm.comp @@ -7,6 +7,13 @@ #extension GL_KHR_memory_scope_semantics : enable #endif +#ifdef COOPMAT +#extension GL_KHR_cooperative_matrix : enable +#extension GL_KHR_shader_subgroup_basic : enable +#extension GL_EXT_shader_explicit_arithmetic_types_float16 : require +#extension GL_KHR_memory_scope_semantics : enable +#endif + #ifdef USE_COLLECTIVES # extension GL_KHR_shader_subgroup_shuffle : enable #endif @@ -77,6 +84,39 @@ layout(constant_id = 12) const uint d1 = 1; // Kernel spatial sizes layout(constant_id = 13) const uint KW = 1; layout(constant_id = 14) const uint KH = 1; +// when set, skip bounds checks and address clamps (K/CRS/NPQ are tile-aligned) +layout(constant_id = 15) const uint aligned = 0; +// stage cm2 result through shmem (Csh) for coalesced stores. cm1 always does this. +layout(constant_id = 16) const uint csh_store = 0; + +#ifdef COOPMAT +// cm1 subgroup tile: each subgroup computes a WM x WN region as a grid of +// TM x TN x TK fragments. Requires WM%TM == WN%TN == BS_K%WM == BS_NPQ%WN == +// BS_CRS%TK == 0, and WG_SIZE == (BS_K/WM) * (BS_NPQ/WN) * subgroup_size. +layout(constant_id = 17) const uint WM = 32; +layout(constant_id = 18) const uint WN = 32; +const uint TM = 16; +const uint TN = 16; +const uint TK = 16; +const uint cms_per_row = WM / TM; +const uint cms_per_col = WN / TN; +const uint warps_M = BS_K / WM; +const uint warps_N = BS_NPQ / WN; +#endif + +// without padding, H_idx/W_idx are in bounds by construction (non-TRANSPOSE only) +#ifdef TRANSPOSE +const bool hw_in_bounds = false; +#else +const bool hw_in_bounds = (p0 == 0) && (p1 == 0); +#endif + +// TRANSPOSE stride alignment is trivially satisfied for stride 1 +#ifdef TRANSPOSE +const bool stride_in_bounds = (s0 == 1) && (s1 == 1); +#else +const bool stride_in_bounds = true; +#endif uint32_t tid = gl_LocalInvocationID.x; const uint32_t WG_SIZE = gl_WorkGroupSize.x; @@ -94,7 +134,7 @@ uint32_t n_elems_out = K * NPQ; // Number of blocktiles per input uint32_t NB_CRS = splitWork(CRS, BS_CRS); -#ifdef COOPMAT2 +#if defined(COOPMAT2) || defined(COOPMAT) #define SHMEM_TYPE float16_t #else #define SHMEM_TYPE float @@ -112,6 +152,17 @@ const uint32_t Bsh_len = BS_CRS * Bsh_stride; shared SHMEM_TYPE Ash[Ash_len]; // K x CRS shared SHMEM_TYPE Bsh[Bsh_len]; // CRS x NPQ +#if defined(COOPMAT2) || defined(COOPMAT) +// stage matC through shmem so global stores are row-major (NPQ-contiguous) +const uint32_t Csh_stride = BS_NPQ; +#ifdef COOPMAT +const uint32_t Csh_len = BS_K * Csh_stride; +#else +const uint32_t Csh_len = csh_store != 0 ? BS_K * Csh_stride : 1; +#endif +shared SHMEM_TYPE Csh[Csh_len]; // K x NPQ +#endif + // Threadtile sizes const uint32_t TS_NPQ = BS_K * BS_NPQ / WG_SIZE / TS_K; @@ -161,7 +212,7 @@ ACC_TYPE perElemOpStore(const in uint32_t r, const in uint32_t c, const in ACC_T uint32_t OH_idx = fastdiv(NPQ_idx - N_idx * p.OH * p.OW, p.OWmp, p.OWL); // divide by p.OW; uint32_t OW_idx = NPQ_idx - N_idx * p.OH * p.OW - OH_idx * p.OW; uint32_t dst_idx = OW_idx + OH_idx * p.nb1 + K_idx * p.nb2 + N_idx * p.nb3; - if (K_idx < K && NPQ_idx < NPQ) { + if (aligned != 0 || (K_idx < K && NPQ_idx < NPQ)) { dst_data[dst_idx] = D_TYPE(elem); } return elem; @@ -176,6 +227,13 @@ void main() { #ifdef COOPMAT2 coopmat<ACC_TYPE, gl_ScopeWorkgroup, BS_K, BS_NPQ, gl_MatrixUseAccumulator> matC; matC = coopmat<ACC_TYPE, gl_ScopeWorkgroup, BS_K, BS_NPQ, gl_MatrixUseAccumulator>(0.0); +#elif defined(COOPMAT) + coopmat<float16_t, gl_ScopeSubgroup, TM, TN, gl_MatrixUseAccumulator> sums[cms_per_row * cms_per_col]; + [[unroll]] for (uint i = 0; i < cms_per_row * cms_per_col; i++) { + sums[i] = coopmat<float16_t, gl_ScopeSubgroup, TM, TN, gl_MatrixUseAccumulator>(0.0); + } + const uint warp_r = gl_SubgroupID / warps_N; + const uint warp_c = gl_SubgroupID % warps_N; #else float regC[TS_K][TS_NPQ]; for (uint32_t T_ly = 0; T_ly < TS_K; T_ly++) { @@ -228,12 +286,15 @@ void main() { uint32_t B_lx = Ac; uint32_t K_idx = B_idx_K * BS_K + B_ly; /* Global K_idx (row index of A)*/ #ifdef TRANSPOSE - uint32_t knl_idx = min(KW_idx_a + KH_idx_a * p.nb01 + K_idx * p.nb02 + Cin_idx_a * p.nb03, K * CRS - 1); + uint32_t knl_idx = KW_idx_a + KH_idx_a * p.nb01 + K_idx * p.nb02 + Cin_idx_a * p.nb03; #else - uint32_t knl_idx = min(KW_idx_a + KH_idx_a * p.nb01 + Cin_idx_a * p.nb02 + K_idx * p.nb03, K * CRS - 1); + uint32_t knl_idx = KW_idx_a + KH_idx_a * p.nb01 + Cin_idx_a * p.nb02 + K_idx * p.nb03; #endif + if (aligned == 0) { + knl_idx = min(knl_idx, K * CRS - 1); + } float val = knl_data[knl_idx]; - if (K_idx >= K || CRS_idx_a >= CRS) { + if (aligned == 0 && (K_idx >= K || CRS_idx_a >= CRS)) { val = 0.0; } Ash[B_ly * Ash_stride + B_lx] = SHMEM_TYPE(val); @@ -282,15 +343,27 @@ void main() { uint32_t H_idx = OH_idx * s1 + KH_idx_b * d1 - p1; uint32_t W_idx = OW_idx * s0 + KW_idx_b * d0 - p0; #endif - uint32_t src_idx = - min(max(W_idx + H_idx * p.nb11 + Cin_idx_b * p.nb12 + N_idx * p.nb13, 0), p.Cin * p.N * p.W * p.H - 1); + uint32_t src_idx = W_idx + H_idx * p.nb11 + Cin_idx_b * p.nb12 + N_idx * p.nb13; + // skip clamp when address can't go OOB + if (aligned == 0 || !hw_in_bounds || !stride_in_bounds) { + src_idx = min(max(src_idx, 0), p.Cin * p.N * p.W * p.H - 1); + } float val = src_data[src_idx]; - if (CRS_idx_b >= CRS || NPQ_idx >= NPQ - || H_idx >= p.H || W_idx >= p.W // Lower bound checks aren't necessary. (idx >= 0x80000000 for such case) + bool oob = false; + if (aligned == 0 && (CRS_idx_b >= CRS || NPQ_idx >= NPQ)) { + oob = true; + } + // also catches lower-bound underflow (idx wraps to 0x80000000+) + if (!hw_in_bounds && (H_idx >= p.H || W_idx >= p.W)) { + oob = true; + } #ifdef TRANSPOSE - || (H_idx_x_s1 - H_idx * s1 != 0) || (W_idx_x_s0 - W_idx * s0 != 0) + if (!stride_in_bounds && + ((H_idx_x_s1 - H_idx * s1 != 0) || (W_idx_x_s0 - W_idx * s0 != 0))) { + oob = true; + } #endif - ) { + if (oob) { val = 0.0; } Bsh[B_ly * Bsh_stride + B_lx] = SHMEM_TYPE(val); @@ -303,6 +376,23 @@ void main() { coopMatLoad(matA, Ash, 0, Ash_stride, gl_CooperativeMatrixLayoutRowMajor); coopMatLoad(matB, Bsh, 0, Bsh_stride, gl_CooperativeMatrixLayoutRowMajor); matC = coopMatMulAdd(matA, matB, matC); +#elif defined(COOPMAT) + // each subgroup multiplies its grid of fragments per TK-sized CRS chunk + [[unroll]] for (uint k_step = 0; k_step < BS_CRS / TK; k_step++) { + coopmat<float16_t, gl_ScopeSubgroup, TM, TK, gl_MatrixUseA> cache_a[cms_per_row]; + [[unroll]] for (uint cm_row = 0; cm_row < cms_per_row; cm_row++) { + const uint a_off = (warp_r * WM + cm_row * TM) * Ash_stride + k_step * TK; + coopMatLoad(cache_a[cm_row], Ash, a_off, Ash_stride, gl_CooperativeMatrixLayoutRowMajor); + } + [[unroll]] for (uint cm_col = 0; cm_col < cms_per_col; cm_col++) { + coopmat<float16_t, gl_ScopeSubgroup, TK, TN, gl_MatrixUseB> cache_b; + const uint b_off = k_step * TK * Bsh_stride + warp_c * WN + cm_col * TN; + coopMatLoad(cache_b, Bsh, b_off, Bsh_stride, gl_CooperativeMatrixLayoutRowMajor); + [[unroll]] for (uint cm_row = 0; cm_row < cms_per_row; cm_row++) { + sums[cm_col * cms_per_row + cm_row] = coopMatMulAdd(cache_a[cm_row], cache_b, sums[cm_col * cms_per_row + cm_row]); + } + } + } #else if (T_y * TS_K < K) { UNROLL for (uint32_t CRS_lidx = 0; CRS_lidx < BS_CRS; CRS_lidx++) { @@ -325,8 +415,51 @@ void main() { barrier(); } /* Save C* */ +#if defined(COOPMAT2) || defined(COOPMAT) + // stage matC into Csh, then write to dst with coalesced NPQ-contiguous stores +#ifdef COOPMAT + const bool use_staged_store = true; +#else + const bool use_staged_store = (csh_store != 0); +#endif + if (use_staged_store) { +#ifdef COOPMAT + // cm1: each subgroup stores its fragment grid into its Csh slot + [[unroll]] for (uint cm_row = 0; cm_row < cms_per_row; cm_row++) { + [[unroll]] for (uint cm_col = 0; cm_col < cms_per_col; cm_col++) { + const uint csh_off = (warp_r * WM + cm_row * TM) * Csh_stride + warp_c * WN + cm_col * TN; + coopMatStore(sums[cm_col * cms_per_row + cm_row], Csh, csh_off, Csh_stride, gl_CooperativeMatrixLayoutRowMajor); + } + } +#else + coopMatStore(matC, Csh, 0, Csh_stride, gl_CooperativeMatrixLayoutRowMajor); +#endif + barrier(); + + // cooperative shmem->global: WG threads spread across BS_NPQ (the + // contiguous direction of dst), each iter covers store_rows_per_iter K-rows + const uint32_t store_rows_per_iter = WG_SIZE / BS_NPQ; + const uint32_t store_iters = BS_K / store_rows_per_iter; + const uint32_t k_thread_offset = tid / BS_NPQ; + const uint32_t npq_thread = tid % BS_NPQ; + [[unroll]] for (uint32_t i = 0; i < store_iters; i++) { + uint32_t k_local = i * store_rows_per_iter + k_thread_offset; + uint32_t K_idx = B_idx_K * BS_K + k_local; + uint32_t NPQ_idx = B_idx_NPQ * BS_NPQ + npq_thread; + uint32_t N_idx = fastdiv(NPQ_idx, p.OWOHmp, p.OWOHL); + uint32_t OH_idx = fastdiv(NPQ_idx - N_idx * p.OH * p.OW, p.OWmp, p.OWL); + uint32_t OW_idx = NPQ_idx - N_idx * p.OH * p.OW - OH_idx * p.OW; + uint32_t dst_idx = OW_idx + OH_idx * p.nb1 + K_idx * p.nb2 + N_idx * p.nb3; + if (aligned != 0 || (K_idx < K && NPQ_idx < NPQ)) { + dst_data[dst_idx] = D_TYPE(Csh[k_local * Csh_stride + npq_thread]); + } + } + } #ifdef COOPMAT2 - coopMatPerElementNV(matC, matC, perElemOpStore); + else { + coopMatPerElementNV(matC, matC, perElemOpStore); + } +#endif #else if (T_y * TS_K < K) { for (uint32_t T_ly = 0; T_ly < TS_K; T_ly++) { @@ -337,7 +470,7 @@ void main() { uint32_t OH_idx = fastdiv(NPQ_idx - N_idx * p.OH * p.OW, p.OWmp, p.OWL); // divide by p.OW; uint32_t OW_idx = NPQ_idx - N_idx * p.OH * p.OW - OH_idx * p.OW; uint32_t dst_idx = OW_idx + OH_idx * p.nb1 + K_idx * p.nb2 + N_idx * p.nb3; - if (K_idx < K && NPQ_idx < NPQ) { + if (aligned != 0 || (K_idx < K && NPQ_idx < NPQ)) { dst_data[dst_idx] = regC[T_ly][T_lx]; } } diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/copy.comp b/ggml/src/ggml-vulkan/vulkan-shaders/copy.comp index 9f8bfd3c182..d55e13253a8 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/copy.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/copy.comp @@ -12,7 +12,9 @@ void main() { return; } -#if defined(DATA_D_BF16) +#if defined(DATA_A_BF16) + data_d[get_doffset() + dst_idx(idx)] = D_TYPE(bf16_to_fp32(uint32_t(data_a[get_aoffset() + src0_idx(idx)]))); +#elif defined(DATA_D_BF16) float f = float(data_a[get_aoffset() + src0_idx(idx)]); data_d[get_doffset() + dst_idx(idx)] = D_TYPE(fp32_to_bf16(f)); #elif !defined(OPTIMIZATION_ERROR_WORKAROUND) diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/copy_from_quant.comp b/ggml/src/ggml-vulkan/vulkan-shaders/copy_from_quant.comp index 06df5095258..6a692147478 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/copy_from_quant.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/copy_from_quant.comp @@ -4,7 +4,7 @@ #include "generic_unary_head.glsl" #include "dequant_funcs.glsl" -#if defined(DATA_A_IQ4_NL) || defined(DATA_A_MXFP4) +#if defined(DATA_A_IQ4_NL) || defined(DATA_A_MXFP4) || defined(DATA_A_NVFP4) // 16 invocations needed for init_iq_shmem layout(local_size_x = 16, local_size_y = 1, local_size_z = 1) in; #else diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp b/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp index b8c40eec102..710c15296da 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp @@ -1,6 +1,5 @@ #version 450 -#include "rte.glsl" #include "types.glsl" #if defined(SET_ROWS) && QUANT_K == 1 @@ -184,6 +183,31 @@ void quantize(uint dst_idx, uint src_idx) } #endif +#if defined(DATA_A_Q1_0) +void quantize(uint dst_idx, uint src_idx) +{ + float sum_abs = 0.0; + + [[unroll]] for (int j = 0; j < QUANT_K_Q1_0; j++) { + sum_abs += abs(data_s[src_idx + j]); + } + + const float d = sum_abs / QUANT_K_Q1_0; + + data_q[dst_idx].d = float16_t(d); + + [[unroll]] for (int j = 0; j < QUANT_K_Q1_0 / 8; ++j) { + data_q[dst_idx].qs[j] = uint8_t(0); + } + + [[unroll]] for (int j = 0; j < QUANT_K_Q1_0; ++j) { + if (data_s[src_idx + j] >= 0.0) { + data_q[dst_idx].qs[j / 8] |= uint8_t(1 << (j % 8)); + } + } +} +#endif + #if defined(DATA_A_IQ4_NL) uint best_index(float x) { if (x <= kvalues_iq4nl[0]) return 0; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.glsl b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.glsl index 7865a6bda79..e67299fdeca 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.glsl +++ b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.glsl @@ -5,21 +5,60 @@ #include "types.glsl" #if defined(DATA_A_F32) +FLOAT_TYPE dequantize1(uint ib, uint iqs, uint a_offset) { + return data_a[a_offset + ib]; +} vec2 dequantize(uint ib, uint iqs, uint a_offset) { return vec2(data_a[a_offset + ib], data_a[a_offset + ib + 1]); } +vec4 dequantize4(uint ib, uint iqs, uint a_offset) { + return vec4(data_a[a_offset + ib ], data_a[a_offset + ib + 1], + data_a[a_offset + ib + 2], data_a[a_offset + ib + 3]); +} +vec4 dequantize4_2aligned(uint ib, uint iqs, uint a_offset) { + return vec4(data_a[a_offset + ib ], data_a[a_offset + ib + 1], + data_a[a_offset + ib + 2], data_a[a_offset + ib + 3]); +} + #endif #if defined(DATA_A_F16) +FLOAT_TYPE dequantize1(uint ib, uint iqs, uint a_offset) { + return data_a[a_offset + ib]; +} vec2 dequantize(uint ib, uint iqs, uint a_offset) { return vec2(data_a[a_offset + ib], data_a[a_offset + ib + 1]); } +vec4 dequantize4(uint ib, uint iqs, uint a_offset) { + return vec4(data_a[a_offset + ib ], data_a[a_offset + ib + 1], + data_a[a_offset + ib + 2], data_a[a_offset + ib + 3]); +} +vec4 dequantize4_2aligned(uint ib, uint iqs, uint a_offset) { + const vec2 a = data_a_packed32[(a_offset + ib)/2]; + const vec2 b = data_a_packed32[(a_offset + ib)/2 + 1]; + return vec4(a, b); +} #endif #if defined(DATA_A_BF16) +FLOAT_TYPE dequantize1(uint ib, uint iqs, uint a_offset) { + return bf16_to_fp32(data_a[a_offset + ib]); +} vec2 dequantize(uint ib, uint iqs, uint a_offset) { return vec2(bf16_to_fp32(data_a[a_offset + ib]), bf16_to_fp32(data_a[a_offset + ib + 1])); } +vec4 dequantize4(uint ib, uint iqs, uint a_offset) { + return vec4(bf16_to_fp32(data_a[a_offset + ib ]), bf16_to_fp32(data_a[a_offset + ib + 1]), + bf16_to_fp32(data_a[a_offset + ib + 2]), bf16_to_fp32(data_a[a_offset + ib + 3])); +} +vec4 dequantize4_2aligned(uint ib, uint iqs, uint a_offset) { + const uint a = data_a_packed32[(a_offset + ib)/2]; + const uint b = data_a_packed32[(a_offset + ib)/2 + 1]; + return vec4(uintBitsToFloat((a & 0x0000ffff) << 16), + uintBitsToFloat( a & 0xffff0000), + uintBitsToFloat((b & 0x0000ffff) << 16), + uintBitsToFloat( b & 0xffff0000)); +} #endif #if defined(DATA_A_Q4_0) @@ -87,6 +126,23 @@ vec4 dequantize4(uint ib, uint iqs, uint a_offset) { } #endif +#if defined(DATA_A_Q1_0) +vec2 dequantize(uint ib, uint iqs, uint a_offset) { + const uint bits = uint(data_a[a_offset + ib].qs[iqs / 8u]) >> (iqs % 8u); + return vec2( + (bits & 1u) != 0u ? 1.0f : -1.0f, + (bits & 2u) != 0u ? 1.0f : -1.0f); +} +vec4 dequantize4(uint ib, uint iqs, uint a_offset) { + const uint bits = uint(data_a[a_offset + ib].qs[iqs / 8u]) >> (iqs % 8u); + return vec4( + (bits & 1u) != 0u ? 1.0f : -1.0f, + (bits & 2u) != 0u ? 1.0f : -1.0f, + (bits & 4u) != 0u ? 1.0f : -1.0f, + (bits & 8u) != 0u ? 1.0f : -1.0f); +} +#endif + #if defined(DATA_A_IQ1_S) vec2 dequantize(uint ib, uint iqs, uint a_offset) { const uint ib32 = iqs / 32; @@ -433,6 +489,25 @@ vec4 dequantize4(uint ib, uint iqs, uint a_offset) { } #endif +#if defined(DATA_A_NVFP4) +vec2 dequantize(uint ib, uint iqs, uint a_offset) { + const uint sub = iqs >> 4; + const float d = ue4m3_to_fp32(data_a[a_offset + ib].d[sub]); + const uint j = iqs & 7; + const uint shift = (iqs & 8) >> 1; // 0 or 4 + const uint vui0 = uint(data_a[a_offset + ib].qs[sub * 8u + j]); + const uint vui1 = uint(data_a[a_offset + ib].qs[sub * 8u + j + 1]); + const uint qs0 = (vui0 >> shift) & 0xF; + const uint qs1 = (vui1 >> shift) & 0xF; + return vec2(float(kvalues_mxfp4[qs0]), float(kvalues_mxfp4[qs1])) * d * 0.5; +} +vec4 dequantize4(uint ib, uint iqs, uint a_offset) { + const vec2 v0 = dequantize(ib, iqs, a_offset); + const vec2 v1 = dequantize(ib, iqs + 2u, a_offset); + return vec4(v0.x, v0.y, v1.x, v1.y); +} +#endif + #if defined(DATA_A_F32) || defined(DATA_A_F16) || defined(DATA_A_BF16) vec2 get_dm(uint ib, uint a_offset) { return vec2(0, 0); @@ -454,12 +529,25 @@ vec2 get_dm(uint ib, uint a_offset) { } #endif +#if defined(DATA_A_Q1_0) +vec2 get_dm(uint ib, uint a_offset) { + const float d = float(data_a[a_offset + ib].d); + return vec2(d, 0); +} +#endif + #if defined(DATA_A_MXFP4) vec2 get_dm(uint ib, uint a_offset) { return vec2(e8m0_to_fp32(data_a[a_offset + ib].e), 0); } #endif +#if defined(DATA_A_NVFP4) +vec2 get_dm(uint ib, uint a_offset) { + return vec2(1.0, 0.0); +} +#endif + #if defined(DATA_A_Q4_1) || defined(DATA_A_Q5_1) vec2 get_dm(uint ib, uint a_offset) { const vec2 dm = vec2(data_a_packed32[a_offset + ib].dm); diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.glsl b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.glsl index 8ac6482dc94..7171cbfa559 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.glsl +++ b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.glsl @@ -1,4 +1,12 @@ +// Each format defines a scalar dequantFunc<T> plus a V=4 dequantFunc<T>_v +// passed as the optional vector decoder to coopMatLoadTensorNV via +// GL_NV_cooperative_matrix_decode_vector. When the driver doesn't support +// the extension, ggml-vulkan.cpp strips it from the compiled SPIR-V. +#ifdef GL_NV_cooperative_matrix_decode_vector +#extension GL_NV_cooperative_matrix_decode_vector : enable +#endif + #include "types.glsl" layout(buffer_reference, std430, buffer_reference_align = 16) buffer decodeBufF32 { @@ -13,6 +21,31 @@ float16_t dequantFuncF32(const in decodeBufF32 bl, const in uint blockCoords[2], return vf16[idx]; } +layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufQ1_0 { + block_q1_0 block; +}; + +float16_t dequantFuncQ1_0(const in decodeBufQ1_0 bl, const in uint blockCoords[2], const in uint coordInBlock[2]) +{ + const float16_t d = bl.block.d; + const uint idx = coordInBlock[1]; + const uint bit = (uint(bl.block.qs[(idx & 0x78) >> 3]) >> (idx & 0x7)) & 1u; + return bit != 0u ? d : -d; +} + +f16vec4 dequantFuncQ1_0_v(const in decodeBufQ1_0 bl, const in uint blockCoords[2], const in uint coordInBlock[2]) +{ + const float16_t d = bl.block.d; + const float16_t md = -d; + const uint idx = coordInBlock[1]; + const uint qs_nib = uint(bl.block.qs[idx >> 3]) >> (idx & 0x4u); + return f16vec4( + (qs_nib & 1u) != 0u ? d : md, + (qs_nib & 2u) != 0u ? d : md, + (qs_nib & 4u) != 0u ? d : md, + (qs_nib & 8u) != 0u ? d : md); +} + layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufQ4_0 { block_q4_0_packed16 block; }; @@ -30,10 +63,28 @@ float16_t dequantFuncQ4_0(const in decodeBufQ4_0 bl, const in uint blockCoords[2 return ret; } +f16vec4 dequantFuncQ4_0_v(const in decodeBufQ4_0 bl, const in uint blockCoords[2], const in uint coordInBlock[2]) +{ + const float16_t d = bl.block.d; + const uint idx = coordInBlock[1]; + const uint shift = (idx & 0x10) >> 2; // 0 or 4 + const uint qs_i = (idx & 0xE) >> 1; // even, in {0,2,4,6} + const uint qsw = uint32_t(bl.block.qs[qs_i ]) + | (uint32_t(bl.block.qs[qs_i + 1u]) << 16); + // shift in {0,4}: per-byte mask 0x0F isolates the wanted nibble in each byte. + const uint q4 = (qsw >> shift) & 0x0F0F0F0Fu; + const u8vec4 q = unpack8(q4); + return f16vec4((vec4(q) - vec4(8.0)) * vec4(float(d))); +} + layout(buffer_reference, std430, buffer_reference_align = 4) buffer decodeBufQ4_1 { block_q4_1 block; }; +layout(buffer_reference, std430, buffer_reference_align = 4) buffer decodeBufQ4_1_packed32 { + block_q4_1_packed32 block; +}; + float16_t dequantFuncQ4_1(const in decodeBufQ4_1 bl, const in uint blockCoords[2], const in uint coordInBlock[2]) { const float16_t d = bl.block.d; @@ -48,10 +99,27 @@ float16_t dequantFuncQ4_1(const in decodeBufQ4_1 bl, const in uint blockCoords[2 return ret; } +f16vec4 dequantFuncQ4_1_v(const in decodeBufQ4_1 bl, const in uint blockCoords[2], const in uint coordInBlock[2]) +{ + decodeBufQ4_1_packed32 bl32 = decodeBufQ4_1_packed32(bl); + const float16_t d = bl.block.d; + const float16_t m = bl.block.m; + const uint idx = coordInBlock[1]; + const uint shift = (idx & 0x10) >> 2; // 0 or 4 + const uint qs_w = (idx & 0xC) >> 2; // iqs / 4 in [0,4) + const uint qsw = uint32_t(bl32.block.qs[qs_w]); + const u8vec4 q = unpack8((qsw >> shift) & 0x0F0F0F0Fu); + return f16vec4(vec4(q) * vec4(float(d)) + vec4(float(m))); +} + layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufQ5_0 { block_q5_0 block; }; +layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufQ5_0_packed16 { + block_q5_0_packed16 block; +}; + float16_t dequantFuncQ5_0(const in decodeBufQ5_0 bl, const in uint blockCoords[2], const in uint coordInBlock[2]) { const float16_t d = bl.block.d; @@ -70,10 +138,32 @@ float16_t dequantFuncQ5_0(const in decodeBufQ5_0 bl, const in uint blockCoords[2 return ret; } +f16vec4 dequantFuncQ5_0_v(const in decodeBufQ5_0 bl, const in uint blockCoords[2], const in uint coordInBlock[2]) +{ + decodeBufQ5_0_packed16 bl16 = decodeBufQ5_0_packed16(bl); + const float16_t d = bl.block.d; + const uint idx = coordInBlock[1]; + const uint shift = (idx & 0x10) >> 2; // 0 or 4 + const uint qs_i = (idx & 0xC) >> 1; // packed16 word index, in {0,2,4,6} + const uint qsw = uint32_t(bl16.block.qs[qs_i ]) + | (uint32_t(bl16.block.qs[qs_i + 1u]) << 16); + const u8vec4 ql = unpack8((qsw >> shift) & 0x0F0F0F0Fu); + + const uint uint_qh = uint(bl16.block.qh[1]) << 16 | uint(bl16.block.qh[0]); + const uint qh_pack = uint_qh >> idx; // bits 0..3 = element idx..idx+3 high bits + const uvec4 qh_high = (uvec4(qh_pack, qh_pack >> 1u, qh_pack >> 2u, qh_pack >> 3u) & uvec4(0x01u)) << 4u; + + return f16vec4((vec4(ql) + vec4(qh_high) - vec4(16.0)) * vec4(float(d))); +} + layout(buffer_reference, std430, buffer_reference_align = 8) buffer decodeBufQ5_1 { block_q5_1 block; }; +layout(buffer_reference, std430, buffer_reference_align = 8) buffer decodeBufQ5_1_packed32 { + block_q5_1_packed32 block; +}; + float16_t dequantFuncQ5_1(const in decodeBufQ5_1 bl, const in uint blockCoords[2], const in uint coordInBlock[2]) { const float16_t d = bl.block.d; @@ -93,6 +183,23 @@ float16_t dequantFuncQ5_1(const in decodeBufQ5_1 bl, const in uint blockCoords[2 return ret; } +f16vec4 dequantFuncQ5_1_v(const in decodeBufQ5_1 bl, const in uint blockCoords[2], const in uint coordInBlock[2]) +{ + decodeBufQ5_1_packed32 bl32 = decodeBufQ5_1_packed32(bl); + const float16_t d = bl.block.d; + const float16_t m = bl.block.m; + const uint idx = coordInBlock[1]; + const uint shift = (idx & 0x10) >> 2; // 0 or 4 + const uint qs_w = (idx & 0xC) >> 2; // iqs / 4 in [0,4) + const uint qsw = uint32_t(bl32.block.qs[qs_w]); + const u8vec4 ql = unpack8((qsw >> shift) & 0x0F0F0F0Fu); + + const uint qh_pack = bl.block.qh >> idx; // bits 0..3 = element idx..idx+3 high bits + const uvec4 qh_high = (uvec4(qh_pack, qh_pack >> 1u, qh_pack >> 2u, qh_pack >> 3u) & uvec4(0x01u)) << 4u; + + return f16vec4((vec4(ql) + vec4(qh_high)) * vec4(float(d)) + vec4(float(m))); +} + layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufQ8_0 { block_q8_0_packed16 block; }; @@ -109,6 +216,17 @@ float16_t dequantFuncQ8_0(const in decodeBufQ8_0 bl, const in uint blockCoords[2 return ret; } +f16vec4 dequantFuncQ8_0_v(const in decodeBufQ8_0 bl, const in uint blockCoords[2], const in uint coordInBlock[2]) +{ + const float16_t d = bl.block.d; + const uint idx = coordInBlock[1]; + const uint base = idx >> 1u; + const uint w = uint(uint16_t(bl.block.qs[base])) + | (uint(uint16_t(bl.block.qs[base + 1u])) << 16u); + const i8vec4 qi = unpack8(int32_t(w)); + return f16vec4(vec4(qi) * vec4(float(d))); +} + layout(buffer_reference, std430, buffer_reference_align = 4) buffer decodeBufQ2_K { block_q2_K block; }; @@ -117,6 +235,10 @@ layout(buffer_reference, std430, buffer_reference_align = 16) buffer decodeBufQ2 block_q2_K_packed16 block; }; +layout(buffer_reference, std430, buffer_reference_align = 4) buffer decodeBufQ2_K_packed32 { + block_q2_K_packed32 block; +}; + float16_t dequantFuncQ2_K(const in decodeBufQ2_K bl, const in uint blockCoords[2], const in uint coordInBlock[2]) { decodeBufQ2_K_packed16 bl16 = decodeBufQ2_K_packed16(bl); @@ -135,10 +257,36 @@ float16_t dequantFuncQ2_K(const in decodeBufQ2_K bl, const in uint blockCoords[2 return ret; } +f16vec4 dequantFuncQ2_K_v(const in decodeBufQ2_K bl, const in uint blockCoords[2], const in uint coordInBlock[2]) +{ + decodeBufQ2_K_packed32 bl32 = decodeBufQ2_K_packed32(bl); + const f16vec2 dm = bl.block.dm; + const uint idx = coordInBlock[1]; + + const uint scalesi = idx >> 4; // 0..15 + const uint qsshift = (idx & 0x60) >> 4; // 0,2,4,6 + + // qs_i (packed16) = ((idx & 0x80) >> 3) + ((idx & 0x1E) >> 1) is even for idx % 4 == 0, + // so qs_w (packed32) = qs_i / 2 = ((idx & 0x80) >> 4) + ((idx & 0x1Cu) >> 2). + const uint qs_w = ((idx & 0x80) >> 4) + ((idx & 0x1Cu) >> 2); + const uint qsw = uint32_t(bl32.block.qs[qs_w]); + const uint qs4 = (qsw >> qsshift) & 0x03030303u; + const u8vec4 qi = unpack8(qs4); + + const uint scales = bl.block.scales[scalesi]; + const float16_t d_sub = dm.x * float16_t(scales & 0xF); + const float16_t m_sub = dm.y * float16_t(scales >> 4); + return f16vec4(vec4(qi) * vec4(float(d_sub)) - vec4(float(m_sub))); +} + layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufQ3_K { block_q3_K block; }; +layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufQ3_K_packed16 { + block_q3_K_packed16 block; +}; + float16_t dequantFuncQ3_K(const in decodeBufQ3_K bl, const in uint blockCoords[2], const in uint coordInBlock[2]) { const uint idx = coordInBlock[1]; @@ -167,6 +315,47 @@ float16_t dequantFuncQ3_K(const in decodeBufQ3_K bl, const in uint blockCoords[2 return ret; } +f16vec4 dequantFuncQ3_K_v(const in decodeBufQ3_K bl, const in uint blockCoords[2], const in uint coordInBlock[2]) +{ + decodeBufQ3_K_packed16 bl16 = decodeBufQ3_K_packed16(bl); + const uint idx = coordInBlock[1]; + + const uint n = idx >> 7; // 0,1 + const uint is = idx >> 4; // 0..15 + const uint halfsplit = (idx & 0x60) >> 5; // 0,1,2,3 + const uint qsshift = halfsplit << 1; // 0,2,4,6 + const uint hbit = (n << 2) + halfsplit; // 0..7 (bit position in hmask byte) + + uint32_t scaleidx0 = (is < 8) ? is : (is - 8); + uint32_t scaleidx0shift = (is < 8) ? 0u : 4u; + uint32_t scaleidx1 = is + 8 - (is / 4) * 4; + uint32_t scaleidx1shift = (is / 4) * 2; + + const int8_t us = int8_t( + ((bl.block.scales[scaleidx0] >> scaleidx0shift) & 0xF) | + (((bl.block.scales[scaleidx1] >> scaleidx1shift) & 3) << 4)); + const float16_t dl = bl.block.d * float16_t(int(us) - 32); + + // For idx % 4 == 0: (idx & 0x1F) == (idx & 0x1C) is a multiple of 4. + const uint qsi = (n << 5) + (idx & 0x1Cu); + const uint hmi = (idx & 0x1Cu); + + // Two adjacent uint16 packed16 reads, combined into a uint32 in registers. + // After this: byte j of qsw / hmw holds the data for element idx+j. + const uint qsw = uint32_t(bl16.block.qs[qsi >> 1]) + | (uint32_t(bl16.block.qs[(qsi >> 1) + 1u]) << 16); + const uint hmw = uint32_t(bl16.block.hmask[hmi >> 1]) + | (uint32_t(bl16.block.hmask[(hmi >> 1) + 1u]) << 16); + + // qsshift in {0,2,4,6} and hbit in {0..7}: per-byte masks isolate the wanted bits + // with no inter-byte leakage. + const uint ql4 = (qsw >> qsshift) & 0x03030303u; + const uint qh4 = (hmw >> hbit) & 0x01010101u; + + const ivec4 q = ivec4(unpack8(ql4 | (qh4 << 2))) - ivec4(4); + return f16vec4(vec4(q) * vec4(float(dl))); +} + layout(buffer_reference, std430, buffer_reference_align = 16) buffer decodeBufQ4_K { block_q4_K block; }; @@ -175,6 +364,10 @@ layout(buffer_reference, std430, buffer_reference_align = 16) buffer decodeBufQ4 block_q4_K_packed16 block; }; +layout(buffer_reference, std430, buffer_reference_align = 16) buffer decodeBufQ4_K_packed32 { + block_q4_K_packed32 block; +}; + layout(buffer_reference, std430, buffer_reference_align = 16) buffer decodeBufQ4_K_packed128 { block_q4_K_packed128 block; }; @@ -322,6 +515,55 @@ float16_t dequantFuncQ4_K(const in decodeBufQ4_K bl, const in uint blockCoords[2 return float16_t(ret); } +f16vec4 dequantFuncQ4_K_v(const in decodeBufQ4_K bl, const in uint blockCoords[2], const in uint coordInBlock[2]) +{ + decodeBufQ4_K_packed32 bl32 = decodeBufQ4_K_packed32(bl); + decodeBufQ4_K_packed128 bl128 = decodeBufQ4_K_packed128(bl); + const uint idx = coordInBlock[1]; + + const uint is = idx >> 5; // 0..7 + +#if defined(IS_MUL_MM2) && defined(DATA_A_Q4_K) + vec2 v = shAscales[is * shAscales_stride + (blockCoords[0] % BM)]; + float d = v.x; + float m = v.y; +#else + uvec4 v = bl128.block.q4k[0]; + const vec2 loadd = vec2(unpackFloat2x16(v.x)); + + uint32_t sc; + uint32_t mbyte; + + uint32_t scale0 = v.y; + uint32_t scale4 = v.z; + uint32_t scale8 = v.w; + + uint32_t sc_lo = scale0; + uint32_t mb_lo = scale4; + uint32_t sc_hi = (scale8 & 0x0F0F0F0F) | ((scale0 & 0xC0C0C0C0) >> 2); + uint32_t mb_hi = ((scale8 & 0xF0F0F0F0) >> 4) | ((scale4 & 0xC0C0C0C0) >> 2); + + sc = is < 4 ? sc_lo : sc_hi; + mbyte = is < 4 ? mb_lo : mb_hi; + sc = sc >> (8 * (is & 3)); + mbyte = mbyte >> (8 * (is & 3)); + sc &= 0x3F; + mbyte &= 0x3F; + + const float d = loadd.x * float(sc); + const float m = loadd.y * float(mbyte); +#endif + + // idx in [0,256); vector decode uses idx a multiple of 4. packed32 word index: + // (qs_i >> 1) == (idx >> 6) * 8 + ((idx & 0x1E) >> 2). sh is 0 or 4 only, so a + // single (w >> sh) & 0x0F0F0F0F isolates all four nibbles without inter-byte leakage. + const uint sh = (idx & 0x20u) >> 3u; + const uint w = uint32_t(bl32.block.qs[(idx >> 6) * 8u + ((idx & 0x1Eu) >> 2)]); + const u8vec4 q = unpack8((w >> sh) & 0x0F0F0F0Fu); + + return f16vec4(vec4(d) * vec4(q) - vec4(m)); +} + layout(buffer_reference, std430, buffer_reference_align = 16) buffer decodeBufQ5_K { block_q5_K block; }; @@ -334,6 +576,10 @@ layout(buffer_reference, std430, buffer_reference_align = 16) buffer decodeBufQ5 block_q5_K_packed128 block; }; +layout(buffer_reference, std430, buffer_reference_align = 16) buffer decodeBufQ5_K_packed32 { + block_q5_K_packed32 block; +}; + float16_t dequantFuncQ5_K(const in decodeBufQ5_K bl, const in uint blockCoords[2], const in uint coordInBlock[2]) { decodeBufQ5_K_packed16 bl16 = decodeBufQ5_K_packed16(bl); @@ -387,6 +633,58 @@ float16_t dequantFuncQ5_K(const in decodeBufQ5_K bl, const in uint blockCoords[2 return float16_t(ret); } +f16vec4 dequantFuncQ5_K_v(const in decodeBufQ5_K bl, const in uint blockCoords[2], const in uint coordInBlock[2]) +{ + decodeBufQ5_K_packed32 bl32 = decodeBufQ5_K_packed32(bl); + decodeBufQ5_K_packed128 bl128 = decodeBufQ5_K_packed128(bl); + const uint idx = coordInBlock[1]; + const uint is = idx >> 5; + +#if defined(IS_MUL_MM2) && defined(DATA_A_Q5_K) + vec2 v = shAscales[is * shAscales_stride + (blockCoords[0] % BM)]; + float d = v.x; + float m = v.y; +#else + uvec4 v = bl128.block.q5k[0]; + + const f16vec2 loadd = unpackFloat2x16(v.x); + + uint32_t sc; + uint32_t mbyte; + + uint32_t scale0 = v.y; + uint32_t scale4 = v.z; + uint32_t scale8 = v.w; + + uint32_t sc_lo = scale0; + uint32_t mb_lo = scale4; + uint32_t sc_hi = (scale8 & 0x0F0F0F0F) | ((scale0 & 0xC0C0C0C0) >> 2); + uint32_t mb_hi = ((scale8 & 0xF0F0F0F0) >> 4) | ((scale4 & 0xC0C0C0C0) >> 2); + + sc = is < 4 ? sc_lo : sc_hi; + mbyte = is < 4 ? mb_lo : mb_hi; + sc = sc >> (8 * (is & 3)); + mbyte = mbyte >> (8 * (is & 3)); + sc &= 0x3F; + mbyte &= 0x3F; + + const float16_t d = loadd.x * float16_t(sc); + const float16_t m = loadd.y * float16_t(mbyte); +#endif + + // sh is 0 or 4; mask 0x0F0F0F0F covers the four nibbles regardless (no inter-byte leakage). + const uint sh = (idx & 0x20u) >> 3u; + const uint qs_w = (idx >> 6) * 8u + ((idx & 0x1Eu) >> 2); + const uint qh_w = (idx & 0x1Eu) >> 2; + + const uint ql4 = (uint32_t(bl32.block.qs[qs_w]) >> sh) & 0x0F0F0F0Fu; + // qh stores bit `is` per element across 4 consecutive bytes; one shift+mask handles all 4. + const uint qh4 = ((uint32_t(bl32.block.qh[qh_w]) >> is) & 0x01010101u) << 4u; + + const u8vec4 qi = unpack8(ql4 | qh4); + return f16vec4(vec4(qi) * vec4(d) - vec4(m)); +} + layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufQ6_K { block_q6_K block; }; @@ -419,6 +717,35 @@ float16_t dequantFuncQ6_K(const in decodeBufQ6_K bl, const in uint blockCoords[2 return ret; } +f16vec4 dequantFuncQ6_K_v(const in decodeBufQ6_K bl, const in uint blockCoords[2], const in uint coordInBlock[2]) +{ + decodeBufQ6_K_packed16 bl16 = decodeBufQ6_K_packed16(bl); + const uint idx = coordInBlock[1]; + + const uint b = (idx & 0x40) >> 6; + const uint qhshift = (idx & 0x60) >> 4; // 0,2,4,6 + const uint is = idx >> 4; + const uint sh = b * 4; // 0 or 4 + + const float16_t dscale = bl.block.d * float16_t(bl.block.scales[is]); + + const uint ql_i = ((idx & 0x80) >> 2) + ((idx & 0x3E) >> 1); + const uint qh_i = ((idx & 0x80) >> 3) + ((idx & 0x1E) >> 1); + + // Two adjacent uint16 packed16 reads, combined into a uint32 in registers. + // After this: byte j of qlw / qhw holds the data for element idx+j. + const uint qlw = uint32_t(bl16.block.ql[ql_i ]) | (uint32_t(bl16.block.ql[ql_i + 1]) << 16); + const uint qhw = uint32_t(bl16.block.qh[qh_i ]) | (uint32_t(bl16.block.qh[qh_i + 1]) << 16); + + // sh in {0,4} and qhshift in {0,2,4,6}: per-byte masks 0x0F / 0x03 keep only the + // wanted bits with no inter-byte leakage; place qh's 2 bits at nibble high position. + const uint ql4 = (qlw >> sh) & 0x0F0F0F0Fu; + const uint qh4 = ((qhw >> qhshift) & 0x03030303u) << 4u; + + const ivec4 qi = ivec4(unpack8(ql4 | qh4)); + return f16vec4((vec4(qi) - vec4(32.0f)) * vec4(float(dscale))); +} + #if defined(DATA_A_IQ1_S) layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufIQ1_S { block_iq1_s block; @@ -441,6 +768,29 @@ float16_t dequantFuncIQ1_S(const in decodeBufIQ1_S bl, const in uint blockCoords float16_t ret = float16_t(dl) * (float16_t(bitfieldExtract(int(grid), 2 * int(idx % 8), 2)) + float16_t(delta)); return ret; } + +f16vec4 dequantFuncIQ1_S_v(const in decodeBufIQ1_S bl, const in uint blockCoords[2], const in uint coordInBlock[2]) +{ + const float16_t d = bl.block.d; + const uint idx = coordInBlock[1]; + + const uint ib32 = idx >> 5; + const uint ib8 = idx >> 3; + const int i8b = int(idx & 4); // 0 or 4 + + const uint qh = bl.block.qh[ib32]; + const uint qs = bl.block.qs[ib8]; + const float dl = float(d) * float(2 * bitfieldExtract(qh, 12, 3) + 1); + const float delta = ((qh & 0x8000u) != 0u) ? -IQ1S_DELTA : IQ1S_DELTA; + const uint grid = iq1s_grid[qs | (bitfieldExtract(qh, 3 * int(ib8 & 3), 3) << 8)]; + + const ivec4 q = ivec4( + bitfieldExtract(int(grid), 2 * (i8b + 0), 2), + bitfieldExtract(int(grid), 2 * (i8b + 1), 2), + bitfieldExtract(int(grid), 2 * (i8b + 2), 2), + bitfieldExtract(int(grid), 2 * (i8b + 3), 2)); + return f16vec4((vec4(q) + vec4(delta)) * dl); +} #endif #if defined(DATA_A_IQ1_M) @@ -473,6 +823,33 @@ float16_t dequantFuncIQ1_M(const in decodeBufIQ1_M bl, const in uint blockCoords float16_t ret = d * float16_t(dl) * (float16_t(bitfieldExtract(int(grid), 2 * i8, 2)) + float16_t(delta)); return ret; } + +f16vec4 dequantFuncIQ1_M_v(const in decodeBufIQ1_M bl, const in uint blockCoords[2], const in uint coordInBlock[2]) +{ + decodeBufIQ1_M_packed64 bl64 = decodeBufIQ1_M_packed64(bl); + const uint idx = coordInBlock[1]; + + uvec2 scales = unpack32(bl64.block.scales); + const float16_t d = uint16BitsToHalf(uint16_t(((scales.x & 0xF000) >> 12) | ((scales.x & 0xF0000000) >> 24) | ((scales.y & 0xF000) >> 4) | ((scales.y & 0xF0000000) >> 16))); + + const uint ib8 = idx >> 3; + const uint ib16 = idx >> 4; + const int i8b = int(idx & 4); // 0 or 4 -- i8 base for the V=4 group + + const uint sc = bl.block.scales[ib8 / 8]; + const uint qs = bl.block.qs[ib8]; + const uint qh = bl.block.qh[ib16] >> (4 * (ib8 & 1)); + const float dl = 2.0 * float(bitfieldExtract(sc, 3 * int(ib16 & 3), 3)) + 1.0; + const float delta = ((qh & 8u) != 0u) ? -IQ1S_DELTA : IQ1S_DELTA; + const uint grid = iq1s_grid[qs | ((qh & 7u) << 8)]; + + const ivec4 q = ivec4( + bitfieldExtract(int(grid), 2 * (i8b + 0), 2), + bitfieldExtract(int(grid), 2 * (i8b + 1), 2), + bitfieldExtract(int(grid), 2 * (i8b + 2), 2), + bitfieldExtract(int(grid), 2 * (i8b + 3), 2)); + return f16vec4((vec4(q) + vec4(delta)) * (float(d) * dl)); +} #endif #if defined(DATA_A_IQ2_XXS) @@ -508,6 +885,33 @@ float16_t dequantFuncIQ2_XXS(const in decodeBufIQ2_XXS bl, const in uint blockCo vec2 ret = dscale * g * ((sign & (1 << (idx & 7))) != 0 ? -1.0hf : 1.0hf); return float16_t(ret[idx & 1]); } + +f16vec4 dequantFuncIQ2_XXS_v(const in decodeBufIQ2_XXS bl, const in uint blockCoords[2], const in uint coordInBlock[2]) +{ + decodeBufIQ2_XXS_packed16 bl16 = decodeBufIQ2_XXS_packed16(bl); + const uint idx = coordInBlock[1]; + + const uint ib32 = idx >> 5; + const uint ib8 = (idx & 0x18) >> 3; + const uint iqs = 8 * ib32 + ib8; + + const uint qs = bl.block.qs[iqs]; + const uint signscale = pack32(u16vec2(bl16.block.qs[4*ib32+2], bl16.block.qs[4*ib32+3])); + const float dscale = float(bl.block.d) * 0.25 * (0.5 + float(signscale >> 28)); + + uint sign = bitfieldExtract(signscale, 7 * int(ib8), 7); + sign |= bitCount(sign) << 7; + const uint sb = sign >> (idx & 7u); + + const uint g2 = iq2xxs_grid[qs][(idx & 4) >> 2]; + const u8vec4 g = unpack8(g2); + + return f16vec4( + dscale * float(g.x) * ((sb & 1u) != 0u ? -1.0 : 1.0), + dscale * float(g.y) * ((sb & 2u) != 0u ? -1.0 : 1.0), + dscale * float(g.z) * ((sb & 4u) != 0u ? -1.0 : 1.0), + dscale * float(g.w) * ((sb & 8u) != 0u ? -1.0 : 1.0)); +} #endif #if defined(DATA_A_IQ2_XS) @@ -536,6 +940,31 @@ float16_t dequantFuncIQ2_XS(const in decodeBufIQ2_XS bl, const in uint blockCoor vec2 ret = dscale * g * ((sign & (1 << (idx & 7))) != 0 ? -1.0hf : 1.0hf); return float16_t(ret[idx & 1]); } + +f16vec4 dequantFuncIQ2_XS_v(const in decodeBufIQ2_XS bl, const in uint blockCoords[2], const in uint coordInBlock[2]) +{ + const uint idx = coordInBlock[1]; + + const uint is = idx >> 5; + const uint sshift = (idx & 0x10) >> 2; + const uint iqs = idx >> 3; + + const uint16_t qs = bl.block.qs[iqs]; + const float dscale = float(bl.block.d) * 0.25 * (0.5 + float((bl.block.scales[is] >> sshift) & 0xF)); + + uint sign = uint(qs >> 9); + sign |= bitCount(sign) << 7; + const uint sb = sign >> (idx & 7u); + + const uint g2 = iq2xs_grid[qs & 0x1FF][(idx & 4) >> 2]; + const u8vec4 g = unpack8(g2); + + return f16vec4( + dscale * float(g.x) * ((sb & 1u) != 0u ? -1.0 : 1.0), + dscale * float(g.y) * ((sb & 2u) != 0u ? -1.0 : 1.0), + dscale * float(g.z) * ((sb & 4u) != 0u ? -1.0 : 1.0), + dscale * float(g.w) * ((sb & 8u) != 0u ? -1.0 : 1.0)); +} #endif #if defined(DATA_A_IQ2_S) @@ -564,6 +993,32 @@ float16_t dequantFuncIQ2_S(const in decodeBufIQ2_S bl, const in uint blockCoords const vec2 v = db * vec2(sign01) * vec2(unpack8(g2)); return float16_t(v[idx & 1]); } + +f16vec4 dequantFuncIQ2_S_v(const in decodeBufIQ2_S bl, const in uint blockCoords[2], const in uint coordInBlock[2]) +{ + const uint idx = coordInBlock[1]; + + const uint ib32 = idx >> 5; + const uint ib8 = idx >> 3; + const uint qhshift = 2 * (ib8 % 4); + + const uint scale = (bl.block.scales[ib32] >> ((idx & 0x10) >> 2)) & 0xf; + const uint qs = bl.block.qs[ib8]; + const uint qh = bl.block.qh[ib32]; + const uint sb = uint(bl.block.qs[QUANT_K / 8 + ib8]) >> (idx & 0x6u); + + const float d = float(bl.block.d); + const float db = d * 0.25 * (0.5 + scale); + + const uint g2 = iq2s_grid[qs | ((qh << (8 - qhshift)) & 0x300)][(idx & 4) >> 2]; + const u8vec4 g = unpack8(g2); + + return f16vec4( + db * float(g.x) * ((sb & 1u) != 0u ? -1.0 : 1.0), + db * float(g.y) * ((sb & 2u) != 0u ? -1.0 : 1.0), + db * float(g.z) * ((sb & 4u) != 0u ? -1.0 : 1.0), + db * float(g.w) * ((sb & 8u) != 0u ? -1.0 : 1.0)); +} #endif #if defined(DATA_A_IQ3_XXS) @@ -597,6 +1052,32 @@ float16_t dequantFuncIQ3_XXS(const in decodeBufIQ3_XXS bl, const in uint blockCo const vec2 v = db * vec2(sign01) * vec2(unpack8(grid).xy); return float16_t(v[idx & 1]); } + +f16vec4 dequantFuncIQ3_XXS_v(const in decodeBufIQ3_XXS bl, const in uint blockCoords[2], const in uint coordInBlock[2]) +{ + decodeBufIQ3_XXS_packed16 bl16 = decodeBufIQ3_XXS_packed16(bl); + const uint idx = coordInBlock[1]; + + const uint iqs = idx >> 2; + const uint is = QUANT_K / 4 + ((idx & 0xE0) >> 3); + + const float d = float(bl.block.d); + const uint qs = bl.block.qs[iqs]; + const uint signs = pack32(u16vec2(bl16.block.qs[is/2+0], bl16.block.qs[is/2+1])); + const float db = d * 0.5 * (0.5 + (signs >> 28)); + + const uint sign7 = bitfieldExtract(signs, 7 * (int(iqs / 2) % 4), 7); + const uint sb = (sign7 | (bitCount(sign7) << 7)) >> (idx & 0x6u); + + const uint grid = iq3xxs_grid[qs]; + const u8vec4 g = unpack8(grid); + + return f16vec4( + db * float(g.x) * ((sb & 1u) != 0u ? -1.0 : 1.0), + db * float(g.y) * ((sb & 2u) != 0u ? -1.0 : 1.0), + db * float(g.z) * ((sb & 4u) != 0u ? -1.0 : 1.0), + db * float(g.w) * ((sb & 8u) != 0u ? -1.0 : 1.0)); +} #endif #if defined(DATA_A_IQ3_S) @@ -623,6 +1104,30 @@ float16_t dequantFuncIQ3_S(const in decodeBufIQ3_S bl, const in uint blockCoords return float16_t(v[idx & 1]); } + +f16vec4 dequantFuncIQ3_S_v(const in decodeBufIQ3_S bl, const in uint blockCoords[2], const in uint coordInBlock[2]) +{ + const uint idx = coordInBlock[1]; + + const uint iqs = idx >> 2; + const uint iqh = idx >> 5; + + const float d = float(bl.block.d); + const uint qs = bl.block.qs[iqs]; + const uint qh = bl.block.qh[iqh]; + const uint sb = uint(bl.block.signs[iqs / 2]) >> (idx & 0x6u); + const uint scale = bl.block.scales[iqs / 16]; + const float db = d * (1 + 2 * ((scale >> (4 * (iqh & 1))) & 0xf)); + + const uint grid = iq3s_grid[qs | ((qh << (8 - (iqs % 8))) & 256)]; + const u8vec4 g = unpack8(grid); + + return f16vec4( + db * float(g.x) * ((sb & 1u) != 0u ? -1.0 : 1.0), + db * float(g.y) * ((sb & 2u) != 0u ? -1.0 : 1.0), + db * float(g.z) * ((sb & 4u) != 0u ? -1.0 : 1.0), + db * float(g.w) * ((sb & 8u) != 0u ? -1.0 : 1.0)); +} #endif #if defined(DATA_A_IQ4_XS) @@ -630,6 +1135,10 @@ layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufIQ4 block_iq4_xs block; }; +layout(buffer_reference, std430, buffer_reference_align = 4) buffer decodeBufIQ4_XS_packed32 { + block_iq4_xs_packed32 block; +}; + float16_t dequantFuncIQ4_XS(const in decodeBufIQ4_XS bl, const in uint blockCoords[2], const in uint coordInBlock[2]) { const float16_t d = bl.block.d; @@ -645,6 +1154,30 @@ float16_t dequantFuncIQ4_XS(const in decodeBufIQ4_XS bl, const in uint blockCoor float16_t ret = d * float16_t(int(sl | (sh << 4)) - 32) * float16_t(kvalues_iq4nl[q]); return ret; } + +f16vec4 dequantFuncIQ4_XS_v(const in decodeBufIQ4_XS bl, const in uint blockCoords[2], const in uint coordInBlock[2]) +{ + decodeBufIQ4_XS_packed32 bl32 = decodeBufIQ4_XS_packed32(bl); + const float16_t d = bl.block.d; + const uint idx = coordInBlock[1]; + + const uint ib32 = idx >> 5; // 0..7 + const uint sl = (bl32.block.scales_l >> (4 * ib32)) & 0xF; + const uint sh = (uint(bl32.block.scales_h) >> (2 * ib32)) & 0x3; + const uint qshift = (idx & 0x10) >> 2; // {0, 4} + const uint qs_w = 4 * ib32 + ((idx & 0xC) >> 2); // iqs / 4, in [0,32) + + const float16_t dl = d * float16_t(int(sl | (sh << 4)) - 32); + + const uint qsw = bl32.block.qs[qs_w]; + const u8vec4 qv = unpack8((qsw >> qshift) & 0x0F0F0F0Fu); + const vec4 ret = vec4( + float(kvalues_iq4nl[qv.x]), + float(kvalues_iq4nl[qv.y]), + float(kvalues_iq4nl[qv.z]), + float(kvalues_iq4nl[qv.w])) * float(dl); + return f16vec4(ret); +} #endif #if defined(DATA_A_IQ4_NL) @@ -652,6 +1185,10 @@ layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufIQ4 block_iq4_nl block; }; +layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufIQ4_NL_packed16 { + block_iq4_nl_packed16 block; +}; + float16_t dequantFuncIQ4_NL(const in decodeBufIQ4_NL bl, const in uint blockCoords[2], const in uint coordInBlock[2]) { const float16_t d = bl.block.d; @@ -664,6 +1201,24 @@ float16_t dequantFuncIQ4_NL(const in decodeBufIQ4_NL bl, const in uint blockCoor float16_t ret = float16_t(kvalues_iq4nl[qs]) * d; return ret; } + +f16vec4 dequantFuncIQ4_NL_v(const in decodeBufIQ4_NL bl, const in uint blockCoords[2], const in uint coordInBlock[2]) +{ + decodeBufIQ4_NL_packed16 bl16 = decodeBufIQ4_NL_packed16(bl); + const float16_t d = bl.block.d; + const uint idx = coordInBlock[1]; + const uint shift = (idx & 0x10) >> 2; // 0 or 4 + const uint qs_i = (idx & 0xC) >> 1; // packed16 word index, in {0,2,4,6} + const uint qsw = uint32_t(bl16.block.qs[qs_i ]) + | (uint32_t(bl16.block.qs[qs_i + 1u]) << 16); + // shift in {0,4}: per-byte mask 0x0F isolates the wanted nibble in each byte. + const u8vec4 q = unpack8((qsw >> shift) & 0x0F0F0F0Fu); + return f16vec4( + float(d) * float(kvalues_iq4nl[q.x]), + float(d) * float(kvalues_iq4nl[q.y]), + float(d) * float(kvalues_iq4nl[q.z]), + float(d) * float(kvalues_iq4nl[q.w])); +} #endif #if defined(DATA_A_MXFP4) @@ -683,52 +1238,139 @@ float16_t dequantFuncMXFP4(const in decodeBufMXFP4 bl, const in uint blockCoords float16_t ret = float16_t(kvalues_mxfp4[qs] * d * 0.5); return ret; } + +f16vec4 dequantFuncMXFP4_v(const in decodeBufMXFP4 bl, const in uint blockCoords[2], const in uint coordInBlock[2]) +{ + const float d = e8m0_to_fp32(bl.block.e); + const uint idx = coordInBlock[1]; + const uint iqs = idx & 0xF; + const uint shift = (idx & 0x10) >> 2; + uvec4 qv = uvec4( + uint(bl.block.qs[iqs]), + uint(bl.block.qs[iqs + 1u]), + uint(bl.block.qs[iqs + 2u]), + uint(bl.block.qs[iqs + 3u])); + qv = (qv >> shift) & 0xFu; + const vec4 ret = vec4( + float(kvalues_mxfp4[qv.x]), + float(kvalues_mxfp4[qv.y]), + float(kvalues_mxfp4[qv.z]), + float(kvalues_mxfp4[qv.w])) * d * 0.5f; + return f16vec4(ret); +} +#endif + +#if defined(DATA_A_NVFP4) +layout(buffer_reference, std430, buffer_reference_align = 4) buffer decodeBufNVFP4 { + block_nvfp4 block; +}; + +layout(buffer_reference, std430, buffer_reference_align = 4) buffer decodeBufNVFP4_packed32 { + block_nvfp4_packed32 block; +}; + +float16_t dequantFuncNVFP4(const in decodeBufNVFP4 bl, const in uint blockCoords[2], const in uint coordInBlock[2]) +{ + const uint idx = coordInBlock[1]; + const uint sub = (idx & 0x30) >> 4; + const uint iqs = ((idx & 0x30) >> 1) + (idx & 0x7); + const uint shift = (idx & 0x8) >> 1; + const float d = ue4m3_to_fp32(bl.block.d[sub]); + uint qs = uint(bl.block.qs[iqs]); + qs = (qs >> shift) & 0xF; + return float16_t(kvalues_mxfp4[qs] * d * 0.5); +} + +f16vec4 dequantFuncNVFP4_v(const in decodeBufNVFP4 bl, const in uint blockCoords[2], const in uint coordInBlock[2]) +{ + decodeBufNVFP4_packed32 bl32 = decodeBufNVFP4_packed32(bl); + const uint idx = coordInBlock[1]; + const uint sub = idx >> 4; + const uint qs_w = ((idx & 0x30) >> 3) + ((idx & 0x4u) >> 2); // iqs / 4, in [0,8) + const uint shift = (idx & 0x8) >> 1; + const float d = ue4m3_to_fp32(bl.block.d[sub]); + + const uint qsw = uint32_t(bl32.block.qs[qs_w]); + const u8vec4 qv = unpack8((qsw >> shift) & 0x0F0F0F0Fu); + const vec4 ret = vec4( + float(kvalues_mxfp4[qv.x]), + float(kvalues_mxfp4[qv.y]), + float(kvalues_mxfp4[qv.z]), + float(kvalues_mxfp4[qv.w])) * d * 0.5f; + return f16vec4(ret); +} #endif -#if defined(DATA_A_Q4_0) +#if defined(DATA_A_Q1_0) +#define dequantFuncA dequantFuncQ1_0 +#define dequantFuncA_v dequantFuncQ1_0_v +#elif defined(DATA_A_Q4_0) #define dequantFuncA dequantFuncQ4_0 +#define dequantFuncA_v dequantFuncQ4_0_v #elif defined(DATA_A_Q4_1) #define dequantFuncA dequantFuncQ4_1 +#define dequantFuncA_v dequantFuncQ4_1_v #elif defined(DATA_A_Q5_0) #define dequantFuncA dequantFuncQ5_0 +#define dequantFuncA_v dequantFuncQ5_0_v #elif defined(DATA_A_Q5_1) #define dequantFuncA dequantFuncQ5_1 +#define dequantFuncA_v dequantFuncQ5_1_v #elif defined(DATA_A_Q8_0) #define dequantFuncA dequantFuncQ8_0 +#define dequantFuncA_v dequantFuncQ8_0_v #elif defined(DATA_A_Q2_K) #define dequantFuncA dequantFuncQ2_K +#define dequantFuncA_v dequantFuncQ2_K_v #elif defined(DATA_A_Q3_K) #define dequantFuncA dequantFuncQ3_K +#define dequantFuncA_v dequantFuncQ3_K_v #elif defined(DATA_A_Q4_K) #define dequantFuncA dequantFuncQ4_K +#define dequantFuncA_v dequantFuncQ4_K_v #define fetch_scales fetch_scalesQ4_K #define store_scales store_scalesQ4_K #elif defined(DATA_A_Q5_K) #define dequantFuncA dequantFuncQ5_K +#define dequantFuncA_v dequantFuncQ5_K_v #define fetch_scales fetch_scalesQ5_K #define store_scales store_scalesQ4_K #elif defined(DATA_A_Q6_K) #define dequantFuncA dequantFuncQ6_K +#define dequantFuncA_v dequantFuncQ6_K_v #elif defined(DATA_A_IQ1_S) #define dequantFuncA dequantFuncIQ1_S +#define dequantFuncA_v dequantFuncIQ1_S_v #elif defined(DATA_A_IQ1_M) #define dequantFuncA dequantFuncIQ1_M +#define dequantFuncA_v dequantFuncIQ1_M_v #elif defined(DATA_A_IQ2_XXS) #define dequantFuncA dequantFuncIQ2_XXS +#define dequantFuncA_v dequantFuncIQ2_XXS_v #elif defined(DATA_A_IQ2_XS) #define dequantFuncA dequantFuncIQ2_XS +#define dequantFuncA_v dequantFuncIQ2_XS_v #elif defined(DATA_A_IQ2_S) #define dequantFuncA dequantFuncIQ2_S +#define dequantFuncA_v dequantFuncIQ2_S_v #elif defined(DATA_A_IQ3_XXS) #define dequantFuncA dequantFuncIQ3_XXS +#define dequantFuncA_v dequantFuncIQ3_XXS_v #elif defined(DATA_A_IQ3_S) #define dequantFuncA dequantFuncIQ3_S +#define dequantFuncA_v dequantFuncIQ3_S_v #elif defined(DATA_A_IQ4_XS) #define dequantFuncA dequantFuncIQ4_XS +#define dequantFuncA_v dequantFuncIQ4_XS_v #elif defined(DATA_A_IQ4_NL) #define dequantFuncA dequantFuncIQ4_NL +#define dequantFuncA_v dequantFuncIQ4_NL_v #elif defined(DATA_A_MXFP4) #define dequantFuncA dequantFuncMXFP4 +#define dequantFuncA_v dequantFuncMXFP4_v +#elif defined(DATA_A_NVFP4) +#define dequantFuncA dequantFuncNVFP4 +#define dequantFuncA_v dequantFuncNVFP4_v #elif defined(DATA_A_F32) #define dequantFuncA dequantFuncF32 #endif diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_nvfp4.comp b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_nvfp4.comp new file mode 100644 index 00000000000..689089160b7 --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_nvfp4.comp @@ -0,0 +1,32 @@ +#version 450 + +#include "dequant_head.glsl" + +layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) readonly buffer A {block_nvfp4 data_a[];}; +layout (binding = 1) writeonly buffer D {D_TYPE data_b[];}; + +void main() { + const uint i = gl_WorkGroupID.x * 4 + gl_LocalInvocationID.x / 64; + + init_iq_shmem(gl_WorkGroupSize); + + const uint tid = gl_LocalInvocationID.x % 64; + const uint sub = tid / 16; + const uint ir = tid % 16; + const uint ib = 16 * i + ir; + if (ib >= p.nel / 64) { + return; + } + + const uint q_idx = 8 * sub; + const uint b_idx = 1024 * i + 64 * ir + 16 * sub; + + const float d = ue4m3_to_fp32(data_a[ib].d[sub]); + + [[unroll]] for (uint l = 0; l < 8; ++l) { + data_b[b_idx + l + 0] = D_TYPE(d * 0.5 * float(kvalues_mxfp4[data_a[ib].qs[q_idx + l] & 0xF])); + data_b[b_idx + l + 8] = D_TYPE(d * 0.5 * float(kvalues_mxfp4[data_a[ib].qs[q_idx + l] >> 4])); + } +} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q1_0.comp b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q1_0.comp new file mode 100644 index 00000000000..ca0bdbc63e0 --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q1_0.comp @@ -0,0 +1,29 @@ +#version 450 + +#include "dequant_head.glsl" + +layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) readonly buffer A {block_q1_0 data_a[];}; +layout (binding = 1) writeonly buffer D {D_TYPE data_b[];}; + +void main() { + const uint i = gl_WorkGroupID.x * 4 + gl_LocalInvocationID.x / 64; + + const uint tid = gl_LocalInvocationID.x % 64; + const uint il = tid / 4; + const uint ir = tid % 4; + const uint ib = 4*i + ir; + if (ib >= p.nel / 128) { + return; + } + + const uint b_idx = 512*i + 128*ir + 8*il; + + const float d = float(data_a[ib].d); + const uint bits = uint(data_a[ib].qs[il]); + + [[unroll]] for (uint l = 0; l < 8; ++l) { + data_b[b_idx + l] = D_TYPE((bits & (1u << l)) != 0u ? d : -d); + } +} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/diag.comp b/ggml/src/ggml-vulkan/vulkan-shaders/diag.comp index cd3f42f4911..79761324f55 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/diag.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/diag.comp @@ -1,6 +1,5 @@ #version 450 -#include "rte.glsl" #include "types.glsl" #include "generic_unary_head.glsl" diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/dot_product_funcs.glsl b/ggml/src/ggml-vulkan/vulkan-shaders/dot_product_funcs.glsl new file mode 100644 index 00000000000..c474bfe09ce --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/dot_product_funcs.glsl @@ -0,0 +1,27 @@ +#ifdef DOT2_F16 +#extension GL_EXT_spirv_intrinsics : require + +spirv_instruction(extensions = ["SPV_VALVE_mixed_float_dot_product"], + capabilities = [6912], id = 6916) +float v_dot2_f32_f16(f16vec2 a, f16vec2 b, float acc); + +ACC_TYPE dot_product(f16vec4 a, f16vec4 b, ACC_TYPE acc) { + return ACC_TYPE(v_dot2_f32_f16(a.zw, b.zw, v_dot2_f32_f16(a.xy, b.xy, float(acc)))); +} + +ACC_TYPE dot_product(f16vec2 a, f16vec2 b, ACC_TYPE acc) { + return ACC_TYPE(v_dot2_f32_f16(a, b, float(acc))); +} + +#else + +ACC_TYPE dot_product(FLOAT_TYPEV4 a, FLOAT_TYPEV4 b, ACC_TYPE acc) { + return fma(ACC_TYPE(a.x), ACC_TYPE(b.x), fma(ACC_TYPE(a.y), ACC_TYPE(b.y), + fma(ACC_TYPE(a.z), ACC_TYPE(b.z), fma(ACC_TYPE(a.w), ACC_TYPE(b.w), acc)))); +} + +ACC_TYPE dot_product(FLOAT_TYPEV2 a, FLOAT_TYPEV2 b, ACC_TYPE acc) { + return fma(ACC_TYPE(a.x), ACC_TYPE(b.x), fma(ACC_TYPE(a.y), ACC_TYPE(b.y), acc)); +} + +#endif diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/elu.comp b/ggml/src/ggml-vulkan/vulkan-shaders/elu.comp new file mode 100644 index 00000000000..84dcbd8c88f --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/elu.comp @@ -0,0 +1,27 @@ +#version 450 + +#include "generic_head.glsl" +#include "types.glsl" + +#extension GL_EXT_control_flow_attributes : enable + +layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) readonly buffer X {A_TYPE data_a[];}; +layout (binding = 1) writeonly buffer D {D_TYPE data_d[];}; + +void main() { + const uint i = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x; + + if (i >= p.KX) { + return; + } + + float x = float(data_a[i]); + + if (x < 0.0f) { + x = exp(x) - 1; + } + + data_d[i] = D_TYPE(x); +} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/exp.comp b/ggml/src/ggml-vulkan/vulkan-shaders/exp.comp index b69d4ddb096..c7cf5ec68f7 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/exp.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/exp.comp @@ -1,6 +1,5 @@ #version 450 -#include "rte.glsl" #include "generic_head.glsl" #include "types.glsl" diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/feature-tests/coopmat2_decode_vector.comp b/ggml/src/ggml-vulkan/vulkan-shaders/feature-tests/coopmat2_decode_vector.comp new file mode 100644 index 00000000000..65e9c678401 --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/feature-tests/coopmat2_decode_vector.comp @@ -0,0 +1,7 @@ +#version 460 + +#extension GL_NV_cooperative_matrix_decode_vector : require + +void main() +{ +} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp index 0379e5d5024..91fb07c93e7 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp @@ -3,20 +3,35 @@ #extension GL_EXT_control_flow_attributes : enable #extension GL_EXT_shader_16bit_storage : require -#extension GL_EXT_shader_explicit_arithmetic_types_float16 : require #extension GL_EXT_shader_explicit_arithmetic_types_int32 : require +#ifdef FLOAT16 +#extension GL_EXT_shader_explicit_arithmetic_types_float16 : require +#extension GL_EXT_shader_subgroup_extended_types_float16 : require +#endif + +#ifdef MMQ +#extension GL_EXT_integer_dot_product : require +#extension GL_KHR_shader_subgroup_clustered : require + +#include "mul_mmq_shmem_types.glsl" +#endif + #extension GL_KHR_shader_subgroup_shuffle : enable #extension GL_KHR_shader_subgroup_vote : enable #include "types.glsl" +#include "dot_product_funcs.glsl" #include "flash_attn_base.glsl" +#include "flash_attn_dequant.glsl" const uint32_t HSK_per_thread = HSK / D_split; const uint32_t HSV_per_thread = HSV / D_split; -const uint32_t cols_per_iter = WorkGroupSize / D_split; +const uint32_t rows_per_thread = Br / row_split; +const uint32_t cols_per_iter = WorkGroupSize / D_split / row_split; const uint32_t cols_per_thread = Bc / cols_per_iter; +const uint32_t num_subgroups = SubGroupSize == 0 ? 0 : WorkGroupSize / SubGroupSize; layout (binding = 0) readonly buffer Q {float data_q[];}; @@ -27,20 +42,41 @@ layout (binding = 2) readonly buffer V {float16_t data_v[];}; layout (binding = 2) readonly buffer VV4 {f16vec4 data_vv4[];}; layout (binding = 3) readonly buffer M {float16_t data_m[];}; -// Store the output when doing grouped query attention. -// Rows index by Q's dimension 2, and the first N rows are valid. -D_TYPE perElemOpGqaStore(const in uint32_t r, const in uint32_t c, const in D_TYPE elem, const in uint32_t o_offset, const in uint32_t iq2, const in uint32_t N) -{ - uint32_t offset = (iq2 + r) * HSV + c; - data_o[o_offset + offset] = D_TYPE(elem); - return elem; -} +// If SubGroupSize is set to 0 then only use shmem reductions +const uint32_t tmpsh_size = (SubGroupSize > 0) ? (row_split == 1 ? num_subgroups * D_split : num_subgroups) : WorkGroupSize; +shared float tmpsh[tmpsh_size]; +shared FLOAT_TYPEV4 tmpshv4[tmpsh_size]; + +const uint32_t masksh_stride = Br + 1; +shared FLOAT_TYPE masksh[Bc * masksh_stride]; + +#ifndef MMQ +const uint32_t qf_stride = HSK / 4 + 1; +shared FLOAT_TYPEV4 Qf[Br * qf_stride]; +#else -shared FLOAT_TYPE tmpsh[WorkGroupSize]; -shared vec4 tmpshv4[WorkGroupSize]; +const uint32_t qf_stride = HSK / 32; +shared block_b_cache Qf[Br * qf_stride]; +#endif + +#ifndef MMQ +const uint32_t D = HSK > HSV ? HSK : HSV; +#else +const uint32_t D = HSV; +#endif +const uint32_t kvsh_stride = D / 4 + 1; +shared FLOAT_TYPEV4 kvsh[SHMEM_STAGING != 0 ? Bc * kvsh_stride : 1]; + +#ifdef MMQ -shared float masksh[Bc][Br]; -shared vec4 Qf[Br][HSK / 4]; +shared block_a_cache kblocksh[SHMEM_STAGING != 0 ? Bc * qf_stride : 1]; +#endif + +shared vec4 occupancy_limiter[LIMIT_OCCUPANCY_SHMEM > 0 ? LIMIT_OCCUPANCY_SHMEM : 1]; + +#ifdef MMQ +#include "flash_attn_mmq_funcs.glsl" +#endif void main() { #ifdef NEEDS_INIT_IQ_SHMEM @@ -50,210 +86,451 @@ void main() { init_indices(); const uint32_t tid = gl_LocalInvocationIndex; + const uint32_t threads_per_rowgroup = gl_WorkGroupSize.x / row_split; + const uint32_t row_tid = gl_LocalInvocationIndex / threads_per_rowgroup; + const uint32_t rowgroup_tid = gl_LocalInvocationIndex % threads_per_rowgroup; const uint32_t d_tid = gl_LocalInvocationIndex % D_split; - const uint32_t col_tid = gl_LocalInvocationIndex / D_split; + const uint32_t col_tid = (gl_LocalInvocationIndex % threads_per_rowgroup) / D_split; + + if (LIMIT_OCCUPANCY_SHMEM > 0) { + // This just exists to avoid the occupancy_limiter array getting optimized out + occupancy_limiter[tid] = vec4(tid); + + barrier(); + + if (occupancy_limiter[tid] == vec4(99999.0)) { + data_ov4[0] = D_TYPEV4(occupancy_limiter[tid]); + } + } + +#define tile_row(r) (row_tid * rows_per_thread + (r)) - uint32_t q_offset = (iq2*p.nb02+iq3*p.nb03) / 4; + uint32_t q_offset = gqa_iq1*p.nb01 + (iq2*p.nb02 + iq3*p.nb03) / 4; [[unroll]] for (uint32_t idx = 0; idx < Br * HSK / 4; idx += gl_WorkGroupSize.x) { uint32_t d = (idx + tid) % (HSK / 4); uint32_t r = (idx + tid) / (HSK / 4); - if (r < Br && d < HSK / 4 && - i * Br + r < N) { - Qf[r][d] = vec4(data_qv4[q_offset / 4 + (i * Br + r) * q_stride / 4 + d]) * p.scale; + const bool is_in_bounds = r < Br && d < HSK / 4 && i * Br + r < N; +#ifndef MMQ + if (is_in_bounds) { + Qf[r * qf_stride + d] = FLOAT_TYPEV4(data_qv4[q_offset / 4 + (i * Br + r) * q_stride / 4 + d] * p.scale); + } +#else + const uint buf_ib = r * qf_stride + d / 8; + const uint buf_iqs = d % 8; + + FLOAT_TYPEV4 vals = is_in_bounds ? FLOAT_TYPEV4(data_qv4[q_offset / 4 + (i * Br + r) * q_stride / 4 + d] * p.scale) : FLOAT_TYPEV4(0.0f); + const FLOAT_TYPEV4 abs_vals = abs(vals); + + const FLOAT_TYPE thread_max = max(max(abs_vals.x, abs_vals.y), max(abs_vals.z, abs_vals.w)); + const FLOAT_TYPE amax = subgroupClusteredMax(thread_max, 8); + const FLOAT_TYPE qd = amax / FLOAT_TYPE(127.0); + const FLOAT_TYPE qd_inv = qd != FLOAT_TYPE(0.0) ? FLOAT_TYPE(1.0) / qd : FLOAT_TYPE(0.0); + vals = round(vals * qd_inv); + + Qf[buf_ib].qs[buf_iqs] = pack32(i8vec4(vals)); + + // Q8_0 K only needs (qd, _); the asymmetric Q4_*/Q5_* family also stores + // the row-sum scaled by qd, used in k_dot_correction. + if (FaTypeK == FA_TYPE_Q8_0) { + if (buf_iqs == 0) { + Qf[buf_ib].ds = FLOAT_TYPEV2(qd, 0.0); + } + } else { + const FLOAT_TYPE thread_sum = vals.x + vals.y + vals.z + vals.w; + const FLOAT_TYPE sum = subgroupClusteredAdd(thread_sum, 8); + + if (buf_iqs == 0) { + Qf[buf_ib].ds = FLOAT_TYPEV2(qd, sum * qd); + } } +#endif } barrier(); - vec4 Of[Br][HSV_per_thread / 4]; + FLOAT_TYPEV4 Of[rows_per_thread][HSV_per_thread / 4]; [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) { - [[unroll]] for (uint32_t r = 0; r < Br; ++r) { - Of[r][d] = vec4(0.0); + [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { + Of[r][d] = FLOAT_TYPEV4(0.0); } } - float Lf[Br], Mf[Br]; + float Lf[rows_per_thread], Mf[rows_per_thread]; // Use -FLT_MAX/2 rather than -inf to reduce the possibility of NaNs, e.g. when computing Mold-M. const float NEG_FLT_MAX_OVER_2 = uintBitsToFloat(0xFEFFFFFF); - [[unroll]] for (uint32_t r = 0; r < Br; ++r) { + [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { Lf[r] = 0; Mf[r] = NEG_FLT_MAX_OVER_2; } - float slope[Br]; - [[unroll]] for (uint32_t r = 0; r < Br; ++r) { - slope[r] = 1.0; + ACC_TYPE slope[rows_per_thread]; + [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { + slope[r] = ACC_TYPE(1.0); } // ALiBi if (p.max_bias > 0.0f) { - [[unroll]] for (uint32_t r = 0; r < Br; ++r) { - slope[r] = perElemOpComputeSlope(r, col_tid, ACC_TYPE(0), iq2); + [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { + slope[r] = perElemOpComputeSlope(tile_row(r), col_tid, ACC_TYPE(0), iq2); } } -#if BLOCK_SIZE > 1 - uint32_t k_offset = (ik2*p.nb12 + ik3*p.nb13) / BLOCK_BYTE_SIZE; - uint32_t v_offset = (iv2*p.nb22 + iv3*p.nb23) / BLOCK_BYTE_SIZE; -#else - uint32_t k_offset = (ik2*p.nb12 + ik3*p.nb13) / 2; - uint32_t v_offset = (iv2*p.nb22 + iv3*p.nb23) / 2; -#endif - uint32_t m_offset = 0; + const uint32_t mo_stride = CEIL_DIV(KV, 16 * Bc); + // mo_offset will point to the tile starting at row i*Br and col 0 + uint32_t mo_offset = mo_stride * i; + + // FaBlockBytesK/V == 2 for f16, 16 for f32, ggml block byte size for quants. + uint32_t k_offset = (ik2*p.nb12 + ik3*p.nb13) / FaBlockBytesK; + uint32_t v_offset = (iv2*p.nb22 + iv3*p.nb23) / FaBlockBytesV; + uint32_t m_offset = gqa_iq1*KV; if (p.nem2 != 1 || p.nem3 != 1) { - m_offset = ((iq3 % p.nem3) * p.nem2 + (iq2 % p.nem2)) * p.nem1 * KV; + m_offset += ((iq3 % p.nem3) * p.nem2 + (iq2 % p.nem2)) * p.nem1 * KV; + mo_offset += ((iq3 % p.nem3) * p.nem2 + (iq2 % p.nem2)) * CEIL_DIV(p.nem1, Br) * mo_stride; } + uint32_t mask_opt = 0; + uint32_t mask_opt_idx = ~0; + uint32_t mask_opt_bits = 0; + [[dont_unroll]] for (uint32_t j = start_j; j < end_j; ++j) { + if (MASK_ENABLE) { + if (USE_MASK_OPT && mask_opt_idx != j / 16) { + mask_opt_idx = j / 16; + mask_opt = data_mask_opt[mo_offset + mask_opt_idx]; + } + mask_opt_bits = (mask_opt >> ((j % 16) * 2)) & 0x3; + if (mask_opt_bits == MASK_OPT_ALL_NEG_INF) { + // skip this block + continue; + } + // Only load if the block is not all zeros + if (mask_opt_bits != MASK_OPT_ALL_ZERO) { + bool nem1_bounds_check = !(p.gqa_ratio > 1) && (p.nem1 % Br) != 0; - if ((p.mask_n_head_log2 & MASK_ENABLE_BIT) != 0) { - bool nem1_bounds_check = !(p.gqa_ratio > 1) && (p.nem1 % Br) != 0; - - float max_mask = NEG_FLT_MAX_OVER_2; - [[unroll]] for (uint32_t idx = 0; idx < Bc * Br; idx += gl_WorkGroupSize.x) { - uint32_t c = (idx + tid) % Bc; - uint32_t r = (idx + tid) / Bc; - if (idx + tid < Bc * Br) { - if ((!KV_bounds_check || j * Bc + c < KV) && (!nem1_bounds_check || i * Br + r < p.nem1)) { - float m = float(data_m[m_offset + (i * Br + r) * m_stride + (j * Bc + c)]); - masksh[c][r] = m; - max_mask = max(max_mask, m); - } else { - masksh[c][r] = float(0); + float max_mask = NEG_FLT_MAX_OVER_2; + barrier(); + [[unroll]] for (uint32_t idx = 0; idx < Bc * Br; idx += gl_WorkGroupSize.x) { + uint32_t c = (idx + tid) % Bc; + uint32_t r = (idx + tid) / Bc; + if (idx + tid < Bc * Br) { + if ((!KV_bounds_check || j * Bc + c < KV) && (!nem1_bounds_check || i * Br + r < p.nem1)) { + FLOAT_TYPE m = FLOAT_TYPE(data_m[m_offset + (i * Br + r) * m_stride + (j * Bc + c)]); + masksh[c * masksh_stride + r] = m; + max_mask = max(max_mask, float(m)); + } else { + masksh[c * masksh_stride + r] = FLOAT_TYPE(0); + } } } + // skip the block if the mask is entirely -inf + bool all_less = subgroupAll(max_mask <= NEG_FLT_MAX_OVER_2); + barrier(); + if (gl_SubgroupInvocationID == 0) { + tmpsh[gl_SubgroupID] = all_less ? NEG_FLT_MAX_OVER_2 : 0.0f; + } + barrier(); + [[unroll]] for (uint s = 0; s < gl_NumSubgroups; ++s) { + max_mask = max(max_mask, tmpsh[s]); + } + if (max_mask <= NEG_FLT_MAX_OVER_2) { + continue; + } } - // skip the block if the mask is entirely -inf - bool all_less = subgroupAll(max_mask <= NEG_FLT_MAX_OVER_2); - barrier(); - if (gl_SubgroupInvocationID == 0) { - tmpsh[gl_SubgroupID] = all_less ? NEG_FLT_MAX_OVER_2 : 0.0f; + } + + ACC_TYPE Sf[rows_per_thread][cols_per_thread]; + [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { + [[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) { + Sf[r][c] = ACC_TYPE(0.0); } + } + + if (SHMEM_STAGING != 0) { barrier(); - [[unroll]] for (uint s = 0; s < gl_NumSubgroups; ++s) { - max_mask = max(max_mask, tmpsh[s]); +#ifndef MMQ + [[unroll]] for (uint32_t idx = 0; idx < Bc * HSK / 4; idx += gl_WorkGroupSize.x) { + uint32_t d = (idx + tid) % (HSK / 4); + uint32_t c = (idx + tid) / (HSK / 4); + if (idx + gl_WorkGroupSize.x <= Bc * HSK / 4 || c < Bc) { + FLOAT_TYPEV4 K_Tf = FLOAT_TYPEV4(0); + if (!KV_bounds_check || j * Bc + c < KV) { + if (USE_DECODE_K) { + uint coord = (j * Bc + c) * k_stride * BLOCK_SIZE_K + 4 * d; + uint ib = coord / BLOCK_SIZE_K; + uint iqs = (coord % BLOCK_SIZE_K); + K_Tf = dequantize4(ib, iqs, k_offset, BINDING_IDX_K); + } else { + K_Tf = FLOAT_TYPEV4(data_kv4[k_offset / 4 + (j * Bc + c) * k_stride / 4 + d]); + } + } + + kvsh[c * kvsh_stride + d] = K_Tf; + } } - if (max_mask <= NEG_FLT_MAX_OVER_2) { - continue; +#else // MMQ + const uint ints_per_block = 8u / fa_quant_r_mmq(FaTypeK); + const uint quant_iters = Bc * HSK / 32 * ints_per_block; + [[unroll]] for (uint32_t idx = 0; idx < quant_iters; idx += gl_WorkGroupSize.x) { + const uint32_t iqs = (idx + tid) % ints_per_block; + const uint32_t ib = (idx + tid) / ints_per_block; + const uint32_t c = ib / (HSK / 32); + const uint32_t block = ib % (HSK / 32); + if (idx + gl_WorkGroupSize.x <= quant_iters || c < Bc) { + const uint buf_ib = c * qf_stride + block; + if (!KV_bounds_check || j * Bc + c < KV) { + const uint global_ib = (j * Bc + c) * k_stride + block; + k_block_to_shmem(buf_ib, global_ib, iqs, k_offset); + } else { + k_block_to_shmem_zero(buf_ib, iqs); + } + } } +#endif // MMQ + barrier(); } - float Sf[Br][cols_per_thread]; - [[unroll]] for (uint32_t r = 0; r < Br; ++r) { +#ifndef MMQ + // More d iterations means Q register caching becomes relevant + // Few iterations means the additional registers needed are worse than the speed-up from caching + if (HSK_per_thread / 4 > 4) { + [[unroll]] for (uint32_t d = 0; d < HSK_per_thread / 4; ++d) { + FLOAT_TYPEV4 Q_cache[rows_per_thread]; + [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { + Q_cache[r] = Qf[tile_row(r) * qf_stride + d * D_split + d_tid]; + } + + [[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) { + if (KV_bounds_check && j * Bc + c * cols_per_iter + col_tid >= KV) { + continue; + } + + FLOAT_TYPEV4 K_Tf; + if (SHMEM_STAGING != 0) { + K_Tf = kvsh[(c * cols_per_iter + col_tid) * kvsh_stride + (d * D_split + d_tid)]; + } else if (USE_DECODE_K) { + uint coord = (j * Bc + c * cols_per_iter + col_tid) * k_stride * BLOCK_SIZE_K + 4 * (d * D_split + d_tid); + uint ib = coord / BLOCK_SIZE_K; + uint iqs = (coord % BLOCK_SIZE_K); + K_Tf = dequantize4(ib, iqs, k_offset, BINDING_IDX_K); + } else { + K_Tf = FLOAT_TYPEV4(data_kv4[k_offset / 4 + (j * Bc + c * cols_per_iter + col_tid) * k_stride / 4 + d * D_split + d_tid]); + } + [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { + Sf[r][c] = dot_product(Q_cache[r], K_Tf, Sf[r][c]); + } + } + } + } else { [[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) { - Sf[r][c] = 0.0; + if (KV_bounds_check && j * Bc + c * cols_per_iter + col_tid >= KV) { + continue; + } + + [[unroll]] for (uint32_t d = 0; d < HSK_per_thread / 4; ++d) { + FLOAT_TYPEV4 K_Tf; + if (SHMEM_STAGING != 0) { + K_Tf = kvsh[(c * cols_per_iter + col_tid) * kvsh_stride + (d * D_split + d_tid)]; + } else if (USE_DECODE_K) { + uint coord = (j * Bc + c * cols_per_iter + col_tid) * k_stride * BLOCK_SIZE_K + 4 * (d * D_split + d_tid); + uint ib = coord / BLOCK_SIZE_K; + uint iqs = (coord % BLOCK_SIZE_K); + K_Tf = dequantize4(ib, iqs, k_offset, BINDING_IDX_K); + } else { + K_Tf = FLOAT_TYPEV4(data_kv4[k_offset / 4 + (j * Bc + c * cols_per_iter + col_tid) * k_stride / 4 + d * D_split + d_tid]); + } + [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { + Sf[r][c] = dot_product(Qf[tile_row(r) * qf_stride + d * D_split + d_tid], K_Tf, Sf[r][c]); + } + } } } - +#else // MMQ + const uint hsk4 = HSK_per_thread / 4; + const uint d_per_step = (hsk4 % 8 == 0) ? 8 : + (hsk4 % 4 == 0) ? 4 : + (hsk4 % 2 == 0) ? 2 : 1; [[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) { if (KV_bounds_check && j * Bc + c * cols_per_iter + col_tid >= KV) { continue; } - [[unroll]] for (uint32_t d = 0; d < HSK_per_thread / 4; ++d) { -#if BLOCK_SIZE > 1 - uint coord = (j * Bc + c * cols_per_iter + col_tid) * k_stride * BLOCK_SIZE + 4 * (d * D_split + d_tid); - uint ib = coord / BLOCK_SIZE; - uint iqs = (coord % BLOCK_SIZE); - vec4 K_Tf = dequantize4(ib, iqs, k_offset, BINDING_IDX_K); -#else - vec4 K_Tf = vec4(data_kv4[k_offset / 4 + (j * Bc + c * cols_per_iter + col_tid) * k_stride / 4 + d * D_split + d_tid]); -#endif - [[unroll]] for (uint32_t r = 0; r < Br; ++r) { - Sf[r][c] += dot(Qf[r][d * D_split + d_tid], K_Tf); + + [[unroll]] for (uint32_t d_block = 0; d_block < HSK_per_thread / 4; d_block += d_per_step) { + int32_t k_quants[d_per_step]; + ACC_TYPEV2 k_dm; + + // Q4_*/Q5_* take the block-8 fast path when one step covers a full + // block; Q8_0 always goes through the per-int get_k_qs* helpers + // (its qs is byte-packed, not nibble-packed). + const bool block8_fast = (d_per_step == 8) && (FaTypeK != FA_TYPE_Q8_0); + + if (SHMEM_STAGING != 0) { + const uint k_block_idx = (d_tid * (HSK_per_thread / 4) + d_block) / 8; + const uint buf_ib = (c * cols_per_iter + col_tid) * qf_stride + k_block_idx; + k_dm = ACC_TYPEV2(kblocksh[buf_ib].dm); + + if (block8_fast) { + const bool has_qh = (FaTypeK == FA_TYPE_Q5_0) || (FaTypeK == FA_TYPE_Q5_1); + [[unroll]] for (uint32_t d = 0; d < 4; d++) { + uint vui = kblocksh[buf_ib].qs[d]; + k_quants[d ] = int32_t( vui & 0x0F0F0F0F); + k_quants[d + 4] = int32_t((vui >> 4) & 0x0F0F0F0F); + if (has_qh) { + uint qh_lo = (kblocksh[buf_ib].qh >> (d * 4)) & 0xF; + uint qh_hi = (kblocksh[buf_ib].qh >> (d * 4 + 16)) & 0xF; + k_quants[d ] |= int32_t((qh_lo * 0x02040810u) & 0x10101010u); + k_quants[d + 4] |= int32_t((qh_hi * 0x02040810u) & 0x10101010u); + } + } + } else { + [[unroll]] for (uint32_t d = 0; d < d_per_step; d++) { + k_quants[d] = get_k_qs_shmem(buf_ib, (d_tid * (HSK_per_thread / 4) + d_block) % 8 + d); + } + } + } else { + const uint coord = (j * Bc + c * cols_per_iter + col_tid) * k_stride * BLOCK_SIZE_K + 4 * (d_tid * (HSK_per_thread / 4) + d_block); + const uint ib = coord / BLOCK_SIZE_K; + const uint iqs = (coord % BLOCK_SIZE_K); + + k_dm = ACC_TYPEV2(get_k_scale(ib, k_offset)); + + if (block8_fast) { + fa_k_qs_block8 blk = get_k_qs_block8(ib, k_offset); + [[unroll]] for (uint32_t d = 0; d < 8; d++) { + k_quants[d] = blk.qs[d]; + } + } else { + [[unroll]] for (uint32_t d = 0; d < d_per_step; d++) { + k_quants[d] = get_k_qs(ib, iqs + d * 4, k_offset); + } + } + } + + [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { + const uint qib = tile_row(r) * qf_stride + (d_tid * (HSK_per_thread / 4) + d_block) / 8; + const uint qiqs = (d_tid * (HSK_per_thread / 4) + d_block) % 8; + + int32_t acc = 0; + [[unroll]] for (uint32_t d = 0; d < d_per_step; d++) { + acc += dotPacked4x8EXT(Qf[qib].qs[qiqs + d], k_quants[d]); + } + + Sf[r][c] += ACC_TYPE(acc) * ACC_TYPE(Qf[qib].ds.x) * k_dm.x; + if ((d_tid * (HSK_per_thread / 4) + d_block) % 8 == 0) { + Sf[r][c] += k_dot_correction(qib, k_dm); + } } } } +#endif // MMQ [[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) { // Compute sum across the D_split [[unroll]] for (uint s = D_split / 2; s > 0; s >>= 1) { - [[unroll]] for (uint32_t r = 0; r < Br; ++r) { + [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { Sf[r][c] += subgroupShuffleXor(Sf[r][c], s); } } } - if (p.logit_softcap != 0.0f) { - [[unroll]] for (uint32_t r = 0; r < Br; ++r) { + if (LOGIT_SOFTCAP) { + [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { [[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) { - Sf[r][c] = p.logit_softcap * tanh(Sf[r][c]); + Sf[r][c] = ACC_TYPE(p.logit_softcap * tanh(Sf[r][c])); } } } - if ((p.mask_n_head_log2 & MASK_ENABLE_BIT) != 0) { + if (MASK_ENABLE && mask_opt_bits != MASK_OPT_ALL_ZERO) { [[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) { - [[unroll]] for (uint32_t r = 0; r < Br; ++r) { - float mvf = masksh[c * cols_per_iter + col_tid][r]; + [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { + FLOAT_TYPE mvf = masksh[(c * cols_per_iter + col_tid) * masksh_stride + tile_row(r)]; Sf[r][c] += slope[r]*mvf; } } - barrier(); } - float rowmaxf[Br], Pf[Br][cols_per_thread], rowsumf[Br], eMf[Br], Moldf[Br]; - [[unroll]] for (uint32_t r = 0; r < Br; ++r) { - rowmaxf[r] = NEG_FLT_MAX_OVER_2; + float eMf[rows_per_thread]; + [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { + float rowmaxf = NEG_FLT_MAX_OVER_2; [[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) { if (KV_bounds_check && j * Bc + c * cols_per_iter + col_tid >= KV) { continue; } - rowmaxf[r] = max(rowmaxf[r], Sf[r][c]); + rowmaxf = max(rowmaxf, float(Sf[r][c])); } - Moldf[r] = Mf[r]; + float Moldf = Mf[r]; // M = max(rowmax, Mold) // P = e^(S - M) // eM = e^(Mold - M) - Mf[r] = max(rowmaxf[r], Moldf[r]); - [[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) { - Pf[r][c] = exp(Sf[r][c] - Mf[r]); - } - eMf[r] = exp(Moldf[r] - Mf[r]); + Mf[r] = max(rowmaxf, Moldf); + eMf[r] = exp(Moldf - Mf[r]); + Lf[r] = eMf[r]*Lf[r]; + } - // Compute sum across row of P - rowsumf[r] = 0.0; - [[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) { - if (KV_bounds_check && j * Bc + c * cols_per_iter + col_tid >= KV) { - continue; - } - rowsumf[r] += Pf[r][c]; + [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) { + [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { + Of[r][d] = FLOAT_TYPE(eMf[r]) * Of[r][d]; } - - Lf[r] = eMf[r]*Lf[r] + rowsumf[r]; } - [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) { - [[unroll]] for (uint32_t r = 0; r < Br; ++r) { - Of[r][d] = eMf[r] * Of[r][d]; + if (SHMEM_STAGING != 0) { + barrier(); + [[unroll]] for (uint32_t idx = 0; idx < Bc * HSV / 4; idx += gl_WorkGroupSize.x) { + uint32_t d = (idx + tid) % (HSV / 4); + uint32_t c = (idx + tid) / (HSV / 4); + if (idx + gl_WorkGroupSize.x <= Bc * HSV / 4 || c < Bc) { + FLOAT_TYPEV4 V_Tf = FLOAT_TYPEV4(0); + if (!KV_bounds_check || j * Bc + c < KV) { + if (USE_DECODE_V) { + uint coord = (j * Bc + c) * v_stride * BLOCK_SIZE_V + 4 * d; + uint ib = coord / BLOCK_SIZE_V; + uint iqs = (coord % BLOCK_SIZE_V); + V_Tf = dequantize4(ib, iqs, v_offset, BINDING_IDX_V); + } else { + V_Tf = FLOAT_TYPEV4(data_vv4[v_offset / 4 + (j * Bc + c) * v_stride / 4 + d]); + } + } + + kvsh[c * kvsh_stride + d] = V_Tf; + } } + barrier(); } [[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) { if (KV_bounds_check && j * Bc + c * cols_per_iter + col_tid >= KV) { continue; } + + FLOAT_TYPE Pf[rows_per_thread]; + [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { + Pf[r] = FLOAT_TYPE(exp(float(Sf[r][c]) - Mf[r])); + Lf[r] += Pf[r]; + } + [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) { -#if BLOCK_SIZE > 1 - uint coord = (j * Bc + c * cols_per_iter + col_tid) * v_stride * BLOCK_SIZE + 4 * (d * D_split + d_tid); - uint ib = coord / BLOCK_SIZE; - uint iqs = (coord % BLOCK_SIZE); - vec4 Vf = dequantize4(ib, iqs, v_offset, BINDING_IDX_V); -#else - vec4 Vf = vec4(data_vv4[v_offset / 4 + (j * Bc + c * cols_per_iter + col_tid) * v_stride / 4 + d * D_split + d_tid]); -#endif - [[unroll]] for (uint32_t r = 0; r < Br; ++r) { - Of[r][d] += Pf[r][c] * Vf; + FLOAT_TYPEV4 Vf; + if (SHMEM_STAGING != 0) { + Vf = kvsh[(c * cols_per_iter + col_tid) * kvsh_stride + (d * D_split + d_tid)]; + } else if (USE_DECODE_V) { + uint coord = (j * Bc + c * cols_per_iter + col_tid) * v_stride * BLOCK_SIZE_V + 4 * (d * D_split + d_tid); + uint ib = coord / BLOCK_SIZE_V; + uint iqs = (coord % BLOCK_SIZE_V); + Vf = dequantize4(ib, iqs, v_offset, BINDING_IDX_V); + } else { + Vf = FLOAT_TYPEV4(data_vv4[v_offset / 4 + (j * Bc + c * cols_per_iter + col_tid) * v_stride / 4 + d * D_split + d_tid]); + } + [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { + Of[r][d] += FLOAT_TYPEV4(Pf[r] * Vf); } } } - - barrier(); } // prevent race on tmpsh @@ -261,58 +538,115 @@ void main() { // reduce across threads - [[unroll]] for (uint32_t r = 0; r < Br; ++r) { - float rowmaxf, eMf; + [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { + float rowmaxf = Mf[r]; - tmpsh[tid] = Mf[r]; // Compute max across the row - barrier(); - [[unroll]] for (int s = int(gl_WorkGroupSize.x) / 2; s >= D_split; s >>= 1) { - if (tid < s) { - tmpsh[tid] = max(tmpsh[tid], tmpsh[tid + s]); + if (SubGroupSize > 0) { + [[unroll]] for (uint s = D_split; s < SubGroupSize; s *= 2) { + rowmaxf = max(rowmaxf, subgroupShuffleXor(rowmaxf, s)); + } + if (row_split == 1) { + // Reduce inside workgroup with shmem + barrier(); + if (gl_SubgroupInvocationID == d_tid) { + tmpsh[gl_SubgroupID * D_split + d_tid] = rowmaxf; + } + barrier(); + rowmaxf = tmpsh[d_tid]; + [[unroll]] for (uint32_t s = 1; s < num_subgroups; ++s) { + rowmaxf = max(rowmaxf, tmpsh[s * D_split + d_tid]); + } } + } else { barrier(); + tmpsh[tid] = rowmaxf; + barrier(); + [[unroll]] for (int s = int(threads_per_rowgroup) / 2; s >= D_split; s >>= 1) { + if (rowgroup_tid < s) { + tmpsh[tid] = max(tmpsh[tid], tmpsh[tid ^ s]); + } + barrier(); + } + rowmaxf = tmpsh[row_tid * threads_per_rowgroup + d_tid]; } - rowmaxf = tmpsh[d_tid]; - barrier(); float Moldf = Mf[r]; // M = max(rowmax, Mold) // eM = e^(Mold - M) Mf[r] = max(rowmaxf, Moldf); - eMf = exp(Moldf - Mf[r]); + float eMf = exp(Moldf - Mf[r]); Lf[r] = eMf*Lf[r]; - tmpsh[tid] = Lf[r]; - // Compute sum across the row - barrier(); - [[unroll]] for (int s = int(gl_WorkGroupSize.x) / 2; s >= D_split; s >>= 1) { - if (tid < s) { - tmpsh[tid] = tmpsh[tid] + tmpsh[tid + s]; + if (SubGroupSize > 0) { + [[unroll]] for (uint s = D_split; s < SubGroupSize; s *= 2) { + Lf[r] += subgroupShuffleXor(Lf[r], s); + } + if (row_split == 1) { + barrier(); + if (gl_SubgroupInvocationID == d_tid) { + tmpsh[gl_SubgroupID * D_split + d_tid] = Lf[r]; + } + barrier(); + Lf[r] = tmpsh[d_tid]; + [[unroll]] for (uint32_t s = 1; s < num_subgroups; ++s) { + Lf[r] += tmpsh[s * D_split + d_tid]; + } } + } else { + barrier(); + tmpsh[tid] = Lf[r]; barrier(); + [[unroll]] for (int s = int(threads_per_rowgroup) / 2; s >= D_split; s >>= 1) { + if (rowgroup_tid < s) { + tmpsh[tid] = tmpsh[tid] + tmpsh[tid ^ s]; + } + barrier(); + } + Lf[r] = tmpsh[row_tid * threads_per_rowgroup + d_tid]; } - Lf[r] = tmpsh[d_tid]; - barrier(); [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) { + Of[r][d] = FLOAT_TYPE(eMf) * Of[r][d]; - Of[r][d] = eMf * Of[r][d]; - tmpshv4[tid] = Of[r][d]; - - barrier(); - [[unroll]] for (int s = int(gl_WorkGroupSize.x) / 2; s >= D_split; s >>= 1) { - if (tid < s) { - Of[r][d] += tmpshv4[tid + s]; - tmpshv4[tid] = Of[r][d]; + if (SubGroupSize > 0) { + [[unroll]] for (uint s = D_split; s < SubGroupSize; s *= 2) { + if (!OLD_AMD_WINDOWS) { + Of[r][d] += subgroupShuffleXor(Of[r][d], s); + } else { + // Something about f16vec4 subgroupShuffleXor is broken on AMD Windows RDNA2 and below. + // Shuffle full vec4 as workaround. + // See https://github.com/ggml-org/llama.cpp/issues/19881#issuecomment-3958643697 + Of[r][d] += FLOAT_TYPEV4(subgroupShuffleXor(vec4(Of[r][d]), s)); + } } + if (row_split == 1) { + barrier(); + if (gl_SubgroupInvocationID == d_tid) { + tmpshv4[gl_SubgroupID * D_split + d_tid] = Of[r][d]; + } + barrier(); + Of[r][d] = tmpshv4[d_tid]; + [[unroll]] for (uint32_t s = 1; s < num_subgroups; ++s) { + Of[r][d] += tmpshv4[s * D_split + d_tid]; + } + } + } else { barrier(); + tmpshv4[tid] = Of[r][d]; + barrier(); + [[unroll]] for (int s = int(threads_per_rowgroup) / 2; s >= D_split; s >>= 1) { + if (rowgroup_tid < s) { + Of[r][d] += tmpshv4[tid ^ s]; + tmpshv4[tid] = Of[r][d]; + } + barrier(); + } + Of[r][d] = tmpshv4[row_tid * threads_per_rowgroup + d_tid]; } - Of[r][d] = tmpshv4[d_tid]; - barrier(); } } @@ -320,32 +654,53 @@ void main() { // If there is split_k, then the split_k resolve shader does the final // division by L. Store the intermediate O value and per-row m and L values. if (p.k_num > 1) { - uint32_t o_offset = HSV * p.ne1 * (split_k_index + iq3 * p.k_num); - - [[unroll]] for (uint32_t r = 0; r < Br; ++r) { - if (r < N) { - [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) { - [[unroll]] for (uint32_t comp = 0; comp < 4; ++comp) { - perElemOpGqaStore(r, 4*(d * D_split + d_tid) + comp, Of[r][d][comp], o_offset, iq2, N); + if (p.gqa_ratio > 1) { + // note: O and Q have swapped coord 1,2. + uint32_t o_offset = HSV * p.ne1 * (split_k_index + p.k_num * (gqa_iq1 + p.ne2 * iq3)) / 4; + + [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { + const uint row = tile_row(r); + if (row < N) { + [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) { + gqaStore(row, d * D_split + d_tid, Of[r][d], o_offset, iq2, N); } } } - } - o_offset = HSV * p.ne1 * p.ne3 * p.k_num + p.ne1 * (split_k_index + iq3 * p.k_num) * 2; - [[unroll]] for (uint32_t r = 0; r < Br; ++r) { - if (r < N) { - perElemOpStoreCol0(r, 0u, ACC_TYPE(Lf[r]), o_offset, iq2, N); - perElemOpStoreCol0(r, 0u, ACC_TYPE(Mf[r]), o_offset + p.ne1, iq2, N); + o_offset = HSV * p.ne1 * p.k_num * p.ne2 * p.ne3 + p.ne1 * 2 * (split_k_index + p.k_num * (gqa_iq1 + p.ne2 * iq3)); + [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { + const uint row = tile_row(r); + if (row < N) { + perElemOpStoreCol0(row, 0u, ACC_TYPE(Lf[r]), o_offset, iq2, N); + perElemOpStoreCol0(row, 0u, ACC_TYPE(Mf[r]), o_offset + p.ne1, iq2, N); + } } - } + } else { + [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { + const uint row = tile_row(r); + const uint global_row = i * Br + row; + + if (global_row < N) { + uint32_t o_offset = HSV * p.ne1 * (split_k_index + p.k_num * (global_row + p.ne2 * iq3)) / 4; + + [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) { + data_ov4[o_offset + iq2 * HSV/4 + d * D_split + d_tid] = D_TYPEV4(Of[r][d]); + } + } + if (global_row < N && d_tid == 0 && col_tid == 0) { + uint32_t lm_offset = HSV * p.ne1 * p.k_num * p.ne2 * p.ne3 + p.ne1 * 2 * (split_k_index + p.k_num * (global_row + p.ne2 * iq3)); + data_o[lm_offset + iq2] = D_TYPE(Lf[r]); + data_o[lm_offset + p.ne1 + iq2] = D_TYPE(Mf[r]); + } + } + } return; } if ((p.mask_n_head_log2 & SINK_ENABLE_BIT) != 0) { - [[unroll]] for (uint32_t r = 0; r < Br; ++r) { - float sink = perElemOpGetSink(r, 0u, ACC_TYPE(0), iq2); + [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { + float sink = perElemOpGetSink(tile_row(r), 0u, ACC_TYPE(0), iq2); float ms = 1.0f; float vs = 1.0f; @@ -354,7 +709,7 @@ void main() { ms = exp(Mf[r] - sink); [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) { - Of[r][d] *= ms; + Of[r][d] *= FLOAT_TYPE(ms); } } else { vs = exp(sink - Mf[r]); @@ -364,39 +719,37 @@ void main() { } } - float Lfrcp[Br]; - [[unroll]] for (uint32_t r = 0; r < Br; ++r) { + float Lfrcp[rows_per_thread]; + [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { Lfrcp[r] = (Lf[r] == 0.0) ? 0.0 : (1.0 / Lf[r]); } [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) { - [[unroll]] for (uint32_t r = 0; r < Br; ++r) { - Of[r][d] *= Lfrcp[r]; -#if defined(ACC_TYPE_MAX) - Of[r][d] = clamp(Of[r][d], -vec4(ACC_TYPE_MAX), vec4(ACC_TYPE_MAX)); + [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { + Of[r][d] *= FLOAT_TYPE(Lfrcp[r]); +#if defined(FLOAT_TYPE_MAX) + Of[r][d] = clamp(Of[r][d], -FLOAT_TYPE_MAX, FLOAT_TYPE_MAX); #endif } } - uint32_t o_offset = iq3*p.ne2*p.ne1*HSV; + uint32_t o_offset = (gqa_iq1*p.ne1*HSV + iq3*p.ne2*p.ne1*HSV) / 4; if (p.gqa_ratio > 1) { - [[unroll]] for (uint32_t r = 0; r < Br; ++r) { - if (r < N) { + [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { + const uint row = tile_row(r); + if (row < N) { [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) { - [[unroll]] for (uint32_t comp = 0; comp < 4; ++comp) { - perElemOpGqaStore(r, 4*(d * D_split + d_tid) + comp, Of[r][d][comp], o_offset, iq2, N); - } + gqaStore(row, d * D_split + d_tid, Of[r][d], o_offset, iq2, N); } } } } else { - [[unroll]] for (uint32_t r = 0; r < Br; ++r) { - if (i * Br + r < N) { + [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { + const uint row = tile_row(r); + if (i * Br + row < N) { [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) { - [[unroll]] for (uint32_t comp = 0; comp < 4; ++comp) { - data_o[o_offset + iq2 * HSV + (i * Br + r) * p.ne1 * HSV + 4*(d * D_split + d_tid) + comp] = D_TYPE(Of[r][d][comp]); - } + data_ov4[o_offset + (iq2 * HSV + (i * Br + row) * p.ne1 * HSV) / 4 + d * D_split + d_tid] = D_TYPEV4(Of[r][d]); } } } diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.glsl b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.glsl index eb93903c468..66dcf610219 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.glsl +++ b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.glsl @@ -1,13 +1,29 @@ layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; -layout (constant_id = 0) const uint32_t WorkGroupSize = 128; -layout (constant_id = 1) const uint32_t Br = 1; -layout (constant_id = 2) const uint32_t Bc = 32; -layout (constant_id = 3) const uint32_t HSK = 32; -layout (constant_id = 4) const uint32_t HSV = 32; -layout (constant_id = 5) const uint32_t Clamp = 0; -layout (constant_id = 6) const uint32_t D_split = 16; +layout (constant_id = 0) const uint32_t WorkGroupSize = 128; +layout (constant_id = 1) const uint32_t Br = 1; +layout (constant_id = 2) const uint32_t Bc = 32; +layout (constant_id = 3) const uint32_t HSK = 32; +layout (constant_id = 4) const uint32_t HSV = 32; +layout (constant_id = 5) const uint32_t Clamp = 0; +layout (constant_id = 6) const uint32_t D_split = 16; +layout (constant_id = 7) const uint32_t row_split = 1; +layout (constant_id = 8) const uint32_t SubGroupSize = 32; +layout (constant_id = 9) const uint32_t SHMEM_STAGING = 0; +layout (constant_id = 10) const uint32_t Flags = 0; +layout (constant_id = 11) const uint32_t LIMIT_OCCUPANCY_SHMEM = 0; +// ggml_type enumerant for K/V +layout (constant_id = 12) const uint32_t FaTypeK = 0; +layout (constant_id = 13) const uint32_t FaTypeV = 0; +// sizeof(decode buffer): quants -> ggml block size; F32 -> 16 (decodeBufF32 vec4). +layout (constant_id = 14) const uint32_t FaBlockBytesK = 2; +layout (constant_id = 15) const uint32_t FaBlockBytesV = 2; + +const bool USE_MASK_OPT = (Flags & 1) != 0; +const bool MASK_ENABLE = (Flags & 2) != 0; +const bool LOGIT_SOFTCAP = (Flags & 4) != 0; +const bool OLD_AMD_WINDOWS = (Flags & 8) != 0; // Round up head sizes to a multiple of 16, for coopmat1/coopmat2 paths const uint32_t HSK_pad = (HSK + 15) & ~15; @@ -57,78 +73,82 @@ layout (push_constant) uniform parameter { } p; #define SINK_ENABLE_BIT (1<<24) -#define MASK_ENABLE_BIT (1<<16) #define N_LOG2_MASK 0xFFFF layout (binding = 4) readonly buffer S {float data_s[];}; layout (binding = 5) writeonly buffer O {D_TYPE data_o[];}; +layout (binding = 5) writeonly buffer OV4 {D_TYPEV4 data_ov4[];}; + +layout (binding = 6) readonly buffer MO {uint32_t data_mask_opt[];}; + +#define MASK_OPT_ALL_NEG_INF 1 +#define MASK_OPT_ALL_ZERO 2 #define BINDING_IDX_K 0 #define BINDING_IDX_V 1 -#if defined(DATA_A_F32) -layout (binding = 1) readonly buffer K_PACKED {vec4 k_data_packed[];} k_packed; -layout (binding = 2) readonly buffer V_PACKED {vec4 v_data_packed[];} v_packed; -#elif defined(A_TYPE_PACKED16) -layout (binding = 1) readonly buffer K_PACKED16 {A_TYPE_PACKED16 k_data_packed16[];} k_packed; -layout (binding = 2) readonly buffer V_PACKED16 {A_TYPE_PACKED16 v_data_packed16[];} v_packed; -#endif -#if defined(DATA_A_F32) -#undef BLOCK_SIZE -#define BLOCK_SIZE 4 -#define BLOCK_BYTE_SIZE 16 - -vec4 dequantize4(uint ib, uint iqs, uint a_offset, uint binding_idx) { - // iqs is currently always zero in the flash attention shaders - if (binding_idx == BINDING_IDX_K) { - return k_packed.k_data_packed[a_offset + ib]; - } else { - return v_packed.v_data_packed[a_offset + ib]; - } -} +// FaTypeK / FaTypeV spec constant values. These mirror enum ggml_type so the +// host can pass the type directly. Keep in sync with ggml.h. +#define FA_TYPE_F32 0u +#define FA_TYPE_F16 1u +#define FA_TYPE_Q4_0 2u +#define FA_TYPE_Q4_1 3u +#define FA_TYPE_Q5_0 6u +#define FA_TYPE_Q5_1 7u +#define FA_TYPE_Q8_0 8u +#define FA_TYPE_BF16 30u +#define FA_TYPE_Q1_0 41u + +#if defined(BFLOAT16) +#define O_TYPE float +#define O_TYPEV4 vec4 +#else +#define O_TYPE FLOAT_TYPE +#define O_TYPEV4 FLOAT_TYPEV4 #endif -#if defined(DATA_A_Q4_0) -#define BLOCK_BYTE_SIZE 18 - -vec4 dequantize4(uint ib, uint iqs, uint a_offset, uint binding_idx) { - if (binding_idx == BINDING_IDX_K) { - uint vui_lo = uint(k_packed.k_data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 0]); - uint vui_hi = uint(k_packed.k_data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 1]); - uint shift = (iqs & 0x10) >> 2; - vui_lo >>= shift; - vui_hi >>= shift; - - return float(k_packed.k_data_packed16[a_offset + ib].d) * (vec4(vui_lo & 0xF, (vui_lo >> 8) & 0xF, vui_hi & 0xF, (vui_hi >> 8) & 0xF) - 8.0f); - } else { - uint vui_lo = uint(v_packed.v_data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 0]); - uint vui_hi = uint(v_packed.v_data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 1]); - uint shift = (iqs & 0x10) >> 2; - vui_lo >>= shift; - vui_hi >>= shift; - - return float(v_packed.v_data_packed16[a_offset + ib].d) * (vec4(vui_lo & 0xF, (vui_lo >> 8) & 0xF, vui_hi & 0xF, (vui_hi >> 8) & 0xF) - 8.0f); +// Number of matrix elements per buffer block, derived from the K/V type spec +// constant. F32 is treated as a vec4 "block" of 4 floats. F16 uses block size 1 +// and bypasses the dequant path entirely. Quants follow their ggml block sizes. +uint fa_block_elems(uint ty) { + switch (ty) { + case FA_TYPE_F32: return 4u; + case FA_TYPE_F16: return 1u; + case FA_TYPE_Q4_0: return uint(QUANT_K_Q4_0); + case FA_TYPE_Q4_1: return uint(QUANT_K_Q4_1); + case FA_TYPE_Q5_0: return uint(QUANT_K_Q5_0); + case FA_TYPE_Q5_1: return uint(QUANT_K_Q5_1); + case FA_TYPE_Q8_0: return uint(QUANT_K_Q8_0); + case FA_TYPE_BF16: return 1u; + case FA_TYPE_Q1_0: return uint(QUANT_K_Q1_0); // cm2-only, harmless elsewhere + default: return 1u; } } -#endif -#if defined(DATA_A_Q8_0) -#define BLOCK_BYTE_SIZE 34 -vec4 dequantize4(uint ib, uint iqs, uint a_offset, uint binding_idx) { - if (binding_idx == BINDING_IDX_K) { - const i8vec2 v0 = unpack8(int32_t(k_packed.k_data_packed16[a_offset + ib].qs[iqs / 2])).xy; // vec4 used due to #12147 - const i8vec2 v1 = unpack8(int32_t(k_packed.k_data_packed16[a_offset + ib].qs[iqs / 2 + 1])).xy; - - return float(k_packed.k_data_packed16[a_offset + ib].d) * vec4(v0.x, v0.y, v1.x, v1.y); - } else { - const i8vec2 v0 = unpack8(int32_t(v_packed.v_data_packed16[a_offset + ib].qs[iqs / 2])).xy; // vec4 used due to #12147 - const i8vec2 v1 = unpack8(int32_t(v_packed.v_data_packed16[a_offset + ib].qs[iqs / 2 + 1])).xy; - - return float(v_packed.v_data_packed16[a_offset + ib].d) * vec4(v0.x, v0.y, v1.x, v1.y); +// QUANT_R_MMQ for FA-eligible K types. Q4_*/Q5_* store two nibbles per byte +// (R==2); Q8_0 stores one byte per element (R==1). Used to derive the number +// of int32s per 32-element block on the MMQ K path: ints_per_block == 8 / R. +uint fa_quant_r_mmq(uint ty) { + switch (ty) { + case FA_TYPE_Q4_0: return uint(QUANT_R_Q4_0); + case FA_TYPE_Q4_1: return uint(QUANT_R_Q4_1); + case FA_TYPE_Q5_0: return uint(QUANT_R_Q5_0); + case FA_TYPE_Q5_1: return uint(QUANT_R_Q5_1); + case FA_TYPE_Q8_0: return uint(QUANT_R_Q8_0); + default: return 1u; } } -#endif + +// These can't be `const` globals because GLSL forbids function calls in global +// const initializers, even when the spec constants would let the driver fold +// them. Macros expand at the use site and fold after specialization. +#define BLOCK_SIZE_K fa_block_elems(FaTypeK) +#define BLOCK_SIZE_V fa_block_elems(FaTypeV) +// F16 reads f16 elements directly from the binding; everything else routes +// through dequantize4 / the MMQ helpers to unpack from the packed block layout. +#define USE_DECODE_K (FaTypeK != FA_TYPE_F16) +#define USE_DECODE_V (FaTypeV != FA_TYPE_F16) #define CEIL_DIV(a, b) (((a) + (b) - 1) / (b)) @@ -165,7 +185,7 @@ ACC_TYPE perElemOpGetSink(const in uint32_t r, const in uint32_t c, const in ACC } uint32_t i, N, KV, split_k_index, Tr, start_j, end_j, - iq2, iq3, rk2, rk3, rv2, rv3, ik2, ik3, iv2, iv3, + gqa_iq1, iq2, iq3, rk2, rk3, rv2, rv3, ik2, ik3, iv2, iv3, q_stride, k_stride, v_stride, m_stride; void init_indices() @@ -173,12 +193,25 @@ void init_indices() N = p.N; KV = p.KV; - i = gl_WorkGroupID.x; - split_k_index = 0; - if (p.k_num > 1) { + if (p.gqa_ratio > 1) { + i = 0; + // batch and split_k share gl_WorkGroupID.x + gqa_iq1 = gl_WorkGroupID.x / p.k_num; + split_k_index = gl_WorkGroupID.x % p.k_num; + } else { + gqa_iq1 = 0; + split_k_index = gl_WorkGroupID.x % p.k_num; + i = gl_WorkGroupID.x / p.k_num; + } + } else if (p.gqa_ratio > 1) { i = 0; - split_k_index = gl_WorkGroupID.x; + gqa_iq1 = gl_WorkGroupID.x; + split_k_index = 0; + } else { + i = gl_WorkGroupID.x; + gqa_iq1 = 0; + split_k_index = 0; } Tr = CEIL_DIV(N, Br); @@ -218,3 +251,15 @@ void init_indices() // and breaking the alignment detection. m_stride = (p.gqa_ratio > 1) ? (p.gqa_ratio >> 16) : KV; } + +// Bias applied to softmax to stay in fp16 range. +// Based on ggml-cuda issue https://github.com/ggml-org/llama.cpp/issues/18606 +const float FATTN_KQ_MAX_OFFSET = 3.0f*0.6931f; + +// Store the output when doing grouped query attention. +// Rows index by Q's dimension 2, and the first N rows are valid. +void gqaStore(const in uint32_t r, const in uint32_t c, const in O_TYPEV4 elems, const in uint32_t o_offset, const in uint32_t iq2, const in uint32_t N) +{ + uint32_t offset = (iq2 + r) * HSV / 4 + c; + data_ov4[o_offset + offset] = D_TYPEV4(elems); +} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp index c995ab140ee..23ae3833e52 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp @@ -6,58 +6,61 @@ #extension GL_EXT_shader_explicit_arithmetic_types_float16 : require #extension GL_EXT_shader_explicit_arithmetic_types_int32 : require +#if defined(BFLOAT16) +#extension GL_EXT_bfloat16 : enable +#endif + #extension GL_KHR_shader_subgroup_basic : enable +#extension GL_KHR_shader_subgroup_arithmetic : enable #extension GL_KHR_shader_subgroup_vote : enable #extension GL_KHR_memory_scope_semantics : enable #extension GL_KHR_cooperative_matrix : enable #include "types.glsl" #include "flash_attn_base.glsl" +#if !defined(BFLOAT16) +#include "flash_attn_dequant.glsl" +#endif -const uint32_t HSK_per_thread = HSK / D_split; -const uint32_t HSV_per_thread = HSV / D_split; +// These need to be supported N,M values for a MatBc x MatBr x 16 coopmatmuladd +const uint32_t MatBr = 16; +const uint32_t MatBc = 16; -const uint32_t row_split = 4; const uint32_t rows_per_thread = Br / row_split; -const uint32_t cols_per_iter = gl_WorkGroupSize.x / D_split / row_split; +const uint32_t cols_per_iter = gl_WorkGroupSize.x / row_split; const uint32_t cols_per_thread = Bc / cols_per_iter; layout (binding = 0) readonly buffer Q {float data_q[];}; layout (binding = 0) readonly buffer QV4 {vec4 data_qv4[];}; -layout (binding = 1) readonly buffer K {float16_t data_k[];}; -layout (binding = 1) readonly buffer KV4 {f16vec4 data_kv4[];}; -layout (binding = 2) readonly buffer V {float16_t data_v[];}; -layout (binding = 2) readonly buffer VV4 {f16vec4 data_vv4[];}; +layout (binding = 1) readonly buffer K {FLOAT_TYPE data_k[];}; +layout (binding = 1) readonly buffer KV4 {FLOAT_TYPEV4 data_kv4[];}; +layout (binding = 2) readonly buffer V {FLOAT_TYPE data_v[];}; +layout (binding = 2) readonly buffer VV4 {FLOAT_TYPEV4 data_vv4[];}; layout (binding = 3) readonly buffer M {float16_t data_m[];}; -// Store the output when doing grouped query attention. -// Rows index by Q's dimension 2, and the first N rows are valid. -D_TYPE perElemOpGqaStore(const in uint32_t r, const in uint32_t c, const in D_TYPE elem, const in uint32_t o_offset, const in uint32_t iq2, const in uint32_t N) -{ - uint32_t offset = (iq2 + r) * HSV + c; - data_o[o_offset + offset] = D_TYPE(elem); - return elem; -} +shared float tmpsh[row_split]; -// These need to be supported N,M values for a MatBc x MatBr x 16 coopmatmuladd -const uint32_t MatBr = 16; -const uint32_t MatBc = 16; - -shared FLOAT_TYPE tmpsh[gl_WorkGroupSize.x]; -shared ACC_TYPEV4 tmpshv4[gl_WorkGroupSize.x]; +const uint32_t qstride = HSK_pad / 4 + 2; +shared FLOAT_TYPEV4 Qf[Br * qstride]; -const uint32_t qstride = HSK_pad / 4 + 2; // in units of f16vec4 -shared f16vec4 Qf[Br * qstride]; +const uint psh_stride = Br / 4 + 2; +shared FLOAT_TYPEV4 Psh[Bc * psh_stride]; // Avoid padding for hsk==256 to make it fit in 48KB shmem. -const uint32_t sfshstride = (HSK <= 128) ? (Br + 8) : Br; -shared ACC_TYPE sfsh[Bc * sfshstride]; +const uint32_t sfshstride = (HSK <= 128) ? (Br / 4 + 2) : Br / 4; +shared ACC_TYPEV4 sfsh[Bc * sfshstride]; -const uint32_t kshstride = HSK_pad / 4 + 2; // in units of f16vec4 -shared f16vec4 ksh[Bc * kshstride]; +const uint32_t D_pad = HSK_pad > HSV_pad ? HSK_pad : HSV_pad; +const uint32_t kvsh_stride = (SHMEM_STAGING != 0 ? D_pad : MatBr) / 4 + 2; +const uint v_cols = MatBc / 4 * row_split; // total cols, 4 vec4s per MatBc * number of subgroups +const uint vsh_stride = v_cols; +shared FLOAT_TYPEV4 kvsh[(kvsh_stride >= vsh_stride) ? (Bc * kvsh_stride) : (Bc * vsh_stride)]; -shared float slope[Br]; +const uint32_t osh_stride = row_split * MatBr / 4; +shared O_TYPEV4 pvsh[MatBc * osh_stride]; + +shared ACC_TYPE slope[Br]; void main() { #ifdef NEEDS_INIT_IQ_SHMEM @@ -69,9 +72,9 @@ void main() { const uint32_t tid = gl_LocalInvocationIndex; const uint32_t threads_per_rowgroup = gl_WorkGroupSize.x / row_split; + const uint32_t d_per_thread = (HSV/4 + threads_per_rowgroup - 1) / threads_per_rowgroup; const uint32_t row_tid = gl_LocalInvocationIndex / threads_per_rowgroup; - const uint32_t d_tid = gl_LocalInvocationIndex % D_split; - const uint32_t col_tid = (gl_LocalInvocationIndex % threads_per_rowgroup) / D_split; + const uint32_t col_tid = gl_LocalInvocationIndex % threads_per_rowgroup; #define tile_row(r) (row_tid * rows_per_thread + (r)) @@ -79,33 +82,28 @@ void main() { if ((HSK % 16) != 0) { [[unroll]] for (uint i = 0; i < Br * qstride; i += gl_WorkGroupSize.x) { if (i + tid < Br * qstride) { - Qf[i + tid] = f16vec4(0); - } - } - [[unroll]] for (uint i = 0; i < Bc * kshstride; i += gl_WorkGroupSize.x) { - if (i + tid < Bc * kshstride) { - ksh[i + tid] = f16vec4(0); + Qf[i + tid] = FLOAT_TYPEV4(0); } } barrier(); } - uint32_t q_offset = (iq2*p.nb02+iq3*p.nb03) / 4; + uint32_t q_offset = gqa_iq1*p.nb01 + (iq2*p.nb02+iq3*p.nb03) / 4; [[unroll]] for (uint32_t idx = 0; idx < Br * HSK / 4; idx += gl_WorkGroupSize.x) { uint32_t d = (idx + tid) % (HSK / 4); uint32_t r = (idx + tid) / (HSK / 4); if (r < Br && d < HSK / 4 && i * Br + r < N) { - Qf[r * qstride + d] = f16vec4(data_qv4[q_offset / 4 + (i * Br + r) * q_stride / 4 + d] * p.scale); + Qf[r * qstride + d] = FLOAT_TYPEV4(data_qv4[q_offset / 4 + (i * Br + r) * q_stride / 4 + d] * p.scale); } } barrier(); - ACC_TYPEV4 Of[rows_per_thread][HSV_per_thread / 4]; - [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) { - [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { - Of[r][d] = ACC_TYPEV4(0.0); + O_TYPEV4 Of[rows_per_thread][d_per_thread]; + [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { + [[unroll]] for (uint32_t d = 0; d < d_per_thread; ++d) { + Of[r][d] = O_TYPEV4(0.0); } } @@ -125,94 +123,188 @@ void main() { uint r = tid; slope[r] = perElemOpComputeSlope(r, col_tid, ACC_TYPE(0), iq2); } - barrier(); } else { if (tid < Br) { uint r = tid; - slope[r] = 1.0; + slope[r] = ACC_TYPE(1.0); } - barrier(); } -#if BLOCK_SIZE > 1 - uint32_t k_offset = (ik2*p.nb12 + ik3*p.nb13) / BLOCK_BYTE_SIZE; - uint32_t v_offset = (iv2*p.nb22 + iv3*p.nb23) / BLOCK_BYTE_SIZE; -#else - uint32_t k_offset = (ik2*p.nb12 + ik3*p.nb13) / 2; - uint32_t v_offset = (iv2*p.nb22 + iv3*p.nb23) / 2; -#endif - uint32_t m_offset = 0; + const uint32_t mo_stride = CEIL_DIV(KV, 16 * Bc); + // mo_offset will point to the tile starting at row i*Br and col 0 + uint32_t mo_offset = mo_stride * i; + + // FaBlockBytesK/V == 2 for f16 (sizeof f16) and == 16 for f32 (vec4) and == ggml block size for quants. + uint32_t k_offset = (ik2*p.nb12 + ik3*p.nb13) / FaBlockBytesK; + uint32_t v_offset = (iv2*p.nb22 + iv3*p.nb23) / FaBlockBytesV; + uint32_t m_offset = gqa_iq1*KV; if (p.nem2 != 1 || p.nem3 != 1) { - m_offset = ((iq3 % p.nem3) * p.nem2 + (iq2 % p.nem2)) * p.nem1 * KV; + m_offset += ((iq3 % p.nem3) * p.nem2 + (iq2 % p.nem2)) * p.nem1 * KV; + mo_offset += ((iq3 % p.nem3) * p.nem2 + (iq2 % p.nem2)) * CEIL_DIV(p.nem1, Br) * mo_stride; } + uint32_t mask_opt = 0; + uint32_t mask_opt_idx = ~0; + uint32_t mask_opt_bits = 0; + f16vec4 mask_cache[Bc * Br / 4 / WorkGroupSize]; + [[dont_unroll]] for (uint32_t j = start_j; j < end_j; ++j) { - float mask_cache[Bc * Br / WorkGroupSize]; - if ((p.mask_n_head_log2 & MASK_ENABLE_BIT) != 0) { - bool nem1_bounds_check = !(p.gqa_ratio > 1) && (p.nem1 % Br) != 0; - - float max_mask = NEG_FLT_MAX_OVER_2; - [[unroll]] for (uint32_t idx = 0; idx < Bc * Br; idx += gl_WorkGroupSize.x) { - uint32_t c = (idx + tid) % Bc; - uint32_t r = (idx + tid) / Bc; - if (idx + tid < Bc * Br || idx + gl_WorkGroupSize.x <= Bc * Br) { - if ((!KV_bounds_check || j * Bc + c < KV) && (!nem1_bounds_check || i * Br + r < p.nem1)) { - float m = float(data_m[m_offset + (i * Br + r) * m_stride + (j * Bc + c)]); - mask_cache[idx / WorkGroupSize] = m; - max_mask = max(max_mask, m); - } - } - } - // skip the block if the mask is entirely -inf - bool all_less = subgroupAll(max_mask <= NEG_FLT_MAX_OVER_2); - barrier(); - if (gl_SubgroupInvocationID == 0) { - tmpsh[gl_SubgroupID] = all_less ? NEG_FLT_MAX_OVER_2 : 0.0f; - } - barrier(); - [[unroll]] for (uint s = 0; s < gl_NumSubgroups; ++s) { - max_mask = max(max_mask, tmpsh[s]); + [[unroll]] for (uint32_t idx = 0; idx < mask_cache.length(); ++idx) { + mask_cache[idx] = f16vec4(0); + } + + if (MASK_ENABLE) { + if (USE_MASK_OPT && mask_opt_idx != j / 16) { + mask_opt_idx = j / 16; + mask_opt = data_mask_opt[mo_offset + mask_opt_idx]; } - if (max_mask <= NEG_FLT_MAX_OVER_2) { + mask_opt_bits = (mask_opt >> ((j % 16) * 2)) & 0x3; + if (mask_opt_bits == MASK_OPT_ALL_NEG_INF) { + // skip this block continue; } + // Only load if the block is not all zeros + if (mask_opt_bits != MASK_OPT_ALL_ZERO) { + bool nem1_bounds_check = !(p.gqa_ratio > 1) && (p.nem1 % Br) != 0; + + float max_mask = NEG_FLT_MAX_OVER_2; + [[unroll]] for (uint32_t idx = 0; idx < Bc * Br / 4; idx += gl_WorkGroupSize.x) { + uint32_t c = (idx + tid) / (Br / 4); + uint32_t r = (idx + tid) % (Br / 4); + if (idx + tid < Bc * Br / 4 || idx + gl_WorkGroupSize.x <= Bc * Br / 4) { + if ((!KV_bounds_check || j * Bc + c < KV)) { + f16vec4 m; + if (!nem1_bounds_check || i * Br + r * 4 + 3 < p.nem1) { + m = f16vec4(data_m[m_offset + (i * Br + r * 4 ) * m_stride + (j * Bc + c)], + data_m[m_offset + (i * Br + r * 4 + 1) * m_stride + (j * Bc + c)], + data_m[m_offset + (i * Br + r * 4 + 2) * m_stride + (j * Bc + c)], + data_m[m_offset + (i * Br + r * 4 + 3) * m_stride + (j * Bc + c)]); + max_mask = max(max(max(max(max_mask, float(m[0])), float(m[1])), float(m[2])), float(m[3])); + } else if (i * Br + r * 4 + 2 < p.nem1) { + m = f16vec4(data_m[m_offset + (i * Br + r * 4 ) * m_stride + (j * Bc + c)], + data_m[m_offset + (i * Br + r * 4 + 1) * m_stride + (j * Bc + c)], + data_m[m_offset + (i * Br + r * 4 + 2) * m_stride + (j * Bc + c)], + 0.0); + max_mask = max(max(max(max_mask, float(m[0])), float(m[1])), float(m[2])); + } else if (i * Br + r * 4 + 1 < p.nem1) { + m = f16vec4(data_m[m_offset + (i * Br + r * 4 ) * m_stride + (j * Bc + c)], + data_m[m_offset + (i * Br + r * 4 + 1) * m_stride + (j * Bc + c)], + 0.0, + 0.0); + max_mask = max(max(max_mask, float(m[0])), float(m[1])); + } else if (i * Br + r * 4 < p.nem1) { + m = f16vec4(data_m[m_offset + (i * Br + r * 4 ) * m_stride + (j * Bc + c)], + 0.0, + 0.0, + 0.0); + max_mask = max(max_mask, float(m[0])); + } else { + m = f16vec4(0.0); + } + mask_cache[idx / WorkGroupSize] = m; + } + } + } + // skip the block if the mask is entirely -inf + bool all_less = subgroupAll(max_mask <= NEG_FLT_MAX_OVER_2); + barrier(); + if (gl_SubgroupInvocationID == 0) { + tmpsh[gl_SubgroupID] = all_less ? NEG_FLT_MAX_OVER_2 : 0.0f; + } + barrier(); + [[unroll]] for (uint s = 0; s < gl_NumSubgroups; ++s) { + max_mask = max(max_mask, tmpsh[s]); + } + if (max_mask <= NEG_FLT_MAX_OVER_2) { + continue; + } + } } - [[unroll]] for (uint32_t idx = 0; idx < Bc * HSK / 4; idx += gl_WorkGroupSize.x) { - uint32_t d = (idx + tid) % (HSK / 4); - uint32_t c = (idx + tid) / (HSK / 4); - if (c < Bc && d < HSK / 4) { - f16vec4 K_Tf = f16vec4(0); - if (!KV_bounds_check || j * Bc + c < KV) { -#if BLOCK_SIZE > 1 - uint coord = (j * Bc + c) * k_stride * BLOCK_SIZE + 4 * d; - uint ib = coord / BLOCK_SIZE; - uint iqs = (coord % BLOCK_SIZE); - K_Tf = f16vec4(dequantize4(ib, iqs, k_offset, BINDING_IDX_K)); -#else - K_Tf = f16vec4(data_kv4[k_offset / 4 + (j * Bc + c) * k_stride / 4 + d]); + if (SHMEM_STAGING != 0) { + [[unroll]] for (uint32_t idx = 0; idx < Bc * HSK_pad / 4; idx += gl_WorkGroupSize.x) { + uint32_t d = (idx + tid) % (HSK_pad / 4); + uint32_t c = (idx + tid) / (HSK_pad / 4); + if (idx + gl_WorkGroupSize.x <= Bc * HSK_pad / 4 || c < Bc) { + FLOAT_TYPEV4 K_Tf = FLOAT_TYPEV4(0); + if ((!KV_bounds_check || j * Bc + c < KV) && (HSK == HSK_pad || d < HSK / 4)) { +#if !defined(BFLOAT16) + if (USE_DECODE_K) { + uint coord = (j * Bc + c) * k_stride * BLOCK_SIZE_K + 4 * d; + uint ib = coord / BLOCK_SIZE_K; + uint iqs = (coord % BLOCK_SIZE_K); + K_Tf = dequantize4(ib, iqs, k_offset, BINDING_IDX_K); + } else #endif - } + { + K_Tf = FLOAT_TYPEV4(data_kv4[k_offset / 4 + (j * Bc + c) * k_stride / 4 + d]); + } + } - ksh[c * kshstride + d] = K_Tf; + kvsh[c * kvsh_stride + d] = K_Tf; + } } + barrier(); } - barrier(); // K * Q^T -> S^T: Bc x HSK_pad * HSK_pad x Br -> Bc x Br // Bc split across workgroup (four subgroups), loop over HSK in chunks of 16: 16 x 16 * 16 x 16 -> 16 x 16 // This is written transposed in order to allow for N being 8 if implementations need it coopmat<ACC_TYPE, gl_ScopeSubgroup, MatBc, MatBr, gl_MatrixUseAccumulator> SfMat = coopmat<ACC_TYPE, gl_ScopeSubgroup, MatBc, MatBr, gl_MatrixUseAccumulator>(0); - coopmat<float16_t, gl_ScopeSubgroup, MatBc, 16, gl_MatrixUseA> KMat; - coopmat<float16_t, gl_ScopeSubgroup, 16, MatBr, gl_MatrixUseB> QMat; + coopmat<FLOAT_TYPE, gl_ScopeSubgroup, MatBc, 16, gl_MatrixUseA> KMat; + coopmat<FLOAT_TYPE, gl_ScopeSubgroup, 16, MatBr, gl_MatrixUseB> QMat; + + [[unroll]] for (uint32_t d = 0; d < HSK_pad / 16; ++d) { + // If SHMEM_STAGING is set, a Bc * HSK_pad size tile of K is loaded to shmem + // If not, K is loaded directly from global memory if aligned, otherwise + // staged through a Bc * MatBr size staging buffer. + // If K is a quant type, then it is always staged for dequantization. + if (SHMEM_STAGING == 0) { + // For quants we always need to dequant into kvsh; for f16/bf16 we can load + // directly from global memory when alignment / bounds allow it. + const bool stage_k = USE_DECODE_K || KV_bounds_check || d * 16 + 16 > HSK; + if (stage_k) { + barrier(); + [[unroll]] for (uint32_t idx = 0; idx < Bc * MatBr / 4; idx += gl_WorkGroupSize.x) { + uint32_t col_vec = (idx + tid) % (MatBr / 4); + uint32_t row = (idx + tid) / (MatBr / 4); + if (idx + tid < Bc * MatBr / 4) { + FLOAT_TYPEV4 K_Tf = FLOAT_TYPEV4(0); + if ((!KV_bounds_check || j * Bc + row < KV) && (HSK == HSK_pad || d * 16 + col_vec * 4 < HSK)) { +#if !defined(BFLOAT16) + if (USE_DECODE_K) { + uint coord = (j * Bc + row) * k_stride * BLOCK_SIZE_K + d * 16 + col_vec * 4; + uint ib = coord / BLOCK_SIZE_K; + uint iqs = (coord % BLOCK_SIZE_K); + K_Tf = dequantize4(ib, iqs, k_offset, BINDING_IDX_K); + } else +#endif + { + K_Tf = FLOAT_TYPEV4(data_kv4[k_offset / 4 + (j * Bc + row) * k_stride / 4 + d * 16 / 4 + col_vec]); + } + } - for (uint32_t d = 0; d < HSK_pad / 16; ++d) { - coopMatLoad(QMat, Qf, d * 16 / 4, qstride, gl_CooperativeMatrixLayoutColumnMajor); + kvsh[row * kvsh_stride + col_vec] = K_Tf; + } + } + barrier(); + } - uint coord = (gl_SubgroupID * MatBc) * kshstride + d * 16 / 4; - coopMatLoad(KMat, ksh, coord, kshstride, gl_CooperativeMatrixLayoutRowMajor); + if (stage_k) { + uint coord = (gl_SubgroupID * MatBc) * kvsh_stride; + coopMatLoad(KMat, kvsh, coord, kvsh_stride, gl_CooperativeMatrixLayoutRowMajor); + } else { + const uint coord = k_offset / 4 + (j * Bc + gl_SubgroupID * MatBc) * k_stride / 4 + d * 16 / 4; + coopMatLoad(KMat, data_kv4, coord, k_stride / 4, gl_CooperativeMatrixLayoutRowMajor); + } + } else { + uint coord = (gl_SubgroupID * MatBc) * kvsh_stride + d * 16 / 4; + coopMatLoad(KMat, kvsh, coord, kvsh_stride, gl_CooperativeMatrixLayoutRowMajor); + } + + coopMatLoad(QMat, Qf, d * 16 / 4, qstride, gl_CooperativeMatrixLayoutColumnMajor); SfMat = coopMatMulAdd(KMat, QMat, SfMat); } @@ -221,27 +313,27 @@ void main() { coopMatStore(SfMat, sfsh, coord, sfshstride, gl_CooperativeMatrixLayoutRowMajor); barrier(); - if (p.logit_softcap != 0.0f) { - [[unroll]] for (uint32_t idx = 0; idx < Bc * Br; idx += gl_WorkGroupSize.x) { - uint32_t c = (idx + tid) / Br; - uint32_t r = (idx + tid) % Br; - if (idx + tid < Bc * Br || idx + gl_WorkGroupSize.x <= Bc * Br) { - sfsh[c * sfshstride + r] = ACC_TYPE(p.logit_softcap * tanh(sfsh[c * sfshstride + r])); + if (LOGIT_SOFTCAP) { + [[unroll]] for (uint32_t idx = 0; idx < Bc * Br / 4; idx += gl_WorkGroupSize.x) { + uint32_t c = (idx + tid) / (Br / 4); + uint32_t r = (idx + tid) % (Br / 4); + if (idx + tid < Bc * Br / 4 || idx + gl_WorkGroupSize.x <= Bc * Br / 4) { + sfsh[c * sfshstride + r] = ACC_TYPEV4(p.logit_softcap * tanh(sfsh[c * sfshstride + r])); } } barrier(); } - if ((p.mask_n_head_log2 & MASK_ENABLE_BIT) != 0) { - bool nem1_bounds_check = !(p.gqa_ratio > 1) && (p.nem1 % Br) != 0; - - [[unroll]] for (uint32_t idx = 0; idx < Bc * Br; idx += gl_WorkGroupSize.x) { - uint32_t c = (idx + tid) % Bc; - uint32_t r = (idx + tid) / Bc; - if (idx + tid < Bc * Br || idx + gl_WorkGroupSize.x <= Bc * Br) { - if ((!KV_bounds_check || j * Bc + c < KV) && (!nem1_bounds_check || i * Br + r < p.nem1)) { - float f = mask_cache[idx / WorkGroupSize]; - sfsh[c * sfshstride + r] += ACC_TYPE(slope[r] * f); + if (MASK_ENABLE && mask_opt_bits != MASK_OPT_ALL_ZERO) { + [[unroll]] for (uint32_t idx = 0; idx < Bc * Br / 4; idx += gl_WorkGroupSize.x) { + uint32_t c = (idx + tid) / (Br / 4); + uint32_t r = (idx + tid) % (Br / 4); + if (idx + tid < Bc * Br / 4 || idx + gl_WorkGroupSize.x <= Bc * Br / 4) { + if (!KV_bounds_check || j * Bc + c < KV) { + // Mask nem1 bounds check is handled when loading masks + ACC_TYPEV4 masks = ACC_TYPEV4(mask_cache[idx / WorkGroupSize]); + ACC_TYPEV4 slopes = ACC_TYPEV4(slope[r * 4], slope[r * 4 + 1], slope[r * 4 + 2], slope[r * 4 + 3]); + sfsh[c * sfshstride + r] += slopes * masks; } } } @@ -250,143 +342,237 @@ void main() { float eMf[rows_per_thread]; [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { + const uint r_vec = tile_row(r) / 4; + const uint r_comp = tile_row(r) % 4; + float rowmaxf = NEG_FLT_MAX_OVER_2; [[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) { if (KV_bounds_check && j * Bc + c * cols_per_iter + col_tid >= KV) { continue; } - rowmaxf = max(rowmaxf, float(sfsh[tile_row(r) + (c * cols_per_iter + col_tid) * sfshstride])); + rowmaxf = max(rowmaxf, float(sfsh[r_vec + (c * cols_per_iter + col_tid) * sfshstride][r_comp])); } float Moldf = Mf[r]; + // Compute max across the row + rowmaxf = subgroupMax(rowmaxf); + // M = max(rowmax, Mold) // P = e^(S - M) // eM = e^(Mold - M) Mf[r] = max(rowmaxf, Moldf); eMf[r] = exp(Moldf - Mf[r]); + + Lf[r] = eMf[r]*Lf[r]; } - [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) { + [[unroll]] for (uint32_t d0 = 0; d0 < HSV / 4; d0 += threads_per_rowgroup) { + const uint d_local = d0 / threads_per_rowgroup; [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { - Of[r][d] = ACC_TYPE(eMf[r]) * Of[r][d]; + Of[r][d_local] = O_TYPE(eMf[r]) * Of[r][d_local]; } } - [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { - Lf[r] = eMf[r]*Lf[r]; - } + // Calculate and store Pf in Psh [[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) { - if (KV_bounds_check && j * Bc + c * cols_per_iter + col_tid >= KV) { - continue; - } - float Pf[rows_per_thread]; - [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { - Pf[r] = exp(sfsh[tile_row(r) + (c * cols_per_iter + col_tid) * sfshstride] - Mf[r]); - Lf[r] += Pf[r]; + const uint col = c * cols_per_iter + col_tid; + + [[unroll]] for (uint32_t r = 0; r < rows_per_thread; r += 4) { + const uint row = tile_row(r); + if (KV_bounds_check && j * Bc + col >= KV) { + Psh[col * psh_stride + row / 4] = FLOAT_TYPEV4(0.0f); + } else { + const vec4 mfvec = vec4(Mf[r], Mf[r + 1], Mf[r + 2], Mf[r + 3]); + const FLOAT_TYPEV4 Pf = FLOAT_TYPEV4(exp(vec4(sfsh[row / 4 + col * sfshstride]) - mfvec)); + [[unroll]] for (uint32_t vec_idx = 0; vec_idx < 4; ++vec_idx) { + Lf[r + vec_idx] += Pf[vec_idx]; + } + Psh[col * psh_stride + row / 4] = Pf; + } } - [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) { -#if BLOCK_SIZE > 1 - uint coord = (j * Bc + c * cols_per_iter + col_tid) * v_stride * BLOCK_SIZE + 4 * (d * D_split + d_tid); - uint ib = coord / BLOCK_SIZE; - uint iqs = (coord % BLOCK_SIZE); - vec4 Vf = dequantize4(ib, iqs, v_offset, BINDING_IDX_V); -#else - vec4 Vf = vec4(data_vv4[v_offset / 4 + (j * Bc + c * cols_per_iter + col_tid) * v_stride / 4 + d * D_split + d_tid]); + } + + if (SHMEM_STAGING != 0) { + [[unroll]] for (uint32_t idx = 0; idx < Bc * HSV_pad / 4; idx += gl_WorkGroupSize.x) { + uint32_t d = (idx + tid) % (HSV_pad / 4); + uint32_t c = (idx + tid) / (HSV_pad / 4); + if (idx + gl_WorkGroupSize.x <= Bc * HSV_pad / 4 || c < Bc) { + FLOAT_TYPEV4 V_Tf = FLOAT_TYPEV4(0); + if ((!KV_bounds_check || j * Bc + c < KV) && (HSV == HSV_pad || d < HSV / 4)) { +#if !defined(BFLOAT16) + if (USE_DECODE_V) { + uint coord = (j * Bc + c) * v_stride * BLOCK_SIZE_V + 4 * d; + uint ib = coord / BLOCK_SIZE_V; + uint iqs = (coord % BLOCK_SIZE_V); + V_Tf = dequantize4(ib, iqs, v_offset, BINDING_IDX_V); + } else #endif - [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { - Of[r][d] += ACC_TYPE(Pf[r]) * ACC_TYPEV4(Vf); + { + V_Tf = FLOAT_TYPEV4(data_vv4[v_offset / 4 + (j * Bc + c) * v_stride / 4 + d]); + } + } + + kvsh[c * kvsh_stride + d] = V_Tf; } } } - barrier(); - } - // prevent race on tmpsh - barrier(); + const uint num_hsv_tiles = (HSV + MatBc * row_split - 1) / (MatBc * row_split); // round up + + // Each subgroup handles HSV/4 columns + [[unroll]] for (uint32_t hsv_tile = 0; hsv_tile < num_hsv_tiles; ++hsv_tile) { + const uint hsv_offset = (hsv_tile * row_split + gl_SubgroupID) * 16; + + coopmat<O_TYPE, gl_ScopeSubgroup, MatBc, MatBr, gl_MatrixUseAccumulator> PVMat = coopmat<O_TYPE, gl_ScopeSubgroup, MatBc, MatBr, gl_MatrixUseAccumulator>(0); + + // Preload V tiles for [Bc, 16 * num subgroups] + const uint v_rows = Bc; + const uint v_total = v_rows * v_cols; + const uint v_loads_per_thread = v_total / gl_WorkGroupSize.x; + + // If SHMEM_STAGING is set, a Bc * HSV_pad size tile of V is loaded to shmem. + // If not, V is loaded directly from global memory if aligned, otherwise + // staged through a Bc * MatBr size staging buffer. + // If V is a quant type, then it is always staged for dequantization. + if (SHMEM_STAGING == 0) { + // For quants we always preload via kvsh. For f16/bf16 we only preload when + // alignment / bounds force it (otherwise we coopMatLoad direct from data_vv4). + const bool stage_v = USE_DECODE_V || KV_bounds_check; + if (stage_v) { + [[unroll]] for (uint32_t i = 0; i < v_loads_per_thread; ++i) { + const uint idx = i * gl_WorkGroupSize.x + tid; + const uint row = idx / v_cols; + const uint col = idx % v_cols; + + const uint v_row = j * Bc + row; + const uint v_col = hsv_tile * MatBc * row_split + col * 4; + + const uint coord = v_row * v_stride * BLOCK_SIZE_V + v_col; + const uint ib = coord / BLOCK_SIZE_V; + const uint iqs = coord % BLOCK_SIZE_V; + + if (!KV_bounds_check || (v_row < KV && v_col < HSV)) { +#if !defined(BFLOAT16) + if (USE_DECODE_V) { + kvsh[row * vsh_stride + col] = dequantize4(ib, iqs, v_offset, BINDING_IDX_V); + } else +#endif + { + kvsh[row * vsh_stride + col] = data_vv4[(v_offset + v_row * v_stride + v_col) / 4]; + } + } else { + kvsh[row * vsh_stride + col] = FLOAT_TYPEV4(0.0f); + } + } + } + } + barrier(); - // reduce across threads + const uint o_offset = gl_SubgroupID * MatBr / 4; + + if (hsv_offset < HSV_pad) { + [[unroll]] for (uint32_t bc_chunk = 0; bc_chunk < Bc / MatBc; ++bc_chunk) { + coopMatLoad(KMat, Psh, bc_chunk * MatBc * psh_stride, psh_stride, gl_CooperativeMatrixLayoutColumnMajor); + + if (SHMEM_STAGING == 0) { + if (!USE_DECODE_V && !KV_bounds_check) { + // F16/BF16 values can be loaded directly from global memory + const uint v_tile_row = j * Bc + bc_chunk * MatBc; + const uint v_tile_offset = v_offset / 4 + v_tile_row * v_stride / 4 + hsv_offset / 4; + coopMatLoad(QMat, data_vv4, v_tile_offset, v_stride / 4, gl_CooperativeMatrixLayoutRowMajor); + } else { + const uint v_tile_offset = bc_chunk * MatBr * v_cols + gl_SubgroupID * (MatBc / 4); + coopMatLoad(QMat, kvsh, v_tile_offset, vsh_stride, gl_CooperativeMatrixLayoutRowMajor); + } + } else { + const uint v_tile_offset = bc_chunk * MatBc * kvsh_stride + (hsv_tile * row_split + gl_SubgroupID) * (MatBc / 4); + coopMatLoad(QMat, kvsh, v_tile_offset, kvsh_stride, gl_CooperativeMatrixLayoutRowMajor); + } + + PVMat = coopMatMulAdd(KMat, QMat, PVMat); + } + + // Store PVMat to pvsh and load into Of + coopMatStore(PVMat, pvsh, o_offset, osh_stride, gl_CooperativeMatrixLayoutRowMajor); + } - float rowmaxf[rows_per_thread], eMf[rows_per_thread], Moldf[rows_per_thread]; - [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { - FLOAT_TYPE M = Mf[r]; - tmpsh[tid] = M; - // Compute max across the row - barrier(); - [[unroll]] for (int s = int(gl_WorkGroupSize.x / row_split) / 2; s >= D_split; s >>= 1) { - M = max(M, tmpsh[tid ^ s]); - barrier(); - tmpsh[tid] = M; barrier(); - } - rowmaxf[r] = tmpsh[d_tid + row_tid * threads_per_rowgroup]; - barrier(); - } - [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { - Moldf[r] = Mf[r]; + const uint hsv_per_tile = row_split * MatBc; + const uint hsv_base = hsv_tile * hsv_per_tile; + const uint d_values_per_tile = hsv_per_tile / 4; - // M = max(rowmax, Mold) - // eM = e^(Mold - M) - Mf[r] = max(rowmaxf[r], Moldf[r]); - eMf[r] = exp(Moldf[r] - Mf[r]); + const uint d_start = hsv_tile * d_values_per_tile; + const uint d_end = min(d_start + d_values_per_tile, HSV / 4); - Lf[r] = eMf[r]*Lf[r]; - } + [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { + const uint row = tile_row(r); - [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { - FLOAT_TYPE L = Lf[r]; - tmpsh[tid] = L; - // Compute sum across the row - barrier(); - [[unroll]] for (int s = int(gl_WorkGroupSize.x / row_split) / 2; s >= D_split; s >>= 1) { - L += tmpsh[tid ^ s]; - barrier(); - tmpsh[tid] = L; - barrier(); + [[unroll]] for (uint32_t d_local = 0; d_local < d_per_thread; ++d_local) { + const uint d = d_local * threads_per_rowgroup + col_tid; + const uint hsv_col = 4 * d; + + if (hsv_col >= hsv_base && hsv_col < hsv_base + hsv_per_tile && hsv_col < HSV) { + const uint local_hsv = (hsv_col - hsv_base) / 4; + Of[r][d_local] += pvsh[row * osh_stride + local_hsv]; + } + } + } } - Lf[r] = tmpsh[d_tid + row_tid * threads_per_rowgroup]; + barrier(); } [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { - [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) { - - Of[r][d] = ACC_TYPE(eMf[r]) * Of[r][d]; - tmpshv4[tid] = Of[r][d]; - - barrier(); - [[unroll]] for (int s = int(gl_WorkGroupSize.x / row_split) / 2; s >= D_split; s >>= 1) { - Of[r][d] += tmpshv4[tid ^ s]; - barrier(); - tmpshv4[tid] = Of[r][d]; - barrier(); - } - Of[r][d] = tmpshv4[d_tid + row_tid * threads_per_rowgroup]; - barrier(); - } + Lf[r] = subgroupAdd(Lf[r]); } // If there is split_k, then the split_k resolve shader does the final // division by L. Store the intermediate O value and per-row m and L values. if (p.k_num > 1) { - uint32_t o_offset = HSV * p.ne1 * (split_k_index + iq3 * p.k_num); + if (p.gqa_ratio > 1) { + // note: O and Q have swapped coord 1,2. + uint32_t o_offset = HSV * p.ne1 * (split_k_index + p.k_num * (gqa_iq1 + p.ne2 * iq3)) / 4; - [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { - if (tile_row(r) < N) { - [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) { - [[unroll]] for (uint32_t comp = 0; comp < 4; ++comp) { - perElemOpGqaStore(tile_row(r), 4*(d * D_split + d_tid) + comp, float(Of[r][d][comp]), o_offset, iq2, N); + [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { + if (tile_row(r) < N) { + [[unroll]] for (uint32_t d0 = 0; d0 < HSV / 4; d0 += threads_per_rowgroup) { + const uint d = d0 + col_tid; + if (d >= HSV/4) break; + const uint d_local = d0 / threads_per_rowgroup; + gqaStore(tile_row(r), d, Of[r][d_local], o_offset, iq2, N); } } } - } - o_offset = HSV * p.ne1 * p.ne3 * p.k_num + p.ne1 * (split_k_index + iq3 * p.k_num) * 2; - [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { - if (tile_row(r) < N) { - perElemOpStoreCol0(tile_row(r), 0u, ACC_TYPE(Lf[r]), o_offset, iq2, N); - perElemOpStoreCol0(tile_row(r), 0u, ACC_TYPE(Mf[r]), o_offset + p.ne1, iq2, N); + o_offset = HSV * p.ne1 * p.k_num * p.ne2 * p.ne3 + p.ne1 * 2 * (split_k_index + p.k_num * (gqa_iq1 + p.ne2 * iq3)); + [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { + if (tile_row(r) < N) { + perElemOpStoreCol0(tile_row(r), 0u, ACC_TYPE(Lf[r]), o_offset, iq2, N); + perElemOpStoreCol0(tile_row(r), 0u, ACC_TYPE(Mf[r]), o_offset + p.ne1, iq2, N); + } + } + } else { + [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { + const uint row = tile_row(r); + const uint global_row = i * Br + row; + + if (global_row < N) { + uint32_t o_offset = HSV * p.ne1 * (split_k_index + p.k_num * (global_row + p.ne2 * iq3)) / 4; + + [[unroll]] for (uint32_t d0 = 0; d0 < HSV / 4; d0 += threads_per_rowgroup) { + const uint d = d0 + col_tid; + if (d >= HSV/4) break; + data_ov4[o_offset + iq2 * HSV/4 + d] = D_TYPEV4(Of[r][d/threads_per_rowgroup]); + } + } + + if (global_row < N && col_tid == 0) { + uint32_t lm_offset = HSV * p.ne1 * p.k_num * p.ne2 * p.ne3 + p.ne1 * 2 * (split_k_index + p.k_num * (global_row + p.ne2 * iq3)); + data_o[lm_offset + iq2] = D_TYPE(Lf[r]); + data_o[lm_offset + p.ne1 + iq2] = D_TYPE(Mf[r]); + } } } @@ -403,8 +589,9 @@ void main() { if (sink > Mf[r]) { ms = exp(Mf[r] - sink); - [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) { - Of[r][d] *= ACC_TYPE(ms); + [[unroll]] for (uint32_t d0 = 0; d0 < HSV / 4; d0 += threads_per_rowgroup) { + const uint d_local = d0 / threads_per_rowgroup; + Of[r][d_local] *= O_TYPE(ms); } } else { vs = exp(sink - Mf[r]); @@ -419,34 +606,37 @@ void main() { Lfrcp[r] = (Lf[r] == 0.0) ? 0.0 : (1.0 / Lf[r]); } - [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) { + [[unroll]] for (uint32_t d0 = 0; d0 < HSV / 4; d0 += threads_per_rowgroup) { + const uint d_local = d0 / threads_per_rowgroup; [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { - Of[r][d] *= ACC_TYPE(Lfrcp[r]); -#if defined(ACC_TYPE_MAX) - Of[r][d] = clamp(Of[r][d], -ACC_TYPE_MAX, ACC_TYPE_MAX); + Of[r][d_local] *= O_TYPE(Lfrcp[r]); +#if defined(FLOAT_TYPE_MAX) + Of[r][d_local] = clamp(Of[r][d_local], -FLOAT_TYPE_MAX, FLOAT_TYPE_MAX); #endif } } - uint32_t o_offset = iq3*p.ne2*p.ne1*HSV; + uint32_t o_offset = (gqa_iq1*p.ne1*HSV + iq3*p.ne2*p.ne1*HSV) / 4; if (p.gqa_ratio > 1) { [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { if (tile_row(r) < N) { - [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) { - [[unroll]] for (uint32_t comp = 0; comp < 4; ++comp) { - perElemOpGqaStore(tile_row(r), 4*(d * D_split + d_tid) + comp, float(Of[r][d][comp]), o_offset, iq2, N); - } + [[unroll]] for (uint32_t d0 = 0; d0 < HSV / 4; d0 += threads_per_rowgroup) { + const uint d = d0 + col_tid; + if (d >= HSV / 4) break; + const uint d_local = d0 / threads_per_rowgroup; + gqaStore(tile_row(r), d, Of[r][d_local], o_offset, iq2, N); } } } } else { [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { if (i * Br + tile_row(r) < N) { - [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) { - [[unroll]] for (uint32_t comp = 0; comp < 4; ++comp) { - data_o[o_offset + iq2 * HSV + (i * Br + tile_row(r)) * p.ne1 * HSV + 4*(d * D_split + d_tid) + comp] = D_TYPE(Of[r][d][comp]); - } + [[unroll]] for (uint32_t d0 = 0; d0 < HSV / 4; d0 += threads_per_rowgroup) { + const uint d = d0 + col_tid; + if (d >= HSV / 4) break; + const uint d_local = d0 / threads_per_rowgroup; + data_ov4[o_offset + (iq2 * HSV + (i * Br + tile_row(r)) * p.ne1 * HSV) / 4 + d] = D_TYPEV4(Of[r][d_local]); } } } diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp index 9a71996383d..b9c03fe499d 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp @@ -8,17 +8,97 @@ #extension GL_EXT_shader_explicit_arithmetic_types_int32 : require #extension GL_EXT_shader_explicit_arithmetic_types_int16 : require +#if defined(BFLOAT16) +#extension GL_EXT_bfloat16 : enable +#endif + #extension GL_KHR_memory_scope_semantics : enable #extension GL_KHR_cooperative_matrix : enable #extension GL_NV_cooperative_matrix2 : enable +#ifdef GL_NV_cooperative_matrix_decode_vector +#extension GL_NV_cooperative_matrix_decode_vector : enable +#endif #extension GL_EXT_buffer_reference : enable #extension GL_KHR_shader_subgroup_ballot : enable #extension GL_KHR_shader_subgroup_vote : enable #extension GL_EXT_null_initializer : enable #include "types.glsl" -#include "dequant_funcs_cm2.glsl" #include "flash_attn_base.glsl" +#if !defined(BFLOAT16) +#include "dequant_funcs_cm2.glsl" +#endif + +// buffer_reference stride = sizeof(struct) = FaBlockBytesK/V. +layout(buffer_reference, std430, buffer_reference_align = 1) buffer decodeBufFA_K { + uint8_t raw[FaBlockBytesK]; +}; +layout(buffer_reference, std430, buffer_reference_align = 1) buffer decodeBufFA_V { + uint8_t raw[FaBlockBytesV]; +}; + +#if !defined(BFLOAT16) +float16_t faDecodeK(const decodeBufFA_K bl_in, const uint blockCoords[2], const uint coordInBlock[2]) { + switch (FaTypeK) { + case FA_TYPE_F32: return dequantFuncF32 (decodeBufF32 (bl_in), blockCoords, coordInBlock); + case FA_TYPE_Q4_0: return dequantFuncQ4_0(decodeBufQ4_0(bl_in), blockCoords, coordInBlock); + case FA_TYPE_Q4_1: return dequantFuncQ4_1(decodeBufQ4_1(bl_in), blockCoords, coordInBlock); + case FA_TYPE_Q5_0: return dequantFuncQ5_0(decodeBufQ5_0(bl_in), blockCoords, coordInBlock); + case FA_TYPE_Q5_1: return dequantFuncQ5_1(decodeBufQ5_1(bl_in), blockCoords, coordInBlock); + case FA_TYPE_Q8_0: return dequantFuncQ8_0(decodeBufQ8_0(bl_in), blockCoords, coordInBlock); + case FA_TYPE_Q1_0: return dequantFuncQ1_0(decodeBufQ1_0(bl_in), blockCoords, coordInBlock); + default: return float16_t(0); + } +} + +float16_t faDecodeV(const decodeBufFA_V bl_in, const uint blockCoords[2], const uint coordInBlock[2]) { + switch (FaTypeV) { + case FA_TYPE_F32: return dequantFuncF32 (decodeBufF32 (bl_in), blockCoords, coordInBlock); + case FA_TYPE_Q4_0: return dequantFuncQ4_0(decodeBufQ4_0(bl_in), blockCoords, coordInBlock); + case FA_TYPE_Q4_1: return dequantFuncQ4_1(decodeBufQ4_1(bl_in), blockCoords, coordInBlock); + case FA_TYPE_Q5_0: return dequantFuncQ5_0(decodeBufQ5_0(bl_in), blockCoords, coordInBlock); + case FA_TYPE_Q5_1: return dequantFuncQ5_1(decodeBufQ5_1(bl_in), blockCoords, coordInBlock); + case FA_TYPE_Q8_0: return dequantFuncQ8_0(decodeBufQ8_0(bl_in), blockCoords, coordInBlock); + case FA_TYPE_Q1_0: return dequantFuncQ1_0(decodeBufQ1_0(bl_in), blockCoords, coordInBlock); + default: return float16_t(0); + } +} + +// V=4 vector decode for K/V; dispatches to per-format _v decoders. +f16vec4 faDecodeKVector(const decodeBufFA_K bl_in, const uint blockCoords[2], const uint coordInBlock[2]) { + switch (FaTypeK) { + case 0u: return f16vec4(decodeBufF32(bl_in).block); + case 2u: return dequantFuncQ4_0_v(decodeBufQ4_0(bl_in), blockCoords, coordInBlock); + case 3u: return dequantFuncQ4_1_v(decodeBufQ4_1(bl_in), blockCoords, coordInBlock); + case 6u: return dequantFuncQ5_0_v(decodeBufQ5_0(bl_in), blockCoords, coordInBlock); + case 7u: return dequantFuncQ5_1_v(decodeBufQ5_1(bl_in), blockCoords, coordInBlock); + case 8u: return dequantFuncQ8_0_v(decodeBufQ8_0(bl_in), blockCoords, coordInBlock); + case 41u: return dequantFuncQ1_0_v(decodeBufQ1_0(bl_in), blockCoords, coordInBlock); + default: return f16vec4(0); + } +} + +f16vec4 faDecodeVVector(const decodeBufFA_V bl_in, const uint blockCoords[2], const uint coordInBlock[2]) { + switch (FaTypeV) { + case 0u: return f16vec4(decodeBufF32(bl_in).block); + case 2u: return dequantFuncQ4_0_v(decodeBufQ4_0(bl_in), blockCoords, coordInBlock); + case 3u: return dequantFuncQ4_1_v(decodeBufQ4_1(bl_in), blockCoords, coordInBlock); + case 6u: return dequantFuncQ5_0_v(decodeBufQ5_0(bl_in), blockCoords, coordInBlock); + case 7u: return dequantFuncQ5_1_v(decodeBufQ5_1(bl_in), blockCoords, coordInBlock); + case 8u: return dequantFuncQ8_0_v(decodeBufQ8_0(bl_in), blockCoords, coordInBlock); + case 41u: return dequantFuncQ1_0_v(decodeBufQ1_0(bl_in), blockCoords, coordInBlock); + default: return f16vec4(0); + } +} + +#ifdef GL_NV_cooperative_matrix_decode_vector +#define FADECODEK , faDecodeK, faDecodeKVector +#define FADECODEV , faDecodeV, faDecodeVVector +#else +#define FADECODEK , faDecodeK +#define FADECODEV , faDecodeV +#endif +#endif layout (binding = 0) readonly buffer Q {uint8_t data_q[];}; layout (binding = 1) readonly buffer K {uint8_t data_k[];}; @@ -55,12 +135,6 @@ ACC_TYPE Max(const in uint32_t row, const in uint32_t col, const in ACC_TYPE ele return max(elem0, elem1); } -#if defined(BLOCK_SIZE) -#define DECODEFUNC , DEQUANTFUNC -#else -#define DECODEFUNC -#endif - // Store the output when doing grouped query attention. // Rows index by Q's dimension 2, and the first N rows are valid. D_TYPE perElemOpGqaStore(const in uint32_t r, const in uint32_t c, const in D_TYPE elem, const in uint32_t o_offset, const in uint32_t iq2, const in uint32_t N) @@ -72,11 +146,29 @@ D_TYPE perElemOpGqaStore(const in uint32_t r, const in uint32_t c, const in D_TY return elem; } -void main() { -#ifdef NEEDS_INIT_IQ_SHMEM - init_iq_shmem(gl_WorkGroupSize); -#endif +// Store O values for non-GQA split_k. Rows are tokens, not heads. +D_TYPE perElemOpNonGqaSplitKStore(const in uint32_t r, const in uint32_t c, const in D_TYPE elem, const in uint32_t unused, const in uint32_t iq2, const in uint32_t N) { + uint32_t global_row = i * Br + r; + if (global_row < N && c < HSV) { + uint32_t o_off = HSV * p.ne1 + * (split_k_index + p.k_num * (global_row + p.ne2 * iq3)); + data_o[o_off + iq2 * HSV + c] = D_TYPE(elem); + } + return elem; +} +// Store L/M values for non-GQA split_k. +ACC_TYPE perElemOpNonGqaSplitKStoreCol0(const in uint32_t r, const in uint32_t c, const in ACC_TYPE elem, const in uint32_t lm_base, const in uint32_t iq2, const in uint32_t N) { + uint32_t global_row = i * Br + r; + if (global_row < N && c == 0) { + uint32_t lm_off = HSV * p.ne1 * p.k_num * p.ne2 * p.ne3 + + p.ne1 * 2 * (split_k_index + p.k_num * (global_row + p.ne2 * iq3)); + data_o[lm_off + lm_base + iq2] = D_TYPE(elem); + } + return elem; +} + +void main() { init_indices(); tensorLayoutNV<2, gl_CooperativeMatrixClampModeConstantNV> tensorLayoutQ = createTensorLayoutNV(2, gl_CooperativeMatrixClampModeConstantNV); @@ -85,10 +177,10 @@ void main() { tensorViewNV<2, false, 1, 0> tensorViewTranspose = createTensorViewNV(2, false, 1, 0); -#if defined(BLOCK_SIZE) - tensorLayoutK = setTensorLayoutBlockSizeNV(tensorLayoutK, 1, BLOCK_SIZE); - tensorLayoutV = setTensorLayoutBlockSizeNV(tensorLayoutV, 1, BLOCK_SIZE); -#endif + const uint bs_k = fa_block_elems(FaTypeK); + const uint bs_v = fa_block_elems(FaTypeV); + tensorLayoutK = setTensorLayoutBlockSizeNV(tensorLayoutK, 1, bs_k); + tensorLayoutV = setTensorLayoutBlockSizeNV(tensorLayoutV, 1, bs_v); tensorLayoutQ = setTensorLayoutDimensionNV(tensorLayoutQ, N, HSK); tensorLayoutK = setTensorLayoutDimensionNV(tensorLayoutK, KV, HSK); @@ -98,10 +190,12 @@ void main() { if (Clamp != gl_CooperativeMatrixClampModeConstantNV) { q_stride &= ~7; -#if !defined(BLOCK_SIZE) - k_stride &= ~7; - v_stride &= ~7; -#endif + if (bs_k == 1u) { + k_stride &= ~7; + } + if (bs_v == 1u) { + v_stride &= ~7; + } m_stride &= ~7; } tensorLayoutQ = setTensorLayoutStrideNV(tensorLayoutQ, q_stride, 1); @@ -109,15 +203,15 @@ void main() { tensorLayoutV = setTensorLayoutStrideNV(tensorLayoutV, v_stride, 1); coopmat<Q_TYPE, gl_ScopeWorkgroup, Br, HSK_pad, gl_MatrixUseAccumulator> Q; - coopmat<float16_t, gl_ScopeWorkgroup, Br, HSK_pad, gl_MatrixUseA> Qf16; + coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, Br, HSK_pad, gl_MatrixUseA> Qf16; - uint32_t q_offset = iq2*p.nb02+iq3*p.nb03; + uint32_t q_offset = gqa_iq1*p.nb01*4/*sizeof(float)*/ + iq2*p.nb02+iq3*p.nb03; coopMatLoadTensorNV(Q, data_q, q_offset, sliceTensorLayoutNV(tensorLayoutQ, i * Br, Br, 0, HSK_pad)); - Qf16 = coopmat<float16_t, gl_ScopeWorkgroup, Br, HSK_pad, gl_MatrixUseA>(Q); - Qf16 *= float16_t(p.scale); + Q *= Q_TYPE(p.scale); + Qf16 = coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, Br, HSK_pad, gl_MatrixUseA>(Q); - coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, HSV_pad, gl_MatrixUseAccumulator> O = coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, HSV_pad, gl_MatrixUseAccumulator>(0); + coopmat<O_TYPE, gl_ScopeWorkgroup, Br, HSV_pad, gl_MatrixUseAccumulator> O = coopmat<O_TYPE, gl_ScopeWorkgroup, Br, HSV_pad, gl_MatrixUseAccumulator>(0); coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> L, M; @@ -138,68 +232,97 @@ void main() { coopMatPerElementNV(slopeMat, slopeMat, perElemOpComputeSlope, iq2); } - uint32_t m_offset = 0; + const uint32_t mo_stride = CEIL_DIV(KV, 16 * Bc); + // mo_offset will point to the tile starting at row i*Br and col 0 + uint32_t mo_offset = mo_stride * i; + + uint32_t m_offset = gqa_iq1*KV * 2 /*sizeof(float16_t)*/; if (p.nem2 != 1 || p.nem3 != 1) { - m_offset = ((iq3 % p.nem3) * p.nem2 + (iq2 % p.nem2)) * p.nem1 * KV * 2 /*sizeof(float16_t)*/; + m_offset += ((iq3 % p.nem3) * p.nem2 + (iq2 % p.nem2)) * p.nem1 * KV * 2 /*sizeof(float16_t)*/; + mo_offset += ((iq3 % p.nem3) * p.nem2 + (iq2 % p.nem2)) * CEIL_DIV(p.nem1, Br) * mo_stride; } + uint32_t mask_opt = 0; + uint32_t mask_opt_idx = ~0; + [[dont_unroll]] for (uint32_t j = start_j; j < end_j; ++j) { - coopmat<float16_t, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> mv; - if ((p.mask_n_head_log2 & MASK_ENABLE_BIT) != 0) { - bool nem1_bounds_check = !(p.gqa_ratio > 1) && (p.nem1 % Br) != 0; - - if (nem1_bounds_check) { - tensorLayoutNV<2, gl_CooperativeMatrixClampModeConstantNV> tensorLayoutM = createTensorLayoutNV(2, gl_CooperativeMatrixClampModeConstantNV); - tensorLayoutM = setTensorLayoutDimensionNV(tensorLayoutM, p.nem1, KV); - tensorLayoutM = setTensorLayoutStrideNV(tensorLayoutM, m_stride, 1); - tensorLayoutM = setTensorLayoutClampValueNV(tensorLayoutM, 0xfc00); // -inf in float16_t - - coopmat<float16_t, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> mvmax; - - coopMatLoadTensorNV(mv, data_m, m_offset, sliceTensorLayoutNV(tensorLayoutM, i * Br, Br, j * Bc, Bc)); - - // skip the block if the mask is entirely -inf - coopMatReduceNV(mvmax, mv, gl_CooperativeMatrixReduceRowAndColumnNV, maxReduceFp16); - if (mvmax[0] <= NEG_FLT_MAX_OVER_2) { - continue; - } - } else { - tensorLayoutNV<2, Clamp> tensorLayoutM = createTensorLayoutNV(2, Clamp); - // Don't clamp against nem1 when GQA is enabled - uint32_t m_height = p.gqa_ratio > 1 ? ~0 : p.nem1; - tensorLayoutM = setTensorLayoutDimensionNV(tensorLayoutM, m_height, KV); - tensorLayoutM = setTensorLayoutStrideNV(tensorLayoutM, m_stride, 1); - - coopmat<float16_t, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> mvmax; - - coopMatLoadTensorNV(mv, data_m, m_offset, sliceTensorLayoutNV(tensorLayoutM, i * Br, Br, j * Bc, Bc)); + coopmat<float16_t, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> mv = coopmat<float16_t, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(0); + if (MASK_ENABLE) { - // skip the block if the mask is entirely -inf - coopMatReduceNV(mvmax, mv, gl_CooperativeMatrixReduceRowAndColumnNV, maxReduceFp16); - if (mvmax[0] <= NEG_FLT_MAX_OVER_2) { - continue; + if (USE_MASK_OPT && mask_opt_idx != j / 16) { + mask_opt_idx = j / 16; + mask_opt = data_mask_opt[mo_offset + mask_opt_idx]; + } + uint32_t mask_opt_bits = (mask_opt >> ((j % 16) * 2)) & 0x3; + if (mask_opt_bits == MASK_OPT_ALL_NEG_INF) { + // skip this block + continue; + } + // Only load if the block is not all zeros + if (mask_opt_bits != MASK_OPT_ALL_ZERO) { + bool nem1_bounds_check = !(p.gqa_ratio > 1) && (p.nem1 % Br) != 0; + + if (nem1_bounds_check) { + tensorLayoutNV<2, gl_CooperativeMatrixClampModeConstantNV> tensorLayoutM = createTensorLayoutNV(2, gl_CooperativeMatrixClampModeConstantNV); + tensorLayoutM = setTensorLayoutDimensionNV(tensorLayoutM, p.nem1, KV); + tensorLayoutM = setTensorLayoutStrideNV(tensorLayoutM, m_stride, 1); + tensorLayoutM = setTensorLayoutClampValueNV(tensorLayoutM, 0xfc00); // -inf in float16_t + + coopmat<float16_t, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> mvmax; + + coopMatLoadTensorNV(mv, data_m, m_offset, sliceTensorLayoutNV(tensorLayoutM, i * Br, Br, j * Bc, Bc)); + // skip the block if the mask is entirely -inf + coopMatReduceNV(mvmax, mv, gl_CooperativeMatrixReduceRowAndColumnNV, maxReduceFp16); + if (mvmax[0] <= NEG_FLT_MAX_OVER_2) { + continue; + } + } else { + tensorLayoutNV<2, Clamp> tensorLayoutM = createTensorLayoutNV(2, Clamp); + // Don't clamp against nem1 when GQA is enabled + uint32_t m_height = p.gqa_ratio > 1 ? ~0 : p.nem1; + tensorLayoutM = setTensorLayoutDimensionNV(tensorLayoutM, m_height, KV); + tensorLayoutM = setTensorLayoutStrideNV(tensorLayoutM, m_stride, 1); + + coopmat<float16_t, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> mvmax; + + coopMatLoadTensorNV(mv, data_m, m_offset, sliceTensorLayoutNV(tensorLayoutM, i * Br, Br, j * Bc, Bc)); + // skip the block if the mask is entirely -inf + coopMatReduceNV(mvmax, mv, gl_CooperativeMatrixReduceRowAndColumnNV, maxReduceFp16); + if (mvmax[0] <= NEG_FLT_MAX_OVER_2) { + continue; + } } } } coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> S = coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(0); - coopmat<float16_t, gl_ScopeWorkgroup, HSK_pad, Bc, gl_MatrixUseB> K_T; + coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, HSK_pad, Bc, gl_MatrixUseB> K_T; uint32_t k_offset = ik2*p.nb12 + ik3*p.nb13; - coopMatLoadTensorNV(K_T, data_k, k_offset, sliceTensorLayoutNV(tensorLayoutK, j * Bc, Bc, 0, HSK_pad), tensorViewTranspose DECODEFUNC); + // F16: bs_k==1 (direct load). F32: bs_k==4 (vec4 / dequantFuncF32). Q4/Q8 family: bs_k==32. Q1_0: bs_k==128. +#if defined(BFLOAT16) + coopMatLoadTensorNV(K_T, data_k, k_offset, sliceTensorLayoutNV(tensorLayoutK, j * Bc, Bc, 0, HSK_pad), tensorViewTranspose); +#else + const bool k_use_decode = (bs_k > 1u); + if (k_use_decode) { + coopMatLoadTensorNV(K_T, data_k, k_offset, sliceTensorLayoutNV(tensorLayoutK, j * Bc, Bc, 0, HSK_pad), tensorViewTranspose FADECODEK); + } else { + coopMatLoadTensorNV(K_T, data_k, k_offset, sliceTensorLayoutNV(tensorLayoutK, j * Bc, Bc, 0, HSK_pad), tensorViewTranspose); + } +#endif S = coopMatMulAdd(Qf16, K_T, S); - if (p.logit_softcap != 0.0f) { + if (LOGIT_SOFTCAP) { [[unroll]] for (int k = 0; k < S.length(); ++k) { S[k] = ACC_TYPE(p.logit_softcap)*tanh(S[k]); } } - if ((p.mask_n_head_log2 & MASK_ENABLE_BIT) != 0) { + if (MASK_ENABLE) { S += slopeMat*coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(mv); } @@ -218,6 +341,8 @@ void main() { coopMatReduceNV(rowmax, S, gl_CooperativeMatrixReduceRowNV, maxReduce); + rowmax += coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(FATTN_KQ_MAX_OFFSET); + coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> Mold = M; // M = max(rowmax, Mold) @@ -238,17 +363,26 @@ void main() { coopMatPerElementNV(P, P, replacePadding, ACC_TYPE(0.0), R, C); } - coopmat<float16_t, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseA> P_A = coopmat<float16_t, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseA>(P); + coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseA> P_A = coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseA>(P); // compute rowsum by multiplying by matrix of all ones. - coopmat<float16_t, gl_ScopeWorkgroup, Bc, Bc, gl_MatrixUseB> One = coopmat<float16_t, gl_ScopeWorkgroup, Bc, Bc, gl_MatrixUseB>(1.0); + coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, Bc, Bc, gl_MatrixUseB> One = coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, Bc, Bc, gl_MatrixUseB>(1.0); rowsum = coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(0.0); rowsum = coopMatMulAdd(P_A, One, rowsum); - coopmat<float16_t, gl_ScopeWorkgroup, Bc, HSV_pad, gl_MatrixUseB> V; + coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, Bc, HSV_pad, gl_MatrixUseB> V; uint32_t v_offset = iv2*p.nb22 + iv3*p.nb23; - coopMatLoadTensorNV(V, data_v, v_offset, sliceTensorLayoutNV(tensorLayoutV, j * Bc, Bc, 0, HSV_pad) DECODEFUNC); +#if defined(BFLOAT16) + coopMatLoadTensorNV(V, data_v, v_offset, sliceTensorLayoutNV(tensorLayoutV, j * Bc, Bc, 0, HSV_pad)); +#else + const bool v_use_decode = (bs_v > 1u); + if (v_use_decode) { + coopMatLoadTensorNV(V, data_v, v_offset, sliceTensorLayoutNV(tensorLayoutV, j * Bc, Bc, 0, HSV_pad) FADECODEV); + } else { + coopMatLoadTensorNV(V, data_v, v_offset, sliceTensorLayoutNV(tensorLayoutV, j * Bc, Bc, 0, HSV_pad)); + } +#endif L = eM*L + rowsum; @@ -260,11 +394,8 @@ void main() { // resize eM by using smear/reduce coopMatReduceNV(eMdiag, eM, gl_CooperativeMatrixReduceRowNV, smearReduce); - // multiply with fp16 accumulation, then add to O. - coopmat<float16_t, gl_ScopeWorkgroup, Br, HSV_pad, gl_MatrixUseAccumulator> PV = coopmat<float16_t, gl_ScopeWorkgroup, Br, HSV_pad, gl_MatrixUseAccumulator>(0); - PV = coopMatMulAdd(P_A, V, PV); - - O = eMdiag * O + coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, HSV_pad, gl_MatrixUseAccumulator>(PV); + O *= coopmat<O_TYPE, gl_ScopeWorkgroup, Br, HSV_pad, gl_MatrixUseAccumulator>(eMdiag); + O = coopMatMulAdd(P_A, V, O); } // If there is split_k, then the split_k resolve shader does the final @@ -272,12 +403,19 @@ void main() { if (p.k_num > 1) { coopmat<D_TYPE, gl_ScopeWorkgroup, Br, HSV_pad, gl_MatrixUseAccumulator> O_D = coopmat<D_TYPE, gl_ScopeWorkgroup, Br, HSV_pad, gl_MatrixUseAccumulator>(O); - uint32_t o_offset = HSV * p.ne1 * (split_k_index + iq3 * p.k_num); - coopMatPerElementNV(O_D, O_D, perElemOpGqaStore, o_offset, iq2, N); - - o_offset = HSV * p.ne1 * p.ne3 * p.k_num + p.ne1 * (split_k_index + iq3 * p.k_num) * 2; - coopMatPerElementNV(L, L, perElemOpStoreCol0, o_offset, iq2, N); - coopMatPerElementNV(M, M, perElemOpStoreCol0, o_offset + p.ne1, iq2, N); + if (p.gqa_ratio > 1) { + // note: O and Q have swapped coord 1,2. + uint32_t o_offset = HSV * p.ne1 * (split_k_index + p.k_num * (gqa_iq1 + p.ne2 * iq3)); + coopMatPerElementNV(O_D, O_D, perElemOpGqaStore, o_offset, iq2, N); + + o_offset = HSV * p.ne1 * p.k_num * p.ne2 * p.ne3 + p.ne1 * 2 * (split_k_index + p.k_num * (gqa_iq1 + p.ne2 * iq3)); + coopMatPerElementNV(L, L, perElemOpStoreCol0, o_offset, iq2, N); + coopMatPerElementNV(M, M, perElemOpStoreCol0, o_offset + p.ne1, iq2, N); + } else { + coopMatPerElementNV(O_D, O_D, perElemOpNonGqaSplitKStore, 0u, iq2, N); + coopMatPerElementNV(L, L, perElemOpNonGqaSplitKStoreCol0, 0u, iq2, N); + coopMatPerElementNV(M, M, perElemOpNonGqaSplitKStoreCol0, p.ne1, iq2, N); + } return; } @@ -305,7 +443,7 @@ void main() { if (sink > Mr[i]) { ms = exp(Mr[i] - sink); - O[i] *= ms; + O[i] *= O_TYPE(ms); } else { vs = exp(sink - Mr[i]); } @@ -319,15 +457,16 @@ void main() { Ldiag[k] = (Ldiag[k] == 0.0) ? ACC_TYPE(0.0) : (ACC_TYPE(1.0) / Ldiag[k]); } - O = Ldiag*O; + coopmat<D_TYPE, gl_ScopeWorkgroup, Br, HSV_pad, gl_MatrixUseAccumulator> O_D = coopmat<D_TYPE, gl_ScopeWorkgroup, Br, HSV_pad, gl_MatrixUseAccumulator>(O); + + O_D = coopmat<D_TYPE, gl_ScopeWorkgroup, Br, HSV_pad, gl_MatrixUseAccumulator>(Ldiag)*O_D; #if defined(ACC_TYPE_MAX) - [[unroll]] for (uint i = 0; i < O.length(); ++i) { O[i] = clamp(O[i], -ACC_TYPE_MAX, ACC_TYPE_MAX); } + [[unroll]] for (uint i = 0; i < O_D.length(); ++i) { O_D[i] = clamp(O_D[i], D_TYPE(-ACC_TYPE_MAX), D_TYPE(ACC_TYPE_MAX)); } #endif - uint32_t o_offset = iq3*p.ne2*p.ne1*HSV; + uint32_t o_offset = gqa_iq1*p.ne1*HSV + iq3*p.ne2*p.ne1*HSV; - coopmat<D_TYPE, gl_ScopeWorkgroup, Br, HSV_pad, gl_MatrixUseAccumulator> O_D = coopmat<D_TYPE, gl_ScopeWorkgroup, Br, HSV_pad, gl_MatrixUseAccumulator>(O); if (p.gqa_ratio > 1) { coopMatPerElementNV(O_D, O_D, perElemOpGqaStore, o_offset, iq2, N); } else { diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_dequant.glsl b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_dequant.glsl new file mode 100644 index 00000000000..8704479d960 --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_dequant.glsl @@ -0,0 +1,131 @@ +// Asymmetric K/V flash attention: aliased SSBO views of bindings 1 (K) and 2 (V) +// covering every supported FA element type, plus an uber dequantize4() that +// switches on FaTypeK / FaTypeV. After spec-constant specialization the driver +// folds away every path except the one matching the K/V type for this pipeline. +// +// Included by flash_attn.comp and flash_attn_cm1.comp. Not included by +// flash_attn_cm2.comp, which has its own buffer_reference-based decode path. +// +// We use macros (rather than per-quant decode functions taking a struct) on +// purpose: the FA shaders don't enable GL_EXT_shader_explicit_arithmetic_types_float16 +// when FLOAT16 isn't defined, which makes float16-containing struct values +// illegal to return from / pass to functions. Macros expand inline where the +// float16 stays in storage and is converted to FLOAT_TYPE at use. + +// F32 is fed as a vec4 "block" (4 floats), matching what dequant_funcs_cm2.glsl +// does for F32 in the cm2 shader. FaBlockBytesK/V == 16 for F32. +layout (binding = 1) readonly buffer K_PACKED_F32 { vec4 data[]; } k_packed_f32; +layout (binding = 2) readonly buffer V_PACKED_F32 { vec4 data[]; } v_packed_f32; + +layout (binding = 1) readonly buffer K_PACKED_Q4_0 { block_q4_0_packed16 data[]; } k_packed_q4_0; +layout (binding = 2) readonly buffer V_PACKED_Q4_0 { block_q4_0_packed16 data[]; } v_packed_q4_0; +layout (binding = 1) readonly buffer K_PACKED_Q4_1 { block_q4_1_packed16 data[]; } k_packed_q4_1; +layout (binding = 2) readonly buffer V_PACKED_Q4_1 { block_q4_1_packed16 data[]; } v_packed_q4_1; +layout (binding = 1) readonly buffer K_PACKED_Q5_0 { block_q5_0_packed16 data[]; } k_packed_q5_0; +layout (binding = 2) readonly buffer V_PACKED_Q5_0 { block_q5_0_packed16 data[]; } v_packed_q5_0; +layout (binding = 1) readonly buffer K_PACKED_Q5_1 { block_q5_1_packed16 data[]; } k_packed_q5_1; +layout (binding = 2) readonly buffer V_PACKED_Q5_1 { block_q5_1_packed16 data[]; } v_packed_q5_1; +layout (binding = 1) readonly buffer K_PACKED_Q8_0 { block_q8_0_packed16 data[]; } k_packed_q8_0; +layout (binding = 2) readonly buffer V_PACKED_Q8_0 { block_q8_0_packed16 data[]; } v_packed_q8_0; + +layout (binding = 1) readonly buffer K_PACKED_BF16 { u16vec4 data[]; } k_packed_bf16; +layout (binding = 2) readonly buffer V_PACKED_BF16 { u16vec4 data[]; } v_packed_bf16; + +// Q4_1 and Q5_1 packed32 views: aliased to the same memory as the packed16 +// views, used by the MMQ K-side hot path for fast 4-uint loads. +layout (binding = 1) readonly buffer K_PACKED_Q4_1_P32 { block_q4_1_packed32 data[]; } k_packed_q4_1_p32; +layout (binding = 1) readonly buffer K_PACKED_Q5_1_P32 { block_q5_1_packed32 data[]; } k_packed_q5_1_p32; + +// Per-quant decode bodies are expanded once for the K view set and once for +// the V view set. The macros take the buffer name as a parameter. +#define FA_DEQUANT4_F32(BUF) \ + return FLOAT_TYPEV4(BUF.data[a_offset + ib]); + +#define FA_DEQUANT4_Q4_0(BUF) { \ + uint vui_lo = uint(BUF.data[a_offset + ib].qs[(iqs & 0xF) / 2 + 0]); \ + uint vui_hi = uint(BUF.data[a_offset + ib].qs[(iqs & 0xF) / 2 + 1]); \ + uint shift = (iqs & 0x10) >> 2; \ + vui_lo >>= shift; \ + vui_hi >>= shift; \ + FLOAT_TYPEV4 nibbles = FLOAT_TYPEV4(vui_lo & 0xF, (vui_lo >> 8) & 0xF, \ + vui_hi & 0xF, (vui_hi >> 8) & 0xF); \ + return FLOAT_TYPE(BUF.data[a_offset + ib].d) * (nibbles - FLOAT_TYPE(8.0f)); \ +} + +#define FA_DEQUANT4_Q4_1(BUF) { \ + uint vui_lo = uint(BUF.data[a_offset + ib].qs[(iqs & 0xF) / 2 + 0]); \ + uint vui_hi = uint(BUF.data[a_offset + ib].qs[(iqs & 0xF) / 2 + 1]); \ + uint shift = (iqs & 0x10) >> 2; \ + vui_lo >>= shift; \ + vui_hi >>= shift; \ + FLOAT_TYPEV4 nibbles = FLOAT_TYPEV4(vui_lo & 0xF, (vui_lo >> 8) & 0xF, \ + vui_hi & 0xF, (vui_hi >> 8) & 0xF); \ + return FLOAT_TYPE(BUF.data[a_offset + ib].d) * nibbles \ + + FLOAT_TYPE(BUF.data[a_offset + ib].m); \ +} + +#define FA_DEQUANT4_Q5_0(BUF) { \ + uint vui_lo = uint(BUF.data[a_offset + ib].qs[(iqs & 0xF) / 2 + 0]); \ + uint vui_hi = uint(BUF.data[a_offset + ib].qs[(iqs & 0xF) / 2 + 1]); \ + uint shift = (iqs & 0x10) >> 2; \ + vui_lo >>= shift; \ + vui_hi >>= shift; \ + uint qh = uint(BUF.data[a_offset + ib].qh[0]) \ + | (uint(BUF.data[a_offset + ib].qh[1]) << 16); \ + FLOAT_TYPEV4 hb = FLOAT_TYPEV4((qh >> iqs) & 1, (qh >> (iqs + 1)) & 1, \ + (qh >> (iqs + 2)) & 1, (qh >> (iqs + 3)) & 1) \ + * FLOAT_TYPE(16.0f); \ + FLOAT_TYPEV4 nibbles = FLOAT_TYPEV4(vui_lo & 0xF, (vui_lo >> 8) & 0xF, \ + vui_hi & 0xF, (vui_hi >> 8) & 0xF); \ + return FLOAT_TYPE(BUF.data[a_offset + ib].d) * (nibbles + hb - FLOAT_TYPE(16.0f)); \ +} + +#define FA_DEQUANT4_Q5_1(BUF) { \ + uint vui_lo = uint(BUF.data[a_offset + ib].qs[(iqs & 0xF) / 2 + 0]); \ + uint vui_hi = uint(BUF.data[a_offset + ib].qs[(iqs & 0xF) / 2 + 1]); \ + uint shift = (iqs & 0x10) >> 2; \ + vui_lo >>= shift; \ + vui_hi >>= shift; \ + uint qh = BUF.data[a_offset + ib].qh; \ + FLOAT_TYPEV4 hb = FLOAT_TYPEV4((qh >> iqs) & 1, (qh >> (iqs + 1)) & 1, \ + (qh >> (iqs + 2)) & 1, (qh >> (iqs + 3)) & 1) \ + * FLOAT_TYPE(16.0f); \ + FLOAT_TYPEV4 nibbles = FLOAT_TYPEV4(vui_lo & 0xF, (vui_lo >> 8) & 0xF, \ + vui_hi & 0xF, (vui_hi >> 8) & 0xF); \ + return FLOAT_TYPE(BUF.data[a_offset + ib].d) * (nibbles + hb) \ + + FLOAT_TYPE(BUF.data[a_offset + ib].m); \ +} + +#define FA_DEQUANT4_Q8_0(BUF) { \ + const i8vec2 v0 = unpack8(int32_t(BUF.data[a_offset + ib].qs[iqs / 2 ])).xy; \ + const i8vec2 v1 = unpack8(int32_t(BUF.data[a_offset + ib].qs[iqs / 2 + 1])).xy; \ + return FLOAT_TYPE(BUF.data[a_offset + ib].d) * FLOAT_TYPEV4(v0.x, v0.y, v1.x, v1.y); \ +} + +#define FA_DEQUANT4_BF16(BUF) \ + return FLOAT_TYPEV4(bf16_to_fp32(uvec4(BUF.data[(a_offset + ib) / 4]))); + +FLOAT_TYPEV4 dequantize4(uint ib, uint iqs, uint a_offset, uint binding_idx) { + if (binding_idx == BINDING_IDX_K) { + switch (FaTypeK) { + case FA_TYPE_F32: FA_DEQUANT4_F32 (k_packed_f32) + case FA_TYPE_Q4_0: FA_DEQUANT4_Q4_0(k_packed_q4_0) + case FA_TYPE_Q4_1: FA_DEQUANT4_Q4_1(k_packed_q4_1) + case FA_TYPE_Q5_0: FA_DEQUANT4_Q5_0(k_packed_q5_0) + case FA_TYPE_Q5_1: FA_DEQUANT4_Q5_1(k_packed_q5_1) + case FA_TYPE_Q8_0: FA_DEQUANT4_Q8_0(k_packed_q8_0) + case FA_TYPE_BF16: FA_DEQUANT4_BF16(k_packed_bf16) + } + } else { + switch (FaTypeV) { + case FA_TYPE_F32: FA_DEQUANT4_F32 (v_packed_f32) + case FA_TYPE_Q4_0: FA_DEQUANT4_Q4_0(v_packed_q4_0) + case FA_TYPE_Q4_1: FA_DEQUANT4_Q4_1(v_packed_q4_1) + case FA_TYPE_Q5_0: FA_DEQUANT4_Q5_0(v_packed_q5_0) + case FA_TYPE_Q5_1: FA_DEQUANT4_Q5_1(v_packed_q5_1) + case FA_TYPE_Q8_0: FA_DEQUANT4_Q8_0(v_packed_q8_0) + case FA_TYPE_BF16: FA_DEQUANT4_BF16(v_packed_bf16) + } + } + return FLOAT_TYPEV4(0); +} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_mask_opt.comp b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_mask_opt.comp new file mode 100644 index 00000000000..0e417708062 --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_mask_opt.comp @@ -0,0 +1,162 @@ +#version 450 + +#extension GL_EXT_control_flow_attributes : enable +#extension GL_EXT_shader_16bit_storage : enable +#extension GL_KHR_shader_subgroup_arithmetic : enable + +layout (constant_id = 0) const uint BLOCK_SIZE = 128; +layout (constant_id = 1) const uint NUM_SUBGROUPS = 4; +layout (constant_id = 2) const uint Br = 32; +layout (constant_id = 3) const uint Bc = 32; + +layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) readonly buffer A {float16_t data_a[];}; +layout (binding = 0) readonly buffer Av4 {f16vec4 data_av4[];}; +layout (binding = 1) writeonly buffer D {uint data_d[];}; + +layout (push_constant) uniform parameter { + uint nem0; + uint nem1; + uint nem2; + uint nbm1; + uint nbm2; + uint nbm3; + uint nbd1; + uint nbd2; + uint nbd3; +}; + +#define MASK_OPT_ALL_NEG_INF 1 +#define MASK_OPT_ALL_ZERO 2 + +shared float minsh[NUM_SUBGROUPS]; +shared float maxsh[NUM_SUBGROUPS]; + +float FLT_MAX_OVER_2 = uintBitsToFloat(0x7EFFFFFF); + +void loadvec4(inout uint result, const uint i0, const uint i1, const uint i2, const uint i3, const bool need_bounds_check) { + const uint tid = gl_LocalInvocationIndex; + + [[unroll]] for (uint block_x = 0; block_x < 16; ++block_x) { + float min_v = FLT_MAX_OVER_2; + float max_v = -FLT_MAX_OVER_2; + [[unroll]] for (uint i = 0; i < Br * Bc / 4; i += BLOCK_SIZE) { + uint j0 = (i + tid) % (Bc / 4); + uint j1 = (i + tid) / (Bc / 4); + + j0 *= 4; + j0 += (i0 * 16 + block_x) * Bc; + j1 += i1 * Br; + + if (!need_bounds_check || j0 + 3 < nem0) { + vec4 f = vec4(data_av4[(j0 + j1 * nbm1 + i2 * nbm2 + i3 * nbm3) / 4]); + [[unroll]] for (int c = 0; c < 4; ++c) { + min_v = min(min_v, f[c]); + max_v = max(max_v, f[c]); + } + } else { + [[unroll]] for (int c = 0; c < 4; ++c) { + if (j0 + c < nem0) { + float f = float(data_a[j0 + j1 * nbm1 + i2 * nbm2 + i3 * nbm3]); + min_v = min(min_v, f); + max_v = max(max_v, f); + } + } + } + } + min_v = subgroupMin(min_v); + max_v = subgroupMax(max_v); + if (gl_SubgroupInvocationID == 0) { + minsh[gl_SubgroupID] = min_v; + maxsh[gl_SubgroupID] = max_v; + } + barrier(); + if (tid == 0) { + [[unroll]] for (uint i = 0; i < NUM_SUBGROUPS; ++i) { + min_v = min(min_v, minsh[i]); + max_v = max(max_v, maxsh[i]); + } + if (max_v <= -FLT_MAX_OVER_2) { + result |= 1 << (2*block_x); + } + if (min_v == 0.0f && max_v == 0.0f) { + result |= 2 << (2*block_x); + } + } + barrier(); + } +} + +// For each Br x Bc block of the mask (input) buffer, read all values and check +// if it's all -inf or all zero. Write out a two-bit code indicating which it is +// (or zero for neither). Each workgroup processes 16 tiles and writes out a +// 32-bit result mask. +// +// TODO: This is a lot of work per workgroup, might make sense to split this into +// more workgroups in the future. +void main() { + // Each workgroup handles a row + const uint tid = gl_LocalInvocationIndex; + const uint i0 = gl_WorkGroupID.x; + const uint i1 = gl_WorkGroupID.y; + const uint i2 = gl_WorkGroupID.z % nem2; + const uint i3 = gl_WorkGroupID.z / nem2; + + uint result = 0; + + // Fast path for fully in-bounds blocks where we can do f16vec4 loads + if ((nem0 % Bc) == 0 && (nem1 % Br) == 0 && + ((Br * Bc) % (BLOCK_SIZE * 4)) == 0) { + if ((i0 + 1) * 16 * Bc <= nem0) { + loadvec4(result, i0, i1, i2, i3, false); + } else { + loadvec4(result, i0, i1, i2, i3, true); + } + } else { + [[unroll]] for (uint block_x = 0; block_x < 16; ++block_x) { + float min_v = FLT_MAX_OVER_2; + float max_v = -FLT_MAX_OVER_2; + [[unroll]] for (uint i = 0; i < Br * Bc; i += BLOCK_SIZE) { + if ((Br * Bc % BLOCK_SIZE) != 0 && i + tid >= Br * Bc) { + continue; + } + uint j0 = (i + tid) % Bc; + uint j1 = (i + tid) / Bc; + + j0 += (i0 * 16 + block_x) * Bc; + j1 += i1 * Br; + + if (j0 < nem0 && j1 < nem1) { + float f = float(data_a[j0 + j1 * nbm1 + i2 * nbm2 + i3 * nbm3]); + min_v = min(min_v, f); + max_v = max(max_v, f); + } + } + min_v = subgroupMin(min_v); + max_v = subgroupMax(max_v); + if (gl_SubgroupInvocationID == 0) { + minsh[gl_SubgroupID] = min_v; + maxsh[gl_SubgroupID] = max_v; + } + barrier(); + if (tid == 0) { + [[unroll]] for (uint i = 0; i < NUM_SUBGROUPS; ++i) { + min_v = min(min_v, minsh[i]); + max_v = max(max_v, maxsh[i]); + } + if (max_v <= -FLT_MAX_OVER_2) { + result |= 1 << (2*block_x); + } + if (min_v == 0.0f && max_v == 0.0f) { + result |= 2 << (2*block_x); + } + } + barrier(); + } + } + + if (tid == 0) { + data_d[i0 + i1 * nbd1 + i2 * nbd2 + i3 * nbd3] = result; + } +} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_mmq_funcs.glsl b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_mmq_funcs.glsl new file mode 100644 index 00000000000..6bf10a7cffd --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_mmq_funcs.glsl @@ -0,0 +1,203 @@ +// MMQ K-side helpers, asymmetric form. Each function dispatches on FaTypeK and +// reads from the matching aliased K binding declared in flash_attn_dequant.glsl. +// Spec-constant specialization folds the unused paths. + +int32_t get_k_qs(uint ib, uint iqs, uint a_offset) { + switch (FaTypeK) { + case FA_TYPE_Q4_0: { + uint vui = pack32(u16vec2(k_packed_q4_0.data[a_offset + ib].qs[(iqs & 0xF) / 2 + 0], + k_packed_q4_0.data[a_offset + ib].qs[(iqs & 0xF) / 2 + 1])); + uint shift = (iqs & 0x10) >> 2; + vui >>= shift; + return int32_t(vui & 0x0F0F0F0F); + } + case FA_TYPE_Q4_1: { // uses packed32 alias + uint vui = k_packed_q4_1_p32.data[a_offset + ib].qs[(iqs & 0xF) / 4]; + uint shift = (iqs & 0x10) >> 2; + vui >>= shift; + return int32_t(vui & 0x0F0F0F0F); + } + case FA_TYPE_Q5_0: { + uint vui = pack32(u16vec2(k_packed_q5_0.data[a_offset + ib].qs[(iqs & 0xF) / 2 + 0], + k_packed_q5_0.data[a_offset + ib].qs[(iqs & 0xF) / 2 + 1])); + uint qh = pack32(u16vec2(k_packed_q5_0.data[a_offset + ib].qh[0], + k_packed_q5_0.data[a_offset + ib].qh[1])); + uint shift = (iqs & 0x10) >> 2; + vui >>= shift; + uint qh_bits = (qh >> iqs) & 0xF; + return int32_t(vui & 0x0F0F0F0F) | int32_t((qh_bits * 0x02040810u) & 0x10101010u); + } + case FA_TYPE_Q5_1: { // qs via packed32, qh via packed16 + uint vui = k_packed_q5_1_p32.data[a_offset + ib].qs[(iqs & 0xF) / 4]; + uint qh = k_packed_q5_1.data[a_offset + ib].qh; + uint shift = (iqs & 0x10) >> 2; + vui >>= shift; + uint qh_bits = (qh >> iqs) & 0xF; + return int32_t(vui & 0x0F0F0F0F) | int32_t((qh_bits * 0x02040810u) & 0x10101010u); + } + case FA_TYPE_Q8_0: { + return pack32(i16vec2(k_packed_q8_0.data[a_offset + ib].qs[iqs / 2], + k_packed_q8_0.data[a_offset + ib].qs[iqs / 2 + 1])); + } + default: return 0; + } +} + +// Per-block scale/min, packed as (d, m). Single-scale types (Q4_0, Q5_0, Q8_0) +// return (d, 0) so call sites always see the same shape. +FLOAT_TYPEV2 get_k_scale(uint ib, uint a_offset) { + switch (FaTypeK) { + case FA_TYPE_Q4_0: return FLOAT_TYPEV2(FLOAT_TYPE(k_packed_q4_0.data[a_offset + ib].d), 0.0); + case FA_TYPE_Q4_1: return FLOAT_TYPEV2(k_packed_q4_1_p32.data[a_offset + ib].dm); + case FA_TYPE_Q5_0: return FLOAT_TYPEV2(FLOAT_TYPE(k_packed_q5_0.data[a_offset + ib].d), 0.0); + case FA_TYPE_Q5_1: return FLOAT_TYPEV2(k_packed_q5_1_p32.data[a_offset + ib].dm); + case FA_TYPE_Q8_0: return FLOAT_TYPEV2(FLOAT_TYPE(k_packed_q8_0.data[a_offset + ib].d), 0.0); + default: return FLOAT_TYPEV2(0); + } +} + +void k_block_to_shmem(const uint buf_ib, const uint global_ib, const uint iqs, const uint a_offset) { + // kblocksh[].qs is int32_t for the unified MMQ struct; uint sources need + // explicit casts. The bit pattern is what we care about here -- the actual + // signed/unsigned interpretation happens downstream in the dot product. + switch (FaTypeK) { + case FA_TYPE_Q4_0: { + kblocksh[buf_ib].qs[iqs] = int32_t(pack32(u16vec2(k_packed_q4_0.data[a_offset + global_ib].qs[iqs * 2], + k_packed_q4_0.data[a_offset + global_ib].qs[iqs * 2 + 1]))); + break; + } + case FA_TYPE_Q4_1: { + kblocksh[buf_ib].qs[iqs] = int32_t(k_packed_q4_1_p32.data[a_offset + global_ib].qs[iqs]); + break; + } + case FA_TYPE_Q5_0: { + kblocksh[buf_ib].qs[iqs] = int32_t(pack32(u16vec2(k_packed_q5_0.data[a_offset + global_ib].qs[iqs * 2], + k_packed_q5_0.data[a_offset + global_ib].qs[iqs * 2 + 1]))); + if (iqs == 0) { + kblocksh[buf_ib].qh = pack32(u16vec2(k_packed_q5_0.data[a_offset + global_ib].qh[0], + k_packed_q5_0.data[a_offset + global_ib].qh[1])); + } + break; + } + case FA_TYPE_Q5_1: { + kblocksh[buf_ib].qs[iqs] = int32_t(k_packed_q5_1_p32.data[a_offset + global_ib].qs[iqs]); + if (iqs == 0) { + kblocksh[buf_ib].qh = k_packed_q5_1.data[a_offset + global_ib].qh; + } + break; + } + case FA_TYPE_Q8_0: { + kblocksh[buf_ib].qs[iqs] = pack32(i16vec2(k_packed_q8_0.data[a_offset + global_ib].qs[iqs * 2], + k_packed_q8_0.data[a_offset + global_ib].qs[iqs * 2 + 1])); + break; + } + } + + if (iqs == 0) { + // Q4_0/Q5_0/Q8_0 store dm.x = d; Q4_1/Q5_1 store dm = (d, m) pair. + switch (FaTypeK) { + case FA_TYPE_Q4_0: kblocksh[buf_ib].dm = FLOAT_TYPEV2(FLOAT_TYPE(k_packed_q4_0.data[a_offset + global_ib].d), 0.0); break; + case FA_TYPE_Q4_1: kblocksh[buf_ib].dm = FLOAT_TYPEV2(k_packed_q4_1_p32.data[a_offset + global_ib].dm); break; + case FA_TYPE_Q5_0: kblocksh[buf_ib].dm = FLOAT_TYPEV2(FLOAT_TYPE(k_packed_q5_0.data[a_offset + global_ib].d), 0.0); break; + case FA_TYPE_Q5_1: kblocksh[buf_ib].dm = FLOAT_TYPEV2(k_packed_q5_1_p32.data[a_offset + global_ib].dm); break; + case FA_TYPE_Q8_0: kblocksh[buf_ib].dm = FLOAT_TYPEV2(FLOAT_TYPE(k_packed_q8_0.data[a_offset + global_ib].d), 0.0); break; + } + } +} + +// d_per_step==8 hot path: read one full 32-element block worth of nibble-packed +// int32 quants. Equivalent to 8 calls to get_k_qs(ib, d*4, a_offset) but reads +// qh (Q5_*) and runs pack32 (Q4_0/Q5_0) once per block instead of per nibble +// quad. iqs is always 0 in this path (hsk4 % 8 == 0 implies block-aligned). +// Q8_0 takes the generic get_k_qs path because its qs layout (i8 pairs) doesn't +// share this nibble shape. +// +// Returned via a struct so the caller's k_quants array (sized from spec +// constants) doesn't need to match a fixed[8] out-parameter type. +struct fa_k_qs_block8 { + int32_t qs[8]; +}; + +fa_k_qs_block8 get_k_qs_block8(uint ib, uint a_offset) { + fa_k_qs_block8 r; + uint qh = 0; + if (FaTypeK == FA_TYPE_Q5_0) { + qh = pack32(u16vec2(k_packed_q5_0.data[a_offset + ib].qh[0], + k_packed_q5_0.data[a_offset + ib].qh[1])); + } else if (FaTypeK == FA_TYPE_Q5_1) { + qh = k_packed_q5_1.data[a_offset + ib].qh; + } + const bool has_qh = (FaTypeK == FA_TYPE_Q5_0) || (FaTypeK == FA_TYPE_Q5_1); + [[unroll]] for (uint32_t d = 0; d < 4; d++) { + uint vui = 0; + switch (FaTypeK) { + case FA_TYPE_Q4_0: { // packed16 + vui = pack32(u16vec2(k_packed_q4_0.data[a_offset + ib].qs[d * 2 + 0], + k_packed_q4_0.data[a_offset + ib].qs[d * 2 + 1])); + break; + } + case FA_TYPE_Q4_1: { // packed32 alias + vui = k_packed_q4_1_p32.data[a_offset + ib].qs[d]; + break; + } + case FA_TYPE_Q5_0: { // packed16 + vui = pack32(u16vec2(k_packed_q5_0.data[a_offset + ib].qs[d * 2 + 0], + k_packed_q5_0.data[a_offset + ib].qs[d * 2 + 1])); + break; + } + case FA_TYPE_Q5_1: { // packed32 alias + vui = k_packed_q5_1_p32.data[a_offset + ib].qs[d]; + break; + } + } + r.qs[d ] = int32_t( vui & 0x0F0F0F0F); + r.qs[d + 4] = int32_t((vui >> 4) & 0x0F0F0F0F); + if (has_qh) { + uint qh_lo = (qh >> (d * 4)) & 0xFu; + uint qh_hi = (qh >> (d * 4 + 16)) & 0xFu; + r.qs[d ] |= int32_t((qh_lo * 0x02040810u) & 0x10101010u); + r.qs[d + 4] |= int32_t((qh_hi * 0x02040810u) & 0x10101010u); + } + } + return r; +} + +int32_t get_k_qs_shmem(const uint buf_ib, const uint pos) { + switch (FaTypeK) { + case FA_TYPE_Q4_0: + case FA_TYPE_Q4_1: { + uint sub = pos % 4; + uint shift = ((pos % 8) >= 4) ? 4u : 0u; + return int32_t((uint(kblocksh[buf_ib].qs[sub]) >> shift) & 0x0F0F0F0Fu); + } + case FA_TYPE_Q5_0: + case FA_TYPE_Q5_1: { + uint sub = pos % 4; + uint shift = ((pos % 8) >= 4) ? 4u : 0u; + int32_t result = int32_t((uint(kblocksh[buf_ib].qs[sub]) >> shift) & 0x0F0F0F0Fu); + uint qh_bits = (kblocksh[buf_ib].qh >> (pos * 4u)) & 0xFu; + return result | int32_t((qh_bits * 0x02040810u) & 0x10101010u); + } + case FA_TYPE_Q8_0: { + return kblocksh[buf_ib].qs[pos]; + } + default: return 0; + } +} + +ACC_TYPE k_dot_correction(const uint qib, const ACC_TYPEV2 k_dm) { + switch (FaTypeK) { + case FA_TYPE_Q4_0: return -ACC_TYPE(8.0) * ACC_TYPE(Qf[qib].ds.y) * k_dm.x; + case FA_TYPE_Q5_0: return -ACC_TYPE(16.0) * ACC_TYPE(Qf[qib].ds.y) * k_dm.x; + case FA_TYPE_Q4_1: + case FA_TYPE_Q5_1: return ACC_TYPE(Qf[qib].ds.y) * k_dm.y; + default: return ACC_TYPE(0.0); + } +} + +void k_block_to_shmem_zero(const uint buf_ib, const uint iqs) { + kblocksh[buf_ib].qs[iqs] = 0; + if (iqs == 0) { + kblocksh[buf_ib].dm = FLOAT_TYPEV2(0.0f); + } +} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp index 4eaddd31a8f..68917fc0bb0 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp @@ -12,7 +12,8 @@ layout (binding = 2) writeonly buffer D {float data_d[];}; layout (push_constant) uniform parameter { uint D; - uint N; + uint ne1; + uint ne2; uint ne3; uint k_num; uint sinks; @@ -24,15 +25,15 @@ void main() { // Each workgroup handles a row const uint n = gl_WorkGroupID.x; const uint tid = gl_LocalInvocationID.x; - const uint iq3 = gl_WorkGroupID.z; + const uint i2 = gl_WorkGroupID.z % p.ne2; + const uint i3 = gl_WorkGroupID.z / p.ne2; uint D = p.D; - uint N = p.N; uint k_num = p.k_num; - uint l_offset = D * N * p.ne3 * k_num + N * iq3 * k_num * 2 + n; - uint m_offset = D * N * p.ne3 * k_num + N * iq3 * k_num * 2 + N + n; - uint lm_stride = N * 2; + uint l_offset = D * p.ne1 * p.ne2 * p.ne3 * k_num + p.ne1 * 2 * (0/*split_k_index*/ + p.k_num * (i2 + p.ne2 * i3)) + n; + uint m_offset = D * p.ne1 * p.ne2 * p.ne3 * k_num + p.ne1 * 2 * (0/*split_k_index*/ + p.k_num * (i2 + p.ne2 * i3)) + p.ne1 + n; + uint lm_stride = p.ne1 * 2; // Compute the max m value for the row float m_max = -1.0/0.0; @@ -99,7 +100,7 @@ void main() { if (d < D) { float O = 0.0; [[unroll]] for (uint k = 0; k < k_num; ++k) { - uint o_offset = D * N * (k + iq3 * k_num) + D * n + d; + uint o_offset = D * p.ne1 * (k + p.k_num * (i2 + p.ne2 * i3)) + D * n + d; float m = data_a[m_offset + k * lm_stride]; O += exp(m - m_max) * data_a[o_offset]; } @@ -115,6 +116,6 @@ void main() { const float FLT_MAX = uintBitsToFloat(0x7F7FFFFF); O = clamp(O, -FLT_MAX, FLT_MAX); - data_d[iq3 * D * N + D * n + d] = O; + data_d[(i3 * p.ne2 + i2) * p.ne1 * D + D * n + d] = O; } } diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/fwht.comp b/ggml/src/ggml-vulkan/vulkan-shaders/fwht.comp new file mode 100644 index 00000000000..a2069964adb --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/fwht.comp @@ -0,0 +1,115 @@ +#version 450 + +#extension GL_EXT_control_flow_attributes : require +#ifndef FWHT_SHMEM +#extension GL_KHR_shader_subgroup_basic : enable +#extension GL_KHR_shader_subgroup_shuffle : enable +#endif + +layout(constant_id = 0) const uint BLOCK_SIZE = 32; +layout(constant_id = 1) const uint N = 128; + +layout(local_size_x_id = 0, local_size_y = 4, local_size_z = 1) in; + +layout(push_constant) uniform parameter +{ + uint n_rows; + uint src_offset; + uint dst_offset; + float scale; +}; + +layout(binding = 0, std430) readonly buffer A { float data_a[]; }; +layout(binding = 1, std430) writeonly buffer D { float data_d[]; }; + +const uint EL_W = N / BLOCK_SIZE; + +#ifdef FWHT_SHMEM +shared float shmem[4 * N]; +#endif + +void main() { +#ifdef FWHT_SHMEM + const uint tid = gl_LocalInvocationID.x; + const uint shmem_base = gl_LocalInvocationID.y * N; + const uint row_id = gl_LocalInvocationID.y; +#else + const uint tid = gl_SubgroupInvocationID; + const uint row_id = gl_SubgroupID; +#endif + + for (uint base_row = gl_WorkGroupID.x * gl_WorkGroupSize.y; + base_row < n_rows; + base_row += gl_NumWorkGroups.x * gl_WorkGroupSize.y) { + const uint row = base_row + row_id; + const uint row_offset = row * N; + +#ifndef FWHT_SHMEM + if (row >= n_rows) { + continue; + } +#endif + + float reg[EL_W]; + + [[unroll]] + for (uint i = 0; i < EL_W; ++i) { + reg[i] = row < n_rows ? data_a[src_offset + row_offset + i * BLOCK_SIZE + tid] * scale : 0.0; + } + +#ifdef FWHT_SHMEM + [[unroll]] + for (uint h = 1; h < BLOCK_SIZE; h <<= 1) { + [[unroll]] + for (uint i = 0; i < EL_W; ++i) { + shmem[shmem_base + i * BLOCK_SIZE + tid] = reg[i]; + } + barrier(); + [[unroll]] + for (uint j = 0; j < EL_W; ++j) { + const float val = reg[j]; + const float other = shmem[shmem_base + j * BLOCK_SIZE + (tid ^ h)]; + reg[j] = (tid & h) == 0 ? val + other : other - val; + } + barrier(); + } +#else + [[unroll]] + for (uint h = 1; h < BLOCK_SIZE; h <<= 1) { + [[unroll]] + for (uint j = 0; j < EL_W; ++j) { + const float val = reg[j]; + const float val2 = subgroupShuffleXor(val, h); + reg[j] = (tid & h) == 0 ? val + val2 : val2 - val; + } + } +#endif + + [[unroll]] + for (uint h = BLOCK_SIZE; h < N; h <<= 1) { + const uint step = h / BLOCK_SIZE; + [[unroll]] + for (uint j = 0; j < EL_W; j += 2 * step) { + [[unroll]] + for (uint k = 0; k < step; ++k) { + const float x = reg[j + k]; + const float y = reg[j + k + step]; + reg[j + k] = x + y; + reg[j + k + step] = x - y; + } + } + } + +#ifdef FWHT_SHMEM + if (row < n_rows) { +#endif + [[unroll]] + for (uint i = 0; i < EL_W; ++i) { + data_d[dst_offset + row_offset + i * BLOCK_SIZE + tid] = reg[i]; + } +#ifdef FWHT_SHMEM + } + barrier(); +#endif + } +} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/gated_delta_net.comp b/ggml/src/ggml-vulkan/vulkan-shaders/gated_delta_net.comp new file mode 100644 index 00000000000..0e384330b9b --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/gated_delta_net.comp @@ -0,0 +1,189 @@ +#version 450 + +#extension GL_EXT_control_flow_attributes : require +#extension GL_KHR_shader_subgroup_basic : enable +#if USE_SUBGROUP_CLUSTERED +#extension GL_KHR_shader_subgroup_clustered : enable +#endif +#if USE_SUBGROUP_ADD +#extension GL_KHR_shader_subgroup_arithmetic : enable +#endif + +// Caller guarantees valid spec constants: S_V % COLS_PER_WG == 0 and S_V % LANES_PER_COLUMN == 0, +// so no bounds checking is needed. +layout(constant_id = 0) const uint S_V = 128; +layout(constant_id = 1) const uint KDA = 0; +layout(constant_id = 2) const uint SUBGROUP_SIZE = 32; +layout(constant_id = 3) const uint LANES_PER_COLUMN = 32; + +const uint COLS_PER_WG = SUBGROUP_SIZE / LANES_PER_COLUMN; +const uint ROWS_PER_LANE = S_V / LANES_PER_COLUMN; + +layout(local_size_x_id = 2, local_size_y = 1, local_size_z = 1) in; + +layout(push_constant) uniform Parameters { + uint H; + uint n_tokens; + uint n_seqs; + uint s_off; + uint sq1, sq2, sq3; + uint sv1, sv2, sv3; + uint sb1, sb2, sb3; + uint neq1, rq3; + float scale; + uint K; +}; + +layout(binding = 0) readonly buffer QBuf { FLOAT_TYPE data_q[]; }; +layout(binding = 1) readonly buffer KBuf { FLOAT_TYPE data_k[]; }; +layout(binding = 2) readonly buffer VBuf { FLOAT_TYPE data_v[]; }; +layout(binding = 3) readonly buffer GBuf { FLOAT_TYPE data_g[]; }; +layout(binding = 4) readonly buffer BetaBuf { FLOAT_TYPE data_beta[]; }; +layout(binding = 5) readonly buffer StateBuf { FLOAT_TYPE data_state[]; }; +layout(binding = 6) buffer DstBuf { FLOAT_TYPE data_dst[]; }; + +#if !USE_SUBGROUP_ADD && !USE_SUBGROUP_CLUSTERED +shared FLOAT_TYPE temp[SUBGROUP_SIZE]; + +// This does a reduction across groups of LANES_PER_COLUMN +FLOAT_TYPE reduce_add_shmem(FLOAT_TYPE partial) { + const uint lane = gl_SubgroupInvocationID; + temp[lane] = partial; + barrier(); + [[unroll]] for (uint s = LANES_PER_COLUMN / 2u; s > 0; s >>= 1u) { + FLOAT_TYPE other = temp[lane ^ s]; + barrier(); + temp[lane] += other; + barrier(); + } + const FLOAT_TYPE result = temp[lane]; + barrier(); + return result; +} +#endif + +// clusterSize for subgroupClusteredAdd must be a compile-time constant; branch on spec constant +FLOAT_TYPE reduce_partial(FLOAT_TYPE partial) { + switch (LANES_PER_COLUMN) { + case 1u: + return partial; +#if USE_SUBGROUP_CLUSTERED + // Workaround for GLSL requiring a literal constant for the cluster size. + // The branches should all fold away. + case 2u: + return subgroupClusteredAdd(partial, 2u); + case 4u: + return subgroupClusteredAdd(partial, 4u); + case 8u: + return subgroupClusteredAdd(partial, 8u); + case 16u: + return subgroupClusteredAdd(partial, 16u); + case 32u: + return subgroupClusteredAdd(partial, 32u); + case 64u: + return subgroupClusteredAdd(partial, 64u); +#endif + default: +#if USE_SUBGROUP_ADD + return subgroupAdd(partial); +#else + return reduce_add_shmem(partial); +#endif + } +} + +void main() { + const uint head_id = gl_WorkGroupID.x; + const uint seq_id = gl_WorkGroupID.y; + const uint lane = gl_SubgroupInvocationID % LANES_PER_COLUMN; + const uint col = gl_WorkGroupID.z * COLS_PER_WG + (gl_SubgroupInvocationID / LANES_PER_COLUMN); + + const uint iq1 = head_id % neq1; + const uint iq3 = seq_id / rq3; + + const uint state_size = S_V * S_V; + // input state holds s0 only [S_v, S_v, H, n_seqs]: per-seq stride is H*D. + const uint state_in_base = (seq_id * H + head_id) * state_size; + // output state layout per slot: same per-(seq,head) offset as the single-slot case. + const uint state_out_base = (seq_id * H + head_id) * state_size; + const uint state_size_per_snap = state_size * H * n_seqs; + + FLOAT_TYPE s_shard[ROWS_PER_LANE]; + [[unroll]] for (uint r = 0; r < ROWS_PER_LANE; r++) { + s_shard[r] = FLOAT_TYPE(data_state[state_in_base + col * S_V + r * LANES_PER_COLUMN + lane]); + } + + // snapshot slot mapping: slot 0 = most recent state, slot s = s tokens back. + // When n_tokens < K, only slots 0..n_tokens-1 are written; older slots are caller-owned. + + uint attn_off = (seq_id * n_tokens * H + head_id) * S_V; + + for (uint t = 0; t < n_tokens; t++) { + const uint q_off = iq3 * sq3 + t * sq2 + iq1 * sq1; + const uint k_off = q_off; + const uint v_off = seq_id * sv3 + t * sv2 + head_id * sv1; + const uint gb_off = seq_id * sb3 + t * sb2 + head_id * sb1; + const FLOAT_TYPE beta_val = FLOAT_TYPE(data_beta[gb_off]); + + FLOAT_TYPE k_reg[ROWS_PER_LANE]; + FLOAT_TYPE q_reg[ROWS_PER_LANE]; + [[unroll]] for (uint r = 0; r < ROWS_PER_LANE; r++) { + const uint i = r * LANES_PER_COLUMN + lane; + k_reg[r] = FLOAT_TYPE(data_k[k_off + i]); + q_reg[r] = FLOAT_TYPE(data_q[q_off + i]); + } + + FLOAT_TYPE g_exp[ROWS_PER_LANE]; + if (KDA == 0) { + const FLOAT_TYPE g_val = exp(FLOAT_TYPE(data_g[gb_off])); + [[unroll]] for (uint r = 0; r < ROWS_PER_LANE; r++) { + g_exp[r] = g_val; + } + } else { + const uint g_base = gb_off * S_V; + [[unroll]] for (uint r = 0; r < ROWS_PER_LANE; r++) { + const uint i = r * LANES_PER_COLUMN + lane; + g_exp[r] = exp(FLOAT_TYPE(data_g[g_base + i])); + } + } + + const FLOAT_TYPE v_val = FLOAT_TYPE(data_v[v_off + col]); + + FLOAT_TYPE kv_shard = 0.0; + [[unroll]] for (uint r = 0; r < ROWS_PER_LANE; r++) { + kv_shard += g_exp[r] * s_shard[r] * k_reg[r]; + } + FLOAT_TYPE kv_col = reduce_partial(kv_shard); + + FLOAT_TYPE delta_col = (v_val - kv_col) * beta_val; + + FLOAT_TYPE attn_partial = 0.0; + [[unroll]] for (uint r = 0; r < ROWS_PER_LANE; r++) { + s_shard[r] = g_exp[r] * s_shard[r] + k_reg[r] * delta_col; + attn_partial += s_shard[r] * q_reg[r]; + } + FLOAT_TYPE attn_col = reduce_partial(attn_partial); + + if (lane == 0) { + data_dst[attn_off + col] = attn_col * scale; + } + + attn_off += S_V * H; + + if (K > 1u) { + const int target_slot = int(n_tokens) - 1 - int(t); + if (target_slot >= 0 && target_slot < int(K)) { + const uint slot_base = s_off + uint(target_slot) * state_size_per_snap + state_out_base; + [[unroll]] for (uint r = 0; r < ROWS_PER_LANE; r++) { + data_dst[slot_base + col * S_V + r * LANES_PER_COLUMN + lane] = s_shard[r]; + } + } + } + } + + if (K == 1u) { + [[unroll]] for (uint r = 0; r < ROWS_PER_LANE; r++) { + data_dst[s_off + state_out_base + col * S_V + r * LANES_PER_COLUMN + lane] = s_shard[r]; + } + } +} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/generic_binary_head.glsl b/ggml/src/ggml-vulkan/vulkan-shaders/generic_binary_head.glsl index ba7909c4d38..dc657f3c708 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/generic_binary_head.glsl +++ b/ggml/src/ggml-vulkan/vulkan-shaders/generic_binary_head.glsl @@ -1,7 +1,6 @@ #extension GL_EXT_shader_16bit_storage : require #extension GL_EXT_control_flow_attributes : require -#include "rte.glsl" #include "utils.glsl" #if RMS_NORM_ROPE_FUSION #include "rope_params.glsl" diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/glu_head.glsl b/ggml/src/ggml-vulkan/vulkan-shaders/glu_head.glsl index 2168989340b..d8fdd8f7b5e 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/glu_head.glsl +++ b/ggml/src/ggml-vulkan/vulkan-shaders/glu_head.glsl @@ -1,6 +1,5 @@ #extension GL_EXT_shader_16bit_storage : require -#include "rte.glsl" layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in; @@ -16,4 +15,14 @@ layout (push_constant) uniform parameter uint mode; float alpha; float limit; + uint nb01; + uint nb02; + uint nb03; + uint ne01; + uint ne02; + uint nb11; + uint nb12; + uint nb13; + uint ne11; + uint ne12; } p; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/glu_main.glsl b/ggml/src/ggml-vulkan/vulkan-shaders/glu_main.glsl index 85cf65a9eca..359461306a5 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/glu_main.glsl +++ b/ggml/src/ggml-vulkan/vulkan-shaders/glu_main.glsl @@ -8,22 +8,32 @@ void main() { const uint row = i / p.ne20; const uint col = i - row * p.ne20; + const uint i3 = row / (p.ne01 * p.ne02); + const uint i2 = (row % (p.ne01 * p.ne02)) / p.ne01; + const uint i1 = row % p.ne01; + const uint src_idx = i3 * p.nb03 + i2 * p.nb02 + i1 * p.nb01 + col; + + const uint dst_i3 = row / (p.ne11 * p.ne12); + const uint dst_i2 = (row % (p.ne11 * p.ne12)) / p.ne11; + const uint dst_i1 = row % p.ne11; + const uint dst_idx = dst_i3 * p.nb13 + dst_i2 * p.nb12 + dst_i1 * p.nb11 + col; + if (p.mode == 0) { // Default const uint offset = p.ne00 / 2; - const uint idx = row * p.ne00 + col; + const uint idx = src_idx; - data_d[row * offset + col] = D_TYPE(op(float(data_a[idx]), float(data_a[idx + offset]))); + data_d[dst_idx] = D_TYPE(op(float(data_a[idx]), float(data_a[idx + offset]))); } else if (p.mode == 1) { // Swapped const uint offset = p.ne00 / 2; - const uint idx = row * p.ne00 + col; + const uint idx = src_idx; - data_d[row * offset + col] = D_TYPE(op(float(data_a[idx + offset]), float(data_a[idx]))); + data_d[dst_idx] = D_TYPE(op(float(data_a[idx + offset]), float(data_a[idx]))); } else { // Split - const uint idx = row * p.ne00 + col; + const uint idx = src_idx; - data_d[idx] = D_TYPE(op(float(data_a[idx]), float(data_b[idx]))); + data_d[dst_idx] = D_TYPE(op(float(data_a[idx]), float(data_b[idx]))); } } diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp b/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp index db14f5a3cf3..f4130d223b1 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp @@ -3,7 +3,6 @@ #extension GL_EXT_shader_16bit_storage : require #extension GL_EXT_control_flow_attributes : require -#include "rte.glsl" #include "types.glsl" layout (push_constant) uniform parameter @@ -14,7 +13,7 @@ layout (push_constant) uniform parameter uint IW; uint IH; uint OW; uint OH; uint KW; uint KH; - uint pelements; + uint OH_batch; uint CHW; int s0; int s1; int p0; int p1; @@ -35,82 +34,105 @@ layout (binding = 1) writeonly buffer D {D_TYPE data_d[];}; layout (buffer_reference) buffer D_ptr {D_TYPE d;}; #endif -void im2col(const uint y, const uint z) { - const uint gidx = gl_GlobalInvocationID.x; +void im2col(const uint ow, const uint z_idx) { + const uint oh = z_idx % p.OH; + const uint batch_idx = z_idx / p.OH; - const uint oh = y; - const uint batch = z / p.IC; - const uint ic = z % p.IC; + const uint gidx = gl_LocalInvocationID.x; + const uint src_batch = batch_idx * p.batch_offset; + const BDA_OFFSET_T dst_row = ((BDA_OFFSET_T(batch_idx) * p.OH + oh) * p.OW + ow) * p.CHW; - const uint src_base = ic * p.offset_delta + batch * p.batch_offset; - const BDA_OFFSET_T dst_base = ((BDA_OFFSET_T(batch) * p.OH + oh) * p.OW) * p.CHW + BDA_OFFSET_T(ic) * (p.KW * p.KH); - const int oh_s1 = int(oh) * p.s1; - const uint ksize = p.OW * p.KH; + const uint KHKW = p.KH * p.KW; - const uint base_linear_idx = gidx * NUM_ITER; + // Precompute base input coordinates + const int base_iw = int(ow * p.s0) - p.p0; + const int base_ih = int(oh * p.s1) - p.p1; - uint current_kx = base_linear_idx / ksize; - const uint rem = base_linear_idx - (current_kx * ksize); - uint current_ky = rem / p.OW; - uint current_ix = rem % p.OW; + // Precompute step deltas + const uint delta_ic = BLOCK_SIZE / KHKW; + const uint delta_rem = BLOCK_SIZE % KHKW; - A_TYPE values[NUM_ITER]; - BDA_OFFSET_T offset_dst[NUM_ITER]; - [[unroll]] for (uint idx = 0; idx < NUM_ITER; ++idx) { - values[idx] = A_TYPE(0); - } + const uint delta_ky = delta_rem / p.KW; + const uint delta_kx = delta_rem % p.KW; - [[unroll]] for (uint idx = 0; idx < NUM_ITER; ++idx) { + const uint delta_ic_offset = delta_ic * p.offset_delta; - const uint linear_idx = base_linear_idx + idx; + // If using BDA mode, precompute the base pointer and step size +#if BDA + const BDA_STORAGE_T base_dst_addr = p.dst_addr + D_SIZE * dst_row; + const uint bda_step = D_SIZE * BLOCK_SIZE; +#endif - if (linear_idx >= p.pelements) { - continue; - } + uint wg_x = gl_WorkGroupID.x; + do { + const uint wg_offset = wg_x * 512; - const uint iiw = current_ix * p.s0 + current_kx * p.d0 - p.p0; - const uint iih = oh_s1 + current_ky * p.d1 - p.p1; + uint chw_idx = wg_offset + gidx; - offset_dst[idx] = dst_base + BDA_OFFSET_T(current_ix) * p.CHW + current_ky * p.KW + current_kx; + uint ic = chw_idx / KHKW; + uint rem = chw_idx % KHKW; - if ((iih < p.IH) && (iiw < p.IW)) { - values[idx] = data_a[src_base + iih * p.IW + iiw]; - } + uint ky = rem / p.KW; + uint kx = rem % p.KW; - if (++current_ix == p.OW) { - current_ix = 0; - if (++current_ky == p.KH) { - current_ky = 0; - current_kx++; - } - } - } + uint ic_offset = src_batch + ic * p.offset_delta; - [[unroll]] for (uint idx = 0; idx < NUM_ITER; ++idx) { + // Initialize running pointer/index for the destination buffer +#if BDA + BDA_STORAGE_T current_dst_addr = base_dst_addr + D_SIZE * chw_idx; +#else + uint current_dst_idx = dst_row + chw_idx; +#endif - const uint linear_idx = base_linear_idx + idx; + [[unroll]] for (uint i = 0; i < NUM_ITER; ++i) { + if (chw_idx >= p.CHW) { + return; + } - if (linear_idx >= p.pelements) { - continue; - } + const int iiw = base_iw + int(kx * p.d0); + const int iih = base_ih + int(ky * p.d1); + + A_TYPE val = A_TYPE(0); + if (uint(iih) < p.IH && uint(iiw) < p.IW) { + val = data_a[ic_offset + uint(iih) * p.IW + uint(iiw)]; + } #if BDA - D_ptr dst_addr = D_ptr(p.dst_addr + D_SIZE * offset_dst[idx]); - dst_addr.d = D_TYPE(values[idx]); + D_ptr(current_dst_addr).d = D_TYPE(val); + current_dst_addr += bda_step; #else - data_d[offset_dst[idx]] = D_TYPE(values[idx]); + data_d[current_dst_idx] = D_TYPE(val); + current_dst_idx += BLOCK_SIZE; #endif - } + + chw_idx += BLOCK_SIZE; + ic_offset += delta_ic_offset; + kx += delta_kx; + ky += delta_ky; + + // Handle X axis wrap + uint kx_wrap = uint(kx >= p.KW); + kx -= kx_wrap * p.KW; + ky += kx_wrap; + + // Handle Y axis wrap + uint ky_wrap = uint(ky >= p.KH); + ky -= ky_wrap * p.KH; + ic_offset += ky_wrap * p.offset_delta; + } + + wg_x += gl_NumWorkGroups.x; + } while (wg_x * 512 < p.CHW); } void main() { - uint y = gl_GlobalInvocationID.y; - while (y < p.OH) { + uint ow = gl_GlobalInvocationID.y; + while (ow < p.OW) { uint z = gl_GlobalInvocationID.z; - while (z < p.batch_IC) { - im2col(y, z); + while (z < p.OH_batch) { + im2col(ow, z); z += gl_NumWorkGroups.z; } - y += gl_NumWorkGroups.y; + ow += gl_NumWorkGroups.y; } } diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/im2col_3d.comp b/ggml/src/ggml-vulkan/vulkan-shaders/im2col_3d.comp index 4bf8b4ca046..93f61fd8543 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/im2col_3d.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/im2col_3d.comp @@ -4,7 +4,6 @@ #extension GL_EXT_control_flow_attributes : require #extension GL_EXT_shader_explicit_arithmetic_types_int32 : require -#include "rte.glsl" #include "types.glsl" layout (push_constant) uniform parameter diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/l2_norm.comp b/ggml/src/ggml-vulkan/vulkan-shaders/l2_norm.comp index 83ef2f87958..f9af46744df 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/l2_norm.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/l2_norm.comp @@ -1,6 +1,6 @@ #version 450 -#include "generic_head.glsl" +#include "generic_unary_head.glsl" #include "types.glsl" #extension GL_EXT_control_flow_attributes : enable @@ -8,19 +8,22 @@ layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in; -layout (binding = 0) readonly buffer X {A_TYPE data_a[];}; -layout (binding = 1) writeonly buffer D {D_TYPE data_d[];}; - shared FLOAT_TYPE sum[BLOCK_SIZE]; void main() { const uint row = gl_WorkGroupID.z * 262144 + gl_WorkGroupID.y * 512 + gl_WorkGroupID.x; const uint tid = gl_LocalInvocationID.x; + const uint i3 = row / (p.ne11 * p.ne12); + const uint i3_offset = i3 * p.ne12 * p.ne11; + const uint i2 = (row - i3_offset) / p.ne11; + const uint i2_offset = i2 * p.ne11; + const uint i1 = row - i3_offset - i2_offset; + sum[tid] = FLOAT_TYPE(0.0f); // partial sum for thread in warp - [[unroll]] for (uint col = tid; col < p.KX; col += BLOCK_SIZE) { - const FLOAT_TYPE xi = FLOAT_TYPE(data_a[row*p.KX + col]); + [[unroll]] for (uint i0 = tid; i0 < p.ne00; i0 += BLOCK_SIZE) { + const FLOAT_TYPE xi = FLOAT_TYPE(data_a[i3*p.nb03 + i2*p.nb02 + i1*p.nb01 + i0]); sum[tid] += xi * xi; } @@ -33,9 +36,9 @@ void main() { barrier(); } - const FLOAT_TYPE scale = inversesqrt(max(sum[0], FLOAT_TYPE(p.param1))); + const FLOAT_TYPE scale = 1.0f / max(sqrt(sum[0]), FLOAT_TYPE(p.param1)); - [[unroll]] for (uint col = tid; col < p.KX; col += BLOCK_SIZE) { - data_d[row*p.KX + col] = D_TYPE(scale * FLOAT_TYPE(data_a[row*p.KX + col])); + [[unroll]] for (uint i0 = tid; i0 < p.ne00; i0 += BLOCK_SIZE) { + data_d[i3*p.nb13 + i2*p.nb12 + i1*p.nb11 + i0] = D_TYPE(scale * FLOAT_TYPE(data_a[i3*p.nb03 + i2*p.nb02 + i1*p.nb01 + i0])); } } diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/log.comp b/ggml/src/ggml-vulkan/vulkan-shaders/log.comp index ff2812d3d75..3cda6a63c45 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/log.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/log.comp @@ -1,6 +1,5 @@ #version 450 -#include "rte.glsl" #include "types.glsl" #include "generic_unary_head.glsl" diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec.comp index 2271be4021b..5a9d0e778fd 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec.comp @@ -10,12 +10,38 @@ layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; #if !defined(DATA_A_F32) && !defined(DATA_A_F16) && !defined(DATA_A_BF16) #define K_PER_ITER 8 #else -#define K_PER_ITER 2 +#define K_PER_ITER 4 #endif uint a_offset, b_offset, d_offset, y_offset; +vec4 load_b(const uint j, const uint iybs, const uint iqs, const bool lastiter, out bool OOB_y, out bool OOB_z, out bool OOB_w) { + // Check if the latter elements are OOB, and don't fetch B or accumulate it. + OOB_y = lastiter && (iybs + iqs + y_offset >= p.ncols); + OOB_z = lastiter && (iybs + iqs + y_offset*2 >= p.ncols); + OOB_w = lastiter && (iybs + iqs + y_offset*3 >= p.ncols); + + if (!OOB_w) { + return vec4(FLOAT_TYPE(data_b[j*p.batch_stride_b + b_offset + iybs + iqs]), + FLOAT_TYPE(data_b[j*p.batch_stride_b + b_offset + iybs + iqs + y_offset]), + FLOAT_TYPE(data_b[j*p.batch_stride_b + b_offset + iybs + iqs + y_offset*2]), + FLOAT_TYPE(data_b[j*p.batch_stride_b + b_offset + iybs + iqs + y_offset*3])); + } else if (!OOB_z) { + return vec4(FLOAT_TYPE(data_b[j*p.batch_stride_b + b_offset + iybs + iqs]), + FLOAT_TYPE(data_b[j*p.batch_stride_b + b_offset + iybs + iqs + y_offset]), + FLOAT_TYPE(data_b[j*p.batch_stride_b + b_offset + iybs + iqs + y_offset*2]), + 0); + } else if (!OOB_y) { + return vec4(FLOAT_TYPE(data_b[j*p.batch_stride_b + b_offset + iybs + iqs]), + FLOAT_TYPE(data_b[j*p.batch_stride_b + b_offset + iybs + iqs + y_offset]), + 0, 0); + } else { + return vec4(FLOAT_TYPE(data_b[j*p.batch_stride_b + b_offset + iybs + iqs]), + 0, 0, 0); + } +} + void iter(inout FLOAT_TYPE temp[NUM_COLS][NUM_ROWS], const uint first_row, const uint num_rows, const uint tid, const uint i, bool lastiter) { [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { @@ -25,6 +51,8 @@ void iter(inout FLOAT_TYPE temp[NUM_COLS][NUM_ROWS], const uint first_row, const #if K_PER_ITER == 8 #if QUANT_R == 2 + // Note that we end up fetching bogus elements here, but its fine as they'll be + // within an accessible block. const vec4 bv02 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + iybs + iqs) / 4]); const vec4 bv13 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + iybs + iqs + y_offset) / 4]); const vec4 bv0 = vec4(bv02.x, bv13.x, bv02.y, bv13.y); @@ -34,18 +62,11 @@ void iter(inout FLOAT_TYPE temp[NUM_COLS][NUM_ROWS], const uint first_row, const const vec4 bv1 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + iybs + iqs) / 4 + 1]); #endif #else - // Check if the second of the pair of elements is OOB, and don't fetch B or - // accumulate it. We still fetch a pair of elements for A, which is fine for - // quantized formats since they'll be within the same block. We should - // probably skip fetching the second element for F16/F32, but as of now we - // still do. - const bool OOB = lastiter && (iybs + iqs + y_offset >= p.ncols); - - FLOAT_TYPE b0 = 0, b1 = 0; - b0 = FLOAT_TYPE(data_b[j*p.batch_stride_b + b_offset + iybs + iqs]); - if (!OOB) { - b1 = FLOAT_TYPE(data_b[j*p.batch_stride_b + b_offset + iybs + iqs + y_offset]); - } + bool OOB_y; + bool OOB_z; + bool OOB_w; + + const vec4 b = load_b(j, iybs, iqs, lastiter, OOB_y, OOB_z, OOB_w); #endif uint ibi = first_row*p.ncols; [[unroll]] for (uint n = 0; n < num_rows; ++n) { @@ -71,22 +92,60 @@ void iter(inout FLOAT_TYPE temp[NUM_COLS][NUM_ROWS], const uint first_row, const temp[j][n] += rowtmp; #else - const vec2 v = dequantize(ib, iqs, a_offset); - - // matrix multiplication - temp[j][n] = fma(FLOAT_TYPE(v.x), b0, temp[j][n]); - if (!OOB) { - temp[j][n] = fma(FLOAT_TYPE(v.y), b1, temp[j][n]); + if (!OOB_w) { + const vec4 v = dequantize4(ib, iqs, a_offset); + temp[j][n] += dot(v, b); + } else if (!OOB_z) { + const vec2 v0 = dequantize(ib, iqs, a_offset); + const FLOAT_TYPE v1 = dequantize1(ib + 2/QUANT_R, iqs, a_offset); + const vec3 v = vec3(v0.x, v0.y, v1); + const vec3 b0 = vec3(b.x, b.y, b.z); + temp[j][n] += dot(v, b0); + } else if (!OOB_y) { + const vec2 v0 = dequantize(ib, iqs, a_offset); + const vec2 b0 = vec2(b.x, b.y); + temp[j][n] += dot(v0, b0); + } else { + const FLOAT_TYPE v = dequantize1(ib, iqs, a_offset); + temp[j][n] = fma(v, b.x, temp[j][n]); } #endif } } } +#if defined(DATA_A_F32) || defined(DATA_A_F16) || defined(DATA_A_BF16) +void iter_aligned_nonquant(inout FLOAT_TYPE temp[NUM_COLS][NUM_ROWS], const uint first_row, const uint num_rows, const uint tid, const uint i) +{ + [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { + const uint col = i*BLOCK_SIZE + K_PER_ITER*tid; + const uint iqs = 0; // quant index + const uint iybs = col; // y block start index + + const vec4 b = data_b_v4[(j*p.batch_stride_b + b_offset + iybs + iqs) / 4]; + + uint ibi = first_row*p.ncols; + [[unroll]] for (uint n = 0; n < num_rows; ++n) { + const uint ib = (ibi + col)/QUANT_K; // block index + ibi += p.ncols; + + const vec4 v = dequantize4_2aligned(ib, iqs, a_offset); + + // matrix multiplication + temp[j][n] += dot(v, b); + } + } +} +#endif + void compute_outputs(const uint32_t first_row, const uint32_t num_rows) { const uint tid = gl_LocalInvocationID.x; get_offsets(a_offset, b_offset, d_offset); + const bool is_aligned_nonquant = + p.batch_stride_b % 4 == 0 && b_offset % 4 == 0 && + p.ncols % 4 == 0 && BLOCK_SIZE % 4 == 0 && + K_PER_ITER == 4; y_offset = QUANT_R == 1 ? 1 : QUANT_K/2; @@ -105,17 +164,26 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) { int unroll_count = 4; uint unrolled_iters = num_iters & ~(unroll_count - 1); -#if K_PER_ITER == 2 + uint i = 0; + +#if K_PER_ITER == 4 // If the K dimension is odd, we need lastiter==true on the last iteration // so OOB is computed correctly. Skip some unrolling to make that happen. - if ((p.ncols & 1) != 0 && + if ((p.ncols & 3) != 0 && unrolled_iters == num_iters && unrolled_iters > 0) { unrolled_iters -= unroll_count; } + if (is_aligned_nonquant) { + while (i < unrolled_iters) { + // Manually partially unroll the loop + [[unroll]] for (uint k = 0; k < unroll_count; ++k) { + iter_aligned_nonquant(temp, first_row, num_rows, tid, i*K_PER_ITER); + i++; + } + } + } else { #endif - - uint i = 0; while (i < unrolled_iters) { // Manually partially unroll the loop [[unroll]] for (uint k = 0; k < unroll_count; ++k) { @@ -123,18 +191,30 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) { i++; } } +#if K_PER_ITER == 4 + } +#endif unroll_count = 2; unrolled_iters = num_iters & ~(unroll_count - 1); -#if K_PER_ITER == 2 - if ((p.ncols & 1) != 0 && +#if K_PER_ITER == 4 + if ((p.ncols & 3) != 0 && unrolled_iters == num_iters && unrolled_iters > 0) { unrolled_iters -= unroll_count; } -#endif + if (is_aligned_nonquant) { + while (i < unrolled_iters && is_aligned_nonquant) { + // Manually partially unroll the loop + [[unroll]] for (uint k = 0; k < unroll_count; ++k) { + iter_aligned_nonquant(temp, first_row, num_rows, tid, i*K_PER_ITER); + i++; + } + } + } else { +#endif while (i < unrolled_iters) { // Manually partially unroll the loop [[unroll]] for (uint k = 0; k < unroll_count; ++k) { @@ -142,10 +222,25 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) { i++; } } +#if K_PER_ITER == 4 + } +#endif + +#if K_PER_ITER == 4 + if (is_aligned_nonquant) { + while (i < num_iters) { + iter_aligned_nonquant(temp, first_row, num_rows, tid, i*K_PER_ITER); + i++; + } + } else { +#endif while (i < num_iters) { iter(temp, first_row, num_rows, tid, i*K_PER_ITER, true); i++; } +#if K_PER_ITER == 4 + } +#endif reduce_result(temp, d_offset, first_row, num_rows, tid); } @@ -164,6 +259,6 @@ void main() { if (first_row >= p.stride_d) { return; } - compute_outputs(first_row, p.stride_d - first_row); + compute_outputs(first_row, min(NUM_ROWS, p.stride_d - first_row)); } } diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.glsl b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.glsl index dfb78659362..4aeda68c7f2 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.glsl +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.glsl @@ -29,7 +29,10 @@ layout (push_constant) uniform parameter #ifdef MUL_MAT_ID uint nei0; uint ne11; + uint expert_i1; + uint nbi1; #else + uint base_work_group_y; uint ne02; uint ne12; uint broadcast2; @@ -43,9 +46,9 @@ uint expert_id; void get_offsets(out uint a_offset, out uint b_offset, out uint d_offset) { #ifdef MUL_MAT_ID - const uint expert_idx = gl_GlobalInvocationID.y; + const uint expert_i0 = gl_WorkGroupID.y; #else - const uint batch_idx = gl_GlobalInvocationID.y; + const uint batch_idx = gl_WorkGroupID.y + p.base_work_group_y; #endif #ifndef MUL_MAT_ID @@ -60,7 +63,7 @@ void get_offsets(out uint a_offset, out uint b_offset, out uint d_offset) { batch_idx_a = i03 * p.ne02 + i02; } #else - expert_id = data_ids[expert_idx]; + expert_id = data_ids[expert_i0 + p.expert_i1 * p.nbi1]; #endif a_offset = @@ -71,13 +74,13 @@ void get_offsets(out uint a_offset, out uint b_offset, out uint d_offset) { #endif b_offset = #ifdef MUL_MAT_ID - (expert_idx % p.ne11) * p.stride_b; + (expert_i0 % p.ne11) * p.stride_b + p.expert_i1 * p.batch_stride_b; #else batch_idx * p.batch_stride_b; #endif d_offset = #ifdef MUL_MAT_ID - expert_idx * p.stride_d; + expert_i0 * p.stride_d + p.expert_i1 * p.batch_stride_d; #else batch_idx * p.batch_stride_d; #endif @@ -103,12 +106,12 @@ void reduce_result(inout FLOAT_TYPE temp[NUM_COLS][NUM_ROWS], const in uint32_t temp[j][n] += FLOAT_TYPE(data_fuse0[expert_id*p.stride_d + first_row + n]); } if ((p.fusion_flags & MAT_VEC_FUSION_FLAGS_SCALE0) != 0) { - const uint expert_idx = gl_GlobalInvocationID.y; - temp[j][n] *= FLOAT_TYPE(data_fuse0[expert_idx]); + const uint expert_i0 = gl_GlobalInvocationID.y; + temp[j][n] *= FLOAT_TYPE(data_fuse0[expert_i0]); } if ((p.fusion_flags & MAT_VEC_FUSION_FLAGS_SCALE1) != 0) { - const uint expert_idx = gl_GlobalInvocationID.y; - temp[j][n] *= FLOAT_TYPE(data_fuse1[expert_idx]); + const uint expert_i0 = gl_GlobalInvocationID.y; + temp[j][n] *= FLOAT_TYPE(data_fuse1[expert_i0]); } #else if ((p.fusion_flags & MAT_VEC_FUSION_FLAGS_BIAS0) != 0) { @@ -158,12 +161,12 @@ void reduce_result(FLOAT_TYPE temp[NUM_COLS][NUM_ROWS], const in uint32_t d_offs temp[j][n] += FLOAT_TYPE(data_fuse0[expert_id*p.stride_d + first_row + n]); } if ((p.fusion_flags & MAT_VEC_FUSION_FLAGS_SCALE0) != 0) { - const uint expert_idx = gl_GlobalInvocationID.y; - temp[j][n] *= FLOAT_TYPE(data_fuse0[expert_idx]); + const uint expert_i0 = gl_GlobalInvocationID.y; + temp[j][n] *= FLOAT_TYPE(data_fuse0[expert_i0]); } if ((p.fusion_flags & MAT_VEC_FUSION_FLAGS_SCALE1) != 0) { - const uint expert_idx = gl_GlobalInvocationID.y; - temp[j][n] *= FLOAT_TYPE(data_fuse1[expert_idx]); + const uint expert_i0 = gl_GlobalInvocationID.y; + temp[j][n] *= FLOAT_TYPE(data_fuse1[expert_i0]); } #else if ((p.fusion_flags & MAT_VEC_FUSION_FLAGS_BIAS0) != 0) { @@ -203,12 +206,12 @@ void reduce_result(FLOAT_TYPE temp[NUM_COLS][NUM_ROWS], const in uint32_t d_offs tmpsh[j][n][0] += FLOAT_TYPE(data_fuse0[expert_id*p.stride_d + first_row + n]); } if ((p.fusion_flags & MAT_VEC_FUSION_FLAGS_SCALE0) != 0) { - const uint expert_idx = gl_GlobalInvocationID.y; - tmpsh[j][n][0] *= FLOAT_TYPE(data_fuse0[expert_idx]); + const uint expert_i0 = gl_GlobalInvocationID.y; + tmpsh[j][n][0] *= FLOAT_TYPE(data_fuse0[expert_i0]); } if ((p.fusion_flags & MAT_VEC_FUSION_FLAGS_SCALE1) != 0) { - const uint expert_idx = gl_GlobalInvocationID.y; - tmpsh[j][n][0] *= FLOAT_TYPE(data_fuse1[expert_idx]); + const uint expert_i0 = gl_GlobalInvocationID.y; + tmpsh[j][n][0] *= FLOAT_TYPE(data_fuse1[expert_i0]); } #else if ((p.fusion_flags & MAT_VEC_FUSION_FLAGS_BIAS0) != 0) { diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iface.glsl b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iface.glsl index 337dbd796ad..e8d053cdd43 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iface.glsl +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iface.glsl @@ -6,8 +6,8 @@ #define MAT_VEC_FUSION_FLAGS_SCALE1 0x8 layout (binding = 0) readonly buffer A {A_TYPE data_a[];}; -#if defined(A_TYPE_VEC4) -layout (binding = 0) readonly buffer AV4 {A_TYPE_VEC4 data_a_v4[];}; +#if defined(A_TYPEV4) +layout (binding = 0) readonly buffer AV4 {A_TYPEV4 data_a_v4[];}; #endif #if defined(A_TYPE_PACKED16) layout (binding = 0) readonly buffer A_PACKED16 {A_TYPE_PACKED16 data_a_packed16[];}; @@ -17,11 +17,11 @@ layout (binding = 0) readonly buffer A_PACKED32 {A_TYPE_PACKED32 data_a_packed32 #endif layout (binding = 1) readonly buffer B {B_TYPE data_b[];}; -#ifdef B_TYPE_VEC2 -layout (binding = 1) readonly buffer BV2 {B_TYPE_VEC2 data_b_v2[];}; +#ifdef B_TYPEV2 +layout (binding = 1) readonly buffer BV2 {B_TYPEV2 data_b_v2[];}; #endif -#ifdef B_TYPE_VEC4 -layout (binding = 1) readonly buffer BV4 {B_TYPE_VEC4 data_b_v4[];}; +#ifdef B_TYPEV4 +layout (binding = 1) readonly buffer BV4 {B_TYPEV4 data_b_v4[];}; #endif layout (binding = 2) writeonly buffer D {D_TYPE data_d[];}; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp index 619de054cb8..975cec8013f 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp @@ -41,7 +41,7 @@ void calc_superblock(const uint a_offset, const uint b_offset, const uint itid, const vec4 qs_u32_4 = vec4(unpack8((qs_u32 >> 4) & 0x03030303)); const vec4 qs_u32_6 = vec4(unpack8((qs_u32 >> 6) & 0x03030303)); - const FLOAT_TYPE_VEC2 dm = vec2(data_a[ib0 + i].dm); + const FLOAT_TYPEV2 dm = vec2(data_a[ib0 + i].dm); [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { vec2 b0 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 0]); diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q4_k.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q4_k.comp index 6af5a81587d..93fbacc6282 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q4_k.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q4_k.comp @@ -14,7 +14,7 @@ void calc_superblock(const uint a_offset, const uint b_offset, const uint v_im, [[unroll]] for (uint n = 0; n < num_rows; ++n) { const uint ib0 = a_offset + (first_row+n)*num_blocks_per_row; - const FLOAT_TYPE_VEC2 dm = FLOAT_TYPE_VEC2(data_a[ib0 + i].dm); + const FLOAT_TYPEV2 dm = FLOAT_TYPEV2(data_a[ib0 + i].dm); const uint32_t scale0_u32 = data_a_packed16[ib0 + i].scales[v_im ]; const uint32_t scale4_u32 = data_a_packed16[ib0 + i].scales[v_im + 2]; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q5_k.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q5_k.comp index 3695b47b98d..54d7e1bcdca 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q5_k.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q5_k.comp @@ -14,7 +14,7 @@ void calc_superblock(const uint a_offset, const uint b_offset, const uint v_im, [[unroll]] for (uint n = 0; n < num_rows; ++n) { const uint ib0 = a_offset + (first_row+n)*num_blocks_per_row; - const FLOAT_TYPE_VEC2 dm = FLOAT_TYPE_VEC2(data_a[ib0 + i].dm); + const FLOAT_TYPEV2 dm = FLOAT_TYPEV2(data_a[ib0 + i].dm); const uint32_t scale0_u32 = data_a_packed16[ib0 + i].scales[v_im ]; const uint32_t scale4_u32 = data_a_packed16[ib0 + i].scales[v_im + 2]; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq.comp index 6fe3e2dc043..fd84c3c91d8 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq.comp @@ -4,6 +4,7 @@ #extension GL_EXT_integer_dot_product : require #define MMQ +#define NEEDS_IQ1S_GRID_GPU #define B_TYPE block_q8_1_x4 #include "mul_mat_vec_base.glsl" diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq_funcs.glsl b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq_funcs.glsl index 6ddbed309d7..73cf9c79955 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq_funcs.glsl +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq_funcs.glsl @@ -11,8 +11,8 @@ FLOAT_TYPE get_dm(uint ib) { #endif #if defined(DATA_A_Q4_1) || defined(DATA_A_Q5_1) -FLOAT_TYPE_VEC2 get_dm(uint ib) { - return FLOAT_TYPE_VEC2(data_a_packed32[ib].dm); +FLOAT_TYPEV2 get_dm(uint ib) { + return FLOAT_TYPEV2(data_a_packed32[ib].dm); } #endif @@ -23,9 +23,9 @@ FLOAT_TYPE get_dm(uint ib) { #endif #if defined(DATA_A_Q2_K) -FLOAT_TYPE_VEC2 get_dm(uint ib) { +FLOAT_TYPEV2 get_dm(uint ib) { const uint ib_k = ib / 8; - return FLOAT_TYPE_VEC2(data_a_packed32[ib_k].dm); + return FLOAT_TYPEV2(data_a_packed32[ib_k].dm); } #endif @@ -212,28 +212,40 @@ i32vec4 repack4(uint ib, uint iqs) { const uint qs_shift = ((iqs_k % 32) / 8) * 2; const uint hm_shift = iqs_k / 8; + const uvec4 qs = uvec4( uint32_t(data_a_packed16[ib_k].qs[qs_idx * 2 ]) | + (uint32_t(data_a_packed16[ib_k].qs[qs_idx * 2 + 1]) << 16), + uint32_t(data_a_packed16[ib_k].qs[qs_idx * 2 + 2]) | + (uint32_t(data_a_packed16[ib_k].qs[qs_idx * 2 + 3]) << 16), + uint32_t(data_a_packed16[ib_k].qs[qs_idx * 2 + 4]) | + (uint32_t(data_a_packed16[ib_k].qs[qs_idx * 2 + 5]) << 16), + uint32_t(data_a_packed16[ib_k].qs[qs_idx * 2 + 6]) | + (uint32_t(data_a_packed16[ib_k].qs[qs_idx * 2 + 7]) << 16)); + + const uvec4 hmask = uvec4( uint32_t(data_a_packed16[ib_k].hmask[iqs * 2 ]) | + (uint32_t(data_a_packed16[ib_k].hmask[iqs * 2 + 1]) << 16), + uint32_t(data_a_packed16[ib_k].hmask[iqs * 2 + 2]) | + (uint32_t(data_a_packed16[ib_k].hmask[iqs * 2 + 3]) << 16), + uint32_t(data_a_packed16[ib_k].hmask[iqs * 2 + 4]) | + (uint32_t(data_a_packed16[ib_k].hmask[iqs * 2 + 5]) << 16), + uint32_t(data_a_packed16[ib_k].hmask[iqs * 2 + 6]) | + (uint32_t(data_a_packed16[ib_k].hmask[iqs * 2 + 7]) << 16)); + // bitwise OR to add 4 if hmask is set, subtract later - const i8vec2 vals00 = unpack8(int16_t((data_a_packed16[ib_k].qs[qs_idx * 2 ] >> qs_shift) & uint16_t(0x0303))) | - unpack8(int16_t(((data_a_packed16[ib_k].hmask[iqs * 2 ] >> hm_shift) & uint16_t(0x0101)) << 2)); - const i8vec2 vals01 = unpack8(int16_t((data_a_packed16[ib_k].qs[qs_idx * 2 + 1] >> qs_shift) & uint16_t(0x0303))) | - unpack8(int16_t(((data_a_packed16[ib_k].hmask[iqs * 2 + 1] >> hm_shift) & uint16_t(0x0101)) << 2)); - const i8vec2 vals10 = unpack8(int16_t((data_a_packed16[ib_k].qs[qs_idx * 2 + 2] >> qs_shift) & uint16_t(0x0303))) | - unpack8(int16_t(((data_a_packed16[ib_k].hmask[iqs * 2 + 2] >> hm_shift) & uint16_t(0x0101)) << 2)); - const i8vec2 vals11 = unpack8(int16_t((data_a_packed16[ib_k].qs[qs_idx * 2 + 3] >> qs_shift) & uint16_t(0x0303))) | - unpack8(int16_t(((data_a_packed16[ib_k].hmask[iqs * 2 + 3] >> hm_shift) & uint16_t(0x0101)) << 2)); - const i8vec2 vals20 = unpack8(int16_t((data_a_packed16[ib_k].qs[qs_idx * 2 + 4] >> qs_shift) & uint16_t(0x0303))) | - unpack8(int16_t(((data_a_packed16[ib_k].hmask[iqs * 2 + 4] >> hm_shift) & uint16_t(0x0101)) << 2)); - const i8vec2 vals21 = unpack8(int16_t((data_a_packed16[ib_k].qs[qs_idx * 2 + 5] >> qs_shift) & uint16_t(0x0303))) | - unpack8(int16_t(((data_a_packed16[ib_k].hmask[iqs * 2 + 5] >> hm_shift) & uint16_t(0x0101)) << 2)); - const i8vec2 vals30 = unpack8(int16_t((data_a_packed16[ib_k].qs[qs_idx * 2 + 6] >> qs_shift) & uint16_t(0x0303))) | - unpack8(int16_t(((data_a_packed16[ib_k].hmask[iqs * 2 + 6] >> hm_shift) & uint16_t(0x0101)) << 2)); - const i8vec2 vals31 = unpack8(int16_t((data_a_packed16[ib_k].qs[qs_idx * 2 + 7] >> qs_shift) & uint16_t(0x0303))) | - unpack8(int16_t(((data_a_packed16[ib_k].hmask[iqs * 2 + 7] >> hm_shift) & uint16_t(0x0101)) << 2)); - - return i32vec4(pack32(i8vec4(vals00.x, vals00.y, vals01.x, vals01.y) - int8_t(4)), - pack32(i8vec4(vals10.x, vals10.y, vals11.x, vals11.y) - int8_t(4)), - pack32(i8vec4(vals20.x, vals20.y, vals21.x, vals21.y) - int8_t(4)), - pack32(i8vec4(vals30.x, vals30.y, vals31.x, vals31.y) - int8_t(4))); + const uint vals0 = (( qs.x >> qs_shift) & 0x03030303) | + (((hmask.x >> hm_shift) & 0x01010101) << 2); + const uint vals1 = (( qs.y >> qs_shift) & 0x03030303) | + (((hmask.y >> hm_shift) & 0x01010101) << 2); + const uint vals2 = (( qs.z >> qs_shift) & 0x03030303) | + (((hmask.z >> hm_shift) & 0x01010101) << 2); + const uint vals3 = (( qs.w >> qs_shift) & 0x03030303) | + (((hmask.w >> hm_shift) & 0x01010101) << 2); + + // Subtract 4 by twiddling bits rather than using re-packing as mesa + // compiles repacking poorly. + return i32vec4(int32_t(((vals0 ^ 0x80808080) - 0x04040404) ^ 0x80808080), + int32_t(((vals1 ^ 0x80808080) - 0x04040404) ^ 0x80808080), + int32_t(((vals2 ^ 0x80808080) - 0x04040404) ^ 0x80808080), + int32_t(((vals3 ^ 0x80808080) - 0x04040404) ^ 0x80808080)); } float get_d_scale(uint ib, uint iqs) { @@ -296,15 +308,24 @@ vec2 get_dm_scale(uint ib, uint iqs) { const uint ib_k = ib / 8; const uint iqs_k = (ib % 8) * 8 + iqs; const uint is = iqs_k / 8; - u8vec2 scale_dm; - if (is < 4) { - scale_dm = u8vec2(data_a[ib_k].scales[is] & 0x3F, data_a[ib_k].scales[is + 4] & 0x3F); - } else { - scale_dm = u8vec2((data_a[ib_k].scales[is+4] & 0xF) | ((data_a[ib_k].scales[is-4] & 0xC0) >> 2), - (data_a[ib_k].scales[is+4] >> 4) | ((data_a[ib_k].scales[is ] & 0xC0) >> 2)); - } - return FLOAT_TYPE_VEC2(data_a_packed32[ib_k].dm) * FLOAT_TYPE_VEC2(scale_dm); + const uvec3 scales = uvec3(data_a_packed32[ib_k].scales[0], + data_a_packed32[ib_k].scales[1], + data_a_packed32[ib_k].scales[2]); + const uint scalesoffs = (is & 3) * 8; + + const uint scidx0 = (is < 4) ? 0 : 2; + const uint scidxshift0 = scalesoffs; + const uint scidxshift1 = (is < 4) ? scalesoffs : scalesoffs + 2; + const uint mbidx0 = (is < 4) ? 1 : 2; + const uint mbidxshift0 = (is < 4) ? scalesoffs : scalesoffs + 4; + const uint mbidxshift1 = (is < 4) ? scalesoffs : scalesoffs + 2; + + const uint8_t sc = uint8_t(((scales[scidx0] >> scidxshift0) & 0xF) | ((scales[0] >> scidxshift1) & 0x30)); + const uint8_t mbyte = uint8_t(((scales[mbidx0] >> mbidxshift0) & 0xF) | ((scales[1] >> mbidxshift1) & 0x30)); + u8vec2 scale_dm = u8vec2(sc, mbyte); + + return FLOAT_TYPEV2(data_a_packed32[ib_k].dm) * FLOAT_TYPEV2(scale_dm); } FLOAT_TYPE mmvq_dot_product(const uint ib_a, const uint iqs) { @@ -334,27 +355,39 @@ i32vec4 repack4(uint ib, uint iqs) { const uint qh_idx = (iqs_k / 32) * 8 + iqs; const uint qh_shift = ((iqs_k % 32) / 8) * 2; - const i8vec2 vals00 = (unpack8(int16_t((data_a_packed16[ib_k].ql[ql_idx * 2 ] >> ql_shift) & uint16_t(0x0F0F))) | - unpack8(int16_t(((data_a_packed16[ib_k].qh[qh_idx * 2 ] >> qh_shift) & uint16_t(0x0303)) << 4))) - int8_t(32); - const i8vec2 vals01 = (unpack8(int16_t((data_a_packed16[ib_k].ql[ql_idx * 2 + 1] >> ql_shift) & uint16_t(0x0F0F))) | - unpack8(int16_t(((data_a_packed16[ib_k].qh[qh_idx * 2 + 1] >> qh_shift) & uint16_t(0x0303)) << 4))) - int8_t(32); - const i8vec2 vals10 = (unpack8(int16_t((data_a_packed16[ib_k].ql[ql_idx * 2 + 2] >> ql_shift) & uint16_t(0x0F0F))) | - unpack8(int16_t(((data_a_packed16[ib_k].qh[qh_idx * 2 + 2] >> qh_shift) & uint16_t(0x0303)) << 4))) - int8_t(32); - const i8vec2 vals11 = (unpack8(int16_t((data_a_packed16[ib_k].ql[ql_idx * 2 + 3] >> ql_shift) & uint16_t(0x0F0F))) | - unpack8(int16_t(((data_a_packed16[ib_k].qh[qh_idx * 2 + 3] >> qh_shift) & uint16_t(0x0303)) << 4))) - int8_t(32); - const i8vec2 vals20 = (unpack8(int16_t((data_a_packed16[ib_k].ql[ql_idx * 2 + 4] >> ql_shift) & uint16_t(0x0F0F))) | - unpack8(int16_t(((data_a_packed16[ib_k].qh[qh_idx * 2 + 4] >> qh_shift) & uint16_t(0x0303)) << 4))) - int8_t(32); - const i8vec2 vals21 = (unpack8(int16_t((data_a_packed16[ib_k].ql[ql_idx * 2 + 5] >> ql_shift) & uint16_t(0x0F0F))) | - unpack8(int16_t(((data_a_packed16[ib_k].qh[qh_idx * 2 + 5] >> qh_shift) & uint16_t(0x0303)) << 4))) - int8_t(32); - const i8vec2 vals30 = (unpack8(int16_t((data_a_packed16[ib_k].ql[ql_idx * 2 + 6] >> ql_shift) & uint16_t(0x0F0F))) | - unpack8(int16_t(((data_a_packed16[ib_k].qh[qh_idx * 2 + 6] >> qh_shift) & uint16_t(0x0303)) << 4))) - int8_t(32); - const i8vec2 vals31 = (unpack8(int16_t((data_a_packed16[ib_k].ql[ql_idx * 2 + 7] >> ql_shift) & uint16_t(0x0F0F))) | - unpack8(int16_t(((data_a_packed16[ib_k].qh[qh_idx * 2 + 7] >> qh_shift) & uint16_t(0x0303)) << 4))) - int8_t(32); - - return i32vec4(pack32(i8vec4(vals00.x, vals00.y, vals01.x, vals01.y)), - pack32(i8vec4(vals10.x, vals10.y, vals11.x, vals11.y)), - pack32(i8vec4(vals20.x, vals20.y, vals21.x, vals21.y)), - pack32(i8vec4(vals30.x, vals30.y, vals31.x, vals31.y))); + const uvec4 ql = uvec4( uint32_t(data_a_packed16[ib_k].ql[ql_idx * 2 ]) | + (uint32_t(data_a_packed16[ib_k].ql[ql_idx * 2 + 1]) << 16), + uint32_t(data_a_packed16[ib_k].ql[ql_idx * 2 + 2]) | + (uint32_t(data_a_packed16[ib_k].ql[ql_idx * 2 + 3]) << 16), + uint32_t(data_a_packed16[ib_k].ql[ql_idx * 2 + 4]) | + (uint32_t(data_a_packed16[ib_k].ql[ql_idx * 2 + 5]) << 16), + uint32_t(data_a_packed16[ib_k].ql[ql_idx * 2 + 6]) | + (uint32_t(data_a_packed16[ib_k].ql[ql_idx * 2 + 7]) << 16)); + + const uvec4 qh = uvec4( uint32_t(data_a_packed16[ib_k].qh[qh_idx * 2 ]) | + (uint32_t(data_a_packed16[ib_k].qh[qh_idx * 2 + 1]) << 16), + uint32_t(data_a_packed16[ib_k].qh[qh_idx * 2 + 2]) | + (uint32_t(data_a_packed16[ib_k].qh[qh_idx * 2 + 3]) << 16), + uint32_t(data_a_packed16[ib_k].qh[qh_idx * 2 + 4]) | + (uint32_t(data_a_packed16[ib_k].qh[qh_idx * 2 + 5]) << 16), + uint32_t(data_a_packed16[ib_k].qh[qh_idx * 2 + 6]) | + (uint32_t(data_a_packed16[ib_k].qh[qh_idx * 2 + 7]) << 16)); + + const uint vals0 = (( ql.x >> ql_shift) & 0x0F0F0F0F) | + (((qh.x >> qh_shift) & 0x03030303) << 4); + const uint vals1 = (( ql.y >> ql_shift) & 0x0F0F0F0F) | + (((qh.y >> qh_shift) & 0x03030303) << 4); + const uint vals2 = (( ql.z >> ql_shift) & 0x0F0F0F0F) | + (((qh.z >> qh_shift) & 0x03030303) << 4); + const uint vals3 = (( ql.w >> ql_shift) & 0x0F0F0F0F) | + (((qh.w >> qh_shift) & 0x03030303) << 4); + + // Subtract 32 by twiddling bits rather than using re-packing as mesa + // compiles repacking poorly. + return i32vec4(int32_t(((vals0 ^ 0x80808080) - 0x20202020) ^ 0x80808080), + int32_t(((vals1 ^ 0x80808080) - 0x20202020) ^ 0x80808080), + int32_t(((vals2 ^ 0x80808080) - 0x20202020) ^ 0x80808080), + int32_t(((vals3 ^ 0x80808080) - 0x20202020) ^ 0x80808080)); } float get_d_scale(uint ib, uint iqs) { @@ -422,7 +455,7 @@ vec2 get_dm(uint ib, uint iqs) { const float dl = d * float(2 * bitfieldExtract(qh, 12, 3) + 1); // the -1 cancels out the bias in iq1s_grid_gpu - return FLOAT_TYPE_VEC2(dl, dl * (delta - 1)); + return FLOAT_TYPEV2(dl, dl * (delta - 1)); } FLOAT_TYPE mmvq_dot_product(const uint ib_a, const uint iqs) { diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp index 775e9a70f6d..f39410d74f0 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp @@ -29,6 +29,7 @@ #endif #include "types.glsl" +#include "dot_product_funcs.glsl" #ifndef LOAD_VEC_A #define LOAD_VEC_A 1 @@ -90,6 +91,8 @@ layout (push_constant) uniform parameter uint nbi1; uint ne11; #else + uint base_work_group_z; + uint num_batches; uint k_split; uint ne02; uint ne12; @@ -123,8 +126,8 @@ layout (constant_id = 3) const uint BK = 16; // Assumed to be 32 if working wit #define SHMEM_STRIDE (BK / 2 + 1) #endif -shared FLOAT_TYPE_VEC2 buf_a[BM * SHMEM_STRIDE]; -shared FLOAT_TYPE_VEC2 buf_b[BN * SHMEM_STRIDE]; +shared FLOAT_TYPEV2 buf_a[BM * SHMEM_STRIDE]; +shared FLOAT_TYPEV2 buf_b[BN * SHMEM_STRIDE]; #define NUM_WARPS (BLOCK_SIZE / WARP) @@ -139,7 +142,7 @@ void main() { const uint ic = gl_WorkGroupID.y; #ifdef MUL_MAT_ID - const uint expert_idx = gl_GlobalInvocationID.z; + const uint expert_idx = gl_WorkGroupID.z; if (ic * BN >= data_expert_count[expert_idx]) { return; } @@ -149,7 +152,7 @@ void main() { #endif #ifndef MUL_MAT_ID - const uint batch_idx = gl_GlobalInvocationID.z; + const uint batch_idx = gl_WorkGroupID.z + p.base_work_group_z; const uint i13 = batch_idx / p.ne12; const uint i12 = batch_idx % p.ne12; @@ -256,17 +259,17 @@ void main() { sums[i] = coopmat<ACC_TYPE, gl_ScopeSubgroup, TM, TN, gl_MatrixUseAccumulator>(0.0f); } #else - ACC_TYPE_VEC2 sums[WMITER * TM * WNITER * TN/2]; + ACC_TYPEV2 sums[WMITER * TM * WNITER * TN/2]; #if defined(DATA_A_F32) || defined(DATA_A_F16) - FLOAT_TYPE_VEC4 cache_a[WMITER * TM]; - FLOAT_TYPE_VEC4 cache_b; + FLOAT_TYPEV4 cache_a[WMITER * TM]; + FLOAT_TYPEV4 cache_b; #else - FLOAT_TYPE_VEC2 cache_a[WMITER * TM]; - FLOAT_TYPE_VEC2 cache_b; + FLOAT_TYPEV2 cache_a[WMITER * TM]; + FLOAT_TYPEV2 cache_b; #endif [[unroll]] for (uint i = 0; i < WMITER*TM*WNITER*TN/2; i++) { - sums[i] = ACC_TYPE_VEC2(0.0f, 0.0f); + sums[i] = ACC_TYPEV2(0.0f, 0.0f); } #endif @@ -327,15 +330,8 @@ void main() { [[unroll]] for (uint cr = 0; cr < TM / 2; cr++) { // [WNITER][TN][WMITER][TM / 2] -> [wsic][cc][wsir][cr] const uint sums_idx = (wsic * TN + cc) * WMITER * (TM / 2) + wsir * (TM / 2) + cr; - #if defined(DATA_A_F32) || defined(DATA_A_F16) - sums[sums_idx].x = fma(ACC_TYPE(cache_a[wsir * TM + 2 * cr ].x), ACC_TYPE(cache_b.x), fma(ACC_TYPE(cache_a[wsir * TM + 2 * cr ].y), ACC_TYPE(cache_b.y), - fma(ACC_TYPE(cache_a[wsir * TM + 2 * cr ].z), ACC_TYPE(cache_b.z), fma(ACC_TYPE(cache_a[wsir * TM + 2 * cr ].w), ACC_TYPE(cache_b.w), sums[sums_idx].x)))); - sums[sums_idx].y = fma(ACC_TYPE(cache_a[wsir * TM + 2 * cr + 1].x), ACC_TYPE(cache_b.x), fma(ACC_TYPE(cache_a[wsir * TM + 2 * cr + 1].y), ACC_TYPE(cache_b.y), - fma(ACC_TYPE(cache_a[wsir * TM + 2 * cr + 1].z), ACC_TYPE(cache_b.z), fma(ACC_TYPE(cache_a[wsir * TM + 2 * cr + 1].w), ACC_TYPE(cache_b.w), sums[sums_idx].y)))); - #else - sums[sums_idx].x = fma(ACC_TYPE(cache_a[wsir * TM + 2 * cr ].x), ACC_TYPE(cache_b.x), fma(ACC_TYPE(cache_a[wsir * TM + 2 * cr ].y), ACC_TYPE(cache_b.y), sums[sums_idx].x)); - sums[sums_idx].y = fma(ACC_TYPE(cache_a[wsir * TM + 2 * cr + 1].x), ACC_TYPE(cache_b.x), fma(ACC_TYPE(cache_a[wsir * TM + 2 * cr + 1].y), ACC_TYPE(cache_b.y), sums[sums_idx].y)); - #endif + sums[sums_idx].x = dot_product(cache_a[wsir * TM + 2 * cr ], cache_b, sums[sums_idx].x); + sums[sums_idx].y = dot_product(cache_a[wsir * TM + 2 * cr + 1], cache_b, sums[sums_idx].y); } } } @@ -366,7 +362,7 @@ void main() { const uint dc = ic * BN + warp_c * WN; #ifndef MUL_MAT_ID - const uint offsets = batch_idx * p.batch_stride_d + ik * p.batch_stride_d * gl_NumWorkGroups.z; + const uint offsets = batch_idx * p.batch_stride_d + ik * p.batch_stride_d * p.num_batches; #endif #ifdef COOPMAT @@ -375,6 +371,7 @@ void main() { [[unroll]] for (uint cm_col = 0; cm_col < cms_per_col; cm_col++) { coopMatStore(sums[cm_col * cms_per_row + cm_row], coopmat_stage, warp_i * TM * TN, TM, gl_CooperativeMatrixLayoutColumnMajor); + barrier(); [[unroll]] for (uint col = 0; col < TN; col += storestride) { const uint row_i = dc + cm_col * TN + col + store_c; if (row_i >= _ne1) break; @@ -385,6 +382,7 @@ void main() { data_d[row_idx.y * p.batch_stride_d + row_idx.x * p.stride_d + dr + cm_row * TM + store_r] = D_TYPE(coopmat_stage[warp_i * TM * TN + (col + store_c) * TM + store_r]); } } + barrier(); } } #else @@ -402,18 +400,22 @@ void main() { // Full coopMat is within bounds, but stride_d is not aligned coopMatStore(sums[cm_col * cms_per_row + cm_row], coopmat_stage, warp_i * TM * TN, TM, gl_CooperativeMatrixLayoutColumnMajor); + controlBarrier(gl_ScopeSubgroup, gl_ScopeSubgroup, gl_StorageSemanticsShared, gl_SemanticsAcquireRelease); [[unroll]] for (uint col = 0; col < TN; col += storestride) { data_d[offsets + (dc + cm_col * TN + col + store_c) * p.stride_d + dr + cm_row * TM + store_r] = D_TYPE(coopmat_stage[warp_i * TM * TN + (col + store_c) * TM + store_r]); } + controlBarrier(gl_ScopeSubgroup, gl_ScopeSubgroup, gl_StorageSemanticsShared, gl_SemanticsAcquireRelease); } else if (dr + cm_row * TM < p.M && dc + cm_col * TN < p.N) { // Partial coopMat is within bounds coopMatStore(sums[cm_col * cms_per_row + cm_row], coopmat_stage, warp_i * TM * TN, TM, gl_CooperativeMatrixLayoutColumnMajor); + controlBarrier(gl_ScopeSubgroup, gl_ScopeSubgroup, gl_StorageSemanticsShared, gl_SemanticsAcquireRelease); [[unroll]] for (uint col = 0; col < TN; col += storestride) { if (dr + cm_row * TM + store_r < p.M && dc + cm_col * TN + col + store_c < p.N) { data_d[offsets + (dc + cm_col * TN + col + store_c) * p.stride_d + dr + cm_row * TM + store_r] = D_TYPE(coopmat_stage[warp_i * TM * TN + (col + store_c) * TM + store_r]); } } + controlBarrier(gl_ScopeSubgroup, gl_ScopeSubgroup, gl_StorageSemanticsShared, gl_SemanticsAcquireRelease); } } } diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp index b6614d2fc59..2656fe1c3e9 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp @@ -11,6 +11,9 @@ #extension GL_KHR_memory_scope_semantics : enable #extension GL_KHR_cooperative_matrix : enable #extension GL_NV_cooperative_matrix2 : enable +#ifdef GGML_VULKAN_COOPMAT2_DECODE_VECTOR +#extension GL_NV_cooperative_matrix_decode_vector : enable +#endif #extension GL_EXT_buffer_reference : enable #extension GL_KHR_shader_subgroup_ballot : enable #extension GL_KHR_shader_subgroup_vote : enable @@ -53,6 +56,8 @@ layout (push_constant) uniform parameter uint nbi1; uint ne11; #else + uint base_work_group_z; + uint num_batches; uint k_split; uint ne02; uint ne12; @@ -67,12 +72,17 @@ layout (push_constant) uniform parameter layout (binding = 0) readonly buffer A {A_TYPE data_a[];}; layout (binding = 1) readonly buffer B {B_TYPE data_b[];}; layout (binding = 2) writeonly buffer D {D_TYPE data_d[];}; +#if defined(MUL_MAT_ID) && defined(GGML_VULKAN_COOPMAT2_DECODE_VECTOR) +layout (binding = 1) readonly buffer B4 {B_TYPEV4 data_b_v4[];}; +#endif #if QUANT_K > 1 -#define DECODEFUNCA , dequantFuncA - #include "dequant_funcs_cm2.glsl" - +#if defined(dequantFuncA_v) && defined(GGML_VULKAN_COOPMAT2_DECODE_VECTOR) +#define DECODEFUNCA , dequantFuncA, dequantFuncA_v +#else +#define DECODEFUNCA , dequantFuncA +#endif #else #define DECODEFUNCA #endif @@ -109,11 +119,33 @@ B_TYPE decodeFuncB(const in decodeBufB bl, const in uint blockCoords[2], const i const uint row_i = blockCoords[0]; const u16vec4 row_idx = row_ids[row_i]; - B_TYPE ret = data_b[row_idx.y * p.batch_stride_b + row_idx.x * p.stride_b + blockCoords[1]]; +#if defined(GGML_VULKAN_COOPMAT2_DECODE_VECTOR) + // The decode-vector path gives B a K-dimension tensor-layout block size of BK. + const uint k = blockCoords[1] * BK + coordInBlock[1]; +#else + const uint k = blockCoords[1]; +#endif + B_TYPE ret = data_b[row_idx.y * p.batch_stride_b + row_idx.x * p.stride_b + k]; return ret; } +#if defined(GGML_VULKAN_COOPMAT2_DECODE_VECTOR) +B_TYPEV4 decodeFuncB_v(const in decodeBufB bl, const in uint blockCoords[2], const in uint coordInBlock[2]) +{ + const uint row_i = blockCoords[0]; + + const u16vec4 row_idx = row_ids[row_i]; + const uint k = blockCoords[1] * BK + coordInBlock[1]; + const uint base = row_idx.y * p.batch_stride_b + row_idx.x * p.stride_b + k; + + return data_b_v4[base >> 2]; +} +#define DECODEFUNCB , decodeFuncB, decodeFuncB_v +#else +#define DECODEFUNCB , decodeFuncB +#endif + D_TYPE perElemOpD(const in uint32_t r, const in uint32_t c, const in D_TYPE elem, const in uint32_t ir, const in uint32_t ic) { uint dr = ir * BM + r; @@ -165,7 +197,9 @@ void load_row_ids(uint expert_idx, bool nei0_is_pow2, uint ic) { uint id = ids[iter++]; uvec4 ballot = subgroupBallot(in_range && id == expert_idx); - ballots_sh[gl_SubgroupID] = ballot; + if (gl_SubgroupInvocationID == 0) { + ballots_sh[gl_SubgroupID] = ballot; + } barrier(); uint subgroup_base = 0; @@ -197,7 +231,7 @@ void main() { const uint ic = gl_WorkGroupID.y; #ifdef MUL_MAT_ID - const uint expert_idx = gl_GlobalInvocationID.z; + const uint expert_idx = gl_WorkGroupID.z; if (ic * BN >= data_expert_count[expert_idx]) { return; } @@ -215,7 +249,7 @@ void main() { #endif #ifndef MUL_MAT_ID - const uint batch_idx = gl_GlobalInvocationID.z; + const uint batch_idx = gl_WorkGroupID.z + p.base_work_group_z; const uint i13 = batch_idx / p.ne12; const uint i12 = batch_idx % p.ne12; @@ -255,7 +289,7 @@ void main() { #else uint pos_a = batch_idx_a * (p.batch_stride_a / QUANT_K); uint pos_b = batch_idx * p.batch_stride_b; - uint pos_d = batch_idx * p.batch_stride_d + ik * p.batch_stride_d * gl_NumWorkGroups.z; + uint pos_d = batch_idx * p.batch_stride_d + ik * p.batch_stride_d * p.num_batches; #endif uint stride_a = p.stride_a / QUANT_K; @@ -281,6 +315,9 @@ void main() { tensorLayoutA = setTensorLayoutBlockSizeNV(tensorLayoutA, 1, QUANT_K); tensorLayoutAClamp = setTensorLayoutBlockSizeNV(tensorLayoutAClamp, 1, QUANT_K); #endif +#if defined(MUL_MAT_ID) && defined(GGML_VULKAN_COOPMAT2_DECODE_VECTOR) + tensorLayoutB = setTensorLayoutBlockSizeNV(tensorLayoutB, 1, BK); +#endif // Use end_k rather than p.K as the dimension because that's what // we need to bound check against when using split_k. @@ -493,7 +530,7 @@ void main() { coopmat<MAT_TYPE, gl_ScopeWorkgroup, BK, BNover4, gl_MatrixUseB> mat_b; coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutA, ir * BM, BM, block_k, BK) DECODEFUNCA); - coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, 0, BNover4, block_k, BK), tensorViewTranspose, decodeFuncB); + coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, 0, BNover4, block_k, BK), tensorViewTranspose DECODEFUNCB); sum = coopMatMulAdd(mat_a, mat_b, sum); } else { @@ -501,7 +538,7 @@ void main() { coopmat<MAT_TYPE, gl_ScopeWorkgroup, BK, BNover4, gl_MatrixUseB> mat_b; coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutAClamp, ir * BM, BM, block_k, BK) DECODEFUNCA); - coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, 0, BNover4, block_k, BK), tensorViewTranspose, decodeFuncB); + coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, 0, BNover4, block_k, BK), tensorViewTranspose DECODEFUNCB); sum = coopMatMulAdd(mat_a, mat_b, sum); } @@ -537,7 +574,7 @@ void main() { coopmat<MAT_TYPE, gl_ScopeWorkgroup, BK, BNover2, gl_MatrixUseB> mat_b; coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutA, ir * BM, BM, block_k, BK) DECODEFUNCA); - coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, 0, BNover2, block_k, BK), tensorViewTranspose, decodeFuncB); + coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, 0, BNover2, block_k, BK), tensorViewTranspose DECODEFUNCB); sum = coopMatMulAdd(mat_a, mat_b, sum); } else { @@ -545,7 +582,7 @@ void main() { coopmat<MAT_TYPE, gl_ScopeWorkgroup, BK, BNover2, gl_MatrixUseB> mat_b; coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutAClamp, ir * BM, BM, block_k, BK) DECODEFUNCA); - coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, 0, BNover2, block_k, BK), tensorViewTranspose, decodeFuncB); + coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, 0, BNover2, block_k, BK), tensorViewTranspose DECODEFUNCB); sum = coopMatMulAdd(mat_a, mat_b, sum); } @@ -582,7 +619,7 @@ void main() { coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutA, ir * BM, BM, block_k, BK) DECODEFUNCA); #ifdef MUL_MAT_ID - coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, 0, BN, block_k, BK), tensorViewTranspose, decodeFuncB); + coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, 0, BN, block_k, BK), tensorViewTranspose DECODEFUNCB); #else coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutBClamp, ic * BN, BN, block_k, BK), tensorViewTranspose); #endif @@ -594,7 +631,7 @@ void main() { coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutAClamp, ir * BM, BM, block_k, BK) DECODEFUNCA); #ifdef MUL_MAT_ID - coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, 0, BN, block_k, BK), tensorViewTranspose, decodeFuncB); + coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, 0, BN, block_k, BK), tensorViewTranspose DECODEFUNCB); #else coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutBClamp, ic * BN, BN, block_k, BK), tensorViewTranspose); #endif diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_funcs.glsl b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_funcs.glsl index ce7f2d699a2..73595168984 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_funcs.glsl +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_funcs.glsl @@ -3,7 +3,7 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin #if LOAD_VEC_A == 8 const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row; const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2; - FLOAT_TYPE_VEC8 aa = FLOAT_TYPE_VEC8(data_a[idx]); + FLOAT_TYPEV8 aa = FLOAT_TYPEV8(data_a[idx]); buf_a[buf_idx ] = aa[0].xy; buf_a[buf_idx + 1] = aa[0].zw; buf_a[buf_idx + 2] = aa[1].xy; @@ -11,38 +11,38 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin #elif LOAD_VEC_A == 4 const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row; const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2; - FLOAT_TYPE_VEC4 aa = FLOAT_TYPE_VEC4(data_a[idx]); + FLOAT_TYPEV4 aa = FLOAT_TYPEV4(data_a[idx]); buf_a[buf_idx ] = aa.xy; buf_a[buf_idx + 1] = aa.zw; #else // LOAD_VEC_BATCH_A == 2 const uint idx = pos_a + col * p.stride_a + row * 2; const uint buf_idx = col * SHMEM_STRIDE + row; if (idx_m < p.M && block + row * 2 + 1 < end_k) { - buf_a[buf_idx] = FLOAT_TYPE_VEC2(data_a[idx], - data_a[idx + 1]); + buf_a[buf_idx] = FLOAT_TYPEV2(data_a[idx], + data_a[idx + 1]); } else if (idx_m < p.M && block + row * 2 < end_k) { - buf_a[buf_idx] = FLOAT_TYPE_VEC2(data_a[idx], 0.0f); + buf_a[buf_idx] = FLOAT_TYPEV2(data_a[idx], 0.0f); } else { - buf_a[buf_idx] = FLOAT_TYPE_VEC2(0.0f); + buf_a[buf_idx] = FLOAT_TYPEV2(0.0f); } #endif #elif defined(DATA_A_BF16) #if LOAD_VEC_A == 4 const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row; const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2; - FLOAT_TYPE_VEC4 aa = FLOAT_TYPE_VEC4(TO_FLOAT_TYPE(data_a[idx])); + FLOAT_TYPEV4 aa = FLOAT_TYPEV4(TO_FLOAT_TYPE(data_a[idx])); buf_a[buf_idx ] = aa.xy; buf_a[buf_idx + 1] = aa.zw; #else // LOAD_VEC_BATCH_A == 2 const uint idx = pos_a + col * p.stride_a + row * 2; const uint buf_idx = col * SHMEM_STRIDE + row; if (idx_m < p.M && block + row * 2 + 1 < end_k) { - buf_a[buf_idx] = FLOAT_TYPE_VEC2(TO_FLOAT_TYPE(data_a[idx]), - TO_FLOAT_TYPE(data_a[idx + 1])); + buf_a[buf_idx] = FLOAT_TYPEV2(TO_FLOAT_TYPE(data_a[idx]), + TO_FLOAT_TYPE(data_a[idx + 1])); } else if (idx_m < p.M && block + row * 2 < end_k) { - buf_a[buf_idx] = FLOAT_TYPE_VEC2(TO_FLOAT_TYPE(data_a[idx]), 0.0f); + buf_a[buf_idx] = FLOAT_TYPEV2(TO_FLOAT_TYPE(data_a[idx]), 0.0f); } else { - buf_a[buf_idx] = FLOAT_TYPE_VEC2(0.0f); + buf_a[buf_idx] = FLOAT_TYPEV2(0.0f); } #endif #elif defined(DATA_A_Q4_0) @@ -57,10 +57,10 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin const vec4 v0 = (vec4(unpack8(vui & 0x0F0F0F0F)) - 8.0f) * d; const vec4 v1 = (vec4(unpack8((vui >> 4) & 0x0F0F0F0F)) - 8.0f) * d; - buf_a[buf_idx ] = FLOAT_TYPE_VEC2(v0.xy); - buf_a[buf_idx + 1] = FLOAT_TYPE_VEC2(v0.zw); - buf_a[buf_idx + 8] = FLOAT_TYPE_VEC2(v1.xy); - buf_a[buf_idx + 9] = FLOAT_TYPE_VEC2(v1.zw); + buf_a[buf_idx ] = FLOAT_TYPEV2(v0.xy); + buf_a[buf_idx + 1] = FLOAT_TYPEV2(v0.zw); + buf_a[buf_idx + 8] = FLOAT_TYPEV2(v1.xy); + buf_a[buf_idx + 9] = FLOAT_TYPEV2(v1.zw); #elif defined(DATA_A_Q4_1) const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row; const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 4; @@ -73,10 +73,10 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin const vec4 v0 = vec4(unpack8(vui & 0x0F0F0F0F)) * dm.x + dm.y; const vec4 v1 = vec4(unpack8((vui >> 4) & 0x0F0F0F0F)) * dm.x + dm.y; - buf_a[buf_idx ] = FLOAT_TYPE_VEC2(v0.xy); - buf_a[buf_idx + 1 ] = FLOAT_TYPE_VEC2(v0.zw); - buf_a[buf_idx + 8 ] = FLOAT_TYPE_VEC2(v1.xy); - buf_a[buf_idx + 9 ] = FLOAT_TYPE_VEC2(v1.zw); + buf_a[buf_idx ] = FLOAT_TYPEV2(v0.xy); + buf_a[buf_idx + 1 ] = FLOAT_TYPEV2(v0.zw); + buf_a[buf_idx + 8 ] = FLOAT_TYPEV2(v1.xy); + buf_a[buf_idx + 9 ] = FLOAT_TYPEV2(v1.zw); #elif defined(DATA_A_Q5_0) const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row; const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 4; @@ -92,8 +92,8 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin const uint vui = uint(data_a_packed16[ib].qs[iqs]); const vec4 v = (vec4((vui & 0xF) | qh0.x, ((vui >> 4) & 0xF) | qh0.y, ((vui >> 8) & 0xF) | qh1.x, (vui >> 12) | qh1.y) - 16.0f) * d; - buf_a[buf_idx ] = FLOAT_TYPE_VEC2(v.xz); - buf_a[buf_idx + 8] = FLOAT_TYPE_VEC2(v.yw); + buf_a[buf_idx ] = FLOAT_TYPEV2(v.xz); + buf_a[buf_idx + 8] = FLOAT_TYPEV2(v.yw); #elif defined(DATA_A_Q5_1) const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row; const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 4; @@ -112,10 +112,10 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin const vec4 v0 = vec4((vui & 0xF) | qh0.x, ((vui >> 4) & 0xF) | qh0.y, ((vui >> 8) & 0xF) | qh1.x, ((vui >> 12) & 0xF) | qh1.y) * dm.x + dm.y; const vec4 v1 = vec4(((vui >> 16) & 0xF) | qh2.x, ((vui >> 20) & 0xF) | qh2.y, ((vui >> 24) & 0xF) | qh3.x, ((vui >> 28) & 0xF) | qh3.y) * dm.x + dm.y; - buf_a[buf_idx ] = FLOAT_TYPE_VEC2(v0.xz); - buf_a[buf_idx + 1] = FLOAT_TYPE_VEC2(v1.xz); - buf_a[buf_idx + 8] = FLOAT_TYPE_VEC2(v0.yw); - buf_a[buf_idx + 9] = FLOAT_TYPE_VEC2(v1.yw); + buf_a[buf_idx ] = FLOAT_TYPEV2(v0.xz); + buf_a[buf_idx + 1] = FLOAT_TYPEV2(v1.xz); + buf_a[buf_idx + 8] = FLOAT_TYPEV2(v0.yw); + buf_a[buf_idx + 9] = FLOAT_TYPEV2(v1.yw); #elif defined(DATA_A_Q8_0) const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row; const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2; @@ -128,8 +128,22 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin const i8vec2 v1 = unpack8(int32_t(data_a_packed16[ib].qs[2*iqs + 1])).xy; const vec4 v = vec4(v0.x, v0.y, v1.x, v1.y) * d; - buf_a[buf_idx ] = FLOAT_TYPE_VEC2(v.xy); - buf_a[buf_idx + 1] = FLOAT_TYPE_VEC2(v.zw); + buf_a[buf_idx ] = FLOAT_TYPEV2(v.xy); + buf_a[buf_idx + 1] = FLOAT_TYPEV2(v.zw); +#elif defined(DATA_A_Q1_0) + const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row; + const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2; + + const uint ib = idx / 16; + const uint iqs = idx & 0xfu; + + const float d = float(data_a[ib].d); + const uint bits = uint(data_a[ib].qs[iqs]); + + buf_a[buf_idx ] = FLOAT_TYPEV2((bits & 0x01u) != 0u ? d : -d, (bits & 0x02u) != 0u ? d : -d); + buf_a[buf_idx + 1] = FLOAT_TYPEV2((bits & 0x04u) != 0u ? d : -d, (bits & 0x08u) != 0u ? d : -d); + buf_a[buf_idx + 2] = FLOAT_TYPEV2((bits & 0x10u) != 0u ? d : -d, (bits & 0x20u) != 0u ? d : -d); + buf_a[buf_idx + 3] = FLOAT_TYPEV2((bits & 0x40u) != 0u ? d : -d, (bits & 0x80u) != 0u ? d : -d); #elif defined(DATA_A_Q2_K) const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row; const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2; @@ -147,8 +161,8 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin const vec4 v = dm.x * float(scales & 0xF) * qs - dm.y * float(scales >> 4); - buf_a[buf_idx ] = FLOAT_TYPE_VEC2(v.xy); - buf_a[buf_idx + 1] = FLOAT_TYPE_VEC2(v.zw); + buf_a[buf_idx ] = FLOAT_TYPEV2(v.xy); + buf_a[buf_idx + 1] = FLOAT_TYPEV2(v.zw); #elif defined(DATA_A_Q3_K) const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row; const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2; @@ -171,8 +185,8 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin const vec2 qs = vec2(unpack8((uint(data_a_packed16[ib].qs[qsi / 2]) >> qsshift) & 0x0303).xy); const vec2 hm = vec2(unpack8(((uint(data_a_packed16[ib].hmask[hmi / 2]) >> (4 * n + halfsplit)) & 0x0101 ^ 0x0101) << 2).xy); - buf_a[buf_idx] = FLOAT_TYPE_VEC2(dl * (qs.x - hm.x), - dl * (qs.y - hm.y)); + buf_a[buf_idx] = FLOAT_TYPEV2(dl * (qs.x - hm.x), + dl * (qs.y - hm.y)); #elif defined(DATA_A_Q4_K) const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row; const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2; @@ -187,27 +201,28 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin const vec2 loadd = vec2(data_a[ib].dm); - const uint scidx0 = (is < 4) ? is : (is + 4); - const uint scidx1 = (is < 4) ? is : (is - 4); - const uint scidxmask1 = (is < 4) ? 0x30 : 0xC0; - const uint scidxshift1 = (is < 4) ? 0 : 2; - const uint mbidx0 = is + 4; - const uint mbidx1 = (is < 4) ? is + 4 : is; - const uint mbidxmask0 = (is < 4) ? 0xF : 0xF0; - const uint mbidxshift0 = (is < 4) ? 0 : 4; - const uint mbidxmask1 = (is < 4) ? 0x30 : 0xC0; - const uint mbidxshift1 = (is < 4) ? 0 : 2; + const uvec3 scales = uvec3(data_a_packed32[ib].scales[0], + data_a_packed32[ib].scales[1], + data_a_packed32[ib].scales[2]); + const uint scalesoffs = (is & 3) * 8; + + const uint scidx0 = (is < 4) ? 0 : 2; + const uint scidxshift0 = scalesoffs; + const uint scidxshift1 = (is < 4) ? scalesoffs : scalesoffs + 2; + const uint mbidx0 = (is < 4) ? 1 : 2; + const uint mbidxshift0 = (is < 4) ? scalesoffs : scalesoffs + 4; + const uint mbidxshift1 = (is < 4) ? scalesoffs : scalesoffs + 2; - const uint8_t sc = uint8_t((data_a[ib].scales[scidx0] & 0xF) | ((data_a[ib].scales[scidx1] & scidxmask1) >> scidxshift1)); - const uint8_t mbyte = uint8_t((data_a[ib].scales[mbidx0] & mbidxmask0) >> mbidxshift0 | ((data_a[ib].scales[mbidx1] & mbidxmask1) >> mbidxshift1)); + const uint8_t sc = uint8_t(((scales[scidx0] >> scidxshift0) & 0xF) | ((scales[0] >> scidxshift1) & 0x30)); + const uint8_t mbyte = uint8_t(((scales[mbidx0] >> mbidxshift0) & 0xF) | ((scales[1] >> mbidxshift1) & 0x30)); const float d = loadd.x * sc; const float m = -loadd.y * mbyte; const vec4 q = vec4(unpack8((data_a_packed32[ib].qs[qsi / 4] >> (b * 4)) & 0x0F0F0F0F)); - buf_a[buf_idx ] = FLOAT_TYPE_VEC2(fma(d, q.x, m), fma(d, q.y, m)); - buf_a[buf_idx + 1] = FLOAT_TYPE_VEC2(fma(d, q.z, m), fma(d, q.w, m)); + buf_a[buf_idx ] = FLOAT_TYPEV2(fma(d, q.x, m), fma(d, q.y, m)); + buf_a[buf_idx + 1] = FLOAT_TYPEV2(fma(d, q.z, m), fma(d, q.w, m)); #elif defined(DATA_A_Q5_K) const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row; const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2; @@ -223,19 +238,20 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin const vec2 loadd = vec2(data_a[ib].dm); - const uint scidx0 = (is < 4) ? is : (is + 4); - const uint scidx1 = (is < 4) ? is : (is - 4); - const uint scidxmask1 = (is < 4) ? 0x30 : 0xC0; - const uint scidxshift1 = (is < 4) ? 0 : 2; - const uint mbidx0 = is + 4; - const uint mbidx1 = (is < 4) ? is + 4 : is; - const uint mbidxmask0 = (is < 4) ? 0xF : 0xF0; - const uint mbidxshift0 = (is < 4) ? 0 : 4; - const uint mbidxmask1 = (is < 4) ? 0x30 : 0xC0; - const uint mbidxshift1 = (is < 4) ? 0 : 2; + const uvec3 scales = uvec3(data_a_packed32[ib].scales[0], + data_a_packed32[ib].scales[1], + data_a_packed32[ib].scales[2]); + const uint scalesoffs = (is & 3) * 8; - const uint8_t sc = uint8_t((data_a[ib].scales[scidx0] & 0xF) | ((data_a[ib].scales[scidx1] & scidxmask1) >> scidxshift1)); - const uint8_t mbyte = uint8_t(((data_a[ib].scales[mbidx0] & mbidxmask0) >> mbidxshift0) | ((data_a[ib].scales[mbidx1] & mbidxmask1) >> mbidxshift1)); + const uint scidx0 = (is < 4) ? 0 : 2; + const uint scidxshift0 = scalesoffs; + const uint scidxshift1 = (is < 4) ? scalesoffs : scalesoffs + 2; + const uint mbidx0 = (is < 4) ? 1 : 2; + const uint mbidxshift0 = (is < 4) ? scalesoffs : scalesoffs + 4; + const uint mbidxshift1 = (is < 4) ? scalesoffs : scalesoffs + 2; + + const uint8_t sc = uint8_t(((scales[scidx0] >> scidxshift0) & 0xF) | ((scales[0] >> scidxshift1) & 0x30)); + const uint8_t mbyte = uint8_t(((scales[mbidx0] >> mbidxshift0) & 0xF) | ((scales[1] >> mbidxshift1) & 0x30)); const float d = loadd.x * sc; const float m = -loadd.y * mbyte; @@ -244,8 +260,8 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin const uint qh = ((data_a_packed32[ib].qh[qhi / 4] >> (iqs / 16)) & 0x01010101) << 4; const vec4 q = vec4(unpack8(qs | qh)); - buf_a[buf_idx ] = FLOAT_TYPE_VEC2(fma(d, q.x, m), fma(d, q.y, m)); - buf_a[buf_idx + 1] = FLOAT_TYPE_VEC2(fma(d, q.z, m), fma(d, q.w, m)); + buf_a[buf_idx ] = FLOAT_TYPEV2(fma(d, q.x, m), fma(d, q.y, m)); + buf_a[buf_idx + 1] = FLOAT_TYPEV2(fma(d, q.z, m), fma(d, q.w, m)); #elif defined(DATA_A_Q6_K) const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row; const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2; @@ -267,7 +283,7 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin const uint qh = (uint(data_a_packed16[ib].qh[qhi]) >> qhshift) & 0x0303; const vec2 q = (vec2(unpack8(ql | (qh << 4)).xy) - 32) * dscale; - buf_a[buf_idx] = FLOAT_TYPE_VEC2(q.x, q.y); + buf_a[buf_idx] = FLOAT_TYPEV2(q.x, q.y); #elif defined(DATA_A_IQ1_S) const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row; const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2; @@ -284,8 +300,8 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin const int16_t grid = int16_t(iq1s_grid[qs | (bitfieldExtract(qh, 3 * int(ib8 & 3), 3) << 8)]); [[unroll]] for (int k = 0; k < 4; ++k) { - buf_a[buf_idx + k] = FLOAT_TYPE_VEC2(dl * (bitfieldExtract(grid, 4 * k , 2) + delta), - dl * (bitfieldExtract(grid, 4 * k + 2, 2) + delta)); + buf_a[buf_idx + k] = FLOAT_TYPEV2(dl * (bitfieldExtract(grid, 4 * k , 2) + delta), + dl * (bitfieldExtract(grid, 4 * k + 2, 2) + delta)); } #elif defined(DATA_A_IQ1_M) const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row; @@ -306,8 +322,8 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin const int16_t grid = int16_t(iq1s_grid[qs | ((qh & 7) << 8)]); [[unroll]] for (int k = 0; k < 4; ++k) { - buf_a[buf_idx + k] = FLOAT_TYPE_VEC2(dl * (bitfieldExtract(grid, 4 * k , 2) + delta), - dl * (bitfieldExtract(grid, 4 * k + 2, 2) + delta)); + buf_a[buf_idx + k] = FLOAT_TYPEV2(dl * (bitfieldExtract(grid, 4 * k , 2) + delta), + dl * (bitfieldExtract(grid, 4 * k + 2, 2) + delta)); } #elif defined(DATA_A_IQ2_XXS) const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row; @@ -332,14 +348,14 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin const vec4 grid0 = vec4(unpack8(grid.x)); const vec4 grid1 = vec4(unpack8(grid.y)); - buf_a[buf_idx ] = db * FLOAT_TYPE_VEC2((sign & 1) != 0 ? -grid0.x : grid0.x, - (sign & 2) != 0 ? -grid0.y : grid0.y); - buf_a[buf_idx + 1] = db * FLOAT_TYPE_VEC2((sign & 4) != 0 ? -grid0.z : grid0.z, - (sign & 8) != 0 ? -grid0.w : grid0.w); - buf_a[buf_idx + 2] = db * FLOAT_TYPE_VEC2((sign & 16) != 0 ? -grid1.x : grid1.x, - (sign & 32) != 0 ? -grid1.y : grid1.y); - buf_a[buf_idx + 3] = db * FLOAT_TYPE_VEC2((sign & 64) != 0 ? -grid1.z : grid1.z, - (sign & 128) != 0 ? -grid1.w : grid1.w); + buf_a[buf_idx ] = db * FLOAT_TYPEV2((sign & 1) != 0 ? -grid0.x : grid0.x, + (sign & 2) != 0 ? -grid0.y : grid0.y); + buf_a[buf_idx + 1] = db * FLOAT_TYPEV2((sign & 4) != 0 ? -grid0.z : grid0.z, + (sign & 8) != 0 ? -grid0.w : grid0.w); + buf_a[buf_idx + 2] = db * FLOAT_TYPEV2((sign & 16) != 0 ? -grid1.x : grid1.x, + (sign & 32) != 0 ? -grid1.y : grid1.y); + buf_a[buf_idx + 3] = db * FLOAT_TYPEV2((sign & 64) != 0 ? -grid1.z : grid1.z, + (sign & 128) != 0 ? -grid1.w : grid1.w); #elif defined(DATA_A_IQ2_XS) const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row; const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2; @@ -358,14 +374,14 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin const vec4 grid0 = vec4(unpack8(grid.x)); const vec4 grid1 = vec4(unpack8(grid.y)); - buf_a[buf_idx ] = db * FLOAT_TYPE_VEC2((sign & 1) != 0 ? -grid0.x : grid0.x, - (sign & 2) != 0 ? -grid0.y : grid0.y); - buf_a[buf_idx + 1] = db * FLOAT_TYPE_VEC2((sign & 4) != 0 ? -grid0.z : grid0.z, - (sign & 8) != 0 ? -grid0.w : grid0.w); - buf_a[buf_idx + 2] = db * FLOAT_TYPE_VEC2((sign & 16) != 0 ? -grid1.x : grid1.x, - (sign & 32) != 0 ? -grid1.y : grid1.y); - buf_a[buf_idx + 3] = db * FLOAT_TYPE_VEC2((sign & 64) != 0 ? -grid1.z : grid1.z, - (sign & 128) != 0 ? -grid1.w : grid1.w); + buf_a[buf_idx ] = db * FLOAT_TYPEV2((sign & 1) != 0 ? -grid0.x : grid0.x, + (sign & 2) != 0 ? -grid0.y : grid0.y); + buf_a[buf_idx + 1] = db * FLOAT_TYPEV2((sign & 4) != 0 ? -grid0.z : grid0.z, + (sign & 8) != 0 ? -grid0.w : grid0.w); + buf_a[buf_idx + 2] = db * FLOAT_TYPEV2((sign & 16) != 0 ? -grid1.x : grid1.x, + (sign & 32) != 0 ? -grid1.y : grid1.y); + buf_a[buf_idx + 3] = db * FLOAT_TYPEV2((sign & 64) != 0 ? -grid1.z : grid1.z, + (sign & 128) != 0 ? -grid1.w : grid1.w); #elif defined(DATA_A_IQ2_S) const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row; const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2; @@ -386,14 +402,14 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin const vec4 grid0 = vec4(unpack8(grid.x)); const vec4 grid1 = vec4(unpack8(grid.y)); - buf_a[buf_idx ] = db * FLOAT_TYPE_VEC2((sign & 1) != 0 ? -grid0.x : grid0.x, - (sign & 2) != 0 ? -grid0.y : grid0.y); - buf_a[buf_idx + 1] = db * FLOAT_TYPE_VEC2((sign & 4) != 0 ? -grid0.z : grid0.z, - (sign & 8) != 0 ? -grid0.w : grid0.w); - buf_a[buf_idx + 2] = db * FLOAT_TYPE_VEC2((sign & 16) != 0 ? -grid1.x : grid1.x, - (sign & 32) != 0 ? -grid1.y : grid1.y); - buf_a[buf_idx + 3] = db * FLOAT_TYPE_VEC2((sign & 64) != 0 ? -grid1.z : grid1.z, - (sign & 128) != 0 ? -grid1.w : grid1.w); + buf_a[buf_idx ] = db * FLOAT_TYPEV2((sign & 1) != 0 ? -grid0.x : grid0.x, + (sign & 2) != 0 ? -grid0.y : grid0.y); + buf_a[buf_idx + 1] = db * FLOAT_TYPEV2((sign & 4) != 0 ? -grid0.z : grid0.z, + (sign & 8) != 0 ? -grid0.w : grid0.w); + buf_a[buf_idx + 2] = db * FLOAT_TYPEV2((sign & 16) != 0 ? -grid1.x : grid1.x, + (sign & 32) != 0 ? -grid1.y : grid1.y); + buf_a[buf_idx + 3] = db * FLOAT_TYPEV2((sign & 64) != 0 ? -grid1.z : grid1.z, + (sign & 128) != 0 ? -grid1.w : grid1.w); #elif defined(DATA_A_IQ3_XXS) const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row; const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2; @@ -414,10 +430,10 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin const uint grid = iq3xxs_grid[qs]; const vec4 v = db * vec4(unpack8(grid)); - buf_a[buf_idx ] = FLOAT_TYPE_VEC2((sign & 1) != 0 ? -v.x : v.x, - (sign & 2) != 0 ? -v.y : v.y); - buf_a[buf_idx + 1] = FLOAT_TYPE_VEC2((sign & 4) != 0 ? -v.z : v.z, - (sign & 8) != 0 ? -v.w : v.w); + buf_a[buf_idx ] = FLOAT_TYPEV2((sign & 1) != 0 ? -v.x : v.x, + (sign & 2) != 0 ? -v.y : v.y); + buf_a[buf_idx + 1] = FLOAT_TYPEV2((sign & 4) != 0 ? -v.z : v.z, + (sign & 8) != 0 ? -v.w : v.w); #elif defined(DATA_A_IQ3_S) const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row; const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2; @@ -436,27 +452,28 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin const uint32_t grid = iq3s_grid[qs | ((qh << (8 - (iqs % 8))) & 256)]; const vec4 v = db * vec4(unpack8(grid)); - buf_a[buf_idx ] = FLOAT_TYPE_VEC2((sign & 1) != 0 ? -v.x : v.x, - (sign & 2) != 0 ? -v.y : v.y); - buf_a[buf_idx + 1] = FLOAT_TYPE_VEC2((sign & 4) != 0 ? -v.z : v.z, - (sign & 8) != 0 ? -v.w : v.w); + buf_a[buf_idx ] = FLOAT_TYPEV2((sign & 1) != 0 ? -v.x : v.x, + (sign & 2) != 0 ? -v.y : v.y); + buf_a[buf_idx + 1] = FLOAT_TYPEV2((sign & 4) != 0 ? -v.z : v.z, + (sign & 8) != 0 ? -v.w : v.w); #elif defined(DATA_A_IQ4_XS) const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row; const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2; - const uint ib = idx / 128; // 2 values per idx - const uint ib32 = (idx % 128) / 16; // 0..7 - const uint iq = 16 * ib32 + 2 * (idx % 8); + const uint ib = idx / 64; // 4 values per idx + const uint ib32 = (idx % 64) / 8; // 0..7 + const uint iq = 4 * ib32 + (idx % 4); const uint sl = (data_a[ib].scales_l[ib32/2] >> (4 * (ib32 & 1))) & 0xF; const uint sh = ((data_a[ib].scales_h) >> (2 * ib32)) & 3; - const uint qshift = (idx & 8) >> 1; - u8vec2 qs = unpack8((uint(data_a_packed16[ib].qs[iq/2]) >> qshift) & 0x0F0F).xy; + const uint qshift = idx & 4; + u8vec4 qs = unpack8((uint(data_a_packed32[ib].qs[iq]) >> qshift) & 0x0F0F0F0F); const float d = float(data_a[ib].d); - const vec2 v = d * float(int(sl | (sh << 4)) - 32) * vec2(kvalues_iq4nl[qs.x], kvalues_iq4nl[qs.y]); + const vec4 v = d * float(int(sl | (sh << 4)) - 32) * vec4(kvalues_iq4nl[qs.x], kvalues_iq4nl[qs.y], kvalues_iq4nl[qs.z], kvalues_iq4nl[qs.w]); - buf_a[buf_idx ] = FLOAT_TYPE_VEC2(v.xy); + buf_a[buf_idx ] = FLOAT_TYPEV2(v.xy); + buf_a[buf_idx + 1] = FLOAT_TYPEV2(v.zw); #elif defined(DATA_A_IQ4_NL) const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row; const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 4; @@ -467,10 +484,10 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin const FLOAT_TYPE d = FLOAT_TYPE(data_a_packed16[ib].d); const uint vui = uint(data_a_packed16[ib].qs[iqs]); - buf_a[buf_idx ] = d * FLOAT_TYPE_VEC2(kvalues_iq4nl[vui & 0xF], - kvalues_iq4nl[bitfieldExtract(vui, 8, 4)]); - buf_a[buf_idx + 8] = d * FLOAT_TYPE_VEC2(kvalues_iq4nl[bitfieldExtract(vui, 4, 4)], - kvalues_iq4nl[vui >> 12]); + buf_a[buf_idx ] = d * FLOAT_TYPEV2(kvalues_iq4nl[vui & 0xF], + kvalues_iq4nl[bitfieldExtract(vui, 8, 4)]); + buf_a[buf_idx + 8] = d * FLOAT_TYPEV2(kvalues_iq4nl[bitfieldExtract(vui, 4, 4)], + kvalues_iq4nl[vui >> 12]); #elif defined(DATA_A_MXFP4) const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row; const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 4; @@ -482,10 +499,27 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin const uint vui = uint(data_a[ib].qs[iqs]); const uint vui2 = uint(data_a[ib].qs[iqs+1]); - buf_a[buf_idx ] = FLOAT_TYPE_VEC2(kvalues_mxfp4[vui & 0xF] * d, - kvalues_mxfp4[vui2 & 0xF] * d); - buf_a[buf_idx + 8] = FLOAT_TYPE_VEC2(kvalues_mxfp4[vui >> 4] * d, - kvalues_mxfp4[vui2 >> 4] * d); + buf_a[buf_idx ] = FLOAT_TYPEV2(kvalues_mxfp4[vui & 0xF] * d, + kvalues_mxfp4[vui2 & 0xF] * d); + buf_a[buf_idx + 8] = FLOAT_TYPEV2(kvalues_mxfp4[vui >> 4] * d, + kvalues_mxfp4[vui2 >> 4] * d); +#elif defined(DATA_A_NVFP4) + const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row; + // lo and hi nibbles are 8 elements apart, which doesn't quite line up with + // how the thread mapping and buf_idx calculation works for other types. + const uint buf_idx = col * SHMEM_STRIDE + (row & 3) + (row & ~3) * 2; + + const uint ib = idx / 16u; + const uint sub = (idx & 0xC) >> 2; + const uint iqs = (idx & 0xF) * 2; + const float d = ue4m3_to_fp32(data_a[ib].d[sub]) * 0.5; + const uint vui = uint(data_a[ib].qs[iqs]); + const uint vui2 = uint(data_a[ib].qs[iqs+1]); + + buf_a[buf_idx ] = FLOAT_TYPEV2(kvalues_mxfp4[vui & 0xF] * d, + kvalues_mxfp4[vui2 & 0xF] * d); + buf_a[buf_idx + 4] = FLOAT_TYPEV2(kvalues_mxfp4[vui >> 4] * d, + kvalues_mxfp4[vui2 >> 4] * d); #endif } @@ -495,7 +529,7 @@ void load_b_to_shmem(const uint pos_b, const uint row, const uint col, const uin // Not supported for b_type bf16 because bf16mat2x4 does not exist const uint idx = pos_b + col * p.stride_b / LOAD_VEC_B + row; const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_B / 2; - FLOAT_TYPE_VEC8 bb = FLOAT_TYPE_VEC8(data_b[idx]); + FLOAT_TYPEV8 bb = FLOAT_TYPEV8(data_b[idx]); buf_b[buf_idx + 0] = bb[0].xy; buf_b[buf_idx + 1] = bb[0].zw; buf_b[buf_idx + 2] = bb[1].xy; @@ -504,9 +538,9 @@ void load_b_to_shmem(const uint pos_b, const uint row, const uint col, const uin const uint idx = pos_b + col * p.stride_b / LOAD_VEC_B + row; const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_B / 2; #if defined(DATA_B_BF16) - FLOAT_TYPE_VEC4 bb = FLOAT_TYPE_VEC4(TO_FLOAT_TYPE(data_b[idx])); + FLOAT_TYPEV4 bb = FLOAT_TYPEV4(TO_FLOAT_TYPE(data_b[idx])); #else - FLOAT_TYPE_VEC4 bb = FLOAT_TYPE_VEC4(data_b[idx]); + FLOAT_TYPEV4 bb = FLOAT_TYPEV4(data_b[idx]); #endif buf_b[buf_idx + 0] = bb.xy; buf_b[buf_idx + 1] = bb.zw; @@ -514,12 +548,12 @@ void load_b_to_shmem(const uint pos_b, const uint row, const uint col, const uin const uint idx = pos_b + col * p.stride_b + row * 2; const uint buf_idx = col * SHMEM_STRIDE + row; if (idx_n < p.N && block + row * 2 + 1 < end_k) { - buf_b[buf_idx] = FLOAT_TYPE_VEC2(TO_FLOAT_TYPE(data_b[idx]), - TO_FLOAT_TYPE(data_b[idx + 1])); + buf_b[buf_idx] = FLOAT_TYPEV2(TO_FLOAT_TYPE(data_b[idx]), + TO_FLOAT_TYPE(data_b[idx + 1])); } else if (idx_n < p.N && block + row * 2 < end_k) { - buf_b[buf_idx] = FLOAT_TYPE_VEC2(TO_FLOAT_TYPE(data_b[idx]), 0.0f); + buf_b[buf_idx] = FLOAT_TYPEV2(TO_FLOAT_TYPE(data_b[idx]), 0.0f); } else { - buf_b[buf_idx] = FLOAT_TYPE_VEC2(0.0f); + buf_b[buf_idx] = FLOAT_TYPEV2(0.0f); } #endif } @@ -530,7 +564,7 @@ void load_b_to_shmem(const uint pos_b, const uint row, const uint col, const uin const u16vec2 row_idx = row_ids[col]; const uint idx = pos_b + row_idx.y * p.batch_stride_b / LOAD_VEC_B + (row_idx.x % p.ne11) * p.stride_b / LOAD_VEC_B + row; const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_B / 2; - FLOAT_TYPE_VEC8 bb = FLOAT_TYPE_VEC8(data_b[idx]); + FLOAT_TYPEV8 bb = FLOAT_TYPEV8(data_b[idx]); buf_b[buf_idx + 0] = bb[0].xy; buf_b[buf_idx + 1] = bb[0].zw; buf_b[buf_idx + 2] = bb[1].xy; @@ -540,9 +574,9 @@ void load_b_to_shmem(const uint pos_b, const uint row, const uint col, const uin const uint idx = pos_b + row_idx.y * p.batch_stride_b / LOAD_VEC_B + (row_idx.x % p.ne11) * p.stride_b / LOAD_VEC_B + row; const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_B / 2; #if defined(DATA_B_BF16) - FLOAT_TYPE_VEC4 bb = FLOAT_TYPE_VEC4(TO_FLOAT_TYPE(data_b[idx])); + FLOAT_TYPEV4 bb = FLOAT_TYPEV4(TO_FLOAT_TYPE(data_b[idx])); #else - FLOAT_TYPE_VEC4 bb = FLOAT_TYPE_VEC4(data_b[idx]); + FLOAT_TYPEV4 bb = FLOAT_TYPEV4(data_b[idx]); #endif buf_b[buf_idx + 0] = bb.xy; buf_b[buf_idx + 1] = bb.zw; @@ -552,14 +586,14 @@ void load_b_to_shmem(const uint pos_b, const uint row, const uint col, const uin if (row_i < _ne1 && block + row * 2 + 1 < end_k) { const u16vec2 row_idx = row_ids[col]; const uint idx = pos_b + row_idx.y * p.batch_stride_b + (row_idx.x % p.ne11) * p.stride_b + row * 2; - buf_b[buf_idx] = FLOAT_TYPE_VEC2(TO_FLOAT_TYPE(data_b[idx]), - TO_FLOAT_TYPE(data_b[idx + 1])); + buf_b[buf_idx] = FLOAT_TYPEV2(TO_FLOAT_TYPE(data_b[idx]), + TO_FLOAT_TYPE(data_b[idx + 1])); } else if (row_i < _ne1 && block + row * 2 < end_k) { const u16vec2 row_idx = row_ids[col]; const uint idx = pos_b + row_idx.y * p.batch_stride_b + (row_idx.x % p.ne11) * p.stride_b + row * 2; - buf_b[buf_idx] = FLOAT_TYPE_VEC2(TO_FLOAT_TYPE(data_b[idx]), 0.0f); + buf_b[buf_idx] = FLOAT_TYPEV2(TO_FLOAT_TYPE(data_b[idx]), 0.0f); } else { - buf_b[buf_idx] = FLOAT_TYPE_VEC2(0.0f); + buf_b[buf_idx] = FLOAT_TYPEV2(0.0f); } #endif } diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_id_funcs.glsl b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_id_funcs.glsl index 743004ff8ad..26c5c12a49a 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_id_funcs.glsl +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_id_funcs.glsl @@ -43,7 +43,9 @@ void load_row_ids(uint expert_idx, bool nei0_is_pow2, uint ic) { uint id = ids[iter++]; uvec4 ballot = subgroupBallot(in_range && id == expert_idx); - ballots_sh[gl_SubgroupID] = ballot; + if (gl_SubgroupInvocationID == 0) { + ballots_sh[gl_SubgroupID] = ballot; + } barrier(); uint subgroup_base = 0; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp index 335d7f6a682..aae1c2e8ae9 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp @@ -57,6 +57,8 @@ layout (push_constant) uniform parameter uint nbi1; uint ne11; #else + uint base_work_group_z; + uint num_batches; uint k_split; uint ne02; uint ne12; @@ -108,7 +110,7 @@ void main() { const uint ic = gl_WorkGroupID.y; #ifdef MUL_MAT_ID - const uint expert_idx = gl_GlobalInvocationID.z; + const uint expert_idx = gl_WorkGroupID.z; if (ic * BN >= data_expert_count[expert_idx]) { return; } @@ -118,7 +120,7 @@ void main() { #endif #ifndef MUL_MAT_ID - const uint batch_idx = gl_GlobalInvocationID.z; + const uint batch_idx = gl_WorkGroupID.z + p.base_work_group_z; const uint i13 = batch_idx / p.ne12; const uint i12 = batch_idx % p.ne12; @@ -276,7 +278,7 @@ void main() { const uint dc = ic * BN + warp_c * WN; #ifndef MUL_MAT_ID - const uint offsets = batch_idx * p.batch_stride_d + ik * p.batch_stride_d * gl_NumWorkGroups.z; + const uint offsets = batch_idx * p.batch_stride_d + ik * p.batch_stride_d * p.num_batches; #endif [[unroll]] for (uint wsic = 0; wsic < WNITER; wsic++) { diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.glsl b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.glsl index 7f32dadf17d..59931b04b94 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.glsl +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.glsl @@ -21,7 +21,7 @@ void block_a_to_shmem(const uint buf_ib, const uint ib, const uint iqs) { buf_a[buf_ib].qs[iqs] = data_a_packed32[ib].qs[iqs]; if (iqs == 0) { - buf_a[buf_ib].dm = FLOAT_TYPE_VEC2(data_a_packed32[ib].dm); + buf_a[buf_ib].dm = FLOAT_TYPEV2(data_a_packed32[ib].dm); } #endif } @@ -72,7 +72,7 @@ void block_a_to_shmem(const uint buf_ib, const uint ib, const uint iqs) { buf_a[buf_ib].qs[iqs] = data_a_packed32[ib].qs[iqs]; if (iqs == 0) { - buf_a[buf_ib].dm = FLOAT_TYPE_VEC2(data_a_packed32[ib].dm); + buf_a[buf_ib].dm = FLOAT_TYPEV2(data_a_packed32[ib].dm); buf_a[buf_ib].qh = data_a_packed32[ib].qh; } #endif @@ -203,7 +203,7 @@ void block_a_to_shmem(const uint buf_ib, const uint ib, const uint iqs) { buf_a[buf_ib].qs[iqs] = vals0 | (vals1 << 2) | (vals2 << 4) | (vals3 << 6); if (iqs == 0) { - buf_a[buf_ib].dm = FLOAT_TYPE_VEC2(data_a_packed32[ib_k].dm); + buf_a[buf_ib].dm = FLOAT_TYPEV2(data_a_packed32[ib_k].dm); buf_a[buf_ib].scales = unpack8(uint32_t(data_a_packed16[ib_k].scales[iqs_k / 8])).xy; // vec4 used due to #12147 } } @@ -264,7 +264,7 @@ void block_a_to_shmem(const uint buf_ib, const uint ib, const uint iqs) { const i8vec2 scales = i8vec2(unpack8(uint32_t(((data_a_packed16[ib_k].scales[(is % 8 ) / 2] >> (4 * (is / 8))) & 0x0F0F) | (((data_a_packed16[ib_k].scales[(8 + (is % 4)) / 2] >> (2 * (is / 4))) & 0x0303) << 4))).xy); // vec4 used due to #12147 - buf_a[buf_ib].d_scales = FLOAT_TYPE(data_a_packed16[ib_k].d) * FLOAT_TYPE_VEC2(scales - 32); + buf_a[buf_ib].d_scales = FLOAT_TYPEV2(float(data_a_packed16[ib_k].d) * vec2(scales - 32)); } } @@ -334,7 +334,7 @@ void block_a_to_shmem(const uint buf_ib, const uint ib, const uint iqs) { (data_a[ib_k].scales[is+4] >> 4) | ((data_a[ib_k].scales[is ] & 0xC0) >> 2)); } - buf_a[buf_ib].dm = FLOAT_TYPE_VEC2(data_a_packed32[ib_k].dm) * FLOAT_TYPE_VEC2(scale_dm); + buf_a[buf_ib].dm = FLOAT_TYPEV2(vec2(data_a_packed32[ib_k].dm) * vec2(scale_dm)); } } @@ -385,7 +385,7 @@ void block_a_to_shmem(const uint buf_ib, const uint ib, const uint iqs) { const uint is = iqs_k / 4; const i8vec2 scales = unpack8(int32_t(data_a_packed16[ib_k].scales[is / 2])).xy; - buf_a[buf_ib].d_scales = FLOAT_TYPE(data_a_packed16[ib_k].d) * FLOAT_TYPE_VEC2(scales); + buf_a[buf_ib].d_scales = FLOAT_TYPEV2(float(data_a_packed16[ib_k].d) * vec2(scales)); } } @@ -426,7 +426,7 @@ void block_b_to_shmem(const uint buf_ib, const uint ib, const uint iqs, const bo const uint ib_inner = ib % 4; if (iqs == 0) { - buf_b[buf_ib].ds = FLOAT_TYPE_VEC2(data_b[ib_outer].ds[ib_inner]); + buf_b[buf_ib].ds = FLOAT_TYPEV2(data_b[ib_outer].ds[ib_inner]); } const ivec4 values = data_b[ib_outer].qs[ib_inner * 2 + iqs]; @@ -436,7 +436,7 @@ void block_b_to_shmem(const uint buf_ib, const uint ib, const uint iqs, const bo buf_b[buf_ib].qs[iqs * 4 + 3] = values.w; } else { if (iqs == 0) { - buf_b[buf_ib].ds = FLOAT_TYPE_VEC2(0.0f); + buf_b[buf_ib].ds = FLOAT_TYPEV2(0.0f); } buf_b[buf_ib].qs[iqs * 4 ] = 0; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_shmem_types.glsl b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_shmem_types.glsl index 1c0f5306f38..79c933f40cf 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_shmem_types.glsl +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_shmem_types.glsl @@ -1,4 +1,13 @@ -#if defined(DATA_A_Q4_0) +#if defined(FA_MMQ_MIXED) +// Mixed-K flash attention MMQ: superset cache that fits Q4_0/Q4_1/Q5_0/Q5_1/Q8_0. +// Q4_*/Q5_* only use qs[0..3] and (for Q5_*) qh. Q8_0 uses qs[0..7]. Single-scale +// types (Q4_0/Q5_0/Q8_0) leave dm.y unused. +struct block_a_cache { + int32_t qs[8]; + uint32_t qh; + FLOAT_TYPEV2 dm; +}; +#elif defined(DATA_A_Q4_0) #define QUANT_R_MMQ 2 struct block_a_cache { uint32_t qs[16/4]; @@ -8,7 +17,7 @@ struct block_a_cache { #define QUANT_R_MMQ 2 struct block_a_cache { uint32_t qs[16/4]; - FLOAT_TYPE_VEC2 dm; + FLOAT_TYPEV2 dm; }; #elif defined(DATA_A_Q5_0) #define QUANT_R_MMQ 2 @@ -22,7 +31,7 @@ struct block_a_cache { struct block_a_cache { uint32_t qs[16/4]; uint32_t qh; - FLOAT_TYPE_VEC2 dm; + FLOAT_TYPEV2 dm; }; #elif defined(DATA_A_Q8_0) #define QUANT_R_MMQ 1 @@ -32,6 +41,12 @@ struct block_a_cache { int32_t qs[32/4]; FLOAT_TYPE dm; }; +#elif defined(DATA_A_IQ4_NL) +#define QUANT_R_MMQ 2 +struct block_a_cache { + int32_t qs[8]; + FLOAT_TYPE dm; +}; #elif defined(DATA_A_MXFP4) #define QUANT_R_MMQ 2 struct block_a_cache { @@ -43,36 +58,36 @@ struct block_a_cache { struct block_a_cache { uint32_t qs[2]; u8vec2 scales; - FLOAT_TYPE_VEC2 dm; + FLOAT_TYPEV2 dm; }; #elif defined(DATA_A_Q3_K) #define QUANT_R_MMQ 2 struct block_a_cache { uint32_t qs[4]; - FLOAT_TYPE_VEC2 d_scales; + FLOAT_TYPEV2 d_scales; }; #elif defined(DATA_A_Q4_K) #define QUANT_R_MMQ 2 struct block_a_cache { uint32_t qs[4]; - FLOAT_TYPE_VEC2 dm; + FLOAT_TYPEV2 dm; }; #elif defined(DATA_A_Q5_K) #define QUANT_R_MMQ 1 struct block_a_cache { int32_t qs[8]; - FLOAT_TYPE_VEC2 dm; + FLOAT_TYPEV2 dm; }; #elif defined(DATA_A_Q6_K) #define QUANT_R_MMQ 1 struct block_a_cache { int32_t qs[8]; - FLOAT_TYPE_VEC2 d_scales; + FLOAT_TYPEV2 d_scales; }; #endif struct block_b_cache { int32_t qs[8]; - FLOAT_TYPE_VEC2 ds; + FLOAT_TYPEV2 ds; }; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/multi_add.comp b/ggml/src/ggml-vulkan/vulkan-shaders/multi_add.comp index 10cf5202a4a..26d194e9e8d 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/multi_add.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/multi_add.comp @@ -8,7 +8,6 @@ #extension GL_KHR_shader_subgroup_basic : enable #endif -#include "rte.glsl" #include "types.glsl" #include "utils.glsl" diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp b/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp index 9d6d3665427..55b89f19a7a 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp @@ -112,12 +112,11 @@ void rms_norm(uint num_iters) { #if RMS_NORM_ROPE_FUSION barrier(); rope_params rp = p.rope; - uint rope_row = (samp*nchannels + channel)*nrows + row; for (uint t = 2*tid; t < ncols; t += 2*BLOCK_SIZE) { if (rp.rope_mode == GGML_ROPE_TYPE_NEOX) { - rope_neox(t, rope_row, rp); + rope_neox(t, row, channel, samp, rp); } else if (rp.rope_mode == GGML_ROPE_TYPE_NORMAL) { - rope_norm(t, rope_row, rp); + rope_norm(t, row, channel, samp, rp); } } #endif diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/rope_funcs.glsl b/ggml/src/ggml-vulkan/vulkan-shaders/rope_funcs.glsl index aacec984696..03358793140 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/rope_funcs.glsl +++ b/ggml/src/ggml-vulkan/vulkan-shaders/rope_funcs.glsl @@ -4,12 +4,12 @@ float rope_yarn_ramp(const float low, const float high, const uint i0) { return 1.0f - min(1.0f, max(0.0f, y)); } -uint rope_a_coord(const uint i0, const uint i01, const uint i02, rope_params p) { +uint rope_a_coord(const uint i0, const uint i01, const uint i02, const uint i03, rope_params p) { #if RMS_NORM_ROPE_FUSION // Per-row offset in shared memory const uint ix = i0; #else - const uint ix = i02*p.nb02 + i01*p.nb01 + i0; + const uint ix = p.a_offset + i03*p.nb03 + i02*p.nb02 + i01*p.nb01 + i0; #endif return ix; } @@ -34,27 +34,21 @@ void rope_yarn(const float theta_extrap, const uint i0, out float cos_theta, out sin_theta = sin(theta) * mscale; } -void rope_norm(const uint i0, const uint i1, rope_params p) { - uint ne0 = p.ncols; - uint ne1 = p.p_delta_rows; - - if (i0 >= ne0) { +void rope_norm(const uint i0, const uint i1, const uint i2, const uint i3, rope_params p) { + if (i0 >= p.ne00) { return; } - // i1 is actually i2*nb2+i1, but the rows are contiguous - const uint i01 = i1 % ne1; - const uint i02 = i1 / ne1; - - uint idst = i1*ne0 + i0; - const uint ix = rope_a_coord(i0, i01, i02, p); + uint idst = i0 + i1 * p.nb11 + i2 * p.nb12 + i3 * p.nb13; + const uint ix = rope_a_coord(i0, i1, i2, i3, p); // Fusion optimization: ROPE + VIEW + SET_ROWS. // The rope output is viewed as a 1D tensor and offset based on a row index in rope_data_i. if (p.set_rows_stride != 0) { - idst = i01*ne0 + i0; - idst += rope_data_i[i02].x * p.set_rows_stride; + idst = i1*p.nb11 + i0; + idst += rope_data_i[i2].x * p.set_rows_stride; } + idst += p.d_offset; if (i0 >= p.n_dims) { rope_data_d[idst + 0] = ROPE_D_TYPE(rope_data_a[ix + 0]); @@ -63,7 +57,7 @@ void rope_norm(const uint i0, const uint i1, rope_params p) { return; } - const float theta_base = rope_data_pos[i02] * pow(p.theta_scale, i0/2.0f); + const float theta_base = rope_data_pos[i2] * pow(p.theta_scale, i0/2.0f); const float freq_factor = p.has_ff != 0 ? rope_data_ff[i0/2] : 1.0f; @@ -77,26 +71,21 @@ void rope_norm(const uint i0, const uint i1, rope_params p) { rope_data_d[idst + 1] = ROPE_D_TYPE(x0*sin_theta + x1*cos_theta); } -void rope_neox(const uint i0, const uint i1, rope_params p) { - uint ne0 = p.ncols; - uint ne1 = p.p_delta_rows; - - if (i0 >= ne0) { +void rope_neox(const uint i0, const uint i1, const uint i2, const uint i3, rope_params p) { + if (i0 >= p.ne00) { return; } - const uint i01 = i1 % ne1; - const uint i02 = i1 / ne1; - - uint idst = i1*ne0 + i0/2; - const uint ix = rope_a_coord(i0/2, i01, i02, p); + uint idst = i0/2 + i1 * p.nb11 + i2 * p.nb12 + i3 * p.nb13; + const uint ix = rope_a_coord(i0/2, i1, i2, i3, p); // Fusion optimization: ROPE + VIEW + SET_ROWS. // The rope output is viewed as a 1D tensor and offset based on a row index in rope_data_i. if (p.set_rows_stride != 0) { - idst = i01*ne0 + i0/2; - idst += rope_data_i[i02].x * p.set_rows_stride; + idst = i1*p.nb11 + i0/2; + idst += rope_data_i[i2].x * p.set_rows_stride; } + idst += p.d_offset; if (i0 >= p.n_dims) { rope_data_d[idst + i0/2 + 0] = ROPE_D_TYPE(rope_data_a[ix + i0/2 + 0]); @@ -105,7 +94,7 @@ void rope_neox(const uint i0, const uint i1, rope_params p) { return; } - const float theta_base = rope_data_pos[i02] * pow(p.theta_scale, i0/2.0f); + const float theta_base = rope_data_pos[i2] * pow(p.theta_scale, i0/2.0f); const float freq_factor = p.has_ff != 0 ? rope_data_ff[i0/2] : 1.0f; @@ -120,27 +109,21 @@ void rope_neox(const uint i0, const uint i1, rope_params p) { } -void rope_multi(const uint i0, const uint i1, rope_params p) { - uint ne0 = p.ncols; - uint ne1 = p.p_delta_rows; - uint ne2 = p.ne02; - - if (i0 >= ne0) { +void rope_multi(const uint i0, const uint i1, const uint i2, const uint i3, rope_params p) { + if (i0 >= p.ne00) { return; } - const uint i01 = i1 % ne1; - const uint i02 = i1 / ne1; - - uint idst = i1*ne0 + i0/2; - const uint ix = rope_a_coord(i0/2, i01, i02, p); + uint idst = i0/2 + i1 * p.nb11 + i2 * p.nb12 + i3 * p.nb13; + const uint ix = rope_a_coord(i0/2, i1, i2, i3, p); // Fusion optimization: ROPE + VIEW + SET_ROWS. // The rope output is viewed as a 1D tensor and offset based on a row index in rope_data_i. if (p.set_rows_stride != 0) { - idst = i01*ne0 + i0/2; - idst += rope_data_i[i02].x * p.set_rows_stride; + idst = i1*p.nb11 + i0/2; + idst += rope_data_i[i2].x * p.set_rows_stride; } + idst += p.d_offset; if (i0 >= p.n_dims) { rope_data_d[idst + i0/2 + 0] = ROPE_D_TYPE(rope_data_a[ix + i0/2 + 0]); @@ -156,26 +139,26 @@ void rope_multi(const uint i0, const uint i1, rope_params p) { float theta_base = 0.0; if (p.is_imrope != 0) { if (sector % 3 == 1 && sector < 3 * p.sections[1]) { - theta_base = rope_data_pos[i02 + ne2 * 1]*pow(p.theta_scale, i0/2.0f); + theta_base = rope_data_pos[i2 + p.ne02 * 1]*pow(p.theta_scale, i0/2.0f); } else if (sector % 3 == 2 && sector < 3 * p.sections[2]) { - theta_base = rope_data_pos[i02 + ne2 * 2]*pow(p.theta_scale, i0/2.0f); + theta_base = rope_data_pos[i2 + p.ne02 * 2]*pow(p.theta_scale, i0/2.0f); } else if (sector % 3 == 0 && sector < 3 * p.sections[0]) { - theta_base = rope_data_pos[i02]*pow(p.theta_scale, i0/2.0f); + theta_base = rope_data_pos[i2]*pow(p.theta_scale, i0/2.0f); } else { - theta_base = rope_data_pos[i02 + ne2 * 3]*pow(p.theta_scale, i0/2.0f); + theta_base = rope_data_pos[i2 + p.ne02 * 3]*pow(p.theta_scale, i0/2.0f); } } else { if (sector < p.sections[0]) { - theta_base = rope_data_pos[i02]*pow(p.theta_scale, i0/2.0f); + theta_base = rope_data_pos[i2]*pow(p.theta_scale, i0/2.0f); } else if (sector >= p.sections[0] && sector < sec_w) { - theta_base = rope_data_pos[i02 + ne2 * 1]*pow(p.theta_scale, i0/2.0f); + theta_base = rope_data_pos[i2 + p.ne02 * 1]*pow(p.theta_scale, i0/2.0f); } else if (sector >= sec_w && sector < sec_w + p.sections[2]) { - theta_base = rope_data_pos[i02 + ne2 * 2]*pow(p.theta_scale, i0/2.0f); + theta_base = rope_data_pos[i2 + p.ne02 * 2]*pow(p.theta_scale, i0/2.0f); } else if (sector >= sec_w + p.sections[2]) { - theta_base = rope_data_pos[i02 + ne2 * 3]*pow(p.theta_scale, i0/2.0f); + theta_base = rope_data_pos[i2 + p.ne02 * 3]*pow(p.theta_scale, i0/2.0f); } } @@ -191,20 +174,13 @@ void rope_multi(const uint i0, const uint i1, rope_params p) { rope_data_d[idst + p.n_dims/2] = ROPE_D_TYPE(x0*sin_theta + x1*cos_theta); } -void rope_vision(const uint i0, const uint i1, rope_params p) { - uint ne0 = p.ncols; - uint ne1 = p.p_delta_rows; - uint ne2 = p.ne02; - - if (i0 >= ne0) { +void rope_vision(const uint i0, const uint i1, const uint i2, const uint i3, rope_params p) { + if (i0 >= p.ne00) { return; } - const uint i01 = i1 % ne1; - const uint i02 = i1 / ne1; - - const uint idst = i1*ne0 + i0/2; - const uint ix = rope_a_coord(i0/2, i01, i02, p); + const uint idst = p.d_offset + i0/2 + i1 * p.nb11 + i2 * p.nb12 + i3 * p.nb13; + const uint ix = rope_a_coord(i0/2, i1, i2, i3, p); const int sect_dims = p.sections[0] + p.sections[1]; const int sec_w = p.sections[1] + p.sections[0]; @@ -213,11 +189,11 @@ void rope_vision(const uint i0, const uint i1, rope_params p) { float theta_base = 0.0; if (sector < p.sections[0]) { const uint p0 = sector; - theta_base = rope_data_pos[i02]*pow(p.theta_scale, p0); + theta_base = rope_data_pos[i2]*pow(p.theta_scale, p0); } else if (sector >= p.sections[0] && sector < sec_w) { const uint p0 = sector - p.sections[0]; - theta_base = rope_data_pos[i02 + ne2]*pow(p.theta_scale, p0); + theta_base = rope_data_pos[i2 + p.ne02]*pow(p.theta_scale, p0); } const float freq_factor = p.has_ff != 0 ? rope_data_ff[i0/2] : 1.0f; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.glsl b/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.glsl index d9b4d4c03f3..51a127bcd87 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.glsl +++ b/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.glsl @@ -2,7 +2,6 @@ #extension GL_EXT_shader_16bit_storage : require -#include "rte.glsl" #include "rope_params.glsl" layout(local_size_x = 1, local_size_y = 256, local_size_z = 1) in; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp b/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp index f7587468a81..1528fbeeaec 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp @@ -5,10 +5,13 @@ void main() { const uint i0 = 2*gl_GlobalInvocationID.y; - // i1 is actually i2*nb2+i1, but the rows are contiguous - const uint i1 = gl_GlobalInvocationID.x + 32768 * gl_GlobalInvocationID.z; - if (i1 >= pc.nrows) { + const uint row = gl_GlobalInvocationID.x + 32768 * gl_GlobalInvocationID.z; + if (row >= pc.nrows) { return; } - rope_multi(i0, i1, pc); + const uint i3 = row / (pc.ne01*pc.ne02); + const uint i2 = (row - i3 * pc.ne01*pc.ne02) / pc.ne01; + const uint i1 = (row - i3 * pc.ne01*pc.ne02 - i2 * pc.ne01); + + rope_multi(i0, i1, i2, i3, pc); } diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp b/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp index acb8ed78155..ad0896095db 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp @@ -5,10 +5,13 @@ void main() { const uint i0 = 2*gl_GlobalInvocationID.y; - // i1 is actually i2*nb2+i1, but the rows are contiguous - const uint i1 = gl_GlobalInvocationID.x + 32768 * gl_GlobalInvocationID.z; - if (i1 >= pc.nrows) { + const uint row = gl_GlobalInvocationID.x + 32768 * gl_GlobalInvocationID.z; + if (row >= pc.nrows) { return; } - rope_neox(i0, i1, pc); + const uint i3 = row / (pc.ne01*pc.ne02); + const uint i2 = (row - i3 * pc.ne01*pc.ne02) / pc.ne01; + const uint i1 = (row - i3 * pc.ne01*pc.ne02 - i2 * pc.ne01); + + rope_neox(i0, i1, i2, i3, pc); } diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp b/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp index 0033cdb224f..11220817df0 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp @@ -5,10 +5,13 @@ void main() { const uint i0 = 2*gl_GlobalInvocationID.y; - // i1 is actually i2*nb2+i1, but the rows are contiguous - const uint i1 = gl_GlobalInvocationID.x + 32768 * gl_GlobalInvocationID.z; - if (i1 >= pc.nrows) { + const uint row = gl_GlobalInvocationID.x + 32768 * gl_GlobalInvocationID.z; + if (row >= pc.nrows) { return; } - rope_norm(i0, i1, pc); + const uint i3 = row / (pc.ne01*pc.ne02); + const uint i2 = (row - i3 * pc.ne01*pc.ne02) / pc.ne01; + const uint i1 = (row - i3 * pc.ne01*pc.ne02 - i2 * pc.ne01); + + rope_norm(i0, i1, i2, i3, pc); } diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/rope_params.glsl b/ggml/src/ggml-vulkan/vulkan-shaders/rope_params.glsl index 939cf3c51cd..3602485b943 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/rope_params.glsl +++ b/ggml/src/ggml-vulkan/vulkan-shaders/rope_params.glsl @@ -1,28 +1,34 @@ #if !defined(GGML_ROPE_PARAMS) #define GGML_ROPE_PARAMS -#include "rte.glsl" - struct rope_params { uint rope_mode; - uint ncols; uint nrows; uint n_dims; float freq_scale; - uint p_delta_rows; float freq_base; float ext_factor; float attn_factor; float corr_dims[2]; float theta_scale; uint has_ff; - uint ne02; - uint nb01; - uint nb02; int sections[4]; uint is_imrope; uint is_back; uint set_rows_stride; + + uint ne00; + uint ne01; + uint ne02; + uint nb01; + uint nb02; + uint nb03; + uint nb11; + uint nb12; + uint nb13; + + uint a_offset; + uint d_offset; }; #endif // !defined(GGML_ROPE_PARAMS) diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/rope_vision.comp b/ggml/src/ggml-vulkan/vulkan-shaders/rope_vision.comp index d93800b5e76..ca71efb2f55 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/rope_vision.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/rope_vision.comp @@ -5,10 +5,13 @@ void main() { const uint i0 = 2*gl_GlobalInvocationID.y; - // i1 is actually i2*nb2+i1, but the rows are contiguous - const uint i1 = gl_GlobalInvocationID.x + 32768 * gl_GlobalInvocationID.z; - if (i1 >= pc.nrows) { + const uint row = gl_GlobalInvocationID.x + 32768 * gl_GlobalInvocationID.z; + if (row >= pc.nrows) { return; } - rope_vision(i0, i1, pc); + const uint i3 = row / (pc.ne01*pc.ne02); + const uint i2 = (row - i3 * pc.ne01*pc.ne02) / pc.ne01; + const uint i1 = (row - i3 * pc.ne01*pc.ne02 - i2 * pc.ne01); + + rope_vision(i0, i1, i2, i3, pc); } diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/rte.glsl b/ggml/src/ggml-vulkan/vulkan-shaders/rte.glsl deleted file mode 100644 index ad51c1e80b8..00000000000 --- a/ggml/src/ggml-vulkan/vulkan-shaders/rte.glsl +++ /dev/null @@ -1,5 +0,0 @@ - -#if RTE16 -#extension GL_EXT_spirv_intrinsics : enable -spirv_execution_mode(capabilities = [4467], 4462, 16); // RoundingModeRTE, 16 bits -#endif // RTE16 diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/sgn.comp b/ggml/src/ggml-vulkan/vulkan-shaders/sgn.comp new file mode 100644 index 00000000000..a9c147bf9ac --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/sgn.comp @@ -0,0 +1,21 @@ +#version 450 + +#include "generic_head.glsl" +#include "types.glsl" + +#extension GL_EXT_control_flow_attributes : enable + +layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) readonly buffer X {A_TYPE data_a[];}; +layout (binding = 1) writeonly buffer D {D_TYPE data_d[];}; + +void main() { + const uint i = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x; + + if (i >= p.KX) { + return; + } + + data_d[i] = D_TYPE(sign(float(data_a[i]))); +} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/snake.comp b/ggml/src/ggml-vulkan/vulkan-shaders/snake.comp new file mode 100644 index 00000000000..8585538cbb0 --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/snake.comp @@ -0,0 +1,49 @@ +#version 450 + +#include "types.glsl" + +// Fused snake activation: y = x + sin(b * x)^2 * c +// data_a [ne0, ne1] per element activation x (A_TYPE) +// data_b [1, ne1] per channel multiplier (float) +// data_c [1, ne1] per channel inverse scale (float, precomputed as 1 / freq) +// data_d [ne0, ne1] output y (D_TYPE) +layout (binding = 0) readonly buffer A {A_TYPE data_a[];}; +layout (binding = 1) readonly buffer B {float data_b[];}; +layout (binding = 2) readonly buffer C {float data_c[];}; +layout (binding = 3) writeonly buffer D {D_TYPE data_d[];}; + +layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in; + +layout (push_constant) uniform parameter { + uint32_t ne0; + uint32_t ne1; +} p; + +// Load A_TYPE to float +float load_val(uint32_t idx) { +#if defined(DATA_A_BF16) + return bf16_to_fp32(uint32_t(data_a[idx])); +#else + return float(data_a[idx]); +#endif +} + +// Store float as D_TYPE +void store_val(uint32_t idx, float v) { +#if defined(DATA_D_BF16) + data_d[idx] = D_TYPE(fp32_to_bf16(v)); +#else + data_d[idx] = D_TYPE(v); +#endif +} + +void main() { + const uint32_t i0 = gl_GlobalInvocationID.x; + const uint32_t i1 = gl_GlobalInvocationID.y; + if (i0 >= p.ne0 || i1 >= p.ne1) return; + + const uint32_t idx = i0 + i1 * p.ne0; + const float xi = load_val(idx); + const float s = sin(data_b[i1] * xi); + store_val(idx, xi + s * s * data_c[i1]); +} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/ssm_conv.comp b/ggml/src/ggml-vulkan/vulkan-shaders/ssm_conv.comp index d62696bcfae..4cd9b8da359 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/ssm_conv.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/ssm_conv.comp @@ -5,12 +5,16 @@ #include "types.glsl" layout(constant_id = 0) const uint BLOCK_SIZE = 32; +layout(constant_id = 1) const uint TOKENS_PER_WG = 16; +layout(constant_id = 2) const bool APPLY_BIAS = false; +layout(constant_id = 3) const bool APPLY_SILU = false; -layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; +layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z = 1) in; layout(binding = 0) readonly buffer Src0 { float src0[]; }; layout(binding = 1) readonly buffer Src1 { float src1[]; }; -layout(binding = 2) buffer Dst { float dst[]; }; +layout(binding = 2) readonly buffer Bias { float bias[]; }; +layout(binding = 3) buffer Dst { float dst[]; }; layout(push_constant) uniform PushConstants { uint nb01; uint nb02; @@ -20,25 +24,37 @@ layout(push_constant) uniform PushConstants { }; void main() { - const uint global_thread_id = gl_GlobalInvocationID.x; - const uint i2 = gl_WorkGroupID.y; + const uint i1 = gl_GlobalInvocationID.x; + const uint i2 = gl_WorkGroupID.y * TOKENS_PER_WG + gl_LocalInvocationID.y; const uint i3 = gl_WorkGroupID.z; - if (global_thread_id >= nr || i2 >= n_t || i3 >= n_s) { + if (i1 >= nr || i2 >= n_t || i3 >= n_s) { return; } - const uint i1 = global_thread_id; const uint src0_base = i3 * (nb02 / 4) + i2 + i1 * (nb01 / 4); const uint src1_base = i1 * (nb11 / 4); - const uint dst_idx = i3 * (dst_nb2 / 4) + i2 * (dst_nb1 / 4) + i1; float sum = 0.0; - [[unroll]] for (uint i0 = 0; i0 < nc; i0++) { - const uint src0_idx = src0_base + i0; - const uint src1_idx = src1_base + i0; - sum += src0[src0_idx] * src1[src1_idx]; + + if (nc == 4) { + sum = dot( + vec4(src0[src0_base], src0[src0_base + 1], src0[src0_base + 2], src0[src0_base + 3]), + vec4(src1[src1_base], src1[src1_base + 1], src1[src1_base + 2], src1[src1_base + 3]) + ); + } else { + [[unroll]] for (uint i0 = 0; i0 < nc; i0++) { + sum += src0[src0_base + i0] * src1[src1_base + i0]; + } + } + + if (APPLY_BIAS) { + sum += bias[i1]; + } + if (APPLY_SILU) { + sum = sum / (1.0f + exp(-sum)); } + const uint dst_idx = i3 * (dst_nb2 / 4) + i2 * (dst_nb1 / 4) + i1; dst[dst_idx] = sum; } diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/tri.comp b/ggml/src/ggml-vulkan/vulkan-shaders/tri.comp index e18d0ffa307..f9b78f96072 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/tri.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/tri.comp @@ -1,6 +1,5 @@ #version 450 -#include "rte.glsl" #include "types.glsl" #include "generic_unary_head.glsl" diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/types.glsl b/ggml/src/ggml-vulkan/vulkan-shaders/types.glsl index bdb2c09259b..8c6b20c6889 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/types.glsl +++ b/ggml/src/ggml-vulkan/vulkan-shaders/types.glsl @@ -31,6 +31,7 @@ #else #define A_TYPE float16_t #endif +#define A_TYPE_PACKED32 f16vec2 #endif #if defined(DATA_A_BF16) @@ -44,6 +45,7 @@ #else #define A_TYPE uint16_t #endif +#define A_TYPE_PACKED32 uint32_t #endif #define QUANT_K_Q4_0 32 @@ -188,6 +190,22 @@ struct block_q8_0_packed16 #define DATA_A_QUANT_LEGACY #endif +#define QUANT_K_Q1_0 128 +#define QUANT_R_Q1_0 1 + +struct block_q1_0 +{ + float16_t d; + uint8_t qs[QUANT_K_Q1_0 / 8]; +}; + +#if defined(DATA_A_Q1_0) +#define QUANT_K QUANT_K_Q1_0 +#define QUANT_R QUANT_R_Q1_0 +#define QUANT_AUXF 1 +#define A_TYPE block_q1_0 +#endif + #define QUANT_K_Q8_1 32 #define QUANT_R_Q8_1 1 @@ -580,9 +598,10 @@ const uint[1024] iq1s_grid_const = { 0x55dd55df, 0x55d555d7, 0x5503550c, 0x557f5501, 0x5577557d, 0x55405575, 0x555d555f, 0x55555557 }; +#if defined(NEEDS_IQ1S_GRID_GPU) // Same content as iq1s_grid_const except each 2-bit value is expanded to 4-bit // and has 1 added to it (allows packed values to be extracted with & 0x0F0F0F0F -// and 0xF0F0F0F0). +// and 0xF0F0F0F0). This is only used by the q8_1/int-dot vector path. const uint32_t[2048] iq1s_grid_gpu_const = { 0x00000000, 0x00000002, 0x00000101, 0x00000200, 0x00000202, 0x00010001, 0x00010101, 0x00020000, 0x00020002, 0x00020200, 0x00020202, 0x01000101, 0x01010001, 0x01010100, 0x01010102, 0x01020101, @@ -841,9 +860,12 @@ const uint32_t[2048] iq1s_grid_gpu_const = { 0x20222020, 0x20222022, 0x20222220, 0x20222222, 0x21212021, 0x21212120, 0x21212122, 0x22202020, 0x22202022, 0x22202220, 0x22202222, 0x22212121, 0x22222020, 0x22222022, 0x22222220, 0x22222222, }; +#endif shared uint16_t iq1s_grid[2048]; +#if defined(NEEDS_IQ1S_GRID_GPU) shared uint32_t iq1s_grid_gpu[2048]; +#endif #define NEEDS_INIT_IQ_SHMEM void init_iq_shmem(uvec3 wgsize) @@ -857,12 +879,14 @@ void init_iq_shmem(uvec3 wgsize) iq1s_grid[2*idx+1] = g.y; } } +#if defined(NEEDS_IQ1S_GRID_GPU) [[unroll]] for (uint i = 0; i < iq1s_grid_gpu_const.length(); i += wgsize.x) { uint idx = i + gl_LocalInvocationIndex.x; if (iq1s_grid_gpu_const.length() % wgsize.x == 0 || idx < iq1s_grid_gpu_const.length()) { iq1s_grid_gpu[idx] = iq1s_grid_gpu_const[idx]; } } +#endif barrier(); } #endif @@ -1676,6 +1700,7 @@ struct block_iq4_nl_packed16 #if defined(DATA_A_IQ4_NL) #define QUANT_K QUANT_K_IQ4_NL #define QUANT_R QUANT_R_IQ4_NL +#define QUANT_AUXF 1 #define A_TYPE block_iq4_nl #define A_TYPE_PACKED16 block_iq4_nl_packed16 #endif @@ -1696,6 +1721,29 @@ struct block_mxfp4 #define A_TYPE block_mxfp4 #endif +#define QUANT_K_NVFP4 64 +#define QUANT_R_NVFP4 1 + +struct block_nvfp4 +{ + uint8_t d[QUANT_K_NVFP4 / 16]; + uint8_t qs[QUANT_K_NVFP4 / 2]; +}; + +struct block_nvfp4_packed32 +{ + uint32_t d[QUANT_K_NVFP4 / 16 / 4]; + uint32_t qs[QUANT_K_NVFP4 / 2 / 4]; +}; + +#if defined(DATA_A_NVFP4) +#define QUANT_K QUANT_K_NVFP4 +#define QUANT_R QUANT_R_NVFP4 +#define QUANT_AUXF 1 +#define A_TYPE block_nvfp4 +#define A_TYPE_PACKED32 block_nvfp4_packed32 +#endif + #if defined(DATA_A_IQ4_NL) || defined(DATA_A_IQ4_XS) const int8_t kvalues_iq4nl_const[16] = { int8_t(-127), int8_t(-104), int8_t(-83), int8_t(-65), int8_t(-49), int8_t(-35), int8_t(-22), int8_t(-10), @@ -1715,7 +1763,7 @@ void init_iq_shmem(uvec3 wgsize) } #endif -#if defined(DATA_A_MXFP4) +#if defined(DATA_A_MXFP4) || defined(DATA_A_NVFP4) const int8_t kvalues_mxfp4_const[16] = { int8_t(0), int8_t(1), int8_t(2), int8_t(3), int8_t(4), int8_t(6), int8_t(8), int8_t(12), int8_t(0), int8_t(-1), int8_t(-2), int8_t(-3), int8_t(-4), int8_t(-6), int8_t(-8), int8_t(-12), @@ -1723,6 +1771,24 @@ const int8_t kvalues_mxfp4_const[16] = { shared int8_t kvalues_mxfp4[16]; +#if defined(DATA_A_NVFP4) +// UE4M3 scale in NVFP4 blocks use only 7 bits; sign (bit 7) is always zero. +shared float ue4m3_fp32_lut[128]; + +float ue4m3_to_fp32_build(uint u) { + if (u == 0u || u == 127u) { + return 0.0; + } + const uint exp = (u >> 3) & 15u; + const uint man = u & 7u; + if (exp == 0u) { + return float(man) * (1.0 / 512.0); + } + const uint bits = (exp + 120u) << 23 | (man << 20); + return uintBitsToFloat(bits); +} +#endif + #define NEEDS_INIT_IQ_SHMEM void init_iq_shmem(uvec3 wgsize) { @@ -1730,6 +1796,11 @@ void init_iq_shmem(uvec3 wgsize) for (uint i = gl_LocalInvocationIndex.x; i < kvalues_mxfp4.length(); i += wgsize.x) { kvalues_mxfp4[i] = kvalues_mxfp4_const[i]; } +#if defined(DATA_A_NVFP4) + for (uint i = gl_LocalInvocationIndex.x; i < 128u; i += wgsize.x) { + ue4m3_fp32_lut[i] = ue4m3_to_fp32_build(i); + } +#endif barrier(); } #endif @@ -1766,6 +1837,12 @@ float e8m0_to_fp32(uint8_t x) { return uintBitsToFloat(bits); } +#if defined(DATA_A_NVFP4) +float ue4m3_to_fp32(uint8_t x) { + return ue4m3_fp32_lut[uint(x)]; +} +#endif + #if BDA #extension GL_EXT_buffer_reference : enable diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp index bbdbf9dcaaa..7bcb1460814 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp @@ -45,6 +45,7 @@ std::string target_cpp = ""; const std::vector<std::string> type_names = { "f32", "f16", + "q1_0", "q4_0", "q4_1", "q5_0", @@ -65,6 +66,7 @@ const std::vector<std::string> type_names = { "iq4_xs", "iq4_nl", "mxfp4", + "nvfp4", "bf16", }; @@ -137,6 +139,7 @@ void execute_command(std::vector<std::string>& command, std::string& stdout_str, pid_t pid = fork(); if (pid < 0) { + std::cerr << strerror(errno) << "\n"; throw std::runtime_error("Failed to fork process"); } @@ -330,10 +333,11 @@ void string_to_spv_func(std::string name, std::string in_path, std::string out_p std::vector<std::string> cmd = {GLSLC, "-fshader-stage=compute", target_env, in_path, "-o", out_path}; #endif - // disable spirv-opt for coopmat shaders for https://github.com/ggerganov/llama.cpp/issues/10734 + // disable spirv-opt for coopmat shaders for https://github.com/ggml-org/llama.cpp/issues/10734 // disable spirv-opt for bf16 shaders for https://github.com/ggml-org/llama.cpp/issues/15344 // disable spirv-opt for rope shaders for https://github.com/ggml-org/llama.cpp/issues/16860 - if (!coopmat && name.find("bf16") == std::string::npos && name.find("rope") == std::string::npos) { + // disable spirv-opt for dot2 shaders (spirv-opt doesn't recognize SPV_VALVE_mixed_float_dot_product capability) + if (!coopmat && name.find("bf16") == std::string::npos && name.find("rope") == std::string::npos && name.find("_dot2") == std::string::npos) { cmd.push_back("-O"); } @@ -404,8 +408,8 @@ std::map<std::string, std::string> merge_maps(const std::map<std::string, std::s } static std::vector<std::future<void>> compiles; -void string_to_spv(std::string name, const std::string& source, const std::map<std::string, std::string>& defines, bool fp16 = true, bool coopmat = false, bool coopmat2 = false, bool f16acc = false) { - name = name + (f16acc ? "_f16acc" : "") + (coopmat ? "_cm1" : "") + (coopmat2 ? "_cm2" : (fp16 ? "" : "_fp32")); +void string_to_spv(std::string name, const std::string& source, const std::map<std::string, std::string>& defines, bool fp16 = true, bool coopmat = false, bool coopmat2 = false, bool f16acc = false, const std::string& suffix = "") { + name = name + (f16acc ? "_f16acc" : "") + (coopmat ? "_cm1" : "") + (coopmat2 ? "_cm2" : (fp16 ? "" : "_fp32")) + suffix; std::string out_path = join_paths(output_dir, name + ".spv"); if (input_filepath == "") { @@ -424,10 +428,11 @@ void string_to_spv(std::string name, const std::string& source, const std::map<s generate_dep_file = false; } -void matmul_shaders(bool fp16, MatMulIdType matmul_id_type, bool coopmat, bool coopmat2, bool f16acc) { +void matmul_shaders(bool fp16, MatMulIdType matmul_id_type, bool coopmat, bool coopmat2, bool f16acc, bool dot2 = false) { std::string load_vec = coopmat2 ? "1" : fp16 ? "8" : "4"; std::string aligned_b_type_f32 = coopmat2 ? "float" : fp16 ? "mat2x4" : "vec4"; std::string aligned_b_type_f16 = coopmat2 ? "float16_t" : fp16 ? "f16mat2x4" : "f16vec4"; + std::string dot2_sfx = dot2 ? "_dot2" : ""; std::map<std::string, std::string> base_dict; std::string shader_name = "matmul"; @@ -445,8 +450,8 @@ void matmul_shaders(bool fp16, MatMulIdType matmul_id_type, bool coopmat, bool c base_dict["FLOAT16"] = "1"; } - base_dict["ACC_TYPE" ] = f16acc ? "float16_t" : "float"; - base_dict["ACC_TYPE_VEC2"] = f16acc ? "f16vec2" : "vec2"; + base_dict["ACC_TYPE" ] = f16acc ? "float16_t" : "float"; + base_dict["ACC_TYPEV2"] = f16acc ? "f16vec2" : "vec2"; if (f16acc) { base_dict["ACC_TYPE_MAX"] = "float16_t(65504.0)"; } @@ -454,6 +459,15 @@ void matmul_shaders(bool fp16, MatMulIdType matmul_id_type, bool coopmat, bool c if (coopmat) { base_dict["COOPMAT"] = "1"; } +#if defined(GGML_VULKAN_COOPMAT2_DECODE_VECTOR_GLSLC_SUPPORT) + if (coopmat2) { + base_dict["GGML_VULKAN_COOPMAT2_DECODE_VECTOR"] = "1"; + } +#endif + + if (dot2) { + base_dict["DOT2_F16"] = "1"; + } const std::string source_name = coopmat2 ? "mul_mm_cm2.comp" : "mul_mm.comp"; @@ -513,18 +527,18 @@ void matmul_shaders(bool fp16, MatMulIdType matmul_id_type, bool coopmat, bool c }; const std::map<std::string, std::string> float_type_dict_f16 = { - {"FLOAT_TYPE", FLOAT_TYPE(1, "f16")}, - {"FLOAT_TYPE_VEC2", FLOAT_TYPE(2, "f16")}, - {"FLOAT_TYPE_VEC4", FLOAT_TYPE(4, "f16")}, - {"FLOAT_TYPE_VEC8", FLOAT_TYPE(8, "f16")}, + {"FLOAT_TYPE", FLOAT_TYPE(1, "f16")}, + {"FLOAT_TYPEV2", FLOAT_TYPE(2, "f16")}, + {"FLOAT_TYPEV4", FLOAT_TYPE(4, "f16")}, + {"FLOAT_TYPEV8", FLOAT_TYPE(8, "f16")}, }; // Shaders with f16 B_TYPE - string_to_spv(shader_name + "_f32_f16", source_name, merge_maps(merge_maps(base_dict, float_type_dict_f16), {{"DATA_A_F32", "1"}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}, }), fp16, coopmat, coopmat2, f16acc); - string_to_spv(shader_name + "_f32_f16_aligned", source_name, merge_maps(merge_maps(base_dict, float_type_dict_f16), {{"DATA_A_F32", "1"}, {"LOAD_VEC_A", load_vec}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc); + string_to_spv(shader_name + "_f32_f16" + dot2_sfx, source_name, merge_maps(merge_maps(base_dict, float_type_dict_f16), {{"DATA_A_F32", "1"}, {"B_TYPE", "float16_t"}, {"B_TYPEV4", "f16vec4"}, {"D_TYPE", "float"}, }), fp16, coopmat, coopmat2, f16acc); + string_to_spv(shader_name + "_f32_f16" + dot2_sfx + "_aligned", source_name, merge_maps(merge_maps(base_dict, float_type_dict_f16), {{"DATA_A_F32", "1"}, {"LOAD_VEC_A", load_vec}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"B_TYPEV4", "f16vec4"}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc); - string_to_spv(shader_name + "_f16", source_name, merge_maps(merge_maps(base_dict, float_type_dict_f16), {{"DATA_A_F16", "1"}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}}), fp16, coopmat, coopmat2, f16acc); - string_to_spv(shader_name + "_f16_aligned", source_name, merge_maps(merge_maps(base_dict, float_type_dict_f16), {{"DATA_A_F16", "1"}, {"LOAD_VEC_A", load_vec}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc); + string_to_spv(shader_name + "_f16" + dot2_sfx, source_name, merge_maps(merge_maps(base_dict, float_type_dict_f16), {{"DATA_A_F16", "1"}, {"B_TYPE", "float16_t"}, {"B_TYPEV4", "f16vec4"}, {"D_TYPE", "float"}}), fp16, coopmat, coopmat2, f16acc); + string_to_spv(shader_name + "_f16" + dot2_sfx + "_aligned", source_name, merge_maps(merge_maps(base_dict, float_type_dict_f16), {{"DATA_A_F16", "1"}, {"LOAD_VEC_A", load_vec}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"B_TYPEV4", "f16vec4"}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc); // bf16 { @@ -535,9 +549,9 @@ void matmul_shaders(bool fp16, MatMulIdType matmul_id_type, bool coopmat, bool c std::string to_float_type = (coopmat || coopmat2) ? "uintBitsToBFloat16EXT" : "bf16_to_fp32"; const std::map<std::string, std::string> float_type_dict_bf16 = { - {"FLOAT_TYPE", FLOAT_TYPE(1, "bf16")}, - {"FLOAT_TYPE_VEC2", FLOAT_TYPE(2, "bf16")}, - {"FLOAT_TYPE_VEC4", FLOAT_TYPE(4, "bf16")}, + {"FLOAT_TYPE", FLOAT_TYPE(1, "bf16")}, + {"FLOAT_TYPEV2", FLOAT_TYPE(2, "bf16")}, + {"FLOAT_TYPEV4", FLOAT_TYPE(4, "bf16")}, }; // If bfloat16 is not supported, then only compile the scalar (promote to fp32) shader @@ -545,16 +559,18 @@ void matmul_shaders(bool fp16, MatMulIdType matmul_id_type, bool coopmat, bool c if (!(coopmat || coopmat2)) #endif { - string_to_spv(shader_name + "_bf16", source_name, merge_maps(merge_maps(base_dict, float_type_dict_bf16), {{"TO_FLOAT_TYPE", to_float_type}, {"DATA_A_BF16", "1"}, {"B_TYPE", coopmat2 ? "bfloat16_t" : "uint16_t"}, {"D_TYPE", "float"}, {"B_IS_FLOAT", "1"}, {"DATA_B_BF16", "1"}}), fp16, coopmat, coopmat2, f16acc); - string_to_spv(shader_name + "_bf16_aligned", source_name, merge_maps(merge_maps(base_dict, float_type_dict_bf16), {{"TO_FLOAT_TYPE", to_float_type}, {"DATA_A_BF16", "1"}, {"LOAD_VEC_A", load_vec_a}, {"LOAD_VEC_B", "4"}, {"B_TYPE", coopmat2 ? "bfloat16_t" : "u16vec4"}, {"D_TYPE", "float"}, {"B_IS_FLOAT", "1"}, {"DATA_B_BF16", "1"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc); + if (!dot2) { + string_to_spv(shader_name + "_bf16", source_name, merge_maps(merge_maps(base_dict, float_type_dict_bf16), {{"TO_FLOAT_TYPE", to_float_type}, {"DATA_A_BF16", "1"}, {"B_TYPE", coopmat2 ? "bfloat16_t" : "uint16_t"}, {"B_TYPEV4", "bf16vec4"}, {"D_TYPE", "float"}, {"B_IS_FLOAT", "1"}, {"DATA_B_BF16", "1"}}), fp16, coopmat, coopmat2, f16acc); + string_to_spv(shader_name + "_bf16_aligned", source_name, merge_maps(merge_maps(base_dict, float_type_dict_bf16), {{"TO_FLOAT_TYPE", to_float_type}, {"DATA_A_BF16", "1"}, {"LOAD_VEC_A", load_vec_a}, {"LOAD_VEC_B", "4"}, {"B_TYPE", coopmat2 ? "bfloat16_t" : "u16vec4"}, {"B_TYPEV4", "bf16vec4"}, {"D_TYPE", "float"}, {"B_IS_FLOAT", "1"}, {"DATA_B_BF16", "1"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc); + } } } for (const auto& tname : type_names) { std::string load_vec_quant = "2"; - if ((tname == "q4_0") || (tname == "q4_1") || (tname == "q5_1") || (tname == "iq1_s") || (tname == "iq1_m") || (tname == "iq2_xxs") || (tname == "iq2_xs") || (tname == "iq2_s")) + if ((tname == "q1_0") || (tname == "q4_0") || (tname == "q4_1") || (tname == "q5_1") || (tname == "iq1_s") || (tname == "iq1_m") || (tname == "iq2_xxs") || (tname == "iq2_xs") || (tname == "iq2_s")) load_vec_quant = "8"; - else if ((tname == "q5_0") || (tname == "q8_0") || (tname == "q2_k") || (tname == "q4_k") || (tname == "q5_k") || (tname == "iq3_xxs") || (tname == "iq3_s") || (tname == "iq4_nl") || (tname == "mxfp4")) + else if ((tname == "q5_0") || (tname == "q8_0") || (tname == "q2_k") || (tname == "q4_k") || (tname == "q5_k") || (tname == "iq3_xxs") || (tname == "iq3_s") || (tname == "iq4_xs") || (tname == "iq4_nl") || (tname == "mxfp4") || (tname == "nvfp4")) load_vec_quant = "4"; if (tname == "bf16") { @@ -568,26 +584,26 @@ void matmul_shaders(bool fp16, MatMulIdType matmul_id_type, bool coopmat, bool c std::string load_vec_a = (coopmat2 || tname == "f32" || tname == "f16" || tname == "bf16") ? load_vec : load_vec_quant; const std::map<std::string, std::string> float_type_dict = { - {"FLOAT_TYPE", FLOAT_TYPE(1, tname)}, - {"FLOAT_TYPE_VEC2", FLOAT_TYPE(2, tname)}, - {"FLOAT_TYPE_VEC4", FLOAT_TYPE(4, tname)}, - {"FLOAT_TYPE_VEC8", FLOAT_TYPE(8, tname)}, + {"FLOAT_TYPE", FLOAT_TYPE(1, tname)}, + {"FLOAT_TYPEV2", FLOAT_TYPE(2, tname)}, + {"FLOAT_TYPEV4", FLOAT_TYPE(4, tname)}, + {"FLOAT_TYPEV8", FLOAT_TYPE(8, tname)}, }; // don't generate f32 variants for coopmat2 if (!coopmat2) { - string_to_spv(shader_name + "_" + tname + "_f32", source_name, merge_maps(merge_maps(base_dict, float_type_dict), {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a_unaligned}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}), fp16, coopmat, coopmat2, f16acc); - string_to_spv(shader_name + "_" + tname + "_f32_aligned", source_name, merge_maps(merge_maps(base_dict, float_type_dict), {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f32}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc); + string_to_spv(shader_name + "_" + tname + "_f32" + dot2_sfx, source_name, merge_maps(merge_maps(base_dict, float_type_dict), {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a_unaligned}, {"B_TYPE", "float"}, {"B_TYPEV4", "vec4"}, {"D_TYPE", "float"}}), fp16, coopmat, coopmat2, f16acc); + string_to_spv(shader_name + "_" + tname + "_f32" + dot2_sfx + "_aligned", source_name, merge_maps(merge_maps(base_dict, float_type_dict), {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f32}, {"B_TYPEV4", "vec4"}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc); } if (tname != "f16" && tname != "f32") { - string_to_spv(shader_name + "_" + tname + "_f16", source_name, merge_maps(merge_maps(base_dict, float_type_dict), {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a_unaligned}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}}), fp16, coopmat, coopmat2, f16acc); - string_to_spv(shader_name + "_" + tname + "_f16_aligned", source_name, merge_maps(merge_maps(base_dict, float_type_dict), {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc); + string_to_spv(shader_name + "_" + tname + "_f16" + dot2_sfx, source_name, merge_maps(merge_maps(base_dict, float_type_dict), {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a_unaligned}, {"B_TYPE", "float16_t"}, {"B_TYPEV4", "f16vec4"}, {"D_TYPE", "float"}}), fp16, coopmat, coopmat2, f16acc); + string_to_spv(shader_name + "_" + tname + "_f16" + dot2_sfx + "_aligned", source_name, merge_maps(merge_maps(base_dict, float_type_dict), {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"B_TYPEV4", "f16vec4"}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc); } #if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT) - // Integer dot mmq performs better with f32 accumulators - if (!f16acc && !coopmat && !coopmat2 && (is_legacy_quant(tname) || is_k_quant(tname) || tname == "mxfp4")) { + // Integer dot mmq performs better with f32 accumulators (different shader, skip for dot2) + if (!f16acc && !coopmat && !coopmat2 && !dot2 && (is_legacy_quant(tname) || is_k_quant(tname) || tname == "mxfp4")) { string_to_spv(shader_name + "_" + tname + "_q8_1", "mul_mmq.comp", merge_maps(merge_maps(base_dict, float_type_dict), {{data_a_key, "1"}, {"D_TYPE", "float"},}), fp16, coopmat, coopmat2, f16acc); } #endif @@ -595,8 +611,6 @@ void matmul_shaders(bool fp16, MatMulIdType matmul_id_type, bool coopmat, bool c } void process_shaders() { - std::map<std::string, std::string> base_dict = {{"FLOAT_TYPE", "float"}, {"FLOAT_TYPE_VEC2", "vec2"}}; - // matmul for (const MatMulIdType& matmul_id_type : {MatMulIdType::NONE, MatMulIdType::DEFAULT, MatMulIdType::SUBGROUP}) { // No coopmats @@ -607,6 +621,10 @@ void process_shaders() { matmul_shaders(true, matmul_id_type, false, false, false); matmul_shaders(true, matmul_id_type, false, false, true); + // dot2 variants (scalar fp16 only) + matmul_shaders(true, matmul_id_type, false, false, false, true); + matmul_shaders(true, matmul_id_type, false, false, true, true); + if (matmul_id_type != MatMulIdType::DEFAULT) { #if defined(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT) // Coopmat, fp32acc and fp16acc @@ -622,77 +640,103 @@ void process_shaders() { } } - // flash attention - for (const auto& f16acc : {false, true}) { - std::map<std::string, std::string> fa_base_dict = base_dict; - fa_base_dict["ACC_TYPE"] = f16acc ? "float16_t" : "float"; - fa_base_dict["ACC_TYPEV4"] = f16acc ? "f16vec4" : "vec4"; - if (f16acc) { - fa_base_dict["ACC_TYPE_MAX"] = "float16_t(65504.0)"; + for (const bool& fp16 : {false, true}) { + std::map<std::string, std::string> base_dict; + if (fp16) { + base_dict = {{"FLOAT_TYPE", "float16_t"}, {"FLOAT_TYPEV2", "f16vec2"}, {"FLOAT_TYPEV4", "f16vec4"}, {"FLOAT16", "1"}, {"FLOAT_TYPE_MAX", "float16_t(65504.0)"}}; + } else { + base_dict = {{"FLOAT_TYPE", "float"}, {"FLOAT_TYPEV2", "vec2"}, {"FLOAT_TYPEV4", "vec4"}}; } - for (const auto& tname : type_names) { - if (tname == "bf16") continue; + // flash attention + for (const bool& f16acc : {false, true}) { + std::map<std::string, std::string> fa_base_dict = base_dict; + fa_base_dict["ACC_TYPE"] = fp16 && f16acc ? "float16_t" : "float"; + fa_base_dict["ACC_TYPEV2"] = fp16 && f16acc ? "f16vec2" : "vec2"; + fa_base_dict["ACC_TYPEV4"] = fp16 && f16acc ? "f16vec4" : "vec4"; + if (fp16 && f16acc) { + fa_base_dict["ACC_TYPE_MAX"] = "float16_t(65504.0)"; + } + if (fp16) { #if defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT) - if (tname == "f16") { - string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn_cm2.comp", - merge_maps(fa_base_dict, {{"Q_TYPE", "float"}, {"D_TYPE", "float"}}), true, false, true, f16acc); - } else { - std::string data_a_key = "DATA_A_" + to_uppercase(tname); - string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn_cm2.comp", - merge_maps(fa_base_dict, {{data_a_key, "1"}, {"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"DEQUANTFUNC", "dequantFunc"+to_uppercase(tname) }, {"BLOCK_SIZE", "QUANT_K_"+to_uppercase(tname) }}), true, false, true, f16acc); - } + string_to_spv("flash_attn_f32_f16", "flash_attn_cm2.comp", + merge_maps(fa_base_dict, {{"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"D_TYPEV4", "vec4"}}), fp16, false, true, f16acc); #endif + #if defined(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT) - if (tname == "f16") { - string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn_cm1.comp", - merge_maps(fa_base_dict, {{"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"COOPMAT", "1"}}), true, true, false, f16acc); - } else if (tname == "q4_0" || tname == "q8_0" || tname == "f32") { - std::string data_a_key = "DATA_A_" + to_uppercase(tname); - string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn_cm1.comp", - merge_maps(fa_base_dict, {{data_a_key, "1"}, {"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"BLOCK_SIZE", "QUANT_K_"+to_uppercase(tname)}, {"COOPMAT", "1"}}), true, true, false, f16acc); - } + string_to_spv("flash_attn_f32_f16", "flash_attn_cm1.comp", + merge_maps(fa_base_dict, {{"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"D_TYPEV4", "vec4"}, {"COOPMAT", "1"}}), fp16, true, false, f16acc); #endif - if (tname == "f16") { - string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn.comp", - merge_maps(fa_base_dict, {{"Q_TYPE", "float"}, {"D_TYPE", "float"}}), true, false, false, f16acc); - } else if (tname == "q4_0" || tname == "q8_0" || tname == "f32") { - std::string data_a_key = "DATA_A_" + to_uppercase(tname); - string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn.comp", - merge_maps(fa_base_dict, {{data_a_key, "1"}, {"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"BLOCK_SIZE", "QUANT_K_"+to_uppercase(tname) }}), true, false, false, f16acc); } + + string_to_spv("flash_attn_f32_f16", "flash_attn.comp", + merge_maps(fa_base_dict, {{"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"D_TYPEV4", "vec4"}}), fp16, false, false, f16acc); + + if (fp16) { + string_to_spv("flash_attn_f32_f16_dot2", "flash_attn.comp", + merge_maps(fa_base_dict, {{"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"D_TYPEV4", "vec4"}, {"DOT2_F16", "1"}}), fp16, false, false, f16acc); + } + +#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT) + string_to_spv("flash_attn_f32_f16", "flash_attn.comp", + merge_maps(fa_base_dict, {{"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"D_TYPEV4", "vec4"}, {"MMQ", "1"}, {"FA_MMQ_MIXED", "1"}}), fp16, false, false, f16acc, "_int8"); +#endif } } + const std::map<std::string, std::string> fa_bf16_dict = { + {"FLOAT_TYPE", "bfloat16_t"}, + {"FLOAT_TYPEV2", "bf16vec2"}, + {"FLOAT_TYPEV4", "bf16vec4"}, + {"ACC_TYPE", "float"}, + {"ACC_TYPEV2", "vec2"}, + {"ACC_TYPEV4", "vec4"}, + {"BFLOAT16", "1"}, + }; + +#if defined(GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT) && defined(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT) + string_to_spv("flash_attn_f32_f16_bf16", "flash_attn_cm1.comp", + merge_maps(fa_bf16_dict, {{"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"D_TYPEV4", "vec4"}, {"COOPMAT", "1"}}), + true, true, false, false); +#endif + +#if defined(GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT) && defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT) + string_to_spv("flash_attn_f32_f16_bf16", "flash_attn_cm2.comp", + merge_maps(fa_bf16_dict, {{"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"D_TYPEV4", "vec4"}}), + true, false, true, false); +#endif + + std::map<std::string, std::string> base_dict = {{"FLOAT_TYPE", "float"}, {"FLOAT_TYPEV2", "vec2"}}; + for (const auto& tname : type_names) { // mul mat vec std::string data_a_key = "DATA_A_" + to_uppercase(tname); std::string shader = (string_ends_with(tname, "_k") || string_starts_with(tname, "iq1_") || string_starts_with(tname, "iq2_") || string_starts_with(tname, "iq3_")) ? "mul_mat_vec_" + tname + ".comp" : "mul_mat_vec.comp"; - string_to_spv("mul_mat_vec_" + tname + "_f32_f32", shader, merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "float"}, {"B_TYPE_VEC2", "vec2"}, {"B_TYPE_VEC4", "vec4"}, {"D_TYPE", "float"}})); - string_to_spv("mul_mat_vec_" + tname + "_f16_f32", shader, merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "float16_t"}, {"B_TYPE_VEC2", "f16vec2"}, {"B_TYPE_VEC4", "f16vec4"}, {"D_TYPE", "float"}})); + string_to_spv("mul_mat_vec_" + tname + "_f32_f32", shader, merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "float"}, {"B_TYPEV2", "vec2"}, {"B_TYPEV4", "vec4"}, {"D_TYPE", "float"}})); + string_to_spv("mul_mat_vec_" + tname + "_f16_f32", shader, merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "float16_t"}, {"B_TYPEV2", "f16vec2"}, {"B_TYPEV4", "f16vec4"}, {"D_TYPE", "float"}})); - string_to_spv("mul_mat_vec_" + tname + "_f32_f32_subgroup", shader, merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "float"}, {"B_TYPE_VEC2", "vec2"}, {"B_TYPE_VEC4", "vec4"}, {"D_TYPE", "float"}, {"USE_SUBGROUP_ADD", "1"}})); - string_to_spv("mul_mat_vec_" + tname + "_f16_f32_subgroup", shader, merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "float16_t"}, {"B_TYPE_VEC2", "f16vec2"}, {"B_TYPE_VEC4", "f16vec4"}, {"D_TYPE", "float"}, {"USE_SUBGROUP_ADD", "1"}})); + string_to_spv("mul_mat_vec_" + tname + "_f32_f32_subgroup", shader, merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "float"}, {"B_TYPEV2", "vec2"}, {"B_TYPEV4", "vec4"}, {"D_TYPE", "float"}, {"USE_SUBGROUP_ADD", "1"}})); + string_to_spv("mul_mat_vec_" + tname + "_f16_f32_subgroup", shader, merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "float16_t"}, {"B_TYPEV2", "f16vec2"}, {"B_TYPEV4", "f16vec4"}, {"D_TYPE", "float"}, {"USE_SUBGROUP_ADD", "1"}})); - string_to_spv("mul_mat_vec_" + tname + "_f32_f32_subgroup_no_shmem", shader, merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "float"}, {"B_TYPE_VEC2", "vec2"}, {"B_TYPE_VEC4", "vec4"}, {"D_TYPE", "float"}, {"USE_SUBGROUP_ADD_NO_SHMEM", "1"}})); - string_to_spv("mul_mat_vec_" + tname + "_f16_f32_subgroup_no_shmem", shader, merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "float16_t"}, {"B_TYPE_VEC2", "f16vec2"}, {"B_TYPE_VEC4", "f16vec4"}, {"D_TYPE", "float"}, {"USE_SUBGROUP_ADD_NO_SHMEM", "1"}})); + string_to_spv("mul_mat_vec_" + tname + "_f32_f32_subgroup_no_shmem", shader, merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "float"}, {"B_TYPEV2", "vec2"}, {"B_TYPEV4", "vec4"}, {"D_TYPE", "float"}, {"USE_SUBGROUP_ADD_NO_SHMEM", "1"}})); + string_to_spv("mul_mat_vec_" + tname + "_f16_f32_subgroup_no_shmem", shader, merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "float16_t"}, {"B_TYPEV2", "f16vec2"}, {"B_TYPEV4", "f16vec4"}, {"D_TYPE", "float"}, {"USE_SUBGROUP_ADD_NO_SHMEM", "1"}})); - string_to_spv("mul_mat_vec_id_" + tname + "_f32_f32", shader, merge_maps(base_dict, {{"MUL_MAT_ID", "1"}, {data_a_key, "1"}, {"B_TYPE", "float"}, {"B_TYPE_VEC2", "vec2"}, {"B_TYPE_VEC4", "vec4"}, {"D_TYPE", "float"}})); - string_to_spv("mul_mat_vec_id_" + tname + "_f32_f32_subgroup", shader, merge_maps(base_dict, {{"MUL_MAT_ID", "1"}, {data_a_key, "1"}, {"B_TYPE", "float"}, {"B_TYPE_VEC2", "vec2"}, {"B_TYPE_VEC4", "vec4"}, {"D_TYPE", "float"}, {"USE_SUBGROUP_ADD", "1"}})); - string_to_spv("mul_mat_vec_id_" + tname + "_f32_f32_subgroup_no_shmem", shader, merge_maps(base_dict, {{"MUL_MAT_ID", "1"}, {data_a_key, "1"}, {"B_TYPE", "float"}, {"B_TYPE_VEC2", "vec2"}, {"B_TYPE_VEC4", "vec4"}, {"D_TYPE", "float"}, {"USE_SUBGROUP_ADD_NO_SHMEM", "1"}})); + string_to_spv("mul_mat_vec_id_" + tname + "_f32_f32", shader, merge_maps(base_dict, {{"MUL_MAT_ID", "1"}, {data_a_key, "1"}, {"B_TYPE", "float"}, {"B_TYPEV2", "vec2"}, {"B_TYPEV4", "vec4"}, {"D_TYPE", "float"}})); + string_to_spv("mul_mat_vec_id_" + tname + "_f32_f32_subgroup", shader, merge_maps(base_dict, {{"MUL_MAT_ID", "1"}, {data_a_key, "1"}, {"B_TYPE", "float"}, {"B_TYPEV2", "vec2"}, {"B_TYPEV4", "vec4"}, {"D_TYPE", "float"}, {"USE_SUBGROUP_ADD", "1"}})); + string_to_spv("mul_mat_vec_id_" + tname + "_f32_f32_subgroup_no_shmem", shader, merge_maps(base_dict, {{"MUL_MAT_ID", "1"}, {data_a_key, "1"}, {"B_TYPE", "float"}, {"B_TYPEV2", "vec2"}, {"B_TYPEV4", "vec4"}, {"D_TYPE", "float"}, {"USE_SUBGROUP_ADD_NO_SHMEM", "1"}})); // mul mat vec with integer dot product #if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT) if (is_legacy_quant(tname) || tname == "mxfp4" || is_k_quant(tname) || tname == "iq1_s" || tname == "iq1_m") { - string_to_spv("mul_mat_vec_" + tname + "_q8_1_f32", "mul_mat_vecq.comp", merge_maps(base_dict, {{data_a_key, "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"FLOAT_TYPE_VEC2", "vec2"}, {"ACC_TYPE", "float"}})); - string_to_spv("mul_mat_vec_" + tname + "_q8_1_f32_subgroup", "mul_mat_vecq.comp", merge_maps(base_dict, {{data_a_key, "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"FLOAT_TYPE_VEC2", "vec2"}, {"ACC_TYPE", "float"}, {"USE_SUBGROUP_ADD", "1"}})); - string_to_spv("mul_mat_vec_" + tname + "_q8_1_f32_subgroup_no_shmem", "mul_mat_vecq.comp", merge_maps(base_dict, {{data_a_key, "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"FLOAT_TYPE_VEC2", "vec2"}, {"ACC_TYPE", "float"}, {"USE_SUBGROUP_ADD_NO_SHMEM", "1"}})); + string_to_spv("mul_mat_vec_" + tname + "_q8_1_f32", "mul_mat_vecq.comp", merge_maps(base_dict, {{data_a_key, "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"FLOAT_TYPEV2", "vec2"}, {"ACC_TYPE", "float"}})); + string_to_spv("mul_mat_vec_" + tname + "_q8_1_f32_subgroup", "mul_mat_vecq.comp", merge_maps(base_dict, {{data_a_key, "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"FLOAT_TYPEV2", "vec2"}, {"ACC_TYPE", "float"}, {"USE_SUBGROUP_ADD", "1"}})); + string_to_spv("mul_mat_vec_" + tname + "_q8_1_f32_subgroup_no_shmem", "mul_mat_vecq.comp", merge_maps(base_dict, {{data_a_key, "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"FLOAT_TYPEV2", "vec2"}, {"ACC_TYPE", "float"}, {"USE_SUBGROUP_ADD_NO_SHMEM", "1"}})); - string_to_spv("mul_mat_vec_id_" + tname + "_q8_1_f32", "mul_mat_vecq.comp", merge_maps(base_dict, {{"MUL_MAT_ID", "1"}, {data_a_key, "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"FLOAT_TYPE_VEC2", "vec2"}, {"ACC_TYPE", "float"}})); - string_to_spv("mul_mat_vec_id_" + tname + "_q8_1_f32_subgroup", "mul_mat_vecq.comp", merge_maps(base_dict, {{"MUL_MAT_ID", "1"}, {data_a_key, "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"FLOAT_TYPE_VEC2", "vec2"}, {"ACC_TYPE", "float"}, {"USE_SUBGROUP_ADD", "1"}})); - string_to_spv("mul_mat_vec_id_" + tname + "_q8_1_f32_subgroup_no_shmem", "mul_mat_vecq.comp", merge_maps(base_dict, {{"MUL_MAT_ID", "1"}, {data_a_key, "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"FLOAT_TYPE_VEC2", "vec2"}, {"ACC_TYPE", "float"}, {"USE_SUBGROUP_ADD_NO_SHMEM", "1"}})); + string_to_spv("mul_mat_vec_id_" + tname + "_q8_1_f32", "mul_mat_vecq.comp", merge_maps(base_dict, {{"MUL_MAT_ID", "1"}, {data_a_key, "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"FLOAT_TYPEV2", "vec2"}, {"ACC_TYPE", "float"}})); + string_to_spv("mul_mat_vec_id_" + tname + "_q8_1_f32_subgroup", "mul_mat_vecq.comp", merge_maps(base_dict, {{"MUL_MAT_ID", "1"}, {data_a_key, "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"FLOAT_TYPEV2", "vec2"}, {"ACC_TYPE", "float"}, {"USE_SUBGROUP_ADD", "1"}})); + string_to_spv("mul_mat_vec_id_" + tname + "_q8_1_f32_subgroup_no_shmem", "mul_mat_vecq.comp", merge_maps(base_dict, {{"MUL_MAT_ID", "1"}, {data_a_key, "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"FLOAT_TYPEV2", "vec2"}, {"ACC_TYPE", "float"}, {"USE_SUBGROUP_ADD_NO_SHMEM", "1"}})); } #endif @@ -713,9 +757,9 @@ void process_shaders() { string_to_spv("get_rows_i32", "get_rows.comp", {{"TEMP_TYPE", "uint"}, {"A_TYPE", "uint"}, {"B_TYPE", "int"}, {"D_TYPE", "uint"}}); - string_to_spv("mul_mat_vec_p021_f16_f32_subgroup_add", "mul_mat_vec_p021.comp", {{"A_TYPE", "float16_t"}, {"A_TYPE_VEC4", "f16vec4"}, {"B_TYPE", "float"}, {"B_TYPE_VEC4", "vec4"}, {"D_TYPE", "float"}, {"USE_SUBGROUP_ADD", "1"}}); - string_to_spv("mul_mat_vec_p021_f16_f32", "mul_mat_vec_p021.comp", {{"A_TYPE", "float16_t"}, {"A_TYPE_VEC4", "f16vec4"}, {"B_TYPE", "float"}, {"B_TYPE_VEC4", "vec4"}, {"D_TYPE", "float"}}); - string_to_spv("mul_mat_vec_nc_f16_f32", "mul_mat_vec_nc.comp", {{"A_TYPE", "float16_t"}, {"A_TYPE_VEC4", "f16vec4"}, {"B_TYPE", "float"}, {"B_TYPE_VEC4", "vec4"}, {"D_TYPE", "float"}}); + string_to_spv("mul_mat_vec_p021_f16_f32_subgroup_add", "mul_mat_vec_p021.comp", {{"A_TYPE", "float16_t"}, {"A_TYPEV4", "f16vec4"}, {"B_TYPE", "float"}, {"B_TYPEV4", "vec4"}, {"D_TYPE", "float"}, {"USE_SUBGROUP_ADD", "1"}}); + string_to_spv("mul_mat_vec_p021_f16_f32", "mul_mat_vec_p021.comp", {{"A_TYPE", "float16_t"}, {"A_TYPEV4", "f16vec4"}, {"B_TYPE", "float"}, {"B_TYPEV4", "vec4"}, {"D_TYPE", "float"}}); + string_to_spv("mul_mat_vec_nc_f16_f32", "mul_mat_vec_nc.comp", {{"A_TYPE", "float16_t"}, {"A_TYPEV4", "f16vec4"}, {"B_TYPE", "float"}, {"B_TYPEV4", "vec4"}, {"D_TYPE", "float"}}); // Norms string_to_spv("norm_f32", "norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}})); @@ -723,7 +767,7 @@ void process_shaders() { string_to_spv("rms_norm_f32", "rms_norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}})); string_to_spv("rms_norm_partials_f32", "rms_norm_partials.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}})); string_to_spv("rms_norm_mul_rope_f32_f32", "rms_norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"ROPE_D_TYPE", "float"}, {"RMS_NORM_ROPE_FUSION", "1"}})); - string_to_spv("rms_norm_mul_rope_f32_f16_rte", "rms_norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"ROPE_D_TYPE", "float16_t"}, {"RMS_NORM_ROPE_FUSION", "1"}, {"RTE16", "1"}})); + string_to_spv("rms_norm_mul_rope_f32_f16", "rms_norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"ROPE_D_TYPE", "float16_t"}, {"RMS_NORM_ROPE_FUSION", "1"}})); string_to_spv("rms_norm_back_f32", "rms_norm_back.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}})); string_to_spv("l2_norm_f32", "l2_norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}})); @@ -732,6 +776,7 @@ void process_shaders() { string_to_spv("cpy_f16_f16", "copy.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"OPTIMIZATION_ERROR_WORKAROUND", "1"}}); string_to_spv("cpy_f16_f32", "copy.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float"}, {"OPTIMIZATION_ERROR_WORKAROUND", "1"}}); string_to_spv("cpy_f32_bf16","copy.comp", {{"A_TYPE", "float"}, {"D_TYPE", "uint16_t"}, {"DATA_D_BF16", "1"}}); + string_to_spv("cpy_bf16_f32","copy.comp", {{"A_TYPE", "uint16_t"}, {"D_TYPE", "float"}, {"DATA_A_BF16", "1"}}); string_to_spv("contig_cpy_f32_f32", "contig_copy.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); string_to_spv("contig_cpy_f32_i32", "contig_copy.comp", {{"A_TYPE", "float"}, {"D_TYPE", "int"}}); string_to_spv("contig_cpy_i32_f32", "contig_copy.comp", {{"A_TYPE", "int"}, {"D_TYPE", "float"}}); @@ -739,23 +784,21 @@ void process_shaders() { string_to_spv("contig_cpy_f16_f16", "contig_copy.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"OPTIMIZATION_ERROR_WORKAROUND", "1"}}); string_to_spv("contig_cpy_f16_f32", "contig_copy.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float"}, {"OPTIMIZATION_ERROR_WORKAROUND", "1"}}); string_to_spv("contig_cpy_f32_bf16","contig_copy.comp",{{"A_TYPE", "float"}, {"D_TYPE", "uint16_t"}, {"DATA_D_BF16", "1"}}); + string_to_spv("contig_cpy_bf16_f32","contig_copy.comp",{{"A_TYPE", "uint16_t"}, {"D_TYPE", "float"}, {"DATA_A_BF16", "1"}}); string_to_spv("cpy_f32_i32", "copy.comp", {{"A_TYPE", "float"}, {"D_TYPE", "int"}}); string_to_spv("cpy_i32_f32", "copy.comp", {{"A_TYPE", "int"}, {"D_TYPE", "float"}}); string_to_spv("cpy_transpose_16", "copy_transpose.comp", {{"A_TYPE", "uint16_t"}, {"D_TYPE", "uint16_t"}}); string_to_spv("cpy_transpose_32", "copy_transpose.comp", {{"A_TYPE", "uint"}, {"D_TYPE", "uint"}}); - for (std::string t : {"q4_0", "q4_1", "q5_0", "q5_1", "q8_0", "iq4_nl"}) { + for (std::string t : {"q1_0", "q4_0", "q4_1", "q5_0", "q5_1", "q8_0", "iq4_nl"}) { string_to_spv("cpy_f32_" + t, "copy_to_quant.comp", {{"DATA_A_" + to_uppercase(t), "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}}); - string_to_spv("cpy_f32_" + t + "_rte", "copy_to_quant.comp", {{"DATA_A_" + to_uppercase(t), "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"RTE16", "1"}}); string_to_spv("cpy_" + t + "_f32", "copy_from_quant.comp", {{"DATA_A_" + to_uppercase(t), "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}}); } - for (std::string t : {"f32", "f16", "bf16", "q4_0", "q4_1", "q5_0", "q5_1", "q8_0", "iq4_nl"}) { - string_to_spv("set_rows_" + t + "_i32", "copy_to_quant.comp", {{"SET_ROWS", "1"}, {"DATA_A_" + to_uppercase(t), "1"}, {"B_TYPE", "uint"}, {"B_SIZE", "32"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}}); - string_to_spv("set_rows_" + t + "_i32_rte", "copy_to_quant.comp", {{"SET_ROWS", "1"}, {"DATA_A_" + to_uppercase(t), "1"}, {"B_TYPE", "uint"}, {"B_SIZE", "32"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"RTE16", "1"}}); - string_to_spv("set_rows_" + t + "_i64", "copy_to_quant.comp", {{"SET_ROWS", "1"}, {"DATA_A_" + to_uppercase(t), "1"}, {"B_TYPE", "uvec2"}, {"B_SIZE", "64"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}}); - string_to_spv("set_rows_" + t + "_i64_rte", "copy_to_quant.comp", {{"SET_ROWS", "1"}, {"DATA_A_" + to_uppercase(t), "1"}, {"B_TYPE", "uvec2"}, {"B_SIZE", "64"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"RTE16", "1"}}); + for (std::string t : {"f32", "f16", "bf16", "q1_0", "q4_0", "q4_1", "q5_0", "q5_1", "q8_0", "iq4_nl"}) { + string_to_spv("set_rows_" + t + "_i32", "copy_to_quant.comp", {{"SET_ROWS", "1"}, {"DATA_A_" + to_uppercase(t), "1"}, {"B_TYPE", "uint"}, {"B_SIZE", "32"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}}); + string_to_spv("set_rows_" + t + "_i64", "copy_to_quant.comp", {{"SET_ROWS", "1"}, {"DATA_A_" + to_uppercase(t), "1"}, {"B_TYPE", "uvec2"}, {"B_SIZE", "64"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}}); } auto get_type_str = [](bool f16) { @@ -772,12 +815,10 @@ void process_shaders() { for (auto src0_f16 : {false, true}) { for (auto src1_f16 : {false, true}) { for (auto dst_f16 : {false, true}) { - for (auto rte : {false, true}) { auto source = op == "add_rms" ? std::string("add") : op; - auto name = op + get_suffix(src0_f16, src1_f16, dst_f16) + (rte ? "_rte" : ""); + auto name = op + get_suffix(src0_f16, src1_f16, dst_f16); auto add_rms = op == "add_rms" ? "1" : "0"; - string_to_spv(name.c_str(), source + ".comp", {{"A_TYPE", get_type_str(src0_f16)}, {"B_TYPE", get_type_str(src1_f16)}, {"D_TYPE", get_type_str(dst_f16)}, {"FLOAT_TYPE", "float"}, {"RTE16", rte ? "1" : "0"}, {"ADD_RMS" , add_rms}}); - } + string_to_spv(name.c_str(), source + ".comp", {{"A_TYPE", get_type_str(src0_f16)}, {"B_TYPE", get_type_str(src1_f16)}, {"D_TYPE", get_type_str(dst_f16)}, {"FLOAT_TYPE", "float"}, {"ADD_RMS" , add_rms}}); } } } @@ -790,6 +831,8 @@ void process_shaders() { string_to_spv("split_k_reduce", "mul_mat_split_k_reduce.comp", {}); string_to_spv("fa_split_k_reduce", "flash_attn_split_k_reduce.comp", {}); + string_to_spv("fa_mask_opt", "flash_attn_mask_opt.comp", {}); + string_to_spv("quantize_q8_1", "quantize_q8_1.comp", {}); string_to_spv("quantize_q8_1_subgroup", "quantize_q8_1.comp", {{"USE_SUBGROUPS", "1"}}); @@ -800,9 +843,11 @@ void process_shaders() { string_to_spv("div_f32", "div.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}}); - string_to_spv("repeat_f32", "repeat.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); + string_to_spv("repeat_i32", "repeat.comp", {{"A_TYPE", "int32_t"}, {"D_TYPE", "int32_t"}}); string_to_spv("repeat_back_f32", "repeat_back.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); + string_to_spv("repeat_i16", "repeat.comp", {{"A_TYPE", "int16_t"}, {"D_TYPE", "int16_t"}}); + string_to_spv("scale_f32", "scale.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}}); string_to_spv("sqr_f32", "square.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}}); @@ -823,14 +868,11 @@ void process_shaders() { string_to_spv("upscale_f32", "upscale.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}); - for (auto rte : {false, true}) { - std::string suffix = rte ? "_rte" : ""; - string_to_spv("exp_f16" + suffix, "exp.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"RTE16", rte ? "1" : "0"}}); - string_to_spv("exp_f32" + suffix, "exp.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"} , {"RTE16", rte ? "1" : "0"}}); + string_to_spv("exp_f16", "exp.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}}); + string_to_spv("exp_f32", "exp.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); - string_to_spv("log_f16" + suffix, "log.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"RTE16", rte ? "1" : "0"}}); - string_to_spv("log_f32" + suffix, "log.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"RTE16", rte ? "1" : "0"}}); - } + string_to_spv("log_f16", "log.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}}); + string_to_spv("log_f32", "log.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); string_to_spv("gelu_f16", "gelu.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}}); string_to_spv("gelu_f32", "gelu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); string_to_spv("gelu_erf_f16", "gelu_erf.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}}); @@ -853,8 +895,12 @@ void process_shaders() { string_to_spv("hardswish_f32", "hardswish.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); string_to_spv("abs_f16", "abs.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}}); string_to_spv("abs_f32", "abs.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); + string_to_spv("elu_f16", "elu.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}}); + string_to_spv("elu_f32", "elu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); string_to_spv("xielu_f16", "xielu.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}}); string_to_spv("xielu_f32", "xielu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); + string_to_spv("sgn_f16", "sgn.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}}); + string_to_spv("sgn_f32", "sgn.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); string_to_spv("tri_f16", "tri.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}}); string_to_spv("tri_f32", "tri.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); @@ -869,6 +915,7 @@ void process_shaders() { string_to_spv("add1_f32_f32", "add1.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}}); string_to_spv("arange_f32", "arange.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}}); string_to_spv("fill_f32", "fill.comp", {{"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}}); + string_to_spv("fill_f16", "fill.comp", {{"D_TYPE", "float16_t"}, {"FLOAT_TYPE", "float"}}); string_to_spv("step_f16", "step.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}}); string_to_spv("step_f32", "step.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); string_to_spv("round_f16", "round.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}}); @@ -880,21 +927,18 @@ void process_shaders() { string_to_spv("trunc_f16", "trunc.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}}); string_to_spv("trunc_f32", "trunc.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); - for (auto rte : {false, true}) { - std::string suffix = rte ? "_rte" : ""; - string_to_spv("geglu_f16" + suffix, "geglu.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"RTE16", rte ? "1" : "0"}}); - string_to_spv("geglu_f32" + suffix, "geglu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"RTE16", rte ? "1" : "0"}}); - string_to_spv("reglu_f16" + suffix, "reglu.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"RTE16", rte ? "1" : "0"}}); - string_to_spv("reglu_f32" + suffix, "reglu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"RTE16", rte ? "1" : "0"}}); - string_to_spv("swiglu_f16" + suffix, "swiglu.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"RTE16", rte ? "1" : "0"}}); - string_to_spv("swiglu_f32" + suffix, "swiglu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"RTE16", rte ? "1" : "0"}}); - string_to_spv("swiglu_oai_f16" + suffix, "swiglu_oai.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"RTE16", rte ? "1" : "0"}}); - string_to_spv("swiglu_oai_f32" + suffix, "swiglu_oai.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"RTE16", rte ? "1" : "0"}}); - string_to_spv("geglu_erf_f16" + suffix, "geglu_erf.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"RTE16", rte ? "1" : "0"}}); - string_to_spv("geglu_erf_f32" + suffix, "geglu_erf.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"RTE16", rte ? "1" : "0"}}); - string_to_spv("geglu_quick_f16" + suffix,"geglu_quick.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"RTE16", rte ? "1" : "0"}}); - string_to_spv("geglu_quick_f32" + suffix,"geglu_quick.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"RTE16", rte ? "1" : "0"}}); - } + string_to_spv("geglu_f16", "geglu.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}}); + string_to_spv("geglu_f32", "geglu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); + string_to_spv("reglu_f16", "reglu.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}}); + string_to_spv("reglu_f32", "reglu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); + string_to_spv("swiglu_f16", "swiglu.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}}); + string_to_spv("swiglu_f32", "swiglu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); + string_to_spv("swiglu_oai_f16", "swiglu_oai.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}}); + string_to_spv("swiglu_oai_f32", "swiglu_oai.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); + string_to_spv("geglu_erf_f16", "geglu_erf.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}}); + string_to_spv("geglu_erf_f32", "geglu_erf.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); + string_to_spv("geglu_quick_f16","geglu_quick.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}}); + string_to_spv("geglu_quick_f32","geglu_quick.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); string_to_spv("leaky_relu_f32", "leaky_relu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); string_to_spv("silu_back_f32", "silu_back.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}); @@ -914,25 +958,18 @@ void process_shaders() { string_to_spv("rope_norm_f32", "rope_norm.comp", {{"A_TYPE", "float"}, {"ROPE_D_TYPE", "float"}}); string_to_spv("rope_norm_f16", "rope_norm.comp", {{"A_TYPE", "float16_t"}, {"ROPE_D_TYPE", "float16_t"}}); - string_to_spv("rope_norm_f16_rte", "rope_norm.comp", {{"A_TYPE", "float16_t"}, {"ROPE_D_TYPE", "float16_t"}, {"RTE16", "1"}}); string_to_spv("rope_norm_f32_f16", "rope_norm.comp", {{"A_TYPE", "float"}, {"ROPE_D_TYPE", "float16_t"}}); - string_to_spv("rope_norm_f32_f16_rte", "rope_norm.comp", {{"A_TYPE", "float"}, {"ROPE_D_TYPE", "float16_t"}, {"RTE16", "1"}}); string_to_spv("rope_neox_f32", "rope_neox.comp", {{"A_TYPE", "float"}, {"ROPE_D_TYPE", "float"}}); string_to_spv("rope_neox_f16", "rope_neox.comp", {{"A_TYPE", "float16_t"}, {"ROPE_D_TYPE", "float16_t"}}); - string_to_spv("rope_neox_f16_rte", "rope_neox.comp", {{"A_TYPE", "float16_t"}, {"ROPE_D_TYPE", "float16_t"}, {"RTE16", "1"}}); string_to_spv("rope_neox_f32_f16", "rope_neox.comp", {{"A_TYPE", "float"}, {"ROPE_D_TYPE", "float16_t"}}); - string_to_spv("rope_neox_f32_f16_rte", "rope_neox.comp", {{"A_TYPE", "float"}, {"ROPE_D_TYPE", "float16_t"}, {"RTE16", "1"}}); string_to_spv("rope_multi_f32", "rope_multi.comp", {{"A_TYPE", "float"}, {"ROPE_D_TYPE", "float"}}); string_to_spv("rope_multi_f16", "rope_multi.comp", {{"A_TYPE", "float16_t"}, {"ROPE_D_TYPE", "float16_t"}}); - string_to_spv("rope_multi_f16_rte", "rope_multi.comp", {{"A_TYPE", "float16_t"}, {"ROPE_D_TYPE", "float16_t"}, {"RTE16", "1"}}); string_to_spv("rope_multi_f32_f16", "rope_multi.comp", {{"A_TYPE", "float"}, {"ROPE_D_TYPE", "float16_t"}}); - string_to_spv("rope_multi_f32_f16_rte", "rope_multi.comp", {{"A_TYPE", "float"}, {"ROPE_D_TYPE", "float16_t"}, {"RTE16", "1"}}); string_to_spv("rope_vision_f32", "rope_vision.comp", {{"A_TYPE", "float"}, {"ROPE_D_TYPE", "float"}}); string_to_spv("rope_vision_f16", "rope_vision.comp", {{"A_TYPE", "float16_t"}, {"ROPE_D_TYPE", "float16_t"}}); - string_to_spv("rope_vision_f16_rte", "rope_vision.comp", {{"A_TYPE", "float16_t"}, {"ROPE_D_TYPE", "float16_t"}, {"RTE16", "1"}}); string_to_spv("argsort_f32", "argsort.comp", {{"A_TYPE", "float"}}); string_to_spv("argsort_large_f32", "argsort_large.comp", {{"A_TYPE", "float"}}); @@ -942,6 +979,8 @@ void process_shaders() { string_to_spv("argmax_f32", "argmax.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "int"}})); string_to_spv("sum_rows_f32", "sum_rows.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}})); + string_to_spv("fwht_f32", "fwht.comp", {}); + string_to_spv("fwht_shmem_f32", "fwht.comp", {{"FWHT_SHMEM", "1"}}); string_to_spv("count_equal_i32", "count_equal.comp", merge_maps(base_dict, {{"A_TYPE", "int"}, {"B_TYPE", "int"}, {"D_TYPE", "int"}})); string_to_spv("cumsum_f32", "cumsum.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}})); string_to_spv("cumsum_multipass1_f32", "cumsum_multipass1.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}})); @@ -955,7 +994,6 @@ void process_shaders() { std::string bda_def = bda ? "1" : "0"; string_to_spv("im2col" + dim_str + "_f32" + bda_str, "im2col" + dim_str + ".comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"D_SIZE", "4"}, {"BDA", bda_def}})); string_to_spv("im2col" + dim_str + "_f32_f16" + bda_str, "im2col" + dim_str + ".comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float16_t"}, {"D_SIZE", "2"}, {"BDA", bda_def}})); - string_to_spv("im2col" + dim_str + "_f32_f16_rte" + bda_str, "im2col" + dim_str + ".comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float16_t"}, {"D_SIZE", "2"}, {"RTE16", "1"}, {"BDA", bda_def}})); } } @@ -963,12 +1001,20 @@ void process_shaders() { string_to_spv("conv_transpose_1d_f32", "conv_transpose_1d.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}); + string_to_spv("snake_f32", "snake.comp", {{"DATA_A_F32", "1"}, {"A_TYPE", "float"}, {"D_TYPE", "float"}}); + string_to_spv("snake_f16", "snake.comp", {{"DATA_A_F16", "1"}, {"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}}); + string_to_spv("snake_bf16", "snake.comp", {{"DATA_A_BF16", "1"}, {"DATA_D_BF16", "1"}, {"A_TYPE", "uint16_t"}, {"D_TYPE", "uint16_t"}}); + string_to_spv("pool2d_f32", "pool2d.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}})); string_to_spv("rwkv_wkv6_f32", "wkv6.comp", merge_maps(base_dict, {{"A_TYPE", "float"}})); string_to_spv("rwkv_wkv7_f32", "wkv7.comp", merge_maps(base_dict, {{"A_TYPE", "float"}})); + string_to_spv("gated_delta_net_f32", "gated_delta_net.comp", merge_maps(base_dict, {{"FLOAT_TYPE", "float"}, {"USE_SUBGROUP_ADD", "1"}, {"USE_SUBGROUP_CLUSTERED", "1"}})); + string_to_spv("gated_delta_net_f32_nocluster", "gated_delta_net.comp", merge_maps(base_dict, {{"FLOAT_TYPE", "float"}, {"USE_SUBGROUP_ADD", "1"}, {"USE_SUBGROUP_CLUSTERED", "0"}})); + string_to_spv("gated_delta_net_f32_shmem", "gated_delta_net.comp", merge_maps(base_dict, {{"FLOAT_TYPE", "float"}, {"USE_SUBGROUP_ADD", "0"}, {"USE_SUBGROUP_CLUSTERED", "0"}})); + string_to_spv("opt_step_adamw_f32", "opt_step_adamw.comp", merge_maps(base_dict, {{"A_TYPE", "float"}})); string_to_spv("opt_step_sgd_f32", "opt_step_sgd.comp", merge_maps(base_dict, {{"A_TYPE", "float"}})); @@ -987,8 +1033,16 @@ void process_shaders() { string_to_spv(name + (unroll ? "_unroll" : ""), "conv2d_mm.comp", defines); #if defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT) if (unroll) { - defines["COOPMAT2"] = "1"; - string_to_spv(name, "conv2d_mm.comp", defines, true, false, true); + auto cm2_defines = defines; + cm2_defines["COOPMAT2"] = "1"; + string_to_spv(name, "conv2d_mm.comp", cm2_defines, true, false, true); + } +#endif +#if defined(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT) + if (unroll) { + auto cm1_defines = defines; + cm1_defines["COOPMAT"] = "1"; + string_to_spv(name, "conv2d_mm.comp", cm1_defines, true, true, false); } #endif } @@ -1004,8 +1058,8 @@ void process_shaders() { string_to_spv("add_id_f32", "add_id.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}})); - string_to_spv("multi_add_f32", "multi_add.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"RTE16", "1"}, {"ADD_RMS" , "0"}}); - string_to_spv("multi_add_rms_f32", "multi_add.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"RTE16", "1"}, {"ADD_RMS" , "1"}}); + string_to_spv("multi_add_f32", "multi_add.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"ADD_RMS" , "0"}}); + string_to_spv("multi_add_rms_f32", "multi_add.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"ADD_RMS" , "1"}}); string_to_spv("ssm_scan_f32", "ssm_scan.comp", {{"A_TYPE", "float"}}); string_to_spv("ssm_scan_subgroup_f32", "ssm_scan.comp", {{"A_TYPE", "float"}, {"USE_SUBGROUP_ADD", "1"}}); @@ -1058,8 +1112,8 @@ void write_output_files() { std::string suffixes[2] = {"_f32", "_f16"}; for (std::string op : {"add", "sub", "mul", "div", "add_rms"}) { - hdr << "extern const void * " << op << "_data[2][2][2][2];\n"; - hdr << "extern const uint64_t " << op << "_len[2][2][2][2];\n"; + hdr << "extern const void * " << op << "_data[2][2][2];\n"; + hdr << "extern const uint64_t " << op << "_len[2][2][2];\n"; std::string op_file = op == "add_rms" ? "add.comp" : std::string(op) + ".comp"; if (basename(input_filepath) != op_file) { @@ -1067,8 +1121,8 @@ void write_output_files() { } std::stringstream data = make_generic_stringstream(); std::stringstream len = make_generic_stringstream(); - data << "const void * " << op << "_data[2][2][2][2] = "; - len << "const uint64_t " << op << "_len[2][2][2][2] = "; + data << "const void * " << op << "_data[2][2][2] = "; + len << "const uint64_t " << op << "_len[2][2][2] = "; for (uint32_t t0 = 0; t0 < 2; ++t0) { if (t0 == 0) { data << "{"; @@ -1084,20 +1138,10 @@ void write_output_files() { data << "{"; len << "{"; } - for (uint32_t rte = 0; rte < 2; ++rte) { - if (rte == 0) { - data << "{"; - len << "{"; - } - data << op << suffixes[t0] << suffixes[t1] << suffixes[t2] << ((rte != 0) ? "_rte" : ""); - len << op << suffixes[t0] << suffixes[t1] << suffixes[t2] << ((rte != 0) ? "_rte" : ""); - data << "_data,"; - len << "_len,"; - if (rte == 1) { - data << "}, "; - len << "}, "; - } - } + data << op << suffixes[t0] << suffixes[t1] << suffixes[t2]; + len << op << suffixes[t0] << suffixes[t1] << suffixes[t2]; + data << "_data,"; + len << "_len,"; if (t2 == 1) { data << "}, "; len << "}, "; diff --git a/ggml/src/ggml-webgpu/CMakeLists.txt b/ggml/src/ggml-webgpu/CMakeLists.txt index 3ccce58aa39..1503a1ef8ba 100644 --- a/ggml/src/ggml-webgpu/CMakeLists.txt +++ b/ggml/src/ggml-webgpu/CMakeLists.txt @@ -10,8 +10,11 @@ file(MAKE_DIRECTORY ${SHADER_OUTPUT_DIR}) message(STATUS "Shader output dir: ${SHADER_OUTPUT_DIR}") -# Find all WGSL files -file(GLOB WGSL_SHADER_FILES "${SHADER_DIR}/*.wgsl") +# Find all WGSL sources +file(GLOB WGSL_SHADER_FILES + "${SHADER_DIR}/*.wgsl" + "${SHADER_DIR}/*.tmpl" +) # Generate the header using a Python script add_custom_command( diff --git a/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp b/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp index 7fdb4c8c8da..6f877f15ce9 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp @@ -1,169 +1,3293 @@ #ifndef GGML_WEBGPU_SHADER_LIB_HPP #define GGML_WEBGPU_SHADER_LIB_HPP +#include "ggml-impl.h" +#include "ggml-wgsl-shaders.hpp" #include "ggml.h" #include "pre_wgsl.hpp" +#include <webgpu/webgpu_cpp.h> + +#include <algorithm> +#include <memory> #include <string> +#include <unordered_map> #include <vector> #define GGML_WEBGPU_F16_SIZE_BYTES 2 #define GGML_WEBGPU_F32_SIZE_BYTES 4 +#define GGML_WEBGPU_I32_SIZE_BYTES 4 #define GGML_WEBGPU_FLASH_ATTN_PREFERRED_KV_SG_TILES 8u +#define GGML_WEBGPU_FLASH_ATTN_VEC_MAX_SEQ_LEN 20u +#define GGML_WEBGPU_FLASH_ATTN_VEC_MAX_KV_TILE 32u +#define GGML_WEBGPU_FLASH_ATTN_TILE_MAX_KV_TILE 64u #define GGML_WEBGPU_FLASH_ATTN_PREFERRED_WG_SIZE 128u // Matches GGML_PAD(..., 256) in src/llama-context.cpp for KV cache sizing. #define GGML_WEBGPU_KV_SEQ_PAD 256u -struct ggml_webgpu_flash_attn_shader_lib_context { - ggml_type kv_type; +#define GGML_WEBGPU_ARGSORT_MERGE_MAX_WG_SIZE 512u + +// Matrix multiplication parameters + +// Register tiling parameters +#define WEBGPU_MUL_MAT_TILE_M 4 +#define WEBGPU_MUL_MAT_TILE_N 4 +#define WEBGPU_MUL_MAT_WG_SIZE_M 8 +#define WEBGPU_MUL_MAT_WG_SIZE_N 8 +#define WEBGPU_MUL_MAT_REG_TILE_K_FLOAT 8 +#define WEBGPU_MUL_MAT_REG_TILE_K_QUANT 32 + +// Subgroup matrix parameters +// The number of subgroups in the M dimension +#define WEBGPU_MUL_MAT_SUBGROUP_M 2 +// The number of subgroups in the N dimension +#define WEBGPU_MUL_MAT_SUBGROUP_N 4 +// The number of subgroup matrices each subgroup accumulates over +#define WEBGPU_MUL_MAT_SUBGROUP_MATRIX_M 4 +#define WEBGPU_MUL_MAT_SUBGROUP_MATRIX_N 2 +#define WEBGPU_MUL_MAT_SUBGROUP_TILE_K_FLOAT 32 +#define WEBGPU_MUL_MAT_SUBGROUP_TILE_K_QUANT 32 + +// Matrix-vector multiplication parameters +#define WEBGPU_MUL_MAT_VEC_WG_SIZE 256 + +#define WEBGPU_MUL_MAT_VEC_FLOAT_OUTPUTS_PER_WG 4 +#define WEBGPU_MUL_MAT_VEC_LEGACY_Q_OUTPUTS_PER_WG 4 +#define WEBGPU_MUL_MAT_VEC_K_Q_OUTPUTS_PER_WG 4 + +// default size for reg-tile matrix multiplication +#define WEBGPU_MUL_MAT_WG_SIZE 256 + +// Same hash combine function as in boost +template <typename T> inline void ggml_webgpu_hash_combine(size_t & seed, const T & value) { + seed ^= std::hash<T>{}(value) + 0x9e3779b9 + (seed << 6) + (seed >> 2); +} + +// Calculates base address of a tensor ignoring the fake base pointer +inline uintptr_t ggml_webgpu_tensor_addr(const ggml_tensor * tensor) { + const ggml_tensor * base_tensor = tensor->view_src ? tensor->view_src : tensor; + return (uintptr_t) base_tensor->data + tensor->view_offs; +} + +inline bool ggml_webgpu_tensor_equal(const ggml_tensor * a, const ggml_tensor * b) { + return a->buffer == b->buffer && ggml_webgpu_tensor_addr(a) == ggml_webgpu_tensor_addr(b); +} + +inline bool ggml_webgpu_tensor_overlap(const ggml_tensor * a, const ggml_tensor * b) { + return a->buffer == b->buffer && ggml_webgpu_tensor_addr(a) < ggml_webgpu_tensor_addr(b) + ggml_nbytes(b) && + ggml_webgpu_tensor_addr(b) < ggml_webgpu_tensor_addr(a) + ggml_nbytes(a); +} + +struct ggml_webgpu_shader_lib_context { + ggml_tensor * src0; + ggml_tensor * src1; + ggml_tensor * src2; + ggml_tensor * src3; + ggml_tensor * src4; + ggml_tensor * src5; + ggml_tensor * dst; + + uint32_t max_wg_size; + size_t wg_mem_limit_bytes = 0; + bool supports_subgroups = false; + bool supports_subgroup_matrix = false; + uint32_t sg_mat_m = 0; + uint32_t sg_mat_n = 0; + uint32_t sg_mat_k = 0; + uint32_t min_subgroup_size = 0; + uint32_t max_subgroup_size = 0; + bool supports_dot_product = false; + std::string vendor; +}; + +struct webgpu_pipeline { + wgpu::ComputePipeline pipeline; + std::string name; + std::shared_ptr<void> context = nullptr; +}; + +struct ggml_webgpu_generic_shader_decisions { + uint32_t wg_size = 0; + bool inplace = false; +}; + +struct ggml_webgpu_binary_shader_decisions { + uint32_t wg_size = 0; + bool inplace = false; + bool overlap = false; + bool src_overlap = false; +}; + +struct ggml_webgpu_processed_shader { + std::string wgsl; + std::string variant; + std::shared_ptr<void> decisions; +}; + +struct ggml_webgpu_ssm_conv_shader_decisions { + uint32_t block_size; + uint32_t tokens_per_wg; +}; + +struct ggml_webgpu_ssm_scan_pipeline_key { + int type; + int d_state; + bool xbc_overlap; + + bool operator==(const ggml_webgpu_ssm_scan_pipeline_key & other) const { + return type == other.type && d_state == other.d_state && xbc_overlap == other.xbc_overlap; + } +}; + +struct ggml_webgpu_ssm_scan_pipeline_key_hash { + size_t operator()(const ggml_webgpu_ssm_scan_pipeline_key & key) const { + size_t seed = 0; + ggml_webgpu_hash_combine(seed, key.type); + ggml_webgpu_hash_combine(seed, key.d_state); + ggml_webgpu_hash_combine(seed, key.xbc_overlap); + return seed; + } +}; + +struct ggml_webgpu_ssm_scan_shader_decisions { + uint32_t wg_size; + uint32_t tokens_per_tile; + bool xbc_overlap = false; +}; + +/** Argsort **/ + +struct ggml_webgpu_argsort_shader_lib_context { + uint32_t max_wg_size; + size_t wg_mem_limit_bytes; + int32_t order; +}; + +/** Set Rows **/ + +struct ggml_webgpu_set_rows_pipeline_key { + int dst_type; + int vec4; + int i64_idx; + int pair_blocks; + + bool operator==(const ggml_webgpu_set_rows_pipeline_key & other) const { + return dst_type == other.dst_type && vec4 == other.vec4 && i64_idx == other.i64_idx && + pair_blocks == other.pair_blocks; + } +}; + +struct ggml_webgpu_set_rows_pipeline_key_hash { + size_t operator()(const ggml_webgpu_set_rows_pipeline_key & key) const { + size_t seed = 0; + ggml_webgpu_hash_combine(seed, key.dst_type); + ggml_webgpu_hash_combine(seed, key.vec4); + ggml_webgpu_hash_combine(seed, key.i64_idx); + ggml_webgpu_hash_combine(seed, key.pair_blocks); + return seed; + } +}; + +struct ggml_webgpu_set_rows_shader_decisions { + bool vec4; + bool i64_idx; + bool pair_blocks; + uint32_t wg_size; +}; + +/** Set **/ + +struct ggml_webgpu_set_pipeline_key { + ggml_type type; + bool inplace; + + bool operator==(const ggml_webgpu_set_pipeline_key & other) const { + return type == other.type && inplace == other.inplace; + } +}; + +struct ggml_webgpu_set_pipeline_key_hash { + size_t operator()(const ggml_webgpu_set_pipeline_key & key) const { + size_t seed = 0; + ggml_webgpu_hash_combine(seed, key.type); + ggml_webgpu_hash_combine(seed, key.inplace); + return seed; + } +}; + +/** Get Rows **/ + +struct ggml_webgpu_get_rows_pipeline_key { + ggml_type src_type; + int vectorized; + + bool operator==(const ggml_webgpu_get_rows_pipeline_key & other) const { + return src_type == other.src_type && vectorized == other.vectorized; + } +}; + +struct ggml_webgpu_get_rows_pipeline_key_hash { + size_t operator()(const ggml_webgpu_get_rows_pipeline_key & key) const { + size_t seed = 0; + ggml_webgpu_hash_combine(seed, key.src_type); + ggml_webgpu_hash_combine(seed, key.vectorized); + return seed; + } +}; + +/** Row Norm **/ + +struct ggml_webgpu_row_norm_pipeline_key { + ggml_op op; + ggml_type src_type; + ggml_type dst_type; + bool inplace; + + bool operator==(const ggml_webgpu_row_norm_pipeline_key & other) const { + return op == other.op && src_type == other.src_type && dst_type == other.dst_type && inplace == other.inplace; + } +}; + +struct ggml_webgpu_row_norm_pipeline_key_hash { + size_t operator()(const ggml_webgpu_row_norm_pipeline_key & key) const { + size_t seed = 0; + ggml_webgpu_hash_combine(seed, key.op); + ggml_webgpu_hash_combine(seed, key.src_type); + ggml_webgpu_hash_combine(seed, key.dst_type); + ggml_webgpu_hash_combine(seed, key.inplace); + return seed; + } +}; + +/** RMS_NORM + MUL **/ + +struct ggml_webgpu_rms_norm_mul_pipeline_key { + bool inplace; // rn_src == dst + bool overlap; // mul_src == dst + bool src_overlap; // rn_src == mul_src + + bool operator==(const ggml_webgpu_rms_norm_mul_pipeline_key & other) const { + return inplace == other.inplace && overlap == other.overlap && src_overlap == other.src_overlap; + } +}; + +struct ggml_webgpu_rms_norm_mul_pipeline_key_hash { + size_t operator()(const ggml_webgpu_rms_norm_mul_pipeline_key & key) const { + size_t seed = 0; + ggml_webgpu_hash_combine(seed, key.inplace); + ggml_webgpu_hash_combine(seed, key.overlap); + ggml_webgpu_hash_combine(seed, key.src_overlap); + return seed; + } +}; + +struct ggml_webgpu_rms_norm_mul_shader_decisions { + uint32_t wg_size = 0; + bool inplace = false; + bool overlap = false; + bool src_overlap = false; +}; + +/** Pad **/ +struct ggml_webgpu_pad_pipeline_key { + bool circular; + + bool operator==(const ggml_webgpu_pad_pipeline_key & other) const { return circular == other.circular; } +}; + +struct ggml_webgpu_pad_pipeline_key_hash { + size_t operator()(const ggml_webgpu_pad_pipeline_key & key) const { + size_t seed = 0; + ggml_webgpu_hash_combine(seed, key.circular); + return seed; + } +}; + +/** Solve Tri **/ +struct ggml_webgpu_solve_tri_pipeline_key { + int type; + int n; + int k; + + bool operator==(const ggml_webgpu_solve_tri_pipeline_key & other) const { + return type == other.type && n == other.n && k == other.k; + } +}; + +struct ggml_webgpu_solve_tri_pipeline_key_hash { + size_t operator()(const ggml_webgpu_solve_tri_pipeline_key & key) const { + size_t seed = 0; + ggml_webgpu_hash_combine(seed, key.type); + ggml_webgpu_hash_combine(seed, key.n); + ggml_webgpu_hash_combine(seed, key.k); + return seed; + } +}; + +/** SSM Conv **/ +struct ggml_webgpu_ssm_conv_pipeline_key { + int type; + int vectorized; + + bool operator==(const ggml_webgpu_ssm_conv_pipeline_key & other) const { + return type == other.type && vectorized == other.vectorized; + } +}; + +/** CONV 2D */ +struct ggml_webgpu_conv2d_pipeline_key { + ggml_type weight_type; + ggml_type input_type; + ggml_type output_type; + + bool operator==(const ggml_webgpu_conv2d_pipeline_key & other) const { + return weight_type == other.weight_type && input_type == other.input_type && output_type == other.output_type; + } +}; + +struct ggml_webgpu_conv2d_pipeline_key_hash { + size_t operator()(const ggml_webgpu_conv2d_pipeline_key & key) const { + size_t seed = 0; + ggml_webgpu_hash_combine(seed, key.weight_type); + ggml_webgpu_hash_combine(seed, key.input_type); + ggml_webgpu_hash_combine(seed, key.output_type); + return seed; + } +}; + +/** Im2Col **/ +struct ggml_webgpu_im2col_pipeline_key { + ggml_type input_type; + ggml_type output_type; + + bool operator==(const ggml_webgpu_im2col_pipeline_key & other) const { + return input_type == other.input_type && output_type == other.output_type; + } +}; + +struct ggml_webgpu_im2col_pipeline_key_hash { + size_t operator()(const ggml_webgpu_im2col_pipeline_key & key) const { + size_t seed = 0; + ggml_webgpu_hash_combine(seed, key.input_type); + ggml_webgpu_hash_combine(seed, key.output_type); + return seed; + } +}; + +/** Gated Delta Net **/ +struct ggml_webgpu_gated_delta_net_pipeline_key { + int type; + int s_v; + int kda; + + bool operator==(const ggml_webgpu_gated_delta_net_pipeline_key & other) const { + return type == other.type && s_v == other.s_v && kda == other.kda; + } +}; + +struct ggml_webgpu_gated_delta_net_pipeline_key_hash { + size_t operator()(const ggml_webgpu_gated_delta_net_pipeline_key & key) const { + size_t seed = 0; + ggml_webgpu_hash_combine(seed, key.type); + ggml_webgpu_hash_combine(seed, key.s_v); + ggml_webgpu_hash_combine(seed, key.kda); + return seed; + } +}; + +struct ggml_webgpu_ssm_conv_pipeline_key_hash { + size_t operator()(const ggml_webgpu_ssm_conv_pipeline_key & key) const { + size_t seed = 0; + ggml_webgpu_hash_combine(seed, key.type); + ggml_webgpu_hash_combine(seed, key.vectorized); + return seed; + } +}; + +/** Scale **/ + +struct ggml_webgpu_scale_pipeline_key { + int inplace; + + bool operator==(const ggml_webgpu_scale_pipeline_key & other) const { return inplace == other.inplace; } +}; + +struct ggml_webgpu_scale_pipeline_key_hash { + size_t operator()(const ggml_webgpu_scale_pipeline_key & key) const { + size_t seed = 0; + ggml_webgpu_hash_combine(seed, key.inplace); + return seed; + } +}; + +/** Upscale **/ + +struct ggml_webgpu_upscale_pipeline_key { + ggml_type input_type; + ggml_type output_type; + uint32_t base_mode; + bool antialias; + + bool operator==(const ggml_webgpu_upscale_pipeline_key & other) const { + return input_type == other.input_type && output_type == other.output_type && base_mode == other.base_mode && + antialias == other.antialias; + } +}; + +struct ggml_webgpu_upscale_pipeline_key_hash { + size_t operator()(const ggml_webgpu_upscale_pipeline_key & key) const { + size_t seed = 0; + ggml_webgpu_hash_combine(seed, key.input_type); + ggml_webgpu_hash_combine(seed, key.output_type); + ggml_webgpu_hash_combine(seed, key.base_mode); + ggml_webgpu_hash_combine(seed, key.antialias); + return seed; + } +}; + +/** Concat **/ + +struct ggml_webgpu_concat_pipeline_key { + int type; + bool src_overlap; + + bool operator==(const ggml_webgpu_concat_pipeline_key & other) const { + return type == other.type && src_overlap == other.src_overlap; + } +}; + +struct ggml_webgpu_concat_pipeline_key_hash { + size_t operator()(const ggml_webgpu_concat_pipeline_key & key) const { + size_t seed = 0; + ggml_webgpu_hash_combine(seed, key.type); + ggml_webgpu_hash_combine(seed, key.src_overlap); + return seed; + } +}; + +/** Repeat **/ + +struct ggml_webgpu_repeat_pipeline_key { + int type; + + bool operator==(const ggml_webgpu_repeat_pipeline_key & other) const { return type == other.type; } +}; + +struct ggml_webgpu_repeat_pipeline_key_hash { + size_t operator()(const ggml_webgpu_repeat_pipeline_key & key) const { + size_t seed = 0; + ggml_webgpu_hash_combine(seed, key.type); + return seed; + } +}; + +/** Binary **/ + +struct ggml_webgpu_binary_pipeline_key { + int type; + int op; + bool inplace; + bool overlap; + bool src_overlap; + + bool operator==(const ggml_webgpu_binary_pipeline_key & other) const { + return type == other.type && op == other.op && inplace == other.inplace && overlap == other.overlap && + src_overlap == other.src_overlap; + } +}; + +struct ggml_webgpu_binary_pipeline_key_hash { + size_t operator()(const ggml_webgpu_binary_pipeline_key & key) const { + size_t seed = 0; + ggml_webgpu_hash_combine(seed, key.type); + ggml_webgpu_hash_combine(seed, key.op); + ggml_webgpu_hash_combine(seed, key.inplace); + ggml_webgpu_hash_combine(seed, key.overlap); + ggml_webgpu_hash_combine(seed, key.src_overlap); + return seed; + } +}; + +/* Add_Id */ + +struct ggml_webgpu_add_id_pipeline_key { + bool inplace; + + bool operator==(const ggml_webgpu_add_id_pipeline_key & other) const { return inplace == other.inplace; } +}; + +struct ggml_webgpu_add_id_pipeline_key_hash { + size_t operator()(const ggml_webgpu_add_id_pipeline_key & key) const { + size_t seed = 0; + ggml_webgpu_hash_combine(seed, key.inplace); + return seed; + } +}; + +/** Unary **/ + +struct ggml_webgpu_unary_pipeline_key { + int type; + int op; + bool is_unary; // many unary operators fall under the GGML_OP_UNARY umbrella + bool inplace; + ggml_tri_type ttype; // only used for GGML_OP_TRI + + bool operator==(const ggml_webgpu_unary_pipeline_key & other) const { + return type == other.type && op == other.op && is_unary == other.is_unary && inplace == other.inplace && + ttype == other.ttype; + } +}; + +struct ggml_webgpu_unary_pipeline_key_hash { + size_t operator()(const ggml_webgpu_unary_pipeline_key & key) const { + size_t seed = 0; + ggml_webgpu_hash_combine(seed, key.type); + ggml_webgpu_hash_combine(seed, key.op); + ggml_webgpu_hash_combine(seed, key.is_unary); + ggml_webgpu_hash_combine(seed, key.inplace); + ggml_webgpu_hash_combine(seed, key.ttype); + return seed; + } +}; + +/** FlashAttention */ + +struct ggml_webgpu_flash_attn_common_pipeline_key { + ggml_type q_type; + ggml_type k_type; + ggml_type v_type; + ggml_type dst_type; uint32_t head_dim_qk; uint32_t head_dim_v; bool kv_direct; + bool kv_overlap; bool has_mask; bool has_sinks; bool uses_logit_softcap; - uint32_t sg_mat_m; - uint32_t sg_mat_n; - uint32_t sg_mat_k; - size_t wg_mem_limit_bytes; - uint32_t max_subgroup_size; + + bool operator==(const ggml_webgpu_flash_attn_common_pipeline_key & other) const { + return q_type == other.q_type && k_type == other.k_type && v_type == other.v_type && + dst_type == other.dst_type && head_dim_qk == other.head_dim_qk && head_dim_v == other.head_dim_v && + kv_direct == other.kv_direct && kv_overlap == other.kv_overlap && has_mask == other.has_mask && + has_sinks == other.has_sinks && uses_logit_softcap == other.uses_logit_softcap; + } }; -struct ggml_webgpu_flash_attn_shader_decisions { - uint32_t q_tile = 0; - uint32_t kv_tile = 0; - uint32_t wg_size = 0; +inline void ggml_webgpu_flash_attn_hash_common_pipeline_key(size_t & seed, + const ggml_webgpu_flash_attn_common_pipeline_key & key) { + ggml_webgpu_hash_combine(seed, key.q_type); + ggml_webgpu_hash_combine(seed, key.k_type); + ggml_webgpu_hash_combine(seed, key.v_type); + ggml_webgpu_hash_combine(seed, key.dst_type); + ggml_webgpu_hash_combine(seed, key.head_dim_qk); + ggml_webgpu_hash_combine(seed, key.head_dim_v); + ggml_webgpu_hash_combine(seed, key.kv_direct); + ggml_webgpu_hash_combine(seed, key.kv_overlap); + ggml_webgpu_hash_combine(seed, key.has_mask); + ggml_webgpu_hash_combine(seed, key.has_sinks); + ggml_webgpu_hash_combine(seed, key.uses_logit_softcap); +} + +struct ggml_webgpu_flash_attn_vec_pipeline_key { + ggml_webgpu_flash_attn_common_pipeline_key common; + + bool operator==(const ggml_webgpu_flash_attn_vec_pipeline_key & other) const { return common == other.common; } }; -struct ggml_webgpu_processed_shader { - std::string wgsl; - std::string variant; - ggml_webgpu_flash_attn_shader_decisions decisions; +struct ggml_webgpu_flash_attn_vec_pipeline_key_hash { + size_t operator()(const ggml_webgpu_flash_attn_vec_pipeline_key & key) const { + size_t seed = 0; + ggml_webgpu_flash_attn_hash_common_pipeline_key(seed, key.common); + return seed; + } }; -// This is exposed because it's necessary in supports_op -inline size_t ggml_webgpu_flash_attn_wg_mem_bytes(uint32_t q_tile, - uint32_t kv_tile, - uint32_t head_dim_qk, - uint32_t head_dim_v, - bool has_mask, - bool kv_direct) { - const uint32_t max_head_dim = std::max(head_dim_qk, head_dim_v); - size_t f16_elems = 0; - size_t f32_elems = 0; - f16_elems += q_tile * head_dim_qk; // q_shmem - if (!kv_direct) { - f16_elems += kv_tile * max_head_dim; // kv_shmem +struct ggml_webgpu_flash_attn_pipeline_key { + ggml_webgpu_flash_attn_common_pipeline_key common; + bool use_sg_matrix; + + bool operator==(const ggml_webgpu_flash_attn_pipeline_key & other) const { + return common == other.common && use_sg_matrix == other.use_sg_matrix; } - f16_elems += q_tile * head_dim_v; // o_shmem - if (has_mask) { - f16_elems += q_tile * kv_tile; // mask_shmem +}; + +struct ggml_webgpu_flash_attn_pipeline_key_hash { + size_t operator()(const ggml_webgpu_flash_attn_pipeline_key & key) const { + size_t seed = 0; + ggml_webgpu_flash_attn_hash_common_pipeline_key(seed, key.common); + ggml_webgpu_hash_combine(seed, key.use_sg_matrix); + return seed; } - f16_elems += q_tile * kv_tile; // inter_shmem - f32_elems += q_tile; // row_max_shmem - f32_elems += q_tile; // exp_sum_shmem - return f16_elems * GGML_WEBGPU_F16_SIZE_BYTES + f32_elems * GGML_WEBGPU_F32_SIZE_BYTES; +}; + +struct ggml_webgpu_flash_attn_vec_decisions { + uint32_t kv_tile = 0; + uint32_t wg_size = 0; +}; + +struct ggml_webgpu_flash_attn_decisions { + bool use_sg_matrix = false; + uint32_t q_tile = 0; + uint32_t kv_tile = 0; + uint32_t wg_size = 0; +}; + +inline constexpr uint32_t GGML_WEBGPU_FLASH_ATTN_TILE_KV_VEC_WIDTH = 4u; +inline constexpr uint32_t GGML_WEBGPU_FLASH_ATTN_TILE_Q_TILE = 4u; + +inline size_t ggml_webgpu_flash_attn_tensor_offset(const ggml_tensor * tensor) { + constexpr uintptr_t ptr_base_addr = 0x1000u; + const ggml_tensor * base = tensor->view_src != nullptr ? tensor->view_src : tensor; + return reinterpret_cast<uintptr_t>(base->data) - ptr_base_addr + tensor->view_offs; } -static uint32_t ggml_webgpu_flash_attn_max_kv_tile(const ggml_webgpu_flash_attn_shader_lib_context & context) { - const size_t limit_bytes = context.wg_mem_limit_bytes; - const size_t q_tile = context.sg_mat_m; - const size_t base_q_bytes = (context.head_dim_qk + context.head_dim_v) * q_tile * GGML_WEBGPU_F16_SIZE_BYTES + - 2 * q_tile * GGML_WEBGPU_F32_SIZE_BYTES; - size_t bytes_per_kv = 0; - if (!context.kv_direct) { - bytes_per_kv += std::max(context.head_dim_qk, context.head_dim_v); - } - if (context.has_mask) { - bytes_per_kv += q_tile; - } - bytes_per_kv += q_tile; - bytes_per_kv *= GGML_WEBGPU_F16_SIZE_BYTES; - const uint32_t max_kv_tile = (limit_bytes - base_q_bytes) / bytes_per_kv; - return (max_kv_tile / context.sg_mat_n) * context.sg_mat_n; +inline bool ggml_webgpu_flash_attn_float_vec4_aligned(const ggml_tensor * K, size_t storage_offset_alignment) { + const uint32_t offset_elems = + (uint32_t) ((ggml_webgpu_flash_attn_tensor_offset(K) & (storage_offset_alignment - 1)) / + ggml_type_size(K->type)); + return offset_elems % GGML_WEBGPU_FLASH_ATTN_TILE_KV_VEC_WIDTH == 0u; +} + +inline bool ggml_webgpu_flash_attn_float_vec4_aligned(const ggml_tensor * K, + const ggml_tensor * V, + size_t storage_offset_alignment) { + return ggml_webgpu_flash_attn_float_vec4_aligned(K, storage_offset_alignment) && + ggml_webgpu_flash_attn_float_vec4_aligned(V, storage_offset_alignment); +} + +inline bool ggml_webgpu_flash_attn_kv_direct(const ggml_tensor * Q, + const ggml_tensor * K, + const ggml_tensor * V, + uint32_t kv_direct_align) { + return K->type == GGML_TYPE_F16 && V->type == GGML_TYPE_F16 && (Q->ne[0] % kv_direct_align == 0) && + (K->ne[1] % GGML_WEBGPU_KV_SEQ_PAD == 0); +} + +inline ggml_webgpu_flash_attn_common_pipeline_key ggml_webgpu_flash_attn_make_common_pipeline_key( + const ggml_webgpu_shader_lib_context & context, + uint32_t kv_direct_align) { + ggml_webgpu_flash_attn_common_pipeline_key key = {}; + key.q_type = context.src0->type; + key.k_type = context.src1->type; + key.v_type = context.src2->type; + key.dst_type = context.dst->type; + key.head_dim_qk = (uint32_t) context.src0->ne[0]; + key.head_dim_v = (uint32_t) context.src2->ne[0]; + key.kv_direct = ggml_webgpu_flash_attn_kv_direct(context.src0, context.src1, context.src2, kv_direct_align); + key.kv_overlap = ggml_webgpu_tensor_overlap(context.src1, context.src2); + key.has_mask = context.src3 != nullptr; + key.has_sinks = context.src4 != nullptr; + key.uses_logit_softcap = ggml_get_op_params_f32(context.dst, 2) != 0.0f; + return key; } -inline ggml_webgpu_processed_shader ggml_webgpu_preprocess_flash_attn_shader( - pre_wgsl::Preprocessor & preprocessor, - const char * shader_src, - const ggml_webgpu_flash_attn_shader_lib_context & context) { +inline std::vector<std::string> ggml_webgpu_flash_attn_common_defines( + const ggml_webgpu_flash_attn_common_pipeline_key & key, + std::string & variant, + uint32_t q_tile, + uint32_t kv_tile, + uint32_t wg_size) { std::vector<std::string> defines; - std::string variant = "flash_attn"; - switch (context.kv_type) { + switch (key.k_type) { + case GGML_TYPE_F32: + defines.push_back("K_F32"); + break; + case GGML_TYPE_F16: + defines.push_back("K_F16"); + break; + case GGML_TYPE_Q4_0: + defines.push_back("K_Q4_0"); + break; + case GGML_TYPE_Q8_0: + defines.push_back("K_Q8_0"); + break; + default: + GGML_ABORT("Unsupported K type for flash attention shader"); + } + variant += std::string("_k") + ggml_type_name(key.k_type); + + switch (key.v_type) { case GGML_TYPE_F32: - defines.push_back("KV_F32"); + defines.push_back("V_F32"); break; case GGML_TYPE_F16: - defines.push_back("KV_F16"); + defines.push_back("V_F16"); break; case GGML_TYPE_Q4_0: - defines.push_back("KV_Q4_0"); + defines.push_back("V_Q4_0"); break; case GGML_TYPE_Q8_0: - defines.push_back("KV_Q8_0"); + defines.push_back("V_Q8_0"); + break; + default: + GGML_ABORT("Unsupported V type for flash attention shader"); + } + variant += std::string("_v") + ggml_type_name(key.v_type); + + switch (key.q_type) { + case GGML_TYPE_F32: + defines.push_back("Q_F32"); + break; + case GGML_TYPE_F16: + defines.push_back("Q_F16"); + break; + default: + GGML_ABORT("Unsupported Q type for flash attention shader"); + } + variant += std::string("_q") + ggml_type_name(key.q_type); + + switch (key.dst_type) { + case GGML_TYPE_F32: + defines.push_back("DST_F32"); + break; + case GGML_TYPE_F16: + defines.push_back("DST_F16"); break; default: - GGML_ABORT("Unsupported KV type for flash attention shader"); + GGML_ABORT("Unsupported dst type for flash attention shader"); } - variant += std::string("_") + ggml_type_name(context.kv_type); + variant += std::string("_dst") + ggml_type_name(key.dst_type); - if (context.has_mask) { + if (key.has_mask) { defines.push_back("MASK"); variant += "_mask"; } - if (context.has_sinks) { + if (key.has_sinks) { defines.push_back("SINKS"); variant += "_sinks"; } - if (context.uses_logit_softcap) { + if (key.uses_logit_softcap) { defines.push_back("LOGIT_SOFTCAP"); variant += "_lgsc"; } - - if (context.kv_direct) { + if (key.kv_direct) { defines.push_back("KV_DIRECT"); variant += "_kvdirect"; } + if (key.kv_overlap) { + defines.push_back("KV_OVERLAP"); + variant += "_kv_overlap"; + } - defines.push_back(std::string("HEAD_DIM_QK=") + std::to_string(context.head_dim_qk)); - variant += std::string("_hsqk") + std::to_string(context.head_dim_qk); + defines.push_back(std::string("HEAD_DIM_QK=") + std::to_string(key.head_dim_qk)); + variant += std::string("_hsqk") + std::to_string(key.head_dim_qk); - defines.push_back(std::string("HEAD_DIM_V=") + std::to_string(context.head_dim_v)); - variant += std::string("_hsv") + std::to_string(context.head_dim_v); + defines.push_back(std::string("HEAD_DIM_V=") + std::to_string(key.head_dim_v)); + variant += std::string("_hsv") + std::to_string(key.head_dim_v); - // For now these are not part of the variant name - defines.push_back(std::string("SG_MAT_M=") + std::to_string(context.sg_mat_m)); - defines.push_back(std::string("SG_MAT_N=") + std::to_string(context.sg_mat_n)); - defines.push_back(std::string("SG_MAT_K=") + std::to_string(context.sg_mat_k)); + defines.push_back(std::string("Q_TILE=") + std::to_string(q_tile)); + defines.push_back(std::string("KV_TILE=") + std::to_string(kv_tile)); + defines.push_back(std::string("WG_SIZE=") + std::to_string(wg_size)); - // Add chosen Q/KV tile sizes - uint32_t q_tile = context.sg_mat_m; - uint32_t kv_tile = std::min(ggml_webgpu_flash_attn_max_kv_tile(context), - context.sg_mat_n * GGML_WEBGPU_FLASH_ATTN_PREFERRED_KV_SG_TILES); - if (context.kv_direct) { - GGML_ASSERT(kv_tile <= GGML_WEBGPU_KV_SEQ_PAD); - // Avoids having to use bounds-checks and decreasing performance for direct KV loads - while (GGML_WEBGPU_KV_SEQ_PAD % kv_tile != 0) { - kv_tile -= context.sg_mat_n; + if (ggml_is_quantized(key.k_type) || ggml_is_quantized(key.v_type)) { + defines.push_back("U32_DEQUANT_HELPERS"); + } + + return defines; +} + +struct ggml_webgpu_flash_attn_vec_reduce_pipeline_key { + uint32_t head_dim_v; + uint32_t wg_size; + ggml_type dst_type; +}; + +struct ggml_webgpu_flash_attn_vec_reduce_pipeline_key_hash { + size_t operator()(const ggml_webgpu_flash_attn_vec_reduce_pipeline_key & key) const { + size_t seed = 0; + ggml_webgpu_hash_combine(seed, key.head_dim_v); + ggml_webgpu_hash_combine(seed, key.wg_size); + ggml_webgpu_hash_combine(seed, key.dst_type); + return seed; + } +}; + +inline bool operator==(const ggml_webgpu_flash_attn_vec_reduce_pipeline_key & lhs, + const ggml_webgpu_flash_attn_vec_reduce_pipeline_key & rhs) { + return lhs.head_dim_v == rhs.head_dim_v && lhs.wg_size == rhs.wg_size && lhs.dst_type == rhs.dst_type; +} + +struct ggml_webgpu_flash_attn_blk_pipeline_key { + uint32_t kv_tile; + + bool operator==(const ggml_webgpu_flash_attn_blk_pipeline_key & other) const { return kv_tile == other.kv_tile; } +}; + +struct ggml_webgpu_flash_attn_blk_pipeline_key_hash { + size_t operator()(const ggml_webgpu_flash_attn_blk_pipeline_key & key) const { + size_t seed = 0; + ggml_webgpu_hash_combine(seed, key.kv_tile); + return seed; + } +}; + +// Note: this will slightly overestimate memory usage for vec path +// since row_max and exp_sum shmem are not needed. +inline size_t ggml_webgpu_flash_attn_wg_mem_bytes(uint32_t q_tile, + uint32_t kv_tile, + uint32_t head_dim_qk, + uint32_t head_dim_v, + bool has_mask, + bool kv_direct) { + const uint32_t max_head_dim = std::max(head_dim_qk, head_dim_v); + size_t f16_elems = 0; + size_t f32_elems = 0; + + f32_elems += q_tile * head_dim_qk; // q_shmem + if (!kv_direct) { + f32_elems += kv_tile * max_head_dim; // kv_shmem + } + f32_elems += q_tile * head_dim_v; // o_shmem + if (has_mask) { + f32_elems += q_tile * kv_tile; // mask_shmem + } + f32_elems += q_tile * kv_tile; // inter_shmem + f32_elems += q_tile; // row_max_shmem + f32_elems += q_tile; // exp_sum_shmem + return f16_elems * GGML_WEBGPU_F16_SIZE_BYTES + f32_elems * GGML_WEBGPU_F32_SIZE_BYTES; +} + +inline uint32_t ggml_webgpu_flash_attn_max_kv_tile(size_t limit_bytes, + uint32_t q_tile, + uint32_t kv_granularity, + uint32_t head_dim_qk, + uint32_t head_dim_v, + bool has_mask, + bool kv_direct) { + const size_t base_q_bytes = + ggml_webgpu_flash_attn_wg_mem_bytes(q_tile, 0, head_dim_qk, head_dim_v, has_mask, kv_direct); + if (limit_bytes <= base_q_bytes) { + return 0; + } + const size_t one_kv_bytes = + ggml_webgpu_flash_attn_wg_mem_bytes(q_tile, 1, head_dim_qk, head_dim_v, has_mask, kv_direct); + const size_t bytes_per_kv = one_kv_bytes - base_q_bytes; + if (bytes_per_kv == 0) { + return 0; + } + const size_t max_kv_tile = (limit_bytes - base_q_bytes) / bytes_per_kv; + return (uint32_t) ((max_kv_tile / kv_granularity) * kv_granularity); +} + +inline uint32_t ggml_webgpu_flash_attn_get_vec_kv_tile(size_t wg_mem_limit_bytes, + uint32_t head_dim_qk, + uint32_t head_dim_v, + bool has_mask, + bool kv_direct) { + const uint32_t max_kv_tile = + ggml_webgpu_flash_attn_max_kv_tile(wg_mem_limit_bytes, 1u, 1u, head_dim_qk, head_dim_v, has_mask, kv_direct); + GGML_ASSERT(max_kv_tile > 0); + + uint32_t kv_tile = std::min(GGML_WEBGPU_FLASH_ATTN_VEC_MAX_KV_TILE, max_kv_tile); + if (kv_direct) { + kv_tile = std::min(kv_tile, GGML_WEBGPU_KV_SEQ_PAD); + while (GGML_WEBGPU_KV_SEQ_PAD % kv_tile != 0) { + kv_tile -= 1u; } } - defines.push_back(std::string("Q_TILE=") + std::to_string(q_tile)); - defines.push_back(std::string("KV_TILE=") + std::to_string(kv_tile)); + return kv_tile; +} - // workgroup size - uint32_t wg_size = std::max(context.max_subgroup_size, GGML_WEBGPU_FLASH_ATTN_PREFERRED_WG_SIZE); +inline bool ggml_webgpu_flash_attn_can_use_subgroup_matrix_path(bool supports_subgroup_matrix, + uint32_t sg_mat_k, + uint32_t sg_mat_n, + const ggml_tensor * Q, + const ggml_tensor * V) { + return supports_subgroup_matrix && Q->ne[0] % sg_mat_k == 0 && V->ne[0] % sg_mat_n == 0; +} - defines.push_back(std::string("WG_SIZE=") + std::to_string(wg_size)); +/** Matrix Multiplication **/ + +struct ggml_webgpu_mul_mat_vec_pipeline_key { + ggml_type src0_type; + ggml_type src1_type; + int vectorized; + bool use_mmvq; + + bool operator==(const ggml_webgpu_mul_mat_vec_pipeline_key & other) const { + return src0_type == other.src0_type && src1_type == other.src1_type && vectorized == other.vectorized && + use_mmvq == other.use_mmvq; + } +}; + +struct ggml_webgpu_mul_mat_vec_pipeline_key_hash { + size_t operator()(const ggml_webgpu_mul_mat_vec_pipeline_key & key) const { + size_t seed = 0; + ggml_webgpu_hash_combine(seed, key.src0_type); + ggml_webgpu_hash_combine(seed, key.src1_type); + ggml_webgpu_hash_combine(seed, key.vectorized); + ggml_webgpu_hash_combine(seed, key.use_mmvq); + return seed; + } +}; + +struct ggml_webgpu_mul_mat_vec_shader_decisions { + uint32_t wg_size; + uint32_t outputs_per_wg; + uint32_t vec_size; +}; + +struct ggml_webgpu_quantize_q8_pipeline_key { + ggml_type src0_type; + + bool operator==(const ggml_webgpu_quantize_q8_pipeline_key & other) const { return src0_type == other.src0_type; } +}; + +struct ggml_webgpu_quantize_q8_pipeline_key_hash { + size_t operator()(const ggml_webgpu_quantize_q8_pipeline_key & key) const { + size_t seed = 0; + ggml_webgpu_hash_combine(seed, key.src0_type); + return seed; + } +}; + +struct ggml_webgpu_mul_mat_pipeline_key { + ggml_type src0_type; + ggml_type src1_type; + int vectorized; + int use_subgroup_matrix; - ggml_webgpu_processed_shader result; - result.wgsl = preprocessor.preprocess(shader_src, defines); - result.variant = variant; - result.decisions.q_tile = q_tile; - result.decisions.kv_tile = kv_tile; - result.decisions.wg_size = wg_size; - return result; + bool operator==(const ggml_webgpu_mul_mat_pipeline_key & other) const { + return src0_type == other.src0_type && src1_type == other.src1_type && vectorized == other.vectorized && + use_subgroup_matrix == other.use_subgroup_matrix; + } +}; + +struct ggml_webgpu_mul_mat_pipeline_key_hash { + size_t operator()(const ggml_webgpu_mul_mat_pipeline_key & key) const { + size_t seed = 0; + ggml_webgpu_hash_combine(seed, key.src0_type); + ggml_webgpu_hash_combine(seed, key.src1_type); + ggml_webgpu_hash_combine(seed, key.vectorized); + ggml_webgpu_hash_combine(seed, key.use_subgroup_matrix); + return seed; + } +}; + +struct ggml_webgpu_mul_mat_shader_decisions { + uint32_t tile_k; + uint32_t wg_size_m; + uint32_t wg_size_n; + uint32_t wg_size; + uint32_t outputs_per_wg; + int use_subgroup_matrix; + + uint32_t tile_m; + uint32_t tile_n; + + // Subgroup matrix parameters + uint32_t subgroup_m; + uint32_t subgroup_n; + uint32_t subgroup_matrix_m; + uint32_t subgroup_matrix_n; + + uint32_t mul_mat_wg_size; +}; + +/** MUL_MAT_ID **/ + +struct ggml_webgpu_mul_mat_id_pipeline_key { + ggml_type src0_type; + ggml_type src1_type; + uint32_t n_experts; + int vectorized; + + bool operator==(const ggml_webgpu_mul_mat_id_pipeline_key & other) const { + return src0_type == other.src0_type && src1_type == other.src1_type && n_experts == other.n_experts && + vectorized == other.vectorized; + } +}; + +struct ggml_webgpu_mul_mat_id_pipeline_key_hash { + size_t operator()(const ggml_webgpu_mul_mat_id_pipeline_key & key) const { + size_t seed = 0; + ggml_webgpu_hash_combine(seed, key.src0_type); + ggml_webgpu_hash_combine(seed, key.src1_type); + ggml_webgpu_hash_combine(seed, key.n_experts); + ggml_webgpu_hash_combine(seed, key.vectorized); + return seed; + } +}; + +/** Cpy **/ + +struct ggml_webgpu_cpy_pipeline_key { + ggml_type src_type; + ggml_type dst_type; + + bool operator==(const ggml_webgpu_cpy_pipeline_key & other) const { + return src_type == other.src_type && dst_type == other.dst_type; + } +}; + +struct ggml_webgpu_cpy_pipeline_key_hash { + size_t operator()(const ggml_webgpu_cpy_pipeline_key & key) const { + size_t seed = 0; + ggml_webgpu_hash_combine(seed, key.src_type); + ggml_webgpu_hash_combine(seed, key.dst_type); + return seed; + } +}; + +/** Glu **/ + +struct ggml_webgpu_glu_pipeline_key { + ggml_glu_op glu_op; + ggml_type type; + bool split; + + bool operator==(const ggml_webgpu_glu_pipeline_key & other) const { + return glu_op == other.glu_op && type == other.type && split == other.split; + } +}; + +struct ggml_webgpu_glu_pipeline_key_hash { + size_t operator()(const ggml_webgpu_glu_pipeline_key & key) const { + size_t seed = 0; + ggml_webgpu_hash_combine(seed, key.glu_op); + ggml_webgpu_hash_combine(seed, key.type); + ggml_webgpu_hash_combine(seed, key.split); + return seed; + } +}; + +/** Rope **/ + +struct ggml_webgpu_rope_pipeline_key { + ggml_type type; + bool inplace; + bool has_ff; + + bool operator==(const ggml_webgpu_rope_pipeline_key & other) const { + return type == other.type && inplace == other.inplace && has_ff == other.has_ff; + } +}; + +struct ggml_webgpu_rope_pipeline_key_hash { + size_t operator()(const ggml_webgpu_rope_pipeline_key & key) const { + size_t seed = 0; + ggml_webgpu_hash_combine(seed, key.type); + ggml_webgpu_hash_combine(seed, key.inplace); + ggml_webgpu_hash_combine(seed, key.has_ff); + return seed; + } +}; + +/** SoftMax **/ + +struct ggml_webgpu_soft_max_pipeline_key { + ggml_type mask_type; + bool has_mask; + bool has_sink; + bool inplace; + + bool operator==(const ggml_webgpu_soft_max_pipeline_key & other) const { + return mask_type == other.mask_type && has_mask == other.has_mask && has_sink == other.has_sink && + inplace == other.inplace; + } +}; + +struct ggml_webgpu_soft_max_pipeline_key_hash { + size_t operator()(const ggml_webgpu_soft_max_pipeline_key & key) const { + size_t seed = 0; + ggml_webgpu_hash_combine(seed, key.mask_type); + ggml_webgpu_hash_combine(seed, key.has_mask); + ggml_webgpu_hash_combine(seed, key.has_sink); + ggml_webgpu_hash_combine(seed, key.inplace); + return seed; + } +}; + +/** MMVQ **/ + +inline bool ggml_webgpu_can_use_mmvq(const ggml_tensor * src0, + const ggml_tensor * src1, + bool supports_dot_product, + const std::string & vendor) { + if (src1->ne[1] == 1) { + bool supports_dp4a = vendor == "amd" || vendor == "intel" || vendor == "nvidia"; + if (supports_dp4a && supports_dot_product) { + switch (src1->type) { + case GGML_TYPE_F32: + switch (src0->type) { + case GGML_TYPE_Q4_0: + case GGML_TYPE_Q4_1: + case GGML_TYPE_Q8_0: + case GGML_TYPE_Q2_K: + case GGML_TYPE_Q4_K: + return src0->ne[0] % 4 == 0; + default: + break; + } + break; + default: + break; + } + } + } + return false; } +class ggml_webgpu_shader_lib { + wgpu::Device device; + pre_wgsl::Preprocessor preprocessor; + + std::unordered_map<int, webgpu_pipeline> sum_rows_pipelines; // key is fixed, no variants yet + std::unordered_map<int, webgpu_pipeline> argmax_pipelines; // key is vec4 + std::unordered_map<int, webgpu_pipeline> argsort_pipelines; // key is order + std::unordered_map<int, webgpu_pipeline> argsort_merge_pipelines; // key is order + std::unordered_map<int, webgpu_pipeline> cumsum_pipelines; // key is fixed, no variants yet + std::unordered_map<ggml_webgpu_row_norm_pipeline_key, webgpu_pipeline, ggml_webgpu_row_norm_pipeline_key_hash> + row_norm_pipelines; // op/inplace + + std::unordered_map<ggml_webgpu_get_rows_pipeline_key, webgpu_pipeline, ggml_webgpu_get_rows_pipeline_key_hash> + get_rows_pipelines; // src_type, vectorized + std::unordered_map<ggml_webgpu_unary_pipeline_key, webgpu_pipeline, ggml_webgpu_unary_pipeline_key_hash> + unary_pipelines; // type/op/inplace + std::unordered_map<ggml_webgpu_scale_pipeline_key, webgpu_pipeline, ggml_webgpu_scale_pipeline_key_hash> + scale_pipelines; // inplace + std::unordered_map<ggml_webgpu_solve_tri_pipeline_key, webgpu_pipeline, ggml_webgpu_solve_tri_pipeline_key_hash> + solve_tri_pipelines; // type + std::unordered_map<ggml_webgpu_ssm_conv_pipeline_key, webgpu_pipeline, ggml_webgpu_ssm_conv_pipeline_key_hash> + ssm_conv_pipelines; // type/vectorized + std::unordered_map<ggml_webgpu_ssm_scan_pipeline_key, webgpu_pipeline, ggml_webgpu_ssm_scan_pipeline_key_hash> + ssm_scan_pipelines; // type/d_state + std::unordered_map<ggml_webgpu_gated_delta_net_pipeline_key, + webgpu_pipeline, + ggml_webgpu_gated_delta_net_pipeline_key_hash> + gated_delta_net_pipelines; // type/S_v/kda + std::unordered_map<ggml_webgpu_pad_pipeline_key, webgpu_pipeline, ggml_webgpu_pad_pipeline_key_hash> + pad_pipelines; // circular/non-circular + std::unordered_map<ggml_webgpu_binary_pipeline_key, webgpu_pipeline, ggml_webgpu_binary_pipeline_key_hash> + binary_pipelines; // type/op/inplace/overlap/src_overlap + std::unordered_map<ggml_webgpu_add_id_pipeline_key, webgpu_pipeline, ggml_webgpu_add_id_pipeline_key_hash> + add_id_pipelines; // inplace + std::unordered_map<ggml_webgpu_concat_pipeline_key, webgpu_pipeline, ggml_webgpu_concat_pipeline_key_hash> + concat_pipelines; // type + std::unordered_map<ggml_webgpu_repeat_pipeline_key, webgpu_pipeline, ggml_webgpu_repeat_pipeline_key_hash> + repeat_pipelines; // type + std::unordered_map<ggml_webgpu_flash_attn_vec_pipeline_key, + webgpu_pipeline, + ggml_webgpu_flash_attn_vec_pipeline_key_hash> + flash_attn_vec_pipelines; + std::unordered_map<ggml_webgpu_flash_attn_pipeline_key, webgpu_pipeline, ggml_webgpu_flash_attn_pipeline_key_hash> + flash_attn_pipelines; + std::unordered_map<ggml_webgpu_flash_attn_vec_reduce_pipeline_key, + webgpu_pipeline, + ggml_webgpu_flash_attn_vec_reduce_pipeline_key_hash> + flash_attn_vec_reduce_pipelines; + std::unordered_map<ggml_webgpu_flash_attn_blk_pipeline_key, + webgpu_pipeline, + ggml_webgpu_flash_attn_blk_pipeline_key_hash> + flash_attn_blk_pipelines; + std::unordered_map<ggml_webgpu_mul_mat_vec_pipeline_key, webgpu_pipeline, ggml_webgpu_mul_mat_vec_pipeline_key_hash> + mul_mat_vec_pipelines; // fast mat-vec (n==1) + std::unordered_map<ggml_webgpu_mul_mat_pipeline_key, webgpu_pipeline, ggml_webgpu_mul_mat_pipeline_key_hash> + mul_mat_fast_pipelines; // fast mat-mat (reg-tile or subgroup) + std::unordered_map<ggml_webgpu_quantize_q8_pipeline_key, webgpu_pipeline, ggml_webgpu_quantize_q8_pipeline_key_hash> + quantize_q8_pipelines; + std::unordered_map<int, webgpu_pipeline> mul_mat_id_gather_pipelines; // key is fixed + std::unordered_map<ggml_webgpu_mul_mat_id_pipeline_key, webgpu_pipeline, ggml_webgpu_mul_mat_id_pipeline_key_hash> + mul_mat_id_pipelines; // src0_type/src1_type + std::unordered_map<ggml_webgpu_mul_mat_id_pipeline_key, webgpu_pipeline, ggml_webgpu_mul_mat_id_pipeline_key_hash> + mul_mat_id_vec_pipelines; // src0_type/src1_type + + std::unordered_map<ggml_webgpu_set_rows_pipeline_key, webgpu_pipeline, ggml_webgpu_set_rows_pipeline_key_hash> + set_rows_pipelines; + std::unordered_map<ggml_webgpu_set_pipeline_key, webgpu_pipeline, ggml_webgpu_set_pipeline_key_hash> set_pipelines; + std::unordered_map<ggml_webgpu_cpy_pipeline_key, webgpu_pipeline, ggml_webgpu_cpy_pipeline_key_hash> cpy_pipelines; + std::unordered_map<ggml_webgpu_glu_pipeline_key, webgpu_pipeline, ggml_webgpu_glu_pipeline_key_hash> glu_pipelines; + std::unordered_map<ggml_webgpu_rope_pipeline_key, webgpu_pipeline, ggml_webgpu_rope_pipeline_key_hash> + rope_pipelines; + std::unordered_map<ggml_webgpu_soft_max_pipeline_key, webgpu_pipeline, ggml_webgpu_soft_max_pipeline_key_hash> + soft_max_pipelines; + std::unordered_map<ggml_webgpu_conv2d_pipeline_key, webgpu_pipeline, ggml_webgpu_conv2d_pipeline_key_hash> + conv2d_pipelines; + std::unordered_map<ggml_webgpu_im2col_pipeline_key, webgpu_pipeline, ggml_webgpu_im2col_pipeline_key_hash> + im2col_pipelines; + + std::unordered_map<ggml_webgpu_rms_norm_mul_pipeline_key, + webgpu_pipeline, + ggml_webgpu_rms_norm_mul_pipeline_key_hash> + rms_norm_mul_pipelines; + std::unordered_map<ggml_webgpu_upscale_pipeline_key, webgpu_pipeline, ggml_webgpu_upscale_pipeline_key_hash> + upscale_pipelines; + + public: + ggml_webgpu_shader_lib(wgpu::Device device) { this->device = device; } + + webgpu_pipeline get_sum_rows_pipeline(const ggml_webgpu_shader_lib_context & context) { + auto it = sum_rows_pipelines.find(1); + if (it != sum_rows_pipelines.end()) { + return it->second; + } + std::vector<std::string> defines; + defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size)); + + auto processed = preprocessor.preprocess(wgsl_sum_rows, defines); + sum_rows_pipelines[1] = ggml_webgpu_create_pipeline(device, processed, "sum_rows"); + return sum_rows_pipelines[1]; + } + + webgpu_pipeline get_row_norm_pipeline(const ggml_webgpu_shader_lib_context & context) { + ggml_webgpu_row_norm_pipeline_key key = {}; + key.op = context.dst->op; + key.src_type = context.src0->type; + key.dst_type = context.dst->type; + key.inplace = ggml_webgpu_tensor_equal(context.src0, context.dst); + + auto it = row_norm_pipelines.find(key); + if (it != row_norm_pipelines.end()) { + return it->second; + } + std::vector<std::string> defines; + std::string variant; + + switch (key.op) { + case GGML_OP_RMS_NORM: + defines.push_back("RMS_NORM"); + variant = "rms_norm"; + break; + case GGML_OP_NORM: + defines.push_back("NORM"); + variant = "norm"; + break; + case GGML_OP_L2_NORM: + defines.push_back("L2_NORM"); + variant = "l2_norm"; + break; + default: + GGML_ABORT("Unsupported op for row_norm shader"); + } + + if (key.inplace) { + defines.push_back("INPLACE"); + variant += "_inplace"; + } + + if (key.src_type == GGML_TYPE_F32) { + defines.push_back("SRC_F32"); + variant += "_src_f32"; + } else if (key.src_type == GGML_TYPE_F16) { + defines.push_back("SRC_F16"); + variant += "_src_f16"; + } + + if (key.dst_type == GGML_TYPE_F32) { + defines.push_back("DST_F32"); + variant += "_dst_f32"; + } else if (key.dst_type == GGML_TYPE_F16) { + defines.push_back("DST_F16"); + variant += "_dst_f16"; + } + + const uint32_t row_norm_wg_size = 128u; + uint32_t wg_size = std::min(context.max_wg_size, row_norm_wg_size); + defines.push_back(std::string("WG_SIZE=") + std::to_string(wg_size)); + auto processed = preprocessor.preprocess(wgsl_row_norm, defines); + auto decisions = std::make_shared<ggml_webgpu_generic_shader_decisions>(); + decisions->wg_size = wg_size; + decisions->inplace = key.inplace; + row_norm_pipelines[key] = ggml_webgpu_create_pipeline(device, processed, variant); + row_norm_pipelines[key].context = decisions; + return row_norm_pipelines[key]; + } + + webgpu_pipeline get_argmax_pipeline(const ggml_webgpu_shader_lib_context & context) { + bool vec4 = context.src0->ne[0] % 4 == 0; + + auto it = argmax_pipelines.find(vec4); + if (it != argmax_pipelines.end()) { + return it->second; + } + std::string variant = "argmax"; + std::vector<std::string> defines; + defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size)); + if (vec4) { + defines.push_back("VEC4"); + variant += "_vec4"; + } + + auto processed = preprocessor.preprocess(wgsl_argmax, defines); + argmax_pipelines[vec4] = ggml_webgpu_create_pipeline(device, processed, variant); + return argmax_pipelines.at(vec4); + } + + webgpu_pipeline get_set_rows_pipeline(const ggml_webgpu_shader_lib_context & context) { + const bool quantized = ggml_is_quantized(context.dst->type); + ggml_webgpu_set_rows_pipeline_key key = {}; + key.dst_type = context.dst->type; + key.vec4 = + (context.dst->type == GGML_TYPE_F32 || context.dst->type == GGML_TYPE_F16) && context.src0->ne[0] % 4 == 0; + key.i64_idx = context.src1->type == GGML_TYPE_I64; + key.pair_blocks = quantized && ((context.src0->ne[0] / ggml_blck_size(context.dst->type)) % 2 == 0); + + auto it = set_rows_pipelines.find(key); + if (it != set_rows_pipelines.end()) { + return it->second; + } + + std::vector<std::string> defines; + std::string variant = "set_rows"; + + switch (context.dst->type) { + case GGML_TYPE_F32: + defines.push_back("DST_F32"); + variant += "_dstf32"; + break; + case GGML_TYPE_F16: + defines.push_back("DST_F16"); + variant += "_dstf16"; + break; + case GGML_TYPE_Q8_0: + defines.push_back("DST_Q8_0"); + variant += "_dstq8_0"; + break; + case GGML_TYPE_Q4_0: + defines.push_back("DST_Q4_0"); + variant += "_dstq4_0"; + break; + default: + GGML_ABORT("Unsupported dst type for set_rows shader"); + } + + if (key.vec4) { + defines.push_back("VEC4"); + variant += "_vec4"; + } + if (key.i64_idx) { + defines.push_back("I64_IDX"); + variant += "_i64idx"; + } + if (key.pair_blocks) { + defines.push_back("PAIR_BLOCKS"); + variant += "_pair_blocks"; + } + + defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size)); + + const auto & shader_source = quantized ? wgsl_set_rows_quant : wgsl_set_rows; + auto processed = preprocessor.preprocess(shader_source, defines); + auto decisions = std::make_shared<ggml_webgpu_set_rows_shader_decisions>(); + decisions->vec4 = key.vec4; + decisions->i64_idx = key.i64_idx; + decisions->pair_blocks = key.pair_blocks; + decisions->wg_size = context.max_wg_size; + set_rows_pipelines[key] = ggml_webgpu_create_pipeline(device, processed, variant); + set_rows_pipelines[key].context = decisions; + return set_rows_pipelines[key]; + } + + webgpu_pipeline get_set_pipeline(const ggml_webgpu_shader_lib_context & context) { + ggml_webgpu_set_pipeline_key key = {}; + key.type = context.dst->type; + key.inplace = ggml_webgpu_tensor_equal(context.src0, context.dst); + + auto it = set_pipelines.find(key); + if (it != set_pipelines.end()) { + return it->second; + } + + std::vector<std::string> defines; + std::string variant = "set"; + + switch (key.type) { + case GGML_TYPE_F32: + defines.push_back("TYPE_F32"); + variant += "_f32"; + break; + case GGML_TYPE_I32: + defines.push_back("TYPE_I32"); + variant += "_i32"; + break; + default: + GGML_ABORT("Unsupported type for set shader"); + } + + if (key.inplace) { + defines.push_back("INPLACE"); + variant += "_inplace"; + } + + defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size)); + + auto processed = preprocessor.preprocess(wgsl_set, defines); + auto decisions = std::make_shared<ggml_webgpu_generic_shader_decisions>(); + decisions->wg_size = context.max_wg_size; + decisions->inplace = key.inplace; + webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant); + pipeline.context = decisions; + set_pipelines[key] = pipeline; + return set_pipelines[key]; + } + + webgpu_pipeline get_cumsum_pipeline(const ggml_webgpu_shader_lib_context & context) { + auto it = cumsum_pipelines.find(1); + if (it != cumsum_pipelines.end()) { + return it->second; + } + + std::vector<std::string> defines; + defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size)); + + auto processed = preprocessor.preprocess(wgsl_cumsum, defines); + cumsum_pipelines[1] = ggml_webgpu_create_pipeline(device, processed, "cumsum"); + return cumsum_pipelines[1]; + } + + webgpu_pipeline get_argsort_pipeline(const ggml_webgpu_shader_lib_context & context) { + bool is_top_k = context.dst->op == GGML_OP_TOP_K; + // ascending order is 0, descending order is 1 + const int32_t order = + is_top_k ? (int32_t) GGML_SORT_ORDER_DESC : (int32_t) ggml_get_op_params_i32(context.dst, 0); + + auto it = argsort_pipelines.find(order); + if (it != argsort_pipelines.end()) { + return it->second; + } + + std::vector<std::string> defines; + std::string variant = "argsort"; + defines.push_back(std::string("ORDER=") + std::to_string(order)); + variant += std::string("_order") + std::to_string(order); + uint32_t wg_size = 1; + while (wg_size * 2 <= context.max_wg_size && + wg_size * GGML_WEBGPU_I32_SIZE_BYTES <= context.wg_mem_limit_bytes / 2) { + wg_size *= 2; + } + defines.push_back(std::string("WG_SIZE=") + std::to_string(wg_size)); + auto processed = preprocessor.preprocess(wgsl_argsort, defines); + auto decisions = std::make_shared<ggml_webgpu_generic_shader_decisions>(); + decisions->wg_size = wg_size; + argsort_pipelines[order] = ggml_webgpu_create_pipeline(device, processed, variant); + argsort_pipelines[order].context = decisions; + return argsort_pipelines[order]; + } + + webgpu_pipeline get_argsort_merge_pipeline(const ggml_webgpu_shader_lib_context & context) { + bool is_top_k = context.dst->op == GGML_OP_TOP_K; + // ascending order is 0, descending order is 1 + const int32_t order = + is_top_k ? (int32_t) GGML_SORT_ORDER_DESC : (int32_t) ggml_get_op_params_i32(context.dst, 0); + + auto it = argsort_merge_pipelines.find(order); + if (it != argsort_merge_pipelines.end()) { + return it->second; + } + + std::vector<std::string> defines; + std::string variant = "argsort_merge"; + defines.push_back(std::string("ORDER=") + std::to_string(order)); + variant += std::string("_order") + std::to_string(order); + uint32_t wg_size = std::min(GGML_WEBGPU_ARGSORT_MERGE_MAX_WG_SIZE, context.max_wg_size); + defines.push_back(std::string("WG_SIZE=") + std::to_string(wg_size)); + + auto processed = preprocessor.preprocess(wgsl_argsort_merge, defines); + argsort_merge_pipelines[order] = ggml_webgpu_create_pipeline(device, processed, variant); + return argsort_merge_pipelines[order]; + } + + webgpu_pipeline get_get_rows_pipeline(const ggml_webgpu_shader_lib_context & context) { + const bool vectorized = context.src0->type == GGML_TYPE_F32 && context.dst->ne[0] % 4 == 0; + ggml_webgpu_get_rows_pipeline_key key = {}; + key.src_type = context.src0->type; + key.vectorized = (int) vectorized; + + auto it = get_rows_pipelines.find(key); + if (it != get_rows_pipelines.end()) { + return it->second; + } + + std::vector<std::string> defines; + std::string variant = "get_rows"; + + const struct ggml_type_traits * type_traits = ggml_get_type_traits(key.src_type); + const char * type_str = type_traits->type_name; + + switch (key.src_type) { + case GGML_TYPE_F32: + defines.push_back("FLOAT_PARALLEL"); + if (key.vectorized) { + defines.push_back("F32_VEC"); + defines.push_back("SRC_TYPE=vec4<f32>"); + defines.push_back("DST_TYPE=vec4<f32>"); + defines.push_back("BLOCK_SIZE=4u"); + } else { + defines.push_back("F32"); + defines.push_back("SRC_TYPE=f32"); + defines.push_back("DST_TYPE=f32"); + defines.push_back("BLOCK_SIZE=1u"); + } + variant += "_f32"; + break; + case GGML_TYPE_F16: + defines.push_back("FLOAT_PARALLEL"); + defines.push_back("F16"); + defines.push_back("SRC_TYPE=f16"); + defines.push_back("DST_TYPE=f32"); + defines.push_back("BLOCK_SIZE=1u"); + variant += "_f16"; + break; + case GGML_TYPE_I32: + defines.push_back("FLOAT_PARALLEL"); + defines.push_back("I32"); + defines.push_back("SRC_TYPE=i32"); + defines.push_back("DST_TYPE=i32"); + defines.push_back("BLOCK_SIZE=1u"); + variant += "_i32"; + break; + default: + { + std::string type_upper = type_str; + std::transform(type_upper.begin(), type_upper.end(), type_upper.begin(), ::toupper); + + switch (key.src_type) { + case GGML_TYPE_Q1_0: + case GGML_TYPE_Q4_0: + case GGML_TYPE_Q5_0: + case GGML_TYPE_Q8_0: + case GGML_TYPE_Q3_K: + case GGML_TYPE_Q6_K: + case GGML_TYPE_IQ2_XXS: + case GGML_TYPE_IQ2_XS: + case GGML_TYPE_IQ2_S: + case GGML_TYPE_IQ3_XXS: + case GGML_TYPE_IQ3_S: + case GGML_TYPE_IQ1_S: + case GGML_TYPE_IQ4_NL: + case GGML_TYPE_MXFP4: + { + // Quantized types using u32 buffers for portability. + defines.push_back("SRC_TYPE=u32"); + defines.push_back("U32_DEQUANT_HELPERS"); + break; + } + default: + { + defines.push_back(std::string("SRC_TYPE=") + type_str); + } + } + + defines.push_back("BYTE_HELPERS"); + defines.push_back(type_upper + "_T"); + defines.push_back(type_upper); + defines.push_back(type_upper + "_SCALE_MIN"); + defines.push_back(type_upper + "_TABLES"); + defines.push_back(type_upper + "_GRID"); + defines.push_back(type_upper + "_LUT"); + + variant += "_"; + variant += type_str; + + defines.push_back("DST_TYPE=f32"); + + if (key.src_type == GGML_TYPE_Q1_0) { + defines.push_back("BLOCK_SIZE=128u"); + } else if ((key.src_type >= GGML_TYPE_Q4_0 && key.src_type <= GGML_TYPE_Q8_1) || + key.src_type == GGML_TYPE_IQ4_NL || key.src_type == GGML_TYPE_MXFP4) { + defines.push_back("BLOCK_SIZE=32u"); + } else if (key.src_type >= GGML_TYPE_Q2_K) { + defines.push_back("BLOCK_SIZE=256u"); + } else { + defines.push_back("BLOCK_SIZE=1u"); + } + break; + } + } + + if (key.vectorized) { + variant += "_vec"; + } + + defines.push_back("WG_SIZE=" + std::to_string(context.max_wg_size)); + + auto processed = preprocessor.preprocess(wgsl_get_rows, defines); + auto decisions = std::make_shared<ggml_webgpu_generic_shader_decisions>(); + decisions->wg_size = context.max_wg_size; + webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant); + pipeline.context = decisions; + get_rows_pipelines[key] = pipeline; + return get_rows_pipelines[key]; + } + + webgpu_pipeline get_scale_pipeline(const ggml_webgpu_shader_lib_context & context) { + ggml_webgpu_scale_pipeline_key key = {}; + key.inplace = ggml_webgpu_tensor_equal(context.src0, context.dst); + + auto it = scale_pipelines.find(key); + if (it != scale_pipelines.end()) { + return it->second; + } + + std::vector<std::string> defines; + std::string variant = "scale"; + + if (key.inplace) { + defines.push_back("INPLACE"); + variant += "_inplace"; + } + + defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size)); + + auto processed = preprocessor.preprocess(wgsl_scale, defines); + auto decisions = std::make_shared<ggml_webgpu_generic_shader_decisions>(); + decisions->wg_size = context.max_wg_size; + decisions->inplace = key.inplace; + webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant); + pipeline.context = decisions; + scale_pipelines[key] = pipeline; + return scale_pipelines[key]; + } + + webgpu_pipeline get_solve_tri_pipeline(const ggml_webgpu_shader_lib_context & context) { + ggml_webgpu_solve_tri_pipeline_key key = {}; + key.type = context.dst->type; + key.n = (int) context.src0->ne[0]; + key.k = (int) context.src1->ne[0]; + + auto it = solve_tri_pipelines.find(key); + if (it != solve_tri_pipelines.end()) { + return it->second; + } + + std::vector<std::string> defines; + std::string variant = "solve_tri"; + + switch (key.type) { + case GGML_TYPE_F32: + variant += "_f32"; + break; + default: + GGML_ABORT("Unsupported type for solve_tri shader"); + } + + const uint32_t wg_size = std::min((uint32_t) key.n, context.max_wg_size); + const uint32_t k_tile = wg_size; + const uint32_t bytes_per_row = ((uint32_t) key.n + wg_size) * GGML_WEBGPU_F32_SIZE_BYTES; + const uint32_t batch_n = (uint32_t) (context.wg_mem_limit_bytes / bytes_per_row); + + defines.push_back(std::string("N=") + std::to_string(key.n)); + defines.push_back(std::string("WG_SIZE=") + std::to_string(wg_size)); + defines.push_back(std::string("K_TILE=") + std::to_string(k_tile)); + defines.push_back(std::string("BATCH_N=") + std::to_string(batch_n)); + + auto processed = preprocessor.preprocess(wgsl_solve_tri, defines); + auto decisions = std::make_shared<ggml_webgpu_generic_shader_decisions>(); + decisions->wg_size = wg_size; + webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant); + pipeline.context = decisions; + solve_tri_pipelines[key] = pipeline; + return solve_tri_pipelines[key]; + } + + webgpu_pipeline get_ssm_conv_pipeline(const ggml_webgpu_shader_lib_context & context) { + ggml_webgpu_ssm_conv_pipeline_key key = {}; + key.type = context.dst->type; + key.vectorized = context.src1->ne[0] == 4; + + auto it = ssm_conv_pipelines.find(key); + if (it != ssm_conv_pipelines.end()) { + return it->second; + } + + std::vector<std::string> defines; + std::string variant = "ssm_conv"; + + switch (key.type) { + case GGML_TYPE_F32: + variant += "_f32"; + break; + default: + GGML_ABORT("Unsupported type for ssm_conv shader"); + } + + if (key.vectorized) { + defines.push_back("VECTORIZED"); + variant += "_vec4"; + } + + constexpr uint32_t block_size = 32u; + constexpr uint32_t tokens_per_wg = 8u; + + defines.push_back("BLOCK_SIZE=" + std::to_string(block_size) + "u"); + defines.push_back("TOKENS_PER_WG=" + std::to_string(tokens_per_wg) + "u"); + + auto processed = preprocessor.preprocess(wgsl_ssm_conv, defines); + auto decisions = std::make_shared<ggml_webgpu_ssm_conv_shader_decisions>(); + decisions->block_size = block_size; + decisions->tokens_per_wg = tokens_per_wg; + webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant); + pipeline.context = decisions; + ssm_conv_pipelines[key] = pipeline; + return ssm_conv_pipelines[key]; + } + + webgpu_pipeline get_ssm_scan_pipeline(const ggml_webgpu_shader_lib_context & context) { + ggml_webgpu_ssm_scan_pipeline_key key = {}; + key.type = context.dst->type; + key.d_state = (int) context.src0->ne[0]; + key.xbc_overlap = ggml_webgpu_tensor_overlap(context.src1, context.src4) && + ggml_webgpu_tensor_overlap(context.src1, context.src5); + + auto it = ssm_scan_pipelines.find(key); + if (it != ssm_scan_pipelines.end()) { + return it->second; + } + + std::vector<std::string> defines; + std::string variant = "ssm_scan"; + + switch (key.type) { + case GGML_TYPE_F32: + variant += "_f32"; + break; + default: + GGML_ABORT("Unsupported type for ssm_scan shader"); + } + + const uint32_t wg_size = (uint32_t) key.d_state; + + constexpr uint32_t tokens_per_tile = 4u; + + defines.push_back("WG_SIZE=" + std::to_string(wg_size) + "u"); + defines.push_back("TOKENS_PER_TILE=" + std::to_string(tokens_per_tile) + "u"); + + if (context.supports_subgroups) { + defines.push_back("USE_SUBGROUP_REDUCTION"); + variant += "_sg_reduce"; + } else { + variant += "_wg_reduce"; + } + + if (key.xbc_overlap) { + defines.push_back("XBC_OVERLAP"); + } + + variant += "_d" + std::to_string(key.d_state); + + auto processed = preprocessor.preprocess(wgsl_ssm_scan, defines); + auto decisions = std::make_shared<ggml_webgpu_ssm_scan_shader_decisions>(); + decisions->wg_size = wg_size; + decisions->tokens_per_tile = tokens_per_tile; + decisions->xbc_overlap = key.xbc_overlap; + webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant); + pipeline.context = decisions; + ssm_scan_pipelines[key] = pipeline; + return ssm_scan_pipelines[key]; + } + + webgpu_pipeline get_gated_delta_net_pipeline(const ggml_webgpu_shader_lib_context & context) { + ggml_webgpu_gated_delta_net_pipeline_key key = {}; + key.type = context.dst->type; + key.s_v = (int) context.src2->ne[0]; + key.kda = context.src3->ne[0] == context.src2->ne[0]; + + auto it = gated_delta_net_pipelines.find(key); + if (it != gated_delta_net_pipelines.end()) { + return it->second; + } + + std::vector<std::string> defines; + std::string variant = "gated_delta_net"; + + switch (key.type) { + case GGML_TYPE_F32: + variant += "_f32"; + break; + default: + GGML_ABORT("Unsupported type for gated_delta_net shader"); + } + + if (key.kda) { + defines.push_back("KDA"); + variant += "_kda"; + } + + defines.push_back("S_V=" + std::to_string(key.s_v) + "u"); + defines.push_back("WG_SIZE=" + std::to_string(key.s_v) + "u"); + + auto processed = preprocessor.preprocess(wgsl_gated_delta_net, defines); + webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant); + gated_delta_net_pipelines[key] = pipeline; + return gated_delta_net_pipelines[key]; + } + + webgpu_pipeline get_pad_pipeline(const ggml_webgpu_shader_lib_context & context) { + ggml_webgpu_pad_pipeline_key key = {}; + key.circular = ggml_get_op_params_i32(context.dst, 8) != 0; + + auto it = pad_pipelines.find(key); + if (it != pad_pipelines.end()) { + return it->second; + } + + std::vector<std::string> defines; + std::string variant = "pad"; + + if (key.circular) { + defines.push_back("CIRCULAR"); + variant += "_circular"; + } + + defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size)); + + auto processed = preprocessor.preprocess(wgsl_pad, defines); + auto decisions = std::make_shared<ggml_webgpu_generic_shader_decisions>(); + decisions->wg_size = context.max_wg_size; + webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant); + pipeline.context = decisions; + pad_pipelines[key] = pipeline; + return pad_pipelines[key]; + } + + webgpu_pipeline get_quantize_q8_pipeline(const ggml_webgpu_shader_lib_context & context) { + ggml_webgpu_quantize_q8_pipeline_key key = {}; + key.src0_type = context.src0->type; + + auto it = quantize_q8_pipelines.find(key); + if (it != quantize_q8_pipelines.end()) { + return it->second; + } + const char * shader_src = wgsl_quantize_q8; + std::vector<std::string> defines; + std::string variant = "quantize_q8"; + + uint32_t wg_size = WEBGPU_MUL_MAT_VEC_WG_SIZE; + + defines.push_back("SRC1_INNER_TYPE=f32"); + defines.push_back(std::string("WG_SIZE=") + std::to_string(wg_size)); + + const struct ggml_type_traits * src0_traits = ggml_get_type_traits(context.src0->type); + std::string src0_name = src0_traits->type_name; + std::string type_upper = src0_name; + variant += "_" + src0_name; + std::transform(type_upper.begin(), type_upper.end(), type_upper.begin(), ::toupper); + + defines.push_back("MUL_ACC_" + type_upper); + defines.push_back("Q8_1_T"); + + defines.push_back(context.supports_subgroups ? "USE_SUBGROUP_REDUCTION" : "USE_WORKGROUP_REDUCTION"); + variant += context.supports_subgroups ? "_sg_reduce" : "_wg_reduce"; + + auto processed = preprocessor.preprocess(shader_src, defines); + auto decisions = std::make_shared<ggml_webgpu_generic_shader_decisions>(); + decisions->wg_size = wg_size; + webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant); + pipeline.context = decisions; + quantize_q8_pipelines[key] = pipeline; + return quantize_q8_pipelines[key]; + } + + webgpu_pipeline get_mul_mat_vec_pipeline(const ggml_webgpu_shader_lib_context & context) { + ggml_webgpu_mul_mat_vec_pipeline_key key = {}; + key.src0_type = context.src0->type; + key.src1_type = context.src1->type; + key.vectorized = (context.src0->ne[0] % 4 == 0 && + (context.src0->type == GGML_TYPE_F32 || context.src0->type == GGML_TYPE_F16)) ? + 1 : + 0; + key.use_mmvq = + ggml_webgpu_can_use_mmvq(context.src0, context.src1, context.supports_dot_product, context.vendor); + + auto it = mul_mat_vec_pipelines.find(key); + if (it != mul_mat_vec_pipelines.end()) { + return it->second; + } + + std::vector<std::string> defines; + std::string variant = "mul_mat_vec"; + const char * shader_src = wgsl_mul_mat_vec; + + // src0 type (matrix row) + switch (context.src0->type) { + case GGML_TYPE_F32: + defines.push_back("SRC0_INNER_TYPE=f32"); + defines.push_back("MUL_ACC_FLOAT"); + variant += "_f32"; + break; + case GGML_TYPE_F16: + defines.push_back("SRC0_INNER_TYPE=f16"); + defines.push_back("MUL_ACC_FLOAT"); + variant += "_f16"; + break; + default: + { + // Quantized types: use helpers but accumulate in f16 + const struct ggml_type_traits * src0_traits = ggml_get_type_traits(context.src0->type); + std::string src0_name = src0_traits->type_name; + std::string type_upper = src0_name; + variant += "_" + src0_name; + std::transform(type_upper.begin(), type_upper.end(), type_upper.begin(), ::toupper); + + defines.push_back("BYTE_HELPERS"); + defines.push_back("MUL_ACC_" + type_upper); + defines.push_back("U32_DEQUANT_HELPERS"); + defines.push_back("SRC0_INNER_TYPE=u32"); + switch (context.src0->type) { + case GGML_TYPE_Q8_0: + case GGML_TYPE_Q4_0: + case GGML_TYPE_Q4_1: + if (key.use_mmvq) { + defines.push_back("LEGACY_QUANTS"); + } + break; + case GGML_TYPE_Q2_K: + case GGML_TYPE_Q4_K: + if (key.use_mmvq) { + defines.push_back("K_QUANTS"); + } + break; + case GGML_TYPE_IQ1_S: + case GGML_TYPE_IQ1_M: + case GGML_TYPE_IQ2_S: + case GGML_TYPE_IQ3_S: + case GGML_TYPE_IQ4_NL: + case GGML_TYPE_IQ4_XS: + defines.push_back(type_upper + "_GRID"); + break; + case GGML_TYPE_IQ2_XXS: + case GGML_TYPE_IQ2_XS: + case GGML_TYPE_IQ3_XXS: + defines.push_back(type_upper + "_GRID"); + defines.push_back(type_upper + "_TABLES"); + break; + case GGML_TYPE_MXFP4: + defines.push_back(type_upper + "_LUT"); + break; + default: + break; + } + break; + } + } + + // src1 type (vector) + switch (context.src1->type) { + case GGML_TYPE_F32: + defines.push_back("SRC1_INNER_TYPE=f32"); + variant += "_f32"; + break; + case GGML_TYPE_F16: + defines.push_back("SRC1_INNER_TYPE=f16"); + variant += "_f16"; + break; + default: + GGML_ABORT("Unsupported src1 type for mul_mat_vec shader"); + } + + // VEC/SCALAR controls + defines.push_back(key.vectorized ? "VEC" : "SCALAR"); + + uint32_t wg_size = WEBGPU_MUL_MAT_VEC_WG_SIZE; + uint32_t outputs_per_wg = WEBGPU_MUL_MAT_VEC_FLOAT_OUTPUTS_PER_WG; + + if (key.src0_type == GGML_TYPE_Q1_0) { + outputs_per_wg = WEBGPU_MUL_MAT_VEC_LEGACY_Q_OUTPUTS_PER_WG; + } else if (key.src0_type >= GGML_TYPE_Q2_K) { + outputs_per_wg = WEBGPU_MUL_MAT_VEC_K_Q_OUTPUTS_PER_WG; + } else if (key.src0_type >= GGML_TYPE_Q4_0) { + outputs_per_wg = WEBGPU_MUL_MAT_VEC_LEGACY_Q_OUTPUTS_PER_WG; + } + + if (key.use_mmvq) { + defines.push_back("MMVQ"); + defines.push_back("Q8_1_T"); + } + + defines.push_back(std::string("WG_SIZE=") + std::to_string(wg_size)); + defines.push_back(std::string("OUTPUTS_PER_WG=") + std::to_string(outputs_per_wg)); + defines.push_back(context.supports_subgroups ? "USE_SUBGROUP_REDUCTION" : "USE_WORKGROUP_REDUCTION"); + variant += context.supports_subgroups ? "_sg_reduce" : "_wg_reduce"; + if (key.vectorized) { + variant += "_vectorized"; + } + + auto processed = preprocessor.preprocess(shader_src, defines); + auto decisions = std::make_shared<ggml_webgpu_mul_mat_vec_shader_decisions>(); + decisions->wg_size = wg_size; + decisions->outputs_per_wg = outputs_per_wg; + decisions->vec_size = key.vectorized ? 4 : 1; + + webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant); + pipeline.context = decisions; + mul_mat_vec_pipelines[key] = pipeline; + return mul_mat_vec_pipelines[key]; + } + + webgpu_pipeline get_mul_mat_fast_pipeline(const ggml_webgpu_shader_lib_context & context) { + ggml_webgpu_mul_mat_pipeline_key key = {}; + key.src0_type = context.src0->type; + key.src1_type = context.src1->type; + key.vectorized = (context.src0->ne[0] % 4 == 0 && context.dst->ne[0] % 4 == 0 && + (context.src0->type == GGML_TYPE_F32 || context.src0->type == GGML_TYPE_F16)) ? + 1 : + 0; + key.use_subgroup_matrix = context.supports_subgroup_matrix; + + auto it = mul_mat_fast_pipelines.find(key); + if (it != mul_mat_fast_pipelines.end()) { + return it->second; + } + + const char * shader_src = key.use_subgroup_matrix ? wgsl_mul_mat_subgroup_matrix : wgsl_mul_mat_reg_tile; + std::vector<std::string> defines; + std::string variant = key.use_subgroup_matrix ? "mul_mat_subgroup_matrix" : "mul_mat_reg_tile"; + + // src1 type + switch (context.src1->type) { + case GGML_TYPE_F32: + defines.push_back("SRC1_INNER_TYPE=f32"); + break; + case GGML_TYPE_F16: + defines.push_back("SRC1_INNER_TYPE=f16"); + break; + default: + GGML_ABORT("Unsupported src1 type for mul_mat fast shader"); + } + + // src0 type + const struct ggml_type_traits * src0_traits = ggml_get_type_traits(context.src0->type); + const char * src0_name = src0_traits->type_name; + + switch (context.src0->type) { + case GGML_TYPE_F32: + defines.push_back("SRC0_INNER_TYPE=f32"); + defines.push_back("FLOAT"); + defines.push_back("MUL_ACC_FLOAT"); + defines.push_back("INIT_SRC0_SHMEM_FLOAT"); + defines.push_back("INIT_SRC1_SHMEM_FLOAT"); + variant += "_f32"; + break; + case GGML_TYPE_F16: + defines.push_back("SRC0_INNER_TYPE=f16"); + defines.push_back("FLOAT"); + defines.push_back("MUL_ACC_FLOAT"); + defines.push_back("INIT_SRC0_SHMEM_FLOAT"); + defines.push_back("INIT_SRC1_SHMEM_FLOAT"); + variant += "_f16"; + break; + default: + { + std::string type_upper = src0_name; + std::transform(type_upper.begin(), type_upper.end(), type_upper.begin(), ::toupper); + + defines.push_back("BYTE_HELPERS"); + defines.push_back("MUL_ACC_" + type_upper); + defines.push_back("INIT_SRC0_SHMEM_" + type_upper); + defines.push_back("INIT_SRC1_SHMEM_FLOAT"); + defines.push_back("U32_DEQUANT_HELPERS"); + defines.push_back("SRC0_INNER_TYPE=u32"); + + switch (context.src0->type) { + case GGML_TYPE_IQ1_S: + case GGML_TYPE_IQ1_M: + case GGML_TYPE_IQ4_NL: + case GGML_TYPE_IQ4_XS: + defines.push_back(type_upper + "_GRID"); + break; + case GGML_TYPE_IQ2_XXS: + case GGML_TYPE_IQ2_XS: + case GGML_TYPE_IQ2_S: + case GGML_TYPE_IQ3_XXS: + case GGML_TYPE_IQ3_S: + defines.push_back(type_upper + "_GRID"); + defines.push_back(type_upper + "_TABLES"); + break; + case GGML_TYPE_MXFP4: + defines.push_back(type_upper + "_LUT"); + break; + default: + break; + } + + variant += std::string("_") + src0_name; + break; + } + } + + // VEC/SCALAR controls + defines.push_back(key.vectorized ? "VEC" : "SCALAR"); + + const bool is_quant = ggml_is_quantized(context.src0->type); + + uint32_t tile_k; + if (key.use_subgroup_matrix) { + tile_k = is_quant ? WEBGPU_MUL_MAT_SUBGROUP_TILE_K_QUANT : WEBGPU_MUL_MAT_SUBGROUP_TILE_K_FLOAT; + } else { + tile_k = is_quant ? WEBGPU_MUL_MAT_REG_TILE_K_QUANT : WEBGPU_MUL_MAT_REG_TILE_K_FLOAT; + } + + // Tiles + defines.push_back("TILE_M=" + std::to_string(WEBGPU_MUL_MAT_TILE_M) + "u"); + defines.push_back("TILE_N=" + std::to_string(WEBGPU_MUL_MAT_TILE_N) + "u"); + + // Subgroup matrix specifics + if (key.use_subgroup_matrix) { + defines.push_back("TILE_K=" + std::to_string(tile_k) + "u"); + defines.push_back("MAX_SUBGROUP_SIZE=" + std::to_string(context.max_subgroup_size) + "u"); + defines.push_back("SUBGROUP_M=" + std::to_string(WEBGPU_MUL_MAT_SUBGROUP_M) + "u"); + defines.push_back("SUBGROUP_N=" + std::to_string(WEBGPU_MUL_MAT_SUBGROUP_N) + "u"); + defines.push_back("SUBGROUP_MATRIX_M=" + std::to_string(WEBGPU_MUL_MAT_SUBGROUP_MATRIX_M) + "u"); + defines.push_back("SUBGROUP_MATRIX_N=" + std::to_string(WEBGPU_MUL_MAT_SUBGROUP_MATRIX_N) + "u"); + defines.push_back("SUBGROUP_MATRIX_M_SIZE=" + std::to_string(context.sg_mat_m) + "u"); + defines.push_back("SUBGROUP_MATRIX_N_SIZE=" + std::to_string(context.sg_mat_n) + "u"); + defines.push_back("SUBGROUP_MATRIX_K_SIZE=" + std::to_string(context.sg_mat_k) + "u"); + } + + // variant suffix for src1 type + variant += std::string("_") + (context.src1->type == GGML_TYPE_F32 ? "f32" : "f16"); + if (key.vectorized) { + variant += "_vectorized"; + } + + if (!key.use_subgroup_matrix) { + defines.push_back("WORKGROUP_SIZE_M=" + std::to_string(WEBGPU_MUL_MAT_WG_SIZE_M) + "u"); + defines.push_back("WORKGROUP_SIZE_N=" + std::to_string(WEBGPU_MUL_MAT_WG_SIZE_N) + "u"); + defines.push_back("TILE_K=" + std::to_string(tile_k) + "u"); + } + + auto processed = preprocessor.preprocess(shader_src, defines); + + auto decisions = std::make_shared<ggml_webgpu_mul_mat_shader_decisions>(); + decisions->tile_k = tile_k; + decisions->tile_m = WEBGPU_MUL_MAT_TILE_M; + decisions->tile_n = WEBGPU_MUL_MAT_TILE_N; + decisions->use_subgroup_matrix = key.use_subgroup_matrix; + if (key.use_subgroup_matrix) { + decisions->subgroup_m = WEBGPU_MUL_MAT_SUBGROUP_M; + decisions->subgroup_n = WEBGPU_MUL_MAT_SUBGROUP_N; + decisions->subgroup_matrix_m = WEBGPU_MUL_MAT_SUBGROUP_MATRIX_M; + decisions->subgroup_matrix_n = WEBGPU_MUL_MAT_SUBGROUP_MATRIX_N; + decisions->wg_size = context.max_subgroup_size; + } else { + decisions->wg_size_m = WEBGPU_MUL_MAT_WG_SIZE_M; + decisions->wg_size_n = WEBGPU_MUL_MAT_WG_SIZE_N; + decisions->wg_size = WEBGPU_MUL_MAT_WG_SIZE_M * WEBGPU_MUL_MAT_WG_SIZE_N; + decisions->mul_mat_wg_size = WEBGPU_MUL_MAT_WG_SIZE; + } + + webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant); + pipeline.context = decisions; + mul_mat_fast_pipelines[key] = pipeline; + return mul_mat_fast_pipelines[key]; + } + + webgpu_pipeline get_mul_mat_id_gather_pipeline(const ggml_webgpu_shader_lib_context & context) { + auto it = mul_mat_id_gather_pipelines.find(1); + if (it != mul_mat_id_gather_pipelines.end()) { + return it->second; + } + std::vector<std::string> defines; + defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size)); + + auto processed = preprocessor.preprocess(wgsl_mul_mat_id_gather, defines); + auto decisions = std::make_shared<ggml_webgpu_generic_shader_decisions>(); + decisions->wg_size = context.max_wg_size; + + webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, "mul_mat_id_gather"); + pipeline.context = decisions; + mul_mat_id_gather_pipelines[1] = pipeline; + return pipeline; + } + + webgpu_pipeline get_mul_mat_id_pipeline(const ggml_webgpu_shader_lib_context & context) { + ggml_webgpu_mul_mat_id_pipeline_key key = {}; + key.src0_type = context.src0->type; + key.src1_type = context.src1->type; + key.n_experts = context.src0->ne[2]; + key.vectorized = (context.src0->ne[0] % 4 == 0 && context.src0->ne[1] % 4 == 0 && + (context.src0->type == GGML_TYPE_F32 || context.src0->type == GGML_TYPE_F16)) ? + 1 : + 0; + + auto it = mul_mat_id_pipelines.find(key); + if (it != mul_mat_id_pipelines.end()) { + return it->second; + } + + std::vector<std::string> defines; + std::string variant = "mul_mat_id"; + defines.push_back("MUL_MAT_ID"); + + // src1 type + switch (context.src1->type) { + case GGML_TYPE_F32: + defines.push_back("SRC1_INNER_TYPE=f32"); + break; + case GGML_TYPE_F16: + defines.push_back("SRC1_INNER_TYPE=f16"); + break; + default: + GGML_ABORT("Unsupported src1 type for mul_mat fast shader"); + } + + // src0 type + const struct ggml_type_traits * src0_traits = ggml_get_type_traits(context.src0->type); + const char * src0_name = src0_traits->type_name; + + switch (context.src0->type) { + case GGML_TYPE_F32: + defines.push_back("SRC0_INNER_TYPE=f32"); + defines.push_back("INIT_SRC0_SHMEM_FLOAT"); + defines.push_back("INIT_SRC1_SHMEM_FLOAT"); + variant += "_f32"; + break; + case GGML_TYPE_F16: + defines.push_back("SRC0_INNER_TYPE=f16"); + defines.push_back("INIT_SRC0_SHMEM_FLOAT"); + defines.push_back("INIT_SRC1_SHMEM_FLOAT"); + variant += "_f16"; + break; + default: + { + std::string type_upper = src0_name; + std::transform(type_upper.begin(), type_upper.end(), type_upper.begin(), ::toupper); + + defines.push_back("BYTE_HELPERS"); + defines.push_back("INIT_SRC0_SHMEM_" + type_upper); + defines.push_back("INIT_SRC1_SHMEM_FLOAT"); + defines.push_back("U32_DEQUANT_HELPERS"); + defines.push_back("SRC0_INNER_TYPE=u32"); + + switch (context.src0->type) { + case GGML_TYPE_IQ1_S: + case GGML_TYPE_IQ1_M: + case GGML_TYPE_IQ4_NL: + case GGML_TYPE_IQ4_XS: + defines.push_back(type_upper + "_GRID"); + break; + case GGML_TYPE_IQ2_XXS: + case GGML_TYPE_IQ2_XS: + case GGML_TYPE_IQ2_S: + case GGML_TYPE_IQ3_XXS: + case GGML_TYPE_IQ3_S: + defines.push_back(type_upper + "_GRID"); + defines.push_back(type_upper + "_TABLES"); + break; + case GGML_TYPE_MXFP4: + defines.push_back(type_upper + "_LUT"); + break; + default: + break; + } + + variant += std::string("_") + src0_name; + break; + } + } + + // VEC/SCALAR controls + defines.push_back(key.vectorized ? "VEC" : "SCALAR"); + + // mul_mat_id is register-tile only. + const uint32_t tile_k = + ggml_is_quantized(context.src0->type) ? WEBGPU_MUL_MAT_REG_TILE_K_QUANT : WEBGPU_MUL_MAT_REG_TILE_K_FLOAT; + + // Tiles + defines.push_back("TILE_M=" + std::to_string(WEBGPU_MUL_MAT_TILE_M) + "u"); + defines.push_back("TILE_N=" + std::to_string(WEBGPU_MUL_MAT_TILE_N) + "u"); + defines.push_back("TILE_K=" + std::to_string(tile_k) + "u"); + + defines.push_back("WORKGROUP_SIZE_M=" + std::to_string(WEBGPU_MUL_MAT_WG_SIZE_M) + "u"); + defines.push_back("WORKGROUP_SIZE_N=" + std::to_string(WEBGPU_MUL_MAT_WG_SIZE_N) + "u"); + + // variant suffix for src1 type + variant += std::string("_") + (context.src1->type == GGML_TYPE_F32 ? "f32" : "f16"); + if (key.vectorized) { + variant += "_vectorized"; + } + + auto processed = preprocessor.preprocess(wgsl_mul_mat_id, defines); + + auto decisions = std::make_shared<ggml_webgpu_mul_mat_shader_decisions>(); + decisions->tile_k = tile_k; + decisions->tile_m = WEBGPU_MUL_MAT_TILE_M; + decisions->tile_n = WEBGPU_MUL_MAT_TILE_N; + decisions->wg_size_m = WEBGPU_MUL_MAT_WG_SIZE_M; + decisions->wg_size_n = WEBGPU_MUL_MAT_WG_SIZE_N; + decisions->wg_size = WEBGPU_MUL_MAT_WG_SIZE_M * WEBGPU_MUL_MAT_WG_SIZE_N; + + webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant); + pipeline.context = decisions; + mul_mat_id_pipelines[key] = pipeline; + return mul_mat_id_pipelines[key]; + } + + webgpu_pipeline get_mul_mat_id_vec_pipeline(const ggml_webgpu_shader_lib_context & context) { + ggml_webgpu_mul_mat_id_pipeline_key key = {}; + key.src0_type = context.src0->type; + key.src1_type = context.src1->type; + key.n_experts = context.src0->ne[2]; + key.vectorized = (context.src0->ne[0] % 4 == 0 && + (context.src0->type == GGML_TYPE_F32 || context.src0->type == GGML_TYPE_F16)) ? + 1 : + 0; + + auto it = mul_mat_id_vec_pipelines.find(key); + if (it != mul_mat_id_vec_pipelines.end()) { + return it->second; + } + + std::vector<std::string> defines; + std::string variant = "mul_mat_id_vec"; + const char * shader_src = wgsl_mul_mat_id_vec; + + // src1 type + switch (context.src1->type) { + case GGML_TYPE_F32: + defines.push_back("SRC1_INNER_TYPE=f32"); + break; + case GGML_TYPE_F16: + defines.push_back("SRC1_INNER_TYPE=f16"); + break; + default: + GGML_ABORT("Unsupported src1 type for mul_mat fast shader"); + } + + // src0 type + switch (context.src0->type) { + case GGML_TYPE_F32: + defines.push_back("SRC0_INNER_TYPE=f32"); + defines.push_back("MUL_ACC_FLOAT"); + variant += "_f32"; + break; + case GGML_TYPE_F16: + defines.push_back("SRC0_INNER_TYPE=f16"); + defines.push_back("MUL_ACC_FLOAT"); + variant += "_f16"; + break; + default: + { + // Quantized types: use helpers but accumulate in f16 + const struct ggml_type_traits * src0_traits = ggml_get_type_traits(context.src0->type); + std::string src0_name = src0_traits->type_name; + std::string type_upper = src0_name; + variant += "_" + src0_name; + std::transform(type_upper.begin(), type_upper.end(), type_upper.begin(), ::toupper); + + defines.push_back("BYTE_HELPERS"); + defines.push_back("MUL_ACC_" + type_upper); + defines.push_back("U32_DEQUANT_HELPERS"); + defines.push_back("SRC0_INNER_TYPE=u32"); + switch (context.src0->type) { + case GGML_TYPE_IQ1_S: + case GGML_TYPE_IQ1_M: + case GGML_TYPE_IQ2_S: + case GGML_TYPE_IQ3_S: + case GGML_TYPE_IQ4_NL: + case GGML_TYPE_IQ4_XS: + defines.push_back(type_upper + "_GRID"); + break; + case GGML_TYPE_IQ2_XXS: + case GGML_TYPE_IQ2_XS: + case GGML_TYPE_IQ3_XXS: + defines.push_back(type_upper + "_GRID"); + defines.push_back(type_upper + "_TABLES"); + break; + case GGML_TYPE_MXFP4: + defines.push_back(type_upper + "_LUT"); + break; + default: + break; + } + break; + } + } + + // VEC/SCALAR controls + defines.push_back(key.vectorized ? "VEC" : "SCALAR"); + + uint32_t wg_size = WEBGPU_MUL_MAT_VEC_WG_SIZE; + uint32_t outputs_per_wg = WEBGPU_MUL_MAT_VEC_FLOAT_OUTPUTS_PER_WG; + + if (key.src0_type == GGML_TYPE_Q1_0) { + outputs_per_wg = WEBGPU_MUL_MAT_VEC_LEGACY_Q_OUTPUTS_PER_WG; + } else if (key.src0_type >= GGML_TYPE_Q2_K) { + outputs_per_wg = WEBGPU_MUL_MAT_VEC_K_Q_OUTPUTS_PER_WG; + } else if (key.src0_type >= GGML_TYPE_Q4_0) { + outputs_per_wg = WEBGPU_MUL_MAT_VEC_LEGACY_Q_OUTPUTS_PER_WG; + } + + // variant suffix for src1 type + variant += std::string("_") + (context.src1->type == GGML_TYPE_F32 ? "f32" : "f16"); + + defines.push_back(std::string("WG_SIZE=") + std::to_string(wg_size)); + defines.push_back(std::string("OUTPUTS_PER_WG=") + std::to_string(outputs_per_wg)); + defines.push_back(context.supports_subgroups ? "USE_SUBGROUP_REDUCTION" : "USE_WORKGROUP_REDUCTION"); + variant += context.supports_subgroups ? "_sg_reduce" : "_wg_reduce"; + if (key.vectorized) { + variant += "_vectorized"; + } + + defines.push_back(std::string("N_EXPERTS=") + std::to_string(key.n_experts)); + + auto processed = preprocessor.preprocess(shader_src, defines); + + auto decisions = std::make_shared<ggml_webgpu_mul_mat_vec_shader_decisions>(); + decisions->wg_size = wg_size; + decisions->outputs_per_wg = outputs_per_wg; + + webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant); + pipeline.context = decisions; + mul_mat_id_vec_pipelines[key] = pipeline; + return mul_mat_id_vec_pipelines[key]; + } + + webgpu_pipeline get_unary_pipeline(const ggml_webgpu_shader_lib_context & context) { + const bool is_unary = context.dst->op == GGML_OP_UNARY; + const int op = is_unary ? (int) ggml_get_unary_op(context.dst) : context.dst->op; + ggml_webgpu_unary_pipeline_key key = {}; + key.type = context.dst->type; + key.op = op; + key.is_unary = is_unary; + key.inplace = ggml_webgpu_tensor_equal(context.src0, context.dst) || context.dst->op == GGML_OP_FILL; + key.ttype = (ggml_tri_type) ggml_get_op_params_i32(context.dst, 0); + + auto it = unary_pipelines.find(key); + if (it != unary_pipelines.end()) { + return it->second; + } + + std::vector<std::string> defines; + std::string variant = + key.is_unary ? ggml_unary_op_name((ggml_unary_op) key.op) : ggml_op_name((ggml_op) key.op); + defines.push_back(variant); + + switch (key.type) { + case GGML_TYPE_F32: + defines.push_back("TYPE_F32"); + variant += "_f32"; + break; + case GGML_TYPE_F16: + defines.push_back("TYPE_F16"); + variant += "_f16"; + break; + default: + GGML_ABORT("Unsupported type for unary shader"); + } + + if (key.inplace) { + defines.push_back("INPLACE"); + variant += "_inplace"; + } + + if (op == GGML_OP_TRI) { + switch (key.ttype) { + case GGML_TRI_TYPE_LOWER: + defines.push_back("TRI_TYPE_LOWER"); + variant += "_tri_type_lower"; + break; + case GGML_TRI_TYPE_LOWER_DIAG: + defines.push_back("TRI_TYPE_LOWER_DIAG"); + variant += "_tri_type_lower_diag"; + break; + case GGML_TRI_TYPE_UPPER: + defines.push_back("TRI_TYPE_UPPER"); + variant += "_tri_type_upper"; + break; + case GGML_TRI_TYPE_UPPER_DIAG: + defines.push_back("TRI_TYPE_UPPER_DIAG"); + variant += "_tri_upper_diag"; + break; + default: + GGML_ABORT("Unsupported ggml_tri_type for unary shader"); + } + } + + defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size)); + + auto processed = preprocessor.preprocess(wgsl_unary, defines); + auto decisions = std::make_shared<ggml_webgpu_generic_shader_decisions>(); + decisions->wg_size = context.max_wg_size; + decisions->inplace = key.inplace; + webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant); + pipeline.context = decisions; + unary_pipelines[key] = pipeline; + return unary_pipelines[key]; + } + + webgpu_pipeline get_rms_norm_mul_pipeline(const ggml_webgpu_shader_lib_context & context) { + ggml_webgpu_rms_norm_mul_pipeline_key key = {}; + key.inplace = ggml_webgpu_tensor_equal(context.src0, context.dst); + key.overlap = ggml_webgpu_tensor_equal(context.src1, context.dst); + key.src_overlap = ggml_webgpu_tensor_overlap(context.src0, context.src1); + + auto it = rms_norm_mul_pipelines.find(key); + if (it != rms_norm_mul_pipelines.end()) { + return it->second; + } + + std::vector<std::string> defines; + std::string op_name = "RMS_NORM_MUL"; + std::string variant = op_name; + + if (key.inplace) { + defines.push_back("INPLACE"); + variant += "_inplace"; + } else if (key.overlap) { + defines.push_back("OVERLAP"); + variant += "_overlap"; + } else if (key.src_overlap) { + defines.push_back("SRC_OVERLAP"); + variant += "_src_overlap"; + } + + defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size)); + + auto processed = preprocessor.preprocess(wgsl_rms_norm_mul, defines); + auto pipeline_decisions = std::make_shared<ggml_webgpu_rms_norm_mul_shader_decisions>(); + pipeline_decisions->wg_size = context.max_wg_size; + pipeline_decisions->inplace = key.inplace; + pipeline_decisions->overlap = key.overlap; + pipeline_decisions->src_overlap = key.src_overlap; + webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant); + pipeline.context = pipeline_decisions; + rms_norm_mul_pipelines[key] = pipeline; + return rms_norm_mul_pipelines[key]; + } + + webgpu_pipeline get_binary_pipeline(const ggml_webgpu_shader_lib_context & context) { + ggml_webgpu_binary_pipeline_key key = {}; + key.type = context.dst->type; + key.op = context.dst->op; + key.inplace = ggml_webgpu_tensor_equal(context.src0, context.dst); + key.overlap = ggml_webgpu_tensor_equal(context.src1, context.dst); + key.src_overlap = ggml_webgpu_tensor_overlap(context.src0, context.src1); + + auto it = binary_pipelines.find(key); + if (it != binary_pipelines.end()) { + return it->second; + } + + std::vector<std::string> defines; + std::string op_name = ggml_op_name((ggml_op) key.op); + std::string variant = op_name; + + defines.push_back(std::string("OP_") + op_name); + + switch (key.type) { + case GGML_TYPE_F32: + defines.push_back("TYPE_F32"); + variant += "_f32"; + break; + case GGML_TYPE_F16: + defines.push_back("TYPE_F16"); + variant += "_f16"; + break; + default: + GGML_ABORT("Unsupported type for binary shader"); + } + + if (key.inplace) { + defines.push_back("INPLACE"); + variant += "_inplace"; + } else if (key.overlap) { + defines.push_back("OVERLAP"); + variant += "_overlap"; + } else if (key.src_overlap) { + defines.push_back("SRC_OVERLAP"); + variant += "_src_overlap"; + } + + defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size)); + + auto processed = preprocessor.preprocess(wgsl_binary, defines); + auto pipeline_decisions = std::make_shared<ggml_webgpu_binary_shader_decisions>(); + pipeline_decisions->wg_size = context.max_wg_size; + pipeline_decisions->inplace = key.inplace; + pipeline_decisions->overlap = key.overlap; + pipeline_decisions->src_overlap = key.src_overlap; + + webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant); + pipeline.context = pipeline_decisions; + binary_pipelines[key] = pipeline; + return binary_pipelines[key]; + } + + webgpu_pipeline get_add_id_pipeline(const ggml_webgpu_shader_lib_context & context) { + ggml_webgpu_add_id_pipeline_key key = {}; + key.inplace = ggml_webgpu_tensor_equal(context.src0, context.dst); + + auto it = add_id_pipelines.find(key); + if (it != add_id_pipelines.end()) { + return it->second; + } + + std::vector<std::string> defines; + std::string variant = "add_id"; + const char * shader_src = wgsl_add_id; + + if (key.inplace) { + defines.push_back("INPLACE"); + variant += "_inplace"; + } + + defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size)); + + auto processed = preprocessor.preprocess(shader_src, defines); + auto pipeline_decisions = std::make_shared<ggml_webgpu_generic_shader_decisions>(); + pipeline_decisions->wg_size = context.max_wg_size; + pipeline_decisions->inplace = key.inplace; + + webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant); + pipeline.context = pipeline_decisions; + add_id_pipelines[key] = pipeline; + return pipeline; + } + + webgpu_pipeline get_concat_pipeline(const ggml_webgpu_shader_lib_context & context) { + ggml_webgpu_concat_pipeline_key key = {}; + key.type = context.dst->type; + key.src_overlap = ggml_webgpu_tensor_overlap(context.src0, context.src1); + + auto it = concat_pipelines.find(key); + if (it != concat_pipelines.end()) { + return it->second; + } + + std::vector<std::string> defines; + std::string variant = "concat"; + + switch (key.type) { + case GGML_TYPE_F32: + defines.push_back("TYPE_F32"); + variant += "_f32"; + break; + case GGML_TYPE_I32: + defines.push_back("TYPE_I32"); + variant += "_i32"; + break; + default: + GGML_ABORT("Unsupported type for concat shader"); + } + + if (key.src_overlap) { + defines.push_back("SRC_OVERLAP"); + variant += "_src_overlap"; + } + + defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size)); + + auto processed = preprocessor.preprocess(wgsl_concat, defines); + auto decisions = std::make_shared<ggml_webgpu_binary_shader_decisions>(); + decisions->wg_size = context.max_wg_size; + decisions->src_overlap = key.src_overlap; + webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant); + pipeline.context = decisions; + concat_pipelines[key] = pipeline; + return concat_pipelines[key]; + } + + webgpu_pipeline get_repeat_pipeline(const ggml_webgpu_shader_lib_context & context) { + ggml_webgpu_repeat_pipeline_key key = {}; + key.type = context.dst->type; + + auto it = repeat_pipelines.find(key); + if (it != repeat_pipelines.end()) { + return it->second; + } + + std::vector<std::string> defines; + std::string variant = "repeat"; + + switch (key.type) { + case GGML_TYPE_F32: + defines.push_back("TYPE_F32"); + variant += "_f32"; + break; + case GGML_TYPE_I32: + defines.push_back("TYPE_I32"); + variant += "_i32"; + break; + case GGML_TYPE_I16: + defines.push_back("TYPE_I16"); + variant += "_i16"; + break; + default: + GGML_ABORT("Unsupported type for repeat shader"); + } + + defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size)); + + auto processed = preprocessor.preprocess(wgsl_repeat, defines); + auto decisions = std::make_shared<ggml_webgpu_generic_shader_decisions>(); + decisions->wg_size = context.max_wg_size; + webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant); + pipeline.context = decisions; + repeat_pipelines[key] = pipeline; + return repeat_pipelines[key]; + } + + webgpu_pipeline get_flash_attn_pipeline(const ggml_webgpu_shader_lib_context & context) { + const bool can_use_subgroup_matrix = ggml_webgpu_flash_attn_can_use_subgroup_matrix_path( + context.supports_subgroup_matrix, context.sg_mat_k, context.sg_mat_n, context.src0, context.src2); + ggml_webgpu_flash_attn_decisions decisions = {}; + decisions.use_sg_matrix = can_use_subgroup_matrix; + decisions.q_tile = decisions.use_sg_matrix ? context.sg_mat_m : GGML_WEBGPU_FLASH_ATTN_TILE_Q_TILE; + + ggml_webgpu_flash_attn_pipeline_key key = {}; + key.common = + ggml_webgpu_flash_attn_make_common_pipeline_key(context, decisions.use_sg_matrix ? context.sg_mat_k : 1u); + key.common.kv_direct = decisions.use_sg_matrix && key.common.kv_direct; + key.use_sg_matrix = decisions.use_sg_matrix; + + const uint32_t max_kv_tile = ggml_webgpu_flash_attn_max_kv_tile( + context.wg_mem_limit_bytes, decisions.q_tile, decisions.use_sg_matrix ? context.sg_mat_n : 1u, + key.common.head_dim_qk, key.common.head_dim_v, key.common.has_mask, key.common.kv_direct); + GGML_ASSERT(max_kv_tile > 0); + + decisions.kv_tile = decisions.use_sg_matrix ? + std::min(max_kv_tile, context.sg_mat_n * GGML_WEBGPU_FLASH_ATTN_PREFERRED_KV_SG_TILES) : + std::min(GGML_WEBGPU_FLASH_ATTN_TILE_MAX_KV_TILE, max_kv_tile); + decisions.wg_size = + decisions.use_sg_matrix ? + std::max(context.max_subgroup_size, GGML_WEBGPU_FLASH_ATTN_PREFERRED_WG_SIZE) : + std::min(context.max_wg_size, std::max(GGML_WEBGPU_FLASH_ATTN_PREFERRED_WG_SIZE, + GGML_WEBGPU_FLASH_ATTN_TILE_Q_TILE * context.max_subgroup_size)); + + if (key.common.kv_direct) { + decisions.kv_tile = std::min(decisions.kv_tile, GGML_WEBGPU_KV_SEQ_PAD); + while (GGML_WEBGPU_KV_SEQ_PAD % decisions.kv_tile != 0) { + decisions.kv_tile -= decisions.use_sg_matrix ? context.sg_mat_n : context.min_subgroup_size; + } + } + + auto it = flash_attn_pipelines.find(key); + if (it != flash_attn_pipelines.end()) { + return it->second; + } + + std::string variant = decisions.use_sg_matrix ? "flash_attn" : "flash_attn_tile"; + std::vector<std::string> defines = ggml_webgpu_flash_attn_common_defines(key.common, variant, decisions.q_tile, + decisions.kv_tile, decisions.wg_size); + const char * shader_src = nullptr; + if (!key.use_sg_matrix) { + shader_src = wgsl_flash_attn_tile; + defines.push_back("MIN_SUBGROUP_SIZE=" + std::to_string(context.min_subgroup_size) + "u"); + defines.push_back("MAX_SUBGROUP_SIZE=" + std::to_string(context.max_subgroup_size) + "u"); + variant += "_tile_sg" + std::to_string(context.min_subgroup_size) + "_" + + std::to_string(context.max_subgroup_size); + } else { + shader_src = wgsl_flash_attn; + defines.push_back(std::string("SG_MAT_M=") + std::to_string(context.sg_mat_m)); + defines.push_back(std::string("SG_MAT_N=") + std::to_string(context.sg_mat_n)); + defines.push_back(std::string("SG_MAT_K=") + std::to_string(context.sg_mat_k)); + } + auto pipeline_decisions = std::make_shared<ggml_webgpu_flash_attn_decisions>(decisions); + webgpu_pipeline pipeline = + ggml_webgpu_create_pipeline(device, preprocessor.preprocess(shader_src, defines), variant); + pipeline.context = pipeline_decisions; + flash_attn_pipelines[key] = pipeline; + return flash_attn_pipelines[key]; + } + + webgpu_pipeline get_flash_attn_vec_pipeline(const ggml_webgpu_shader_lib_context & context) { + ggml_webgpu_flash_attn_vec_pipeline_key key = {}; + key.common = ggml_webgpu_flash_attn_make_common_pipeline_key(context, GGML_WEBGPU_FLASH_ATTN_TILE_KV_VEC_WIDTH); + + auto it = flash_attn_vec_pipelines.find(key); + if (it != flash_attn_vec_pipelines.end()) { + return it->second; + } + + ggml_webgpu_flash_attn_vec_decisions decisions = {}; + decisions.kv_tile = + ggml_webgpu_flash_attn_get_vec_kv_tile(context.wg_mem_limit_bytes, key.common.head_dim_qk, + key.common.head_dim_v, key.common.has_mask, key.common.kv_direct); + decisions.wg_size = context.max_subgroup_size; + + std::string variant = "flash_attn_vec"; + std::vector<std::string> defines = + ggml_webgpu_flash_attn_common_defines(key.common, variant, 1u, decisions.kv_tile, decisions.wg_size); + if (key.common.has_mask) { + defines.push_back("BLK"); + variant.resize(variant.size() - (sizeof("_mask") - 1)); + variant += "_mask_blk"; + } + uint32_t vec_ne = 1u; + if (key.common.k_type == GGML_TYPE_F16 && key.common.v_type == GGML_TYPE_F16 && + key.common.head_dim_qk == key.common.head_dim_v) { + switch (key.common.head_dim_qk) { + case 64: + case 192: + case 576: + vec_ne = 2u; + break; + case 96: + vec_ne = 4u; + break; + default: + break; + } + } + defines.push_back(std::string("VEC_NE=") + std::to_string(vec_ne) + "u"); + + auto pipeline_decisions = std::make_shared<ggml_webgpu_flash_attn_vec_decisions>(decisions); + webgpu_pipeline pipeline = + ggml_webgpu_create_pipeline(device, preprocessor.preprocess(wgsl_flash_attn_vec_split, defines), variant); + pipeline.context = pipeline_decisions; + flash_attn_vec_pipelines[key] = pipeline; + return flash_attn_vec_pipelines[key]; + } + + webgpu_pipeline get_flash_attn_blk_pipeline(const ggml_webgpu_shader_lib_context & context, uint32_t kv_tile) { + ggml_webgpu_flash_attn_blk_pipeline_key key = {}; + key.kv_tile = kv_tile; + auto it = flash_attn_blk_pipelines.find(key); + if (it != flash_attn_blk_pipelines.end()) { + return it->second; + } + + std::vector<std::string> defines; + std::string variant = "flash_attn_vec_blk"; + + defines.push_back(std::string("KV_TILE=") + std::to_string(key.kv_tile)); + variant += std::string("_kvt") + std::to_string(key.kv_tile); + + uint32_t wg_size = 1; + while ((wg_size << 1) <= context.max_wg_size) { + wg_size <<= 1; + } + defines.push_back(std::string("WG_SIZE=") + std::to_string(wg_size)); + variant += std::string("_wg") + std::to_string(wg_size); + + webgpu_pipeline pipeline = + ggml_webgpu_create_pipeline(device, preprocessor.preprocess(wgsl_flash_attn_vec_blk, defines), variant); + flash_attn_blk_pipelines[key] = pipeline; + return flash_attn_blk_pipelines[key]; + } + + webgpu_pipeline get_flash_attn_vec_reduce_pipeline(const ggml_webgpu_shader_lib_context & context) { + ggml_webgpu_flash_attn_vec_reduce_pipeline_key key = {}; + key.head_dim_v = (uint32_t) context.src2->ne[0]; + key.dst_type = context.dst->type; + key.wg_size = context.max_wg_size; + auto it = flash_attn_vec_reduce_pipelines.find(key); + if (it != flash_attn_vec_reduce_pipelines.end()) { + return it->second; + } + + std::vector<std::string> defines; + std::string variant = "flash_attn_vec_reduce"; + + switch (key.dst_type) { + case GGML_TYPE_F32: + defines.push_back("DST_F32"); + break; + case GGML_TYPE_F16: + defines.push_back("DST_F16"); + break; + default: + GGML_ABORT("Unsupported dst type for flash attention vec reduce shader"); + } + variant += std::string("_dst") + ggml_type_name(key.dst_type); + + defines.push_back(std::string("HEAD_DIM_V=") + std::to_string(key.head_dim_v)); + variant += std::string("_hsv") + std::to_string(key.head_dim_v); + + defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size)); + variant += std::string("_wg") + std::to_string(context.max_wg_size); + + webgpu_pipeline pipeline = + ggml_webgpu_create_pipeline(device, preprocessor.preprocess(wgsl_flash_attn_vec_reduce, defines), variant); + flash_attn_vec_reduce_pipelines[key] = pipeline; + return flash_attn_vec_reduce_pipelines[key]; + } + + webgpu_pipeline get_cpy_pipeline(const ggml_webgpu_shader_lib_context & context) { + ggml_webgpu_cpy_pipeline_key key = {}; + key.src_type = context.src0->type; + key.dst_type = context.dst->type; + + auto it = cpy_pipelines.find(key); + if (it != cpy_pipelines.end()) { + return it->second; + } + + std::vector<std::string> defines; + std::string variant = "cpy"; + + switch (key.src_type) { + case GGML_TYPE_F32: + defines.push_back("SRC_F32"); + variant += "_f32"; + break; + case GGML_TYPE_F16: + defines.push_back("SRC_F16"); + variant += "_f16"; + break; + default: + GGML_ABORT("Unsupported src type for cpy shader"); + } + + switch (key.dst_type) { + case GGML_TYPE_F32: + defines.push_back("DST_F32"); + variant += "_f32"; + break; + case GGML_TYPE_F16: + defines.push_back("DST_F16"); + variant += "_f16"; + break; + case GGML_TYPE_I32: + defines.push_back("DST_I32"); + variant += "_i32"; + break; + default: + GGML_ABORT("Unsupported dst type for cpy shader"); + } + + defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size)); + + auto processed = preprocessor.preprocess(wgsl_cpy, defines); + auto decisions = std::make_shared<ggml_webgpu_generic_shader_decisions>(); + decisions->wg_size = context.max_wg_size; + webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant); + pipeline.context = decisions; + cpy_pipelines[key] = pipeline; + return cpy_pipelines[key]; + } + + webgpu_pipeline get_glu_pipeline(const ggml_webgpu_shader_lib_context & context) { + ggml_webgpu_glu_pipeline_key key = {}; + key.glu_op = ggml_get_glu_op(context.dst); + key.type = context.dst->type; + key.split = (context.src1 != nullptr); + + auto it = glu_pipelines.find(key); + if (it != glu_pipelines.end()) { + return it->second; + } + + std::vector<std::string> defines; + std::string variant = "glu"; + + switch (key.glu_op) { + case GGML_GLU_OP_REGLU: + defines.push_back("OP_REGLU"); + variant += "_reglu"; + break; + case GGML_GLU_OP_GEGLU: + defines.push_back("OP_GEGLU"); + variant += "_geglu"; + break; + case GGML_GLU_OP_SWIGLU: + defines.push_back("OP_SWIGLU"); + variant += "_swiglu"; + break; + case GGML_GLU_OP_SWIGLU_OAI: + defines.push_back("OP_SWIGLU_OAI"); + variant += "_swiglu_oai"; + break; + case GGML_GLU_OP_GEGLU_ERF: + defines.push_back("OP_GEGLU_ERF"); + variant += "_geglu_erf"; + break; + case GGML_GLU_OP_GEGLU_QUICK: + defines.push_back("OP_GEGLU_QUICK"); + variant += "_geglu_quick"; + break; + default: + GGML_ABORT("Unsupported GLU op"); + } + switch (key.type) { + case GGML_TYPE_F32: + defines.push_back("TYPE_F32"); + variant += "_f32"; + break; + case GGML_TYPE_F16: + defines.push_back("TYPE_F16"); + variant += "_f16"; + break; + default: + GGML_ABORT("Unsupported type for GLU shader"); + } + + if (key.split) { + variant += "_split"; + } else { + defines.push_back("NO_SPLIT"); + } + + defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size)); + + auto processed = preprocessor.preprocess(wgsl_glu, defines); + auto decisions = std::make_shared<ggml_webgpu_generic_shader_decisions>(); + decisions->wg_size = context.max_wg_size; + webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant); + pipeline.context = decisions; + glu_pipelines[key] = pipeline; + return glu_pipelines[key]; + } + + webgpu_pipeline get_rope_pipeline(const ggml_webgpu_shader_lib_context & context) { + ggml_webgpu_rope_pipeline_key key = {}; + key.type = context.dst->type; + key.inplace = ggml_webgpu_tensor_equal(context.src0, context.dst); + key.has_ff = (context.src2 != nullptr); + + auto it = rope_pipelines.find(key); + if (it != rope_pipelines.end()) { + return it->second; + } + + std::vector<std::string> defines; + std::string variant = "rope"; + + switch (key.type) { + case GGML_TYPE_F32: + defines.push_back("TYPE_F32"); + variant += "_f32"; + break; + case GGML_TYPE_F16: + defines.push_back("TYPE_F16"); + variant += "_f16"; + break; + default: + GGML_ABORT("Unsupported type for ROPE shader"); + } + + if (key.inplace) { + defines.push_back("INPLACE"); + variant += "_inplace"; + } + + if (key.has_ff) { + defines.push_back("FF_FUNC"); + variant += "_ff"; + } + + defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size)); + + auto processed = preprocessor.preprocess(wgsl_rope, defines); + auto decisions = std::make_shared<ggml_webgpu_generic_shader_decisions>(); + decisions->wg_size = context.max_wg_size; + decisions->inplace = key.inplace; + webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant); + pipeline.context = decisions; + rope_pipelines[key] = pipeline; + return rope_pipelines[key]; + } + + webgpu_pipeline get_soft_max_pipeline(const ggml_webgpu_shader_lib_context & context) { + ggml_webgpu_soft_max_pipeline_key key = {}; + key.mask_type = context.src1 ? context.src1->type : GGML_TYPE_F32; + key.has_mask = (context.src1 != nullptr); + key.has_sink = (context.src2 != nullptr); + key.inplace = ggml_webgpu_tensor_equal(context.src0, context.dst); + + auto it = soft_max_pipelines.find(key); + if (it != soft_max_pipelines.end()) { + return it->second; + } + + std::vector<std::string> defines; + std::string variant = "soft_max"; + + if (key.has_mask) { + defines.push_back("HAS_MASK"); + switch (key.mask_type) { + case GGML_TYPE_F32: + defines.push_back("MASK_F32"); + variant += "_mask_f32"; + break; + case GGML_TYPE_F16: + defines.push_back("MASK_F16"); + variant += "_mask_f16"; + break; + default: + GGML_ABORT("Unsupported type for SOFT_MAX shader"); + } + } + + if (key.has_sink) { + defines.push_back("HAS_SINK"); + variant += "_sink"; + } + + if (key.inplace) { + defines.push_back("INPLACE"); + variant += "_inplace"; + } + + defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size)); + + auto processed = preprocessor.preprocess(wgsl_soft_max, defines); + auto decisions = std::make_shared<ggml_webgpu_generic_shader_decisions>(); + decisions->wg_size = context.max_wg_size; + decisions->inplace = key.inplace; + webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant); + pipeline.context = decisions; + soft_max_pipelines[key] = pipeline; + return soft_max_pipelines[key]; + } + + webgpu_pipeline get_conv2d_pipeline(const ggml_webgpu_shader_lib_context & context) { + ggml_webgpu_conv2d_pipeline_key key = {}; + key.weight_type = context.src0->type; + key.input_type = context.src1->type; + key.output_type = context.dst->type; + + auto it = conv2d_pipelines.find(key); + if (it != conv2d_pipelines.end()) { + return it->second; + } + + std::vector<std::string> defines; + std::string variant = "conv_2d"; + + auto push_type_defines = [&](const char * prefix, ggml_type type) { + std::string s_prefix = prefix; + if (type == GGML_TYPE_F32) { + defines.push_back(s_prefix + "_F32"); + } else if (type == GGML_TYPE_F16) { + defines.push_back(s_prefix + "_F16"); + } else { + GGML_ABORT("Unsupported type for CONV_2D shader"); + } + }; + + push_type_defines("WEIGHT", key.weight_type); + push_type_defines("INPUT", key.input_type); + push_type_defines("OUTPUT", key.output_type); + + defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size)); + + auto processed = preprocessor.preprocess(wgsl_conv2d, defines); + auto decisions = std::make_shared<ggml_webgpu_generic_shader_decisions>(); + decisions->wg_size = context.max_wg_size; + webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant); + pipeline.context = decisions; + conv2d_pipelines[key] = pipeline; + return conv2d_pipelines[key]; + } + + webgpu_pipeline get_im2col_pipeline(const ggml_webgpu_shader_lib_context & context) { + ggml_webgpu_im2col_pipeline_key key = {}; + key.input_type = context.src1->type; + key.output_type = context.dst->type; + + auto it = im2col_pipelines.find(key); + if (it != im2col_pipelines.end()) { + return it->second; + } + + std::vector<std::string> defines; + std::string variant = "im2col"; + + auto push_type_defines = [&](const char * prefix, ggml_type type) { + std::string s_prefix = prefix; + if (type == GGML_TYPE_F32) { + defines.push_back(s_prefix + "_F32"); + } else if (type == GGML_TYPE_F16) { + defines.push_back(s_prefix + "_F16"); + } else { + GGML_ABORT("Unsupported type for IM2COL shader"); + } + }; + + push_type_defines("INPUT", key.input_type); + push_type_defines("OUTPUT", key.output_type); + + defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size)); + + auto processed = preprocessor.preprocess(wgsl_im2col, defines); + auto decisions = std::make_shared<ggml_webgpu_generic_shader_decisions>(); + decisions->wg_size = context.max_wg_size; + webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant); + pipeline.context = decisions; + im2col_pipelines[key] = pipeline; + return im2col_pipelines[key]; + } + + webgpu_pipeline get_upscale_pipeline(const ggml_webgpu_shader_lib_context & context) { + const uint32_t mode_flags = (uint32_t) ggml_get_op_params_i32(context.dst, 0); + const uint32_t base_mode = mode_flags & 0xFFu; + const bool antialias = (mode_flags & GGML_SCALE_FLAG_ANTIALIAS) != 0u; + + ggml_webgpu_upscale_pipeline_key key = {}; + key.input_type = context.src0->type; + key.output_type = context.dst->type; + key.base_mode = base_mode; + key.antialias = antialias; + + auto it = upscale_pipelines.find(key); + if (it != upscale_pipelines.end()) { + return it->second; + } + + std::vector<std::string> defines; + std::string variant = "upscale"; + + if (key.input_type == GGML_TYPE_F16) { + defines.push_back("SRC_F16"); + variant += "_src_f16"; + } else { + variant += "_src_f32"; + } + + if (key.output_type == GGML_TYPE_F16) { + defines.push_back("DST_F16"); + variant += "_dst_f16"; + } else { + variant += "_dst_f32"; + } + + switch (base_mode) { + case GGML_SCALE_MODE_NEAREST: + defines.push_back("NEAREST"); + variant += "_nearest"; + break; + case GGML_SCALE_MODE_BILINEAR: + defines.push_back("BILINEAR"); + variant += "_bilinear"; + break; + case GGML_SCALE_MODE_BICUBIC: + defines.push_back("BICUBIC"); + variant += "_bicubic"; + break; + default: + GGML_ABORT("Unsupported upscale mode"); + } + + if (antialias) { + defines.push_back("ANTIALIAS"); + variant += "_aa"; + } + + defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size)); + + auto processed = preprocessor.preprocess(wgsl_upscale, defines); + auto decisions = std::make_shared<ggml_webgpu_generic_shader_decisions>(); + decisions->wg_size = context.max_wg_size; + webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant); + pipeline.context = decisions; + upscale_pipelines[key] = pipeline; + return upscale_pipelines[key]; + } + + private: + static webgpu_pipeline ggml_webgpu_create_pipeline(wgpu::Device & device, + std::string shader_code, + std::string label) { + wgpu::ShaderSourceWGSL shader_source; + shader_source.code = shader_code.c_str(); + + wgpu::ShaderModuleDescriptor shader_desc; + shader_desc.nextInChain = &shader_source; + + wgpu::ShaderModule shader_module = device.CreateShaderModule(&shader_desc); + + wgpu::ComputePipelineDescriptor pipeline_desc; + pipeline_desc.label = label.c_str(); + pipeline_desc.compute.module = shader_module; + pipeline_desc.compute.entryPoint = "main"; // Entry point in the WGSL code + pipeline_desc.layout = nullptr; // nullptr means auto layout + return { device.CreateComputePipeline(&pipeline_desc), label }; + } +}; + #endif // GGML_WEBGPU_SHADER_LIB_HPP diff --git a/ggml/src/ggml-webgpu/ggml-webgpu.cpp b/ggml/src/ggml-webgpu/ggml-webgpu.cpp index 5b8f7f72d57..0b605fa86ba 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu.cpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu.cpp @@ -8,8 +8,7 @@ #include "ggml-backend-impl.h" #include "ggml-impl.h" #include "ggml-webgpu-shader-lib.hpp" -#include "ggml-wgsl-shaders.hpp" -#include "pre_wgsl.hpp" +#include "ggml.h" #ifdef __EMSCRIPTEN__ # include <emscripten/emscripten.h> @@ -18,19 +17,37 @@ #include <webgpu/webgpu_cpp.h> #include <atomic> -#include <condition_variable> #include <cstdint> #include <cstring> -#include <iostream> -#include <map> +#ifdef GGML_WEBGPU_GPU_PROFILE +# include <iomanip> +#endif +#if defined(GGML_WEBGPU_DEBUG) || defined(GGML_WEBGPU_CPU_PROFILE) || defined(GGML_WEBGPU_GPU_PROFILE) +# include <iostream> +#endif +#include <memory> #include <mutex> #include <optional> #include <string> +#include <utility> #include <vector> #define ROUNDUP_POW2(x, pow2) (((x) + ((pow2) - 1)) & ~((pow2) - 1)) #define CEIL_DIV(M, N) (((M) + (N) - 1) / (N)) +// Return a rectangular grid of workgroups with minimal over-provisioned workgroups. +// Assumes that the total number of workgroups does not exceed max_per_dim^2. +static inline void compute_2d_workgroups(uint32_t total_wg, uint32_t max_per_dim, uint32_t & wg_x, uint32_t & wg_y) { + wg_y = std::max(1u, CEIL_DIV(total_wg, max_per_dim)); + wg_x = CEIL_DIV(total_wg, wg_y); +} + +static inline uint32_t ggml_webgpu_u32_from_f32(float value) { + uint32_t bits; + memcpy(&bits, &value, sizeof(bits)); + return bits; +} + #ifdef GGML_WEBGPU_DEBUG # define WEBGPU_LOG_DEBUG(msg) std::cout << msg << std::endl # define WEBGPU_DEBUG_BUF_ELEMS 512 @@ -47,7 +64,6 @@ double cpu_total_time_##id = \ std::chrono::duration<double, std::milli>(cpu_total_end_##id - cpu_total_start_##id).count(); \ (ctx)->cpu_time_ms[#id] += cpu_total_time_##id; - // fine-grained timing (not included in totals) # define WEBGPU_CPU_PROFILE_DETAIL_START(id) auto cpu_detail_start_##id = std::chrono::high_resolution_clock::now(); @@ -64,64 +80,29 @@ #endif // GGML_WEBGPU_CPU_PROFILE #ifdef GGML_WEBGPU_GPU_PROFILE -# define WEBGPU_NUM_TIMESTAMP_QUERY_BUFS 24 -# define WEBGPU_TIMESTAMP_QUERY_BUF_SIZE_BYTES 16 // e.g. enough for two timestamps +# define WEBGPU_MAX_PROFILE_QUERY_COUNT 4096u +# define WEBGPU_TIMESTAMP_QUERY_BUF_SIZE_BYTES (WEBGPU_MAX_PROFILE_QUERY_COUNT * sizeof(uint64_t)) #endif /* Constants */ -// Track https://github.com/gpuweb/gpuweb/issues/5315 for fixes to implementations so this can be removed. -#define WEBGPU_MAX_WG_SIZE 288 - -#define WEBGPU_MUL_MAT_WG_SIZE 256 -#define WEBGPU_NUM_PARAM_BUFS 32u -#define WEBGPU_COMMAND_SUBMIT_BATCH_SIZE 8u -#define WEBGPU_WAIT_ANY_TIMEOUT_MS 0 -// Maximum number of in-flight submissions per-thread, to avoid exhausting the parameter buffer pool -#define WEBGPU_MAX_INFLIGHT_SUBS_PER_THREAD WEBGPU_NUM_PARAM_BUFS / WEBGPU_COMMAND_SUBMIT_BATCH_SIZE -#define WEBGPU_PARAMS_BUF_SIZE_BYTES 128 // enough for 32 parameters -#define WEBGPU_NUM_SET_ROWS_ERROR_BUFS 32 -#define WEBGPU_SET_ROWS_ERROR_BUF_SIZE_BYTES 4 -#define WEBGPU_STORAGE_BUF_BINDING_MULT 4 // a storage buffer binding size must be a multiple of 4 - -// For operations which process a row in parallel, this seems like a reasonable default -#define WEBGPU_ROW_SPLIT_WG_SIZE 64 - -// Matrix multiplication parameters - -// Register tiling parameters -#define WEBGPU_MUL_MAT_TILE_M 8 -#define WEBGPU_MUL_MAT_TILE_N 8 -#define WEBGPU_MUL_MAT_WG_SIZE_M 8 -#define WEBGPU_MUL_MAT_WG_SIZE_N 8 -#define WEBGPU_MUL_MAT_TILE_K 32 - -// Subgroup matrix parameters -// The number of subgroups in the M dimension -#define WEBGPU_MUL_MAT_SUBGROUP_M 2 -// The number of subgroups in the N dimension -#define WEBGPU_MUL_MAT_SUBGROUP_N 2 -// The number of subgroup matrices each subgroup accumulates over -#define WEBGPU_MUL_MAT_SUBGROUP_MATRIX_M 4 -#define WEBGPU_MUL_MAT_SUBGROUP_MATRIX_N 2 - -// Matrix-vector multiplication parameters -#define WEBGPU_MUL_MAT_VEC_WG_SIZE 256 -// Must be multiple of 4 to work with vectorized paths, and must divide mul_mat_vec wg size -#define WEBGPU_MUL_MAT_VEC_OUTPUTS_PER_WG 64 -#define WEBGPU_MUL_MAT_VEC_TILE_K 256 +#define WEBGPU_DEFAULT_COMMAND_SUBMIT_BATCH_SIZE 64u +#define WEBGPU_NUM_PARAM_SLOT_SAFETY_MARGIN 10u +#define WEBGPU_RUNTIME_WAIT_TIMEOUT_MS 30000u +#define WEBGPU_RUNTIME_WAIT_TIMEOUT_NS (WEBGPU_RUNTIME_WAIT_TIMEOUT_MS * 1e6) +#define WEBGPU_PARAMS_BUF_SIZE_BYTES 128 // enough for 32 parameters +#define WEBGPU_SET_ROWS_ERROR_BUF_SIZE_BYTES 4 +#define WEBGPU_STORAGE_BUF_BINDING_MULT 4 // a storage buffer binding size must be a multiple of 4 /* End Constants */ -// This is a "fake" base pointer, since WebGPU buffers do not have pointers to their locations. +// This is a "fake" base pointer, since WebGPU buffers do not have pointers to +// their locations. static void * const webgpu_ptr_base = (void *) (uintptr_t) 0x1000; // NOLINT -// Always returns the base offset of a tensor, regardless of views. -static uint64_t webgpu_tensor_offset(const ggml_tensor * tensor) { - if (tensor->view_src) { - return (uint8_t *) tensor->view_src->data - (uint8_t *) webgpu_ptr_base; - } - return (uint8_t *) tensor->data - (uint8_t *) webgpu_ptr_base; +static size_t ggml_webgpu_tensor_offset(const ggml_tensor * tensor) { + const ggml_tensor * base_tensor = tensor->view_src ? tensor->view_src : tensor; + return (size_t) ((uintptr_t) base_tensor->data - (uintptr_t) webgpu_ptr_base) + tensor->view_offs; } /* Struct definitions */ @@ -133,303 +114,222 @@ static void ggml_webgpu_create_buffer(wgpu::Device & device, wgpu::BufferUsage usage, const char * label); -struct webgpu_pool_bufs { - wgpu::Buffer host_buf; - wgpu::Buffer dev_buf; -}; - -// The futures to wait on for a single queue submission -struct webgpu_submission_futures { - std::vector<wgpu::FutureWaitInfo> futures; -}; - -// Holds a pool of parameter buffers for WebGPU operations -struct webgpu_buf_pool { - std::vector<webgpu_pool_bufs> free; - - std::mutex mutex; - - std::condition_variable cv; - - void init(wgpu::Device device, - int num_bufs, - size_t buf_size, - wgpu::BufferUsage dev_buf_usage, - wgpu::BufferUsage host_buf_usage) { - for (int i = 0; i < num_bufs; i++) { - wgpu::Buffer host_buf; - wgpu::Buffer dev_buf; - ggml_webgpu_create_buffer(device, host_buf, buf_size, host_buf_usage, "ggml_webgpu_host_pool_buf"); - ggml_webgpu_create_buffer(device, dev_buf, buf_size, dev_buf_usage, "ggml_webgpu_dev_pool_buf"); - free.push_back({ host_buf, dev_buf }); - } - } - - webgpu_pool_bufs alloc_bufs() { - std::unique_lock<std::mutex> lock(mutex); - cv.wait(lock, [this] { return !free.empty(); }); - webgpu_pool_bufs bufs = free.back(); - free.pop_back(); - return bufs; - } - - void free_bufs(std::vector<webgpu_pool_bufs> bufs) { - std::lock_guard<std::mutex> lock(mutex); - free.insert(free.end(), bufs.begin(), bufs.end()); - cv.notify_all(); - } - - void cleanup() { - std::lock_guard<std::mutex> lock(mutex); - for (auto & bufs : free) { - bufs.host_buf.Destroy(); - bufs.dev_buf.Destroy(); - } - free.clear(); +// Slot-based parameter arena for compute graph encoding. Each encoded kernel +// gets a unique uniform-buffer slice within the current batch, and the slot +// cursor is reset immediately after that batch is submitted. +struct webgpu_param_arena { + wgpu::Buffer buffer; + size_t slot_stride = 0; + size_t slot_size = 0; + uint32_t slot_count = 0; + uint32_t next_slot = 0; + + void init(wgpu::Device device, size_t slot_size, uint32_t slot_count, size_t alignment) { + this->slot_stride = ROUNDUP_POW2(slot_size, alignment); + this->slot_size = slot_size; + this->slot_count = slot_count; + this->next_slot = 0; + + ggml_webgpu_create_buffer(device, buffer, this->slot_stride * slot_count, + wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::Uniform, "ggml_webgpu_param_arena"); } -}; - -#ifdef GGML_WEBGPU_GPU_PROFILE -struct webgpu_gpu_profile_bufs { - wgpu::Buffer host_buf; - wgpu::Buffer dev_buf; - wgpu::QuerySet query_set; -}; - -// Holds a pool of parameter buffers for WebGPU operations -struct webgpu_gpu_profile_buf_pool { - std::vector<webgpu_gpu_profile_bufs> free; - - std::mutex mutex; - - std::condition_variable cv; - - void init(wgpu::Device device, - int num_bufs, - size_t buf_size, - wgpu::BufferUsage dev_buf_usage, - wgpu::BufferUsage host_buf_usage) { - for (int i = 0; i < num_bufs; i++) { - wgpu::Buffer host_buf; - wgpu::Buffer dev_buf; - ggml_webgpu_create_buffer(device, host_buf, buf_size, host_buf_usage, "ggml_webgpu_host_profile_buf"); - ggml_webgpu_create_buffer(device, dev_buf, buf_size, dev_buf_usage, "ggml_webgpu_dev_profile_buf"); - // Create a query set for 2 timestamps - wgpu::QuerySetDescriptor ts_query_set_desc = {}; - - ts_query_set_desc.type = wgpu::QueryType::Timestamp; - ts_query_set_desc.count = 2; - wgpu::QuerySet ts_query_set = device.CreateQuerySet(&ts_query_set_desc); - free.push_back({ host_buf, dev_buf, ts_query_set }); + size_t alloc_slot(size_t size) { + GGML_ASSERT(size <= slot_size); + if (next_slot >= slot_count) { + GGML_ABORT("ggml_webgpu: parameter arena exhausted while encoding a batch"); } - } - webgpu_gpu_profile_bufs alloc_bufs() { - std::unique_lock<std::mutex> lock(mutex); - cv.wait(lock, [this] { return !free.empty(); }); - webgpu_gpu_profile_bufs bufs = free.back(); - free.pop_back(); - return bufs; + return slot_stride * next_slot++; } - void free_bufs(std::vector<webgpu_gpu_profile_bufs> bufs) { - std::lock_guard<std::mutex> lock(mutex); - free.insert(free.end(), bufs.begin(), bufs.end()); - cv.notify_all(); - } + void reset() { next_slot = 0; } void cleanup() { - std::lock_guard<std::mutex> lock(mutex); - for (auto & bufs : free) { - bufs.host_buf.Destroy(); - bufs.dev_buf.Destroy(); - bufs.query_set.Destroy(); + if (buffer) { + buffer.Destroy(); + buffer = nullptr; } - free.clear(); } -}; -#endif -struct webgpu_pipeline { - wgpu::ComputePipeline pipeline; - std::string name; - void * context = nullptr; + ~webgpu_param_arena() { this->cleanup(); } }; -struct webgpu_command { - wgpu::CommandBuffer commands; - webgpu_pool_bufs params_bufs; - std::optional<webgpu_pool_bufs> set_rows_error_bufs; +struct webgpu_encoded_op { + uint32_t num_kernels = 0; #ifdef GGML_WEBGPU_GPU_PROFILE - webgpu_gpu_profile_bufs timestamp_query_bufs; - std::string pipeline_name; + std::vector<std::string> pipeline_names; #endif }; -struct flash_attn_pipeline_key { - int q_type; - int kv_type; - int dst_type; - uint32_t head_dim_qk; - uint32_t head_dim_v; - bool kv_direct; - bool has_mask; - bool has_sinks; - bool uses_logit_softcap; - - bool operator==(const flash_attn_pipeline_key & other) const { - return q_type == other.q_type && kv_type == other.kv_type && dst_type == other.dst_type && - head_dim_qk == other.head_dim_qk && head_dim_v == other.head_dim_v && kv_direct == other.kv_direct && - has_mask == other.has_mask && has_sinks == other.has_sinks && - uses_logit_softcap == other.uses_logit_softcap; - } +struct webgpu_dispatch_desc { + webgpu_pipeline pipeline; + std::vector<uint32_t> params; + std::vector<wgpu::BindGroupEntry> bind_group_entries; + std::pair<uint32_t, uint32_t> workgroups = { 1, 1 }; }; -// Same hash combine function as in boost -template <typename T> inline void ggml_webgpu_hash_combine(size_t & seed, const T & value) { - seed ^= std::hash<T>{}(value) + 0x9e3779b9 + (seed << 6) + (seed >> 2); -} - -struct flash_attn_pipeline_key_hash { - size_t operator()(const flash_attn_pipeline_key & key) const { - size_t seed = 0; - ggml_webgpu_hash_combine(seed, key.q_type); - ggml_webgpu_hash_combine(seed, key.kv_type); - ggml_webgpu_hash_combine(seed, key.dst_type); - ggml_webgpu_hash_combine(seed, key.head_dim_qk); - ggml_webgpu_hash_combine(seed, key.head_dim_v); - ggml_webgpu_hash_combine(seed, key.kv_direct); - ggml_webgpu_hash_combine(seed, key.has_mask); - ggml_webgpu_hash_combine(seed, key.has_sinks); - ggml_webgpu_hash_combine(seed, key.uses_logit_softcap); - return seed; - } +struct webgpu_capabilities { + wgpu::Limits limits; + bool supports_subgroups = false; + bool supports_subgroup_matrix = false; + bool supports_dot_product = false; + + uint32_t sg_mat_m = 0; + uint32_t sg_mat_n = 0; + uint32_t sg_mat_k = 0; + + uint32_t subgroup_size = 0; + uint32_t min_subgroup_size = 0; + uint32_t max_subgroup_size = 0; + size_t memset_bytes_per_thread; }; -// All the base objects needed to run operations on a WebGPU device -struct webgpu_context_struct { +// Stores global webgpu members +struct webgpu_global_context_struct { wgpu::Instance instance; wgpu::Adapter adapter; wgpu::Device device; wgpu::Queue queue; - wgpu::Limits limits; - - uint32_t max_subgroup_size; - - bool supports_subgroup_matrix = false; - uint32_t sg_mat_m; - uint32_t sg_mat_n; - uint32_t sg_mat_k; + uint32_t command_submit_batch_size = WEBGPU_DEFAULT_COMMAND_SUBMIT_BATCH_SIZE; + uint32_t max_inflight_batches = UINT32_MAX; + webgpu_capabilities capabilities; + // Shared buffer to move data from device to host + wgpu::Buffer get_tensor_staging_buf; + // Global mutex for get_tensor std::recursive_mutex mutex; - std::atomic_uint inflight_threads = 0; - - webgpu_buf_pool param_buf_pool; - webgpu_buf_pool set_rows_error_buf_pool; - - pre_wgsl::Preprocessor p; - - std::map<int, webgpu_pipeline> memset_pipelines; // variant or type index - - std::map<int, std::map<int, std::map<int, webgpu_pipeline>>> mul_mat_pipelines; // src0_type, src1_type, vectorized - std::map<int, std::map<int, std::map<int, webgpu_pipeline>>> - mul_mat_vec_pipelines; // src0_type, src1_type, vectorized - std::unordered_map<flash_attn_pipeline_key, webgpu_pipeline, flash_attn_pipeline_key_hash> flash_attn_pipelines; + wgpu::Buffer memset_params_buf; + webgpu_pipeline memset_pipeline; - std::map<int, std::map<int, webgpu_pipeline>> set_rows_pipelines; // dst_type, vectorized - std::map<int, std::map<int, webgpu_pipeline>> get_rows_pipelines; // src_type, vectorized + std::string vendor; - std::map<int, std::map<int, webgpu_pipeline>> cpy_pipelines; // src_type, dst_type - std::map<int, std::map<int, webgpu_pipeline>> add_pipelines; // type, inplace - std::map<int, std::map<int, webgpu_pipeline>> sub_pipelines; // type, inplace - std::map<int, std::map<int, webgpu_pipeline>> mul_pipelines; // type, inplace - std::map<int, std::map<int, webgpu_pipeline>> div_pipelines; // type, inplace - - std::map<int, webgpu_pipeline> rms_norm_pipelines; // inplace - std::map<int, std::map<int, std::map<int, webgpu_pipeline>>> rope_pipelines; // type, ff, inplace - std::map<int, std::map<int, std::map<int, webgpu_pipeline>>> glu_pipelines; // glu_op, type, split - std::map<int, webgpu_pipeline> scale_pipelines; // inplace - std::map<int, std::map<int, std::map<int, webgpu_pipeline>>> soft_max_pipelines; // mask_type, has_sink, inplace - std::map<int, std::map<int, std::map<int, webgpu_pipeline>>> unary_pipelines; // unary_op, type, inplace - - size_t memset_bytes_per_thread; - - // Staging buffer for reading data from the GPU - wgpu::Buffer get_tensor_staging_buf; + // TODO: We should rework the CPU profiling time handling to make it more useful. ref: https://github.com/ggml-org/llama.cpp/pull/22050 +#ifdef GGML_WEBGPU_CPU_PROFILE + // Profiling: labeled CPU time in ms (total) + std::unordered_map<std::string, double> cpu_time_ms; + // Profiling: detailed CPU time in ms + std::unordered_map<std::string, double> cpu_detail_ms; +#endif #ifdef GGML_WEBGPU_DEBUG wgpu::Buffer debug_host_buf; wgpu::Buffer debug_dev_buf; #endif -#ifdef GGML_WEBGPU_CPU_PROFILE - // Profiling: labeled CPU time in ms (total) - std::unordered_map<std::string, double> cpu_time_ms; - // Profiling: detailed CPU time in ms - std::unordered_map<std::string, double> cpu_detail_ms; + ~webgpu_global_context_struct() { + if (this->get_tensor_staging_buf) { + this->get_tensor_staging_buf.Destroy(); + this->get_tensor_staging_buf = nullptr; + } + if (this->memset_params_buf) { + this->memset_params_buf.Destroy(); + this->memset_params_buf = nullptr; + } +#ifdef GGML_WEBGPU_DEBUG + if (this->debug_host_buf) { + this->debug_host_buf.Destroy(); + this->debug_host_buf = nullptr; + } + if (this->debug_dev_buf) { + this->debug_dev_buf.Destroy(); + this->debug_dev_buf = nullptr; + } #endif + } +}; + +typedef std::shared_ptr<webgpu_global_context_struct> webgpu_global_context; + +// All the base objects needed to run operations on a WebGPU device +struct webgpu_context_struct { + // Points to global instances owned by ggml_backend_webgpu_reg_context + webgpu_global_context global_ctx; + + std::unique_ptr<ggml_webgpu_shader_lib> shader_lib; + + webgpu_param_arena param_arena; + wgpu::Buffer set_rows_dev_error_buf; + wgpu::Buffer set_rows_host_error_buf; + wgpu::CommandEncoder active_command_encoder; + wgpu::ComputePassEncoder active_compute_pass; + bool batch_compute_passes = true; + + size_t memset_bytes_per_thread; #ifdef GGML_WEBGPU_GPU_PROFILE // Profiling: per-shader GPU time in ms std::unordered_map<std::string, double> shader_gpu_time_ms; - // Profiling: pool of timestamp query buffers (one per operation) - webgpu_gpu_profile_buf_pool timestamp_query_buf_pool; + wgpu::Buffer profile_timestamp_dev_buf; + wgpu::Buffer profile_timestamp_host_buf; + wgpu::QuerySet profile_timestamp_query_set; + uint32_t profile_timestamp_query_count = 0; +#endif + + ~webgpu_context_struct() { +#ifdef GGML_WEBGPU_GPU_PROFILE + if (this->profile_timestamp_host_buf) { + this->profile_timestamp_host_buf.Destroy(); + this->profile_timestamp_host_buf = nullptr; + } + if (this->profile_timestamp_dev_buf) { + this->profile_timestamp_dev_buf.Destroy(); + this->profile_timestamp_dev_buf = nullptr; + } + if (this->profile_timestamp_query_set) { + this->profile_timestamp_query_set.Destroy(); + this->profile_timestamp_query_set = nullptr; + } #endif + if (this->set_rows_host_error_buf) { + this->set_rows_host_error_buf.Destroy(); + this->set_rows_host_error_buf = nullptr; + } + if (this->set_rows_dev_error_buf) { + this->set_rows_dev_error_buf.Destroy(); + this->set_rows_dev_error_buf = nullptr; + } + } }; typedef std::shared_ptr<webgpu_context_struct> webgpu_context; +// Metadata required for the ggml backend registration/discovery interface struct ggml_backend_webgpu_reg_context { - webgpu_context webgpu_ctx; - size_t device_count; - const char * name; + // Since the Instance is a global entrypoint into the WebGPU API, it lives here + webgpu_global_context webgpu_global_ctx; + size_t device_count; + const char * name; }; +// Per-device struct for the global logical device interface struct ggml_backend_webgpu_device_context { - webgpu_context webgpu_ctx; - std::string device_name; - std::string device_desc; + webgpu_global_context webgpu_global_ctx; + std::string device_name; + std::string device_desc; }; +// Per-thread data required to actually run WebGPU operations in a backend instance struct ggml_backend_webgpu_context { webgpu_context webgpu_ctx; std::string name; }; +// Per-thread data related to buffers struct ggml_backend_webgpu_buffer_context { - webgpu_context webgpu_ctx; - wgpu::Buffer buffer; - std::string label; + wgpu::Buffer buffer; + std::string label; + webgpu_global_context global_ctx; - ggml_backend_webgpu_buffer_context(webgpu_context ctx, wgpu::Buffer buf, std::string lbl) : - webgpu_ctx(std::move(ctx)), + ggml_backend_webgpu_buffer_context(wgpu::Buffer buf, std::string lbl, webgpu_global_context global_ctx_) : buffer(std::move(buf)), - label(std::move(lbl)) {} + label(std::move(lbl)), + global_ctx(std::move(global_ctx_)) {} }; /* WebGPU object initializations */ -// Process a WGSL shader string, replacing tokens of the form {{KEY}} with -// the corresponding values provided in `repls`. -static std::string ggml_webgpu_process_shader_repls(const char * src, - const std::map<std::string, std::string> & repls) { - if (!src) { - return std::string(); - } - std::string s = src; - for (const auto & kv : repls) { - std::string token = "{{" + kv.first + "}}"; - size_t pos = 0; - while ((pos = s.find(token, pos)) != std::string::npos) { - s.replace(pos, token.length(), kv.second); - pos += kv.second.length(); - } - } - return s; -} - static webgpu_pipeline ggml_webgpu_create_pipeline(wgpu::Device & device, const char * shader_code, const char * label, @@ -469,63 +369,158 @@ static void ggml_webgpu_create_buffer(wgpu::Device & device, buffer = device.CreateBuffer(&buffer_desc); } +static wgpu::Buffer ggml_webgpu_tensor_buf(const ggml_tensor * tensor) { + ggml_backend_webgpu_buffer_context * ctx = (ggml_backend_webgpu_buffer_context *) tensor->buffer->context; + return ctx->buffer; +} + +static size_t ggml_webgpu_tensor_misalignment(webgpu_context & ctx, const ggml_tensor * t) { + size_t offset = ggml_webgpu_tensor_offset(t); + return offset & (ctx->global_ctx->capabilities.limits.minStorageBufferOffsetAlignment - 1); +} + +static size_t ggml_webgpu_tensor_align_offset(webgpu_context & ctx, const ggml_tensor * t) { + size_t offset = ggml_webgpu_tensor_offset(t); + return offset & ~(ctx->global_ctx->capabilities.limits.minStorageBufferOffsetAlignment - 1); +} + +static size_t ggml_webgpu_tensor_binding_size(webgpu_context & ctx, ggml_tensor * t) { + return ROUNDUP_POW2(ggml_nbytes(t) + ggml_webgpu_tensor_misalignment(ctx, t), WEBGPU_STORAGE_BUF_BINDING_MULT); +} + +struct ggml_webgpu_merged_binding_range { + size_t offset; + size_t size; +}; + +static ggml_webgpu_merged_binding_range ggml_webgpu_tensor_merged_binding_range( + webgpu_context & ctx, + std::initializer_list<ggml_tensor *> tensors) { + size_t merged_offset = SIZE_MAX; + size_t merged_end = 0; + + for (ggml_tensor * tensor : tensors) { + const size_t bind_offset = ggml_webgpu_tensor_align_offset(ctx, tensor); + const size_t bind_end = bind_offset + ggml_webgpu_tensor_binding_size(ctx, tensor); + + merged_offset = std::min(merged_offset, bind_offset); + merged_end = std::max(merged_end, bind_end); + } + + return { merged_offset, merged_end - merged_offset }; +} + +static uint32_t ggml_webgpu_tensor_merged_element_offset(const ggml_tensor * tensor, + const ggml_webgpu_merged_binding_range & merged_range) { + return (uint32_t) ((ggml_webgpu_tensor_offset(tensor) - merged_range.offset) / ggml_type_size(tensor->type)); +} + +static wgpu::BindGroupEntry ggml_webgpu_make_bind_group_entry(uint32_t binding, + wgpu::Buffer buffer, + uint64_t offset, + uint64_t size) { + wgpu::BindGroupEntry entry = {}; + entry.binding = binding; + entry.buffer = std::move(buffer); + entry.offset = offset; + entry.size = size; + return entry; +} + +static wgpu::BindGroupEntry ggml_webgpu_make_tensor_bind_group_entry(webgpu_context & ctx, + uint32_t binding, + ggml_tensor * tensor) { + return ggml_webgpu_make_bind_group_entry(binding, ggml_webgpu_tensor_buf(tensor), + ggml_webgpu_tensor_align_offset(ctx, tensor), + ggml_webgpu_tensor_binding_size(ctx, tensor)); +} + /** End WebGPU object initializations */ /** WebGPU Actions */ -// Wait for the queue to finish processing all submitted work -static void ggml_backend_webgpu_wait(webgpu_context & ctx, - std::vector<webgpu_submission_futures> & futures, - bool block = true) { - // If we have too many in-flight submissions, wait on the oldest one first. If there are many threads, - // inflight_max may be 0, meaning that we must wait on all futures. - uint64_t timeout_ms = block ? UINT64_MAX : 0; - uint32_t inflight_threads = ctx->inflight_threads; - uint32_t inflight_max = WEBGPU_MAX_INFLIGHT_SUBS_PER_THREAD / std::max(inflight_threads, 1u); - while (futures.size() >= inflight_max && futures.size() > 0) { - ctx->instance.WaitAny(futures[0].futures.size(), futures[0].futures.data(), UINT64_MAX); - futures.erase(futures.begin()); - } - size_t i = 0; - while (i < futures.size()) { - auto waitStatus = ctx->instance.WaitAny(futures[i].futures.size(), futures[i].futures.data(), timeout_ms); - switch (waitStatus) { - case wgpu::WaitStatus::Success: - futures.erase(futures.begin() + i); - break; - case wgpu::WaitStatus::TimedOut: - i++; - break; - case wgpu::WaitStatus::Error: - GGML_LOG_ERROR("ggml_webgpu: WaitAny returned an error\n"); - break; - default: - GGML_LOG_ERROR("ggml_webgpu: WaitAny returned an unknown status\n"); - break; - } +template <typename T> +static void ggml_backend_webgpu_check_wait_status(wgpu::WaitStatus wait_status, + T callback_status, + T success_status, + const char * wait_name, + const char * failure_name, + const char * callback_message) { + if (wait_status == wgpu::WaitStatus::TimedOut) { + GGML_ABORT("ggml_webgpu: %s timed out after %u ms\n", wait_name, WEBGPU_RUNTIME_WAIT_TIMEOUT_MS); + } + if (wait_status == wgpu::WaitStatus::Error) { + GGML_ABORT("ggml_webgpu: %s failed\n", wait_name); } + if (callback_status != success_status) { + GGML_ABORT("ggml_webgpu: %s failed with status %d: %s\n", failure_name, static_cast<int>(callback_status), + callback_message); + } +} + +// TODO: these next two functions may want tuning across different platforms and workloads, +static uint32_t ggml_backend_webgpu_get_max_inflight_batches() { + return UINT32_MAX; +} + +static uint32_t ggml_backend_webgpu_get_command_submit_batch_size() { + return WEBGPU_DEFAULT_COMMAND_SUBMIT_BATCH_SIZE; +} + +static void ggml_backend_webgpu_wait_queue(webgpu_global_context & ctx) { + wgpu::QueueWorkDoneStatus callback_status = wgpu::QueueWorkDoneStatus::Error; + std::string callback_message; + + const wgpu::WaitStatus wait_status = ctx->instance.WaitAny( + ctx->queue.OnSubmittedWorkDone( + wgpu::CallbackMode::AllowSpontaneous, + [&callback_status, &callback_message](wgpu::QueueWorkDoneStatus status, wgpu::StringView message) { + callback_status = status; + callback_message = std::string(message); + }), + WEBGPU_RUNTIME_WAIT_TIMEOUT_NS); + + ggml_backend_webgpu_check_wait_status(wait_status, callback_status, wgpu::QueueWorkDoneStatus::Success, + "Queue wait", "Queue work", callback_message.c_str()); +} + +static void ggml_backend_webgpu_map_buffer(webgpu_global_context & ctx, + wgpu::Buffer & buffer, + wgpu::MapMode mode, + size_t offset, + size_t size) { + wgpu::MapAsyncStatus callback_status = wgpu::MapAsyncStatus::Error; + std::string callback_message; + + const wgpu::WaitStatus wait_status = ctx->instance.WaitAny( + buffer.MapAsync(mode, offset, size, wgpu::CallbackMode::AllowSpontaneous, + [&callback_status, &callback_message](wgpu::MapAsyncStatus status, wgpu::StringView message) { + callback_status = status; + callback_message = std::string(message); + }), + WEBGPU_RUNTIME_WAIT_TIMEOUT_NS); + + ggml_backend_webgpu_check_wait_status(wait_status, callback_status, wgpu::MapAsyncStatus::Success, + "Buffer map wait", "Buffer map", callback_message.c_str()); } -static void ggml_backend_webgpu_map_buffer(webgpu_context & ctx, - wgpu::Buffer & buffer, - wgpu::MapMode mode, - size_t offset, - size_t size) { - ctx->instance.WaitAny(buffer.MapAsync(mode, offset, size, wgpu::CallbackMode::AllowSpontaneous, - [](wgpu::MapAsyncStatus status, wgpu::StringView message) { - if (status != wgpu::MapAsyncStatus::Success) { - GGML_LOG_ERROR("ggml_webgpu: Failed to map buffer: %s\n", - message.data); - } - }), - UINT64_MAX); +static void ggml_backend_webgpu_submit_commands(webgpu_context & ctx, + const wgpu::CommandBuffer commands, + uint32_t & num_inflight_batches) { + if (num_inflight_batches >= ctx->global_ctx->max_inflight_batches) { + ggml_backend_webgpu_wait_queue(ctx->global_ctx); + num_inflight_batches = 0; + } + + ctx->global_ctx->queue.Submit(1, &commands); + num_inflight_batches++; } #ifdef GGML_WEBGPU_DEBUG // This function adds debugging information to shaders, as WebGPU does not support printing directly. // To use, add a bind group entry to the setup for the shader you are debugging, add the buffer and // debug statements in the shader, and then call this function after encoding the commands and submitting them. -static void ggml_backend_webgpu_debug(webgpu_context & ctx) { +static void ggml_backend_webgpu_debug(webgpu_global_context & ctx) { wgpu::CommandEncoder encoder = ctx->device.CreateCommandEncoder(); encoder.CopyBufferToBuffer(ctx->debug_dev_buf, 0, ctx->debug_host_buf, 0, ctx->debug_host_buf.GetSize()); wgpu::CommandBuffer commands = encoder.Finish(); @@ -537,172 +532,127 @@ static void ggml_backend_webgpu_debug(webgpu_context & ctx) { } #endif -static webgpu_submission_futures ggml_backend_webgpu_submit(webgpu_context ctx, std::vector<webgpu_command> commands) { - std::vector<wgpu::CommandBuffer> command_buffers; - std::vector<webgpu_pool_bufs> params_bufs; - std::vector<webgpu_pool_bufs> set_rows_error_bufs; -#ifdef GGML_WEBGPU_GPU_PROFILE - std::vector<std::pair<std::string, webgpu_gpu_profile_bufs>> pipeline_name_and_ts_bufs; -#endif - - for (const auto & command : commands) { - command_buffers.push_back(command.commands); - params_bufs.push_back(command.params_bufs); - if (command.set_rows_error_bufs) { - set_rows_error_bufs.push_back(command.set_rows_error_bufs.value()); - } +static webgpu_encoded_op ggml_backend_webgpu_build_multi(webgpu_context & ctx, + const std::vector<webgpu_dispatch_desc> & dispatches) { + webgpu_encoded_op result = {}; + std::vector<wgpu::BindGroup> bind_groups; + std::vector<size_t> param_offsets; + result.num_kernels = dispatches.size(); + + for (size_t i = 0; i < dispatches.size(); i++) { + const webgpu_dispatch_desc & dispatch = dispatches[i]; + const size_t param_size = dispatch.params.size() * sizeof(uint32_t); + const size_t param_offset = ctx->param_arena.alloc_slot(param_size); + + std::vector<wgpu::BindGroupEntry> entries = dispatch.bind_group_entries; + uint32_t params_binding_num = entries.size(); + entries.push_back(ggml_webgpu_make_bind_group_entry(params_binding_num, ctx->param_arena.buffer, param_offset, + ctx->param_arena.slot_size)); + + wgpu::BindGroupDescriptor bind_group_desc; + bind_group_desc.layout = dispatch.pipeline.pipeline.GetBindGroupLayout(0); + bind_group_desc.entryCount = entries.size(); + bind_group_desc.entries = entries.data(); + bind_group_desc.label = dispatch.pipeline.name.c_str(); + bind_groups.push_back(ctx->global_ctx->device.CreateBindGroup(&bind_group_desc)); + param_offsets.push_back(param_offset); } - ctx->queue.Submit(command_buffers.size(), command_buffers.data()); - - std::vector<wgpu::FutureWaitInfo> futures; - wgpu::Future p_f = ctx->queue.OnSubmittedWorkDone( - wgpu::CallbackMode::AllowSpontaneous, - [ctx, params_bufs](wgpu::QueueWorkDoneStatus status, wgpu::StringView message) { - if (status != wgpu::QueueWorkDoneStatus::Success) { - GGML_LOG_ERROR("ggml_webgpu: Failed to submit commands: %s\n", std::string(message).c_str()); - } - // Free the staged buffers - ctx->param_buf_pool.free_bufs({ params_bufs }); - }); - futures.push_back({ p_f }); - - for (const auto & bufs : set_rows_error_bufs) { - wgpu::Future f = bufs.host_buf.MapAsync( - wgpu::MapMode::Read, 0, bufs.host_buf.GetSize(), wgpu::CallbackMode::AllowSpontaneous, - [ctx, bufs](wgpu::MapAsyncStatus status, wgpu::StringView message) { - if (status != wgpu::MapAsyncStatus::Success) { - GGML_LOG_ERROR("ggml_webgpu: Failed to map error buffer: %s\n", std::string(message).c_str()); - } else { - const uint32_t * error_data = (const uint32_t *) bufs.host_buf.GetConstMappedRange(); - if (*error_data) { - GGML_ABORT("ggml_webgpu: SET_ROWS index > 2^32, unsupported."); - } - // We can't unmap in here due to WebGPU reentrancy limitations. - ctx->set_rows_error_buf_pool.free_bufs({ bufs }); - } - }); - futures.push_back({ f }); + for (size_t i = 0; i < param_offsets.size(); i++) { + ctx->global_ctx->queue.WriteBuffer(ctx->param_arena.buffer, param_offsets[i], dispatches[i].params.data(), + dispatches[i].params.size() * sizeof(uint32_t)); } #ifdef GGML_WEBGPU_GPU_PROFILE - for (const auto & command : commands) { - auto label = command.pipeline_name; - auto ts_bufs = command.timestamp_query_bufs; - - wgpu::Future f = ts_bufs.host_buf.MapAsync( - wgpu::MapMode::Read, 0, ts_bufs.host_buf.GetSize(), wgpu::CallbackMode::AllowSpontaneous, - [ctx, ts_bufs, label](wgpu::MapAsyncStatus status, wgpu::StringView message) { - if (status != wgpu::MapAsyncStatus::Success) { - GGML_LOG_ERROR("ggml_webgpu: Failed to map timestamp buffer: %s\n", std::string(message).c_str()); - } else { - const uint64_t * ts_data = (const uint64_t *) ts_bufs.host_buf.GetConstMappedRange(); - // WebGPU timestamps are in ns; convert to ms - double elapsed_ms = double(ts_data[1] - ts_data[0]) * 1e-6; - ctx->shader_gpu_time_ms[label] += elapsed_ms; - // We can't unmap in here due to WebGPU reentrancy limitations. - ctx->timestamp_query_buf_pool.free_bufs({ ts_bufs }); - } - }); - futures.push_back({ f }); + for (size_t i = 0; i < dispatches.size(); i++) { + GGML_ASSERT(ctx->profile_timestamp_query_count + 2 <= WEBGPU_MAX_PROFILE_QUERY_COUNT); + const uint32_t query_begin = ctx->profile_timestamp_query_count++; + const uint32_t query_end = ctx->profile_timestamp_query_count++; + + wgpu::PassTimestampWrites ts_writes = {}; + ts_writes.querySet = ctx->profile_timestamp_query_set; + ts_writes.beginningOfPassWriteIndex = query_begin; + ts_writes.endOfPassWriteIndex = query_end; + wgpu::ComputePassDescriptor pass_desc = {}; + pass_desc.timestampWrites = &ts_writes; + + wgpu::ComputePassEncoder pass = ctx->active_command_encoder.BeginComputePass(&pass_desc); + + pass.SetPipeline(dispatches[i].pipeline.pipeline); + pass.SetBindGroup(0, bind_groups[i]); + pass.DispatchWorkgroups(dispatches[i].workgroups.first, dispatches[i].workgroups.second, 1); + pass.End(); + result.pipeline_names.push_back(dispatches[i].pipeline.name); + } +#else + for (size_t i = 0; i < dispatches.size(); i++) { + if (ctx->batch_compute_passes) { + ctx->active_compute_pass.SetPipeline(dispatches[i].pipeline.pipeline); + ctx->active_compute_pass.SetBindGroup(0, bind_groups[i]); + ctx->active_compute_pass.DispatchWorkgroups(dispatches[i].workgroups.first, dispatches[i].workgroups.second, + 1); + } else { + wgpu::ComputePassEncoder pass = ctx->active_command_encoder.BeginComputePass(); + pass.SetPipeline(dispatches[i].pipeline.pipeline); + pass.SetBindGroup(0, bind_groups[i]); + pass.DispatchWorkgroups(dispatches[i].workgroups.first, dispatches[i].workgroups.second, 1); + pass.End(); + } } #endif - return { futures }; -} - -static webgpu_command ggml_backend_webgpu_build(webgpu_context & ctx, - webgpu_pipeline & pipeline, - std::vector<uint32_t> params, - std::vector<wgpu::BindGroupEntry> bind_group_entries, - uint32_t wg_x, - uint32_t wg_y = 1, - std::optional<webgpu_pool_bufs> set_rows_error_bufs = std::nullopt) { - webgpu_pool_bufs params_bufs = ctx->param_buf_pool.alloc_bufs(); - - ggml_backend_webgpu_map_buffer(ctx, params_bufs.host_buf, wgpu::MapMode::Write, 0, params_bufs.host_buf.GetSize()); - uint32_t * _params = (uint32_t *) params_bufs.host_buf.GetMappedRange(); - for (size_t i = 0; i < params.size(); i++) { - _params[i] = params[i]; - }; - params_bufs.host_buf.Unmap(); + return result; +} + +static webgpu_encoded_op ggml_backend_webgpu_build(webgpu_context & ctx, + webgpu_pipeline & pipeline, + std::vector<uint32_t> params, + std::vector<wgpu::BindGroupEntry> bind_group_entries, + uint32_t wg_x, + uint32_t wg_y = 1) { + return ggml_backend_webgpu_build_multi( + ctx, { + { pipeline, std::move(params), std::move(bind_group_entries), { wg_x, wg_y } }, + }); +} + +static void ggml_backend_webgpu_buffer_memset(webgpu_global_context & ctx, + wgpu::Buffer & buf, + uint32_t value, + size_t offset, + size_t size) { + std::vector<uint32_t> params = { (uint32_t) offset, (uint32_t) size, value }; + std::vector<wgpu::BindGroupEntry> entries = { ggml_webgpu_make_bind_group_entry(0, buf, 0, buf.GetSize()) }; + size_t bytes_per_wg = + ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup * ctx->capabilities.memset_bytes_per_thread; + uint32_t wg_x = CEIL_DIV(size + 3, bytes_per_wg); + + ctx->queue.WriteBuffer(ctx->memset_params_buf, 0, params.data(), params.size() * sizeof(uint32_t)); - uint32_t params_bufs_binding_num = bind_group_entries.size(); - bind_group_entries.push_back({ .binding = params_bufs_binding_num, - .buffer = params_bufs.dev_buf, - .offset = 0, - .size = params_bufs.dev_buf.GetSize() }); + wgpu::BindGroupEntry params_entry = {}; + params_entry.binding = 1; + params_entry.buffer = ctx->memset_params_buf; + params_entry.offset = 0; + params_entry.size = WEBGPU_PARAMS_BUF_SIZE_BYTES; + entries.push_back(params_entry); wgpu::BindGroupDescriptor bind_group_desc; - bind_group_desc.layout = pipeline.pipeline.GetBindGroupLayout(0); - bind_group_desc.entryCount = bind_group_entries.size(); - bind_group_desc.entries = bind_group_entries.data(); - bind_group_desc.label = pipeline.name.c_str(); + bind_group_desc.layout = ctx->memset_pipeline.pipeline.GetBindGroupLayout(0); + bind_group_desc.entryCount = entries.size(); + bind_group_desc.entries = entries.data(); + bind_group_desc.label = ctx->memset_pipeline.name.c_str(); wgpu::BindGroup bind_group = ctx->device.CreateBindGroup(&bind_group_desc); - wgpu::CommandEncoder encoder = ctx->device.CreateCommandEncoder(); - encoder.CopyBufferToBuffer(params_bufs.host_buf, 0, params_bufs.dev_buf, 0, params_bufs.dev_buf.GetSize()); - -#ifdef GGML_WEBGPU_GPU_PROFILE - // --- Profiling: GPU timestamp queries --- - // Allocate a timestamp query buffer (2 timestamps: start/end) - webgpu_gpu_profile_bufs ts_bufs = ctx->timestamp_query_buf_pool.alloc_bufs(); - if (ts_bufs.host_buf.GetMapState() == wgpu::BufferMapState::Mapped) { - ts_bufs.host_buf.Unmap(); - } - - wgpu::PassTimestampWrites ts_writes = { .querySet = ts_bufs.query_set, - .beginningOfPassWriteIndex = 0, - .endOfPassWriteIndex = 1 }; - wgpu::ComputePassDescriptor pass_desc = { .timestampWrites = &ts_writes }; - wgpu::ComputePassEncoder pass = encoder.BeginComputePass(&pass_desc); -#else - wgpu::ComputePassEncoder pass = encoder.BeginComputePass(); -#endif - pass.SetPipeline(pipeline.pipeline); + wgpu::CommandEncoder encoder = ctx->device.CreateCommandEncoder(); + wgpu::ComputePassEncoder pass = encoder.BeginComputePass(); + pass.SetPipeline(ctx->memset_pipeline.pipeline); pass.SetBindGroup(0, bind_group); - pass.DispatchWorkgroups(wg_x, wg_y, 1); + pass.DispatchWorkgroups(wg_x, 1, 1); pass.End(); -#ifdef GGML_WEBGPU_GPU_PROFILE - // Resolve the query set into the device buffer - encoder.ResolveQuerySet(ts_bufs.query_set, 0, 2, ts_bufs.dev_buf, 0); - encoder.CopyBufferToBuffer(ts_bufs.dev_buf, 0, ts_bufs.host_buf, 0, ts_bufs.host_buf.GetSize()); -#endif - - // If there are SET_ROWS operations in this submission, copy their error buffers to the host. - if (set_rows_error_bufs) { - encoder.CopyBufferToBuffer(set_rows_error_bufs->dev_buf, 0, set_rows_error_bufs->host_buf, 0, - set_rows_error_bufs->host_buf.GetSize()); - } - - wgpu::CommandBuffer commands = encoder.Finish(); - webgpu_command result = {}; - result.commands = commands; - result.params_bufs = params_bufs; - result.set_rows_error_bufs = set_rows_error_bufs; -#ifdef GGML_WEBGPU_GPU_PROFILE - result.timestamp_query_bufs = ts_bufs; - result.pipeline_name = pipeline.name; -#endif - return result; -} - -static void ggml_backend_webgpu_buffer_memset(webgpu_context & ctx, - wgpu::Buffer & buf, - uint32_t value, - size_t offset, - size_t size) { - std::vector<uint32_t> params = { (uint32_t) offset, (uint32_t) size, value }; - std::vector<wgpu::BindGroupEntry> entries = { - { .binding = 0, .buffer = buf, .offset = 0, .size = buf.GetSize() } - }; - size_t bytes_per_wg = WEBGPU_MAX_WG_SIZE * ctx->memset_bytes_per_thread; - uint32_t wg_x = CEIL_DIV(size + 3, bytes_per_wg); - - webgpu_command command = ggml_backend_webgpu_build(ctx, ctx->memset_pipelines[0], params, entries, wg_x); - std::vector<webgpu_submission_futures> futures = { ggml_backend_webgpu_submit(ctx, { command }) }; - ggml_backend_webgpu_wait(ctx, futures); + wgpu::CommandBuffer command = encoder.Finish(); + std::vector<wgpu::CommandBuffer> commands = { command }; + ctx->queue.Submit(commands.size(), commands.data()); } /** End WebGPU Actions */ @@ -714,7 +664,6 @@ static const char * ggml_backend_webgpu_name(ggml_backend_t backend) { return ctx->name.c_str(); } -// TODO: implement proper cleanup static void ggml_backend_webgpu_free(ggml_backend_t backend) { ggml_backend_webgpu_context * ctx = (ggml_backend_webgpu_context *) backend->context; WEBGPU_LOG_DEBUG("ggml_backend_webgpu_free(" << ctx->name << ")"); @@ -722,19 +671,19 @@ static void ggml_backend_webgpu_free(ggml_backend_t backend) { #ifdef GGML_WEBGPU_CPU_PROFILE std::cout << "\n[ggml_webgpu cpu profiling summary]\n"; double total_cpu = 0.0; - for (const auto & kv : ctx->webgpu_ctx->cpu_time_ms) { + for (const auto & kv : ctx->webgpu_ctx->global_ctx->cpu_time_ms) { total_cpu += kv.second; } std::cout << "ggml_webgpu: total cpu time: " << total_cpu << " ms\n"; std::cout << "ggml_webgpu: cpu breakdown:\n"; - for (const auto & kv : ctx->webgpu_ctx->cpu_time_ms) { + for (const auto & kv : ctx->webgpu_ctx->global_ctx->cpu_time_ms) { double pct = (total_cpu > 0.0) ? (kv.second / total_cpu * 100.0) : 0.0; std::cout << "ggml_webgpu: " << kv.first << ": " << kv.second << " ms (" << pct << "%)\n"; } - if (ctx->webgpu_ctx->cpu_detail_ms.size() > 0) { + if (ctx->webgpu_ctx->global_ctx->cpu_detail_ms.size() > 0) { std::cout << "ggml_webgpu: cpu detailed breakdown:\n"; } - for (const auto & kv : ctx->webgpu_ctx->cpu_detail_ms) { + for (const auto & kv : ctx->webgpu_ctx->global_ctx->cpu_detail_ms) { double pct = (total_cpu > 0.0) ? (kv.second / total_cpu * 100.0) : 0.0; std::cout << "ggml_webgpu: " << kv.first << ": " << kv.second << " ms (" << pct << "%)\n"; } @@ -750,7 +699,8 @@ static void ggml_backend_webgpu_free(ggml_backend_t backend) { std::cout << "\nggml_webgpu: gpu breakdown:\n"; for (const auto & kv : ctx->webgpu_ctx->shader_gpu_time_ms) { double pct = (total_gpu > 0.0) ? (kv.second / total_gpu * 100.0) : 0.0; - std::cout << "ggml_webgpu: " << kv.first << ": " << kv.second << " ms (" << pct << "%)\n"; + std::cout << "ggml_webgpu: " << kv.first << ": " << kv.second << " ms (" << std::fixed << std::setprecision(2) + << pct << "%)\n"; } #endif @@ -758,41 +708,20 @@ static void ggml_backend_webgpu_free(ggml_backend_t backend) { std::cout << "ggml_webgpu: gpu/cpu ratio: " << (total_cpu > 0.0 ? total_gpu / total_cpu : 0.0) << "\n"; #endif -#if !defined(GGML_WEBGPU_CPU_PROFILE) && !defined(GGML_WEBGPU_GPU_PROFILE) - GGML_UNUSED(ctx); -#endif -} - -static size_t ggml_webgpu_tensor_offset(const ggml_tensor * tensor) { - return webgpu_tensor_offset(tensor) + tensor->view_offs; -} - -static wgpu::Buffer ggml_webgpu_tensor_buf(const ggml_tensor * tensor) { - ggml_backend_webgpu_buffer_context * ctx = (ggml_backend_webgpu_buffer_context *) tensor->buffer->context; - return ctx->buffer; -} - -static size_t ggml_webgpu_tensor_misalignment(webgpu_context & ctx, const ggml_tensor * t) { - size_t offset = ggml_webgpu_tensor_offset(t); - return offset & (ctx->limits.minStorageBufferOffsetAlignment - 1); + delete ctx; + delete backend; } -static size_t ggml_webgpu_tensor_align_offset(webgpu_context & ctx, const ggml_tensor * t) { - size_t offset = ggml_webgpu_tensor_offset(t); - return offset & ~(ctx->limits.minStorageBufferOffsetAlignment - 1); -} +static webgpu_encoded_op ggml_webgpu_cpy(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) { + ggml_webgpu_shader_lib_context shader_lib_ctx = {}; + shader_lib_ctx.src0 = src; + shader_lib_ctx.dst = dst; + shader_lib_ctx.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup; -static size_t ggml_webgpu_tensor_binding_size(webgpu_context & ctx, ggml_tensor * t) { - return ROUNDUP_POW2(ggml_nbytes(t) + ggml_webgpu_tensor_misalignment(ctx, t), WEBGPU_STORAGE_BUF_BINDING_MULT); -} + webgpu_pipeline pipeline = ctx->shader_lib->get_cpy_pipeline(shader_lib_ctx); -// Used to determine if two tensors are the same for in-place operations -static bool ggml_webgpu_tensor_equal(ggml_tensor * a, ggml_tensor * b) { - return (ggml_webgpu_tensor_buf(a).Get() == ggml_webgpu_tensor_buf(b).Get()) && - (ggml_webgpu_tensor_offset(a) == ggml_webgpu_tensor_offset(b)); -} + auto * decisions = static_cast<ggml_webgpu_generic_shader_decisions *>(pipeline.context.get()); -static webgpu_command ggml_webgpu_cpy(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) { uint32_t ne = (uint32_t) ggml_nelements(dst); std::vector<uint32_t> params = { @@ -809,235 +738,1071 @@ static webgpu_command ggml_webgpu_cpy(webgpu_context & ctx, ggml_tensor * src, g }; std::vector<wgpu::BindGroupEntry> entries = { - { .binding = 0, - .buffer = ggml_webgpu_tensor_buf(src), - .offset = ggml_webgpu_tensor_align_offset(ctx, src), - .size = ggml_webgpu_tensor_binding_size(ctx, src) }, - { .binding = 1, - .buffer = ggml_webgpu_tensor_buf(dst), - .offset = ggml_webgpu_tensor_align_offset(ctx, dst), - .size = ggml_webgpu_tensor_binding_size(ctx, dst) } + ggml_webgpu_make_tensor_bind_group_entry(ctx, 0, src), + ggml_webgpu_make_tensor_bind_group_entry(ctx, 1, dst), }; - uint32_t wg_x = CEIL_DIV(ne, WEBGPU_MAX_WG_SIZE); - return ggml_backend_webgpu_build(ctx, ctx->cpy_pipelines[src->type][dst->type], params, entries, wg_x); + uint32_t wg_x; + uint32_t wg_y; + uint32_t total_wg = CEIL_DIV(ne, decisions->wg_size); + compute_2d_workgroups(total_wg, ctx->global_ctx->capabilities.limits.maxComputeWorkgroupsPerDimension, wg_x, wg_y); + return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x, wg_y); } -static std::optional<webgpu_command> ggml_webgpu_set_rows(webgpu_context & ctx, - ggml_tensor * src, - ggml_tensor * idx, - ggml_tensor * dst) { - // For set rows specifically, we need to check if src and idx are empty tensors. - if (ggml_is_empty(src) || ggml_is_empty(idx)) { - return std::nullopt; - } +static webgpu_encoded_op ggml_webgpu_set(webgpu_context & ctx, + ggml_tensor * src0, + ggml_tensor * src1, + ggml_tensor * dst) { + ggml_webgpu_shader_lib_context shader_lib_ctx = {}; + shader_lib_ctx.src0 = src0; + shader_lib_ctx.src1 = src1; + shader_lib_ctx.dst = dst; + shader_lib_ctx.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup; - webgpu_pool_bufs error_bufs = ctx->set_rows_error_buf_pool.alloc_bufs(); - if (error_bufs.host_buf.GetMapState() == wgpu::BufferMapState::Mapped) { - error_bufs.host_buf.Unmap(); - } + webgpu_pipeline pipeline = ctx->shader_lib->get_set_pipeline(shader_lib_ctx); + + auto * decisions = static_cast<ggml_webgpu_generic_shader_decisions *>(pipeline.context.get()); + const bool inplace = decisions->inplace; + + const uint32_t ne = inplace ? (uint32_t) ggml_nelements(src1) : (uint32_t) ggml_nelements(dst); + const uint32_t dst_type_size = (uint32_t) ggml_type_size(dst->type); std::vector<uint32_t> params = { - (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src) / ggml_type_size(src->type)), - (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, idx) / ggml_type_size(idx->type)), - (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)), - // Convert byte-strides to element-strides - (uint32_t) (src->nb[1] / ggml_type_size(src->type)), (uint32_t) (src->nb[2] / ggml_type_size(src->type)), - (uint32_t) (src->nb[3] / ggml_type_size(src->type)), (uint32_t) (idx->nb[0] / ggml_type_size(idx->type)), - (uint32_t) (idx->nb[1] / ggml_type_size(idx->type)), (uint32_t) (idx->nb[2] / ggml_type_size(idx->type)), - (uint32_t) (dst->nb[1] / ggml_type_size(dst->type)), (uint32_t) (dst->nb[2] / ggml_type_size(dst->type)), - (uint32_t) (dst->nb[3] / ggml_type_size(dst->type)), - // Shape of src - (uint32_t) src->ne[0], (uint32_t) src->ne[1], (uint32_t) src->ne[2], (uint32_t) src->ne[3], - // Shape of idx - (uint32_t) (idx->ne[1]), (uint32_t) (idx->ne[2]) - }; + ne, + (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src0) / ggml_type_size(src0->type)), + (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src1) / ggml_type_size(src1->type)), + (uint32_t) (((const int32_t *) dst->op_params)[3] / dst_type_size), - std::vector<wgpu::BindGroupEntry> entries = { - { .binding = 0, - .buffer = ggml_webgpu_tensor_buf(src), - .offset = ggml_webgpu_tensor_align_offset(ctx, src), - .size = ggml_webgpu_tensor_binding_size(ctx, src) }, - { .binding = 1, - .buffer = ggml_webgpu_tensor_buf(idx), - .offset = ggml_webgpu_tensor_align_offset(ctx, idx), - .size = ggml_webgpu_tensor_binding_size(ctx, idx) }, - { .binding = 2, - .buffer = ggml_webgpu_tensor_buf(dst), - .offset = ggml_webgpu_tensor_align_offset(ctx, dst), - .size = ggml_webgpu_tensor_binding_size(ctx, dst) }, - { .binding = 3, .buffer = error_bufs.dev_buf, .offset = 0, .size = error_bufs.dev_buf.GetSize() } + (uint32_t) (src1->nb[0] / ggml_type_size(src1->type)), + (uint32_t) (src1->nb[1] / ggml_type_size(src1->type)), + (uint32_t) (src1->nb[2] / ggml_type_size(src1->type)), + (uint32_t) (src1->nb[3] / ggml_type_size(src1->type)), + + 1u, + (uint32_t) (((const int32_t *) dst->op_params)[0] / dst_type_size), + (uint32_t) (((const int32_t *) dst->op_params)[1] / dst_type_size), + (uint32_t) (((const int32_t *) dst->op_params)[2] / dst_type_size), + + (uint32_t) src1->ne[0], + (uint32_t) src1->ne[1], + (uint32_t) src1->ne[2], + (uint32_t) src1->ne[3], }; - int vectorized = src->ne[0] % 4 == 0; - webgpu_pipeline pipeline = ctx->set_rows_pipelines[0][vectorized]; - uint32_t threads; - if (vectorized) { - threads = (src->ne[1] * src->ne[2] * src->ne[3]) * (src->ne[0] / 4); - } else { - threads = src->ne[0] * src->ne[1] * src->ne[2] * src->ne[3]; + std::vector<wgpu::BindGroupEntry> entries; + uint32_t binding_index = 0; + if (!inplace) { + entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 0, src0)); + binding_index++; } + entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, binding_index, src1)); + entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, binding_index + 1, dst)); - uint32_t wg_x = CEIL_DIV(threads, WEBGPU_MAX_WG_SIZE); - - return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x, 1, error_bufs); + uint32_t wg_x = CEIL_DIV(ne, decisions->wg_size); + return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x); } -static webgpu_command ggml_webgpu_get_rows(webgpu_context & ctx, - ggml_tensor * src, - ggml_tensor * idx, - ggml_tensor * dst) { +static webgpu_encoded_op ggml_webgpu_pad(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) { + ggml_webgpu_shader_lib_context shader_lib_ctx = {}; + shader_lib_ctx.src0 = src; + shader_lib_ctx.dst = dst; + shader_lib_ctx.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup; + + webgpu_pipeline pipeline = ctx->shader_lib->get_pad_pipeline(shader_lib_ctx); + + auto * decisions = static_cast<ggml_webgpu_generic_shader_decisions *>(pipeline.context.get()); + + const uint32_t ne = (uint32_t) ggml_nelements(dst); + std::vector<uint32_t> params = { + ne, (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src) / ggml_type_size(src->type)), - (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, idx) / ggml_type_size(idx->type)), (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)), - // Convert byte-strides to element-strides - (uint32_t) (src->nb[1] / ggml_type_size(src->type)), (uint32_t) (src->nb[2] / ggml_type_size(src->type)), - (uint32_t) (src->nb[3] / ggml_type_size(src->type)), (uint32_t) (idx->nb[0] / ggml_type_size(idx->type)), + // Strides (in elements) + (uint32_t) (src->nb[0] / ggml_type_size(src->type)), + (uint32_t) (src->nb[1] / ggml_type_size(src->type)), + (uint32_t) (src->nb[2] / ggml_type_size(src->type)), + (uint32_t) (src->nb[3] / ggml_type_size(src->type)), + // Shapes + (uint32_t) src->ne[0], + (uint32_t) src->ne[1], + (uint32_t) src->ne[2], + (uint32_t) src->ne[3], + (uint32_t) dst->ne[0], + (uint32_t) dst->ne[1], + (uint32_t) dst->ne[2], + (uint32_t) dst->ne[3], + // Pad sizes + (uint32_t) ggml_get_op_params_i32(dst, 0), + (uint32_t) ggml_get_op_params_i32(dst, 1), + (uint32_t) ggml_get_op_params_i32(dst, 2), + (uint32_t) ggml_get_op_params_i32(dst, 3), + (uint32_t) ggml_get_op_params_i32(dst, 4), + (uint32_t) ggml_get_op_params_i32(dst, 5), + (uint32_t) ggml_get_op_params_i32(dst, 6), + (uint32_t) ggml_get_op_params_i32(dst, 7), + }; + + std::vector<wgpu::BindGroupEntry> entries = { + ggml_webgpu_make_tensor_bind_group_entry(ctx, 0, src), + ggml_webgpu_make_tensor_bind_group_entry(ctx, 1, dst), + }; + + uint32_t wg_x = CEIL_DIV(ne, decisions->wg_size); + return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x); +} + +static webgpu_encoded_op ggml_webgpu_solve_tri(webgpu_context & ctx, + ggml_tensor * src0, + ggml_tensor * src1, + ggml_tensor * dst) { + ggml_webgpu_shader_lib_context shader_lib_ctx = {}; + shader_lib_ctx.src0 = src0; + shader_lib_ctx.src1 = src1; + shader_lib_ctx.dst = dst; + shader_lib_ctx.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup; + shader_lib_ctx.wg_mem_limit_bytes = ctx->global_ctx->capabilities.limits.maxComputeWorkgroupStorageSize; + + webgpu_pipeline pipeline = ctx->shader_lib->get_solve_tri_pipeline(shader_lib_ctx); + + auto * decisions = static_cast<ggml_webgpu_generic_shader_decisions *>(pipeline.context.get()); + + std::vector<uint32_t> params = { + (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src0) / ggml_type_size(src0->type)), + (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src1) / ggml_type_size(src1->type)), + (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)), + + (uint32_t) (src0->nb[0] / ggml_type_size(src0->type)), + (uint32_t) (src0->nb[1] / ggml_type_size(src0->type)), + (uint32_t) (src0->nb[2] / ggml_type_size(src0->type)), + (uint32_t) (src0->nb[3] / ggml_type_size(src0->type)), + + (uint32_t) (src1->nb[0] / ggml_type_size(src1->type)), + (uint32_t) (src1->nb[1] / ggml_type_size(src1->type)), + (uint32_t) (src1->nb[2] / ggml_type_size(src1->type)), + (uint32_t) (src1->nb[3] / ggml_type_size(src1->type)), + + (uint32_t) (dst->nb[0] / ggml_type_size(dst->type)), + (uint32_t) (dst->nb[1] / ggml_type_size(dst->type)), + (uint32_t) (dst->nb[2] / ggml_type_size(dst->type)), + (uint32_t) (dst->nb[3] / ggml_type_size(dst->type)), + + (uint32_t) src1->ne[0], + (uint32_t) dst->ne[2], + (uint32_t) dst->ne[3], + }; + + std::vector<wgpu::BindGroupEntry> entries = { + ggml_webgpu_make_tensor_bind_group_entry(ctx, 0, src0), + ggml_webgpu_make_tensor_bind_group_entry(ctx, 1, src1), + ggml_webgpu_make_tensor_bind_group_entry(ctx, 2, dst), + }; + + const uint32_t wg_x = CEIL_DIV((uint32_t) src1->ne[0], decisions->wg_size); + const uint32_t wg_y = (uint32_t) (dst->ne[2] * dst->ne[3]); + return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x, wg_y); +} + +static webgpu_encoded_op ggml_webgpu_conv_2d(webgpu_context & ctx, + ggml_tensor * src0, + ggml_tensor * src1, + ggml_tensor * dst) { + const int32_t s0 = ggml_get_op_params_i32(dst, 0); + const int32_t s1 = ggml_get_op_params_i32(dst, 1); + const int32_t p0 = ggml_get_op_params_i32(dst, 2); + const int32_t p1 = ggml_get_op_params_i32(dst, 3); + const int32_t d0 = ggml_get_op_params_i32(dst, 4); + const int32_t d1 = ggml_get_op_params_i32(dst, 5); + + std::vector<uint32_t> params = { + (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src0) / ggml_type_size(src0->type)), + (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src1) / ggml_type_size(src1->type)), + (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)), + + (uint32_t) (src0->nb[0] / ggml_type_size(src0->type)), + (uint32_t) (src0->nb[1] / ggml_type_size(src0->type)), + (uint32_t) (src0->nb[2] / ggml_type_size(src0->type)), + (uint32_t) (src0->nb[3] / ggml_type_size(src0->type)), + + (uint32_t) (src1->nb[0] / ggml_type_size(src1->type)), + (uint32_t) (src1->nb[1] / ggml_type_size(src1->type)), + (uint32_t) (src1->nb[2] / ggml_type_size(src1->type)), + (uint32_t) (src1->nb[3] / ggml_type_size(src1->type)), + + (uint32_t) (dst->nb[0] / ggml_type_size(dst->type)), + (uint32_t) (dst->nb[1] / ggml_type_size(dst->type)), + (uint32_t) (dst->nb[2] / ggml_type_size(dst->type)), + (uint32_t) (dst->nb[3] / ggml_type_size(dst->type)), + + (uint32_t) src0->ne[0], + (uint32_t) src0->ne[1], + (uint32_t) src0->ne[2], + + (uint32_t) src1->ne[0], + (uint32_t) src1->ne[1], + + (uint32_t) dst->ne[0], + (uint32_t) dst->ne[1], + (uint32_t) dst->ne[2], + (uint32_t) dst->ne[3], + + (uint32_t) s0, + (uint32_t) s1, + (uint32_t) p0, + (uint32_t) p1, + (uint32_t) d0, + (uint32_t) d1, + }; + + std::vector<wgpu::BindGroupEntry> entries = { + ggml_webgpu_make_tensor_bind_group_entry(ctx, 0, src0), + ggml_webgpu_make_tensor_bind_group_entry(ctx, 1, src1), + ggml_webgpu_make_tensor_bind_group_entry(ctx, 2, dst), + }; + + ggml_webgpu_shader_lib_context shader_lib_ctx = {}; + shader_lib_ctx.src0 = src0; + shader_lib_ctx.src1 = src1; + shader_lib_ctx.dst = dst; + shader_lib_ctx.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup; + + webgpu_pipeline pipeline = ctx->shader_lib->get_conv2d_pipeline(shader_lib_ctx); + + auto * decisions = static_cast<ggml_webgpu_generic_shader_decisions *>(pipeline.context.get()); + + uint32_t wg_x; + uint32_t wg_y; + uint32_t total_wg = CEIL_DIV((uint32_t) ggml_nelements(dst), decisions->wg_size); + compute_2d_workgroups(total_wg, ctx->global_ctx->capabilities.limits.maxComputeWorkgroupsPerDimension, wg_x, wg_y); + + return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x, wg_y); +} + +static webgpu_encoded_op ggml_webgpu_im2col(webgpu_context & ctx, + ggml_tensor * src0, + ggml_tensor * src1, + ggml_tensor * dst) { + const int32_t s0 = ggml_get_op_params_i32(dst, 0); + const int32_t s1 = ggml_get_op_params_i32(dst, 1); + const int32_t p0 = ggml_get_op_params_i32(dst, 2); + const int32_t p1 = ggml_get_op_params_i32(dst, 3); + const int32_t d0 = ggml_get_op_params_i32(dst, 4); + const int32_t d1 = ggml_get_op_params_i32(dst, 5); + const bool is_2D = ggml_get_op_params_i32(dst, 6) == 1; + + const uint32_t KW = src0->ne[0]; + const uint32_t KH = is_2D ? src0->ne[1] : 1; + const uint32_t IC = is_2D ? src0->ne[2] : src0->ne[1]; + + const uint32_t IW = src1->ne[0]; + const uint32_t IH = is_2D ? src1->ne[1] : 1; + const uint32_t N = is_2D ? src1->ne[3] : src1->ne[2]; + + const uint32_t OW = dst->ne[1]; + const uint32_t OH = is_2D ? dst->ne[2] : 1; + + const uint32_t si0 = (uint32_t) (src1->nb[0] / ggml_type_size(src1->type)); + const uint32_t si1 = is_2D ? (uint32_t) (src1->nb[1] / ggml_type_size(src1->type)) : 0; + const uint32_t si2 = is_2D ? (uint32_t) (src1->nb[2] / ggml_type_size(src1->type)) : + (uint32_t) (src1->nb[1] / ggml_type_size(src1->type)); + const uint32_t si3 = is_2D ? (uint32_t) (src1->nb[3] / ggml_type_size(src1->type)) : + (uint32_t) (src1->nb[2] / ggml_type_size(src1->type)); + + const uint32_t so0 = (uint32_t) (dst->nb[0] / ggml_type_size(dst->type)); + const uint32_t so1 = (uint32_t) (dst->nb[1] / ggml_type_size(dst->type)); + const uint32_t so2 = is_2D ? (uint32_t) (dst->nb[2] / ggml_type_size(dst->type)) : 0; + const uint32_t so3 = is_2D ? (uint32_t) (dst->nb[3] / ggml_type_size(dst->type)) : + (uint32_t) (dst->nb[2] / ggml_type_size(dst->type)); + + std::vector<uint32_t> params = { + (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src1) / ggml_type_size(src1->type)), + (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)), + + si0, + si1, + si2, + si3, + so0, + so1, + so2, + so3, + + KW, + KH, + IC, + + IW, + IH, + N, + + OW, + OH, + + (uint32_t) s0, + (uint32_t) s1, + (uint32_t) p0, + (uint32_t) p1, + (uint32_t) d0, + (uint32_t) d1, + }; + + std::vector<wgpu::BindGroupEntry> entries = { + ggml_webgpu_make_tensor_bind_group_entry(ctx, 0, src1), + ggml_webgpu_make_tensor_bind_group_entry(ctx, 1, dst), + }; + + ggml_webgpu_shader_lib_context shader_lib_ctx = {}; + shader_lib_ctx.src0 = src0; + shader_lib_ctx.src1 = src1; + shader_lib_ctx.dst = dst; + shader_lib_ctx.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup; + + webgpu_pipeline pipeline = ctx->shader_lib->get_im2col_pipeline(shader_lib_ctx); + + auto * decisions = static_cast<ggml_webgpu_generic_shader_decisions *>(pipeline.context.get()); + + uint32_t wg_x; + uint32_t wg_y; + uint32_t total_wg = CEIL_DIV((uint32_t) ggml_nelements(dst), decisions->wg_size); + compute_2d_workgroups(total_wg, ctx->global_ctx->capabilities.limits.maxComputeWorkgroupsPerDimension, wg_x, wg_y); + + return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x, wg_y); +} + +static webgpu_encoded_op ggml_webgpu_ssm_conv(webgpu_context & ctx, + ggml_tensor * src0, + ggml_tensor * src1, + ggml_tensor * dst) { + ggml_webgpu_shader_lib_context shader_lib_ctx = {}; + shader_lib_ctx.src0 = src0; + shader_lib_ctx.src1 = src1; + shader_lib_ctx.dst = dst; + shader_lib_ctx.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup; + + webgpu_pipeline pipeline = ctx->shader_lib->get_ssm_conv_pipeline(shader_lib_ctx); + auto * decisions = static_cast<ggml_webgpu_ssm_conv_shader_decisions *>(pipeline.context.get()); + + const uint32_t token_tiles = CEIL_DIV((uint32_t) dst->ne[1], decisions->tokens_per_wg); + + std::vector<uint32_t> params = { + (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src0) / ggml_type_size(src0->type)), + (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src1) / ggml_type_size(src1->type)), + (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)), + + (uint32_t) (src0->nb[1] / ggml_type_size(src0->type)), + (uint32_t) (src0->nb[2] / ggml_type_size(src0->type)), + (uint32_t) (src1->nb[1] / ggml_type_size(src1->type)), + + (uint32_t) (dst->nb[0] / ggml_type_size(dst->type)), + (uint32_t) (dst->nb[1] / ggml_type_size(dst->type)), + (uint32_t) (dst->nb[2] / ggml_type_size(dst->type)), + + (uint32_t) src1->ne[0], + (uint32_t) src0->ne[1], + (uint32_t) dst->ne[1], + (uint32_t) dst->ne[2], + token_tiles, + }; + + std::vector<wgpu::BindGroupEntry> entries = { + ggml_webgpu_make_tensor_bind_group_entry(ctx, 0, src0), + ggml_webgpu_make_tensor_bind_group_entry(ctx, 1, src1), + ggml_webgpu_make_tensor_bind_group_entry(ctx, 2, dst), + }; + + const uint32_t wg_x = CEIL_DIV((uint32_t) src0->ne[1], decisions->block_size); + const uint32_t wg_y = token_tiles * (uint32_t) dst->ne[2]; + return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x, wg_y); +} + +static webgpu_encoded_op ggml_webgpu_ssm_scan(webgpu_context & ctx, + ggml_tensor * src0, + ggml_tensor * src1, + ggml_tensor * src2, + ggml_tensor * src3, + ggml_tensor * src4, + ggml_tensor * src5, + ggml_tensor * src6, + ggml_tensor * dst) { + ggml_webgpu_shader_lib_context shader_lib_ctx = {}; + shader_lib_ctx.src0 = src0; + shader_lib_ctx.src1 = src1; + shader_lib_ctx.src4 = src4; + shader_lib_ctx.src5 = src5; + shader_lib_ctx.dst = dst; + shader_lib_ctx.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup; + shader_lib_ctx.supports_subgroups = ctx->global_ctx->capabilities.supports_subgroups; + + webgpu_pipeline pipeline = ctx->shader_lib->get_ssm_scan_pipeline(shader_lib_ctx); + auto * decisions = static_cast<ggml_webgpu_ssm_scan_shader_decisions *>(pipeline.context.get()); + const bool xbc_overlap = decisions->xbc_overlap; + + uint32_t offset_x = (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src1) / ggml_type_size(src1->type)); + uint32_t offset_B = (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src4) / ggml_type_size(src4->type)); + uint32_t offset_C = (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src5) / ggml_type_size(src5->type)); + size_t xbc_bind_offset = 0; + size_t xbc_bind_size = 0; + if (xbc_overlap) { + const ggml_webgpu_merged_binding_range merged_range = + ggml_webgpu_tensor_merged_binding_range(ctx, { src1, src4, src5 }); + xbc_bind_offset = merged_range.offset; + xbc_bind_size = merged_range.size; + offset_x = ggml_webgpu_tensor_merged_element_offset(src1, merged_range); + offset_B = ggml_webgpu_tensor_merged_element_offset(src4, merged_range); + offset_C = ggml_webgpu_tensor_merged_element_offset(src5, merged_range); + } + + std::vector<uint32_t> params = { + (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src0) / ggml_type_size(src0->type)), + offset_x, + (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src2) / ggml_type_size(src2->type)), + (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src3) / ggml_type_size(src3->type)), + offset_B, + offset_C, + (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src6) / ggml_type_size(src6->type)), + (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)), + + (uint32_t) (src0->nb[1] / ggml_type_size(src0->type)), + (uint32_t) (src0->nb[2] / ggml_type_size(src0->type)), + (uint32_t) (src0->nb[3] / ggml_type_size(src0->type)), + + (uint32_t) (src1->nb[1] / ggml_type_size(src1->type)), + (uint32_t) (src1->nb[2] / ggml_type_size(src1->type)), + (uint32_t) (src1->nb[3] / ggml_type_size(src1->type)), + + (uint32_t) (src2->nb[1] / ggml_type_size(src2->type)), + (uint32_t) (src2->nb[2] / ggml_type_size(src2->type)), + + (uint32_t) src3->ne[0], + (uint32_t) (src3->nb[1] / ggml_type_size(src3->type)), + + (uint32_t) (src4->nb[1] / ggml_type_size(src4->type)), + (uint32_t) (src4->nb[2] / ggml_type_size(src4->type)), + (uint32_t) (src4->nb[3] / ggml_type_size(src4->type)), + + (uint32_t) (src5->nb[1] / ggml_type_size(src5->type)), + (uint32_t) (src5->nb[2] / ggml_type_size(src5->type)), + (uint32_t) (src5->nb[3] / ggml_type_size(src5->type)), + + (uint32_t) src0->ne[0], + (uint32_t) src0->ne[1], + (uint32_t) src0->ne[2], + (uint32_t) src4->ne[1], + (uint32_t) src1->ne[2], + (uint32_t) src1->ne[3], + (uint32_t) ggml_nelements(src1), + }; + + std::vector<wgpu::BindGroupEntry> entries = { + ggml_webgpu_make_tensor_bind_group_entry(ctx, 0, src0), + }; + if (xbc_overlap) { + entries.push_back( + ggml_webgpu_make_bind_group_entry(1, ggml_webgpu_tensor_buf(src1), xbc_bind_offset, xbc_bind_size)); + entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 2, src2)); + entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 3, src3)); + entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 4, src6)); + entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 5, dst)); + } else { + entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 1, src1)); + entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 2, src2)); + entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 3, src3)); + entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 4, src4)); + entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 5, src5)); + entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 6, src6)); + entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 7, dst)); + } + + const uint32_t total_wg = (uint32_t) (src0->ne[1] * src0->ne[2] * src1->ne[3]); + const uint32_t max_wg_per_dim = ctx->global_ctx->capabilities.limits.maxComputeWorkgroupsPerDimension; + uint32_t wg_x; + uint32_t wg_y; + compute_2d_workgroups(total_wg, max_wg_per_dim, wg_x, wg_y); + + return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x, wg_y); +} + +static webgpu_encoded_op ggml_webgpu_gated_delta_net(webgpu_context & ctx, + ggml_tensor * src0, + ggml_tensor * src1, + ggml_tensor * src2, + ggml_tensor * src3, + ggml_tensor * src4, + ggml_tensor * src5, + ggml_tensor * dst) { + ggml_webgpu_shader_lib_context shader_lib_ctx = {}; + shader_lib_ctx.src0 = src0; + shader_lib_ctx.src1 = src1; + shader_lib_ctx.src2 = src2; + shader_lib_ctx.src3 = src3; + shader_lib_ctx.src4 = src4; + shader_lib_ctx.dst = dst; + shader_lib_ctx.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup; + + webgpu_pipeline pipeline = ctx->shader_lib->get_gated_delta_net_pipeline(shader_lib_ctx); + + const uint32_t s_v = (uint32_t) src2->ne[0]; + const uint32_t h = (uint32_t) src2->ne[1]; + const uint32_t n_tokens = (uint32_t) src2->ne[2]; + const uint32_t n_seqs = (uint32_t) src2->ne[3]; + const uint32_t K = (uint32_t) ggml_get_op_params_i32(dst, 0); + const float scale = 1.0f / sqrtf((float) s_v); + uint32_t scale_u32; + memcpy(&scale_u32, &scale, sizeof(scale_u32)); + + std::vector<uint32_t> params = { + h, + n_tokens, + n_seqs, + s_v * h * n_tokens * n_seqs, + + (uint32_t) (src0->nb[1] / ggml_type_size(src0->type)), + (uint32_t) (src0->nb[2] / ggml_type_size(src0->type)), + (uint32_t) (src0->nb[3] / ggml_type_size(src0->type)), + + (uint32_t) (src2->nb[1] / ggml_type_size(src2->type)), + (uint32_t) (src2->nb[2] / ggml_type_size(src2->type)), + (uint32_t) (src2->nb[3] / ggml_type_size(src2->type)), + + (uint32_t) (src4->nb[1] / ggml_type_size(src4->type)), + (uint32_t) (src4->nb[2] / ggml_type_size(src4->type)), + (uint32_t) (src4->nb[3] / ggml_type_size(src4->type)), + + (uint32_t) src0->ne[1], + (uint32_t) (src2->ne[3] / src0->ne[3]), + K, + scale_u32, + }; + + std::vector<wgpu::BindGroupEntry> entries = { + ggml_webgpu_make_tensor_bind_group_entry(ctx, 0, src0), ggml_webgpu_make_tensor_bind_group_entry(ctx, 1, src1), + ggml_webgpu_make_tensor_bind_group_entry(ctx, 2, src2), ggml_webgpu_make_tensor_bind_group_entry(ctx, 3, src3), + ggml_webgpu_make_tensor_bind_group_entry(ctx, 4, src4), ggml_webgpu_make_tensor_bind_group_entry(ctx, 5, src5), + ggml_webgpu_make_tensor_bind_group_entry(ctx, 6, dst), + }; + + return ggml_backend_webgpu_build(ctx, pipeline, params, entries, h, n_seqs); +} + +static std::optional<webgpu_encoded_op> ggml_webgpu_set_rows(webgpu_context & ctx, + ggml_tensor * src, + ggml_tensor * idx, + ggml_tensor * dst) { + // For set rows specifically, we need to check if src and idx are empty + // tensors. + if (ggml_is_empty(src) || ggml_is_empty(idx)) { + return std::nullopt; + } + + ggml_webgpu_shader_lib_context shader_lib_ctx = {}; + shader_lib_ctx.src0 = src; + shader_lib_ctx.src1 = idx; + shader_lib_ctx.dst = dst; + shader_lib_ctx.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup; + + webgpu_pipeline pipeline = ctx->shader_lib->get_set_rows_pipeline(shader_lib_ctx); + + auto * decisions = static_cast<ggml_webgpu_set_rows_shader_decisions *>(pipeline.context.get()); + + std::vector<uint32_t> params = { + (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src) / ggml_type_size(src->type)), + (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, idx) / ggml_type_size(idx->type)), + (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)), + // Convert byte-strides to element-strides + (uint32_t) (src->nb[1] / ggml_type_size(src->type)), (uint32_t) (src->nb[2] / ggml_type_size(src->type)), + (uint32_t) (src->nb[3] / ggml_type_size(src->type)), (uint32_t) (idx->nb[0] / ggml_type_size(idx->type)), (uint32_t) (idx->nb[1] / ggml_type_size(idx->type)), (uint32_t) (idx->nb[2] / ggml_type_size(idx->type)), (uint32_t) (dst->nb[1] / ggml_type_size(dst->type)), (uint32_t) (dst->nb[2] / ggml_type_size(dst->type)), (uint32_t) (dst->nb[3] / ggml_type_size(dst->type)), - // Shape of dst - (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], + // Shape of src + (uint32_t) src->ne[0], (uint32_t) src->ne[1], (uint32_t) src->ne[2], (uint32_t) src->ne[3], // Shape of idx (uint32_t) (idx->ne[1]), (uint32_t) (idx->ne[2]) }; std::vector<wgpu::BindGroupEntry> entries = { - { .binding = 0, - .buffer = ggml_webgpu_tensor_buf(src), - .offset = ggml_webgpu_tensor_align_offset(ctx, src), - .size = ggml_webgpu_tensor_binding_size(ctx, src) }, - { .binding = 1, - .buffer = ggml_webgpu_tensor_buf(idx), - .offset = ggml_webgpu_tensor_align_offset(ctx, idx), - .size = ggml_webgpu_tensor_binding_size(ctx, idx) }, - { .binding = 2, - .buffer = ggml_webgpu_tensor_buf(dst), - .offset = ggml_webgpu_tensor_align_offset(ctx, dst), - .size = ggml_webgpu_tensor_binding_size(ctx, dst) } + ggml_webgpu_make_tensor_bind_group_entry(ctx, 0, src), + ggml_webgpu_make_tensor_bind_group_entry(ctx, 1, idx), + ggml_webgpu_make_tensor_bind_group_entry(ctx, 2, dst), }; - uint32_t wg_x = CEIL_DIV(dst->ne[1] * dst->ne[2] * dst->ne[3], WEBGPU_MAX_WG_SIZE); + if (decisions->i64_idx) { + entries.push_back(ggml_webgpu_make_bind_group_entry(3, ctx->set_rows_dev_error_buf, 0, + ctx->set_rows_dev_error_buf.GetSize())); + } + + uint32_t threads; + if (ggml_is_quantized(dst->type)) { + const uint32_t blocks_per_row = src->ne[0] / ggml_blck_size(dst->type); + threads = + (src->ne[1] * src->ne[2] * src->ne[3]) * (decisions->pair_blocks ? (blocks_per_row / 2) : blocks_per_row); + } else if (decisions->vec4) { + threads = (src->ne[1] * src->ne[2] * src->ne[3]) * (src->ne[0] / 4); + } else { + threads = src->ne[0] * src->ne[1] * src->ne[2] * src->ne[3]; + } + uint32_t wg_x = CEIL_DIV(threads, decisions->wg_size); + return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x, 1); +} + +// Workgroup size is a common constant +static std::vector<wgpu::ConstantEntry> ggml_webgpu_wg_size_entry(uint32_t wg_size) { + std::vector<wgpu::ConstantEntry> constants(1); + constants[0].key = "wg_size"; + constants[0].value = wg_size; + return constants; +} + +static webgpu_encoded_op ggml_webgpu_get_rows(webgpu_context & ctx, + ggml_tensor * src, + ggml_tensor * idx, + ggml_tensor * dst) { + const bool float_parallel = src->type == GGML_TYPE_F32 || src->type == GGML_TYPE_F16 || src->type == GGML_TYPE_I32; + + ggml_webgpu_shader_lib_context shader_lib_ctx = {}; + shader_lib_ctx.src0 = src; + shader_lib_ctx.src1 = nullptr; + shader_lib_ctx.dst = dst; + shader_lib_ctx.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup; + + webgpu_pipeline pipeline = ctx->shader_lib->get_get_rows_pipeline(shader_lib_ctx); + auto * decisions = static_cast<ggml_webgpu_generic_shader_decisions *>(pipeline.context.get()); + + std::vector<uint32_t> params = { (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src) / ggml_type_size(src->type)), + (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, idx) / ggml_type_size(idx->type)), + (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)), + (uint32_t) (src->nb[1] / ggml_type_size(src->type)), + (uint32_t) (src->nb[2] / ggml_type_size(src->type)), + (uint32_t) (src->nb[3] / ggml_type_size(src->type)), + (uint32_t) (idx->nb[0] / ggml_type_size(idx->type)), + (uint32_t) (idx->nb[1] / ggml_type_size(idx->type)), + (uint32_t) (idx->nb[2] / ggml_type_size(idx->type)), + (uint32_t) (dst->nb[1] / ggml_type_size(dst->type)), + (uint32_t) (dst->nb[2] / ggml_type_size(dst->type)), + (uint32_t) (dst->nb[3] / ggml_type_size(dst->type)), + (uint32_t) dst->ne[0], + (uint32_t) dst->ne[1], + (uint32_t) dst->ne[2], + (uint32_t) dst->ne[3], + (uint32_t) (idx->ne[1]), + (uint32_t) (idx->ne[2]) }; + + std::vector<wgpu::BindGroupEntry> entries = { ggml_webgpu_make_tensor_bind_group_entry(ctx, 0, src), + ggml_webgpu_make_tensor_bind_group_entry(ctx, 1, idx), + ggml_webgpu_make_tensor_bind_group_entry(ctx, 2, dst) }; + + uint32_t blocks_per_row = (uint32_t) (dst->ne[0] / (src->type == GGML_TYPE_F32 && dst->ne[0] % 4 == 0 ? 4 : 1)); + uint32_t total_rows = (uint32_t) (dst->ne[1] * dst->ne[2] * dst->ne[3]); + uint32_t total_threads = float_parallel ? blocks_per_row * total_rows : total_rows; + uint32_t wg_x = CEIL_DIV(total_threads, decisions->wg_size); - uint32_t vectorized = src->type == GGML_TYPE_F32 && dst->ne[0] % 4 == 0; - webgpu_pipeline pipeline = ctx->get_rows_pipelines[src->type][vectorized]; return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x); } -static webgpu_command ggml_webgpu_mul_mat(webgpu_context & ctx, - ggml_tensor * src0, - ggml_tensor * src1, - ggml_tensor * dst) { +static void ggml_webgpu_quantize_q8_dispatch(webgpu_context & ctx, + ggml_tensor * src0, + ggml_tensor * src1, + ggml_tensor * dst, + std::vector<webgpu_dispatch_desc> & dispatches) { + ggml_webgpu_shader_lib_context shader_lib_ctx = {}; + + shader_lib_ctx.src0 = src0; + shader_lib_ctx.src1 = src1; + shader_lib_ctx.dst = dst; + shader_lib_ctx.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup; + shader_lib_ctx.supports_subgroups = ctx->global_ctx->capabilities.supports_subgroups; + + webgpu_pipeline qq8_pipeline = ctx->shader_lib->get_quantize_q8_pipeline(shader_lib_ctx); + + // quantize_q8 pipeline + const size_t dst_offset = ggml_webgpu_tensor_offset(dst); + const size_t q8_src1_align_offset = ROUNDUP_POW2( + dst_offset + ggml_nbytes(dst), ctx->global_ctx->capabilities.limits.minStorageBufferOffsetAlignment); + const size_t q8_src1_binding_size = + ROUNDUP_POW2(src1->ne[3] * src1->ne[2] * (36 /* sizeof(q8_1) */ * (src1->ne[0] / /* block_size */ 32)), + WEBGPU_STORAGE_BUF_BINDING_MULT); + + std::vector<uint32_t> q8_params = { + (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src1) / ggml_type_size(src1->type)), + (uint32_t) (src1->nb[2] / ggml_type_size(src1->type)), + (uint32_t) (src1->nb[3] / ggml_type_size(src1->type)), + (uint32_t) src1->ne[0], + (uint32_t) src1->ne[2], + (uint32_t) src1->ne[3], + }; + + std::vector<wgpu::BindGroupEntry> q8_entries = { + ggml_webgpu_make_tensor_bind_group_entry(ctx, 0, src1), + ggml_webgpu_make_bind_group_entry(1, ggml_webgpu_tensor_buf(dst), q8_src1_align_offset, q8_src1_binding_size) + }; + + auto q8_decisions = static_cast<ggml_webgpu_generic_shader_decisions *>(qq8_pipeline.context.get()); + + uint32_t q8_wg_size = q8_decisions->wg_size; + uint32_t q8_wg_x = 1; + uint32_t q8_wg_y = 1; + const uint32_t wg_per_vec = (src0->ne[0] / 4 + (q8_wg_size - 1)) / q8_wg_size; + const uint32_t q8_total_wg = src1->ne[2] * src1->ne[3] * wg_per_vec; + const uint32_t max_wg_per_dim = ctx->global_ctx->capabilities.limits.maxComputeWorkgroupsPerDimension; + compute_2d_workgroups(q8_total_wg, max_wg_per_dim, q8_wg_x, q8_wg_y); + + dispatches.push_back({ + qq8_pipeline, std::move(q8_params), std::move(q8_entries), { q8_wg_x, q8_wg_y } + }); +} + +static webgpu_encoded_op ggml_webgpu_mul_mat(webgpu_context & ctx, + ggml_tensor * src0, + ggml_tensor * src1, + ggml_tensor * dst) { + // Determine if this is a mat-vec operation + bool is_vec = (dst->ne[1] == 1); + + // use MMVQ path for mat-vec + bool use_mmvq = ggml_webgpu_can_use_mmvq(src0, src1, ctx->global_ctx->capabilities.supports_dot_product, + ctx->global_ctx->vendor); + + ggml_webgpu_shader_lib_context shader_lib_ctx = {}; + + shader_lib_ctx.src0 = src0; + shader_lib_ctx.src1 = src1; + shader_lib_ctx.dst = dst; + shader_lib_ctx.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup; + shader_lib_ctx.supports_subgroups = ctx->global_ctx->capabilities.supports_subgroups; + shader_lib_ctx.supports_subgroup_matrix = ctx->global_ctx->capabilities.supports_subgroup_matrix; + shader_lib_ctx.sg_mat_m = ctx->global_ctx->capabilities.sg_mat_m; + shader_lib_ctx.sg_mat_n = ctx->global_ctx->capabilities.sg_mat_n; + shader_lib_ctx.sg_mat_k = ctx->global_ctx->capabilities.sg_mat_k; + shader_lib_ctx.min_subgroup_size = ctx->global_ctx->capabilities.min_subgroup_size; + shader_lib_ctx.max_subgroup_size = ctx->global_ctx->capabilities.max_subgroup_size; + shader_lib_ctx.supports_dot_product = ctx->global_ctx->capabilities.supports_dot_product; + shader_lib_ctx.vendor = ctx->global_ctx->vendor; + + // Get or create pipeline + webgpu_pipeline pipeline; + std::vector<webgpu_dispatch_desc> dispatches; + + if (is_vec) { + if (use_mmvq) { + ggml_webgpu_quantize_q8_dispatch(ctx, src0, src1, dst, dispatches); + } + pipeline = ctx->shader_lib->get_mul_mat_vec_pipeline(shader_lib_ctx); + } else { + pipeline = ctx->shader_lib->get_mul_mat_fast_pipeline(shader_lib_ctx); + } + + // Build params + std::vector<uint32_t> params = { + (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src0) / ggml_type_size(src0->type)), + (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src1) / ggml_type_size(src1->type)), + (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)), + (uint32_t) dst->ne[0], + (uint32_t) dst->ne[1], + (uint32_t) src0->ne[0], + (uint32_t) (src0->nb[1] / ggml_type_size(src0->type)), + (uint32_t) (src1->nb[1] / ggml_type_size(src1->type)), + (uint32_t) (src0->nb[2] / ggml_type_size(src0->type)), + (uint32_t) (src1->nb[2] / ggml_type_size(src1->type)), + (uint32_t) (src0->nb[3] / ggml_type_size(src0->type)), + (uint32_t) (src1->nb[3] / ggml_type_size(src1->type)), + (uint32_t) src0->ne[2], + (uint32_t) src0->ne[3], + (uint32_t) (src1->ne[2] / src0->ne[2]), + (uint32_t) (src1->ne[3] / src0->ne[3]) + }; + + // Build bind group entries + std::vector<wgpu::BindGroupEntry> entries = {}; + + entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 0, src0)); + if (use_mmvq) { + auto & mmvq_qq8_entry = dispatches[0].bind_group_entries[1]; + entries.push_back(ggml_webgpu_make_bind_group_entry(1, ggml_webgpu_tensor_buf(dst), mmvq_qq8_entry.offset, + mmvq_qq8_entry.size)); + } else { + entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 1, src1)); + } + entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 2, dst)); + + // Calculate workgroup dimensions + uint32_t wg_x = 1; + uint32_t wg_y = 1; + const uint32_t max_wg_per_dim = ctx->global_ctx->capabilities.limits.maxComputeWorkgroupsPerDimension; + + if (is_vec) { + auto * decisions = static_cast<ggml_webgpu_mul_mat_vec_shader_decisions *>(pipeline.context.get()); + + uint32_t batches = dst->ne[2] * dst->ne[3]; + uint32_t output_groups = CEIL_DIV(dst->ne[0], decisions->outputs_per_wg); + uint32_t total_wg = output_groups * batches; + compute_2d_workgroups(total_wg, max_wg_per_dim, wg_x, wg_y); + } else { + auto * decisions = static_cast<ggml_webgpu_mul_mat_shader_decisions *>(pipeline.context.get()); + + // Fast-path tiled/subgroup calculations + uint32_t wg_m; + uint32_t wg_n; + if (decisions->use_subgroup_matrix) { + uint32_t wg_m_sg_tile = + decisions->subgroup_m * decisions->subgroup_matrix_m * ctx->global_ctx->capabilities.sg_mat_m; + wg_m = CEIL_DIV(dst->ne[0], wg_m_sg_tile); + uint32_t wg_n_sg_tile = + decisions->subgroup_n * decisions->subgroup_matrix_n * ctx->global_ctx->capabilities.sg_mat_n; + wg_n = CEIL_DIV(dst->ne[1], wg_n_sg_tile); + } else { + uint32_t tile_m_s = decisions->tile_m * decisions->wg_size_m; + uint32_t tile_n_s = decisions->tile_n * decisions->wg_size_n; + wg_m = CEIL_DIV(dst->ne[0], tile_m_s); + wg_n = CEIL_DIV(dst->ne[1], tile_n_s); + } + uint32_t total_wg = wg_m * wg_n * dst->ne[2] * dst->ne[3]; + compute_2d_workgroups(total_wg, max_wg_per_dim, wg_x, wg_y); + } + + dispatches.push_back({ + pipeline, std::move(params), std::move(entries), { wg_x, wg_y } + }); + + return ggml_backend_webgpu_build_multi(ctx, dispatches); +} + +static webgpu_encoded_op ggml_webgpu_mul_mat_id_vec(webgpu_context & ctx, + ggml_tensor * src0, + ggml_tensor * src1, + ggml_tensor * src2, + ggml_tensor * dst) { + const uint32_t param_n_expert = (uint32_t) src0->ne[2]; + const uint32_t param_n_expert_used = (uint32_t) dst->ne[1]; + + ggml_webgpu_shader_lib_context shader_lib_ctx = {}; + shader_lib_ctx.src0 = src0; + shader_lib_ctx.src1 = src1; + shader_lib_ctx.src2 = src2; + shader_lib_ctx.dst = dst; + shader_lib_ctx.supports_subgroups = ctx->global_ctx->capabilities.supports_subgroups; + shader_lib_ctx.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup; + + webgpu_pipeline pipeline = ctx->shader_lib->get_mul_mat_id_vec_pipeline(shader_lib_ctx); + std::vector<uint32_t> params = { + (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src0) / ggml_type_size(src0->type)), + (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src1) / ggml_type_size(src1->type)), + (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src2) / ggml_type_size(src2->type)), + (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)), + (uint32_t) src0->ne[0], + (uint32_t) src0->ne[1], + param_n_expert, + param_n_expert_used, + (uint32_t) src1->ne[1], + (uint32_t) (src0->nb[1] / ggml_type_size(src0->type)), + (uint32_t) (src1->nb[1] / ggml_type_size(src1->type)), + (uint32_t) (src0->nb[2] / ggml_type_size(src0->type)), + (uint32_t) (src1->nb[2] / ggml_type_size(src1->type)), + }; + + std::vector<wgpu::BindGroupEntry> entries = { + ggml_webgpu_make_bind_group_entry(0, ggml_webgpu_tensor_buf(src0), ggml_webgpu_tensor_align_offset(ctx, src0), + ggml_webgpu_tensor_binding_size(ctx, src0)), + ggml_webgpu_make_bind_group_entry(1, ggml_webgpu_tensor_buf(src1), ggml_webgpu_tensor_align_offset(ctx, src1), + ggml_webgpu_tensor_binding_size(ctx, src1)), + ggml_webgpu_make_bind_group_entry(2, ggml_webgpu_tensor_buf(src2), ggml_webgpu_tensor_align_offset(ctx, src2), + ggml_webgpu_tensor_binding_size(ctx, src2)), + ggml_webgpu_make_bind_group_entry(3, ggml_webgpu_tensor_buf(dst), ggml_webgpu_tensor_align_offset(ctx, dst), + ggml_webgpu_tensor_binding_size(ctx, dst)), + }; + + uint32_t wg_x = 1; + uint32_t wg_y = 1; + + auto * decisions = static_cast<ggml_webgpu_mul_mat_vec_shader_decisions *>(pipeline.context.get()); + + const uint32_t max_wg_per_dim = ctx->global_ctx->capabilities.limits.maxComputeWorkgroupsPerDimension; + uint32_t output_groups = CEIL_DIV(dst->ne[0], decisions->outputs_per_wg); + uint32_t total_wg = output_groups * param_n_expert_used; + compute_2d_workgroups(total_wg, max_wg_per_dim, wg_x, wg_y); + + return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x, wg_y); +} + +static webgpu_encoded_op ggml_webgpu_mul_mat_id(webgpu_context & ctx, + ggml_tensor * src0, + ggml_tensor * src1, + ggml_tensor * src2, + ggml_tensor * dst) { + // we can use mat-vec fast path + if (dst->ne[2] == 1) { + return ggml_webgpu_mul_mat_id_vec(ctx, src0, src1, src2, dst); + } + + ggml_webgpu_shader_lib_context shader_lib_ctx = {}; + shader_lib_ctx.src0 = src0; + shader_lib_ctx.src1 = src1; + shader_lib_ctx.src2 = src2; + shader_lib_ctx.dst = dst; + shader_lib_ctx.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup; + + // Get or create pipeline + webgpu_pipeline gather_pipeline; + webgpu_pipeline main_pipeline; + + std::vector<webgpu_dispatch_desc> dispatches; + + gather_pipeline = ctx->shader_lib->get_mul_mat_id_gather_pipeline(shader_lib_ctx); + main_pipeline = ctx->shader_lib->get_mul_mat_id_pipeline(shader_lib_ctx); + + const uint32_t param_n_expert = (uint32_t) src0->ne[2]; + const uint32_t param_n_expert_used = (uint32_t) dst->ne[1]; + const uint32_t param_n_tokens = (uint32_t) dst->ne[2]; + + // params for mul_mat_id_gather.wgsl + std::vector<uint32_t> gather_params = { + (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src2) / ggml_type_size(src2->type)), + param_n_expert, + param_n_expert_used, + param_n_tokens, + (uint32_t) (src2->nb[1] / ggml_type_size(src2->type)), + }; + + const size_t dst_offset = ggml_webgpu_tensor_offset(dst); + const size_t gathered_buf_nbytes = src0->ne[2] * src1->ne[2] * sizeof(uint32_t); + + const size_t gathered_expert_used_align_offset = ROUNDUP_POW2( + dst_offset + ggml_nbytes(dst), ctx->global_ctx->capabilities.limits.minStorageBufferOffsetAlignment); + const size_t gathered_tokens_align_offset = + ROUNDUP_POW2(gathered_expert_used_align_offset + gathered_buf_nbytes, + ctx->global_ctx->capabilities.limits.minStorageBufferOffsetAlignment); + const size_t gathered_count_ids_align_offset = + ROUNDUP_POW2(gathered_tokens_align_offset + gathered_buf_nbytes, + ctx->global_ctx->capabilities.limits.minStorageBufferOffsetAlignment); + + const size_t gathered_binding_size = ROUNDUP_POW2(gathered_buf_nbytes, WEBGPU_STORAGE_BUF_BINDING_MULT); + const size_t gathered_count_ids_binding_size = + ROUNDUP_POW2(src0->ne[2] * sizeof(uint32_t), WEBGPU_STORAGE_BUF_BINDING_MULT); + + // bind group entries for mul_mat_id_gather.wgsl + std::vector<wgpu::BindGroupEntry> gather_entries = { + ggml_webgpu_make_bind_group_entry(0, ggml_webgpu_tensor_buf(src2), ggml_webgpu_tensor_align_offset(ctx, src2), + ggml_webgpu_tensor_binding_size(ctx, src2)), + ggml_webgpu_make_bind_group_entry(1, ggml_webgpu_tensor_buf(dst), gathered_expert_used_align_offset, + gathered_binding_size), + ggml_webgpu_make_bind_group_entry(2, ggml_webgpu_tensor_buf(dst), gathered_tokens_align_offset, + gathered_binding_size), + ggml_webgpu_make_bind_group_entry(3, ggml_webgpu_tensor_buf(dst), gathered_count_ids_align_offset, + gathered_count_ids_binding_size), + }; + + // n_expert is much less than maxComputeWorkgroupsPerDimension (e.g., n_exeprt=256 at Qwen3.5-35B-A3B) + const uint32_t gather_wg_x = param_n_expert; + + dispatches.push_back({ + gather_pipeline, std::move(gather_params), std::move(gather_entries), { gather_wg_x, 1 } + }); + + // params for mul_mat_id.wgsl + std::vector<uint32_t> main_params = { (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src0) / ggml_type_size(src0->type)), (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src1) / ggml_type_size(src1->type)), (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)), - (uint32_t) dst->ne[0], // number of rows in result (M, transposed) - (uint32_t) dst->ne[1], // number of columns in result (N) - (uint32_t) src0->ne[0], // number of columns in src0/src1 (K) - (uint32_t) (src0->nb[1] / ggml_type_size(src0->type)), // stride (elements/blocks) of src0 in dimension 1 - (uint32_t) (src1->nb[1] / ggml_type_size(src1->type)), // stride (elements/blocks) of src1 in dimension 1 - (uint32_t) (src0->nb[2] / ggml_type_size(src0->type)), // stride (elements/blocks) of src0 in dimension 2 - (uint32_t) (src1->nb[2] / ggml_type_size(src1->type)), // stride (elements/blocks) of src1 in dimension 2 - (uint32_t) (src0->nb[3] / ggml_type_size(src0->type)), // stride (elements/blocks) of src0 in dimension 3 - (uint32_t) (src1->nb[3] / ggml_type_size(src1->type)), // stride (elements/blocks) of src1 in dimension 3 - (uint32_t) src0->ne[2], // batch size in dimension 2 - (uint32_t) src0->ne[3], // batch size in dimension 3 - (uint32_t) (src1->ne[2] / src0->ne[2]), // broadcast in dimension 2 - (uint32_t) (src1->ne[3] / src0->ne[3]) // broadcast in dimension 3 + (uint32_t) src0->ne[0], + (uint32_t) src0->ne[1], + param_n_expert, + param_n_expert_used, + param_n_tokens, + (uint32_t) src1->ne[1], + (uint32_t) (src0->nb[1] / ggml_type_size(src0->type)), + (uint32_t) (src1->nb[1] / ggml_type_size(src1->type)), + (uint32_t) (src0->nb[2] / ggml_type_size(src0->type)), + (uint32_t) (src1->nb[2] / ggml_type_size(src1->type)), }; - std::vector<wgpu::BindGroupEntry> entries = { - { .binding = 0, - .buffer = ggml_webgpu_tensor_buf(src0), - .offset = ggml_webgpu_tensor_align_offset(ctx, src0), - .size = ggml_webgpu_tensor_binding_size(ctx, src0) }, - { .binding = 1, - .buffer = ggml_webgpu_tensor_buf(src1), - .offset = ggml_webgpu_tensor_align_offset(ctx, src1), - .size = ggml_webgpu_tensor_binding_size(ctx, src1) }, - { .binding = 2, - .buffer = ggml_webgpu_tensor_buf(dst), - .offset = ggml_webgpu_tensor_align_offset(ctx, dst), - .size = ggml_webgpu_tensor_binding_size(ctx, dst) }, + // bind group entries for mul_mat_id.wgsl + std::vector<wgpu::BindGroupEntry> main_entries = { + ggml_webgpu_make_bind_group_entry(0, ggml_webgpu_tensor_buf(src0), ggml_webgpu_tensor_align_offset(ctx, src0), + ggml_webgpu_tensor_binding_size(ctx, src0)), + ggml_webgpu_make_bind_group_entry(1, ggml_webgpu_tensor_buf(src1), ggml_webgpu_tensor_align_offset(ctx, src1), + ggml_webgpu_tensor_binding_size(ctx, src1)), + ggml_webgpu_make_bind_group_entry(2, ggml_webgpu_tensor_buf(dst), ggml_webgpu_tensor_align_offset(ctx, dst), + ggml_webgpu_tensor_binding_size(ctx, dst)), + ggml_webgpu_make_bind_group_entry(3, ggml_webgpu_tensor_buf(dst), gathered_expert_used_align_offset, + gathered_binding_size), + ggml_webgpu_make_bind_group_entry(4, ggml_webgpu_tensor_buf(dst), gathered_tokens_align_offset, + gathered_binding_size), + ggml_webgpu_make_bind_group_entry(5, ggml_webgpu_tensor_buf(dst), gathered_count_ids_align_offset, + gathered_count_ids_binding_size), }; - webgpu_pipeline pipeline = ctx->mul_mat_pipelines[src0->type][src1->type][0]; - - uint32_t wg_x = CEIL_DIV(dst->ne[0] * dst->ne[1] * dst->ne[2] * dst->ne[3], WEBGPU_MUL_MAT_WG_SIZE); + // Calculate workgroup dimensions + uint32_t wg_x = 1; uint32_t wg_y = 1; - bool use_fast = false; - switch (src1->type) { - case GGML_TYPE_F16: - use_fast = (src0->type == GGML_TYPE_F16); - break; - case GGML_TYPE_F32: - switch (src0->type) { - case GGML_TYPE_F32: - case GGML_TYPE_F16: - case GGML_TYPE_Q4_0: - use_fast = true; - break; - default: - break; - } - break; - default: - break; - } + auto * main_decisions = static_cast<ggml_webgpu_mul_mat_shader_decisions *>(main_pipeline.context.get()); - if (use_fast) { - int vectorized = src0->ne[0] % 4 == 0 && dst->ne[0] % 4 == 0 && dst->ne[1] % 4 == 0; - if (dst->ne[1] == 1) { - // We don't support vectorized mul_mat_vec for quantized types - vectorized = vectorized && (src0->type < 2); - pipeline = ctx->mul_mat_vec_pipelines[src0->type][src1->type][vectorized]; - uint32_t batches = dst->ne[2] * dst->ne[3]; - uint32_t output_groups = CEIL_DIV(dst->ne[0], WEBGPU_MUL_MAT_VEC_OUTPUTS_PER_WG); - uint32_t total_wg = output_groups * batches; - wg_x = total_wg % ctx->limits.maxComputeWorkgroupsPerDimension; - wg_y = CEIL_DIV(total_wg, ctx->limits.maxComputeWorkgroupsPerDimension); - } else { - pipeline = ctx->mul_mat_pipelines[src0->type][src1->type][vectorized]; - uint32_t wg_m; - uint32_t wg_n; -#ifndef __EMSCRIPTEN__ - if (ctx->supports_subgroup_matrix) { - // The total number of subgroups/workgroups needed per matrix. - uint32_t wg_m_sg_tile = WEBGPU_MUL_MAT_SUBGROUP_M * WEBGPU_MUL_MAT_SUBGROUP_MATRIX_M * ctx->sg_mat_m; - wg_m = CEIL_DIV(dst->ne[0], wg_m_sg_tile); - uint32_t wg_n_sg_tile = WEBGPU_MUL_MAT_SUBGROUP_N * WEBGPU_MUL_MAT_SUBGROUP_MATRIX_N * ctx->sg_mat_n; - wg_n = CEIL_DIV(dst->ne[1], wg_n_sg_tile); - } else { -#endif - uint32_t tile_m_s = WEBGPU_MUL_MAT_TILE_M * WEBGPU_MUL_MAT_WG_SIZE_M; - uint32_t tile_n_s = WEBGPU_MUL_MAT_TILE_N * WEBGPU_MUL_MAT_WG_SIZE_N; - wg_m = CEIL_DIV(dst->ne[0], tile_m_s); - wg_n = CEIL_DIV(dst->ne[1], tile_n_s); -#ifndef __EMSCRIPTEN__ - } -#endif + uint32_t wg_m; - wg_x = wg_m * wg_n * dst->ne[2] * dst->ne[3]; - } - } - return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x, wg_y); + uint32_t tile_m_s = main_decisions->tile_m * main_decisions->wg_size_m; + uint32_t tile_n_s = main_decisions->tile_n * main_decisions->wg_size_n; + wg_m = CEIL_DIV(dst->ne[0], tile_m_s); + uint32_t total_gathered = dst->ne[1] * dst->ne[2]; + uint32_t max_active_experts = std::min((uint32_t) src0->ne[2], total_gathered); + uint32_t max_wg_n = CEIL_DIV(total_gathered, tile_n_s) + max_active_experts; + uint32_t total_wg = wg_m * max_wg_n; + + compute_2d_workgroups(total_wg, ctx->global_ctx->capabilities.limits.maxComputeWorkgroupsPerDimension, wg_x, wg_y); + + dispatches.push_back({ + main_pipeline, std::move(main_params), std::move(main_entries), { wg_x, wg_y } + }); + + return ggml_backend_webgpu_build_multi(ctx, dispatches); } -static webgpu_command ggml_webgpu_flash_attn(webgpu_context & ctx, - ggml_tensor * Q, - ggml_tensor * K, - ggml_tensor * V, - ggml_tensor * mask, - ggml_tensor * sinks, - ggml_tensor * dst) { - float scale = *(float *) dst->op_params; - float max_bias; - memcpy(&max_bias, (float *) dst->op_params + 1, sizeof(float)); - float logit_softcap; - memcpy(&logit_softcap, (float *) dst->op_params + 2, sizeof(float)); +struct ggml_webgpu_flash_attn_op { + ggml_webgpu_shader_lib_context shader_lib_ctx = {}; + std::vector<uint32_t> params; + std::vector<wgpu::BindGroupEntry> entries; + size_t kv_bind_offset = 0; + size_t kv_bind_size = 0; + bool has_mask = false; + bool has_sinks = false; + bool kv_overlap = false; +}; + +static bool ggml_webgpu_flash_attn_use_vec_path(const webgpu_global_context & global_ctx, + const ggml_tensor * Q, + const ggml_tensor * K, + const ggml_tensor * V) { + const size_t storage_offset_alignment = global_ctx->capabilities.limits.minStorageBufferOffsetAlignment; + const bool k_float_vec4_aligned = (K->type != GGML_TYPE_F16 && K->type != GGML_TYPE_F32) || + ggml_webgpu_flash_attn_float_vec4_aligned(K, storage_offset_alignment); + const bool v_float_vec4_aligned = (V->type != GGML_TYPE_F16 && V->type != GGML_TYPE_F32) || + ggml_webgpu_flash_attn_float_vec4_aligned(V, storage_offset_alignment); + const bool k_vec_type_supported = + K->type == GGML_TYPE_F32 || K->type == GGML_TYPE_F16 || K->type == GGML_TYPE_Q4_0 || K->type == GGML_TYPE_Q8_0; + const bool v_vec_type_supported = + V->type == GGML_TYPE_F32 || V->type == GGML_TYPE_F16 || V->type == GGML_TYPE_Q4_0 || V->type == GGML_TYPE_Q8_0; + const uint32_t k_vec_head_align = (K->type == GGML_TYPE_F32 || K->type == GGML_TYPE_F16) ? + GGML_WEBGPU_FLASH_ATTN_TILE_KV_VEC_WIDTH : + (uint32_t) ggml_blck_size(K->type); + const uint32_t v_vec_head_align = (V->type == GGML_TYPE_F32 || V->type == GGML_TYPE_F16) ? + GGML_WEBGPU_FLASH_ATTN_TILE_KV_VEC_WIDTH : + (uint32_t) ggml_blck_size(V->type); + const bool kv_vec_head_dims_aligned = Q->ne[0] % k_vec_head_align == 0 && V->ne[0] % v_vec_head_align == 0; + + return global_ctx->capabilities.supports_subgroups && (Q->ne[1] < GGML_WEBGPU_FLASH_ATTN_VEC_MAX_SEQ_LEN) && + kv_vec_head_dims_aligned && k_vec_type_supported && v_vec_type_supported && k_float_vec4_aligned && + v_float_vec4_aligned; +} + +static ggml_webgpu_flash_attn_op ggml_webgpu_flash_attn_prepare(webgpu_context & ctx, + ggml_tensor * Q, + ggml_tensor * K, + ggml_tensor * V, + ggml_tensor * mask, + ggml_tensor * sinks, + ggml_tensor * dst) { + float scale = ggml_get_op_params_f32(dst, 0); + float max_bias = ggml_get_op_params_f32(dst, 1); + float logit_softcap = ggml_get_op_params_f32(dst, 2); if (logit_softcap != 0.0f) { scale /= logit_softcap; } @@ -1045,15 +1810,43 @@ static webgpu_command ggml_webgpu_flash_attn(webgpu_context & ctx, float m0 = powf(2.0f, -(max_bias) / n_head_log2); float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2); - const int has_mask = (mask != nullptr); - const int has_sinks = (sinks != nullptr); + ggml_webgpu_flash_attn_op op = {}; + op.shader_lib_ctx.src0 = Q; + op.shader_lib_ctx.src1 = K; + op.shader_lib_ctx.src2 = V; + op.shader_lib_ctx.src3 = mask; + op.shader_lib_ctx.src4 = sinks; + op.shader_lib_ctx.dst = dst; + op.shader_lib_ctx.supports_subgroups = ctx->global_ctx->capabilities.supports_subgroups; + op.shader_lib_ctx.supports_subgroup_matrix = ctx->global_ctx->capabilities.supports_subgroup_matrix; + op.shader_lib_ctx.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup; + op.shader_lib_ctx.wg_mem_limit_bytes = ctx->global_ctx->capabilities.limits.maxComputeWorkgroupStorageSize; + op.shader_lib_ctx.sg_mat_m = ctx->global_ctx->capabilities.sg_mat_m; + op.shader_lib_ctx.sg_mat_n = ctx->global_ctx->capabilities.sg_mat_n; + op.shader_lib_ctx.sg_mat_k = ctx->global_ctx->capabilities.sg_mat_k; + op.shader_lib_ctx.min_subgroup_size = ctx->global_ctx->capabilities.min_subgroup_size; + op.shader_lib_ctx.max_subgroup_size = ctx->global_ctx->capabilities.max_subgroup_size; + + op.has_mask = mask != nullptr; + op.has_sinks = sinks != nullptr; + op.kv_overlap = ggml_webgpu_tensor_overlap(K, V); + + uint32_t offset_k = (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, K) / ggml_type_size(K->type)); + uint32_t offset_v = (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, V) / ggml_type_size(V->type)); + if (op.kv_overlap) { + const ggml_webgpu_merged_binding_range merged_range = ggml_webgpu_tensor_merged_binding_range(ctx, { K, V }); + op.kv_bind_offset = merged_range.offset; + op.kv_bind_size = merged_range.size; + offset_k = ggml_webgpu_tensor_merged_element_offset(K, merged_range); + offset_v = ggml_webgpu_tensor_merged_element_offset(V, merged_range); + } - std::vector<uint32_t> params = { + op.params = { (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, Q) / ggml_type_size(Q->type)), - (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, K) / ggml_type_size(K->type)), - (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, V) / ggml_type_size(V->type)), - has_mask ? (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, mask) / ggml_type_size(mask->type)) : 0, - has_sinks ? (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, sinks) / ggml_type_size(sinks->type)) : 0, + offset_k, + offset_v, + op.has_mask ? (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, mask) / ggml_type_size(mask->type)) : 0, + op.has_sinks ? (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, sinks) / ggml_type_size(sinks->type)) : 0, (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)), (uint32_t) Q->ne[2], // number of heads (uint32_t) Q->ne[1], // sequence length (Q) @@ -1067,170 +1860,362 @@ static webgpu_command ggml_webgpu_flash_attn(webgpu_context & ctx, (uint32_t) (V->nb[1] / ggml_type_size(V->type)), // stride (elements/blocks) of V in dimension 1 (uint32_t) (V->nb[2] / ggml_type_size(V->type)), // stride (elements/blocks) of V in dimension 2 (uint32_t) (V->nb[3] / ggml_type_size(V->type)), // stride (elements/blocks) of V in dimension 3 - has_mask ? (uint32_t) (mask->nb[3] / ggml_type_size(mask->type)) : 0, // stride of mask dim 3 + op.has_mask ? (uint32_t) (mask->nb[3] / ggml_type_size(mask->type)) : 0, // stride of mask dim 3 (uint32_t) (Q->ne[2] / K->ne[2]), // repeat factor for K/V in dim 2 (MHA/MQA/GQA) - *(uint32_t *) &scale, // scale (possibly adjusted for logit softcap) - *(uint32_t *) &max_bias, - *(uint32_t *) &logit_softcap, - *(uint32_t *) &n_head_log2, - *(uint32_t *) &m0, - *(uint32_t *) &m1 - - }; - std::vector<wgpu::BindGroupEntry> entries = { - { .binding = 0, - .buffer = ggml_webgpu_tensor_buf(Q), - .offset = ggml_webgpu_tensor_align_offset(ctx, Q), - .size = ggml_webgpu_tensor_binding_size(ctx, Q) }, - { .binding = 1, - .buffer = ggml_webgpu_tensor_buf(K), - .offset = ggml_webgpu_tensor_align_offset(ctx, K), - .size = ggml_webgpu_tensor_binding_size(ctx, K) }, - { .binding = 2, - .buffer = ggml_webgpu_tensor_buf(V), - .offset = ggml_webgpu_tensor_align_offset(ctx, V), - .size = ggml_webgpu_tensor_binding_size(ctx, V) } + ggml_webgpu_u32_from_f32(scale), // scale (possibly adjusted for logit softcap) + ggml_webgpu_u32_from_f32(max_bias), + ggml_webgpu_u32_from_f32(logit_softcap), + ggml_webgpu_u32_from_f32(n_head_log2), + ggml_webgpu_u32_from_f32(m0), + ggml_webgpu_u32_from_f32(m1) }; - uint32_t binding_index = 3; - if (has_mask) { - entries.push_back({ .binding = binding_index++, - .buffer = ggml_webgpu_tensor_buf(mask), - .offset = ggml_webgpu_tensor_align_offset(ctx, mask), - .size = ggml_webgpu_tensor_binding_size(ctx, mask) }); - } - if (has_sinks) { - entries.push_back({ .binding = binding_index++, - .buffer = ggml_webgpu_tensor_buf(sinks), - .offset = ggml_webgpu_tensor_align_offset(ctx, sinks), - .size = ggml_webgpu_tensor_binding_size(ctx, sinks) }); - } - entries.push_back({ .binding = binding_index++, - .buffer = ggml_webgpu_tensor_buf(dst), - .offset = ggml_webgpu_tensor_align_offset(ctx, dst), - .size = ggml_webgpu_tensor_binding_size(ctx, dst) }); - - bool kv_direct = - (K->type == GGML_TYPE_F16) && (Q->ne[0] % ctx->sg_mat_k == 0) && (K->ne[1] % GGML_WEBGPU_KV_SEQ_PAD == 0); - - flash_attn_pipeline_key key = { - .q_type = Q->type, - .kv_type = K->type, - .dst_type = dst->type, - .head_dim_qk = (uint32_t) Q->ne[0], - .head_dim_v = (uint32_t) V->ne[0], - .kv_direct = kv_direct, - .has_mask = static_cast<bool>(has_mask), - .has_sinks = static_cast<bool>(has_sinks), - .uses_logit_softcap = logit_softcap != 0.0f, + op.entries = { + ggml_webgpu_make_tensor_bind_group_entry(ctx, 0, Q), }; + if (op.kv_overlap) { + op.entries.push_back( + ggml_webgpu_make_bind_group_entry(1, ggml_webgpu_tensor_buf(K), op.kv_bind_offset, op.kv_bind_size)); + } else { + op.entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 1, K)); + op.entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 2, V)); + } + uint32_t binding_index = op.kv_overlap ? 2u : 3u; + if (op.has_mask) { + op.entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, binding_index++, mask)); + } + if (op.has_sinks) { + op.entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, binding_index++, sinks)); + } + op.entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, binding_index++, dst)); - webgpu_pipeline pipeline; - ggml_webgpu_flash_attn_shader_decisions decisions = {}; + return op; +} - auto it = ctx->flash_attn_pipelines.find(key); - if (it != ctx->flash_attn_pipelines.end()) { - pipeline = it->second; - decisions = *static_cast<ggml_webgpu_flash_attn_shader_decisions *>(pipeline.context); - } else { - std::lock_guard<std::recursive_mutex> lock(ctx->mutex); - it = ctx->flash_attn_pipelines.find(key); - if (it != ctx->flash_attn_pipelines.end()) { - pipeline = it->second; - decisions = *static_cast<ggml_webgpu_flash_attn_shader_decisions *>(pipeline.context); - } else { - ggml_webgpu_flash_attn_shader_lib_context shader_lib_ctx = { .kv_type = K->type, - .head_dim_qk = (uint32_t) Q->ne[0], - .head_dim_v = (uint32_t) V->ne[0], - .kv_direct = kv_direct, - .has_mask = static_cast<bool>(has_mask), - .has_sinks = static_cast<bool>(has_sinks), - .uses_logit_softcap = logit_softcap != 0.0f, - .sg_mat_m = ctx->sg_mat_m, - .sg_mat_n = ctx->sg_mat_n, - .sg_mat_k = ctx->sg_mat_k, - .wg_mem_limit_bytes = - ctx->limits.maxComputeWorkgroupStorageSize, - .max_subgroup_size = ctx->max_subgroup_size }; - - ggml_webgpu_processed_shader processed = - ggml_webgpu_preprocess_flash_attn_shader(ctx->p, wgsl_flash_attn, shader_lib_ctx); - pipeline = ggml_webgpu_create_pipeline(ctx->device, processed.wgsl.c_str(), processed.variant.c_str()); - pipeline.context = new ggml_webgpu_flash_attn_shader_decisions(processed.decisions); - ctx->flash_attn_pipelines.emplace(key, pipeline); - decisions = processed.decisions; - } +static uint32_t ggml_webgpu_flash_attn_vec_nwg(uint32_t vec_nwg_cap, uint32_t kv_tile, uint32_t seq_len_kv) { + uint32_t nwg = 1u; + const uint64_t kv_span = (uint64_t) kv_tile; + while ((2u * nwg * kv_span) < (uint64_t) seq_len_kv && nwg < vec_nwg_cap) { + nwg <<= 1; } + return std::min(nwg, vec_nwg_cap); +} - uint32_t wg_per_head = CEIL_DIV(Q->ne[1], decisions.q_tile); - uint32_t wg_x = wg_per_head * Q->ne[2] * Q->ne[3]; // wg per head * number of heads * number of batches - return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x); +static webgpu_encoded_op ggml_webgpu_flash_attn_direct(webgpu_context & ctx, const ggml_webgpu_flash_attn_op & op) { + webgpu_pipeline pipeline = ctx->shader_lib->get_flash_attn_pipeline(op.shader_lib_ctx); + auto * decisions = static_cast<ggml_webgpu_flash_attn_decisions *>(pipeline.context.get()); + uint32_t wg_per_head = CEIL_DIV(op.shader_lib_ctx.src0->ne[1], decisions->q_tile); + uint32_t wg_x = wg_per_head * op.shader_lib_ctx.src0->ne[2] * op.shader_lib_ctx.src0->ne[3]; + return ggml_backend_webgpu_build(ctx, pipeline, op.params, op.entries, wg_x); } -static webgpu_command ggml_webgpu_unary_op(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) { - uint32_t ne = (uint32_t) ggml_nelements(dst); - ggml_unary_op unary_op = ggml_get_unary_op(dst); - uint32_t inplace = ggml_webgpu_tensor_equal(src, dst); +static webgpu_encoded_op ggml_webgpu_flash_attn_vec(webgpu_context & ctx, + ggml_tensor * Q, + ggml_tensor * K, + ggml_tensor * V, + ggml_tensor * mask, + ggml_tensor * sinks, + ggml_tensor * dst, + ggml_webgpu_flash_attn_op op) { + webgpu_pipeline pipeline = ctx->shader_lib->get_flash_attn_vec_pipeline(op.shader_lib_ctx); + auto * decisions = static_cast<ggml_webgpu_flash_attn_vec_decisions *>(pipeline.context.get()); + + wgpu::Buffer blk_buf = {}; + uint64_t blk_size_bytes = 0; + uint32_t blk_nblk0 = 0; + uint32_t blk_nblk1 = 0; + uint32_t blk_batch_count = 0; + + const uint32_t vec_nwg_cap = ctx->global_ctx->capabilities.min_subgroup_size; + uint32_t nwg = ggml_webgpu_flash_attn_vec_nwg(vec_nwg_cap, decisions->kv_tile, (uint32_t) K->ne[1]); + const uint64_t nrows = (uint64_t) Q->ne[1] * Q->ne[2] * Q->ne[3]; + const bool use_vec_reduce = nwg > 1u; + GGML_ASSERT(nrows <= UINT32_MAX); + + uint64_t tmp_stats_base = 0; + uint64_t tmp_size_bytes = 0; + wgpu::Buffer tmp_buf = {}; + uint64_t tmp_bind_offset = 0; + uint64_t tmp_bind_size = 0; + const size_t align_bytes = ctx->global_ctx->capabilities.limits.minStorageBufferOffsetAlignment; + const size_t dst_offset = ggml_webgpu_tensor_offset(dst); + size_t scratch_offset = ROUNDUP_POW2(dst_offset + ggml_nbytes(dst), align_bytes); + + if (use_vec_reduce) { + const uint64_t tmp_data_elems = nrows * (uint64_t) V->ne[0] * nwg; + const uint64_t tmp_stats_elems = nrows * 2u * nwg; + tmp_stats_base = tmp_data_elems; + tmp_size_bytes = + ROUNDUP_POW2((tmp_data_elems + tmp_stats_elems) * sizeof(float), WEBGPU_STORAGE_BUF_BINDING_MULT); + GGML_ASSERT(tmp_stats_base <= UINT32_MAX); + tmp_buf = ggml_webgpu_tensor_buf(dst); + tmp_bind_offset = scratch_offset; + tmp_bind_size = tmp_size_bytes; + scratch_offset = ROUNDUP_POW2(scratch_offset + tmp_size_bytes, align_bytes); + } else { + // nwg==1 writes final dst directly in vec-split; bind tmp to a tiny non-overlapping scratch region. + tmp_size_bytes = WEBGPU_STORAGE_BUF_BINDING_MULT; + tmp_buf = ggml_webgpu_tensor_buf(dst); + tmp_bind_offset = scratch_offset; + tmp_bind_size = tmp_size_bytes; + scratch_offset = ROUNDUP_POW2(scratch_offset + tmp_size_bytes, align_bytes); + } - std::vector<uint32_t> params = { - ne, (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src) / ggml_type_size(src->type)), - (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)), - // Convert byte-strides to element-strides - (uint32_t) (src->nb[0] / ggml_type_size(src->type)), (uint32_t) (src->nb[1] / ggml_type_size(src->type)), - (uint32_t) (src->nb[2] / ggml_type_size(src->type)), (uint32_t) (src->nb[3] / ggml_type_size(src->type)), - (uint32_t) (dst->nb[0] / ggml_type_size(dst->type)), (uint32_t) (dst->nb[1] / ggml_type_size(dst->type)), - (uint32_t) (dst->nb[2] / ggml_type_size(dst->type)), (uint32_t) (dst->nb[3] / ggml_type_size(dst->type)), - // Logical shapes - (uint32_t) src->ne[0], (uint32_t) src->ne[1], (uint32_t) src->ne[2], (uint32_t) dst->ne[0], - (uint32_t) dst->ne[1], (uint32_t) dst->ne[2] + webgpu_pipeline blk_pipeline; + std::vector<uint32_t> blk_params; + std::vector<wgpu::BindGroupEntry> blk_entries; + if (op.has_mask) { + blk_nblk0 = CEIL_DIV((uint32_t) K->ne[1], decisions->kv_tile); + blk_nblk1 = (uint32_t) Q->ne[1]; + blk_buf = ggml_webgpu_tensor_buf(dst); + const uint32_t stride_mask3 = (uint32_t) (mask->nb[3] / ggml_type_size(mask->type)); + blk_batch_count = stride_mask3 > 0 ? (uint32_t) Q->ne[3] : 1u; + const uint64_t blk_elems = (uint64_t) blk_nblk0 * blk_nblk1 * blk_batch_count; + blk_size_bytes = ROUNDUP_POW2(blk_elems * sizeof(uint32_t), WEBGPU_STORAGE_BUF_BINDING_MULT); + const ggml_webgpu_shader_lib_context blk_shader_ctx = op.shader_lib_ctx; + blk_pipeline = ctx->shader_lib->get_flash_attn_blk_pipeline(blk_shader_ctx, decisions->kv_tile); + + blk_params = { + (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, mask) / ggml_type_size(mask->type)), // offset_mask + (uint32_t) Q->ne[1], // seq_len_q + (uint32_t) K->ne[1], // seq_len_kv + stride_mask3, // stride_mask3 + blk_nblk0, // nblk0 + blk_nblk1, // nblk1 + }; + blk_entries = { + ggml_webgpu_make_bind_group_entry(0, ggml_webgpu_tensor_buf(mask), + ggml_webgpu_tensor_align_offset(ctx, mask), + ggml_webgpu_tensor_binding_size(ctx, mask)), + ggml_webgpu_make_bind_group_entry(1, blk_buf, scratch_offset, blk_size_bytes), + }; + scratch_offset = ROUNDUP_POW2(scratch_offset + blk_size_bytes, align_bytes); + } + + std::vector<uint32_t> split_params = op.params; + if (op.has_mask) { + split_params.push_back(0u); // blk_base + split_params.push_back(blk_nblk0); // blk_nblk0 + split_params.push_back(blk_nblk1); // blk_nblk1 + } + split_params.push_back(0u); // tmp_data_base + split_params.push_back((uint32_t) tmp_stats_base); // tmp_stats_base + split_params.push_back(nwg); // nwg + + std::vector<wgpu::BindGroupEntry> split_entries = { + ggml_webgpu_make_bind_group_entry(0, ggml_webgpu_tensor_buf(Q), ggml_webgpu_tensor_align_offset(ctx, Q), + ggml_webgpu_tensor_binding_size(ctx, Q)), }; + if (op.kv_overlap) { + split_entries.push_back( + ggml_webgpu_make_bind_group_entry(1, ggml_webgpu_tensor_buf(K), op.kv_bind_offset, op.kv_bind_size)); + } else { + split_entries.push_back(ggml_webgpu_make_bind_group_entry(1, ggml_webgpu_tensor_buf(K), + ggml_webgpu_tensor_align_offset(ctx, K), + ggml_webgpu_tensor_binding_size(ctx, K))); + split_entries.push_back(ggml_webgpu_make_bind_group_entry(2, ggml_webgpu_tensor_buf(V), + ggml_webgpu_tensor_align_offset(ctx, V), + ggml_webgpu_tensor_binding_size(ctx, V))); + } + uint32_t split_binding_index = op.kv_overlap ? 2u : 3u; + if (op.has_mask) { + split_entries.push_back(ggml_webgpu_make_bind_group_entry(split_binding_index++, ggml_webgpu_tensor_buf(mask), + ggml_webgpu_tensor_align_offset(ctx, mask), + ggml_webgpu_tensor_binding_size(ctx, mask))); + } + if (op.has_sinks) { + split_entries.push_back(ggml_webgpu_make_bind_group_entry(split_binding_index++, ggml_webgpu_tensor_buf(sinks), + ggml_webgpu_tensor_align_offset(ctx, sinks), + ggml_webgpu_tensor_binding_size(ctx, sinks))); + } + if (op.has_mask) { + split_entries.push_back( + ggml_webgpu_make_bind_group_entry(split_binding_index++, blk_buf, blk_entries[1].offset, blk_size_bytes)); + } + split_entries.push_back( + ggml_webgpu_make_bind_group_entry(split_binding_index++, tmp_buf, tmp_bind_offset, tmp_bind_size)); + split_entries.push_back(ggml_webgpu_make_bind_group_entry(split_binding_index++, ggml_webgpu_tensor_buf(dst), + ggml_webgpu_tensor_align_offset(ctx, dst), + ggml_webgpu_tensor_binding_size(ctx, dst))); + + webgpu_pipeline reduce_pipeline; + std::vector<uint32_t> reduce_params; + std::vector<wgpu::BindGroupEntry> reduce_entries; + if (use_vec_reduce) { + const uint32_t reduce_sg_size = ctx->global_ctx->capabilities.max_subgroup_size; + const uint32_t reduce_wg_size = std::max( + reduce_sg_size, + (uint32_t) std::min<uint64_t>((uint64_t) nwg * reduce_sg_size, + ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup)); + ggml_webgpu_shader_lib_context reduce_shader_ctx = op.shader_lib_ctx; + reduce_shader_ctx.max_wg_size = reduce_wg_size; + reduce_pipeline = ctx->shader_lib->get_flash_attn_vec_reduce_pipeline(reduce_shader_ctx); + + reduce_params = { + (uint32_t) nrows, // nrows + (uint32_t) Q->ne[1], // seq_len_q + (uint32_t) Q->ne[2], // n_heads + (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)), // offset_dst + nwg, // nwg + 0u, // tmp_data_base + (uint32_t) tmp_stats_base, // tmp_stats_base + }; + + reduce_entries = { + ggml_webgpu_make_bind_group_entry(0, tmp_buf, tmp_bind_offset, tmp_size_bytes), + ggml_webgpu_make_bind_group_entry(1, ggml_webgpu_tensor_buf(dst), ggml_webgpu_tensor_align_offset(ctx, dst), + ggml_webgpu_tensor_binding_size(ctx, dst)), + }; + } - switch (unary_op) { - case GGML_UNARY_OP_XIELU: - { - // Get float parameters and reinterpret their bit patterns as uint32_t - // for passing through the params buffer - float alpha_n = ggml_get_op_params_f32(dst, 1); - float alpha_p = ggml_get_op_params_f32(dst, 2); - float beta = ggml_get_op_params_f32(dst, 3); - float eps = ggml_get_op_params_f32(dst, 4); - params.push_back(*reinterpret_cast<const uint32_t *>(&alpha_n)); - params.push_back(*reinterpret_cast<const uint32_t *>(&alpha_p)); - params.push_back(*reinterpret_cast<const uint32_t *>(&beta)); - params.push_back(*reinterpret_cast<const uint32_t *>(&eps)); + uint32_t wg_x = Q->ne[1] * Q->ne[2] * Q->ne[3]; + const uint64_t split_wg_total = (uint64_t) wg_x * nwg; + GGML_ASSERT(split_wg_total <= UINT32_MAX); + + std::vector<webgpu_dispatch_desc> dispatches; + + if (op.has_mask) { + dispatches.push_back({ + blk_pipeline, std::move(blk_params), std::move(blk_entries), { blk_nblk0, blk_nblk1 * blk_batch_count } + }); + } + dispatches.push_back({ + pipeline, std::move(split_params), std::move(split_entries), { (uint32_t) split_wg_total, 1u } + }); + if (use_vec_reduce) { + dispatches.push_back({ + reduce_pipeline, std::move(reduce_params), std::move(reduce_entries), { (uint32_t) nrows, 1u } + }); + } + + return ggml_backend_webgpu_build_multi(ctx, dispatches); +} + +static webgpu_encoded_op ggml_webgpu_flash_attn(webgpu_context & ctx, + ggml_tensor * Q, + ggml_tensor * K, + ggml_tensor * V, + ggml_tensor * mask, + ggml_tensor * sinks, + ggml_tensor * dst) { + ggml_webgpu_flash_attn_op op = ggml_webgpu_flash_attn_prepare(ctx, Q, K, V, mask, sinks, dst); + if (ggml_webgpu_flash_attn_use_vec_path(ctx->global_ctx, Q, K, V)) { + return ggml_webgpu_flash_attn_vec(ctx, Q, K, V, mask, sinks, dst, std::move(op)); + } + return ggml_webgpu_flash_attn_direct(ctx, op); +} + +static webgpu_encoded_op ggml_webgpu_unary_op(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) { + bool is_unary = dst->op == GGML_OP_UNARY; + + ggml_webgpu_shader_lib_context shader_lib_ctx = {}; + shader_lib_ctx.src0 = src; + shader_lib_ctx.src1 = nullptr; + shader_lib_ctx.dst = dst; + shader_lib_ctx.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup; + + webgpu_pipeline pipeline = ctx->shader_lib->get_unary_pipeline(shader_lib_ctx); + + auto * decisions = static_cast<ggml_webgpu_generic_shader_decisions *>(pipeline.context.get()); + const bool inplace = decisions->inplace; + + uint32_t ne = (uint32_t) ggml_nelements(dst); + + std::vector<uint32_t> params = { ne, + (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src) / ggml_type_size(src->type)), + (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)), + (uint32_t) (src->nb[0] / ggml_type_size(src->type)), + (uint32_t) (src->nb[1] / ggml_type_size(src->type)), + (uint32_t) (src->nb[2] / ggml_type_size(src->type)), + (uint32_t) (src->nb[3] / ggml_type_size(src->type)), + (uint32_t) src->ne[0], + (uint32_t) src->ne[1], + (uint32_t) src->ne[2] }; + + ggml_tensor * effective_src = src; + if (is_unary) { + ggml_unary_op unary_op = ggml_get_unary_op(dst); + switch (unary_op) { + case GGML_UNARY_OP_XIELU: + { + // Get float parameters and reinterpret their bit patterns as uint32_t + // for passing through the params buffer + float alpha_n = ggml_get_op_params_f32(dst, 1); + float alpha_p = ggml_get_op_params_f32(dst, 2); + float beta = ggml_get_op_params_f32(dst, 3); + float eps = ggml_get_op_params_f32(dst, 4); + params.push_back(ggml_webgpu_u32_from_f32(alpha_n)); + params.push_back(ggml_webgpu_u32_from_f32(alpha_p)); + params.push_back(ggml_webgpu_u32_from_f32(beta)); + params.push_back(ggml_webgpu_u32_from_f32(eps)); + break; + } + default: break; - } - default: - break; + } + } else if (dst->op == GGML_OP_CLAMP) { + float clamp_min = ggml_get_op_params_f32(dst, 0); + float clamp_max = ggml_get_op_params_f32(dst, 1); + params.push_back(ggml_webgpu_u32_from_f32(clamp_min)); + params.push_back(ggml_webgpu_u32_from_f32(clamp_max)); + } else if (dst->op == GGML_OP_FILL) { + float fill_val = ggml_get_op_params_f32(dst, 0); + params.push_back(ggml_webgpu_u32_from_f32(fill_val)); + effective_src = dst; // fill simply fills dst } std::vector<wgpu::BindGroupEntry> entries = { - { .binding = 0, - .buffer = ggml_webgpu_tensor_buf(src), - .offset = ggml_webgpu_tensor_align_offset(ctx, src), - .size = ggml_webgpu_tensor_binding_size(ctx, src) }, + ggml_webgpu_make_tensor_bind_group_entry(ctx, 0, effective_src), }; if (!inplace) { - entries.push_back({ .binding = 1, - .buffer = ggml_webgpu_tensor_buf(dst), - .offset = ggml_webgpu_tensor_align_offset(ctx, dst), - .size = ggml_webgpu_tensor_binding_size(ctx, dst) }); + entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 1, dst)); } - uint32_t wg_x = CEIL_DIV(ne, WEBGPU_MAX_WG_SIZE); - return ggml_backend_webgpu_build(ctx, ctx->unary_pipelines[unary_op][dst->type][inplace], params, entries, wg_x); + uint32_t wg_x, wg_y; + uint32_t total_wg = CEIL_DIV(ggml_nelements(dst), decisions->wg_size); + compute_2d_workgroups(total_wg, ctx->global_ctx->capabilities.limits.maxComputeWorkgroupsPerDimension, wg_x, wg_y); + return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x, wg_y); } -static webgpu_command ggml_webgpu_binary_op(webgpu_context & ctx, - ggml_tensor * src0, - ggml_tensor * src1, - ggml_tensor * dst, - webgpu_pipeline & pipeline, - bool inplace) { +static webgpu_encoded_op ggml_webgpu_binary_op(webgpu_context & ctx, + ggml_tensor * src0, + ggml_tensor * src1, + ggml_tensor * dst) { + ggml_webgpu_shader_lib_context shader_lib_ctx = {}; + shader_lib_ctx.src0 = src0; + shader_lib_ctx.src1 = src1; + shader_lib_ctx.dst = dst; + shader_lib_ctx.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup; + + webgpu_pipeline pipeline = ctx->shader_lib->get_binary_pipeline(shader_lib_ctx); + auto * decisions = static_cast<ggml_webgpu_binary_shader_decisions *>(pipeline.context.get()); + + uint32_t ne = (uint32_t) ggml_nelements(dst); + + size_t src0_webgpu_tensor_align_offset = ggml_webgpu_tensor_align_offset(ctx, src0); + size_t src1_webgpu_tensor_align_offset = ggml_webgpu_tensor_align_offset(ctx, src1); + + uint32_t offset_src0 = (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src0) / ggml_type_size(src0->type)); + uint32_t offset_src1 = (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src1) / ggml_type_size(src1->type)); + size_t merged_offset = 0; + size_t merged_size = 0; + if (decisions->src_overlap) { + const ggml_webgpu_merged_binding_range merged_range = + ggml_webgpu_tensor_merged_binding_range(ctx, { src0, src1 }); + merged_offset = merged_range.offset; + merged_size = merged_range.size; + offset_src0 = ggml_webgpu_tensor_merged_element_offset(src0, merged_range); + offset_src1 = ggml_webgpu_tensor_merged_element_offset(src1, merged_range); + } + std::vector<uint32_t> params = { - (uint32_t) ggml_nelements(dst), - (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src0) / ggml_type_size(src0->type)), - (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src1) / ggml_type_size(src1->type)), + ne, + offset_src0, + offset_src1, (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)), + (uint32_t) (src0->nb[0] / ggml_type_size(src0->type)), + (uint32_t) (src0->nb[1] / ggml_type_size(src0->type)), + (uint32_t) (src0->nb[2] / ggml_type_size(src0->type)), + (uint32_t) (src0->nb[3] / ggml_type_size(src0->type)), (uint32_t) (src1->nb[0] / ggml_type_size(src1->type)), (uint32_t) (src1->nb[1] / ggml_type_size(src1->type)), (uint32_t) (src1->nb[2] / ggml_type_size(src1->type)), @@ -1244,30 +2229,263 @@ static webgpu_command ggml_webgpu_binary_op(webgpu_context & ctx, (uint32_t) src1->ne[3], }; - std::vector<wgpu::BindGroupEntry> entries = { - { .binding = 0, - .buffer = ggml_webgpu_tensor_buf(src0), - .offset = ggml_webgpu_tensor_align_offset(ctx, src0), - .size = ggml_webgpu_tensor_binding_size(ctx, src0) }, - { .binding = 1, - .buffer = ggml_webgpu_tensor_buf(src1), - .offset = ggml_webgpu_tensor_align_offset(ctx, src1), - .size = ggml_webgpu_tensor_binding_size(ctx, src1) } + std::vector<wgpu::BindGroupEntry> entries; + + if (decisions->src_overlap) { + entries.push_back( + ggml_webgpu_make_bind_group_entry(0, ggml_webgpu_tensor_buf(src0), merged_offset, merged_size)); + entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 1, dst)); + } else { + entries.push_back(ggml_webgpu_make_bind_group_entry(0, ggml_webgpu_tensor_buf(src0), + src0_webgpu_tensor_align_offset, + ggml_webgpu_tensor_binding_size(ctx, src0))); + entries.push_back(ggml_webgpu_make_bind_group_entry(1, ggml_webgpu_tensor_buf(src1), + src1_webgpu_tensor_align_offset, + ggml_webgpu_tensor_binding_size(ctx, src1))); + if (!decisions->inplace && !decisions->overlap) { + entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 2, dst)); + } + } + + uint32_t wg_x, wg_y; + uint32_t total_wg = CEIL_DIV(ggml_nelements(dst), decisions->wg_size); + compute_2d_workgroups(total_wg, ctx->global_ctx->capabilities.limits.maxComputeWorkgroupsPerDimension, wg_x, wg_y); + return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x, wg_y); +} + +static webgpu_encoded_op ggml_webgpu_add_id(webgpu_context & ctx, + ggml_tensor * src0, + ggml_tensor * src1, + ggml_tensor * src2, + ggml_tensor * dst) { + ggml_webgpu_shader_lib_context shader_lib_ctx = {}; + shader_lib_ctx.src0 = src0; + shader_lib_ctx.src1 = src1; + shader_lib_ctx.src2 = src2; + shader_lib_ctx.dst = dst; + shader_lib_ctx.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup; + + webgpu_pipeline pipeline = ctx->shader_lib->get_add_id_pipeline(shader_lib_ctx); + + auto * decisions = static_cast<ggml_webgpu_generic_shader_decisions *>(pipeline.context.get()); + + std::vector<uint32_t> params = { + (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src0) / ggml_type_size(src0->type)), + (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src1) / ggml_type_size(src1->type)), + (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src2) / ggml_type_size(src2->type)), + (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)), + (uint32_t) (src0->nb[1] / ggml_type_size(src0->type)), + (uint32_t) (src0->nb[2] / ggml_type_size(src0->type)), + (uint32_t) (src1->nb[1] / ggml_type_size(src1->type)), + (uint32_t) (src2->nb[0] / ggml_type_size(src2->type)), + (uint32_t) (src2->nb[1] / ggml_type_size(src2->type)), + (uint32_t) dst->ne[0], + (uint32_t) dst->ne[1], + (uint32_t) dst->ne[2], }; - if (!inplace) { - entries.push_back({ .binding = 2, - .buffer = ggml_webgpu_tensor_buf(dst), - .offset = ggml_webgpu_tensor_align_offset(ctx, dst), - .size = ggml_webgpu_tensor_binding_size(ctx, dst) }); + + std::vector<wgpu::BindGroupEntry> entries; + + entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 0, src0)); + entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 1, src1)); + entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 2, src2)); + + if (!decisions->inplace) { + entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 3, dst)); + } + + uint32_t wg_x = 1; + uint32_t wg_y = 1; + uint32_t total_wg = ggml_nrows(dst); + const uint32_t max_wg_per_dim = ctx->global_ctx->capabilities.limits.maxComputeWorkgroupsPerDimension; + compute_2d_workgroups(total_wg, max_wg_per_dim, wg_x, wg_y); + + return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x, wg_y); +} + +static webgpu_encoded_op ggml_webgpu_concat(webgpu_context & ctx, + ggml_tensor * src0, + ggml_tensor * src1, + ggml_tensor * dst) { + uint32_t ne = (uint32_t) ggml_nelements(dst); + uint32_t dim = (uint32_t) dst->op_params[0]; + + ggml_webgpu_shader_lib_context shader_lib_ctx = {}; + shader_lib_ctx.src0 = src0; + shader_lib_ctx.src1 = src1; + shader_lib_ctx.dst = dst; + shader_lib_ctx.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup; + + webgpu_pipeline pipeline = ctx->shader_lib->get_concat_pipeline(shader_lib_ctx); + auto * decisions = static_cast<ggml_webgpu_binary_shader_decisions *>(pipeline.context.get()); + + uint32_t offset_src0 = (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src0) / ggml_type_size(src0->type)); + uint32_t offset_src1 = (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src1) / ggml_type_size(src1->type)); + size_t merged_offset = 0; + size_t merged_size = 0; + if (decisions->src_overlap) { + const ggml_webgpu_merged_binding_range merged_range = + ggml_webgpu_tensor_merged_binding_range(ctx, { src0, src1 }); + merged_offset = merged_range.offset; + merged_size = merged_range.size; + offset_src0 = ggml_webgpu_tensor_merged_element_offset(src0, merged_range); + offset_src1 = ggml_webgpu_tensor_merged_element_offset(src1, merged_range); + } + + std::vector<uint32_t> params = { ne, + offset_src0, + offset_src1, + (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)), + (uint32_t) (src0->nb[0] / ggml_type_size(src0->type)), + (uint32_t) (src0->nb[1] / ggml_type_size(src0->type)), + (uint32_t) (src0->nb[2] / ggml_type_size(src0->type)), + (uint32_t) (src0->nb[3] / ggml_type_size(src0->type)), + (uint32_t) (src1->nb[0] / ggml_type_size(src1->type)), + (uint32_t) (src1->nb[1] / ggml_type_size(src1->type)), + (uint32_t) (src1->nb[2] / ggml_type_size(src1->type)), + (uint32_t) (src1->nb[3] / ggml_type_size(src1->type)), + (uint32_t) dst->ne[0], + (uint32_t) dst->ne[1], + (uint32_t) dst->ne[2], + (uint32_t) dst->ne[3], + dim, + (uint32_t) src0->ne[dim] }; + + std::vector<wgpu::BindGroupEntry> entries = {}; + if (decisions->src_overlap) { + entries.push_back( + ggml_webgpu_make_bind_group_entry(0, ggml_webgpu_tensor_buf(src0), merged_offset, merged_size)); + entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 1, dst)); + } else { + entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 0, src0)); + entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 1, src1)); + entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 2, dst)); } - uint32_t wg_x = CEIL_DIV(ggml_nelements(dst), WEBGPU_MAX_WG_SIZE); + uint32_t wg_x = CEIL_DIV(ne, decisions->wg_size); + return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x); +} + +static webgpu_encoded_op ggml_webgpu_repeat(webgpu_context & ctx, ggml_tensor * src0, ggml_tensor * dst) { + uint32_t ne = (uint32_t) ggml_nelements(dst); + + std::vector<uint32_t> params = { ne, + (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src0) / + ggml_type_size(src0->type)), + (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)), + (uint32_t) (src0->nb[0] / ggml_type_size(src0->type)), + (uint32_t) (src0->nb[1] / ggml_type_size(src0->type)), + (uint32_t) (src0->nb[2] / ggml_type_size(src0->type)), + (uint32_t) (src0->nb[3] / ggml_type_size(src0->type)), + (uint32_t) (src0->ne[0]), + (uint32_t) (src0->ne[1]), + (uint32_t) (src0->ne[2]), + (uint32_t) (src0->ne[3]), + (uint32_t) (dst->ne[0]), + (uint32_t) (dst->ne[1]), + (uint32_t) (dst->ne[2]) }; + + std::vector<wgpu::BindGroupEntry> entries = { + ggml_webgpu_make_tensor_bind_group_entry(ctx, 0, src0), + ggml_webgpu_make_tensor_bind_group_entry(ctx, 1, dst), + }; + + ggml_webgpu_shader_lib_context shader_lib_ctx = {}; + shader_lib_ctx.src0 = src0; + shader_lib_ctx.dst = dst; + shader_lib_ctx.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup; + + webgpu_pipeline pipeline = ctx->shader_lib->get_repeat_pipeline(shader_lib_ctx); + auto * decisions = static_cast<ggml_webgpu_generic_shader_decisions *>(pipeline.context.get()); + uint32_t wg_x = CEIL_DIV(ne, decisions->wg_size); return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x); } -static webgpu_command ggml_webgpu_rms_norm(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) { - int inplace = ggml_webgpu_tensor_equal(src, dst); +static std::optional<webgpu_encoded_op> ggml_webgpu_rms_norm_mul(webgpu_context & ctx, + ggml_tensor * rn_src, + ggml_tensor * rn_dst, + ggml_tensor * mul_src0, + ggml_tensor * mul_src1, + ggml_tensor * dst) { + ggml_tensor * mul_src; + + if (ggml_webgpu_tensor_equal(rn_dst, mul_src0)) { + mul_src = mul_src1; + } else if (ggml_webgpu_tensor_equal(rn_dst, mul_src1)) { + mul_src = mul_src0; + } else { + GGML_ABORT("rms_norm must be equal to the one of mul_src0 and mul_src1"); + } + + uint32_t offset_rn_src = (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, rn_src) / ggml_type_size(rn_src->type)); + uint32_t offset_mul_src = + (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, mul_src) / ggml_type_size(mul_src->type)); + size_t merged_offset = 0; + size_t merged_size = 0; + + std::vector<uint32_t> params = { + offset_rn_src, + offset_mul_src, + (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)), + (uint32_t) (rn_src->nb[1] / ggml_type_size(rn_src->type)), + (uint32_t) (rn_src->nb[2] / ggml_type_size(rn_src->type)), + (uint32_t) (rn_src->nb[3] / ggml_type_size(rn_src->type)), + (uint32_t) (mul_src->nb[1] / ggml_type_size(mul_src->type)), + (uint32_t) (mul_src->nb[2] / ggml_type_size(mul_src->type)), + (uint32_t) (mul_src->nb[3] / ggml_type_size(mul_src->type)), + (uint32_t) (dst->nb[1] / ggml_type_size(dst->type)), + (uint32_t) (dst->nb[2] / ggml_type_size(dst->type)), + (uint32_t) (dst->nb[3] / ggml_type_size(dst->type)), + (uint32_t) mul_src->ne[0], + (uint32_t) mul_src->ne[1], + (uint32_t) mul_src->ne[2], + (uint32_t) mul_src->ne[3], + (uint32_t) dst->ne[0], + (uint32_t) dst->ne[1], + (uint32_t) dst->ne[2], + (uint32_t) dst->ne[3], + ggml_webgpu_u32_from_f32(ggml_get_op_params_f32(rn_dst, 0)) // epsilon, treated as f32 in the shader + }; + + std::vector<wgpu::BindGroupEntry> entries; + + ggml_webgpu_shader_lib_context shader_lib_ctx = {}; + shader_lib_ctx.src0 = rn_src; + shader_lib_ctx.src1 = mul_src; + shader_lib_ctx.dst = dst; + shader_lib_ctx.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup; + + webgpu_pipeline pipeline = ctx->shader_lib->get_rms_norm_mul_pipeline(shader_lib_ctx); + auto * decisions = static_cast<ggml_webgpu_rms_norm_mul_shader_decisions *>(pipeline.context.get()); + + if (decisions->src_overlap) { + const ggml_webgpu_merged_binding_range merged_range = + ggml_webgpu_tensor_merged_binding_range(ctx, { rn_src, mul_src }); + merged_offset = merged_range.offset; + merged_size = merged_range.size; + offset_rn_src = ggml_webgpu_tensor_merged_element_offset(rn_src, merged_range); + offset_mul_src = ggml_webgpu_tensor_merged_element_offset(mul_src, merged_range); + params[0] = offset_rn_src; + params[1] = offset_mul_src; + } + + if (decisions->inplace || decisions->overlap) { + entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 0, rn_src)); + entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 1, mul_src)); + } else if (decisions->src_overlap) { + entries.push_back( + ggml_webgpu_make_bind_group_entry(0, ggml_webgpu_tensor_buf(rn_src), merged_offset, merged_size)); + entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 1, dst)); + } else { + entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 0, rn_src)); + entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 1, mul_src)); + entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 2, dst)); + } + + return ggml_backend_webgpu_build(ctx, pipeline, params, entries, ggml_nrows(dst)); +} +static webgpu_encoded_op ggml_webgpu_row_norm(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) { std::vector<uint32_t> params = { (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src) / ggml_type_size(src->type)), (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)), @@ -1281,38 +2499,53 @@ static webgpu_command ggml_webgpu_rms_norm(webgpu_context & ctx, ggml_tensor * s (uint32_t) src->ne[1], (uint32_t) src->ne[2], (uint32_t) src->ne[3], - *(uint32_t *) dst->op_params // epsilon, treated as f32 in the shader + ggml_webgpu_u32_from_f32(ggml_get_op_params_f32(dst, 0)) // epsilon, treated as f32 in the shader }; - std::vector<wgpu::BindGroupEntry> entries = { - { .binding = 0, - .buffer = ggml_webgpu_tensor_buf(src), - .offset = ggml_webgpu_tensor_align_offset(ctx, src), - .size = ggml_webgpu_tensor_binding_size(ctx, src) } - }; - if (!inplace) { - entries.push_back({ .binding = 1, - .buffer = ggml_webgpu_tensor_buf(dst), - .offset = ggml_webgpu_tensor_align_offset(ctx, dst), - .size = ggml_webgpu_tensor_binding_size(ctx, dst) }); - } + ggml_webgpu_shader_lib_context shader_lib_ctx = {}; + shader_lib_ctx.src0 = src; + shader_lib_ctx.dst = dst; + shader_lib_ctx.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup; + + webgpu_pipeline pipeline = ctx->shader_lib->get_row_norm_pipeline(shader_lib_ctx); + auto * decisions = static_cast<ggml_webgpu_generic_shader_decisions *>(pipeline.context.get()); - return ggml_backend_webgpu_build(ctx, ctx->rms_norm_pipelines[inplace], params, entries, ggml_nrows(src)); + std::vector<wgpu::BindGroupEntry> entries = { ggml_webgpu_make_tensor_bind_group_entry(ctx, 0, src) }; + if (!decisions->inplace) { + entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 1, dst)); + } + return ggml_backend_webgpu_build(ctx, pipeline, params, entries, ggml_nrows(src)); } -static webgpu_command ggml_webgpu_rope(webgpu_context & ctx, - ggml_tensor * src0, - ggml_tensor * src1, - ggml_tensor * src2, - ggml_tensor * dst) { - const int inplace = ggml_webgpu_tensor_equal(src0, dst); - const int has_freq_factor = (src2 != nullptr); +static webgpu_encoded_op ggml_webgpu_rope(webgpu_context & ctx, + ggml_tensor * src0, + ggml_tensor * src1, + ggml_tensor * src2, + ggml_tensor * dst) { + ggml_webgpu_shader_lib_context shader_lib_ctx = {}; + shader_lib_ctx.src0 = src0; + shader_lib_ctx.src1 = src1; + shader_lib_ctx.src2 = src2; + shader_lib_ctx.dst = dst; + shader_lib_ctx.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup; + + webgpu_pipeline pipeline = ctx->shader_lib->get_rope_pipeline(shader_lib_ctx); + + auto * decisions = static_cast<ggml_webgpu_generic_shader_decisions *>(pipeline.context.get()); + + const bool inplace = decisions->inplace; + const int has_freq_factor = (src2 != nullptr); const int n_dims = ((int32_t *) dst->op_params)[1]; const int mode = ((int32_t *) dst->op_params)[2]; const int n_ctx_orig = ((int32_t *) dst->op_params)[4]; - float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow; + float freq_base; + float freq_scale; + float ext_factor; + float attn_factor; + float beta_fast; + float beta_slow; memcpy(&freq_base, (int32_t *) dst->op_params + 5, sizeof(float)); memcpy(&freq_scale, (int32_t *) dst->op_params + 6, sizeof(float)); memcpy(&ext_factor, (int32_t *) dst->op_params + 7, sizeof(float)); @@ -1345,49 +2578,47 @@ static webgpu_command ggml_webgpu_rope(webgpu_context & ctx, (uint32_t) src0->ne[2], (uint32_t) n_dims, (uint32_t) mode, - *(uint32_t *) &theta_scale, - *(uint32_t *) &attn_factor, - *(uint32_t *) &freq_scale, - *(uint32_t *) &ext_factor, - *(uint32_t *) &corr_dims[0], - *(uint32_t *) &corr_dims[1], + ggml_webgpu_u32_from_f32(theta_scale), + ggml_webgpu_u32_from_f32(attn_factor), + ggml_webgpu_u32_from_f32(freq_scale), + ggml_webgpu_u32_from_f32(ext_factor), + ggml_webgpu_u32_from_f32(corr_dims[0]), + ggml_webgpu_u32_from_f32(corr_dims[1]), (uint32_t) sections[0], (uint32_t) sections[1], (uint32_t) sections[2], (uint32_t) sections[3] }; - std::vector<wgpu::BindGroupEntry> entries = { - { .binding = 0, - .buffer = ggml_webgpu_tensor_buf(src0), - .offset = ggml_webgpu_tensor_align_offset(ctx, src0), - .size = ggml_webgpu_tensor_binding_size(ctx, src0) }, - { .binding = 1, - .buffer = ggml_webgpu_tensor_buf(src1), - .offset = ggml_webgpu_tensor_align_offset(ctx, src1), - .size = ggml_webgpu_tensor_binding_size(ctx, src1) } - }; - uint32_t dst_binding = 2; + std::vector<wgpu::BindGroupEntry> entries = { ggml_webgpu_make_tensor_bind_group_entry(ctx, 0, src0), + ggml_webgpu_make_tensor_bind_group_entry(ctx, 1, src1) }; + uint32_t dst_binding = 2; if (has_freq_factor) { dst_binding = 3; - entries.push_back({ .binding = 2, - .buffer = ggml_webgpu_tensor_buf(src2), - .offset = ggml_webgpu_tensor_align_offset(ctx, src2), - .size = ggml_webgpu_tensor_binding_size(ctx, src2) }); + entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 2, src2)); } if (!inplace) { - entries.push_back({ .binding = dst_binding, - .buffer = ggml_webgpu_tensor_buf(dst), - .offset = ggml_webgpu_tensor_align_offset(ctx, dst), - .size = ggml_webgpu_tensor_binding_size(ctx, dst) }); + entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, dst_binding, dst)); } - webgpu_pipeline pipeline = ctx->rope_pipelines[dst->type][has_freq_factor][inplace]; - uint32_t wg_x = CEIL_DIV(ggml_nelements(dst), WEBGPU_MAX_WG_SIZE); + uint32_t wg_x = CEIL_DIV(ggml_nelements(dst), decisions->wg_size); return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x); } -static webgpu_command ggml_webgpu_glu(webgpu_context & ctx, ggml_tensor * src0, ggml_tensor * src1, ggml_tensor * dst) { +static webgpu_encoded_op ggml_webgpu_glu(webgpu_context & ctx, + ggml_tensor * src0, + ggml_tensor * src1, + ggml_tensor * dst) { + ggml_webgpu_shader_lib_context shader_lib_ctx = {}; + shader_lib_ctx.src0 = src0; + shader_lib_ctx.src1 = src1; + shader_lib_ctx.dst = dst; + shader_lib_ctx.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup; + + webgpu_pipeline pipeline = ctx->shader_lib->get_glu_pipeline(shader_lib_ctx); + + auto * decisions = static_cast<ggml_webgpu_generic_shader_decisions *>(pipeline.context.get()); + const int split = (src1 != nullptr); std::vector<uint32_t> params = { @@ -1410,38 +2641,36 @@ static webgpu_command ggml_webgpu_glu(webgpu_context & ctx, ggml_tensor * src0, (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], - (uint32_t) ((int32_t *) dst->op_params)[1], // swapped - *(uint32_t *) &dst->op_params[2], // alpha, for swiglu_oai - *(uint32_t *) &dst->op_params[3], // limit, for swiglu_oai + (uint32_t) ((int32_t *) dst->op_params)[1], // swapped + ggml_webgpu_u32_from_f32(ggml_get_op_params_f32(dst, 2)), // alpha, for swiglu_oai + ggml_webgpu_u32_from_f32(ggml_get_op_params_f32(dst, 3)), // limit, for swiglu_oai }; std::vector<wgpu::BindGroupEntry> entries = { - { .binding = 0, - .buffer = ggml_webgpu_tensor_buf(src0), - .offset = ggml_webgpu_tensor_align_offset(ctx, src0), - .size = ggml_webgpu_tensor_binding_size(ctx, src0) }, + ggml_webgpu_make_tensor_bind_group_entry(ctx, 0, src0), }; uint32_t dst_binding = 1; if (split) { dst_binding = 2; - entries.push_back({ .binding = 1, - .buffer = ggml_webgpu_tensor_buf(src1), - .offset = ggml_webgpu_tensor_align_offset(ctx, src1), - .size = ggml_webgpu_tensor_binding_size(ctx, src1) }); - } - entries.push_back({ .binding = dst_binding, - .buffer = ggml_webgpu_tensor_buf(dst), - .offset = ggml_webgpu_tensor_align_offset(ctx, dst), - .size = ggml_webgpu_tensor_binding_size(ctx, dst) }); - - webgpu_pipeline pipeline = ctx->glu_pipelines[ggml_get_glu_op(dst)][dst->type][split]; - uint32_t wg_x = CEIL_DIV(ggml_nelements(dst), WEBGPU_MAX_WG_SIZE); + entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 1, src1)); + } + entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, dst_binding, dst)); + + uint32_t wg_x = CEIL_DIV(ggml_nelements(dst), decisions->wg_size); return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x); } -static webgpu_command ggml_webgpu_scale(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) { - int inplace = ggml_webgpu_tensor_equal(src, dst); +static webgpu_encoded_op ggml_webgpu_scale(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) { + ggml_webgpu_shader_lib_context shader_lib_ctx = {}; + shader_lib_ctx.src0 = src; + shader_lib_ctx.src1 = nullptr; + shader_lib_ctx.dst = dst; + shader_lib_ctx.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup; + webgpu_pipeline pipeline = ctx->shader_lib->get_scale_pipeline(shader_lib_ctx); + auto * decisions = static_cast<ggml_webgpu_generic_shader_decisions *>(pipeline.context.get()); + + // params unchanged std::vector<uint32_t> params = { (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src) / ggml_type_size(src->type)), (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)), @@ -1455,52 +2684,57 @@ static webgpu_command ggml_webgpu_scale(webgpu_context & ctx, ggml_tensor * src, (uint32_t) src->ne[0], (uint32_t) src->ne[1], (uint32_t) src->ne[2], - *(uint32_t *) dst->op_params, // scale - *(uint32_t *) &dst->op_params[1] // bias + ggml_webgpu_u32_from_f32(ggml_get_op_params_f32(dst, 0)), // scale + ggml_webgpu_u32_from_f32(ggml_get_op_params_f32(dst, 1)) // bias }; - std::vector<wgpu::BindGroupEntry> entries = { - { .binding = 0, - .buffer = ggml_webgpu_tensor_buf(src), - .offset = ggml_webgpu_tensor_align_offset(ctx, src), - .size = ggml_webgpu_tensor_binding_size(ctx, src) } - }; - if (!inplace) { - entries.push_back({ .binding = 1, - .buffer = ggml_webgpu_tensor_buf(dst), - .offset = ggml_webgpu_tensor_align_offset(ctx, dst), - .size = ggml_webgpu_tensor_binding_size(ctx, dst) }); - } - - uint32_t wg_x = CEIL_DIV(ggml_nelements(dst), WEBGPU_MAX_WG_SIZE); - return ggml_backend_webgpu_build(ctx, ctx->scale_pipelines[inplace], params, entries, wg_x); -} - -static webgpu_command ggml_webgpu_soft_max(webgpu_context & ctx, - ggml_tensor * src0, - ggml_tensor * src1, - ggml_tensor * src2, - ggml_tensor * dst) { - const int inplace = ggml_webgpu_tensor_equal(src0, dst); - const int mask_type = (src1 != nullptr) ? src1->type : 2; // use 2 for no mask here - const int has_sink = (src2 != nullptr); - float max_bias; - memcpy(&max_bias, (float *) dst->op_params + 1, sizeof(float)); - float n_head_log2 = float(1u << (uint32_t) floor(log2(src0->ne[2]))); - float m0 = powf(2.0f, -(max_bias) / n_head_log2); - float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2); + // bindgroups unchanged + std::vector<wgpu::BindGroupEntry> entries = { ggml_webgpu_make_tensor_bind_group_entry(ctx, 0, src) }; + + if (!decisions->inplace) { + entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, 1, dst)); + } + + uint32_t wg_x, wg_y; + uint32_t total_wg = CEIL_DIV(ggml_nelements(dst), decisions->wg_size); + compute_2d_workgroups(total_wg, ctx->global_ctx->capabilities.limits.maxComputeWorkgroupsPerDimension, wg_x, wg_y); + return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x, wg_y); +} + +static webgpu_encoded_op ggml_webgpu_soft_max(webgpu_context & ctx, + ggml_tensor * src0, + ggml_tensor * src1, + ggml_tensor * src2, + ggml_tensor * dst) { + ggml_webgpu_shader_lib_context shader_lib_ctx = {}; + shader_lib_ctx.src0 = src0; + shader_lib_ctx.src1 = src1; + shader_lib_ctx.src2 = src2; + shader_lib_ctx.dst = dst; + shader_lib_ctx.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup; + + webgpu_pipeline pipeline = ctx->shader_lib->get_soft_max_pipeline(shader_lib_ctx); + auto * decisions = static_cast<ggml_webgpu_generic_shader_decisions *>(pipeline.context.get()); + + const bool inplace = decisions->inplace; + const int has_mask = (src1 != nullptr); + const int has_sink = (src2 != nullptr); + float max_bias = ggml_get_op_params_f32(dst, 1); + float n_head_log2 = float(1u << (uint32_t) floor(log2(src0->ne[2]))); + float m0 = powf(2.0f, -(max_bias) / n_head_log2); + float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2); std::vector<uint32_t> params = { (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src0) / ggml_type_size(src0->type)), - mask_type < 2 ? (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src1) / ggml_type_size(src1->type)) : 0, + has_mask ? (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src1) / ggml_type_size(src1->type)) : 0, has_sink ? (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src2) / ggml_type_size(src2->type)) : 0, (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)), (uint32_t) (src0->nb[1] / ggml_type_size(src0->type)), (uint32_t) (src0->nb[2] / ggml_type_size(src0->type)), (uint32_t) (src0->nb[3] / ggml_type_size(src0->type)), - mask_type < 2 ? (uint32_t) (src1->nb[1] / ggml_type_size(src1->type)) : 0, - mask_type < 2 ? (uint32_t) (src1->nb[2] / ggml_type_size(src1->type)) : 0, - mask_type < 2 ? (uint32_t) (src1->nb[3] / ggml_type_size(src1->type)) : 0, + has_mask ? (uint32_t) (src1->nb[1] / ggml_type_size(src1->type)) : 0, + has_mask ? (uint32_t) (src1->nb[2] / ggml_type_size(src1->type)) : 0, + has_mask ? (uint32_t) (src1->nb[3] / ggml_type_size(src1->type)) : 0, (uint32_t) (dst->nb[1] / ggml_type_size(dst->type)), (uint32_t) (dst->nb[2] / ggml_type_size(dst->type)), (uint32_t) (dst->nb[3] / ggml_type_size(dst->type)), @@ -1508,53 +2742,333 @@ static webgpu_command ggml_webgpu_soft_max(webgpu_context & ctx, (uint32_t) src0->ne[0], (uint32_t) src0->ne[1], (uint32_t) src0->ne[2], - mask_type < 2 ? (uint32_t) src1->ne[2] : 0, - mask_type < 2 ? (uint32_t) src1->ne[3] : 0, - *(uint32_t *) dst->op_params, // scale - *(uint32_t *) &max_bias, - *(uint32_t *) &n_head_log2, - *(uint32_t *) &m0, - *(uint32_t *) &m1 + has_mask ? (uint32_t) src1->ne[2] : 0, + has_mask ? (uint32_t) src1->ne[3] : 0, + ggml_webgpu_u32_from_f32(ggml_get_op_params_f32(dst, 0)), // scale + ggml_webgpu_u32_from_f32(max_bias), + ggml_webgpu_u32_from_f32(n_head_log2), + ggml_webgpu_u32_from_f32(m0), + ggml_webgpu_u32_from_f32(m1) + }; + + std::vector<wgpu::BindGroupEntry> entries = { ggml_webgpu_make_bind_group_entry( + 0, ggml_webgpu_tensor_buf(src0), ggml_webgpu_tensor_align_offset(ctx, src0), + ggml_webgpu_tensor_binding_size(ctx, src0)) }; + uint32_t binding_num = 1; + if (has_mask) { + entries.push_back(ggml_webgpu_make_bind_group_entry(binding_num, ggml_webgpu_tensor_buf(src1), + ggml_webgpu_tensor_align_offset(ctx, src1), + ggml_webgpu_tensor_binding_size(ctx, src1))); + binding_num++; + } + if (has_sink) { + entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, binding_num, src2)); + binding_num++; + } + if (!inplace) { + entries.push_back(ggml_webgpu_make_tensor_bind_group_entry(ctx, binding_num, dst)); + } + + return ggml_backend_webgpu_build(ctx, pipeline, params, entries, ggml_nrows(dst)); +} + +static webgpu_encoded_op ggml_webgpu_argmax(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) { + std::vector<uint32_t> params = { (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src) / ggml_type_size(src->type)), + (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)), + (uint32_t) src->ne[0] }; + + std::vector<wgpu::BindGroupEntry> entries = { ggml_webgpu_make_tensor_bind_group_entry(ctx, 0, src), + ggml_webgpu_make_tensor_bind_group_entry(ctx, 1, dst) }; + + ggml_webgpu_shader_lib_context shader_lib_ctx = {}; + shader_lib_ctx.src0 = src; + shader_lib_ctx.dst = dst; + shader_lib_ctx.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup; + + webgpu_pipeline pipeline = ctx->shader_lib->get_argmax_pipeline(shader_lib_ctx); + uint32_t wg_x = ggml_nelements(dst); + return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x); +} + +static webgpu_encoded_op ggml_webgpu_argsort(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) { + bool is_top_k = dst->op == GGML_OP_TOP_K; + + ggml_webgpu_shader_lib_context shader_lib_ctx = {}; + shader_lib_ctx.src0 = src; + shader_lib_ctx.src1 = nullptr; + shader_lib_ctx.dst = dst; + shader_lib_ctx.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup; + shader_lib_ctx.wg_mem_limit_bytes = ctx->global_ctx->capabilities.limits.maxComputeWorkgroupStorageSize; + + webgpu_pipeline argsort_pipeline = ctx->shader_lib->get_argsort_pipeline(shader_lib_ctx); + auto * argsort_decisions = static_cast<ggml_webgpu_generic_shader_decisions *>(argsort_pipeline.context.get()); + + webgpu_pipeline argsort_merge_pipeline = ctx->shader_lib->get_argsort_merge_pipeline(shader_lib_ctx); + + const uint32_t src_ne0 = (uint32_t) src->ne[0]; + const uint32_t nrows = (uint32_t) ggml_nrows(src); + const uint32_t npr = CEIL_DIV(src_ne0, argsort_decisions->wg_size); + const uint32_t block_size = + is_top_k ? std::min(argsort_decisions->wg_size, (uint32_t) dst->ne[0]) : argsort_decisions->wg_size; + uint32_t out_ne0 = src_ne0; + if (is_top_k) { + if (npr > 1) { + const uint32_t last_tile = src_ne0 - (npr - 1) * argsort_decisions->wg_size; + out_ne0 = (npr - 1) * block_size + std::min(last_tile, block_size); + } else { + out_ne0 = block_size; + } + } + + uint32_t merge_len = block_size; + uint32_t merge_passes = 0; + while (merge_len < out_ne0) { + merge_len <<= 1; + merge_passes++; + } + + const bool start_in_tmp = (merge_passes % 2) == 1; + + const size_t dst_offset = ggml_webgpu_tensor_offset(dst); + const size_t idx_nbytes = out_ne0 * ggml_nrows(dst) * sizeof(int32_t); + const size_t tmp_offset = + ROUNDUP_POW2(dst_offset + idx_nbytes, ctx->global_ctx->capabilities.limits.minStorageBufferOffsetAlignment); + const size_t tmp_binding_size = ROUNDUP_POW2(idx_nbytes, WEBGPU_STORAGE_BUF_BINDING_MULT); + const size_t dst_binding_size = + ROUNDUP_POW2(idx_nbytes + ggml_webgpu_tensor_misalignment(ctx, dst), WEBGPU_STORAGE_BUF_BINDING_MULT); + + const uint32_t offset_src = (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src) / ggml_type_size(src->type)); + const uint32_t offset_dst = (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)); + const uint32_t offset_tmp = 0; + const uint32_t stride_src1 = (uint32_t) (src->nb[1] / ggml_type_size(src->type)); + const uint32_t stride_src2 = (uint32_t) (src->nb[2] / ggml_type_size(src->type)); + const uint32_t stride_src3 = (uint32_t) (src->nb[3] / ggml_type_size(src->type)); + const uint32_t stride_idx1 = out_ne0; + const uint32_t stride_idx2 = out_ne0 * (uint32_t) dst->ne[1]; + const uint32_t stride_idx3 = stride_idx2 * (uint32_t) dst->ne[2]; + + std::vector<webgpu_dispatch_desc> dispatches; + + const uint32_t init_offset = start_in_tmp ? offset_tmp : offset_dst; + const size_t init_align_offset = start_in_tmp ? tmp_offset : ggml_webgpu_tensor_align_offset(ctx, dst); + const size_t init_binding_size = start_in_tmp ? tmp_binding_size : dst_binding_size; + + std::vector<uint32_t> init_params = { + offset_src, init_offset, stride_src1, stride_src2, stride_src3, stride_idx1, + stride_idx2, stride_idx3, src_ne0, (uint32_t) src->ne[1], (uint32_t) src->ne[2], out_ne0, + block_size, npr, nrows + }; + + uint32_t wg_x_init; + uint32_t wg_y_init; + const uint32_t total_wg_init = npr * nrows; + const uint32_t max_wg_per_dim = ctx->global_ctx->capabilities.limits.maxComputeWorkgroupsPerDimension; + compute_2d_workgroups(total_wg_init, max_wg_per_dim, wg_x_init, wg_y_init); + + std::vector<wgpu::BindGroupEntry> init_entries = { + ggml_webgpu_make_tensor_bind_group_entry(ctx, 0, src), + ggml_webgpu_make_bind_group_entry(1, ggml_webgpu_tensor_buf(dst), init_align_offset, init_binding_size) }; - std::vector<wgpu::BindGroupEntry> entries = { - { .binding = 0, - .buffer = ggml_webgpu_tensor_buf(src0), - .offset = ggml_webgpu_tensor_align_offset(ctx, src0), - .size = ggml_webgpu_tensor_binding_size(ctx, src0) } - }; - uint32_t binding_num = 1; - if (mask_type < 2) { - entries.push_back({ .binding = binding_num, - .buffer = ggml_webgpu_tensor_buf(src1), - .offset = ggml_webgpu_tensor_align_offset(ctx, src1), - .size = ggml_webgpu_tensor_binding_size(ctx, src1) }); - binding_num++; + dispatches.push_back({ + argsort_pipeline, std::move(init_params), std::move(init_entries), { wg_x_init, wg_y_init } + }); + + if (merge_passes == 0) { + return ggml_backend_webgpu_build_multi(ctx, dispatches); + } + + bool in_is_tmp = start_in_tmp; + uint32_t len = block_size; + while (len < out_ne0) { + const uint32_t nm = CEIL_DIV(out_ne0, 2 * len); + + const bool out_is_tmp = !in_is_tmp; + const uint32_t offset_in = in_is_tmp ? offset_tmp : offset_dst; + const uint32_t offset_out = out_is_tmp ? offset_tmp : offset_dst; + const size_t align_in = in_is_tmp ? tmp_offset : ggml_webgpu_tensor_align_offset(ctx, dst); + const size_t align_out = out_is_tmp ? tmp_offset : ggml_webgpu_tensor_align_offset(ctx, dst); + const size_t size_in = in_is_tmp ? tmp_binding_size : dst_binding_size; + const size_t size_out = out_is_tmp ? tmp_binding_size : dst_binding_size; + const uint32_t top_k_out = (is_top_k && nm == 1) ? (uint32_t) dst->ne[0] : out_ne0; + const uint32_t stride_out1 = top_k_out; + const uint32_t stride_out2 = top_k_out * (uint32_t) dst->ne[1]; + const uint32_t stride_out3 = stride_out2 * (uint32_t) dst->ne[2]; + + std::vector<uint32_t> merge_params = { offset_src, + offset_in, + offset_out, + stride_src1, + stride_src2, + stride_src3, + stride_idx1, + stride_idx2, + stride_idx3, + stride_out1, + stride_out2, + stride_out3, + out_ne0, + (uint32_t) src->ne[1], + (uint32_t) src->ne[2], + top_k_out, + len, + nm, + nrows }; + + std::vector<wgpu::BindGroupEntry> merge_entries = { + ggml_webgpu_make_tensor_bind_group_entry(ctx, 0, src), + ggml_webgpu_make_bind_group_entry(1, ggml_webgpu_tensor_buf(dst), align_in, size_in), + ggml_webgpu_make_bind_group_entry(2, ggml_webgpu_tensor_buf(dst), align_out, size_out) + }; + + uint32_t wg_x_merge; + uint32_t wg_y_merge; + const uint32_t total_wg_merge = nm * nrows; + compute_2d_workgroups(total_wg_merge, max_wg_per_dim, wg_x_merge, wg_y_merge); + + dispatches.push_back({ + argsort_merge_pipeline, std::move(merge_params), std::move(merge_entries), { wg_x_merge, wg_y_merge } + }); + + len <<= 1; + in_is_tmp = !in_is_tmp; + } + + return ggml_backend_webgpu_build_multi(ctx, dispatches); +} + +static webgpu_encoded_op ggml_webgpu_cumsum(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) { + std::vector<uint32_t> params = { (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src) / ggml_type_size(src->type)), + (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)), + (uint32_t) src->ne[0] }; + + std::vector<wgpu::BindGroupEntry> entries = { ggml_webgpu_make_tensor_bind_group_entry(ctx, 0, src), + ggml_webgpu_make_tensor_bind_group_entry(ctx, 1, dst) }; + + ggml_webgpu_shader_lib_context shader_lib_ctx = {}; + shader_lib_ctx.src0 = src; + shader_lib_ctx.src1 = nullptr; + shader_lib_ctx.dst = dst; + shader_lib_ctx.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup; + + webgpu_pipeline pipeline = ctx->shader_lib->get_cumsum_pipeline(shader_lib_ctx); + uint32_t wg_x = ggml_nrows(dst); + return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x); +} + +static webgpu_encoded_op ggml_webgpu_sum_rows(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) { + bool total_sum = dst->op == GGML_OP_SUM; + std::vector<uint32_t> params = { (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src) / ggml_type_size(src->type)), + (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)), + total_sum ? 0 : (uint32_t) (src->nb[1] / ggml_type_size(src->type)), + total_sum ? 0 : (uint32_t) (src->nb[2] / ggml_type_size(src->type)), + total_sum ? 0 : (uint32_t) (src->nb[3] / ggml_type_size(src->type)), + total_sum ? static_cast<uint32_t>(ggml_nelements(src)) : (uint32_t) src->ne[0], + total_sum ? 1 : (uint32_t) src->ne[1], + total_sum ? 1 : (uint32_t) src->ne[2] }; + + std::vector<wgpu::BindGroupEntry> entries = { ggml_webgpu_make_tensor_bind_group_entry(ctx, 0, src), + ggml_webgpu_make_tensor_bind_group_entry(ctx, 1, dst) }; + + ggml_webgpu_shader_lib_context shader_lib_ctx = {}; + shader_lib_ctx.src0 = src; + shader_lib_ctx.dst = dst; + shader_lib_ctx.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup; + + webgpu_pipeline pipeline = ctx->shader_lib->get_sum_rows_pipeline(shader_lib_ctx); + + uint32_t wg_x = total_sum ? 1 : ggml_nrows(dst); + return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x); +} + +static bool ggml_webgpu_can_fuse_rms_norm_mul(const struct ggml_cgraph * cgraph, int node_idx) { + if (!ggml_can_fuse(cgraph, node_idx, { GGML_OP_RMS_NORM, GGML_OP_MUL })) { + return false; } - if (has_sink) { - entries.push_back({ .binding = binding_num, - .buffer = ggml_webgpu_tensor_buf(src2), - .offset = ggml_webgpu_tensor_align_offset(ctx, src2), - .size = ggml_webgpu_tensor_binding_size(ctx, src2) }); - binding_num++; + + // additional constraints specific to this fusion + const ggml_tensor * rms_norm = cgraph->nodes[node_idx]; + const ggml_tensor * mul = cgraph->nodes[node_idx + 1]; + + GGML_ASSERT(rms_norm->src[0]->type == GGML_TYPE_F32); + GGML_ASSERT(rms_norm->type == GGML_TYPE_F32); + // rms_norm only supports f32 + if (mul->src[0]->type != GGML_TYPE_F32 || mul->src[1]->type != GGML_TYPE_F32 || mul->type != GGML_TYPE_F32) { + return false; } - if (!inplace) { - entries.push_back({ .binding = binding_num, - .buffer = ggml_webgpu_tensor_buf(dst), - .offset = ggml_webgpu_tensor_align_offset(ctx, dst), - .size = ggml_webgpu_tensor_binding_size(ctx, dst) }); + // if rms_norm is the B operand, then we don't handle broadcast + if (rms_norm == mul->src[1] && !ggml_are_same_shape(mul->src[0], rms_norm)) { + return false; + } + // rms_norm shader assumes contiguous rows + if (!ggml_is_contiguous_rows(mul->src[0]) || !ggml_is_contiguous_rows(mul->src[1])) { + return false; } - return ggml_backend_webgpu_build(ctx, ctx->soft_max_pipelines[mask_type][has_sink][inplace], params, entries, - ggml_nrows(dst)); + return true; +} + +static webgpu_encoded_op ggml_webgpu_upscale(webgpu_context ctx, ggml_tensor * src, ggml_tensor * dst) { + const uint32_t mode_flags = (uint32_t) ggml_get_op_params_i32(dst, 0); + std::vector<uint32_t> params = { (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src) / ggml_type_size(src->type)), + (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)), + + (uint32_t) (src->nb[0] / ggml_type_size(src->type)), + (uint32_t) (src->nb[1] / ggml_type_size(src->type)), + (uint32_t) (src->nb[2] / ggml_type_size(src->type)), + (uint32_t) (src->nb[3] / ggml_type_size(src->type)), + + (uint32_t) (dst->nb[0] / ggml_type_size(dst->type)), + (uint32_t) (dst->nb[1] / ggml_type_size(dst->type)), + (uint32_t) (dst->nb[2] / ggml_type_size(dst->type)), + (uint32_t) (dst->nb[3] / ggml_type_size(dst->type)), + + (uint32_t) src->ne[0], + (uint32_t) src->ne[1], + (uint32_t) src->ne[2], + (uint32_t) src->ne[3], + + (uint32_t) dst->ne[0], + (uint32_t) dst->ne[1], + (uint32_t) dst->ne[2], + (uint32_t) dst->ne[3], + + mode_flags }; + + std::vector<wgpu::BindGroupEntry> entries = { ggml_webgpu_make_tensor_bind_group_entry(ctx, 0, src), + ggml_webgpu_make_tensor_bind_group_entry(ctx, 1, dst) }; + + ggml_webgpu_shader_lib_context shader_lib_ctx = {}; + shader_lib_ctx.src0 = src; + shader_lib_ctx.dst = dst; + shader_lib_ctx.max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup; + + webgpu_pipeline pipeline = ctx->shader_lib->get_upscale_pipeline(shader_lib_ctx); + auto * decisions = static_cast<ggml_webgpu_generic_shader_decisions *>(pipeline.context.get()); + + uint32_t wg_x; + uint32_t wg_y; + uint32_t total_wg = CEIL_DIV((uint32_t) ggml_nelements(dst), decisions->wg_size); + compute_2d_workgroups(total_wg, ctx->global_ctx->capabilities.limits.maxComputeWorkgroupsPerDimension, wg_x, wg_y); + + return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x, wg_y); } // Returns the encoded command, or std::nullopt if the operation is a no-op -static std::optional<webgpu_command> ggml_webgpu_encode_node(webgpu_context ctx, ggml_tensor * node) { +static std::optional<webgpu_encoded_op> ggml_webgpu_encode(webgpu_context ctx, + ggml_cgraph * cgraph, + int node_idx, + int & num_encoded_ops) { + ggml_tensor ** nodes = cgraph->nodes; + ggml_tensor * node = nodes[node_idx]; + if (ggml_is_empty(node)) { return std::nullopt; } - WEBGPU_LOG_DEBUG("ggml_webgpu_encode_node(" << node << ", " << ggml_op_name(node->op) << ")"); + if ((node->flags & GGML_TENSOR_FLAG_COMPUTE) == 0) { + return std::nullopt; + } + WEBGPU_LOG_DEBUG("ggml_webgpu_encode(" << node << ", " << ggml_op_name(node->op) << ")"); ggml_tensor * src0 = node->src[0]; ggml_tensor * src1 = node->src[1]; @@ -1571,36 +3085,40 @@ static std::optional<webgpu_command> ggml_webgpu_encode_node(webgpu_context ctx, case GGML_OP_CPY: case GGML_OP_CONT: return ggml_webgpu_cpy(ctx, src0, node); + case GGML_OP_SET: + return ggml_webgpu_set(ctx, src0, src1, node); case GGML_OP_SET_ROWS: return ggml_webgpu_set_rows(ctx, src0, src1, node); case GGML_OP_GET_ROWS: return ggml_webgpu_get_rows(ctx, src0, src1, node); case GGML_OP_MUL_MAT: return ggml_webgpu_mul_mat(ctx, src0, src1, node); + case GGML_OP_MUL_MAT_ID: + return ggml_webgpu_mul_mat_id(ctx, src0, src1, src2, node); case GGML_OP_FLASH_ATTN_EXT: return ggml_webgpu_flash_attn(ctx, src0, src1, src2, node->src[3], node->src[4], node); case GGML_OP_ADD: - { - int inplace = ggml_webgpu_tensor_equal(src0, node); - return ggml_webgpu_binary_op(ctx, src0, src1, node, ctx->add_pipelines[node->type][inplace], inplace); - } case GGML_OP_SUB: - { - int inplace = ggml_webgpu_tensor_equal(src0, node); - return ggml_webgpu_binary_op(ctx, src0, src1, node, ctx->sub_pipelines[node->type][inplace], inplace); - } case GGML_OP_MUL: - { - int inplace = ggml_webgpu_tensor_equal(src0, node); - return ggml_webgpu_binary_op(ctx, src0, src1, node, ctx->mul_pipelines[node->type][inplace], inplace); - } case GGML_OP_DIV: - { - int inplace = ggml_webgpu_tensor_equal(src0, node); - return ggml_webgpu_binary_op(ctx, src0, src1, node, ctx->div_pipelines[node->type][inplace], inplace); - } + return ggml_webgpu_binary_op(ctx, src0, src1, node); + case GGML_OP_ADD_ID: + return ggml_webgpu_add_id(ctx, src0, src1, src2, node); + case GGML_OP_CONCAT: + return ggml_webgpu_concat(ctx, src0, src1, node); + case GGML_OP_REPEAT: + return ggml_webgpu_repeat(ctx, src0, node); case GGML_OP_RMS_NORM: - return ggml_webgpu_rms_norm(ctx, src0, node); + if (ggml_webgpu_can_fuse_rms_norm_mul(cgraph, node_idx)) { + num_encoded_ops = 2; + ggml_tensor * mul_node = nodes[node_idx + 1]; + return ggml_webgpu_rms_norm_mul(ctx, src0, node, mul_node->src[0], mul_node->src[1], mul_node); + } else { + return ggml_webgpu_row_norm(ctx, src0, node); + } + case GGML_OP_NORM: + case GGML_OP_L2_NORM: + return ggml_webgpu_row_norm(ctx, src0, node); case GGML_OP_ROPE: return ggml_webgpu_rope(ctx, src0, src1, src2, node); case GGML_OP_GLU: @@ -1610,65 +3128,301 @@ static std::optional<webgpu_command> ggml_webgpu_encode_node(webgpu_context ctx, case GGML_OP_SOFT_MAX: return ggml_webgpu_soft_max(ctx, src0, src1, src2, node); case GGML_OP_UNARY: + case GGML_OP_CLAMP: + case GGML_OP_FILL: + case GGML_OP_LOG: + case GGML_OP_SQR: + case GGML_OP_SQRT: + case GGML_OP_SIN: + case GGML_OP_COS: + case GGML_OP_DIAG: + case GGML_OP_TRI: return ggml_webgpu_unary_op(ctx, src0, node); + case GGML_OP_SOLVE_TRI: + return ggml_webgpu_solve_tri(ctx, src0, src1, node); + case GGML_OP_SSM_CONV: + return ggml_webgpu_ssm_conv(ctx, src0, src1, node); + case GGML_OP_SSM_SCAN: + return ggml_webgpu_ssm_scan(ctx, src0, src1, src2, node->src[3], node->src[4], node->src[5], node->src[6], + node); + case GGML_OP_GATED_DELTA_NET: + return ggml_webgpu_gated_delta_net(ctx, src0, src1, src2, node->src[3], node->src[4], node->src[5], node); + case GGML_OP_PAD: + return ggml_webgpu_pad(ctx, src0, node); + case GGML_OP_ARGMAX: + return ggml_webgpu_argmax(ctx, src0, node); + case GGML_OP_ARGSORT: + case GGML_OP_TOP_K: + // we reuse the same argsort implementation for top_k + return ggml_webgpu_argsort(ctx, src0, node); + case GGML_OP_CUMSUM: + return ggml_webgpu_cumsum(ctx, src0, node); + case GGML_OP_SUM: + case GGML_OP_SUM_ROWS: + return ggml_webgpu_sum_rows(ctx, src0, node); + case GGML_OP_CONV_2D: + return ggml_webgpu_conv_2d(ctx, src0, src1, node); + case GGML_OP_IM2COL: + return ggml_webgpu_im2col(ctx, src0, src1, node); + case GGML_OP_UPSCALE: + return ggml_webgpu_upscale(ctx, src0, node); default: return std::nullopt; } } +#ifdef GGML_WEBGPU_GPU_PROFILE +static void ggml_backend_webgpu_collect_profile_results(webgpu_context & ctx, + const std::vector<std::string> & pipeline_names, + uint32_t & num_inflight_batches) { + if (pipeline_names.empty()) { + return; + } + + wgpu::CommandEncoder encoder = ctx->global_ctx->device.CreateCommandEncoder(); + encoder.ResolveQuerySet(ctx->profile_timestamp_query_set, 0, ctx->profile_timestamp_query_count, + ctx->profile_timestamp_dev_buf, 0); + encoder.CopyBufferToBuffer(ctx->profile_timestamp_dev_buf, 0, ctx->profile_timestamp_host_buf, 0, + ctx->profile_timestamp_query_count * sizeof(uint64_t)); + + wgpu::CommandBuffer profile_commands = encoder.Finish(); + ggml_backend_webgpu_submit_commands(ctx, profile_commands, num_inflight_batches); + + const size_t mapped_size = ctx->profile_timestamp_query_count * sizeof(uint64_t); + GGML_ASSERT(ctx->profile_timestamp_query_count == 2 * pipeline_names.size()); + + ggml_backend_webgpu_map_buffer(ctx->global_ctx, ctx->profile_timestamp_host_buf, wgpu::MapMode::Read, 0, + mapped_size); + const uint64_t * ts_data = (const uint64_t *) ctx->profile_timestamp_host_buf.GetConstMappedRange(0, mapped_size); + + for (size_t i = 0; i < pipeline_names.size(); ++i) { + // WebGPU timestamps are in ns; convert to ms. + const double elapsed_ms = double(ts_data[2 * i + 1] - ts_data[2 * i]) * 1e-6; + ctx->shader_gpu_time_ms[pipeline_names[i]] += elapsed_ms; + } + + ctx->profile_timestamp_host_buf.Unmap(); +} +#endif + +// Don't bother checking set_rows index overflow for now, since practically the WebGPU doesn't need to support +// models that would require it right now. +static void ggml_backend_webgpu_check_set_rows(webgpu_context & ctx, uint32_t & num_inflight_batches) { +#ifdef GGML_WEBGPU_CHECK_SET_ROWS + wgpu::CommandEncoder encoder = ctx->global_ctx->device.CreateCommandEncoder(); + encoder.CopyBufferToBuffer(ctx->set_rows_dev_error_buf, 0, ctx->set_rows_host_error_buf, 0, + ctx->set_rows_host_error_buf.GetSize()); + wgpu::CommandBuffer commands = encoder.Finish(); + ggml_backend_webgpu_submit_commands(ctx, commands, num_inflight_batches); + ggml_backend_webgpu_map_buffer(ctx->global_ctx, ctx->set_rows_host_error_buf, wgpu::MapMode::Read, 0, + ctx->set_rows_host_error_buf.GetSize()); + const uint32_t * error_data = (const uint32_t *) ctx->set_rows_host_error_buf.GetConstMappedRange(); + if (*error_data) { + GGML_ABORT("ggml_webgpu: SET_ROWS index > 2^32, unsupported."); + } + ctx->set_rows_host_error_buf.Unmap(); +#else + GGML_UNUSED(ctx); + GGML_UNUSED(num_inflight_batches); +#endif +} + static ggml_status ggml_backend_webgpu_graph_compute(ggml_backend_t backend, struct ggml_cgraph * cgraph) { WEBGPU_LOG_DEBUG("ggml_backend_webgpu_graph_compute(" << cgraph->n_nodes << " nodes)"); - ggml_backend_webgpu_context * backend_ctx = static_cast<ggml_backend_webgpu_context *>(backend->context); + ggml_backend_webgpu_context * backend_ctx = (ggml_backend_webgpu_context *) backend->context; webgpu_context ctx = backend_ctx->webgpu_ctx; WEBGPU_CPU_PROFILE_TOTAL_START(graph_compute); - ctx->inflight_threads++; + std::vector<webgpu_encoded_op> commands; + + uint32_t num_batched_kernels = 0; + uint32_t num_inflight_batches = 0; + bool contains_set_rows = false; + int num_encoded_ops = 1; + int node_idx = 0; - std::vector<webgpu_command> commands; - std::vector<webgpu_submission_futures> futures; - for (int i = 0; i < cgraph->n_nodes; i++) { - if (auto cmd = ggml_webgpu_encode_node(ctx, cgraph->nodes[i])) { +#ifdef GGML_WEBGPU_GPU_PROFILE + ctx->profile_timestamp_query_count = 0; + std::vector<std::string> profile_pipeline_names; +#endif + + ctx->active_command_encoder = ctx->global_ctx->device.CreateCommandEncoder(); + if (ctx->batch_compute_passes) { + ctx->active_compute_pass = ctx->active_command_encoder.BeginComputePass(); + } + + while (node_idx < cgraph->n_nodes) { + if (cgraph->nodes[node_idx]->op == GGML_OP_SET_ROWS) { + contains_set_rows = true; + } + if (auto cmd = ggml_webgpu_encode(ctx, cgraph, node_idx, num_encoded_ops)) { commands.push_back(*cmd); + num_batched_kernels += cmd.value().num_kernels; +#ifdef GGML_WEBGPU_GPU_PROFILE + profile_pipeline_names.insert(profile_pipeline_names.end(), cmd->pipeline_names.begin(), + cmd->pipeline_names.end()); +#endif } - // compute the batch size based on the number of inflight threads - uint32_t inflight_threads = ctx->inflight_threads; - uint32_t batch_size = std::min(std::max(1u, WEBGPU_NUM_PARAM_BUFS / std::max(inflight_threads, 1u)), - WEBGPU_COMMAND_SUBMIT_BATCH_SIZE); - if (commands.size() >= batch_size) { - futures.push_back(ggml_backend_webgpu_submit(ctx, commands)); - // Process events and check for completed submissions - ctx->instance.ProcessEvents(); - ggml_backend_webgpu_wait(ctx, futures, false); + + if (num_batched_kernels >= ctx->global_ctx->command_submit_batch_size) { + if (ctx->active_compute_pass) { + ctx->active_compute_pass.End(); + } + num_batched_kernels = 0; + wgpu::CommandBuffer batch_commands = ctx->active_command_encoder.Finish(); + ggml_backend_webgpu_submit_commands(ctx, batch_commands, num_inflight_batches); + + // reset state for next batch + ctx->active_command_encoder = ctx->global_ctx->device.CreateCommandEncoder(); + if (ctx->batch_compute_passes) { + ctx->active_compute_pass = ctx->active_command_encoder.BeginComputePass(); + } + ctx->param_arena.reset(); commands.clear(); +#ifdef GGML_WEBGPU_GPU_PROFILE + // flush before the next batch can overflow the QuerySet + if (ctx->profile_timestamp_query_count + 2 * ctx->global_ctx->command_submit_batch_size >= + WEBGPU_MAX_PROFILE_QUERY_COUNT) { + ggml_backend_webgpu_collect_profile_results(ctx, profile_pipeline_names, num_inflight_batches); + // reset profile timestamp state + ctx->profile_timestamp_query_count = 0; + profile_pipeline_names.clear(); + } +#endif } + + node_idx += num_encoded_ops; + num_encoded_ops = 1; + } + + if (ctx->active_compute_pass) { + ctx->active_compute_pass.End(); + ctx->active_compute_pass = nullptr; + } + + if (num_batched_kernels > 0) { + wgpu::CommandBuffer batch_commands = ctx->active_command_encoder.Finish(); + ggml_backend_webgpu_submit_commands(ctx, batch_commands, num_inflight_batches); + ctx->param_arena.reset(); + commands.clear(); } - if (!commands.empty()) { - webgpu_submission_futures new_futures = ggml_backend_webgpu_submit(ctx, commands); - futures.push_back(new_futures); + ctx->active_command_encoder = nullptr; + +#ifdef GGML_WEBGPU_GPU_PROFILE + ggml_backend_webgpu_collect_profile_results(ctx, profile_pipeline_names, num_inflight_batches); +#endif + + if (contains_set_rows) { + ggml_backend_webgpu_check_set_rows(ctx, num_inflight_batches); } - ggml_backend_webgpu_wait(ctx, futures); - ctx->inflight_threads--; - WEBGPU_CPU_PROFILE_TOTAL_END(graph_compute, ctx); + WEBGPU_CPU_PROFILE_TOTAL_END(graph_compute, ctx->global_ctx); return GGML_STATUS_SUCCESS; } +struct ggml_backend_webgpu_event_context { + webgpu_global_context global_ctx; + wgpu::Future future; + bool recorded = false; +}; + +static ggml_backend_event_t ggml_backend_webgpu_device_event_new(ggml_backend_dev_t device) { + ggml_backend_webgpu_device_context * dev_ctx = (ggml_backend_webgpu_device_context *) device->context; + + auto * event_ctx = new ggml_backend_webgpu_event_context(); + event_ctx->global_ctx = dev_ctx->webgpu_global_ctx; + + auto * event = new ggml_backend_event; + event->device = device; + event->context = event_ctx; + return event; +} + +static void ggml_backend_webgpu_device_event_free(ggml_backend_dev_t dev, ggml_backend_event_t event) { + GGML_UNUSED(dev); + delete static_cast<ggml_backend_webgpu_event_context *>(event->context); + delete event; +} + +static void ggml_backend_webgpu_device_event_synchronize(ggml_backend_dev_t dev, ggml_backend_event_t event) { + GGML_UNUSED(dev); + ggml_backend_webgpu_event_context * event_ctx = (ggml_backend_webgpu_event_context *) event->context; + if (!event_ctx->recorded) { + return; + } + wgpu::WaitStatus status = + event_ctx->global_ctx->instance.WaitAny(event_ctx->future, WEBGPU_RUNTIME_WAIT_TIMEOUT_NS); + if (status == wgpu::WaitStatus::TimedOut) { + GGML_ABORT("ggml_webgpu: event_synchronize timed out after %u ms\n", WEBGPU_RUNTIME_WAIT_TIMEOUT_MS); + } + event_ctx->recorded = false; +} + +static void ggml_backend_webgpu_event_record(ggml_backend_t backend, ggml_backend_event_t event) { + ggml_backend_webgpu_context * backend_ctx = (ggml_backend_webgpu_context *) backend->context; + ggml_backend_webgpu_event_context * event_ctx = (ggml_backend_webgpu_event_context *) event->context; + + event_ctx->future = backend_ctx->webgpu_ctx->global_ctx->queue.OnSubmittedWorkDone( + wgpu::CallbackMode::AllowSpontaneous, [](wgpu::QueueWorkDoneStatus, wgpu::StringView) {}); + event_ctx->recorded = true; +} + +static void ggml_backend_webgpu_event_wait(ggml_backend_t backend, ggml_backend_event_t event) { + GGML_UNUSED(backend); + ggml_backend_webgpu_device_event_synchronize(nullptr, event); +} + +static void ggml_backend_webgpu_set_tensor_async(ggml_backend_t backend, + ggml_tensor * tensor, + const void * data, + size_t offset, + size_t size) { + GGML_UNUSED(backend); + auto * buf_ctx = (ggml_backend_webgpu_buffer_context *) tensor->buffer->context; + size_t total_offset = ggml_webgpu_tensor_offset(tensor) + offset; + + // Write aligned portion + buf_ctx->global_ctx->queue.WriteBuffer(buf_ctx->buffer, total_offset, data, (size / 4) * 4); + + if (size % 4 != 0) { + // If size is not a multiple of 4, we need to memset the remaining bytes + size_t remaining_size = size % 4; + + // pack the remaining bytes into a uint32_t + uint32_t val32 = 0; + + for (size_t i = 0; i < remaining_size; i++) { + ((uint8_t *) &val32)[i] = ((const uint8_t *) data)[size - remaining_size + i]; + } + // memset the remaining bytes + ggml_backend_webgpu_buffer_memset(buf_ctx->global_ctx, buf_ctx->buffer, val32, + total_offset + (size - remaining_size), remaining_size); + } +} + +static void ggml_backend_webgpu_synchronize(ggml_backend_t backend) { + ggml_backend_webgpu_context * backend_ctx = (ggml_backend_webgpu_context *) backend->context; + ggml_backend_webgpu_wait_queue(backend_ctx->webgpu_ctx->global_ctx); +} + static ggml_backend_i ggml_backend_webgpu_i = { /* .get_name = */ ggml_backend_webgpu_name, /* .free = */ ggml_backend_webgpu_free, - /* .set_tensor_async = */ NULL, + /* .set_tensor_async = */ ggml_backend_webgpu_set_tensor_async, /* .get_tensor_async = */ NULL, + /* .set_tensor_2d_async = */ NULL, + /* .get_tensor_2d_async = */ NULL, /* .cpy_tensor_async = */ NULL, - /* .synchronize = */ NULL, + /* .synchronize = */ ggml_backend_webgpu_synchronize, /* .graph_plan_create = */ NULL, /* .graph_plan_free = */ NULL, /* .graph_plan_update = */ NULL, /* .graph_plan_compute = */ NULL, /* .graph_compute = */ ggml_backend_webgpu_graph_compute, - /* .event_record = */ NULL, - /* .event_wait = */ NULL, + /* .event_record = */ ggml_backend_webgpu_event_record, + /* .event_wait = */ ggml_backend_webgpu_event_wait, /* .graph_optimize = */ NULL, }; @@ -1678,7 +3432,10 @@ static ggml_backend_i ggml_backend_webgpu_i = { static void ggml_backend_webgpu_buffer_free_buffer(ggml_backend_buffer_t buffer) { ggml_backend_webgpu_buffer_context * ctx = static_cast<ggml_backend_webgpu_buffer_context *>(buffer->context); - ctx->buffer.Destroy(); + if (ctx != nullptr && ctx->buffer != nullptr) { + ctx->buffer.Destroy(); + delete ctx; + } } // Returns the "fake" base pointer. @@ -1693,7 +3450,9 @@ static void ggml_backend_webgpu_buffer_memset_tensor(ggml_backend_buffer_t buffe size_t offset, size_t size) { if (size == 0) { - WEBGPU_LOG_DEBUG("ggml_backend_webgpu_buffer_memset_tensor: size is zero, nothing to do."); + WEBGPU_LOG_DEBUG( + "ggml_backend_webgpu_buffer_memset_tensor: size is zero, " + "nothing to do."); return; } @@ -1704,12 +3463,12 @@ static void ggml_backend_webgpu_buffer_memset_tensor(ggml_backend_buffer_t buffe WEBGPU_LOG_DEBUG("ggml_backend_webgpu_buffer_memset_tensor(" << buf_ctx->label << ", " << tensor << ", " << value << ", " << offset << ", " << size << ")"); - size_t total_offset = webgpu_tensor_offset(tensor) + tensor->view_offs + offset; + size_t total_offset = ggml_webgpu_tensor_offset(tensor) + offset; // This is a trick to set all bytes of a u32 to the same 1 byte value. uint32_t val32 = (uint32_t) value * 0x01010101; - ggml_backend_webgpu_buffer_memset(buf_ctx->webgpu_ctx, buf_ctx->buffer, val32, total_offset, size); - WEBGPU_CPU_PROFILE_TOTAL_END(memset_tensor, buf_ctx->webgpu_ctx); + ggml_backend_webgpu_buffer_memset(buf_ctx->global_ctx, buf_ctx->buffer, val32, total_offset, size); + WEBGPU_CPU_PROFILE_TOTAL_END(memset_tensor, buf_ctx->global_ctx); } static void ggml_backend_webgpu_buffer_set_tensor(ggml_backend_buffer_t buffer, @@ -1718,15 +3477,14 @@ static void ggml_backend_webgpu_buffer_set_tensor(ggml_backend_buffer_t buffer, size_t offset, size_t size) { WEBGPU_CPU_PROFILE_TOTAL_START(set_tensor); - ggml_backend_webgpu_buffer_context * buf_ctx = (ggml_backend_webgpu_buffer_context *) buffer->context; - webgpu_context webgpu_ctx = buf_ctx->webgpu_ctx; + ggml_backend_webgpu_buffer_context * buf_ctx = (ggml_backend_webgpu_buffer_context *) buffer->context; WEBGPU_LOG_DEBUG("ggml_backend_webgpu_buffer_set_tensor(" << buf_ctx->label << ", " << tensor << ", " << data << ", " << offset << ", " << size << ")"); - size_t total_offset = webgpu_tensor_offset(tensor) + tensor->view_offs + offset; + size_t total_offset = ggml_webgpu_tensor_offset(tensor) + offset; - webgpu_ctx->queue.WriteBuffer(buf_ctx->buffer, total_offset, data, (size / 4) * 4); + buf_ctx->global_ctx->queue.WriteBuffer(buf_ctx->buffer, total_offset, data, (size / 4) * 4); if (size % 4 != 0) { // If size is not a multiple of 4, we need to memset the remaining bytes @@ -1739,21 +3497,10 @@ static void ggml_backend_webgpu_buffer_set_tensor(ggml_backend_buffer_t buffer, ((uint8_t *) &val32)[i] = ((const uint8_t *) data)[size - remaining_size + i]; } // memset the remaining bytes - ggml_backend_webgpu_buffer_memset(webgpu_ctx, buf_ctx->buffer, val32, total_offset + (size - remaining_size), - remaining_size); - } else { - // wait for WriteBuffer to complete - webgpu_ctx->instance.WaitAny( - webgpu_ctx->queue.OnSubmittedWorkDone(wgpu::CallbackMode::AllowSpontaneous, - [](wgpu::QueueWorkDoneStatus status, wgpu::StringView message) { - if (status != wgpu::QueueWorkDoneStatus::Success) { - GGML_LOG_ERROR("ggml_webgpu: Failed to submit commands: %s\n", - std::string(message).c_str()); - } - }), - UINT64_MAX); + ggml_backend_webgpu_buffer_memset(buf_ctx->global_ctx, buf_ctx->buffer, val32, + total_offset + (size - remaining_size), remaining_size); } - WEBGPU_CPU_PROFILE_TOTAL_END(set_tensor, webgpu_ctx); + WEBGPU_CPU_PROFILE_TOTAL_END(set_tensor, buf_ctx->global_ctx); } static void ggml_backend_webgpu_buffer_get_tensor(ggml_backend_buffer_t buffer, @@ -1765,53 +3512,56 @@ static void ggml_backend_webgpu_buffer_get_tensor(ggml_backend_buffer_t buffer, ggml_backend_webgpu_buffer_context * buf_ctx = (ggml_backend_webgpu_buffer_context *) buffer->context; WEBGPU_LOG_DEBUG("ggml_backend_webgpu_buffer_get_tensor(" << buf_ctx->label << ", " << tensor << ", " << data << ", " << offset << ", " << size << ")"); - webgpu_context webgpu_ctx = buf_ctx->webgpu_ctx; - wgpu::Device device = webgpu_ctx->device; + wgpu::Device device = buf_ctx->global_ctx->device; - size_t total_offset = webgpu_tensor_offset(tensor) + tensor->view_offs + offset; + size_t total_offset = ggml_webgpu_tensor_offset(tensor) + offset; size_t final_size = size; if (size % 4 != 0) { - // If size is not a multiple of 4, we need to round it up to the next multiple of 4 + // If size is not a multiple of 4, we need to round it up to the next + // multiple of 4 final_size = size + (4 - (size % 4)); } - std::lock_guard<std::recursive_mutex> lock(webgpu_ctx->mutex); + std::lock_guard<std::recursive_mutex> lock(buf_ctx->global_ctx->mutex); - if (webgpu_ctx->get_tensor_staging_buf == nullptr || webgpu_ctx->get_tensor_staging_buf.GetSize() < final_size) { + if (buf_ctx->global_ctx->get_tensor_staging_buf == nullptr || + buf_ctx->global_ctx->get_tensor_staging_buf.GetSize() < final_size) { // Create a new staging buffer if it doesn't exist or is too small - if (webgpu_ctx->get_tensor_staging_buf) { - webgpu_ctx->get_tensor_staging_buf.Destroy(); + if (buf_ctx->global_ctx->get_tensor_staging_buf) { + buf_ctx->global_ctx->get_tensor_staging_buf.Destroy(); } - ggml_webgpu_create_buffer(device, webgpu_ctx->get_tensor_staging_buf, final_size, + ggml_webgpu_create_buffer(device, buf_ctx->global_ctx->get_tensor_staging_buf, final_size, wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::MapRead, "get_tensor_staging_buf"); } // Copy the data from the buffer to the staging buffer wgpu::CommandEncoder encoder = device.CreateCommandEncoder(); - encoder.CopyBufferToBuffer(buf_ctx->buffer, total_offset, webgpu_ctx->get_tensor_staging_buf, 0, final_size); + encoder.CopyBufferToBuffer(buf_ctx->buffer, total_offset, buf_ctx->global_ctx->get_tensor_staging_buf, 0, + final_size); wgpu::CommandBuffer commands = encoder.Finish(); // Submit the command buffer to the queue - webgpu_ctx->queue.Submit(1, &commands); + buf_ctx->global_ctx->queue.Submit(1, &commands); // Map the staging buffer to read the data - ggml_backend_webgpu_map_buffer(webgpu_ctx, webgpu_ctx->get_tensor_staging_buf, wgpu::MapMode::Read, 0, final_size); + ggml_backend_webgpu_map_buffer(buf_ctx->global_ctx, buf_ctx->global_ctx->get_tensor_staging_buf, + wgpu::MapMode::Read, 0, final_size); // Must specify size here since the staging buffer might be larger than the tensor size - const void * mapped_range = webgpu_ctx->get_tensor_staging_buf.GetConstMappedRange(0, final_size); + const void * mapped_range = buf_ctx->global_ctx->get_tensor_staging_buf.GetConstMappedRange(0, final_size); // Copy the data from the mapped range to the output buffer std::memcpy(data, mapped_range, size); - webgpu_ctx->get_tensor_staging_buf.Unmap(); - WEBGPU_CPU_PROFILE_TOTAL_END(get_tensor, webgpu_ctx); + buf_ctx->global_ctx->get_tensor_staging_buf.Unmap(); + WEBGPU_CPU_PROFILE_TOTAL_END(get_tensor, buf_ctx->global_ctx); } static void ggml_backend_webgpu_buffer_clear(ggml_backend_buffer_t buffer, uint8_t value) { WEBGPU_LOG_DEBUG("ggml_backend_webgpu_buffer_clear(" << buffer << ", " << (uint32_t) value << ")"); WEBGPU_CPU_PROFILE_TOTAL_START(clear); ggml_backend_webgpu_buffer_context * buf_ctx = (ggml_backend_webgpu_buffer_context *) buffer->context; - ggml_backend_webgpu_buffer_memset(buf_ctx->webgpu_ctx, buf_ctx->buffer, value, 0, buffer->size); - WEBGPU_CPU_PROFILE_TOTAL_END(clear, buf_ctx->webgpu_ctx); + ggml_backend_webgpu_buffer_memset(buf_ctx->global_ctx, buf_ctx->buffer, value, 0, buffer->size); + WEBGPU_CPU_PROFILE_TOTAL_END(clear, buf_ctx->global_ctx); } static ggml_backend_buffer_i ggml_backend_webgpu_buffer_interface = { @@ -1821,9 +3571,12 @@ static ggml_backend_buffer_i ggml_backend_webgpu_buffer_interface = { /* .memset_tensor = */ ggml_backend_webgpu_buffer_memset_tensor, /* .set_tensor = */ ggml_backend_webgpu_buffer_set_tensor, /* .get_tensor = */ ggml_backend_webgpu_buffer_get_tensor, + /* .set_tensor_2d = */ NULL, + /* .get_tensor_2d = */ NULL, /* .cpy_tensor = */ NULL, // TODO: optional, implement this /* .clear = */ ggml_backend_webgpu_buffer_clear, - /* .reset = */ NULL, // TODO: optional, think it coordinates with .init_tensor + /* .reset = */ NULL, // TODO: optional, think it coordinates with + // .init_tensor }; /* End GGML Backend Buffer Interface */ @@ -1841,28 +3594,129 @@ static ggml_backend_buffer_t ggml_backend_webgpu_buffer_type_alloc_buffer(ggml_b int buffer_id = buffer_count++; std::string buf_name = "tensor_buf" + std::to_string(buffer_id); WEBGPU_LOG_DEBUG("ggml_backend_webgpu_buffer_type_alloc_buffer_" << buffer_id << ": " << size << " bytes"); - ggml_backend_webgpu_device_context * ctx = static_cast<ggml_backend_webgpu_device_context *>(buft->device->context); - wgpu::Buffer buf; - ggml_webgpu_create_buffer(ctx->webgpu_ctx->device, buf, ROUNDUP_POW2(size, WEBGPU_STORAGE_BUF_BINDING_MULT), + ggml_backend_webgpu_device_context * ctx = static_cast<ggml_backend_webgpu_device_context *>(buft->device->context); + wgpu::Buffer buf; + ggml_webgpu_create_buffer(ctx->webgpu_global_ctx->device, buf, ROUNDUP_POW2(size, WEBGPU_STORAGE_BUF_BINDING_MULT), wgpu::BufferUsage::Storage | wgpu::BufferUsage::CopySrc | wgpu::BufferUsage::CopyDst, buf_name.c_str()); ggml_backend_webgpu_buffer_context * buf_ctx = - new ggml_backend_webgpu_buffer_context(ctx->webgpu_ctx, buf, buf_name); + new ggml_backend_webgpu_buffer_context(buf, buf_name, ctx->webgpu_global_ctx); return ggml_backend_buffer_init(buft, ggml_backend_webgpu_buffer_interface, buf_ctx, size); } static size_t ggml_backend_webgpu_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) { - ggml_backend_webgpu_device_context * ctx = static_cast<ggml_backend_webgpu_device_context *>(buft->device->context); - return ctx->webgpu_ctx->limits.minStorageBufferOffsetAlignment; + ggml_backend_webgpu_device_context * dev_ctx = + static_cast<ggml_backend_webgpu_device_context *>(buft->device->context); + return dev_ctx->webgpu_global_ctx->capabilities.limits.minStorageBufferOffsetAlignment; } -// maxBufferSize might be larger, but you can't bind more than maxStorageBufferBindingSize to a single binding. +// maxBufferSize might be larger, but you can't bind more than +// maxStorageBufferBindingSize to a single binding. static size_t ggml_backend_webgpu_buffer_type_get_max_size(ggml_backend_buffer_type_t buft) { + ggml_backend_webgpu_device_context * dev_ctx = + static_cast<ggml_backend_webgpu_device_context *>(buft->device->context); + return dev_ctx->webgpu_global_ctx->capabilities.limits.maxStorageBufferBindingSize; +} + +static size_t ggml_backend_webgpu_buffer_type_get_alloc_size(ggml_backend_buffer_type_t buft, + const ggml_tensor * tensor) { ggml_backend_webgpu_device_context * ctx = static_cast<ggml_backend_webgpu_device_context *>(buft->device->context); - return ctx->webgpu_ctx->limits.maxStorageBufferBindingSize; + size_t res = ggml_nbytes(tensor); + switch (tensor->op) { + case GGML_OP_ARGSORT: + res = ROUNDUP_POW2(res * 2 + ctx->webgpu_global_ctx->capabilities.limits.minStorageBufferOffsetAlignment, + WEBGPU_STORAGE_BUF_BINDING_MULT); + break; + case GGML_OP_TOP_K: + { + const ggml_tensor * src0 = tensor->src[0]; + if (src0) { + const size_t full = sizeof(int32_t) * ggml_nelements(src0); + res = ROUNDUP_POW2( + full * 2 + ctx->webgpu_global_ctx->capabilities.limits.minStorageBufferOffsetAlignment, + WEBGPU_STORAGE_BUF_BINDING_MULT); + } + } + break; + case GGML_OP_FLASH_ATTN_EXT: + { + const ggml_tensor * Q = tensor->src[0]; + const ggml_tensor * K = tensor->src[1]; + const ggml_tensor * V = tensor->src[2]; + const ggml_tensor * mask = tensor->src[3]; + const auto & capabilities = ctx->webgpu_global_ctx->capabilities; + if (ggml_webgpu_flash_attn_use_vec_path(ctx->webgpu_global_ctx, Q, K, V)) { + const bool kv_direct = + ggml_webgpu_flash_attn_kv_direct(Q, K, V, GGML_WEBGPU_FLASH_ATTN_TILE_KV_VEC_WIDTH); + const uint32_t kv_tile = ggml_webgpu_flash_attn_get_vec_kv_tile( + capabilities.limits.maxComputeWorkgroupStorageSize, (uint32_t) Q->ne[0], (uint32_t) V->ne[0], + mask != nullptr, kv_direct); + + const uint32_t vec_nwg_cap = capabilities.min_subgroup_size; + uint32_t nwg = ggml_webgpu_flash_attn_vec_nwg(vec_nwg_cap, kv_tile, (uint32_t) K->ne[1]); + + const size_t align = capabilities.limits.minStorageBufferOffsetAlignment; + const uint64_t nrows = (uint64_t) Q->ne[1] * Q->ne[2] * Q->ne[3]; + if (nwg > 1u) { + const uint64_t tmp_data_elems = nrows * (uint64_t) V->ne[0] * nwg; + const uint64_t tmp_stats_elems = nrows * 2u * nwg; + const size_t tmp_size_bytes = ROUNDUP_POW2((tmp_data_elems + tmp_stats_elems) * sizeof(float), + WEBGPU_STORAGE_BUF_BINDING_MULT); + res += tmp_size_bytes + align; + } else { + res += WEBGPU_STORAGE_BUF_BINDING_MULT + align; + } + if (mask != nullptr) { + const uint32_t blk_nblk0 = CEIL_DIV((uint32_t) K->ne[1], kv_tile); + const uint32_t blk_nblk1 = CEIL_DIV((uint32_t) Q->ne[1], 1u); + const uint32_t stride_mask3 = (uint32_t) (mask->nb[3] / ggml_type_size(mask->type)); + const uint32_t blk_batch_count = stride_mask3 > 0 ? (uint32_t) Q->ne[3] : 1u; + const uint64_t blk_elems = (uint64_t) blk_nblk0 * blk_nblk1 * blk_batch_count; + const size_t blk_size_bytes = + ROUNDUP_POW2(blk_elems * sizeof(uint32_t), WEBGPU_STORAGE_BUF_BINDING_MULT); + res += blk_size_bytes + align; + } + res = ROUNDUP_POW2(res, WEBGPU_STORAGE_BUF_BINDING_MULT); + } + } + break; + case GGML_OP_MUL_MAT: + { + const ggml_tensor * src0 = tensor->src[0]; + const ggml_tensor * src1 = tensor->src[1]; + bool use_mmvq = + ggml_webgpu_can_use_mmvq(src0, src1, ctx->webgpu_global_ctx->capabilities.supports_dot_product, + ctx->webgpu_global_ctx->vendor); + if (use_mmvq) { + const size_t q8_src1_size = + src1->ne[3] * src1->ne[2] * (36 /* sizeof(q8_1) */ * (src1->ne[0] / /* block_size */ 32)); + res = ROUNDUP_POW2(res + q8_src1_size + + ctx->webgpu_global_ctx->capabilities.limits.minStorageBufferOffsetAlignment, + WEBGPU_STORAGE_BUF_BINDING_MULT); + } + } + break; + case GGML_OP_MUL_MAT_ID: + { + const ggml_tensor * src0 = tensor->src[0]; + const ggml_tensor * src1 = tensor->src[1]; + if (src0 && src1) { + const size_t gathered_size = sizeof(uint32_t) * tensor->src[0]->ne[2] * tensor->src[1]->ne[2]; + const size_t gathered_count_ids_size = sizeof(uint32_t) * tensor->src[0]->ne[2]; + res = ROUNDUP_POW2( + res + gathered_size * 2 + gathered_count_ids_size + + ctx->webgpu_global_ctx->capabilities.limits.minStorageBufferOffsetAlignment * 3, + WEBGPU_STORAGE_BUF_BINDING_MULT); + } + } + break; + default: + break; + } + return res; } /* End GGML Backend Buffer Type Interface */ @@ -1874,673 +3728,279 @@ static const char * ggml_backend_webgpu_device_get_name(ggml_backend_dev_t dev) return ctx->device_name.c_str(); } -static const char * ggml_backend_webgpu_device_get_description(ggml_backend_dev_t dev) { - ggml_backend_webgpu_device_context * ctx = static_cast<ggml_backend_webgpu_device_context *>(dev->context); - return ctx->device_desc.c_str(); -} +static const char * ggml_backend_webgpu_device_get_description(ggml_backend_dev_t dev) { + ggml_backend_webgpu_device_context * ctx = static_cast<ggml_backend_webgpu_device_context *>(dev->context); + return ctx->device_desc.c_str(); +} + +static void ggml_backend_webgpu_device_get_memory(ggml_backend_dev_t dev, size_t * free, size_t * total) { + ggml_backend_webgpu_device_context * ctx = static_cast<ggml_backend_webgpu_device_context *>(dev->context); + // TODO: for now, return maxBufferSize as both free and total memory + // Track https://github.com/gpuweb/gpuweb/issues/5505 for updates. + uint64_t max_buffer_size = ctx->webgpu_global_ctx->capabilities.limits.maxBufferSize; + // If we're on a 32-bit system, clamp to UINTPTR_MAX +#if UINTPTR_MAX < UINT64_MAX + uint64_t max_ptr_size = static_cast<uint64_t>(UINTPTR_MAX); + if (max_buffer_size > max_ptr_size) { + max_buffer_size = max_ptr_size; + } +#endif + *free = static_cast<size_t>(max_buffer_size); + *total = static_cast<size_t>(max_buffer_size); +} + +static enum ggml_backend_dev_type ggml_backend_webgpu_device_get_type(ggml_backend_dev_t dev) { + GGML_UNUSED(dev); + return GGML_BACKEND_DEVICE_TYPE_GPU; +} + +static void ggml_backend_webgpu_device_get_props(ggml_backend_dev_t dev, struct ggml_backend_dev_props * props) { + props->name = ggml_backend_webgpu_device_get_name(dev); + props->description = ggml_backend_webgpu_device_get_description(dev); + props->type = ggml_backend_webgpu_device_get_type(dev); + ggml_backend_webgpu_device_get_memory(dev, &props->memory_free, &props->memory_total); + props->caps = { + /* .async = */ false, + /* .host_buffer = */ false, + /* .buffer_from_host_ptr = */ false, + /* .events = */ false, + }; +} + +static ggml_guid_t ggml_backend_webgpu_guid(void) { + static ggml_guid guid = { 0x67, 0xc7, 0xa4, 0xb1, 0x78, 0x74, 0x4f, 0x51, + 0x9d, 0x65, 0x44, 0x6d, 0xe4, 0x1b, 0x82, 0x9a }; + return &guid; +} + +static void ggml_webgpu_init_memset_pipeline(webgpu_global_context & ctx) { + // we use the maximum workgroup size for the memset pipeline + size_t max_threads = ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup * + ctx->capabilities.limits.maxComputeWorkgroupsPerDimension; + // Size the bytes_per_thread so that the largest buffer size can be handled + ctx->capabilities.memset_bytes_per_thread = + CEIL_DIV(ctx->capabilities.limits.maxStorageBufferBindingSize, max_threads); + std::vector<wgpu::ConstantEntry> constants(2); + constants[0].key = "wg_size"; + constants[0].value = ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup; + constants[1].key = "bytes_per_thread"; + constants[1].value = ctx->capabilities.memset_bytes_per_thread; + ctx->memset_pipeline = ggml_webgpu_create_pipeline(ctx->device, wgsl_memset, "memset", constants); +} + +static void create_webgpu_device(ggml_backend_webgpu_reg_context * ctx) { + wgpu::RequestAdapterOptions options = {}; + +#ifndef __EMSCRIPTEN__ + // TODO: track need for these toggles: https://issues.chromium.org/issues/42251215 + const char * const adapterEnabledToggles[] = { "vulkan_enable_f16_on_nvidia", "use_vulkan_memory_model" }; + wgpu::DawnTogglesDescriptor adapterTogglesDesc; + adapterTogglesDesc.enabledToggles = adapterEnabledToggles; + adapterTogglesDesc.enabledToggleCount = 2; + options.nextInChain = &adapterTogglesDesc; +#endif + + ctx->webgpu_global_ctx->instance.WaitAny( + ctx->webgpu_global_ctx->instance.RequestAdapter( + &options, wgpu::CallbackMode::AllowSpontaneous, + [&ctx](wgpu::RequestAdapterStatus status, wgpu::Adapter adapter, const char * message) { + if (status != wgpu::RequestAdapterStatus::Success) { + GGML_LOG_ERROR("ggml_webgpu: Failed to get an adapter: %s\n", message); + return; + } + ctx->webgpu_global_ctx->adapter = std::move(adapter); + }), + UINT64_MAX); + GGML_ASSERT(ctx->webgpu_global_ctx->adapter != nullptr); + + ctx->webgpu_global_ctx->adapter.GetLimits(&ctx->webgpu_global_ctx->capabilities.limits); + + wgpu::AdapterInfo info{}; +#ifndef __EMSCRIPTEN__ + wgpu::AdapterPropertiesSubgroupMatrixConfigs subgroup_matrix_configs{}; + if (ctx->webgpu_global_ctx->adapter.HasFeature(wgpu::FeatureName::ChromiumExperimentalSubgroupMatrix)) { + info.nextInChain = &subgroup_matrix_configs; + } +#endif + ctx->webgpu_global_ctx->adapter.GetInfo(&info); + ctx->webgpu_global_ctx->command_submit_batch_size = ggml_backend_webgpu_get_command_submit_batch_size(); + ctx->webgpu_global_ctx->max_inflight_batches = ggml_backend_webgpu_get_max_inflight_batches(); + ctx->webgpu_global_ctx->vendor = info.vendor; + ctx->webgpu_global_ctx->capabilities.supports_subgroups = + ctx->webgpu_global_ctx->adapter.HasFeature(wgpu::FeatureName::Subgroups); + // for dot4I8packed + ctx->webgpu_global_ctx->capabilities.supports_dot_product = ctx->webgpu_global_ctx->instance.HasWGSLLanguageFeature( + wgpu::WGSLLanguageFeatureName::Packed4x8IntegerDotProduct); + + bool valid_subgroup_matrix_config = false; +#ifndef __EMSCRIPTEN__ + // Accept f16 subgroup matrix configurations (square or non-square). + // NVIDIA GPUs typically report square configs (e.g. 16x16x16), + // while Intel Xe2 GPUs report non-square configs (e.g. 8x16x16). + // The shaders are already parameterized to handle any M/N/K dimensions. + if (ctx->webgpu_global_ctx->adapter.HasFeature(wgpu::FeatureName::ChromiumExperimentalSubgroupMatrix)) { + for (size_t i = 0; i < subgroup_matrix_configs.configCount; i++) { + const wgpu::SubgroupMatrixConfig config = subgroup_matrix_configs.configs[i]; + if (config.componentType == wgpu::SubgroupMatrixComponentType::F16 && + config.resultComponentType == wgpu::SubgroupMatrixComponentType::F16) { + ctx->webgpu_global_ctx->capabilities.sg_mat_m = config.M; + ctx->webgpu_global_ctx->capabilities.sg_mat_n = config.N; + ctx->webgpu_global_ctx->capabilities.sg_mat_k = config.K; + valid_subgroup_matrix_config = true; + break; + } + } + } +#endif + ctx->webgpu_global_ctx->capabilities.supports_subgroup_matrix = valid_subgroup_matrix_config; + + // Runtime subgroup size can be any supported size in this range. Shaders + // that allocate per-lane register arrays must size them for the minimum. + ctx->webgpu_global_ctx->capabilities.min_subgroup_size = info.subgroupMinSize; + ctx->webgpu_global_ctx->capabilities.max_subgroup_size = info.subgroupMaxSize; + // Initialize device + std::vector<wgpu::FeatureName> required_features = { wgpu::FeatureName::ShaderF16 }; + +#ifndef __EMSCRIPTEN__ + required_features.push_back(wgpu::FeatureName::ImplicitDeviceSynchronization); + if (ctx->webgpu_global_ctx->capabilities.supports_subgroup_matrix) { + required_features.push_back(wgpu::FeatureName::ChromiumExperimentalSubgroupMatrix); + } +#endif + + if (ctx->webgpu_global_ctx->capabilities.supports_subgroups) { + required_features.push_back(wgpu::FeatureName::Subgroups); + } + +#ifdef GGML_WEBGPU_GPU_PROFILE + required_features.push_back(wgpu::FeatureName::TimestampQuery); +#endif + + wgpu::DeviceDescriptor dev_desc; + dev_desc.requiredLimits = &ctx->webgpu_global_ctx->capabilities.limits; + dev_desc.requiredFeatures = required_features.data(); + dev_desc.requiredFeatureCount = required_features.size(); + dev_desc.SetDeviceLostCallback( + wgpu::CallbackMode::AllowSpontaneous, + [](const wgpu::Device & device, wgpu::DeviceLostReason reason, wgpu::StringView message) { + if (reason == wgpu::DeviceLostReason::Destroyed) { + return; + } + GGML_UNUSED(device); + GGML_LOG_ERROR("ggml_webgpu: Device lost! Reason: %d, Message: %s\n", static_cast<int>(reason), + std::string(message).c_str()); + }); + dev_desc.SetUncapturedErrorCallback( + [](const wgpu::Device & device, wgpu::ErrorType reason, wgpu::StringView message) { + GGML_UNUSED(device); + GGML_ABORT("ggml_webgpu: Device error! Reason: %d, Message: %s\n", static_cast<int>(reason), + std::string(message).c_str()); + }); + +#ifndef __EMSCRIPTEN__ + // Enable Dawn-specific toggles to increase native performance + // TODO: Maybe WebGPU needs a "fast" mode where you can request compilers skip adding checks like these, + // only for native performance? + const char * const deviceEnabledToggles[] = { "disable_robustness", "disable_workgroup_init", + "disable_polyfills_on_integer_div_and_mod" }; + const char * const deviceDisabledToggles[] = { "timestamp_quantization" }; + wgpu::DawnTogglesDescriptor deviceTogglesDesc; + deviceTogglesDesc.enabledToggles = deviceEnabledToggles; + deviceTogglesDesc.enabledToggleCount = 3; + deviceTogglesDesc.disabledToggles = deviceDisabledToggles; + deviceTogglesDesc.disabledToggleCount = 1; -static void ggml_backend_webgpu_device_get_memory(ggml_backend_dev_t dev, size_t * free, size_t * total) { - ggml_backend_webgpu_device_context * ctx = static_cast<ggml_backend_webgpu_device_context *>(dev->context); - // TODO: for now, return maxBufferSize as both free and total memory - // Track https://github.com/gpuweb/gpuweb/issues/5505 for updates. - uint64_t max_buffer_size = ctx->webgpu_ctx->limits.maxBufferSize; - // If we're on a 32-bit system, clamp to UINTPTR_MAX -#if UINTPTR_MAX < UINT64_MAX - uint64_t max_ptr_size = static_cast<uint64_t>(UINTPTR_MAX); - if (max_buffer_size > max_ptr_size) { - max_buffer_size = max_ptr_size; - } + dev_desc.nextInChain = &deviceTogglesDesc; #endif - *free = static_cast<size_t>(max_buffer_size); - *total = static_cast<size_t>(max_buffer_size); -} -static enum ggml_backend_dev_type ggml_backend_webgpu_device_get_type(ggml_backend_dev_t dev) { - GGML_UNUSED(dev); - return GGML_BACKEND_DEVICE_TYPE_GPU; -} + ctx->webgpu_global_ctx->instance.WaitAny( + ctx->webgpu_global_ctx->adapter.RequestDevice( + &dev_desc, wgpu::CallbackMode::AllowSpontaneous, + [ctx](wgpu::RequestDeviceStatus status, wgpu::Device device, wgpu::StringView message) { + if (status != wgpu::RequestDeviceStatus::Success) { + GGML_LOG_ERROR("ggml_webgpu: Failed to get a device: %s\n", std::string(message).c_str()); + return; + } + ctx->webgpu_global_ctx->device = std::move(device); + }), + UINT64_MAX); + GGML_ASSERT(ctx->webgpu_global_ctx->device != nullptr); -static void ggml_backend_webgpu_device_get_props(ggml_backend_dev_t dev, struct ggml_backend_dev_props * props) { - props->name = ggml_backend_webgpu_device_get_name(dev); - props->description = ggml_backend_webgpu_device_get_description(dev); - props->type = ggml_backend_webgpu_device_get_type(dev); - ggml_backend_webgpu_device_get_memory(dev, &props->memory_free, &props->memory_total); - props->caps = { - /* .async = */ false, - /* .host_buffer = */ false, - /* .buffer_from_host_ptr = */ false, - /* .events = */ false, - }; -} + ggml_webgpu_init_memset_pipeline(ctx->webgpu_global_ctx); + ggml_webgpu_create_buffer(ctx->webgpu_global_ctx->device, ctx->webgpu_global_ctx->memset_params_buf, + WEBGPU_PARAMS_BUF_SIZE_BYTES, wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::Uniform, + "memset_params_buf"); + ctx->webgpu_global_ctx->queue = ctx->webgpu_global_ctx->device.GetQueue(); -static ggml_guid_t ggml_backend_webgpu_guid(void) { - static const char * guid_str = "__ggml_webgpu :)"; - return reinterpret_cast<ggml_guid_t>((void *) guid_str); + GGML_LOG_INFO( + "ggml_webgpu: adapter_info: vendor_id: %u | vendor: %s | architecture: %s | device_id: %u | name: %s | " + "device_desc: %s\n", + info.vendorID, std::string(info.vendor).c_str(), std::string(info.architecture).c_str(), info.deviceID, + std::string(info.device).c_str(), std::string(info.description).c_str()); } -// Workgroup size is a common constant -static std::vector<wgpu::ConstantEntry> ggml_webgpu_wg_size_entry(uint32_t wg_size) { - std::vector<wgpu::ConstantEntry> constants(1); - constants[0].key = "wg_size"; - constants[0].value = wg_size; - return constants; -} +static webgpu_context initialize_webgpu_context(ggml_backend_dev_t dev) { + ggml_backend_webgpu_device_context * dev_ctx = (ggml_backend_webgpu_device_context *) dev->context; + webgpu_context webgpu_ctx = std::make_shared<webgpu_context_struct>(); + webgpu_ctx->global_ctx = dev_ctx->webgpu_global_ctx; + webgpu_ctx->shader_lib = std::make_unique<ggml_webgpu_shader_lib>(dev_ctx->webgpu_global_ctx->device); + webgpu_ctx->param_arena.init( + webgpu_ctx->global_ctx->device, WEBGPU_PARAMS_BUF_SIZE_BYTES, + webgpu_ctx->global_ctx->command_submit_batch_size + WEBGPU_NUM_PARAM_SLOT_SAFETY_MARGIN, + webgpu_ctx->global_ctx->capabilities.limits.minUniformBufferOffsetAlignment); + ggml_webgpu_create_buffer(webgpu_ctx->global_ctx->device, webgpu_ctx->set_rows_dev_error_buf, + WEBGPU_SET_ROWS_ERROR_BUF_SIZE_BYTES, + wgpu::BufferUsage::Storage | wgpu::BufferUsage::CopySrc, "set_rows_dev_error_buf"); + ggml_webgpu_create_buffer(webgpu_ctx->global_ctx->device, webgpu_ctx->set_rows_host_error_buf, + WEBGPU_SET_ROWS_ERROR_BUF_SIZE_BYTES, + wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::MapRead, "set_rows_host_error_buf"); -static void ggml_webgpu_init_memset_pipeline(webgpu_context & webgpu_ctx) { - // we use the maximum workgroup size for the memset pipeline - size_t max_threads = WEBGPU_MAX_WG_SIZE * webgpu_ctx->limits.maxComputeWorkgroupsPerDimension; - // Size the bytes_per_thread so that the largest buffer size can be handled - webgpu_ctx->memset_bytes_per_thread = CEIL_DIV(webgpu_ctx->limits.maxStorageBufferBindingSize, max_threads); - std::vector<wgpu::ConstantEntry> constants(2); - constants[0].key = "wg_size"; - constants[0].value = WEBGPU_MAX_WG_SIZE; - constants[1].key = "bytes_per_thread"; - constants[1].value = webgpu_ctx->memset_bytes_per_thread; - webgpu_ctx->memset_pipelines[0] = ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_memset, "memset", constants); -} - -static void ggml_webgpu_init_mul_mat_pipeline(webgpu_context & webgpu_ctx) { - // Q4/Q5/Q8 classic quantizations - webgpu_ctx->mul_mat_pipelines[GGML_TYPE_Q4_0][GGML_TYPE_F32][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_mul_mat_q4_0_f32, "mul_mat_q4_0_f32"); - webgpu_ctx->mul_mat_pipelines[GGML_TYPE_Q4_1][GGML_TYPE_F32][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_mul_mat_q4_1_f32, "mul_mat_q4_1_f32"); - webgpu_ctx->mul_mat_pipelines[GGML_TYPE_Q5_0][GGML_TYPE_F32][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_mul_mat_q5_0_f32, "mul_mat_q5_0_f32"); - webgpu_ctx->mul_mat_pipelines[GGML_TYPE_Q5_1][GGML_TYPE_F32][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_mul_mat_q5_1_f32, "mul_mat_q5_1_f32"); - webgpu_ctx->mul_mat_pipelines[GGML_TYPE_Q8_0][GGML_TYPE_F32][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_mul_mat_q8_0_f32, "mul_mat_q8_0_f32"); - - // K-quantizations - webgpu_ctx->mul_mat_pipelines[GGML_TYPE_Q2_K][GGML_TYPE_F32][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_mul_mat_q2_k_f32, "mul_mat_q2_k_f32"); - webgpu_ctx->mul_mat_pipelines[GGML_TYPE_Q3_K][GGML_TYPE_F32][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_mul_mat_q3_k_f32, "mul_mat_q3_k_f32"); - webgpu_ctx->mul_mat_pipelines[GGML_TYPE_Q4_K][GGML_TYPE_F32][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_mul_mat_q4_k_f32, "mul_mat_q4_k_f32"); - webgpu_ctx->mul_mat_pipelines[GGML_TYPE_Q5_K][GGML_TYPE_F32][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_mul_mat_q5_k_f32, "mul_mat_q5_k_f32"); - webgpu_ctx->mul_mat_pipelines[GGML_TYPE_Q6_K][GGML_TYPE_F32][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_mul_mat_q6_k_f32, "mul_mat_q6_k_f32"); - - // IQ quantizations (2-, 3-, 4-bit variants) - webgpu_ctx->mul_mat_pipelines[GGML_TYPE_IQ2_XXS][GGML_TYPE_F32][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_mul_mat_iq2_xxs_f32, "mul_mat_iq2_xxs_f32"); - webgpu_ctx->mul_mat_pipelines[GGML_TYPE_IQ2_XS][GGML_TYPE_F32][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_mul_mat_iq2_xs_f32, "mul_mat_iq2_xs_f32"); - webgpu_ctx->mul_mat_pipelines[GGML_TYPE_IQ2_S][GGML_TYPE_F32][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_mul_mat_iq2_s_f32, "mul_mat_iq2_s_f32"); - - webgpu_ctx->mul_mat_pipelines[GGML_TYPE_IQ3_XXS][GGML_TYPE_F32][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_mul_mat_iq3_xxs_f32, "mul_mat_iq3_xxs_f32"); - webgpu_ctx->mul_mat_pipelines[GGML_TYPE_IQ3_S][GGML_TYPE_F32][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_mul_mat_iq3_s_f32, "mul_mat_iq3_s_f32"); - - // 1-bit and 4-bit IQ variants - webgpu_ctx->mul_mat_pipelines[GGML_TYPE_IQ1_S][GGML_TYPE_F32][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_mul_mat_iq1_s_f32, "mul_mat_iq1_s_f32"); - webgpu_ctx->mul_mat_pipelines[GGML_TYPE_IQ1_M][GGML_TYPE_F32][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_mul_mat_iq1_m_f32, "mul_mat_iq1_m_f32"); - webgpu_ctx->mul_mat_pipelines[GGML_TYPE_IQ4_NL][GGML_TYPE_F32][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_mul_mat_iq4_nl_f32, "mul_mat_iq4_nl_f32"); - webgpu_ctx->mul_mat_pipelines[GGML_TYPE_IQ4_XS][GGML_TYPE_F32][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_mul_mat_iq4_xs_f32, "mul_mat_iq4_xs_f32"); - - std::string proc_mul_mat_f32_f32; - std::string proc_mul_mat_f32_f32_vec; - std::string proc_mul_mat_f16_f32; - std::string proc_mul_mat_f16_f32_vec; - std::string proc_mul_mat_f16_f16; - std::string proc_mul_mat_f16_f16_vec; - std::string proc_mul_mat_q4_0_f32; - std::string proc_mul_mat_q4_0_f32_vec; - - std::vector<wgpu::ConstantEntry> mul_mat_constants; -#ifndef __EMSCRIPTEN__ - if (webgpu_ctx->supports_subgroup_matrix) { - std::map<std::string, std::string> sg_matrix_repls; - sg_matrix_repls["WEBGPU_MAX_SUBGROUP_SIZE"] = std::to_string(webgpu_ctx->max_subgroup_size); - sg_matrix_repls["WEBGPU_TILE_K"] = std::to_string(WEBGPU_MUL_MAT_TILE_K); - sg_matrix_repls["WEBGPU_SUBGROUP_M"] = std::to_string(WEBGPU_MUL_MAT_SUBGROUP_M); - sg_matrix_repls["WEBGPU_SUBGROUP_N"] = std::to_string(WEBGPU_MUL_MAT_SUBGROUP_N); - sg_matrix_repls["WEBGPU_SUBGROUP_MATRIX_M"] = std::to_string(WEBGPU_MUL_MAT_SUBGROUP_MATRIX_M); - sg_matrix_repls["WEBGPU_SUBGROUP_MATRIX_N"] = std::to_string(WEBGPU_MUL_MAT_SUBGROUP_MATRIX_N); - sg_matrix_repls["WEBGPU_SG_MAT_M_SIZE"] = std::to_string(webgpu_ctx->sg_mat_m); - sg_matrix_repls["WEBGPU_SG_MAT_N_SIZE"] = std::to_string(webgpu_ctx->sg_mat_n); - sg_matrix_repls["WEBGPU_SG_MAT_K_SIZE"] = std::to_string(webgpu_ctx->sg_mat_k); - - proc_mul_mat_f32_f32 = ggml_webgpu_process_shader_repls(wgsl_mul_mat_subgroup_matrix_f32_f32, sg_matrix_repls); - proc_mul_mat_f32_f32_vec = - ggml_webgpu_process_shader_repls(wgsl_mul_mat_subgroup_matrix_f32_f32_vec, sg_matrix_repls); - proc_mul_mat_f16_f32 = ggml_webgpu_process_shader_repls(wgsl_mul_mat_subgroup_matrix_f16_f32, sg_matrix_repls); - proc_mul_mat_f16_f32_vec = - ggml_webgpu_process_shader_repls(wgsl_mul_mat_subgroup_matrix_f16_f32_vec, sg_matrix_repls); - proc_mul_mat_f16_f16 = ggml_webgpu_process_shader_repls(wgsl_mul_mat_subgroup_matrix_f16_f16, sg_matrix_repls); - proc_mul_mat_f16_f16_vec = - ggml_webgpu_process_shader_repls(wgsl_mul_mat_subgroup_matrix_f16_f16_vec, sg_matrix_repls); - proc_mul_mat_q4_0_f32 = - ggml_webgpu_process_shader_repls(wgsl_mul_mat_subgroup_matrix_q4_0_f32, sg_matrix_repls); - proc_mul_mat_q4_0_f32_vec = - ggml_webgpu_process_shader_repls(wgsl_mul_mat_subgroup_matrix_q4_0_f32_vec, sg_matrix_repls); - } else { +#ifdef GGML_WEBGPU_GPU_PROFILE + webgpu_ctx->batch_compute_passes = false; + ggml_webgpu_create_buffer( + webgpu_ctx->global_ctx->device, webgpu_ctx->profile_timestamp_dev_buf, WEBGPU_TIMESTAMP_QUERY_BUF_SIZE_BYTES, + wgpu::BufferUsage::QueryResolve | wgpu::BufferUsage::CopySrc, "profile_timestamp_dev_buf"); + ggml_webgpu_create_buffer(webgpu_ctx->global_ctx->device, webgpu_ctx->profile_timestamp_host_buf, + WEBGPU_TIMESTAMP_QUERY_BUF_SIZE_BYTES, + wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::MapRead, "profile_timestamp_host_buf"); + wgpu::QuerySetDescriptor query_set_desc = {}; + query_set_desc.type = wgpu::QueryType::Timestamp; + query_set_desc.count = WEBGPU_MAX_PROFILE_QUERY_COUNT; + webgpu_ctx->profile_timestamp_query_set = webgpu_ctx->global_ctx->device.CreateQuerySet(&query_set_desc); #endif - mul_mat_constants.push_back({ .key = "TILE_K", .value = WEBGPU_MUL_MAT_TILE_K }); - mul_mat_constants.push_back({ .key = "WORKGROUP_SIZE_M", .value = WEBGPU_MUL_MAT_WG_SIZE_M }); - mul_mat_constants.push_back({ .key = "WORKGROUP_SIZE_N", .value = WEBGPU_MUL_MAT_WG_SIZE_N }); - - std::map<std::string, std::string> reg_repls; - reg_repls["WEBGPU_TILE_M"] = std::to_string(WEBGPU_MUL_MAT_TILE_M); - reg_repls["WEBGPU_TILE_N"] = std::to_string(WEBGPU_MUL_MAT_TILE_N); - - proc_mul_mat_f32_f32 = ggml_webgpu_process_shader_repls(wgsl_mul_mat_reg_tile_f32_f32, reg_repls); - proc_mul_mat_f32_f32_vec = ggml_webgpu_process_shader_repls(wgsl_mul_mat_reg_tile_f32_f32_vec, reg_repls); - proc_mul_mat_f16_f32 = ggml_webgpu_process_shader_repls(wgsl_mul_mat_reg_tile_f16_f32, reg_repls); - proc_mul_mat_f16_f32_vec = ggml_webgpu_process_shader_repls(wgsl_mul_mat_reg_tile_f16_f32_vec, reg_repls); - proc_mul_mat_f16_f16 = ggml_webgpu_process_shader_repls(wgsl_mul_mat_reg_tile_f16_f16, reg_repls); - proc_mul_mat_f16_f16_vec = ggml_webgpu_process_shader_repls(wgsl_mul_mat_reg_tile_f16_f16_vec, reg_repls); - proc_mul_mat_q4_0_f32 = ggml_webgpu_process_shader_repls(wgsl_mul_mat_reg_tile_q4_0_f32, reg_repls); - proc_mul_mat_q4_0_f32_vec = ggml_webgpu_process_shader_repls(wgsl_mul_mat_reg_tile_q4_0_f32_vec, reg_repls); -#ifndef __EMSCRIPTEN__ - } + +#ifdef GGML_WEBGPU_DEBUG + // Initialize debug buffers + ggml_webgpu_create_buffer(webgpu_ctx->global_ctx->device, webgpu_ctx->global_ctx->debug_host_buf, + WEBGPU_DEBUG_BUF_ELEMS * sizeof(uint32_t), + wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::MapRead, "debug_host_buf"); + ggml_webgpu_create_buffer(webgpu_ctx->global_ctx->device, webgpu_ctx->global_ctx->debug_dev_buf, + WEBGPU_DEBUG_BUF_ELEMS * sizeof(uint32_t), + wgpu::BufferUsage::Storage | wgpu::BufferUsage::CopySrc, "debug_dev_buf"); #endif + return webgpu_ctx; +} - webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F32][GGML_TYPE_F32][0] = ggml_webgpu_create_pipeline( - webgpu_ctx->device, proc_mul_mat_f32_f32.c_str(), "mul_mat_f32_f32", mul_mat_constants); - webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F32][GGML_TYPE_F32][1] = ggml_webgpu_create_pipeline( - webgpu_ctx->device, proc_mul_mat_f32_f32_vec.c_str(), "mul_mat_f32_f32_vec", mul_mat_constants); - webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F16][GGML_TYPE_F32][0] = ggml_webgpu_create_pipeline( - webgpu_ctx->device, proc_mul_mat_f16_f32.c_str(), "mul_mat_f16_f32", mul_mat_constants); - webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F16][GGML_TYPE_F32][1] = ggml_webgpu_create_pipeline( - webgpu_ctx->device, proc_mul_mat_f16_f32_vec.c_str(), "mul_mat_f16_f32_vec", mul_mat_constants); - webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F16][GGML_TYPE_F16][0] = ggml_webgpu_create_pipeline( - webgpu_ctx->device, proc_mul_mat_f16_f16.c_str(), "mul_mat_f16_f16", mul_mat_constants); - webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F16][GGML_TYPE_F16][1] = ggml_webgpu_create_pipeline( - webgpu_ctx->device, proc_mul_mat_f16_f16_vec.c_str(), "mul_mat_f16_f16_vec", mul_mat_constants); - webgpu_ctx->mul_mat_pipelines[GGML_TYPE_Q4_0][GGML_TYPE_F32][0] = ggml_webgpu_create_pipeline( - webgpu_ctx->device, proc_mul_mat_q4_0_f32.c_str(), "mul_mat_q4_0_f32", mul_mat_constants); - webgpu_ctx->mul_mat_pipelines[GGML_TYPE_Q4_0][GGML_TYPE_F32][1] = ggml_webgpu_create_pipeline( - webgpu_ctx->device, proc_mul_mat_q4_0_f32_vec.c_str(), "mul_mat_q4_0_f32_vec", mul_mat_constants); - - std::vector<wgpu::ConstantEntry> mul_mat_vec_constants(3); - mul_mat_vec_constants[0].key = "WORKGROUP_SIZE"; - mul_mat_vec_constants[0].value = WEBGPU_MUL_MAT_VEC_WG_SIZE; - mul_mat_vec_constants[1].key = "TILE_K"; - mul_mat_vec_constants[1].value = WEBGPU_MUL_MAT_VEC_TILE_K; - mul_mat_vec_constants[2].key = "OUTPUTS_PER_WG"; - mul_mat_vec_constants[2].value = WEBGPU_MUL_MAT_VEC_OUTPUTS_PER_WG; - - webgpu_ctx->mul_mat_vec_pipelines[GGML_TYPE_F32][GGML_TYPE_F32][0] = ggml_webgpu_create_pipeline( - webgpu_ctx->device, wgsl_mul_mat_vec_f32_f32, "mul_mat_vec_f32_f32", mul_mat_vec_constants); - webgpu_ctx->mul_mat_vec_pipelines[GGML_TYPE_F32][GGML_TYPE_F32][1] = ggml_webgpu_create_pipeline( - webgpu_ctx->device, wgsl_mul_mat_vec_f32_f32_vec, "mul_mat_vec_f32_f32_vec", mul_mat_vec_constants); - webgpu_ctx->mul_mat_vec_pipelines[GGML_TYPE_F16][GGML_TYPE_F32][0] = ggml_webgpu_create_pipeline( - webgpu_ctx->device, wgsl_mul_mat_vec_f16_f32, "mul_mat_vec_f16_f32", mul_mat_vec_constants); - webgpu_ctx->mul_mat_vec_pipelines[GGML_TYPE_F16][GGML_TYPE_F32][1] = ggml_webgpu_create_pipeline( - webgpu_ctx->device, wgsl_mul_mat_vec_f16_f32_vec, "mul_mat_vec_f16_f32_vec", mul_mat_vec_constants); - webgpu_ctx->mul_mat_vec_pipelines[GGML_TYPE_F16][GGML_TYPE_F16][0] = ggml_webgpu_create_pipeline( - webgpu_ctx->device, wgsl_mul_mat_vec_f16_f16, "mul_mat_vec_f16_f16", mul_mat_vec_constants); - webgpu_ctx->mul_mat_vec_pipelines[GGML_TYPE_F16][GGML_TYPE_F16][1] = ggml_webgpu_create_pipeline( - webgpu_ctx->device, wgsl_mul_mat_vec_f16_f16_vec, "mul_mat_vec_f16_f16_vec", mul_mat_vec_constants); - webgpu_ctx->mul_mat_vec_pipelines[GGML_TYPE_Q4_0][GGML_TYPE_F32][0] = ggml_webgpu_create_pipeline( - webgpu_ctx->device, wgsl_mul_mat_vec_q4_0_f32, "mul_mat_vec_q4_0_f32", mul_mat_vec_constants); -} - -static void ggml_webgpu_init_set_rows_pipeline(webgpu_context & webgpu_ctx) { - webgpu_ctx->set_rows_pipelines[0][0] = ggml_webgpu_create_pipeline( - webgpu_ctx->device, wgsl_set_rows_f16, "set_rows_f16", ggml_webgpu_wg_size_entry(WEBGPU_MAX_WG_SIZE)); - webgpu_ctx->set_rows_pipelines[0][1] = ggml_webgpu_create_pipeline( - webgpu_ctx->device, wgsl_set_rows_f16_vec, "set_rows_f16_vec", ggml_webgpu_wg_size_entry(WEBGPU_MAX_WG_SIZE)); -} - -static void ggml_webgpu_init_get_rows_pipeline(webgpu_context & webgpu_ctx) { - std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_wg_size_entry(WEBGPU_MAX_WG_SIZE); - - webgpu_ctx->get_rows_pipelines[GGML_TYPE_F32][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_get_rows_f32, "get_rows_f32", constants); - webgpu_ctx->get_rows_pipelines[GGML_TYPE_F32][1] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_get_rows_f32_vec, "get_rows_f32_vec", constants); - - webgpu_ctx->get_rows_pipelines[GGML_TYPE_F16][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_get_rows_f16, "get_rows_f16", constants); - webgpu_ctx->get_rows_pipelines[GGML_TYPE_I32][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_get_rows_i32, "get_rows_i32", constants); - webgpu_ctx->get_rows_pipelines[GGML_TYPE_Q4_0][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_get_rows_q4_0, "get_rows_q4_0", constants); - webgpu_ctx->get_rows_pipelines[GGML_TYPE_Q4_1][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_get_rows_q4_1, "get_rows_q4_1", constants); - webgpu_ctx->get_rows_pipelines[GGML_TYPE_Q5_0][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_get_rows_q5_0, "get_rows_q5_0", constants); - webgpu_ctx->get_rows_pipelines[GGML_TYPE_Q5_1][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_get_rows_q5_1, "get_rows_q5_1", constants); - webgpu_ctx->get_rows_pipelines[GGML_TYPE_Q8_0][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_get_rows_q8_0, "get_rows_q8_0", constants); - - webgpu_ctx->get_rows_pipelines[GGML_TYPE_Q2_K][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_get_rows_q2_k, "get_rows_q2_k", constants); - webgpu_ctx->get_rows_pipelines[GGML_TYPE_Q3_K][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_get_rows_q3_k, "get_rows_q3_k", constants); - webgpu_ctx->get_rows_pipelines[GGML_TYPE_Q4_K][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_get_rows_q4_k, "get_rows_q4_k", constants); - webgpu_ctx->get_rows_pipelines[GGML_TYPE_Q5_K][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_get_rows_q5_k, "get_rows_q5_k", constants); - webgpu_ctx->get_rows_pipelines[GGML_TYPE_Q6_K][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_get_rows_q6_k, "get_rows_q6_k", constants); - - webgpu_ctx->get_rows_pipelines[GGML_TYPE_IQ2_XXS][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_get_rows_iq2_xxs, "get_rows_iq2_xxs", constants); - webgpu_ctx->get_rows_pipelines[GGML_TYPE_IQ2_XS][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_get_rows_iq2_xs, "get_rows_iq2_xs", constants); - webgpu_ctx->get_rows_pipelines[GGML_TYPE_IQ2_S][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_get_rows_iq2_s, "get_rows_iq2_s", constants); - webgpu_ctx->get_rows_pipelines[GGML_TYPE_IQ3_XXS][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_get_rows_iq3_xxs, "get_rows_iq3_xxs", constants); - webgpu_ctx->get_rows_pipelines[GGML_TYPE_IQ3_S][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_get_rows_iq3_s, "get_rows_iq3_s", constants); - webgpu_ctx->get_rows_pipelines[GGML_TYPE_IQ1_S][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_get_rows_iq1_s, "get_rows_iq1_s", constants); - webgpu_ctx->get_rows_pipelines[GGML_TYPE_IQ1_M][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_get_rows_iq1_m, "get_rows_iq1_m", constants); - webgpu_ctx->get_rows_pipelines[GGML_TYPE_IQ4_NL][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_get_rows_iq4_nl, "get_rows_iq4_nl", constants); - webgpu_ctx->get_rows_pipelines[GGML_TYPE_IQ4_XS][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_get_rows_iq4_xs, "get_rows_iq4_xs", constants); -} - -static void ggml_webgpu_init_cpy_pipeline(webgpu_context & webgpu_ctx) { - std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_wg_size_entry(WEBGPU_MAX_WG_SIZE); - - webgpu_ctx->cpy_pipelines[GGML_TYPE_F32][GGML_TYPE_F32] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_cpy_f32_f32, "cpy_f32_f32", constants); - webgpu_ctx->cpy_pipelines[GGML_TYPE_F32][GGML_TYPE_F16] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_cpy_f32_f16, "cpy_f32_f16", constants); - webgpu_ctx->cpy_pipelines[GGML_TYPE_F16][GGML_TYPE_F32] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_cpy_f16_f32, "cpy_f16_f32", constants); - webgpu_ctx->cpy_pipelines[GGML_TYPE_F16][GGML_TYPE_F16] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_cpy_f16_f16, "cpy_f16_f16", constants); -} - -static void ggml_webgpu_init_add_pipeline(webgpu_context & webgpu_ctx) { - std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_wg_size_entry(WEBGPU_MAX_WG_SIZE); - - webgpu_ctx->add_pipelines[GGML_TYPE_F32][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_add_f32, "add_f32", constants); - webgpu_ctx->add_pipelines[GGML_TYPE_F16][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_add_f16, "add_f16", constants); - webgpu_ctx->add_pipelines[GGML_TYPE_F32][1] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_add_f32_inplace, "add_f32_inplace", constants); - webgpu_ctx->add_pipelines[GGML_TYPE_F16][1] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_add_f16_inplace, "add_f16_inplace", constants); -} - -static void ggml_webgpu_init_sub_pipeline(webgpu_context & webgpu_ctx) { - std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_wg_size_entry(WEBGPU_MAX_WG_SIZE); - - webgpu_ctx->sub_pipelines[GGML_TYPE_F32][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_sub_f32, "sub_f32", constants); - webgpu_ctx->sub_pipelines[GGML_TYPE_F16][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_sub_f16, "sub_f16", constants); - webgpu_ctx->sub_pipelines[GGML_TYPE_F32][1] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_sub_f32_inplace, "sub_f32_inplace", constants); - webgpu_ctx->sub_pipelines[GGML_TYPE_F16][1] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_sub_f16_inplace, "sub_f16_inplace", constants); -} - -static void ggml_webgpu_init_mul_pipeline(webgpu_context & webgpu_ctx) { - std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_wg_size_entry(WEBGPU_MAX_WG_SIZE); - - webgpu_ctx->mul_pipelines[GGML_TYPE_F32][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_mul_f32, "mul_f32", constants); - webgpu_ctx->mul_pipelines[GGML_TYPE_F16][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_mul_f16, "mul_f16", constants); - webgpu_ctx->mul_pipelines[GGML_TYPE_F32][1] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_mul_f32_inplace, "mul_f32_inplace", constants); - webgpu_ctx->mul_pipelines[GGML_TYPE_F16][1] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_mul_f16_inplace, "mul_f16_inplace", constants); -} - -static void ggml_webgpu_init_div_pipeline(webgpu_context & webgpu_ctx) { - std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_wg_size_entry(WEBGPU_MAX_WG_SIZE); - - webgpu_ctx->div_pipelines[GGML_TYPE_F32][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_div_f32, "div_f32", constants); - webgpu_ctx->div_pipelines[GGML_TYPE_F16][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_div_f16, "div_f16", constants); - webgpu_ctx->div_pipelines[GGML_TYPE_F32][1] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_div_f32_inplace, "div_f32_inplace", constants); - webgpu_ctx->div_pipelines[GGML_TYPE_F16][1] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_div_f16_inplace, "div_f16_inplace", constants); -} - -static void ggml_webgpu_init_rms_norm_pipeline(webgpu_context & webgpu_ctx) { - std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_wg_size_entry(WEBGPU_ROW_SPLIT_WG_SIZE); - - webgpu_ctx->rms_norm_pipelines[0] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_rms_norm, "rms_norm", constants); - webgpu_ctx->rms_norm_pipelines[1] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_rms_norm_inplace, "rms_norm_inplace", constants); -} - -static void ggml_webgpu_init_rope_pipeline(webgpu_context & webgpu_ctx) { - std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_wg_size_entry(WEBGPU_MAX_WG_SIZE); - - webgpu_ctx->rope_pipelines[GGML_TYPE_F32][0][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_rope_f32, "rope_f32", constants); - webgpu_ctx->rope_pipelines[GGML_TYPE_F32][0][1] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_rope_f32_inplace, "rope_f32_inplace", constants); - webgpu_ctx->rope_pipelines[GGML_TYPE_F32][1][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_rope_f32_ff, "rope_f32_ff", constants); - webgpu_ctx->rope_pipelines[GGML_TYPE_F32][1][1] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_rope_f32_ff_inplace, "rope_f32_ff_inplace", constants); - - webgpu_ctx->rope_pipelines[GGML_TYPE_F16][0][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_rope_f16, "rope_f16", constants); - webgpu_ctx->rope_pipelines[GGML_TYPE_F16][0][1] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_rope_f16_inplace, "rope_f16_inplace", constants); - webgpu_ctx->rope_pipelines[GGML_TYPE_F16][1][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_rope_f16_ff, "rope_f16_ff", constants); - webgpu_ctx->rope_pipelines[GGML_TYPE_F16][1][1] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_rope_f16_ff_inplace, "rope_f16_ff_inplace", constants); -} - -static void ggml_webgpu_init_glu_pipeline(webgpu_context & webgpu_ctx) { - std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_wg_size_entry(WEBGPU_MAX_WG_SIZE); - - // REGLU - webgpu_ctx->glu_pipelines[GGML_GLU_OP_REGLU][GGML_TYPE_F32][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_reglu_f32, "reglu_f32", constants); - webgpu_ctx->glu_pipelines[GGML_GLU_OP_REGLU][GGML_TYPE_F16][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_reglu_f16, "reglu_f16", constants); - webgpu_ctx->glu_pipelines[GGML_GLU_OP_REGLU][GGML_TYPE_F32][1] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_reglu_f32_split, "reglu_f32_split", constants); - webgpu_ctx->glu_pipelines[GGML_GLU_OP_REGLU][GGML_TYPE_F16][1] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_reglu_f16_split, "reglu_f16_split", constants); - - // GEGLU - webgpu_ctx->glu_pipelines[GGML_GLU_OP_GEGLU][GGML_TYPE_F32][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_geglu_f32, "geglu_f32", constants); - webgpu_ctx->glu_pipelines[GGML_GLU_OP_GEGLU][GGML_TYPE_F16][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_geglu_f16, "geglu_f16", constants); - webgpu_ctx->glu_pipelines[GGML_GLU_OP_GEGLU][GGML_TYPE_F32][1] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_geglu_f32_split, "geglu_f32_split", constants); - webgpu_ctx->glu_pipelines[GGML_GLU_OP_GEGLU][GGML_TYPE_F16][1] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_geglu_f16_split, "geglu_f16_split", constants); - - // SWIGLU - webgpu_ctx->glu_pipelines[GGML_GLU_OP_SWIGLU][GGML_TYPE_F32][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_swiglu_f32, "swiglu_f32", constants); - webgpu_ctx->glu_pipelines[GGML_GLU_OP_SWIGLU][GGML_TYPE_F16][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_swiglu_f16, "swiglu_f16", constants); - webgpu_ctx->glu_pipelines[GGML_GLU_OP_SWIGLU][GGML_TYPE_F32][1] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_swiglu_f32_split, "swiglu_f32_split", constants); - webgpu_ctx->glu_pipelines[GGML_GLU_OP_SWIGLU][GGML_TYPE_F16][1] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_swiglu_f16_split, "swiglu_f16_split", constants); - - // SWIGLU_OAI - webgpu_ctx->glu_pipelines[GGML_GLU_OP_SWIGLU_OAI][GGML_TYPE_F32][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_swiglu_oai_f32, "swiglu_oai_f32", constants); - webgpu_ctx->glu_pipelines[GGML_GLU_OP_SWIGLU_OAI][GGML_TYPE_F32][1] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_swiglu_oai_f32_split, "swiglu_oai_f32_split", constants); - - // GEGLU_ERF - webgpu_ctx->glu_pipelines[GGML_GLU_OP_GEGLU_ERF][GGML_TYPE_F32][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_geglu_erf_f32, "geglu_erf_f32", constants); - webgpu_ctx->glu_pipelines[GGML_GLU_OP_GEGLU_ERF][GGML_TYPE_F16][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_geglu_erf_f16, "geglu_erf_f16", constants); - webgpu_ctx->glu_pipelines[GGML_GLU_OP_GEGLU_ERF][GGML_TYPE_F32][1] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_geglu_erf_f32_split, "geglu_erf_f32_split", constants); - webgpu_ctx->glu_pipelines[GGML_GLU_OP_GEGLU_ERF][GGML_TYPE_F16][1] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_geglu_erf_f16_split, "geglu_erf_f16_split", constants); - - // GEGLU_QUICK - webgpu_ctx->glu_pipelines[GGML_GLU_OP_GEGLU_QUICK][GGML_TYPE_F32][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_geglu_quick_f32, "geglu_quick_f32", constants); - webgpu_ctx->glu_pipelines[GGML_GLU_OP_GEGLU_QUICK][GGML_TYPE_F16][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_geglu_quick_f16, "geglu_quick_f16", constants); - webgpu_ctx->glu_pipelines[GGML_GLU_OP_GEGLU_QUICK][GGML_TYPE_F32][1] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_geglu_quick_f32_split, "geglu_quick_f32_split", constants); - webgpu_ctx->glu_pipelines[GGML_GLU_OP_GEGLU_QUICK][GGML_TYPE_F16][1] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_geglu_quick_f16_split, "geglu_quick_f16_split", constants); -} - -static void ggml_webgpu_init_unary_pipeline(webgpu_context & webgpu_ctx) { - std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_wg_size_entry(WEBGPU_MAX_WG_SIZE); - - // ABS - webgpu_ctx->unary_pipelines[GGML_UNARY_OP_ABS][GGML_TYPE_F32][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_abs_f32, "abs_f32", constants); - webgpu_ctx->unary_pipelines[GGML_UNARY_OP_ABS][GGML_TYPE_F16][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_abs_f16, "abs_f16", constants); - webgpu_ctx->unary_pipelines[GGML_UNARY_OP_ABS][GGML_TYPE_F32][1] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_abs_inplace_f32, "abs_inplace_f32", constants); - webgpu_ctx->unary_pipelines[GGML_UNARY_OP_ABS][GGML_TYPE_F16][1] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_abs_inplace_f16, "abs_inplace_f16", constants); - - // SGN - webgpu_ctx->unary_pipelines[GGML_UNARY_OP_SGN][GGML_TYPE_F32][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_sgn_f32, "sgn_f32", constants); - webgpu_ctx->unary_pipelines[GGML_UNARY_OP_SGN][GGML_TYPE_F16][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_sgn_f16, "sgn_f16", constants); - webgpu_ctx->unary_pipelines[GGML_UNARY_OP_SGN][GGML_TYPE_F32][1] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_sgn_inplace_f32, "sgn_inplace_f32", constants); - webgpu_ctx->unary_pipelines[GGML_UNARY_OP_SGN][GGML_TYPE_F16][1] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_sgn_inplace_f16, "sgn_inplace_f16", constants); - - // NEG - webgpu_ctx->unary_pipelines[GGML_UNARY_OP_NEG][GGML_TYPE_F32][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_neg_f32, "neg_f32", constants); - webgpu_ctx->unary_pipelines[GGML_UNARY_OP_NEG][GGML_TYPE_F16][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_neg_f16, "neg_f16", constants); - webgpu_ctx->unary_pipelines[GGML_UNARY_OP_NEG][GGML_TYPE_F32][1] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_neg_inplace_f32, "neg_inplace_f32", constants); - webgpu_ctx->unary_pipelines[GGML_UNARY_OP_NEG][GGML_TYPE_F16][1] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_neg_inplace_f16, "neg_inplace_f16", constants); - - // STEP - webgpu_ctx->unary_pipelines[GGML_UNARY_OP_STEP][GGML_TYPE_F32][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_step_f32, "step_f32", constants); - webgpu_ctx->unary_pipelines[GGML_UNARY_OP_STEP][GGML_TYPE_F16][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_step_f16, "step_f16", constants); - webgpu_ctx->unary_pipelines[GGML_UNARY_OP_STEP][GGML_TYPE_F32][1] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_step_inplace_f32, "step_inplace_f32", constants); - webgpu_ctx->unary_pipelines[GGML_UNARY_OP_STEP][GGML_TYPE_F16][1] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_step_inplace_f16, "step_inplace_f16", constants); - - // TANH - webgpu_ctx->unary_pipelines[GGML_UNARY_OP_TANH][GGML_TYPE_F32][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_tanh_f32, "tanh_f32", constants); - webgpu_ctx->unary_pipelines[GGML_UNARY_OP_TANH][GGML_TYPE_F16][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_tanh_f16, "tanh_f16", constants); - webgpu_ctx->unary_pipelines[GGML_UNARY_OP_TANH][GGML_TYPE_F32][1] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_tanh_inplace_f32, "tanh_inplace_f32", constants); - webgpu_ctx->unary_pipelines[GGML_UNARY_OP_TANH][GGML_TYPE_F16][1] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_tanh_inplace_f16, "tanh_inplace_f16", constants); - - // ELU - webgpu_ctx->unary_pipelines[GGML_UNARY_OP_ELU][GGML_TYPE_F32][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_elu_f32, "elu_f32", constants); - webgpu_ctx->unary_pipelines[GGML_UNARY_OP_ELU][GGML_TYPE_F16][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_elu_f16, "elu_f16", constants); - webgpu_ctx->unary_pipelines[GGML_UNARY_OP_ELU][GGML_TYPE_F32][1] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_elu_inplace_f32, "elu_inplace_f32", constants); - webgpu_ctx->unary_pipelines[GGML_UNARY_OP_ELU][GGML_TYPE_F16][1] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_elu_inplace_f16, "elu_inplace_f16", constants); - - // RELU - webgpu_ctx->unary_pipelines[GGML_UNARY_OP_RELU][GGML_TYPE_F32][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_relu_f32, "relu_f32", constants); - webgpu_ctx->unary_pipelines[GGML_UNARY_OP_RELU][GGML_TYPE_F16][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_relu_f16, "relu_f16", constants); - webgpu_ctx->unary_pipelines[GGML_UNARY_OP_RELU][GGML_TYPE_F32][1] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_relu_inplace_f32, "relu_inplace_f32", constants); - webgpu_ctx->unary_pipelines[GGML_UNARY_OP_RELU][GGML_TYPE_F16][1] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_relu_inplace_f16, "relu_inplace_f16", constants); - - // SIGMOID - webgpu_ctx->unary_pipelines[GGML_UNARY_OP_SIGMOID][GGML_TYPE_F32][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_sigmoid_f32, "sigmoid_f32", constants); - webgpu_ctx->unary_pipelines[GGML_UNARY_OP_SIGMOID][GGML_TYPE_F16][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_sigmoid_f16, "sigmoid_f16", constants); - webgpu_ctx->unary_pipelines[GGML_UNARY_OP_SIGMOID][GGML_TYPE_F32][1] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_sigmoid_inplace_f32, "sigmoid_inplace_f32", constants); - webgpu_ctx->unary_pipelines[GGML_UNARY_OP_SIGMOID][GGML_TYPE_F16][1] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_sigmoid_inplace_f16, "sigmoid_inplace_f16", constants); - - // GELU - webgpu_ctx->unary_pipelines[GGML_UNARY_OP_GELU][GGML_TYPE_F32][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_gelu_f32, "gelu_f32", constants); - webgpu_ctx->unary_pipelines[GGML_UNARY_OP_GELU][GGML_TYPE_F16][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_gelu_f16, "gelu_f16", constants); - webgpu_ctx->unary_pipelines[GGML_UNARY_OP_GELU][GGML_TYPE_F32][1] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_gelu_inplace_f32, "gelu_inplace_f32", constants); - webgpu_ctx->unary_pipelines[GGML_UNARY_OP_GELU][GGML_TYPE_F16][1] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_gelu_inplace_f16, "gelu_inplace_f16", constants); - - // GELU_QUICK - webgpu_ctx->unary_pipelines[GGML_UNARY_OP_GELU_QUICK][GGML_TYPE_F32][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_gelu_quick_f32, "gelu_quick_f32", constants); - webgpu_ctx->unary_pipelines[GGML_UNARY_OP_GELU_QUICK][GGML_TYPE_F16][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_gelu_quick_f16, "gelu_quick_f16", constants); - webgpu_ctx->unary_pipelines[GGML_UNARY_OP_GELU_QUICK][GGML_TYPE_F32][1] = ggml_webgpu_create_pipeline( - webgpu_ctx->device, wgsl_gelu_quick_inplace_f32, "gelu_quick_inplace_f32", constants); - webgpu_ctx->unary_pipelines[GGML_UNARY_OP_GELU_QUICK][GGML_TYPE_F16][1] = ggml_webgpu_create_pipeline( - webgpu_ctx->device, wgsl_gelu_quick_inplace_f16, "gelu_quick_inplace_f16", constants); - - // SILU - webgpu_ctx->unary_pipelines[GGML_UNARY_OP_SILU][GGML_TYPE_F32][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_silu_f32, "silu_f32", constants); - webgpu_ctx->unary_pipelines[GGML_UNARY_OP_SILU][GGML_TYPE_F16][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_silu_f16, "silu_f16", constants); - webgpu_ctx->unary_pipelines[GGML_UNARY_OP_SILU][GGML_TYPE_F32][1] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_silu_inplace_f32, "silu_inplace_f32", constants); - webgpu_ctx->unary_pipelines[GGML_UNARY_OP_SILU][GGML_TYPE_F16][1] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_silu_inplace_f16, "silu_inplace_f16", constants); - - // HARDSWISH - webgpu_ctx->unary_pipelines[GGML_UNARY_OP_HARDSWISH][GGML_TYPE_F32][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_hardswish_f32, "hardswish_f32", constants); - webgpu_ctx->unary_pipelines[GGML_UNARY_OP_HARDSWISH][GGML_TYPE_F16][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_hardswish_f16, "hardswish_f16", constants); - webgpu_ctx->unary_pipelines[GGML_UNARY_OP_HARDSWISH][GGML_TYPE_F32][1] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_hardswish_inplace_f32, "hardswish_inplace_f32", constants); - webgpu_ctx->unary_pipelines[GGML_UNARY_OP_HARDSWISH][GGML_TYPE_F16][1] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_hardswish_inplace_f16, "hardswish_inplace_f16", constants); - - // HARDSIGMOID - webgpu_ctx->unary_pipelines[GGML_UNARY_OP_HARDSIGMOID][GGML_TYPE_F32][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_hardsigmoid_f32, "hardsigmoid_f32", constants); - webgpu_ctx->unary_pipelines[GGML_UNARY_OP_HARDSIGMOID][GGML_TYPE_F16][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_hardsigmoid_f16, "hardsigmoid_f16", constants); - webgpu_ctx->unary_pipelines[GGML_UNARY_OP_HARDSIGMOID][GGML_TYPE_F32][1] = ggml_webgpu_create_pipeline( - webgpu_ctx->device, wgsl_hardsigmoid_inplace_f32, "hardsigmoid_inplace_f32", constants); - webgpu_ctx->unary_pipelines[GGML_UNARY_OP_HARDSIGMOID][GGML_TYPE_F16][1] = ggml_webgpu_create_pipeline( - webgpu_ctx->device, wgsl_hardsigmoid_inplace_f16, "hardsigmoid_inplace_f16", constants); - - // EXP - webgpu_ctx->unary_pipelines[GGML_UNARY_OP_EXP][GGML_TYPE_F32][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_exp_f32, "exp_f32", constants); - webgpu_ctx->unary_pipelines[GGML_UNARY_OP_EXP][GGML_TYPE_F16][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_exp_f16, "exp_f16", constants); - webgpu_ctx->unary_pipelines[GGML_UNARY_OP_EXP][GGML_TYPE_F32][1] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_exp_inplace_f32, "exp_inplace_f32", constants); - webgpu_ctx->unary_pipelines[GGML_UNARY_OP_EXP][GGML_TYPE_F16][1] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_exp_inplace_f16, "exp_inplace_f16", constants); - - // GELU_ERF - webgpu_ctx->unary_pipelines[GGML_UNARY_OP_GELU_ERF][GGML_TYPE_F32][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_gelu_erf_f32, "gelu_erf_f32", constants); - webgpu_ctx->unary_pipelines[GGML_UNARY_OP_GELU_ERF][GGML_TYPE_F16][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_gelu_erf_f16, "gelu_erf_f16", constants); - webgpu_ctx->unary_pipelines[GGML_UNARY_OP_GELU_ERF][GGML_TYPE_F32][1] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_gelu_erf_inplace_f32, "gelu_erf_inplace_f32", constants); - webgpu_ctx->unary_pipelines[GGML_UNARY_OP_GELU_ERF][GGML_TYPE_F16][1] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_gelu_erf_inplace_f16, "gelu_erf_inplace_f16", constants); - - // XIELU - webgpu_ctx->unary_pipelines[GGML_UNARY_OP_XIELU][GGML_TYPE_F32][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_xielu_f32, "xielu_f32", constants); - webgpu_ctx->unary_pipelines[GGML_UNARY_OP_XIELU][GGML_TYPE_F16][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_xielu_f16, "xielu_f16", constants); - webgpu_ctx->unary_pipelines[GGML_UNARY_OP_XIELU][GGML_TYPE_F32][1] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_xielu_inplace_f32, "xielu_inplace_f32", constants); - webgpu_ctx->unary_pipelines[GGML_UNARY_OP_XIELU][GGML_TYPE_F16][1] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_xielu_inplace_f16, "xielu_inplace_f16", constants); - - // CEIL - webgpu_ctx->unary_pipelines[GGML_UNARY_OP_CEIL][GGML_TYPE_F32][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_ceil_f32, "ceil_f32", constants); - webgpu_ctx->unary_pipelines[GGML_UNARY_OP_CEIL][GGML_TYPE_F16][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_ceil_f16, "ceil_f16", constants); - webgpu_ctx->unary_pipelines[GGML_UNARY_OP_CEIL][GGML_TYPE_F32][1] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_ceil_inplace_f32, "ceil_inplace_f32", constants); - webgpu_ctx->unary_pipelines[GGML_UNARY_OP_CEIL][GGML_TYPE_F16][1] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_ceil_inplace_f16, "ceil_inplace_f16", constants); -} - -static void ggml_webgpu_init_scale_pipeline(webgpu_context & webgpu_ctx) { - std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_wg_size_entry(WEBGPU_MAX_WG_SIZE); - - webgpu_ctx->scale_pipelines[0] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_scale_f32, "scale_f32", constants); - webgpu_ctx->scale_pipelines[1] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_scale_f32_inplace, "scale_f32_inplace", constants); -} - -static void ggml_webgpu_init_soft_max_pipeline(webgpu_context & webgpu_ctx) { - std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_wg_size_entry(WEBGPU_ROW_SPLIT_WG_SIZE); - - // f32 (no mask) - webgpu_ctx->soft_max_pipelines[2][0][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_soft_max_f32, "soft_max_f32", constants); - webgpu_ctx->soft_max_pipelines[2][0][1] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_soft_max_f32_inplace, "soft_max_f32_inplace", constants); - webgpu_ctx->soft_max_pipelines[2][1][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_soft_max_f32_sink, "soft_max_f32_sink", constants); - webgpu_ctx->soft_max_pipelines[2][1][1] = ggml_webgpu_create_pipeline( - webgpu_ctx->device, wgsl_soft_max_f32_sink_inplace, "soft_max_f32_sink_inplace", constants); - - // f32 mask (mask_type = 0) - webgpu_ctx->soft_max_pipelines[0][0][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_soft_max_f32_mask_f32, "soft_max_f32_mask_f32", constants); - webgpu_ctx->soft_max_pipelines[0][0][1] = ggml_webgpu_create_pipeline( - webgpu_ctx->device, wgsl_soft_max_f32_mask_f32_inplace, "soft_max_f32_mask_f32_inplace", constants); - webgpu_ctx->soft_max_pipelines[0][1][0] = ggml_webgpu_create_pipeline( - webgpu_ctx->device, wgsl_soft_max_f32_mask_f32_sink, "soft_max_f32_mask_f32_sink", constants); - webgpu_ctx->soft_max_pipelines[0][1][1] = ggml_webgpu_create_pipeline( - webgpu_ctx->device, wgsl_soft_max_f32_mask_f32_sink_inplace, "soft_max_f32_mask_f32_sink_inplace", constants); - - // f16 mask (mask_type = 1) - webgpu_ctx->soft_max_pipelines[1][0][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_soft_max_f32_mask_f16, "soft_max_f32_mask_f16", constants); - webgpu_ctx->soft_max_pipelines[1][0][1] = ggml_webgpu_create_pipeline( - webgpu_ctx->device, wgsl_soft_max_f32_mask_f16_inplace, "soft_max_f32_mask_f16_inplace", constants); - webgpu_ctx->soft_max_pipelines[1][1][0] = ggml_webgpu_create_pipeline( - webgpu_ctx->device, wgsl_soft_max_f32_mask_f16_sink, "soft_max_f32_mask_f16_sink", constants); - webgpu_ctx->soft_max_pipelines[1][1][1] = ggml_webgpu_create_pipeline( - webgpu_ctx->device, wgsl_soft_max_f32_mask_f16_sink_inplace, "soft_max_f32_mask_f16_sink_inplace", constants); -} - -// TODO: move most initialization logic here -static ggml_backend_t ggml_backend_webgpu_device_init(ggml_backend_dev_t dev, const char * params) { +static ggml_backend_t ggml_backend_webgpu_backend_init(ggml_backend_dev_t dev, const char * params) { GGML_UNUSED(params); - WEBGPU_LOG_DEBUG("ggml_backend_webgpu_device_init()"); + WEBGPU_LOG_DEBUG("ggml_backend_webgpu_backend_init()"); - ggml_backend_webgpu_device_context * dev_ctx = static_cast<ggml_backend_webgpu_device_context *>(dev->context); - webgpu_context webgpu_ctx = dev_ctx->webgpu_ctx; + ggml_backend_webgpu_device_context * dev_ctx = static_cast<ggml_backend_webgpu_device_context *>(dev->context); - static ggml_backend_webgpu_context backend_ctx; - backend_ctx.name = GGML_WEBGPU_NAME + std::string(": ") + dev_ctx->device_name; - backend_ctx.webgpu_ctx = webgpu_ctx; + auto * backend_ctx = new ggml_backend_webgpu_context(); + backend_ctx->name = GGML_WEBGPU_NAME + std::string(": ") + dev_ctx->device_name; + backend_ctx->webgpu_ctx = initialize_webgpu_context(dev); // See GGML Backend Interface section - static ggml_backend backend = { + auto * backend = new ggml_backend(); + *backend = { /* .guid = */ ggml_backend_webgpu_guid(), /* .interface = */ ggml_backend_webgpu_i, /* .device = */ dev, - /* .context = */ &backend_ctx, + /* .context = */ backend_ctx, }; - return &backend; + return backend; } static ggml_backend_buffer_type_t ggml_backend_webgpu_device_get_buffer_type(ggml_backend_dev_t dev) { @@ -2548,16 +4008,16 @@ static ggml_backend_buffer_type_t ggml_backend_webgpu_device_get_buffer_type(ggm static struct ggml_backend_buffer_type ggml_backend_webgpu_buffer_type = { /* .iface = */ { - /* .get_name = */ ggml_backend_webgpu_buffer_type_get_name, - /* .alloc_buffer = */ ggml_backend_webgpu_buffer_type_alloc_buffer, - /* .get_alignment = */ ggml_backend_webgpu_buffer_type_get_alignment, - /* .get_max_size = */ ggml_backend_webgpu_buffer_type_get_max_size, - /* .get_alloc_size = */ NULL, // defaults to ggml_nbytes - /* .is_host = */ NULL, // defaults to false + /* .get_name = */ ggml_backend_webgpu_buffer_type_get_name, + /* .alloc_buffer = */ ggml_backend_webgpu_buffer_type_alloc_buffer, + /* .get_alignment = */ ggml_backend_webgpu_buffer_type_get_alignment, + /* .get_max_size = */ ggml_backend_webgpu_buffer_type_get_max_size, + /* .get_alloc_size = */ ggml_backend_webgpu_buffer_type_get_alloc_size, + /* .is_host = */ NULL, // defaults to false }, /* .device = */ dev, - /* .context = */ NULL, + /* .context = */ NULL }; return &ggml_backend_webgpu_buffer_type; @@ -2570,6 +4030,7 @@ static bool ggml_backend_webgpu_device_supports_buft(ggml_backend_dev_t dev, ggm static bool ggml_webgpu_supported_qtype(ggml_type type) { switch (type) { + case GGML_TYPE_Q1_0: case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_1: case GGML_TYPE_Q5_0: @@ -2589,6 +4050,7 @@ static bool ggml_webgpu_supported_qtype(ggml_type type) { case GGML_TYPE_IQ1_M: case GGML_TYPE_IQ4_NL: case GGML_TYPE_IQ4_XS: + case GGML_TYPE_MXFP4: return true; default: return false; @@ -2598,16 +4060,16 @@ static bool ggml_webgpu_supported_qtype(ggml_type type) { static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const ggml_tensor * op) { ggml_backend_webgpu_device_context * ctx = static_cast<ggml_backend_webgpu_device_context *>(dev->context); - webgpu_context webgpu_ctx = ctx->webgpu_ctx; - ggml_tensor * src0 = op->src[0]; ggml_tensor * src1 = op->src[1]; ggml_tensor * src2 = op->src[2]; // on smaller devices (or CI), tensors may be larger than the max storage buffer size - if (ggml_nbytes(op) > webgpu_ctx->limits.maxStorageBufferBindingSize || - (src0 != nullptr && ggml_nbytes(src0) > webgpu_ctx->limits.maxStorageBufferBindingSize) || - (src1 != nullptr && ggml_nbytes(src1) > webgpu_ctx->limits.maxStorageBufferBindingSize)) { + if (ggml_nbytes(op) > ctx->webgpu_global_ctx->capabilities.limits.maxStorageBufferBindingSize || + (src0 != nullptr && + ggml_nbytes(src0) > ctx->webgpu_global_ctx->capabilities.limits.maxStorageBufferBindingSize) || + (src1 != nullptr && + ggml_nbytes(src1) > ctx->webgpu_global_ctx->capabilities.limits.maxStorageBufferBindingSize)) { return false; } @@ -2624,23 +4086,38 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const case GGML_OP_SUB: case GGML_OP_MUL: case GGML_OP_DIV: - // TODO: support non-contiguous tensors, e.g. for MOE_EXPERT_REDUCE - // see https://github.com/ggml-org/llama.cpp/pull/16857 supports_op = (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) && (src0->type == op->type) && - (src1->type == op->type) && ggml_is_contiguous(src0) && ggml_is_contiguous(src1); + (src1->type == op->type); + break; + case GGML_OP_ADD_ID: + supports_op = src0->type == GGML_TYPE_F32; + break; + case GGML_OP_CONCAT: + supports_op = (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_I32); + break; + case GGML_OP_REPEAT: + supports_op = (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_I32 || src0->type == GGML_TYPE_I16); break; case GGML_OP_CPY: case GGML_OP_CONT: - supports_op = (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) && - (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16); + supports_op = ((op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) && + (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16)) || + (op->type == GGML_TYPE_I32 && src0->type == GGML_TYPE_F32); + break; + case GGML_OP_SET: + supports_op = src0->type == src1->type && src0->type == op->type && + (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_I32); break; case GGML_OP_SET_ROWS: - supports_op = (op->type == GGML_TYPE_F16 && src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_I64); + supports_op = ((op->type == GGML_TYPE_F16 || op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_Q8_0 || + op->type == GGML_TYPE_Q4_0) && + src0->type == GGML_TYPE_F32 && (src1->type == GGML_TYPE_I64 || src1->type == GGML_TYPE_I32)); break; case GGML_OP_GET_ROWS: - if (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || src0->type == GGML_TYPE_I32 || - ggml_webgpu_supported_qtype(src0->type)) { + if (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || ggml_webgpu_supported_qtype(src0->type)) { supports_op = (op->type == GGML_TYPE_F32); + } else if (src0->type == GGML_TYPE_I32) { + supports_op = op->type == GGML_TYPE_I32; } break; case GGML_OP_MUL_MAT: @@ -2653,6 +4130,7 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const switch (src0->type) { case GGML_TYPE_F32: case GGML_TYPE_F16: + case GGML_TYPE_Q1_0: case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_1: case GGML_TYPE_Q5_0: @@ -2672,6 +4150,7 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const case GGML_TYPE_IQ1_M: case GGML_TYPE_IQ4_NL: case GGML_TYPE_IQ4_XS: + case GGML_TYPE_MXFP4: supports_op = true; break; default: @@ -2682,30 +4161,110 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const } break; } + case GGML_OP_MUL_MAT_ID: + switch (src1->type) { + case GGML_TYPE_F16: + supports_op |= (src0->type == GGML_TYPE_F16); + break; + case GGML_TYPE_F32: + switch (src0->type) { + case GGML_TYPE_F32: + case GGML_TYPE_F16: + case GGML_TYPE_Q1_0: + case GGML_TYPE_Q4_0: + case GGML_TYPE_Q4_1: + case GGML_TYPE_Q5_0: + case GGML_TYPE_Q5_1: + case GGML_TYPE_Q8_0: + case GGML_TYPE_Q2_K: + case GGML_TYPE_Q3_K: + case GGML_TYPE_Q4_K: + case GGML_TYPE_Q5_K: + case GGML_TYPE_Q6_K: + case GGML_TYPE_IQ1_S: + case GGML_TYPE_IQ1_M: + case GGML_TYPE_IQ2_XXS: + case GGML_TYPE_IQ2_XS: + case GGML_TYPE_IQ2_S: + case GGML_TYPE_IQ3_XXS: + case GGML_TYPE_IQ3_S: + case GGML_TYPE_IQ4_NL: + case GGML_TYPE_IQ4_XS: + case GGML_TYPE_MXFP4: + supports_op = true; + break; + default: + break; + } + break; + default: + break; + } + break; case GGML_OP_FLASH_ATTN_EXT: { - if (!webgpu_ctx->supports_subgroup_matrix) { + // conservative support checks for whether the more resource-intensive shader paths + // can be used, to avoid cases where flash_attn is assigned to the CPU later on + supports_op = src0->type == GGML_TYPE_F32 && + (src1->type == GGML_TYPE_F32 || src1->type == GGML_TYPE_F16 || + src1->type == GGML_TYPE_Q4_0 || src1->type == GGML_TYPE_Q8_0) && + (src2->type == GGML_TYPE_F32 || src2->type == GGML_TYPE_F16 || + src2->type == GGML_TYPE_Q4_0 || src2->type == GGML_TYPE_Q8_0) && + op->type == GGML_TYPE_F32; + if (!supports_op) { break; } - // Head dimensions must fit in workgroup memory with minimum tile sizes - size_t limit_bytes = webgpu_ctx->limits.maxComputeWorkgroupStorageSize; - const bool has_mask = op->src[3] != nullptr; - const bool kv_direct = src1->type == GGML_TYPE_F16 && (src0->ne[0] % webgpu_ctx->sg_mat_k) == 0 && - (src1->ne[1] % GGML_WEBGPU_KV_SEQ_PAD) == 0; - const size_t min_bytes = ggml_webgpu_flash_attn_wg_mem_bytes( - webgpu_ctx->sg_mat_m, webgpu_ctx->sg_mat_n, (uint32_t) src0->ne[0], (uint32_t) src2->ne[0], - has_mask, kv_direct); - if (min_bytes > limit_bytes) { + if (ggml_webgpu_tensor_overlap(src1, src2) && src1->type != src2->type && + !ggml_is_quantized(src1->type) && !ggml_is_quantized(src2->type)) { + supports_op = false; break; } - - supports_op = src0->type == GGML_TYPE_F32 && - (src1->type == GGML_TYPE_F32 || src1->type == GGML_TYPE_F16 || - src1->type == GGML_TYPE_Q4_0 || src1->type == GGML_TYPE_Q8_0) && - src2->type == src1->type && op->type == GGML_TYPE_F32; + const auto & capabilities = ctx->webgpu_global_ctx->capabilities; + const size_t storage_offset_alignment = capabilities.limits.minStorageBufferOffsetAlignment; + + // subgroup matrix path requirements + const bool use_subgroup_matrix = ggml_webgpu_flash_attn_can_use_subgroup_matrix_path( + capabilities.supports_subgroup_matrix, capabilities.sg_mat_k, capabilities.sg_mat_n, src0, src2); + + // tile path requirements + const bool float_vec4_aligned = + ((src1->type != GGML_TYPE_F16 && src1->type != GGML_TYPE_F32) || + ggml_webgpu_flash_attn_float_vec4_aligned(src1, storage_offset_alignment)) && + ((src2->type != GGML_TYPE_F16 && src2->type != GGML_TYPE_F32) || + ggml_webgpu_flash_attn_float_vec4_aligned(src2, storage_offset_alignment)); + const uint32_t k_tile_head_align = (src1->type == GGML_TYPE_F32 || src1->type == GGML_TYPE_F16) ? + GGML_WEBGPU_FLASH_ATTN_TILE_KV_VEC_WIDTH : + (uint32_t) ggml_blck_size(src1->type); + const uint32_t v_tile_head_align = (src2->type == GGML_TYPE_F32 || src2->type == GGML_TYPE_F16) ? + GGML_WEBGPU_FLASH_ATTN_TILE_KV_VEC_WIDTH : + (uint32_t) ggml_blck_size(src2->type); + const bool tile_kv_head_dims_aligned = + src0->ne[0] % k_tile_head_align == 0 && src2->ne[0] % v_tile_head_align == 0; + const bool tile_can_dispatch_all_q_rows = + capabilities.limits.maxComputeInvocationsPerWorkgroup >= + GGML_WEBGPU_FLASH_ATTN_TILE_Q_TILE * capabilities.max_subgroup_size; + const bool use_tile = !use_subgroup_matrix && capabilities.supports_subgroups && float_vec4_aligned && + tile_kv_head_dims_aligned && tile_can_dispatch_all_q_rows; + + if (!use_subgroup_matrix && !use_tile) { + supports_op = false; + break; + } + const uint32_t q_tile = + use_subgroup_matrix ? capabilities.sg_mat_m : GGML_WEBGPU_FLASH_ATTN_TILE_Q_TILE; + const uint32_t kv_granularity = use_subgroup_matrix ? capabilities.sg_mat_n : 1u; + const bool kv_direct = use_subgroup_matrix ? + ggml_webgpu_flash_attn_kv_direct(src0, src1, src2, capabilities.sg_mat_k) : + false; + const uint32_t max_kv_tile = ggml_webgpu_flash_attn_max_kv_tile( + capabilities.limits.maxComputeWorkgroupStorageSize, q_tile, kv_granularity, (uint32_t) src0->ne[0], + (uint32_t) src2->ne[0], op->src[3] != nullptr, kv_direct); + supports_op = max_kv_tile > 0; break; } case GGML_OP_RMS_NORM: + case GGML_OP_NORM: + case GGML_OP_L2_NORM: supports_op = op->type == GGML_TYPE_F32 && src0->type == GGML_TYPE_F32; break; case GGML_OP_ROPE: @@ -2753,9 +4312,14 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const case GGML_UNARY_OP_HARDSIGMOID: case GGML_UNARY_OP_EXP: case GGML_UNARY_OP_GELU_ERF: - case GGML_UNARY_OP_XIELU: + case GGML_UNARY_OP_SOFTPLUS: + case GGML_UNARY_OP_EXPM1: + case GGML_UNARY_OP_FLOOR: case GGML_UNARY_OP_CEIL: - supports_op = supports_op = + case GGML_UNARY_OP_ROUND: + case GGML_UNARY_OP_TRUNC: + case GGML_UNARY_OP_XIELU: + supports_op = (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) && (src0->type == op->type); break; default: @@ -2763,14 +4327,94 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const } } break; - + case GGML_OP_TRI: + supports_op = op->type == GGML_TYPE_F32 && src0->type == GGML_TYPE_F32; + break; + case GGML_OP_DIAG: + supports_op = op->type == GGML_TYPE_F32 && src0->type == GGML_TYPE_F32; + break; + case GGML_OP_SOLVE_TRI: + supports_op = op->type == GGML_TYPE_F32 && src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32; + break; + case GGML_OP_CONV_2D: + supports_op = (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) && + (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16) && + (src1->type == GGML_TYPE_F32 || src1->type == GGML_TYPE_F16); + break; + case GGML_OP_IM2COL: + supports_op = (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) && + (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16); + break; + case GGML_OP_SSM_CONV: + supports_op = op->type == GGML_TYPE_F32; + break; + case GGML_OP_SSM_SCAN: + supports_op = op->type == GGML_TYPE_F32 && + src0->ne[0] <= ctx->webgpu_global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup; + break; + case GGML_OP_GATED_DELTA_NET: + { + const uint32_t s_v = (uint32_t) src2->ne[0]; + supports_op = op->type == GGML_TYPE_F32 && src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && + src2->type == GGML_TYPE_F32 && op->src[3]->type == GGML_TYPE_F32 && + op->src[4]->type == GGML_TYPE_F32 && op->src[5]->type == GGML_TYPE_F32 && + s_v <= ctx->webgpu_global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup; + } + break; + case GGML_OP_CLAMP: + supports_op = (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) && (src0->type == op->type); + break; + case GGML_OP_FILL: + supports_op = op->type == GGML_TYPE_F32 && src0->type == GGML_TYPE_F32; + break; + case GGML_OP_LOG: + supports_op = (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) && (src0->type == op->type); + break; + case GGML_OP_SQR: + supports_op = (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) && (src0->type == op->type); + break; + case GGML_OP_SQRT: + supports_op = (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) && (src0->type == op->type); + break; + case GGML_OP_SIN: + supports_op = (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) && (src0->type == op->type); + break; + case GGML_OP_COS: + supports_op = (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) && (src0->type == op->type); + break; + case GGML_OP_PAD: + supports_op = op->type == GGML_TYPE_F32 && src0->type == GGML_TYPE_F32; + break; + case GGML_OP_ARGMAX: + supports_op = op->type == GGML_TYPE_I32 && src0->type == GGML_TYPE_F32; + break; + case GGML_OP_ARGSORT: + supports_op = op->type == GGML_TYPE_I32 && src0->type == GGML_TYPE_F32 && ggml_is_contiguous_rows(src0); + break; + case GGML_OP_TOP_K: + supports_op = op->type == GGML_TYPE_I32 && src0->type == GGML_TYPE_F32 && ggml_is_contiguous_rows(src0); + break; + case GGML_OP_CUMSUM: + supports_op = op->type == GGML_TYPE_F32 && src0->type == op->type; + break; + case GGML_OP_SUM: + case GGML_OP_SUM_ROWS: + supports_op = op->type == GGML_TYPE_F32 && src0->type == op->type && ggml_is_contiguous_rows(src0); + break; + case GGML_OP_UPSCALE: + supports_op = (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) && + (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16); + break; default: break; } - if (ggml_nbytes(op) > webgpu_ctx->limits.maxStorageBufferBindingSize || - (src0 != nullptr && ggml_nbytes(src0) > webgpu_ctx->limits.maxStorageBufferBindingSize) || - (src1 != nullptr && ggml_nbytes(src1) > webgpu_ctx->limits.maxStorageBufferBindingSize) || - (src2 != nullptr && ggml_nbytes(src2) > webgpu_ctx->limits.maxStorageBufferBindingSize)) { + if (ggml_nbytes(op) > ctx->webgpu_global_ctx->capabilities.limits.maxStorageBufferBindingSize || + (src0 != nullptr && + ggml_nbytes(src0) > ctx->webgpu_global_ctx->capabilities.limits.maxStorageBufferBindingSize) || + (src1 != nullptr && + ggml_nbytes(src1) > ctx->webgpu_global_ctx->capabilities.limits.maxStorageBufferBindingSize) || + (src2 != nullptr && + ggml_nbytes(src2) > ctx->webgpu_global_ctx->capabilities.limits.maxStorageBufferBindingSize)) { supports_op = false; WEBGPU_LOG_DEBUG("ggml_webgpu op not supported due to size: "); } @@ -2795,16 +4439,16 @@ static struct ggml_backend_device_i ggml_backend_webgpu_device_i = { /* .get_memory = */ ggml_backend_webgpu_device_get_memory, /* .get_type = */ ggml_backend_webgpu_device_get_type, /* .get_props = */ ggml_backend_webgpu_device_get_props, - /* .init_backend = */ ggml_backend_webgpu_device_init, + /* .init_backend = */ ggml_backend_webgpu_backend_init, /* .get_buffer_type = */ ggml_backend_webgpu_device_get_buffer_type, /* .get_host_buffer_type = */ NULL, /* .buffer_from_host_ptr = */ NULL, /* .supports_op = */ ggml_backend_webgpu_device_supports_op, /* .supports_buft = */ ggml_backend_webgpu_device_supports_buft, /* .offload_op = */ NULL, - /* .event_new = */ NULL, - /* .event_free = */ NULL, - /* .event_synchronize = */ NULL, + /* .event_new = */ ggml_backend_webgpu_device_event_new, + /* .event_free = */ ggml_backend_webgpu_device_event_free, + /* .event_synchronize = */ ggml_backend_webgpu_device_event_synchronize, }; /* End GGML Backend Device Interface */ @@ -2821,8 +4465,6 @@ static size_t ggml_backend_webgpu_reg_get_device_count(ggml_backend_reg_t reg) { return ctx->device_count; } -// TODO: Does this need to be thread safe? Is it only called once? -// TODO: move most logic to device_init function so backend can be freed/initialized properly // Only one device is supported for now static ggml_backend_dev_t ggml_backend_webgpu_reg_get_device(ggml_backend_reg_t reg, size_t index) { GGML_ASSERT(index == 0); @@ -2832,191 +4474,12 @@ static ggml_backend_dev_t ggml_backend_webgpu_reg_get_device(ggml_backend_reg_t ggml_backend_webgpu_reg_context * reg_ctx = static_cast<ggml_backend_webgpu_reg_context *>(reg->context); - webgpu_context ctx = reg_ctx->webgpu_ctx; - - wgpu::RequestAdapterOptions options = {}; - -#ifndef __EMSCRIPTEN__ - // TODO: track need for these toggles: https://issues.chromium.org/issues/42251215 - const char * const adapterEnabledToggles[] = { "vulkan_enable_f16_on_nvidia", "use_vulkan_memory_model" }; - wgpu::DawnTogglesDescriptor adapterTogglesDesc; - adapterTogglesDesc.enabledToggles = adapterEnabledToggles; - adapterTogglesDesc.enabledToggleCount = 2; - options.nextInChain = &adapterTogglesDesc; -#endif - - ctx->instance.WaitAny(ctx->instance.RequestAdapter( - &options, wgpu::CallbackMode::AllowSpontaneous, - [&ctx](wgpu::RequestAdapterStatus status, wgpu::Adapter adapter, const char * message) { - if (status != wgpu::RequestAdapterStatus::Success) { - GGML_LOG_ERROR("ggml_webgpu: Failed to get an adapter: %s\n", message); - return; - } - ctx->adapter = std::move(adapter); - }), - UINT64_MAX); - GGML_ASSERT(ctx->adapter != nullptr); - - ctx->adapter.GetLimits(&ctx->limits); - - wgpu::AdapterInfo info{}; -#ifndef __EMSCRIPTEN__ - wgpu::AdapterPropertiesSubgroupMatrixConfigs subgroup_matrix_configs{}; - if (ctx->adapter.HasFeature(wgpu::FeatureName::ChromiumExperimentalSubgroupMatrix)) { - info.nextInChain = &subgroup_matrix_configs; - } -#endif - ctx->adapter.GetInfo(&info); - - wgpu::SupportedFeatures features; - ctx->adapter.GetFeatures(&features); - // we require f16 support - GGML_ASSERT(ctx->adapter.HasFeature(wgpu::FeatureName::ShaderF16)); - -#ifndef __EMSCRIPTEN__ - // Only support square f16 matrices of size 8 or 16 for now - bool valid_subgroup_matrix_config = false; - if (ctx->adapter.HasFeature(wgpu::FeatureName::ChromiumExperimentalSubgroupMatrix)) { - for (size_t i = 0; i < subgroup_matrix_configs.configCount; i++) { - const wgpu::SubgroupMatrixConfig config = subgroup_matrix_configs.configs[i]; - if (config.M == config.N && config.N == config.K && (config.K == 8 || config.K == 16) && - config.componentType == wgpu::SubgroupMatrixComponentType::F16 && - config.resultComponentType == wgpu::SubgroupMatrixComponentType::F16) { - ctx->sg_mat_m = config.M; - ctx->sg_mat_n = config.N; - ctx->sg_mat_k = config.K; - valid_subgroup_matrix_config = true; - break; - } - } - } - - ctx->supports_subgroup_matrix = valid_subgroup_matrix_config; -#endif - // For subgroup matrix code to be the most efficient, we would like the subgroup size to be consistent and accurate. - // Unfortunately, that is not possible, so we use the maximum subgroup size reported by the adapter. - ctx->max_subgroup_size = info.subgroupMaxSize; - - // Initialize device - std::vector<wgpu::FeatureName> required_features = { wgpu::FeatureName::ShaderF16 }; - -#ifndef __EMSCRIPTEN__ - required_features.push_back(wgpu::FeatureName::ImplicitDeviceSynchronization); - if (ctx->supports_subgroup_matrix) { - required_features.push_back(wgpu::FeatureName::Subgroups); - required_features.push_back(wgpu::FeatureName::ChromiumExperimentalSubgroupMatrix); - } -#endif - -#ifdef GGML_WEBGPU_GPU_PROFILE - required_features.push_back(wgpu::FeatureName::TimestampQuery); -#endif - - wgpu::DeviceDescriptor dev_desc; - dev_desc.requiredLimits = &ctx->limits; - dev_desc.requiredFeatures = required_features.data(); - dev_desc.requiredFeatureCount = required_features.size(); - dev_desc.SetDeviceLostCallback( - wgpu::CallbackMode::AllowSpontaneous, - [](const wgpu::Device & device, wgpu::DeviceLostReason reason, wgpu::StringView message) { - GGML_UNUSED(device); - GGML_UNUSED(reason); - GGML_UNUSED(message); - //TODO: uncomment once proper free logic is in place - //GGML_LOG_ERROR("ggml_webgpu: Device lost! Reason: %d, Message: %s\n", static_cast<int>(reason), - //std::string(message).c_str()); - }); - dev_desc.SetUncapturedErrorCallback( - [](const wgpu::Device & device, wgpu::ErrorType reason, wgpu::StringView message) { - GGML_UNUSED(device); - GGML_ABORT("ggml_webgpu: Device error! Reason: %d, Message: %s\n", static_cast<int>(reason), - std::string(message).c_str()); - }); - -#ifndef __EMSCRIPTEN__ - // Enable Dawn-specific toggles to increase native performance - // TODO: Maybe WebGPU needs a "fast" mode where you can request compilers skip adding checks like these, - // only for native performance? - const char * const deviceEnabledToggles[] = { "skip_validation", "disable_robustness", "disable_workgroup_init", - "disable_polyfills_on_integer_div_and_mod" }; - const char * const deviceDisabledToggles[] = { "timestamp_quantization" }; - wgpu::DawnTogglesDescriptor deviceTogglesDesc; - deviceTogglesDesc.enabledToggles = deviceEnabledToggles; - deviceTogglesDesc.enabledToggleCount = 4; - deviceTogglesDesc.disabledToggles = deviceDisabledToggles; - deviceTogglesDesc.disabledToggleCount = 1; - - dev_desc.nextInChain = &deviceTogglesDesc; -#endif - - ctx->instance.WaitAny(ctx->adapter.RequestDevice( - &dev_desc, wgpu::CallbackMode::AllowSpontaneous, - [ctx](wgpu::RequestDeviceStatus status, wgpu::Device device, wgpu::StringView message) { - if (status != wgpu::RequestDeviceStatus::Success) { - GGML_LOG_ERROR("ggml_webgpu: Failed to get a device: %s\n", - std::string(message).c_str()); - return; - } - ctx->device = std::move(device); - }), - UINT64_MAX); - GGML_ASSERT(ctx->device != nullptr); - - // Initialize (compute) queue - ctx->queue = ctx->device.GetQueue(); - - // Create buffer pool for shader parameters - ctx->param_buf_pool.init(ctx->device, WEBGPU_NUM_PARAM_BUFS, WEBGPU_PARAMS_BUF_SIZE_BYTES, - wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::Uniform, - wgpu::BufferUsage::CopySrc | wgpu::BufferUsage::MapWrite); - -#ifdef GGML_WEBGPU_GPU_PROFILE - // Initialize buffer pool for timestamp queries (profiling) - ctx->timestamp_query_buf_pool.init(ctx->device, WEBGPU_NUM_TIMESTAMP_QUERY_BUFS, - WEBGPU_TIMESTAMP_QUERY_BUF_SIZE_BYTES, - wgpu::BufferUsage::QueryResolve | wgpu::BufferUsage::CopySrc, - wgpu::BufferUsage::MapRead | wgpu::BufferUsage::CopyDst); -#endif - - ctx->set_rows_error_buf_pool.init(ctx->device, WEBGPU_NUM_SET_ROWS_ERROR_BUFS, WEBGPU_SET_ROWS_ERROR_BUF_SIZE_BYTES, - wgpu::BufferUsage::CopySrc | wgpu::BufferUsage::Storage, - wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::MapRead); - - ggml_webgpu_init_memset_pipeline(ctx); - ggml_webgpu_init_mul_mat_pipeline(ctx); - ggml_webgpu_init_set_rows_pipeline(ctx); - ggml_webgpu_init_get_rows_pipeline(ctx); - ggml_webgpu_init_cpy_pipeline(ctx); - ggml_webgpu_init_add_pipeline(ctx); - ggml_webgpu_init_sub_pipeline(ctx); - ggml_webgpu_init_mul_pipeline(ctx); - ggml_webgpu_init_div_pipeline(ctx); - ggml_webgpu_init_rms_norm_pipeline(ctx); - ggml_webgpu_init_rope_pipeline(ctx); - ggml_webgpu_init_glu_pipeline(ctx); - ggml_webgpu_init_scale_pipeline(ctx); - ggml_webgpu_init_soft_max_pipeline(ctx); - ggml_webgpu_init_unary_pipeline(ctx); - -#ifdef GGML_WEBGPU_DEBUG - // Initialize debug buffers - ggml_webgpu_create_buffer(ctx->device, ctx->debug_host_buf, WEBGPU_DEBUG_BUF_ELEMS * sizeof(uint32_t), - wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::MapRead, "debug_host_buf"); - ggml_webgpu_create_buffer(ctx->device, ctx->debug_dev_buf, WEBGPU_DEBUG_BUF_ELEMS * sizeof(uint32_t), - wgpu::BufferUsage::Storage | wgpu::BufferUsage::CopySrc, "debug_dev_buf"); -#endif + create_webgpu_device(reg_ctx); static ggml_backend_webgpu_device_context device_ctx; - device_ctx.webgpu_ctx = ctx; - device_ctx.device_name = GGML_WEBGPU_NAME; - device_ctx.device_desc = info.description; - - GGML_LOG_INFO( - "ggml_webgpu: adapter_info: vendor_id: %u | vendor: %s | architecture: %s | device_id: %u | name: %s | " - "device_desc: %s\n", - info.vendorID, std::string(info.vendor).c_str(), std::string(info.architecture).c_str(), info.deviceID, - std::string(info.device).c_str(), std::string(info.description).c_str()); - + device_ctx.device_name = GGML_WEBGPU_NAME; + device_ctx.device_desc = GGML_WEBGPU_NAME; + device_ctx.webgpu_global_ctx = reg_ctx->webgpu_global_ctx; // See GGML Backend Device Interface section static ggml_backend_device device = { /* .iface = */ ggml_backend_webgpu_device_i, @@ -3024,7 +4487,7 @@ static ggml_backend_dev_t ggml_backend_webgpu_reg_get_device(ggml_backend_reg_t /* .context = */ &device_ctx, }; - WEBGPU_CPU_PROFILE_TOTAL_END(reg_get_device, ctx); + WEBGPU_CPU_PROFILE_TOTAL_END(reg_get_device, reg_ctx->webgpu_global_ctx); return &device; } @@ -3040,12 +4503,25 @@ static const struct ggml_backend_reg_i ggml_backend_webgpu_reg_i = { ggml_backend_reg_t ggml_backend_webgpu_reg() { WEBGPU_LOG_DEBUG("ggml_backend_webgpu_reg()"); - webgpu_context webgpu_ctx = std::make_shared<webgpu_context_struct>(); + // Intentionally leak the global registry context to avoid crashing inside + // Dawn/Vulkan static teardown during process exit. + static ggml_backend_webgpu_reg_context * ctx = new ggml_backend_webgpu_reg_context(); + + static ggml_backend_reg reg = { + /* .api_version = */ GGML_BACKEND_API_VERSION, + /* .iface = */ ggml_backend_webgpu_reg_i, + /* .context = */ ctx, + }; + + ctx->name = GGML_WEBGPU_NAME; + ctx->device_count = 0; - static ggml_backend_webgpu_reg_context ctx; - ctx.webgpu_ctx = webgpu_ctx; - ctx.name = GGML_WEBGPU_NAME; - ctx.device_count = 1; + // Keep one Dawn/WebGPU instance alive for the lifetime of the static backend + // registry. Recreating it on repeated registry lookups can invalidate + // adapter/device references that are still held by the backend/device layer. + if (ctx->webgpu_global_ctx != nullptr && ctx->webgpu_global_ctx->instance != nullptr) { + return ® + } wgpu::InstanceDescriptor instance_descriptor{}; std::vector<wgpu::InstanceFeatureName> instance_features = { wgpu::InstanceFeatureName::TimedWaitAny }; @@ -3060,28 +4536,48 @@ ggml_backend_reg_t ggml_backend_webgpu_reg() { instance_descriptor.nextInChain = &instanceTogglesDesc; #endif - webgpu_ctx->instance = wgpu::CreateInstance(&instance_descriptor); - -#ifdef __EMSCRIPTEN__ - if (webgpu_ctx->instance == nullptr) { - GGML_LOG_ERROR("ggml_webgpu: Failed to create WebGPU instance. Make sure either -sASYNCIFY or -sJSPI is set\n"); - return nullptr; + wgpu::Instance inst = wgpu::CreateInstance(&instance_descriptor); + ctx->webgpu_global_ctx = webgpu_global_context(new webgpu_global_context_struct()); + ctx->webgpu_global_ctx->instance = std::move(inst); + + // Probe for adapter support + wgpu::Adapter adapter; + if (ctx->webgpu_global_ctx->instance != nullptr) { + wgpu::RequestAdapterOptions options = {}; + + // probe for adapter support + ctx->webgpu_global_ctx->instance.WaitAny( + ctx->webgpu_global_ctx->instance.RequestAdapter( + &options, wgpu::CallbackMode::AllowSpontaneous, + [&adapter](wgpu::RequestAdapterStatus status, wgpu::Adapter _adapter, const char * message) { + if (status != wgpu::RequestAdapterStatus::Success) { + GGML_LOG_ERROR("ggml_webgpu: Failed to get an adapter: %s\n", message); + return; + } + adapter = std::move(_adapter); + }), + UINT64_MAX); } + + // WebGPU backend requires f16 support and, on native, implicit device synchronization. + if (adapter != nullptr && adapter.HasFeature(wgpu::FeatureName::ShaderF16) +#ifndef __EMSCRIPTEN__ + && adapter.HasFeature(wgpu::FeatureName::ImplicitDeviceSynchronization) #endif - GGML_ASSERT(webgpu_ctx->instance != nullptr); + ) { + ctx->device_count = 1; + } - static ggml_backend_reg reg = { - /* .api_version = */ GGML_BACKEND_API_VERSION, - /* .iface = */ ggml_backend_webgpu_reg_i, - /* .context = */ &ctx, - }; return ® } ggml_backend_t ggml_backend_webgpu_init(void) { - ggml_backend_dev_t dev = ggml_backend_reg_dev_get(ggml_backend_webgpu_reg(), 0); - - return ggml_backend_webgpu_device_init(dev, nullptr); + ggml_backend_reg_t reg = ggml_backend_webgpu_reg(); + if (ggml_backend_reg_dev_count(reg) == 0) { + return nullptr; + } + ggml_backend_dev_t dev = ggml_backend_reg_dev_get(reg, 0); + return ggml_backend_webgpu_backend_init(dev, nullptr); } GGML_BACKEND_DL_IMPL(ggml_backend_webgpu_reg) diff --git a/ggml/src/ggml-webgpu/pre_wgsl.hpp b/ggml/src/ggml-webgpu/pre_wgsl.hpp index 4d4359463ca..fb41a961d74 100644 --- a/ggml/src/ggml-webgpu/pre_wgsl.hpp +++ b/ggml/src/ggml-webgpu/pre_wgsl.hpp @@ -37,15 +37,33 @@ static std::string trim(const std::string & s) { } static std::string trim_value(std::istream & is) { - std::string str; - std::getline(is, str); - return trim(str); + std::ostringstream ss; + ss << is.rdbuf(); + return trim(ss.str()); } static bool isIdentChar(char c) { return std::isalnum(static_cast<unsigned char>(c)) || c == '_'; } +static bool endsWithContinuation(const std::string & line) { + size_t i = line.size(); + while (i > 0 && std::isspace((unsigned char) line[i - 1])) { + i--; + } + return i > 0 && line[i - 1] == '\\'; +} + +static void stripContinuation(std::string & line) { + size_t i = line.size(); + while (i > 0 && std::isspace((unsigned char) line[i - 1])) { + i--; + } + if (i > 0 && line[i - 1] == '\\') { + line.erase(i - 1); + } +} + static std::string expandMacrosRecursiveInternal(const std::string & line, const std::unordered_map<std::string, std::string> & macros, std::unordered_set<std::string> & visiting); @@ -595,19 +613,31 @@ class Preprocessor { std::string line; while (std::getline(in, line)) { - std::string t = trim(line); + std::string logical = line; + std::string t = trim(logical); + if (!t.empty() && t[0] == '#') { + while (endsWithContinuation(logical)) { + stripContinuation(logical); + if (!std::getline(in, line)) { + break; + } + logical += "\n"; + logical += line; + } + t = trim(logical); + } if (!t.empty() && t[0] == '#') { bool handled = handleDirective(t, out, macros, predefined_macros, cond, include_stack, mode); if (mode == DirectiveMode::IncludesOnly && !handled) { - out << line << "\n"; + out << logical << "\n"; } } else { if (mode == DirectiveMode::IncludesOnly) { - out << line << "\n"; + out << logical << "\n"; } else if (condActive(cond)) { // Expand macros in the line before outputting - std::string expanded = expandMacrosRecursive(line, macros); + std::string expanded = expandMacrosRecursive(logical, macros); out << expanded << "\n"; } } diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/add_id.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/add_id.wgsl new file mode 100644 index 00000000000..2573926cb89 --- /dev/null +++ b/ggml/src/ggml-webgpu/wgsl-shaders/add_id.wgsl @@ -0,0 +1,64 @@ +struct Params { + offset_src0: u32, + offset_src1: u32, + offset_ids: u32, + offset_dst: u32, + + nb01: u32, + nb02: u32, + nb11: u32, + nb20: u32, + nb21: u32, + + ne0: u32, + ne1: u32, + ne2: u32, +}; + +@group(0) @binding(0) var<storage, read_write> src0: array<f32>; // [n_embd, n_experts_used, n_token] +@group(0) @binding(1) var<storage, read_write> src1: array<f32>; // [n_embd, n_experts] +@group(0) @binding(2) var<storage, read_write> ids: array<i32>; // [n_experts_used, n_token] + +#ifdef INPLACE + +@group(0) @binding(3) +var<uniform> params: Params; + +#else + +@group(0) @binding(3) +var<storage, read_write> dst: array<f32>; + +@group(0) @binding(4) +var<uniform> params: Params; + +#endif + +@compute @workgroup_size(WG_SIZE) +fn main(@builtin(workgroup_id) wg_id: vec3<u32>, + @builtin(num_workgroups) num_wg: vec3<u32>, + @builtin(local_invocation_id) local_id: vec3<u32>) { + + let wg_linear = wg_id.x + wg_id.y * num_wg.x; + + if (wg_linear < params.ne1 * params.ne2) { + let thread_id = local_id.x; + let i2 = wg_linear / params.ne1; + let i1 = wg_linear % params.ne1; + + let i11 = u32(ids[params.offset_ids + i1 * params.nb20 + i2 * params.nb21]); + + let src0_row = params.offset_src0 + i1 * params.nb01 + i2 * params.nb02; + let src1_row = params.offset_src1 + i11 * params.nb11; + let dst_row = params.offset_dst + i1 * params.ne0 + i2 * (params.ne0 * params.ne1); + + for (var i = thread_id;i < params.ne0; i += WG_SIZE) { +#ifdef INPLACE + src0[src0_row + i] = src0[src0_row + i] + src1[src1_row + i]; +#else + dst[dst_row + i] = src0[src0_row + i] + src1[src1_row + i]; +#endif + } + } + +} diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/argmax.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/argmax.wgsl new file mode 100644 index 00000000000..ca5bfcc4d4c --- /dev/null +++ b/ggml/src/ggml-webgpu/wgsl-shaders/argmax.wgsl @@ -0,0 +1,72 @@ +@group(0) @binding(0) +#ifdef VEC4 +var<storage, read_write> src: array<vec4<f32>>; +#define VEC_SIZE 4 +#else +var<storage, read_write> src: array<f32>; +#define VEC_SIZE 1 +#endif + +@group(0) @binding(1) +var<storage, read_write> dst: array<i32>; + +struct Params { + offset_src: u32, // in elements + offset_dst: u32, // in elements + ne0: u32, +}; + +@group(0) @binding(2) +var<uniform> params: Params; + +const FLOAT_MIN: f32 = -1.0e9; + +struct Pair { + value: f32, + index: i32 +}; + +var<workgroup> shared_max: array<Pair, WG_SIZE>; + +@compute @workgroup_size(WG_SIZE) +fn main(@builtin(workgroup_id) wid: vec3<u32>, + @builtin(local_invocation_id) lid: vec3<u32>) { + let row_idx = params.offset_src + wid.x * params.ne0; + var local_pair = Pair(FLOAT_MIN, -1); +#ifdef VEC4 + for (var col = lid.x; col < params.ne0/VEC_SIZE; col += WG_SIZE) { + let vec_val = src[row_idx / VEC_SIZE + col]; + for (var v = 0u; v < VEC_SIZE; v++) { + let val = vec_val[v]; + if (val >= local_pair.value) { + local_pair = Pair(val, i32(col * VEC_SIZE + v)); + } + } + } +#else + for (var col = lid.x; col < params.ne0; col += WG_SIZE) { + if (src[row_idx + col] >= local_pair.value) { + local_pair = Pair(src[row_idx + col], i32(col)); + } + } +#endif + shared_max[lid.x] = local_pair; + workgroupBarrier(); + var offset: u32 = WG_SIZE >> 1; + while (offset > 0) { + if (lid.x < offset) { + let a = shared_max[lid.x]; + let b = shared_max[lid.x + offset]; + if (b.value > a.value) { + shared_max[lid.x] = b; + } else if (b.value == a.value && b.index > a.index) { + shared_max[lid.x] = b; + } + } + workgroupBarrier(); + offset >>= 1; + } + if (lid.x == 0u) { + dst[params.offset_dst + wid.x] = shared_max[0].index; + } +} diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/argsort.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/argsort.wgsl new file mode 100644 index 00000000000..46ed19fc775 --- /dev/null +++ b/ggml/src/ggml-webgpu/wgsl-shaders/argsort.wgsl @@ -0,0 +1,106 @@ +@group(0) @binding(0) +var<storage, read_write> src: array<f32>; + +@group(0) @binding(1) +var<storage, read_write> dst: array<i32>; + +struct Params { + offset_src: u32, // in elements + offset_dst: u32, // in elements + + stride_src1: u32, + stride_src2: u32, + stride_src3: u32, + + stride_dst1: u32, + stride_dst2: u32, + stride_dst3: u32, + + // src/dst dimensions + src_ne0: u32, + ne1: u32, + ne2: u32, + + ne0: u32, + top_k: u32, + + npr: u32, // tiles per row + nrows: u32 +}; + +@group(0) @binding(2) +var<uniform> params: Params; + +var<workgroup> shmem_idx: array<u32, WG_SIZE>; + +#if ORDER == 0 +#define EXTREME_VALUE 1e30 +#define SWAP_COMPARE_UP > +#define SWAP_COMPARE_DOWN < +#else +#define EXTREME_VALUE -1e30 +#define SWAP_COMPARE_UP < +#define SWAP_COMPARE_DOWN > +#endif + +@compute @workgroup_size(WG_SIZE) +fn main(@builtin(workgroup_id) wid: vec3<u32>, + @builtin(num_workgroups) num_wg: vec3<u32>, + @builtin(local_invocation_id) lid: vec3<u32>) { + let linear = wid.x + wid.y * num_wg.x; + // guard against overprovisioned workgroups + if (linear >= params.npr * params.nrows) { + return; + } + let tile = linear % params.npr; + var row = linear / params.npr; + let i3 = row / (params.ne2 * params.ne1); + row = row % (params.ne2 * params.ne1); + let i2 = row / params.ne1; + let i1 = row % params.ne1; + + let row_base = params.offset_src + + i1 * params.stride_src1 + + i2 * params.stride_src2 + + i3 * params.stride_src3; + + let tile_base = tile * WG_SIZE; + let idx = tile_base + lid.x; + shmem_idx[lid.x] = select(params.src_ne0, idx, idx < params.src_ne0); + workgroupBarrier(); + + var k = 2u; + while (k <= WG_SIZE) { + var j = k >> 1; + while (j > 0) { + let ixj = lid.x ^ j; + if (ixj > lid.x) { + let dir_up = (lid.x & k) == 0; + let a_idx = shmem_idx[lid.x]; + let b_idx = shmem_idx[ixj]; + let a_val = select(EXTREME_VALUE, src[row_base + a_idx], a_idx < params.src_ne0); + let b_val = select(EXTREME_VALUE, src[row_base + b_idx], b_idx < params.src_ne0); + let should_swap = select( + (a_val SWAP_COMPARE_DOWN b_val), + (a_val SWAP_COMPARE_UP b_val), + dir_up); + if (should_swap) { + shmem_idx[lid.x] = b_idx; + shmem_idx[ixj] = a_idx; + } + } + workgroupBarrier(); + j >>= 1; + } + k <<= 1; + } + + let out_idx = tile * params.top_k + lid.x; + if (out_idx < params.ne0 && lid.x < params.top_k) { + let row_dst = params.offset_dst + + i1 * params.stride_dst1 + + i2 * params.stride_dst2 + + i3 * params.stride_dst3; + dst[row_dst + out_idx] = i32(shmem_idx[lid.x]); + } +} diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/argsort_merge.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/argsort_merge.wgsl new file mode 100644 index 00000000000..9a77f6eca74 --- /dev/null +++ b/ggml/src/ggml-webgpu/wgsl-shaders/argsort_merge.wgsl @@ -0,0 +1,134 @@ +@group(0) @binding(0) +var<storage, read_write> src: array<f32>; + +@group(0) @binding(1) +var<storage, read_write> idx_in: array<i32>; + +@group(0) @binding(2) +var<storage, read_write> idx_out: array<i32>; + +struct Params { + offset_src: u32, // in elements + offset_in: u32, // in elements + offset_out: u32, // in elements + + stride_src1: u32, + stride_src2: u32, + stride_src3: u32, + + stride_idx1: u32, + stride_idx2: u32, + stride_idx3: u32, + + stride_out1: u32, + stride_out2: u32, + stride_out3: u32, + + ne0: u32, + ne1: u32, + ne2: u32, + + top_k: u32, + + len: u32, + nm: u32, + nrows: u32 +}; + +@group(0) @binding(3) +var<uniform> params: Params; + +fn take_left(a_idx: i32, b_idx: i32, row_base: u32) -> bool { + let a_val = src[row_base + u32(a_idx)]; + let b_val = src[row_base + u32(b_idx)]; +#if ORDER == 0 + return a_val <= b_val; +#else + return a_val >= b_val; +#endif +} + +@compute @workgroup_size(WG_SIZE) +fn main(@builtin(workgroup_id) wid: vec3<u32>, + @builtin(num_workgroups) num_wg: vec3<u32>, + @builtin(local_invocation_id) lid: vec3<u32>) { + let linear = wid.x + wid.y * num_wg.x; + // guard against overprovisioned workgroups + if (linear >= params.nm * params.nrows) { + return; + } + + let start = (linear % params.nm) * params.len * 2; + let len0 = min(params.len, params.ne0 - start); + let rem1 = select(0, params.ne0 - (start + params.len), params.ne0 > (start + params.len)); + let len1 = min(params.len, rem1); + let total = len0 + len1; + let chunk = (total + WG_SIZE - 1u) / WG_SIZE; + let k0 = lid.x * chunk; + let k1 = min(min(k0 + chunk, total), params.top_k); + // guard against overprovisioned threads + if (k0 >= params.top_k || k0 >= total) { + return; + } + + var row = linear / params.nm; + let i3 = row / (params.ne2 * params.ne1); + row = row % (params.ne2 * params.ne1); + let i2 = row / params.ne1; + let i1 = row % params.ne1; + + let row_src = params.offset_src + + i1 * params.stride_src1 + + i2 * params.stride_src2 + + i3 * params.stride_src3; + + let row_in = params.offset_in + + i1 * params.stride_idx1 + + i2 * params.stride_idx2 + + i3 * params.stride_idx3; + + let row_out = params.offset_out + + i1 * params.stride_out1 + + i2 * params.stride_out2 + + i3 * params.stride_out3; + + + var low: u32 = select(0, k0 - len1, k0 > len1); + var high: u32 = min(k0, len0); + + while (low < high) { + let mid = (low + high) >> 1; + let idx0 = idx_in[row_in + start + mid]; + let idx1 = idx_in[row_in + start + params.len + (k0 - mid - 1)]; + if (take_left(idx0, idx1, row_src)) { + low = mid + 1; + } else { + high = mid; + } + } + + var i = low; + var j = k0 - i; + var k = k0; + while (k < k1) { + var take_l = false; + if (i >= len0) { + take_l = false; + } else if (j >= len1) { + take_l = true; + } else { + let idx0 = idx_in[row_in + start + i]; + let idx1 = idx_in[row_in + start + params.len + j]; + take_l = take_left(idx0, idx1, row_src); + } + + let out_idx = select( + idx_in[row_in + start + params.len + j], + idx_in[row_in + start + i], + take_l); + idx_out[row_out + start + k] = out_idx; + i = select(i, i + 1, take_l); + j = select(j + 1, j, take_l); + k += 1; + } +} diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/bin_op.tmpl.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/bin_op.tmpl.wgsl deleted file mode 100644 index 1ce4d83fa8e..00000000000 --- a/ggml/src/ggml-webgpu/wgsl-shaders/bin_op.tmpl.wgsl +++ /dev/null @@ -1,188 +0,0 @@ -#define(VARIANTS) - -[ - { - "SHADER_NAME": "add_f32", - "REPLS": { - "TYPE" : "f32", - "OP": "+" - }, - "DECLS": ["NOT_INPLACE"] - }, - { - "SHADER_NAME": "add_f16", - "REPLS": { - "TYPE" : "f16", - "OP": "+" - }, - "DECLS": ["NOT_INPLACE"] - }, - { - "SHADER_NAME": "add_f32_inplace", - "REPLS": { - "TYPE" : "f32", - "OP": "+" - }, - "DECLS": ["INPLACE"] - }, - { - "SHADER_NAME": "add_f16_inplace", - "REPLS": { - "TYPE" : "f16", - "OP": "+" - }, - "DECLS": ["INPLACE"] - }, - { - "SHADER_NAME": "mul_f32", - "REPLS": { - "TYPE" : "f32", - "OP": "*" - }, - "DECLS": ["NOT_INPLACE"] - }, - { - "SHADER_NAME": "mul_f16", - "REPLS": { - "TYPE" : "f16", - "OP": "*" - }, - "DECLS": ["NOT_INPLACE"] - }, - { - "SHADER_NAME": "mul_f32_inplace", - "REPLS": { - "TYPE" : "f32", - "OP": "*" - }, - "DECLS": ["INPLACE"] - }, - { - "SHADER_NAME": "mul_f16_inplace", - "REPLS": { - "TYPE" : "f16", - "OP": "*" - }, - "DECLS": ["INPLACE"] - }, - { - "SHADER_NAME": "sub_f32", - "REPLS": { - "TYPE" : "f32", - "OP": "-" - }, - "DECLS": ["NOT_INPLACE"] - }, - { - "SHADER_NAME": "sub_f16", - "REPLS": { - "TYPE" : "f16", - "OP": "-" - }, - "DECLS": ["NOT_INPLACE"] - }, - { - "SHADER_NAME": "sub_f32_inplace", - "REPLS": { - "TYPE" : "f32", - "OP": "-" - }, - "DECLS": ["INPLACE"] - }, - { - "SHADER_NAME": "sub_f16_inplace", - "REPLS": { - "TYPE" : "f16", - "OP": "-" - }, - "DECLS": ["INPLACE"] - }, - { - "SHADER_NAME": "div_f32", - "REPLS": { - "TYPE" : "f32", - "OP": "/" - }, - "DECLS": ["NOT_INPLACE"] - }, - { - "SHADER_NAME": "div_f16", - "REPLS": { - "TYPE" : "f16", - "OP": "/" - }, - "DECLS": ["NOT_INPLACE"] - }, - { - "SHADER_NAME": "div_f32_inplace", - "REPLS": { - "TYPE" : "f32", - "OP": "/" - }, - "DECLS": ["INPLACE"] - }, - { - "SHADER_NAME": "div_f16_inplace", - "REPLS": { - "TYPE" : "f16", - "OP": "/" - }, - "DECLS": ["INPLACE"] - } -] - -#end(VARIANTS) - -#define(DECLS) - -#decl(NOT_INPLACE) - -fn update(dst_i: u32, src0_i: u32, src1_i: u32) { - dst[dst_i] = src0[src0_i] {{OP}} src1[src1_i]; -} - -@group(0) @binding(2) -var<storage, read_write> dst: array<{{TYPE}}>; - -@group(0) @binding(3) -var<uniform> params: Params; - -#enddecl(NOT_INPLACE) - -#decl(INPLACE) - -fn update(dst_i: u32, src0_i: u32, src1_i: u32) { - src0[dst_i] = src0[src0_i] {{OP}} src1[src1_i]; -} - -@group(0) @binding(2) -var<uniform> params: Params; - -#enddecl(INPLACE) - -#end(DECLS) - - -#define(SHADER) - -enable f16; - -#include "binary_head.tmpl" - -@group(0) @binding(0) -var<storage, read_write> src0: array<{{TYPE}}>; - -@group(0) @binding(1) -var<storage, read_write> src1: array<{{TYPE}}>; - -DECLS - -override wg_size: u32; -@compute @workgroup_size(wg_size) -fn main(@builtin(global_invocation_id) gid: vec3<u32>) { - if (gid.x < params.ne) { - update(params.offset_dst + gid.x, params.offset_src0 + gid.x, params.offset_src1 + src1_index(gid.x)); - } -} - -#end(SHADER) diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/binary.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/binary.wgsl new file mode 100644 index 00000000000..f262c4a8f6a --- /dev/null +++ b/ggml/src/ggml-webgpu/wgsl-shaders/binary.wgsl @@ -0,0 +1,142 @@ +enable f16; + +struct Params { + ne: u32, + + // offsets in elements + offset_src0: u32, + offset_src1: u32, + offset_dst: u32, + + stride_src0_0: u32, + stride_src0_1: u32, + stride_src0_2: u32, + stride_src0_3: u32, + + stride_src1_0: u32, + stride_src1_1: u32, + stride_src1_2: u32, + stride_src1_3: u32, + + a_ne0: u32, + a_ne1: u32, + a_ne2: u32, + + b_ne0: u32, + b_ne1: u32, + b_ne2: u32, + b_ne3: u32, +}; + +fn src0_index(_i: u32) -> u32 { + var i = _i; + let a_i3 = i / (params.a_ne2 * params.a_ne1 * params.a_ne0); + i = i % (params.a_ne2 * params.a_ne1 * params.a_ne0); + let a_i2 = i / (params.a_ne1 * params.a_ne0); + i = i % (params.a_ne1 * params.a_ne0); + let a_i1 = i / params.a_ne0; + let a_i0 = i % params.a_ne0; + + return a_i0 * params.stride_src0_0 + + a_i1 * params.stride_src0_1 + + a_i2 * params.stride_src0_2 + + a_i3 * params.stride_src0_3; +} + +fn src1_index(_i: u32) -> u32 { + var i = _i; + let a_i3 = i / (params.a_ne2 * params.a_ne1 * params.a_ne0); + i = i % (params.a_ne2 * params.a_ne1 * params.a_ne0); + let a_i2 = i / (params.a_ne1 * params.a_ne0); + i = i % (params.a_ne1 * params.a_ne0); + let a_i1 = i / params.a_ne0; + let a_i0 = i % params.a_ne0; + + // handle repetition of b + // index loops back to the beginning and repeats after elements are exhausted = modulo + let b_i0 = a_i0 % params.b_ne0; + let b_i1 = a_i1 % params.b_ne1; + let b_i2 = a_i2 % params.b_ne2; + let b_i3 = a_i3 % params.b_ne3; + + // compute index for position in b's flat array + return b_i0 * params.stride_src1_0 + + b_i1 * params.stride_src1_1 + + b_i2 * params.stride_src1_2 + + b_i3 * params.stride_src1_3; +} + +#ifdef TYPE_F32 +#define DataType f32 +#endif +#ifdef TYPE_F16 +#define DataType f16 +#endif + +#ifdef SRC_OVERLAP +@group(0) @binding(0) +var<storage, read_write> merged_src: array<DataType>; + +@group(0) @binding(1) +var<storage, read_write> dst: array<DataType>; + +@group(0) @binding(2) +var<uniform> params: Params; +#else +@group(0) @binding(0) +var<storage, read_write> src0: array<DataType>; + +@group(0) @binding(1) +var<storage, read_write> src1 : array<DataType>; +#if defined(INPLACE) || defined(OVERLAP) +@group(0) @binding(2) +var<uniform> params: Params; + +#else +@group(0) @binding(2) +var<storage, read_write> dst: array<DataType>; + +@group(0) @binding(3) +var<uniform> params: Params; +#endif +#endif + +fn op(a: DataType, b: DataType) -> DataType { +#ifdef OP_ADD + return a + b; +#elif defined(OP_SUB) + return a - b; +#elif defined(OP_MUL) + return a * b; +#elif defined(OP_DIV) + return a / b; +#endif +} + +fn update(dst_i: u32, src0_i: u32, src1_i: u32) { +#ifdef SRC_OVERLAP + let result = op(merged_src[src0_i], merged_src[src1_i]); +#else + let result = op(src0[src0_i], src1[src1_i]); +#endif + +#ifdef INPLACE + src0[src0_i] = result; +#elif defined(OVERLAP) + src1[src1_i] = result; +#else + dst[dst_i] = result; +#endif +} + +@compute @workgroup_size(WG_SIZE) +fn main(@builtin(global_invocation_id) gid: vec3<u32>, + @builtin(num_workgroups) num_wg: vec3<u32>) { + let threads_per_group = u32(WG_SIZE); + let i = gid.x + (num_wg.x * threads_per_group) * gid.y; + if (i < params.ne) { + let src0_i = params.offset_src0 + src0_index(i); + let src1_i = params.offset_src1 + src1_index(i); + update(params.offset_dst + i, src0_i, src1_i); + } +} diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/binary_head.tmpl b/ggml/src/ggml-webgpu/wgsl-shaders/binary_head.tmpl deleted file mode 100644 index 4b254f468d6..00000000000 --- a/ggml/src/ggml-webgpu/wgsl-shaders/binary_head.tmpl +++ /dev/null @@ -1,45 +0,0 @@ -struct Params { - ne: u32, - - // offsets in elements - offset_src0: u32, - offset_src1: u32, - offset_dst: u32, - - stride_src1_0: u32, - stride_src1_1: u32, - stride_src1_2: u32, - stride_src1_3: u32, - - a_ne0: u32, - a_ne1: u32, - a_ne2: u32, - - b_ne0: u32, - b_ne1: u32, - b_ne2: u32, - b_ne3: u32, -}; - -fn src1_index(_i: u32) -> u32 { - var i = _i; - let a_i3 = i / (params.a_ne2 * params.a_ne1 * params.a_ne0); - i = i % (params.a_ne2 * params.a_ne1 * params.a_ne0); - let a_i2 = i / (params.a_ne1 * params.a_ne0); - i = i % (params.a_ne1 * params.a_ne0); - let a_i1 = i / params.a_ne0; - let a_i0 = i % params.a_ne0; - - // handle repetition of b - // index loops back to the beginning and repeats after elements are exhausted = modulo - let b_i0 = a_i0 % params.b_ne0; - let b_i1 = a_i1 % params.b_ne1; - let b_i2 = a_i2 % params.b_ne2; - let b_i3 = a_i3 % params.b_ne3; - - // compute index for position in b's flat array - return b_i0 * params.stride_src1_0 + - b_i1 * params.stride_src1_1 + - b_i2 * params.stride_src1_2 + - b_i3 * params.stride_src1_3; -} diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/common_decls.tmpl b/ggml/src/ggml-webgpu/wgsl-shaders/common_decls.tmpl index 389c97bb51b..758efa17d77 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/common_decls.tmpl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/common_decls.tmpl @@ -1,5 +1,4 @@ -#decl(BYTE_HELPERS) - +#ifdef BYTE_HELPERS fn get_byte(value: u32, index: u32) -> u32 { return (value >> (index * 8)) & 0xFF; } @@ -7,76 +6,114 @@ fn get_byte(value: u32, index: u32) -> u32 { fn get_byte_i32(value: u32, index: u32) -> i32 { return bitcast<i32>(((value >> (index * 8)) & 0xFF) << 24) >> 24; } +#endif + +#ifdef U32_DEQUANT_HELPERS +#ifdef DECLARE_BYTE_LOADERS_SRC +fn load_u16_at_src(byte_offset: u32) -> u32 { + let word = src[byte_offset / 4u]; + let shift = (byte_offset & 0x2u) * 8u; + return (word >> shift) & 0xFFFFu; +} -#enddecl(BYTE_HELPERS) +fn load_u32_at_src(byte_offset: u32) -> u32 { + let word_idx = byte_offset / 4u; + let shift = (byte_offset & 0x3u) * 8u; + let lo = src[word_idx]; + let hi = src[word_idx + 1u]; + let shifted = (lo >> shift) | (hi << (32u - shift)); + return select(shifted, lo, shift == 0u); +} -#decl(Q4_0_T) -struct q4_0 { - d: f16, - qs: array<f16, 8> -}; -#enddecl(Q4_0_T) +fn load_f16_at_src(byte_offset: u32) -> f16 { + let packed = unpack2x16float(load_u16_at_src(byte_offset)); + return f16(packed[0]); +} + +fn load_f16_as_f32_at_src(byte_offset: u32) -> f32 { + let word = src[byte_offset / 4u]; + let shift = (byte_offset & 0x2u) * 8u; + let d_bits = (word >> shift) & 0xFFFFu; + return unpack2x16float(d_bits)[0]; +} +#endif -#decl(Q4_1_T) +#ifdef DECLARE_BYTE_LOADERS_SRC0 +fn load_u16_at_src0(byte_offset: u32) -> u32 { + let word = src0[byte_offset / 4u]; + let shift = (byte_offset & 0x2u) * 8u; + return (word >> shift) & 0xFFFFu; +} + +// Always reads the 4-byte-aligned word containing byte_offset. +// Caller extracts the 16-bit half it needs via & 0xFFFFu or >> 16u. +// this is used in k-quants for better performance +fn load_u32_at_src0_aligned(byte_offset: u32) -> u32 { + return src0[(byte_offset & ~3u) / 4u]; +} + +fn load_u32_at_src0(byte_offset: u32) -> u32 { + let word_idx = byte_offset / 4u; + let shift = (byte_offset & 0x3u) * 8u; + let lo = src0[word_idx]; + let hi = src0[word_idx + 1u]; + let shifted = (lo >> shift) | (hi << (32u - shift)); + return select(shifted, lo, shift == 0u); +} + +fn load_f16_at_src0(byte_offset: u32) -> f16 { + let packed = unpack2x16float(load_u16_at_src0(byte_offset)); + return f16(packed[0]); +} + +fn load_f16_as_f32_at_src0(byte_offset: u32) -> f32 { + let word = src0[byte_offset / 4u]; + let shift = (byte_offset & 0x2u) * 8u; + let d_bits = (word >> shift) & 0xFFFFu; + return unpack2x16float(d_bits)[0]; +} +#endif +#endif + + + +#ifdef Q4_1_T struct q4_1 { d: f16, m: f16, qs: array<u32, 4> }; -#enddecl(Q4_1_T) +#endif -#decl(Q5_0_T) -struct q5_0 { - d: f16, - qh: array<f16, 2>, - qs: array<f16, 8> -}; -#enddecl(Q5_0_T) -#decl(Q5_1_T) +#ifdef Q5_1_T struct q5_1 { d: f16, m: f16, qh: u32, qs: array<u32, 4> }; -#enddecl(Q5_1_T) +#endif -#decl(Q8_0_T) -struct q8_0 { - d: f16, - qs: array<f16, 16> -}; -#enddecl(Q8_0_T) - -#decl(Q8_1_T) +#ifdef Q8_1_T struct q8_1 { d: f16, - m: f16, + s: f16, // d * sum(qs[i]) qs: array<u32, 8> }; -#enddecl(Q8_1_T) +#endif -#decl(Q2_K_T) -struct q2_k { +#ifdef Q2_K_T +struct q2_K { scales: array<u32, 4>, qs: array<u32, 16>, d: f16, dmin: f16 }; -#enddecl(Q2_K_T) - -#decl(Q3_K_T) -struct q3_k { - hmask: array<f16, 16>, - qs: array<f16, 32>, - scales: array<f16, 6>, - d: f16 -}; -#enddecl(Q3_K_T) +#endif -#decl(Q45_K_SCALE_MIN) +#if defined(Q4_K_SCALE_MIN) || defined(Q5_K_SCALE_MIN) fn get_scale_min(is: u32, scales: array<u32, 3>) -> vec2<f32> { if (is < 4) { let sc_byte = get_byte(scales[is / 4], is % 4); @@ -91,111 +128,43 @@ fn get_scale_min(is: u32, scales: array<u32, 3>) -> vec2<f32> { return vec2(f32(sc), f32(m)); } } - -#enddecl(Q45_K_SCALE_MIN) - -#decl(Q4_K_T) -struct q4_k { +#endif +#ifdef Q4_K_T +struct q4_K { d: f16, dmin: f16, scales: array<u32, 3>, qs: array<u32, 32> }; -#enddecl(Q4_K_T) +#endif -#decl(Q5_K_T) -struct q5_k { +#ifdef Q5_K_T +struct q5_K { d: f16, dmin: f16, scales: array<u32, 3>, qh: array<u32, 8>, qs: array<u32, 32> }; -#enddecl(Q5_K_T) - -#decl(Q6_K_T) -struct q6_k { - ql: array<f16, 64>, - qh: array<f16, 32>, - scales: array<f16, 8>, - d: f16 -}; -#enddecl(Q6_K_T) - -#decl(IQ2_XXS_T) -struct iq2_xxs { - d: f16, - qs: array<f16, 32> -}; -#enddecl(IQ2_XXS_T) - -#decl(IQ2_XS_T) -struct iq2_xs { - d: f16, - qs: array<f16, 32>, - scales: array<f16, 4> -}; -#enddecl(IQ2_XS_T) - -#decl(IQ2_S_T) -struct iq2_s { - d: f16, - qs: array<f16, 32>, - qh: array<f16, 4>, - scales: array<f16, 4> -}; -#enddecl(IQ2_S_T) - -#decl(IQ3_XSS_T) -struct iq3_xxs { - d: f16, - qs: array<f16, 48> -}; -#enddecl(IQ3_XSS_T) - -#decl(IQ3_S_T) -struct iq3_s { - d: f16, - qs: array<f16, 32>, - qh: array<f16, 4>, - signs: array<f16, 16>, - scales: array<f16, 2> -}; -#enddecl(IQ3_S_T) +#endif -#decl(IQ1_S_T) -struct iq1_s { - d: f16, - qs: array<f16, 16>, - qh: array<f16, 8> -}; -#enddecl(IQ1_S_T) - -#decl(IQ1_M_T) +#ifdef IQ1_M_T struct iq1_m { qs: array<u32, 8>, qh: array<u32, 4>, scales: array<u32, 2> }; -#enddecl(IQ1_M_T) - -#decl(IQ4_NL_T) -struct iq4_nl { - d: f16, - qs: array<f16, 8>, -}; -#enddecl(IQ4_NL_T) +#endif -#decl(IQ4_XS_T) +#ifdef IQ4_XS_T struct iq4_xs { - d: f16, - scales_h: f16, + d_scales_h: u32, scales_l: u32, qs: array<u32, 32> }; -#enddecl(IQ4_XS_T) +#endif -#decl(IQ23_TABLES) +#if defined(IQ2_XXS_TABLES) || defined(IQ2_XS_TABLES) || defined(IQ2_S_TABLES) || defined(IQ3_XXS_TABLES) || defined(IQ3_S_TABLES) const kmask_iq2xs : array<u32, 2> = array<u32, 2>( 0x08040201u, // 1, 2, 4, 8 0x80402010u // 16, 32, 64, 128 @@ -211,9 +180,9 @@ const ksigns_iq2xs: array<u32, 32> = array<u32, 32>( 0x63e2e160,0xe76665e4,0xeb6a69e8,0x6feeed6c, 0xf37271f0,0x77f6f574,0x7bfaf978,0xff7e7dfc ); -#enddecl(IQ23_TABLES) +#endif -#decl(IQ2_XXS_GRID) +#ifdef IQ2_XXS_GRID const iq2xxs_grid = array<u32, 512>( 0x08080808, 0x08080808, 0x0808082b, 0x08080808, 0x08081919, 0x08080808, 0x08082b08, 0x08080808, 0x08082b2b, 0x08080808, 0x08190819, 0x08080808, 0x08191908, 0x08080808, 0x082b0808, 0x08080808, @@ -280,9 +249,9 @@ const iq2xxs_grid = array<u32, 512>( 0x0808082b, 0x2b2b0808, 0x19190808, 0x2b2b0808, 0x2b081919, 0x2b2b0808, 0x08082b19, 0x2b2b0819, 0x08080808, 0x2b2b082b, 0x08192b08, 0x2b2b1908, 0x19190808, 0x2b2b2b08, 0x08081908, 0x2b2b2b19 ); -#enddecl(IQ2_XXS_GRID) +#endif -#decl(IQ2_XS_GRID) +#ifdef IQ2_XS_GRID const iq2xs_grid = array<u32, 1024>( 0x08080808, 0x08080808, 0x0808082b, 0x08080808, 0x08081919, 0x08080808, 0x08082b08, 0x08080808, 0x08082b2b, 0x08080808, 0x08190819, 0x08080808, 0x08191908, 0x08080808, 0x0819192b, 0x08080808, @@ -413,9 +382,9 @@ const iq2xs_grid = array<u32, 1024>( 0x2b2b2b08, 0x2b2b2b08, 0x08081908, 0x2b2b2b19, 0x2b081908, 0x2b2b2b19, 0x2b08192b, 0x2b2b2b19, 0x082b2b08, 0x2b2b2b2b, 0x082b2b2b, 0x2b2b2b2b, 0x2b190819, 0x2b2b2b2b, 0x2b2b2b2b, 0x2b2b2b2b ); -#enddecl(IQ2_XS_GRID) +#endif -#decl(IQ2_S_GRID) +#ifdef IQ2_S_GRID const iq2s_grid = array<u32, 2048>( 0x08080808, 0x08080808, 0x0808082b, 0x08080808, 0x08081919, 0x08080808, 0x08082b08, 0x08080808, 0x08082b2b, 0x08080808, 0x08190819, 0x08080808, 0x08191908, 0x08080808, 0x0819192b, 0x08080808, @@ -674,10 +643,9 @@ const iq2s_grid = array<u32, 2048>( 0x2b08192b, 0x2b2b2b19, 0x08082b08, 0x2b2b2b2b, 0x08082b2b, 0x2b2b2b2b, 0x082b0808, 0x2b2b2b2b, 0x082b082b, 0x2b2b2b2b, 0x082b2b08, 0x2b2b2b2b, 0x2b082b08, 0x2b2b2b2b, 0x2b2b2b2b, 0x2b2b2b2b ); -#enddecl(IQ2_S_GRID) - -#decl(IQ3_XSS_GRID) +#endif +#ifdef IQ3_XXS_GRID const iq3xxs_grid = array<u32, 256>( 0x04040404, 0x04040414, 0x04040424, 0x04040c0c, 0x04040c1c, 0x04040c3e, 0x04041404, 0x04041414, 0x04041c0c, 0x04042414, 0x04043e1c, 0x04043e2c, 0x040c040c, 0x040c041c, 0x040c0c04, 0x040c0c14, @@ -712,10 +680,9 @@ const iq3xxs_grid = array<u32, 256>( 0x3e042c14, 0x3e0c1434, 0x3e0c2404, 0x3e140c14, 0x3e14242c, 0x3e142c14, 0x3e1c0404, 0x3e1c0c2c, 0x3e1c1c1c, 0x3e1c3404, 0x3e24140c, 0x3e24240c, 0x3e2c0404, 0x3e2c0414, 0x3e2c1424, 0x3e341c04 ); -#enddecl(IQ3_XSS_GRID) - -#decl(IQ3_S_GRID) +#endif +#ifdef IQ3_S_GRID const iq3s_grid = array<u32, 512>( 0x01010101, 0x01010103, 0x01010105, 0x0101010b, 0x0101010f, 0x01010301, 0x01010303, 0x01010305, 0x01010309, 0x0101030d, 0x01010501, 0x01010503, 0x0101050b, 0x01010707, 0x01010901, 0x01010905, @@ -782,9 +749,9 @@ const iq3s_grid = array<u32, 512>( 0x0f050701, 0x0f050b03, 0x0f070105, 0x0f070705, 0x0f07070b, 0x0f070b07, 0x0f090103, 0x0f09010b, 0x0f090307, 0x0f090501, 0x0f090b01, 0x0f0b0505, 0x0f0b0905, 0x0f0d0105, 0x0f0d0703, 0x0f0f0101 ); -#enddecl(IQ3_S_GRID) +#endif -#decl(IQ1_GRID) +#if defined(IQ1_S_GRID) || defined(IQ1_M_GRID) const IQ1_DELTA: f32 = 0.125; @@ -919,12 +886,19 @@ const iq1_grid = array<u32, 1024>( 0x55dd55df, 0x55d555d7, 0x5503550c, 0x557f5501, 0x5577557d, 0x55405575, 0x555d555f, 0x55555557 ); -#enddecl(IQ1_GRID) +#endif -#decl(IQ4_GRID) +#if defined(IQ4_NL_GRID) || defined(IQ4_XS_GRID) const kvalues_iq4nl = array<i32, 16>( -127, -104, -83, -65, -49, -35, -22, -10, 1, 13, 25, 38, 53, 69, 89, 113 ); -#enddecl(IQ4_GRID) +#endif + +#ifdef MXFP4_LUT +const kvalues_mxfp4 = array<i32, 16>( + 0, 1, 2, 3, 4, 6, 8, 12, 0, -1, -2, -3, -4, -6, -8, -12 +); +#endif + diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/concat.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/concat.wgsl new file mode 100644 index 00000000000..eb901bf0547 --- /dev/null +++ b/ggml/src/ggml-webgpu/wgsl-shaders/concat.wgsl @@ -0,0 +1,93 @@ +struct Params { + ne: u32, + + offset_src0: u32, + offset_src1: u32, + offset_dst: u32, + + stride_src0_0: u32, + stride_src0_1: u32, + stride_src0_2: u32, + stride_src0_3: u32, + + stride_src1_0: u32, + stride_src1_1: u32, + stride_src1_2: u32, + stride_src1_3: u32, + + ne0: u32, + ne1: u32, + ne2: u32, + ne3: u32, + + dim: u32, + src0_nedim: u32 +}; + +#ifdef TYPE_F32 +#define DataType f32 +#endif +#ifdef TYPE_I32 +#define DataType i32 +#endif + +#ifdef SRC_OVERLAP +@group(0) @binding(0) +var<storage, read_write> merged_src: array<DataType>; + +@group(0) @binding(1) +var<storage, read_write> dst: array<DataType>; + +@group(0) @binding(2) +var<uniform> params: Params; +#else +@group(0) @binding(0) +var<storage, read_write> src0: array<DataType>; + +@group(0) @binding(1) +var<storage, read_write> src1 : array<DataType>; + +@group(0) @binding(2) +var<storage, read_write> dst: array<DataType>; + +@group(0) @binding(3) +var<uniform> params: Params; +#endif +@compute @workgroup_size(WG_SIZE) +fn main(@builtin(global_invocation_id) gid: vec3<u32>) { + + if (gid.x < params.ne) { + var i = gid.x; + let i3 = i / (params.ne2 * params.ne1 * params.ne0); + i = i % (params.ne2 * params.ne1 * params.ne0); + let i2 = i / (params.ne1 * params.ne0); + i = i % (params.ne1 * params.ne0); + let i1 = i / params.ne0; + let i0 = i % params.ne0; + + var ni = array<u32, 4>(i0, i1, i2, i3); + + if (ni[params.dim] < params.src0_nedim) { + let src_i = ni[0] * params.stride_src0_0 + + ni[1] * params.stride_src0_1 + + ni[2] * params.stride_src0_2 + + ni[3] * params.stride_src0_3; +#ifdef SRC_OVERLAP + dst[params.offset_dst + gid.x] = merged_src[params.offset_src0 + src_i]; +#else + dst[params.offset_dst + gid.x] = src0[params.offset_src0 + src_i]; +#endif + } else { + ni[params.dim] -= params.src0_nedim; + let src_i = ni[0] * params.stride_src1_0 + + ni[1] * params.stride_src1_1 + + ni[2] * params.stride_src1_2 + + ni[3] * params.stride_src1_3; +#ifdef SRC_OVERLAP + dst[params.offset_dst + gid.x] = merged_src[params.offset_src1 + src_i]; +#else + dst[params.offset_dst + gid.x] = src1[params.offset_src1 + src_i]; +#endif + } + } +} diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/conv2d.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/conv2d.wgsl new file mode 100644 index 00000000000..9eb131dc221 --- /dev/null +++ b/ggml/src/ggml-webgpu/wgsl-shaders/conv2d.wgsl @@ -0,0 +1,165 @@ +#include "common_decls.tmpl" +enable f16; + +@group(0) @binding(0) +#if defined(WEIGHT_F32) +var<storage, read_write> weights: array<f32>; +#elif defined(WEIGHT_F16) +var<storage, read_write> weights: array<f16>; +#endif + +@group(0) @binding(1) +#if defined(INPUT_F32) +var<storage, read_write> input: array<f32>; +#elif defined(INPUT_F16) +var<storage, read_write> input: array<f16>; +#endif + +@group(0) @binding(2) +#if defined(OUTPUT_F32) +var<storage, read_write> output: array<f32>; +#elif defined(OUTPUT_F16) +var<storage, read_write> output: array<f16>; +#endif + +struct Params { + offset_w: u32, + offset_i: u32, + offset_o: u32, + + // element strides + sw0: u32, sw1: u32, sw2: u32, sw3: u32, + si0: u32, si1: u32, si2: u32, si3: u32, + so0: u32, so1: u32, so2: u32, so3: u32, + + // kernel dimensions + KW: u32, KH: u32, IC: u32, + // input dimensions + IW: u32, IH: u32, + // output dimensions + OW: u32, OH: u32, OC_out: u32, N_out: u32, + + // stride + s0: u32, s1: u32, + // padding + p0: u32, p1: u32, + // dilation + d0: u32, d1: u32, +}; + +@group(0) @binding(3) +var<uniform> params: Params; + +fn load_weight(idx: u32) -> f32 { + #if defined(WEIGHT_F32) + return weights[idx]; + #elif defined(WEIGHT_F16) + return f32(weights[idx]); + #endif +} + +fn load_input(idx: u32) -> f32 { + #if defined(INPUT_F32) + return input[idx]; + #elif defined(INPUT_F16) + return f32(input[idx]); + #endif +} + +fn store_output(idx: u32, val: f32) { + #if defined(OUTPUT_F32) + output[idx] = val; + #elif defined(OUTPUT_F16) + output[idx] = f16(val); + #endif +} + +fn ceil_div_u32(x: u32, y: u32) -> u32 { + return (x + y - 1) / y; +} + +// returns the first valid kernel index k such that base + k * step >= 0 +fn first_valid_k(base: i32, step: u32) -> u32 { + if (base >= 0) { + return 0; + } + + return ceil_div_u32(u32(-base), step); +} + +// returns the first invalid kernel index k such that base + k * step >= limit so valid k are in [0, end_valid_k) +fn end_valid_k(base: i32, step: u32, limit: u32, k_max: u32) -> u32 { + let remaining = i32(limit) - base; + if (remaining <= 0) { + return 0; + } + + return min(k_max, ceil_div_u32(u32(remaining), step)); +} + +@compute @workgroup_size(WG_SIZE) +fn main( + @builtin(global_invocation_id) gid: vec3<u32>, + @builtin(num_workgroups) num_wg: vec3<u32> +) { + + let threads_per_group = u32(WG_SIZE); + let i_out = gid.x + (num_wg.x * threads_per_group) * gid.y; + let n_out = params.OW * params.OH * params.OC_out * params.N_out; + + var sum: f32 = 0.0; + if (i_out >= n_out) { + return; + } + + // Kernel layout: [KW, KH, IC, ..] + // Input layout: [IW, IH, .., ..] + // Output layout: [OW, OH, OC, N] + + var i = i_out; + let n = i / (params.OC_out * params.OH * params.OW); + i = i % (params.OC_out * params.OH * params.OW); + let oc = i / (params.OH * params.OW); + i = i % (params.OH * params.OW); + let oh = i / params.OW; + let ow = i % params.OW; + + let ow_base = i32(ow * params.s0) - i32(params.p0); + let oh_base = i32(oh * params.s1) - i32(params.p1); + + // clip the valid kernel window once + let kw_begin = first_valid_k(ow_base, params.d0); + let kw_end = end_valid_k(ow_base, params.d0, params.IW, params.KW); + let kh_begin = first_valid_k(oh_base, params.d1); + let kh_end = end_valid_k(oh_base, params.d1, params.IH, params.KH); + + // entire receptive field is out of bounds + if (kw_begin >= kw_end || kh_begin >= kh_end) { + let out_idx = params.offset_o + ow * params.so0 + oh * params.so1 + oc * params.so2 + n * params.so3; + store_output(out_idx, 0.0); + return; + } + + let weight_oc_base = params.offset_w + oc * params.sw3; + let input_n_base = params.offset_i + n * params.si3; + + for (var ic: u32 = 0; ic < params.IC; ic += 1) { + let w_base_ic = ic * params.sw2 + weight_oc_base; + let in_base = ic * params.si2 + input_n_base; + + for (var kh: u32 = kh_begin; kh < kh_end; kh += 1) { + let ih = u32(oh_base + i32(kh * params.d1)); + let w_row_base = w_base_ic + kh * params.sw1; + let in_row_base = in_base + ih * params.si1; + for (var kw: u32 = kw_begin; kw < kw_end; kw += 1) { + let iw = u32(ow_base + i32(kw * params.d0)); + let w_idx = w_row_base + kw * params.sw0; + let in_idx = in_row_base + iw * params.si0; + sum += load_weight(w_idx) * load_input(in_idx); + } + } + } + + let out_idx = params.offset_o + ow * params.so0 + oh * params.so1 + oc * params.so2 + n * params.so3; + store_output(out_idx, sum); +} diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/cpy.tmpl.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/cpy.wgsl similarity index 60% rename from ggml/src/ggml-webgpu/wgsl-shaders/cpy.tmpl.wgsl rename to ggml/src/ggml-webgpu/wgsl-shaders/cpy.wgsl index db1aa34903b..67f1dc0928f 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/cpy.tmpl.wgsl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/cpy.wgsl @@ -1,60 +1,41 @@ -#define(VARIANTS) - -[ - { - "REPLS": { - "SRC_TYPE": "f32", - "DST_TYPE": "f32" - } - }, - { - "REPLS": { - "SRC_TYPE": "f32", - "DST_TYPE": "f16" - } - }, - { - "REPLS": { - "SRC_TYPE": "f16", - "DST_TYPE": "f16" - } - }, - { - "REPLS": { - "SRC_TYPE": "f16", - "DST_TYPE": "f32" - } - } -] +enable f16; -#end(VARIANTS) +#ifdef SRC_F32 +#define SRC_TYPE f32 +#elif defined(SRC_F16) +#define SRC_TYPE f16 +#endif -#define(SHADER) -enable f16; +#ifdef DST_F32 +#define DST_TYPE f32 +#elif defined(DST_F16) +#define DST_TYPE f16 +#elif defined(DST_I32) +#define DST_TYPE i32 +#endif @group(0) @binding(0) -var<storage, read_write> src: array<{{SRC_TYPE}}>; +var<storage, read_write> src: array<SRC_TYPE>; @group(0) @binding(1) -var<storage, read_write> dst: array<{{DST_TYPE}}>; +var<storage, read_write> dst: array<DST_TYPE>; -struct Params { - ne: u32, // total number of elements - offset_src: u32, // in elements - offset_dst: u32, // in elements +struct Params{ + ne: u32, + offset_src: u32, + offset_dst: u32, - // Strides (in elements) — may be permuted stride_src0: u32, stride_src1: u32, stride_src2: u32, stride_src3: u32, + stride_dst0: u32, stride_dst1: u32, stride_dst2: u32, stride_dst3: u32, - // Logical shapes src_ne0: u32, src_ne1: u32, src_ne2: u32, @@ -67,9 +48,10 @@ struct Params { @group(0) @binding(2) var<uniform> params: Params; -override wg_size: u32; -@compute @workgroup_size(wg_size) -fn main(@builtin(global_invocation_id) gid: vec3<u32>) { +@compute @workgroup_size(WG_SIZE) +fn main( + @builtin(global_invocation_id) gid: vec3<u32>, +) { if (gid.x >= params.ne) { return; } @@ -96,6 +78,5 @@ fn main(@builtin(global_invocation_id) gid: vec3<u32>) { let dst_idx = j0 * params.stride_dst0 + j1 * params.stride_dst1 + j2 * params.stride_dst2 + j3 * params.stride_dst3; - dst[params.offset_dst + dst_idx] = {{DST_TYPE}}((src[params.offset_src + src_idx])); + dst[params.offset_dst + dst_idx] = DST_TYPE((src[params.offset_src + src_idx])); } -#end(SHADER) diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/cumsum.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/cumsum.wgsl new file mode 100644 index 00000000000..e622552c421 --- /dev/null +++ b/ggml/src/ggml-webgpu/wgsl-shaders/cumsum.wgsl @@ -0,0 +1,66 @@ +@group(0) @binding(0) +var<storage, read_write> src: array<f32>; + +@group(0) @binding(1) +var<storage, read_write> dst: array<f32>; + +struct Params { + offset_src: u32, // in elements + offset_dst: u32, // in elements + ne0: u32, +}; + +@group(0) @binding(2) +var<uniform> params: Params; + +var<workgroup> shared_sum: array<f32, WG_SIZE>; + +@compute @workgroup_size(WG_SIZE) +fn main(@builtin(workgroup_id) wid: vec3<u32>, + @builtin(local_invocation_id) lid: vec3<u32>) { + let row_idx = params.offset_src + wid.x * params.ne0; + let elems = (params.ne0 + WG_SIZE - 1) / WG_SIZE; + var local_sum: f32 = 0.0; + for (var col = lid.x * elems; col < (lid.x + 1) * elems && col < params.ne0; col ++) { + local_sum += src[row_idx + col]; + } + shared_sum[lid.x] = local_sum; + workgroupBarrier(); + + // upsweep + var offset = 1u; + while (offset < WG_SIZE) { + let idx = (lid.x + 1) * offset * 2 - 1; + if (idx < WG_SIZE) { + shared_sum[idx] = shared_sum[idx] + shared_sum[idx - offset]; + } + workgroupBarrier(); + offset <<= 1; + } + + // set last to 0 for exclusive sum + if (lid.x == 0) { + shared_sum[WG_SIZE - 1] = 0.0; + } + workgroupBarrier(); + + // downsweep + offset = WG_SIZE >> 1; + while (offset > 0) { + let idx = (lid.x + 1) * offset * 2 - 1; + if (idx < WG_SIZE) { + let t = shared_sum[idx - offset]; + shared_sum[idx - offset] = shared_sum[idx]; + shared_sum[idx] = shared_sum[idx] + t; + } + workgroupBarrier(); + offset = offset >> 1; + } + + // shared_sum[lid] is exclusive prefix sum up to this thread. + var running_sum = shared_sum[lid.x]; + for (var col = lid.x * elems; col < (lid.x + 1) * elems && col < params.ne0; col ++) { + running_sum += src[row_idx + col]; + dst[params.offset_dst + wid.x * params.ne0 + col] = running_sum; + } +} diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py b/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py index d61df5bb9e5..79a3a9597ab 100755 --- a/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py +++ b/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py @@ -1,41 +1,8 @@ import os import re -import ast import argparse -def extract_block(text, name): - pattern = rf'#define\({name}\)\s*(.*?)#end\({name}\)' - match = re.search(pattern, text, re.DOTALL) - if not match: - raise ValueError(f"Missing block: {name}") - return match.group(1).strip() - - -def parse_decls(decls_text): - decls = {} - for name, code in re.findall(r'#decl\((.*?)\)\s*(.*?)#enddecl\(\1\)', decls_text, re.DOTALL): - decls[name.strip()] = code.strip() - return decls - - -def replace_repl_placeholders(variant, template_map): - for repl, code in variant["REPLS"].items(): - for key, val in template_map.items(): - # Match "key" and avoid matching subsequences using by using \b - code = re.sub(rf'\b{re.escape(str(key))}\b', str(val), code) - variant["REPLS"][repl] = code - return variant - - -def replace_placeholders(shader_text, replacements): - for key, val in replacements.items(): - # Match {{KEY}} literally, where KEY is escaped - pattern = r'{{\s*' + re.escape(key) + r'\s*}}' - shader_text = re.sub(pattern, str(val), shader_text) - return shader_text - - def expand_includes(shader, input_dir): """ Replace #include "file" lines in the text with the contents of that file. @@ -56,91 +23,66 @@ def replacer(match): return include_pattern.sub(replacer, shader) -def write_shader(shader_name, shader_code, output_dir, outfile): +def chunk_shader(shader_code, max_chunk_len=60000): + """Split shader_code into safe raw-string sized chunks.""" + return [shader_code[i : i + max_chunk_len] for i in range(0, len(shader_code), max_chunk_len)] + + +def raw_delim(shader_code): + """Pick a raw-string delimiter that does not appear in the shader.""" + delim = "wgsl" + while f"){delim}\"" in shader_code: + delim += "_x" + return delim + + +def write_shader(shader_name, shader_code, output_dir, outfile, input_dir): + shader_code = expand_includes(shader_code, input_dir) + if output_dir: wgsl_filename = os.path.join(output_dir, f"{shader_name}.wgsl") with open(wgsl_filename, "w", encoding="utf-8") as f_out: f_out.write(shader_code) - outfile.write(f'const char* wgsl_{shader_name} = R"({shader_code})";\n\n') - -def generate_variants(fname, input_dir, output_dir, outfile): - shader_path = os.path.join(input_dir, fname) - shader_base_name = fname.split(".")[0] + delim = raw_delim(shader_code) + chunks = chunk_shader(shader_code) - with open(shader_path, "r", encoding="utf-8") as f: - text = f.read() - - try: - variants = ast.literal_eval(extract_block(text, "VARIANTS")) - except ValueError: - write_shader(shader_base_name, text, output_dir, outfile) + if len(chunks) == 1: + outfile.write(f'const char* wgsl_{shader_name} = R"{delim}({shader_code}){delim}";\n\n') else: - try: - decls_map = parse_decls(extract_block(text, "DECLS")) - except ValueError: - decls_map = {} - try: - templates_map = ast.literal_eval(extract_block(text, "REPL_TEMPLATES")) - except ValueError: - templates_map = {} - - for fname in sorted(os.listdir(input_dir)): - if fname.endswith(".tmpl"): - tmpl_path = os.path.join(input_dir, fname) - with open(tmpl_path, "r", encoding="utf-8") as f_tmpl: - decls = f_tmpl.read() - decls_map.update(parse_decls(decls)) - - shader_template = extract_block(text, "SHADER") - for variant in variants: - if "DECLS" in variant: - decls = variant["DECLS"] - else: - decls = [] - decls_code = "" - for key in decls: - if key not in decls_map: - raise ValueError(f"DECLS key '{key}' not found.") - decls_code += decls_map[key] + "\n\n" - final_shader = re.sub(r'\bDECLS\b', decls_code, shader_template) - if "REPLS" in variant: - variant = replace_repl_placeholders(variant, templates_map) - final_shader = replace_placeholders(final_shader, variant["REPLS"]) - # second run to expand placeholders in repl_template - final_shader = replace_placeholders(final_shader, variant["REPLS"]) - final_shader = expand_includes(final_shader, input_dir) - - if "SHADER_NAME" in variant: - output_name = variant["SHADER_NAME"] - elif "SHADER_SUFFIX" in variant: - output_name = f"{shader_base_name}_" + variant["SHADER_SUFFIX"] - elif "REPLS" in variant and "SRC0_TYPE" in variant["REPLS"] and "SRC1_TYPE" in variant["REPLS"]: - output_name = f"{shader_base_name}_" + "_".join([variant["REPLS"]["SRC0_TYPE"], variant["REPLS"]["SRC1_TYPE"]]) - elif "REPLS" in variant and "SRC_TYPE" in variant["REPLS"] and "DST_TYPE" in variant["REPLS"]: - output_name = f"{shader_base_name}_" + "_".join([variant["REPLS"]["SRC_TYPE"], variant["REPLS"]["DST_TYPE"]]) - elif "REPLS" in variant and "TYPE" in variant["REPLS"]: - output_name = f"{shader_base_name}_" + variant["REPLS"]["TYPE"] - else: - output_name = shader_base_name - write_shader(output_name, final_shader, output_dir, outfile) + for idx, chunk in enumerate(chunks): + outfile.write(f'static const char wgsl_{shader_name}_part{idx}[] = R"{delim}({chunk}){delim}";\n\n') + outfile.write(f'static const std::string& wgsl_{shader_name}_str() {{\n') + outfile.write(' static const std::string s = []{\n') + outfile.write(' std::string tmp;\n') + outfile.write(f' tmp.reserve({len(shader_code)});\n') + for idx in range(len(chunks)): + outfile.write(f' tmp.append(wgsl_{shader_name}_part{idx});\n') + outfile.write(' return tmp;\n') + outfile.write(' }();\n') + outfile.write(' return s;\n') + outfile.write('}\n') + outfile.write(f'const char* wgsl_{shader_name} = wgsl_{shader_name}_str().c_str();\n\n') def main(): parser = argparse.ArgumentParser() parser.add_argument("--input_dir", required=True) parser.add_argument("--output_file", required=True) - parser.add_argument("--output_dir") args = parser.parse_args() - if args.output_dir: - os.makedirs(args.output_dir, exist_ok=True) - with open(args.output_file, "w", encoding="utf-8") as out: - out.write("// Auto-generated shader embedding\n\n") + out.write("// Auto-generated shader embedding\n") + out.write("#include <string>\n\n") for fname in sorted(os.listdir(args.input_dir)): if fname.endswith(".wgsl"): - generate_variants(fname, args.input_dir, args.output_dir, out) + shader_path = os.path.join(args.input_dir, fname) + shader_name = fname.replace(".wgsl", "") + + with open(shader_path, "r", encoding="utf-8") as f: + shader_code = f.read() + + write_shader(shader_name, shader_code, None, out, args.input_dir) if __name__ == "__main__": diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn.wgsl index de7c132a624..9767ca3d754 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn.wgsl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn.wgsl @@ -4,10 +4,23 @@ enable f16; enable subgroups; enable chromium_experimental_subgroup_matrix; -#ifdef KV_F32 -#define KV_TYPE f32 +#define BYTE_HELPERS +#include "common_decls.tmpl" + +#ifdef K_F32 +#define K_TYPE f32 +#elif defined(K_Q4_0) || defined(K_Q8_0) +#define K_TYPE u32 +#else +#define K_TYPE f16 +#endif + +#ifdef V_F32 +#define V_TYPE f32 +#elif defined(V_Q4_0) || defined(V_Q8_0) +#define V_TYPE u32 #else -#define KV_TYPE f16 +#define V_TYPE f16 #endif // Default values @@ -28,33 +41,6 @@ enable chromium_experimental_subgroup_matrix; // Number of subgroup-matrix-width blocks that span the KV tile. SG_MAT_N must divide KV_TILE. #define KV_BLOCKS (KV_TILE / SG_MAT_N) -// Quantization constants/helpers -#define BLOCK_SIZE 32 -#define BLOCKS_K ((HEAD_DIM_QK + BLOCK_SIZE - 1) / BLOCK_SIZE) -#define BLOCKS_V ((HEAD_DIM_V + BLOCK_SIZE - 1) / BLOCK_SIZE) -// number of quantized elements processed per thread -#if defined(KV_Q4_0) -#define NQ 16 -// Q4_0 has 32 elements, 1 f16 for scale, 8 f16 for 4-bit weights -#define F16_PER_BLOCK 9 -#define WEIGHTS_PER_F16 4 -#elif defined(KV_Q8_0) -#define NQ 8 -// Q8_0 has 32 elements, 1 f16 for scale, 16 f16 for 8-bit weights -#define F16_PER_BLOCK 17 -#define WEIGHTS_PER_F16 2 -#endif -#define F16_PER_THREAD (NQ / WEIGHTS_PER_F16) - -// Ok not to put these in a define block, compiler will remove if unused -fn get_byte(value: u32, index: u32) -> u32 { - return (value >> (index * 8)) & 0xFF; -} - -fn get_byte_i32(value: u32, index: u32) -> i32 { - return bitcast<i32>(((value >> (index * 8)) & 0xFF) << 24) >> 24; -} - struct Params { offset_q: u32, offset_k: u32, @@ -93,28 +79,57 @@ struct Params { }; @group(0) @binding(0) var<storage, read_write> Q: array<f32>; -@group(0) @binding(1) var<storage, read_write> K: array<KV_TYPE>; -@group(0) @binding(2) var<storage, read_write> V: array<KV_TYPE>; +#ifdef KV_OVERLAP +@group(0) @binding(1) var<storage, read_write> K: array<K_TYPE>; +#define V K +#else +@group(0) @binding(1) var<storage, read_write> K: array<K_TYPE>; +@group(0) @binding(2) var<storage, read_write> V: array<V_TYPE>; +#endif #if defined(MASK) && defined(SINKS) +#ifdef KV_OVERLAP +@group(0) @binding(2) var<storage, read_write> mask: array<f16>; +@group(0) @binding(3) var<storage, read_write> sinks: array<f32>; +#define DST_BINDING 4 +#define PARAMS_BINDING 5 +#else @group(0) @binding(3) var<storage, read_write> mask: array<f16>; @group(0) @binding(4) var<storage, read_write> sinks: array<f32>; #define DST_BINDING 5 #define PARAMS_BINDING 6 +#endif #elif defined(MASK) +#ifdef KV_OVERLAP +@group(0) @binding(2) var<storage, read_write> mask: array<f16>; +#define DST_BINDING 3 +#define PARAMS_BINDING 4 +#else @group(0) @binding(3) var<storage, read_write> mask: array<f16>; #define DST_BINDING 4 #define PARAMS_BINDING 5 +#endif #elif defined(SINKS) +#ifdef KV_OVERLAP +@group(0) @binding(2) var<storage, read_write> sinks: array<f32>; +#define DST_BINDING 3 +#define PARAMS_BINDING 4 +#else @group(0) @binding(3) var<storage, read_write> sinks: array<f32>; #define DST_BINDING 4 #define PARAMS_BINDING 5 +#endif +#else +#ifdef KV_OVERLAP +#define DST_BINDING 2 +#define PARAMS_BINDING 3 #else #define DST_BINDING 3 #define PARAMS_BINDING 4 #endif +#endif -@group(0) @binding(DST_BINDING) var<storage, read_write> dst: array<f32>; +@group(0) @binding(DST_BINDING) var<storage, read_write> dst: array<vec4<f32>>; @group(0) @binding(PARAMS_BINDING) var<uniform> params: Params; // Just a very small float value. @@ -160,14 +175,58 @@ fn calc_softmax_term(kv_idx: u32, q_tile_row: u32, slope: f32) -> f32 { return v; } +fn load_f32x4(buf: ptr<storage, array<vec4<f32>>, read_write>, scalar_index: u32) -> vec4<f32> { + return (*buf)[scalar_index >> 2u]; +} + +fn load_kx4(buf: ptr<storage, array<vec4<K_TYPE>>, read_write>, scalar_index: u32) -> vec4<K_TYPE> { + return (*buf)[scalar_index >> 2u]; +} + +#ifndef KV_DIRECT +#define QUANT_SHMEM kv_shmem +#define QUANT_OUT_TYPE f16 +#include "quant_inner_loops.tmpl" +#include "flash_attn_quant_staging.tmpl" + +#if !defined(K_Q4_0) && !defined(K_Q8_0) +fn load_k_tile_block(local_x: u32, kv_count: u32, kv_tile: u32, k_head_offset: u32) { + for (var elem_idx = local_x; elem_idx < KV_TILE * HEAD_DIM_QK; elem_idx += WG_SIZE) { + let k_row = elem_idx / HEAD_DIM_QK; + let k_col = elem_idx % HEAD_DIM_QK; + let global_k_row = kv_tile + k_row; + let global_k_row_offset = k_head_offset + global_k_row * params.stride_k1; + kv_shmem[elem_idx] = f16(select( + 0.0, + K[global_k_row_offset + k_col], + global_k_row < params.seq_len_kv && k_col < HEAD_DIM_QK)); + } +} +#endif + +#if !defined(V_Q4_0) && !defined(V_Q8_0) +fn load_v_tile_block(local_x: u32, kv_count: u32, kv_tile: u32, v_head_offset: u32) { + for (var elem_idx = local_x; elem_idx < KV_TILE * HEAD_DIM_V; elem_idx += WG_SIZE) { + let v_row = elem_idx / HEAD_DIM_V; + let v_col = elem_idx % HEAD_DIM_V; + let global_v_row = kv_tile + v_row; + let global_v_row_offset = v_head_offset + global_v_row * params.stride_v1; + kv_shmem[elem_idx] = f16(select( + 0.0, + V[global_v_row_offset + v_col], + global_v_row < params.seq_len_kv && v_col < HEAD_DIM_V)); + } +} +#endif +#endif @compute @workgroup_size(WG_SIZE) fn main(@builtin(workgroup_id) wg_id: vec3<u32>, - @builtin(local_invocation_id) local_id: vec3<u32>, - @builtin(subgroup_id) subgroup_id: u32, - @builtin(subgroup_size) subgroup_size: u32, - @builtin(num_subgroups) num_subgroups: u32, - @builtin(subgroup_invocation_id) sg_inv_id: u32) { + @builtin(local_invocation_id) local_id: vec3<u32>, + @builtin(subgroup_id) subgroup_id: u32, + @builtin(subgroup_size) subgroup_size: u32, + @builtin(num_subgroups) num_subgroups: u32, + @builtin(subgroup_invocation_id) sg_inv_id: u32) { // initialize row max for online softmax for (var i = local_id.x; i < Q_TILE; i += WG_SIZE) { @@ -230,127 +289,92 @@ fn main(@builtin(workgroup_id) wg_id: vec3<u32>, } for (var kv_tile = 0u; kv_tile < params.seq_len_kv; kv_tile += KV_TILE) { + let kv_count = min(KV_TILE, params.seq_len_kv - kv_tile); // clear inter_shmem to ensure zero-initialized accumulators - for (var elem_idx = local_id.x; elem_idx < Q_TILE * KV_TILE; elem_idx += WG_SIZE) { - inter_shmem[elem_idx] = 0.0; - } + for (var elem_idx = local_id.x; elem_idx < Q_TILE * KV_TILE; elem_idx += WG_SIZE) { + inter_shmem[elem_idx] = 0.0; + } // load k tile into shared memory -#if defined(KV_Q4_0) - for (var elem_idx = local_id.x * NQ; elem_idx < KV_TILE * HEAD_DIM_QK; elem_idx += WG_SIZE * NQ) { - let blck_idx = elem_idx / BLOCK_SIZE; - let block_offset = (elem_idx % BLOCK_SIZE) / WEIGHTS_PER_F16; - let k_row = blck_idx / BLOCKS_K; - let global_k_row = kv_tile + k_row; - let block_k = blck_idx % BLOCKS_K; - let row_offset = k_row * HEAD_DIM_QK; - - if (global_k_row < params.seq_len_kv) { - let global_block_idx = k_head_offset + global_k_row * params.stride_k1 + block_k; - let base_idx = global_block_idx * F16_PER_BLOCK; - let d = K[base_idx]; // scale - for (var j = 0u; j < F16_PER_THREAD; j += 2) { - let q_0 = K[base_idx + 1u + block_offset + j]; - let q_1 = K[base_idx + 1u + block_offset + j + 1]; - let q_packed = bitcast<u32>(vec2(q_0, q_1)); - for (var k = 0u; k < 4u; k++) { - let q_byte = get_byte(q_packed, k); - let q_hi = (f16((q_byte >> 4) & 0xF) - 8.0) * d; - let q_lo = (f16(q_byte & 0xF) - 8.0) * d; - let idx = block_k * BLOCK_SIZE + block_offset * 2u + j * 2u + k; - kv_shmem[row_offset + idx] = q_lo; - kv_shmem[row_offset + idx + 16u] = q_hi; - } - } - } - } -#elif defined(KV_Q8_0) - for (var elem_idx = local_id.x * NQ; elem_idx < KV_TILE * HEAD_DIM_QK; elem_idx += WG_SIZE * NQ) { - let blck_idx = elem_idx / BLOCK_SIZE; - let block_offset = (elem_idx % BLOCK_SIZE) / WEIGHTS_PER_F16; - let k_row = blck_idx / BLOCKS_K; - let global_k_row = kv_tile + k_row; - let block_k = blck_idx % BLOCKS_K; - let row_offset = k_row * HEAD_DIM_QK; - - if (global_k_row < params.seq_len_kv) { - let global_block_idx = k_head_offset + global_k_row * params.stride_k1 + block_k; - let base_idx = global_block_idx * F16_PER_BLOCK; - let d = K[base_idx]; // scale - for (var j = 0u; j < F16_PER_THREAD; j += 2) { - let q_0 = K[base_idx + 1u + block_offset + j]; - let q_1 = K[base_idx + 1u + block_offset + j + 1]; - let q_packed = bitcast<u32>(vec2(q_0, q_1)); - for (var k = 0u; k < 4u; k++) { - let q_byte = get_byte_i32(q_packed, k); - let q_val = f16(q_byte) * d; - let idx = block_k * BLOCK_SIZE + block_offset * 2u + j * 2u + k; - kv_shmem[row_offset + idx] = q_val; - } - } - } - } -#elif defined(KV_DIRECT) - // Direct global loads for KV -#else - for (var elem_idx = local_id.x; elem_idx < KV_TILE * HEAD_DIM_QK; elem_idx += WG_SIZE) { - let k_row = elem_idx / HEAD_DIM_QK; - let k_col = elem_idx % HEAD_DIM_QK; - let global_k_row = kv_tile + k_row; - let global_k_row_offset = k_head_offset + global_k_row * params.stride_k1; - kv_shmem[elem_idx] = f16(select( - 0.0, - K[global_k_row_offset + k_col], - global_k_row < params.seq_len_kv && k_col < HEAD_DIM_QK)); - } +#ifndef KV_DIRECT + load_k_tile_block(local_id.x, kv_count, kv_tile, k_head_offset); #endif workgroupBarrier(); // accumulate q block * k block into registers across the entire KV tile // TODO: this loop seems to be the current largest bottleneck - for (var kv_block = subgroup_id; kv_block < KV_BLOCKS; kv_block += num_subgroups) { - let inter_offset = kv_block * SG_MAT_N; - var acc: subgroup_matrix_result<f16, SG_MAT_M, SG_MAT_N> = subgroupMatrixLoad< - subgroup_matrix_result<f16, SG_MAT_M, SG_MAT_N>>(&inter_shmem, inter_offset, false, KV_TILE); + // this bracket exists to scope the lifetime of variables, reducing register pressure + { #ifdef KV_DIRECT - let k_block_row = kv_tile + kv_block * SG_MAT_N; - let k_global_offset = k_head_offset + k_block_row * params.stride_k1; + let k_block_row = kv_tile + subgroup_id * SG_MAT_N; + var k_global_offset = k_head_offset + k_block_row * params.stride_k1; #else - let k_block_offset = kv_block * SG_MAT_N * HEAD_DIM_QK; + var k_block_offset = subgroup_id * SG_MAT_N * HEAD_DIM_QK; #endif - for (var head_dim_block = 0u; head_dim_block < HEAD_DIM_QK; head_dim_block += SG_MAT_K) { - // load q submatrix from shared memory - var q_sg_mat: subgroup_matrix_left<f16, SG_MAT_M, SG_MAT_K> = subgroupMatrixLoad<subgroup_matrix_left<f16, SG_MAT_M, SG_MAT_K>>( - &q_shmem, - head_dim_block, - false, - HEAD_DIM_QK - ); + for (var kv_block = subgroup_id; kv_block < KV_BLOCKS; kv_block += num_subgroups) { + let inter_offset = kv_block * SG_MAT_N; + var acc: subgroup_matrix_result<f16, SG_MAT_N, SG_MAT_M> = subgroupMatrixLoad<subgroup_matrix_result<f16, SG_MAT_N, SG_MAT_M>>(&inter_shmem, inter_offset, false, KV_TILE); + + var q_cur = subgroupMatrixLoad<subgroup_matrix_left<f16, SG_MAT_K, SG_MAT_M>>(&q_shmem, 0u, false, HEAD_DIM_QK); - // load k submatrix from device or shared memory #ifdef KV_DIRECT - var k_sg_mat: subgroup_matrix_right<f16, SG_MAT_K, SG_MAT_N> = subgroupMatrixLoad<subgroup_matrix_right<f16, SG_MAT_K, SG_MAT_N>>( - &K, - k_global_offset + head_dim_block, - true, - params.stride_k1 - ); + var k_cur = subgroupMatrixLoad<subgroup_matrix_right<f16, SG_MAT_N, SG_MAT_K>>(&K, k_global_offset + 0u, true, params.stride_k1); #else - var k_sg_mat: subgroup_matrix_right<f16, SG_MAT_K, SG_MAT_N> = subgroupMatrixLoad<subgroup_matrix_right<f16, SG_MAT_K, SG_MAT_N>>( - &kv_shmem, - k_block_offset + head_dim_block, - true, - HEAD_DIM_QK - ); + var k_cur = subgroupMatrixLoad<subgroup_matrix_right<f16, SG_MAT_N, SG_MAT_K>>(&kv_shmem, k_block_offset + 0u, true, HEAD_DIM_QK); #endif - acc = subgroupMatrixMultiplyAccumulate(q_sg_mat, k_sg_mat, acc); - } - // store acc to shared memory for softmax (S matrix from paper) - subgroupMatrixStore(&inter_shmem, inter_offset, acc, false, KV_TILE); + var t: u32 = 1u; + for (; t + 1u < HEAD_DIM_QK / SG_MAT_K; t += 2u) { + let h0 = t * SG_MAT_K; + var q0 = subgroupMatrixLoad<subgroup_matrix_left<f16, SG_MAT_K, SG_MAT_M>>(&q_shmem, h0, false, HEAD_DIM_QK); +#ifdef KV_DIRECT + var k0 = subgroupMatrixLoad<subgroup_matrix_right<f16, SG_MAT_N, SG_MAT_K>>(&K, k_global_offset + h0, true, params.stride_k1); +#else + var k0 = subgroupMatrixLoad<subgroup_matrix_right<f16, SG_MAT_N, SG_MAT_K>>(&kv_shmem, k_block_offset + h0, true, HEAD_DIM_QK); +#endif + acc = subgroupMatrixMultiplyAccumulate(q_cur, k_cur, acc); + q_cur = q0; + k_cur = k0; + + let h1 = (t + 1u) * SG_MAT_K; + var q1g = subgroupMatrixLoad<subgroup_matrix_left<f16, SG_MAT_K, SG_MAT_M>>(&q_shmem, h1, false, HEAD_DIM_QK); +#ifdef KV_DIRECT + var k1g = subgroupMatrixLoad<subgroup_matrix_right<f16, SG_MAT_N, SG_MAT_K>>(&K, k_global_offset + h1, true, params.stride_k1); +#else + var k1g = subgroupMatrixLoad<subgroup_matrix_right<f16, SG_MAT_N, SG_MAT_K>>(&kv_shmem, k_block_offset + h1, true, HEAD_DIM_QK); +#endif + acc = subgroupMatrixMultiplyAccumulate(q_cur, k_cur, acc); + q_cur = q1g; + k_cur = k1g; + } + + // handle odd tail + if (t < HEAD_DIM_QK / SG_MAT_K) { + let h = t * SG_MAT_K; + var qn = subgroupMatrixLoad<subgroup_matrix_left<f16, SG_MAT_K, SG_MAT_M>>(&q_shmem, h, false, HEAD_DIM_QK); +#ifdef KV_DIRECT + var kn = subgroupMatrixLoad<subgroup_matrix_right<f16, SG_MAT_N, SG_MAT_K>>(&K, k_global_offset + h, true, params.stride_k1); +#else + var kn = subgroupMatrixLoad<subgroup_matrix_right<f16, SG_MAT_N, SG_MAT_K>>(&kv_shmem, k_block_offset + h, true, HEAD_DIM_QK); +#endif + acc = subgroupMatrixMultiplyAccumulate(q_cur, k_cur, acc); + q_cur = qn; + k_cur = kn; + } + + acc = subgroupMatrixMultiplyAccumulate(q_cur, k_cur, acc); + +#ifdef KV_DIRECT + k_global_offset += num_subgroups * SG_MAT_N * params.stride_k1; +#else + k_block_offset += num_subgroups * SG_MAT_N * HEAD_DIM_QK; +#endif + subgroupMatrixStore(&inter_shmem, inter_offset, acc, false, KV_TILE); + } } + #ifdef MASK // load mask tile into shared memory for this KV block // TODO: optimize and skip if mask is -INF for the entire tile @@ -412,73 +436,8 @@ fn main(@builtin(workgroup_id) wg_id: vec3<u32>, } // load v tile into shared memory -#if defined(KV_Q4_0) - for (var elem_idx = local_id.x * NQ; elem_idx < KV_TILE * HEAD_DIM_V; elem_idx += WG_SIZE * NQ) { - let blck_idx = elem_idx / BLOCK_SIZE; - let block_offset = (elem_idx % BLOCK_SIZE) / WEIGHTS_PER_F16; - let v_row = blck_idx / BLOCKS_V; - let global_v_row = kv_tile + v_row; - let block_k = blck_idx % BLOCKS_V; - let row_offset = v_row * HEAD_DIM_V; - - if (global_v_row < params.seq_len_kv) { - let global_block_idx = v_head_offset + global_v_row * params.stride_v1 + block_k; - let base_idx = global_block_idx * F16_PER_BLOCK; - let d = V[base_idx]; // scale - for (var j = 0u; j < F16_PER_THREAD; j += 2) { - let q_0 = V[base_idx + 1u + block_offset + j]; - let q_1 = V[base_idx + 1u + block_offset + j + 1]; - let q_packed = bitcast<u32>(vec2(q_0, q_1)); - for (var k = 0u; k < 4u; k++) { - let q_byte = get_byte(q_packed, k); - let q_hi = (f16((q_byte >> 4) & 0xF) - 8.0) * d; - let q_lo = (f16(q_byte & 0xF) - 8.0) * d; - let idx = block_k * BLOCK_SIZE + block_offset * 2u + j * 2u + k; - kv_shmem[row_offset + idx] = q_lo; - kv_shmem[row_offset + idx + 16u] = q_hi; - } - } - } - } -#elif defined(KV_Q8_0) - for (var elem_idx = local_id.x * NQ; elem_idx < KV_TILE * HEAD_DIM_V; elem_idx += WG_SIZE * NQ) { - let blck_idx = elem_idx / BLOCK_SIZE; - let block_offset = (elem_idx % BLOCK_SIZE) / WEIGHTS_PER_F16; - let v_row = blck_idx / BLOCKS_V; - let global_v_row = kv_tile + v_row; - let block_k = blck_idx % BLOCKS_V; - let row_offset = v_row * HEAD_DIM_V; - - if (global_v_row < params.seq_len_kv) { - let global_block_idx = v_head_offset + global_v_row * params.stride_v1 + block_k; - let base_idx = global_block_idx * F16_PER_BLOCK; - let d = V[base_idx]; // scale - for (var j = 0u; j < F16_PER_THREAD; j += 2) { - let q_0 = V[base_idx + 1u + block_offset + j]; - let q_1 = V[base_idx + 1u + block_offset + j + 1]; - let q_packed = bitcast<u32>(vec2(q_0, q_1)); - for (var k = 0u; k < 4u; k++) { - let q_byte = get_byte_i32(q_packed, k); - let q_val = f16(q_byte) * d; - let idx = block_k * BLOCK_SIZE + block_offset * 2u + j * 2u + k; - kv_shmem[row_offset + idx] = q_val; - } - } - } - } -#elif defined(KV_DIRECT) - // Direct global loads for KV -#else - for (var elem_idx = local_id.x; elem_idx < KV_TILE * HEAD_DIM_V; elem_idx += WG_SIZE) { - let v_row = elem_idx / HEAD_DIM_V; - let v_col = elem_idx % HEAD_DIM_V; - let global_v_row = kv_tile + v_row; - let global_v_row_offset = v_head_offset + global_v_row * params.stride_v1; - kv_shmem[elem_idx] = f16(select( - 0.0, - V[global_v_row_offset + v_col], - global_v_row < params.seq_len_kv && v_col < HEAD_DIM_V)); - } +#ifndef KV_DIRECT + load_v_tile_block(local_id.x, kv_count, kv_tile, v_head_offset); #endif workgroupBarrier(); @@ -489,16 +448,15 @@ fn main(@builtin(workgroup_id) wg_id: vec3<u32>, head_dim_block < HEAD_DIM_V; head_dim_block += num_subgroups * SG_MAT_N) { // load O submatrix from shared memory - var o_sg_mat: subgroup_matrix_result<f16, SG_MAT_M, SG_MAT_N> = subgroupMatrixLoad<subgroup_matrix_result<f16, SG_MAT_M, SG_MAT_N>>( + var o_sg_mat: subgroup_matrix_result<f16, SG_MAT_N, SG_MAT_M> = subgroupMatrixLoad<subgroup_matrix_result<f16, SG_MAT_N, SG_MAT_M>>( &o_shmem, head_dim_block, false, HEAD_DIM_V ); - for (var kv_block = 0u; kv_block < KV_BLOCKS; kv_block++) { let p_offset = kv_block * SG_MAT_N; - var p_sg_mat: subgroup_matrix_left<f16, SG_MAT_M, SG_MAT_K> = subgroupMatrixLoad<subgroup_matrix_left<f16, SG_MAT_M, SG_MAT_K>>( + var p_sg_mat: subgroup_matrix_left<f16, SG_MAT_K, SG_MAT_M> = subgroupMatrixLoad<subgroup_matrix_left<f16, SG_MAT_K, SG_MAT_M>>( &inter_shmem, p_offset, false, @@ -509,7 +467,7 @@ fn main(@builtin(workgroup_id) wg_id: vec3<u32>, #ifdef KV_DIRECT let v_block_row = kv_tile + kv_block * SG_MAT_N; let v_global_offset = v_head_offset + v_block_row * params.stride_v1 + head_dim_block; - var v_sg_mat: subgroup_matrix_right<f16, SG_MAT_K, SG_MAT_N> = subgroupMatrixLoad<subgroup_matrix_right<f16, SG_MAT_K, SG_MAT_N>>( + var v_sg_mat: subgroup_matrix_right<f16, SG_MAT_N, SG_MAT_K> = subgroupMatrixLoad<subgroup_matrix_right<f16, SG_MAT_N, SG_MAT_K>>( &V, v_global_offset, false, @@ -517,7 +475,7 @@ fn main(@builtin(workgroup_id) wg_id: vec3<u32>, ); #else let v_block_offset = kv_block * SG_MAT_N * HEAD_DIM_V; - var v_sg_mat: subgroup_matrix_right<f16, SG_MAT_K, SG_MAT_N> = subgroupMatrixLoad<subgroup_matrix_right<f16, SG_MAT_K, SG_MAT_N>>( + var v_sg_mat: subgroup_matrix_right<f16, SG_MAT_N, SG_MAT_K> = subgroupMatrixLoad<subgroup_matrix_right<f16, SG_MAT_N, SG_MAT_K>>( &kv_shmem, v_block_offset + head_dim_block, false, @@ -527,11 +485,9 @@ fn main(@builtin(workgroup_id) wg_id: vec3<u32>, // O += P * V o_sg_mat = subgroupMatrixMultiplyAccumulate(p_sg_mat, v_sg_mat, o_sg_mat); } - // store O back to shared memory subgroupMatrixStore(&o_shmem, head_dim_block, o_sg_mat, false, HEAD_DIM_V); } - workgroupBarrier(); } @@ -566,26 +522,38 @@ fn main(@builtin(workgroup_id) wg_id: vec3<u32>, o_shmem[idx] = f16(val); } } - workgroupBarrier(); #endif - - // write output back to global memory for (var q_tile_row = subgroup_id; - q_tile_row < Q_TILE; - q_tile_row += num_subgroups) { - let global_q_row = q_row_start + q_tile_row; - if (global_q_row >= params.seq_len_q) { - break; - } + q_tile_row < Q_TILE; + q_tile_row += num_subgroups) { - let exp_sum = exp_sum_shmem[q_tile_row]; - let scale = select(0.0, 1.0 / exp_sum, exp_sum != 0); + let global_q_row = q_row_start + q_tile_row; + if (global_q_row >= params.seq_len_q) { break; } - for (var elem_idx = sg_inv_id; elem_idx < HEAD_DIM_V; elem_idx += subgroup_size) { - let o_val = o_shmem[q_tile_row * HEAD_DIM_V + elem_idx]; - let scaled = f32(o_val) * scale; - dst[dst_global_offset + q_tile_row * dst2_stride + elem_idx] = scaled; - } + let exp_sum = exp_sum_shmem[q_tile_row]; + let scale = select(0.0, 1.0 / exp_sum, exp_sum != 0.0); + + let row_base: u32 = dst_global_offset + q_tile_row * dst2_stride; + + for (var elem_base = sg_inv_id * 4u; + elem_base < HEAD_DIM_V; + elem_base += subgroup_size * 4u) { + + let i0 = q_tile_row * HEAD_DIM_V + (elem_base + 0u); + let i1 = q_tile_row * HEAD_DIM_V + (elem_base + 1u); + let i2 = q_tile_row * HEAD_DIM_V + (elem_base + 2u); + let i3 = q_tile_row * HEAD_DIM_V + (elem_base + 3u); + + let v = vec4<f32>( + f32(o_shmem[i0]) * scale, + f32(o_shmem[i1]) * scale, + f32(o_shmem[i2]) * scale, + f32(o_shmem[i3]) * scale + ); + + let dst_vec_index: u32 = (row_base + elem_base) >> 2u; + dst[dst_vec_index] = v; + } } } diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_quant_staging.tmpl b/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_quant_staging.tmpl new file mode 100644 index 00000000000..8f41eb7bfdb --- /dev/null +++ b/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_quant_staging.tmpl @@ -0,0 +1,124 @@ +#define BLOCK_SIZE 32 +#define BLOCKS_K ((HEAD_DIM_QK + BLOCK_SIZE - 1) / BLOCK_SIZE) +#define BLOCKS_V ((HEAD_DIM_V + BLOCK_SIZE - 1) / BLOCK_SIZE) + +#if defined(K_Q4_0) +#define K_NQ 16 +#define K_BLOCK_SIZE_BYTES 18u +#define K_BYTES_PER_THREAD 8u +#define K_BYTES_PER_INNER_LOOP 4u +#elif defined(K_Q8_0) +#define K_NQ 16 +#define K_BLOCK_SIZE_BYTES 34u +#define K_BYTES_PER_THREAD 16u +#define K_BYTES_PER_INNER_LOOP 4u +#endif + +#if defined(V_Q4_0) +#define V_NQ 16 +#define V_BLOCK_SIZE_BYTES 18u +#define V_BYTES_PER_THREAD 8u +#define V_BYTES_PER_INNER_LOOP 4u +#elif defined(V_Q8_0) +#define V_NQ 16 +#define V_BLOCK_SIZE_BYTES 34u +#define V_BYTES_PER_THREAD 16u +#define V_BYTES_PER_INNER_LOOP 4u +#endif + +#if defined(K_Q4_0) || defined(K_Q8_0) +fn load_k_u16_at(byte_offset: u32) -> u32 { + let word = K[byte_offset / 4u]; + let shift = (byte_offset & 2u) * 8u; + return (word >> shift) & 0xFFFFu; +} + +fn load_k_u32_at(byte_offset: u32) -> u32 { + let word_idx = byte_offset / 4u; + let shift = (byte_offset & 3u) * 8u; + let lo = K[word_idx]; + if (shift == 0u) { + return lo; + } + let hi = K[word_idx + 1u]; + return (lo >> shift) | (hi << (32u - shift)); +} +#endif + +#if defined(V_Q4_0) || defined(V_Q8_0) +fn load_v_u16_at(byte_offset: u32) -> u32 { + let word = V[byte_offset / 4u]; + let shift = (byte_offset & 2u) * 8u; + return (word >> shift) & 0xFFFFu; +} + +fn load_v_u32_at(byte_offset: u32) -> u32 { + let word_idx = byte_offset / 4u; + let shift = (byte_offset & 3u) * 8u; + let lo = V[word_idx]; + if (shift == 0u) { + return lo; + } + let hi = V[word_idx + 1u]; + return (lo >> shift) | (hi << (32u - shift)); +} +#endif + +fn f16_from_u16(bits: u32) -> f16 { + let packed = unpack2x16float(bits); + return f16(packed[0]); +} + +#if defined(K_Q4_0) || defined(K_Q8_0) +fn load_k_tile_block(local_x: u32, kv_count: u32, kv_tile: u32, k_head_offset: u32) { + for (var elem_idx = local_x * K_NQ; elem_idx < kv_count * HEAD_DIM_QK; elem_idx += WG_SIZE * K_NQ) { + let blck_idx = elem_idx / BLOCK_SIZE; + let block_offset = (elem_idx % BLOCK_SIZE) / K_NQ; + let k_row = blck_idx / BLOCKS_K; + let global_k_row = kv_tile + k_row; + let block_k = blck_idx % BLOCKS_K; + let row_offset = k_row * HEAD_DIM_QK; + let global_block_idx = k_head_offset + global_k_row * params.stride_k1 + block_k; + let block_byte_base = global_block_idx * K_BLOCK_SIZE_BYTES; + let d = f16_from_u16(load_k_u16_at(block_byte_base)); + let thread_byte_offset = block_offset * K_BYTES_PER_THREAD; + let shmem_idx = row_offset + block_k * BLOCK_SIZE + thread_byte_offset; + for (var j = 0u; j < K_BYTES_PER_THREAD / K_BYTES_PER_INNER_LOOP; j += 1u) { + let q_byte_offset = block_byte_base + 2u + thread_byte_offset + j * K_BYTES_PER_INNER_LOOP; + let q_packed = load_k_u32_at(q_byte_offset); +#if defined(K_Q4_0) + dequant_q4_0_packed_to_shmem(q_packed, d, shmem_idx + j * K_BYTES_PER_INNER_LOOP); +#elif defined(K_Q8_0) + dequant_q8_0_packed_to_shmem(q_packed, d, shmem_idx + j * K_BYTES_PER_INNER_LOOP); +#endif + } + } +} +#endif + +#if defined(V_Q4_0) || defined(V_Q8_0) +fn load_v_tile_block(local_x: u32, kv_count: u32, kv_tile: u32, v_head_offset: u32) { + for (var elem_idx = local_x * V_NQ; elem_idx < kv_count * HEAD_DIM_V; elem_idx += WG_SIZE * V_NQ) { + let blck_idx = elem_idx / BLOCK_SIZE; + let block_offset = (elem_idx % BLOCK_SIZE) / V_NQ; + let v_row = blck_idx / BLOCKS_V; + let global_v_row = kv_tile + v_row; + let block_k = blck_idx % BLOCKS_V; + let row_offset = v_row * HEAD_DIM_V; + let global_block_idx = v_head_offset + global_v_row * params.stride_v1 + block_k; + let block_byte_base = global_block_idx * V_BLOCK_SIZE_BYTES; + let d = f16_from_u16(load_v_u16_at(block_byte_base)); + let thread_byte_offset = block_offset * V_BYTES_PER_THREAD; + let shmem_idx = row_offset + block_k * BLOCK_SIZE + thread_byte_offset; + for (var j = 0u; j < V_BYTES_PER_THREAD / V_BYTES_PER_INNER_LOOP; j += 1u) { + let q_byte_offset = block_byte_base + 2u + thread_byte_offset + j * V_BYTES_PER_INNER_LOOP; + let q_packed = load_v_u32_at(q_byte_offset); +#if defined(V_Q4_0) + dequant_q4_0_packed_to_shmem(q_packed, d, shmem_idx + j * V_BYTES_PER_INNER_LOOP); +#elif defined(V_Q8_0) + dequant_q8_0_packed_to_shmem(q_packed, d, shmem_idx + j * V_BYTES_PER_INNER_LOOP); +#endif + } + } +} +#endif diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_tile.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_tile.wgsl new file mode 100644 index 00000000000..e68934113fc --- /dev/null +++ b/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_tile.wgsl @@ -0,0 +1,397 @@ +enable f16; +enable subgroups; + +#define BYTE_HELPERS +#include "common_decls.tmpl" + +#ifdef Q_F16 +#define Q_TYPE f16 +#else +#define Q_TYPE f32 +#endif + +#ifdef K_F32 +#define K_TYPE f32 +#elif defined(K_Q4_0) || defined(K_Q8_0) +#define K_TYPE u32 +#else +#define K_TYPE f16 +#endif + +#ifdef V_F32 +#define V_TYPE f32 +#elif defined(V_Q4_0) || defined(V_Q8_0) +#define V_TYPE u32 +#else +#define V_TYPE f16 +#endif + +#ifdef DST_F16 +#define DST_TYPE f16 +#else +#define DST_TYPE f32 +#endif + +#define HEAD_DIM_QK 64 +#define HEAD_DIM_V 64 +#define Q_TILE 4 +#define KV_TILE 64 +#define WG_SIZE 128 +#ifndef MIN_SUBGROUP_SIZE +#define MIN_SUBGROUP_SIZE MAX_SUBGROUP_SIZE +#endif + +struct Params { + offset_q: u32, + offset_k: u32, + offset_v: u32, + offset_mask: u32, + offset_sinks: u32, + offset_dst: u32, + + n_heads: u32, + seq_len_q: u32, + seq_len_kv: u32, + + stride_q1: u32, + stride_q2: u32, + stride_q3: u32, + stride_k1: u32, + stride_k2: u32, + stride_k3: u32, + stride_v1: u32, + stride_v2: u32, + stride_v3: u32, + stride_mask3: u32, + + q_per_kv: u32, + + scale: f32, + max_bias: f32, + logit_softcap: f32, + n_head_log2: f32, + m0: f32, + m1: f32, +}; + +@group(0) @binding(0) var<storage, read_write> Q: array<Q_TYPE>; +#ifdef KV_OVERLAP +#if defined(K_Q4_0) || defined(K_Q8_0) +@group(0) @binding(1) var<storage, read_write> K: array<K_TYPE>; +#else +@group(0) @binding(1) var<storage, read_write> K: array<vec4<K_TYPE>>; +#endif +#define V K +#else +#if defined(K_Q4_0) || defined(K_Q8_0) +@group(0) @binding(1) var<storage, read_write> K: array<K_TYPE>; +#else +@group(0) @binding(1) var<storage, read_write> K: array<vec4<K_TYPE>>; +#endif +#if defined(V_Q4_0) || defined(V_Q8_0) +@group(0) @binding(2) var<storage, read_write> V: array<V_TYPE>; +#else +@group(0) @binding(2) var<storage, read_write> V: array<vec4<V_TYPE>>; +#endif +#endif + +#if defined(MASK) && defined(SINKS) +#ifdef KV_OVERLAP +@group(0) @binding(2) var<storage, read_write> mask: array<f16>; +@group(0) @binding(3) var<storage, read_write> sinks: array<f32>; +#define DST_BINDING 4 +#define PARAMS_BINDING 5 +#else +@group(0) @binding(3) var<storage, read_write> mask: array<f16>; +@group(0) @binding(4) var<storage, read_write> sinks: array<f32>; +#define DST_BINDING 5 +#define PARAMS_BINDING 6 +#endif +#elif defined(MASK) +#ifdef KV_OVERLAP +@group(0) @binding(2) var<storage, read_write> mask: array<f16>; +#define DST_BINDING 3 +#define PARAMS_BINDING 4 +#else +@group(0) @binding(3) var<storage, read_write> mask: array<f16>; +#define DST_BINDING 4 +#define PARAMS_BINDING 5 +#endif +#elif defined(SINKS) +#ifdef KV_OVERLAP +@group(0) @binding(2) var<storage, read_write> sinks: array<f32>; +#define DST_BINDING 3 +#define PARAMS_BINDING 4 +#else +@group(0) @binding(3) var<storage, read_write> sinks: array<f32>; +#define DST_BINDING 4 +#define PARAMS_BINDING 5 +#endif +#else +#ifdef KV_OVERLAP +#define DST_BINDING 2 +#define PARAMS_BINDING 3 +#else +#define DST_BINDING 3 +#define PARAMS_BINDING 4 +#endif +#endif + +@group(0) @binding(DST_BINDING) var<storage, read_write> dst: array<vec4<DST_TYPE>>; +@group(0) @binding(PARAMS_BINDING) var<uniform> params: Params; + +const FLOAT_MIN: f32 = -1.0e9; +const Q_CHUNKS: u32 = HEAD_DIM_QK / 4u; +const V_CHUNKS: u32 = HEAD_DIM_V / 4u; +const SCORE_REGS_PER_LANE: u32 = (KV_TILE + MIN_SUBGROUP_SIZE - 1u) / MIN_SUBGROUP_SIZE; +const OUT_REGS_PER_LANE: u32 = (V_CHUNKS + MIN_SUBGROUP_SIZE - 1u) / MIN_SUBGROUP_SIZE; +const kv_shmem_size = KV_TILE * max(HEAD_DIM_QK, HEAD_DIM_V); + +var<workgroup> q_shmem: array<Q_TYPE, Q_TILE * HEAD_DIM_QK>; +var<workgroup> kv_shmem: array<f16, kv_shmem_size>; +var<workgroup> p_shmem: array<f16, Q_TILE * KV_TILE>; + +#define QUANT_SHMEM kv_shmem +#define QUANT_OUT_TYPE f16 +#include "quant_inner_loops.tmpl" +#include "flash_attn_quant_staging.tmpl" + +#if !defined(K_Q4_0) && !defined(K_Q8_0) +fn load_k_tile_block(local_x: u32, kv_count: u32, kv_tile: u32, k_head_offset: u32) { + for (var vec_idx_local = local_x; vec_idx_local < kv_count * Q_CHUNKS; vec_idx_local += WG_SIZE) { + let kv_local = vec_idx_local / Q_CHUNKS; + let chunk = vec_idx_local % Q_CHUNKS; + let global_k_row = kv_tile + kv_local; + let k_vec_index = (k_head_offset + global_k_row * params.stride_k1 + chunk * 4u) >> 2u; + let k4 = K[k_vec_index]; + let kv_off = kv_local * HEAD_DIM_QK + chunk * 4u; + kv_shmem[kv_off + 0u] = f16(k4.x); + kv_shmem[kv_off + 1u] = f16(k4.y); + kv_shmem[kv_off + 2u] = f16(k4.z); + kv_shmem[kv_off + 3u] = f16(k4.w); + } +} +#endif + +#if !defined(V_Q4_0) && !defined(V_Q8_0) +fn load_v_tile_block(local_x: u32, kv_count: u32, kv_tile: u32, v_head_offset: u32) { + for (var vec_idx_local = local_x; vec_idx_local < kv_count * V_CHUNKS; vec_idx_local += WG_SIZE) { + let kv_local = vec_idx_local / V_CHUNKS; + let chunk = vec_idx_local % V_CHUNKS; + let global_v_row = kv_tile + kv_local; + let v_vec_index = (v_head_offset + global_v_row * params.stride_v1 + chunk * 4u) >> 2u; + let v4 = V[v_vec_index]; + let kv_off = kv_local * HEAD_DIM_V + chunk * 4u; + kv_shmem[kv_off + 0u] = f16(v4.x); + kv_shmem[kv_off + 1u] = f16(v4.y); + kv_shmem[kv_off + 2u] = f16(v4.z); + kv_shmem[kv_off + 3u] = f16(v4.w); + } +} +#endif + +@compute @workgroup_size(WG_SIZE) +fn main(@builtin(workgroup_id) wg_id: vec3<u32>, + @builtin(local_invocation_id) local_id: vec3<u32>, + @builtin(subgroup_id) subgroup_id: u32, + @builtin(subgroup_size) subgroup_size: u32, + @builtin(num_subgroups) num_subgroups: u32, + @builtin(subgroup_invocation_id) sg_inv_id: u32) { + if (subgroup_size == 0u || num_subgroups < Q_TILE) { + return; + } + + let wg_per_head = (params.seq_len_q + Q_TILE - 1u) / Q_TILE; + let wg_per_batch = wg_per_head * params.n_heads; + + let dst2_stride = HEAD_DIM_V * params.n_heads; + let dst3_stride = dst2_stride * params.seq_len_q; + + let batch_idx = wg_id.x / wg_per_batch; + let q_batch_offset = params.offset_q + batch_idx * params.stride_q3; + let k_batch_offset = params.offset_k + batch_idx * params.stride_k3; + let v_batch_offset = params.offset_v + batch_idx * params.stride_v3; + let dst_batch_offset = params.offset_dst + batch_idx * dst3_stride; + let wg_in_batch = wg_id.x % wg_per_batch; + + let head_idx = wg_in_batch / wg_per_head; + let q_head_offset = q_batch_offset + head_idx * params.stride_q2; + let k_head_idx = head_idx / params.q_per_kv; + let v_head_offset = v_batch_offset + k_head_idx * params.stride_v2; + let k_head_offset = k_batch_offset + k_head_idx * params.stride_k2; + + let wg_in_head = wg_in_batch % wg_per_head; + let q_row_start = wg_in_head * Q_TILE; + let global_q_row = q_row_start + subgroup_id; + let row_active = subgroup_id < Q_TILE && global_q_row < params.seq_len_q; + +#ifdef MASK + let mask_global_offset = params.offset_mask + batch_idx * params.stride_mask3 + q_row_start * params.seq_len_kv; +#endif + + let dst_global_offset = dst_batch_offset + q_row_start * dst2_stride + head_idx * HEAD_DIM_V; + + let head = f32(head_idx); + let slope = select(1.0, + select(pow(params.m1, 2.0 * (head - params.n_head_log2) + 1.0), + pow(params.m0, head + 1.0), + head < params.n_head_log2), + params.max_bias > 0.0); + + for (var elem_idx = local_id.x; elem_idx < Q_TILE * HEAD_DIM_QK; elem_idx += WG_SIZE) { + let q_tile_row = elem_idx / HEAD_DIM_QK; + let q_col = elem_idx % HEAD_DIM_QK; + let head_q_row = q_row_start + q_tile_row; + let global_q_row_offset = q_head_offset + head_q_row * params.stride_q1; + q_shmem[elem_idx] = select( + 0.0, + Q_TYPE(Q[global_q_row_offset + q_col]) * params.scale, + head_q_row < params.seq_len_q); + } + + workgroupBarrier(); + + var row_max = FLOAT_MIN; + var exp_sum = 0.0; + var out_regs: array<vec4<f32>, OUT_REGS_PER_LANE>; + for (var reg_idx = 0u; reg_idx < OUT_REGS_PER_LANE; reg_idx += 1u) { + out_regs[reg_idx] = vec4<f32>(0.0); + } + + let q_base = subgroup_id * HEAD_DIM_QK; + let subgroup_p_offset = subgroup_id * KV_TILE; + + for (var kv_tile = 0u; kv_tile < params.seq_len_kv; kv_tile += KV_TILE) { + let kv_count = min(KV_TILE, params.seq_len_kv - kv_tile); + let score_slots = min(SCORE_REGS_PER_LANE, (kv_count + subgroup_size - 1u) / subgroup_size); + let out_slots = min(OUT_REGS_PER_LANE, (V_CHUNKS + subgroup_size - 1u) / subgroup_size); + var local_scores: array<f32, SCORE_REGS_PER_LANE>; + for (var slot = 0u; slot < SCORE_REGS_PER_LANE; slot += 1u) { + local_scores[slot] = FLOAT_MIN; + } + +#ifndef KV_DIRECT + load_k_tile_block(local_id.x, kv_count, kv_tile, k_head_offset); +#endif + + workgroupBarrier(); + + var local_max = FLOAT_MIN; + if (row_active) { + for (var slot = 0u; slot < score_slots; slot += 1u) { + let kv_local = sg_inv_id + slot * subgroup_size; + if (kv_local >= kv_count) { + continue; + } + + let global_k_row = kv_tile + kv_local; + var dot_val = 0.0; + for (var chunk = 0u; chunk < Q_CHUNKS; chunk += 1u) { + let q_off = q_base + chunk * 4u; + let qv = vec4<Q_TYPE>( + q_shmem[q_off + 0u], + q_shmem[q_off + 1u], + q_shmem[q_off + 2u], + q_shmem[q_off + 3u]); + let kv_off = kv_local * HEAD_DIM_QK + chunk * 4u; + let kv = vec4<f16>( + kv_shmem[kv_off + 0u], + kv_shmem[kv_off + 1u], + kv_shmem[kv_off + 2u], + kv_shmem[kv_off + 3u]); + dot_val += dot(vec4<f32>(qv), vec4<f32>(kv)); + } +#ifdef LOGIT_SOFTCAP + dot_val = params.logit_softcap * tanh(dot_val); +#endif +#ifdef MASK + let mask_idx = mask_global_offset + subgroup_id * params.seq_len_kv + global_k_row; + dot_val += slope * f32(mask[mask_idx]); +#endif + local_scores[slot] = dot_val; + local_max = max(local_max, dot_val); + } + } + + let tile_max = subgroupMax(local_max); + let new_max = max(row_max, tile_max); + let cur_exp = exp(row_max - new_max); + exp_sum *= cur_exp; + for (var reg_idx = 0u; reg_idx < OUT_REGS_PER_LANE; reg_idx += 1u) { + out_regs[reg_idx] *= cur_exp; + } + + var local_sum = 0.0; + for (var slot = 0u; slot < score_slots; slot += 1u) { + let kv_local = sg_inv_id + slot * subgroup_size; + if (row_active && kv_local < kv_count) { + let p = exp(local_scores[slot] - new_max); + p_shmem[subgroup_p_offset + kv_local] = f16(p); + local_sum += p; + } + } + + workgroupBarrier(); + +#ifndef KV_DIRECT + load_v_tile_block(local_id.x, kv_count, kv_tile, v_head_offset); +#endif + + workgroupBarrier(); + + let tile_sum = subgroupAdd(local_sum); + exp_sum += tile_sum; + row_max = new_max; + + if (row_active) { + for (var reg_idx = 0u; reg_idx < out_slots; reg_idx += 1u) { + let chunk = sg_inv_id + reg_idx * subgroup_size; + if (chunk >= V_CHUNKS) { + continue; + } + + var acc = out_regs[reg_idx]; + for (var kv_local = 0u; kv_local < kv_count; kv_local += 1u) { + let p = f32(p_shmem[subgroup_p_offset + kv_local]); + let kv_off = kv_local * HEAD_DIM_V + chunk * 4u; + let v4 = vec4<f16>( + kv_shmem[kv_off + 0u], + kv_shmem[kv_off + 1u], + kv_shmem[kv_off + 2u], + kv_shmem[kv_off + 3u]); + acc += p * vec4<f32>(v4); + } + out_regs[reg_idx] = acc; + } + } + + workgroupBarrier(); + } + +#ifdef SINKS + if (row_active) { + let sink_score = sinks[params.offset_sinks + head_idx]; + let sink_max = max(row_max, sink_score); + let sink_scale = exp(row_max - sink_max); + for (var reg_idx = 0u; reg_idx < OUT_REGS_PER_LANE; reg_idx += 1u) { + out_regs[reg_idx] *= sink_scale; + } + exp_sum = exp_sum * sink_scale + exp(sink_score - sink_max); + row_max = sink_max; + } +#endif + + if (row_active) { + let inv_exp_sum = select(0.0, 1.0 / exp_sum, exp_sum != 0.0); + let row_base = dst_global_offset + subgroup_id * dst2_stride; + let out_slots = min(OUT_REGS_PER_LANE, (V_CHUNKS + subgroup_size - 1u) / subgroup_size); + for (var reg_idx = 0u; reg_idx < out_slots; reg_idx += 1u) { + let chunk = sg_inv_id + reg_idx * subgroup_size; + if (chunk >= V_CHUNKS) { + continue; + } + let dst_vec_index = (row_base + chunk * 4u) >> 2u; + dst[dst_vec_index] = vec4<DST_TYPE>(out_regs[reg_idx] * inv_exp_sum); + } + } +} diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_blk.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_blk.wgsl new file mode 100644 index 00000000000..b4f7c16c35d --- /dev/null +++ b/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_blk.wgsl @@ -0,0 +1,101 @@ +diagnostic(off, subgroup_uniformity); +enable f16; + +#define KV_TILE 32 +#define WG_SIZE 32 + +struct Params { + offset_mask: u32, + seq_len_q: u32, + seq_len_kv: u32, + stride_mask3: u32, + // Number of KV blocks and Q blocks per batch. + // nblk0 = ceil(seq_len_kv / KV_TILE), nblk1 = seq_len_q. + nblk0: u32, + nblk1: u32, +}; + +@group(0) @binding(0) var<storage, read_write> mask: array<f16>; +@group(0) @binding(1) var<storage, read_write> blk: array<u32>; +@group(0) @binding(2) var<uniform> params: Params; + +const MASK_MIN: f32 = -65504.0; +const MASK_MAX: f32 = 65504.0; +var<workgroup> wg_min: array<f32, WG_SIZE>; +var<workgroup> wg_max: array<f32, WG_SIZE>; +var<workgroup> wg_any: array<u32, WG_SIZE>; + +@compute @workgroup_size(WG_SIZE) +fn main(@builtin(workgroup_id) wg_id: vec3<u32>, + @builtin(local_invocation_id) local_id: vec3<u32>) { + // Dispatch mapping: + // - x indexes KV blocks + // - y flattens (batch_idx, q_blk) as y = batch_idx * nblk1 + q_blk + let kv_blk = wg_id.x; + let y = wg_id.y; + let q_blk = y % params.nblk1; + let batch_idx = y / params.nblk1; + if (kv_blk >= params.nblk0) { + return; + } + + let q_start = q_blk; + let k_start = kv_blk * KV_TILE; + + let mask_batch = select(0u, batch_idx, params.stride_mask3 > 0u); + let mask_batch_base = params.offset_mask + mask_batch * params.stride_mask3; + + // We keep min/max to classify: + // - fully masked (max <= MASK_MIN) + // - all-zero mask (min == 0 && max == 0) + // - mixed/general mask + var local_min = MASK_MAX; + var local_max = -MASK_MAX; + var local_any = 0u; + + let q_row = q_start; + if (q_row < params.seq_len_q) { + let row_base = mask_batch_base + q_row * params.seq_len_kv; + for (var k_rel = local_id.x; k_rel < KV_TILE; k_rel += WG_SIZE) { + let k_col = k_start + k_rel; + if (k_col >= params.seq_len_kv) { + continue; + } + let mv = f32(mask[row_base + k_col]); + local_min = min(local_min, mv); + local_max = max(local_max, mv); + local_any = 1u; + } + } + + wg_min[local_id.x] = local_min; + wg_max[local_id.x] = local_max; + wg_any[local_id.x] = local_any; + workgroupBarrier(); + + // Thread 0 writes one state per block. + if (local_id.x == 0u) { + var mmin = wg_min[0]; + var mmax = wg_max[0]; + var many = wg_any[0]; + for (var i = 1u; i < WG_SIZE; i += 1u) { + mmin = min(mmin, wg_min[i]); + mmax = max(mmax, wg_max[i]); + many = max(many, wg_any[i]); + } + + var state = 0u; + if (many != 0u) { + if (mmax <= MASK_MIN) { + state = 0u; + } else if (mmin == 0.0 && mmax == 0.0) { + state = 2u; + } else { + state = 1u; + } + } + + let blk_idx = (batch_idx * params.nblk1 + q_blk) * params.nblk0 + kv_blk; + blk[blk_idx] = state; + } +} diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_reduce.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_reduce.wgsl new file mode 100644 index 00000000000..1091d744073 --- /dev/null +++ b/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_reduce.wgsl @@ -0,0 +1,84 @@ +diagnostic(off, subgroup_uniformity); +enable f16; +enable subgroups; + +#ifdef DST_F16 +#define DST_TYPE f16 +#else +#define DST_TYPE f32 +#endif + +// Default values +#define HEAD_DIM_V 64 +#define WG_SIZE 128 + +struct Params { + nrows: u32, + seq_len_q: u32, + n_heads: u32, + offset_dst: u32, + nwg: u32, + tmp_data_base: u32, + tmp_stats_base: u32, +}; + +@group(0) @binding(0) var<storage, read_write> tmp: array<f32>; +@group(0) @binding(1) var<storage, read_write> dst: array<vec4<DST_TYPE>>; +@group(0) @binding(2) var<uniform> params: Params; + +const FLOAT_MIN: f32 = -1.0e9; + +@compute @workgroup_size(WG_SIZE) +fn main(@builtin(workgroup_id) wg_id: vec3<u32>, + @builtin(subgroup_id) subgroup_id: u32, + @builtin(num_subgroups) num_subgroups: u32, + @builtin(subgroup_size) subgroup_size: u32, + @builtin(subgroup_invocation_id) sg_inv_id: u32) { + let rid = wg_id.x; + if (rid >= params.nrows) { + return; + } + + let rows_per_batch = params.n_heads * params.seq_len_q; + let batch_idx = rid / rows_per_batch; + let rem = rid % rows_per_batch; + let head_idx = rem / params.seq_len_q; + let q_row = rem % params.seq_len_q; + + let dst2_stride = HEAD_DIM_V * params.n_heads; + let dst3_stride = dst2_stride * params.seq_len_q; + let row_base = params.offset_dst + batch_idx * dst3_stride + q_row * dst2_stride + head_idx * HEAD_DIM_V; + + let thread = sg_inv_id; + if (params.nwg > subgroup_size) { + return; + } + + let stats_base = params.tmp_stats_base + rid * (2u * params.nwg); + let active_thread = thread < params.nwg; + let si = select(0.0, tmp[stats_base + 2u * thread + 0u], active_thread); + let mi = select(FLOAT_MIN, tmp[stats_base + 2u * thread + 1u], active_thread); + let m = subgroupMax(mi); + let ms = select(0.0, exp(mi - m), active_thread); + let s = subgroupAdd(si * ms); + let inv_s = select(0.0, 1.0 / s, s != 0.0); + + let row_tmp_base = params.tmp_data_base + rid * (HEAD_DIM_V * params.nwg); + for (var elem_base = subgroup_id * 4u; elem_base < HEAD_DIM_V; elem_base += num_subgroups * 4u) { + var weighted = vec4<f32>(0.0, 0.0, 0.0, 0.0); + if (active_thread) { + let src = row_tmp_base + thread * HEAD_DIM_V + elem_base; + weighted = vec4<f32>(tmp[src + 0u], tmp[src + 1u], tmp[src + 2u], tmp[src + 3u]) * ms; + } + + let sum_x = subgroupAdd(weighted.x); + let sum_y = subgroupAdd(weighted.y); + let sum_z = subgroupAdd(weighted.z); + let sum_w = subgroupAdd(weighted.w); + + if (thread == 0u) { + let dst_vec_index = (row_base + elem_base) >> 2u; + dst[dst_vec_index] = vec4<DST_TYPE>(vec4<f32>(sum_x, sum_y, sum_z, sum_w) * inv_s); + } + } +} diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_split.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_split.wgsl new file mode 100644 index 00000000000..30ed97cca0c --- /dev/null +++ b/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_split.wgsl @@ -0,0 +1,619 @@ +diagnostic(off, subgroup_uniformity); +enable f16; +enable subgroups; + +#define BYTE_HELPERS +#include "common_decls.tmpl" + +#ifdef K_F32 +#define K_TYPE f32 +#elif defined(K_Q4_0) || defined(K_Q8_0) +#define K_TYPE u32 +#else +#define K_TYPE f16 +#endif + +#ifdef V_F32 +#define V_TYPE f32 +#elif defined(V_Q4_0) || defined(V_Q8_0) +#define V_TYPE u32 +#else +#define V_TYPE f16 +#endif + +#ifdef Q_F16 +#define Q_TYPE f16 +#else +#define Q_TYPE f32 +#endif + +#ifdef DST_F16 +#define DST_TYPE f16 +#else +#define DST_TYPE f32 +#endif + +#define HEAD_DIM_QK 64 +#define HEAD_DIM_V 64 + +#define KV_GRANULARITY 8 +#define KV_TILE 16 +#define WG_SIZE 64 +#ifndef VEC_NE +#define VEC_NE 4u +#endif + +#define KV_BLOCKS (KV_TILE / KV_GRANULARITY) + +struct Params { + offset_q: u32, + offset_k: u32, + offset_v: u32, + offset_mask: u32, + offset_sinks: u32, + offset_dst: u32, + + // shapes of Q/K/V + n_heads: u32, + seq_len_q: u32, + seq_len_kv: u32, + + // strides (in elements) + stride_q1: u32, + stride_q2: u32, + stride_q3: u32, + stride_k1: u32, + stride_k2: u32, + stride_k3: u32, + stride_v1: u32, + stride_v2: u32, + stride_v3: u32, + stride_mask3: u32, + + // repeat factors for K/V, e.g., MHA vs. MQA vs. GQA + q_per_kv: u32, + + // softmax params + scale: f32, + max_bias: f32, + logit_softcap: f32, + n_head_log2: f32, + m0: f32, + m1: f32, + +#ifdef BLK + blk_base: u32, + blk_nblk0: u32, + blk_nblk1: u32, +#endif + + tmp_data_base: u32, + tmp_stats_base: u32, + nwg: u32, +}; + +@group(0) @binding(0) var<storage, read_write> Q: array<Q_TYPE>; +#ifdef KV_OVERLAP +#if defined(K_Q4_0) || defined(K_Q8_0) +@group(0) @binding(1) var<storage, read_write> K: array<K_TYPE>; +#else +@group(0) @binding(1) var<storage, read_write> K: array<vec4<K_TYPE>>; +#endif +#define V K +#else +#if defined(K_Q4_0) || defined(K_Q8_0) +@group(0) @binding(1) var<storage, read_write> K: array<K_TYPE>; +#else +@group(0) @binding(1) var<storage, read_write> K: array<vec4<K_TYPE>>; +#endif +#if defined(V_Q4_0) || defined(V_Q8_0) +@group(0) @binding(2) var<storage, read_write> V: array<V_TYPE>; +#else +@group(0) @binding(2) var<storage, read_write> V: array<vec4<V_TYPE>>; +#endif +#endif +#if defined(MASK) && defined(SINKS) +#ifdef KV_OVERLAP +@group(0) @binding(2) var<storage, read_write> mask: array<f16>; +@group(0) @binding(3) var<storage, read_write> sinks: array<f32>; +#ifdef BLK +#define BLK_BINDING 4 +#define TMP_BINDING 5 +#define DST_BINDING 6 +#define PARAMS_BINDING 7 +#else +#define TMP_BINDING 4 +#define DST_BINDING 5 +#define PARAMS_BINDING 6 +#endif +#else +@group(0) @binding(3) var<storage, read_write> mask: array<f16>; +@group(0) @binding(4) var<storage, read_write> sinks: array<f32>; +#ifdef BLK +#define BLK_BINDING 5 +#define TMP_BINDING 6 +#define DST_BINDING 7 +#define PARAMS_BINDING 8 +#else +#define TMP_BINDING 5 +#define DST_BINDING 6 +#define PARAMS_BINDING 7 +#endif +#endif +#elif defined(MASK) +#ifdef KV_OVERLAP +@group(0) @binding(2) var<storage, read_write> mask: array<f16>; +#ifdef BLK +#define BLK_BINDING 3 +#define TMP_BINDING 4 +#define DST_BINDING 5 +#define PARAMS_BINDING 6 +#else +#define TMP_BINDING 3 +#define DST_BINDING 4 +#define PARAMS_BINDING 5 +#endif +#else +@group(0) @binding(3) var<storage, read_write> mask: array<f16>; +#ifdef BLK +#define BLK_BINDING 4 +#define TMP_BINDING 5 +#define DST_BINDING 6 +#define PARAMS_BINDING 7 +#else +#define TMP_BINDING 4 +#define DST_BINDING 5 +#define PARAMS_BINDING 6 +#endif +#endif +#elif defined(SINKS) +#ifdef KV_OVERLAP +@group(0) @binding(2) var<storage, read_write> sinks: array<f32>; +#define TMP_BINDING 3 +#define DST_BINDING 4 +#define PARAMS_BINDING 5 +#else +@group(0) @binding(3) var<storage, read_write> sinks: array<f32>; +#define TMP_BINDING 4 +#define DST_BINDING 5 +#define PARAMS_BINDING 6 +#endif +#else +#ifdef KV_OVERLAP +#define TMP_BINDING 2 +#define DST_BINDING 3 +#define PARAMS_BINDING 4 +#else +#define TMP_BINDING 3 +#define DST_BINDING 4 +#define PARAMS_BINDING 5 +#endif +#endif + +#ifdef BLK +@group(0) @binding(BLK_BINDING) var<storage, read_write> blk: array<u32>; +#endif +@group(0) @binding(TMP_BINDING) var<storage, read_write> tmp: array<f32>; +@group(0) @binding(DST_BINDING) var<storage, read_write> dst: array<vec4<DST_TYPE>>; +@group(0) @binding(PARAMS_BINDING) var<uniform> params: Params; + +// Just a very small float value. +const FLOAT_MIN: f32 = -1.0e9; + +var<workgroup> q_shmem: array<f32, HEAD_DIM_QK>; + +#ifndef KV_DIRECT +const kv_shmem_size = KV_TILE * max(HEAD_DIM_QK, HEAD_DIM_V); +// we can reuse the same shmem for K and V since we only need one at a time +var<workgroup> kv_shmem: array<f32, kv_shmem_size>; +#endif + +var<workgroup> o_shmem: array<f32, HEAD_DIM_V>; + +#ifdef MASK +// storage for mask values +var<workgroup> mask_shmem: array<f32, KV_TILE>; +#endif + +// note that we reuse the same storage for both since we only need one at a time +var<workgroup> inter_shmem: array<f32, KV_TILE>; + +// Storage for row max and exp sum during online softmax +fn calc_softmax_term(kv_idx: u32, slope: f32, has_bias: bool, apply_mask: bool) -> f32 { + var v = select(FLOAT_MIN, + inter_shmem[kv_idx] * params.scale, + kv_idx < KV_TILE); +#ifdef LOGIT_SOFTCAP + v = params.logit_softcap * tanh(v); +#endif +#ifdef MASK + if (apply_mask) { + var mask_val = select(0.0, mask_shmem[kv_idx], kv_idx < KV_TILE); + v += select(mask_val, slope * mask_val, has_bias); + } +#endif + return v; +} + +#ifndef KV_DIRECT +#define QUANT_SHMEM kv_shmem +#define QUANT_OUT_TYPE f32 +#include "quant_inner_loops.tmpl" +#include "flash_attn_quant_staging.tmpl" + +#if !defined(K_Q4_0) && !defined(K_Q8_0) +fn load_k_tile_block(local_x: u32, kv_count: u32, kv_tile: u32, k_head_offset: u32) { + for (var elem_idx = local_x * 4u; elem_idx < KV_TILE * HEAD_DIM_QK; elem_idx += WG_SIZE * 4u) { + let k_row = elem_idx / HEAD_DIM_QK; + let k_col = elem_idx % HEAD_DIM_QK; + let global_k_row = kv_tile + k_row; + let global_k_row_offset = k_head_offset + global_k_row * params.stride_k1; + let in_bounds = global_k_row < params.seq_len_kv && (k_col + 3u) < HEAD_DIM_QK; + let vec_idx = (global_k_row_offset + k_col) >> 2u; + let k4 = select(vec4<K_TYPE>(0.0), K[vec_idx], in_bounds); + kv_shmem[elem_idx + 0u] = f32(k4.x); + kv_shmem[elem_idx + 1u] = f32(k4.y); + kv_shmem[elem_idx + 2u] = f32(k4.z); + kv_shmem[elem_idx + 3u] = f32(k4.w); + } +} +#endif + +#if !defined(V_Q4_0) && !defined(V_Q8_0) +fn load_v_tile_block(local_x: u32, kv_count: u32, kv_tile: u32, v_head_offset: u32) { + for (var elem_idx = local_x * 4u; elem_idx < KV_TILE * HEAD_DIM_V; elem_idx += WG_SIZE * 4u) { + let v_row = elem_idx / HEAD_DIM_V; + let v_col = elem_idx % HEAD_DIM_V; + let global_v_row = kv_tile + v_row; + let global_v_row_offset = v_head_offset + global_v_row * params.stride_v1; + let in_bounds = global_v_row < params.seq_len_kv && (v_col + 3u) < HEAD_DIM_V; + let vec_idx = (global_v_row_offset + v_col) >> 2u; + let v4 = select(vec4<V_TYPE>(0.0), V[vec_idx], in_bounds); + kv_shmem[elem_idx + 0u] = f32(v4.x); + kv_shmem[elem_idx + 1u] = f32(v4.y); + kv_shmem[elem_idx + 2u] = f32(v4.z); + kv_shmem[elem_idx + 3u] = f32(v4.w); + } +} +#endif +#endif + +@compute @workgroup_size(WG_SIZE) +fn main(@builtin(workgroup_id) wg_id: vec3<u32>, + @builtin(local_invocation_id) local_id: vec3<u32>, + @builtin(subgroup_id) subgroup_id: u32, + @builtin(subgroup_size) subgroup_size: u32, + @builtin(num_subgroups) num_subgroups: u32, + @builtin(subgroup_invocation_id) sg_inv_id: u32) { + // Vec path processes exactly one query row per workgroup, so subgroup 0 can + // keep the running softmax state in private storage. + var row_max = FLOAT_MIN; + var exp_sum = 0.0; + + for (var i = local_id.x; i < HEAD_DIM_V; i += WG_SIZE) { + o_shmem[i] = 0.0; + } + + // workgroups per head/batch + let wg_per_head = params.seq_len_q; + let wg_per_batch = wg_per_head * params.n_heads; + + let dst2_stride = HEAD_DIM_V * params.n_heads; + let dst3_stride = dst2_stride * params.seq_len_q; + + let iwg = wg_id.x % params.nwg; + let base_wg_id = wg_id.x / params.nwg; + + // batch index + let batch_idx = base_wg_id / wg_per_batch; + let q_batch_offset = params.offset_q + batch_idx * params.stride_q3; + let k_batch_offset = params.offset_k + batch_idx * params.stride_k3; + let v_batch_offset = params.offset_v + batch_idx * params.stride_v3; + let wg_in_batch = base_wg_id % wg_per_batch; + + // head index + let head_idx = wg_in_batch / wg_per_head; + let q_head_offset = q_batch_offset + head_idx * params.stride_q2; + let k_head_idx = head_idx / params.q_per_kv; + let v_head_idx = k_head_idx; + let k_head_offset = k_batch_offset + k_head_idx * params.stride_k2; + let v_head_offset = v_batch_offset + v_head_idx * params.stride_v2; + + // Vec path handles one Q row per workgroup. + let wg_in_head = wg_in_batch % wg_per_head; + let q_row_start = wg_in_head; + +#ifdef MASK + // mask offset + let mask_global_offset = params.offset_mask + batch_idx * params.stride_mask3 + q_row_start * params.seq_len_kv; +#endif + + let head = f32(head_idx); + let has_bias = params.max_bias > 0.0; + let slope = select(1.0, select(pow(params.m1, 2.0 * (head - params.n_head_log2) + 1.0), pow(params.m0, head + 1.0), head < params.n_head_log2), has_bias); + + // load the single Q row into shared memory + for (var elem_idx = local_id.x; elem_idx < HEAD_DIM_QK; elem_idx += WG_SIZE) { + let global_q_row_offset = q_head_offset + q_row_start * params.stride_q1; + q_shmem[elem_idx] = select( + 0.0, + f32(Q[global_q_row_offset + elem_idx]), + q_row_start < params.seq_len_q); + } + + for (var kv_tile = iwg * KV_TILE; kv_tile < params.seq_len_kv; kv_tile += KV_TILE * params.nwg) { + let kv_count = min(KV_TILE, params.seq_len_kv - kv_tile); +#ifdef BLK + let q_blk = q_row_start; + let kv_blk = kv_tile / KV_TILE; + let blk_batch = select(0u, batch_idx, params.stride_mask3 > 0u); + let blk_idx = params.blk_base + (blk_batch * params.blk_nblk1 + q_blk) * params.blk_nblk0 + kv_blk; + let blk_state_local = blk[blk_idx]; +#else + let blk_state_local = 1u; +#endif + let blk_state = blk_state_local; + let skip_tile = blk_state == 0u; + for (var elem_idx = local_id.x; elem_idx < KV_TILE; elem_idx += WG_SIZE) { + inter_shmem[elem_idx] = 0.0; + } + + // load k tile into shared memory +#ifndef KV_DIRECT + load_k_tile_block(local_id.x, kv_count, kv_tile, k_head_offset); +#endif + + workgroupBarrier(); + + // accumulate q block * k block into registers across the entire KV tile + if (!skip_tile) { + let num_of_threads = subgroup_size / VEC_NE; + let tx = sg_inv_id % num_of_threads; + let ty = sg_inv_id / num_of_threads; + if (subgroup_id == 0u && q_row_start < params.seq_len_q) { + for (var kv_base : u32 = 0u; kv_base < KV_TILE; kv_base += VEC_NE) { + let kv_idx = kv_base + ty; + var partial_sum: f32 = 0.0; + let kv_valid = kv_idx < KV_TILE && (kv_tile + kv_idx) < params.seq_len_kv; + if (kv_valid) { + for (var i = tx; i < (HEAD_DIM_QK / 4u); i += num_of_threads) { + let q_off = i * 4u; + + let qv = vec4<f32>( + q_shmem[q_off + 0u], + q_shmem[q_off + 1u], + q_shmem[q_off + 2u], + q_shmem[q_off + 3u]); +#ifdef KV_DIRECT + let idx = k_head_offset + (kv_tile + kv_idx) * params.stride_k1 + (i * 4u); + let kv = vec4<f32>(K[idx >> 2u]); +#else + let idx = kv_idx * HEAD_DIM_QK + (i * 4u); + let kv = vec4<f32>( + kv_shmem[idx + 0u], + kv_shmem[idx + 1u], + kv_shmem[idx + 2u], + kv_shmem[idx + 3u]); +#endif + partial_sum += dot(qv, kv); + } + } + var sum = partial_sum; + // Reduce over tx threads (NL) for this ty stripe. + var tx_delta = num_of_threads >> 1u; + loop { + if (tx_delta == 0u) { + break; + } + let sh = subgroupShuffleDown(sum, tx_delta); + if (tx < tx_delta) { + sum += sh; + } + tx_delta >>= 1u; + } + + let sum_bcast = subgroupShuffle(sum, num_of_threads * ty); + if (tx == 0u && kv_valid) { + inter_shmem[kv_idx] = sum_bcast; + } + } + } + } + + +#ifdef MASK + let apply_mask = !skip_tile && (blk_state != 2u); + if (apply_mask) { + // load mask tile into shared memory for this KV block + for (var elem_idx = local_id.x; elem_idx < KV_TILE; elem_idx += WG_SIZE) { + let global_k_col = kv_tile + elem_idx; + let mask_in_bounds = q_row_start < params.seq_len_q && global_k_col < params.seq_len_kv; + let mask_idx = mask_global_offset + global_k_col; + mask_shmem[elem_idx] = select(0.0f, f32(mask[mask_idx]), mask_in_bounds); + } + } +#else + let apply_mask = false; +#endif + + workgroupBarrier(); + + // online softmax + if (!skip_tile && subgroup_id == 0u && q_row_start < params.seq_len_q) { + var prev_max = row_max; + var final_max = prev_max; + // pass 1: compute final max across the full KV tile in chunks + for (var kv_offset = 0u; kv_offset < KV_TILE; kv_offset += subgroup_size) { + let kv_idx = kv_offset + sg_inv_id; + let kv_valid = kv_tile + kv_idx < params.seq_len_kv && kv_idx < KV_TILE; + let softmax_term = select(FLOAT_MIN, + calc_softmax_term(kv_idx, slope, has_bias, apply_mask), + kv_valid); + final_max = subgroupMax(max(final_max, softmax_term)); + } + + var total_exp_term: f32 = 0.0; + // pass 2: compute exp sum and write P using final_max + for (var kv_offset = 0u; kv_offset < KV_TILE; kv_offset += subgroup_size) { + let kv_idx = kv_offset + sg_inv_id; + let softmax_term = calc_softmax_term(kv_idx, slope, has_bias, apply_mask); + let cur_p = select(0.0, + exp(softmax_term - final_max), + kv_tile + kv_idx < params.seq_len_kv && kv_idx < KV_TILE); + total_exp_term += subgroupAdd(cur_p); + if (kv_idx < KV_TILE) { + inter_shmem[kv_idx] = cur_p; + } + } + + let cur_exp = exp(prev_max - final_max); + + row_max = final_max; + exp_sum = exp_sum * cur_exp + total_exp_term; + + for (var elem_idx = sg_inv_id; elem_idx < HEAD_DIM_V; elem_idx += subgroup_size) { + o_shmem[elem_idx] = o_shmem[elem_idx] * cur_exp; + } + } + + // load v tile into shared memory +#ifndef KV_DIRECT + load_v_tile_block(local_id.x, kv_count, kv_tile, v_head_offset); +#endif + + workgroupBarrier(); + + if (!skip_tile) { + // we have P (KV_TILE) in inter_shmem and V (KV_TILE x head_dim_v) in kv_shmem + // we want to compute O += P * V across the full KV tile + let ne_threads : u32 = VEC_NE; + let nl_threads = max(1u, subgroup_size / ne_threads); + let tx_pv = sg_inv_id % nl_threads; + let ty_pv = sg_inv_id / nl_threads; + if (subgroup_id == 0u && q_row_start < params.seq_len_q) { + for (var vec_col = tx_pv; vec_col < (HEAD_DIM_V / 4u); vec_col += nl_threads) { + var lo = vec4<f32>(0.0, 0.0, 0.0, 0.0); + for (var cc = 0u; cc < KV_TILE / ne_threads; cc += 1u) { + let kv_idx = cc * ne_threads + ty_pv; + let v_row = kv_tile + kv_idx; + if (v_row >= params.seq_len_kv) { + continue; + } + + let p = inter_shmem[kv_idx]; +#ifdef KV_DIRECT + let v_idx = v_head_offset + v_row * params.stride_v1 + vec_col * 4u; + let v4 = vec4<f32>(V[v_idx >> 2u]); +#else + let v_idx = kv_idx * HEAD_DIM_V + vec_col * 4u; + let v4 = vec4<f32>( + kv_shmem[v_idx + 0u], + kv_shmem[v_idx + 1u], + kv_shmem[v_idx + 2u], + kv_shmem[v_idx + 3u]); +#endif + lo += p * v4; + } + + var lo_x = lo.x; + var lo_y = lo.y; + var lo_z = lo.z; + var lo_w = lo.w; + // Reduce over ty threads (NE) for this tx thread. + var ty_delta = ne_threads >> 1u; + loop { + if (ty_delta == 0u) { + break; + } + let thread_delta = ty_delta * nl_threads; + let shx = subgroupShuffleDown(lo_x, thread_delta); + let shy = subgroupShuffleDown(lo_y, thread_delta); + let shz = subgroupShuffleDown(lo_z, thread_delta); + let shw = subgroupShuffleDown(lo_w, thread_delta); + if (ty_pv < ty_delta) { + lo_x += shx; + lo_y += shy; + lo_z += shz; + lo_w += shw; + } + ty_delta >>= 1u; + } + + if (ty_pv == 0u) { + let elem_base = vec_col * 4u; + o_shmem[elem_base + 0u] = o_shmem[elem_base + 0u] + lo_x; + o_shmem[elem_base + 1u] = o_shmem[elem_base + 1u] + lo_y; + o_shmem[elem_base + 2u] = o_shmem[elem_base + 2u] + lo_z; + o_shmem[elem_base + 3u] = o_shmem[elem_base + 3u] + lo_w; + } + } + } + } + + workgroupBarrier(); + } + + +#ifdef SINKS + // Sinks are global terms and must be applied exactly once across split workgroups. + if (iwg == 0u && subgroup_id == 0u && q_row_start < params.seq_len_q) { + var prev_max = row_max; + + // for non-sink threads, exp(FLOAT_MIN) effectively zeroes out their contribution to the sum + let sink_val = select(FLOAT_MIN, sinks[params.offset_sinks + head_idx], sg_inv_id == 0u); + let new_max = subgroupMax(max(prev_max, sink_val)); + let max_exp = exp(prev_max - new_max); + let sink_exp = exp(sink_val - new_max); + + let sink_exp_sum = subgroupAdd(sink_exp); + + row_max = new_max; + exp_sum = exp_sum * max_exp + sink_exp_sum; + + for (var elem_idx = sg_inv_id; elem_idx < HEAD_DIM_V; elem_idx += subgroup_size) { + o_shmem[elem_idx] = o_shmem[elem_idx] * max_exp; + } + } + workgroupBarrier(); +#endif + let rows_per_batch = params.n_heads * params.seq_len_q; + if (subgroup_id == 0u && q_row_start < params.seq_len_q) { + if (params.nwg == 1u) { + let scale = select(0.0, 1.0 / exp_sum, exp_sum != 0.0); + let row_base: u32 = params.offset_dst + batch_idx * dst3_stride + q_row_start * dst2_stride + + head_idx * HEAD_DIM_V; + + for (var elem_base = sg_inv_id * 4u; elem_base < HEAD_DIM_V; elem_base += subgroup_size * 4u) { + let v = vec4<f32>( + f32(o_shmem[elem_base + 0u]) * scale, + f32(o_shmem[elem_base + 1u]) * scale, + f32(o_shmem[elem_base + 2u]) * scale, + f32(o_shmem[elem_base + 3u]) * scale + ); + + let dst_vec_index: u32 = (row_base + elem_base) >> 2u; + dst[dst_vec_index] = vec4<DST_TYPE>(v); + } + } else { + let rid = batch_idx * rows_per_batch + head_idx * params.seq_len_q + q_row_start; + let tmp_row_data_base = params.tmp_data_base + rid * (HEAD_DIM_V * params.nwg) + iwg * HEAD_DIM_V; + let tmp_row_stats_base = params.tmp_stats_base + rid * (2u * params.nwg) + 2u * iwg; + + for (var elem_base = sg_inv_id * 4u; + elem_base < HEAD_DIM_V; + elem_base += subgroup_size * 4u) { + + let tbase = tmp_row_data_base + elem_base; + tmp[tbase + 0u] = f32(o_shmem[elem_base + 0u]); + tmp[tbase + 1u] = f32(o_shmem[elem_base + 1u]); + tmp[tbase + 2u] = f32(o_shmem[elem_base + 2u]); + tmp[tbase + 3u] = f32(o_shmem[elem_base + 3u]); + } + + if (sg_inv_id == 0u) { + tmp[tmp_row_stats_base + 0u] = exp_sum; + tmp[tmp_row_stats_base + 1u] = row_max; + } + } + } +} diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/gated_delta_net.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/gated_delta_net.wgsl new file mode 100644 index 00000000000..7d7b3475549 --- /dev/null +++ b/ggml/src/ggml-webgpu/wgsl-shaders/gated_delta_net.wgsl @@ -0,0 +1,149 @@ +@group(0) @binding(0) +var<storage, read_write> src_q: array<f32>; + +@group(0) @binding(1) +var<storage, read_write> src_k: array<f32>; + +@group(0) @binding(2) +var<storage, read_write> src_v: array<f32>; + +@group(0) @binding(3) +var<storage, read_write> src_g: array<f32>; + +@group(0) @binding(4) +var<storage, read_write> src_beta: array<f32>; + +@group(0) @binding(5) +var<storage, read_write> src_state: array<f32>; + +@group(0) @binding(6) +var<storage, read_write> dst: array<f32>; + +struct Params { + h: u32, + n_tokens: u32, + n_seqs: u32, + s_off: u32, + + sq1: u32, + sq2: u32, + sq3: u32, + + sv1: u32, + sv2: u32, + sv3: u32, + + sb1: u32, + sb2: u32, + sb3: u32, + + neq1: u32, + rq3: u32, + K: u32, + scale: f32, +}; + +@group(0) @binding(7) +var<uniform> params: Params; + +var<workgroup> sh_k: array<f32, S_V>; +var<workgroup> sh_q: array<f32, S_V>; +var<workgroup> sh_g: array<f32, S_V>; + +@compute @workgroup_size(WG_SIZE) +fn main( + @builtin(workgroup_id) workgroup_id: vec3<u32>, + @builtin(local_invocation_id) local_id: vec3<u32> +) { + let head_id = workgroup_id.x; + let seq_id = workgroup_id.y; + let col = local_id.x; + + let iq1 = head_id % params.neq1; + let iq3 = seq_id / params.rq3; + + let state_size = S_V * S_V; + // input state holds s0 only [S_v, S_v, H, n_seqs]: per-seq stride is H*D. + let state_in_base = (seq_id * params.h + head_id) * state_size; + let state_out_base = (seq_id * params.h + head_id) * state_size; + let state_size_per_snap = state_size * params.h * params.n_seqs; + + var state: array<f32, S_V>; + for (var i = 0u; i < S_V; i++) { + state[i] = src_state[state_in_base + col * S_V + i]; + } + + var attn_off = (seq_id * params.n_tokens * params.h + head_id) * S_V; + + for (var t = 0u; t < params.n_tokens; t++) { + let q_off = iq3 * params.sq3 + t * params.sq2 + iq1 * params.sq1; + let k_off = q_off; + let v_off = seq_id * params.sv3 + t * params.sv2 + head_id * params.sv1; + let gb_off = seq_id * params.sb3 + t * params.sb2 + head_id * params.sb1; + + sh_q[col] = src_q[q_off + col]; + sh_k[col] = src_k[k_off + col]; + +#ifdef KDA + let g_base = gb_off * S_V; + sh_g[col] = exp(src_g[g_base + col]); +#endif + + workgroupBarrier(); + + let v_val = src_v[v_off + col]; + let beta_val = src_beta[gb_off]; + + var kv_col = 0.0; + var delta_col = 0.0; + var attn_col = 0.0; + +#ifdef KDA + for (var i = 0u; i < S_V; i++) { + kv_col += (sh_g[i] * state[i]) * sh_k[i]; + } + + delta_col = (v_val - kv_col) * beta_val; + + for (var i = 0u; i < S_V; i++) { + state[i] = sh_g[i] * state[i] + sh_k[i] * delta_col; + attn_col += state[i] * sh_q[i]; + } +#else + let g_val = exp(src_g[gb_off]); + + for (var i = 0u; i < S_V; i++) { + kv_col += state[i] * sh_k[i]; + } + + delta_col = (v_val - g_val * kv_col) * beta_val; + + for (var i = 0u; i < S_V; i++) { + state[i] = g_val * state[i] + sh_k[i] * delta_col; + attn_col += state[i] * sh_q[i]; + } +#endif + + dst[attn_off + col] = attn_col * params.scale; + attn_off += S_V * params.h; + + if (params.K > 1u) { + // snapshot slot mapping: slot 0 = most recent state, slot s = s tokens back. + let target_slot = i32(params.n_tokens) - 1 - i32(t); + if (target_slot >= 0 && target_slot < i32(params.K)) { + let slot_base = params.s_off + u32(target_slot) * state_size_per_snap + state_out_base; + for (var i = 0u; i < S_V; i++) { + dst[slot_base + col * S_V + i] = state[i]; + } + } + } + + workgroupBarrier(); + } + + if (params.K == 1u) { + for (var i = 0u; i < S_V; i++) { + dst[params.s_off + state_out_base + col * S_V + i] = state[i]; + } + } +} diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/get_rows.tmpl.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/get_rows.wgsl similarity index 69% rename from ggml/src/ggml-webgpu/wgsl-shaders/get_rows.tmpl.wgsl rename to ggml/src/ggml-webgpu/wgsl-shaders/get_rows.wgsl index f80ce1fc550..78d61a93d28 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/get_rows.tmpl.wgsl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/get_rows.wgsl @@ -1,240 +1,70 @@ -#define(VARIANTS) - -[ - { - "SHADER_SUFFIX": "f32_vec", - "REPLS": { - "TYPE" : "vec4<f32>", - "DST_TYPE": "vec4<f32>", - "BLOCK_SIZE": 4 - }, - "DECLS": ["F32_VEC"] - }, - { - "REPLS": { - "TYPE" : "f32", - "DST_TYPE": "f32", - "BLOCK_SIZE": 1 - }, - "DECLS": ["F32"] - }, - { - "REPLS": { - "TYPE" : "f16", - "DST_TYPE": "f32", - "BLOCK_SIZE": 1 - }, - "DECLS": ["F16"] - }, - { - "REPLS": { - "TYPE" : "i32", - "DST_TYPE": "i32", - "BLOCK_SIZE": 1 - }, - "DECLS": ["I32"] - }, - { - "REPLS": { - "TYPE" : "q4_0", - "DST_TYPE": "f32", - "BLOCK_SIZE": 32 - }, - "DECLS": ["BYTE_HELPERS", "Q4_0_T", "Q4_0"] - }, - { - "REPLS": { - "TYPE" : "q4_1", - "DST_TYPE": "f32", - "BLOCK_SIZE": 32 - }, - "DECLS": ["BYTE_HELPERS", "Q4_1_T", "Q4_1"] - }, - { - "REPLS": { - "TYPE" : "q5_0", - "DST_TYPE": "f32", - "BLOCK_SIZE": 32 - }, - "DECLS": ["BYTE_HELPERS", "Q5_0_T", "Q5_0"] - }, - { - "REPLS": { - "TYPE" : "q5_1", - "DST_TYPE": "f32", - "BLOCK_SIZE": 32 - }, - "DECLS": ["BYTE_HELPERS", "Q5_1_T", "Q5_1"] - }, - { - "REPLS": { - "TYPE" : "q8_0", - "DST_TYPE": "f32", - "BLOCK_SIZE": 32 - }, - "DECLS": ["BYTE_HELPERS", "Q8_0_T", "Q8_0"] - }, - { - "REPLS": { - "TYPE" : "q2_k", - "DST_TYPE": "f32", - "BLOCK_SIZE": 256 - }, - "DECLS": ["BYTE_HELPERS", "Q2_K_T", "Q2_K"] - }, - { - "REPLS": { - "TYPE" : "q3_k", - "DST_TYPE": "f32", - "BLOCK_SIZE": 256 - }, - "DECLS": ["BYTE_HELPERS", "Q3_K_T", "Q3_K"] - }, - { - "REPLS": { - "TYPE" : "q4_k", - "DST_TYPE": "f32", - "BLOCK_SIZE": 256 - }, - "DECLS": ["Q45_K_SCALE_MIN", "BYTE_HELPERS", "Q4_K_T", "Q4_K"] - }, - { - "REPLS": { - "TYPE" : "q5_k", - "DST_TYPE": "f32", - "BLOCK_SIZE": 256 - }, - "DECLS": ["Q45_K_SCALE_MIN", "BYTE_HELPERS", "Q5_K_T", "Q5_K"] - }, - { - "REPLS": { - "TYPE" : "q6_k", - "DST_TYPE": "f32", - "BLOCK_SIZE": 256 - }, - "DECLS": ["BYTE_HELPERS", "Q6_K_T", "Q6_K"] - }, - { - "REPLS": { - "TYPE" : "iq2_xxs", - "DST_TYPE": "f32", - "BLOCK_SIZE": 256 - }, - "DECLS": ["BYTE_HELPERS", "IQ23_TABLES", "IQ2_XXS_GRID", "IQ2_XXS_T", "IQ2_XXS"] - }, - { - "REPLS": { - "TYPE" : "iq2_xs", - "DST_TYPE": "f32", - "BLOCK_SIZE": 256 - }, - "DECLS": ["BYTE_HELPERS", "IQ23_TABLES", "IQ2_XS_GRID", "IQ2_XS_T", "IQ2_XS"] - }, - { - "REPLS": { - "TYPE": "iq2_s", - "DST_TYPE": "f32", - "BLOCK_SIZE": 256 - }, - "DECLS": ["BYTE_HELPERS", "IQ23_TABLES", "IQ2_S_GRID", "IQ2_S_T", "IQ2_S"] - }, - { - "REPLS": { - "TYPE": "iq3_xxs", - "DST_TYPE": "f32", - "BLOCK_SIZE": 256 - }, - "DECLS": ["BYTE_HELPERS", "IQ23_TABLES", "IQ3_XSS_GRID", "IQ3_XSS_T", "IQ3_XSS"] - }, - { - "REPLS": { - "TYPE": "iq3_s", - "DST_TYPE": "f32", - "BLOCK_SIZE": 256 - }, - "DECLS": ["BYTE_HELPERS", "IQ23_TABLES", "IQ3_S_GRID", "IQ3_S_T", "IQ3_S"] - }, - { - "REPLS": { - "TYPE": "iq1_s", - "DST_TYPE": "f32", - "BLOCK_SIZE": 256 - }, - "DECLS": ["BYTE_HELPERS", "IQ1_GRID", "IQ1_S_T", "IQ1_S"] - }, - { - "REPLS": { - "TYPE": "iq1_m", - "DST_TYPE": "f32", - "BLOCK_SIZE": 256 - }, - "DECLS": ["BYTE_HELPERS", "IQ1_GRID", "IQ1_M_T", "IQ1_M"] - }, - { - "REPLS": { - "TYPE": "iq4_nl", - "DST_TYPE": "f32", - "BLOCK_SIZE": 32, - }, - "DECLS": ["BYTE_HELPERS", "IQ4_GRID", "IQ4_NL_T", "IQ4_NL"] - }, - { - "REPLS": { - "TYPE": "iq4_xs", - "DST_TYPE": "f32", - "BLOCK_SIZE": 256, - }, - "DECLS": ["BYTE_HELPERS", "IQ4_GRID", "IQ4_XS_T", "IQ4_XS"] - } -] - -#end(VARIANTS) - -#define(DECLS) - -#decl(F32_VEC) +enable f16; +#define DECLARE_BYTE_LOADERS_SRC +#include "common_decls.tmpl" + + +#ifdef F32_VEC fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { dst[(dst_base / 4) + offset] = src[(src_base / 4) + offset]; } -#enddecl(F32_VEC) +#endif -#decl(F32) +#ifdef F32 fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { dst[dst_base + offset] = src[src_base + offset]; } -#enddecl(F32) +#endif -#decl(F16) +#ifdef F16 fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { dst[dst_base + offset] = f32(src[src_base + offset]); } -#enddecl(F16) +#endif -#decl(I32) +#ifdef I32 fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { dst[dst_base + offset] = src[src_base + offset]; } -#enddecl(I32) +#endif -#decl(Q4_0) +#ifdef Q1_0 fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { - let block_q4_0 = src[src_base + offset]; - let d = f32(block_q4_0.d); - for (var j: u32 = 0; j < 4; j++) { - let q_packed = bitcast<u32>(vec2(block_q4_0.qs[2 * j], block_q4_0.qs[2 * j + 1])); + let block_byte_base = (src_base + offset) * 18; + let d = load_f16_as_f32_at_src(block_byte_base); + for (var j: u32 = 0u; j < 4u; j++) { + let q_packed = load_u32_at_src(block_byte_base + 2u + j * 4u); + let dst_base128 = dst_base + offset * 128u + j * 32u; + for (var k: u32 = 0; k < 4u; k++) { + let q_byte = get_byte(q_packed, k); + for (var bit: u32 = 0; bit < 8u; bit++) { + let w = select(-d, d, ((q_byte >> bit) & 1u) != 0u); + dst[dst_base128 + k * 8u + bit] = w; + } + } + } +} +#endif + +#ifdef Q4_0 +fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { + let block_byte_base = (src_base + offset) * 18; // Block stride: 18 bytes + let d = load_f16_as_f32_at_src(block_byte_base); + for (var j: u32 = 0u; j < 4; j++) { + let q_byte_offset = block_byte_base + 2 + j * 4; + let q_packed = load_u32_at_src(q_byte_offset); for (var k: u32 = 0; k < 4; k++) { let q_byte = get_byte(q_packed, k); - let q_hi = (f32((q_byte >> 4) & 0xF) - 8.0f) * d; - let q_lo = (f32(q_byte & 0xF) - 8.0f) * d; + let q_hi = (f32((q_byte >> 4) & 0xF) - 8.0) * d; + let q_lo = (f32(q_byte & 0xFu) - 8.0) * d; let dst_offset = dst_base + offset * 32 + j * 4 + k; dst[dst_offset] = q_lo; - dst[dst_offset + 16] = q_hi; + dst[dst_offset + 16u] = q_hi; } } } -#enddecl(Q4_0) +#endif -#decl(Q4_1) +#ifdef Q4_1 fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { let block_q4_1 = src[src_base + offset]; let d = f32(block_q4_1.d); @@ -251,31 +81,35 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { } } } -#enddecl(Q4_1) +#endif -#decl(Q5_0) +#ifdef Q5_0 fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { - let block_q5_0 = src[src_base + offset]; - let d = f32(block_q5_0.d); - let qh_packed = bitcast<u32>(vec2(block_q5_0.qh[0], block_q5_0.qh[1])); + let block_byte_base = (src_base + offset) * 22; // Block stride: 22 bytes + let d = load_f16_as_f32_at_src(block_byte_base); + let qh_packed = load_u32_at_src(block_byte_base + 2); for (var j: u32 = 0; j < 4; j++) { - let q_packed = bitcast<u32>(vec2(block_q5_0.qs[2 * j], block_q5_0.qs[2 * j + 1])); + let q_byte_offset = block_byte_base + 6 + j * 4; + let q_packed = load_u32_at_src(q_byte_offset); + for (var k: u32 = 0; k < 4; k++) { let q_byte = get_byte(q_packed, k); + let qh_hi = (qh_packed >> (j * 4 + k + 12)) & 0x10; let q_hi = (f32(((q_byte >> 4) & 0xF) | qh_hi) - 16.0) * d; + let qh_lo = ((qh_packed >> (j * 4 + k)) << 4) & 0x10; let q_lo = (f32((q_byte & 0xF) | qh_lo) - 16.0) * d; + let dst_offset = dst_base + offset * 32 + j * 4 + k; dst[dst_offset] = q_lo; dst[dst_offset + 16] = q_hi; } } } +#endif -#enddecl(Q5_0) - -#decl(Q5_1) +#ifdef Q5_1 fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { let block_q5_1 = src[src_base + offset]; let d = f32(block_q5_1.d); @@ -294,25 +128,26 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { } } } -#enddecl(Q5_1) +#endif -#decl(Q8_0) +#ifdef Q8_0 fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { - let block_q8_0 = src[src_base + offset]; - let d = f32(block_q8_0.d); - for (var j: u32 = 0; j < 8; j++) { - let q_packed = bitcast<u32>(vec2(block_q8_0.qs[2 * j], block_q8_0.qs[2 * j + 1])); - for (var k: u32 = 0; k < 4; k++) { + let block_byte_base = (src_base + offset) * 34; // Block stride: 34 bytes + let d = load_f16_as_f32_at_src(block_byte_base); + for (var j: u32 = 0u; j < 8u; j++) { + let q_byte_offset = block_byte_base + 2u + j * 4u; + let q_packed = load_u32_at_src(q_byte_offset); + for (var k: u32 = 0u; k < 4u; k++) { let q_byte = get_byte_i32(q_packed, k); let q_val = f32(q_byte) * d; - let dst_offset = dst_base + offset * 32 + j * 4 + k; + let dst_offset = dst_base + offset * 32u + j * 4u + k; dst[dst_offset] = q_val; } } } -#enddecl(Q8_0) +#endif -#decl(Q2_K) +#ifdef Q2_K fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { let block = src[src_base + offset]; let d = f32(block.d); @@ -340,40 +175,46 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { } } } -#enddecl(Q2_K) +#endif -#decl(Q3_K) +#ifdef Q3_K fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { - let block = src[src_base + offset]; - let d = f32(block.d); + let block_byte_base = (src_base + offset) * 110; // Block stride: 110 bytes + + // Bytes 108-109: f16 scale 'd' + let d = load_f16_as_f32_at_src(block_byte_base + 108); - // extract 6-bit scales, which consist of 4-bits from first 8 bytes of scale, - // and 2-bits from the last 4 bytes + // Bytes 96-107: 12 bytes of scales (3 u32s) let kmask1: u32 = 0x03030303; let kmask2: u32 = 0x0f0f0f0f; + var scale_vals: array<u32, 4>; - for (var i: u32 = 0; i < 4; i++) { - scale_vals[i] = bitcast<u32>(vec2(block.scales[2 * i], block.scales[2 * i + 1])); - } + scale_vals[0] = load_u32_at_src(block_byte_base + 96); + scale_vals[1] = load_u32_at_src(block_byte_base + 100); + scale_vals[2] = load_u32_at_src(block_byte_base + 104); + var tmp: u32 = scale_vals[2]; scale_vals[2] = ((scale_vals[0] >> 4) & kmask2) | (((tmp >> 4) & kmask1) << 4); scale_vals[3] = ((scale_vals[1] >> 4) & kmask2) | (((tmp >> 6) & kmask1) << 4); scale_vals[0] = (scale_vals[0] & kmask2) | ((tmp & kmask1) << 4); scale_vals[1] = (scale_vals[1] & kmask2) | (((tmp >> 2) & kmask1) << 4); - // convert arrays of f16 -> u32 + // Bytes 0-31: 32 bytes of hmask (8 u32s) var hmask_vals: array<u32, 8>; for (var i: u32 = 0; i < 8; i++) { - hmask_vals[i] = bitcast<u32>(vec2(block.hmask[2 * i], block.hmask[2 * i + 1])); + hmask_vals[i] = load_u32_at_src(block_byte_base + i * 4); } + + // Bytes 32-95: 64 bytes of qs (16 u32s) var qs_vals: array<u32, 16>; - for (var i: u32 = 0; i < 16; i++) { - qs_vals[i] = bitcast<u32>(vec2(block.qs[2 * i], block.qs[2 * i + 1])); + for (var i: u32 = 0u; i < 16; i++) { + qs_vals[i] = load_u32_at_src(block_byte_base + 32 + i * 4); } var dst_i = dst_base + offset * 256; var is: u32 = 0; var m: u32 = 1; + // 2 halves of the block (128 elements each) for (var q_b_idx: u32 = 0; q_b_idx < 64; q_b_idx += 32) { // 4 groups (each group has 2 blocks of 16 elements) @@ -383,11 +224,13 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { let sc = get_byte(scale_vals[is / 4], is % 4); is++; let dl = d * (f32(sc) - 32.0); - for (var l: u32 = 0u; l < 16u; l++) { + + for (var l: u32 = 0; l < 16; l++) { let q_idx = q_b_idx + k + l; let hm_idx = k + l; let q_byte = get_byte(qs_vals[q_idx / 4], q_idx % 4); let hmask_byte = get_byte(hmask_vals[hm_idx / 4], hm_idx % 4); + let hm = select(4.0, 0.0, (hmask_byte & m) != 0); let qs_val = (q_byte >> shift) & 3; dst[dst_i] = (f32(qs_val) - hm) * dl; @@ -398,9 +241,9 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { } } } -#enddecl(Q3_K) +#endif -#decl(Q4_K) +#ifdef Q4_K // 8 blocks of 32 elements each fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { let block = src[src_base + offset]; @@ -425,9 +268,9 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { } } } -#enddecl(Q4_K) +#endif -#decl(Q5_K) +#ifdef Q5_K fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { let block = src[src_base + offset]; let d = f32(block.d); @@ -455,26 +298,32 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { } } } -#enddecl(Q5_K) +#endif -#decl(Q6_K) +#ifdef Q6_K // 16 blocks of 16 elements each fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { - let block = src[src_base + offset]; - let d = f32(block.d); + let block_byte_base = (src_base + offset) * 210; // Block stride: 210 bytes - // convert arrays of f16 -> u32 + // Bytes 208-209: f16 scale 'd' + let d = load_f16_as_f32_at_src(block_byte_base + 208); + + // Bytes 0-127: 128 bytes of ql (32 u32s) var ql_vals: array<u32, 32>; for (var i: u32 = 0; i < 32; i++) { - ql_vals[i] = bitcast<u32>(vec2(block.ql[2 * i], block.ql[2 * i + 1])); + ql_vals[i] = load_u32_at_src(block_byte_base + i * 4); } + + // Bytes 128-191: 64 bytes of qh (16 u32s) var qh_vals: array<u32, 16>; - for (var i: u32 = 0; i < 16; i++) { - qh_vals[i] = bitcast<u32>(vec2(block.qh[2 * i], block.qh[2 * i + 1])); + for (var i: u32 = 0; i < 16u; i++) { + qh_vals[i] = load_u32_at_src(block_byte_base + 128 + i * 4u); } + + // Bytes 192-207: 16 bytes of scales (4 u32s) var scale_vals: array<u32, 4>; for (var i: u32 = 0; i < 4; i++) { - scale_vals[i] = bitcast<u32>(vec2(block.scales[2 * i], block.scales[2 * i + 1])); + scale_vals[i] = load_u32_at_src(block_byte_base + 192 + i * 4); } var dst_i = dst_base + offset * 256; @@ -511,17 +360,18 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { sc_b_idx += 8; } } +#endif -#enddecl(Q6_K) - -#decl(IQ2_XXS) +#ifdef IQ2_XXS fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { - let block = src[src_base + offset]; - let d = f32(block.d); + let block_byte_base = (src_base + offset) * 66; // Block stride: 66 bytes + let d = load_f16_as_f32_at_src(block_byte_base); var dst_i = dst_base + offset * 256; for (var ib: u32 = 0; ib < 32; ib += 4) { - let aux0 = bitcast<u32>(vec2(block.qs[ib], block.qs[ib + 1])); - let aux1 = bitcast<u32>(vec2(block.qs[ib + 2], block.qs[ib + 3])); + let aux0_offset = block_byte_base + 2 + ib * 2; + let aux1_offset = block_byte_base + 2 + (ib + 2) * 2; + let aux0 = load_u32_at_src(aux0_offset); + let aux1 = load_u32_at_src(aux1_offset); let db = d * (0.5 + f32(aux1 >> 28)) * 0.25; for (var l: u32 = 0; l < 4; l++) { let ig = get_byte(aux0, l) * 8; @@ -536,17 +386,21 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { } } } -#enddecl(IQ2_XXS) +#endif + + -#decl(IQ2_XS) +#ifdef IQ2_XS fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { - let block = src[src_base + offset]; - let d = f32(block.d); + let block_byte_base = (src_base + offset) * 74; // Block stride: 74 bytes + let d = load_f16_as_f32_at_src(block_byte_base); var dst_i = dst_base + offset * 256; + var scale_vals = array<u32, 2>( - bitcast<u32>(vec2(block.scales[0], block.scales[1])), - bitcast<u32>(vec2(block.scales[2], block.scales[3])) + load_u32_at_src(block_byte_base + 66), + load_u32_at_src(block_byte_base + 70) ); + for (var ib: u32 = 0; ib < 32; ib += 4) { let s = get_byte(scale_vals[ib / 16], (ib % 16) / 4); let db = array<f32, 2>( @@ -554,7 +408,8 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { d * (0.5 + f32(s >> 4)) * 0.25 ); for (var l: u32 = 0; l < 4; l++) { - let qs_val = bitcast<u32>(vec2(block.qs[ib + l], 0.0)); + let qs_offset = block_byte_base + 2 + (ib + l) * 2; + let qs_val = load_u32_at_src(qs_offset) & 0xFFFF; let ig = (qs_val & 511) * 8; let is = qs_val >> 9; let signs = get_byte(ksigns_iq2xs[is / 4], is % 4); @@ -568,25 +423,27 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { } } } -#enddecl(IQ2_XS) +#endif -#decl(IQ2_S) +#ifdef IQ2_S fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { - let block = src[src_base + offset]; - let d = f32(block.d); + let block_byte_base = (src_base + offset) * 82; // Block stride: 82 bytes + let d = load_f16_as_f32_at_src(block_byte_base); var dst_i = dst_base + offset * 256; + var qs_vals : array<u32, 16>; for (var i: u32 = 0; i < 16; i++) { - qs_vals[i] = bitcast<u32>(vec2(block.qs[i * 2], block.qs[i * 2 + 1])); + qs_vals[i] = load_u32_at_src(block_byte_base + 2 + i * 4); } - var qh_vals = array<u32, 2>( - bitcast<u32>(vec2(block.qh[0], block.qh[1])), - bitcast<u32>(vec2(block.qh[2], block.qh[3])) - ); - var scale_vals = array<u32, 2>( - bitcast<u32>(vec2(block.scales[0], block.scales[1])), - bitcast<u32>(vec2(block.scales[2], block.scales[3])) - ); + + var qh_vals: array<u32, 2>; + qh_vals[0] = load_u32_at_src(block_byte_base + 66); + qh_vals[1] = load_u32_at_src(block_byte_base + 70); + + var scale_vals: array<u32, 2>; + scale_vals[0] = load_u32_at_src(block_byte_base + 74); + scale_vals[1] = load_u32_at_src(block_byte_base + 78); + for (var ib: u32 = 0; ib < 8; ib ++) { let s = get_byte(scale_vals[ib / 4], ib % 4); let db = array<f32, 2>( @@ -608,21 +465,21 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { } } } +#endif -#enddecl(IQ2_S) - -#decl(IQ3_XSS) +#ifdef IQ3_XXS fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { - let block = src[src_base + offset]; - let d = f32(block.d); + let block_byte_base = (src_base + offset) * 98; // Block stride: 98 bytes + let d = load_f16_as_f32_at_src(block_byte_base); var dst_i = dst_base + offset * 256; for (var ib: u32 = 0; ib < 16; ib += 2) { - let sc_sign = bitcast<u32>(vec2(block.qs[ib + 32], block.qs[ib + 33])); + let sc_sign_offset = block_byte_base + 2 + (ib + 32) * 2; + let sc_sign = load_u32_at_src(sc_sign_offset); let db = d * (0.5 + f32(sc_sign >> 28)) * 0.5; for (var l: u32 = 0; l < 4; l++) { let is = (sc_sign >> (7 * l)) & 127; let signs = get_byte(ksigns_iq2xs[is / 4], is % 4); - let ig_val = bitcast<u32>(vec2(block.qs[ib * 2 + l], 0.0)); + let ig_val = load_u32_at_src(block_byte_base + 2 + (ib * 2 + l) * 2) & 0xFFFF; let ig1 = get_byte(ig_val, 0); let ig2 = get_byte(ig_val, 1); for (var j: u32 = 0; j < 4; j++) { @@ -638,22 +495,26 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { } } } -#enddecl(IQ3_XSS) +#endif -#decl(IQ3_S) +#ifdef IQ3_S fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { - let block = src[src_base + offset]; - let d = f32(block.d); + let block_byte_base = (src_base + offset) * 110; // Block stride: 110 bytes + let d = load_f16_as_f32_at_src(block_byte_base); var dst_i = dst_base + offset * 256; + var qh_vals = array<u32, 2>( - bitcast<u32>(vec2(block.qh[0], block.qh[1])), - bitcast<u32>(vec2(block.qh[2], block.qh[3])) + load_u32_at_src(block_byte_base + 66), + load_u32_at_src(block_byte_base + 70) ); + var sign_vals: array<u32, 8>; for (var i: u32 = 0; i < 8; i++) { - sign_vals[i] = bitcast<u32>(vec2(block.signs[i * 2], block.signs[i * 2 + 1])); + sign_vals[i] = load_u32_at_src(block_byte_base + 74 + i * 4); } - var scale_vals = bitcast<u32>(vec2(block.scales[0], block.scales[1])); + + var scale_vals = load_u32_at_src(block_byte_base + 106); + for (var ib: u32 = 0; ib < 4; ib++) { let s = get_byte(scale_vals, ib); let db = array<f32, 2>( @@ -666,7 +527,7 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { let sign_w = sign_vals[ib * 2 + k]; for (var l: u32 = 0; l < 4; l++) { let signs = get_byte(sign_w, l); - let ig_val = bitcast<u32>(vec2(block.qs[ib * 8 + k * 4 + l], 0.0)); + let ig_val = load_u32_at_src(block_byte_base + 2 + (ib * 8 + k * 4 + l) * 2) & 0xFFFF; let ig1 = get_byte(ig_val, 0) | ((qh_byte << ((8 - (2 * l)))) & 256); let ig2 = get_byte(ig_val, 1) | ((qh_byte << ((7 - (2 * l)))) & 256); for (var j: u32 = 0; j < 4; j++) { @@ -683,18 +544,18 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { } } } -#enddecl(IQ3_S) +#endif -#decl(IQ1_S) +#ifdef IQ1_S fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { - let block = src[src_base + offset]; - let d = f32(block.d); + let block_byte_base = (src_base + offset) * 50; // Block stride: 50 bytes + let d = load_f16_as_f32_at_src(block_byte_base); var dst_i = dst_base + offset * 256; for (var ib: u32 = 0; ib < 8; ib++) { - let qh = bitcast<u32>(vec2(block.qh[ib], 0.0)); - let dl = d * (2 * f32((qh >> 12) & 7) + 1); + let qh = load_u32_at_src(block_byte_base + 34 + ib * 2) & 0xFFFF; + let dl = d * (2.0 * f32((qh >> 12) & 7) + 1.0); let delta = select(IQ1_DELTA, -IQ1_DELTA, (qh & 0x8000) != 0); - let qs_w = bitcast<u32>(vec2(block.qs[ib * 2], block.qs[ib * 2 + 1])); + let qs_w = load_u32_at_src(block_byte_base + 2 + ib * 4); for (var l: u32 = 0; l < 4; l++) { let ig = (get_byte(qs_w, l) | (((qh >> (3 * l)) & 7) << 8)) * 8; for (var j: u32 = 0; j < 8; j++) { @@ -707,10 +568,9 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { } } } +#endif -#enddecl(IQ1_S) - -#decl(IQ1_M) +#ifdef IQ1_M fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { let block = src[src_base + offset]; @@ -751,17 +611,16 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { } } } +#endif -#enddecl(IQ1_M) - -#decl(IQ4_NL) +#ifdef IQ4_NL fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { - let block = src[src_base + offset]; - let d = f32(block.d); + let block_byte_base = (src_base + offset) * 18; // Block stride: 18 bytes + let d = load_f16_as_f32_at_src(block_byte_base); var dst_i = dst_base + offset * 32; var qs: array<u32, 4>; for (var i: u32 = 0; i < 4; i++) { - qs[i] = bitcast<u32>(vec2(block.qs[i * 2], block.qs[i * 2 + 1])); + qs[i] = load_u32_at_src(block_byte_base + 2 + i * 4); } for (var j: u32 = 0; j < 16; j++) { let qsb = get_byte(qs[j / 4], j % 4); @@ -770,13 +629,13 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { dst_i++; } } -#enddecl(IQ4_NL) +#endif -#decl(IQ4_XS) +#ifdef IQ4_XS fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { let block = src[src_base + offset]; - let d = f32(block.d); - let scales_h = bitcast<u32>(vec2(block.scales_h, 0.0)); + let d = unpack2x16float(block.d_scales_h)[0]; + let scales_h = block.d_scales_h >> 16; var dst_i = dst_base + offset * 256; for (var ib: u32 = 0; ib < 8; ib++) { let ls = ((get_byte(block.scales_l, ib / 2) >> (4 * (ib % 2))) & 0xF) | (((scales_h >> (2 * ib)) & 3) << 4); @@ -791,24 +650,37 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { dst_i += 16; } } -#enddecl(IQ4_XS) - -#end(DECLS) +#endif -#define(SHADER) - -enable f16; +#ifdef MXFP4 +fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { + let block_byte_base = (src_base + offset) * 17; + let eu8 = get_byte(load_u32_at_src(block_byte_base), 0); + let d = ldexp(1.0, i32(eu8) - 128); + for (var j: u32 = 0u; j < 4; j++) { + let q_byte_offset = block_byte_base + 1 + j * 4; + let q_packed = load_u32_at_src(q_byte_offset); + for (var k: u32 = 0; k < 4; k++) { + let q_byte = get_byte(q_packed, k); + let q_hi = f32(kvalues_mxfp4[(q_byte >> 4) & 0xF]) * d; + let q_lo = f32(kvalues_mxfp4[q_byte & 0xFu]) * d; + let dst_offset = dst_base + offset * 32 + j * 4 + k; + dst[dst_offset] = q_lo; + dst[dst_offset + 16u] = q_hi; + } + } +} +#endif -DECLS @group(0) @binding(0) -var<storage, read_write> src: array<{{TYPE}}>; +var<storage, read_write> src: array<SRC_TYPE>; @group(0) @binding(1) var<storage, read_write> idx: array<i32>; @group(0) @binding(2) -var<storage, read_write> dst: array<{{DST_TYPE}}>; +var<storage, read_write> dst: array<DST_TYPE>; struct Params { offset_src: u32, // in elements @@ -842,9 +714,37 @@ struct Params { @group(0) @binding(3) var<uniform> params: Params; -override wg_size: u32; -@compute @workgroup_size(wg_size) +@compute @workgroup_size(WG_SIZE) fn main(@builtin(global_invocation_id) gid: vec3<u32>) { +#ifdef FLOAT_PARALLEL + let blocks_per_row = params.ne0 / BLOCK_SIZE; + let row_count = params.n_rows * params.ne2 * params.ne3; + + if (gid.x >= blocks_per_row * row_count) { + return; + } + + let block_idx = gid.x % blocks_per_row; + var row_idx = gid.x / blocks_per_row; + let i_dst3 = row_idx / (params.ne2 * params.n_rows); + + row_idx = row_idx % (params.ne2 * params.n_rows); + let i_dst2 = row_idx / params.n_rows; + let i_dst1 = row_idx % params.n_rows; + + let i_idx2 = i_dst3 % params.idx2; + let i_idx1 = i_dst2 % params.idx1; + let i_idx0 = i_dst1; + + let i_idx = params.offset_idx + i_idx0 * params.stride_idx0 + i_idx1 * params.stride_idx1 + i_idx2 * params.stride_idx2; + + let idx_val = u32(idx[i_idx]); + + let i_src_row = params.offset_src + idx_val * params.stride_src1 + i_dst2 * params.stride_src2 + i_dst3 * params.stride_src3; + let i_dst_row = params.offset_dst + i_dst1 * params.stride_dst1 + i_dst2 * params.stride_dst2 + i_dst3 * params.stride_dst3; + + copy_elements(i_src_row, i_dst_row, block_idx); +#else if (gid.x >= params.n_rows * params.ne2 * params.ne3) { return; } @@ -866,9 +766,8 @@ fn main(@builtin(global_invocation_id) gid: vec3<u32>) { let i_src_row = params.offset_src + idx_val * params.stride_src1 + i_dst2 * params.stride_src2 + i_dst3 * params.stride_src3; let i_dst_row = params.offset_dst + i_dst1 * params.stride_dst1 + i_dst2 * params.stride_dst2 + i_dst3 * params.stride_dst3; - for (var i: u32 = 0; i < params.ne0/{{BLOCK_SIZE}}; i++) { + for (var i: u32 = 0; i < params.ne0/BLOCK_SIZE; i++) { copy_elements(i_src_row, i_dst_row, i); } +#endif } - -#end(SHADER) diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/glu.tmpl.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/glu.tmpl.wgsl deleted file mode 100644 index 03fcd548689..00000000000 --- a/ggml/src/ggml-webgpu/wgsl-shaders/glu.tmpl.wgsl +++ /dev/null @@ -1,323 +0,0 @@ -#define(VARIANTS) - -[ - { - "SHADER_NAME": "reglu_f32", - "REPLS": { - "TYPE" : "f32", - }, - "DECLS": ["NO_SPLIT", "REGLU"] - }, - { - "SHADER_NAME": "reglu_f32_split", - "REPLS": { - "TYPE" : "f32", - }, - "DECLS": ["SPLIT", "REGLU"] - }, - { - "SHADER_NAME": "reglu_f16", - "REPLS": { - "TYPE" : "f16", - }, - "DECLS": ["NO_SPLIT", "REGLU"] - }, - { - "SHADER_NAME": "reglu_f16_split", - "REPLS": { - "TYPE" : "f16", - }, - "DECLS": ["SPLIT", "REGLU"] - }, - { - "SHADER_NAME": "geglu_f32", - "REPLS": { - "TYPE" : "f32", - }, - "DECLS": ["NO_SPLIT", "GEGLU"] - }, - { - "SHADER_NAME": "geglu_f32_split", - "REPLS": { - "TYPE" : "f32", - }, - "DECLS": ["SPLIT", "GEGLU"] - }, - { - "SHADER_NAME": "geglu_f16", - "REPLS": { - "TYPE" : "f16", - }, - "DECLS": ["NO_SPLIT", "GEGLU"] - }, - { - "SHADER_NAME": "geglu_f16_split", - "REPLS": { - "TYPE" : "f16", - }, - "DECLS": ["SPLIT", "GEGLU"] - }, - { - "SHADER_NAME": "swiglu_f32", - "REPLS": { - "TYPE" : "f32", - }, - "DECLS": ["NO_SPLIT", "SWIGLU"] - }, - { - "SHADER_NAME": "swiglu_f32_split", - "REPLS": { - "TYPE" : "f32", - }, - "DECLS": ["SPLIT", "SWIGLU"] - }, - { - "SHADER_NAME": "swiglu_f16", - "REPLS": { - "TYPE" : "f16", - }, - "DECLS": ["NO_SPLIT", "SWIGLU"] - }, - { - "SHADER_NAME": "swiglu_f16_split", - "REPLS": { - "TYPE" : "f16", - }, - "DECLS": ["SPLIT", "SWIGLU"] - }, - { - "SHADER_NAME": "swiglu_oai_f32", - "REPLS": { - "TYPE" : "f32", - }, - "DECLS": ["NO_SPLIT", "SWIGLU_OAI"] - }, - { - "SHADER_NAME": "swiglu_oai_f32_split", - "REPLS": { - "TYPE" : "f32", - }, - "DECLS": ["SPLIT", "SWIGLU_OAI"] - }, - { - "SHADER_NAME": "geglu_erf_f32", - "REPLS": { - "TYPE" : "f32", - }, - "DECLS": ["NO_SPLIT", "GEGLU_ERF"] - }, - { - "SHADER_NAME": "geglu_erf_f32_split", - "REPLS": { - "TYPE" : "f32", - }, - "DECLS": ["SPLIT", "GEGLU_ERF"] - }, - { - "SHADER_NAME": "geglu_erf_f16", - "REPLS": { - "TYPE" : "f16", - }, - "DECLS": ["NO_SPLIT", "GEGLU_ERF"] - }, - { - "SHADER_NAME": "geglu_erf_f16_split", - "REPLS": { - "TYPE" : "f16", - }, - "DECLS": ["SPLIT", "GEGLU_ERF"] - }, - { - "SHADER_NAME": "geglu_quick_f32", - "REPLS": { - "TYPE" : "f32", - }, - "DECLS": ["NO_SPLIT", "GEGLU_QUICK"] - }, - { - "SHADER_NAME": "geglu_quick_f32_split", - "REPLS": { - "TYPE" : "f32", - }, - "DECLS": ["SPLIT", "GEGLU_QUICK"] - }, - { - "SHADER_NAME": "geglu_quick_f16", - "REPLS": { - "TYPE" : "f16", - }, - "DECLS": ["NO_SPLIT", "GEGLU_QUICK"] - }, - { - "SHADER_NAME": "geglu_quick_f16_split", - "REPLS": { - "TYPE" : "f16", - }, - "DECLS": ["SPLIT", "GEGLU_QUICK"] - }, -] - -#end(VARIANTS) - -#define(DECLS) - -#decl(REGLU) -fn op(a: {{TYPE}}, b: {{TYPE}}) -> {{TYPE}} { - return max(a, 0) * b; -} -#enddecl(REGLU) - -#decl(GEGLU) -const SQRT_2_OVER_PI: {{TYPE}} = 0.79788456080286535587989211986876; -const GELU_COEF_A: {{TYPE}} = 0.044715; - -fn op(a: {{TYPE}}, b: {{TYPE}}) -> {{TYPE}} { - let val = SQRT_2_OVER_PI * a * (1.0 + GELU_COEF_A * a * a); - return 0.5 * a * (2.0 - 2.0 / (exp(2 * val) + 1)) * b; -} -#enddecl(GEGLU) - -#decl(SWIGLU) -fn op(a: {{TYPE}}, b: {{TYPE}}) -> {{TYPE}} { - return a / (1.0 + exp(-a)) * b; -} -#enddecl(SWIGLU) - -#decl(SWIGLU_OAI) -fn op(a: f32, b: f32) -> f32 { - let xi = min(a, params.limit); - let gi = max(min(b, params.limit), -params.limit); - var out_glu = xi / (1.0 + exp(-xi * params.alpha)); - out_glu = out_glu * (1.0 + gi); - return out_glu; -} -#enddecl(SWIGLU_OAI) - -#decl(GEGLU_ERF) -const p_erf: {{TYPE}} = 0.3275911; -const a1_erf: {{TYPE}} = 0.254829592; -const a2_erf: {{TYPE}} = -0.284496736; -const a3_erf: {{TYPE}} = 1.421413741; -const a4_erf: {{TYPE}} = -1.453152027; -const a5_erf: {{TYPE}} = 1.061405429; -const SQRT_2_INV: {{TYPE}} = 0.7071067811865476; - -fn op(a: {{TYPE}}, b: {{TYPE}}) -> {{TYPE}} { - let a_div_sqr2 = a * SQRT_2_INV; - let sign_x = sign(a_div_sqr2); - let x = abs(a_div_sqr2); - let t = 1.0 / (1.0 + p_erf * x); - let y = 1.0 - (((((a5_erf * t + a4_erf) * t + a3_erf) * t + a2_erf) * t + a1_erf) * t * exp(-x * x)); - let erf_approx = sign_x * y; - return 0.5 * a * (1.0 + erf_approx) * b; -} -#enddecl(GEGLU_ERF) - -#decl(GEGLU_QUICK) -const GELU_QUICK_COEF: {{TYPE}} = -1.702; - -fn op(a: {{TYPE}}, b: {{TYPE}}) -> {{TYPE}} { - return a * (1.0 / (1.0 + exp(GELU_QUICK_COEF * a))) * b; -} -#enddecl(GEGLU_QUICK) - -#decl(NO_SPLIT) -@group(0) @binding(1) -var<storage, read_write> dst: array<{{TYPE}}>; - -@group(0) @binding(2) -var<uniform> params: Params; - -fn a_value(base: u32) -> {{TYPE}} { - let offset: u32 = select(0, params.ne0, params.swapped != 0); - return src0[base + offset]; -} - -fn b_value(base: u32) -> {{TYPE}} { - let offset: u32 = select(params.ne0, 0, params.swapped != 0); - return src0[base + offset]; -} -#enddecl(NO_SPLIT) - -#decl(SPLIT) -@group(0) @binding(1) -var<storage, read_write> src1: array<{{TYPE}}>; - -@group(0) @binding(2) -var<storage, read_write> dst: array<{{TYPE}}>; - -@group(0) @binding(3) -var<uniform> params: Params; - -fn a_value(base: u32) -> {{TYPE}} { - return src0[base]; -} - -fn b_value(base: u32) -> {{TYPE}} { - return src1[base]; -} -#enddecl(SPLIT) - -#end(DECLS) - -#define(SHADER) - -enable f16; - -struct Params { - offset_src0: u32, - offset_src1: u32, - offset_dst: u32, - - // Strides (in elements) - stride_src01: u32, - stride_src02: u32, - stride_src03: u32, - - stride_src11: u32, - stride_src12: u32, - stride_src13: u32, - - stride_dst1: u32, - stride_dst2: u32, - stride_dst3: u32, - - // shape of dst - ne: u32, - ne0: u32, - ne1: u32, - ne2: u32, - - swapped: u32, - alpha: f32, - limit: f32, -} - -@group(0) @binding(0) -var<storage, read_write> src0: array<{{TYPE}}>; - -DECLS - -override wg_size: u32; -@compute @workgroup_size(wg_size) -fn main(@builtin(global_invocation_id) gid: vec3<u32>) { - if (gid.x >= params.ne) { - return; - } - - var i = gid.x; - let i3 = i / (params.ne2 * params.ne1 * params.ne0); - i = i % (params.ne2 * params.ne1 * params.ne0); - let i2 = i / (params.ne1 * params.ne0); - i = i % (params.ne1 * params.ne0); - let i1 = i / params.ne0; - let i0 = i % params.ne0; - - let i_a = params.offset_src0 + i3 * params.stride_src03 + i2 * params.stride_src02 + i1 * params.stride_src01 + i0; - let i_b = params.offset_src1 + i3 * params.stride_src13 + i2 * params.stride_src12 + i1 * params.stride_src11 + i0; - let i_dst = params.offset_dst + i3 * params.stride_dst3 + i2 * params.stride_dst2 + i1 * params.stride_dst1 + i0; - - dst[i_dst] = op(a_value(i_a), b_value(i_b)); -} - -#end(SHADER) diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/glu.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/glu.wgsl new file mode 100644 index 00000000000..e6d7608cec5 --- /dev/null +++ b/ggml/src/ggml-webgpu/wgsl-shaders/glu.wgsl @@ -0,0 +1,155 @@ +enable f16; + +#ifdef TYPE_F32 +#define DataType f32 +#endif +#ifdef TYPE_F16 +#define DataType f16 +#endif + +#ifdef OP_REGLU +fn op(a: DataType, b: DataType) -> DataType { + return max(a, 0) * b; +} +#endif + +#ifdef OP_GEGLU +const SQRT_2_OVER_PI: DataType = 0.79788456080286535587989211986876; +const GELU_COEF_A: DataType = 0.044715; + +fn op(a: DataType, b: DataType) -> DataType { + let val = SQRT_2_OVER_PI * a * (1.0 + GELU_COEF_A * a * a); + return 0.5 * a * (2.0 - 2.0/ (exp(2* val) + 1)) * b; +} +#endif + +#ifdef OP_SWIGLU +fn op(a: DataType, b: DataType) -> DataType { + return a / (1.0 + exp(-a)) * b; +} +#endif +#ifdef OP_SWIGLU_OAI +fn op(a: f32, b: f32) -> f32 { + let xi = min(a, params.limit); + let gi = max(min(b, params.limit), -params.limit); + var out_glu = xi / (1.0 + exp(-xi * params.alpha)); + out_glu = out_glu * (1.0 + gi); + return out_glu; +} +#endif +#ifdef OP_GEGLU_ERF +const p_erf: DataType = 0.3275911; +const a1_erf: DataType = 0.254829592; +const a2_erf: DataType = -0.284496736; +const a3_erf: DataType = 1.421413741; +const a4_erf: DataType = -1.453152027; +const a5_erf: DataType = 1.061405429; +const SQRT_2_INV: DataType = 0.7071067811865476; + +fn op(a: DataType, b: DataType) -> DataType { + let a_div_sqr2 = a * SQRT_2_INV; + let sign_x = sign(a_div_sqr2); + let x = abs(a_div_sqr2); + let t = 1.0 / (1.0 + p_erf * x); + let y = 1.0 - (((((a5_erf * t + a4_erf) * t + a3_erf) * t + a2_erf) * t + a1_erf) * t * exp(-x * x)); + let erf_approx = sign_x * y; + return 0.5 * a * (1.0 + erf_approx) * b; +} +#endif +#ifdef OP_GEGLU_QUICK +const GELU_QUICK_COEF: DataType = -1.702; + +fn op(a: DataType, b: DataType) -> DataType { + return a * (1.0 / (1.0 + exp(GELU_QUICK_COEF * a))) * b; +} +#endif + +struct Params { + offset_src0: u32, + offset_src1: u32, + offset_dst: u32, + + // Strides (in elements) + stride_src01: u32, + stride_src02: u32, + stride_src03: u32, + + stride_src11: u32, + stride_src12: u32, + stride_src13: u32, + + stride_dst1: u32, + stride_dst2: u32, + stride_dst3: u32, + + // shape of dst + ne: u32, + ne0: u32, + ne1: u32, + ne2: u32, + + swapped: u32, + alpha: f32, + limit: f32, +} + +@group(0) @binding(0) +var<storage, read_write> src0: array<DataType>; + +#ifdef NO_SPLIT +@group(0) @binding(1) +var<storage, read_write> dst: array<DataType>; + +@group(0) @binding(2) +var<uniform> params: Params; + +fn a_value(base: u32) -> DataType { + let offset: u32 = select(0, params.ne0, params.swapped != 0); + return src0[base + offset]; +} + +fn b_value(base: u32) -> DataType { + let offset: u32 = select(params.ne0, 0, params.swapped != 0); + return src0[base + offset]; +} + +#else +@group(0) @binding(1) +var<storage, read_write> src1: array<DataType>; + +@group(0) @binding(2) +var<storage, read_write> dst: array<DataType>; + +@group(0) @binding(3) +var<uniform> params: Params; + +fn a_value(base: u32) -> DataType { + return src0[base]; +} + +fn b_value(base: u32) -> DataType { + return src1[base]; +} + +#endif + +@compute @workgroup_size(WG_SIZE) +fn main(@builtin(global_invocation_id) gid: vec3<u32>) { + if (gid.x >= params.ne) { + return; + } + + var i = gid.x; + let i3 = i / (params.ne2 * params.ne1 * params.ne0); + i = i % (params.ne2 * params.ne1 * params.ne0); + let i2 = i / (params.ne1 * params.ne0); + i = i % (params.ne1 * params.ne0); + let i1 = i / params.ne0; + let i0 = i % params.ne0; + + let i_a = params.offset_src0 + i3 * params.stride_src03 + i2 * params.stride_src02 + i1 * params.stride_src01 + i0; + let i_b = params.offset_src1 + i3 * params.stride_src13 + i2 * params.stride_src12 + i1 * params.stride_src11 + i0; + let i_dst = params.offset_dst + i3 * params.stride_dst3 + i2 * params.stride_dst2 + i1 * params.stride_dst1 + i0; + + dst[i_dst] = op(a_value(i_a), b_value(i_b)); +} diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/im2col.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/im2col.wgsl new file mode 100644 index 00000000000..386ebab879f --- /dev/null +++ b/ggml/src/ggml-webgpu/wgsl-shaders/im2col.wgsl @@ -0,0 +1,101 @@ +#include "common_decls.tmpl" +enable f16; + +@group(0) @binding(0) +#if defined(INPUT_F32) +var<storage, read_write> input: array<f32>; +#elif defined(INPUT_F16) +var<storage, read_write> input: array<f16>; +#endif + +@group(0) @binding(1) +#if defined(OUTPUT_F32) +var<storage, read_write> output: array<f32>; +#elif defined(OUTPUT_F16) +var<storage, read_write> output: array<f16>; +#endif + +struct Params { + offset_i: u32, + offset_o: u32, + + // element strides + si0: u32, si1: u32, si2: u32, si3: u32, + so0: u32, so1: u32, so2: u32, so3: u32, + + KW: u32, KH: u32, IC: u32, + IW: u32, IH: u32, N: u32, + OW: u32, OH: u32, + + // stride + s0: u32, s1: u32, + // padding + p0: u32, p1: u32, + // dilation + d0: u32, d1: u32, +} + +@group(0) @binding(2) +var<uniform> params: Params; + +fn load_input(idx: u32) -> f32 { + #if defined(INPUT_F32) + return input[idx]; + #elif defined(INPUT_F16) + return f32(input[idx]); + #endif +} + +fn store_output(idx: u32, val: f32) { + #if defined(OUTPUT_F32) + output[idx] = val; + #elif defined(OUTPUT_F16) + output[idx] = f16(val); + #endif +} + +@compute @workgroup_size(WG_SIZE) +fn main( + @builtin(global_invocation_id) gid: vec3<u32>, + @builtin(num_workgroups) num_wg: vec3<u32> +) { + + let threads_per_group = u32(WG_SIZE); + let i_out = gid.x + (num_wg.x * threads_per_group) * gid.y; + let K = params.KW * params.KH * params.IC; + let M = params.OW * params.OH; + let total = K * M * params.N; + + if (i_out >= total) { + return; + } + + // decode (k, m, n) + var i = i_out; + let n = i / (K * M); + i = i % (K * M); + let m = i / K; + let k = i % K; + + // decode (oh, ow) + let oh = m / params.OW; + let ow = m % params.OW; + + // decode (kw, kh, ic) + let kw = k % params.KW; + let tmp = k / params.KW; + let kh = tmp % params.KH; + let ic = tmp / params.KH; + + let iw_i32 = i32(ow * params.s0 + kw * params.d0) - i32(params.p0); + let ih_i32 = i32(oh * params.s1 + kh * params.d1) - i32(params.p1); + + if (iw_i32 >= 0 && iw_i32 < i32(params.IW) && ih_i32 >= 0 && ih_i32 < i32(params.IH)) { + let iw = u32(iw_i32); + let ih = u32(ih_i32); + let in_idx = params.offset_i + iw * params.si0 + ih * params.si1 + ic * params.si2 + n * params.si3; + store_output(params.offset_o + k * params.so0 + ow * params.so1 + oh * params.so2 + n * params.so3, load_input(in_idx)); + } else { + store_output(params.offset_o + k * params.so0 + ow * params.so1 + oh * params.so2 + n * params.so3, 0.0); + } +} diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.tmpl.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.tmpl.wgsl deleted file mode 100644 index 0f8e6e5ac3d..00000000000 --- a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.tmpl.wgsl +++ /dev/null @@ -1,907 +0,0 @@ -#define(VARIANTS) - -[ - { - "REPLS": { - "SRC0_TYPE" : "f32", - "SRC1_TYPE" : "f32", - "BLOCK_SIZE" : 1 - }, - "DECLS" : ["FLOAT"] - }, - { - "REPLS": { - "SRC0_TYPE" : "f16", - "SRC1_TYPE" : "f16", - "BLOCK_SIZE" : 1 - }, - "DECLS" : ["FLOAT"] - }, - { - "REPLS": { - "SRC0_TYPE" : "f16", - "SRC1_TYPE" : "f32", - "BLOCK_SIZE" : 1 - }, - "DECLS" : ["FLOAT"] - }, - { - "REPLS": { - "SRC0_TYPE": "q4_0", - "SRC1_TYPE": "f32", - "BLOCK_SIZE": 32 - }, - "DECLS": ["BYTE_HELPERS", "Q4_0_T", "Q4_0"] - }, - { - "REPLS": { - "SRC0_TYPE": "q4_1", - "SRC1_TYPE": "f32", - "BLOCK_SIZE": 32 - }, - "DECLS": ["BYTE_HELPERS", "Q4_1_T", "Q4_1"] - }, - { - "REPLS": { - "SRC0_TYPE": "q5_0", - "SRC1_TYPE": "f32", - "BLOCK_SIZE": 32 - }, - "DECLS": ["BYTE_HELPERS", "Q5_0_T", "Q5_0"] - }, - { - "REPLS": { - "SRC0_TYPE": "q5_1", - "SRC1_TYPE": "f32", - "BLOCK_SIZE": 32 - }, - "DECLS": ["BYTE_HELPERS", "Q5_1_T", "Q5_1"] - }, - { - "REPLS": { - "SRC0_TYPE": "q8_0", - "SRC1_TYPE": "f32", - "BLOCK_SIZE": 32 - }, - "DECLS": ["BYTE_HELPERS", "Q8_0_T", "Q8_0"] - }, - { - "REPLS": { - "SRC0_TYPE": "q2_k", - "SRC1_TYPE": "f32", - "BLOCK_SIZE": 256 - }, - "DECLS": ["BYTE_HELPERS", "Q2_K_T", "Q2_K"] - }, - { - "REPLS": { - "SRC0_TYPE": "q3_k", - "SRC1_TYPE": "f32", - "BLOCK_SIZE": 256 - }, - "DECLS": ["BYTE_HELPERS", "Q3_K_T", "Q3_K"] - }, - { - "REPLS": { - "SRC0_TYPE": "q4_k", - "SRC1_TYPE": "f32", - "BLOCK_SIZE": 256 - }, - "DECLS": ["Q45_K_SCALE_MIN", "BYTE_HELPERS", "Q4_K_T", "Q4_K"] - }, - { - "REPLS": { - "SRC0_TYPE": "q5_k", - "SRC1_TYPE": "f32", - "BLOCK_SIZE": 256 - }, - "DECLS": ["Q45_K_SCALE_MIN", "BYTE_HELPERS", "Q5_K_T", "Q5_K"] - }, - { - "REPLS": { - "SRC0_TYPE": "q6_k", - "SRC1_TYPE": "f32", - "BLOCK_SIZE": 256 - }, - "DECLS": ["BYTE_HELPERS", "Q6_K_T", "Q6_K"] - }, - { - "REPLS": { - "SRC0_TYPE": "iq2_xxs", - "SRC1_TYPE": "f32", - "BLOCK_SIZE": 256 - }, - "DECLS": ["BYTE_HELPERS", "IQ23_TABLES", "IQ2_XXS_GRID", "IQ2_XXS_T", "IQ2_XXS"] - }, - { - "REPLS": { - "SRC0_TYPE": "iq2_xs", - "SRC1_TYPE": "f32", - "BLOCK_SIZE": 256 - }, - "DECLS": ["BYTE_HELPERS", "IQ23_TABLES", "IQ2_XS_GRID", "IQ2_XS_T", "IQ2_XS"] - }, - { - "REPLS": { - "SRC0_TYPE": "iq2_s", - "SRC1_TYPE": "f32", - "BLOCK_SIZE": 256 - }, - "DECLS": ["BYTE_HELPERS", "IQ23_TABLES", "IQ2_S_GRID", "IQ2_S_T", "IQ2_S"] - }, - { - "REPLS": { - "SRC0_TYPE": "iq3_xxs", - "SRC1_TYPE": "f32", - "BLOCK_SIZE": 256 - }, - "DECLS": ["BYTE_HELPERS", "IQ23_TABLES", "IQ3_XSS_GRID", "IQ3_XSS_T", "IQ3_XSS"] - }, - { - "REPLS": { - "SRC0_TYPE": "iq3_s", - "SRC1_TYPE": "f32", - "BLOCK_SIZE": 256 - }, - "DECLS": ["BYTE_HELPERS", "IQ23_TABLES", "IQ3_S_GRID", "IQ3_S_T", "IQ3_S"] - }, - { - "REPLS": { - "SRC0_TYPE": "iq1_s", - "SRC1_TYPE": "f32", - "BLOCK_SIZE": 256 - }, - "DECLS": ["BYTE_HELPERS", "IQ1_GRID", "IQ1_S_T", "IQ1_S"] - }, - { - "REPLS": { - "SRC0_TYPE": "iq1_m", - "SRC1_TYPE": "f32", - "BLOCK_SIZE": 256 - }, - "DECLS": ["BYTE_HELPERS", "IQ1_GRID", "IQ1_M_T", "IQ1_M"] - }, - { - "REPLS": { - "SRC0_TYPE": "iq4_nl", - "SRC1_TYPE": "f32", - "BLOCK_SIZE": 32, - }, - "DECLS": ["BYTE_HELPERS", "IQ4_GRID", "IQ4_NL_T", "IQ4_NL"] - }, - { - "REPLS": { - "SRC0_TYPE": "iq4_xs", - "SRC1_TYPE": "f32", - "BLOCK_SIZE": 256, - }, - "DECLS": ["BYTE_HELPERS", "IQ4_GRID", "IQ4_XS_T", "IQ4_XS"] - } -] - -#end(VARIANTS) - -#define(DECLS) - -#decl(FLOAT) -fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { - return f32(src0[src0_idx_base + offset]) * f32(src1[src1_idx_base + offset]); -} -#enddecl(FLOAT) - -#decl(Q4_0) -fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { - let block_q4_0 = src0[src0_idx_base + offset]; - let d = f32(block_q4_0.d); - var sum: f32 = 0.0; - for (var j: u32 = 0; j < 4; j++) { - let q_packed = bitcast<u32>(vec2(block_q4_0.qs[2 * j], block_q4_0.qs[2 * j + 1])); - for (var k: u32 = 0; k < 4; k++) { - let q_byte = get_byte(q_packed, k); - let q_hi = (f32((q_byte >> 4) & 0xF) - 8.0f) * d; - let q_lo = (f32(q_byte & 0xF) - 8.0f) * d; - let src1_offset = src1_idx_base + offset * 32 + j * 4 + k; - sum += q_lo * f32(src1[src1_offset]); - sum += q_hi * f32(src1[src1_offset + 16]); - } - } - return sum; -} -#enddecl(Q4_0) - -#decl(Q4_1) -fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { - let block_q4_1 = src0[src0_idx_base + offset]; - let d = f32(block_q4_1.d); - let m = f32(block_q4_1.m); - var sum: f32 = 0.0; - for (var j: u32 = 0; j < 4; j++) { - let q_packed = block_q4_1.qs[j]; - for (var k: u32 = 0; k < 4; k++) { - let q_byte = get_byte(q_packed, k); - let q_hi = f32((q_byte >> 4) & 0xF) * d + m; - let q_lo = f32(q_byte & 0xF) * d + m; - let src1_offset = src1_idx_base + offset * 32 + j * 4 + k; - sum += q_lo * f32(src1[src1_offset]); - sum += q_hi * f32(src1[src1_offset + 16]); - } - } - return sum; -} -#enddecl(Q4_1) - -#decl(Q5_0) -fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { - let block_q5_0 = src0[src0_idx_base + offset]; - let d = f32(block_q5_0.d); - var sum: f32 = 0.0; - let qh_packed = bitcast<u32>(vec2(block_q5_0.qh[0], block_q5_0.qh[1])); - for (var j: u32 = 0; j < 4; j++) { - let q_packed = bitcast<u32>(vec2(block_q5_0.qs[2 * j], block_q5_0.qs[2 * j + 1])); - for (var k: u32 = 0; k < 4; k++) { - let q_byte = get_byte(q_packed, k); - let qh_hi = (qh_packed >> (j * 4 + k + 12)) & 0x10; - let q_hi = (f32(((q_byte >> 4) & 0xF) | qh_hi) - 16.0) * d; - let qh_lo = ((qh_packed >> (j * 4 + k)) << 4) & 0x10; - let q_lo = (f32((q_byte & 0xF) | qh_lo) - 16.0) * d; - let src1_offset = src1_idx_base + offset * 32 + j * 4 + k; - sum += q_lo * f32(src1[src1_offset]); - sum += q_hi * f32(src1[src1_offset + 16]); - } - } - return sum; -} -#enddecl(Q5_0) - -#decl(Q5_1) -fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { - let block_q5_1 = src0[src0_idx_base + offset]; - let d = f32(block_q5_1.d); - let m = f32(block_q5_1.m); - var sum: f32 = 0.0; - for (var j: u32 = 0; j < 4; j++) { - let q_packed = block_q5_1.qs[j]; - for (var k: u32 = 0; k < 4; k++) { - let q_byte = get_byte(q_packed, k); - let qh_hi = (block_q5_1.qh >> (j * 4 + k + 12)) & 0x10; - let q_hi = f32(((q_byte >> 4) & 0xF) | qh_hi) * d + m; - let qh_lo = ((block_q5_1.qh >> (j * 4 + k)) << 4) & 0x10; - let q_lo = f32((q_byte & 0xF) | qh_lo) * d + m; - let src1_offset = src1_idx_base + offset * 32 + j * 4 + k; - sum += q_lo * f32(src1[src1_offset]); - sum += q_hi * f32(src1[src1_offset + 16]); - } - } - return sum; -} -#enddecl(Q5_1) - -#decl(Q8_0) -fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { - let block_q8_0 = src0[src0_idx_base + offset]; - let d = f32(block_q8_0.d); - var sum: f32 = 0.0; - for (var j: u32 = 0; j < 8; j++) { - let q_packed = bitcast<u32>(vec2(block_q8_0.qs[2 * j], block_q8_0.qs[2 * j + 1])); - for (var k: u32 = 0; k < 4; k++) { - let q_byte = get_byte_i32(q_packed, k); - let q_val = f32(q_byte) * d; - let src1_offset = src1_idx_base + offset * 32 + j * 4 + k; - sum += q_val * f32(src1[src1_offset]); - } - } - return sum; -} -#enddecl(Q8_0) - -#decl(Q8_1) -fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { - let block_q8_1 = src0[src0_idx_base + offset]; - let d = f32(block_q8_1.d); - let m = f32(block_q8_1.m); - var sum: f32 = 0.0; - for (var j: u32 = 0; j < 8; j++) { - let q_packed = block_q8_1.qs[j]; - for (var k: u32 = 0; k < 4; k++) { - let q_byte = get_byte_i32(q_packed, k); - let q_val = f32(q_byte) * d + m; - let src1_offset = src1_idx_base + offset * 32 + j * 4 + k; - sum += q_val * f32(src1[src1_offset]); - } - } - return sum; -} -#enddecl(Q8_1) - -#decl(Q2_K) -// 16 blocks of 16 elements each -fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { - let block = src0[src0_idx_base + offset]; - let d = f32(block.d); - let m = f32(block.dmin); - var sum = 0.0; - var src1_i = src1_idx_base + offset * 256; - var is: u32 = 0; - // 2 halves of the block (128 elements each) - for (var q_b_idx: u32 = 0; q_b_idx < 64; q_b_idx += 32) { - // 4 groups (each group has 2 blocks of 16 elements) - for (var shift: u32 = 0; shift < 8; shift += 2) { - // 2 blocks - for (var k: u32 = 0; k < 32; k += 16) { - let sc = get_byte(block.scales[is / 4], is % 4); - is++; - let dl = d * f32(sc & 0xF); - let ml = m * f32(sc >> 4); - for (var l: u32 = 0u; l < 16; l++) { - let q_idx = q_b_idx + k + l; - let q_byte = get_byte(block.qs[q_idx / 4], q_idx % 4); - let qs_val = (q_byte >> shift) & 3; - sum += (f32(qs_val) * dl - ml) * src1[src1_i]; - src1_i++; - } - } - } - } - return sum; -} - -#enddecl(Q2_K) - -#decl(Q3_K) -// 16 blocks of 16 elements each -fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { - let block = src0[src0_idx_base + offset]; - let d = f32(block.d); - - // extract 6-bit scales, which consist of 4-bits from first 8 bytes of scale, - // and 2-bits from the last 4 bytes - let kmask1: u32 = 0x03030303; - let kmask2: u32 = 0x0f0f0f0f; - var scale_vals: array<u32, 4>; - for (var i: u32 = 0; i < 4; i++) { - scale_vals[i] = bitcast<u32>(vec2(block.scales[2 * i], block.scales[2 * i + 1])); - } - var tmp: u32 = scale_vals[2]; - scale_vals[2] = ((scale_vals[0] >> 4) & kmask2) | (((tmp >> 4) & kmask1) << 4); - scale_vals[3] = ((scale_vals[1] >> 4) & kmask2) | (((tmp >> 6) & kmask1) << 4); - scale_vals[0] = (scale_vals[0] & kmask2) | ((tmp & kmask1) << 4); - scale_vals[1] = (scale_vals[1] & kmask2) | (((tmp >> 2) & kmask1) << 4); - - // convert arrays of f16 -> u32 - var hmask_vals: array<u32, 8>; - for (var i: u32 = 0; i < 8; i++) { - hmask_vals[i] = bitcast<u32>(vec2(block.hmask[2 * i], block.hmask[2 * i + 1])); - } - var qs_vals: array<u32, 16>; - for (var i: u32 = 0; i < 16; i++) { - qs_vals[i] = bitcast<u32>(vec2(block.qs[2 * i], block.qs[2 * i + 1])); - } - - var sum = 0.0; - var src1_i = src1_idx_base + offset * 256; - var is: u32 = 0; - var m: u32 = 1; - // 2 halves of the block (128 elements each) - for (var q_b_idx: u32 = 0; q_b_idx < 64; q_b_idx += 32) { - // 4 groups (each group has 2 blocks of 16 elements) - for (var shift: u32 = 0; shift < 8; shift += 2) { - // 2 blocks - for (var k: u32 = 0; k < 32; k += 16) { - let sc = get_byte(scale_vals[is / 4], is % 4); - is++; - let dl = d * (f32(sc) - 32.0); - for (var l: u32 = 0u; l < 16u; l++) { - let q_idx = q_b_idx + k + l; - let hm_idx = k + l; - let q_byte = get_byte(qs_vals[q_idx / 4], q_idx % 4); - let hmask_byte = get_byte(hmask_vals[hm_idx / 4], hm_idx % 4); - let hm = select(4.0, 0.0, (hmask_byte & m) != 0); - let qs_val = (q_byte >> shift) & 3; - sum += ((f32(qs_val) - hm) * dl) * src1[src1_i]; - src1_i++; - } - } - m <<= 1; - } - } - return sum; -} - -#enddecl(Q3_K) - -#decl(Q4_K) -// 8 blocks of 32 elements each -fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { - let block = src0[src0_idx_base + offset]; - let d = f32(block.d); - let m = f32(block.dmin); - var sum = 0.0; - var src1_i = src1_idx_base + offset * 256; - var is: u32 = 0; - // 2 blocks each iteration - for (var q_b_idx: u32 = 0; q_b_idx < 128; q_b_idx += 32) { - for (var shift: u32 = 0; shift < 8; shift += 4) { - let scale_min = get_scale_min(is, block.scales); - is++; - let dl = d * scale_min.x; - let ml = m * scale_min.y; - for (var l: u32 = 0; l < 32; l++) { - let q_idx = q_b_idx + l; - let q_byte = get_byte(block.qs[q_idx / 4], q_idx % 4); - let qs_val = (q_byte >> shift) & 0xF; - sum += (f32(qs_val) * dl - ml) * src1[src1_i]; - src1_i++; - } - } - } - return sum; -} - -#enddecl(Q4_K) - -#decl(Q5_K) -// 8 blocks of 32 elements each -fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { - let block = src0[src0_idx_base + offset]; - let d = f32(block.d); - let m = f32(block.dmin); - var sum = 0.0; - var src1_i = src1_idx_base + offset * 256; - var is: u32 = 0; - var u: u32 = 1; - // 2 blocks each iteration - for (var q_b_idx: u32 = 0; q_b_idx < 128; q_b_idx += 32) { - for (var shift: u32 = 0; shift < 8; shift += 4) { - let scale_min = get_scale_min(is, block.scales); - is++; - let dl = d * scale_min.x; - let ml = m * scale_min.y; - for (var l: u32 = 0; l < 32; l++) { - let q_idx = q_b_idx + l; - let q_byte = get_byte(block.qs[q_idx / 4], q_idx % 4); - let qh_byte = get_byte(block.qh[l / 4], l % 4); - let qs_val = (q_byte >> shift) & 0xF; - let qh_val = select(0.0, 16.0, (qh_byte & u) != 0); - sum += ((f32(qs_val) + qh_val) * dl - ml) * src1[src1_i]; - src1_i++; - } - u <<= 1; - } - } - return sum; -} - -#enddecl(Q5_K) - -#decl(Q6_K) -// 16 blocks of 16 elements each -fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { - let block = src0[src0_idx_base + offset]; - let d = f32(block.d); - - // convert arrays of f16 -> u32 - var ql_vals: array<u32, 32>; - for (var i: u32 = 0; i < 32; i++) { - ql_vals[i] = bitcast<u32>(vec2(block.ql[2 * i], block.ql[2 * i + 1])); - } - var qh_vals: array<u32, 16>; - for (var i: u32 = 0; i < 16; i++) { - qh_vals[i] = bitcast<u32>(vec2(block.qh[2 * i], block.qh[2 * i + 1])); - } - var scale_vals: array<u32, 4>; - for (var i: u32 = 0; i < 4; i++) { - scale_vals[i] = bitcast<u32>(vec2(block.scales[2 * i], block.scales[2 * i + 1])); - } - - var sum = 0.0; - var src1_i = src1_idx_base + offset * 256; - var qh_b_idx: u32 = 0; - var sc_b_idx: u32 = 0; - for (var ql_b_idx: u32 = 0; ql_b_idx < 128; ql_b_idx += 64) { - for (var l: u32 = 0; l < 32; l++) { - let ql13_b = get_byte(ql_vals[(ql_b_idx + l) / 4], (ql_b_idx + l) % 4); - let ql24_b = get_byte(ql_vals[(ql_b_idx + l + 32) / 4], (ql_b_idx + l + 32) % 4); - let qh_b = get_byte(qh_vals[(qh_b_idx + l) / 4], (qh_b_idx + l) % 4); - - let q1 = f32((ql13_b & 0xF) | ((qh_b & 3) << 4)) - 32.0; - let q2 = f32((ql24_b & 0xF) | (((qh_b >> 2) & 3) << 4)) - 32.0; - let q3 = f32((ql13_b >> 4) | (((qh_b >> 4) & 3) << 4)) - 32.0; - let q4 = f32((ql24_b >> 4) | (((qh_b >> 6) & 3) << 4)) - 32.0; - - let is = l/16; - let is1 = sc_b_idx + is; - let sc1 = get_byte_i32(scale_vals[is1 / 4], is1 % 4); - let is2 = sc_b_idx + is + 2; - let sc2 = get_byte_i32(scale_vals[is2 / 4], is2 % 4); - let is3 = sc_b_idx + is + 4; - let sc3 = get_byte_i32(scale_vals[is3 / 4], is3 % 4); - let is4 = sc_b_idx + is + 6; - let sc4 = get_byte_i32(scale_vals[is4 / 4], is4 % 4); - - sum += d * f32(sc1) * q1 * src1[src1_i + l]; - sum += d * f32(sc2) * q2 * src1[src1_i + l + 32]; - sum += d * f32(sc3) * q3 * src1[src1_i + l + 64]; - sum += d * f32(sc4) * q4 * src1[src1_i + l + 96]; - } - src1_i += 128; - qh_b_idx += 32; - sc_b_idx += 8; - } - return sum; -} - -#enddecl(Q6_K) - -#decl(IQ2_XXS) -fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { - let block = src0[src0_idx_base + offset]; - let d = f32(block.d); - var src1_i = src1_idx_base + offset * 256; - var sum = 0.0; - for (var ib: u32 = 0; ib < 32; ib += 4) { - let aux0 = bitcast<u32>(vec2(block.qs[ib], block.qs[ib + 1])); - let aux1 = bitcast<u32>(vec2(block.qs[ib + 2], block.qs[ib + 3])); - let db = d * (0.5 + f32(aux1 >> 28)) * 0.25; - for (var l: u32 = 0; l < 4; l++) { - let ig = get_byte(aux0, l) * 8; - let is = (aux1 >> (7 * l)) & 127; - let signs = get_byte(ksigns_iq2xs[is / 4], is % 4); - for (var j: u32 = 0; j < 8; j++) { - let g = get_byte(iq2xxs_grid[(ig + j) / 4], (ig + j) % 4); - let m = select(1.0, -1.0, (get_byte(kmask_iq2xs[j / 4], j % 4) & signs) != 0); - sum += db * f32(g) * m * src1[src1_i]; - src1_i++; - } - } - } - return sum; -} - -#enddecl(IQ2_XXS) - -#decl(IQ2_XS) -fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { - let block = src0[src0_idx_base + offset]; - let d = f32(block.d); - var src1_i = src1_idx_base + offset * 256; - var scale_vals = array<u32, 2>( - bitcast<u32>(vec2(block.scales[0], block.scales[1])), - bitcast<u32>(vec2(block.scales[2], block.scales[3])) - ); - var sum = 0.0; - for (var ib: u32 = 0; ib < 32; ib += 4) { - let s = get_byte(scale_vals[ib / 16], (ib % 16) / 4); - let db = array<f32, 2>( - d * (0.5 + f32(s & 0xF)) * 0.25, - d * (0.5 + f32(s >> 4)) * 0.25 - ); - for (var l: u32 = 0; l < 4; l++) { - let qs_val = bitcast<u32>(vec2(block.qs[ib + l], 0.0)); - let ig = (qs_val & 511) * 8; - let is = qs_val >> 9; - let signs = get_byte(ksigns_iq2xs[is / 4], is % 4); - let dl = db[l/2]; - for (var j: u32 = 0; j < 8; j++) { - let g = get_byte(iq2xs_grid[(ig + j) / 4], (ig + j) % 4); - let m = select(1.0, -1.0, (get_byte(kmask_iq2xs[j / 4], j % 4) & signs) != 0); - sum += dl * f32(g) * m * src1[src1_i]; - src1_i++; - } - } - } - return sum; -} - -#enddecl(IQ2_XS) - -#decl(IQ2_S) -fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { - let block = src0[src0_idx_base + offset]; - let d = f32(block.d); - var src1_i = src1_idx_base + offset * 256; - var qs_vals : array<u32, 16>; - for (var i: u32 = 0; i < 16; i++) { - qs_vals[i] = bitcast<u32>(vec2(block.qs[i * 2], block.qs[i * 2 + 1])); - } - var qh_vals = array<u32, 2>( - bitcast<u32>(vec2(block.qh[0], block.qh[1])), - bitcast<u32>(vec2(block.qh[2], block.qh[3])) - ); - var scale_vals = array<u32, 2>( - bitcast<u32>(vec2(block.scales[0], block.scales[1])), - bitcast<u32>(vec2(block.scales[2], block.scales[3])) - ); - var sum = 0.0; - for (var ib: u32 = 0; ib < 8; ib ++) { - let s = get_byte(scale_vals[ib / 4], ib % 4); - let db = array<f32, 2>( - d * (0.5 + f32(s & 0xF)) * 0.25, - d * (0.5 + f32(s >> 4)) * 0.25 - ); - let qs_w = qs_vals[ib]; - for (var l: u32 = 0; l < 4; l++) { - let qh_b = (get_byte(qh_vals[ib / 4], ib % 4) << (8 - 2 * l)) & 0x300; - let ig = (get_byte(qs_w, l) | qh_b) * 8; - let signs = get_byte(qs_vals[ib + 8], l); - let dl = db[l/2]; - for (var j: u32 = 0; j < 8; j++) { - let g = get_byte(iq2s_grid[(ig + j) / 4], (ig + j) % 4); - let m = select(1.0, -1.0, (get_byte(kmask_iq2xs[j / 4], j % 4) & signs) != 0); - sum += dl * f32(g) * m * src1[src1_i]; - src1_i++; - } - } - } - return sum; -} - - -#enddecl(IQ2_S) - -#decl(IQ3_XSS) -fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { - let block = src0[src0_idx_base + offset]; - let d = f32(block.d); - var src1_i = src1_idx_base + offset * 256; - var sum = 0.0; - for (var ib: u32 = 0; ib < 16; ib += 2) { - let sc_sign = bitcast<u32>(vec2(block.qs[ib + 32], block.qs[ib + 33])); - let db = d * (0.5 + f32(sc_sign >> 28)) * 0.5; - for (var l: u32 = 0; l < 4; l++) { - let is = (sc_sign >> (7 * l)) & 127; - let signs = get_byte(ksigns_iq2xs[is / 4], is % 4); - let ig_val = bitcast<u32>(vec2(block.qs[ib * 2 + l], 0.0)); - let ig1 = get_byte(ig_val, 0); - let ig2 = get_byte(ig_val, 1); - for (var j: u32 = 0; j < 4; j++) { - let g1 = get_byte(iq3xxs_grid[ig1], j); - let g2 = get_byte(iq3xxs_grid[ig2], j); - let m1 = select(1.0, -1.0, (get_byte(kmask_iq2xs[0], j) & signs) != 0); - let m2 = select(1.0, -1.0, (get_byte(kmask_iq2xs[1], j) & signs) != 0); - sum += db * f32(g1) * m1 * src1[src1_i]; - sum += db * f32(g2) * m2 * src1[src1_i + 4]; - src1_i++; - } - src1_i += 4; - } - } - return sum; -} - -#enddecl(IQ3_XSS) - -#decl(IQ3_S) -fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { - let block = src0[src0_idx_base + offset]; - let d = f32(block.d); - var src1_i = src1_idx_base + offset * 256; - var qh_vals = array<u32, 2>( - bitcast<u32>(vec2(block.qh[0], block.qh[1])), - bitcast<u32>(vec2(block.qh[2], block.qh[3])) - ); - var sign_vals: array<u32, 8>; - for (var i: u32 = 0; i < 8; i++) { - sign_vals[i] = bitcast<u32>(vec2(block.signs[i * 2], block.signs[i * 2 + 1])); - } - var scale_vals = bitcast<u32>(vec2(block.scales[0], block.scales[1])); - var sum = 0.0; - for (var ib: u32 = 0; ib < 4; ib++) { - let s = get_byte(scale_vals, ib); - let db = array<f32, 2>( - d * (1.0 + 2.0 * f32(s & 0xF)), - d * (1.0 + 2.0 * f32(s >> 4)) - ); - for (var k: u32 = 0; k < 2; k++) { - let dl = db[k]; - let qh_byte = get_byte(qh_vals[ib / 2], (ib % 2) * 2 + k); - let sign_w = sign_vals[ib * 2 + k]; - for (var l: u32 = 0; l < 4; l++) { - let signs = get_byte(sign_w, l); - let ig_val = bitcast<u32>(vec2(block.qs[ib * 8 + k * 4 + l], 0.0)); - let ig1 = get_byte(ig_val, 0) | ((qh_byte << ((8 - (2 * l)))) & 256); - let ig2 = get_byte(ig_val, 1) | ((qh_byte << ((7 - (2 * l)))) & 256); - for (var j: u32 = 0; j < 4; j++) { - let g1 = get_byte(iq3s_grid[ig1], j); - let g2 = get_byte(iq3s_grid[ig2], j); - let m1 = select(1.0, -1.0, (get_byte(kmask_iq2xs[0], j) & signs) != 0); - let m2 = select(1.0, -1.0, (get_byte(kmask_iq2xs[1], j) & signs) != 0); - sum += dl * f32(g1) * m1 * src1[src1_i]; - sum += dl * f32(g2) * m2 * src1[src1_i + 4]; - src1_i++; - } - src1_i += 4; - } - } - } - return sum; -} -#enddecl(IQ3_S) - -#decl(IQ1_S) -fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { - let block = src0[src0_idx_base + offset]; - let d = f32(block.d); - var src1_i = src1_idx_base + offset * 256; - var sum = 0.0; - for (var ib: u32 = 0; ib < 8; ib++) { - let qh = bitcast<u32>(vec2(block.qh[ib], 0.0)); - let dl = d * (2 * f32((qh >> 12) & 7) + 1); - let delta = select(IQ1_DELTA, -IQ1_DELTA, (qh & 0x8000) != 0); - let qs_w = bitcast<u32>(vec2(block.qs[ib * 2], block.qs[ib * 2 + 1])); - for (var l: u32 = 0; l < 4; l++) { - let ig = (get_byte(qs_w, l) | (((qh >> (3 * l)) & 7) << 8)) * 8; - for (var j: u32 = 0; j < 8; j++) { - let gw = iq1_grid[(ig + j) / 16]; - let g = (gw >> (((ig + j) % 16) * 2)) & 3; - let gs = bitcast<i32>(g << 30) >> 30; - sum += dl * (f32(gs) + delta) * src1[src1_i]; - src1_i++; - } - } - } - return sum; -} - -#enddecl(IQ1_S) - -#decl(IQ1_M) -fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { - let block = src0[src0_idx_base + offset]; - - let scale = ((block.scales[0] >> 12) & 0xF) | ((block.scales[0] >> 24) & 0x00F0) | ((block.scales[1] >> 4) & 0x0F00) | ((block.scales[1] >> 16) & 0xF000); - let d = f32(bitcast<vec2<f16>>(scale).x); - var src1_i = src1_idx_base + offset * 256; - var sum = 0.0; - for (var ib: u32 = 0; ib < 8; ib++) { - let sw = (block.scales[ib / 4] >> (16 * ((ib / 2) % 2))) & 0xFFFF; - let s1 : u32 = (sw >> (6 * (ib % 2))) & 0x7; - let s2 : u32 = (sw >> (6 * (ib % 2) + 3)) & 0x7; - var dl = array<f32, 2>( - d * f32(2 * s1 + 1), - d * f32(2 * s2 + 1) - ); - - let qh = block.qh[ib / 2] >> (16 * (ib % 2)); - var idx = array<u32, 4>( - get_byte(block.qs[ib], 0) | ((qh << 8) & 0x700), - get_byte(block.qs[ib], 1) | ((qh << 4) & 0x700), - get_byte(block.qs[ib], 2) | ((qh) & 0x700), - get_byte(block.qs[ib], 3) | ((qh >> 4) & 0x700) - ); - var delta = array<f32, 4>( - select(IQ1_DELTA, -IQ1_DELTA, (qh & 0x08) != 0), - select(IQ1_DELTA, -IQ1_DELTA, (qh & 0x80) != 0), - select(IQ1_DELTA, -IQ1_DELTA, ((qh >> 8) & 0x08) != 0), - select(IQ1_DELTA, -IQ1_DELTA, ((qh >> 8) & 0x80) != 0) - ); - for (var l: u32 = 0; l < 4; l++) { - let ig = idx[l] * 8; - for (var j: u32 = 0; j < 8; j++) { - let gw = iq1_grid[(ig + j) / 16]; - let g = (gw >> (((ig + j) % 16) * 2)) & 3; - let gs = bitcast<i32>(g << 30) >> 30; - sum += dl[l/2] * (f32(gs) + delta[l]) * src1[src1_i]; - src1_i++; - } - } - } - return sum; -} - -#enddecl(IQ1_M) - -#decl(IQ4_NL) -fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { - let block = src0[src0_idx_base + offset]; - let d = f32(block.d); - var src1_i = src1_idx_base + offset * 32; - var sum = 0.0; - var qs: array<u32, 4>; - for (var i: u32 = 0; i < 4; i++) { - qs[i] = bitcast<u32>(vec2(block.qs[i * 2], block.qs[i * 2 + 1])); - } - for (var j: u32 = 0; j < 16; j++) { - let qsb = get_byte(qs[j / 4], j % 4); - sum += d * f32(kvalues_iq4nl[qsb & 0xF]) * src1[src1_i]; - sum += d * f32(kvalues_iq4nl[qsb >> 4]) * src1[src1_i + 16]; - src1_i++; - } - return sum; -} - -#enddecl(IQ4_NL) - -#decl(IQ4_XS) -fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { - let block = src0[src0_idx_base + offset]; - let d = f32(block.d); - let scales_h = bitcast<u32>(vec2(block.scales_h, 0.0)); - var src1_i = src1_idx_base + offset * 256; - var sum = 0.0; - for (var ib: u32 = 0; ib < 8; ib++) { - let ls = ((get_byte(block.scales_l, ib / 2) >> (4 * (ib % 2))) & 0xF) | (((scales_h >> (2 * ib)) & 3) << 4); - let dl = d * (f32(ls) - 32.0); - for (var j: u32 = 0; j < 16; j++) { - let iqs = ib * 16 + j; - let qsb = get_byte(block.qs[iqs / 4], iqs % 4); - sum += dl * f32(kvalues_iq4nl[qsb & 0xF]) * src1[src1_i]; - sum += dl * f32(kvalues_iq4nl[qsb >> 4]) * src1[src1_i + 16]; - src1_i++; - } - src1_i += 16; - } - return sum; -} - -#enddecl(IQ4_XS) - -#end(DECLS) - -#define(SHADER) - -enable f16; - -DECLS - -struct MulMatParams { - offset_src0: u32, // in elements/blocks - offset_src1: u32, // in elements/blocks - offset_dst: u32, // in elements/blocks - m: u32, - n: u32, - k: u32, - // all strides are in elements/blocks - stride_01: u32, - stride_11: u32, - stride_02: u32, - stride_12: u32, - stride_03: u32, - stride_13: u32, - - bs02: u32, - bs03: u32, - broadcast2: u32, - broadcast3: u32 -}; - -@group(0) @binding(0) var<storage, read_write> src0: array<{{SRC0_TYPE}}>; // M rows, K columns -@group(0) @binding(1) var<storage, read_write> src1: array<{{SRC1_TYPE}}>; // K rows, N columns (transposed) -@group(0) @binding(2) var<storage, read_write> dst: array<f32>; // M rows, N columns - -@group(0) @binding(3) var<uniform> params: MulMatParams; - -@compute @workgroup_size(256) -fn main(@builtin(global_invocation_id) global_id: vec3<u32>) { - let total = params.m * params.n * params.bs02 * params.broadcast2 * params.bs03 * params.broadcast3; - if (global_id.x >= total) { - return; - } - - let dst2_stride = params.m * params.n; - let dst3_stride = dst2_stride * params.bs02 * params.broadcast2; - - let dst3_idx = global_id.x / dst3_stride; - let src03_idx = dst3_idx / params.broadcast3; // src0 may be broadcast along the third dimension - let src13_idx = dst3_idx; // src1 is not broadcast - let dst3_rem = global_id.x % dst3_stride; - - let dst2_idx = dst3_rem / dst2_stride; - let src02_idx = dst2_idx / params.broadcast2; // src0 may also be broadcast along the second dimension - let src12_idx = dst2_idx; // src1 is not broadcast - - let dst2_rem = dst3_rem % dst2_stride; - - let row = dst2_rem / params.m; // output row - let col = dst2_rem % params.m; // output column - - let src0_idx_base = params.offset_src0 + src03_idx * params.stride_03 + src02_idx * params.stride_02 + col * params.stride_01; - let src1_idx_base = params.offset_src1 + src13_idx * params.stride_13 + src12_idx * params.stride_12 + row * params.stride_11; - - var sum = 0.0; - for (var i: u32 = 0u; i < params.k/{{BLOCK_SIZE}}; i = i + 1u) { - sum += multiply_add(src0_idx_base, src1_idx_base, i); - } - dst[params.offset_dst + dst3_idx * dst3_stride + dst2_idx * dst2_stride + row * params.m + col] = sum; -} - -#end(SHADER) diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl index 109ff8d6159..ed4a6b13bbf 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl @@ -1,97 +1,926 @@ -#decl(SHMEM_VEC) +#ifdef VEC +#define VEC_SIZE 4 +#define SHMEM_TYPE vec4<f16> +#define DST_TYPE vec4<f32> +#define SRC0_TYPE vec4<SRC0_INNER_TYPE> +#define SRC1_TYPE vec4<SRC1_INNER_TYPE> + fn store_shmem(val: vec4<f16>, idx: u32) { shmem[idx] = val.x; shmem[idx + 1] = val.y; shmem[idx + 2] = val.z; shmem[idx + 3] = val.w; } -#enddecl(SHMEM_VEC) +#endif // VEC + +#ifdef SCALAR +#define VEC_SIZE 1 +#define SHMEM_TYPE f16 +#define DST_TYPE f32 +#define SRC0_TYPE SRC0_INNER_TYPE +#define SRC1_TYPE SRC1_INNER_TYPE -#decl(SHMEM_SCALAR) fn store_shmem(val: f16, idx: u32) { shmem[idx] = val; } -#enddecl(SHMEM_SCALAR) +#endif // SCALAR -#decl(INIT_SRC0_SHMEM_FLOAT) +#define QUANT_SHMEM shmem +#define QUANT_OUT_TYPE f16 +#include "quant_inner_loops.tmpl" +#ifdef INIT_SRC0_SHMEM_FLOAT fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) { - for (var elem_idx = thread_id * {{VEC_SIZE}}; elem_idx < TILE_SRC0_SHMEM; elem_idx += TOTAL_WORKGROUP_SIZE * {{VEC_SIZE}}) { + for (var elem_idx = thread_id * VEC_SIZE; elem_idx < TILE_SRC0_SHMEM; elem_idx += TOTAL_WORKGROUP_SIZE * VEC_SIZE) { let tile_m = elem_idx / TILE_K; let tile_k = elem_idx % TILE_K; let global_m = offset_m + tile_m; let global_k = k_outer + tile_k; let src0_idx = batch_offset + global_m * params.stride_01 + global_k; let src0_val = select( // taking a slight performance hit to avoid oob - {{SRC0_TYPE}}(0.0), - src0[src0_idx/{{VEC_SIZE}}], + SRC0_TYPE(0.0), + src0[src0_idx/VEC_SIZE], global_m < params.m && global_k < params.k); - store_shmem({{SHMEM_TYPE}}(src0_val), elem_idx); + store_shmem(SHMEM_TYPE(src0_val), elem_idx); } } +#endif // INIT_SRC0_SHMEM_FLOAT -#enddecl(INIT_SRC0_SHMEM_FLOAT) - -#decl(INIT_SRC1_SHMEM) - +#ifndef MUL_MAT_ID +#ifdef INIT_SRC1_SHMEM_FLOAT fn init_shmem_src1(thread_id: u32, batch_offset: u32, offset_n: u32, k_outer: u32) { - for (var elem_idx = thread_id * {{VEC_SIZE}}; elem_idx < TILE_SRC1_SHMEM; elem_idx += TOTAL_WORKGROUP_SIZE * {{VEC_SIZE}}) { + for (var elem_idx = thread_id * VEC_SIZE; elem_idx < TILE_SRC1_SHMEM; elem_idx += TOTAL_WORKGROUP_SIZE * VEC_SIZE) { let tile_n = elem_idx / TILE_K; let tile_k = elem_idx % TILE_K; let global_n = offset_n + tile_n; let global_k = k_outer + tile_k; let src1_idx = batch_offset + global_n * params.stride_11 + global_k; let src1_val = select( - {{SRC1_TYPE}}(0.0), - src1[src1_idx/{{VEC_SIZE}}], + SRC1_TYPE(0.0), + src1[src1_idx/VEC_SIZE], global_n < params.n && global_k < params.k); - store_shmem({{SHMEM_TYPE}}(src1_val), TILE_SRC0_SHMEM + elem_idx); + store_shmem(SHMEM_TYPE(src1_val), TILE_SRC0_SHMEM + elem_idx); } } +#endif // INIT_SRC1_SHMEM_FLOAT +#endif -#enddecl(INIT_SRC1_SHMEM) +#ifdef INIT_SRC0_SHMEM_Q1_0 +const BLOCK_SIZE = 128u; +const BLOCK_SIZE_BYTES = 18u; +const NQ = 8u; // 8 weights (1 byte of qs) per thread per iteration -#decl(INIT_SRC0_SHMEM_Q4_0) +fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) { + for (var i = thread_id * NQ; i < TILE_SRC0_SHMEM; i += TOTAL_WORKGROUP_SIZE * NQ) { + let tile_m = i / TILE_K; + let tile_k_start = i % TILE_K; + let global_m = offset_m + tile_m; + let global_k_start = k_outer + tile_k_start; + + if (global_m >= params.m) { + break; + } + let block_k = global_k_start / BLOCK_SIZE; + let byte_in_block = (global_k_start % BLOCK_SIZE) / 8u; + let src0_idx = batch_offset + global_m * params.stride_01 + block_k; + let block_byte_base = src0_idx * BLOCK_SIZE_BYTES; + let d = load_f16_at_src0(block_byte_base); + let q_byte = load_u32_at_src0(block_byte_base + 2u + byte_in_block) & 0xFFu; + + for (var bit = 0u; bit < NQ; bit++) { + let global_k = global_k_start + bit; + if (global_k < params.k) { + shmem[i + bit] = select(-d, d, ((q_byte >> bit) & 1u) != 0u); + } + } + } +} +#endif // INIT_SRC0_SHMEM_Q1_0 + +#if defined(INIT_SRC0_SHMEM_Q4_0) || defined(INIT_SRC0_SHMEM_Q4_1) || defined(INIT_SRC0_SHMEM_Q5_0) || defined(INIT_SRC0_SHMEM_Q5_1) || defined(INIT_SRC0_SHMEM_Q8_0) || defined(INIT_SRC0_SHMEM_Q8_1) || defined(INIT_SRC0_SHMEM_MXFP4) const BLOCK_SIZE = 32u; // the number of blocks per k-tile. Note that this currently only works if TILE_K is a multiple of BLOCK_SIZE, which may need to be rethought for larger quantized types. override BLOCKS_K = TILE_K/BLOCK_SIZE; const NQ = 16u; -const F16_PER_BLOCK = 9u; // 1 scale + 8x4 packed weights -const WEIGHTS_PER_F16 = 4u; // 4 weights per f16 -const F16_PER_THREAD = NQ / WEIGHTS_PER_F16; +#if defined(INIT_SRC0_SHMEM_Q8_0) || defined(INIT_SRC0_SHMEM_Q8_1) +const BYTES_PER_THREAD = 16u; // NQ(16) weights use 16 bytes of q +#else +const BYTES_PER_THREAD = 8u; // NQ(16) weights use 8 bytes of q +#endif +const BYTES_PER_INNER_LOOP = 4u; // == sizeof(q_packed) fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) { for (var i = thread_id * NQ; i < TILE_SRC0_SHMEM; i += TOTAL_WORKGROUP_SIZE * NQ) { - let blck_idx = i / BLOCK_SIZE; - let block_offset = (i % BLOCK_SIZE) / WEIGHTS_PER_F16; - let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * 2u; + let block_idx = i / BLOCK_SIZE; + let block_offset = (i % BLOCK_SIZE) / NQ; + let shmem_idx = block_idx * BLOCK_SIZE + block_offset * BYTES_PER_THREAD; - let tile_m = blck_idx / BLOCKS_K; + let tile_m = block_idx / BLOCKS_K; let global_m = offset_m + tile_m; - let block_k = blck_idx % BLOCKS_K; - let global_k = k_outer / BLOCK_SIZE + block_k; + let block_k = block_idx % BLOCKS_K; + let global_block_k = k_outer / BLOCK_SIZE + block_k; + + if (global_m < params.m && global_block_k < params.k / BLOCK_SIZE) { + let src0_idx = batch_offset + global_m * params.stride_01 + global_block_k; + +#ifdef INIT_SRC0_SHMEM_Q4_0 + let block_byte_base = src0_idx * 18u; // BLOCK_SIZE_BYTES = 18u; + let d = load_f16_at_src0(block_byte_base); - if (global_m < params.m && global_k < params.k / BLOCK_SIZE) { - let src0_idx = batch_offset + global_m * params.stride_01 + global_k; - let scale_idx = src0_idx * F16_PER_BLOCK; - let d = src0[scale_idx]; + // load NQ(16) weights + for (var j = 0u; j < BYTES_PER_THREAD / BYTES_PER_INNER_LOOP; j += 1) { + let q_byte_offset = block_byte_base + 2u + block_offset * BYTES_PER_THREAD + j * BYTES_PER_INNER_LOOP; + let q_packed = load_u32_at_src0(q_byte_offset); + dequant_q4_0_packed_to_shmem(q_packed, d, shmem_idx + j * BYTES_PER_INNER_LOOP); + } +#elif INIT_SRC0_SHMEM_Q4_1 + let block_byte_base = src0_idx * 20u; // BLOCK_SIZE_BYTES = 20u; + let dm = unpack2x16float(load_u32_at_src0_aligned(block_byte_base)); + let d = f16(dm[0]); + let m = f16(dm[1]); - for (var j = 0u; j < F16_PER_THREAD; j += 2) { - let q_0 = src0[scale_idx + 1u + block_offset + j]; - let q_1 = src0[scale_idx + 1u + block_offset + j + 1]; + // load NQ(16) weights + for (var j = 0u; j < BYTES_PER_THREAD / BYTES_PER_INNER_LOOP; j += 1) { + let q_byte_offset = block_byte_base + 4u + block_offset * BYTES_PER_THREAD + j * BYTES_PER_INNER_LOOP; + let q_packed = load_u32_at_src0(q_byte_offset); - let q_packed = bitcast<u32>(vec2(q_0, q_1)); - for (var k = 0u; k < 4u; k++) { + for (var k = 0u; k < BYTES_PER_INNER_LOOP; k++) { let q_byte = get_byte(q_packed, k); - let q_hi = (f16((q_byte >> 4) & 0xF) - 8.0) * d; - let q_lo = (f16(q_byte & 0xF) - 8.0) * d; - shmem[shmem_idx + j * 2 + k] = q_lo; - shmem[shmem_idx + j * 2 + k + 16u] = q_hi; + let q_lo = f16(q_byte & 0xF) * d + m; + let q_hi = f16((q_byte >> 4) & 0xF) * d + m; + shmem[shmem_idx + j * BYTES_PER_INNER_LOOP + k] = q_lo; + shmem[shmem_idx + j * BYTES_PER_INNER_LOOP + k + 16u] = q_hi; } } +#elif INIT_SRC0_SHMEM_Q5_0 + let block_byte_base = src0_idx * 22u; // BLOCK_SIZE_BYTES = 22u; + + let d = load_f16_at_src0(block_byte_base); + let qh_packed = load_u32_at_src0(block_byte_base + 2u); + + // load NQ(16) weights + for (var j = 0u; j < BYTES_PER_THREAD / BYTES_PER_INNER_LOOP; j += 1) { + let q_byte_offset = block_byte_base + 6u + block_offset * BYTES_PER_THREAD + j * BYTES_PER_INNER_LOOP; + let q_packed = load_u32_at_src0(q_byte_offset); + + for (var k = 0u; k < BYTES_PER_INNER_LOOP; k++) { + let q_byte = get_byte(q_packed, k); + + let byte_idx = block_offset * BYTES_PER_THREAD + j * BYTES_PER_INNER_LOOP + k; + let qh_hi = (qh_packed >> (byte_idx + 12u)) & 0x10; + let q_hi = (f16(((q_byte >> 4) & 0xF) | qh_hi) - 16.0) * d; + let qh_lo = ((qh_packed >> byte_idx) << 4) & 0x10; + let q_lo = (f16((q_byte & 0xF) | qh_lo) - 16.0) * d; + shmem[shmem_idx + j * BYTES_PER_INNER_LOOP + k] = q_lo; + shmem[shmem_idx + j * BYTES_PER_INNER_LOOP + k + 16u] = q_hi; + } + } +#elif INIT_SRC0_SHMEM_Q5_1 + let block_byte_base = src0_idx * 24u; // BLOCK_SIZE_BYTES = 24u; + + let dm = unpack2x16float(load_u32_at_src0_aligned(block_byte_base)); + let d = f16(dm[0]); + let m = f16(dm[1]); + let qh_packed = load_u32_at_src0_aligned(block_byte_base + 4u); + + // load NQ(16) weights + for (var j = 0u; j < BYTES_PER_THREAD / BYTES_PER_INNER_LOOP; j += 1) { + let q_byte_offset = block_byte_base + 8u + block_offset * BYTES_PER_THREAD + j * BYTES_PER_INNER_LOOP; + let q_packed = load_u32_at_src0_aligned(q_byte_offset); + + for (var k = 0u; k < BYTES_PER_INNER_LOOP; k++) { + let q_byte = get_byte(q_packed, k); + + let byte_idx = block_offset * BYTES_PER_THREAD + j * BYTES_PER_INNER_LOOP + k; + let qh_hi = (qh_packed >> (byte_idx + 12u)) & 0x10; + let q_hi = f16(((q_byte >> 4) & 0xF) | qh_hi) * d + m; + let qh_lo = ((qh_packed >> byte_idx) << 4) & 0x10; + let q_lo = f16((q_byte & 0xF) | qh_lo) * d + m; + shmem[shmem_idx + j * BYTES_PER_INNER_LOOP + k] = q_lo; + shmem[shmem_idx + j * BYTES_PER_INNER_LOOP + k + 16u] = q_hi; + } + } +#elif INIT_SRC0_SHMEM_Q8_0 + let block_byte_base = src0_idx * 34u; // BLOCK_SIZE_BYTES = 34u; + let d = load_f16_at_src0(block_byte_base); + + // load NQ(16) weights + for (var j = 0u; j < BYTES_PER_THREAD / BYTES_PER_INNER_LOOP; j += 1) { + let q_byte_offset = block_byte_base + 2u + block_offset * BYTES_PER_THREAD + j * BYTES_PER_INNER_LOOP; + let q_packed = load_u32_at_src0(q_byte_offset); + dequant_q8_0_packed_to_shmem(q_packed, d, shmem_idx + j * BYTES_PER_INNER_LOOP); + } +#elif INIT_SRC0_SHMEM_Q8_1 + let block_byte_base = src0_idx * 36u; // BLOCK_SIZE_BYTES = 36u; + let dm = unpack2x16float(load_u32_at_src0_aligned(block_byte_base)); + let d = f16(dm[0]); + let m = f16(dm[1]); + + // load NQ(16) weights + for (var j = 0u; j < BYTES_PER_THREAD / BYTES_PER_INNER_LOOP; j += 1) { + let q_byte_offset = block_byte_base + 4u + block_offset * BYTES_PER_THREAD + j * BYTES_PER_INNER_LOOP; + let q_packed = load_u32_at_src0(q_byte_offset); + for (var k = 0u; k < BYTES_PER_INNER_LOOP; k++) { + let q_byte = get_byte_i32(q_packed, k); + let q_val = f16(q_byte) * d + m; + shmem[shmem_idx + j * BYTES_PER_INNER_LOOP + k] = q_val; + } + } +#elif INIT_SRC0_SHMEM_MXFP4 + let block_byte_base = src0_idx * 17u; + let eu8 = get_byte(load_u32_at_src0_aligned(block_byte_base), block_byte_base & 3u); + let e = ldexp(1.0, i32(eu8) - 128); + + // load NQ(16) weights + for (var j = 0u; j < BYTES_PER_THREAD / BYTES_PER_INNER_LOOP; j += 1) { + let q_byte_offset = block_byte_base + 1u + block_offset * BYTES_PER_THREAD + j * BYTES_PER_INNER_LOOP; + let q_packed = load_u32_at_src0(q_byte_offset); + for (var k = 0u; k < BYTES_PER_INNER_LOOP; k++) { + let q_byte = get_byte(q_packed, k); + let q_hi = f32(kvalues_mxfp4[(q_byte >> 4) & 0xF]) * e; + let q_lo = f32(kvalues_mxfp4[q_byte & 0xF]) * e; + shmem[shmem_idx + j * BYTES_PER_INNER_LOOP + k] = f16(q_lo); + shmem[shmem_idx + j * BYTES_PER_INNER_LOOP + k + 16u] = f16(q_hi); + } + } +#endif + } + } +} +#endif + +// k-quants +#if defined(INIT_SRC0_SHMEM_Q2_K) || defined(INIT_SRC0_SHMEM_Q3_K) || defined(INIT_SRC0_SHMEM_Q4_K) || defined(INIT_SRC0_SHMEM_Q5_K) || defined(INIT_SRC0_SHMEM_Q6_K) +const BLOCK_SIZE = 256u; +const NQ = 4u; + +fn store_shmem_kquants(val: vec4<f16>, idx: u32) { + shmem[idx] = val.x; + shmem[idx + 1] = val.y; + shmem[idx + 2] = val.z; + shmem[idx + 3] = val.w; +} + +fn load_byte_at_src0_aligned(byte_offset: u32) -> u32 { + return get_byte(load_u32_at_src0_aligned(byte_offset), byte_offset % 4u); +} + +fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) { + for (var elem_idx = thread_id * NQ; elem_idx < TILE_SRC0_SHMEM; elem_idx += TOTAL_WORKGROUP_SIZE * NQ) { + let tile_m = elem_idx / TILE_K; + let tile_k = elem_idx % TILE_K; + + let global_m = offset_m + tile_m; + let global_k = k_outer + tile_k; + + if (global_m >= params.m || global_k >= params.k) { + store_shmem_kquants(vec4<f16>(f16(0.0), f16(0.0), f16(0.0), f16(0.0)), elem_idx); + continue; + } + + let block_k = global_k / BLOCK_SIZE; + let k_in_block = global_k % BLOCK_SIZE; // k_in_block % 4 == 0; + + let src0_idx = batch_offset + global_m * params.stride_01 + block_k; + +#ifdef INIT_SRC0_SHMEM_Q2_K + let block_byte_base = src0_idx * 84u; // BLOCK_SIZE_BYTES = 84u; + let scales_byte_base = block_byte_base; + let qs_byte_base = block_byte_base + 16u; + let dm_byte_base = block_byte_base + 80u; + + let d_packed = unpack2x16float(load_u32_at_src0_aligned(dm_byte_base)); + let d = f16(d_packed[0]); + let dmin = f16(d_packed[1]); + + let chunk = k_in_block / 128u; + let pos_in_chunk = k_in_block % 32u; + let sub_block = k_in_block / 16u; + let shift_phase = (k_in_block % 128u) / 32u; + + // whole 2 bits (4 elems) + let qs_word = load_u32_at_src0_aligned(qs_byte_base + 32u * chunk + 1u * pos_in_chunk); + let qs_vec4 = vec4<f16>( + f16((qs_word >> (2u * shift_phase + 0u)) & 0x3u), + f16((qs_word >> (2u * shift_phase + 8u)) & 0x3u), + f16((qs_word >> (2u * shift_phase + 16u)) & 0x3u), + f16((qs_word >> (2u * shift_phase + 24u)) & 0x3u), + ); + + let scale = load_byte_at_src0_aligned(scales_byte_base + sub_block); + + let dl = d * f16(scale & 0xFu); + let ml = dmin * f16(scale >> 4u); + + store_shmem_kquants(qs_vec4 * dl - ml, elem_idx); +#elif INIT_SRC0_SHMEM_Q3_K + let block_byte_base = src0_idx * 110u; // BLOCK_SIZE_BYTES = 110u; + let hmask_byte_base = block_byte_base + 0u; + let qs_byte_base = block_byte_base + 32u; + let scales_byte_base = block_byte_base + 96u; + + let d_all = load_f16_at_src0(block_byte_base + 108u); + + let chunk = k_in_block / 128u; + let pos_in_chunk = k_in_block % 32u; + let sub_block = k_in_block / 16u; + let shift_phase = (k_in_block % 128u) / 32u; + + let hmask_block = pos_in_chunk; + let hmask_shift_phase = k_in_block / 32u; + + // low 2 bits (4 elems) + let q_lo2_word = load_u32_at_src0(qs_byte_base + 32u * chunk + 1u * hmask_block); + let q_lo2_vec4 = vec4<f16>( + f16((q_lo2_word >> (2u * shift_phase + 0u)) & 3u), + f16((q_lo2_word >> (2u * shift_phase + 8u)) & 3u), + f16((q_lo2_word >> (2u * shift_phase + 16u)) & 3u), + f16((q_lo2_word >> (2u * shift_phase + 24u)) & 3u) + ); + + // high 1 bit (4 elems) + let q_hi1_word = load_u32_at_src0(hmask_byte_base + pos_in_chunk); + let q_hi1_vec4 = vec4<f16>( + f16(select(4.0, 0.0, ((q_hi1_word >> (1u * hmask_shift_phase + 0u)) & 1u) == 1u)), + f16(select(4.0, 0.0, ((q_hi1_word >> (1u * hmask_shift_phase + 8u)) & 1u) == 1u)), + f16(select(4.0, 0.0, ((q_hi1_word >> (1u * hmask_shift_phase + 16u)) & 1u) == 1u)), + f16(select(4.0, 0.0, ((q_hi1_word >> (1u * hmask_shift_phase + 24u)) & 1u) == 1u)) + ); + + let q_vec4 = q_lo2_vec4 - q_hi1_vec4; + + let scale_low4 = (load_byte_at_src0_aligned(scales_byte_base + (sub_block % 8u)) >> (4u * (sub_block / 8u))) & 0xFu; + let scale_hi2 = (load_byte_at_src0_aligned(scales_byte_base + 8u + (sub_block % 4u)) >> (2u * (sub_block / 4u))) & 3u; + let dl = d_all * (f16((scale_hi2 << 4u) | scale_low4) - 32.0); + + store_shmem_kquants(dl * q_vec4, elem_idx); +#elif INIT_SRC0_SHMEM_Q4_K + let block_byte_base = src0_idx * 144u; // BLOCK_SIZE_BYTES = 144u; + let dm_byte_base = block_byte_base + 0u; + let scale_byte_base = block_byte_base + 4u; + let qs_byte_base = block_byte_base + 16u; + + let dm = unpack2x16float(load_u32_at_src0_aligned(dm_byte_base)); + let d = f16(dm[0]); + let dmin = f16(dm[1]); + + let chunk = k_in_block / 64u; + let pos_in_chunk = (k_in_block % 64u) % 32u; + let sub_block = k_in_block / 32u; + let shift_phase = sub_block & 1u; + + // whole 4 bits (4 elems) + let qs_word = load_u32_at_src0_aligned(qs_byte_base + 32u * chunk + 1u * pos_in_chunk); + let qs_vec4 = vec4<f16>( + f16((qs_word >> (4u * shift_phase + 0u)) & 0xFu), + f16((qs_word >> (4u * shift_phase + 8u)) & 0xFu), + f16((qs_word >> (4u * shift_phase + 16u)) & 0xFu), + f16((qs_word >> (4u * shift_phase + 24u)) & 0xFu) + ); + + var sc: u32; + var mn: u32; + + if (sub_block < 4u) { + let sc_byte = get_byte(load_u32_at_src0_aligned(scale_byte_base), sub_block % 4u); + let min_byte = get_byte(load_u32_at_src0_aligned(scale_byte_base + 4), sub_block % 4u); + sc = sc_byte & 63u; + mn = min_byte & 63u; + } else { + let sc_min_lo = get_byte(load_u32_at_src0_aligned(scale_byte_base + 8), (sub_block + 4u) % 4u); + let sc_hi = get_byte(load_u32_at_src0_aligned(scale_byte_base), (sub_block - 4u) % 4u); + let min_hi = get_byte(load_u32_at_src0_aligned(scale_byte_base + 4), sub_block % 4u); + sc = (sc_min_lo & 0xFu) | ((sc_hi >> 6u) << 4u); + mn = (sc_min_lo >> 4u) | ((min_hi >> 6u) << 4u); + } + + let dl = d * f16(sc); + let ml = dmin * f16(mn); + + store_shmem_kquants(dl * qs_vec4 - vec4(ml, ml, ml, ml), elem_idx); +#elif INIT_SRC0_SHMEM_Q5_K + let block_byte_base = src0_idx * 176u; // BLOCK_SIZE_BYTES = 176u; + let dm_byte_base = block_byte_base + 0u; + let scale_byte_base = block_byte_base + 4u; + let qh_byte_base = block_byte_base + 16u; + let qs_byte_base = block_byte_base + 48u; + + let dm = unpack2x16float(load_u32_at_src0_aligned(dm_byte_base)); + let d = f16(dm[0]); + let dmin = f16(dm[1]); + + let chunk = k_in_block / 64u; + let pos_in_chunk = (k_in_block % 64u) % 32u; + let sub_block = k_in_block / 32u; + let shift_phase = sub_block & 1u; + + let qh_block = k_in_block % 32u; + let qh_shift_phase = sub_block; + + // low 4 bits (4 elems) + let qs_word = load_u32_at_src0_aligned(qs_byte_base + 32u * chunk + 1u * pos_in_chunk); + let qs_lo4_vec4 = vec4<f16>( + f16((qs_word >> (4u * shift_phase + 0u)) & 0xFu), + f16((qs_word >> (4u * shift_phase + 8u)) & 0xFu), + f16((qs_word >> (4u * shift_phase + 16u)) & 0xFu), + f16((qs_word >> (4u * shift_phase + 24u)) & 0xFu) + ); + + // high 1 bit (4 elems) + let qh_word = load_u32_at_src0_aligned(qh_byte_base + qh_block); + let qh_vec4 = vec4<f16>( + f16(select(0.0, 16.0, ((qh_word >> (1u * qh_shift_phase + 0u)) & 1u) == 1u)), + f16(select(0.0, 16.0, ((qh_word >> (1u * qh_shift_phase + 8u)) & 1u) == 1u)), + f16(select(0.0, 16.0, ((qh_word >> (1u * qh_shift_phase + 16u)) & 1u) == 1u)), + f16(select(0.0, 16.0, ((qh_word >> (1u * qh_shift_phase + 24u)) & 1u) == 1u)) + ); + + var sc: u32; + var mn: u32; + + if (sub_block < 4u) { + let sc_byte = get_byte(load_u32_at_src0_aligned(scale_byte_base), sub_block % 4u); + let min_byte = get_byte(load_u32_at_src0_aligned(scale_byte_base + 4), sub_block % 4u); + sc = sc_byte & 63u; + mn = min_byte & 63u; + } else { + let sc_min_lo = get_byte(load_u32_at_src0_aligned(scale_byte_base + 8), (sub_block + 4u) % 4u); + let sc_hi = get_byte(load_u32_at_src0_aligned(scale_byte_base), (sub_block - 4u) % 4u); + let min_hi = get_byte(load_u32_at_src0_aligned(scale_byte_base + 4), sub_block % 4u); + sc = (sc_min_lo & 0xFu) | ((sc_hi >> 6u) << 4u); + mn = (sc_min_lo >> 4u) | ((min_hi >> 6u) << 4u); + } + + let dl = d * f16(sc); + let ml = dmin * f16(mn); + + store_shmem_kquants((qh_vec4 + qs_lo4_vec4) * dl - vec4<f16>(ml, ml, ml, ml), elem_idx); +#elif INIT_SRC0_SHMEM_Q6_K + let block_byte_base = src0_idx * 210u; // BLOCK_SIZE_BYTES = 210u; + let ql_byte_base = block_byte_base; + let qh_byte_base = block_byte_base + 128u; + let scales_byte_base = block_byte_base + 192u; + let d_byte_base = block_byte_base + 208u; + + let d = load_f16_at_src0(d_byte_base); + + let chunk = k_in_block / 128u; + let ql_pos_in_chunk = (k_in_block % 128u) % 64u; + let qh_pos_in_chunk = (k_in_block % 128u) % 32u; + let sub_block = k_in_block / 16u; + let ql_shift_phase = (k_in_block % 128u) / 64u; + let qh_shift_phase = (k_in_block % 128u) / 32u; + + // low 4 bits (4 elems) + let ql_word = load_u32_at_src0(ql_byte_base + 64u * chunk + 1u * ql_pos_in_chunk); + let ql_lo4_vec4 = vec4<u32>( + (ql_word >> (4u * ql_shift_phase + 0u)) & 0xFu, + (ql_word >> (4u * ql_shift_phase + 8u)) & 0xFu, + (ql_word >> (4u * ql_shift_phase + 16u)) & 0xFu, + (ql_word >> (4u * ql_shift_phase + 24u)) & 0xFu + ); + + // hi 2 bits (4 elems) + let qh_word = load_u32_at_src0(qh_byte_base + 32u * chunk + 1u * qh_pos_in_chunk); + let qh_hi2_vec4 = vec4<u32>( + ((qh_word >> (2u * qh_shift_phase + 0u)) & 0x3u) << 4u, + ((qh_word >> (2u * qh_shift_phase + 8u)) & 0x3u) << 4u, + ((qh_word >> (2u * qh_shift_phase + 16u)) & 0x3u) << 4u, + ((qh_word >> (2u * qh_shift_phase + 24u)) & 0x3u) << 4u, + ); + + let q_vec4 = vec4<f16>(qh_hi2_vec4 | ql_lo4_vec4) - vec4<f16>(32.0, 32.0, 32.0, 32.0); + + let scale_byte = scales_byte_base + 1u * sub_block; + let scale_word = load_u32_at_src0_aligned(scale_byte); + let scale = get_byte_i32(scale_word, scale_byte & 3u); + + store_shmem_kquants(d * q_vec4 * f16(scale), elem_idx); +#endif + } +} +#endif // k-quants + +#ifdef INIT_SRC0_SHMEM_IQ4_NL +const BLOCK_SIZE = 32u; +const BLOCK_SIZE_BYTES = 18u; + +fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) { + for (var elem_idx = thread_id; elem_idx < TILE_SRC0_SHMEM; elem_idx += TOTAL_WORKGROUP_SIZE) { + let tile_m = elem_idx / TILE_K; + let tile_k = elem_idx % TILE_K; + let global_m = offset_m + tile_m; + let global_k = k_outer + tile_k; + + if (global_m >= params.m || global_k >= params.k) { + shmem[elem_idx] = f16(0.0); + continue; + } + + let block_k = global_k / BLOCK_SIZE; + let k_in_block = global_k % BLOCK_SIZE; + + let src0_idx = batch_offset + global_m * params.stride_01 + block_k; + let block_byte_base = src0_idx * BLOCK_SIZE_BYTES; + let d = load_f16_at_src0(block_byte_base); + + let pos = k_in_block % 16u; + let nib_shift = (k_in_block / 16u) * 4u; + let q_packed = load_u32_at_src0(block_byte_base + 2u + (pos / 4u) * 4u); + let nib = (get_byte(q_packed, pos % 4u) >> nib_shift) & 0xFu; + + shmem[elem_idx] = d * f16(kvalues_iq4nl[nib]); + } +} +#endif // INIT_SRC0_SHMEM_IQ4_NL + +#ifdef INIT_SRC0_SHMEM_IQ4_XS +const BLOCK_SIZE = 256u; +const BLOCK_SIZE_BYTES = 136u; + +fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) { + for (var elem_idx = thread_id; elem_idx < TILE_SRC0_SHMEM; elem_idx += TOTAL_WORKGROUP_SIZE) { + let tile_m = elem_idx / TILE_K; + let tile_k = elem_idx % TILE_K; + let global_m = offset_m + tile_m; + let global_k = k_outer + tile_k; + + if (global_m >= params.m || global_k >= params.k) { + shmem[elem_idx] = f16(0.0); + continue; + } + + let block_k = global_k / BLOCK_SIZE; + let k_in_block = global_k % BLOCK_SIZE; + + let src0_idx = batch_offset + global_m * params.stride_01 + block_k; + let block_byte_base = src0_idx * BLOCK_SIZE_BYTES; + + let d_scales_h = load_u32_at_src0(block_byte_base); + let d = bitcast<vec2<f16>>(d_scales_h).x; + let scales_h = d_scales_h >> 16u; + + let ib = k_in_block / 32u; + let pos = k_in_block % 32u; + + let scales_l_word = load_u32_at_src0(block_byte_base + 4u); + let ls_lo = (get_byte(scales_l_word, ib / 2u) >> ((ib & 1u) * 4u)) & 0xFu; + let ls_hi = ((scales_h >> (2u * ib)) & 3u) << 4u; + let dl = d * f16(i32(ls_lo | ls_hi) - 32); + + let iqs = ib * 16u + (pos % 16u); + let nib_shift = (pos / 16u) * 4u; + let q_packed = load_u32_at_src0(block_byte_base + 8u + (iqs / 4u) * 4u); + let nib = (get_byte(q_packed, iqs % 4u) >> nib_shift) & 0xFu; + + shmem[elem_idx] = dl * f16(kvalues_iq4nl[nib]); + } +} +#endif // INIT_SRC0_SHMEM_IQ4_XS + +#ifdef INIT_SRC0_SHMEM_IQ1_S +const BLOCK_SIZE = 256u; +const BLOCK_SIZE_BYTES = 50u; + +fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) { + for (var elem_idx = thread_id; elem_idx < TILE_SRC0_SHMEM; elem_idx += TOTAL_WORKGROUP_SIZE) { + let tile_m = elem_idx / TILE_K; + let tile_k = elem_idx % TILE_K; + let global_m = offset_m + tile_m; + let global_k = k_outer + tile_k; + + if (global_m >= params.m || global_k >= params.k) { + shmem[elem_idx] = f16(0.0); + continue; + } + + let block_k = global_k / BLOCK_SIZE; + let k_in_block = global_k % BLOCK_SIZE; + + let src0_idx = batch_offset + global_m * params.stride_01 + block_k; + let block_byte_base = src0_idx * BLOCK_SIZE_BYTES; + let d = load_f16_as_f32_at_src0(block_byte_base); + + let ib = k_in_block / 32u; + let pos = k_in_block % 32u; + let l = pos / 8u; + let j = pos % 8u; + + let qh = load_u32_at_src0(block_byte_base + 34u + ib * 2u) & 0xFFFFu; + let dl = d * (2.0 * f32((qh >> 12u) & 7u) + 1.0); + let delta = select(IQ1_DELTA, -IQ1_DELTA, (qh & 0x8000u) != 0u); + + let qs_w = load_u32_at_src0(block_byte_base + 2u + ib * 4u); + let ig = (get_byte(qs_w, l) | (((qh >> (3u * l)) & 7u) << 8u)) * 8u; + + let gw = iq1_grid[(ig + j) / 16u]; + let g = (gw >> (((ig + j) % 16u) * 2u)) & 3u; + let gs = bitcast<i32>(g << 30u) >> 30u; + + shmem[elem_idx] = f16(dl * (f32(gs) + delta)); + } +} +#endif // INIT_SRC0_SHMEM_IQ1_S + +#ifdef INIT_SRC0_SHMEM_IQ1_M +const BLOCK_SIZE = 256u; +const BLOCK_SIZE_BYTES = 56u; + +fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) { + for (var elem_idx = thread_id; elem_idx < TILE_SRC0_SHMEM; elem_idx += TOTAL_WORKGROUP_SIZE) { + let tile_m = elem_idx / TILE_K; + let tile_k = elem_idx % TILE_K; + let global_m = offset_m + tile_m; + let global_k = k_outer + tile_k; + + if (global_m >= params.m || global_k >= params.k) { + shmem[elem_idx] = f16(0.0); + continue; + } + + let block_k = global_k / BLOCK_SIZE; + let k_in_block = global_k % BLOCK_SIZE; + + let src0_idx = batch_offset + global_m * params.stride_01 + block_k; + let block_byte_base = src0_idx * BLOCK_SIZE_BYTES; + + let scales0 = load_u32_at_src0(block_byte_base + 48u); + let scales1 = load_u32_at_src0(block_byte_base + 52u); + let scale_packed = ((scales0 >> 12u) & 0xFu) | + ((scales0 >> 24u) & 0x00F0u) | + ((scales1 >> 4u) & 0x0F00u) | + ((scales1 >> 16u) & 0xF000u); + let d = f32(bitcast<vec2<f16>>(scale_packed).x); + + let ib = k_in_block / 32u; + let pos = k_in_block % 32u; + let l = pos / 8u; + let j = pos % 8u; + + let scales = select(scales0, scales1, ib >= 4u); + let sw = (scales >> (16u * ((ib / 2u) % 2u))) & 0xFFFFu; + let s_pair = (sw >> (6u * (ib % 2u) + 3u * (l / 2u))) & 0x7u; + let dl = d * f32(2u * s_pair + 1u); + + let qh_word = load_u32_at_src0(block_byte_base + 32u + (ib / 2u) * 4u); + let qh = qh_word >> (16u * (ib % 2u)); + let qh_nib = (qh >> (4u * l)) & 0xFu; + + let qs_w = load_u32_at_src0(block_byte_base + ib * 4u); + let idx = get_byte(qs_w, l) | ((qh_nib & 7u) << 8u); + let delta = select(IQ1_DELTA, -IQ1_DELTA, (qh_nib & 0x8u) != 0u); + + let ig = idx * 8u; + let gw = iq1_grid[(ig + j) / 16u]; + let g = (gw >> (((ig + j) % 16u) * 2u)) & 3u; + let gs = bitcast<i32>(g << 30u) >> 30u; + + shmem[elem_idx] = f16(dl * (f32(gs) + delta)); + } +} +#endif // INIT_SRC0_SHMEM_IQ1_M + +#ifdef INIT_SRC0_SHMEM_IQ2_XXS +const BLOCK_SIZE = 256u; +const BLOCK_SIZE_BYTES = 66u; + +fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) { + for (var elem_idx = thread_id; elem_idx < TILE_SRC0_SHMEM; elem_idx += TOTAL_WORKGROUP_SIZE) { + let tile_m = elem_idx / TILE_K; + let tile_k = elem_idx % TILE_K; + let global_m = offset_m + tile_m; + let global_k = k_outer + tile_k; + + if (global_m >= params.m || global_k >= params.k) { + shmem[elem_idx] = f16(0.0); + continue; + } + + let block_k = global_k / BLOCK_SIZE; + let k_in_block = global_k % BLOCK_SIZE; + + let src0_idx = batch_offset + global_m * params.stride_01 + block_k; + let block_byte_base = src0_idx * BLOCK_SIZE_BYTES; + let d = load_f16_as_f32_at_src0(block_byte_base); + + let entry_idx = k_in_block / 8u; + let j = k_in_block % 8u; + + let ib = entry_idx & ~3u; + let l = entry_idx & 3u; + + let aux0 = load_u32_at_src0(block_byte_base + 2u + ib * 2u); + let aux1 = load_u32_at_src0(block_byte_base + 2u + (ib + 2u) * 2u); + let db = d * (0.5 + f32(aux1 >> 28u)) * 0.25; + + let ig = get_byte(aux0, l) * 8u; + let is = (aux1 >> (7u * l)) & 127u; + let signs = get_byte(ksigns_iq2xs[is / 4u], is % 4u); + + let g = get_byte(iq2xxs_grid[(ig + j) / 4u], (ig + j) % 4u); + let m = select(1.0, -1.0, (get_byte(kmask_iq2xs[j / 4u], j % 4u) & signs) != 0u); + + shmem[elem_idx] = f16(db * f32(g) * m); + } +} +#endif // INIT_SRC0_SHMEM_IQ2_XXS + +#ifdef INIT_SRC0_SHMEM_IQ2_XS +const BLOCK_SIZE = 256u; +const BLOCK_SIZE_BYTES = 74u; + +fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) { + for (var elem_idx = thread_id; elem_idx < TILE_SRC0_SHMEM; elem_idx += TOTAL_WORKGROUP_SIZE) { + let tile_m = elem_idx / TILE_K; + let tile_k = elem_idx % TILE_K; + let global_m = offset_m + tile_m; + let global_k = k_outer + tile_k; + + if (global_m >= params.m || global_k >= params.k) { + shmem[elem_idx] = f16(0.0); + continue; + } + + let block_k = global_k / BLOCK_SIZE; + let k_in_block = global_k % BLOCK_SIZE; + + let src0_idx = batch_offset + global_m * params.stride_01 + block_k; + let block_byte_base = src0_idx * BLOCK_SIZE_BYTES; + let d = load_f16_as_f32_at_src0(block_byte_base); + + let entry_idx = k_in_block / 8u; + let j = k_in_block % 8u; + + let ib = entry_idx & ~3u; + let l = entry_idx & 3u; + + let scales_word = load_u32_at_src0(block_byte_base + 66u + (ib / 16u) * 4u); + let s = get_byte(scales_word, (ib % 16u) / 4u); + let s_nib = select(s & 0xFu, (s >> 4u) & 0xFu, (l / 2u) != 0u); + let dl = d * (0.5 + f32(s_nib)) * 0.25; + + let qs_word = load_u32_at_src0(block_byte_base + 2u + (ib + l) * 2u); + let qs_val = qs_word & 0xFFFFu; + let ig = (qs_val & 511u) * 8u; + let is = qs_val >> 9u; + let signs = get_byte(ksigns_iq2xs[is / 4u], is % 4u); + + let g = get_byte(iq2xs_grid[(ig + j) / 4u], (ig + j) % 4u); + let m = select(1.0, -1.0, (get_byte(kmask_iq2xs[j / 4u], j % 4u) & signs) != 0u); + + shmem[elem_idx] = f16(dl * f32(g) * m); + } +} +#endif // INIT_SRC0_SHMEM_IQ2_XS + +#ifdef INIT_SRC0_SHMEM_IQ2_S +const BLOCK_SIZE = 256u; +const BLOCK_SIZE_BYTES = 82u; + +fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) { + for (var elem_idx = thread_id; elem_idx < TILE_SRC0_SHMEM; elem_idx += TOTAL_WORKGROUP_SIZE) { + let tile_m = elem_idx / TILE_K; + let tile_k = elem_idx % TILE_K; + let global_m = offset_m + tile_m; + let global_k = k_outer + tile_k; + + if (global_m >= params.m || global_k >= params.k) { + shmem[elem_idx] = f16(0.0); + continue; } + + let block_k = global_k / BLOCK_SIZE; + let k_in_block = global_k % BLOCK_SIZE; + + let src0_idx = batch_offset + global_m * params.stride_01 + block_k; + let block_byte_base = src0_idx * BLOCK_SIZE_BYTES; + let d = load_f16_as_f32_at_src0(block_byte_base); + + let ib = k_in_block / 32u; + let l = (k_in_block % 32u) / 8u; + let j = k_in_block % 8u; + + let scales_word = load_u32_at_src0(block_byte_base + 74u + (ib / 4u) * 4u); + let s = get_byte(scales_word, ib % 4u); + let s_nib = select(s & 0xFu, (s >> 4u) & 0xFu, (l / 2u) != 0u); + let dl = d * (0.5 + f32(s_nib)) * 0.25; + + let qs_word = load_u32_at_src0(block_byte_base + 2u + ib * 4u); + let qh_word = load_u32_at_src0(block_byte_base + 66u + (ib / 4u) * 4u); + let qh_b = (get_byte(qh_word, ib % 4u) << (8u - 2u * l)) & 0x300u; + let ig = (get_byte(qs_word, l) | qh_b) * 8u; + + let signs_word = load_u32_at_src0(block_byte_base + 34u + ib * 4u); + let signs = get_byte(signs_word, l); + + let g = get_byte(iq2s_grid[(ig + j) / 4u], (ig + j) % 4u); + let m = select(1.0, -1.0, (get_byte(kmask_iq2xs[j / 4u], j % 4u) & signs) != 0u); + + shmem[elem_idx] = f16(dl * f32(g) * m); } } +#endif // INIT_SRC0_SHMEM_IQ2_S + +#ifdef INIT_SRC0_SHMEM_IQ3_XXS +const BLOCK_SIZE = 256u; +const BLOCK_SIZE_BYTES = 98u; -#enddecl(INIT_SRC0_SHMEM_Q4_0) +fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) { + for (var elem_idx = thread_id; elem_idx < TILE_SRC0_SHMEM; elem_idx += TOTAL_WORKGROUP_SIZE) { + let tile_m = elem_idx / TILE_K; + let tile_k = elem_idx % TILE_K; + let global_m = offset_m + tile_m; + let global_k = k_outer + tile_k; + + if (global_m >= params.m || global_k >= params.k) { + shmem[elem_idx] = f16(0.0); + continue; + } + + let block_k = global_k / BLOCK_SIZE; + let k_in_block = global_k % BLOCK_SIZE; + + let src0_idx = batch_offset + global_m * params.stride_01 + block_k; + let block_byte_base = src0_idx * BLOCK_SIZE_BYTES; + let d = load_f16_as_f32_at_src0(block_byte_base); + + let ib_pair = k_in_block / 32u; + let in_pair = k_in_block % 32u; + let l = in_pair / 8u; + let in_l = in_pair % 8u; + let k2 = in_l / 4u; + let j = in_l % 4u; + + let ib = ib_pair * 2u; + let sc_sign_off = block_byte_base + 2u + (ib + 32u) * 2u; + let sc_sign = load_u32_at_src0(sc_sign_off); + let db = d * (0.5 + f32(sc_sign >> 28u)) * 0.5; + let is = (sc_sign >> (7u * l)) & 127u; + let signs = get_byte(ksigns_iq2xs[is / 4u], is % 4u); + + let ig_word = load_u32_at_src0(block_byte_base + 2u + (ib * 2u + l) * 2u) & 0xFFFFu; + let ig_byte = get_byte(ig_word, k2); + let g = get_byte(iq3xxs_grid[ig_byte], j); + let m = select(1.0, -1.0, (get_byte(kmask_iq2xs[k2], j) & signs) != 0u); + + shmem[elem_idx] = f16(db * f32(g) * m); + } +} +#endif // INIT_SRC0_SHMEM_IQ3_XXS + +#ifdef INIT_SRC0_SHMEM_IQ3_S +const BLOCK_SIZE = 256u; +const BLOCK_SIZE_BYTES = 110u; + +fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) { + for (var elem_idx = thread_id; elem_idx < TILE_SRC0_SHMEM; elem_idx += TOTAL_WORKGROUP_SIZE) { + let tile_m = elem_idx / TILE_K; + let tile_k = elem_idx % TILE_K; + let global_m = offset_m + tile_m; + let global_k = k_outer + tile_k; + + if (global_m >= params.m || global_k >= params.k) { + shmem[elem_idx] = f16(0.0); + continue; + } + + let block_k = global_k / BLOCK_SIZE; + let k_in_block = global_k % BLOCK_SIZE; + + let src0_idx = batch_offset + global_m * params.stride_01 + block_k; + let block_byte_base = src0_idx * BLOCK_SIZE_BYTES; + let d = load_f16_as_f32_at_src0(block_byte_base); + + let ib = k_in_block / 64u; + let rest = k_in_block % 64u; + let k = rest / 32u; + let in_k = rest % 32u; + let l = in_k / 8u; + let in_l = in_k % 8u; + let k2 = in_l / 4u; + let j = in_l % 4u; + + let scales_word = load_u32_at_src0(block_byte_base + 106u); + let s = get_byte(scales_word, ib); + let s_nib = select(s & 0xFu, (s >> 4u) & 0xFu, k != 0u); + let dl = d * (1.0 + 2.0 * f32(s_nib)); + + let qh_word = load_u32_at_src0(block_byte_base + 66u + (ib / 2u) * 4u); + let qh_byte = get_byte(qh_word, (ib % 2u) * 2u + k); + + let ig_word = load_u32_at_src0(block_byte_base + 2u + (ib * 8u + k * 4u + l) * 2u) & 0xFFFFu; + let ig_lo = get_byte(ig_word, 0u) | ((qh_byte << (8u - 2u * l)) & 256u); + let ig_hi = get_byte(ig_word, 1u) | ((qh_byte << (7u - 2u * l)) & 256u); + let ig = select(ig_lo, ig_hi, k2 != 0u); + + let signs_word = load_u32_at_src0(block_byte_base + 74u + (ib * 2u + k) * 4u); + let signs = get_byte(signs_word, l); + + let g = get_byte(iq3s_grid[ig], j); + let m = select(1.0, -1.0, (get_byte(kmask_iq2xs[k2], j) & signs) != 0u); + + shmem[elem_idx] = f16(dl * f32(g) * m); + } +} +#endif // INIT_SRC0_SHMEM_IQ3_S diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_id.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_id.wgsl new file mode 100644 index 00000000000..91039ff2546 --- /dev/null +++ b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_id.wgsl @@ -0,0 +1,195 @@ +enable f16; + +#define DECLARE_BYTE_LOADERS_SRC0 +#include "common_decls.tmpl" + +#include "mul_mat_decls.tmpl" + +#ifdef VEC +fn store_val(acc: array<array<f16, TILE_M>, TILE_N>, tn: u32, tm: u32) -> vec4<f32> { + return vec4<f32>(f32(acc[tn][tm]), f32(acc[tn][tm + 1]), f32(acc[tn][tm + 2]), f32(acc[tn][tm + 3])); +} +#endif + +#ifdef SCALAR +fn store_val(acc: array<array<f16, TILE_M>, TILE_N>, tn: u32, tm: u32) -> f32 { + return f32(acc[tn][tm]); +} +#endif + +struct MulMatIdParams { + offset_src0: u32, + offset_src1: u32, + offset_dst: u32, + + k: u32, + m: u32, + n_expert: u32, + n_expert_used: u32, + n_tokens: u32, + b_ne1: u32, + + stride_01: u32, + stride_11: u32, + stride_02: u32, + stride_12: u32, +}; + +@group(0) @binding(0) var<storage, read_write> src0: array<SRC0_TYPE>; // [cols, rows, n_expert] +@group(0) @binding(1) var<storage, read_write> src1: array<SRC1_TYPE>; // [cols, b_ne1, n_tokens] +@group(0) @binding(2) var<storage, read_write> dst: array<DST_TYPE>; // [rows, n_expert_used, n_tokens] +@group(0) @binding(3) var<storage, read_write> global_gathered_expert_used: array<u32>; // [n_expert][n_tokens] +@group(0) @binding(4) var<storage, read_write> global_gathered_tokens: array<u32>; // [n_expert][n_tokens] +@group(0) @binding(5) var<storage, read_write> gathered_count_ids: array<u32>; // [n_expert] + +@group(0) @binding(6) var<uniform> params: MulMatIdParams; + +fn get_local_n(thread_id: u32) -> u32 { + return thread_id / WORKGROUP_SIZE_M; +} +fn get_local_m(thread_id: u32) -> u32 { + return thread_id % WORKGROUP_SIZE_M; +} + +const TOTAL_WORKGROUP_SIZE = WORKGROUP_SIZE_M * WORKGROUP_SIZE_N; +const TILE_SRC0_SHMEM = TILE_K * WORKGROUP_SIZE_M * TILE_M; +const TILE_SRC1_SHMEM = TILE_K * WORKGROUP_SIZE_N * TILE_N; + +var<workgroup> shmem: array<f16, TILE_SRC0_SHMEM + TILE_SRC1_SHMEM>; +var<workgroup> gathered_expert_used: array<u32, TILE_N * WORKGROUP_SIZE_N>; +var<workgroup> gathered_tokens: array<u32, TILE_N * WORKGROUP_SIZE_N>; + +#ifdef INIT_SRC1_SHMEM_FLOAT +fn init_shmem_id_src1(thread_id: u32, offset_src1: u32, rest_token_n: u32, k_outer: u32) { + for (var elem_idx = thread_id * VEC_SIZE; elem_idx < TILE_SRC1_SHMEM; elem_idx += TOTAL_WORKGROUP_SIZE * VEC_SIZE) { + let tile_n = elem_idx / TILE_K; + let tile_k = elem_idx % TILE_K; + if (tile_n < rest_token_n) { + let global_src10 = k_outer + tile_k; + let expert_used_idx = gathered_expert_used[tile_n] % params.b_ne1; + let token_idx = gathered_tokens[tile_n]; + let src1_idx = offset_src1 + token_idx * params.stride_12 + expert_used_idx * params.stride_11 + global_src10; + let src1_val = select( + SRC1_TYPE(0.0), + src1[src1_idx/VEC_SIZE], + global_src10 < params.k); + store_shmem(SHMEM_TYPE(src1_val), TILE_SRC0_SHMEM + elem_idx); + } else { + store_shmem(SHMEM_TYPE(0.0), TILE_SRC0_SHMEM + elem_idx); + } + } +} +#endif // INIT_SRC1_SHMEM_FLOAT + +@compute @workgroup_size(TOTAL_WORKGROUP_SIZE) +fn main(@builtin(workgroup_id) wg_id: vec3<u32>, + @builtin(local_invocation_id) local_id: vec3<u32>, + @builtin(num_workgroups) num_wg: vec3<u32>) { + + let thread_id = local_id.x; + let local_m = get_local_m(thread_id); + let local_n = get_local_n(thread_id); + + var expert_idx:u32 = 0xFFFFFFFFu; + var wg_in_batch:u32 = 0; + var wg_sum:u32 = 0; + let wg_m_count = (params.m + WORKGROUP_SIZE_M * TILE_M - 1u) / (WORKGROUP_SIZE_M * TILE_M); + let wg_linear = wg_id.y * num_wg.x + wg_id.x; + + for (var i = 0u;i < params.n_expert;i += 1) { + let wg_n_count = (gathered_count_ids[i] + WORKGROUP_SIZE_N * TILE_N - 1u) / (WORKGROUP_SIZE_N * TILE_N); + let wg_per_matrix = wg_m_count * wg_n_count; + if (wg_sum <= wg_linear && wg_linear < wg_sum + wg_per_matrix) { + expert_idx = i; + wg_in_batch = wg_linear - wg_sum; + break; + } + wg_sum += wg_per_matrix; + } + + let is_valid = expert_idx != 0xFFFFFFFFu; + + var wg_m: u32 = 0; + var wg_n: u32 = 0; + var offset_wg_m: u32 = 0; + var offset_wg_n: u32 = 0; + var rest_token_n: u32 = 0; + var src0_batch_offset: u32 = 0; + + wg_m = wg_in_batch % wg_m_count; + wg_n = wg_in_batch / wg_m_count; + + offset_wg_m = wg_m * WORKGROUP_SIZE_M * TILE_M; + offset_wg_n = wg_n * WORKGROUP_SIZE_N * TILE_N; + + if (is_valid) { + rest_token_n = gathered_count_ids[expert_idx] - offset_wg_n; + let global_gathered_base = expert_idx * params.n_tokens + offset_wg_n; + for (var i = thread_id; i < TILE_N * WORKGROUP_SIZE_N && offset_wg_n + i < gathered_count_ids[expert_idx]; i += TOTAL_WORKGROUP_SIZE) { + gathered_expert_used[i] = global_gathered_expert_used[global_gathered_base + i]; + gathered_tokens[i] = global_gathered_tokens[global_gathered_base + i]; + } + src0_batch_offset = params.offset_src0 + expert_idx * params.stride_02; + } + + workgroupBarrier(); + + let output_row_base = offset_wg_m + local_m * TILE_M; + let output_col_base = offset_wg_n + local_n * TILE_N; + + let dst2_stride = params.m * params.n_expert_used; + let dst1_stride = params.m; + + var acc: array<array<f16, TILE_M>, TILE_N>; + + for (var k_outer = 0u; k_outer < params.k; k_outer += TILE_K) { + + if (is_valid) { + init_shmem_src0(thread_id, src0_batch_offset, offset_wg_m, k_outer); + init_shmem_id_src1(thread_id, params.offset_src1, rest_token_n, k_outer); + } + + workgroupBarrier(); + + if (is_valid) { + let k_end = min(TILE_K, params.k - k_outer); + + for (var k_inner = 0u; k_inner < k_end; k_inner++) { + var src0_tile: array<f16, TILE_M>; + for (var tm = 0u; tm < TILE_M; tm++) { + let src0_m = local_m * TILE_M + tm; + let src0_idx = k_inner + src0_m * TILE_K; + src0_tile[tm] = shmem[src0_idx]; + } + for (var tn = 0u; tn < TILE_N; tn++) { + let src1_n = local_n * TILE_N + tn; + let src1_idx = src1_n * TILE_K + k_inner; + let src1_val = shmem[TILE_SRC0_SHMEM + src1_idx]; + for (var tm = 0u; tm < TILE_M; tm++) { + acc[tn][tm] += src0_tile[tm] * src1_val; + } + } + } + } + + workgroupBarrier(); + } + + if (is_valid) { + for (var tn = 0u; tn < TILE_N; tn++) { + let n_idx = output_col_base + tn; + if (n_idx < gathered_count_ids[expert_idx]) { + let dst1_idx = gathered_expert_used[n_idx - offset_wg_n]; + let dst2_idx = gathered_tokens[n_idx - offset_wg_n]; + let dst12_offset = params.offset_dst + dst2_idx * dst2_stride + dst1_idx * dst1_stride; + for (var tm = 0u; tm < TILE_M; tm += VEC_SIZE) { + let global_row = output_row_base + tm; + if (global_row < params.m) { + let dst_idx = dst12_offset + global_row; + dst[dst_idx/VEC_SIZE] = store_val(acc, tn, tm); + } + } + } + } + } +} diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_id_gather.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_id_gather.wgsl new file mode 100644 index 00000000000..581e922709d --- /dev/null +++ b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_id_gather.wgsl @@ -0,0 +1,52 @@ +enable f16; + +struct MulMatIdGatherParams { + offset_ids: u32, + + n_expert: u32, + n_expert_used: u32, + n_tokens: u32, + + stride_ids_1: u32, +}; + +@group(0) @binding(0) var<storage, read_write> ids: array<i32>; // [n_expert_used, n_tokens] +@group(0) @binding(1) var<storage, read_write> global_gathered_expert_used: array<u32>; // [n_expert][n_tokens] +@group(0) @binding(2) var<storage, read_write> global_gathered_tokens: array<u32>; // [n_expert][n_tokens] +@group(0) @binding(3) var<storage, read_write> gathered_count_ids: array<u32>; // [n_expert] + +@group(0) @binding(4) var<uniform> params: MulMatIdGatherParams; + +var<workgroup> count:atomic<u32>; + +@compute @workgroup_size(WG_SIZE) +fn main(@builtin(workgroup_id) wg_id: vec3<u32>, + @builtin(local_invocation_id) local_id: vec3<u32>) { + + let thread_id = local_id.x; + let own_expert = wg_id.x; // the expert assigned to this workgroup + + if (thread_id == 0u) { + atomicStore(&count, 0); + } + + workgroupBarrier(); + + for (var i = thread_id;i < params.n_expert_used * params.n_tokens;i += WG_SIZE) { + let row = i / params.n_expert_used; + let col = i % params.n_expert_used; + let expert = u32(ids[params.offset_ids + row * params.stride_ids_1 + col]); + if (own_expert == expert) { + let pos = atomicAdd(&count, 1u); + let gathered_id = own_expert * params.n_tokens + pos; + global_gathered_expert_used[gathered_id] = col; + global_gathered_tokens[gathered_id] = row; + } + } + + workgroupBarrier(); + + if (thread_id == 0u) { + gathered_count_ids[own_expert] = atomicLoad(&count); + } +} diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_id_vec.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_id_vec.wgsl new file mode 100644 index 00000000000..6ff9bcf2df0 --- /dev/null +++ b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_id_vec.wgsl @@ -0,0 +1,154 @@ +#ifdef USE_SUBGROUP_REDUCTION +enable subgroups; +#endif +enable f16; + +#define DECLARE_BYTE_LOADERS_SRC0 +#include "common_decls.tmpl" + +#include "mul_mat_vec_acc.tmpl" + +struct MulMatIdVecParams { + offset_src0: u32, + offset_src1: u32, + offset_ids: u32, + offset_dst: u32, + + k: u32, + m: u32, + n_expert: u32, + n_expert_used: u32, + b_ne1: u32, + + stride_01: u32, + stride_11: u32, + stride_02: u32, + stride_12: u32, +}; + +@group(0) @binding(0) var<storage, read_write> src0: array<SRC0_TYPE>; // [cols, rows, n_expert] +@group(0) @binding(1) var<storage, read_write> src1: array<SRC1_TYPE>; // [cols, b_ne1, n_tokens(1)] +@group(0) @binding(2) var<storage, read_write> ids: array<u32>; // [n_experd_used, n_tokens(1)] +@group(0) @binding(3) var<storage, read_write> dst: array<f32>; // [rows, n_expert_used, n_tokens(1)] + +// "mul_mat_vec_acc.tmpl" requires params.k, params.m, params.stride_01 +@group(0) @binding(4) var<uniform> params: MulMatIdVecParams; + +// Flattened as [row][thread] to keep each row's reduction contiguous in memory. +var<workgroup> partial_sums: array<f32, OUTPUTS_PER_WG * WG_SIZE>; + +fn partial_index(row: u32, thread: u32) -> u32 { + return row * WG_SIZE + thread; +} + +var<workgroup> gathered_count_ids: array<u32, N_EXPERTS>; +var<workgroup> gathered_expert_used: array<u32, N_EXPERTS>; + +@compute @workgroup_size(WG_SIZE) +fn main( + @builtin(local_invocation_id) local_id: vec3<u32>, + @builtin(workgroup_id) wg_id: vec3<u32>, + @builtin(num_workgroups) num_wg: vec3<u32> +#ifdef USE_SUBGROUP_REDUCTION + , @builtin(subgroup_id) subgroup_id: u32, + @builtin(subgroup_invocation_id) subgroup_invocation_id: u32, + @builtin(num_subgroups) num_subgroups: u32, + @builtin(subgroup_size) subgroup_size: u32 +#endif +) { + + let thread_id = local_id.x; + + for (var i = thread_id;i < params.n_expert;i += WG_SIZE) { + gathered_count_ids[i] = 0; + } + + workgroupBarrier(); + + // gather the selected experts for the target token. + for (var col = thread_id;col < params.n_expert_used;col += WG_SIZE) { + let expert = ids[params.offset_ids + col]; + gathered_count_ids[expert] = 1; + gathered_expert_used[expert] = col; + } + + workgroupBarrier(); + + let output_groups:u32 = (params.m + OUTPUTS_PER_WG - 1u) / OUTPUTS_PER_WG; + let wg_linear = wg_id.y * num_wg.x + wg_id.x; + + var own_expert:u32 = 0; + var wg_in_batch:u32 = 0; + var wg_sum:u32 = 0; + + for (var i = 0u;i < params.n_expert;i += 1) { + let wg_vec_count = gathered_count_ids[i]; // 1 or 0 + let wg_per_matrix = output_groups * wg_vec_count; + if (wg_sum <= wg_linear && wg_linear < wg_sum + wg_per_matrix) { + own_expert = i; + wg_in_batch = wg_linear - wg_sum; + break; + } + wg_sum += wg_per_matrix; + } + + let row_base = (wg_linear % output_groups) * OUTPUTS_PER_WG; + let dst1_stride = params.m; + + let src0_batch_offset = params.offset_src0 + own_expert * params.stride_02; + let src1_idx_base = params.offset_src1 + (gathered_expert_used[own_expert] % params.b_ne1) * params.stride_11; + let dst_idx_base = params.offset_dst + gathered_expert_used[own_expert] * dst1_stride + row_base; + + let acc = accumulate_vec_dot(thread_id, row_base, src0_batch_offset, src1_idx_base); + +#ifdef USE_SUBGROUP_REDUCTION + for (var row = 0u; row < OUTPUTS_PER_WG; row++) { + let subgroup_total = subgroupAdd(acc[row]); + if (subgroup_invocation_id == 0u) { + partial_sums[partial_index(row, subgroup_id)] = subgroup_total; + } + } + + workgroupBarrier(); + + for (var row = subgroup_id; (row < OUTPUTS_PER_WG) && (row_base + row < params.m); row += num_subgroups) { + let output_row = row_base + row; + var row_acc = 0.0f; + for (var k = subgroup_invocation_id; k < num_subgroups; k += subgroup_size) { + row_acc += partial_sums[partial_index(row, k)]; + } + let row_total = subgroupAdd(row_acc); + if (subgroup_invocation_id == 0) { + dst[dst_idx_base + row] = row_total; + } + } +#endif + +#ifdef USE_WORKGROUP_REDUCTION + for (var row = 0u; row < OUTPUTS_PER_WG; row++) { + partial_sums[partial_index(row, thread_id)] = acc[row]; + } + + workgroupBarrier(); + + var stride:u32 = WG_SIZE / 2u; + + while (stride > 0) { + if (thread_id < stride) { + for (var row = 0u; row < OUTPUTS_PER_WG; row++) { + partial_sums[partial_index(row, thread_id)] += partial_sums[partial_index(row, thread_id + stride)]; + } + } + + workgroupBarrier(); + stride = stride / 2; + } + + if (thread_id < OUTPUTS_PER_WG) { + let output_row = row_base + thread_id; + if (output_row < params.m) { + dst[dst_idx_base + thread_id] = partial_sums[partial_index(thread_id, 0)]; + } + } +#endif +} diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_reg_tile.tmpl.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_reg_tile.tmpl.wgsl deleted file mode 100644 index 6b1dd26cd9e..00000000000 --- a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_reg_tile.tmpl.wgsl +++ /dev/null @@ -1,247 +0,0 @@ -#define(VARIANTS) -[ - { - "SHADER_SUFFIX": "f32_f32_vec", - "REPLS": { - "SRC0_TYPE" : "vec4<f32>", - "SRC1_TYPE" : "vec4<f32>", - "DST_TYPE" : "vec4<f32>", - "SHMEM_TYPE" : "vec4<f16>", - "VEC_SIZE" : 4, - }, - "DECLS": ["VEC", "SHMEM_VEC", "INIT_SRC0_SHMEM_FLOAT", "INIT_SRC1_SHMEM"] - }, - { - "SHADER_SUFFIX": "f32_f32", - "REPLS": { - "SRC0_TYPE" : "f32", - "SRC1_TYPE" : "f32", - "DST_TYPE" : "f32", - "SHMEM_TYPE" : "f16", - "VEC_SIZE" : 1, - }, - "DECLS": ["SCALAR", "SHMEM_SCALAR", "INIT_SRC0_SHMEM_FLOAT", "INIT_SRC1_SHMEM"] - }, - { - "SHADER_SUFFIX": "f16_f32_vec", - "REPLS": { - "SRC0_TYPE" : "vec4<f16>", - "SRC1_TYPE" : "vec4<f32>", - "DST_TYPE" : "vec4<f32>", - "SHMEM_TYPE" : "vec4<f16>", - "VEC_SIZE" : 4, - }, - "DECLS": ["VEC", "SHMEM_VEC", "INIT_SRC0_SHMEM_FLOAT", "INIT_SRC1_SHMEM"] - }, - { - "SHADER_SUFFIX": "f16_f32", - "REPLS": { - "SRC0_TYPE" : "f16", - "SRC1_TYPE" : "f32", - "DST_TYPE" : "f32", - "SHMEM_TYPE" : "f16", - "VEC_SIZE" : 1, - }, - "DECLS": ["SCALAR", "SHMEM_SCALAR", "INIT_SRC0_SHMEM_FLOAT", "INIT_SRC1_SHMEM"] - }, - { - "SHADER_SUFFIX": "f16_f16_vec", - "REPLS": { - "SRC0_TYPE" : "vec4<f16>", - "SRC1_TYPE" : "vec4<f16>", - "DST_TYPE" : "vec4<f32>", - "SHMEM_TYPE" : "vec4<f16>", - "VEC_SIZE" : 4, - }, - "DECLS": ["VEC", "SHMEM_VEC", "INIT_SRC0_SHMEM_FLOAT", "INIT_SRC1_SHMEM"] - }, - { - "SHADER_SUFFIX": "f16_f16", - "REPLS": { - "SRC0_TYPE" : "f16", - "SRC1_TYPE" : "f16", - "DST_TYPE" : "f32", - "SHMEM_TYPE" : "f16", - "VEC_SIZE" : 1, - }, - "DECLS": ["SCALAR", "SHMEM_SCALAR", "INIT_SRC0_SHMEM_FLOAT", "INIT_SRC1_SHMEM"] - }, - { - "SHADER_SUFFIX": "q4_0_f32_vec", - "REPLS": { - "SRC0_TYPE" : "f16", - "SRC1_TYPE" : "vec4<f32>", - "DST_TYPE" : "vec4<f32>", - "SHMEM_TYPE" : "vec4<f16>", - "VEC_SIZE" : 4, - }, - "DECLS": ["BYTE_HELPERS", "VEC", "SHMEM_VEC", "INIT_SRC0_SHMEM_Q4_0", "INIT_SRC1_SHMEM"] - }, - { - "SHADER_SUFFIX": "q4_0_f32", - "REPLS": { - "SRC0_TYPE" : "f16", - "SRC1_TYPE" : "f32", - "DST_TYPE" : "f32", - "SHMEM_TYPE" : "f16", - "VEC_SIZE" : 1, - }, - "DECLS": ["BYTE_HELPERS", "SCALAR", "SHMEM_SCALAR", "INIT_SRC0_SHMEM_Q4_0", "INIT_SRC1_SHMEM"] - } -] - -#end(VARIANTS) - -#define(DECLS) - -#decl(VEC) -fn store_val(acc: array<array<f16, TILE_N>, TILE_M>, tn: u32, tm: u32) -> vec4<f32> { - return vec4<f32>(f32(acc[tm][tn]), f32(acc[tm + 1][tn]), f32(acc[tm + 2][tn]), f32(acc[tm + 3][tn])); -} -#enddecl(VEC) - -#decl(SCALAR) -fn store_val(acc: array<array<f16, TILE_N>, TILE_M>, tn: u32, tm: u32) -> f32 { - return f32(acc[tm][tn]); -} -#enddecl(SCALAR) - -#end(DECLS) - -#define(SHADER) -enable f16; - -struct MulMatParams { - offset_src0: u32, - offset_src1: u32, - offset_dst: u32, - m: u32, - n: u32, - k: u32, - stride_01: u32, - stride_11: u32, - stride_02: u32, - stride_12: u32, - stride_03: u32, - stride_13: u32, - bs02: u32, - bs03: u32, - broadcast2: u32, - broadcast3: u32 -}; - -@group(0) @binding(0) var<storage, read_write> src0: array<{{SRC0_TYPE}}>; // M rows, K columns -@group(0) @binding(1) var<storage, read_write> src1: array<{{SRC1_TYPE}}>; // K rows, N columns (transposed) -@group(0) @binding(2) var<storage, read_write> dst: array<{{DST_TYPE}}>; // M rows, N columns (transposed) - -@group(0) @binding(3) var<uniform> params: MulMatParams; - -DECLS - -fn get_local_n(thread_id: u32) -> u32 { - return thread_id / WORKGROUP_SIZE_M; -} -fn get_local_m(thread_id: u32) -> u32 { - return thread_id % WORKGROUP_SIZE_M; -} - -// TILE_M must be multiple of 4 for vec4 loads -const TILE_M = {{WEBGPU_TILE_M}}u; -const TILE_N = {{WEBGPU_TILE_N}}u; - -override WORKGROUP_SIZE_M: u32; -override WORKGROUP_SIZE_N: u32; -override TILE_K: u32; - -override TOTAL_WORKGROUP_SIZE = WORKGROUP_SIZE_M * WORKGROUP_SIZE_N; -override TILE_SRC0_SHMEM = TILE_K * WORKGROUP_SIZE_M * TILE_M; -override TILE_SRC1_SHMEM = TILE_K * WORKGROUP_SIZE_N * TILE_N; - -var<workgroup> shmem: array<f16, TILE_SRC0_SHMEM + TILE_SRC1_SHMEM>; - -@compute @workgroup_size(TOTAL_WORKGROUP_SIZE) -fn main(@builtin(workgroup_id) wg_id: vec3<u32>, - @builtin(local_invocation_id) local_id: vec3<u32>) { - - let thread_id = local_id.x; - let local_m = get_local_m(thread_id); - let local_n = get_local_n(thread_id); - - let wg_n_count = (params.n + WORKGROUP_SIZE_N * TILE_N - 1u) / (WORKGROUP_SIZE_N * TILE_N); - let wg_m_count = (params.m + WORKGROUP_SIZE_M * TILE_M - 1u) / (WORKGROUP_SIZE_M * TILE_M); - let wg_per_matrix = wg_m_count * wg_n_count; - - let batch_idx = wg_id.x / wg_per_matrix; - - let wg_in_batch = wg_id.x % wg_per_matrix; - let wg_m = wg_in_batch % wg_m_count; - let wg_n = wg_in_batch / wg_m_count; - - let output_row_base = wg_m * WORKGROUP_SIZE_M * TILE_M + local_m * TILE_M; - let output_col_base = wg_n * WORKGROUP_SIZE_N * TILE_N + local_n * TILE_N; - - let dst2_stride = params.m * params.n; - let dst3_stride = dst2_stride * params.bs02 * params.broadcast2; - - let dst3_idx = batch_idx / (params.bs02 * params.broadcast2); - let src03_idx = dst3_idx / params.broadcast3; - let src13_idx = dst3_idx; - let dst2_idx = batch_idx % (params.bs02 * params.broadcast2); - let src02_idx = dst2_idx / params.broadcast2; - let src12_idx = dst2_idx; - - let src0_batch_offset = params.offset_src0 + src03_idx * params.stride_03 + src02_idx * params.stride_02; - let src1_batch_offset = params.offset_src1 + src13_idx * params.stride_13 + src12_idx * params.stride_12; - - let offset_m = wg_m * WORKGROUP_SIZE_M * TILE_M; - let offset_n = wg_n * WORKGROUP_SIZE_N * TILE_N; - - var acc: array<array<f16, TILE_N>, TILE_M>; - - for (var k_outer = 0u; k_outer < params.k; k_outer += TILE_K) { - - // see mul_mat_decls.tmpl - init_shmem_src0(thread_id, src0_batch_offset, offset_m, k_outer); - init_shmem_src1(thread_id, src1_batch_offset, offset_n, k_outer); - - workgroupBarrier(); - - let k_end = min(TILE_K, params.k - k_outer); - - for (var k_inner = 0u; k_inner < k_end; k_inner++) { - var src0_tile: array<f16, TILE_M>; - for (var tm = 0u; tm < TILE_M; tm++) { - let src0_m = local_m * TILE_M + tm; - let src0_idx = k_inner + src0_m * TILE_K; - src0_tile[tm] = shmem[src0_idx]; - } - for (var tn = 0u; tn < TILE_N; tn++) { - let src1_n = local_n * TILE_N + tn; - let src1_idx = src1_n * TILE_K + k_inner; - let src1_val = shmem[TILE_SRC0_SHMEM + src1_idx]; - for (var tm = 0u; tm < TILE_M; tm++) { - acc[tm][tn] += src0_tile[tm] * src1_val; - } - } - } - - workgroupBarrier(); - } - - let dst_batch_offset = params.offset_dst + dst3_idx * dst3_stride + dst2_idx * dst2_stride; - - for (var tn = 0u; tn < TILE_N; tn++) { - let global_col = output_col_base + tn; - if (global_col < params.n) { - for (var tm = 0u; tm < TILE_M; tm += {{VEC_SIZE}}) { - let global_row = output_row_base + tm; - if (global_row < params.m) { - let dst_idx = dst_batch_offset + global_col * params.m + global_row; - dst[dst_idx/{{VEC_SIZE}}] = store_val(acc, tn, tm); - } - } - } - } -} - -#end(SHADER) diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_reg_tile.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_reg_tile.wgsl new file mode 100644 index 00000000000..98bbdeb83ba --- /dev/null +++ b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_reg_tile.wgsl @@ -0,0 +1,149 @@ +enable f16; + +#define DECLARE_BYTE_LOADERS_SRC0 +#include "common_decls.tmpl" + +#include "mul_mat_decls.tmpl" + +#ifdef VEC +fn store_val(acc: array<array<f32, TILE_N>, TILE_M>, tn: u32, tm: u32) -> vec4<f32> { + return vec4<f32>(acc[tm][tn], acc[tm + 1][tn], acc[tm + 2][tn], acc[tm + 3][tn]); +} +#endif + +#ifdef SCALAR +fn store_val(acc: array<array<f32, TILE_N>, TILE_M>, tn: u32, tm: u32) -> f32 { + return acc[tm][tn]; +} +#endif + +struct MulMatParams { + offset_src0: u32, + offset_src1: u32, + offset_dst: u32, + m: u32, + n: u32, + k: u32, + stride_01: u32, + stride_11: u32, + stride_02: u32, + stride_12: u32, + stride_03: u32, + stride_13: u32, + bs02: u32, + bs03: u32, + broadcast2: u32, + broadcast3: u32 +}; + +@group(0) @binding(0) var<storage, read_write> src0: array<SRC0_TYPE>; // M rows, K columns +@group(0) @binding(1) var<storage, read_write> src1: array<SRC1_TYPE>; // K rows, N columns (transposed) +@group(0) @binding(2) var<storage, read_write> dst: array<DST_TYPE>; // M rows, N columns (transposed) + +@group(0) @binding(3) var<uniform> params: MulMatParams; + +fn get_local_n(thread_id: u32) -> u32 { + return thread_id / WORKGROUP_SIZE_M; +} +fn get_local_m(thread_id: u32) -> u32 { + return thread_id % WORKGROUP_SIZE_M; +} + +const TOTAL_WORKGROUP_SIZE = WORKGROUP_SIZE_M * WORKGROUP_SIZE_N; +const TILE_SRC0_SHMEM = TILE_K * WORKGROUP_SIZE_M * TILE_M; +const TILE_SRC1_SHMEM = TILE_K * WORKGROUP_SIZE_N * TILE_N; + +var<workgroup> shmem: array<f16, TILE_SRC0_SHMEM + TILE_SRC1_SHMEM>; + +@compute @workgroup_size(TOTAL_WORKGROUP_SIZE) +fn main(@builtin(workgroup_id) wg_id: vec3<u32>, + @builtin(local_invocation_id) local_id: vec3<u32>, + @builtin(num_workgroups) num_wg: vec3<u32>) { + + let thread_id = local_id.x; + let local_m = get_local_m(thread_id); + let local_n = get_local_n(thread_id); + + let wg_n_count = (params.n + WORKGROUP_SIZE_N * TILE_N - 1u) / (WORKGROUP_SIZE_N * TILE_N); + let wg_m_count = (params.m + WORKGROUP_SIZE_M * TILE_M - 1u) / (WORKGROUP_SIZE_M * TILE_M); + let wg_per_matrix = wg_m_count * wg_n_count; + + let wg_linear = wg_id.y * num_wg.x + wg_id.x; + + let batch_idx = wg_linear / wg_per_matrix; + + let total_batches = params.bs02 * params.broadcast2 * params.bs03 * params.broadcast3; + if (batch_idx >= total_batches) { + return; + } + + let wg_in_batch = wg_linear % wg_per_matrix; + let wg_m = wg_in_batch % wg_m_count; + let wg_n = wg_in_batch / wg_m_count; + + let output_row_base = wg_m * WORKGROUP_SIZE_M * TILE_M + local_m * TILE_M; + let output_col_base = wg_n * WORKGROUP_SIZE_N * TILE_N + local_n * TILE_N; + + let dst2_stride = params.m * params.n; + let dst3_stride = dst2_stride * params.bs02 * params.broadcast2; + + let dst3_idx = batch_idx / (params.bs02 * params.broadcast2); + let src03_idx = dst3_idx / params.broadcast3; + let src13_idx = dst3_idx; + let dst2_idx = batch_idx % (params.bs02 * params.broadcast2); + let src02_idx = dst2_idx / params.broadcast2; + let src12_idx = dst2_idx; + + let src0_batch_offset = params.offset_src0 + src03_idx * params.stride_03 + src02_idx * params.stride_02; + let src1_batch_offset = params.offset_src1 + src13_idx * params.stride_13 + src12_idx * params.stride_12; + + let offset_m = wg_m * WORKGROUP_SIZE_M * TILE_M; + let offset_n = wg_n * WORKGROUP_SIZE_N * TILE_N; + + var acc: array<array<f32, TILE_N>, TILE_M>; + + for (var k_outer = 0u; k_outer < params.k; k_outer += TILE_K) { + + // see mul_mat_decls.tmpl + init_shmem_src0(thread_id, src0_batch_offset, offset_m, k_outer); + init_shmem_src1(thread_id, src1_batch_offset, offset_n, k_outer); + + workgroupBarrier(); + + let k_end = min(TILE_K, params.k - k_outer); + + for (var k_inner = 0u; k_inner < k_end; k_inner++) { + var src0_tile: array<f16, TILE_M>; + for (var tm = 0u; tm < TILE_M; tm++) { + let src0_m = local_m * TILE_M + tm; + let src0_idx = k_inner + src0_m * TILE_K; + src0_tile[tm] = shmem[src0_idx]; + } + for (var tn = 0u; tn < TILE_N; tn++) { + let src1_n = local_n * TILE_N + tn; + let src1_idx = src1_n * TILE_K + k_inner; + let src1_val = shmem[TILE_SRC0_SHMEM + src1_idx]; + for (var tm = 0u; tm < TILE_M; tm++) { + acc[tm][tn] += f32(src0_tile[tm]) * f32(src1_val); + } + } + } + + workgroupBarrier(); + } + + let dst_batch_offset = params.offset_dst + dst3_idx * dst3_stride + dst2_idx * dst2_stride; + + for (var tn = 0u; tn < TILE_N; tn++) { + let global_col = output_col_base + tn; + if (global_col < params.n) { + for (var tm = 0u; tm < TILE_M; tm += VEC_SIZE) { + let global_row = output_row_base + tm; + if (global_row < params.m) { + let dst_idx = dst_batch_offset + global_col * params.m + global_row; + dst[dst_idx/VEC_SIZE] = store_val(acc, tn, tm); + } + } + } + } +} diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_subgroup_matrix.tmpl.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_subgroup_matrix.wgsl similarity index 64% rename from ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_subgroup_matrix.tmpl.wgsl rename to ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_subgroup_matrix.wgsl index 47c8ce36ab3..d86a72ce6e0 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_subgroup_matrix.tmpl.wgsl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_subgroup_matrix.wgsl @@ -1,100 +1,17 @@ -#define(VARIANTS) -[ - { - "SHADER_SUFFIX": "f32_f32_vec", - "REPLS": { - "SRC0_TYPE" : "vec4<f32>", - "SRC1_TYPE" : "vec4<f32>", - "DST_TYPE" : "vec4<f32>", - "SHMEM_TYPE" : "vec4<f16>", - "VEC_SIZE" : 4, - }, - "DECLS": ["VEC", "SHMEM_VEC", "INIT_SRC0_SHMEM_FLOAT", "INIT_SRC1_SHMEM"] - }, - { - "SHADER_SUFFIX": "f32_f32", - "REPLS": { - "SRC0_TYPE" : "f32", - "SRC1_TYPE" : "f32", - "DST_TYPE" : "f32", - "SHMEM_TYPE" : "f16", - "VEC_SIZE" : 1, - }, - "DECLS": ["SCALAR", "SHMEM_SCALAR", "INIT_SRC0_SHMEM_FLOAT", "INIT_SRC1_SHMEM"] - }, - { - "SHADER_SUFFIX": "f16_f32_vec", - "REPLS": { - "SRC0_TYPE" : "vec4<f16>", - "SRC1_TYPE" : "vec4<f32>", - "DST_TYPE" : "vec4<f32>", - "SHMEM_TYPE" : "vec4<f16>", - "VEC_SIZE" : 4, - }, - "DECLS": ["VEC", "SHMEM_VEC", "INIT_SRC0_SHMEM_FLOAT", "INIT_SRC1_SHMEM"] - }, - { - "SHADER_SUFFIX": "f16_f32", - "REPLS": { - "SRC0_TYPE" : "f16", - "SRC1_TYPE" : "f32", - "DST_TYPE" : "f32", - "SHMEM_TYPE" : "f16", - "VEC_SIZE" : 1, - }, - "DECLS": ["SCALAR", "SHMEM_SCALAR", "INIT_SRC0_SHMEM_FLOAT", "INIT_SRC1_SHMEM"] - }, - { - "SHADER_SUFFIX": "f16_f16_vec", - "REPLS": { - "SRC0_TYPE" : "vec4<f16>", - "SRC1_TYPE" : "vec4<f16>", - "DST_TYPE" : "vec4<f32>", - "SHMEM_TYPE" : "vec4<f16>", - "VEC_SIZE" : 4, - }, - "DECLS": ["VEC", "SHMEM_VEC", "INIT_SRC0_SHMEM_FLOAT", "INIT_SRC1_SHMEM"] - }, - { - "SHADER_SUFFIX": "f16_f16", - "REPLS": { - "SRC0_TYPE" : "f16", - "SRC1_TYPE" : "f16", - "DST_TYPE" : "f32", - "SHMEM_TYPE" : "f16", - "VEC_SIZE" : 1, - }, - "DECLS": ["SCALAR", "SHMEM_SCALAR", "INIT_SRC0_SHMEM_FLOAT", "INIT_SRC1_SHMEM"] - }, - { - "SHADER_SUFFIX": "q4_0_f32_vec", - "REPLS": { - "SRC0_TYPE" : "f16", - "SRC1_TYPE" : "vec4<f32>", - "DST_TYPE" : "vec4<f32>", - "SHMEM_TYPE" : "vec4<f16>", - "VEC_SIZE" : 4, - }, - "DECLS": ["BYTE_HELPERS", "VEC", "SHMEM_VEC", "INIT_SRC0_SHMEM_Q4_0", "INIT_SRC1_SHMEM"] - }, - { - "SHADER_SUFFIX": "q4_0_f32", - "REPLS": { - "SRC0_TYPE" : "f16", - "SRC1_TYPE" : "f32", - "DST_TYPE" : "f32", - "SHMEM_TYPE" : "f16", - "VEC_SIZE" : 1, - }, - "DECLS": ["BYTE_HELPERS", "SCALAR", "SHMEM_SCALAR", "INIT_SRC0_SHMEM_Q4_0", "INIT_SRC1_SHMEM"] - } -] - -#end(VARIANTS) - -#define(DECLS) - -#decl(VEC) +diagnostic(off, chromium.subgroup_matrix_uniformity); +enable f16; +enable subgroups; +enable chromium_experimental_subgroup_matrix; + +#define DECLARE_BYTE_LOADERS_SRC0 +#include "common_decls.tmpl" + +#include "mul_mat_decls.tmpl" + +// TODO: this shader path does not work with some models like qwen2.5 on Metal devices, f16 accumulation causes NaNs. +// See https://github.com/ggml-org/llama.cpp/issues/21602 + +#ifdef VEC fn store_dst(shmem_idx: u32, dst_idx: u32) { dst[dst_idx] = vec4<f32>( f32(shmem[shmem_idx]), @@ -103,21 +20,13 @@ fn store_dst(shmem_idx: u32, dst_idx: u32) { f32(shmem[shmem_idx + 3]) ); } -#enddecl(VEC) +#endif -#decl(SCALAR) +#ifdef SCALAR fn store_dst(shmem_idx: u32, dst_idx: u32) { dst[dst_idx] = f32(shmem[shmem_idx]); } -#enddecl(SCALAR) - -#end(DECLS) - -#define(SHADER) -diagnostic(off, chromium.subgroup_matrix_uniformity); -enable f16; -enable subgroups; -enable chromium_experimental_subgroup_matrix; +#endif struct MulMatParams { offset_src0: u32, @@ -138,36 +47,19 @@ struct MulMatParams { broadcast3: u32 }; -@group(0) @binding(0) var<storage, read_write> src0: array<{{SRC0_TYPE}}>; // M rows, K columns -@group(0) @binding(1) var<storage, read_write> src1: array<{{SRC1_TYPE}}>; // K rows, N columns (transposed) -@group(0) @binding(2) var<storage, read_write> dst: array<{{DST_TYPE}}>; // M rows, N columns (transposed) +// SRC0_TYPE and SRC1_TYPE are defined in mul_mat_decls, which is included +@group(0) @binding(0) var<storage, read_write> src0: array<SRC0_TYPE>; // M rows, K columns +@group(0) @binding(1) var<storage, read_write> src1: array<SRC1_TYPE>; // K rows, N columns (transposed) +@group(0) @binding(2) var<storage, read_write> dst: array<DST_TYPE>; // M rows, N columns (transposed) @group(0) @binding(3) var<uniform> params: MulMatParams; -DECLS +const WG_M_SG_TILE_SIZE = SUBGROUP_M * SUBGROUP_MATRIX_M * SUBGROUP_MATRIX_M_SIZE; +const WG_N_SG_TILE_SIZE = SUBGROUP_N * SUBGROUP_MATRIX_N * SUBGROUP_MATRIX_N_SIZE; -// Note: These are string interpolated at build time, cannot use override constants due to limitations in -// current Dawn version type definitions/matrix load requirements for constant memory sizes. -const SUBGROUP_M = {{WEBGPU_SUBGROUP_M}}u; -const SUBGROUP_N = {{WEBGPU_SUBGROUP_N}}u; // For portability we assume the max subgroup size, meaning some subgroups will be masked out if the // runtime subgroup size is smaller. -const MAX_SUBGROUP_SIZE = {{WEBGPU_MAX_SUBGROUP_SIZE}}u; - const EXPECTED_SUBGROUPS = SUBGROUP_M * SUBGROUP_N; - -const SUBGROUP_MATRIX_M_SIZE = {{WEBGPU_SG_MAT_M_SIZE}}u; -const SUBGROUP_MATRIX_N_SIZE = {{WEBGPU_SG_MAT_N_SIZE}}u; -const SUBGROUP_MATRIX_K_SIZE = {{WEBGPU_SG_MAT_K_SIZE}}u; - -const SUBGROUP_MATRIX_M = {{WEBGPU_SUBGROUP_MATRIX_M}}u; -const SUBGROUP_MATRIX_N = {{WEBGPU_SUBGROUP_MATRIX_N}}u; - -const TILE_K = {{WEBGPU_TILE_K}}u; - -const WG_M_SG_TILE_SIZE = SUBGROUP_M * SUBGROUP_MATRIX_M * SUBGROUP_MATRIX_M_SIZE; -const WG_N_SG_TILE_SIZE = SUBGROUP_N * SUBGROUP_MATRIX_N * SUBGROUP_MATRIX_N_SIZE; - const TOTAL_WORKGROUP_SIZE = SUBGROUP_M * SUBGROUP_N * MAX_SUBGROUP_SIZE; const TILE_SRC0_SHMEM = TILE_K * SUBGROUP_M * SUBGROUP_MATRIX_M * SUBGROUP_MATRIX_M_SIZE; const TILE_SRC1_SHMEM = TILE_K * SUBGROUP_N * SUBGROUP_MATRIX_N * SUBGROUP_MATRIX_N_SIZE; @@ -182,7 +74,8 @@ var<workgroup> shmem: array<f16, SHMEM_SIZE>; @compute @workgroup_size(TOTAL_WORKGROUP_SIZE) fn main(@builtin(workgroup_id) wg_id: vec3<u32>, @builtin(local_invocation_id) local_id: vec3<u32>, - @builtin(subgroup_id) subgroup_id: u32) { + @builtin(subgroup_id) subgroup_id: u32, + @builtin(num_workgroups) num_wg: vec3<u32>) { let thread_id = local_id.x; let subgroup_m = subgroup_id % SUBGROUP_M; @@ -192,9 +85,16 @@ fn main(@builtin(workgroup_id) wg_id: vec3<u32>, let wg_n_count = (params.n + WG_N_SG_TILE_SIZE - 1) / WG_N_SG_TILE_SIZE; let wg_per_matrix = wg_m_count * wg_n_count; - let batch_idx = wg_id.x / wg_per_matrix; + let wg_linear = wg_id.y * num_wg.x + wg_id.x; + + let batch_idx = wg_linear / wg_per_matrix; - let wg_in_batch = wg_id.x % wg_per_matrix; + let total_batches = params.bs02 * params.broadcast2 * params.bs03 * params.broadcast3; + if (batch_idx >= total_batches) { + return; + } + + let wg_in_batch = wg_linear % wg_per_matrix; let wg_m = wg_in_batch % wg_m_count; let wg_n = wg_in_batch / wg_m_count; @@ -285,7 +185,7 @@ fn main(@builtin(workgroup_id) wg_id: vec3<u32>, let tile_dst_row_base = wg_m * SUBGROUP_M * SUBGROUP_MATRIX_M * SUBGROUP_MATRIX_M_SIZE; let tile_dst_col_base = wg_n * SUBGROUP_N * SUBGROUP_MATRIX_N * SUBGROUP_MATRIX_N_SIZE; - for (var idx = thread_id * {{VEC_SIZE}}; idx < total_tile_elems; idx += TOTAL_WORKGROUP_SIZE * {{VEC_SIZE}}) { + for (var idx = thread_id * VEC_SIZE; idx < total_tile_elems; idx += TOTAL_WORKGROUP_SIZE * VEC_SIZE) { let local_row = idx % WG_TILE_STRIDE; let local_col = idx / WG_TILE_STRIDE; @@ -294,9 +194,7 @@ fn main(@builtin(workgroup_id) wg_id: vec3<u32>, if (global_col < params.n && global_row < params.m) { let dst_idx = dst_batch_offset + global_col * params.m + global_row; - store_dst(idx, dst_idx/{{VEC_SIZE}}); + store_dst(idx, dst_idx/VEC_SIZE); } } } - -#end(SHADER) diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.tmpl.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.tmpl.wgsl deleted file mode 100644 index ffbb6403285..00000000000 --- a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.tmpl.wgsl +++ /dev/null @@ -1,267 +0,0 @@ -#define(VARIANTS) -[ - { - "SHADER_SUFFIX": "f32_f32_vec", - "REPLS": { - "SRC0_TYPE" : "vec4<f32>", - "SRC1_TYPE" : "vec4<f32>", - "DST_TYPE": "vec4<f32>", - "VEC_SIZE" : 4, - }, - "DECLS": ["VEC", "MUL_ACC_FLOAT"] - }, - { - "SHADER_SUFFIX": "f32_f32", - "REPLS": { - "SRC0_TYPE" : "f32", - "SRC1_TYPE" : "f32", - "DST_TYPE": "f32", - "VEC_SIZE" : 1, - }, - "DECLS": ["SCALAR", "MUL_ACC_FLOAT"] - }, - { - "SHADER_SUFFIX": "f16_f32_vec", - "REPLS": { - "SRC0_TYPE" : "vec4<f16>", - "SRC1_TYPE" : "vec4<f32>", - "DST_TYPE": "vec4<f32>", - "VEC_SIZE" : 4, - }, - "DECLS": ["VEC", "MUL_ACC_FLOAT"] - }, - { - "SHADER_SUFFIX": "f16_f32", - "REPLS": { - "SRC0_TYPE" : "f16", - "SRC1_TYPE" : "f32", - "DST_TYPE": "f32", - "VEC_SIZE" : 1, - }, - "DECLS": ["SCALAR", "MUL_ACC_FLOAT"] - }, - { - "SHADER_SUFFIX": "f16_f16_vec", - "REPLS": { - "SRC0_TYPE" : "vec4<f16>", - "SRC1_TYPE" : "vec4<f16>", - "DST_TYPE": "vec4<f32>", - "VEC_SIZE" : 4, - }, - "DECLS": ["VEC", "MUL_ACC_FLOAT"] - }, - { - "SHADER_SUFFIX": "f16_f16", - "REPLS": { - "SRC0_TYPE" : "f16", - "SRC1_TYPE" : "f16", - "DST_TYPE": "f32", - "VEC_SIZE" : 1, - }, - "DECLS": ["SCALAR", "MUL_ACC_FLOAT"] - }, - { - "SHADER_SUFFIX": "q4_0_f32", - "REPLS": { - "SRC0_TYPE" : "f16", - "SRC1_TYPE" : "f32", - "DST_TYPE": "f32", - "VEC_SIZE" : 1, - }, - "DECLS": ["BYTE_HELPERS", "SCALAR", "MUL_ACC_Q4_0"] - } -] - -#end(VARIANTS) - -#define(DECLS) - -#decl(VEC) -fn inner_dot(src0_val: {{SRC0_TYPE}}, src1_val: {{SRC1_TYPE}}) -> f32 { - return f32(dot({{SRC1_TYPE}}(src0_val), src1_val)); -} - -fn store_val(group_base: u32) -> vec4<f32> { - return vec4<f32>(partial_sums[group_base], - partial_sums[group_base + THREADS_PER_OUTPUT], - partial_sums[group_base + THREADS_PER_OUTPUT * 2], - partial_sums[group_base + THREADS_PER_OUTPUT * 3]); -} -#enddecl(VEC) - -#decl(SCALAR) -fn inner_dot(src0_val: {{SRC0_TYPE}}, src1_val: {{SRC1_TYPE}}) -> f32 { - return f32(src0_val) * f32(src1_val); -} - -fn store_val(group_base: u32) -> f32 { - return partial_sums[group_base]; -} -#enddecl(SCALAR) - -#decl(MUL_ACC_FLOAT) - -fn mul_acc(tig:u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 { - var local_sum = 0.0; - for (var i = tig * {{VEC_SIZE}}; i < tile_size; i += THREADS_PER_OUTPUT * {{VEC_SIZE}}) { - let a = src0[(idx_base + k_outer + i) / {{VEC_SIZE}}]; - let b = shared_vector[i / {{VEC_SIZE}}]; - local_sum += inner_dot(a, b); - } - return local_sum; -} - -#enddecl(MUL_ACC_FLOAT) - -#decl(MUL_ACC_Q4_0) - -const BLOCK_SIZE = 32; -const NQ = 16u; // number of weights per thread -const F16_PER_BLOCK = 9u; // 1 scale + 8x4 packed weights -const WEIGHTS_PER_F16 = 4u; // 4 weights per f16 -const F16_PER_THREAD = NQ / WEIGHTS_PER_F16; - -fn mul_acc(tig:u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 { - var local_sum = 0.0; - for (var i = tig * NQ; i < tile_size; i += THREADS_PER_OUTPUT * NQ) { - let blck_idx = i / BLOCK_SIZE; - let block_offset = (i % BLOCK_SIZE) / WEIGHTS_PER_F16; - let scale_idx = (idx_base + k_outer / BLOCK_SIZE + blck_idx) * F16_PER_BLOCK; - // each f16 contains offsets [block_offset, block_offset + 1] and [block_offset + 16, block_offset + 17] - let shmem_idx = blck_idx * BLOCK_SIZE + block_offset * 2u; - let d = f32(src0[scale_idx]); - for (var j = 0u; j < F16_PER_THREAD; j += 2) { - let q_0 = src0[scale_idx + 1 + block_offset + j]; - let q_1 = src0[scale_idx + 1 + block_offset + j + 1]; - let q_packed = bitcast<u32>(vec2(q_0, q_1)); - for (var k: u32 = 0; k < 4; k++) { - let q_byte = get_byte(q_packed, k); - let q_hi = (f32((q_byte >> 4) & 0xF) - 8.0) * d; - let q_lo = (f32(q_byte & 0xF) - 8.0) * d; - local_sum += q_lo * shared_vector[shmem_idx + j * 2 + k]; - local_sum += q_hi * shared_vector[shmem_idx + j * 2 + k + 16]; - } - } - } - return local_sum; -} - -#enddecl(MUL_ACC_Q4_0) - -#end(DECLS) - -#define(SHADER) -enable f16; - -DECLS - -struct MulMatParams { - offset_src0: u32, - offset_src1: u32, - offset_dst: u32, - m: u32, - n: u32, - k: u32, - stride_01: u32, - stride_11: u32, - stride_02: u32, - stride_12: u32, - stride_03: u32, - stride_13: u32, - bs02: u32, - bs03: u32, - broadcast2: u32, - broadcast3: u32 -}; - -@group(0) @binding(0) var<storage, read_write> src0: array<{{SRC0_TYPE}}>; // Matrix (M x K) -@group(0) @binding(1) var<storage, read_write> src1: array<{{SRC1_TYPE}}>; // Vector (K x 1, transposed) -@group(0) @binding(2) var<storage, read_write> dst: array<{{DST_TYPE}}>; // Result vector (transposed) - -@group(0) @binding(3) var<uniform> params: MulMatParams; - -override WORKGROUP_SIZE: u32; -override TILE_K: u32; -override OUTPUTS_PER_WG: u32; -override THREADS_PER_OUTPUT = WORKGROUP_SIZE / OUTPUTS_PER_WG; - -// Shared memory for collaborative loading and reduction -var<workgroup> shared_vector: array<{{SRC1_TYPE}}, TILE_K/{{VEC_SIZE}}>; // Cache vector tile -var<workgroup> partial_sums: array<f32, WORKGROUP_SIZE>; // For reduction - -@compute @workgroup_size(WORKGROUP_SIZE) -fn main( - @builtin(local_invocation_id) local_id: vec3<u32>, - @builtin(workgroup_id) wg_id: vec3<u32>, - @builtin(num_workgroups) num_wg: vec3<u32>) { - let thread_id = local_id.x; - - // Handle batch dimensions - let total_batches = params.bs02 * params.broadcast2 * params.bs03 * params.broadcast3; - let wg_linear = wg_id.y * num_wg.x + wg_id.x; - let output_groups = (params.m + OUTPUTS_PER_WG - 1u) / OUTPUTS_PER_WG; - let batch_idx = wg_linear / output_groups; - if (batch_idx >= total_batches) { - return; - } - - // Which of the outputs does this thread belong to? - let thread_group = thread_id / THREADS_PER_OUTPUT; - let thread_in_group = thread_id % THREADS_PER_OUTPUT; - - // Each workgroup computes OUTPUTS_PER_WG consecutive outputs - let output_row = (wg_linear % output_groups) * OUTPUTS_PER_WG + thread_group; - - let dst2_stride = params.m * params.n; - let dst2_idx = batch_idx % (params.bs02 * params.broadcast2); - let dst3_stride = dst2_stride * params.bs02 * params.broadcast2; - let dst3_idx = batch_idx / (params.bs02 * params.broadcast2); - let src03_idx = dst3_idx / params.broadcast3; - let src13_idx = dst3_idx; - let src02_idx = dst2_idx / params.broadcast2; - let src12_idx = dst2_idx; - - let src0_idx_base = params.offset_src0 + src03_idx * params.stride_03 + src02_idx * params.stride_02 + output_row * params.stride_01; - let src1_idx_base = params.offset_src1 + src13_idx * params.stride_13 + src12_idx * params.stride_12; - let dst_idx = params.offset_dst + dst3_idx * dst3_stride + dst2_idx * dst2_stride + output_row; - - var local_sum = 0.0; - - // Each thread processes multiple K elements and accumulates - for (var k_tile = 0u; k_tile < params.k; k_tile += TILE_K) { - let tile_size = min(TILE_K, params.k - k_tile); - - // Cooperatively load vector tile into shared memory (all threads) - for (var i = thread_id * {{VEC_SIZE}}; i < tile_size; i += WORKGROUP_SIZE * {{VEC_SIZE}}) { - shared_vector[i / {{VEC_SIZE}}] = src1[(src1_idx_base + k_tile + i) / {{VEC_SIZE}}]; - } - - workgroupBarrier(); - - if (output_row < params.m) { - local_sum += mul_acc(thread_in_group, tile_size, src0_idx_base, k_tile); - } - - workgroupBarrier(); - } - - // Store partial sums and reduce within each partition - partial_sums[thread_id] = local_sum; - workgroupBarrier(); - let group_base = thread_group * THREADS_PER_OUTPUT; - let thread_base = group_base + thread_in_group; - var offset = THREADS_PER_OUTPUT / 2; - while (offset > 0) { - if (thread_in_group < offset) { - partial_sums[thread_base] += partial_sums[thread_base + offset]; - } - offset = offset / 2; - workgroupBarrier(); - } - - // Store back to global memory - if (output_row < params.m && thread_group % {{VEC_SIZE}} == 0 && thread_in_group == 0) { - dst[dst_idx / {{VEC_SIZE}}] = store_val(group_base); - } -} -#end(SHADER) diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.wgsl new file mode 100644 index 00000000000..f0a7fbd059a --- /dev/null +++ b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.wgsl @@ -0,0 +1,151 @@ +#ifdef USE_SUBGROUP_REDUCTION +enable subgroups; +#endif +enable f16; + +#ifdef MMVQ +requires packed_4x8_integer_dot_product; +#endif + +#define DECLARE_BYTE_LOADERS_SRC0 +#include "common_decls.tmpl" + +#ifdef MMVQ +#include "mul_mat_vec_q_acc.tmpl" +#else +#include "mul_mat_vec_acc.tmpl" +#endif + +struct MulMatParams { + offset_src0: u32, + offset_src1: u32, + offset_dst: u32, + m: u32, + n: u32, + k: u32, + stride_01: u32, + stride_11: u32, + stride_02: u32, + stride_12: u32, + stride_03: u32, + stride_13: u32, + bs02: u32, + bs03: u32, + broadcast2: u32, + broadcast3: u32 +}; + +@group(0) @binding(0) var<storage, read_write> src0: array<SRC0_TYPE>; + +#ifdef MMVQ +@group(0) @binding(1) var<storage, read_write> src1q: array<q8_1>; +#else +@group(0) @binding(1) var<storage, read_write> src1: array<SRC1_TYPE>; +#endif + +@group(0) @binding(2) var<storage, read_write> dst: array<f32>; +// "mul_mat_vec_acc.tmpl" requires params.k, params.m, params.stride_01 +@group(0) @binding(3) var<uniform> params: MulMatParams; + +// Flattened as [row][thread] to keep each row's reduction contiguous in memory. +var<workgroup> partial_sums: array<f32, OUTPUTS_PER_WG * WG_SIZE>; + +fn partial_index(row: u32, thread: u32) -> u32 { + return row * WG_SIZE + thread; +} + +@compute @workgroup_size(WG_SIZE) +fn main( + @builtin(local_invocation_id) local_id: vec3<u32>, + @builtin(workgroup_id) wg_id: vec3<u32>, + @builtin(num_workgroups) num_wg: vec3<u32> +#ifdef USE_SUBGROUP_REDUCTION + , @builtin(subgroup_id) subgroup_id: u32, + @builtin(subgroup_invocation_id) subgroup_invocation_id: u32, + @builtin(num_subgroups) num_subgroups: u32, + @builtin(subgroup_size) subgroup_size: u32 +#endif +) { + let thread_id = local_id.x; + + let total_batches = params.bs02 * params.broadcast2 * params.bs03 * params.broadcast3; + let wg_linear = wg_id.y * num_wg.x + wg_id.x; + let output_groups = (params.m + OUTPUTS_PER_WG - 1u) / OUTPUTS_PER_WG; + let batch_idx = wg_linear / output_groups; + if (batch_idx >= total_batches) { + return; + } + + let row_base = (wg_linear % output_groups) * OUTPUTS_PER_WG; + + let dst2_stride = params.m * params.n; + let dst2_idx = batch_idx % (params.bs02 * params.broadcast2); + let dst3_stride = dst2_stride * params.bs02 * params.broadcast2; + let dst3_idx = batch_idx / (params.bs02 * params.broadcast2); + let src03_idx = dst3_idx / params.broadcast3; + let src13_idx = dst3_idx; + let src02_idx = dst2_idx / params.broadcast2; + let src12_idx = dst2_idx; + + let src0_batch_offset = params.offset_src0 + src03_idx * params.stride_03 + src02_idx * params.stride_02; + let dst_idx_base = params.offset_dst + dst3_idx * dst3_stride + dst2_idx * dst2_stride + row_base; + +#ifdef MMVQ + let src1q_idx_base = (src13_idx * params.bs02 * params.broadcast2 + src12_idx) * (params.k / 32u); + let acc = accumulate_vec_q_dot(thread_id, row_base, src0_batch_offset, src1q_idx_base); +#else + let src1_idx_base = params.offset_src1 + src13_idx * params.stride_13 + src12_idx * params.stride_12; + let acc = accumulate_vec_dot(thread_id, row_base, src0_batch_offset, src1_idx_base); +#endif + +#ifdef USE_SUBGROUP_REDUCTION + for (var row = 0u; row < OUTPUTS_PER_WG; row++) { + let subgroup_total = subgroupAdd(acc[row]); + if (subgroup_invocation_id == 0u) { + partial_sums[partial_index(row, subgroup_id)] = subgroup_total; + } + } + + workgroupBarrier(); + + for (var row = subgroup_id; (row < OUTPUTS_PER_WG) && (row_base + row < params.m); row += num_subgroups) { + let output_row = row_base + row; + var row_acc = 0.0f; + for (var k = subgroup_invocation_id; k < num_subgroups; k += subgroup_size) { + row_acc += partial_sums[partial_index(row, k)]; + } + let row_total = subgroupAdd(row_acc); + if (subgroup_invocation_id == 0) { + dst[dst_idx_base + row] = row_total; + } + } +#endif + +#ifdef USE_WORKGROUP_REDUCTION + for (var row = 0u; row < OUTPUTS_PER_WG; row++) { + partial_sums[partial_index(row, thread_id)] = acc[row]; + } + + workgroupBarrier(); + + var stride = WG_SIZE / 2u; + + while (stride > 0) { + if (thread_id < stride) { + for (var row = 0u; row < OUTPUTS_PER_WG; row++) { + partial_sums[partial_index(row, thread_id)] += partial_sums[partial_index(row, thread_id + stride)]; + } + } + + workgroupBarrier(); + stride = stride / 2; + } + + if (thread_id < OUTPUTS_PER_WG) { + let output_row = row_base + thread_id; + if (output_row < params.m) { + dst[dst_idx_base + thread_id] = partial_sums[partial_index(thread_id, 0)]; + } + } +#endif +} diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec_acc.tmpl b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec_acc.tmpl new file mode 100644 index 00000000000..08753b9d643 --- /dev/null +++ b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec_acc.tmpl @@ -0,0 +1,1432 @@ +#ifdef U32_DEQUANT_HELPERS +#define SRC0_TYPE u32 + +fn byte_of(v: u32, b: u32) -> u32 { + return (v >> (b * 8u)) & 0xFFu; +} + +fn sbyte_of(v: u32, b: u32) -> i32 { + let raw = i32((v >> (b * 8u)) & 0xFFu); + return select(raw, raw - 256, raw >= 128); +} +#endif + +#ifdef VEC +#define VEC_SIZE 4u +#define SRC0_TYPE vec4<SRC0_INNER_TYPE> +#define SRC1_TYPE vec4<SRC1_INNER_TYPE> + +fn inner_dot(src0_val: SRC0_TYPE, src1_val: SRC1_TYPE) -> f32 { + return f32(dot(SRC1_TYPE(src0_val), src1_val)); +} +#endif + +#ifdef SCALAR +#define VEC_SIZE 1u +#define SRC0_TYPE SRC0_INNER_TYPE +#define SRC1_TYPE SRC1_INNER_TYPE + +fn inner_dot(src0_val: SRC0_TYPE, src1_val: SRC1_TYPE) -> f32 { + return f32(src0_val) * f32(src1_val); +} +#endif + +#ifdef MUL_ACC_FLOAT +fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array<f32, OUTPUTS_PER_WG> { + var acc: array<f32, OUTPUTS_PER_WG>; + + let k_vec = params.k / VEC_SIZE; + let src1_idx_base_vec = src1_idx_base / VEC_SIZE; + + // Each thread walks K, loads from the vector, and updates + // a small block of output rows held in registers. + for (var k = thread_id; k < k_vec; k += WG_SIZE) { + let x = src1[src1_idx_base_vec + k]; + for (var row = 0u; row < OUTPUTS_PER_WG; row++) { + let output_row = row_base + row; + if (output_row < params.m) { + let src0_idx = (src0_batch_offset + output_row * params.stride_01) / VEC_SIZE + k; + acc[row] += inner_dot(src0[src0_idx], x); + } + } + } + + return acc; +} +#endif + +#ifdef MUL_ACC_Q1_0 +#define BLOCK_SIZE 128 +#define BLOCK_SIZE_BYTES 18 +#define THREADS_PER_BLOCK 16 +#define ELEMS_PER_THREAD (BLOCK_SIZE/THREADS_PER_BLOCK) +fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array<f32, OUTPUTS_PER_WG> { + var acc: array<f32, OUTPUTS_PER_WG>; + + let num_blocks = params.k / BLOCK_SIZE; + let thread_within_block = thread_id % THREADS_PER_BLOCK; + for (var block = thread_id / THREADS_PER_BLOCK; block < num_blocks; block += WG_SIZE / THREADS_PER_BLOCK) { + let x_base = src1_idx_base + block * BLOCK_SIZE + thread_within_block * ELEMS_PER_THREAD; + var x_block: array<f32, ELEMS_PER_THREAD>; + for (var i = 0u; i < ELEMS_PER_THREAD; i++) { + x_block[i] = f32(src1[x_base + i]); + } + + for (var row = 0u; row < OUTPUTS_PER_WG; row++) { + let output_row = row_base + row; + if (output_row < params.m) { + let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES; + let d = f32(load_f16_at_src0(block_byte_base)); + let q_byte = load_u32_at_src0(block_byte_base + 2u + thread_within_block) & 0xFFu; + var row_sum = 0.0; + for (var bit = 0u; bit < 8u; bit++) { + let w = select(-d, d, ((q_byte >> bit) & 1u) != 0u); + row_sum += w * x_block[bit]; + } + acc[row] += row_sum; + } + } + } + + return acc; +} +#endif + +#ifdef MUL_ACC_Q4_0 +#define BLOCK_SIZE 32 +#define BLOCK_SIZE_BYTES 18 +#define THREADS_PER_BLOCK 4 +#define ELEMS_PER_THREAD (BLOCK_SIZE/THREADS_PER_BLOCK) +fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array<f32, OUTPUTS_PER_WG> { + var acc: array<f32, OUTPUTS_PER_WG>; + + let num_blocks = params.k / BLOCK_SIZE; + let thread_within_block = thread_id % 4; + for (var block = thread_id/THREADS_PER_BLOCK; block < num_blocks; block += WG_SIZE/THREADS_PER_BLOCK) { + let x_base = src1_idx_base + block * BLOCK_SIZE + thread_within_block * 4; + var x_block: array<f32, ELEMS_PER_THREAD>; + for (var i = 0u; i < ELEMS_PER_THREAD / 2; i++) { + x_block[i] = f32(src1[x_base + i]); + x_block[i + 4] = f32(src1[x_base + i + 16]); + } + + for (var row = 0u; row < OUTPUTS_PER_WG; row++) { + let output_row = row_base + row; + if (output_row < params.m) { + let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES; + let d = f32(load_f16_at_src0(block_byte_base)); + var row_sum = 0.0; + + let q_packed = load_u32_at_src0(block_byte_base + 2u + 4u * thread_within_block); + for (var byte_idx = 0u; byte_idx < 4u; byte_idx++) { + let q_byte = get_byte(q_packed, byte_idx); + let q_lo = (f32(q_byte & 0xFu) - 8.0) * d; + let q_hi = (f32((q_byte >> 4u) & 0xFu) - 8.0) * d; + row_sum += q_lo * x_block[byte_idx]; + row_sum += q_hi * x_block[byte_idx + 4u]; + } + acc[row] += row_sum; + } + } + } + + return acc; +} +#endif + +#ifdef MUL_ACC_Q4_1 +#define BLOCK_SIZE 32 +#define BLOCK_SIZE_BYTES 20 +#define THREADS_PER_BLOCK 4 +#define ELEMS_PER_THREAD (BLOCK_SIZE/THREADS_PER_BLOCK) +fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array<f32, OUTPUTS_PER_WG> { + var acc: array<f32, OUTPUTS_PER_WG>; + + let num_blocks = params.k / BLOCK_SIZE; + let thread_within_block = thread_id % THREADS_PER_BLOCK; + for (var block = thread_id / THREADS_PER_BLOCK; block < num_blocks; block += WG_SIZE / THREADS_PER_BLOCK) { + let x_base = src1_idx_base + block * BLOCK_SIZE + thread_within_block * 4; + var x_block: array<f32, ELEMS_PER_THREAD>; + for (var i = 0u; i < ELEMS_PER_THREAD / 2; i++) { + x_block[i] = f32(src1[x_base + i]); + x_block[i + 4] = f32(src1[x_base + i + 16]); + } + + for (var row = 0u; row < OUTPUTS_PER_WG; row++) { + let output_row = row_base + row; + if (output_row < params.m) { + let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES; + let d = f32(load_f16_at_src0(block_byte_base)); + let m = f32(load_f16_at_src0(block_byte_base + 2u)); + var row_sum = 0.0; + + let q_packed = load_u32_at_src0(block_byte_base + 4u + 4u * thread_within_block); + for (var byte_idx = 0u; byte_idx < 4u; byte_idx++) { + let q_byte = get_byte(q_packed, byte_idx); + let q_lo = f32(q_byte & 0xFu) * d + m; + let q_hi = f32((q_byte >> 4u) & 0xFu) * d + m; + row_sum += q_lo * x_block[byte_idx]; + row_sum += q_hi * x_block[byte_idx + 4u]; + } + acc[row] += row_sum; + } + } + } + + return acc; +} +#endif + +#ifdef MUL_ACC_Q5_0 +#define BLOCK_SIZE 32 +#define BLOCK_SIZE_BYTES 22 +#define THREADS_PER_BLOCK 4 +#define ELEMS_PER_THREAD (BLOCK_SIZE/THREADS_PER_BLOCK) +fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array<f32, OUTPUTS_PER_WG> { + var acc: array<f32, OUTPUTS_PER_WG>; + + let num_blocks = params.k / BLOCK_SIZE; + let thread_within_block = thread_id % THREADS_PER_BLOCK; + for (var block = thread_id / THREADS_PER_BLOCK; block < num_blocks; block += WG_SIZE / THREADS_PER_BLOCK) { + let x_base = src1_idx_base + block * BLOCK_SIZE + thread_within_block * 4; + var x_block: array<f32, ELEMS_PER_THREAD>; + for (var i = 0u; i < ELEMS_PER_THREAD / 2; i++) { + x_block[i] = f32(src1[x_base + i]); + x_block[i + 4] = f32(src1[x_base + i + 16]); + } + + for (var row = 0u; row < OUTPUTS_PER_WG; row++) { + let output_row = row_base + row; + if (output_row < params.m) { + let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES; + let d = f32(load_f16_at_src0(block_byte_base)); + let qh_packed = load_u32_at_src0(block_byte_base + 2u); + let q_packed = load_u32_at_src0(block_byte_base + 6u + 4u * thread_within_block); + let qh_shift = thread_within_block * 4u; + var row_sum = 0.0; + + for (var byte_idx = 0u; byte_idx < 4u; byte_idx++) { + let q_byte = get_byte(q_packed, byte_idx); + let qh_lo = ((qh_packed >> (qh_shift + byte_idx)) << 4u) & 0x10u; + let qh_hi = (qh_packed >> (qh_shift + byte_idx + 12u)) & 0x10u; + let q_lo = (f32((q_byte & 0xFu) | qh_lo) - 16.0) * d; + let q_hi = (f32(((q_byte >> 4u) & 0xFu) | qh_hi) - 16.0) * d; + row_sum += q_lo * x_block[byte_idx]; + row_sum += q_hi * x_block[byte_idx + 4u]; + } + acc[row] += row_sum; + } + } + } + + return acc; +} +#endif + +#ifdef MUL_ACC_Q5_1 +#define BLOCK_SIZE 32 +#define BLOCK_SIZE_BYTES 24 +#define THREADS_PER_BLOCK 4 +#define ELEMS_PER_THREAD (BLOCK_SIZE/THREADS_PER_BLOCK) +fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array<f32, OUTPUTS_PER_WG> { + var acc: array<f32, OUTPUTS_PER_WG>; + + let num_blocks = params.k / BLOCK_SIZE; + let thread_within_block = thread_id % THREADS_PER_BLOCK; + for (var block = thread_id / THREADS_PER_BLOCK; block < num_blocks; block += WG_SIZE / THREADS_PER_BLOCK) { + let x_base = src1_idx_base + block * BLOCK_SIZE + thread_within_block * 4; + var x_block: array<f32, ELEMS_PER_THREAD>; + for (var i = 0u; i < ELEMS_PER_THREAD / 2; i++) { + x_block[i] = f32(src1[x_base + i]); + x_block[i + 4] = f32(src1[x_base + i + 16]); + } + + for (var row = 0u; row < OUTPUTS_PER_WG; row++) { + let output_row = row_base + row; + if (output_row < params.m) { + let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES; + let d = f32(load_f16_at_src0(block_byte_base)); + let m = f32(load_f16_at_src0(block_byte_base + 2u)); + let qh_packed = load_u32_at_src0(block_byte_base + 4u); + let q_packed = load_u32_at_src0(block_byte_base + 8u + 4u * thread_within_block); + let qh_shift = thread_within_block * 4u; + var row_sum = 0.0; + + for (var byte_idx = 0u; byte_idx < 4u; byte_idx++) { + let q_byte = get_byte(q_packed, byte_idx); + let qh_lo = ((qh_packed >> (qh_shift + byte_idx)) << 4u) & 0x10u; + let qh_hi = (qh_packed >> (qh_shift + byte_idx + 12u)) & 0x10u; + let q_lo = f32((q_byte & 0xFu) | qh_lo) * d + m; + let q_hi = f32(((q_byte >> 4u) & 0xFu) | qh_hi) * d + m; + row_sum += q_lo * x_block[byte_idx]; + row_sum += q_hi * x_block[byte_idx + 4u]; + } + acc[row] += row_sum; + } + } + } + + return acc; +} +#endif + +#ifdef MUL_ACC_Q8_0 +#define BLOCK_SIZE 32 +#define BLOCK_SIZE_BYTES 34 +#define THREADS_PER_BLOCK 4 +#define ELEMS_PER_THREAD (BLOCK_SIZE/THREADS_PER_BLOCK) +fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array<f32, OUTPUTS_PER_WG> { + var acc: array<f32, OUTPUTS_PER_WG>; + + let num_blocks = params.k / BLOCK_SIZE; + let thread_within_block = thread_id % THREADS_PER_BLOCK; + for (var block = thread_id / THREADS_PER_BLOCK; block < num_blocks; block += WG_SIZE / THREADS_PER_BLOCK) { + let x_base = src1_idx_base + block * BLOCK_SIZE + thread_within_block * ELEMS_PER_THREAD; + var x_block: array<f32, ELEMS_PER_THREAD>; + for (var i = 0u; i < ELEMS_PER_THREAD; i++) { + x_block[i] = f32(src1[x_base + i]); + } + + for (var row = 0u; row < OUTPUTS_PER_WG; row++) { + let output_row = row_base + row; + if (output_row < params.m) { + let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES; + let d = f32(load_f16_at_src0(block_byte_base)); + var row_sum = 0.0; + + for (var packed_idx = 0u; packed_idx < ELEMS_PER_THREAD / 4u; packed_idx++) { + let q_packed = load_u32_at_src0(block_byte_base + 2u + 4u * (thread_within_block * 2u + packed_idx)); + for (var byte_idx = 0u; byte_idx < 4u; byte_idx++) { + let q_val = f32(get_byte_i32(q_packed, byte_idx)) * d; + row_sum += q_val * x_block[packed_idx * 4u + byte_idx]; + } + } + acc[row] += row_sum; + } + } + } + + return acc; +} +#endif + +#ifdef MUL_ACC_Q8_1 +#define BLOCK_SIZE 32 +#define BLOCK_SIZE_BYTES 36 +#define THREADS_PER_BLOCK 4 +#define ELEMS_PER_THREAD (BLOCK_SIZE/THREADS_PER_BLOCK) +fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array<f32, OUTPUTS_PER_WG> { + var acc: array<f32, OUTPUTS_PER_WG>; + + let num_blocks = params.k / BLOCK_SIZE; + let thread_within_block = thread_id % THREADS_PER_BLOCK; + for (var block = thread_id / THREADS_PER_BLOCK; block < num_blocks; block += WG_SIZE / THREADS_PER_BLOCK) { + let x_base = src1_idx_base + block * BLOCK_SIZE + thread_within_block * ELEMS_PER_THREAD; + var x_block: array<f32, ELEMS_PER_THREAD>; + for (var i = 0u; i < ELEMS_PER_THREAD; i++) { + x_block[i] = f32(src1[x_base + i]); + } + + for (var row = 0u; row < OUTPUTS_PER_WG; row++) { + let output_row = row_base + row; + if (output_row < params.m) { + let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES; + let d = f32(load_f16_at_src0(block_byte_base)); + let m = f32(load_f16_at_src0(block_byte_base + 2u)); + var row_sum = 0.0; + + for (var packed_idx = 0u; packed_idx < ELEMS_PER_THREAD / 4u; packed_idx++) { + let q_packed = load_u32_at_src0(block_byte_base + 4u + 4u * (thread_within_block * 2u + packed_idx)); + for (var byte_idx = 0u; byte_idx < 4u; byte_idx++) { + let q_val = f32(get_byte_i32(q_packed, byte_idx)) * d + m; + row_sum += q_val * x_block[packed_idx * 4u + byte_idx]; + } + } + acc[row] += row_sum; + } + } + } + + return acc; +} +#endif + +#ifdef MUL_ACC_Q2_K +#define BLOCK_SIZE 256 +#define BLOCK_SIZE_BYTES 84 +#define THREADS_PER_BLOCK 16 +fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array<f32, OUTPUTS_PER_WG> { + var acc: array<f32, OUTPUTS_PER_WG>; + + let tid = thread_id % THREADS_PER_BLOCK; + let block_group = thread_id / THREADS_PER_BLOCK; + let num_block_groups: u32 = WG_SIZE / THREADS_PER_BLOCK; + + let lane = tid / 2u; + let phase = tid % 2u; + let iq = lane / 4u; + let ir = lane % 4u; + let is = ir / 2u; + + let y_offset = 128u * iq + 8u * ir + 4u * phase; + let sc0_byte = 8u * iq + is; + let sc2_byte = 8u * iq + is + 2u; + let sc4_byte = 8u * iq + is + 4u; + let sc6_byte = 8u * iq + is + 6u; + let qs_byte = 16u + (16u * iq + 4u * ir) * 2u + 4u * phase; + + let num_blocks = params.k / BLOCK_SIZE; + + for (var block = block_group; block < num_blocks; block += num_block_groups) { + let x_base = src1_idx_base + block * BLOCK_SIZE + y_offset; + var x_block: array<f32, 16>; + for (var i = 0u; i < 4u; i++) { + x_block[i] = f32(src1[x_base + i]); + x_block[i + 4u] = f32(src1[x_base + 32u + i]); + x_block[i + 8u] = f32(src1[x_base + 64u + i]); + x_block[i + 12u] = f32(src1[x_base + 96u + i]); + } + + for (var row = 0u; row < OUTPUTS_PER_WG; row++) { + let output_row = row_base + row; + if (output_row < params.m) { + let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES; + + let dall = f32(load_f16_at_src0(block_byte_base + 80u)); + let dmin = f32(load_f16_at_src0(block_byte_base + 82u)) * (1.0 / 16.0); + + let sc0 = byte_of(load_u32_at_src0_aligned(block_byte_base + sc0_byte), sc0_byte & 3u); + let sc2 = byte_of(load_u32_at_src0_aligned(block_byte_base + sc2_byte), sc2_byte & 3u); + let sc4 = byte_of(load_u32_at_src0_aligned(block_byte_base + sc4_byte), sc4_byte & 3u); + let sc6 = byte_of(load_u32_at_src0_aligned(block_byte_base + sc6_byte), sc6_byte & 3u); + + let q_u32 = load_u32_at_src0_aligned(block_byte_base + qs_byte); + let qs0 = q_u32 & 0xFFFFu; + let qs1 = q_u32 >> 16u; + + var sumy = vec4<f32>(0.0, 0.0, 0.0, 0.0); + var acc1 = vec4<f32>(0.0, 0.0, 0.0, 0.0); + var acc2 = vec4<f32>(0.0, 0.0, 0.0, 0.0); + + sumy[0] = x_block[0] + x_block[1] + x_block[2] + x_block[3]; + sumy[1] = x_block[4] + x_block[5] + x_block[6] + x_block[7]; + sumy[2] = x_block[8] + x_block[9] + x_block[10] + x_block[11]; + sumy[3] = x_block[12] + x_block[13] + x_block[14] + x_block[15]; + + acc1[0] = x_block[0] * f32(qs0 & 0x0003u) + x_block[2] * f32(qs1 & 0x0003u); + acc2[0] = x_block[1] * f32(qs0 & 0x0300u) + x_block[3] * f32(qs1 & 0x0300u); + acc1[1] = x_block[4] * f32(qs0 & 0x000Cu) + x_block[6] * f32(qs1 & 0x000Cu); + acc2[1] = x_block[5] * f32(qs0 & 0x0C00u) + x_block[7] * f32(qs1 & 0x0C00u); + acc1[2] = x_block[8] * f32(qs0 & 0x0030u) + x_block[10] * f32(qs1 & 0x0030u); + acc2[2] = x_block[9] * f32(qs0 & 0x3000u) + x_block[11] * f32(qs1 & 0x3000u); + acc1[3] = x_block[12] * f32(qs0 & 0x00C0u) + x_block[14] * f32(qs1 & 0x00C0u); + acc2[3] = x_block[13] * f32(qs0 & 0xC000u) + x_block[15] * f32(qs1 & 0xC000u); + + acc[row] += dall * ((acc1[0] + (1.0/256.0) * acc2[0]) * f32(sc0 & 0xFu) + + (acc1[1] + (1.0/256.0) * acc2[1]) * f32(sc2 & 0xFu) / 4.0 + + (acc1[2] + (1.0/256.0) * acc2[2]) * f32(sc4 & 0xFu) / 16.0 + + (acc1[3] + (1.0/256.0) * acc2[3]) * f32(sc6 & 0xFu) / 64.0) + - dmin * (sumy[0] * f32(sc0 & 0xF0u) + sumy[1] * f32(sc2 & 0xF0u) + + sumy[2] * f32(sc4 & 0xF0u) + sumy[3] * f32(sc6 & 0xF0u)); + } + } + } + + return acc; +} +#endif + +#ifdef MUL_ACC_Q3_K +#define BLOCK_SIZE 256 +#define BLOCK_SIZE_BYTES 110 +#define THREADS_PER_BLOCK 16 +fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array<f32, OUTPUTS_PER_WG> { + var acc: array<f32, OUTPUTS_PER_WG>; + + let tid = thread_id % THREADS_PER_BLOCK; + let block_group = thread_id / THREADS_PER_BLOCK; + let num_block_groups: u32 = WG_SIZE / THREADS_PER_BLOCK; + + let lane = tid / 2u; + let phase = tid % 2u; + let ip = lane / 4u; + let il = 2u * ((lane % 4u) / 2u); + let ir = lane % 2u; + let l0 = 8u * ir; + + let q_byte = 32u + 32u * ip + l0 + 16u * phase; + let h_byte = l0 + 16u * phase; + let y_offset = 128u * ip + 32u * il + l0 + 16u * phase; + + let s_shift1 = 4u * ip; + let s_shift2 = s_shift1 + il; + + let v1 = select(64.0, 4.0, il == 0u); + let v2 = 4.0 * v1; + let shift = 2u * il; + + var qm0: u32; var qm1: u32; var qm2: u32; var qm3: u32; + if (il == 0u) { + qm0 = 0x0003u; qm1 = 0x0300u; qm2 = 0x000Cu; qm3 = 0x0C00u; + } else { + qm0 = 0x0030u; qm1 = 0x3000u; qm2 = 0x00C0u; qm3 = 0xC000u; + } + + let mm_idx = 2u * ip + il / 2u; + var hm0: u32; var hm1: u32; var hm2: u32; var hm3: u32; + switch (mm_idx) { + case 0u: { hm0=0x0001u; hm1=0x0100u; hm2=0x0002u; hm3=0x0200u; } + case 1u: { hm0=0x0004u; hm1=0x0400u; hm2=0x0008u; hm3=0x0800u; } + case 2u: { hm0=0x0010u; hm1=0x1000u; hm2=0x0020u; hm3=0x2000u; } + default: { hm0=0x0040u; hm1=0x4000u; hm2=0x0080u; hm3=0x8000u; } + } + + let num_blocks = params.k / BLOCK_SIZE; + + for (var block = block_group; block < num_blocks; block += num_block_groups) { + let x_base = src1_idx_base + block * BLOCK_SIZE + y_offset; + var x_block: array<f32, 16>; + for (var i = 0u; i < 8u; i++) { + x_block[i] = f32(src1[x_base + i]); + x_block[i + 8u] = f32(src1[x_base + 32u + i]); + } + + for (var row = 0u; row < OUTPUTS_PER_WG; row++) { + let output_row = row_base + row; + if (output_row < params.m) { + let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES; + + let d = f32(load_f16_at_src0(block_byte_base + 108u)); + let a_base = 96u; + let a_il0 = load_u16_at_src0(block_byte_base + a_base + il * 2u); + let a_il1 = load_u16_at_src0(block_byte_base + a_base + (il + 1u) * 2u); + let a_4 = load_u16_at_src0(block_byte_base + a_base + 8u); + let a_5 = load_u16_at_src0(block_byte_base + a_base + 10u); + + var scales32 = a_4 | (a_5 << 16u); + let aux32 = ((scales32 >> s_shift2) << 4u) & 0x30303030u; + scales32 = a_il0 | (a_il1 << 16u); + scales32 = ((scales32 >> s_shift1) & 0x0F0F0F0Fu) | aux32; + + let scale0 = f32(i32(byte_of(scales32, phase + 0u)) - 32); + let scale1 = f32(i32(byte_of(scales32, phase + 2u)) - 32); + + let q_u32_0 = load_u32_at_src0(block_byte_base + q_byte + 0u); + let q_u32_1 = load_u32_at_src0(block_byte_base + q_byte + 4u); + let h_u32_0 = load_u32_at_src0(block_byte_base + h_byte + 0u); + let h_u32_1 = load_u32_at_src0(block_byte_base + h_byte + 4u); + + var s1 = 0.0; var s2 = 0.0; var s3 = 0.0; + var s4 = 0.0; var s5 = 0.0; var s6 = 0.0; + + for (var l = 0u; l < 8u; l += 2u) { + let q_u32 = select(q_u32_0, q_u32_1, l >= 4u); + let qs = select(q_u32 & 0xFFFFu, q_u32 >> 16u, (l & 2u) != 0u); + let h_u32 = select(h_u32_0, h_u32_1, l >= 4u); + let hv = select(h_u32 & 0xFFFFu, h_u32 >> 16u, (l & 2u) != 0u); + + s1 += x_block[l + 0u] * f32(qs & qm0); + s2 += x_block[l + 1u] * f32(qs & qm1); + s3 += select(0.0, x_block[l + 0u], (hv & hm0) == 0u) + + select(0.0, x_block[l + 1u], (hv & hm1) == 0u); + s4 += x_block[l + 8u] * f32(qs & qm2); + s5 += x_block[l + 9u] * f32(qs & qm3); + s6 += select(0.0, x_block[l + 8u], (hv & hm2) == 0u) + + select(0.0, x_block[l + 9u], (hv & hm3) == 0u); + } + + let d1 = d * (s1 + (1.0/256.0) * s2 - s3 * v1); + let d2 = d * (s4 + (1.0/256.0) * s5 - s6 * v2); + acc[row] += (d1 * scale0 + 0.25 * d2 * scale1) / f32(1u << shift); + } + } + } + + return acc; +} +#endif + +#ifdef MUL_ACC_Q4_K +#define BLOCK_SIZE 256 +#define BLOCK_SIZE_BYTES 144 +#define THREADS_PER_BLOCK 16 +fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array<f32, OUTPUTS_PER_WG> { + var acc: array<f32, OUTPUTS_PER_WG>; + + let tid = thread_id % THREADS_PER_BLOCK; + let block_group = thread_id / THREADS_PER_BLOCK; + let num_block_groups: u32 = WG_SIZE / THREADS_PER_BLOCK; + + let il = tid / 4u; + let ir = tid % 4u; + let im = il / 2u; + let in = il % 2u; + let l0 = 4u * (2u * ir + in); + + let y_offset = 64u * im + l0; + let q_offset = 32u * im + l0; + let sc0_byte = 4u + im * 2u; + let sc2_byte = 4u + (im + 2u) * 2u; + let sc4_byte = 4u + (im + 4u) * 2u; + + let num_blocks = params.k / BLOCK_SIZE; + + for (var block = block_group; block < num_blocks; block += num_block_groups) { + let x_base = src1_idx_base + block * BLOCK_SIZE + y_offset; + var x_block: array<f32, 16>; + for (var i = 0u; i < 4u; i++) { + x_block[i] = f32(src1[x_base + i]); + x_block[i + 4u] = f32(src1[x_base + 32u + i]); + x_block[i + 8u] = f32(src1[x_base + 128u + i]); + x_block[i + 12u] = f32(src1[x_base + 160u + i]); + } + + for (var row = 0u; row < OUTPUTS_PER_WG; row++) { + let output_row = row_base + row; + if (output_row < params.m) { + let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES; + + let d = f32(load_f16_at_src0(block_byte_base + 0u)); + let dmin = f32(load_f16_at_src0(block_byte_base + 2u)); + + let sc0_u32 = load_u32_at_src0_aligned(block_byte_base + sc0_byte); + let sc0 = select(sc0_u32 & 0xFFFFu, sc0_u32 >> 16u, (sc0_byte & 2u) != 0u); + let sc2_u32 = load_u32_at_src0_aligned(block_byte_base + sc2_byte); + let sc2 = select(sc2_u32 & 0xFFFFu, sc2_u32 >> 16u, (sc2_byte & 2u) != 0u); + let sc4_u32 = load_u32_at_src0_aligned(block_byte_base + sc4_byte); + let sc4 = select(sc4_u32 & 0xFFFFu, sc4_u32 >> 16u, (sc4_byte & 2u) != 0u); + + let sc16_0 = sc0 & 0x3F3Fu; + let sc16_1 = sc2 & 0x3F3Fu; + let sc16_2 = (sc4 & 0x0F0Fu) | ((sc0 & 0xC0C0u) >> 2u); + let sc16_3 = ((sc4 >> 4u) & 0x0F0Fu) | ((sc2 & 0xC0C0u) >> 2u); + + let scale0 = f32(sc16_0 & 0xFFu); + let scale1 = f32((sc16_0 >> 8u) & 0xFFu); + let min0 = f32(sc16_1 & 0xFFu); + let min1 = f32((sc16_1 >> 8u) & 0xFFu); + let scale2 = f32(sc16_2 & 0xFFu); + let scale3 = f32((sc16_2 >> 8u) & 0xFFu); + let min2 = f32(sc16_3 & 0xFFu); + let min3 = f32((sc16_3 >> 8u) & 0xFFu); + + let q1_u32 = load_u32_at_src0_aligned(block_byte_base + 16u + q_offset); + let q2_u32 = load_u32_at_src0_aligned(block_byte_base + 80u + q_offset); + + var dot = vec4<f32>(0.0, 0.0, 0.0, 0.0); + var sumx = vec4<f32>(0.0, 0.0, 0.0, 0.0); + for (var i = 0u; i < 4u; i++) { + let q1b = byte_of(q1_u32, i); + let q2b = byte_of(q2_u32, i); + dot[0] += x_block[i] * f32(q1b & 0x0Fu); + dot[1] += x_block[i + 4u] * f32(q1b >> 4u); + dot[2] += x_block[i + 8u] * f32(q2b & 0x0Fu); + dot[3] += x_block[i + 12u] * f32(q2b >> 4u); + sumx[0] += x_block[i]; + sumx[1] += x_block[i + 4u]; + sumx[2] += x_block[i + 8u]; + sumx[3] += x_block[i + 12u]; + } + + acc[row] += d * (dot[0] * scale0 + dot[1] * scale1 + dot[2] * scale2 + dot[3] * scale3) + - dmin * (sumx[0] * min0 + sumx[1] * min1 + sumx[2] * min2 + sumx[3] * min3); + } + } + } + + return acc; +} +#endif + +#ifdef MUL_ACC_Q5_K +#define BLOCK_SIZE 256 +#define BLOCK_SIZE_BYTES 176 +#define THREADS_PER_BLOCK 16 +fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array<f32, OUTPUTS_PER_WG> { + var acc: array<f32, OUTPUTS_PER_WG>; + + let tid = thread_id % THREADS_PER_BLOCK; + let block_group = thread_id / THREADS_PER_BLOCK; + let num_block_groups: u32 = WG_SIZE / THREADS_PER_BLOCK; + + let il = tid / 4u; + let ir = tid % 4u; + let im = il / 2u; + let in = il % 2u; + let l0 = 4u * (2u * ir + in); + + let y_offset = 64u * im + l0; + let q_offset = 48u + 32u * im + l0; + let qh_offset = 16u + 8u * ir + 4u * in; + let sc0_byte = 4u + im * 2u; + let sc2_byte = 4u + (im + 2u) * 2u; + let sc4_byte = 4u + (im + 4u) * 2u; + + let hm1 = 1u << (2u * im); + let hm2 = hm1 << 1u; + let hm3 = hm1 << 4u; + let hm4 = hm2 << 4u; + + let num_blocks = params.k / BLOCK_SIZE; + + for (var block = block_group; block < num_blocks; block += num_block_groups) { + let x_base = src1_idx_base + block * BLOCK_SIZE + y_offset; + var x_block: array<f32, 16>; + for (var i = 0u; i < 4u; i++) { + x_block[i] = f32(src1[x_base + i]); + x_block[i + 4u] = f32(src1[x_base + 32u + i]); + x_block[i + 8u] = f32(src1[x_base + 128u + i]); + x_block[i + 12u] = f32(src1[x_base + 160u + i]); + } + + for (var row = 0u; row < OUTPUTS_PER_WG; row++) { + let output_row = row_base + row; + if (output_row < params.m) { + let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES; + + let d = f32(load_f16_at_src0(block_byte_base + 0u)); + let dmin = f32(load_f16_at_src0(block_byte_base + 2u)); + + let sc0_u32 = load_u32_at_src0_aligned(block_byte_base + sc0_byte); + let sc0 = select(sc0_u32 & 0xFFFFu, sc0_u32 >> 16u, (sc0_byte & 2u) != 0u); + let sc2_u32 = load_u32_at_src0_aligned(block_byte_base + sc2_byte); + let sc2 = select(sc2_u32 & 0xFFFFu, sc2_u32 >> 16u, (sc2_byte & 2u) != 0u); + let sc4_u32 = load_u32_at_src0_aligned(block_byte_base + sc4_byte); + let sc4 = select(sc4_u32 & 0xFFFFu, sc4_u32 >> 16u, (sc4_byte & 2u) != 0u); + + let sc16_0 = sc0 & 0x3F3Fu; + let sc16_1 = sc2 & 0x3F3Fu; + let sc16_2 = (sc4 & 0x0F0Fu) | ((sc0 & 0xC0C0u) >> 2u); + let sc16_3 = ((sc4 >> 4u) & 0x0F0Fu) | ((sc2 & 0xC0C0u) >> 2u); + + let f0 = f32(sc16_0 & 0xFFu); + let f1 = f32((sc16_0 >> 8u) & 0xFFu); + let m0 = f32(sc16_1 & 0xFFu); + let m1 = f32((sc16_1 >> 8u) & 0xFFu); + let f4 = f32(sc16_2 & 0xFFu); + let f5 = f32((sc16_2 >> 8u) & 0xFFu); + let m4 = f32(sc16_3 & 0xFFu); + let m5 = f32((sc16_3 >> 8u) & 0xFFu); + + let q1_u32 = load_u32_at_src0_aligned(block_byte_base + q_offset); + let q2_u32 = load_u32_at_src0_aligned(block_byte_base + q_offset + 64u); + let qh_u32 = load_u32_at_src0_aligned(block_byte_base + qh_offset); + + var vals = vec4<f32>(0.0, 0.0, 0.0, 0.0); + var sumy = vec4<f32>(0.0, 0.0, 0.0, 0.0); + for (var i = 0u; i < 4u; i++) { + let q1b = byte_of(q1_u32, i); + let q2b = byte_of(q2_u32, i); + let qhb = byte_of(qh_u32, i); + + let yl0 = x_block[i]; + let yl8 = x_block[i + 4u]; + let yh0 = x_block[i + 8u]; + let yh8 = x_block[i + 12u]; + + sumy[0] += yl0; + sumy[1] += yl8; + sumy[2] += yh0; + sumy[3] += yh8; + + let q0 = f32((q1b & 0x0Fu) | select(0u, 0x10u, (qhb & hm1) != 0u)); + let q1 = f32((q1b >> 4u) | select(0u, 0x10u, (qhb & hm2) != 0u)); + let q2 = f32((q2b & 0x0Fu) | select(0u, 0x10u, (qhb & hm3) != 0u)); + let q3 = f32((q2b >> 4u) | select(0u, 0x10u, (qhb & hm4) != 0u)); + + vals[0] += yl0 * q0; + vals[1] += yl8 * q1; + vals[2] += yh0 * q2; + vals[3] += yh8 * q3; + } + + acc[row] += d * (f0 * vals[0] + f1 * vals[1] + f4 * vals[2] + f5 * vals[3]) + - dmin * (sumy[0] * m0 + sumy[1] * m1 + + sumy[2] * m4 + sumy[3] * m5); + } + } + } + + return acc; +} +#endif + +#ifdef MUL_ACC_Q6_K +#define BLOCK_SIZE 256 +#define BLOCK_SIZE_BYTES 210 +#define THREADS_PER_BLOCK 16 +fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array<f32, OUTPUTS_PER_WG> { + var acc: array<f32, OUTPUTS_PER_WG>; + + let tid = thread_id % THREADS_PER_BLOCK; + let block_group = thread_id / THREADS_PER_BLOCK; + let num_block_groups: u32 = WG_SIZE / THREADS_PER_BLOCK; + + let ip = tid / 8u; + let il = tid % 8u; + let l0 = 4u * il; + let is = 8u * ip + l0 / 16u; + + let y_offset = 128u * ip + l0; + let q_offset_l = 64u * ip + l0; + let q_offset_h = 32u * ip + l0; + + let num_blocks = params.k / BLOCK_SIZE; + let sc_base_byte = 192u + (is & ~3u); + let sc_byte_pos = is & 3u; + + for (var block = block_group; block < num_blocks; block += num_block_groups) { + let x_base = src1_idx_base + block * BLOCK_SIZE + y_offset; + var x_block: array<f32, 16>; + for (var l = 0u; l < 4u; l++) { + x_block[l] = f32(src1[x_base + l]); + x_block[l + 4u] = f32(src1[x_base + 32u + l]); + x_block[l + 8u] = f32(src1[x_base + 64u + l]); + x_block[l + 12u] = f32(src1[x_base + 96u + l]); + } + + for (var row = 0u; row < OUTPUTS_PER_WG; row++) { + let output_row = row_base + row; + if (output_row < params.m) { + let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES; + + let d = f32(load_f16_at_src0(block_byte_base + 208u)); + let ql1_u32 = load_u32_at_src0(block_byte_base + q_offset_l); + let ql2_u32 = load_u32_at_src0(block_byte_base + q_offset_l + 32u); + let qh_u32 = load_u32_at_src0(block_byte_base + 128u + q_offset_h); + let sc_u32_0 = load_u32_at_src0(block_byte_base + sc_base_byte); + let sc_u32_1 = load_u32_at_src0(block_byte_base + sc_base_byte + 4u); + + let sc0 = sbyte_of(sc_u32_0, sc_byte_pos); + let sc2 = sbyte_of(sc_u32_0, sc_byte_pos + 2u); + let sc4 = sbyte_of(sc_u32_1, sc_byte_pos); + let sc6 = sbyte_of(sc_u32_1, sc_byte_pos + 2u); + + var sums = vec4<f32>(0.0, 0.0, 0.0, 0.0); + + for (var l = 0u; l < 4u; l++) { + let q1b = byte_of(ql1_u32, l); + let q2b = byte_of(ql2_u32, l); + let qhb = byte_of(qh_u32, l); + + let dq0 = f32(i32((q1b & 0x0Fu) | ((qhb & 0x03u) << 4u)) - 32); + let dq1 = f32(i32((q2b & 0x0Fu) | ((qhb & 0x0Cu) << 2u)) - 32); + let dq2 = f32(i32((q1b >> 4u) | (qhb & 0x30u)) - 32); + let dq3 = f32(i32((q2b >> 4u) | ((qhb & 0xC0u) >> 2u)) - 32); + + sums[0] += x_block[l] * dq0; + sums[1] += x_block[l + 4u] * dq1; + sums[2] += x_block[l + 8u] * dq2; + sums[3] += x_block[l + 12u] * dq3; + } + + acc[row] += d * (sums[0] * f32(sc0) + sums[1] * f32(sc2) + + sums[2] * f32(sc4) + sums[3] * f32(sc6)); + } + } + } + + return acc; +} +#endif + +#ifdef MUL_ACC_IQ1_S +#define BLOCK_SIZE 256 +#define BLOCK_SIZE_BYTES 50 +#define THREADS_PER_BLOCK 16 +fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array<f32, OUTPUTS_PER_WG> { + var acc: array<f32, OUTPUTS_PER_WG>; + + let tid = thread_id % THREADS_PER_BLOCK; + let block_group = thread_id / THREADS_PER_BLOCK; + let num_block_groups: u32 = WG_SIZE / THREADS_PER_BLOCK; + + let sub_blk = tid / 2u; + let half = tid % 2u; + let slot0 = half * 2u; + let y_offset = sub_blk * 32u + slot0 * 8u; + + let num_blocks = params.k / BLOCK_SIZE; + + for (var block = block_group; block < num_blocks; block += num_block_groups) { + let x_base = src1_idx_base + block * BLOCK_SIZE + y_offset; + var x_block: array<f32, 16>; + for (var i = 0u; i < 16u; i++) { + x_block[i] = f32(src1[x_base + i]); + } + + for (var row = 0u; row < OUTPUTS_PER_WG; row++) { + let output_row = row_base + row; + if (output_row < params.m) { + let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES; + + let d = f32(load_f16_at_src0(block_byte_base)); + let qh = load_u32_at_src0(block_byte_base + 34u + sub_blk * 2u) & 0xFFFFu; + let dl = d * f32(2u * ((qh >> 12u) & 7u) + 1u); + let delta = select(IQ1_DELTA, -IQ1_DELTA, (qh & 0x8000u) != 0u); + let qs_w = load_u32_at_src0(block_byte_base + 2u + sub_blk * 4u); + + var row_sum = 0.0; + for (var ll = 0u; ll < 2u; ll++) { + let l = slot0 + ll; + let qs_byte = get_byte(qs_w, l); + let ig = (qs_byte | (((qh >> (3u * l)) & 7u) << 8u)) * 8u; + let gw = iq1_grid[ig / 16u]; + let bit_base = (ig % 16u) * 2u; + for (var j = 0u; j < 8u; j++) { + let g = (gw >> (bit_base + j * 2u)) & 3u; + let gs = select(f32(g), f32(g) - 4.0, (g & 2u) != 0u); + row_sum += dl * (gs + delta) * x_block[ll * 8u + j]; + } + } + acc[row] += row_sum; + } + } + } + + return acc; +} +#endif + +#ifdef MUL_ACC_IQ1_M +#define BLOCK_SIZE 256 +#define BLOCK_SIZE_BYTES 56 +#define THREADS_PER_BLOCK 16 +fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array<f32, OUTPUTS_PER_WG> { + var acc: array<f32, OUTPUTS_PER_WG>; + + let tid = thread_id % THREADS_PER_BLOCK; + let block_group = thread_id / THREADS_PER_BLOCK; + let num_block_groups: u32 = WG_SIZE / THREADS_PER_BLOCK; + + let sub_blk = tid / 2u; + let half = tid % 2u; + let slot0 = half * 2u; + let y_offset = sub_blk * 32u + slot0 * 8u; + + let num_blocks = params.k / BLOCK_SIZE; + + for (var block = block_group; block < num_blocks; block += num_block_groups) { + let x_base = src1_idx_base + block * BLOCK_SIZE + y_offset; + var x_block: array<f32, 16>; + for (var i = 0u; i < 16u; i++) { + x_block[i] = f32(src1[x_base + i]); + } + + for (var row = 0u; row < OUTPUTS_PER_WG; row++) { + let output_row = row_base + row; + if (output_row < params.m) { + let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES; + + let sc_lo = load_u32_at_src0(block_byte_base + 48u); + let sc_hi = load_u32_at_src0(block_byte_base + 52u); + let sc0 = sc_lo & 0xFFFFu; + let sc1 = (sc_lo >> 16u) & 0xFFFFu; + let sc2 = sc_hi & 0xFFFFu; + let sc3 = (sc_hi >> 16u) & 0xFFFFu; + let d_bits = (sc0 >> 12u) | ((sc1 >> 8u) & 0xF0u) | ((sc2 >> 4u) & 0xF00u) | (sc3 & 0xF000u); + let d = f32(bitcast<vec2<f16>>(d_bits)[0]); + + let sc_u16 = select(select(sc2, sc3, sub_blk >= 6u), + select(sc0, sc1, sub_blk >= 2u), + sub_blk < 4u); + + let qs_w = load_u32_at_src0(block_byte_base + sub_blk * 4u); + let qh = load_u32_at_src0(block_byte_base + 32u + sub_blk * 2u) & 0xFFFFu; + let qh_lo = qh & 0xFFu; + let qh_hi = (qh >> 8u) & 0xFFu; + + var row_sum = 0.0; + for (var ll = 0u; ll < 2u; ll++) { + let l = slot0 + ll; + let bit_off = 6u * (sub_blk % 2u) + 3u * (l / 2u); + let sub_scale = (sc_u16 >> bit_off) & 0x7u; + let dl = d * f32(2u * sub_scale + 1u); + let qh_byte = select(qh_lo, qh_hi, l >= 2u); + let ll2 = l % 2u; + let grid_idx = get_byte(qs_w, l) | (((qh_byte >> (4u * ll2)) & 7u) << 8u); + let delta = select(IQ1_DELTA, -IQ1_DELTA, ((qh_byte >> (3u + 4u * ll2)) & 1u) != 0u); + let ig = grid_idx * 8u; + let gw = iq1_grid[ig / 16u]; + let bit_base = (ig % 16u) * 2u; + for (var j = 0u; j < 8u; j++) { + let g = (gw >> (bit_base + j * 2u)) & 3u; + let gs = select(f32(g), f32(g) - 4.0, (g & 2u) != 0u); + row_sum += dl * (gs + delta) * x_block[ll * 8u + j]; + } + } + acc[row] += row_sum; + } + } + } + + return acc; +} +#endif + +#ifdef MUL_ACC_IQ2_XXS +#define BLOCK_SIZE 256 +#define BLOCK_SIZE_BYTES 66 +#define THREADS_PER_BLOCK 16 +fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array<f32, OUTPUTS_PER_WG> { + var acc: array<f32, OUTPUTS_PER_WG>; + + let tid = thread_id % THREADS_PER_BLOCK; + let block_group = thread_id / THREADS_PER_BLOCK; + let num_block_groups: u32 = WG_SIZE / THREADS_PER_BLOCK; + + let sub_blk = tid / 2u; + let half = tid % 2u; + let slot0 = half * 2u; + let y_offset = sub_blk * 32u + slot0 * 8u; + + let num_blocks = params.k / BLOCK_SIZE; + + for (var block = block_group; block < num_blocks; block += num_block_groups) { + let x_base = src1_idx_base + block * BLOCK_SIZE + y_offset; + var x_block: array<f32, 16>; + for (var i = 0u; i < 16u; i++) { + x_block[i] = f32(src1[x_base + i]); + } + + for (var row = 0u; row < OUTPUTS_PER_WG; row++) { + let output_row = row_base + row; + if (output_row < params.m) { + let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES; + let d = f32(load_f16_at_src0(block_byte_base)); + let aux_lo = load_u32_at_src0(block_byte_base + 2u + sub_blk * 8u); + let aux_hi = load_u32_at_src0(block_byte_base + 2u + sub_blk * 8u + 4u); + let ls = aux_hi >> 28u; + let db = d * (0.5 + f32(ls)) * 0.25; + + var row_sum = 0.0; + for (var ll = 0u; ll < 2u; ll++) { + let l = slot0 + ll; + let grid_idx = (aux_lo >> (8u * l)) & 0xFFu; + let signs_idx = (aux_hi >> (7u * l)) & 0x7Fu; + let signs = (ksigns_iq2xs[signs_idx / 4u] >> ((signs_idx % 4u) * 8u)) & 0xFFu; + let gw_lo = iq2xxs_grid[grid_idx * 2u]; + let gw_hi = iq2xxs_grid[grid_idx * 2u + 1u]; + for (var j = 0u; j < 8u; j++) { + let gw = select(gw_hi, gw_lo, j < 4u); + let b = f32((gw >> ((j & 3u) * 8u)) & 0xFFu); + let s = select(1.0, -1.0, ((signs >> j) & 1u) != 0u); + row_sum += db * b * s * x_block[ll * 8u + j]; + } + } + acc[row] += row_sum; + } + } + } + + return acc; +} +#endif + +#ifdef MUL_ACC_IQ2_XS +#define BLOCK_SIZE 256 +#define BLOCK_SIZE_BYTES 74 +#define THREADS_PER_BLOCK 16 +fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array<f32, OUTPUTS_PER_WG> { + var acc: array<f32, OUTPUTS_PER_WG>; + + let tid = thread_id % THREADS_PER_BLOCK; + let block_group = thread_id / THREADS_PER_BLOCK; + let num_block_groups: u32 = WG_SIZE / THREADS_PER_BLOCK; + + let sub_blk = tid / 2u; + let half = tid % 2u; + let slot0 = half * 2u; + let y_offset = sub_blk * 32u + slot0 * 8u; + + let num_blocks = params.k / BLOCK_SIZE; + + for (var block = block_group; block < num_blocks; block += num_block_groups) { + let x_base = src1_idx_base + block * BLOCK_SIZE + y_offset; + var x_block: array<f32, 16>; + for (var i = 0u; i < 16u; i++) { + x_block[i] = f32(src1[x_base + i]); + } + + for (var row = 0u; row < OUTPUTS_PER_WG; row++) { + let output_row = row_base + row; + if (output_row < params.m) { + let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES; + let d = f32(load_f16_at_src0(block_byte_base)); + let qs_lo = load_u32_at_src0(block_byte_base + 2u + sub_blk * 8u); + let qs_hi = load_u32_at_src0(block_byte_base + 2u + sub_blk * 8u + 4u); + let scales_word = load_u32_at_src0(block_byte_base + 66u + (sub_blk / 4u) * 4u); + let scales_byte = get_byte(scales_word, sub_blk % 4u); + + var row_sum = 0.0; + for (var ll = 0u; ll < 2u; ll++) { + let l = slot0 + ll; + let qs_word = select(qs_hi, qs_lo, l < 2u); + let half2 = (l % 2u) * 16u; + let qs_val = (qs_word >> half2) & 0xFFFFu; + let grid_idx = qs_val & 0x1FFu; + let signs_idx = (qs_val >> 9u) & 0x7Fu; + let sub_scale = (scales_byte >> (4u * (l / 2u))) & 0xFu; + let db = d * (0.5 + f32(sub_scale)) * 0.25; + let signs = (ksigns_iq2xs[signs_idx / 4u] >> ((signs_idx % 4u) * 8u)) & 0xFFu; + let gw_lo = iq2xs_grid[grid_idx * 2u]; + let gw_hi = iq2xs_grid[grid_idx * 2u + 1u]; + for (var j = 0u; j < 8u; j++) { + let gw = select(gw_hi, gw_lo, j < 4u); + let b = f32((gw >> ((j & 3u) * 8u)) & 0xFFu); + let s = select(1.0, -1.0, ((signs >> j) & 1u) != 0u); + row_sum += db * b * s * x_block[ll * 8u + j]; + } + } + acc[row] += row_sum; + } + } + } + + return acc; +} +#endif + +#ifdef MUL_ACC_IQ2_S +#define BLOCK_SIZE 256 +#define BLOCK_SIZE_BYTES 82 +#define THREADS_PER_BLOCK 16 +fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array<f32, OUTPUTS_PER_WG> { + var acc: array<f32, OUTPUTS_PER_WG>; + + let tid = thread_id % THREADS_PER_BLOCK; + let block_group = thread_id / THREADS_PER_BLOCK; + let num_block_groups: u32 = WG_SIZE / THREADS_PER_BLOCK; + + let sub_blk = tid / 2u; + let half = tid % 2u; + let slot0 = half * 2u; + let y_offset = sub_blk * 32u + slot0 * 8u; + + let num_blocks = params.k / BLOCK_SIZE; + + for (var block = block_group; block < num_blocks; block += num_block_groups) { + let x_base = src1_idx_base + block * BLOCK_SIZE + y_offset; + var x_block: array<f32, 16>; + for (var i = 0u; i < 16u; i++) { + x_block[i] = f32(src1[x_base + i]); + } + + for (var row = 0u; row < OUTPUTS_PER_WG; row++) { + let output_row = row_base + row; + if (output_row < params.m) { + let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES; + let d = f32(load_f16_at_src0(block_byte_base)); + let qs_w = load_u32_at_src0(block_byte_base + 2u + sub_blk * 4u); + let sg_w = load_u32_at_src0(block_byte_base + 34u + sub_blk * 4u); + let qh_word = load_u32_at_src0(block_byte_base + 66u + (sub_blk / 4u) * 4u); + let qh_byte = get_byte(qh_word, sub_blk % 4u); + let sc_word = load_u32_at_src0(block_byte_base + 74u + (sub_blk / 4u) * 4u); + let scales_byte = get_byte(sc_word, sub_blk % 4u); + + var row_sum = 0.0; + for (var ll = 0u; ll < 2u; ll++) { + let l = slot0 + ll; + let qs_byte = get_byte(qs_w, l); + let sign_byte = get_byte(sg_w, l); + let grid_idx = qs_byte | (((qh_byte >> (2u * l)) & 3u) << 8u); + let sub_scale = (scales_byte >> (4u * (l / 2u))) & 0xFu; + let db = d * (0.5 + f32(sub_scale)) * 0.25; + let gw_lo = iq2s_grid[grid_idx * 2u]; + let gw_hi = iq2s_grid[grid_idx * 2u + 1u]; + for (var j = 0u; j < 8u; j++) { + let gw = select(gw_hi, gw_lo, j < 4u); + let b = f32((gw >> ((j & 3u) * 8u)) & 0xFFu); + let s = select(1.0, -1.0, ((sign_byte >> j) & 1u) != 0u); + row_sum += db * b * s * x_block[ll * 8u + j]; + } + } + acc[row] += row_sum; + } + } + } + + return acc; +} +#endif + +#ifdef MUL_ACC_IQ3_XXS +#define BLOCK_SIZE 256 +#define BLOCK_SIZE_BYTES 98 +#define THREADS_PER_BLOCK 16 +fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array<f32, OUTPUTS_PER_WG> { + var acc: array<f32, OUTPUTS_PER_WG>; + + let tid = thread_id % THREADS_PER_BLOCK; + let block_group = thread_id / THREADS_PER_BLOCK; + let num_block_groups: u32 = WG_SIZE / THREADS_PER_BLOCK; + + let sub_blk = tid / 2u; + let half = tid % 2u; + let slot0 = half * 2u; + let y_offset = sub_blk * 32u + slot0 * 8u; + + let num_blocks = params.k / BLOCK_SIZE; + + for (var block = block_group; block < num_blocks; block += num_block_groups) { + let x_base = src1_idx_base + block * BLOCK_SIZE + y_offset; + var x_block: array<f32, 16>; + for (var i = 0u; i < 16u; i++) { + x_block[i] = f32(src1[x_base + i]); + } + + for (var row = 0u; row < OUTPUTS_PER_WG; row++) { + let output_row = row_base + row; + if (output_row < params.m) { + let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES; + let d = f32(load_f16_at_src0(block_byte_base)); + let qs_lo = load_u32_at_src0(block_byte_base + 2u + sub_blk * 8u); + let qs_hi = load_u32_at_src0(block_byte_base + 2u + sub_blk * 8u + 4u); + let aux = load_u32_at_src0(block_byte_base + 66u + sub_blk * 4u); + let ls = aux >> 28u; + let db = d * (0.5 + f32(ls)) * 0.5; + + var row_sum = 0.0; + for (var ll = 0u; ll < 2u; ll++) { + let l = slot0 + ll; + let qs_word = select(qs_hi, qs_lo, l < 2u); + let byte_pos = (l % 2u) * 2u; + let grid_idx_0 = (qs_word >> (byte_pos * 8u)) & 0xFFu; + let grid_idx_1 = (qs_word >> ((byte_pos + 1u) * 8u)) & 0xFFu; + let signs_idx = (aux >> (7u * l)) & 0x7Fu; + let signs = (ksigns_iq2xs[signs_idx / 4u] >> ((signs_idx % 4u) * 8u)) & 0xFFu; + let grid1 = iq3xxs_grid[grid_idx_0]; + let grid2 = iq3xxs_grid[grid_idx_1]; + for (var j = 0u; j < 4u; j++) { + let b1 = f32((grid1 >> (j * 8u)) & 0xFFu); + let b2 = f32((grid2 >> (j * 8u)) & 0xFFu); + let s1 = select(1.0, -1.0, ((signs >> j) & 1u) != 0u); + let s2 = select(1.0, -1.0, ((signs >> (j + 4u)) & 1u) != 0u); + row_sum += db * b1 * s1 * x_block[ll * 8u + j]; + row_sum += db * b2 * s2 * x_block[ll * 8u + j + 4u]; + } + } + acc[row] += row_sum; + } + } + } + + return acc; +} +#endif + +#ifdef MUL_ACC_IQ3_S +#define BLOCK_SIZE 256 +#define BLOCK_SIZE_BYTES 110 +#define THREADS_PER_BLOCK 16 +fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array<f32, OUTPUTS_PER_WG> { + var acc: array<f32, OUTPUTS_PER_WG>; + + let tid = thread_id % THREADS_PER_BLOCK; + let block_group = thread_id / THREADS_PER_BLOCK; + let num_block_groups: u32 = WG_SIZE / THREADS_PER_BLOCK; + + let sub_blk = tid / 2u; + let half = tid % 2u; + let slot0 = half * 2u; + let y_offset = sub_blk * 32u + slot0 * 8u; + + let num_blocks = params.k / BLOCK_SIZE; + + for (var block = block_group; block < num_blocks; block += num_block_groups) { + let x_base = src1_idx_base + block * BLOCK_SIZE + y_offset; + var x_block: array<f32, 16>; + for (var i = 0u; i < 16u; i++) { + x_block[i] = f32(src1[x_base + i]); + } + + for (var row = 0u; row < OUTPUTS_PER_WG; row++) { + let output_row = row_base + row; + if (output_row < params.m) { + let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES; + let d = f32(load_f16_at_src0(block_byte_base)); + let qs_lo = load_u32_at_src0(block_byte_base + 2u + sub_blk * 8u); + let qs_hi = load_u32_at_src0(block_byte_base + 2u + sub_blk * 8u + 4u); + let qh_word = load_u32_at_src0(block_byte_base + 66u + (sub_blk / 4u) * 4u); + let qh_byte = get_byte(qh_word, sub_blk % 4u); + let sg_w = load_u32_at_src0(block_byte_base + 74u + sub_blk * 4u); + let sc_word = load_u32_at_src0(block_byte_base + 106u); + let scales_byte = get_byte(sc_word, sub_blk / 2u); + let sub_scale = (scales_byte >> (4u * (sub_blk % 2u))) & 0xFu; + let db = d * (1.0 + 2.0 * f32(sub_scale)); + + var row_sum = 0.0; + for (var ll = 0u; ll < 2u; ll++) { + let l = slot0 + ll; + let qs_word = select(qs_hi, qs_lo, l < 2u); + let byte_pos = (l % 2u) * 2u; + let qs0 = (qs_word >> (byte_pos * 8u)) & 0xFFu; + let qs1 = (qs_word >> ((byte_pos + 1u) * 8u)) & 0xFFu; + let grid_idx_1 = qs0 | (((qh_byte >> (2u * l)) & 1u) << 8u); + let grid_idx_2 = qs1 | (((qh_byte >> (2u * l + 1u)) & 1u) << 8u); + let sign_byte = get_byte(sg_w, l); + let grid1 = iq3s_grid[grid_idx_1]; + let grid2 = iq3s_grid[grid_idx_2]; + for (var j = 0u; j < 4u; j++) { + let b1 = f32((grid1 >> (j * 8u)) & 0xFFu); + let b2 = f32((grid2 >> (j * 8u)) & 0xFFu); + let s1 = select(1.0, -1.0, ((sign_byte >> j) & 1u) != 0u); + let s2 = select(1.0, -1.0, ((sign_byte >> (j + 4u)) & 1u) != 0u); + row_sum += db * b1 * s1 * x_block[ll * 8u + j]; + row_sum += db * b2 * s2 * x_block[ll * 8u + j + 4u]; + } + } + acc[row] += row_sum; + } + } + } + + return acc; +} +#endif + +#ifdef MUL_ACC_IQ4_NL +#define BLOCK_SIZE 32 +#define BLOCK_SIZE_BYTES 18 +#define THREADS_PER_BLOCK 4 +#define ELEMS_PER_THREAD (BLOCK_SIZE/THREADS_PER_BLOCK) +fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array<f32, OUTPUTS_PER_WG> { + var acc: array<f32, OUTPUTS_PER_WG>; + + let num_blocks = params.k / BLOCK_SIZE; + let thread_within_block = thread_id % THREADS_PER_BLOCK; + for (var block = thread_id / THREADS_PER_BLOCK; block < num_blocks; block += WG_SIZE / THREADS_PER_BLOCK) { + let x_base = src1_idx_base + block * BLOCK_SIZE + thread_within_block * 4u; + var x_block: array<f32, ELEMS_PER_THREAD>; + for (var i = 0u; i < ELEMS_PER_THREAD / 2u; i++) { + x_block[i] = f32(src1[x_base + i]); + x_block[i + 4u] = f32(src1[x_base + i + 16u]); + } + + for (var row = 0u; row < OUTPUTS_PER_WG; row++) { + let output_row = row_base + row; + if (output_row < params.m) { + let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES; + let d = f32(load_f16_at_src0(block_byte_base)); + var row_sum = 0.0; + + let q_packed = load_u32_at_src0(block_byte_base + 2u + 4u * thread_within_block); + for (var byte_idx = 0u; byte_idx < 4u; byte_idx++) { + let q_byte = get_byte(q_packed, byte_idx); + let q_lo = f32(kvalues_iq4nl[q_byte & 0xFu]) * d; + let q_hi = f32(kvalues_iq4nl[(q_byte >> 4u) & 0xFu]) * d; + row_sum += q_lo * x_block[byte_idx]; + row_sum += q_hi * x_block[byte_idx + 4u]; + } + acc[row] += row_sum; + } + } + } + + return acc; +} +#endif + +#ifdef MUL_ACC_IQ4_XS +#define BLOCK_SIZE 256 +#define BLOCK_SIZE_BYTES 136 +#define THREADS_PER_BLOCK 16 +fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array<f32, OUTPUTS_PER_WG> { + var acc: array<f32, OUTPUTS_PER_WG>; + + let tid = thread_id % THREADS_PER_BLOCK; + let block_group = thread_id / THREADS_PER_BLOCK; + let num_block_groups: u32 = WG_SIZE / THREADS_PER_BLOCK; + + let sub_blk = tid / 2u; + let half = tid % 2u; + let y_offset = sub_blk * 32u + half * 16u; + + let num_blocks = params.k / BLOCK_SIZE; + + for (var block = block_group; block < num_blocks; block += num_block_groups) { + let x_base = src1_idx_base + block * BLOCK_SIZE + y_offset; + var x_block: array<f32, 16>; + for (var i = 0u; i < 16u; i++) { + x_block[i] = f32(src1[x_base + i]); + } + + for (var row = 0u; row < OUTPUTS_PER_WG; row++) { + let output_row = row_base + row; + if (output_row < params.m) { + let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES; + let d = f32(load_f16_at_src0(block_byte_base)); + let scales_h = load_u16_at_src0(block_byte_base + 2u); + let scales_l_word = load_u32_at_src0(block_byte_base + 4u); + let sl_byte = get_byte(scales_l_word, sub_blk / 2u); + let sl = (sl_byte >> (4u * (sub_blk % 2u))) & 0xFu; + let sh_bits = (scales_h >> (2u * sub_blk)) & 3u; + let ls = i32(sl | (sh_bits << 4u)); + let dl = d * f32(ls - 32); + + let qs_byte_off = 8u + sub_blk * 16u; + let q_w0 = load_u32_at_src0(block_byte_base + qs_byte_off); + let q_w1 = load_u32_at_src0(block_byte_base + qs_byte_off + 4u); + let q_w2 = load_u32_at_src0(block_byte_base + qs_byte_off + 8u); + let q_w3 = load_u32_at_src0(block_byte_base + qs_byte_off + 12u); + + var row_sum = 0.0; + for (var i = 0u; i < 16u; i++) { + let q_word = select( + select(q_w0, q_w1, i >= 4u), + select(q_w2, q_w3, i >= 12u), + i >= 8u); + let q_byte = get_byte(q_word, i % 4u); + let nib = select(q_byte & 0xFu, (q_byte >> 4u) & 0xFu, half == 1u); + row_sum += f32(kvalues_iq4nl[nib]) * dl * x_block[i]; + } + acc[row] += row_sum; + } + } + } + + return acc; +} +#endif + +#ifdef MUL_ACC_MXFP4 +#define BLOCK_SIZE 32 +#define BLOCK_SIZE_BYTES 17 +#define THREADS_PER_BLOCK 4 +#define ELEMS_PER_THREAD (BLOCK_SIZE/THREADS_PER_BLOCK) +fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array<f32, OUTPUTS_PER_WG> { + var acc: array<f32, OUTPUTS_PER_WG>; + + let num_blocks = params.k / BLOCK_SIZE; + let thread_within_block = thread_id % 4; + for (var block = thread_id/THREADS_PER_BLOCK; block < num_blocks; block += WG_SIZE/THREADS_PER_BLOCK) { + let x_base = src1_idx_base + block * BLOCK_SIZE + thread_within_block * 4; + var x_block: array<f32, ELEMS_PER_THREAD>; + for (var i = 0u; i < ELEMS_PER_THREAD / 2; i++) { + x_block[i] = f32(src1[x_base + i]); + x_block[i + 4] = f32(src1[x_base + i + 16]); + } + + for (var row = 0u; row < OUTPUTS_PER_WG; row++) { + let output_row = row_base + row; + if (output_row < params.m) { + let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES; + let eu8 = get_byte(load_u32_at_src0(block_byte_base), 0); + let e = ldexp(1.0, i32(eu8) - 128); + var row_sum = 0.0; + let q_packed = load_u32_at_src0(block_byte_base + 1u + 4u * thread_within_block); + for (var byte_idx = 0u; byte_idx < 4u; byte_idx++) { + let q_byte = get_byte(q_packed, byte_idx); + let q_lo = f32(kvalues_mxfp4[q_byte & 0xFu]) * e; + let q_hi = f32(kvalues_mxfp4[(q_byte >> 4u) & 0xFu]) * e; + row_sum += q_lo * x_block[byte_idx]; + row_sum += q_hi * x_block[byte_idx + 4u]; + } + acc[row] += row_sum; + } + } + } + + return acc; +} +#endif diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec_q_acc.tmpl b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec_q_acc.tmpl new file mode 100644 index 00000000000..3ef2f77ebe0 --- /dev/null +++ b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec_q_acc.tmpl @@ -0,0 +1,303 @@ +#ifdef U32_DEQUANT_HELPERS +#define SRC0_TYPE u32 + +fn byte_of(v: u32, b: u32) -> u32 { + return (v >> (b * 8u)) & 0xFFu; +} + +fn sbyte_of(v: u32, b: u32) -> i32 { + let raw = i32((v >> (b * 8u)) & 0xFFu); + return select(raw, raw - 256, raw >= 128); +} +#endif + +#define SRC0_TYPE SRC0_INNER_TYPE +#define SRC1_TYPE SRC1_INNER_TYPE + +#ifdef LEGACY_QUANTS +#define BLOCK_SIZE 32 +#define THREADS_PER_BLOCK 4 +#elif K_QUANTS +#define BLOCK_SIZE 256 +#define THREADS_PER_BLOCK 16 +#endif + +#define ELEMS_PER_THREAD (BLOCK_SIZE/THREADS_PER_BLOCK) +#define Q8_BLOCK_SIZE 32 + +#ifdef MUL_ACC_Q4_0 +#define BLOCK_SIZE_BYTES 18 +#define B_DS_TYPE vec2<f32> +fn repack_a(block_byte_base: u32, inner_id: u32) -> vec2<u32> { + let qs_packed = load_u32_at_src0(block_byte_base + 2u + 4u * inner_id); + + return vec2<u32>( + qs_packed & 0x0F0F0F0Fu, + (qs_packed >> 4u) & 0x0F0F0F0Fu + ); +} +fn repack_b_qs(block:u32, inner_id: u32) -> vec2<u32> { + return vec2<u32>( + src1q[block].qs[inner_id], + src1q[block].qs[inner_id + 4u], + ); +} +fn repack_b_dm(block: u32) -> B_DS_TYPE { + return B_DS_TYPE( + f32(src1q[block].d), + f32(src1q[block].s) + ); +} +fn get_dm(block_byte_base: u32) -> f32 { + return f32(load_f16_at_src0(block_byte_base)); +} +fn mul_q8_1(row_sum: i32, da: f32, b_ds: B_DS_TYPE) -> f32 { + return f32(row_sum) * (da * b_ds.x) - 8.0 * da * b_ds.y / THREADS_PER_BLOCK; +} +#endif + +#ifdef MUL_ACC_Q4_1 +#define BLOCK_SIZE_BYTES 20 +#define B_DS_TYPE vec2<f32> +fn repack_a(block_byte_base: u32, inner_id: u32) -> vec2<u32> { + let qs_packed = load_u32_at_src0(block_byte_base + 4u + 4u * inner_id); + + return vec2<u32>( + qs_packed & 0x0F0F0F0Fu, + (qs_packed >> 4u) & 0x0F0F0F0Fu + ); +} +fn repack_b_qs(block:u32, inner_id: u32) -> vec2<u32> { + return vec2<u32>( + src1q[block].qs[inner_id], + src1q[block].qs[inner_id + 4u], + ); +} +fn repack_b_dm(block: u32) -> B_DS_TYPE { + return B_DS_TYPE( + f32(src1q[block].d), + f32(src1q[block].s) + ); +} +fn get_dm(block_byte_base: u32) -> vec2<f32> { + return vec2<f32>( + f32(load_f16_at_src0(block_byte_base)), + f32(load_f16_at_src0(block_byte_base + 2u)) + ); +} +fn mul_q8_1(row_sum: i32, dma: vec2<f32>, b_ds: B_DS_TYPE) -> f32 { + return f32(row_sum) * (dma.x * b_ds.x) + dma.y * b_ds.y / THREADS_PER_BLOCK; +} +#endif + +#ifdef MUL_ACC_Q8_0 +#define BLOCK_SIZE_BYTES 34 +#define B_DS_TYPE f32 +fn repack_a(block_byte_base: u32, inner_id: u32) -> vec2<u32> { + return vec2<u32>( + load_u32_at_src0(block_byte_base + 2u + 4u * (inner_id * 2u)), + load_u32_at_src0(block_byte_base + 2u + 4u * (inner_id * 2u + 1)) + ); +} +fn repack_b_qs(block:u32, inner_id: u32) -> vec2<u32> { + return vec2<u32>( + src1q[block].qs[inner_id * 2u], + src1q[block].qs[inner_id * 2u + 1], + ); +} +fn repack_b_dm(block: u32) -> B_DS_TYPE { + return B_DS_TYPE(src1q[block].d); +} +fn get_dm(block_byte_base: u32) -> f32 { + return f32(load_f16_at_src0(block_byte_base)); +} +fn mul_q8_1(row_sum: i32, da: f32, b_ds: B_DS_TYPE) -> f32 { + return f32(row_sum) * (da * b_ds); +} +#endif + +#ifdef LEGACY_QUANTS +fn mmvq_dot_product(a_byte_base: u32, b_inner_id: u32, b_repacked: vec2<u32>, b_ds: B_DS_TYPE) -> f32 { + var row_sum = 0; + let a_repacked = repack_a(a_byte_base, b_inner_id); + + row_sum += dot4I8Packed(a_repacked[0], b_repacked[0]); + row_sum += dot4I8Packed(a_repacked[1], b_repacked[1]); + + return mul_q8_1(row_sum, get_dm(a_byte_base), b_ds); +} + +fn accumulate_vec_q_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1q_idx_base: u32) -> array<f32, OUTPUTS_PER_WG> { + var acc: array<f32, OUTPUTS_PER_WG>; + + let num_blocks = params.k / BLOCK_SIZE; + + for (var block = thread_id / THREADS_PER_BLOCK; block < num_blocks; block += WG_SIZE / THREADS_PER_BLOCK) { + let b_inner_id = thread_id % THREADS_PER_BLOCK; + let b_block_idx = src1q_idx_base + block; + + let b_repacked = repack_b_qs(b_block_idx, b_inner_id); + let b_ds = repack_b_dm(b_block_idx); + + for (var row = 0u; row < OUTPUTS_PER_WG; row++) { + let output_row = row_base + row; + if (output_row < params.m) { + let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES; + acc[row] += mmvq_dot_product(block_byte_base, b_inner_id, b_repacked, b_ds); + } + } + } + + return acc; +} +#endif + +#ifdef MUL_ACC_Q2_K +#define BLOCK_SIZE_BYTES 84 +#define B_DS_TYPE f32 +fn repack_a(block_byte_base: u32, tid: u32) -> vec4<u32> { + let ih2 = tid / 8u; + let phase = tid % 2u; + let iq4_idx = 2u * ih2 + phase; + let qs_byte_base = block_byte_base + 16u + 16u * iq4_idx; + let qs_shift = tid & 6u; + return vec4<u32>( + (load_u32_at_src0_aligned(qs_byte_base) >> qs_shift) & 0x03030303u, + (load_u32_at_src0_aligned(qs_byte_base + 4u) >> qs_shift) & 0x03030303u, + (load_u32_at_src0_aligned(qs_byte_base + 8u) >> qs_shift) & 0x03030303u, + (load_u32_at_src0_aligned(qs_byte_base + 12u) >> qs_shift) & 0x03030303u, + ); +} +fn repack_b_qs(q8_block_idx: u32, tid: u32) -> vec4<u32> { + let phase = tid % 2u; + return vec4<u32>( + src1q[q8_block_idx].qs[4u * phase], + src1q[q8_block_idx].qs[4u * phase + 1u], + src1q[q8_block_idx].qs[4u * phase + 2u], + src1q[q8_block_idx].qs[4u * phase + 3u], + ); +} +fn repack_b_dm(q8_block_idx: u32) -> B_DS_TYPE { + return B_DS_TYPE(src1q[q8_block_idx].d); +} +fn get_dm(block_byte_base: u32) -> vec2<f32> { + return vec2<f32>( + f32(load_f16_at_src0(block_byte_base + 80u)), + f32(load_f16_at_src0(block_byte_base + 82u)), + ); +} +fn get_scale_min(block_byte_base: u32, tid: u32) -> vec2<f32> { + let scale_byte = block_byte_base + tid; + let scale = byte_of(load_u32_at_src0_aligned(scale_byte), scale_byte & 3u); + return vec2<f32>(f32(scale & 0xFu), f32(scale >> 4u)); +} +fn mmvq_dot_product(a_byte_base: u32, tid: u32, b_repacked: vec4<u32>, b_ds: B_DS_TYPE) -> f32 { + let a_repacked = repack_a(a_byte_base, tid); + let dm = get_dm(a_byte_base); + let scale_min = get_scale_min(a_byte_base, tid); + + let scale_q = i32(scale_min.x); + let scale_m_i8x4 = u32(scale_min.y) * 0x01010101u; + + let row_sum_d = (dot4I8Packed(b_repacked[0], a_repacked[0]) + dot4I8Packed(b_repacked[1], a_repacked[1]) + + dot4I8Packed(b_repacked[2], a_repacked[2]) + dot4I8Packed(b_repacked[3], a_repacked[3])) * scale_q; + let row_sum_m = dot4I8Packed(b_repacked[0], scale_m_i8x4) + dot4I8Packed(b_repacked[1], scale_m_i8x4) + + dot4I8Packed(b_repacked[2], scale_m_i8x4) + dot4I8Packed(b_repacked[3], scale_m_i8x4); + + return b_ds * (dm.x * f32(row_sum_d) - dm.y * f32(row_sum_m)); +} +#endif + +#ifdef MUL_ACC_Q4_K +#define BLOCK_SIZE_BYTES 144 +#define B_DS_TYPE vec2<f32> +fn repack_a(block_byte_base: u32, tid: u32) -> vec4<u32> { + let iq4 = tid / 4u; + let phase = tid % 2u; + let nibble = (tid >> 1u) % 2u; + let q_qs_byte_base = block_byte_base + 16u + 32u * iq4 + 16u * phase; + let qs_shift = 4u * nibble; + return vec4<u32>( + (load_u32_at_src0_aligned(q_qs_byte_base) >> qs_shift) & 0x0F0F0F0Fu, + (load_u32_at_src0_aligned(q_qs_byte_base + 4u) >> qs_shift) & 0x0F0F0F0Fu, + (load_u32_at_src0_aligned(q_qs_byte_base + 8u) >> qs_shift) & 0x0F0F0F0Fu, + (load_u32_at_src0_aligned(q_qs_byte_base + 12u) >> qs_shift) & 0x0F0F0F0Fu, + ); +} +fn repack_b_qs(q8_block_idx: u32, tid: u32) -> vec4<u32> { + let phase = tid % 2u; + return vec4<u32>( + src1q[q8_block_idx].qs[4u * phase], + src1q[q8_block_idx].qs[4u * phase + 1u], + src1q[q8_block_idx].qs[4u * phase + 2u], + src1q[q8_block_idx].qs[4u * phase + 3u], + ); +} +fn repack_b_dm(q8_block_idx: u32) -> B_DS_TYPE { + return B_DS_TYPE( + f32(src1q[q8_block_idx].d), + f32(src1q[q8_block_idx].s), + ); +} +fn get_dm(block_byte_base: u32) -> vec2<f32> { + return vec2<f32>( + f32(load_f16_at_src0(block_byte_base + 0u)), + f32(load_f16_at_src0(block_byte_base + 2u)), + ); +} +fn get_scale_min(block_byte_base: u32, tid: u32) -> vec2<f32> { + let sc_m_idx = tid / 2u; + let scales_byte_base = block_byte_base + 4u; + let scales0_3 = load_u32_at_src0_aligned(scales_byte_base); + let scales4_7 = load_u32_at_src0_aligned(scales_byte_base + 4u); + let scales8_11 = load_u32_at_src0_aligned(scales_byte_base + 8u); + + let byte_idx = sc_m_idx & 3u; + let is_high = sc_m_idx >= 4u; + + let sc_low = byte_of(scales0_3, byte_idx) & 0x3Fu; + let sc_high = (byte_of(scales8_11, byte_idx) & 0x0Fu) | ((byte_of(scales0_3, byte_idx) & 0xC0u) >> 2u); + let scale = f32(select(sc_low, sc_high, is_high)); + + let mn_low = byte_of(scales4_7, byte_idx) & 0x3Fu; + let mn_high = (byte_of(scales8_11, byte_idx) >> 4u) | ((byte_of(scales4_7, byte_idx) & 0xC0u) >> 2u); + let min_val = f32(select(mn_low, mn_high, is_high)); + + return vec2<f32>(scale, min_val); +} +fn mmvq_dot_product(a_byte_base: u32, tid: u32, b_repacked: vec4<u32>, b_ds: B_DS_TYPE) -> f32 { + let a_repacked = repack_a(a_byte_base, tid); + let dm = get_dm(a_byte_base); + let scale_min = get_scale_min(a_byte_base, tid); + + let row_sum = dot4I8Packed(a_repacked[0], b_repacked[0]) + dot4I8Packed(a_repacked[1], b_repacked[1]) + + dot4I8Packed(a_repacked[2], b_repacked[2]) + dot4I8Packed(a_repacked[3], b_repacked[3]); + + // Each thread covers half of the Q8_1 block, so add only b_ds.y/2. + return b_ds.x * dm.x * scale_min.x * f32(row_sum) - dm.y * scale_min.y * (b_ds.y / (Q8_BLOCK_SIZE / ELEMS_PER_THREAD)); +} +#endif + +#ifdef K_QUANTS +fn accumulate_vec_q_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1q_idx_base: u32) -> array<f32, OUTPUTS_PER_WG> { + var acc: array<f32, OUTPUTS_PER_WG>; + + let tid = thread_id % THREADS_PER_BLOCK; + + for (var block = thread_id / THREADS_PER_BLOCK; block < params.k / BLOCK_SIZE; block += WG_SIZE / THREADS_PER_BLOCK) { + let src1q_idx = src1q_idx_base + (block * BLOCK_SIZE + ELEMS_PER_THREAD * tid) / Q8_BLOCK_SIZE; + let b_repacked = repack_b_qs(src1q_idx, tid); + let b_ds = repack_b_dm(src1q_idx); + + for (var row = 0u; row < OUTPUTS_PER_WG; row++) { + let output_row = row_base + row; + if (output_row < params.m) { + let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES; + acc[row] += mmvq_dot_product(block_byte_base, tid, b_repacked, b_ds); + } + } + } + + return acc; +} +#endif diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/pad.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/pad.wgsl new file mode 100644 index 00000000000..ea63b9a731c --- /dev/null +++ b/ggml/src/ggml-webgpu/wgsl-shaders/pad.wgsl @@ -0,0 +1,86 @@ +@group(0) @binding(0) +var<storage, read_write> src: array<f32>; + +@group(0) @binding(1) +var<storage, read_write> dst: array<f32>; + +struct Params { + ne: u32, // total number of elements + offset_src: u32, // in elements + offset_dst: u32, // in elements + + // Strides (in elements) + stride_src0: u32, + stride_src1: u32, + stride_src2: u32, + stride_src3: u32, + + // Logical shapes + src_ne0: u32, + src_ne1: u32, + src_ne2: u32, + src_ne3: u32, + + dst_ne0: u32, + dst_ne1: u32, + dst_ne2: u32, + dst_ne3: u32, + + // Pad sizes (in elements) + lp0: u32, + rp0: u32, + lp1: u32, + rp1: u32, + lp2: u32, + rp2: u32, + lp3: u32, + rp3: u32, +}; + +@group(0) @binding(2) +var<uniform> params: Params; + +fn wrap_around(idx: i32, n: u32) -> u32 { + return u32(idx + i32(n)) % n; +} + +@compute @workgroup_size(WG_SIZE) +fn main(@builtin(global_invocation_id) gid: vec3<u32>) { + if (gid.x >= params.ne) { + return; + } + + var i = gid.x; + let dst_plane = params.dst_ne2 * params.dst_ne1 * params.dst_ne0; + let i3 = i / dst_plane; + i = i % dst_plane; + let i2 = i / (params.dst_ne1 * params.dst_ne0); + i = i % (params.dst_ne1 * params.dst_ne0); + let i1 = i / params.dst_ne0; + let i0 = i % params.dst_ne0; + + var value: f32 = 0.0; + +#ifdef CIRCULAR + let ci0 = wrap_around(i32(i0) - i32(params.lp0), params.src_ne0); + let ci1 = wrap_around(i32(i1) - i32(params.lp1), params.src_ne1); + let ci2 = wrap_around(i32(i2) - i32(params.lp2), params.src_ne2); + let ci3 = wrap_around(i32(i3) - i32(params.lp3), params.src_ne3); + let circular_src_idx = ci0 * params.stride_src0 + ci1 * params.stride_src1 + + ci2 * params.stride_src2 + ci3 * params.stride_src3; + value = src[params.offset_src + circular_src_idx]; +#else + let is_src = + (i0 >= params.lp0 && i0 < params.dst_ne0 - params.rp0) && + (i1 >= params.lp1 && i1 < params.dst_ne1 - params.rp1) && + (i2 >= params.lp2 && i2 < params.dst_ne2 - params.rp2) && + (i3 >= params.lp3 && i3 < params.dst_ne3 - params.rp3); + if (is_src) { + let src_idx = (i0 - params.lp0) * params.stride_src0 + (i1 - params.lp1) * params.stride_src1 + + (i2 - params.lp2) * params.stride_src2 + (i3 - params.lp3) * params.stride_src3; + value = src[params.offset_src + src_idx]; + } +#endif + + dst[params.offset_dst + gid.x] = value; +} diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/quant_inner_loops.tmpl b/ggml/src/ggml-webgpu/wgsl-shaders/quant_inner_loops.tmpl new file mode 100644 index 00000000000..d1da4608434 --- /dev/null +++ b/ggml/src/ggml-webgpu/wgsl-shaders/quant_inner_loops.tmpl @@ -0,0 +1,21 @@ +#ifdef U32_DEQUANT_HELPERS +fn dequant_q4_0_packed_to_shmem(q_packed: u32, d: f16, dst_idx: u32) { + let scale = QUANT_OUT_TYPE(d); + for (var k = 0u; k < 4u; k++) { + let q_byte = get_byte(q_packed, k); + let q_hi = (QUANT_OUT_TYPE((q_byte >> 4) & 0xFu) - QUANT_OUT_TYPE(8.0)) * scale; + let q_lo = (QUANT_OUT_TYPE(q_byte & 0xFu) - QUANT_OUT_TYPE(8.0)) * scale; + QUANT_SHMEM[dst_idx + k] = q_lo; + QUANT_SHMEM[dst_idx + k + 16u] = q_hi; + } +} + +fn dequant_q8_0_packed_to_shmem(q_packed: u32, d: f16, dst_idx: u32) { + let scale = QUANT_OUT_TYPE(d); + for (var k = 0u; k < 4u; k++) { + let q_byte = get_byte_i32(q_packed, k); + let q_val = QUANT_OUT_TYPE(q_byte) * scale; + QUANT_SHMEM[dst_idx + k] = q_val; + } +} +#endif diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/quantize_q8.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/quantize_q8.wgsl new file mode 100644 index 00000000000..b3f1fa04b80 --- /dev/null +++ b/ggml/src/ggml-webgpu/wgsl-shaders/quantize_q8.wgsl @@ -0,0 +1,173 @@ +#ifdef USE_SUBGROUP_REDUCTION +enable subgroups; +#endif +enable f16; + +requires packed_4x8_integer_dot_product; + +#include "common_decls.tmpl" + +struct Params { + offset_src1: u32, + stride_12: u32, + stride_13: u32, + ne0: u32, + ne2: u32, + ne3: u32, +}; + +#define SRC1_TYPE vec4<SRC1_INNER_TYPE> + +@group(0) @binding(0) var<storage, read_write> src1: array<SRC1_TYPE>; +@group(0) @binding(1) var<storage, read_write> src1q: array<q8_1>; + +@group(0) @binding(2) var<uniform> params: Params; + +#ifdef USE_SUBGROUP_REDUCTION +fn cluster_max_8(v: f32) -> f32 { + var r = v; + r = max(r, subgroupShuffleXor(r, 1u)); + r = max(r, subgroupShuffleXor(r, 2u)); + r = max(r, subgroupShuffleXor(r, 4u)); + return r; +} + +#if defined(MUL_ACC_Q4_0) || defined(MUL_ACC_Q4_1) || defined(MUL_ACC_Q4_K) +fn cluster_add_i4x8(v: i32) -> i32 { + var r= v; + r += subgroupShuffleXor(r, 1u); + r += subgroupShuffleXor(r, 2u); + r += subgroupShuffleXor(r, 4u); + return r; +} +#endif +#endif + +#ifdef USE_WORKGROUP_REDUCTION +#define CLUSTER_SIZE 8 + +var<workgroup> partial_amaxs: array<array<f32, CLUSTER_SIZE>, WG_SIZE / CLUSTER_SIZE>; +var<workgroup> partial_sums: array<array<i32, CLUSTER_SIZE>, WG_SIZE / CLUSTER_SIZE>; +#endif + +@compute @workgroup_size(WG_SIZE) +fn main( + @builtin(local_invocation_id) local_id: vec3<u32>, + @builtin(workgroup_id) wg_id: vec3<u32>, + @builtin(num_workgroups) num_wg: vec3<u32> +) { + let thread_id = local_id.x; + let num_vec4 = params.ne0 / 4u; + + let wg_per_vec = (num_vec4 + (WG_SIZE - 1u)) / WG_SIZE; + let total_batches = wg_per_vec * params.ne2 * params.ne3; + + let wg_linear = wg_id.y * num_wg.x + wg_id.x; + if (wg_linear >= total_batches) { + return; + } + + let src13_idx = wg_linear / (params.ne2 * wg_per_vec); + let src12_idx = (wg_linear - src13_idx * (params.ne2 * wg_per_vec)) / wg_per_vec; + let src11_wg_idx = wg_linear % wg_per_vec; + let src1_idx_base = params.offset_src1 + src13_idx * params.stride_13 + src12_idx * params.stride_12; + let src1_idx_vec4_base = src1_idx_base / 4u; + + let blocks_per_row = params.ne0 / 32u; + let blocks_per_wg = (WG_SIZE * 4u) / 32u; + let src1q_idx_base = (src13_idx * params.ne2 + src12_idx) * blocks_per_row; + let src1q_idx = src1q_idx_base + src11_wg_idx * blocks_per_wg + thread_id / 8u; + let qs_idx = thread_id % 8u; + + // reduction + var q4 = vec4<f32>(0.0); + var q4_quants = 0u; + var thread_amax = 0.0; + + let src11_vec4_idx = src11_wg_idx * WG_SIZE + thread_id; + let is_valid = src11_vec4_idx < num_vec4; + +#ifdef USE_SUBGROUP_REDUCTION + + var d = 0.0; + + if (is_valid) { + q4 = src1[src1_idx_vec4_base + src11_vec4_idx]; + let abs_q4 = abs(q4); + thread_amax = max(max(abs_q4[0u], abs_q4[1u]), max(abs_q4[2], abs_q4[3])); + } + + d = cluster_max_8(thread_amax) / 127.0; + + if (is_valid) { + let id = select(0.0, 1.0 / d, d > 0.0); + q4_quants = pack4xI8(vec4<i32>(round(q4 * id))); + if (qs_idx == 0u) { + src1q[src1q_idx].d = f16(d); + } + src1q[src1q_idx].qs[qs_idx] = q4_quants; + } + +#if defined(MUL_ACC_Q4_0) || defined(MUL_ACC_Q4_1) || defined(MUL_ACC_Q4_K) + let q4_quants_sum = dot4I8Packed(q4_quants, 0x01010101u); + let s = f16(d * f32(cluster_add_i4x8(q4_quants_sum))); + + if (is_valid) { + if (qs_idx == 0u) { + src1q[src1q_idx].s = s; + } + } +#endif +#endif + +#ifdef USE_WORKGROUP_REDUCTION + + var d = 0.0; + let cluster_id = thread_id / 8u; + + if (is_valid) { + q4 = src1[src1_idx_vec4_base + src11_vec4_idx]; + let abs_q4 = abs(q4); + thread_amax = max(max(abs_q4[0], abs_q4[1]), max(abs_q4[2], abs_q4[3])); + partial_amaxs[cluster_id][qs_idx] = thread_amax; + } + + workgroupBarrier(); + + if (is_valid) { + let amax = max( + max( + max(partial_amaxs[cluster_id][0], partial_amaxs[cluster_id][1]), max(partial_amaxs[cluster_id][2], partial_amaxs[cluster_id][3])), + max( + max(partial_amaxs[cluster_id][4], partial_amaxs[cluster_id][5]), max(partial_amaxs[cluster_id][6], partial_amaxs[cluster_id][7])) + ); + + d = amax / 127.0; + let id = select(0.0f, 1.0f / d, d > 0.0f); + + q4_quants = pack4xI8(vec4<i32>(round(q4 * id))); + src1q[src1q_idx].qs[qs_idx] = q4_quants; + + if (qs_idx == 0u) { + src1q[src1q_idx].d = f16(d); + } + } + +#if defined(MUL_ACC_Q4_0) || defined(MUL_ACC_Q4_1) || defined(MUL_ACC_Q4_K) + + partial_sums[cluster_id][qs_idx] = dot4I8Packed(q4_quants, 0x01010101u); + + workgroupBarrier(); + + if (is_valid) { + if (qs_idx == 0u) { + let s = d * f32(partial_sums[cluster_id][0] + partial_sums[cluster_id][1] + partial_sums[cluster_id][2] + partial_sums[cluster_id][3] + + partial_sums[cluster_id][4] + partial_sums[cluster_id][5] + partial_sums[cluster_id][6] + partial_sums[cluster_id][7]); + src1q[src1q_idx].s = f16(s); + } + } + +#endif +#endif + +} diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/repeat.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/repeat.wgsl new file mode 100644 index 00000000000..6e2a1a8b614 --- /dev/null +++ b/ggml/src/ggml-webgpu/wgsl-shaders/repeat.wgsl @@ -0,0 +1,67 @@ +enable f16; + +struct Params { + ne: u32, + + offset_src0: u32, + offset_dst: u32, + + stride_src0_0: u32, + stride_src0_1: u32, + stride_src0_2: u32, + stride_src0_3: u32, + + a_ne0: u32, + a_ne1: u32, + a_ne2: u32, + a_ne3: u32, + + ne0: u32, + ne1: u32, + ne2: u32, +}; + +#ifdef TYPE_F32 +#define DataType f32 +#endif +#ifdef TYPE_I32 +#define DataType i32 +#endif +#ifdef TYPE_I16 +// same size (16-bit) is sufficient for repeat +#define DataType f16 +#endif + +@group(0) @binding(0) +var<storage, read_write> src0: array<DataType>; + +@group(0) @binding(1) +var<storage, read_write> dst: array<DataType>; + +@group(0) @binding(2) +var<uniform> params: Params; + +@compute @workgroup_size(WG_SIZE) +fn main(@builtin(global_invocation_id) gid: vec3<u32>) { + if (gid.x < params.ne) { + var i = gid.x; + let i3 = i / (params.ne2 * params.ne1 * params.ne0); + i = i % (params.ne2 * params.ne1 * params.ne0); + let i2 = i / (params.ne1 * params.ne0); + i = i % (params.ne1 * params.ne0); + let i1 = i / params.ne0; + let i0 = i % params.ne0; + + let a_i0 = i0 % params.a_ne0; + let a_i1 = i1 % params.a_ne1; + let a_i2 = i2 % params.a_ne2; + let a_i3 = i3 % params.a_ne3; + + let a_index = a_i0 * params.stride_src0_0 + + a_i1 * params.stride_src0_1 + + a_i2 * params.stride_src0_2 + + a_i3 * params.stride_src0_3; + + dst[params.offset_dst + gid.x] = src0[params.offset_src0 + a_index]; + } +} diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm.wgsl deleted file mode 100644 index 712b921f1ab..00000000000 --- a/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm.wgsl +++ /dev/null @@ -1,123 +0,0 @@ -#define(VARIANTS) - -[ - { - "DECLS": ["NOT_INPLACE"] - }, - { - "SHADER_SUFFIX": "inplace", - "DECLS": ["INPLACE"] - }, -] - -#end(VARIANTS) - -#define(DECLS) - -#decl(NOT_INPLACE) - -fn update(src_offset: u32, dst_offset: u32, scale: f32) { - dst[dst_offset] = scale * src[src_offset]; -} - -@group(0) @binding(1) -var<storage, read_write> dst: array<f32>; - -@group(0) @binding(2) -var<uniform> params: Params; - -#enddecl(NOT_INPLACE) - -#decl(INPLACE) - -fn update(src_offset: u32, dst_offset: u32, scale: f32) { - src[dst_offset] = scale * src[src_offset]; -} - -@group(0) @binding(1) -var<uniform> params: Params; - -#enddecl(INPLACE) - -#end(DECLS) - -#define(SHADER) - -struct Params { - offset_src: u32, // in elements - offset_dst: u32, // in elements - - // Strides (in elements) - stride_src1: u32, - stride_src2: u32, - stride_src3: u32, - - stride_dst1: u32, - stride_dst2: u32, - stride_dst3: u32, - - // Shape of src/dst - ne0: u32, - ne1: u32, - ne2: u32, - ne3: u32, - - eps: f32 -}; - -@group(0) @binding(0) -var<storage, read_write> src: array<f32>; - -DECLS - -override wg_size: u32; -var<workgroup> scratch: array<f32, wg_size>; - -@compute @workgroup_size(wg_size) -fn main(@builtin(workgroup_id) wid: vec3<u32>, - @builtin(local_invocation_id) lid: vec3<u32>) { - - // one thread per row - var i = wid.x; - let i3 = i / (params.ne2 * params.ne1); - i = i % (params.ne2 * params.ne1); - let i2 = i / params.ne1; - let i1 = i % params.ne1; - let i_src_row = params.offset_src + i3 * params.stride_src3 + i2 * params.stride_src2 + i1 * params.stride_src1; - let i_dst_row = params.offset_dst + i3 * params.stride_dst3 + i2 * params.stride_dst2 + i1 * params.stride_dst1; - - let elems = (params.ne0 + wg_size - 1) / wg_size; - - var sum = 0.0f; - var col = lid.x; - for (var j: u32 = 0; j < elems; j++) { - if (col >= params.ne0) { - break; - } - sum += pow(src[i_src_row + col], 2.0); - col += wg_size; - } - - scratch[lid.x] = sum; - workgroupBarrier(); - var offset = wg_size / 2; - while (offset > 0) { - if (lid.x < offset) { - scratch[lid.x] += scratch[lid.x + offset]; - } - offset = offset / 2; - workgroupBarrier(); - } - sum = scratch[0]; - - let scale = 1.0/sqrt(sum/f32(params.ne0) + params.eps); - col = lid.x; - for (var j: u32 = 0; j < elems; j++) { - if (col >= params.ne0) { - break; - } - update(i_src_row + col, i_dst_row + col, scale); - col += wg_size; - } -} -#end(SHADER) diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm_mul.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm_mul.wgsl new file mode 100644 index 00000000000..fd20a4e54c9 --- /dev/null +++ b/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm_mul.wgsl @@ -0,0 +1,152 @@ +#ifdef OVERLAP + +@group(0) @binding(0) +var<storage, read_write> rn_src: array<f32>; + +@group(0) @binding(1) +var<storage, read_write> mul_src: array<f32>; + +@group(0) @binding(2) +var<uniform> params: Params; + +fn update(rn_src_offset: u32, dst_offset: u32, scale: f32, mul_src_offset: u32) { + mul_src[dst_offset] = scale * rn_src[rn_src_offset] * mul_src[mul_src_offset]; +} + +#elif INPLACE + +@group(0) @binding(0) +var<storage, read_write> rn_src: array<f32>; + +@group(0) @binding(1) +var<storage, read_write> mul_src: array<f32>; + +@group(0) @binding(2) +var<uniform> params: Params; + +fn update(rn_src_offset: u32, dst_offset: u32, scale: f32, mul_src_offset: u32) { + rn_src[dst_offset] = scale * rn_src[rn_src_offset] * mul_src[mul_src_offset]; +} + +#elif SRC_OVERLAP + +@group(0) @binding(0) +var<storage, read_write> merged_src: array<f32>; + +@group(0) @binding(1) +var<storage, read_write> dst: array<f32>; + +@group(0) @binding(2) +var<uniform> params: Params; + +fn update(rn_src_offset: u32, dst_offset: u32, scale: f32, mul_src_offset: u32) { + dst[dst_offset] = scale * merged_src[rn_src_offset] * merged_src[mul_src_offset]; +} + +#else + +@group(0) @binding(0) +var<storage, read_write> rn_src: array<f32>; + +@group(0) @binding(1) +var<storage, read_write> mul_src: array<f32>; + +@group(0) @binding(2) +var<storage, read_write> dst: array<f32>; + +@group(0) @binding(3) +var<uniform> params: Params; + +fn update(rn_src_offset: u32, dst_offset: u32, scale: f32, mul_src_offset: u32) { + dst[dst_offset] = scale * rn_src[rn_src_offset] * mul_src[mul_src_offset]; +} + +#endif + +struct Params { + offset_rn_src: u32, + offset_mul_src: u32, + offset_dst: u32, + + stride_rn_src1: u32, + stride_rn_src2: u32, + stride_rn_src3: u32, + + stride_mul_src1: u32, + stride_mul_src2: u32, + stride_mul_src3: u32, + + stride_dst1: u32, + stride_dst2: u32, + stride_dst3: u32, + + mul_src_ne0: u32, + mul_src_ne1: u32, + mul_src_ne2: u32, + mul_src_ne3: u32, + + ne0: u32, + ne1: u32, + ne2: u32, + ne3: u32, + + eps: f32 +}; + +var<workgroup> scratch: array<f32, WG_SIZE>; + +@compute @workgroup_size(WG_SIZE) +fn main(@builtin(workgroup_id) wid: vec3<u32>, + @builtin(local_invocation_id) lid: vec3<u32>) { + + // one thread per row + var i = wid.x; + let i3 = i / (params.ne2 * params.ne1); + i = i % (params.ne2 * params.ne1); + let i2 = i / params.ne1; + let i1 = i % params.ne1; + let i_rn_src_row = params.offset_rn_src + i3 * params.stride_rn_src3 + i2 * params.stride_rn_src2 + i1 * params.stride_rn_src1; + let i_mul_src_row = params.offset_mul_src + (i3 % params.mul_src_ne3) * params.stride_mul_src3 + (i2 % params.mul_src_ne2) * params.stride_mul_src2 + (i1 % params.mul_src_ne1) * params.stride_mul_src1; + let i_dst_row = params.offset_dst + i3 * params.stride_dst3 + i2 * params.stride_dst2 + i1 * params.stride_dst1; + + let elems = (params.ne0 + WG_SIZE - 1) / WG_SIZE; + + var sum = 0.0f; + var col = lid.x; + for (var j: u32 = 0; j < elems; j++) { + if (col >= params.ne0) { + break; + } +#ifdef SRC_OVERLAP + sum += pow(merged_src[i_rn_src_row + col], 2.0); +#else + sum += pow(rn_src[i_rn_src_row + col], 2.0); +#endif + col += WG_SIZE; + } + + scratch[lid.x] = sum; + + workgroupBarrier(); + + var offset: u32 = WG_SIZE / 2; + while (offset > 0) { + if (lid.x < offset) { + scratch[lid.x] += scratch[lid.x + offset]; + } + offset = offset / 2; + workgroupBarrier(); + } + sum = scratch[0]; + + let scale = 1.0/sqrt(sum/f32(params.ne0) + params.eps); + + col = lid.x; + for (var j: u32 = 0; j < elems; j++) { + if (col >= params.ne0) { + break; + } + update(i_rn_src_row + col, i_dst_row + col, scale, i_mul_src_row + col % params.mul_src_ne0); + col += WG_SIZE; + } +} diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/rope.tmpl.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/rope.wgsl similarity index 73% rename from ggml/src/ggml-webgpu/wgsl-shaders/rope.tmpl.wgsl rename to ggml/src/ggml-webgpu/wgsl-shaders/rope.wgsl index 84dc8dbff61..1c874e14240 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/rope.tmpl.wgsl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/rope.wgsl @@ -1,138 +1,12 @@ -#define(VARIANTS) - -[ - { - "REPLS": { - "TYPE" : "f32", - }, - "DECLS": ["NO_FF_BINDINGS", "NO_FF_FUNC", "ROTATE"] - }, - { - "SHADER_SUFFIX": "f32_inplace", - "REPLS": { - "TYPE" : "f32", - }, - "DECLS": ["NO_FF_BINDINGS_INPLACE", "NO_FF_FUNC", "ROTATE_INPLACE"] - }, - { - "REPLS": { - "TYPE" : "f16", - }, - "DECLS": ["NO_FF_BINDINGS", "NO_FF_FUNC", "ROTATE"] - }, - { - "SHADER_SUFFIX": "f16_inplace", - "REPLS": { - "TYPE" : "f16", - }, - "DECLS": ["NO_FF_BINDINGS_INPLACE", "NO_FF_FUNC", "ROTATE_INPLACE"] - }, - { - "SHADER_SUFFIX": "f32_ff", - "REPLS": { - "TYPE" : "f32", - }, - "DECLS": ["FF_BINDINGS", "FF_FUNC", "ROTATE"] - }, - { - "SHADER_SUFFIX": "f32_ff_inplace", - "REPLS": { - "TYPE" : "f32", - }, - "DECLS": ["FF_BINDINGS_INPLACE", "FF_FUNC", "ROTATE_INPLACE"] - }, - { - "SHADER_SUFFIX": "f16_ff", - "REPLS": { - "TYPE" : "f16", - }, - "DECLS": ["FF_BINDINGS", "FF_FUNC", "ROTATE"] - }, - { - "SHADER_SUFFIX": "f16_ff_inplace", - "REPLS": { - "TYPE" : "f16", - }, - "DECLS": ["FF_BINDINGS_INPLACE", "FF_FUNC", "ROTATE_INPLACE"] - } -] - -#end(VARIANTS) - -#define(DECLS) - -#decl(ROTATE) -fn rotate(i_dst0: u32, i_dst1: u32, out0: f32, out1: f32) { - dst[i_dst0] = {{TYPE}}(out0); - dst[i_dst1] = {{TYPE}}(out1); -} -#enddecl(ROTATE) - -#decl(ROTATE_INPLACE) -fn rotate(i_dst0: u32, i_dst1: u32, out0: f32, out1: f32) { - src0[i_dst0] = {{TYPE}}(out0); - src0[i_dst1] = {{TYPE}}(out1); -} -#enddecl(ROTATE_INPLACE) - -#decl(NO_FF_FUNC) -fn freq_factor(i: u32) -> f32 { - return 1.0f; -} -#enddecl(NO_FF_FUNC) - -#decl(FF_FUNC) -fn freq_factor(i: u32) -> f32 { - return src2[params.offset_src2 + i/2]; -} -#enddecl(FF_FUNC) - -#decl(NO_FF_BINDINGS) - -@group(0) @binding(2) -var<storage, read_write> dst: array<{{TYPE}}>; - -@group(0) @binding(3) -var<uniform> params: Params; - -#enddecl(NO_FF_BINDINGS) - -#decl(NO_FF_BINDINGS_INPLACE) - -@group(0) @binding(2) -var<uniform> params: Params; - -#enddecl(NO_FF_BINDINGS_INPLACE) - -#decl(FF_BINDINGS) - -@group(0) @binding(2) -var<storage, read_write> src2: array<f32>; - -@group(0) @binding(3) -var<storage, read_write> dst: array<{{TYPE}}>; - -@group(0) @binding(4) -var<uniform> params: Params; - -#enddecl(FF_BINDINGS) - -#decl(FF_BINDINGS_INPLACE) - -@group(0) @binding(2) -var<storage, read_write> src2: array<f32>; - -@group(0) @binding(3) -var<uniform> params: Params; - -#enddecl(FF_BINDINGS_INPLACE) - -#end(DECLS) - -#define(SHADER) - enable f16; +#ifdef TYPE_F32 +#define DataType f32 +#endif +#ifdef TYPE_F16 +#define DataType f16 +#endif + struct Params { offset_src0: u32, offset_src1: u32, @@ -168,12 +42,69 @@ struct Params { }; @group(0) @binding(0) -var<storage, read_write> src0: array<{{TYPE}}>; - +var<storage, read_write> src0: array<DataType>; @group(0) @binding(1) var<storage, read_write> src1: array<i32>; -DECLS +#ifdef INPLACE + +#ifdef FF_FUNC + +@group(0) @binding(2) +var<storage, read_write> src2: array<f32>; + +@group(0) @binding(3) +var<uniform> params: Params; + +#else + +@group(0) @binding(2) +var<uniform> params: Params; + +#endif + +#else + +#ifdef FF_FUNC +@group(0) @binding(2) +var<storage, read_write> src2: array<f32>; + +@group(0) @binding(3) +var<storage, read_write> dst: array<DataType>; + +@group(0) @binding(4) +var<uniform> params: Params; + +#else +@group(0) @binding(2) +var<storage, read_write> dst: array<DataType>; + +@group(0) @binding(3) +var<uniform> params: Params; +#endif +#endif + +#ifdef FF_FUNC +fn freq_factor(i: u32) -> f32 { + return src2[params.offset_src2 + i/2]; +} + +#else +fn freq_factor(i: u32) -> f32 { + return 1.0f; +} +#endif +#ifdef INPLACE +fn rotate(i_dst0: u32, i_dst1: u32, out0: f32, out1: f32) { + src0[i_dst0] = DataType(out0); + src0[i_dst1] = DataType(out1); +} +#else +fn rotate(i_dst0: u32, i_dst1: u32, out0: f32, out1: f32) { + dst[i_dst0] = DataType(out0); + dst[i_dst1] = DataType(out1); +} +#endif fn rope_yarn_ramp(low: f32, high: f32, i: u32) -> f32 { let y = (f32(i / 2) - low) / max(0.001f, high - low); @@ -184,7 +115,7 @@ fn rope_yarn_ramp(low: f32, high: f32, i: u32) -> f32 { // TODO: check performance of instantiating once on the CPU and passed as buffer, since it's repeated per-row fn rope_yarn(theta_extrap: f32, i: u32) -> vec2<f32> { var mscale = params.attn_factor; - var theta = params.freq_scale * theta_extrap; + var theta = params.freq_scale * theta_extrap; if (params.ext_factor != 0.0f) { let ramp_mix = rope_yarn_ramp(params.corr_dim0, params.corr_dim1, i) * params.ext_factor; theta = theta * (1 - ramp_mix) + theta_extrap * ramp_mix; @@ -211,10 +142,9 @@ fn pair_offset(is_neox: bool, is_mrope: bool, is_vision: bool) -> u32 { } } -override wg_size: u32; -@compute @workgroup_size(wg_size) +@compute @workgroup_size(WG_SIZE) fn main(@builtin(global_invocation_id) gid: vec3<u32>) { - // two elements per thread + // two elements per n_threads if (gid.x >= params.n_threads) { return; } @@ -290,6 +220,5 @@ fn main(@builtin(global_invocation_id) gid: vec3<u32>) { let x0 = f32(src0[i_src]); let x1 = f32(src0[i_src + pair_offset(is_neox, is_mrope, is_vision)]); rotate(i_dst, i_dst + pair_offset(is_neox, is_mrope, is_vision), x0 * thetas.x - x1 * thetas.y, x0 * thetas.y + x1 * thetas.x); -} -#end(SHADER) +} diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/row_norm.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/row_norm.wgsl new file mode 100644 index 00000000000..5eaf5e7bbe5 --- /dev/null +++ b/ggml/src/ggml-webgpu/wgsl-shaders/row_norm.wgsl @@ -0,0 +1,153 @@ +#if defined(SRC_F16) || defined(DST_F16) +enable f16; +#endif + +#ifdef SRC_F16 +#define SRC_TYPE f16 +#else +#define SRC_TYPE f32 +#endif + +#ifdef DST_F16 +#define DST_TYPE f16 +#else +#define DST_TYPE f32 +#endif + +struct Params { + offset_src: u32, // in elements + offset_dst: u32, // in elements + + // Strides (in elements) + stride_src1: u32, + stride_src2: u32, + stride_src3: u32, + + stride_dst1: u32, + stride_dst2: u32, + stride_dst3: u32, + + // Shape of src/dst + ne0: u32, + ne1: u32, + ne2: u32, + ne3: u32, + + eps: f32 +}; + +@group(0) @binding(0) +var<storage, read_write> src: array<SRC_TYPE>; + +#ifdef INPLACE +@group(0) @binding(1) +var<uniform> params: Params; +#else +@group(0) @binding(1) +var<storage, read_write> dst: array<DST_TYPE>; + +@group(0) @binding(2) +var<uniform> params: Params; +#endif + +var<workgroup> scratch: array<f32, WG_SIZE * 2u>; + +@compute @workgroup_size(WG_SIZE) +fn main(@builtin(workgroup_id) wid: vec3<u32>, + @builtin(local_invocation_id) lid: vec3<u32>) { + + // one thread per row + var i = wid.x; + let i3 = i / (params.ne2 * params.ne1); + i = i % (params.ne2 * params.ne1); + let i2 = i / params.ne1; + let i1 = i % params.ne1; + let i_src_row = params.offset_src + i3 * params.stride_src3 + i2 * params.stride_src2 + i1 * params.stride_src1; + let i_dst_row = params.offset_dst + i3 * params.stride_dst3 + i2 * params.stride_dst2 + i1 * params.stride_dst1; + + let elems = (params.ne0 + WG_SIZE - 1) / WG_SIZE; + + var sum = 0.0f; + var col = lid.x; + for (var j: u32 = 0; j < elems; j++) { + if (col >= params.ne0) { + break; + } + let v = f32(src[i_src_row + col]); +#ifdef NORM + sum += v; +#else + sum += v * v; +#endif + col += WG_SIZE; + } + + scratch[lid.x] = sum; + workgroupBarrier(); + + var offset: u32 = WG_SIZE / 2u; + while (offset > 0) { + if (lid.x < offset) { + scratch[lid.x] += scratch[lid.x + offset]; + } + offset /= 2u; + workgroupBarrier(); + } + sum = scratch[0]; + +#ifdef NORM + let mean = sum / f32(params.ne0); + var sq_sum = 0.0f; + col = lid.x; + for (var j: u32 = 0; j < elems; j++) { + if (col >= params.ne0) { + break; + } + let v = f32(src[i_src_row + col]); + let d = v - mean; + sq_sum += d * d; + col += WG_SIZE; + } + + workgroupBarrier(); + scratch[lid.x] = sq_sum; + workgroupBarrier(); + offset = WG_SIZE / 2u; + while (offset > 0) { + if (lid.x < offset) { + scratch[lid.x] += scratch[lid.x + offset]; + } + offset /= 2u; + workgroupBarrier(); + } + + let variance = scratch[0] / f32(params.ne0); + let scale = 1.0 / sqrt(variance + params.eps); +#elif defined(RMS_NORM) + let scale = 1.0/sqrt(sum/f32(params.ne0) + params.eps); +#elif defined(L2_NORM) + let scale = 1.0/max(sqrt(sum), params.eps); +#endif + +#ifdef NORM + let mean_val = mean; +#else + let mean_val = 0.0f; +#endif + + col = lid.x; + for (var j: u32 = 0; j < elems; j++) { + if (col >= params.ne0) { + break; + } + let i_src = i_src_row + col; + let i_dst = i_dst_row + col; + let v = src[i_src]; +#ifdef INPLACE + src[i_dst] = scale * (v - mean_val); +#else + dst[i_dst] = scale * (v - mean_val); +#endif + col += WG_SIZE; + } +} diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/scale.tmpl.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/scale.wgsl similarity index 71% rename from ggml/src/ggml-webgpu/wgsl-shaders/scale.tmpl.wgsl rename to ggml/src/ggml-webgpu/wgsl-shaders/scale.wgsl index 040e80dfea2..6c76ed69e45 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/scale.tmpl.wgsl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/scale.wgsl @@ -1,44 +1,21 @@ -#define(VARIANTS) - -[ - { - "SHADER_NAME": "scale_f32", - "DECLS": ["NOT_INPLACE"] - }, - { - "SHADER_NAME": "scale_f32_inplace", - "DECLS": ["INPLACE"] - } -] - -#end(VARIANTS) - -#define(DECLS) - -#decl(NOT_INPLACE) +#ifdef INPLACE @group(0) @binding(1) -var<storage, read_write> dst: array<f32>; - -@group(0) @binding(2) var<uniform> params: Params; fn store_scale(val: f32, offset: u32) { - dst[offset] = val; + src[offset] = val; } -#enddecl(NOT_INPLACE) - -#decl(INPLACE) +#else @group(0) @binding(1) +var<storage, read_write> dst: array<f32>; + +@group(0) @binding(2) var<uniform> params: Params; fn store_scale(val: f32, offset: u32) { - src[offset] = val; + dst[offset] = val; } -#enddecl(INPLACE) - -#end(DECLS) - -#define(SHADER) +#endif struct Params { offset_src: u32, @@ -65,16 +42,15 @@ struct Params { @group(0) @binding(0) var<storage, read_write> src: array<f32>; -DECLS - -override wg_size: u32; -@compute @workgroup_size(wg_size) -fn main(@builtin(global_invocation_id) gid: vec3<u32>) { - if (gid.x >= params.ne) { +@compute @workgroup_size(WG_SIZE) +fn main( + @builtin(global_invocation_id) gid: vec3<u32>, + @builtin(num_workgroups) num_wg: vec3<u32>) { + let threads_per_group = u32(WG_SIZE); + var i = gid.x + (num_wg.x * threads_per_group) * gid.y; + if (i >= params.ne) { return; } - - var i = gid.x; let i3 = i / (params.ne2 * params.ne1 * params.ne0); i = i % (params.ne2 * params.ne1 * params.ne0); let i2 = i / (params.ne1 * params.ne0); @@ -87,4 +63,3 @@ fn main(@builtin(global_invocation_id) gid: vec3<u32>) { store_scale(src[i_src] * params.scale + params.bias, i_dst); } -#end(SHADER) diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/set.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/set.wgsl new file mode 100644 index 00000000000..0a7ae9bdb2c --- /dev/null +++ b/ggml/src/ggml-webgpu/wgsl-shaders/set.wgsl @@ -0,0 +1,109 @@ +#ifdef TYPE_I32 +#define TYPE i32 +#else +#define TYPE f32 +#endif + +#ifndef INPLACE +@group(0) @binding(0) +var<storage, read_write> src0: array<TYPE>; +#define SRC1_BINDING 1 +#else +#define SRC1_BINDING 0 +#endif + +#define DST_BINDING SRC1_BINDING + 1 +#define PARAMS_BINDING SRC1_BINDING + 2 + +@group(0) @binding(SRC1_BINDING) +var<storage, read_write> src1: array<TYPE>; + +@group(0) @binding(DST_BINDING) +var<storage, read_write> dst: array<TYPE>; + +struct Params { + ne: u32, + offset_src0: u32, + offset_src1: u32, + offset_view: u32, + + stride_src10: u32, + stride_src11: u32, + stride_src12: u32, + stride_src13: u32, + + stride_dst10: u32, + stride_dst11: u32, + stride_dst12: u32, + stride_dst13: u32, + + src1_ne0: u32, + src1_ne1: u32, + src1_ne2: u32, + src1_ne3: u32, +}; + +@group(0) @binding(PARAMS_BINDING) +var<uniform> params: Params; + +fn decode_src1_coords(idx: u32) -> vec4<u32> { + var i = idx; + let plane = params.src1_ne2 * params.src1_ne1 * params.src1_ne0; + let i3 = i / plane; + i = i % plane; + let row = params.src1_ne1 * params.src1_ne0; + let i2 = i / row; + i = i % row; + let i1 = i / params.src1_ne0; + let i0 = i % params.src1_ne0; + return vec4<u32>(i0, i1, i2, i3); +} + +fn decode_view_coords(rel: u32) -> vec4<u32> { + let i3 = rel / params.stride_dst13; + let rem3 = rel % params.stride_dst13; + let i2 = rem3 / params.stride_dst12; + let rem2 = rem3 % params.stride_dst12; + let i1 = rem2 / params.stride_dst11; + let i0 = rem2 % params.stride_dst11; + return vec4<u32>(i0, i1, i2, i3); +} + +fn view_rel_from_coords(coords: vec4<u32>) -> u32 { + return coords.x * params.stride_dst10 + coords.y * params.stride_dst11 + + coords.z * params.stride_dst12 + coords.w * params.stride_dst13; +} + +fn src1_idx_from_coords(coords: vec4<u32>) -> u32 { + return coords.x * params.stride_src10 + coords.y * params.stride_src11 + + coords.z * params.stride_src12 + coords.w * params.stride_src13; +} + +fn in_set_view(rel: u32, coords: vec4<u32>) -> bool { + return view_rel_from_coords(coords) == rel; +} + +@compute @workgroup_size(WG_SIZE) +fn main(@builtin(global_invocation_id) gid: vec3<u32>) { + if (gid.x >= params.ne) { + return; + } + +#ifdef INPLACE + let coords = decode_src1_coords(gid.x); + + let src1_idx = params.offset_src1 + src1_idx_from_coords(coords); + let dst_idx = params.offset_view + view_rel_from_coords(coords); + + dst[dst_idx] = src1[src1_idx]; +#else + let rel = select(params.ne, gid.x - params.offset_view, gid.x >= params.offset_view); + let coords = decode_view_coords(rel); + + if (rel < params.stride_dst13 * params.src1_ne3 && in_set_view(rel, coords)) { + dst[gid.x] = src1[params.offset_src1 + src1_idx_from_coords(coords)]; + } else { + dst[gid.x] = src0[params.offset_src0 + gid.x]; + } +#endif +} diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.tmpl.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.tmpl.wgsl deleted file mode 100644 index fca3be6bc27..00000000000 --- a/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.tmpl.wgsl +++ /dev/null @@ -1,112 +0,0 @@ -#define(VARIANTS) - -[ - { - "SHADER_SUFFIX": "f16_vec", - "REPLS": { - "TYPE" : "vec4<f32>", - "DST_TYPE": "vec4<f16>", - "VEC_SIZE": 4 - } - }, - { - "SHADER_SUFFIX": "f16", - "REPLS": { - "TYPE" : "f32", - "DST_TYPE": "f16", - "VEC_SIZE": 1 - } - } -] - -#end(VARIANTS) - -#define(SHADER) - -enable f16; - -@group(0) @binding(0) -var<storage, read_write> src: array<{{TYPE}}>; - -@group(0) @binding(1) -var<storage, read_write> idx: array<u32>; - -@group(0) @binding(2) -var<storage, read_write> dst: array<{{DST_TYPE}}>; - -@group(0) @binding(3) -var<storage, read_write> error: atomic<u32>; - -struct Params { - offset_src: u32, // in elements - offset_idx: u32, // in elements - offset_dst: u32, // in elements - - // Strides (in elements) - stride_src1: u32, - stride_src2: u32, - stride_src3: u32, - - stride_idx0: u32, - stride_idx1: u32, - stride_idx2: u32, - - stride_dst1: u32, - stride_dst2: u32, - stride_dst3: u32, - - // Shape of src - ne0: u32, - n_rows: u32, - ne2: u32, - ne3: u32, - - // Shape of idx - idx1: u32, - idx2: u32, -}; - -@group(0) @binding(4) -var<uniform> params: Params; - -override wg_size: u32; -@compute @workgroup_size(wg_size) -fn main(@builtin(global_invocation_id) gid: vec3<u32>) { - if (gid.x >= (params.ne3 * params.ne2 * params.n_rows * params.ne0) / {{VEC_SIZE}}) { - return; - } - - // getting the row from gid - let elems_per_row = params.ne0 / {{VEC_SIZE}}; - var i = gid.x / elems_per_row; - - let i_src3 = i / (params.ne2 * params.n_rows); - - i = i % (params.ne2 * params.n_rows); - let i_src2 = i / params.n_rows; - let i_src1 = i % params.n_rows; - - let i_idx2 = i_src3 % params.idx2; - let i_idx1 = i_src2 % params.idx1; - let i_idx0 = i_src1; - - let idx_high = (params.offset_idx + i_idx0 * params.stride_idx0 + i_idx1 * params.stride_idx1 + i_idx2 * params.stride_idx2) * 2; - - let idx_high_val = idx[idx_high]; - let idx_low_val = idx[idx_high + 1]; - - if (idx_low_val != 0) { - // Upper bits of index are not zero, output will be incorrect - atomicStore(&error, 1); - return; - } - - let i_dst_row = params.offset_dst + idx_high_val * params.stride_dst1 + i_src2 * params.stride_dst2 + i_src3 * params.stride_dst3; - let i_src_row = params.offset_src + i_src1 * params.stride_src1 + i_src2 * params.stride_src2 + i_src3 * params.stride_src3; - - let col_idx = (gid.x % elems_per_row); - dst[i_dst_row/{{VEC_SIZE}} + col_idx] = {{DST_TYPE}}(src[i_src_row/{{VEC_SIZE}} + col_idx]); -} - -#end(SHADER) - diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.wgsl index 3567713dc21..09f2f0eddb3 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.wgsl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.wgsl @@ -1,16 +1,37 @@ enable f16; +#ifdef DST_F32 +#define DST_INNER_TYPE f32 +#else +#define DST_INNER_TYPE f16 +#endif + +#ifdef VEC4 +#define SRC_TYPE vec4<f32> +#define DST_TYPE vec4<DST_INNER_TYPE> +#define VEC_SIZE 4 +#else +#define SRC_TYPE f32 +#define DST_TYPE DST_INNER_TYPE +#define VEC_SIZE 1 +#endif + @group(0) @binding(0) -var<storage, read_write> src: array<f32>; +var<storage, read_write> src: array<SRC_TYPE>; @group(0) @binding(1) var<storage, read_write> idx: array<u32>; @group(0) @binding(2) -var<storage, read_write> dst: array<f16>; +var<storage, read_write> dst: array<DST_TYPE>; +#ifdef I64_IDX @group(0) @binding(3) var<storage, read_write> error: atomic<u32>; +#define PARAMS_BINDING 4 +#else +#define PARAMS_BINDING 3 +#endif struct Params { offset_src: u32, // in elements @@ -41,16 +62,18 @@ struct Params { idx2: u32, }; -@group(0) @binding(4) +@group(0) @binding(PARAMS_BINDING) var<uniform> params: Params; -override wg_size: u32; -@compute @workgroup_size(wg_size) +@compute @workgroup_size(WG_SIZE) fn main(@builtin(global_invocation_id) gid: vec3<u32>) { - if (gid.x >= params.n_rows * params.ne2 * params.ne3) { + if (gid.x >= (params.ne3 * params.ne2 * params.n_rows * params.ne0) / VEC_SIZE) { return; } - var i = gid.x; + + let elems_per_row = params.ne0 / VEC_SIZE; + var i = gid.x / elems_per_row; + let i_src3 = i / (params.ne2 * params.n_rows); i = i % (params.ne2 * params.n_rows); @@ -61,9 +84,10 @@ fn main(@builtin(global_invocation_id) gid: vec3<u32>) { let i_idx1 = i_src2 % params.idx1; let i_idx0 = i_src1; +#ifdef I64_IDX let idx_high = (params.offset_idx + i_idx0 * params.stride_idx0 + i_idx1 * params.stride_idx1 + i_idx2 * params.stride_idx2) * 2; - let idx_high_val = idx[idx_high]; + let idx_val = idx[idx_high]; let idx_low_val = idx[idx_high + 1]; if (idx_low_val != 0) { @@ -71,11 +95,14 @@ fn main(@builtin(global_invocation_id) gid: vec3<u32>) { atomicStore(&error, 1); return; } +#else + let idx_i = params.offset_idx + i_idx0 * params.stride_idx0 + i_idx1 * params.stride_idx1 + i_idx2 * params.stride_idx2; + let idx_val = idx[idx_i]; +#endif - let i_dst_row = params.offset_dst + idx_high_val * params.stride_dst1 + i_src2 * params.stride_dst2 + i_src3 * params.stride_dst3; + let i_dst_row = params.offset_dst + idx_val * params.stride_dst1 + i_src2 * params.stride_dst2 + i_src3 * params.stride_dst3; let i_src_row = params.offset_src + i_src1 * params.stride_src1 + i_src2 * params.stride_src2 + i_src3 * params.stride_src3; - for (var i: u32 = 0; i < params.ne0; i++) { - dst[i_dst_row + i] = f16(src[i_src_row + i]); - } + let col_idx = gid.x % elems_per_row; + dst[i_dst_row / VEC_SIZE + col_idx] = DST_TYPE(src[i_src_row / VEC_SIZE + col_idx]); } diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/set_rows_quant.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/set_rows_quant.wgsl new file mode 100644 index 00000000000..876e65b6ae1 --- /dev/null +++ b/ggml/src/ggml-webgpu/wgsl-shaders/set_rows_quant.wgsl @@ -0,0 +1,224 @@ +#ifdef DST_Q8_0 +#define BLOCK_SIZE 32u +#define BLOCK_BYTES 34u +#define QS_WORDS 8u +#elif defined(DST_Q4_0) +#define BLOCK_SIZE 32u +#define BLOCK_BYTES 18u +#define QS_WORDS 4u +#endif + +@group(0) @binding(0) +var<storage, read_write> src: array<f32>; + +@group(0) @binding(1) +var<storage, read_write> idx: array<u32>; + +@group(0) @binding(2) +#ifdef PAIR_BLOCKS +var<storage, read_write> dst: array<u32>; +#else +var<storage, read_write> dst: array<atomic<u32>>; +#endif + +#ifdef I64_IDX +@group(0) @binding(3) +var<storage, read_write> error: atomic<u32>; +#define PARAMS_BINDING 4 +#else +#define PARAMS_BINDING 3 +#endif + +struct Params { + offset_src: u32, // in elements + offset_idx: u32, // in elements + offset_dst: u32, // in blocks + + // Strides (in elements / blocks) + stride_src1: u32, + stride_src2: u32, + stride_src3: u32, + + stride_idx0: u32, + stride_idx1: u32, + stride_idx2: u32, + + stride_dst1: u32, + stride_dst2: u32, + stride_dst3: u32, + + // Shape of src + ne0: u32, + n_rows: u32, + ne2: u32, + ne3: u32, + + // Shape of idx + idx1: u32, + idx2: u32, +}; + +@group(0) @binding(PARAMS_BINDING) +var<uniform> params: Params; + +// if the quantization type is unaligned and there are an odd number of blocks per row, we need to store atomically +#ifndef PAIR_BLOCKS +fn merge_store_dst_word(word_idx: u32, mask: u32, bits: u32) { + loop { + let old = atomicLoad(&dst[word_idx]); + let merged = (old & ~mask) | (bits & mask); + let result = atomicCompareExchangeWeak(&dst[word_idx], old, merged); + if (result.exchanged) { + return; + } + } +} +#else +fn merge_store_dst_word(word_idx: u32, mask: u32, bits: u32) { + let old = dst[word_idx]; + dst[word_idx] = (old & ~mask) | (bits & mask); +} +#endif + +fn store_u16(dst_word_idx: u32, block_byte_offset: u32, byte_offset: u32, value: u32) { + let total_byte_offset = block_byte_offset + byte_offset; + let word_idx = dst_word_idx + total_byte_offset / 4u; + let shift = (total_byte_offset & 2u) * 8u; + let mask = 0xFFFFu << shift; + merge_store_dst_word(word_idx, mask, (value & 0xFFFFu) << shift); +} + +fn store_u32(dst_word_idx: u32, block_byte_offset: u32, byte_offset: u32, value: u32) { + let total_byte_offset = block_byte_offset + byte_offset; + let word_idx = dst_word_idx + total_byte_offset / 4u; + let shift = (total_byte_offset & 3u) * 8u; + + if (shift == 0u) { +#ifdef PAIR_BLOCKS + dst[word_idx] = value; +#else + atomicStore(&dst[word_idx], value); +#endif + return; + } + + let lo_mask = 0xFFFFFFFFu << shift; + let hi_mask = (1u << shift) - 1u; + merge_store_dst_word(word_idx, lo_mask, value << shift); + merge_store_dst_word(word_idx + 1u, hi_mask, value >> (32u - shift)); +} + +fn quantize_block_params(src_block: u32) -> vec2<f32> { +#ifdef DST_Q8_0 + var amax = 0.0; + for (var j: u32 = 0u; j < BLOCK_SIZE; j++) { + amax = max(amax, abs(src[src_block + j])); + } + + let d = amax / 127.0; + let id = select(0.0, 1.0 / d, d > 0.0); + return vec2(d, id); +#elif defined(DST_Q4_0) + var amax = 0.0; + var max_val = 0.0; + for (var j: u32 = 0u; j < BLOCK_SIZE; j++) { + let v = src[src_block + j]; + let av = abs(v); + if (amax < av) { + amax = av; + max_val = v; + } + } + + let d = max_val / -8.0; + let id = select(0.0, 1.0 / d, d != 0.0); + return vec2(d, id); +#endif +} + +fn quantize_block_word(src_block: u32, j: u32, id: f32) -> u32 { +#ifdef DST_Q8_0 + let base = src_block + j * 4u; + return (u32(i32(round(src[base + 0u] * id)) & 0xFF) << 0u) | + (u32(i32(round(src[base + 1u] * id)) & 0xFF) << 8u) | + (u32(i32(round(src[base + 2u] * id)) & 0xFF) << 16u) | + (u32(i32(round(src[base + 3u] * id)) & 0xFF) << 24u); +#elif defined(DST_Q4_0) + var packed_q = 0u; + for (var k: u32 = 0u; k < 4u; k++) { + let x0 = src[src_block + j * 4u + k] * id; + let x1 = src[src_block + 16u + j * 4u + k] * id; + let q0 = u32(clamp(i32(x0 + 8.5), 0, 15)); + let q1 = u32(clamp(i32(x1 + 8.5), 0, 15)); + packed_q |= (q0 & 0xFu) << (8u * k); + packed_q |= (q1 & 0xFu) << (8u * k + 4u); + } + return packed_q; +#endif +} + +fn quantize_block(src_block: u32, dst_word_idx: u32, block_byte_offset: u32) { + let params = quantize_block_params(src_block); + let d = params.x; + let id = params.y; + let packed_d = pack2x16float(vec2(d, 0.0)) & 0xFFFFu; + store_u16(dst_word_idx, block_byte_offset, 0u, packed_d); + + for (var j: u32 = 0u; j < QS_WORDS; j++) { + store_u32(dst_word_idx, block_byte_offset, 2u + j * 4u, quantize_block_word(src_block, j, id)); + } +} + +@compute @workgroup_size(WG_SIZE) +fn main(@builtin(global_invocation_id) gid: vec3<u32>) { + let blocks_per_row = params.ne0 / BLOCK_SIZE; +#ifdef PAIR_BLOCKS + let blocks_per_invocation = 2u; +#else + let blocks_per_invocation = 1u; +#endif + let invocations_per_row = blocks_per_row / blocks_per_invocation; + let total_invocations = params.ne3 * params.ne2 * params.n_rows * invocations_per_row; + if (gid.x >= total_invocations) { + return; + } + + var i = gid.x / invocations_per_row; + let block_in_row = (gid.x % invocations_per_row) * blocks_per_invocation; + + let i_src3 = i / (params.ne2 * params.n_rows); + i = i % (params.ne2 * params.n_rows); + let i_src2 = i / params.n_rows; + let i_src1 = i % params.n_rows; + + let i_idx2 = i_src3 % params.idx2; + let i_idx1 = i_src2 % params.idx1; + let i_idx0 = i_src1; + +#ifdef I64_IDX + let idx_high = (params.offset_idx + i_idx0 * params.stride_idx0 + i_idx1 * params.stride_idx1 + i_idx2 * params.stride_idx2) * 2u; + let idx_val = idx[idx_high]; + let idx_low_val = idx[idx_high + 1u]; + + if (idx_low_val != 0u) { + atomicStore(&error, 1u); + return; + } +#else + let idx_i = params.offset_idx + i_idx0 * params.stride_idx0 + i_idx1 * params.stride_idx1 + i_idx2 * params.stride_idx2; + let idx_val = idx[idx_i]; +#endif + + let dst_row_blocks = params.offset_dst + idx_val * params.stride_dst1 + i_src2 * params.stride_dst2 + i_src3 * params.stride_dst3; + let src_row = params.offset_src + i_src1 * params.stride_src1 + i_src2 * params.stride_src2 + i_src3 * params.stride_src3; + let src_block = src_row + block_in_row * BLOCK_SIZE; + let dst_block_byte = (dst_row_blocks + block_in_row) * BLOCK_BYTES; + + let dst_word_idx = dst_block_byte / 4u; +#ifdef PAIR_BLOCKS + quantize_block(src_block, dst_word_idx, 0u); + quantize_block(src_block + BLOCK_SIZE, dst_word_idx, BLOCK_BYTES); +#else + quantize_block(src_block, dst_word_idx, dst_block_byte & 3u); +#endif +} diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/soft_max.tmpl.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/soft_max.wgsl similarity index 59% rename from ggml/src/ggml-webgpu/wgsl-shaders/soft_max.tmpl.wgsl rename to ggml/src/ggml-webgpu/wgsl-shaders/soft_max.wgsl index c74dc4cc923..10edf136048 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/soft_max.tmpl.wgsl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/soft_max.wgsl @@ -1,262 +1,162 @@ -#define(VARIANTS) -[ - { - "SHADER_NAME": "soft_max_f32", - "DECLS": ["BASE_BINDINGS", "NOT_INPLACE", "NO_MASK", "NO_SINK"] - }, - { - "SHADER_NAME": "soft_max_f32_inplace", - "DECLS": ["BASE_BINDINGS_INPLACE", "INPLACE", "NO_MASK", "NO_SINK"] - }, - { - "SHADER_NAME": "soft_max_f32_sink", - "DECLS": ["SINK_BINDINGS", "NOT_INPLACE", "NO_MASK", "SINK"] - }, - { - "SHADER_NAME": "soft_max_f32_sink_inplace", - "DECLS": ["SINK_BINDINGS_INPLACE", "INPLACE", "NO_MASK", "SINK"] - }, - { - "SHADER_NAME": "soft_max_f32_mask_f32", - "REPLS": { - "MASK_TYPE" : "f32", - }, - "DECLS": ["MASK_BINDINGS", "NOT_INPLACE", "MASK", "NO_SINK"] - }, - { - "SHADER_NAME": "soft_max_f32_mask_f32_inplace", - "REPLS": { - "MASK_TYPE" : "f32", - }, - "DECLS": ["MASK_BINDINGS_INPLACE", "INPLACE", "MASK", "NO_SINK"] - }, - { - "SHADER_NAME": "soft_max_f32_mask_f16", - "REPLS": { - "MASK_TYPE" : "f16", - }, - "DECLS": ["MASK_BINDINGS", "NOT_INPLACE", "MASK", "NO_SINK"] - }, - { - "SHADER_NAME": "soft_max_f32_mask_f16_inplace", - "REPLS": { - "MASK_TYPE" : "f16", - }, - "DECLS": ["MASK_BINDINGS_INPLACE", "INPLACE", "MASK", "NO_SINK"] - }, - { - "SHADER_NAME": "soft_max_f32_mask_f32_sink", - "REPLS": { - "MASK_TYPE" : "f32", - }, - "DECLS": ["MASK_SINK_BINDINGS", "NOT_INPLACE", "MASK", "SINK"] - }, - { - "SHADER_NAME": "soft_max_f32_mask_f32_sink_inplace", - "REPLS": { - "MASK_TYPE" : "f32", - }, - "DECLS": ["MASK_SINK_BINDINGS_INPLACE", "INPLACE", "MASK", "SINK"] - }, - { - "SHADER_NAME": "soft_max_f32_mask_f16_sink", - "REPLS": { - "MASK_TYPE" : "f16", - }, - "DECLS": ["MASK_SINK_BINDINGS", "NOT_INPLACE", "MASK", "SINK"] - }, - { - "SHADER_NAME": "soft_max_f32_mask_f16_sink_inplace", - "REPLS": { - "MASK_TYPE" : "f16", - }, - "DECLS": ["MASK_SINK_BINDINGS_INPLACE", "INPLACE", "MASK", "SINK"] - } -] -#end(VARIANTS) - -#define(DECLS) - -#decl(BASE_BINDINGS) -@group(0) @binding(1) -var<storage, read_write> dst: array<f32>; +enable f16; -@group(0) @binding(2) -var<uniform> params: Params; -#enddecl(BASE_BINDINGS) +#ifdef MASK_F32 +#define MaskType f32 +#endif +#ifdef MASK_F16 +#define MaskType f16 +#endif -#decl(BASE_BINDINGS_INPLACE) -@group(0) @binding(1) -var<uniform> params: Params; -#enddecl(BASE_BINDINGS_INPLACE) +struct Params { + offset_src0: u32, + offset_src1: u32, + offset_sinks: u32, + offset_dst: u32, + + // Strides (in elements) + stride_src01: u32, + stride_src02: u32, + stride_src03: u32, + + stride_src11: u32, + stride_src12: u32, + stride_src13: u32, + + stride_dst1: u32, + stride_dst2: u32, + stride_dst3: u32, + + // shape of src0/dst + ne: u32, + ne0: u32, + ne1: u32, + ne2: u32, -#decl(SINK_BINDINGS) + // shape of src1 + ne12: u32, + ne13: u32, + + scale: f32, + max_bias: f32, + n_head_log2: f32, + m0: f32, + m1: f32, +}; + +@group(0) @binding(0) +var<storage, read_write> src: array<f32>; + +#ifdef HAS_MASK +#ifdef HAS_SINK @group(0) @binding(1) +var<storage, read_write> mask: array<MaskType>; +@group(0) @binding(2) var<storage, read_write> sinks: array<f32>; -@group(0) @binding(2) -var<storage, read_write> dst: array<f32>; +#ifdef INPLACE +@group(0) @binding(3) +var<uniform> params: Params; +#else @group(0) @binding(3) +var<storage, read_write> dst: array<f32>; +@group(0) @binding(4) var<uniform> params: Params; -#enddecl(SINK_BINDINGS) +#endif -#decl(SINK_BINDINGS_INPLACE) +#else @group(0) @binding(1) -var<storage, read_write> sinks: array<f32>; +var<storage, read_write> mask: array<MaskType>; +#ifdef INPLACE @group(0) @binding(2) var<uniform> params: Params; -#enddecl(SINK_BINDINGS_INPLACE) - -#decl(MASK_BINDINGS) -@group(0) @binding(1) -var<storage, read_write> mask: array<{{MASK_TYPE}}>; +#else @group(0) @binding(2) var<storage, read_write> dst: array<f32>; - @group(0) @binding(3) var<uniform> params: Params; -#enddecl(MASK_BINDINGS) +#endif +#endif -#decl(MASK_BINDINGS_INPLACE) +#else +#ifdef HAS_SINK @group(0) @binding(1) -var<storage, read_write> mask: array<{{MASK_TYPE}}>; +var<storage, read_write> sinks: array<f32>; +#ifdef INPLACE @group(0) @binding(2) var<uniform> params: Params; -#enddecl(MASK_BINDINGS_INPLACE) - -#decl(MASK_SINK_BINDINGS) -@group(0) @binding(1) -var<storage, read_write> mask: array<{{MASK_TYPE}}>; +#else @group(0) @binding(2) -var<storage, read_write> sinks: array<f32>; - -@group(0) @binding(3) var<storage, read_write> dst: array<f32>; - -@group(0) @binding(4) +@group(0) @binding(3) var<uniform> params: Params; -#enddecl(MASK_SINK_BINDINGS) +#endif -#decl(MASK_SINK_BINDINGS_INPLACE) +#else +#ifdef INPLACE @group(0) @binding(1) -var<storage, read_write> mask: array<{{MASK_TYPE}}>; - +var<uniform> params: Params; +#else +@group(0) @binding(1) +var<storage, read_write> dst: array<f32>; @group(0) @binding(2) -var<storage, read_write> sinks: array<f32>; - -@group(0) @binding(3) var<uniform> params: Params; -#enddecl(MASK_SINK_BINDINGS_INPLACE) +#endif +#endif +#endif -#decl(NOT_INPLACE) +#ifdef INPLACE fn inter_value(i: u32) -> f32 { - return dst[i]; + return src[i]; } - fn update(i: u32, val: f32) { - dst[i] = val; + src[i] = val; } -#enddecl(NOT_INPLACE) -#decl(INPLACE) +#else fn inter_value(i: u32) -> f32 { - return src[i]; + return dst[i]; } - fn update(i: u32, val: f32) { - src[i] = val; + dst[i] = val; } -#enddecl(INPLACE) +#endif -#decl(NO_MASK) +#ifdef HAS_MASK fn mask_val(i: u32) -> f32 { - return 0.0; + return f32(mask[i]); } -#enddecl(NO_MASK) -#decl(MASK) +#else fn mask_val(i: u32) -> f32 { - return f32(mask[i]); + return 0.0; } -#enddecl(MASK) +#endif -#decl(NO_SINK) +#ifdef HAS_SINK fn lower_max_bound(i2: u32) -> f32 { - return -1e30; + return sinks[params.offset_sinks + i2]; } - fn add_sinks(val: f32, i2: u32, max_val: f32) -> f32 { - return val; + return val + exp(sinks[params.offset_sinks + i2] - max_val); } -#enddecl(NO_SINK) - -#decl(SINK) +#else fn lower_max_bound(i2: u32) -> f32 { - return sinks[params.offset_sinks + i2]; + return -1e30; } - fn add_sinks(val: f32, i2: u32, max_val: f32) -> f32 { - return val + exp(sinks[params.offset_sinks + i2] - max_val); + return val; } -#enddecl(SINK) - -#end(DECLS) - -#define(SHADER) -enable f16; - -struct Params { - offset_src0: u32, - offset_src1: u32, - offset_sinks: u32, - offset_dst: u32, - - // Strides (in elements) - stride_src01: u32, - stride_src02: u32, - stride_src03: u32, - - stride_src11: u32, - stride_src12: u32, - stride_src13: u32, - - stride_dst1: u32, - stride_dst2: u32, - stride_dst3: u32, - - // shape of src0/dst - ne: u32, - ne0: u32, - ne1: u32, - ne2: u32, - - // shape of src1 - ne12: u32, - ne13: u32, - - scale: f32, - max_bias: f32, - n_head_log2: f32, - m0: f32, - m1: f32, -}; - -@group(0) @binding(0) -var<storage, read_write> src: array<f32>; - -DECLS +#endif const CACHE_SIZE: u32 = 16; +var<workgroup> scratch: array<f32, WG_SIZE>; -override wg_size: u32; -var<workgroup> scratch: array<f32, wg_size>; - -@compute @workgroup_size(wg_size) +@compute @workgroup_size(WG_SIZE) fn main(@builtin(workgroup_id) wid: vec3<u32>, @builtin(local_invocation_id) lid: vec3<u32>) { @@ -268,7 +168,7 @@ fn main(@builtin(workgroup_id) wid: vec3<u32>, let i_src0_row = params.offset_src0 + i3 * params.stride_src03 + i2 * params.stride_src02 + i1 * params.stride_src01; let i_src1_row = params.offset_src1 + (i3 % params.ne13) * params.stride_src13 + (i2 % params.ne12) * params.stride_src12 + i1 * params.stride_src11; let i_dst_row = params.offset_dst + i3 * params.stride_dst3 + i2 * params.stride_dst2 + i1 * params.stride_dst1; - let elems = (params.ne0 + wg_size - 1) / wg_size; + let elems = (params.ne0 + WG_SIZE - 1) / WG_SIZE; let head = f32(i2); let slope = select(1, select(pow(params.m1, 2 * (head - params.n_head_log2) + 1), pow(params.m0, head + 1), head < params.n_head_log2), params.max_bias > 0); @@ -286,12 +186,12 @@ fn main(@builtin(workgroup_id) wid: vec3<u32>, if (col < CACHE_SIZE) { cache[col] = val; } - col += wg_size; + col += WG_SIZE; } scratch[lid.x] = max_val; workgroupBarrier(); - var offset = wg_size / 2; + var offset: u32 = WG_SIZE / 2; while (offset > 0) { if (lid.x < offset) { scratch[lid.x] = max(scratch[lid.x], scratch[lid.x + offset]); @@ -317,12 +217,12 @@ fn main(@builtin(workgroup_id) wid: vec3<u32>, } else { update(i_dst_row + col, ex); } - col += wg_size; + col += WG_SIZE; } scratch[lid.x] = sum; workgroupBarrier(); - offset = wg_size / 2; + offset = WG_SIZE / 2; while (offset > 0) { if (lid.x < offset) { scratch[lid.x] += scratch[lid.x + offset]; @@ -339,7 +239,7 @@ fn main(@builtin(workgroup_id) wid: vec3<u32>, break; } update(i_dst_row + col, select(inter_value(i_dst_row + col), cache[col], col < CACHE_SIZE) * sum_recip); - col += wg_size; + col += WG_SIZE; } } -#end(SHADER) + diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/solve_tri.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/solve_tri.wgsl new file mode 100644 index 00000000000..9d5d902cb1e --- /dev/null +++ b/ggml/src/ggml-webgpu/wgsl-shaders/solve_tri.wgsl @@ -0,0 +1,121 @@ +@group(0) @binding(0) +var<storage, read_write> src0: array<f32>; + +@group(0) @binding(1) +var<storage, read_write> src1: array<f32>; + +@group(0) @binding(2) +var<storage, read_write> dst: array<f32>; + +struct Params { + offset_src0: u32, + offset_src1: u32, + offset_dst: u32, + + stride_src00: u32, + stride_src01: u32, + stride_src02: u32, + stride_src03: u32, + + stride_src10: u32, + stride_src11: u32, + stride_src12: u32, + stride_src13: u32, + + stride_dst0: u32, + stride_dst1: u32, + stride_dst2: u32, + stride_dst3: u32, + + k: u32, + ne2: u32, + ne3: u32, +}; + +@group(0) @binding(3) +var<uniform> params: Params; + +var<workgroup> shA: array<f32, BATCH_N * N>; +var<workgroup> shB: array<f32, BATCH_N * K_TILE>; + +fn src0_idx(row: u32, col: u32, i2: u32, i3: u32) -> u32 { + return params.offset_src0 + + col * params.stride_src00 + + row * params.stride_src01 + + i2 * params.stride_src02 + + i3 * params.stride_src03; +} + +fn src1_idx(row: u32, col: u32, i2: u32, i3: u32) -> u32 { + return params.offset_src1 + + col * params.stride_src10 + + row * params.stride_src11 + + i2 * params.stride_src12 + + i3 * params.stride_src13; +} + +fn dst_idx(row: u32, col: u32, i2: u32, i3: u32) -> u32 { + return params.offset_dst + + col * params.stride_dst0 + + row * params.stride_dst1 + + i2 * params.stride_dst2 + + i3 * params.stride_dst3; +} + +@compute @workgroup_size(WG_SIZE) +fn main( + @builtin(workgroup_id) workgroup_id: vec3<u32>, + @builtin(local_invocation_id) local_id: vec3<u32> +) { + let batch = workgroup_id.y; + let col = workgroup_id.x * WG_SIZE + local_id.x; + let i3 = batch / params.ne2; + let i2 = batch % params.ne2; + let active_lane = local_id.x < K_TILE; + let active_col = active_lane && col < params.k; + + var X: array<f32, N>; + + for (var row_base = 0u; row_base < N; row_base += BATCH_N) { + let cur_n = min(BATCH_N, N - row_base); + + for (var i = local_id.x; i < cur_n * N; i += WG_SIZE) { + let tile_row = i / N; + let tile_col = i % N; + shA[i] = src0[src0_idx(row_base + tile_row, tile_col, i2, i3)]; + } + + for (var i = local_id.x; i < cur_n * K_TILE; i += WG_SIZE) { + let tile_row = i / K_TILE; + let tile_col = i % K_TILE; + let global_col = workgroup_id.x * WG_SIZE + tile_col; + let sh_idx = tile_row * K_TILE + tile_col; + + if (global_col < params.k) { + shB[sh_idx] = src1[src1_idx(row_base + tile_row, global_col, i2, i3)]; + } else { + shB[sh_idx] = 0.0; + } + } + + workgroupBarrier(); + + if (active_col) { + for (var row_offset = 0u; row_offset < cur_n; row_offset++) { + let r = row_base + row_offset; + var b = shB[row_offset * K_TILE + local_id.x]; + let a_row = row_offset * N; + + for (var t = 0u; t < r; t++) { + b -= shA[a_row + t] * X[t]; + } + + let x = b / shA[a_row + r]; + X[r] = x; + dst[dst_idx(r, col, i2, i3)] = x; + } + } + + workgroupBarrier(); + } +} diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/ssm_conv.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/ssm_conv.wgsl new file mode 100644 index 00000000000..11511305ed8 --- /dev/null +++ b/ggml/src/ggml-webgpu/wgsl-shaders/ssm_conv.wgsl @@ -0,0 +1,65 @@ +@group(0) @binding(0) +var<storage, read_write> src0: array<f32>; + +@group(0) @binding(1) +var<storage, read_write> src1: array<f32>; + +@group(0) @binding(2) +var<storage, read_write> dst: array<f32>; + +struct Params { + offset_src0: u32, + offset_src1: u32, + offset_dst: u32, + + stride_src01: u32, + stride_src02: u32, + stride_src11: u32, + + stride_dst0: u32, + stride_dst1: u32, + stride_dst2: u32, + + nc: u32, + nr: u32, + n_t: u32, + n_s: u32, + token_tiles: u32, +}; + +@group(0) @binding(3) +var<uniform> params: Params; + +@compute @workgroup_size(BLOCK_SIZE, TOKENS_PER_WG) +fn main(@builtin(global_invocation_id) gid: vec3<u32>) { + let i1 = gid.x; + let tile_y = gid.y / TOKENS_PER_WG; + let local_token = gid.y % TOKENS_PER_WG; + let i3 = tile_y / params.token_tiles; + let token_tile = tile_y % params.token_tiles; + let i2 = token_tile * TOKENS_PER_WG + local_token; + + if (i1 >= params.nr || i2 >= params.n_t || i3 >= params.n_s) { + return; + } + + let src0_base = params.offset_src0 + i3 * params.stride_src02 + i2 + i1 * params.stride_src01; + let src1_base = params.offset_src1 + i1 * params.stride_src11; + + var sum = 0.0; + +#ifdef VECTORIZED + sum = + src0[src0_base + 0u] * src1[src1_base + 0u] + + src0[src0_base + 1u] * src1[src1_base + 1u] + + src0[src0_base + 2u] * src1[src1_base + 2u] + + src0[src0_base + 3u] * src1[src1_base + 3u]; +#else + for (var i0 = 0u; i0 < params.nc; i0++) { + sum += src0[src0_base + i0] * src1[src1_base + i0]; + } +#endif + + let dst_idx = params.offset_dst + i3 * params.stride_dst2 + i2 * params.stride_dst1 + i1 * params.stride_dst0; + dst[dst_idx] = sum; +} diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/ssm_scan.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/ssm_scan.wgsl new file mode 100644 index 00000000000..05761dec353 --- /dev/null +++ b/ggml/src/ggml-webgpu/wgsl-shaders/ssm_scan.wgsl @@ -0,0 +1,193 @@ +#ifdef USE_SUBGROUP_REDUCTION +enable subgroups; +#endif + +struct Params { + offset_s: u32, + offset_x: u32, + offset_dt: u32, + offset_A: u32, + offset_B: u32, + offset_C: u32, + offset_ids: u32, + offset_dst: u32, + + stride_s1: u32, + stride_s2: u32, + stride_s3: u32, + + stride_x1: u32, + stride_x2: u32, + stride_x3: u32, + + stride_dt1: u32, + stride_dt2: u32, + + a_ne0: u32, + stride_A1: u32, + + stride_B1: u32, + stride_B2: u32, + stride_B3: u32, + + stride_C1: u32, + stride_C2: u32, + stride_C3: u32, + + d_state: u32, + d_inner: u32, + n_head: u32, + n_group: u32, + n_seq_tokens: u32, + n_seqs: u32, + + y_elems: u32, +}; + +@group(0) @binding(0) var<storage, read_write> s_in: array<f32>; +#ifdef XBC_OVERLAP +@group(0) @binding(1) var<storage, read_write> x_B_C_merged: array<f32>; +@group(0) @binding(2) var<storage, read_write> dt: array<f32>; +@group(0) @binding(3) var<storage, read_write> A: array<f32>; +@group(0) @binding(4) var<storage, read_write> ids: array<i32>; +@group(0) @binding(5) var<storage, read_write> dst: array<f32>; +@group(0) @binding(6) var<uniform> params: Params; +#else +@group(0) @binding(1) var<storage, read_write> x: array<f32>; +@group(0) @binding(2) var<storage, read_write> dt: array<f32>; +@group(0) @binding(3) var<storage, read_write> A: array<f32>; +@group(0) @binding(4) var<storage, read_write> B: array<f32>; +@group(0) @binding(5) var<storage, read_write> C: array<f32>; +@group(0) @binding(6) var<storage, read_write> ids: array<i32>; +@group(0) @binding(7) var<storage, read_write> dst: array<f32>; +@group(0) @binding(8) var<uniform> params: Params; +#endif + +var<workgroup> shared_x_dt: array<f32, TOKENS_PER_TILE>; +var<workgroup> shared_dtsp: array<f32, TOKENS_PER_TILE>; +var<workgroup> shared_reduce: array<f32, TOKENS_PER_TILE * WG_SIZE>; + +fn reduce_base(token_in_tile: u32) -> u32 { + return token_in_tile * WG_SIZE; +} + +@compute @workgroup_size(WG_SIZE) +fn main( + @builtin(local_invocation_id) local_id: vec3<u32>, + @builtin(workgroup_id) wg_id: vec3<u32>, + @builtin(num_workgroups) num_wg: vec3<u32> +#ifdef USE_SUBGROUP_REDUCTION + , @builtin(subgroup_id) subgroup_id: u32, + @builtin(subgroup_invocation_id) subgroup_invocation_id: u32, + @builtin(num_subgroups) num_subgroups: u32 +#endif +) { + let tid = local_id.x; + let wg_linear = wg_id.y * num_wg.x + wg_id.x; + + let i1 = wg_linear % params.d_inner; + let head_seq = wg_linear / params.d_inner; + let ir = head_seq % params.n_head; + let i3 = head_seq / params.n_head; + + let state_slot = u32(ids[params.offset_ids + i3]); + let g = ir / (params.n_head / params.n_group); + + let s_idx = params.offset_s + tid + i1 * params.stride_s1 + ir * params.stride_s2 + state_slot * params.stride_s3; + var s_prev = s_in[s_idx]; + + let A0 = A[params.offset_A + (tid % params.a_ne0) + ir * params.stride_A1]; + + for (var token_base = 0u; token_base < params.n_seq_tokens; token_base += TOKENS_PER_TILE) { + if (tid < TOKENS_PER_TILE) { + let token = token_base + tid; + if (token < params.n_seq_tokens) { + let x_idx = params.offset_x + i1 + ir * params.stride_x1 + token * params.stride_x2 + i3 * params.stride_x3; + let dt_idx = params.offset_dt + ir + token * params.stride_dt1 + i3 * params.stride_dt2; + let dt0 = dt[dt_idx]; + let dtsp = select(log(1.0 + exp(dt0)), dt0, dt0 > 20.0); + shared_dtsp[tid] = dtsp; +#ifdef XBC_OVERLAP + shared_x_dt[tid] = x_B_C_merged[x_idx] * dtsp; +#else + shared_x_dt[tid] = x[x_idx] * dtsp; +#endif + } + } + + workgroupBarrier(); + + for (var token_in_tile = 0u; token_in_tile < TOKENS_PER_TILE; token_in_tile++) { + let token = token_base + token_in_tile; + if (token >= params.n_seq_tokens) { + break; + } + + let x_dt = shared_x_dt[token_in_tile]; + let dA = exp(shared_dtsp[token_in_tile] * A0); + let reduce_idx = reduce_base(token_in_tile) + tid; + + let b_idx = params.offset_B + tid + g * params.stride_B1 + token * params.stride_B2 + i3 * params.stride_B3; + let c_idx = params.offset_C + tid + g * params.stride_C1 + token * params.stride_C2 + i3 * params.stride_C3; +#ifdef XBC_OVERLAP + let s = s_prev * dA + x_B_C_merged[b_idx] * x_dt; +#else + let s = s_prev * dA + B[b_idx] * x_dt; +#endif + s_prev = s; + +#ifdef USE_SUBGROUP_REDUCTION +#ifdef XBC_OVERLAP + let subgroup_partial = subgroupAdd(s * x_B_C_merged[c_idx]); +#else + let subgroup_partial = subgroupAdd(s * C[c_idx]); +#endif + if (subgroup_invocation_id == 0u) { + shared_reduce[reduce_idx - tid + subgroup_id] = subgroup_partial; + } +#else +#ifdef XBC_OVERLAP + shared_reduce[reduce_idx] = s * x_B_C_merged[c_idx]; +#else + shared_reduce[reduce_idx] = s * C[c_idx]; +#endif +#endif + + workgroupBarrier(); + +#ifdef USE_SUBGROUP_REDUCTION + if (tid == 0u) { + var sum = 0.0; + for (var sg = 0u; sg < num_subgroups; sg++) { + sum += shared_reduce[reduce_base(token_in_tile) + sg]; + } + let y_idx = + params.offset_dst + i1 + ir * params.d_inner + token * (params.n_head * params.d_inner) + + i3 * (params.n_seq_tokens * params.n_head * params.d_inner); + dst[y_idx] = sum; + } +#else + for (var stride = WG_SIZE / 2u; stride > 0u; stride >>= 1u) { + if (tid < stride) { + shared_reduce[reduce_idx] += shared_reduce[reduce_idx + stride]; + } + workgroupBarrier(); + } + + if (tid == 0u) { + let y_idx = + params.offset_dst + i1 + ir * params.d_inner + token * (params.n_head * params.d_inner) + + i3 * (params.n_seq_tokens * params.n_head * params.d_inner); + dst[y_idx] = shared_reduce[reduce_base(token_in_tile)]; + } +#endif + + workgroupBarrier(); + } + } + + let state_idx = + params.offset_dst + params.y_elems + tid + i1 * params.d_state + ir * (params.d_state * params.d_inner) + + i3 * (params.d_state * params.d_inner * params.n_head); + dst[state_idx] = s_prev; +} diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/sum_rows.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/sum_rows.wgsl new file mode 100644 index 00000000000..6ea2de9b7c6 --- /dev/null +++ b/ggml/src/ggml-webgpu/wgsl-shaders/sum_rows.wgsl @@ -0,0 +1,55 @@ +@group(0) @binding(0) +var<storage, read_write> src: array<f32>; + +@group(0) @binding(1) +var<storage, read_write> dst: array<f32>; + +struct Params { + offset_src: u32, // in elements + offset_dst: u32, // in elements + + // Strides (in elements) + stride_src1: u32, + stride_src2: u32, + stride_src3: u32, + + ne0: u32, + ne1: u32, + ne2: u32 +}; + +@group(0) @binding(2) +var<uniform> params: Params; + +var<workgroup> shared_sum: array<f32, WG_SIZE>; + +@compute @workgroup_size(WG_SIZE) +fn main(@builtin(workgroup_id) wid: vec3<u32>, + @builtin(local_invocation_id) lid: vec3<u32>) { + + var i = wid.x; + let i3 = i / (params.ne2 * params.ne1); + i = i % (params.ne2 * params.ne1); + let i2 = i / params.ne1; + let i1 = i % params.ne1; + let i_src_row = params.offset_src + i3 * params.stride_src3 + i2 * params.stride_src2 + i1 * params.stride_src1; + var local_sum: f32 = 0.0; + for (var col = lid.x; col < params.ne0; col += WG_SIZE) { + local_sum += src[i_src_row + col]; + } + shared_sum[lid.x] = local_sum; + workgroupBarrier(); + // reduce within workgroup + var offset: u32 = WG_SIZE >> 1; + while (offset > 0) { + if (lid.x < offset) { + shared_sum[lid.x] = shared_sum[lid.x] + shared_sum[lid.x + offset]; + } + workgroupBarrier(); + offset >>= 1; + } + + if (lid.x == 0) { + dst[params.offset_dst + wid.x] = shared_sum[0]; + } +} diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/unary.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/unary.wgsl new file mode 100644 index 00000000000..cb342c47263 --- /dev/null +++ b/ggml/src/ggml-webgpu/wgsl-shaders/unary.wgsl @@ -0,0 +1,213 @@ +#ifdef TYPE_F16 +enable f16; +#define TYPE f16 +#else +#define TYPE f32 +#endif + +@group(0) @binding(0) +var<storage, read_write> src: array<TYPE>; + +#ifndef INPLACE +@group(0) @binding(1) +var<storage, read_write> dst: array<TYPE>; +#define PARAMS_BINDING 2 +#else +#define PARAMS_BINDING 1 +#endif + +struct Params { + ne: u32, // total number of elements + offset_src: u32, // in elements + offset_dst: u32, // in elements + + // Strides (in elements) + stride_src0: u32, + stride_src1: u32, + stride_src2: u32, + stride_src3: u32, + + // Logical shapes + ne0: u32, + ne1: u32, + ne2: u32, +#ifdef CLAMP + clamp_min: f32, + clamp_max: f32, +#endif +#ifdef FILL + fill_val: f32, +#endif +#ifdef XIELU + alpha_n: f32, + alpha_p: f32, + beta: f32, + eps: f32, +#endif + +}; + +@group(0) @binding(PARAMS_BINDING) +var<uniform> params: Params; + +fn erf_approx(x: TYPE) -> TYPE { + let x_f32 = f32(x); + let s = select(-1.0, 1.0, x_f32 >= 0.0); + let ax = abs(x_f32); + + let t = 1.0 / (1.0 + 0.3275911 * ax); + + let y = 1.0 - + (((((1.061405429 * t - 1.453152027) * t + 1.421413741) * t + - 0.284496736) * t + 0.254829592) * t) * + exp(-ax * ax); + + return TYPE(s * y); +} + +@compute @workgroup_size(WG_SIZE) +fn main(@builtin(global_invocation_id) gid: vec3<u32>, + @builtin(num_workgroups) num_wg: vec3<u32>) { + let threads_per_group = u32(WG_SIZE); + let flat_i = gid.x + (num_wg.x * threads_per_group) * gid.y; + if (flat_i >= params.ne) { + return; + } + var i = flat_i; + let ne2 = params.ne2; +#ifdef DIAG + let ne1 = params.ne0; +#else + let ne1 = params.ne1; +#endif + let ne0 = params.ne0; + + let i3 = i / (ne2 * ne1 * ne0); + i = i % (ne2 * ne1 * ne0); + let i2 = i / (ne1 * ne0); + i = i % (ne1 * ne0); + let i1 = i / ne0; + let i0 = i % ne0; + + let src_idx = i0 * params.stride_src0 + i1 * params.stride_src1 + i2 * params.stride_src2 + i3 * params.stride_src3; + +#ifdef ABS + let res = abs(src[params.offset_src + src_idx]); +#endif +#ifdef SGN + let res = select(TYPE(select(0.0, -1.0, src[params.offset_src + src_idx] < 0.0)), TYPE(1.0), src[params.offset_src + src_idx] > 0.0); +#endif +#ifdef NEG + let res = -src[params.offset_src + src_idx]; +#endif +#ifdef STEP + let res = TYPE(select(0.0, 1.0, src[params.offset_src + src_idx] > 0.0)); +#endif +#ifdef TANH + let res = tanh(clamp(src[params.offset_src + src_idx], -9.010913, 9.010913)); +#endif +#ifdef RELU + let res = select(0.0, src[params.offset_src + src_idx], src[params.offset_src + src_idx] > 0.0); +#endif +#ifdef ELU + let res = select(exp(src[params.offset_src + src_idx]) - 1.0, src[params.offset_src + src_idx], src[params.offset_src + src_idx] > 0.0); +#endif +#ifdef HARDSIGMOID + let res = min(1.0, max(0.0, (src[params.offset_src + src_idx] + 3.0) / 6.0)); +#endif +#ifdef SIGMOID + let res = 1.0 / (1.0 + exp(-src[params.offset_src + src_idx])); +#endif +#ifdef SILU + let res = src[params.offset_src + src_idx] / (1.0 + exp(-src[params.offset_src + src_idx])); +#endif +#ifdef EXP + let src_f32 = f32(src[params.offset_src + src_idx]); + let res = TYPE(exp(src_f32)); +#endif +#ifdef LOG + let res = TYPE(log(f32(src[params.offset_src + src_idx]))); +#endif +#ifdef CLAMP + let res = clamp(src[params.offset_src + src_idx], TYPE(params.clamp_min), TYPE(params.clamp_max)); +#endif +#ifdef FILL + let res = TYPE(params.fill_val); +#endif +#ifdef HARDSWISH + let res = src[params.offset_src + src_idx] * min(1.0, max(0.0, (src[params.offset_src + src_idx] + 3.0) / 6.0)); +#endif +#ifdef GELU + let res = 0.5 * src[params.offset_src + src_idx] * (1.0 + tanh(clamp(0.7978845608028654 * (src[params.offset_src + src_idx] + 0.044715 * src[params.offset_src + src_idx] * src[params.offset_src + src_idx] * src[params.offset_src + src_idx]), -9.010913, 9.010913))); +#endif +#ifdef GELU_QUICK + let res = src[params.offset_src + src_idx] * (1.0 / (1.0 + exp(clamp(-1.702 * src[params.offset_src + src_idx], -80.0, 80.0)))); +#endif +#ifdef GELU_ERF + let res = 0.5 * src[params.offset_src + src_idx] * (1.0 + erf_approx(src[params.offset_src + src_idx] * 0.7071067811865476)); +#endif +#ifdef XIELU + let val = f32(src[params.offset_src + src_idx]); + let res = + TYPE(select( + ((exp(min(val, params.eps)) - 1.0) - val) * params.alpha_n + params.beta * val, + params.alpha_p * val * val + params.beta * val, + val > 0.0)); +#endif +#ifdef SOFTPLUS + let src_f32 = f32(src[params.offset_src + src_idx]); + let res = TYPE(select(log(1.0 + exp(src_f32)), src_f32, src_f32 > 20.0)); +#endif +#ifdef EXPM1 + let src_f32 = f32(src[params.offset_src + src_idx]); + let res = TYPE(exp(src_f32) - 1.0); +#endif +#ifdef FLOOR + let res = floor(src[params.offset_src + src_idx]); +#endif +#ifdef CEIL + let res = ceil(src[params.offset_src + src_idx]); +#endif +#ifdef ROUND + let src_f32 = f32(src[params.offset_src + src_idx]); + let result = select(ceil(src_f32 - 0.5), floor(src_f32 + 0.5), src_f32 >= 0.0); + let res = TYPE(result); +#endif +#ifdef TRUNC + let res = trunc(src[params.offset_src + src_idx]); +#endif +#ifdef SQR + let res = src[params.offset_src + src_idx] * src[params.offset_src + src_idx]; +#endif +#ifdef SQRT + let res = TYPE(sqrt(f32(src[params.offset_src + src_idx]))); +#endif +#ifdef SIN + let res_f32 = sin(f32(src[params.offset_src + src_idx])); + let res = TYPE(res_f32); +#endif +#ifdef COS + let res_f32 = cos(f32(src[params.offset_src + src_idx])); + let res = TYPE(res_f32); +#endif +#ifdef DIAG + let res = select(0.0, src[params.offset_src + i0 + i2 * params.stride_src2 + i3 * params.stride_src3], i0 == i1); +#endif +#ifdef TRI +#ifdef TRI_TYPE_LOWER + let res = select(0.0, src[params.offset_src + src_idx], i0 < i1); +#elif TRI_TYPE_LOWER_DIAG + let res = select(0.0, src[params.offset_src + src_idx], i0 <= i1); +#elif TRI_TYPE_UPPER + let res = select(0.0, src[params.offset_src + src_idx], i0 > i1); +#elif TRI_TYPE_UPPER_DIAG + let res = select(0.0, src[params.offset_src + src_idx], i0 >= i1); +#endif +#endif + +#ifdef INPLACE + src[params.offset_src + src_idx] = res; +#else + dst[params.offset_dst + flat_i] = res; +#endif +} diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/unary_op.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/unary_op.wgsl deleted file mode 100644 index 25fe2854518..00000000000 --- a/ggml/src/ggml-webgpu/wgsl-shaders/unary_op.wgsl +++ /dev/null @@ -1,483 +0,0 @@ -#define(REPL_TEMPLATES) - -{ - "XIELU_FUNC": "{{MUTATE}}[dst_i] = select(((exp(min(src[src_i], {{TYPE}}(params.eps))) - 1.0) - src[src_i]) * {{TYPE}}(params.alpha_n) + {{TYPE}}(params.beta) * src[src_i], {{TYPE}}(params.alpha_p) * src[src_i] * src[src_i] + {{TYPE}}(params.beta) * src[src_i], src[src_i] > 0.0);", - "ABS_FUNC": "{{MUTATE}}[dst_i] = abs(src[src_i]);", - "SGN_FUNC": "{{MUTATE}}[dst_i] = select({{TYPE}}(select(0.0, -1.0, src[src_i] < 0.0)), {{TYPE}}(1.0), src[src_i] > 0.0);", - "NEG_FUNC": "{{MUTATE}}[dst_i] = -src[src_i];", - "STEP_FUNC": "{{MUTATE}}[dst_i] = {{TYPE}}(select(0.0, 1.0, src[src_i] > 0.0));", - "TANH_FUNC": "{{MUTATE}}[dst_i] = tanh(clamp(src[src_i], -9.010913, 9.010913)); // Regarding tanh() domain restrictions in wgsl https://github.com/gpuweb/gpuweb/issues/4458", - "RELU_FUNC": "{{MUTATE}}[dst_i] = select(0.0, src[src_i], src[src_i] > 0.0);", - "ELU_FUNC": "{{MUTATE}}[dst_i] = select(exp(src[src_i]) - 1.0, src[src_i], src[src_i] > 0.0);", - "HARDSIGMOID_FUNC": "{{MUTATE}}[dst_i] = min(1.0, max(0.0, (src[src_i] + 3.0) / 6.0));", - "SIGMOID_FUNC": "{{MUTATE}}[dst_i] = 1.0 / (1.0 + exp(-src[src_i]));", - "SILU_FUNC": "{{MUTATE}}[dst_i] = src[src_i] / (1.0 + exp(-src[src_i]));", - "EXP_FUNC": "{{MUTATE}}[dst_i] = exp(src[src_i]);", - "HARDSWISH_FUNC": "{{MUTATE}}[dst_i] = src[src_i] * min(1.0, max(0.0, (src[src_i] + 3.0) / 6.0));", - "GELU_FUNC": "{{MUTATE}}[dst_i] = 0.5 * src[src_i] * (1.0 + tanh(clamp(sqrt(2.0 / 3.14159265) * (src[src_i] + 0.044715 * pow(src[src_i], 3.0)), -9.010913, 9.010913))); // Regarding tanh() domain restrictions in wgsl https://github.com/gpuweb/gpuweb/issues/4458", - "GELU_QUICK_FUNC": "{{MUTATE}}[dst_i] = src[src_i] * 0.5 * (1.0 + tanh(clamp(0.79788456 * (src[src_i] + 0.044715 * src[src_i] * src[src_i] * src[src_i]), -9.010913, 9.010913))); // Regarding tanh() domain restrictions in wgsl https://github.com/gpuweb/gpuweb/issues/4458", - "GELU_ERF_FUNC": "{{MUTATE}}[dst_i] = 0.5 * src[src_i] * (1.0 + tanh(clamp(0.79788456 * (src[src_i] + 0.044715 * src[src_i] * src[src_i] * src[src_i]), -9.010913, 9.010913))); // Regarding tanh() domain restrictions in wgsl https://github.com/gpuweb/gpuweb/issues/4458", - "CEIL_FUNC": "{{MUTATE}}[dst_i] = ceil(src[src_i]);" -} - -#end(REPL_TEMPLATES) - -#define(VARIANTS) - -[ - { - "SHADER_NAME": "abs_f32", - "REPLS": { "TYPE": "f32", "FUNC": "ABS_FUNC", "EXT_PARAMS": "", "MUTATE": "dst" }, - "DECLS": ["NOT_INPLACE"] - }, - { - "SHADER_NAME": "abs_f16", - "REPLS": { "TYPE": "f16", "FUNC": "ABS_FUNC", "EXT_PARAMS": "", "MUTATE": "dst" }, - "DECLS": ["NOT_INPLACE"] - }, - { - "SHADER_NAME": "abs_inplace_f32", - "REPLS": { "TYPE": "f32", "FUNC": "ABS_FUNC", "EXT_PARAMS": "", "MUTATE": "src" }, - "DECLS": ["INPLACE"] - }, - { - "SHADER_NAME": "abs_inplace_f16", - "REPLS": { "TYPE": "f16", "FUNC": "ABS_FUNC", "EXT_PARAMS": "", "MUTATE": "src" }, - "DECLS": ["INPLACE"] - }, - - { - "SHADER_NAME": "sgn_f32", - "REPLS": { "TYPE": "f32", "FUNC": "SGN_FUNC", "EXT_PARAMS": "", "MUTATE": "dst" }, - "DECLS": ["NOT_INPLACE"] - }, - { - "SHADER_NAME": "sgn_f16", - "REPLS": { "TYPE": "f16", "FUNC": "SGN_FUNC", "EXT_PARAMS": "", "MUTATE": "dst" }, - "DECLS": ["NOT_INPLACE"] - }, - { - "SHADER_NAME": "sgn_inplace_f32", - "REPLS": { "TYPE": "f32", "FUNC": "SGN_FUNC", "EXT_PARAMS": "", "MUTATE": "src" }, - "DECLS": ["INPLACE"] - }, - { - "SHADER_NAME": "sgn_inplace_f16", - "REPLS": { "TYPE": "f16", "FUNC": "SGN_FUNC", "EXT_PARAMS": "", "MUTATE": "src" }, - "DECLS": ["INPLACE"] - }, - - { - "SHADER_NAME": "neg_f32", - "REPLS": { "TYPE": "f32", "FUNC": "NEG_FUNC", "EXT_PARAMS": "", "MUTATE": "dst" }, - "DECLS": ["NOT_INPLACE"] - }, - { - "SHADER_NAME": "neg_f16", - "REPLS": { "TYPE": "f16", "FUNC": "NEG_FUNC", "EXT_PARAMS": "", "MUTATE": "dst" }, - "DECLS": ["NOT_INPLACE"] - }, - { - "SHADER_NAME": "neg_inplace_f32", - "REPLS": { "TYPE": "f32", "FUNC": "NEG_FUNC", "EXT_PARAMS": "", "MUTATE": "src" }, - "DECLS": ["INPLACE"] - }, - { - "SHADER_NAME": "neg_inplace_f16", - "REPLS": { "TYPE": "f16", "FUNC": "NEG_FUNC", "EXT_PARAMS": "", "MUTATE": "src" }, - "DECLS": ["INPLACE"] - }, - - { - "SHADER_NAME": "step_f32", - "REPLS": { "TYPE": "f32", "FUNC": "STEP_FUNC", "EXT_PARAMS": "", "MUTATE": "dst" }, - "DECLS": ["NOT_INPLACE"] - }, - { - "SHADER_NAME": "step_f16", - "REPLS": { "TYPE": "f16", "FUNC": "STEP_FUNC", "EXT_PARAMS": "", "MUTATE": "dst" }, - "DECLS": ["NOT_INPLACE"] - }, - { - "SHADER_NAME": "step_inplace_f32", - "REPLS": { "TYPE": "f32", "FUNC": "STEP_FUNC", "EXT_PARAMS": "", "MUTATE": "src" }, - "DECLS": ["INPLACE"] - }, - { - "SHADER_NAME": "step_inplace_f16", - "REPLS": { "TYPE": "f16", "FUNC": "STEP_FUNC", "EXT_PARAMS": "", "MUTATE": "src" }, - "DECLS": ["INPLACE"] - }, - - { - "SHADER_NAME": "tanh_f32", - "REPLS": { "TYPE": "f32", "FUNC": "TANH_FUNC", "EXT_PARAMS": "", "MUTATE": "dst" }, - "DECLS": ["NOT_INPLACE"] - }, - { - "SHADER_NAME": "tanh_f16", - "REPLS": { "TYPE": "f16", "FUNC": "TANH_FUNC", "EXT_PARAMS": "", "MUTATE": "dst" }, - "DECLS": ["NOT_INPLACE"] - }, - { - "SHADER_NAME": "tanh_inplace_f32", - "REPLS": { "TYPE": "f32", "FUNC": "TANH_FUNC", "EXT_PARAMS": "", "MUTATE": "src" }, - "DECLS": ["INPLACE"] - }, - { - "SHADER_NAME": "tanh_inplace_f16", - "REPLS": { "TYPE": "f16", "FUNC": "TANH_FUNC", "EXT_PARAMS": "", "MUTATE": "src" }, - "DECLS": ["INPLACE"] - }, - - { - "SHADER_NAME": "elu_f32", - "REPLS": { "TYPE": "f32", "FUNC": "ELU_FUNC", "EXT_PARAMS": "", "MUTATE": "dst" }, - "DECLS": ["NOT_INPLACE"] - }, - { - "SHADER_NAME": "elu_f16", - "REPLS": { "TYPE": "f16", "FUNC": "ELU_FUNC", "EXT_PARAMS": "", "MUTATE": "dst" }, - "DECLS": ["NOT_INPLACE"] - }, - { - "SHADER_NAME": "elu_inplace_f32", - "REPLS": { "TYPE": "f32", "FUNC": "ELU_FUNC", "EXT_PARAMS": "", "MUTATE": "src" }, - "DECLS": ["INPLACE"] - }, - { - "SHADER_NAME": "elu_inplace_f16", - "REPLS": { "TYPE": "f16", "FUNC": "ELU_FUNC", "EXT_PARAMS": "", "MUTATE": "src" }, - "DECLS": ["INPLACE"] - }, - - { - "SHADER_NAME": "relu_f32", - "REPLS": { "TYPE": "f32", "FUNC": "RELU_FUNC", "EXT_PARAMS": "", "MUTATE": "dst" }, - "DECLS": ["NOT_INPLACE"] - }, - { - "SHADER_NAME": "relu_f16", - "REPLS": { "TYPE": "f16", "FUNC": "RELU_FUNC", "EXT_PARAMS": "", "MUTATE": "dst" }, - "DECLS": ["NOT_INPLACE"] - }, - { - "SHADER_NAME": "relu_inplace_f32", - "REPLS": { "TYPE": "f32", "FUNC": "RELU_FUNC", "EXT_PARAMS": "", "MUTATE": "src" }, - "DECLS": ["INPLACE"] - }, - { - "SHADER_NAME": "relu_inplace_f16", - "REPLS": { "TYPE": "f16", "FUNC": "RELU_FUNC", "EXT_PARAMS": "", "MUTATE": "src" }, - "DECLS": ["INPLACE"] - }, - - { - "SHADER_NAME": "sigmoid_f32", - "REPLS": { "TYPE": "f32", "FUNC": "SIGMOID_FUNC", "EXT_PARAMS": "", "MUTATE": "dst" }, - "DECLS": ["NOT_INPLACE"] - }, - { - "SHADER_NAME": "sigmoid_f16", - "REPLS": { "TYPE": "f16", "FUNC": "SIGMOID_FUNC", "EXT_PARAMS": "", "MUTATE": "dst" }, - "DECLS": ["NOT_INPLACE"] - }, - { - "SHADER_NAME": "sigmoid_inplace_f32", - "REPLS": { "TYPE": "f32", "FUNC": "SIGMOID_FUNC", "EXT_PARAMS": "", "MUTATE": "src" }, - "DECLS": ["INPLACE"] - }, - { - "SHADER_NAME": "sigmoid_inplace_f16", - "REPLS": { "TYPE": "f16", "FUNC": "SIGMOID_FUNC", "EXT_PARAMS": "", "MUTATE": "src" }, - "DECLS": ["INPLACE"] - }, - - { - "SHADER_NAME": "silu_f32", - "REPLS": { "TYPE": "f32", "FUNC": "SILU_FUNC", "EXT_PARAMS": "", "MUTATE": "dst" }, - "DECLS": ["NOT_INPLACE"] - }, - { - "SHADER_NAME": "silu_f16", - "REPLS": { "TYPE": "f16", "FUNC": "SILU_FUNC", "EXT_PARAMS": "", "MUTATE": "dst" }, - "DECLS": ["NOT_INPLACE"] - }, - { - "SHADER_NAME": "silu_inplace_f32", - "REPLS": { "TYPE": "f32", "FUNC": "SILU_FUNC", "EXT_PARAMS": "", "MUTATE": "src" }, - "DECLS": ["INPLACE"] - }, - { - "SHADER_NAME": "silu_inplace_f16", - "REPLS": { "TYPE": "f16", "FUNC": "SILU_FUNC", "EXT_PARAMS": "", "MUTATE": "src" }, - "DECLS": ["INPLACE"] - }, - - { - "SHADER_NAME": "exp_f32", - "REPLS": { "TYPE": "f32", "FUNC": "EXP_FUNC", "EXT_PARAMS": "", "MUTATE": "dst" }, - "DECLS": ["NOT_INPLACE"] - }, - { - "SHADER_NAME": "exp_f16", - "REPLS": { "TYPE": "f16", "FUNC": "EXP_FUNC", "EXT_PARAMS": "", "MUTATE": "dst" }, - "DECLS": ["NOT_INPLACE"] - }, - { - "SHADER_NAME": "exp_inplace_f32", - "REPLS": { "TYPE": "f32", "FUNC": "EXP_FUNC", "EXT_PARAMS": "", "MUTATE": "src" }, - "DECLS": ["INPLACE"] - }, - { - "SHADER_NAME": "exp_inplace_f16", - "REPLS": { "TYPE": "f16", "FUNC": "EXP_FUNC", "EXT_PARAMS": "", "MUTATE": "src" }, - "DECLS": ["INPLACE"] - }, - - { - "SHADER_NAME": "hardsigmoid_f32", - "REPLS": { "TYPE": "f32", "FUNC": "HARDSIGMOID_FUNC", "EXT_PARAMS": "", "MUTATE": "dst" }, - "DECLS": ["NOT_INPLACE"] - }, - { - "SHADER_NAME": "hardsigmoid_f16", - "REPLS": { "TYPE": "f16", "FUNC": "HARDSIGMOID_FUNC", "EXT_PARAMS": "", "MUTATE": "dst" }, - "DECLS": ["NOT_INPLACE"] - }, - { - "SHADER_NAME": "hardsigmoid_inplace_f32", - "REPLS": { "TYPE": "f32", "FUNC": "HARDSIGMOID_FUNC", "EXT_PARAMS": "", "MUTATE": "src" }, - "DECLS": ["INPLACE"] - }, - { - "SHADER_NAME": "hardsigmoid_inplace_f16", - "REPLS": { "TYPE": "f16", "FUNC": "HARDSIGMOID_FUNC", "EXT_PARAMS": "", "MUTATE": "src" }, - "DECLS": ["INPLACE"] - }, - - { - "SHADER_NAME": "hardswish_f32", - "REPLS": { "TYPE": "f32", "FUNC": "HARDSWISH_FUNC", "EXT_PARAMS": "", "MUTATE": "dst" }, - "DECLS": ["NOT_INPLACE"] - }, - { - "SHADER_NAME": "hardswish_f16", - "REPLS": { "TYPE": "f16", "FUNC": "HARDSWISH_FUNC", "EXT_PARAMS": "", "MUTATE": "dst" }, - "DECLS": ["NOT_INPLACE"] - }, - { - "SHADER_NAME": "hardswish_inplace_f32", - "REPLS": { "TYPE": "f32", "FUNC": "HARDSWISH_FUNC", "EXT_PARAMS": "", "MUTATE": "src" }, - "DECLS": ["INPLACE"] - }, - { - "SHADER_NAME": "hardswish_inplace_f16", - "REPLS": { "TYPE": "f16", "FUNC": "HARDSWISH_FUNC", "EXT_PARAMS": "", "MUTATE": "src" }, - "DECLS": ["INPLACE"] - }, - - { - "SHADER_NAME": "gelu_f32", - "REPLS": { "TYPE": "f32", "FUNC": "GELU_FUNC", "EXT_PARAMS": "", "MUTATE": "dst" }, - "DECLS": ["NOT_INPLACE"] - }, - { - "SHADER_NAME": "gelu_f16", - "REPLS": { "TYPE": "f16", "FUNC": "GELU_FUNC", "EXT_PARAMS": "", "MUTATE": "dst" }, - "DECLS": ["NOT_INPLACE"] - }, - { - "SHADER_NAME": "gelu_inplace_f32", - "REPLS": { "TYPE": "f32", "FUNC": "GELU_FUNC", "EXT_PARAMS": "", "MUTATE": "src" }, - "DECLS": ["INPLACE"] - }, - { - "SHADER_NAME": "gelu_inplace_f16", - "REPLS": { "TYPE": "f16", "FUNC": "GELU_FUNC", "EXT_PARAMS": "", "MUTATE": "src" }, - "DECLS": ["INPLACE"] - }, - - { - "SHADER_NAME": "gelu_quick_f32", - "REPLS": { "TYPE": "f32", "FUNC": "GELU_QUICK_FUNC", "EXT_PARAMS": "", "MUTATE": "dst" }, - "DECLS": ["NOT_INPLACE"] - }, - { - "SHADER_NAME": "gelu_quick_f16", - "REPLS": { "TYPE": "f16", "FUNC": "GELU_QUICK_FUNC", "EXT_PARAMS": "", "MUTATE": "dst" }, - "DECLS": ["NOT_INPLACE"] - }, - { - "SHADER_NAME": "gelu_quick_inplace_f32", - "REPLS": { "TYPE": "f32", "FUNC": "GELU_QUICK_FUNC", "EXT_PARAMS": "", "MUTATE": "src" }, - "DECLS": ["INPLACE"] - }, - { - "SHADER_NAME": "gelu_quick_inplace_f16", - "REPLS": { "TYPE": "f16", "FUNC": "GELU_QUICK_FUNC", "EXT_PARAMS": "", "MUTATE": "src" }, - "DECLS": ["INPLACE"] - }, - - { - "SHADER_NAME": "xielu_f32", - "REPLS": { "TYPE": "f32", "FUNC": "XIELU_FUNC", "EXT_PARAMS": "alpha_n: f32, alpha_p: f32, beta: f32, eps: f32", "MUTATE": "dst" }, - "DECLS": ["NOT_INPLACE"] - }, - { - "SHADER_NAME": "xielu_f16", - "REPLS": { "TYPE": "f16", "FUNC": "XIELU_FUNC", "EXT_PARAMS": "alpha_n: f32, alpha_p: f32, beta: f32, eps: f32", "MUTATE": "dst" }, - "DECLS": ["NOT_INPLACE"] - }, - { - "SHADER_NAME": "xielu_inplace_f32", - "REPLS": { "TYPE": "f32", "FUNC": "XIELU_FUNC", "EXT_PARAMS": "alpha_n: f32, alpha_p: f32, beta: f32, eps: f32", "MUTATE": "src" }, - "DECLS": ["INPLACE"] - }, - { - "SHADER_NAME": "xielu_inplace_f16", - "REPLS": { "TYPE": "f16", "FUNC": "XIELU_FUNC", "EXT_PARAMS": "alpha_n: f32, alpha_p: f32, beta: f32, eps: f32", "MUTATE": "src" }, - "DECLS": ["INPLACE"] - }, - { - "SHADER_NAME": "gelu_erf_f32", - "REPLS": { "TYPE": "f32", "FUNC": "GELU_ERF_FUNC", "EXT_PARAMS": "", "MUTATE": "dst" }, - "DECLS": ["NOT_INPLACE"] - }, - { - "SHADER_NAME": "gelu_erf_f16", - "REPLS": { "TYPE": "f16", "FUNC": "GELU_ERF_FUNC", "EXT_PARAMS": "", "MUTATE": "dst" }, - "DECLS": ["NOT_INPLACE"] - }, - { - "SHADER_NAME": "gelu_erf_inplace_f32", - "REPLS": { "TYPE": "f32", "FUNC": "GELU_ERF_FUNC", "EXT_PARAMS": "", "MUTATE": "src" }, - "DECLS": ["INPLACE"] - }, - { - "SHADER_NAME": "gelu_erf_inplace_f16", - "REPLS": { "TYPE": "f16", "FUNC": "GELU_ERF_FUNC", "EXT_PARAMS": "", "MUTATE": "src" }, - "DECLS": ["INPLACE"] - }, - - { - "SHADER_NAME": "ceil_f32", - "REPLS": { "TYPE": "f32", "FUNC": "CEIL_FUNC", "EXT_PARAMS": "", "MUTATE": "dst" }, - "DECLS": ["NOT_INPLACE"] - }, - { - "SHADER_NAME": "ceil_f16", - "REPLS": { "TYPE": "f16", "FUNC": "CEIL_FUNC", "EXT_PARAMS": "", "MUTATE": "dst" }, - "DECLS": ["NOT_INPLACE"] - }, - { - "SHADER_NAME": "ceil_inplace_f32", - "REPLS": { "TYPE": "f32", "FUNC": "CEIL_FUNC", "EXT_PARAMS": "", "MUTATE": "src" }, - "DECLS": ["INPLACE"] - }, - { - "SHADER_NAME": "ceil_inplace_f16", - "REPLS": { "TYPE": "f16", "FUNC": "CEIL_FUNC", "EXT_PARAMS": "", "MUTATE": "src" }, - "DECLS": ["INPLACE"] - } -] - -#end(VARIANTS) - -#define(DECLS) - -#decl(INPLACE) - -@group(0) @binding(1) -var<uniform> params: Params; - -#enddecl(INPLACE) - -#decl(NOT_INPLACE) - -@group(0) @binding(1) -var<storage, read_write> dst: array<{{TYPE}}>; - -@group(0) @binding(2) -var<uniform> params: Params; - -#enddecl(NOT_INPLACE) - -#end(DECLS) - -#define(SHADER) - -enable f16; - -fn update(dst_i: u32, src_i: u32) { - {{FUNC}} -} - -@group(0) @binding(0) -var<storage, read_write> src: array<{{TYPE}}>; - -DECLS - -struct Params { - ne: u32, // total number of elements - offset_src: u32, // in elements - offset_dst: u32, // in elements - - // Strides (in elements) — may be permuted - stride_src0: u32, - stride_src1: u32, - stride_src2: u32, - stride_src3: u32, - - stride_dst0: u32, - stride_dst1: u32, - stride_dst2: u32, - stride_dst3: u32, - - // Logical shapes - src_ne0: u32, - src_ne1: u32, - src_ne2: u32, - - dst_ne0: u32, - dst_ne1: u32, - dst_ne2: u32, - - {{EXT_PARAMS}} -}; - -override wg_size: u32; -@compute @workgroup_size(wg_size) -fn main(@builtin(global_invocation_id) gid: vec3<u32>) { - if (gid.x >= params.ne) { - return; - } - - var i = gid.x; - let i3 = i / (params.src_ne2 * params.src_ne1 * params.src_ne0); - i = i % (params.src_ne2 * params.src_ne1 * params.src_ne0); - let i2 = i / (params.src_ne1 * params.src_ne0); - i = i % (params.src_ne1 * params.src_ne0); - let i1 = i / params.src_ne0; - let i0 = i % params.src_ne0; - - var j = gid.x; - let j3 = j / (params.dst_ne2 * params.dst_ne1 * params.dst_ne0); - j = j % (params.dst_ne2 * params.dst_ne1 * params.dst_ne0); - let j2 = j / (params.dst_ne1 * params.dst_ne0); - j = j % (params.dst_ne1 * params.dst_ne0); - let j1 = j / params.dst_ne0; - let j0 = j % params.dst_ne0; - - let src_idx = i0 * params.stride_src0 + i1 * params.stride_src1 + - i2 * params.stride_src2 + i3 * params.stride_src3; - - let dst_idx = j0 * params.stride_dst0 + j1 * params.stride_dst1 + - j2 * params.stride_dst2 + j3 * params.stride_dst3; - - - update(params.offset_dst + dst_idx, params.offset_src + src_idx); -} - -#end(SHADER) - diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/upscale.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/upscale.wgsl new file mode 100644 index 00000000000..e9ef8822644 --- /dev/null +++ b/ggml/src/ggml-webgpu/wgsl-shaders/upscale.wgsl @@ -0,0 +1,240 @@ +#if defined(SRC_F16) || defined(DST_F16) +enable f16; +#endif + +#ifdef SRC_F16 +#define SRC_TYPE f16 +#else +#define SRC_TYPE f32 +#endif + +#ifdef DST_F16 +#define DST_TYPE f16 +#else +#define DST_TYPE f32 +#endif + +@group(0) @binding(0) +var<storage, read_write> input: array<SRC_TYPE>; + +@group(0) @binding(1) +var<storage, read_write> output: array<DST_TYPE>; + +struct Params { + offset_i: u32, + offset_o: u32, + + // element strides + si0: u32, si1: u32, si2: u32, si3: u32, + so0: u32, so1: u32, so2: u32, so3: u32, + + src_w: u32, + src_h: u32, + src_z: u32, + src_n: u32, + + dst_w: u32, + dst_h: u32, + dst_z: u32, + dst_n: u32, + + mode_flags: u32, +}; + +@group(0) @binding(2) +var<uniform> params: Params; + +const GGML_SCALE_FLAG_ALIGN_CORNERS: u32 = 1u << 8u; + +fn get_clamped_input(x: i32, y: i32, z: u32, n: u32) -> f32 { + let cx = u32(clamp(x, 0, i32(params.src_w) - 1)); + let cy = u32(clamp(y, 0, i32(params.src_h) - 1)); + let i = params.offset_i + cx * params.si0 + cy * params.si1 + z * params.si2 + n * params.si3; + return f32(input[i]); +} + +fn cubic_weight(t: f32, a: f32) -> f32 { + let at = abs(t); + if (at <= 1.0) { + return (a + 2.0) * at * at * at - (a + 3.0) * at * at + 1.0; + } else if (at <= 2.0) { + return a * at * at * at - 5.0 * a * at * at + 8.0 * a * at - 4.0 * a; + } else { + return 0.0; + } +} + +@compute @workgroup_size(WG_SIZE) +fn main( + @builtin(global_invocation_id) gid: vec3<u32>, + @builtin(num_workgroups) num_wg: vec3<u32> +) { + + let i_out = gid.x + (num_wg.x * u32(WG_SIZE)) * gid.y; + let total = params.dst_w * params.dst_h * params.dst_z * params.dst_n; + + if (i_out >= total) { + return; + } + + // decode (x, y, z, n) + var i = i_out; + let x_dst = i % params.dst_w; + i = i / params.dst_w; + let y_dst = i % params.dst_h; + i = i / params.dst_h; + let z_dst = i % params.dst_z; + let n_dst = i / params.dst_z; + + // scale factors + var sf0 = f32(params.dst_w) / f32(params.src_w); + var sf1 = f32(params.dst_h) / f32(params.src_h); + var sf2 = f32(params.dst_z) / f32(params.src_z); + var sf3 = f32(params.dst_n) / f32(params.src_n); + + let align_corners = (params.mode_flags & GGML_SCALE_FLAG_ALIGN_CORNERS) != 0; + + // pixel_offset: 0.5 for half-pixel-center (default), 0.0 for align_corners + var pixel_offset = 0.5; + if (align_corners) { + pixel_offset = 0.0; + if (params.dst_w > 1 && params.src_w > 1) { + sf0 = f32(params.dst_w - 1) / f32(params.src_w - 1); + } + if (params.dst_h > 1 && params.src_h > 1) { + sf1 = f32(params.dst_h - 1) / f32(params.src_h - 1); + } + } + + let z_src = min(params.src_z - 1, u32(floor(f32(z_dst) / sf2))); + let n_src = min(params.src_n - 1, u32(floor(f32(n_dst) / sf3))); + + var result = 0.0; + +#if defined(NEAREST) + + let x_src = min(params.src_w - 1, u32(floor(f32(x_dst) / sf0))); + let y_src = min(params.src_h - 1, u32(floor(f32(y_dst) / sf1))); + + result = get_clamped_input(i32(x_src), i32(y_src), z_src, n_src); + +#elif defined(BILINEAR) + +#if defined(ANTIALIAS) + + // Antialiased bilinear: triangle filter over a variable support region. + let support0 = max(1.0f / sf0, 1.0f); + let support1 = max(1.0f / sf1, 1.0f); + let invscale0 = 1.0 / support0; + let invscale1 = 1.0 / support1; + + let fx = (f32(x_dst) + pixel_offset) / sf0; + let fy = (f32(y_dst) + pixel_offset) / sf1; + + let x_min = max(i32(fx - support0 + pixel_offset), 0); + let y_min = max(i32(fy - support1 + pixel_offset), 0); + let x_max = min(i32(fx + support0 + pixel_offset), i32(params.src_w)); + let y_max = min(i32(fy + support1 + pixel_offset), i32(params.src_h)); + + var weighted_sum = 0.0; + var total_weight = 0.0; + + for (var x = x_min; x < x_max; x += 1) { + let wx = max(1.0 - abs(f32(x) - fx + pixel_offset) * invscale0, 0.0); + for (var y = y_min; y < y_max; y += 1) { + let wy = max(1.0 - abs(f32(y) - fy + pixel_offset) * invscale1, 0.0); + let w = wx * wy; + if (w > 0.0) { + weighted_sum += get_clamped_input(x, y, z_src, n_src) * w; + total_weight += w; + } + } + } + + if (total_weight > 0.0) { + result = weighted_sum / total_weight; + } + +#else + + let fx = (f32(x_dst) + pixel_offset) / sf0 - pixel_offset; + let fy = (f32(y_dst) + pixel_offset) / sf1 - pixel_offset; + let x0 = i32(floor(fx)); + let y0 = i32(floor(fy)); + let dx = clamp(fx - f32(x0), 0.0, 1.0); + let dy = clamp(fy - f32(y0), 0.0, 1.0); + let a = get_clamped_input(x0, y0, z_src, n_src); + let b = get_clamped_input(x0 + 1, y0, z_src, n_src); + let c = get_clamped_input(x0, y0 + 1, z_src, n_src); + let d = get_clamped_input(x0 + 1, y0 + 1, z_src, n_src); + + let wa = (1.0 - dx) * (1.0 - dy); + let wb = dx * (1.0 - dy); + let wc = (1.0 - dx) * dy; + let wd = dx * dy; + + result = a * wa + b * wb + c * wc + d * wd; + +#endif + +#elif defined(BICUBIC) + + // bicubic convolution with alpha = -0.75 (PyTorch default) + let alpha = -0.75; + let fx = (f32(x_dst) + pixel_offset) / sf0 - pixel_offset; + let fy = (f32(y_dst) + pixel_offset) / sf1 - pixel_offset; + + let x0 = i32(floor(fx)); + let y0 = i32(floor(fy)); + let dx = fx - f32(x0); + let dy = fy - f32(y0); + + // horizontal weights for offsets -1, 0, 1, 2 + let wx0 = cubic_weight(dx + 1.0, alpha); + let wx1 = cubic_weight(dx, alpha); + let wx2 = cubic_weight(1.0 - dx, alpha); + let wx3 = cubic_weight(2.0 - dx, alpha); + + // vertical weights for offsets -1, 0, 1, 2 + let wy0 = cubic_weight(dy + 1.0, alpha); + let wy1 = cubic_weight(dy, alpha); + let wy2 = cubic_weight(1.0 - dy, alpha); + let wy3 = cubic_weight(2.0 - dy, alpha); + + // intermediate horizontal interpolation for 4x4 grid of pixels + // x0-1, x0, x0+1, x0+2, y0-1 + let p0 = get_clamped_input(x0 - 1, y0 - 1, z_src, n_src); + let p1 = get_clamped_input(x0, y0 - 1, z_src, n_src); + let p2 = get_clamped_input(x0 + 1, y0 - 1, z_src, n_src); + let p3 = get_clamped_input(x0 + 2, y0 - 1, z_src, n_src); + let row0 = p0 * wx0 + p1 * wx1 + p2 * wx2 + p3 * wx3; + + // x0-1, x0, x0+1, x0+2, y0 + let q0 = get_clamped_input(x0 - 1, y0, z_src, n_src); + let q1 = get_clamped_input(x0, y0, z_src, n_src); + let q2 = get_clamped_input(x0 + 1, y0, z_src, n_src); + let q3 = get_clamped_input(x0 + 2, y0, z_src, n_src); + let row1 = q0 * wx0 + q1 * wx1 + q2 * wx2 + q3 * wx3; + + // x0-1, x0, x0+1, x0+2, y0+1 + let r0 = get_clamped_input(x0 - 1, y0 + 1, z_src, n_src); + let r1 = get_clamped_input(x0, y0 + 1, z_src, n_src); + let r2 = get_clamped_input(x0 + 1, y0 + 1, z_src, n_src); + let r3 = get_clamped_input(x0 + 2, y0 + 1, z_src, n_src); + let row2 = r0 * wx0 + r1 * wx1 + r2 * wx2 + r3 * wx3; + + // x0-1, x0, x0+1, x0+2, y0+2 + let s0 = get_clamped_input(x0 - 1, y0 + 2, z_src, n_src); + let s1 = get_clamped_input(x0, y0 + 2, z_src, n_src); + let s2 = get_clamped_input(x0 + 1, y0 + 2, z_src, n_src); + let s3 = get_clamped_input(x0 + 2, y0 + 2, z_src, n_src); + let row3 = s0 * wx0 + s1 * wx1 + s2 * wx2 + s3 * wx3; + + // final vertical interpolation + result = row0 * wy0 + row1 * wy1 + row2 * wy2 + row3 * wy3; + +#endif + + let dst_idx = params.offset_o + x_dst * params.so0 + y_dst * params.so1 + z_dst * params.so2 + n_dst * params.so3; + output[dst_idx] = DST_TYPE(result); +} diff --git a/ggml/src/ggml-zdnn/ggml-zdnn.cpp b/ggml/src/ggml-zdnn/ggml-zdnn.cpp index edbeb8eef24..639b818d128 100644 --- a/ggml/src/ggml-zdnn/ggml-zdnn.cpp +++ b/ggml/src/ggml-zdnn/ggml-zdnn.cpp @@ -58,6 +58,10 @@ static enum ggml_status ggml_zdnn_graph_compute(ggml_backend_t backend, ggml_cgr continue; } + if ((node->flags & GGML_TENSOR_FLAG_COMPUTE) == 0) { + continue; + } + bool ok = ggml_zdnn_compute_forward(ctx, node); if (!ok) { GGML_LOG_ERROR("%s: unsupported op %s (%s)\n", @@ -309,6 +313,8 @@ static ggml_backend_buffer_i ggml_backend_zdnn_buffer_i = { /* .memset_tensor = */ ggml_backend_zdnn_buffer_memset_tensor, /* .set_tensor = */ ggml_backend_zdnn_buffer_set_tensor, /* .get_tensor = */ ggml_backend_zdnn_buffer_get_tensor, + /* .set_tensor_2d = */ NULL, + /* .get_tensor_2d = */ NULL, /* .cpy_tensor = */ NULL, /* .clear = */ ggml_backend_zdnn_buffer_clear, /* .reset = */ NULL, @@ -368,7 +374,8 @@ static size_t ggml_backend_zdnn_buffer_type_get_alignment(ggml_backend_buffer_ty } static bool ggml_backend_zdnn_buffer_type_is_host(ggml_backend_buffer_type_t buft) { - return true; + /* while it resides in host memory, additional transformation is needed */ + return false; GGML_UNUSED(buft); } @@ -412,20 +419,22 @@ static enum ggml_status ggml_backend_zdnn_graph_compute(ggml_backend_t backend, } static ggml_backend_i ggml_backend_zdnn_i = { - /* .get_name = */ ggml_backend_zdnn_name, - /* .free = */ ggml_backend_zdnn_free, - /* .set_tensor_async = */ NULL, - /* .get_tensor_async = */ NULL, - /* .cpy_tensor_async = */ NULL, - /* .synchronize = */ NULL, - /* .graph_plan_create = */ NULL, - /* .graph_plan_free = */ NULL, - /* .graph_plan_update = */ NULL, - /* .graph_plan_compute = */ NULL, - /* .graph_compute = */ ggml_backend_zdnn_graph_compute, - /* .event_record = */ NULL, - /* .event_wait = */ NULL, - /* .graph_optimize = */ NULL, + /* .get_name = */ ggml_backend_zdnn_name, + /* .free = */ ggml_backend_zdnn_free, + /* .set_tensor_async = */ NULL, + /* .get_tensor_async = */ NULL, + /* .set_tensor_2d_async = */ NULL, + /* .get_tensor_2d_async = */ NULL, + /* .cpy_tensor_async = */ NULL, + /* .synchronize = */ NULL, + /* .graph_plan_create = */ NULL, + /* .graph_plan_free = */ NULL, + /* .graph_plan_update = */ NULL, + /* .graph_plan_compute = */ NULL, + /* .graph_compute = */ ggml_backend_zdnn_graph_compute, + /* .event_record = */ NULL, + /* .event_wait = */ NULL, + /* .graph_optimize = */ NULL, }; static ggml_guid_t ggml_backend_zdnn_guid(void) { diff --git a/ggml/src/ggml-zendnn/CMakeLists.txt b/ggml/src/ggml-zendnn/CMakeLists.txt index bdbfc74369f..e4ba9cfbd0f 100644 --- a/ggml/src/ggml-zendnn/CMakeLists.txt +++ b/ggml/src/ggml-zendnn/CMakeLists.txt @@ -1,12 +1,19 @@ ggml_add_backend_library(ggml-zendnn ggml-zendnn.cpp) -# Get ZenDNN path if (NOT DEFINED ZENDNN_ROOT OR ZENDNN_ROOT STREQUAL "") set(ZENDNN_ROOT "$ENV{ZENDNN_ROOT}") endif() -# Check if path is still empty or OFF +if (BUILD_SHARED_LIBS) + set(ZENDNN_SHARED_LIB ON) + set(ZENDNN_ARCHIVE_LIB OFF) +else() + set(ZENDNN_SHARED_LIB OFF) + set(ZENDNN_ARCHIVE_LIB ON) +endif() + +# Download and build ZenDNN if not provided if (NOT ZENDNN_ROOT OR ZENDNN_ROOT STREQUAL "" OR ZENDNN_ROOT STREQUAL "OFF") message(STATUS "ZENDNN_ROOT not set. Automatically downloading and building ZenDNN...") message(STATUS "This will take several minutes on first build...") @@ -21,7 +28,7 @@ if (NOT ZENDNN_ROOT OR ZENDNN_ROOT STREQUAL "" OR ZENDNN_ROOT STREQUAL "OFF") ExternalProject_Add( zendnn GIT_REPOSITORY https://github.com/amd/ZenDNN.git - GIT_TAG zendnnl + GIT_TAG 253b94ce0d7e9284c265fefb485714944caff9d3 # ZenDNN-2026-WW19 PREFIX ${ZENDNN_PREFIX} SOURCE_DIR ${ZENDNN_SOURCE_DIR} BINARY_DIR ${ZENDNN_BUILD_DIR} @@ -32,7 +39,9 @@ if (NOT ZENDNN_ROOT OR ZENDNN_ROOT STREQUAL "" OR ZENDNN_ROOT STREQUAL "OFF") -DZENDNNL_BUILD_DOXYGEN=OFF -DZENDNNL_BUILD_GTEST=OFF -DZENDNNL_BUILD_BENCHDNN=OFF - # Enable ALL matmul algorithm backends + -DZENDNNL_DEPENDS_FBGEMM=OFF + -DZENDNNL_LIB_BUILD_ARCHIVE=${ZENDNN_ARCHIVE_LIB} + -DZENDNNL_LIB_BUILD_SHARED=${ZENDNN_SHARED_LIB} -DZENDNNL_DEPENDS_AOCLDLP=ON -DZENDNNL_DEPENDS_ONEDNN=ON -DZENDNNL_DEPENDS_LIBXSMM=ON @@ -45,47 +54,37 @@ if (NOT ZENDNN_ROOT OR ZENDNN_ROOT STREQUAL "" OR ZENDNN_ROOT STREQUAL "OFF") LOG_INSTALL ON ) - # Add dependency so ZenDNN builds before our library add_dependencies(ggml-zendnn zendnn) - - # Set ZENDNN_ROOT to the installation directory set(ZENDNN_ROOT ${ZENDNN_INSTALL_DIR}) - message(STATUS "ZenDNN will be built to: ${ZENDNN_ROOT}") else() message(STATUS "Using custom ZenDNN installation at: ${ZENDNN_ROOT}") endif() -# ZenDNN headers + libs target_include_directories(ggml-zendnn PRIVATE ${ZENDNN_ROOT}/zendnnl/include - ${ZENDNN_ROOT}/deps/aocldlp/include - ${ZENDNN_ROOT}/deps/aoclutils/include ${ZENDNN_ROOT}/deps/json/include - ${ZENDNN_ROOT}/deps/libxsmm/include + ${ZENDNN_ROOT}/deps/aoclutils/include + ${ZENDNN_ROOT}/deps/aocldlp/include ${ZENDNN_ROOT}/deps/onednn/include -) + ${ZENDNN_ROOT}/deps/libxsmm/include) -target_link_directories(ggml-zendnn PRIVATE - ${ZENDNN_ROOT}/zendnnl/lib - ${ZENDNN_ROOT}/deps/aocldlp/lib - ${ZENDNN_ROOT}/deps/aoclutils/lib - ${ZENDNN_ROOT}/deps/libxsmm/lib - ${ZENDNN_ROOT}/deps/onednn/lib -) +if (ZENDNN_SHARED_LIB) + target_link_directories(ggml-zendnn PRIVATE ${ZENDNN_ROOT}/zendnnl/lib) + target_link_libraries(ggml-zendnn PRIVATE zendnnl) +elseif (ZENDNN_ARCHIVE_LIB) + target_link_libraries(ggml-zendnn PRIVATE + ${ZENDNN_ROOT}/zendnnl/lib/libzendnnl_archive.a + ${ZENDNN_ROOT}/deps/aoclutils/${CMAKE_INSTALL_LIBDIR}/libaoclutils.a + ${ZENDNN_ROOT}/deps/aoclutils/${CMAKE_INSTALL_LIBDIR}/libau_cpuid.a + ${ZENDNN_ROOT}/deps/aocldlp/lib/libaocl-dlp.a + ${ZENDNN_ROOT}/deps/onednn/${CMAKE_INSTALL_LIBDIR}/libdnnl.a + ${ZENDNN_ROOT}/deps/libxsmm/lib/libxsmm.a + ${ZENDNN_ROOT}/deps/libxsmm/lib/libxsmmext.a + ${ZENDNN_ROOT}/deps/libxsmm/lib/libxsmmnoblas.a) +endif() -target_link_libraries(ggml-zendnn PRIVATE - zendnnl_archive # ZenDNN main - aocl-dlp # AOCL libraries - aoclutils - au_cpuid - dnnl # OneDNN - xsmm # libxsmm small matrix math - xsmmext - xsmmnoblas - m - pthread -) +target_link_libraries(ggml-zendnn PRIVATE m pthread) if (GGML_OPENMP) target_link_libraries(ggml-zendnn PRIVATE OpenMP::OpenMP_CXX) diff --git a/ggml/src/ggml-zendnn/ggml-zendnn.cpp b/ggml/src/ggml-zendnn/ggml-zendnn.cpp index fd07f983da7..3c33dcb11a0 100644 --- a/ggml/src/ggml-zendnn/ggml-zendnn.cpp +++ b/ggml/src/ggml-zendnn/ggml-zendnn.cpp @@ -2,7 +2,10 @@ #include "ggml-backend-impl.h" #include "ggml-impl.h" -#include "ggml-cpu.h" + +#define GGML_COMMON_DECL_CPP +#include "ggml-common.h" + #include "zendnnl.hpp" #include <cstring> @@ -20,6 +23,8 @@ zendnnl::common::data_type_t ggml_to_zendnn_type() { return zendnnl::common::data_type_t::f32; } else if constexpr (std::is_same_v<T, ggml_bf16_t>) { return zendnnl::common::data_type_t::bf16; + } else if constexpr (std::is_same_v<T, block_q8_0>) { + return zendnnl::common::data_type_t::s8; } else { return zendnnl::common::data_type_t::none; } @@ -42,13 +47,25 @@ static bool ggml_zendnn_matmul(ggml_backend_zendnn_context * ctx, int64_t m, int const TA * A, int64_t lda, const TB * B, int64_t ldb, TC * C, int64_t ldc) { - zendnnl::lowoha::lowoha_params params; + zendnnl::lowoha::matmul::matmul_params params; params.dtypes.src = ggml_to_zendnn_type<TB>(); params.dtypes.wei = ggml_to_zendnn_type<TA>(); params.dtypes.dst = ggml_to_zendnn_type<TC>(); params.num_threads = ctx->n_threads; - zendnnl::lowoha::status_t status = zendnnl::lowoha::matmul_direct( + zendnnl::lowoha::matmul::matmul_batch_params_t batch_params; + + if constexpr (std::is_same_v<TA, block_q8_0>) { + params.dtypes.compute = zendnnl::common::data_type_t::s8; + const int64_t num_groups = k / QK8_0; + params.dynamic_quant = true; + params.quant_params.src_scale.buff = nullptr; + params.quant_params.src_scale.dt = zendnnl::common::data_type_t::bf16; + params.quant_params.src_scale.dims = {n, num_groups}; + params.packing.pack_format_b = 1; + } + + zendnnl::error_handling::status_t status = zendnnl::lowoha::matmul::matmul_direct( 'r', false, true, // row-major, don't transpose B, transpose A (because it's column-major) n, // M: rows of B and C m, // N: cols of A^T and C @@ -60,18 +77,18 @@ static bool ggml_zendnn_matmul(ggml_backend_zendnn_context * ctx, int64_t m, int 0.0f, // beta C, ldc, // output C[n,m] true, // is_weights_const - {}, // batch_params + batch_params, // batch_params params // params ); - if (status != zendnnl::lowoha::status_t::success) { + if (status != zendnnl::error_handling::status_t::success) { GGML_LOG_ERROR("%s, ZenDNN matmul failed: status=%d\n", __func__, static_cast<int>(status)); return false; } return true; } -static bool ggml_zendnn_sgemm(ggml_backend_zendnn_context * ctx, int64_t m, int64_t n, int64_t k, +static bool ggml_zendnn_gemm(ggml_backend_zendnn_context * ctx, int64_t m, int64_t n, int64_t k, const void * A, int64_t lda, const void * B, int64_t ldb, void * C, int64_t ldc, int Atype, int Btype, int Ctype) { @@ -108,6 +125,14 @@ static bool ggml_zendnn_sgemm(ggml_backend_zendnn_context * ctx, int64_t m, int6 (const ggml_bf16_t *)B, ldb, (float *)C, ldc); return false; + case GGML_TYPE_Q8_0: + if (Btype != GGML_TYPE_F32 || Ctype != GGML_TYPE_F32) + return false; + return ggml_zendnn_matmul<block_q8_0, float, float>( + ctx, m, n, k, + (const block_q8_0 *)A, lda, + (const float *)B, ldb, + (float *)C, ldc); default: return false; // unsupported type } @@ -122,8 +147,8 @@ static void ggml_zendnn_compute_forward_mul_mat( GGML_TENSOR_BINARY_OP_LOCALS - ggml_type const vec_dot_type = ggml_get_type_traits_cpu(src0->type)->vec_dot_type; - ggml_from_float_t const from_float = ggml_get_type_traits_cpu(vec_dot_type)->from_float; + ggml_type const vec_dot_type = src0->type; + ggml_from_float_t const from_float = ggml_get_type_traits(vec_dot_type)->from_float_ref; GGML_ASSERT(ne0 == ne01); GGML_ASSERT(ne1 == ne11); @@ -145,7 +170,9 @@ static void ggml_zendnn_compute_forward_mul_mat( const int64_t r3 = ne13/ne03; void * work_data = ctx->work_data.get(); - if (src1->type != vec_dot_type) { + + // ZenDNN requires FP32 for dynamic quantization, so conversion is skipped + if (src1->type != vec_dot_type && src0->type != GGML_TYPE_Q8_0) { const size_t nbw1 = ggml_row_size(vec_dot_type, ne10); const size_t nbw2 = nbw1 * ne11; const size_t nbw3 = nbw2 * ne12; @@ -171,9 +198,9 @@ static void ggml_zendnn_compute_forward_mul_mat( for (int64_t i13 = 0; i13 < ne13; i13++) { for (int64_t i12 = 0; i12 < ne12; i12++) { - const void* wdata = src1->type == vec_dot_type ? src1->data : work_data; + const void* wdata = (src1->type == vec_dot_type || src0->type == GGML_TYPE_Q8_0) ? src1->data : work_data; const size_t row_size = ggml_row_size(vec_dot_type, ne10); - if (!ggml_zendnn_sgemm(ctx, + if (!ggml_zendnn_gemm(ctx, ne01, // m ne11, // n ne10, // k @@ -184,9 +211,179 @@ static void ggml_zendnn_compute_forward_mul_mat( static_cast<char *>(dst->data) + i12*nb2 + i13*nb3, ne01, // ldc src0->type, - vec_dot_type, + src0->type == GGML_TYPE_Q8_0 ? GGML_TYPE_F32 : vec_dot_type, dst->type)) - GGML_ABORT("%s: ZenDNN sgemm failed\n", __func__); + GGML_ABORT("%s: ZenDNN gemm failed\n", __func__); + } + } +} + +struct mmid_row_mapping { + int32_t i1; + int32_t i2; +}; + +static void ggml_zendnn_compute_forward_mul_mat_id( + ggml_backend_zendnn_context * ctx, + ggml_tensor * dst) { + + const ggml_tensor * src0 = dst->src[0]; // expert weights + const ggml_tensor * src1 = dst->src[1]; // inputs + const ggml_tensor * ids = dst->src[2]; // expert ids + + GGML_TENSOR_BINARY_OP_LOCALS + + // exit for no tokens to process + if (ne2 == 0 || ne11 == 0) { + return; + } + + ggml_type const vec_dot_type = src0->type; + ggml_from_float_t const from_float = ggml_get_type_traits(vec_dot_type)->from_float_ref; + + // we don't support permuted src0 or src1 + GGML_ASSERT(nb00 == ggml_type_size(src0->type)); + GGML_ASSERT(nb10 == ggml_type_size(src1->type)); + + // dst cannot be transposed or permuted + GGML_ASSERT(nb0 == sizeof(float)); + GGML_ASSERT(nb0 <= nb1); + GGML_ASSERT(nb1 <= nb2); + GGML_ASSERT(nb2 <= nb3); + + GGML_ASSERT(ne03 == 1); + GGML_ASSERT(ne13 == 1); + GGML_ASSERT(ne3 == 1); + + // row groups + const int n_ids = ids->ne[0]; // n_expert_used + const int n_as = ne02; // n_experts + + std::vector<int64_t> matrix_row_counts(n_as, 0); + std::vector<std::vector<mmid_row_mapping>> matrix_rows(n_as); + + int64_t max_rows = 0; + // group rows by expert (preprocessing step) + for (int64_t iid1 = 0; iid1 < ids->ne[1]; ++iid1) { + for (int id = 0; id < n_ids; ++id) { + const int32_t i02 = *(const int32_t *)((const char *)ids->data + iid1*ids->nb[1] + id*ids->nb[0]); + + GGML_ASSERT(i02 >= 0 && i02 < n_as); + + matrix_rows[i02].push_back({id, iid1}); + matrix_row_counts[i02]++; + if (matrix_row_counts[i02] > max_rows) { + max_rows = matrix_row_counts[i02]; + } + } + } + + if (max_rows == 0) { + return; // no rows to process + } + + const size_t row_size = ggml_row_size(vec_dot_type, ne10); + + // size for converting src1 rows to vec_dot_type if needed + const size_t nbw1 = row_size; + const size_t nbw2 = nbw1 * ne11; + const size_t nbw3 = nbw2 * ne12; + const size_t src1_conv_size = (src1->type != vec_dot_type && src0->type != GGML_TYPE_Q8_0) ? ne13 * nbw3 : 0; + + // For Q8_0, src1 is always F32; the gather buffer must hold F32 rows (ne10*4 bytes), + // not Q8_0-encoded rows (row_size ≈ ne10/32*34 bytes) — they differ by ~4x. + const size_t f32_row_size = (size_t)ne10 * sizeof(float); + const size_t gather_row_size = (src0->type == GGML_TYPE_Q8_0) ? f32_row_size : row_size; + + // size for MoE gather/scatter buffers + const size_t wdata_cur_size = max_rows * gather_row_size; + const size_t dst_cur_size = max_rows * ggml_row_size(dst->type, ne01); + + // allocate single buffer for all needs + const size_t total_size = src1_conv_size + wdata_cur_size + dst_cur_size; + if (ctx->work_size < total_size) { + ctx->work_data.reset(new char[total_size]); + ctx->work_size = total_size; + } + + // partition the buffer + char * work_data = ctx->work_data.get(); + char * wdata_cur = work_data + src1_conv_size; + char * dst_cur = wdata_cur + wdata_cur_size; + + // ZenDNN requires FP32 for dynamic quantization, so conversion is skipped + if (src1->type != vec_dot_type && src0->type != GGML_TYPE_Q8_0) { + GGML_ASSERT(src1->type == GGML_TYPE_F32); + + #pragma omp parallel for collapse(3) num_threads(ctx->n_threads) schedule(static) + for (int64_t i13 = 0; i13 < ne13; ++i13) { + for (int64_t i12 = 0; i12 < ne12; ++i12) { + for (int64_t i11 = 0; i11 < ne11; ++i11) { + const float * src1_f32 = (float *)((char *)src1->data + i11*nb11 + i12*nb12 + i13*nb13); + void * src1_conv = (char *)work_data + i11*nbw1 + i12*nbw2 + i13*nbw3; + from_float(src1_f32, src1_conv, ne10); + } + } + } + } + + const void * wdata = (src1->type == vec_dot_type || src0->type == GGML_TYPE_Q8_0) ? src1->data : work_data; + + // process each expert with gather -> gemm -> scatter pattern + for (int64_t cur_a = 0; cur_a < n_as; ++cur_a) { + const int64_t cne1 = matrix_row_counts[cur_a]; + + if (cne1 == 0) { + continue; + } + + const char * src0_cur = (const char *) src0->data + cur_a*nb02; + + // gather input rows for this expert + #pragma omp parallel for num_threads(ctx->n_threads) schedule(static) + for (int64_t ir1 = 0; ir1 < cne1; ++ir1) { + const mmid_row_mapping & row_mapping = matrix_rows[cur_a][ir1]; + const int64_t id = row_mapping.i1; + const int64_t i11 = id % ne11; + const int64_t i12 = row_mapping.i2; + + std::memcpy( + wdata_cur + ir1 * gather_row_size, + (const char *) wdata + (i11 + i12*ne11) * gather_row_size, + gather_row_size + ); + } + + // batched gemm for all tokens in this expert + if (!ggml_zendnn_gemm(ctx, + ne01, // m + cne1, // n + ne10, // k + src0_cur, + ne00, // lda + wdata_cur, + ne10, // ldb + dst_cur, + ne01, // ldc + src0->type, + src0->type == GGML_TYPE_Q8_0 ? GGML_TYPE_F32 : vec_dot_type, + dst->type)) { + GGML_ABORT("%s: ZenDNN gemm failed\n", __func__); + } + + // scatter output rows to destination + #pragma omp parallel for num_threads(ctx->n_threads) schedule(static) + for (int64_t ir1 = 0; ir1 < cne1; ++ir1) { + const mmid_row_mapping & row_mapping = matrix_rows[cur_a][ir1]; + const int64_t id = row_mapping.i1; + const int64_t i1 = id; + const int64_t i2 = row_mapping.i2; + + std::memcpy( + (char *) dst->data + i1*nb1 + i2*nb2, + dst_cur + ir1 * ggml_row_size(dst->type, ne01), + ggml_row_size(dst->type, ne01) + ); } } } @@ -211,10 +408,17 @@ static ggml_status ggml_backend_zendnn_graph_compute(ggml_backend_t backend, ggm for (int i = 0; i < cgraph->n_nodes; i++) { struct ggml_tensor * node = cgraph->nodes[i]; + if ((node->flags & GGML_TENSOR_FLAG_COMPUTE) == 0) { + continue; + } + switch (node->op) { case GGML_OP_MUL_MAT: ggml_zendnn_compute_forward_mul_mat(ctx, node); break; + case GGML_OP_MUL_MAT_ID: + ggml_zendnn_compute_forward_mul_mat_id(ctx, node); + break; case GGML_OP_NONE: case GGML_OP_RESHAPE: case GGML_OP_VIEW: @@ -237,6 +441,8 @@ static struct ggml_backend_i ggml_backend_zendnn_i = { /* .free = */ ggml_backend_zendnn_free, /* .set_tensor_async = */ NULL, /* .get_tensor_async = */ NULL, + /* .set_tensor_2d_async = */ NULL, + /* .get_tensor_2d_async = */ NULL, /* .cpy_tensor_async = */ NULL, /* .synchronize = */ NULL, /* .graph_plan_create = */ NULL, @@ -348,6 +554,12 @@ static ggml_backend_buffer_t ggml_backend_zendnn_device_buffer_from_host_ptr(ggm GGML_UNUSED(max_tensor_size); } +static bool ggml_zendnn_adaptive_fallback_enabled() { + static const bool enabled = std::getenv("GGML_ZENDNN_ADAPTIVE_FALLBACK") == nullptr || + std::atoi(std::getenv("GGML_ZENDNN_ADAPTIVE_FALLBACK")) != 0; + return enabled; +} + static bool ggml_backend_zendnn_device_supports_op(ggml_backend_dev_t dev, const struct ggml_tensor * op) { switch (op->op) { case GGML_OP_NONE: @@ -358,6 +570,7 @@ static bool ggml_backend_zendnn_device_supports_op(ggml_backend_dev_t dev, const return true; case GGML_OP_MUL_MAT: + case GGML_OP_MUL_MAT_ID: { const ggml_tensor * weights = op->src[0]; const ggml_tensor * inputs = op->src[1]; @@ -365,15 +578,39 @@ static bool ggml_backend_zendnn_device_supports_op(ggml_backend_dev_t dev, const const int64_t ne10 = inputs->ne[0]; const int64_t ne0 = op->ne[0]; const int64_t ne1 = op->ne[1]; - const int64_t min_batch = 1; - if (!ggml_is_contiguous(weights) || !ggml_is_contiguous(inputs) || - ne0 < min_batch || ne1 < min_batch || ne10 < min_batch) { + + if(!ggml_is_contiguous(weights) || !ggml_is_contiguous(inputs)) { + return false; + } + + if (ggml_zendnn_adaptive_fallback_enabled()) { + const int64_t K = inputs->ne[0]; + const int64_t N = (inputs->ne[1]*inputs->ne[2]*inputs->ne[3]); + const int64_t M = weights->ne[1]; + if(K <= 256 || N <= 128 || M <= 96) { return false; + } + } + else if (ne0 < min_batch || ne1 < min_batch || ne10 < min_batch) { + return false; + } + + // MUL_MAT_ID performs best with a moderate number of experts due to its + // gather + batched matmul + scatter approach. Future versions will leverage + // ZenDNN's grouped_gemm for better scalability with larger expert counts: + // https://github.com/amd/ZenDNN/blob/main/docs/operator/lowoha_group_gemm_operator.md + if (op->op == GGML_OP_MUL_MAT_ID) { + const int64_t n_experts = weights->ne[2]; + const int64_t max_experts = 32; + if (n_experts > max_experts) { + return false; + } } switch (weights->type) { case GGML_TYPE_F32: case GGML_TYPE_BF16: + case GGML_TYPE_Q8_0: return true; default: return false; diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index 09b8eb466d3..b43016c87d2 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -53,6 +53,21 @@ #define UNUSED GGML_UNUSED +uint64_t ggml_graph_next_uid(void) { +#ifdef _MSC_VER +#if defined(_WIN32) + static volatile LONG counter = 1; + return (uint64_t) InterlockedIncrement(&counter) - 1; +#else + static volatile long long counter = 1; + return (uint64_t) _InterlockedIncrement64(&counter) - 1; +#endif +#else + static uint64_t counter = 1; + return __atomic_fetch_add(&counter, 1, __ATOMIC_RELAXED); +#endif +} + // Needed for ggml_fp32_to_bf16_row() #if defined(__AVX512BF16__) #if defined(_MSC_VER) @@ -651,6 +666,14 @@ static const struct ggml_type_traits type_traits[GGML_TYPE_COUNT] = { .to_float = (ggml_to_float_t) ggml_fp16_to_fp32_row, .from_float_ref = (ggml_from_float_t) ggml_fp32_to_fp16_row, }, + [GGML_TYPE_Q1_0] = { + .type_name = "q1_0", + .blck_size = QK1_0, + .type_size = sizeof(block_q1_0), + .is_quantized = true, + .to_float = (ggml_to_float_t) dequantize_row_q1_0, + .from_float_ref = (ggml_from_float_t) quantize_row_q1_0_ref, + }, [GGML_TYPE_Q4_0] = { .type_name = "q4_0", .blck_size = QK4_0, @@ -718,6 +741,14 @@ static const struct ggml_type_traits type_traits[GGML_TYPE_COUNT] = { .to_float = (ggml_to_float_t) dequantize_row_mxfp4, .from_float_ref = (ggml_from_float_t)quantize_row_mxfp4_ref, }, + [GGML_TYPE_NVFP4] = { + .type_name = "nvfp4", + .blck_size = QK_NVFP4, + .type_size = sizeof(block_nvfp4), + .is_quantized = true, + .to_float = (ggml_to_float_t) dequantize_row_nvfp4, + .from_float_ref = (ggml_from_float_t)quantize_row_nvfp4_ref, + }, [GGML_TYPE_Q2_K] = { .type_name = "q2_K", .blck_size = QK_K, @@ -899,7 +930,8 @@ static const struct ggml_type_traits type_traits[GGML_TYPE_COUNT] = { }; const struct ggml_type_traits * ggml_get_type_traits(enum ggml_type type) { - GGML_ASSERT(type < GGML_TYPE_COUNT); + assert(type >= 0); + assert(type < GGML_TYPE_COUNT); return &type_traits[type]; } @@ -999,6 +1031,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = { "IM2COL", "IM2COL_BACK", "IM2COL_3D", + "COL2IM_1D", "CONV_2D", "CONV_3D", "CONV_2D_DW", @@ -1030,6 +1063,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = { "GATED_LINEAR_ATTN", "RWKV_WKV7", "SOLVE_TRI", + "GATED_DELTA_NET", "UNARY", @@ -1047,7 +1081,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = { "GLU", }; -static_assert(GGML_OP_COUNT == 95, "GGML_OP_COUNT != 95"); +static_assert(GGML_OP_COUNT == 97, "GGML_OP_COUNT != 97"); static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "none", @@ -1108,6 +1142,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "im2col(x)", "im2col_back(x)", "im2col_3d(x)", + "col2im_1d(x)", "conv_2d(x)", "conv_3d(x)", "conv_2d_dw(x)", @@ -1139,6 +1174,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "gated_linear_attn(k, v, q, gate, s)", "rwkv_wkv7(r, w, k, v, a, b, s)", "A X = B, A triangular, solve X", + "gated_delta_net(q, k, v, g, beta, s)", "unary(x)", @@ -1156,7 +1192,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "glu(x)", }; -static_assert(GGML_OP_COUNT == 95, "GGML_OP_COUNT != 95"); +static_assert(GGML_OP_COUNT == 97, "GGML_OP_COUNT != 97"); static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2"); @@ -1265,27 +1301,39 @@ size_t ggml_nbytes_pad(const struct ggml_tensor * tensor) { } int64_t ggml_blck_size(enum ggml_type type) { + assert(type >= 0); + assert(type < GGML_TYPE_COUNT); return type_traits[type].blck_size; } size_t ggml_type_size(enum ggml_type type) { + assert(type >= 0); + assert(type < GGML_TYPE_COUNT); return type_traits[type].type_size; } size_t ggml_row_size(enum ggml_type type, int64_t ne) { + assert(type >= 0); + assert(type < GGML_TYPE_COUNT); assert(ne % ggml_blck_size(type) == 0); return ggml_type_size(type)*ne/ggml_blck_size(type); } double ggml_type_sizef(enum ggml_type type) { + assert(type >= 0); + assert(type < GGML_TYPE_COUNT); return ((double)(type_traits[type].type_size))/type_traits[type].blck_size; } const char * ggml_type_name(enum ggml_type type) { - return type < GGML_TYPE_COUNT ? type_traits[type].type_name : "NONE"; + assert(type >= 0); + assert(type < GGML_TYPE_COUNT); + return type_traits[type].type_name; } bool ggml_is_quantized(enum ggml_type type) { + assert(type >= 0); + assert(type < GGML_TYPE_COUNT); return type_traits[type].is_quantized; } @@ -1361,10 +1409,12 @@ enum ggml_type ggml_ftype_to_ggml_type(enum ggml_ftype ftype) { case GGML_FTYPE_MOSTLY_BF16: wtype = GGML_TYPE_BF16; break; case GGML_FTYPE_MOSTLY_Q4_0: wtype = GGML_TYPE_Q4_0; break; case GGML_FTYPE_MOSTLY_Q4_1: wtype = GGML_TYPE_Q4_1; break; + case GGML_FTYPE_MOSTLY_Q1_0: wtype = GGML_TYPE_Q1_0; break; case GGML_FTYPE_MOSTLY_Q5_0: wtype = GGML_TYPE_Q5_0; break; case GGML_FTYPE_MOSTLY_Q5_1: wtype = GGML_TYPE_Q5_1; break; case GGML_FTYPE_MOSTLY_Q8_0: wtype = GGML_TYPE_Q8_0; break; case GGML_FTYPE_MOSTLY_MXFP4: wtype = GGML_TYPE_MXFP4; break; + case GGML_FTYPE_MOSTLY_NVFP4: wtype = GGML_TYPE_NVFP4; break; case GGML_FTYPE_MOSTLY_Q2_K: wtype = GGML_TYPE_Q2_K; break; case GGML_FTYPE_MOSTLY_Q3_K: wtype = GGML_TYPE_Q3_K; break; case GGML_FTYPE_MOSTLY_Q4_K: wtype = GGML_TYPE_Q4_K; break; @@ -1403,16 +1453,14 @@ static bool ggml_is_contiguous_n(const struct ggml_tensor * tensor, int n) { } next_nb *= tensor->ne[0]/ggml_blck_size(tensor->type); for (int i = 1; i < GGML_MAX_DIMS; i++) { - if (tensor->ne[i] != 1) { - if (i > n) { - if (tensor->nb[i] != next_nb) { - return false; - } - next_nb *= tensor->ne[i]; - } else { - // this dimension does not need to be contiguous - next_nb = tensor->ne[i]*tensor->nb[i]; + if (i > n) { + if (tensor->ne[i] != 1 && tensor->nb[i] != next_nb) { + return false; } + next_nb *= tensor->ne[i]; + } else { + // this dimension does not need to be contiguous + next_nb = tensor->ne[i]*tensor->nb[i]; } } return true; @@ -1496,6 +1544,10 @@ bool ggml_are_same_stride(const struct ggml_tensor * t0, const struct ggml_tenso (t0->nb[3] == t1->nb[3]); } +bool ggml_is_view(const struct ggml_tensor * t) { + return ggml_impl_is_view(t); +} + // check if t1 can be represented as a repetition of t0 bool ggml_can_repeat(const struct ggml_tensor * t0, const struct ggml_tensor * t1) { static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function"); @@ -1625,11 +1677,23 @@ static struct ggml_object * ggml_new_object(struct ggml_context * ctx, enum ggml const size_t cur_end = cur_offs + cur_size; // align to GGML_MEM_ALIGN + GGML_ASSERT(size <= SIZE_MAX - (GGML_MEM_ALIGN - 1)); size_t size_needed = GGML_PAD(size, GGML_MEM_ALIGN); char * const mem_buffer = ctx->mem_buffer; struct ggml_object * const obj_new = (struct ggml_object *)(mem_buffer + cur_end); + // integer overflow checks + if (cur_end > SIZE_MAX - size_needed) { + GGML_LOG_WARN("%s: overflow detected in cur_end (%zu) + size_needed (%zu)\n", __func__, cur_end, size_needed); + return NULL; + } + if (cur_end + size_needed > SIZE_MAX - GGML_OBJECT_SIZE) { + GGML_LOG_WARN("%s: overflow detected in cur_end (%zu) + size_needed (%zu) + GGML_OBJECT_SIZE (%zu)\n", __func__, + cur_end, size_needed, (size_t) GGML_OBJECT_SIZE); + return NULL; + } + if (cur_end + size_needed + GGML_OBJECT_SIZE > ctx->mem_size) { GGML_LOG_WARN("%s: not enough space in the context's memory pool (needed %zu, available %zu)\n", __func__, cur_end + size_needed + GGML_OBJECT_SIZE, ctx->mem_size); @@ -1698,6 +1762,8 @@ static struct ggml_tensor * ggml_new_tensor_impl( obj_alloc_size = data_size; } + GGML_ASSERT(GGML_TENSOR_SIZE <= SIZE_MAX - obj_alloc_size); + struct ggml_object * const obj_new = ggml_new_object(ctx, GGML_OBJECT_TYPE_TENSOR, GGML_TENSOR_SIZE + obj_alloc_size); GGML_ASSERT(obj_new); @@ -3200,6 +3266,16 @@ void ggml_mul_mat_set_prec( ggml_set_op_params_i32(a, 0, prec_i32); } +void ggml_mul_mat_set_hint( + struct ggml_tensor * a, + enum ggml_op_hint hint) { + GGML_ASSERT(a->op == GGML_OP_MUL_MAT); + + const int32_t hint_i32 = (int32_t) hint; + + ggml_set_op_params_i32(a, 1, hint_i32); +} + // ggml_mul_mat_id /* @@ -3441,7 +3517,8 @@ struct ggml_tensor * ggml_cast( result->op = GGML_OP_CPY; result->src[0] = a; - result->src[1] = result; + result->src[1] = result; // note: this self-reference might seem redundant, but it's actually needed by some + // backends for consistency with ggml_cpy_impl() above return result; } @@ -4466,6 +4543,41 @@ struct ggml_tensor * ggml_conv_1d_dw_ph( return ggml_conv_1d_dw(ctx, a, b, s0, a->ne[0] / 2, d0); } +// ggml_col2im_1d + +struct ggml_tensor * ggml_col2im_1d( + struct ggml_context * ctx, + struct ggml_tensor * a, + int s0, + int oc, + int p0) { + GGML_ASSERT(ggml_is_matrix(a)); + GGML_ASSERT(ggml_is_contiguous(a)); + GGML_ASSERT(a->type == GGML_TYPE_F32 || a->type == GGML_TYPE_F16 || a->type == GGML_TYPE_BF16); + GGML_ASSERT(s0 > 0); + GGML_ASSERT(oc > 0); + GGML_ASSERT(p0 >= 0); + + const int64_t K_OC = a->ne[0]; + const int64_t T_in = a->ne[1]; + const int64_t K = K_OC / oc; + const int64_t T_out = (T_in - 1) * s0 + K - 2 * p0; + + GGML_ASSERT(K_OC == K * oc); // a->ne[0] must be a whole number of oc blocks + GGML_ASSERT(K > 0 && T_out > 0); + + const int64_t ne[4] = { T_out, oc, 1, 1 }; + struct ggml_tensor * result = ggml_new_tensor(ctx, a->type, 2, ne); + + int32_t params[] = { s0, (int32_t)oc, (int32_t)p0 }; + ggml_set_op_params(result, params, sizeof(params)); + + result->op = GGML_OP_COL2IM_1D; + result->src[0] = a; + + return result; +} + // ggml_conv_transpose_1d static int64_t ggml_calc_conv_transpose_1d_output_size(int64_t ins, int64_t ks, int s, int p, int d) { @@ -4838,6 +4950,8 @@ struct ggml_tensor * ggml_pool_1d( a->ne[2], a->ne[3], }; + GGML_ASSERT(ne[0] > 0); + struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne); int32_t params[] = { op, k0, s0, p0 }; @@ -4868,6 +4982,9 @@ struct ggml_tensor * ggml_pool_2d( a->ne[2], a->ne[3], }; + GGML_ASSERT(ne[0] > 0); + GGML_ASSERT(ne[1] > 0); + result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne); int32_t params[] = { op, k0, k1, s0, s1, p0, p1 }; @@ -4916,6 +5033,7 @@ static struct ggml_tensor * ggml_interpolate_impl( GGML_ASSERT((mode & 0xFF) < GGML_SCALE_MODE_COUNT); // TODO: implement antialias for modes other than bilinear GGML_ASSERT(!(mode & GGML_SCALE_FLAG_ANTIALIAS) || (mode & 0xFF) == GGML_SCALE_MODE_BILINEAR); + GGML_ASSERT(a->type == GGML_TYPE_F32); struct ggml_tensor * result = ggml_new_tensor_4d(ctx, a->type, ne0, ne1, ne2, ne3); @@ -5142,7 +5260,7 @@ static struct ggml_tensor * ggml_fill_impl( struct ggml_tensor * a, float c, bool inplace) { - GGML_ASSERT(a->type == GGML_TYPE_F32); + GGML_ASSERT(a->type == GGML_TYPE_F32 || a->type == GGML_TYPE_F16); GGML_ASSERT(ggml_is_contiguous(a)); struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); @@ -5261,6 +5379,7 @@ struct ggml_tensor * ggml_flash_attn_ext( GGML_ASSERT(q->ne[3] == v->ne[3]); if (mask) { + GGML_ASSERT(mask->type == GGML_TYPE_F16); GGML_ASSERT(ggml_is_contiguous(mask)); //GGML_ASSERT(ggml_can_repeat_rows(mask, qk)); @@ -5743,7 +5862,7 @@ static struct ggml_tensor * ggml_unary_impl( struct ggml_tensor * a, enum ggml_unary_op op, bool inplace) { - GGML_ASSERT(ggml_is_contiguous_1(a)); + GGML_ASSERT(ggml_is_contiguous_rows(a)); struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); @@ -6095,6 +6214,63 @@ struct ggml_tensor * ggml_solve_tri( return result; } +// ggml_gated_delta_net + +struct ggml_tensor * ggml_gated_delta_net( + struct ggml_context * ctx, + struct ggml_tensor * q, + struct ggml_tensor * k, + struct ggml_tensor * v, + struct ggml_tensor * g, + struct ggml_tensor * beta, + struct ggml_tensor * state, + int64_t K) { + GGML_ASSERT(ggml_is_contiguous_rows(q)); + GGML_ASSERT(ggml_is_contiguous_rows(k)); + GGML_ASSERT(ggml_is_contiguous_rows(v)); + GGML_ASSERT(ggml_is_contiguous(g)); + GGML_ASSERT(ggml_is_contiguous(beta)); + GGML_ASSERT(ggml_is_contiguous(state)); + + GGML_ASSERT(q->type == GGML_TYPE_F32); + GGML_ASSERT(k->type == GGML_TYPE_F32); + GGML_ASSERT(v->type == GGML_TYPE_F32); + GGML_ASSERT(g->type == GGML_TYPE_F32); + GGML_ASSERT(beta->type == GGML_TYPE_F32); + GGML_ASSERT(state->type == GGML_TYPE_F32); + + const int64_t S_v = v->ne[0]; + const int64_t H = v->ne[1]; + const int64_t n_tokens = v->ne[2]; + const int64_t n_seqs = v->ne[3]; + + // gate: scalar [1, H, T, B] or vector [S_v, H, T, B] (KDA) + GGML_ASSERT(g->ne[0] == 1 || g->ne[0] == S_v); + GGML_ASSERT(beta->ne[0] == 1); + + // state holds the initial state s0 only: [S_v, S_v, H, n_seqs]. K (snapshot slot count) is an op param. + GGML_ASSERT(state->ne[0] == S_v); + GGML_ASSERT(state->ne[1] == S_v); + GGML_ASSERT(state->ne[2] == H); + GGML_ASSERT(state->ne[3] == n_seqs); + GGML_ASSERT(K >= 1); + const int64_t state_rows = K * S_v * n_seqs; + const int64_t ne[4] = { S_v * H, n_tokens * n_seqs + state_rows, 1, 1 }; + struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne); + + ggml_set_op_params_i32(result, 0, (int32_t) K); + + result->op = GGML_OP_GATED_DELTA_NET; + result->src[0] = q; + result->src[1] = k; + result->src[2] = v; + result->src[3] = g; + result->src[4] = beta; + result->src[5] = state; + + return result; +} + //////////////////////////////////////////////////////////////////////////////// struct ggml_hash_set ggml_hash_set_new(size_t size) { @@ -6556,7 +6732,7 @@ static void ggml_compute_backward( case GGML_OP_DIAG_MASK_INF: { if (src0_needs_grads) { /* ggml_diag_mask_inf_impl() shouldn't be here */ - /* ref: https://github.com/ggerganov/llama.cpp/pull/4203#discussion_r1412377992 */ + /* ref: https://github.com/ggml-org/llama.cpp/pull/4203#discussion_r1412377992 */ const int n_past = ((const int32_t *) tensor->op_params)[0]; ggml_add_or_set(ctx, cgraph, isrc0, ggml_diag_mask_zero_impl(ctx, grad, n_past, false)); } @@ -6720,20 +6896,35 @@ static void ggml_compute_backward( GGML_ASSERT(!src2_needs_grads || ggml_are_same_shape(src2, cgraph->grads[isrc2])); } -static size_t ggml_visit_parents(struct ggml_cgraph * cgraph, struct ggml_tensor * node) { - // check if already visited - size_t node_hash_pos = ggml_hash_find(&cgraph->visited_hash_set, node); +static size_t ggml_visit_parents_graph(struct ggml_cgraph * cgraph, struct ggml_tensor * node, bool compute) { + if (node->op != GGML_OP_NONE && compute) { + node->flags |= GGML_TENSOR_FLAG_COMPUTE; + } + + const size_t node_hash_pos = ggml_hash_find(&cgraph->visited_hash_set, node); GGML_ASSERT(node_hash_pos != GGML_HASHSET_FULL); - if (!ggml_bitset_get(cgraph->visited_hash_set.used, node_hash_pos)) { - // This is the first time we see this node in the current graph. - cgraph->visited_hash_set.keys[node_hash_pos] = node; - ggml_bitset_set(cgraph->visited_hash_set.used, node_hash_pos); - cgraph->use_counts[node_hash_pos] = 0; - } else { + + if (ggml_bitset_get(cgraph->visited_hash_set.used, node_hash_pos)) { // already visited + + if (compute) { + // update the compute flag regardless + for (int i = 0; i < GGML_MAX_SRC; ++i) { + struct ggml_tensor * src = node->src[i]; + if (src && ((src->flags & GGML_TENSOR_FLAG_COMPUTE) == 0)) { + ggml_visit_parents_graph(cgraph, src, true); + } + } + } + return node_hash_pos; } + // This is the first time we see this node in the current graph. + cgraph->visited_hash_set.keys[node_hash_pos] = node; + ggml_bitset_set(cgraph->visited_hash_set.used, node_hash_pos); + cgraph->use_counts[node_hash_pos] = 0; + for (int i = 0; i < GGML_MAX_SRC; ++i) { const int k = (cgraph->order == GGML_CGRAPH_EVAL_ORDER_LEFT_TO_RIGHT) ? i : @@ -6742,7 +6933,7 @@ static size_t ggml_visit_parents(struct ggml_cgraph * cgraph, struct ggml_tensor struct ggml_tensor * src = node->src[k]; if (src) { - size_t src_hash_pos = ggml_visit_parents(cgraph, src); + const size_t src_hash_pos = ggml_visit_parents_graph(cgraph, src, compute); // Update the use count for this operand. cgraph->use_counts[src_hash_pos]++; @@ -6773,17 +6964,17 @@ static size_t ggml_visit_parents(struct ggml_cgraph * cgraph, struct ggml_tensor return node_hash_pos; } -static void ggml_build_forward_impl(struct ggml_cgraph * cgraph, struct ggml_tensor * tensor, bool expand) { +static void ggml_build_forward_impl(struct ggml_cgraph * cgraph, struct ggml_tensor * tensor, bool expand, bool compute) { if (!expand) { // TODO: this branch isn't accessible anymore, maybe move this to ggml_build_forward_expand ggml_graph_clear(cgraph); } - const int n0 = cgraph->n_nodes; + const int n_old = cgraph->n_nodes; - ggml_visit_parents(cgraph, tensor); + ggml_visit_parents_graph(cgraph, tensor, compute); - const int n_new = cgraph->n_nodes - n0; + const int n_new = cgraph->n_nodes - n_old; GGML_PRINT_DEBUG("%s: visited %d new nodes\n", __func__, n_new); if (n_new > 0) { @@ -6792,8 +6983,22 @@ static void ggml_build_forward_impl(struct ggml_cgraph * cgraph, struct ggml_ten } } +struct ggml_tensor * ggml_build_forward_select( + struct ggml_cgraph * cgraph, + struct ggml_tensor ** tensors, + int n_tensors, + int idx) { + GGML_ASSERT(idx >= 0 && idx < n_tensors); + + for (int i = 0; i < n_tensors; i++) { + ggml_build_forward_impl(cgraph, tensors[i], true, i == idx ? true : false); + } + + return tensors[idx]; +} + void ggml_build_forward_expand(struct ggml_cgraph * cgraph, struct ggml_tensor * tensor) { - ggml_build_forward_impl(cgraph, tensor, true); + ggml_build_forward_impl(cgraph, tensor, true, true); } void ggml_build_backward_expand( @@ -6961,6 +7166,7 @@ struct ggml_cgraph * ggml_new_graph_custom(struct ggml_context * ctx, size_t siz /*.use_counts =*/ use_counts_ptr, /*.hash_table =*/ { hash_size, hash_used, hash_keys_ptr }, /*.order =*/ GGML_CGRAPH_EVAL_ORDER_LEFT_TO_RIGHT, + /*.uid =*/ 0, }; ggml_hash_set_reset(&cgraph->visited_hash_set); @@ -6988,6 +7194,7 @@ struct ggml_cgraph ggml_graph_view(struct ggml_cgraph * cgraph0, int i0, int i1) /*.use_counts =*/ cgraph0->use_counts, /*.visited_hash_set =*/ cgraph0->visited_hash_set, /*.order =*/ cgraph0->order, + /*.uid =*/ 0 }; return cgraph; @@ -7224,6 +7431,10 @@ bool ggml_can_fuse_subgraph_ext(const struct ggml_cgraph * cgraph, return false; } + if ((node->flags & GGML_TENSOR_FLAG_COMPUTE) == 0) { + return false; + } + if (ggml_node_list_find_tensor(cgraph, outputs, num_outputs, node) != -1) { continue; } @@ -7305,7 +7516,7 @@ static void ggml_graph_dump_dot_leaf_edge(FILE * fp, struct ggml_tensor * node, label); } -void ggml_graph_dump_dot(const struct ggml_cgraph * gb, const struct ggml_cgraph * gf, const char * filename) { +void ggml_graph_dump_dot(const struct ggml_cgraph * gb, const struct ggml_cgraph * cgraph, const char * filename) { char color[16]; FILE * fp = ggml_fopen(filename, "w"); @@ -7326,7 +7537,7 @@ void ggml_graph_dump_dot(const struct ggml_cgraph * gb, const struct ggml_cgraph if (node->flags & GGML_TENSOR_FLAG_PARAM) { snprintf(color, sizeof(color), "yellow"); } else if (grad) { - if (ggml_graph_find(gf, node)) { + if (ggml_graph_find(cgraph, node)) { snprintf(color, sizeof(color), "green"); } else { snprintf(color, sizeof(color), "lightblue"); @@ -7478,8 +7689,11 @@ void ggml_quantize_free(void) { iq2xs_free_impl(GGML_TYPE_IQ2_XXS); iq2xs_free_impl(GGML_TYPE_IQ2_XS); + iq2xs_free_impl(GGML_TYPE_IQ2_S); iq2xs_free_impl(GGML_TYPE_IQ1_S); + iq2xs_free_impl(GGML_TYPE_IQ1_M); iq3xs_free_impl(256); + iq3xs_free_impl(512); ggml_critical_section_end(); } @@ -7500,7 +7714,7 @@ size_t ggml_quantize_chunk( int64_t nrows, int64_t n_per_row, const float * imatrix) { - const int64_t n = (int64_t) nrows * n_per_row; + const int64_t n = nrows * n_per_row; if (ggml_quantize_requires_imatrix(type)) { GGML_ASSERT(imatrix != NULL); @@ -7517,19 +7731,21 @@ size_t ggml_quantize_chunk( size_t result = 0; switch (type) { - case GGML_TYPE_Q4_0: result = quantize_q4_0(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; - case GGML_TYPE_Q4_1: result = quantize_q4_1(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; - case GGML_TYPE_Q5_0: result = quantize_q5_0(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; - case GGML_TYPE_Q5_1: result = quantize_q5_1(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; - case GGML_TYPE_Q8_0: result = quantize_q8_0(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; - case GGML_TYPE_MXFP4: result = quantize_mxfp4(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; - case GGML_TYPE_Q2_K: result = quantize_q2_K(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; - case GGML_TYPE_Q3_K: result = quantize_q3_K(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; - case GGML_TYPE_Q4_K: result = quantize_q4_K(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; - case GGML_TYPE_Q5_K: result = quantize_q5_K(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; - case GGML_TYPE_Q6_K: result = quantize_q6_K(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; - case GGML_TYPE_TQ1_0: result = quantize_tq1_0(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; - case GGML_TYPE_TQ2_0: result = quantize_tq2_0(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; + case GGML_TYPE_Q1_0: result = quantize_q1_0 (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; + case GGML_TYPE_Q4_0: result = quantize_q4_0 (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; + case GGML_TYPE_Q4_1: result = quantize_q4_1 (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; + case GGML_TYPE_Q5_0: result = quantize_q5_0 (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; + case GGML_TYPE_Q5_1: result = quantize_q5_1 (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; + case GGML_TYPE_Q8_0: result = quantize_q8_0 (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; + case GGML_TYPE_MXFP4: result = quantize_mxfp4 (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; + case GGML_TYPE_NVFP4: result = quantize_nvfp4 (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; + case GGML_TYPE_Q2_K: result = quantize_q2_K (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; + case GGML_TYPE_Q3_K: result = quantize_q3_K (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; + case GGML_TYPE_Q4_K: result = quantize_q4_K (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; + case GGML_TYPE_Q5_K: result = quantize_q5_K (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; + case GGML_TYPE_Q6_K: result = quantize_q6_K (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; + case GGML_TYPE_TQ1_0: result = quantize_tq1_0 (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; + case GGML_TYPE_TQ2_0: result = quantize_tq2_0 (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; case GGML_TYPE_IQ2_XXS: result = quantize_iq2_xxs(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; case GGML_TYPE_IQ2_XS: result = quantize_iq2_xs (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; case GGML_TYPE_IQ3_XXS: result = quantize_iq3_xxs(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; @@ -7594,9 +7810,9 @@ struct ggml_threadpool_params ggml_threadpool_params_default(int n_threads) { } bool ggml_threadpool_params_match(const struct ggml_threadpool_params * p0, const struct ggml_threadpool_params * p1) { - if (p0->n_threads != p1->n_threads ) return false; - if (p0->prio != p1->prio ) return false; - if (p0->poll != p1->poll ) return false; - if (p0->strict_cpu != p1->strict_cpu ) return false; + if (p0->n_threads != p1->n_threads ) return false; + if (p0->prio != p1->prio ) return false; + if (p0->poll != p1->poll ) return false; + if (p0->strict_cpu != p1->strict_cpu ) return false; return memcmp(p0->cpumask, p1->cpumask, GGML_MAX_N_THREADS) == 0; } diff --git a/ggml/src/gguf.cpp b/ggml/src/gguf.cpp index 53504399c57..5e198618251 100644 --- a/ggml/src/gguf.cpp +++ b/ggml/src/gguf.cpp @@ -15,6 +15,17 @@ #include <string> #include <vector> +#define GGUF_MAX_STRING_LENGTH (1024*1024*1024) +#define GGUF_MAX_ARRAY_ELEMENTS (1024*1024*1024) + +#ifdef _WIN32 +# define gguf_ftell _ftelli64 +# define gguf_fseek _fseeki64 +#else +# define gguf_ftell ftello +# define gguf_fseek fseeko +#endif + template <typename T> struct type_to_gguf_type; @@ -217,17 +228,71 @@ struct gguf_context { }; struct gguf_reader { - FILE * file; + gguf_reader( + gguf_reader_callback_t callback, + void * userdata, + size_t max_chunk_read, + uint64_t data_offset = 0, + uint64_t nbytes_remain = 0) + : callback(callback), + userdata(userdata), + max_chunk_read(max_chunk_read), + data_offset(data_offset), + nbytes_remain(nbytes_remain) { + GGML_ASSERT(max_chunk_read > 0); + } - gguf_reader(FILE * file) : file(file) {} + // helper for remaining bytes in a file + static uint64_t file_remain(FILE * file) { + const int64_t cur = gguf_ftell(file); + if (cur < 0) { + return 0; + } + if (gguf_fseek(file, 0, SEEK_END) != 0) { + gguf_fseek(file, cur, SEEK_SET); + + return 0; + } + const int64_t end = gguf_ftell(file); + if (end < 0) { + gguf_fseek(file, cur, SEEK_SET); + + return 0; + } + gguf_fseek(file, cur, SEEK_SET); + return static_cast<uint64_t>(end - cur); + } template <typename T> bool read(T & dst) const { - return fread(&dst, 1, sizeof(dst), file) == sizeof(dst); + const size_t size = sizeof(dst); + if (size > nbytes_remain) { + return false; + } + return read_raw(&dst, size) == size; } template <typename T> bool read(std::vector<T> & dst, const size_t n) const { + if (n > GGUF_MAX_ARRAY_ELEMENTS) { + return false; + } + if constexpr (std::is_same<T, std::string>::value) { + // strings are prefixed with their length, so we need to account for that + if (n > SIZE_MAX / sizeof(uint64_t)) { + return false; + } + if (nbytes_remain < n * sizeof(uint64_t)) { + return false; + } + } else { + if (n > SIZE_MAX / sizeof(T)) { + return false; + } + if (nbytes_remain < n * sizeof(T)) { + return false; + } + } dst.resize(n); for (size_t i = 0; i < dst.size(); ++i) { if constexpr (std::is_same<T, bool>::value) { @@ -273,17 +338,84 @@ struct gguf_reader { } bool read(std::string & dst) const { - uint64_t size = -1; + uint64_t size = 0; if (!read(size)) { return false; } - dst.resize(size); - return fread(dst.data(), 1, dst.length(), file) == dst.length(); + if (size > GGUF_MAX_STRING_LENGTH) { + GGML_LOG_ERROR("%s: string length %" PRIu64 " exceeds maximum %" PRIu64 "\n", __func__, size, (uint64_t) GGUF_MAX_STRING_LENGTH); + return false; + } + if (size > nbytes_remain) { + GGML_LOG_ERROR("%s: string length %" PRIu64 " exceeds remaining file size %" PRIu64 " bytes\n", __func__, size, nbytes_remain); + return false; + } + dst.resize(static_cast<size_t>(size)); + return read_raw(dst.data(), static_cast<size_t>(size)) == size; } bool read(void * dst, const size_t size) const { - return fread(dst, 1, size, file) == size; + if (size > nbytes_remain) { + return false; + } + return read_raw(dst, size) == size; + } + + uint64_t tell() const { + return data_offset; + } + + bool seek(uint64_t absolute_offset) const { + const uint64_t end_offset = uint64_t(data_offset) + nbytes_remain; + if (absolute_offset > end_offset) { + return false; + } + + data_offset = absolute_offset; + nbytes_remain = end_offset - absolute_offset; + + return true; + } + +private: + size_t read_raw(void * dst, size_t size) const { + if (callback == nullptr || size == 0) { + return 0; + } + + uint8_t * data = static_cast<uint8_t *>(dst); + size_t total_nread = 0; + bool reached_eof = false; + + while (total_nread < size) { + const size_t chunk_size = std::min(max_chunk_read, size - total_nread); + if (data_offset + total_nread < data_offset) { + break; + } + const size_t nread = callback(userdata, static_cast<void *>(data + total_nread), data_offset + total_nread, chunk_size); + total_nread += nread; + if (nread != chunk_size) { + reached_eof = true; + break; + } + } + + data_offset += total_nread; + GGML_ASSERT(total_nread <= nbytes_remain); + nbytes_remain -= total_nread; + + if (reached_eof) { + nbytes_remain = 0; + } + + return total_nread; } + + gguf_reader_callback_t callback = nullptr; + void * userdata = nullptr; + size_t max_chunk_read = 0; + mutable uint64_t data_offset = 0; + mutable uint64_t nbytes_remain = 0; }; struct gguf_context * gguf_init_empty(void) { @@ -316,8 +448,7 @@ bool gguf_read_emplace_helper(const struct gguf_reader & gr, std::vector<struct return true; } -struct gguf_context * gguf_init_from_file_impl(FILE * file, struct gguf_init_params params) { - const struct gguf_reader gr(file); +static struct gguf_context * gguf_init_from_reader(const struct gguf_reader & gr, struct gguf_init_params params) { struct gguf_context * ctx = new gguf_context; bool ok = true; @@ -523,7 +654,7 @@ struct gguf_context * gguf_init_from_file_impl(FILE * file, struct gguf_init_par // tensor shape { - uint32_t n_dims = -1; + uint32_t n_dims = 0; ok = ok && gr.read(n_dims); if (n_dims > GGML_MAX_DIMS) { GGML_LOG_ERROR("%s: tensor '%s' has invalid number of dimensions: %" PRIu32 " > %" PRIu32 "\n", @@ -568,8 +699,8 @@ struct gguf_context * gguf_init_from_file_impl(FILE * file, struct gguf_init_par // check that tensor type is within defined range if (info.t.type < 0 || info.t.type >= GGML_TYPE_COUNT) { - GGML_LOG_ERROR("%s: tensor '%s' has invalid ggml type %d (%s)\n", - __func__, info.t.name, info.t.type, ggml_type_name(info.t.type)); + GGML_LOG_ERROR("%s: tensor '%s' has invalid ggml type %d. should be in [0, %d)\n", + __func__, info.t.name, info.t.type, GGML_TYPE_COUNT); ok = false; break; } @@ -585,6 +716,14 @@ struct gguf_context * gguf_init_from_file_impl(FILE * file, struct gguf_init_par break; } + // check that the size of the tensor in bytes is representable + if (ok && uint64_t(ggml_nelements(&info.t)/ggml_blck_size(info.t.type)) > SIZE_MAX/ggml_type_size(info.t.type)) { + GGML_LOG_ERROR("%s: tensor '%s' with shape (%" PRIi64 ", %" PRIi64 ", %" PRIi64 ", %" PRIi64 ") has a size in bytes > %zu\n", + __func__, info.t.name, info.t.ne[0], info.t.ne[1], info.t.ne[2], info.t.ne[3], SIZE_MAX); + ok = false; + break; + } + // calculate byte offsets given the tensor shape and type info.t.nb[0] = type_size; info.t.nb[1] = info.t.nb[0]*(info.t.ne[0]/blck_size); @@ -610,14 +749,14 @@ struct gguf_context * gguf_init_from_file_impl(FILE * file, struct gguf_init_par GGML_ASSERT(int64_t(ctx->info.size()) == n_tensors); // we require the data section to be aligned, so take into account any padding - if (fseek(file, GGML_PAD(ftell(file), ctx->alignment), SEEK_SET) != 0) { + if (n_tensors > 0 && !gr.seek(GGML_PAD(gr.tell(), ctx->alignment))) { GGML_LOG_ERROR("%s: failed to seek to beginning of data section\n", __func__); gguf_free(ctx); return nullptr; } // store the current file offset - this is where the data section starts - ctx->offset = ftell(file); + ctx->offset = gr.tell(); // compute the total size of the data section, taking into account the alignment { @@ -649,10 +788,34 @@ struct gguf_context * gguf_init_from_file_impl(FILE * file, struct gguf_init_par // the ggml_tensor structs to the appropriate locations in the binary blob // compute the exact size needed for the new ggml_context - const size_t mem_size = - params.no_alloc ? - (n_tensors )*ggml_tensor_overhead() : - (n_tensors + 1)*ggml_tensor_overhead() + ctx->size; + size_t mem_size = 0; + if (params.no_alloc) { + if (n_tensors != 0 && SIZE_MAX / n_tensors < ggml_tensor_overhead()) { + GGML_LOG_ERROR("%s: memory size overflow while allocating ggml context\n", __func__); + gguf_free(ctx); + return nullptr; + } + + const size_t overhead = n_tensors * ggml_tensor_overhead(); + + mem_size = overhead; + } else { + if ((n_tensors + 1) != 0 && SIZE_MAX / (n_tensors + 1) < ggml_tensor_overhead()) { + GGML_LOG_ERROR("%s: memory size overflow while allocating ggml context\n", __func__); + gguf_free(ctx); + return nullptr; + } + + const size_t overhead = (n_tensors + 1) * ggml_tensor_overhead(); + + if (SIZE_MAX - overhead < ctx->size) { + GGML_LOG_ERROR("%s: memory size overflow while allocating ggml context\n", __func__); + gguf_free(ctx); + return nullptr; + } + + mem_size = overhead + ctx->size; + } struct ggml_init_params pdata = { /*mem_size =*/ mem_size, @@ -730,15 +893,98 @@ struct gguf_context * gguf_init_from_file_impl(FILE * file, struct gguf_init_par return ctx; } +struct gguf_context * gguf_init_from_callback(gguf_reader_callback_t callback, void * userdata, size_t max_chunk_read, uint64_t max_expected_size, struct gguf_init_params params) { + if (callback == nullptr) { + return nullptr; + } + + const struct gguf_reader gr(callback, userdata, max_chunk_read == 0 ? SIZE_MAX : max_chunk_read, 0, max_expected_size); + return gguf_init_from_reader(gr, params); +} + +struct gguf_file_reader { + FILE * file; + uint64_t offset; +}; + +static size_t gguf_file_reader_callback(void * userdata, void * output, uint64_t offset, size_t len) { + GGML_ASSERT(len > 0); + + gguf_file_reader & reader = *static_cast<gguf_file_reader *>(userdata); + + if (reader.offset != offset) { + if (offset > INT64_MAX || gguf_fseek(reader.file, static_cast<int64_t>(offset), SEEK_SET) != 0) { + return 0; + } + + reader.offset = offset; + } + + const size_t nread = fread(static_cast<uint8_t *>(output), 1, len, reader.file); + reader.offset += nread; + return nread; +} + +struct gguf_context * gguf_init_from_file_ptr(FILE * file, struct gguf_init_params params) { + if (!file) { + return nullptr; + } + + const int64_t cur = gguf_ftell(file); + if (cur < 0) { + return nullptr; + } + + gguf_file_reader reader = { + /*.file = */ file, + /*.offset = */ static_cast<uint64_t>(cur), + }; + const struct gguf_reader gr(gguf_file_reader_callback, &reader, SIZE_MAX, reader.offset, gguf_reader::file_remain(file)); + return gguf_init_from_reader(gr, params); +} + +struct gguf_buffer_reader { + const uint8_t * data; + size_t size; +}; + +static size_t gguf_buffer_reader_callback(void * userdata, void * output, uint64_t offset, size_t len) { + GGML_ASSERT(len > 0); + + const gguf_buffer_reader & reader = *static_cast<gguf_buffer_reader *>(userdata); + + if (offset > reader.size || len > reader.size - offset) { + return 0; + } + + const size_t data_offset = static_cast<size_t>(offset); + const size_t nread = std::min(len, reader.size - data_offset); + memcpy(static_cast<uint8_t *>(output), reader.data + data_offset, nread); + return nread; +} + +struct gguf_context * gguf_init_from_buffer(const void * data, size_t size, struct gguf_init_params params) { + if (data == nullptr || size == 0) { + return nullptr; + } + + gguf_buffer_reader reader = { + /*.data = */ static_cast<const uint8_t *>(data), + /*.size = */ size, + }; + const struct gguf_reader gr(gguf_buffer_reader_callback, &reader, SIZE_MAX, 0, size); + return gguf_init_from_reader(gr, params); +} + struct gguf_context * gguf_init_from_file(const char * fname, struct gguf_init_params params) { FILE * file = ggml_fopen(fname, "rb"); if (!file) { - GGML_LOG_ERROR("%s: failed to open GGUF file '%s'\n", __func__, fname); + GGML_LOG_ERROR("%s: failed to open GGUF file '%s' (%s)\n", __func__, fname, strerror(errno)); return nullptr; } - struct gguf_context * result = gguf_init_from_file_impl(file, params); + struct gguf_context * result = gguf_init_from_file_ptr(file, params); fclose(file); return result; } @@ -1166,50 +1412,51 @@ void gguf_set_tensor_data(struct gguf_context * ctx, const char * name, const vo ctx->info[tensor_id].t.data = (void *)(uintptr_t)data; // double cast suppresses warning about casting away const } -struct gguf_writer { - std::vector<int8_t> & buf; +struct gguf_writer_base { + size_t written_bytes {0u}; + + ~gguf_writer_base(void) = default; - gguf_writer(std::vector<int8_t> & buf) : buf(buf) {} + // we bet on devirtualization + virtual void write(int8_t val) = 0; + virtual void write(const std::vector<int8_t> & val) = 0; + virtual void write_tensor_data(const struct gguf_tensor_info & info, size_t offset_data, size_t alignment) = 0; template <typename T> - void write(const T & val) const { + void write(const T & val) { for (size_t i = 0; i < sizeof(val); ++i) { - buf.push_back(reinterpret_cast<const int8_t *>(&val)[i]); + write(reinterpret_cast<const int8_t *>(&val)[i]); } } - void write(const std::vector<int8_t> & val) const { - buf.insert(buf.end(), val.begin(), val.end()); - } - - void write(const bool & val) const { + void write(const bool & val) { const int8_t val8 = val ? 1 : 0; write(val8); } - void write(const std::string & val) const { + void write(const std::string & val) { { const uint64_t n = val.length(); write(n); } for (size_t i = 0; i < val.length(); ++i) { - buf.push_back(reinterpret_cast<const int8_t *>(val.data())[i]); + write((val.data())[i]); } } - void write(const char * val) const { + void write(const char * val) { write(std::string(val)); } - void write(const enum ggml_type & val) const { + void write(const enum ggml_type & val) { write(int32_t(val)); } - void write(const enum gguf_type & val) const { + void write(const enum gguf_type & val) { write(int32_t(val)); } - void write(const struct gguf_kv & kv) const { + void write(const struct gguf_kv & kv) { const uint64_t ne = kv.get_ne(); write(kv.get_key()); @@ -1250,7 +1497,7 @@ struct gguf_writer { } } - void write_tensor_meta(const struct gguf_tensor_info & info) const { + void write_tensor_meta(const struct gguf_tensor_info & info) { write(info.t.name); const uint32_t n_dims = ggml_n_dims(&info.t); @@ -1263,14 +1510,33 @@ struct gguf_writer { write(info.offset); } - void pad(const size_t alignment) const { - while (buf.size() % alignment != 0) { + void pad(const size_t alignment) { + while (written_bytes % alignment != 0) { const int8_t zero = 0; write(zero); } } +}; + +// vector buffer based writer +struct gguf_writer_buf final : public gguf_writer_base { + std::vector<int8_t> & buf; + + gguf_writer_buf(std::vector<int8_t> & buf) : buf(buf) {} + + using gguf_writer_base::write; - void write_tensor_data(const struct gguf_tensor_info & info, const size_t offset_data, const size_t alignment) const { + void write(const int8_t val) override { + buf.push_back(val); + written_bytes++; + } + + void write(const std::vector<int8_t> & val) override { + buf.insert(buf.end(), val.begin(), val.end()); + written_bytes += val.size(); + } + + void write_tensor_data(const struct gguf_tensor_info & info, const size_t offset_data, const size_t alignment) override { GGML_ASSERT(buf.size() - offset_data == info.offset); GGML_ASSERT(ggml_is_contiguous(&info.t)); @@ -1284,14 +1550,58 @@ struct gguf_writer { GGML_ASSERT(info.t.data); memcpy(buf.data() + offset, info.t.data, nbytes); } + written_bytes += nbytes; pad(alignment); } }; -void gguf_write_to_buf(const struct gguf_context * ctx, std::vector<int8_t> & buf, bool only_meta) { - const struct gguf_writer gw(buf); +// file based writer +struct gguf_writer_file final : public gguf_writer_base { + FILE * file; + + gguf_writer_file(FILE* file) : file(file) {} + + using gguf_writer_base::write; + + void write(const int8_t val) override { + const auto real_val = static_cast<uint8_t>(val); + const auto ret = fputc(real_val, file); + written_bytes++; + if (ret != real_val) { + throw std::runtime_error("unexpected fputc result '" + std::to_string(ret) + "' instead of '" + std::to_string((int)real_val) + "'"); + } + } + + void write(const std::vector<int8_t> & val) override { + const auto ret = fwrite(val.data(), 1, val.size(), file); + written_bytes += val.size(); + if (ret != val.size()) { + throw std::runtime_error("unexpected fwrite number of bytes written, '" + std::to_string(ret) + "' instead of '" + std::to_string(val.size()) + "'"); + } + } + + void write_tensor_data(const struct gguf_tensor_info & info, const size_t offset_data, const size_t alignment) override { + GGML_ASSERT(written_bytes - offset_data == info.offset); + + GGML_ASSERT(ggml_is_contiguous(&info.t)); + const size_t nbytes = ggml_nbytes(&info.t); + + std::vector<int8_t> buf(nbytes); + if (info.t.buffer) { + ggml_backend_tensor_get(&info.t, buf.data(), 0, nbytes); + } else { + GGML_ASSERT(info.t.data); + memcpy(buf.data(), info.t.data, nbytes); + } + write(buf); + + pad(alignment); + } +}; +template <typename writer_t> +static void gguf_write_out(const struct gguf_context * ctx, writer_t & gw, bool only_meta) { const int64_t n_kv = gguf_get_n_kv(ctx); const int64_t n_tensors = gguf_get_n_tensors(ctx); @@ -1321,7 +1631,7 @@ void gguf_write_to_buf(const struct gguf_context * ctx, std::vector<int8_t> & bu return; } - const size_t offset_data = gw.buf.size(); + const size_t offset_data = gw.written_bytes; // write tensor data for (int64_t i = 0; i < n_tensors; ++i) { @@ -1329,6 +1639,24 @@ void gguf_write_to_buf(const struct gguf_context * ctx, std::vector<int8_t> & bu } } +void gguf_write_to_buf(const struct gguf_context * ctx, std::vector<int8_t> & buf, bool only_meta) { + gguf_writer_buf gw(buf); + gguf_write_out(ctx, gw, only_meta); +} + +bool gguf_write_to_file_ptr(const struct gguf_context * ctx, FILE * file, bool only_meta) { + GGML_ASSERT(file); + + try { + gguf_writer_file gw(file); + gguf_write_out(ctx, gw, only_meta); + } catch (const std::runtime_error& ex) { + GGML_LOG_ERROR("%s: failed to write GGUF data: %s\n", __func__, ex.what()); + return false; + } + return true; +} + bool gguf_write_to_file(const struct gguf_context * ctx, const char * fname, bool only_meta) { FILE * file = ggml_fopen(fname, "wb"); @@ -1337,11 +1665,13 @@ bool gguf_write_to_file(const struct gguf_context * ctx, const char * fname, boo return false; } - std::vector<int8_t> buf; - gguf_write_to_buf(ctx, buf, only_meta); - const bool ok = fwrite(buf.data(), 1, buf.size(), file) == buf.size(); + const bool success = gguf_write_to_file_ptr(ctx, file, only_meta); + if (!success) { + GGML_LOG_ERROR("%s: failed to write GGUF data into '%s'\n", __func__, fname); + } + fclose(file); - return ok; + return success; } size_t gguf_get_meta_size(const struct gguf_context * ctx) { diff --git a/include/parakeet.h b/include/parakeet.h new file mode 100644 index 00000000000..d35aa870adb --- /dev/null +++ b/include/parakeet.h @@ -0,0 +1,342 @@ +#ifndef PARAKEET_H +#define PARAKEET_H + +#include "ggml.h" +#include "ggml-cpu.h" + +#include <stddef.h> +#include <stdint.h> +#include <stdbool.h> + +#ifdef __GNUC__ +# define PARAKEET_DEPRECATED(func, hint) func __attribute__((deprecated(hint))) +#elif defined(_MSC_VER) +# define PARAKEET_DEPRECATED(func, hint) __declspec(deprecated(hint)) func +#else +# define PARAKEET_DEPRECATED(func, hint) func +#endif + +#ifdef PARAKEET_SHARED +# ifdef _WIN32 +# ifdef PARAKEET_BUILD +# define PARAKEET_API __declspec(dllexport) +# else +# define PARAKEET_API __declspec(dllimport) +# endif +# else +# define PARAKEET_API __attribute__ ((visibility ("default"))) +# endif +#else +# define PARAKEET_API +#endif + +#define PARAKEET_SAMPLE_RATE 16000 +#define PARAKEET_HOP_LENGTH 160 + +#ifdef __cplusplus +extern "C" { +#endif + + struct parakeet_context; + struct parakeet_state; + struct parakeet_full_params; + + typedef int32_t parakeet_pos; + typedef int32_t parakeet_token; + typedef int32_t parakeet_seq_id; + + struct parakeet_context_params { + bool use_gpu; + int gpu_device; // CUDA device + }; + + typedef struct parakeet_token_data { + parakeet_token id; // the BPE subword ID (0-8191) + + int duration_idx; // index into the models durations array + int duration_value; // actual duration value + int frame_index; + + float p; + float plog; + + int64_t t0; + int64_t t1; + + bool is_word_start; + } parakeet_token_data; + + typedef struct parakeet_model_loader { + void * context; + + size_t (*read)(void * ctx, void * output, size_t read_size); + bool (*eof)(void * ctx); + void (*close)(void * ctx); + } parakeet_model_loader; + + PARAKEET_API const char * parakeet_version(void); + + // Various functions for loading a ggml parakeet model. + // Allocate (almost) all memory needed for the model. + // Return NULL on failure + PARAKEET_API struct parakeet_context * parakeet_init_from_file_with_params (const char * path_model, struct parakeet_context_params params); + PARAKEET_API struct parakeet_context * parakeet_init_from_buffer_with_params(void * buffer, size_t buffer_size, struct parakeet_context_params params); + PARAKEET_API struct parakeet_context * parakeet_init_with_params (struct parakeet_model_loader * loader, struct parakeet_context_params params); + + // These are the same as the above, but the internal state of the context is not allocated automatically + // It is the responsibility of the caller to allocate the state using parakeet_init_state() (#523) + PARAKEET_API struct parakeet_context * parakeet_init_from_file_with_params_no_state (const char * path_model, struct parakeet_context_params params); + PARAKEET_API struct parakeet_context * parakeet_init_from_buffer_with_params_no_state(void * buffer, size_t buffer_size, struct parakeet_context_params params); + PARAKEET_API struct parakeet_context * parakeet_init_with_params_no_state (struct parakeet_model_loader * loader, struct parakeet_context_params params); + + PARAKEET_API struct parakeet_state * parakeet_init_state(struct parakeet_context * ctx); + + // Frees all allocated memory + PARAKEET_API void parakeet_free (struct parakeet_context * ctx); + PARAKEET_API void parakeet_free_state(struct parakeet_state * state); + PARAKEET_API void parakeet_free_params(struct parakeet_full_params * params); + PARAKEET_API void parakeet_free_context_params(struct parakeet_context_params * params); + + // Convert RAW PCM audio to log mel spectrogram. + // The resulting spectrogram is stored inside the default state of the provided parakeet context. + // Returns 0 on success + PARAKEET_API int parakeet_pcm_to_mel( + struct parakeet_context * ctx, + const float * samples, + int n_samples, + int n_threads); + + PARAKEET_API int parakeet_pcm_to_mel_with_state( + struct parakeet_context * ctx, + struct parakeet_state * state, + const float * samples, + int n_samples, + int n_threads); + + // This can be used to set a custom log mel spectrogram inside the default state of the provided parakeet context. + // Use this instead of parakeet_pcm_to_mel() if you want to provide your own log mel spectrogram. + // n_mel must be 128 + // Returns 0 on success + PARAKEET_API int parakeet_set_mel( + struct parakeet_context * ctx, + const float * data, + int n_len, + int n_mel); + + PARAKEET_API int parakeet_set_mel_with_state( + struct parakeet_context * ctx, + struct parakeet_state * state, + const float * data, + int n_len, + int n_mel); + + // Run the Parakeet encoder on the log mel spectrogram stored inside the default state in the provided parakeet context. + // Make sure to call parakeet_pcm_to_mel() or parakeet_set_mel() first. + // offset can be used to specify the offset of the first frame in the spectrogram. + // Returns 0 on success + PARAKEET_API int parakeet_encode( + struct parakeet_context * ctx, + int offset, + int n_threads); + + PARAKEET_API int parakeet_encode_with_state( + struct parakeet_context * ctx, + struct parakeet_state * state, + int offset, + int n_threads); + + // Convert the provided text into tokens. + // The tokens pointer must be large enough to hold the resulting tokens. + // Returns the number of tokens on success, no more than n_max_tokens + // Returns a negative number on failure - the number of tokens that would have been returned + // TODO: not sure if correct + PARAKEET_API int parakeet_tokenize( + struct parakeet_context * ctx, + const char * text, + parakeet_token * tokens, + int n_max_tokens); + + // Return the number of tokens in the provided text + // Equivalent to: -parakeet_tokenize(ctx, text, NULL, 0) + int parakeet_token_count(struct parakeet_context * ctx, const char * text); + + PARAKEET_API int parakeet_n_len (struct parakeet_context * ctx); // mel length + PARAKEET_API int parakeet_n_len_from_state(struct parakeet_state * state); // mel length + PARAKEET_API int parakeet_n_vocab (struct parakeet_context * ctx); + PARAKEET_API int parakeet_n_audio_ctx (struct parakeet_context * ctx); + + PARAKEET_API int parakeet_model_n_vocab (struct parakeet_context * ctx); + PARAKEET_API int parakeet_model_n_audio_ctx (struct parakeet_context * ctx); + PARAKEET_API int parakeet_model_n_audio_state(struct parakeet_context * ctx); + PARAKEET_API int parakeet_model_n_audio_head (struct parakeet_context * ctx); + PARAKEET_API int parakeet_model_n_audio_layer(struct parakeet_context * ctx); + PARAKEET_API int parakeet_model_n_mels (struct parakeet_context * ctx); + PARAKEET_API int parakeet_model_ftype (struct parakeet_context * ctx); + + // Token logits obtained from the last call to parakeet_full/parakeet_chunk + // The logits for the last token are stored in the last row + // Rows: n_tokens + // Cols: n_vocab + PARAKEET_API float * parakeet_get_logits (struct parakeet_context * ctx); + PARAKEET_API float * parakeet_get_logits_from_state(struct parakeet_state * state); + + // Token Id -> String. Uses the vocabulary in the provided context + PARAKEET_API const char * parakeet_token_to_str(struct parakeet_context * ctx, parakeet_token token); + + PARAKEET_API int parakeet_token_to_text(const char * token_str, bool is_first, char * output, int max_len); + + // Special tokens + PARAKEET_API parakeet_token parakeet_token_blank(struct parakeet_context * ctx); + PARAKEET_API parakeet_token parakeet_token_unk (struct parakeet_context * ctx); + PARAKEET_API parakeet_token parakeet_token_bos (struct parakeet_context * ctx); + + // Performance information from the default state. + struct parakeet_timings { + float sample_ms; + float encode_ms; + float decode_ms; + }; + PARAKEET_API struct parakeet_timings * parakeet_get_timings(struct parakeet_context * ctx); + PARAKEET_API void parakeet_print_timings(struct parakeet_context * ctx); + PARAKEET_API void parakeet_reset_timings(struct parakeet_context * ctx); + + // Print system information + PARAKEET_API const char * parakeet_print_system_info(void); + + // Available sampling strategies + enum parakeet_sampling_strategy { + PARAKEET_SAMPLING_GREEDY, + }; + + // Token callback. + // Called for each new predicted token. + // Use the parakeet_full_...() functions to obtain the text segments + typedef void (*parakeet_new_token_callback)( + struct parakeet_context * ctx, + struct parakeet_state * state, + const parakeet_token_data * token_data, + void * user_data); + + // Text segment callback + // Called on every newly generated text segment + // Use the parakeet_full_...() functions to obtain the text segments + typedef void (*parakeet_new_segment_callback)(struct parakeet_context * ctx, struct parakeet_state * state, int n_new, void * user_data); + + // Progress callback + typedef void (*parakeet_progress_callback)(struct parakeet_context * ctx, struct parakeet_state * state, int progress, void * user_data); + + // Encoder begin callback + // If not NULL, called before the encoder starts + // If it returns false, the computation is aborted + typedef bool (*parakeet_encoder_begin_callback)(struct parakeet_context * ctx, struct parakeet_state * state, void * user_data); + + // Parameters for the parakeet_full() function + // If you change the order or add new parameters, make sure to update the default values in parakeet.cpp: + // parakeet_full_default_params() + struct parakeet_full_params { + enum parakeet_sampling_strategy strategy; + + int n_threads; + int offset_ms; // start offset in ms + int duration_ms; // audio duration to process in ms + + bool no_context; // do not use past transcription (if any) as context + + int audio_ctx; // overwrite the audio context size (0 = use default) + + // called for every newly generated text segment + parakeet_new_segment_callback new_segment_callback; + void * new_segment_callback_user_data; + + // called for every newly generated token + parakeet_new_token_callback new_token_callback; + void * new_token_callback_user_data; + + // called on each progress update + parakeet_progress_callback progress_callback; + void * progress_callback_user_data; + + // called each time before the encoder starts + parakeet_encoder_begin_callback encoder_begin_callback; + void * encoder_begin_callback_user_data; + + // called each time before ggml computation starts + ggml_abort_callback abort_callback; + void * abort_callback_user_data; + }; + + // NOTE: this function allocates memory, and it is the responsibility of the caller to free the pointer - see parakeet_free_context_params() & parakeet_free_params() + PARAKEET_API struct parakeet_context_params * parakeet_context_default_params_by_ref(void); + PARAKEET_API struct parakeet_context_params parakeet_context_default_params (void); + + PARAKEET_API struct parakeet_full_params * parakeet_full_default_params_by_ref(enum parakeet_sampling_strategy strategy); + PARAKEET_API struct parakeet_full_params parakeet_full_default_params (enum parakeet_sampling_strategy strategy); + + // Run the entire model: PCM -> log mel spectrogram -> encoder -> decoder -> text + // Not thread safe for same context + PARAKEET_API int parakeet_full( + struct parakeet_context * ctx, + struct parakeet_full_params params, + const float * samples, + int n_samples); + + PARAKEET_API int parakeet_full_with_state( + struct parakeet_context * ctx, + struct parakeet_state * state, + struct parakeet_full_params params, + const float * samples, + int n_samples); + + // Process a single chunk of audio data that fits within the model's audio context window. + // This is more efficient than parakeet_full() for short audio clips. + PARAKEET_API int parakeet_chunk( + struct parakeet_context * ctx, + struct parakeet_state * state, + struct parakeet_full_params params, + const float * samples, + int n_samples); + + // Number of generated text segments + PARAKEET_API int parakeet_full_n_segments (struct parakeet_context * ctx); + PARAKEET_API int parakeet_full_n_segments_from_state(struct parakeet_state * state); + + // Get the start and end time of the specified segment + PARAKEET_API int64_t parakeet_full_get_segment_t0 (struct parakeet_context * ctx, int i_segment); + PARAKEET_API int64_t parakeet_full_get_segment_t0_from_state(struct parakeet_state * state, int i_segment); + + PARAKEET_API int64_t parakeet_full_get_segment_t1 (struct parakeet_context * ctx, int i_segment); + PARAKEET_API int64_t parakeet_full_get_segment_t1_from_state(struct parakeet_state * state, int i_segment); + + // Get the text of the specified segment + PARAKEET_API const char * parakeet_full_get_segment_text (struct parakeet_context * ctx, int i_segment); + PARAKEET_API const char * parakeet_full_get_segment_text_from_state(struct parakeet_state * state, int i_segment); + + // Get number of tokens in the specified segment + PARAKEET_API int parakeet_full_n_tokens (struct parakeet_context * ctx, int i_segment); + PARAKEET_API int parakeet_full_n_tokens_from_state(struct parakeet_state * state, int i_segment); + + // Get the token text of the specified token in the specified segment + PARAKEET_API const char * parakeet_full_get_token_text (struct parakeet_context * ctx, int i_segment, int i_token); + PARAKEET_API const char * parakeet_full_get_token_text_from_state(struct parakeet_context * ctx, struct parakeet_state * state, int i_segment, int i_token); + + // Get the token id of the specified token in the specified segment + PARAKEET_API parakeet_token parakeet_full_get_token_id (struct parakeet_context * ctx, int i_segment, int i_token); + PARAKEET_API parakeet_token parakeet_full_get_token_id_from_state(struct parakeet_state * state, int i_segment, int i_token); + + // Get token data for the specified token in the specified segment + PARAKEET_API parakeet_token_data parakeet_full_get_token_data (struct parakeet_context * ctx, int i_segment, int i_token); + PARAKEET_API parakeet_token_data parakeet_full_get_token_data_from_state(struct parakeet_state * state, int i_segment, int i_token); + + // Get the probability of the specified token in the specified segment + PARAKEET_API float parakeet_full_get_token_p (struct parakeet_context * ctx, int i_segment, int i_token); + PARAKEET_API float parakeet_full_get_token_p_from_state(struct parakeet_state * state, int i_segment, int i_token); + + // Control logging output; default behavior is to print to stderr + + PARAKEET_API void parakeet_log_set(ggml_log_callback log_callback, void * user_data); + +#ifdef __cplusplus +} +#endif + +#endif diff --git a/include/whisper.h b/include/whisper.h index f4cc6bf7abd..b5dcdb2917a 100644 --- a/include/whisper.h +++ b/include/whisper.h @@ -695,6 +695,16 @@ extern "C" { const float * samples, int n_samples); + // Like whisper_vad_detect_speech, but does not reset LSTM state. + // Use for streaming: call whisper_vad_reset_state() between utterances. + WHISPER_API bool whisper_vad_detect_speech_no_reset( + struct whisper_vad_context * vctx, + const float * samples, + int n_samples); + + // Reset LSTM hidden/cell states to zero. + WHISPER_API void whisper_vad_reset_state(struct whisper_vad_context * vctx); + WHISPER_API int whisper_vad_n_probs(struct whisper_vad_context * vctx); WHISPER_API float * whisper_vad_probs (struct whisper_vad_context * vctx); diff --git a/media/matmul.png b/media/matmul.png new file mode 100644 index 00000000000..786a20492c0 Binary files /dev/null and b/media/matmul.png differ diff --git a/models/convert-parakeet-to-ggml.py b/models/convert-parakeet-to-ggml.py new file mode 100755 index 00000000000..2d6a6d01554 --- /dev/null +++ b/models/convert-parakeet-to-ggml.py @@ -0,0 +1,337 @@ +#!/usr/bin/env python3 +# Convert Parakeet TDT model from NeMo format to ggml format +# +# Usage: python convert-parakeet-to-ggml.py --model parakeet-model.nemo --output-dir output-dir [--use-f32] +# +# The NeMo file is a tar archive containing: +# - model_weights.ckpt (PyTorch checkpoint) +# - model_config.yaml (model configuration) +# - tokenizer files +# +# This script extracts the NeMo archive, loads the model weights and configuration, +# and saves them in ggml format compatible with whisper.cpp. +# + +import torch +import argparse +import io +import os +import sys +import struct +import tarfile +import tempfile +import shutil +import yaml +import numpy as np +from pathlib import Path +from typing import Optional + +def hz_to_mel(freq): + return 2595.0 * np.log10(1.0 + freq / 700.0) + +def mel_to_hz(mel): + return 700.0 * (10.0**(mel / 2595.0) - 1.0) + +def extract_nemo_archive(nemo_path, extract_dir): + print(f"Extracting {nemo_path} to {extract_dir}") + with tarfile.open(nemo_path, 'r') as tar: + tar.extractall(path=extract_dir) + print("Extraction complete") + +def load_model_config(config_path): + with open(config_path, 'r', encoding='utf-8') as f: + config = yaml.safe_load(f) + return config + +def load_tokenizer(extract_dir, config): + tokenizer_model_path = None + tokenizer_vocab_path = None + + for file in os.listdir(extract_dir): + if file.endswith('_tokenizer.model'): + tokenizer_model_path = os.path.join(extract_dir, file) + elif file.endswith('tokenizer.vocab'): + tokenizer_vocab_path = os.path.join(extract_dir, file) + + if not tokenizer_model_path: + raise FileNotFoundError("Tokenizer model file not found") + + if not tokenizer_vocab_path: + raise FileNotFoundError("Tokenizer vocab file not found") + + tokens = {} + with open(tokenizer_vocab_path, 'r', encoding='utf-8') as f: + for idx, line in enumerate(f): + parts = line.strip().split('\t') + if len(parts) >= 1: + token = parts[0] + tokens[token.encode('utf-8')] = idx + + print(f"Loaded {len(tokens)} tokens from {os.path.basename(tokenizer_vocab_path)}") + + if len(tokens) != 8192: + print(f"WARNING: Expected 8192 tokens, got {len(tokens)}") + + return tokens + +def write_tensor(fout, name, data, use_f16=True, force_f32=False): + if 'pre_encode.conv' in name and 'bias' in name and len(data.shape) == 1: + data = data.reshape(1, -1, 1, 1) + print(f" Reshaped conv bias {name} to {data.shape}") + + n_dims = len(data.shape) + + ftype = 1 if use_f16 and not force_f32 else 0 + if force_f32: + data = data.astype(np.float32) + elif use_f16: + if n_dims < 2 or 'bias' in name or 'norm' in name or \ + ('pre_encode.conv' in name and n_dims == 4) or \ + 'depthwise_conv.weight' in name: + data = data.astype(np.float32) + ftype = 0 + else: + data = data.astype(np.float16) + else: + data = data.astype(np.float32) + + dims_reversed = [data.shape[n_dims - 1 - i] for i in range(n_dims)] + print(f"Processing: {name} {list(data.shape)}, dtype: {data.dtype}, n_dims: {n_dims}, reversed: {dims_reversed}") + name_bytes = name.encode('utf-8') + fout.write(struct.pack("iii", n_dims, len(name_bytes), ftype)) + for i in range(n_dims): + fout.write(struct.pack("i", data.shape[n_dims - 1 - i])) + fout.write(name_bytes) + + data.tofile(fout) + +def convert_parakeet_to_ggml(nemo_path, output_dir, use_f16=True, out_name=None): + nemo_path = Path(nemo_path) + output_dir = Path(output_dir) + output_dir.mkdir(parents=True, exist_ok=True) + + # Create temporary directory for extraction + with tempfile.TemporaryDirectory() as temp_dir: + extract_nemo_archive(nemo_path, temp_dir) + + config_path = os.path.join(temp_dir, 'model_config.yaml') + config = load_model_config(config_path) + + print("Model configuration:") + print(f" Sample rate: {config['sample_rate']}") + print(f" Encoder layers: {config['encoder']['n_layers']}") + print(f" Encoder d_model: {config['encoder']['d_model']}") + print(f" Mel features: {config['preprocessor']['features']}") + + weights_path = os.path.join(temp_dir, 'model_weights.ckpt') + print(f"\nLoading model weights from {weights_path}") + checkpoint = torch.load(weights_path, map_location='cpu') + + # Extract state dict + if 'state_dict' in checkpoint: + state_dict = checkpoint['state_dict'] + else: + state_dict = checkpoint + + print(f"Loaded {len(state_dict)} tensors") + + # Load tokenizer + print("\nLoading tokenizer...") + tokens = load_tokenizer(temp_dir, config) + print(f"Loaded {len(tokens)} tokens") + + # Prepare hyperparameters for the Parakeet ggml format. + hparams = { + 'n_audio_ctx': 5000, + 'n_audio_state': config['encoder']['d_model'], + 'n_audio_head': config['encoder']['n_heads'], + 'n_audio_layer': config['encoder']['n_layers'], + 'n_mels': config['preprocessor']['features'], + 'n_fft': config['preprocessor']['n_fft'], + 'subsampling_factor': config['encoder']['subsampling_factor'], + 'n_subsampling_channels': config['encoder']['subsampling_conv_channels'], + 'n_conv_kernel': config['encoder']['conv_kernel_size'], + + 'n_pred_dim': config['decoder']['prednet']['pred_hidden'], + 'n_pred_layers': config['decoder']['prednet']['pred_rnn_layers'], + 'n_vocab': config['decoder']['vocab_size'], + 'n_tdt_durations': config['model_defaults']['num_tdt_durations'], + 'n_max_tokens': config['decoding']['greedy']['max_symbols'], + } + + print("\nGGML hyperparameters:") + for key, value in hparams.items(): + print(f" {key}: {value}") + + # Create output file + if out_name: + fname_out = output_dir / out_name + else: + fname_out = output_dir / ("ggml-model-f32.bin" if not use_f16 else "ggml-model.bin") + print(f"\nWriting to {fname_out}") + + with open(fname_out, 'wb') as fout: + # Write magic number + fout.write(struct.pack("i", 0x67676d6c)) # 'ggml' in hex + + # Write hyperparameters + fout.write(struct.pack("i", hparams['n_vocab'])) + fout.write(struct.pack("i", hparams['n_audio_ctx'])) + fout.write(struct.pack("i", hparams['n_audio_state'])) + fout.write(struct.pack("i", hparams['n_audio_head'])) + fout.write(struct.pack("i", hparams['n_audio_layer'])) + fout.write(struct.pack("i", hparams['n_mels'])) + fout.write(struct.pack("i", 1 if use_f16 else 0)) + fout.write(struct.pack("i", hparams['n_fft'])) + fout.write(struct.pack("i", hparams['subsampling_factor'])) + fout.write(struct.pack("i", hparams['n_subsampling_channels'])) + fout.write(struct.pack("i", hparams['n_conv_kernel'])) + fout.write(struct.pack("i", hparams['n_pred_dim'])) + fout.write(struct.pack("i", hparams['n_pred_layers'])) + fout.write(struct.pack("i", hparams['n_tdt_durations'])) + fout.write(struct.pack("i", hparams['n_max_tokens'])) + + # Extract mel filterbank from model + fb_key = None + for key in state_dict.keys(): + if 'featurizer.fb' in key or 'filterbank' in key.lower(): + fb_key = key + break + + if not fb_key: + print("\nERROR: Mel filterbank not found in model!") + print("Expected tensor with 'featurizer.fb' or 'filterbank' in name") + print("\nAvailable preprocessor tensors:") + for key in sorted(state_dict.keys()): + if 'preprocessor' in key or 'featurizer' in key: + print(f" {key}: {state_dict[key].shape}") + raise ValueError("Mel filterbank tensor not found in model") + + print(f"\nUsing model's mel filterbank from: {fb_key}") + mel_filters = state_dict[fb_key].squeeze().numpy().astype(np.float32) + print(f" Filterbank shape: {mel_filters.shape}") + print(f" Filterbank min/max values: {mel_filters.min():.6f} / {mel_filters.max():.6f}") + print(f" Filterbank non-zero elements: {np.count_nonzero(mel_filters)} / {mel_filters.size}") + print(f" First row sum: {mel_filters[0].sum():.6f}") + + if len(mel_filters.shape) != 2: + raise ValueError(f"Expected 2D filterbank, got shape {mel_filters.shape}") + + n_mels, n_freqs = mel_filters.shape + fout.write(struct.pack("i", n_mels)) # n_mel + fout.write(struct.pack("i", n_freqs)) # n_fb (frequency bins) + + # Write mel filterbank + for i in range(n_mels): + for j in range(n_freqs): + fout.write(struct.pack("f", mel_filters[i, j])) + + # Extract window function from model + window_key = None + for key in state_dict.keys(): + if 'featurizer.window' in key or 'preproc' in key and 'window' in key: + window_key = key + break + + if not window_key: + print("\nERROR: Window function not found in model!") + print("Expected tensor with 'featurizer.window' in name") + raise ValueError("Window function tensor not found in model") + + print(f"\nUsing model's window function from: {window_key}") + window = state_dict[window_key].squeeze().numpy().astype(np.float32) + print(f" Window shape: {window.shape}") + print(f" Window min/max values: {window.min():.6f} / {window.max():.6f}") + print(f" Window non-zero elements: {np.count_nonzero(window)} / {window.size}") + print(f" Window sum: {window.sum():.6f}") + + if len(window.shape) != 1: + raise ValueError(f"Expected 1D window, got shape {window.shape}") + + n_window = window.shape[0] + fout.write(struct.pack("i", n_window)) + + # Write window function + for i in range(n_window): + fout.write(struct.pack("f", window[i])) + + # Write TDT durations + tdt_durations = config['model_defaults']['tdt_durations'] + if len(tdt_durations) != hparams['n_tdt_durations']: + raise ValueError(f"TDT durations count mismatch: {len(tdt_durations)} vs {hparams['n_tdt_durations']}") + + for duration in tdt_durations: + fout.write(struct.pack("I", duration)) + + fout.write(struct.pack("i", len(tokens))) + for token_bytes, idx in sorted(tokens.items(), key=lambda x: x[1]): + fout.write(struct.pack("i", len(token_bytes))) + fout.write(token_bytes) + + # Pre-collect prediction LSTM input-hidden biases so they can be + # folded into the hidden-hidden bias during the main write loop. + lstm_prefix = 'decoder.prediction.dec_rnn.lstm' + pred_bias_ih = {} + for key, t in state_dict.items(): + if f'{lstm_prefix}.bias_ih_l' in key: + layer_idx = int(key.rsplit('bias_ih_l', 1)[1]) + pred_bias_ih[layer_idx] = t.squeeze().numpy().astype(np.float32) + + print("\nConverting model weights...") + for name, tensor in state_dict.items(): + # Skip the filterbank and window - already written in preprocessing section + if name == fb_key: + continue + if name == window_key: + continue + + # bias_ih is folded into bias_hh below; skip writing it separately + if f'{lstm_prefix}.bias_ih_l' in name: + continue + + # Don't squeeze Conv2d weights - they need to preserve all 4 dimensions + if 'conv' in name and 'weight' in name and len(tensor.shape) == 4: + data = tensor.numpy() + else: + data = tensor.squeeze().numpy() + + # For prediction LSTM weights/biases: + # Fold bias_ih into bias_hh (bias_ih already skipped above). + # Reorder gates (input, forget, cell, output) from PyTorch layout + # [i, f, g, o] to [i, f, o, g] so the three sigmoid-gated outputs + # (i, f, o) are contiguous. + if name.startswith(f'{lstm_prefix}.'): + if f'{lstm_prefix}.bias_hh_l' in name: + layer_idx = int(name.rsplit('bias_hh_l', 1)[1]) + data = data.astype(np.float32) + pred_bias_ih[layer_idx] + name = name.replace('bias_hh_l', 'bias_h_l') + h = data.shape[0] // 4 + data = np.concatenate([data[:h], data[h:2*h], data[3*h:], data[2*h:3*h]], axis=0) + + write_tensor(fout, name, data, use_f16=use_f16) + + print(f"\nConversion complete!") + print(f"Output file: {fname_out}") + print(f"File size: {fname_out.stat().st_size / (1024**2):.2f} MB") + +if __name__ == '__main__': + parser = argparse.ArgumentParser( + description='Convert Parakeet TDT model from NeMo format to ggml format' + ) + parser.add_argument('--model', type=str, required=True, + help='Path to Parakeet .nemo model file') + parser.add_argument('--out-dir', type=str, required=True, + help='Directory to write ggml model file') + parser.add_argument('--use-f32', action='store_true', default=False, + help='Use f32 instead of f16 (default: f16)') + parser.add_argument('--out-name', type=str, default=None, + help='Output file name (default: ggml-model.bin or ggml-model-f32.bin)') + + args = parser.parse_args() + + if not os.path.exists(args.model): + print(f"Error: {args.model} not found") + sys.exit(1) + + use_f16 = not args.use_f32 + convert_parakeet_to_ggml(args.model, args.out_dir, use_f16, args.out_name) diff --git a/models/convert-whisper-to-coreml.py b/models/convert-whisper-to-coreml.py index 66827b6d420..7cf07754a89 100644 --- a/models/convert-whisper-to-coreml.py +++ b/models/convert-whisper-to-coreml.py @@ -8,10 +8,19 @@ from typing import Dict from typing import Optional from ane_transformers.reference.layer_norm import LayerNormANE as LayerNormANEBase -from coremltools.models.neural_network.quantization_utils import quantize_weights from whisper.model import Whisper, AudioEncoder, TextDecoder, ResidualAttentionBlock, MultiHeadAttention, ModelDimensions from whisper import load_model + +def _str_to_bool(v): + if isinstance(v, bool): + return v + if v.lower() in ("true", "1", "yes"): + return True + if v.lower() in ("false", "0", "no"): + return False + raise argparse.ArgumentTypeError(f"boolean value expected, got '{v}'") + # Disable PyTorch Scaled Dot-Product Attention (SDPA) to avoid compatibility issues. # The Whisper implementation expects a specific behavior from # torch.nn.functional.scaled_dot_product_attention that differs between PyTorch @@ -258,11 +267,9 @@ def convert_encoder(hparams, model, quantize=False): inputs=[ct.TensorType(name="logmel_data", shape=input_shape)], outputs=[ct.TensorType(name="output")], compute_units=ct.ComputeUnit.ALL, + compute_precision=ct.precision.FLOAT16 if quantize else ct.precision.FLOAT32, ) - if quantize: - model = quantize_weights(model, nbits=16) - return model def convert_decoder(hparams, model, quantize=False): @@ -283,20 +290,18 @@ def convert_decoder(hparams, model, quantize=False): ct.TensorType(name="token_data", shape=tokens_shape, dtype=int), ct.TensorType(name="audio_data", shape=audio_shape) ], + compute_precision=ct.precision.FLOAT16 if quantize else ct.precision.FLOAT32, ) - if quantize: - model = quantize_weights(model, nbits=16) - return model if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--model", type=str, help="model to convert (e.g. tiny, tiny.en, base, base.en, small, small.en, medium, medium.en, large-v1, large-v2, large-v3, large-v3-turbo)", required=True) - parser.add_argument("--encoder-only", type=bool, help="only convert encoder", default=False) - parser.add_argument("--quantize", type=bool, help="quantize weights to F16", default=False) - parser.add_argument("--optimize-ane", type=bool, help="optimize for ANE execution (currently broken)", default=False) + parser.add_argument("--encoder-only", type=_str_to_bool, help="only convert encoder", default=False) + parser.add_argument("--quantize", type=_str_to_bool, help="quantize weights to F16", default=False) + parser.add_argument("--optimize-ane", type=_str_to_bool, help="optimize for ANE execution", default=False) args = parser.parse_args() if args.model not in ["tiny", "tiny.en", "base", "base.en", "small", "small.en", "small.en-tdrz", "medium", "medium.en", "large-v1", "large-v2", "large-v3", "large-v3-turbo"]: diff --git a/models/convert-whisper-to-openvino.py b/models/convert-whisper-to-openvino.py index 3124dd3d7cf..a17e535550d 100644 --- a/models/convert-whisper-to-openvino.py +++ b/models/convert-whisper-to-openvino.py @@ -2,7 +2,6 @@ import torch from whisper import load_model import os -from openvino.tools import mo from openvino.frontend import FrontEndManager from openvino.runtime import serialize import shutil diff --git a/models/download-ggml-model.sh b/models/download-ggml-model.sh index f1394e98484..0539c8afb3d 100755 --- a/models/download-ggml-model.sh +++ b/models/download-ggml-model.sh @@ -120,7 +120,13 @@ fi if [ -x "$(command -v wget2)" ]; then wget2 --no-config --progress bar -O ggml-"$model".bin $src/$pfx-"$model".bin elif [ -x "$(command -v curl)" ]; then - curl -L --output ggml-"$model".bin $src/$pfx-"$model".bin + curl -L --fail \ + --retry 5 \ + --retry-delay 5 \ + --retry-all-errors \ + --retry-connrefused \ + ${HF_TOKEN:+--header "Authorization: Bearer $HF_TOKEN"} \ + --output ggml-"$model".bin $src/$pfx-"$model".bin elif [ -x "$(command -v wget)" ]; then wget --no-config --quiet --show-progress -O ggml-"$model".bin $src/$pfx-"$model".bin else diff --git a/models/for-tests-ggml-parakeet-tdt.bin b/models/for-tests-ggml-parakeet-tdt.bin new file mode 100644 index 00000000000..8b1dda1feba Binary files /dev/null and b/models/for-tests-ggml-parakeet-tdt.bin differ diff --git a/models/generate-parakeet-test-model.py b/models/generate-parakeet-test-model.py new file mode 100755 index 00000000000..192a96ce627 --- /dev/null +++ b/models/generate-parakeet-test-model.py @@ -0,0 +1,182 @@ +#!/usr/bin/env python3 +import struct +import sys +import numpy as np +from pathlib import Path + +def write_tensor(fout, name, data): + n_dims = len(data.shape) + data = data.astype(np.float32) + ftype = 0 # GGML_TYPE_F32 + + name_bytes = name.encode('utf-8') + fout.write(struct.pack("iii", n_dims, len(name_bytes), ftype)) + for i in range(n_dims): + fout.write(struct.pack("i", data.shape[n_dims - 1 - i])) + fout.write(name_bytes) + data.tofile(fout) + +def generate(output_path): + rng = np.random.default_rng(42) + + hparams = { + 'n_vocab': 10, + 'n_audio_ctx': 3200, + 'n_audio_state': 8, + 'n_audio_head': 2, + 'n_audio_layer': 1, + 'n_mels': 16, + 'ftype': 0, + 'n_fft': 64, + 'subsampling_factor': 8, + 'n_subsampling_channels': 4, + 'n_conv_kernel': 3, + 'n_pred_dim': 8, + 'n_pred_layers': 1, + 'n_tdt_durations': 2, + 'n_max_tokens': 5, + } + + n_vocab = hparams['n_vocab'] + n_state = hparams['n_audio_state'] + n_head = hparams['n_audio_head'] + n_layer = hparams['n_audio_layer'] + n_mels = hparams['n_mels'] + n_fft = hparams['n_fft'] + n_sub_fac = hparams['subsampling_factor'] + n_sub_ch = hparams['n_subsampling_channels'] + n_conv_ker = hparams['n_conv_kernel'] + dec_dim = hparams['n_pred_dim'] + n_pred_l = hparams['n_pred_layers'] + n_tdt = hparams['n_tdt_durations'] + + n_pre_enc = (n_mels // n_sub_fac) * n_sub_ch + n_head_dim = n_state // n_head + n_pred_embed = n_vocab + 1 + n_lstm_gates = 4 * dec_dim + n_joint_out = n_vocab + n_tdt + 1 + n_freqs = n_fft // 2 + 1 + + def f32(*shape): + return rng.standard_normal(shape).astype(np.float32) + + with open(output_path, 'wb') as fout: + fout.write(struct.pack("I", 0x67676d6c)) + + for key in ['n_vocab', + 'n_audio_ctx', + 'n_audio_state', + 'n_audio_head', + 'n_audio_layer', + 'n_mels', + 'ftype', + 'n_fft', + 'subsampling_factor', + 'n_subsampling_channels', + 'n_conv_kernel', + 'n_pred_dim', + 'n_pred_layers', + 'n_tdt_durations', + 'n_max_tokens']: + fout.write(struct.pack("i", hparams[key])) + + fout.write(struct.pack("i", n_mels)) + fout.write(struct.pack("i", n_freqs)) + f32(n_mels, n_freqs).tofile(fout) + + fout.write(struct.pack("i", n_fft)) + f32(n_fft).tofile(fout) + + for d in range(n_tdt): + fout.write(struct.pack("I", d)) + + tokens = ['<unk>', '<s>', '</s>'] + [chr(ord('a') + i) for i in range(n_vocab - 3)] + assert len(tokens) == n_vocab + fout.write(struct.pack("i", n_vocab)) + for tok in tokens: + tok_bytes = tok.encode('utf-8') + fout.write(struct.pack("i", len(tok_bytes))) + fout.write(tok_bytes) + + write_tensor(fout, "encoder.pre_encode.out.weight", f32(n_state, n_pre_enc)) + write_tensor(fout, "encoder.pre_encode.out.bias", f32(n_state)) + + write_tensor(fout, "encoder.pre_encode.conv.0.weight", f32(n_sub_ch, 1, 3, 3)) + write_tensor(fout, "encoder.pre_encode.conv.0.bias", f32(1, n_sub_ch, 1, 1)) + + write_tensor(fout, "encoder.pre_encode.conv.2.weight", f32(n_sub_ch, 1, 3, 3)) + write_tensor(fout, "encoder.pre_encode.conv.2.bias", f32(1, n_sub_ch, 1, 1)) + + write_tensor(fout, "encoder.pre_encode.conv.3.weight", f32(n_sub_ch, n_sub_ch, 1, 1)) + write_tensor(fout, "encoder.pre_encode.conv.3.bias", f32(1, n_sub_ch, 1, 1)) + + write_tensor(fout, "encoder.pre_encode.conv.5.weight", f32(n_sub_ch, 1, 3, 3)) + write_tensor(fout, "encoder.pre_encode.conv.5.bias", f32(1, n_sub_ch, 1, 1)) + + write_tensor(fout, "encoder.pre_encode.conv.6.weight", f32(n_sub_ch, n_sub_ch, 1, 1)) + write_tensor(fout, "encoder.pre_encode.conv.6.bias", f32(1, n_sub_ch, 1, 1)) + + for i in range(n_layer): + p = f"encoder.layers.{i}" + + write_tensor(fout, f"{p}.norm_feed_forward1.weight", f32(n_state)) + write_tensor(fout, f"{p}.norm_feed_forward1.bias", f32(n_state)) + write_tensor(fout, f"{p}.feed_forward1.linear1.weight", f32(4*n_state, n_state)) + write_tensor(fout, f"{p}.feed_forward1.linear2.weight", f32(n_state, 4*n_state)) + + write_tensor(fout, f"{p}.norm_conv.weight", f32(n_state)) + write_tensor(fout, f"{p}.norm_conv.bias", f32(n_state)) + write_tensor(fout, f"{p}.conv.pointwise_conv1.weight", f32(2*n_state, n_state)) + write_tensor(fout, f"{p}.conv.depthwise_conv.weight", f32(n_state, n_conv_ker)) + write_tensor(fout, f"{p}.conv.batch_norm.weight", f32(n_state)) + write_tensor(fout, f"{p}.conv.batch_norm.bias", f32(n_state)) + write_tensor(fout, f"{p}.conv.batch_norm.running_mean", f32(n_state)) + write_tensor(fout, f"{p}.conv.batch_norm.running_var", np.abs(f32(n_state))) + num_batches = np.zeros(1, dtype=np.int32) + write_tensor(fout, f"{p}.conv.batch_norm.num_batches_tracked", num_batches) + write_tensor(fout, f"{p}.conv.pointwise_conv2.weight", f32(n_state, n_state)) + + write_tensor(fout, f"{p}.norm_self_att.weight", f32(n_state)) + write_tensor(fout, f"{p}.norm_self_att.bias", f32(n_state)) + + write_tensor(fout, f"{p}.self_attn.pos_bias_u", f32(n_head, n_head_dim)) + write_tensor(fout, f"{p}.self_attn.pos_bias_v", f32(n_head, n_head_dim)) + write_tensor(fout, f"{p}.self_attn.linear_q.weight", f32(n_state, n_state)) + write_tensor(fout, f"{p}.self_attn.linear_k.weight", f32(n_state, n_state)) + write_tensor(fout, f"{p}.self_attn.linear_v.weight", f32(n_state, n_state)) + write_tensor(fout, f"{p}.self_attn.linear_out.weight", f32(n_state, n_state)) + write_tensor(fout, f"{p}.self_attn.linear_pos.weight", f32(n_state, n_state)) + + write_tensor(fout, f"{p}.norm_feed_forward2.weight", f32(n_state)) + write_tensor(fout, f"{p}.norm_feed_forward2.bias", f32(n_state)) + write_tensor(fout, f"{p}.feed_forward2.linear1.weight", f32(4*n_state, n_state)) + write_tensor(fout, f"{p}.feed_forward2.linear2.weight", f32(n_state, 4*n_state)) + + write_tensor(fout, f"{p}.norm_out.weight", f32(n_state)) + write_tensor(fout, f"{p}.norm_out.bias", f32(n_state)) + + write_tensor(fout, "decoder.prediction.embed.weight", f32(n_pred_embed, dec_dim)) + + def reorder_gates(data): + h = data.shape[0] // 4 + return np.concatenate([data[:h], data[h:2*h], data[3*h:], data[2*h:3*h]], axis=0) + + for i in range(n_pred_l): + base = f"decoder.prediction.dec_rnn.lstm" + write_tensor(fout, f"{base}.weight_ih_l{i}", reorder_gates(f32(n_lstm_gates, dec_dim))) + write_tensor(fout, f"{base}.weight_hh_l{i}", reorder_gates(f32(n_lstm_gates, dec_dim))) + write_tensor(fout, f"{base}.bias_h_l{i}", reorder_gates(f32(n_lstm_gates) + f32(n_lstm_gates))) + + write_tensor(fout, "joint.pred.weight", f32(dec_dim, dec_dim)) + write_tensor(fout, "joint.pred.bias", f32(dec_dim)) + write_tensor(fout, "joint.enc.weight", f32(dec_dim, n_state)) + write_tensor(fout, "joint.enc.bias", f32(dec_dim)) + write_tensor(fout, "joint.joint_net.2.weight", f32(n_joint_out, dec_dim)) + write_tensor(fout, "joint.joint_net.2.bias", f32(n_joint_out)) + + size = Path(output_path).stat().st_size + print(f"Generated {output_path} ({size / 1024:.1f} KB)") + +if __name__ == '__main__': + output = sys.argv[1] if len(sys.argv) > 1 else 'models/for-tests-ggml-parakeet-tdt.bin' + generate(output) diff --git a/models/requirements-openvino.txt b/models/requirements-openvino.txt index 5bfd95db88e..707fa58ab30 100644 --- a/models/requirements-openvino.txt +++ b/models/requirements-openvino.txt @@ -1,2 +1,2 @@ -openvino-dev[pytorch,onnx] -openai-whisper \ No newline at end of file +openvino>=2023.3.0 +openai-whisper diff --git a/models/requirements-parakeet.txt b/models/requirements-parakeet.txt new file mode 100644 index 00000000000..5239ae0af5d --- /dev/null +++ b/models/requirements-parakeet.txt @@ -0,0 +1,3 @@ +torch +numpy +pyyaml diff --git a/scripts/bench-all-gg.txt b/scripts/bench-all-gg.txt index 32a0908306c..1b65fc7d778 100644 --- a/scripts/bench-all-gg.txt +++ b/scripts/bench-all-gg.txt @@ -111,61 +111,61 @@ make -j && ./scripts/bench-all.sh 1 1 0 | CPU | Config | Model | Th | FA | Enc. | Dec. | Bch5 | PP | Commit | | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | -| M2 ULTRA | METAL | tiny | 1 | 0 | 8.80 | 1.13 | 0.28 | 0.01 | 47af2fb7 | -| M2 ULTRA | METAL | tiny-q5_0 | 1 | 0 | 9.34 | 1.09 | 0.28 | 0.01 | 47af2fb7 | -| M2 ULTRA | METAL | tiny-q5_1 | 1 | 0 | 9.29 | 1.09 | 0.29 | 0.01 | 47af2fb7 | -| M2 ULTRA | METAL | tiny-q8_0 | 1 | 0 | 9.00 | 1.12 | 0.28 | 0.01 | 47af2fb7 | -| M2 ULTRA | METAL | base | 1 | 0 | 15.92 | 1.60 | 0.43 | 0.02 | 47af2fb7 | -| M2 ULTRA | METAL | base-q5_0 | 1 | 0 | 17.01 | 1.53 | 0.43 | 0.02 | 47af2fb7 | -| M2 ULTRA | METAL | base-q5_1 | 1 | 0 | 17.02 | 1.53 | 0.44 | 0.02 | 47af2fb7 | -| M2 ULTRA | METAL | base-q8_0 | 1 | 0 | 16.25 | 1.55 | 0.43 | 0.02 | 47af2fb7 | -| M2 ULTRA | METAL | small | 1 | 0 | 47.83 | 3.09 | 0.91 | 0.05 | 47af2fb7 | -| M2 ULTRA | METAL | small-q5_0 | 1 | 0 | 52.85 | 2.98 | 0.94 | 0.06 | 47af2fb7 | -| M2 ULTRA | METAL | small-q5_1 | 1 | 0 | 52.92 | 2.97 | 0.94 | 0.06 | 47af2fb7 | -| M2 ULTRA | METAL | small-q8_0 | 1 | 0 | 49.05 | 2.89 | 0.90 | 0.06 | 47af2fb7 | -| M2 ULTRA | METAL | medium | 1 | 0 | 127.98 | 6.62 | 2.05 | 0.12 | 47af2fb7 | -| M2 ULTRA | METAL | medium-q5_0 | 1 | 0 | 145.42 | 6.09 | 2.12 | 0.14 | 47af2fb7 | -| M2 ULTRA | METAL | medium-q5_1 | 1 | 0 | 145.16 | 6.08 | 2.14 | 0.14 | 47af2fb7 | -| M2 ULTRA | METAL | medium-q8_0 | 1 | 0 | 132.72 | 6.10 | 2.07 | 0.13 | 47af2fb7 | -| M2 ULTRA | METAL | medium-dis | 1 | 0 | 115.09 | 0.91 | 0.25 | 0.02 | 47af2fb7 | -| M2 ULTRA | METAL | large-v2 | 1 | 0 | 243.69 | 9.68 | 3.14 | 0.22 | 47af2fb7 | -| M2 ULTRA | METAL | large-v2-q5_0 | 1 | 0 | 280.38 | 8.95 | 3.18 | 0.25 | 47af2fb7 | -| M2 ULTRA | METAL | large-v2-q5_1 | 1 | 0 | 279.76 | 8.92 | 3.18 | 0.25 | 47af2fb7 | -| M2 ULTRA | METAL | large-v2-q8_0 | 1 | 0 | 254.55 | 9.35 | 3.04 | 0.23 | 47af2fb7 | -| M2 ULTRA | METAL | large-v2-dis | 1 | 0 | 219.23 | 1.01 | 0.28 | 0.02 | 47af2fb7 | -| M2 ULTRA | METAL | large-v3-turbo | 1 | 0 | 220.57 | 1.55 | 0.46 | 0.03 | 47af2fb7 | -| M2 ULTRA | METAL | large-v3-turbo-q5_0 | 1 | 0 | 253.03 | 1.40 | 0.47 | 0.04 | 47af2fb7 | -| M2 ULTRA | METAL | large-v3-turbo-q8_0 | 1 | 0 | 229.82 | 1.43 | 0.45 | 0.04 | 47af2fb7 | +| M2 ULTRA | METAL | tiny | 1 | 0 | 8.10 | 1.03 | 0.25 | 0.01 | f14ae77f | +| M2 ULTRA | METAL | tiny-q5_0 | 1 | 0 | 8.53 | 1.02 | 0.26 | 0.01 | f14ae77f | +| M2 ULTRA | METAL | tiny-q5_1 | 1 | 0 | 8.67 | 1.00 | 0.26 | 0.01 | f14ae77f | +| M2 ULTRA | METAL | tiny-q8_0 | 1 | 0 | 9.32 | 1.02 | 0.26 | 0.01 | f14ae77f | +| M2 ULTRA | METAL | base | 1 | 0 | 15.50 | 1.51 | 0.40 | 0.02 | f14ae77f | +| M2 ULTRA | METAL | base-q5_0 | 1 | 0 | 16.63 | 1.45 | 0.40 | 0.02 | f14ae77f | +| M2 ULTRA | METAL | base-q5_1 | 1 | 0 | 16.76 | 1.44 | 0.39 | 0.02 | f14ae77f | +| M2 ULTRA | METAL | base-q8_0 | 1 | 0 | 15.73 | 1.43 | 0.38 | 0.02 | f14ae77f | +| M2 ULTRA | METAL | small | 1 | 0 | 45.43 | 2.93 | 0.83 | 0.05 | f14ae77f | +| M2 ULTRA | METAL | small-q5_0 | 1 | 0 | 49.78 | 2.85 | 0.84 | 0.06 | f14ae77f | +| M2 ULTRA | METAL | small-q5_1 | 1 | 0 | 50.22 | 2.85 | 0.84 | 0.06 | f14ae77f | +| M2 ULTRA | METAL | small-q8_0 | 1 | 0 | 47.08 | 2.78 | 0.83 | 0.05 | f14ae77f | +| M2 ULTRA | METAL | medium | 1 | 0 | 125.19 | 6.10 | 1.88 | 0.12 | f14ae77f | +| M2 ULTRA | METAL | medium-q5_0 | 1 | 0 | 142.49 | 5.59 | 1.90 | 0.14 | f14ae77f | +| M2 ULTRA | METAL | medium-q5_1 | 1 | 0 | 142.63 | 5.68 | 1.92 | 0.14 | f14ae77f | +| M2 ULTRA | METAL | medium-q8_0 | 1 | 0 | 130.98 | 5.83 | 1.87 | 0.13 | f14ae77f | +| M2 ULTRA | METAL | medium-dis | 1 | 0 | 113.95 | 0.88 | 0.24 | 0.02 | f14ae77f | +| M2 ULTRA | METAL | large-v2 | 1 | 0 | 239.27 | 8.97 | 2.92 | 0.21 | f14ae77f | +| M2 ULTRA | METAL | large-v2-q5_0 | 1 | 0 | 275.07 | 8.56 | 2.92 | 0.24 | f14ae77f | +| M2 ULTRA | METAL | large-v2-q5_1 | 1 | 0 | 274.28 | 8.62 | 2.93 | 0.24 | f14ae77f | +| M2 ULTRA | METAL | large-v2-q8_0 | 1 | 0 | 248.90 | 8.32 | 2.81 | 0.22 | f14ae77f | +| M2 ULTRA | METAL | large-v2-dis | 1 | 0 | 214.26 | 0.97 | 0.27 | 0.02 | f14ae77f | +| M2 ULTRA | METAL | large-v3-turbo | 1 | 0 | 222.47 | 1.49 | 0.45 | 0.03 | f14ae77f | +| M2 ULTRA | METAL | large-v3-turbo-q5_0 | 1 | 0 | 250.56 | 1.35 | 0.45 | 0.04 | f14ae77f | +| M2 ULTRA | METAL | large-v3-turbo-q8_0 | 1 | 0 | 228.57 | 1.33 | 0.43 | 0.03 | f14ae77f | make -j && ./scripts/bench-all.sh 1 1 1 | CPU | Config | Model | Th | FA | Enc. | Dec. | Bch5 | PP | Commit | | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | -| M2 ULTRA | METAL | tiny | 1 | 1 | 6.19 | 0.93 | 0.21 | 0.01 | 47af2fb7 | -| M2 ULTRA | METAL | tiny-q5_0 | 1 | 1 | 6.64 | 0.89 | 0.22 | 0.01 | 47af2fb7 | -| M2 ULTRA | METAL | tiny-q5_1 | 1 | 1 | 6.65 | 0.91 | 0.23 | 0.01 | 47af2fb7 | -| M2 ULTRA | METAL | tiny-q8_0 | 1 | 1 | 6.26 | 0.93 | 0.22 | 0.01 | 47af2fb7 | -| M2 ULTRA | METAL | base | 1 | 1 | 10.89 | 1.31 | 0.32 | 0.02 | 47af2fb7 | -| M2 ULTRA | METAL | base-q5_0 | 1 | 1 | 12.10 | 1.22 | 0.33 | 0.02 | 47af2fb7 | -| M2 ULTRA | METAL | base-q5_1 | 1 | 1 | 12.05 | 1.22 | 0.33 | 0.02 | 47af2fb7 | -| M2 ULTRA | METAL | base-q8_0 | 1 | 1 | 11.24 | 1.24 | 0.32 | 0.02 | 47af2fb7 | -| M2 ULTRA | METAL | small | 1 | 1 | 32.06 | 2.41 | 0.64 | 0.04 | 47af2fb7 | -| M2 ULTRA | METAL | small-q5_0 | 1 | 1 | 37.20 | 2.32 | 0.67 | 0.04 | 47af2fb7 | -| M2 ULTRA | METAL | small-q5_1 | 1 | 1 | 37.13 | 2.30 | 0.67 | 0.04 | 47af2fb7 | -| M2 ULTRA | METAL | small-q8_0 | 1 | 1 | 33.63 | 2.28 | 0.64 | 0.04 | 47af2fb7 | -| M2 ULTRA | METAL | medium | 1 | 1 | 89.22 | 5.14 | 1.46 | 0.09 | 47af2fb7 | -| M2 ULTRA | METAL | medium-q5_0 | 1 | 1 | 106.82 | 4.83 | 1.49 | 0.11 | 47af2fb7 | -| M2 ULTRA | METAL | medium-q5_1 | 1 | 1 | 106.60 | 4.88 | 1.50 | 0.11 | 47af2fb7 | -| M2 ULTRA | METAL | medium-q8_0 | 1 | 1 | 94.48 | 4.93 | 1.43 | 0.09 | 47af2fb7 | -| M2 ULTRA | METAL | medium-dis | 1 | 1 | 77.85 | 0.80 | 0.20 | 0.01 | 47af2fb7 | -| M2 ULTRA | METAL | large-v2 | 1 | 1 | 170.73 | 7.50 | 2.12 | 0.16 | 47af2fb7 | -| M2 ULTRA | METAL | large-v2-q5_0 | 1 | 1 | 206.46 | 7.05 | 2.17 | 0.20 | 47af2fb7 | -| M2 ULTRA | METAL | large-v2-q5_1 | 1 | 1 | 206.15 | 7.10 | 2.19 | 0.20 | 47af2fb7 | -| M2 ULTRA | METAL | large-v2-q8_0 | 1 | 1 | 180.31 | 6.90 | 2.10 | 0.17 | 47af2fb7 | -| M2 ULTRA | METAL | large-v2-dis | 1 | 1 | 147.44 | 0.90 | 0.22 | 0.02 | 47af2fb7 | -| M2 ULTRA | METAL | large-v3-turbo | 1 | 1 | 148.79 | 1.30 | 0.34 | 0.03 | 47af2fb7 | -| M2 ULTRA | METAL | large-v3-turbo-q5_0 | 1 | 1 | 180.34 | 1.14 | 0.35 | 0.03 | 47af2fb7 | -| M2 ULTRA | METAL | large-v3-turbo-q8_0 | 1 | 1 | 158.04 | 1.18 | 0.33 | 0.03 | 47af2fb7 | +| M2 ULTRA | METAL | tiny | 1 | 1 | 6.03 | 0.86 | 0.20 | 0.01 | f14ae77f | +| M2 ULTRA | METAL | tiny-q5_0 | 1 | 1 | 6.46 | 0.84 | 0.21 | 0.01 | f14ae77f | +| M2 ULTRA | METAL | tiny-q5_1 | 1 | 1 | 6.46 | 0.85 | 0.21 | 0.01 | f14ae77f | +| M2 ULTRA | METAL | tiny-q8_0 | 1 | 1 | 6.14 | 0.88 | 0.20 | 0.01 | f14ae77f | +| M2 ULTRA | METAL | base | 1 | 1 | 10.87 | 1.24 | 0.31 | 0.01 | f14ae77f | +| M2 ULTRA | METAL | base-q5_0 | 1 | 1 | 11.98 | 1.18 | 0.31 | 0.02 | f14ae77f | +| M2 ULTRA | METAL | base-q5_1 | 1 | 1 | 12.07 | 1.18 | 0.31 | 0.02 | f14ae77f | +| M2 ULTRA | METAL | base-q8_0 | 1 | 1 | 11.13 | 1.19 | 0.30 | 0.02 | f14ae77f | +| M2 ULTRA | METAL | small | 1 | 1 | 31.46 | 2.37 | 0.63 | 0.04 | f14ae77f | +| M2 ULTRA | METAL | small-q5_0 | 1 | 1 | 36.16 | 2.31 | 0.65 | 0.04 | f14ae77f | +| M2 ULTRA | METAL | small-q5_1 | 1 | 1 | 36.57 | 2.31 | 0.65 | 0.04 | f14ae77f | +| M2 ULTRA | METAL | small-q8_0 | 1 | 1 | 32.94 | 2.27 | 0.63 | 0.04 | f14ae77f | +| M2 ULTRA | METAL | medium | 1 | 1 | 89.86 | 4.92 | 1.41 | 0.09 | f14ae77f | +| M2 ULTRA | METAL | medium-q5_0 | 1 | 1 | 107.12 | 4.72 | 1.42 | 0.10 | f14ae77f | +| M2 ULTRA | METAL | medium-q5_1 | 1 | 1 | 107.00 | 4.70 | 1.42 | 0.10 | f14ae77f | +| M2 ULTRA | METAL | medium-q8_0 | 1 | 1 | 94.93 | 4.56 | 1.37 | 0.09 | f14ae77f | +| M2 ULTRA | METAL | medium-dis | 1 | 1 | 79.66 | 0.78 | 0.20 | 0.01 | f14ae77f | +| M2 ULTRA | METAL | large-v2 | 1 | 1 | 170.06 | 7.13 | 2.15 | 0.16 | f14ae77f | +| M2 ULTRA | METAL | large-v2-q5_0 | 1 | 1 | 205.16 | 6.80 | 2.18 | 0.20 | f14ae77f | +| M2 ULTRA | METAL | large-v2-q5_1 | 1 | 1 | 204.22 | 6.69 | 2.16 | 0.20 | f14ae77f | +| M2 ULTRA | METAL | large-v2-q8_0 | 1 | 1 | 179.78 | 6.35 | 2.13 | 0.18 | f14ae77f | +| M2 ULTRA | METAL | large-v2-dis | 1 | 1 | 148.11 | 0.89 | 0.22 | 0.02 | f14ae77f | +| M2 ULTRA | METAL | large-v3-turbo | 1 | 1 | 149.23 | 1.29 | 0.34 | 0.03 | f14ae77f | +| M2 ULTRA | METAL | large-v3-turbo-q5_0 | 1 | 1 | 180.77 | 1.13 | 0.35 | 0.03 | f14ae77f | +| M2 ULTRA | METAL | large-v3-turbo-q8_0 | 1 | 1 | 158.66 | 1.10 | 0.33 | 0.03 | f14ae77f | ## M4 Max @@ -233,20 +233,6 @@ make -j && ./scripts/bench-all.sh 1 1 0 make -j && ./scripts/bench-all.sh 1 1 1 -| CPU | Config | Model | Th | FA | Enc. | Dec. | Bch5 | PP | Commit | -| --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | -| M4 Max | METAL | tiny | 1 | 1 | 8.23 | 0.71 | 0.16 | 0.01 | 47fcd7da | -| M4 Max | METAL | tiny-q8_0 | 1 | 1 | 8.47 | 0.67 | 0.16 | 0.01 | 47fcd7da | -| M4 Max | METAL | base | 1 | 1 | 15.47 | 1.12 | 0.26 | 0.02 | 47fcd7da | -| M4 Max | METAL | base-q8_0 | 1 | 1 | 15.70 | 1.05 | 0.27 | 0.02 | 47fcd7da | -| M4 Max | METAL | small | 1 | 1 | 49.82 | 2.37 | 0.53 | 0.05 | 47fcd7da | -| M4 Max | METAL | small-q8_0 | 1 | 1 | 51.76 | 1.99 | 0.53 | 0.05 | 47fcd7da | -| M4 Max | METAL | medium | 1 | 1 | 147.76 | 5.52 | 1.27 | 0.12 | 47fcd7da | -| M4 Max | METAL | medium-q8_0 | 1 | 1 | 153.98 | 4.59 | 1.24 | 0.13 | 47fcd7da | -| M4 Max | METAL | large-v2 | 1 | 1 | 282.89 | 9.06 | 2.11 | 0.22 | 47fcd7da | -| M4 Max | METAL | large-v2-q8_0 | 1 | 1 | 296.43 | 7.44 | 2.09 | 0.23 | 47fcd7da | -| M4 Max | METAL | large-v3-turbo | 1 | 1 | 249.91 | 1.65 | 0.38 | 0.04 | 47fcd7da | - | CPU | Config | Model | Th | FA | Enc. | Dec. | Bch5 | PP | Commit | | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | | M4 Max | METAL | tiny | 1 | 1 | 8.23 | 0.72 | 0.16 | 0.01 | 47af2fb7 | @@ -262,41 +248,77 @@ make -j && ./scripts/bench-all.sh 1 1 1 | M4 Max | METAL | large-v3-turbo | 1 | 1 | 256.23 | 1.61 | 0.38 | 0.04 | 47af2fb7 | +## M5 Max + +make -j && ./scripts/bench-all.sh 1 1 0 + +| CPU | Config | Model | Th | FA | Enc. | Dec. | Bch5 | PP | Commit | +| --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | +| M5 Max | METAL | tiny | 1 | 0 | 4.88 | 0.65 | 0.17 | 0.01 | f14ae77f | +| M5 Max | METAL | tiny-q8_0 | 1 | 0 | 4.84 | 0.63 | 0.17 | 0.01 | f14ae77f | +| M5 Max | METAL | base | 1 | 0 | 8.95 | 1.02 | 0.24 | 0.01 | f14ae77f | +| M5 Max | METAL | base-q8_0 | 1 | 0 | 9.12 | 0.94 | 0.24 | 0.01 | f14ae77f | +| M5 Max | METAL | small | 1 | 0 | 25.61 | 2.15 | 0.52 | 0.03 | f14ae77f | +| M5 Max | METAL | small-q8_0 | 1 | 0 | 25.77 | 1.93 | 0.50 | 0.03 | f14ae77f | +| M5 Max | METAL | medium | 1 | 0 | 73.96 | 4.61 | 1.16 | 0.08 | f14ae77f | +| M5 Max | METAL | medium-q8_0 | 1 | 0 | 74.89 | 3.94 | 1.12 | 0.08 | f14ae77f | +| M5 Max | METAL | large-v2 | 1 | 0 | 132.06 | 6.91 | 1.86 | 0.13 | f14ae77f | +| M5 Max | METAL | large-v2-q8_0 | 1 | 0 | 132.56 | 6.00 | 1.76 | 0.13 | f14ae77f | +| M5 Max | METAL | large-v3-turbo | 1 | 0 | 119.34 | 1.30 | 0.32 | 0.02 | f14ae77f | + + +make -j && ./scripts/bench-all.sh 1 1 1 + +| CPU | Config | Model | Th | FA | Enc. | Dec. | Bch5 | PP | Commit | +| --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | +| M5 Max | METAL | tiny | 1 | 1 | 4.31 | 0.59 | 0.13 | 0.01 | f14ae77f | +| M5 Max | METAL | tiny-q8_0 | 1 | 1 | 4.51 | 0.55 | 0.12 | 0.01 | f14ae77f | +| M5 Max | METAL | base | 1 | 1 | 7.77 | 0.91 | 0.20 | 0.01 | f14ae77f | +| M5 Max | METAL | base-q8_0 | 1 | 1 | 7.67 | 0.78 | 0.19 | 0.01 | f14ae77f | +| M5 Max | METAL | small | 1 | 1 | 20.90 | 1.76 | 0.40 | 0.03 | f14ae77f | +| M5 Max | METAL | small-q8_0 | 1 | 1 | 21.32 | 1.62 | 0.38 | 0.03 | f14ae77f | +| M5 Max | METAL | medium | 1 | 1 | 60.40 | 3.98 | 0.89 | 0.07 | f14ae77f | +| M5 Max | METAL | medium-q8_0 | 1 | 1 | 60.72 | 3.35 | 0.86 | 0.07 | f14ae77f | +| M5 Max | METAL | large-v2 | 1 | 1 | 110.57 | 6.06 | 1.41 | 0.12 | f14ae77f | +| M5 Max | METAL | large-v2-q8_0 | 1 | 1 | 110.92 | 5.00 | 1.31 | 0.12 | f14ae77f | +| M5 Max | METAL | large-v3-turbo | 1 | 1 | 98.36 | 1.19 | 0.27 | 0.02 | f14ae77f | + + # RTX 5090 make -j && ./scripts/bench-all.sh 1 1 0 | GPU | Config | Model | Th | FA | Enc. | Dec. | Bch5 | PP | Commit | | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | -| RTX 5090 | CUDA | tiny | 1 | 0 | 2.12 | 0.51 | 0.13 | 0.00 | 47af2fb7 | -| RTX 5090 | CUDA | tiny-q8_0 | 1 | 0 | 2.50 | 0.52 | 0.14 | 0.01 | 47af2fb7 | -| RTX 5090 | CUDA | base | 1 | 0 | 3.74 | 0.76 | 0.19 | 0.01 | 47af2fb7 | -| RTX 5090 | CUDA | base-q8_0 | 1 | 0 | 4.38 | 0.74 | 0.20 | 0.01 | 47af2fb7 | -| RTX 5090 | CUDA | small | 1 | 0 | 11.25 | 1.46 | 0.39 | 0.02 | 47af2fb7 | -| RTX 5090 | CUDA | small-q8_0 | 1 | 0 | 12.70 | 1.58 | 0.41 | 0.02 | 47af2fb7 | -| RTX 5090 | CUDA | medium | 1 | 0 | 31.16 | 3.07 | 0.80 | 0.04 | 47af2fb7 | -| RTX 5090 | CUDA | medium-q8_0 | 1 | 0 | 32.50 | 3.23 | 0.83 | 0.05 | 47af2fb7 | -| RTX 5090 | CUDA | large-v2 | 1 | 0 | 50.04 | 4.59 | 1.15 | 0.05 | 47af2fb7 | -| RTX 5090 | CUDA | large-v2-q8_0 | 1 | 0 | 52.17 | 4.38 | 1.14 | 0.07 | 47af2fb7 | -| RTX 5090 | CUDA | large-v3-turbo | 1 | 0 | 46.88 | 0.70 | 0.17 | 0.01 | 47af2fb7 | -| RTX 5090 | CUDA | large-v3-turbo-q8_0 | 1 | 0 | 48.49 | 0.64 | 0.16 | 0.01 | 47af2fb7 | +| RTX 5090 | CUDA | tiny | 1 | 0 | 2.17 | 0.38 | 0.10 | 0.00 | f14ae77f | +| RTX 5090 | CUDA | tiny-q8_0 | 1 | 0 | 2.31 | 0.37 | 0.10 | 0.01 | f14ae77f | +| RTX 5090 | CUDA | base | 1 | 0 | 3.94 | 0.56 | 0.17 | 0.01 | f14ae77f | +| RTX 5090 | CUDA | base-q8_0 | 1 | 0 | 4.13 | 0.53 | 0.14 | 0.01 | f14ae77f | +| RTX 5090 | CUDA | small | 1 | 0 | 12.06 | 1.09 | 0.34 | 0.02 | f14ae77f | +| RTX 5090 | CUDA | small-q8_0 | 1 | 0 | 12.50 | 1.11 | 0.30 | 0.02 | f14ae77f | +| RTX 5090 | CUDA | medium | 1 | 0 | 33.08 | 2.38 | 0.70 | 0.04 | f14ae77f | +| RTX 5090 | CUDA | medium-q8_0 | 1 | 0 | 32.57 | 2.26 | 0.62 | 0.04 | f14ae77f | +| RTX 5090 | CUDA | large-v2 | 1 | 0 | 54.27 | 3.68 | 1.03 | 0.06 | f14ae77f | +| RTX 5090 | CUDA | large-v2-q8_0 | 1 | 0 | 53.11 | 3.22 | 0.89 | 0.06 | f14ae77f | +| RTX 5090 | CUDA | large-v3-turbo | 1 | 0 | 50.56 | 0.58 | 0.15 | 0.01 | f14ae77f | +| RTX 5090 | CUDA | large-v3-turbo-q8_0 | 1 | 0 | 49.39 | 0.49 | 0.13 | 0.01 | f14ae77f | make -j && ./scripts/bench-all.sh 1 1 1 | GPU | Config | Model | Th | FA | Enc. | Dec. | Bch5 | PP | Commit | | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | -| RTX 5090 | CUDA | tiny | 1 | 1 | 1.42 | 0.44 | 0.11 | 0.00 | 47af2fb7 | -| RTX 5090 | CUDA | tiny-q8_0 | 1 | 1 | 1.83 | 0.45 | 0.12 | 0.01 | 47af2fb7 | -| RTX 5090 | CUDA | base | 1 | 1 | 2.21 | 0.65 | 0.16 | 0.01 | 47af2fb7 | -| RTX 5090 | CUDA | base-q8_0 | 1 | 1 | 2.85 | 0.62 | 0.17 | 0.01 | 47af2fb7 | -| RTX 5090 | CUDA | small | 1 | 1 | 5.11 | 1.23 | 0.32 | 0.01 | 47af2fb7 | -| RTX 5090 | CUDA | small-q8_0 | 1 | 1 | 6.50 | 1.35 | 0.34 | 0.02 | 47af2fb7 | -| RTX 5090 | CUDA | medium | 1 | 1 | 14.01 | 2.57 | 0.64 | 0.03 | 47af2fb7 | -| RTX 5090 | CUDA | medium-q8_0 | 1 | 1 | 15.34 | 2.72 | 0.67 | 0.04 | 47af2fb7 | -| RTX 5090 | CUDA | large-v2 | 1 | 1 | 21.70 | 3.96 | 0.97 | 0.04 | 47af2fb7 | -| RTX 5090 | CUDA | large-v2-q8_0 | 1 | 1 | 23.57 | 3.70 | 0.94 | 0.05 | 47af2fb7 | -| RTX 5090 | CUDA | large-v3-turbo | 1 | 1 | 18.61 | 0.62 | 0.15 | 0.01 | 47af2fb7 | -| RTX 5090 | CUDA | large-v3-turbo-q8_0 | 1 | 1 | 20.10 | 0.56 | 0.14 | 0.01 | 47af2fb7 | +| RTX 5090 | CUDA | tiny | 1 | 1 | 1.29 | 0.31 | 0.07 | 0.00 | f14ae77f | +| RTX 5090 | CUDA | tiny-q8_0 | 1 | 1 | 1.45 | 0.31 | 0.07 | 0.00 | f14ae77f | +| RTX 5090 | CUDA | base | 1 | 1 | 2.15 | 0.44 | 0.13 | 0.01 | f14ae77f | +| RTX 5090 | CUDA | base-q8_0 | 1 | 1 | 2.27 | 0.43 | 0.10 | 0.01 | f14ae77f | +| RTX 5090 | CUDA | small | 1 | 1 | 5.54 | 0.83 | 0.26 | 0.01 | f14ae77f | +| RTX 5090 | CUDA | small-q8_0 | 1 | 1 | 5.95 | 0.84 | 0.22 | 0.01 | f14ae77f | +| RTX 5090 | CUDA | medium | 1 | 1 | 15.43 | 1.81 | 0.53 | 0.02 | f14ae77f | +| RTX 5090 | CUDA | medium-q8_0 | 1 | 1 | 14.71 | 1.66 | 0.46 | 0.03 | f14ae77f | +| RTX 5090 | CUDA | large-v2 | 1 | 1 | 24.73 | 2.92 | 0.81 | 0.04 | f14ae77f | +| RTX 5090 | CUDA | large-v2-q8_0 | 1 | 1 | 23.35 | 2.43 | 0.67 | 0.04 | f14ae77f | +| RTX 5090 | CUDA | large-v3-turbo | 1 | 1 | 21.36 | 0.49 | 0.13 | 0.01 | f14ae77f | +| RTX 5090 | CUDA | large-v3-turbo-q8_0 | 1 | 1 | 20.07 | 0.39 | 0.10 | 0.01 | f14ae77f | # DGX Spark @@ -305,35 +327,50 @@ make -j && ./scripts/bench-all.sh 1 1 0 | GPU | Config | Model | Th | FA | Enc. | Dec. | Bch5 | PP | Commit | | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | -| DGX Spk. | CUDA | tiny | 1 | 0 | 9.42 | 0.85 | 0.22 | 0.01 | 47af2fb7 | -| DGX Spk. | CUDA | tiny-q8_0 | 1 | 0 | 9.69 | 0.81 | 0.20 | 0.01 | 47af2fb7 | -| DGX Spk. | CUDA | base | 1 | 0 | 18.81 | 1.36 | 0.33 | 0.02 | 47af2fb7 | -| DGX Spk. | CUDA | base-q8_0 | 1 | 0 | 18.11 | 1.20 | 0.30 | 0.02 | 47af2fb7 | -| DGX Spk. | CUDA | small | 1 | 0 | 59.83 | 3.01 | 0.74 | 0.04 | 47af2fb7 | -| DGX Spk. | CUDA | small-q8_0 | 1 | 0 | 59.12 | 2.66 | 0.67 | 0.05 | 47af2fb7 | -| DGX Spk. | CUDA | medium | 1 | 0 | 163.73 | 7.53 | 1.70 | 0.12 | 47af2fb7 | -| DGX Spk. | CUDA | medium-q8_0 | 1 | 0 | 157.54 | 5.98 | 1.48 | 0.13 | 47af2fb7 | -| DGX Spk. | CUDA | large-v2 | 1 | 0 | 279.83 | 12.26 | 2.77 | 0.21 | 47af2fb7 | -| DGX Spk. | CUDA | large-v2-q8_0 | 1 | 0 | 273.05 | 9.31 | 2.33 | 0.22 | 47af2fb7 | -| DGX Spk. | CUDA | large-v3-turbo | 1 | 0 | 271.11 | 2.06 | 0.47 | 0.03 | 47af2fb7 | -| DGX Spk. | CUDA | large-v3-turbo-q8_0 | 1 | 0 | 262.69 | 1.49 | 0.36 | 0.03 | 47af2fb7 | +| DGX Spk. | CUDA | tiny | 1 | 0 | 9.00 | 0.85 | 0.14 | 0.01 | f5b477ab | +| DGX Spk. | CUDA | tiny-q8_0 | 1 | 0 | 8.86 | 0.83 | 0.12 | 0.01 | f5b477ab | +| DGX Spk. | CUDA | base | 1 | 0 | 18.48 | 1.38 | 0.22 | 0.02 | f5b477ab | +| DGX Spk. | CUDA | base-q8_0 | 1 | 0 | 17.28 | 1.22 | 0.19 | 0.02 | f5b477ab | +| DGX Spk. | CUDA | small | 1 | 0 | 56.43 | 3.01 | 0.51 | 0.04 | f5b477ab | +| DGX Spk. | CUDA | small-q8_0 | 1 | 0 | 55.70 | 2.68 | 0.44 | 0.04 | f5b477ab | +| DGX Spk. | CUDA | medium | 1 | 0 | 160.20 | 7.52 | 1.25 | 0.11 | f5b477ab | +| DGX Spk. | CUDA | medium-q8_0 | 1 | 0 | 150.84 | 6.01 | 1.01 | 0.12 | f5b477ab | +| DGX Spk. | CUDA | large-v2 | 1 | 0 | 276.42 | 12.29 | 2.16 | 0.20 | f5b477ab | +| DGX Spk. | CUDA | large-v2-q8_0 | 1 | 0 | 264.92 | 9.32 | 1.67 | 0.20 | f5b477ab | +| DGX Spk. | CUDA | large-v3-turbo | 1 | 0 | 264.90 | 2.03 | 0.37 | 0.03 | f5b477ab | +| DGX Spk. | CUDA | large-v3-turbo-q8_0 | 1 | 0 | 253.56 | 1.48 | 0.27 | 0.03 | f5b477ab | + +| GPU | Config | Model | Th | FA | Enc. | Dec. | Bch5 | PP | Commit | +| --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | +| DGX Spk. | CUDA | tiny | 1 | 0 | 9.79 | 0.65 | 0.14 | 0.01 | f14ae77f | +| DGX Spk. | CUDA | tiny-q8_0 | 1 | 0 | 8.97 | 0.56 | 0.12 | 0.01 | f14ae77f | +| DGX Spk. | CUDA | base | 1 | 0 | 18.58 | 1.04 | 0.22 | 0.01 | f14ae77f | +| DGX Spk. | CUDA | base-q8_0 | 1 | 0 | 17.36 | 0.88 | 0.18 | 0.02 | f14ae77f | +| DGX Spk. | CUDA | small | 1 | 0 | 56.78 | 2.33 | 0.51 | 0.04 | f14ae77f | +| DGX Spk. | CUDA | small-q8_0 | 1 | 0 | 55.47 | 1.99 | 0.43 | 0.04 | f14ae77f | +| DGX Spk. | CUDA | medium | 1 | 0 | 158.21 | 5.71 | 1.23 | 0.11 | f14ae77f | +| DGX Spk. | CUDA | medium-q8_0 | 1 | 0 | 151.17 | 4.54 | 0.97 | 0.11 | f14ae77f | +| DGX Spk. | CUDA | large-v2 | 1 | 0 | 269.59 | 10.48 | 2.13 | 0.20 | f14ae77f | +| DGX Spk. | CUDA | large-v2-q8_0 | 1 | 0 | 262.82 | 7.43 | 1.61 | 0.20 | f14ae77f | +| DGX Spk. | CUDA | large-v3-turbo | 1 | 0 | 263.91 | 1.80 | 0.37 | 0.03 | f14ae77f | +| DGX Spk. | CUDA | large-v3-turbo-q8_0 | 1 | 0 | 252.89 | 1.23 | 0.26 | 0.03 | f14ae77f | make -j && ./scripts/bench-all.sh 1 1 1 | GPU | Config | Model | Th | FA | Enc. | Dec. | Bch5 | PP | Commit | | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | -| DGX Spk. | CUDA | tiny | 1 | 1 | 2.89 | 0.76 | 0.19 | 0.01 | 47af2fb7 | -| DGX Spk. | CUDA | tiny-q8_0 | 1 | 1 | 3.06 | 0.72 | 0.17 | 0.01 | 47af2fb7 | -| DGX Spk. | CUDA | base | 1 | 1 | 5.37 | 1.23 | 0.29 | 0.01 | 47af2fb7 | -| DGX Spk. | CUDA | base-q8_0 | 1 | 1 | 4.70 | 1.07 | 0.26 | 0.01 | 47af2fb7 | -| DGX Spk. | CUDA | small | 1 | 1 | 17.70 | 2.73 | 0.66 | 0.02 | 47af2fb7 | -| DGX Spk. | CUDA | small-q8_0 | 1 | 1 | 16.77 | 2.38 | 0.58 | 0.03 | 47af2fb7 | -| DGX Spk. | CUDA | medium | 1 | 1 | 56.22 | 6.98 | 1.53 | 0.06 | 47af2fb7 | -| DGX Spk. | CUDA | medium-q8_0 | 1 | 1 | 46.39 | 5.46 | 1.28 | 0.07 | 47af2fb7 | -| DGX Spk. | CUDA | large-v2 | 1 | 1 | 100.33 | 11.59 | 2.53 | 0.09 | 47af2fb7 | -| DGX Spk. | CUDA | large-v2-q8_0 | 1 | 1 | 97.28 | 8.60 | 2.10 | 0.10 | 47af2fb7 | -| DGX Spk. | CUDA | large-v3-turbo | 1 | 1 | 92.59 | 2.00 | 0.44 | 0.02 | 47af2fb7 | -| DGX Spk. | CUDA | large-v3-turbo-q8_0 | 1 | 1 | 85.96 | 1.40 | 0.33 | 0.02 | 47af2fb7 | +| DGX Spk. | CUDA | tiny | 1 | 1 | 2.72 | 0.56 | 0.13 | 0.01 | f14ae77f | +| DGX Spk. | CUDA | tiny-q8_0 | 1 | 1 | 2.55 | 0.47 | 0.11 | 0.01 | f14ae77f | +| DGX Spk. | CUDA | base | 1 | 1 | 5.08 | 0.90 | 0.20 | 0.01 | f14ae77f | +| DGX Spk. | CUDA | base-q8_0 | 1 | 1 | 4.38 | 0.72 | 0.16 | 0.01 | f14ae77f | +| DGX Spk. | CUDA | small | 1 | 1 | 16.95 | 2.00 | 0.47 | 0.02 | f14ae77f | +| DGX Spk. | CUDA | small-q8_0 | 1 | 1 | 15.67 | 1.67 | 0.39 | 0.02 | f14ae77f | +| DGX Spk. | CUDA | medium | 1 | 1 | 53.12 | 5.10 | 1.24 | 0.06 | f14ae77f | +| DGX Spk. | CUDA | medium-q8_0 | 1 | 1 | 43.64 | 3.87 | 0.91 | 0.05 | f14ae77f | +| DGX Spk. | CUDA | large-v2 | 1 | 1 | 102.15 | 9.58 | 2.02 | 0.08 | f14ae77f | +| DGX Spk. | CUDA | large-v2-q8_0 | 1 | 1 | 93.86 | 6.54 | 1.49 | 0.08 | f14ae77f | +| DGX Spk. | CUDA | large-v3-turbo | 1 | 1 | 90.29 | 1.69 | 0.36 | 0.02 | f14ae77f | +| DGX Spk. | CUDA | large-v3-turbo-q8_0 | 1 | 1 | 82.79 | 1.13 | 0.25 | 0.01 | f14ae77f | # V100 diff --git a/scripts/bench-all.sh b/scripts/bench-all.sh index a15a361c708..7a0d0c8764b 100755 --- a/scripts/bench-all.sh +++ b/scripts/bench-all.sh @@ -100,12 +100,14 @@ for model in "${models[@]}"; do if [[ $system_info == *"CUDA = 1"* ]]; then config="$config CUDA" + elif [[ $system_info == *"CUDA : ARCHS"* ]]; then + config="$config CUDA" fi - if [[ $system_info == *"METAL = 1"* ]]; then - config="$config METAL" - elif [[ $system_info == *"Metal : EMBED_LIBRARY = 1"* ]]; then - config="$config METAL" + if [[ $system_info == *"MTL = 1"* ]]; then + config="$config MTL" + elif [[ $system_info == *"MTL : EMBED_LIBRARY = 1"* ]]; then + config="$config MTL" fi commit=$(git rev-parse --short HEAD) diff --git a/scripts/quantize-parakeet.sh b/scripts/quantize-parakeet.sh new file mode 100755 index 00000000000..7816696bfcb --- /dev/null +++ b/scripts/quantize-parakeet.sh @@ -0,0 +1,15 @@ +#!/bin/bash + +set -e + +build_dir=build +modelname=ggml-parakeet-tdt-0.6b-v3 +model=models/${modelname}-f32.bin +cmd=parakeet-quantize + +cmake --build ${build_dir} --target $cmd -j 12 + +${build_dir}/bin/${cmd} $model models/${modelname}-q8_0.bin q8_0 +${build_dir}/bin/${cmd} $model models/${modelname}-q4_0.bin q4_0 +${build_dir}/bin/${cmd} $model models/${modelname}-q4_k.bin q4_k +${build_dir}/bin/${cmd} $model models/${modelname}-q2_k.bin q2_k diff --git a/scripts/sync-ggml-am.sh b/scripts/sync-ggml-am.sh index 1f87e23122b..bc7c1b2fe15 100755 --- a/scripts/sync-ggml-am.sh +++ b/scripts/sync-ggml-am.sh @@ -60,8 +60,8 @@ while read c; do cmake/common.cmake \ cmake/ggml-config.cmake.in \ src/ggml-cpu/cmake/FindSIMD.cmake \ - src/ggml*.h \ src/ggml* \ + src/gguf* \ include/ggml*.h \ include/gguf*.h \ examples/common.h \ @@ -105,6 +105,7 @@ if [ -f $SRC_WHISPER/ggml-src.patch ]; then # src/ggml-cpu/cmake/FindSIMD.cmake -> ggml/src/ggml-cpu/cmake/FindSIMD.cmake # # src/ggml* -> ggml/src/ggml*.c + # src/gguf* -> ggml/src/gguf*.c # # include/ggml*.h -> ggml/include/ggml*.h # include/gguf*.h -> ggml/include/gguf*.h @@ -126,6 +127,7 @@ if [ -f $SRC_WHISPER/ggml-src.patch ]; then -e 's/(^[[:space:]]| [ab]\/)cmake\/ggml-config.cmake.in/\1ggml\/cmake\/ggml-config.cmake.in/g' \ -e 's/(^[[:space:]]| [ab]\/)src\/ggml-cpu\/cmake\/FindSIMD.cmake/\1ggml\/src\/ggml-cpu\/cmake\/FindSIMD.cmake/g' \ -e 's/([[:space:]]| [ab]\/)src\/ggml(.*)/\1ggml\/src\/ggml\2/g' \ + -e 's/([[:space:]]| [ab]\/)src\/gguf(.*)/\1ggml\/src\/gguf\2/g' \ -e 's/(^[[:space:]]| [ab]\/)include\/ggml(.*)\.h/\1ggml\/include\/ggml\2.h/g' \ -e 's/(^[[:space:]]| [ab]\/)include\/gguf(.*)\.h/\1ggml\/include\/gguf\2.h/g' \ -e 's/(^[[:space:]]| [ab]\/)examples\/common\.h/\1examples\/common.h/g' \ diff --git a/scripts/sync-ggml.last b/scripts/sync-ggml.last index 44fa890d78b..87d353ef452 100644 --- a/scripts/sync-ggml.last +++ b/scripts/sync-ggml.last @@ -1 +1 @@ -b6d1f0f247adcfa25c0ca1ffe97e651fe1afd5e2 +3af5f5760e19a96427f5f7a93b79cbdf3d4b265b diff --git a/scripts/sync-ggml.sh b/scripts/sync-ggml.sh index 4296ddf5f50..099d5445c8c 100755 --- a/scripts/sync-ggml.sh +++ b/scripts/sync-ggml.sh @@ -7,6 +7,7 @@ cp -rpv ../ggml/cmake/* ./ggml/cmake/ cp -rpv ../ggml/src/ggml-cpu/cmake/* ./ggml/src/ggml-cpu/cmake/ cp -rpv ../ggml/src/ggml* ./ggml/src/ +cp -rpv ../ggml/src/gguf* ./ggml/src/ cp -rpv ../ggml/include/ggml*.h ./ggml/include/ cp -rpv ../ggml/include/gguf*.h ./ggml/include/ diff --git a/scripts/upload-parakeet.py b/scripts/upload-parakeet.py new file mode 100644 index 00000000000..3644bec8bd3 --- /dev/null +++ b/scripts/upload-parakeet.py @@ -0,0 +1,157 @@ +import argparse +import os +from huggingface_hub import HfApi, create_repo + +USER_NAME = "ggml-org" +REPO_ID = f"{USER_NAME}/parakeet-GGUF" + +MODELS = { + "f32": { + "local_path": "models/ggml-parakeet-tdt-0.6b-v3-f32.bin", + "remote_name": "ggml-parakeet-tdt-0.6b-v3-f32.bin", + "description": "Full precision (F32)", + }, + "f16": { + "local_path": "models/ggml-parakeet-tdt-0.6b-v3-f16.bin", + "remote_name": "ggml-parakeet-tdt-0.6b-v3-f16.bin", + "description": "Half precision (F16)", + }, + "q8_0": { + "local_path": "models/ggml-parakeet-tdt-0.6b-v3-q8_0.bin", + "remote_name": "ggml-parakeet-tdt-0.6b-v3-q8_0.bin", + "description": "8-bit quantized (Q8_0)", + }, + "q4_0": { + "local_path": "models/ggml-parakeet-tdt-0.6b-v3-q4_0.bin", + "remote_name": "ggml-parakeet-tdt-0.6b-v3-q4_0.bin", + "description": "4-bit quantized (Q4_0)", + }, + "q4_k": { + "local_path": "models/ggml-parakeet-tdt-0.6b-v3-q4_k.bin", + "remote_name": "ggml-parakeet-tdt-0.6b-v3-q4_k.bin", + "description": "4-bit K-quantized (Q4_k)", + }, +} + +def build_model_card(uploaded_variants): + lines = [ + f"---", + f"license: mit", + f"base_model: nvidia/parakeet-tdt-0.6b-v3", + f"tags:", + f"- gguf", + f"- asr", + f"---", + f"", + f"# Parakeet TDT 0.6B v3 (GGUF)", + f"", + f"GGUF conversions of [nvidia/parakeet-tdt-0.6b-v3](https://huggingface.co/nvidia/parakeet-tdt-0.6b-v3) for use with [whisper.cpp](https://github.com/ggml-org/whisper.cpp).", + f"", + f"## Available files", + f"", + ] + + for key, m in MODELS.items(): + if key in uploaded_variants: + lines.append(f"- `{m['remote_name']}` — {m['description']}") + + lines += [ + f"", + f"## Usage", + f"", + f"Build parakeet-cli:", + f"```console", + f"git clone https://github.com/ggml-org/whisper.cpp.git", + f"cd whisper.cpp", + f"cmake -B build -S .", + f"cmake --build build --target parakeet-cli -j $(nproc)", + f"```", + f"", + f"Download a model (e.g. Q8_0):", + f"```console", + f"hf download {REPO_ID} {MODELS['q8_0']['remote_name']} --local-dir models", + f"```", + f"", + f"Run:", + f"```console", + f"./build/bin/parakeet-cli -m models/{MODELS['q8_0']['remote_name']} -f samples/jfk.wav", + f"```", + f"", + ] + + return "\n".join(lines) + + +def upload_variant(api, key): + m = MODELS[key] + local_path = m["local_path"] + + if not os.path.exists(local_path): + print(f" Skipping {key}: {local_path} not found") + return False + + print(f" Uploading {m['remote_name']} ({m['description']})...") + api.upload_file( + path_or_fileobj=local_path, + path_in_repo=m["remote_name"], + repo_id=REPO_ID, + repo_type="model", + commit_message=f"Upload {m['remote_name']}", + ) + return True + + +def main(): + parser = argparse.ArgumentParser(description="Upload parakeet GGUF models to Hugging Face") + parser.add_argument( + "variants", + nargs="*", + default=None, + metavar="{" + ",".join(MODELS.keys()) + "}", + help="Model variants to upload (default: all)", + ) + parser.add_argument( + "--no-model-card", + action="store_true", + help="Skip updating the model card README", + ) + args = parser.parse_args() + + api = HfApi() + create_repo(repo_id=REPO_ID, repo_type="model", exist_ok=True) + + variants = args.variants if args.variants else list(MODELS.keys()) + + unknown = [v for v in variants if v not in MODELS] + if unknown: + parser.error(f"unknown variant(s): {', '.join(unknown)} (choose from {', '.join(MODELS.keys())})") + + uploaded = [] + for key in variants: + if upload_variant(api, key): + uploaded.append(key) + + if not uploaded: + print("No models were uploaded.") + return + + if not args.no_model_card: + print("Updating model card...") + existing = [k for k in MODELS if k in uploaded or + any(f.rfilename == MODELS[k]["remote_name"] + for f in api.list_repo_files(REPO_ID, repo_type="model") + if hasattr(f, "rfilename"))] + card = build_model_card(existing if existing else uploaded) + api.upload_file( + path_or_fileobj=card.encode(), + path_in_repo="README.md", + repo_id=REPO_ID, + repo_type="model", + commit_message="Update README.md", + ) + + print(f"\nDone. Repository: https://huggingface.co/{REPO_ID}") + + +if __name__ == "__main__": + main() diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 095a2791de5..4e7c5b24dc3 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -109,23 +109,43 @@ add_library(whisper whisper.cpp ) +add_library(parakeet + ../include/parakeet.h + parakeet-arch.h + parakeet.cpp + ) + +target_include_directories(parakeet PUBLIC . ../include) +target_compile_features (parakeet PUBLIC cxx_std_11) +target_link_libraries(parakeet PUBLIC ggml Threads::Threads) + # Set the version numbers set_target_properties(whisper PROPERTIES VERSION ${PROJECT_VERSION} SOVERSION ${SOVERSION} ) +set_target_properties(parakeet PROPERTIES + VERSION ${PROJECT_VERSION} + SOVERSION ${SOVERSION} +) + target_include_directories(whisper PUBLIC . ../include) target_compile_features (whisper PUBLIC cxx_std_11) # don't bump if (CMAKE_CXX_BYTE_ORDER STREQUAL "BIG_ENDIAN") set(WHISPER_EXTRA_FLAGS ${WHISPER_EXTRA_FLAGS} -DWHISPER_BIG_ENDIAN) + set(PARAKEET_EXTRA_FLAGS ${PARAKEET_EXTRA_FLAGS} -DPARAKEET_BIG_ENDIAN) endif() if (WHISPER_EXTRA_FLAGS) target_compile_options(whisper PRIVATE ${WHISPER_EXTRA_FLAGS}) endif() +if (PARAKEET_EXTRA_FLAGS) + target_compile_options(parakeet PRIVATE ${PARAKEET_EXTRA_FLAGS}) +endif() + find_package(Threads REQUIRED) target_link_libraries(whisper PUBLIC ggml Threads::Threads) @@ -144,4 +164,7 @@ endif() if (BUILD_SHARED_LIBS) set_target_properties(whisper PROPERTIES POSITION_INDEPENDENT_CODE ON) target_compile_definitions(whisper PRIVATE WHISPER_SHARED WHISPER_BUILD) + + set_target_properties(parakeet PROPERTIES POSITION_INDEPENDENT_CODE ON) + target_compile_definitions(parakeet PRIVATE PARAKEET_SHARED PARAKEET_BUILD) endif() diff --git a/src/parakeet-arch.h b/src/parakeet-arch.h new file mode 100644 index 00000000000..3407a95c9c7 --- /dev/null +++ b/src/parakeet-arch.h @@ -0,0 +1,188 @@ +#pragma once + +#include "ggml.h" + +#include <map> + +enum parakeet_tensor { + // Encoder pre_encode + PARAKEET_TENSOR_ENC_PRE_OUT_WEIGHT, + PARAKEET_TENSOR_ENC_PRE_OUT_BIAS, + PARAKEET_TENSOR_ENC_PRE_CONV_0_WEIGHT, + PARAKEET_TENSOR_ENC_PRE_CONV_0_BIAS, + PARAKEET_TENSOR_ENC_PRE_CONV_2_WEIGHT, + PARAKEET_TENSOR_ENC_PRE_CONV_2_BIAS, + PARAKEET_TENSOR_ENC_PRE_CONV_3_WEIGHT, + PARAKEET_TENSOR_ENC_PRE_CONV_3_BIAS, + PARAKEET_TENSOR_ENC_PRE_CONV_5_WEIGHT, + PARAKEET_TENSOR_ENC_PRE_CONV_5_BIAS, + PARAKEET_TENSOR_ENC_PRE_CONV_6_WEIGHT, + PARAKEET_TENSOR_ENC_PRE_CONV_6_BIAS, + + // Encoder layers (per-layer) + PARAKEET_TENSOR_ENC_NORM_FF1_WEIGHT, + PARAKEET_TENSOR_ENC_NORM_FF1_BIAS, + PARAKEET_TENSOR_ENC_FF1_LINEAR1_WEIGHT, + PARAKEET_TENSOR_ENC_FF1_LINEAR2_WEIGHT, + PARAKEET_TENSOR_ENC_NORM_CONV_WEIGHT, + PARAKEET_TENSOR_ENC_NORM_CONV_BIAS, + PARAKEET_TENSOR_ENC_CONV_PW1_WEIGHT, + PARAKEET_TENSOR_ENC_CONV_DW_WEIGHT, + PARAKEET_TENSOR_ENC_CONV_BN_WEIGHT, + PARAKEET_TENSOR_ENC_CONV_BN_BIAS, + PARAKEET_TENSOR_ENC_CONV_BN_MEAN, + PARAKEET_TENSOR_ENC_CONV_BN_VAR, + PARAKEET_TENSOR_ENC_CONV_BN_NUM_BATCHES, + PARAKEET_TENSOR_ENC_CONV_PW2_WEIGHT, + PARAKEET_TENSOR_ENC_NORM_ATTN_WEIGHT, + PARAKEET_TENSOR_ENC_NORM_ATTN_BIAS, + PARAKEET_TENSOR_ENC_ATTN_POS_BIAS_U, + PARAKEET_TENSOR_ENC_ATTN_POS_BIAS_V, + PARAKEET_TENSOR_ENC_ATTN_Q_WEIGHT, + PARAKEET_TENSOR_ENC_ATTN_K_WEIGHT, + PARAKEET_TENSOR_ENC_ATTN_V_WEIGHT, + PARAKEET_TENSOR_ENC_ATTN_OUT_WEIGHT, + PARAKEET_TENSOR_ENC_ATTN_POS_WEIGHT, + PARAKEET_TENSOR_ENC_NORM_FF2_WEIGHT, + PARAKEET_TENSOR_ENC_NORM_FF2_BIAS, + PARAKEET_TENSOR_ENC_FF2_LINEAR1_WEIGHT, + PARAKEET_TENSOR_ENC_FF2_LINEAR2_WEIGHT, + PARAKEET_TENSOR_ENC_NORM_OUT_WEIGHT, + PARAKEET_TENSOR_ENC_NORM_OUT_BIAS, + + // Prediction network + PARAKEET_TENSOR_PRED_EMBED_WEIGHT, + PARAKEET_TENSOR_PRED_LSTM_WEIGHT_IH, + PARAKEET_TENSOR_PRED_LSTM_WEIGHT_HH, + PARAKEET_TENSOR_PRED_LSTM_BIAS_H, + + // Joint network + PARAKEET_TENSOR_JOINT_PRED_WEIGHT, + PARAKEET_TENSOR_JOINT_PRED_BIAS, + PARAKEET_TENSOR_JOINT_ENC_WEIGHT, + PARAKEET_TENSOR_JOINT_ENC_BIAS, + PARAKEET_TENSOR_JOINT_NET_WEIGHT, + PARAKEET_TENSOR_JOINT_NET_BIAS, +}; + +static const std::map<parakeet_tensor, const char *> PARAKEET_TENSOR_NAMES = { + // Encoder pre_encode + {PARAKEET_TENSOR_ENC_PRE_OUT_WEIGHT, "encoder.pre_encode.out.weight"}, + {PARAKEET_TENSOR_ENC_PRE_OUT_BIAS, "encoder.pre_encode.out.bias"}, + {PARAKEET_TENSOR_ENC_PRE_CONV_0_WEIGHT, "encoder.pre_encode.conv.0.weight"}, + {PARAKEET_TENSOR_ENC_PRE_CONV_0_BIAS, "encoder.pre_encode.conv.0.bias"}, + {PARAKEET_TENSOR_ENC_PRE_CONV_2_WEIGHT, "encoder.pre_encode.conv.2.weight"}, + {PARAKEET_TENSOR_ENC_PRE_CONV_2_BIAS, "encoder.pre_encode.conv.2.bias"}, + {PARAKEET_TENSOR_ENC_PRE_CONV_3_WEIGHT, "encoder.pre_encode.conv.3.weight"}, + {PARAKEET_TENSOR_ENC_PRE_CONV_3_BIAS, "encoder.pre_encode.conv.3.bias"}, + {PARAKEET_TENSOR_ENC_PRE_CONV_5_WEIGHT, "encoder.pre_encode.conv.5.weight"}, + {PARAKEET_TENSOR_ENC_PRE_CONV_5_BIAS, "encoder.pre_encode.conv.5.bias"}, + {PARAKEET_TENSOR_ENC_PRE_CONV_6_WEIGHT, "encoder.pre_encode.conv.6.weight"}, + {PARAKEET_TENSOR_ENC_PRE_CONV_6_BIAS, "encoder.pre_encode.conv.6.bias"}, + + // Encoder layers (use %d for layer number) + {PARAKEET_TENSOR_ENC_NORM_FF1_WEIGHT, "encoder.layers.%d.norm_feed_forward1.weight"}, + {PARAKEET_TENSOR_ENC_NORM_FF1_BIAS, "encoder.layers.%d.norm_feed_forward1.bias"}, + {PARAKEET_TENSOR_ENC_FF1_LINEAR1_WEIGHT, "encoder.layers.%d.feed_forward1.linear1.weight"}, + {PARAKEET_TENSOR_ENC_FF1_LINEAR2_WEIGHT, "encoder.layers.%d.feed_forward1.linear2.weight"}, + {PARAKEET_TENSOR_ENC_NORM_CONV_WEIGHT, "encoder.layers.%d.norm_conv.weight"}, + {PARAKEET_TENSOR_ENC_NORM_CONV_BIAS, "encoder.layers.%d.norm_conv.bias"}, + {PARAKEET_TENSOR_ENC_CONV_PW1_WEIGHT, "encoder.layers.%d.conv.pointwise_conv1.weight"}, + {PARAKEET_TENSOR_ENC_CONV_DW_WEIGHT, "encoder.layers.%d.conv.depthwise_conv.weight"}, + {PARAKEET_TENSOR_ENC_CONV_BN_WEIGHT, "encoder.layers.%d.conv.batch_norm.weight"}, + {PARAKEET_TENSOR_ENC_CONV_BN_BIAS, "encoder.layers.%d.conv.batch_norm.bias"}, + {PARAKEET_TENSOR_ENC_CONV_BN_MEAN, "encoder.layers.%d.conv.batch_norm.running_mean"}, + {PARAKEET_TENSOR_ENC_CONV_BN_VAR, "encoder.layers.%d.conv.batch_norm.running_var"}, + {PARAKEET_TENSOR_ENC_CONV_BN_NUM_BATCHES, "encoder.layers.%d.conv.batch_norm.num_batches_tracked"}, + {PARAKEET_TENSOR_ENC_CONV_PW2_WEIGHT, "encoder.layers.%d.conv.pointwise_conv2.weight"}, + {PARAKEET_TENSOR_ENC_NORM_ATTN_WEIGHT, "encoder.layers.%d.norm_self_att.weight"}, + {PARAKEET_TENSOR_ENC_NORM_ATTN_BIAS, "encoder.layers.%d.norm_self_att.bias"}, + {PARAKEET_TENSOR_ENC_ATTN_POS_BIAS_U, "encoder.layers.%d.self_attn.pos_bias_u"}, + {PARAKEET_TENSOR_ENC_ATTN_POS_BIAS_V, "encoder.layers.%d.self_attn.pos_bias_v"}, + {PARAKEET_TENSOR_ENC_ATTN_Q_WEIGHT, "encoder.layers.%d.self_attn.linear_q.weight"}, + {PARAKEET_TENSOR_ENC_ATTN_K_WEIGHT, "encoder.layers.%d.self_attn.linear_k.weight"}, + {PARAKEET_TENSOR_ENC_ATTN_V_WEIGHT, "encoder.layers.%d.self_attn.linear_v.weight"}, + {PARAKEET_TENSOR_ENC_ATTN_OUT_WEIGHT, "encoder.layers.%d.self_attn.linear_out.weight"}, + {PARAKEET_TENSOR_ENC_ATTN_POS_WEIGHT, "encoder.layers.%d.self_attn.linear_pos.weight"}, + {PARAKEET_TENSOR_ENC_NORM_FF2_WEIGHT, "encoder.layers.%d.norm_feed_forward2.weight"}, + {PARAKEET_TENSOR_ENC_NORM_FF2_BIAS, "encoder.layers.%d.norm_feed_forward2.bias"}, + {PARAKEET_TENSOR_ENC_FF2_LINEAR1_WEIGHT, "encoder.layers.%d.feed_forward2.linear1.weight"}, + {PARAKEET_TENSOR_ENC_FF2_LINEAR2_WEIGHT, "encoder.layers.%d.feed_forward2.linear2.weight"}, + {PARAKEET_TENSOR_ENC_NORM_OUT_WEIGHT, "encoder.layers.%d.norm_out.weight"}, + {PARAKEET_TENSOR_ENC_NORM_OUT_BIAS, "encoder.layers.%d.norm_out.bias"}, + + // Prediction network + {PARAKEET_TENSOR_PRED_EMBED_WEIGHT, "decoder.prediction.embed.weight"}, + {PARAKEET_TENSOR_PRED_LSTM_WEIGHT_IH, "decoder.prediction.dec_rnn.lstm.weight_ih_l%d"}, + {PARAKEET_TENSOR_PRED_LSTM_WEIGHT_HH, "decoder.prediction.dec_rnn.lstm.weight_hh_l%d"}, + {PARAKEET_TENSOR_PRED_LSTM_BIAS_H, "decoder.prediction.dec_rnn.lstm.bias_h_l%d"}, + + // Joint network + {PARAKEET_TENSOR_JOINT_PRED_WEIGHT, "joint.pred.weight"}, + {PARAKEET_TENSOR_JOINT_PRED_BIAS, "joint.pred.bias"}, + {PARAKEET_TENSOR_JOINT_ENC_WEIGHT, "joint.enc.weight"}, + {PARAKEET_TENSOR_JOINT_ENC_BIAS, "joint.enc.bias"}, + {PARAKEET_TENSOR_JOINT_NET_WEIGHT, "joint.joint_net.2.weight"}, + {PARAKEET_TENSOR_JOINT_NET_BIAS, "joint.joint_net.2.bias"}, +}; + +static const std::map<parakeet_tensor, ggml_op> PARAKEET_TENSOR_INFO = { + // Encoder pre_encode + {PARAKEET_TENSOR_ENC_PRE_OUT_WEIGHT, GGML_OP_MUL_MAT}, + {PARAKEET_TENSOR_ENC_PRE_OUT_BIAS, GGML_OP_ADD}, + {PARAKEET_TENSOR_ENC_PRE_CONV_0_WEIGHT, GGML_OP_IM2COL}, + {PARAKEET_TENSOR_ENC_PRE_CONV_0_BIAS, GGML_OP_ADD}, + {PARAKEET_TENSOR_ENC_PRE_CONV_2_WEIGHT, GGML_OP_IM2COL}, + {PARAKEET_TENSOR_ENC_PRE_CONV_2_BIAS, GGML_OP_ADD}, + {PARAKEET_TENSOR_ENC_PRE_CONV_3_WEIGHT, GGML_OP_IM2COL}, + {PARAKEET_TENSOR_ENC_PRE_CONV_3_BIAS, GGML_OP_ADD}, + {PARAKEET_TENSOR_ENC_PRE_CONV_5_WEIGHT, GGML_OP_IM2COL}, + {PARAKEET_TENSOR_ENC_PRE_CONV_5_BIAS, GGML_OP_ADD}, + {PARAKEET_TENSOR_ENC_PRE_CONV_6_WEIGHT, GGML_OP_IM2COL}, + {PARAKEET_TENSOR_ENC_PRE_CONV_6_BIAS, GGML_OP_ADD}, + + // Encoder layers + {PARAKEET_TENSOR_ENC_NORM_FF1_WEIGHT, GGML_OP_MUL}, + {PARAKEET_TENSOR_ENC_NORM_FF1_BIAS, GGML_OP_ADD}, + {PARAKEET_TENSOR_ENC_FF1_LINEAR1_WEIGHT, GGML_OP_MUL_MAT}, + {PARAKEET_TENSOR_ENC_FF1_LINEAR2_WEIGHT, GGML_OP_MUL_MAT}, + {PARAKEET_TENSOR_ENC_NORM_CONV_WEIGHT, GGML_OP_MUL}, + {PARAKEET_TENSOR_ENC_NORM_CONV_BIAS, GGML_OP_ADD}, + {PARAKEET_TENSOR_ENC_CONV_PW1_WEIGHT, GGML_OP_IM2COL}, + {PARAKEET_TENSOR_ENC_CONV_DW_WEIGHT, GGML_OP_IM2COL}, + {PARAKEET_TENSOR_ENC_CONV_BN_WEIGHT, GGML_OP_MUL}, + {PARAKEET_TENSOR_ENC_CONV_BN_BIAS, GGML_OP_ADD}, + {PARAKEET_TENSOR_ENC_CONV_BN_MEAN, GGML_OP_SUB}, + {PARAKEET_TENSOR_ENC_CONV_BN_VAR, GGML_OP_DIV}, + {PARAKEET_TENSOR_ENC_CONV_BN_NUM_BATCHES, GGML_OP_NONE}, + {PARAKEET_TENSOR_ENC_CONV_PW2_WEIGHT, GGML_OP_IM2COL}, + {PARAKEET_TENSOR_ENC_NORM_ATTN_WEIGHT, GGML_OP_MUL}, + {PARAKEET_TENSOR_ENC_NORM_ATTN_BIAS, GGML_OP_ADD}, + {PARAKEET_TENSOR_ENC_ATTN_POS_BIAS_U, GGML_OP_ADD}, + {PARAKEET_TENSOR_ENC_ATTN_POS_BIAS_V, GGML_OP_ADD}, + {PARAKEET_TENSOR_ENC_ATTN_Q_WEIGHT, GGML_OP_MUL_MAT}, + {PARAKEET_TENSOR_ENC_ATTN_K_WEIGHT, GGML_OP_MUL_MAT}, + {PARAKEET_TENSOR_ENC_ATTN_V_WEIGHT, GGML_OP_MUL_MAT}, + {PARAKEET_TENSOR_ENC_ATTN_OUT_WEIGHT, GGML_OP_MUL_MAT}, + {PARAKEET_TENSOR_ENC_ATTN_POS_WEIGHT, GGML_OP_MUL_MAT}, + {PARAKEET_TENSOR_ENC_NORM_FF2_WEIGHT, GGML_OP_MUL}, + {PARAKEET_TENSOR_ENC_NORM_FF2_BIAS, GGML_OP_ADD}, + {PARAKEET_TENSOR_ENC_FF2_LINEAR1_WEIGHT, GGML_OP_MUL_MAT}, + {PARAKEET_TENSOR_ENC_FF2_LINEAR2_WEIGHT, GGML_OP_MUL_MAT}, + {PARAKEET_TENSOR_ENC_NORM_OUT_WEIGHT, GGML_OP_MUL}, + {PARAKEET_TENSOR_ENC_NORM_OUT_BIAS, GGML_OP_ADD}, + + // Prediction network + {PARAKEET_TENSOR_PRED_EMBED_WEIGHT, GGML_OP_GET_ROWS}, + {PARAKEET_TENSOR_PRED_LSTM_WEIGHT_IH, GGML_OP_MUL_MAT}, + {PARAKEET_TENSOR_PRED_LSTM_WEIGHT_HH, GGML_OP_MUL_MAT}, + {PARAKEET_TENSOR_PRED_LSTM_BIAS_H, GGML_OP_ADD}, + + // Joint network + {PARAKEET_TENSOR_JOINT_PRED_WEIGHT, GGML_OP_MUL_MAT}, + {PARAKEET_TENSOR_JOINT_PRED_BIAS, GGML_OP_ADD}, + {PARAKEET_TENSOR_JOINT_ENC_WEIGHT, GGML_OP_MUL_MAT}, + {PARAKEET_TENSOR_JOINT_ENC_BIAS, GGML_OP_ADD}, + {PARAKEET_TENSOR_JOINT_NET_WEIGHT, GGML_OP_MUL_MAT}, + {PARAKEET_TENSOR_JOINT_NET_BIAS, GGML_OP_ADD}, +}; diff --git a/src/parakeet.cpp b/src/parakeet.cpp new file mode 100644 index 00000000000..b5da73e985c --- /dev/null +++ b/src/parakeet.cpp @@ -0,0 +1,3838 @@ +#include "parakeet.h" +#include "parakeet-arch.h" + +#include "ggml.h" +#include "ggml-cpp.h" +#include "ggml-alloc.h" +#include "ggml-backend.h" + +#include <atomic> +#include <algorithm> +#include <cassert> +#include <cfloat> +#define _USE_MATH_DEFINES +#include <cmath> +#include <climits> +#include <cstdarg> +#include <cstdio> +#include <cstring> +#include <fstream> +#include <functional> +#include <cctype> +#include <map> +#include <random> +#include <set> +#include <string> +#include <thread> +#include <vector> + +#ifdef _MSC_VER +#include <codecvt> +#endif + +#if defined(PARAKEET_BIG_ENDIAN) +template<typename T> +static T byteswap(T value) { + T value_swapped; + char * source = reinterpret_cast<char *>(&value); + char * target = reinterpret_cast<char *>(&value_swapped); + int size = sizeof(T); + for (int i = 0; i < size; i++) { + target[size - 1 - i] = source[i]; + } + return value_swapped; +} + +template<typename T> +static void byteswap_tensor_data(ggml_tensor * tensor) { + T * datum = reinterpret_cast<T *>(tensor->data); + for (int i = 0; i < ggml_nelements(tensor); i++) { + datum[i] = byteswap(datum[i]); + } +} + +static void byteswap_tensor(ggml_tensor * tensor) { + switch (tensor->type) { + case GGML_TYPE_I16: { + byteswap_tensor_data<int16_t>(tensor); + break; + } + case GGML_TYPE_F16: { + byteswap_tensor_data<ggml_fp16_t>(tensor); + break; + } + case GGML_TYPE_I32: { + byteswap_tensor_data<int32_t>(tensor); + break; + } + case GGML_TYPE_F32: { + byteswap_tensor_data<float>(tensor); + break; + } + default: { // GML_TYPE_I8 + break; + } + } +} + +#define BYTESWAP_VALUE(d) d = byteswap(d) +#define BYTESWAP_FILTERS(f) \ + do { \ + for (auto & datum : f.data) { \ + datum = byteswap(datum); \ + } \ + } while (0) +#define BYTESWAP_TENSOR(t) \ + do { \ + byteswap_tensor(t); \ + } while (0) +#else +#define BYTESWAP_VALUE(d) do {} while (0) +#define BYTESWAP_FILTERS(f) do {} while (0) +#define BYTESWAP_TENSOR(t) do {} while (0) +#endif + +#ifdef __GNUC__ +#ifdef __MINGW32__ +#define PARAKEET_ATTRIBUTE_FORMAT(...) __attribute__((format(gnu_printf, __VA_ARGS__))) +#else +#define PARAKEET_ATTRIBUTE_FORMAT(...) __attribute__((format(printf, __VA_ARGS__))) +#endif +#else +#define PARAKEET_ATTRIBUTE_FORMAT(...) +#endif + +// +// logging +// + +PARAKEET_ATTRIBUTE_FORMAT(2, 3) +static void parakeet_log_internal (ggml_log_level level, const char * format, ...); +static void parakeet_log_callback_default(ggml_log_level level, const char * text, void * user_data); + +#define PARAKEET_LOG_ERROR(...) parakeet_log_internal(GGML_LOG_LEVEL_ERROR, __VA_ARGS__) +#define PARAKEET_LOG_WARN(...) parakeet_log_internal(GGML_LOG_LEVEL_WARN , __VA_ARGS__) +#define PARAKEET_LOG_INFO(...) parakeet_log_internal(GGML_LOG_LEVEL_INFO , __VA_ARGS__) + +// define this to enable verbose trace logging - useful for debugging purposes +//#define PARAKEET_DEBUG + +#if defined(PARAKEET_DEBUG) +#define PARAKEET_LOG_DEBUG(...) parakeet_log_internal(GGML_LOG_LEVEL_DEBUG, __VA_ARGS__) +#else +#define PARAKEET_LOG_DEBUG(...) +#endif + +#define PARAKEET_ASSERT(x) \ + do { \ + if (!(x)) { \ + PARAKEET_LOG_ERROR("PARAKEET_ASSERT: %s:%d: %s\n", __FILE__, __LINE__, #x); \ + abort(); \ + } \ + } while (0) + +#define PARAKEET_MAX_NODES 8192 + +// Threshold for when local attention should be used. +// 8192 frames x 80ms = 655 s (about 10.9 mins) +static constexpr int PARAKEET_LOCAL_ATTN_THRESHOLD = 8192; +// Window of context in each director of the current token. +// 128 frames * 80ms = 10.24 s +static constexpr int PARAKEET_LOCAL_ATTN_WINDOW = 128; + +static std::string format(const char * fmt, ...) { + va_list ap; + va_list ap2; + va_start(ap, fmt); + va_copy(ap2, ap); + int size = vsnprintf(NULL, 0, fmt, ap); + GGML_ASSERT(size >= 0 && size < INT_MAX); // NOLINT + std::vector<char> buf(size + 1); + int size2 = vsnprintf(buf.data(), size + 1, fmt, ap2); + GGML_ASSERT(size2 == size); + va_end(ap2); + va_end(ap); + return std::string(buf.data(), size); +} + +// +// ggml helpers +// + +static bool ggml_graph_compute_helper( + struct ggml_cgraph * graph, + int n_threads, + ggml_abort_callback abort_callback, + void * abort_callback_data) { + ggml_backend_ptr backend { ggml_backend_init_by_type(GGML_BACKEND_DEVICE_TYPE_CPU, nullptr) }; + + auto * reg = ggml_backend_dev_backend_reg(ggml_backend_get_device(backend.get())); + + auto * set_abort_callback_fn = (ggml_backend_set_abort_callback_t) ggml_backend_reg_get_proc_address(reg, "ggml_backend_set_abort_callback"); + if (set_abort_callback_fn) { + set_abort_callback_fn(backend.get(), abort_callback, abort_callback_data); + } + + auto ggml_backend_set_n_threads_fn = (ggml_backend_set_n_threads_t) ggml_backend_reg_get_proc_address(reg, "ggml_backend_set_n_threads"); + if (ggml_backend_set_n_threads_fn) { + ggml_backend_set_n_threads_fn(backend.get(), n_threads); + } + + return ggml_backend_graph_compute(backend.get(), graph) == GGML_STATUS_SUCCESS; +} + +static bool ggml_graph_compute_helper( + ggml_backend_sched_t sched, + struct ggml_cgraph * graph, + int n_threads, + bool sched_reset = true) { + for (int i = 0; i < ggml_backend_sched_get_n_backends(sched); ++i) { + ggml_backend_t backend = ggml_backend_sched_get_backend(sched, i); + ggml_backend_dev_t dev = ggml_backend_get_device(backend); + ggml_backend_reg_t reg = dev ? ggml_backend_dev_backend_reg(dev) : nullptr; + + auto * fn_set_n_threads = (ggml_backend_set_n_threads_t) ggml_backend_reg_get_proc_address(reg, "ggml_backend_set_n_threads"); + if (fn_set_n_threads) { + fn_set_n_threads(backend, n_threads); + } + } + + const bool t = (ggml_backend_sched_graph_compute(sched, graph) == GGML_STATUS_SUCCESS); + + if (!t || sched_reset) { + ggml_backend_sched_reset(sched); + } + + return t; +} + +// TODO: move these functions to ggml-base with support for ggml-backend? + + +struct parakeet_mel { + int n_len = 0; + int n_len_org = 0; + int n_mel = 0; + + std::vector<float> data; +}; + +struct parakeet_filters { + int32_t n_mel = 0; + int32_t n_fb = 0; // number of frequency bins + + std::vector<float> data; +}; + +struct parakeet_vocab { + using id = int32_t; + using token = std::string; + + int n_vocab = 8192; + size_t max_token_length = 0; + + std::map<token, id> token_to_id; + std::map<id, token> id_to_token; + + id token_unk; + id token_bos; + id token_blank; + id token_eos; +}; + +struct parakeet_segment { + int64_t t0; + int64_t t1; + + std::string text; + + std::vector<parakeet_token_data> tokens; +}; + +struct parakeet_batch { + int32_t n_tokens; + + parakeet_token * token; + int32_t * i_time; // index of the audio frame + parakeet_pos * pos; + int32_t * n_seq_id; // always 1, here for consistency with llama.cpp + parakeet_seq_id ** seq_id; // null terminated + int8_t * logits; +}; + +// ggml_backend_sched wrapper for parakeet usage +struct parakeet_sched { + ggml_backend_sched_t sched = nullptr; + + std::vector<uint8_t> meta; +}; + +// TODO: Find out is there a multiple version types. It is not yet clear to me +// at this point. +enum parakeet_arch { + PARAKEET_ARCH_UNKNOWN = 0, + PARAKEET_ARCH_TDT = 1, // NVIDIA Parakeet TDT (RNN-T) +}; + +struct parakeet_hparams { + int32_t n_vocab = 8192; + int32_t n_audio_ctx = 0; // 0 = unlimited, will be set based on input + int32_t n_audio_state = 1024; + int32_t n_audio_head = 8; + int32_t n_audio_layer = 24; + int32_t n_mels = 128; + int32_t ftype = 1; + int32_t n_fft = 512; // FFT size for mel spectrogram + float eps = 1e-5f; + int32_t subsampling_factor = 8; + int32_t n_subsampling_channels = 256; + int32_t n_conv_kernel = 9; + int32_t n_pred_dim = 640; + int32_t n_pred_layers = 2; + int32_t n_tdt_durations = 5; + int32_t n_max_tokens = 10; + + parakeet_arch arch = PARAKEET_ARCH_TDT; +}; + +struct parakeet_layer_encoder { + struct ggml_tensor * norm_ff1_w = nullptr; + struct ggml_tensor * norm_ff1_b = nullptr; + + struct ggml_tensor * ff1_linear1_w = nullptr; + struct ggml_tensor * ff1_linear2_w = nullptr; + + struct ggml_tensor * norm_conv_w = nullptr; + struct ggml_tensor * norm_conv_b = nullptr; + + struct ggml_tensor * conv_pw1_w = nullptr; // pointwise_conv1 + struct ggml_tensor * conv_dw_w = nullptr; // depthwise_conv + struct ggml_tensor * conv_bn_w = nullptr; // batch_norm weight + struct ggml_tensor * conv_bn_b = nullptr; // batch_norm bias + struct ggml_tensor * conv_bn_mean = nullptr; // batch_norm running_mean + struct ggml_tensor * conv_bn_var = nullptr; // batch_norm running_var + struct ggml_tensor * conv_bn_num_batches = nullptr; // batch_norm num_batches_tracked + struct ggml_tensor * conv_pw2_w = nullptr; // pointwise_conv2 + + struct ggml_tensor * norm_attn_w = nullptr; + struct ggml_tensor * norm_attn_b = nullptr; + + struct ggml_tensor * attn_pos_bias_u = nullptr; + struct ggml_tensor * attn_pos_bias_v = nullptr; + struct ggml_tensor * attn_q_w = nullptr; + struct ggml_tensor * attn_k_w = nullptr; + struct ggml_tensor * attn_v_w = nullptr; + struct ggml_tensor * attn_out_w = nullptr; + struct ggml_tensor * attn_pos_w = nullptr; + + struct ggml_tensor * norm_ff2_w = nullptr; + struct ggml_tensor * norm_ff2_b = nullptr; + + struct ggml_tensor * ff2_linear1_w = nullptr; + struct ggml_tensor * ff2_linear2_w = nullptr; + + struct ggml_tensor * norm_out_w = nullptr; + struct ggml_tensor * norm_out_b = nullptr; +}; + +struct parakeet_lsmt_layer { + struct ggml_tensor * ih_w = nullptr; // input-to-hidden weight + struct ggml_tensor * hh_w = nullptr; // hidden-to-hidden weight + struct ggml_tensor * b_h = nullptr; // bias (ih folded into hh at conversion time) +}; + +struct parakeet_prediction_network { + struct ggml_tensor * embed_w = nullptr; + + std::vector<parakeet_lsmt_layer> lstm_layer; +}; + +struct parakeet_joint_network { + struct ggml_tensor * pred_w = nullptr; + struct ggml_tensor * pred_b = nullptr; + struct ggml_tensor * enc_w = nullptr; + struct ggml_tensor * enc_b = nullptr; + struct ggml_tensor * net_w = nullptr; + struct ggml_tensor * net_b = nullptr; +}; + +struct parakeet_model { + parakeet_filters filters; + parakeet_hparams hparams; + + struct ggml_tensor * enc_pre_out_w = nullptr; + struct ggml_tensor * enc_pre_out_b = nullptr; + struct ggml_tensor * enc_pre_conv_0_w = nullptr; + struct ggml_tensor * enc_pre_conv_0_b = nullptr; + struct ggml_tensor * enc_pre_conv_2_w = nullptr; + struct ggml_tensor * enc_pre_conv_2_b = nullptr; + struct ggml_tensor * enc_pre_conv_3_w = nullptr; + struct ggml_tensor * enc_pre_conv_3_b = nullptr; + struct ggml_tensor * enc_pre_conv_5_w = nullptr; + struct ggml_tensor * enc_pre_conv_5_b = nullptr; + struct ggml_tensor * enc_pre_conv_6_w = nullptr; + struct ggml_tensor * enc_pre_conv_6_b = nullptr; + + std::vector<parakeet_layer_encoder> layers; + + parakeet_prediction_network prediction; + + parakeet_joint_network joint; + + std::vector<uint32_t> tdt_durations; + + std::vector<ggml_context *> ctxs; + + std::vector<ggml_backend_buffer_t> buffers; + + int n_loaded = 0; + std::map<std::string, struct ggml_tensor *> tensors; +}; + +struct parakeet_lstm_state_layer { + struct ggml_tensor * h_state = nullptr; + struct ggml_tensor * c_state = nullptr; +}; + +struct parakeet_lstm_state { + std::vector<parakeet_lstm_state_layer> layer; + + std::vector<uint8_t> ctx_buf; + + ggml_backend_buffer_t buffer = nullptr; +}; + +struct parakeet_state { + int64_t t_sample_us = 0; + int64_t t_encode_us = 0; + int64_t t_decode_us = 0; + int64_t t_predict_us = 0; + int64_t t_predict_build_us = 0; // time spent building the prediction graph + int64_t t_predict_alloc_us = 0; // time spent in ggml_backend_sched_alloc_graph + int64_t t_predict_compute_us = 0; // time spent in ggml_graph_compute_helper + int64_t t_mel_us = 0; + + int32_t n_sample = 0; // number of tokens sampled + int32_t n_encode = 0; // number of encoder calls + int32_t n_decode = 0; // number of decoder calls with n_tokens == 1 (text-generation) + int32_t n_predict = 0; // number of prediction network calls + int32_t n_fail_p = 0; // number of logprob threshold failures + int32_t n_fail_h = 0; // number of entropy threshold failures + + parakeet_mel mel; + + parakeet_batch batch; + + int n_frames = 0; + + std::vector<ggml_backend_t> backends; + + parakeet_sched sched_encode; + parakeet_sched sched_decode; + + // outputs from encoder stages + struct ggml_tensor * enc_out = nullptr; + struct ggml_tensor * pred_out = nullptr; + + std::vector<uint8_t> enc_out_buf; + ggml_backend_buffer_t enc_out_buffer = nullptr; + + std::vector<uint8_t> pred_out_buf; + ggml_backend_buffer_t pred_out_buffer = nullptr; + + struct ggml_tensor * attn_mask = nullptr; + + std::vector<float> inp_mel; + std::vector<float> inp_mask; + + std::vector<float> logits; + + std::vector<parakeet_segment> result_all; + + std::vector<parakeet_token> decoded_tokens; + std::vector<parakeet_token_data> decoded_token_data; + + std::string path_model; + + int32_t n_audio_ctx = 0; + int32_t sched_encode_n_audio_ctx = 0; + + parakeet_lstm_state lstm_state; +}; + +// FFT cache for mel spectrogram computation +struct parakeet_mel_cache { + int n_fft = 0; + + // In FFT, we frequently use sine and cosine operations with the same values. + // We can use precalculated values to speed up the process. + std::vector<float> sin_vals; + std::vector<float> cos_vals; + + // Hann window (Use cosf to eliminate difference) + // ref: https://pytorch.org/docs/stable/generated/torch.hann_window.html + // ref: https://github.com/openai/whisper/blob/main/whisper/audio.py#L147 + std::vector<float> hann_window; + + // Window function from model (Parakeet uses actual window from training) + std::vector<float> window; + + void init(int fft_size) { + n_fft = fft_size; + sin_vals.resize(n_fft); + cos_vals.resize(n_fft); + hann_window.resize(n_fft); + + fill_sin_cos_table(); + fill_hann_window(n_fft, true, hann_window.data()); + } + + void fill_sin_cos_table() { + for (int i = 0; i < n_fft; i++) { + double theta = (2 * M_PI * i) / n_fft; + sin_vals[i] = sinf(theta); + cos_vals[i] = cosf(theta); + } + } + + void fill_hann_window(int length, bool periodic, float * output) { + int offset = -1; + if (periodic) { + offset = 0; + } + for (int i = 0; i < length; i++) { + output[i] = 0.5 * (1.0 - cosf((2.0 * M_PI * i) / (length + offset))); + } + } +}; + +struct parakeet_context { + int64_t t_load_us = 0; + int64_t t_start_us = 0; + + ggml_type wtype = ggml_type::GGML_TYPE_F16; + ggml_type itype = ggml_type::GGML_TYPE_F16; + + parakeet_context_params params; + + parakeet_model model; + parakeet_vocab vocab; + + parakeet_state * state = nullptr; + + parakeet_mel_cache mel_cache; + + std::string path_model; +}; + +struct parakeet_global { + // We save the log callback globally + ggml_log_callback log_callback = parakeet_log_callback_default; + void * log_callback_user_data = nullptr; +}; + +static parakeet_global g_state; + +static const std::string PARAKEET_SPM_SPACE = "\xE2\x96\x81"; + +static inline int utf8_codepoint_len(unsigned char c) { + if ((c & 0x80) == 0x00) return 1; + if ((c & 0xE0) == 0xC0) return 2; + if ((c & 0xF0) == 0xE0) return 3; + if ((c & 0xF8) == 0xF0) return 4; + return 1; +} + +static bool is_sentencepiece_control(const std::string & piece) { + return piece == "<unk>" || piece == "<s>" || piece == "</s>" || piece == "[BLANK]"; +} + +static std::string sentencepiece_normalize(const std::string & text) { + std::string normalized; + normalized.reserve(text.size() + PARAKEET_SPM_SPACE.size()); + normalized += PARAKEET_SPM_SPACE; // SentencePiece dummy prefix + + for (unsigned char c : text) { + if (std::isspace(c)) { + normalized += PARAKEET_SPM_SPACE; + } else { + normalized += static_cast<char>(c); + } + } + + return normalized; +} + +static std::string sentencepiece_piece_to_text(const std::string & piece, bool is_first_piece) { + if (is_sentencepiece_control(piece)) { + return ""; + } + + std::string text; + text.reserve(piece.size()); + + size_t pos = 0; + while (pos < piece.size()) { + if (piece.compare(pos, PARAKEET_SPM_SPACE.size(), PARAKEET_SPM_SPACE) == 0) { + if (!is_first_piece || !text.empty()) { + text += ' '; + } + pos += PARAKEET_SPM_SPACE.size(); + continue; + } + + text += piece[pos]; + ++pos; + } + + return text; +} + + +static struct parakeet_batch parakeet_batch_init(int32_t n_tokens) { + parakeet_batch batch = { 0, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, }; + + batch.token = (parakeet_token * ) malloc(sizeof(parakeet_token) * (n_tokens)); + batch.i_time = (int32_t *) malloc(sizeof(int32_t) * (n_tokens)); + batch.pos = (parakeet_pos *) malloc(sizeof(parakeet_pos) * (n_tokens)); + batch.n_seq_id = (int32_t *) malloc(sizeof(int32_t) * (n_tokens)); + batch.seq_id = (parakeet_seq_id **) malloc(sizeof(parakeet_seq_id *) * (n_tokens + 1)); + for (int i = 0; i < n_tokens; ++i) { + batch.seq_id[i] = (parakeet_seq_id *) malloc(sizeof(parakeet_seq_id)); + } + batch.seq_id[n_tokens] = nullptr; + batch.logits = (int8_t *) malloc(sizeof(int8_t) * n_tokens); + + return batch; +} + +static void parakeet_batch_free(struct parakeet_batch batch) { + if (batch.token) free(batch.token); + if (batch.i_time) free(batch.i_time); + if (batch.pos) free(batch.pos); + if (batch.n_seq_id) free(batch.n_seq_id); + if (batch.seq_id) { + for (int i = 0; batch.seq_id[i]; ++i) { + free(batch.seq_id[i]); + } + free(batch.seq_id); + } + if (batch.logits) free(batch.logits); +} + +static void parakeet_batch_prep_legacy(parakeet_batch & batch, const parakeet_token * tokens, int n_tokens, int n_past, int seq_id) { + batch.n_tokens = n_tokens; + for (int i = 0; i < n_tokens; ++i) { + if (tokens) { + batch.token[i] = tokens[i]; + } + batch.pos [i] = n_past + i; + batch.n_seq_id[i] = 1; + batch.seq_id [i][0] = seq_id; + batch.logits [i] = 0; + } + batch.logits[n_tokens - 1] = 1; +} + + +static size_t parakeet_sched_size(struct parakeet_sched & allocr) { + size_t size = allocr.meta.size(); + for (int i = 0; i < ggml_backend_sched_get_n_backends(allocr.sched); ++i) { + ggml_backend_t backend = ggml_backend_sched_get_backend(allocr.sched, i); + size += ggml_backend_sched_get_buffer_size(allocr.sched, backend); + } + return size; +} + +static bool parakeet_sched_graph_init(struct parakeet_sched & allocr, std::vector<ggml_backend_t> backends, std::function<struct ggml_cgraph *()> && get_graph) { + auto & sched = allocr.sched; + auto & meta = allocr.meta; + + sched = ggml_backend_sched_new(backends.data(), nullptr, backends.size(), PARAKEET_MAX_NODES, false, true); + + if (!sched) { + PARAKEET_LOG_ERROR("%s: failed to create scheduler\n", __func__); + return false; + } + + meta.resize(ggml_tensor_overhead()*PARAKEET_MAX_NODES + ggml_graph_overhead()); + + if (!ggml_backend_sched_alloc_graph(sched, get_graph())) { + PARAKEET_LOG_ERROR("%s: failed to allocate the compute buffer\n", __func__); + ggml_backend_sched_free(sched); + sched = nullptr; + return false; + } + + ggml_backend_sched_reset(sched); + + return true; +} + +static void parakeet_sched_free(struct parakeet_sched & sched) { + if (sched.sched) { + ggml_backend_sched_free(sched.sched); + sched.sched = nullptr; + } + + sched.meta.clear(); +} + + +template<typename T> +static void read_safe(parakeet_model_loader * loader, T & dest) { + loader->read(loader->context, &dest, sizeof(T)); + BYTESWAP_VALUE(dest); +} + +static bool parakeet_lstm_state_init( + struct parakeet_state & pstate, + ggml_backend_t backend, + int n_layer, + int n_pred_dim) { + parakeet_lstm_state & lstm_state = pstate.lstm_state; + + lstm_state.ctx_buf.resize(ggml_tensor_overhead() * n_layer * 2); + lstm_state.layer.resize(n_layer); + + struct ggml_init_params params = { + /*.mem_size =*/ lstm_state.ctx_buf.size(), + /*.mem_buffer =*/ lstm_state.ctx_buf.data(), + /*.no_alloc =*/ true, + }; + + struct ggml_context * ctx = ggml_init(params); + + if (!ctx) { + PARAKEET_LOG_ERROR("%s: failed to allocate memory for the lstm states context\n", __func__); + return false; + } + + + for (int il = 0; il < n_layer; ++il) { + lstm_state.layer[il].h_state = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_pred_dim); + lstm_state.layer[il].c_state = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_pred_dim); + } + + lstm_state.buffer = ggml_backend_alloc_ctx_tensors(ctx, backend); + if (!lstm_state.buffer) { + PARAKEET_LOG_ERROR("%s: failed to allocate memory for the lstm states\n", __func__); + return false; + } + + ggml_backend_buffer_clear(lstm_state.buffer, 0); + + ggml_free(ctx); + + return true; +} + +static bool parakeet_pred_state_init( + struct parakeet_state & pstate, + ggml_backend_t backend, + int n_pred_dim) { + pstate.pred_out_buf.resize(ggml_tensor_overhead()); + + struct ggml_init_params params = { + /*.mem_size =*/ pstate.pred_out_buf.size(), + /*.mem_buffer =*/ pstate.pred_out_buf.data(), + /*.no_alloc =*/ true, + }; + + struct ggml_context * ctx = ggml_init(params); + if (!ctx) { + PARAKEET_LOG_ERROR("%s: failed to allocate memory for pred tensor context\n", __func__); + return false; + } + + pstate.pred_out = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_pred_dim); + pstate.pred_out_buffer = ggml_backend_alloc_ctx_tensors(ctx, backend); + if (!pstate.pred_out_buffer) { + PARAKEET_LOG_ERROR("%s: failed to allocate memory for pred tensor\n", __func__); + ggml_free(ctx); + return false; + } + + ggml_free(ctx); + + return true; +} + +static bool parakeet_enc_state_init( + struct parakeet_state & pstate, + ggml_backend_t backend, + int n_audio_state, + int n_frames_max) { + pstate.enc_out_buf.resize(ggml_tensor_overhead()); + + struct ggml_init_params params = { + /*.mem_size =*/ pstate.enc_out_buf.size(), + /*.mem_buffer =*/ pstate.enc_out_buf.data(), + /*.no_alloc =*/ true, + }; + + struct ggml_context * ctx = ggml_init(params); + if (!ctx) { + PARAKEET_LOG_ERROR("%s: failed to allocate memory for enc_out tensor context\n", __func__); + return false; + } + + pstate.enc_out = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_audio_state, n_frames_max); + pstate.enc_out_buffer = ggml_backend_alloc_ctx_tensors(ctx, backend); + if (!pstate.enc_out_buffer) { + PARAKEET_LOG_ERROR("%s: failed to allocate memory for enc_out tensor\n", __func__); + ggml_free(ctx); + return false; + } + + ggml_free(ctx); + + return true; +} + +static ggml_backend_t parakeet_backend_init_gpu(const parakeet_context_params & params) { + ggml_log_set(g_state.log_callback, g_state.log_callback_user_data); + + ggml_backend_dev_t dev = nullptr; + + int cnt = 0; + if (params.use_gpu) { + for (size_t i = 0; i < ggml_backend_dev_count(); ++i) { + ggml_backend_dev_t dev_cur = ggml_backend_dev_get(i); + enum ggml_backend_dev_type dev_type = ggml_backend_dev_type(dev_cur); + const char * dev_name = ggml_backend_dev_name(dev_cur); + PARAKEET_LOG_INFO("%s: device %zu: %s (type: %d)\n", __func__, i, dev_name, dev_type); + if (dev_type == GGML_BACKEND_DEVICE_TYPE_GPU || dev_type == GGML_BACKEND_DEVICE_TYPE_IGPU) { + PARAKEET_LOG_INFO("%s: found GPU device %zu: %s (type: %d, cnt: %d)\n", __func__, i, dev_name, dev_type, cnt); + if (cnt == params.gpu_device) { + dev = dev_cur; + } + + if (++cnt > params.gpu_device) { + break; + } + } + } + } + + if (dev == nullptr) { + PARAKEET_LOG_INFO("%s: no GPU found\n", __func__); + return nullptr; + } + + PARAKEET_LOG_INFO("%s: using %s backend\n", __func__, ggml_backend_dev_name(dev)); + ggml_backend_t result = ggml_backend_dev_init(dev, nullptr); + if (!result) { + PARAKEET_LOG_ERROR("%s: failed to initialize %s backend\n", __func__, ggml_backend_dev_name(dev)); + } + + return result; +} + +static std::vector<ggml_backend_t> parakeet_backend_init(const parakeet_context_params & params) { + std::vector<ggml_backend_t> result; + + ggml_backend_t backend_gpu = parakeet_backend_init_gpu(params); + + if (backend_gpu) { + result.push_back(backend_gpu); + } + + // ACCEL backends + for (size_t i = 0; i < ggml_backend_dev_count(); ++i) { + ggml_backend_dev_t dev = ggml_backend_dev_get(i); + if (ggml_backend_dev_type(dev) == GGML_BACKEND_DEVICE_TYPE_ACCEL) { + PARAKEET_LOG_INFO("%s: using %s backend\n", __func__, ggml_backend_dev_name(dev)); + ggml_backend_t backend = ggml_backend_dev_init(dev, nullptr); + if (!backend) { + PARAKEET_LOG_ERROR("%s: failed to initialize %s backend\n", __func__, ggml_backend_dev_name(dev)); + continue; + } + result.push_back(backend); + } + } + + ggml_backend_t backend_cpu = ggml_backend_init_by_type(GGML_BACKEND_DEVICE_TYPE_CPU, nullptr); + if (backend_cpu == nullptr) { + throw std::runtime_error("failed to initialize CPU backend"); + } + result.push_back(backend_cpu); + + return result; +} + +using buft_list_t = std::vector<std::pair<ggml_backend_dev_t, ggml_backend_buffer_type_t>>; + +static buft_list_t make_buft_list(parakeet_context_params & params) { + // Prio order: GPU -> CPU Extra -> CPU + buft_list_t buft_list; + + // GPU + if (params.use_gpu) { + int cnt = 0; + for (size_t i = 0; i < ggml_backend_dev_count(); ++i) { + ggml_backend_dev_t dev = ggml_backend_dev_get(i); + if (ggml_backend_dev_type(dev) == GGML_BACKEND_DEVICE_TYPE_GPU || ggml_backend_dev_type(dev) == GGML_BACKEND_DEVICE_TYPE_IGPU) { + if (cnt == params.gpu_device) { + auto * buft = ggml_backend_dev_buffer_type(dev); + if (buft) { + buft_list.emplace_back(dev, buft); + } + } + + if (++cnt > params.gpu_device) { + break; + } + } + } + } + + // CPU Extra + auto * cpu_dev = ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_CPU); + auto * cpu_reg = ggml_backend_dev_backend_reg(cpu_dev); + auto get_extra_bufts_fn = (ggml_backend_dev_get_extra_bufts_t) + ggml_backend_reg_get_proc_address(cpu_reg, "ggml_backend_dev_get_extra_bufts"); + if (get_extra_bufts_fn) { + ggml_backend_buffer_type_t * extra_bufts = get_extra_bufts_fn(cpu_dev); + while (extra_bufts && *extra_bufts) { + buft_list.emplace_back(cpu_dev, *extra_bufts); + ++extra_bufts; + } + } + + // CPU + buft_list.emplace_back(cpu_dev, ggml_backend_cpu_buffer_type()); + + return buft_list; +} + +static bool weight_buft_supported(const parakeet_hparams & hparams, ggml_tensor * w, ggml_op op, ggml_backend_buffer_type_t buft, ggml_backend_dev_t dev) { + bool op_supported = true; + + if (ggml_backend_dev_type(dev) == GGML_BACKEND_DEVICE_TYPE_GPU || + ggml_backend_dev_type(dev) == GGML_BACKEND_DEVICE_TYPE_IGPU || + (ggml_backend_dev_type(dev) == GGML_BACKEND_DEVICE_TYPE_CPU && buft == ggml_backend_cpu_buffer_type())) { + // GPU and default CPU backend support all operators + op_supported = true; + } else { + switch (op) { + // The current extra_buffer_type implementations only support GGML_OP_MUL_MAT and GGML_OP_GET_ROWS + case GGML_OP_GET_ROWS: + case GGML_OP_MUL_MAT: { + ggml_init_params params = { + /*.mem_size =*/ 2 * ggml_tensor_overhead(), + /*.mem_buffer =*/ nullptr, + /*.no_alloc =*/ true, + }; + + ggml_context_ptr ctx_ptr { ggml_init(params) }; + if (!ctx_ptr) { + throw std::runtime_error("failed to create ggml context"); + } + ggml_context * ctx = ctx_ptr.get(); + + ggml_tensor * op_tensor = nullptr; + + if (op == GGML_OP_MUL_MAT) { + int64_t n_ctx = hparams.n_audio_ctx; + ggml_tensor * b = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, w->ne[0], n_ctx, w->ne[2], w->ne[3]); + op_tensor = ggml_mul_mat(ctx, w, b); + } else if (op == GGML_OP_GET_ROWS) { + int64_t num_indices = 8; + ggml_tensor * indices = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, num_indices); + op_tensor = ggml_get_rows(ctx, w, indices); + } + + // create a temporary dummy buffer for the weight so that supports_op can check the buffer type + GGML_ASSERT(w->buffer == nullptr); + w->buffer = ggml_backend_buft_alloc_buffer(buft, 0); + op_supported = ggml_backend_dev_supports_op(dev, op_tensor); + ggml_backend_buffer_free(w->buffer); + w->buffer = nullptr; + break; + } + default: { + op_supported = false; + break; + } + }; + } + + return op_supported; +} + +static ggml_backend_buffer_type_t select_weight_buft(const parakeet_hparams & hparams, ggml_tensor * w, ggml_op op, buft_list_t buft_list) { + GGML_ASSERT(!buft_list.empty()); + for (const auto & p : buft_list) { + ggml_backend_dev_t dev = p.first; + ggml_backend_buffer_type_t buft = p.second; + if (weight_buft_supported(hparams, w, op, buft, dev)) { + return buft; + } + } + + return nullptr; +} + + +// load the model from a ggml file +// + +// see the convert-parakeet-to-ggml.py script for details +// +static bool parakeet_model_load(struct parakeet_model_loader * loader, parakeet_context & wctx) { + PARAKEET_LOG_INFO("%s: loading model\n", __func__); + + const int64_t t_start_us = ggml_time_us(); + + wctx.t_start_us = t_start_us; + + auto & model = wctx.model; + auto & vocab = wctx.vocab; + + // verify magic + { + uint32_t magic; + read_safe(loader, magic); + if (magic != GGML_FILE_MAGIC) { + PARAKEET_LOG_ERROR("%s: invalid model data (bad magic)\n", __func__); + return false; + } + } + + //load hparams + parakeet_hparams hparams; + { + read_safe(loader, hparams.n_vocab); + read_safe(loader, hparams.n_audio_ctx); + read_safe(loader, hparams.n_audio_state); + read_safe(loader, hparams.n_audio_head); + read_safe(loader, hparams.n_audio_layer); + read_safe(loader, hparams.n_mels); + read_safe(loader, hparams.ftype); + read_safe(loader, hparams.n_fft); + read_safe(loader, hparams.subsampling_factor); + read_safe(loader, hparams.n_subsampling_channels); + read_safe(loader, hparams.n_conv_kernel); + read_safe(loader, hparams.n_pred_dim); + read_safe(loader, hparams.n_pred_layers); + read_safe(loader, hparams.n_tdt_durations); + read_safe(loader, hparams.n_max_tokens); + + hparams.arch = PARAKEET_ARCH_TDT; + wctx.model.hparams = hparams; + + const int32_t qntvr = hparams.ftype / GGML_QNT_VERSION_FACTOR; + + hparams.ftype %= GGML_QNT_VERSION_FACTOR; + + // for the big tensors, we have the option to store the data in 16-bit floats or quantized + // in order to save memory and also to speed up the computation + wctx.wtype = ggml_ftype_to_ggml_type((ggml_ftype) hparams.ftype); + if (wctx.wtype == GGML_TYPE_COUNT) { + PARAKEET_LOG_ERROR("%s: invalid model (bad ftype value %d)\n", __func__, hparams.ftype); + return false; + } + + const char* arch_name = hparams.arch == PARAKEET_ARCH_TDT ? "Parakeet TDT" : "unknown"; + PARAKEET_LOG_INFO("%s: arch = %s\n", __func__, arch_name); + PARAKEET_LOG_INFO("%s: n_vocab = %d\n", __func__, hparams.n_vocab); + PARAKEET_LOG_INFO("%s: n_audio_ctx = %d\n", __func__, hparams.n_audio_ctx); + PARAKEET_LOG_INFO("%s: n_audio_state = %d\n", __func__, hparams.n_audio_state); + PARAKEET_LOG_INFO("%s: n_audio_head = %d\n", __func__, hparams.n_audio_head); + PARAKEET_LOG_INFO("%s: n_audio_layer = %d\n", __func__, hparams.n_audio_layer); + PARAKEET_LOG_INFO("%s: n_mels = %d\n", __func__, hparams.n_mels); + PARAKEET_LOG_INFO("%s: n_fft = %d\n", __func__, hparams.n_fft); + PARAKEET_LOG_INFO("%s: eps = %f\n", __func__, hparams.eps); + PARAKEET_LOG_INFO("%s: ftype = %d\n", __func__, hparams.ftype); + PARAKEET_LOG_INFO("%s: qntvr = %d\n", __func__, qntvr); + PARAKEET_LOG_INFO("%s: subsampling_factor = %d\n", __func__, hparams.subsampling_factor); + PARAKEET_LOG_INFO("%s: n_subsampling_channels = %d\n", __func__, hparams.n_subsampling_channels); + PARAKEET_LOG_INFO("%s: n_conv_kernel = %d\n", __func__, hparams.n_conv_kernel); + PARAKEET_LOG_INFO("%s: n_pred_dim = %d\n", __func__, hparams.n_pred_dim); + PARAKEET_LOG_INFO("%s: n_pred_layers = %d\n", __func__, hparams.n_pred_layers); + PARAKEET_LOG_INFO("%s: n_tdt_durations = %d\n", __func__, hparams.n_tdt_durations); + PARAKEET_LOG_INFO("%s: n_max_tokens = %d\n", __func__, hparams.n_max_tokens); + } + + // load mel filters + { + auto & filters = wctx.model.filters; + + read_safe(loader, filters.n_mel); + read_safe(loader, filters.n_fb); + + filters.data.resize(filters.n_mel * filters.n_fb); + loader->read(loader->context, filters.data.data(), filters.data.size() * sizeof(float)); + BYTESWAP_FILTERS(filters); + } + + // load window function + { + int32_t n_window = 0; + read_safe(loader, n_window); + + wctx.mel_cache.window.resize(n_window); + loader->read(loader->context, wctx.mel_cache.window.data(), n_window * sizeof(float)); + +#ifdef GGML_BIG_ENDIAN + for (auto & datum : wctx.mel_cache.window) { + datum = byteswap(datum); + } +#endif + + PARAKEET_LOG_INFO("%s: loaded window function with %d samples\n", __func__, n_window); + } + + // load TDT (Token and Duration Transducer) values + { + auto & tdt_durations = wctx.model.tdt_durations; + tdt_durations.resize(hparams.n_tdt_durations); + loader->read(loader->context, tdt_durations.data(), hparams.n_tdt_durations * sizeof(uint32_t)); + + PARAKEET_LOG_INFO("%s: loaded tdt_durations: [", __func__); + for (const auto value : tdt_durations) { + PARAKEET_LOG_INFO("%u ", value); + } + PARAKEET_LOG_INFO("]\n"); + } + + // load vocab + { + int32_t n_vocab = 0; + read_safe(loader, n_vocab); + + std::string word; + std::vector<char> tmp; + + tmp.reserve(128); + + for (int i = 0; i < n_vocab; i++) { + uint32_t len; + read_safe(loader, len); + + if (len > 0) { + tmp.resize(len); + loader->read(loader->context, &tmp[0], tmp.size()); // read to buffer + word.assign(&tmp[0], tmp.size()); + } else { + PARAKEET_LOG_WARN("%s: warning: empty-string token in vocab, i = %d\n", __func__, i); + word = ""; + } + + vocab.token_to_id[word] = i; + vocab.id_to_token[i] = word; + vocab.max_token_length = std::max(vocab.max_token_length, word.size()); + } + // Blank token for transducer is at index n_vocab (8192), outside the vocabulary + int blank_id = n_vocab; + vocab.token_blank = blank_id; + vocab.id_to_token[blank_id] = "[BLANK]"; + vocab.token_to_id["[BLANK]"] = blank_id; + + // Set special token IDs by looking them up in the loaded vocabulary + // These are from the SentencePiece vocab file loaded above + if (vocab.token_to_id.find("<unk>") != vocab.token_to_id.end()) { + vocab.token_unk = vocab.token_to_id.at("<unk>"); + } else { + vocab.token_unk = 0; // Fallback + } + + if (vocab.token_to_id.find("<s>") != vocab.token_to_id.end()) { + vocab.token_bos = vocab.token_to_id.at("<s>"); + } else if (vocab.token_to_id.find("<|startoftranscript|>") != vocab.token_to_id.end()) { + vocab.token_bos = vocab.token_to_id.at("<|startoftranscript|>"); + } else { + vocab.token_bos = 0; // Fallback + } + + if (vocab.token_to_id.find("</s>") != vocab.token_to_id.end()) { + vocab.token_eos = vocab.token_to_id.at("</s>"); + } else if (vocab.token_to_id.find("<|endoftext|>") != vocab.token_to_id.end()) { + vocab.token_eos = vocab.token_to_id.at("<|endoftext|>"); + } else { + vocab.token_eos = 0; // Fallback + } + + vocab.n_vocab = model.hparams.n_vocab; + + PARAKEET_LOG_INFO("%s: loaded vocab with %d tokens (blank_id=%d, unk=%d, bos=%d, eos=%d)\n", + __func__, n_vocab, blank_id, vocab.token_unk, vocab.token_bos, vocab.token_eos); + } + + const ggml_type wtype = wctx.wtype; + + + const int n_audio_layer = hparams.n_audio_layer; + + // Calculate tensor count: pre_encode (12) + encoder layers (29 per layer) + prediction (9) + joint (6) + size_t n_tensors = 12 + (29 * n_audio_layer) + 9 + 6; + + std::map<ggml_backend_buffer_type_t, ggml_context *> ctx_map; + auto get_ctx = [&](ggml_backend_buffer_type_t buft) -> ggml_context * { + auto it = ctx_map.find(buft); + if (it == ctx_map.end()) { + ggml_init_params params = { + /*.mem_size =*/ n_tensors * ggml_tensor_overhead(), + /*.mem_buffer =*/ nullptr, + /*.no_alloc =*/ true, + }; + + ggml_context * ctx = ggml_init(params); + if (!ctx) { + throw std::runtime_error("failed to create ggml context"); + } + + ctx_map[buft] = ctx; + wctx.model.ctxs.emplace_back(ctx); + + return ctx; + } + + return it->second; + }; + + // Create a list of available bufts, in priority order + buft_list_t buft_list = make_buft_list(wctx.params); + + auto create_tensor = [&](parakeet_tensor type, ggml_tensor * meta, int layer = -1) -> ggml_tensor * { + ggml_op op = PARAKEET_TENSOR_INFO.at(type); + ggml_backend_buffer_type_t buft = select_weight_buft(hparams, meta, op, buft_list); + if (!buft) { + throw std::runtime_error(format("failed to find a compatible buffer type for parakeet tensor %s", + PARAKEET_TENSOR_NAMES.at(type))); + } + + ggml_context * ctx = get_ctx(buft); + ggml_tensor * tensor = ggml_dup_tensor(ctx, meta); + + std::string tensor_name; + if (layer >= 0) { + tensor_name = format(PARAKEET_TENSOR_NAMES.at(type), layer); + } else { + tensor_name = PARAKEET_TENSOR_NAMES.at(type); + } + + wctx.model.tensors[tensor_name] = tensor; + + return tensor; + }; + + // prepare tensors for the weights + + ggml_init_params params = { + /*.mem_size =*/ n_tensors * ggml_tensor_overhead(), + /*.mem_buffer =*/ nullptr, + /*.no_alloc =*/ true, + }; + + ggml_context * ctx = ggml_init(params); + + const int n_audio_state = hparams.n_audio_state; + + model.layers.resize(n_audio_layer); + + // Encoder pre_encode + const int n_subsampling_channels = hparams.n_subsampling_channels; + const int n_pre_enc_features = (hparams.n_mels / hparams.subsampling_factor) * n_subsampling_channels; + model.enc_pre_out_w = create_tensor(PARAKEET_TENSOR_ENC_PRE_OUT_WEIGHT, ggml_new_tensor_2d(ctx, wtype, n_pre_enc_features, n_audio_state)); + ggml_set_name(model.enc_pre_out_w, "enc_pre_out_w"); + model.enc_pre_out_b = create_tensor(PARAKEET_TENSOR_ENC_PRE_OUT_BIAS, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state)); + ggml_set_name(model.enc_pre_out_b, "enc_pre_out_b"); + + model.enc_pre_conv_0_w = create_tensor(PARAKEET_TENSOR_ENC_PRE_CONV_0_WEIGHT, ggml_new_tensor_4d(ctx, GGML_TYPE_F32, 3, 3, 1, n_subsampling_channels)); + ggml_set_name(model.enc_pre_conv_0_w, "enc_pre_conv_0_w"); + model.enc_pre_conv_0_b = create_tensor(PARAKEET_TENSOR_ENC_PRE_CONV_0_BIAS, ggml_new_tensor_4d(ctx, GGML_TYPE_F32, 1, 1, n_subsampling_channels, 1)); + ggml_set_name(model.enc_pre_conv_0_b, "enc_pre_conv_0_b"); + + model.enc_pre_conv_2_w = create_tensor(PARAKEET_TENSOR_ENC_PRE_CONV_2_WEIGHT, ggml_new_tensor_4d(ctx, GGML_TYPE_F32, 3, 3, 1, n_subsampling_channels)); + ggml_set_name(model.enc_pre_conv_2_w, "enc_pre_conv_2_w"); + model.enc_pre_conv_2_b = create_tensor(PARAKEET_TENSOR_ENC_PRE_CONV_2_BIAS, ggml_new_tensor_4d(ctx, GGML_TYPE_F32, 1, 1, n_subsampling_channels, 1)); + ggml_set_name(model.enc_pre_conv_2_b, "enc_pre_conv_2_b"); + + model.enc_pre_conv_3_w = create_tensor(PARAKEET_TENSOR_ENC_PRE_CONV_3_WEIGHT, ggml_new_tensor_4d(ctx, GGML_TYPE_F32, 1, 1, n_subsampling_channels, n_subsampling_channels)); + ggml_set_name(model.enc_pre_conv_3_w, "enc_pre_conv_3_w"); + model.enc_pre_conv_3_b = create_tensor(PARAKEET_TENSOR_ENC_PRE_CONV_3_BIAS, ggml_new_tensor_4d(ctx, GGML_TYPE_F32, 1, 1, n_subsampling_channels, 1)); + ggml_set_name(model.enc_pre_conv_3_b, "enc_pre_conv_3_b"); + + model.enc_pre_conv_5_w = create_tensor(PARAKEET_TENSOR_ENC_PRE_CONV_5_WEIGHT, ggml_new_tensor_4d(ctx, GGML_TYPE_F32, 3, 3, 1, n_subsampling_channels)); + ggml_set_name(model.enc_pre_conv_5_w, "enc_pre_conv_5_w"); + model.enc_pre_conv_5_b = create_tensor(PARAKEET_TENSOR_ENC_PRE_CONV_5_BIAS, ggml_new_tensor_4d(ctx, GGML_TYPE_F32, 1, 1, n_subsampling_channels, 1)); + ggml_set_name(model.enc_pre_conv_5_b, "enc_pre_conv_5_b"); + + model.enc_pre_conv_6_w = create_tensor(PARAKEET_TENSOR_ENC_PRE_CONV_6_WEIGHT, ggml_new_tensor_4d(ctx, GGML_TYPE_F32, 1, 1, n_subsampling_channels, n_subsampling_channels)); + ggml_set_name(model.enc_pre_conv_6_w, "enc_pre_conv_6_w"); + model.enc_pre_conv_6_b = create_tensor(PARAKEET_TENSOR_ENC_PRE_CONV_6_BIAS, ggml_new_tensor_4d(ctx, GGML_TYPE_F32, 1, 1, n_subsampling_channels, 1)); + ggml_set_name(model.enc_pre_conv_6_b, "enc_pre_conv_6_b"); + + // Encoder layers + for (int i = 0; i < n_audio_layer; ++i) { + auto & layer = model.layers[i]; + + // Feed forward 1 + layer.norm_ff1_w = create_tensor(PARAKEET_TENSOR_ENC_NORM_FF1_WEIGHT, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state), i); + layer.norm_ff1_b = create_tensor(PARAKEET_TENSOR_ENC_NORM_FF1_BIAS, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state), i); + layer.ff1_linear1_w = create_tensor(PARAKEET_TENSOR_ENC_FF1_LINEAR1_WEIGHT, ggml_new_tensor_2d(ctx, wtype, n_audio_state, 4*n_audio_state), i); + ggml_format_name(layer.ff1_linear1_w, "enc_%d_ff1_linear1_w", i); + layer.ff1_linear2_w = create_tensor(PARAKEET_TENSOR_ENC_FF1_LINEAR2_WEIGHT, ggml_new_tensor_2d(ctx, wtype, 4*n_audio_state, n_audio_state), i); + ggml_format_name(layer.ff1_linear2_w, "enc_%d_ff1_linear2_w", i); + + // Convolution module + layer.norm_conv_w = create_tensor(PARAKEET_TENSOR_ENC_NORM_CONV_WEIGHT, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state), i); + ggml_format_name(layer.norm_conv_w, "enc_%d_norm_conv_w", i); + layer.norm_conv_b = create_tensor(PARAKEET_TENSOR_ENC_NORM_CONV_BIAS, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state), i); + ggml_format_name(layer.norm_conv_b, "enc_%d_norm_conv_b", i); + layer.conv_pw1_w = create_tensor(PARAKEET_TENSOR_ENC_CONV_PW1_WEIGHT, ggml_new_tensor_2d(ctx, wtype, n_audio_state, 2*n_audio_state), i); + ggml_format_name(layer.conv_pw1_w, "enc_%d_conv_pw1_w", i); + layer.conv_dw_w = create_tensor(PARAKEET_TENSOR_ENC_CONV_DW_WEIGHT, ggml_new_tensor_2d(ctx, GGML_TYPE_F32, hparams.n_conv_kernel, n_audio_state), i); + ggml_format_name(layer.conv_dw_w, "enc_%d_conv_dw_w", i); + layer.conv_bn_w = create_tensor(PARAKEET_TENSOR_ENC_CONV_BN_WEIGHT, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state), i); + ggml_format_name(layer.conv_bn_w, "enc_%d_conv_bn_w", i); + layer.conv_bn_b = create_tensor(PARAKEET_TENSOR_ENC_CONV_BN_BIAS, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state), i); + ggml_format_name(layer.conv_bn_b, "enc_%d_conv_bn_b", i); + layer.conv_bn_mean = create_tensor(PARAKEET_TENSOR_ENC_CONV_BN_MEAN, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state), i); + layer.conv_bn_var = create_tensor(PARAKEET_TENSOR_ENC_CONV_BN_VAR, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state), i); + ggml_format_name(layer.conv_bn_var, "enc_%d_conv_bn_var", i); + layer.conv_bn_num_batches = create_tensor(PARAKEET_TENSOR_ENC_CONV_BN_NUM_BATCHES, ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 1), i); + layer.conv_pw2_w = create_tensor(PARAKEET_TENSOR_ENC_CONV_PW2_WEIGHT, ggml_new_tensor_2d(ctx, wtype, n_audio_state, n_audio_state), i); + ggml_format_name(layer.conv_pw2_w, "enc_%d_conv_pw2_w", i); + + // Self attention + layer.norm_attn_w = create_tensor(PARAKEET_TENSOR_ENC_NORM_ATTN_WEIGHT, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state), i); + layer.norm_attn_b = create_tensor(PARAKEET_TENSOR_ENC_NORM_ATTN_BIAS, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state), i); + layer.attn_pos_bias_u = create_tensor(PARAKEET_TENSOR_ENC_ATTN_POS_BIAS_U, ggml_new_tensor_2d(ctx, GGML_TYPE_F32, hparams.n_audio_state / hparams.n_audio_head, hparams.n_audio_head), i); + layer.attn_pos_bias_v = create_tensor(PARAKEET_TENSOR_ENC_ATTN_POS_BIAS_V, ggml_new_tensor_2d(ctx, GGML_TYPE_F32, hparams.n_audio_state / hparams.n_audio_head, hparams.n_audio_head), i); + layer.attn_q_w = create_tensor(PARAKEET_TENSOR_ENC_ATTN_Q_WEIGHT, ggml_new_tensor_2d(ctx, wtype, n_audio_state, n_audio_state), i); + layer.attn_k_w = create_tensor(PARAKEET_TENSOR_ENC_ATTN_K_WEIGHT, ggml_new_tensor_2d(ctx, wtype, n_audio_state, n_audio_state), i); + layer.attn_v_w = create_tensor(PARAKEET_TENSOR_ENC_ATTN_V_WEIGHT, ggml_new_tensor_2d(ctx, wtype, n_audio_state, n_audio_state), i); + layer.attn_out_w = create_tensor(PARAKEET_TENSOR_ENC_ATTN_OUT_WEIGHT, ggml_new_tensor_2d(ctx, wtype, n_audio_state, n_audio_state), i); + layer.attn_pos_w = create_tensor(PARAKEET_TENSOR_ENC_ATTN_POS_WEIGHT, ggml_new_tensor_2d(ctx, wtype, n_audio_state, n_audio_state), i); + ggml_format_name(layer.attn_pos_w, "enc_%d_attn_pos_w", i); + + // Feed forward 2 + layer.norm_ff2_w = create_tensor(PARAKEET_TENSOR_ENC_NORM_FF2_WEIGHT, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state), i); + layer.norm_ff2_b = create_tensor(PARAKEET_TENSOR_ENC_NORM_FF2_BIAS, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state), i); + layer.ff2_linear1_w = create_tensor(PARAKEET_TENSOR_ENC_FF2_LINEAR1_WEIGHT, ggml_new_tensor_2d(ctx, wtype, n_audio_state, 4*n_audio_state), i); + layer.ff2_linear2_w = create_tensor(PARAKEET_TENSOR_ENC_FF2_LINEAR2_WEIGHT, ggml_new_tensor_2d(ctx, wtype, 4*n_audio_state, n_audio_state), i); + + // Output norm + layer.norm_out_w = create_tensor(PARAKEET_TENSOR_ENC_NORM_OUT_WEIGHT, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state), i); + layer.norm_out_b = create_tensor(PARAKEET_TENSOR_ENC_NORM_OUT_BIAS, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state), i); + } + + // Prediction network (decoder) + const int dec_hidden = hparams.n_pred_dim; + const int n_pred_embed = hparams.n_vocab + 1; // vocab + blank token + const int n_lstm_gates = 4 * dec_hidden; // 4 LSTM gates + const int n_joint_out = hparams.n_vocab + hparams.n_tdt_durations + 1; // vocab + durations + blank + + // The prediction/joint hidden dimension is 640, which is not a multiple of the + // K-quant block size (256). For K-quant models, we keep these tensors at F32. + const int blck = ggml_blck_size(wtype); + const ggml_type pred_wtype = (blck > 1 && dec_hidden % blck != 0) ? GGML_TYPE_F32 : wtype; + const ggml_type join_wtype = pred_wtype; + + model.prediction.embed_w = create_tensor(PARAKEET_TENSOR_PRED_EMBED_WEIGHT, ggml_new_tensor_2d(ctx, pred_wtype, dec_hidden, n_pred_embed)); + model.prediction.lstm_layer.resize(hparams.n_pred_layers); + for (int i = 0; i < hparams.n_pred_layers; ++i) { + auto & layer = model.prediction.lstm_layer[i]; + layer.ih_w = create_tensor(PARAKEET_TENSOR_PRED_LSTM_WEIGHT_IH, ggml_new_tensor_2d(ctx, pred_wtype, dec_hidden, n_lstm_gates), i); + ggml_format_name(layer.ih_w, "pred_%d_ih_w", i); + + layer.hh_w = create_tensor(PARAKEET_TENSOR_PRED_LSTM_WEIGHT_HH, ggml_new_tensor_2d(ctx, pred_wtype, dec_hidden, n_lstm_gates), i); + ggml_format_name(layer.hh_w, "pred_%d_hh_w", i); + + layer.b_h = create_tensor(PARAKEET_TENSOR_PRED_LSTM_BIAS_H, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_lstm_gates), i); + ggml_format_name(layer.b_h, "pred_%d_b_h", i); + } + + // Joint network + model.joint.pred_w = create_tensor(PARAKEET_TENSOR_JOINT_PRED_WEIGHT, ggml_new_tensor_2d(ctx, join_wtype, dec_hidden, dec_hidden)); + ggml_set_name(model.joint.pred_w, "pred_w"); + model.joint.pred_b = create_tensor(PARAKEET_TENSOR_JOINT_PRED_BIAS, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, dec_hidden)); + ggml_set_name(model.joint.pred_b, "pred_b"); + model.joint.enc_w = create_tensor(PARAKEET_TENSOR_JOINT_ENC_WEIGHT, ggml_new_tensor_2d(ctx, wtype, n_audio_state, dec_hidden)); + ggml_set_name(model.joint.enc_w, "enc_w"); + model.joint.enc_b = create_tensor(PARAKEET_TENSOR_JOINT_ENC_BIAS, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, dec_hidden)); + ggml_set_name(model.joint.enc_b, "enc_b"); + model.joint.net_w = create_tensor(PARAKEET_TENSOR_JOINT_NET_WEIGHT, ggml_new_tensor_2d(ctx, join_wtype, dec_hidden, n_joint_out)); + ggml_set_name(model.joint.net_w, "net_w"); + model.joint.net_b = create_tensor(PARAKEET_TENSOR_JOINT_NET_BIAS, ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_joint_out)); + ggml_set_name(model.joint.net_b, "net_b"); + + ggml_free(ctx); + + // allocate tensors in the backend buffers + for (auto & p : ctx_map) { + ggml_backend_buffer_type_t buft = p.first; + ggml_context * ctx = p.second; + ggml_backend_buffer_t buf = ggml_backend_alloc_ctx_tensors_from_buft(ctx, buft); + if (buf) { + wctx.model.buffers.emplace_back(buf); + + size_t size_main = ggml_backend_buffer_get_size(buf); + PARAKEET_LOG_INFO("%s: %12s total size = %8.2f MB\n", __func__, ggml_backend_buffer_name(buf), size_main / 1e6); + } + } + + // load weights + { + size_t total_size = 0; + + auto & tensors_map = wctx.model.tensors; + int & n_loaded = wctx.model.n_loaded; + + n_loaded = 0; + + std::vector<char> read_buf; + + while (true) { + int32_t n_dims; + int32_t length; + int32_t ttype; + + read_safe(loader, n_dims); + read_safe(loader, length); + read_safe(loader, ttype); + + if (loader->eof(loader->context)) { + break; + } + + int32_t nelements = 1; + int32_t ne[4] = { 1, 1, 1, 1 }; + for (int i = 0; i < n_dims; ++i) { + read_safe(loader, ne[i]); + nelements *= ne[i]; + } + + std::string name; + std::vector<char> tmp(length); // create a buffer + loader->read(loader->context, &tmp[0], tmp.size()); // read to buffer + name.assign(&tmp[0], tmp.size()); + + if (tensors_map.find(name) == tensors_map.end()) { + PARAKEET_LOG_ERROR("%s: unknown tensor '%s' in model file\n", __func__, name.data()); + return false; + } + + auto tensor = tensors_map[name.data()]; + + if (ggml_nelements(tensor) != nelements) { + PARAKEET_LOG_ERROR("%s: tensor '%s' has wrong size in model file\n", __func__, name.data()); + PARAKEET_LOG_ERROR("%s: shape: [%d, %d, %d], expected: [%d, %d, %d]\n", + __func__, ne[0], ne[1], ne[2], (int) tensor->ne[0], (int) tensor->ne[1], (int) tensor->ne[2]); + return false; + } + + if (tensor->ne[0] != ne[0] || tensor->ne[1] != ne[1] || tensor->ne[2] != ne[2] || tensor->ne[3] != ne[3]) { + PARAKEET_LOG_ERROR("%s: tensor '%s' has wrong shape in model file: got [%d, %d, %d, %d], expected [%d, %d, %d, %d]\n", + __func__, name.data(), (int) tensor->ne[0], (int) tensor->ne[1], (int) tensor->ne[2], (int) tensor->ne[3], ne[0], ne[1], ne[2], ne[3]); + return false; + } + + const size_t bpe = ggml_type_size(ggml_type(ttype)); + + if ((nelements*bpe)/ggml_blck_size(tensor->type) != ggml_nbytes(tensor)) { + PARAKEET_LOG_ERROR("%s: tensor '%s' has wrong size in model file: got %zu, expected %zu\n", + __func__, name.data(), ggml_nbytes(tensor), nelements*bpe); + return false; + } + + if (ggml_backend_buffer_is_host(tensor->buffer)) { + // for the CPU and Metal backend, we can read directly into the tensor + loader->read(loader->context, tensor->data, ggml_nbytes(tensor)); + BYTESWAP_TENSOR(tensor); + } else { + // read into a temporary buffer first, then copy to device memory + read_buf.resize(ggml_nbytes(tensor)); + + loader->read(loader->context, read_buf.data(), read_buf.size()); + + ggml_backend_tensor_set(tensor, read_buf.data(), 0, ggml_nbytes(tensor)); + } + + total_size += ggml_nbytes(tensor); + n_loaded++; + } + + PARAKEET_LOG_INFO("%s: model size = %7.2f MB\n", __func__, total_size/1e6); + + if (n_loaded == 0) { + PARAKEET_LOG_WARN("%s: WARN no tensors loaded from model file - assuming empty model for testing\n", __func__); + } else if (n_loaded != (int) tensors_map.size()) { + PARAKEET_LOG_ERROR("%s: ERROR not all tensors loaded from model file - expected %zu, got %d\n", __func__, tensors_map.size(), n_loaded); + return false; + } + } + + auto & buffers = wctx.model.buffers; + for (auto & buf : buffers) { + ggml_backend_buffer_set_usage(buf, GGML_BACKEND_BUFFER_USAGE_WEIGHTS); + } + + wctx.t_load_us = ggml_time_us() - t_start_us; + + return true; +} + +// conv subsampling + conformer encoder +static struct ggml_cgraph * parakeet_build_graph_encode(parakeet_context & pctx, parakeet_state & pstate) { + const auto & model = pctx.model; + const auto & hparams = model.hparams; + const int n_mel_time = pstate.n_audio_ctx > 0 ? pstate.n_audio_ctx : hparams.n_audio_ctx; + const int n_mels = hparams.n_mels; + const int n_layer = hparams.n_audio_layer; + const int n_state = hparams.n_audio_state; + const float fc_factor = 0.5f; + + struct ggml_init_params params = { + /*.mem_size =*/ pstate.sched_encode.meta.size(), + /*.mem_buffer =*/ pstate.sched_encode.meta.data(), + /*.no_alloc =*/ true, + }; + + struct ggml_context * ctx0 = ggml_init(params); + ggml_cgraph * gf = ggml_new_graph_custom(ctx0, PARAKEET_MAX_NODES, false); + + // Conv subsampling + + // [freq, time] + struct ggml_tensor * mel = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_mels, n_mel_time, 1, 1); + ggml_set_name(mel, "mel"); + ggml_set_input(mel); + + // [freq, time, channels, batch] + struct ggml_tensor * cur = ggml_conv_2d(ctx0, model.enc_pre_conv_0_w, mel, 2, 2, 1, 1, 1, 1); + cur = ggml_add(ctx0, cur, model.enc_pre_conv_0_b); + ggml_set_name(cur, "pre_conv_0"); + + cur = ggml_relu(ctx0, cur); + ggml_set_name(cur, "pre_conv_0_relu"); + + // [freq, time, channels, batch] + cur = ggml_conv_2d_dw_direct(ctx0, model.enc_pre_conv_2_w, cur, 2, 2, 1, 1, 1, 1); + cur = ggml_add(ctx0, cur, model.enc_pre_conv_2_b); + ggml_set_name(cur, "pre_conv_2"); + + // [freq, time, channels, batch] + cur = ggml_conv_2d(ctx0, model.enc_pre_conv_3_w, cur, 1, 1, 0, 0, 1, 1); + cur = ggml_add(ctx0, cur, model.enc_pre_conv_3_b); + ggml_set_name(cur, "pre_conv_3"); + + cur = ggml_relu(ctx0, cur); + ggml_set_name(cur, "pre_conv_3_relu"); + + // [freq, time, channels, batch] + cur = ggml_conv_2d_dw_direct(ctx0, model.enc_pre_conv_5_w, cur, 2, 2, 1, 1, 1, 1); + ggml_set_name(cur, "pre_conv_5_direct"); + cur = ggml_add(ctx0, cur, model.enc_pre_conv_5_b); + ggml_set_name(cur, "pre_conv_5"); + + // [freq, time, channels, batch] + cur = ggml_conv_2d(ctx0, model.enc_pre_conv_6_w, cur, 1, 1, 0, 0, 1, 1); + cur = ggml_add(ctx0, cur, model.enc_pre_conv_6_b); + ggml_set_name(cur, "pre_conv_6"); + + cur = ggml_relu(ctx0, cur); + ggml_set_name(cur, "pre_conv_6_relu"); + + // [freq, time, chan] + cur = ggml_permute(ctx0, cur, 0, 2, 1, 3); + // [freq, chan, time] + cur = ggml_cont(ctx0, cur); + + const int n_freq = cur->ne[0]; // 16 + const int n_chan = cur->ne[1]; // 256 + const int n_frames = cur->ne[2]; // time + + // [freq, time, chan, batch] -> [(freq * chan), time] + cur = ggml_reshape_2d(ctx0, cur, n_freq * n_chan, n_frames); + + cur = ggml_mul_mat(ctx0, model.enc_pre_out_w, cur); + cur = ggml_add(ctx0, cur, model.enc_pre_out_b); + + ggml_set_name(cur, "pre_enc_out"); + + // Encoder + // cur: [n_state, n_enc_time] + + const int n_time = cur->ne[1]; + const bool local_attn = n_time > PARAKEET_LOCAL_ATTN_THRESHOLD; + const int att_left = local_attn ? PARAKEET_LOCAL_ATTN_WINDOW : n_time - 1; + const int att_right = local_attn ? PARAKEET_LOCAL_ATTN_WINDOW : n_time - 1; + const int window_size = local_attn ? att_left + att_right + 1 : 2 * n_time - 1; + const int d_half = n_state / 2; + const int mask_dim = local_attn ? window_size : n_time; + + // mask [key, n_time] + struct ggml_tensor * attn_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, mask_dim, n_time); + ggml_set_name(attn_mask, "attn_mask"); + ggml_set_input(attn_mask); + + struct ggml_tensor * local_mask = nullptr; + if (local_attn) { + const int chunk = att_left + att_right; + local_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, chunk + window_size - 1, chunk); + ggml_set_name(local_mask, "local_mask"); + ggml_set_input(local_mask); + } + + struct ggml_tensor * pos_freqs = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, d_half); + ggml_set_name(pos_freqs, "pos_freqs"); + ggml_set_input(pos_freqs); + + struct ggml_tensor * rel_positions = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, 1, window_size); + ggml_set_name(rel_positions, "rel_positions"); + ggml_set_input(rel_positions); + + struct ggml_tensor * freqs = ggml_repeat_4d(ctx0, pos_freqs, d_half, window_size, 1, 1); + struct ggml_tensor * theta = ggml_mul(ctx0, freqs, rel_positions); + + struct ggml_tensor * sin_t = ggml_reshape_3d(ctx0, ggml_sin(ctx0, theta), 1, d_half, window_size); + struct ggml_tensor * cos_t = ggml_reshape_3d(ctx0, ggml_cos(ctx0, theta), 1, d_half, window_size); + // [n_state, window_size] + struct ggml_tensor * pos_emb = ggml_reshape_2d(ctx0, ggml_cont(ctx0, ggml_concat(ctx0, sin_t, cos_t, 0)), n_state, window_size); + ggml_set_name(pos_emb, "pos_emb"); + + for (int il = 0; il < n_layer; ++il) { + const auto & layer = model.layers[il]; + + // FFN1 + { + struct ggml_tensor * residual = cur; + ggml_format_name(cur, "enc_%d_res", il); + + // norm + cur = ggml_norm(ctx0, cur, hparams.eps); + cur = ggml_add(ctx0, ggml_mul(ctx0, cur, layer.norm_ff1_w), layer.norm_ff1_b); + ggml_format_name(cur, "enc_%d_ffn_norm_1", il); + + // ffn_1 + cur = ggml_mul_mat(ctx0, layer.ff1_linear1_w, cur); + cur = ggml_silu(ctx0, cur); + ggml_format_name(cur, "enc_%d_silu", il); + + cur = ggml_mul_mat(ctx0, layer.ff1_linear2_w, cur); + ggml_format_name(cur, "enc_%d_ffn_1", il); + + cur = ggml_add(ctx0, residual, ggml_scale(ctx0, cur, fc_factor)); + ggml_format_name(cur, "enc_%d_res_ffn", il); + } + + // self attention block using relative positional encoding computed in graph. + { + // [feat, time_frames, 1, 1] + struct ggml_tensor * residual = cur; + + cur = ggml_norm(ctx0, cur, hparams.eps); + cur = ggml_add(ctx0, ggml_mul(ctx0, cur, layer.norm_attn_w), layer.norm_attn_b); + ggml_format_name(cur, "enc_%d_attn_norm", il); + + const int n_head = hparams.n_audio_head; + const int d_head = n_state / n_head; + + // [feat, time_frames, 1, 1] + struct ggml_tensor * Q_cur = ggml_mul_mat(ctx0, layer.attn_q_w, cur); + struct ggml_tensor * K_cur = ggml_mul_mat(ctx0, layer.attn_k_w, cur); + struct ggml_tensor * V_cur = ggml_mul_mat(ctx0, layer.attn_v_w, cur); + + Q_cur = ggml_reshape_3d(ctx0, Q_cur, d_head, n_head, n_time); + K_cur = ggml_reshape_3d(ctx0, K_cur, d_head, n_head, n_time); + V_cur = ggml_reshape_3d(ctx0, V_cur, d_head, n_head, n_time); + + struct ggml_tensor * pos = ggml_mul_mat(ctx0, layer.attn_pos_w, pos_emb); + pos = ggml_reshape_3d(ctx0, pos, d_head, n_head, window_size); + pos = ggml_cont(ctx0, ggml_permute(ctx0, pos, 0, 2, 1, 3)); + + if (local_attn) { + const int chunk = att_left + att_right; + const int n_group = (n_time + chunk - 1) / chunk; + const int n_time_padded = n_group * chunk; + const int n_kv_chunk = chunk + window_size - 1; + const int n_kv_dense = n_kv_chunk * n_group; + const bool need_padding = n_time_padded > n_time; + + Q_cur = ggml_cont(ctx0, ggml_permute(ctx0, Q_cur, 0, 2, 1, 3)); + K_cur = ggml_cont(ctx0, ggml_permute(ctx0, K_cur, 0, 2, 1, 3)); + V_cur = ggml_cont(ctx0, ggml_permute(ctx0, V_cur, 0, 2, 1, 3)); + + // content bias + struct ggml_tensor * bias_u = ggml_reshape_3d(ctx0, layer.attn_pos_bias_u, d_head, 1, n_head); + struct ggml_tensor * Q_u = ggml_add(ctx0, Q_cur, bias_u); + + // position bias + struct ggml_tensor * bias_v = ggml_reshape_3d(ctx0, layer.attn_pos_bias_v, d_head, 1, n_head); + struct ggml_tensor * Q_v = ggml_add(ctx0, Q_cur, bias_v); + + // right pad the time_frame. + struct ggml_tensor * Q_u_padded = need_padding ? + ggml_pad_ext(ctx0, Q_u, 0, 0, 0, n_time_padded - n_time, 0, 0, 0, 0) : Q_u; + Q_u_padded = ggml_reshape_4d(ctx0, Q_u_padded, d_head, chunk, n_group, n_head); + + // Add padding to front and back (for the first timeframe and the last timeframe). + struct ggml_tensor * K_padded = ggml_pad_ext(ctx0, K_cur, 0, 0, att_left, att_right, 0, 0, 0, 0); + + // pad time axis to match n_kv_dense if needed. + if (n_kv_dense > K_padded->ne[1]) { + K_padded = ggml_pad_ext(ctx0, K_padded, 0, 0, 0, n_kv_dense - K_padded->ne[1], 0, 0, 0, 0); + } + + // Create a 4d tensor where each group spans a wide window of + // 512 keys (n_kv_chunk), but moving to the next group (nb[2]) + // only jumps forward by 256 frames (chunk * nb[1]). This creates + // a 256 frame overlap, shared keys in RAM without copies. + struct ggml_tensor * K_chunk = ggml_view_4d(ctx0, K_padded, + d_head, n_kv_chunk, n_group, n_head, + K_padded->nb[1], + (size_t) chunk * K_padded->nb[1], + K_padded->nb[2], + 0); + K_chunk = ggml_cont(ctx0, K_chunk); + + struct ggml_tensor * content_scores = ggml_mul_mat(ctx0, K_chunk, Q_u_padded); + + // The above mul_mat operation, combined with K_chunk's overlapping + // frames, produces a dense matrix. But some of the results in + // this matrix were computed for keys that aren't part of that + // query's window. So we shift each row to keep only the results + // that we want. + content_scores = ggml_view_4d(ctx0, content_scores, + window_size, chunk, n_group, n_head, + (size_t) (chunk + window_size) * content_scores->nb[0], + content_scores->nb[2], + content_scores->nb[3], + 0); + content_scores = ggml_cont(ctx0, content_scores); + + // ungrouping. + content_scores = ggml_reshape_3d(ctx0, content_scores, window_size, n_time_padded, n_head); + + // remove padding if padding was applied (truncating to n_time). + if (need_padding) { + content_scores = ggml_view_3d(ctx0, content_scores, + window_size, n_time, n_head, + content_scores->nb[1], + content_scores->nb[2], + 0); + } + + struct ggml_tensor * rel_pos_scores = ggml_mul_mat(ctx0, pos, Q_v); + + // attention_score = content similarity + relative position scores + struct ggml_tensor * attn_scores = ggml_add(ctx0, content_scores, rel_pos_scores); + + attn_scores = ggml_soft_max_ext(ctx0, attn_scores, attn_mask, 1.0f / std::sqrt(d_head), 0.0f); + + // right pad the probabilites. + struct ggml_tensor * probs_padded = need_padding ? + ggml_pad_ext(ctx0, attn_scores, 0, 0, 0, n_time_padded - n_time, 0, 0, 0, 0) : attn_scores; + + probs_padded = ggml_reshape_4d(ctx0, probs_padded, window_size, chunk, n_group, n_head); + probs_padded = ggml_pad_ext(ctx0, probs_padded, 0, chunk, 0, 0, 0, 0, 0, 0); + probs_padded = ggml_view_4d(ctx0, probs_padded, + n_kv_chunk, chunk, n_group, n_head, + (size_t) n_kv_chunk * probs_padded->nb[0], + probs_padded->nb[2], + probs_padded->nb[3], + 0); + probs_padded = ggml_cont(ctx0, probs_padded); + probs_padded = ggml_mul(ctx0, probs_padded, local_mask); + + // Add padding to front and back (for the first timeframe and the last timeframe). + struct ggml_tensor * V_padded = ggml_pad_ext(ctx0, V_cur, 0, 0, att_left, att_right, 0, 0, 0, 0); + + // pad time axis to match n_kv_dense if needed. + if (n_kv_dense > V_padded->ne[1]) { + V_padded = ggml_pad_ext(ctx0, V_padded, 0, 0, 0, n_kv_dense - V_padded->ne[1], 0, 0, 0, 0); + } + + V_padded = ggml_cont(ctx0, ggml_transpose(ctx0, V_padded)); + + struct ggml_tensor * V_chunk = ggml_view_4d(ctx0, V_padded, + n_kv_chunk, d_head, n_group, n_head, + V_padded->nb[1], + (size_t) chunk * V_padded->nb[0], + V_padded->nb[2], + 0); + V_chunk = ggml_cont(ctx0, V_chunk); + + cur = ggml_mul_mat(ctx0, V_chunk, probs_padded); + // ungroup. + cur = ggml_reshape_3d(ctx0, cur, d_head, n_time_padded, n_head); + // unpad + if (need_padding) { + cur = ggml_view_3d(ctx0, cur, d_head, n_time, n_head, cur->nb[1], cur->nb[2], 0); + } + + cur = ggml_cont(ctx0, ggml_permute(ctx0, cur, 0, 2, 1, 3)); + cur = ggml_reshape_2d(ctx0, cur, n_state, n_time); + cur = ggml_mul_mat(ctx0, layer.attn_out_w, cur); + } else { + struct ggml_tensor * Q_u = ggml_add(ctx0, Q_cur, layer.attn_pos_bias_u); + ggml_format_name(Q_u, "enc_%d_attn_q_u", il); + + struct ggml_tensor * K_prep = ggml_permute(ctx0, K_cur, 0, 2, 1, 3); + struct ggml_tensor * Q_prep = ggml_permute(ctx0, Q_u, 0, 2, 1, 3); + struct ggml_tensor * content_scores = ggml_mul_mat(ctx0, K_prep, Q_prep); + ggml_format_name(content_scores, "enc_%d_attn_content_scores", il); + + struct ggml_tensor * Q_v = ggml_add(ctx0, Q_cur, layer.attn_pos_bias_v); + ggml_format_name(Q_v, "enc_%d_attn_q_v", il); + + Q_v = ggml_permute(ctx0, Q_v, 0, 2, 1, 3); + Q_v = ggml_cont(ctx0, Q_v); + ggml_format_name(Q_v, "enc_%d_attn_q_v_perm", il); + + struct ggml_tensor * rel_pos_scores = ggml_mul_mat(ctx0, pos, Q_v); + ggml_format_name(rel_pos_scores, "enc_%d_attn_rel_pos", il); + + // Relative position shifting is performed in the following block. + // Some more details on the operations performed below can be found here: + // https://github.com/danbev/learning-ai/blob/main/notes/whisper/parakeet.md#relative-position-shift + { + const auto pos_window = rel_pos_scores->ne[0]; + const auto n_frame = rel_pos_scores->ne[1]; + const auto n_head_cur = rel_pos_scores->ne[2]; + + rel_pos_scores = ggml_pad(ctx0, rel_pos_scores, 1, 0, 0, 0); + rel_pos_scores = ggml_roll(ctx0, rel_pos_scores, 1, 0, 0, 0); + + rel_pos_scores = ggml_reshape_3d(ctx0, rel_pos_scores, n_frame, pos_window + 1, n_head_cur); + ggml_format_name(rel_pos_scores, "enc_%d_attn_rel_pos_reshaped", il); + + int center = pos_window / 2; + size_t offset = rel_pos_scores->nb[0] * (center+1); + + rel_pos_scores = ggml_view_3d(ctx0, rel_pos_scores, + n_frame, pos_window, n_head_cur, + (pos_window) * 4, + rel_pos_scores->nb[2], + offset); + + ggml_format_name(rel_pos_scores, "enc_%d_attn_rel_pos_shifted", il); + + rel_pos_scores = ggml_view_3d(ctx0, rel_pos_scores, + content_scores->ne[0], + content_scores->ne[1], + rel_pos_scores->ne[2], + rel_pos_scores->nb[1], + rel_pos_scores->nb[2], + 0); + rel_pos_scores = ggml_cont(ctx0, rel_pos_scores); + ggml_format_name(rel_pos_scores, "enc_%d_attn_rel_pos_shifted_view", il); + } + + struct ggml_tensor * attn_scores = ggml_add(ctx0, content_scores, rel_pos_scores); + ggml_format_name(attn_scores, "enc_%d_attn_scores", il); + attn_scores = ggml_scale(ctx0, attn_scores, 1.0f / std::sqrt(d_head)); + attn_scores = ggml_add(ctx0, attn_scores, attn_mask); + ggml_format_name(attn_scores, "enc_%d_attn_scores_scaled", il); + + struct ggml_tensor * probs = ggml_soft_max(ctx0, attn_scores); + ggml_format_name(probs, "enc_%d_attn_probs", il); + + V_cur = ggml_cont(ctx0, ggml_permute(ctx0, V_cur, 1, 2, 0, 3)); + ggml_format_name(V_cur, "enc_%d_attn_v_cur", il); + cur = ggml_mul_mat(ctx0, probs, V_cur); + ggml_format_name(cur, "enc_%d_attn_inp", il); + + cur = ggml_permute(ctx0, cur, 2, 0, 1, 3); + cur = ggml_cont_2d(ctx0, cur, n_state, n_time); + cur = ggml_mul_mat(ctx0, layer.attn_out_w, cur); + } + ggml_format_name(cur, "enc_%d_attn_out", il); + + cur = ggml_add(ctx0, residual, cur); + ggml_format_name(cur, "enc_%d_attn_res", il); + } + + // Convolution + { + struct ggml_tensor * residual = cur; + ggml_format_name(cur, "enc_%d_residual_conv", il); + + cur = ggml_norm(ctx0, cur, hparams.eps); + cur = ggml_add(ctx0, ggml_mul(ctx0, cur, layer.norm_conv_w), layer.norm_conv_b); + ggml_format_name(cur, "enc_%d_norm_conv", il); + + // pointwise 1d convolution: [1024, 138] -> [2048, 138] + cur = ggml_mul_mat(ctx0, layer.conv_pw1_w, cur); + ggml_format_name(cur, "enc_%d_conv_pw1", il); + + { + int64_t d = cur->ne[0] / 2; + struct ggml_tensor * signal = ggml_view_2d(ctx0, cur, d, cur->ne[1], cur->nb[1], 0); + struct ggml_tensor * gate = ggml_view_2d(ctx0, cur, d, cur->ne[1], cur->nb[1], d * cur->nb[0]); + + cur = ggml_mul(ctx0, signal, ggml_sigmoid(ctx0, gate)); + ggml_format_name(cur, "enc_%d_conv_glu", il); + } + + cur = ggml_cont(ctx0, ggml_transpose(ctx0, cur)); + + // use ggml_ssm_conv for f32 precision + const int dw_pad = (hparams.n_conv_kernel - 1) / 2; + cur = ggml_pad(ctx0, cur, dw_pad, 0, 0, 0); + cur = ggml_roll(ctx0, cur, dw_pad, 0, 0, 0); + cur = ggml_pad(ctx0, cur, dw_pad, 0, 0, 0); + ggml_format_name(cur, "enc_%d_conv_dw_pad", il); + + cur = ggml_ssm_conv(ctx0, cur, layer.conv_dw_w); + ggml_format_name(cur, "enc_%d_conv_1d_dw", il); + + cur = ggml_sub(ctx0, cur, layer.conv_bn_mean); + struct ggml_tensor * std = ggml_sqrt(ctx0, layer.conv_bn_var); + cur = ggml_div(ctx0, cur, std); + cur = ggml_add(ctx0, ggml_mul(ctx0, cur, layer.conv_bn_w), layer.conv_bn_b); + ggml_format_name(cur, "enc_%d_conv_bn", il); + + cur = ggml_silu(ctx0, cur); + ggml_format_name(cur, "enc_%d_conv_silu", il); + + cur = ggml_mul_mat(ctx0, layer.conv_pw2_w, cur); + ggml_format_name(cur, "enc_%d_conv_pw2", il); + + cur = ggml_add(ctx0, residual, cur); + ggml_format_name(cur, "enc_%d_conv_res", il); + } + + // FFN2 + { + struct ggml_tensor * residual = cur; + cur = ggml_norm(ctx0, cur, hparams.eps); + cur = ggml_add(ctx0, ggml_mul(ctx0, cur, layer.norm_ff2_w), layer.norm_ff2_b); + ggml_format_name(cur, "enc_%d_ffn_norm_2", il); + + cur = ggml_mul_mat(ctx0, layer.ff2_linear1_w, cur); + cur = ggml_silu(ctx0, cur); + cur = ggml_mul_mat(ctx0, layer.ff2_linear2_w, cur); + cur = ggml_add(ctx0, residual, ggml_scale(ctx0, cur, 0.5)); + ggml_format_name(cur, "enc_%d_ffn_res", il); + } + + cur = ggml_norm(ctx0, cur, hparams.eps); + cur = ggml_add(ctx0, ggml_mul(ctx0, cur, layer.norm_out_w), layer.norm_out_b); + } + + ggml_set_name(cur, "encoder_out"); + pstate.n_frames = cur->ne[1]; + + struct ggml_tensor * enc_out_view = ggml_view_2d(ctx0, pstate.enc_out, n_state, pstate.n_frames, pstate.enc_out->nb[1], 0); + ggml_build_forward_expand(gf, ggml_cpy(ctx0, cur, enc_out_view)); + + ggml_free(ctx0); + + return gf; +} + +static bool parakeet_encode_internal( + parakeet_context & pctx, + parakeet_state & pstate, + const int mel_offset, + const int n_threads, + ggml_abort_callback abort_callback, + void * abort_callback_data) { + const int64_t t_start_us = ggml_time_us(); + + auto & sched = pstate.sched_encode.sched; + + ggml_cgraph * gf = parakeet_build_graph_encode(pctx, pstate); + + if (!ggml_backend_sched_alloc_graph(sched, gf)) { + // should never happen as we pre-allocate the memory + return false; + } + + // set mel input + { + struct ggml_tensor * mel = ggml_graph_get_tensor(gf, "mel"); + + const auto & mel_inp = pstate.mel; + const int n_ctx = pstate.n_audio_ctx > 0 ? pstate.n_audio_ctx : pctx.model.hparams.n_audio_ctx; + + assert(mel->type == GGML_TYPE_F32); + assert(mel_inp.n_mel == pctx.model.hparams.n_mels); + + pstate.inp_mel.resize(ggml_nelements(mel)); + + float * dst = pstate.inp_mel.data(); + memset(dst, 0, ggml_nbytes(mel)); + + const int i0 = std::min(mel_offset, mel_inp.n_len); + const int i1 = std::min(mel_offset + n_ctx, mel_inp.n_len); + + memcpy(dst, mel_inp.data.data() + i0 * mel_inp.n_mel, (i1 - i0) * mel_inp.n_mel * sizeof(float)); + + ggml_backend_tensor_set(mel, pstate.inp_mel.data(), 0, ggml_nelements(mel)*sizeof(float)); + } + + // set attention mask + { + struct ggml_tensor * attn_mask = ggml_graph_get_tensor(gf, "attn_mask"); + const int n_q = attn_mask->ne[1]; + const int n_k = attn_mask->ne[0]; + + const int32_t subsampl_factor = pctx.model.hparams.subsampling_factor; + const int n_tokens_real = (pstate.mel.n_len_org + subsampl_factor - 1) / subsampl_factor; + + std::vector<float> mask_data(n_q * n_k); + const float mask_value = -1e30f; + + if (n_k == n_q) { // full attention + for (int q = 0; q < n_q; ++q) { + for (int k = 0; k < n_k; ++k) { + mask_data[q * n_k + k] = (k >= n_tokens_real) ? mask_value : 0.0f; + } + } + } else { // local attention + const int att_left = n_k / 2; + for (int q = 0; q < n_q; ++q) { + for (int k = 0; k < n_k; ++k) { + const int key = q - att_left + k; + mask_data[q * n_k + k] = (key >= 0 && key < n_tokens_real) ? 0.0f : mask_value; + } + } + } + ggml_backend_tensor_set(attn_mask, mask_data.data(), 0, mask_data.size() * sizeof(float)); + } + + // set local attention skew mask + if (struct ggml_tensor * local_mask = ggml_graph_get_tensor(gf, "local_mask")) { + const int n_k = local_mask->ne[0]; + const int n_q = local_mask->ne[1]; + + std::vector<float> mask_data(n_q * n_k); + const int window_size = n_k - n_q + 1; + for (int q = 0; q < n_q; ++q) { + for (int k = 0; k < n_k; ++k) { + const int rel = k - q; + mask_data[q * n_k + k] = (rel >= 0 && rel < window_size) ? 1.0f : 0.0f; + } + } + ggml_backend_tensor_set(local_mask, mask_data.data(), 0, mask_data.size() * sizeof(float)); + } + + // set positional frequency + { + struct ggml_tensor * pos_freqs_t = ggml_graph_get_tensor(gf, "pos_freqs"); + const int d_half = pos_freqs_t->ne[0]; + const int n_state = pctx.model.hparams.n_audio_state; + const float log_10000 = logf(10000.0f); + std::vector<float> freqs(d_half); + for (int k = 0; k < d_half; ++k) { + freqs[k] = expf(-(float(k * 2) * log_10000 / float(n_state))); + } + ggml_backend_tensor_set(pos_freqs_t, freqs.data(), 0, freqs.size() * sizeof(float)); + } + + // set relative position offsets + { + struct ggml_tensor * rel_pos_t = ggml_graph_get_tensor(gf, "rel_positions"); + const int window_size = rel_pos_t->ne[1]; + std::vector<float> pos(window_size); + if (window_size == PARAKEET_LOCAL_ATTN_WINDOW * 2 + 1) { + for (int t = 0; t < window_size; ++t) { + pos[t] = float(PARAKEET_LOCAL_ATTN_WINDOW - t); + } + } else { + const int n_time = (window_size + 1) / 2; + for (int t = 0; t < window_size; ++t) { + pos[t] = float(n_time - 1 - t); + } + } + ggml_backend_tensor_set(rel_pos_t, pos.data(), 0, pos.size() * sizeof(float)); + } + + if (!ggml_graph_compute_helper(sched, gf, n_threads)) { + return false; + } + + pstate.t_encode_us += ggml_time_us() - t_start_us; + pstate.n_encode++; + + return !(abort_callback && abort_callback(abort_callback_data)); +} + +static bool parakeet_ensure_encode_sched( + parakeet_context & pctx, + parakeet_state & pstate, + int n_audio_ctx) { + if (pstate.sched_encode.sched && pstate.sched_encode_n_audio_ctx == n_audio_ctx) { + return true; + } + + parakeet_sched_free(pstate.sched_encode); + + const int32_t prev_n_audio_ctx = pstate.n_audio_ctx; + pstate.n_audio_ctx = n_audio_ctx; + + const int subsampl_factor = pctx.model.hparams.subsampling_factor; + const int n_frames_max = (n_audio_ctx + subsampl_factor - 1) / subsampl_factor; + if (n_frames_max > pstate.enc_out->ne[1]) { + ggml_backend_buffer_free(pstate.enc_out_buffer); + pstate.enc_out_buffer = nullptr; + pstate.enc_out = nullptr; + + if (!parakeet_enc_state_init(pstate, pstate.backends[0], pctx.model.hparams.n_audio_state, n_frames_max)) { + pstate.sched_encode_n_audio_ctx = 0; + pstate.n_audio_ctx = prev_n_audio_ctx; + return false; + } + } + + const bool ok = parakeet_sched_graph_init(pstate.sched_encode, pstate.backends, + [&]() { + return parakeet_build_graph_encode(pctx, pstate); + }); + + if (!ok) { + pstate.sched_encode_n_audio_ctx = 0; + pstate.n_audio_ctx = prev_n_audio_ctx; + return false; + } + + pstate.sched_encode_n_audio_ctx = n_audio_ctx; + return true; +} + +static struct ggml_tensor * parakeet_build_graph_lstm_layer( + struct ggml_context * ctx0, + struct ggml_cgraph * gf, + struct ggml_tensor * x_t, // the current input token embedding + struct ggml_tensor * w_ih, // input to hidden weights (4 weight tensors packed) + struct ggml_tensor * w_hh, // hidden to hidden weights (4 weight tensors packed) + struct ggml_tensor * b_h, // folded ih+hh bias (4 bias tensors packed) + struct ggml_tensor * h_state, // this layers hidden state + struct ggml_tensor * c_state, // this layers cell state + int li) { // layer index (for tensor naming) + + ggml_format_name(x_t, "lstm_layer_%d_x_t", li); + ggml_format_name(h_state, "lstm_layer_%d_h_state", li); + ggml_format_name(c_state, "lstm_layer_%d_c_state", li); + + // The 4 gates (i, f, o, c) are packed in the same weight tensor. + struct ggml_tensor * inp_gates = ggml_mul_mat(ctx0, w_ih, x_t); + + // Hidden-to-Hidden Projections are also packed in the same weight tensor. + // b_h holds the folded ih+hh bias (see parakeet_model_load), so it is + // the only bias that needs to be added here. + struct ggml_tensor * hid_gates = ggml_mul_mat(ctx0, w_hh, h_state); + hid_gates = ggml_add(ctx0, hid_gates, b_h); + + // Combine the input and hidden contributions of the gates. + struct ggml_tensor * gates = ggml_add(ctx0, inp_gates, hid_gates); + ggml_format_name(gates, "lstm_layer_%d_gates", li); + + const int h_dim = h_state->ne[0]; + const size_t row_size = ggml_row_size(gates->type, h_dim); + + // The gates are packed as [i, f, o, c] (reordered at convert time, see + // parakeet_model_load), so the three sigmoid-gated outputs (i, f, o) are + // contiguous and can be computed with a single ggml_sigmoid call. + struct ggml_tensor * ifo = ggml_sigmoid(ctx0, ggml_view_1d(ctx0, gates, 3 * h_dim, 0)); + ggml_format_name(ifo, "lstm_layer_%d_ifo", li); + + // 1. Input Gate at time t. + struct ggml_tensor * i_t = ggml_view_1d(ctx0, ifo, h_dim, 0 * row_size); + ggml_format_name(i_t, "lstm_layer_%d_i_t", li); + + // Forget gate. + struct ggml_tensor * f_t = ggml_view_1d(ctx0, ifo, h_dim, 1 * row_size); + ggml_format_name(f_t, "lstm_layer_%d_f_t", li); + + // Output gate. + struct ggml_tensor * o_t = ggml_view_1d(ctx0, ifo, h_dim, 2 * row_size); + ggml_format_name(o_t, "lstm_layer_%d_o_t", li); + + // Cell gate. + struct ggml_tensor * c_t = ggml_tanh(ctx0, ggml_view_1d(ctx0, gates, h_dim, 3 * row_size)); + ggml_format_name(c_t, "lstm_layer_%d_c_t", li); + + // Calculate the new cell state. + struct ggml_tensor * c_new = ggml_add(ctx0, + ggml_mul(ctx0, f_t, c_state), // apply forget gate to cell state. + ggml_mul(ctx0, i_t, c_t)); // apply input gate to cell gate. + ggml_build_forward_expand(gf, ggml_cpy(ctx0, c_new, c_state)); + + // Calculate the new hidden state. + struct ggml_tensor * h_new = ggml_mul(ctx0, o_t, ggml_tanh(ctx0, c_new)); + ggml_set_output(h_new); + ggml_format_name(h_new, "lstm_layer_%d_h_new", li); + ggml_build_forward_expand(gf, ggml_cpy(ctx0, h_new, h_state)); + + return h_new; +} + +static struct ggml_cgraph * parakeet_build_graph_prediction( + parakeet_context & pctx, + parakeet_state & pstate, + const parakeet_batch & batch, + bool worst_case) { + GGML_UNUSED(worst_case); + const auto & model = pctx.model; + const auto & hparams = model.hparams; + const int n_tokens = batch.n_tokens; + + struct ggml_init_params params = { + /*.mem_size =*/ pstate.sched_decode.meta.size(), + /*.mem_buffer =*/ pstate.sched_decode.meta.data(), + /*.no_alloc =*/ true, + }; + + struct ggml_context * ctx0 = ggml_init(params); + ggml_cgraph * gf = ggml_new_graph_custom(ctx0, PARAKEET_MAX_NODES, false); + + // Prediction Network + struct ggml_tensor * token = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens); + ggml_set_name(token, "token_inp"); + ggml_set_input(token); + + struct ggml_tensor * token_embd = ggml_get_rows(ctx0, model.prediction.embed_w, token); + + struct ggml_tensor * inpL = token_embd; + + for (int il = 0; il < hparams.n_pred_layers; ++il) { + inpL = parakeet_build_graph_lstm_layer(ctx0, gf, inpL, + model.prediction.lstm_layer[il].ih_w, + model.prediction.lstm_layer[il].hh_w, + model.prediction.lstm_layer[il].b_h, + pstate.lstm_state.layer[il].h_state, + pstate.lstm_state.layer[il].c_state, + il); + } + + struct ggml_tensor * pred_out = inpL; + ggml_format_name(pred_out, "lstm_pred_out"); + + // Project the prediction network output to the joint network hidden dimension. + struct ggml_tensor * pred = ggml_mul_mat(ctx0, model.joint.pred_w, pred_out); + pred = ggml_add(ctx0, pred, model.joint.pred_b); + ggml_set_name(pred, "h_pred"); + + ggml_build_forward_expand(gf, ggml_cpy(ctx0, pred, pstate.pred_out)); + + ggml_free(ctx0); + + return gf; +} + +static struct ggml_cgraph * parakeet_build_graph_joint( + parakeet_context & pctx, + parakeet_state & pstate, + const parakeet_batch & batch, + bool worst_case) { + GGML_UNUSED(worst_case); + const auto & model = pctx.model; + const auto & hparams = model.hparams; + + struct ggml_init_params params = { + /*.mem_size =*/ pstate.sched_decode.meta.size(), + /*.mem_buffer =*/ pstate.sched_decode.meta.data(), + /*.no_alloc =*/ true, + }; + + struct ggml_context * ctx0 = ggml_init(params); + ggml_cgraph * gf = ggml_new_graph_custom(ctx0, PARAKEET_MAX_NODES, false); + + struct ggml_tensor * pred = pstate.pred_out; + ggml_format_name(pred, "pred"); + + const int t_idx = batch.i_time[0]; + struct ggml_tensor * enc_out = ggml_view_1d(ctx0, pstate.enc_out, hparams.n_audio_state, + (size_t) t_idx * pstate.enc_out->nb[1]); + ggml_format_name(enc_out, "enc_out_view"); + + // Project the encoder output to the joint network hidden dimension. + struct ggml_tensor * enc = ggml_mul_mat(ctx0, model.joint.enc_w, enc_out); + enc = ggml_add(ctx0, enc, model.joint.enc_b); + ggml_set_name(enc, "enc"); + + struct ggml_tensor * joint = ggml_add(ctx0, enc, pred); + ggml_set_name(joint, "joint"); + joint = ggml_relu(ctx0, joint); + + struct ggml_tensor * logits = ggml_mul_mat(ctx0, model.joint.net_w, joint); + logits = ggml_add(ctx0, logits, model.joint.net_b); + ggml_set_output(logits); + ggml_set_name(logits, "logits"); + + struct ggml_tensor * probs = ggml_soft_max(ctx0, logits); + struct ggml_tensor * log_probs = ggml_log(ctx0, probs); + ggml_set_output(log_probs); + ggml_format_name(log_probs, "log_probs"); + + ggml_build_forward_expand(gf, log_probs); + + ggml_free(ctx0); + + return gf; +} + +static bool parakeet_predict( + parakeet_context & pctx, + parakeet_state & pstate, + const parakeet_batch & batch, + const int n_threads, + ggml_abort_callback abort_callback, + void * abort_callback_data) { + + const int n_tokens = batch.n_tokens; + + const int64_t t_start_us = ggml_time_us(); + + { + auto & sched = pstate.sched_decode.sched; + + const int64_t t_build_start_us = ggml_time_us(); + ggml_cgraph * gf = parakeet_build_graph_prediction(pctx, pstate, batch, false); + pstate.t_predict_build_us += ggml_time_us() - t_build_start_us; + + const int64_t t_alloc_start_us = ggml_time_us(); + if (!ggml_backend_sched_alloc_graph(sched, gf)) { + // should never happen as we pre-allocate the memory + return false; + } + pstate.t_predict_alloc_us += ggml_time_us() - t_alloc_start_us; + + // set the inputs + { + struct ggml_tensor * token_inp = ggml_graph_get_tensor(gf, "token_inp"); + ggml_backend_tensor_set(token_inp, batch.token, 0, n_tokens * ggml_element_size(token_inp)); + } + + const int64_t t_compute_start_us = ggml_time_us(); + if (!ggml_graph_compute_helper(sched, gf, n_threads)) { + return false; + } + pstate.t_predict_compute_us += ggml_time_us() - t_compute_start_us; + } + + pstate.t_predict_us += ggml_time_us() - t_start_us; + pstate.n_predict++; + + return !(abort_callback && abort_callback(abort_callback_data)); +} + +static bool parakeet_joint( + parakeet_context & pctx, + parakeet_state & pstate, + const parakeet_batch & batch, + const int n_threads, + ggml_abort_callback abort_callback, + void * abort_callback_data) { + const int64_t t_start_us = ggml_time_us(); + + const auto & model = pctx.model; + const auto & hparams = model.hparams; + const int n_tokens = batch.n_tokens; + + auto & logits_out = pstate.logits; + + struct ggml_tensor * logits; + + { + auto & sched = pstate.sched_decode.sched; + + ggml_cgraph * gf = parakeet_build_graph_joint(pctx, pstate, batch, false); + + if (!ggml_backend_sched_alloc_graph(sched, gf)) { + // should never happen as we pre-allocate the memory + return false; + } + + logits = ggml_graph_node(gf, -1); + + if (!ggml_graph_compute_helper(sched, gf, n_threads)) { + return false; + } + + } + + const int n_logits = hparams.n_vocab + hparams.n_tdt_durations + 1; // one for the blank token + logits_out.resize(n_tokens * n_logits); + for (int i = 0; i < n_tokens; i++) { + if (batch.logits[i] == 0) { + continue; + } + ggml_backend_tensor_get(logits, logits_out.data() + (n_logits*i), sizeof(float)*(n_logits*i), sizeof(float)*n_logits); + } + + if (batch.n_tokens == 1) { + pstate.t_decode_us += ggml_time_us() - t_start_us; + pstate.n_decode++; + } + + return !(abort_callback && abort_callback(abort_callback_data)); +} + +static bool is_word_start_token(parakeet_vocab & vocab, parakeet_token token_id) { + const std::string & token_str = vocab.id_to_token[token_id]; + // check if it starts with the SentencePiece meta-space "▁" (U+2581) or 3-byte UTF-8 character: 0xE2 0x96 0x81 + if (!token_str.empty()) { + if (token_str.find("\xE2\x96\x81") == 0 || token_str[0] == '_') { + return true; + } + } + return false; +} + +static bool is_punctuation_token(parakeet_vocab & vocab, parakeet_token token_id) { + const std::string & token_str = vocab.id_to_token[token_id]; + static const std::string punct_chars = ".,!?;:'\"-()[]{}"; + + if (token_str.empty()) { + return false; + } + + std::string clean_token = token_str; + if (clean_token.find("\xE2\x96\x81") == 0) { + clean_token = clean_token.substr(3); // Remove the 3-byte UTF-8 character + } else if (clean_token[0] == '_') { + clean_token = clean_token.substr(1); + } + + return clean_token.length() == 1 && punct_chars.find(clean_token[0]) != std::string::npos; +} + +// Collapse punctuation timestamps to match the original Parakeet model. +// Punctuations symbols like ',', '.' and others are not spoken words but the +// model will still produce a duration for these tokens. But since these are +// non-spoken we collapse the timestamps so that they don't have an time duration. +static void refine_timestamps_tdt(parakeet_vocab & vocab, std::vector<parakeet_token_data> & tokens) { + if (tokens.empty()) { + return; + } + + int64_t last_non_punct_t1 = -1; + + for (size_t i = 0; i < tokens.size(); ++i) { + if (is_punctuation_token(vocab, tokens[i].id)) { + if (last_non_punct_t1 >= 0) { + tokens[i].t0 = last_non_punct_t1; + tokens[i].t1 = last_non_punct_t1; + } + } else { + last_non_punct_t1 = tokens[i].t1; + } + } +} + +static parakeet_token_data create_token_data( + parakeet_context & pctx, + parakeet_state & pstate, + parakeet_token token_id, + int duration_idx, + int duration_value, + int frame_index, + float token_logit, + int n_vocab_logits) { + + float token_sum = 0.0f; + for (int i = 0; i < n_vocab_logits; ++i) { + token_sum += expf(pstate.logits[i]); + } + float token_p = expf(token_logit) / token_sum; + + parakeet_token_data token_data; + token_data.id = token_id; + token_data.duration_idx = duration_idx; + token_data.duration_value = duration_value; + token_data.frame_index = frame_index; + token_data.p = token_p; + token_data.plog = token_logit; + token_data.t0 = frame_index * pctx.model.hparams.subsampling_factor; + token_data.t1 = (frame_index + duration_value) * pctx.model.hparams.subsampling_factor; + token_data.is_word_start = is_word_start_token(pctx.vocab, token_id); + + return token_data; +} + +static bool parakeet_decode( + parakeet_context & pctx, + parakeet_state & pstate, + parakeet_batch & batch, + const int n_threads, + const parakeet_full_params * params = nullptr) { + const auto & hparams = pctx.model.hparams; + const auto & tdt_durations = pctx.model.tdt_durations; + + const int n_tdt_durations = hparams.n_tdt_durations; + const int n_frames = pstate.n_frames; + const int blank_id = pctx.vocab.token_blank; + const int n_vocab_logits = blank_id + 1; + const int max_tokens_per_timestep = hparams.n_max_tokens; + + // time index into the encoder frame (current time frame) + int t = 0; + // number of symbols emitted for the current time frame + int tokens_emitted = 0; + + // Start with the blank token (8192) + parakeet_token last_token = blank_id; + + PARAKEET_LOG_DEBUG("parakeet_decode: starting decode with n_frames=%d\n", n_frames); + + batch.n_tokens = 1; + batch.token[0] = last_token; + batch.logits[0] = 1; + batch.i_time[0] = 0; + + // run the prediction network for the initial blank token. This will + // initialize the LSTM state and produce an initial hidden state that can + // be used in the joint network below. + if (!parakeet_predict(pctx, pstate, batch, n_threads, + params ? params->abort_callback : nullptr, + params ? params->abort_callback_user_data : nullptr)) { + return false; + } + + // process all time frames of the encoder output + while (t < n_frames) { + batch.n_tokens = 1; + batch.i_time[0] = t; + batch.logits[0] = 1; + + // Use the current encoder frame (t) and the output of the prediction to + // generate probabilities for the next token and duration. batch.i_time + // is used in to select the correct frame from the encoder output. + // The joint network outputs logits for all the tokens in the vocabulary + // plus the blank token, and also n_duration logits for the duration + // tokens which contain information about how many frames to skip/advance forward. + if (!parakeet_joint(pctx, pstate, batch, n_threads, + params ? params->abort_callback : nullptr, + params ? params->abort_callback_user_data : nullptr)) { + return false; + } + + const int64_t t_start_sample_us = ggml_time_us(); + + // find the best token (greedy). + // TODO: implement beam search? + int best_token = 0; + float max_logit = -1e10f; + for (int i = 0; i < n_vocab_logits; ++i) { + if (pstate.logits[i] > max_logit) { + max_logit = pstate.logits[i]; + best_token = i; + } + } + + // find the max index of the duration logits, and look up that index + // value in the tdt_durations array to get the actual duration value. + int best_duration_idx = 0; + float best_duration_logit = -1e10f; + for (int i = 0; i < n_tdt_durations; ++i) { + if (pstate.logits[n_vocab_logits + i] > best_duration_logit) { + best_duration_logit = pstate.logits[n_vocab_logits + i]; + best_duration_idx = i; + } + } + // look up that max duration index value in the tdt_durations array to + // get the actual duration value. + int duration = tdt_durations[best_duration_idx]; + + if (best_token == blank_id) { + if (duration == 0) { + duration = 1; + } + // skip forward by duration time frames. + t += duration; + // reset symbols emitted counter + tokens_emitted = 0; + // continue without predicting. + continue; + } + + // Emit non-blank token at current frame t. + pstate.decoded_tokens.push_back(best_token); + pstate.t_sample_us += ggml_time_us() - t_start_sample_us; + pstate.n_sample++; + + parakeet_token_data token_data = create_token_data( + pctx, pstate, best_token, best_duration_idx, duration, t, + max_logit, n_vocab_logits); + + pstate.decoded_token_data.push_back(token_data); + + // Call token callback if registered (for real-time streaming) + if (params && params->new_token_callback) { + params->new_token_callback(&pctx, &pstate, &token_data, params->new_token_callback_user_data); + } + + last_token = best_token; + + // advance predictor for the non-blank token. + batch.token[0] = last_token; + if (!parakeet_predict(pctx, pstate, batch, n_threads, + params ? params->abort_callback : nullptr, + params ? params->abort_callback_user_data : nullptr)) { + return false; + } + + // if duration greater than 0, continue looping over the encoder frames + // and skip to the updated time frame (t). + if (duration > 0) { + t += duration; + tokens_emitted = 0; + continue; + } + + // if duration is zero we stay on the current time frame. + tokens_emitted++; + if (tokens_emitted >= max_tokens_per_timestep) { + t += 1; // forced blank/time advance behavior + tokens_emitted = 0; + } + } + + return true; +} + +// 500 -> 00:05.000 +// 6000 -> 01:00.000 +// naive Discrete Fourier Transform +// input is real-valued +// output is complex-valued +static void dft(const float* in, int N, float* out, const parakeet_mel_cache & cache) { + const int sin_cos_step = cache.n_fft / N; + + for (int k = 0; k < N; k++) { + float re = 0; + float im = 0; + + for (int n = 0; n < N; n++) { + int idx = (k * n * sin_cos_step) % cache.n_fft; // t = 2*M_PI*k*n/N + re += in[n]*cache.cos_vals[idx]; // cos(t) + im -= in[n]*cache.sin_vals[idx]; // sin(t) + } + + out[k*2 + 0] = re; + out[k*2 + 1] = im; + } +} + +// Cooley-Tukey FFT +// poor man's implementation - use something better +// input is real-valued +// output is complex-valued +static void fft(float* in, int N, float* out, const parakeet_mel_cache & cache) { + if (N == 1) { + out[0] = in[0]; + out[1] = 0; + return; + } + + const int half_N = N / 2; + if (N - half_N*2 == 1) { + dft(in, N, out, cache); + return; + } + + float* even = in + N; + for (int i = 0; i < half_N; ++i) { + even[i]= in[2*i]; + } + float* even_fft = out + 2 * N; + fft(even, half_N, even_fft, cache); + + float* odd = even; + for (int i = 0; i < half_N; ++i) { + odd[i] = in[2*i + 1]; + } + float* odd_fft = even_fft + N; + fft(odd, half_N, odd_fft, cache); + + const int sin_cos_step = cache.n_fft / N; + for (int k = 0; k < half_N; k++) { + int idx = k * sin_cos_step; // t = 2*M_PI*k/N + float re = cache.cos_vals[idx]; // cos(t) + float im = -cache.sin_vals[idx]; // sin(t) + + float re_odd = odd_fft[2*k + 0]; + float im_odd = odd_fft[2*k + 1]; + + out[2*k + 0] = even_fft[2*k + 0] + re*re_odd - im*im_odd; + out[2*k + 1] = even_fft[2*k + 1] + re*im_odd + im*re_odd; + + out[2*(k + half_N) + 0] = even_fft[2*k + 0] - re*re_odd + im*im_odd; + out[2*(k + half_N) + 1] = even_fft[2*k + 1] - re*im_odd - im*re_odd; + } +} + +struct mel_worker_params { + int ith; + int window_size; + int n_samples; + int frame_size; + int frame_step; + int n_threads; +}; + +static void log_mel_spectrogram_worker_thread( + mel_worker_params params, + const float * window_func, + const std::vector<float> & samples, + const parakeet_filters & filters, + parakeet_mel & mel, + const parakeet_mel_cache & cache) { + std::vector<float> fft_in(params.frame_size * 2, 0.0); + std::vector<float> fft_out(params.frame_size * 2 * 2 * 2); + + int n_fb = filters.n_fb; // number of frequency bins + int i = params.ith; + + // make sure n_fb == 1 + (frame_size / 2), bin_0 to bin_nyquist + assert(n_fb == 1 + (params.frame_size / 2)); + + const double eps = 5.960464477539063e-08; + + // calculate FFT only when fft_in are not all zero + for (; i < std::min(params.n_samples / params.frame_step + 1, mel.n_len); i += params.n_threads) { + const int offset = i * params.frame_step; + + const int window_pad_left = (params.frame_size - params.window_size) / 2; + + // Zero-pad left + std::fill(fft_in.begin(), fft_in.begin() + window_pad_left, 0.0f); + + // Apply windowed samples in the center + const int n_to_process = std::min({params.window_size, params.n_samples - offset}); + for (int j = 0; j < n_to_process; j++) { + fft_in[window_pad_left + j] = window_func[j] * samples[offset + window_pad_left + j]; + } + + // Zero-pad right (and any samples we didn't have) + std::fill(fft_in.begin() + window_pad_left + n_to_process, fft_in.begin() + params.frame_size, 0.0f); + + // FFT + fft(fft_in.data(), params.frame_size, fft_out.data(), cache); + + // Calculate modulus^2 of complex numbers + // Use pow(fft_out[2 * j + 0], 2) + pow(fft_out[2 * j + 1], 2) causes inference quality problem? Interesting. + for (int j = 0; j < n_fb; j++) { + fft_out[j] = (fft_out[2 * j + 0] * fft_out[2 * j + 0] + fft_out[2 * j + 1] * fft_out[2 * j + 1]); + } + + // mel spectrogram + for (int j = 0; j < mel.n_mel; j++) { + double sum = 0.0; + // unroll loop (suggested by GH user @lunixbochs) + int k = 0; + for (k = 0; k < n_fb - 3; k += 4) { + sum += + fft_out[k + 0] * filters.data[j * n_fb + k + 0] + + fft_out[k + 1] * filters.data[j * n_fb + k + 1] + + fft_out[k + 2] * filters.data[j * n_fb + k + 2] + + fft_out[k + 3] * filters.data[j * n_fb + k + 3]; + } + // handle n_fb remainder + for (; k < n_fb; k++) { + sum += fft_out[k] * filters.data[j * n_fb + k]; + } + + mel.data[i * mel.n_mel + j] = std::log(sum + eps); + } + } + + // Otherwise fft_out are all zero - use log(eps) for consistency + const double empty_sum = std::log(eps); + for (; i < mel.n_len; i += params.n_threads) { + for (int j = 0; j < mel.n_mel; j++) { + mel.data[i * mel.n_mel + j] = empty_sum; + } + } +} + +static bool log_mel_spectrogram( + parakeet_state & wstate, + const float * samples, + const int n_samples, + const int /*sample_rate*/, + const int frame_size, + const int frame_step, + const int n_mel, + const int n_threads, + const parakeet_filters & filters, + const bool debug, + parakeet_mel & mel, + const parakeet_mel_cache & cache) { + const int64_t t_start_us = ggml_time_us(); + + const float * window_func = cache.window.empty() ? cache.hann_window.data() : cache.window.data(); + const int window_size = cache.window.empty() ? cache.n_fft : cache.window.size(); + + std::vector<float> samples_preprocessed(samples, samples + n_samples); + + // Apply preemphasis filter (high-pass): x[i] = x[i] - 0.97 * x[i-1] + { + const float preemph = 0.97f; + for (int i = n_samples - 1; i > 0; i--) { + samples_preprocessed[i] = samples_preprocessed[i] - preemph * samples_preprocessed[i - 1]; + } + } + + // Parakeet Pytorch implementation uses centered contant padding. + const size_t pad = (size_t)(frame_size / 2); + std::vector<float> samples_padded(n_samples + 2 * pad, 0.0f); + std::copy(samples_preprocessed.begin(), samples_preprocessed.end(), samples_padded.begin() + pad); + + mel.n_mel = n_mel; + mel.n_len = (samples_padded.size() - frame_size) / frame_step + 1; + mel.n_len_org = mel.n_len; + mel.data.resize(mel.n_mel * mel.n_len); + + // Worker Threads (STFT + Mel + Natural Log) + { + std::vector<std::thread> workers(n_threads - 1); + const mel_worker_params mel_params { 0, window_size, (int)samples_padded.size(), frame_size, frame_step, n_threads }; + + for (int iw = 0; iw < n_threads - 1; ++iw) { + mel_worker_params params = mel_params; + params.ith = iw + 1; + workers[iw] = std::thread(log_mel_spectrogram_worker_thread, + params, + window_func, + std::cref(samples_padded), + std::cref(filters), + std::ref(mel), + std::cref(cache)); + } + + log_mel_spectrogram_worker_thread( + mel_params, + window_func, + samples_padded, + filters, + mel, + cache); + + for (int iw = 0; iw < n_threads - 1; ++iw) { + workers[iw].join(); + } + } + + { + const double eps = 1e-5; + int valid_frames = n_samples / frame_step; + + for (int j = 0; j < mel.n_mel; j++) { + double sum = 0.0; + double sq_diff_sum = 0.0; + + // Calculate Mean ONLY on valid audio frames + for (int i = 0; i < valid_frames; i++) { + sum += (double)mel.data[i * mel.n_mel + j]; + } + double mean = sum / valid_frames; + + // Calculate Variance ONLY on valid audio frames + for (int i = 0; i < valid_frames; i++) { + double diff = (double)mel.data[i * mel.n_mel + j] - mean; + sq_diff_sum += diff * diff; + } + + double std_dev = std::sqrt(sq_diff_sum / (valid_frames - 1.0)); + double denominator = std_dev + eps; + + // Apply to ALL frames (including the padded ones) + for (int i = 0; i < mel.n_len; i++) { + mel.data[i * mel.n_mel + j] = (float)((mel.data[i * mel.n_mel + j] - mean) / denominator); + } + } + } + + wstate.t_mel_us += ggml_time_us() - t_start_us; + + if (debug) { + std::ofstream outFile("log_mel_spectrogram.json"); + outFile << "["; + for (uint64_t i = 0; i < mel.data.size() - 1; i++) { + outFile << mel.data[i] << ", "; + } + outFile << mel.data[mel.data.size() - 1] << "]"; + outFile.close(); + } + + return true; +} + +static std::vector<parakeet_vocab::id> tokenize(const parakeet_vocab & vocab, const std::string & text) { + std::vector<parakeet_vocab::id> tokens; + const std::string normalized = sentencepiece_normalize(text); + + size_t i = 0; + while (i < normalized.size()) { + const size_t remaining = normalized.size() - i; + const size_t max_len = std::min(vocab.max_token_length, remaining); + + bool found = false; + for (size_t len = max_len; len > 0; --len) { + const auto it = vocab.token_to_id.find(normalized.substr(i, len)); + if (it != vocab.token_to_id.end() && !is_sentencepiece_control(it->first)) { + tokens.push_back(it->second); + i += len; + found = true; + break; + } + } + + if (!found) { + if (vocab.token_unk >= 0) { + tokens.push_back(vocab.token_unk); + } + + const unsigned char c = static_cast<unsigned char>(normalized[i]); + i += utf8_codepoint_len(c); + } + } + + return tokens; +} + + +// +// interface implementation +// + +struct parakeet_state * parakeet_init_state(parakeet_context * ctx) { + parakeet_state * state = new parakeet_state; + + state->backends = parakeet_backend_init(ctx->params); + if (state->backends.empty()) { + PARAKEET_LOG_ERROR("%s: parakeet_backend_init() failed\n", __func__); + parakeet_free_state(state); + return nullptr; + } + + const int batch_size = ctx->model.hparams.n_audio_ctx; + + state->logits.reserve(ctx->vocab.n_vocab * batch_size); + + state->batch = parakeet_batch_init(batch_size); + + { + const int n_audio_state = ctx->model.hparams.n_audio_state; + const int subsampl_factor = ctx->model.hparams.subsampling_factor; + const int n_frames_max = (batch_size + subsampl_factor - 1) / subsampl_factor; + + if (!parakeet_enc_state_init(*state, state->backends[0], n_audio_state, n_frames_max)) { + PARAKEET_LOG_ERROR("%s: parakeet_enc_state_init() failed\n", __func__); + parakeet_free_state(state); + return nullptr; + } + + const size_t mem_enc_ctx = state->enc_out_buf.size(); + const size_t mem_enc_out_buf = ggml_backend_buffer_get_size(state->enc_out_buffer); + PARAKEET_LOG_INFO("%s: enc_out state: %7.2f MB (meta) + %7.2f MB (data)\n", __func__, + mem_enc_ctx / 1024.0 / 1024.0, mem_enc_out_buf / 1024.0 / 1024.0); + } + + // conv/encoder allocator + bool ok = parakeet_sched_graph_init(state->sched_encode, state->backends, + [&]() { + return parakeet_build_graph_encode(*ctx, *state); + }); + + if (!ok) { + PARAKEET_LOG_ERROR("%s: failed to init encode allocator\n", __func__); + parakeet_free_state(state); + return nullptr; + } + state->sched_encode_n_audio_ctx = state->n_audio_ctx > 0 ? state->n_audio_ctx : ctx->model.hparams.n_audio_ctx; + + if (!parakeet_lstm_state_init(*state, state->backends[0], ctx->model.hparams.n_pred_layers, ctx->model.hparams.n_pred_dim)) { + PARAKEET_LOG_ERROR("%s: parakeet_lstm_states_init () failed\n", __func__); + parakeet_free_state(state); + return nullptr; + } + + { + const size_t mem_lstm_ctx = state->lstm_state.ctx_buf.size(); + const size_t mem_lstm_buf = ggml_backend_buffer_get_size(state->lstm_state.buffer); + PARAKEET_LOG_INFO("%s: lstm state: %7.2f MB (meta) + %7.2f MB (data)\n", __func__, + mem_lstm_ctx / 1024.0 / 1024.0, mem_lstm_buf / 1024.0 / 1024.0); + } + + if (!parakeet_pred_state_init(*state, state->backends[0], ctx->model.hparams.n_pred_dim)) { + PARAKEET_LOG_ERROR("%s: parakeet_pred_state_init() failed\n", __func__); + parakeet_free_state(state); + return nullptr; + } + + { + const size_t mem_pred_ctx = state->pred_out_buf.size(); + const size_t mem_pred_out_buf = ggml_backend_buffer_get_size(state->pred_out_buffer); + PARAKEET_LOG_INFO("%s: pred state: %7.2f MB (meta) + %7.2f MB (data)\n", __func__, + mem_pred_ctx / 1024.0 / 1024.0, mem_pred_out_buf / 1024.0 / 1024.0); + } + + PARAKEET_LOG_INFO("%s: compute buffer (encode) = %7.2f MB\n", __func__, parakeet_sched_size(state->sched_encode) / 1e6); + + { + bool ok = parakeet_sched_graph_init(state->sched_decode, state->backends, + [&]() { + const auto & hparams = ctx->model.hparams; + const int n_tokens = hparams.n_audio_ctx; // Use audio ctx for Parakeet + + parakeet_batch_prep_legacy(state->batch, nullptr, n_tokens, 0, 0); + + return parakeet_build_graph_prediction(*ctx, *state, state->batch, true); + }); + + if (!ok) { + PARAKEET_LOG_ERROR("%s: failed to init decoder allocator\n", __func__); + parakeet_free_state(state); + return nullptr; + } + + PARAKEET_LOG_INFO("%s: compute buffer (decode) = %7.2f MB\n", __func__, parakeet_sched_size(state->sched_decode) / 1e6); + } + + return state; +} + +struct parakeet_context_params parakeet_context_default_params() { + struct parakeet_context_params result = { + /*.use_gpu =*/ true, + /*.gpu_device =*/ 0, + }; + return result; +} + +struct parakeet_context * parakeet_init_from_file_with_params_no_state(const char * path_model, struct parakeet_context_params params) { + PARAKEET_LOG_INFO("%s: loading model from '%s'\n", __func__, path_model); +#ifdef _MSC_VER + // Convert UTF-8 path to wide string (UTF-16) for Windows, resolving character encoding issues. + std::wstring_convert<std::codecvt_utf8<wchar_t>> converter; + std::wstring path_model_wide = converter.from_bytes(path_model); + auto fin = std::ifstream(path_model_wide, std::ios::binary); +#else + auto fin = std::ifstream(path_model, std::ios::binary); +#endif + if (!fin) { + PARAKEET_LOG_ERROR("%s: failed to open '%s'\n", __func__, path_model); + return nullptr; + } + + parakeet_model_loader loader = {}; + + loader.context = &fin; + + loader.read = [](void * ctx, void * output, size_t read_size) { + std::ifstream * fin = (std::ifstream*)ctx; + fin->read((char *)output, read_size); + return read_size; + }; + + loader.eof = [](void * ctx) { + std::ifstream * fin = (std::ifstream*)ctx; + return fin->eof(); + }; + + loader.close = [](void * ctx) { + std::ifstream * fin = (std::ifstream*)ctx; + fin->close(); + }; + + auto ctx = parakeet_init_with_params_no_state(&loader, params); + + if (ctx) { + ctx->path_model = path_model; + } + + return ctx; +} + +struct parakeet_context * parakeet_init_from_buffer_with_params_no_state(void * buffer, size_t buffer_size, struct parakeet_context_params params) { + struct buf_context { + uint8_t* buffer; + size_t size; + size_t current_offset; + }; + + buf_context ctx = { reinterpret_cast<uint8_t*>(buffer), buffer_size, 0 }; + + PARAKEET_LOG_INFO("%s: loading model from buffer\n", __func__); + + parakeet_model_loader loader = {}; + + loader.context = &ctx; + + loader.read = [](void * ctx, void * output, size_t read_size) { + buf_context * buf = reinterpret_cast<buf_context *>(ctx); + + size_t size_to_copy = buf->current_offset + read_size < buf->size ? read_size : buf->size - buf->current_offset; + + memcpy(output, buf->buffer + buf->current_offset, size_to_copy); + buf->current_offset += size_to_copy; + + return size_to_copy; + }; + + loader.eof = [](void * ctx) { + buf_context * buf = reinterpret_cast<buf_context *>(ctx); + + return buf->current_offset >= buf->size; + }; + + loader.close = [](void * /*ctx*/) { }; + + return parakeet_init_with_params_no_state(&loader, params); +} + +struct parakeet_context * parakeet_init_with_params_no_state(struct parakeet_model_loader * loader, struct parakeet_context_params params) { + ggml_time_init(); + + PARAKEET_LOG_INFO("%s: use gpu = %d\n", __func__, params.use_gpu); + PARAKEET_LOG_INFO("%s: gpu_device = %d\n", __func__, params.gpu_device); + PARAKEET_LOG_INFO("%s: devices = %zu\n", __func__, ggml_backend_dev_count()); + PARAKEET_LOG_INFO("%s: backends = %zu\n", __func__, ggml_backend_reg_count()); + + parakeet_context * ctx = new parakeet_context; + ctx->params = params; + + bool model_loaded = false; + try { + model_loaded = parakeet_model_load(loader, *ctx); + } catch (const std::exception & e) { + PARAKEET_LOG_ERROR("%s: exception during model load: %s\n", __func__, e.what()); + } catch (...) { + PARAKEET_LOG_ERROR("%s: unknown exception during model load\n", __func__); + } + + if (!model_loaded) { + loader->close(loader->context); + PARAKEET_LOG_ERROR("%s: failed to load model\n", __func__); + delete ctx; + return nullptr; + } + + loader->close(loader->context); + + // Initialize mel cache with model's FFT size + ctx->mel_cache.init(ctx->model.hparams.n_fft); + PARAKEET_LOG_INFO("%s: initialized mel cache with n_fft = %d\n", __func__, ctx->model.hparams.n_fft); + + return ctx; +} + +struct parakeet_context * parakeet_init_from_file_with_params(const char * path_model, struct parakeet_context_params params) { + parakeet_context * ctx = parakeet_init_from_file_with_params_no_state(path_model, params); + if (!ctx) { + return nullptr; + } + + ctx->state = parakeet_init_state(ctx); + if (!ctx->state) { + parakeet_free(ctx); + return nullptr; + } + + return ctx; +} + +struct parakeet_context * parakeet_init_from_buffer_with_params(void * buffer, size_t buffer_size, struct parakeet_context_params params) { + parakeet_context * ctx = parakeet_init_from_buffer_with_params_no_state(buffer, buffer_size, params); + if (!ctx) { + return nullptr; + } + + ctx->state = parakeet_init_state(ctx); + if (!ctx->state) { + parakeet_free(ctx); + return nullptr; + } + + return ctx; +} + +struct parakeet_context * parakeet_init_with_params(struct parakeet_model_loader * loader, struct parakeet_context_params params) { + parakeet_context * ctx = parakeet_init_with_params_no_state(loader, params); + if (!ctx) { + return nullptr; + } + + ctx->state = parakeet_init_state(ctx); + if (!ctx->state) { + parakeet_free(ctx); + return nullptr; + } + + return ctx; +} + +void parakeet_free_state(struct parakeet_state * state) { + if (state) { + ggml_backend_buffer_free(state->lstm_state.buffer); + ggml_backend_buffer_free(state->pred_out_buffer); + ggml_backend_buffer_free(state->enc_out_buffer); + + parakeet_batch_free(state->batch); + + parakeet_sched_free(state->sched_encode); + parakeet_sched_free(state->sched_decode); + + for (auto & backend : state->backends) { + ggml_backend_free(backend); + } + + delete state; + } +} + +void parakeet_free(struct parakeet_context * ctx) { + if (ctx) { + for (ggml_context * context : ctx->model.ctxs) { + ggml_free(context); + } + + for (ggml_backend_buffer_t buf : ctx->model.buffers) { + ggml_backend_buffer_free(buf); + } + + parakeet_free_state(ctx->state); + + delete ctx; + } +} + +void parakeet_free_context_params(struct parakeet_context_params * params) { + if (params) { + delete params; + } +} + +void parakeet_free_params(struct parakeet_full_params * params) { + if (params) { + delete params; + } +} + +int parakeet_pcm_to_mel_with_state(struct parakeet_context * ctx, struct parakeet_state * state, const float * samples, int n_samples, int n_threads) { + if (!log_mel_spectrogram(*state, + samples, + n_samples, + PARAKEET_SAMPLE_RATE, + ctx->model.hparams.n_fft, + PARAKEET_HOP_LENGTH, + ctx->model.filters.n_mel, + n_threads, + ctx->model.filters, + false, // debug + state->mel, + ctx->mel_cache)) { + PARAKEET_LOG_ERROR("%s: failed to compute mel spectrogram\n", __func__); + return -1; + } + + return 0; +} + +int parakeet_pcm_to_mel(struct parakeet_context * ctx, const float * samples, int n_samples, int n_threads) { + return parakeet_pcm_to_mel_with_state(ctx, ctx->state, samples, n_samples, n_threads); +} + +int parakeet_set_mel_with_state( + struct parakeet_context * ctx, + struct parakeet_state * state, + const float * data, + int n_len, + int n_mel) { + if (n_mel != ctx->model.filters.n_mel) { + PARAKEET_LOG_ERROR("%s: invalid number of mel bands: %d (expected %d)\n", __func__, n_mel, ctx->model.filters.n_mel); + return -1; + } + + state->mel.n_len = n_len; + state->mel.n_len_org = n_len; + state->mel.n_mel = n_mel; + + state->mel.data.resize(n_len*n_mel); + memcpy(state->mel.data.data(), data, n_len*n_mel*sizeof(float)); + + return 0; +} + +int parakeet_set_mel( + struct parakeet_context * ctx, + const float * data, + int n_len, + int n_mel) { + return parakeet_set_mel_with_state(ctx, ctx->state, data, n_len, n_mel); +} + +int parakeet_encode_with_state(struct parakeet_context * ctx, struct parakeet_state * state, int offset, int n_threads) { + if (!parakeet_encode_internal(*ctx, *state, offset, n_threads, nullptr, nullptr)) { + PARAKEET_LOG_ERROR("%s: failed to eval\n", __func__); + return -1; + } + + return 0; +} + +int parakeet_encode(struct parakeet_context * ctx, int offset, int n_threads) { + if (!parakeet_encode_internal(*ctx, *ctx->state, offset, n_threads, nullptr, nullptr)) { + PARAKEET_LOG_ERROR("%s: failed to eval\n", __func__); + return -1; + } + + return 0; +} + +int parakeet_tokenize(struct parakeet_context * ctx, const char * text, parakeet_token * tokens, int n_max_tokens) { + const auto res = tokenize(ctx->vocab, text); + + if (n_max_tokens < (int) res.size()) { + PARAKEET_LOG_ERROR("%s: too many resulting tokens: %d (max %d)\n", __func__, (int) res.size(), n_max_tokens); + return -(int) res.size(); + } + + for (int i = 0; i < (int) res.size(); i++) { + tokens[i] = res[i]; + } + + return res.size(); +} + +int parakeet_token_count(struct parakeet_context * ctx, const char * text) { + return -parakeet_tokenize(ctx, text, NULL, 0); +} + +int parakeet_model_n_vocab(struct parakeet_context * ctx) { + return ctx->model.hparams.n_vocab; +} + +int parakeet_model_n_audio_ctx(struct parakeet_context * ctx) { + return ctx->model.hparams.n_audio_ctx; +} + +int parakeet_model_n_audio_state(struct parakeet_context * ctx) { + return ctx->model.hparams.n_audio_state; +} + +int parakeet_model_n_audio_head(struct parakeet_context * ctx) { + return ctx->model.hparams.n_audio_head; +} + +int parakeet_model_n_audio_layer(struct parakeet_context * ctx) { + return ctx->model.hparams.n_audio_layer; +} + +int parakeet_model_n_mels(struct parakeet_context * ctx) { + return ctx->model.hparams.n_mels; +} + +int parakeet_model_ftype(struct parakeet_context * ctx) { + return ctx->model.hparams.ftype; +} + +int parakeet_n_len_from_state(struct parakeet_state * state) { + return state->mel.n_len_org; +} + +int parakeet_n_len(struct parakeet_context * ctx) { + return ctx->state->mel.n_len_org; +} + +int parakeet_n_vocab(struct parakeet_context * ctx) { + return ctx->vocab.n_vocab; +} + +int parakeet_n_audio_ctx(struct parakeet_context * ctx) { + return ctx->model.hparams.n_audio_ctx; +} + +float * parakeet_get_logits(struct parakeet_context * ctx) { + return ctx->state->logits.data(); +} + +float * parakeet_get_logits_from_state(struct parakeet_state * state) { + return state->logits.data(); +} + +const char * parakeet_token_to_str(struct parakeet_context * ctx, parakeet_token token) { + return ctx->vocab.id_to_token.at(token).c_str(); +} + +int parakeet_token_to_text(const char * token_str, bool is_first, char * output, int max_len) { + std::string text = sentencepiece_piece_to_text(token_str, is_first); + + if (output == nullptr) { + return text.size(); + } + + int bytes_to_copy = std::min((int)text.size(), max_len - 1); + if (bytes_to_copy > 0) { + memcpy(output, text.c_str(), bytes_to_copy); + output[bytes_to_copy] = '\0'; + } else if (max_len > 0) { + output[0] = '\0'; + } + + return text.size(); +} + +parakeet_token parakeet_token_bos(struct parakeet_context * ctx) { + return ctx->vocab.token_bos; +} + +parakeet_token parakeet_token_unk(struct parakeet_context * ctx) { + return ctx->vocab.token_unk; +} + +parakeet_token parakeet_token_blank(struct parakeet_context * ctx) { + return ctx->vocab.token_blank; +} + +struct parakeet_timings * parakeet_get_timings(struct parakeet_context * ctx) { + if (ctx->state == nullptr) { + return nullptr; + } + parakeet_timings * timings = new parakeet_timings; + timings->sample_ms = 1e-3f * ctx->state->t_sample_us / std::max(1, ctx->state->n_sample); + timings->encode_ms = 1e-3f * ctx->state->t_encode_us / std::max(1, ctx->state->n_encode); + timings->decode_ms = 1e-3f * ctx->state->t_decode_us / std::max(1, ctx->state->n_decode); + return timings; +} + +void parakeet_print_timings(struct parakeet_context * ctx) { + const int64_t t_end_us = ggml_time_us(); + + PARAKEET_LOG_INFO("\n"); + PARAKEET_LOG_INFO("%s: load time = %8.2f ms\n", __func__, ctx->t_load_us / 1000.0f); + if (ctx->state != nullptr) { + + const int32_t n_sample = std::max(1, ctx->state->n_sample); + const int32_t n_encode = std::max(1, ctx->state->n_encode); + const int32_t n_decode = std::max(1, ctx->state->n_decode); + const int32_t n_predict = std::max(1, ctx->state->n_predict); + + PARAKEET_LOG_INFO("%s: fallbacks = %3d p / %3d h\n", __func__, ctx->state->n_fail_p, ctx->state->n_fail_h); + PARAKEET_LOG_INFO("%s: mel time = %8.2f ms\n", __func__, ctx->state->t_mel_us / 1000.0f); + PARAKEET_LOG_INFO("%s: sample time = %8.2f ms / %5d runs ( %8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_sample_us, n_sample, 1e-3f * ctx->state->t_sample_us / n_sample); + PARAKEET_LOG_INFO("%s: encode time = %8.2f ms / %5d runs ( %8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_encode_us, n_encode, 1e-3f * ctx->state->t_encode_us / n_encode); + PARAKEET_LOG_INFO("%s: decode time = %8.2f ms / %5d runs ( %8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_decode_us, n_decode, 1e-3f * ctx->state->t_decode_us / n_decode); + PARAKEET_LOG_INFO("%s: predict time = %8.2f ms / %5d runs ( %8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_predict_us, n_predict, 1e-3f * ctx->state->t_predict_us / n_predict); + PARAKEET_LOG_INFO("%s: - build = %8.2f ms / %5d runs ( %8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_predict_build_us, n_predict, 1e-3f * ctx->state->t_predict_build_us / n_predict); + PARAKEET_LOG_INFO("%s: - alloc = %8.2f ms / %5d runs ( %8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_predict_alloc_us, n_predict, 1e-3f * ctx->state->t_predict_alloc_us / n_predict); + PARAKEET_LOG_INFO("%s: - compute = %8.2f ms / %5d runs ( %8.2f ms per run)\n", __func__, 1e-3f * ctx->state->t_predict_compute_us, n_predict, 1e-3f * ctx->state->t_predict_compute_us / n_predict); + + } + PARAKEET_LOG_INFO("%s: total time = %8.2f ms\n", __func__, (t_end_us - ctx->t_start_us)/1000.0f); +} + +void parakeet_reset_timings(struct parakeet_context * ctx) { + ctx->t_start_us = ggml_time_us(); + if (ctx->state != nullptr) { + ctx->state->t_mel_us = 0; + ctx->state->t_sample_us = 0; + ctx->state->t_encode_us = 0; + ctx->state->t_decode_us = 0; + ctx->state->t_predict_us = 0; + ctx->state->t_predict_build_us = 0; + ctx->state->t_predict_alloc_us = 0; + ctx->state->t_predict_compute_us = 0; + + ctx->state->n_sample = 0; + ctx->state->n_encode = 0; + ctx->state->n_decode = 0; + ctx->state->n_predict = 0; + } +} + +const char * parakeet_print_system_info(void) { + static std::string s; + + s = ""; + s += "PARAKEET : "; + + for (size_t i = 0; i < ggml_backend_reg_count(); i++) { + auto * reg = ggml_backend_reg_get(i); + auto * get_features_fn = (ggml_backend_get_features_t) ggml_backend_reg_get_proc_address(reg, "ggml_backend_get_features"); + if (get_features_fn) { + ggml_backend_feature * features = get_features_fn(reg); + s += ggml_backend_reg_name(reg); + s += " : "; + for (; features->name; features++) { + s += features->name; + s += " = "; + s += features->value; + s += " | "; + } + } + } + return s.c_str(); +} + +struct parakeet_context_params * parakeet_context_default_params_by_ref(void) { + struct parakeet_context_params params = parakeet_context_default_params(); + + struct parakeet_context_params* result = new parakeet_context_params(); + *result = params; + return result; +} + +struct parakeet_full_params * parakeet_full_default_params_by_ref(enum parakeet_sampling_strategy strategy) { + struct parakeet_full_params params = parakeet_full_default_params(strategy); + + struct parakeet_full_params* result = new parakeet_full_params(); + *result = params; + return result; +} + +struct parakeet_full_params parakeet_full_default_params(enum parakeet_sampling_strategy strategy) { + struct parakeet_full_params result = { + /*.strategy =*/ strategy, + /*.n_threads =*/ std::min(4, (int32_t) std::thread::hardware_concurrency()), + /*.offset_ms =*/ 0, + /*.duration_ms =*/ 0, + /*.no_context =*/ true, + /*.audio_ctx =*/ 0, + /*.new_token_callback =*/ nullptr, + /*.new_token_callback_user_data =*/ nullptr, + /*.new_segment_callback =*/ nullptr, + /*.new_segment_callback_user_data =*/ nullptr, + /*.progress_callback =*/ nullptr, + /*.progress_callback_user_data =*/ nullptr, + /*.encoder_begin_callback =*/ nullptr, + /*.encoder_begin_callback_user_data =*/ nullptr, + /*.abort_callback =*/ nullptr, + /*.abort_callback_user_data =*/ nullptr, + }; + + return result; +} + +static void parakeet_reset_state(struct parakeet_state * state) { + state->decoded_tokens.clear(); + state->decoded_token_data.clear(); + + if (state->lstm_state.buffer) { + ggml_backend_buffer_clear(state->lstm_state.buffer, 0); + } + +} + +// Encode and decode the mel spectrogram already in state, without recomputing it. +static int parakeet_chunk_with_state( + struct parakeet_context * ctx, + struct parakeet_state * state, + struct parakeet_full_params params) { + return parakeet_chunk(ctx, state, params, nullptr, 0); +} + +int parakeet_full_with_state( + struct parakeet_context * ctx, + struct parakeet_state * state, + struct parakeet_full_params params, + const float * samples, + int n_samples) { + state->result_all.clear(); + + if (params.no_context) { + parakeet_reset_state(state); + } + + if (n_samples > 0) { + if (parakeet_pcm_to_mel_with_state(ctx, state, samples, n_samples, params.n_threads) != 0) { + PARAKEET_LOG_ERROR("%s: failed to compute log mel spectrogram\n", __func__); + return -2; + } + } + + const int n_mel_total = state->mel.n_len; + const int n_audio_ctx = ctx->model.hparams.n_audio_ctx; + + if (n_mel_total <= n_audio_ctx) { + if (params.progress_callback) { + params.progress_callback(ctx, state, 0, params.progress_callback_user_data); + } + return parakeet_chunk_with_state(ctx, state, params); + } + + PARAKEET_LOG_DEBUG("%s: audio too long (%d mel > n_audio_ctx=%d), using dynamic encoder graph\n", + __func__, n_mel_total, n_audio_ctx); + + if (params.encoder_begin_callback) { + if (!params.encoder_begin_callback(ctx, state, params.encoder_begin_callback_user_data)) { + PARAKEET_LOG_ERROR("%s: encoder_begin_callback returned false\n", __func__); + return -6; + } + } + + if (params.progress_callback) { + params.progress_callback(ctx, state, 0, params.progress_callback_user_data); + } + + if (!parakeet_ensure_encode_sched(*ctx, *state, n_mel_total)) { + PARAKEET_LOG_ERROR("%s: failed to allocate dynamic encoder graph for %d mel frames\n", + __func__, n_mel_total); + return -6; + } + + state->n_audio_ctx = n_mel_total; + + if (!parakeet_encode_internal(*ctx, *state, 0, params.n_threads, + params.abort_callback, params.abort_callback_user_data)) { + PARAKEET_LOG_ERROR("%s: failed to encode\n", __func__); + return -6; + } + + if (params.progress_callback) { + params.progress_callback(ctx, state, 100, params.progress_callback_user_data); + } + + const size_t tokens_before = state->decoded_tokens.size(); + + if (!parakeet_decode(*ctx, *state, state->batch, params.n_threads, ¶ms)) { + PARAKEET_LOG_ERROR("%s: failed to decode\n", __func__); + return -7; + } + + const size_t tokens_after = state->decoded_tokens.size(); + const size_t new_token_count = tokens_after - tokens_before; + + if (new_token_count > 0) { + std::string text; + std::vector<parakeet_token_data> result_tokens; + + for (size_t i = tokens_before; i < tokens_after; i++) { + const auto token_id = state->decoded_tokens[i]; + const char * tok_str = parakeet_token_to_str(ctx, token_id); + if (tok_str) { + const bool is_first = (tokens_before == 0) && text.empty(); + text += sentencepiece_piece_to_text(tok_str, is_first); + } + result_tokens.push_back(state->decoded_token_data[i]); + } + + refine_timestamps_tdt(ctx->vocab, result_tokens); + + if (!text.empty()) { + parakeet_segment seg; + seg.t0 = 0; + seg.t1 = state->n_frames; + seg.text = text; + seg.tokens = result_tokens; + state->result_all.push_back(std::move(seg)); + + if (params.new_segment_callback) { + params.new_segment_callback(ctx, state, 1, params.new_segment_callback_user_data); + } + } + } + + return 0; +} + +int parakeet_full( + struct parakeet_context * ctx, + struct parakeet_full_params params, + const float * samples, + int n_samples) { + return parakeet_full_with_state(ctx, ctx->state, params, samples, n_samples); +} + +int parakeet_chunk( + struct parakeet_context * ctx, + struct parakeet_state * state, + struct parakeet_full_params params, + const float * samples, + int n_samples) { + + if (params.no_context) { + parakeet_reset_state(state); + } + + if (n_samples > 0) { + if (parakeet_pcm_to_mel_with_state(ctx, state, samples, n_samples, params.n_threads) != 0) { + PARAKEET_LOG_ERROR("%s: failed to compute log mel spectrogram\n", __func__); + return -2; + } + } + + if (params.audio_ctx == 0) { + const int total_len = parakeet_n_len_from_state(state); + const int model_max_ctx = parakeet_n_audio_ctx(ctx); + params.audio_ctx = std::min(total_len, model_max_ctx); + PARAKEET_LOG_DEBUG("Processing audio: total_frames=%d, chunk_size=%d\n", total_len, params.audio_ctx); + } + state->n_audio_ctx = params.audio_ctx; + + const int n_frames = parakeet_n_len_from_state(state); + + if (!parakeet_ensure_encode_sched(*ctx, *state, state->n_audio_ctx)) { + PARAKEET_LOG_ERROR("%s: failed to allocate encoder graph for %d mel frames\n", + __func__, state->n_audio_ctx); + return -6; + } + + if (params.encoder_begin_callback) { + if (!params.encoder_begin_callback(ctx, state, params.encoder_begin_callback_user_data)) { + PARAKEET_LOG_ERROR("%s: encoder_begin_callback returned false - aborting\n", __func__); + return -6; + } + } + if (!parakeet_encode_internal(*ctx, *state, 0, params.n_threads, params.abort_callback, params.abort_callback_user_data)) { + PARAKEET_LOG_ERROR("%s: failed to encode\n", __func__); + return -6; + } + + const size_t tokens_before = state->decoded_tokens.size(); + + if (!parakeet_decode(*ctx, *state, state->batch, params.n_threads, ¶ms)) { + PARAKEET_LOG_ERROR("%s: failed to decode\n", __func__); + return -7; + } + + const size_t tokens_after = state->decoded_tokens.size(); + const size_t new_token_count = tokens_after - tokens_before; + + if (new_token_count > 0) { + std::string text; + std::vector<parakeet_token_data> result_tokens; + + for (size_t i = tokens_before; i < tokens_after; i++) { + const auto token_id = state->decoded_tokens[i]; + const char * token_str = parakeet_token_to_str(ctx, token_id); + if (token_str) { + const bool is_first_piece = (tokens_before == 0) && text.empty(); + text += sentencepiece_piece_to_text(token_str, is_first_piece); + } + + // Use the stored token data from parakeet_decode + result_tokens.push_back(state->decoded_token_data[i]); + } + + refine_timestamps_tdt(ctx->vocab, result_tokens); + + if (!text.empty()) { + parakeet_segment segment; + segment.t0 = 0; // Caller tracks timing + segment.t1 = n_frames; + segment.text = text; + segment.tokens = result_tokens; + + state->result_all.push_back(std::move(segment)); + + if (params.new_segment_callback) { + params.new_segment_callback(ctx, state, 1, params.new_segment_callback_user_data); + } + } + } + + return 0; +} + +int parakeet_full_n_segments_from_state(struct parakeet_state * state) { + return state->result_all.size(); +} + +int parakeet_full_n_segments(struct parakeet_context * ctx) { + return ctx->state->result_all.size(); +} + +int64_t parakeet_full_get_segment_t0_from_state(struct parakeet_state * state, int i_segment) { + return state->result_all[i_segment].t0; +} + +int64_t parakeet_full_get_segment_t1_from_state(struct parakeet_state * state, int i_segment) { + return state->result_all[i_segment].t1; +} + +int64_t parakeet_full_get_segment_t0(struct parakeet_context * ctx, int i_segment) { + return parakeet_full_get_segment_t0_from_state(ctx->state, i_segment); +} + +int64_t parakeet_full_get_segment_t1(struct parakeet_context * ctx, int i_segment) { + return parakeet_full_get_segment_t1_from_state(ctx->state, i_segment); +} + +const char * parakeet_full_get_segment_text_from_state(struct parakeet_state * state, int i_segment) { + return state->result_all[i_segment].text.c_str(); +} + +const char * parakeet_full_get_segment_text(struct parakeet_context * ctx, int i_segment) { + return ctx->state->result_all[i_segment].text.c_str(); +} + +int parakeet_full_n_tokens_from_state(struct parakeet_state * state, int i_segment) { + return state->result_all[i_segment].tokens.size(); +} + +int parakeet_full_n_tokens(struct parakeet_context * ctx, int i_segment) { + return ctx->state->result_all[i_segment].tokens.size(); +} + +const char * parakeet_full_get_token_text_from_state(struct parakeet_context * ctx, struct parakeet_state * state, int i_segment, int i_token) { + return ctx->vocab.id_to_token[state->result_all[i_segment].tokens[i_token].id].c_str(); +} + +const char* parakeet_full_get_token_text(struct parakeet_context * ctx, int i_segment, int i_token) { + return ctx->vocab.id_to_token[ctx->state->result_all[i_segment].tokens[i_token].id].c_str(); +} + +parakeet_token parakeet_full_get_token_id_from_state(struct parakeet_state * state, int i_segment, int i_token) { + return state->result_all[i_segment].tokens[i_token].id; +} + +parakeet_token parakeet_full_get_token_id(struct parakeet_context * ctx, int i_segment, int i_token) { + return ctx->state->result_all[i_segment].tokens[i_token].id; +} + +struct parakeet_token_data parakeet_full_get_token_data_from_state(struct parakeet_state * state, int i_segment, int i_token) { + return state->result_all[i_segment].tokens[i_token]; +} + +struct parakeet_token_data parakeet_full_get_token_data(struct parakeet_context * ctx, int i_segment, int i_token) { + return ctx->state->result_all[i_segment].tokens[i_token]; +} + +float parakeet_full_get_token_p_from_state(struct parakeet_state * state, int i_segment, int i_token) { + return state->result_all[i_segment].tokens[i_token].p; +} + +float parakeet_full_get_token_p(struct parakeet_context * ctx, int i_segment, int i_token) { + return ctx->state->result_all[i_segment].tokens[i_token].p; +} + +void parakeet_log_set(ggml_log_callback log_callback, void * user_data) { + g_state.log_callback = log_callback ? log_callback : parakeet_log_callback_default; + g_state.log_callback_user_data = user_data; + ggml_log_set(g_state.log_callback, g_state.log_callback_user_data); +} + +const char * parakeet_version(void) { + return PARAKEET_VERSION; +} + +GGML_ATTRIBUTE_FORMAT(2, 3) +static void parakeet_log_internal(ggml_log_level level, const char * format, ...) { + va_list args; + va_start(args, format); + char buffer[1024]; + int len = vsnprintf(buffer, 1024, format, args); + if (len < 1024) { + g_state.log_callback(level, buffer, g_state.log_callback_user_data); + } else { + char* buffer2 = new char[len+1]; + vsnprintf(buffer2, len+1, format, args); + buffer2[len] = 0; + g_state.log_callback(level, buffer2, g_state.log_callback_user_data); + delete[] buffer2; + } + va_end(args); +} + +static void parakeet_log_callback_default(ggml_log_level level, const char * text, void * user_data) { + (void) level; + (void) user_data; +#ifndef PARAKEET_DEBUG + if (level == GGML_LOG_LEVEL_DEBUG) { + return; + } +#endif + fputs(text, stderr); + fflush(stderr); +} diff --git a/src/whisper.cpp b/src/whisper.cpp index 796bccfb45d..5ffc70af00e 100644 --- a/src/whisper.cpp +++ b/src/whisper.cpp @@ -3720,7 +3720,21 @@ struct whisper_context * whisper_init_with_params_no_state(struct whisper_model_ whisper_context * ctx = new whisper_context; ctx->params = params; - if (!whisper_model_load(loader, *ctx)) { + // A C++ exception escaping this extern "C" function aborts non-C++ callers + // (Rust via whisper-rs, Go via cgo, ...). whisper_model_load can throw + // (std::runtime_error here; vk::SystemError from the Vulkan backend during + // device/buffer allocation), so funnel any throw into the existing + // NULL-return failure path instead of letting it cross the C ABI. + bool model_loaded = false; + try { + model_loaded = whisper_model_load(loader, *ctx); + } catch (const std::exception & e) { + WHISPER_LOG_ERROR("%s: exception during model load: %s\n", __func__, e.what()); + } catch (...) { + WHISPER_LOG_ERROR("%s: unknown exception during model load\n", __func__); + } + + if (!model_loaded) { loader->close(loader->context); WHISPER_LOG_ERROR("%s: failed to load model\n", __func__); delete ctx; @@ -5083,7 +5097,11 @@ struct whisper_vad_context * whisper_vad_init_with_params( return vctx; } -bool whisper_vad_detect_speech( +void whisper_vad_reset_state(whisper_vad_context * vctx) { + ggml_backend_buffer_clear(vctx->buffer, 0); +} + +bool whisper_vad_detect_speech_no_reset( struct whisper_vad_context * vctx, const float * samples, int n_samples) { @@ -5095,9 +5113,6 @@ bool whisper_vad_detect_speech( WHISPER_LOG_INFO("%s: detecting speech in %d samples\n", __func__, n_samples); WHISPER_LOG_INFO("%s: n_chunks: %d\n", __func__, n_chunks); - // Reset LSTM hidden/cell states - ggml_backend_buffer_clear(vctx->buffer, 0); - vctx->probs.resize(n_chunks); WHISPER_LOG_INFO("%s: props size: %u\n", __func__, n_chunks); @@ -5165,6 +5180,14 @@ bool whisper_vad_detect_speech( return true; } +bool whisper_vad_detect_speech( + struct whisper_vad_context * vctx, + const float * samples, + int n_samples) { + whisper_vad_reset_state(vctx); + return whisper_vad_detect_speech_no_reset(vctx, samples, n_samples); +} + int whisper_vad_segments_n_segments(struct whisper_vad_segments * segments) { return segments->data.size(); } @@ -6207,6 +6230,13 @@ static void whisper_process_logits( } } + // ref: https://github.com/ggml-org/whisper.cpp/pull/3798 + if (!params.no_timestamps && !params.single_segment && params.max_tokens > 0 && (int) tokens_cur.size() >= params.max_tokens) { + for (int i = 0; i < vocab.token_eot; ++i) { + logits[i] = -INFINITY; + } + } + // suppress sot and nosp tokens logits[vocab.token_sot] = -INFINITY; logits[vocab.token_nosp] = -INFINITY; @@ -6701,12 +6731,13 @@ static bool whisper_vad( int segment_start_samples = cs_to_samples(vad_segments->data[i].start); int segment_end_samples = cs_to_samples(vad_segments->data[i].end); - if (i < (int)vad_segments->data.size() - 1) { - segment_end_samples += overlap_samples; - } - segment_start_samples = std::min(segment_start_samples, n_samples - 1); segment_end_samples = std::min(segment_end_samples, n_samples - 1); + int original_segment_length = segment_end_samples - segment_start_samples; + + if (i < (int)vad_segments->data.size() - 1) { + segment_end_samples = std::min(segment_end_samples + overlap_samples, n_samples - 1); + } int segment_length = segment_end_samples - segment_start_samples; if (segment_length > 0) { whisper_state::vad_segment_info segment; @@ -6715,7 +6746,7 @@ static bool whisper_vad( segment.orig_end = vad_segments->data[i].end; segment.vad_start = samples_to_cs(offset); - segment.vad_end = samples_to_cs(offset + segment_length); + segment.vad_end = samples_to_cs(offset + original_segment_length); // Add segment boundaries to mapping table vad_time_mapping start_mapping = {segment.vad_start, segment.orig_start}; @@ -6724,29 +6755,6 @@ static bool whisper_vad( state->vad_mapping_table.push_back(start_mapping); state->vad_mapping_table.push_back(end_mapping); - // Add intermediate points for longer segments to improve interpolation accuracy - const int64_t min_segment_length = 100; // 1 second - const int64_t point_interval = 20; // Add a point every 200ms - - if (segment.vad_end - segment.vad_start > min_segment_length) { - int64_t segment_duration = segment.vad_end - segment.vad_start; - int num_points = (int)(segment_duration / point_interval) - 1; - - for (int j = 1; j <= num_points; j++) { - int64_t vad_time = segment.vad_start + j * point_interval; - - if (vad_time >= segment.vad_end) continue; - - int64_t vad_elapsed = vad_time - segment.vad_start; - int64_t vad_total = segment.vad_end - segment.vad_start; - int64_t orig_total = segment.orig_end - segment.orig_start; - int64_t orig_time = segment.orig_start + (vad_elapsed * orig_total) / vad_total; - - vad_time_mapping intermediate_mapping = {vad_time, orig_time}; - state->vad_mapping_table.push_back(intermediate_mapping); - } - } - WHISPER_LOG_INFO("%s: vad_segment_info: orig_start: %.2f, orig_end: %.2f, vad_start: %.2f, vad_end: %.2f\n", __func__, segment.orig_start/100.0, segment.orig_end/100.0, segment.vad_start/100.0, segment.vad_end/100.0); ctx->state->vad_segments.push_back(segment); @@ -7672,11 +7680,14 @@ int whisper_full_with_state( } } text = ""; - while (i < (int) tokens_cur.size() && tokens_cur[i].id > whisper_token_beg(ctx)) { + t0 = t1; + while (i + 1 < (int) tokens_cur.size() && tokens_cur[i + 1].id > whisper_token_beg(ctx)) { i++; + if (params.print_special) { + text += whisper_token_to_str(ctx, tokens_cur[i].id); + } + t0 = seek + 2 * (tokens_cur[i].tid - whisper_token_beg(ctx)); } - i--; - t0 = t1; i0 = i + 1; speaker_turn_next = false; } @@ -7693,8 +7704,8 @@ int whisper_full_with_state( printf("[%s --> %s] %s\n", to_timestamp(tt0).c_str(), to_timestamp(tt1).c_str(), text.c_str()); } else { printf("%s", text.c_str()); - fflush(stdout); } + fflush(stdout); } result_all.push_back({ tt0, tt1, text, state->no_speech_prob, {}, speaker_turn_next }); @@ -7735,7 +7746,12 @@ int whisper_full_with_state( } // ref: https://github.com/ggml-org/whisper.cpp/pull/2629 + const bool max_tokens_timestamp_ending = params.max_tokens > 0 && + !params.single_segment && + tokens_cur.size() > (size_t) params.max_tokens; + const bool single_timestamp_ending = tokens_cur.size() > 1 && + !max_tokens_timestamp_ending && tokens_cur[tokens_cur.size() - 2].id < whisper_token_beg(ctx) && tokens_cur[tokens_cur.size() - 1].id > whisper_token_beg(ctx); if (single_timestamp_ending) { @@ -8256,9 +8272,6 @@ WHISPER_API const char * whisper_bench_ggml_mul_mat_str(int n_threads) { // when F16 is used, there is an extra work buffer of size N*N*sizeof(float) std::vector<uint8_t> buf(3llu*N_max*N_max*sizeof(float) + 3*ggml_tensor_overhead() + ggml_graph_overhead()); - // put a bunch of random data in the buffer - for (size_t i = 0; i < buf.size(); i++) buf[i] = i; - for (int j = 0; j < (int) sizes.size(); j++) { int n_q4_0 = 0; int n_q4_1 = 0; @@ -8302,6 +8315,15 @@ WHISPER_API const char * whisper_bench_ggml_mul_mat_str(int n_threads) { struct ggml_tensor * a = ggml_new_tensor_2d(ctx0, wtype, N, N); struct ggml_tensor * b = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, N, N); + // set tensor data after allocation so previous iteration results don't corrupt it. + { + uint8_t * a_data = (uint8_t *) a->data; + for (size_t ii = 0; ii < ggml_nbytes(a); ii++) a_data[ii] = ii & 0x3F; + + uint8_t * b_data = (uint8_t *) b->data; + for (size_t ii = 0; ii < ggml_nbytes(b); ii++) b_data[ii] = ii & 0x3F; + } + struct ggml_tensor * c = ggml_mul_mat(ctx0, a, b); struct ggml_cgraph * gf = ggml_new_graph(ctx0); diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index 09e77ea89c2..74a5b142948 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -78,7 +78,7 @@ add_test(NAME ${TEST_TARGET} -f ${PROJECT_SOURCE_DIR}/samples/jfk.wav) set_tests_properties(${TEST_TARGET} PROPERTIES LABELS "large") -if (WHISPER_FFMPEG) +if (WHISPER_COMMON_FFMPEG) set(TEST_TARGET test-whisper-cli-tiny-mp3) # Check with reviewers: any way to check the output transcription via ctest (diff, ...)? add_test(NAME ${TEST_TARGET} @@ -88,6 +88,14 @@ if (WHISPER_FFMPEG) set_tests_properties(${TEST_TARGET} PROPERTIES LABELS "tiny;mp3") endif() +# UTF-8 helper unit test +set(UTF8_TEST test-common-utf8) +add_executable(${UTF8_TEST} ${UTF8_TEST}.cpp) +target_include_directories(${UTF8_TEST} PRIVATE ../examples) +target_link_libraries(${UTF8_TEST} PRIVATE common) +add_test(NAME ${UTF8_TEST} COMMAND ${UTF8_TEST}) +set_tests_properties(${UTF8_TEST} PROPERTIES LABELS "unit") + # VAD test tests VAD in isolation set(VAD_TEST test-vad) add_executable(${VAD_TEST} ${VAD_TEST}.cpp) @@ -110,3 +118,62 @@ target_compile_definitions(${VAD_TEST} PRIVATE SAMPLE_PATH="${PROJECT_SOURCE_DIR}/samples/jfk.wav") add_test(NAME ${VAD_TEST} COMMAND ${VAD_TEST}) set_tests_properties(${VAD_TEST} PROPERTIES LABELS "base;en") + +# Parakeet model loading test +set(PARAKEET_TEST test-parakeet) +add_executable(${PARAKEET_TEST} ${PARAKEET_TEST}.cpp) +target_include_directories(${PARAKEET_TEST} PRIVATE ../include ../ggml/include ../examples) +target_link_libraries(${PARAKEET_TEST} PRIVATE parakeet common) +target_compile_definitions(${PARAKEET_TEST} PRIVATE + PARAKEET_MODEL_PATH="${PROJECT_SOURCE_DIR}/models/for-tests-ggml-parakeet-tdt.bin" + SAMPLE_PATH="${PROJECT_SOURCE_DIR}/samples/jfk.wav") +add_test(NAME ${PARAKEET_TEST} COMMAND ${PARAKEET_TEST}) +set_tests_properties(${PARAKEET_TEST} PROPERTIES LABELS "parakeet;gh") + +# The following parakeet test require a real ggml-parakeet-tdt model to have +# been converted or downloaded: +# $ hf download danbev/parakeet parakeet-tdt-0.6b-v3-f32.bin --local-dir models +# +# And also required more audio samples that are shipped by default. These can +# downloaded by running: +# $ make samples +function(add_parakeet_transcription_test TEST_TARGET TEST_SOURCE SAMPLE_PATH EXPECTED_TRANSCRIPTION_PATH) + set(TRANSCRIPTION_SIMILARITY_THRESHOLD "1.0") + if (ARGC GREATER 4) + set(TRANSCRIPTION_SIMILARITY_THRESHOLD "${ARGV4}") + endif() + + add_executable(${TEST_TARGET} ${TEST_SOURCE}) + target_include_directories(${TEST_TARGET} PRIVATE ../include ../ggml/include ../examples) + target_link_libraries(${TEST_TARGET} PRIVATE parakeet common) + target_compile_definitions(${TEST_TARGET} PRIVATE + PARAKEET_MODEL_PATH="${PROJECT_SOURCE_DIR}/models/ggml-parakeet-tdt-0.6b-v3-f32.bin" + SAMPLE_PATH="${PROJECT_SOURCE_DIR}/${SAMPLE_PATH}" + EXPECTED_TRANSCRIPTION_PATH="${PROJECT_SOURCE_DIR}/${EXPECTED_TRANSCRIPTION_PATH}" + TRANSCRIPTION_SIMILARITY_THRESHOLD=${TRANSCRIPTION_SIMILARITY_THRESHOLD}) + + add_custom_target(run-${TEST_TARGET} + COMMAND $<TARGET_FILE:${TEST_TARGET}> + DEPENDS ${TEST_TARGET} + WORKING_DIRECTORY ${PROJECT_BINARY_DIR}) +endfunction() + +add_parakeet_transcription_test( + test-parakeet-full-jfk + test-parakeet-full.cpp + samples/jfk.wav + tests/parakeet-expected-jfk-output.txt) + +add_parakeet_transcription_test( + test-parakeet-full-gb1 + test-parakeet-full.cpp + samples/gb1.wav + tests/parakeet-expected-gb1-output.txt) + +add_parakeet_transcription_test( + test-parakeet-full-diffusion + test-parakeet-full.cpp + samples/diffusion2023-07-03.flac + tests/parakeet-expected-diffusion-output.txt + 0.95) + diff --git a/tests/librispeech-parakeet/.gitignore b/tests/librispeech-parakeet/.gitignore new file mode 100644 index 00000000000..838bfeae9db --- /dev/null +++ b/tests/librispeech-parakeet/.gitignore @@ -0,0 +1,6 @@ +__pycache__ +*.tar.gz +*.txt +eval.conf +venv +LibriSpeech diff --git a/tests/librispeech-parakeet/Makefile b/tests/librispeech-parakeet/Makefile new file mode 100644 index 00000000000..0afa2465f49 --- /dev/null +++ b/tests/librispeech-parakeet/Makefile @@ -0,0 +1,15 @@ +TAR_URL = https://www.openslr.org/resources/12/test-clean.tar.gz + +all: eval + +eval: + $(MAKE) -f eval.mk + +clean: + $(MAKE) -f eval.mk clean + +get-audio: + wget -c $(TAR_URL) + tar -xf test-clean.tar.gz + +.PHONY: all eval clean setup-venv clean-venv get-audio diff --git a/tests/librispeech-parakeet/README.md b/tests/librispeech-parakeet/README.md new file mode 100644 index 00000000000..e09cba405ef --- /dev/null +++ b/tests/librispeech-parakeet/README.md @@ -0,0 +1,57 @@ +# parakeet.cpp/tests/librispeech + +[LibriSpeech](https://www.openslr.org/12) is a standard dataset for +training and evaluating automatic speech recognition systems. + +This directory contains a set of tools to evaluate the recognition +performance of parakeet.cpp on LibriSpeech corpus. + +## Quick Start + +1. (Pre-requirement) Compile `parakeet-cli` and prepare the Parakeet + model in `ggml` format. + + ``` + $ # Execute the commands below in the project root dir. + $ cmake -B build + $ cmake --build build --config Release + ``` + +2. Download the audio files from LibriSpeech project. + + ``` + $ make get-audio + ``` + +3. Set up the environment to compute WER score. + + ``` + $ pip install -r requirements.txt + ``` + + For example, if you use `virtualenv`, you can set up it as follows: + + ``` + $ python3 -m venv venv + $ . venv/bin/activate + $ pip install -r requirements.txt + ``` + +4. Run the benchmark test. + + ``` + $ make + ``` + +## How-to guides + +### How to change the inference parameters + +Create `eval.conf` and override variables. + +``` +PARAKEET_MODEL = parakeet-tdt-0.6b-v3 +PARAKEET_FLAGS = --no-prints --threads 8 --language en --output-txt +``` + +Check out `eval.mk` for more details. diff --git a/tests/librispeech-parakeet/eval.mk b/tests/librispeech-parakeet/eval.mk new file mode 100644 index 00000000000..7d8992ec471 --- /dev/null +++ b/tests/librispeech-parakeet/eval.mk @@ -0,0 +1,39 @@ +PYTHON = python + +PARAKEET_PREFIX = ../../ +PARAKEET_MODEL = parakeet-tdt-0.6b-v3 + +PARAKEET_CLI = $(PARAKEET_PREFIX)build/bin/parakeet-cli +PARAKEET_FLAGS = --no-prints --output-txt + +# You can create eval.conf to override the PARAKEET_* variables +# defined above. +-include eval.conf + +# This follows the file structure of the LibriSpeech project. +AUDIO_SRCS = $(sort $(wildcard LibriSpeech/*/*/*/*.flac)) +TRANS_TXTS = $(addsuffix .txt, $(AUDIO_SRCS)) + +# We output the evaluation result to this file. +DONE = $(PARAKEET_MODEL).txt + +all: $(DONE) + +$(DONE): $(TRANS_TXTS) + $(PYTHON) eval.py > $@.tmp + mv $@.tmp $@ + +# Note: This task writes to a temporary file first to +# create the target file atomically. +%.flac.txt: %.flac + $(PARAKEET_CLI) $(PARAKEET_FLAGS) --model $(PARAKEET_PREFIX)models/ggml-$(PARAKEET_MODEL).bin --file $^ --output-file $^.tmp + mv $^.tmp.txt $^.txt + +archive: + tar -czf $(PARAKEET_MODEL).tar.gz --exclude="*.flac" LibriSpeech $(DONE) + +clean: + @rm -f $(TRANS_TXTS) + @rm -f $(DONE) + +.PHONY: all clean diff --git a/tests/librispeech-parakeet/eval.py b/tests/librispeech-parakeet/eval.py new file mode 100644 index 00000000000..cdaf8352fd4 --- /dev/null +++ b/tests/librispeech-parakeet/eval.py @@ -0,0 +1,47 @@ +import os +import glob +import jiwer +from normalizers import EnglishTextNormalizer + +def get_reference(): + ref = {} + for path in glob.glob('LibriSpeech/*/*/*/*.trans.txt'): + with open(path) as fp: + for line in fp: + code, text = line.strip().split(" ", maxsplit=1) + ref [code] = text + return ref + +def get_hypothesis(): + hyp = {} + for path in glob.glob('LibriSpeech/*/*/*/*.flac.txt'): + with open(path) as fp: + text = fp.read().strip() + code = os.path.basename(path).replace('.flac.txt', '') + hyp[code] = text + return hyp + +def get_codes(): + codes = [] + for path in glob.glob('LibriSpeech/*/*/*/*.flac'): + codes.append(os.path.basename(path).replace('.flac', '')) + return sorted(codes) + +def main(): + normalizer = EnglishTextNormalizer() + + ref_orig = get_reference() + hyp_orig = get_hypothesis() + + ref_clean = [] + hyp_clean = [] + + for code in get_codes(): + ref_clean.append(normalizer(ref_orig[code])) + hyp_clean.append(normalizer(hyp_orig[code])) + + wer = jiwer.wer(ref_clean, hyp_clean) + print(f"WER: {wer * 100:.2f}%") + +if __name__ == '__main__': + main() diff --git a/tests/librispeech-parakeet/normalizers/LICENSE b/tests/librispeech-parakeet/normalizers/LICENSE new file mode 100644 index 00000000000..7c8e603b0fc --- /dev/null +++ b/tests/librispeech-parakeet/normalizers/LICENSE @@ -0,0 +1,25 @@ +Code in this directory is adapted from OpenAI Whisper project +(https://github.com/openai/whisper) and carries the following +copyright and license. + + MIT License + + Copyright (c) 2022 OpenAI + + Permission is hereby granted, free of charge, to any person obtaining a copy + of this software and associated documentation files (the "Software"), to deal + in the Software without restriction, including without limitation the rights + to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + copies of the Software, and to permit persons to whom the Software is + furnished to do so, subject to the following conditions: + + The above copyright notice and this permission notice shall be included in all + copies or substantial portions of the Software. + + THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + SOFTWARE. diff --git a/tests/librispeech-parakeet/normalizers/__init__.py b/tests/librispeech-parakeet/normalizers/__init__.py new file mode 100644 index 00000000000..896d5e33641 --- /dev/null +++ b/tests/librispeech-parakeet/normalizers/__init__.py @@ -0,0 +1,2 @@ +from .basic import BasicTextNormalizer as BasicTextNormalizer +from .english import EnglishTextNormalizer as EnglishTextNormalizer diff --git a/tests/librispeech-parakeet/normalizers/basic.py b/tests/librispeech-parakeet/normalizers/basic.py new file mode 100644 index 00000000000..8690ae71c5f --- /dev/null +++ b/tests/librispeech-parakeet/normalizers/basic.py @@ -0,0 +1,80 @@ +import re +import unicodedata + +import regex + +# non-ASCII letters that are not separated by "NFKD" normalization +ADDITIONAL_DIACRITICS = { + "œ": "oe", + "Œ": "OE", + "ø": "o", + "Ø": "O", + "æ": "ae", + "Æ": "AE", + "ß": "ss", + "ẞ": "SS", + "đ": "d", + "Đ": "D", + "ð": "d", + "Ð": "D", + "þ": "th", + "Þ": "th", + "ł": "l", + "Ł": "L", +} + + +def remove_symbols_and_diacritics(s: str, keep=""): + """ + Replace any other markers, symbols, and punctuations with a space, + and drop any diacritics (category 'Mn' and some manual mappings) + """ + return "".join( + ( + c + if c in keep + else ( + ADDITIONAL_DIACRITICS[c] + if c in ADDITIONAL_DIACRITICS + else ( + "" + if unicodedata.category(c) == "Mn" + else " " if unicodedata.category(c)[0] in "MSP" else c + ) + ) + ) + for c in unicodedata.normalize("NFKD", s) + ) + + +def remove_symbols(s: str): + """ + Replace any other markers, symbols, punctuations with a space, keeping diacritics + """ + return "".join( + " " if unicodedata.category(c)[0] in "MSP" else c + for c in unicodedata.normalize("NFKC", s) + ) + + +class BasicTextNormalizer: + def __init__(self, remove_diacritics: bool = False, split_letters: bool = False): + self.clean = ( + remove_symbols_and_diacritics if remove_diacritics else remove_symbols + ) + self.split_letters = split_letters + + def __call__(self, s: str): + s = s.lower() + s = re.sub(r"[<\[][^>\]]*[>\]]", "", s) # remove words between brackets + s = re.sub(r"\(([^)]+?)\)", "", s) # remove words between parenthesis + s = self.clean(s).lower() + + if self.split_letters: + s = " ".join(regex.findall(r"\X", s, regex.U)) + + s = re.sub( + r"\s+", " ", s + ) # replace any successive whitespace characters with a space + + return s diff --git a/tests/librispeech-parakeet/normalizers/english.json b/tests/librispeech-parakeet/normalizers/english.json new file mode 100644 index 00000000000..74a1c3521d9 --- /dev/null +++ b/tests/librispeech-parakeet/normalizers/english.json @@ -0,0 +1,1741 @@ +{ + "accessorise": "accessorize", + "accessorised": "accessorized", + "accessorises": "accessorizes", + "accessorising": "accessorizing", + "acclimatisation": "acclimatization", + "acclimatise": "acclimatize", + "acclimatised": "acclimatized", + "acclimatises": "acclimatizes", + "acclimatising": "acclimatizing", + "accoutrements": "accouterments", + "aeon": "eon", + "aeons": "eons", + "aerogramme": "aerogram", + "aerogrammes": "aerograms", + "aeroplane": "airplane", + "aeroplanes": "airplanes", + "aesthete": "esthete", + "aesthetes": "esthetes", + "aesthetic": "esthetic", + "aesthetically": "esthetically", + "aesthetics": "esthetics", + "aetiology": "etiology", + "ageing": "aging", + "aggrandisement": "aggrandizement", + "agonise": "agonize", + "agonised": "agonized", + "agonises": "agonizes", + "agonising": "agonizing", + "agonisingly": "agonizingly", + "almanack": "almanac", + "almanacks": "almanacs", + "aluminium": "aluminum", + "amortisable": "amortizable", + "amortisation": "amortization", + "amortisations": "amortizations", + "amortise": "amortize", + "amortised": "amortized", + "amortises": "amortizes", + "amortising": "amortizing", + "amphitheatre": "amphitheater", + "amphitheatres": "amphitheaters", + "anaemia": "anemia", + "anaemic": "anemic", + "anaesthesia": "anesthesia", + "anaesthetic": "anesthetic", + "anaesthetics": "anesthetics", + "anaesthetise": "anesthetize", + "anaesthetised": "anesthetized", + "anaesthetises": "anesthetizes", + "anaesthetising": "anesthetizing", + "anaesthetist": "anesthetist", + "anaesthetists": "anesthetists", + "anaesthetize": "anesthetize", + "anaesthetized": "anesthetized", + "anaesthetizes": "anesthetizes", + "anaesthetizing": "anesthetizing", + "analogue": "analog", + "analogues": "analogs", + "analyse": "analyze", + "analysed": "analyzed", + "analyses": "analyzes", + "analysing": "analyzing", + "anglicise": "anglicize", + "anglicised": "anglicized", + "anglicises": "anglicizes", + "anglicising": "anglicizing", + "annualised": "annualized", + "antagonise": "antagonize", + "antagonised": "antagonized", + "antagonises": "antagonizes", + "antagonising": "antagonizing", + "apologise": "apologize", + "apologised": "apologized", + "apologises": "apologizes", + "apologising": "apologizing", + "appal": "appall", + "appals": "appalls", + "appetiser": "appetizer", + "appetisers": "appetizers", + "appetising": "appetizing", + "appetisingly": "appetizingly", + "arbour": "arbor", + "arbours": "arbors", + "archeological": "archaeological", + "archaeologically": "archeologically", + "archaeologist": "archeologist", + "archaeologists": "archeologists", + "archaeology": "archeology</span>", + "ardour": "ardor", + "armour": "armor", + "armoured": "armored", + "armourer": "armorer", + "armourers": "armorers", + "armouries": "armories", + "armoury": "armory", + "artefact": "artifact", + "artefacts": "artifacts", + "authorise": "authorize", + "authorised": "authorized", + "authorises": "authorizes", + "authorising": "authorizing", + "axe": "ax", + "backpedalled": "backpedaled", + "backpedalling": "backpedaling", + "bannister": "banister", + "bannisters": "banisters", + "baptise": "baptize", + "baptised": "baptized", + "baptises": "baptizes", + "baptising": "baptizing", + "bastardise": "bastardize", + "bastardised": "bastardized", + "bastardises": "bastardizes", + "bastardising": "bastardizing", + "battleax": "battleaxe", + "baulk": "balk", + "baulked": "balked", + "baulking": "balking", + "baulks": "balks", + "bedevilled": "bedeviled", + "bedevilling": "bedeviling", + "behaviour": "behavior", + "behavioural": "behavioral", + "behaviourism": "behaviorism", + "behaviourist": "behaviorist", + "behaviourists": "behaviorists", + "behaviours": "behaviors", + "behove": "behoove", + "behoved": "behooved", + "behoves": "behooves", + "bejewelled": "bejeweled", + "belabour": "belabor", + "belaboured": "belabored", + "belabouring": "belaboring", + "belabours": "belabors", + "bevelled": "beveled", + "bevvies": "bevies", + "bevvy": "bevy", + "biassed": "biased", + "biassing": "biasing", + "bingeing": "binging", + "bougainvillaea": "bougainvillea", + "bougainvillaeas": "bougainvilleas", + "bowdlerise": "bowdlerize", + "bowdlerised": "bowdlerized", + "bowdlerises": "bowdlerizes", + "bowdlerising": "bowdlerizing", + "breathalyse": "breathalyze", + "breathalysed": "breathalyzed", + "breathalyser": "breathalyzer", + "breathalysers": "breathalyzers", + "breathalyses": "breathalyzes", + "breathalysing": "breathalyzing", + "brutalise": "brutalize", + "brutalised": "brutalized", + "brutalises": "brutalizes", + "brutalising": "brutalizing", + "busses": "buses", + "bussing": "busing", + "caesarean": "cesarean", + "caesareans": "cesareans", + "calibre": "caliber", + "calibres": "calibers", + "calliper": "caliper", + "callipers": "calipers", + "callisthenics": "calisthenics", + "canalise": "canalize", + "canalised": "canalized", + "canalises": "canalizes", + "canalising": "canalizing", + "cancelation": "cancellation", + "cancelations": "cancellations", + "cancelled": "canceled", + "cancelling": "canceling", + "candour": "candor", + "cannibalise": "cannibalize", + "cannibalised": "cannibalized", + "cannibalises": "cannibalizes", + "cannibalising": "cannibalizing", + "canonise": "canonize", + "canonised": "canonized", + "canonises": "canonizes", + "canonising": "canonizing", + "capitalise": "capitalize", + "capitalised": "capitalized", + "capitalises": "capitalizes", + "capitalising": "capitalizing", + "caramelise": "caramelize", + "caramelised": "caramelized", + "caramelises": "caramelizes", + "caramelising": "caramelizing", + "carbonise": "carbonize", + "carbonised": "carbonized", + "carbonises": "carbonizes", + "carbonising": "carbonizing", + "carolled": "caroled", + "carolling": "caroling", + "catalogue": "catalog", + "catalogued": "cataloged", + "catalogues": "catalogs", + "cataloguing": "cataloging", + "catalyse": "catalyze", + "catalysed": "catalyzed", + "catalyses": "catalyzes", + "catalysing": "catalyzing", + "categorise": "categorize", + "categorised": "categorized", + "categorises": "categorizes", + "categorising": "categorizing", + "cauterise": "cauterize", + "cauterised": "cauterized", + "cauterises": "cauterizes", + "cauterising": "cauterizing", + "cavilled": "caviled", + "cavilling": "caviling", + "centigramme": "centigram", + "centigrammes": "centigrams", + "centilitre": "centiliter", + "centilitres": "centiliters", + "centimetre": "centimeter", + "centimetres": "centimeters", + "centralise": "centralize", + "centralised": "centralized", + "centralises": "centralizes", + "centralising": "centralizing", + "centre": "center", + "centred": "centered", + "centrefold": "centerfold", + "centrefolds": "centerfolds", + "centrepiece": "centerpiece", + "centrepieces": "centerpieces", + "centres": "centers", + "channelled": "channeled", + "channelling": "channeling", + "characterise": "characterize", + "characterised": "characterized", + "characterises": "characterizes", + "characterising": "characterizing", + "cheque": "check", + "chequebook": "checkbook", + "chequebooks": "checkbooks", + "chequered": "checkered", + "cheques": "checks", + "chilli": "chili", + "chimaera": "chimera", + "chimaeras": "chimeras", + "chiselled": "chiseled", + "chiselling": "chiseling", + "circularise": "circularize", + "circularised": "circularized", + "circularises": "circularizes", + "circularising": "circularizing", + "civilise": "civilize", + "civilised": "civilized", + "civilises": "civilizes", + "civilising": "civilizing", + "clamour": "clamor", + "clamoured": "clamored", + "clamouring": "clamoring", + "clamours": "clamors", + "clangour": "clangor", + "clarinettist": "clarinetist", + "clarinettists": "clarinetists", + "collectivise": "collectivize", + "collectivised": "collectivized", + "collectivises": "collectivizes", + "collectivising": "collectivizing", + "colonisation": "colonization", + "colonise": "colonize", + "colonised": "colonized", + "coloniser": "colonizer", + "colonisers": "colonizers", + "colonises": "colonizes", + "colonising": "colonizing", + "colour": "color", + "colourant": "colorant", + "colourants": "colorants", + "coloured": "colored", + "coloureds": "coloreds", + "colourful": "colorful", + "colourfully": "colorfully", + "colouring": "coloring", + "colourize": "colorize", + "colourized": "colorized", + "colourizes": "colorizes", + "colourizing": "colorizing", + "colourless": "colorless", + "colours": "colors", + "commercialise": "commercialize", + "commercialised": "commercialized", + "commercialises": "commercializes", + "commercialising": "commercializing", + "compartmentalise": "compartmentalize", + "compartmentalised": "compartmentalized", + "compartmentalises": "compartmentalizes", + "compartmentalising": "compartmentalizing", + "computerise": "computerize", + "computerised": "computerized", + "computerises": "computerizes", + "computerising": "computerizing", + "conceptualise": "conceptualize", + "conceptualised": "conceptualized", + "conceptualises": "conceptualizes", + "conceptualising": "conceptualizing", + "connexion": "connection", + "connexions": "connections", + "contextualise": "contextualize", + "contextualised": "contextualized", + "contextualises": "contextualizes", + "contextualising": "contextualizing", + "cosier": "cozier", + "cosies": "cozies", + "cosiest": "coziest", + "cosily": "cozily", + "cosiness": "coziness", + "cosy": "cozy", + "councillor": "councilor", + "councillors": "councilors", + "counselled": "counseled", + "counselling": "counseling", + "counsellor": "counselor", + "counsellors": "counselors", + "crenelated": "crenellated", + "criminalise": "criminalize", + "criminalised": "criminalized", + "criminalises": "criminalizes", + "criminalising": "criminalizing", + "criticise": "criticize", + "criticised": "criticized", + "criticises": "criticizes", + "criticising": "criticizing", + "crueller": "crueler", + "cruellest": "cruelest", + "crystallisation": "crystallization", + "crystallise": "crystallize", + "crystallised": "crystallized", + "crystallises": "crystallizes", + "crystallising": "crystallizing", + "cudgelled": "cudgeled", + "cudgelling": "cudgeling", + "customise": "customize", + "customised": "customized", + "customises": "customizes", + "customising": "customizing", + "cypher": "cipher", + "cyphers": "ciphers", + "decentralisation": "decentralization", + "decentralise": "decentralize", + "decentralised": "decentralized", + "decentralises": "decentralizes", + "decentralising": "decentralizing", + "decriminalisation": "decriminalization", + "decriminalise": "decriminalize", + "decriminalised": "decriminalized", + "decriminalises": "decriminalizes", + "decriminalising": "decriminalizing", + "defence": "defense", + "defenceless": "defenseless", + "defences": "defenses", + "dehumanisation": "dehumanization", + "dehumanise": "dehumanize", + "dehumanised": "dehumanized", + "dehumanises": "dehumanizes", + "dehumanising": "dehumanizing", + "demeanour": "demeanor", + "demilitarisation": "demilitarization", + "demilitarise": "demilitarize", + "demilitarised": "demilitarized", + "demilitarises": "demilitarizes", + "demilitarising": "demilitarizing", + "demobilisation": "demobilization", + "demobilise": "demobilize", + "demobilised": "demobilized", + "demobilises": "demobilizes", + "demobilising": "demobilizing", + "democratisation": "democratization", + "democratise": "democratize", + "democratised": "democratized", + "democratises": "democratizes", + "democratising": "democratizing", + "demonise": "demonize", + "demonised": "demonized", + "demonises": "demonizes", + "demonising": "demonizing", + "demoralisation": "demoralization", + "demoralise": "demoralize", + "demoralised": "demoralized", + "demoralises": "demoralizes", + "demoralising": "demoralizing", + "denationalisation": "denationalization", + "denationalise": "denationalize", + "denationalised": "denationalized", + "denationalises": "denationalizes", + "denationalising": "denationalizing", + "deodorise": "deodorize", + "deodorised": "deodorized", + "deodorises": "deodorizes", + "deodorising": "deodorizing", + "depersonalise": "depersonalize", + "depersonalised": "depersonalized", + "depersonalises": "depersonalizes", + "depersonalising": "depersonalizing", + "deputise": "deputize", + "deputised": "deputized", + "deputises": "deputizes", + "deputising": "deputizing", + "desensitisation": "desensitization", + "desensitise": "desensitize", + "desensitised": "desensitized", + "desensitises": "desensitizes", + "desensitising": "desensitizing", + "destabilisation": "destabilization", + "destabilise": "destabilize", + "destabilised": "destabilized", + "destabilises": "destabilizes", + "destabilising": "destabilizing", + "dialled": "dialed", + "dialling": "dialing", + "dialogue": "dialog", + "dialogues": "dialogs", + "diarrhoea": "diarrhea", + "digitise": "digitize", + "digitised": "digitized", + "digitises": "digitizes", + "digitising": "digitizing", + "disc": "disk", + "discolour": "discolor", + "discoloured": "discolored", + "discolouring": "discoloring", + "discolours": "discolors", + "discs": "disks", + "disembowelled": "disemboweled", + "disembowelling": "disemboweling", + "disfavour": "disfavor", + "dishevelled": "disheveled", + "dishonour": "dishonor", + "dishonourable": "dishonorable", + "dishonourably": "dishonorably", + "dishonoured": "dishonored", + "dishonouring": "dishonoring", + "dishonours": "dishonors", + "disorganisation": "disorganization", + "disorganised": "disorganized", + "distil": "distill", + "distils": "distills", + "dramatisation": "dramatization", + "dramatisations": "dramatizations", + "dramatise": "dramatize", + "dramatised": "dramatized", + "dramatises": "dramatizes", + "dramatising": "dramatizing", + "draught": "draft", + "draughtboard": "draftboard", + "draughtboards": "draftboards", + "draughtier": "draftier", + "draughtiest": "draftiest", + "draughts": "drafts", + "draughtsman": "draftsman", + "draughtsmanship": "draftsmanship", + "draughtsmen": "draftsmen", + "draughtswoman": "draftswoman", + "draughtswomen": "draftswomen", + "draughty": "drafty", + "drivelled": "driveled", + "drivelling": "driveling", + "duelled": "dueled", + "duelling": "dueling", + "economise": "economize", + "economised": "economized", + "economises": "economizes", + "economising": "economizing", + "edoema": "edema", + "editorialise": "editorialize", + "editorialised": "editorialized", + "editorialises": "editorializes", + "editorialising": "editorializing", + "empathise": "empathize", + "empathised": "empathized", + "empathises": "empathizes", + "empathising": "empathizing", + "emphasise": "emphasize", + "emphasised": "emphasized", + "emphasises": "emphasizes", + "emphasising": "emphasizing", + "enamelled": "enameled", + "enamelling": "enameling", + "enamoured": "enamored", + "encyclopaedia": "encyclopedia", + "encyclopaedias": "encyclopedias", + "encyclopaedic": "encyclopedic", + "endeavour": "endeavor", + "endeavoured": "endeavored", + "endeavouring": "endeavoring", + "endeavours": "endeavors", + "energise": "energize", + "energised": "energized", + "energises": "energizes", + "energising": "energizing", + "enrol": "enroll", + "enrols": "enrolls", + "enthral": "enthrall", + "enthrals": "enthralls", + "epaulette": "epaulet", + "epaulettes": "epaulets", + "epicentre": "epicenter", + "epicentres": "epicenters", + "epilogue": "epilog", + "epilogues": "epilogs", + "epitomise": "epitomize", + "epitomised": "epitomized", + "epitomises": "epitomizes", + "epitomising": "epitomizing", + "equalisation": "equalization", + "equalise": "equalize", + "equalised": "equalized", + "equaliser": "equalizer", + "equalisers": "equalizers", + "equalises": "equalizes", + "equalising": "equalizing", + "eulogise": "eulogize", + "eulogised": "eulogized", + "eulogises": "eulogizes", + "eulogising": "eulogizing", + "evangelise": "evangelize", + "evangelised": "evangelized", + "evangelises": "evangelizes", + "evangelising": "evangelizing", + "exorcise": "exorcize", + "exorcised": "exorcized", + "exorcises": "exorcizes", + "exorcising": "exorcizing", + "extemporisation": "extemporization", + "extemporise": "extemporize", + "extemporised": "extemporized", + "extemporises": "extemporizes", + "extemporising": "extemporizing", + "externalisation": "externalization", + "externalisations": "externalizations", + "externalise": "externalize", + "externalised": "externalized", + "externalises": "externalizes", + "externalising": "externalizing", + "factorise": "factorize", + "factorised": "factorized", + "factorises": "factorizes", + "factorising": "factorizing", + "faecal": "fecal", + "faeces": "feces", + "familiarisation": "familiarization", + "familiarise": "familiarize", + "familiarised": "familiarized", + "familiarises": "familiarizes", + "familiarising": "familiarizing", + "fantasise": "fantasize", + "fantasised": "fantasized", + "fantasises": "fantasizes", + "fantasising": "fantasizing", + "favour": "favor", + "favourable": "favorable", + "favourably": "favorably", + "favoured": "favored", + "favouring": "favoring", + "favourite": "favorite", + "favourites": "favorites", + "favouritism": "favoritism", + "favours": "favors", + "feminise": "feminize", + "feminised": "feminized", + "feminises": "feminizes", + "feminising": "feminizing", + "fertilisation": "fertilization", + "fertilise": "fertilize", + "fertilised": "fertilized", + "fertiliser": "fertilizer", + "fertilisers": "fertilizers", + "fertilises": "fertilizes", + "fertilising": "fertilizing", + "fervour": "fervor", + "fibre": "fiber", + "fibreglass": "fiberglass", + "fibres": "fibers", + "fictionalisation": "fictionalization", + "fictionalisations": "fictionalizations", + "fictionalise": "fictionalize", + "fictionalised": "fictionalized", + "fictionalises": "fictionalizes", + "fictionalising": "fictionalizing", + "fillet": "filet", + "filleted": "fileted", + "filleting": "fileting", + "fillets": "filets", + "finalisation": "finalization", + "finalise": "finalize", + "finalised": "finalized", + "finalises": "finalizes", + "finalising": "finalizing", + "flautist": "flutist", + "flautists": "flutists", + "flavour": "flavor", + "flavoured": "flavored", + "flavouring": "flavoring", + "flavourings": "flavorings", + "flavourless": "flavorless", + "flavours": "flavors", + "flavoursome": "flavorsome", + "flyer / flier": "flier / flyer", + "foetal": "fetal", + "foetid": "fetid", + "foetus": "fetus", + "foetuses": "fetuses", + "formalisation": "formalization", + "formalise": "formalize", + "formalised": "formalized", + "formalises": "formalizes", + "formalising": "formalizing", + "fossilisation": "fossilization", + "fossilise": "fossilize", + "fossilised": "fossilized", + "fossilises": "fossilizes", + "fossilising": "fossilizing", + "fraternisation": "fraternization", + "fraternise": "fraternize", + "fraternised": "fraternized", + "fraternises": "fraternizes", + "fraternising": "fraternizing", + "fulfil": "fulfill", + "fulfilment": "fulfillment", + "fulfils": "fulfills", + "funnelled": "funneled", + "funnelling": "funneling", + "galvanise": "galvanize", + "galvanised": "galvanized", + "galvanises": "galvanizes", + "galvanising": "galvanizing", + "gambolled": "gamboled", + "gambolling": "gamboling", + "gaol": "jail", + "gaolbird": "jailbird", + "gaolbirds": "jailbirds", + "gaolbreak": "jailbreak", + "gaolbreaks": "jailbreaks", + "gaoled": "jailed", + "gaoler": "jailer", + "gaolers": "jailers", + "gaoling": "jailing", + "gaols": "jails", + "gasses": "gases", + "gage": "gauge", + "gaged": "gauged", + "gages": "gauges", + "gaging": "gauging", + "generalisation": "generalization", + "generalisations": "generalizations", + "generalise": "generalize", + "generalised": "generalized", + "generalises": "generalizes", + "generalising": "generalizing", + "ghettoise": "ghettoize", + "ghettoised": "ghettoized", + "ghettoises": "ghettoizes", + "ghettoising": "ghettoizing", + "gipsies": "gypsies", + "glamorise": "glamorize", + "glamorised": "glamorized", + "glamorises": "glamorizes", + "glamorising": "glamorizing", + "glamor": "glamour", + "globalisation": "globalization", + "globalise": "globalize", + "globalised": "globalized", + "globalises": "globalizes", + "globalising": "globalizing", + "glueing": "gluing", + "goitre": "goiter", + "goitres": "goiters", + "gonorrhoea": "gonorrhea", + "gramme": "gram", + "grammes": "grams", + "gravelled": "graveled", + "grey": "gray", + "greyed": "grayed", + "greying": "graying", + "greyish": "grayish", + "greyness": "grayness", + "greys": "grays", + "grovelled": "groveled", + "grovelling": "groveling", + "groyne": "groin", + "groynes": "groins", + "gruelling": "grueling", + "gruellingly": "gruelingly", + "gryphon": "griffin", + "gryphons": "griffins", + "gynaecological": "gynecological", + "gynaecologist": "gynecologist", + "gynaecologists": "gynecologists", + "gynaecology": "gynecology", + "haematological": "hematological", + "haematologist": "hematologist", + "haematologists": "hematologists", + "haematology": "hematology", + "haemoglobin": "hemoglobin", + "haemophilia": "hemophilia", + "haemophiliac": "hemophiliac", + "haemophiliacs": "hemophiliacs", + "haemorrhage": "hemorrhage", + "haemorrhaged": "hemorrhaged", + "haemorrhages": "hemorrhages", + "haemorrhaging": "hemorrhaging", + "haemorrhoids": "hemorrhoids", + "harbour": "harbor", + "harboured": "harbored", + "harbouring": "harboring", + "harbours": "harbors", + "harmonisation": "harmonization", + "harmonise": "harmonize", + "harmonised": "harmonized", + "harmonises": "harmonizes", + "harmonising": "harmonizing", + "homoeopath": "homeopath", + "homoeopathic": "homeopathic", + "homoeopaths": "homeopaths", + "homoeopathy": "homeopathy", + "homogenise": "homogenize", + "homogenised": "homogenized", + "homogenises": "homogenizes", + "homogenising": "homogenizing", + "honour": "honor", + "honourable": "honorable", + "honourably": "honorably", + "honoured": "honored", + "honouring": "honoring", + "honours": "honors", + "hospitalisation": "hospitalization", + "hospitalise": "hospitalize", + "hospitalised": "hospitalized", + "hospitalises": "hospitalizes", + "hospitalising": "hospitalizing", + "humanise": "humanize", + "humanised": "humanized", + "humanises": "humanizes", + "humanising": "humanizing", + "humour": "humor", + "humoured": "humored", + "humouring": "humoring", + "humourless": "humorless", + "humours": "humors", + "hybridise": "hybridize", + "hybridised": "hybridized", + "hybridises": "hybridizes", + "hybridising": "hybridizing", + "hypnotise": "hypnotize", + "hypnotised": "hypnotized", + "hypnotises": "hypnotizes", + "hypnotising": "hypnotizing", + "hypothesise": "hypothesize", + "hypothesised": "hypothesized", + "hypothesises": "hypothesizes", + "hypothesising": "hypothesizing", + "idealisation": "idealization", + "idealise": "idealize", + "idealised": "idealized", + "idealises": "idealizes", + "idealising": "idealizing", + "idolise": "idolize", + "idolised": "idolized", + "idolises": "idolizes", + "idolising": "idolizing", + "immobilisation": "immobilization", + "immobilise": "immobilize", + "immobilised": "immobilized", + "immobiliser": "immobilizer", + "immobilisers": "immobilizers", + "immobilises": "immobilizes", + "immobilising": "immobilizing", + "immortalise": "immortalize", + "immortalised": "immortalized", + "immortalises": "immortalizes", + "immortalising": "immortalizing", + "immunisation": "immunization", + "immunise": "immunize", + "immunised": "immunized", + "immunises": "immunizes", + "immunising": "immunizing", + "impanelled": "impaneled", + "impanelling": "impaneling", + "imperilled": "imperiled", + "imperilling": "imperiling", + "individualise": "individualize", + "individualised": "individualized", + "individualises": "individualizes", + "individualising": "individualizing", + "industrialise": "industrialize", + "industrialised": "industrialized", + "industrialises": "industrializes", + "industrialising": "industrializing", + "inflexion": "inflection", + "inflexions": "inflections", + "initialise": "initialize", + "initialised": "initialized", + "initialises": "initializes", + "initialising": "initializing", + "initialled": "initialed", + "initialling": "initialing", + "instal": "install", + "instalment": "installment", + "instalments": "installments", + "instals": "installs", + "instil": "instill", + "instils": "instills", + "institutionalisation": "institutionalization", + "institutionalise": "institutionalize", + "institutionalised": "institutionalized", + "institutionalises": "institutionalizes", + "institutionalising": "institutionalizing", + "intellectualise": "intellectualize", + "intellectualised": "intellectualized", + "intellectualises": "intellectualizes", + "intellectualising": "intellectualizing", + "internalisation": "internalization", + "internalise": "internalize", + "internalised": "internalized", + "internalises": "internalizes", + "internalising": "internalizing", + "internationalisation": "internationalization", + "internationalise": "internationalize", + "internationalised": "internationalized", + "internationalises": "internationalizes", + "internationalising": "internationalizing", + "ionisation": "ionization", + "ionise": "ionize", + "ionised": "ionized", + "ioniser": "ionizer", + "ionisers": "ionizers", + "ionises": "ionizes", + "ionising": "ionizing", + "italicise": "italicize", + "italicised": "italicized", + "italicises": "italicizes", + "italicising": "italicizing", + "itemise": "itemize", + "itemised": "itemized", + "itemises": "itemizes", + "itemising": "itemizing", + "jeopardise": "jeopardize", + "jeopardised": "jeopardized", + "jeopardises": "jeopardizes", + "jeopardising": "jeopardizing", + "jewelled": "jeweled", + "jeweller": "jeweler", + "jewellers": "jewelers", + "jewellery": "jewelry", + "judgement": "judgment", + "kilogramme": "kilogram", + "kilogrammes": "kilograms", + "kilometre": "kilometer", + "kilometres": "kilometers", + "labelled": "labeled", + "labelling": "labeling", + "labour": "labor", + "laboured": "labored", + "labourer": "laborer", + "labourers": "laborers", + "labouring": "laboring", + "labours": "labors", + "lacklustre": "lackluster", + "legalisation": "legalization", + "legalise": "legalize", + "legalised": "legalized", + "legalises": "legalizes", + "legalising": "legalizing", + "legitimise": "legitimize", + "legitimised": "legitimized", + "legitimises": "legitimizes", + "legitimising": "legitimizing", + "leukaemia": "leukemia", + "levelled": "leveled", + "leveller": "leveler", + "levellers": "levelers", + "levelling": "leveling", + "libelled": "libeled", + "libelling": "libeling", + "libellous": "libelous", + "liberalisation": "liberalization", + "liberalise": "liberalize", + "liberalised": "liberalized", + "liberalises": "liberalizes", + "liberalising": "liberalizing", + "licence": "license", + "licenced": "licensed", + "licences": "licenses", + "licencing": "licensing", + "likeable": "likable", + "lionisation": "lionization", + "lionise": "lionize", + "lionised": "lionized", + "lionises": "lionizes", + "lionising": "lionizing", + "liquidise": "liquidize", + "liquidised": "liquidized", + "liquidiser": "liquidizer", + "liquidisers": "liquidizers", + "liquidises": "liquidizes", + "liquidising": "liquidizing", + "litre": "liter", + "litres": "liters", + "localise": "localize", + "localised": "localized", + "localises": "localizes", + "localising": "localizing", + "louvre": "louver", + "louvred": "louvered", + "louvres": "louvers", + "lustre": "luster", + "magnetise": "magnetize", + "magnetised": "magnetized", + "magnetises": "magnetizes", + "magnetising": "magnetizing", + "manoeuvrability": "maneuverability", + "manoeuvrable": "maneuverable", + "manoeuvre": "maneuver", + "manoeuvred": "maneuvered", + "manoeuvres": "maneuvers", + "manoeuvring": "maneuvering", + "manoeuvrings": "maneuverings", + "marginalisation": "marginalization", + "marginalise": "marginalize", + "marginalised": "marginalized", + "marginalises": "marginalizes", + "marginalising": "marginalizing", + "marshalled": "marshaled", + "marshalling": "marshaling", + "marvelled": "marveled", + "marvelling": "marveling", + "marvellous": "marvelous", + "marvellously": "marvelously", + "materialisation": "materialization", + "materialise": "materialize", + "materialised": "materialized", + "materialises": "materializes", + "materialising": "materializing", + "maximisation": "maximization", + "maximise": "maximize", + "maximised": "maximized", + "maximises": "maximizes", + "maximising": "maximizing", + "meagre": "meager", + "mechanisation": "mechanization", + "mechanise": "mechanize", + "mechanised": "mechanized", + "mechanises": "mechanizes", + "mechanising": "mechanizing", + "mediaeval": "medieval", + "memorialise": "memorialize", + "memorialised": "memorialized", + "memorialises": "memorializes", + "memorialising": "memorializing", + "memorise": "memorize", + "memorised": "memorized", + "memorises": "memorizes", + "memorising": "memorizing", + "mesmerise": "mesmerize", + "mesmerised": "mesmerized", + "mesmerises": "mesmerizes", + "mesmerising": "mesmerizing", + "metabolise": "metabolize", + "metabolised": "metabolized", + "metabolises": "metabolizes", + "metabolising": "metabolizing", + "metre": "meter", + "metres": "meters", + "micrometre": "micrometer", + "micrometres": "micrometers", + "militarise": "militarize", + "militarised": "militarized", + "militarises": "militarizes", + "militarising": "militarizing", + "milligramme": "milligram", + "milligrammes": "milligrams", + "millilitre": "milliliter", + "millilitres": "milliliters", + "millimetre": "millimeter", + "millimetres": "millimeters", + "miniaturisation": "miniaturization", + "miniaturise": "miniaturize", + "miniaturised": "miniaturized", + "miniaturises": "miniaturizes", + "miniaturising": "miniaturizing", + "minibusses": "minibuses", + "minimise": "minimize", + "minimised": "minimized", + "minimises": "minimizes", + "minimising": "minimizing", + "misbehaviour": "misbehavior", + "misdemeanour": "misdemeanor", + "misdemeanours": "misdemeanors", + "misspelt": "misspelled", + "mitre": "miter", + "mitres": "miters", + "mobilisation": "mobilization", + "mobilise": "mobilize", + "mobilised": "mobilized", + "mobilises": "mobilizes", + "mobilising": "mobilizing", + "modelled": "modeled", + "modeller": "modeler", + "modellers": "modelers", + "modelling": "modeling", + "modernise": "modernize", + "modernised": "modernized", + "modernises": "modernizes", + "modernising": "modernizing", + "moisturise": "moisturize", + "moisturised": "moisturized", + "moisturiser": "moisturizer", + "moisturisers": "moisturizers", + "moisturises": "moisturizes", + "moisturising": "moisturizing", + "monologue": "monolog", + "monologues": "monologs", + "monopolisation": "monopolization", + "monopolise": "monopolize", + "monopolised": "monopolized", + "monopolises": "monopolizes", + "monopolising": "monopolizing", + "moralise": "moralize", + "moralised": "moralized", + "moralises": "moralizes", + "moralising": "moralizing", + "motorised": "motorized", + "mould": "mold", + "moulded": "molded", + "moulder": "molder", + "mouldered": "moldered", + "mouldering": "moldering", + "moulders": "molders", + "mouldier": "moldier", + "mouldiest": "moldiest", + "moulding": "molding", + "mouldings": "moldings", + "moulds": "molds", + "mouldy": "moldy", + "moult": "molt", + "moulted": "molted", + "moulting": "molting", + "moults": "molts", + "moustache": "mustache", + "moustached": "mustached", + "moustaches": "mustaches", + "moustachioed": "mustachioed", + "multicoloured": "multicolored", + "nationalisation": "nationalization", + "nationalisations": "nationalizations", + "nationalise": "nationalize", + "nationalised": "nationalized", + "nationalises": "nationalizes", + "nationalising": "nationalizing", + "naturalisation": "naturalization", + "naturalise": "naturalize", + "naturalised": "naturalized", + "naturalises": "naturalizes", + "naturalising": "naturalizing", + "neighbour": "neighbor", + "neighbourhood": "neighborhood", + "neighbourhoods": "neighborhoods", + "neighbouring": "neighboring", + "neighbourliness": "neighborliness", + "neighbourly": "neighborly", + "neighbours": "neighbors", + "neutralisation": "neutralization", + "neutralise": "neutralize", + "neutralised": "neutralized", + "neutralises": "neutralizes", + "neutralising": "neutralizing", + "normalisation": "normalization", + "normalise": "normalize", + "normalised": "normalized", + "normalises": "normalizes", + "normalising": "normalizing", + "odour": "odor", + "odourless": "odorless", + "odours": "odors", + "oesophagus": "esophagus", + "oesophaguses": "esophaguses", + "oestrogen": "estrogen", + "offence": "offense", + "offences": "offenses", + "omelette": "omelet", + "omelettes": "omelets", + "optimise": "optimize", + "optimised": "optimized", + "optimises": "optimizes", + "optimising": "optimizing", + "organisation": "organization", + "organisational": "organizational", + "organisations": "organizations", + "organise": "organize", + "organised": "organized", + "organiser": "organizer", + "organisers": "organizers", + "organises": "organizes", + "organising": "organizing", + "orthopaedic": "orthopedic", + "orthopaedics": "orthopedics", + "ostracise": "ostracize", + "ostracised": "ostracized", + "ostracises": "ostracizes", + "ostracising": "ostracizing", + "outmanoeuvre": "outmaneuver", + "outmanoeuvred": "outmaneuvered", + "outmanoeuvres": "outmaneuvers", + "outmanoeuvring": "outmaneuvering", + "overemphasise": "overemphasize", + "overemphasised": "overemphasized", + "overemphasises": "overemphasizes", + "overemphasising": "overemphasizing", + "oxidisation": "oxidization", + "oxidise": "oxidize", + "oxidised": "oxidized", + "oxidises": "oxidizes", + "oxidising": "oxidizing", + "paederast": "pederast", + "paederasts": "pederasts", + "paediatric": "pediatric", + "paediatrician": "pediatrician", + "paediatricians": "pediatricians", + "paediatrics": "pediatrics", + "paedophile": "pedophile", + "paedophiles": "pedophiles", + "paedophilia": "pedophilia", + "palaeolithic": "paleolithic", + "palaeontologist": "paleontologist", + "palaeontologists": "paleontologists", + "palaeontology": "paleontology", + "panelled": "paneled", + "panelling": "paneling", + "panellist": "panelist", + "panellists": "panelists", + "paralyse": "paralyze", + "paralysed": "paralyzed", + "paralyses": "paralyzes", + "paralysing": "paralyzing", + "parcelled": "parceled", + "parcelling": "parceling", + "parlour": "parlor", + "parlours": "parlors", + "particularise": "particularize", + "particularised": "particularized", + "particularises": "particularizes", + "particularising": "particularizing", + "passivisation": "passivization", + "passivise": "passivize", + "passivised": "passivized", + "passivises": "passivizes", + "passivising": "passivizing", + "pasteurisation": "pasteurization", + "pasteurise": "pasteurize", + "pasteurised": "pasteurized", + "pasteurises": "pasteurizes", + "pasteurising": "pasteurizing", + "patronise": "patronize", + "patronised": "patronized", + "patronises": "patronizes", + "patronising": "patronizing", + "patronisingly": "patronizingly", + "pedalled": "pedaled", + "pedalling": "pedaling", + "pedestrianisation": "pedestrianization", + "pedestrianise": "pedestrianize", + "pedestrianised": "pedestrianized", + "pedestrianises": "pedestrianizes", + "pedestrianising": "pedestrianizing", + "penalise": "penalize", + "penalised": "penalized", + "penalises": "penalizes", + "penalising": "penalizing", + "pencilled": "penciled", + "pencilling": "penciling", + "personalise": "personalize", + "personalised": "personalized", + "personalises": "personalizes", + "personalising": "personalizing", + "pharmacopoeia": "pharmacopeia", + "pharmacopoeias": "pharmacopeias", + "philosophise": "philosophize", + "philosophised": "philosophized", + "philosophises": "philosophizes", + "philosophising": "philosophizing", + "philtre": "filter", + "philtres": "filters", + "phoney": "phony", + "plagiarise": "plagiarize", + "plagiarised": "plagiarized", + "plagiarises": "plagiarizes", + "plagiarising": "plagiarizing", + "plough": "plow", + "ploughed": "plowed", + "ploughing": "plowing", + "ploughman": "plowman", + "ploughmen": "plowmen", + "ploughs": "plows", + "ploughshare": "plowshare", + "ploughshares": "plowshares", + "polarisation": "polarization", + "polarise": "polarize", + "polarised": "polarized", + "polarises": "polarizes", + "polarising": "polarizing", + "politicisation": "politicization", + "politicise": "politicize", + "politicised": "politicized", + "politicises": "politicizes", + "politicising": "politicizing", + "popularisation": "popularization", + "popularise": "popularize", + "popularised": "popularized", + "popularises": "popularizes", + "popularising": "popularizing", + "pouffe": "pouf", + "pouffes": "poufs", + "practise": "practice", + "practised": "practiced", + "practises": "practices", + "practising": "practicing", + "praesidium": "presidium", + "praesidiums": "presidiums", + "pressurisation": "pressurization", + "pressurise": "pressurize", + "pressurised": "pressurized", + "pressurises": "pressurizes", + "pressurising": "pressurizing", + "pretence": "pretense", + "pretences": "pretenses", + "primaeval": "primeval", + "prioritisation": "prioritization", + "prioritise": "prioritize", + "prioritised": "prioritized", + "prioritises": "prioritizes", + "prioritising": "prioritizing", + "privatisation": "privatization", + "privatisations": "privatizations", + "privatise": "privatize", + "privatised": "privatized", + "privatises": "privatizes", + "privatising": "privatizing", + "professionalisation": "professionalization", + "professionalise": "professionalize", + "professionalised": "professionalized", + "professionalises": "professionalizes", + "professionalising": "professionalizing", + "programme": "program", + "programmes": "programs", + "prologue": "prolog", + "prologues": "prologs", + "propagandise": "propagandize", + "propagandised": "propagandized", + "propagandises": "propagandizes", + "propagandising": "propagandizing", + "proselytise": "proselytize", + "proselytised": "proselytized", + "proselytiser": "proselytizer", + "proselytisers": "proselytizers", + "proselytises": "proselytizes", + "proselytising": "proselytizing", + "psychoanalyse": "psychoanalyze", + "psychoanalysed": "psychoanalyzed", + "psychoanalyses": "psychoanalyzes", + "psychoanalysing": "psychoanalyzing", + "publicise": "publicize", + "publicised": "publicized", + "publicises": "publicizes", + "publicising": "publicizing", + "pulverisation": "pulverization", + "pulverise": "pulverize", + "pulverised": "pulverized", + "pulverises": "pulverizes", + "pulverising": "pulverizing", + "pummelled": "pummel", + "pummelling": "pummeled", + "pyjama": "pajama", + "pyjamas": "pajamas", + "pzazz": "pizzazz", + "quarrelled": "quarreled", + "quarrelling": "quarreling", + "radicalise": "radicalize", + "radicalised": "radicalized", + "radicalises": "radicalizes", + "radicalising": "radicalizing", + "rancour": "rancor", + "randomise": "randomize", + "randomised": "randomized", + "randomises": "randomizes", + "randomising": "randomizing", + "rationalisation": "rationalization", + "rationalisations": "rationalizations", + "rationalise": "rationalize", + "rationalised": "rationalized", + "rationalises": "rationalizes", + "rationalising": "rationalizing", + "ravelled": "raveled", + "ravelling": "raveling", + "realisable": "realizable", + "realisation": "realization", + "realisations": "realizations", + "realise": "realize", + "realised": "realized", + "realises": "realizes", + "realising": "realizing", + "recognisable": "recognizable", + "recognisably": "recognizably", + "recognisance": "recognizance", + "recognise": "recognize", + "recognised": "recognized", + "recognises": "recognizes", + "recognising": "recognizing", + "reconnoitre": "reconnoiter", + "reconnoitred": "reconnoitered", + "reconnoitres": "reconnoiters", + "reconnoitring": "reconnoitering", + "refuelled": "refueled", + "refuelling": "refueling", + "regularisation": "regularization", + "regularise": "regularize", + "regularised": "regularized", + "regularises": "regularizes", + "regularising": "regularizing", + "remodelled": "remodeled", + "remodelling": "remodeling", + "remould": "remold", + "remoulded": "remolded", + "remoulding": "remolding", + "remoulds": "remolds", + "reorganisation": "reorganization", + "reorganisations": "reorganizations", + "reorganise": "reorganize", + "reorganised": "reorganized", + "reorganises": "reorganizes", + "reorganising": "reorganizing", + "revelled": "reveled", + "reveller": "reveler", + "revellers": "revelers", + "revelling": "reveling", + "revitalise": "revitalize", + "revitalised": "revitalized", + "revitalises": "revitalizes", + "revitalising": "revitalizing", + "revolutionise": "revolutionize", + "revolutionised": "revolutionized", + "revolutionises": "revolutionizes", + "revolutionising": "revolutionizing", + "rhapsodise": "rhapsodize", + "rhapsodised": "rhapsodized", + "rhapsodises": "rhapsodizes", + "rhapsodising": "rhapsodizing", + "rigour": "rigor", + "rigours": "rigors", + "ritualised": "ritualized", + "rivalled": "rivaled", + "rivalling": "rivaling", + "romanticise": "romanticize", + "romanticised": "romanticized", + "romanticises": "romanticizes", + "romanticising": "romanticizing", + "rumour": "rumor", + "rumoured": "rumored", + "rumours": "rumors", + "sabre": "saber", + "sabres": "sabers", + "saltpetre": "saltpeter", + "sanitise": "sanitize", + "sanitised": "sanitized", + "sanitises": "sanitizes", + "sanitising": "sanitizing", + "satirise": "satirize", + "satirised": "satirized", + "satirises": "satirizes", + "satirising": "satirizing", + "saviour": "savior", + "saviours": "saviors", + "savour": "savor", + "savoured": "savored", + "savouries": "savories", + "savouring": "savoring", + "savours": "savors", + "savoury": "savory", + "scandalise": "scandalize", + "scandalised": "scandalized", + "scandalises": "scandalizes", + "scandalising": "scandalizing", + "sceptic": "skeptic", + "sceptical": "skeptical", + "sceptically": "skeptically", + "scepticism": "skepticism", + "sceptics": "skeptics", + "sceptre": "scepter", + "sceptres": "scepters", + "scrutinise": "scrutinize", + "scrutinised": "scrutinized", + "scrutinises": "scrutinizes", + "scrutinising": "scrutinizing", + "secularisation": "secularization", + "secularise": "secularize", + "secularised": "secularized", + "secularises": "secularizes", + "secularising": "secularizing", + "sensationalise": "sensationalize", + "sensationalised": "sensationalized", + "sensationalises": "sensationalizes", + "sensationalising": "sensationalizing", + "sensitise": "sensitize", + "sensitised": "sensitized", + "sensitises": "sensitizes", + "sensitising": "sensitizing", + "sentimentalise": "sentimentalize", + "sentimentalised": "sentimentalized", + "sentimentalises": "sentimentalizes", + "sentimentalising": "sentimentalizing", + "sepulchre": "sepulcher", + "sepulchres": "sepulchers", + "serialisation": "serialization", + "serialisations": "serializations", + "serialise": "serialize", + "serialised": "serialized", + "serialises": "serializes", + "serialising": "serializing", + "sermonise": "sermonize", + "sermonised": "sermonized", + "sermonises": "sermonizes", + "sermonising": "sermonizing", + "sheikh": "sheik", + "shovelled": "shoveled", + "shovelling": "shoveling", + "shrivelled": "shriveled", + "shrivelling": "shriveling", + "signalise": "signalize", + "signalised": "signalized", + "signalises": "signalizes", + "signalising": "signalizing", + "signalled": "signaled", + "signalling": "signaling", + "smoulder": "smolder", + "smouldered": "smoldered", + "smouldering": "smoldering", + "smoulders": "smolders", + "snivelled": "sniveled", + "snivelling": "sniveling", + "snorkelled": "snorkeled", + "snorkelling": "snorkeling", + "snowplough": "snowplow", + "snowploughs": "snowplow", + "socialisation": "socialization", + "socialise": "socialize", + "socialised": "socialized", + "socialises": "socializes", + "socialising": "socializing", + "sodomise": "sodomize", + "sodomised": "sodomized", + "sodomises": "sodomizes", + "sodomising": "sodomizing", + "solemnise": "solemnize", + "solemnised": "solemnized", + "solemnises": "solemnizes", + "solemnising": "solemnizing", + "sombre": "somber", + "specialisation": "specialization", + "specialisations": "specializations", + "specialise": "specialize", + "specialised": "specialized", + "specialises": "specializes", + "specialising": "specializing", + "spectre": "specter", + "spectres": "specters", + "spiralled": "spiraled", + "spiralling": "spiraling", + "splendour": "splendor", + "splendours": "splendors", + "squirrelled": "squirreled", + "squirrelling": "squirreling", + "stabilisation": "stabilization", + "stabilise": "stabilize", + "stabilised": "stabilized", + "stabiliser": "stabilizer", + "stabilisers": "stabilizers", + "stabilises": "stabilizes", + "stabilising": "stabilizing", + "standardisation": "standardization", + "standardise": "standardize", + "standardised": "standardized", + "standardises": "standardizes", + "standardising": "standardizing", + "stencilled": "stenciled", + "stencilling": "stenciling", + "sterilisation": "sterilization", + "sterilisations": "sterilizations", + "sterilise": "sterilize", + "sterilised": "sterilized", + "steriliser": "sterilizer", + "sterilisers": "sterilizers", + "sterilises": "sterilizes", + "sterilising": "sterilizing", + "stigmatisation": "stigmatization", + "stigmatise": "stigmatize", + "stigmatised": "stigmatized", + "stigmatises": "stigmatizes", + "stigmatising": "stigmatizing", + "storey": "story", + "storeys": "stories", + "subsidisation": "subsidization", + "subsidise": "subsidize", + "subsidised": "subsidized", + "subsidiser": "subsidizer", + "subsidisers": "subsidizers", + "subsidises": "subsidizes", + "subsidising": "subsidizing", + "succour": "succor", + "succoured": "succored", + "succouring": "succoring", + "succours": "succors", + "sulphate": "sulfate", + "sulphates": "sulfates", + "sulphide": "sulfide", + "sulphides": "sulfides", + "sulphur": "sulfur", + "sulphurous": "sulfurous", + "summarise": "summarize", + "summarised": "summarized", + "summarises": "summarizes", + "summarising": "summarizing", + "swivelled": "swiveled", + "swivelling": "swiveling", + "symbolise": "symbolize", + "symbolised": "symbolized", + "symbolises": "symbolizes", + "symbolising": "symbolizing", + "sympathise": "sympathize", + "sympathised": "sympathized", + "sympathiser": "sympathizer", + "sympathisers": "sympathizers", + "sympathises": "sympathizes", + "sympathising": "sympathizing", + "synchronisation": "synchronization", + "synchronise": "synchronize", + "synchronised": "synchronized", + "synchronises": "synchronizes", + "synchronising": "synchronizing", + "synthesise": "synthesize", + "synthesised": "synthesized", + "synthesiser": "synthesizer", + "synthesisers": "synthesizers", + "synthesises": "synthesizes", + "synthesising": "synthesizing", + "syphon": "siphon", + "syphoned": "siphoned", + "syphoning": "siphoning", + "syphons": "siphons", + "systematisation": "systematization", + "systematise": "systematize", + "systematised": "systematized", + "systematises": "systematizes", + "systematising": "systematizing", + "tantalise": "tantalize", + "tantalised": "tantalized", + "tantalises": "tantalizes", + "tantalising": "tantalizing", + "tantalisingly": "tantalizingly", + "tasselled": "tasseled", + "technicolour": "technicolor", + "temporise": "temporize", + "temporised": "temporized", + "temporises": "temporizes", + "temporising": "temporizing", + "tenderise": "tenderize", + "tenderised": "tenderized", + "tenderises": "tenderizes", + "tenderising": "tenderizing", + "terrorise": "terrorize", + "terrorised": "terrorized", + "terrorises": "terrorizes", + "terrorising": "terrorizing", + "theatre": "theater", + "theatregoer": "theatergoer", + "theatregoers": "theatergoers", + "theatres": "theaters", + "theorise": "theorize", + "theorised": "theorized", + "theorises": "theorizes", + "theorising": "theorizing", + "tonne": "ton", + "tonnes": "tons", + "towelled": "toweled", + "towelling": "toweling", + "toxaemia": "toxemia", + "tranquillise": "tranquilize", + "tranquillised": "tranquilized", + "tranquilliser": "tranquilizer", + "tranquillisers": "tranquilizers", + "tranquillises": "tranquilizes", + "tranquillising": "tranquilizing", + "tranquillity": "tranquility", + "tranquillize": "tranquilize", + "tranquillized": "tranquilized", + "tranquillizer": "tranquilizer", + "tranquillizers": "tranquilizers", + "tranquillizes": "tranquilizes", + "tranquillizing": "tranquilizing", + "tranquilly": "tranquility", + "transistorised": "transistorized", + "traumatise": "traumatize", + "traumatised": "traumatized", + "traumatises": "traumatizes", + "traumatising": "traumatizing", + "travelled": "traveled", + "traveller": "traveler", + "travellers": "travelers", + "travelling": "traveling", + "travelog": "travelogue", + "travelogs": "travelogues", + "trialled": "trialed", + "trialling": "trialing", + "tricolour": "tricolor", + "tricolours": "tricolors", + "trivialise": "trivialize", + "trivialised": "trivialized", + "trivialises": "trivializes", + "trivialising": "trivializing", + "tumour": "tumor", + "tumours": "tumors", + "tunnelled": "tunneled", + "tunnelling": "tunneling", + "tyrannise": "tyrannize", + "tyrannised": "tyrannized", + "tyrannises": "tyrannizes", + "tyrannising": "tyrannizing", + "tyre": "tire", + "tyres": "tires", + "unauthorised": "unauthorized", + "uncivilised": "uncivilized", + "underutilised": "underutilized", + "unequalled": "unequaled", + "unfavourable": "unfavorable", + "unfavourably": "unfavorably", + "unionisation": "unionization", + "unionise": "unionize", + "unionised": "unionized", + "unionises": "unionizes", + "unionising": "unionizing", + "unorganised": "unorganized", + "unravelled": "unraveled", + "unravelling": "unraveling", + "unrecognisable": "unrecognizable", + "unrecognised": "unrecognized", + "unrivalled": "unrivaled", + "unsavoury": "unsavory", + "untrammelled": "untrammeled", + "urbanisation": "urbanization", + "urbanise": "urbanize", + "urbanised": "urbanized", + "urbanises": "urbanizes", + "urbanising": "urbanizing", + "utilisable": "utilizable", + "utilisation": "utilization", + "utilise": "utilize", + "utilised": "utilized", + "utilises": "utilizes", + "utilising": "utilizing", + "valour": "valor", + "vandalise": "vandalize", + "vandalised": "vandalized", + "vandalises": "vandalizes", + "vandalising": "vandalizing", + "vaporisation": "vaporization", + "vaporise": "vaporize", + "vaporised": "vaporized", + "vaporises": "vaporizes", + "vaporising": "vaporizing", + "vapour": "vapor", + "vapours": "vapors", + "verbalise": "verbalize", + "verbalised": "verbalized", + "verbalises": "verbalizes", + "verbalising": "verbalizing", + "victimisation": "victimization", + "victimise": "victimize", + "victimised": "victimized", + "victimises": "victimizes", + "victimising": "victimizing", + "videodisc": "videodisk", + "videodiscs": "videodisks", + "vigour": "vigor", + "visualisation": "visualization", + "visualisations": "visualizations", + "visualise": "visualize", + "visualised": "visualized", + "visualises": "visualizes", + "visualising": "visualizing", + "vocalisation": "vocalization", + "vocalisations": "vocalizations", + "vocalise": "vocalize", + "vocalised": "vocalized", + "vocalises": "vocalizes", + "vocalising": "vocalizing", + "vulcanised": "vulcanized", + "vulgarisation": "vulgarization", + "vulgarise": "vulgarize", + "vulgarised": "vulgarized", + "vulgarises": "vulgarizes", + "vulgarising": "vulgarizing", + "waggon": "wagon", + "waggons": "wagons", + "watercolour": "watercolor", + "watercolours": "watercolors", + "weaselled": "weaseled", + "weaselling": "weaseling", + "westernisation": "westernization", + "westernise": "westernize", + "westernised": "westernized", + "westernises": "westernizes", + "westernising": "westernizing", + "womanise": "womanize", + "womanised": "womanized", + "womaniser": "womanizer", + "womanisers": "womanizers", + "womanises": "womanizes", + "womanising": "womanizing", + "woollen": "woolen", + "woollens": "woolens", + "woollies": "woolies", + "woolly": "wooly", + "worshipped": "worshiped", + "worshipping": "worshiping", + "worshipper": "worshiper", + "yodelled": "yodeled", + "yodelling": "yodeling", + "yoghourt": "yogurt", + "yoghourts": "yogurts", + "yoghurt": "yogurt", + "yoghurts": "yogurts", + "mhm": "hmm", + "mmm": "hmm" +} \ No newline at end of file diff --git a/tests/librispeech-parakeet/normalizers/english.py b/tests/librispeech-parakeet/normalizers/english.py new file mode 100644 index 00000000000..4932042bc5b --- /dev/null +++ b/tests/librispeech-parakeet/normalizers/english.py @@ -0,0 +1,550 @@ +import json +import os +import re +from fractions import Fraction +from typing import Iterator, List, Match, Optional, Union + +from more_itertools import windowed + +from .basic import remove_symbols_and_diacritics + + +class EnglishNumberNormalizer: + """ + Convert any spelled-out numbers into arabic numbers, while handling: + + - remove any commas + - keep the suffixes such as: `1960s`, `274th`, `32nd`, etc. + - spell out currency symbols after the number. e.g. `$20 million` -> `20000000 dollars` + - spell out `one` and `ones` + - interpret successive single-digit numbers as nominal: `one oh one` -> `101` + """ + + def __init__(self): + super().__init__() + + self.zeros = {"o", "oh", "zero"} + self.ones = { + name: i + for i, name in enumerate( + [ + "one", + "two", + "three", + "four", + "five", + "six", + "seven", + "eight", + "nine", + "ten", + "eleven", + "twelve", + "thirteen", + "fourteen", + "fifteen", + "sixteen", + "seventeen", + "eighteen", + "nineteen", + ], + start=1, + ) + } + self.ones_plural = { + "sixes" if name == "six" else name + "s": (value, "s") + for name, value in self.ones.items() + } + self.ones_ordinal = { + "zeroth": (0, "th"), + "first": (1, "st"), + "second": (2, "nd"), + "third": (3, "rd"), + "fifth": (5, "th"), + "twelfth": (12, "th"), + **{ + name + ("h" if name.endswith("t") else "th"): (value, "th") + for name, value in self.ones.items() + if value > 3 and value != 5 and value != 12 + }, + } + self.ones_suffixed = {**self.ones_plural, **self.ones_ordinal} + + self.tens = { + "twenty": 20, + "thirty": 30, + "forty": 40, + "fifty": 50, + "sixty": 60, + "seventy": 70, + "eighty": 80, + "ninety": 90, + } + self.tens_plural = { + name.replace("y", "ies"): (value, "s") for name, value in self.tens.items() + } + self.tens_ordinal = { + name.replace("y", "ieth"): (value, "th") + for name, value in self.tens.items() + } + self.tens_suffixed = {**self.tens_plural, **self.tens_ordinal} + + self.multipliers = { + "hundred": 100, + "thousand": 1_000, + "million": 1_000_000, + "billion": 1_000_000_000, + "trillion": 1_000_000_000_000, + "quadrillion": 1_000_000_000_000_000, + "quintillion": 1_000_000_000_000_000_000, + "sextillion": 1_000_000_000_000_000_000_000, + "septillion": 1_000_000_000_000_000_000_000_000, + "octillion": 1_000_000_000_000_000_000_000_000_000, + "nonillion": 1_000_000_000_000_000_000_000_000_000_000, + "decillion": 1_000_000_000_000_000_000_000_000_000_000_000, + } + self.multipliers_plural = { + name + "s": (value, "s") for name, value in self.multipliers.items() + } + self.multipliers_ordinal = { + name + "th": (value, "th") for name, value in self.multipliers.items() + } + self.multipliers_suffixed = { + **self.multipliers_plural, + **self.multipliers_ordinal, + } + self.decimals = {*self.ones, *self.tens, *self.zeros} + + self.preceding_prefixers = { + "minus": "-", + "negative": "-", + "plus": "+", + "positive": "+", + } + self.following_prefixers = { + "pound": "£", + "pounds": "£", + "euro": "€", + "euros": "€", + "dollar": "$", + "dollars": "$", + "cent": "¢", + "cents": "¢", + } + self.prefixes = set( + list(self.preceding_prefixers.values()) + + list(self.following_prefixers.values()) + ) + self.suffixers = { + "per": {"cent": "%"}, + "percent": "%", + } + self.specials = {"and", "double", "triple", "point"} + + self.words = set( + [ + key + for mapping in [ + self.zeros, + self.ones, + self.ones_suffixed, + self.tens, + self.tens_suffixed, + self.multipliers, + self.multipliers_suffixed, + self.preceding_prefixers, + self.following_prefixers, + self.suffixers, + self.specials, + ] + for key in mapping + ] + ) + self.literal_words = {"one", "ones"} + + def process_words(self, words: List[str]) -> Iterator[str]: + prefix: Optional[str] = None + value: Optional[Union[str, int]] = None + skip = False + + def to_fraction(s: str): + try: + return Fraction(s) + except ValueError: + return None + + def output(result: Union[str, int]): + nonlocal prefix, value + result = str(result) + if prefix is not None: + result = prefix + result + value = None + prefix = None + return result + + if len(words) == 0: + return + + for prev, current, next in windowed([None] + words + [None], 3): + if skip: + skip = False + continue + + next_is_numeric = next is not None and re.match(r"^\d+(\.\d+)?$", next) + has_prefix = current[0] in self.prefixes + current_without_prefix = current[1:] if has_prefix else current + if re.match(r"^\d+(\.\d+)?$", current_without_prefix): + # arabic numbers (potentially with signs and fractions) + f = to_fraction(current_without_prefix) + assert f is not None + if value is not None: + if isinstance(value, str) and value.endswith("."): + # concatenate decimals / ip address components + value = str(value) + str(current) + continue + else: + yield output(value) + + prefix = current[0] if has_prefix else prefix + if f.denominator == 1: + value = f.numerator # store integers as int + else: + value = current_without_prefix + elif current not in self.words: + # non-numeric words + if value is not None: + yield output(value) + yield output(current) + elif current in self.zeros: + value = str(value or "") + "0" + elif current in self.ones: + ones = self.ones[current] + + if value is None: + value = ones + elif isinstance(value, str) or prev in self.ones: + if ( + prev in self.tens and ones < 10 + ): # replace the last zero with the digit + assert value[-1] == "0" + value = value[:-1] + str(ones) + else: + value = str(value) + str(ones) + elif ones < 10: + if value % 10 == 0: + value += ones + else: + value = str(value) + str(ones) + else: # eleven to nineteen + if value % 100 == 0: + value += ones + else: + value = str(value) + str(ones) + elif current in self.ones_suffixed: + # ordinal or cardinal; yield the number right away + ones, suffix = self.ones_suffixed[current] + if value is None: + yield output(str(ones) + suffix) + elif isinstance(value, str) or prev in self.ones: + if prev in self.tens and ones < 10: + assert value[-1] == "0" + yield output(value[:-1] + str(ones) + suffix) + else: + yield output(str(value) + str(ones) + suffix) + elif ones < 10: + if value % 10 == 0: + yield output(str(value + ones) + suffix) + else: + yield output(str(value) + str(ones) + suffix) + else: # eleven to nineteen + if value % 100 == 0: + yield output(str(value + ones) + suffix) + else: + yield output(str(value) + str(ones) + suffix) + value = None + elif current in self.tens: + tens = self.tens[current] + if value is None: + value = tens + elif isinstance(value, str): + value = str(value) + str(tens) + else: + if value % 100 == 0: + value += tens + else: + value = str(value) + str(tens) + elif current in self.tens_suffixed: + # ordinal or cardinal; yield the number right away + tens, suffix = self.tens_suffixed[current] + if value is None: + yield output(str(tens) + suffix) + elif isinstance(value, str): + yield output(str(value) + str(tens) + suffix) + else: + if value % 100 == 0: + yield output(str(value + tens) + suffix) + else: + yield output(str(value) + str(tens) + suffix) + elif current in self.multipliers: + multiplier = self.multipliers[current] + if value is None: + value = multiplier + elif isinstance(value, str) or value == 0: + f = to_fraction(value) + p = f * multiplier if f is not None else None + if f is not None and p.denominator == 1: + value = p.numerator + else: + yield output(value) + value = multiplier + else: + before = value // 1000 * 1000 + residual = value % 1000 + value = before + residual * multiplier + elif current in self.multipliers_suffixed: + multiplier, suffix = self.multipliers_suffixed[current] + if value is None: + yield output(str(multiplier) + suffix) + elif isinstance(value, str): + f = to_fraction(value) + p = f * multiplier if f is not None else None + if f is not None and p.denominator == 1: + yield output(str(p.numerator) + suffix) + else: + yield output(value) + yield output(str(multiplier) + suffix) + else: # int + before = value // 1000 * 1000 + residual = value % 1000 + value = before + residual * multiplier + yield output(str(value) + suffix) + value = None + elif current in self.preceding_prefixers: + # apply prefix (positive, minus, etc.) if it precedes a number + if value is not None: + yield output(value) + + if next in self.words or next_is_numeric: + prefix = self.preceding_prefixers[current] + else: + yield output(current) + elif current in self.following_prefixers: + # apply prefix (dollars, cents, etc.) only after a number + if value is not None: + prefix = self.following_prefixers[current] + yield output(value) + else: + yield output(current) + elif current in self.suffixers: + # apply suffix symbols (percent -> '%') + if value is not None: + suffix = self.suffixers[current] + if isinstance(suffix, dict): + if next in suffix: + yield output(str(value) + suffix[next]) + skip = True + else: + yield output(value) + yield output(current) + else: + yield output(str(value) + suffix) + else: + yield output(current) + elif current in self.specials: + if next not in self.words and not next_is_numeric: + # apply special handling only if the next word can be numeric + if value is not None: + yield output(value) + yield output(current) + elif current == "and": + # ignore "and" after hundreds, thousands, etc. + if prev not in self.multipliers: + if value is not None: + yield output(value) + yield output(current) + elif current == "double" or current == "triple": + if next in self.ones or next in self.zeros: + repeats = 2 if current == "double" else 3 + ones = self.ones.get(next, 0) + value = str(value or "") + str(ones) * repeats + skip = True + else: + if value is not None: + yield output(value) + yield output(current) + elif current == "point": + if next in self.decimals or next_is_numeric: + value = str(value or "") + "." + else: + # should all have been covered at this point + raise ValueError(f"Unexpected token: {current}") + else: + # all should have been covered at this point + raise ValueError(f"Unexpected token: {current}") + + if value is not None: + yield output(value) + + def preprocess(self, s: str): + # replace "<number> and a half" with "<number> point five" + results = [] + + segments = re.split(r"\band\s+a\s+half\b", s) + for i, segment in enumerate(segments): + if len(segment.strip()) == 0: + continue + if i == len(segments) - 1: + results.append(segment) + else: + results.append(segment) + last_word = segment.rsplit(maxsplit=2)[-1] + if last_word in self.decimals or last_word in self.multipliers: + results.append("point five") + else: + results.append("and a half") + + s = " ".join(results) + + # put a space at number/letter boundary + s = re.sub(r"([a-z])([0-9])", r"\1 \2", s) + s = re.sub(r"([0-9])([a-z])", r"\1 \2", s) + + # but remove spaces which could be a suffix + s = re.sub(r"([0-9])\s+(st|nd|rd|th|s)\b", r"\1\2", s) + + return s + + def postprocess(self, s: str): + def combine_cents(m: Match): + try: + currency = m.group(1) + integer = m.group(2) + cents = int(m.group(3)) + return f"{currency}{integer}.{cents:02d}" + except ValueError: + return m.string + + def extract_cents(m: Match): + try: + return f"¢{int(m.group(1))}" + except ValueError: + return m.string + + # apply currency postprocessing; "$2 and ¢7" -> "$2.07" + s = re.sub(r"([€£$])([0-9]+) (?:and )?¢([0-9]{1,2})\b", combine_cents, s) + s = re.sub(r"[€£$]0.([0-9]{1,2})\b", extract_cents, s) + + # write "one(s)" instead of "1(s)", just for the readability + s = re.sub(r"\b1(s?)\b", r"one\1", s) + + return s + + def __call__(self, s: str): + s = self.preprocess(s) + s = " ".join(word for word in self.process_words(s.split()) if word is not None) + s = self.postprocess(s) + + return s + + +class EnglishSpellingNormalizer: + """ + Applies British-American spelling mappings as listed in [1]. + + [1] https://www.tysto.com/uk-us-spelling-list.html + """ + + def __init__(self): + mapping_path = os.path.join(os.path.dirname(__file__), "english.json") + self.mapping = json.load(open(mapping_path)) + + def __call__(self, s: str): + return " ".join(self.mapping.get(word, word) for word in s.split()) + + +class EnglishTextNormalizer: + def __init__(self): + self.ignore_patterns = r"\b(hmm|mm|mhm|mmm|uh|um)\b" + self.replacers = { + # common contractions + r"\bwon't\b": "will not", + r"\bcan't\b": "can not", + r"\blet's\b": "let us", + r"\bain't\b": "aint", + r"\by'all\b": "you all", + r"\bwanna\b": "want to", + r"\bgotta\b": "got to", + r"\bgonna\b": "going to", + r"\bi'ma\b": "i am going to", + r"\bimma\b": "i am going to", + r"\bwoulda\b": "would have", + r"\bcoulda\b": "could have", + r"\bshoulda\b": "should have", + r"\bma'am\b": "madam", + # contractions in titles/prefixes + r"\bmr\b": "mister ", + r"\bmrs\b": "missus ", + r"\bst\b": "saint ", + r"\bdr\b": "doctor ", + r"\bprof\b": "professor ", + r"\bcapt\b": "captain ", + r"\bgov\b": "governor ", + r"\bald\b": "alderman ", + r"\bgen\b": "general ", + r"\bsen\b": "senator ", + r"\brep\b": "representative ", + r"\bpres\b": "president ", + r"\brev\b": "reverend ", + r"\bhon\b": "honorable ", + r"\basst\b": "assistant ", + r"\bassoc\b": "associate ", + r"\blt\b": "lieutenant ", + r"\bcol\b": "colonel ", + r"\bjr\b": "junior ", + r"\bsr\b": "senior ", + r"\besq\b": "esquire ", + # prefect tenses, ideally it should be any past participles, but it's harder.. + r"'d been\b": " had been", + r"'s been\b": " has been", + r"'d gone\b": " had gone", + r"'s gone\b": " has gone", + r"'d done\b": " had done", # "'s done" is ambiguous + r"'s got\b": " has got", + # general contractions + r"n't\b": " not", + r"'re\b": " are", + r"'s\b": " is", + r"'d\b": " would", + r"'ll\b": " will", + r"'t\b": " not", + r"'ve\b": " have", + r"'m\b": " am", + } + self.standardize_numbers = EnglishNumberNormalizer() + self.standardize_spellings = EnglishSpellingNormalizer() + + def __call__(self, s: str): + s = s.lower() + + s = re.sub(r"[<\[][^>\]]*[>\]]", "", s) # remove words between brackets + s = re.sub(r"\(([^)]+?)\)", "", s) # remove words between parenthesis + s = re.sub(self.ignore_patterns, "", s) + s = re.sub(r"\s+'", "'", s) # when there's a space before an apostrophe + + for pattern, replacement in self.replacers.items(): + s = re.sub(pattern, replacement, s) + + s = re.sub(r"(\d),(\d)", r"\1\2", s) # remove commas between digits + s = re.sub(r"\.([^0-9]|$)", r" \1", s) # remove periods not followed by numbers + s = remove_symbols_and_diacritics(s, keep=".%$¢€£") # keep numeric symbols + + s = self.standardize_numbers(s) + s = self.standardize_spellings(s) + + # now remove prefix/suffix symbols that are not preceded/followed by numbers + s = re.sub(r"[.$¢€£]([^0-9])", r" \1", s) + s = re.sub(r"([^0-9])%", r"\1 ", s) + + s = re.sub(r"\s+", " ", s) # replace any successive whitespaces with a space + + return s diff --git a/tests/parakeet-expected-diffusion-output.txt b/tests/parakeet-expected-diffusion-output.txt new file mode 100644 index 00000000000..9753a86953a --- /dev/null +++ b/tests/parakeet-expected-diffusion-output.txt @@ -0,0 +1 @@ +Hello and welcome to Diffusion. Sit back and relax while we stretch your brain with weird and wonderful science. I'm Ian Wolf. On this edition, Dr. Viv Robinson rewrites cosmology. But first up, here's news of two massive galaxies that might be older than the Big Bang. Galaxies too massive. Astronomers from the Swinburne University of Technology in Melbourne, using the James Webb Space Telescope, have observed six galaxies that formed in the universe's first 700 million years appear to be up to a hundred times more massive than our best theories say can possibly exist. Astronomer Ivo Labe and his colleagues wrote in his paper, adding up the stars in those galaxies, it would exceed the total amount of mass available in the universe at that time. There's too much mass and not enough time for it to get together. The galaxies must have had much longer than the 700 million years after the Big Bang that our standard model of the universe gives them, and the universe must have had more mass available, or galaxies must have formed differently than what we think. The Big Bang is currently thought to have started everything 13.77 billion years ago. And these galaxies, we're watching them at 0.77 billion years ago because they're so far away. Galaxies are thought to accumulate gas moved together by giant clumps of dark matter in their region. Generally, only about 10% of the gas in the galaxy ignites to make a star. For galaxies in the remotest parts of the universe where the gas is thin, it takes a long time to accumulate this much gas for this many stars. These six galaxies, however, have so many stars adding up to so much mass that all of the gas in each galaxy had to have become 100% converted into stars in the 700 million years since the universe started in the Big Bang. Under our current understanding, this is impossible. It suggests something in our understanding of the cosmos is wrong. Are we wrong about how to calculate astronomical masses, galaxy formation, dark matter, and the Big Bang and the age of the universe? An astronomer from the Cosmic Dawn Centre in Denmark used the James Webb telescope to look at closer galaxies, and then used the very high resolution of that telescope to calculate the mass more precisely with a different method, and found that these galaxies are three to ten times more massive than we previously thought. Applying this more accurate technique to the six galaxies that are 13 billion light years away would increase their mass, which makes it much worse than what we thought. The paper was titled A Population of Red Candidate Massive Galaxies, approximately 600 million years after the Big Bang, and was published in the journal Nature.com. We're brought to you across Australia on the Community Radio Network and podcast over the internet on www.diffusionradio.com Challenging Physics Newton said everything is either a particle or a wave. Faraday and Maxwell added fields. Einstein added space-time. Quantum physics says everything is made of quanta, which have the properties of both waves and particles, but is neither. Quantum mechanics has no explanation for gravity, and relativity doesn't account for the quantum world. There's a contradiction between our most basic explanations of the universe. Dr. Viv Robinson was the first person to create a physical explanation of Einstein's gravity in a paper published in the Journal of Physics Communications. He's made corrections to people's extensions of Einstein's mathematics and has a different way to interpret those mathematics that gives a different picture of the age of the universe and a different way of looking at how the physics works. From the standard model of quantum physics to Big Bang cosmology. Everything, including you and me, is made of light. It's a very big and very bold claim. I spoke to Dr. Viv Robinson via Zoom and began by asking him, what is the universe made of? The whole stuff of the universe, or entity. I won't call it items because one of them is absolutely nothing. The first thing to all the mass and all the energy is made up of photons. They're little packets of electromagnetic energy, postulated by Maxwell and Planck and proven by Einstein. They come in many different sizes, shapes, and which make that they make up all the mass and energy of the universe. The volume is made up by empty space, absolutely nothing. But it's the properties of the space that are important. And it does this through two of its properties, electric permittivity and magnetic permeability. And it's those properties which then transmit all of the fields. So that's really all it is. They're just the only two stars in a call because the photos are physical things, and space is just the absence of everything, but its property, its properties are what is important about it. And that's a little bit different to what you might hear from a quantum physics class where they talk about space being full of virtual particles coming into and out of existence so that it's not totally empty, or sometimes they say it's full of fields. The fields of every force is in there and things are coming up all the time. So if you go very fast, you'll interact with the fields, all the virtual particles, and you'll get radiation. Yes, well, uh the unfortunate part is that physics is doing exceedingly well under Newtonian mechanics and exceedingly well under Maxwell's mechanics. But as things get smaller and smaller, you get to a stage where things aren't continuous. I mean, Newton's work will anything that's continuous, but eventually you get to the stage where you know a droplet of water is fine, it has surface tension, evaporates, and you're left with one molecule of water. That doesn't behave the same as bulk water. Into that molecule you go hydrogen atoms and oxygen atoms, they behave nothing like water. And then you get, well, they're made of protons, neutrons, electrons, and they have completely different properties from bulk water. So quantum mechanics, things get quantized, and you get the smallest quantity you can get, and that has very, very different properties from the bulk. And what has happened in the past is that uh the uh early on in quantum mechanics and met men like Dirac and Schrdinger, they didn't know what an the structure was an electron was. Also, all they had to know, they knew it was it had wave properties. And so all they did was they attributed it to a way a wave property to it. Now, waves have the advantage over particles, you can manipulate them almost forever with all sorts of different transforms until you get the answer you want. And that gave some confidence to quantum mechanics guys that yes, waves work, and they've been using that forever, and all I'm saying, no, no, no, no, no. Everything is particles, and the particles have specific properties, and you can't manipulate those properties, or you can to a certain extent, but they are what they are, and it's when you know what those properties are that the whole quantum mechanics becomes much simpler. You don't need any of that uh foamy sort of stuff to get to explain whatever you want to explain. I mentioned that there are many different forms of photons, and photons are electromagnetic radiation with an electric field, saying on a magnetic field perpendicular to it, and the whole lot travels in the speed of light in the third dimension. There are many, many variations of that. So that that's fine for energy radiation. But how about matter particles? Well, matter particles are nothing more than photons of the appropriate wavelength making uh appropriate energy making two revolutions per wavelength. And when they do that, what holds what allows them to do that is that they rotate around the magnetic field. And suddenly, instead of in a linear photon, magnetic fields are open. When they rotate around the magnetic field, then the magnetic field of a particle is closed. And a closed magnetic field is much more stable than an open magnetic field, and that's why most of the universe, for example, when uh less about, I think the best estimate I've seen, one percent is radiation, the other 99% is photons struggling in circles, making two revolutions per wavelength. And it's for that that gives particles all their properties. Now, I may say this is a bit hairy-fairy, but it's been known for a long, long time that you get a particle and an antiparticle, you put them together, bing, two photons. At the same time, you can get a photon and goes and hit the target, bang, a particle and an antiparticle. Now that shows a relationship between the two that somehow lots of people missed. But what's the simplest relationship you can have? The simplest relationship is that a particle is a photon making two revolutions in one direction, an antiparticle is the same particle making two revolutions in the other direction. Put them together, they unlock. Because they have mass, they have this thing called angular momentum, which is a great Newtonian property. But because mathematicians sort of didn't know what an electron was, they called it a point particle. You can't have angular momentum with a point particle, so they call it spin and they wave all sorts of different things to make it seem as if they know what they're talking about. It's really just angular momentum. And that's the relationship between mass and energy. Energy is the photon zipping along at the speed of light. Mass is the same photon making two revolutions per wavelength. That's how they can interchange so easily. And that property gives particles all of their properties, including mass. And one of the things that Einstein did work out in 1905, those little what they called uh packets of radio of electromagnetic energy, he did work out that they carried momentum or carried inertia, they had momentum, they had mass. I don't know why people want to prove Einstein wrong. Photons have mass. Now I think the reason for this is that they think oh, Einstein's special relativity corrections, anything traveling at the speed of light, will have an infinite mass. The special relativity corrections only apply to photons which are spiraling. And that's just as um the reason for that is about as complicated as uh post Thagoras' theorem. And what he was at 300 BC or something like that, not difficult. And so photons themselves always travel at the speed of light. And so the rotating photons, photons that are rotating, are rotating also at the same speed of light. Well, that's one old hell of a gyroscope. And that is what gives particles a spin, that's why E equals mc squared, and it's all straightforward. There you go. Really? Well. So if we go back a little bit there where you're saying there's no wave nature, what about the double-slit experiment and other sorts of experiments that seem to show wave properties of particles other than photons? Particles um De Royal worked out in 1925 that if if photons, if um photons behave like particles, and particles to behave like photons, I agree with him, it's completely it's completely true. The actual nature of the rotating photon generates the de Broilie wavelength, and it has all the right properties. For me, and to me, Einstein's special and general relativity theories are relatively simple, so it may I may be talking a little bit out of line here. But the deuil wavelength is automatically generated by the particle as it moves. So it's not something that they hypothesize and don't know what occurs. They they hypothesized it, they measured it, but they don't know how it occurs. Well, yeah, it's quite it's fairly straightforward, but not at uh not not not at this level. What are the implications for this difference in understanding? So are there predictions that you would make that are different to the ones that people following the standard model would make? Oh, not the numbers of them, yeah. So probably the electron tunneling. Where electrons hit a barrier. That's got a very simple mechanical analog. I mean, the electrons are held in uh what you call a very taut field. Now, if you've got something coming up, you've got everything in a tight situation, you come something up banging it at this end, you can do it with billiard balls that'll transport through, and another one will knock out. So, what they call tunneling under this model, but in reality, what they call tunneling is just really a momentum exchange. So that's a little bit like one of those Newton cradles. Where you've got the balls on all attached by a string or a chain to a fulcrum over the top, and one will hit the other one and transfer the momentum to the other one without actually transferring itself. Yeah, you don't get electrons, you know, they have they have wave properties, but yes, but you won't get an electron uh tunneling the wave, the wave is in a very fixed position with respect to the uh electron. It's equal on either side of it. If their tunneling theory were correct, then the lower the energy of the electron, the longer its wavelength, therefore the easier it would be to tunnel. However, in the energy transfer one, the higher the energy, the greater probability it'll knock another electron out the other side. Or it's a simple experiment to do. Just increase the energy of uh an electron coming up to a barrier and see which ones go come out the other end first. Is anyone set up to do that? Oh, anyone could set up to do it. Well, a lot of laboratories could do it. And the so-called tunneling effect is what they use in all of the microelectronics systems. And they wouldn't, it wouldn't, it'd be a very, very simple exercise to carry that out. They may well have done it, and the mathematicians have turned around and added another factor. Yeah, it's a standard thing they do when they don't get the right answer, just add another factor. I can't do that. It's physical reality is physical reality. End of story. I guess that's something to look up and see if someone's done those experiments and and what they did with the results. I think there is I think I'm sure it has been done, and the result is that the higher the energy of the electron, the greater the probability of it emerging on the other side of the barrier. And on the very much bigger scale, are there differences in the way the universe looks for astronomy? Yeah, not as far as astronomy is concerned. What the astronomers see is what there is. No question about it. They're great, they're brilliant, as the astronomers, and most of the experimentalists are they're doing an exceedingly good job. The problem becomes in interpreting what they've seen. And when it comes to the whole universe, for example, it's all based on Einstein's theory of gravity. Well, it should be, but it's more advanced than Newton's inverse square, but for most practical purposes, uh Newton's inverse square works quite well. The two situations where it doesn't work, when the mass is so large, like the mass of the sun or the mass of the center of uh Sagittarius A with the planet or star S2 going around it. That's one situation. The reason why a planet uh or Mercury's orbit precesses in its direction of travel is simply that gravity, when mass is strong enough, gravity actually becomes weaker than inverse square. And that's one of the things you get when you solve Einstein's gravity theory accurately. It becomes weaker than inverse square. Now, when it's weak, if it's weaker than inverse square, Mercury travels a little bit closer to the Sun and is attracted by a slightly stronger force. So it'll arrive back at its perihelion point a little later, and it it'll um process in its direction of travel. And Newton pointed that out in 1687. So I don't know why they didn't sort of work it out correctly. But gravity is weaker than inverse square, is the solution to Einstein's gravity. The other thing is that when gravity is an infinite steady state universe under Newton's theory of gravity, inverse square, will collapse. The reason being that the relative to the universe density mass increases as r cubed, gravity decreases as r squared, so eventually you get to the stage where gravity just uh dominates mass and it collapses. But if gravity is weaker than inverse square, and I just tried to show you that Mercury is precessing orbit because the sun's gravity is weaker than inverse square, well, that applies to all gravity. There's nothing special about our sun, except that it's keeping all us alive on this. When you have an infinite steady-state universe, if gravity is weaker than inverse squares, its effect gets relatively weaker over long distances. And I'm talking typically uh 10 billion light years or something like that, maybe more. But that means an infinite steady-state universe won't collapse. That's a huge, huge difference. That's the biggest thing, mind you, what difference does it make to us here on Earth if uh if Bang's web has seen galaxies, fully formed galaxies 20 billion light years away, doesn't make a scrap of difference to us. But as far as understanding how the universe works, that mistake, and the simple the simple mistake that they the um all mathematicians were uh made, Einstein introduced approximations. He couldn't solve the gravity exactly himself. I have no problem solving his uh his gravity exactly. But he he uh introduced the approximation that one over one plus x approximately equals one minus x. You know, when x is ten to the minus seven or which is or ten to the minus eight, that's a good approximation. I mean you you just read his paper, he says so. And you read the mathematics, you don't even you could read the German version, look at the mathematics, and he says so, and you just work it out, and that was the difference. So, all of their exact solutions to Einstein's gravity, they took where he used the approximation, he derived the figure from one plus one over x, the equivalent of that, and then he rather than do that, he equated it to one minus x, which is which is true. You know, one plus one millionth is nine hundred and ninety millionth. Why they did it, I have no idea. Mind you, it'd be interesting to try and find out why. Uh I think it's if a mathematician of repute says one thing, and I I I will agree that uh on my first readings of Einstein's relativity theories, you think, oh my god, really? Could he understand that then? Then you get in and you start. It's not that difficult. And I think most of them had a solution. You know, somebody came up with a solution to Einstein's group, and everybody just followed it. And nobody, and this is the big thing that I always stress to everybody, don't take somebody's word for it. Go back and check the original yourself. I've seen a few times where people have just made terrible, terrible mistakes. But this would probably be the biggest one in the whole field of cosmology, sorry. Astronomy? You guys, great. Thanks, Uncle Sam, for providing us with all this information. That was part one of my interview with Dr. Viv Robinson. You heard Viv say that matter is made of photons moving in circles. Physicists took Einstein's approximations as gospel instead of using the exact solutions available with lather mathematics. Gravity changes to be weaker over distances, and the universe isn't expanding. Listen next week for part two. If you have any questions for Dr. Robinson, he'd love to answer them on the show. So send your questions to science at diffusionradio.com. If you're in Darlinghurst this Wednesday night, the 5th of July, I will be part of the lineup of scientists speaking at Future Science Talks at the East Village. Go to www.futurescience talks.com.au to grab a ticket and come up and say hello. And if you can't make it Wednesday night, I'll keep you posted on some future talks I'll be giving. And that's all from us this week on Diffusion. Are you a scientist, artist, biohacker, or maker who'd like to be interviewed about your work? Would your company like to sponsor diffusion? Send your contributions, opinions, helpful suggestions and donations to science at diffusionradio.com. That's science at diffusionradio.com. Please subscribe to the Diffusion Science Radio channel on youtube.com slash C slash Diffusion Radio and rate the show on iTunes and tell your friends. Follow me on Twitter at IanWorf. The news music was Rhinos Theme by Kevin McLeod of Incompitech.com. I produce diffusion, which is broadcast around Australia, to 28 stations on the community radio network, including Radio Blue Mountains 89.1 FM in New South Wales, 8CCC in Alice Springs and Tennant Creek, 2 MVR in Nambucker Valley, 3 MVR in the Malleigh Border Districts of Victoria and South Australia, City Park Radio 7LTN in Launcest and Tasmania, and 2XFM in Canberra. Diffusion is narrowcast on Indigo FM88 in Northeast Victoria. Diffusion is syndicated globally on astronomy.fm. Subscribe to the podcast on the diffusion website www.diffusionradio.com. That's www.diffusionradio.com and check the website for links, photos, and videos about this week's show. If you enjoyed the show, you can explore more than a thousand previous episodes archived on diffusionradio.com where the shows are labelled by keywords so you can focus in on the stories you want to hear. Make a donation through PayPal.me slash Ian Worf. Or join my patrons at patreon.com slash Diffusion Radio. I'm Ian Worf. Join us inside your audio device of choice for more science wondering next week on Diffusion Science Radio. Science is fun. It helps you to learn, to know, and to appreciate. When you study science, you make fun feel. diff --git a/tests/parakeet-expected-gb1-output.txt b/tests/parakeet-expected-gb1-output.txt new file mode 100644 index 00000000000..312ed1ce048 --- /dev/null +++ b/tests/parakeet-expected-gb1-output.txt @@ -0,0 +1 @@ +My fellow Americans, this day has brought terrible news and great sadness to our country. At nine o'clock this morning, mission control in Houston lost contact with our space shuttle Columbia. A short time later, debris was seen falling from the skies above Texas. The Columbia's lost. There are no survivors. On board was a crew of seven. Colonel Rick Husband, Lieutenant Colonel Michael Anderson, Commander Laurel Clark, Captain David Brown, Commander William McCool, Dr. Kulpna Shavla, and Ilan Ramon, a colonel in the Israeli Air Force. These men and women assumed great risk in the service to all humanity. In an age when space flight has come to seem almost routine. It is easy to overlook the dangers of travel by rocket and the difficulties of navigating the fierce outer atmosphere of the earth. These astronauts knew the dangers, and they faced them willingly, knowing they had a high and noble purpose in life. Because of their courage and daring and idealism, we will miss them all the more. And those you loved will always have the respect and gratitude of this country. The cause in which they died will continue. Mankind is led into the darkness beyond our world by the inspiration of discovery and the longing to understand. Our journey into space will go on. In the skies today, we saw destruction and tragedy. Yet farther than we can see, there is comfort and hope. In the words of the prophet Isaiah, lift your eyes and look to the heavens. Who created all these? He who brings out the starry hosts one by one and calls them each by name. Because of his great power and mighty strength, not one of them is missing. The same creator who names the stars also knows the names of the seven souls we mourn today. The crew of the shuttle Columbia did not return safely to Earth. Yet we can pray that all are safely home. May God bless the grieving families and make out may God continue to bless America. diff --git a/tests/parakeet-expected-jfk-output.txt b/tests/parakeet-expected-jfk-output.txt new file mode 100644 index 00000000000..ece35697ae8 --- /dev/null +++ b/tests/parakeet-expected-jfk-output.txt @@ -0,0 +1 @@ +And so, my fellow Americans, ask not what your country can do for you, ask what you can do for your country. diff --git a/tests/parakeet-verification.h b/tests/parakeet-verification.h new file mode 100644 index 00000000000..0e95610ba26 --- /dev/null +++ b/tests/parakeet-verification.h @@ -0,0 +1,110 @@ +#pragma once + +#include <algorithm> +#include <cassert> +#include <cctype> +#include <cstdio> +#include <fstream> +#include <iterator> +#include <string> +#include <vector> + +#ifndef TRANSCRIPTION_SIMILARITY_THRESHOLD +#define TRANSCRIPTION_SIMILARITY_THRESHOLD 1.0 +#endif + +static std::string read_expected_transcription(const char * path) { + std::ifstream fin(path); + assert(fin.is_open()); + + std::string text( + (std::istreambuf_iterator<char>(fin)), + std::istreambuf_iterator<char>()); + + while (!text.empty() && (text.back() == '\n' || text.back() == '\r')) { + text.pop_back(); + } + + return text; +} + +static std::vector<std::string> transcription_words(const std::string & text) { + std::vector<std::string> words; + std::string word; + + for (unsigned char ch : text) { + if (std::isalnum(ch)) { + word.push_back((char) std::tolower(ch)); + } else if (!word.empty()) { + words.push_back(word); + word.clear(); + } + } + + if (!word.empty()) { + words.push_back(word); + } + + return words; +} + +static double transcription_lcs_similarity(const std::string & expected, const std::string & actual) { + const std::vector<std::string> expected_words = transcription_words(expected); + const std::vector<std::string> actual_words = transcription_words(actual); + + if (expected_words.empty() && actual_words.empty()) { + return 1.0; + } + + if (expected_words.empty() || actual_words.empty()) { + return 0.0; + } + + std::vector<int> prev(actual_words.size() + 1, 0); + std::vector<int> cur (actual_words.size() + 1, 0); + + for (size_t i = 0; i < expected_words.size(); ++i) { + std::fill(cur.begin(), cur.end(), 0); + + for (size_t j = 0; j < actual_words.size(); ++j) { + if (expected_words[i] == actual_words[j]) { + cur[j + 1] = prev[j] + 1; + } else { + cur[j + 1] = std::max(prev[j + 1], cur[j]); + } + } + + prev.swap(cur); + } + + const int lcs = prev[actual_words.size()]; + return (2.0 * lcs) / (expected_words.size() + actual_words.size()); +} + +static bool verify_transcription(const std::string & expected, const std::string & actual) { + const double threshold = TRANSCRIPTION_SIMILARITY_THRESHOLD; + + if (threshold >= 1.0) { + if (actual == expected) { + return true; + } + + fprintf(stderr, "\n\n"); + fprintf(stderr, "[Failed] Transcript mismatched\n"); + fprintf(stderr, "expected:\n%s\n\n", expected.c_str()); + fprintf(stderr, "actual:\n%s\n", actual.c_str()); + return false; + } + + const double similarity = transcription_lcs_similarity(expected, actual); + printf("\nTranscript similarity: %.6f (threshold %.6f)\n", similarity, threshold); + + if (similarity >= threshold) { + return true; + } + + fprintf(stderr, "\n\nTranscript similarity below threshold: %.6f < %.6f\n", similarity, threshold); + fprintf(stderr, "Expected:\n%s\n\n", expected.c_str()); + fprintf(stderr, "Actual:\n%s\n", actual.c_str()); + return false; +} diff --git a/tests/run-tests.sh b/tests/run-tests.sh index ad2b8d3ec09..bc28314a704 100755 --- a/tests/run-tests.sh +++ b/tests/run-tests.sh @@ -21,13 +21,21 @@ cd `dirname $0` # Whisper models models=( "tiny.en" "tiny" "base.en" "base" "small.en" "small" "medium.en" "medium" "large-v1" "large-v2" "large-v3" "large-v3-turbo" ) +# Parakeet model variants +parakeet_models=( "f16" "f32" "q2_k" "q4_0" "q4_k" "q8_0" ) + # list available models function list_models { printf "\n" - printf " Available models:" + printf " Available whisper models:" for model in "${models[@]}"; do printf " $model" done + printf "\n" + printf " Available parakeet models:" + for model in "${parakeet_models[@]}"; do + printf " parakeet-$model" + done printf "\n\n" } @@ -39,15 +47,37 @@ if [ $# -eq 0 ]; then fi model=$1 -main="../build/bin/whisper-cli" threads="" if [ $# -eq 2 ]; then threads="-t $2" fi -if [ ! -f ../models/ggml-$model.bin ]; then - printf "Model $model not found. Aborting\n" +# Detect parakeet model (prefix "parakeet-" or a bare variant like "f32") +is_parakeet=0 +parakeet_variant="" +if [[ $model == parakeet-* ]]; then + is_parakeet=1 + parakeet_variant="${model#parakeet-}" +fi +for v in "${parakeet_models[@]}"; do + if [[ $model == "$v" ]]; then + is_parakeet=1 + parakeet_variant="$v" + break + fi +done + +if [ $is_parakeet -eq 1 ]; then + main="../build/bin/parakeet-cli" + model_path="../models/ggml-parakeet-tdt-0.6b-v3-${parakeet_variant}.bin" +else + main="../build/bin/whisper-cli" + model_path="../models/ggml-${model}.bin" +fi + +if [ ! -f $model_path ]; then + printf "Model $model not found ($model_path). Aborting\n" list_models exit 1 fi @@ -110,7 +140,11 @@ function run_lang() { fi fi - $main -m ../models/ggml-$model.bin $threads -f $fname_dst -l $lang -otxt 2> /dev/null + if [ $is_parakeet -eq 1 ]; then + $main -m $model_path $threads -f $fname_dst -otxt 2> /dev/null + else + $main -m $model_path $threads -f $fname_dst -l $lang -otxt 2> /dev/null + fi git diff --no-index --word-diff=color --word-diff-regex=. $lang-$i-ref.txt $fname_dst.txt @@ -120,7 +154,7 @@ function run_lang() { run_lang "en" "${urls_en[@]}" -if [[ $model != *.en* ]]; then +if [ $is_parakeet -eq 0 ] && [[ $model != *.en* ]]; then run_lang "es" "${urls_es[@]}" run_lang "it" "${urls_it[@]}" run_lang "pt" "${urls_pt[@]}" diff --git a/tests/test-common-utf8.cpp b/tests/test-common-utf8.cpp new file mode 100644 index 00000000000..91c73a7428d --- /dev/null +++ b/tests/test-common-utf8.cpp @@ -0,0 +1,34 @@ +#include "common-whisper.h" + +#include <cstdlib> +#include <cstdio> +#include <string> + +static void expect_needed(const std::string & input, int expected) { + const int actual = utf8_trailing_bytes_needed(input); + if (actual != expected) { + fprintf(stderr, "expected %d trailing UTF-8 bytes, got %d\n", expected, actual); + std::abort(); + } +} + +int main() { + expect_needed("", 0); + expect_needed("plain ascii", 0); + + const std::string cjk = "\xE4\xBD\xA0"; // U+4F60 + expect_needed(cjk.substr(0, 1), 2); + expect_needed(cjk.substr(0, 2), 1); + expect_needed(cjk, 0); + + const std::string emoji = "\xF0\x9F\x98\x80"; // U+1F600 + expect_needed(emoji.substr(0, 1), 3); + expect_needed(emoji.substr(0, 2), 2); + expect_needed(emoji.substr(0, 3), 1); + expect_needed(emoji, 0); + + expect_needed("\x80\x80", 0); + expect_needed("\xFF", 0); + + return 0; +} diff --git a/tests/test-parakeet-full.cpp b/tests/test-parakeet-full.cpp new file mode 100644 index 00000000000..22ac4c20e31 --- /dev/null +++ b/tests/test-parakeet-full.cpp @@ -0,0 +1,101 @@ +#include "parakeet.h" +#include "common-whisper.h" +#include "parakeet-verification.h" + +#include <cstdio> +#include <string> + +#ifdef NDEBUG +#undef NDEBUG +#endif +#include <cassert> + +struct test_state { + bool is_first = true; + std::string transcript; +}; + +void progress_callback(parakeet_context * ctx, parakeet_state * state, int progress, void * user_data) { + bool * called = static_cast<bool *>(user_data); + *called = true; +} + +bool encoder_begin_callback(parakeet_context * ctx, parakeet_state * state, void * user_data) { + bool * called = static_cast<bool *>(user_data); + *called = true; + return true; +} + +bool abort_callback(void * user_data) { + bool * called = static_cast<bool *>(user_data); + *called = true; + return false; // just continue without aborting. +} + +void token_callback(parakeet_context * ctx, parakeet_state * state, const parakeet_token_data * token_data, void * user_data) { + test_state * tstate = static_cast<test_state *>(user_data); + + const char * token_str = parakeet_token_to_str(ctx, token_data->id); + char text_buf[256]; + parakeet_token_to_text(token_str, tstate->is_first, text_buf, sizeof(text_buf)); + + printf("%s", text_buf); + fflush(stdout); + + tstate->transcript += text_buf; + tstate->is_first = false; +} + +int main() { + std::string model_path = PARAKEET_MODEL_PATH; + std::string sample_path = SAMPLE_PATH; + + std::vector<float> pcmf32; + std::vector<std::vector<float>> pcmf32s; + assert(read_audio_data(sample_path.c_str(), pcmf32, pcmf32s, false)); + assert(pcmf32.size() > 0); + assert(pcmf32s.size() == 0); // no stereo vector + + printf("Loading Parakeet model from: %s\n", model_path.c_str()); + + struct parakeet_context_params ctx_params = parakeet_context_default_params(); + + struct parakeet_context * pctx = parakeet_init_from_file_with_params(model_path.c_str(), ctx_params); + if (pctx == nullptr) { + fprintf(stderr, "Failed to load Parakeet model\n"); + return 1; + } + printf("Successfully loaded Parakeet model\n"); + + struct parakeet_full_params params = parakeet_full_default_params(PARAKEET_SAMPLING_GREEDY); + test_state tstate; + params.new_token_callback = token_callback; + params.new_token_callback_user_data = &tstate; + bool progress_callback_called = false; + params.progress_callback = progress_callback; + params.progress_callback_user_data = &progress_callback_called; + bool encoder_begin_callback_called = false; + params.encoder_begin_callback = encoder_begin_callback; + params.encoder_begin_callback_user_data = &encoder_begin_callback_called; + bool abort_callback_called = false; + params.abort_callback = abort_callback; + params.abort_callback_user_data = &abort_callback_called; + + int ret = parakeet_full(pctx, params, pcmf32.data(), pcmf32.size()); + assert(ret == 0); + assert(progress_callback_called); + assert(encoder_begin_callback_called); + assert(abort_callback_called); + + const std::string expected = read_expected_transcription(EXPECTED_TRANSCRIPTION_PATH); + const bool transcript_matches = verify_transcription(expected, tstate.transcript); + + parakeet_free(pctx); + + if (!transcript_matches) { + return 1; + } + + printf("\nTest passed: parakeet_full succeeded!\n"); + return 0; +} diff --git a/tests/test-parakeet.cpp b/tests/test-parakeet.cpp new file mode 100644 index 00000000000..83237c600ac --- /dev/null +++ b/tests/test-parakeet.cpp @@ -0,0 +1,99 @@ +#include "parakeet.h" +#include "common-whisper.h" + +#include <cstdio> +#include <string> + +#ifdef NDEBUG +#undef NDEBUG +#endif +#include <cassert> + +void token_callback(parakeet_context * ctx, parakeet_state * state, const parakeet_token_data * token_data, void * user_data) { + static bool is_first = true; + const char * token_str = parakeet_token_to_str(ctx, token_data->id); + char text_buf[256]; + parakeet_token_to_text(token_str, is_first, text_buf, sizeof(text_buf)); + + int32_t time_ms = token_data->frame_index * 10; + + printf("%s", text_buf); + fflush(stdout); + + is_first = false; +} + +void segment_callback(parakeet_context * ctx, parakeet_state * state, int n_new, void * user_data) { + const int n_segments = parakeet_full_n_segments_from_state(state); + const int s0 = n_segments - n_new; + + printf("\nSegment Callback: %d new segment(s)\n", n_new); + + for (int i = s0; i < n_segments; i++) { + const char * text = parakeet_full_get_segment_text_from_state(state, i); + const int64_t t0 = parakeet_full_get_segment_t0_from_state(state, i); + const int64_t t1 = parakeet_full_get_segment_t1_from_state(state, i); + + printf("Segment %d: [%lld -> %lld] \"%s\"\n", i, (long long)t0, (long long)t1, text); + printf("Tokens:\n"); + + const int n_tokens = parakeet_full_n_tokens_from_state(state, i); + for (int j = 0; j < n_tokens; j++) { + parakeet_token_data token_data = parakeet_full_get_token_data_from_state(state, i, j); + const char * token_str = parakeet_token_to_str(ctx, token_data.id); + + printf(" [%2d] id=%5d frame=%3d dur_idx=%2d dur_val=%2d p=%.4f plog=%.4f t0=%4lld t1=%4lld word_start=%d \"%s\"\n", + j, + token_data.id, + token_data.frame_index, + token_data.duration_idx, + token_data.duration_value, + token_data.p, + token_data.plog, + (long long)token_data.t0, + (long long)token_data.t1, + token_data.is_word_start, + token_str); + } + } + printf("\n"); +} + +int main() { + std::string model_path = PARAKEET_MODEL_PATH; + std::string sample_path = SAMPLE_PATH; + + // Load the sample audio file + std::vector<float> pcmf32; + std::vector<std::vector<float>> pcmf32s; + assert(read_audio_data(sample_path.c_str(), pcmf32, pcmf32s, false)); + assert(pcmf32.size() > 0); + assert(pcmf32s.size() == 0); + + printf("Loading Parakeet model from: %s\n", model_path.c_str()); + + struct parakeet_context_params ctx_params = parakeet_context_default_params(); + + struct parakeet_context * pctx = parakeet_init_from_file_with_params_no_state(model_path.c_str(), ctx_params); + if (pctx == nullptr) { + fprintf(stderr, "Failed to load Parakeet model\n"); + return 1; + } + printf("Successfully loaded Parakeet model\n"); + + struct parakeet_full_params params = parakeet_full_default_params(PARAKEET_SAMPLING_GREEDY); + params.new_token_callback = token_callback; + params.new_token_callback_user_data = nullptr; + params.new_segment_callback = segment_callback; + params.new_segment_callback_user_data = nullptr; + parakeet_state * state = parakeet_init_state(pctx); + + int ret = parakeet_chunk(pctx, state, params, pcmf32.data(), pcmf32.size()); + assert(ret == 0); + + parakeet_free_state(state); + parakeet_free(pctx); + + printf("\nTest passed: Parakeet model loaded and freed successfully\n"); + return 0; +}